imt-ring 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- imt_ring-1.2.1.dist-info/METADATA +91 -0
- imt_ring-1.2.1.dist-info/RECORD +83 -0
- imt_ring-1.2.1.dist-info/WHEEL +5 -0
- imt_ring-1.2.1.dist-info/top_level.txt +1 -0
- ring/__init__.py +63 -0
- ring/algebra.py +100 -0
- ring/algorithms/__init__.py +45 -0
- ring/algorithms/_random.py +403 -0
- ring/algorithms/custom_joints/__init__.py +6 -0
- ring/algorithms/custom_joints/rr_imp_joint.py +69 -0
- ring/algorithms/custom_joints/rr_joint.py +33 -0
- ring/algorithms/custom_joints/suntay.py +424 -0
- ring/algorithms/dynamics.py +345 -0
- ring/algorithms/generator/__init__.py +25 -0
- ring/algorithms/generator/base.py +414 -0
- ring/algorithms/generator/batch.py +282 -0
- ring/algorithms/generator/motion_artifacts.py +222 -0
- ring/algorithms/generator/pd_control.py +182 -0
- ring/algorithms/generator/randomize.py +119 -0
- ring/algorithms/generator/transforms.py +410 -0
- ring/algorithms/generator/types.py +36 -0
- ring/algorithms/jcalc.py +840 -0
- ring/algorithms/kinematics.py +202 -0
- ring/algorithms/sensors.py +582 -0
- ring/base.py +1046 -0
- ring/io/__init__.py +9 -0
- ring/io/examples/branched.xml +24 -0
- ring/io/examples/exclude/knee_trans_dof.xml +26 -0
- ring/io/examples/exclude/standard_sys.xml +106 -0
- ring/io/examples/exclude/standard_sys_rr_imp.xml +106 -0
- ring/io/examples/inv_pendulum.xml +14 -0
- ring/io/examples/knee_flexible_imus.xml +22 -0
- ring/io/examples/spherical_stiff.xml +11 -0
- ring/io/examples/symmetric.xml +12 -0
- ring/io/examples/test_all_1.xml +39 -0
- ring/io/examples/test_all_2.xml +39 -0
- ring/io/examples/test_ang0_pos0.xml +9 -0
- ring/io/examples/test_control.xml +16 -0
- ring/io/examples/test_double_pendulum.xml +14 -0
- ring/io/examples/test_free.xml +11 -0
- ring/io/examples/test_kinematics.xml +23 -0
- ring/io/examples/test_morph_system/four_seg_seg1.xml +26 -0
- ring/io/examples/test_morph_system/four_seg_seg3.xml +26 -0
- ring/io/examples/test_randomize_position.xml +26 -0
- ring/io/examples/test_sensors.xml +13 -0
- ring/io/examples/test_three_seg_seg2.xml +23 -0
- ring/io/examples.py +42 -0
- ring/io/test_examples.py +6 -0
- ring/io/xml/__init__.py +6 -0
- ring/io/xml/abstract.py +300 -0
- ring/io/xml/from_xml.py +299 -0
- ring/io/xml/test_from_xml.py +56 -0
- ring/io/xml/test_to_xml.py +31 -0
- ring/io/xml/to_xml.py +94 -0
- ring/maths.py +397 -0
- ring/ml/__init__.py +33 -0
- ring/ml/base.py +292 -0
- ring/ml/callbacks.py +434 -0
- ring/ml/ml_utils.py +272 -0
- ring/ml/optimizer.py +149 -0
- ring/ml/params/0x13e3518065c21cd8.pickle +0 -0
- ring/ml/ringnet.py +279 -0
- ring/ml/train.py +318 -0
- ring/ml/training_loop.py +131 -0
- ring/rendering/__init__.py +2 -0
- ring/rendering/base_render.py +271 -0
- ring/rendering/mujoco_render.py +222 -0
- ring/rendering/vispy_render.py +340 -0
- ring/rendering/vispy_visuals.py +290 -0
- ring/sim2real/__init__.py +7 -0
- ring/sim2real/sim2real.py +288 -0
- ring/spatial.py +126 -0
- ring/sys_composer/__init__.py +5 -0
- ring/sys_composer/delete_sys.py +114 -0
- ring/sys_composer/inject_sys.py +110 -0
- ring/sys_composer/morph_sys.py +361 -0
- ring/utils/__init__.py +21 -0
- ring/utils/batchsize.py +51 -0
- ring/utils/colab.py +48 -0
- ring/utils/hdf5.py +198 -0
- ring/utils/normalizer.py +56 -0
- ring/utils/path.py +44 -0
- ring/utils/utils.py +161 -0
@@ -0,0 +1,582 @@
|
|
1
|
+
from collections import defaultdict
|
2
|
+
from typing import Optional
|
3
|
+
|
4
|
+
import jax
|
5
|
+
import jax.numpy as jnp
|
6
|
+
from ring import algebra
|
7
|
+
from ring import algorithms
|
8
|
+
from ring import base
|
9
|
+
from ring import io
|
10
|
+
from ring import maths
|
11
|
+
from ring import sim2real
|
12
|
+
|
13
|
+
|
14
|
+
def accelerometer(
|
15
|
+
xs: base.Transform, gravity: jax.Array, dt: float, n: int
|
16
|
+
) -> jax.Array:
|
17
|
+
"""Compute measurements of an accelerometer that follows a frame which moves along
|
18
|
+
a trajectory of Transforms. Let `xs` be the transforms from base to link.
|
19
|
+
"""
|
20
|
+
|
21
|
+
acc = (xs.pos[: -2 * n] + xs.pos[2 * n :] - 2 * xs.pos[n:-n]) / (n * dt) ** 2
|
22
|
+
acc = acc + gravity
|
23
|
+
|
24
|
+
# 2nd order derivative, (N,) -> (N-2n,)
|
25
|
+
# prepend and append n elements to keep shape size
|
26
|
+
acc = jnp.vstack((jnp.atleast_2d(acc[:n]), acc, jnp.atleast_2d(acc[-n:])))
|
27
|
+
|
28
|
+
return maths.rotate(acc, xs.rot)
|
29
|
+
|
30
|
+
|
31
|
+
def gyroscope(rot: jax.Array, dt: float, second_order: bool) -> jax.Array:
|
32
|
+
"""Compute measurements of a gyroscope that follows a frame with an orientation
|
33
|
+
given by trajectory of quaternions `rot`."""
|
34
|
+
# this was not needed before
|
35
|
+
# the reason is that before q represented
|
36
|
+
# FROM LOCAL TO EPS
|
37
|
+
# but now we q represents
|
38
|
+
# FROM EPS TO LOCAL
|
39
|
+
# q = maths.quat_inv(rot)
|
40
|
+
|
41
|
+
q = rot
|
42
|
+
if second_order:
|
43
|
+
dq = maths.quat_mul(q[2:], maths.quat_inv(q[:-2]))
|
44
|
+
dq = jnp.vstack((dq[0][None], dq, dq[-1][None]))
|
45
|
+
|
46
|
+
dt = 2 * dt
|
47
|
+
else:
|
48
|
+
# 1st-order approx to derivative
|
49
|
+
dq = maths.quat_mul(q[1:], maths.quat_inv(q[:-1]))
|
50
|
+
|
51
|
+
# due to 1st order derivative, shape (N,) -> (N-1,)
|
52
|
+
# append one element at the end to keep shape size
|
53
|
+
dq = jnp.vstack((dq, dq[-1][None]))
|
54
|
+
|
55
|
+
axis, angle = maths.quat_to_rot_axis(dq)
|
56
|
+
angle = angle[:, None]
|
57
|
+
|
58
|
+
gyr = axis * angle / dt
|
59
|
+
return jnp.where(jnp.abs(angle) > 1e-10, gyr, jnp.zeros(3))
|
60
|
+
|
61
|
+
|
62
|
+
def _draw_random_magvec(key):
|
63
|
+
"Unit is in a.u. (40 microTesla)"
|
64
|
+
c1, c2 = jax.random.split(key)
|
65
|
+
|
66
|
+
dip_angle_min, dip_angle_max = -85.0, -50.0 # degrees
|
67
|
+
dip_angle = jnp.deg2rad(
|
68
|
+
jax.random.uniform(c1, minval=dip_angle_min, maxval=dip_angle_max)
|
69
|
+
)
|
70
|
+
|
71
|
+
norm_minval, norm_maxval = 15e-6, 65e-6 # Tesla; from lecture script page 10
|
72
|
+
# convert Tesla -> a.u. where (1 a.u. ~ 40 microTesla)
|
73
|
+
au = 40e-6
|
74
|
+
norm_minval, norm_maxval = norm_minval / au, norm_maxval / au
|
75
|
+
norm = jax.random.uniform(c2, minval=norm_minval, maxval=norm_maxval)
|
76
|
+
|
77
|
+
return jnp.array([0.0, jnp.cos(dip_angle), jnp.sin(dip_angle)]) * norm
|
78
|
+
|
79
|
+
|
80
|
+
def magnetometer(rot: jax.Array, magvec: jax.Array) -> jax.Array:
|
81
|
+
return maths.rotate(magvec, rot)
|
82
|
+
|
83
|
+
|
84
|
+
# Xsens MTI 10
|
85
|
+
# gyr:
|
86
|
+
# - bias error: 0.2 deg/s
|
87
|
+
# - bias stability: 18 deg/h
|
88
|
+
# - noise density: 0.03 deg/s/sqrt(hz)
|
89
|
+
# -> 40 hz: 0.2 deg/s
|
90
|
+
# -> 100 hz: 0.3 deg/s
|
91
|
+
# acc:
|
92
|
+
# - bias error: 0.05 m/s/s
|
93
|
+
# - bias stability: 15 micro g (<- gravity)
|
94
|
+
# - noise density: 60 micro g/sqrt(hz)
|
95
|
+
# -> 40 hz: 0.0036 m/s/s
|
96
|
+
# -> 100 hz: 0.006 m/s/s
|
97
|
+
# mag:
|
98
|
+
# - Total RMS noise: 0.5 milliGauss (1 Gauss = 1e-4 Tesla)
|
99
|
+
# -------------
|
100
|
+
# Xsens MTI 100
|
101
|
+
# gyr:
|
102
|
+
# - bias error: 0.2 deg/s
|
103
|
+
# - bias stability: 10 deg/h
|
104
|
+
# - noise density: 0.01 deg/s/sqrt(hz)
|
105
|
+
# -> 40 hz: 0.067 deg/s
|
106
|
+
# -> 100 hz: 0.1 deg/s
|
107
|
+
# acc:
|
108
|
+
# - bias error: 0.05 m/s/s
|
109
|
+
# - bias stability: 15 micro g (<- gravity)
|
110
|
+
# - noise density: 60 micro g/sqrt(hz)
|
111
|
+
# -> 40 hz: 0.0036 m/s/s
|
112
|
+
# -> 100 hz: 0.006 m/s/s
|
113
|
+
# mag:
|
114
|
+
# - Total RMS noise: 0.5 milliGauss
|
115
|
+
# -------------
|
116
|
+
# Movella Dot
|
117
|
+
# gyr:
|
118
|
+
# - bias error: ?
|
119
|
+
# - bias stability: 10 deg/h
|
120
|
+
# - noise density: 0.007 deg/s/sqrt(hz)
|
121
|
+
# acc:
|
122
|
+
# - bias error: ?
|
123
|
+
# - bias stability: 30 micro g
|
124
|
+
# - noise density: 120 micro g/sqrt(hz)
|
125
|
+
# mag:
|
126
|
+
# - Total RMS noise: 0.5 milliGauss = 5e-8 Tesla
|
127
|
+
|
128
|
+
# units are:
|
129
|
+
# - acc: m/s/s
|
130
|
+
# - gyr: rad/s
|
131
|
+
# - mag: a.u.
|
132
|
+
NOISE_LEVELS = {"acc": 0.03, "gyr": jnp.deg2rad(0.5), "mag": 0.01}
|
133
|
+
BIAS_LEVELS = {"acc": 0.1, "gyr": jnp.deg2rad(0.5), "mag": 0.0}
|
134
|
+
|
135
|
+
|
136
|
+
def add_noise_bias(
|
137
|
+
key: jax.random.PRNGKey,
|
138
|
+
imu_measurements: dict[str, jax.Array],
|
139
|
+
noise_levels: Optional[dict[str, float | None]] = None,
|
140
|
+
bias_levels: Optional[dict[str, float | None]] = None,
|
141
|
+
) -> dict[str, jax.Array]:
|
142
|
+
"""Add noise and bias to 6D or 9D imu measurements.
|
143
|
+
|
144
|
+
Args:
|
145
|
+
key (jax.random.PRNGKey): Random seed.
|
146
|
+
imu_measurements (dict): IMU measurements without noise and bias.
|
147
|
+
Format is {"gyr": Array, "acc": Array, "mag": Array}.
|
148
|
+
|
149
|
+
Returns:
|
150
|
+
dict: IMU measurements with noise and bias.
|
151
|
+
"""
|
152
|
+
noise_levels = {} if noise_levels is None else noise_levels
|
153
|
+
bias_levels = {} if bias_levels is None else bias_levels
|
154
|
+
|
155
|
+
noisy_imu_measurements = {}
|
156
|
+
for sensor in imu_measurements:
|
157
|
+
key, c1, c2 = jax.random.split(key, 3)
|
158
|
+
|
159
|
+
noise_scale = noise_levels.get(sensor, NOISE_LEVELS[sensor])
|
160
|
+
if noise_scale is not None:
|
161
|
+
noise = (
|
162
|
+
jax.random.normal(c1, shape=imu_measurements[sensor].shape)
|
163
|
+
* noise_scale
|
164
|
+
)
|
165
|
+
else:
|
166
|
+
noise = 0.0
|
167
|
+
|
168
|
+
bias_maxval = bias_levels.get(sensor, BIAS_LEVELS[sensor])
|
169
|
+
if bias_maxval is not None:
|
170
|
+
bias = jax.random.uniform(
|
171
|
+
c2, minval=-bias_maxval, maxval=bias_maxval, shape=(3,)
|
172
|
+
)
|
173
|
+
else:
|
174
|
+
bias = 0.0
|
175
|
+
|
176
|
+
noisy_imu_measurements[sensor] = imu_measurements[sensor] + noise + bias
|
177
|
+
|
178
|
+
return noisy_imu_measurements
|
179
|
+
|
180
|
+
|
181
|
+
def imu(
|
182
|
+
xs: base.Transform,
|
183
|
+
gravity: jax.Array,
|
184
|
+
dt: float,
|
185
|
+
key: Optional[jax.random.PRNGKey] = None,
|
186
|
+
noisy: bool = False,
|
187
|
+
smoothen_degree: Optional[int] = None,
|
188
|
+
delay: Optional[int] = None,
|
189
|
+
random_s2s_ori: Optional[float] = None,
|
190
|
+
quasi_physical: bool = False,
|
191
|
+
low_pass_filter_pos_f_cutoff: Optional[float] = None,
|
192
|
+
low_pass_filter_rot_cutoff: Optional[float] = None,
|
193
|
+
has_magnetometer: bool = False,
|
194
|
+
magvec: Optional[jax.Array] = None,
|
195
|
+
gyro_second_order: bool = False,
|
196
|
+
natural_units: bool = False,
|
197
|
+
acc_xinyuyi_n: int = 1,
|
198
|
+
) -> dict:
|
199
|
+
"""Simulates a 6D IMU, `xs` should be Transforms from eps-to-imu.
|
200
|
+
NOTE: `smoothen_degree` is used as window size for moving average.
|
201
|
+
NOTE: If `smoothen_degree` is given, and `delay` is not, then delay is chosen
|
202
|
+
such moving average window is delayed to just be causal.
|
203
|
+
"""
|
204
|
+
assert xs.ndim() == 2
|
205
|
+
|
206
|
+
if random_s2s_ori is not None:
|
207
|
+
assert key is not None, "`random_s2s_ori` requires a random seed via `key`"
|
208
|
+
# `xs` are now from eps-to-segment, so add another final rotation from
|
209
|
+
# segment-to-sensor where this transform is only rotational
|
210
|
+
key, consume = jax.random.split(key)
|
211
|
+
xs_s2s = base.Transform.create(
|
212
|
+
rot=maths.quat_random(consume, maxval=random_s2s_ori)
|
213
|
+
)
|
214
|
+
xs = jax.vmap(algebra.transform_mul, in_axes=(None, 0))(xs_s2s, xs)
|
215
|
+
|
216
|
+
if quasi_physical:
|
217
|
+
xs = _quasi_physical_simulation(xs, dt)
|
218
|
+
|
219
|
+
if low_pass_filter_pos_f_cutoff is not None:
|
220
|
+
xs = xs.replace(
|
221
|
+
pos=_butterworth(
|
222
|
+
xs.pos, f_sampling=1 / dt, f_cutoff=low_pass_filter_pos_f_cutoff
|
223
|
+
)
|
224
|
+
)
|
225
|
+
|
226
|
+
if low_pass_filter_rot_cutoff is not None:
|
227
|
+
xs = xs.replace(
|
228
|
+
rot=maths.quat_lowpassfilter(
|
229
|
+
xs.rot, cutoff_freq=low_pass_filter_rot_cutoff, samp_freq=1 / dt
|
230
|
+
)
|
231
|
+
)
|
232
|
+
|
233
|
+
measurements = {
|
234
|
+
"acc": accelerometer(xs, gravity, dt, acc_xinyuyi_n),
|
235
|
+
"gyr": gyroscope(xs.rot, dt, gyro_second_order),
|
236
|
+
}
|
237
|
+
|
238
|
+
if has_magnetometer:
|
239
|
+
if magvec is None:
|
240
|
+
assert key is not None
|
241
|
+
key, consume = jax.random.split(key)
|
242
|
+
magvec = _draw_random_magvec(consume)
|
243
|
+
measurements["mag"] = magnetometer(xs.rot, magvec)
|
244
|
+
|
245
|
+
if smoothen_degree is not None:
|
246
|
+
measurements = jax.tree_map(
|
247
|
+
lambda arr: _moving_average(arr, smoothen_degree),
|
248
|
+
measurements,
|
249
|
+
)
|
250
|
+
|
251
|
+
# if you low-pass filter the imu measurements through a moving average which
|
252
|
+
# effectively uses future values, then it also makes sense to delay the imu
|
253
|
+
# measurements by this amount such that no future information is used
|
254
|
+
if delay is None:
|
255
|
+
half_window = (smoothen_degree - 1) // 2
|
256
|
+
delay = half_window
|
257
|
+
|
258
|
+
if delay is not None and delay > 0:
|
259
|
+
measurements = jax.tree_map(
|
260
|
+
lambda arr: (jnp.pad(arr, ((delay, 0), (0, 0)))[:-delay]), measurements
|
261
|
+
)
|
262
|
+
|
263
|
+
if noisy:
|
264
|
+
assert key is not None, "For noisy sensors random seed `key` must be provided."
|
265
|
+
measurements = add_noise_bias(key, measurements)
|
266
|
+
|
267
|
+
if natural_units:
|
268
|
+
measurements = rescale_natural_units(measurements)
|
269
|
+
|
270
|
+
return measurements
|
271
|
+
|
272
|
+
|
273
|
+
_rescale_natural_units_fns = defaultdict(lambda: (lambda arr: arr))
|
274
|
+
_rescale_natural_units_fns["gyr"] = lambda gyr: gyr / jnp.pi
|
275
|
+
_rescale_natural_units_fns["acc"] = lambda acc: acc / 9.81
|
276
|
+
|
277
|
+
|
278
|
+
def rescale_natural_units(X: dict[str, jax.Array]):
|
279
|
+
return {key: _rescale_natural_units_fns[key](val) for key, val in X.items()}
|
280
|
+
|
281
|
+
|
282
|
+
def rel_pose(
|
283
|
+
sys: base.System, xs: base.Transform, sys_xs: Optional[base.System] = None
|
284
|
+
) -> dict:
|
285
|
+
"""Relative pose of the entire system. `sys_scan` defines the parent-child ordering,
|
286
|
+
relative pose is from child to parent in local coordinates. Bodies that connect
|
287
|
+
to the base are skipped (that would be absolute pose).
|
288
|
+
|
289
|
+
Args:
|
290
|
+
sys_scan (base.System): System defining parent-child ordering.
|
291
|
+
xs (base.Transform): Body transforms from base to body.
|
292
|
+
sys_xs (base.System): System that defines the stacking order of `xs`.
|
293
|
+
|
294
|
+
Returns:
|
295
|
+
dict: Child-to-parent quaternions
|
296
|
+
"""
|
297
|
+
if sys_xs is None:
|
298
|
+
sys_xs = sys
|
299
|
+
|
300
|
+
if xs.pos.ndim == 3:
|
301
|
+
# swap (n_timesteps, n_links) axes
|
302
|
+
xs = xs.transpose([1, 0, 2])
|
303
|
+
|
304
|
+
assert xs.batch_dim() == sys_xs.num_links()
|
305
|
+
|
306
|
+
qrel = lambda q1, q2: maths.quat_mul(q1, maths.quat_inv(q2))
|
307
|
+
|
308
|
+
y = {}
|
309
|
+
|
310
|
+
def pose_child_to_parent(_, __, name_i: str, p: int):
|
311
|
+
# body connects to base
|
312
|
+
if p == -1:
|
313
|
+
return
|
314
|
+
|
315
|
+
name_p = sys.idx_to_name(p)
|
316
|
+
|
317
|
+
# find the transforms of those named bodies
|
318
|
+
i = sys_xs.name_to_idx(name_i)
|
319
|
+
p = sys_xs.name_to_idx(name_p)
|
320
|
+
|
321
|
+
# get those transforms
|
322
|
+
q1, q2 = xs.take(p).rot, xs.take(i).rot
|
323
|
+
|
324
|
+
y[name_i] = qrel(q1, q2)
|
325
|
+
|
326
|
+
sys.scan(pose_child_to_parent, "ll", sys.link_names, sys.link_parents)
|
327
|
+
|
328
|
+
return y
|
329
|
+
|
330
|
+
|
331
|
+
def root_incl(
|
332
|
+
sys: base.System, x: base.Transform, sys_x: base.System
|
333
|
+
) -> dict[str, jax.Array]:
|
334
|
+
# (time, nlinks, 4) -> (nlinks, time, 4)
|
335
|
+
rots = x.rot.transpose((1, 0, 2))
|
336
|
+
l_map = sys_x.idx_map("l")
|
337
|
+
|
338
|
+
y = dict()
|
339
|
+
|
340
|
+
def f(_, __, name: str, parent: int):
|
341
|
+
if parent != -1:
|
342
|
+
return
|
343
|
+
y[name] = maths.quat_project(rots[l_map[name]], jnp.array([0.0, 0, 1]))[1]
|
344
|
+
|
345
|
+
sys.scan(f, "ll", sys.link_names, sys.link_parents)
|
346
|
+
|
347
|
+
return y
|
348
|
+
|
349
|
+
|
350
|
+
def root_full(
|
351
|
+
sys: base.System, x: base.Transform, sys_x: base.System
|
352
|
+
) -> dict[str, jax.Array]:
|
353
|
+
# (time, nlinks, 4) -> (nlinks, time, 4)
|
354
|
+
rots = x.rot.transpose((1, 0, 2))
|
355
|
+
l_map = sys_x.idx_map("l")
|
356
|
+
|
357
|
+
y = dict()
|
358
|
+
|
359
|
+
def f(_, __, name: str, parent: int):
|
360
|
+
if parent != -1:
|
361
|
+
return
|
362
|
+
y[name] = rots[l_map[name]]
|
363
|
+
|
364
|
+
sys.scan(f, "ll", sys.link_names, sys.link_parents)
|
365
|
+
|
366
|
+
return y
|
367
|
+
|
368
|
+
|
369
|
+
def joint_axes(
|
370
|
+
sys: base.System,
|
371
|
+
xs: base.Transform,
|
372
|
+
sys_xs: base.System,
|
373
|
+
key: Optional[jax.Array] = None,
|
374
|
+
noisy: bool = False,
|
375
|
+
from_sys: bool = False,
|
376
|
+
randomly_flip: bool = False,
|
377
|
+
):
|
378
|
+
"""
|
379
|
+
The joint-axes to world is always zeros.
|
380
|
+
"""
|
381
|
+
if key is None:
|
382
|
+
assert not noisy
|
383
|
+
assert not randomly_flip
|
384
|
+
|
385
|
+
N = xs.shape(axis=0)
|
386
|
+
|
387
|
+
if from_sys:
|
388
|
+
X = _joint_axes_from_sys(sys, N)
|
389
|
+
else:
|
390
|
+
X = _joint_axes_from_xs(sys, xs, sys_xs)
|
391
|
+
|
392
|
+
if noisy:
|
393
|
+
for name in X:
|
394
|
+
key, c1, c2 = jax.random.split(key, 3)
|
395
|
+
bias = maths.quat_random(c1, maxval=jnp.deg2rad(5.0))
|
396
|
+
noise = maths.quat_random(c2, (N,), maxval=jnp.deg2rad(2.0))
|
397
|
+
dist = maths.quat_mul(noise, bias)
|
398
|
+
X[name]["joint_axes"] = maths.rotate(X[name]["joint_axes"], dist)
|
399
|
+
|
400
|
+
# joint axes to world must be zeros
|
401
|
+
for name, p in zip(sys.link_names, sys.link_parents):
|
402
|
+
if p == -1:
|
403
|
+
X[name]["joint_axes"] = jnp.zeros((N, 3))
|
404
|
+
else:
|
405
|
+
if randomly_flip:
|
406
|
+
key, consume = jax.random.split(key)
|
407
|
+
X[name]["joint_axes"] = (
|
408
|
+
jax.random.choice(consume, jnp.array([1.0, -1.0]))
|
409
|
+
* X[name]["joint_axes"]
|
410
|
+
)
|
411
|
+
|
412
|
+
return X
|
413
|
+
|
414
|
+
|
415
|
+
def _joint_axes_from_xs(sys, xs, sys_xs):
|
416
|
+
|
417
|
+
xs = sim2real.match_xs(sys, xs, sys_xs)
|
418
|
+
|
419
|
+
_, transform2_rot = sim2real.unzip_xs(sys, xs)
|
420
|
+
qs = transform2_rot.rot.transpose((1, 0, 2))
|
421
|
+
|
422
|
+
l2norm = lambda x: jnp.sqrt(jnp.sum(x**2, axis=-1))
|
423
|
+
|
424
|
+
@jax.vmap
|
425
|
+
def ensure_axis_convention(qs):
|
426
|
+
axis = qs[..., 1:] / (
|
427
|
+
jnp.linalg.norm(qs[..., 1:], axis=-1, keepdims=True) + 1e-6
|
428
|
+
)
|
429
|
+
convention = axis[0]
|
430
|
+
cond = (l2norm(convention - axis) > l2norm(convention + axis))[..., None]
|
431
|
+
return jnp.where(cond, -axis, axis)
|
432
|
+
|
433
|
+
axes = ensure_axis_convention(qs)
|
434
|
+
|
435
|
+
# TODO
|
436
|
+
# not ideal to average vectors that live on a sphere
|
437
|
+
N = axes.shape[1]
|
438
|
+
axes_average = jnp.mean(axes, axis=1)
|
439
|
+
axes_average /= jnp.linalg.norm(axes_average, axis=-1, keepdims=True)
|
440
|
+
axes = jnp.repeat(axes_average[:, None], N, axis=1)
|
441
|
+
|
442
|
+
X = {name: {"joint_axes": axes[sys.name_to_idx(name)]} for name in sys.link_names}
|
443
|
+
return X
|
444
|
+
|
445
|
+
|
446
|
+
def _joint_axes_from_sys(sys: base.Transform, N: int) -> dict:
|
447
|
+
"`sys` should be `sys_noimu`. `N` is number of timesteps"
|
448
|
+
xaxis = jnp.array([1.0, 0, 0])
|
449
|
+
yaxis = jnp.array([0.0, 1, 0])
|
450
|
+
zaxis = jnp.array([0.0, 0, 1])
|
451
|
+
id_to_axis = {"x": xaxis, "y": yaxis, "z": zaxis}
|
452
|
+
X = {}
|
453
|
+
|
454
|
+
def f(_, __, name, link_type, link):
|
455
|
+
joint_params = link.joint_params
|
456
|
+
if link_type in ["rx", "ry", "rz"]:
|
457
|
+
joint_axes = id_to_axis[link_type[1]]
|
458
|
+
elif link_type == "rr":
|
459
|
+
joint_axes = joint_params["rr"]["joint_axes"]
|
460
|
+
elif link_type[:6] == "rr_imp":
|
461
|
+
joint_axes = joint_params[link_type]["joint_axes"]
|
462
|
+
else:
|
463
|
+
joint_axes = xaxis
|
464
|
+
X[name] = {"joint_axes": joint_axes}
|
465
|
+
|
466
|
+
sys.scan(f, "lll", sys.link_names, sys.link_types, sys.links)
|
467
|
+
X = jax.tree_map(lambda arr: jnp.repeat(arr[None], N, axis=0), X)
|
468
|
+
return X
|
469
|
+
|
470
|
+
|
471
|
+
def _moving_average(arr: jax.Array, window: int) -> jax.Array:
|
472
|
+
"Padds with left and right values of array."
|
473
|
+
assert window % 2 == 1
|
474
|
+
assert window > 1, "Window size of 1 would be a no-op"
|
475
|
+
arr_smooth = jnp.zeros((len(arr) + window - 1,) + arr.shape[1:])
|
476
|
+
half_window = (window - 1) // 2
|
477
|
+
arr_padded = arr_smooth.at[half_window : (len(arr) + half_window)].set(arr)
|
478
|
+
arr_padded = arr_padded.at[:half_window].set(arr[0])
|
479
|
+
arr_padded = arr_padded.at[-half_window:].set(arr[-1])
|
480
|
+
|
481
|
+
for i in range(-half_window, half_window + 1):
|
482
|
+
rolled = jnp.roll(arr_padded, i, axis=0)
|
483
|
+
arr_smooth += rolled
|
484
|
+
arr_smooth = arr_smooth / window
|
485
|
+
return arr_smooth[half_window : (len(arr) + half_window)]
|
486
|
+
|
487
|
+
|
488
|
+
_quasi_physical_sys_str = r"""
|
489
|
+
<x_xy>
|
490
|
+
<options gravity="0 0 0"/>
|
491
|
+
<worldbody>
|
492
|
+
<body name="IMU" joint="p3d" damping="0.1 0.1 0.1" spring_stiff="3 3 3">
|
493
|
+
<geom type="box" mass="0.002" dim="0.01 0.01 0.01"/>
|
494
|
+
</body>
|
495
|
+
</worldbody>
|
496
|
+
</x_xy>
|
497
|
+
"""
|
498
|
+
|
499
|
+
|
500
|
+
def _quasi_physical_simulation_beautiful(
|
501
|
+
xs: base.Transform, dt: float
|
502
|
+
) -> base.Transform:
|
503
|
+
sys = io.load_sys_from_str(_quasi_physical_sys_str).replace(dt=dt)
|
504
|
+
|
505
|
+
def step_dynamics(state: base.State, x):
|
506
|
+
state = algorithms.step(sys.replace(link_spring_zeropoint=x.pos), state)
|
507
|
+
return state, state.q
|
508
|
+
|
509
|
+
state = base.State.create(sys, q=xs.pos[0])
|
510
|
+
_, pos = jax.lax.scan(step_dynamics, state, xs)
|
511
|
+
return xs.replace(pos=pos)
|
512
|
+
|
513
|
+
|
514
|
+
_constants = {
|
515
|
+
"qp_damp": 35.0,
|
516
|
+
"qp_stif": 625.0,
|
517
|
+
}
|
518
|
+
|
519
|
+
|
520
|
+
def _quasi_physical_simulation(xs: base.Transform, dt: float) -> base.Transform:
|
521
|
+
mass = 1.0
|
522
|
+
damp = _constants["qp_damp"]
|
523
|
+
stif = _constants["qp_stif"]
|
524
|
+
|
525
|
+
def step_dynamics(state, zeropoint):
|
526
|
+
pos, vel = state
|
527
|
+
zeropoint_pos, zeropoint_vel = zeropoint
|
528
|
+
acc = (damp * (zeropoint_vel - vel) + stif * (zeropoint_pos - pos)) / mass
|
529
|
+
vel += dt * acc
|
530
|
+
# semi-implicit, so use already next velocity
|
531
|
+
pos += dt * vel
|
532
|
+
return (pos, vel), pos
|
533
|
+
|
534
|
+
zero_vel = jnp.zeros_like(xs.pos[0])
|
535
|
+
state = (xs.pos[0], zero_vel)
|
536
|
+
zeropoint_vel = jnp.vstack((zero_vel, jnp.diff(xs.pos, axis=0) / dt))
|
537
|
+
zeropoint = (xs.pos, zeropoint_vel)
|
538
|
+
_, pos = jax.lax.scan(step_dynamics, state, zeropoint)
|
539
|
+
return xs.replace(pos=pos)
|
540
|
+
|
541
|
+
|
542
|
+
def _butterworth(
|
543
|
+
signal: jax.Array,
|
544
|
+
f_sampling: float,
|
545
|
+
f_cutoff: int,
|
546
|
+
method: str = "forward_backward",
|
547
|
+
) -> jax.Array:
|
548
|
+
"""https://stackoverflow.com/questions/20924868/calculate-coefficients-of-2nd-order
|
549
|
+
-butterworth-low-pass-filter"""
|
550
|
+
|
551
|
+
if method == "forward_backward":
|
552
|
+
signal = _butterworth(signal, f_sampling, f_cutoff, "forward")
|
553
|
+
return _butterworth(signal, f_sampling, f_cutoff, "backward")
|
554
|
+
elif method == "forward":
|
555
|
+
pass
|
556
|
+
elif method == "backward":
|
557
|
+
signal = jnp.flip(signal, axis=0)
|
558
|
+
else:
|
559
|
+
raise NotImplementedError
|
560
|
+
|
561
|
+
ff = f_cutoff / f_sampling
|
562
|
+
ita = 1.0 / jnp.tan(jnp.pi * ff)
|
563
|
+
q = jnp.sqrt(2.0)
|
564
|
+
b0 = 1.0 / (1.0 + q * ita + ita**2)
|
565
|
+
b1 = 2 * b0
|
566
|
+
b2 = b0
|
567
|
+
a1 = 2.0 * (ita**2 - 1.0) * b0
|
568
|
+
a2 = -(1.0 - q * ita + ita**2) * b0
|
569
|
+
|
570
|
+
def f(carry, x_i):
|
571
|
+
x_im1, x_im2, y_im1, y_im2 = carry
|
572
|
+
y_i = b0 * x_i + b1 * x_im1 + b2 * x_im2 + a1 * y_im1 + a2 * y_im2
|
573
|
+
return (x_i, x_im1, y_i, y_im1), y_i
|
574
|
+
|
575
|
+
init = (signal[1], signal[0]) * 2
|
576
|
+
signal = jax.lax.scan(f, init, signal[2:])[1]
|
577
|
+
signal = jnp.concatenate((signal[0:1],) * 2 + (signal,))
|
578
|
+
|
579
|
+
if method == "backward":
|
580
|
+
signal = jnp.flip(signal, axis=0)
|
581
|
+
|
582
|
+
return signal
|