jaxsim 0.4.1__tar.gz → 0.4.1.dev6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (121) hide show
  1. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/PKG-INFO +1 -2
  2. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/environment.yml +1 -2
  3. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/pyproject.toml +0 -1
  4. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/src/jaxsim/_version.py +2 -2
  5. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/src/jaxsim/integrators/common.py +33 -24
  6. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/src/jaxsim/mujoco/loaders.py +25 -134
  7. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/src/jaxsim/mujoco/model.py +16 -82
  8. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/src/jaxsim/mujoco/visualizer.py +6 -64
  9. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/src/jaxsim/rbda/utils.py +0 -8
  10. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/src/jaxsim/terrain/terrain.py +1 -1
  11. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/src/jaxsim.egg-info/PKG-INFO +1 -2
  12. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/src/jaxsim.egg-info/requires.txt +0 -1
  13. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/tests/test_automatic_differentiation.py +3 -3
  14. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/tests/test_simulations.py +30 -50
  15. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/.devcontainer/Dockerfile +0 -0
  16. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/.devcontainer/devcontainer.json +0 -0
  17. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/.gitattributes +0 -0
  18. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/.github/CODEOWNERS +0 -0
  19. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/.github/workflows/ci_cd.yml +0 -0
  20. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/.github/workflows/read_the_docs.yml +0 -0
  21. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/.gitignore +0 -0
  22. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/.pre-commit-config.yaml +0 -0
  23. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/.readthedocs.yaml +0 -0
  24. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/CONTRIBUTING.md +0 -0
  25. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/LICENSE +0 -0
  26. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/README.md +0 -0
  27. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/docs/Makefile +0 -0
  28. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/docs/conf.py +0 -0
  29. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/docs/examples.rst +0 -0
  30. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/docs/guide/install.rst +0 -0
  31. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/docs/index.rst +0 -0
  32. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/docs/make.bat +0 -0
  33. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/docs/modules/api.rst +0 -0
  34. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/docs/modules/integrators.rst +0 -0
  35. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/docs/modules/math.rst +0 -0
  36. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/docs/modules/mujoco.rst +0 -0
  37. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/docs/modules/parsers.rst +0 -0
  38. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/docs/modules/rbda.rst +0 -0
  39. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/docs/modules/typing.rst +0 -0
  40. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/docs/modules/utils.rst +0 -0
  41. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/examples/.gitattributes +0 -0
  42. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/examples/.gitignore +0 -0
  43. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/examples/PD_controller.ipynb +0 -0
  44. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/examples/Parallel_computing.ipynb +0 -0
  45. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/examples/README.md +0 -0
  46. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/examples/assets/cartpole.urdf +0 -0
  47. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/pixi.lock +0 -0
  48. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/setup.cfg +0 -0
  49. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/setup.py +0 -0
  50. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/src/jaxsim/__init__.py +0 -0
  51. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/src/jaxsim/api/__init__.py +0 -0
  52. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/src/jaxsim/api/com.py +0 -0
  53. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/src/jaxsim/api/common.py +0 -0
  54. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/src/jaxsim/api/contact.py +0 -0
  55. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/src/jaxsim/api/data.py +0 -0
  56. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/src/jaxsim/api/frame.py +0 -0
  57. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/src/jaxsim/api/joint.py +0 -0
  58. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/src/jaxsim/api/kin_dyn_parameters.py +0 -0
  59. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/src/jaxsim/api/link.py +0 -0
  60. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/src/jaxsim/api/model.py +0 -0
  61. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/src/jaxsim/api/ode.py +0 -0
  62. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/src/jaxsim/api/ode_data.py +0 -0
  63. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/src/jaxsim/api/references.py +0 -0
  64. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/src/jaxsim/exceptions.py +0 -0
  65. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/src/jaxsim/integrators/__init__.py +0 -0
  66. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/src/jaxsim/integrators/fixed_step.py +0 -0
  67. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/src/jaxsim/integrators/variable_step.py +0 -0
  68. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/src/jaxsim/logging.py +0 -0
  69. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/src/jaxsim/math/__init__.py +0 -0
  70. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/src/jaxsim/math/adjoint.py +0 -0
  71. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/src/jaxsim/math/cross.py +0 -0
  72. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/src/jaxsim/math/inertia.py +0 -0
  73. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/src/jaxsim/math/joint_model.py +0 -0
  74. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/src/jaxsim/math/quaternion.py +0 -0
  75. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/src/jaxsim/math/rotation.py +0 -0
  76. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/src/jaxsim/math/skew.py +0 -0
  77. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/src/jaxsim/math/transform.py +0 -0
  78. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/src/jaxsim/mujoco/__init__.py +0 -0
  79. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/src/jaxsim/mujoco/__main__.py +0 -0
  80. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/src/jaxsim/parsers/__init__.py +0 -0
  81. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/src/jaxsim/parsers/descriptions/__init__.py +0 -0
  82. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/src/jaxsim/parsers/descriptions/collision.py +0 -0
  83. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/src/jaxsim/parsers/descriptions/joint.py +0 -0
  84. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/src/jaxsim/parsers/descriptions/link.py +0 -0
  85. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/src/jaxsim/parsers/descriptions/model.py +0 -0
  86. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/src/jaxsim/parsers/kinematic_graph.py +0 -0
  87. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/src/jaxsim/parsers/rod/__init__.py +0 -0
  88. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/src/jaxsim/parsers/rod/parser.py +0 -0
  89. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/src/jaxsim/parsers/rod/utils.py +0 -0
  90. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/src/jaxsim/rbda/__init__.py +0 -0
  91. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/src/jaxsim/rbda/aba.py +0 -0
  92. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/src/jaxsim/rbda/collidable_points.py +0 -0
  93. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/src/jaxsim/rbda/contacts/__init__.py +0 -0
  94. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/src/jaxsim/rbda/contacts/common.py +0 -0
  95. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/src/jaxsim/rbda/contacts/soft.py +0 -0
  96. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/src/jaxsim/rbda/crba.py +0 -0
  97. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/src/jaxsim/rbda/forward_kinematics.py +0 -0
  98. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/src/jaxsim/rbda/jacobian.py +0 -0
  99. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/src/jaxsim/rbda/rnea.py +0 -0
  100. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/src/jaxsim/terrain/__init__.py +0 -0
  101. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/src/jaxsim/typing.py +0 -0
  102. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/src/jaxsim/utils/__init__.py +0 -0
  103. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/src/jaxsim/utils/jaxsim_dataclass.py +0 -0
  104. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/src/jaxsim/utils/tracing.py +0 -0
  105. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/src/jaxsim/utils/wrappers.py +0 -0
  106. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/src/jaxsim.egg-info/SOURCES.txt +0 -0
  107. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/src/jaxsim.egg-info/dependency_links.txt +0 -0
  108. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/src/jaxsim.egg-info/top_level.txt +0 -0
  109. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/tests/__init__.py +0 -0
  110. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/tests/conftest.py +0 -0
  111. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/tests/test_api_com.py +0 -0
  112. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/tests/test_api_contact.py +0 -0
  113. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/tests/test_api_data.py +0 -0
  114. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/tests/test_api_frame.py +0 -0
  115. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/tests/test_api_joint.py +0 -0
  116. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/tests/test_api_link.py +0 -0
  117. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/tests/test_api_model.py +0 -0
  118. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/tests/test_contact.py +0 -0
  119. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/tests/test_exceptions.py +0 -0
  120. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/tests/test_pytree.py +0 -0
  121. {jaxsim-0.4.1 → jaxsim-0.4.1.dev6}/tests/utils_idyntree.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jaxsim
