jaxsim 0.2.dev191__py3-none-any.whl → 0.6.1.dev5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. jaxsim/__init__.py +73 -22
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +13 -1
  4. jaxsim/api/com.py +423 -0
  5. jaxsim/api/common.py +48 -19
  6. jaxsim/api/contact.py +604 -52
  7. jaxsim/api/data.py +308 -163
  8. jaxsim/api/frame.py +471 -0
  9. jaxsim/api/joint.py +166 -37
  10. jaxsim/api/kin_dyn_parameters.py +901 -0
  11. jaxsim/api/link.py +277 -78
  12. jaxsim/api/model.py +1572 -362
  13. jaxsim/api/ode.py +324 -133
  14. jaxsim/api/ode_data.py +401 -0
  15. jaxsim/api/references.py +216 -80
  16. jaxsim/exceptions.py +80 -0
  17. jaxsim/integrators/__init__.py +2 -2
  18. jaxsim/integrators/common.py +191 -107
  19. jaxsim/integrators/fixed_step.py +97 -102
  20. jaxsim/integrators/variable_step.py +706 -0
  21. jaxsim/logging.py +1 -2
  22. jaxsim/math/__init__.py +13 -0
  23. jaxsim/math/adjoint.py +64 -30
  24. jaxsim/math/cross.py +18 -9
  25. jaxsim/math/inertia.py +11 -9
  26. jaxsim/math/joint_model.py +289 -0
  27. jaxsim/math/quaternion.py +59 -25
  28. jaxsim/math/rotation.py +30 -24
  29. jaxsim/math/skew.py +18 -7
  30. jaxsim/math/transform.py +102 -0
  31. jaxsim/math/utils.py +31 -0
  32. jaxsim/mujoco/__init__.py +2 -1
  33. jaxsim/mujoco/loaders.py +216 -29
  34. jaxsim/mujoco/model.py +163 -33
  35. jaxsim/mujoco/utils.py +228 -0
  36. jaxsim/mujoco/visualizer.py +107 -22
  37. jaxsim/parsers/__init__.py +0 -1
  38. jaxsim/parsers/descriptions/__init__.py +8 -2
  39. jaxsim/parsers/descriptions/collision.py +83 -26
  40. jaxsim/parsers/descriptions/joint.py +80 -87
  41. jaxsim/parsers/descriptions/link.py +58 -31
  42. jaxsim/parsers/descriptions/model.py +101 -68
  43. jaxsim/parsers/kinematic_graph.py +606 -229
  44. jaxsim/parsers/rod/meshes.py +104 -0
  45. jaxsim/parsers/rod/parser.py +125 -82
  46. jaxsim/parsers/rod/utils.py +127 -82
  47. jaxsim/rbda/__init__.py +11 -0
  48. jaxsim/rbda/aba.py +289 -0
  49. jaxsim/rbda/collidable_points.py +156 -0
  50. jaxsim/rbda/contacts/__init__.py +13 -0
  51. jaxsim/rbda/contacts/common.py +313 -0
  52. jaxsim/rbda/contacts/relaxed_rigid.py +605 -0
  53. jaxsim/rbda/contacts/rigid.py +462 -0
  54. jaxsim/rbda/contacts/soft.py +480 -0
  55. jaxsim/rbda/contacts/visco_elastic.py +1066 -0
  56. jaxsim/rbda/crba.py +167 -0
  57. jaxsim/rbda/forward_kinematics.py +117 -0
  58. jaxsim/rbda/jacobian.py +330 -0
  59. jaxsim/rbda/rnea.py +235 -0
  60. jaxsim/rbda/utils.py +160 -0
  61. jaxsim/terrain/__init__.py +2 -0
  62. jaxsim/terrain/terrain.py +238 -0
  63. jaxsim/typing.py +24 -24
  64. jaxsim/utils/__init__.py +1 -4
  65. jaxsim/utils/jaxsim_dataclass.py +289 -34
  66. jaxsim/utils/tracing.py +5 -11
  67. jaxsim/utils/wrappers.py +159 -0
  68. {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev5.dist-info}/LICENSE +1 -1
  69. jaxsim-0.6.1.dev5.dist-info/METADATA +465 -0
  70. jaxsim-0.6.1.dev5.dist-info/RECORD +74 -0
  71. {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev5.dist-info}/WHEEL +1 -1
  72. jaxsim/high_level/__init__.py +0 -2
  73. jaxsim/high_level/common.py +0 -11
  74. jaxsim/high_level/joint.py +0 -148
  75. jaxsim/high_level/link.py +0 -259
  76. jaxsim/high_level/model.py +0 -1686
  77. jaxsim/math/conv.py +0 -114
  78. jaxsim/math/joint.py +0 -102
  79. jaxsim/math/plucker.py +0 -100
  80. jaxsim/physics/__init__.py +0 -12
  81. jaxsim/physics/algos/__init__.py +0 -0
  82. jaxsim/physics/algos/aba.py +0 -254
  83. jaxsim/physics/algos/aba_motors.py +0 -284
  84. jaxsim/physics/algos/crba.py +0 -154
  85. jaxsim/physics/algos/forward_kinematics.py +0 -79
  86. jaxsim/physics/algos/jacobian.py +0 -98
  87. jaxsim/physics/algos/rnea.py +0 -180
  88. jaxsim/physics/algos/rnea_motors.py +0 -196
  89. jaxsim/physics/algos/soft_contacts.py +0 -523
  90. jaxsim/physics/algos/terrain.py +0 -78
  91. jaxsim/physics/algos/utils.py +0 -69
  92. jaxsim/physics/model/__init__.py +0 -0
  93. jaxsim/physics/model/ground_contact.py +0 -53
  94. jaxsim/physics/model/physics_model.py +0 -388
  95. jaxsim/physics/model/physics_model_state.py +0 -283
  96. jaxsim/simulation/__init__.py +0 -4
  97. jaxsim/simulation/integrators.py +0 -393
  98. jaxsim/simulation/ode.py +0 -290
  99. jaxsim/simulation/ode_data.py +0 -96
  100. jaxsim/simulation/ode_integration.py +0 -62
  101. jaxsim/simulation/simulator.py +0 -543
  102. jaxsim/simulation/simulator_callbacks.py +0 -79
  103. jaxsim/simulation/utils.py +0 -15
  104. jaxsim/sixd/__init__.py +0 -2
  105. jaxsim/utils/oop.py +0 -536
  106. jaxsim/utils/vmappable.py +0 -117
  107. jaxsim-0.2.dev191.dist-info/METADATA +0 -184
  108. jaxsim-0.2.dev191.dist-info/RECORD +0 -81
  109. {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev5.dist-info}/top_level.txt +0 -0
