imt-ring 1.6.47__py3-none-any.whl → 1.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.2
1
+ Metadata-Version: 2.4
2
2
  Name: imt-ring
3
- Version: 1.6.47
3
+ Version: 1.7.1
4
4
  Summary: RING: Recurrent Inertial Graph-based Estimator
5
5
  Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
6
6
  Project-URL: Homepage, https://github.com/SimiPixel/ring
@@ -1,27 +1,39 @@
1
- ring/__init__.py,sha256=H1Rd2uXVkux4Z792XyHIkQ8OpDSZBiPqFwyAFDWDU3E,5260
1
+ ring/__init__.py,sha256=y3LuDekHyOCYdzaEDJM5dodClfderAKH-0ufklrwtHY,5266
2
2
  ring/algebra.py,sha256=F0GwbP8LQP5qGVkoMUYJmkp9Hn2nKAVIkCVYDEjNjGU,3128
3
- ring/base.py,sha256=zromjIuMpNBoyiwHa9OCyZvAz7jHjXHZIdRt8fN8PoA,50481
3
+ ring/base.py,sha256=AkG_Gpk7i2j77MzxzjiolJ9WNGcSq_3aOcpu0l6-0e0,50543
4
4
  ring/maths.py,sha256=R22SNQutkf9v7Hp9klo0wvJVIyBQz0O8_5oJaDQcFis,12652
5
5
  ring/spatial.py,sha256=nmZ-UhRanhyM34bez8uCS4wMwaKqLkuEbgKGP5XNH60,2351
6
6
  ring/algorithms/__init__.py,sha256=IiK9EN5Xgs3dB075-A-H-Yad0Z7vzvKIJF2g6X_-C_8,1224
7
- ring/algorithms/_random.py,sha256=UMyv-VPZLcErrKqs0XB83QJjs8GrmoNsv-zRSxGXvnI,14490
7
+ ring/algorithms/_random.py,sha256=9mSP7M1On_CBz-dEbvE0iDzl_buf125EiYb1y10Ylxc,15805
8
8
  ring/algorithms/dynamics.py,sha256=NFOZawjrFoS5RgiWOpG1pQCC8l7RBOEZIi9ok6gvf9U,12268
9
- ring/algorithms/jcalc.py,sha256=l6BXOmXwrZ_AKKRm4gEHq_k2LSUQ4wd--1qL1qNTcKk,46794
9
+ ring/algorithms/jcalc.py,sha256=zUMKEndSF6zaU7qHjJEU-iEQ_YAFNTNPcOwzdyu33uk,47632
10
10
  ring/algorithms/kinematics.py,sha256=IXeTQ-afzeEzLVmQVQ1oTXJxz_lTwvaWlgHeJkhO_8o,7423
11
11
  ring/algorithms/sensors.py,sha256=v_TZMyWjffDpPwoyS1fy8X-1i9y1vDf6mk1EmGS2ztc,18251
12
12
  ring/algorithms/custom_joints/__init__.py,sha256=3pQ-Is_HBTQDkzESCNg9VfoP8wvseWmooryG8ERnu_A,366
13
- ring/algorithms/custom_joints/rr_imp_joint.py,sha256=_YJK0p8_0MHFtr1NuGnNZoxTbwaMQyUjYv7EtsPiU3A,2402
13
+ ring/algorithms/custom_joints/rr_imp_joint.py,sha256=1sOej-D3q4bSh5KJqcK5fZA3iOHXUlrlT9u-tIa-D6c,2351
14
14
  ring/algorithms/custom_joints/rr_joint.py,sha256=jnRtjtOCALMaq2_0bcu2d7qgfQ6etXpoh43MioRaDmY,1000
15
15
  ring/algorithms/custom_joints/rsaddle_joint.py,sha256=QoMo6NXdYgA9JygSzBvr0eCdd3qKhUgCrGPNO2Qdxko,1200
16
16
  ring/algorithms/custom_joints/suntay.py,sha256=TZG307NqdMiXnNY63xEx8AkAjbQBQ4eO6DQ7R4j4D08,16726
17
17
  ring/algorithms/generator/__init__.py,sha256=bF-CW3x2x-o6KWESKy-DuxzZPh3UNSjJb_MaAcSHGsQ,277
18
- ring/algorithms/generator/base.py,sha256=yPH_RIQPU_nlq58HyZ6T3RUm1S5chA3-Ro__-ArYTq0,22669
18
+ ring/algorithms/generator/base.py,sha256=sLIXfFliRUzUKaf84rBQjsExEfmU3XjENrYGD4fm1Q0,23808
19
19
  ring/algorithms/generator/batch.py,sha256=xp1X8oYtwI6l2cH4GRu9zw-P8dnh-X1FWTSyixEfgr8,2652
20
20
  ring/algorithms/generator/finalize_fns.py,sha256=ty1NaU-Mghx1RL-voivDjS0TWSKNtjTmbdmBnShhn7k,10398
21
21
  ring/algorithms/generator/motion_artifacts.py,sha256=2VJbldVDbI3PSyboshIbtYvSAKzBBwGV7cQfYjqvluM,9167
22
22
  ring/algorithms/generator/pd_control.py,sha256=dHnhJZx_FqrHD4xFXpQZH-R7rputFkAVGwoBGccZnz4,5767
23
23
  ring/algorithms/generator/setup_fns.py,sha256=MFz3czHBeWs1Zk1A8O02CyQpQ-NCyW9PMpbqmKit6es,1455
24
24
  ring/algorithms/generator/types.py,sha256=HjNyATFSLfHkXlzdJhvUkiqnhzpXFDDXmWS3LYBlOtU,721