3
- Version: 0.4.1
3
+ Version: 0.4.1.dev6
4
4
  Summary: A differentiable physics engine and multibody dynamics library for control and robot learning.
5
5
  Author-email: Diego Ferigo <dgferigo@gmail.com>
6
6
  Maintainer-email: Diego Ferigo <dgferigo@gmail.com>, Filippo Luca Ferretti <filippo.ferretti@iit.it>
@@ -80,7 +80,6 @@ Provides-Extra: viz
80
80
  Requires-Dist: lxml; extra == "viz"
81
81
  Requires-Dist: mediapy; extra == "viz"
82
82
  Requires-Dist: mujoco>=3.0.0; extra == "viz"
83
- Requires-Dist: scipy>=1.14.0; extra == "viz"
84
83
  Provides-Extra: all
85
84
  Requires-Dist: jaxsim[style,testing,viz]; extra == "all"
86
85
 
@@ -5,7 +5,7 @@ dependencies:
5
5
  # ===========================
6
6
  # Dependencies from setup.cfg
7
7
  # ===========================
8
- - python >= 3.12.0
8
+ - python=3.11
9
9
  - coloredlogs
10
10
  - jax >= 0.4.13
11
11
  - jaxlib >= 0.4.13
@@ -30,7 +30,6 @@ dependencies:
30
30
  - lxml
31
31
  - mediapy
32
32
  - mujoco >= 3.0.0
33
- - scipy >= 1.14.0
34
33
  # ==========================
35
34
  # Documentation dependencies
36
35
  # ==========================
@@ -69,7 +69,6 @@ viz = [
69
69
  "lxml",
70
70
  "mediapy",
71
71
  "mujoco >= 3.0.0",
72
- "scipy >= 1.14.0",
73
72
  ]
