jaxsim 0.2.dev188__py3-none-any.whl → 0.6.1.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/__init__.py +73 -22
- jaxsim/_version.py +2 -2
- jaxsim/api/__init__.py +13 -1
- jaxsim/api/com.py +423 -0
- jaxsim/api/common.py +48 -19
- jaxsim/api/contact.py +604 -52
- jaxsim/api/data.py +308 -163
- jaxsim/api/frame.py +471 -0
- jaxsim/api/joint.py +166 -37
- jaxsim/api/kin_dyn_parameters.py +901 -0
- jaxsim/api/link.py +277 -78
- jaxsim/api/model.py +1572 -362
- jaxsim/api/ode.py +324 -129
- jaxsim/api/ode_data.py +401 -0
- jaxsim/api/references.py +216 -80
- jaxsim/exceptions.py +80 -0
- jaxsim/integrators/__init__.py +2 -2
- jaxsim/integrators/common.py +191 -107
- jaxsim/integrators/fixed_step.py +97 -102
- jaxsim/integrators/variable_step.py +706 -0
- jaxsim/logging.py +1 -2
- jaxsim/math/__init__.py +13 -0
- jaxsim/math/adjoint.py +57 -22
- jaxsim/math/cross.py +16 -7
- jaxsim/math/inertia.py +10 -8
- jaxsim/math/joint_model.py +289 -0
- jaxsim/math/quaternion.py +54 -20
- jaxsim/math/rotation.py +27 -21
- jaxsim/math/skew.py +16 -5
- jaxsim/math/transform.py +102 -0
- jaxsim/math/utils.py +31 -0
- jaxsim/mujoco/__init__.py +2 -1
- jaxsim/mujoco/loaders.py +216 -29
- jaxsim/mujoco/model.py +163 -33
- jaxsim/mujoco/utils.py +228 -0
- jaxsim/mujoco/visualizer.py +107 -22
- jaxsim/parsers/__init__.py +0 -1
- jaxsim/parsers/descriptions/__init__.py +8 -2
- jaxsim/parsers/descriptions/collision.py +87 -16
- jaxsim/parsers/descriptions/joint.py +80 -87
- jaxsim/parsers/descriptions/link.py +62 -24
- jaxsim/parsers/descriptions/model.py +101 -68
- jaxsim/parsers/kinematic_graph.py +607 -225
- jaxsim/parsers/rod/meshes.py +104 -0
- jaxsim/parsers/rod/parser.py +125 -82
- jaxsim/parsers/rod/utils.py +127 -82
- jaxsim/rbda/__init__.py +11 -0
- jaxsim/rbda/aba.py +289 -0
- jaxsim/rbda/collidable_points.py +156 -0
- jaxsim/rbda/contacts/__init__.py +13 -0
- jaxsim/rbda/contacts/common.py +313 -0
- jaxsim/rbda/contacts/relaxed_rigid.py +605 -0
- jaxsim/rbda/contacts/rigid.py +462 -0
- jaxsim/rbda/contacts/soft.py +480 -0
- jaxsim/rbda/contacts/visco_elastic.py +1066 -0
- jaxsim/rbda/crba.py +167 -0
- jaxsim/rbda/forward_kinematics.py +117 -0
- jaxsim/rbda/jacobian.py +330 -0
- jaxsim/rbda/rnea.py +235 -0
- jaxsim/rbda/utils.py +160 -0
- jaxsim/terrain/__init__.py +2 -0
- jaxsim/terrain/terrain.py +238 -0
- jaxsim/typing.py +24 -24
- jaxsim/utils/__init__.py +1 -4
- jaxsim/utils/jaxsim_dataclass.py +289 -34
- jaxsim/utils/tracing.py +5 -11
- jaxsim/utils/wrappers.py +159 -0
- {jaxsim-0.2.dev188.dist-info → jaxsim-0.6.1.dev2.dist-info}/LICENSE +1 -1
- jaxsim-0.6.1.dev2.dist-info/METADATA +465 -0
- jaxsim-0.6.1.dev2.dist-info/RECORD +74 -0
- {jaxsim-0.2.dev188.dist-info → jaxsim-0.6.1.dev2.dist-info}/WHEEL +1 -1
- jaxsim/high_level/__init__.py +0 -2
- jaxsim/high_level/common.py +0 -11
- jaxsim/high_level/joint.py +0 -148
- jaxsim/high_level/link.py +0 -259
- jaxsim/high_level/model.py +0 -1686
- jaxsim/math/conv.py +0 -114
- jaxsim/math/joint.py +0 -102
- jaxsim/math/plucker.py +0 -100
- jaxsim/physics/__init__.py +0 -12
- jaxsim/physics/algos/__init__.py +0 -0
- jaxsim/physics/algos/aba.py +0 -254
- jaxsim/physics/algos/aba_motors.py +0 -284
- jaxsim/physics/algos/crba.py +0 -154
- jaxsim/physics/algos/forward_kinematics.py +0 -79
- jaxsim/physics/algos/jacobian.py +0 -98
- jaxsim/physics/algos/rnea.py +0 -180
- jaxsim/physics/algos/rnea_motors.py +0 -196
- jaxsim/physics/algos/soft_contacts.py +0 -523
- jaxsim/physics/algos/terrain.py +0 -80
- jaxsim/physics/algos/utils.py +0 -69
- jaxsim/physics/model/__init__.py +0 -0
- jaxsim/physics/model/ground_contact.py +0 -55
- jaxsim/physics/model/physics_model.py +0 -388
- jaxsim/physics/model/physics_model_state.py +0 -283
- jaxsim/simulation/__init__.py +0 -4
- jaxsim/simulation/integrators.py +0 -393
- jaxsim/simulation/ode.py +0 -290
- jaxsim/simulation/ode_data.py +0 -96
- jaxsim/simulation/ode_integration.py +0 -62
- jaxsim/simulation/simulator.py +0 -543
- jaxsim/simulation/simulator_callbacks.py +0 -79
- jaxsim/simulation/utils.py +0 -15
- jaxsim/sixd/__init__.py +0 -2
- jaxsim/utils/oop.py +0 -536
- jaxsim/utils/vmappable.py +0 -117
- jaxsim-0.2.dev188.dist-info/METADATA +0 -184
- jaxsim-0.2.dev188.dist-info/RECORD +0 -81
- {jaxsim-0.2.dev188.dist-info → jaxsim-0.6.1.dev2.dist-info}/top_level.txt +0 -0
jaxsim/api/joint.py
CHANGED
@@ -1,19 +1,21 @@
|
|
1
1
|
import functools
|
2
|
-
from
|
2
|
+
from collections.abc import Sequence
|
3
3
|
|
4
4
|
import jax
|
5
5
|
import jax.numpy as jnp
|
6
6
|
|
7
|
+
import jaxsim.api as js
|
7
8
|
import jaxsim.typing as jtp
|
8
|
-
|
9
|
-
from . import model as Model
|
9
|
+
from jaxsim import exceptions
|
10
10
|
|
11
11
|
# =======================
|
12
12
|
# Index-related functions
|
13
13
|
# =======================
|
14
14
|
|
15
15
|
|
16
|
-
|
16
|
+
@functools.partial(jax.jit, static_argnames="joint_name")
|
17
|
+
@js.common.named_scope
|
18
|
+
def name_to_idx(model: js.model.JaxSimModel, *, joint_name: str) -> jtp.Int:
|
17
19
|
"""
|
18
20
|
Convert the name of a joint to its index.
|
19
21
|
|
@@ -25,12 +27,21 @@ def name_to_idx(model: Model.JaxSimModel, *, joint_name: str) -> jtp.Int:
|
|
25
27
|
The index of the joint.
|
26
28
|
"""
|
27
29
|
|
28
|
-
|
29
|
-
model.
|
30
|
+
if joint_name not in model.joint_names():
|
31
|
+
raise ValueError(f"Joint '{joint_name}' not found in the model.")
|
32
|
+
|
33
|
+
# Note: the index of the joint for RBDAs starts from 1, but the index for
|
34
|
+
# accessing the right element starts from 0. Therefore, there is a -1.
|
35
|
+
return (
|
36
|
+
jnp.array(
|
37
|
+
model.kin_dyn_parameters.joint_model.joint_names.index(joint_name) - 1
|
38
|
+
)
|
39
|
+
.astype(int)
|
40
|
+
.squeeze()
|
30
41
|
)
|
31
42
|
|
32
43
|
|
33
|
-
def idx_to_name(model:
|
44
|
+
def idx_to_name(model: js.model.JaxSimModel, *, joint_index: jtp.IntLike) -> str:
|
34
45
|
"""
|
35
46
|
Convert the index of a joint to its name.
|
36
47
|
|
@@ -42,11 +53,20 @@ def idx_to_name(model: Model.JaxSimModel, *, joint_index: jtp.IntLike) -> str:
|
|
42
53
|
The name of the joint.
|
43
54
|
"""
|
44
55
|
|
45
|
-
|
46
|
-
|
56
|
+
exceptions.raise_value_error_if(
|
57
|
+
condition=joint_index < 0,
|
58
|
+
msg="Invalid joint index '{idx}'",
|
59
|
+
idx=joint_index,
|
60
|
+
)
|
61
|
+
|
62
|
+
return model.kin_dyn_parameters.joint_model.joint_names[joint_index + 1]
|
47
63
|
|
48
64
|
|
49
|
-
|
65
|
+
@functools.partial(jax.jit, static_argnames="joint_names")
|
66
|
+
@js.common.named_scope
|
67
|
+
def names_to_idxs(
|
68
|
+
model: js.model.JaxSimModel, *, joint_names: Sequence[str]
|
69
|
+
) -> jax.Array:
|
50
70
|
"""
|
51
71
|
Convert a sequence of joint names to their corresponding indices.
|
52
72
|
|
@@ -59,19 +79,14 @@ def names_to_idxs(model: Model.JaxSimModel, *, joint_names: Sequence[str]) -> ja
|
|
59
79
|
"""
|
60
80
|
|
61
81
|
return jnp.array(
|
62
|
-
[
|
63
|
-
|
64
|
-
# the index for accessing the right element starts from 0.
|
65
|
-
# Therefore, there is a -1.
|
66
|
-
model.physics_model.description.joints_dict[name].index - 1
|
67
|
-
for name in joint_names
|
68
|
-
],
|
69
|
-
dtype=int,
|
70
|
-
)
|
82
|
+
[name_to_idx(model=model, joint_name=name) for name in joint_names],
|
83
|
+
).astype(int)
|
71
84
|
|
72
85
|
|
73
86
|
def idxs_to_names(
|
74
|
-
model:
|
87
|
+
model: js.model.JaxSimModel,
|
88
|
+
*,
|
89
|
+
joint_indices: Sequence[jtp.IntLike] | jtp.VectorLike,
|
75
90
|
) -> tuple[str, ...]:
|
76
91
|
"""
|
77
92
|
Convert a sequence of joint indices to their corresponding names.
|
@@ -84,12 +99,7 @@ def idxs_to_names(
|
|
84
99
|
The names of the joints.
|
85
100
|
"""
|
86
101
|
|
87
|
-
|
88
|
-
j.index - 1: j.name
|
89
|
-
for j in model.physics_model.description.joints_dict.values()
|
90
|
-
}
|
91
|
-
|
92
|
-
return tuple(d[i] for i in joint_indices)
|
102
|
+
return tuple(idx_to_name(model=model, joint_index=idx) for idx in joint_indices)
|
93
103
|
|
94
104
|
|
95
105
|
# ============
|
@@ -99,25 +109,69 @@ def idxs_to_names(
|
|
99
109
|
|
100
110
|
@jax.jit
|
101
111
|
def position_limit(
|
102
|
-
model:
|
112
|
+
model: js.model.JaxSimModel, *, joint_index: jtp.IntLike
|
103
113
|
) -> tuple[jtp.Float, jtp.Float]:
|
104
|
-
"""
|
114
|
+
"""
|
115
|
+
Get the position limits of a joint.
|
105
116
|
|
106
|
-
|
107
|
-
|
117
|
+
Args:
|
118
|
+
model: The model to consider.
|
119
|
+
joint_index: The index of the joint.
|
108
120
|
|
109
|
-
|
121
|
+
Returns:
|
122
|
+
The position limits of the joint.
|
123
|
+
"""
|
124
|
+
|
125
|
+
if model.number_of_joints() == 0:
|
126
|
+
return jnp.empty(0).astype(float), jnp.empty(0).astype(float)
|
127
|
+
|
128
|
+
exceptions.raise_value_error_if(
|
129
|
+
condition=jnp.array(
|
130
|
+
[joint_index < 0, joint_index >= model.number_of_joints()]
|
131
|
+
).any(),
|
132
|
+
msg="Invalid joint index '{idx}'",
|
133
|
+
idx=joint_index,
|
134
|
+
)
|
135
|
+
|
136
|
+
s_min = jnp.atleast_1d(
|
137
|
+
model.kin_dyn_parameters.joint_parameters.position_limits_min
|
138
|
+
)[joint_index]
|
139
|
+
s_max = jnp.atleast_1d(
|
140
|
+
model.kin_dyn_parameters.joint_parameters.position_limits_max
|
141
|
+
)[joint_index]
|
142
|
+
|
143
|
+
return s_min.astype(float), s_max.astype(float)
|
110
144
|
|
111
145
|
|
112
146
|
@functools.partial(jax.jit, static_argnames=["joint_names"])
|
147
|
+
@js.common.named_scope
|
113
148
|
def position_limits(
|
114
|
-
model:
|
149
|
+
model: js.model.JaxSimModel, *, joint_names: Sequence[str] | None = None
|
115
150
|
) -> tuple[jtp.Vector, jtp.Vector]:
|
151
|
+
"""
|
152
|
+
Get the position limits of a list of joint.
|
116
153
|
|
117
|
-
|
154
|
+
Args:
|
155
|
+
model: The model to consider.
|
156
|
+
joint_names: The names of the joints.
|
157
|
+
|
158
|
+
Returns:
|
159
|
+
The position limits of the joints.
|
160
|
+
"""
|
161
|
+
|
162
|
+
joint_idxs = (
|
163
|
+
names_to_idxs(joint_names=joint_names, model=model)
|
164
|
+
if joint_names is not None
|
165
|
+
else jnp.arange(model.number_of_joints())
|
166
|
+
)
|
167
|
+
|
168
|
+
if len(joint_idxs) == 0:
|
169
|
+
return jnp.empty(0).astype(float), jnp.empty(0).astype(float)
|
118
170
|
|
119
|
-
|
120
|
-
|
171
|
+
s_min = model.kin_dyn_parameters.joint_parameters.position_limits_min[joint_idxs]
|
172
|
+
s_max = model.kin_dyn_parameters.joint_parameters.position_limits_max[joint_idxs]
|
173
|
+
|
174
|
+
return s_min.astype(float), s_max.astype(float)
|
121
175
|
|
122
176
|
|
123
177
|
# ======================
|
@@ -126,18 +180,93 @@ def position_limits(
|
|
126
180
|
|
127
181
|
|
128
182
|
@functools.partial(jax.jit, static_argnames=["joint_names"])
|
183
|
+
@js.common.named_scope
|
129
184
|
def random_joint_positions(
|
130
|
-
model:
|
185
|
+
model: js.model.JaxSimModel,
|
131
186
|
*,
|
132
187
|
joint_names: Sequence[str] | None = None,
|
133
188
|
key: jax.Array | None = None,
|
134
189
|
) -> jtp.Vector:
|
135
|
-
"""
|
190
|
+
"""
|
191
|
+
Generate random joint positions.
|
192
|
+
|
193
|
+
Args:
|
194
|
+
model: The model to consider.
|
195
|
+
joint_names: The names of the considered joints (all if None).
|
196
|
+
key: The random key (initialized from seed 0 if None).
|
136
197
|
|
198
|
+
Note:
|
199
|
+
If the joint range or revolute joints is larger than 2π, their joint positions
|
200
|
+
will be sampled from an interval of size 2π.
|
201
|
+
|
202
|
+
Returns:
|
203
|
+
The random joint positions.
|
204
|
+
"""
|
205
|
+
|
206
|
+
# Consider the key corresponding to a zero seed if it was not passed.
|
137
207
|
key = key if key is not None else jax.random.PRNGKey(seed=0)
|
138
208
|
|
209
|
+
# Get the joint limits parsed from the model description.
|
139
210
|
s_min, s_max = position_limits(model=model, joint_names=joint_names)
|
140
211
|
|
212
|
+
# Get the joint indices.
|
213
|
+
# Note that it will trigger an exception if the given `joint_names` are not valid.
|
214
|
+
joint_names = joint_names if joint_names is not None else model.joint_names()
|
215
|
+
joint_indices = (
|
216
|
+
names_to_idxs(model=model, joint_names=joint_names)
|
217
|
+
if joint_names is not None
|
218
|
+
else jnp.arange(model.number_of_joints())
|
219
|
+
)
|
220
|
+
|
221
|
+
from jaxsim.parsers.descriptions.joint import JointType
|
222
|
+
|
223
|
+
# Filter for revolute joints.
|
224
|
+
is_revolute = jnp.where(
|
225
|
+
jnp.array(model.kin_dyn_parameters.joint_model.joint_types[1:])[joint_indices]
|
226
|
+
== JointType.Revolute,
|
227
|
+
True,
|
228
|
+
False,
|
229
|
+
)
|
230
|
+
|
231
|
+
# Shorthand for π.
|
232
|
+
π = jnp.pi
|
233
|
+
|
234
|
+
# Filter for revolute with full range (or continuous).
|
235
|
+
is_revolute_full_range = jnp.logical_and(is_revolute, s_max - s_min >= 2 * π)
|
236
|
+
|
237
|
+
# Clip the lower limit to -π if the joint range is larger than [-π, π].
|
238
|
+
s_min = jnp.where(
|
239
|
+
jnp.logical_and(
|
240
|
+
is_revolute_full_range, jnp.logical_and(s_min <= -π, s_max >= π)
|
241
|
+
),
|
242
|
+
-π,
|
243
|
+
s_min,
|
244
|
+
)
|
245
|
+
|
246
|
+
# Clip the upper limit to +π if the joint range is larger than [-π, π].
|
247
|
+
s_max = jnp.where(
|
248
|
+
jnp.logical_and(
|
249
|
+
is_revolute_full_range, jnp.logical_and(s_min <= -π, s_max >= π)
|
250
|
+
),
|
251
|
+
π,
|
252
|
+
s_max,
|
253
|
+
)
|
254
|
+
|
255
|
+
# Shift the lower limit if the upper limit is smaller than +π.
|
256
|
+
s_min = jnp.where(
|
257
|
+
jnp.logical_and(is_revolute_full_range, s_max < π),
|
258
|
+
s_max - 2 * π,
|
259
|
+
s_min,
|
260
|
+
)
|
261
|
+
|
262
|
+
# Shift the upper limit if the lower limit is larger than -π.
|
263
|
+
s_max = jnp.where(
|
264
|
+
jnp.logical_and(is_revolute_full_range, s_min > -π),
|
265
|
+
s_min + 2 * π,
|
266
|
+
s_max,
|
267
|
+
)
|
268
|
+
|
269
|
+
# Sample the joint positions.
|
141
270
|
s_random = jax.random.uniform(
|
142
271
|
minval=s_min,
|
143
272
|
maxval=s_max,
|