jaxsim 0.2.1.dev101__py3-none-any.whl → 0.2.1.dev113__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jaxsim/_version.py CHANGED
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.2.1.dev101'
16
- __version_tuple__ = version_tuple = (0, 2, 1, 'dev101')
15
+ __version__ = version = '0.2.1.dev113'
16
+ __version_tuple__ = version_tuple = (0, 2, 1, 'dev113')
jaxsim/api/common.py CHANGED
@@ -87,7 +87,8 @@ class ModelDataWithVelocityRepresentation(JaxsimDataclass, abc.ABC):
87
87
  array: jtp.Array,
88
88
  other_representation: VelRepr,
89
89
  transform: jtp.Matrix,
90
- is_force: bool = False,
90
+ *,
91
+ is_force: bool,
91
92
  ) -> jtp.Array:
92
93
  r"""
93
94
  Convert a 6D quantity from inertial-fixed to another representation.
@@ -153,7 +154,8 @@ class ModelDataWithVelocityRepresentation(JaxsimDataclass, abc.ABC):
153
154
  array: jtp.Array,
154
155
  other_representation: VelRepr,
155
156
  transform: jtp.Matrix,
156
- is_force: bool = False,
157
+ *,
158
+ is_force: bool,
157
159
  ) -> jtp.Array:
158
160
  r"""
159
161
  Convert a 6D quantity from another representation to inertial-fixed.
jaxsim/api/model.py CHANGED
@@ -451,27 +451,36 @@ def generalized_free_floating_jacobian(
451
451
  )
452
452
 
453
453
  # Compute the doubly-left free-floating full jacobian.