25
+ ring/extras/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
26
+ ring/extras/backend.py,sha256=cKSi9sB59texqKzNVASTDczGKLCBL8VVDiP7TNdj41k,1294
27
+ ring/extras/colab.py,sha256=ZLHwP0jNQUsmZJU4l68a5djULPi6T-jYNNHevjIoMn8,1631
28
+ ring/extras/dataloader.py,sha256=dfNPjnxDoKxWGKSImuJ_49CWgBn73vxSEek8COq9nNk,3749
29
+ ring/extras/dataloader_torch.py,sha256=t2DDiB9ZHb_SzFlVbntCGGIybj4F-NoA0PaB4_afjGw,3983
30
+ ring/extras/hdf5.py,sha256=XPIrwogD-d544yy08UJyfLVp1ZKRUtiZukW7RA8VUxQ,5856
31
+ ring/extras/interactive_viewer.py,sha256=-jpoSsDrdzshZw-_MSI96QLvm9foRWTUS_aa_CNFk74,3867
32
+ ring/extras/normalizer.py,sha256=o26stPP6EHasZQxQX0vKqTrhUNZBaJ2O17L6W_gBMN4,1699
33
+ ring/extras/randomize_sys.py,sha256=G_vBIo0OwQkXL2u0djwbaoaeb02C4LQCTNNloOYIU2M,3699
34
+ ring/extras/torch_loss_fn.py,sha256=1LnWTmtxXPxoQFr4QixW12AjpRUfrseSDBmifhu6ErE,2676
35
+ ring/extras/register_gym_envs/__init__.py,sha256=PtPIRBQJ16339xZ9G9VpvqrvcGbQ_Pk_SUz4tQPa9nQ,94
36
+ ring/extras/register_gym_envs/saddle.py,sha256=tA5CyW_akSXyDm0xJ83CtOrUMVElH0f9vZtEDDJQalI,4422
25
37
  ring/io/__init__.py,sha256=1gEJdyDCbldbbm8QeZbLmhzSKmaQ-UqTmQgu4DBH2Z4,328
26
38
  ring/io/examples.py,sha256=KLf2iCagvRfjs9MCnQsLUlfGBjrQKrD-Qv8U0TtX6Ek,1114
27
39
  ring/io/test_examples.py,sha256=htpnSgLG9Fi9_qwSL4F1yLi9sN7ZUrF8dDmiqU3B510,117
@@ -63,8 +75,8 @@ ring/ml/training_loop.py,sha256=yxuUua_4RExq_0GUYm4eUZJsBmtrwDSVL94bWUpYfdo,3586
63
75
  ring/ml/params/0x13e3518065c21cd8.pickle,sha256=Zh2k1zK-TNxJl5F7nyTeQ9001qqRE_dfvaq1HWV287A,9355838
64
76
  ring/ml/params/0x1d76628065a71e0f.pickle,sha256=YTNVuvfw-nCRD9BH1PZYcR9uCFpNWDhw8Lc50eDn_EE,9351038
65
77
  ring/rendering/__init__.py,sha256=Zf7qOdzK3t2hljIrs5P4zFhzHljLSMRyDDZO2YlZk4k,75
66
- ring/rendering/base_render.py,sha256=Mv9SRLEmuoPVhi46UIjb6xCkKmbWCwIyENGx7nu9REM,9617
67
- ring/rendering/mujoco_render.py,sha256=HMvZc04I0-lXPBL3hcnBzV2bNiXQAQM7QcHlG_Obmj4,8757
78
+ ring/rendering/base_render.py,sha256=O8Oo9znAgWRE09R7B2yecpwNDJ5veIRoMci144oHwF8,10554
79
+ ring/rendering/mujoco_render.py,sha256=eCmnnzwVZ3BeIo1INswXMZaZ9TDaF1HO50f70spXX2E,9704
68
80
  ring/rendering/vispy_render.py,sha256=6Z6S5LNZ7iy9BN1GVb9EDe-Tix5N_SQ1s7ZsfiTSDEA,10261
69
81
  ring/rendering/vispy_visuals.py,sha256=ooBZqppnebeL0ANe6V6zUgnNTtDcdkOsa4vZuM4sx-I,7873
70
82
  ring/sim2real/__init__.py,sha256=gCLYg8IoMdzUagzhCFcfjZ5GavtIU772L7HR0G5hUtM,251
@@ -73,21 +85,12 @@ ring/sys_composer/__init__.py,sha256=5J_JJJIHfTPcpxh0v4FqiOs81V1REPUd7pgiw2nAN5E
73
85
  ring/sys_composer/delete_sys.py,sha256=cIM9KbyLfg7B9121g7yjzuFbjeNu9cil1dPavAYEgzk,3408
74
86
  ring/sys_composer/inject_sys.py,sha256=PLuxLbXU7hPtAsqvpsEim9hkoVE26ddrg3OipZNvnhU,3504
75
87
  ring/sys_composer/morph_sys.py,sha256=2GpPtS5hT0eZMptdGpt30Hc97OykJNE67lEVRf7sHrc,12700
76
- ring/utils/__init__.py,sha256=MHHavc8YfjBlmB-zAV42QEQS_ebW7cy0lhWXEVyQU7s,720
77
- ring/utils/backend.py,sha256=cKSi9sB59texqKzNVASTDczGKLCBL8VVDiP7TNdj41k,1294
88
+ ring/utils/__init__.py,sha256=Q37bjy2wjRGggd77MHlgl_50i2zOuVnPny4yOLiTe-8,567
78
89
  ring/utils/batchsize.py,sha256=uCj8LG7elbjEUUzuK29Z3I9T8bxJTcsybY3DdGeqhQs,1786
79
- ring/utils/colab.py,sha256=ZLHwP0jNQUsmZJU4l68a5djULPi6T-jYNNHevjIoMn8,1631
80
- ring/utils/dataloader.py,sha256=dfNPjnxDoKxWGKSImuJ_49CWgBn73vxSEek8COq9nNk,3749
81
- ring/utils/dataloader_torch.py,sha256=t2DDiB9ZHb_SzFlVbntCGGIybj4F-NoA0PaB4_afjGw,3983
82
- ring/utils/hdf5.py,sha256=XPIrwogD-d544yy08UJyfLVp1ZKRUtiZukW7RA8VUxQ,5856
83
- ring/utils/normalizer.py,sha256=o26stPP6EHasZQxQX0vKqTrhUNZBaJ2O17L6W_gBMN4,1699
84
90
  ring/utils/path.py,sha256=zRPfxYNesvgefkddd26oar6f9433LkMGkhp9dF3rPUs,1926
85
- ring/utils/randomize_sys.py,sha256=G_vBIo0OwQkXL2u0djwbaoaeb02C4LQCTNNloOYIU2M,3699
86
91
  ring/utils/utils.py,sha256=gKwOXLxWraeZfX6EbBcg3hkq30DcXN0mcRUeOSTNiMo,7336
87
- ring/utils/register_gym_envs/__init__.py,sha256=PtPIRBQJ16339xZ9G9VpvqrvcGbQ_Pk_SUz4tQPa9nQ,94
88
- ring/utils/register_gym_envs/saddle.py,sha256=tA5CyW_akSXyDm0xJ83CtOrUMVElH0f9vZtEDDJQalI,4422
89
- imt_ring-1.6.47.dist-info/METADATA,sha256=4acmyig9LGSCfOTMNNOBLrwOaLbCa2EtYXJEfvuMEpc,5888
90
- imt_ring-1.6.47.dist-info/WHEEL,sha256=jB7zZ3N9hIM9adW7qlTAyycLYW9npaWKLRzaoVcLKcM,91
91
- imt_ring-1.6.47.dist-info/entry_points.txt,sha256=npNqSOvNiBR0BNa_GL3J66q8Gky3h0G_PHzHzk8oyE0,66
92
- imt_ring-1.6.47.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
93
- imt_ring-1.6.47.dist-info/RECORD,,
92
+ imt_ring-1.7.1.dist-info/METADATA,sha256=cFujMyd1Xpqa8NiYn3iMqnhPHzGvlb_mLBbIPplQA6M,5887
93
+ imt_ring-1.7.1.dist-info/WHEEL,sha256=SmOxYU7pzNKBqASvQJ7DjX3XGUF92lrGhMb3R6_iiqI,91
94
+ imt_ring-1.7.1.dist-info/entry_points.txt,sha256=npNqSOvNiBR0BNa_GL3J66q8Gky3h0G_PHzHzk8oyE0,66
95
+ imt_ring-1.7.1.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
96
+ imt_ring-1.7.1.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.8.2)
2
+ Generator: setuptools (79.0.1)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
ring/__init__.py CHANGED
@@ -35,12 +35,12 @@ def RING(lam: list[int] | None, Ts: float | None, **kwargs) -> ml.AbstractFilter
35
35
  >>> import ring
36
36
  >>> import numpy as np
37
37
  >>>
38
- >>> T : int = 30 # sequence length [s]
39
- >>> Ts : float = 0.01 # sampling interval [s]
40
- >>> B : int = 1 # batch size
41
- >>> lam: list[int] = [0, 1, 2] # parent array
42
- >>> N : int = len(lam) # number of bodies
43
- >>> T_i: int = int(T/Ts) # number of timesteps
38
+ >>> T : int = 30 # sequence length [s]
39
+ >>> Ts : float = 0.01 # sampling interval [s]
40
+ >>> B : int = 1 # batch size
41
+ >>> lam: list[int] = [-1, 0, 1] # parent array
42
+ >>> N : int = len(lam) # number of bodies
43
+ >>> T_i: int = int(T/Ts) # number of timesteps
44
44
  >>>
45
45
  >>> X = np.zeros((B, T_i, N, 9))
46
46
  >>> # where X is structured as follows:
@@ -41,30 +41,48 @@ def random_angle_over_time(
41
41
  cdf_bins_min: int = 5,
42
42
  cdf_bins_max: Optional[int] = None,
43
43
  interpolation_method: str = "cosine",
44
+ include_standstills_prob: float = 0.0, # 0.0 means no standstills
45
+ include_standstills_t_min: float = 0.5,
46
+ include_standstills_t_max: float = 5.0,
44
47
  ) -> jax.Array:
45
48
  def body_fn_outer(val):
46
49
  i, t, phi, key_t, key_ang, ANG = val
47
50
 
48
- key_t, consume_t = random.split(key_t)
51
+ key_t, consume_t, consume_standstill = random.split(key_t, 3)
49
52
  key_ang, consume_ang = random.split(key_ang)
50
53
  rom_halfsize_float = _to_float(rom_halfsize, t)
51
54
  rom_lower = ANG_0 - rom_halfsize_float
52
55
  rom_upper = ANG_0 + rom_halfsize_float
53
- dt, phi = _resolve_range_of_motion(
54
- range_of_motion,
55
- range_of_motion_method,
56
- rom_lower,
57
- rom_upper,
58
- _to_float(dang_min, t),
59
- _to_float(dang_max, t),
60
- _to_float(delta_ang_min, t),
61
- _to_float(delta_ang_max, t),
62
- t_min,
63
- _to_float(t_max, t),
64
- phi,
65
- consume_t,
66
- consume_ang,
67
- max_iter,
56
+
57
+ is_standstill = jax.random.bernoulli(
58
+ consume_standstill, include_standstills_prob
59
+ )
60
+ dt, phi = jax.lax.cond(
61
+ is_standstill,
62
+ lambda: (
63
+ jax.random.uniform(
64
+ consume_t,
65
+ minval=include_standstills_t_min,
66
+ maxval=include_standstills_t_max,
67
+ ),
68
+ phi,
69
+ ),
70
+ lambda: _resolve_range_of_motion(
71
+ range_of_motion,
72
+ range_of_motion_method,
73
+ rom_lower,
74
+ rom_upper,
75
+ _to_float(dang_min, t),
76
+ _to_float(dang_max, t),
77
+ _to_float(delta_ang_min, t),
78
+ _to_float(delta_ang_max, t),
79
+ t_min,
80
+ _to_float(t_max, t),
81
+ phi,
82
+ consume_t,
83
+ consume_ang,
84
+ max_iter,
85
+ ),
68
86
  )
69
87
  t += dt
70
88
 
@@ -119,7 +137,8 @@ def random_angle_over_time(
119
137
 
120
138
  # APPROVED
121
139
  def random_position_over_time(
122
- key: random.PRNGKey,
140
+ key_t: random.PRNGKey,
141
+ key_value: random.PRNGKey,
123
142
  POS_0: float,
124
143
  pos_min: float | TimeDependentFloat,
125
144
  pos_max: float | TimeDependentFloat,
@@ -135,19 +154,14 @@ def random_position_over_time(
135
154
  cdf_bins_min: int = 5,
136
155
  cdf_bins_max: Optional[int] = None,
137
156
  interpolation_method: str = "cosine",
157
+ include_standstills_prob: float = 0.0, # 0.0 means no standstills
158
+ include_standstills_t_min: float = 0.5,
159
+ include_standstills_t_max: float = 5.0,
138
160
  ) -> jax.Array:
139
161
  def body_fn_inner(val):
140
162
  i, t, t_pre, x, x_pre, key = val
141
163
  dt = t - t_pre
142
164
 
143
- def sample_dx_squared(key):
144
- key, consume = random.split(key)
145
- dx = (
146
- random.uniform(consume) * (2 * dpos_max * t_max**2)
147
- - dpos_max * t_max**2
148
- )
149
- return key, dx
150
-
151
165
  def sample_dx(key):
152
166
  key, consume1, consume2 = random.split(key, 3)
153
167
  sign = random.choice(consume1, jnp.array([-1.0, 1.0]))
@@ -182,24 +196,43 @@ def random_position_over_time(
182
196
  return jnp.logical_not(break_if_true1 | break_if_true2)
183
197
 
184
198
  def body_fn_outer(val):
185
- i, t, t_pre, x, x_pre, key, POS = val
186
- key, consume = random.split(key)
187
- t += random.uniform(consume, minval=t_min, maxval=_to_float(t_max, t_pre))
199
+ i, t, t_pre, x, x_pre, key_t, key_value, POS = val
200
+ key_t, consume_t, consume_standstill = random.split(key_t, 3)
188
201
 
189
- # that zero resets the max_it count
190
- val_inner = (0, t, t_pre, x, x_pre, key)
191
- _, t, t_pre, x, x_pre, key = jax.lax.while_loop(
192
- cond_fn_inner, body_fn_inner, val_inner
202
+ is_standstill = jax.random.bernoulli(
203
+ consume_standstill, include_standstills_prob
204
+ )
205
+
206
+ def is_standstill_branch():
207
+ dt = random.uniform(
208
+ consume_t,
209
+ minval=include_standstills_t_min,
210
+ maxval=include_standstills_t_max,
211
+ )
212
+ t = t_pre + dt
213
+ return 0, t, t_pre, x, x_pre, key_value
214
+
215
+ def no_standstill_branch():
216
+ dt = random.uniform(consume_t, minval=t_min, maxval=_to_float(t_max, t_pre))
217
+ t = t_pre + dt
218
+ # that zero resets the max_it count
219
+ val_inner = (0, t, t_pre, x, x_pre, key_value)
220
+ return jax.lax.while_loop(cond_fn_inner, body_fn_inner, val_inner)
221
+
222
+ _, t, t_pre, x, x_pre, key_value = jax.lax.cond(
223
+ is_standstill,
224
+ is_standstill_branch,
225
+ no_standstill_branch,
193
226
  )
194
227
 
195
228
  POS_i = jnp.array([[jnp.floor(t / Ts) * Ts, x]])
196
229
  POS = jax.lax.dynamic_update_slice_in_dim(POS, POS_i, start_index=i, axis=0)
197
230
  t_pre = t
198
231
  x_pre = x
199
- return i + 1, t, t_pre, x, x_pre, key, POS
232
+ return i + 1, t, t_pre, x, x_pre, key_t, key_value, POS
200
233
 
201
234
  def cond_fn_outer(val):
202
- i, t, t_pre, x, x_pre, key, POS = val
235
+ i, t, t_pre, x, x_pre, key_t, key_value, POS = val
203
236
  return t <= T
204
237
 
205
238
  # preallocate POS array
@@ -207,7 +240,7 @@ def random_position_over_time(
207
240
  POS = jnp.zeros((int(T // t_min) + 1, 2))
208
241
  POS = POS.at[0, 1].set(POS_0)
209
242
 
210
- val_outer = (1, 0.0, 0.0, POS_0, POS_0, key, POS)
243
+ val_outer = (1, 0.0, 0.0, POS_0, POS_0, key_t, key_value, POS)
211
244
  end, *_, consume, POS = jax.lax.while_loop(cond_fn_outer, body_fn_outer, val_outer)
212
245
  POS = jnp.where(
213
246
  (jnp.arange(len(POS)) < end)[:, None],
@@ -23,11 +23,10 @@ def register_rr_imp_joint(
23
23
  return ring.Transform.create(rot=rot)
24
24
 
25
25
  def _draw_rr_imp(config, key_t, key_value, dt, N, _):
26
- key_t1, key_t2 = jax.random.split(key_t)
27
26
  key_value1, key_value2 = jax.random.split(key_value)
28
- q_traj_pri = _draw_rxyz(config, key_t1, key_value1, dt, N, _)
27
+ q_traj_pri = _draw_rxyz(config, key_t, key_value1, dt, N, _)
29
28
  q_traj_res = _draw_rxyz(
30
- replace(config_res, T=config.T), key_t2, key_value2, dt, N, _
29
+ replace(config_res, T=config.T), key_t, key_value2, dt, N, _
31
30
  )
32
31
  # scale to be within bounds
33
32
  q_traj_res = q_traj_res * (jnp.deg2rad(ang_max_deg) / jnp.pi)
@@ -1,5 +1,6 @@
1
1
  from dataclasses import replace
2
2
  from functools import partial
3
+ import json
3
4
  import logging
4
5
  import random
5
6
  from typing import Callable, Optional
@@ -136,6 +137,14 @@ class RCMG:
136
137
  affecting joint motion behavior.
137
138
  """ # noqa: E501
138
139
 
140
+ # capture all funtion arguments before creating local variables
141
+ to_json_kwargs = locals()
142
+ # the purpose is to not capture the RCMG itself since we want to make it
143
+ # serialisable in the first place
144
+ to_json_kwargs.pop("self")
145
+ to_json_kwargs.pop("sys")
146
+ to_json_kwargs.pop("config")
147
+
139
148
  # add some default values
140
149
  randomize_hz_kwargs_defaults = dict(add_dt=True)
141
150
  randomize_hz_kwargs_defaults.update(randomize_hz_kwargs)
@@ -186,6 +195,11 @@ class RCMG:
186
195
 
187
196
  self._disable_tqdm = disable_tqdm
188
197
 
198
+ # store arguments that fully define the RCMG objects for use in `.to_json`
199
+ self._to_json_sys = sys
200
+ self._to_json_mconfig = config
201
+ self._to_json_kwargs = to_json_kwargs
202
+
189
203
  def _compute_repeats(self, sizes: int | list[int]) -> list[int]:
190
204
  "how many times the generators are repeated to create a batch of `sizes`"
191
205
 
@@ -355,6 +369,21 @@ class RCMG:
355
369
 
356
370
  return generator
357
371
 
372
+ def serialise_to_dict(self) -> dict:
373
+ dict_representation = {
374
+ "system": [_sys.to_str(warn=False) for _sys in self._to_json_sys],
375
+ "motion_configs": [_config.__dict__ for _config in self._to_json_mconfig],
376
+ "kwargs": self._to_json_kwargs,
377
+ }
378
+ return dict_representation
379
+
380
+ def serialise_to_json(self, path_of_json: str) -> None:
381
+ with open(path_of_json, "w") as file:
382
+ json.dump(self.serialise_to_dict(), file, indent=4)
383
+
384
+ def from_json(self, path_to_json: str) -> "RCMG":
385
+ raise NotImplementedError
386
+
358
387
 
359
388
  def _copy_dicts(f) -> dict:
360
389
  def _f(*args, **kwargs):
@@ -526,7 +555,7 @@ def draw_random_q(
526
555
  sys: base.System,
527
556
  config: jcalc.MotionConfig,
528
557
  N: int | None,
529
- ) -> tuple[types.Xy, types.OutputExtras]:
558
+ ) -> tuple[jax.random.PRNGKey, jax.Array]:
530
559
 
531
560
  key_start = key
532
561
  # build generalized coordintes vector `q`
ring/algorithms/jcalc.py CHANGED
@@ -174,6 +174,16 @@ class MotionConfig:
174
174
  default_factory=lambda: dict()
175
175
  )
176
176
 
177
+ # fields related to simulating standstills (no motion time periods)
178
+ # these are "Joint Standstills" so the standstills are calculated on
179
+ # a joint level, for each joint independently
180
+ # This means that a `standstills_prob` of 20% means that each joint
181
+ # has at each dt \in [t_min, t_max] drawing process a probability of
182
+ # 20% that it will just stay at its current joint value
183
+ include_standstills_prob: float = 0.0 # in %; 0% means no standstills
184
+ include_standstills_t_min: float = 0.5
185
+ include_standstills_t_max: float = 5.0
186
+
177
187
  def is_feasible(self) -> bool:
178
188
  return _is_feasible_config1(self)
179
189
 
@@ -791,12 +801,15 @@ def _draw_rxyz(
791
801
  config.cdf_bins_min,
792
802
  config.cdf_bins_max,
793
803
  config.interpolation_method,
804
+ config.include_standstills_prob,
805
+ config.include_standstills_t_min,
806
+ config.include_standstills_t_max,
794
807
  )
795
808
 
796
809
 
797
810
  def _draw_pxyz(
798
811
  config: MotionConfig,
799
- _: jax.random.PRNGKey,
812
+ key_t: jax.random.PRNGKey,
800
813
  key_value: jax.random.PRNGKey,
801
814
  dt: float | jax.Array,
802
815
  N: int | None,
@@ -811,6 +824,7 @@ def _draw_pxyz(
811
824
  )
812
825
  max_iter = 100
813
826
  return _random.random_position_over_time(
827
+ key_t,
814
828
  key_value,
815
829
  POS_0,
816
830
  config.cor_pos_min if cor else config.pos_min,
@@ -827,6 +841,9 @@ def _draw_pxyz(
827
841
  config.cdf_bins_min,
828
842
  config.cdf_bins_max,
829
843
  config.interpolation_method,
844
+ config.include_standstills_prob,
845
+ config.include_standstills_t_min,
846
+ config.include_standstills_t_max,
830
847
  )
831
848
 
832
849
 
@@ -840,7 +857,6 @@ def _draw_spherical(
840
857
  ) -> jax.Array:
841
858
  # NOTE: We draw 3 euler angles and then build a quaternion.
842
859
  # Not ideal, but i am unaware of a better way.
843
- @jax.vmap
844
860
  def draw_euler_angles(key_t, key_value):
845
861
  return _draw_rxyz(
846
862
  config,
@@ -853,8 +869,9 @@ def _draw_spherical(
853
869
  free_spherical=True,
854
870
  )
855
871
 
856
- triple = lambda key: jax.random.split(key, 3)
857
- euler_angles = draw_euler_angles(triple(key_t), triple(key_value)).T
872
+ euler_angles = jax.vmap(draw_euler_angles, in_axes=(None, 0))(
873
+ key_t, jax.random.split(key_value, 3)
874
+ ).T
858
875
  q = maths.quat_euler(euler_angles)
859
876
  return q
860
877
 
@@ -867,7 +884,6 @@ def _draw_saddle(
867
884
  N: int | None,
868
885
  _: jax.Array,
869
886
  ) -> jax.Array:
870
- @jax.vmap
871
887
  def draw_euler_angles(key_t, key_value):
872
888
  return _draw_rxyz(
873
889
  config,
@@ -880,14 +896,15 @@ def _draw_saddle(
880
896
  free_spherical=False,
881
897
  )
882
898
 
883
- double = lambda key: jax.random.split(key)
884
- yz_euler_angles = draw_euler_angles(double(key_t), double(key_value)).T
899
+ yz_euler_angles = jax.vmap(draw_euler_angles, in_axes=(None, 0))(
900
+ key_t, jax.random.split(key_value)
901
+ ).T
885
902
  return yz_euler_angles
886
903
 
887
904
 
888
905
  def _draw_p3d_and_cor(
889
906
  config: MotionConfig,
890
- _: jax.random.PRNGKey,
907
+ key_t: jax.random.PRNGKey,
891
908
  key_value: jax.random.PRNGKey,
892
909
  dt: float | jax.Array,
893
910
  N: int | None,
@@ -896,7 +913,7 @@ def _draw_p3d_and_cor(
896
913
  ) -> jax.Array:
897
914
  keys = jax.random.split(key_value, 3)
898
915
 
899
- def draw(key, xyz: str):
916
+ def draw(key_value, xyz: str):
900
917
  return _draw_pxyz(
901
918
  replace(
902
919
  config,
@@ -905,8 +922,8 @@ def _draw_p3d_and_cor(
905
922
  pos0_min=getattr(config, f"pos0_min_p3d_{xyz}"),
906
923
  pos0_max=getattr(config, f"pos0_max_p3d_{xyz}"),
907
924
  ),
908
- None,
909
- key,
925
+ key_t,
926
+ key_value,
910
927
  dt,
911
928
  N,
912
929
  None,
@@ -919,26 +936,26 @@ def _draw_p3d_and_cor(
919
936
 
920
937
  def _draw_p3d(
921
938
  config: MotionConfig,
922
- _: jax.random.PRNGKey,
939
+ key_t: jax.random.PRNGKey,
923
940
  key_value: jax.random.PRNGKey,
924
941
  dt: float | jax.Array,
925
942
  N: int | None,
926
943
  __: jax.Array,
927
944
  ) -> jax.Array:
928
- return _draw_p3d_and_cor(config, _, key_value, dt, N, None, cor=False)
945
+ return _draw_p3d_and_cor(config, key_t, key_value, dt, N, None, cor=False)
929
946
 
930
947
 
931
948
  def _draw_cor(
932
949
  config: MotionConfig,
933
- _: jax.random.PRNGKey,
950
+ key_t: jax.random.PRNGKey,
934
951
  key_value: jax.random.PRNGKey,
935
952
  dt: float | jax.Array,
936
953
  N: int | None,
937
954
  __: jax.Array,
938
955
  ) -> jax.Array:
939
956
  key_value1, key_value2 = jax.random.split(key_value)
940
- q_free = _draw_free(config, _, key_value1, dt, N, None)
941
- q_p3d = _draw_p3d_and_cor(config, _, key_value2, dt, N, None, cor=True)
957
+ q_free = _draw_free(config, key_t, key_value1, dt, N, None)
958
+ q_p3d = _draw_p3d_and_cor(config, key_t, key_value2, dt, N, None, cor=True)
942
959
  return jnp.concatenate((q_free, q_p3d), axis=1)
943
960
 
944
961
 
@@ -952,7 +969,7 @@ def _draw_free(
952
969
  ) -> jax.Array:
953
970
  key_value1, key_value2 = jax.random.split(key_value)
954
971
  q = _draw_spherical(config, key_t, key_value1, dt, N, None)
955
- pos = _draw_p3d(config, None, key_value2, dt, N, None)
972
+ pos = _draw_p3d(config, key_t, key_value2, dt, N, None)
956
973
  return jnp.concatenate((q, pos), axis=1)
957
974
 
958
975
 
@@ -975,7 +992,7 @@ def _draw_free_2d(
975
992
  enable_range_of_motion=False,
976
993
  free_spherical=True,
977
994
  )[:, None]
978
- pos_yz = _draw_p3d(config, None, key_value2, dt, N, None)[:, :2]
995
+ pos_yz = _draw_p3d(config, key_t, key_value2, dt, N, None)[:, :2]
979
996
  return jnp.concatenate((angle_x, pos_yz), axis=1)
980
997
 
981
998
 
ring/base.py CHANGED
@@ -981,6 +981,7 @@ class System(_Base):
981
981
 
982
982
  def render(
983
983
  self,
984
+ qs: Optional[jax.Array | list[jax.Array]] = None,
984
985
  xs: Optional[Transform | list[Transform]] = None,
985
986
  camera: Optional[str] = None,
986
987
  show_pbar: bool = True,
@@ -1001,7 +1002,7 @@ class System(_Base):
1001
1002
  list[np.ndarray]: Stacked rendered frames. Length == len(xs).
1002
1003
  """
1003
1004
  return ring.rendering.render(
1004
- self, xs, camera, show_pbar, backend, render_every_nth, **scene_kwargs
1005
+ self, qs, xs, camera, show_pbar, backend, render_every_nth, **scene_kwargs
1005
1006
  )
1006
1007
 
1007
1008
  def render_prediction(
File without changes
@@ -0,0 +1,114 @@
1
+ import multiprocessing
2
+ import time
3
+ from typing import Optional
4
+
5
+ import fire
6
+ import jax.numpy as jnp
7
+ import numpy as np
8
+
9
+ import ring
10
+ from ring import System
11
+
12
+
13
+ class InteractiveViewer:
14
+ def __init__(self, sys: ring.System, **scene_kwargs):
15
+ self._mp_dict = multiprocessing.Manager().dict()
16
+ self._geom_dict = multiprocessing.Manager().dict()
17
+ self.update_q(np.array(ring.State.create(sys).q))
18
+ self.process = multiprocessing.Process(
19
+ target=self._worker,
20
+ args=(self._mp_dict, self._geom_dict, sys.to_str(), scene_kwargs),
21
+ )
22
+ self.process.start()
23
+
24
+ def update_q(self, q: np.ndarray):
25
+ self._mp_dict["q"] = q
26
+
27
+ def make_geometry_transparent(self, body_number: int, geom_number: int):
28
+ geom_name = f"body{body_number}_geom{geom_number}"
29
+ # the value is not used
30
+ self._geom_dict[geom_name] = None
31
+
32
+ def _worker(self, mp_dict, geom_dict, sys_str, scene_kwargs):
33
+ from ring.rendering import base_render
34
+
35
+ sys = System.from_str(sys_str)
36
+ while base_render._scene is None or base_render._scene._renderer.is_alive:
37
+ sys.render(jnp.array(mp_dict["q"]), interactive=True, **scene_kwargs)
38
+
39
+ if len(geom_dict) > 0:
40
+ model = base_render._scene._model
41
+ processed = []
42
+ for geom_name in list(geom_dict.keys()):
43
+ # Get the geometry ID
44
+ geom_id = model.geom(geom_name).id
45
+ # Set transparency to 0 (fully transparent)
46
+ model.geom_rgba[geom_id, 3] = 0
47
+ print(f"Made geom with name={geom_name} transparent (worker)")
48
+ processed.append(geom_name)
49
+
50
+ for geom_name in processed:
51
+ geom_dict.pop(geom_name)
52
+
53
+ def __enter__(self):
54
+ return self
55
+
56
+ def close(self):
57
+ self.process.terminate()
58
+ self.process.join()
59
+
60
+ def __exit__(self, exc_type, exc_value, traceback):
61
+ self.close()
62
+
63
+
64
+ def _fire_main(path_sys_xml: str, path_qs_np: Optional[str] = None, **scene_kwargs):
65
+ """View motion given by trajectory of minimal coordinates in interactive viewer.
66
+
67
+ Args:
68
+ path_sys_xml (str): Path to xml file defining the system.
69
+ path_qs_np (str | None, optional): Path to numpy array containing the timeseries of minimal coordinates with
70
+ shape (T, DOF) where DOF is equal to `sys.q_size()`. Each minimal coordiante is from parent
71
+ to child. So for example a `spherical` joint that connects the first body to the worldbody
72
+ has a minimal coordinate of a quaternion that gives from worldbody to first body. The sampling
73
+ rate of the motion is inferred from the `sys.dt` attribute. If `None` (default), then simply renders the
74
+ unarticulated pose of the system.
75
+ """ # noqa: E501
76
+
77
+ sys = ring.System.from_xml(path_sys_xml)
78
+ if path_qs_np is None:
79
+ qs = np.array(ring.State.create(sys).q)[None]
80
+ else:
81
+ qs: np.ndarray = np.load(path_qs_np)
82
+
83
+ assert qs.ndim == 2, f"qs.shape = {qs.shape}"
84
+ T, Q = qs.shape
85
+ assert Q == sys.q_size(), f"Q={Q} != sys.q_size={sys.q_size()}"
86
+ dt_target = sys.dt
87
+
88
+ with InteractiveViewer(sys, width=640, height=480, **scene_kwargs) as viewer:
89
+ dt = dt_target
90
+ last_t = time.time()
91
+ t = -1
92
+
93
+ while True:
94
+ t = (t + 1) % T
95
+
96
+ while dt < dt_target:
97
+ time.sleep(0.001)
98
+ dt = time.time() - last_t
99
+
100
+ last_t = time.time()
101
+ viewer.update_q(qs[t])
102
+ dt = time.time() - last_t
103
+
104
+ # process will be stopped if the window is closed
105
+ if not viewer.process.is_alive():
106
+ break
107
+
108
+
109
+ def main():
110
+ fire.Fire(_fire_main)
111
+
112
+
113
+ if __name__ == "__main__":
114
+ main()
@@ -0,0 +1,93 @@
1
+ """This module exports a loss function `loss_fn` for training neural networks that
2
+ output quaternions in PyTorch"""
3
+
4
+ from typing import Sequence
5
+
6
+ import torch
7
+
8
+
9
+ def quat_mul(u: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
10
+ "Multiplies two quaternions."
11
+ q = torch.stack(
12
+ [
13
+ u[..., 0] * v[..., 0]
14
+ - u[..., 1] * v[..., 1]
15
+ - u[..., 2] * v[..., 2]
16
+ - u[..., 3] * v[..., 3],
17
+ u[..., 0] * v[..., 1]
18
+ + u[..., 1] * v[..., 0]
19
+ + u[..., 2] * v[..., 3]
20
+ - u[..., 3] * v[..., 2],
21
+ u[..., 0] * v[..., 2]
22
+ - u[..., 1] * v[..., 3]
23
+ + u[..., 2] * v[..., 0]
24
+ + u[..., 3] * v[..., 1],
25
+ u[..., 0] * v[..., 3]
26
+ + u[..., 1] * v[..., 2]
27
+ - u[..., 2] * v[..., 1]
28
+ + u[..., 3] * v[..., 0],
29
+ ],
30
+ dim=-1,
31
+ )
32
+ return q
33
+
34
+
35
+ def quat_inv(q: torch.Tensor):
36
+ return torch.concat([q[..., :1], -q[..., 1:]], dim=-1)
37
+
38
+
39
+ def wrap_to_pi(phi):
40
+ "Wraps angle `phi` (radians) to interval [-pi, pi]."
41
+ return (phi + torch.pi) % (2 * torch.pi) - torch.pi
42
+
43
+
44
+ def quat_angle(q: torch.Tensor):
45
+ phi = 2 * torch.arctan2(torch.norm(q[..., 1:], dim=-1), q[..., 0])
46
+ return wrap_to_pi(phi)
47
+
48
+
49
+ def safe_normalize(x):
50
+ return x / (1e-6 + torch.norm(x, dim=-1, keepdim=True))
51
+
52
+
53
+ def quat_qrel(q1, q2):
54
+ "q1^-1 * q2"
55
+ return quat_mul(quat_inv(q1), q2)
56
+
57
+
58
+ @torch.jit.script
59
+ def angle_error(q, qhat):
60
+ "Absolute angle error in radians"
61
+ return torch.abs(quat_angle(quat_qrel(q, qhat)))
62
+
63
+
64
+ @torch.jit.script
65
+ def inclination_error(q, qhat):
66
+ "Absolute inclination error in radians. `q`s are from body-to-eps"
67
+ q_rel = quat_mul(q, quat_inv(qhat))
68
+ phi_pri = 2 * torch.arctan2(q_rel[..., 3], q_rel[..., 0])
69
+ q_pri = torch.zeros_like(q)
70
+ q_pri[..., 0] = torch.cos(phi_pri / 2)
71
+ q_pri[..., 3] = torch.sin(phi_pri / 2)
72
+ q_res = quat_mul(q_rel, quat_inv(q_pri))
73
+ return torch.abs(quat_angle(q_res))
74
+
75
+
76
+ def loss_fn(lam: Sequence[int], q: torch.Tensor, qhat: torch.Tensor) -> torch.Tensor:
77
+ "(..., N, 4) -> (..., N)"
78
+ *batch_dims, N, F = q.shape
79
+ assert q.shape == qhat.shape
80
+ assert F == 4
81
+ assert N == len(lam)
82
+ permu = list(reversed(range(q.ndim - 1)))
83
+ loss_incl = inclination_error(q, qhat).permute(*permu)
84
+ loss_mae = angle_error(q, qhat).permute(*permu)
85
+ lam = torch.tensor(lam, device=q.device)
86
+ return torch.where(
87
+ lam.reshape(-1, *[1] * len(batch_dims)) == -1, loss_incl, loss_mae
88
+ ).permute(*permu)
89
+
90
+
91
+ def quat_rand(*size: tuple[int]):
92
+ qs = torch.randn(size=size + (4,))
93
+ return qs / torch.norm(qs, dim=-1, keepdim=True)
@@ -1,3 +1,4 @@
1
+ from functools import partial
1
2
  from typing import Optional
2
3
 
3
4
  import jax
@@ -93,15 +94,29 @@ def _load_scene(sys, backend, **scene_kwargs):
93
94
  return _scene
94
95
 
95
96
 
97
+ @jax.jit
98
+ def _jit_forward_kinematics(sys):
99
+ _, state = kinematics.forward_kinematics(sys, base.State.create(sys))
100
+ return state.x
101
+
102
+
103
+ @jax.jit
104
+ @partial(jax.vmap, in_axes=(None, 0))
105
+ def _jit_vmap_forward_kinematics(sys, q):
106
+ _, state = kinematics.forward_kinematics(sys, base.State.create(sys, q=q))
107
+ return state.x
108
+
109
+
96
110
  def render(
97
111
  sys: base.System,
112
+ qs: Optional[jax.Array | list[jax.Array]] = None,
98
113
  xs: Optional[base.Transform | list[base.Transform]] = None,
99
114
  camera: Optional[str] = None,
100
115
  show_pbar: bool = True,
101
116
  backend: str = "mujoco",
102
117
  render_every_nth: int = 1,
103
118
  **scene_kwargs,
104
- ) -> list[np.ndarray]:
119
+ ) -> list[np.ndarray | None]:
105
120
  """Render frames from system and trajectory of maximal coordinates `xs`.
106
121
 
107
122
  Args:
@@ -114,9 +129,18 @@ def render(
114
129
  Returns:
115
130
  list[np.ndarray]: Stacked rendered frames. Length == len(xs).
116
131
  """
132
+ assert not (qs is not None and xs is not None)
117
133
 
134
+ if xs is None and qs is None:
135
+ xs = _jit_forward_kinematics(sys)
118
136
  if xs is None:
119
- xs = kinematics.forward_kinematics(sys, base.State.create(sys))[1].x
137
+ # throw error if `xs` has been given by accident as `qs` argument
138
+ qs = utils.to_list(qs)
139
+ assert not isinstance(
140
+ qs[0], base.Transform
141
+ ), "`qs` should be `jax.Array` and not `Transform`; maybe you want to pass `xs` as keyword argument `xs=xs`?" # noqa: E501
142
+ qs = jnp.stack(qs, axis=0)
143
+ xs = _jit_vmap_forward_kinematics(sys, qs)
120
144
 
121
145
  # convert time-axis of batched xs object into a list of unbatched x objects
122
146
  if isinstance(xs, base.Transform) and xs.ndim() == 3:
@@ -144,6 +168,9 @@ def render(
144
168
 
145
169
  scene = _load_scene(sys, backend, **scene_kwargs)
146
170
 
171
+ if scene_kwargs.get("interactive", False):
172
+ show_pbar = False
173
+
147
174
  frames = []
148
175
  for x in tqdm.tqdm(xs, "Rendering frames..", disable=not show_pbar):
149
176
  scene.update(x)
@@ -241,7 +268,7 @@ def render_prediction(
241
268
  sys, xs, yhat, transparent_segment_to_root, offset_truth, offset_pred
242
269
  )
243
270
 
244
- frames = render(sys_render, xs_render, **kwargs)
271
+ frames = render(sys=sys_render, xs=xs_render, **kwargs)
245
272
  return frames
246
273
 
247
274
 
@@ -118,8 +118,10 @@ def _xml_str_one_body(
118
118
  body_number: int, geoms: list[base.Geometry], cameras: list[str], lights: list[str]
119
119
  ) -> str:
120
120
  inside_body_geoms = ""
121
- for geom in geoms:
122
- inside_body_geoms += _xml_str_one_geom(geom)
121
+ for geom_number, geom in enumerate(geoms):
122
+ inside_body_geoms += _xml_str_one_geom(
123
+ geom, name=f"body{body_number}_geom{geom_number}"
124
+ )
123
125
 
124
126
  inside_body_cameras = ""
125
127
  for camera in cameras:
@@ -138,7 +140,7 @@ def _xml_str_one_body(
138
140
  """
139
141
 
140
142
 
141
- def _xml_str_one_geom(geom: base.Geometry) -> str:
143
+ def _xml_str_one_geom(geom: base.Geometry, name: str) -> str:
142
144
  rgba = f'rgba="{_array_to_str(geom.color)}"'
143
145
 
144
146
  if isinstance(geom, base.Box):
@@ -158,7 +160,8 @@ def _xml_str_one_geom(geom: base.Geometry) -> str:
158
160
 
159
161
  rot, pos = maths.quat_inv(geom.transform.rot), geom.transform.pos
160
162
  rot, pos = f'pos="{_array_to_str(pos)}"', f'quat="{_array_to_str(rot)}"'
161
- return f"<geom {type_size} {rgba} {rot} {pos}/>"
163
+ name = f'name="{name}"'
164
+ return f"<geom {type_size} {rgba} {rot} {pos} {name}/>"
162
165
 
163
166
 
164
167
  def _array_to_str(arr: Sequence[float]) -> str:
@@ -181,6 +184,8 @@ class MujocoScene:
181
184
  floor_z: float = -0.84,
182
185
  floor_material: str = "matplane",
183
186
  debug: bool = False,
187
+ interactive: bool = False,
188
+ interactive_hide_menu: bool = False,
184
189
  ) -> None:
185
190
  self.debug = debug
186
191
  self.height, self.width = height, width
@@ -195,6 +200,8 @@ class MujocoScene:
195
200
  self.show_stars = show_stars
196
201
  self.show_floor = show_floor
197
202
  self.floor_kwargs = dict(z=floor_z, material=floor_material)
203
+ self.interactive = interactive
204
+ self.interactive_hide_menu = interactive_hide_menu
198
205
 
199
206
  def init(self, geoms: list[base.Geometry]):
200
207
  self._parent_ids = list(set([geom.link_idx for geom in geoms]))
@@ -208,7 +215,22 @@ class MujocoScene:
208
215
  debug=self.debug,
209
216
  )
210
217
  self._data = mujoco.MjData(self._model)
211
- self._renderer = mujoco.Renderer(self._model, self.height, self.width)
218
+ if self.interactive:
219
+ import mujoco_viewer
220
+
221
+ self._renderer = mujoco_viewer.MujocoViewer(
222
+ self._model,
223
+ self._data,
224
+ width=self.width,
225
+ height=self.height,
226
+ hide_menus=self.interactive_hide_menu,
227
+ )
228
+
229
+ if self.interactive_hide_menu:
230
+ print("Menu can be shown with key `H` for H(elp)")
231
+
232
+ else:
233
+ self._renderer = mujoco.Renderer(self._model, self.height, self.width)
212
234
 
213
235
  def update(self, x: base.Transform):
214
236
  rot, pos = maths.quat_inv(x.rot), x.pos
@@ -234,6 +256,15 @@ class MujocoScene:
234
256
 
235
257
  mujoco.mj_forward(self._model, self._data)
236
258
 
237
- def render(self, camera: Optional[str] = None):
238
- self._renderer.update_scene(self._data, camera=-1 if camera is None else camera)
259
+ def render(self, camera: Optional[str] = None) -> np.ndarray | None:
260
+ if not self.interactive:
261
+ self._renderer.update_scene(
262
+ self._data, camera=-1 if camera is None else camera
263
+ )
239
264
  return self._renderer.render()
265
+
266
+ def close(self):
267
+ self._renderer.close()
268
+
269
+ def __del__(self):
270
+ self.close()
ring/utils/__init__.py CHANGED
@@ -1,11 +1,7 @@
1
- from . import randomize_sys
2
1
  from .batchsize import batchsize_thresholds
3
2
  from .batchsize import distribute_batchsize
4
3
  from .batchsize import expand_batchsize
5
4
  from .batchsize import merge_batchsize
6
- from .colab import setup_colab_env
7
- from .normalizer import make_normalizer_from_generator
8
- from .normalizer import Normalizer
9
5
  from .path import parse_path
10
6
  from .utils import dict_to_nested
11
7
  from .utils import dict_union
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes