imt-ring 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. imt_ring-1.2.1.dist-info/METADATA +91 -0
  2. imt_ring-1.2.1.dist-info/RECORD +83 -0
  3. imt_ring-1.2.1.dist-info/WHEEL +5 -0
  4. imt_ring-1.2.1.dist-info/top_level.txt +1 -0
  5. ring/__init__.py +63 -0
  6. ring/algebra.py +100 -0
  7. ring/algorithms/__init__.py +45 -0
  8. ring/algorithms/_random.py +403 -0
  9. ring/algorithms/custom_joints/__init__.py +6 -0
  10. ring/algorithms/custom_joints/rr_imp_joint.py +69 -0
  11. ring/algorithms/custom_joints/rr_joint.py +33 -0
  12. ring/algorithms/custom_joints/suntay.py +424 -0
  13. ring/algorithms/dynamics.py +345 -0
  14. ring/algorithms/generator/__init__.py +25 -0
  15. ring/algorithms/generator/base.py +414 -0
  16. ring/algorithms/generator/batch.py +282 -0
  17. ring/algorithms/generator/motion_artifacts.py +222 -0
  18. ring/algorithms/generator/pd_control.py +182 -0
  19. ring/algorithms/generator/randomize.py +119 -0
  20. ring/algorithms/generator/transforms.py +410 -0
  21. ring/algorithms/generator/types.py +36 -0
  22. ring/algorithms/jcalc.py +840 -0
  23. ring/algorithms/kinematics.py +202 -0
  24. ring/algorithms/sensors.py +582 -0
  25. ring/base.py +1046 -0
  26. ring/io/__init__.py +9 -0
  27. ring/io/examples/branched.xml +24 -0
  28. ring/io/examples/exclude/knee_trans_dof.xml +26 -0
  29. ring/io/examples/exclude/standard_sys.xml +106 -0
  30. ring/io/examples/exclude/standard_sys_rr_imp.xml +106 -0
  31. ring/io/examples/inv_pendulum.xml +14 -0
  32. ring/io/examples/knee_flexible_imus.xml +22 -0
  33. ring/io/examples/spherical_stiff.xml +11 -0
  34. ring/io/examples/symmetric.xml +12 -0
  35. ring/io/examples/test_all_1.xml +39 -0
  36. ring/io/examples/test_all_2.xml +39 -0
  37. ring/io/examples/test_ang0_pos0.xml +9 -0
  38. ring/io/examples/test_control.xml +16 -0
  39. ring/io/examples/test_double_pendulum.xml +14 -0
  40. ring/io/examples/test_free.xml +11 -0
  41. ring/io/examples/test_kinematics.xml +23 -0
  42. ring/io/examples/test_morph_system/four_seg_seg1.xml +26 -0
  43. ring/io/examples/test_morph_system/four_seg_seg3.xml +26 -0
  44. ring/io/examples/test_randomize_position.xml +26 -0
  45. ring/io/examples/test_sensors.xml +13 -0
  46. ring/io/examples/test_three_seg_seg2.xml +23 -0
  47. ring/io/examples.py +42 -0
  48. ring/io/test_examples.py +6 -0
  49. ring/io/xml/__init__.py +6 -0
  50. ring/io/xml/abstract.py +300 -0
  51. ring/io/xml/from_xml.py +299 -0
  52. ring/io/xml/test_from_xml.py +56 -0
  53. ring/io/xml/test_to_xml.py +31 -0
  54. ring/io/xml/to_xml.py +94 -0
  55. ring/maths.py +397 -0
  56. ring/ml/__init__.py +33 -0
  57. ring/ml/base.py +292 -0
  58. ring/ml/callbacks.py +434 -0
  59. ring/ml/ml_utils.py +272 -0
  60. ring/ml/optimizer.py +149 -0
  61. ring/ml/params/0x13e3518065c21cd8.pickle +0 -0
  62. ring/ml/ringnet.py +279 -0
  63. ring/ml/train.py +318 -0
  64. ring/ml/training_loop.py +131 -0
  65. ring/rendering/__init__.py +2 -0
  66. ring/rendering/base_render.py +271 -0
  67. ring/rendering/mujoco_render.py +222 -0
  68. ring/rendering/vispy_render.py +340 -0
  69. ring/rendering/vispy_visuals.py +290 -0
  70. ring/sim2real/__init__.py +7 -0
  71. ring/sim2real/sim2real.py +288 -0
  72. ring/spatial.py +126 -0
  73. ring/sys_composer/__init__.py +5 -0
  74. ring/sys_composer/delete_sys.py +114 -0
  75. ring/sys_composer/inject_sys.py +110 -0
  76. ring/sys_composer/morph_sys.py +361 -0
  77. ring/utils/__init__.py +21 -0
  78. ring/utils/batchsize.py +51 -0
  79. ring/utils/colab.py +48 -0
  80. ring/utils/hdf5.py +198 -0
  81. ring/utils/normalizer.py +56 -0
  82. ring/utils/path.py +44 -0
  83. ring/utils/utils.py +161 -0
ring/base.py ADDED
@@ -0,0 +1,1046 @@
1
+ from pathlib import Path
2
+ from typing import Any, Callable, Optional, Sequence, Union
3
+
4
+ from flax import struct
5
+ import jax
6
+ from jax.core import Tracer
7
+ import jax.numpy as jnp
8
+ from jax.tree_util import tree_map
9
+ import numpy as np
10
+ import tree_utils as tu
11
+
12
+ import ring
13
+ from ring import maths
14
+ from ring import spatial
15
+
16
+ Scalar = jax.Array
17
+ Vector = jax.Array
18
+ Quaternion = jax.Array
19
+
20
+
21
+ Color = Optional[str | tuple[float, float, float] | tuple[float, float, float, float]]
22
+
23
+
24
+ class _Base:
25
+ """Base functionality of all spatial datatypes.
26
+ Copied and modified from https://github.com/google/brax/blob/main/brax/v2/base.py
27
+ """
28
+
29
+ def __add__(self, o: Any) -> Any:
30
+ return tree_map(lambda x, y: x + y, self, o)
31
+
32
+ def __sub__(self, o: Any) -> Any:
33
+ return tree_map(lambda x, y: x - y, self, o)
34
+
35
+ def __mul__(self, o: Any) -> Any:
36
+ return tree_map(lambda x: x * o, self)
37
+
38
+ def __neg__(self) -> Any:
39
+ return tree_map(lambda x: -x, self)
40
+
41
+ def __truediv__(self, o: Any) -> Any:
42
+ return tree_map(lambda x: x / o, self)
43
+
44
+ def __getitem__(self, i: int) -> Any:
45
+ return self.take(i)
46
+
47
+ def reshape(self, shape: Sequence[int]) -> Any:
48
+ return tree_map(lambda x: x.reshape(shape), self)
49
+
50
+ def slice(self, beg: int, end: int) -> Any:
51
+ return tree_map(lambda x: x[beg:end], self)
52
+
53
+ def take(self, i, axis=0) -> Any:
54
+ return tree_map(lambda x: jnp.take(x, i, axis=axis), self)
55
+
56
+ def hstack(self, *others: Any) -> Any:
57
+ return tree_map(lambda *x: jnp.hstack(x), self, *others)
58
+
59
+ def vstack(self, *others: Any) -> Any:
60
+ return tree_map(lambda *x: jnp.vstack(x), self, *others)
61
+
62
+ def concatenate(self, *others: Any, axis: int = 0) -> Any:
63
+ return tree_map(lambda *x: jnp.concatenate(x, axis=axis), self, *others)
64
+
65
+ def batch(self, *others, along_existing_first_axis: bool = False) -> Any:
66
+ return tu.tree_batch((self,) + others, along_existing_first_axis, "jax")
67
+
68
+ def index_set(self, idx: Union[jnp.ndarray, Sequence[jnp.ndarray]], o: Any) -> Any:
69
+ return tree_map(lambda x, y: x.at[idx].set(y), self, o)
70
+
71
+ def index_sum(self, idx: Union[jnp.ndarray, Sequence[jnp.ndarray]], o: Any) -> Any:
72
+ return tree_map(lambda x, y: x.at[idx].add(y), self, o)
73
+
74
+ @property
75
+ def T(self):
76
+ return tree_map(lambda x: x.T, self)
77
+
78
+ def flatten(self, num_batch_dims: int = 0) -> jax.Array:
79
+ return tu.batch_concat(self, num_batch_dims)
80
+
81
+ def squeeze(self):
82
+ return tree_map(lambda x: jnp.squeeze(x), self)
83
+
84
+ def squeeze_1d(self):
85
+ return tree_map(lambda x: jnp.atleast_1d(jnp.squeeze(x)), self)
86
+
87
+ def batch_dim(self) -> int:
88
+ return tu.tree_shape(self)
89
+
90
+ def transpose(self, axes: Sequence[int]) -> Any:
91
+ return tree_map(lambda x: jnp.transpose(x, axes), self)
92
+
93
+ def __iter__(self):
94
+ raise NotImplementedError
95
+
96
+ def repeat(self, repeats, axis=0):
97
+ return tree_map(lambda x: jnp.repeat(x, repeats, axis), self)
98
+
99
+ def ndim(self):
100
+ return tu.tree_ndim(self)
101
+
102
+ def shape(self, axis=0) -> int:
103
+ return tu.tree_shape(self, axis)
104
+
105
+ def __len__(self) -> int:
106
+ Bs = tree_map(lambda arr: arr.shape[0], self)
107
+ Bs = set(jax.tree_util.tree_flatten(Bs)[0])
108
+ assert len(Bs) == 1
109
+ return list(Bs)[0]
110
+
111
+
112
+ @struct.dataclass
113
+ class Transform(_Base):
114
+ """Represents the Transformation from Plücker A to Plücker B,
115
+ where B is located relative to A at `pos` in frame A and `rot` is the
116
+ relative quaternion from A to B."""
117
+
118
+ pos: Vector
119
+ rot: Quaternion
120
+
121
+ @classmethod
122
+ def create(cls, pos=None, rot=None):
123
+ assert not (pos is None and rot is None), "One must be given."
124
+ shape_rot = rot.shape[:-1] if rot is not None else ()
125
+ shape_pos = pos.shape[:-1] if pos is not None else ()
126
+
127
+ if pos is None:
128
+ pos = jnp.zeros(shape_rot + (3,))
129
+ if rot is None:
130
+ rot = jnp.array([1.0, 0, 0, 0])
131
+ rot = jnp.tile(jnp.array([1.0, 0.0, 0.0, 0.0]), shape_pos + (1,))
132
+
133
+ assert pos.shape[:-1] == rot.shape[:-1]
134
+
135
+ return Transform(pos, rot)
136
+
137
+ @classmethod
138
+ def zero(cls, shape=()) -> "Transform":
139
+ """Returns a zero transform with a batch shape."""
140
+ pos = jnp.zeros(shape + (3,))
141
+ rot = jnp.tile(jnp.array([1.0, 0.0, 0.0, 0.0]), shape + (1,))
142
+ return Transform(pos, rot)
143
+
144
+ def as_matrix(self) -> jax.Array:
145
+ E = maths.quat_to_3x3(self.rot)
146
+ return spatial.quadrants(aa=E, bb=E) @ spatial.xlt(self.pos)
147
+
148
+
149
+ @struct.dataclass
150
+ class Motion(_Base):
151
+ "Coordinate vector that represents a spatial motion vector in Plücker Coordinates."
152
+ ang: Vector
153
+ vel: Vector
154
+
155
+ @classmethod
156
+ def create(cls, ang=None, vel=None):
157
+ assert not (ang is None and vel is None), "One must be given."
158
+ if ang is None:
159
+ ang = jnp.zeros((3,))
160
+ if vel is None:
161
+ vel = jnp.zeros((3,))
162
+ return Motion(ang, vel)
163
+
164
+ @classmethod
165
+ def zero(cls, shape=()) -> "Motion":
166
+ ang = jnp.zeros(shape + (3,))
167
+ vel = jnp.zeros(shape + (3,))
168
+ return Motion(ang, vel)
169
+
170
+ def as_matrix(self):
171
+ return self.flatten()
172
+
173
+
174
+ @struct.dataclass
175
+ class Force(_Base):
176
+ "Coordinate vector that represents a spatial force vector in Plücker Coordinates."
177
+ ang: Vector
178
+ vel: Vector
179
+
180
+ @classmethod
181
+ def create(cls, ang=None, vel=None):
182
+ assert not (ang is None and vel is None), "One must be given."
183
+ if ang is None:
184
+ ang = jnp.zeros((3,))
185
+ if vel is None:
186
+ vel = jnp.zeros((3,))
187
+ return Force(ang, vel)
188
+
189
+ @classmethod
190
+ def zero(cls, shape=()) -> "Force":
191
+ ang = jnp.zeros(shape + (3,))
192
+ vel = jnp.zeros(shape + (3,))
193
+ return Force(ang, vel)
194
+
195
+ def as_matrix(self):
196
+ return self.flatten()
197
+
198
+
199
+ @struct.dataclass
200
+ class Inertia(_Base):
201
+ """Spatial Inertia Matrix in Plücker Coordinates.
202
+ Note that `h` is *not* the center of mass."""
203
+
204
+ it_3x3: jax.Array
205
+ h: Vector
206
+ mass: Vector
207
+
208
+ @classmethod
209
+ def create(cls, mass: Vector, transform: Transform, it_3x3: jnp.ndarray):
210
+ """Construct spatial inertia of an object with mass `mass` located and aligned
211
+ with a coordinate system that is given by `transform` where `transform` is from
212
+ parent to local geometry coordinates.
213
+ """
214
+ it_3x3 = maths.rotate_matrix(it_3x3, maths.quat_inv(transform.rot))
215
+ it_3x3 = spatial.mcI(mass, transform.pos, it_3x3)[:3, :3]
216
+ h = mass * transform.pos
217
+ return cls(it_3x3, h, mass)
218
+
219
+ @classmethod
220
+ def zero(cls, shape=()) -> "Inertia":
221
+ it_shape_3x3 = jnp.zeros(shape + (3, 3))
222
+ h = jnp.zeros(shape + (3,))
223
+ mass = jnp.zeros(shape + (1,))
224
+ return cls(it_shape_3x3, h, mass)
225
+
226
+ def as_matrix(self):
227
+ hcross = spatial.cross(self.h)
228
+ return spatial.quadrants(self.it_3x3, hcross, -hcross, self.mass * jnp.eye(3))
229
+
230
+
231
+ @struct.dataclass
232
+ class Geometry(_Base):
233
+ mass: jax.Array
234
+ transform: Transform
235
+ link_idx: int = struct.field(pytree_node=False)
236
+
237
+ color: Color = struct.field(pytree_node=False)
238
+ edge_color: Color = struct.field(pytree_node=False)
239
+
240
+
241
+ @struct.dataclass
242
+ class XYZ(Geometry):
243
+ # TODO: possibly subclass this of _Base? does this need a mass, transform, and
244
+ # link_idx? maybe just transform?
245
+ size: float
246
+
247
+ @classmethod
248
+ def create(cls, link_idx: int, size: float):
249
+ return cls(0.0, Transform.zero(), link_idx, None, None, size)
250
+
251
+ def get_it_3x3(self) -> jax.Array:
252
+ return jnp.zeros((3, 3))
253
+
254
+
255
+ @struct.dataclass
256
+ class Sphere(Geometry):
257
+ radius: float
258
+
259
+ def get_it_3x3(self) -> jax.Array:
260
+ it_3x3 = 2 / 5 * self.mass * self.radius**2 * jnp.eye(3)
261
+ return it_3x3
262
+
263
+
264
+ @struct.dataclass
265
+ class Box(Geometry):
266
+ dim_x: float
267
+ dim_y: float
268
+ dim_z: float
269
+
270
+ def get_it_3x3(self) -> jax.Array:
271
+ it_3x3 = (
272
+ 1
273
+ / 12
274
+ * self.mass
275
+ * jnp.diag(
276
+ jnp.array(
277
+ [
278
+ self.dim_y**2 + self.dim_z**2,
279
+ self.dim_x**2 + self.dim_z**2,
280
+ self.dim_x**2 + self.dim_y**2,
281
+ ]
282
+ )
283
+ )
284
+ )
285
+ return it_3x3
286
+
287
+
288
+ @struct.dataclass
289
+ class Cylinder(Geometry):
290
+ """Length is along x-axis."""
291
+
292
+ radius: float
293
+ length: float
294
+
295
+ def get_it_3x3(self) -> jax.Array:
296
+ radius_dir = 3 * self.radius**2 + self.length**2
297
+ it_3x3 = (
298
+ 1
299
+ / 12
300
+ * self.mass
301
+ * jnp.diag(jnp.array([6 * self.radius**2, radius_dir, radius_dir]))
302
+ )
303
+ return it_3x3
304
+
305
+
306
+ @struct.dataclass
307
+ class Capsule(Geometry):
308
+ """Length is along x-axis."""
309
+
310
+ radius: float
311
+ length: float
312
+
313
+ def get_it_3x3(self) -> jax.Array:
314
+ """https://github.com/thomasmarsh/ODE/blob/master/ode/src/mass.cpp#L141"""
315
+ r = self.radius
316
+ d = self.length
317
+
318
+ v_cyl = jnp.pi * r**2 * d
319
+ v_cap = 4 / 3 * jnp.pi * r**3
320
+
321
+ v_tot = v_cyl + v_cap
322
+
323
+ m_cyl = self.mass * v_cyl / v_tot
324
+ m_cap = self.mass * v_cap / v_tot
325
+
326
+ I_a = m_cyl * (0.25 * r**2 + 1 / 12 * d**2) + m_cap * (
327
+ 0.4 * r**2 + 0.375 * r * d + 0.25 * d**2
328
+ )
329
+ I_b = (0.5 * m_cyl + 0.4 * m_cap) * r**2
330
+
331
+ return jnp.diag(jnp.array([I_b, I_a, I_a]))
332
+
333
+
334
+ _DEFAULT_JOINT_PARAMS_DICT: dict[str, tu.PyTree] = {"default": jnp.array([])}
335
+
336
+
337
+ @struct.dataclass
338
+ class Link(_Base):
339
+ transform1: Transform
340
+
341
+ # only used by `setup_fn_randomize_positions`
342
+ pos_min: jax.Array = struct.field(default_factory=lambda: jnp.zeros((3,)))
343
+ pos_max: jax.Array = struct.field(default_factory=lambda: jnp.zeros((3,)))
344
+
345
+ # these parameters can be used to model joints that have parameters
346
+ # they are directly feed into the `jcalc` routines
347
+ joint_params: dict[str, tu.PyTree] = struct.field(
348
+ default_factory=lambda: _DEFAULT_JOINT_PARAMS_DICT
349
+ )
350
+
351
+ # internal useage
352
+ # gets populated by `parse_system`
353
+ inertia: Inertia = Inertia.zero()
354
+ # gets populated by `forward_kinematics`
355
+ transform2: Transform = Transform.zero()
356
+ transform: Transform = Transform.zero()
357
+
358
+
359
+ @struct.dataclass
360
+ class MaxCoordOMC(_Base):
361
+ coordinate_system_name: str = struct.field(False)
362
+ pos_marker_number: int = struct.field(False)
363
+ pos_marker_constant_offset: jax.Array
364
+
365
+
366
+ Q_WIDTHS = {
367
+ "free": 7,
368
+ "free_2d": 3,
369
+ "frozen": 0,
370
+ "spherical": 4,
371
+ "p3d": 3,
372
+ # center of rotation, a `free` joint and then a `p3d` joint with custom
373
+ # parameter fields in `RMCG_Config`
374
+ "cor": 10,
375
+ "px": 1,
376
+ "py": 1,
377
+ "pz": 1,
378
+ "rx": 1,
379
+ "ry": 1,
380
+ "rz": 1,
381
+ "saddle": 2,
382
+ }
383
+ QD_WIDTHS = {
384
+ "free": 6,
385
+ "free_2d": 3,
386
+ "frozen": 0,
387
+ "spherical": 3,
388
+ "p3d": 3,
389
+ "cor": 9,
390
+ "px": 1,
391
+ "py": 1,
392
+ "pz": 1,
393
+ "rx": 1,
394
+ "ry": 1,
395
+ "rz": 1,
396
+ "saddle": 2,
397
+ }
398
+
399
+
400
+ @struct.dataclass
401
+ class System(_Base):
402
+ link_parents: list[int] = struct.field(False)
403
+ links: Link
404
+ link_types: list[str] = struct.field(False)
405
+ link_damping: jax.Array
406
+ link_armature: jax.Array
407
+ link_spring_stiffness: jax.Array
408
+ link_spring_zeropoint: jax.Array
409
+ # simulation timestep size
410
+ dt: float = struct.field(False)
411
+ # geometries in the system
412
+ geoms: list[Geometry]
413
+ # root / base acceleration offset
414
+ gravity: jax.Array = struct.field(default_factory=lambda: jnp.array([0, 0, -9.81]))
415
+
416
+ integration_method: str = struct.field(
417
+ False, default_factory=lambda: "semi_implicit_euler"
418
+ )
419
+ mass_mat_iters: int = struct.field(False, default_factory=lambda: 0)
420
+
421
+ link_names: list[str] = struct.field(False, default_factory=lambda: [])
422
+
423
+ model_name: Optional[str] = struct.field(False, default_factory=lambda: None)
424
+
425
+ omc: list[MaxCoordOMC | None] = struct.field(True, default_factory=lambda: [])
426
+
427
+ def num_links(self) -> int:
428
+ return len(self.link_parents)
429
+
430
+ def q_size(self) -> int:
431
+ return sum([Q_WIDTHS[typ] for typ in self.link_types])
432
+
433
+ def qd_size(self) -> int:
434
+ return sum([QD_WIDTHS[typ] for typ in self.link_types])
435
+
436
+ def name_to_idx(self, name: str) -> int:
437
+ return self.link_names.index(name)
438
+
439
+ def idx_to_name(self, idx: int, allow_world: bool = False) -> str:
440
+ if allow_world and idx == -1:
441
+ return "world"
442
+ assert idx >= 0, "Worldbody index has no name."
443
+ return self.link_names[idx]
444
+
445
+ def idx_map(self, type: str) -> dict:
446
+ "type: is either `l` or `q` or `d`"
447
+ dict_int_slices = {}
448
+
449
+ def f(_, idx_map, name: str, link_idx: int):
450
+ dict_int_slices[name] = idx_map[type](link_idx)
451
+
452
+ self.scan(f, "ll", self.link_names, list(range(self.num_links())))
453
+
454
+ return dict_int_slices
455
+
456
+ def parent_name(self, name: str) -> str:
457
+ return self.idx_to_name(self.link_parents[self.name_to_idx(name)])
458
+
459
+ def add_prefix(self, prefix: str = "") -> "System":
460
+ return self.replace(link_names=[prefix + name for name in self.link_names])
461
+
462
+ def change_model_name(
463
+ self,
464
+ new_name: Optional[str] = None,
465
+ prefix: Optional[str] = None,
466
+ suffix: Optional[str] = None,
467
+ ) -> "System":
468
+ if prefix is None:
469
+ prefix = ""
470
+ if suffix is None:
471
+ suffix = ""
472
+ if new_name is None:
473
+ new_name = self.model_name
474
+ name = prefix + new_name + suffix
475
+ return self.replace(model_name=name)
476
+
477
+ def change_link_name(self, old_name: str, new_name: str) -> "System":
478
+ old_idx = self.name_to_idx(old_name)
479
+ new_link_names = self.link_names.copy()
480
+ new_link_names[old_idx] = new_name
481
+ return self.replace(link_names=new_link_names)
482
+
483
+ def add_prefix_suffix(
484
+ self, prefix: Optional[str] = None, suffix: Optional[str] = None
485
+ ) -> "System":
486
+ if prefix is None:
487
+ prefix = ""
488
+ if suffix is None:
489
+ suffix = ""
490
+ new_link_names = [prefix + name + suffix for name in self.link_names]
491
+ return self.replace(link_names=new_link_names)
492
+
493
+ @staticmethod
494
+ def deep_equal(a, b):
495
+ if type(a) is not type(b):
496
+ return False
497
+ if isinstance(a, _Base):
498
+ return System.deep_equal(a.__dict__, b.__dict__)
499
+ if isinstance(a, dict):
500
+ if a.keys() != b.keys():
501
+ return False
502
+ return all(System.deep_equal(a[k], b[k]) for k in a.keys())
503
+ if isinstance(a, (list, tuple)):
504
+ if len(a) != len(b):
505
+ return False
506
+ return all(System.deep_equal(a[i], b[i]) for i in range(len(a)))
507
+ if isinstance(a, (np.ndarray, jnp.ndarray, jax.Array)):
508
+ return jnp.array_equal(a, b)
509
+ return a == b
510
+
511
+ def _replace_free_with_cor(self) -> "System":
512
+ # check that
513
+ # - all free joints connect to -1
514
+ # - all joints connecting to -1 are free joints
515
+ for i, p in enumerate(self.link_parents):
516
+ link_type = self.link_types[i]
517
+ if (p == -1 and link_type != "free") or (link_type == "free" and p != -1):
518
+ raise InvalidSystemError(
519
+ f"link={self.idx_to_name(i)}, parent="
520
+ f"{self.idx_to_name(p, allow_world=True)},"
521
+ f" joint={link_type}. Hint: Try setting `config.cor` to false."
522
+ )
523
+
524
+ def logic_replace_free_with_cor(name, olt, ola, old, ols, olz):
525
+ # by default new is equal to old
526
+ nlt, nla, nld, nls, nlz = olt, ola, old, ols, olz
527
+
528
+ # old link type == free
529
+ if olt == "free":
530
+ # cor joint is (free, p3d) stacked
531
+ nlt = "cor"
532
+ # entries of old armature are 3*ang (spherical), 3*pos (p3d)
533
+ nla = jnp.concatenate((ola, ola[3:]))
534
+ nld = jnp.concatenate((old, old[3:]))
535
+ nls = jnp.concatenate((ols, ols[3:]))
536
+ nlz = jnp.concatenate((olz, olz[4:]))
537
+
538
+ return nlt, nla, nld, nls, nlz
539
+
540
+ return _update_sys_if_replace_joint_type(self, logic_replace_free_with_cor)
541
+
542
+ def freeze(self, name: str | list[str]):
543
+ if isinstance(name, list):
544
+ sys = self
545
+ for n in name:
546
+ sys = sys.freeze(n)
547
+ return sys
548
+
549
+ def logic_freeze(link_name, olt, ola, old, ols, olz):
550
+ nlt, nla, nld, nls, nlz = olt, ola, old, ols, olz
551
+
552
+ if link_name == name:
553
+ nlt = "frozen"
554
+ nla = nld = nls = nlz = jnp.array([])
555
+
556
+ return nlt, nla, nld, nls, nlz
557
+
558
+ return _update_sys_if_replace_joint_type(self, logic_freeze)
559
+
560
+ def unfreeze(self, name: str, new_joint_type: str):
561
+ assert self.link_types[self.name_to_idx(name)] == "frozen"
562
+ assert new_joint_type != "frozen"
563
+
564
+ return self.change_joint_type(name, new_joint_type)
565
+
566
+ def change_joint_type(
567
+ self,
568
+ name: str,
569
+ new_joint_type: str,
570
+ new_arma: Optional[jax.Array] = None,
571
+ new_damp: Optional[jax.Array] = None,
572
+ new_stif: Optional[jax.Array] = None,
573
+ new_zero: Optional[jax.Array] = None,
574
+ ):
575
+ "By default damping, stiffness are set to zero."
576
+ q_size, qd_size = Q_WIDTHS[new_joint_type], QD_WIDTHS[new_joint_type]
577
+
578
+ def logic_unfreeze_to_spherical(link_name, olt, ola, old, ols, olz):
579
+ nlt, nla, nld, nls, nlz = olt, ola, old, ols, olz
580
+
581
+ if link_name == name:
582
+ nlt = new_joint_type
583
+ q_zeros = jnp.zeros((q_size))
584
+ qd_zeros = jnp.zeros((qd_size,))
585
+
586
+ nla = qd_zeros if new_arma is None else new_arma
587
+ nld = qd_zeros if new_damp is None else new_damp
588
+ nls = qd_zeros if new_stif is None else new_stif
589
+ nlz = q_zeros if new_zero is None else new_zero
590
+
591
+ # unit quaternion
592
+ if new_joint_type in ["spherical", "free", "cor"] and new_zero is None:
593
+ nlz = nlz.at[0].set(1.0)
594
+
595
+ return nlt, nla, nld, nls, nlz
596
+
597
+ return _update_sys_if_replace_joint_type(self, logic_unfreeze_to_spherical)
598
+
599
+ def findall_imus(self) -> list[str]:
600
+ return [name for name in self.link_names if name[:3] == "imu"]
601
+
602
+ def findall_segments(self) -> list[str]:
603
+ imus = self.findall_imus()
604
+ return [name for name in self.link_names if name not in imus]
605
+
606
+ def _bodies_indices_to_bodies_name(self, bodies: list[int]) -> list[str]:
607
+ return [self.idx_to_name(i) for i in bodies]
608
+
609
+ def findall_bodies_to_world(self, names: bool = False) -> list[int] | list[str]:
610
+ bodies = [i for i, p in enumerate(self.link_parents) if p == -1]
611
+ return self._bodies_indices_to_bodies_name(bodies) if names else bodies
612
+
613
+ def find_body_to_world(self, name: bool = False) -> int | str:
614
+ bodies = self.findall_bodies_to_world(names=name)
615
+ assert len(bodies) == 1
616
+ return bodies[0]
617
+
618
+ def findall_bodies_with_jointtype(
619
+ self, typ: str, names: bool = False
620
+ ) -> list[int] | list[str]:
621
+ bodies = [i for i, _typ in enumerate(self.link_types) if _typ == typ]
622
+ return self._bodies_indices_to_bodies_name(bodies) if names else bodies
623
+
624
+ def scan(self, f: Callable, in_types: str, *args, reverse: bool = False):
625
+ """Scan `f` along each link in system whilst carrying along state.
626
+
627
+ Args:
628
+ f (Callable[..., Y]): f(y: Y, *args) -> y
629
+ in_types: string specifying the type of each input arg:
630
+ 'l' is an input to be split according to link ranges
631
+ 'q' is an input to be split according to q ranges
632
+ 'd' is an input to be split according to qd ranges
633
+ args: Arguments passed to `f`, and split to match the link.
634
+ reverse (bool, optional): If `true` from leaves to root. Defaults to False.
635
+
636
+ Returns:
637
+ ys: Stacked output y of f.
638
+ """
639
+ return _scan_sys(self, f, in_types, *args, reverse=reverse)
640
+
641
+ def parse(self) -> "System":
642
+ """Initial setup of system. System object does not work unless it is parsed.
643
+ Currently it does:
644
+ - some consistency checks
645
+ - populate the spatial inertia tensors
646
+ - check that all names are unique
647
+ - check that names are strings
648
+ - check that all pos_min <= pos_max (unless traced)
649
+ - order geoms in ascending order based on their parent link idx
650
+ - check that all links have the correct size of
651
+ - damping
652
+ - armature
653
+ - stiffness
654
+ - zeropoint
655
+ - check that n_links == len(sys.omc)
656
+ """
657
+ return _parse_system(self)
658
+
659
+ def render(
660
+ self,
661
+ xs: Optional[Transform | list[Transform]] = None,
662
+ camera: Optional[str] = None,
663
+ show_pbar: bool = True,
664
+ backend: str = "mujoco",
665
+ render_every_nth: int = 1,
666
+ **scene_kwargs,
667
+ ) -> list[np.ndarray]:
668
+ """Render frames from system and trajectory of maximal coordinates `xs`.
669
+
670
+ Args:
671
+ sys (base.System): System to render.
672
+ xs (base.Transform | list[base.Transform]): Single or time-series
673
+ of maximal coordinates `xs`.
674
+ show_pbar (bool, optional): Whether or not to show a progress bar.
675
+ Defaults to True.
676
+
677
+ Returns:
678
+ list[np.ndarray]: Stacked rendered frames. Length == len(xs).
679
+ """
680
+ return ring.rendering.render(
681
+ self, xs, camera, show_pbar, backend, render_every_nth, **scene_kwargs
682
+ )
683
+
684
+ def render_prediction(
685
+ self,
686
+ xs: Transform | list[Transform],
687
+ yhat: dict | jax.Array | np.ndarray,
688
+ stepframe: int = 1,
689
+ # by default we don't predict the global rotation
690
+ transparent_segment_to_root: bool = True,
691
+ **kwargs,
692
+ ):
693
+ "`xs` matches `sys`. `yhat` matches `sys_noimu`. `yhat` are child-to-parent."
694
+ return ring.rendering.render_prediction(
695
+ self, xs, yhat, stepframe, transparent_segment_to_root, **kwargs
696
+ )
697
+
698
+ def delete_system(self, link_name: str | list[str], strict: bool = True):
699
+ "Cut subsystem starting at `link_name` (inclusive) from tree."
700
+ return ring.sys_composer.delete_subsystem(self, link_name, strict)
701
+
702
+ def make_sys_noimu(self, imu_link_names: Optional[list[str]] = None):
703
+ "Returns, e.g., imu_attachment = {'imu1': 'seg1', 'imu2': 'seg3'}"
704
+ return ring.sys_composer.make_sys_noimu(self, imu_link_names)
705
+
706
+ def inject_system(self, other_system: "System", at_body: Optional[str] = None):
707
+ """Combine two systems into one.
708
+
709
+ Args:
710
+ sys (base.System): Large system.
711
+ sub_sys (base.System): Small system that will be included into the
712
+ large system `sys`.
713
+ at_body (Optional[str], optional): Into which body of the large system
714
+ small system will be included. Defaults to `worldbody`.
715
+
716
+ Returns:
717
+ base.System: _description_
718
+ """
719
+ return ring.sys_composer.inject_system(self, other_system, at_body)
720
+
721
+ def morph_system(
722
+ self,
723
+ new_parents: Optional[list[int | str]] = None,
724
+ new_anchor: Optional[int | str] = None,
725
+ ):
726
+ """Re-orders the graph underlying the system. Returns a new system.
727
+
728
+ Args:
729
+ sys (base.System): System to be modified.
730
+ new_parents (list[int]): Let the i-th entry have value j. Then, after
731
+ morphing the system the system will be such that the link corresponding
732
+ to the i-th link in the old system will have as parent the link
733
+ corresponding to the j-th link in the old system.
734
+
735
+ Returns:
736
+ base.System: Modified system.
737
+ """
738
+ return ring.sys_composer.morph_system(self, new_parents, new_anchor)
739
+
740
+ @staticmethod
741
+ def from_xml(path: str, seed: int = 1):
742
+ return ring.io.load_sys_from_xml(path, seed)
743
+
744
+ @staticmethod
745
+ def from_str(xml: str, seed: int = 1):
746
+ return ring.io.load_sys_from_str(xml, seed)
747
+
748
+ def to_str(self) -> str:
749
+ return ring.io.save_sys_to_str(self)
750
+
751
+ def to_xml(self, path: str) -> None:
752
+ ring.io.save_sys_to_xml(self, path)
753
+
754
+ @classmethod
755
+ def create(cls, path_or_str: str, seed: int = 1) -> "System":
756
+ path = Path(path_or_str).with_suffix(".xml")
757
+
758
+ exists = False
759
+ try:
760
+ exists = path.exists()
761
+ except OSError:
762
+ # file length too length
763
+ pass
764
+
765
+ if exists:
766
+ return cls.from_xml(path, seed=seed)
767
+ else:
768
+ return cls.from_str(path_or_str)
769
+
770
+ def coordinate_vector_to_q(
771
+ self,
772
+ q: jax.Array,
773
+ custom_joints: dict[str, Callable] = {},
774
+ ) -> jax.Array:
775
+ """Map a coordinate vector `q` to the minimal coordinates vector of the sys"""
776
+ # Does, e.g.
777
+ # - normalize quaternions
778
+ # - hinge joints in [-pi, pi]
779
+ q_preproc = []
780
+
781
+ def preprocess(_, __, link_type, q):
782
+ to_q = ring.algorithms.jcalc.get_joint_model(
783
+ link_type
784
+ ).coordinate_vector_to_q
785
+ # function in custom_joints has priority over JointModel
786
+ if link_type in custom_joints:
787
+ to_q = custom_joints[link_type]
788
+ if to_q is None:
789
+ raise NotImplementedError(
790
+ f"Please specify the custom joint `{link_type}`"
791
+ " either using the `custom_joints` arguments or using the"
792
+ " JointModel.coordinate_vector_to_q field."
793
+ )
794
+ new_q = to_q(q)
795
+ q_preproc.append(new_q)
796
+
797
+ self.scan(preprocess, "lq", self.link_types, q)
798
+ return jnp.concatenate(q_preproc)
799
+
800
+
801
+ def _update_sys_if_replace_joint_type(sys: System, logic) -> System:
802
+ lt, la, ld, ls, lz = [], [], [], [], []
803
+
804
+ def f(_, __, name, olt, ola, old, ols, olz):
805
+ nlt, nla, nld, nls, nlz = logic(name, olt, ola, old, ols, olz)
806
+
807
+ lt.append(nlt)
808
+ la.append(nla)
809
+ ld.append(nld)
810
+ ls.append(nls)
811
+ lz.append(nlz)
812
+
813
+ sys.scan(
814
+ f,
815
+ "lldddq",
816
+ sys.link_names,
817
+ sys.link_types,
818
+ sys.link_armature,
819
+ sys.link_damping,
820
+ sys.link_spring_stiffness,
821
+ sys.link_spring_zeropoint,
822
+ )
823
+
824
+ # lt is supposed to be a list of strings; no concat required
825
+ la, ld, ls, lz = map(jnp.concatenate, (la, ld, ls, lz))
826
+
827
+ sys = sys.replace(
828
+ link_types=lt,
829
+ link_armature=la,
830
+ link_damping=ld,
831
+ link_spring_stiffness=ls,
832
+ link_spring_zeropoint=lz,
833
+ )
834
+
835
+ # parse system such that it checks if all joint types have the
836
+ # correct dimensionality of damping / stiffness / zeropoint / armature
837
+ return sys.parse()
838
+
839
+
840
+ class InvalidSystemError(Exception):
841
+ pass
842
+
843
+
844
+ def _parse_system(sys: System) -> System:
845
+ assert len(sys.link_parents) == len(sys.link_types) == sys.links.batch_dim()
846
+ assert len(sys.omc) == sys.num_links()
847
+
848
+ for i, name in enumerate(sys.link_names):
849
+ assert sys.link_names.count(name) == 1, f"Duplicated name=`{name}` in system"
850
+ assert isinstance(name, str)
851
+
852
+ pos_min, pos_max = sys.links.pos_min, sys.links.pos_max
853
+
854
+ try:
855
+ from jax.errors import TracerBoolConversionError
856
+
857
+ try:
858
+ assert jnp.all(pos_max >= pos_min), f"min={pos_min}, max={pos_max}"
859
+ except TracerBoolConversionError:
860
+ pass
861
+ # on older versions of jax this import is not possible
862
+ except ImportError:
863
+ pass
864
+
865
+ for geom in sys.geoms:
866
+ assert geom.link_idx in list(range(sys.num_links())) + [-1]
867
+
868
+ inertia = _parse_system_calculate_inertia(sys)
869
+ sys = sys.replace(links=sys.links.replace(inertia=inertia))
870
+
871
+ # sort geoms in ascending order
872
+ geoms = sys.geoms.copy()
873
+ geoms.sort(key=lambda geom: geom.link_idx)
874
+ sys = sys.replace(geoms=geoms)
875
+
876
+ # round dt
877
+ # sys = sys.replace(dt=round(sys.dt, 8))
878
+
879
+ # check sizes of damping / arma / stiff / zeropoint
880
+ def check_dasz_unitq(_, __, name, typ, d, a, s, z):
881
+ q_size, qd_size = Q_WIDTHS[typ], QD_WIDTHS[typ]
882
+
883
+ error_msg = (
884
+ f"wrong size for link `{name}` of typ `{typ}` in model {sys.model_name}"
885
+ )
886
+
887
+ assert d.size == a.size == s.size == qd_size, error_msg
888
+ assert z.size == q_size, error_msg
889
+
890
+ if typ in ["spherical", "free", "cor"] and not isinstance(z, Tracer):
891
+ assert jnp.allclose(
892
+ jnp.linalg.norm(z[:4]), 1.0
893
+ ), f"not unit quat for link `{name}` of typ `{typ}` in model"
894
+ f" {sys.model_name}"
895
+
896
+ sys.scan(
897
+ check_dasz_unitq,
898
+ "lldddq",
899
+ sys.link_names,
900
+ sys.link_types,
901
+ sys.link_damping,
902
+ sys.link_armature,
903
+ sys.link_spring_stiffness,
904
+ sys.link_spring_zeropoint,
905
+ )
906
+
907
+ return sys
908
+
909
+
910
+ def _inertia_from_geometries(geometries: list[Geometry]) -> Inertia:
911
+ inertia = Inertia.zero()
912
+ for geom in geometries:
913
+ inertia += Inertia.create(geom.mass, geom.transform, geom.get_it_3x3())
914
+ return inertia
915
+
916
+
917
+ def _parse_system_calculate_inertia(sys: System):
918
+ def compute_inertia_per_link(_, __, link_idx: int):
919
+ geoms_link = []
920
+ for geom in sys.geoms:
921
+ if geom.link_idx == link_idx:
922
+ geoms_link.append(geom)
923
+
924
+ it = _inertia_from_geometries(geoms_link)
925
+ return it
926
+
927
+ return sys.scan(compute_inertia_per_link, "l", list(range(sys.num_links())))
928
+
929
+
930
+ def _scan_sys(sys: System, f: Callable, in_types: str, *args, reverse: bool = False):
931
+ assert len(args) == len(in_types)
932
+ for in_type, arg in zip(in_types, args):
933
+ B = len(arg)
934
+ if in_type == "l":
935
+ assert B == sys.num_links()
936
+ elif in_type == "q":
937
+ assert B == sys.q_size()
938
+ elif in_type == "d":
939
+ assert B == sys.qd_size()
940
+ else:
941
+ raise Exception("`in_types` must be one of `l` or `q` or `d`")
942
+
943
+ order = range(sys.num_links())
944
+ q_idx, qd_idx = 0, 0
945
+ q_idxs, qd_idxs = {}, {}
946
+ for link_idx, link_type in zip(order, sys.link_types):
947
+ # build map from
948
+ # link-idx -> q_idx
949
+ # link-idx -> qd_idx
950
+ q_idxs[link_idx] = slice(q_idx, q_idx + Q_WIDTHS[link_type])
951
+ qd_idxs[link_idx] = slice(qd_idx, qd_idx + QD_WIDTHS[link_type])
952
+ q_idx += Q_WIDTHS[link_type]
953
+ qd_idx += QD_WIDTHS[link_type]
954
+
955
+ idx_map = {
956
+ "l": lambda link_idx: link_idx,
957
+ "q": lambda link_idx: q_idxs[link_idx],
958
+ "d": lambda link_idx: qd_idxs[link_idx],
959
+ }
960
+
961
+ if reverse:
962
+ order = range(sys.num_links() - 1, -1, -1)
963
+
964
+ y, ys = None, []
965
+ for link_idx in order:
966
+ args_link = [arg[idx_map[t](link_idx)] for arg, t in zip(args, in_types)]
967
+ y = f(y, idx_map, *args_link)
968
+ ys.append(y)
969
+
970
+ if reverse:
971
+ ys.reverse()
972
+
973
+ ys = tu.tree_batch(ys, backend="jax")
974
+ return ys
975
+
976
+
977
+ @struct.dataclass
978
+ class State(_Base):
979
+ """The static and dynamic state of a system in minimal and maximal coordinates.
980
+ Use `.create()` to create this object.
981
+
982
+ Args:
983
+ q (jax.Array): System state in minimal coordinates (equals `sys.q_size()`)
984
+ qd (jax.Array): System velocity in minimal coordinates (equals `sys.qd_size()`)
985
+ x: (Transform): Maximal coordinates of all links. From epsilon-to-link.
986
+ mass_mat_inv (jax.Array): Inverse of the mass matrix. Internal usage.
987
+ """
988
+
989
+ q: jax.Array
990
+ qd: jax.Array
991
+ x: Transform
992
+ mass_mat_inv: jax.Array
993
+
994
+ @classmethod
995
+ def create(
996
+ cls,
997
+ sys: System,
998
+ q: Optional[jax.Array] = None,
999
+ qd: Optional[jax.Array] = None,
1000
+ x: Optional[Transform] = None,
1001
+ key: Optional[jax.Array] = None,
1002
+ custom_joints: dict[str, Callable] = {},
1003
+ ):
1004
+ """Create state of system.
1005
+
1006
+ Args:
1007
+ sys (System): The system for which to create a state.
1008
+ q (jax.Array, optional): The joint values of the system. Defaults to None.
1009
+ Which then defaults to zeros.
1010
+ qd (jax.Array, optional): The joint velocities of the system.
1011
+ Defaults to None. Which then defaults to zeros.
1012
+
1013
+ Returns:
1014
+ (State): Create State object.
1015
+ """
1016
+ if key is not None:
1017
+ assert q is None
1018
+ q = jax.random.normal(key, shape=(sys.q_size(),))
1019
+ q = sys.coordinate_vector_to_q(q, custom_joints)
1020
+ elif q is None:
1021
+ q = jnp.zeros((sys.q_size(),))
1022
+
1023
+ # free, cor, spherical joints are not zeros but have unit quaternions
1024
+ def replace_by_unit_quat(_, idx_map, link_typ, link_idx):
1025
+ nonlocal q
1026
+
1027
+ if link_typ in ["free", "cor", "spherical"]:
1028
+ q_idxs_link = idx_map["q"](link_idx)
1029
+ q = q.at[q_idxs_link.start].set(1.0)
1030
+
1031
+ sys.scan(
1032
+ replace_by_unit_quat,
1033
+ "ll",
1034
+ sys.link_types,
1035
+ list(range(sys.num_links())),
1036
+ )
1037
+ else:
1038
+ pass
1039
+
1040
+ if qd is None:
1041
+ qd = jnp.zeros((sys.qd_size(),))
1042
+
1043
+ if x is None:
1044
+ x = Transform.zero((sys.num_links(),))
1045
+
1046
+ return cls(q, qd, x, jnp.diag(jnp.ones((sys.qd_size(),))))