454
- B_J_full_WX_B, _ = jaxsim.rbda.jacobian_full_doubly_left(
454
+ B_J_full_WX_B, B_H_L = jaxsim.rbda.jacobian_full_doubly_left(
455
455
  model=model,
456
456
  joint_positions=data.joint_positions(),
457
457
  )
458
458
 
459
- # Update the input velocity representation such that `J_WL_I @ I_ν`.
459
+ # ======================================================================
460
+ # Update the input velocity representation such that v_WL = J_WL_I @ I_ν
461
+ # ======================================================================
462
+
460
463
  match data.velocity_representation:
464
+
461
465
  case VelRepr.Inertial:
466
+
462
467
  W_H_B = data.base_transform()
463
468
  B_X_W = jaxlie.SE3.from_matrix(W_H_B).inverse().adjoint()
469
+
464
470
  B_J_full_WX_I = B_J_full_WX_W = B_J_full_WX_B @ jax.scipy.linalg.block_diag(
465
471
  B_X_W, jnp.eye(model.dofs())
466
472
  )
467
473
 
468
474
  case VelRepr.Body:
475
+
469
476
  B_J_full_WX_I = B_J_full_WX_B
470
477
 
471
478
  case VelRepr.Mixed:
479
+
472
480
  W_R_B = data.base_orientation(dcm=True)
473
481
  BW_H_B = jnp.eye(4).at[0:3, 0:3].set(W_R_B)
474
482
  B_X_BW = jaxlie.SE3.from_matrix(BW_H_B).inverse().adjoint()
483
+
475
484
  B_J_full_WX_I = B_J_full_WX_BW = (
476
485
  B_J_full_WX_B
477
486
  @ jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(model.dofs()))
@@ -480,32 +489,61 @@ def generalized_free_floating_jacobian(
480
489
  case _:
481
490
  raise ValueError(data.velocity_representation)
482
491
 
483
- # Update the output velocity representation such that `O_v_WL_I = O_J_WL_I @ I_ν`.
492
+ # ====================================================================
493
+ # Create stacked Jacobian for each link by filtering the full Jacobian
494
+ # ====================================================================
495
+
496
+ κ_bool = model.kin_dyn_parameters.support_body_array_bool
497
+
498
+ # Keep only the columns of the full Jacobian corresponding to the support
499
+ # body array of each link.
500
+ B_J_WL_I = jax.vmap(
501
+ lambda κ: jnp.where(
502
+ jnp.hstack([jnp.ones(5), κ]), B_J_full_WX_I, jnp.zeros_like(B_J_full_WX_I)
503
+ )
504
+ )(κ_bool)
505
+
506
+ # =======================================================================
507
+ # Update the output velocity representation such that O_v_WL = O_J_WL @ ν
508
+ # =======================================================================
509
+
484
510
  match output_vel_repr:
511
+
485
512
  case VelRepr.Inertial:
513
+
486
514
  W_H_B = data.base_transform()
487
- W_X_B = jaxlie.SE3.from_matrix(W_H_B).adjoint()
488
- O_J_full_WX_I = W_J_full_WX_I = W_X_B @ B_J_full_WX_I
515
+ W_X_B = jaxsim.math.Adjoint.from_transform(W_H_B)
516
+
517
+ O_J_WL_I = W_J_WL_I = jax.vmap(lambda B_J_WL_I: W_X_B @ B_J_WL_I)(B_J_WL_I)
489
518
 
490
519
  case VelRepr.Body:
491
- O_J_full_WX_I = B_J_full_WX_I
520
+
521
+ O_J_WL_I = L_J_WL_I = jax.vmap(
522
+ lambda B_H_L, B_J_WL_I: jaxsim.math.Adjoint.from_transform(
523
+ B_H_L, inverse=True
524
+ )
525
+ @ B_J_WL_I
526
+ )(B_H_L, B_J_WL_I)
492
527
 
493
528
  case VelRepr.Mixed:
494
- W_R_B = data.base_orientation(dcm=True)
495
- BW_H_B = jnp.eye(4).at[0:3, 0:3].set(W_R_B)
496
- BW_X_B = jaxlie.SE3.from_matrix(BW_H_B).adjoint()
497
- O_J_full_WX_I = BW_J_full_WX_I = BW_X_B @ B_J_full_WX_I
498
529
 
499
- case _:
500
- raise ValueError(output_vel_repr)
530
+ W_H_B = data.base_transform()
501
531
 
502
- κ_bool = model.kin_dyn_parameters.support_body_array_bool
532
+ LW_H_L = jax.vmap(
533
+ lambda B_H_L: (W_H_B @ B_H_L).at[0:3, 3].set(jnp.zeros(3))
534
+ )(B_H_L)
503
535
 
504
- O_J_WL_I = jax.vmap(
505
- lambda κ: jnp.where(
506
- jnp.hstack([jnp.ones(5), κ]), O_J_full_WX_I, jnp.zeros_like(O_J_full_WX_I)
507
- )
508
- )(κ_bool)
536
+ LW_H_B = jax.vmap(
537
+ lambda LW_H_L, B_H_L: LW_H_L @ jaxsim.math.Transform.inverse(B_H_L)
538
+ )(LW_H_L, B_H_L)
539
+
540
+ O_J_WL_I = LW_J_WL_I = jax.vmap(
541
+ lambda LW_H_B, B_J_WL_I: jaxsim.math.Adjoint.from_transform(LW_H_B)
542
+ @ B_J_WL_I
543
+ )(LW_H_B, B_J_WL_I)
544
+
545
+ case _:
546
+ raise ValueError(output_vel_repr)
509
547
 
510
548
  return O_J_WL_I
511
549
 
jaxsim/api/ode.py CHANGED
@@ -113,7 +113,7 @@ def system_velocity_dynamics(
113
113
  ).astype(float)
114
114
 
115
115
  # Build link forces if not provided
116
- W_f_L = (
116
+ O_f_L = (
117
117
  jnp.atleast_2d(link_forces.squeeze())
118
118
  if link_forces is not None
119
119
  else jnp.zeros((model.number_of_links(), 6))
@@ -125,7 +125,7 @@ def system_velocity_dynamics(
125
125
 
126
126
  # Initialize the 6D forces W_f ∈ ℝ^{n_L × 6} applied to links due to contact
127
127
  # with the terrain.
128
- W_f_Li_terrain = jnp.zeros_like(W_f_L).astype(float)
128
+ W_f_Li_terrain = jnp.zeros_like(O_f_L).astype(float)
129
129
 
130
130
  # Initialize the 6D contact forces W_f ∈ ℝ^{n_c × 6} applied to collidable points,
131
131
  # expressed in the world frame.
@@ -183,7 +183,7 @@ def system_velocity_dynamics(
183
183
 
184
184
  # Compute the joint friction torque
185
185
  τ_friction = -(
186
- jnp.diag(kc) @ jnp.sign(data.state.physics_model.joint_positions)
186
+ jnp.diag(kc) @ jnp.sign(data.state.physics_model.joint_velocities)
187
187
  + jnp.diag(kv) @ data.state.physics_model.joint_velocities
188
188
  )
189
189
 
@@ -194,6 +194,17 @@ def system_velocity_dynamics(
194
194
  # Compute the total joint forces
195
195
  τ_total = τ + τ_friction + τ_position_limit
196
196
 
197
+ references = js.references.JaxSimModelReferences.build(
198
+ model=model,
199
+ joint_force_references=τ_total,
200
+ link_forces=O_f_L,
201
+ data=data,
202
+ velocity_representation=data.velocity_representation,
203
+ )
204
+
205
+ with references.switch_velocity_representation(VelRepr.Inertial):
206
+ W_f_L = references.link_forces(model=model, data=data)
207
+
197
208
  # Compute the total external 6D forces applied to the links
198
209
  W_f_L_total = W_f_L + W_f_Li_terrain
199
210
 
jaxsim/api/references.py CHANGED
@@ -202,17 +202,22 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
202
202
  if not_tracing(self.input.physics_model.f_ext) and not data.valid(model=model):
203
203
  raise ValueError("The provided data is not valid for the model")
204
204
 
205
- # Helper function to convert a single 6D force to the active representation.
206
- def convert(f_L: jtp.Vector) -> jtp.Vector:
207
- return JaxSimModelReferences.inertial_to_other_representation(
208
- array=f_L,
209
- other_representation=self.velocity_representation,
210
- transform=data.base_transform(),
211
- is_force=True,
212
- )
205
+ # Helper function to convert a single 6D force to the active representation
206
+ # considering as body the link (i.e. L_f_L and LW_f_L).
207
+ def convert(W_f_L: jtp.MatrixLike, W_H_L: jtp.ArrayLike) -> jtp.Matrix:
208
+
209
+ return jax.vmap(
210
+ lambda W_f_L, W_H_L: JaxSimModelReferences.inertial_to_other_representation(
211
+ array=W_f_L,
212
+ other_representation=self.velocity_representation,
213
+ transform=W_H_L,
214
+ is_force=True,
215
+ )
216
+ )(W_f_L, W_H_L)
213
217
 
214
- # Convert to the desired representation.
215
- f_L = jax.vmap(convert)(W_f_L[link_idxs, :])
218
+ # The f_L output is either L_f_L or LW_f_L, depending on the representation.
219
+ W_H_L = js.model.forward_kinematics(model=model, data=data)
220
+ f_L = convert(W_f_L=W_f_L[link_idxs, :], W_H_L=W_H_L[link_idxs, :, :])
216
221
 
217
222
  return f_L
218
223
 
@@ -319,7 +324,7 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
319
324
  forces: jtp.MatrixLike,
320
325
  model: js.model.JaxSimModel | None = None,
321
326
  data: js.data.JaxSimModelData | None = None,
322
- link_names: tuple[str, ...] | None = None,
327
+ link_names: tuple[str, ...] | str | None = None,
323
328
  additive: bool = False,
324
329
  ) -> Self:
325
330
  """
@@ -345,7 +350,7 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
345
350
  Then, we always convert and store forces in inertial-fixed representation.
346
351
  """
347
352
 
348
- f_L = jnp.array(forces)
353
+ f_L = jnp.atleast_2d(forces).astype(float)
349
354
 
350
355
  # Helper function to replace the link forces.
351
356
  def replace(forces: jtp.MatrixLike) -> JaxSimModelReferences:
@@ -380,6 +385,15 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
380
385
 
381
386
  # If we have the model, we can extract the link names if not provided.
382
387
  link_names = link_names if link_names is not None else model.link_names()
388
+
389
+ # Make sure that the link names are a tuple if they are provided by the user.
390
+ link_names = (link_names,) if isinstance(link_names, str) else link_names
391
+
392
+ if len(link_names) != f_L.shape[0]:
393
+ msg = "The number of link names ({}) must match the number of forces ({})"
394
+ raise ValueError(msg.format(len(link_names), f_L.shape[0]))
395
+
396
+ # Extract the link indices.
383
397
  link_idxs = js.link.names_to_idxs(link_names=link_names, model=model)
384
398
 
385
399
  # Compute the bias depending on whether we either set or add the link forces.
@@ -405,16 +419,24 @@ class JaxSimModelReferences(js.common.ModelDataWithVelocityRepresentation):
405
419
  if not_tracing(forces) and not data.valid(model=model):
406
420
  raise ValueError("The provided data is not valid for the model")
407
421
 
408
- # Helper function to convert a single 6D force to the inertial representation.
409
- def convert(f_L: jtp.Vector) -> jtp.Vector:
410
- return JaxSimModelReferences.other_representation_to_inertial(
411
- array=f_L,
412
- other_representation=self.velocity_representation,
413
- transform=data.base_transform(),
414
- is_force=True,
415
- )
422
+ # Helper function to convert a single 6D force to the inertial representation
423
+ # considering as body the link (i.e. L_f_L and LW_f_L).
424
+ def convert_using_link_frame(
425
+ f_L: jtp.MatrixLike, W_H_L: jtp.ArrayLike
426
+ ) -> jtp.Matrix:
427
+
428
+ return jax.vmap(
429
+ lambda f_L, W_H_L: JaxSimModelReferences.other_representation_to_inertial(
430
+ array=f_L,
431
+ other_representation=self.velocity_representation,
432
+ transform=W_H_L,
433
+ is_force=True,
434
+ )
435
+ )(f_L, W_H_L)
416
436
 
417
- W_f_L = jax.vmap(convert)(f_L)
437
+ # The f_L input is either L_f_L or LW_f_L, depending on the representation.
438
+ W_H_L = js.model.forward_kinematics(model=model, data=data)
439
+ W_f_L = convert_using_link_frame(f_L=f_L, W_H_L=W_H_L[link_idxs, :, :])
418
440
 
419
441
  return replace(
420
442
  forces=self.input.physics_model.f_ext.at[link_idxs, :].set(W_f0_L + W_f_L)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jaxsim
3
- Version: 0.2.1.dev101
3
+ Version: 0.2.1.dev113
4
4
  Home-page: https://github.com/ami-iit/jaxsim
5
5
  Author: Diego Ferigo
6
6
  Author-email: diego.ferigo@iit.it
@@ -1,20 +1,20 @@
1
1
  jaxsim/__init__.py,sha256=OcrfoYS1DGcmAGqu2AqlCTiUVxcpi-IsVwcr_16x74Q,1789
2
- jaxsim/_version.py,sha256=G1D0JLzyL0ACAlwNnaPkzuJCqBQM59tregHl11EAaSo,428
2
+ jaxsim/_version.py,sha256=tWgdOeciml9_rSgehjRMlzuVy3GjT9rU5sEmQnyWbSM,428
3
3
  jaxsim/logging.py,sha256=c4zhwBKf9eAYAHVp62kTEllqdsZgh0K-kPKVy8L3elU,1584
4
4
  jaxsim/typing.py,sha256=MeuOCQtLAr-sPkvB_sU8FtwGNRirz1auCwIgRC-QZl8,646
5
5
  jaxsim/api/__init__.py,sha256=8eV22t2S3UwNyCg8karPetG1dmX1VDBXkyv28_FwNQA,210
6
6
  jaxsim/api/com.py,sha256=Yof6otFi-mLWAs1rqjmeNJTOWIH9gn7BdU5EIjiL6Ts,13481
7
- jaxsim/api/common.py,sha256=DV-WZG28sikXopNv458aYvpLjmiAtFr5LRscOwXusuk,6640
7
+ jaxsim/api/common.py,sha256=bqQ__pIQZbh-j8rkoHUkYHAgGiJnDzjHG-q4Ny0OOYQ,6646
8
8
  jaxsim/api/contact.py,sha256=Cvr-EfQtHP3nymtWdo-9WWU24Bkta-2Pp3nKsdjo6uc,12778
9
9
  jaxsim/api/data.py,sha256=xfKJz6Rw0YTk-EHCGiT8BFQrs_ggOz01lRi1Qh1mb28,27256
10
10
  jaxsim/api/frame.py,sha256=0YXOrGmx3cSQqa4_Ky-n6zyup3I3xvXNEgub-Bc5xUw,6222
11
11
  jaxsim/api/joint.py,sha256=-5DogPg4g4mmLckyVIVNjwv-Rxz0IWS7_md9nDlhPWA,4581
12
12
  jaxsim/api/kin_dyn_parameters.py,sha256=zMca7OmCsCWK_cavLTSZSeYh9Qu1-409cdsyWvWPAUQ,26090
13
13
  jaxsim/api/link.py,sha256=rypTwkMf9HJ5UuAtHRJh0LqqdJWcLKTtTjWcjduEsF0,9842
14
- jaxsim/api/model.py,sha256=Ii17tBzCkMHCa_G7plKPzNcW-bj3QGuGI9jCklPQStM,54294
15
- jaxsim/api/ode.py,sha256=6l-6i2YHagsQvR8Ac-_fmO6P0hBVT6NkHhwXnrdITEg,9785
14
+ jaxsim/api/model.py,sha256=1HlQ5FMzeJAk-cE1pmELgVjzMYUX9-iipw3N4WssAL4,55435
15
+ jaxsim/api/ode.py,sha256=BfvV_14uu0szWecoDiV8rTu-dvSFLK7eyrO38ZqHB_w,10157
16
16
  jaxsim/api/ode_data.py,sha256=D6FzMkvY_qNuoFEImyp7sxAk-0pJOd3oZeSr9bBTcLk,23089
17
- jaxsim/api/references.py,sha256=Lvskf17r619KKxwCJP7hAAty2kaXgDXJX1uKqoDIDgo,15483
17
+ jaxsim/api/references.py,sha256=UA6kSQVBoq-bXSo99EOELf-_MD5MTy2zS0GtG3wQ410,16618
18
18
  jaxsim/integrators/__init__.py,sha256=hxvOD-VK_mmd6v31wtC-nb28AYve1gLuZCNLV9wS-Kg,103
19
19
  jaxsim/integrators/common.py,sha256=9HXRVFo95Mpt6RcVhBrOfvOO7mDxqbkXeg_lKUibEFY,20693
20
20
  jaxsim/integrators/fixed_step.py,sha256=JXaEyEzfSiYea0GnPA7l27J3X0YPB0e25D4qfrxAvzQ,2766
@@ -58,8 +58,8 @@ jaxsim/utils/__init__.py,sha256=Y5zyoRevl3EMVQadhZ4EtSwTEkDt2vcnFoRhPJjKTZ0,215
58
58
  jaxsim/utils/jaxsim_dataclass.py,sha256=h26timZ_XrBL_Q_oymv-DkQd-EcUiHn8QexAaZXBY9c,11396
59
59
  jaxsim/utils/tracing.py,sha256=KDMoyVPlu2NJvFkhtZwq5AkqMMgajt3munvJom-vEjQ,650
60
60
  jaxsim/utils/wrappers.py,sha256=EJMcblYKUjxw9HJShVf81Ig3pHUJno6Dx6h-RnY--wM,2040
61
- jaxsim-0.2.1.dev101.dist-info/LICENSE,sha256=eaYdFmdeMbiIoIiPzEK0MjP1S9wtFXjXNR5er49uLR0,1546
62
- jaxsim-0.2.1.dev101.dist-info/METADATA,sha256=e2Rvemht7C6PBWz899beW2a-92wHtkPzxKpaRelY-eg,9745
63
- jaxsim-0.2.1.dev101.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
64
- jaxsim-0.2.1.dev101.dist-info/top_level.txt,sha256=LxGMA8FLtXjQ6oI7N5gd_R_oSUHxpXxUEOfT1xS_ni0,7
65
- jaxsim-0.2.1.dev101.dist-info/RECORD,,
61
+ jaxsim-0.2.1.dev113.dist-info/LICENSE,sha256=eaYdFmdeMbiIoIiPzEK0MjP1S9wtFXjXNR5er49uLR0,1546
62
+ jaxsim-0.2.1.dev113.dist-info/METADATA,sha256=NH5olEy-GNgTb5Ibe-q3FH2bu_WOuY0hbQBJZCcOiS4,9745
63
+ jaxsim-0.2.1.dev113.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
64
+ jaxsim-0.2.1.dev113.dist-info/top_level.txt,sha256=LxGMA8FLtXjQ6oI7N5gd_R_oSUHxpXxUEOfT1xS_ni0,7
65
+ jaxsim-0.2.1.dev113.dist-info/RECORD,,