jaxsim 0.2.dev188__py3-none-any.whl → 0.2.dev364__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. jaxsim/__init__.py +3 -4
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +3 -1
  4. jaxsim/api/com.py +240 -0
  5. jaxsim/api/common.py +13 -2
  6. jaxsim/api/contact.py +120 -43
  7. jaxsim/api/data.py +112 -71
  8. jaxsim/api/joint.py +77 -36
  9. jaxsim/api/kin_dyn_parameters.py +777 -0
  10. jaxsim/api/link.py +150 -75
  11. jaxsim/api/model.py +542 -269
  12. jaxsim/api/ode.py +88 -72
  13. jaxsim/api/ode_data.py +694 -0
  14. jaxsim/api/references.py +12 -11
  15. jaxsim/integrators/__init__.py +2 -2
  16. jaxsim/integrators/common.py +110 -24
  17. jaxsim/integrators/fixed_step.py +11 -67
  18. jaxsim/integrators/variable_step.py +610 -0
  19. jaxsim/math/__init__.py +11 -0
  20. jaxsim/math/adjoint.py +24 -2
  21. jaxsim/math/joint_model.py +335 -0
  22. jaxsim/math/quaternion.py +44 -3
  23. jaxsim/math/rotation.py +4 -4
  24. jaxsim/math/transform.py +93 -0
  25. jaxsim/parsers/descriptions/collision.py +14 -0
  26. jaxsim/parsers/descriptions/link.py +13 -2
  27. jaxsim/parsers/kinematic_graph.py +5 -0
  28. jaxsim/parsers/rod/utils.py +7 -8
  29. jaxsim/rbda/__init__.py +7 -0
  30. jaxsim/rbda/aba.py +295 -0
  31. jaxsim/rbda/collidable_points.py +142 -0
  32. jaxsim/{physics/algos → rbda}/crba.py +43 -42
  33. jaxsim/rbda/forward_kinematics.py +113 -0
  34. jaxsim/rbda/jacobian.py +201 -0
  35. jaxsim/rbda/rnea.py +237 -0
  36. jaxsim/rbda/soft_contacts.py +296 -0
  37. jaxsim/rbda/utils.py +152 -0
  38. jaxsim/terrain/__init__.py +2 -0
  39. jaxsim/{physics/algos → terrain}/terrain.py +4 -6
  40. jaxsim/utils/__init__.py +1 -4
  41. jaxsim/utils/hashless.py +18 -0
  42. jaxsim/utils/jaxsim_dataclass.py +281 -30
  43. {jaxsim-0.2.dev188.dist-info → jaxsim-0.2.dev364.dist-info}/METADATA +4 -6
  44. jaxsim-0.2.dev364.dist-info/RECORD +64 -0
  45. {jaxsim-0.2.dev188.dist-info → jaxsim-0.2.dev364.dist-info}/WHEEL +1 -1
  46. jaxsim/high_level/__init__.py +0 -2
  47. jaxsim/high_level/common.py +0 -11
  48. jaxsim/high_level/joint.py +0 -148
  49. jaxsim/high_level/link.py +0 -259
  50. jaxsim/high_level/model.py +0 -1686
  51. jaxsim/math/conv.py +0 -114
  52. jaxsim/math/joint.py +0 -102
  53. jaxsim/math/plucker.py +0 -100
  54. jaxsim/physics/__init__.py +0 -12
  55. jaxsim/physics/algos/__init__.py +0 -0
  56. jaxsim/physics/algos/aba.py +0 -254
  57. jaxsim/physics/algos/aba_motors.py +0 -284
  58. jaxsim/physics/algos/forward_kinematics.py +0 -79
  59. jaxsim/physics/algos/jacobian.py +0 -98
  60. jaxsim/physics/algos/rnea.py +0 -180
  61. jaxsim/physics/algos/rnea_motors.py +0 -196
  62. jaxsim/physics/algos/soft_contacts.py +0 -523
  63. jaxsim/physics/algos/utils.py +0 -69
  64. jaxsim/physics/model/__init__.py +0 -0
  65. jaxsim/physics/model/ground_contact.py +0 -55
  66. jaxsim/physics/model/physics_model.py +0 -388
  67. jaxsim/physics/model/physics_model_state.py +0 -283
  68. jaxsim/simulation/__init__.py +0 -4
  69. jaxsim/simulation/integrators.py +0 -393
  70. jaxsim/simulation/ode.py +0 -290
  71. jaxsim/simulation/ode_data.py +0 -96
  72. jaxsim/simulation/ode_integration.py +0 -62
  73. jaxsim/simulation/simulator.py +0 -543
  74. jaxsim/simulation/simulator_callbacks.py +0 -79
  75. jaxsim/simulation/utils.py +0 -15
  76. jaxsim/sixd/__init__.py +0 -2
  77. jaxsim/utils/oop.py +0 -536
  78. jaxsim/utils/vmappable.py +0 -117
  79. jaxsim-0.2.dev188.dist-info/RECORD +0 -81
  80. {jaxsim-0.2.dev188.dist-info → jaxsim-0.2.dev364.dist-info}/LICENSE +0 -0
  81. {jaxsim-0.2.dev188.dist-info → jaxsim-0.2.dev364.dist-info}/top_level.txt +0 -0
