imt-ring 1.6.46__py3-none-any.whl → 1.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {imt_ring-1.6.46.dist-info → imt_ring-1.7.0.dist-info}/METADATA +2 -2
- {imt_ring-1.6.46.dist-info → imt_ring-1.7.0.dist-info}/RECORD +23 -20
- {imt_ring-1.6.46.dist-info → imt_ring-1.7.0.dist-info}/WHEEL +1 -1
- ring/__init__.py +6 -6
- ring/algorithms/generator/base.py +30 -1
- ring/base.py +2 -1
- ring/extras/__init__.py +0 -0
- ring/extras/interactive_viewer.py +114 -0
- ring/extras/torch_loss_fn.py +93 -0
- ring/rendering/base_render.py +30 -3
- ring/rendering/mujoco_render.py +38 -7
- ring/utils/__init__.py +0 -4
- {imt_ring-1.6.46.dist-info → imt_ring-1.7.0.dist-info}/entry_points.txt +0 -0
- {imt_ring-1.6.46.dist-info → imt_ring-1.7.0.dist-info}/top_level.txt +0 -0
- /ring/{utils → extras}/backend.py +0 -0
- /ring/{utils → extras}/colab.py +0 -0
- /ring/{utils → extras}/dataloader.py +0 -0
- /ring/{utils → extras}/dataloader_torch.py +0 -0
- /ring/{utils → extras}/hdf5.py +0 -0
- /ring/{utils → extras}/normalizer.py +0 -0
- /ring/{utils → extras}/randomize_sys.py +0 -0
- /ring/{utils → extras}/register_gym_envs/__init__.py +0 -0
- /ring/{utils → extras}/register_gym_envs/saddle.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
|
-
Metadata-Version: 2.
|
1
|
+
Metadata-Version: 2.4
|
2
2
|
Name: imt-ring
|
3
|
-
Version: 1.
|
3
|
+
Version: 1.7.0
|
4
4
|
Summary: RING: Recurrent Inertial Graph-based Estimator
|
5
5
|
Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
|
6
6
|
Project-URL: Homepage, https://github.com/SimiPixel/ring
|
@@ -1,6 +1,6 @@
|
|
1
|
-
ring/__init__.py,sha256=
|
1
|
+
ring/__init__.py,sha256=y3LuDekHyOCYdzaEDJM5dodClfderAKH-0ufklrwtHY,5266
|
2
2
|
ring/algebra.py,sha256=F0GwbP8LQP5qGVkoMUYJmkp9Hn2nKAVIkCVYDEjNjGU,3128
|
3
|
-
ring/base.py,sha256=
|
3
|
+
ring/base.py,sha256=AkG_Gpk7i2j77MzxzjiolJ9WNGcSq_3aOcpu0l6-0e0,50543
|
4
4
|
ring/maths.py,sha256=R22SNQutkf9v7Hp9klo0wvJVIyBQz0O8_5oJaDQcFis,12652
|
5
5
|
ring/spatial.py,sha256=nmZ-UhRanhyM34bez8uCS4wMwaKqLkuEbgKGP5XNH60,2351
|
6
6
|
ring/algorithms/__init__.py,sha256=IiK9EN5Xgs3dB075-A-H-Yad0Z7vzvKIJF2g6X_-C_8,1224
|
@@ -15,13 +15,25 @@ ring/algorithms/custom_joints/rr_joint.py,sha256=jnRtjtOCALMaq2_0bcu2d7qgfQ6etXp
|
|
15
15
|
ring/algorithms/custom_joints/rsaddle_joint.py,sha256=QoMo6NXdYgA9JygSzBvr0eCdd3qKhUgCrGPNO2Qdxko,1200
|
16
16
|
ring/algorithms/custom_joints/suntay.py,sha256=TZG307NqdMiXnNY63xEx8AkAjbQBQ4eO6DQ7R4j4D08,16726
|
17
17
|
ring/algorithms/generator/__init__.py,sha256=bF-CW3x2x-o6KWESKy-DuxzZPh3UNSjJb_MaAcSHGsQ,277
|
18
|
-
ring/algorithms/generator/base.py,sha256=
|
18
|
+
ring/algorithms/generator/base.py,sha256=sLIXfFliRUzUKaf84rBQjsExEfmU3XjENrYGD4fm1Q0,23808
|
19
19
|
ring/algorithms/generator/batch.py,sha256=xp1X8oYtwI6l2cH4GRu9zw-P8dnh-X1FWTSyixEfgr8,2652
|
20
20
|
ring/algorithms/generator/finalize_fns.py,sha256=ty1NaU-Mghx1RL-voivDjS0TWSKNtjTmbdmBnShhn7k,10398
|
21
21
|
ring/algorithms/generator/motion_artifacts.py,sha256=2VJbldVDbI3PSyboshIbtYvSAKzBBwGV7cQfYjqvluM,9167
|
22
22
|
ring/algorithms/generator/pd_control.py,sha256=dHnhJZx_FqrHD4xFXpQZH-R7rputFkAVGwoBGccZnz4,5767
|
23
23
|
ring/algorithms/generator/setup_fns.py,sha256=MFz3czHBeWs1Zk1A8O02CyQpQ-NCyW9PMpbqmKit6es,1455
|
24
24
|
ring/algorithms/generator/types.py,sha256=HjNyATFSLfHkXlzdJhvUkiqnhzpXFDDXmWS3LYBlOtU,721
|
25
|
+
ring/extras/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
26
|
+
ring/extras/backend.py,sha256=cKSi9sB59texqKzNVASTDczGKLCBL8VVDiP7TNdj41k,1294
|
27
|
+
ring/extras/colab.py,sha256=ZLHwP0jNQUsmZJU4l68a5djULPi6T-jYNNHevjIoMn8,1631
|
28
|
+
ring/extras/dataloader.py,sha256=dfNPjnxDoKxWGKSImuJ_49CWgBn73vxSEek8COq9nNk,3749
|
29
|
+
ring/extras/dataloader_torch.py,sha256=t2DDiB9ZHb_SzFlVbntCGGIybj4F-NoA0PaB4_afjGw,3983
|
30
|
+
ring/extras/hdf5.py,sha256=XPIrwogD-d544yy08UJyfLVp1ZKRUtiZukW7RA8VUxQ,5856
|
31
|
+
ring/extras/interactive_viewer.py,sha256=vQEzcBDdG3BPqTGEktC74DsCfvgKktj9DKWK8gBzRtE,3805
|
32
|
+
ring/extras/normalizer.py,sha256=o26stPP6EHasZQxQX0vKqTrhUNZBaJ2O17L6W_gBMN4,1699
|
33
|
+
ring/extras/randomize_sys.py,sha256=G_vBIo0OwQkXL2u0djwbaoaeb02C4LQCTNNloOYIU2M,3699
|
34
|
+
ring/extras/torch_loss_fn.py,sha256=1LnWTmtxXPxoQFr4QixW12AjpRUfrseSDBmifhu6ErE,2676
|
35
|
+
ring/extras/register_gym_envs/__init__.py,sha256=PtPIRBQJ16339xZ9G9VpvqrvcGbQ_Pk_SUz4tQPa9nQ,94
|
36
|
+
ring/extras/register_gym_envs/saddle.py,sha256=tA5CyW_akSXyDm0xJ83CtOrUMVElH0f9vZtEDDJQalI,4422
|
25
37
|
ring/io/__init__.py,sha256=1gEJdyDCbldbbm8QeZbLmhzSKmaQ-UqTmQgu4DBH2Z4,328
|
26
38
|
ring/io/examples.py,sha256=KLf2iCagvRfjs9MCnQsLUlfGBjrQKrD-Qv8U0TtX6Ek,1114
|
27
39
|
ring/io/test_examples.py,sha256=htpnSgLG9Fi9_qwSL4F1yLi9sN7ZUrF8dDmiqU3B510,117
|
@@ -63,8 +75,8 @@ ring/ml/training_loop.py,sha256=yxuUua_4RExq_0GUYm4eUZJsBmtrwDSVL94bWUpYfdo,3586
|
|
63
75
|
ring/ml/params/0x13e3518065c21cd8.pickle,sha256=Zh2k1zK-TNxJl5F7nyTeQ9001qqRE_dfvaq1HWV287A,9355838
|
64
76
|
ring/ml/params/0x1d76628065a71e0f.pickle,sha256=YTNVuvfw-nCRD9BH1PZYcR9uCFpNWDhw8Lc50eDn_EE,9351038
|
65
77
|
ring/rendering/__init__.py,sha256=Zf7qOdzK3t2hljIrs5P4zFhzHljLSMRyDDZO2YlZk4k,75
|
66
|
-
ring/rendering/base_render.py,sha256=
|
67
|
-
ring/rendering/mujoco_render.py,sha256=
|
78
|
+
ring/rendering/base_render.py,sha256=O8Oo9znAgWRE09R7B2yecpwNDJ5veIRoMci144oHwF8,10554
|
79
|
+
ring/rendering/mujoco_render.py,sha256=eCmnnzwVZ3BeIo1INswXMZaZ9TDaF1HO50f70spXX2E,9704
|
68
80
|
ring/rendering/vispy_render.py,sha256=6Z6S5LNZ7iy9BN1GVb9EDe-Tix5N_SQ1s7ZsfiTSDEA,10261
|
69
81
|
ring/rendering/vispy_visuals.py,sha256=ooBZqppnebeL0ANe6V6zUgnNTtDcdkOsa4vZuM4sx-I,7873
|
70
82
|
ring/sim2real/__init__.py,sha256=gCLYg8IoMdzUagzhCFcfjZ5GavtIU772L7HR0G5hUtM,251
|
@@ -73,21 +85,12 @@ ring/sys_composer/__init__.py,sha256=5J_JJJIHfTPcpxh0v4FqiOs81V1REPUd7pgiw2nAN5E
|
|
73
85
|
ring/sys_composer/delete_sys.py,sha256=cIM9KbyLfg7B9121g7yjzuFbjeNu9cil1dPavAYEgzk,3408
|
74
86
|
ring/sys_composer/inject_sys.py,sha256=PLuxLbXU7hPtAsqvpsEim9hkoVE26ddrg3OipZNvnhU,3504
|
75
87
|
ring/sys_composer/morph_sys.py,sha256=2GpPtS5hT0eZMptdGpt30Hc97OykJNE67lEVRf7sHrc,12700
|
76
|
-
ring/utils/__init__.py,sha256=
|
77
|
-
ring/utils/backend.py,sha256=cKSi9sB59texqKzNVASTDczGKLCBL8VVDiP7TNdj41k,1294
|
88
|
+
ring/utils/__init__.py,sha256=Q37bjy2wjRGggd77MHlgl_50i2zOuVnPny4yOLiTe-8,567
|
78
89
|
ring/utils/batchsize.py,sha256=uCj8LG7elbjEUUzuK29Z3I9T8bxJTcsybY3DdGeqhQs,1786
|
79
|
-
ring/utils/colab.py,sha256=ZLHwP0jNQUsmZJU4l68a5djULPi6T-jYNNHevjIoMn8,1631
|
80
|
-
ring/utils/dataloader.py,sha256=dfNPjnxDoKxWGKSImuJ_49CWgBn73vxSEek8COq9nNk,3749
|
81
|
-
ring/utils/dataloader_torch.py,sha256=t2DDiB9ZHb_SzFlVbntCGGIybj4F-NoA0PaB4_afjGw,3983
|
82
|
-
ring/utils/hdf5.py,sha256=XPIrwogD-d544yy08UJyfLVp1ZKRUtiZukW7RA8VUxQ,5856
|
83
|
-
ring/utils/normalizer.py,sha256=o26stPP6EHasZQxQX0vKqTrhUNZBaJ2O17L6W_gBMN4,1699
|
84
90
|
ring/utils/path.py,sha256=zRPfxYNesvgefkddd26oar6f9433LkMGkhp9dF3rPUs,1926
|
85
|
-
ring/utils/randomize_sys.py,sha256=G_vBIo0OwQkXL2u0djwbaoaeb02C4LQCTNNloOYIU2M,3699
|
86
91
|
ring/utils/utils.py,sha256=gKwOXLxWraeZfX6EbBcg3hkq30DcXN0mcRUeOSTNiMo,7336
|
87
|
-
|
88
|
-
|
89
|
-
imt_ring-1.
|
90
|
-
imt_ring-1.
|
91
|
-
imt_ring-1.
|
92
|
-
imt_ring-1.6.46.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
|
93
|
-
imt_ring-1.6.46.dist-info/RECORD,,
|
92
|
+
imt_ring-1.7.0.dist-info/METADATA,sha256=CNwgvWr9Yu7MgIfcNwXkuByr7_8vkxvO5IkJg3iDKbs,5887
|
93
|
+
imt_ring-1.7.0.dist-info/WHEEL,sha256=SmOxYU7pzNKBqASvQJ7DjX3XGUF92lrGhMb3R6_iiqI,91
|
94
|
+
imt_ring-1.7.0.dist-info/entry_points.txt,sha256=npNqSOvNiBR0BNa_GL3J66q8Gky3h0G_PHzHzk8oyE0,66
|
95
|
+
imt_ring-1.7.0.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
|
96
|
+
imt_ring-1.7.0.dist-info/RECORD,,
|
ring/__init__.py
CHANGED
@@ -35,12 +35,12 @@ def RING(lam: list[int] | None, Ts: float | None, **kwargs) -> ml.AbstractFilter
|
|
35
35
|
>>> import ring
|
36
36
|
>>> import numpy as np
|
37
37
|
>>>
|
38
|
-
>>> T : int = 30
|
39
|
-
>>> Ts : float = 0.01
|
40
|
-
>>> B : int = 1
|
41
|
-
>>> lam: list[int] = [0, 1
|
42
|
-
>>> N : int = len(lam)
|
43
|
-
>>> T_i: int = int(T/Ts)
|
38
|
+
>>> T : int = 30 # sequence length [s]
|
39
|
+
>>> Ts : float = 0.01 # sampling interval [s]
|
40
|
+
>>> B : int = 1 # batch size
|
41
|
+
>>> lam: list[int] = [-1, 0, 1] # parent array
|
42
|
+
>>> N : int = len(lam) # number of bodies
|
43
|
+
>>> T_i: int = int(T/Ts) # number of timesteps
|
44
44
|
>>>
|
45
45
|
>>> X = np.zeros((B, T_i, N, 9))
|
46
46
|
>>> # where X is structured as follows:
|
@@ -1,5 +1,6 @@
|
|
1
1
|
from dataclasses import replace
|
2
2
|
from functools import partial
|
3
|
+
import json
|
3
4
|
import logging
|
4
5
|
import random
|
5
6
|
from typing import Callable, Optional
|
@@ -136,6 +137,14 @@ class RCMG:
|
|
136
137
|
affecting joint motion behavior.
|
137
138
|
""" # noqa: E501
|
138
139
|
|
140
|
+
# capture all funtion arguments before creating local variables
|
141
|
+
to_json_kwargs = locals()
|
142
|
+
# the purpose is to not capture the RCMG itself since we want to make it
|
143
|
+
# serialisable in the first place
|
144
|
+
to_json_kwargs.pop("self")
|
145
|
+
to_json_kwargs.pop("sys")
|
146
|
+
to_json_kwargs.pop("config")
|
147
|
+
|
139
148
|
# add some default values
|
140
149
|
randomize_hz_kwargs_defaults = dict(add_dt=True)
|
141
150
|
randomize_hz_kwargs_defaults.update(randomize_hz_kwargs)
|
@@ -186,6 +195,11 @@ class RCMG:
|
|
186
195
|
|
187
196
|
self._disable_tqdm = disable_tqdm
|
188
197
|
|
198
|
+
# store arguments that fully define the RCMG objects for use in `.to_json`
|
199
|
+
self._to_json_sys = sys
|
200
|
+
self._to_json_mconfig = config
|
201
|
+
self._to_json_kwargs = to_json_kwargs
|
202
|
+
|
189
203
|
def _compute_repeats(self, sizes: int | list[int]) -> list[int]:
|
190
204
|
"how many times the generators are repeated to create a batch of `sizes`"
|
191
205
|
|
@@ -355,6 +369,21 @@ class RCMG:
|
|
355
369
|
|
356
370
|
return generator
|
357
371
|
|
372
|
+
def serialise_to_dict(self) -> dict:
|
373
|
+
dict_representation = {
|
374
|
+
"system": [_sys.to_str(warn=False) for _sys in self._to_json_sys],
|
375
|
+
"motion_configs": [_config.__dict__ for _config in self._to_json_mconfig],
|
376
|
+
"kwargs": self._to_json_kwargs,
|
377
|
+
}
|
378
|
+
return dict_representation
|
379
|
+
|
380
|
+
def serialise_to_json(self, path_of_json: str) -> None:
|
381
|
+
with open(path_of_json, "w") as file:
|
382
|
+
json.dump(self.serialise_to_dict(), file, indent=4)
|
383
|
+
|
384
|
+
def from_json(self, path_to_json: str) -> "RCMG":
|
385
|
+
raise NotImplementedError
|
386
|
+
|
358
387
|
|
359
388
|
def _copy_dicts(f) -> dict:
|
360
389
|
def _f(*args, **kwargs):
|
@@ -526,7 +555,7 @@ def draw_random_q(
|
|
526
555
|
sys: base.System,
|
527
556
|
config: jcalc.MotionConfig,
|
528
557
|
N: int | None,
|
529
|
-
) -> tuple[
|
558
|
+
) -> tuple[jax.random.PRNGKey, jax.Array]:
|
530
559
|
|
531
560
|
key_start = key
|
532
561
|
# build generalized coordintes vector `q`
|
ring/base.py
CHANGED
@@ -981,6 +981,7 @@ class System(_Base):
|
|
981
981
|
|
982
982
|
def render(
|
983
983
|
self,
|
984
|
+
qs: Optional[jax.Array | list[jax.Array]] = None,
|
984
985
|
xs: Optional[Transform | list[Transform]] = None,
|
985
986
|
camera: Optional[str] = None,
|
986
987
|
show_pbar: bool = True,
|
@@ -1001,7 +1002,7 @@ class System(_Base):
|
|
1001
1002
|
list[np.ndarray]: Stacked rendered frames. Length == len(xs).
|
1002
1003
|
"""
|
1003
1004
|
return ring.rendering.render(
|
1004
|
-
self, xs, camera, show_pbar, backend, render_every_nth, **scene_kwargs
|
1005
|
+
self, qs, xs, camera, show_pbar, backend, render_every_nth, **scene_kwargs
|
1005
1006
|
)
|
1006
1007
|
|
1007
1008
|
def render_prediction(
|
ring/extras/__init__.py
ADDED
File without changes
|
@@ -0,0 +1,114 @@
|
|
1
|
+
import multiprocessing
|
2
|
+
import time
|
3
|
+
from typing import Optional
|
4
|
+
|
5
|
+
import fire
|
6
|
+
import jax.numpy as jnp
|
7
|
+
import numpy as np
|
8
|
+
|
9
|
+
import ring
|
10
|
+
from ring import System
|
11
|
+
|
12
|
+
|
13
|
+
class InteractiveViewer:
|
14
|
+
def __init__(self, sys: ring.System, **scene_kwargs):
|
15
|
+
self._mp_dict = multiprocessing.Manager().dict()
|
16
|
+
self._geom_dict = multiprocessing.Manager().dict()
|
17
|
+
self.update_q(np.array(ring.State.create(sys).q))
|
18
|
+
self.process = multiprocessing.Process(
|
19
|
+
target=self._worker,
|
20
|
+
args=(self._mp_dict, self._geom_dict, sys.to_str(), scene_kwargs),
|
21
|
+
)
|
22
|
+
self.process.start()
|
23
|
+
|
24
|
+
def update_q(self, q: np.ndarray):
|
25
|
+
self._mp_dict["q"] = q
|
26
|
+
|
27
|
+
def make_geometry_transparent(self, body_number: int, geom_number: int):
|
28
|
+
geom_name = f"body{body_number}_geom{geom_number}"
|
29
|
+
# the value is not used
|
30
|
+
self._geom_dict[geom_name] = None
|
31
|
+
|
32
|
+
def _worker(self, mp_dict, geom_dict, sys_str, scene_kwargs):
|
33
|
+
from ring.rendering import base_render
|
34
|
+
|
35
|
+
sys = System.from_str(sys_str)
|
36
|
+
while base_render._scene is None or base_render._scene._renderer.is_alive:
|
37
|
+
sys.render(jnp.array(mp_dict["q"]), interactive=True, **scene_kwargs)
|
38
|
+
|
39
|
+
if len(geom_dict) > 0:
|
40
|
+
model = base_render._scene._model
|
41
|
+
processed = []
|
42
|
+
for geom_name in list(geom_dict.keys()):
|
43
|
+
# Get the geometry ID
|
44
|
+
geom_id = model.geom(geom_name).id
|
45
|
+
# Set transparency to 0 (fully transparent)
|
46
|
+
model.geom_rgba[geom_id, 3] = 0
|
47
|
+
print(f"Made geom with name={geom_name} transparent (worker)")
|
48
|
+
processed.append(geom_name)
|
49
|
+
|
50
|
+
for geom_name in processed:
|
51
|
+
geom_dict.pop(geom_name)
|
52
|
+
|
53
|
+
def __enter__(self):
|
54
|
+
return self
|
55
|
+
|
56
|
+
def close(self):
|
57
|
+
self.process.terminate()
|
58
|
+
self.process.join()
|
59
|
+
|
60
|
+
def __exit__(self, exc_type, exc_value, traceback):
|
61
|
+
self.close()
|
62
|
+
|
63
|
+
|
64
|
+
def _fire_main(path_sys_xml: str, path_qs_np: Optional[str] = None, **scene_kwargs):
|
65
|
+
"""View motion given by trajectory of minimal coordinates in interactive viewer.
|
66
|
+
|
67
|
+
Args:
|
68
|
+
path_sys_xml (str): Path to xml file defining the system.
|
69
|
+
path_qs_np (str | None, optional): Path to numpy array containing the timeseries of minimal coordinates with
|
70
|
+
shape (T, DOF) where DOF is equal to `sys.q_size()`. Each minimal coordiante is from parent
|
71
|
+
to child. So for example a `spherical` joint that connects the first body to the worldbody
|
72
|
+
has a minimal coordinate of a quaternion that gives from worldbody to first body. The sampling
|
73
|
+
rate of the motion is inferred from the `sys.dt` attribute. If `None` (default), then simply renders the
|
74
|
+
unarticulated pose of the system.
|
75
|
+
""" # noqa: E501
|
76
|
+
|
77
|
+
sys = ring.System.from_xml(path_sys_xml)
|
78
|
+
if path_qs_np is None:
|
79
|
+
qs = np.array(ring.State.create(sys).q)[None]
|
80
|
+
else:
|
81
|
+
qs: np.ndarray = np.load(path_qs_np)
|
82
|
+
|
83
|
+
assert qs.ndim == 2, f"qs.shape = {qs.shape}"
|
84
|
+
T, Q = qs.shape
|
85
|
+
assert Q == sys.q_size()
|
86
|
+
dt_target = sys.dt
|
87
|
+
|
88
|
+
with InteractiveViewer(sys, **scene_kwargs) as viewer:
|
89
|
+
dt = dt_target
|
90
|
+
last_t = time.time()
|
91
|
+
t = -1
|
92
|
+
|
93
|
+
while True:
|
94
|
+
t = (t + 1) % T
|
95
|
+
|
96
|
+
while dt < dt_target:
|
97
|
+
time.sleep(0.001)
|
98
|
+
dt = time.time() - last_t
|
99
|
+
|
100
|
+
last_t = time.time()
|
101
|
+
viewer.update_q(qs[t])
|
102
|
+
dt = time.time() - last_t
|
103
|
+
|
104
|
+
# process will be stopped if the window is closed
|
105
|
+
if not viewer.process.is_alive():
|
106
|
+
break
|
107
|
+
|
108
|
+
|
109
|
+
def main():
|
110
|
+
fire.Fire(_fire_main)
|
111
|
+
|
112
|
+
|
113
|
+
if __name__ == "__main__":
|
114
|
+
main()
|
@@ -0,0 +1,93 @@
|
|
1
|
+
"""This module exports a loss function `loss_fn` for training neural networks that
|
2
|
+
output quaternions in PyTorch"""
|
3
|
+
|
4
|
+
from typing import Sequence
|
5
|
+
|
6
|
+
import torch
|
7
|
+
|
8
|
+
|
9
|
+
def quat_mul(u: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
|
10
|
+
"Multiplies two quaternions."
|
11
|
+
q = torch.stack(
|
12
|
+
[
|
13
|
+
u[..., 0] * v[..., 0]
|
14
|
+
- u[..., 1] * v[..., 1]
|
15
|
+
- u[..., 2] * v[..., 2]
|
16
|
+
- u[..., 3] * v[..., 3],
|
17
|
+
u[..., 0] * v[..., 1]
|
18
|
+
+ u[..., 1] * v[..., 0]
|
19
|
+
+ u[..., 2] * v[..., 3]
|
20
|
+
- u[..., 3] * v[..., 2],
|
21
|
+
u[..., 0] * v[..., 2]
|
22
|
+
- u[..., 1] * v[..., 3]
|
23
|
+
+ u[..., 2] * v[..., 0]
|
24
|
+
+ u[..., 3] * v[..., 1],
|
25
|
+
u[..., 0] * v[..., 3]
|
26
|
+
+ u[..., 1] * v[..., 2]
|
27
|
+
- u[..., 2] * v[..., 1]
|
28
|
+
+ u[..., 3] * v[..., 0],
|
29
|
+
],
|
30
|
+
dim=-1,
|
31
|
+
)
|
32
|
+
return q
|
33
|
+
|
34
|
+
|
35
|
+
def quat_inv(q: torch.Tensor):
|
36
|
+
return torch.concat([q[..., :1], -q[..., 1:]], dim=-1)
|
37
|
+
|
38
|
+
|
39
|
+
def wrap_to_pi(phi):
|
40
|
+
"Wraps angle `phi` (radians) to interval [-pi, pi]."
|
41
|
+
return (phi + torch.pi) % (2 * torch.pi) - torch.pi
|
42
|
+
|
43
|
+
|
44
|
+
def quat_angle(q: torch.Tensor):
|
45
|
+
phi = 2 * torch.arctan2(torch.norm(q[..., 1:], dim=-1), q[..., 0])
|
46
|
+
return wrap_to_pi(phi)
|
47
|
+
|
48
|
+
|
49
|
+
def safe_normalize(x):
|
50
|
+
return x / (1e-6 + torch.norm(x, dim=-1, keepdim=True))
|
51
|
+
|
52
|
+
|
53
|
+
def quat_qrel(q1, q2):
|
54
|
+
"q1^-1 * q2"
|
55
|
+
return quat_mul(quat_inv(q1), q2)
|
56
|
+
|
57
|
+
|
58
|
+
@torch.jit.script
|
59
|
+
def angle_error(q, qhat):
|
60
|
+
"Absolute angle error in radians"
|
61
|
+
return torch.abs(quat_angle(quat_qrel(q, qhat)))
|
62
|
+
|
63
|
+
|
64
|
+
@torch.jit.script
|
65
|
+
def inclination_error(q, qhat):
|
66
|
+
"Absolute inclination error in radians. `q`s are from body-to-eps"
|
67
|
+
q_rel = quat_mul(q, quat_inv(qhat))
|
68
|
+
phi_pri = 2 * torch.arctan2(q_rel[..., 3], q_rel[..., 0])
|
69
|
+
q_pri = torch.zeros_like(q)
|
70
|
+
q_pri[..., 0] = torch.cos(phi_pri / 2)
|
71
|
+
q_pri[..., 3] = torch.sin(phi_pri / 2)
|
72
|
+
q_res = quat_mul(q_rel, quat_inv(q_pri))
|
73
|
+
return torch.abs(quat_angle(q_res))
|
74
|
+
|
75
|
+
|
76
|
+
def loss_fn(lam: Sequence[int], q: torch.Tensor, qhat: torch.Tensor) -> torch.Tensor:
|
77
|
+
"(..., N, 4) -> (..., N)"
|
78
|
+
*batch_dims, N, F = q.shape
|
79
|
+
assert q.shape == qhat.shape
|
80
|
+
assert F == 4
|
81
|
+
assert N == len(lam)
|
82
|
+
permu = list(reversed(range(q.ndim - 1)))
|
83
|
+
loss_incl = inclination_error(q, qhat).permute(*permu)
|
84
|
+
loss_mae = angle_error(q, qhat).permute(*permu)
|
85
|
+
lam = torch.tensor(lam, device=q.device)
|
86
|
+
return torch.where(
|
87
|
+
lam.reshape(-1, *[1] * len(batch_dims)) == -1, loss_incl, loss_mae
|
88
|
+
).permute(*permu)
|
89
|
+
|
90
|
+
|
91
|
+
def quat_rand(*size: tuple[int]):
|
92
|
+
qs = torch.randn(size=size + (4,))
|
93
|
+
return qs / torch.norm(qs, dim=-1, keepdim=True)
|
ring/rendering/base_render.py
CHANGED
@@ -1,3 +1,4 @@
|
|
1
|
+
from functools import partial
|
1
2
|
from typing import Optional
|
2
3
|
|
3
4
|
import jax
|
@@ -93,15 +94,29 @@ def _load_scene(sys, backend, **scene_kwargs):
|
|
93
94
|
return _scene
|
94
95
|
|
95
96
|
|
97
|
+
@jax.jit
|
98
|
+
def _jit_forward_kinematics(sys):
|
99
|
+
_, state = kinematics.forward_kinematics(sys, base.State.create(sys))
|
100
|
+
return state.x
|
101
|
+
|
102
|
+
|
103
|
+
@jax.jit
|
104
|
+
@partial(jax.vmap, in_axes=(None, 0))
|
105
|
+
def _jit_vmap_forward_kinematics(sys, q):
|
106
|
+
_, state = kinematics.forward_kinematics(sys, base.State.create(sys, q=q))
|
107
|
+
return state.x
|
108
|
+
|
109
|
+
|
96
110
|
def render(
|
97
111
|
sys: base.System,
|
112
|
+
qs: Optional[jax.Array | list[jax.Array]] = None,
|
98
113
|
xs: Optional[base.Transform | list[base.Transform]] = None,
|
99
114
|
camera: Optional[str] = None,
|
100
115
|
show_pbar: bool = True,
|
101
116
|
backend: str = "mujoco",
|
102
117
|
render_every_nth: int = 1,
|
103
118
|
**scene_kwargs,
|
104
|
-
) -> list[np.ndarray]:
|
119
|
+
) -> list[np.ndarray | None]:
|
105
120
|
"""Render frames from system and trajectory of maximal coordinates `xs`.
|
106
121
|
|
107
122
|
Args:
|
@@ -114,9 +129,18 @@ def render(
|
|
114
129
|
Returns:
|
115
130
|
list[np.ndarray]: Stacked rendered frames. Length == len(xs).
|
116
131
|
"""
|
132
|
+
assert not (qs is not None and xs is not None)
|
117
133
|
|
134
|
+
if xs is None and qs is None:
|
135
|
+
xs = _jit_forward_kinematics(sys)
|
118
136
|
if xs is None:
|
119
|
-
xs
|
137
|
+
# throw error if `xs` has been given by accident as `qs` argument
|
138
|
+
qs = utils.to_list(qs)
|
139
|
+
assert not isinstance(
|
140
|
+
qs[0], base.Transform
|
141
|
+
), "`qs` should be `jax.Array` and not `Transform`; maybe you want to pass `xs` as keyword argument `xs=xs`?" # noqa: E501
|
142
|
+
qs = jnp.stack(qs, axis=0)
|
143
|
+
xs = _jit_vmap_forward_kinematics(sys, qs)
|
120
144
|
|
121
145
|
# convert time-axis of batched xs object into a list of unbatched x objects
|
122
146
|
if isinstance(xs, base.Transform) and xs.ndim() == 3:
|
@@ -144,6 +168,9 @@ def render(
|
|
144
168
|
|
145
169
|
scene = _load_scene(sys, backend, **scene_kwargs)
|
146
170
|
|
171
|
+
if scene_kwargs.get("interactive", False):
|
172
|
+
show_pbar = False
|
173
|
+
|
147
174
|
frames = []
|
148
175
|
for x in tqdm.tqdm(xs, "Rendering frames..", disable=not show_pbar):
|
149
176
|
scene.update(x)
|
@@ -241,7 +268,7 @@ def render_prediction(
|
|
241
268
|
sys, xs, yhat, transparent_segment_to_root, offset_truth, offset_pred
|
242
269
|
)
|
243
270
|
|
244
|
-
frames = render(sys_render, xs_render, **kwargs)
|
271
|
+
frames = render(sys=sys_render, xs=xs_render, **kwargs)
|
245
272
|
return frames
|
246
273
|
|
247
274
|
|
ring/rendering/mujoco_render.py
CHANGED
@@ -118,8 +118,10 @@ def _xml_str_one_body(
|
|
118
118
|
body_number: int, geoms: list[base.Geometry], cameras: list[str], lights: list[str]
|
119
119
|
) -> str:
|
120
120
|
inside_body_geoms = ""
|
121
|
-
for geom in geoms:
|
122
|
-
inside_body_geoms += _xml_str_one_geom(
|
121
|
+
for geom_number, geom in enumerate(geoms):
|
122
|
+
inside_body_geoms += _xml_str_one_geom(
|
123
|
+
geom, name=f"body{body_number}_geom{geom_number}"
|
124
|
+
)
|
123
125
|
|
124
126
|
inside_body_cameras = ""
|
125
127
|
for camera in cameras:
|
@@ -138,7 +140,7 @@ def _xml_str_one_body(
|
|
138
140
|
"""
|
139
141
|
|
140
142
|
|
141
|
-
def _xml_str_one_geom(geom: base.Geometry) -> str:
|
143
|
+
def _xml_str_one_geom(geom: base.Geometry, name: str) -> str:
|
142
144
|
rgba = f'rgba="{_array_to_str(geom.color)}"'
|
143
145
|
|
144
146
|
if isinstance(geom, base.Box):
|
@@ -158,7 +160,8 @@ def _xml_str_one_geom(geom: base.Geometry) -> str:
|
|
158
160
|
|
159
161
|
rot, pos = maths.quat_inv(geom.transform.rot), geom.transform.pos
|
160
162
|
rot, pos = f'pos="{_array_to_str(pos)}"', f'quat="{_array_to_str(rot)}"'
|
161
|
-
|
163
|
+
name = f'name="{name}"'
|
164
|
+
return f"<geom {type_size} {rgba} {rot} {pos} {name}/>"
|
162
165
|
|
163
166
|
|
164
167
|
def _array_to_str(arr: Sequence[float]) -> str:
|
@@ -181,6 +184,8 @@ class MujocoScene:
|
|
181
184
|
floor_z: float = -0.84,
|
182
185
|
floor_material: str = "matplane",
|
183
186
|
debug: bool = False,
|
187
|
+
interactive: bool = False,
|
188
|
+
interactive_hide_menu: bool = False,
|
184
189
|
) -> None:
|
185
190
|
self.debug = debug
|
186
191
|
self.height, self.width = height, width
|
@@ -195,6 +200,8 @@ class MujocoScene:
|
|
195
200
|
self.show_stars = show_stars
|
196
201
|
self.show_floor = show_floor
|
197
202
|
self.floor_kwargs = dict(z=floor_z, material=floor_material)
|
203
|
+
self.interactive = interactive
|
204
|
+
self.interactive_hide_menu = interactive_hide_menu
|
198
205
|
|
199
206
|
def init(self, geoms: list[base.Geometry]):
|
200
207
|
self._parent_ids = list(set([geom.link_idx for geom in geoms]))
|
@@ -208,7 +215,22 @@ class MujocoScene:
|
|
208
215
|
debug=self.debug,
|
209
216
|
)
|
210
217
|
self._data = mujoco.MjData(self._model)
|
211
|
-
|
218
|
+
if self.interactive:
|
219
|
+
import mujoco_viewer
|
220
|
+
|
221
|
+
self._renderer = mujoco_viewer.MujocoViewer(
|
222
|
+
self._model,
|
223
|
+
self._data,
|
224
|
+
width=self.width,
|
225
|
+
height=self.height,
|
226
|
+
hide_menus=self.interactive_hide_menu,
|
227
|
+
)
|
228
|
+
|
229
|
+
if self.interactive_hide_menu:
|
230
|
+
print("Menu can be shown with key `H` for H(elp)")
|
231
|
+
|
232
|
+
else:
|
233
|
+
self._renderer = mujoco.Renderer(self._model, self.height, self.width)
|
212
234
|
|
213
235
|
def update(self, x: base.Transform):
|
214
236
|
rot, pos = maths.quat_inv(x.rot), x.pos
|
@@ -234,6 +256,15 @@ class MujocoScene:
|
|
234
256
|
|
235
257
|
mujoco.mj_forward(self._model, self._data)
|
236
258
|
|
237
|
-
def render(self, camera: Optional[str] = None):
|
238
|
-
|
259
|
+
def render(self, camera: Optional[str] = None) -> np.ndarray | None:
|
260
|
+
if not self.interactive:
|
261
|
+
self._renderer.update_scene(
|
262
|
+
self._data, camera=-1 if camera is None else camera
|
263
|
+
)
|
239
264
|
return self._renderer.render()
|
265
|
+
|
266
|
+
def close(self):
|
267
|
+
self._renderer.close()
|
268
|
+
|
269
|
+
def __del__(self):
|
270
|
+
self.close()
|
ring/utils/__init__.py
CHANGED
@@ -1,11 +1,7 @@
|
|
1
|
-
from . import randomize_sys
|
2
1
|
from .batchsize import batchsize_thresholds
|
3
2
|
from .batchsize import distribute_batchsize
|
4
3
|
from .batchsize import expand_batchsize
|
5
4
|
from .batchsize import merge_batchsize
|
6
|
-
from .colab import setup_colab_env
|
7
|
-
from .normalizer import make_normalizer_from_generator
|
8
|
-
from .normalizer import Normalizer
|
9
5
|
from .path import parse_path
|
10
6
|
from .utils import dict_to_nested
|
11
7
|
from .utils import dict_union
|
File without changes
|
File without changes
|
File without changes
|
/ring/{utils → extras}/colab.py
RENAMED
File without changes
|
File without changes
|
File without changes
|
/ring/{utils → extras}/hdf5.py
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|