jaxsim/api/references.py CHANGED
@@ -7,10 +7,11 @@ import jax.numpy as jnp
7
7
  import jax_dataclasses
8
8
 
9
9
  import jaxsim.api as js
10
- import jaxsim.physics.model.physics_model_state
11
10
  import jaxsim.typing as jtp
12
- from jaxsim import VelRepr
13
- from jaxsim.simulation.ode_data import ODEInput
11
+ from jaxsim import exceptions
12
+ from jaxsim.utils.tracing import not_tracing
13
+
14
+ from .common import VelRepr
14
15
 
15
16
  try:
16
17
  from typing import Self
@@ -22,13 +23,19 @@ except ImportError:
22
23
  class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
23
24
  """
24
25
  Class containing the references for a `JaxSimModel` object.
26
+
27
+ Attributes:
28
+ _link_forces: The link 6D forces in inertial-fixed representation.
29
+ _joint_force_references: The joint force references.
25
30
  """
26
31
 
27
- input: ODEInput
32
+ _link_forces: jtp.Matrix
33
+ _joint_force_references: jtp.Vector
28
34
 
29
35
  @staticmethod
30
36
  def zero(
31
37
  model: js.model.JaxSimModel,
38
+ data: js.data.JaxSimModelData | None = None,
32
39
  velocity_representation: VelRepr = VelRepr.Inertial,
33
40
  ) -> JaxSimModelReferences:
34
41
  """
@@ -36,6 +43,9 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
36
43
 
37
44
  Args:
38
45
  model: The model for which to create the zero references.
46
+ data:
47
+ The data of the model, only needed if the velocity representation is
48
+ not inertial-fixed.
39
49
  velocity_representation: The velocity representation to use.
40
50
 
41
51
  Returns:
@@ -43,14 +53,14 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
43
53
  """
44
54
 
45
55
  return JaxSimModelReferences.build(
46
- model=model, velocity_representation=velocity_representation
56
+ model=model, data=data, velocity_representation=velocity_representation
47
57
  )
48
58
 
49
59
  @staticmethod
50
60
  def build(
51
61
  model: js.model.JaxSimModel,
52
- joint_force_references: jtp.Vector | None = None,
53
- link_forces: jtp.Matrix | None = None,
62
+ joint_force_references: jtp.VectorLike | None = None,
63
+ link_forces: jtp.MatrixLike | None = None,
54
64
  data: js.data.JaxSimModelData | None = None,
55
65
  velocity_representation: VelRepr | None = None,
56
66
  ) -> JaxSimModelReferences:
@@ -72,14 +82,14 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
72
82
 
73
83
  # Create or adjust joint force references.
74
84
  joint_force_references = jnp.atleast_1d(
75
- joint_force_references.squeeze()
85
+ jnp.array(joint_force_references, dtype=float).squeeze()
76
86
  if joint_force_references is not None
77
87
  else jnp.zeros(model.dofs())
78
88
  ).astype(float)
79
89
 
80
90
  # Create or adjust link forces.
81
91
  f_L = jnp.atleast_2d(
82
- link_forces.squeeze()
92
+ jnp.array(link_forces, dtype=float).squeeze()
83
93
  if link_forces is not None
84
94
  else jnp.zeros((model.number_of_links(), 6))
85
95
  ).astype(float)
@@ -88,17 +98,21 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
88
98
  velocity_representation = (
89
99
  velocity_representation
90
100
  if velocity_representation is not None
91
- else (
92
- data.velocity_representation if data is not None else VelRepr.Inertial
93
- )
101
+ else getattr(data, "velocity_representation", VelRepr.Inertial)
94
102
  )
95
103
 
96
104
  # Create a zero references object.
97
105
  references = JaxSimModelReferences(
98
- input=ODEInput.zero(physics_model=model.physics_model),
106
+ _link_forces=f_L,
107
+ _joint_force_references=joint_force_references,
99
108
  velocity_representation=velocity_representation,
100
109
  )
101
110
 
111
+ # If the velocity representation is inertial-fixed, we can return
112
+ # the references directly, as we store the link forces in this frame.
113
+ if velocity_representation is VelRepr.Inertial:
114
+ return references
115
+
102
116
  # Store the joint force references.
103
117
  references = references.set_joint_force_references(
104
118
  forces=joint_force_references,
@@ -129,17 +143,27 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
129
143
  `False` otherwise.
130
144
  """
131
145
 
132
- valid = True
146
+ if model is None:
147
+ return True
148
+
149
+ shape = self._joint_force_references.shape
150
+ expected_shape = (model.dofs(),)
151
+
152
+ if shape != expected_shape:
153
+ return False
154
+
155
+ shape = self._link_forces.shape
156
+ expected_shape = (model.number_of_links(), 6)
133
157
 
134
- if model is not None:
135
- valid = valid and self.input.valid(physics_model=model.physics_model)
158
+ if shape != expected_shape:
159
+ return False
136
160
 
137
- return valid
161
+ return True
138
162
 
139
163
  # ==================
140
164
  # Extract quantities
141
165
  # ==================
142
-
166
+ @js.common.named_scope
143
167
  @functools.partial(jax.jit, static_argnames=["link_names"])
144
168
  def link_forces(
145
169
  self,
@@ -172,7 +196,7 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
172
196
  e.g. to the contact model and other kinematic constraints.
173
197
  """
174
198
 
175
- W_f_L = self.input.physics_model.f_ext
199
+ W_f_L = self._link_forces
176
200
 
177
201
  # Return all link forces in inertial-fixed representation using the implicit
178
202
  # serialization.
@@ -184,11 +208,14 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
184
208
  if link_names is not None:
185
209
  raise ValueError("Link names cannot be provided without a model")
186
210
 
187
- return self.input.physics_model.f_ext
211
+ return W_f_L
188
212
 
189
213
  # If we have the model, we can extract the link names, if not provided.
190
- link_names = link_names if link_names is not None else model.link_names()
191
- link_idxs = jaxsim.api.link.names_to_idxs(link_names=link_names, model=model)
214
+ link_idxs = (
215
+ js.link.names_to_idxs(link_names=link_names, model=model)
216
+ if link_names is not None
217
+ else jnp.arange(model.number_of_links())
218
+ )
192
219
 
193
220
  # In inertial-fixed representation, we already have the link forces.
194
221
  if self.velocity_representation is VelRepr.Inertial:
@@ -198,20 +225,25 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
198
225
  msg = "Missing model data to use a representation different from {}"
199
226
  raise ValueError(msg.format(VelRepr.Inertial.name))
200
227
 
201
- if not data.valid(model=model):
228
+ if not_tracing(self._link_forces) and not data.valid(model=model):
202
229
  raise ValueError("The provided data is not valid for the model")
203
230
 
204
- # Helper function to convert a single 6D force to the active representation.
205
- def convert(f_L: jtp.Vector) -> jtp.Vector:
206
- return JaxSimModelReferences.inertial_to_other_representation(
207
- array=f_L,
208
- other_representation=self.velocity_representation,
209
- transform=data.base_transform(),
210
- is_force=True,
211
- )
231
+ # Helper function to convert a single 6D force to the active representation
232
+ # considering as body the link (i.e. L_f_L and LW_f_L).
233
+ def convert(W_f_L: jtp.MatrixLike, W_H_L: jtp.ArrayLike) -> jtp.Matrix:
212
234
 
213
- # Convert to the desired representation.
214
- f_L = jax.vmap(convert)(W_f_L[link_idxs, :])
235
+ return jax.vmap(
236
+ lambda W_f_L, W_H_L: JaxSimModelReferences.inertial_to_other_representation(
237
+ array=W_f_L,
238
+ other_representation=self.velocity_representation,
239
+ transform=W_H_L,
240
+ is_force=True,
241
+ )
242
+ )(W_f_L, W_H_L)
243
+
244
+ # The f_L output is either L_f_L or LW_f_L, depending on the representation.
245
+ W_H_L = js.model.forward_kinematics(model=model, data=data)
246
+ f_L = convert(W_f_L=W_f_L[link_idxs, :], W_H_L=W_H_L[link_idxs, :, :])
215
247
 
216
248
  return f_L
217
249
 
@@ -250,23 +282,26 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
250
282
  if joint_names is not None:
251
283
  raise ValueError("Joint names cannot be provided without a model")
252
284
 
253
- return self.input.physics_model.tau
285
+ return self._joint_force_references
254
286
 
255
- if not self.valid(model=model):
287
+ if not_tracing(self._joint_force_references) and not self.valid(model=model):
256
288
  msg = "The actuation object is not compatible with the provided model"
257
289
  raise ValueError(msg)
258
290
 
259
- joint_names = joint_names if joint_names is not None else model.joint_names()
260
- joint_idxs = js.joint.names_to_idxs(joint_names=joint_names, model=model)
291
+ joint_idxs = (
292
+ js.joint.names_to_idxs(joint_names=joint_names, model=model)
293
+ if joint_names is not None
294
+ else jnp.arange(model.number_of_joints())
295
+ )
261
296
 
262
297
  return jnp.atleast_1d(
263
- self.input.physics_model.tau[joint_idxs].squeeze()
298
+ self._joint_force_references[joint_idxs].squeeze()
264
299
  ).astype(float)
265
300
 
266
301
  # ================
267
302
  # Store quantities
268
303
  # ================
269
-
304
+ @js.common.named_scope
270
305
  @functools.partial(jax.jit, static_argnames=["joint_names"])
271
306
  def set_joint_force_references(
272
307
  self,
@@ -288,37 +323,37 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
288
323
  A new `JaxSimModelReferences` object with the given joint force references.
289
324
  """
290
325
 
291
- forces = jnp.array(forces)
326
+ forces = jnp.atleast_1d(jnp.array(forces, dtype=float).squeeze())
292
327
 
293
- def replace(forces: jtp.VectorLike) -> JaxSimModelReferences:
328
+ def replace(forces: jtp.Vector) -> JaxSimModelReferences:
294
329
  return self.replace(
295
330
  validate=True,
296
- input=self.input.replace(
297
- physics_model=self.input.physics_model.replace(
298
- tau=jnp.atleast_1d(forces.squeeze()).astype(float)
299
- )
300
- ),
331
+ _joint_force_references=jnp.atleast_1d(forces.squeeze()).astype(float),
301
332
  )
302
333
 
303
334
  if model is None:
304
335
  return replace(forces=forces)
305
336
 
306
- if not self.valid(model=model):
337
+ if not_tracing(forces) and not self.valid(model=model):
307
338
  msg = "The references object is not compatible with the provided model"
308
339
  raise ValueError(msg)
309
340
 
310
- joint_names = joint_names if joint_names is not None else model.joint_names()
311
- joint_idxs = js.joint.names_to_idxs(joint_names=joint_names, model=model)
341
+ joint_idxs = (
342
+ js.joint.names_to_idxs(joint_names=joint_names, model=model)
343
+ if joint_names is not None
344
+ else jnp.arange(model.number_of_joints())
345
+ )
312
346
 
313
- return replace(forces=self.input.physics_model.tau.at[joint_idxs].set(forces))
347
+ return replace(forces=self._joint_force_references.at[joint_idxs].set(forces))
314
348
 
349
+ @js.common.named_scope
315
350
  @functools.partial(jax.jit, static_argnames=["link_names", "additive"])
316
351
  def apply_link_forces(
317
352
  self,
318
353
  forces: jtp.MatrixLike,
319
354
  model: js.model.JaxSimModel | None = None,
320
355
  data: js.data.JaxSimModelData | None = None,
321
- link_names: tuple[str, ...] | None = None,
356
+ link_names: tuple[str, ...] | str | None = None,
322
357
  additive: bool = False,
323
358
  ) -> Self:
324
359
  """
@@ -344,17 +379,13 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
344
379
  Then, we always convert and store forces in inertial-fixed representation.
345
380
  """
346
381
 
347
- f_L = jnp.array(forces)
382
+ f_L = jnp.atleast_2d(forces).astype(float)
348
383
 
349
384
  # Helper function to replace the link forces.
350
385
  def replace(forces: jtp.MatrixLike) -> JaxSimModelReferences:
351
386
  return self.replace(
352
387
  validate=True,
353
- input=self.input.replace(
354
- physics_model=self.input.physics_model.replace(
355
- f_ext=jnp.atleast_2d(forces.squeeze()).astype(float)
356
- )
357
- ),
388
+ _link_forces=jnp.atleast_2d(forces.squeeze()).astype(float),
358
389
  )
359
390
 
360
391
  # In this case, we allow only to set the inertial 6D forces to all links
@@ -369,52 +400,157 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
369
400
 
370
401
  W_f_L = f_L
371
402
 
372
- W_f0_L = (
373
- jnp.zeros_like(W_f_L)
374
- if not additive
375
- else self.input.physics_model.f_ext
376
- )
403
+ W_f0_L = jnp.zeros_like(W_f_L) if not additive else self._link_forces
377
404
 
378
405
  return replace(forces=W_f0_L + W_f_L)
379
406
 
380
- # If we have the model, we can extract the link names if not provided.
381
- link_names = link_names if link_names is not None else model.link_names()
382
- link_idxs = jaxsim.api.link.names_to_idxs(link_names=link_names, model=model)
407
+ if link_names is not None and len(link_names) != f_L.shape[0]:
408
+ msg = "The number of link names ({}) must match the number of forces ({})"
409
+ raise ValueError(msg.format(len(link_names), f_L.shape[0]))
410
+
411
+ # Extract the link indices.
412
+ link_idxs = (
413
+ js.link.names_to_idxs(link_names=link_names, model=model)
414
+ if link_names is not None
415
+ else jnp.arange(model.number_of_links())
416
+ )
383
417
 
384
418
  # Compute the bias depending on whether we either set or add the link forces.
385
419
  W_f0_L = (
386
- jnp.zeros_like(f_L)
387
- if not additive
388
- else self.input.physics_model.f_ext[link_idxs, :]
420
+ jnp.zeros_like(f_L) if not additive else self._link_forces[link_idxs, :]
389
421
  )
390
422
 
391
423
  # If inertial-fixed representation, we can directly store the link forces.
392
424
  if self.velocity_representation is VelRepr.Inertial:
393
425
  W_f_L = f_L
394
426
  return replace(
395
- forces=self.input.physics_model.f_ext.at[link_idxs, :].set(
396
- W_f0_L + W_f_L
397
- )
427
+ forces=self._link_forces.at[link_idxs, :].set(W_f0_L + W_f_L)
398
428
  )
399
429
 
400
430
  if data is None:
401
431
  msg = "Missing model data to use a representation different from {}"
402
432
  raise ValueError(msg.format(VelRepr.Inertial.name))
403
433
 
404
- if not data.valid(model=model):
434
+ if not_tracing(forces) and not data.valid(model=model):
405
435
  raise ValueError("The provided data is not valid for the model")
406
436
 
407
- # Helper function to convert a single 6D force to the inertial representation.
408
- def convert(f_L: jtp.Vector) -> jtp.Vector:
437
+ # Helper function to convert a single 6D force to the inertial representation
438
+ # considering as body the link (i.e. L_f_L and LW_f_L).
439
+ def convert_using_link_frame(
440
+ f_L: jtp.MatrixLike, W_H_L: jtp.ArrayLike
441
+ ) -> jtp.Matrix:
442
+
443
+ return jax.vmap(
444
+ lambda f_L, W_H_L: JaxSimModelReferences.other_representation_to_inertial(
445
+ array=f_L,
446
+ other_representation=self.velocity_representation,
447
+ transform=W_H_L,
448
+ is_force=True,
449
+ )
450
+ )(f_L, W_H_L)
451
+
452
+ # The f_L input is either L_f_L or LW_f_L, depending on the representation.
453
+ W_H_L = js.model.forward_kinematics(model=model, data=data)
454
+ W_f_L = convert_using_link_frame(f_L=f_L, W_H_L=W_H_L[link_idxs, :, :])
455
+
456
+ return replace(forces=self._link_forces.at[link_idxs, :].set(W_f0_L + W_f_L))
457
+
458
+ def apply_frame_forces(
459
+ self,
460
+ forces: jtp.MatrixLike,
461
+ model: js.model.JaxSimModel,
462
+ data: js.data.JaxSimModelData,
463
+ frame_names: tuple[str, ...] | str | None = None,
464
+ additive: bool = False,
465
+ ) -> Self:
466
+ """
467
+ Apply the frame forces.
468
+
469
+ Args:
470
+ forces: The frame 6D forces in the active representation.
471
+ model:
472
+ The model to consider, only needed if a frame serialization different
473
+ from the implicit one is used.
474
+ data:
475
+ The data of the considered model, only needed if the velocity
476
+ representation is not inertial-fixed.
477
+ frame_names: The names of the frames corresponding to the forces.
478
+ additive:
479
+ Whether to add the forces to the existing ones instead of replacing them.
480
+
481
+ Returns:
482
+ A new `JaxSimModelReferences` object with the given frame forces.
483
+
484
+ Note:
485
+ The frame forces must be expressed in the active representation.
486
+ Then, we always convert and store forces in inertial-fixed representation.
487
+ """
488
+
489
+ f_F = jnp.atleast_2d(forces).astype(float)
490
+
491
+ if len(frame_names) != f_F.shape[0]:
492
+ msg = "The number of frame names ({}) must match the number of forces ({})"
493
+ raise ValueError(msg.format(len(frame_names), f_F.shape[0]))
494
+
495
+ # Extract the frame indices.
496
+ frame_idxs = (
497
+ js.frame.names_to_idxs(frame_names=frame_names, model=model)
498
+ if frame_names is not None
499
+ else jnp.arange(len(model.frame_names()))
500
+ )
501
+
502
+ parent_link_idxs = jnp.array(model.kin_dyn_parameters.frame_parameters.body)[
503
+ frame_idxs - model.number_of_links()
504
+ ]
505
+
506
+ exceptions.raise_value_error_if(
507
+ condition=~data.valid(model=model),
508
+ msg="The provided data is not valid for the model",
509
+ )
510
+ W_H_Fi = jax.vmap(
511
+ lambda frame_idx: js.frame.transform(
512
+ model=model, data=data, frame_index=frame_idx
513
+ )
514
+ )(frame_idxs)
515
+
516
+ # Helper function to convert a single 6D force to the inertial representation
517
+ # considering as body the frame (i.e. L_f_F and LW_f_F).
518
+ def to_inertial(f_F: jtp.MatrixLike, W_H_F: jtp.MatrixLike) -> jtp.Matrix:
409
519
  return JaxSimModelReferences.other_representation_to_inertial(
410
- array=f_L,
520
+ array=f_F,
411
521
  other_representation=self.velocity_representation,
412
- transform=data.base_transform(),
522
+ transform=W_H_F,
413
523
  is_force=True,
414
524
  )
415
525
 
416
- W_f_L = jax.vmap(convert)(f_L)
526
+ match self.velocity_representation:
527
+ case VelRepr.Inertial:
528
+ W_f_F = f_F
417
529
 
418
- return replace(
419
- forces=self.input.physics_model.f_ext.at[link_idxs, :].set(W_f0_L + W_f_L)
420
- )
530
+ case VelRepr.Body | VelRepr.Mixed:
531
+ W_f_F = jax.vmap(to_inertial)(f_F, W_H_Fi)
532
+
533
+ case _:
534
+ raise ValueError("Invalid velocity representation.")
535
+
536
+ # Sum the forces on the parent links.
537
+ mask = parent_link_idxs[:, jnp.newaxis] == jnp.arange(model.number_of_links())
538
+ W_f_L = mask.T @ W_f_F
539
+
540
+ with self.switch_velocity_representation(
541
+ velocity_representation=VelRepr.Inertial
542
+ ):
543
+ references = self.apply_link_forces(
544
+ model=model,
545
+ data=data,
546
+ link_names=js.link.idxs_to_names(
547
+ model=model, link_indices=parent_link_idxs
548
+ ),
549
+ forces=W_f_L,
550
+ additive=additive,
551
+ )
552
+
553
+ with references.switch_velocity_representation(
554
+ velocity_representation=self.velocity_representation
555
+ ):
556
+ return references
jaxsim/exceptions.py ADDED
@@ -0,0 +1,80 @@
1
+ import os
2
+
3
+ import jax
4
+
5
+
6
+ def raise_if(
7
+ condition: bool | jax.Array, exception: type, msg: str, *args, **kwargs
8
+ ) -> None:
9
+ """
10
+ Raise a host-side exception if a condition is met. Useful in jit-compiled functions.
11
+
12
+ Args:
13
+ condition:
14
+ The boolean condition of the evaluated expression that triggers
15
+ the exception during runtime.
16
+ exception: The type of exception to raise.
17
+ msg:
18
+ The message to display when the exception is raised. The message can be a
19
+ format string (fmt), whose fields are filled with the args and kwargs.
20
+ *args: The arguments to fill the format string.
21
+ **kwargs: The keyword arguments to fill the format string
22
+ """
23
+
24
+ # Disable host callback if running on unsupported hardware or if the user
25
+ # explicitly disabled it.
26
+ if jax.devices()[0].platform in {"tpu", "METAL"} or os.environ.get(
27
+ "JAXSIM_DISABLE_EXCEPTIONS", 0
28
+ ):
29
+ return
30
+
31
+ # Check early that the format string is well-formed.
32
+ try:
33
+ _ = msg.format(*args, **kwargs)
34
+ except Exception as e:
35
+ msg = "Error in formatting exception message with args={} and kwargs={}"
36
+ raise ValueError(msg.format(args, kwargs)) from e
37
+
38
+ def _raise_exception(condition: bool, *args, **kwargs) -> None:
39
+ """The function called by the JAX callback."""
40
+
41
+ if condition:
42
+ raise exception(msg.format(*args, **kwargs))
43
+
44
+ def _callback(args, kwargs) -> None:
45
+ """The function that calls the JAX callback, executed only when needed."""
46
+
47
+ jax.debug.callback(_raise_exception, condition, *args, **kwargs)
48
+
49
+ # Since running a callable on the host is expensive, we prevent its execution
50
+ # if the condition is False with a low-level conditional expression.
51
+ def _run_callback_only_if_condition_is_true(*args, **kwargs) -> None:
52
+ return jax.lax.cond(
53
+ condition,
54
+ _callback,
55
+ lambda args, kwargs: None,
56
+ args,
57
+ kwargs,
58
+ )
59
+
60
+ return _run_callback_only_if_condition_is_true(*args, **kwargs)
61
+
62
+
63
+ def raise_runtime_error_if(
64
+ condition: bool | jax.Array, msg: str, *args, **kwargs
65
+ ) -> None:
66
+ """
67
+ Raise a RuntimeError if a condition is met. Useful in jit-compiled functions.
68
+ """
69
+
70
+ return raise_if(condition, RuntimeError, msg, *args, **kwargs)
71
+
72
+
73
+ def raise_value_error_if(
74
+ condition: bool | jax.Array, msg: str, *args, **kwargs
75
+ ) -> None:
76
+ """
77
+ Raise a ValueError if a condition is met. Useful in jit-compiled functions.
78
+ """
79
+
80
+ return raise_if(condition, ValueError, msg, *args, **kwargs)
@@ -1,2 +1,2 @@
1
- from . import fixed_step
2
- from .common import Integrator, Time, TimeStep
1
+ from . import fixed_step, variable_step
2
+ from .common import Integrator, SystemDynamics, Time, TimeStep