jaxsim 0.2.dev108__py3-none-any.whl → 0.2.dev166__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jaxsim/api/joint.py ADDED
@@ -0,0 +1,148 @@
1
+ import functools
2
+ from typing import Sequence
3
+
4
+ import jax
5
+ import jax.numpy as jnp
6
+
7
+ import jaxsim.typing as jtp
8
+
9
+ from . import model as Model
10
+
11
+ # =======================
12
+ # Index-related functions
13
+ # =======================
14
+
15
+
16
+ def name_to_idx(model: Model.JaxSimModel, *, joint_name: str) -> jtp.Int:
17
+ """
18
+ Convert the name of a joint to its index.
19
+
20
+ Args:
21
+ model: The model to consider.
22
+ joint_name: The name of the joint.
23
+
24
+ Returns:
25
+ The index of the joint.
26
+ """
27
+
28
+ return jnp.array(
29
+ model.physics_model.description.joints_dict[joint_name].index, dtype=int
30
+ )
31
+
32
+
33
+ def idx_to_name(model: Model.JaxSimModel, *, joint_index: jtp.IntLike) -> str:
34
+ """
35
+ Convert the index of a joint to its name.
36
+
37
+ Args:
38
+ model: The model to consider.
39
+ joint_index: The index of the joint.
40
+
41
+ Returns:
42
+ The name of the joint.
43
+ """
44
+
45
+ d = {j.index: j.name for j in model.physics_model.description.joints_dict.values()}
46
+ return d[joint_index]
47
+
48
+
49
+ def names_to_idxs(model: Model.JaxSimModel, *, joint_names: Sequence[str]) -> jax.Array:
50
+ """
51
+ Convert a sequence of joint names to their corresponding indices.
52
+
53
+ Args:
54
+ model: The model to consider.
55
+ joint_names: The names of the joints.
56
+
57
+ Returns:
58
+ The indices of the joints.
59
+ """
60
+
61
+ return jnp.array(
62
+ [
63
+ # Note: the index of the joint for RBDAs starts from 1, but
64
+ # the index for accessing the right element starts from 0.
65
+ # Therefore, there is a -1.
66
+ model.physics_model.description.joints_dict[name].index - 1
67
+ for name in joint_names
68
+ ],
69
+ dtype=int,
70
+ )
71
+
72
+
73
+ def idxs_to_names(
74
+ model: Model.JaxSimModel, *, joint_indices: Sequence[jtp.IntLike] | jtp.VectorLike
75
+ ) -> tuple[str, ...]:
76
+ """
77
+ Convert a sequence of joint indices to their corresponding names.
78
+
79
+ Args:
80
+ model: The model to consider.
81
+ joint_indices: The indices of the joints.
82
+
83
+ Returns:
84
+ The names of the joints.
85
+ """
86
+
87
+ d = {
88
+ j.index - 1: j.name
89
+ for j in model.physics_model.description.joints_dict.values()
90
+ }
91
+
92
+ return tuple(d[i] for i in joint_indices)
93
+
94
+
95
+ # ============
96
+ # Joint limits
97
+ # ============
98
+
99
+
100
+ @jax.jit
101
+ def position_limit(
102
+ model: Model.JaxSimModel, *, joint_index: jtp.IntLike
103
+ ) -> tuple[jtp.Float, jtp.Float]:
104
+ """"""
105
+
106
+ min = model.physics_model._joint_position_limits_min[joint_index]
107
+ max = model.physics_model._joint_position_limits_max[joint_index]
108
+
109
+ return min.astype(float), max.astype(float)
110
+
111
+
112
+ @functools.partial(jax.jit, static_argnames=["joint_names"])
113
+ def position_limits(
114
+ model: Model.JaxSimModel, *, joint_names: Sequence[str] | None = None
115
+ ) -> tuple[jtp.Vector, jtp.Vector]:
116
+
117
+ joint_names = joint_names if joint_names is not None else model.joint_names()
118
+
119
+ joint_idxs = names_to_idxs(joint_names=joint_names, model=model)
120
+ return jax.vmap(lambda i: position_limit(model=model, joint_index=i))(joint_idxs)
121
+
122
+
123
+ # ======================
124
+ # Random data generation
125
+ # ======================
126
+
127
+
128
+ @functools.partial(jax.jit, static_argnames=["joint_names"])
129
+ def random_joint_positions(
130
+ model: Model.JaxSimModel,
131
+ *,
132
+ joint_names: Sequence[str] | None = None,
133
+ key: jax.Array | None = None,
134
+ ) -> jtp.Vector:
135
+ """"""
136
+
137
+ key = key if key is not None else jax.random.PRNGKey(seed=0)
138
+
139
+ s_min, s_max = position_limits(model=model, joint_names=joint_names)
140
+
141
+ s_random = jax.random.uniform(
142
+ minval=s_min,
143
+ maxval=s_max,
144
+ key=key,
145
+ shape=s_min.shape,
146
+ )
147
+
148
+ return s_random
jaxsim/api/link.py ADDED
@@ -0,0 +1,262 @@
1
+ import functools
2
+ from typing import Sequence
3
+
4
+ import jax
5
+ import jax.numpy as jnp
6
+ import jaxlie
7
+
8
+ import jaxsim.physics.algos.jacobian
9
+ import jaxsim.typing as jtp
10
+ from jaxsim.high_level.common import VelRepr
11
+
12
+ from . import data as Data
13
+ from . import model as Model
14
+
15
+ # =======================
16
+ # Index-related functions
17
+ # =======================
18
+
19
+
20
+ def name_to_idx(model: Model.JaxSimModel, *, link_name: str) -> jtp.Int:
21
+ """
22
+ Convert the name of a link to its index.
23
+
24
+ Args:
25
+ model: The model to consider.
26
+ link_name: The name of the link.
27
+
28
+ Returns:
29
+ The index of the link.
30
+ """
31
+
32
+ return jnp.array(
33
+ model.physics_model.description.links_dict[link_name].index, dtype=int
34
+ )
35
+
36
+
37
+ def idx_to_name(model: Model.JaxSimModel, *, link_index: jtp.IntLike) -> str:
38
+ """
39
+ Convert the index of a link to its name.
40
+
41
+ Args:
42
+ model: The model to consider.
43
+ link_index: The index of the link.
44
+
45
+ Returns:
46
+ The name of the link.
47
+ """
48
+
49
+ d = {l.index: l.name for l in model.physics_model.description.links_dict.values()}
50
+ return d[link_index]
51
+
52
+
53
+ def names_to_idxs(model: Model.JaxSimModel, *, link_names: Sequence[str]) -> jax.Array:
54
+ """
55
+ Convert a sequence of link names to their corresponding indices.
56
+
57
+ Args:
58
+ model: The model to consider.
59
+ link_names: The names of the links.
60
+
61
+ Returns:
62
+ The indices of the links.
63
+ """
64
+
65
+ return jnp.array(
66
+ [model.physics_model.description.links_dict[name].index for name in link_names],
67
+ dtype=int,
68
+ )
69
+
70
+
71
+ def idxs_to_names(
72
+ model: Model.JaxSimModel, *, link_indices: Sequence[jtp.IntLike] | jtp.VectorLike
73
+ ) -> tuple[str, ...]:
74
+ """
75
+ Convert a sequence of link indices to their corresponding names.
76
+
77
+ Args:
78
+ model: The model to consider.
79
+ link_indices: The indices of the links.
80
+
81
+ Returns:
82
+ The names of the links.
83
+ """
84
+
85
+ d = {l.index: l.name for l in model.physics_model.description.links_dict.values()}
86
+ return tuple(d[i] for i in link_indices)
87
+
88
+
89
+ # =========
90
+ # Link APIs
91
+ # =========
92
+
93
+
94
+ def mass(model: Model.JaxSimModel, *, link_index: jtp.IntLike) -> jtp.Float:
95
+ """"""
96
+
97
+ return model.physics_model._link_masses[link_index].astype(float)
98
+
99
+
100
+ def spatial_inertia(model: Model.JaxSimModel, *, link_index: jtp.IntLike) -> jtp.Matrix:
101
+ """"""
102
+
103
+ return model.physics_model._link_spatial_inertias[link_index]
104
+
105
+
106
+ @jax.jit
107
+ def transform(
108
+ model: Model.JaxSimModel, data: Data.JaxSimModelData, *, link_index: jtp.IntLike
109
+ ) -> jtp.Matrix:
110
+ """
111
+ Compute the SE(3) transform from the world frame to the link frame.
112
+
113
+ Args:
114
+ model: The model to consider.
115
+ data: The data of the considered model.
116
+ link_index: The index of the link.
117
+
118
+ Returns:
119
+ The 4x4 matrix representing the transform.
120
+ """
121
+
122
+ return Model.forward_kinematics(model=model, data=data)[link_index]
123
+
124
+
125
+ @jax.jit
126
+ def com_position(
127
+ model: Model.JaxSimModel,
128
+ data: Data.JaxSimModelData,
129
+ *,
130
+ link_index: jtp.IntLike,
131
+ in_link_frame: jtp.BoolLike = True,
132
+ ) -> jtp.Vector:
133
+ """
134
+ Compute the position of the center of mass of the link.
135
+
136
+ Args:
137
+ model: The model to consider.
138
+ data: The data of the considered model.
139
+ link_index: The index of the link.
140
+ in_link_frame:
141
+ Whether to return the position in the link frame or in the world frame.
142
+
143
+ Returns:
144
+ The 3D position of the center of mass of the link.
145
+ """
146
+
147
+ from jaxsim.math.inertia import Inertia
148
+
149
+ _, L_p_CoM, _ = Inertia.to_params(
150
+ M=spatial_inertia(model=model, link_index=link_index)
151
+ )
152
+
153
+ def com_in_link_frame():
154
+ return L_p_CoM.squeeze()
155
+
156
+ def com_in_inertial_frame():
157
+ W_H_L = transform(link_index=link_index, model=model, data=data)
158
+ W_p̃_CoM = W_H_L @ jnp.hstack([L_p_CoM.squeeze(), 1])
159
+
160
+ return W_p̃_CoM[0:3].squeeze()
161
+
162
+ return jax.lax.select(
163
+ pred=in_link_frame,
164
+ on_true=com_in_link_frame(),
165
+ on_false=com_in_inertial_frame(),
166
+ )
167
+
168
+
169
+ @functools.partial(jax.jit, static_argnames=["output_vel_repr"])
170
+ def jacobian(
171
+ model: Model.JaxSimModel,
172
+ data: Data.JaxSimModelData,
173
+ *,
174
+ link_index: jtp.IntLike,
175
+ output_vel_repr: VelRepr | None = None,
176
+ ) -> jtp.Matrix:
177
+ """
178
+ Compute the free-floating jacobian of the link.
179
+
180
+ Args:
181
+ model: The model to consider.
182
+ data: The data of the considered model.
183
+ link_index: The index of the link.
184
+ output_vel_repr:
185
+ The output velocity representation of the free-floating jacobian.
186
+
187
+ Returns:
188
+ The 6x(6+dofs) free-floating jacobian of the link.
189
+
190
+ Note:
191
+ The input representation of the free-floating jacobian is the active
192
+ velocity representation.
193
+ """
194
+
195
+ if output_vel_repr is None:
196
+ output_vel_repr = data.velocity_representation
197
+
198
+ # Compute the doubly left-trivialized free-floating jacobian
199
+ L_J_WL_B = jaxsim.physics.algos.jacobian.jacobian(
200
+ model=model.physics_model,
201
+ body_index=link_index,
202
+ q=data.joint_positions(),
203
+ )
204
+
205
+ match data.velocity_representation:
206
+
207
+ case VelRepr.Body:
208
+ L_J_WL_target = L_J_WL_B
209
+
210
+ case VelRepr.Inertial:
211
+ dofs = model.dofs()
212
+ W_H_B = data.base_transform()
213
+
214
+ B_X_W = jaxlie.SE3.from_matrix(W_H_B).inverse().adjoint()
215
+ zero_6n = jnp.zeros(shape=(6, dofs))
216
+
217
+ B_T_W = jnp.vstack(
218
+ [
219
+ jnp.block([B_X_W, zero_6n]),
220
+ jnp.block([zero_6n.T, jnp.eye(dofs)]),
221
+ ]
222
+ )
223
+
224
+ L_J_WL_target = L_J_WL_B @ B_T_W
225
+
226
+ case VelRepr.Mixed:
227
+ dofs = model.dofs()
228
+ W_H_B = data.base_transform()
229
+ BW_H_B = jnp.array(W_H_B).at[0:3, 3].set(jnp.zeros(3))
230
+
231
+ B_X_BW = jaxlie.SE3.from_matrix(BW_H_B).inverse().adjoint()
232
+ zero_6n = jnp.zeros(shape=(6, dofs))
233
+
234
+ B_T_BW = jnp.vstack(
235
+ [
236
+ jnp.block([B_X_BW, zero_6n]),
237
+ jnp.block([zero_6n.T, jnp.eye(dofs)]),
238
+ ]
239
+ )
240
+
241
+ L_J_WL_target = L_J_WL_B @ B_T_BW
242
+
243
+ case _:
244
+ raise ValueError(data.velocity_representation)
245
+
246
+ match output_vel_repr:
247
+ case VelRepr.Body:
248
+ return L_J_WL_target
249
+
250
+ case VelRepr.Inertial:
251
+ W_H_L = transform(model=model, data=data, link_index=link_index)
252
+ W_X_L = jaxlie.SE3.from_matrix(W_H_L).adjoint()
253
+ return W_X_L @ L_J_WL_target
254
+
255
+ case VelRepr.Mixed:
256
+ W_H_L = transform(model=model, data=data, link_index=link_index)
257
+ LW_H_L = jnp.array(W_H_L).at[0:3, 3].set(jnp.zeros(3))
258
+ LW_X_L = jaxlie.SE3.from_matrix(LW_H_L).adjoint()
259
+ return LW_X_L @ L_J_WL_target
260
+
261
+ case _:
262
+ raise ValueError(output_vel_repr)