jaxsim 0.2.dev191__py3-none-any.whl → 0.6.1.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. jaxsim/__init__.py +73 -22
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +13 -1
  4. jaxsim/api/com.py +423 -0
  5. jaxsim/api/common.py +48 -19
  6. jaxsim/api/contact.py +604 -52
  7. jaxsim/api/data.py +308 -163
  8. jaxsim/api/frame.py +471 -0
  9. jaxsim/api/joint.py +166 -37
  10. jaxsim/api/kin_dyn_parameters.py +901 -0
  11. jaxsim/api/link.py +277 -78
  12. jaxsim/api/model.py +1572 -362
  13. jaxsim/api/ode.py +324 -133
  14. jaxsim/api/ode_data.py +401 -0
  15. jaxsim/api/references.py +216 -80
  16. jaxsim/exceptions.py +80 -0
  17. jaxsim/integrators/__init__.py +2 -2
  18. jaxsim/integrators/common.py +191 -107
  19. jaxsim/integrators/fixed_step.py +97 -102
  20. jaxsim/integrators/variable_step.py +706 -0
  21. jaxsim/logging.py +1 -2
  22. jaxsim/math/__init__.py +13 -0
  23. jaxsim/math/adjoint.py +57 -22
  24. jaxsim/math/cross.py +16 -7
  25. jaxsim/math/inertia.py +10 -8
  26. jaxsim/math/joint_model.py +289 -0
  27. jaxsim/math/quaternion.py +54 -20
  28. jaxsim/math/rotation.py +27 -21
  29. jaxsim/math/skew.py +16 -5
  30. jaxsim/math/transform.py +102 -0
  31. jaxsim/math/utils.py +31 -0
  32. jaxsim/mujoco/__init__.py +2 -1
  33. jaxsim/mujoco/loaders.py +216 -29
  34. jaxsim/mujoco/model.py +163 -33
  35. jaxsim/mujoco/utils.py +228 -0
  36. jaxsim/mujoco/visualizer.py +107 -22
  37. jaxsim/parsers/__init__.py +0 -1
  38. jaxsim/parsers/descriptions/__init__.py +8 -2
  39. jaxsim/parsers/descriptions/collision.py +83 -26
  40. jaxsim/parsers/descriptions/joint.py +80 -87
  41. jaxsim/parsers/descriptions/link.py +58 -31
  42. jaxsim/parsers/descriptions/model.py +101 -68
  43. jaxsim/parsers/kinematic_graph.py +606 -229
  44. jaxsim/parsers/rod/meshes.py +104 -0
  45. jaxsim/parsers/rod/parser.py +125 -82
  46. jaxsim/parsers/rod/utils.py +127 -82
  47. jaxsim/rbda/__init__.py +11 -0
  48. jaxsim/rbda/aba.py +289 -0
  49. jaxsim/rbda/collidable_points.py +156 -0
  50. jaxsim/rbda/contacts/__init__.py +13 -0
  51. jaxsim/rbda/contacts/common.py +313 -0
  52. jaxsim/rbda/contacts/relaxed_rigid.py +605 -0
  53. jaxsim/rbda/contacts/rigid.py +462 -0
  54. jaxsim/rbda/contacts/soft.py +480 -0
  55. jaxsim/rbda/contacts/visco_elastic.py +1066 -0
  56. jaxsim/rbda/crba.py +167 -0
  57. jaxsim/rbda/forward_kinematics.py +117 -0
  58. jaxsim/rbda/jacobian.py +330 -0
  59. jaxsim/rbda/rnea.py +235 -0
  60. jaxsim/rbda/utils.py +160 -0
  61. jaxsim/terrain/__init__.py +2 -0
  62. jaxsim/terrain/terrain.py +238 -0
  63. jaxsim/typing.py +24 -24
  64. jaxsim/utils/__init__.py +1 -4
  65. jaxsim/utils/jaxsim_dataclass.py +289 -34
  66. jaxsim/utils/tracing.py +5 -11
  67. jaxsim/utils/wrappers.py +159 -0
  68. {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/LICENSE +1 -1
  69. jaxsim-0.6.1.dev2.dist-info/METADATA +465 -0
  70. jaxsim-0.6.1.dev2.dist-info/RECORD +74 -0
  71. {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/WHEEL +1 -1
  72. jaxsim/high_level/__init__.py +0 -2
  73. jaxsim/high_level/common.py +0 -11
  74. jaxsim/high_level/joint.py +0 -148
  75. jaxsim/high_level/link.py +0 -259
  76. jaxsim/high_level/model.py +0 -1686
  77. jaxsim/math/conv.py +0 -114
  78. jaxsim/math/joint.py +0 -102
  79. jaxsim/math/plucker.py +0 -100
  80. jaxsim/physics/__init__.py +0 -12
  81. jaxsim/physics/algos/__init__.py +0 -0
  82. jaxsim/physics/algos/aba.py +0 -254
  83. jaxsim/physics/algos/aba_motors.py +0 -284
  84. jaxsim/physics/algos/crba.py +0 -154
  85. jaxsim/physics/algos/forward_kinematics.py +0 -79
  86. jaxsim/physics/algos/jacobian.py +0 -98
  87. jaxsim/physics/algos/rnea.py +0 -180
  88. jaxsim/physics/algos/rnea_motors.py +0 -196
  89. jaxsim/physics/algos/soft_contacts.py +0 -523
  90. jaxsim/physics/algos/terrain.py +0 -78
  91. jaxsim/physics/algos/utils.py +0 -69
  92. jaxsim/physics/model/__init__.py +0 -0
  93. jaxsim/physics/model/ground_contact.py +0 -53
  94. jaxsim/physics/model/physics_model.py +0 -388
  95. jaxsim/physics/model/physics_model_state.py +0 -283
  96. jaxsim/simulation/__init__.py +0 -4
  97. jaxsim/simulation/integrators.py +0 -393
  98. jaxsim/simulation/ode.py +0 -290
  99. jaxsim/simulation/ode_data.py +0 -96
  100. jaxsim/simulation/ode_integration.py +0 -62
  101. jaxsim/simulation/simulator.py +0 -543
  102. jaxsim/simulation/simulator_callbacks.py +0 -79
  103. jaxsim/simulation/utils.py +0 -15
  104. jaxsim/sixd/__init__.py +0 -2
  105. jaxsim/utils/oop.py +0 -536
  106. jaxsim/utils/vmappable.py +0 -117
  107. jaxsim-0.2.dev191.dist-info/METADATA +0 -184
  108. jaxsim-0.2.dev191.dist-info/RECORD +0 -81
  109. {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/top_level.txt +0 -0
@@ -1,16 +1,15 @@
1
+ import dataclasses
1
2
  from typing import ClassVar, Generic
2
3
 
3
- import jax
4
4
  import jax.numpy as jnp
5
5
  import jax_dataclasses
6
- import jaxlie
7
6
 
8
- from jaxsim.simulation.ode_data import ODEState
7
+ import jaxsim.api as js
8
+ import jaxsim.typing as jtp
9
9
 
10
- from .common import ExplicitRungeKutta, PyTreeType, Time, TimeStep
11
-
12
- ODEStateDerivative = ODEState
10
+ from .common import ExplicitRungeKutta, ExplicitRungeKuttaSO3Mixin, PyTreeType
13
11
 
12
+ ODEStateDerivative = js.ode_data.ODEState
14
13
 
15
14
  # =====================================================
16
15
  # Explicit Runge-Kutta integrators operating on PyTrees
@@ -19,83 +18,107 @@ ODEStateDerivative = ODEState
19
18
 
20
19
  @jax_dataclasses.pytree_dataclass
21
20
  class ForwardEuler(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
21
+ """
22
+ Forward Euler integrator.
23
+ """
22
24
 
23
- A: ClassVar[jax.typing.ArrayLike] = jnp.array(
24
- [
25
- [0],
26
- ]
27
- ).astype(float)
28
-
29
- b: ClassVar[jax.typing.ArrayLike] = (
30
- jnp.array(
31
- [
32
- [1],
33
- ]
34
- )
35
- .astype(float)
36
- .transpose()
25
+ A: jtp.Matrix = dataclasses.field(
26
+ default_factory=lambda: jnp.atleast_2d(0).astype(float), compare=False
27
+ )
28
+ b: jtp.Matrix = dataclasses.field(
29
+ default_factory=lambda: jnp.atleast_2d(1).astype(float), compare=False
37
30
  )
38
31
 
39
- c: ClassVar[jax.typing.ArrayLike] = jnp.array(
40
- [0],
41
- ).astype(float)
32
+ c: jtp.Vector = dataclasses.field(
33
+ default_factory=lambda: jnp.atleast_1d(0).astype(float), compare=False
34
+ )
42
35
 
43
- row_index_of_solution: ClassVar[int] = 0
44
- order_of_bT_rows: ClassVar[tuple[int, ...]] = (1,)
36
+ row_index_of_solution: int = 0
37
+ order_of_bT_rows: tuple[int, ...] = (1,)
38
+ index_of_fsal: jtp.IntLike | None = None
39
+ fsal_enabled_if_supported: bool = False
45
40
 
46
41
 
47
42
  @jax_dataclasses.pytree_dataclass
48
- class Heun(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
49
-
50
- A: ClassVar[jax.typing.ArrayLike] = jnp.array(
51
- [
52
- [0, 0],
53
- [1 / 2, 0],
54
- ]
55
- ).astype(float)
56
-
57
- b: ClassVar[jax.typing.ArrayLike] = (
58
- jnp.atleast_2d(
59
- jnp.array([1 / 2, 1 / 2]),
60
- )
61
- .astype(float)
62
- .transpose()
43
+ class Heun2(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
44
+ """
45
+ Heun's second-order integrator.
46
+ """
47
+
48
+ A: jtp.Matrix = dataclasses.field(
49
+ default_factory=lambda: jnp.array(
50
+ [
51
+ [0, 0],
52
+ [1, 0],
53
+ ]
54
+ ).astype(float),
55
+ compare=False,
56
+ )
57
+
58
+ b: jtp.Matrix = dataclasses.field(
59
+ default_factory=lambda: (
60
+ jnp.atleast_2d(
61
+ jnp.array([1 / 2, 1 / 2]),
62
+ )
63
+ .astype(float)
64
+ .transpose()
65
+ ),
66
+ compare=False,
63
67
  )
64
68
 
65
- c: ClassVar[jax.typing.ArrayLike] = jnp.array(
66
- [0, 1],
67
- ).astype(float)
69
+ c: jtp.Vector = dataclasses.field(
70
+ default_factory=lambda: jnp.array(
71
+ [0, 1],
72
+ ).astype(float),
73
+ compare=False,
74
+ )
68
75
 
69
76
  row_index_of_solution: ClassVar[int] = 0
70
77
  order_of_bT_rows: ClassVar[tuple[int, ...]] = (2,)
78
+ index_of_fsal: jtp.IntLike | None = None
79
+ fsal_enabled_if_supported: bool = False
71
80
 
72
81
 
73
82
  @jax_dataclasses.pytree_dataclass
74
83
  class RungeKutta4(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
84
+ """
85
+ Fourth-order Runge-Kutta integrator.
86
+ """
75
87
 
76
- A: ClassVar[jax.typing.ArrayLike] = jnp.array(
77
- [
78
- [0, 0, 0, 0],
79
- [1 / 2, 0, 0, 0],
80
- [0, 1 / 2, 0, 0],
81
- [0, 0, 1, 0],
82
- ]
83
- ).astype(float)
84
-
85
- b: ClassVar[jax.typing.ArrayLike] = (
86
- jnp.atleast_2d(
87
- jnp.array([1 / 6, 1 / 3, 1 / 3, 1 / 6]),
88
- )
89
- .astype(float)
90
- .transpose()
88
+ A: jtp.Matrix = dataclasses.field(
89
+ default_factory=lambda: jnp.array(
90
+ [
91
+ [0, 0, 0, 0],
92
+ [1 / 2, 0, 0, 0],
93
+ [0, 1 / 2, 0, 0],
94
+ [0, 0, 1, 0],
95
+ ]
96
+ ).astype(float),
97
+ compare=False,
91
98
  )
92
99
 
93
- c: ClassVar[jax.typing.ArrayLike] = jnp.array(
94
- [0, 1 / 2, 1 / 2, 1],
95
- ).astype(float)
100
+ b: jtp.Matrix = dataclasses.field(
101
+ default_factory=lambda: (
102
+ jnp.atleast_2d(
103
+ jnp.array([1 / 6, 1 / 3, 1 / 3, 1 / 6]),
104
+ )
105
+ .astype(float)
106
+ .transpose()
107
+ ),
108
+ compare=False,
109
+ )
110
+
111
+ c: jtp.Vector = dataclasses.field(
112
+ default_factory=lambda: jnp.array(
113
+ [0, 1 / 2, 1 / 2, 1],
114
+ ).astype(float),
115
+ compare=False,
116
+ )
96
117
 
97
118
  row_index_of_solution: ClassVar[int] = 0
98
119
  order_of_bT_rows: ClassVar[tuple[int, ...]] = (4,)
120
+ index_of_fsal: jtp.IntLike | None = None
121
+ fsal_enabled_if_supported: bool = False
99
122
 
100
123
 
101
124
  # ===============================================================================
@@ -103,56 +126,28 @@ class RungeKutta4(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
103
126
  # ===============================================================================
104
127
 
105
128
 
106
- class ExplicitRungeKuttaSO3Mixin:
129
+ @jax_dataclasses.pytree_dataclass
130
+ class ForwardEulerSO3(ExplicitRungeKuttaSO3Mixin, ForwardEuler[js.ode_data.ODEState]):
107
131
  """
108
- Mixin class to apply over explicit RK integrators defined on
109
- `PyTreeType = ODEState` to integrate the quaternion on SO(3).
132
+ Forward Euler integrator for SO(3) states.
110
133
  """
111
134
 
112
- @classmethod
113
- def post_process_state(
114
- cls, x0: ODEState, t0: Time, xf: ODEState, dt: TimeStep
115
- ) -> ODEState:
116
-
117
- # Indices to convert quaternions between serializations.
118
- to_xyzw = jnp.array([1, 2, 3, 0])
119
- to_wxyz = jnp.array([3, 0, 1, 2])
120
-
121
- # Get the initial quaternion.
122
- W_Q_B_t0 = jaxlie.SO3.from_quaternion_xyzw(
123
- xyzw=x0.physics_model.base_quaternion[to_xyzw]
124
- )
125
-
126
- # Get the final angular velocity.
127
- # This is already computed by averaging the kᵢ in RK-based schemes.
128
- # Therefore, by using the ω at tf, we obtain a RK scheme operating
129
- # on the SO(3) manifold.
130
- W_ω_WB_tf = xf.physics_model.base_angular_velocity
131
-
132
- # Integrate the quaternion on SO(3).
133
- # Note that we left-multiply with the exponential map since the angular
134
- # velocity is expressed in the inertial frame.
135
- W_Q_B_tf = jaxlie.SO3.exp(tangent=dt * W_ω_WB_tf) @ W_Q_B_t0
136
-
137
- # Replace the quaternion in the final state.
138
- return xf.replace(
139
- physics_model=xf.physics_model.replace(
140
- base_quaternion=W_Q_B_tf.as_quaternion_xyzw()[to_wxyz]
141
- ),
142
- validate=True,
143
- )
144
-
145
-
146
- @jax_dataclasses.pytree_dataclass
147
- class ForwardEulerSO3(ExplicitRungeKuttaSO3Mixin, Heun[ODEState]):
148
135
  pass
149
136
 
150
137
 
151
138
  @jax_dataclasses.pytree_dataclass
152
- class HeunSO3(ExplicitRungeKuttaSO3Mixin, Heun[ODEState]):
139
+ class Heun2SO3(ExplicitRungeKuttaSO3Mixin, Heun2[js.ode_data.ODEState]):
140
+ """
141
+ Heun's second-order integrator for SO(3) states.
142
+ """
143
+
153
144
  pass
154
145
 
155
146
 
156
147
  @jax_dataclasses.pytree_dataclass
157
- class RungeKutta4SO3(ExplicitRungeKuttaSO3Mixin, RungeKutta4[ODEState]):
148
+ class RungeKutta4SO3(ExplicitRungeKuttaSO3Mixin, RungeKutta4[js.ode_data.ODEState]):
149
+ """
150
+ Fourth-order Runge-Kutta integrator for SO(3) states.
151
+ """
152
+
158
153
  pass