jaxsim 0.2.dev188__py3-none-any.whl → 0.2.dev364__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. jaxsim/__init__.py +3 -4
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +3 -1
  4. jaxsim/api/com.py +240 -0
  5. jaxsim/api/common.py +13 -2
  6. jaxsim/api/contact.py +120 -43
  7. jaxsim/api/data.py +112 -71
  8. jaxsim/api/joint.py +77 -36
  9. jaxsim/api/kin_dyn_parameters.py +777 -0
  10. jaxsim/api/link.py +150 -75
  11. jaxsim/api/model.py +542 -269
  12. jaxsim/api/ode.py +88 -72
  13. jaxsim/api/ode_data.py +694 -0
  14. jaxsim/api/references.py +12 -11
  15. jaxsim/integrators/__init__.py +2 -2
  16. jaxsim/integrators/common.py +110 -24
  17. jaxsim/integrators/fixed_step.py +11 -67
  18. jaxsim/integrators/variable_step.py +610 -0
  19. jaxsim/math/__init__.py +11 -0
  20. jaxsim/math/adjoint.py +24 -2
  21. jaxsim/math/joint_model.py +335 -0
  22. jaxsim/math/quaternion.py +44 -3
  23. jaxsim/math/rotation.py +4 -4
  24. jaxsim/math/transform.py +93 -0
  25. jaxsim/parsers/descriptions/collision.py +14 -0
  26. jaxsim/parsers/descriptions/link.py +13 -2
  27. jaxsim/parsers/kinematic_graph.py +5 -0
  28. jaxsim/parsers/rod/utils.py +7 -8
  29. jaxsim/rbda/__init__.py +7 -0
  30. jaxsim/rbda/aba.py +295 -0
  31. jaxsim/rbda/collidable_points.py +142 -0
  32. jaxsim/{physics/algos → rbda}/crba.py +43 -42
  33. jaxsim/rbda/forward_kinematics.py +113 -0
  34. jaxsim/rbda/jacobian.py +201 -0
  35. jaxsim/rbda/rnea.py +237 -0
  36. jaxsim/rbda/soft_contacts.py +296 -0
  37. jaxsim/rbda/utils.py +152 -0
  38. jaxsim/terrain/__init__.py +2 -0
  39. jaxsim/{physics/algos → terrain}/terrain.py +4 -6
  40. jaxsim/utils/__init__.py +1 -4
  41. jaxsim/utils/hashless.py +18 -0
  42. jaxsim/utils/jaxsim_dataclass.py +281 -30
  43. {jaxsim-0.2.dev188.dist-info → jaxsim-0.2.dev364.dist-info}/METADATA +4 -6
  44. jaxsim-0.2.dev364.dist-info/RECORD +64 -0
  45. {jaxsim-0.2.dev188.dist-info → jaxsim-0.2.dev364.dist-info}/WHEEL +1 -1
  46. jaxsim/high_level/__init__.py +0 -2
  47. jaxsim/high_level/common.py +0 -11
  48. jaxsim/high_level/joint.py +0 -148
  49. jaxsim/high_level/link.py +0 -259
  50. jaxsim/high_level/model.py +0 -1686
  51. jaxsim/math/conv.py +0 -114
  52. jaxsim/math/joint.py +0 -102
  53. jaxsim/math/plucker.py +0 -100
  54. jaxsim/physics/__init__.py +0 -12
  55. jaxsim/physics/algos/__init__.py +0 -0
  56. jaxsim/physics/algos/aba.py +0 -254
  57. jaxsim/physics/algos/aba_motors.py +0 -284
  58. jaxsim/physics/algos/forward_kinematics.py +0 -79
  59. jaxsim/physics/algos/jacobian.py +0 -98
  60. jaxsim/physics/algos/rnea.py +0 -180
  61. jaxsim/physics/algos/rnea_motors.py +0 -196
  62. jaxsim/physics/algos/soft_contacts.py +0 -523
  63. jaxsim/physics/algos/utils.py +0 -69
  64. jaxsim/physics/model/__init__.py +0 -0
  65. jaxsim/physics/model/ground_contact.py +0 -55
  66. jaxsim/physics/model/physics_model.py +0 -388
  67. jaxsim/physics/model/physics_model_state.py +0 -283
  68. jaxsim/simulation/__init__.py +0 -4
  69. jaxsim/simulation/integrators.py +0 -393
  70. jaxsim/simulation/ode.py +0 -290
  71. jaxsim/simulation/ode_data.py +0 -96
  72. jaxsim/simulation/ode_integration.py +0 -62
  73. jaxsim/simulation/simulator.py +0 -543
  74. jaxsim/simulation/simulator_callbacks.py +0 -79
  75. jaxsim/simulation/utils.py +0 -15
  76. jaxsim/sixd/__init__.py +0 -2
  77. jaxsim/utils/oop.py +0 -536
  78. jaxsim/utils/vmappable.py +0 -117
  79. jaxsim-0.2.dev188.dist-info/RECORD +0 -81
  80. {jaxsim-0.2.dev188.dist-info → jaxsim-0.2.dev364.dist-info}/LICENSE +0 -0
  81. {jaxsim-0.2.dev188.dist-info → jaxsim-0.2.dev364.dist-info}/top_level.txt +0 -0
