jaxsim 0.6.2.dev2__py3-none-any.whl → 0.6.2.dev102__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/__init__.py +1 -1
- jaxsim/_version.py +2 -2
- jaxsim/api/__init__.py +3 -1
- jaxsim/api/actuation_model.py +96 -0
- jaxsim/api/com.py +8 -8
- jaxsim/api/contact.py +15 -255
- jaxsim/api/contact_model.py +101 -0
- jaxsim/api/data.py +258 -556
- jaxsim/api/frame.py +7 -7
- jaxsim/api/integrators.py +76 -0
- jaxsim/api/kin_dyn_parameters.py +41 -58
- jaxsim/api/link.py +7 -7
- jaxsim/api/model.py +190 -453
- jaxsim/api/ode.py +34 -338
- jaxsim/api/references.py +2 -2
- jaxsim/exceptions.py +2 -2
- jaxsim/math/__init__.py +4 -3
- jaxsim/math/joint_model.py +17 -107
- jaxsim/mujoco/model.py +1 -1
- jaxsim/mujoco/utils.py +2 -2
- jaxsim/parsers/kinematic_graph.py +1 -3
- jaxsim/rbda/aba.py +7 -4
- jaxsim/rbda/collidable_points.py +7 -98
- jaxsim/rbda/contacts/__init__.py +2 -10
- jaxsim/rbda/contacts/common.py +0 -138
- jaxsim/rbda/contacts/relaxed_rigid.py +154 -9
- jaxsim/rbda/crba.py +5 -2
- jaxsim/rbda/forward_kinematics.py +37 -12
- jaxsim/rbda/jacobian.py +15 -6
- jaxsim/rbda/rnea.py +7 -4
- jaxsim/rbda/utils.py +3 -3
- jaxsim/utils/jaxsim_dataclass.py +5 -1
- {jaxsim-0.6.2.dev2.dist-info → jaxsim-0.6.2.dev102.dist-info}/METADATA +7 -9
- jaxsim-0.6.2.dev102.dist-info/RECORD +69 -0
- jaxsim/api/ode_data.py +0 -401
- jaxsim/integrators/__init__.py +0 -2
- jaxsim/integrators/common.py +0 -592
- jaxsim/integrators/fixed_step.py +0 -153
- jaxsim/integrators/variable_step.py +0 -706
- jaxsim/rbda/contacts/rigid.py +0 -462
- jaxsim/rbda/contacts/soft.py +0 -480
- jaxsim/rbda/contacts/visco_elastic.py +0 -1066
- jaxsim-0.6.2.dev2.dist-info/RECORD +0 -74
- {jaxsim-0.6.2.dev2.dist-info → jaxsim-0.6.2.dev102.dist-info}/LICENSE +0 -0
- {jaxsim-0.6.2.dev2.dist-info → jaxsim-0.6.2.dev102.dist-info}/WHEEL +0 -0
- {jaxsim-0.6.2.dev2.dist-info → jaxsim-0.6.2.dev102.dist-info}/top_level.txt +0 -0
jaxsim/rbda/contacts/rigid.py
DELETED
@@ -1,462 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
import dataclasses
|
4
|
-
from typing import Any
|
5
|
-
|
6
|
-
import jax
|
7
|
-
import jax.numpy as jnp
|
8
|
-
import jax_dataclasses
|
9
|
-
|
10
|
-
import jaxsim.api as js
|
11
|
-
import jaxsim.typing as jtp
|
12
|
-
from jaxsim import logging
|
13
|
-
from jaxsim.api.common import ModelDataWithVelocityRepresentation, VelRepr
|
14
|
-
|
15
|
-
from . import common
|
16
|
-
from .common import ContactModel, ContactsParams
|
17
|
-
|
18
|
-
try:
|
19
|
-
from typing import Self
|
20
|
-
except ImportError:
|
21
|
-
from typing_extensions import Self
|
22
|
-
|
23
|
-
|
24
|
-
@jax_dataclasses.pytree_dataclass
|
25
|
-
class RigidContactsParams(ContactsParams):
|
26
|
-
"""Parameters of the rigid contacts model."""
|
27
|
-
|
28
|
-
# Static friction coefficient
|
29
|
-
mu: jtp.Float = dataclasses.field(
|
30
|
-
default_factory=lambda: jnp.array(0.5, dtype=float)
|
31
|
-
)
|
32
|
-
|
33
|
-
# Baumgarte proportional term
|
34
|
-
K: jtp.Float = dataclasses.field(
|
35
|
-
default_factory=lambda: jnp.array(0.0, dtype=float)
|
36
|
-
)
|
37
|
-
|
38
|
-
# Baumgarte derivative term
|
39
|
-
D: jtp.Float = dataclasses.field(
|
40
|
-
default_factory=lambda: jnp.array(0.0, dtype=float)
|
41
|
-
)
|
42
|
-
|
43
|
-
def __hash__(self) -> int:
|
44
|
-
from jaxsim.utils.wrappers import HashedNumpyArray
|
45
|
-
|
46
|
-
return hash(
|
47
|
-
(
|
48
|
-
HashedNumpyArray.hash_of_array(self.mu),
|
49
|
-
HashedNumpyArray.hash_of_array(self.K),
|
50
|
-
HashedNumpyArray.hash_of_array(self.D),
|
51
|
-
)
|
52
|
-
)
|
53
|
-
|
54
|
-
def __eq__(self, other: RigidContactsParams) -> bool:
|
55
|
-
return hash(self) == hash(other)
|
56
|
-
|
57
|
-
@classmethod
|
58
|
-
def build(
|
59
|
-
cls: type[Self],
|
60
|
-
*,
|
61
|
-
mu: jtp.FloatLike | None = None,
|
62
|
-
K: jtp.FloatLike | None = None,
|
63
|
-
D: jtp.FloatLike | None = None,
|
64
|
-
) -> Self:
|
65
|
-
"""Create a `RigidContactParams` instance."""
|
66
|
-
|
67
|
-
return cls(
|
68
|
-
mu=jnp.array(
|
69
|
-
mu
|
70
|
-
if mu is not None
|
71
|
-
else cls.__dataclass_fields__["mu"].default_factory()
|
72
|
-
).astype(float),
|
73
|
-
K=jnp.array(
|
74
|
-
K if K is not None else cls.__dataclass_fields__["K"].default_factory()
|
75
|
-
).astype(float),
|
76
|
-
D=jnp.array(
|
77
|
-
D if D is not None else cls.__dataclass_fields__["D"].default_factory()
|
78
|
-
).astype(float),
|
79
|
-
)
|
80
|
-
|
81
|
-
def valid(self) -> jtp.BoolLike:
|
82
|
-
"""Check if the parameters are valid."""
|
83
|
-
return bool(
|
84
|
-
jnp.all(self.mu >= 0.0)
|
85
|
-
and jnp.all(self.K >= 0.0)
|
86
|
-
and jnp.all(self.D >= 0.0)
|
87
|
-
)
|
88
|
-
|
89
|
-
|
90
|
-
@jax_dataclasses.pytree_dataclass
|
91
|
-
class RigidContacts(ContactModel):
|
92
|
-
"""Rigid contacts model."""
|
93
|
-
|
94
|
-
regularization_delassus: jax_dataclasses.Static[float] = dataclasses.field(
|
95
|
-
default=1e-6, kw_only=True
|
96
|
-
)
|
97
|
-
|
98
|
-
_solver_options_keys: jax_dataclasses.Static[tuple[str, ...]] = dataclasses.field(
|
99
|
-
default=("solver_tol",), kw_only=True
|
100
|
-
)
|
101
|
-
_solver_options_values: jax_dataclasses.Static[tuple[Any, ...]] = dataclasses.field(
|
102
|
-
default=(1e-3,), kw_only=True
|
103
|
-
)
|
104
|
-
|
105
|
-
@property
|
106
|
-
def solver_options(self) -> dict[str, Any]:
|
107
|
-
"""Get the solver options as a dictionary."""
|
108
|
-
|
109
|
-
return dict(
|
110
|
-
zip(
|
111
|
-
self._solver_options_keys,
|
112
|
-
self._solver_options_values,
|
113
|
-
strict=True,
|
114
|
-
)
|
115
|
-
)
|
116
|
-
|
117
|
-
@classmethod
|
118
|
-
def build(
|
119
|
-
cls: type[Self],
|
120
|
-
regularization_delassus: jtp.FloatLike | None = None,
|
121
|
-
solver_options: dict[str, Any] | None = None,
|
122
|
-
**kwargs,
|
123
|
-
) -> Self:
|
124
|
-
"""
|
125
|
-
Create a `RigidContacts` instance with specified parameters.
|
126
|
-
|
127
|
-
Args:
|
128
|
-
regularization_delassus:
|
129
|
-
The regularization term to add to the diagonal of the Delassus matrix.
|
130
|
-
solver_options: The options to pass to the QP solver.
|
131
|
-
**kwargs: Extra arguments which are ignored.
|
132
|
-
|
133
|
-
Returns:
|
134
|
-
The `RigidContacts` instance.
|
135
|
-
"""
|
136
|
-
|
137
|
-
if len(kwargs) != 0:
|
138
|
-
logging.debug(msg=f"Ignoring extra arguments: {kwargs}")
|
139
|
-
|
140
|
-
# Get the default solver options.
|
141
|
-
default_solver_options = dict(
|
142
|
-
zip(cls._solver_options_keys, cls._solver_options_values, strict=True)
|
143
|
-
)
|
144
|
-
|
145
|
-
# Create the solver options to set by combining the default solver options
|
146
|
-
# with the user-provided solver options.
|
147
|
-
solver_options = default_solver_options | (
|
148
|
-
solver_options if solver_options is not None else {}
|
149
|
-
)
|
150
|
-
|
151
|
-
# Make sure that the solver options are hashable.
|
152
|
-
# We need to check this because the solver options are static.
|
153
|
-
try:
|
154
|
-
hash(tuple(solver_options.values()))
|
155
|
-
except TypeError as exc:
|
156
|
-
raise ValueError(
|
157
|
-
"The values of the solver options must be hashable."
|
158
|
-
) from exc
|
159
|
-
|
160
|
-
return cls(
|
161
|
-
regularization_delassus=float(
|
162
|
-
regularization_delassus
|
163
|
-
if regularization_delassus is not None
|
164
|
-
else cls.__dataclass_fields__["regularization_delassus"].default
|
165
|
-
),
|
166
|
-
_solver_options_keys=tuple(solver_options.keys()),
|
167
|
-
_solver_options_values=tuple(solver_options.values()),
|
168
|
-
**kwargs,
|
169
|
-
)
|
170
|
-
|
171
|
-
@staticmethod
|
172
|
-
def compute_impact_velocity(
|
173
|
-
inactive_collidable_points: jtp.ArrayLike,
|
174
|
-
M: jtp.MatrixLike,
|
175
|
-
J_WC: jtp.MatrixLike,
|
176
|
-
generalized_velocity: jtp.VectorLike,
|
177
|
-
) -> jtp.Vector:
|
178
|
-
"""
|
179
|
-
Return the new velocity of the system after a potential impact.
|
180
|
-
|
181
|
-
Args:
|
182
|
-
inactive_collidable_points: The activation state of the collidable points.
|
183
|
-
M: The mass matrix of the system (in mixed representation).
|
184
|
-
J_WC: The Jacobian matrix of the collidable points (in mixed representation).
|
185
|
-
generalized_velocity: The generalized velocity of the system.
|
186
|
-
|
187
|
-
Note:
|
188
|
-
The mass matrix `M`, the Jacobian `J_WC`, and the generalized velocity `generalized_velocity`
|
189
|
-
must be expressed in the same velocity representation.
|
190
|
-
"""
|
191
|
-
|
192
|
-
# Compute system velocity after impact maintaining zero linear velocity of active points.
|
193
|
-
sl = jnp.s_[:, 0:3, :]
|
194
|
-
Jl_WC = J_WC[sl]
|
195
|
-
|
196
|
-
# Zero out the jacobian rows of inactive points.
|
197
|
-
Jl_WC = jnp.vstack(
|
198
|
-
jnp.where(
|
199
|
-
inactive_collidable_points[:, jnp.newaxis, jnp.newaxis],
|
200
|
-
jnp.zeros_like(Jl_WC),
|
201
|
-
Jl_WC,
|
202
|
-
)
|
203
|
-
)
|
204
|
-
|
205
|
-
A = jnp.vstack(
|
206
|
-
[
|
207
|
-
jnp.hstack([M, -Jl_WC.T]),
|
208
|
-
jnp.hstack([Jl_WC, jnp.zeros((Jl_WC.shape[0], Jl_WC.shape[0]))]),
|
209
|
-
]
|
210
|
-
)
|
211
|
-
b = jnp.hstack([M @ generalized_velocity, jnp.zeros(Jl_WC.shape[0])])
|
212
|
-
|
213
|
-
BW_ν_post_impact = jnp.linalg.lstsq(A, b)[0]
|
214
|
-
|
215
|
-
return BW_ν_post_impact[0 : M.shape[0]]
|
216
|
-
|
217
|
-
@jax.jit
|
218
|
-
def compute_contact_forces(
|
219
|
-
self,
|
220
|
-
model: js.model.JaxSimModel,
|
221
|
-
data: js.data.JaxSimModelData,
|
222
|
-
*,
|
223
|
-
link_forces: jtp.MatrixLike | None = None,
|
224
|
-
joint_force_references: jtp.VectorLike | None = None,
|
225
|
-
) -> tuple[jtp.Matrix, dict[str, jtp.PyTree]]:
|
226
|
-
"""
|
227
|
-
Compute the contact forces.
|
228
|
-
|
229
|
-
Args:
|
230
|
-
model: The model to consider.
|
231
|
-
data: The data of the considered model.
|
232
|
-
link_forces:
|
233
|
-
Optional `(n_links, 6)` matrix of external forces acting on the links,
|
234
|
-
expressed in the same representation of data.
|
235
|
-
joint_force_references:
|
236
|
-
Optional `(n_joints,)` vector of joint forces.
|
237
|
-
|
238
|
-
Returns:
|
239
|
-
A tuple containing as first element the computed contact forces.
|
240
|
-
"""
|
241
|
-
|
242
|
-
# Import qpax privately just in this method.
|
243
|
-
import qpax
|
244
|
-
|
245
|
-
# Get the indices of the enabled collidable points.
|
246
|
-
indices_of_enabled_collidable_points = (
|
247
|
-
model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points
|
248
|
-
)
|
249
|
-
|
250
|
-
n_collidable_points = len(indices_of_enabled_collidable_points)
|
251
|
-
|
252
|
-
link_forces = jnp.atleast_2d(
|
253
|
-
jnp.array(link_forces, dtype=float).squeeze()
|
254
|
-
if link_forces is not None
|
255
|
-
else jnp.zeros((model.number_of_links(), 6))
|
256
|
-
)
|
257
|
-
|
258
|
-
joint_force_references = jnp.atleast_1d(
|
259
|
-
jnp.array(joint_force_references, dtype=float).squeeze()
|
260
|
-
if joint_force_references is not None
|
261
|
-
else jnp.zeros((model.number_of_joints(),))
|
262
|
-
)
|
263
|
-
|
264
|
-
# Compute kin-dyn quantities used in the contact model.
|
265
|
-
with data.switch_velocity_representation(VelRepr.Mixed):
|
266
|
-
BW_ν = data.generalized_velocity()
|
267
|
-
|
268
|
-
M = js.model.free_floating_mass_matrix(model=model, data=data)
|
269
|
-
|
270
|
-
J_WC = js.contact.jacobian(model=model, data=data)
|
271
|
-
J̇_WC = js.contact.jacobian_derivative(model=model, data=data)
|
272
|
-
|
273
|
-
W_H_C = js.contact.transforms(model=model, data=data)
|
274
|
-
|
275
|
-
# Compute the position and linear velocities (mixed representation) of
|
276
|
-
# all enabled collidable points belonging to the robot.
|
277
|
-
position, velocity = js.contact.collidable_point_kinematics(
|
278
|
-
model=model, data=data
|
279
|
-
)
|
280
|
-
|
281
|
-
# Compute the penetration depth and velocity of the collidable points.
|
282
|
-
# Note that this function considers the penetration in the normal direction.
|
283
|
-
δ, δ_dot, n̂ = jax.vmap(common.compute_penetration_data, in_axes=(0, 0, None))(
|
284
|
-
position, velocity, model.terrain
|
285
|
-
)
|
286
|
-
|
287
|
-
# Build a references object to simplify converting link forces.
|
288
|
-
references = js.references.JaxSimModelReferences.build(
|
289
|
-
model=model,
|
290
|
-
data=data,
|
291
|
-
velocity_representation=data.velocity_representation,
|
292
|
-
link_forces=link_forces,
|
293
|
-
joint_force_references=joint_force_references,
|
294
|
-
)
|
295
|
-
|
296
|
-
# Compute the generalized free acceleration.
|
297
|
-
with (
|
298
|
-
references.switch_velocity_representation(VelRepr.Mixed),
|
299
|
-
data.switch_velocity_representation(VelRepr.Mixed),
|
300
|
-
):
|
301
|
-
|
302
|
-
BW_ν̇_free = jnp.hstack(
|
303
|
-
js.ode.system_acceleration(
|
304
|
-
model=model,
|
305
|
-
data=data,
|
306
|
-
link_forces=references.link_forces(model=model, data=data),
|
307
|
-
joint_force_references=references.joint_force_references(
|
308
|
-
model=model
|
309
|
-
),
|
310
|
-
)
|
311
|
-
)
|
312
|
-
|
313
|
-
# Compute the free linear acceleration of the collidable points.
|
314
|
-
# Since we use doubly-mixed jacobian, this corresponds to W_p̈_C.
|
315
|
-
free_contact_acc = RigidContacts._linear_acceleration_of_collidable_points(
|
316
|
-
BW_nu=BW_ν,
|
317
|
-
BW_nu_dot=BW_ν̇_free,
|
318
|
-
CW_J_WC_BW=J_WC,
|
319
|
-
CW_J_dot_WC_BW=J̇_WC,
|
320
|
-
).flatten()
|
321
|
-
|
322
|
-
# Compute stabilization term.
|
323
|
-
baumgarte_term = RigidContacts._compute_baumgarte_stabilization_term(
|
324
|
-
inactive_collidable_points=(δ <= 0),
|
325
|
-
δ=δ,
|
326
|
-
δ_dot=δ_dot,
|
327
|
-
n=n̂,
|
328
|
-
K=data.contacts_params.K,
|
329
|
-
D=data.contacts_params.D,
|
330
|
-
).flatten()
|
331
|
-
|
332
|
-
# Compute the Delassus matrix.
|
333
|
-
delassus_matrix = RigidContacts._delassus_matrix(M=M, J_WC=J_WC)
|
334
|
-
|
335
|
-
# Initialize regularization term of the Delassus matrix for
|
336
|
-
# better numerical conditioning.
|
337
|
-
Iε = self.regularization_delassus * jnp.eye(delassus_matrix.shape[0])
|
338
|
-
|
339
|
-
# Construct the quadratic cost function.
|
340
|
-
Q = delassus_matrix + Iε
|
341
|
-
q = free_contact_acc - baumgarte_term
|
342
|
-
|
343
|
-
# Construct the inequality constraints.
|
344
|
-
G = RigidContacts._compute_ineq_constraint_matrix(
|
345
|
-
inactive_collidable_points=(δ <= 0), mu=data.contacts_params.mu
|
346
|
-
)
|
347
|
-
h_bounds = RigidContacts._compute_ineq_bounds(
|
348
|
-
n_collidable_points=n_collidable_points
|
349
|
-
)
|
350
|
-
|
351
|
-
# Construct the equality constraints.
|
352
|
-
A = jnp.zeros((0, 3 * n_collidable_points))
|
353
|
-
b = jnp.zeros((0,))
|
354
|
-
|
355
|
-
# Solve the following optimization problem with qpax:
|
356
|
-
#
|
357
|
-
# min_{x} 0.5 x⊤ Q x + q⊤ x
|
358
|
-
#
|
359
|
-
# s.t. A x = b
|
360
|
-
# G x ≤ h
|
361
|
-
#
|
362
|
-
# TODO: add possibility to notify if the QP problem did not converge.
|
363
|
-
solution, _, _, _, converged, _ = qpax.solve_qp( # noqa: F841
|
364
|
-
Q=Q, q=q, A=A, b=b, G=G, h=h_bounds, **self.solver_options
|
365
|
-
)
|
366
|
-
|
367
|
-
# Reshape the optimized solution to be a matrix of 3D contact forces.
|
368
|
-
CW_fl_C = solution.reshape(-1, 3)
|
369
|
-
|
370
|
-
# Convert the contact forces from mixed to inertial-fixed representation.
|
371
|
-
W_f_C = jax.vmap(
|
372
|
-
lambda CW_fl_C, W_H_C: (
|
373
|
-
ModelDataWithVelocityRepresentation.other_representation_to_inertial(
|
374
|
-
array=jnp.zeros(6).at[0:3].set(CW_fl_C),
|
375
|
-
transform=W_H_C,
|
376
|
-
other_representation=VelRepr.Mixed,
|
377
|
-
is_force=True,
|
378
|
-
)
|
379
|
-
),
|
380
|
-
)(CW_fl_C, W_H_C)
|
381
|
-
|
382
|
-
return W_f_C, {}
|
383
|
-
|
384
|
-
@staticmethod
|
385
|
-
def _delassus_matrix(
|
386
|
-
M: jtp.MatrixLike,
|
387
|
-
J_WC: jtp.MatrixLike,
|
388
|
-
) -> jtp.Matrix:
|
389
|
-
|
390
|
-
sl = jnp.s_[:, 0:3, :]
|
391
|
-
J_WC_lin = jnp.vstack(J_WC[sl])
|
392
|
-
|
393
|
-
delassus_matrix = J_WC_lin @ jnp.linalg.pinv(M) @ J_WC_lin.T
|
394
|
-
return delassus_matrix
|
395
|
-
|
396
|
-
@staticmethod
|
397
|
-
def _compute_ineq_constraint_matrix(
|
398
|
-
inactive_collidable_points: jtp.Vector, mu: jtp.FloatLike
|
399
|
-
) -> jtp.Matrix:
|
400
|
-
"""
|
401
|
-
Compute the inequality constraint matrix for a single collidable point.
|
402
|
-
|
403
|
-
Rows 0-3: enforce the friction pyramid constraint,
|
404
|
-
Row 4: last one is for the non negativity of the vertical force
|
405
|
-
Row 5: contact complementarity condition
|
406
|
-
"""
|
407
|
-
G_single_point = jnp.array(
|
408
|
-
[
|
409
|
-
[1, 0, -mu],
|
410
|
-
[0, 1, -mu],
|
411
|
-
[-1, 0, -mu],
|
412
|
-
[0, -1, -mu],
|
413
|
-
[0, 0, -1],
|
414
|
-
[0, 0, 0],
|
415
|
-
]
|
416
|
-
)
|
417
|
-
G = jnp.tile(G_single_point, (len(inactive_collidable_points), 1, 1))
|
418
|
-
G = G.at[:, 5, 2].set(inactive_collidable_points)
|
419
|
-
|
420
|
-
G = jax.scipy.linalg.block_diag(*G)
|
421
|
-
return G
|
422
|
-
|
423
|
-
@staticmethod
|
424
|
-
def _compute_ineq_bounds(n_collidable_points: int) -> jtp.Vector:
|
425
|
-
|
426
|
-
n_constraints = 6 * n_collidable_points
|
427
|
-
return jnp.zeros(shape=(n_constraints,))
|
428
|
-
|
429
|
-
@staticmethod
|
430
|
-
def _linear_acceleration_of_collidable_points(
|
431
|
-
BW_nu: jtp.ArrayLike,
|
432
|
-
BW_nu_dot: jtp.ArrayLike,
|
433
|
-
CW_J_WC_BW: jtp.MatrixLike,
|
434
|
-
CW_J_dot_WC_BW: jtp.MatrixLike,
|
435
|
-
) -> jtp.Matrix:
|
436
|
-
|
437
|
-
BW_ν = BW_nu
|
438
|
-
BW_ν̇ = BW_nu_dot
|
439
|
-
CW_J̇_WC_BW = CW_J_dot_WC_BW
|
440
|
-
|
441
|
-
# Compute the linear acceleration of the collidable points.
|
442
|
-
# Since we use doubly-mixed jacobians, this corresponds to W_p̈_C.
|
443
|
-
CW_a_WC = jnp.vstack(CW_J̇_WC_BW) @ BW_ν + jnp.vstack(CW_J_WC_BW) @ BW_ν̇
|
444
|
-
|
445
|
-
CW_a_WC = CW_a_WC.reshape(-1, 6)
|
446
|
-
return CW_a_WC[:, 0:3].squeeze()
|
447
|
-
|
448
|
-
@staticmethod
|
449
|
-
def _compute_baumgarte_stabilization_term(
|
450
|
-
inactive_collidable_points: jtp.ArrayLike,
|
451
|
-
δ: jtp.ArrayLike,
|
452
|
-
δ_dot: jtp.ArrayLike,
|
453
|
-
n: jtp.ArrayLike,
|
454
|
-
K: jtp.FloatLike,
|
455
|
-
D: jtp.FloatLike,
|
456
|
-
) -> jtp.Array:
|
457
|
-
|
458
|
-
return jnp.where(
|
459
|
-
inactive_collidable_points[:, jnp.newaxis],
|
460
|
-
jnp.zeros_like(n),
|
461
|
-
(K * δ + D * δ_dot)[:, jnp.newaxis] * n,
|
462
|
-
)
|