jaxsim 0.4.3.dev88__py3-none-any.whl → 0.4.3.dev92__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jaxsim/_version.py CHANGED
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.4.3.dev88'
16
- __version_tuple__ = version_tuple = (0, 4, 3, 'dev88')
15
+ __version__ = version = '0.4.3.dev92'
16
+ __version_tuple__ = version_tuple = (0, 4, 3, 'dev92')
jaxsim/api/ode.py CHANGED
@@ -145,7 +145,10 @@ def system_velocity_dynamics(
145
145
 
146
146
  # Compute the 6D forces W_f ∈ ℝ^{n_c × 6} applied to each collidable point
147
147
  # along with contact-specific auxiliary states.
148
- with data.switch_velocity_representation(VelRepr.Inertial):
148
+ with (
149
+ data.switch_velocity_representation(VelRepr.Inertial),
150
+ references.switch_velocity_representation(VelRepr.Inertial),
151
+ ):
149
152
  W_f_Ci, aux_data = js.contact.collidable_point_dynamics(
150
153
  model=model,
151
154
  data=data,
jaxsim/rbda/crba.py CHANGED
@@ -94,48 +94,48 @@ def crba(model: js.model.JaxSimModel, *, joint_positions: jtp.Vector) -> jtp.Mat
94
94
 
95
95
  j = i
96
96
 
97
- CarryInnerFn = tuple[jtp.Int, jtp.Matrix, jtp.Matrix]
98
- carry_inner_fn = (j, Fi, M)
97
+ FakeWhileCarry = tuple[jtp.Int, jtp.Vector, jtp.Matrix]
98
+ fake_while_carry = (j, Fi, M)
99
99
 
100
- def while_loop_body(carry: CarryInnerFn) -> CarryInnerFn:
101
- j, Fi, M = carry
100
+ # This internal for loop implements the while loop of the CRBA algorithm
101
+ # to compute off-diagonal blocks of the mass matrix M.
102
+ # In pseudocode it is implemented as a while loop. However, in order to enable
103
+ # applying reverse-mode AD, we implement it as a nested for loop with a fixed
104
+ # number of iterations and a branching model to skip for loop iterations.
105
+ def fake_while_loop(
106
+ carry: FakeWhileCarry, i: jtp.Int
107
+ ) -> tuple[FakeWhileCarry, None]:
102
108
 
103
- Fi = i_X_λi[j].T @ Fi
104
- j = λ[j]
105
- jj = j - 1
109
+ def compute(carry: FakeWhileCarry) -> FakeWhileCarry:
106
110
 
107
- M_ij = Fi.T @ S[j]
111
+ j, Fi, M = carry
108
112
 
109
- M = M.at[ii + 6, jj + 6].set(M_ij.squeeze())
110
- M = M.at[jj + 6, ii + 6].set(M_ij.squeeze())
113
+ Fi = i_X_λi[j].T @ Fi
114
+ j = λ[j]
111
115
 
112
- return j, Fi, M
116
+ M_ij = Fi.T @ S[j]
113
117
 
114
- # The following functions are part of a (rather messy) workaround for computing
115
- # a while loop using a for loop with fixed number of iterations.
116
- def inner_fn(carry: CarryInnerFn, k: jtp.Int) -> tuple[CarryInnerFn, None]:
117
- def compute_inner(carry: CarryInnerFn) -> tuple[CarryInnerFn, None]:
118
- j, _, _ = carry
119
- out = jax.lax.cond(
120
- pred=(λ[j] > 0),
121
- true_fun=while_loop_body,
122
- false_fun=lambda carry: carry,
123
- operand=carry,
124
- )
125
- return out, None
118
+ jj = j - 1
119
+ M = M.at[ii + 6, jj + 6].set(M_ij.squeeze())
120
+ M = M.at[jj + 6, ii + 6].set(M_ij.squeeze())
121
+
122
+ return j, Fi, M
126
123
 
127
124
  j, _, _ = carry
128
- return jax.lax.cond(
129
- pred=(k == j),
130
- true_fun=compute_inner,
131
- false_fun=lambda carry: (carry, None),
125
+
126
+ j, Fi, M = jax.lax.cond(
127
+ pred=jnp.logical_and(i == λ[j], λ[j] > 0),
128
+ true_fun=compute,
129
+ false_fun=lambda carry: carry,
132
130
  operand=carry,
133
131
  )
134
132
 
133
+ return (j, Fi, M), None
134
+
135
135
  (j, Fi, M), _ = (
136
136
  jax.lax.scan(
137
- f=inner_fn,
138
- init=carry_inner_fn,
137
+ f=fake_while_loop,
138
+ init=fake_while_carry,
139
139
  xs=jnp.flip(jnp.arange(start=1, stop=model.number_of_links())),
140
140
  )
141
141
  if model.number_of_links() > 1
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jaxsim
3
- Version: 0.4.3.dev88
3
+ Version: 0.4.3.dev92
4
4
  Summary: A differentiable physics engine and multibody dynamics library for control and robot learning.
5
5
  Author-email: Diego Ferigo <dgferigo@gmail.com>
6
6
  Maintainer-email: Diego Ferigo <dgferigo@gmail.com>, Filippo Luca Ferretti <filippo.ferretti@iit.it>
@@ -1,5 +1,5 @@
1
1
  jaxsim/__init__.py,sha256=bSbpggIz5aG6QuGZLa0V2EfHjAOeucMxi-vIYxzLmN8,2788
2
- jaxsim/_version.py,sha256=B5nbhwRMLkROhGufRX4ovaMmwptyxOtLB8YwoD_nlo8,426
2
+ jaxsim/_version.py,sha256=zAPSBvbq9I7HBdNXrtiufDT_Jj2UKtJOpjP0i_YFtmE,426
3
3
  jaxsim/exceptions.py,sha256=8_h8iqL8DgNR754dR8SZiQ7361GR5V1sUk3ZuZCHw1Q,2069
4
4
  jaxsim/logging.py,sha256=STI-D_upXZYX-ZezLrlJJ0UlD5YspST0vZ_DcIwkzO4,1553
5
5
  jaxsim/typing.py,sha256=2HXy9hgazPXjofi1vLQ09ZubPtgVmg80U9NKmZ6NYiI,761
@@ -13,7 +13,7 @@ jaxsim/api/joint.py,sha256=lksT1Doxz2jknHyhb4ls20z6f6dofpZSzBJtVacZXAE,7129
13
13
  jaxsim/api/kin_dyn_parameters.py,sha256=ElahFk_RCcLvjTidH2qDOsY-m1gN1hXitCv4SvfgGYY,29260
14
14
  jaxsim/api/link.py,sha256=LAA6ZMQXkWomXeptURBtc7z3_xDZ2BBnBMhVrohh0bE,18621
15
15
  jaxsim/api/model.py,sha256=WL31JA2jK5L79TJ05ZouIYGe02rQvFItVDqizkzC1UE,66100
16
- jaxsim/api/ode.py,sha256=ZshGdHptftku0yoUwBiBdd1iOqntH0vVEOjRHfL7Fao,13518
16
+ jaxsim/api/ode.py,sha256=gYSbtHWGCDP-IkUzQlH3t0fBKnK8qmxwhIvsbLG9lwU,13616
17
17
  jaxsim/api/ode_data.py,sha256=7RSoBhfCJdP6P9InQbDwdBVpClPMMuetewI-6AWm-_0,20276
18
18
  jaxsim/api/references.py,sha256=XOVKuQXRmjPoP-T5JWGSbqIGX5DzOkeGafqRpj0ZQEM,20771
19
19
  jaxsim/integrators/__init__.py,sha256=hxvOD-VK_mmd6v31wtC-nb28AYve1gLuZCNLV9wS-Kg,103
@@ -47,7 +47,7 @@ jaxsim/parsers/rod/utils.py,sha256=5DsF3OeePZGidOJ5GiFSZx-51uIdnFvMW9EK6SgOW6Q,5
47
47
  jaxsim/rbda/__init__.py,sha256=H7DhXpxkPOi9lpUvg31IMHFfRafke1UoJLc5GQIdyhA,387
48
48
  jaxsim/rbda/aba.py,sha256=w7ciyxB0IsmueatT0C7PcBQEl9dyiH9oqJgIi3xeTUE,8983
49
49
  jaxsim/rbda/collidable_points.py,sha256=Rmf1DhflhOTYh9mDalv0agS0CGSbmfoOybwP2KzKuJ0,4883
50
- jaxsim/rbda/crba.py,sha256=zJSiHKRvNU98z2tT9prrWR4VU9wIZQWFwEut7mua6as,5044
50
+ jaxsim/rbda/crba.py,sha256=bXkXESnVbv-lxhU1Y_i0rViEcQA4z2t2_jHwdVj5CBo,5049
51
51
  jaxsim/rbda/forward_kinematics.py,sha256=2GmEoWsrioVl_SAbKRKfhOLz57pY4aR81PKRdulqStA,3458
52
52
  jaxsim/rbda/jacobian.py,sha256=p0EV_8cLzLVV-93VKznT7VPuRj8W7h7rQWkPlWJXfCA,11023
53
53
  jaxsim/rbda/rnea.py,sha256=CLfqs9XFVaD-hvkLABshDAfdw5bm_AMV3UVAQ_IvURQ,7542
@@ -63,8 +63,8 @@ jaxsim/utils/__init__.py,sha256=Y5zyoRevl3EMVQadhZ4EtSwTEkDt2vcnFoRhPJjKTZ0,215
63
63
  jaxsim/utils/jaxsim_dataclass.py,sha256=TGmTQV2Lq7Q-2nLoAEaeNtkPa_qj0IKkdBm4COj46Os,11312
64
64
  jaxsim/utils/tracing.py,sha256=KDMoyVPlu2NJvFkhtZwq5AkqMMgajt3munvJom-vEjQ,650
65
65
  jaxsim/utils/wrappers.py,sha256=Fh82ZcaFi5fUnByyFLnmumaobsu1hJIvFdopUVzJ1ps,4052
66
- jaxsim-0.4.3.dev88.dist-info/LICENSE,sha256=eaYdFmdeMbiIoIiPzEK0MjP1S9wtFXjXNR5er49uLR0,1546
67
- jaxsim-0.4.3.dev88.dist-info/METADATA,sha256=Hm6Vv2DzMpviXdgHoSsUUuGWJAyYMKngkYXEVMHEh7U,17276
68
- jaxsim-0.4.3.dev88.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
69
- jaxsim-0.4.3.dev88.dist-info/top_level.txt,sha256=LxGMA8FLtXjQ6oI7N5gd_R_oSUHxpXxUEOfT1xS_ni0,7
70
- jaxsim-0.4.3.dev88.dist-info/RECORD,,
66
+ jaxsim-0.4.3.dev92.dist-info/LICENSE,sha256=eaYdFmdeMbiIoIiPzEK0MjP1S9wtFXjXNR5er49uLR0,1546
67
+ jaxsim-0.4.3.dev92.dist-info/METADATA,sha256=ix-zV0RmlB_yVniPpcZ-IzY0EEVgi2iZ9d8GLyg7vxU,17276
68
+ jaxsim-0.4.3.dev92.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
69
+ jaxsim-0.4.3.dev92.dist-info/top_level.txt,sha256=LxGMA8FLtXjQ6oI7N5gd_R_oSUHxpXxUEOfT1xS_ni0,7
70
+ jaxsim-0.4.3.dev92.dist-info/RECORD,,