jaxsim 0.2.dev188__py3-none-any.whl → 0.6.1.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. jaxsim/__init__.py +73 -22
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +13 -1
  4. jaxsim/api/com.py +423 -0
  5. jaxsim/api/common.py +48 -19
  6. jaxsim/api/contact.py +604 -52
  7. jaxsim/api/data.py +308 -163
  8. jaxsim/api/frame.py +471 -0
  9. jaxsim/api/joint.py +166 -37
  10. jaxsim/api/kin_dyn_parameters.py +901 -0
  11. jaxsim/api/link.py +277 -78
  12. jaxsim/api/model.py +1572 -362
  13. jaxsim/api/ode.py +324 -129
  14. jaxsim/api/ode_data.py +401 -0
  15. jaxsim/api/references.py +216 -80
  16. jaxsim/exceptions.py +80 -0
  17. jaxsim/integrators/__init__.py +2 -2
  18. jaxsim/integrators/common.py +191 -107
  19. jaxsim/integrators/fixed_step.py +97 -102
  20. jaxsim/integrators/variable_step.py +706 -0
  21. jaxsim/logging.py +1 -2
  22. jaxsim/math/__init__.py +13 -0
  23. jaxsim/math/adjoint.py +57 -22
  24. jaxsim/math/cross.py +16 -7
  25. jaxsim/math/inertia.py +10 -8
  26. jaxsim/math/joint_model.py +289 -0
  27. jaxsim/math/quaternion.py +54 -20
  28. jaxsim/math/rotation.py +27 -21
  29. jaxsim/math/skew.py +16 -5
  30. jaxsim/math/transform.py +102 -0
  31. jaxsim/math/utils.py +31 -0
  32. jaxsim/mujoco/__init__.py +2 -1
  33. jaxsim/mujoco/loaders.py +216 -29
  34. jaxsim/mujoco/model.py +163 -33
  35. jaxsim/mujoco/utils.py +228 -0
  36. jaxsim/mujoco/visualizer.py +107 -22
  37. jaxsim/parsers/__init__.py +0 -1
  38. jaxsim/parsers/descriptions/__init__.py +8 -2
  39. jaxsim/parsers/descriptions/collision.py +87 -16
  40. jaxsim/parsers/descriptions/joint.py +80 -87
  41. jaxsim/parsers/descriptions/link.py +62 -24
  42. jaxsim/parsers/descriptions/model.py +101 -68
  43. jaxsim/parsers/kinematic_graph.py +607 -225
  44. jaxsim/parsers/rod/meshes.py +104 -0
  45. jaxsim/parsers/rod/parser.py +125 -82
  46. jaxsim/parsers/rod/utils.py +127 -82
  47. jaxsim/rbda/__init__.py +11 -0
  48. jaxsim/rbda/aba.py +289 -0
  49. jaxsim/rbda/collidable_points.py +156 -0
  50. jaxsim/rbda/contacts/__init__.py +13 -0
  51. jaxsim/rbda/contacts/common.py +313 -0
  52. jaxsim/rbda/contacts/relaxed_rigid.py +605 -0
  53. jaxsim/rbda/contacts/rigid.py +462 -0
  54. jaxsim/rbda/contacts/soft.py +480 -0
  55. jaxsim/rbda/contacts/visco_elastic.py +1066 -0
  56. jaxsim/rbda/crba.py +167 -0
  57. jaxsim/rbda/forward_kinematics.py +117 -0
  58. jaxsim/rbda/jacobian.py +330 -0
  59. jaxsim/rbda/rnea.py +235 -0
  60. jaxsim/rbda/utils.py +160 -0
  61. jaxsim/terrain/__init__.py +2 -0
  62. jaxsim/terrain/terrain.py +238 -0
  63. jaxsim/typing.py +24 -24
  64. jaxsim/utils/__init__.py +1 -4
  65. jaxsim/utils/jaxsim_dataclass.py +289 -34
  66. jaxsim/utils/tracing.py +5 -11
  67. jaxsim/utils/wrappers.py +159 -0
  68. {jaxsim-0.2.dev188.dist-info → jaxsim-0.6.1.dev2.dist-info}/LICENSE +1 -1
  69. jaxsim-0.6.1.dev2.dist-info/METADATA +465 -0
  70. jaxsim-0.6.1.dev2.dist-info/RECORD +74 -0
  71. {jaxsim-0.2.dev188.dist-info → jaxsim-0.6.1.dev2.dist-info}/WHEEL +1 -1
  72. jaxsim/high_level/__init__.py +0 -2
  73. jaxsim/high_level/common.py +0 -11
  74. jaxsim/high_level/joint.py +0 -148
  75. jaxsim/high_level/link.py +0 -259
  76. jaxsim/high_level/model.py +0 -1686
  77. jaxsim/math/conv.py +0 -114
  78. jaxsim/math/joint.py +0 -102
  79. jaxsim/math/plucker.py +0 -100
  80. jaxsim/physics/__init__.py +0 -12
  81. jaxsim/physics/algos/__init__.py +0 -0
  82. jaxsim/physics/algos/aba.py +0 -254
  83. jaxsim/physics/algos/aba_motors.py +0 -284
  84. jaxsim/physics/algos/crba.py +0 -154
  85. jaxsim/physics/algos/forward_kinematics.py +0 -79
  86. jaxsim/physics/algos/jacobian.py +0 -98
  87. jaxsim/physics/algos/rnea.py +0 -180
  88. jaxsim/physics/algos/rnea_motors.py +0 -196
  89. jaxsim/physics/algos/soft_contacts.py +0 -523
  90. jaxsim/physics/algos/terrain.py +0 -80
  91. jaxsim/physics/algos/utils.py +0 -69
  92. jaxsim/physics/model/__init__.py +0 -0
  93. jaxsim/physics/model/ground_contact.py +0 -55
  94. jaxsim/physics/model/physics_model.py +0 -388
  95. jaxsim/physics/model/physics_model_state.py +0 -283
  96. jaxsim/simulation/__init__.py +0 -4
  97. jaxsim/simulation/integrators.py +0 -393
  98. jaxsim/simulation/ode.py +0 -290
  99. jaxsim/simulation/ode_data.py +0 -96
  100. jaxsim/simulation/ode_integration.py +0 -62
  101. jaxsim/simulation/simulator.py +0 -543
  102. jaxsim/simulation/simulator_callbacks.py +0 -79
  103. jaxsim/simulation/utils.py +0 -15
  104. jaxsim/sixd/__init__.py +0 -2
  105. jaxsim/utils/oop.py +0 -536
  106. jaxsim/utils/vmappable.py +0 -117
  107. jaxsim-0.2.dev188.dist-info/METADATA +0 -184
  108. jaxsim-0.2.dev188.dist-info/RECORD +0 -81
  109. {jaxsim-0.2.dev188.dist-info → jaxsim-0.6.1.dev2.dist-info}/top_level.txt +0 -0