jaxsim/api/data.py CHANGED
@@ -10,20 +10,16 @@ import jax_dataclasses
10
10
  import jaxlie
11
11
  import numpy as np
12
12
 
13
- import jaxsim.api
14
- import jaxsim.physics.algos.aba
15
- import jaxsim.physics.algos.crba
16
- import jaxsim.physics.algos.forward_kinematics
17
- import jaxsim.physics.algos.rnea
18
- import jaxsim.physics.model.physics_model
19
- import jaxsim.physics.model.physics_model_state
13
+ import jaxsim.api as js
14
+ import jaxsim.rbda
20
15
  import jaxsim.typing as jtp
21
- from jaxsim.high_level.common import VelRepr
22
- from jaxsim.physics.algos import soft_contacts
23
- from jaxsim.simulation.ode_data import ODEState
16
+ from jaxsim.math import Quaternion
24
17
  from jaxsim.utils import Mutability
18
+ from jaxsim.utils.tracing import not_tracing
25
19
 
26
20
  from . import common
21
+ from .common import VelRepr
22
+ from .ode_data import ODEState
27
23
 
28
24
  try:
29
25
  from typing import Self
@@ -41,14 +37,13 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
41
37
 
42
38
  gravity: jtp.Array
43
39
 
44
- soft_contacts_params: soft_contacts.SoftContactsParams = dataclasses.field(
45
- repr=False
46
- )
40
+ soft_contacts_params: jaxsim.rbda.SoftContactsParams = dataclasses.field(repr=False)
41
+
47
42
  time_ns: jtp.Int = dataclasses.field(
48
43
  default_factory=lambda: jnp.array(0, dtype=jnp.uint64)
49
44
  )
50
45
 
51
- def valid(self, model: jaxsim.api.model.JaxSimModel | None = None) -> bool:
46
+ def valid(self, model: js.model.JaxSimModel | None = None) -> bool:
52
47
  """
53
48
  Check if the current state is valid for the given model.
54
49
 
@@ -60,15 +55,16 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
60
55
  """
61
56
 
62
57
  valid = True
58
+ valid = valid and self.standard_gravity() > 0
63
59
 
64
60
  if model is not None:
65
- valid = valid and self.state.valid(physics_model=model.physics_model)
61
+ valid = valid and self.state.valid(model=model)
66
62
 
67
63
  return valid
68
64
 
69
65
  @staticmethod
70
66
  def zero(
71
- model: jaxsim.api.model.JaxSimModel,
67
+ model: js.model.JaxSimModel,
72
68
  velocity_representation: VelRepr = VelRepr.Inertial,
73
69
  ) -> JaxSimModelData:
74
70
  """
@@ -88,16 +84,16 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
88
84
 
89
85
  @staticmethod
90
86
  def build(
91
- model: jaxsim.api.model.JaxSimModel,
87
+ model: js.model.JaxSimModel,
92
88
  base_position: jtp.Vector | None = None,
93
89
  base_quaternion: jtp.Vector | None = None,
94
90
  joint_positions: jtp.Vector | None = None,
95
91
  base_linear_velocity: jtp.Vector | None = None,
96
92
  base_angular_velocity: jtp.Vector | None = None,
97
93
  joint_velocities: jtp.Vector | None = None,
98
- gravity: jtp.Vector | None = None,
99
- soft_contacts_state: soft_contacts.SoftContactsState | None = None,
100
- soft_contacts_params: soft_contacts.SoftContactsParams | None = None,
94
+ standard_gravity: jtp.FloatLike = jaxsim.math.StandardGravity,
95
+ soft_contacts_state: js.ode_data.SoftContactsState | None = None,
96
+ soft_contacts_params: jaxsim.rbda.SoftContactsParams | None = None,
101
97
  velocity_representation: VelRepr = VelRepr.Inertial,
102
98
  time: jtp.FloatLike | None = None,
103
99
  ) -> JaxSimModelData:
@@ -114,7 +110,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
114
110
  base_angular_velocity:
115
111
  The base angular velocity in the selected representation.
116
112
  joint_velocities: The joint velocities.
117
- gravity: The gravity 3D vector.
113
+ standard_gravity: The standard gravity constant.
118
114
  soft_contacts_state: The state of the soft contacts.
119
115
  soft_contacts_params: The parameters of the soft contacts.
120
116
  velocity_representation: The velocity representation to use.
@@ -142,9 +138,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
142
138
  base_angular_velocity if base_angular_velocity is not None else jnp.zeros(3)
143
139
  ).squeeze()
144
140
 
145
- gravity = jnp.array(
146
- gravity if gravity is not None else model.physics_model.gravity[0:3]
147
- ).squeeze()
141
+ gravity = jnp.zeros(3).at[2].set(-standard_gravity)
148
142
 
149
143
  joint_positions = jnp.atleast_1d(
150
144
  joint_positions.squeeze()
@@ -167,7 +161,9 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
167
161
  soft_contacts_params = (
168
162
  soft_contacts_params
169
163
  if soft_contacts_params is not None
170
- else jaxsim.api.contact.estimate_good_soft_contacts_parameters(model=model)
164
+ else js.contact.estimate_good_soft_contacts_parameters(
165
+ model=model, standard_gravity=standard_gravity
166
+ )
171
167
  )
172
168
 
173
169
  W_H_B = jaxlie.SE3.from_rotation_and_translation(
@@ -184,20 +180,22 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
184
180
  is_force=False,
185
181
  )
186
182
 
187
- ode_state = ODEState.build(
188
- physics_model=model.physics_model,
189
- physics_model_state=jaxsim.physics.model.physics_model_state.PhysicsModelState(
190
- base_position=base_position.astype(float),
191
- base_quaternion=base_quaternion.astype(float),
192
- joint_positions=joint_positions.astype(float),
193
- base_linear_velocity=v_WB[0:3].astype(float),
194
- base_angular_velocity=v_WB[3:6].astype(float),
195
- joint_velocities=joint_velocities.astype(float),
183
+ ode_state = ODEState.build_from_jaxsim_model(
184
+ model=model,
185
+ base_position=base_position.astype(float),
186
+ base_quaternion=base_quaternion.astype(float),
187
+ joint_positions=joint_positions.astype(float),
188
+ base_linear_velocity=v_WB[0:3].astype(float),
189
+ base_angular_velocity=v_WB[3:6].astype(float),
190
+ joint_velocities=joint_velocities.astype(float),
191
+ tangential_deformation=(
192
+ soft_contacts_state.tangential_deformation
193
+ if soft_contacts_state is not None
194
+ else None
196
195
  ),
197
- soft_contacts_state=soft_contacts_state,
198
196
  )
199
197
 
200
- if not ode_state.valid(physics_model=model.physics_model):
198
+ if not ode_state.valid(model=model):
201
199
  raise ValueError(ode_state)
202
200
 
203
201
  return JaxSimModelData(
@@ -222,10 +220,20 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
222
220
 
223
221
  return self.time_ns.astype(float) / 1e9
224
222
 
223
+ def standard_gravity(self) -> jtp.Float:
224
+ """
225
+ Get the standard gravity constant.
226
+
227
+ Returns:
228
+ The standard gravity constant.
229
+ """
230
+
231
+ return -self.gravity[2]
232
+
225
233
  @functools.partial(jax.jit, static_argnames=["joint_names"])
226
234
  def joint_positions(
227
235
  self,
228
- model: jaxsim.api.model.JaxSimModel | None = None,
236
+ model: js.model.JaxSimModel | None = None,
229
237
  joint_names: tuple[str, ...] | None = None,
230
238
  ) -> jtp.Vector:
231
239
  """
@@ -250,22 +258,27 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
250
258
  """
251
259
 
252
260
  if model is None:
261
+ if joint_names is not None:
262
+ raise ValueError("Joint names cannot be provided without a model")
263
+
253
264
  return self.state.physics_model.joint_positions
254
265
 
255
- if not self.valid(model=model):
266
+ if not_tracing(self.state.physics_model.joint_positions) and not self.valid(
267
+ model=model
268
+ ):
256
269
  msg = "The data object is not compatible with the provided model"
257
270
  raise ValueError(msg)
258
271
 
259
272
  joint_names = joint_names if joint_names is not None else model.joint_names()
260
273
 
261
274
  return self.state.physics_model.joint_positions[
262
- jaxsim.api.joint.names_to_idxs(joint_names=joint_names, model=model)
275
+ js.joint.names_to_idxs(joint_names=joint_names, model=model)
263
276
  ]
264
277
 
265
278
  @functools.partial(jax.jit, static_argnames=["joint_names"])
266
279
  def joint_velocities(
267
280
  self,
268
- model: jaxsim.api.model.JaxSimModel | None = None,
281
+ model: js.model.JaxSimModel | None = None,
269
282
  joint_names: tuple[str, ...] | None = None,
270
283
  ) -> jtp.Vector:
271
284
  """
@@ -290,16 +303,21 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
290
303
  """
291
304
 
292
305
  if model is None:
306
+ if joint_names is not None:
307
+ raise ValueError("Joint names cannot be provided without a model")
308
+
293
309
  return self.state.physics_model.joint_velocities
294
310
 
295
- if not self.valid(model=model):
311
+ if not_tracing(self.state.physics_model.joint_velocities) and not self.valid(
312
+ model=model
313
+ ):
296
314
  msg = "The data object is not compatible with the provided model"
297
315
  raise ValueError(msg)
298
316
 
299
317
  joint_names = joint_names if joint_names is not None else model.joint_names()
300
318
 
301
319
  return self.state.physics_model.joint_velocities[
302
- jaxsim.api.joint.names_to_idxs(joint_names=joint_names, model=model)
320
+ js.joint.names_to_idxs(joint_names=joint_names, model=model)
303
321
  ]
304
322
 
305
323
  @jax.jit
@@ -325,26 +343,27 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
325
343
  The base orientation.
326
344
  """
327
345
 
346
+ # Extract the base quaternion.
347
+ W_Q_B = self.state.physics_model.base_quaternion.squeeze()
348
+
328
349
  # Always normalize the quaternion to avoid numerical issues.
329
350
  # If the active scheme does not integrate the quaternion on its manifold,
330
351
  # we introduce a Baumgarte stabilization to let the quaternion converge to
331
352
  # a unit quaternion. In this case, it is not guaranteed that the quaternion
