jaxsim 0.2.dev191__py3-none-any.whl → 0.2.dev366__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. jaxsim/__init__.py +3 -4
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +3 -1
  4. jaxsim/api/com.py +240 -0
  5. jaxsim/api/common.py +13 -2
  6. jaxsim/api/contact.py +120 -43
  7. jaxsim/api/data.py +112 -71
  8. jaxsim/api/joint.py +77 -36
  9. jaxsim/api/kin_dyn_parameters.py +777 -0
  10. jaxsim/api/link.py +150 -75
  11. jaxsim/api/model.py +542 -269
  12. jaxsim/api/ode.py +86 -74
  13. jaxsim/api/ode_data.py +694 -0
  14. jaxsim/api/references.py +12 -11
  15. jaxsim/integrators/__init__.py +2 -2
  16. jaxsim/integrators/common.py +110 -24
  17. jaxsim/integrators/fixed_step.py +11 -67
  18. jaxsim/integrators/variable_step.py +610 -0
  19. jaxsim/math/__init__.py +11 -0
  20. jaxsim/math/adjoint.py +24 -2
  21. jaxsim/math/joint_model.py +335 -0
  22. jaxsim/math/quaternion.py +44 -3
  23. jaxsim/math/rotation.py +4 -4
  24. jaxsim/math/transform.py +93 -0
  25. jaxsim/parsers/descriptions/link.py +2 -2
  26. jaxsim/parsers/rod/utils.py +7 -8
  27. jaxsim/rbda/__init__.py +7 -0
  28. jaxsim/rbda/aba.py +295 -0
  29. jaxsim/rbda/collidable_points.py +142 -0
  30. jaxsim/{physics/algos → rbda}/crba.py +43 -42
  31. jaxsim/rbda/forward_kinematics.py +113 -0
  32. jaxsim/rbda/jacobian.py +201 -0
  33. jaxsim/rbda/rnea.py +237 -0
  34. jaxsim/rbda/soft_contacts.py +296 -0
  35. jaxsim/rbda/utils.py +152 -0
  36. jaxsim/terrain/__init__.py +2 -0
  37. jaxsim/utils/__init__.py +1 -4
  38. jaxsim/utils/hashless.py +18 -0
  39. jaxsim/utils/jaxsim_dataclass.py +281 -30
  40. {jaxsim-0.2.dev191.dist-info → jaxsim-0.2.dev366.dist-info}/METADATA +4 -6
  41. jaxsim-0.2.dev366.dist-info/RECORD +64 -0
  42. {jaxsim-0.2.dev191.dist-info → jaxsim-0.2.dev366.dist-info}/WHEEL +1 -1
  43. jaxsim/high_level/__init__.py +0 -2
  44. jaxsim/high_level/common.py +0 -11
  45. jaxsim/high_level/joint.py +0 -148
  46. jaxsim/high_level/link.py +0 -259
  47. jaxsim/high_level/model.py +0 -1686
  48. jaxsim/math/conv.py +0 -114
  49. jaxsim/math/joint.py +0 -102
  50. jaxsim/math/plucker.py +0 -100
  51. jaxsim/physics/__init__.py +0 -12
  52. jaxsim/physics/algos/__init__.py +0 -0
  53. jaxsim/physics/algos/aba.py +0 -254
  54. jaxsim/physics/algos/aba_motors.py +0 -284
  55. jaxsim/physics/algos/forward_kinematics.py +0 -79
  56. jaxsim/physics/algos/jacobian.py +0 -98
  57. jaxsim/physics/algos/rnea.py +0 -180
  58. jaxsim/physics/algos/rnea_motors.py +0 -196
  59. jaxsim/physics/algos/soft_contacts.py +0 -523
  60. jaxsim/physics/algos/utils.py +0 -69
  61. jaxsim/physics/model/__init__.py +0 -0
  62. jaxsim/physics/model/ground_contact.py +0 -53
  63. jaxsim/physics/model/physics_model.py +0 -388
  64. jaxsim/physics/model/physics_model_state.py +0 -283
  65. jaxsim/simulation/__init__.py +0 -4
  66. jaxsim/simulation/integrators.py +0 -393
  67. jaxsim/simulation/ode.py +0 -290
  68. jaxsim/simulation/ode_data.py +0 -96
  69. jaxsim/simulation/ode_integration.py +0 -62
  70. jaxsim/simulation/simulator.py +0 -543
  71. jaxsim/simulation/simulator_callbacks.py +0 -79
  72. jaxsim/simulation/utils.py +0 -15
  73. jaxsim/sixd/__init__.py +0 -2
  74. jaxsim/utils/oop.py +0 -536
  75. jaxsim/utils/vmappable.py +0 -117
  76. jaxsim-0.2.dev191.dist-info/RECORD +0 -81
  77. /jaxsim/{physics/algos → terrain}/terrain.py +0 -0
  78. {jaxsim-0.2.dev191.dist-info → jaxsim-0.2.dev366.dist-info}/LICENSE +0 -0
  79. {jaxsim-0.2.dev191.dist-info → jaxsim-0.2.dev366.dist-info}/top_level.txt +0 -0
