jaxsim 0.2.dev188__py3-none-any.whl → 0.2.dev364__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/__init__.py +3 -4
- jaxsim/_version.py +2 -2
- jaxsim/api/__init__.py +3 -1
- jaxsim/api/com.py +240 -0
- jaxsim/api/common.py +13 -2
- jaxsim/api/contact.py +120 -43
- jaxsim/api/data.py +112 -71
- jaxsim/api/joint.py +77 -36
- jaxsim/api/kin_dyn_parameters.py +777 -0
- jaxsim/api/link.py +150 -75
- jaxsim/api/model.py +542 -269
- jaxsim/api/ode.py +88 -72
- jaxsim/api/ode_data.py +694 -0
- jaxsim/api/references.py +12 -11
- jaxsim/integrators/__init__.py +2 -2
- jaxsim/integrators/common.py +110 -24
- jaxsim/integrators/fixed_step.py +11 -67
- jaxsim/integrators/variable_step.py +610 -0
- jaxsim/math/__init__.py +11 -0
- jaxsim/math/adjoint.py +24 -2
- jaxsim/math/joint_model.py +335 -0
- jaxsim/math/quaternion.py +44 -3
- jaxsim/math/rotation.py +4 -4
- jaxsim/math/transform.py +93 -0
- jaxsim/parsers/descriptions/collision.py +14 -0
- jaxsim/parsers/descriptions/link.py +13 -2
- jaxsim/parsers/kinematic_graph.py +5 -0
- jaxsim/parsers/rod/utils.py +7 -8
- jaxsim/rbda/__init__.py +7 -0
- jaxsim/rbda/aba.py +295 -0
- jaxsim/rbda/collidable_points.py +142 -0
- jaxsim/{physics/algos → rbda}/crba.py +43 -42
- jaxsim/rbda/forward_kinematics.py +113 -0
- jaxsim/rbda/jacobian.py +201 -0
- jaxsim/rbda/rnea.py +237 -0
- jaxsim/rbda/soft_contacts.py +296 -0
- jaxsim/rbda/utils.py +152 -0
- jaxsim/terrain/__init__.py +2 -0
- jaxsim/{physics/algos → terrain}/terrain.py +4 -6
- jaxsim/utils/__init__.py +1 -4
- jaxsim/utils/hashless.py +18 -0
- jaxsim/utils/jaxsim_dataclass.py +281 -30
- {jaxsim-0.2.dev188.dist-info → jaxsim-0.2.dev364.dist-info}/METADATA +4 -6
- jaxsim-0.2.dev364.dist-info/RECORD +64 -0
- {jaxsim-0.2.dev188.dist-info → jaxsim-0.2.dev364.dist-info}/WHEEL +1 -1
- jaxsim/high_level/__init__.py +0 -2
- jaxsim/high_level/common.py +0 -11
- jaxsim/high_level/joint.py +0 -148
- jaxsim/high_level/link.py +0 -259
- jaxsim/high_level/model.py +0 -1686
- jaxsim/math/conv.py +0 -114
- jaxsim/math/joint.py +0 -102
- jaxsim/math/plucker.py +0 -100
- jaxsim/physics/__init__.py +0 -12
- jaxsim/physics/algos/__init__.py +0 -0
- jaxsim/physics/algos/aba.py +0 -254
- jaxsim/physics/algos/aba_motors.py +0 -284
- jaxsim/physics/algos/forward_kinematics.py +0 -79
- jaxsim/physics/algos/jacobian.py +0 -98
- jaxsim/physics/algos/rnea.py +0 -180
- jaxsim/physics/algos/rnea_motors.py +0 -196
- jaxsim/physics/algos/soft_contacts.py +0 -523
- jaxsim/physics/algos/utils.py +0 -69
- jaxsim/physics/model/__init__.py +0 -0
- jaxsim/physics/model/ground_contact.py +0 -55
- jaxsim/physics/model/physics_model.py +0 -388
- jaxsim/physics/model/physics_model_state.py +0 -283
- jaxsim/simulation/__init__.py +0 -4
- jaxsim/simulation/integrators.py +0 -393
- jaxsim/simulation/ode.py +0 -290
- jaxsim/simulation/ode_data.py +0 -96
- jaxsim/simulation/ode_integration.py +0 -62
- jaxsim/simulation/simulator.py +0 -543
- jaxsim/simulation/simulator_callbacks.py +0 -79
- jaxsim/simulation/utils.py +0 -15
- jaxsim/sixd/__init__.py +0 -2
- jaxsim/utils/oop.py +0 -536
- jaxsim/utils/vmappable.py +0 -117
- jaxsim-0.2.dev188.dist-info/RECORD +0 -81
- {jaxsim-0.2.dev188.dist-info → jaxsim-0.2.dev364.dist-info}/LICENSE +0 -0
- {jaxsim-0.2.dev188.dist-info → jaxsim-0.2.dev364.dist-info}/top_level.txt +0 -0
jaxsim/api/data.py
CHANGED
@@ -10,20 +10,16 @@ import jax_dataclasses
|
|
10
10
|
import jaxlie
|
11
11
|
import numpy as np
|
12
12
|
|
13
|
-
import jaxsim.api
|
14
|
-
import jaxsim.
|
15
|
-
import jaxsim.physics.algos.crba
|
16
|
-
import jaxsim.physics.algos.forward_kinematics
|
17
|
-
import jaxsim.physics.algos.rnea
|
18
|
-
import jaxsim.physics.model.physics_model
|
19
|
-
import jaxsim.physics.model.physics_model_state
|
13
|
+
import jaxsim.api as js
|
14
|
+
import jaxsim.rbda
|
20
15
|
import jaxsim.typing as jtp
|
21
|
-
from jaxsim.
|
22
|
-
from jaxsim.physics.algos import soft_contacts
|
23
|
-
from jaxsim.simulation.ode_data import ODEState
|
16
|
+
from jaxsim.math import Quaternion
|
24
17
|
from jaxsim.utils import Mutability
|
18
|
+
from jaxsim.utils.tracing import not_tracing
|
25
19
|
|
26
20
|
from . import common
|
21
|
+
from .common import VelRepr
|
22
|
+
from .ode_data import ODEState
|
27
23
|
|
28
24
|
try:
|
29
25
|
from typing import Self
|
@@ -41,14 +37,13 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
41
37
|
|
42
38
|
gravity: jtp.Array
|
43
39
|
|
44
|
-
soft_contacts_params:
|
45
|
-
|
46
|
-
)
|
40
|
+
soft_contacts_params: jaxsim.rbda.SoftContactsParams = dataclasses.field(repr=False)
|
41
|
+
|
47
42
|
time_ns: jtp.Int = dataclasses.field(
|
48
43
|
default_factory=lambda: jnp.array(0, dtype=jnp.uint64)
|
49
44
|
)
|
50
45
|
|
51
|
-
def valid(self, model:
|
46
|
+
def valid(self, model: js.model.JaxSimModel | None = None) -> bool:
|
52
47
|
"""
|
53
48
|
Check if the current state is valid for the given model.
|
54
49
|
|
@@ -60,15 +55,16 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
60
55
|
"""
|
61
56
|
|
62
57
|
valid = True
|
58
|
+
valid = valid and self.standard_gravity() > 0
|
63
59
|
|
64
60
|
if model is not None:
|
65
|
-
valid = valid and self.state.valid(
|
61
|
+
valid = valid and self.state.valid(model=model)
|
66
62
|
|
67
63
|
return valid
|
68
64
|
|
69
65
|
@staticmethod
|
70
66
|
def zero(
|
71
|
-
model:
|
67
|
+
model: js.model.JaxSimModel,
|
72
68
|
velocity_representation: VelRepr = VelRepr.Inertial,
|
73
69
|
) -> JaxSimModelData:
|
74
70
|
"""
|
@@ -88,16 +84,16 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
88
84
|
|
89
85
|
@staticmethod
|
90
86
|
def build(
|
91
|
-
model:
|
87
|
+
model: js.model.JaxSimModel,
|
92
88
|
base_position: jtp.Vector | None = None,
|
93
89
|
base_quaternion: jtp.Vector | None = None,
|
94
90
|
joint_positions: jtp.Vector | None = None,
|
95
91
|
base_linear_velocity: jtp.Vector | None = None,
|
96
92
|
base_angular_velocity: jtp.Vector | None = None,
|
97
93
|
joint_velocities: jtp.Vector | None = None,
|
98
|
-
|
99
|
-
soft_contacts_state:
|
100
|
-
soft_contacts_params:
|
94
|
+
standard_gravity: jtp.FloatLike = jaxsim.math.StandardGravity,
|
95
|
+
soft_contacts_state: js.ode_data.SoftContactsState | None = None,
|
96
|
+
soft_contacts_params: jaxsim.rbda.SoftContactsParams | None = None,
|
101
97
|
velocity_representation: VelRepr = VelRepr.Inertial,
|
102
98
|
time: jtp.FloatLike | None = None,
|
103
99
|
) -> JaxSimModelData:
|
@@ -114,7 +110,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
114
110
|
base_angular_velocity:
|
115
111
|
The base angular velocity in the selected representation.
|
116
112
|
joint_velocities: The joint velocities.
|
117
|
-
|
113
|
+
standard_gravity: The standard gravity constant.
|
118
114
|
soft_contacts_state: The state of the soft contacts.
|
119
115
|
soft_contacts_params: The parameters of the soft contacts.
|
120
116
|
velocity_representation: The velocity representation to use.
|
@@ -142,9 +138,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
142
138
|
base_angular_velocity if base_angular_velocity is not None else jnp.zeros(3)
|
143
139
|
).squeeze()
|
144
140
|
|
145
|
-
gravity = jnp.
|
146
|
-
gravity if gravity is not None else model.physics_model.gravity[0:3]
|
147
|
-
).squeeze()
|
141
|
+
gravity = jnp.zeros(3).at[2].set(-standard_gravity)
|
148
142
|
|
149
143
|
joint_positions = jnp.atleast_1d(
|
150
144
|
joint_positions.squeeze()
|
@@ -167,7 +161,9 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
167
161
|
soft_contacts_params = (
|
168
162
|
soft_contacts_params
|
169
163
|
if soft_contacts_params is not None
|
170
|
-
else
|
164
|
+
else js.contact.estimate_good_soft_contacts_parameters(
|
165
|
+
model=model, standard_gravity=standard_gravity
|
166
|
+
)
|
171
167
|
)
|
172
168
|
|
173
169
|
W_H_B = jaxlie.SE3.from_rotation_and_translation(
|
@@ -184,20 +180,22 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
184
180
|
is_force=False,
|
185
181
|
)
|
186
182
|
|
187
|
-
ode_state = ODEState.
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
183
|
+
ode_state = ODEState.build_from_jaxsim_model(
|
184
|
+
model=model,
|
185
|
+
base_position=base_position.astype(float),
|
186
|
+
base_quaternion=base_quaternion.astype(float),
|
187
|
+
joint_positions=joint_positions.astype(float),
|
188
|
+
base_linear_velocity=v_WB[0:3].astype(float),
|
189
|
+
base_angular_velocity=v_WB[3:6].astype(float),
|
190
|
+
joint_velocities=joint_velocities.astype(float),
|
191
|
+
tangential_deformation=(
|
192
|
+
soft_contacts_state.tangential_deformation
|
193
|
+
if soft_contacts_state is not None
|
194
|
+
else None
|
196
195
|
),
|
197
|
-
soft_contacts_state=soft_contacts_state,
|
198
196
|
)
|
199
197
|
|
200
|
-
if not ode_state.valid(
|
198
|
+
if not ode_state.valid(model=model):
|
201
199
|
raise ValueError(ode_state)
|
202
200
|
|
203
201
|
return JaxSimModelData(
|
@@ -222,10 +220,20 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
222
220
|
|
223
221
|
return self.time_ns.astype(float) / 1e9
|
224
222
|
|
223
|
+
def standard_gravity(self) -> jtp.Float:
|
224
|
+
"""
|
225
|
+
Get the standard gravity constant.
|
226
|
+
|
227
|
+
Returns:
|
228
|
+
The standard gravity constant.
|
229
|
+
"""
|
230
|
+
|
231
|
+
return -self.gravity[2]
|
232
|
+
|
225
233
|
@functools.partial(jax.jit, static_argnames=["joint_names"])
|
226
234
|
def joint_positions(
|
227
235
|
self,
|
228
|
-
model:
|
236
|
+
model: js.model.JaxSimModel | None = None,
|
229
237
|
joint_names: tuple[str, ...] | None = None,
|
230
238
|
) -> jtp.Vector:
|
231
239
|
"""
|
@@ -250,22 +258,27 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
250
258
|
"""
|
251
259
|
|
252
260
|
if model is None:
|
261
|
+
if joint_names is not None:
|
262
|
+
raise ValueError("Joint names cannot be provided without a model")
|
263
|
+
|
253
264
|
return self.state.physics_model.joint_positions
|
254
265
|
|
255
|
-
if not self.valid(
|
266
|
+
if not_tracing(self.state.physics_model.joint_positions) and not self.valid(
|
267
|
+
model=model
|
268
|
+
):
|
256
269
|
msg = "The data object is not compatible with the provided model"
|
257
270
|
raise ValueError(msg)
|
258
271
|
|
259
272
|
joint_names = joint_names if joint_names is not None else model.joint_names()
|
260
273
|
|
261
274
|
return self.state.physics_model.joint_positions[
|
262
|
-
|
275
|
+
js.joint.names_to_idxs(joint_names=joint_names, model=model)
|
263
276
|
]
|
264
277
|
|
265
278
|
@functools.partial(jax.jit, static_argnames=["joint_names"])
|
266
279
|
def joint_velocities(
|
267
280
|
self,
|
268
|
-
model:
|
281
|
+
model: js.model.JaxSimModel | None = None,
|
269
282
|
joint_names: tuple[str, ...] | None = None,
|
270
283
|
) -> jtp.Vector:
|
271
284
|
"""
|
@@ -290,16 +303,21 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
290
303
|
"""
|
291
304
|
|
292
305
|
if model is None:
|
306
|
+
if joint_names is not None:
|
307
|
+
raise ValueError("Joint names cannot be provided without a model")
|
308
|
+
|
293
309
|
return self.state.physics_model.joint_velocities
|
294
310
|
|
295
|
-
if not self.valid(
|
311
|
+
if not_tracing(self.state.physics_model.joint_velocities) and not self.valid(
|
312
|
+
model=model
|
313
|
+
):
|
296
314
|
msg = "The data object is not compatible with the provided model"
|
297
315
|
raise ValueError(msg)
|
298
316
|
|
299
317
|
joint_names = joint_names if joint_names is not None else model.joint_names()
|
300
318
|
|
301
319
|
return self.state.physics_model.joint_velocities[
|
302
|
-
|
320
|
+
js.joint.names_to_idxs(joint_names=joint_names, model=model)
|
303
321
|
]
|
304
322
|
|
305
323
|
@jax.jit
|
@@ -325,26 +343,27 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
325
343
|
The base orientation.
|
326
344
|
"""
|
327
345
|
|
346
|
+
# Extract the base quaternion.
|
347
|
+
W_Q_B = self.state.physics_model.base_quaternion.squeeze()
|
348
|
+
|
328
349
|
# Always normalize the quaternion to avoid numerical issues.
|
329
350
|
# If the active scheme does not integrate the quaternion on its manifold,
|
330
351
|
# we introduce a Baumgarte stabilization to let the quaternion converge to
|
331
352
|
# a unit quaternion. In this case, it is not guaranteed that the quaternion
|
332
353
|
# stored in the state is a unit quaternion.
|
333
|
-
|
334
|
-
|
335
|
-
|
354
|
+
W_Q_B = jax.lax.select(
|
355
|
+
pred=jnp.allclose(jnp.linalg.norm(W_Q_B), 1.0, atol=1e-6, rtol=0.0),
|
356
|
+
on_true=W_Q_B,
|
357
|
+
on_false=W_Q_B / jnp.linalg.norm(W_Q_B),
|
336
358
|
)
|
337
359
|
|
338
|
-
# Slice to convert quaternion wxyz -> xyzw
|
339
|
-
to_xyzw = np.array([1, 2, 3, 0])
|
340
|
-
|
341
360
|
return (
|
342
|
-
|
361
|
+
W_Q_B
|
343
362
|
if not dcm
|
344
363
|
else jaxlie.SO3.from_quaternion_xyzw(
|
345
|
-
|
364
|
+
Quaternion.to_xyzw(wxyz=W_Q_B)
|
346
365
|
).as_matrix()
|
347
|
-
)
|
366
|
+
).astype(float)
|
348
367
|
|
349
368
|
@jax.jit
|
350
369
|
def base_transform(self) -> jtp.MatrixJax:
|
@@ -430,7 +449,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
430
449
|
def reset_joint_positions(
|
431
450
|
self,
|
432
451
|
positions: jtp.VectorLike,
|
433
|
-
model:
|
452
|
+
model: js.model.JaxSimModel | None = None,
|
434
453
|
joint_names: tuple[str, ...] | None = None,
|
435
454
|
) -> Self:
|
436
455
|
"""
|
@@ -460,7 +479,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
460
479
|
if model is None:
|
461
480
|
return replace(s=positions)
|
462
481
|
|
463
|
-
if not self.valid(model=model):
|
482
|
+
if not_tracing(positions) and not self.valid(model=model):
|
464
483
|
msg = "The data object is not compatible with the provided model"
|
465
484
|
raise ValueError(msg)
|
466
485
|
|
@@ -468,7 +487,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
468
487
|
|
469
488
|
return replace(
|
470
489
|
s=self.state.physics_model.joint_positions.at[
|
471
|
-
|
490
|
+
js.joint.names_to_idxs(joint_names=joint_names, model=model)
|
472
491
|
].set(positions)
|
473
492
|
)
|
474
493
|
|
@@ -476,7 +495,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
476
495
|
def reset_joint_velocities(
|
477
496
|
self,
|
478
497
|
velocities: jtp.VectorLike,
|
479
|
-
model:
|
498
|
+
model: js.model.JaxSimModel | None = None,
|
480
499
|
joint_names: tuple[str, ...] | None = None,
|
481
500
|
) -> Self:
|
482
501
|
"""
|
@@ -506,7 +525,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
506
525
|
if model is None:
|
507
526
|
return replace(ṡ=velocities)
|
508
527
|
|
509
|
-
if not self.valid(model=model):
|
528
|
+
if not_tracing(velocities) and not self.valid(model=model):
|
510
529
|
msg = "The data object is not compatible with the provided model"
|
511
530
|
raise ValueError(msg)
|
512
531
|
|
@@ -514,7 +533,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
514
533
|
|
515
534
|
return replace(
|
516
535
|
ṡ=self.state.physics_model.joint_velocities.at[
|
517
|
-
|
536
|
+
js.joint.names_to_idxs(joint_names=joint_names, model=model)
|
518
537
|
].set(velocities)
|
519
538
|
)
|
520
539
|
|
@@ -692,7 +711,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
692
711
|
|
693
712
|
|
694
713
|
def random_model_data(
|
695
|
-
model:
|
714
|
+
model: js.model.JaxSimModel,
|
696
715
|
*,
|
697
716
|
key: jax.Array | None = None,
|
698
717
|
velocity_representation: VelRepr | None = None,
|
@@ -712,6 +731,10 @@ def random_model_data(
|
|
712
731
|
jtp.FloatLike | Sequence[jtp.FloatLike],
|
713
732
|
jtp.FloatLike | Sequence[jtp.FloatLike],
|
714
733
|
] = (-1.0, 1.0),
|
734
|
+
standard_gravity_bounds: tuple[jtp.FloatLike, jtp.FloatLike] = (
|
735
|
+
jaxsim.math.StandardGravity,
|
736
|
+
jaxsim.math.StandardGravity,
|
737
|
+
),
|
715
738
|
) -> JaxSimModelData:
|
716
739
|
"""
|
717
740
|
Randomly generate a `JaxSimModelData` object.
|
@@ -724,13 +747,14 @@ def random_model_data(
|
|
724
747
|
base_vel_lin_bounds: The bounds for the base linear velocity.
|
725
748
|
base_vel_ang_bounds: The bounds for the base angular velocity.
|
726
749
|
joint_vel_bounds: The bounds for the joint velocities.
|
750
|
+
standard_gravity_bounds: The bounds for the standard gravity.
|
727
751
|
|
728
752
|
Returns:
|
729
753
|
A `JaxSimModelData` object with random data.
|
730
754
|
"""
|
731
755
|
|
732
756
|
key = key if key is not None else jax.random.PRNGKey(seed=0)
|
733
|
-
k1, k2, k3, k4, k5, k6 = jax.random.split(key, num=
|
757
|
+
k1, k2, k3, k4, k5, k6, k7 = jax.random.split(key, num=7)
|
734
758
|
|
735
759
|
p_min = jnp.array(base_pos_bounds[0], dtype=float)
|
736
760
|
p_max = jnp.array(base_pos_bounds[1], dtype=float)
|
@@ -749,7 +773,9 @@ def random_model_data(
|
|
749
773
|
),
|
750
774
|
)
|
751
775
|
|
752
|
-
with random_data.mutable_context(
|
776
|
+
with random_data.mutable_context(
|
777
|
+
mutability=Mutability.MUTABLE, restore_after_exception=False
|
778
|
+
):
|
753
779
|
|
754
780
|
physics_model_state = random_data.state.physics_model
|
755
781
|
|
@@ -761,20 +787,35 @@ def random_model_data(
|
|
761
787
|
*jax.random.uniform(key=k2, shape=(3,), minval=0, maxval=2 * jnp.pi)
|
762
788
|
).as_quaternion_xyzw()[np.array([3, 0, 1, 2])]
|
763
789
|
|
764
|
-
|
765
|
-
|
766
|
-
|
790
|
+
if model.number_of_joints() > 0:
|
791
|
+
physics_model_state.joint_positions = js.joint.random_joint_positions(
|
792
|
+
model=model, key=k3
|
793
|
+
)
|
767
794
|
|
768
|
-
|
769
|
-
|
770
|
-
|
795
|
+
physics_model_state.joint_velocities = jax.random.uniform(
|
796
|
+
key=k4, shape=(model.dofs(),), minval=ṡ_min, maxval=ṡ_max
|
797
|
+
)
|
771
798
|
|
772
|
-
|
773
|
-
|
774
|
-
|
799
|
+
if model.floating_base():
|
800
|
+
physics_model_state.base_linear_velocity = jax.random.uniform(
|
801
|
+
key=k5, shape=(3,), minval=v_min, maxval=v_max
|
802
|
+
)
|
803
|
+
|
804
|
+
physics_model_state.base_angular_velocity = jax.random.uniform(
|
805
|
+
key=k6, shape=(3,), minval=ω_min, maxval=ω_max
|
806
|
+
)
|
775
807
|
|
776
|
-
|
777
|
-
|
808
|
+
random_data.gravity = (
|
809
|
+
jnp.zeros(3, dtype=random_data.gravity.dtype)
|
810
|
+
.at[2]
|
811
|
+
.set(
|
812
|
+
-jax.random.uniform(
|
813
|
+
key=k7,
|
814
|
+
shape=(),
|
815
|
+
minval=standard_gravity_bounds[0],
|
816
|
+
maxval=standard_gravity_bounds[1],
|
817
|
+
)
|
818
|
+
)
|
778
819
|
)
|
779
820
|
|
780
821
|
return random_data
|
jaxsim/api/joint.py
CHANGED
@@ -3,17 +3,18 @@ from typing import Sequence
|
|
3
3
|
|
4
4
|
import jax
|
5
5
|
import jax.numpy as jnp
|
6
|
+
import numpy as np
|
6
7
|
|
8
|
+
import jaxsim.api as js
|
7
9
|
import jaxsim.typing as jtp
|
8
10
|
|
9
|
-
from . import model as Model
|
10
|
-
|
11
11
|
# =======================
|
12
12
|
# Index-related functions
|
13
13
|
# =======================
|
14
14
|
|
15
15
|
|
16
|
-
|
16
|
+
@functools.partial(jax.jit, static_argnames="joint_name")
|
17
|
+
def name_to_idx(model: js.model.JaxSimModel, *, joint_name: str) -> jtp.Int:
|
17
18
|
"""
|
18
19
|
Convert the name of a joint to its index.
|
19
20
|
|
@@ -25,12 +26,25 @@ def name_to_idx(model: Model.JaxSimModel, *, joint_name: str) -> jtp.Int:
|
|
25
26
|
The index of the joint.
|
26
27
|
"""
|
27
28
|
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
29
|
+
if joint_name in model.kin_dyn_parameters.joint_model.joint_names:
|
30
|
+
# Note: the index of the joint for RBDAs starts from 1, but
|
31
|
+
# the index for accessing the right element starts from 0.
|
32
|
+
# Therefore, there is a -1.
|
33
|
+
return (
|
34
|
+
jnp.array(
|
35
|
+
np.argwhere(
|
36
|
+
np.array(model.kin_dyn_parameters.joint_model.joint_names)
|
37
|
+
== joint_name
|
38
|
+
)
|
39
|
+
- 1
|
40
|
+
)
|
41
|
+
.squeeze()
|
42
|
+
.astype(int)
|
43
|
+
)
|
44
|
+
return jnp.array(-1).astype(int)
|
45
|
+
|
46
|
+
|
47
|
+
def idx_to_name(model: js.model.JaxSimModel, *, joint_index: jtp.IntLike) -> str:
|
34
48
|
"""
|
35
49
|
Convert the index of a joint to its name.
|
36
50
|
|
@@ -42,11 +56,13 @@ def idx_to_name(model: Model.JaxSimModel, *, joint_index: jtp.IntLike) -> str:
|
|
42
56
|
The name of the joint.
|
43
57
|
"""
|
44
58
|
|
45
|
-
|
46
|
-
return d[joint_index]
|
59
|
+
return model.kin_dyn_parameters.joint_model.joint_names[joint_index + 1]
|
47
60
|
|
48
61
|
|
49
|
-
|
62
|
+
@functools.partial(jax.jit, static_argnames="joint_names")
|
63
|
+
def names_to_idxs(
|
64
|
+
model: js.model.JaxSimModel, *, joint_names: Sequence[str]
|
65
|
+
) -> jax.Array:
|
50
66
|
"""
|
51
67
|
Convert a sequence of joint names to their corresponding indices.
|
52
68
|
|
@@ -59,19 +75,14 @@ def names_to_idxs(model: Model.JaxSimModel, *, joint_names: Sequence[str]) -> ja
|
|
59
75
|
"""
|
60
76
|
|
61
77
|
return jnp.array(
|
62
|
-
[
|
63
|
-
|
64
|
-
# the index for accessing the right element starts from 0.
|
65
|
-
# Therefore, there is a -1.
|
66
|
-
model.physics_model.description.joints_dict[name].index - 1
|
67
|
-
for name in joint_names
|
68
|
-
],
|
69
|
-
dtype=int,
|
70
|
-
)
|
78
|
+
[name_to_idx(model=model, joint_name=name) for name in joint_names],
|
79
|
+
).astype(int)
|
71
80
|
|
72
81
|
|
73
82
|
def idxs_to_names(
|
74
|
-
model:
|
83
|
+
model: js.model.JaxSimModel,
|
84
|
+
*,
|
85
|
+
joint_indices: Sequence[jtp.IntLike] | jtp.VectorLike,
|
75
86
|
) -> tuple[str, ...]:
|
76
87
|
"""
|
77
88
|
Convert a sequence of joint indices to their corresponding names.
|
@@ -84,12 +95,7 @@ def idxs_to_names(
|
|
84
95
|
The names of the joints.
|
85
96
|
"""
|
86
97
|
|
87
|
-
|
88
|
-
j.index - 1: j.name
|
89
|
-
for j in model.physics_model.description.joints_dict.values()
|
90
|
-
}
|
91
|
-
|
92
|
-
return tuple(d[i] for i in joint_indices)
|
98
|
+
return tuple(idx_to_name(model=model, joint_index=idx) for idx in joint_indices)
|
93
99
|
|
94
100
|
|
95
101
|
# ============
|
@@ -99,23 +105,48 @@ def idxs_to_names(
|
|
99
105
|
|
100
106
|
@jax.jit
|
101
107
|
def position_limit(
|
102
|
-
model:
|
108
|
+
model: js.model.JaxSimModel, *, joint_index: jtp.IntLike
|
103
109
|
) -> tuple[jtp.Float, jtp.Float]:
|
104
|
-
"""
|
110
|
+
"""
|
111
|
+
Get the position limits of a joint.
|
112
|
+
|
113
|
+
Args:
|
114
|
+
model: The model to consider.
|
115
|
+
joint_index: The index of the joint.
|
116
|
+
|
117
|
+
Returns:
|
118
|
+
The position limits of the joint.
|
119
|
+
"""
|
120
|
+
|
121
|
+
if model.number_of_joints() <= 1:
|
122
|
+
return jnp.empty(0).astype(float), jnp.empty(0).astype(float)
|
105
123
|
|
106
|
-
|
107
|
-
|
124
|
+
s_min = model.kin_dyn_parameters.joint_parameters.position_limits_min[joint_index]
|
125
|
+
s_max = model.kin_dyn_parameters.joint_parameters.position_limits_max[joint_index]
|
108
126
|
|
109
|
-
return
|
127
|
+
return s_min.astype(float), s_max.astype(float)
|
110
128
|
|
111
129
|
|
112
130
|
@functools.partial(jax.jit, static_argnames=["joint_names"])
|
113
131
|
def position_limits(
|
114
|
-
model:
|
132
|
+
model: js.model.JaxSimModel, *, joint_names: Sequence[str] | None = None
|
115
133
|
) -> tuple[jtp.Vector, jtp.Vector]:
|
134
|
+
"""
|
135
|
+
Get the position limits of a list of joint.
|
136
|
+
|
137
|
+
Args:
|
138
|
+
model: The model to consider.
|
139
|
+
joint_names: The names of the joints.
|
140
|
+
|
141
|
+
Returns:
|
142
|
+
The position limits of the joints.
|
143
|
+
"""
|
116
144
|
|
117
145
|
joint_names = joint_names if joint_names is not None else model.joint_names()
|
118
146
|
|
147
|
+
if len(joint_names) == 0:
|
148
|
+
return jnp.empty(0).astype(float), jnp.empty(0).astype(float)
|
149
|
+
|
119
150
|
joint_idxs = names_to_idxs(joint_names=joint_names, model=model)
|
120
151
|
return jax.vmap(lambda i: position_limit(model=model, joint_index=i))(joint_idxs)
|
121
152
|
|
@@ -127,12 +158,22 @@ def position_limits(
|
|
127
158
|
|
128
159
|
@functools.partial(jax.jit, static_argnames=["joint_names"])
|
129
160
|
def random_joint_positions(
|
130
|
-
model:
|
161
|
+
model: js.model.JaxSimModel,
|
131
162
|
*,
|
132
163
|
joint_names: Sequence[str] | None = None,
|
133
164
|
key: jax.Array | None = None,
|
134
165
|
) -> jtp.Vector:
|
135
|
-
"""
|
166
|
+
"""
|
167
|
+
Generate random joint positions.
|
168
|
+
|
169
|
+
Args:
|
170
|
+
model: The model to consider.
|
171
|
+
joint_names: The names of the joints.
|
172
|
+
key: The random key.
|
173
|
+
|
174
|
+
Returns:
|
175
|
+
The random joint positions.
|
176
|
+
"""
|
136
177
|
|
137
178
|
key = key if key is not None else jax.random.PRNGKey(seed=0)
|
138
179
|
|