imt-ring 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. imt_ring-1.2.1.dist-info/METADATA +91 -0
  2. imt_ring-1.2.1.dist-info/RECORD +83 -0
  3. imt_ring-1.2.1.dist-info/WHEEL +5 -0
  4. imt_ring-1.2.1.dist-info/top_level.txt +1 -0
  5. ring/__init__.py +63 -0
  6. ring/algebra.py +100 -0
  7. ring/algorithms/__init__.py +45 -0
  8. ring/algorithms/_random.py +403 -0
  9. ring/algorithms/custom_joints/__init__.py +6 -0
  10. ring/algorithms/custom_joints/rr_imp_joint.py +69 -0
  11. ring/algorithms/custom_joints/rr_joint.py +33 -0
  12. ring/algorithms/custom_joints/suntay.py +424 -0
  13. ring/algorithms/dynamics.py +345 -0
  14. ring/algorithms/generator/__init__.py +25 -0
  15. ring/algorithms/generator/base.py +414 -0
  16. ring/algorithms/generator/batch.py +282 -0
  17. ring/algorithms/generator/motion_artifacts.py +222 -0
  18. ring/algorithms/generator/pd_control.py +182 -0
  19. ring/algorithms/generator/randomize.py +119 -0
  20. ring/algorithms/generator/transforms.py +410 -0
  21. ring/algorithms/generator/types.py +36 -0
  22. ring/algorithms/jcalc.py +840 -0
  23. ring/algorithms/kinematics.py +202 -0
  24. ring/algorithms/sensors.py +582 -0
  25. ring/base.py +1046 -0
  26. ring/io/__init__.py +9 -0
  27. ring/io/examples/branched.xml +24 -0
  28. ring/io/examples/exclude/knee_trans_dof.xml +26 -0
  29. ring/io/examples/exclude/standard_sys.xml +106 -0
  30. ring/io/examples/exclude/standard_sys_rr_imp.xml +106 -0
  31. ring/io/examples/inv_pendulum.xml +14 -0
  32. ring/io/examples/knee_flexible_imus.xml +22 -0
  33. ring/io/examples/spherical_stiff.xml +11 -0
  34. ring/io/examples/symmetric.xml +12 -0
  35. ring/io/examples/test_all_1.xml +39 -0
  36. ring/io/examples/test_all_2.xml +39 -0
  37. ring/io/examples/test_ang0_pos0.xml +9 -0
  38. ring/io/examples/test_control.xml +16 -0
  39. ring/io/examples/test_double_pendulum.xml +14 -0
  40. ring/io/examples/test_free.xml +11 -0
  41. ring/io/examples/test_kinematics.xml +23 -0
  42. ring/io/examples/test_morph_system/four_seg_seg1.xml +26 -0
  43. ring/io/examples/test_morph_system/four_seg_seg3.xml +26 -0
  44. ring/io/examples/test_randomize_position.xml +26 -0
  45. ring/io/examples/test_sensors.xml +13 -0
  46. ring/io/examples/test_three_seg_seg2.xml +23 -0
  47. ring/io/examples.py +42 -0
  48. ring/io/test_examples.py +6 -0
  49. ring/io/xml/__init__.py +6 -0
  50. ring/io/xml/abstract.py +300 -0
  51. ring/io/xml/from_xml.py +299 -0
  52. ring/io/xml/test_from_xml.py +56 -0
  53. ring/io/xml/test_to_xml.py +31 -0
  54. ring/io/xml/to_xml.py +94 -0
  55. ring/maths.py +397 -0
  56. ring/ml/__init__.py +33 -0
  57. ring/ml/base.py +292 -0
  58. ring/ml/callbacks.py +434 -0
  59. ring/ml/ml_utils.py +272 -0
  60. ring/ml/optimizer.py +149 -0
  61. ring/ml/params/0x13e3518065c21cd8.pickle +0 -0
  62. ring/ml/ringnet.py +279 -0
  63. ring/ml/train.py +318 -0
  64. ring/ml/training_loop.py +131 -0
  65. ring/rendering/__init__.py +2 -0
  66. ring/rendering/base_render.py +271 -0
  67. ring/rendering/mujoco_render.py +222 -0
  68. ring/rendering/vispy_render.py +340 -0
  69. ring/rendering/vispy_visuals.py +290 -0
  70. ring/sim2real/__init__.py +7 -0
  71. ring/sim2real/sim2real.py +288 -0
  72. ring/spatial.py +126 -0
  73. ring/sys_composer/__init__.py +5 -0
  74. ring/sys_composer/delete_sys.py +114 -0
  75. ring/sys_composer/inject_sys.py +110 -0
  76. ring/sys_composer/morph_sys.py +361 -0
  77. ring/utils/__init__.py +21 -0
  78. ring/utils/batchsize.py +51 -0
  79. ring/utils/colab.py +48 -0
  80. ring/utils/hdf5.py +198 -0
  81. ring/utils/normalizer.py +56 -0
  82. ring/utils/path.py +44 -0
  83. ring/utils/utils.py +161 -0
