imt-ring 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. imt_ring-1.2.1.dist-info/METADATA +91 -0
  2. imt_ring-1.2.1.dist-info/RECORD +83 -0
  3. imt_ring-1.2.1.dist-info/WHEEL +5 -0
  4. imt_ring-1.2.1.dist-info/top_level.txt +1 -0
  5. ring/__init__.py +63 -0
  6. ring/algebra.py +100 -0
  7. ring/algorithms/__init__.py +45 -0
  8. ring/algorithms/_random.py +403 -0
  9. ring/algorithms/custom_joints/__init__.py +6 -0
  10. ring/algorithms/custom_joints/rr_imp_joint.py +69 -0
  11. ring/algorithms/custom_joints/rr_joint.py +33 -0
  12. ring/algorithms/custom_joints/suntay.py +424 -0
  13. ring/algorithms/dynamics.py +345 -0
  14. ring/algorithms/generator/__init__.py +25 -0
  15. ring/algorithms/generator/base.py +414 -0
  16. ring/algorithms/generator/batch.py +282 -0
  17. ring/algorithms/generator/motion_artifacts.py +222 -0
  18. ring/algorithms/generator/pd_control.py +182 -0
  19. ring/algorithms/generator/randomize.py +119 -0
  20. ring/algorithms/generator/transforms.py +410 -0
  21. ring/algorithms/generator/types.py +36 -0
  22. ring/algorithms/jcalc.py +840 -0
  23. ring/algorithms/kinematics.py +202 -0
  24. ring/algorithms/sensors.py +582 -0
  25. ring/base.py +1046 -0
  26. ring/io/__init__.py +9 -0
  27. ring/io/examples/branched.xml +24 -0
  28. ring/io/examples/exclude/knee_trans_dof.xml +26 -0
  29. ring/io/examples/exclude/standard_sys.xml +106 -0
  30. ring/io/examples/exclude/standard_sys_rr_imp.xml +106 -0
  31. ring/io/examples/inv_pendulum.xml +14 -0
  32. ring/io/examples/knee_flexible_imus.xml +22 -0
  33. ring/io/examples/spherical_stiff.xml +11 -0
  34. ring/io/examples/symmetric.xml +12 -0
  35. ring/io/examples/test_all_1.xml +39 -0
  36. ring/io/examples/test_all_2.xml +39 -0
  37. ring/io/examples/test_ang0_pos0.xml +9 -0
  38. ring/io/examples/test_control.xml +16 -0
  39. ring/io/examples/test_double_pendulum.xml +14 -0
  40. ring/io/examples/test_free.xml +11 -0
  41. ring/io/examples/test_kinematics.xml +23 -0
  42. ring/io/examples/test_morph_system/four_seg_seg1.xml +26 -0
  43. ring/io/examples/test_morph_system/four_seg_seg3.xml +26 -0
  44. ring/io/examples/test_randomize_position.xml +26 -0
  45. ring/io/examples/test_sensors.xml +13 -0
  46. ring/io/examples/test_three_seg_seg2.xml +23 -0
  47. ring/io/examples.py +42 -0
  48. ring/io/test_examples.py +6 -0
  49. ring/io/xml/__init__.py +6 -0
  50. ring/io/xml/abstract.py +300 -0
  51. ring/io/xml/from_xml.py +299 -0
  52. ring/io/xml/test_from_xml.py +56 -0
  53. ring/io/xml/test_to_xml.py +31 -0
  54. ring/io/xml/to_xml.py +94 -0
  55. ring/maths.py +397 -0
  56. ring/ml/__init__.py +33 -0
  57. ring/ml/base.py +292 -0
  58. ring/ml/callbacks.py +434 -0
  59. ring/ml/ml_utils.py +272 -0
  60. ring/ml/optimizer.py +149 -0
  61. ring/ml/params/0x13e3518065c21cd8.pickle +0 -0
  62. ring/ml/ringnet.py +279 -0
  63. ring/ml/train.py +318 -0
  64. ring/ml/training_loop.py +131 -0
  65. ring/rendering/__init__.py +2 -0
  66. ring/rendering/base_render.py +271 -0
  67. ring/rendering/mujoco_render.py +222 -0
  68. ring/rendering/vispy_render.py +340 -0
  69. ring/rendering/vispy_visuals.py +290 -0
  70. ring/sim2real/__init__.py +7 -0
  71. ring/sim2real/sim2real.py +288 -0
  72. ring/spatial.py +126 -0
  73. ring/sys_composer/__init__.py +5 -0
  74. ring/sys_composer/delete_sys.py +114 -0
  75. ring/sys_composer/inject_sys.py +110 -0
  76. ring/sys_composer/morph_sys.py +361 -0
  77. ring/utils/__init__.py +21 -0
  78. ring/utils/batchsize.py +51 -0
  79. ring/utils/colab.py +48 -0
  80. ring/utils/hdf5.py +198 -0
  81. ring/utils/normalizer.py +56 -0
  82. ring/utils/path.py +44 -0
  83. ring/utils/utils.py +161 -0
