jaxsim 0.4.2.dev45__py3-none-any.whl → 0.4.2.dev50__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jaxsim/_version.py CHANGED
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.4.2.dev45'
16
- __version_tuple__ = version_tuple = (0, 4, 2, 'dev45')
15
+ __version__ = version = '0.4.2.dev50'
16
+ __version_tuple__ = version_tuple = (0, 4, 2, 'dev50')
jaxsim/api/references.py CHANGED
@@ -9,7 +9,6 @@ import jax_dataclasses
9
9
  import jaxsim.api as js
10
10
  import jaxsim.typing as jtp
11
11
  from jaxsim import exceptions
12
- from jaxsim.math import Adjoint
13
12
  from jaxsim.utils.tracing import not_tracing
14
13
 
15
14
  from .common import VelRepr
@@ -493,9 +492,9 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
493
492
 
494
493
  # Extract the frame indices.
495
494
  frame_idxs = js.frame.names_to_idxs(frame_names=frame_names, model=model)
496
- parent_link_idxs = jax.vmap(
497
- lambda frame_idx: js.frame.idx_of_parent_link, in_axes=(None,)
498
- )(model, frame_idx=frame_idxs)
495
+ parent_link_idxs = jax.vmap(js.frame.idx_of_parent_link, in_axes=(None,))(
496
+ model, frame_index=frame_idxs
497
+ )
499
498
 