332
353
  # stored in the state is a unit quaternion.
333
- base_unit_quaternion = (
334
- self.state.physics_model.base_quaternion.squeeze()
335
- / jnp.linalg.norm(self.state.physics_model.base_quaternion)
354
+ W_Q_B = jax.lax.select(
355
+ pred=jnp.allclose(jnp.linalg.norm(W_Q_B), 1.0, atol=1e-6, rtol=0.0),
356
+ on_true=W_Q_B,
357
+ on_false=W_Q_B / jnp.linalg.norm(W_Q_B),
336
358
  )
337
359
 
338
- # Slice to convert quaternion wxyz -> xyzw
339
- to_xyzw = np.array([1, 2, 3, 0])
340
-
341
360
  return (
342
- base_unit_quaternion
361
+ W_Q_B
343
362
  if not dcm
344
363
  else jaxlie.SO3.from_quaternion_xyzw(
345
- base_unit_quaternion[to_xyzw]
364
+ Quaternion.to_xyzw(wxyz=W_Q_B)
346
365
  ).as_matrix()
347
- )
366
+ ).astype(float)
348
367
 
349
368
  @jax.jit
350
369
  def base_transform(self) -> jtp.MatrixJax:
@@ -430,7 +449,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
430
449
  def reset_joint_positions(
431
450
  self,
432
451
  positions: jtp.VectorLike,
433
- model: jaxsim.api.model.JaxSimModel | None = None,
452
+ model: js.model.JaxSimModel | None = None,
434
453
  joint_names: tuple[str, ...] | None = None,
435
454
  ) -> Self:
436
455
  """
@@ -460,7 +479,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
460
479
  if model is None:
461
480
  return replace(s=positions)
462
481
 
463
- if not self.valid(model=model):
482
+ if not_tracing(positions) and not self.valid(model=model):
464
483
  msg = "The data object is not compatible with the provided model"
465
484
  raise ValueError(msg)
466
485
 
@@ -468,7 +487,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
468
487
 
469
488
  return replace(
470
489
  s=self.state.physics_model.joint_positions.at[
471
- jaxsim.api.joint.names_to_idxs(joint_names=joint_names, model=model)
490
+ js.joint.names_to_idxs(joint_names=joint_names, model=model)
472
491
  ].set(positions)
473
492
  )
474
493
 
@@ -476,7 +495,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
476
495
  def reset_joint_velocities(
477
496
  self,
478
497
  velocities: jtp.VectorLike,
479
- model: jaxsim.api.model.JaxSimModel | None = None,
498
+ model: js.model.JaxSimModel | None = None,
480
499
  joint_names: tuple[str, ...] | None = None,
481
500
  ) -> Self:
482
501
  """
@@ -506,7 +525,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
506
525
  if model is None:
507
526
  return replace(ṡ=velocities)
508
527
 
509
- if not self.valid(model=model):
528
+ if not_tracing(velocities) and not self.valid(model=model):
510
529
  msg = "The data object is not compatible with the provided model"
511
530
  raise ValueError(msg)
512
531
 
@@ -514,7 +533,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
514
533
 
515
534
  return replace(
516
535
  ṡ=self.state.physics_model.joint_velocities.at[
517
- jaxsim.api.joint.names_to_idxs(joint_names=joint_names, model=model)
536
+ js.joint.names_to_idxs(joint_names=joint_names, model=model)
518
537
  ].set(velocities)
519
538
  )
520
539
 
@@ -692,7 +711,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
692
711
 
693
712
 
694
713
  def random_model_data(
695
- model: jaxsim.api.model.JaxSimModel,
714
+ model: js.model.JaxSimModel,
696
715
  *,
697
716
  key: jax.Array | None = None,
698
717
  velocity_representation: VelRepr | None = None,
@@ -712,6 +731,10 @@ def random_model_data(
712
731
  jtp.FloatLike | Sequence[jtp.FloatLike],
713
732
  jtp.FloatLike | Sequence[jtp.FloatLike],
714
733
  ] = (-1.0, 1.0),
734
+ standard_gravity_bounds: tuple[jtp.FloatLike, jtp.FloatLike] = (
735
+ jaxsim.math.StandardGravity,
736
+ jaxsim.math.StandardGravity,
737
+ ),
715
738
  ) -> JaxSimModelData:
716
739
  """
717
740
  Randomly generate a `JaxSimModelData` object.
@@ -724,13 +747,14 @@ def random_model_data(
724
747
  base_vel_lin_bounds: The bounds for the base linear velocity.
725
748
  base_vel_ang_bounds: The bounds for the base angular velocity.
726
749
  joint_vel_bounds: The bounds for the joint velocities.
750
+ standard_gravity_bounds: The bounds for the standard gravity.
727
751
 
728
752
  Returns:
729
753
  A `JaxSimModelData` object with random data.
730
754
  """
731
755
 
732
756
  key = key if key is not None else jax.random.PRNGKey(seed=0)
733
- k1, k2, k3, k4, k5, k6 = jax.random.split(key, num=6)
757
+ k1, k2, k3, k4, k5, k6, k7 = jax.random.split(key, num=7)
734
758
 
735
759
  p_min = jnp.array(base_pos_bounds[0], dtype=float)
736
760
  p_max = jnp.array(base_pos_bounds[1], dtype=float)
@@ -749,7 +773,9 @@ def random_model_data(
749
773
  ),
750
774
  )
751
775
 
752
- with random_data.mutable_context(mutability=Mutability.MUTABLE):
776
+ with random_data.mutable_context(
777
+ mutability=Mutability.MUTABLE, restore_after_exception=False
778
+ ):
753
779
 
754
780
  physics_model_state = random_data.state.physics_model
755
781
 
@@ -761,20 +787,35 @@ def random_model_data(
761
787
  *jax.random.uniform(key=k2, shape=(3,), minval=0, maxval=2 * jnp.pi)
762
788
  ).as_quaternion_xyzw()[np.array([3, 0, 1, 2])]
763
789
 
764
- physics_model_state.joint_positions = jaxsim.api.joint.random_joint_positions(
765
- model=model, key=k3
766
- )
790
+ if model.number_of_joints() > 0:
791
+ physics_model_state.joint_positions = js.joint.random_joint_positions(
792
+ model=model, key=k3
793
+ )
767
794
 
768
- physics_model_state.base_linear_velocity = jax.random.uniform(
769
- key=k4, shape=(3,), minval=v_min, maxval=v_max
770
- )
795
+ physics_model_state.joint_velocities = jax.random.uniform(
796
+ key=k4, shape=(model.dofs(),), minval=ṡ_min, maxval=ṡ_max
797
+ )
771
798
 
772
- physics_model_state.base_angular_velocity = jax.random.uniform(
773
- key=k5, shape=(3,), minval=ω_min, maxval=ω_max
774
- )
799
+ if model.floating_base():
800
+ physics_model_state.base_linear_velocity = jax.random.uniform(
801
+ key=k5, shape=(3,), minval=v_min, maxval=v_max
802
+ )
803
+
804
+ physics_model_state.base_angular_velocity = jax.random.uniform(
805
+ key=k6, shape=(3,), minval=ω_min, maxval=ω_max
806
+ )
775
807
 
