jaxsim 0.2.dev188__py3-none-any.whl → 0.6.1.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. jaxsim/__init__.py +73 -22
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +13 -1
  4. jaxsim/api/com.py +423 -0
  5. jaxsim/api/common.py +48 -19
  6. jaxsim/api/contact.py +604 -52
  7. jaxsim/api/data.py +308 -163
  8. jaxsim/api/frame.py +471 -0
  9. jaxsim/api/joint.py +166 -37
  10. jaxsim/api/kin_dyn_parameters.py +901 -0
  11. jaxsim/api/link.py +277 -78
  12. jaxsim/api/model.py +1572 -362
  13. jaxsim/api/ode.py +324 -129
  14. jaxsim/api/ode_data.py +401 -0
  15. jaxsim/api/references.py +216 -80
  16. jaxsim/exceptions.py +80 -0
  17. jaxsim/integrators/__init__.py +2 -2
  18. jaxsim/integrators/common.py +191 -107
  19. jaxsim/integrators/fixed_step.py +97 -102
  20. jaxsim/integrators/variable_step.py +706 -0
  21. jaxsim/logging.py +1 -2
  22. jaxsim/math/__init__.py +13 -0
  23. jaxsim/math/adjoint.py +57 -22
  24. jaxsim/math/cross.py +16 -7
  25. jaxsim/math/inertia.py +10 -8
  26. jaxsim/math/joint_model.py +289 -0
  27. jaxsim/math/quaternion.py +54 -20
  28. jaxsim/math/rotation.py +27 -21
  29. jaxsim/math/skew.py +16 -5
  30. jaxsim/math/transform.py +102 -0
  31. jaxsim/math/utils.py +31 -0
  32. jaxsim/mujoco/__init__.py +2 -1
  33. jaxsim/mujoco/loaders.py +216 -29
  34. jaxsim/mujoco/model.py +163 -33
  35. jaxsim/mujoco/utils.py +228 -0
  36. jaxsim/mujoco/visualizer.py +107 -22
  37. jaxsim/parsers/__init__.py +0 -1
  38. jaxsim/parsers/descriptions/__init__.py +8 -2
  39. jaxsim/parsers/descriptions/collision.py +87 -16
  40. jaxsim/parsers/descriptions/joint.py +80 -87
  41. jaxsim/parsers/descriptions/link.py +62 -24
  42. jaxsim/parsers/descriptions/model.py +101 -68
  43. jaxsim/parsers/kinematic_graph.py +607 -225
  44. jaxsim/parsers/rod/meshes.py +104 -0
  45. jaxsim/parsers/rod/parser.py +125 -82
  46. jaxsim/parsers/rod/utils.py +127 -82
  47. jaxsim/rbda/__init__.py +11 -0
  48. jaxsim/rbda/aba.py +289 -0
  49. jaxsim/rbda/collidable_points.py +156 -0
  50. jaxsim/rbda/contacts/__init__.py +13 -0
  51. jaxsim/rbda/contacts/common.py +313 -0
  52. jaxsim/rbda/contacts/relaxed_rigid.py +605 -0
  53. jaxsim/rbda/contacts/rigid.py +462 -0
  54. jaxsim/rbda/contacts/soft.py +480 -0
  55. jaxsim/rbda/contacts/visco_elastic.py +1066 -0
  56. jaxsim/rbda/crba.py +167 -0
  57. jaxsim/rbda/forward_kinematics.py +117 -0
  58. jaxsim/rbda/jacobian.py +330 -0
  59. jaxsim/rbda/rnea.py +235 -0
  60. jaxsim/rbda/utils.py +160 -0
  61. jaxsim/terrain/__init__.py +2 -0
  62. jaxsim/terrain/terrain.py +238 -0
  63. jaxsim/typing.py +24 -24
  64. jaxsim/utils/__init__.py +1 -4
  65. jaxsim/utils/jaxsim_dataclass.py +289 -34
  66. jaxsim/utils/tracing.py +5 -11
  67. jaxsim/utils/wrappers.py +159 -0
  68. {jaxsim-0.2.dev188.dist-info → jaxsim-0.6.1.dev2.dist-info}/LICENSE +1 -1
  69. jaxsim-0.6.1.dev2.dist-info/METADATA +465 -0
  70. jaxsim-0.6.1.dev2.dist-info/RECORD +74 -0
  71. {jaxsim-0.2.dev188.dist-info → jaxsim-0.6.1.dev2.dist-info}/WHEEL +1 -1
  72. jaxsim/high_level/__init__.py +0 -2
  73. jaxsim/high_level/common.py +0 -11
  74. jaxsim/high_level/joint.py +0 -148
  75. jaxsim/high_level/link.py +0 -259
  76. jaxsim/high_level/model.py +0 -1686
  77. jaxsim/math/conv.py +0 -114
  78. jaxsim/math/joint.py +0 -102
  79. jaxsim/math/plucker.py +0 -100
  80. jaxsim/physics/__init__.py +0 -12
  81. jaxsim/physics/algos/__init__.py +0 -0
  82. jaxsim/physics/algos/aba.py +0 -254
  83. jaxsim/physics/algos/aba_motors.py +0 -284
  84. jaxsim/physics/algos/crba.py +0 -154
  85. jaxsim/physics/algos/forward_kinematics.py +0 -79
  86. jaxsim/physics/algos/jacobian.py +0 -98
  87. jaxsim/physics/algos/rnea.py +0 -180
  88. jaxsim/physics/algos/rnea_motors.py +0 -196
  89. jaxsim/physics/algos/soft_contacts.py +0 -523
  90. jaxsim/physics/algos/terrain.py +0 -80
  91. jaxsim/physics/algos/utils.py +0 -69
  92. jaxsim/physics/model/__init__.py +0 -0
  93. jaxsim/physics/model/ground_contact.py +0 -55
  94. jaxsim/physics/model/physics_model.py +0 -388
  95. jaxsim/physics/model/physics_model_state.py +0 -283
  96. jaxsim/simulation/__init__.py +0 -4
  97. jaxsim/simulation/integrators.py +0 -393
  98. jaxsim/simulation/ode.py +0 -290
  99. jaxsim/simulation/ode_data.py +0 -96
  100. jaxsim/simulation/ode_integration.py +0 -62
  101. jaxsim/simulation/simulator.py +0 -543
  102. jaxsim/simulation/simulator_callbacks.py +0 -79
  103. jaxsim/simulation/utils.py +0 -15
  104. jaxsim/sixd/__init__.py +0 -2
  105. jaxsim/utils/oop.py +0 -536
  106. jaxsim/utils/vmappable.py +0 -117
  107. jaxsim-0.2.dev188.dist-info/METADATA +0 -184
  108. jaxsim-0.2.dev188.dist-info/RECORD +0 -81
  109. {jaxsim-0.2.dev188.dist-info → jaxsim-0.6.1.dev2.dist-info}/top_level.txt +0 -0
