jaxsim 0.4.3.dev68__py3-none-any.whl → 0.4.3.dev70__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/__init__.py +5 -0
- jaxsim/_version.py +2 -2
- jaxsim/api/contact.py +27 -1
- jaxsim/api/data.py +68 -20
- jaxsim/api/joint.py +62 -2
- jaxsim/api/model.py +37 -23
- jaxsim/api/ode.py +26 -24
- jaxsim/api/ode_data.py +11 -1
- jaxsim/integrators/common.py +1 -1
- jaxsim/math/inertia.py +1 -1
- jaxsim/mujoco/loaders.py +3 -3
- jaxsim/parsers/kinematic_graph.py +3 -3
- jaxsim/parsers/rod/parser.py +18 -14
- jaxsim/rbda/contacts/relaxed_rigid.py +384 -0
- jaxsim/rbda/contacts/rigid.py +11 -41
- jaxsim/terrain/terrain.py +41 -25
- jaxsim/typing.py +1 -1
- jaxsim/utils/jaxsim_dataclass.py +12 -9
- jaxsim/utils/wrappers.py +1 -1
- {jaxsim-0.4.3.dev68.dist-info → jaxsim-0.4.3.dev70.dist-info}/METADATA +2 -1
- {jaxsim-0.4.3.dev68.dist-info → jaxsim-0.4.3.dev70.dist-info}/RECORD +24 -23
- {jaxsim-0.4.3.dev68.dist-info → jaxsim-0.4.3.dev70.dist-info}/WHEEL +1 -1
- {jaxsim-0.4.3.dev68.dist-info → jaxsim-0.4.3.dev70.dist-info}/LICENSE +0 -0
- {jaxsim-0.4.3.dev68.dist-info → jaxsim-0.4.3.dev70.dist-info}/top_level.txt +0 -0
jaxsim/parsers/rod/parser.py
CHANGED
@@ -223,7 +223,7 @@ def extract_model_data(
|
|
223
223
|
child=links_dict[j.child],
|
224
224
|
jtype=utils.joint_to_joint_type(joint=j),
|
225
225
|
axis=(
|
226
|
-
np.array(j.axis.xyz.xyz)
|
226
|
+
np.array(j.axis.xyz.xyz, dtype=float)
|
227
227
|
if j.axis is not None
|
228
228
|
and j.axis.xyz is not None
|
229
229
|
and j.axis.xyz.xyz is not None
|
@@ -232,39 +232,43 @@ def extract_model_data(
|
|
232
232
|
pose=j.pose.transform() if j.pose is not None else np.eye(4),
|
233
233
|
initial_position=0.0,
|
234
234
|
position_limit=(
|
235
|
-
(
|
236
|
-
|
237
|
-
if j.axis is not None
|
238
|
-
|
235
|
+
float(
|
236
|
+
j.axis.limit.lower
|
237
|
+
if j.axis is not None
|
238
|
+
and j.axis.limit is not None
|
239
|
+
and j.axis.limit.lower is not None
|
240
|
+
else jnp.finfo(float).min
|
239
241
|
),
|
240
|
-
(
|
241
|
-
|
242
|
-
if j.axis is not None
|
243
|
-
|
242
|
+
float(
|
243
|
+
j.axis.limit.upper
|
244
|
+
if j.axis is not None
|
245
|
+
and j.axis.limit is not None
|
246
|
+
and j.axis.limit.upper is not None
|
247
|
+
else jnp.finfo(float).max
|
244
248
|
),
|
245
249
|
),
|
246
|
-
friction_static=(
|
250
|
+
friction_static=float(
|
247
251
|
j.axis.dynamics.friction
|
248
252
|
if j.axis is not None
|
249
253
|
and j.axis.dynamics is not None
|
250
254
|
and j.axis.dynamics.friction is not None
|
251
255
|
else 0.0
|
252
256
|
),
|
253
|
-
friction_viscous=(
|
257
|
+
friction_viscous=float(
|
254
258
|
j.axis.dynamics.damping
|
255
259
|
if j.axis is not None
|
256
260
|
and j.axis.dynamics is not None
|
257
261
|
and j.axis.dynamics.damping is not None
|
258
262
|
else 0.0
|
259
263
|
),
|
260
|
-
position_limit_damper=(
|
264
|
+
position_limit_damper=float(
|
261
265
|
j.axis.limit.dissipation
|
262
266
|
if j.axis is not None
|
263
267
|
and j.axis.limit is not None
|
264
268
|
and j.axis.limit.dissipation is not None
|
265
269
|
else 0.0
|
266
270
|
),
|
267
|
-
position_limit_spring=(
|
271
|
+
position_limit_spring=float(
|
268
272
|
j.axis.limit.stiffness
|
269
273
|
if j.axis is not None
|
270
274
|
and j.axis.limit is not None
|
@@ -273,7 +277,7 @@ def extract_model_data(
|
|
273
277
|
),
|
274
278
|
)
|
275
279
|
for j in sdf_model.joints()
|
276
|
-
if j.type in {"revolute", "prismatic", "fixed"}
|
280
|
+
if j.type in {"revolute", "continuous", "prismatic", "fixed"}
|
277
281
|
and j.parent != "world"
|
278
282
|
and j.child in links_dict.keys()
|
279
283
|
]
|
@@ -0,0 +1,384 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import dataclasses
|
4
|
+
from typing import Any
|
5
|
+
|
6
|
+
import jax
|
7
|
+
import jax.numpy as jnp
|
8
|
+
import jax_dataclasses
|
9
|
+
import jaxopt
|
10
|
+
|
11
|
+
import jaxsim.api as js
|
12
|
+
import jaxsim.typing as jtp
|
13
|
+
from jaxsim.api.common import VelRepr
|
14
|
+
from jaxsim.math import Adjoint
|
15
|
+
from jaxsim.terrain.terrain import FlatTerrain, Terrain
|
16
|
+
|
17
|
+
from .common import ContactModel, ContactsParams, ContactsState
|
18
|
+
|
19
|
+
|
20
|
+
@jax_dataclasses.pytree_dataclass
|
21
|
+
class RelaxedRigidContactsParams(ContactsParams):
|
22
|
+
"""Parameters of the relaxed rigid contacts model."""
|
23
|
+
|
24
|
+
# Time constant
|
25
|
+
time_constant: jtp.Float = dataclasses.field(
|
26
|
+
default_factory=lambda: jnp.array(0.01, dtype=float)
|
27
|
+
)
|
28
|
+
|
29
|
+
# Adimensional damping coefficient
|
30
|
+
damping_coefficient: jtp.Float = dataclasses.field(
|
31
|
+
default_factory=lambda: jnp.array(1.0, dtype=float)
|
32
|
+
)
|
33
|
+
|
34
|
+
# Minimum impedance
|
35
|
+
d_min: jtp.Float = dataclasses.field(
|
36
|
+
default_factory=lambda: jnp.array(0.9, dtype=float)
|
37
|
+
)
|
38
|
+
|
39
|
+
# Maximum impedance
|
40
|
+
d_max: jtp.Float = dataclasses.field(
|
41
|
+
default_factory=lambda: jnp.array(0.95, dtype=float)
|
42
|
+
)
|
43
|
+
|
44
|
+
# Width
|
45
|
+
width: jtp.Float = dataclasses.field(
|
46
|
+
default_factory=lambda: jnp.array(0.0001, dtype=float)
|
47
|
+
)
|
48
|
+
|
49
|
+
# Midpoint
|
50
|
+
midpoint: jtp.Float = dataclasses.field(
|
51
|
+
default_factory=lambda: jnp.array(0.1, dtype=float)
|
52
|
+
)
|
53
|
+
|
54
|
+
# Power exponent
|
55
|
+
power: jtp.Float = dataclasses.field(
|
56
|
+
default_factory=lambda: jnp.array(1.0, dtype=float)
|
57
|
+
)
|
58
|
+
|
59
|
+
# Stiffness
|
60
|
+
stiffness: jtp.Float = dataclasses.field(
|
61
|
+
default_factory=lambda: jnp.array(0.0, dtype=float)
|
62
|
+
)
|
63
|
+
|
64
|
+
# Damping
|
65
|
+
damping: jtp.Float = dataclasses.field(
|
66
|
+
default_factory=lambda: jnp.array(0.0, dtype=float)
|
67
|
+
)
|
68
|
+
|
69
|
+
# Friction coefficient
|
70
|
+
mu: jtp.Float = dataclasses.field(
|
71
|
+
default_factory=lambda: jnp.array(0.5, dtype=float)
|
72
|
+
)
|
73
|
+
|
74
|
+
# Maximum number of iterations
|
75
|
+
max_iterations: jtp.Int = dataclasses.field(
|
76
|
+
default_factory=lambda: jnp.array(50, dtype=int)
|
77
|
+
)
|
78
|
+
|
79
|
+
# Solver tolerance
|
80
|
+
tolerance: jtp.Float = dataclasses.field(
|
81
|
+
default_factory=lambda: jnp.array(1e-6, dtype=float)
|
82
|
+
)
|
83
|
+
|
84
|
+
def __hash__(self) -> int:
|
85
|
+
from jaxsim.utils.wrappers import HashedNumpyArray
|
86
|
+
|
87
|
+
return hash(
|
88
|
+
(
|
89
|
+
HashedNumpyArray(self.time_constant),
|
90
|
+
HashedNumpyArray(self.damping_coefficient),
|
91
|
+
HashedNumpyArray(self.d_min),
|
92
|
+
HashedNumpyArray(self.d_max),
|
93
|
+
HashedNumpyArray(self.width),
|
94
|
+
HashedNumpyArray(self.midpoint),
|
95
|
+
HashedNumpyArray(self.power),
|
96
|
+
HashedNumpyArray(self.stiffness),
|
97
|
+
HashedNumpyArray(self.damping),
|
98
|
+
HashedNumpyArray(self.mu),
|
99
|
+
HashedNumpyArray(self.max_iterations),
|
100
|
+
HashedNumpyArray(self.tolerance),
|
101
|
+
)
|
102
|
+
)
|
103
|
+
|
104
|
+
def __eq__(self, other: RelaxedRigidContactsParams) -> bool:
|
105
|
+
return hash(self) == hash(other)
|
106
|
+
|
107
|
+
@classmethod
|
108
|
+
def build(
|
109
|
+
cls,
|
110
|
+
time_constant: jtp.FloatLike | None = None,
|
111
|
+
damping_coefficient: jtp.FloatLike | None = None,
|
112
|
+
d_min: jtp.FloatLike | None = None,
|
113
|
+
d_max: jtp.FloatLike | None = None,
|
114
|
+
width: jtp.FloatLike | None = None,
|
115
|
+
midpoint: jtp.FloatLike | None = None,
|
116
|
+
power: jtp.FloatLike | None = None,
|
117
|
+
stiffness: jtp.FloatLike | None = None,
|
118
|
+
damping: jtp.FloatLike | None = None,
|
119
|
+
mu: jtp.FloatLike | None = None,
|
120
|
+
max_iterations: jtp.IntLike | None = None,
|
121
|
+
tolerance: jtp.FloatLike | None = None,
|
122
|
+
) -> RelaxedRigidContactsParams:
|
123
|
+
"""Create a `RelaxedRigidContactsParams` instance"""
|
124
|
+
|
125
|
+
return cls(
|
126
|
+
**{
|
127
|
+
field: jnp.array(locals().get(field, default), dtype=default.dtype)
|
128
|
+
for field, default in map(
|
129
|
+
lambda f: (f, cls.__dataclass_fields__[f].default),
|
130
|
+
filter(lambda f: f != "__mutability__", cls.__dataclass_fields__),
|
131
|
+
)
|
132
|
+
}
|
133
|
+
)
|
134
|
+
|
135
|
+
def valid(self) -> bool:
|
136
|
+
return bool(
|
137
|
+
jnp.all(self.time_constant >= 0.0)
|
138
|
+
and jnp.all(self.damping_coefficient > 0.0)
|
139
|
+
and jnp.all(self.d_min >= 0.0)
|
140
|
+
and jnp.all(self.d_max <= 1.0)
|
141
|
+
and jnp.all(self.d_min <= self.d_max)
|
142
|
+
and jnp.all(self.width >= 0.0)
|
143
|
+
and jnp.all(self.midpoint >= 0.0)
|
144
|
+
and jnp.all(self.power >= 0.0)
|
145
|
+
and jnp.all(self.mu >= 0.0)
|
146
|
+
and jnp.all(self.max_iterations > 0)
|
147
|
+
and jnp.all(self.tolerance > 0.0)
|
148
|
+
)
|
149
|
+
|
150
|
+
|
151
|
+
@jax_dataclasses.pytree_dataclass
|
152
|
+
class RelaxedRigidContactsState(ContactsState):
|
153
|
+
"""Class storing the state of the relaxed rigid contacts model."""
|
154
|
+
|
155
|
+
def __eq__(self, other: RelaxedRigidContactsState) -> bool:
|
156
|
+
return hash(self) == hash(other)
|
157
|
+
|
158
|
+
@staticmethod
|
159
|
+
def build() -> RelaxedRigidContactsState:
|
160
|
+
"""Create a `RelaxedRigidContactsState` instance"""
|
161
|
+
|
162
|
+
return RelaxedRigidContactsState()
|
163
|
+
|
164
|
+
@staticmethod
|
165
|
+
def zero(model: js.model.JaxSimModel) -> RelaxedRigidContactsState:
|
166
|
+
"""Build a zero `RelaxedRigidContactsState` instance from a `JaxSimModel`."""
|
167
|
+
return RelaxedRigidContactsState.build()
|
168
|
+
|
169
|
+
def valid(self, model: js.model.JaxSimModel) -> bool:
|
170
|
+
return True
|
171
|
+
|
172
|
+
|
173
|
+
@jax_dataclasses.pytree_dataclass
|
174
|
+
class RelaxedRigidContacts(ContactModel):
|
175
|
+
"""Relaxed rigid contacts model."""
|
176
|
+
|
177
|
+
parameters: RelaxedRigidContactsParams = dataclasses.field(
|
178
|
+
default_factory=RelaxedRigidContactsParams
|
179
|
+
)
|
180
|
+
|
181
|
+
terrain: jax_dataclasses.Static[Terrain] = dataclasses.field(
|
182
|
+
default_factory=FlatTerrain
|
183
|
+
)
|
184
|
+
|
185
|
+
def compute_contact_forces(
|
186
|
+
self,
|
187
|
+
position: jtp.Vector,
|
188
|
+
velocity: jtp.Vector,
|
189
|
+
model: js.model.JaxSimModel,
|
190
|
+
data: js.data.JaxSimModelData,
|
191
|
+
link_forces: jtp.MatrixLike | None = None,
|
192
|
+
) -> tuple[jtp.Vector, tuple[Any, ...]]:
|
193
|
+
|
194
|
+
link_forces = (
|
195
|
+
link_forces
|
196
|
+
if link_forces is not None
|
197
|
+
else jnp.zeros((model.number_of_links(), 6))
|
198
|
+
)
|
199
|
+
|
200
|
+
references = js.references.JaxSimModelReferences.build(
|
201
|
+
model=model,
|
202
|
+
data=data,
|
203
|
+
velocity_representation=data.velocity_representation,
|
204
|
+
link_forces=link_forces,
|
205
|
+
)
|
206
|
+
|
207
|
+
def _detect_contact(x: jtp.Array, y: jtp.Array, z: jtp.Array) -> jtp.Array:
|
208
|
+
x, y, z = jax.tree_map(jnp.squeeze, (x, y, z))
|
209
|
+
|
210
|
+
n̂ = self.terrain.normal(x=x, y=y).squeeze()
|
211
|
+
h = jnp.array([0, 0, z - model.terrain.height(x=x, y=y)])
|
212
|
+
|
213
|
+
return jnp.dot(h, n̂)
|
214
|
+
|
215
|
+
# Compute the activation state of the collidable points
|
216
|
+
δ = jax.vmap(_detect_contact)(*position.T)
|
217
|
+
|
218
|
+
with (
|
219
|
+
references.switch_velocity_representation(VelRepr.Mixed),
|
220
|
+
data.switch_velocity_representation(VelRepr.Mixed),
|
221
|
+
):
|
222
|
+
M = js.model.free_floating_mass_matrix(model=model, data=data)
|
223
|
+
Jl_WC = jnp.vstack(
|
224
|
+
jax.vmap(lambda J, height: J * (height < 0))(
|
225
|
+
js.contact.jacobian(model=model, data=data)[:, :3, :], δ
|
226
|
+
)
|
227
|
+
)
|
228
|
+
W_H_C = js.contact.transforms(model=model, data=data)
|
229
|
+
BW_ν̇_free = jnp.hstack(
|
230
|
+
js.ode.system_acceleration(
|
231
|
+
model=model,
|
232
|
+
data=data,
|
233
|
+
link_forces=references.link_forces(model=model, data=data),
|
234
|
+
)
|
235
|
+
)
|
236
|
+
BW_ν = data.generalized_velocity()
|
237
|
+
J̇_WC = jnp.vstack(
|
238
|
+
jax.vmap(lambda J̇, height: J̇ * (height < 0))(
|
239
|
+
js.contact.jacobian_derivative(model=model, data=data)[:, :3], δ
|
240
|
+
),
|
241
|
+
)
|
242
|
+
|
243
|
+
a_ref, R, K, D = self._regularizers(
|
244
|
+
model=model,
|
245
|
+
penetration=δ,
|
246
|
+
velocity=velocity,
|
247
|
+
parameters=self.parameters,
|
248
|
+
)
|
249
|
+
|
250
|
+
G = Jl_WC @ jnp.linalg.lstsq(M, Jl_WC.T)[0]
|
251
|
+
CW_al_free_WC = Jl_WC @ BW_ν̇_free + J̇_WC @ BW_ν
|
252
|
+
|
253
|
+
# Calculate quantities for the linear optimization problem.
|
254
|
+
A = G + R
|
255
|
+
b = CW_al_free_WC - a_ref
|
256
|
+
|
257
|
+
objective = lambda x: jnp.sum(jnp.square(A @ x + b))
|
258
|
+
|
259
|
+
# Compute the 3D linear force in C[W] frame
|
260
|
+
opt = jaxopt.LBFGS(
|
261
|
+
fun=objective,
|
262
|
+
maxiter=self.parameters.max_iterations,
|
263
|
+
tol=self.parameters.tolerance,
|
264
|
+
maxls=30,
|
265
|
+
history_size=10,
|
266
|
+
max_stepsize=100.0,
|
267
|
+
)
|
268
|
+
|
269
|
+
init_params = (
|
270
|
+
K[:, jnp.newaxis] * jnp.zeros_like(position).at[:, 2].set(δ)
|
271
|
+
+ D[:, jnp.newaxis] * velocity
|
272
|
+
).flatten()
|
273
|
+
|
274
|
+
CW_f_Ci = opt.run(init_params=init_params).params.reshape(-1, 3)
|
275
|
+
|
276
|
+
def mixed_to_inertial(W_H_C: jax.Array, CW_fl: jax.Array) -> jax.Array:
|
277
|
+
W_Xf_CW = Adjoint.from_transform(
|
278
|
+
W_H_C.at[0:3, 0:3].set(jnp.eye(3)),
|
279
|
+
inverse=True,
|
280
|
+
).T
|
281
|
+
return W_Xf_CW @ jnp.hstack([CW_fl, jnp.zeros(3)])
|
282
|
+
|
283
|
+
W_f_C = jax.vmap(mixed_to_inertial)(W_H_C, CW_f_Ci)
|
284
|
+
|
285
|
+
return W_f_C, (None,)
|
286
|
+
|
287
|
+
@staticmethod
|
288
|
+
def _regularizers(
|
289
|
+
model: js.model.JaxSimModel,
|
290
|
+
penetration: jtp.Array,
|
291
|
+
velocity: jtp.Array,
|
292
|
+
parameters: RelaxedRigidContactsParams,
|
293
|
+
) -> tuple:
|
294
|
+
"""
|
295
|
+
Compute the contact jacobian and the reference acceleration.
|
296
|
+
|
297
|
+
Args:
|
298
|
+
model: The jaxsim model.
|
299
|
+
penetration: The penetration of the collidable points.
|
300
|
+
velocity: The velocity of the collidable points.
|
301
|
+
parameters: The parameters of the relaxed rigid contacts model.
|
302
|
+
|
303
|
+
Returns:
|
304
|
+
A tuple containing the reference acceleration, the regularization matrix, the stiffness, and the damping.
|
305
|
+
"""
|
306
|
+
|
307
|
+
Ω, ζ, ξ_min, ξ_max, width, mid, p, K, D, μ, *_ = jax_dataclasses.astuple(
|
308
|
+
parameters
|
309
|
+
)
|
310
|
+
|
311
|
+
def _imp_aref(
|
312
|
+
penetration: jtp.Array,
|
313
|
+
velocity: jtp.Array,
|
314
|
+
) -> tuple[jtp.Array, jtp.Array]:
|
315
|
+
"""
|
316
|
+
Calculates impedance and offset acceleration in constraint frame.
|
317
|
+
|
318
|
+
Args:
|
319
|
+
penetration: penetration in constraint frame
|
320
|
+
velocity: velocity in constraint frame
|
321
|
+
|
322
|
+
Returns:
|
323
|
+
a_ref: offset acceleration in constraint frame
|
324
|
+
R: regularization matrix
|
325
|
+
K: computed stiffness
|
326
|
+
D: computed damping
|
327
|
+
"""
|
328
|
+
position = jnp.zeros(shape=(3,)).at[2].set(penetration)
|
329
|
+
|
330
|
+
imp_x = jnp.abs(position) / width
|
331
|
+
imp_a = (1.0 / jnp.power(mid, p - 1)) * jnp.power(imp_x, p)
|
332
|
+
|
333
|
+
imp_b = 1 - (1.0 / jnp.power(1 - mid, p - 1)) * jnp.power(1 - imp_x, p)
|
334
|
+
|
335
|
+
imp_y = jnp.where(imp_x < mid, imp_a, imp_b)
|
336
|
+
|
337
|
+
imp = jnp.clip(ξ_min + imp_y * (ξ_max - ξ_min), ξ_min, ξ_max)
|
338
|
+
imp = jnp.atleast_1d(jnp.where(imp_x > 1.0, ξ_max, imp))
|
339
|
+
|
340
|
+
# When passing negative values, K and D represent a spring and damper, respectively.
|
341
|
+
K_f = jnp.where(K < 0, -K / ξ_max**2, 1 / (ξ_max * Ω * ζ) ** 2)
|
342
|
+
D_f = jnp.where(D < 0, -D / ξ_max, 2 / (ξ_max * Ω))
|
343
|
+
|
344
|
+
a_ref = -jnp.atleast_1d(D_f * velocity + K_f * imp * position)
|
345
|
+
|
346
|
+
return imp, a_ref, jnp.atleast_1d(K_f), jnp.atleast_1d(D_f)
|
347
|
+
|
348
|
+
def _compute_row(
|
349
|
+
*,
|
350
|
+
link_idx: jtp.Float,
|
351
|
+
penetration: jtp.Array,
|
352
|
+
velocity: jtp.Array,
|
353
|
+
) -> tuple[jtp.Array, jtp.Array]:
|
354
|
+
|
355
|
+
# Compute the reference acceleration.
|
356
|
+
ξ, a_ref, K, D = _imp_aref(
|
357
|
+
penetration=penetration,
|
358
|
+
velocity=velocity,
|
359
|
+
)
|
360
|
+
|
361
|
+
# Compute the regularization terms.
|
362
|
+
R = (
|
363
|
+
(2 * μ**2 * (1 - ξ) / (ξ + 1e-12))
|
364
|
+
* (1 + μ**2)
|
365
|
+
@ jnp.linalg.inv(M_L[link_idx, :3, :3])
|
366
|
+
)
|
367
|
+
|
368
|
+
return jax.tree.map(lambda x: x * (penetration < 0), (a_ref, R, K, D))
|
369
|
+
|
370
|
+
M_L = js.model.link_spatial_inertia_matrices(model=model)
|
371
|
+
|
372
|
+
a_ref, R, K, D = jax.tree.map(
|
373
|
+
jnp.concatenate,
|
374
|
+
(
|
375
|
+
*jax.vmap(_compute_row)(
|
376
|
+
link_idx=jnp.array(
|
377
|
+
model.kin_dyn_parameters.contact_parameters.body
|
378
|
+
),
|
379
|
+
penetration=penetration,
|
380
|
+
velocity=velocity,
|
381
|
+
),
|
382
|
+
),
|
383
|
+
)
|
384
|
+
return a_ref, jnp.diag(R), K, D
|
jaxsim/rbda/contacts/rigid.py
CHANGED
@@ -9,7 +9,6 @@ import jax_dataclasses
|
|
9
9
|
|
10
10
|
import jaxsim.api as js
|
11
11
|
import jaxsim.typing as jtp
|
12
|
-
from jaxsim import math
|
13
12
|
from jaxsim.api.common import ModelDataWithVelocityRepresentation, VelRepr
|
14
13
|
from jaxsim.terrain import FlatTerrain, Terrain
|
15
14
|
|
@@ -272,9 +271,17 @@ class RigidContacts(ContactModel):
|
|
272
271
|
link_forces=link_forces,
|
273
272
|
)
|
274
273
|
|
275
|
-
with
|
276
|
-
|
277
|
-
|
274
|
+
with (
|
275
|
+
references.switch_velocity_representation(VelRepr.Mixed),
|
276
|
+
data.switch_velocity_representation(VelRepr.Mixed),
|
277
|
+
):
|
278
|
+
BW_ν̇_free = jnp.hstack(
|
279
|
+
js.ode.system_acceleration(
|
280
|
+
model=model,
|
281
|
+
data=data,
|
282
|
+
joint_forces=references.joint_force_references(model=model),
|
283
|
+
link_forces=references.link_forces(model=model, data=data),
|
284
|
+
)
|
278
285
|
)
|
279
286
|
|
280
287
|
free_contact_acc = RigidContacts._linear_acceleration_of_collidable_points(
|
@@ -380,43 +387,6 @@ class RigidContacts(ContactModel):
|
|
380
387
|
n_constraints = 6 * n_collidable_points
|
381
388
|
return jnp.zeros(shape=(n_constraints,))
|
382
389
|
|
383
|
-
@staticmethod
|
384
|
-
def _compute_mixed_nu_dot_free(
|
385
|
-
model: js.model.JaxSimModel,
|
386
|
-
data: js.data.JaxSimModelData,
|
387
|
-
references: js.references.JaxSimModelReferences | None = None,
|
388
|
-
) -> jtp.Array:
|
389
|
-
references = (
|
390
|
-
references
|
391
|
-
if references is not None
|
392
|
-
else js.references.JaxSimModelReferences.zero(model=model, data=data)
|
393
|
-
)
|
394
|
-
|
395
|
-
with (
|
396
|
-
data.switch_velocity_representation(VelRepr.Mixed),
|
397
|
-
references.switch_velocity_representation(VelRepr.Mixed),
|
398
|
-
):
|
399
|
-
BW_v_WB = data.base_velocity()
|
400
|
-
W_ṗ_B, W_ω_WB = jnp.split(BW_v_WB, 2)
|
401
|
-
W_v̇_WB, s̈ = js.ode.system_acceleration(
|
402
|
-
model=model,
|
403
|
-
data=data,
|
404
|
-
joint_forces=references.joint_force_references(model=model),
|
405
|
-
link_forces=references.link_forces(model=model, data=data),
|
406
|
-
)
|
407
|
-
|
408
|
-
# Convert the inertial-fixed base acceleration to a mixed base acceleration.
|
409
|
-
W_H_B = data.base_transform()
|
410
|
-
W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3))
|
411
|
-
BW_X_W = math.Adjoint.from_transform(W_H_BW, inverse=True)
|
412
|
-
term1 = BW_X_W @ W_v̇_WB
|
413
|
-
term2 = jnp.zeros(6).at[0:3].set(jnp.cross(W_ṗ_B, W_ω_WB))
|
414
|
-
BW_v̇_WB = term1 - term2
|
415
|
-
|
416
|
-
BW_ν̇ = jnp.hstack([BW_v̇_WB, s̈])
|
417
|
-
|
418
|
-
return BW_ν̇
|
419
|
-
|
420
390
|
@staticmethod
|
421
391
|
def _linear_acceleration_of_collidable_points(
|
422
392
|
model: js.model.JaxSimModel,
|
jaxsim/terrain/terrain.py
CHANGED
@@ -46,66 +46,82 @@ class Terrain(abc.ABC):
|
|
46
46
|
@jax_dataclasses.pytree_dataclass
|
47
47
|
class FlatTerrain(Terrain):
|
48
48
|
|
49
|
-
|
49
|
+
_height: float = dataclasses.field(default=0.0, kw_only=True)
|
50
50
|
|
51
51
|
@staticmethod
|
52
52
|
def build(height: jtp.FloatLike) -> FlatTerrain:
|
53
53
|
|
54
|
-
return FlatTerrain(
|
54
|
+
return FlatTerrain(_height=float(height))
|
55
55
|
|
56
56
|
def height(self, x: jtp.FloatLike, y: jtp.FloatLike) -> jtp.Float:
|
57
57
|
|
58
|
-
return jnp.array(self.
|
58
|
+
return jnp.array(self._height, dtype=float)
|
59
|
+
|
60
|
+
def normal(self, x: jtp.FloatLike, y: jtp.FloatLike) -> jtp.Vector:
|
61
|
+
|
62
|
+
return jnp.array([0.0, 0.0, 1.0], dtype=float)
|
59
63
|
|
60
64
|
def __hash__(self) -> int:
|
61
65
|
|
62
|
-
return hash(self.
|
66
|
+
return hash(self._height)
|
63
67
|
|
64
68
|
def __eq__(self, other: FlatTerrain) -> bool:
|
65
69
|
|
66
70
|
if not isinstance(other, FlatTerrain):
|
67
71
|
return False
|
68
72
|
|
69
|
-
return self.
|
73
|
+
return self._height == other._height
|
70
74
|
|
71
75
|
|
72
76
|
@jax_dataclasses.pytree_dataclass
|
73
77
|
class PlaneTerrain(FlatTerrain):
|
74
78
|
|
75
|
-
|
79
|
+
_normal: tuple[float, float, float] = jax_dataclasses.field(
|
76
80
|
default=(0.0, 0.0, 1.0), kw_only=True
|
77
81
|
)
|
78
82
|
|
79
83
|
@staticmethod
|
80
|
-
def build(
|
81
|
-
plane_normal: jtp.VectorLike, plane_height_over_origin: jtp.FloatLike = 0.0
|
82
|
-
) -> PlaneTerrain:
|
84
|
+
def build(height: jtp.FloatLike = 0.0, *, normal: jtp.VectorLike) -> PlaneTerrain:
|
83
85
|
"""
|
84
86
|
Create a PlaneTerrain instance with a specified plane normal vector.
|
85
87
|
|
86
88
|
Args:
|
87
|
-
|
88
|
-
|
89
|
+
normal: The normal vector of the terrain plane.
|
90
|
+
height: The height of the plane over the origin.
|
89
91
|
|
90
92
|
Returns:
|
91
93
|
PlaneTerrain: A PlaneTerrain instance.
|
92
94
|
"""
|
93
95
|
|
94
|
-
|
95
|
-
|
96
|
+
normal = jnp.array(normal, dtype=float)
|
97
|
+
height = jnp.array(height, dtype=float)
|
96
98
|
|
97
|
-
if
|
99
|
+
if normal.shape != (3,):
|
98
100
|
msg = "Expected a 3D vector for the plane normal, got '{}'."
|
99
|
-
raise ValueError(msg.format(
|
101
|
+
raise ValueError(msg.format(normal.shape))
|
100
102
|
|
101
103
|
# Make sure that the plane normal is a unit vector.
|
102
|
-
|
104
|
+
normal = normal / jnp.linalg.norm(normal)
|
103
105
|
|
104
106
|
return PlaneTerrain(
|
105
|
-
|
106
|
-
|
107
|
+
_height=height.item(),
|
108
|
+
_normal=tuple(normal.tolist()),
|
107
109
|
)
|
108
110
|
|
111
|
+
def normal(self, x: jtp.FloatLike, y: jtp.FloatLike) -> jtp.Vector:
|
112
|
+
"""
|
113
|
+
Compute the normal vector of the terrain at a specific (x, y) location.
|
114
|
+
|
115
|
+
Args:
|
116
|
+
x: The x-coordinate of the location.
|
117
|
+
y: The y-coordinate of the location.
|
118
|
+
|
119
|
+
Returns:
|
120
|
+
The normal vector of the terrain surface at the specified location.
|
121
|
+
"""
|
122
|
+
|
123
|
+
return jnp.array(self._normal, dtype=float)
|
124
|
+
|
109
125
|
def height(self, x: jtp.FloatLike, y: jtp.FloatLike) -> jtp.Float:
|
110
126
|
"""
|
111
127
|
Compute the height of the terrain at a specific (x, y) location on a plane.
|
@@ -123,10 +139,10 @@ class PlaneTerrain(FlatTerrain):
|
|
123
139
|
# The height over the origin: -D/C
|
124
140
|
|
125
141
|
# Get the plane equation coefficients from the terrain normal.
|
126
|
-
A, B, C = self.
|
142
|
+
A, B, C = self._normal
|
127
143
|
|
128
144
|
# Compute the final coefficient D considering the terrain height.
|
129
|
-
D = -C * self.
|
145
|
+
D = -C * self._height
|
130
146
|
|
131
147
|
# Invert the plane equation to get the height at the given (x, y) coordinates.
|
132
148
|
return jnp.array(-(A * x + B * y + D) / C).astype(float)
|
@@ -137,9 +153,9 @@ class PlaneTerrain(FlatTerrain):
|
|
137
153
|
|
138
154
|
return hash(
|
139
155
|
(
|
140
|
-
hash(self.
|
156
|
+
hash(self._height),
|
141
157
|
HashedNumpyArray.hash_of_array(
|
142
|
-
array=jnp.array(self.
|
158
|
+
array=jnp.array(self._normal, dtype=float)
|
143
159
|
),
|
144
160
|
)
|
145
161
|
)
|
@@ -150,10 +166,10 @@ class PlaneTerrain(FlatTerrain):
|
|
150
166
|
return False
|
151
167
|
|
152
168
|
if not (
|
153
|
-
np.allclose(self.
|
169
|
+
np.allclose(self._height, other._height)
|
154
170
|
and np.allclose(
|
155
|
-
np.array(self.
|
156
|
-
np.array(other.
|
171
|
+
np.array(self._normal, dtype=float),
|
172
|
+
np.array(other._normal, dtype=float),
|
157
173
|
)
|
158
174
|
):
|
159
175
|
return False
|