imt-ring 1.6.46__py3-none-any.whl → 1.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.2
1
+ Metadata-Version: 2.4
2
2
  Name: imt-ring
3
- Version: 1.6.46
3
+ Version: 1.7.0
4
4
  Summary: RING: Recurrent Inertial Graph-based Estimator
5
5
  Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
6
6
  Project-URL: Homepage, https://github.com/SimiPixel/ring
@@ -1,6 +1,6 @@
1
- ring/__init__.py,sha256=H1Rd2uXVkux4Z792XyHIkQ8OpDSZBiPqFwyAFDWDU3E,5260
1
+ ring/__init__.py,sha256=y3LuDekHyOCYdzaEDJM5dodClfderAKH-0ufklrwtHY,5266
2
2
  ring/algebra.py,sha256=F0GwbP8LQP5qGVkoMUYJmkp9Hn2nKAVIkCVYDEjNjGU,3128
3
- ring/base.py,sha256=zromjIuMpNBoyiwHa9OCyZvAz7jHjXHZIdRt8fN8PoA,50481
3
+ ring/base.py,sha256=AkG_Gpk7i2j77MzxzjiolJ9WNGcSq_3aOcpu0l6-0e0,50543
4
4
  ring/maths.py,sha256=R22SNQutkf9v7Hp9klo0wvJVIyBQz0O8_5oJaDQcFis,12652
5
5
  ring/spatial.py,sha256=nmZ-UhRanhyM34bez8uCS4wMwaKqLkuEbgKGP5XNH60,2351
6
6
  ring/algorithms/__init__.py,sha256=IiK9EN5Xgs3dB075-A-H-Yad0Z7vzvKIJF2g6X_-C_8,1224
@@ -15,13 +15,25 @@ ring/algorithms/custom_joints/rr_joint.py,sha256=jnRtjtOCALMaq2_0bcu2d7qgfQ6etXp
15
15
  ring/algorithms/custom_joints/rsaddle_joint.py,sha256=QoMo6NXdYgA9JygSzBvr0eCdd3qKhUgCrGPNO2Qdxko,1200
16
16
  ring/algorithms/custom_joints/suntay.py,sha256=TZG307NqdMiXnNY63xEx8AkAjbQBQ4eO6DQ7R4j4D08,16726
17
17
  ring/algorithms/generator/__init__.py,sha256=bF-CW3x2x-o6KWESKy-DuxzZPh3UNSjJb_MaAcSHGsQ,277
18
- ring/algorithms/generator/base.py,sha256=yPH_RIQPU_nlq58HyZ6T3RUm1S5chA3-Ro__-ArYTq0,22669
18
+ ring/algorithms/generator/base.py,sha256=sLIXfFliRUzUKaf84rBQjsExEfmU3XjENrYGD4fm1Q0,23808
19
19
  ring/algorithms/generator/batch.py,sha256=xp1X8oYtwI6l2cH4GRu9zw-P8dnh-X1FWTSyixEfgr8,2652
20
20
  ring/algorithms/generator/finalize_fns.py,sha256=ty1NaU-Mghx1RL-voivDjS0TWSKNtjTmbdmBnShhn7k,10398
21
21
  ring/algorithms/generator/motion_artifacts.py,sha256=2VJbldVDbI3PSyboshIbtYvSAKzBBwGV7cQfYjqvluM,9167
22
22
  ring/algorithms/generator/pd_control.py,sha256=dHnhJZx_FqrHD4xFXpQZH-R7rputFkAVGwoBGccZnz4,5767
23
23
  ring/algorithms/generator/setup_fns.py,sha256=MFz3czHBeWs1Zk1A8O02CyQpQ-NCyW9PMpbqmKit6es,1455
24
24
  ring/algorithms/generator/types.py,sha256=HjNyATFSLfHkXlzdJhvUkiqnhzpXFDDXmWS3LYBlOtU,721
