jaxsim 0.3.1.dev46__py3-none-any.whl → 0.3.1.dev51__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jaxsim/_version.py CHANGED
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.3.1.dev46'
16
- __version_tuple__ = version_tuple = (0, 3, 1, 'dev46')
15
+ __version__ = version = '0.3.1.dev51'
16
+ __version_tuple__ = version_tuple = (0, 3, 1, 'dev51')
jaxsim/api/model.py CHANGED
@@ -1841,7 +1841,13 @@ def step(
1841
1841
  and the new state of the integrator.
1842
1842
  """
1843
1843
 
1844
- integrator_kwargs = kwargs if kwargs is not None else dict()
1844
+ # Extract the integrator kwargs.
1845
+ # The following logic allows using integrators having kwargs colliding with the
1846
+ # kwargs of this step function.
1847
+ kwargs = kwargs if kwargs is not None else {}
1848
+ integrator_kwargs = kwargs.pop("integrator_kwargs", {})
1849
+ integrator_kwargs = kwargs | integrator_kwargs
1850
+
1845
1851
  integrator_state = integrator_state if integrator_state is not None else dict()
1846
1852
 
1847
1853
  # Extract the initial resources.
@@ -1855,8 +1861,21 @@ def step(
1855
1861
  t0=jnp.array(t0_ns / 1e9).astype(float),
1856
1862
  dt=dt,
1857
1863
  params=integrator_state_x0,
1864
+ # Always inject the current (model, data) pair into the system dynamics
1865
+ # considered by the integrator, and include the input variables represented
1866
+ # by the pair (joint_forces, link_forces).
1867
+ # Note that the wrapper of the system dynamics will override (state_x0, t0)
1868
+ # inside the passed data even if it is not strictly needed. This logic is
1869
+ # necessary to re-use the jit-compiled step function of compatible pytrees
1870
+ # of model and data produced e.g. by parameterized applications.
1858
1871
  **(
1859
- dict(joint_forces=joint_forces, link_forces=link_forces) | integrator_kwargs
1872
+ dict(
1873
+ model=model,
1874
+ data=data,
1875
+ joint_forces=joint_forces,
1876
+ link_forces=link_forces,
1877
+ )
1878
+ | integrator_kwargs
1860
1879
  ),
1861
1880
  )
1862
1881
 
@@ -422,7 +422,9 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
422
422
 
423
423
  # Update the FSAL property for the next iteration.
424
424
  if self.has_fsal:
425
- self.params["dxdt0"] = jax.tree_map(lambda l: l[self.index_of_fsal], K)
425
+ self.params["dxdt0"] = jax.tree_util.tree_map(
426
+ lambda l: l[self.index_of_fsal], K
427
+ )
426
428
 
427
429
  # Compute the output state.
428
430
  # Note that z contains as many new states as the rows of `b.T`.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jaxsim
3
- Version: 0.3.1.dev46
3
+ Version: 0.3.1.dev51
4
4
  Home-page: https://github.com/ami-iit/jaxsim
5
5
  Author: Diego Ferigo
6
6
  Author-email: diego.ferigo@iit.it
@@ -1,5 +1,5 @@
1
1
  jaxsim/__init__.py,sha256=xzuTuZrgKdWLqqDzbvqzm2cJrEtAbepOeUqDu7ByVek,2621
2
- jaxsim/_version.py,sha256=9MZZuPQZKCPbVj6_p_Ju-do62zKeUP7ohaQeRYNKN3w,426
2
+ jaxsim/_version.py,sha256=IGT0sukgd0adIU8Kj5QFc30r0PeT_NCk5YVNZH5UbRg,426
3
3
  jaxsim/exceptions.py,sha256=8_h8iqL8DgNR754dR8SZiQ7361GR5V1sUk3ZuZCHw1Q,2069
4
4
  jaxsim/logging.py,sha256=c4zhwBKf9eAYAHVp62kTEllqdsZgh0K-kPKVy8L3elU,1584
5
5
  jaxsim/typing.py,sha256=cl7HHQCeP3mHmtF6EuQZcCjGvDmc_AryMWntP_lRBGg,722
@@ -12,12 +12,12 @@ jaxsim/api/frame.py,sha256=m_waB9_0kgJq5miiZDXdRzIZii-BwQaN9bRt22JkJ1I,7212
12
12
  jaxsim/api/joint.py,sha256=Pvg_It2iYA-jAQ2nOlFZxwmITiozO_f46G13BdQtHQ0,5106
13
13
  jaxsim/api/kin_dyn_parameters.py,sha256=AEpDg9kihbKUN9PA8pNrAruSuWFUC-k_GGxtlcdcDiQ,29215
14
14
  jaxsim/api/link.py,sha256=edXaNO0TcqyFyMOlIlCnReQK_VP8p38crLEp0of7mWo,18404
15
- jaxsim/api/model.py,sha256=HAnrlgPDl5CCZQzQ84AfjC_DZjmrCzBKEDodE6hyLf8,60518
15
+ jaxsim/api/model.py,sha256=G4tlJ0mkpH_v31ZPObBSeTqaioYJpdfapmIG7-FyWe8,61424
16
16
  jaxsim/api/ode.py,sha256=xQL53ppnKweMQWRNm5gGR8FTjqRVzds8WKg9js9k5TA,10780
17
17
  jaxsim/api/ode_data.py,sha256=Sa2i1zZhqyQqIGv1jarTmmU-W9HhTw-DErs12kFA1GA,19737
18
18
  jaxsim/api/references.py,sha256=UA6kSQVBoq-bXSo99EOELf-_MD5MTy2zS0GtG3wQ410,16618
19
19
  jaxsim/integrators/__init__.py,sha256=hxvOD-VK_mmd6v31wtC-nb28AYve1gLuZCNLV9wS-Kg,103
20
- jaxsim/integrators/common.py,sha256=9HXRVFo95Mpt6RcVhBrOfvOO7mDxqbkXeg_lKUibEFY,20693
20
+ jaxsim/integrators/common.py,sha256=cRXD0UPEFo2UkPQnOUi3u9Ph3kbfaYFaAvGnZWjSWwI,20733
21
21
  jaxsim/integrators/fixed_step.py,sha256=JXaEyEzfSiYea0GnPA7l27J3X0YPB0e25D4qfrxAvzQ,2766
22
22
  jaxsim/integrators/variable_step.py,sha256=jq3PStzFiMciU7lux6CTj4B3gVOfSpYgK2oz2yzIbdo,21380
23
23
  jaxsim/math/__init__.py,sha256=inJ9nRFkqstuGa8OyFkfWVudo5U9Ug4WgDBuKva8AIA,337
@@ -61,8 +61,8 @@ jaxsim/utils/__init__.py,sha256=Y5zyoRevl3EMVQadhZ4EtSwTEkDt2vcnFoRhPJjKTZ0,215
61
61
  jaxsim/utils/jaxsim_dataclass.py,sha256=h26timZ_XrBL_Q_oymv-DkQd-EcUiHn8QexAaZXBY9c,11396
62
62
  jaxsim/utils/tracing.py,sha256=KDMoyVPlu2NJvFkhtZwq5AkqMMgajt3munvJom-vEjQ,650
63
63
  jaxsim/utils/wrappers.py,sha256=QIJitSoljrKR_U4T3ewCJPT3DTh-tPZsRsg0t_MH93E,3896
64
- jaxsim-0.3.1.dev46.dist-info/LICENSE,sha256=eaYdFmdeMbiIoIiPzEK0MjP1S9wtFXjXNR5er49uLR0,1546
65
- jaxsim-0.3.1.dev46.dist-info/METADATA,sha256=KS14B2THB9Wao1MpRKEp345CTkC7PMLMfUOSbOkYbBA,9739
66
- jaxsim-0.3.1.dev46.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
67
- jaxsim-0.3.1.dev46.dist-info/top_level.txt,sha256=LxGMA8FLtXjQ6oI7N5gd_R_oSUHxpXxUEOfT1xS_ni0,7
68
- jaxsim-0.3.1.dev46.dist-info/RECORD,,
64
+ jaxsim-0.3.1.dev51.dist-info/LICENSE,sha256=eaYdFmdeMbiIoIiPzEK0MjP1S9wtFXjXNR5er49uLR0,1546
65
+ jaxsim-0.3.1.dev51.dist-info/METADATA,sha256=zbBiTRoZ6XIJvJcq8UTkHhvGOc7krPxFpxDqz5_oJ_A,9739
66
+ jaxsim-0.3.1.dev51.dist-info/WHEEL,sha256=cpQTJ5IWu9CdaPViMhC9YzF8gZuS5-vlfoFihTBC86A,91
67
+ jaxsim-0.3.1.dev51.dist-info/top_level.txt,sha256=LxGMA8FLtXjQ6oI7N5gd_R_oSUHxpXxUEOfT1xS_ni0,7
68
+ jaxsim-0.3.1.dev51.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: bdist_wheel (0.43.0)
2
+ Generator: setuptools (70.1.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5