jaxsim 0.4.2.dev19__py3-none-any.whl → 0.4.2.dev28__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jaxsim/_version.py CHANGED
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.4.2.dev19'
16
- __version_tuple__ = version_tuple = (0, 4, 2, 'dev19')
15
+ __version__ = version = '0.4.2.dev28'
16
+ __version_tuple__ = version_tuple = (0, 4, 2, 'dev28')
jaxsim/api/contact.py CHANGED
@@ -8,6 +8,7 @@ import jax.numpy as jnp
8
8
  import jaxsim.api as js
9
9
  import jaxsim.terrain
10
10
  import jaxsim.typing as jtp
11
+ from jaxsim.math import Adjoint, Cross, Transform
11
12
  from jaxsim.rbda.contacts.soft import SoftContactsParams
12
13
 
13
14
  from .common import VelRepr
@@ -411,3 +412,168 @@ def jacobian(
411
412
  raise ValueError(output_vel_repr)
412
413
 
413
414
  return O_J_WC
415
+
416
+
417
+ @functools.partial(jax.jit, static_argnames=["output_vel_repr"])
418
+ def jacobian_derivative(
419
+ model: js.model.JaxSimModel,
420
+ data: js.data.JaxSimModelData,
421
+ *,
422
+ output_vel_repr: VelRepr | None = None,
423
+ ) -> jtp.Matrix:
424
+ r"""
425
+ Compute the derivative of the free-floating jacobian of the contact points.
426
+
427
+ Args:
428
+ model: The model to consider.
429
+ data: The data of the considered model.
430
+ output_vel_repr:
431
+ The output velocity representation of the free-floating jacobian derivative.
432
+
433
+ Returns:
434
+ The derivative of the :math:`6 \times (6+n)` free-floating jacobian of the contact points.
435
+
436
+ Note:
437
+ The input representation of the free-floating jacobian derivative is the active
438
+ velocity representation.
439
+ """
440
+
441
+ output_vel_repr = (
442
+ output_vel_repr if output_vel_repr is not None else data.velocity_representation
443
+ )
444
+
445
+ # Get the index of the parent link and the position of the collidable point.
446
+ parent_link_idxs = jnp.array(model.kin_dyn_parameters.contact_parameters.body)
447
+ L_p_Ci = jnp.array(model.kin_dyn_parameters.contact_parameters.point)
448
+ contact_idxs = jnp.arange(L_p_Ci.shape[0])
449
+
450
+ # Get the transforms of all the parent links.
451
+ W_H_Li = js.model.forward_kinematics(model=model, data=data)
452
+
453
+ # =====================================================
454
+ # Compute quantities to adjust the input representation
455
+ # =====================================================
456
+
457
+ def compute_T(model: js.model.JaxSimModel, X: jtp.Matrix) -> jtp.Matrix:
458
+ In = jnp.eye(model.dofs())
459
+ T = jax.scipy.linalg.block_diag(X, In)
460
+ return T
461
+
462
+ def compute_Ṫ(model: js.model.JaxSimModel, Ẋ: jtp.Matrix) -> jtp.Matrix:
463
+ On = jnp.zeros(shape=(model.dofs(), model.dofs()))
464
+ Ṫ = jax.scipy.linalg.block_diag(Ẋ, On)
465
+ return Ṫ
466
+
467
+ # Compute the operator to change the representation of ν, and its
468
+ # time derivative.
469
+ match data.velocity_representation:
470
+ case VelRepr.Inertial:
471
+ W_H_W = jnp.eye(4)
472
+ W_X_W = Adjoint.from_transform(transform=W_H_W)
473
+ W_Ẋ_W = jnp.zeros((6, 6))
474
+
475
+ T = compute_T(model=model, X=W_X_W)
476
+ Ṫ = compute_Ṫ(model=model, Ẋ=W_Ẋ_W)
477
+
478
+ case VelRepr.Body:
479
+ W_H_B = data.base_transform()
480
+ W_X_B = Adjoint.from_transform(transform=W_H_B)
481
+ B_v_WB = data.base_velocity()
482
+ B_vx_WB = Cross.vx(B_v_WB)
483
+ W_Ẋ_B = W_X_B @ B_vx_WB
484
+
485
+ T = compute_T(model=model, X=W_X_B)
486
+ Ṫ = compute_Ṫ(model=model, Ẋ=W_Ẋ_B)
487
+
488
+ case VelRepr.Mixed:
489
+ W_H_B = data.base_transform()
490
+ W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3))
491
+ W_X_BW = Adjoint.from_transform(transform=W_H_BW)
492
+ BW_v_WB = data.base_velocity()
493
+ BW_v_W_BW = BW_v_WB.at[3:6].set(jnp.zeros(3))
494
+ BW_vx_W_BW = Cross.vx(BW_v_W_BW)
495
+ W_Ẋ_BW = W_X_BW @ BW_vx_W_BW
496
+
497
+ T = compute_T(model=model, X=W_X_BW)
498
+ Ṫ = compute_Ṫ(model=model, Ẋ=W_Ẋ_BW)
499
+
500
+ case _:
501
+ raise ValueError(data.velocity_representation)
502
+
503
+ # =====================================================
504
+ # Compute quantities to adjust the output representation
505
+ # =====================================================
506
+
507
+ with data.switch_velocity_representation(VelRepr.Inertial):
508
+ # Compute the Jacobian of the parent link in inertial representation.
509
+ W_J_WL_W = js.model.generalized_free_floating_jacobian(
510
+ model=model,
511
+ data=data,
512
+ output_vel_repr=VelRepr.Inertial,
513
+ )
514
+ # Compute the Jacobian derivative of the parent link in inertial representation.
515
+ W_J̇_WL_W = js.model.generalized_free_floating_jacobian_derivative(
516
+ model=model,
517
+ data=data,
518
+ output_vel_repr=VelRepr.Inertial,
519
+ )
520
+
521
+ # Get the Jacobian of the collidable points in the mixed representation.
522
+ with data.switch_velocity_representation(VelRepr.Mixed):
523
+ CW_J_WC_BW = jacobian(
524
+ model=model,
525
+ data=data,
526
+ output_vel_repr=VelRepr.Mixed,
527
+ )
528
+
529
+ def compute_O_J̇_WC_I(
530
+ L_p_C: jtp.Vector,
531
+ contact_idx: jtp.Int,
532
+ CW_J_WC_BW: jtp.Matrix,
533
+ W_H_L: jtp.Matrix,
534
+ ) -> jtp.Matrix:
535
+
536
+ parent_link_idx = parent_link_idxs[contact_idx]
537
+
538
+ match output_vel_repr:
539
+ case VelRepr.Inertial:
540
+ O_X_W = W_X_W = Adjoint.from_transform(transform=jnp.eye(4))
541
+ O_Ẋ_W = W_Ẋ_W = jnp.zeros((6, 6))
542
+
543
+ case VelRepr.Body:
544
+ L_H_C = Transform.from_rotation_and_translation(translation=L_p_C)
545
+ W_H_C = W_H_L[parent_link_idx] @ L_H_C
546
+ O_X_W = C_X_W = Adjoint.from_transform(transform=W_H_C, inverse=True)
547
+ with data.switch_velocity_representation(VelRepr.Inertial):
548
+ W_nu = data.generalized_velocity()
549
+ W_v_WC = W_J_WL_W[parent_link_idx] @ W_nu
550
+ W_vx_WC = Cross.vx(W_v_WC)
551
+ O_Ẋ_W = C_Ẋ_W = -C_X_W @ W_vx_WC
552
+
553
+ case VelRepr.Mixed:
554
+ L_H_C = Transform.from_rotation_and_translation(translation=L_p_C)
555
+ W_H_C = W_H_L[parent_link_idx] @ L_H_C
556
+ W_H_CW = W_H_C.at[0:3, 0:3].set(jnp.eye(3))
557
+ CW_H_W = Transform.inverse(W_H_CW)
558
+ O_X_W = CW_X_W = Adjoint.from_transform(transform=CW_H_W)
559
+ with data.switch_velocity_representation(VelRepr.Mixed):
560
+ CW_v_WC = CW_J_WC_BW @ data.generalized_velocity()
561
+ W_v_W_CW = jnp.zeros(6).at[0:3].set(CW_v_WC[0:3])
562
+ W_vx_W_CW = Cross.vx(W_v_W_CW)
563
+ O_Ẋ_W = CW_Ẋ_W = -CW_X_W @ W_vx_W_CW
564
+
565
+ case _:
566
+ raise ValueError(output_vel_repr)
567
+
568
+ O_J̇_WC_I = jnp.zeros(shape=(6, 6 + model.dofs()))
569
+ O_J̇_WC_I += O_Ẋ_W @ W_J_WL_W[parent_link_idx] @ T
570
+ O_J̇_WC_I += O_X_W @ W_J̇_WL_W[parent_link_idx] @ T
571
+ O_J̇_WC_I += O_X_W @ W_J_WL_W[parent_link_idx] @ Ṫ
572
+
573
+ return O_J̇_WC_I
574
+
575
+ O_J̇_WC = jax.vmap(compute_O_J̇_WC_I, in_axes=(0, 0, 0, None))(
576
+ L_p_Ci, contact_idxs, CW_J_WC_BW, W_H_Li
577
+ )
578
+
579
+ return O_J̇_WC
jaxsim/api/model.py CHANGED
@@ -576,6 +576,41 @@ def generalized_free_floating_jacobian(
576
576
  return O_J_WL_I
577
577
 
578
578
 
579
+ @functools.partial(jax.jit, static_argnames=["output_vel_repr"])
580
+ def generalized_free_floating_jacobian_derivative(
581
+ model: JaxSimModel,
582
+ data: js.data.JaxSimModelData,
583
+ *,
584
+ output_vel_repr: VelRepr | None = None,
585
+ ) -> jtp.Matrix:
586
+ """
587
+ Compute the free-floating jacobian derivatives of all links.
588
+
589
+ Args:
590
+ model: The model to consider.
591
+ data: The data of the considered model.
592
+ output_vel_repr:
593
+ The output velocity representation of the free-floating jacobian derivatives.
594
+
595
+ Returns:
596
+ The `(nL, 6, 6+dofs)` array containing the stacked free-floating
597
+ jacobian derivatives of the links. The first axis is the link index.
598
+ """
599
+
600
+ output_vel_repr = (
601
+ output_vel_repr if output_vel_repr is not None else data.velocity_representation
602
+ )
603
+
604
+ O_J̇_WL_I = jax.vmap(
605
+ lambda model, data, link_idxs, output_vel_repr: js.link.jacobian_derivative(
606
+ model, data, link_index=link_idxs, output_vel_repr=output_vel_repr
607
+ ),
608
+ in_axes=(None, None, 0, None),
609
+ )(model, data, jnp.arange(model.number_of_links()), output_vel_repr)
610
+
611
+ return O_J̇_WL_I
612
+
613
+
579
614
  @functools.partial(jax.jit, static_argnames=["prefer_aba"])
580
615
  def forward_dynamics(
581
616
  model: JaxSimModel,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jaxsim
3
- Version: 0.4.2.dev19
3
+ Version: 0.4.2.dev28
4
4
  Summary: A differentiable physics engine and multibody dynamics library for control and robot learning.
5
5
  Author-email: Diego Ferigo <dgferigo@gmail.com>
6
6
  Maintainer-email: Diego Ferigo <dgferigo@gmail.com>, Filippo Luca Ferretti <filippo.ferretti@iit.it>
@@ -1,18 +1,18 @@
1
1
  jaxsim/__init__.py,sha256=ixsS4dYMPex2wOUUp_rkPnwrPhYzkRh1xO_YuMj3Cr4,2626
2
- jaxsim/_version.py,sha256=fa--MfLuJkEPIj9KjXW-ToBHp9R1u3zovY8O3FGfsuY,426
2
+ jaxsim/_version.py,sha256=2qDCKMaCBeWCRwlxCB4UQbl_v3wIgT4nVMoTvL5ZPKA,426
3
3
  jaxsim/exceptions.py,sha256=8_h8iqL8DgNR754dR8SZiQ7361GR5V1sUk3ZuZCHw1Q,2069
4
4
  jaxsim/logging.py,sha256=c4zhwBKf9eAYAHVp62kTEllqdsZgh0K-kPKVy8L3elU,1584
5
5
  jaxsim/typing.py,sha256=IbFx3UkEXi-cm7UBqMPi58rJAFV_HbZ9E_K4JwfNvVM,753
6
6
  jaxsim/api/__init__.py,sha256=8eV22t2S3UwNyCg8karPetG1dmX1VDBXkyv28_FwNQA,210
7
7
  jaxsim/api/com.py,sha256=6TnYCvjmsJ2KLuw3NtZb0pay7ZwGKe9MKphYeQdjpQ0,13474
8
8
  jaxsim/api/common.py,sha256=Ubi6uAw3o6qbdU0TFGzUyHg98EnoMzrnlihrvrs95Sk,6653
9
- jaxsim/api/contact.py,sha256=EcOx_T94gZT3igtebmW9FJDpZYPEf-RwKfFN18JjOWM,13364
9
+ jaxsim/api/contact.py,sha256=RdxPdwYbCpD_Wz0oiG9N569EoXtLiDGUAvCuoOJFYxc,19761
10
10
  jaxsim/api/data.py,sha256=T6m7-NteWrm-K3491Yo76xvlWtFCKqTzEE7Ughcasi8,27197
11
11
  jaxsim/api/frame.py,sha256=yQmhh8fckXnqzs7dQvojOzbuSanNGLwUTWUQDXbVtF4,12874
12
12
  jaxsim/api/joint.py,sha256=Pvg_It2iYA-jAQ2nOlFZxwmITiozO_f46G13BdQtHQ0,5106
13
13
  jaxsim/api/kin_dyn_parameters.py,sha256=CcfSg5Mc8qb1mZeMQ4AK_ffZIsK5yOl7tu397pFhcDA,29369
14
14
  jaxsim/api/link.py,sha256=hn7fbxaebHeXnvwEG9jZiWwzRcfdS8m-18LVsIG3S24,18479
15
- jaxsim/api/model.py,sha256=PMTSz00AIVopwiJ3zGBoYPTtkLH_beJCcQsX9wBE38I,61502
15
+ jaxsim/api/model.py,sha256=8TTFVc7HuPLxIRq4h1pzXwuSlQ4sDVmL5KpiO-OnmnM,62662
16
16
  jaxsim/api/ode.py,sha256=NnLTBvpaT4kXnbjAghXIzLv9DTMJ8bele2iOlUQDv3Q,11028
17
17
  jaxsim/api/ode_data.py,sha256=9YZX-SK_KJtoIqG-zYWZsQInb2NA_LtxDn-jtLqm_3U,19759
18
18
  jaxsim/api/references.py,sha256=UA6kSQVBoq-bXSo99EOELf-_MD5MTy2zS0GtG3wQ410,16618
@@ -61,8 +61,8 @@ jaxsim/utils/__init__.py,sha256=Y5zyoRevl3EMVQadhZ4EtSwTEkDt2vcnFoRhPJjKTZ0,215
61
61
  jaxsim/utils/jaxsim_dataclass.py,sha256=fLl1tY3DDb3lpIhG6BPqA5W34hM84oFzL-5cuz8k-68,11379
62
62
  jaxsim/utils/tracing.py,sha256=KDMoyVPlu2NJvFkhtZwq5AkqMMgajt3munvJom-vEjQ,650
63
63
  jaxsim/utils/wrappers.py,sha256=GOJQCJc5zwzoEGZB62wnWWGvUUQlXvDxz_A2Q-hFv7c,4027
64
- jaxsim-0.4.2.dev19.dist-info/LICENSE,sha256=eaYdFmdeMbiIoIiPzEK0MjP1S9wtFXjXNR5er49uLR0,1546
65
- jaxsim-0.4.2.dev19.dist-info/METADATA,sha256=F_Rh5tV8N8VP-6NGTLxInNLuMTm1lRu9IRiu7r-jh7k,17250
66
- jaxsim-0.4.2.dev19.dist-info/WHEEL,sha256=Wyh-_nZ0DJYolHNn1_hMa4lM7uDedD_RGVwbmTjyItk,91
67
- jaxsim-0.4.2.dev19.dist-info/top_level.txt,sha256=LxGMA8FLtXjQ6oI7N5gd_R_oSUHxpXxUEOfT1xS_ni0,7
68
- jaxsim-0.4.2.dev19.dist-info/RECORD,,
64
+ jaxsim-0.4.2.dev28.dist-info/LICENSE,sha256=eaYdFmdeMbiIoIiPzEK0MjP1S9wtFXjXNR5er49uLR0,1546
65
+ jaxsim-0.4.2.dev28.dist-info/METADATA,sha256=yw8qQ3mi7xmCuxJdZghz_SMShsV2mGyInvxo_yCS3j4,17250
66
+ jaxsim-0.4.2.dev28.dist-info/WHEEL,sha256=Wyh-_nZ0DJYolHNn1_hMa4lM7uDedD_RGVwbmTjyItk,91
67
+ jaxsim-0.4.2.dev28.dist-info/top_level.txt,sha256=LxGMA8FLtXjQ6oI7N5gd_R_oSUHxpXxUEOfT1xS_ni0,7
68
+ jaxsim-0.4.2.dev28.dist-info/RECORD,,