25
+ ring/extras/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
26
+ ring/extras/backend.py,sha256=cKSi9sB59texqKzNVASTDczGKLCBL8VVDiP7TNdj41k,1294
27
+ ring/extras/colab.py,sha256=ZLHwP0jNQUsmZJU4l68a5djULPi6T-jYNNHevjIoMn8,1631
28
+ ring/extras/dataloader.py,sha256=dfNPjnxDoKxWGKSImuJ_49CWgBn73vxSEek8COq9nNk,3749
29
+ ring/extras/dataloader_torch.py,sha256=t2DDiB9ZHb_SzFlVbntCGGIybj4F-NoA0PaB4_afjGw,3983
30
+ ring/extras/hdf5.py,sha256=XPIrwogD-d544yy08UJyfLVp1ZKRUtiZukW7RA8VUxQ,5856
31
+ ring/extras/interactive_viewer.py,sha256=vQEzcBDdG3BPqTGEktC74DsCfvgKktj9DKWK8gBzRtE,3805
32
+ ring/extras/normalizer.py,sha256=o26stPP6EHasZQxQX0vKqTrhUNZBaJ2O17L6W_gBMN4,1699
33
+ ring/extras/randomize_sys.py,sha256=G_vBIo0OwQkXL2u0djwbaoaeb02C4LQCTNNloOYIU2M,3699
34
+ ring/extras/torch_loss_fn.py,sha256=1LnWTmtxXPxoQFr4QixW12AjpRUfrseSDBmifhu6ErE,2676
35
+ ring/extras/register_gym_envs/__init__.py,sha256=PtPIRBQJ16339xZ9G9VpvqrvcGbQ_Pk_SUz4tQPa9nQ,94
36
+ ring/extras/register_gym_envs/saddle.py,sha256=tA5CyW_akSXyDm0xJ83CtOrUMVElH0f9vZtEDDJQalI,4422
25
37
  ring/io/__init__.py,sha256=1gEJdyDCbldbbm8QeZbLmhzSKmaQ-UqTmQgu4DBH2Z4,328
26
38
  ring/io/examples.py,sha256=KLf2iCagvRfjs9MCnQsLUlfGBjrQKrD-Qv8U0TtX6Ek,1114
27
39
  ring/io/test_examples.py,sha256=htpnSgLG9Fi9_qwSL4F1yLi9sN7ZUrF8dDmiqU3B510,117
@@ -63,8 +75,8 @@ ring/ml/training_loop.py,sha256=yxuUua_4RExq_0GUYm4eUZJsBmtrwDSVL94bWUpYfdo,3586
63
75
  ring/ml/params/0x13e3518065c21cd8.pickle,sha256=Zh2k1zK-TNxJl5F7nyTeQ9001qqRE_dfvaq1HWV287A,9355838
64
76
  ring/ml/params/0x1d76628065a71e0f.pickle,sha256=YTNVuvfw-nCRD9BH1PZYcR9uCFpNWDhw8Lc50eDn_EE,9351038
65
77
  ring/rendering/__init__.py,sha256=Zf7qOdzK3t2hljIrs5P4zFhzHljLSMRyDDZO2YlZk4k,75
66
- ring/rendering/base_render.py,sha256=Mv9SRLEmuoPVhi46UIjb6xCkKmbWCwIyENGx7nu9REM,9617
67
- ring/rendering/mujoco_render.py,sha256=HMvZc04I0-lXPBL3hcnBzV2bNiXQAQM7QcHlG_Obmj4,8757
78
+ ring/rendering/base_render.py,sha256=O8Oo9znAgWRE09R7B2yecpwNDJ5veIRoMci144oHwF8,10554
79
+ ring/rendering/mujoco_render.py,sha256=eCmnnzwVZ3BeIo1INswXMZaZ9TDaF1HO50f70spXX2E,9704
68
80
  ring/rendering/vispy_render.py,sha256=6Z6S5LNZ7iy9BN1GVb9EDe-Tix5N_SQ1s7ZsfiTSDEA,10261
69
81
  ring/rendering/vispy_visuals.py,sha256=ooBZqppnebeL0ANe6V6zUgnNTtDcdkOsa4vZuM4sx-I,7873
70
82
  ring/sim2real/__init__.py,sha256=gCLYg8IoMdzUagzhCFcfjZ5GavtIU772L7HR0G5hUtM,251
@@ -73,21 +85,12 @@ ring/sys_composer/__init__.py,sha256=5J_JJJIHfTPcpxh0v4FqiOs81V1REPUd7pgiw2nAN5E
73
85
  ring/sys_composer/delete_sys.py,sha256=cIM9KbyLfg7B9121g7yjzuFbjeNu9cil1dPavAYEgzk,3408
