jaxsim 0.6.2.dev233__py3-none-any.whl → 0.6.2.dev238__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jaxsim/_version.py CHANGED
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.6.2.dev233'
21
- __version_tuple__ = version_tuple = (0, 6, 2, 'dev233')
20
+ __version__ = version = '0.6.2.dev238'
21
+ __version_tuple__ = version_tuple = (0, 6, 2, 'dev238')
jaxsim/api/data.py CHANGED
@@ -60,7 +60,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
60
60
  _link_velocities: jtp.Matrix = dataclasses.field(repr=False, default=None)
61
61
 
62
62
  # Extended state for soft and rigid contact models.
63
- contact_state: dict[str, jtp.Array] = dataclasses.field(default=None)
63
+ contact_state: dict[str, jtp.Array] = dataclasses.field(default_factory=dict)
64
64
 
65
65
  @staticmethod
66
66
  def build(
@@ -174,7 +174,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
174
174
  contact_state = contact_state or {}
175
175
 
176
176
  if isinstance(model.contact_model, jaxsim.rbda.contacts.SoftContacts):
177
- contact_state.setdefault(
177
+ contact_state["tangential_deformation"] = contact_state.get(
178
178
  "tangential_deformation",
179
179
  jnp.zeros_like(model.kin_dyn_parameters.contact_parameters.point),
180
180
  )
@@ -420,11 +420,6 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
420
420
  Replace the attributes of the `JaxSimModelData` object.
421
421
  """
422
422
 
423
- # Extract the batch size.
424
- batch_size = (
425
- self._base_transform.shape[0] if self._base_transform.ndim > 2 else 1
426
- )
427
-
428
423
  if joint_positions is None:
429
424
  joint_positions = self.joint_positions
430
425
  if joint_velocities is None:
@@ -437,10 +432,11 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
437
432
  contact_state = self.contact_state
438
433
 
439
434
  if isinstance(model.contact_model, jaxsim.rbda.contacts.SoftContacts):
440
- contact_state.setdefault(
441
- "tangential_deformation",
442
- jnp.zeros_like(model.kin_dyn_parameters.contact_parameters.point),
443
- )
435
+ contact_state = {
436
+ "tangential_deformation": jnp.zeros_like(
437
+ contact_state["tangential_deformation"]
438
+ )
439
+ }
444
440
 
445
441
  # Normalize the quaternion to avoid numerical issues.
446
442
  base_quaternion_norm = jaxsim.math.safe_norm(
@@ -450,20 +446,18 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
450
446
  base_quaternion_norm == 0, 1.0, base_quaternion_norm
451
447
  )
452
448
 
453
- joint_positions = jnp.atleast_1d(joint_positions.squeeze()).astype(float)
454
- joint_velocities = jnp.atleast_1d(joint_velocities.squeeze()).astype(float)
455
- base_quaternion = jnp.atleast_1d(base_quaternion.squeeze()).astype(float)
456
- base_position = jnp.atleast_1d(base_position.squeeze()).astype(float)
449
+ joint_positions = jnp.atleast_2d(joint_positions.squeeze()).astype(float)
450
+ joint_velocities = jnp.atleast_2d(joint_velocities.squeeze()).astype(float)
451
+ base_quaternion = jnp.atleast_2d(base_quaternion.squeeze()).astype(float)
452
+ base_position = jnp.atleast_2d(base_position.squeeze()).astype(float)
457
453
 
458
454
  base_transform = jaxsim.math.Transform.from_quaternion_and_translation(
459
455
  translation=base_position, quaternion=base_quaternion
460
- )
456
+ ).reshape((-1, 4, 4))
461
457
 
462
458
  joint_transforms = jax.vmap(model.kin_dyn_parameters.joint_transforms)(
463
- joint_positions=jnp.broadcast_to(
464
- joint_positions, (batch_size, model.dofs())
465
- ),
466
- base_transform=jnp.broadcast_to(base_transform, (batch_size, 4, 4)),
459
+ joint_positions=joint_positions,
460
+ base_transform=base_transform,
467
461
  )
468
462
 
469
463
  if base_linear_velocity is None and base_angular_velocity is None:
@@ -494,27 +488,31 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
494
488
  jaxsim.rbda.forward_kinematics_model, in_axes=(None,)
495
489
  )(
496
490
  model,
497
- base_position=jnp.broadcast_to(base_position, (batch_size, 3)),
498
- base_quaternion=jnp.broadcast_to(base_quaternion, (batch_size, 4)),
499
- joint_positions=jnp.broadcast_to(
500
- joint_positions, (batch_size, model.dofs())
501
- ),
502
- joint_velocities=jnp.broadcast_to(
503
- joint_velocities, (batch_size, model.dofs())
504
- ),
505
- base_linear_velocity_inertial=jnp.broadcast_to(
506
- base_linear_velocity_inertial, (batch_size, 3)
507
- ),
508
- base_angular_velocity_inertial=jnp.broadcast_to(
509
- base_angular_velocity_inertial, (batch_size, 3)
491
+ base_position=base_position,
492
+ base_quaternion=base_quaternion,
493
+ joint_positions=joint_positions,
494
+ joint_velocities=joint_velocities,
495
+ base_linear_velocity_inertial=jnp.atleast_2d(base_linear_velocity_inertial),
496
+ base_angular_velocity_inertial=jnp.atleast_2d(
497
+ base_angular_velocity_inertial
510
498
  ),
511
499
  )
512
500
 
513
501
  # Adjust the output shapes.
514
- if batch_size == 1:
515
- link_transforms = link_transforms.reshape(self._link_transforms.shape)
516
- link_velocities = link_velocities.reshape(self._link_velocities.shape)
517
- joint_transforms = joint_transforms.reshape(self._joint_transforms.shape)
502
+ joint_positions = joint_positions.reshape(self._joint_positions.shape)
503
+ joint_velocities = joint_velocities.reshape(self._joint_velocities.shape)
504
+ base_quaternion = base_quaternion.reshape(self._base_quaternion.shape)
505
+ base_linear_velocity_inertial = base_linear_velocity_inertial.reshape(
506
+ self._base_linear_velocity.shape
507
+ )
508
+ base_angular_velocity_inertial = base_angular_velocity_inertial.reshape(
509
+ self._base_angular_velocity.shape
510
+ )
511
+ base_position = base_position.reshape(self._base_position.shape)
512
+ base_transform = base_transform.reshape(self._base_transform.shape)
513
+ joint_transforms = joint_transforms.reshape(self._joint_transforms.shape)
514
+ link_transforms = link_transforms.reshape(self._link_transforms.shape)
515
+ link_velocities = link_velocities.reshape(self._link_velocities.shape)
518
516
 
519
517
  return super().replace(
520
518
  _joint_positions=joint_positions,
@@ -477,8 +477,7 @@ class RelaxedRigidContacts(common.ContactModel):
477
477
  tol = solver_options.pop("tol")
478
478
  maxiter = solver_options.pop("maxiter")
479
479
 
480
- # Compute the 3D linear force in C[W] frame.
481
- solution, _ = run_optimization(
480
+ solve_fn = lambda *_: run_optimization(
482
481
  init_params=init_params,
483
482
  fun=objective,
484
483
  opt=optax.lbfgs(**solver_options),
@@ -486,6 +485,15 @@ class RelaxedRigidContacts(common.ContactModel):
486
485
  maxiter=maxiter,
487
486
  )
488
487
 
488
+ # Compute the 3D linear force in C[W] frame.
489
+ solution, _ = jax.lax.custom_linear_solve(
490
+ lambda x: A @ x,
491
+ -b,
492
+ solve=solve_fn,
493
+ symmetric=True,
494
+ has_aux=True,
495
+ )
496
+
489
497
  # Reshape the optimized solution to be a matrix of 3D contact forces.
490
498
  CW_fl_C = solution.reshape(-1, 3)
491
499
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: jaxsim
3
- Version: 0.6.2.dev233
3
+ Version: 0.6.2.dev238
4
4
  Summary: A differentiable physics engine and multibody dynamics library for control and robot learning.
5
5
  Author-email: Diego Ferigo <dgferigo@gmail.com>, Filippo Luca Ferretti <filippoluca.ferretti@outlook.com>
6
6
  Maintainer-email: Filippo Luca Ferretti <filippo.ferretti@iit.it>, Alessandro Croci <alessandro.croci@iit.it>
@@ -1,5 +1,5 @@
1
1
  jaxsim/__init__.py,sha256=b8dYoVXqtHxHcF56iM2xgKk78lsvmGrfDlvdwaGasgs,3388
2
- jaxsim/_version.py,sha256=sPJq_dpGjKmFrDc3jN3zh6Jvj0t6QyFFNuDkHRcOzlE,528
2
+ jaxsim/_version.py,sha256=eVqBRpQ8sW7XIxtY3zbtRorKBdbdS2mv3DLVp8TmLjI,528
3
3
  jaxsim/exceptions.py,sha256=MQ3LRMfVMX2-g3qYj7mUVNV9OLlIA48TANJegbcQyXI,2641
4
4
  jaxsim/logging.py,sha256=STI-D_upXZYX-ZezLrlJJ0UlD5YspST0vZ_DcIwkzO4,1553
5
5
  jaxsim/typing.py,sha256=7msl8t5Jt09RNYfKdPJtpjLfWurldcycDappb045Eso,761
@@ -8,7 +8,7 @@ jaxsim/api/actuation_model.py,sha256=L8AzxIiEquWeG8UGGJaYr2Alt4dkkOROlbsCn9hUYik
8
8
  jaxsim/api/com.py,sha256=47a9SSaXY540RCkVnHodwLNUrodIfJIkguIYdSEQVwQ,13697
9
9
  jaxsim/api/common.py,sha256=yTaRXDYkXmISBOhZ93f9TssR0p4wq7qj7B6OsvYzRME,6942
10
10
  jaxsim/api/contact.py,sha256=dlKKDQUG-KQ5qQaYBv2NmZLDb1OnJdltZv8MWXkD_W0,20969
11
- jaxsim/api/data.py,sha256=aW6buyH3YszzEBqBXqBgHf6uaSilZcqU7PyP60qgLco,23336
11
+ jaxsim/api/data.py,sha256=KBDN7zFrNDSkry3Pow1LdBZ6kmfwk7dR5jaLA6DNAVc,23468
12
12
  jaxsim/api/frame.py,sha256=4wg6GsyBQgYhSvc-ry_31JsaL66sZt3TtgwjB7NrHmk,14583
13
13
  jaxsim/api/integrators.py,sha256=DgOnzLepy45e-TM6Infk8qfPXn0r8GubCdJQZmNLP8w,5269
14
14
  jaxsim/api/joint.py,sha256=AnqlNWmBOay-gsoo0y4AbfFQ2OCJm-8T1E0IMhZeLoY,7457
@@ -54,7 +54,7 @@ jaxsim/rbda/rnea.py,sha256=lMU7xxdPqGGzk0QwteB-IYjL4auHOpd78C1YqAXlp9s,7588
54
54
  jaxsim/rbda/utils.py,sha256=6JwEDQqLMsBX7CUmPYEhdPEscXmGbWVYg6xEriPOgvE,5587
55
55
  jaxsim/rbda/contacts/__init__.py,sha256=resrBkTdOA-1YMdcdUH2RATEhAf_Ye6MQNtjG3ClMYQ,371
56
56
  jaxsim/rbda/contacts/common.py,sha256=qVm3Ghoytg1HAeykNrYw5-4rQJ4Mv7h0Pk75ETzGXyc,9045
57
- jaxsim/rbda/contacts/relaxed_rigid.py,sha256=zoMbc1edTsbUBf3hVCUoJb-xTCrtQT4RXKaILFup3KI,21152
57
+ jaxsim/rbda/contacts/relaxed_rigid.py,sha256=RjeLF06Pp19qio447U9z5EdhdM6nyMh-ISQX_2-vdaE,21349
58
58
  jaxsim/rbda/contacts/rigid.py,sha256=ctJ4_4qSHaqGxCBMTYhvWWm8D4hJ-YdDiorJuURiNWw,17601
59
59
  jaxsim/rbda/contacts/soft.py,sha256=Ac9aWDdjAm55Mv9LLnEs3nj7hX_NvJMnicV35SbFLSY,15282
60
60
  jaxsim/terrain/__init__.py,sha256=f7lVX-iNpH_wkkjef9Qpjh19TTAUOQw76EiLYJDVizc,78
@@ -63,8 +63,8 @@ jaxsim/utils/__init__.py,sha256=Y5zyoRevl3EMVQadhZ4EtSwTEkDt2vcnFoRhPJjKTZ0,215
63
63
  jaxsim/utils/jaxsim_dataclass.py,sha256=XzmZeIibcaOzaxpprsGSxH3UrM66PAO456rFV91sNXg,11453
64
64
  jaxsim/utils/tracing.py,sha256=Btwxdfhb7fJLk3r5PlQkGYj60Y2KbFT1gANGIA697FU,530
65
65
  jaxsim/utils/wrappers.py,sha256=3IMwydqFgmSPqeuUQ3PRmdhDc1IoT6XC23jPC_LjWXs,4175
66
- jaxsim-0.6.2.dev233.dist-info/licenses/LICENSE,sha256=eaYdFmdeMbiIoIiPzEK0MjP1S9wtFXjXNR5er49uLR0,1546
67
- jaxsim-0.6.2.dev233.dist-info/METADATA,sha256=9kHFCW__ENh9MIjH0wu2LxivnchhSt1vXL4Obu_BUD4,19650
68
- jaxsim-0.6.2.dev233.dist-info/WHEEL,sha256=DK49LOLCYiurdXXOXwGJm6U4DkHkg4lcxjhqwRa0CP4,91
69
- jaxsim-0.6.2.dev233.dist-info/top_level.txt,sha256=LxGMA8FLtXjQ6oI7N5gd_R_oSUHxpXxUEOfT1xS_ni0,7
70
- jaxsim-0.6.2.dev233.dist-info/RECORD,,
66
+ jaxsim-0.6.2.dev238.dist-info/licenses/LICENSE,sha256=eaYdFmdeMbiIoIiPzEK0MjP1S9wtFXjXNR5er49uLR0,1546
67
+ jaxsim-0.6.2.dev238.dist-info/METADATA,sha256=RqW1Edan0yEe43Bp8uljqRkFdaTqvSffYxCg1bGKEjY,19650
68
+ jaxsim-0.6.2.dev238.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
69
+ jaxsim-0.6.2.dev238.dist-info/top_level.txt,sha256=LxGMA8FLtXjQ6oI7N5gd_R_oSUHxpXxUEOfT1xS_ni0,7
70
+ jaxsim-0.6.2.dev238.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (78.0.2)
2
+ Generator: setuptools (78.1.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5