@@ -1,196 +0,0 @@
1
- from typing import Tuple
2
-
3
- import jax
4
- import jax.numpy as jnp
5
- import numpy as np
6
-
7
- import jaxsim.typing as jtp
8
- from jaxsim.math.adjoint import Adjoint
9
- from jaxsim.math.cross import Cross
10
- from jaxsim.physics.model.physics_model import PhysicsModel
11
-
12
- from . import utils
13
-
14
-
15
- def rnea(
16
- model: PhysicsModel,
17
- xfb: jtp.Vector,
18
- q: jtp.Vector,
19
- qd: jtp.Vector,
20
- qdd: jtp.Vector,
21
- a0fb: jtp.Vector = jnp.zeros(6),
22
- f_ext: jtp.Matrix | None = None,
23
- ) -> Tuple[jtp.Vector, jtp.Vector]:
24
- """
25
- Recursive Newton-Euler Algorithm (RNEA) algorithm for inverse dynamics.
26
- """
27
-
28
- xfb, q, qd, qdd, _, f_ext = utils.process_inputs(
29
- physics_model=model, xfb=xfb, q=q, qd=qd, qdd=qdd, f_ext=f_ext
30
- )
31
-
32
- a0fb = a0fb.squeeze()
33
- gravity = model.gravity.squeeze()
34
-
35
- if a0fb.shape[0] != 6:
36
- raise ValueError(a0fb.shape)
37
-
38
- M = model.spatial_inertias
39
- pre_X_λi = model.tree_transforms
40
- i_X_pre = model.joint_transforms(q=q)
41
- S = model.motion_subspaces(q=q)
42
- i_X_λi = jnp.zeros_like(pre_X_λi)
43
-
44
- Γ = jnp.array([*model._joint_motor_gear_ratio.values()])
45
- IM = jnp.array([*model._joint_motor_inertia.values()])
46
- K_v = jnp.array([*model._joint_motor_viscous_friction.values()])
47
- K̅ᵥ = jnp.diag(Γ.T * jnp.diag(K_v) * Γ)
48
- m_S = jnp.concatenate([S[:1], S[1:] * Γ[:, None, None]], axis=0)
49
-
50
- i_X_0 = jnp.zeros_like(pre_X_λi)
51
- i_X_0 = i_X_0.at[0].set(jnp.eye(6))
52
-
53
- # Parent array mapping: i -> λ(i).
54
- # Exception: λ(0) must not be used, it's initialized to -1.
55
- λ = model.parent_array()
56
-
57
- v = jnp.array([jnp.zeros([6, 1])] * model.NB)
58
- a = jnp.array([jnp.zeros([6, 1])] * model.NB)
59
- f = jnp.array([jnp.zeros([6, 1])] * model.NB)
60
-
61
- v_m = jnp.array([jnp.zeros([6, 1])] * model.NB)
62
- a_m = jnp.array([jnp.zeros([6, 1])] * model.NB)
63
- f_m = jnp.array([jnp.zeros([6, 1])] * model.NB)
64
-
65
- # 6D transform of base velocity
66
- B_X_W = Adjoint.from_quaternion_and_translation(
67
- quaternion=xfb[0:4],
68
- translation=xfb[4:7],
69
- inverse=True,
70
- normalize_quaternion=True,
71
- )
72
- i_X_λi = i_X_λi.at[0].set(B_X_W)
73
-
74
- a_0 = -B_X_W @ jnp.vstack(gravity)
75
- a = a.at[0].set(a_0)
76
-
77
- if model.is_floating_base:
78
- W_v_WB = jnp.vstack(jnp.hstack([xfb[10:13], xfb[7:10]]))
79
-
80
- v_0 = B_X_W @ W_v_WB
81
- v = v.at[0].set(v_0)
82
-
83
- a_0 = B_X_W @ (jnp.vstack(a0fb) - jnp.vstack(gravity))
84
- a = a.at[0].set(a_0)
85
-
86
- f_0 = (
87
- M[0] @ a[0]
88
- + Cross.vx_star(v[0]) @ M[0] @ v[0]
89
- - Adjoint.inverse(B_X_W).T @ jnp.vstack(f_ext[0])
90
- )
91
- f = f.at[0].set(f_0)
92
-
93
- ForwardPassCarry = Tuple[
94
- jtp.MatrixJax,
95
- jtp.MatrixJax,
96
- jtp.MatrixJax,
97
- jtp.MatrixJax,
98
- jtp.MatrixJax,
99
- jtp.MatrixJax,
100
- jtp.MatrixJax,
101
- jtp.MatrixJax,
102
- ]
103
- forward_pass_carry = (i_X_λi, v, v_m, a, a_m, i_X_0, f, f_m)
104
-
105
- def forward_pass(
106
- carry: ForwardPassCarry, i: jtp.Int
107
- ) -> Tuple[ForwardPassCarry, None]:
108
- ii = i - 1
109
- i_X_λi, v, v_m, a, a_m, i_X_0, f, f_m = carry
110
-
111
- vJ = S[i] * qd[ii]
112
- vJ_m = m_S[i] * qd[ii]
113
-
114
- i_X_λi_i = i_X_pre[i] @ pre_X_λi[i]
115
- i_X_λi = i_X_λi.at[i].set(i_X_λi_i)
116
-
117
- v_i = i_X_λi[i] @ v[λ[i]] + vJ
118
- v = v.at[i].set(v_i)
119
-
120
- v_i_m = i_X_λi[i] @ v_m[λ[i]] + vJ_m
121
- v_m = v_m.at[i].set(v_i_m)
122
-
123
- a_i = i_X_λi[i] @ a[λ[i]] + S[i] * qdd[ii] + Cross.vx(v[i]) @ vJ
124
- a = a.at[i].set(a_i)
125
-
126
- a_i_m = i_X_λi[i] @ a_m[λ[i]] + m_S[i] * qdd[ii] + Cross.vx(v_m[i]) @ vJ_m
127
- a_m = a_m.at[i].set(a_i_m)
128
-
129
- i_X_0_i = i_X_λi[i] @ i_X_0[λ[i]]
130
- i_X_0 = i_X_0.at[i].set(i_X_0_i)
131
- i_Xf_W = Adjoint.inverse(i_X_0[i] @ B_X_W).T
132
-
133
- f_i = (
134
- M[i] @ a[i]
135
- + Cross.vx_star(v[i]) @ M[i] @ v[i]
136
- - i_Xf_W @ jnp.vstack(f_ext[i])
137
- )
138
- f = f.at[i].set(f_i)
139
-
140
- f_i_m = IM[i] * a_m[i] + Cross.vx_star(v_m[i]) * IM[i] @ v_m[i]
141
- f_m = f_m.at[i].set(f_i_m)
142
-
143
- return (i_X_λi, v, v_m, a, a_m, i_X_0, f, f_m), None
144
-
145
- (i_X_λi, v, v_m, a, a_m, i_X_0, f, f_m), _ = jax.lax.scan(
146
- f=forward_pass,
147
- init=forward_pass_carry,
148
- xs=np.arange(start=1, stop=model.NB),
149
- )
150
-
151
- tau = jnp.zeros_like(q)
152
-
153
- BackwardPassCarry = Tuple[jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax]
154
- backward_pass_carry = (tau, f, f_m)
155
-
156
- def backward_pass(
157
- carry: BackwardPassCarry, i: jtp.Int
158
- ) -> Tuple[BackwardPassCarry, None]:
159
- ii = i - 1
160
- tau, f, f_m = carry
161
-
162
- value = S[i].T @ f[i] + m_S[i].T @ f_m[i] # + K̅ᵥ[i] * qd[ii]
163
- tau = tau.at[ii].set(value.squeeze())
164
-
165
- def update_f(ffm: Tuple[jtp.MatrixJax, jtp.MatrixJax]) -> jtp.MatrixJax:
166
- f, f_m = ffm
167
- f_λi = f[λ[i]] + i_X_λi[i].T @ f[i]
168
- f = f.at[λ[i]].set(f_λi)
169
-
170
- f_m_λi = f_m[λ[i]] + i_X_λi[i].T @ f_m[i]
171
- f_m = f_m.at[λ[i]].set(f_m_λi)
172
- return f, f_m
173
-
174
- f, f_m = jax.lax.cond(
175
- pred=jnp.array([λ[i] != 0, model.is_floating_base]).any(),
176
- true_fun=update_f,
177
- false_fun=lambda f: f,
178
- operand=(f, f_m),
179
- )
180
-
181
- return (tau, f, f_m), None
182
-
183
- (tau, f, f_m), _ = jax.lax.scan(
184
- f=backward_pass,
185
- init=backward_pass_carry,
186
- xs=np.flip(np.arange(start=1, stop=model.NB)),
187
- )
188
-
189
- # Handle 1 DoF models
190
- tau = jnp.atleast_1d(tau.squeeze())
191
- tau = jnp.vstack(tau) if tau.size > 0 else jnp.empty(shape=(0, 1))
192
-
193
- # Express the base 6D force in the world frame
194
- W_f0 = B_X_W.T @ jnp.vstack(f[0])
195
-
196
- return W_f0, tau