@@ -1,180 +0,0 @@
1
- from typing import Tuple
2
-
3
- import jax
4
- import jax.numpy as jnp
5
- import numpy as np
6
-
7
- import jaxsim.typing as jtp
8
- from jaxsim.math.adjoint import Adjoint
9
- from jaxsim.math.cross import Cross
10
- from jaxsim.physics.model.physics_model import PhysicsModel
11
-
12
- from . import utils
13
-
14
-
15
- def rnea(
16
- model: PhysicsModel,
17
- xfb: jtp.Vector,
18
- q: jtp.Vector,
19
- qd: jtp.Vector,
20
- qdd: jtp.Vector,
21
- a0fb: jtp.Vector = jnp.zeros(6),
22
- f_ext: jtp.Matrix | None = None,
23
- ) -> Tuple[jtp.Vector, jtp.Vector]:
24
- """
25
- Perform Inverse Dynamics Calculation using the Recursive Newton-Euler Algorithm (RNEA).
26
-
27
- This function calculates the joint torques (forces) required to achieve a desired motion
28
- given the robot's configuration, velocities, accelerations, and external forces.
29
-
30
- Args:
31
- model (PhysicsModel): The robot's physics model containing dynamic parameters.
32
- xfb (jtp.Vector): The floating base state, including orientation and position.
33
- q (jtp.Vector): Joint positions (angles).
34
- qd (jtp.Vector): Joint velocities.
35
- qdd (jtp.Vector): Joint accelerations.
36
- a0fb (jtp.Vector, optional): Base acceleration. Defaults to zeros.
37
- f_ext (jtp.Matrix, optional): External forces acting on the robot. Defaults to None.
38
-
39
- Returns:
40
- W_f0 (jtp.Vector): The base 6D force expressed in the world frame.
41
- tau (jtp.Vector): Joint torques (forces) required for the desired motion.
42
- """
43
-
44
- xfb, q, qd, qdd, _, f_ext = utils.process_inputs(
45
- physics_model=model, xfb=xfb, q=q, qd=qd, qdd=qdd, f_ext=f_ext
46
- )
47
-
48
- a0fb = a0fb.squeeze()
49
- gravity = model.gravity.squeeze()
50
-
51
- if a0fb.shape[0] != 6:
52
- raise ValueError(a0fb.shape)
53
-
54
- M = model.spatial_inertias
55
- pre_X_λi = model.tree_transforms
56
- i_X_pre = model.joint_transforms(q=q)
57
- S = model.motion_subspaces(q=q)
58
- i_X_λi = jnp.zeros_like(pre_X_λi)
59
-
60
- i_X_0 = jnp.zeros_like(pre_X_λi)
61
- i_X_0 = i_X_0.at[0].set(jnp.eye(6))
62
-
63
- # Parent array mapping: i -> λ(i).
64
- # Exception: λ(0) must not be used, it's initialized to -1.
65
- λ = model.parent_array()
66
-
67
- v = jnp.array([jnp.zeros([6, 1])] * model.NB)
68
- a = jnp.array([jnp.zeros([6, 1])] * model.NB)
69
- f = jnp.array([jnp.zeros([6, 1])] * model.NB)
70
-
71
- # 6D transform of base velocity
72
- B_X_W = Adjoint.from_quaternion_and_translation(
73
- quaternion=xfb[0:4],
74
- translation=xfb[4:7],
75
- inverse=True,
76
- normalize_quaternion=True,
77
- )
78
- i_X_λi = i_X_λi.at[0].set(B_X_W)
79
-
80
- a_0 = -B_X_W @ jnp.vstack(gravity)
81
- a = a.at[0].set(a_0)
82
-
83
- if model.is_floating_base:
84
- W_v_WB = jnp.vstack(jnp.hstack([xfb[10:13], xfb[7:10]]))
85
-
86
- v_0 = B_X_W @ W_v_WB
87
- v = v.at[0].set(v_0)
88
-
89
- a_0 = B_X_W @ (jnp.vstack(a0fb) - jnp.vstack(gravity))
90
- a = a.at[0].set(a_0)
91
-
92
- f_0 = (
93
- M[0] @ a[0]
94
- + Cross.vx_star(v[0]) @ M[0] @ v[0]
95
- - Adjoint.inverse(B_X_W).T @ jnp.vstack(f_ext[0])
96
- )
97
- f = f.at[0].set(f_0)
98
-
99
- ForwardPassCarry = Tuple[
100
- jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax
101
- ]
102
- forward_pass_carry = (i_X_λi, v, a, i_X_0, f)
103
-
104
- def forward_pass(
105
- carry: ForwardPassCarry, i: jtp.Int
106
- ) -> Tuple[ForwardPassCarry, None]:
107
- ii = i - 1
108
- i_X_λi, v, a, i_X_0, f = carry
109
-
110
- vJ = S[i] * qd[ii]
111
- i_X_λi_i = i_X_pre[i] @ pre_X_λi[i]
112
- i_X_λi = i_X_λi.at[i].set(i_X_λi_i)
113
-
114
- v_i = i_X_λi[i] @ v[λ[i]] + vJ
115
- v = v.at[i].set(v_i)
116
-
117
- a_i = i_X_λi[i] @ a[λ[i]] + S[i] * qdd[ii] + Cross.vx(v[i]) @ vJ
118
- a = a.at[i].set(a_i)
119
-
120
- i_X_0_i = i_X_λi[i] @ i_X_0[λ[i]]
121
- i_X_0 = i_X_0.at[i].set(i_X_0_i)
122
- i_Xf_W = Adjoint.inverse(i_X_0[i] @ B_X_W).T
123
-
124
- f_i = (
125
- M[i] @ a[i]
126
- + Cross.vx_star(v[i]) @ M[i] @ v[i]
127
- - i_Xf_W @ jnp.vstack(f_ext[i])
128
- )
129
- f = f.at[i].set(f_i)
130
-
131
- return (i_X_λi, v, a, i_X_0, f), None
132
-
133
- (i_X_λi, v, a, i_X_0, f), _ = jax.lax.scan(
134
- f=forward_pass,
135
- init=forward_pass_carry,
136
- xs=np.arange(start=1, stop=model.NB),
137
- )
138
-
139
- tau = jnp.zeros_like(q)
140
-
141
- BackwardPassCarry = Tuple[jtp.MatrixJax, jtp.MatrixJax]
142
- backward_pass_carry = (tau, f)
143
-
144
- def backward_pass(
145
- carry: BackwardPassCarry, i: jtp.Int
146
- ) -> Tuple[BackwardPassCarry, None]:
147
- ii = i - 1
148
- tau, f = carry
149
-
150
- value = S[i].T @ f[i]
151
- tau = tau.at[ii].set(value.squeeze())
152
-
153
- def update_f(f: jtp.MatrixJax) -> jtp.MatrixJax:
154
- f_λi = f[λ[i]] + i_X_λi[i].T @ f[i]
155
- f = f.at[λ[i]].set(f_λi)
156
- return f
157
-
158
- f = jax.lax.cond(
159
- pred=jnp.array([λ[i] != 0, model.is_floating_base]).any(),
160
- true_fun=update_f,
161
- false_fun=lambda f: f,
162
- operand=f,
163
- )
164
-
165
- return (tau, f), None
166
-
167
- (tau, f), _ = jax.lax.scan(
168
- f=backward_pass,
169
- init=backward_pass_carry,
170
- xs=np.flip(np.arange(start=1, stop=model.NB)),
171
- )
172
-
173
- # Handle 1 DoF models
174
- tau = jnp.atleast_1d(tau.squeeze())
175
- tau = jnp.vstack(tau) if tau.size > 0 else jnp.empty(shape=(0, 1))
176
-
177
- # Express the base 6D force in the world frame
178
- W_f0 = B_X_W.T @ jnp.vstack(f[0])
179
-
180
- return W_f0, tau
@@ -1,196 +0,0 @@
1
- from typing import Tuple
2
-
3
- import jax
4
- import jax.numpy as jnp
5
- import numpy as np
6
-
7
- import jaxsim.typing as jtp
8
- from jaxsim.math.adjoint import Adjoint
9
- from jaxsim.math.cross import Cross
10
- from jaxsim.physics.model.physics_model import PhysicsModel
11
-
12
- from . import utils
13
-
14
-
15
- def rnea(
16
- model: PhysicsModel,
17
- xfb: jtp.Vector,
18
- q: jtp.Vector,
19
- qd: jtp.Vector,
20
- qdd: jtp.Vector,
21
- a0fb: jtp.Vector = jnp.zeros(6),
22
- f_ext: jtp.Matrix | None = None,
23
- ) -> Tuple[jtp.Vector, jtp.Vector]:
24
- """
25
- Recursive Newton-Euler Algorithm (RNEA) algorithm for inverse dynamics.
26
- """
27
-
28
- xfb, q, qd, qdd, _, f_ext = utils.process_inputs(
29
- physics_model=model, xfb=xfb, q=q, qd=qd, qdd=qdd, f_ext=f_ext
30
- )
31
-
32
- a0fb = a0fb.squeeze()
33
- gravity = model.gravity.squeeze()
34
-
35
- if a0fb.shape[0] != 6:
36
- raise ValueError(a0fb.shape)
37
-
38
- M = model.spatial_inertias
39
- pre_X_λi = model.tree_transforms
40
- i_X_pre = model.joint_transforms(q=q)
41
- S = model.motion_subspaces(q=q)
42
- i_X_λi = jnp.zeros_like(pre_X_λi)
43
-
44
- Γ = jnp.array([*model._joint_motor_gear_ratio.values()])
45
- IM = jnp.array([*model._joint_motor_inertia.values()])
46
- K_v = jnp.array([*model._joint_motor_viscous_friction.values()])
47
- K̅ᵥ = jnp.diag(Γ.T * jnp.diag(K_v) * Γ)
48
- m_S = jnp.concatenate([S[:1], S[1:] * Γ[:, None, None]], axis=0)
49
-
50
- i_X_0 = jnp.zeros_like(pre_X_λi)
51
- i_X_0 = i_X_0.at[0].set(jnp.eye(6))
52
-
53
- # Parent array mapping: i -> λ(i).
54
- # Exception: λ(0) must not be used, it's initialized to -1.
55
- λ = model.parent_array()
56
-
57
- v = jnp.array([jnp.zeros([6, 1])] * model.NB)
58
- a = jnp.array([jnp.zeros([6, 1])] * model.NB)
59
- f = jnp.array([jnp.zeros([6, 1])] * model.NB)
60
-
61
- v_m = jnp.array([jnp.zeros([6, 1])] * model.NB)
62
- a_m = jnp.array([jnp.zeros([6, 1])] * model.NB)
63
- f_m = jnp.array([jnp.zeros([6, 1])] * model.NB)
64
-
65
- # 6D transform of base velocity
66
- B_X_W = Adjoint.from_quaternion_and_translation(
67
- quaternion=xfb[0:4],
68
- translation=xfb[4:7],
69
- inverse=True,
70
- normalize_quaternion=True,
71
- )
72
- i_X_λi = i_X_λi.at[0].set(B_X_W)
73
-
74
- a_0 = -B_X_W @ jnp.vstack(gravity)
75
- a = a.at[0].set(a_0)
76
-
77
- if model.is_floating_base:
78
- W_v_WB = jnp.vstack(jnp.hstack([xfb[10:13], xfb[7:10]]))
79
-
80
- v_0 = B_X_W @ W_v_WB
81
- v = v.at[0].set(v_0)
82
-
83
- a_0 = B_X_W @ (jnp.vstack(a0fb) - jnp.vstack(gravity))
84
- a = a.at[0].set(a_0)
85
-
86
- f_0 = (
87
- M[0] @ a[0]
88
- + Cross.vx_star(v[0]) @ M[0] @ v[0]
89
- - Adjoint.inverse(B_X_W).T @ jnp.vstack(f_ext[0])
90
- )
91
- f = f.at[0].set(f_0)
92
-
93
- ForwardPassCarry = Tuple[
94
- jtp.MatrixJax,
95
- jtp.MatrixJax,
96
- jtp.MatrixJax,
97
- jtp.MatrixJax,
98
- jtp.MatrixJax,
99
- jtp.MatrixJax,
100
- jtp.MatrixJax,
101
- jtp.MatrixJax,
102
- ]
103
- forward_pass_carry = (i_X_λi, v, v_m, a, a_m, i_X_0, f, f_m)
104
-
105
- def forward_pass(
106
- carry: ForwardPassCarry, i: jtp.Int
107
- ) -> Tuple[ForwardPassCarry, None]:
108
- ii = i - 1
109
- i_X_λi, v, v_m, a, a_m, i_X_0, f, f_m = carry
110
-
111
- vJ = S[i] * qd[ii]
112
- vJ_m = m_S[i] * qd[ii]
113
-
114
- i_X_λi_i = i_X_pre[i] @ pre_X_λi[i]
115
- i_X_λi = i_X_λi.at[i].set(i_X_λi_i)
116
-
117
- v_i = i_X_λi[i] @ v[λ[i]] + vJ
118
- v = v.at[i].set(v_i)
119
-
120
- v_i_m = i_X_λi[i] @ v_m[λ[i]] + vJ_m
121
- v_m = v_m.at[i].set(v_i_m)
122
-
123
- a_i = i_X_λi[i] @ a[λ[i]] + S[i] * qdd[ii] + Cross.vx(v[i]) @ vJ
124
- a = a.at[i].set(a_i)
125
-
126
- a_i_m = i_X_λi[i] @ a_m[λ[i]] + m_S[i] * qdd[ii] + Cross.vx(v_m[i]) @ vJ_m
127
- a_m = a_m.at[i].set(a_i_m)
128
-
129
- i_X_0_i = i_X_λi[i] @ i_X_0[λ[i]]
130
- i_X_0 = i_X_0.at[i].set(i_X_0_i)
131
- i_Xf_W = Adjoint.inverse(i_X_0[i] @ B_X_W).T
132
-
133
- f_i = (
134
- M[i] @ a[i]
135
- + Cross.vx_star(v[i]) @ M[i] @ v[i]
136
- - i_Xf_W @ jnp.vstack(f_ext[i])
137
- )
138
- f = f.at[i].set(f_i)
139
-
140
- f_i_m = IM[i] * a_m[i] + Cross.vx_star(v_m[i]) * IM[i] @ v_m[i]
141
- f_m = f_m.at[i].set(f_i_m)
142
-
143
- return (i_X_λi, v, v_m, a, a_m, i_X_0, f, f_m), None
144
-
145
- (i_X_λi, v, v_m, a, a_m, i_X_0, f, f_m), _ = jax.lax.scan(
146
- f=forward_pass,
147
- init=forward_pass_carry,
148
- xs=np.arange(start=1, stop=model.NB),
149
- )
150
-
151
- tau = jnp.zeros_like(q)
152
-
153
- BackwardPassCarry = Tuple[jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax]
154
- backward_pass_carry = (tau, f, f_m)
155
-
156
- def backward_pass(
157
- carry: BackwardPassCarry, i: jtp.Int
158
- ) -> Tuple[BackwardPassCarry, None]:
159
- ii = i - 1
160
- tau, f, f_m = carry
161
-
162
- value = S[i].T @ f[i] + m_S[i].T @ f_m[i] # + K̅ᵥ[i] * qd[ii]
163
- tau = tau.at[ii].set(value.squeeze())
164
-
165
- def update_f(ffm: Tuple[jtp.MatrixJax, jtp.MatrixJax]) -> jtp.MatrixJax:
166
- f, f_m = ffm
167
- f_λi = f[λ[i]] + i_X_λi[i].T @ f[i]
168
- f = f.at[λ[i]].set(f_λi)
169
-
170
- f_m_λi = f_m[λ[i]] + i_X_λi[i].T @ f_m[i]
171
- f_m = f_m.at[λ[i]].set(f_m_λi)
172
- return f, f_m
173
-
174
- f, f_m = jax.lax.cond(
175
- pred=jnp.array([λ[i] != 0, model.is_floating_base]).any(),
176
- true_fun=update_f,
177
- false_fun=lambda f: f,
178
- operand=(f, f_m),
179
- )
180
-
181
- return (tau, f, f_m), None
182
-
183
- (tau, f, f_m), _ = jax.lax.scan(
184
- f=backward_pass,
185
- init=backward_pass_carry,
186
- xs=np.flip(np.arange(start=1, stop=model.NB)),
187
- )
188
-
189
- # Handle 1 DoF models
190
- tau = jnp.atleast_1d(tau.squeeze())
191
- tau = jnp.vstack(tau) if tau.size > 0 else jnp.empty(shape=(0, 1))
192
-
193
- # Express the base 6D force in the world frame
194
- W_f0 = B_X_W.T @ jnp.vstack(f[0])
195
-
196
- return W_f0, tau