jaxsim 0.2.dev188__py3-none-any.whl → 0.2.dev364__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/__init__.py +3 -4
- jaxsim/_version.py +2 -2
- jaxsim/api/__init__.py +3 -1
- jaxsim/api/com.py +240 -0
- jaxsim/api/common.py +13 -2
- jaxsim/api/contact.py +120 -43
- jaxsim/api/data.py +112 -71
- jaxsim/api/joint.py +77 -36
- jaxsim/api/kin_dyn_parameters.py +777 -0
- jaxsim/api/link.py +150 -75
- jaxsim/api/model.py +542 -269
- jaxsim/api/ode.py +88 -72
- jaxsim/api/ode_data.py +694 -0
- jaxsim/api/references.py +12 -11
- jaxsim/integrators/__init__.py +2 -2
- jaxsim/integrators/common.py +110 -24
- jaxsim/integrators/fixed_step.py +11 -67
- jaxsim/integrators/variable_step.py +610 -0
- jaxsim/math/__init__.py +11 -0
- jaxsim/math/adjoint.py +24 -2
- jaxsim/math/joint_model.py +335 -0
- jaxsim/math/quaternion.py +44 -3
- jaxsim/math/rotation.py +4 -4
- jaxsim/math/transform.py +93 -0
- jaxsim/parsers/descriptions/collision.py +14 -0
- jaxsim/parsers/descriptions/link.py +13 -2
- jaxsim/parsers/kinematic_graph.py +5 -0
- jaxsim/parsers/rod/utils.py +7 -8
- jaxsim/rbda/__init__.py +7 -0
- jaxsim/rbda/aba.py +295 -0
- jaxsim/rbda/collidable_points.py +142 -0
- jaxsim/{physics/algos → rbda}/crba.py +43 -42
- jaxsim/rbda/forward_kinematics.py +113 -0
- jaxsim/rbda/jacobian.py +201 -0
- jaxsim/rbda/rnea.py +237 -0
- jaxsim/rbda/soft_contacts.py +296 -0
- jaxsim/rbda/utils.py +152 -0
- jaxsim/terrain/__init__.py +2 -0
- jaxsim/{physics/algos → terrain}/terrain.py +4 -6
- jaxsim/utils/__init__.py +1 -4
- jaxsim/utils/hashless.py +18 -0
- jaxsim/utils/jaxsim_dataclass.py +281 -30
- {jaxsim-0.2.dev188.dist-info → jaxsim-0.2.dev364.dist-info}/METADATA +4 -6
- jaxsim-0.2.dev364.dist-info/RECORD +64 -0
- {jaxsim-0.2.dev188.dist-info → jaxsim-0.2.dev364.dist-info}/WHEEL +1 -1
- jaxsim/high_level/__init__.py +0 -2
- jaxsim/high_level/common.py +0 -11
- jaxsim/high_level/joint.py +0 -148
- jaxsim/high_level/link.py +0 -259
- jaxsim/high_level/model.py +0 -1686
- jaxsim/math/conv.py +0 -114
- jaxsim/math/joint.py +0 -102
- jaxsim/math/plucker.py +0 -100
- jaxsim/physics/__init__.py +0 -12
- jaxsim/physics/algos/__init__.py +0 -0
- jaxsim/physics/algos/aba.py +0 -254
- jaxsim/physics/algos/aba_motors.py +0 -284
- jaxsim/physics/algos/forward_kinematics.py +0 -79
- jaxsim/physics/algos/jacobian.py +0 -98
- jaxsim/physics/algos/rnea.py +0 -180
- jaxsim/physics/algos/rnea_motors.py +0 -196
- jaxsim/physics/algos/soft_contacts.py +0 -523
- jaxsim/physics/algos/utils.py +0 -69
- jaxsim/physics/model/__init__.py +0 -0
- jaxsim/physics/model/ground_contact.py +0 -55
- jaxsim/physics/model/physics_model.py +0 -388
- jaxsim/physics/model/physics_model_state.py +0 -283
- jaxsim/simulation/__init__.py +0 -4
- jaxsim/simulation/integrators.py +0 -393
- jaxsim/simulation/ode.py +0 -290
- jaxsim/simulation/ode_data.py +0 -96
- jaxsim/simulation/ode_integration.py +0 -62
- jaxsim/simulation/simulator.py +0 -543
- jaxsim/simulation/simulator_callbacks.py +0 -79
- jaxsim/simulation/utils.py +0 -15
- jaxsim/sixd/__init__.py +0 -2
- jaxsim/utils/oop.py +0 -536
- jaxsim/utils/vmappable.py +0 -117
- jaxsim-0.2.dev188.dist-info/RECORD +0 -81
- {jaxsim-0.2.dev188.dist-info → jaxsim-0.2.dev364.dist-info}/LICENSE +0 -0
- {jaxsim-0.2.dev188.dist-info → jaxsim-0.2.dev364.dist-info}/top_level.txt +0 -0
jaxsim/api/link.py
CHANGED
@@ -4,20 +4,21 @@ from typing import Sequence
|
|
4
4
|
import jax
|
5
5
|
import jax.numpy as jnp
|
6
6
|
import jaxlie
|
7
|
+
import numpy as np
|
7
8
|
|
8
|
-
import jaxsim.
|
9
|
+
import jaxsim.api as js
|
10
|
+
import jaxsim.rbda
|
9
11
|
import jaxsim.typing as jtp
|
10
|
-
from jaxsim.high_level.common import VelRepr
|
11
12
|
|
12
|
-
from . import
|
13
|
-
from . import model as Model
|
13
|
+
from .common import VelRepr
|
14
14
|
|
15
15
|
# =======================
|
16
16
|
# Index-related functions
|
17
17
|
# =======================
|
18
18
|
|
19
19
|
|
20
|
-
|
20
|
+
@functools.partial(jax.jit, static_argnames="link_name")
|
21
|
+
def name_to_idx(model: js.model.JaxSimModel, *, link_name: str) -> jtp.Int:
|
21
22
|
"""
|
22
23
|
Convert the name of a link to its index.
|
23
24
|
|
@@ -29,12 +30,18 @@ def name_to_idx(model: Model.JaxSimModel, *, link_name: str) -> jtp.Int:
|
|
29
30
|
The index of the link.
|
30
31
|
"""
|
31
32
|
|
32
|
-
|
33
|
-
|
34
|
-
|
33
|
+
if link_name in model.kin_dyn_parameters.link_names:
|
34
|
+
return (
|
35
|
+
jnp.array(
|
36
|
+
np.argwhere(np.array(model.kin_dyn_parameters.link_names) == link_name)
|
37
|
+
)
|
38
|
+
.squeeze()
|
39
|
+
.astype(int)
|
40
|
+
)
|
41
|
+
return jnp.array(-1).astype(int)
|
35
42
|
|
36
43
|
|
37
|
-
def idx_to_name(model:
|
44
|
+
def idx_to_name(model: js.model.JaxSimModel, *, link_index: jtp.IntLike) -> str:
|
38
45
|
"""
|
39
46
|
Convert the index of a link to its name.
|
40
47
|
|
@@ -46,11 +53,13 @@ def idx_to_name(model: Model.JaxSimModel, *, link_index: jtp.IntLike) -> str:
|
|
46
53
|
The name of the link.
|
47
54
|
"""
|
48
55
|
|
49
|
-
|
50
|
-
return d[link_index]
|
56
|
+
return model.kin_dyn_parameters.link_names[link_index]
|
51
57
|
|
52
58
|
|
53
|
-
|
59
|
+
@functools.partial(jax.jit, static_argnames="link_names")
|
60
|
+
def names_to_idxs(
|
61
|
+
model: js.model.JaxSimModel, *, link_names: Sequence[str]
|
62
|
+
) -> jax.Array:
|
54
63
|
"""
|
55
64
|
Convert a sequence of link names to their corresponding indices.
|
56
65
|
|
@@ -63,13 +72,12 @@ def names_to_idxs(model: Model.JaxSimModel, *, link_names: Sequence[str]) -> jax
|
|
63
72
|
"""
|
64
73
|
|
65
74
|
return jnp.array(
|
66
|
-
[model
|
67
|
-
|
68
|
-
)
|
75
|
+
[name_to_idx(model=model, link_name=name) for name in link_names],
|
76
|
+
).astype(int)
|
69
77
|
|
70
78
|
|
71
79
|
def idxs_to_names(
|
72
|
-
model:
|
80
|
+
model: js.model.JaxSimModel, *, link_indices: Sequence[jtp.IntLike] | jtp.VectorLike
|
73
81
|
) -> tuple[str, ...]:
|
74
82
|
"""
|
75
83
|
Convert a sequence of link indices to their corresponding names.
|
@@ -82,8 +90,7 @@ def idxs_to_names(
|
|
82
90
|
The names of the links.
|
83
91
|
"""
|
84
92
|
|
85
|
-
|
86
|
-
return tuple(d[i] for i in link_indices)
|
93
|
+
return tuple(idx_to_name(model=model, link_index=idx) for idx in link_indices)
|
87
94
|
|
88
95
|
|
89
96
|
# =========
|
@@ -91,21 +98,51 @@ def idxs_to_names(
|
|
91
98
|
# =========
|
92
99
|
|
93
100
|
|
94
|
-
|
95
|
-
|
101
|
+
@jax.jit
|
102
|
+
def mass(model: js.model.JaxSimModel, *, link_index: jtp.IntLike) -> jtp.Float:
|
103
|
+
"""
|
104
|
+
Return the mass of the link.
|
105
|
+
|
106
|
+
Args:
|
107
|
+
model: The model to consider.
|
108
|
+
link_index: The index of the link.
|
96
109
|
|
97
|
-
|
110
|
+
Returns:
|
111
|
+
The mass of the link.
|
112
|
+
"""
|
98
113
|
|
114
|
+
return model.kin_dyn_parameters.link_parameters.mass[link_index].astype(float)
|
99
115
|
|
100
|
-
def spatial_inertia(model: Model.JaxSimModel, *, link_index: jtp.IntLike) -> jtp.Matrix:
|
101
|
-
""""""
|
102
116
|
|
103
|
-
|
117
|
+
@jax.jit
|
118
|
+
def spatial_inertia(
|
119
|
+
model: js.model.JaxSimModel, *, link_index: jtp.IntLike
|
120
|
+
) -> jtp.Matrix:
|
121
|
+
"""
|
122
|
+
Compute the 6D spatial inertial of the link.
|
123
|
+
|
124
|
+
Args:
|
125
|
+
model: The model to consider.
|
126
|
+
link_index: The index of the link.
|
127
|
+
|
128
|
+
Returns:
|
129
|
+
The 6×6 matrix representing the spatial inertia of the link expressed in
|
130
|
+
the link frame (body-fixed representation).
|
131
|
+
"""
|
132
|
+
|
133
|
+
link_parameters = jax.tree_util.tree_map(
|
134
|
+
lambda l: l[link_index], model.kin_dyn_parameters.link_parameters
|
135
|
+
)
|
136
|
+
|
137
|
+
return js.kin_dyn_parameters.LinkParameters.spatial_inertia(link_parameters)
|
104
138
|
|
105
139
|
|
106
140
|
@jax.jit
|
107
141
|
def transform(
|
108
|
-
model:
|
142
|
+
model: js.model.JaxSimModel,
|
143
|
+
data: js.data.JaxSimModelData,
|
144
|
+
*,
|
145
|
+
link_index: jtp.IntLike,
|
109
146
|
) -> jtp.Matrix:
|
110
147
|
"""
|
111
148
|
Compute the SE(3) transform from the world frame to the link frame.
|
@@ -119,13 +156,13 @@ def transform(
|
|
119
156
|
The 4x4 matrix representing the transform.
|
120
157
|
"""
|
121
158
|
|
122
|
-
return
|
159
|
+
return js.model.forward_kinematics(model=model, data=data)[link_index]
|
123
160
|
|
124
161
|
|
125
162
|
@jax.jit
|
126
163
|
def com_position(
|
127
|
-
model:
|
128
|
-
data:
|
164
|
+
model: js.model.JaxSimModel,
|
165
|
+
data: js.data.JaxSimModelData,
|
129
166
|
*,
|
130
167
|
link_index: jtp.IntLike,
|
131
168
|
in_link_frame: jtp.BoolLike = True,
|
@@ -168,8 +205,8 @@ def com_position(
|
|
168
205
|
|
169
206
|
@functools.partial(jax.jit, static_argnames=["output_vel_repr"])
|
170
207
|
def jacobian(
|
171
|
-
model:
|
172
|
-
data:
|
208
|
+
model: js.model.JaxSimModel,
|
209
|
+
data: js.data.JaxSimModelData,
|
173
210
|
*,
|
174
211
|
link_index: jtp.IntLike,
|
175
212
|
output_vel_repr: VelRepr | None = None,
|
@@ -185,78 +222,116 @@ def jacobian(
|
|
185
222
|
The output velocity representation of the free-floating jacobian.
|
186
223
|
|
187
224
|
Returns:
|
188
|
-
The
|
225
|
+
The 6×(6+n) free-floating jacobian of the link.
|
189
226
|
|
190
227
|
Note:
|
191
228
|
The input representation of the free-floating jacobian is the active
|
192
229
|
velocity representation.
|
193
230
|
"""
|
194
231
|
|
195
|
-
|
196
|
-
output_vel_repr
|
197
|
-
|
198
|
-
# Compute the doubly left-trivialized free-floating jacobian
|
199
|
-
L_J_WL_B = jaxsim.physics.algos.jacobian.jacobian(
|
200
|
-
model=model.physics_model,
|
201
|
-
body_index=link_index,
|
202
|
-
q=data.joint_positions(),
|
232
|
+
output_vel_repr = (
|
233
|
+
output_vel_repr if output_vel_repr is not None else data.velocity_representation
|
203
234
|
)
|
204
235
|
|
205
|
-
|
236
|
+
# Compute the doubly-left free-floating full jacobian.
|
237
|
+
B_J_full_WX_B, B_H_Li = jaxsim.rbda.jacobian_full_doubly_left(
|
238
|
+
model=model,
|
239
|
+
joint_positions=data.joint_positions(),
|
240
|
+
)
|
206
241
|
|
207
|
-
|
208
|
-
|
242
|
+
# Compute the actual doubly-left free-floating jacobian of the link.
|
243
|
+
κ = model.kin_dyn_parameters.support_body_array_bool[link_index]
|
244
|
+
B_J_WL_B = jnp.hstack([jnp.ones(5), κ]) * B_J_full_WX_B
|
209
245
|
|
246
|
+
# Adjust the input representation such that `J_WL_I @ I_ν`.
|
247
|
+
match data.velocity_representation:
|
210
248
|
case VelRepr.Inertial:
|
211
|
-
dofs = model.dofs()
|
212
249
|
W_H_B = data.base_transform()
|
213
|
-
|
214
250
|
B_X_W = jaxlie.SE3.from_matrix(W_H_B).inverse().adjoint()
|
215
|
-
|
216
|
-
|
217
|
-
B_T_W = jnp.vstack(
|
218
|
-
[
|
219
|
-
jnp.block([B_X_W, zero_6n]),
|
220
|
-
jnp.block([zero_6n.T, jnp.eye(dofs)]),
|
221
|
-
]
|
251
|
+
B_J_WL_I = B_J_WL_W = B_J_WL_B @ jax.scipy.linalg.block_diag(
|
252
|
+
B_X_W, jnp.eye(model.dofs())
|
222
253
|
)
|
223
254
|
|
224
|
-
|
255
|
+
case VelRepr.Body:
|
256
|
+
B_J_WL_I = B_J_WL_B
|
225
257
|
|
226
258
|
case VelRepr.Mixed:
|
227
|
-
|
228
|
-
|
229
|
-
BW_H_B = jnp.array(W_H_B).at[0:3, 3].set(jnp.zeros(3))
|
230
|
-
|
259
|
+
W_R_B = data.base_orientation(dcm=True)
|
260
|
+
BW_H_B = jnp.eye(4).at[0:3, 0:3].set(W_R_B)
|
231
261
|
B_X_BW = jaxlie.SE3.from_matrix(BW_H_B).inverse().adjoint()
|
232
|
-
|
233
|
-
|
234
|
-
B_T_BW = jnp.vstack(
|
235
|
-
[
|
236
|
-
jnp.block([B_X_BW, zero_6n]),
|
237
|
-
jnp.block([zero_6n.T, jnp.eye(dofs)]),
|
238
|
-
]
|
262
|
+
B_J_WL_I = B_J_WL_BW = B_J_WL_B @ jax.scipy.linalg.block_diag(
|
263
|
+
B_X_BW, jnp.eye(model.dofs())
|
239
264
|
)
|
240
265
|
|
241
|
-
L_J_WL_target = L_J_WL_B @ B_T_BW
|
242
|
-
|
243
266
|
case _:
|
244
267
|
raise ValueError(data.velocity_representation)
|
245
268
|
|
246
|
-
|
247
|
-
case VelRepr.Body:
|
248
|
-
return L_J_WL_target
|
269
|
+
B_H_L = B_H_Li[link_index]
|
249
270
|
|
271
|
+
# Adjust the output representation such that `O_v_WL_I = O_J_WL_I @ I_ν`.
|
272
|
+
match output_vel_repr:
|
250
273
|
case VelRepr.Inertial:
|
251
|
-
|
252
|
-
|
253
|
-
|
274
|
+
W_H_B = data.base_transform()
|
275
|
+
W_X_B = jaxlie.SE3.from_matrix(W_H_B).adjoint()
|
276
|
+
O_J_WL_I = W_J_WL_I = W_X_B @ B_J_WL_I
|
277
|
+
|
278
|
+
case VelRepr.Body:
|
279
|
+
L_X_B = jaxlie.SE3.from_matrix(B_H_L).inverse().adjoint()
|
280
|
+
L_J_WL_I = L_X_B @ B_J_WL_I
|
281
|
+
O_J_WL_I = L_J_WL_I
|
254
282
|
|
255
283
|
case VelRepr.Mixed:
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
284
|
+
W_H_B = data.base_transform()
|
285
|
+
W_H_L = W_H_B @ B_H_L
|
286
|
+
LW_H_L = W_H_L.at[0:3, 3].set(jnp.zeros(3))
|
287
|
+
LW_H_B = LW_H_L @ jaxsim.math.Transform.inverse(B_H_L)
|
288
|
+
LW_X_B = jaxlie.SE3.from_matrix(LW_H_B).adjoint()
|
289
|
+
LW_J_WL_I = LW_X_B @ B_J_WL_I
|
290
|
+
O_J_WL_I = LW_J_WL_I
|
260
291
|
|
261
292
|
case _:
|
262
293
|
raise ValueError(output_vel_repr)
|
294
|
+
|
295
|
+
return O_J_WL_I
|
296
|
+
|
297
|
+
|
298
|
+
@functools.partial(jax.jit, static_argnames=["output_vel_repr"])
|
299
|
+
def velocity(
|
300
|
+
model: js.model.JaxSimModel,
|
301
|
+
data: js.data.JaxSimModelData,
|
302
|
+
*,
|
303
|
+
link_index: jtp.IntLike,
|
304
|
+
output_vel_repr: VelRepr | None = None,
|
305
|
+
) -> jtp.Vector:
|
306
|
+
"""
|
307
|
+
Compute the 6D velocity of the link.
|
308
|
+
|
309
|
+
Args:
|
310
|
+
model: The model to consider.
|
311
|
+
data: The data of the considered model.
|
312
|
+
link_index: The index of the link.
|
313
|
+
output_vel_repr:
|
314
|
+
The output velocity representation of the link velocity.
|
315
|
+
|
316
|
+
Returns:
|
317
|
+
The 6D velocity of the link in the specified velocity representation.
|
318
|
+
"""
|
319
|
+
|
320
|
+
output_vel_repr = (
|
321
|
+
output_vel_repr if output_vel_repr is not None else data.velocity_representation
|
322
|
+
)
|
323
|
+
|
324
|
+
# Get the link jacobian having I as input representation (taken from data)
|
325
|
+
# and O as output representation, specified by the user (or taken from data).
|
326
|
+
O_J_WL_I = jacobian(
|
327
|
+
model=model,
|
328
|
+
data=data,
|
329
|
+
link_index=link_index,
|
330
|
+
output_vel_repr=output_vel_repr,
|
331
|
+
)
|
332
|
+
|
333
|
+
# Get the generalized velocity in the input velocity representation.
|
334
|
+
I_ν = data.generalized_velocity()
|
335
|
+
|
336
|
+
# Compute the link velocity in the output velocity representation.
|
337
|
+
return O_J_WL_I @ I_ν
|