jaxsim 0.2.dev191__py3-none-any.whl → 0.6.1.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. jaxsim/__init__.py +73 -22
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +13 -1
  4. jaxsim/api/com.py +423 -0
  5. jaxsim/api/common.py +48 -19
  6. jaxsim/api/contact.py +604 -52
  7. jaxsim/api/data.py +308 -163
  8. jaxsim/api/frame.py +471 -0
  9. jaxsim/api/joint.py +166 -37
  10. jaxsim/api/kin_dyn_parameters.py +901 -0
  11. jaxsim/api/link.py +277 -78
  12. jaxsim/api/model.py +1572 -362
  13. jaxsim/api/ode.py +324 -133
  14. jaxsim/api/ode_data.py +401 -0
  15. jaxsim/api/references.py +216 -80
  16. jaxsim/exceptions.py +80 -0
  17. jaxsim/integrators/__init__.py +2 -2
  18. jaxsim/integrators/common.py +191 -107
  19. jaxsim/integrators/fixed_step.py +97 -102
  20. jaxsim/integrators/variable_step.py +706 -0
  21. jaxsim/logging.py +1 -2
  22. jaxsim/math/__init__.py +13 -0
  23. jaxsim/math/adjoint.py +57 -22
  24. jaxsim/math/cross.py +16 -7
  25. jaxsim/math/inertia.py +10 -8
  26. jaxsim/math/joint_model.py +289 -0
  27. jaxsim/math/quaternion.py +54 -20
  28. jaxsim/math/rotation.py +27 -21
  29. jaxsim/math/skew.py +16 -5
  30. jaxsim/math/transform.py +102 -0
  31. jaxsim/math/utils.py +31 -0
  32. jaxsim/mujoco/__init__.py +2 -1
  33. jaxsim/mujoco/loaders.py +216 -29
  34. jaxsim/mujoco/model.py +163 -33
  35. jaxsim/mujoco/utils.py +228 -0
  36. jaxsim/mujoco/visualizer.py +107 -22
  37. jaxsim/parsers/__init__.py +0 -1
  38. jaxsim/parsers/descriptions/__init__.py +8 -2
  39. jaxsim/parsers/descriptions/collision.py +83 -26
  40. jaxsim/parsers/descriptions/joint.py +80 -87
  41. jaxsim/parsers/descriptions/link.py +58 -31
  42. jaxsim/parsers/descriptions/model.py +101 -68
  43. jaxsim/parsers/kinematic_graph.py +606 -229
  44. jaxsim/parsers/rod/meshes.py +104 -0
  45. jaxsim/parsers/rod/parser.py +125 -82
  46. jaxsim/parsers/rod/utils.py +127 -82
  47. jaxsim/rbda/__init__.py +11 -0
  48. jaxsim/rbda/aba.py +289 -0
  49. jaxsim/rbda/collidable_points.py +156 -0
  50. jaxsim/rbda/contacts/__init__.py +13 -0
  51. jaxsim/rbda/contacts/common.py +313 -0
  52. jaxsim/rbda/contacts/relaxed_rigid.py +605 -0
  53. jaxsim/rbda/contacts/rigid.py +462 -0
  54. jaxsim/rbda/contacts/soft.py +480 -0
  55. jaxsim/rbda/contacts/visco_elastic.py +1066 -0
  56. jaxsim/rbda/crba.py +167 -0
  57. jaxsim/rbda/forward_kinematics.py +117 -0
  58. jaxsim/rbda/jacobian.py +330 -0
  59. jaxsim/rbda/rnea.py +235 -0
  60. jaxsim/rbda/utils.py +160 -0
  61. jaxsim/terrain/__init__.py +2 -0
  62. jaxsim/terrain/terrain.py +238 -0
  63. jaxsim/typing.py +24 -24
  64. jaxsim/utils/__init__.py +1 -4
  65. jaxsim/utils/jaxsim_dataclass.py +289 -34
  66. jaxsim/utils/tracing.py +5 -11
  67. jaxsim/utils/wrappers.py +159 -0
  68. {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/LICENSE +1 -1
  69. jaxsim-0.6.1.dev2.dist-info/METADATA +465 -0
  70. jaxsim-0.6.1.dev2.dist-info/RECORD +74 -0
  71. {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/WHEEL +1 -1
  72. jaxsim/high_level/__init__.py +0 -2
  73. jaxsim/high_level/common.py +0 -11
  74. jaxsim/high_level/joint.py +0 -148
  75. jaxsim/high_level/link.py +0 -259
  76. jaxsim/high_level/model.py +0 -1686
  77. jaxsim/math/conv.py +0 -114
  78. jaxsim/math/joint.py +0 -102
  79. jaxsim/math/plucker.py +0 -100
  80. jaxsim/physics/__init__.py +0 -12
  81. jaxsim/physics/algos/__init__.py +0 -0
  82. jaxsim/physics/algos/aba.py +0 -254
  83. jaxsim/physics/algos/aba_motors.py +0 -284
  84. jaxsim/physics/algos/crba.py +0 -154
  85. jaxsim/physics/algos/forward_kinematics.py +0 -79
  86. jaxsim/physics/algos/jacobian.py +0 -98
  87. jaxsim/physics/algos/rnea.py +0 -180
  88. jaxsim/physics/algos/rnea_motors.py +0 -196
  89. jaxsim/physics/algos/soft_contacts.py +0 -523
  90. jaxsim/physics/algos/terrain.py +0 -78
  91. jaxsim/physics/algos/utils.py +0 -69
  92. jaxsim/physics/model/__init__.py +0 -0
  93. jaxsim/physics/model/ground_contact.py +0 -53
  94. jaxsim/physics/model/physics_model.py +0 -388
  95. jaxsim/physics/model/physics_model_state.py +0 -283
  96. jaxsim/simulation/__init__.py +0 -4
  97. jaxsim/simulation/integrators.py +0 -393
  98. jaxsim/simulation/ode.py +0 -290
  99. jaxsim/simulation/ode_data.py +0 -96
  100. jaxsim/simulation/ode_integration.py +0 -62
  101. jaxsim/simulation/simulator.py +0 -543
  102. jaxsim/simulation/simulator_callbacks.py +0 -79
  103. jaxsim/simulation/utils.py +0 -15
  104. jaxsim/sixd/__init__.py +0 -2
  105. jaxsim/utils/oop.py +0 -536
  106. jaxsim/utils/vmappable.py +0 -117
  107. jaxsim-0.2.dev191.dist-info/METADATA +0 -184
  108. jaxsim-0.2.dev191.dist-info/RECORD +0 -81
  109. {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/top_level.txt +0 -0
jaxsim/mujoco/model.py CHANGED
@@ -1,12 +1,20 @@
1
+ from __future__ import annotations
2
+
1
3
  import functools
2
4
  import pathlib
5
+ from collections.abc import Callable, Sequence
3
6
  from typing import Any
4
7
 
5
8
  import mujoco as mj
6
9
  import numpy as np
7
10
  import numpy.typing as npt
11
+ import xmltodict
8
12
  from scipy.spatial.transform import Rotation
9
13
 
14
+ import jaxsim.typing as jtp
15
+
16
+ HeightmapCallable = Callable[[jtp.FloatLike, jtp.FloatLike], jtp.FloatLike]
17
+
10
18
 
11
19
  class MujocoModelHelper:
12
20
  """
@@ -14,34 +22,118 @@ class MujocoModelHelper:
14
22
  """
15
23
 
16
24
  def __init__(self, model: mj.MjModel, data: mj.MjData | None = None) -> None:
17
- """"""
25
+ """
26
+ Initialize the MujocoModelHelper object.
27
+
28
+ Args:
29
+ model: A Mujoco model object.
30
+ data: A Mujoco data object. If None, a new one will be created.
31
+ """
18
32
 
19
33
  self.model = model
20
34
  self.data = data if data is not None else mj.MjData(self.model)
21
35
 
22
- # Populate the data with kinematics
36
+ # Populate the data with kinematics.
23
37
  mj.mj_forward(self.model, self.data)
24
38
 
25
- # Keep the cache of this method local to improve GC
39
+ # Keep the cache of this method local to improve GC.
26
40
  self.mask_qpos = functools.cache(self._mask_qpos)
27
41
 
28
42
  @staticmethod
29
43
  def build_from_xml(
30
- mjcf_description: str | pathlib.Path, assets: dict[str, Any] = None
31
- ) -> "MujocoModelHelper":
32
- """"""
44
+ mjcf_description: str | pathlib.Path,
45
+ assets: dict[str, Any] | None = None,
46
+ heightmap: HeightmapCallable | None = None,
47
+ heightmap_name: str = "terrain",
48
+ heightmap_radius_xy: tuple[float, float] = (1.0, 1.0),
49
+ ) -> MujocoModelHelper:
50
+ """
51
+ Build a Mujoco model from an MJCF description.
33
52
 
34
- # Read the XML description if it's a path to file
53
+ Args:
54
+ mjcf_description:
55
+ A string containing the XML description of the Mujoco model
56
+ or a path to a file containing the XML description.
57
+ assets: An optional dictionary containing the assets of the model.
58
+ heightmap:
59
+ A function in two variables that returns the height of a terrain
60
+ in the specified coordinate point.
61
+ heightmap_name:
62
+ The default name of the heightmap in the MJCF description
63
+ to load the corresponding configuration.
64
+ heightmap_radius_xy:
65
+ The extension of the heightmap in the x-y surface corresponding to the
66
+ plane over which the grid of the sampled heightmap is generated.
67
+
68
+ Returns:
69
+ A MujocoModelHelper object.
70
+ """
71
+
72
+ # Read the XML description if it is a path to file.
35
73
  mjcf_description = (
36
74
  mjcf_description.read_text()
37
75
  if isinstance(mjcf_description, pathlib.Path)
38
76
  else mjcf_description
39
77
  )
40
78
 
41
- # Create the Mujoco model from the XML and, optionally, the assets dictionary
42
- model = mj.MjModel.from_xml_string(xml=mjcf_description, assets=assets) # noqa
79
+ if heightmap is None:
80
+ hfield = None
81
+
82
+ else:
83
+
84
+ mjcf_description_dict = xmltodict.parse(xml_input=mjcf_description)
85
+
86
+ # Create a dictionary of all hfield configurations from the MJCF.
87
+ hfields = mjcf_description_dict["mujoco"]["asset"].get("hfield", [])
88
+ hfields = hfields if isinstance(hfields, list) else [hfields]
89
+ hfields_dict = {hfield["@name"]: hfield for hfield in hfields}
90
+
91
+ if heightmap_name not in hfields_dict:
92
+ raise ValueError(f"Heightmap '{heightmap_name}' not found in MJCF")
43
93
 
44
- return MujocoModelHelper(model=model, data=mj.MjData(model))
94
+ hfield_element = hfields_dict[heightmap_name]
95
+
96
+ # Generate the hfield by sampling the heightmap function.
97
+ hfield = generate_hfield(
98
+ heightmap=heightmap,
99
+ samples_xy=(int(hfield_element["@nrow"]), int(hfield_element["@ncol"])),
100
+ radius_xy=heightmap_radius_xy,
101
+ )
102
+
103
+ # Update dynamically the '/asset/hfield[@name=heightmap_name]@size' attribute
104
+ # with the information of the sampled points.
105
+ # This is necessary for correctly rendering the heightmap over the
106
+ # specified xy area with the correct z elevation.
107
+ size = [float(el) for el in hfield_element["@size"].split(" ")]
108
+ size[0], size[1] = heightmap_radius_xy
109
+ size[2] = 1.0
110
+ # The following could be zero but Mujoco complains if it's exactly zero.
111
+ size[3] = max(0.000_001, -min(hfield))
112
+
113
+ # Replace the 'size' attribute.
114
+ hfields_dict[heightmap_name]["@size"] = " ".join(str(el) for el in size)
115
+
116
+ # Update the hfield elements of the original MJCF.
117
+ # Only the hfield corresponding to 'heightmap_name' was actually edited.
118
+ mjcf_description_dict["mujoco"]["asset"]["hfield"] = list(
119
+ hfields_dict.values()
120
+ )
121
+
122
+ # Serialize the updated MJCF to XML.
123
+ mjcf_description = xmltodict.unparse(
124
+ input_dict=mjcf_description_dict, pretty=True
125
+ )
126
+
127
+ # Create the Mujoco model from the XML and, optionally, the dictionary of assets.
128
+ model = mj.MjModel.from_xml_string(xml=mjcf_description, assets=assets)
129
+ data = mj.MjData(model)
130
+
131
+ # Store the sampled heightmap into the Mujoco model.
132
+ if heightmap is not None:
133
+ assert hfield is not None
134
+ model.hfield_data = hfield
135
+
136
+ return MujocoModelHelper(model=model, data=data)
45
137
 
46
138
  def time(self) -> float:
47
139
  """Return the simulation time."""
@@ -148,9 +240,9 @@ class MujocoModelHelper:
148
240
  raise ValueError("The orientation is not a valid element of SO(3)")
149
241
 
150
242
  W_Q_B = (
151
- Rotation.from_matrix(orientation).as_quat(canonical=True)[
152
- np.array([3, 0, 1, 2])
153
- ]
243
+ Rotation.from_matrix(orientation).as_quat(
244
+ canonical=True, scalar_first=False
245
+ )
154
246
  if dcm
155
247
  else orientation
156
248
  )
@@ -162,17 +254,17 @@ class MujocoModelHelper:
162
254
  # ==================
163
255
 
164
256
  def number_of_joints(self) -> int:
165
- """"""
257
+ """Return the number of joints in the model."""
166
258
 
167
259
  return self.model.njnt
168
260
 
169
261
  def number_of_dofs(self) -> int:
170
- """"""
262
+ """Return the number of DoFs in the model."""
171
263
 
172
264
  return self.model.nq
173
265
 
174
266
  def joint_names(self) -> list[str]:
175
- """"""
267
+ """Return the names of the joints in the model."""
176
268
 
177
269
  return [
178
270
  mj.mj_id2name(self.model, mj.mjtObj.mjOBJ_JOINT, idx)
@@ -180,7 +272,7 @@ class MujocoModelHelper:
180
272
  ]
181
273
 
182
274
  def joint_dofs(self, joint_name: str) -> int:
183
- """"""
275
+ """Return the number of DoFs of a joint."""
184
276
 
185
277
  if joint_name not in self.joint_names():
186
278
  raise ValueError(f"Joint '{joint_name}' not found")
@@ -188,7 +280,7 @@ class MujocoModelHelper:
188
280
  return self.data.joint(joint_name).qpos.size
189
281
 
190
282
  def joint_position(self, joint_name: str) -> npt.NDArray:
191
- """"""
283
+ """Return the position of a joint."""
192
284
 
193
285
  if joint_name not in self.joint_names():
194
286
  raise ValueError(f"Joint '{joint_name}' not found")
@@ -196,7 +288,7 @@ class MujocoModelHelper:
196
288
  return self.data.joint(joint_name).qpos
197
289
 
198
290
  def joint_positions(self, joint_names: list[str] | None = None) -> npt.NDArray:
199
- """"""
291
+ """Return the positions of the joints."""
200
292
 
201
293
  joint_names = joint_names if joint_names is not None else self.joint_names()
202
294
 
@@ -207,7 +299,7 @@ class MujocoModelHelper:
207
299
  def set_joint_position(
208
300
  self, joint_name: str, position: npt.NDArray | float
209
301
  ) -> None:
210
- """"""
302
+ """Set the position of a joint."""
211
303
 
212
304
  position = np.atleast_1d(np.array(position).squeeze())
213
305
 
@@ -224,9 +316,9 @@ class MujocoModelHelper:
224
316
  self.data.qpos[sl] = position
225
317
 
226
318
  def set_joint_positions(
227
- self, joint_names: list[str], positions: npt.NDArray | list[npt.NDArray]
319
+ self, joint_names: Sequence[str], positions: npt.NDArray | list[npt.NDArray]
228
320
  ) -> None:
229
- """"""
321
+ """Set the positions of multiple joints."""
230
322
 
231
323
  mask = self.mask_qpos(joint_names=tuple(joint_names))
232
324
  self.data.qpos[mask] = positions
@@ -236,12 +328,12 @@ class MujocoModelHelper:
236
328
  # ==================
237
329
 
238
330
  def number_of_bodies(self) -> int:
239
- """"""
331
+ """Return the number of bodies in the model."""
240
332
 
241
333
  return self.model.nbody
242
334
 
243
335
  def body_names(self) -> list[str]:
244
- """"""
336
+ """Return the names of the bodies in the model."""
245
337
 
246
338
  return [
247
339
  mj.mj_id2name(self.model, mj.mjtObj.mjOBJ_BODY, idx)
@@ -249,7 +341,7 @@ class MujocoModelHelper:
249
341
  ]
250
342
 
251
343
  def body_position(self, body_name: str) -> npt.NDArray:
252
- """"""
344
+ """Return the position of a body."""
253
345
 
254
346
  if body_name not in self.body_names():
255
347
  raise ValueError(f"Body '{body_name}' not found")
@@ -257,7 +349,7 @@ class MujocoModelHelper:
257
349
  return self.data.body(body_name).xpos
258
350
 
259
351
  def body_orientation(self, body_name: str, dcm: bool = False) -> npt.NDArray:
260
- """"""
352
+ """Return the orientation of a body."""
261
353
 
262
354
  if body_name not in self.body_names():
263
355
  raise ValueError(f"Body '{body_name}' not found")
@@ -271,12 +363,12 @@ class MujocoModelHelper:
271
363
  # ======================
272
364
 
273
365
  def number_of_geometries(self) -> int:
274
- """"""
366
+ """Return the number of geometries in the model."""
275
367
 
276
368
  return self.model.ngeom
277
369
 
278
370
  def geometry_names(self) -> list[str]:
279
- """"""
371
+ """Return the names of the geometries in the model."""
280
372
 
281
373
  return [
282
374
  mj.mj_id2name(self.model, mj.mjtObj.mjOBJ_GEOM, idx)
@@ -284,7 +376,7 @@ class MujocoModelHelper:
284
376
  ]
285
377
 
286
378
  def geometry_position(self, geometry_name: str) -> npt.NDArray:
287
- """"""
379
+ """Return the position of a geometry."""
288
380
 
289
381
  if geometry_name not in self.geometry_names():
290
382
  raise ValueError(f"Geometry '{geometry_name}' not found")
@@ -294,7 +386,7 @@ class MujocoModelHelper:
294
386
  def geometry_orientation(
295
387
  self, geometry_name: str, dcm: bool = False
296
388
  ) -> npt.NDArray:
297
- """"""
389
+ """Return the orientation of a geometry."""
298
390
 
299
391
  if geometry_name not in self.geometry_names():
300
392
  raise ValueError(f"Geometry '{geometry_name}' not found")
@@ -304,8 +396,8 @@ class MujocoModelHelper:
304
396
  if dcm:
305
397
  return R
306
398
 
307
- q_xyzw = Rotation.from_matrix(R).as_quat(canonical=True)
308
- return q_xyzw[[3, 0, 1, 2]]
399
+ q_xyzw = Rotation.from_matrix(R).as_quat(canonical=True, scalar_first=False)
400
+ return q_xyzw
309
401
 
310
402
  # ===============
311
403
  # Private methods
@@ -346,7 +438,45 @@ class MujocoModelHelper:
346
438
  for i in range(self.joint_dofs(joint_name=joint_name))
347
439
  ]
348
440
  )