74
73
  all = [
75
74
  "jaxsim[style,testing,viz]",
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.4.1'
16
- __version_tuple__ = version_tuple = (0, 4, 1)
15
+ __version__ = version = '0.4.1.dev6'
16
+ __version_tuple__ = version_tuple = (0, 4, 1, 'dev6')
@@ -5,12 +5,11 @@ from typing import Any, ClassVar, Generic, Protocol, Type, TypeVar
5
5
  import jax
6
6
  import jax.numpy as jnp
7
7
  import jax_dataclasses
8
+ import jaxlie
8
9
  from jax_dataclasses import Static
9
10
 
10
11
  import jaxsim.api as js
11
- import jaxsim.math
12
12
  import jaxsim.typing as jtp
13
- from jaxsim import exceptions
14
13
  from jaxsim.utils.jaxsim_dataclass import JaxsimDataclass, Mutability
15
14
 
16
15
  try:
@@ -540,38 +539,48 @@ class ExplicitRungeKuttaSO3Mixin:
540
539
  `PyTreeType = ODEState` to integrate the quaternion on SO(3).
541
540
  """
542
541
 
542
+ @classmethod
543
+ def integrate_rk_stage(
544
+ cls, x0: js.ode_data.ODEState, t0: Time, dt: TimeStep, k: js.ode_data.ODEState
545
+ ) -> js.ode_data.ODEState:
546
+
547
+ op = lambda x0_leaf, k_leaf: x0_leaf + dt * k_leaf
548
+ xf: js.ode_data.ODEState = jax.tree_util.tree_map(op, x0, k)
549
+
550
+ W_Q_B_tf = xf.physics_model.base_quaternion
551
+
552
+ return xf.replace(
553
+ physics_model=xf.physics_model.replace(
554
+ base_quaternion=W_Q_B_tf / jnp.linalg.norm(W_Q_B_tf)
555
+ )
556
+ )
557
+
543
558
  @classmethod
544
559
  def post_process_state(
545
560
  cls, x0: js.ode_data.ODEState, t0: Time, xf: js.ode_data.ODEState, dt: TimeStep
546
561
  ) -> js.ode_data.ODEState:
547
562
 
548
- # Extract the initial base quaternion.
549
- W_Q_B_t0 = x0.physics_model.base_quaternion
563
+ # Indices to convert quaternions between serializations.
564
+ to_xyzw = jnp.array([1, 2, 3, 0])
550
565
 
551
- # We assume that the initial quaternion is already unary.
552
- exceptions.raise_runtime_error_if(
553
- condition=jnp.logical_not(jnp.allclose(W_Q_B_t0.dot(W_Q_B_t0), 1.0)),
554
- msg="The SO(3) integrator received a quaternion at t0 that is not unary.",
566
+ # Get the initial rotation.
567
+ W_R_B_t0 = jaxlie.SO3.from_quaternion_xyzw(
568
+ xyzw=x0.physics_model.base_quaternion[to_xyzw]
555
569
  )
556
570
 
557
- # Get the angular velocity ω to integrate the quaternion.
558
- # This velocity ω[t0] is computed in the previous timestep by averaging the kᵢ
559
- # corresponding to the active RK-based scheme. Therefore, by using the ω[t0],
560
- # we obtain an explicit RK scheme operating on the SO(3) manifold.
561
- # Note that the current integrator is not a semi-implicit scheme, therefore
562
- # using the final ω[tf] would be not correct.
563
- W_ω_WB_t0 = x0.physics_model.base_angular_velocity
564
-
565
- # Integrate the quaternion on SO(3).
566
- W_Q_B_tf = jaxsim.math.Quaternion.integration(
567
- quaternion=W_Q_B_t0,
568
- dt=dt,
569
- omega=W_ω_WB_t0,
570
- omega_in_body_fixed=False,
571
- )
571
+ # Get the final angular velocity.
572
+ # This is already computed by averaging the kᵢ in RK-based schemes.
573
+ # Therefore, by using the ω at tf, we obtain a RK scheme operating
574
+ # on the SO(3) manifold.
575
+ W_ω_WB_tf = xf.physics_model.base_angular_velocity
576
+
577
+ # Integrate the orientation on SO(3).
578
+ # Note that we left-multiply with the exponential map since the angular
579
+ # velocity is expressed in the inertial frame.
580
+ W_R_B_tf = jaxlie.SO3.exp(tangent=dt * W_ω_WB_tf) @ W_R_B_t0
572
581
 
573
582
  # Replace the quaternion in the final state.
574
583
  return xf.replace(
575
- physics_model=xf.physics_model.replace(base_quaternion=W_Q_B_tf),
584
+ physics_model=xf.physics_model.replace(base_quaternion=W_R_B_tf.wxyz),
576
585
  validate=True,
577
586
  )
@@ -1,17 +1,12 @@
1
- from __future__ import annotations
2
-
3
1
  import dataclasses
4
2
  import pathlib
5
3
  import tempfile
6
4
  import warnings
7
- from typing import Any, Sequence
5
+ from typing import Any
8
6
 
9
7
  import mujoco as mj
10
- import numpy as np
11
- import numpy.typing as npt
12
8
  import rod.urdf.exporter
13
9
  from lxml import etree as ET
14
- from scipy.spatial.transform import Rotation
15
10
 
16
11
 
17
12
  def load_rod_model(
@@ -165,13 +160,7 @@ class RodModelToMjcf:
165
160
  considered_joints: list[str] | None = None,
166
161
  plane_normal: tuple[float, float, float] = (0, 0, 1),
167
162
  heightmap: bool | None = None,
168
- heightmap_samples_xy: tuple[int, int] = (101, 101),
169
- cameras: (
170
- MujocoCamera
171
- | Sequence[MujocoCamera]
172
- | dict[str, str]
173
- | Sequence[dict[str, str]]
174
- ) = (),
163
+ cameras: list[dict[str, str]] | dict[str, str] | None = None,
175
164
  ) -> tuple[str, dict[str, Any]]:
176
165
  """
177
166
  Converts a ROD model to a Mujoco MJCF string.
@@ -181,11 +170,10 @@ class RodModelToMjcf:
181
170
  considered_joints: The list of joint names to consider in the conversion.
182
171
  plane_normal: The normal vector of the plane.
183
172
  heightmap: Whether to generate a heightmap.
184
- heightmap_samples_xy: The number of points in the heightmap grid.
185
- cameras: The custom cameras to add to the scene.
173
+ cameras: The list of cameras to add to the scene.
186
174
 
187
175
  Returns:
188
- A tuple containing the MJCF string and the dictionary of assets.
176
+ tuple: A tuple containing the MJCF string and the assets dictionary.
189
177
  """
190
178
 
191
179
  # -------------------------------------
@@ -260,6 +248,7 @@ class RodModelToMjcf:
260
248
 
261
249
  parser = ET.XMLParser(remove_blank_text=True)
262
250
  root: ET._Element = ET.fromstring(text=urdf_string.encode(), parser=parser)
251
+ import numpy as np
263
252
 
264
253
  # Give a tiny radius to all dummy spheres
265
254
  for geometry in root.findall(".//visual/geometry[sphere]"):
@@ -415,11 +404,9 @@ class RodModelToMjcf:
415
404
  asset_element,
416
405
  "hfield",
417
406
  name="terrain",
418
- nrow=f"{int(heightmap_samples_xy[0])}",
419
- ncol=f"{int(heightmap_samples_xy[1])}",
420
- # The following 'size' is a placeholder, it is updated dynamically
421
- # when a hfield/heightmap is stored into MjData.
422
- size="1 1 1 1",
407
+ nrow="100",
408
+ ncol="100",
409
+ size="5 5 1 1",
423
410
  )
424
411
  if heightmap
425
412
  else None
@@ -487,17 +474,14 @@ class RodModelToMjcf:
487
474
  fovy="60",
488
475
  )
489
476
 
490
- # Add user-defined camera.
491
- for camera in cameras if isinstance(cameras, Sequence) else [cameras]:
492
-
493
- mj_camera = (
494
- camera
495
- if isinstance(camera, MujocoCamera)
496
- else MujocoCamera.build(**camera)
477
+ # Add user-defined camera
478
+ cameras = cameras if cameras is not None else []
479
+ for camera in cameras if isinstance(cameras, list) else [cameras]:
480
+ mj_camera = MujocoCamera.build(**camera)
481
+ _ = ET.SubElement(
482
+ worldbody_element, "camera", dataclasses.asdict(mj_camera)
497
483
  )
498
484
 
499
- _ = ET.SubElement(worldbody_element, "camera", mj_camera.asdict())
500
-
501
485
  # ------------------------------------------------
502
486
  # Add a light following the CoM of the first link
503
487
  # ------------------------------------------------
@@ -610,114 +594,21 @@ class SdfToMjcf:
610
594
 
611
595
  @dataclasses.dataclass
612
596
  class MujocoCamera:
613
- """
614
- Helper class storing parameters of a Mujoco camera.
615
-
616
- Refer to the official documentation for more details:
617
- https://mujoco.readthedocs.io/en/stable/XMLreference.html#body-camera
618
- """
619
-
620
- mode: str = "fixed"
621
-
622
- target: str | None = None
623
- fovy: str = "45"
624
- pos: str = "0 0 0"
625
-
626
- quat: str | None = None
627
- axisangle: str | None = None
628
- xyaxes: str | None = None
629
- zaxis: str | None = None
630
- euler: str | None = None
631
-
632
- name: str | None = None
597
+ name: str
598
+ mode: str
599
+ pos: str
600
+ xyaxes: str
601
+ fovy: str
633
602
 
634
603
  @classmethod
635
- def build(cls, **kwargs) -> MujocoCamera:
636
-
604
+ def build(cls, **kwargs):
637
605
  if not all(isinstance(value, str) for value in kwargs.values()):
638
606
  raise ValueError("Values must be strings")
639
607
 
640
- return cls(**kwargs)
641
-
642
- @staticmethod
643
- def build_from_target_view(
644
- camera_name: str,
645
- lookat: Sequence[float | int] | npt.NDArray = (0, 0, 0),
646
- distance: float | int | npt.NDArray = 3,
647
- azimut: float | int | npt.NDArray = 90,
648
- elevation: float | int | npt.NDArray = -45,
649
- fovy: float | int | npt.NDArray = 45,
650
- degrees: bool = True,
651
- **kwargs,
652
- ) -> MujocoCamera:
653
- """
654
- Create a custom camera that looks at a target point.
655
-
656
- Note:
657
- The choice of the parameters is easier if we imagine to consider a target
658
- frame `T` whose origin is located over the lookat point and having the same
659
- orientation of the world frame `W`. We also introduce a camera frame `C`
660
- whose origin is located over the lower-left corner of the image, and having
661
- the x-axis pointing right and the y-axis pointing up in image coordinates.
662
- The camera renders what it sees in the -z direction of frame `C`.
663
-
664
- Args:
665
- camera_name: The name of the camera.
666
- lookat: The target point to look at (origin of `T`).
667
- distance:
668
- The distance from the target point (displacement between the origins
669
- of `T` and `C`).
670
- azimut:
671
- The rotation around z of the camera. With an angle of 0, the camera
672
- would loot at the target point towards the positive x-axis of `T`.
673
- elevation:
674
- The rotation around the x-axis of the camera frame `C`. Note that if
675
- you want to lift the view angle, the elevation is negative.
676
- fovy: The field of view of the camera.
677
- degrees: Whether the angles are in degrees or radians.
678
- **kwargs: Additional camera parameters.
608
+ if len(kwargs["pos"].split()) != 3:
609
+ raise ValueError("pos must have three values separated by space")
679
610
 
680
- Returns:
681
- The custom camera.
682
- """
611
+ if len(kwargs["xyaxes"].split()) != 6:
612
+ raise ValueError("xyaxes must have six values separated by space")
683
613
 
684
- # Start from a frame whose origin is located over the lookat point.
685
- # We initialize a -90 degrees rotation around the z-axis because due to
686
- # the default camera coordinate system (x pointing right, y pointing up).
687
- W_H_C = np.eye(4)
688
- W_H_C[0:3, 3] = np.array(lookat)
689
- W_H_C[0:3, 0:3] = Rotation.from_euler(
690
- seq="ZX", angles=[-90, 90], degrees=True
691
- ).as_matrix()
692
-
693
- # Process the azimut.
694
- R_az = Rotation.from_euler(seq="Y", angles=azimut, degrees=degrees).as_matrix()
695
- W_H_C[0:3, 0:3] = W_H_C[0:3, 0:3] @ R_az
696
-
697
- # Process elevation.
698
- R_el = Rotation.from_euler(
699
- seq="X", angles=elevation, degrees=degrees
700
- ).as_matrix()
701
- W_H_C[0:3, 0:3] = W_H_C[0:3, 0:3] @ R_el
702
-
703
- # Process distance.
704
- tf_distance = np.eye(4)
705
- tf_distance[2, 3] = distance
706
- W_H_C = W_H_C @ tf_distance
707
-
708
- # Extract the position and the quaternion.
709
- p = W_H_C[0:3, 3]
710
- Q = Rotation.from_matrix(W_H_C[0:3, 0:3]).as_quat(scalar_first=True)
711
-
712
- return MujocoCamera.build(
713
- name=camera_name,
714
- mode="fixed",
715
- fovy=f"{fovy if degrees else np.rad2deg(fovy)}",
716
- pos=" ".join(p.astype(str).tolist()),
717
- quat=" ".join(Q.astype(str).tolist()),
718
- **kwargs,
719
- )
720
-
721
- def asdict(self) -> dict[str, str]:
722
-
723
- return {k: v for k, v in dataclasses.asdict(self).items() if v is not None}
614
+ return cls(**kwargs)
@@ -7,7 +7,6 @@ from typing import Any, Callable
7
7
  import mujoco as mj
8
8
  import numpy as np
9
9
  import numpy.typing as npt
10
- import xmltodict
11
10
  from scipy.spatial.transform import Rotation
12
11
 
13
12
  import jaxsim.typing as jtp
@@ -43,27 +42,16 @@ class MujocoModelHelper:
43
42
  mjcf_description: str | pathlib.Path,
44
43
  assets: dict[str, Any] | None = None,
45
44
  heightmap: HeightmapCallable | None = None,
46
- heightmap_name: str = "terrain",
47
- heightmap_radius_xy: tuple[float, float] = (1.0, 1.0),
48
45
  ) -> MujocoModelHelper:
49
46
  """
50
- Build a Mujoco model from an MJCF description.
47
+ Build a Mujoco model from an XML description and an optional assets dictionary.
51
48
 
52
49
  Args:
53
- mjcf_description:
54
- A string containing the XML description of the Mujoco model
50
+ mjcf_description: A string containing the XML description of the Mujoco model
55
51
  or a path to a file containing the XML description.
56
52
  assets: An optional dictionary containing the assets of the model.
57
- heightmap:
58
- A function in two variables that returns the height of a terrain
53
+ heightmap: A function in two variables that returns the height of a terrain
59
54
  in the specified coordinate point.
60
- heightmap_name:
61
- The default name of the heightmap in the MJCF description
62
- to load the corresponding configuration.
63
- heightmap_radius_xy:
64
- The extension of the heightmap in the x-y surface corresponding to the
65
- plane over which the grid of the sampled heightmap is generated.
66
-
67
55
  Returns:
68
56
  A MujocoModelHelper object.
69
57
  """
@@ -75,61 +63,15 @@ class MujocoModelHelper:
75
63
  else mjcf_description
76
64
  )
77
65
 
78
- if heightmap is None:
79
- hfield = None
80
-
81
- else:
82
-
83
- mjcf_description_dict = xmltodict.parse(xml_input=mjcf_description)
84
-
85
- # Create a dictionary of all hfield configurations from the MJCF.
86
- hfields = mjcf_description_dict["mujoco"]["asset"].get("hfield", [])
87
- hfields = hfields if isinstance(hfields, list) else [hfields]
88
- hfields_dict = {hfield["@name"]: hfield for hfield in hfields}
89
-
90
- if heightmap_name not in hfields_dict:
91
- raise ValueError(f"Heightmap '{heightmap_name}' not found in MJCF")
92
-
93
- hfield_element = hfields_dict[heightmap_name]
94
-
95
- # Generate the hfield by sampling the heightmap function.
96
- hfield = generate_hfield(
97
- heightmap=heightmap,
98
- samples_xy=(int(hfield_element["@nrow"]), int(hfield_element["@ncol"])),
99
- radius_xy=heightmap_radius_xy,
100
- )
101
-
102
- # Update dynamically the '/asset/hfield[@name=heightmap_name]@size' attribute
103
- # with the information of the sampled points.
104
- # This is necessary for correctly rendering the heightmap over the
105
- # specified xy area with the correct z elevation.
106
- size = [float(el) for el in hfield_element["@size"].split(" ")]
107
- size[0], size[1] = heightmap_radius_xy
108
- size[2] = 1.0
109
- size[3] = max(0, -min(hfield))
110
-
111
- # Replace the 'size' attribute.
112
- hfields_dict[heightmap_name]["@size"] = " ".join(str(el) for el in size)
113
-
114
- # Update the hfield elements of the original MJCF.
115
- # Only the hfield corresponding to 'heightmap_name' was actually edited.
116
- mjcf_description_dict["mujoco"]["asset"]["hfield"] = list(
117
- hfields_dict.values()
118
- )
119
-
120
- # Serialize the updated MJCF to XML.
121
- mjcf_description = xmltodict.unparse(
122
- input_dict=mjcf_description_dict, pretty=True
123
- )
124
-
125
- # Create the Mujoco model from the XML and, optionally, the dictionary of assets.
66
+ # Create the Mujoco model from the XML and, optionally, the assets dictionary.
126
67
  model = mj.MjModel.from_xml_string(xml=mjcf_description, assets=assets)
127
68
  data = mj.MjData(model)
128
69
 
129
- # Store the sampled heightmap into the Mujoco model.
130
- if heightmap is not None:
131
- assert hfield is not None
132
- model.hfield_data = hfield
70
+ if heightmap:
71
+ nrow = model.hfield_nrow.item()
72
+ ncol = model.hfield_ncol.item()
73
+ new_hfield = generate_hfield(heightmap, (nrow, ncol))
74
+ model.hfield_data = new_hfield
133
75
 
134
76
  return MujocoModelHelper(model=model, data=data)
135
77
 
@@ -443,13 +385,10 @@ class MujocoModelHelper:
443
385
 
444
386
 
445
387
  def generate_hfield(
446
- heightmap: HeightmapCallable,
447
- samples_xy: tuple[int, int] = (11, 11),
448
- radius_xy: tuple[float, float] = (1.0, 1.0),
388
+ heightmap: HeightmapCallable, size: tuple[int, int] = (10, 10)
449
389
  ) -> npt.NDArray:
450
390
  """
451
- Generate an array with elevation points sampled from a heightmap function.
452
-
391
+ Generates a numpy array representing the heightmap of
453
392
  The map will have the following format:
454
393
  ```
455
394
  heightmap[0, 0] heightmap[0, 1] ... heightmap[0, size[1]-1]
@@ -459,22 +398,17 @@ def generate_hfield(
459
398
  ```
460
399
 
461
400
  Args:
462
- heightmap:
463
- A function that takes two arguments (x, y) and returns the height
401
+ heightmap: A function that takes two arguments (x, y) and returns the height
464
402
  at that point.
465
- samples_xy: A tuple of two integers representing the size of the grid.
466
- radius_xy:
467
- A tuple of two floats representing extension of the heightmap in the
468
- x-y surface corresponding to the area over which the grid of the sampled
469
- heightmap is generated.
403
+ size: A tuple of two integers representing the size of the grid.
470
404
 
471
405
  Returns:
472
- A flat array of the sampled terrain heightmap.
406
+ np.ndarray: The terrain heightmap
473
407
  """
474
408
 
475
409
  # Generate the grid.
476
- x = np.linspace(-radius_xy[0], radius_xy[0], samples_xy[0])
477
- y = np.linspace(-radius_xy[1], radius_xy[1], samples_xy[1])
410
+ x = np.linspace(0, 1, size[0])
411
+ y = np.linspace(0, 1, size[1])
478
412
 
479
413
  # Generate the heightmap.
480
414
  return np.array([[heightmap(xi, yi) for xi in x] for yi in y]).flatten()
@@ -1,11 +1,10 @@
1
1
  import contextlib
2
2
  import pathlib
3
- from typing import ContextManager, Sequence
3
+ from typing import ContextManager
4
4
 
5
5
  import mediapy as media
6
6
  import mujoco as mj
7
7
  import mujoco.viewer
8
- import numpy as np
9
8
  import numpy.typing as npt
10
9
 
11
10
 
@@ -63,16 +62,18 @@ class MujocoVideoRecorder:
63
62
  self.data = data if data is not None else self.data
64
63
  self.model = model if model is not None else self.model
65
64
 
66
- def render_frame(self, camera_name: str = "track") -> npt.NDArray:
65
+ def render_frame(self, camera_name: str | None = None) -> npt.NDArray:
67
66
  """Renders a frame."""
67
+ camera_name = camera_name or "track"
68
68
 
69
69
  mujoco.mj_forward(self.model, self.data)
70
70
  self.renderer.update_scene(data=self.data, camera=camera_name)
71
71
 
72
72
  return self.renderer.render()
73
73
 
74
- def record_frame(self, camera_name: str = "track") -> None:
74
+ def record_frame(self, camera_name: str | None = None) -> None:
75
75
  """Stores a frame in the buffer."""
76
+ camera_name = camera_name or "track"
76
77
 
77
78
  frame = self.render_frame(camera_name=camera_name)
78
79
  self.frames.append(frame)
@@ -166,72 +167,13 @@ class MujocoVisualizer:
166
167
  self,
167
168
  model: mj.MjModel | None = None,
168
169
  data: mj.MjData | None = None,
169
- *,
170
170
  close_on_exit: bool = True,
171
- lookat: Sequence[float | int] | npt.NDArray | None = None,
172
- distance: float | int | npt.NDArray | None = None,
173
- azimut: float | int | npt.NDArray | None = None,
174
- elevation: float | int | npt.NDArray | None = None,
175
171
  ) -> ContextManager[mujoco.viewer.Handle]:
176
- """
177
- Context manager to open the Mujoco passive viewer.
178
-
179
- Note:
180
- Refer to the Mujoco documentation for details of the camera options:
181
- https://mujoco.readthedocs.io/en/stable/XMLreference.html#visual-global
182
- """
172
+ """Context manager to open a viewer."""
183
173
 
184
174
  handle = self.open_viewer(model=model, data=data)
185
175
 
186
- handle = MujocoVisualizer.setup_viewer_camera(
187
- viewer=handle,
188
- lookat=lookat,
189
- distance=distance,
190
- azimut=azimut,
191
- elevation=elevation,
192
- )
193
-
194
176
  try:
195
177
  yield handle
196
178
  finally:
197
179
  _ = handle.close() if close_on_exit else None
198
-
199
- @staticmethod
200
- def setup_viewer_camera(
201
- viewer: mj.viewer.Handle,
202
- *,
203
- lookat: Sequence[float | int] | npt.NDArray | None,
204
- distance: float | int | npt.NDArray | None = None,
205
- azimut: float | int | npt.NDArray | None = None,
206
- elevation: float | int | npt.NDArray | None = None,
207
- ) -> mj.viewer.Handle:
208
- """
209
- Configure the initial viewpoint of the Mujoco passive viewer.
210
-
211
- Note:
212
- Refer to the Mujoco documentation for details of the camera options:
213
- https://mujoco.readthedocs.io/en/stable/XMLreference.html#visual-global
214
-
215
- Returns:
216
- The viewer with configured camera.
217
- """
218
-
219
- if lookat is not None:
220
-
221
- lookat_array = np.array(lookat, dtype=float).squeeze()
222
-
223
- if lookat_array.size != 3:
224
- raise ValueError(lookat)
225
-
226
- viewer.cam.lookat = lookat_array
227
-
228
- if distance is not None:
229
- viewer.cam.distance = float(distance)
230
-
231
- if azimut is not None:
232
- viewer.cam.azimuth = float(azimut) % 360
233
-
234
- if elevation is not None:
235
- viewer.cam.elevation = float(elevation)
236
-
237
- return viewer
@@ -2,7 +2,6 @@ import jax.numpy as jnp
2
2
 
3
3
  import jaxsim.api as js
4
4
  import jaxsim.typing as jtp
5
- from jaxsim import exceptions
6
5
  from jaxsim.math import StandardGravity
7
6
 
8
7
 
@@ -132,13 +131,6 @@ def process_inputs(
132
131
  if W_Q_B.shape != (4,):
133
132
  raise ValueError(W_Q_B.shape, (4,))
134
133
 
135
- # Check that the quaternion is unary since our RBDAs make this assumption in order
136
- # to prevent introducing additional normalizations that would affect AD.
137
- exceptions.raise_value_error_if(
138
- condition=jnp.logical_not(jnp.allclose(W_Q_B.dot(W_Q_B), 1.0)),
139
- msg="A RBDA received a quaternion that is not normalized.",
140
- )
141
-
142
134
  # Pack the 6D base velocity and acceleration.
143
135
  W_v_WB = jnp.hstack([W_vl_WB, W_ω_WB])
144
136
  W_v̇_WB = jnp.hstack([W_v̇l_WB, W_ω̇_WB])
@@ -73,7 +73,7 @@ class FlatTerrain(Terrain):
73
73
  class PlaneTerrain(FlatTerrain):
74
74
 
75
75
  plane_normal: tuple[float, float, float] = jax_dataclasses.field(
76
- default=(0.0, 0.0, 1.0), kw_only=True
76
+ default=(0.0, 0.0, 0.0), kw_only=True
77
77
  )
78
78
 
79
79
  @staticmethod
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jaxsim
3
- Version: 0.4.1
3
+ Version: 0.4.1.dev6
4
4
  Summary: A differentiable physics engine and multibody dynamics library for control and robot learning.
5
5
  Author-email: Diego Ferigo <dgferigo@gmail.com>
6
6
  Maintainer-email: Diego Ferigo <dgferigo@gmail.com>, Filippo Luca Ferretti <filippo.ferretti@iit.it>
@@ -80,7 +80,6 @@ Provides-Extra: viz
80
80
  Requires-Dist: lxml; extra == "viz"
81
81
  Requires-Dist: mediapy; extra == "viz"
82
82
  Requires-Dist: mujoco>=3.0.0; extra == "viz"
83
- Requires-Dist: scipy>=1.14.0; extra == "viz"
84
83
  Provides-Extra: all
85
84
  Requires-Dist: jaxsim[style,testing,viz]; extra == "all"
86
85
 
@@ -27,4 +27,3 @@ robot-descriptions
27
27
  lxml
28
28
  mediapy
29
29
  mujoco>=3.0.0
30
- scipy>=1.14.0
@@ -93,7 +93,7 @@ def test_ad_aba(
93
93
  aba = lambda W_p_B, W_Q_B, s, W_v_WB, ṡ, τ, W_f_L, g: jaxsim.rbda.aba(
94
94
  model=model,
95
95
  base_position=W_p_B,
96
- base_quaternion=W_Q_B / jnp.linalg.norm(W_Q_B),
96
+ base_quaternion=W_Q_B,
97
97
  joint_positions=s,
98
98
  base_linear_velocity=W_v_WB[0:3],
99
99
  base_angular_velocity=W_v_WB[3:6],
@@ -150,7 +150,7 @@ def test_ad_rnea(
150
150
  rnea = lambda W_p_B, W_Q_B, s, W_v_WB, ṡ, W_v̇_WB, s̈, W_f_L, g: jaxsim.rbda.rnea(
151
151
  model=model,
152
152
  base_position=W_p_B,
153
- base_quaternion=W_Q_B / jnp.linalg.norm(W_Q_B),
153
+ base_quaternion=W_Q_B,
154
154
  joint_positions=s,
155
155
  base_linear_velocity=W_v_WB[0:3],
156
156
  base_angular_velocity=W_v_WB[3:6],
@@ -229,7 +229,7 @@ def test_ad_fk(
229
229
  fk = lambda W_p_B, W_Q_B, s: jaxsim.rbda.forward_kinematics_model(
230
230
  model=model,
231
231
  base_position=W_p_B,
232
- base_quaternion=W_Q_B / jnp.linalg.norm(W_Q_B),
232
+ base_quaternion=W_Q_B,
233
233
  joint_positions=s,
234
234
  )
235
235
 
@@ -6,7 +6,7 @@ import jaxsim.api as js
6
6
  import jaxsim.integrators
7
7
  import jaxsim.rbda
8
8
  from jaxsim import VelRepr
9
- from jaxsim.utils import Mutability
9
+ from jaxsim.rbda.contacts.soft import SoftContactsParams
10
10
 
11
11
 
12
12
  def test_box_with_external_forces(
@@ -102,19 +102,23 @@ def test_box_with_zero_gravity(
102
102
 
103
103
  model = jaxsim_model_box
104
104
 
105
- # Move the terrain (almost) infinitely far away from the box.
106
- with model.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):
107
- model.terrain = jaxsim.terrain.FlatTerrain.build(height=-1e9)
108
-
109
105
  # Split the PRNG key.
110
- _, subkey = jax.random.split(prng_key, num=2)
106
+ _, subkey, subkey2 = jax.random.split(prng_key, num=3)
111
107
 
112
108
  # Build the data of the model.
113
109
  data0 = js.data.JaxSimModelData.build(
114
110
  model=model,
115
- base_position=jax.random.uniform(subkey, shape=(3,)),
111
+ base_position=jax.random.uniform(subkey2, shape=(3,)),
116
112
  velocity_representation=velocity_representation,
117
113
  standard_gravity=0.0,
114
+ contacts_params=SoftContactsParams.build(K=0.0, D=0.0, mu=0.0),
115
+ )
116
+
117
+ # Generate a random linear force.
118
+ L_f = (
119
+ jax.random.uniform(subkey, shape=(model.number_of_links(), 6))
120
+ .at[:, 3:]
121
+ .set(jnp.zeros(3))
118
122
  )
119
123
 
120
124
  # Initialize a references object that simplifies handling external forces.
@@ -125,29 +129,13 @@ def test_box_with_zero_gravity(
125
129
  )
126
130
 
127
131
  # Apply a link forces to the base link.
128
- with references.switch_velocity_representation(jaxsim.VelRepr.Mixed):
129
-
130
- # Generate a random linear force.
131
- # We enforce them to be the same for all velocity representations so that
132
- # we can compare their outcomes.
133
- LW_f = 10.0 * (
134
- jax.random.uniform(jax.random.key(0), shape=(model.number_of_links(), 6))
135
- .at[:, 3:]
136
- .set(jnp.zeros(3))
137
- )
138
-
139
- # Note that the context manager does not switch back the newly created
140
- # `references` (that is not the yielded object) to the original representation.
141
- # In the simulation loop below, we need to make sure that we switch both `data`
142
- # and `references` to the same representation before extracting the information
143
- # passed to the step function.
144
- references = references.apply_link_forces(
145
- forces=jnp.atleast_2d(LW_f),
146
- link_names=model.link_names(),
147
- model=model,
148
- data=data0,
149
- additive=False,
150
- )
132
+ references = references.apply_link_forces(
133
+ forces=jnp.atleast_2d(L_f),
134
+ link_names=model.link_names(),
135
+ model=model,
136
+ data=data0,
137
+ additive=False,
138
+ )
151
139
 
152
140
  # Create the integrator.
153
141
  integrator = jaxsim.integrators.fixed_step.RungeKutta4SO3.build(
@@ -157,7 +145,8 @@ def test_box_with_zero_gravity(
157
145
  )
158
146
 
159
147
  # Initialize the integrator.
160
- tf, dt = 1.0, 0.010
148
+ tf = 1.0
149
+ dt = 0.010
161
150
  T = jnp.arange(start=0, stop=tf * 1e9, step=dt * 1e9, dtype=int)
162
151
  integrator_state = integrator.init(x0=data0.state, t0=0.0, dt=dt)
163
152
 
@@ -167,28 +156,19 @@ def test_box_with_zero_gravity(
167
156
  # ... and step the simulation.
168
157
  for t_ns in T:
169
158
 
170
- assert data.time() == t_ns / 1e9
171
-
172
- with (
173
- data.switch_velocity_representation(velocity_representation),
174
- references.switch_velocity_representation(velocity_representation),
175
- ):
176
-
177
- data, integrator_state = js.model.step(
178
- model=model,
179
- data=data,
180
- dt=dt,
181
- integrator=integrator,
182
- integrator_state=integrator_state,
183
- link_forces=references.link_forces(model=model, data=data),
184
- )
185
-
186
- # Check the final simulation time.
187
- assert data.time() == T[-1] / 1e9 + dt
159
+ data, integrator_state = js.model.step(
160
+ model=model,
161
+ data=data,
162
+ dt=dt,
163
+ integrator=integrator,
164
+ integrator_state=integrator_state,
165
+ link_forces=references.link_forces(model=model, data=data),
166
+ )
188
167
 
189
168
  # Check that the box moved as expected.
169
+ assert data.time() == t_ns / 1e9 + dt
190
170
  assert data.base_position() == pytest.approx(
191
171
  data0.base_position()
192
- + 0.5 * LW_f[:, :3].squeeze() / js.model.total_mass(model=model) * tf**2,
172
+ + 0.5 * L_f[:, :3].squeeze() / js.model.total_mass(model=model) * tf**2,
193
173
  abs=1e-3,
194
174
  )
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes