jaxsim 0.4.3.dev64__py3-none-any.whl → 0.4.3.dev70__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jaxsim/_version.py CHANGED
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.4.3.dev64'
16
- __version_tuple__ = version_tuple = (0, 4, 3, 'dev64')
15
+ __version__ = version = '0.4.3.dev70'
16
+ __version_tuple__ = version_tuple = (0, 4, 3, 'dev70')
jaxsim/api/data.py CHANGED
@@ -6,10 +6,11 @@ from collections.abc import Sequence
6
6
 
7
7
  import jax
8
8
  import jax.numpy as jnp
9
+ import jax.scipy.spatial.transform
9
10
  import jax_dataclasses
10
- import jaxlie
11
11
 
12
12
  import jaxsim.api as js
13
+ import jaxsim.math
13
14
  import jaxsim.rbda
14
15
  import jaxsim.typing as jtp
15
16
  from jaxsim.rbda.contacts.soft import SoftContacts
@@ -195,10 +196,9 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
195
196
  else:
196
197
  contacts_params = model.contact_model.parameters
197
198
 
198
- W_H_B = jaxlie.SE3.from_rotation_and_translation(
199
- translation=base_position,
200
- rotation=jaxlie.SO3(wxyz=base_quaternion),
201
- ).as_matrix()
199
+ W_H_B = jaxsim.math.Transform.from_quaternion_and_translation(
200
+ translation=base_position, quaternion=base_quaternion
201
+ )
202
202
 
203
203
  v_WB = JaxSimModelData.other_representation_to_inertial(
204
204
  array=jnp.hstack([base_linear_velocity, base_angular_velocity]),
@@ -384,7 +384,9 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
384
384
  on_false=W_Q_B / jnp.linalg.norm(W_Q_B),
385
385
  )
386
386
 
387
- return (W_Q_B if not dcm else jaxlie.SO3(wxyz=W_Q_B).as_matrix()).astype(float)
387
+ return (W_Q_B if not dcm else jaxsim.math.Quaternion.to_dcm(W_Q_B)).astype(
388
+ float
389
+ )
388
390
 
389
391
  @jax.jit
390
392
  def base_transform(self) -> jtp.Matrix:
@@ -737,6 +739,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
737
739
  )
738
740
 
739
741
 
742
+ @functools.partial(jax.jit, static_argnames=["velocity_representation", "base_rpy_seq"])
740
743
  def random_model_data(
741
744
  model: js.model.JaxSimModel,
742
745
  *,
@@ -746,6 +749,11 @@ def random_model_data(
746
749
  jtp.FloatLike | Sequence[jtp.FloatLike],
747
750
  jtp.FloatLike | Sequence[jtp.FloatLike],
748
751
  ] = ((-1, -1, 0.5), 1.0),
752
+ base_rpy_bounds: tuple[
753
+ jtp.FloatLike | Sequence[jtp.FloatLike],
754
+ jtp.FloatLike | Sequence[jtp.FloatLike],
755
+ ] = (-jnp.pi, jnp.pi),
756
+ base_rpy_seq: str = "XYZ",
749
757
  joint_pos_bounds: (
750
758
  tuple[
751
759
  jtp.FloatLike | Sequence[jtp.FloatLike],
@@ -778,6 +786,10 @@ def random_model_data(
778
786
  key: The random key.
779
787
  velocity_representation: The velocity representation to use.
780
788
  base_pos_bounds: The bounds for the base position.
789
+ base_rpy_bounds:
790
+ The bounds for the euler angles used to build the base orientation.
791
+ base_rpy_seq:
792
+ The sequence of axes for rotation (using `Rotation` from scipy).
781
793
  joint_pos_bounds:
782
794
  The bounds for the joint positions (reading the joint limits if None).
783
795
  base_vel_lin_bounds: The bounds for the base linear velocity.
@@ -794,6 +806,8 @@ def random_model_data(
794
806
 
795
807
  p_min = jnp.array(base_pos_bounds[0], dtype=float)
796
808
  p_max = jnp.array(base_pos_bounds[1], dtype=float)
809
+ rpy_min = jnp.array(base_rpy_bounds[0], dtype=float)
810
+ rpy_max = jnp.array(base_rpy_bounds[1], dtype=float)
797
811
  v_min = jnp.array(base_vel_lin_bounds[0], dtype=float)
798
812
  v_max = jnp.array(base_vel_lin_bounds[1], dtype=float)
799
813
  ω_min = jnp.array(base_vel_ang_bounds[0], dtype=float)
@@ -819,9 +833,14 @@ def random_model_data(
819
833
  key=k1, shape=(3,), minval=p_min, maxval=p_max
820
834
  )
821
835
 
822
- physics_model_state.base_quaternion = jaxlie.SO3.from_rpy_radians(
823
- *jax.random.uniform(key=k2, shape=(3,), minval=0, maxval=2 * jnp.pi)
824
- ).wxyz
836
+ physics_model_state.base_quaternion = jaxsim.math.Quaternion.to_wxyz(
837
+ xyzw=jax.scipy.spatial.transform.Rotation.from_euler(
838
+ seq=base_rpy_seq,
839
+ angles=jax.random.uniform(
840
+ key=k2, shape=(3,), minval=rpy_min, maxval=rpy_max
841
+ ),
842
+ ).as_quat()
843
+ )
825
844
 
826
845
  if model.number_of_joints() > 0:
827
846
 
jaxsim/api/model.py CHANGED
@@ -1747,14 +1747,18 @@ def link_contact_forces(
1747
1747
  data: The data of the considered model.
1748
1748
 
1749
1749
  Returns:
1750
- A (nL, 6) array containing the stacked 6D contact forces of the links,
1750
+ A `(nL, 6)` array containing the stacked 6D contact forces of the links,
1751
1751
  expressed in the frame corresponding to the active representation.
1752
1752
  """
1753
1753
 
1754
+ # Note: the following code should be kept in sync with the function
1755
+ # `jaxsim.api.ode.system_velocity_dynamics`. We cannot merge them since
1756
+ # there we need to get also aux_data.
1757
+
1754
1758
  # Compute the 6D forces applied to each collidable point expressed in the
1755
1759
  # inertial frame.
1756
1760
  with data.switch_velocity_representation(VelRepr.Inertial):
1757
- W_f_Ci = js.contact.collidable_point_forces(model=model, data=data)
1761
+ W_f_C = js.contact.collidable_point_forces(model=model, data=data)
1758
1762
 
1759
1763
  # Construct the vector defining the parent link index of each collidable point.
1760
1764
  # We use this vector to sum the 6D forces of all collidable points rigidly
@@ -1763,29 +1767,28 @@ def link_contact_forces(
1763
1767
  model.kin_dyn_parameters.contact_parameters.body, dtype=int
1764
1768
  )
1765
1769
 
1770
+ # Create the mask that associate each collidable point to their parent link.
1771
+ # We use this mask to sum the collidable points to the right link.
1772
+ mask = parent_link_index_of_collidable_points[:, jnp.newaxis] == jnp.arange(
1773
+ model.number_of_links()
1774
+ )
1775
+
1766
1776
  # Sum the forces of all collidable points rigidly attached to a body.
1767
- # Since the contact forces W_f_Ci are expressed in the world frame,
1777
+ # Since the contact forces W_f_C are expressed in the world frame,
1768
1778
  # we don't need any coordinate transformation.
1769
- W_f_Li = jax.vmap(
1770
- lambda nc: (
1771
- jnp.vstack(
1772
- jnp.equal(parent_link_index_of_collidable_points, nc).astype(int)
1773
- )
1774
- * W_f_Ci
1775
- ).sum(axis=0)
1776
- )(jnp.arange(model.number_of_links()))
1777
-
1778
- # Convert the 6D forces to the active representation.
1779
- f_Li = jax.vmap(
1780
- lambda W_f_L: data.inertial_to_other_representation(
1781
- array=W_f_L,
1782
- other_representation=data.velocity_representation,
1783
- transform=data.base_transform(),
1784
- is_force=True,
1785
- )
1786
- )(W_f_Li)
1779
+ W_f_L = mask.T @ W_f_C
1780
+
1781
+ # Create a references object to store the link forces.
1782
+ references = js.references.JaxSimModelReferences.build(
1783
+ model=model, link_forces=W_f_L, velocity_representation=VelRepr.Inertial
1784
+ )
1785
+
1786
+ # Use the references object to convert the link forces to the velocity
1787
+ # representation of data.
1788
+ with references.switch_velocity_representation(data.velocity_representation):
1789
+ f_L = references.link_forces(model=model, data=data)
1787
1790
 
1788
- return f_Li
1791
+ return f_L
1789
1792
 
1790
1793
 
1791
1794
  # ======
jaxsim/api/ode.py CHANGED
@@ -132,9 +132,16 @@ def system_velocity_dynamics(
132
132
  # with the terrain.
133
133
  W_f_Li_terrain = jnp.zeros_like(O_f_L).astype(float)
134
134
 
135
+ # Initialize a dictionary of auxiliary data.
136
+ # This dictionary is used to store additional data computed by the contact model.
135
137
  aux_data = {}
138
+
136
139
  if len(model.kin_dyn_parameters.contact_parameters.body) > 0:
137
140
 
141
+ # Note: the following code should be kept in sync with the function
142
+ # `jaxsim.api.model.link_contact_forces`. We cannot merge them since
143
+ # here we need to get also aux_data.
144
+
138
145
  # Compute the 6D forces W_f ∈ ℝ^{n_c × 6} applied to each collidable point
139
146
  # along with contact-specific auxiliary states.
140
147
  with data.switch_velocity_representation(VelRepr.Inertial):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jaxsim
3
- Version: 0.4.3.dev64
3
+ Version: 0.4.3.dev70
4
4
  Summary: A differentiable physics engine and multibody dynamics library for control and robot learning.
5
5
  Author-email: Diego Ferigo <dgferigo@gmail.com>
6
6
  Maintainer-email: Diego Ferigo <dgferigo@gmail.com>, Filippo Luca Ferretti <filippo.ferretti@iit.it>
@@ -1,5 +1,5 @@
1
1
  jaxsim/__init__.py,sha256=bSbpggIz5aG6QuGZLa0V2EfHjAOeucMxi-vIYxzLmN8,2788
2
- jaxsim/_version.py,sha256=lLNskxtfHW1HqvnLRuhux3LlK89fMiZFUWknSYopw7k,426
2
+ jaxsim/_version.py,sha256=TimgvJoa-WOzfWYDSSYt46dywfeE4QnnNq6VHB1jyaQ,426
3
3
  jaxsim/exceptions.py,sha256=8_h8iqL8DgNR754dR8SZiQ7361GR5V1sUk3ZuZCHw1Q,2069
4
4
  jaxsim/logging.py,sha256=STI-D_upXZYX-ZezLrlJJ0UlD5YspST0vZ_DcIwkzO4,1553
5
5
  jaxsim/typing.py,sha256=2HXy9hgazPXjofi1vLQ09ZubPtgVmg80U9NKmZ6NYiI,761
@@ -7,13 +7,13 @@ jaxsim/api/__init__.py,sha256=8eV22t2S3UwNyCg8karPetG1dmX1VDBXkyv28_FwNQA,210
7
7
  jaxsim/api/com.py,sha256=m-p3EJDhpnMTlXKplfbZE_aH9NqX_VyLlAE3vUhc6l4,13642
8
8
  jaxsim/api/common.py,sha256=SNgxq42r6eF_-aPszvOjUYkGwXOzz4hKmhDwEUkscFQ,6650
9
9
  jaxsim/api/contact.py,sha256=C_PgMjWYYiqpA7Oz3IxHeFgrp855-xG6AQr6Ze98CtI,21863
10
- jaxsim/api/data.py,sha256=mFUw2mj8AIXduW6HnkGN7eooZHfJhwnWbtYZfLF6gk4,28206
10
+ jaxsim/api/data.py,sha256=QldUHniJqKrdNtAcXuRaS9UyeslJ0Rjvb17UA0Ca5Tw,29008
11
11
  jaxsim/api/frame.py,sha256=KS8A5wRfjxhe9NgcVo2QA516iP5zky7UVnWxG7nTa7c,12911
12
12
  jaxsim/api/joint.py,sha256=lksT1Doxz2jknHyhb4ls20z6f6dofpZSzBJtVacZXAE,7129
13
13
  jaxsim/api/kin_dyn_parameters.py,sha256=CcfSg5Mc8qb1mZeMQ4AK_ffZIsK5yOl7tu397pFhcDA,29369
14
14
  jaxsim/api/link.py,sha256=qPRtc8qqMRjZxUCZYXJMygbB6huDXBfIT1b1b8Durkw,18631
15
- jaxsim/api/model.py,sha256=K0q8-j-04f6B3MEXsctDGtWiuWlN3HbDrsS7zoPYStk,65871
16
- jaxsim/api/ode.py,sha256=VuOLvCFoyGLmhNf2vFP5BI9BAPz78V_RW5tJ4hrizsw,13041
15
+ jaxsim/api/model.py,sha256=TLjgacgTXm-2YRGDA0Id9pe9nxIem28KoAls6Tdk9WM,66241
16
+ jaxsim/api/ode.py,sha256=AE_obhb1u_xMGs8b4OlOXTvvQrbXOODqIDUWV5VGrJI,13376
17
17
  jaxsim/api/ode_data.py,sha256=7RSoBhfCJdP6P9InQbDwdBVpClPMMuetewI-6AWm-_0,20276
18
18
  jaxsim/api/references.py,sha256=XOVKuQXRmjPoP-T5JWGSbqIGX5DzOkeGafqRpj0ZQEM,20771
19
19
  jaxsim/integrators/__init__.py,sha256=hxvOD-VK_mmd6v31wtC-nb28AYve1gLuZCNLV9wS-Kg,103
@@ -63,8 +63,8 @@ jaxsim/utils/__init__.py,sha256=Y5zyoRevl3EMVQadhZ4EtSwTEkDt2vcnFoRhPJjKTZ0,215
63
63
  jaxsim/utils/jaxsim_dataclass.py,sha256=FSiUvdnq4Y1T9Jaa_mw4ZBQJe8H7deLr3Kupxtlh4iI,11322
64
64
  jaxsim/utils/tracing.py,sha256=KDMoyVPlu2NJvFkhtZwq5AkqMMgajt3munvJom-vEjQ,650
65
65
  jaxsim/utils/wrappers.py,sha256=Fh82ZcaFi5fUnByyFLnmumaobsu1hJIvFdopUVzJ1ps,4052
66
- jaxsim-0.4.3.dev64.dist-info/LICENSE,sha256=eaYdFmdeMbiIoIiPzEK0MjP1S9wtFXjXNR5er49uLR0,1546
67
- jaxsim-0.4.3.dev64.dist-info/METADATA,sha256=0-JS1eJjFMSaMzwqbCSpWYU2GcrZkxT1LBDo7lhWICo,17276
68
- jaxsim-0.4.3.dev64.dist-info/WHEEL,sha256=cVxcB9AmuTcXqmwrtPhNK88dr7IR_b6qagTj0UvIEbY,91
69
- jaxsim-0.4.3.dev64.dist-info/top_level.txt,sha256=LxGMA8FLtXjQ6oI7N5gd_R_oSUHxpXxUEOfT1xS_ni0,7
70
- jaxsim-0.4.3.dev64.dist-info/RECORD,,
66
+ jaxsim-0.4.3.dev70.dist-info/LICENSE,sha256=eaYdFmdeMbiIoIiPzEK0MjP1S9wtFXjXNR5er49uLR0,1546
67
+ jaxsim-0.4.3.dev70.dist-info/METADATA,sha256=G6d-f4k63c6A1nozzlueaBz-fSgvz4zfpQIBecmGAiA,17276
68
+ jaxsim-0.4.3.dev70.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
69
+ jaxsim-0.4.3.dev70.dist-info/top_level.txt,sha256=LxGMA8FLtXjQ6oI7N5gd_R_oSUHxpXxUEOfT1xS_ni0,7
70
+ jaxsim-0.4.3.dev70.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (74.1.2)
2
+ Generator: setuptools (75.1.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5