@@ -0,0 +1,840 @@
1
+ from dataclasses import asdict
2
+ from dataclasses import dataclass
3
+ from dataclasses import field
4
+ from dataclasses import replace
5
+ from typing import Any, Callable, get_type_hints, Optional
6
+
7
+ import jax
8
+ import jax.numpy as jnp
9
+ import tree_utils
10
+
11
+ from ring import algebra
12
+ from ring import base
13
+ from ring import maths
14
+ from ring.algorithms import _random
15
+ from ring.algorithms._random import _to_float
16
+ from ring.algorithms._random import TimeDependentFloat
17
+
18
+
19
+ @dataclass
20
+ class MotionConfig:
21
+ T: float = 60.0 # length of random motion
22
+ t_min: float = 0.05 # min time between two generated angles
23
+ t_max: float | TimeDependentFloat = 0.30 # max time ..
24
+
25
+ dang_min: float | TimeDependentFloat = 0.1 # minimum angular velocity in rad/s
26
+ dang_max: float | TimeDependentFloat = 3.0 # maximum angular velocity in rad/s
27
+
28
+ # minimum angular velocity of euler angles used for `free and spherical joints`
29
+ dang_min_free_spherical: float | TimeDependentFloat = 0.1
30
+ dang_max_free_spherical: float | TimeDependentFloat = 3.0
31
+
32
+ # max min allowed actual delta values in radians
33
+ delta_ang_min: float | TimeDependentFloat = 0.0
34
+ delta_ang_max: float | TimeDependentFloat = 2 * jnp.pi
35
+ delta_ang_min_free_spherical: float | TimeDependentFloat = 0.0
36
+ delta_ang_max_free_spherical: float | TimeDependentFloat = 2 * jnp.pi
37
+
38
+ dpos_min: float | TimeDependentFloat = 0.001 # speed of translation
39
+ dpos_max: float | TimeDependentFloat = 0.7
40
+ pos_min: float | TimeDependentFloat = -2.5
41
+ pos_max: float | TimeDependentFloat = +2.5
42
+
43
+ # used by both `random_angle_*` and `random_pos_*`
44
+ # only used if `randomized_interpolation` is set
45
+ cdf_bins_min: int = 5
46
+ # by default equal to `cdf_bins_min`
47
+ cdf_bins_max: Optional[int] = None
48
+
49
+ # flags
50
+ randomized_interpolation_angle: bool = False
51
+ randomized_interpolation_position: bool = False
52
+ interpolation_method: str = "cosine"
53
+ range_of_motion_hinge: bool = True
54
+ range_of_motion_hinge_method: str = "uniform"
55
+
56
+ # initial value of joints
57
+ ang0_min: float = -jnp.pi
58
+ ang0_max: float = jnp.pi
59
+ pos0_min: float = 0.0
60
+ pos0_max: float = 0.0
61
+
62
+ # cor (center of rotation) custom fields
63
+ cor: bool = False
64
+ cor_t_min: float = 0.2
65
+ cor_t_max: float | TimeDependentFloat = 2.0
66
+ cor_dpos_min: float | TimeDependentFloat = 0.00001
67
+ cor_dpos_max: float | TimeDependentFloat = 0.5
68
+ cor_pos_min: float | TimeDependentFloat = -0.4
69
+ cor_pos_max: float | TimeDependentFloat = 0.4
70
+
71
+ def is_feasible(self) -> bool:
72
+ return _is_feasible_config1(self)
73
+
74
+ def to_nomotion_config(self) -> "MotionConfig":
75
+ kwargs = asdict(self)
76
+ for key in [
77
+ "dang_min",
78
+ "dang_max",
79
+ "delta_ang_min",
80
+ "dang_min_free_spherical",
81
+ "dang_max_free_spherical",
82
+ "delta_ang_min_free_spherical",
83
+ "dpos_min",
84
+ "dpos_max",
85
+ ]:
86
+ kwargs[key] = 0.0
87
+ nomotion_config = MotionConfig(**kwargs)
88
+ assert nomotion_config.is_feasible()
89
+ return nomotion_config
90
+
91
+
92
+ def _is_feasible_config1(c: MotionConfig) -> bool:
93
+ t_min, t_max = c.t_min, _to_float(c.t_max, 0.0)
94
+
95
+ def dx_deltax_check(dx_min, dx_max, deltax_min, deltax_max) -> bool:
96
+ dx_min, dx_max, deltax_min, deltax_max = map(
97
+ (lambda v: _to_float(v, 0.0)), (dx_min, dx_max, deltax_min, deltax_max)
98
+ )
99
+ if (deltax_max / t_min) < dx_min:
100
+ return False
101
+ if (deltax_min / t_max) > dx_max:
102
+ return False
103
+ return True
104
+
105
+ return all(
106
+ [
107
+ dx_deltax_check(*args)
108
+ for args in zip(
109
+ [c.dang_min, c.dang_min_free_spherical],
110
+ [c.dang_max, c.dang_max_free_spherical],
111
+ [c.delta_ang_min, c.delta_ang_min_free_spherical],
112
+ [c.delta_ang_max, c.delta_ang_max_free_spherical],
113
+ )
114
+ ]
115
+ )
116
+
117
+
118
+ def _find_interval(t: jax.Array, boundaries: jax.Array):
119
+ """Find the interval of `boundaries` between which `t` lies.
120
+
121
+ Args:
122
+ t: Scalar float (e.g. time)
123
+ boundaries: Array of floats
124
+
125
+ Example: (from `test_jcalc.py`)
126
+ >> _find_interval(1.5, jnp.array([0.0, 1.0, 2.0])) -> 2
127
+ >> _find_interval(0.5, jnp.array([0.0])) -> 1
128
+ >> _find_interval(-0.5, jnp.array([0.0])) -> 0
129
+ """
130
+ assert boundaries.ndim == 1
131
+
132
+ @jax.vmap
133
+ def leq_than_boundary(boundary: jax.Array):
134
+ return jnp.where(t >= boundary, 1, 0)
135
+
136
+ return jnp.sum(leq_than_boundary(boundaries))
137
+
138
+
139
+ def join_motionconfigs(
140
+ configs: list[MotionConfig], boundaries: list[float]
141
+ ) -> MotionConfig:
142
+ assert len(configs) == (
143
+ len(boundaries) + 1
144
+ ), "length of `boundaries` should be one less than length of `configs`"
145
+ boundaries = jnp.array(boundaries, dtype=float)
146
+
147
+ def new_value(field: str):
148
+ scalar_options = jnp.array([getattr(c, field) for c in configs])
149
+
150
+ def scalar(t):
151
+ return jax.lax.dynamic_index_in_dim(
152
+ scalar_options, _find_interval(t, boundaries), keepdims=False
153
+ )
154
+
155
+ return scalar
156
+
157
+ hints = get_type_hints(MotionConfig())
158
+ attrs = MotionConfig().__dict__
159
+ is_time_dependent_field = lambda key: hints[key] == (float | TimeDependentFloat)
160
+ time_dependent_fields = [key for key in attrs if is_time_dependent_field(key)]
161
+ time_independent_fields = [key for key in attrs if not is_time_dependent_field(key)]
162
+
163
+ for time_dep_field in time_independent_fields:
164
+ field_values = set([getattr(config, time_dep_field) for config in configs])
165
+ assert (
166
+ len(field_values) == 1
167
+ ), f"MotionConfig.{time_dep_field}={field_values}. Should be one unique value.."
168
+
169
+ changes = {field: new_value(field) for field in time_dependent_fields}
170
+ return replace(configs[0], **changes)
171
+
172
+
173
+ DRAW_FN = Callable[
174
+ # config, key_t, key_value, dt, params
175
+ [MotionConfig, jax.random.PRNGKey, jax.random.PRNGKey, float, jax.Array],
176
+ jax.Array,
177
+ ]
178
+ P_CONTROL_TERM = Callable[
179
+ # q, q_ref -> qdd
180
+ # (q_size,), (q_size), -> (qd_size,)
181
+ [jax.Array, jax.Array],
182
+ jax.Array,
183
+ ]
184
+ # this function is used to generate the velocity reference trajectory from the
185
+ # reference trajectory q, which both are required for the pd control, which it is
186
+ # required if the simulation is not kinematic but dynamic
187
+ QD_FROM_Q = Callable[
188
+ # qs, dt -> dqs
189
+ # (N, q_size), (1,) -> (N, qd_size)
190
+ [jax.Array, jax.Array],
191
+ jax.Array,
192
+ ]
193
+ # used by ring.algorithms.inverse_kinematics_endeffector to maps from
194
+ # [-inf, inf] -> feasible joint value range. Defaults to {}.
195
+ # For example: By default, for a hinge joint it uses `maths.wrap_to_pi`.
196
+ # For a spherical joint it would normalize to create a unit quaternion.
197
+ COORDINATE_VECTOR_TO_Q = Callable[
198
+ # (q_size,) -> (q_size)
199
+ [jax.Array],
200
+ jax.Array,
201
+ ]
202
+
203
+ # used only by `sim2real.project_xs`, and it receives a transform object
204
+ # and projects it into the feasible subspace as defined by the joint
205
+ # and returns the new transform object
206
+ PROJECT_TRANSFORM_TO_FEASIBLE = Callable[
207
+ # base.Transform, Pytree (joint_params)
208
+ [base.Transform, tree_utils.PyTree],
209
+ base.Transform,
210
+ ]
211
+
212
+ # used by ring.System.from_xml and by ring.RCMG
213
+ # (key) -> Pytree
214
+ # if it is not given and None, then there will be no specific
215
+ # joint_parameters for the custom joint and it will simply receive
216
+ # the defaults parameters, that is joint_params['default']
217
+ INIT_JOINT_PARAMS = Callable[[jax.Array], tree_utils.PyTree]
218
+
219
+ # (transform2_p_to_i, joint_params) -> (q_size)
220
+ INV_KIN = Callable[[base.Transform, tree_utils.PyTree], jax.Array]
221
+
222
+
223
+ @dataclass
224
+ class JointModel:
225
+ # (q, params) -> Transform
226
+ transform: Callable[[jax.Array, jax.Array], base.Transform]
227
+ # len(motion) == len(qd)
228
+ # if callable: joint_params -> base.Motion
229
+ motion: list[base.Motion | Callable[[jax.Array], base.Motion]] = field(
230
+ default_factory=lambda: []
231
+ )
232
+ # (config, key_t, key_value, params) -> jax.Array
233
+ rcmg_draw_fn: Optional[DRAW_FN] = None
234
+
235
+ # only used by `pd_control`
236
+ p_control_term: Optional[P_CONTROL_TERM] = None
237
+ qd_from_q: Optional[QD_FROM_Q] = None
238
+
239
+ # used by
240
+ # -`inverse_kinematics_endeffector`
241
+ # - System.coordinate_vector_to_q
242
+ coordinate_vector_to_q: Optional[COORDINATE_VECTOR_TO_Q] = None
243
+
244
+ # only used by `inverse_kinematics`
245
+ inv_kin: Optional[INV_KIN] = None
246
+
247
+ init_joint_params: Optional[INIT_JOINT_PARAMS] = None
248
+
249
+ utilities: Optional[dict[str, Any]] = field(default_factory=lambda: dict())
250
+
251
+
252
+ def _free_transform(q, _):
253
+ rot, pos = q[:4], q[4:]
254
+ return base.Transform(pos, rot)
255
+
256
+
257
+ def _free_2d_transform(q, _):
258
+ angle_x, pos_yz = q[0], q[1:]
259
+ rot = maths.quat_rot_axis(maths.x_unit_vector, angle_x)
260
+ pos = jnp.concatenate((jnp.array([0.0]), pos_yz))
261
+ return base.Transform(pos, rot)
262
+
263
+
264
+ def _rxyz_transform(q, _, axis):
265
+ q = jnp.squeeze(q)
266
+ rot = maths.quat_rot_axis(axis, q)
267
+ return base.Transform.create(rot=rot)
268
+
269
+
270
+ def _pxyz_transform(q, _, direction):
271
+ pos = direction * q
272
+ return base.Transform.create(pos=pos)
273
+
274
+
275
+ def _frozen_transform(_, __):
276
+ return base.Transform.zero()
277
+
278
+
279
+ def _spherical_transform(q, _):
280
+ return base.Transform.create(rot=q)
281
+
282
+
283
+ def _saddle_transform(q, _):
284
+ rot = maths.euler_to_quat(jnp.array([0.0, q[0], q[1]]))
285
+ return base.Transform.create(rot=rot)
286
+
287
+
288
+ def _p3d_transform(q, _):
289
+ return base.Transform.create(pos=q)
290
+
291
+
292
+ def _cor_transform(q, _):
293
+ free = _free_transform(q[:7], _)
294
+ p3d = _p3d_transform(q[7:], _)
295
+ return algebra.transform_mul(p3d, free)
296
+
297
+
298
+ mrx = base.Motion.create(ang=jnp.array([1.0, 0, 0]))
299
+ mry = base.Motion.create(ang=jnp.array([0.0, 1, 0]))
300
+ mrz = base.Motion.create(ang=jnp.array([0.0, 0, 1]))
301
+ mpx = base.Motion.create(vel=jnp.array([1.0, 0, 0]))
302
+ mpy = base.Motion.create(vel=jnp.array([0.0, 1, 0]))
303
+ mpz = base.Motion.create(vel=jnp.array([0.0, 0, 1]))
304
+
305
+
306
+ def _draw_rxyz(
307
+ config: MotionConfig,
308
+ key_t: jax.random.PRNGKey,
309
+ key_value: jax.random.PRNGKey,
310
+ dt: float,
311
+ _: jax.Array,
312
+ # TODO, delete these args and pass a modifified `config` with `replace` instead
313
+ enable_range_of_motion: bool = True,
314
+ free_spherical: bool = False,
315
+ ) -> jax.Array:
316
+ key_value, consume = jax.random.split(key_value)
317
+ ANG_0 = jax.random.uniform(consume, minval=config.ang0_min, maxval=config.ang0_max)
318
+ # `random_angle_over_time` always returns wrapped angles, thus it would be
319
+ # inconsistent to allow an initial value that is not wrapped
320
+ ANG_0 = maths.wrap_to_pi(ANG_0)
321
+ # only used for `delta_ang_min_max` logic
322
+ max_iter = 5
323
+ return _random.random_angle_over_time(
324
+ key_t,
325
+ key_value,
326
+ ANG_0,
327
+ config.dang_min_free_spherical if free_spherical else config.dang_min,
328
+ config.dang_max_free_spherical if free_spherical else config.dang_max,
329
+ config.delta_ang_min_free_spherical if free_spherical else config.delta_ang_min,
330
+ config.delta_ang_max_free_spherical if free_spherical else config.delta_ang_max,
331
+ config.t_min,
332
+ config.t_max,
333
+ config.T,
334
+ dt,
335
+ max_iter,
336
+ config.randomized_interpolation_angle,
337
+ config.range_of_motion_hinge if enable_range_of_motion else False,
338
+ config.range_of_motion_hinge_method,
339
+ config.cdf_bins_min,
340
+ config.cdf_bins_max,
341
+ config.interpolation_method,
342
+ )
343
+
344
+
345
+ def _draw_pxyz(
346
+ config: MotionConfig,
347
+ _: jax.random.PRNGKey,
348
+ key_value: jax.random.PRNGKey,
349
+ dt: float,
350
+ __: jax.Array,
351
+ cor: bool = False,
352
+ ) -> jax.Array:
353
+ key_value, consume = jax.random.split(key_value)
354
+ POS_0 = jax.random.uniform(consume, minval=config.pos0_min, maxval=config.pos0_max)
355
+ max_iter = 100
356
+ return _random.random_position_over_time(
357
+ key_value,
358
+ POS_0,
359
+ config.cor_pos_min if cor else config.pos_min,
360
+ config.cor_pos_max if cor else config.pos_max,
361
+ config.cor_dpos_min if cor else config.dpos_min,
362
+ config.cor_dpos_max if cor else config.dpos_max,
363
+ config.cor_t_min if cor else config.t_min,
364
+ config.cor_t_max if cor else config.t_max,
365
+ config.T,
366
+ dt,
367
+ max_iter,
368
+ config.randomized_interpolation_position,
369
+ config.cdf_bins_min,
370
+ config.cdf_bins_max,
371
+ config.interpolation_method,
372
+ )
373
+
374
+
375
+ def _draw_spherical(
376
+ config: MotionConfig,
377
+ key_t: jax.random.PRNGKey,
378
+ key_value: jax.random.PRNGKey,
379
+ dt: float,
380
+ _: jax.Array,
381
+ ) -> jax.Array:
382
+ # NOTE: We draw 3 euler angles and then build a quaternion.
383
+ # Not ideal, but i am unaware of a better way.
384
+ @jax.vmap
385
+ def draw_euler_angles(key_t, key_value):
386
+ return _draw_rxyz(
387
+ config,
388
+ key_t,
389
+ key_value,
390
+ dt,
391
+ None,
392
+ enable_range_of_motion=False,
393
+ free_spherical=True,
394
+ )
395
+
396
+ triple = lambda key: jax.random.split(key, 3)
397
+ euler_angles = draw_euler_angles(triple(key_t), triple(key_value)).T
398
+ q = maths.quat_euler(euler_angles)
399
+ return q
400
+
401
+
402
+ def _draw_saddle(
403
+ config: MotionConfig,
404
+ key_t: jax.random.PRNGKey,
405
+ key_value: jax.random.PRNGKey,
406
+ dt: float,
407
+ _: jax.Array,
408
+ ) -> jax.Array:
409
+ @jax.vmap
410
+ def draw_euler_angles(key_t, key_value):
411
+ return _draw_rxyz(
412
+ config,
413
+ key_t,
414
+ key_value,
415
+ dt,
416
+ None,
417
+ enable_range_of_motion=False,
418
+ free_spherical=False,
419
+ )
420
+
421
+ double = lambda key: jax.random.split(key)
422
+ yz_euler_angles = draw_euler_angles(double(key_t), double(key_value)).T
423
+ return yz_euler_angles
424
+
425
+
426
+ def _draw_p3d_and_cor(
427
+ config: MotionConfig,
428
+ _: jax.random.PRNGKey,
429
+ key_value: jax.random.PRNGKey,
430
+ dt: float,
431
+ __: jax.Array,
432
+ cor: bool,
433
+ ) -> jax.Array:
434
+ pos = jax.vmap(lambda key: _draw_pxyz(config, None, key, dt, None, cor))(
435
+ jax.random.split(key_value, 3)
436
+ )
437
+ return pos.T
438
+
439
+
440
+ def _draw_p3d(
441
+ config: MotionConfig,
442
+ _: jax.random.PRNGKey,
443
+ key_value: jax.random.PRNGKey,
444
+ dt: float,
445
+ __: jax.Array,
446
+ ) -> jax.Array:
447
+ return _draw_p3d_and_cor(config, _, key_value, dt, None, cor=False)
448
+
449
+
450
+ def _draw_cor(
451
+ config: MotionConfig,
452
+ _: jax.random.PRNGKey,
453
+ key_value: jax.random.PRNGKey,
454
+ dt: float,
455
+ __: jax.Array,
456
+ ) -> jax.Array:
457
+ key_value1, key_value2 = jax.random.split(key_value)
458
+ q_free = _draw_free(config, _, key_value1, dt, None)
459
+ q_p3d = _draw_p3d_and_cor(config, _, key_value2, dt, None, cor=True)
460
+ return jnp.concatenate((q_free, q_p3d), axis=1)
461
+
462
+
463
+ def _draw_free(
464
+ config: MotionConfig,
465
+ key_t: jax.random.PRNGKey,
466
+ key_value: jax.random.PRNGKey,
467
+ dt: float,
468
+ __: jax.Array,
469
+ ) -> jax.Array:
470
+ key_value1, key_value2 = jax.random.split(key_value)
471
+ q = _draw_spherical(config, key_t, key_value1, dt, None)
472
+ pos = _draw_p3d(config, None, key_value2, dt, None)
473
+ return jnp.concatenate((q, pos), axis=1)
474
+
475
+
476
+ def _draw_free_2d(
477
+ config: MotionConfig,
478
+ key_t: jax.random.PRNGKey,
479
+ key_value: jax.random.PRNGKey,
480
+ dt: float,
481
+ __: jax.Array,
482
+ ) -> jax.Array:
483
+ key_value1, key_value2 = jax.random.split(key_value)
484
+ angle_x = _draw_rxyz(
485
+ config,
486
+ key_t,
487
+ key_value1,
488
+ dt,
489
+ None,
490
+ enable_range_of_motion=False,
491
+ free_spherical=True,
492
+ )[:, None]
493
+ pos_yz = _draw_p3d(config, None, key_value2, dt, None)[:, :2]
494
+ return jnp.concatenate((angle_x, pos_yz), axis=1)
495
+
496
+
497
+ def _draw_frozen(config: MotionConfig, _, __, dt: float, ___) -> jax.Array:
498
+ N = int(config.T / dt)
499
+ return jnp.zeros((N, 0))
500
+
501
+
502
+ qrel = lambda q1, q2: maths.quat_mul(q1, maths.quat_inv(q2))
503
+
504
+
505
+ def _qd_from_q_quaternion(qs, dt):
506
+ axis, angle = maths.quat_to_rot_axis(qrel(qs[2:], qs[:-2]))
507
+ # axis.shape = (n_timesteps, 3); angle.shape = (n_timesteps,)
508
+ # Thus add singleton dimesions otherwise broadcast error
509
+ dq = axis * angle[:, None] / (2 * dt)
510
+ dq = jnp.vstack((jnp.zeros((3,)), dq, jnp.zeros((3,))))
511
+ return dq
512
+
513
+
514
+ def _qd_from_q_cartesian(qs, dt):
515
+ dq = jnp.vstack(
516
+ (jnp.zeros_like(qs[0]), (qs[2:] - qs[:-2]) / (2 * dt), jnp.zeros_like(qs[0]))
517
+ )
518
+ return dq
519
+
520
+
521
+ def _p_control_quaternion(q, q_ref):
522
+ axis, angle = maths.quat_to_rot_axis(qrel(q_ref, q))
523
+ return axis * angle
524
+
525
+
526
+ def _p_control_term_rxyz(q, q_ref):
527
+ # q_ref comes from rcmg. Thus, it is already wrapped
528
+ # TODO: Currently state.q is not wrapped. Change that?
529
+ return maths.wrap_to_pi(q_ref - maths.wrap_to_pi(q))
530
+
531
+
532
+ def _p_control_term_pxyz_p3d(q, q_ref):
533
+ return q_ref - q
534
+
535
+
536
+ def _p_control_term_frozen(q, q_ref):
537
+ return jnp.array([])
538
+
539
+
540
+ def _p_control_term_spherical(q, q_ref):
541
+ return _p_control_quaternion(q, q_ref)
542
+
543
+
544
+ def _p_control_term_free(q, q_ref):
545
+ return jnp.concatenate(
546
+ (
547
+ _p_control_quaternion(q[:4], q_ref[:4]),
548
+ (q_ref[4:] - q[4:]),
549
+ )
550
+ )
551
+
552
+
553
+ def _p_control_term_free_2d(q, q_ref):
554
+ return jnp.concatenate(
555
+ (
556
+ _p_control_term_rxyz(q[:1], q_ref[:1]),
557
+ (q_ref[1:] - q[1:]),
558
+ )
559
+ )
560
+
561
+
562
+ def _p_control_term_cor(q, q_ref):
563
+ return _p_control_term_free(q, q_ref)
564
+
565
+
566
+ def _qd_from_q_free(qs, dt):
567
+ qd_quat = _qd_from_q_quaternion(qs[:, :4], dt)
568
+ qd_pos = _qd_from_q_cartesian(qs[:, 4:], dt)
569
+ return jnp.hstack((qd_quat, qd_pos))
570
+
571
+
572
+ def _coordinate_vector_to_q_free_spherical_cor(q):
573
+ return q.at[:4].set(maths.safe_normalize(q[:4]))
574
+
575
+
576
+ def _coordinate_vector_to_q_free_2d(q):
577
+ return q.at[0].set(maths.wrap_to_pi(q[0]))
578
+
579
+
580
+ _str2idx = {"x": slice(0, 1), "y": slice(1, 2), "z": slice(2, 3)}
581
+
582
+
583
+ def _inv_kin_rxyz_factory(xyz: str):
584
+ k = maths.unit_vectors(xyz)
585
+
586
+ def _inv_kin_rxyz(x: base.Transform, _) -> jax.Array:
587
+ # TODO
588
+ # NOTE: CONVENTION
589
+ # the first return is the much faster version but it suffers from a convention
590
+ # issue the second version is equivalent and does not suffer from the
591
+ # convention issue but it is much slower
592
+ q = x.rot
593
+ angle = 2 * jnp.arctan2(q[1:] @ k, q[0])
594
+ return -angle[None]
595
+ axis, angle = maths.quat_to_rot_axis(maths.quat_project(q, k)[0])
596
+ return jnp.where((k @ axis) > 0, angle, -angle)[None]
597
+
598
+ return _inv_kin_rxyz
599
+
600
+
601
+ def _inv_kin_pxyz_factory(xyz: str):
602
+ idx = _str2idx[xyz]
603
+
604
+ def _inv_kin_pxyz(x: base.Transform, _) -> jax.Array:
605
+ return x.pos[idx]
606
+
607
+ return _inv_kin_pxyz
608
+
609
+
610
+ def _inv_kin_free_2d(x: base.Transform, _) -> jax.Array:
611
+ angle_x = _inv_kin_rxyz_factory("x")
612
+ return jnp.concatenate((angle_x(x), x.pos[1:]))
613
+
614
+
615
+ _joint_types = {
616
+ "free": JointModel(
617
+ _free_transform,
618
+ [mrx, mry, mrz, mpx, mpy, mpz],
619
+ _draw_free,
620
+ _p_control_term_free,
621
+ _qd_from_q_free,
622
+ coordinate_vector_to_q=_coordinate_vector_to_q_free_spherical_cor,
623
+ inv_kin=lambda x, _: jnp.concatenate((x.rot, x.pos)),
624
+ ),
625
+ "free_2d": JointModel(
626
+ _free_2d_transform,
627
+ [mrx, mpy, mpz],
628
+ _draw_free_2d,
629
+ _p_control_term_free_2d,
630
+ _qd_from_q_cartesian,
631
+ coordinate_vector_to_q=_coordinate_vector_to_q_free_2d,
632
+ inv_kin=_inv_kin_free_2d,
633
+ ),
634
+ "frozen": JointModel(
635
+ _frozen_transform,
636
+ [],
637
+ _draw_frozen,
638
+ _p_control_term_frozen,
639
+ _qd_from_q_cartesian,
640
+ lambda q: q,
641
+ lambda x, _: jnp.array([]),
642
+ ),
643
+ "spherical": JointModel(
644
+ _spherical_transform,
645
+ [mrx, mry, mrz],
646
+ _draw_spherical,
647
+ _p_control_term_spherical,
648
+ _qd_from_q_quaternion,
649
+ _coordinate_vector_to_q_free_spherical_cor,
650
+ lambda x, _: x.rot,
651
+ ),
652
+ "p3d": JointModel(
653
+ _p3d_transform,
654
+ [mpx, mpy, mpz],
655
+ _draw_p3d,
656
+ _p_control_term_pxyz_p3d,
657
+ _qd_from_q_cartesian,
658
+ lambda q: q,
659
+ lambda x, _: x.pos,
660
+ ),
661
+ "cor": JointModel(
662
+ _cor_transform,
663
+ [mrx, mry, mrz, mpx, mpy, mpz, mpx, mpy, mpz],
664
+ _draw_cor,
665
+ _p_control_term_cor,
666
+ _qd_from_q_free,
667
+ _coordinate_vector_to_q_free_spherical_cor,
668
+ ),
669
+ "rx": JointModel(
670
+ lambda q, _: _rxyz_transform(q, _, jnp.array([1.0, 0, 0])),
671
+ [mrx],
672
+ _draw_rxyz,
673
+ _p_control_term_rxyz,
674
+ _qd_from_q_cartesian,
675
+ maths.wrap_to_pi,
676
+ _inv_kin_rxyz_factory("x"),
677
+ ),
678
+ "ry": JointModel(
679
+ lambda q, _: _rxyz_transform(q, _, jnp.array([0.0, 1, 0])),
680
+ [mry],
681
+ _draw_rxyz,
682
+ _p_control_term_rxyz,
683
+ _qd_from_q_cartesian,
684
+ maths.wrap_to_pi,
685
+ _inv_kin_rxyz_factory("y"),
686
+ ),
687
+ "rz": JointModel(
688
+ lambda q, _: _rxyz_transform(q, _, jnp.array([0.0, 0, 1])),
689
+ [mrz],
690
+ _draw_rxyz,
691
+ _p_control_term_rxyz,
692
+ _qd_from_q_cartesian,
693
+ maths.wrap_to_pi,
694
+ _inv_kin_rxyz_factory("z"),
695
+ ),
696
+ "px": JointModel(
697
+ lambda q, _: _pxyz_transform(q, _, jnp.array([1.0, 0, 0])),
698
+ [mpx],
699
+ _draw_pxyz,
700
+ _p_control_term_pxyz_p3d,
701
+ _qd_from_q_cartesian,
702
+ lambda q: q,
703
+ _inv_kin_pxyz_factory("x"),
704
+ ),
705
+ "py": JointModel(
706
+ lambda q, _: _pxyz_transform(q, _, jnp.array([0.0, 1, 0])),
707
+ [mpy],
708
+ _draw_pxyz,
709
+ _p_control_term_pxyz_p3d,
710
+ _qd_from_q_cartesian,
711
+ lambda q: q,
712
+ _inv_kin_pxyz_factory("y"),
713
+ ),
714
+ "pz": JointModel(
715
+ lambda q, _: _pxyz_transform(q, _, jnp.array([0.0, 0, 1])),
716
+ [mpz],
717
+ _draw_pxyz,
718
+ _p_control_term_pxyz_p3d,
719
+ _qd_from_q_cartesian,
720
+ lambda q: q,
721
+ _inv_kin_pxyz_factory("z"),
722
+ ),
723
+ "saddle": JointModel(
724
+ _saddle_transform,
725
+ [mry, mrz],
726
+ _draw_saddle,
727
+ _p_control_term_rxyz,
728
+ _qd_from_q_cartesian,
729
+ maths.wrap_to_pi,
730
+ ),
731
+ }
732
+
733
+
734
+ def get_joint_model(joint_type: str) -> JointModel:
735
+ assert (
736
+ joint_type in _joint_types
737
+ ), f"{joint_type} not in {list(_joint_types.keys())}"
738
+ return _joint_types[joint_type]
739
+
740
+
741
+ def register_new_joint_type(
742
+ joint_type: str,
743
+ joint_model: JointModel,
744
+ q_width: int,
745
+ qd_width: Optional[int] = None,
746
+ overwrite: bool = False,
747
+ ):
748
+ # this name is used
749
+ assert joint_type != "default", "Please use another name."
750
+
751
+ exists = joint_type in _joint_types
752
+ if exists and overwrite:
753
+ for dic in [
754
+ base.Q_WIDTHS,
755
+ base.QD_WIDTHS,
756
+ _joint_types,
757
+ ]:
758
+ dic.pop(joint_type)
759
+ else:
760
+ assert (
761
+ not exists
762
+ ), f"joint type `{joint_type}`already exists, use `overwrite=True`"
763
+
764
+ if qd_width is None:
765
+ qd_width = q_width
766
+
767
+ assert len(joint_model.motion) == qd_width
768
+
769
+ _joint_types.update({joint_type: joint_model})
770
+ base.Q_WIDTHS.update({joint_type: q_width})
771
+ base.QD_WIDTHS.update({joint_type: qd_width})
772
+
773
+
774
+ def _limit_scope_of_joint_params(
775
+ joint_type: str, joint_params: dict[str, tree_utils.PyTree]
776
+ ) -> tree_utils.PyTree:
777
+ if joint_type not in joint_params:
778
+ return joint_params["default"]
779
+ else:
780
+ return joint_params[joint_type]
781
+
782
+
783
+ def jcalc_transform(
784
+ joint_type: str, q: jax.Array, joint_params: dict[str, tree_utils.PyTree]
785
+ ) -> base.Transform:
786
+ joint_params = _limit_scope_of_joint_params(joint_type, joint_params)
787
+ return _joint_types[joint_type].transform(q, joint_params)
788
+
789
+
790
+ def _to_motion(
791
+ m: base.Motion | Callable[[jax.Array], base.Motion], joint_params: tree_utils.PyTree
792
+ ) -> base.Motion:
793
+ if isinstance(m, base.Motion):
794
+ return m
795
+ return m(joint_params)
796
+
797
+
798
+ def jcalc_motion(
799
+ joint_type: str, qd: jax.Array, joint_params: dict[str, tree_utils.PyTree]
800
+ ) -> base.Motion:
801
+ joint_params = _limit_scope_of_joint_params(joint_type, joint_params)
802
+ list_motion = _joint_types[joint_type].motion
803
+ m = base.Motion.zero()
804
+ for dof in range(len(list_motion)):
805
+ m += _to_motion(list_motion[dof], joint_params) * qd[dof]
806
+ return m
807
+
808
+
809
+ def jcalc_tau(
810
+ joint_type: str, f: base.Force, joint_params: dict[str, tree_utils.PyTree]
811
+ ) -> jax.Array:
812
+ joint_params = _limit_scope_of_joint_params(joint_type, joint_params)
813
+ list_motion = _joint_types[joint_type].motion
814
+ return jnp.array(
815
+ [algebra.motion_dot(_to_motion(m, joint_params), f) for m in list_motion]
816
+ )
817
+
818
+
819
+ def _init_joint_params(key: jax.Array, sys: base.System) -> base.System:
820
+ """Search systems for custom joints and call their JointModel.init_joint_params
821
+ functions. Then return updated system."""
822
+
823
+ joint_params_init_fns = {}
824
+ for typ in sys.link_types:
825
+ if typ not in joint_params_init_fns:
826
+ init_joint_params = _joint_types[typ].init_joint_params
827
+ if init_joint_params is not None:
828
+ joint_params_init_fns[typ] = init_joint_params
829
+
830
+ joint_params: dict[str, tree_utils.PyTree] = {}
831
+ n_links = sys.num_links()
832
+ for typ in joint_params_init_fns:
833
+ keys = jax.random.split(key, num=n_links + 1)
834
+ key, consume = keys[0], keys[1:]
835
+ joint_params[typ] = jax.vmap(joint_params_init_fns[typ])(consume)
836
+
837
+ # add batch default parameters
838
+ joint_params["default"] = jnp.zeros((n_links, 0))
839
+
840
+ return sys.replace(links=sys.links.replace(joint_params=joint_params))