@@ -1,196 +0,0 @@
1
- from typing import Tuple
2
-
3
- import jax
4
- import jax.numpy as jnp
5
- import numpy as np
6
-
7
- import jaxsim.typing as jtp
8
- from jaxsim.math.adjoint import Adjoint
9
- from jaxsim.math.cross import Cross
10
- from jaxsim.physics.model.physics_model import PhysicsModel
11
-
12
- from . import utils
13
-
14
-
15
- def rnea(
16
- model: PhysicsModel,
17
- xfb: jtp.Vector,
18
- q: jtp.Vector,
19
- qd: jtp.Vector,
20
- qdd: jtp.Vector,
21
- a0fb: jtp.Vector = jnp.zeros(6),
22
- f_ext: jtp.Matrix | None = None,
23
- ) -> Tuple[jtp.Vector, jtp.Vector]:
24
- """
25
- Recursive Newton-Euler Algorithm (RNEA) algorithm for inverse dynamics.
26
- """
27
-
28
- xfb, q, qd, qdd, _, f_ext = utils.process_inputs(
29
- physics_model=model, xfb=xfb, q=q, qd=qd, qdd=qdd, f_ext=f_ext
30
- )
31
-
32
- a0fb = a0fb.squeeze()
33
- gravity = model.gravity.squeeze()
34
-
35
- if a0fb.shape[0] != 6:
36
- raise ValueError(a0fb.shape)
37
-
38
- M = model.spatial_inertias
39
- pre_X_λi = model.tree_transforms
40
- i_X_pre = model.joint_transforms(q=q)
41
- S = model.motion_subspaces(q=q)
42
- i_X_λi = jnp.zeros_like(pre_X_λi)
43
-
44
- Γ = jnp.array([*model._joint_motor_gear_ratio.values()])
45
- IM = jnp.array([*model._joint_motor_inertia.values()])
46
- K_v = jnp.array([*model._joint_motor_viscous_friction.values()])
47
- K̅ᵥ = jnp.diag(Γ.T * jnp.diag(K_v) * Γ)
48
- m_S = jnp.concatenate([S[:1], S[1:] * Γ[:, None, None]], axis=0)
49
-
50
- i_X_0 = jnp.zeros_like(pre_X_λi)
51
- i_X_0 = i_X_0.at[0].set(jnp.eye(6))
52
-
53
- # Parent array mapping: i -> λ(i).
54
- # Exception: λ(0) must not be used, it's initialized to -1.
55
- λ = model.parent_array()
56
-
57
- v = jnp.array([jnp.zeros([6, 1])] * model.NB)
58
- a = jnp.array([jnp.zeros([6, 1])] * model.NB)
59
- f = jnp.array([jnp.zeros([6, 1])] * model.NB)
60
-
61
- v_m = jnp.array([jnp.zeros([6, 1])] * model.NB)
62
- a_m = jnp.array([jnp.zeros([6, 1])] * model.NB)
63
- f_m = jnp.array([jnp.zeros([6, 1])] * model.NB)
64
-
65
- # 6D transform of base velocity
66
- B_X_W = Adjoint.from_quaternion_and_translation(
67
- quaternion=xfb[0:4],
68
- translation=xfb[4:7],
69
- inverse=True,
70
- normalize_quaternion=True,
71
- )
72
- i_X_λi = i_X_λi.at[0].set(B_X_W)
73
-
74
- a_0 = -B_X_W @ jnp.vstack(gravity)
75
- a = a.at[0].set(a_0)
76
-
77
- if model.is_floating_base:
78
- W_v_WB = jnp.vstack(jnp.hstack([xfb[10:13], xfb[7:10]]))
79
-
80
- v_0 = B_X_W @ W_v_WB
81
- v = v.at[0].set(v_0)
82
-
83
- a_0 = B_X_W @ (jnp.vstack(a0fb) - jnp.vstack(gravity))
84
- a = a.at[0].set(a_0)
85
-
86
- f_0 = (
87
- M[0] @ a[0]
88
- + Cross.vx_star(v[0]) @ M[0] @ v[0]
89
- - Adjoint.inverse(B_X_W).T @ jnp.vstack(f_ext[0])
90
- )
91
- f = f.at[0].set(f_0)
92
-
93
- ForwardPassCarry = Tuple[
94
- jtp.MatrixJax,
95
- jtp.MatrixJax,
96
- jtp.MatrixJax,
97
- jtp.MatrixJax,
98
- jtp.MatrixJax,
99
- jtp.MatrixJax,
100
- jtp.MatrixJax,
101
- jtp.MatrixJax,
102
- ]
103
- forward_pass_carry = (i_X_λi, v, v_m, a, a_m, i_X_0, f, f_m)
104
-
105
- def forward_pass(
106
- carry: ForwardPassCarry, i: jtp.Int
107
- ) -> Tuple[ForwardPassCarry, None]:
108
- ii = i - 1
109
- i_X_λi, v, v_m, a, a_m, i_X_0, f, f_m = carry
110
-
111
- vJ = S[i] * qd[ii]
112
- vJ_m = m_S[i] * qd[ii]
113
-
114
- i_X_λi_i = i_X_pre[i] @ pre_X_λi[i]
115
- i_X_λi = i_X_λi.at[i].set(i_X_λi_i)
116
-
117
- v_i = i_X_λi[i] @ v[λ[i]] + vJ
118
- v = v.at[i].set(v_i)
119
-
120
- v_i_m = i_X_λi[i] @ v_m[λ[i]] + vJ_m
121
- v_m = v_m.at[i].set(v_i_m)
122
-
123
- a_i = i_X_λi[i] @ a[λ[i]] + S[i] * qdd[ii] + Cross.vx(v[i]) @ vJ
124
- a = a.at[i].set(a_i)
125
-
126
- a_i_m = i_X_λi[i] @ a_m[λ[i]] + m_S[i] * qdd[ii] + Cross.vx(v_m[i]) @ vJ_m
127
- a_m = a_m.at[i].set(a_i_m)
128
-
129
- i_X_0_i = i_X_λi[i] @ i_X_0[λ[i]]
130
- i_X_0 = i_X_0.at[i].set(i_X_0_i)
131
- i_Xf_W = Adjoint.inverse(i_X_0[i] @ B_X_W).T
132
-
133
- f_i = (
134
- M[i] @ a[i]
135
- + Cross.vx_star(v[i]) @ M[i] @ v[i]
136
- - i_Xf_W @ jnp.vstack(f_ext[i])
137
- )
138
- f = f.at[i].set(f_i)
139
-
140
- f_i_m = IM[i] * a_m[i] + Cross.vx_star(v_m[i]) * IM[i] @ v_m[i]
141
- f_m = f_m.at[i].set(f_i_m)
142
-
143
- return (i_X_λi, v, v_m, a, a_m, i_X_0, f, f_m), None
144
-
145
- (i_X_λi, v, v_m, a, a_m, i_X_0, f, f_m), _ = jax.lax.scan(
146
- f=forward_pass,
147
- init=forward_pass_carry,
148
- xs=np.arange(start=1, stop=model.NB),
149
- )
150
-
151
- tau = jnp.zeros_like(q)
152
-
153
- BackwardPassCarry = Tuple[jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax]
154
- backward_pass_carry = (tau, f, f_m)
155
-
156
- def backward_pass(
157
- carry: BackwardPassCarry, i: jtp.Int
158
- ) -> Tuple[BackwardPassCarry, None]:
159
- ii = i - 1
160
- tau, f, f_m = carry
161
-
162
- value = S[i].T @ f[i] + m_S[i].T @ f_m[i] # + K̅ᵥ[i] * qd[ii]
163
- tau = tau.at[ii].set(value.squeeze())
164
-
165
- def update_f(ffm: Tuple[jtp.MatrixJax, jtp.MatrixJax]) -> jtp.MatrixJax:
166
- f, f_m = ffm
167
- f_λi = f[λ[i]] + i_X_λi[i].T @ f[i]
168
- f = f.at[λ[i]].set(f_λi)
169
-
170
- f_m_λi = f_m[λ[i]] + i_X_λi[i].T @ f_m[i]
171
- f_m = f_m.at[λ[i]].set(f_m_λi)
172
- return f, f_m
173
-
174
- f, f_m = jax.lax.cond(
175
- pred=jnp.array([λ[i] != 0, model.is_floating_base]).any(),
176
- true_fun=update_f,
177
- false_fun=lambda f: f,
178
- operand=(f, f_m),
179
- )
180
-
181
- return (tau, f, f_m), None
182
-
183
- (tau, f, f_m), _ = jax.lax.scan(
184
- f=backward_pass,
185
- init=backward_pass_carry,
186
- xs=np.flip(np.arange(start=1, stop=model.NB)),
187
- )
188
-
189
- # Handle 1 DoF models
190
- tau = jnp.atleast_1d(tau.squeeze())
191
- tau = jnp.vstack(tau) if tau.size > 0 else jnp.empty(shape=(0, 1))
192
-
193
- # Express the base 6D force in the world frame
194
- W_f0 = B_X_W.T @ jnp.vstack(f[0])
195
-
196
- return W_f0, tau