500
499
  exceptions.raise_value_error_if(
501
500
  condition=jnp.logical_not(data.valid(model=model)),
@@ -527,25 +526,14 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
527
526
  case _:
528
527
  raise ValueError("Invalid velocity representation.")
529
528
 
530
- W_H_L = js.model.forward_kinematics(model=model, data=data)
531
-
532
- def convert_to_link_force(
533
- W_f_F: jtp.MatrixLike, W_H_F: jtp.MatrixLike, parent_link_idx: jtp.ArrayLike
534
- ) -> jtp.Matrix:
535
- L_Xf_W = Adjoint.from_transform(W_H_L[parent_link_idx]).T
536
-
537
- return L_Xf_W @ W_f_F
538
-
539
- W_f_L_i = jax.vmap(convert_to_link_force)(W_f_F, W_H_Fi, parent_link_idxs)
540
-
541
529
  # Sum the forces on the parent links.
542
530
  mask = parent_link_idxs[:, jnp.newaxis] == jnp.arange(model.number_of_links())
543
- W_f_L = mask.T @ W_f_L_i
531
+ W_f_L = mask.T @ W_f_F
544
532
 
545
533
  with self.switch_velocity_representation(
546
534
  velocity_representation=VelRepr.Inertial
547
535
  ):
548
- return self.apply_link_forces(
536
+ references = self.apply_link_forces(
549
537
  model=model,
550
538
  data=data,
551
539
  link_names=js.link.idxs_to_names(
@@ -554,3 +542,8 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
554
542
  forces=W_f_L,
555
543
  additive=additive,
556
544
  )
545
+
546
+ with references.switch_velocity_representation(
547
+ velocity_representation=self.velocity_representation
548
+ ):
549
+ return references
jaxsim/rbda/crba.py CHANGED
@@ -59,10 +59,14 @@ def crba(model: js.model.JaxSimModel, *, joint_positions: jtp.Vector) -> jtp.Mat
59
59
 
60
60
  return (i_X_0,), None
61
61
 
62
- (i_X_0,), _ = jax.lax.scan(
63
- f=propagate_kinematics,
64
- init=forward_pass_carry,
65
- xs=jnp.arange(start=1, stop=model.number_of_links()),
62
+ (i_X_0,), _ = (
63
+ jax.lax.scan(
64
+ f=propagate_kinematics,
65
+ init=forward_pass_carry,
66
+ xs=jnp.arange(start=1, stop=model.number_of_links()),
67
+ )
68
+ if model.number_of_links() > 1
69
+ else [(i_X_0,), None]
66
70
  )
67
71
 
68
72
  # ===================
@@ -128,10 +132,14 @@ def crba(model: js.model.JaxSimModel, *, joint_positions: jtp.Vector) -> jtp.Mat
128
132
  operand=carry,
129
133
  )
130
134
 
131
- (j, Fi, M), _ = jax.lax.scan(
132
- f=inner_fn,
133
- init=carry_inner_fn,
134
- xs=jnp.flip(jnp.arange(start=1, stop=model.number_of_links())),
135
+ (j, Fi, M), _ = (
136
+ jax.lax.scan(
137
+ f=inner_fn,
138
+ init=carry_inner_fn,
139
+ xs=jnp.flip(jnp.arange(start=1, stop=model.number_of_links())),
140
+ )
141
+ if model.number_of_links() > 1
142
+ else [(j, Fi, M), None]
135
143
  )
136
144
 
137
145
  Fi = i_X_0[j].T @ Fi
@@ -143,10 +151,14 @@ def crba(model: js.model.JaxSimModel, *, joint_positions: jtp.Vector) -> jtp.Mat
143
151
 
144
152
  # This scan performs the backward pass to compute Mbj, Mjb and Mjj, that
145
153
  # also includes a fake while loop implemented with a scan and two cond.
146
- (Mc, M), _ = jax.lax.scan(
147
- f=backward_pass,
148
- init=backward_pass_carry,
149
- xs=jnp.flip(jnp.arange(start=1, stop=model.number_of_links())),
154
+ (Mc, M), _ = (
155
+ jax.lax.scan(
156
+ f=backward_pass,
157
+ init=backward_pass_carry,
158
+ xs=jnp.flip(jnp.arange(start=1, stop=model.number_of_links())),
159
+ )
160
+ if model.number_of_links() > 1
161
+ else [(Mc, M), None]
150
162
  )
151
163
 
152
164
  # Store the locked 6D rigid-body inertia matrix Mbb ∈ ℝ⁶ˣ⁶.
@@ -75,10 +75,14 @@ def forward_kinematics_model(
75
75
 
76
76
  return (W_X_i,), None
77
77
 
78
- (W_X_i,), _ = jax.lax.scan(
79
- f=propagate_kinematics,
80
- init=propagate_kinematics_carry,
81
- xs=jnp.arange(start=1, stop=model.number_of_links()),
78
+ (W_X_i,), _ = (
79
+ jax.lax.scan(
80
+ f=propagate_kinematics,
81
+ init=propagate_kinematics_carry,
82
+ xs=jnp.arange(start=1, stop=model.number_of_links()),
83
+ )
84
+ if model.number_of_links() > 1
85
+ else [(W_X_i,), None]
82
86
  )
83
87
 
84
88
  return jax.vmap(Adjoint.to_transform)(W_X_i)
jaxsim/rbda/jacobian.py CHANGED
@@ -67,10 +67,14 @@ def jacobian(
67
67
 
68
68
  return (i_X_0,), None
69
69
 
70
- (i_X_0,), _ = jax.lax.scan(
71
- f=propagate_kinematics,
72
- init=propagate_kinematics_carry,
73
- xs=np.arange(start=1, stop=model.number_of_links()),
70
+ (i_X_0,), _ = (
71
+ jax.lax.scan(
72
+ f=propagate_kinematics,
73
+ init=propagate_kinematics_carry,
74
+ xs=np.arange(start=1, stop=model.number_of_links()),
75
+ )
76
+ if model.number_of_links() > 1
77
+ else [(i_X_0,), None]
74
78
  )
75
79
 
76
80
  # ============================
@@ -105,10 +109,14 @@ def jacobian(
105
109
 
106
110
  return J, None
107
111
 
108
- L_J_WL_B, _ = jax.lax.scan(
109
- f=compute_jacobian,
110
- init=J,
111
- xs=np.arange(start=1, stop=model.number_of_links()),
112
+ L_J_WL_B, _ = (
113
+ jax.lax.scan(
114
+ f=compute_jacobian,
115
+ init=J,
116
+ xs=np.arange(start=1, stop=model.number_of_links()),
117
+ )
118
+ if model.number_of_links() > 1
119
+ else [J, None]
112
120
  )
113
121
 
114
122
  return L_J_WL_B
@@ -184,10 +192,14 @@ def jacobian_full_doubly_left(
184
192
 
185
193
  return (B_X_i, J), None
186
194
 
187
- (B_X_i, J), _ = jax.lax.scan(
188
- f=compute_full_jacobian,
189
- init=compute_full_jacobian_carry,
190
- xs=np.arange(start=1, stop=model.number_of_links()),
195
+ (B_X_i, J), _ = (
196
+ jax.lax.scan(
197
+ f=compute_full_jacobian,
198
+ init=compute_full_jacobian_carry,
199
+ xs=np.arange(start=1, stop=model.number_of_links()),
200
+ )
201
+ if model.number_of_links() > 1
202
+ else [(B_X_i, J), None]
191
203
  )
192
204
 
193
205
  # Convert adjoints to SE(3) transforms.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jaxsim
3
- Version: 0.4.2.dev45
3
+ Version: 0.4.2.dev50
4
4
  Summary: A differentiable physics engine and multibody dynamics library for control and robot learning.
5
5
  Author-email: Diego Ferigo <dgferigo@gmail.com>
6
6
  Maintainer-email: Diego Ferigo <dgferigo@gmail.com>, Filippo Luca Ferretti <filippo.ferretti@iit.it>
@@ -60,29 +60,29 @@ Requires-Python: >=3.10
60
60
  Description-Content-Type: text/markdown
61
61
  License-File: LICENSE
62
62
  Requires-Dist: coloredlogs
63
- Requires-Dist: jax >=0.4.13
64
- Requires-Dist: jaxlib >=0.4.13
65
- Requires-Dist: jaxlie >=1.3.0
66
- Requires-Dist: jax-dataclasses >=1.4.0
63
+ Requires-Dist: jax>=0.4.13
64
+ Requires-Dist: jaxlib>=0.4.13
65
+ Requires-Dist: jaxlie>=1.3.0
66
+ Requires-Dist: jax-dataclasses>=1.4.0
67
67
  Requires-Dist: pptree
68
- Requires-Dist: rod >=0.3.0
69
- Requires-Dist: typing-extensions ; python_version < "3.12"
68
+ Requires-Dist: rod>=0.3.0
69
+ Requires-Dist: typing-extensions; python_version < "3.12"
70
70
  Provides-Extra: all
71
- Requires-Dist: jaxsim[style,testing,viz] ; extra == 'all'
71
+ Requires-Dist: jaxsim[style,testing,viz]; extra == "all"
72
72
  Provides-Extra: style
73
- Requires-Dist: black[jupyter] ~=24.0 ; extra == 'style'
74
- Requires-Dist: isort ; extra == 'style'
75
- Requires-Dist: pre-commit ; extra == 'style'
73
+ Requires-Dist: black[jupyter]~=24.0; extra == "style"
74
+ Requires-Dist: isort; extra == "style"
75
+ Requires-Dist: pre-commit; extra == "style"
76
76
  Provides-Extra: testing
77
- Requires-Dist: idyntree >=12.2.1 ; extra == 'testing'
78
- Requires-Dist: pytest >=6.0 ; extra == 'testing'
79
- Requires-Dist: pytest-icdiff ; extra == 'testing'
80
- Requires-Dist: robot-descriptions ; extra == 'testing'
77
+ Requires-Dist: idyntree>=12.2.1; extra == "testing"
78
+ Requires-Dist: pytest>=6.0; extra == "testing"
79
+ Requires-Dist: pytest-icdiff; extra == "testing"
80
+ Requires-Dist: robot-descriptions; extra == "testing"
81
81
  Provides-Extra: viz
82
- Requires-Dist: lxml ; extra == 'viz'
83
- Requires-Dist: mediapy ; extra == 'viz'
84
- Requires-Dist: mujoco >=3.0.0 ; extra == 'viz'
85
- Requires-Dist: scipy >=1.14.0 ; extra == 'viz'
82
+ Requires-Dist: lxml; extra == "viz"
83
+ Requires-Dist: mediapy; extra == "viz"
84
+ Requires-Dist: mujoco>=3.0.0; extra == "viz"
85
+ Requires-Dist: scipy>=1.14.0; extra == "viz"
86
86
 
87
87
  # JaxSim
88
88
 
@@ -1,5 +1,5 @@
1
1
  jaxsim/__init__.py,sha256=ixsS4dYMPex2wOUUp_rkPnwrPhYzkRh1xO_YuMj3Cr4,2626
2
- jaxsim/_version.py,sha256=mvgmTuGOMpNiwe4eR061Vc1ZiYsIU9hMLQKJGhaZkFI,426
2
+ jaxsim/_version.py,sha256=yzghig35pfSMhh0vIsIOBH-eqIGj9fpA6g-9yTfJMlQ,426
3
3
  jaxsim/exceptions.py,sha256=8_h8iqL8DgNR754dR8SZiQ7361GR5V1sUk3ZuZCHw1Q,2069
4
4
  jaxsim/logging.py,sha256=c4zhwBKf9eAYAHVp62kTEllqdsZgh0K-kPKVy8L3elU,1584
5
5
  jaxsim/typing.py,sha256=IbFx3UkEXi-cm7UBqMPi58rJAFV_HbZ9E_K4JwfNvVM,753
@@ -15,7 +15,7 @@ jaxsim/api/link.py,sha256=GlnY7LMne-siFyg9J49IZGhiPQzS9Uk6rzQ0jI8cD_E,18622
15
15
  jaxsim/api/model.py,sha256=EdSjpKXd4N72wYjg5o0wGKFxjVMyrXg6LnlPEi3JqnU,63094
16
16
  jaxsim/api/ode.py,sha256=NnLTBvpaT4kXnbjAghXIzLv9DTMJ8bele2iOlUQDv3Q,11028
17
17
  jaxsim/api/ode_data.py,sha256=9YZX-SK_KJtoIqG-zYWZsQInb2NA_LtxDn-jtLqm_3U,19759
18
- jaxsim/api/references.py,sha256=6DG1V7dYLf088vSOOsvYSY-KNkzirD0zDJGohJbdLqc,21060
18
+ jaxsim/api/references.py,sha256=XOVKuQXRmjPoP-T5JWGSbqIGX5DzOkeGafqRpj0ZQEM,20771
19
19
  jaxsim/integrators/__init__.py,sha256=hxvOD-VK_mmd6v31wtC-nb28AYve1gLuZCNLV9wS-Kg,103
20
20
  jaxsim/integrators/common.py,sha256=GqiyKTrAozuR6RuvVWdPF7locZQAXSEDY2AjTKpFGYM,20149
21
21
  jaxsim/integrators/fixed_step.py,sha256=KpjRd6hHtapxDoo6D1kyDrVDSHnke2TepI5grFH7_bM,2693
@@ -47,9 +47,9 @@ jaxsim/parsers/rod/utils.py,sha256=5DsF3OeePZGidOJ5GiFSZx-51uIdnFvMW9EK6SgOW6Q,5
47
47
  jaxsim/rbda/__init__.py,sha256=H7DhXpxkPOi9lpUvg31IMHFfRafke1UoJLc5GQIdyhA,387
48
48
  jaxsim/rbda/aba.py,sha256=w7ciyxB0IsmueatT0C7PcBQEl9dyiH9oqJgIi3xeTUE,8983
49
49
  jaxsim/rbda/collidable_points.py,sha256=Rmf1DhflhOTYh9mDalv0agS0CGSbmfoOybwP2KzKuJ0,4883
50
- jaxsim/rbda/crba.py,sha256=NhtZO48OUKKor7ddY7mB7h7a6idrmOyf0Vy4p7UCCgI,4724
51
- jaxsim/rbda/forward_kinematics.py,sha256=FmqhZD0hQpIuUlmpzlA-5b7EYBaLZarrhiSZ782aC3E,3357
52
- jaxsim/rbda/jacobian.py,sha256=I6mrlkk7Cpq3CE7k_tajOHCbT6vf2pW6vMS0TKNCnng,10725
50
+ jaxsim/rbda/crba.py,sha256=zJSiHKRvNU98z2tT9prrWR4VU9wIZQWFwEut7mua6as,5044
51
+ jaxsim/rbda/forward_kinematics.py,sha256=2GmEoWsrioVl_SAbKRKfhOLz57pY4aR81PKRdulqStA,3458
52
+ jaxsim/rbda/jacobian.py,sha256=p0EV_8cLzLVV-93VKznT7VPuRj8W7h7rQWkPlWJXfCA,11023
53
53
  jaxsim/rbda/rnea.py,sha256=LGXD6s3NigaVy4-WxoROjnbKLZcUoyFmS9UNu_4ldjo,7568
54
54
  jaxsim/rbda/utils.py,sha256=eeT21Y4DiiyhrdF0lUE_VvRuwru5-rR7yOlOlWzCCWE,5381
55
55
  jaxsim/rbda/contacts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -61,8 +61,8 @@ jaxsim/utils/__init__.py,sha256=Y5zyoRevl3EMVQadhZ4EtSwTEkDt2vcnFoRhPJjKTZ0,215
61
61
  jaxsim/utils/jaxsim_dataclass.py,sha256=fLl1tY3DDb3lpIhG6BPqA5W34hM84oFzL-5cuz8k-68,11379
62
62
  jaxsim/utils/tracing.py,sha256=KDMoyVPlu2NJvFkhtZwq5AkqMMgajt3munvJom-vEjQ,650
63
63
  jaxsim/utils/wrappers.py,sha256=GOJQCJc5zwzoEGZB62wnWWGvUUQlXvDxz_A2Q-hFv7c,4027
64
- jaxsim-0.4.2.dev45.dist-info/LICENSE,sha256=eaYdFmdeMbiIoIiPzEK0MjP1S9wtFXjXNR5er49uLR0,1546
65
- jaxsim-0.4.2.dev45.dist-info/METADATA,sha256=gNR3PAFm00PuogCUoSG_ru548HM0k6IGQR6kSIIt2FI,17250
66
- jaxsim-0.4.2.dev45.dist-info/WHEEL,sha256=Rp8gFpivVLXx-k3U95ozHnQw8yDcPxmhOpn_Gx8d5nc,91
67
- jaxsim-0.4.2.dev45.dist-info/top_level.txt,sha256=LxGMA8FLtXjQ6oI7N5gd_R_oSUHxpXxUEOfT1xS_ni0,7
68
- jaxsim-0.4.2.dev45.dist-info/RECORD,,
64
+ jaxsim-0.4.2.dev50.dist-info/LICENSE,sha256=eaYdFmdeMbiIoIiPzEK0MjP1S9wtFXjXNR5er49uLR0,1546
65
+ jaxsim-0.4.2.dev50.dist-info/METADATA,sha256=2Z9T8eA8kbRT6KsLPBUqmQLb6xvHBjLp_QI2yLMHWfk,17227
66
+ jaxsim-0.4.2.dev50.dist-info/WHEEL,sha256=nCVcAvsfA9TDtwGwhYaRrlPhTLV9m-Ga6mdyDtuwK18,91
67
+ jaxsim-0.4.2.dev50.dist-info/top_level.txt,sha256=LxGMA8FLtXjQ6oI7N5gd_R_oSUHxpXxUEOfT1xS_ni0,7
68
+ jaxsim-0.4.2.dev50.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (72.0.0)
2
+ Generator: setuptools (73.0.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5