@@ -0,0 +1,582 @@
1
+ from collections import defaultdict
2
+ from typing import Optional
3
+
4
+ import jax
5
+ import jax.numpy as jnp
6
+ from ring import algebra
7
+ from ring import algorithms
8
+ from ring import base
9
+ from ring import io
10
+ from ring import maths
11
+ from ring import sim2real
12
+
13
+
14
+ def accelerometer(
15
+ xs: base.Transform, gravity: jax.Array, dt: float, n: int
16
+ ) -> jax.Array:
17
+ """Compute measurements of an accelerometer that follows a frame which moves along
18
+ a trajectory of Transforms. Let `xs` be the transforms from base to link.
19
+ """
20
+
21
+ acc = (xs.pos[: -2 * n] + xs.pos[2 * n :] - 2 * xs.pos[n:-n]) / (n * dt) ** 2
22
+ acc = acc + gravity
23
+
24
+ # 2nd order derivative, (N,) -> (N-2n,)
25
+ # prepend and append n elements to keep shape size
26
+ acc = jnp.vstack((jnp.atleast_2d(acc[:n]), acc, jnp.atleast_2d(acc[-n:])))
27
+
28
+ return maths.rotate(acc, xs.rot)
29
+
30
+
31
+ def gyroscope(rot: jax.Array, dt: float, second_order: bool) -> jax.Array:
32
+ """Compute measurements of a gyroscope that follows a frame with an orientation
33
+ given by trajectory of quaternions `rot`."""
34
+ # this was not needed before
35
+ # the reason is that before q represented
36
+ # FROM LOCAL TO EPS
37
+ # but now we q represents
38
+ # FROM EPS TO LOCAL
39
+ # q = maths.quat_inv(rot)
40
+
41
+ q = rot
42
+ if second_order:
43
+ dq = maths.quat_mul(q[2:], maths.quat_inv(q[:-2]))
44
+ dq = jnp.vstack((dq[0][None], dq, dq[-1][None]))
45
+
46
+ dt = 2 * dt
47
+ else:
48
+ # 1st-order approx to derivative
49
+ dq = maths.quat_mul(q[1:], maths.quat_inv(q[:-1]))
50
+
51
+ # due to 1st order derivative, shape (N,) -> (N-1,)
52
+ # append one element at the end to keep shape size
53
+ dq = jnp.vstack((dq, dq[-1][None]))
54
+
55
+ axis, angle = maths.quat_to_rot_axis(dq)
56
+ angle = angle[:, None]
57
+
58
+ gyr = axis * angle / dt
59
+ return jnp.where(jnp.abs(angle) > 1e-10, gyr, jnp.zeros(3))
60
+
61
+
62
+ def _draw_random_magvec(key):
63
+ "Unit is in a.u. (40 microTesla)"
64
+ c1, c2 = jax.random.split(key)
65
+
66
+ dip_angle_min, dip_angle_max = -85.0, -50.0 # degrees
67
+ dip_angle = jnp.deg2rad(
68
+ jax.random.uniform(c1, minval=dip_angle_min, maxval=dip_angle_max)
69
+ )
70
+
71
+ norm_minval, norm_maxval = 15e-6, 65e-6 # Tesla; from lecture script page 10
72
+ # convert Tesla -> a.u. where (1 a.u. ~ 40 microTesla)
73
+ au = 40e-6
74
+ norm_minval, norm_maxval = norm_minval / au, norm_maxval / au
75
+ norm = jax.random.uniform(c2, minval=norm_minval, maxval=norm_maxval)
76
+
77
+ return jnp.array([0.0, jnp.cos(dip_angle), jnp.sin(dip_angle)]) * norm
78
+
79
+
80
+ def magnetometer(rot: jax.Array, magvec: jax.Array) -> jax.Array:
81
+ return maths.rotate(magvec, rot)
82
+
83
+
84
+ # Xsens MTI 10
85
+ # gyr:
86
+ # - bias error: 0.2 deg/s
87
+ # - bias stability: 18 deg/h
88
+ # - noise density: 0.03 deg/s/sqrt(hz)
89
+ # -> 40 hz: 0.2 deg/s
90
+ # -> 100 hz: 0.3 deg/s
91
+ # acc:
92
+ # - bias error: 0.05 m/s/s
93
+ # - bias stability: 15 micro g (<- gravity)
94
+ # - noise density: 60 micro g/sqrt(hz)
95
+ # -> 40 hz: 0.0036 m/s/s
96
+ # -> 100 hz: 0.006 m/s/s
97
+ # mag:
98
+ # - Total RMS noise: 0.5 milliGauss (1 Gauss = 1e-4 Tesla)
99
+ # -------------
100
+ # Xsens MTI 100
101
+ # gyr:
102
+ # - bias error: 0.2 deg/s
103
+ # - bias stability: 10 deg/h
104
+ # - noise density: 0.01 deg/s/sqrt(hz)
105
+ # -> 40 hz: 0.067 deg/s
106
+ # -> 100 hz: 0.1 deg/s
107
+ # acc:
108
+ # - bias error: 0.05 m/s/s
109
+ # - bias stability: 15 micro g (<- gravity)
110
+ # - noise density: 60 micro g/sqrt(hz)
111
+ # -> 40 hz: 0.0036 m/s/s
112
+ # -> 100 hz: 0.006 m/s/s
113
+ # mag:
114
+ # - Total RMS noise: 0.5 milliGauss
115
+ # -------------
116
+ # Movella Dot
117
+ # gyr:
118
+ # - bias error: ?
119
+ # - bias stability: 10 deg/h
120
+ # - noise density: 0.007 deg/s/sqrt(hz)
121
+ # acc:
122
+ # - bias error: ?
123
+ # - bias stability: 30 micro g
124
+ # - noise density: 120 micro g/sqrt(hz)
125
+ # mag:
126
+ # - Total RMS noise: 0.5 milliGauss = 5e-8 Tesla
127
+
128
+ # units are:
129
+ # - acc: m/s/s
130
+ # - gyr: rad/s
131
+ # - mag: a.u.
132
+ NOISE_LEVELS = {"acc": 0.03, "gyr": jnp.deg2rad(0.5), "mag": 0.01}
133
+ BIAS_LEVELS = {"acc": 0.1, "gyr": jnp.deg2rad(0.5), "mag": 0.0}
134
+
135
+
136
+ def add_noise_bias(
137
+ key: jax.random.PRNGKey,
138
+ imu_measurements: dict[str, jax.Array],
139
+ noise_levels: Optional[dict[str, float | None]] = None,
140
+ bias_levels: Optional[dict[str, float | None]] = None,
141
+ ) -> dict[str, jax.Array]:
142
+ """Add noise and bias to 6D or 9D imu measurements.
143
+
144
+ Args:
145
+ key (jax.random.PRNGKey): Random seed.
146
+ imu_measurements (dict): IMU measurements without noise and bias.
147
+ Format is {"gyr": Array, "acc": Array, "mag": Array}.
148
+
149
+ Returns:
150
+ dict: IMU measurements with noise and bias.
151
+ """
152
+ noise_levels = {} if noise_levels is None else noise_levels
153
+ bias_levels = {} if bias_levels is None else bias_levels
154
+
155
+ noisy_imu_measurements = {}
156
+ for sensor in imu_measurements:
157
+ key, c1, c2 = jax.random.split(key, 3)
158
+
159
+ noise_scale = noise_levels.get(sensor, NOISE_LEVELS[sensor])
160
+ if noise_scale is not None:
161
+ noise = (
162
+ jax.random.normal(c1, shape=imu_measurements[sensor].shape)
163
+ * noise_scale
164
+ )
165
+ else:
166
+ noise = 0.0
167
+
168
+ bias_maxval = bias_levels.get(sensor, BIAS_LEVELS[sensor])
169
+ if bias_maxval is not None:
170
+ bias = jax.random.uniform(
171
+ c2, minval=-bias_maxval, maxval=bias_maxval, shape=(3,)
172
+ )
173
+ else:
174
+ bias = 0.0
175
+
176
+ noisy_imu_measurements[sensor] = imu_measurements[sensor] + noise + bias
177
+
178
+ return noisy_imu_measurements
179
+
180
+
181
+ def imu(
182
+ xs: base.Transform,
183
+ gravity: jax.Array,
184
+ dt: float,
185
+ key: Optional[jax.random.PRNGKey] = None,
186
+ noisy: bool = False,
187
+ smoothen_degree: Optional[int] = None,
188
+ delay: Optional[int] = None,
189
+ random_s2s_ori: Optional[float] = None,
190
+ quasi_physical: bool = False,
191
+ low_pass_filter_pos_f_cutoff: Optional[float] = None,
192
+ low_pass_filter_rot_cutoff: Optional[float] = None,
193
+ has_magnetometer: bool = False,
194
+ magvec: Optional[jax.Array] = None,
195
+ gyro_second_order: bool = False,
196
+ natural_units: bool = False,
197
+ acc_xinyuyi_n: int = 1,
198
+ ) -> dict:
199
+ """Simulates a 6D IMU, `xs` should be Transforms from eps-to-imu.
200
+ NOTE: `smoothen_degree` is used as window size for moving average.
201
+ NOTE: If `smoothen_degree` is given, and `delay` is not, then delay is chosen
202
+ such moving average window is delayed to just be causal.
203
+ """
204
+ assert xs.ndim() == 2
205
+
206
+ if random_s2s_ori is not None:
207
+ assert key is not None, "`random_s2s_ori` requires a random seed via `key`"
208
+ # `xs` are now from eps-to-segment, so add another final rotation from
209
+ # segment-to-sensor where this transform is only rotational
210
+ key, consume = jax.random.split(key)
211
+ xs_s2s = base.Transform.create(
212
+ rot=maths.quat_random(consume, maxval=random_s2s_ori)
213
+ )
214
+ xs = jax.vmap(algebra.transform_mul, in_axes=(None, 0))(xs_s2s, xs)
215
+
216
+ if quasi_physical:
217
+ xs = _quasi_physical_simulation(xs, dt)
218
+
219
+ if low_pass_filter_pos_f_cutoff is not None:
220
+ xs = xs.replace(
221
+ pos=_butterworth(
222
+ xs.pos, f_sampling=1 / dt, f_cutoff=low_pass_filter_pos_f_cutoff
223
+ )
224
+ )
225
+
226
+ if low_pass_filter_rot_cutoff is not None:
227
+ xs = xs.replace(
228
+ rot=maths.quat_lowpassfilter(
229
+ xs.rot, cutoff_freq=low_pass_filter_rot_cutoff, samp_freq=1 / dt
230
+ )
231
+ )
232
+
233
+ measurements = {
234
+ "acc": accelerometer(xs, gravity, dt, acc_xinyuyi_n),
235
+ "gyr": gyroscope(xs.rot, dt, gyro_second_order),
236
+ }
237
+
238
+ if has_magnetometer:
239
+ if magvec is None:
240
+ assert key is not None
241
+ key, consume = jax.random.split(key)
242
+ magvec = _draw_random_magvec(consume)
243
+ measurements["mag"] = magnetometer(xs.rot, magvec)
244
+
245
+ if smoothen_degree is not None:
246
+ measurements = jax.tree_map(
247
+ lambda arr: _moving_average(arr, smoothen_degree),
248
+ measurements,
249
+ )
250
+
251
+ # if you low-pass filter the imu measurements through a moving average which
252
+ # effectively uses future values, then it also makes sense to delay the imu
253
+ # measurements by this amount such that no future information is used
254
+ if delay is None:
255
+ half_window = (smoothen_degree - 1) // 2
256
+ delay = half_window
257
+
258
+ if delay is not None and delay > 0:
259
+ measurements = jax.tree_map(
260
+ lambda arr: (jnp.pad(arr, ((delay, 0), (0, 0)))[:-delay]), measurements
261
+ )
262
+
263
+ if noisy:
264
+ assert key is not None, "For noisy sensors random seed `key` must be provided."
265
+ measurements = add_noise_bias(key, measurements)
266
+
267
+ if natural_units:
268
+ measurements = rescale_natural_units(measurements)
269
+
270
+ return measurements
271
+
272
+
273
+ _rescale_natural_units_fns = defaultdict(lambda: (lambda arr: arr))
274
+ _rescale_natural_units_fns["gyr"] = lambda gyr: gyr / jnp.pi
275
+ _rescale_natural_units_fns["acc"] = lambda acc: acc / 9.81
276
+
277
+
278
+ def rescale_natural_units(X: dict[str, jax.Array]):
279
+ return {key: _rescale_natural_units_fns[key](val) for key, val in X.items()}
280
+
281
+
282
+ def rel_pose(
283
+ sys: base.System, xs: base.Transform, sys_xs: Optional[base.System] = None
284
+ ) -> dict:
285
+ """Relative pose of the entire system. `sys_scan` defines the parent-child ordering,
286
+ relative pose is from child to parent in local coordinates. Bodies that connect
287
+ to the base are skipped (that would be absolute pose).
288
+
289
+ Args:
290
+ sys_scan (base.System): System defining parent-child ordering.
291
+ xs (base.Transform): Body transforms from base to body.
292
+ sys_xs (base.System): System that defines the stacking order of `xs`.
293
+
294
+ Returns:
295
+ dict: Child-to-parent quaternions
296
+ """
297
+ if sys_xs is None:
298
+ sys_xs = sys
299
+
300
+ if xs.pos.ndim == 3:
301
+ # swap (n_timesteps, n_links) axes
302
+ xs = xs.transpose([1, 0, 2])
303
+
304
+ assert xs.batch_dim() == sys_xs.num_links()
305
+
306
+ qrel = lambda q1, q2: maths.quat_mul(q1, maths.quat_inv(q2))
307
+
308
+ y = {}
309
+
310
+ def pose_child_to_parent(_, __, name_i: str, p: int):
311
+ # body connects to base
312
+ if p == -1:
313
+ return
314
+
315
+ name_p = sys.idx_to_name(p)
316
+
317
+ # find the transforms of those named bodies
318
+ i = sys_xs.name_to_idx(name_i)
319
+ p = sys_xs.name_to_idx(name_p)
320
+
321
+ # get those transforms
322
+ q1, q2 = xs.take(p).rot, xs.take(i).rot
323
+
324
+ y[name_i] = qrel(q1, q2)
325
+
326
+ sys.scan(pose_child_to_parent, "ll", sys.link_names, sys.link_parents)
327
+
328
+ return y
329
+
330
+
331
+ def root_incl(
332
+ sys: base.System, x: base.Transform, sys_x: base.System
333
+ ) -> dict[str, jax.Array]:
334
+ # (time, nlinks, 4) -> (nlinks, time, 4)
335
+ rots = x.rot.transpose((1, 0, 2))
336
+ l_map = sys_x.idx_map("l")
337
+
338
+ y = dict()
339
+
340
+ def f(_, __, name: str, parent: int):
341
+ if parent != -1:
342
+ return
343
+ y[name] = maths.quat_project(rots[l_map[name]], jnp.array([0.0, 0, 1]))[1]
344
+
345
+ sys.scan(f, "ll", sys.link_names, sys.link_parents)
346
+
347
+ return y
348
+
349
+
350
+ def root_full(
351
+ sys: base.System, x: base.Transform, sys_x: base.System
352
+ ) -> dict[str, jax.Array]:
353
+ # (time, nlinks, 4) -> (nlinks, time, 4)
354
+ rots = x.rot.transpose((1, 0, 2))
355
+ l_map = sys_x.idx_map("l")
356
+
357
+ y = dict()
358
+
359
+ def f(_, __, name: str, parent: int):
360
+ if parent != -1:
361
+ return
362
+ y[name] = rots[l_map[name]]
363
+
364
+ sys.scan(f, "ll", sys.link_names, sys.link_parents)
365
+
366
+ return y
367
+
368
+
369
+ def joint_axes(
370
+ sys: base.System,
371
+ xs: base.Transform,
372
+ sys_xs: base.System,
373
+ key: Optional[jax.Array] = None,
374
+ noisy: bool = False,
375
+ from_sys: bool = False,
376
+ randomly_flip: bool = False,
377
+ ):
378
+ """
379
+ The joint-axes to world is always zeros.
380
+ """
381
+ if key is None:
382
+ assert not noisy
383
+ assert not randomly_flip
384
+
385
+ N = xs.shape(axis=0)
386
+
387
+ if from_sys:
388
+ X = _joint_axes_from_sys(sys, N)
389
+ else:
390
+ X = _joint_axes_from_xs(sys, xs, sys_xs)
391
+
392
+ if noisy:
393
+ for name in X:
394
+ key, c1, c2 = jax.random.split(key, 3)
395
+ bias = maths.quat_random(c1, maxval=jnp.deg2rad(5.0))
396
+ noise = maths.quat_random(c2, (N,), maxval=jnp.deg2rad(2.0))
397
+ dist = maths.quat_mul(noise, bias)
398
+ X[name]["joint_axes"] = maths.rotate(X[name]["joint_axes"], dist)
399
+
400
+ # joint axes to world must be zeros
401
+ for name, p in zip(sys.link_names, sys.link_parents):
402
+ if p == -1:
403
+ X[name]["joint_axes"] = jnp.zeros((N, 3))
404
+ else:
405
+ if randomly_flip:
406
+ key, consume = jax.random.split(key)
407
+ X[name]["joint_axes"] = (
408
+ jax.random.choice(consume, jnp.array([1.0, -1.0]))
409
+ * X[name]["joint_axes"]
410
+ )
411
+
412
+ return X
413
+
414
+
415
+ def _joint_axes_from_xs(sys, xs, sys_xs):
416
+
417
+ xs = sim2real.match_xs(sys, xs, sys_xs)
418
+
419
+ _, transform2_rot = sim2real.unzip_xs(sys, xs)
420
+ qs = transform2_rot.rot.transpose((1, 0, 2))
421
+
422
+ l2norm = lambda x: jnp.sqrt(jnp.sum(x**2, axis=-1))
423
+
424
+ @jax.vmap
425
+ def ensure_axis_convention(qs):
426
+ axis = qs[..., 1:] / (
427
+ jnp.linalg.norm(qs[..., 1:], axis=-1, keepdims=True) + 1e-6
428
+ )
429
+ convention = axis[0]
430
+ cond = (l2norm(convention - axis) > l2norm(convention + axis))[..., None]
431
+ return jnp.where(cond, -axis, axis)
432
+
433
+ axes = ensure_axis_convention(qs)
434
+
435
+ # TODO
436
+ # not ideal to average vectors that live on a sphere
437
+ N = axes.shape[1]
438
+ axes_average = jnp.mean(axes, axis=1)
439
+ axes_average /= jnp.linalg.norm(axes_average, axis=-1, keepdims=True)
440
+ axes = jnp.repeat(axes_average[:, None], N, axis=1)
441
+
442
+ X = {name: {"joint_axes": axes[sys.name_to_idx(name)]} for name in sys.link_names}
443
+ return X
444
+
445
+
446
+ def _joint_axes_from_sys(sys: base.Transform, N: int) -> dict:
447
+ "`sys` should be `sys_noimu`. `N` is number of timesteps"
448
+ xaxis = jnp.array([1.0, 0, 0])
449
+ yaxis = jnp.array([0.0, 1, 0])
450
+ zaxis = jnp.array([0.0, 0, 1])
451
+ id_to_axis = {"x": xaxis, "y": yaxis, "z": zaxis}
452
+ X = {}
453
+
454
+ def f(_, __, name, link_type, link):
455
+ joint_params = link.joint_params
456
+ if link_type in ["rx", "ry", "rz"]:
457
+ joint_axes = id_to_axis[link_type[1]]
458
+ elif link_type == "rr":
459
+ joint_axes = joint_params["rr"]["joint_axes"]
460
+ elif link_type[:6] == "rr_imp":
461
+ joint_axes = joint_params[link_type]["joint_axes"]
462
+ else:
463
+ joint_axes = xaxis
464
+ X[name] = {"joint_axes": joint_axes}
465
+
466
+ sys.scan(f, "lll", sys.link_names, sys.link_types, sys.links)
467
+ X = jax.tree_map(lambda arr: jnp.repeat(arr[None], N, axis=0), X)
468
+ return X
469
+
470
+
471
+ def _moving_average(arr: jax.Array, window: int) -> jax.Array:
472
+ "Padds with left and right values of array."
473
+ assert window % 2 == 1
474
+ assert window > 1, "Window size of 1 would be a no-op"
475
+ arr_smooth = jnp.zeros((len(arr) + window - 1,) + arr.shape[1:])
476
+ half_window = (window - 1) // 2
477
+ arr_padded = arr_smooth.at[half_window : (len(arr) + half_window)].set(arr)
478
+ arr_padded = arr_padded.at[:half_window].set(arr[0])
479
+ arr_padded = arr_padded.at[-half_window:].set(arr[-1])
480
+
481
+ for i in range(-half_window, half_window + 1):
482
+ rolled = jnp.roll(arr_padded, i, axis=0)
483
+ arr_smooth += rolled
484
+ arr_smooth = arr_smooth / window
485
+ return arr_smooth[half_window : (len(arr) + half_window)]
486
+
487
+
488
+ _quasi_physical_sys_str = r"""
489
+ <x_xy>
490
+ <options gravity="0 0 0"/>
491
+ <worldbody>
492
+ <body name="IMU" joint="p3d" damping="0.1 0.1 0.1" spring_stiff="3 3 3">
493
+ <geom type="box" mass="0.002" dim="0.01 0.01 0.01"/>
494
+ </body>
495
+ </worldbody>
496
+ </x_xy>
497
+ """
498
+
499
+
500
+ def _quasi_physical_simulation_beautiful(
501
+ xs: base.Transform, dt: float
502
+ ) -> base.Transform:
503
+ sys = io.load_sys_from_str(_quasi_physical_sys_str).replace(dt=dt)
504
+
505
+ def step_dynamics(state: base.State, x):
506
+ state = algorithms.step(sys.replace(link_spring_zeropoint=x.pos), state)
507
+ return state, state.q
508
+
509
+ state = base.State.create(sys, q=xs.pos[0])
510
+ _, pos = jax.lax.scan(step_dynamics, state, xs)
511
+ return xs.replace(pos=pos)
512
+
513
+
514
+ _constants = {
515
+ "qp_damp": 35.0,
516
+ "qp_stif": 625.0,
517
+ }
518
+
519
+
520
+ def _quasi_physical_simulation(xs: base.Transform, dt: float) -> base.Transform:
521
+ mass = 1.0
522
+ damp = _constants["qp_damp"]
523
+ stif = _constants["qp_stif"]
524
+
525
+ def step_dynamics(state, zeropoint):
526
+ pos, vel = state
527
+ zeropoint_pos, zeropoint_vel = zeropoint
528
+ acc = (damp * (zeropoint_vel - vel) + stif * (zeropoint_pos - pos)) / mass
529
+ vel += dt * acc
530
+ # semi-implicit, so use already next velocity
531
+ pos += dt * vel
532
+ return (pos, vel), pos
533
+
534
+ zero_vel = jnp.zeros_like(xs.pos[0])
535
+ state = (xs.pos[0], zero_vel)
536
+ zeropoint_vel = jnp.vstack((zero_vel, jnp.diff(xs.pos, axis=0) / dt))
537
+ zeropoint = (xs.pos, zeropoint_vel)
538
+ _, pos = jax.lax.scan(step_dynamics, state, zeropoint)
539
+ return xs.replace(pos=pos)
540
+
541
+
542
+ def _butterworth(
543
+ signal: jax.Array,
544
+ f_sampling: float,
545
+ f_cutoff: int,
546
+ method: str = "forward_backward",
547
+ ) -> jax.Array:
548
+ """https://stackoverflow.com/questions/20924868/calculate-coefficients-of-2nd-order
549
+ -butterworth-low-pass-filter"""
550
+
551
+ if method == "forward_backward":
552
+ signal = _butterworth(signal, f_sampling, f_cutoff, "forward")
553
+ return _butterworth(signal, f_sampling, f_cutoff, "backward")
554
+ elif method == "forward":
555
+ pass
556
+ elif method == "backward":
557
+ signal = jnp.flip(signal, axis=0)
558
+ else:
559
+ raise NotImplementedError
560
+
561
+ ff = f_cutoff / f_sampling
562
+ ita = 1.0 / jnp.tan(jnp.pi * ff)
563
+ q = jnp.sqrt(2.0)
564
+ b0 = 1.0 / (1.0 + q * ita + ita**2)
565
+ b1 = 2 * b0
566
+ b2 = b0
567
+ a1 = 2.0 * (ita**2 - 1.0) * b0
568
+ a2 = -(1.0 - q * ita + ita**2) * b0
569
+
570
+ def f(carry, x_i):
571
+ x_im1, x_im2, y_im1, y_im2 = carry
572
+ y_i = b0 * x_i + b1 * x_im1 + b2 * x_im2 + a1 * y_im1 + a2 * y_im2
573
+ return (x_i, x_im1, y_i, y_im1), y_i
574
+
575
+ init = (signal[1], signal[0]) * 2
576
+ signal = jax.lax.scan(f, init, signal[2:])[1]
577
+ signal = jnp.concatenate((signal[0:1],) * 2 + (signal,))
578
+
579
+ if method == "backward":
580
+ signal = jnp.flip(signal, axis=0)
581
+
582
+ return signal