@@ -1,1686 +0,0 @@
1
- import dataclasses
2
- import functools
3
- import pathlib
4
- from typing import Any, Dict, List, Optional, Tuple, Union
5
-
6
- import jax
7
- import jax.numpy as jnp
8
- import jax_dataclasses
9
- import numpy as np
10
- import rod
11
- from jax_dataclasses import Static
12
-
13
- import jaxsim.physics.algos.aba
14
- import jaxsim.physics.algos.crba
15
- import jaxsim.physics.algos.forward_kinematics
16
- import jaxsim.physics.algos.rnea
17
- import jaxsim.physics.model.physics_model
18
- import jaxsim.physics.model.physics_model_state
19
- import jaxsim.typing as jtp
20
- from jaxsim import high_level, logging, physics, sixd
21
- from jaxsim.physics.algos import soft_contacts
22
- from jaxsim.physics.algos.terrain import FlatTerrain, Terrain
23
- from jaxsim.utils import JaxsimDataclass, Mutability, Vmappable, oop
24
-
25
- from .common import VelRepr
26
-
27
-
28
- @jax_dataclasses.pytree_dataclass
29
- class ModelData(JaxsimDataclass):
30
- """
31
- Class used to store the model state and input at a given time.
32
- """
33
-
34
- model_state: jaxsim.physics.model.physics_model_state.PhysicsModelState
35
- model_input: jaxsim.physics.model.physics_model_state.PhysicsModelInput
36
- contact_state: jaxsim.physics.algos.soft_contacts.SoftContactsState
37
-
38
- @staticmethod
39
- def zero(physics_model: physics.model.physics_model.PhysicsModel) -> "ModelData":
40
- """
41
- Return a ModelData object with all fields set to zero and initialized with the right shape.
42
-
43
- Args:
44
- physics_model: The considered physics model.
45
-
46
- Returns:
47
- The zero ModelData object of the given physics model.
48
- """
49
-
50
- return ModelData(
51
- model_state=jaxsim.physics.model.physics_model_state.PhysicsModelState.zero(
52
- physics_model=physics_model
53
- ),
54
- model_input=jaxsim.physics.model.physics_model_state.PhysicsModelInput.zero(
55
- physics_model=physics_model
56
- ),
57
- contact_state=jaxsim.physics.algos.soft_contacts.SoftContactsState.zero(
58
- physics_model=physics_model
59
- ),
60
- )
61
-
62
-
63
- @jax_dataclasses.pytree_dataclass
64
- class StepData(JaxsimDataclass):
65
- """
66
- Class used to store the data computed at each step of the simulation.
67
- """
68
-
69
- t0: float
70
- tf: float
71
- dt: float
72
-
73
- # Starting model data and real input (tau, f_ext) computed at t0
74
- t0_model_data: ModelData = dataclasses.field(repr=False)
75
- t0_model_input_real: jaxsim.physics.model.physics_model_state.PhysicsModelInput = (
76
- dataclasses.field(repr=False)
77
- )
78
-
79
- # ABA output
80
- t0_base_acceleration: jtp.Vector = dataclasses.field(repr=False)
81
- t0_joint_acceleration: jtp.Vector = dataclasses.field(repr=False)
82
-
83
- # (new ODEState)
84
- # Starting from t0_model_data, can be obtained by integrating the ABA output
85
- # and tangential_deformation_dot (which is fn of ode_state at t0)
86
- tf_model_state: jaxsim.physics.model.physics_model_state.PhysicsModelState = (
87
- dataclasses.field(repr=False)
88
- )
89
- tf_contact_state: jaxsim.physics.algos.soft_contacts.SoftContactsState = (
90
- dataclasses.field(repr=False)
91
- )
92
-
93
- aux: Dict[str, Any] = dataclasses.field(default_factory=dict)
94
-
95
-
96
- @jax_dataclasses.pytree_dataclass
97
- class Model(Vmappable):
98
- """
99
- High-level class to operate on a simulated model.
100
- """
101
-
102
- model_name: Static[str]
103
-
104
- physics_model: physics.model.physics_model.PhysicsModel = dataclasses.field(
105
- repr=False
106
- )
107
-
108
- velocity_representation: Static[VelRepr] = dataclasses.field(default=VelRepr.Mixed)
109
-
110
- data: ModelData = dataclasses.field(default=None, repr=False)
111
-
112
- # ========================
113
- # Initialization and state
114
- # ========================
115
-
116
- @staticmethod
117
- def build_from_model_description(
118
- model_description: Union[str, pathlib.Path, rod.Model],
119
- model_name: str | None = None,
120
- vel_repr: VelRepr = VelRepr.Mixed,
121
- gravity: jtp.Array = jaxsim.physics.default_gravity(),
122
- is_urdf: bool | None = None,
123
- considered_joints: List[str] | None = None,
124
- ) -> "Model":
125
- """
126
- Build a Model object from a model description.
127
-
128
- Args:
129
- model_description: A path to an SDF/URDF file, a string containing its content, or a pre-parsed/pre-built rod model.
130
- model_name: The optional name of the model that overrides the one in the description.
131
- vel_repr: The velocity representation to use.
132
- gravity: The 3D gravity vector.
133
- is_urdf: Whether the model description is a URDF or an SDF. This is automatically inferred if the model description is a path to a file.
134
- considered_joints: The list of joints to consider. If None, all joints are considered.
135
-
136
- Returns:
137
- The built Model object.
138
- """
139
-
140
- import jaxsim.parsers.rod
141
-
142
- # Parse the input resource (either a path to file or a string with the URDF/SDF)
143
- # and build the -intermediate- model description
144
- model_description = jaxsim.parsers.rod.build_model_description(
145
- model_description=model_description, is_urdf=is_urdf
146
- )
147
-
148
- # Lump links together if not all joints are considered.
149
- # Note: this procedure assigns a zero position to all joints not considered.
150
- if considered_joints is not None:
151
- model_description = model_description.reduce(
152
- considered_joints=considered_joints
153
- )
154
-
155
- # Create the physics model from the model description
156
- physics_model = jaxsim.physics.model.physics_model.PhysicsModel.build_from(
157
- model_description=model_description, gravity=gravity
158
- )
159
-
160
- # Build and return the high-level model
161
- return Model.build(
162
- physics_model=physics_model,
163
- model_name=model_name,
164
- vel_repr=vel_repr,
165
- )
166
-
167
- @staticmethod
168
- def build_from_sdf(
169
- sdf: Union[str, pathlib.Path],
170
- model_name: str | None = None,
171
- vel_repr: VelRepr = VelRepr.Mixed,
172
- gravity: jtp.Array = jaxsim.physics.default_gravity(),
173
- is_urdf: bool | None = None,
174
- considered_joints: List[str] | None = None,
175
- ) -> "Model":
176
- """
177
- Build a Model object from an SDF description.
178
- This is a deprecated method, use build_from_model_description instead.
179
- """
180
-
181
- msg = "Model.{} is deprecated, use Model.{} instead."
182
- logging.warning(
183
- msg=msg.format("build_from_sdf", "build_from_model_description")
184
- )
185
-
186
- return Model.build_from_model_description(
187
- model_description=sdf,
188
- model_name=model_name,
189
- vel_repr=vel_repr,
190
- gravity=gravity,
191
- is_urdf=is_urdf,
192
- considered_joints=considered_joints,
193
- )
194
-
195
- @staticmethod
196
- def build(
197
- physics_model: jaxsim.physics.model.physics_model.PhysicsModel,
198
- model_name: str | None = None,
199
- vel_repr: VelRepr = VelRepr.Mixed,
200
- ) -> "Model":
201
- """
202
- Build a Model object from a physics model.
203
-
204
- Args:
205
- physics_model: The physics model.
206
- model_name: The optional name of the model that overrides the one in the physics model.
207
- vel_repr: The velocity representation to use.
208
-
209
- Returns:
210
- The built Model object.
211
- """
212
-
213
- # Set the model name (if not provided, use the one from the model description)
214
- model_name = (
215
- model_name if model_name is not None else physics_model.description.name
216
- )
217
-
218
- # Build the high-level model
219
- model = Model(
220
- physics_model=physics_model,
221
- model_name=model_name,
222
- velocity_representation=vel_repr,
223
- )
224
-
225
- # Zero the model data
226
- with model.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):
227
- model.zero()
228
-
229
- # Check model validity
230
- if not model.valid():
231
- raise RuntimeError("The model is not valid.")
232
-
233
- # Return the high-level model
234
- return model
235
-
236
- @functools.partial(oop.jax_tf.method_rw, jit=False, vmap=False, validate=False)
237
- def reduce(
238
- self, considered_joints: tuple[str, ...], keep_base_pose: bool = False
239
- ) -> None:
240
- """
241
- Reduce the model by lumping together the links connected by removed joints.
242
-
243
- Args:
244
- considered_joints: The sequence of joints to consider.
245
- keep_base_pose: A flag indicating whether to keep the base pose or not.
246
- """
247
-
248
- if self.vectorized:
249
- raise RuntimeError("Cannot reduce a vectorized model.")
250
-
251
- # Reduce the model description.
252
- # If considered_joints contains joints not existing in the model, the method
253
- # will raise an exception.
254
- reduced_model_description = self.physics_model.description.reduce(
255
- considered_joints=list(considered_joints)
256
- )
257
-
258
- # Create the physics model from the reduced model description
259
- physics_model = jaxsim.physics.model.physics_model.PhysicsModel.build_from(
260
- model_description=reduced_model_description,
261
- gravity=self.physics_model.gravity[0:3],
262
- )
263
-
264
- # Build the reduced high-level model
265
- reduced_model = Model.build(
266
- physics_model=physics_model,
267
- model_name=self.name(),
268
- vel_repr=self.velocity_representation,
269
- )
270
-
271
- # Extract the base pose
272
- W_p_B = self.base_position()
273
- W_Q_B = self.base_orientation(dcm=False)
274
-
275
- # Replace the current model with the reduced model.
276
- # Since the structure of the PyTree changes, we disable validation.
277
- self.physics_model = reduced_model.physics_model
278
- self.data = reduced_model.data
279
-
280
- if keep_base_pose:
281
- self.reset_base_position(position=W_p_B)
282
- self.reset_base_orientation(orientation=W_Q_B, dcm=False)
283
-
284
- @functools.partial(oop.jax_tf.method_rw, jit=False)
285
- def zero(self) -> None:
286
- """"""
287
-
288
- self.data = ModelData.zero(physics_model=self.physics_model)
289
-
290
- @functools.partial(oop.jax_tf.method_rw, jit=False)
291
- def zero_input(self) -> None:
292
- """"""
293
-
294
- self.data.model_input = ModelData.zero(
295
- physics_model=self.physics_model
296
- ).model_input
297
-
298
- @functools.partial(oop.jax_tf.method_rw, jit=False)
299
- def zero_state(self) -> None:
300
- """"""
301
-
302
- model_data_zero = ModelData.zero(physics_model=self.physics_model)
303
- self.data.model_state = model_data_zero.model_state
304
- self.data.contact_state = model_data_zero.contact_state
305
-
306
- @functools.partial(oop.jax_tf.method_rw, jit=False, vmap=False)
307
- def set_velocity_representation(self, vel_repr: VelRepr) -> None:
308
- """"""
309
-
310
- if self.velocity_representation is vel_repr:
311
- return
312
-
313
- self.velocity_representation = vel_repr
314
-
315
- # ==========
316
- # Properties
317
- # ==========
318
-
319
- @functools.partial(oop.jax_tf.method_ro, jit=False)
320
- def valid(self) -> jtp.Bool:
321
- """"""
322
-
323
- valid = True
324
- valid = valid and all(l.valid() for l in self.links())
325
- valid = valid and all(j.valid() for j in self.joints())
326
- return jnp.array(valid, dtype=bool)
327
-
328
- @functools.partial(oop.jax_tf.method_ro, jit=False)
329
- def floating_base(self) -> jtp.Bool:
330
- """"""
331
-
332
- return jnp.array(self.physics_model.is_floating_base, dtype=bool)
333
-
334
- @functools.partial(oop.jax_tf.method_ro, jit=False)
335
- def dofs(self) -> jtp.Int:
336
- """"""
337
-
338
- return self.joint_positions().size
339
-
340
- @functools.partial(oop.jax_tf.method_ro, jit=False, vmap=False)
341
- def name(self) -> str:
342
- """"""
343
-
344
- return self.model_name
345
-
346
- @functools.partial(oop.jax_tf.method_ro, jit=False)
347
- def nr_of_links(self) -> jtp.Int:
348
- """"""
349
-
350
- return jnp.array(len(self.links()), dtype=int)
351
-
352
- @functools.partial(oop.jax_tf.method_ro, jit=False)
353
- def nr_of_joints(self) -> jtp.Int:
354
- """"""
355
-
356
- return jnp.array(len(self.joints()), dtype=int)
357
-
358
- @functools.partial(oop.jax_tf.method_ro)
359
- def total_mass(self) -> jtp.Float:
360
- """"""
361
-
362
- return jnp.sum(jnp.array([l.mass() for l in self.links()]), dtype=float)
363
-
364
- @functools.partial(oop.jax_tf.method_ro, jit=False, vmap=False)
365
- def get_link(self, link_name: str) -> high_level.link.Link:
366
- """"""
367
-
368
- if link_name not in self.link_names():
369
- msg = f"Link '{link_name}' is not part of model '{self.name()}'"
370
- raise ValueError(msg)
371
-
372
- return self.links(link_names=(link_name,))[0]
373
-
374
- @functools.partial(oop.jax_tf.method_ro, jit=False, vmap=False)
375
- def get_joint(self, joint_name: str) -> high_level.joint.Joint:
376
- """"""
377
-
378
- if joint_name not in self.joint_names():
379
- msg = f"Joint '{joint_name}' is not part of model '{self.name()}'"
380
- raise ValueError(msg)
381
-
382
- return self.joints(joint_names=(joint_name,))[0]
383
-
384
- @functools.partial(oop.jax_tf.method_ro, jit=False, vmap=False)
385
- def link_names(self) -> tuple[str, ...]:
386
- """"""
387
-
388
- return tuple(self.physics_model.description.links_dict.keys())
389
-
390
- @functools.partial(oop.jax_tf.method_ro, jit=False, vmap=False)
391
- def joint_names(self) -> tuple[str, ...]:
392
- """"""
393
-
394
- return tuple(self.physics_model.description.joints_dict.keys())
395
-
396
- @functools.partial(oop.jax_tf.method_ro, jit=False, vmap=False)
397
- def links(
398
- self, link_names: tuple[str, ...] | None = None
399
- ) -> tuple[high_level.link.Link, ...]:
400
- """"""
401
-
402
- all_links = {
403
- l.name: high_level.link.Link(
404
- link_description=l, _parent_model=self, batch_size=self.batch_size
405
- )
406
- for l in sorted(
407
- self.physics_model.description.links_dict.values(),
408
- key=lambda l: l.index,
409
- )
410
- }
411
-
412
- for l in all_links.values():
413
- l._set_mutability(self._mutability())
414
-
415
- if link_names is None:
416
- return tuple(all_links.values())
417
-
418
- return tuple(all_links[name] for name in link_names)
419
-
420
- @functools.partial(oop.jax_tf.method_ro, jit=False, vmap=False)
421
- def joints(
422
- self, joint_names: tuple[str, ...] | None = None
423
- ) -> tuple[high_level.joint.Joint, ...]:
424
- """"""
425
-
426
- all_joints = {
427
- j.name: high_level.joint.Joint(
428
- joint_description=j, _parent_model=self, batch_size=self.batch_size
429
- )
430
- for j in sorted(
431
- self.physics_model.description.joints_dict.values(),
432
- key=lambda j: j.index,
433
- )
434
- }
435
-
436
- for j in all_joints.values():
437
- j._set_mutability(self._mutability())
438
-
439
- if joint_names is None:
440
- return tuple(all_joints.values())
441
-
442
- return tuple(all_joints[name] for name in joint_names)
443
-
444
- @functools.partial(oop.jax_tf.method_ro, static_argnames=["link_names", "terrain"])
445
- def in_contact(
446
- self,
447
- link_names: tuple[str, ...] | None = None,
448
- terrain: Terrain = FlatTerrain(),
449
- ) -> jtp.Vector:
450
- """"""
451
-
452
- link_names = link_names if link_names is not None else self.link_names()
453
-
454
- if set(link_names) - set(self.link_names()) != set():
455
- raise ValueError("One or more link names are not part of the model")
456
-
457
- from jaxsim.physics.algos.soft_contacts import collidable_points_pos_vel
458
-
459
- W_p_Ci, _ = collidable_points_pos_vel(
460
- model=self.physics_model,
461
- q=self.data.model_state.joint_positions,
462
- qd=self.data.model_state.joint_velocities,
463
- xfb=self.data.model_state.xfb(),
464
- )
465
-
466
- terrain_height = jax.vmap(terrain.height)(W_p_Ci[0, :], W_p_Ci[1, :])
467
-
468
- below_terrain = W_p_Ci[2, :] <= terrain_height
469
-
470
- links_in_contact = jax.vmap(
471
- lambda link_index: jnp.where(
472
- self.physics_model.gc.body == link_index,
473
- below_terrain,
474
- jnp.zeros_like(below_terrain, dtype=bool),
475
- ).any()
476
- )(jnp.array([link.index() for link in self.links(link_names=link_names)]))
477
-
478
- return links_in_contact
479
-
480
- # =================
481
- # Multi-DoF methods
482
- # =================
483
-
484
- @functools.partial(oop.jax_tf.method_ro, static_argnames=["joint_names"])
485
- def joint_positions(self, joint_names: tuple[str, ...] | None = None) -> jtp.Vector:
486
- """"""
487
-
488
- return self.data.model_state.joint_positions[
489
- self._joint_indices(joint_names=joint_names)
490
- ]
491
-
492
- @functools.partial(oop.jax_tf.method_ro, static_argnames=["joint_names"])
493
- def joint_random_positions(
494
- self,
495
- joint_names: tuple[str, ...] | None = None,
496
- key: jax.Array | None = None,
497
- ) -> jtp.Vector:
498
- """"""
499
-
500
- if key is None:
501
- key = jax.random.PRNGKey(seed=0)
502
-
503
- s_min, s_max = self.joint_limits(joint_names=joint_names)
504
-
505
- s_random = jax.random.uniform(
506
- minval=s_min,
507
- maxval=s_max,
508
- key=key,
509
- shape=s_min.shape,
510
- )
511
-
512
- return s_random
513
-
514
- @functools.partial(oop.jax_tf.method_ro, static_argnames=["joint_names"])
515
- def joint_velocities(
516
- self, joint_names: tuple[str, ...] | None = None
517
- ) -> jtp.Vector:
518
- """"""
519
-
520
- return self.data.model_state.joint_velocities[
521
- self._joint_indices(joint_names=joint_names)
522
- ]
523
-
524
- @functools.partial(oop.jax_tf.method_ro, static_argnames=["joint_names"])
525
- def joint_generalized_forces_targets(
526
- self, joint_names: tuple[str, ...] | None = None
527
- ) -> jtp.Vector:
528
- """"""
529
-
530
- return self.data.model_input.tau[self._joint_indices(joint_names=joint_names)]
531
-
532
- @functools.partial(oop.jax_tf.method_ro, static_argnames=["joint_names"])
533
- def joint_limits(
534
- self, joint_names: tuple[str, ...] | None = None
535
- ) -> Tuple[jtp.Vector, jtp.Vector]:
536
- """"""
537
-
538
- # Consider all joints if not specified otherwise
539
- joint_names = joint_names if joint_names is not None else self.joint_names()
540
-
541
- # Create a (Dofs, 2) matrix containing the joint limits
542
- limits = jnp.vstack(
543
- jnp.array([j.position_limit() for j in self.joints(joint_names)])
544
- )
545
-
546
- # Get the limits, reordering them in case low > high
547
- s_low = jnp.min(limits, axis=1)
548
- s_high = jnp.max(limits, axis=1)
549
-
550
- return s_low, s_high
551
-
552
- # =========
553
- # Base link
554
- # =========
555
-
556
- @functools.partial(oop.jax_tf.method_ro, jit=False, vmap=False)
557
- def base_frame(self) -> str:
558
- """"""
559
-
560
- return self.physics_model.description.root.name
561
-
562
- @functools.partial(oop.jax_tf.method_ro)
563
- def base_position(self) -> jtp.Vector:
564
- """"""
565
-
566
- return self.data.model_state.base_position.squeeze()
567
-
568
- @functools.partial(oop.jax_tf.method_ro, static_argnames=["dcm"])
569
- def base_orientation(self, dcm: bool = False) -> jtp.Vector:
570
- """"""
571
-
572
- # Normalize the quaternion before using it.
573
- # Our integration logic has a Baumgarte stabilization term makes the quaternion
574
- # norm converge to 1, but it does not enforce to be 1 at all the time instants.
575
- base_unit_quaternion = (
576
- self.data.model_state.base_quaternion.squeeze()
577
- / jnp.linalg.norm(self.data.model_state.base_quaternion)
578
- )
579
-
580
- # wxyz -> xyzw
581
- to_xyzw = np.array([1, 2, 3, 0])
582
-
583
- return (
584
- base_unit_quaternion
585
- if not dcm
586
- else sixd.so3.SO3.from_quaternion_xyzw(
587
- base_unit_quaternion[to_xyzw]
588
- ).as_matrix()
589
- )
590
-
591
- @functools.partial(oop.jax_tf.method_ro)
592
- def base_transform(self) -> jtp.MatrixJax:
593
- """"""
594
-
595
- W_R_B = self.base_orientation(dcm=True)
596
- W_p_B = jnp.vstack(self.base_position())
597
-
598
- return jnp.vstack(
599
- [
600
- jnp.block([W_R_B, W_p_B]),
601
- jnp.array([0, 0, 0, 1]),
602
- ]
603
- )
604
-
605
- @functools.partial(oop.jax_tf.method_ro)
606
- def base_velocity(self) -> jtp.Vector:
607
- """"""
608
-
609
- W_v_WB = jnp.hstack(
610
- [
611
- self.data.model_state.base_linear_velocity,
612
- self.data.model_state.base_angular_velocity,
613
- ]
614
- )
615
-
616
- return self.inertial_to_active_representation(array=W_v_WB)
617
-
618
- @functools.partial(oop.jax_tf.method_ro)
619
- def external_forces(self) -> jtp.Matrix:
620
- """
621
- Return the active external forces acting on the robot.
622
-
623
- The external forces are a user input and are not computed by the physics engine.
624
- During the simulation, these external forces are summed to other terms like
625
- the external forces due to the contact with the environment.
626
-
627
- Returns:
628
- A matrix of shape (n_links, 6) containing the external forces acting on the
629
- robot links. The forces are expressed in the active representation.
630
- """
631
-
632
- # Get the active external forces that are always stored internally
633
- # in Inertial representation
634
- W_f_ext = self.data.model_input.f_ext
635
-
636
- inertial_to_active = lambda f: self.inertial_to_active_representation(
637
- f, is_force=True
638
- )
639
-
640
- return jax.vmap(inertial_to_active, in_axes=0)(W_f_ext)
641
-
642
- # =======================
643
- # Single link r/w methods
644
- # =======================
645
-
646
- @functools.partial(
647
- oop.jax_tf.method_rw, jit=True, static_argnames=["link_name", "additive"]
648
- )
649
- def apply_external_force_to_link(
650
- self,
651
- link_name: str,
652
- force: jtp.Array | None = None,
653
- torque: jtp.Array | None = None,
654
- additive: bool = True,
655
- ) -> None:
656
- """"""
657
-
658
- # Get the target link with the correct mutability
659
- link = self.get_link(link_name=link_name)
660
- link._set_mutability(mutability=self._mutability())
661
-
662
- # Initialize zero force components if not set
663
- force = force if force is not None else jnp.zeros(3)
664
- torque = torque if torque is not None else jnp.zeros(3)
665
-
666
- # Build the target 6D force in the active representation
667
- f_ext = jnp.hstack([force, torque])
668
-
669
- # Convert the 6D force to the inertial representation
670
- if self.velocity_representation is VelRepr.Inertial:
671
- W_f_ext = f_ext
672
-
673
- elif self.velocity_representation is VelRepr.Body:
674
- L_f_ext = f_ext
675
- W_H_L = link.transform()
676
- L_X_W = sixd.se3.SE3.from_matrix(W_H_L).inverse().adjoint()
677
-
678
- W_f_ext = L_X_W.transpose() @ L_f_ext
679
-
680
- elif self.velocity_representation is VelRepr.Mixed:
681
- LW_f_ext = f_ext
682
-
683
- W_p_L = link.transform()[0:3, 3]
684
- W_H_LW = jnp.eye(4).at[0:3, 3].set(W_p_L)
685
- LW_X_W = sixd.se3.SE3.from_matrix(W_H_LW).inverse().adjoint()
686
-
687
- W_f_ext = LW_X_W.transpose() @ LW_f_ext
688
-
689
- else:
690
- raise ValueError(self.velocity_representation)
691
-
692
- # Obtain the new 6D force considering the 'additive' flag
693
- W_f_ext_current = self.data.model_input.f_ext[link.index(), :]
694
- new_force = W_f_ext_current + W_f_ext if additive else W_f_ext
695
-
696
- # Update the model data
697
- self.data.model_input.f_ext = self.data.model_input.f_ext.at[
698
- link.index(), :
699
- ].set(new_force)
700
-
701
- @functools.partial(
702
- oop.jax_tf.method_rw, jit=True, static_argnames=["link_name", "additive"]
703
- )
704
- def apply_external_force_to_link_com(
705
- self,
706
- link_name: str,
707
- force: jtp.Array | None = None,
708
- torque: jtp.Array | None = None,
709
- additive: bool = True,
710
- ) -> None:
711
- """"""
712
-
713
- # Get the target link with the correct mutability
714
- link = self.get_link(link_name=link_name)
715
- link._set_mutability(mutability=self._mutability())
716
-
717
- # Initialize zero force components if not set
718
- force = force if force is not None else jnp.zeros(3)
719
- torque = torque if torque is not None else jnp.zeros(3)
720
-
721
- # Build the target 6D force in the active representation
722
- f_ext = jnp.hstack([force, torque])
723
-
724
- # Convert the 6D force to the inertial representation
725
- if self.velocity_representation is VelRepr.Inertial:
726
- W_f_ext = f_ext
727
-
728
- elif self.velocity_representation is VelRepr.Body:
729
- GL_f_ext = f_ext
730
-
731
- W_H_L = link.transform()
732
- L_p_CoM = link.com_position(in_link_frame=True)
733
- L_H_GL = jnp.eye(4).at[0:3, 3].set(L_p_CoM)
734
- W_H_GL = W_H_L @ L_H_GL
735
- GL_X_W = sixd.se3.SE3.from_matrix(W_H_GL).inverse().adjoint()
736
-
737
- W_f_ext = GL_X_W.transpose() @ GL_f_ext
738
-
739
- elif self.velocity_representation is VelRepr.Mixed:
740
- GW_f_ext = f_ext
741
-
742
- W_p_CoM = link.com_position(in_link_frame=False)
743
- W_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM)
744
- GW_X_W = sixd.se3.SE3.from_matrix(W_H_GW).inverse().adjoint()
745
-
746
- W_f_ext = GW_X_W.transpose() @ GW_f_ext
747
-
748
- else:
749
- raise ValueError(self.velocity_representation)
750
-
751
- # Obtain the new 6D force considering the 'additive' flag
752
- W_f_ext_current = self.data.model_input.f_ext[link.index(), :]
753
- new_force = W_f_ext_current + W_f_ext if additive else W_f_ext
754
-
755
- # Update the model data
756
- self.data.model_input.f_ext = self.data.model_input.f_ext.at[
757
- link.index(), :
758
- ].set(new_force)
759
-
760
- # ================================================
761
- # Generalized methods and free-floating quantities
762
- # ================================================
763
-
764
- @functools.partial(oop.jax_tf.method_ro)
765
- def generalized_position(self) -> Tuple[jtp.Matrix, jtp.Vector]:
766
- """"""
767
-
768
- return self.base_transform(), self.joint_positions()
769
-
770
- @functools.partial(oop.jax_tf.method_ro)
771
- def generalized_velocity(self) -> jtp.Vector:
772
- """"""
773
-
774
- return jnp.hstack([self.base_velocity(), self.joint_velocities()])
775
-
776
- @functools.partial(oop.jax_tf.method_ro, static_argnames=["output_vel_repr"])
777
- def generalized_free_floating_jacobian(
778
- self, output_vel_repr: VelRepr | None = None
779
- ) -> jtp.Matrix:
780
- """"""
781
-
782
- if output_vel_repr is None:
783
- output_vel_repr = self.velocity_representation
784
-
785
- # The body frame of the Link.jacobian method is the link frame L.
786
- # In this method, we want instead to use the base link B as body frame.
787
- # Therefore, we always get the link jacobian having Inertial as output
788
- # representation, and then we convert it to the desired output representation.
789
- if output_vel_repr is VelRepr.Inertial:
790
- to_output = lambda J: J
791
-
792
- elif output_vel_repr is VelRepr.Body:
793
-
794
- def to_output(W_J_Wi):
795
- W_H_B = self.base_transform()
796
- B_X_W = sixd.se3.SE3.from_matrix(W_H_B).inverse().adjoint()
797
- return B_X_W @ W_J_Wi
798
-
799
- elif output_vel_repr is VelRepr.Mixed:
800
-
801
- def to_output(W_J_Wi):
802
- W_H_B = self.base_transform()
803
- W_H_BW = jnp.array(W_H_B).at[0:3, 0:3].set(jnp.eye(3))
804
- BW_X_W = sixd.se3.SE3.from_matrix(W_H_BW).inverse().adjoint()
805
- return BW_X_W @ W_J_Wi
806
-
807
- else:
808
- raise ValueError(output_vel_repr)
809
-
810
- # Get the link jacobians in Inertial representation and convert them to the
811
- # target output representation in which the body frame is the base link B
812
- J_free_floating = jnp.vstack(
813
- [
814
- to_output(
815
- self.get_link(link_name=link_name).jacobian(
816
- output_vel_repr=VelRepr.Inertial
817
- )
818
- )
819
- for link_name in self.link_names()
820
- ]
821
- )
822
-
823
- return J_free_floating
824
-
825
- @functools.partial(oop.jax_tf.method_ro)
826
- def free_floating_mass_matrix(self) -> jtp.Matrix:
827
- """"""
828
-
829
- M_body = jaxsim.physics.algos.crba.crba(
830
- model=self.physics_model,
831
- q=self.data.model_state.joint_positions,
832
- )
833
-
834
- if self.velocity_representation is VelRepr.Body:
835
- return M_body
836
-
837
- elif self.velocity_representation is VelRepr.Inertial:
838
- zero_6n = jnp.zeros(shape=(6, self.dofs()))
839
- B_X_W = sixd.se3.SE3.from_matrix(self.base_transform()).inverse().adjoint()
840
-
841
- invT = jnp.vstack(
842
- [
843
- jnp.block([B_X_W, zero_6n]),
844
- jnp.block([zero_6n.T, jnp.eye(self.dofs())]),
845
- ]
846
- )
847
-
848
- return invT.T @ M_body @ invT
849
-
850
- elif self.velocity_representation is VelRepr.Mixed:
851
- zero_6n = jnp.zeros(shape=(6, self.dofs()))
852
- W_H_BW = self.base_transform().at[0:3, 3].set(jnp.zeros(3))
853
- BW_X_W = sixd.se3.SE3.from_matrix(W_H_BW).inverse().adjoint()
854
-
855
- invT = jnp.vstack(
856
- [
857
- jnp.block([BW_X_W, zero_6n]),
858
- jnp.block([zero_6n.T, jnp.eye(self.dofs())]),
859
- ]
860
- )
861
-
862
- return invT.T @ M_body @ invT
863
-
864
- else:
865
- raise ValueError(self.velocity_representation)
866
-
867
- @functools.partial(oop.jax_tf.method_ro)
868
- def free_floating_bias_forces(self) -> jtp.Vector:
869
- """"""
870
-
871
- with self.editable(validate=True) as model:
872
- model.zero_input()
873
-
874
- return jnp.hstack(
875
- model.inverse_dynamics(
876
- base_acceleration=jnp.zeros(6), joint_accelerations=None
877
- )
878
- )
879
-
880
- @functools.partial(oop.jax_tf.method_ro)
881
- def free_floating_gravity_forces(self) -> jtp.Vector:
882
- """"""
883
-
884
- with self.editable(validate=True) as model:
885
- model.zero_input()
886
- model.data.model_state.joint_velocities = jnp.zeros_like(
887
- model.data.model_state.joint_velocities
888
- )
889
- model.data.model_state.base_linear_velocity = jnp.zeros_like(
890
- model.data.model_state.base_linear_velocity
891
- )
892
- model.data.model_state.base_angular_velocity = jnp.zeros_like(
893
- model.data.model_state.base_angular_velocity
894
- )
895
-
896
- return jnp.hstack(
897
- model.inverse_dynamics(
898
- base_acceleration=jnp.zeros(6), joint_accelerations=None
899
- )
900
- )
901
-
902
- @functools.partial(oop.jax_tf.method_ro)
903
- def momentum(self) -> jtp.Vector:
904
- """"""
905
-
906
- with self.editable(validate=True) as m:
907
- m.set_velocity_representation(vel_repr=VelRepr.Body)
908
-
909
- # Compute the momentum in body-fixed velocity representation.
910
- # Note: the first 6 rows of the mass matrix define the jacobian of the
911
- # floating-base momentum.
912
- B_h = m.free_floating_mass_matrix()[0:6, :] @ m.generalized_velocity()
913
-
914
- W_H_B = self.base_transform()
915
- B_X_W: jtp.Array = sixd.se3.SE3.from_matrix(W_H_B).inverse().adjoint()
916
-
917
- W_h = B_X_W.T @ B_h
918
- return self.inertial_to_active_representation(array=W_h, is_force=True)
919
-
920
- # ===========
921
- # CoM methods
922
- # ===========
923
-
924
- @functools.partial(oop.jax_tf.method_ro)
925
- def com_position(self) -> jtp.Vector:
926
- """"""
927
-
928
- m = self.total_mass()
929
-
930
- W_H_L = self.forward_kinematics()
931
- W_H_B = self.base_transform()
932
- B_H_W = sixd.se3.SE3.from_matrix(W_H_B).inverse().as_matrix()
933
-
934
- com_links = [
935
- (
936
- l.mass()
937
- * B_H_W
938
- @ W_H_L[l.index()]
939
- @ jnp.hstack([l.com_position(in_link_frame=True), 1])
940
- )
941
- for l in self.links()
942
- ]
943
-
944
- B_ph_CoM = (1 / m) * jnp.sum(jnp.array(com_links), axis=0)
945
-
946
- return (W_H_B @ B_ph_CoM)[0:3]
947
-
948
- # ==========
949
- # Algorithms
950
- # ==========
951
-
952
- @functools.partial(oop.jax_tf.method_ro)
953
- def forward_kinematics(self) -> jtp.Array:
954
- """"""
955
-
956
- W_H_i = jaxsim.physics.algos.forward_kinematics.forward_kinematics_model(
957
- model=self.physics_model,
958
- q=self.data.model_state.joint_positions,
959
- xfb=self.data.model_state.xfb(),
960
- )
961
-
962
- return W_H_i
963
-
964
- @functools.partial(oop.jax_tf.method_ro)
965
- def inverse_dynamics(
966
- self,
967
- joint_accelerations: jtp.Vector | None = None,
968
- base_acceleration: jtp.Vector | None = None,
969
- ) -> Tuple[jtp.Vector, jtp.Vector]:
970
- """
971
- Compute inverse dynamics with the RNEA algorithm.
972
-
973
- Args:
974
- joint_accelerations: the joint accelerations to consider.
975
- base_acceleration: the base acceleration in the active representation to consider.
976
-
977
- Returns:
978
- A tuple containing the 6D force in active representation applied to the base
979
- to obtain the considered base acceleration, and the joint torques to apply
980
- to obtain the considered joint accelerations.
981
- """
982
-
983
- # Build joint accelerations if not provided
984
- joint_accelerations = (
985
- joint_accelerations
986
- if joint_accelerations is not None
987
- else jnp.zeros_like(self.joint_positions())
988
- )
989
-
990
- # Build base acceleration if not provided
991
- base_acceleration = (
992
- base_acceleration if base_acceleration is not None else jnp.zeros(6)
993
- )
994
-
995
- if base_acceleration.size != 6:
996
- raise ValueError(base_acceleration.size)
997
-
998
- def to_inertial(C_vd_WB, W_H_C, C_v_WB, W_vl_WC):
999
- W_X_C = sixd.se3.SE3.from_matrix(W_H_C).adjoint()
1000
- C_X_W = sixd.se3.SE3.from_matrix(W_H_C).inverse().adjoint()
1001
-
1002
- if self.velocity_representation != VelRepr.Mixed:
1003
- return W_X_C @ C_vd_WB
1004
- else:
1005
- from jaxsim.math.cross import Cross
1006
-
1007
- C_v_WC = C_X_W @ jnp.hstack([W_vl_WC, jnp.zeros(3)])
1008
- return W_X_C @ (C_vd_WB + Cross.vx(C_v_WC) @ C_v_WB)
1009
-
1010
- if self.velocity_representation is VelRepr.Inertial:
1011
- W_H_C = W_H_W = jnp.eye(4)
1012
- W_vl_WC = W_vl_WW = jnp.zeros(3)
1013
-
1014
- elif self.velocity_representation is VelRepr.Body:
1015
- W_H_C = W_H_B = self.base_transform()
1016
- W_vl_WC = W_vl_WB = self.base_velocity()[0:3]
1017
-
1018
- elif self.velocity_representation is VelRepr.Mixed:
1019
- W_H_B = self.base_transform()
1020
- W_H_C = W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3))
1021
- W_vl_WC = W_vl_W_BW = self.base_velocity()[0:3]
1022
-
1023
- else:
1024
- raise ValueError(self.velocity_representation)
1025
-
1026
- # We need to convert the derivative of the base acceleration to the Inertial
1027
- # representation. In Mixed representation, this conversion is not a plain
1028
- # transformation with just X, but it also involves a cross product in ℝ⁶.
1029
- W_v̇_WB = to_inertial(
1030
- C_vd_WB=base_acceleration,
1031
- W_H_C=W_H_C,
1032
- C_v_WB=self.base_velocity(),
1033
- W_vl_WC=W_vl_WC,
1034
- )
1035
-
1036
- # Compute RNEA
1037
- W_f_B, tau = jaxsim.physics.algos.rnea.rnea(
1038
- model=self.physics_model,
1039
- xfb=self.data.model_state.xfb(),
1040
- q=self.data.model_state.joint_positions,
1041
- qd=self.data.model_state.joint_velocities,
1042
- qdd=joint_accelerations,
1043
- a0fb=W_v̇_WB,
1044
- f_ext=self.data.model_input.f_ext,
1045
- )
1046
-
1047
- # Adjust shape
1048
- tau = jnp.atleast_1d(tau.squeeze())
1049
-
1050
- # Express W_f_B in the active representation
1051
- f_B = self.inertial_to_active_representation(array=W_f_B, is_force=True)
1052
-
1053
- return f_B, tau
1054
-
1055
- @functools.partial(oop.jax_tf.method_ro, static_argnames=["prefer_aba"])
1056
- def forward_dynamics(
1057
- self, tau: jtp.Vector | None = None, prefer_aba: float = True
1058
- ) -> Tuple[jtp.Vector, jtp.Vector]:
1059
- """"""
1060
-
1061
- return (
1062
- self.forward_dynamics_aba(tau=tau)
1063
- if prefer_aba
1064
- else self.forward_dynamics_crb(tau=tau)
1065
- )
1066
-
1067
- @functools.partial(oop.jax_tf.method_ro)
1068
- def forward_dynamics_aba(
1069
- self, tau: jtp.Vector | None = None
1070
- ) -> Tuple[jtp.Vector, jtp.Vector]:
1071
- """"""
1072
-
1073
- # Build joint torques if not provided
1074
- tau = tau if tau is not None else jnp.zeros_like(self.joint_positions())
1075
-
1076
- # Compute ABA
1077
- W_v̇_WB, s̈ = jaxsim.physics.algos.aba.aba(
1078
- model=self.physics_model,
1079
- xfb=self.data.model_state.xfb(),
1080
- q=self.data.model_state.joint_positions,
1081
- qd=self.data.model_state.joint_velocities,
1082
- tau=tau,
1083
- f_ext=self.data.model_input.f_ext,
1084
- )
1085
-
1086
- def to_active(W_vd_WB, W_H_C, W_v_WB, W_vl_WC):
1087
- C_X_W = sixd.se3.SE3.from_matrix(W_H_C).inverse().adjoint()
1088
-
1089
- if self.velocity_representation != VelRepr.Mixed:
1090
- return C_X_W @ W_vd_WB
1091
- else:
1092
- from jaxsim.math.cross import Cross
1093
-
1094
- W_v_WC = jnp.hstack([W_vl_WC, jnp.zeros(3)])
1095
- return C_X_W @ (W_vd_WB - Cross.vx(W_v_WC) @ W_v_WB)
1096
-
1097
- if self.velocity_representation is VelRepr.Inertial:
1098
- W_H_C = W_H_W = jnp.eye(4)
1099
- W_vl_WC = W_vl_WW = jnp.zeros(3)
1100
-
1101
- elif self.velocity_representation is VelRepr.Body:
1102
- W_H_C = W_H_B = self.base_transform()
1103
- W_vl_WC = W_vl_WB = self.base_velocity()[0:3]
1104
-
1105
- elif self.velocity_representation is VelRepr.Mixed:
1106
- W_H_B = self.base_transform()
1107
- W_H_C = W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3))
1108
- W_vl_WC = W_vl_W_BW = self.base_velocity()[0:3]
1109
-
1110
- else:
1111
- raise ValueError(self.velocity_representation)
1112
-
1113
- # We need to convert the derivative of the base acceleration to the active
1114
- # representation. In Mixed representation, this conversion is not a plain
1115
- # transformation with just X, but it also involves a cross product in ℝ⁶.
1116
- C_v̇_WB = to_active(
1117
- W_vd_WB=W_v̇_WB.squeeze(),
1118
- W_H_C=W_H_C,
1119
- W_v_WB=jnp.hstack(
1120
- [
1121
- self.data.model_state.base_linear_velocity,
1122
- self.data.model_state.base_angular_velocity,
1123
- ]
1124
- ),
1125
- W_vl_WC=W_vl_WC,
1126
- )
1127
-
1128
- # Adjust shape
1129
- s̈ = jnp.atleast_1d(s̈.squeeze())
1130
-
1131
- return C_v̇_WB, s̈
1132
-
1133
- @functools.partial(oop.jax_tf.method_ro)
1134
- def forward_dynamics_crb(
1135
- self, tau: jtp.Vector | None = None
1136
- ) -> Tuple[jtp.Vector, jtp.Vector]:
1137
- """"""
1138
-
1139
- # Build joint torques if not provided
1140
- τ = tau if tau is not None else jnp.zeros(shape=(self.dofs(),))
1141
- τ = jnp.atleast_1d(τ.squeeze())
1142
- τ = jnp.vstack(τ) if τ.size > 0 else jnp.empty(shape=(0, 1))
1143
-
1144
- # Extract motor parameters from the physics model
1145
- GR = self.motor_gear_ratios()
1146
- IM = self.motor_inertias()
1147
- KV = jnp.diag(self.motor_viscous_frictions())
1148
-
1149
- # Compute auxiliary quantities
1150
- Γ = jnp.diag(GR)
1151
- K̅ᵥ = Γ.T @ KV @ Γ
1152
-
1153
- # Compute terms of the floating-base EoM
1154
- M = self.free_floating_mass_matrix()
1155
- h = jnp.vstack(self.free_floating_bias_forces())
1156
- J = self.generalized_free_floating_jacobian()
1157
- f_ext = jnp.vstack(self.external_forces().flatten())
1158
- S = jnp.block([jnp.zeros(shape=(self.dofs(), 6)), jnp.eye(self.dofs())]).T
1159
-
1160
- # Configure the slice for motors
1161
- sl_m = np.s_[M.shape[0] - self.dofs() :]
1162
-
1163
- # Add the motor related terms to the EoM
1164
- M = M.at[sl_m, sl_m].set(M[sl_m, sl_m] + jnp.diag(Γ.T @ IM @ Γ))
1165
- h = h.at[sl_m].set(h[sl_m] + K̅ᵥ @ self.joint_velocities()[:, None])
1166
- S = S.at[sl_m].set(S[sl_m])
1167
-
1168
- # Compute the generalized acceleration by inverting the EoM
1169
- ν̇ = jax.lax.select(
1170
- pred=self.floating_base(),
1171
- on_true=jnp.linalg.inv(M) @ ((S @ τ) - h + J.T @ f_ext),
1172
- on_false=jnp.vstack(
1173
- [
1174
- jnp.zeros(shape=(6, 1)),
1175
- jnp.linalg.inv(M[6:, 6:])
1176
- @ ((S @ τ)[6:] - h[6:] + J[:, 6:].T @ f_ext),
1177
- ]
1178
- ),
1179
- ).squeeze()
1180
-
1181
- # Extract the base acceleration in the active representation.
1182
- # Note that this is an apparent acceleration (relevant in Mixed representation),
1183
- # therefore it cannot be always expressed in different frames with just a
1184
- # 6D transformation X.
1185
- v̇_WB = ν̇[0:6]
1186
-
1187
- # Extract the joint accelerations
1188
- s̈ = jnp.atleast_1d(ν̇[6:])
1189
-
1190
- return v̇_WB, s̈
1191
-
1192
- # ======
1193
- # Energy
1194
- # ======
1195
-
1196
- @functools.partial(oop.jax_tf.method_ro)
1197
- def mechanical_energy(self) -> jtp.Float:
1198
- """"""
1199
-
1200
- K = self.kinetic_energy()
1201
- U = self.potential_energy()
1202
-
1203
- return K + U
1204
-
1205
- @functools.partial(oop.jax_tf.method_ro)
1206
- def kinetic_energy(self) -> jtp.Float:
1207
- """"""
1208
-
1209
- with self.editable(validate=True) as m:
1210
- m.set_velocity_representation(vel_repr=VelRepr.Body)
1211
-
1212
- nu = m.generalized_velocity()
1213
- M = m.free_floating_mass_matrix()
1214
-
1215
- return 0.5 * nu.T @ M @ nu
1216
-
1217
- @functools.partial(oop.jax_tf.method_ro)
1218
- def potential_energy(self) -> jtp.Float:
1219
- """"""
1220
-
1221
- m = self.total_mass()
1222
- W_p_CoM = jnp.hstack([self.com_position(), 1])
1223
- gravity = self.physics_model.gravity[3:6].squeeze()
1224
-
1225
- return -(m * jnp.hstack([gravity, 0]) @ W_p_CoM)
1226
-
1227
- # ===========
1228
- # Set targets
1229
- # ===========
1230
-
1231
- @functools.partial(oop.jax_tf.method_rw, static_argnames=["joint_names"])
1232
- def set_joint_generalized_force_targets(
1233
- self, forces: jtp.Vector, joint_names: tuple[str, ...] | None = None
1234
- ) -> None:
1235
- """"""
1236
-
1237
- if joint_names is None:
1238
- joint_names = self.joint_names()
1239
-
1240
- if forces.size != len(joint_names):
1241
- raise ValueError("Wrong arguments size", forces.size, len(joint_names))
1242
-
1243
- self.data.model_input.tau = self.data.model_input.tau.at[
1244
- self._joint_indices(joint_names=joint_names)
1245
- ].set(forces)
1246
-
1247
- # ==========
1248
- # Reset data
1249
- # ==========
1250
-
1251
- @functools.partial(oop.jax_tf.method_rw, static_argnames=["joint_names"])
1252
- def reset_joint_positions(
1253
- self, positions: jtp.Vector, joint_names: tuple[str, ...] | None = None
1254
- ) -> None:
1255
- """"""
1256
-
1257
- if joint_names is None:
1258
- joint_names = self.joint_names()
1259
-
1260
- if positions.size != len(joint_names):
1261
- raise ValueError("Wrong arguments size", positions.size, len(joint_names))
1262
-
1263
- if positions.size == 0:
1264
- return
1265
-
1266
- # TODO: joint position limits
1267
-
1268
- self.data.model_state.joint_positions = jnp.atleast_1d(
1269
- jnp.array(
1270
- self.data.model_state.joint_positions.at[
1271
- self._joint_indices(joint_names=joint_names)
1272
- ].set(positions),
1273
- dtype=float,
1274
- )
1275
- )
1276
-
1277
- @functools.partial(oop.jax_tf.method_rw, static_argnames=["joint_names"])
1278
- def reset_joint_velocities(
1279
- self, velocities: jtp.Vector, joint_names: tuple[str, ...] | None = None
1280
- ) -> None:
1281
- """"""
1282
-
1283
- if joint_names is None:
1284
- joint_names = self.joint_names()
1285
-
1286
- if velocities.size != len(joint_names):
1287
- raise ValueError("Wrong arguments size", velocities.size, len(joint_names))
1288
-
1289
- if velocities.size == 0:
1290
- return
1291
-
1292
- # TODO: joint velocity limits
1293
-
1294
- self.data.model_state.joint_velocities = jnp.atleast_1d(
1295
- jnp.array(
1296
- self.data.model_state.joint_velocities.at[
1297
- self._joint_indices(joint_names=joint_names)
1298
- ].set(velocities),
1299
- dtype=float,
1300
- )
1301
- )
1302
-
1303
- @functools.partial(oop.jax_tf.method_rw)
1304
- def reset_base_position(self, position: jtp.Vector) -> None:
1305
- """"""
1306
-
1307
- self.data.model_state.base_position = jnp.array(position, dtype=float)
1308
-
1309
- @functools.partial(oop.jax_tf.method_rw, static_argnames=["dcm"])
1310
- def reset_base_orientation(self, orientation: jtp.Array, dcm: bool = False) -> None:
1311
- """"""
1312
-
1313
- if dcm:
1314
- to_wxyz = np.array([3, 0, 1, 2])
1315
- orientation_xyzw = sixd.so3.SO3.from_matrix(
1316
- orientation
1317
- ).as_quaternion_xyzw()
1318
- orientation = orientation_xyzw[to_wxyz]
1319
-
1320
- unit_quaternion = orientation / jnp.linalg.norm(orientation)
1321
- self.data.model_state.base_quaternion = jnp.array(unit_quaternion, dtype=float)
1322
-
1323
- @functools.partial(oop.jax_tf.method_rw)
1324
- def reset_base_transform(self, transform: jtp.Matrix) -> None:
1325
- """"""
1326
-
1327
- if transform.shape != (4, 4):
1328
- raise ValueError(transform.shape)
1329
-
1330
- self.reset_base_position(position=transform[0:3, 3])
1331
- self.reset_base_orientation(orientation=transform[0:3, 0:3], dcm=True)
1332
-
1333
- @functools.partial(oop.jax_tf.method_rw)
1334
- def reset_base_velocity(self, base_velocity: jtp.VectorJax) -> None:
1335
- """"""
1336
-
1337
- if not self.physics_model.is_floating_base:
1338
- msg = "Changing the base velocity of a fixed-based model is not allowed"
1339
- raise RuntimeError(msg)
1340
-
1341
- # Remove extra dimensions
1342
- base_velocity = base_velocity.squeeze()
1343
-
1344
- # Check for a valid shape
1345
- if base_velocity.shape != (6,):
1346
- raise ValueError(base_velocity.shape)
1347
-
1348
- # Convert, if needed, to the representation used internally (VelRepr.Inertial)
1349
- if self.velocity_representation is VelRepr.Inertial:
1350
- base_velocity_inertial = base_velocity
1351
-
1352
- elif self.velocity_representation is VelRepr.Body:
1353
- w_X_b = sixd.se3.SE3.from_rotation_and_translation(
1354
- rotation=sixd.so3.SO3.from_matrix(self.base_orientation(dcm=True)),
1355
- translation=self.base_position(),
1356
- ).adjoint()
1357
-
1358
- base_velocity_inertial = w_X_b @ base_velocity
1359
-
1360
- elif self.velocity_representation is VelRepr.Mixed:
1361
- w_X_bw = sixd.se3.SE3.from_rotation_and_translation(
1362
- rotation=sixd.so3.SO3.identity(),
1363
- translation=self.base_position(),
1364
- ).adjoint()
1365
-
1366
- base_velocity_inertial = w_X_bw @ base_velocity
1367
-
1368
- else:
1369
- raise ValueError(self.velocity_representation)
1370
-
1371
- self.data.model_state.base_linear_velocity = jnp.array(
1372
- base_velocity_inertial[0:3], dtype=float
1373
- )
1374
-
1375
- self.data.model_state.base_angular_velocity = jnp.array(
1376
- base_velocity_inertial[3:6], dtype=float
1377
- )
1378
-
1379
- # ===========
1380
- # Integration
1381
- # ===========
1382
-
1383
- @functools.partial(
1384
- oop.jax_tf.method_rw,
1385
- static_argnames=["sub_steps", "integrator_type", "terrain"],
1386
- vmap_in_axes=(0, 0, 0, None, None, None, 0, None),
1387
- )
1388
- def integrate(
1389
- self,
1390
- t0: jtp.Float,
1391
- tf: jtp.Float,
1392
- sub_steps: int = 1,
1393
- integrator_type: Optional[
1394
- "jaxsim.simulation.ode_integration.IntegratorType"
1395
- ] = None,
1396
- terrain: soft_contacts.Terrain = soft_contacts.FlatTerrain(),
1397
- contact_parameters: soft_contacts.SoftContactsParams = soft_contacts.SoftContactsParams(),
1398
- clear_inputs: bool = False,
1399
- ) -> StepData:
1400
- """"""
1401
-
1402
- from jaxsim.simulation import ode_data, ode_integration
1403
- from jaxsim.simulation.ode_integration import IntegratorType
1404
-
1405
- if integrator_type is None:
1406
- integrator_type = IntegratorType.EulerForward
1407
-
1408
- x0 = ode_integration.ode.ode_data.ODEState(
1409
- physics_model=self.data.model_state,
1410
- soft_contacts=self.data.contact_state,
1411
- )
1412
-
1413
- ode_input = ode_integration.ode.ode_data.ODEInput(
1414
- physics_model=self.data.model_input
1415
- )
1416
-
1417
- assert isinstance(integrator_type, IntegratorType)
1418
-
1419
- # Integrate the model dynamics
1420
- ode_states, aux = ode_integration.ode_integration_fixed_step(
1421
- x0=x0,
1422
- t=jnp.array([t0, tf], dtype=float),
1423
- ode_input=ode_input,
1424
- physics_model=self.physics_model,
1425
- soft_contacts_params=contact_parameters,
1426
- num_sub_steps=sub_steps,
1427
- terrain=terrain,
1428
- integrator_type=integrator_type,
1429
- return_aux=True,
1430
- )
1431
-
1432
- # Get quantities at t0
1433
- t0_model_data = self.data
1434
- t0_model_input = jax.tree_util.tree_map(
1435
- lambda l: l[0],
1436
- aux["ode_input"],
1437
- )
1438
- t0_model_input_real = jax.tree_util.tree_map(
1439
- lambda l: l[0],
1440
- aux["ode_input_real"],
1441
- )
1442
- t0_model_acceleration = jax.tree_util.tree_map(
1443
- lambda l: l[0],
1444
- aux["model_acceleration"],
1445
- )
1446
-
1447
- # Get quantities at tf
1448
- ode_states: ode_data.ODEState
1449
- tf_model_state = jax.tree_util.tree_map(
1450
- lambda l: l[-1], ode_states.physics_model
1451
- )
1452
- tf_contact_state = jax.tree_util.tree_map(
1453
- lambda l: l[-1], ode_states.soft_contacts
1454
- )
1455
-
1456
- # Clear user inputs (joint torques and external forces) if asked
1457
- model_input = jax.lax.cond(
1458
- pred=clear_inputs,
1459
- false_fun=lambda: t0_model_input.physics_model,
1460
- true_fun=lambda: jaxsim.physics.model.physics_model_state.PhysicsModelInput.zero(
1461
- physics_model=self.physics_model
1462
- ),
1463
- )
1464
-
1465
- # Update model state
1466
- self.data = ModelData(
1467
- model_state=tf_model_state,
1468
- contact_state=tf_contact_state,
1469
- model_input=model_input,
1470
- )
1471
-
1472
- return StepData(
1473
- t0=t0,
1474
- tf=tf,
1475
- dt=(tf - t0),
1476
- t0_model_data=t0_model_data,
1477
- t0_model_input_real=t0_model_input_real.physics_model,
1478
- t0_base_acceleration=t0_model_acceleration[0:6],
1479
- t0_joint_acceleration=t0_model_acceleration[6:],
1480
- tf_model_state=tf_model_state,
1481
- tf_contact_state=tf_contact_state,
1482
- aux={
1483
- "t0": jax.tree_util.tree_map(
1484
- lambda l: l[0],
1485
- aux,
1486
- ),
1487
- "tf": jax.tree_util.tree_map(
1488
- lambda l: l[-1],
1489
- aux,
1490
- ),
1491
- },
1492
- )
1493
-
1494
- # ==============
1495
- # Motor dynamics
1496
- # ==============
1497
-
1498
- @functools.partial(oop.jax_tf.method_rw, static_argnames=["joint_names"])
1499
- def set_motor_inertias(
1500
- self, inertias: jtp.Vector, joint_names: tuple[str, ...] | None = None
1501
- ) -> None:
1502
- joint_names = joint_names or self.joint_names()
1503
-
1504
- if inertias.size != len(joint_names):
1505
- raise ValueError("Wrong arguments size", inertias.size, len(joint_names))
1506
-
1507
- self.physics_model._joint_motor_inertia.update(
1508
- dict(zip(self.physics_model._joint_motor_inertia, inertias))
1509
- )
1510
-
1511
- logging.info("Setting attribute `motor_inertias`")
1512
-
1513
- @functools.partial(oop.jax_tf.method_rw, jit=False)
1514
- def set_motor_gear_ratios(
1515
- self, gear_ratios: jtp.Vector, joint_names: tuple[str, ...] | None = None
1516
- ) -> None:
1517
- joint_names = joint_names or self.joint_names()
1518
-
1519
- if gear_ratios.size != len(joint_names):
1520
- raise ValueError("Wrong arguments size", gear_ratios.size, len(joint_names))
1521
-
1522
- # Check on gear ratios if motor_inertias are not zero
1523
- for idx, gr in enumerate(gear_ratios):
1524
- if gr != 0 and self.motor_inertias()[idx] == 0:
1525
- raise ValueError(
1526
- f"Zero motor inertia with non-zero gear ratio found in position {idx}"
1527
- )
1528
-
1529
- self.physics_model._joint_motor_gear_ratio.update(
1530
- dict(zip(self.physics_model._joint_motor_gear_ratio, gear_ratios))
1531
- )
1532
-
1533
- logging.info("Setting attribute `motor_gear_ratios`")
1534
-
1535
- @functools.partial(oop.jax_tf.method_rw, static_argnames=["joint_names"])
1536
- def set_motor_viscous_frictions(
1537
- self,
1538
- viscous_frictions: jtp.Vector,
1539
- joint_names: tuple[str, ...] | None = None,
1540
- ) -> None:
1541
- joint_names = joint_names or self.joint_names()
1542
-
1543
- if viscous_frictions.size != len(joint_names):
1544
- raise ValueError(
1545
- "Wrong arguments size", viscous_frictions.size, len(joint_names)
1546
- )
1547
-
1548
- self.physics_model._joint_motor_viscous_friction.update(
1549
- dict(
1550
- zip(
1551
- self.physics_model._joint_motor_viscous_friction,
1552
- viscous_frictions,
1553
- )
1554
- )
1555
- )
1556
-
1557
- logging.info("Setting attribute `motor_viscous_frictions`")
1558
-
1559
- @functools.partial(oop.jax_tf.method_ro, jit=False)
1560
- def motor_inertias(self) -> jtp.Vector:
1561
- return jnp.array(
1562
- [*self.physics_model._joint_motor_inertia.values()], dtype=float
1563
- )
1564
-
1565
- @functools.partial(oop.jax_tf.method_ro, jit=False)
1566
- def motor_gear_ratios(self) -> jtp.Vector:
1567
- return jnp.array(
1568
- [*self.physics_model._joint_motor_gear_ratio.values()], dtype=float
1569
- )
1570
-
1571
- @functools.partial(oop.jax_tf.method_ro, jit=False)
1572
- def motor_viscous_frictions(self) -> jtp.Vector:
1573
- return jnp.array(
1574
- [*self.physics_model._joint_motor_viscous_friction.values()], dtype=float
1575
- )
1576
-
1577
- # ===============
1578
- # Private methods
1579
- # ===============
1580
-
1581
- @functools.partial(oop.jax_tf.method_ro, static_argnames=["is_force"])
1582
- def inertial_to_active_representation(
1583
- self, array: jtp.Array, is_force: bool = False
1584
- ) -> jtp.Array:
1585
- """"""
1586
-
1587
- W_array = array.squeeze()
1588
-
1589
- if W_array.size != 6:
1590
- raise ValueError(W_array.size)
1591
-
1592
- if self.velocity_representation is VelRepr.Inertial:
1593
- return W_array
1594
-
1595
- elif self.velocity_representation is VelRepr.Body:
1596
- W_H_B = self.base_transform()
1597
-
1598
- if not is_force:
1599
- B_Xv_W = sixd.se3.SE3.from_matrix(W_H_B).inverse().adjoint()
1600
- B_array = B_Xv_W @ W_array
1601
-
1602
- else:
1603
- B_Xf_W = sixd.se3.SE3.from_matrix(W_H_B).adjoint().T
1604
- B_array = B_Xf_W @ W_array
1605
-
1606
- return B_array
1607
-
1608
- elif self.velocity_representation is VelRepr.Mixed:
1609
- W_H_BW = jnp.eye(4).at[0:3, 3].set(self.base_position())
1610
-
1611
- if not is_force:
1612
- BW_Xv_W = sixd.se3.SE3.from_matrix(W_H_BW).inverse().adjoint()
1613
- BW_array = BW_Xv_W @ W_array
1614
-
1615
- else:
1616
- BW_Xf_W = sixd.se3.SE3.from_matrix(W_H_BW).adjoint().T
1617
- BW_array = BW_Xf_W @ W_array
1618
-
1619
- return BW_array
1620
-
1621
- else:
1622
- raise ValueError(self.velocity_representation)
1623
-
1624
- @functools.partial(oop.jax_tf.method_ro, static_argnames=["is_force"])
1625
- def active_to_inertial_representation(
1626
- self, array: jtp.Array, is_force: bool = False
1627
- ) -> jtp.Array:
1628
- """"""
1629
-
1630
- array = array.squeeze()
1631
-
1632
- if array.size != 6:
1633
- raise ValueError(array.size)
1634
-
1635
- if self.velocity_representation is VelRepr.Inertial:
1636
- W_array = array
1637
- return W_array
1638
-
1639
- elif self.velocity_representation is VelRepr.Body:
1640
- B_array = array
1641
- W_H_B = self.base_transform()
1642
-
1643
- if not is_force:
1644
- W_Xv_B: jtp.Array = sixd.se3.SE3.from_matrix(W_H_B).adjoint()
1645
- W_array = W_Xv_B @ B_array
1646
-
1647
- else:
1648
- W_Xf_B = sixd.se3.SE3.from_matrix(W_H_B).inverse().adjoint().T
1649
- W_array = W_Xf_B @ B_array
1650
-
1651
- return W_array
1652
-
1653
- elif self.velocity_representation is VelRepr.Mixed:
1654
- BW_array = array
1655
- W_H_BW = jnp.eye(4).at[0:3, 3].set(self.base_position())
1656
-
1657
- if not is_force:
1658
- W_Xv_BW: jtp.Array = sixd.se3.SE3.from_matrix(W_H_BW).adjoint()
1659
- W_array = W_Xv_BW @ BW_array
1660
-
1661
- else:
1662
- W_Xf_BW = sixd.se3.SE3.from_matrix(W_H_BW).inverse().adjoint().T
1663
- W_array = W_Xf_BW @ BW_array
1664
-
1665
- return W_array
1666
-
1667
- else:
1668
- raise ValueError(self.velocity_representation)
1669
-
1670
- def _joint_indices(self, joint_names: tuple[str, ...] | None = None) -> jtp.Vector:
1671
- """"""
1672
-
1673
- if joint_names is None:
1674
- joint_names = self.joint_names()
1675
-
1676
- if set(joint_names) - set(self.joint_names()) != set():
1677
- raise ValueError("One or more joint names are not part of the model")
1678
-
1679
- # Note: joints share the same index as their child link, therefore the first
1680
- # joint has index=1. We need to subtract one to get the right entry of
1681
- # data stored in the PhysicsModelState arrays.
1682
- joint_indices = [
1683
- j.joint_description.index - 1 for j in self.joints(joint_names=joint_names)
1684
- ]
1685
-
1686
- return np.array(joint_indices, dtype=int)