74
86
  ring/sys_composer/inject_sys.py,sha256=PLuxLbXU7hPtAsqvpsEim9hkoVE26ddrg3OipZNvnhU,3504
75
87
  ring/sys_composer/morph_sys.py,sha256=2GpPtS5hT0eZMptdGpt30Hc97OykJNE67lEVRf7sHrc,12700
76
- ring/utils/__init__.py,sha256=MHHavc8YfjBlmB-zAV42QEQS_ebW7cy0lhWXEVyQU7s,720
77
- ring/utils/backend.py,sha256=cKSi9sB59texqKzNVASTDczGKLCBL8VVDiP7TNdj41k,1294
88
+ ring/utils/__init__.py,sha256=Q37bjy2wjRGggd77MHlgl_50i2zOuVnPny4yOLiTe-8,567
78
89
  ring/utils/batchsize.py,sha256=uCj8LG7elbjEUUzuK29Z3I9T8bxJTcsybY3DdGeqhQs,1786
79
- ring/utils/colab.py,sha256=ZLHwP0jNQUsmZJU4l68a5djULPi6T-jYNNHevjIoMn8,1631
80
- ring/utils/dataloader.py,sha256=dfNPjnxDoKxWGKSImuJ_49CWgBn73vxSEek8COq9nNk,3749
81
- ring/utils/dataloader_torch.py,sha256=t2DDiB9ZHb_SzFlVbntCGGIybj4F-NoA0PaB4_afjGw,3983
82
- ring/utils/hdf5.py,sha256=XPIrwogD-d544yy08UJyfLVp1ZKRUtiZukW7RA8VUxQ,5856
83
- ring/utils/normalizer.py,sha256=o26stPP6EHasZQxQX0vKqTrhUNZBaJ2O17L6W_gBMN4,1699
84
90
  ring/utils/path.py,sha256=zRPfxYNesvgefkddd26oar6f9433LkMGkhp9dF3rPUs,1926
85
- ring/utils/randomize_sys.py,sha256=G_vBIo0OwQkXL2u0djwbaoaeb02C4LQCTNNloOYIU2M,3699
86
91
  ring/utils/utils.py,sha256=gKwOXLxWraeZfX6EbBcg3hkq30DcXN0mcRUeOSTNiMo,7336
87
- ring/utils/register_gym_envs/__init__.py,sha256=PtPIRBQJ16339xZ9G9VpvqrvcGbQ_Pk_SUz4tQPa9nQ,94
88
- ring/utils/register_gym_envs/saddle.py,sha256=tA5CyW_akSXyDm0xJ83CtOrUMVElH0f9vZtEDDJQalI,4422
89
- imt_ring-1.6.46.dist-info/METADATA,sha256=9NOkzI2PpdcJpw90_ZV0smHEWyhkrZyEF211Wy1gNpg,5888
90
- imt_ring-1.6.46.dist-info/WHEEL,sha256=jB7zZ3N9hIM9adW7qlTAyycLYW9npaWKLRzaoVcLKcM,91
91
- imt_ring-1.6.46.dist-info/entry_points.txt,sha256=npNqSOvNiBR0BNa_GL3J66q8Gky3h0G_PHzHzk8oyE0,66
92
- imt_ring-1.6.46.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
93
- imt_ring-1.6.46.dist-info/RECORD,,
92
+ imt_ring-1.7.0.dist-info/METADATA,sha256=CNwgvWr9Yu7MgIfcNwXkuByr7_8vkxvO5IkJg3iDKbs,5887
93
+ imt_ring-1.7.0.dist-info/WHEEL,sha256=SmOxYU7pzNKBqASvQJ7DjX3XGUF92lrGhMb3R6_iiqI,91
94
+ imt_ring-1.7.0.dist-info/entry_points.txt,sha256=npNqSOvNiBR0BNa_GL3J66q8Gky3h0G_PHzHzk8oyE0,66
95
+ imt_ring-1.7.0.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
96
+ imt_ring-1.7.0.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.8.2)
2
+ Generator: setuptools (79.0.1)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
ring/__init__.py CHANGED
@@ -35,12 +35,12 @@ def RING(lam: list[int] | None, Ts: float | None, **kwargs) -> ml.AbstractFilter
35
35
  >>> import ring
36
36
  >>> import numpy as np
