jaxsim 0.4.3.dev12__py3-none-any.whl → 0.4.3.dev18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/_version.py +2 -2
- jaxsim/api/common.py +2 -2
- jaxsim/api/contact.py +37 -9
- jaxsim/api/data.py +1 -1
- jaxsim/api/frame.py +1 -1
- jaxsim/api/joint.py +1 -1
- jaxsim/api/link.py +1 -1
- jaxsim/api/model.py +62 -8
- jaxsim/api/ode.py +114 -36
- jaxsim/api/ode_data.py +11 -7
- jaxsim/integrators/common.py +30 -21
- jaxsim/integrators/variable_step.py +2 -2
- jaxsim/logging.py +1 -2
- jaxsim/math/inertia.py +1 -3
- jaxsim/math/joint_model.py +1 -1
- jaxsim/math/rotation.py +1 -3
- jaxsim/mujoco/loaders.py +2 -1
- jaxsim/mujoco/model.py +2 -1
- jaxsim/mujoco/visualizer.py +2 -2
- jaxsim/parsers/descriptions/model.py +1 -1
- jaxsim/parsers/kinematic_graph.py +4 -3
- jaxsim/parsers/rod/parser.py +10 -10
- jaxsim/rbda/contacts/common.py +3 -2
- jaxsim/rbda/contacts/rigid.py +478 -0
- jaxsim/rbda/rnea.py +5 -7
- jaxsim/utils/jaxsim_dataclass.py +3 -3
- jaxsim/utils/wrappers.py +2 -1
- {jaxsim-0.4.3.dev12.dist-info → jaxsim-0.4.3.dev18.dist-info}/METADATA +2 -1
- {jaxsim-0.4.3.dev12.dist-info → jaxsim-0.4.3.dev18.dist-info}/RECORD +32 -31
- {jaxsim-0.4.3.dev12.dist-info → jaxsim-0.4.3.dev18.dist-info}/WHEEL +1 -1
- {jaxsim-0.4.3.dev12.dist-info → jaxsim-0.4.3.dev18.dist-info}/LICENSE +0 -0
- {jaxsim-0.4.3.dev12.dist-info → jaxsim-0.4.3.dev18.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,478 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import dataclasses
|
4
|
+
from typing import Any
|
5
|
+
|
6
|
+
import jax
|
7
|
+
import jax.numpy as jnp
|
8
|
+
import jax_dataclasses
|
9
|
+
|
10
|
+
import jaxsim.api as js
|
11
|
+
import jaxsim.typing as jtp
|
12
|
+
from jaxsim import math
|
13
|
+
from jaxsim.api.common import ModelDataWithVelocityRepresentation, VelRepr
|
14
|
+
from jaxsim.terrain import FlatTerrain, Terrain
|
15
|
+
|
16
|
+
from .common import ContactModel, ContactsParams, ContactsState
|
17
|
+
|
18
|
+
|
19
|
+
@jax_dataclasses.pytree_dataclass
|
20
|
+
class RigidContactsParams(ContactsParams):
|
21
|
+
"""Parameters of the rigid contacts model."""
|
22
|
+
|
23
|
+
# Static friction coefficient
|
24
|
+
mu: jtp.Float = dataclasses.field(
|
25
|
+
default_factory=lambda: jnp.array(0.5, dtype=float)
|
26
|
+
)
|
27
|
+
|
28
|
+
# Baumgarte proportional term
|
29
|
+
K: jtp.Float = dataclasses.field(
|
30
|
+
default_factory=lambda: jnp.array(0.0, dtype=float)
|
31
|
+
)
|
32
|
+
|
33
|
+
# Baumgarte derivative term
|
34
|
+
D: jtp.Float = dataclasses.field(
|
35
|
+
default_factory=lambda: jnp.array(0.0, dtype=float)
|
36
|
+
)
|
37
|
+
|
38
|
+
def __hash__(self) -> int:
|
39
|
+
from jaxsim.utils.wrappers import HashedNumpyArray
|
40
|
+
|
41
|
+
return hash(
|
42
|
+
(
|
43
|
+
HashedNumpyArray.hash_of_array(self.mu),
|
44
|
+
HashedNumpyArray.hash_of_array(self.K),
|
45
|
+
HashedNumpyArray.hash_of_array(self.D),
|
46
|
+
)
|
47
|
+
)
|
48
|
+
|
49
|
+
def __eq__(self, other: RigidContactsParams) -> bool:
|
50
|
+
return hash(self) == hash(other)
|
51
|
+
|
52
|
+
@classmethod
|
53
|
+
def build(
|
54
|
+
cls,
|
55
|
+
mu: jtp.FloatLike | None = None,
|
56
|
+
K: jtp.FloatLike | None = None,
|
57
|
+
D: jtp.FloatLike | None = None,
|
58
|
+
) -> RigidContactsParams:
|
59
|
+
"""Create a `RigidContactParams` instance"""
|
60
|
+
return RigidContactsParams(
|
61
|
+
mu=mu or cls.__dataclass_fields__["mu"].default,
|
62
|
+
K=K or cls.__dataclass_fields__["K"].default,
|
63
|
+
D=D or cls.__dataclass_fields__["D"].default,
|
64
|
+
)
|
65
|
+
|
66
|
+
def valid(self) -> bool:
|
67
|
+
return bool(
|
68
|
+
jnp.all(self.mu >= 0.0)
|
69
|
+
and jnp.all(self.K >= 0.0)
|
70
|
+
and jnp.all(self.D >= 0.0)
|
71
|
+
)
|
72
|
+
|
73
|
+
|
74
|
+
@jax_dataclasses.pytree_dataclass
|
75
|
+
class RigidContactsState(ContactsState):
|
76
|
+
"""Class storing the state of the rigid contacts model."""
|
77
|
+
|
78
|
+
def __eq__(self, other: RigidContactsState) -> bool:
|
79
|
+
return hash(self) == hash(other)
|
80
|
+
|
81
|
+
@staticmethod
|
82
|
+
def build(**kwargs) -> RigidContactsState:
|
83
|
+
"""Create a `RigidContactsState` instance"""
|
84
|
+
|
85
|
+
return RigidContactsState()
|
86
|
+
|
87
|
+
@staticmethod
|
88
|
+
def zero(**kwargs) -> RigidContactsState:
|
89
|
+
"""Build a zero `RigidContactsState` instance from a `JaxSimModel`."""
|
90
|
+
return RigidContactsState.build()
|
91
|
+
|
92
|
+
def valid(self, **kwargs) -> bool:
|
93
|
+
return True
|
94
|
+
|
95
|
+
|
96
|
+
@jax_dataclasses.pytree_dataclass
|
97
|
+
class RigidContacts(ContactModel):
|
98
|
+
"""Rigid contacts model."""
|
99
|
+
|
100
|
+
parameters: RigidContactsParams = dataclasses.field(
|
101
|
+
default_factory=RigidContactsParams
|
102
|
+
)
|
103
|
+
|
104
|
+
terrain: jax_dataclasses.Static[Terrain] = dataclasses.field(
|
105
|
+
default_factory=FlatTerrain
|
106
|
+
)
|
107
|
+
|
108
|
+
@staticmethod
|
109
|
+
def detect_contacts(
|
110
|
+
W_p_C: jtp.ArrayLike,
|
111
|
+
terrain_height: jtp.ArrayLike,
|
112
|
+
) -> tuple[jtp.Vector, jtp.Vector]:
|
113
|
+
"""
|
114
|
+
Detect contacts between the collidable points and the terrain.
|
115
|
+
|
116
|
+
Args:
|
117
|
+
W_p_C: The position of the collidable points.
|
118
|
+
terrain_height: The height of the terrain at the collidable point position.
|
119
|
+
|
120
|
+
Returns:
|
121
|
+
A tuple containing the activation state of the collidable points and the contact penetration depth h.
|
122
|
+
"""
|
123
|
+
|
124
|
+
# TODO: reduce code duplication with js.contact.in_contact
|
125
|
+
def detect_contact(
|
126
|
+
W_p_C: jtp.ArrayLike,
|
127
|
+
terrain_height: jtp.FloatLike,
|
128
|
+
) -> tuple[jtp.Bool, jtp.Float]:
|
129
|
+
"""
|
130
|
+
Detect contacts between the collidable points and the terrain.
|
131
|
+
"""
|
132
|
+
|
133
|
+
# Unpack the position of the collidable point.
|
134
|
+
_, _, pz = W_p_C.squeeze()
|
135
|
+
|
136
|
+
inactive = pz > terrain_height
|
137
|
+
|
138
|
+
# Compute contact penetration depth
|
139
|
+
h = jnp.maximum(0.0, terrain_height - pz)
|
140
|
+
|
141
|
+
return inactive, h
|
142
|
+
|
143
|
+
inactive_collidable_points, h = jax.vmap(detect_contact)(W_p_C, terrain_height)
|
144
|
+
|
145
|
+
return inactive_collidable_points, h
|
146
|
+
|
147
|
+
@staticmethod
|
148
|
+
def compute_impact_velocity(
|
149
|
+
inactive_collidable_points: jtp.ArrayLike,
|
150
|
+
M: jtp.MatrixLike,
|
151
|
+
J_WC: jtp.MatrixLike,
|
152
|
+
data: js.data.JaxSimModelData,
|
153
|
+
) -> jtp.Vector:
|
154
|
+
"""Returns the new velocity of the system after a potential impact.
|
155
|
+
|
156
|
+
Args:
|
157
|
+
inactive_collidable_points: The activation state of the collidable points.
|
158
|
+
M: The mass matrix of the system.
|
159
|
+
J_WC: The Jacobian matrix of the collidable points.
|
160
|
+
data: The `JaxSimModelData` instance.
|
161
|
+
"""
|
162
|
+
|
163
|
+
def impact_velocity(
|
164
|
+
inactive_collidable_points: jtp.ArrayLike,
|
165
|
+
nu_pre: jtp.ArrayLike,
|
166
|
+
M: jtp.MatrixLike,
|
167
|
+
J_WC: jtp.MatrixLike,
|
168
|
+
data: js.data.JaxSimModelData,
|
169
|
+
):
|
170
|
+
# Compute system velocity after impact maintaining zero linear velocity of active points
|
171
|
+
with data.switch_velocity_representation(VelRepr.Mixed):
|
172
|
+
sl = jnp.s_[:, 0:3, :]
|
173
|
+
Jl_WC = J_WC[sl]
|
174
|
+
# Zero out the jacobian rows of inactive points
|
175
|
+
Jl_WC = jnp.vstack(
|
176
|
+
jnp.where(
|
177
|
+
inactive_collidable_points[:, jnp.newaxis, jnp.newaxis],
|
178
|
+
jnp.zeros_like(Jl_WC),
|
179
|
+
Jl_WC,
|
180
|
+
)
|
181
|
+
)
|
182
|
+
|
183
|
+
A = jnp.vstack(
|
184
|
+
[
|
185
|
+
jnp.hstack([M, -Jl_WC.T]),
|
186
|
+
jnp.hstack(
|
187
|
+
[Jl_WC, jnp.zeros((Jl_WC.shape[0], Jl_WC.shape[0]))]
|
188
|
+
),
|
189
|
+
]
|
190
|
+
)
|
191
|
+
b = jnp.hstack([M @ nu_pre, jnp.zeros(Jl_WC.shape[0])])
|
192
|
+
x = jnp.linalg.lstsq(A, b)[0]
|
193
|
+
nu_post = x[0 : M.shape[0]]
|
194
|
+
|
195
|
+
return nu_post
|
196
|
+
|
197
|
+
with data.switch_velocity_representation(VelRepr.Mixed):
|
198
|
+
BW_ν_pre_impact = data.generalized_velocity()
|
199
|
+
|
200
|
+
BW_ν_post_impact = impact_velocity(
|
201
|
+
data=data,
|
202
|
+
inactive_collidable_points=inactive_collidable_points,
|
203
|
+
nu_pre=BW_ν_pre_impact,
|
204
|
+
M=M,
|
205
|
+
J_WC=J_WC,
|
206
|
+
)
|
207
|
+
|
208
|
+
return BW_ν_post_impact
|
209
|
+
|
210
|
+
def compute_contact_forces(
|
211
|
+
self,
|
212
|
+
position: jtp.Vector,
|
213
|
+
velocity: jtp.Vector,
|
214
|
+
model: js.model.JaxSimModel,
|
215
|
+
data: js.data.JaxSimModelData,
|
216
|
+
link_forces: jtp.MatrixLike | None = None,
|
217
|
+
regularization_term: jtp.FloatLike = 1e-6,
|
218
|
+
) -> tuple[jtp.Vector, tuple[Any, ...]]:
|
219
|
+
"""
|
220
|
+
Compute the contact forces.
|
221
|
+
|
222
|
+
Args:
|
223
|
+
position: The position of the collidable point.
|
224
|
+
velocity: The linear velocity of the collidable point.
|
225
|
+
model: The `JaxSimModel` instance.
|
226
|
+
data: The `JaxSimModelData` instance.
|
227
|
+
link_forces:
|
228
|
+
Optional `(n_links, 6)` matrix of external forces acting on the links,
|
229
|
+
expressed in the same representation of data.
|
230
|
+
regularization_term:
|
231
|
+
The regularization term to add to the diagonal of the Delassus
|
232
|
+
matrix for better numerical conditioning.
|
233
|
+
|
234
|
+
Returns:
|
235
|
+
A tuple containing the contact forces.
|
236
|
+
"""
|
237
|
+
|
238
|
+
# Import qpax just in this method
|
239
|
+
import qpax
|
240
|
+
|
241
|
+
link_forces = (
|
242
|
+
link_forces
|
243
|
+
if link_forces is not None
|
244
|
+
else jnp.zeros((model.number_of_links(), 6))
|
245
|
+
)
|
246
|
+
|
247
|
+
# Compute kin-dyn quantities used in the contact model
|
248
|
+
with data.switch_velocity_representation(VelRepr.Mixed):
|
249
|
+
M = js.model.free_floating_mass_matrix(model=model, data=data)
|
250
|
+
J_WC = js.contact.jacobian(model=model, data=data)
|
251
|
+
W_H_C = js.contact.transforms(model=model, data=data)
|
252
|
+
terrain_height = jax.vmap(self.terrain.height)(position[:, 0], position[:, 1])
|
253
|
+
n_collidable_points = model.kin_dyn_parameters.contact_parameters.point.shape[0]
|
254
|
+
|
255
|
+
# Compute the activation state of the collidable points
|
256
|
+
inactive_collidable_points, h = RigidContacts.detect_contacts(
|
257
|
+
W_p_C=position,
|
258
|
+
terrain_height=terrain_height,
|
259
|
+
)
|
260
|
+
|
261
|
+
delassus_matrix = RigidContacts._delassus_matrix(M=M, J_WC=J_WC)
|
262
|
+
|
263
|
+
# Add regularization for better numerical conditioning
|
264
|
+
delassus_matrix = delassus_matrix + regularization_term * jnp.eye(
|
265
|
+
delassus_matrix.shape[0]
|
266
|
+
)
|
267
|
+
|
268
|
+
references = js.references.JaxSimModelReferences.build(
|
269
|
+
model=model,
|
270
|
+
data=data,
|
271
|
+
velocity_representation=data.velocity_representation,
|
272
|
+
link_forces=link_forces,
|
273
|
+
)
|
274
|
+
|
275
|
+
with references.switch_velocity_representation(VelRepr.Mixed):
|
276
|
+
BW_ν̇_free = RigidContacts._compute_mixed_nu_dot_free(
|
277
|
+
model, data, references=references
|
278
|
+
)
|
279
|
+
|
280
|
+
free_contact_acc = RigidContacts._linear_acceleration_of_collidable_points(
|
281
|
+
model,
|
282
|
+
data,
|
283
|
+
BW_ν̇_free,
|
284
|
+
).flatten()
|
285
|
+
|
286
|
+
# Compute stabilization term
|
287
|
+
ḣ = velocity[:, 2].squeeze()
|
288
|
+
baumgarte_term = RigidContacts._compute_baumgarte_stabilization_term(
|
289
|
+
inactive_collidable_points=inactive_collidable_points,
|
290
|
+
h=h,
|
291
|
+
ḣ=ḣ,
|
292
|
+
K=self.parameters.K,
|
293
|
+
D=self.parameters.D,
|
294
|
+
).flatten()
|
295
|
+
|
296
|
+
free_contact_acc -= baumgarte_term
|
297
|
+
|
298
|
+
# Setup optimization problem
|
299
|
+
Q = delassus_matrix
|
300
|
+
q = free_contact_acc
|
301
|
+
G = RigidContacts._compute_ineq_constraint_matrix(
|
302
|
+
inactive_collidable_points=inactive_collidable_points, mu=self.parameters.mu
|
303
|
+
)
|
304
|
+
h_bounds = RigidContacts._compute_ineq_bounds(
|
305
|
+
n_collidable_points=n_collidable_points
|
306
|
+
)
|
307
|
+
A = jnp.zeros((0, 3 * n_collidable_points))
|
308
|
+
b = jnp.zeros((0,))
|
309
|
+
|
310
|
+
# Solve the optimization problem
|
311
|
+
solution, *_ = qpax.solve_qp(Q=Q, q=q, A=A, b=b, G=G, h=h_bounds)
|
312
|
+
|
313
|
+
f_C_lin = solution.reshape(-1, 3)
|
314
|
+
|
315
|
+
# Transform linear contact forces to 6D
|
316
|
+
CW_f_C = jnp.hstack(
|
317
|
+
(
|
318
|
+
f_C_lin,
|
319
|
+
jnp.zeros((f_C_lin.shape[0], 3)),
|
320
|
+
)
|
321
|
+
)
|
322
|
+
|
323
|
+
# Transform the contact forces to inertial-fixed representation
|
324
|
+
W_f_C = jax.vmap(
|
325
|
+
lambda CW_f_C, W_H_C: ModelDataWithVelocityRepresentation.other_representation_to_inertial(
|
326
|
+
array=CW_f_C,
|
327
|
+
transform=W_H_C,
|
328
|
+
other_representation=VelRepr.Mixed,
|
329
|
+
is_force=True,
|
330
|
+
),
|
331
|
+
)(
|
332
|
+
CW_f_C,
|
333
|
+
W_H_C,
|
334
|
+
)
|
335
|
+
|
336
|
+
return W_f_C, ()
|
337
|
+
|
338
|
+
@staticmethod
|
339
|
+
def _delassus_matrix(
|
340
|
+
M: jtp.MatrixLike,
|
341
|
+
J_WC: jtp.MatrixLike,
|
342
|
+
) -> jtp.Matrix:
|
343
|
+
sl = jnp.s_[:, 0:3, :]
|
344
|
+
J_WC_lin = jnp.vstack(J_WC[sl])
|
345
|
+
|
346
|
+
delassus_matrix = J_WC_lin @ jnp.linalg.pinv(M) @ J_WC_lin.T
|
347
|
+
return delassus_matrix
|
348
|
+
|
349
|
+
@staticmethod
|
350
|
+
def _compute_ineq_constraint_matrix(
|
351
|
+
inactive_collidable_points: jtp.Vector, mu: jtp.FloatLike
|
352
|
+
) -> jtp.Matrix:
|
353
|
+
def compute_G_single_point(mu: float, c: float) -> jtp.Matrix:
|
354
|
+
"""
|
355
|
+
Compute the inequality constraint matrix for a single collidable point
|
356
|
+
Rows 0-3: enforce the friction pyramid constraint,
|
357
|
+
Row 4: last one is for the non negativity of the vertical force
|
358
|
+
Row 5: contact complementarity condition
|
359
|
+
"""
|
360
|
+
G_single_point = jnp.array(
|
361
|
+
[
|
362
|
+
[1, 0, -mu],
|
363
|
+
[0, 1, -mu],
|
364
|
+
[-1, 0, -mu],
|
365
|
+
[0, -1, -mu],
|
366
|
+
[0, 0, -1],
|
367
|
+
[0, 0, c],
|
368
|
+
]
|
369
|
+
)
|
370
|
+
return G_single_point
|
371
|
+
|
372
|
+
G = jax.vmap(compute_G_single_point, in_axes=(None, 0))(
|
373
|
+
mu, inactive_collidable_points
|
374
|
+
)
|
375
|
+
G = jax.scipy.linalg.block_diag(*G)
|
376
|
+
return G
|
377
|
+
|
378
|
+
@staticmethod
|
379
|
+
def _compute_ineq_bounds(n_collidable_points: jtp.FloatLike) -> jtp.Vector:
|
380
|
+
n_constraints = 6 * n_collidable_points
|
381
|
+
return jnp.zeros(shape=(n_constraints,))
|
382
|
+
|
383
|
+
@staticmethod
|
384
|
+
def _compute_mixed_nu_dot_free(
|
385
|
+
model: js.model.JaxSimModel,
|
386
|
+
data: js.data.JaxSimModelData,
|
387
|
+
references: js.references.JaxSimModelReferences | None = None,
|
388
|
+
) -> jtp.Array:
|
389
|
+
references = (
|
390
|
+
references
|
391
|
+
if references is not None
|
392
|
+
else js.references.JaxSimModelReferences.zero(model=model, data=data)
|
393
|
+
)
|
394
|
+
|
395
|
+
with (
|
396
|
+
data.switch_velocity_representation(VelRepr.Mixed),
|
397
|
+
references.switch_velocity_representation(VelRepr.Mixed),
|
398
|
+
):
|
399
|
+
BW_v_WB = data.base_velocity()
|
400
|
+
W_ṗ_B, W_ω_WB = jnp.split(BW_v_WB, 2)
|
401
|
+
W_v̇_WB, s̈ = js.ode.system_acceleration(
|
402
|
+
model=model,
|
403
|
+
data=data,
|
404
|
+
joint_forces=references.joint_force_references(model=model),
|
405
|
+
link_forces=references.link_forces(model=model, data=data),
|
406
|
+
)
|
407
|
+
|
408
|
+
# Convert the inertial-fixed base acceleration to a mixed base acceleration.
|
409
|
+
W_H_B = data.base_transform()
|
410
|
+
W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3))
|
411
|
+
BW_X_W = math.Adjoint.from_transform(W_H_BW, inverse=True)
|
412
|
+
term1 = BW_X_W @ W_v̇_WB
|
413
|
+
term2 = jnp.zeros(6).at[0:3].set(jnp.cross(W_ṗ_B, W_ω_WB))
|
414
|
+
BW_v̇_WB = term1 - term2
|
415
|
+
|
416
|
+
BW_ν̇ = jnp.hstack([BW_v̇_WB, s̈])
|
417
|
+
|
418
|
+
return BW_ν̇
|
419
|
+
|
420
|
+
@staticmethod
|
421
|
+
def _linear_acceleration_of_collidable_points(
|
422
|
+
model: js.model.JaxSimModel,
|
423
|
+
data: js.data.JaxSimModelData,
|
424
|
+
mixed_nu_dot: jtp.ArrayLike,
|
425
|
+
) -> jtp.Matrix:
|
426
|
+
with data.switch_velocity_representation(VelRepr.Mixed):
|
427
|
+
CW_J_WC_BW = js.contact.jacobian(
|
428
|
+
model=model,
|
429
|
+
data=data,
|
430
|
+
output_vel_repr=VelRepr.Mixed,
|
431
|
+
)
|
432
|
+
CW_J̇_WC_BW = js.contact.jacobian_derivative(
|
433
|
+
model=model,
|
434
|
+
data=data,
|
435
|
+
output_vel_repr=VelRepr.Mixed,
|
436
|
+
)
|
437
|
+
|
438
|
+
BW_ν = data.generalized_velocity()
|
439
|
+
BW_ν̇ = mixed_nu_dot
|
440
|
+
|
441
|
+
CW_a_WC = jnp.vstack(CW_J̇_WC_BW) @ BW_ν + jnp.vstack(CW_J_WC_BW) @ BW_ν̇
|
442
|
+
CW_a_WC = CW_a_WC.reshape(-1, 6)
|
443
|
+
|
444
|
+
return CW_a_WC[:, 0:3].squeeze()
|
445
|
+
|
446
|
+
@staticmethod
|
447
|
+
def _compute_baumgarte_stabilization_term(
|
448
|
+
inactive_collidable_points: jtp.ArrayLike,
|
449
|
+
h: jtp.ArrayLike,
|
450
|
+
ḣ: jtp.ArrayLike,
|
451
|
+
K: jtp.FloatLike,
|
452
|
+
D: jtp.FloatLike,
|
453
|
+
) -> jtp.Array:
|
454
|
+
def baumgarte_stabilization(
|
455
|
+
inactive: jtp.BoolLike,
|
456
|
+
h: jtp.FloatLike,
|
457
|
+
ḣ: jtp.FloatLike,
|
458
|
+
k_baumgarte: jtp.FloatLike,
|
459
|
+
d_baumgarte: jtp.FloatLike,
|
460
|
+
) -> jtp.Array:
|
461
|
+
baumgarte_term = jax.lax.cond(
|
462
|
+
inactive,
|
463
|
+
lambda h, ḣ, K, D: jnp.zeros(shape=(3,)),
|
464
|
+
lambda h, ḣ, K, D: jnp.zeros(shape=(3,)).at[2].set(K * h + D * ḣ),
|
465
|
+
*(
|
466
|
+
h,
|
467
|
+
ḣ,
|
468
|
+
k_baumgarte,
|
469
|
+
d_baumgarte,
|
470
|
+
),
|
471
|
+
)
|
472
|
+
return baumgarte_term
|
473
|
+
|
474
|
+
baumgarte_term = jax.vmap(
|
475
|
+
baumgarte_stabilization, in_axes=(0, 0, 0, None, None)
|
476
|
+
)(inactive_collidable_points, h, ḣ, K, D)
|
477
|
+
|
478
|
+
return baumgarte_term
|
jaxsim/rbda/rnea.py
CHANGED
@@ -1,5 +1,3 @@
|
|
1
|
-
from typing import Tuple
|
2
|
-
|
3
1
|
import jax
|
4
2
|
import jax.numpy as jnp
|
5
3
|
import jaxlie
|
@@ -25,7 +23,7 @@ def rnea(
|
|
25
23
|
joint_accelerations: jtp.Vector | None = None,
|
26
24
|
link_forces: jtp.Matrix | None = None,
|
27
25
|
standard_gravity: jtp.FloatLike = StandardGravity,
|
28
|
-
) ->
|
26
|
+
) -> tuple[jtp.Vector, jtp.Vector]:
|
29
27
|
"""
|
30
28
|
Compute inverse dynamics using the Recursive Newton-Euler Algorithm (RNEA).
|
31
29
|
|
@@ -132,12 +130,12 @@ def rnea(
|
|
132
130
|
# Pass 1
|
133
131
|
# ======
|
134
132
|
|
135
|
-
ForwardPassCarry =
|
133
|
+
ForwardPassCarry = tuple[jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix]
|
136
134
|
forward_pass_carry: ForwardPassCarry = (v, a, i_X_0, f)
|
137
135
|
|
138
136
|
def forward_pass(
|
139
137
|
carry: ForwardPassCarry, i: jtp.Int
|
140
|
-
) ->
|
138
|
+
) -> tuple[ForwardPassCarry, None]:
|
141
139
|
|
142
140
|
ii = i - 1
|
143
141
|
v, a, i_X_0, f = carry
|
@@ -186,12 +184,12 @@ def rnea(
|
|
186
184
|
|
187
185
|
τ = jnp.zeros_like(s)
|
188
186
|
|
189
|
-
BackwardPassCarry =
|
187
|
+
BackwardPassCarry = tuple[jtp.Vector, jtp.Matrix]
|
190
188
|
backward_pass_carry: BackwardPassCarry = (τ, f)
|
191
189
|
|
192
190
|
def backward_pass(
|
193
191
|
carry: BackwardPassCarry, i: jtp.Int
|
194
|
-
) ->
|
192
|
+
) -> tuple[BackwardPassCarry, None]:
|
195
193
|
|
196
194
|
ii = i - 1
|
197
195
|
τ, f = carry
|
jaxsim/utils/jaxsim_dataclass.py
CHANGED
@@ -2,8 +2,8 @@ import abc
|
|
2
2
|
import contextlib
|
3
3
|
import dataclasses
|
4
4
|
import functools
|
5
|
-
from collections.abc import Iterator
|
6
|
-
from typing import Any,
|
5
|
+
from collections.abc import Callable, Iterator, Sequence
|
6
|
+
from typing import Any, ClassVar
|
7
7
|
|
8
8
|
import jax.flatten_util
|
9
9
|
import jax_dataclasses
|
@@ -337,7 +337,7 @@ class JaxsimDataclass(abc.ABC):
|
|
337
337
|
return self.flatten_fn()(self)
|
338
338
|
|
339
339
|
@classmethod
|
340
|
-
def flatten_fn(cls:
|
340
|
+
def flatten_fn(cls: type[Self]) -> Callable[[Self], jtp.Vector]:
|
341
341
|
"""
|
342
342
|
Return a function to flatten the object into a 1D vector.
|
343
343
|
|
jaxsim/utils/wrappers.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: jaxsim
|
3
|
-
Version: 0.4.3.
|
3
|
+
Version: 0.4.3.dev18
|
4
4
|
Summary: A differentiable physics engine and multibody dynamics library for control and robot learning.
|
5
5
|
Author-email: Diego Ferigo <dgferigo@gmail.com>
|
6
6
|
Maintainer-email: Diego Ferigo <dgferigo@gmail.com>, Filippo Luca Ferretti <filippo.ferretti@iit.it>
|
@@ -65,6 +65,7 @@ Requires-Dist: jaxlib>=0.4.13
|
|
65
65
|
Requires-Dist: jaxlie>=1.3.0
|
66
66
|
Requires-Dist: jax-dataclasses>=1.4.0
|
67
67
|
Requires-Dist: pptree
|
68
|
+
Requires-Dist: qpax
|
68
69
|
Requires-Dist: rod>=0.3.0
|
69
70
|
Requires-Dist: typing-extensions; python_version < "3.12"
|
70
71
|
Provides-Extra: all
|
@@ -1,48 +1,48 @@
|
|
1
1
|
jaxsim/__init__.py,sha256=ixsS4dYMPex2wOUUp_rkPnwrPhYzkRh1xO_YuMj3Cr4,2626
|
2
|
-
jaxsim/_version.py,sha256=
|
2
|
+
jaxsim/_version.py,sha256=SFJGfO84uy3oOc6jDmWWev6VuJIqHL3tI8_OvaYfdsA,426
|
3
3
|
jaxsim/exceptions.py,sha256=8_h8iqL8DgNR754dR8SZiQ7361GR5V1sUk3ZuZCHw1Q,2069
|
4
|
-
jaxsim/logging.py,sha256=
|
4
|
+
jaxsim/logging.py,sha256=STI-D_upXZYX-ZezLrlJJ0UlD5YspST0vZ_DcIwkzO4,1553
|
5
5
|
jaxsim/typing.py,sha256=IbFx3UkEXi-cm7UBqMPi58rJAFV_HbZ9E_K4JwfNvVM,753
|
6
6
|
jaxsim/api/__init__.py,sha256=8eV22t2S3UwNyCg8karPetG1dmX1VDBXkyv28_FwNQA,210
|
7
7
|
jaxsim/api/com.py,sha256=m-p3EJDhpnMTlXKplfbZE_aH9NqX_VyLlAE3vUhc6l4,13642
|
8
|
-
jaxsim/api/common.py,sha256=
|
9
|
-
jaxsim/api/contact.py,sha256=
|
10
|
-
jaxsim/api/data.py,sha256=
|
11
|
-
jaxsim/api/frame.py,sha256=
|
12
|
-
jaxsim/api/joint.py,sha256=
|
8
|
+
jaxsim/api/common.py,sha256=SNgxq42r6eF_-aPszvOjUYkGwXOzz4hKmhDwEUkscFQ,6650
|
9
|
+
jaxsim/api/contact.py,sha256=HyEAjF7BySDDOlRahN0l7V15IPB0HPXuoM0twamuEW0,20913
|
10
|
+
jaxsim/api/data.py,sha256=CUh9lvhVk3_clNQ26BUBGpjvFSsK_PrVWVMEWpMdHRM,27206
|
11
|
+
jaxsim/api/frame.py,sha256=KS8A5wRfjxhe9NgcVo2QA516iP5zky7UVnWxG7nTa7c,12911
|
12
|
+
jaxsim/api/joint.py,sha256=L81bQe-noPT6_54KOSF7KBjRmEPAS433ULn2EcXI8vI,5115
|
13
13
|
jaxsim/api/kin_dyn_parameters.py,sha256=CcfSg5Mc8qb1mZeMQ4AK_ffZIsK5yOl7tu397pFhcDA,29369
|
14
|
-
jaxsim/api/link.py,sha256=
|
15
|
-
jaxsim/api/model.py,sha256=
|
16
|
-
jaxsim/api/ode.py,sha256=
|
17
|
-
jaxsim/api/ode_data.py,sha256=
|
14
|
+
jaxsim/api/link.py,sha256=qPRtc8qqMRjZxUCZYXJMygbB6huDXBfIT1b1b8Durkw,18631
|
15
|
+
jaxsim/api/model.py,sha256=HXoqCtQ3KStGoxhgvFm8P_Sc-lbEM4l5No2MoHzNlOk,65558
|
16
|
+
jaxsim/api/ode.py,sha256=Vb2sN4zwpXnaJDD9-ziz2qvfmfa4jvIQ0fONbBIRGmU,13368
|
17
|
+
jaxsim/api/ode_data.py,sha256=U7F6TL6bENAxpQQl4PupPoDG7d7VfTTFqDAs3xwu6Hs,20003
|
18
18
|
jaxsim/api/references.py,sha256=XOVKuQXRmjPoP-T5JWGSbqIGX5DzOkeGafqRpj0ZQEM,20771
|
19
19
|
jaxsim/integrators/__init__.py,sha256=hxvOD-VK_mmd6v31wtC-nb28AYve1gLuZCNLV9wS-Kg,103
|
20
|
-
jaxsim/integrators/common.py,sha256=
|
20
|
+
jaxsim/integrators/common.py,sha256=ntjflaV3qWaFH_E65pAGZ6QipdnFsgQDasKtIKpxTe4,20432
|
21
21
|
jaxsim/integrators/fixed_step.py,sha256=KpjRd6hHtapxDoo6D1kyDrVDSHnke2TepI5grFH7_bM,2693
|
22
|
-
jaxsim/integrators/variable_step.py,sha256=
|
22
|
+
jaxsim/integrators/variable_step.py,sha256=5StkFh9oQba34zlkIoXG2fUN78gbxkHePWbrpQ-QZOI,21274
|
23
23
|
jaxsim/math/__init__.py,sha256=8oPITEoGwgRcOeG8KxtqxPQ8b5uku1HNRMokpCoi9Tc,352
|
24
24
|
jaxsim/math/adjoint.py,sha256=o1FCipkGwPtMbN2gFNIyUV8ADF3TX5fxElpTEXK0bIs,4377
|
25
25
|
jaxsim/math/cross.py,sha256=U7yEx_l75mSy5g6O-jsjBztApvxC3WaV4MpkS5tThu4,1330
|
26
|
-
jaxsim/math/inertia.py,sha256=
|
27
|
-
jaxsim/math/joint_model.py,sha256=
|
26
|
+
jaxsim/math/inertia.py,sha256=_hNpoeyEpAGr9ExDQJjckbjhk39luJFF-jv0SKqefnQ,1614
|
27
|
+
jaxsim/math/joint_model.py,sha256=EzAveaG5B6ZnCFNUzN30KEQUVesd83lfWXJarYR-kUw,9989
|
28
28
|
jaxsim/math/quaternion.py,sha256=_WA7W3iv7px83sWO1V1n0-J78hqAlO4SL1-jofE-UZ4,4754
|
29
|
-
jaxsim/math/rotation.py,sha256=
|
29
|
+
jaxsim/math/rotation.py,sha256=k-nwT79zmWrys3NNAB-lGWxat7Kqm_6JnFRoimJ8rBg,2156
|
30
30
|
jaxsim/math/skew.py,sha256=oOGSSR8PUGROl6IJFlrmu6K3gPH-u16hUPfKIkcVv9o,1177
|
31
31
|
jaxsim/math/transform.py,sha256=KXzQgOnCfAtbXCwxhplpJ3F0JT3oEyeLVby1_uRAryQ,2892
|
32
32
|
jaxsim/mujoco/__init__.py,sha256=Zo5GAlN1DYKvX8s1hu1j6HntKIbBMLB9Puv9ouaNAZ8,158
|
33
33
|
jaxsim/mujoco/__main__.py,sha256=GBmB7J-zj75ZnFyuAAmpSOpbxi_HhHhWJeot3ljGDJY,5291
|
34
|
-
jaxsim/mujoco/loaders.py,sha256=
|
35
|
-
jaxsim/mujoco/model.py,sha256=
|
36
|
-
jaxsim/mujoco/visualizer.py,sha256=
|
34
|
+
jaxsim/mujoco/loaders.py,sha256=XB-fgXuWMTFiaand5MZlLFQ5__Sh8MK5CJsxIU34MBk,25328
|
35
|
+
jaxsim/mujoco/model.py,sha256=AQksXemXWACJ3yvefV2G5HLwwBU9ISoJrOD1wlxdY5w,16386
|
36
|
+
jaxsim/mujoco/visualizer.py,sha256=T1vU-w4NKSmgEkZ0FqVcGmIvYrYO0len2UBSsU4MOZ0,6978
|
37
37
|
jaxsim/parsers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
38
|
-
jaxsim/parsers/kinematic_graph.py,sha256=
|
38
|
+
jaxsim/parsers/kinematic_graph.py,sha256=KijMWKyhTLKSNUmOOk4sYQMgPh_OkA_brncL7gBRHaY,34757
|
39
39
|
jaxsim/parsers/descriptions/__init__.py,sha256=PbIlunVfb59pB5jSX97YVpMAANRZPRkJ0X-hS14rzv4,221
|
40
40
|
jaxsim/parsers/descriptions/collision.py,sha256=BQeIG-TKi4SVny23w6riDrQ5itC6VRwEMBX6HgAXHxA,3973
|
41
41
|
jaxsim/parsers/descriptions/joint.py,sha256=VSb6C0FBBKMqwrHBKfc-Bbn4rl_J0RzUxMQlhIEvOPM,5185
|
42
42
|
jaxsim/parsers/descriptions/link.py,sha256=Eh0W5qL7_Uw0GV-BkNKXhm9Q2dRTfIWCX5D-87zQkxA,3711
|
43
|
-
jaxsim/parsers/descriptions/model.py,sha256=
|
43
|
+
jaxsim/parsers/descriptions/model.py,sha256=I2Vsbv8Josl4Le7b5rIvhqA2k9Bbv5JxMqwytayxds0,9833
|
44
44
|
jaxsim/parsers/rod/__init__.py,sha256=G2vqlLajBLUc4gyzXwsEI2Wsi4TMOIF9bLDFeT6KrGU,92
|
45
|
-
jaxsim/parsers/rod/parser.py,sha256=
|
45
|
+
jaxsim/parsers/rod/parser.py,sha256=HskeCqDsbtwH2BDk3vfxvx391wUTVGLaUXNvBrdNo-4,13486
|
46
46
|
jaxsim/parsers/rod/utils.py,sha256=5DsF3OeePZGidOJ5GiFSZx-51uIdnFvMW9EK6SgOW6Q,5698
|
47
47
|
jaxsim/rbda/__init__.py,sha256=H7DhXpxkPOi9lpUvg31IMHFfRafke1UoJLc5GQIdyhA,387
|
48
48
|
jaxsim/rbda/aba.py,sha256=w7ciyxB0IsmueatT0C7PcBQEl9dyiH9oqJgIi3xeTUE,8983
|
@@ -50,19 +50,20 @@ jaxsim/rbda/collidable_points.py,sha256=Rmf1DhflhOTYh9mDalv0agS0CGSbmfoOybwP2KzK
|
|
50
50
|
jaxsim/rbda/crba.py,sha256=zJSiHKRvNU98z2tT9prrWR4VU9wIZQWFwEut7mua6as,5044
|
51
51
|
jaxsim/rbda/forward_kinematics.py,sha256=2GmEoWsrioVl_SAbKRKfhOLz57pY4aR81PKRdulqStA,3458
|
52
52
|
jaxsim/rbda/jacobian.py,sha256=p0EV_8cLzLVV-93VKznT7VPuRj8W7h7rQWkPlWJXfCA,11023
|
53
|
-
jaxsim/rbda/rnea.py,sha256=
|
53
|
+
jaxsim/rbda/rnea.py,sha256=CLfqs9XFVaD-hvkLABshDAfdw5bm_AMV3UVAQ_IvURQ,7542
|
54
54
|
jaxsim/rbda/utils.py,sha256=eeT21Y4DiiyhrdF0lUE_VvRuwru5-rR7yOlOlWzCCWE,5381
|
55
55
|
jaxsim/rbda/contacts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
56
|
-
jaxsim/rbda/contacts/common.py,sha256=
|
56
|
+
jaxsim/rbda/contacts/common.py,sha256=VwAs742futAmLnDgbaOuLzNDBFiKDfYItdEZ4UcFgzE,2467
|
57
|
+
jaxsim/rbda/contacts/rigid.py,sha256=8Vbnxng-ERZ5ka_eZGIBuhBDr2PNjc7m-Or255AfEw4,15862
|
57
58
|
jaxsim/rbda/contacts/soft.py,sha256=_wvb5iZDjGcVg6rNQelN4LZN7qSC2NIp0HdKvZmlGfk,15647
|
58
59
|
jaxsim/terrain/__init__.py,sha256=f7lVX-iNpH_wkkjef9Qpjh19TTAUOQw76EiLYJDVizc,78
|
59
60
|
jaxsim/terrain/terrain.py,sha256=ctyNANIFSM3tZmamprjaEDcWgUSP0oNJbmT1zw9RjPs,4565
|
60
61
|
jaxsim/utils/__init__.py,sha256=Y5zyoRevl3EMVQadhZ4EtSwTEkDt2vcnFoRhPJjKTZ0,215
|
61
|
-
jaxsim/utils/jaxsim_dataclass.py,sha256=
|
62
|
+
jaxsim/utils/jaxsim_dataclass.py,sha256=5xJbY0G8d7C0OTNIW9T4vQxiDak6TGZT9gpNOvRykFI,11373
|
62
63
|
jaxsim/utils/tracing.py,sha256=KDMoyVPlu2NJvFkhtZwq5AkqMMgajt3munvJom-vEjQ,650
|
63
|
-
jaxsim/utils/wrappers.py,sha256=
|
64
|
-
jaxsim-0.4.3.
|
65
|
-
jaxsim-0.4.3.
|
66
|
-
jaxsim-0.4.3.
|
67
|
-
jaxsim-0.4.3.
|
68
|
-
jaxsim-0.4.3.
|
64
|
+
jaxsim/utils/wrappers.py,sha256=JhLUh1g8iU-lhjbuZRfkscPZhYlLCOorVM2Xl3ulRBI,4054
|
65
|
+
jaxsim-0.4.3.dev18.dist-info/LICENSE,sha256=eaYdFmdeMbiIoIiPzEK0MjP1S9wtFXjXNR5er49uLR0,1546
|
66
|
+
jaxsim-0.4.3.dev18.dist-info/METADATA,sha256=aLpRkfa9CC7GVzXMKX3LY5DkCHEmOr4CE-u3Vbt5fx8,17247
|
67
|
+
jaxsim-0.4.3.dev18.dist-info/WHEEL,sha256=cVxcB9AmuTcXqmwrtPhNK88dr7IR_b6qagTj0UvIEbY,91
|
68
|
+
jaxsim-0.4.3.dev18.dist-info/top_level.txt,sha256=LxGMA8FLtXjQ6oI7N5gd_R_oSUHxpXxUEOfT1xS_ni0,7
|
69
|
+
jaxsim-0.4.3.dev18.dist-info/RECORD,,
|
File without changes
|
File without changes
|