349
- for idx, joint_name in zip(idxs, joint_names)
441
+ for idx, joint_name in zip(idxs, joint_names, strict=True)
350
442
  ]
351
443
  ).squeeze()
352
444
  )
445
+
446
+
447
+ def generate_hfield(
448
+ heightmap: HeightmapCallable,
449
+ samples_xy: tuple[int, int] = (11, 11),
450
+ radius_xy: tuple[float, float] = (1.0, 1.0),
451
+ ) -> npt.NDArray:
452
+ """
453
+ Generate an array with elevation points sampled from a heightmap function.
454
+
455
+ The map will have the following format:
456
+ ```
457
+ heightmap[0, 0] heightmap[0, 1] ... heightmap[0, size[1]-1]
458
+ heightmap[1, 0] heightmap[1, 1] ... heightmap[1, size[1]-1]
459
+ ...
460
+ heightmap[size[0]-1, 0] heightmap[size[0]-1, 1] ... heightmap[size[0]-1, size[1]-1]
461
+ ```
462
+
463
+ Args:
464
+ heightmap:
465
+ A function that takes two arguments (x, y) and returns the height
466
+ at that point.
467
+ samples_xy: A tuple of two integers representing the size of the grid.
468
+ radius_xy:
469
+ A tuple of two floats representing extension of the heightmap in the
470
+ x-y surface corresponding to the area over which the grid of the sampled
471
+ heightmap is generated.
472
+
473
+ Returns:
474
+ A flat array of the sampled terrain heightmap.
475
+ """
476
+
477
+ # Generate the grid.
478
+ x = np.linspace(-radius_xy[0], radius_xy[0], samples_xy[0])
479
+ y = np.linspace(-radius_xy[1], radius_xy[1], samples_xy[1])
480
+
481
+ # Generate the heightmap.
482
+ return np.array([[heightmap(xi, yi) for xi in x] for yi in y]).flatten()
jaxsim/mujoco/utils.py ADDED
@@ -0,0 +1,228 @@
1
+ from __future__ import annotations
2
+
3
+ import dataclasses
4
+ from collections.abc import Sequence
5
+
6
+ import mujoco as mj
7
+ import numpy as np
8
+ import numpy.typing as npt
9
+ from scipy.spatial.transform import Rotation
10
+
11
+ from .model import MujocoModelHelper
12
+
13
+
14
+ def mujoco_data_from_jaxsim(
15
+ mujoco_model: mj.MjModel,
16
+ jaxsim_model,
17
+ jaxsim_data,
18
+ mujoco_data: mj.MjData | None = None,
19
+ update_removed_joints: bool = True,
20
+ ) -> mj.MjData:
21
+ """
22
+ Create a Mujoco data object from a JaxSim model and data objects.
23
+
24
+ Args:
25
+ mujoco_model: The Mujoco model object corresponding to the JaxSim model.
26
+ jaxsim_model: The JaxSim model object from which the Mujoco model was created.
27
+ jaxsim_data: The JaxSim data object containing the state of the model.
28
+ mujoco_data: An optional Mujoco data object. If None, a new one will be created.
29
+ update_removed_joints:
30
+ If True, the positions of the joints that have been removed during the
31
+ model reduction process will be set to their initial values.
32
+
33
+ Returns:
34
+ The Mujoco data object containing the state of the JaxSim model.
35
+
36
+ Note:
37
+ This method is useful to initialize a Mujoco data object used for visualization
38
+ with the state of a JaxSim model. In particular, this function takes care of
39
+ initializing the positions of the joints that have been removed during the
40
+ model reduction process. After the initial creation of the Mujoco data object,
41
+ it's faster to update the state using an external MujocoModelHelper object.
42
+ """
43
+
44
+ # The package `jaxsim.mujoco` is supposed to be jax-independent.
45
+ # We import all the JaxSim resources privately.
46
+ import jaxsim.api as js
47
+
48
+ if not isinstance(jaxsim_model, js.model.JaxSimModel):
49
+ raise ValueError("The `jaxsim_model` argument must be a JaxSimModel object.")
50
+
51
+ if not isinstance(jaxsim_data, js.data.JaxSimModelData):
52
+ raise ValueError("The `jaxsim_data` argument must be a JaxSimModelData object.")
53
+
54
+ # Create the helper to operate on the Mujoco model and data.
55
+ model_helper = MujocoModelHelper(model=mujoco_model, data=mujoco_data)
56
+
57
+ # If the model is fixed-base, the Mujoco model won't have the joint corresponding
58
+ # to the floating base, and the helper would raise an exception.
59
+ if jaxsim_model.floating_base():
60
+
61
+ # Set the model position.
62
+ model_helper.set_base_position(position=np.array(jaxsim_data.base_position()))
63
+
64
+ # Set the model orientation.
65
+ model_helper.set_base_orientation(
66
+ orientation=np.array(jaxsim_data.base_orientation())
67
+ )
68
+
69
+ # Set the joint positions.
70
+ if jaxsim_model.dofs() > 0:
71
+
72
+ model_helper.set_joint_positions(
73
+ joint_names=list(jaxsim_model.joint_names()),
74
+ positions=np.array(
75
+ jaxsim_data.joint_positions(
76
+ model=jaxsim_model, joint_names=jaxsim_model.joint_names()
77
+ )
78
+ ),
79
+ )
80
+
81
+ # Updating these joints is not necessary after the first time.
82
+ # Users can disable this update after initialization.
83
+ if update_removed_joints:
84
+
85
+ # Create a dictionary with the joints that have been removed for various reasons
86
+ # (like link lumping due to model reduction).
87
+ joints_removed_dict = {
88
+ j.name: j
89
+ for j in jaxsim_model.description._joints_removed
90
+ if j.name not in set(jaxsim_model.joint_names())
91
+ }
92
+
93
+ # Set the positions of the removed joints.
94
+ _ = [
95
+ model_helper.set_joint_position(
96
+ position=joints_removed_dict[joint_name].initial_position,
97
+ joint_name=joint_name,
98
+ )
99
+ # Select all original joint that have been removed from the JaxSim model
100
+ # that are still present in the Mujoco model.
101
+ for joint_name in joints_removed_dict
102
+ if joint_name in model_helper.joint_names()
103
+ ]
104
+
105
+ # Return the mujoco data with updated kinematics.
106
+ mj.mj_forward(mujoco_model, model_helper.data)
107
+
108
+ return model_helper.data
109
+
110
+
111
+ @dataclasses.dataclass
112
+ class MujocoCamera:
113
+ """
114
+ Helper class storing parameters of a Mujoco camera.
115
+
116
+ Refer to the official documentation for more details:
117
+ https://mujoco.readthedocs.io/en/stable/XMLreference.html#body-camera
118
+ """
119
+
120
+ mode: str = "fixed"
121
+
122
+ target: str | None = None
123
+ fovy: str = "45"
124
+ pos: str = "0 0 0"
125
+
126
+ quat: str | None = None
127
+ axisangle: str | None = None
128
+ xyaxes: str | None = None
129
+ zaxis: str | None = None
130
+ euler: str | None = None
131
+
132
+ name: str | None = None
133
+
134
+ @classmethod
135
+ def build(cls, **kwargs) -> MujocoCamera:
136
+ """
137
+ Build a Mujoco camera from a dictionary.
138
+ """
139
+
140
+ if not all(isinstance(value, str) for value in kwargs.values()):
141
+ raise ValueError(f"Values must be strings: {kwargs}")
142
+
143
+ return cls(**kwargs)
144
+
145
+ @staticmethod
146
+ def build_from_target_view(
147
+ camera_name: str,
148
+ lookat: Sequence[float | int] | npt.NDArray = (0, 0, 0),
149
+ distance: float | int | npt.NDArray = 3,
150
+ azimuth: float | int | npt.NDArray = 90,
151
+ elevation: float | int | npt.NDArray = -45,
152
+ fovy: float | int | npt.NDArray = 45,
153
+ degrees: bool = True,
154
+ **kwargs,
155
+ ) -> MujocoCamera:
156
+ """
157
+ Create a custom camera that looks at a target point.
158
+
159
+ Note:
160
+ The choice of the parameters is easier if we imagine to consider a target
161
+ frame `T` whose origin is located over the lookat point and having the same
162
+ orientation of the world frame `W`. We also introduce a camera frame `C`
163
+ whose origin is located over the lower-left corner of the image, and having
164
+ the x-axis pointing right and the y-axis pointing up in image coordinates.
165
+ The camera renders what it sees in the -z direction of frame `C`.
166
+
167
+ Args:
168
+ camera_name: The name of the camera.
169
+ lookat: The target point to look at (origin of `T`).
170
+ distance:
171
+ The distance from the target point (displacement between the origins
172
+ of `T` and `C`).
173
+ azimuth:
174
+ The rotation around z of the camera. With an angle of 0, the camera
175
+ would loot at the target point towards the positive x-axis of `T`.
176
+ elevation:
177
+ The rotation around the x-axis of the camera frame `C`. Note that if
178
+ you want to lift the view angle, the elevation is negative.
179
+ fovy: The field of view of the camera.
180
+ degrees: Whether the angles are in degrees or radians.
181
+ **kwargs: Additional camera parameters.
182
+
183
+ Returns:
184
+ The custom camera.
185
+ """
186
+
187
+ # Start from a frame whose origin is located over the lookat point.
188
+ # We initialize a -90 degrees rotation around the z-axis because due to
189
+ # the default camera coordinate system (x pointing right, y pointing up).
190
+ W_H_C = np.eye(4)
191
+ W_H_C[0:3, 3] = np.array(lookat)
192
+ W_H_C[0:3, 0:3] = Rotation.from_euler(
193
+ seq="ZX", angles=[-90, 90], degrees=True
194
+ ).as_matrix()
195
+
196
+ # Process the azimuth.
197
+ R_az = Rotation.from_euler(seq="Y", angles=azimuth, degrees=degrees).as_matrix()
198
+ W_H_C[0:3, 0:3] = W_H_C[0:3, 0:3] @ R_az
199
+
200
+ # Process elevation.
201
+ R_el = Rotation.from_euler(
202
+ seq="X", angles=elevation, degrees=degrees
203
+ ).as_matrix()
204
+ W_H_C[0:3, 0:3] = W_H_C[0:3, 0:3] @ R_el
205
+
206
+ # Process distance.
207
+ tf_distance = np.eye(4)
208
+ tf_distance[2, 3] = distance
209
+ W_H_C = W_H_C @ tf_distance
210
+
211
+ # Extract the position and the quaternion.
212
+ p = W_H_C[0:3, 3]
213
+ Q = Rotation.from_matrix(W_H_C[0:3, 0:3]).as_quat(scalar_first=True)
214
+
215
+ return MujocoCamera.build(
216
+ name=camera_name,
217
+ mode="fixed",
218
+ fovy=f"{fovy if degrees else np.rad2deg(fovy)}",
219
+ pos=" ".join(p.astype(str).tolist()),
220
+ quat=" ".join(Q.astype(str).tolist()),
221
+ **kwargs,
222
+ )
223
+
224
+ def asdict(self) -> dict[str, str]:
225
+ """
226
+ Convert the camera to a dictionary.
227
+ """
228
+ return {k: v for k, v in dataclasses.asdict(self).items() if v is not None}