37
37
  >>>
38
- >>> T : int = 30 # sequence length [s]
39
- >>> Ts : float = 0.01 # sampling interval [s]
40
- >>> B : int = 1 # batch size
41
- >>> lam: list[int] = [0, 1, 2] # parent array
42
- >>> N : int = len(lam) # number of bodies
43
- >>> T_i: int = int(T/Ts) # number of timesteps
38
+ >>> T : int = 30 # sequence length [s]
39
+ >>> Ts : float = 0.01 # sampling interval [s]
40
+ >>> B : int = 1 # batch size
41
+ >>> lam: list[int] = [-1, 0, 1] # parent array
42
+ >>> N : int = len(lam) # number of bodies
43
+ >>> T_i: int = int(T/Ts) # number of timesteps
44
44
  >>>
45
45
  >>> X = np.zeros((B, T_i, N, 9))
46
46
  >>> # where X is structured as follows:
@@ -1,5 +1,6 @@
1
1
  from dataclasses import replace
2
2
  from functools import partial
3
+ import json
3
4
  import logging
4
5
  import random
5
6
  from typing import Callable, Optional
@@ -136,6 +137,14 @@ class RCMG:
136
137
  affecting joint motion behavior.
137
138
  """ # noqa: E501
138
139
 
140
+ # capture all funtion arguments before creating local variables
141
+ to_json_kwargs = locals()
142
+ # the purpose is to not capture the RCMG itself since we want to make it
143
+ # serialisable in the first place
144
+ to_json_kwargs.pop("self")
145
+ to_json_kwargs.pop("sys")
146
+ to_json_kwargs.pop("config")
147
+
139
148
  # add some default values
140
149
  randomize_hz_kwargs_defaults = dict(add_dt=True)
141
150
  randomize_hz_kwargs_defaults.update(randomize_hz_kwargs)
@@ -186,6 +195,11 @@ class RCMG:
186
195
 
187
196
  self._disable_tqdm = disable_tqdm
188
197
 
198
+ # store arguments that fully define the RCMG objects for use in `.to_json`
199
+ self._to_json_sys = sys
200
+ self._to_json_mconfig = config
201
+ self._to_json_kwargs = to_json_kwargs
202
+
189
203
  def _compute_repeats(self, sizes: int | list[int]) -> list[int]:
190
204
  "how many times the generators are repeated to create a batch of `sizes`"
191
205
 
@@ -355,6 +369,21 @@ class RCMG:
355
369
 
356
370
  return generator
357
371
 
372
+ def serialise_to_dict(self) -> dict:
373
+ dict_representation = {
374
+ "system": [_sys.to_str(warn=False) for _sys in self._to_json_sys],
375
+ "motion_configs": [_config.__dict__ for _config in self._to_json_mconfig],
376
+ "kwargs": self._to_json_kwargs,
377
+ }
378
+ return dict_representation
379
+
380
+ def serialise_to_json(self, path_of_json: str) -> None:
381
+ with open(path_of_json, "w") as file:
382
+ json.dump(self.serialise_to_dict(), file, indent=4)
383
+
384
+ def from_json(self, path_to_json: str) -> "RCMG":
385
+ raise NotImplementedError
386
+
358
387
 
359
388
  def _copy_dicts(f) -> dict:
360
389
  def _f(*args, **kwargs):
@@ -526,7 +555,7 @@ def draw_random_q(
526
555
  sys: base.System,
527
556
  config: jcalc.MotionConfig,
528
557
  N: int | None,
529
- ) -> tuple[types.Xy, types.OutputExtras]:
558
+ ) -> tuple[jax.random.PRNGKey, jax.Array]:
530
559
 
531
560
  key_start = key
532
561
  # build generalized coordintes vector `q`
ring/base.py CHANGED
@@ -981,6 +981,7 @@ class System(_Base):
981
981
 
982
982
  def render(
983
983
  self,
984
+ qs: Optional[jax.Array | list[jax.Array]] = None,
984
985
  xs: Optional[Transform | list[Transform]] = None,
985
986
  camera: Optional[str] = None,
986
987
  show_pbar: bool = True,
@@ -1001,7 +1002,7 @@ class System(_Base):
1001
1002
  list[np.ndarray]: Stacked rendered frames. Length == len(xs).
1002
1003
  """
