imt-ring 1.6.47__tar.gz → 1.7.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {imt_ring-1.6.47 → imt_ring-1.7.1}/PKG-INFO +2 -2
- {imt_ring-1.6.47 → imt_ring-1.7.1}/pyproject.toml +1 -1
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/imt_ring.egg-info/PKG-INFO +2 -2
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/imt_ring.egg-info/SOURCES.txt +12 -9
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/ring/__init__.py +6 -6
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/ring/algorithms/_random.py +68 -35
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/ring/algorithms/custom_joints/rr_imp_joint.py +2 -3
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/ring/algorithms/generator/base.py +30 -1
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/ring/algorithms/jcalc.py +35 -18
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/ring/base.py +2 -1
- imt_ring-1.7.1/src/ring/extras/__init__.py +0 -0
- imt_ring-1.7.1/src/ring/extras/interactive_viewer.py +114 -0
- imt_ring-1.7.1/src/ring/extras/torch_loss_fn.py +93 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/ring/rendering/base_render.py +30 -3
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/ring/rendering/mujoco_render.py +38 -7
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/ring/utils/__init__.py +0 -4
- {imt_ring-1.6.47 → imt_ring-1.7.1}/readme.md +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/setup.cfg +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/imt_ring.egg-info/dependency_links.txt +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/imt_ring.egg-info/entry_points.txt +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/imt_ring.egg-info/requires.txt +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/imt_ring.egg-info/top_level.txt +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/ring/algebra.py +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/ring/algorithms/__init__.py +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/ring/algorithms/custom_joints/__init__.py +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/ring/algorithms/custom_joints/rr_joint.py +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/ring/algorithms/custom_joints/rsaddle_joint.py +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/ring/algorithms/custom_joints/suntay.py +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/ring/algorithms/dynamics.py +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/ring/algorithms/generator/__init__.py +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/ring/algorithms/generator/batch.py +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/ring/algorithms/generator/finalize_fns.py +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/ring/algorithms/generator/motion_artifacts.py +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/ring/algorithms/generator/pd_control.py +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/ring/algorithms/generator/setup_fns.py +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/ring/algorithms/generator/types.py +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/ring/algorithms/kinematics.py +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/ring/algorithms/sensors.py +0 -0
- {imt_ring-1.6.47/src/ring/utils → imt_ring-1.7.1/src/ring/extras}/backend.py +0 -0
- {imt_ring-1.6.47/src/ring/utils → imt_ring-1.7.1/src/ring/extras}/colab.py +0 -0
- {imt_ring-1.6.47/src/ring/utils → imt_ring-1.7.1/src/ring/extras}/dataloader.py +0 -0
- {imt_ring-1.6.47/src/ring/utils → imt_ring-1.7.1/src/ring/extras}/dataloader_torch.py +0 -0
- {imt_ring-1.6.47/src/ring/utils → imt_ring-1.7.1/src/ring/extras}/hdf5.py +0 -0
- {imt_ring-1.6.47/src/ring/utils → imt_ring-1.7.1/src/ring/extras}/normalizer.py +0 -0
- {imt_ring-1.6.47/src/ring/utils → imt_ring-1.7.1/src/ring/extras}/randomize_sys.py +0 -0
- {imt_ring-1.6.47/src/ring/utils → imt_ring-1.7.1/src/ring/extras}/register_gym_envs/__init__.py +0 -0
- {imt_ring-1.6.47/src/ring/utils → imt_ring-1.7.1/src/ring/extras}/register_gym_envs/saddle.py +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/ring/io/__init__.py +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/ring/io/examples/branched.xml +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/ring/io/examples/exclude/knee_trans_dof.xml +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/ring/io/examples/exclude/standard_sys.xml +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/ring/io/examples/exclude/standard_sys_rr_imp.xml +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/ring/io/examples/inv_pendulum.xml +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/ring/io/examples/knee_flexible_imus.xml +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/ring/io/examples/spherical_stiff.xml +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/ring/io/examples/symmetric.xml +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/ring/io/examples/test_all_1.xml +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/ring/io/examples/test_all_2.xml +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/ring/io/examples/test_ang0_pos0.xml +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/ring/io/examples/test_control.xml +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/ring/io/examples/test_double_pendulum.xml +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/ring/io/examples/test_free.xml +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/ring/io/examples/test_kinematics.xml +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/ring/io/examples/test_morph_system/four_seg_seg1.xml +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/ring/io/examples/test_morph_system/four_seg_seg3.xml +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/ring/io/examples/test_randomize_position.xml +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/ring/io/examples/test_sensors.xml +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/ring/io/examples/test_three_seg_seg2.xml +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/ring/io/examples.py +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/ring/io/test_examples.py +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/ring/io/xml/__init__.py +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/ring/io/xml/abstract.py +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/ring/io/xml/from_xml.py +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/ring/io/xml/test_from_xml.py +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/ring/io/xml/test_to_xml.py +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/ring/io/xml/to_xml.py +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/ring/maths.py +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/ring/ml/__init__.py +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/ring/ml/base.py +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/ring/ml/callbacks.py +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/ring/ml/ml_utils.py +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/ring/ml/optimizer.py +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/ring/ml/params/0x13e3518065c21cd8.pickle +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/ring/ml/params/0x1d76628065a71e0f.pickle +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/ring/ml/ringnet.py +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/ring/ml/rnno_v1.py +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/ring/ml/train.py +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/ring/ml/training_loop.py +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/ring/rendering/__init__.py +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/ring/rendering/vispy_render.py +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/ring/rendering/vispy_visuals.py +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/ring/sim2real/__init__.py +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/ring/sim2real/sim2real.py +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/ring/spatial.py +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/ring/sys_composer/__init__.py +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/ring/sys_composer/delete_sys.py +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/ring/sys_composer/inject_sys.py +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/ring/sys_composer/morph_sys.py +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/ring/utils/batchsize.py +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/ring/utils/path.py +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/src/ring/utils/utils.py +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/tests/test_algebra.py +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/tests/test_base.py +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/tests/test_custom_joints.py +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/tests/test_dynamics.py +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/tests/test_generator.py +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/tests/test_jcalc.py +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/tests/test_jit.py +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/tests/test_kinematics.py +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/tests/test_maths.py +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/tests/test_ml_utils.py +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/tests/test_motion_artifacts.py +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/tests/test_pd_control.py +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/tests/test_quickstart_example.py +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/tests/test_random.py +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/tests/test_randomize.py +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/tests/test_rcmg.py +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/tests/test_render.py +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/tests/test_sensors.py +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/tests/test_sim2real.py +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/tests/test_sys_composer.py +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/tests/test_train.py +0 -0
- {imt_ring-1.6.47 → imt_ring-1.7.1}/tests/test_utils.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
|
-
Metadata-Version: 2.
|
1
|
+
Metadata-Version: 2.4
|
2
2
|
Name: imt-ring
|
3
|
-
Version: 1.
|
3
|
+
Version: 1.7.1
|
4
4
|
Summary: RING: Recurrent Inertial Graph-based Estimator
|
5
5
|
Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
|
6
6
|
Project-URL: Homepage, https://github.com/SimiPixel/ring
|
@@ -1,6 +1,6 @@
|
|
1
|
-
Metadata-Version: 2.
|
1
|
+
Metadata-Version: 2.4
|
2
2
|
Name: imt-ring
|
3
|
-
Version: 1.
|
3
|
+
Version: 1.7.1
|
4
4
|
Summary: RING: Recurrent Inertial Graph-based Estimator
|
5
5
|
Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
|
6
6
|
Project-URL: Homepage, https://github.com/SimiPixel/ring
|
@@ -31,6 +31,18 @@ src/ring/algorithms/generator/motion_artifacts.py
|
|
31
31
|
src/ring/algorithms/generator/pd_control.py
|
32
32
|
src/ring/algorithms/generator/setup_fns.py
|
33
33
|
src/ring/algorithms/generator/types.py
|
34
|
+
src/ring/extras/__init__.py
|
35
|
+
src/ring/extras/backend.py
|
36
|
+
src/ring/extras/colab.py
|
37
|
+
src/ring/extras/dataloader.py
|
38
|
+
src/ring/extras/dataloader_torch.py
|
39
|
+
src/ring/extras/hdf5.py
|
40
|
+
src/ring/extras/interactive_viewer.py
|
41
|
+
src/ring/extras/normalizer.py
|
42
|
+
src/ring/extras/randomize_sys.py
|
43
|
+
src/ring/extras/torch_loss_fn.py
|
44
|
+
src/ring/extras/register_gym_envs/__init__.py
|
45
|
+
src/ring/extras/register_gym_envs/saddle.py
|
34
46
|
src/ring/io/__init__.py
|
35
47
|
src/ring/io/examples.py
|
36
48
|
src/ring/io/test_examples.py
|
@@ -83,18 +95,9 @@ src/ring/sys_composer/delete_sys.py
|
|
83
95
|
src/ring/sys_composer/inject_sys.py
|
84
96
|
src/ring/sys_composer/morph_sys.py
|
85
97
|
src/ring/utils/__init__.py
|
86
|
-
src/ring/utils/backend.py
|
87
98
|
src/ring/utils/batchsize.py
|
88
|
-
src/ring/utils/colab.py
|
89
|
-
src/ring/utils/dataloader.py
|
90
|
-
src/ring/utils/dataloader_torch.py
|
91
|
-
src/ring/utils/hdf5.py
|
92
|
-
src/ring/utils/normalizer.py
|
93
99
|
src/ring/utils/path.py
|
94
|
-
src/ring/utils/randomize_sys.py
|
95
100
|
src/ring/utils/utils.py
|
96
|
-
src/ring/utils/register_gym_envs/__init__.py
|
97
|
-
src/ring/utils/register_gym_envs/saddle.py
|
98
101
|
tests/test_algebra.py
|
99
102
|
tests/test_base.py
|
100
103
|
tests/test_custom_joints.py
|
@@ -35,12 +35,12 @@ def RING(lam: list[int] | None, Ts: float | None, **kwargs) -> ml.AbstractFilter
|
|
35
35
|
>>> import ring
|
36
36
|
>>> import numpy as np
|
37
37
|
>>>
|
38
|
-
>>> T : int = 30
|
39
|
-
>>> Ts : float = 0.01
|
40
|
-
>>> B : int = 1
|
41
|
-
>>> lam: list[int] = [0, 1
|
42
|
-
>>> N : int = len(lam)
|
43
|
-
>>> T_i: int = int(T/Ts)
|
38
|
+
>>> T : int = 30 # sequence length [s]
|
39
|
+
>>> Ts : float = 0.01 # sampling interval [s]
|
40
|
+
>>> B : int = 1 # batch size
|
41
|
+
>>> lam: list[int] = [-1, 0, 1] # parent array
|
42
|
+
>>> N : int = len(lam) # number of bodies
|
43
|
+
>>> T_i: int = int(T/Ts) # number of timesteps
|
44
44
|
>>>
|
45
45
|
>>> X = np.zeros((B, T_i, N, 9))
|
46
46
|
>>> # where X is structured as follows:
|
@@ -41,30 +41,48 @@ def random_angle_over_time(
|
|
41
41
|
cdf_bins_min: int = 5,
|
42
42
|
cdf_bins_max: Optional[int] = None,
|
43
43
|
interpolation_method: str = "cosine",
|
44
|
+
include_standstills_prob: float = 0.0, # 0.0 means no standstills
|
45
|
+
include_standstills_t_min: float = 0.5,
|
46
|
+
include_standstills_t_max: float = 5.0,
|
44
47
|
) -> jax.Array:
|
45
48
|
def body_fn_outer(val):
|
46
49
|
i, t, phi, key_t, key_ang, ANG = val
|
47
50
|
|
48
|
-
key_t, consume_t = random.split(key_t)
|
51
|
+
key_t, consume_t, consume_standstill = random.split(key_t, 3)
|
49
52
|
key_ang, consume_ang = random.split(key_ang)
|
50
53
|
rom_halfsize_float = _to_float(rom_halfsize, t)
|
51
54
|
rom_lower = ANG_0 - rom_halfsize_float
|
52
55
|
rom_upper = ANG_0 + rom_halfsize_float
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
56
|
+
|
57
|
+
is_standstill = jax.random.bernoulli(
|
58
|
+
consume_standstill, include_standstills_prob
|
59
|
+
)
|
60
|
+
dt, phi = jax.lax.cond(
|
61
|
+
is_standstill,
|
62
|
+
lambda: (
|
63
|
+
jax.random.uniform(
|
64
|
+
consume_t,
|
65
|
+
minval=include_standstills_t_min,
|
66
|
+
maxval=include_standstills_t_max,
|
67
|
+
),
|
68
|
+
phi,
|
69
|
+
),
|
70
|
+
lambda: _resolve_range_of_motion(
|
71
|
+
range_of_motion,
|
72
|
+
range_of_motion_method,
|
73
|
+
rom_lower,
|
74
|
+
rom_upper,
|
75
|
+
_to_float(dang_min, t),
|
76
|
+
_to_float(dang_max, t),
|
77
|
+
_to_float(delta_ang_min, t),
|
78
|
+
_to_float(delta_ang_max, t),
|
79
|
+
t_min,
|
80
|
+
_to_float(t_max, t),
|
81
|
+
phi,
|
82
|
+
consume_t,
|
83
|
+
consume_ang,
|
84
|
+
max_iter,
|
85
|
+
),
|
68
86
|
)
|
69
87
|
t += dt
|
70
88
|
|
@@ -119,7 +137,8 @@ def random_angle_over_time(
|
|
119
137
|
|
120
138
|
# APPROVED
|
121
139
|
def random_position_over_time(
|
122
|
-
|
140
|
+
key_t: random.PRNGKey,
|
141
|
+
key_value: random.PRNGKey,
|
123
142
|
POS_0: float,
|
124
143
|
pos_min: float | TimeDependentFloat,
|
125
144
|
pos_max: float | TimeDependentFloat,
|
@@ -135,19 +154,14 @@ def random_position_over_time(
|
|
135
154
|
cdf_bins_min: int = 5,
|
136
155
|
cdf_bins_max: Optional[int] = None,
|
137
156
|
interpolation_method: str = "cosine",
|
157
|
+
include_standstills_prob: float = 0.0, # 0.0 means no standstills
|
158
|
+
include_standstills_t_min: float = 0.5,
|
159
|
+
include_standstills_t_max: float = 5.0,
|
138
160
|
) -> jax.Array:
|
139
161
|
def body_fn_inner(val):
|
140
162
|
i, t, t_pre, x, x_pre, key = val
|
141
163
|
dt = t - t_pre
|
142
164
|
|
143
|
-
def sample_dx_squared(key):
|
144
|
-
key, consume = random.split(key)
|
145
|
-
dx = (
|
146
|
-
random.uniform(consume) * (2 * dpos_max * t_max**2)
|
147
|
-
- dpos_max * t_max**2
|
148
|
-
)
|
149
|
-
return key, dx
|
150
|
-
|
151
165
|
def sample_dx(key):
|
152
166
|
key, consume1, consume2 = random.split(key, 3)
|
153
167
|
sign = random.choice(consume1, jnp.array([-1.0, 1.0]))
|
@@ -182,24 +196,43 @@ def random_position_over_time(
|
|
182
196
|
return jnp.logical_not(break_if_true1 | break_if_true2)
|
183
197
|
|
184
198
|
def body_fn_outer(val):
|
185
|
-
i, t, t_pre, x, x_pre,
|
186
|
-
|
187
|
-
t += random.uniform(consume, minval=t_min, maxval=_to_float(t_max, t_pre))
|
199
|
+
i, t, t_pre, x, x_pre, key_t, key_value, POS = val
|
200
|
+
key_t, consume_t, consume_standstill = random.split(key_t, 3)
|
188
201
|
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
202
|
+
is_standstill = jax.random.bernoulli(
|
203
|
+
consume_standstill, include_standstills_prob
|
204
|
+
)
|
205
|
+
|
206
|
+
def is_standstill_branch():
|
207
|
+
dt = random.uniform(
|
208
|
+
consume_t,
|
209
|
+
minval=include_standstills_t_min,
|
210
|
+
maxval=include_standstills_t_max,
|
211
|
+
)
|
212
|
+
t = t_pre + dt
|
213
|
+
return 0, t, t_pre, x, x_pre, key_value
|
214
|
+
|
215
|
+
def no_standstill_branch():
|
216
|
+
dt = random.uniform(consume_t, minval=t_min, maxval=_to_float(t_max, t_pre))
|
217
|
+
t = t_pre + dt
|
218
|
+
# that zero resets the max_it count
|
219
|
+
val_inner = (0, t, t_pre, x, x_pre, key_value)
|
220
|
+
return jax.lax.while_loop(cond_fn_inner, body_fn_inner, val_inner)
|
221
|
+
|
222
|
+
_, t, t_pre, x, x_pre, key_value = jax.lax.cond(
|
223
|
+
is_standstill,
|
224
|
+
is_standstill_branch,
|
225
|
+
no_standstill_branch,
|
193
226
|
)
|
194
227
|
|
195
228
|
POS_i = jnp.array([[jnp.floor(t / Ts) * Ts, x]])
|
196
229
|
POS = jax.lax.dynamic_update_slice_in_dim(POS, POS_i, start_index=i, axis=0)
|
197
230
|
t_pre = t
|
198
231
|
x_pre = x
|
199
|
-
return i + 1, t, t_pre, x, x_pre,
|
232
|
+
return i + 1, t, t_pre, x, x_pre, key_t, key_value, POS
|
200
233
|
|
201
234
|
def cond_fn_outer(val):
|
202
|
-
i, t, t_pre, x, x_pre,
|
235
|
+
i, t, t_pre, x, x_pre, key_t, key_value, POS = val
|
203
236
|
return t <= T
|
204
237
|
|
205
238
|
# preallocate POS array
|
@@ -207,7 +240,7 @@ def random_position_over_time(
|
|
207
240
|
POS = jnp.zeros((int(T // t_min) + 1, 2))
|
208
241
|
POS = POS.at[0, 1].set(POS_0)
|
209
242
|
|
210
|
-
val_outer = (1, 0.0, 0.0, POS_0, POS_0,
|
243
|
+
val_outer = (1, 0.0, 0.0, POS_0, POS_0, key_t, key_value, POS)
|
211
244
|
end, *_, consume, POS = jax.lax.while_loop(cond_fn_outer, body_fn_outer, val_outer)
|
212
245
|
POS = jnp.where(
|
213
246
|
(jnp.arange(len(POS)) < end)[:, None],
|
@@ -23,11 +23,10 @@ def register_rr_imp_joint(
|
|
23
23
|
return ring.Transform.create(rot=rot)
|
24
24
|
|
25
25
|
def _draw_rr_imp(config, key_t, key_value, dt, N, _):
|
26
|
-
key_t1, key_t2 = jax.random.split(key_t)
|
27
26
|
key_value1, key_value2 = jax.random.split(key_value)
|
28
|
-
q_traj_pri = _draw_rxyz(config,
|
27
|
+
q_traj_pri = _draw_rxyz(config, key_t, key_value1, dt, N, _)
|
29
28
|
q_traj_res = _draw_rxyz(
|
30
|
-
replace(config_res, T=config.T),
|
29
|
+
replace(config_res, T=config.T), key_t, key_value2, dt, N, _
|
31
30
|
)
|
32
31
|
# scale to be within bounds
|
33
32
|
q_traj_res = q_traj_res * (jnp.deg2rad(ang_max_deg) / jnp.pi)
|
@@ -1,5 +1,6 @@
|
|
1
1
|
from dataclasses import replace
|
2
2
|
from functools import partial
|
3
|
+
import json
|
3
4
|
import logging
|
4
5
|
import random
|
5
6
|
from typing import Callable, Optional
|
@@ -136,6 +137,14 @@ class RCMG:
|
|
136
137
|
affecting joint motion behavior.
|
137
138
|
""" # noqa: E501
|
138
139
|
|
140
|
+
# capture all funtion arguments before creating local variables
|
141
|
+
to_json_kwargs = locals()
|
142
|
+
# the purpose is to not capture the RCMG itself since we want to make it
|
143
|
+
# serialisable in the first place
|
144
|
+
to_json_kwargs.pop("self")
|
145
|
+
to_json_kwargs.pop("sys")
|
146
|
+
to_json_kwargs.pop("config")
|
147
|
+
|
139
148
|
# add some default values
|
140
149
|
randomize_hz_kwargs_defaults = dict(add_dt=True)
|
141
150
|
randomize_hz_kwargs_defaults.update(randomize_hz_kwargs)
|
@@ -186,6 +195,11 @@ class RCMG:
|
|
186
195
|
|
187
196
|
self._disable_tqdm = disable_tqdm
|
188
197
|
|
198
|
+
# store arguments that fully define the RCMG objects for use in `.to_json`
|
199
|
+
self._to_json_sys = sys
|
200
|
+
self._to_json_mconfig = config
|
201
|
+
self._to_json_kwargs = to_json_kwargs
|
202
|
+
|
189
203
|
def _compute_repeats(self, sizes: int | list[int]) -> list[int]:
|
190
204
|
"how many times the generators are repeated to create a batch of `sizes`"
|
191
205
|
|
@@ -355,6 +369,21 @@ class RCMG:
|
|
355
369
|
|
356
370
|
return generator
|
357
371
|
|
372
|
+
def serialise_to_dict(self) -> dict:
|
373
|
+
dict_representation = {
|
374
|
+
"system": [_sys.to_str(warn=False) for _sys in self._to_json_sys],
|
375
|
+
"motion_configs": [_config.__dict__ for _config in self._to_json_mconfig],
|
376
|
+
"kwargs": self._to_json_kwargs,
|
377
|
+
}
|
378
|
+
return dict_representation
|
379
|
+
|
380
|
+
def serialise_to_json(self, path_of_json: str) -> None:
|
381
|
+
with open(path_of_json, "w") as file:
|
382
|
+
json.dump(self.serialise_to_dict(), file, indent=4)
|
383
|
+
|
384
|
+
def from_json(self, path_to_json: str) -> "RCMG":
|
385
|
+
raise NotImplementedError
|
386
|
+
|
358
387
|
|
359
388
|
def _copy_dicts(f) -> dict:
|
360
389
|
def _f(*args, **kwargs):
|
@@ -526,7 +555,7 @@ def draw_random_q(
|
|
526
555
|
sys: base.System,
|
527
556
|
config: jcalc.MotionConfig,
|
528
557
|
N: int | None,
|
529
|
-
) -> tuple[
|
558
|
+
) -> tuple[jax.random.PRNGKey, jax.Array]:
|
530
559
|
|
531
560
|
key_start = key
|
532
561
|
# build generalized coordintes vector `q`
|
@@ -174,6 +174,16 @@ class MotionConfig:
|
|
174
174
|
default_factory=lambda: dict()
|
175
175
|
)
|
176
176
|
|
177
|
+
# fields related to simulating standstills (no motion time periods)
|
178
|
+
# these are "Joint Standstills" so the standstills are calculated on
|
179
|
+
# a joint level, for each joint independently
|
180
|
+
# This means that a `standstills_prob` of 20% means that each joint
|
181
|
+
# has at each dt \in [t_min, t_max] drawing process a probability of
|
182
|
+
# 20% that it will just stay at its current joint value
|
183
|
+
include_standstills_prob: float = 0.0 # in %; 0% means no standstills
|
184
|
+
include_standstills_t_min: float = 0.5
|
185
|
+
include_standstills_t_max: float = 5.0
|
186
|
+
|
177
187
|
def is_feasible(self) -> bool:
|
178
188
|
return _is_feasible_config1(self)
|
179
189
|
|
@@ -791,12 +801,15 @@ def _draw_rxyz(
|
|
791
801
|
config.cdf_bins_min,
|
792
802
|
config.cdf_bins_max,
|
793
803
|
config.interpolation_method,
|
804
|
+
config.include_standstills_prob,
|
805
|
+
config.include_standstills_t_min,
|
806
|
+
config.include_standstills_t_max,
|
794
807
|
)
|
795
808
|
|
796
809
|
|
797
810
|
def _draw_pxyz(
|
798
811
|
config: MotionConfig,
|
799
|
-
|
812
|
+
key_t: jax.random.PRNGKey,
|
800
813
|
key_value: jax.random.PRNGKey,
|
801
814
|
dt: float | jax.Array,
|
802
815
|
N: int | None,
|
@@ -811,6 +824,7 @@ def _draw_pxyz(
|
|
811
824
|
)
|
812
825
|
max_iter = 100
|
813
826
|
return _random.random_position_over_time(
|
827
|
+
key_t,
|
814
828
|
key_value,
|
815
829
|
POS_0,
|
816
830
|
config.cor_pos_min if cor else config.pos_min,
|
@@ -827,6 +841,9 @@ def _draw_pxyz(
|
|
827
841
|
config.cdf_bins_min,
|
828
842
|
config.cdf_bins_max,
|
829
843
|
config.interpolation_method,
|
844
|
+
config.include_standstills_prob,
|
845
|
+
config.include_standstills_t_min,
|
846
|
+
config.include_standstills_t_max,
|
830
847
|
)
|
831
848
|
|
832
849
|
|
@@ -840,7 +857,6 @@ def _draw_spherical(
|
|
840
857
|
) -> jax.Array:
|
841
858
|
# NOTE: We draw 3 euler angles and then build a quaternion.
|
842
859
|
# Not ideal, but i am unaware of a better way.
|
843
|
-
@jax.vmap
|
844
860
|
def draw_euler_angles(key_t, key_value):
|
845
861
|
return _draw_rxyz(
|
846
862
|
config,
|
@@ -853,8 +869,9 @@ def _draw_spherical(
|
|
853
869
|
free_spherical=True,
|
854
870
|
)
|
855
871
|
|
856
|
-
|
857
|
-
|
872
|
+
euler_angles = jax.vmap(draw_euler_angles, in_axes=(None, 0))(
|
873
|
+
key_t, jax.random.split(key_value, 3)
|
874
|
+
).T
|
858
875
|
q = maths.quat_euler(euler_angles)
|
859
876
|
return q
|
860
877
|
|
@@ -867,7 +884,6 @@ def _draw_saddle(
|
|
867
884
|
N: int | None,
|
868
885
|
_: jax.Array,
|
869
886
|
) -> jax.Array:
|
870
|
-
@jax.vmap
|
871
887
|
def draw_euler_angles(key_t, key_value):
|
872
888
|
return _draw_rxyz(
|
873
889
|
config,
|
@@ -880,14 +896,15 @@ def _draw_saddle(
|
|
880
896
|
free_spherical=False,
|
881
897
|
)
|
882
898
|
|
883
|
-
|
884
|
-
|
899
|
+
yz_euler_angles = jax.vmap(draw_euler_angles, in_axes=(None, 0))(
|
900
|
+
key_t, jax.random.split(key_value)
|
901
|
+
).T
|
885
902
|
return yz_euler_angles
|
886
903
|
|
887
904
|
|
888
905
|
def _draw_p3d_and_cor(
|
889
906
|
config: MotionConfig,
|
890
|
-
|
907
|
+
key_t: jax.random.PRNGKey,
|
891
908
|
key_value: jax.random.PRNGKey,
|
892
909
|
dt: float | jax.Array,
|
893
910
|
N: int | None,
|
@@ -896,7 +913,7 @@ def _draw_p3d_and_cor(
|
|
896
913
|
) -> jax.Array:
|
897
914
|
keys = jax.random.split(key_value, 3)
|
898
915
|
|
899
|
-
def draw(
|
916
|
+
def draw(key_value, xyz: str):
|
900
917
|
return _draw_pxyz(
|
901
918
|
replace(
|
902
919
|
config,
|
@@ -905,8 +922,8 @@ def _draw_p3d_and_cor(
|
|
905
922
|
pos0_min=getattr(config, f"pos0_min_p3d_{xyz}"),
|
906
923
|
pos0_max=getattr(config, f"pos0_max_p3d_{xyz}"),
|
907
924
|
),
|
908
|
-
|
909
|
-
|
925
|
+
key_t,
|
926
|
+
key_value,
|
910
927
|
dt,
|
911
928
|
N,
|
912
929
|
None,
|
@@ -919,26 +936,26 @@ def _draw_p3d_and_cor(
|
|
919
936
|
|
920
937
|
def _draw_p3d(
|
921
938
|
config: MotionConfig,
|
922
|
-
|
939
|
+
key_t: jax.random.PRNGKey,
|
923
940
|
key_value: jax.random.PRNGKey,
|
924
941
|
dt: float | jax.Array,
|
925
942
|
N: int | None,
|
926
943
|
__: jax.Array,
|
927
944
|
) -> jax.Array:
|
928
|
-
return _draw_p3d_and_cor(config,
|
945
|
+
return _draw_p3d_and_cor(config, key_t, key_value, dt, N, None, cor=False)
|
929
946
|
|
930
947
|
|
931
948
|
def _draw_cor(
|
932
949
|
config: MotionConfig,
|
933
|
-
|
950
|
+
key_t: jax.random.PRNGKey,
|
934
951
|
key_value: jax.random.PRNGKey,
|
935
952
|
dt: float | jax.Array,
|
936
953
|
N: int | None,
|
937
954
|
__: jax.Array,
|
938
955
|
) -> jax.Array:
|
939
956
|
key_value1, key_value2 = jax.random.split(key_value)
|
940
|
-
q_free = _draw_free(config,
|
941
|
-
q_p3d = _draw_p3d_and_cor(config,
|
957
|
+
q_free = _draw_free(config, key_t, key_value1, dt, N, None)
|
958
|
+
q_p3d = _draw_p3d_and_cor(config, key_t, key_value2, dt, N, None, cor=True)
|
942
959
|
return jnp.concatenate((q_free, q_p3d), axis=1)
|
943
960
|
|
944
961
|
|
@@ -952,7 +969,7 @@ def _draw_free(
|
|
952
969
|
) -> jax.Array:
|
953
970
|
key_value1, key_value2 = jax.random.split(key_value)
|
954
971
|
q = _draw_spherical(config, key_t, key_value1, dt, N, None)
|
955
|
-
pos = _draw_p3d(config,
|
972
|
+
pos = _draw_p3d(config, key_t, key_value2, dt, N, None)
|
956
973
|
return jnp.concatenate((q, pos), axis=1)
|
957
974
|
|
958
975
|
|
@@ -975,7 +992,7 @@ def _draw_free_2d(
|
|
975
992
|
enable_range_of_motion=False,
|
976
993
|
free_spherical=True,
|
977
994
|
)[:, None]
|
978
|
-
pos_yz = _draw_p3d(config,
|
995
|
+
pos_yz = _draw_p3d(config, key_t, key_value2, dt, N, None)[:, :2]
|
979
996
|
return jnp.concatenate((angle_x, pos_yz), axis=1)
|
980
997
|
|
981
998
|
|
@@ -981,6 +981,7 @@ class System(_Base):
|
|
981
981
|
|
982
982
|
def render(
|
983
983
|
self,
|
984
|
+
qs: Optional[jax.Array | list[jax.Array]] = None,
|
984
985
|
xs: Optional[Transform | list[Transform]] = None,
|
985
986
|
camera: Optional[str] = None,
|
986
987
|
show_pbar: bool = True,
|
@@ -1001,7 +1002,7 @@ class System(_Base):
|
|
1001
1002
|
list[np.ndarray]: Stacked rendered frames. Length == len(xs).
|
1002
1003
|
"""
|
1003
1004
|
return ring.rendering.render(
|
1004
|
-
self, xs, camera, show_pbar, backend, render_every_nth, **scene_kwargs
|
1005
|
+
self, qs, xs, camera, show_pbar, backend, render_every_nth, **scene_kwargs
|
1005
1006
|
)
|
1006
1007
|
|
1007
1008
|
def render_prediction(
|
File without changes
|
@@ -0,0 +1,114 @@
|
|
1
|
+
import multiprocessing
|
2
|
+
import time
|
3
|
+
from typing import Optional
|
4
|
+
|
5
|
+
import fire
|
6
|
+
import jax.numpy as jnp
|
7
|
+
import numpy as np
|
8
|
+
|
9
|
+
import ring
|
10
|
+
from ring import System
|
11
|
+
|
12
|
+
|
13
|
+
class InteractiveViewer:
|
14
|
+
def __init__(self, sys: ring.System, **scene_kwargs):
|
15
|
+
self._mp_dict = multiprocessing.Manager().dict()
|
16
|
+
self._geom_dict = multiprocessing.Manager().dict()
|
17
|
+
self.update_q(np.array(ring.State.create(sys).q))
|
18
|
+
self.process = multiprocessing.Process(
|
19
|
+
target=self._worker,
|
20
|
+
args=(self._mp_dict, self._geom_dict, sys.to_str(), scene_kwargs),
|
21
|
+
)
|
22
|
+
self.process.start()
|
23
|
+
|
24
|
+
def update_q(self, q: np.ndarray):
|
25
|
+
self._mp_dict["q"] = q
|
26
|
+
|
27
|
+
def make_geometry_transparent(self, body_number: int, geom_number: int):
|
28
|
+
geom_name = f"body{body_number}_geom{geom_number}"
|
29
|
+
# the value is not used
|
30
|
+
self._geom_dict[geom_name] = None
|
31
|
+
|
32
|
+
def _worker(self, mp_dict, geom_dict, sys_str, scene_kwargs):
|
33
|
+
from ring.rendering import base_render
|
34
|
+
|
35
|
+
sys = System.from_str(sys_str)
|
36
|
+
while base_render._scene is None or base_render._scene._renderer.is_alive:
|
37
|
+
sys.render(jnp.array(mp_dict["q"]), interactive=True, **scene_kwargs)
|
38
|
+
|
39
|
+
if len(geom_dict) > 0:
|
40
|
+
model = base_render._scene._model
|
41
|
+
processed = []
|
42
|
+
for geom_name in list(geom_dict.keys()):
|
43
|
+
# Get the geometry ID
|
44
|
+
geom_id = model.geom(geom_name).id
|
45
|
+
# Set transparency to 0 (fully transparent)
|
46
|
+
model.geom_rgba[geom_id, 3] = 0
|
47
|
+
print(f"Made geom with name={geom_name} transparent (worker)")
|
48
|
+
processed.append(geom_name)
|
49
|
+
|
50
|
+
for geom_name in processed:
|
51
|
+
geom_dict.pop(geom_name)
|
52
|
+
|
53
|
+
def __enter__(self):
|
54
|
+
return self
|
55
|
+
|
56
|
+
def close(self):
|
57
|
+
self.process.terminate()
|
58
|
+
self.process.join()
|
59
|
+
|
60
|
+
def __exit__(self, exc_type, exc_value, traceback):
|
61
|
+
self.close()
|
62
|
+
|
63
|
+
|
64
|
+
def _fire_main(path_sys_xml: str, path_qs_np: Optional[str] = None, **scene_kwargs):
|
65
|
+
"""View motion given by trajectory of minimal coordinates in interactive viewer.
|
66
|
+
|
67
|
+
Args:
|
68
|
+
path_sys_xml (str): Path to xml file defining the system.
|
69
|
+
path_qs_np (str | None, optional): Path to numpy array containing the timeseries of minimal coordinates with
|
70
|
+
shape (T, DOF) where DOF is equal to `sys.q_size()`. Each minimal coordiante is from parent
|
71
|
+
to child. So for example a `spherical` joint that connects the first body to the worldbody
|
72
|
+
has a minimal coordinate of a quaternion that gives from worldbody to first body. The sampling
|
73
|
+
rate of the motion is inferred from the `sys.dt` attribute. If `None` (default), then simply renders the
|
74
|
+
unarticulated pose of the system.
|
75
|
+
""" # noqa: E501
|
76
|
+
|
77
|
+
sys = ring.System.from_xml(path_sys_xml)
|
78
|
+
if path_qs_np is None:
|
79
|
+
qs = np.array(ring.State.create(sys).q)[None]
|
80
|
+
else:
|
81
|
+
qs: np.ndarray = np.load(path_qs_np)
|
82
|
+
|
83
|
+
assert qs.ndim == 2, f"qs.shape = {qs.shape}"
|
84
|
+
T, Q = qs.shape
|
85
|
+
assert Q == sys.q_size(), f"Q={Q} != sys.q_size={sys.q_size()}"
|
86
|
+
dt_target = sys.dt
|
87
|
+
|
88
|
+
with InteractiveViewer(sys, width=640, height=480, **scene_kwargs) as viewer:
|
89
|
+
dt = dt_target
|
90
|
+
last_t = time.time()
|
91
|
+
t = -1
|
92
|
+
|
93
|
+
while True:
|
94
|
+
t = (t + 1) % T
|
95
|
+
|
96
|
+
while dt < dt_target:
|
97
|
+
time.sleep(0.001)
|
98
|
+
dt = time.time() - last_t
|
99
|
+
|
100
|
+
last_t = time.time()
|
101
|
+
viewer.update_q(qs[t])
|
102
|
+
dt = time.time() - last_t
|
103
|
+
|
104
|
+
# process will be stopped if the window is closed
|
105
|
+
if not viewer.process.is_alive():
|
106
|
+
break
|
107
|
+
|
108
|
+
|
109
|
+
def main():
|
110
|
+
fire.Fire(_fire_main)
|
111
|
+
|
112
|
+
|
113
|
+
if __name__ == "__main__":
|
114
|
+
main()
|