jaxsim 0.2.dev191__py3-none-any.whl → 0.2.dev366__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/__init__.py +3 -4
- jaxsim/_version.py +2 -2
- jaxsim/api/__init__.py +3 -1
- jaxsim/api/com.py +240 -0
- jaxsim/api/common.py +13 -2
- jaxsim/api/contact.py +120 -43
- jaxsim/api/data.py +112 -71
- jaxsim/api/joint.py +77 -36
- jaxsim/api/kin_dyn_parameters.py +777 -0
- jaxsim/api/link.py +150 -75
- jaxsim/api/model.py +542 -269
- jaxsim/api/ode.py +86 -74
- jaxsim/api/ode_data.py +694 -0
- jaxsim/api/references.py +12 -11
- jaxsim/integrators/__init__.py +2 -2
- jaxsim/integrators/common.py +110 -24
- jaxsim/integrators/fixed_step.py +11 -67
- jaxsim/integrators/variable_step.py +610 -0
- jaxsim/math/__init__.py +11 -0
- jaxsim/math/adjoint.py +24 -2
- jaxsim/math/joint_model.py +335 -0
- jaxsim/math/quaternion.py +44 -3
- jaxsim/math/rotation.py +4 -4
- jaxsim/math/transform.py +93 -0
- jaxsim/parsers/descriptions/link.py +2 -2
- jaxsim/parsers/rod/utils.py +7 -8
- jaxsim/rbda/__init__.py +7 -0
- jaxsim/rbda/aba.py +295 -0
- jaxsim/rbda/collidable_points.py +142 -0
- jaxsim/{physics/algos → rbda}/crba.py +43 -42
- jaxsim/rbda/forward_kinematics.py +113 -0
- jaxsim/rbda/jacobian.py +201 -0
- jaxsim/rbda/rnea.py +237 -0
- jaxsim/rbda/soft_contacts.py +296 -0
- jaxsim/rbda/utils.py +152 -0
- jaxsim/terrain/__init__.py +2 -0
- jaxsim/utils/__init__.py +1 -4
- jaxsim/utils/hashless.py +18 -0
- jaxsim/utils/jaxsim_dataclass.py +281 -30
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.2.dev366.dist-info}/METADATA +4 -6
- jaxsim-0.2.dev366.dist-info/RECORD +64 -0
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.2.dev366.dist-info}/WHEEL +1 -1
- jaxsim/high_level/__init__.py +0 -2
- jaxsim/high_level/common.py +0 -11
- jaxsim/high_level/joint.py +0 -148
- jaxsim/high_level/link.py +0 -259
- jaxsim/high_level/model.py +0 -1686
- jaxsim/math/conv.py +0 -114
- jaxsim/math/joint.py +0 -102
- jaxsim/math/plucker.py +0 -100
- jaxsim/physics/__init__.py +0 -12
- jaxsim/physics/algos/__init__.py +0 -0
- jaxsim/physics/algos/aba.py +0 -254
- jaxsim/physics/algos/aba_motors.py +0 -284
- jaxsim/physics/algos/forward_kinematics.py +0 -79
- jaxsim/physics/algos/jacobian.py +0 -98
- jaxsim/physics/algos/rnea.py +0 -180
- jaxsim/physics/algos/rnea_motors.py +0 -196
- jaxsim/physics/algos/soft_contacts.py +0 -523
- jaxsim/physics/algos/utils.py +0 -69
- jaxsim/physics/model/__init__.py +0 -0
- jaxsim/physics/model/ground_contact.py +0 -53
- jaxsim/physics/model/physics_model.py +0 -388
- jaxsim/physics/model/physics_model_state.py +0 -283
- jaxsim/simulation/__init__.py +0 -4
- jaxsim/simulation/integrators.py +0 -393
- jaxsim/simulation/ode.py +0 -290
- jaxsim/simulation/ode_data.py +0 -96
- jaxsim/simulation/ode_integration.py +0 -62
- jaxsim/simulation/simulator.py +0 -543
- jaxsim/simulation/simulator_callbacks.py +0 -79
- jaxsim/simulation/utils.py +0 -15
- jaxsim/sixd/__init__.py +0 -2
- jaxsim/utils/oop.py +0 -536
- jaxsim/utils/vmappable.py +0 -117
- jaxsim-0.2.dev191.dist-info/RECORD +0 -81
- /jaxsim/{physics/algos → terrain}/terrain.py +0 -0
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.2.dev366.dist-info}/LICENSE +0 -0
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.2.dev366.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,296 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import dataclasses
|
4
|
+
|
5
|
+
import jax
|
6
|
+
import jax.numpy as jnp
|
7
|
+
import jax_dataclasses
|
8
|
+
|
9
|
+
import jaxsim.api as js
|
10
|
+
import jaxsim.typing as jtp
|
11
|
+
from jaxsim.math import Skew, StandardGravity
|
12
|
+
from jaxsim.terrain import FlatTerrain, Terrain
|
13
|
+
from jaxsim.utils import JaxsimDataclass
|
14
|
+
|
15
|
+
|
16
|
+
@jax_dataclasses.pytree_dataclass
|
17
|
+
class SoftContactsParams(JaxsimDataclass):
|
18
|
+
"""Parameters of the soft contacts model."""
|
19
|
+
|
20
|
+
K: jtp.Float = dataclasses.field(
|
21
|
+
default_factory=lambda: jnp.array(1e6, dtype=float)
|
22
|
+
)
|
23
|
+
|
24
|
+
D: jtp.Float = dataclasses.field(
|
25
|
+
default_factory=lambda: jnp.array(2000, dtype=float)
|
26
|
+
)
|
27
|
+
|
28
|
+
mu: jtp.Float = dataclasses.field(
|
29
|
+
default_factory=lambda: jnp.array(0.5, dtype=float)
|
30
|
+
)
|
31
|
+
|
32
|
+
@staticmethod
|
33
|
+
def build(
|
34
|
+
K: jtp.FloatLike = 1e6, D: jtp.FloatLike = 2_000, mu: jtp.FloatLike = 0.5
|
35
|
+
) -> SoftContactsParams:
|
36
|
+
"""
|
37
|
+
Create a SoftContactsParams instance with specified parameters.
|
38
|
+
|
39
|
+
Args:
|
40
|
+
K: The stiffness parameter.
|
41
|
+
D: The damping parameter of the soft contacts model.
|
42
|
+
mu: The static friction coefficient.
|
43
|
+
|
44
|
+
Returns:
|
45
|
+
A SoftContactsParams instance with the specified parameters.
|
46
|
+
"""
|
47
|
+
|
48
|
+
return SoftContactsParams(
|
49
|
+
K=jnp.array(K, dtype=float),
|
50
|
+
D=jnp.array(D, dtype=float),
|
51
|
+
mu=jnp.array(mu, dtype=float),
|
52
|
+
)
|
53
|
+
|
54
|
+
@staticmethod
|
55
|
+
def build_default_from_jaxsim_model(
|
56
|
+
model: js.model.JaxSimModel,
|
57
|
+
*,
|
58
|
+
standard_gravity: jtp.FloatLike = StandardGravity,
|
59
|
+
static_friction_coefficient: jtp.FloatLike = 0.5,
|
60
|
+
max_penetration: jtp.FloatLike = 0.001,
|
61
|
+
number_of_active_collidable_points_steady_state: jtp.IntLike = 1,
|
62
|
+
damping_ratio: jtp.FloatLike = 1.0,
|
63
|
+
) -> SoftContactsParams:
|
64
|
+
"""
|
65
|
+
Create a SoftContactsParams instance with good default parameters.
|
66
|
+
|
67
|
+
Args:
|
68
|
+
model: The target model.
|
69
|
+
standard_gravity: The standard gravity constant.
|
70
|
+
static_friction_coefficient:
|
71
|
+
The static friction coefficient between the model and the terrain.
|
72
|
+
max_penetration: The maximum penetration depth.
|
73
|
+
number_of_active_collidable_points_steady_state:
|
74
|
+
The number of contacts supporting the weight of the model
|
75
|
+
in steady state.
|
76
|
+
damping_ratio: The ratio controlling the damping behavior.
|
77
|
+
|
78
|
+
Returns:
|
79
|
+
A `SoftContactsParams` instance with the specified parameters.
|
80
|
+
|
81
|
+
Note:
|
82
|
+
The `damping_ratio` parameter allows to operate on the following conditions:
|
83
|
+
- ξ > 1.0: over-damped
|
84
|
+
- ξ = 1.0: critically damped
|
85
|
+
- ξ < 1.0: under-damped
|
86
|
+
"""
|
87
|
+
|
88
|
+
# Use symbols for input parameters
|
89
|
+
ξ = damping_ratio
|
90
|
+
δ_max = max_penetration
|
91
|
+
μc = static_friction_coefficient
|
92
|
+
|
93
|
+
# Compute the total mass of the model
|
94
|
+
m = jnp.array(model.kin_dyn_parameters.link_parameters.mass).sum()
|
95
|
+
|
96
|
+
# Rename the standard gravity
|
97
|
+
g = standard_gravity
|
98
|
+
|
99
|
+
# Compute the average support force on each collidable point
|
100
|
+
f_average = m * g / number_of_active_collidable_points_steady_state
|
101
|
+
|
102
|
+
# Compute the stiffness to get the desired steady-state penetration
|
103
|
+
K = f_average / jnp.power(δ_max, 3 / 2)
|
104
|
+
|
105
|
+
# Compute the damping using the damping ratio
|
106
|
+
critical_damping = 2 * jnp.sqrt(K * m)
|
107
|
+
D = ξ * critical_damping
|
108
|
+
|
109
|
+
return SoftContactsParams.build(K=K, D=D, mu=μc)
|
110
|
+
|
111
|
+
|
112
|
+
@jax_dataclasses.pytree_dataclass
|
113
|
+
class SoftContacts:
|
114
|
+
"""Soft contacts model."""
|
115
|
+
|
116
|
+
parameters: SoftContactsParams = dataclasses.field(
|
117
|
+
default_factory=SoftContactsParams
|
118
|
+
)
|
119
|
+
|
120
|
+
terrain: Terrain = dataclasses.field(default_factory=FlatTerrain)
|
121
|
+
|
122
|
+
def contact_model(
|
123
|
+
self,
|
124
|
+
position: jtp.Vector,
|
125
|
+
velocity: jtp.Vector,
|
126
|
+
tangential_deformation: jtp.Vector,
|
127
|
+
) -> tuple[jtp.Vector, jtp.Vector]:
|
128
|
+
"""
|
129
|
+
Compute the contact forces and material deformation rate.
|
130
|
+
|
131
|
+
Args:
|
132
|
+
position: The position of the collidable point.
|
133
|
+
velocity: The linear velocity of the collidable point.
|
134
|
+
tangential_deformation: The tangential deformation.
|
135
|
+
|
136
|
+
Returns:
|
137
|
+
A tuple containing the contact force and material deformation rate.
|
138
|
+
"""
|
139
|
+
|
140
|
+
# Short name of parameters
|
141
|
+
K = self.parameters.K
|
142
|
+
D = self.parameters.D
|
143
|
+
μ = self.parameters.mu
|
144
|
+
|
145
|
+
# Material 3D tangential deformation and its derivative
|
146
|
+
m = tangential_deformation.squeeze()
|
147
|
+
ṁ = jnp.zeros_like(m)
|
148
|
+
|
149
|
+
# Note: all the small hardcoded tolerances in this method have been introduced
|
150
|
+
# to allow jax differentiating through this algorithm. They should not affect
|
151
|
+
# the accuracy of the simulation, although they might make it less readable.
|
152
|
+
|
153
|
+
# ========================
|
154
|
+
# Normal force computation
|
155
|
+
# ========================
|
156
|
+
|
157
|
+
# Unpack the position of the collidable point
|
158
|
+
px, py, pz = W_p_C = position.squeeze()
|
159
|
+
vx, vy, vz = W_ṗ_C = velocity.squeeze()
|
160
|
+
|
161
|
+
# Compute the terrain normal and the contact depth
|
162
|
+
n̂ = self.terrain.normal(x=px, y=py).squeeze()
|
163
|
+
h = jnp.array([0, 0, self.terrain.height(x=px, y=py) - pz])
|
164
|
+
|
165
|
+
# Compute the penetration depth normal to the terrain
|
166
|
+
δ = jnp.maximum(0.0, jnp.dot(h, n̂))
|
167
|
+
|
168
|
+
# Compute the penetration normal velocity
|
169
|
+
δ̇ = -jnp.dot(W_ṗ_C, n̂)
|
170
|
+
|
171
|
+
# Non-linear spring-damper model.
|
172
|
+
# This is the force magnitude along the direction normal to the terrain.
|
173
|
+
force_normal_mag = jax.lax.select(
|
174
|
+
pred=δ >= 1e-9,
|
175
|
+
on_true=jnp.sqrt(δ + 1e-12) * (K * δ + D * δ̇),
|
176
|
+
on_false=jnp.array(0.0),
|
177
|
+
)
|
178
|
+
|
179
|
+
# Prevent negative normal forces that might occur when δ̇ is largely negative
|
180
|
+
force_normal_mag = jnp.maximum(0.0, force_normal_mag)
|
181
|
+
|
182
|
+
# Compute the 3D linear force in C[W] frame
|
183
|
+
force_normal = force_normal_mag * n̂
|
184
|
+
|
185
|
+
# ====================================
|
186
|
+
# No friction and no tangential forces
|
187
|
+
# ====================================
|
188
|
+
|
189
|
+
# Compute the adjoint C[W]->W for transforming 6D forces from mixed to inertial.
|
190
|
+
# Note: this is equal to the 6D velocities transform: CW_X_W.transpose().
|
191
|
+
W_Xf_CW = jnp.vstack(
|
192
|
+
[
|
193
|
+
jnp.block([jnp.eye(3), jnp.zeros(shape=(3, 3))]),
|
194
|
+
jnp.block([Skew.wedge(W_p_C), jnp.eye(3)]),
|
195
|
+
]
|
196
|
+
)
|
197
|
+
|
198
|
+
def with_no_friction():
|
199
|
+
# Compute 6D mixed force in C[W]
|
200
|
+
CW_f_lin = force_normal
|
201
|
+
CW_f = jnp.hstack([force_normal, jnp.zeros_like(CW_f_lin)])
|
202
|
+
|
203
|
+
# Compute lin-ang 6D forces (inertial representation)
|
204
|
+
W_f = W_Xf_CW @ CW_f
|
205
|
+
|
206
|
+
return W_f, ṁ
|
207
|
+
|
208
|
+
# =========================
|
209
|
+
# Compute tangential forces
|
210
|
+
# =========================
|
211
|
+
|
212
|
+
def with_friction():
|
213
|
+
# Initialize the tangential deformation rate ṁ.
|
214
|
+
# For inactive contacts with m≠0, this is the dynamics of the material
|
215
|
+
# relaxation converging exponentially to steady state.
|
216
|
+
ṁ = (-K / D) * m
|
217
|
+
|
218
|
+
# Check if the collidable point is below ground.
|
219
|
+
# Note: when δ=0, we consider the point still not it contact such that
|
220
|
+
# we prevent divisions by 0 in the computations below.
|
221
|
+
active_contact = pz < self.terrain.height(x=px, y=py)
|
222
|
+
|
223
|
+
def above_terrain():
|
224
|
+
return jnp.zeros(6), ṁ
|
225
|
+
|
226
|
+
def below_terrain():
|
227
|
+
# Decompose the velocity in normal and tangential components
|
228
|
+
v_normal = jnp.dot(W_ṗ_C, n̂) * n̂
|
229
|
+
v_tangential = W_ṗ_C - v_normal
|
230
|
+
|
231
|
+
# Compute the tangential force. If inside the friction cone, the contact
|
232
|
+
f_tangential = -jnp.sqrt(δ + 1e-12) * (K * m + D * v_tangential)
|
233
|
+
|
234
|
+
def sticking_contact():
|
235
|
+
# Sum the normal and tangential forces, and create the 6D force
|
236
|
+
CW_f_stick = force_normal + f_tangential
|
237
|
+
CW_f = jnp.hstack([CW_f_stick, jnp.zeros(3)])
|
238
|
+
|
239
|
+
# In this case the 3D material deformation is the tangential velocity
|
240
|
+
ṁ = v_tangential
|
241
|
+
|
242
|
+
# Return the 6D force in the contact frame and
|
243
|
+
# the deformation derivative
|
244
|
+
return CW_f, ṁ
|
245
|
+
|
246
|
+
def slipping_contact():
|
247
|
+
# Project the force to the friction cone boundary
|
248
|
+
f_tangential_projected = (μ * force_normal_mag) * (
|
249
|
+
f_tangential / jnp.maximum(jnp.linalg.norm(f_tangential), 1e-9)
|
250
|
+
)
|
251
|
+
|
252
|
+
# Sum the normal and tangential forces, and create the 6D force
|
253
|
+
CW_f_slip = force_normal + f_tangential_projected
|
254
|
+
CW_f = jnp.hstack([CW_f_slip, jnp.zeros(3)])
|
255
|
+
|
256
|
+
# Correct the material deformation derivative for slipping contacts.
|
257
|
+
# Basically we compute ṁ such that we get `f_tangential` on the cone
|
258
|
+
# given the current (m, δ).
|
259
|
+
ε = 1e-9
|
260
|
+
δε = jnp.maximum(δ, ε)
|
261
|
+
α = -K * jnp.sqrt(δε)
|
262
|
+
β = -D * jnp.sqrt(δε)
|
263
|
+
ṁ = (f_tangential_projected - α * m) / β
|
264
|
+
|
265
|
+
# Return the 6D force in the contact frame and
|
266
|
+
# the deformation derivative
|
267
|
+
return CW_f, ṁ
|
268
|
+
|
269
|
+
CW_f, ṁ = jax.lax.cond(
|
270
|
+
pred=f_tangential.dot(f_tangential) > (μ * force_normal_mag) ** 2,
|
271
|
+
true_fun=lambda _: slipping_contact(),
|
272
|
+
false_fun=lambda _: sticking_contact(),
|
273
|
+
operand=None,
|
274
|
+
)
|
275
|
+
|
276
|
+
# Express the 6D force in the world frame
|
277
|
+
W_f = W_Xf_CW @ CW_f
|
278
|
+
|
279
|
+
# Return the 6D force in the world frame and the deformation derivative
|
280
|
+
return W_f, ṁ
|
281
|
+
|
282
|
+
# (W_f, ṁ)
|
283
|
+
return jax.lax.cond(
|
284
|
+
pred=active_contact,
|
285
|
+
true_fun=lambda _: below_terrain(),
|
286
|
+
false_fun=lambda _: above_terrain(),
|
287
|
+
operand=None,
|
288
|
+
)
|
289
|
+
|
290
|
+
# (W_f, ṁ)
|
291
|
+
return jax.lax.cond(
|
292
|
+
pred=(μ == 0.0),
|
293
|
+
true_fun=lambda _: with_no_friction(),
|
294
|
+
false_fun=lambda _: with_friction(),
|
295
|
+
operand=None,
|
296
|
+
)
|
jaxsim/rbda/utils.py
ADDED
@@ -0,0 +1,152 @@
|
|
1
|
+
import jax.numpy as jnp
|
2
|
+
|
3
|
+
import jaxsim.api as js
|
4
|
+
import jaxsim.typing as jtp
|
5
|
+
from jaxsim.math import StandardGravity
|
6
|
+
|
7
|
+
|
8
|
+
def process_inputs(
|
9
|
+
model: js.model.JaxSimModel,
|
10
|
+
*,
|
11
|
+
base_position: jtp.VectorLike | None = None,
|
12
|
+
base_quaternion: jtp.VectorLike | None = None,
|
13
|
+
joint_positions: jtp.VectorLike | None = None,
|
14
|
+
base_linear_velocity: jtp.VectorLike | None = None,
|
15
|
+
base_angular_velocity: jtp.VectorLike | None = None,
|
16
|
+
joint_velocities: jtp.VectorLike | None = None,
|
17
|
+
base_linear_acceleration: jtp.VectorLike | None = None,
|
18
|
+
base_angular_acceleration: jtp.VectorLike | None = None,
|
19
|
+
joint_accelerations: jtp.VectorLike | None = None,
|
20
|
+
joint_forces: jtp.VectorLike | None = None,
|
21
|
+
link_forces: jtp.MatrixLike | None = None,
|
22
|
+
standard_gravity: jtp.VectorLike | None = None,
|
23
|
+
) -> tuple[
|
24
|
+
jtp.Vector,
|
25
|
+
jtp.Vector,
|
26
|
+
jtp.Vector,
|
27
|
+
jtp.Vector,
|
28
|
+
jtp.Vector,
|
29
|
+
jtp.Vector,
|
30
|
+
jtp.Vector,
|
31
|
+
jtp.Vector,
|
32
|
+
jtp.Matrix,
|
33
|
+
jtp.Vector,
|
34
|
+
]:
|
35
|
+
"""
|
36
|
+
Adjust the inputs to rigid-body dynamics algorithms.
|
37
|
+
|
38
|
+
Args:
|
39
|
+
model: The model to consider.
|
40
|
+
base_position: The position of the base link.
|
41
|
+
base_quaternion: The quaternion of the base link.
|
42
|
+
joint_positions: The positions of the joints.
|
43
|
+
base_linear_velocity: The linear velocity of the base link.
|
44
|
+
base_angular_velocity: The angular velocity of the base link.
|
45
|
+
joint_velocities: The velocities of the joints.
|
46
|
+
base_linear_acceleration: The linear acceleration of the base link.
|
47
|
+
base_angular_acceleration: The angular acceleration of the base link.
|
48
|
+
joint_accelerations: The accelerations of the joints.
|
49
|
+
joint_forces: The forces applied to the joints.
|
50
|
+
link_forces: The forces applied to the links.
|
51
|
+
standard_gravity: The standard gravity constant.
|
52
|
+
|
53
|
+
Returns:
|
54
|
+
The adjusted inputs.
|
55
|
+
"""
|
56
|
+
|
57
|
+
dofs = model.dofs()
|
58
|
+
nl = model.number_of_links()
|
59
|
+
|
60
|
+
# Floating-base position.
|
61
|
+
W_p_B = base_position
|
62
|
+
W_Q_B = base_quaternion
|
63
|
+
s = joint_positions
|
64
|
+
|
65
|
+
# Floating-base velocity in inertial-fixed representation.
|
66
|
+
W_vl_WB = base_linear_velocity
|
67
|
+
W_ω_WB = base_angular_velocity
|
68
|
+
ṡ = joint_velocities
|
69
|
+
|
70
|
+
# Floating-base acceleration in inertial-fixed representation.
|
71
|
+
W_v̇l_WB = base_linear_acceleration
|
72
|
+
W_ω̇_WB = base_angular_acceleration
|
73
|
+
s̈ = joint_accelerations
|
74
|
+
|
75
|
+
# System dynamics inputs.
|
76
|
+
f = link_forces
|
77
|
+
τ = joint_forces
|
78
|
+
|
79
|
+
# Fill missing data and adjust dimensions.
|
80
|
+
s = jnp.atleast_1d(s.squeeze()) if s is not None else jnp.zeros(dofs)
|
81
|
+
ṡ = jnp.atleast_1d(ṡ.squeeze()) if ṡ is not None else jnp.zeros(dofs)
|
82
|
+
s̈ = jnp.atleast_1d(s̈.squeeze()) if s̈ is not None else jnp.zeros(dofs)
|
83
|
+
τ = jnp.atleast_1d(τ.squeeze()) if τ is not None else jnp.zeros(dofs)
|
84
|
+
W_vl_WB = jnp.atleast_1d(W_vl_WB.squeeze()) if W_vl_WB is not None else jnp.zeros(3)
|
85
|
+
W_v̇l_WB = jnp.atleast_1d(W_v̇l_WB.squeeze()) if W_v̇l_WB is not None else jnp.zeros(3)
|
86
|
+
W_p_B = jnp.atleast_1d(W_p_B.squeeze()) if W_p_B is not None else jnp.zeros(3)
|
87
|
+
W_ω_WB = jnp.atleast_1d(W_ω_WB.squeeze()) if W_ω_WB is not None else jnp.zeros(3)
|
88
|
+
W_ω̇_WB = jnp.atleast_1d(W_ω̇_WB.squeeze()) if W_ω̇_WB is not None else jnp.zeros(3)
|
89
|
+
f = jnp.atleast_2d(f.squeeze()) if f is not None else jnp.zeros(shape=(nl, 6))
|
90
|
+
W_Q_B = (
|
91
|
+
jnp.atleast_1d(W_Q_B.squeeze())
|
92
|
+
if W_Q_B is not None
|
93
|
+
else jnp.array([1.0, 0, 0, 0])
|
94
|
+
)
|
95
|
+
standard_gravity = (
|
96
|
+
jnp.array(standard_gravity).squeeze()
|
97
|
+
if standard_gravity is not None
|
98
|
+
else StandardGravity
|
99
|
+
)
|
100
|
+
|
101
|
+
if s.shape != (dofs,):
|
102
|
+
raise ValueError(s.shape, dofs)
|
103
|
+
|
104
|
+
if ṡ.shape != (dofs,):
|
105
|
+
raise ValueError(ṡ.shape, dofs)
|
106
|
+
|
107
|
+
if s̈.shape != (dofs,):
|
108
|
+
raise ValueError(s̈.shape, dofs)
|
109
|
+
|
110
|
+
if τ.shape != (dofs,):
|
111
|
+
raise ValueError(τ.shape, dofs)
|
112
|
+
|
113
|
+
if W_p_B.shape != (3,):
|
114
|
+
raise ValueError(W_p_B.shape, (3,))
|
115
|
+
|
116
|
+
if W_vl_WB.shape != (3,):
|
117
|
+
raise ValueError(W_vl_WB.shape, (3,))
|
118
|
+
|
119
|
+
if W_ω_WB.shape != (3,):
|
120
|
+
raise ValueError(W_ω_WB.shape, (3,))
|
121
|
+
|
122
|
+
if W_v̇l_WB.shape != (3,):
|
123
|
+
raise ValueError(W_v̇l_WB.shape, (3,))
|
124
|
+
|
125
|
+
if W_ω̇_WB.shape != (3,):
|
126
|
+
raise ValueError(W_ω̇_WB.shape, (3,))
|
127
|
+
|
128
|
+
if f.shape != (nl, 6):
|
129
|
+
raise ValueError(f.shape, (nl, 6))
|
130
|
+
|
131
|
+
if W_Q_B.shape != (4,):
|
132
|
+
raise ValueError(W_Q_B.shape, (4,))
|
133
|
+
|
134
|
+
# Pack the 6D base velocity and acceleration.
|
135
|
+
W_v_WB = jnp.hstack([W_vl_WB, W_ω_WB])
|
136
|
+
W_v̇_WB = jnp.hstack([W_v̇l_WB, W_ω̇_WB])
|
137
|
+
|
138
|
+
# Create the 6D gravity acceleration.
|
139
|
+
W_g = jnp.zeros(6).at[2].set(-standard_gravity)
|
140
|
+
|
141
|
+
return (
|
142
|
+
W_p_B.astype(float),
|
143
|
+
W_Q_B.astype(float),
|
144
|
+
s.astype(float),
|
145
|
+
W_v_WB.astype(float),
|
146
|
+
ṡ.astype(float),
|
147
|
+
W_v̇_WB.astype(float),
|
148
|
+
s̈.astype(float),
|
149
|
+
τ.astype(float),
|
150
|
+
f.astype(float),
|
151
|
+
W_g.astype(float),
|
152
|
+
)
|
jaxsim/utils/__init__.py
CHANGED
@@ -1,8 +1,5 @@
|
|
1
1
|
from jax_dataclasses._copy_and_mutate import _Mutability as Mutability
|
2
2
|
|
3
|
+
from .hashless import HashlessObject
|
3
4
|
from .jaxsim_dataclass import JaxsimDataclass
|
4
5
|
from .tracing import not_tracing, tracing
|
5
|
-
from .vmappable import Vmappable
|
6
|
-
|
7
|
-
# Leave this below the others to prevent circular imports
|
8
|
-
from .oop import jax_tf # isort: skip
|
jaxsim/utils/hashless.py
ADDED
@@ -0,0 +1,18 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import dataclasses
|
4
|
+
from typing import Generic, TypeVar
|
5
|
+
|
6
|
+
T = TypeVar("T")
|
7
|
+
|
8
|
+
|
9
|
+
@dataclasses.dataclass
|
10
|
+
class HashlessObject(Generic[T]):
|
11
|
+
|
12
|
+
obj: T
|
13
|
+
|
14
|
+
def get(self: HashlessObject[T]) -> T:
|
15
|
+
return self.obj
|
16
|
+
|
17
|
+
def __hash__(self) -> int:
|
18
|
+
return 0
|