776
- physics_model_state.joint_velocities = jax.random.uniform(
777
- key=k6, shape=(model.dofs(),), minval=ṡ_min, maxval=ṡ_max
808
+ random_data.gravity = (
809
+ jnp.zeros(3, dtype=random_data.gravity.dtype)
810
+ .at[2]
811
+ .set(
812
+ -jax.random.uniform(
813
+ key=k7,
814
+ shape=(),
815
+ minval=standard_gravity_bounds[0],
816
+ maxval=standard_gravity_bounds[1],
817
+ )
818
+ )
778
819
  )
779
820
 
780
821
  return random_data
jaxsim/api/joint.py CHANGED
@@ -3,17 +3,18 @@ from typing import Sequence
3
3
 
4
4
  import jax
5
5
  import jax.numpy as jnp
6
+ import numpy as np
6
7
 
8
+ import jaxsim.api as js
7
9
  import jaxsim.typing as jtp
8
10
 
9
- from . import model as Model
10
-
11
11
  # =======================
12
12
  # Index-related functions
13
13
  # =======================
14
14
 
15
15
 
16
- def name_to_idx(model: Model.JaxSimModel, *, joint_name: str) -> jtp.Int:
16
+ @functools.partial(jax.jit, static_argnames="joint_name")
17
+ def name_to_idx(model: js.model.JaxSimModel, *, joint_name: str) -> jtp.Int:
17
18
  """
18
19
  Convert the name of a joint to its index.
19
20
 
@@ -25,12 +26,25 @@ def name_to_idx(model: Model.JaxSimModel, *, joint_name: str) -> jtp.Int:
25
26
  The index of the joint.
26
27
  """
27
28
 
28
- return jnp.array(
29
- model.physics_model.description.joints_dict[joint_name].index, dtype=int
30
- )
31
-
32
-
33
- def idx_to_name(model: Model.JaxSimModel, *, joint_index: jtp.IntLike) -> str:
29
+ if joint_name in model.kin_dyn_parameters.joint_model.joint_names:
30
+ # Note: the index of the joint for RBDAs starts from 1, but
31
+ # the index for accessing the right element starts from 0.
32
+ # Therefore, there is a -1.
33
+ return (
34
+ jnp.array(
35
+ np.argwhere(
36
+ np.array(model.kin_dyn_parameters.joint_model.joint_names)
37
+ == joint_name
38
+ )
39
+ - 1
40
+ )
41
+ .squeeze()
42
+ .astype(int)
43
+ )
44
+ return jnp.array(-1).astype(int)
45
+
46
+
47
+ def idx_to_name(model: js.model.JaxSimModel, *, joint_index: jtp.IntLike) -> str:
34
48
  """
35
49
  Convert the index of a joint to its name.
36
50
 
@@ -42,11 +56,13 @@ def idx_to_name(model: Model.JaxSimModel, *, joint_index: jtp.IntLike) -> str:
42
56
  The name of the joint.
43
57
  """
44
58
 
45
- d = {j.index: j.name for j in model.physics_model.description.joints_dict.values()}
46
- return d[joint_index]
59
+ return model.kin_dyn_parameters.joint_model.joint_names[joint_index + 1]
47
60
 
48
61
 
49
- def names_to_idxs(model: Model.JaxSimModel, *, joint_names: Sequence[str]) -> jax.Array:
62
+ @functools.partial(jax.jit, static_argnames="joint_names")
63
+ def names_to_idxs(
64
+ model: js.model.JaxSimModel, *, joint_names: Sequence[str]
65
+ ) -> jax.Array:
50
66
  """
51
67
  Convert a sequence of joint names to their corresponding indices.
52
68
 
@@ -59,19 +75,14 @@ def names_to_idxs(model: Model.JaxSimModel, *, joint_names: Sequence[str]) -> ja
59
75
  """
60
76
 
61
77
  return jnp.array(
62
- [
63
- # Note: the index of the joint for RBDAs starts from 1, but
64
- # the index for accessing the right element starts from 0.
65
- # Therefore, there is a -1.
66
- model.physics_model.description.joints_dict[name].index - 1
67
- for name in joint_names
68
- ],
69
- dtype=int,
70
- )
78
+ [name_to_idx(model=model, joint_name=name) for name in joint_names],
79
+ ).astype(int)
71
80
 
72
81
 
73
82
  def idxs_to_names(
74
- model: Model.JaxSimModel, *, joint_indices: Sequence[jtp.IntLike] | jtp.VectorLike
83
+ model: js.model.JaxSimModel,
84
+ *,
85
+ joint_indices: Sequence[jtp.IntLike] | jtp.VectorLike,
75
86
  ) -> tuple[str, ...]:
76
87
  """
77
88
  Convert a sequence of joint indices to their corresponding names.
@@ -84,12 +95,7 @@ def idxs_to_names(
84
95
  The names of the joints.
85
96
  """
86
97
 
87
- d = {
88
- j.index - 1: j.name
89
- for j in model.physics_model.description.joints_dict.values()
90
- }
91
-
92
- return tuple(d[i] for i in joint_indices)
98
+ return tuple(idx_to_name(model=model, joint_index=idx) for idx in joint_indices)
93
99
 
94
100
 
95
101
  # ============
@@ -99,23 +105,48 @@ def idxs_to_names(
99
105
 
100
106
  @jax.jit
101
107
  def position_limit(
102
- model: Model.JaxSimModel, *, joint_index: jtp.IntLike
108
+ model: js.model.JaxSimModel, *, joint_index: jtp.IntLike
103
109
  ) -> tuple[jtp.Float, jtp.Float]:
104
- """"""
110
+ """
111
+ Get the position limits of a joint.
112
+
113
+ Args:
114
+ model: The model to consider.
115
+ joint_index: The index of the joint.
116
+
117
+ Returns:
118
+ The position limits of the joint.
119
+ """
120
+
121
+ if model.number_of_joints() <= 1:
122
+ return jnp.empty(0).astype(float), jnp.empty(0).astype(float)
105
123
 
106
- min = model.physics_model._joint_position_limits_min[joint_index]
107
- max = model.physics_model._joint_position_limits_max[joint_index]
124
+ s_min = model.kin_dyn_parameters.joint_parameters.position_limits_min[joint_index]
125
+ s_max = model.kin_dyn_parameters.joint_parameters.position_limits_max[joint_index]
108
126
 
109
- return min.astype(float), max.astype(float)
127
+ return s_min.astype(float), s_max.astype(float)
110
128
 
111
129
 
112
130
  @functools.partial(jax.jit, static_argnames=["joint_names"])
113
131
  def position_limits(
114
- model: Model.JaxSimModel, *, joint_names: Sequence[str] | None = None
132
+ model: js.model.JaxSimModel, *, joint_names: Sequence[str] | None = None
115
133
  ) -> tuple[jtp.Vector, jtp.Vector]:
134
+ """
135
+ Get the position limits of a list of joint.
136
+
137
+ Args:
138
+ model: The model to consider.
139
+ joint_names: The names of the joints.
140
+
141
+ Returns:
142
+ The position limits of the joints.
143
+ """
116
144
 
117
145
  joint_names = joint_names if joint_names is not None else model.joint_names()
118
146
 
147
+ if len(joint_names) == 0:
148
+ return jnp.empty(0).astype(float), jnp.empty(0).astype(float)
149
+
119
150
  joint_idxs = names_to_idxs(joint_names=joint_names, model=model)
120
151
  return jax.vmap(lambda i: position_limit(model=model, joint_index=i))(joint_idxs)
121
152
 
@@ -127,12 +158,22 @@ def position_limits(
127
158
 
128
159
  @functools.partial(jax.jit, static_argnames=["joint_names"])
129
160
  def random_joint_positions(
130
- model: Model.JaxSimModel,
161
+ model: js.model.JaxSimModel,
131
162
  *,
132
163
  joint_names: Sequence[str] | None = None,
133
164
  key: jax.Array | None = None,
134
165
  ) -> jtp.Vector:
135
- """"""
166
+ """
167
+ Generate random joint positions.
168
+
169
+ Args:
170
+ model: The model to consider.
171
+ joint_names: The names of the joints.
172
+ key: The random key.
173
+
174
+ Returns:
175
+ The random joint positions.
176
+ """
136
177
 
137
178
  key = key if key is not None else jax.random.PRNGKey(seed=0)
138
179