1003
1004
  return ring.rendering.render(
1004
- self, xs, camera, show_pbar, backend, render_every_nth, **scene_kwargs
1005
+ self, qs, xs, camera, show_pbar, backend, render_every_nth, **scene_kwargs
1005
1006
  )
1006
1007
 
1007
1008
  def render_prediction(
File without changes
@@ -0,0 +1,114 @@
1
+ import multiprocessing
2
+ import time
3
+ from typing import Optional
4
+
5
+ import fire
6
+ import jax.numpy as jnp
7
+ import numpy as np
8
+
9
+ import ring
10
+ from ring import System
11
+
12
+
13
+ class InteractiveViewer:
14
+ def __init__(self, sys: ring.System, **scene_kwargs):
15
+ self._mp_dict = multiprocessing.Manager().dict()
16
+ self._geom_dict = multiprocessing.Manager().dict()
17
+ self.update_q(np.array(ring.State.create(sys).q))
18
+ self.process = multiprocessing.Process(
19
+ target=self._worker,
20
+ args=(self._mp_dict, self._geom_dict, sys.to_str(), scene_kwargs),
21
+ )
22
+ self.process.start()
23
+
24
+ def update_q(self, q: np.ndarray):
25
+ self._mp_dict["q"] = q
26
+
27
+ def make_geometry_transparent(self, body_number: int, geom_number: int):
28
+ geom_name = f"body{body_number}_geom{geom_number}"
29
+ # the value is not used
30
+ self._geom_dict[geom_name] = None
31
+
32
+ def _worker(self, mp_dict, geom_dict, sys_str, scene_kwargs):
33
+ from ring.rendering import base_render
34
+
35
+ sys = System.from_str(sys_str)
36
+ while base_render._scene is None or base_render._scene._renderer.is_alive:
37
+ sys.render(jnp.array(mp_dict["q"]), interactive=True, **scene_kwargs)
38
+
39
+ if len(geom_dict) > 0:
40
+ model = base_render._scene._model
41
+ processed = []
42
+ for geom_name in list(geom_dict.keys()):
43
+ # Get the geometry ID
44
+ geom_id = model.geom(geom_name).id
45
+ # Set transparency to 0 (fully transparent)
46
+ model.geom_rgba[geom_id, 3] = 0
47
+ print(f"Made geom with name={geom_name} transparent (worker)")
48
+ processed.append(geom_name)
49
+
50
+ for geom_name in processed:
51
+ geom_dict.pop(geom_name)
52
+
53
+ def __enter__(self):
54
+ return self
55
+
56
+ def close(self):
57
+ self.process.terminate()
58
+ self.process.join()
59
+
60
+ def __exit__(self, exc_type, exc_value, traceback):
61
+ self.close()
62
+
63
+
64
+ def _fire_main(path_sys_xml: str, path_qs_np: Optional[str] = None, **scene_kwargs):
65
+ """View motion given by trajectory of minimal coordinates in interactive viewer.
66
+
67
+ Args:
68
+ path_sys_xml (str): Path to xml file defining the system.
69
+ path_qs_np (str | None, optional): Path to numpy array containing the timeseries of minimal coordinates with
70
+ shape (T, DOF) where DOF is equal to `sys.q_size()`. Each minimal coordiante is from parent
71
+ to child. So for example a `spherical` joint that connects the first body to the worldbody
72
+ has a minimal coordinate of a quaternion that gives from worldbody to first body. The sampling
73
+ rate of the motion is inferred from the `sys.dt` attribute. If `None` (default), then simply renders the
74
+ unarticulated pose of the system.
75
+ """ # noqa: E501
76
+
77
+ sys = ring.System.from_xml(path_sys_xml)
78
+ if path_qs_np is None:
79
+ qs = np.array(ring.State.create(sys).q)[None]
80
+ else:
81
+ qs: np.ndarray = np.load(path_qs_np)
82
+
83
+ assert qs.ndim == 2, f"qs.shape = {qs.shape}"
84
+ T, Q = qs.shape
85
+ assert Q == sys.q_size()
86
+ dt_target = sys.dt
87
+
88
+ with InteractiveViewer(sys, **scene_kwargs) as viewer:
89
+ dt = dt_target
90
+ last_t = time.time()
91
+ t = -1
92
+
93
+ while True:
94
+ t = (t + 1) % T
95
+
96
+ while dt < dt_target:
97
+ time.sleep(0.001)
98
+ dt = time.time() - last_t
99
+
100
+ last_t = time.time()
101
+ viewer.update_q(qs[t])
102
+ dt = time.time() - last_t
103
+
104
+ # process will be stopped if the window is closed
105
+ if not viewer.process.is_alive():
106
+ break
107
+
108
+
109
+ def main():
110
+ fire.Fire(_fire_main)
111
+
112
+
113
+ if __name__ == "__main__":
114
+ main()
@@ -0,0 +1,93 @@
1
+ """This module exports a loss function `loss_fn` for training neural networks that
2
+ output quaternions in PyTorch"""
3
+
4
+ from typing import Sequence
5
+
6
+ import torch
7
+
8
+
9
+ def quat_mul(u: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
10
+ "Multiplies two quaternions."
11
+ q = torch.stack(
12
+ [
13
+ u[..., 0] * v[..., 0]
14
+ - u[..., 1] * v[..., 1]
15
+ - u[..., 2] * v[..., 2]
16
+ - u[..., 3] * v[..., 3],
17
+ u[..., 0] * v[..., 1]
18
+ + u[..., 1] * v[..., 0]
19
+ + u[..., 2] * v[..., 3]
20
+ - u[..., 3] * v[..., 2],
21
+ u[..., 0] * v[..., 2]
22
+ - u[..., 1] * v[..., 3]
23
+ + u[..., 2] * v[..., 0]
24
+ + u[..., 3] * v[..., 1],
25
+ u[..., 0] * v[..., 3]
26
+ + u[..., 1] * v[..., 2]
27
+ - u[..., 2] * v[..., 1]
28
+ + u[..., 3] * v[..., 0],
29
+ ],
30
+ dim=-1,
31
+ )
32
+ return q
33
+
34
+
35
+ def quat_inv(q: torch.Tensor):
36
+ return torch.concat([q[..., :1], -q[..., 1:]], dim=-1)
37
+
38
+
39
+ def wrap_to_pi(phi):
40
+ "Wraps angle `phi` (radians) to interval [-pi, pi]."
41
+ return (phi + torch.pi) % (2 * torch.pi) - torch.pi
42
+
43
+
44
+ def quat_angle(q: torch.Tensor):
45
+ phi = 2 * torch.arctan2(torch.norm(q[..., 1:], dim=-1), q[..., 0])
46
+ return wrap_to_pi(phi)
47
+
48
+
49
+ def safe_normalize(x):
50
+ return x / (1e-6 + torch.norm(x, dim=-1, keepdim=True))
51
+
52
+
53
+ def quat_qrel(q1, q2):
54
+ "q1^-1 * q2"
55
+ return quat_mul(quat_inv(q1), q2)
56
+
57
+
58
+ @torch.jit.script
59
+ def angle_error(q, qhat):
60
+ "Absolute angle error in radians"
61
+ return torch.abs(quat_angle(quat_qrel(q, qhat)))
62
+
63
+
64
+ @torch.jit.script
65
+ def inclination_error(q, qhat):
66
+ "Absolute inclination error in radians. `q`s are from body-to-eps"
67
+ q_rel = quat_mul(q, quat_inv(qhat))
68
+ phi_pri = 2 * torch.arctan2(q_rel[..., 3], q_rel[..., 0])
69
+ q_pri = torch.zeros_like(q)
70
+ q_pri[..., 0] = torch.cos(phi_pri / 2)
71
+ q_pri[..., 3] = torch.sin(phi_pri / 2)
72
+ q_res = quat_mul(q_rel, quat_inv(q_pri))
73
+ return torch.abs(quat_angle(q_res))
74
+
75
+
76
+ def loss_fn(lam: Sequence[int], q: torch.Tensor, qhat: torch.Tensor) -> torch.Tensor:
77
+ "(..., N, 4) -> (..., N)"
78
+ *batch_dims, N, F = q.shape
79
+ assert q.shape == qhat.shape
80
+ assert F == 4
81
+ assert N == len(lam)
82
+ permu = list(reversed(range(q.ndim - 1)))
83
+ loss_incl = inclination_error(q, qhat).permute(*permu)
84
+ loss_mae = angle_error(q, qhat).permute(*permu)
85
+ lam = torch.tensor(lam, device=q.device)
86
+ return torch.where(
87
+ lam.reshape(-1, *[1] * len(batch_dims)) == -1, loss_incl, loss_mae
88
+ ).permute(*permu)
89
+
90
+
91
+ def quat_rand(*size: tuple[int]):
92
+ qs = torch.randn(size=size + (4,))
93
+ return qs / torch.norm(qs, dim=-1, keepdim=True)
@@ -1,3 +1,4 @@
1
+ from functools import partial
1
2
  from typing import Optional
2
3
 
3
4
  import jax
@@ -93,15 +94,29 @@ def _load_scene(sys, backend, **scene_kwargs):
93
94
  return _scene
94
95
 
95
96
 
97
+ @jax.jit
98
+ def _jit_forward_kinematics(sys):
99
+ _, state = kinematics.forward_kinematics(sys, base.State.create(sys))
100
+ return state.x
101
+
102
+
103
+ @jax.jit
104
+ @partial(jax.vmap, in_axes=(None, 0))
105
+ def _jit_vmap_forward_kinematics(sys, q):
106
+ _, state = kinematics.forward_kinematics(sys, base.State.create(sys, q=q))
107
+ return state.x
108
+
109
+
96
110
  def render(
97
111
  sys: base.System,
112
+ qs: Optional[jax.Array | list[jax.Array]] = None,
98
113
  xs: Optional[base.Transform | list[base.Transform]] = None,
99
114
  camera: Optional[str] = None,
100
115
  show_pbar: bool = True,
101
116
  backend: str = "mujoco",
102
117
  render_every_nth: int = 1,
103
118
  **scene_kwargs,
104
- ) -> list[np.ndarray]:
119
+ ) -> list[np.ndarray | None]:
105
120
  """Render frames from system and trajectory of maximal coordinates `xs`.
106
121
 
107
122
  Args:
@@ -114,9 +129,18 @@ def render(
114
129
  Returns:
115
130
  list[np.ndarray]: Stacked rendered frames. Length == len(xs).
116
131
  """
132
+ assert not (qs is not None and xs is not None)
117
133
 
134
+ if xs is None and qs is None:
135
+ xs = _jit_forward_kinematics(sys)
118
136
  if xs is None:
119
- xs = kinematics.forward_kinematics(sys, base.State.create(sys))[1].x
137
+ # throw error if `xs` has been given by accident as `qs` argument
138
+ qs = utils.to_list(qs)
139
+ assert not isinstance(
140
+ qs[0], base.Transform
141
+ ), "`qs` should be `jax.Array` and not `Transform`; maybe you want to pass `xs` as keyword argument `xs=xs`?" # noqa: E501
142
+ qs = jnp.stack(qs, axis=0)
143
+ xs = _jit_vmap_forward_kinematics(sys, qs)
120
144
 
121
145
  # convert time-axis of batched xs object into a list of unbatched x objects
122
146
  if isinstance(xs, base.Transform) and xs.ndim() == 3:
@@ -144,6 +168,9 @@ def render(
144
168
 
145
169
  scene = _load_scene(sys, backend, **scene_kwargs)
146
170
 
171
+ if scene_kwargs.get("interactive", False):
172
+ show_pbar = False
173
+
147
174
  frames = []
148
175
  for x in tqdm.tqdm(xs, "Rendering frames..", disable=not show_pbar):
149
176
  scene.update(x)
@@ -241,7 +268,7 @@ def render_prediction(
241
268
  sys, xs, yhat, transparent_segment_to_root, offset_truth, offset_pred
242
269
  )
243
270
 
244
- frames = render(sys_render, xs_render, **kwargs)
271
+ frames = render(sys=sys_render, xs=xs_render, **kwargs)
245
272
  return frames
246
273
 
247
274
 
@@ -118,8 +118,10 @@ def _xml_str_one_body(
118
118
  body_number: int, geoms: list[base.Geometry], cameras: list[str], lights: list[str]
119
119
  ) -> str:
120
120
  inside_body_geoms = ""
121
- for geom in geoms:
122
- inside_body_geoms += _xml_str_one_geom(geom)
121
+ for geom_number, geom in enumerate(geoms):
122
+ inside_body_geoms += _xml_str_one_geom(
123
+ geom, name=f"body{body_number}_geom{geom_number}"
124
+ )
123
125
 
124
126
  inside_body_cameras = ""
125
127
  for camera in cameras:
@@ -138,7 +140,7 @@ def _xml_str_one_body(
138
140
  """
139
141
 
140
142
 
141
- def _xml_str_one_geom(geom: base.Geometry) -> str:
143
+ def _xml_str_one_geom(geom: base.Geometry, name: str) -> str:
142
144
  rgba = f'rgba="{_array_to_str(geom.color)}"'
143
145
 
144
146
  if isinstance(geom, base.Box):
@@ -158,7 +160,8 @@ def _xml_str_one_geom(geom: base.Geometry) -> str:
158
160
 
159
161
  rot, pos = maths.quat_inv(geom.transform.rot), geom.transform.pos
160
162
  rot, pos = f'pos="{_array_to_str(pos)}"', f'quat="{_array_to_str(rot)}"'
161
- return f"<geom {type_size} {rgba} {rot} {pos}/>"
163
+ name = f'name="{name}"'
164
+ return f"<geom {type_size} {rgba} {rot} {pos} {name}/>"
162
165
 
163
166
 
164
167
  def _array_to_str(arr: Sequence[float]) -> str:
@@ -181,6 +184,8 @@ class MujocoScene:
181
184
  floor_z: float = -0.84,
182
185
  floor_material: str = "matplane",
183
186
  debug: bool = False,
187
+ interactive: bool = False,
188
+ interactive_hide_menu: bool = False,
184
189
  ) -> None:
185
190
  self.debug = debug
186
191
  self.height, self.width = height, width
@@ -195,6 +200,8 @@ class MujocoScene:
195
200
  self.show_stars = show_stars
196
201
  self.show_floor = show_floor
197
202
  self.floor_kwargs = dict(z=floor_z, material=floor_material)
203
+ self.interactive = interactive
204
+ self.interactive_hide_menu = interactive_hide_menu
198
205
 
199
206
  def init(self, geoms: list[base.Geometry]):
200
207
  self._parent_ids = list(set([geom.link_idx for geom in geoms]))
@@ -208,7 +215,22 @@ class MujocoScene:
208
215
  debug=self.debug,
209
216
  )
210
217
  self._data = mujoco.MjData(self._model)
211
- self._renderer = mujoco.Renderer(self._model, self.height, self.width)
218
+ if self.interactive:
219
+ import mujoco_viewer
220
+
221
+ self._renderer = mujoco_viewer.MujocoViewer(
222
+ self._model,
223
+ self._data,
224
+ width=self.width,
225
+ height=self.height,
226
+ hide_menus=self.interactive_hide_menu,
227
+ )
228
+
229
+ if self.interactive_hide_menu:
230
+ print("Menu can be shown with key `H` for H(elp)")
231
+
232
+ else:
233
+ self._renderer = mujoco.Renderer(self._model, self.height, self.width)
212
234
 
213
235
  def update(self, x: base.Transform):
214
236
  rot, pos = maths.quat_inv(x.rot), x.pos
@@ -234,6 +256,15 @@ class MujocoScene:
234
256
 
235
257
  mujoco.mj_forward(self._model, self._data)
236
258
 
237
- def render(self, camera: Optional[str] = None):
238
- self._renderer.update_scene(self._data, camera=-1 if camera is None else camera)
259
+ def render(self, camera: Optional[str] = None) -> np.ndarray | None:
260
+ if not self.interactive:
261
+ self._renderer.update_scene(
262
+ self._data, camera=-1 if camera is None else camera
263
+ )
239
264
  return self._renderer.render()
265
+
266
+ def close(self):
267
+ self._renderer.close()
268
+
269
+ def __del__(self):
270
+ self.close()
ring/utils/__init__.py CHANGED
@@ -1,11 +1,7 @@
1
- from . import randomize_sys
2
1
  from .batchsize import batchsize_thresholds
3
2
  from .batchsize import distribute_batchsize
4
3
  from .batchsize import expand_batchsize
5
4
  from .batchsize import merge_batchsize
6
- from .colab import setup_colab_env
7
- from .normalizer import make_normalizer_from_generator
8
- from .normalizer import Normalizer
9
5
  from .path import parse_path
10
6
  from .utils import dict_to_nested
11
7
  from .utils import dict_union
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes