jaxsim 0.2.dev101__py3-none-any.whl → 0.2.dev166__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/_version.py +2 -2
- jaxsim/api/__init__.py +1 -0
- jaxsim/api/contact.py +194 -0
- jaxsim/api/data.py +951 -0
- jaxsim/api/joint.py +148 -0
- jaxsim/api/link.py +262 -0
- jaxsim/api/model.py +1099 -0
- jaxsim/api/ode.py +280 -0
- jaxsim/integrators/__init__.py +2 -0
- jaxsim/integrators/common.py +508 -0
- jaxsim/integrators/fixed_step.py +158 -0
- jaxsim/mujoco/__init__.py +1 -1
- jaxsim/mujoco/loaders.py +30 -18
- jaxsim/mujoco/visualizer.py +3 -1
- jaxsim/physics/algos/soft_contacts.py +97 -28
- jaxsim/physics/model/physics_model.py +30 -0
- jaxsim/physics/model/physics_model_state.py +110 -11
- jaxsim/simulation/ode_data.py +43 -0
- {jaxsim-0.2.dev101.dist-info → jaxsim-0.2.dev166.dist-info}/METADATA +2 -1
- {jaxsim-0.2.dev101.dist-info → jaxsim-0.2.dev166.dist-info}/RECORD +23 -13
- {jaxsim-0.2.dev101.dist-info → jaxsim-0.2.dev166.dist-info}/LICENSE +0 -0
- {jaxsim-0.2.dev101.dist-info → jaxsim-0.2.dev166.dist-info}/WHEEL +0 -0
- {jaxsim-0.2.dev101.dist-info → jaxsim-0.2.dev166.dist-info}/top_level.txt +0 -0
jaxsim/api/joint.py
ADDED
@@ -0,0 +1,148 @@
|
|
1
|
+
import functools
|
2
|
+
from typing import Sequence
|
3
|
+
|
4
|
+
import jax
|
5
|
+
import jax.numpy as jnp
|
6
|
+
|
7
|
+
import jaxsim.typing as jtp
|
8
|
+
|
9
|
+
from . import model as Model
|
10
|
+
|
11
|
+
# =======================
|
12
|
+
# Index-related functions
|
13
|
+
# =======================
|
14
|
+
|
15
|
+
|
16
|
+
def name_to_idx(model: Model.JaxSimModel, *, joint_name: str) -> jtp.Int:
|
17
|
+
"""
|
18
|
+
Convert the name of a joint to its index.
|
19
|
+
|
20
|
+
Args:
|
21
|
+
model: The model to consider.
|
22
|
+
joint_name: The name of the joint.
|
23
|
+
|
24
|
+
Returns:
|
25
|
+
The index of the joint.
|
26
|
+
"""
|
27
|
+
|
28
|
+
return jnp.array(
|
29
|
+
model.physics_model.description.joints_dict[joint_name].index, dtype=int
|
30
|
+
)
|
31
|
+
|
32
|
+
|
33
|
+
def idx_to_name(model: Model.JaxSimModel, *, joint_index: jtp.IntLike) -> str:
|
34
|
+
"""
|
35
|
+
Convert the index of a joint to its name.
|
36
|
+
|
37
|
+
Args:
|
38
|
+
model: The model to consider.
|
39
|
+
joint_index: The index of the joint.
|
40
|
+
|
41
|
+
Returns:
|
42
|
+
The name of the joint.
|
43
|
+
"""
|
44
|
+
|
45
|
+
d = {j.index: j.name for j in model.physics_model.description.joints_dict.values()}
|
46
|
+
return d[joint_index]
|
47
|
+
|
48
|
+
|
49
|
+
def names_to_idxs(model: Model.JaxSimModel, *, joint_names: Sequence[str]) -> jax.Array:
|
50
|
+
"""
|
51
|
+
Convert a sequence of joint names to their corresponding indices.
|
52
|
+
|
53
|
+
Args:
|
54
|
+
model: The model to consider.
|
55
|
+
joint_names: The names of the joints.
|
56
|
+
|
57
|
+
Returns:
|
58
|
+
The indices of the joints.
|
59
|
+
"""
|
60
|
+
|
61
|
+
return jnp.array(
|
62
|
+
[
|
63
|
+
# Note: the index of the joint for RBDAs starts from 1, but
|
64
|
+
# the index for accessing the right element starts from 0.
|
65
|
+
# Therefore, there is a -1.
|
66
|
+
model.physics_model.description.joints_dict[name].index - 1
|
67
|
+
for name in joint_names
|
68
|
+
],
|
69
|
+
dtype=int,
|
70
|
+
)
|
71
|
+
|
72
|
+
|
73
|
+
def idxs_to_names(
|
74
|
+
model: Model.JaxSimModel, *, joint_indices: Sequence[jtp.IntLike] | jtp.VectorLike
|
75
|
+
) -> tuple[str, ...]:
|
76
|
+
"""
|
77
|
+
Convert a sequence of joint indices to their corresponding names.
|
78
|
+
|
79
|
+
Args:
|
80
|
+
model: The model to consider.
|
81
|
+
joint_indices: The indices of the joints.
|
82
|
+
|
83
|
+
Returns:
|
84
|
+
The names of the joints.
|
85
|
+
"""
|
86
|
+
|
87
|
+
d = {
|
88
|
+
j.index - 1: j.name
|
89
|
+
for j in model.physics_model.description.joints_dict.values()
|
90
|
+
}
|
91
|
+
|
92
|
+
return tuple(d[i] for i in joint_indices)
|
93
|
+
|
94
|
+
|
95
|
+
# ============
|
96
|
+
# Joint limits
|
97
|
+
# ============
|
98
|
+
|
99
|
+
|
100
|
+
@jax.jit
|
101
|
+
def position_limit(
|
102
|
+
model: Model.JaxSimModel, *, joint_index: jtp.IntLike
|
103
|
+
) -> tuple[jtp.Float, jtp.Float]:
|
104
|
+
""""""
|
105
|
+
|
106
|
+
min = model.physics_model._joint_position_limits_min[joint_index]
|
107
|
+
max = model.physics_model._joint_position_limits_max[joint_index]
|
108
|
+
|
109
|
+
return min.astype(float), max.astype(float)
|
110
|
+
|
111
|
+
|
112
|
+
@functools.partial(jax.jit, static_argnames=["joint_names"])
|
113
|
+
def position_limits(
|
114
|
+
model: Model.JaxSimModel, *, joint_names: Sequence[str] | None = None
|
115
|
+
) -> tuple[jtp.Vector, jtp.Vector]:
|
116
|
+
|
117
|
+
joint_names = joint_names if joint_names is not None else model.joint_names()
|
118
|
+
|
119
|
+
joint_idxs = names_to_idxs(joint_names=joint_names, model=model)
|
120
|
+
return jax.vmap(lambda i: position_limit(model=model, joint_index=i))(joint_idxs)
|
121
|
+
|
122
|
+
|
123
|
+
# ======================
|
124
|
+
# Random data generation
|
125
|
+
# ======================
|
126
|
+
|
127
|
+
|
128
|
+
@functools.partial(jax.jit, static_argnames=["joint_names"])
|
129
|
+
def random_joint_positions(
|
130
|
+
model: Model.JaxSimModel,
|
131
|
+
*,
|
132
|
+
joint_names: Sequence[str] | None = None,
|
133
|
+
key: jax.Array | None = None,
|
134
|
+
) -> jtp.Vector:
|
135
|
+
""""""
|
136
|
+
|
137
|
+
key = key if key is not None else jax.random.PRNGKey(seed=0)
|
138
|
+
|
139
|
+
s_min, s_max = position_limits(model=model, joint_names=joint_names)
|
140
|
+
|
141
|
+
s_random = jax.random.uniform(
|
142
|
+
minval=s_min,
|
143
|
+
maxval=s_max,
|
144
|
+
key=key,
|
145
|
+
shape=s_min.shape,
|
146
|
+
)
|
147
|
+
|
148
|
+
return s_random
|
jaxsim/api/link.py
ADDED
@@ -0,0 +1,262 @@
|
|
1
|
+
import functools
|
2
|
+
from typing import Sequence
|
3
|
+
|
4
|
+
import jax
|
5
|
+
import jax.numpy as jnp
|
6
|
+
import jaxlie
|
7
|
+
|
8
|
+
import jaxsim.physics.algos.jacobian
|
9
|
+
import jaxsim.typing as jtp
|
10
|
+
from jaxsim.high_level.common import VelRepr
|
11
|
+
|
12
|
+
from . import data as Data
|
13
|
+
from . import model as Model
|
14
|
+
|
15
|
+
# =======================
|
16
|
+
# Index-related functions
|
17
|
+
# =======================
|
18
|
+
|
19
|
+
|
20
|
+
def name_to_idx(model: Model.JaxSimModel, *, link_name: str) -> jtp.Int:
|
21
|
+
"""
|
22
|
+
Convert the name of a link to its index.
|
23
|
+
|
24
|
+
Args:
|
25
|
+
model: The model to consider.
|
26
|
+
link_name: The name of the link.
|
27
|
+
|
28
|
+
Returns:
|
29
|
+
The index of the link.
|
30
|
+
"""
|
31
|
+
|
32
|
+
return jnp.array(
|
33
|
+
model.physics_model.description.links_dict[link_name].index, dtype=int
|
34
|
+
)
|
35
|
+
|
36
|
+
|
37
|
+
def idx_to_name(model: Model.JaxSimModel, *, link_index: jtp.IntLike) -> str:
|
38
|
+
"""
|
39
|
+
Convert the index of a link to its name.
|
40
|
+
|
41
|
+
Args:
|
42
|
+
model: The model to consider.
|
43
|
+
link_index: The index of the link.
|
44
|
+
|
45
|
+
Returns:
|
46
|
+
The name of the link.
|
47
|
+
"""
|
48
|
+
|
49
|
+
d = {l.index: l.name for l in model.physics_model.description.links_dict.values()}
|
50
|
+
return d[link_index]
|
51
|
+
|
52
|
+
|
53
|
+
def names_to_idxs(model: Model.JaxSimModel, *, link_names: Sequence[str]) -> jax.Array:
|
54
|
+
"""
|
55
|
+
Convert a sequence of link names to their corresponding indices.
|
56
|
+
|
57
|
+
Args:
|
58
|
+
model: The model to consider.
|
59
|
+
link_names: The names of the links.
|
60
|
+
|
61
|
+
Returns:
|
62
|
+
The indices of the links.
|
63
|
+
"""
|
64
|
+
|
65
|
+
return jnp.array(
|
66
|
+
[model.physics_model.description.links_dict[name].index for name in link_names],
|
67
|
+
dtype=int,
|
68
|
+
)
|
69
|
+
|
70
|
+
|
71
|
+
def idxs_to_names(
|
72
|
+
model: Model.JaxSimModel, *, link_indices: Sequence[jtp.IntLike] | jtp.VectorLike
|
73
|
+
) -> tuple[str, ...]:
|
74
|
+
"""
|
75
|
+
Convert a sequence of link indices to their corresponding names.
|
76
|
+
|
77
|
+
Args:
|
78
|
+
model: The model to consider.
|
79
|
+
link_indices: The indices of the links.
|
80
|
+
|
81
|
+
Returns:
|
82
|
+
The names of the links.
|
83
|
+
"""
|
84
|
+
|
85
|
+
d = {l.index: l.name for l in model.physics_model.description.links_dict.values()}
|
86
|
+
return tuple(d[i] for i in link_indices)
|
87
|
+
|
88
|
+
|
89
|
+
# =========
|
90
|
+
# Link APIs
|
91
|
+
# =========
|
92
|
+
|
93
|
+
|
94
|
+
def mass(model: Model.JaxSimModel, *, link_index: jtp.IntLike) -> jtp.Float:
|
95
|
+
""""""
|
96
|
+
|
97
|
+
return model.physics_model._link_masses[link_index].astype(float)
|
98
|
+
|
99
|
+
|
100
|
+
def spatial_inertia(model: Model.JaxSimModel, *, link_index: jtp.IntLike) -> jtp.Matrix:
|
101
|
+
""""""
|
102
|
+
|
103
|
+
return model.physics_model._link_spatial_inertias[link_index]
|
104
|
+
|
105
|
+
|
106
|
+
@jax.jit
|
107
|
+
def transform(
|
108
|
+
model: Model.JaxSimModel, data: Data.JaxSimModelData, *, link_index: jtp.IntLike
|
109
|
+
) -> jtp.Matrix:
|
110
|
+
"""
|
111
|
+
Compute the SE(3) transform from the world frame to the link frame.
|
112
|
+
|
113
|
+
Args:
|
114
|
+
model: The model to consider.
|
115
|
+
data: The data of the considered model.
|
116
|
+
link_index: The index of the link.
|
117
|
+
|
118
|
+
Returns:
|
119
|
+
The 4x4 matrix representing the transform.
|
120
|
+
"""
|
121
|
+
|
122
|
+
return Model.forward_kinematics(model=model, data=data)[link_index]
|
123
|
+
|
124
|
+
|
125
|
+
@jax.jit
|
126
|
+
def com_position(
|
127
|
+
model: Model.JaxSimModel,
|
128
|
+
data: Data.JaxSimModelData,
|
129
|
+
*,
|
130
|
+
link_index: jtp.IntLike,
|
131
|
+
in_link_frame: jtp.BoolLike = True,
|
132
|
+
) -> jtp.Vector:
|
133
|
+
"""
|
134
|
+
Compute the position of the center of mass of the link.
|
135
|
+
|
136
|
+
Args:
|
137
|
+
model: The model to consider.
|
138
|
+
data: The data of the considered model.
|
139
|
+
link_index: The index of the link.
|
140
|
+
in_link_frame:
|
141
|
+
Whether to return the position in the link frame or in the world frame.
|
142
|
+
|
143
|
+
Returns:
|
144
|
+
The 3D position of the center of mass of the link.
|
145
|
+
"""
|
146
|
+
|
147
|
+
from jaxsim.math.inertia import Inertia
|
148
|
+
|
149
|
+
_, L_p_CoM, _ = Inertia.to_params(
|
150
|
+
M=spatial_inertia(model=model, link_index=link_index)
|
151
|
+
)
|
152
|
+
|
153
|
+
def com_in_link_frame():
|
154
|
+
return L_p_CoM.squeeze()
|
155
|
+
|
156
|
+
def com_in_inertial_frame():
|
157
|
+
W_H_L = transform(link_index=link_index, model=model, data=data)
|
158
|
+
W_p̃_CoM = W_H_L @ jnp.hstack([L_p_CoM.squeeze(), 1])
|
159
|
+
|
160
|
+
return W_p̃_CoM[0:3].squeeze()
|
161
|
+
|
162
|
+
return jax.lax.select(
|
163
|
+
pred=in_link_frame,
|
164
|
+
on_true=com_in_link_frame(),
|
165
|
+
on_false=com_in_inertial_frame(),
|
166
|
+
)
|
167
|
+
|
168
|
+
|
169
|
+
@functools.partial(jax.jit, static_argnames=["output_vel_repr"])
|
170
|
+
def jacobian(
|
171
|
+
model: Model.JaxSimModel,
|
172
|
+
data: Data.JaxSimModelData,
|
173
|
+
*,
|
174
|
+
link_index: jtp.IntLike,
|
175
|
+
output_vel_repr: VelRepr | None = None,
|
176
|
+
) -> jtp.Matrix:
|
177
|
+
"""
|
178
|
+
Compute the free-floating jacobian of the link.
|
179
|
+
|
180
|
+
Args:
|
181
|
+
model: The model to consider.
|
182
|
+
data: The data of the considered model.
|
183
|
+
link_index: The index of the link.
|
184
|
+
output_vel_repr:
|
185
|
+
The output velocity representation of the free-floating jacobian.
|
186
|
+
|
187
|
+
Returns:
|
188
|
+
The 6x(6+dofs) free-floating jacobian of the link.
|
189
|
+
|
190
|
+
Note:
|
191
|
+
The input representation of the free-floating jacobian is the active
|
192
|
+
velocity representation.
|
193
|
+
"""
|
194
|
+
|
195
|
+
if output_vel_repr is None:
|
196
|
+
output_vel_repr = data.velocity_representation
|
197
|
+
|
198
|
+
# Compute the doubly left-trivialized free-floating jacobian
|
199
|
+
L_J_WL_B = jaxsim.physics.algos.jacobian.jacobian(
|
200
|
+
model=model.physics_model,
|
201
|
+
body_index=link_index,
|
202
|
+
q=data.joint_positions(),
|
203
|
+
)
|
204
|
+
|
205
|
+
match data.velocity_representation:
|
206
|
+
|
207
|
+
case VelRepr.Body:
|
208
|
+
L_J_WL_target = L_J_WL_B
|
209
|
+
|
210
|
+
case VelRepr.Inertial:
|
211
|
+
dofs = model.dofs()
|
212
|
+
W_H_B = data.base_transform()
|
213
|
+
|
214
|
+
B_X_W = jaxlie.SE3.from_matrix(W_H_B).inverse().adjoint()
|
215
|
+
zero_6n = jnp.zeros(shape=(6, dofs))
|
216
|
+
|
217
|
+
B_T_W = jnp.vstack(
|
218
|
+
[
|
219
|
+
jnp.block([B_X_W, zero_6n]),
|
220
|
+
jnp.block([zero_6n.T, jnp.eye(dofs)]),
|
221
|
+
]
|
222
|
+
)
|
223
|
+
|
224
|
+
L_J_WL_target = L_J_WL_B @ B_T_W
|
225
|
+
|
226
|
+
case VelRepr.Mixed:
|
227
|
+
dofs = model.dofs()
|
228
|
+
W_H_B = data.base_transform()
|
229
|
+
BW_H_B = jnp.array(W_H_B).at[0:3, 3].set(jnp.zeros(3))
|
230
|
+
|
231
|
+
B_X_BW = jaxlie.SE3.from_matrix(BW_H_B).inverse().adjoint()
|
232
|
+
zero_6n = jnp.zeros(shape=(6, dofs))
|
233
|
+
|
234
|
+
B_T_BW = jnp.vstack(
|
235
|
+
[
|
236
|
+
jnp.block([B_X_BW, zero_6n]),
|
237
|
+
jnp.block([zero_6n.T, jnp.eye(dofs)]),
|
238
|
+
]
|
239
|
+
)
|
240
|
+
|
241
|
+
L_J_WL_target = L_J_WL_B @ B_T_BW
|
242
|
+
|
243
|
+
case _:
|
244
|
+
raise ValueError(data.velocity_representation)
|
245
|
+
|
246
|
+
match output_vel_repr:
|
247
|
+
case VelRepr.Body:
|
248
|
+
return L_J_WL_target
|
249
|
+
|
250
|
+
case VelRepr.Inertial:
|
251
|
+
W_H_L = transform(model=model, data=data, link_index=link_index)
|
252
|
+
W_X_L = jaxlie.SE3.from_matrix(W_H_L).adjoint()
|
253
|
+
return W_X_L @ L_J_WL_target
|
254
|
+
|
255
|
+
case VelRepr.Mixed:
|
256
|
+
W_H_L = transform(model=model, data=data, link_index=link_index)
|
257
|
+
LW_H_L = jnp.array(W_H_L).at[0:3, 3].set(jnp.zeros(3))
|
258
|
+
LW_X_L = jaxlie.SE3.from_matrix(LW_H_L).adjoint()
|
259
|
+
return LW_X_L @ L_J_WL_target
|
260
|
+
|
261
|
+
case _:
|
262
|
+
raise ValueError(output_vel_repr)
|