imt-ring 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- imt_ring-1.2.1.dist-info/METADATA +91 -0
- imt_ring-1.2.1.dist-info/RECORD +83 -0
- imt_ring-1.2.1.dist-info/WHEEL +5 -0
- imt_ring-1.2.1.dist-info/top_level.txt +1 -0
- ring/__init__.py +63 -0
- ring/algebra.py +100 -0
- ring/algorithms/__init__.py +45 -0
- ring/algorithms/_random.py +403 -0
- ring/algorithms/custom_joints/__init__.py +6 -0
- ring/algorithms/custom_joints/rr_imp_joint.py +69 -0
- ring/algorithms/custom_joints/rr_joint.py +33 -0
- ring/algorithms/custom_joints/suntay.py +424 -0
- ring/algorithms/dynamics.py +345 -0
- ring/algorithms/generator/__init__.py +25 -0
- ring/algorithms/generator/base.py +414 -0
- ring/algorithms/generator/batch.py +282 -0
- ring/algorithms/generator/motion_artifacts.py +222 -0
- ring/algorithms/generator/pd_control.py +182 -0
- ring/algorithms/generator/randomize.py +119 -0
- ring/algorithms/generator/transforms.py +410 -0
- ring/algorithms/generator/types.py +36 -0
- ring/algorithms/jcalc.py +840 -0
- ring/algorithms/kinematics.py +202 -0
- ring/algorithms/sensors.py +582 -0
- ring/base.py +1046 -0
- ring/io/__init__.py +9 -0
- ring/io/examples/branched.xml +24 -0
- ring/io/examples/exclude/knee_trans_dof.xml +26 -0
- ring/io/examples/exclude/standard_sys.xml +106 -0
- ring/io/examples/exclude/standard_sys_rr_imp.xml +106 -0
- ring/io/examples/inv_pendulum.xml +14 -0
- ring/io/examples/knee_flexible_imus.xml +22 -0
- ring/io/examples/spherical_stiff.xml +11 -0
- ring/io/examples/symmetric.xml +12 -0
- ring/io/examples/test_all_1.xml +39 -0
- ring/io/examples/test_all_2.xml +39 -0
- ring/io/examples/test_ang0_pos0.xml +9 -0
- ring/io/examples/test_control.xml +16 -0
- ring/io/examples/test_double_pendulum.xml +14 -0
- ring/io/examples/test_free.xml +11 -0
- ring/io/examples/test_kinematics.xml +23 -0
- ring/io/examples/test_morph_system/four_seg_seg1.xml +26 -0
- ring/io/examples/test_morph_system/four_seg_seg3.xml +26 -0
- ring/io/examples/test_randomize_position.xml +26 -0
- ring/io/examples/test_sensors.xml +13 -0
- ring/io/examples/test_three_seg_seg2.xml +23 -0
- ring/io/examples.py +42 -0
- ring/io/test_examples.py +6 -0
- ring/io/xml/__init__.py +6 -0
- ring/io/xml/abstract.py +300 -0
- ring/io/xml/from_xml.py +299 -0
- ring/io/xml/test_from_xml.py +56 -0
- ring/io/xml/test_to_xml.py +31 -0
- ring/io/xml/to_xml.py +94 -0
- ring/maths.py +397 -0
- ring/ml/__init__.py +33 -0
- ring/ml/base.py +292 -0
- ring/ml/callbacks.py +434 -0
- ring/ml/ml_utils.py +272 -0
- ring/ml/optimizer.py +149 -0
- ring/ml/params/0x13e3518065c21cd8.pickle +0 -0
- ring/ml/ringnet.py +279 -0
- ring/ml/train.py +318 -0
- ring/ml/training_loop.py +131 -0
- ring/rendering/__init__.py +2 -0
- ring/rendering/base_render.py +271 -0
- ring/rendering/mujoco_render.py +222 -0
- ring/rendering/vispy_render.py +340 -0
- ring/rendering/vispy_visuals.py +290 -0
- ring/sim2real/__init__.py +7 -0
- ring/sim2real/sim2real.py +288 -0
- ring/spatial.py +126 -0
- ring/sys_composer/__init__.py +5 -0
- ring/sys_composer/delete_sys.py +114 -0
- ring/sys_composer/inject_sys.py +110 -0
- ring/sys_composer/morph_sys.py +361 -0
- ring/utils/__init__.py +21 -0
- ring/utils/batchsize.py +51 -0
- ring/utils/colab.py +48 -0
- ring/utils/hdf5.py +198 -0
- ring/utils/normalizer.py +56 -0
- ring/utils/path.py +44 -0
- ring/utils/utils.py +161 -0
ring/base.py
ADDED
@@ -0,0 +1,1046 @@
|
|
1
|
+
from pathlib import Path
|
2
|
+
from typing import Any, Callable, Optional, Sequence, Union
|
3
|
+
|
4
|
+
from flax import struct
|
5
|
+
import jax
|
6
|
+
from jax.core import Tracer
|
7
|
+
import jax.numpy as jnp
|
8
|
+
from jax.tree_util import tree_map
|
9
|
+
import numpy as np
|
10
|
+
import tree_utils as tu
|
11
|
+
|
12
|
+
import ring
|
13
|
+
from ring import maths
|
14
|
+
from ring import spatial
|
15
|
+
|
16
|
+
Scalar = jax.Array
|
17
|
+
Vector = jax.Array
|
18
|
+
Quaternion = jax.Array
|
19
|
+
|
20
|
+
|
21
|
+
Color = Optional[str | tuple[float, float, float] | tuple[float, float, float, float]]
|
22
|
+
|
23
|
+
|
24
|
+
class _Base:
|
25
|
+
"""Base functionality of all spatial datatypes.
|
26
|
+
Copied and modified from https://github.com/google/brax/blob/main/brax/v2/base.py
|
27
|
+
"""
|
28
|
+
|
29
|
+
def __add__(self, o: Any) -> Any:
|
30
|
+
return tree_map(lambda x, y: x + y, self, o)
|
31
|
+
|
32
|
+
def __sub__(self, o: Any) -> Any:
|
33
|
+
return tree_map(lambda x, y: x - y, self, o)
|
34
|
+
|
35
|
+
def __mul__(self, o: Any) -> Any:
|
36
|
+
return tree_map(lambda x: x * o, self)
|
37
|
+
|
38
|
+
def __neg__(self) -> Any:
|
39
|
+
return tree_map(lambda x: -x, self)
|
40
|
+
|
41
|
+
def __truediv__(self, o: Any) -> Any:
|
42
|
+
return tree_map(lambda x: x / o, self)
|
43
|
+
|
44
|
+
def __getitem__(self, i: int) -> Any:
|
45
|
+
return self.take(i)
|
46
|
+
|
47
|
+
def reshape(self, shape: Sequence[int]) -> Any:
|
48
|
+
return tree_map(lambda x: x.reshape(shape), self)
|
49
|
+
|
50
|
+
def slice(self, beg: int, end: int) -> Any:
|
51
|
+
return tree_map(lambda x: x[beg:end], self)
|
52
|
+
|
53
|
+
def take(self, i, axis=0) -> Any:
|
54
|
+
return tree_map(lambda x: jnp.take(x, i, axis=axis), self)
|
55
|
+
|
56
|
+
def hstack(self, *others: Any) -> Any:
|
57
|
+
return tree_map(lambda *x: jnp.hstack(x), self, *others)
|
58
|
+
|
59
|
+
def vstack(self, *others: Any) -> Any:
|
60
|
+
return tree_map(lambda *x: jnp.vstack(x), self, *others)
|
61
|
+
|
62
|
+
def concatenate(self, *others: Any, axis: int = 0) -> Any:
|
63
|
+
return tree_map(lambda *x: jnp.concatenate(x, axis=axis), self, *others)
|
64
|
+
|
65
|
+
def batch(self, *others, along_existing_first_axis: bool = False) -> Any:
|
66
|
+
return tu.tree_batch((self,) + others, along_existing_first_axis, "jax")
|
67
|
+
|
68
|
+
def index_set(self, idx: Union[jnp.ndarray, Sequence[jnp.ndarray]], o: Any) -> Any:
|
69
|
+
return tree_map(lambda x, y: x.at[idx].set(y), self, o)
|
70
|
+
|
71
|
+
def index_sum(self, idx: Union[jnp.ndarray, Sequence[jnp.ndarray]], o: Any) -> Any:
|
72
|
+
return tree_map(lambda x, y: x.at[idx].add(y), self, o)
|
73
|
+
|
74
|
+
@property
|
75
|
+
def T(self):
|
76
|
+
return tree_map(lambda x: x.T, self)
|
77
|
+
|
78
|
+
def flatten(self, num_batch_dims: int = 0) -> jax.Array:
|
79
|
+
return tu.batch_concat(self, num_batch_dims)
|
80
|
+
|
81
|
+
def squeeze(self):
|
82
|
+
return tree_map(lambda x: jnp.squeeze(x), self)
|
83
|
+
|
84
|
+
def squeeze_1d(self):
|
85
|
+
return tree_map(lambda x: jnp.atleast_1d(jnp.squeeze(x)), self)
|
86
|
+
|
87
|
+
def batch_dim(self) -> int:
|
88
|
+
return tu.tree_shape(self)
|
89
|
+
|
90
|
+
def transpose(self, axes: Sequence[int]) -> Any:
|
91
|
+
return tree_map(lambda x: jnp.transpose(x, axes), self)
|
92
|
+
|
93
|
+
def __iter__(self):
|
94
|
+
raise NotImplementedError
|
95
|
+
|
96
|
+
def repeat(self, repeats, axis=0):
|
97
|
+
return tree_map(lambda x: jnp.repeat(x, repeats, axis), self)
|
98
|
+
|
99
|
+
def ndim(self):
|
100
|
+
return tu.tree_ndim(self)
|
101
|
+
|
102
|
+
def shape(self, axis=0) -> int:
|
103
|
+
return tu.tree_shape(self, axis)
|
104
|
+
|
105
|
+
def __len__(self) -> int:
|
106
|
+
Bs = tree_map(lambda arr: arr.shape[0], self)
|
107
|
+
Bs = set(jax.tree_util.tree_flatten(Bs)[0])
|
108
|
+
assert len(Bs) == 1
|
109
|
+
return list(Bs)[0]
|
110
|
+
|
111
|
+
|
112
|
+
@struct.dataclass
|
113
|
+
class Transform(_Base):
|
114
|
+
"""Represents the Transformation from Plücker A to Plücker B,
|
115
|
+
where B is located relative to A at `pos` in frame A and `rot` is the
|
116
|
+
relative quaternion from A to B."""
|
117
|
+
|
118
|
+
pos: Vector
|
119
|
+
rot: Quaternion
|
120
|
+
|
121
|
+
@classmethod
|
122
|
+
def create(cls, pos=None, rot=None):
|
123
|
+
assert not (pos is None and rot is None), "One must be given."
|
124
|
+
shape_rot = rot.shape[:-1] if rot is not None else ()
|
125
|
+
shape_pos = pos.shape[:-1] if pos is not None else ()
|
126
|
+
|
127
|
+
if pos is None:
|
128
|
+
pos = jnp.zeros(shape_rot + (3,))
|
129
|
+
if rot is None:
|
130
|
+
rot = jnp.array([1.0, 0, 0, 0])
|
131
|
+
rot = jnp.tile(jnp.array([1.0, 0.0, 0.0, 0.0]), shape_pos + (1,))
|
132
|
+
|
133
|
+
assert pos.shape[:-1] == rot.shape[:-1]
|
134
|
+
|
135
|
+
return Transform(pos, rot)
|
136
|
+
|
137
|
+
@classmethod
|
138
|
+
def zero(cls, shape=()) -> "Transform":
|
139
|
+
"""Returns a zero transform with a batch shape."""
|
140
|
+
pos = jnp.zeros(shape + (3,))
|
141
|
+
rot = jnp.tile(jnp.array([1.0, 0.0, 0.0, 0.0]), shape + (1,))
|
142
|
+
return Transform(pos, rot)
|
143
|
+
|
144
|
+
def as_matrix(self) -> jax.Array:
|
145
|
+
E = maths.quat_to_3x3(self.rot)
|
146
|
+
return spatial.quadrants(aa=E, bb=E) @ spatial.xlt(self.pos)
|
147
|
+
|
148
|
+
|
149
|
+
@struct.dataclass
|
150
|
+
class Motion(_Base):
|
151
|
+
"Coordinate vector that represents a spatial motion vector in Plücker Coordinates."
|
152
|
+
ang: Vector
|
153
|
+
vel: Vector
|
154
|
+
|
155
|
+
@classmethod
|
156
|
+
def create(cls, ang=None, vel=None):
|
157
|
+
assert not (ang is None and vel is None), "One must be given."
|
158
|
+
if ang is None:
|
159
|
+
ang = jnp.zeros((3,))
|
160
|
+
if vel is None:
|
161
|
+
vel = jnp.zeros((3,))
|
162
|
+
return Motion(ang, vel)
|
163
|
+
|
164
|
+
@classmethod
|
165
|
+
def zero(cls, shape=()) -> "Motion":
|
166
|
+
ang = jnp.zeros(shape + (3,))
|
167
|
+
vel = jnp.zeros(shape + (3,))
|
168
|
+
return Motion(ang, vel)
|
169
|
+
|
170
|
+
def as_matrix(self):
|
171
|
+
return self.flatten()
|
172
|
+
|
173
|
+
|
174
|
+
@struct.dataclass
|
175
|
+
class Force(_Base):
|
176
|
+
"Coordinate vector that represents a spatial force vector in Plücker Coordinates."
|
177
|
+
ang: Vector
|
178
|
+
vel: Vector
|
179
|
+
|
180
|
+
@classmethod
|
181
|
+
def create(cls, ang=None, vel=None):
|
182
|
+
assert not (ang is None and vel is None), "One must be given."
|
183
|
+
if ang is None:
|
184
|
+
ang = jnp.zeros((3,))
|
185
|
+
if vel is None:
|
186
|
+
vel = jnp.zeros((3,))
|
187
|
+
return Force(ang, vel)
|
188
|
+
|
189
|
+
@classmethod
|
190
|
+
def zero(cls, shape=()) -> "Force":
|
191
|
+
ang = jnp.zeros(shape + (3,))
|
192
|
+
vel = jnp.zeros(shape + (3,))
|
193
|
+
return Force(ang, vel)
|
194
|
+
|
195
|
+
def as_matrix(self):
|
196
|
+
return self.flatten()
|
197
|
+
|
198
|
+
|
199
|
+
@struct.dataclass
|
200
|
+
class Inertia(_Base):
|
201
|
+
"""Spatial Inertia Matrix in Plücker Coordinates.
|
202
|
+
Note that `h` is *not* the center of mass."""
|
203
|
+
|
204
|
+
it_3x3: jax.Array
|
205
|
+
h: Vector
|
206
|
+
mass: Vector
|
207
|
+
|
208
|
+
@classmethod
|
209
|
+
def create(cls, mass: Vector, transform: Transform, it_3x3: jnp.ndarray):
|
210
|
+
"""Construct spatial inertia of an object with mass `mass` located and aligned
|
211
|
+
with a coordinate system that is given by `transform` where `transform` is from
|
212
|
+
parent to local geometry coordinates.
|
213
|
+
"""
|
214
|
+
it_3x3 = maths.rotate_matrix(it_3x3, maths.quat_inv(transform.rot))
|
215
|
+
it_3x3 = spatial.mcI(mass, transform.pos, it_3x3)[:3, :3]
|
216
|
+
h = mass * transform.pos
|
217
|
+
return cls(it_3x3, h, mass)
|
218
|
+
|
219
|
+
@classmethod
|
220
|
+
def zero(cls, shape=()) -> "Inertia":
|
221
|
+
it_shape_3x3 = jnp.zeros(shape + (3, 3))
|
222
|
+
h = jnp.zeros(shape + (3,))
|
223
|
+
mass = jnp.zeros(shape + (1,))
|
224
|
+
return cls(it_shape_3x3, h, mass)
|
225
|
+
|
226
|
+
def as_matrix(self):
|
227
|
+
hcross = spatial.cross(self.h)
|
228
|
+
return spatial.quadrants(self.it_3x3, hcross, -hcross, self.mass * jnp.eye(3))
|
229
|
+
|
230
|
+
|
231
|
+
@struct.dataclass
|
232
|
+
class Geometry(_Base):
|
233
|
+
mass: jax.Array
|
234
|
+
transform: Transform
|
235
|
+
link_idx: int = struct.field(pytree_node=False)
|
236
|
+
|
237
|
+
color: Color = struct.field(pytree_node=False)
|
238
|
+
edge_color: Color = struct.field(pytree_node=False)
|
239
|
+
|
240
|
+
|
241
|
+
@struct.dataclass
|
242
|
+
class XYZ(Geometry):
|
243
|
+
# TODO: possibly subclass this of _Base? does this need a mass, transform, and
|
244
|
+
# link_idx? maybe just transform?
|
245
|
+
size: float
|
246
|
+
|
247
|
+
@classmethod
|
248
|
+
def create(cls, link_idx: int, size: float):
|
249
|
+
return cls(0.0, Transform.zero(), link_idx, None, None, size)
|
250
|
+
|
251
|
+
def get_it_3x3(self) -> jax.Array:
|
252
|
+
return jnp.zeros((3, 3))
|
253
|
+
|
254
|
+
|
255
|
+
@struct.dataclass
|
256
|
+
class Sphere(Geometry):
|
257
|
+
radius: float
|
258
|
+
|
259
|
+
def get_it_3x3(self) -> jax.Array:
|
260
|
+
it_3x3 = 2 / 5 * self.mass * self.radius**2 * jnp.eye(3)
|
261
|
+
return it_3x3
|
262
|
+
|
263
|
+
|
264
|
+
@struct.dataclass
|
265
|
+
class Box(Geometry):
|
266
|
+
dim_x: float
|
267
|
+
dim_y: float
|
268
|
+
dim_z: float
|
269
|
+
|
270
|
+
def get_it_3x3(self) -> jax.Array:
|
271
|
+
it_3x3 = (
|
272
|
+
1
|
273
|
+
/ 12
|
274
|
+
* self.mass
|
275
|
+
* jnp.diag(
|
276
|
+
jnp.array(
|
277
|
+
[
|
278
|
+
self.dim_y**2 + self.dim_z**2,
|
279
|
+
self.dim_x**2 + self.dim_z**2,
|
280
|
+
self.dim_x**2 + self.dim_y**2,
|
281
|
+
]
|
282
|
+
)
|
283
|
+
)
|
284
|
+
)
|
285
|
+
return it_3x3
|
286
|
+
|
287
|
+
|
288
|
+
@struct.dataclass
|
289
|
+
class Cylinder(Geometry):
|
290
|
+
"""Length is along x-axis."""
|
291
|
+
|
292
|
+
radius: float
|
293
|
+
length: float
|
294
|
+
|
295
|
+
def get_it_3x3(self) -> jax.Array:
|
296
|
+
radius_dir = 3 * self.radius**2 + self.length**2
|
297
|
+
it_3x3 = (
|
298
|
+
1
|
299
|
+
/ 12
|
300
|
+
* self.mass
|
301
|
+
* jnp.diag(jnp.array([6 * self.radius**2, radius_dir, radius_dir]))
|
302
|
+
)
|
303
|
+
return it_3x3
|
304
|
+
|
305
|
+
|
306
|
+
@struct.dataclass
|
307
|
+
class Capsule(Geometry):
|
308
|
+
"""Length is along x-axis."""
|
309
|
+
|
310
|
+
radius: float
|
311
|
+
length: float
|
312
|
+
|
313
|
+
def get_it_3x3(self) -> jax.Array:
|
314
|
+
"""https://github.com/thomasmarsh/ODE/blob/master/ode/src/mass.cpp#L141"""
|
315
|
+
r = self.radius
|
316
|
+
d = self.length
|
317
|
+
|
318
|
+
v_cyl = jnp.pi * r**2 * d
|
319
|
+
v_cap = 4 / 3 * jnp.pi * r**3
|
320
|
+
|
321
|
+
v_tot = v_cyl + v_cap
|
322
|
+
|
323
|
+
m_cyl = self.mass * v_cyl / v_tot
|
324
|
+
m_cap = self.mass * v_cap / v_tot
|
325
|
+
|
326
|
+
I_a = m_cyl * (0.25 * r**2 + 1 / 12 * d**2) + m_cap * (
|
327
|
+
0.4 * r**2 + 0.375 * r * d + 0.25 * d**2
|
328
|
+
)
|
329
|
+
I_b = (0.5 * m_cyl + 0.4 * m_cap) * r**2
|
330
|
+
|
331
|
+
return jnp.diag(jnp.array([I_b, I_a, I_a]))
|
332
|
+
|
333
|
+
|
334
|
+
_DEFAULT_JOINT_PARAMS_DICT: dict[str, tu.PyTree] = {"default": jnp.array([])}
|
335
|
+
|
336
|
+
|
337
|
+
@struct.dataclass
|
338
|
+
class Link(_Base):
|
339
|
+
transform1: Transform
|
340
|
+
|
341
|
+
# only used by `setup_fn_randomize_positions`
|
342
|
+
pos_min: jax.Array = struct.field(default_factory=lambda: jnp.zeros((3,)))
|
343
|
+
pos_max: jax.Array = struct.field(default_factory=lambda: jnp.zeros((3,)))
|
344
|
+
|
345
|
+
# these parameters can be used to model joints that have parameters
|
346
|
+
# they are directly feed into the `jcalc` routines
|
347
|
+
joint_params: dict[str, tu.PyTree] = struct.field(
|
348
|
+
default_factory=lambda: _DEFAULT_JOINT_PARAMS_DICT
|
349
|
+
)
|
350
|
+
|
351
|
+
# internal useage
|
352
|
+
# gets populated by `parse_system`
|
353
|
+
inertia: Inertia = Inertia.zero()
|
354
|
+
# gets populated by `forward_kinematics`
|
355
|
+
transform2: Transform = Transform.zero()
|
356
|
+
transform: Transform = Transform.zero()
|
357
|
+
|
358
|
+
|
359
|
+
@struct.dataclass
|
360
|
+
class MaxCoordOMC(_Base):
|
361
|
+
coordinate_system_name: str = struct.field(False)
|
362
|
+
pos_marker_number: int = struct.field(False)
|
363
|
+
pos_marker_constant_offset: jax.Array
|
364
|
+
|
365
|
+
|
366
|
+
Q_WIDTHS = {
|
367
|
+
"free": 7,
|
368
|
+
"free_2d": 3,
|
369
|
+
"frozen": 0,
|
370
|
+
"spherical": 4,
|
371
|
+
"p3d": 3,
|
372
|
+
# center of rotation, a `free` joint and then a `p3d` joint with custom
|
373
|
+
# parameter fields in `RMCG_Config`
|
374
|
+
"cor": 10,
|
375
|
+
"px": 1,
|
376
|
+
"py": 1,
|
377
|
+
"pz": 1,
|
378
|
+
"rx": 1,
|
379
|
+
"ry": 1,
|
380
|
+
"rz": 1,
|
381
|
+
"saddle": 2,
|
382
|
+
}
|
383
|
+
QD_WIDTHS = {
|
384
|
+
"free": 6,
|
385
|
+
"free_2d": 3,
|
386
|
+
"frozen": 0,
|
387
|
+
"spherical": 3,
|
388
|
+
"p3d": 3,
|
389
|
+
"cor": 9,
|
390
|
+
"px": 1,
|
391
|
+
"py": 1,
|
392
|
+
"pz": 1,
|
393
|
+
"rx": 1,
|
394
|
+
"ry": 1,
|
395
|
+
"rz": 1,
|
396
|
+
"saddle": 2,
|
397
|
+
}
|
398
|
+
|
399
|
+
|
400
|
+
@struct.dataclass
|
401
|
+
class System(_Base):
|
402
|
+
link_parents: list[int] = struct.field(False)
|
403
|
+
links: Link
|
404
|
+
link_types: list[str] = struct.field(False)
|
405
|
+
link_damping: jax.Array
|
406
|
+
link_armature: jax.Array
|
407
|
+
link_spring_stiffness: jax.Array
|
408
|
+
link_spring_zeropoint: jax.Array
|
409
|
+
# simulation timestep size
|
410
|
+
dt: float = struct.field(False)
|
411
|
+
# geometries in the system
|
412
|
+
geoms: list[Geometry]
|
413
|
+
# root / base acceleration offset
|
414
|
+
gravity: jax.Array = struct.field(default_factory=lambda: jnp.array([0, 0, -9.81]))
|
415
|
+
|
416
|
+
integration_method: str = struct.field(
|
417
|
+
False, default_factory=lambda: "semi_implicit_euler"
|
418
|
+
)
|
419
|
+
mass_mat_iters: int = struct.field(False, default_factory=lambda: 0)
|
420
|
+
|
421
|
+
link_names: list[str] = struct.field(False, default_factory=lambda: [])
|
422
|
+
|
423
|
+
model_name: Optional[str] = struct.field(False, default_factory=lambda: None)
|
424
|
+
|
425
|
+
omc: list[MaxCoordOMC | None] = struct.field(True, default_factory=lambda: [])
|
426
|
+
|
427
|
+
def num_links(self) -> int:
|
428
|
+
return len(self.link_parents)
|
429
|
+
|
430
|
+
def q_size(self) -> int:
|
431
|
+
return sum([Q_WIDTHS[typ] for typ in self.link_types])
|
432
|
+
|
433
|
+
def qd_size(self) -> int:
|
434
|
+
return sum([QD_WIDTHS[typ] for typ in self.link_types])
|
435
|
+
|
436
|
+
def name_to_idx(self, name: str) -> int:
|
437
|
+
return self.link_names.index(name)
|
438
|
+
|
439
|
+
def idx_to_name(self, idx: int, allow_world: bool = False) -> str:
|
440
|
+
if allow_world and idx == -1:
|
441
|
+
return "world"
|
442
|
+
assert idx >= 0, "Worldbody index has no name."
|
443
|
+
return self.link_names[idx]
|
444
|
+
|
445
|
+
def idx_map(self, type: str) -> dict:
|
446
|
+
"type: is either `l` or `q` or `d`"
|
447
|
+
dict_int_slices = {}
|
448
|
+
|
449
|
+
def f(_, idx_map, name: str, link_idx: int):
|
450
|
+
dict_int_slices[name] = idx_map[type](link_idx)
|
451
|
+
|
452
|
+
self.scan(f, "ll", self.link_names, list(range(self.num_links())))
|
453
|
+
|
454
|
+
return dict_int_slices
|
455
|
+
|
456
|
+
def parent_name(self, name: str) -> str:
|
457
|
+
return self.idx_to_name(self.link_parents[self.name_to_idx(name)])
|
458
|
+
|
459
|
+
def add_prefix(self, prefix: str = "") -> "System":
|
460
|
+
return self.replace(link_names=[prefix + name for name in self.link_names])
|
461
|
+
|
462
|
+
def change_model_name(
|
463
|
+
self,
|
464
|
+
new_name: Optional[str] = None,
|
465
|
+
prefix: Optional[str] = None,
|
466
|
+
suffix: Optional[str] = None,
|
467
|
+
) -> "System":
|
468
|
+
if prefix is None:
|
469
|
+
prefix = ""
|
470
|
+
if suffix is None:
|
471
|
+
suffix = ""
|
472
|
+
if new_name is None:
|
473
|
+
new_name = self.model_name
|
474
|
+
name = prefix + new_name + suffix
|
475
|
+
return self.replace(model_name=name)
|
476
|
+
|
477
|
+
def change_link_name(self, old_name: str, new_name: str) -> "System":
|
478
|
+
old_idx = self.name_to_idx(old_name)
|
479
|
+
new_link_names = self.link_names.copy()
|
480
|
+
new_link_names[old_idx] = new_name
|
481
|
+
return self.replace(link_names=new_link_names)
|
482
|
+
|
483
|
+
def add_prefix_suffix(
|
484
|
+
self, prefix: Optional[str] = None, suffix: Optional[str] = None
|
485
|
+
) -> "System":
|
486
|
+
if prefix is None:
|
487
|
+
prefix = ""
|
488
|
+
if suffix is None:
|
489
|
+
suffix = ""
|
490
|
+
new_link_names = [prefix + name + suffix for name in self.link_names]
|
491
|
+
return self.replace(link_names=new_link_names)
|
492
|
+
|
493
|
+
@staticmethod
|
494
|
+
def deep_equal(a, b):
|
495
|
+
if type(a) is not type(b):
|
496
|
+
return False
|
497
|
+
if isinstance(a, _Base):
|
498
|
+
return System.deep_equal(a.__dict__, b.__dict__)
|
499
|
+
if isinstance(a, dict):
|
500
|
+
if a.keys() != b.keys():
|
501
|
+
return False
|
502
|
+
return all(System.deep_equal(a[k], b[k]) for k in a.keys())
|
503
|
+
if isinstance(a, (list, tuple)):
|
504
|
+
if len(a) != len(b):
|
505
|
+
return False
|
506
|
+
return all(System.deep_equal(a[i], b[i]) for i in range(len(a)))
|
507
|
+
if isinstance(a, (np.ndarray, jnp.ndarray, jax.Array)):
|
508
|
+
return jnp.array_equal(a, b)
|
509
|
+
return a == b
|
510
|
+
|
511
|
+
def _replace_free_with_cor(self) -> "System":
|
512
|
+
# check that
|
513
|
+
# - all free joints connect to -1
|
514
|
+
# - all joints connecting to -1 are free joints
|
515
|
+
for i, p in enumerate(self.link_parents):
|
516
|
+
link_type = self.link_types[i]
|
517
|
+
if (p == -1 and link_type != "free") or (link_type == "free" and p != -1):
|
518
|
+
raise InvalidSystemError(
|
519
|
+
f"link={self.idx_to_name(i)}, parent="
|
520
|
+
f"{self.idx_to_name(p, allow_world=True)},"
|
521
|
+
f" joint={link_type}. Hint: Try setting `config.cor` to false."
|
522
|
+
)
|
523
|
+
|
524
|
+
def logic_replace_free_with_cor(name, olt, ola, old, ols, olz):
|
525
|
+
# by default new is equal to old
|
526
|
+
nlt, nla, nld, nls, nlz = olt, ola, old, ols, olz
|
527
|
+
|
528
|
+
# old link type == free
|
529
|
+
if olt == "free":
|
530
|
+
# cor joint is (free, p3d) stacked
|
531
|
+
nlt = "cor"
|
532
|
+
# entries of old armature are 3*ang (spherical), 3*pos (p3d)
|
533
|
+
nla = jnp.concatenate((ola, ola[3:]))
|
534
|
+
nld = jnp.concatenate((old, old[3:]))
|
535
|
+
nls = jnp.concatenate((ols, ols[3:]))
|
536
|
+
nlz = jnp.concatenate((olz, olz[4:]))
|
537
|
+
|
538
|
+
return nlt, nla, nld, nls, nlz
|
539
|
+
|
540
|
+
return _update_sys_if_replace_joint_type(self, logic_replace_free_with_cor)
|
541
|
+
|
542
|
+
def freeze(self, name: str | list[str]):
|
543
|
+
if isinstance(name, list):
|
544
|
+
sys = self
|
545
|
+
for n in name:
|
546
|
+
sys = sys.freeze(n)
|
547
|
+
return sys
|
548
|
+
|
549
|
+
def logic_freeze(link_name, olt, ola, old, ols, olz):
|
550
|
+
nlt, nla, nld, nls, nlz = olt, ola, old, ols, olz
|
551
|
+
|
552
|
+
if link_name == name:
|
553
|
+
nlt = "frozen"
|
554
|
+
nla = nld = nls = nlz = jnp.array([])
|
555
|
+
|
556
|
+
return nlt, nla, nld, nls, nlz
|
557
|
+
|
558
|
+
return _update_sys_if_replace_joint_type(self, logic_freeze)
|
559
|
+
|
560
|
+
def unfreeze(self, name: str, new_joint_type: str):
|
561
|
+
assert self.link_types[self.name_to_idx(name)] == "frozen"
|
562
|
+
assert new_joint_type != "frozen"
|
563
|
+
|
564
|
+
return self.change_joint_type(name, new_joint_type)
|
565
|
+
|
566
|
+
def change_joint_type(
|
567
|
+
self,
|
568
|
+
name: str,
|
569
|
+
new_joint_type: str,
|
570
|
+
new_arma: Optional[jax.Array] = None,
|
571
|
+
new_damp: Optional[jax.Array] = None,
|
572
|
+
new_stif: Optional[jax.Array] = None,
|
573
|
+
new_zero: Optional[jax.Array] = None,
|
574
|
+
):
|
575
|
+
"By default damping, stiffness are set to zero."
|
576
|
+
q_size, qd_size = Q_WIDTHS[new_joint_type], QD_WIDTHS[new_joint_type]
|
577
|
+
|
578
|
+
def logic_unfreeze_to_spherical(link_name, olt, ola, old, ols, olz):
|
579
|
+
nlt, nla, nld, nls, nlz = olt, ola, old, ols, olz
|
580
|
+
|
581
|
+
if link_name == name:
|
582
|
+
nlt = new_joint_type
|
583
|
+
q_zeros = jnp.zeros((q_size))
|
584
|
+
qd_zeros = jnp.zeros((qd_size,))
|
585
|
+
|
586
|
+
nla = qd_zeros if new_arma is None else new_arma
|
587
|
+
nld = qd_zeros if new_damp is None else new_damp
|
588
|
+
nls = qd_zeros if new_stif is None else new_stif
|
589
|
+
nlz = q_zeros if new_zero is None else new_zero
|
590
|
+
|
591
|
+
# unit quaternion
|
592
|
+
if new_joint_type in ["spherical", "free", "cor"] and new_zero is None:
|
593
|
+
nlz = nlz.at[0].set(1.0)
|
594
|
+
|
595
|
+
return nlt, nla, nld, nls, nlz
|
596
|
+
|
597
|
+
return _update_sys_if_replace_joint_type(self, logic_unfreeze_to_spherical)
|
598
|
+
|
599
|
+
def findall_imus(self) -> list[str]:
|
600
|
+
return [name for name in self.link_names if name[:3] == "imu"]
|
601
|
+
|
602
|
+
def findall_segments(self) -> list[str]:
|
603
|
+
imus = self.findall_imus()
|
604
|
+
return [name for name in self.link_names if name not in imus]
|
605
|
+
|
606
|
+
def _bodies_indices_to_bodies_name(self, bodies: list[int]) -> list[str]:
|
607
|
+
return [self.idx_to_name(i) for i in bodies]
|
608
|
+
|
609
|
+
def findall_bodies_to_world(self, names: bool = False) -> list[int] | list[str]:
|
610
|
+
bodies = [i for i, p in enumerate(self.link_parents) if p == -1]
|
611
|
+
return self._bodies_indices_to_bodies_name(bodies) if names else bodies
|
612
|
+
|
613
|
+
def find_body_to_world(self, name: bool = False) -> int | str:
|
614
|
+
bodies = self.findall_bodies_to_world(names=name)
|
615
|
+
assert len(bodies) == 1
|
616
|
+
return bodies[0]
|
617
|
+
|
618
|
+
def findall_bodies_with_jointtype(
|
619
|
+
self, typ: str, names: bool = False
|
620
|
+
) -> list[int] | list[str]:
|
621
|
+
bodies = [i for i, _typ in enumerate(self.link_types) if _typ == typ]
|
622
|
+
return self._bodies_indices_to_bodies_name(bodies) if names else bodies
|
623
|
+
|
624
|
+
def scan(self, f: Callable, in_types: str, *args, reverse: bool = False):
|
625
|
+
"""Scan `f` along each link in system whilst carrying along state.
|
626
|
+
|
627
|
+
Args:
|
628
|
+
f (Callable[..., Y]): f(y: Y, *args) -> y
|
629
|
+
in_types: string specifying the type of each input arg:
|
630
|
+
'l' is an input to be split according to link ranges
|
631
|
+
'q' is an input to be split according to q ranges
|
632
|
+
'd' is an input to be split according to qd ranges
|
633
|
+
args: Arguments passed to `f`, and split to match the link.
|
634
|
+
reverse (bool, optional): If `true` from leaves to root. Defaults to False.
|
635
|
+
|
636
|
+
Returns:
|
637
|
+
ys: Stacked output y of f.
|
638
|
+
"""
|
639
|
+
return _scan_sys(self, f, in_types, *args, reverse=reverse)
|
640
|
+
|
641
|
+
def parse(self) -> "System":
|
642
|
+
"""Initial setup of system. System object does not work unless it is parsed.
|
643
|
+
Currently it does:
|
644
|
+
- some consistency checks
|
645
|
+
- populate the spatial inertia tensors
|
646
|
+
- check that all names are unique
|
647
|
+
- check that names are strings
|
648
|
+
- check that all pos_min <= pos_max (unless traced)
|
649
|
+
- order geoms in ascending order based on their parent link idx
|
650
|
+
- check that all links have the correct size of
|
651
|
+
- damping
|
652
|
+
- armature
|
653
|
+
- stiffness
|
654
|
+
- zeropoint
|
655
|
+
- check that n_links == len(sys.omc)
|
656
|
+
"""
|
657
|
+
return _parse_system(self)
|
658
|
+
|
659
|
+
def render(
|
660
|
+
self,
|
661
|
+
xs: Optional[Transform | list[Transform]] = None,
|
662
|
+
camera: Optional[str] = None,
|
663
|
+
show_pbar: bool = True,
|
664
|
+
backend: str = "mujoco",
|
665
|
+
render_every_nth: int = 1,
|
666
|
+
**scene_kwargs,
|
667
|
+
) -> list[np.ndarray]:
|
668
|
+
"""Render frames from system and trajectory of maximal coordinates `xs`.
|
669
|
+
|
670
|
+
Args:
|
671
|
+
sys (base.System): System to render.
|
672
|
+
xs (base.Transform | list[base.Transform]): Single or time-series
|
673
|
+
of maximal coordinates `xs`.
|
674
|
+
show_pbar (bool, optional): Whether or not to show a progress bar.
|
675
|
+
Defaults to True.
|
676
|
+
|
677
|
+
Returns:
|
678
|
+
list[np.ndarray]: Stacked rendered frames. Length == len(xs).
|
679
|
+
"""
|
680
|
+
return ring.rendering.render(
|
681
|
+
self, xs, camera, show_pbar, backend, render_every_nth, **scene_kwargs
|
682
|
+
)
|
683
|
+
|
684
|
+
def render_prediction(
|
685
|
+
self,
|
686
|
+
xs: Transform | list[Transform],
|
687
|
+
yhat: dict | jax.Array | np.ndarray,
|
688
|
+
stepframe: int = 1,
|
689
|
+
# by default we don't predict the global rotation
|
690
|
+
transparent_segment_to_root: bool = True,
|
691
|
+
**kwargs,
|
692
|
+
):
|
693
|
+
"`xs` matches `sys`. `yhat` matches `sys_noimu`. `yhat` are child-to-parent."
|
694
|
+
return ring.rendering.render_prediction(
|
695
|
+
self, xs, yhat, stepframe, transparent_segment_to_root, **kwargs
|
696
|
+
)
|
697
|
+
|
698
|
+
def delete_system(self, link_name: str | list[str], strict: bool = True):
|
699
|
+
"Cut subsystem starting at `link_name` (inclusive) from tree."
|
700
|
+
return ring.sys_composer.delete_subsystem(self, link_name, strict)
|
701
|
+
|
702
|
+
def make_sys_noimu(self, imu_link_names: Optional[list[str]] = None):
|
703
|
+
"Returns, e.g., imu_attachment = {'imu1': 'seg1', 'imu2': 'seg3'}"
|
704
|
+
return ring.sys_composer.make_sys_noimu(self, imu_link_names)
|
705
|
+
|
706
|
+
def inject_system(self, other_system: "System", at_body: Optional[str] = None):
|
707
|
+
"""Combine two systems into one.
|
708
|
+
|
709
|
+
Args:
|
710
|
+
sys (base.System): Large system.
|
711
|
+
sub_sys (base.System): Small system that will be included into the
|
712
|
+
large system `sys`.
|
713
|
+
at_body (Optional[str], optional): Into which body of the large system
|
714
|
+
small system will be included. Defaults to `worldbody`.
|
715
|
+
|
716
|
+
Returns:
|
717
|
+
base.System: _description_
|
718
|
+
"""
|
719
|
+
return ring.sys_composer.inject_system(self, other_system, at_body)
|
720
|
+
|
721
|
+
def morph_system(
|
722
|
+
self,
|
723
|
+
new_parents: Optional[list[int | str]] = None,
|
724
|
+
new_anchor: Optional[int | str] = None,
|
725
|
+
):
|
726
|
+
"""Re-orders the graph underlying the system. Returns a new system.
|
727
|
+
|
728
|
+
Args:
|
729
|
+
sys (base.System): System to be modified.
|
730
|
+
new_parents (list[int]): Let the i-th entry have value j. Then, after
|
731
|
+
morphing the system the system will be such that the link corresponding
|
732
|
+
to the i-th link in the old system will have as parent the link
|
733
|
+
corresponding to the j-th link in the old system.
|
734
|
+
|
735
|
+
Returns:
|
736
|
+
base.System: Modified system.
|
737
|
+
"""
|
738
|
+
return ring.sys_composer.morph_system(self, new_parents, new_anchor)
|
739
|
+
|
740
|
+
@staticmethod
|
741
|
+
def from_xml(path: str, seed: int = 1):
|
742
|
+
return ring.io.load_sys_from_xml(path, seed)
|
743
|
+
|
744
|
+
@staticmethod
|
745
|
+
def from_str(xml: str, seed: int = 1):
|
746
|
+
return ring.io.load_sys_from_str(xml, seed)
|
747
|
+
|
748
|
+
def to_str(self) -> str:
|
749
|
+
return ring.io.save_sys_to_str(self)
|
750
|
+
|
751
|
+
def to_xml(self, path: str) -> None:
|
752
|
+
ring.io.save_sys_to_xml(self, path)
|
753
|
+
|
754
|
+
@classmethod
|
755
|
+
def create(cls, path_or_str: str, seed: int = 1) -> "System":
|
756
|
+
path = Path(path_or_str).with_suffix(".xml")
|
757
|
+
|
758
|
+
exists = False
|
759
|
+
try:
|
760
|
+
exists = path.exists()
|
761
|
+
except OSError:
|
762
|
+
# file length too length
|
763
|
+
pass
|
764
|
+
|
765
|
+
if exists:
|
766
|
+
return cls.from_xml(path, seed=seed)
|
767
|
+
else:
|
768
|
+
return cls.from_str(path_or_str)
|
769
|
+
|
770
|
+
def coordinate_vector_to_q(
|
771
|
+
self,
|
772
|
+
q: jax.Array,
|
773
|
+
custom_joints: dict[str, Callable] = {},
|
774
|
+
) -> jax.Array:
|
775
|
+
"""Map a coordinate vector `q` to the minimal coordinates vector of the sys"""
|
776
|
+
# Does, e.g.
|
777
|
+
# - normalize quaternions
|
778
|
+
# - hinge joints in [-pi, pi]
|
779
|
+
q_preproc = []
|
780
|
+
|
781
|
+
def preprocess(_, __, link_type, q):
|
782
|
+
to_q = ring.algorithms.jcalc.get_joint_model(
|
783
|
+
link_type
|
784
|
+
).coordinate_vector_to_q
|
785
|
+
# function in custom_joints has priority over JointModel
|
786
|
+
if link_type in custom_joints:
|
787
|
+
to_q = custom_joints[link_type]
|
788
|
+
if to_q is None:
|
789
|
+
raise NotImplementedError(
|
790
|
+
f"Please specify the custom joint `{link_type}`"
|
791
|
+
" either using the `custom_joints` arguments or using the"
|
792
|
+
" JointModel.coordinate_vector_to_q field."
|
793
|
+
)
|
794
|
+
new_q = to_q(q)
|
795
|
+
q_preproc.append(new_q)
|
796
|
+
|
797
|
+
self.scan(preprocess, "lq", self.link_types, q)
|
798
|
+
return jnp.concatenate(q_preproc)
|
799
|
+
|
800
|
+
|
801
|
+
def _update_sys_if_replace_joint_type(sys: System, logic) -> System:
|
802
|
+
lt, la, ld, ls, lz = [], [], [], [], []
|
803
|
+
|
804
|
+
def f(_, __, name, olt, ola, old, ols, olz):
|
805
|
+
nlt, nla, nld, nls, nlz = logic(name, olt, ola, old, ols, olz)
|
806
|
+
|
807
|
+
lt.append(nlt)
|
808
|
+
la.append(nla)
|
809
|
+
ld.append(nld)
|
810
|
+
ls.append(nls)
|
811
|
+
lz.append(nlz)
|
812
|
+
|
813
|
+
sys.scan(
|
814
|
+
f,
|
815
|
+
"lldddq",
|
816
|
+
sys.link_names,
|
817
|
+
sys.link_types,
|
818
|
+
sys.link_armature,
|
819
|
+
sys.link_damping,
|
820
|
+
sys.link_spring_stiffness,
|
821
|
+
sys.link_spring_zeropoint,
|
822
|
+
)
|
823
|
+
|
824
|
+
# lt is supposed to be a list of strings; no concat required
|
825
|
+
la, ld, ls, lz = map(jnp.concatenate, (la, ld, ls, lz))
|
826
|
+
|
827
|
+
sys = sys.replace(
|
828
|
+
link_types=lt,
|
829
|
+
link_armature=la,
|
830
|
+
link_damping=ld,
|
831
|
+
link_spring_stiffness=ls,
|
832
|
+
link_spring_zeropoint=lz,
|
833
|
+
)
|
834
|
+
|
835
|
+
# parse system such that it checks if all joint types have the
|
836
|
+
# correct dimensionality of damping / stiffness / zeropoint / armature
|
837
|
+
return sys.parse()
|
838
|
+
|
839
|
+
|
840
|
+
class InvalidSystemError(Exception):
|
841
|
+
pass
|
842
|
+
|
843
|
+
|
844
|
+
def _parse_system(sys: System) -> System:
|
845
|
+
assert len(sys.link_parents) == len(sys.link_types) == sys.links.batch_dim()
|
846
|
+
assert len(sys.omc) == sys.num_links()
|
847
|
+
|
848
|
+
for i, name in enumerate(sys.link_names):
|
849
|
+
assert sys.link_names.count(name) == 1, f"Duplicated name=`{name}` in system"
|
850
|
+
assert isinstance(name, str)
|
851
|
+
|
852
|
+
pos_min, pos_max = sys.links.pos_min, sys.links.pos_max
|
853
|
+
|
854
|
+
try:
|
855
|
+
from jax.errors import TracerBoolConversionError
|
856
|
+
|
857
|
+
try:
|
858
|
+
assert jnp.all(pos_max >= pos_min), f"min={pos_min}, max={pos_max}"
|
859
|
+
except TracerBoolConversionError:
|
860
|
+
pass
|
861
|
+
# on older versions of jax this import is not possible
|
862
|
+
except ImportError:
|
863
|
+
pass
|
864
|
+
|
865
|
+
for geom in sys.geoms:
|
866
|
+
assert geom.link_idx in list(range(sys.num_links())) + [-1]
|
867
|
+
|
868
|
+
inertia = _parse_system_calculate_inertia(sys)
|
869
|
+
sys = sys.replace(links=sys.links.replace(inertia=inertia))
|
870
|
+
|
871
|
+
# sort geoms in ascending order
|
872
|
+
geoms = sys.geoms.copy()
|
873
|
+
geoms.sort(key=lambda geom: geom.link_idx)
|
874
|
+
sys = sys.replace(geoms=geoms)
|
875
|
+
|
876
|
+
# round dt
|
877
|
+
# sys = sys.replace(dt=round(sys.dt, 8))
|
878
|
+
|
879
|
+
# check sizes of damping / arma / stiff / zeropoint
|
880
|
+
def check_dasz_unitq(_, __, name, typ, d, a, s, z):
|
881
|
+
q_size, qd_size = Q_WIDTHS[typ], QD_WIDTHS[typ]
|
882
|
+
|
883
|
+
error_msg = (
|
884
|
+
f"wrong size for link `{name}` of typ `{typ}` in model {sys.model_name}"
|
885
|
+
)
|
886
|
+
|
887
|
+
assert d.size == a.size == s.size == qd_size, error_msg
|
888
|
+
assert z.size == q_size, error_msg
|
889
|
+
|
890
|
+
if typ in ["spherical", "free", "cor"] and not isinstance(z, Tracer):
|
891
|
+
assert jnp.allclose(
|
892
|
+
jnp.linalg.norm(z[:4]), 1.0
|
893
|
+
), f"not unit quat for link `{name}` of typ `{typ}` in model"
|
894
|
+
f" {sys.model_name}"
|
895
|
+
|
896
|
+
sys.scan(
|
897
|
+
check_dasz_unitq,
|
898
|
+
"lldddq",
|
899
|
+
sys.link_names,
|
900
|
+
sys.link_types,
|
901
|
+
sys.link_damping,
|
902
|
+
sys.link_armature,
|
903
|
+
sys.link_spring_stiffness,
|
904
|
+
sys.link_spring_zeropoint,
|
905
|
+
)
|
906
|
+
|
907
|
+
return sys
|
908
|
+
|
909
|
+
|
910
|
+
def _inertia_from_geometries(geometries: list[Geometry]) -> Inertia:
|
911
|
+
inertia = Inertia.zero()
|
912
|
+
for geom in geometries:
|
913
|
+
inertia += Inertia.create(geom.mass, geom.transform, geom.get_it_3x3())
|
914
|
+
return inertia
|
915
|
+
|
916
|
+
|
917
|
+
def _parse_system_calculate_inertia(sys: System):
|
918
|
+
def compute_inertia_per_link(_, __, link_idx: int):
|
919
|
+
geoms_link = []
|
920
|
+
for geom in sys.geoms:
|
921
|
+
if geom.link_idx == link_idx:
|
922
|
+
geoms_link.append(geom)
|
923
|
+
|
924
|
+
it = _inertia_from_geometries(geoms_link)
|
925
|
+
return it
|
926
|
+
|
927
|
+
return sys.scan(compute_inertia_per_link, "l", list(range(sys.num_links())))
|
928
|
+
|
929
|
+
|
930
|
+
def _scan_sys(sys: System, f: Callable, in_types: str, *args, reverse: bool = False):
|
931
|
+
assert len(args) == len(in_types)
|
932
|
+
for in_type, arg in zip(in_types, args):
|
933
|
+
B = len(arg)
|
934
|
+
if in_type == "l":
|
935
|
+
assert B == sys.num_links()
|
936
|
+
elif in_type == "q":
|
937
|
+
assert B == sys.q_size()
|
938
|
+
elif in_type == "d":
|
939
|
+
assert B == sys.qd_size()
|
940
|
+
else:
|
941
|
+
raise Exception("`in_types` must be one of `l` or `q` or `d`")
|
942
|
+
|
943
|
+
order = range(sys.num_links())
|
944
|
+
q_idx, qd_idx = 0, 0
|
945
|
+
q_idxs, qd_idxs = {}, {}
|
946
|
+
for link_idx, link_type in zip(order, sys.link_types):
|
947
|
+
# build map from
|
948
|
+
# link-idx -> q_idx
|
949
|
+
# link-idx -> qd_idx
|
950
|
+
q_idxs[link_idx] = slice(q_idx, q_idx + Q_WIDTHS[link_type])
|
951
|
+
qd_idxs[link_idx] = slice(qd_idx, qd_idx + QD_WIDTHS[link_type])
|
952
|
+
q_idx += Q_WIDTHS[link_type]
|
953
|
+
qd_idx += QD_WIDTHS[link_type]
|
954
|
+
|
955
|
+
idx_map = {
|
956
|
+
"l": lambda link_idx: link_idx,
|
957
|
+
"q": lambda link_idx: q_idxs[link_idx],
|
958
|
+
"d": lambda link_idx: qd_idxs[link_idx],
|
959
|
+
}
|
960
|
+
|
961
|
+
if reverse:
|
962
|
+
order = range(sys.num_links() - 1, -1, -1)
|
963
|
+
|
964
|
+
y, ys = None, []
|
965
|
+
for link_idx in order:
|
966
|
+
args_link = [arg[idx_map[t](link_idx)] for arg, t in zip(args, in_types)]
|
967
|
+
y = f(y, idx_map, *args_link)
|
968
|
+
ys.append(y)
|
969
|
+
|
970
|
+
if reverse:
|
971
|
+
ys.reverse()
|
972
|
+
|
973
|
+
ys = tu.tree_batch(ys, backend="jax")
|
974
|
+
return ys
|
975
|
+
|
976
|
+
|
977
|
+
@struct.dataclass
|
978
|
+
class State(_Base):
|
979
|
+
"""The static and dynamic state of a system in minimal and maximal coordinates.
|
980
|
+
Use `.create()` to create this object.
|
981
|
+
|
982
|
+
Args:
|
983
|
+
q (jax.Array): System state in minimal coordinates (equals `sys.q_size()`)
|
984
|
+
qd (jax.Array): System velocity in minimal coordinates (equals `sys.qd_size()`)
|
985
|
+
x: (Transform): Maximal coordinates of all links. From epsilon-to-link.
|
986
|
+
mass_mat_inv (jax.Array): Inverse of the mass matrix. Internal usage.
|
987
|
+
"""
|
988
|
+
|
989
|
+
q: jax.Array
|
990
|
+
qd: jax.Array
|
991
|
+
x: Transform
|
992
|
+
mass_mat_inv: jax.Array
|
993
|
+
|
994
|
+
@classmethod
|
995
|
+
def create(
|
996
|
+
cls,
|
997
|
+
sys: System,
|
998
|
+
q: Optional[jax.Array] = None,
|
999
|
+
qd: Optional[jax.Array] = None,
|
1000
|
+
x: Optional[Transform] = None,
|
1001
|
+
key: Optional[jax.Array] = None,
|
1002
|
+
custom_joints: dict[str, Callable] = {},
|
1003
|
+
):
|
1004
|
+
"""Create state of system.
|
1005
|
+
|
1006
|
+
Args:
|
1007
|
+
sys (System): The system for which to create a state.
|
1008
|
+
q (jax.Array, optional): The joint values of the system. Defaults to None.
|
1009
|
+
Which then defaults to zeros.
|
1010
|
+
qd (jax.Array, optional): The joint velocities of the system.
|
1011
|
+
Defaults to None. Which then defaults to zeros.
|
1012
|
+
|
1013
|
+
Returns:
|
1014
|
+
(State): Create State object.
|
1015
|
+
"""
|
1016
|
+
if key is not None:
|
1017
|
+
assert q is None
|
1018
|
+
q = jax.random.normal(key, shape=(sys.q_size(),))
|
1019
|
+
q = sys.coordinate_vector_to_q(q, custom_joints)
|
1020
|
+
elif q is None:
|
1021
|
+
q = jnp.zeros((sys.q_size(),))
|
1022
|
+
|
1023
|
+
# free, cor, spherical joints are not zeros but have unit quaternions
|
1024
|
+
def replace_by_unit_quat(_, idx_map, link_typ, link_idx):
|
1025
|
+
nonlocal q
|
1026
|
+
|
1027
|
+
if link_typ in ["free", "cor", "spherical"]:
|
1028
|
+
q_idxs_link = idx_map["q"](link_idx)
|
1029
|
+
q = q.at[q_idxs_link.start].set(1.0)
|
1030
|
+
|
1031
|
+
sys.scan(
|
1032
|
+
replace_by_unit_quat,
|
1033
|
+
"ll",
|
1034
|
+
sys.link_types,
|
1035
|
+
list(range(sys.num_links())),
|
1036
|
+
)
|
1037
|
+
else:
|
1038
|
+
pass
|
1039
|
+
|
1040
|
+
if qd is None:
|
1041
|
+
qd = jnp.zeros((sys.qd_size(),))
|
1042
|
+
|
1043
|
+
if x is None:
|
1044
|
+
x = Transform.zero((sys.num_links(),))
|
1045
|
+
|
1046
|
+
return cls(q, qd, x, jnp.diag(jnp.ones((sys.qd_size(),))))
|