jaxsim 0.4.3.dev64__py3-none-any.whl → 0.4.3.dev68__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/__init__.py +0 -5
- jaxsim/_version.py +2 -2
- jaxsim/api/contact.py +1 -27
- jaxsim/api/data.py +11 -40
- jaxsim/api/joint.py +2 -62
- jaxsim/api/model.py +1 -12
- jaxsim/api/ode.py +24 -19
- jaxsim/api/ode_data.py +1 -11
- jaxsim/integrators/common.py +1 -1
- jaxsim/math/inertia.py +1 -1
- jaxsim/mujoco/loaders.py +3 -3
- jaxsim/parsers/kinematic_graph.py +3 -3
- jaxsim/parsers/rod/parser.py +14 -18
- jaxsim/rbda/contacts/rigid.py +41 -11
- jaxsim/terrain/terrain.py +25 -41
- jaxsim/typing.py +1 -1
- jaxsim/utils/jaxsim_dataclass.py +9 -12
- jaxsim/utils/wrappers.py +1 -1
- {jaxsim-0.4.3.dev64.dist-info → jaxsim-0.4.3.dev68.dist-info}/METADATA +1 -2
- {jaxsim-0.4.3.dev64.dist-info → jaxsim-0.4.3.dev68.dist-info}/RECORD +23 -24
- jaxsim/rbda/contacts/relaxed_rigid.py +0 -384
- {jaxsim-0.4.3.dev64.dist-info → jaxsim-0.4.3.dev68.dist-info}/LICENSE +0 -0
- {jaxsim-0.4.3.dev64.dist-info → jaxsim-0.4.3.dev68.dist-info}/WHEEL +0 -0
- {jaxsim-0.4.3.dev64.dist-info → jaxsim-0.4.3.dev68.dist-info}/top_level.txt +0 -0
jaxsim/__init__.py
CHANGED
@@ -20,11 +20,6 @@ def _jnp_options() -> None:
|
|
20
20
|
if jnp.empty(0, dtype=float).dtype != jnp.empty(0, dtype=np.float64).dtype:
|
21
21
|
logging.warning("Failed to enable 64bit precision in JAX")
|
22
22
|
|
23
|
-
else:
|
24
|
-
logging.warning(
|
25
|
-
"Using 32bit precision in JaxSim is still experimental, please avoid to use variable step integrators."
|
26
|
-
)
|
27
|
-
|
28
23
|
|
29
24
|
def _np_options() -> None:
|
30
25
|
import numpy as np
|
jaxsim/_version.py
CHANGED
@@ -12,5 +12,5 @@ __version__: str
|
|
12
12
|
__version_tuple__: VERSION_TUPLE
|
13
13
|
version_tuple: VERSION_TUPLE
|
14
14
|
|
15
|
-
__version__ = version = '0.4.3.
|
16
|
-
__version_tuple__ = version_tuple = (0, 4, 3, '
|
15
|
+
__version__ = version = '0.4.3.dev68'
|
16
|
+
__version_tuple__ = version_tuple = (0, 4, 3, 'dev68')
|
jaxsim/api/contact.py
CHANGED
@@ -131,8 +131,7 @@ def collidable_point_dynamics(
|
|
131
131
|
Returns:
|
132
132
|
The 6D force applied to each collidable point and additional data based on the contact model configured:
|
133
133
|
- Soft: the material deformation rate.
|
134
|
-
- Rigid:
|
135
|
-
- QuasiRigid: no additional data.
|
134
|
+
- Rigid: nothing.
|
136
135
|
|
137
136
|
Note:
|
138
137
|
The material deformation rate is always returned in the mixed frame
|
@@ -145,10 +144,6 @@ def collidable_point_dynamics(
|
|
145
144
|
W_p_Ci, W_ṗ_Ci = js.contact.collidable_point_kinematics(model=model, data=data)
|
146
145
|
|
147
146
|
# Import privately the contacts classes.
|
148
|
-
from jaxsim.rbda.contacts.relaxed_rigid import (
|
149
|
-
RelaxedRigidContacts,
|
150
|
-
RelaxedRigidContactsState,
|
151
|
-
)
|
152
147
|
from jaxsim.rbda.contacts.rigid import RigidContacts, RigidContactsState
|
153
148
|
from jaxsim.rbda.contacts.soft import SoftContacts, SoftContactsState
|
154
149
|
|
@@ -195,27 +190,6 @@ def collidable_point_dynamics(
|
|
195
190
|
|
196
191
|
aux_data = dict()
|
197
192
|
|
198
|
-
case RelaxedRigidContacts():
|
199
|
-
assert isinstance(model.contact_model, RelaxedRigidContacts)
|
200
|
-
assert isinstance(data.state.contact, RelaxedRigidContactsState)
|
201
|
-
|
202
|
-
# Build the contact model.
|
203
|
-
relaxed_rigid_contacts = RelaxedRigidContacts(
|
204
|
-
parameters=data.contacts_params, terrain=model.terrain
|
205
|
-
)
|
206
|
-
|
207
|
-
# Compute the 6D force expressed in the inertial frame and applied to each
|
208
|
-
# collidable point.
|
209
|
-
W_f_Ci, _ = relaxed_rigid_contacts.compute_contact_forces(
|
210
|
-
position=W_p_Ci,
|
211
|
-
velocity=W_ṗ_Ci,
|
212
|
-
model=model,
|
213
|
-
data=data,
|
214
|
-
link_forces=link_forces,
|
215
|
-
)
|
216
|
-
|
217
|
-
aux_data = dict()
|
218
|
-
|
219
193
|
case _:
|
220
194
|
raise ValueError(f"Invalid contact model {model.contact_model}")
|
221
195
|
|
jaxsim/api/data.py
CHANGED
@@ -39,9 +39,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
39
39
|
contacts_params: jaxsim.rbda.ContactsParams = dataclasses.field(repr=False)
|
40
40
|
|
41
41
|
time_ns: jtp.Int = dataclasses.field(
|
42
|
-
default_factory=lambda: jnp.array(
|
43
|
-
0, dtype=jnp.uint64 if jax.config.read("jax_enable_x64") else jnp.uint32
|
44
|
-
),
|
42
|
+
default_factory=lambda: jnp.array(0, dtype=jnp.uint64)
|
45
43
|
)
|
46
44
|
|
47
45
|
def __hash__(self) -> int:
|
@@ -174,14 +172,9 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
174
172
|
)
|
175
173
|
|
176
174
|
time_ns = (
|
177
|
-
jnp.array(
|
178
|
-
time * 1e9,
|
179
|
-
dtype=jnp.uint64 if jax.config.read("jax_enable_x64") else jnp.uint32,
|
180
|
-
)
|
175
|
+
jnp.array(time * 1e9, dtype=jnp.uint64)
|
181
176
|
if time is not None
|
182
|
-
else jnp.array(
|
183
|
-
0, dtype=jnp.uint64 if jax.config.read("jax_enable_x64") else jnp.uint32
|
184
|
-
)
|
177
|
+
else jnp.array(0, dtype=jnp.uint64)
|
185
178
|
)
|
186
179
|
|
187
180
|
if isinstance(model.contact_model, SoftContacts):
|
@@ -593,18 +586,16 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
|
|
593
586
|
The updated `JaxSimModelData` object.
|
594
587
|
"""
|
595
588
|
|
596
|
-
|
597
|
-
|
598
|
-
W_Q_B = jax.lax.select(
|
599
|
-
pred=jnp.allclose(jnp.linalg.norm(W_Q_B), 1.0, atol=1e-6, rtol=0.0),
|
600
|
-
on_true=W_Q_B,
|
601
|
-
on_false=W_Q_B / jnp.linalg.norm(W_Q_B),
|
602
|
-
)
|
589
|
+
base_quaternion = jnp.array(base_quaternion)
|
603
590
|
|
604
591
|
return self.replace(
|
605
592
|
validate=True,
|
606
593
|
state=self.state.replace(
|
607
|
-
physics_model=self.state.physics_model.replace(
|
594
|
+
physics_model=self.state.physics_model.replace(
|
595
|
+
base_quaternion=jnp.atleast_1d(base_quaternion.squeeze()).astype(
|
596
|
+
float
|
597
|
+
)
|
598
|
+
)
|
608
599
|
),
|
609
600
|
)
|
610
601
|
|
@@ -746,13 +737,6 @@ def random_model_data(
|
|
746
737
|
jtp.FloatLike | Sequence[jtp.FloatLike],
|
747
738
|
jtp.FloatLike | Sequence[jtp.FloatLike],
|
748
739
|
] = ((-1, -1, 0.5), 1.0),
|
749
|
-
joint_pos_bounds: (
|
750
|
-
tuple[
|
751
|
-
jtp.FloatLike | Sequence[jtp.FloatLike],
|
752
|
-
jtp.FloatLike | Sequence[jtp.FloatLike],
|
753
|
-
]
|
754
|
-
| None
|
755
|
-
) = None,
|
756
740
|
base_vel_lin_bounds: tuple[
|
757
741
|
jtp.FloatLike | Sequence[jtp.FloatLike],
|
758
742
|
jtp.FloatLike | Sequence[jtp.FloatLike],
|
@@ -778,8 +762,6 @@ def random_model_data(
|
|
778
762
|
key: The random key.
|
779
763
|
velocity_representation: The velocity representation to use.
|
780
764
|
base_pos_bounds: The bounds for the base position.
|
781
|
-
joint_pos_bounds:
|
782
|
-
The bounds for the joint positions (reading the joint limits if None).
|
783
765
|
base_vel_lin_bounds: The bounds for the base linear velocity.
|
784
766
|
base_vel_ang_bounds: The bounds for the base angular velocity.
|
785
767
|
joint_vel_bounds: The bounds for the joint velocities.
|
@@ -824,19 +806,8 @@ def random_model_data(
|
|
824
806
|
).wxyz
|
825
807
|
|
826
808
|
if model.number_of_joints() > 0:
|
827
|
-
|
828
|
-
|
829
|
-
jnp.array(joint_pos_bounds, dtype=float)
|
830
|
-
if joint_pos_bounds is not None
|
831
|
-
else (None, None)
|
832
|
-
)
|
833
|
-
|
834
|
-
physics_model_state.joint_positions = (
|
835
|
-
js.joint.random_joint_positions(model=model, key=k3)
|
836
|
-
if (s_min is None or s_max is None)
|
837
|
-
else jax.random.uniform(
|
838
|
-
key=k3, shape=(model.dofs(),), minval=s_min, maxval=s_max
|
839
|
-
)
|
809
|
+
physics_model_state.joint_positions = js.joint.random_joint_positions(
|
810
|
+
model=model, key=k3
|
840
811
|
)
|
841
812
|
|
842
813
|
physics_model_state.joint_velocities = jax.random.uniform(
|
jaxsim/api/joint.py
CHANGED
@@ -180,77 +180,17 @@ def random_joint_positions(
|
|
180
180
|
|
181
181
|
Args:
|
182
182
|
model: The model to consider.
|
183
|
-
joint_names: The names of the
|
184
|
-
key: The random key
|
185
|
-
|
186
|
-
Note:
|
187
|
-
If the joint range or revolute joints is larger than 2π, their joint positions
|
188
|
-
will be sampled from an interval of size 2π.
|
183
|
+
joint_names: The names of the joints.
|
184
|
+
key: The random key.
|
189
185
|
|
190
186
|
Returns:
|
191
187
|
The random joint positions.
|
192
188
|
"""
|
193
189
|
|
194
|
-
# Consider the key corresponding to a zero seed if it was not passed.
|
195
190
|
key = key if key is not None else jax.random.PRNGKey(seed=0)
|
196
191
|
|
197
|
-
# Get the joint limits parsed from the model description.
|
198
192
|
s_min, s_max = position_limits(model=model, joint_names=joint_names)
|
199
193
|
|
200
|
-
# Get the joint indices.
|
201
|
-
# Note that it will trigger an exception if the given `joint_names` are not valid.
|
202
|
-
joint_names = joint_names if joint_names is not None else model.joint_names()
|
203
|
-
joint_indices = names_to_idxs(model=model, joint_names=joint_names)
|
204
|
-
|
205
|
-
from jaxsim.parsers.descriptions.joint import JointType
|
206
|
-
|
207
|
-
# Filter for revolute joints.
|
208
|
-
is_revolute = jnp.where(
|
209
|
-
jnp.array(model.kin_dyn_parameters.joint_model.joint_types[1:])[joint_indices]
|
210
|
-
== JointType.Revolute,
|
211
|
-
True,
|
212
|
-
False,
|
213
|
-
)
|
214
|
-
|
215
|
-
# Shorthand for π.
|
216
|
-
π = jnp.pi
|
217
|
-
|
218
|
-
# Filter for revolute with full range (or continuous).
|
219
|
-
is_revolute_full_range = jnp.logical_and(is_revolute, s_max - s_min >= 2 * π)
|
220
|
-
|
221
|
-
# Clip the lower limit to -π if the joint range is larger than [-π, π].
|
222
|
-
s_min = jnp.where(
|
223
|
-
jnp.logical_and(
|
224
|
-
is_revolute_full_range, jnp.logical_and(s_min <= -π, s_max >= π)
|
225
|
-
),
|
226
|
-
-π,
|
227
|
-
s_min,
|
228
|
-
)
|
229
|
-
|
230
|
-
# Clip the upper limit to +π if the joint range is larger than [-π, π].
|
231
|
-
s_max = jnp.where(
|
232
|
-
jnp.logical_and(
|
233
|
-
is_revolute_full_range, jnp.logical_and(s_min <= -π, s_max >= π)
|
234
|
-
),
|
235
|
-
π,
|
236
|
-
s_max,
|
237
|
-
)
|
238
|
-
|
239
|
-
# Shift the lower limit if the upper limit is smaller than +π.
|
240
|
-
s_min = jnp.where(
|
241
|
-
jnp.logical_and(is_revolute_full_range, s_max < π),
|
242
|
-
s_max - 2 * π,
|
243
|
-
s_min,
|
244
|
-
)
|
245
|
-
|
246
|
-
# Shift the upper limit if the lower limit is larger than -π.
|
247
|
-
s_max = jnp.where(
|
248
|
-
jnp.logical_and(is_revolute_full_range, s_min > -π),
|
249
|
-
s_min + 2 * π,
|
250
|
-
s_max,
|
251
|
-
)
|
252
|
-
|
253
|
-
# Sample the joint positions.
|
254
194
|
s_random = jax.random.uniform(
|
255
195
|
minval=s_min,
|
256
196
|
maxval=s_max,
|
jaxsim/api/model.py
CHANGED
@@ -1931,22 +1931,11 @@ def step(
|
|
1931
1931
|
),
|
1932
1932
|
)
|
1933
1933
|
|
1934
|
-
tf_ns = t0_ns + jnp.array(dt * 1e9, dtype=t0_ns.dtype)
|
1935
|
-
tf_ns = jnp.where(tf_ns >= t0_ns, tf_ns, jnp.array(0, dtype=t0_ns.dtype))
|
1936
|
-
|
1937
|
-
jax.lax.cond(
|
1938
|
-
pred=tf_ns < t0_ns,
|
1939
|
-
true_fun=lambda: jax.debug.print(
|
1940
|
-
"The simulation time overflowed, resetting simulation time to 0."
|
1941
|
-
),
|
1942
|
-
false_fun=lambda: None,
|
1943
|
-
)
|
1944
|
-
|
1945
1934
|
data_tf = (
|
1946
1935
|
# Store the new state of the model and the new time.
|
1947
1936
|
data.replace(
|
1948
1937
|
state=state_tf,
|
1949
|
-
time_ns=
|
1938
|
+
time_ns=t0_ns + jnp.array(dt * 1e9).astype(jnp.uint64),
|
1950
1939
|
)
|
1951
1940
|
)
|
1952
1941
|
|
jaxsim/api/ode.py
CHANGED
@@ -175,15 +175,17 @@ def system_velocity_dynamics(
|
|
175
175
|
forces=W_f_Li_terrain,
|
176
176
|
additive=True,
|
177
177
|
)
|
178
|
-
|
179
|
-
|
178
|
+
# Get the link forces in the data representation
|
179
|
+
with references.switch_velocity_representation(data.velocity_representation):
|
180
180
|
f_L_total = references.link_forces(model=model, data=data)
|
181
181
|
|
182
|
-
|
183
|
-
|
184
|
-
|
182
|
+
# The following method always returns the inertial-fixed acceleration, and expects
|
183
|
+
# the link_forces expressed in the inertial frame.
|
184
|
+
W_v̇_WB, s̈ = system_acceleration(
|
185
|
+
model=model, data=data, joint_forces=joint_forces, link_forces=f_L_total
|
186
|
+
)
|
185
187
|
|
186
|
-
return
|
188
|
+
return W_v̇_WB, s̈, aux_data
|
187
189
|
|
188
190
|
|
189
191
|
def system_acceleration(
|
@@ -194,7 +196,7 @@ def system_acceleration(
|
|
194
196
|
link_forces: jtp.MatrixLike | None = None,
|
195
197
|
) -> tuple[jtp.Vector, jtp.Vector]:
|
196
198
|
"""
|
197
|
-
Compute the system acceleration in
|
199
|
+
Compute the system acceleration in inertial-fixed representation.
|
198
200
|
|
199
201
|
Args:
|
200
202
|
model: The model to consider.
|
@@ -204,7 +206,7 @@ def system_acceleration(
|
|
204
206
|
The 6D forces to apply to the links expressed in the same representation of data.
|
205
207
|
|
206
208
|
Returns:
|
207
|
-
A tuple containing the base 6D acceleration in
|
209
|
+
A tuple containing the base 6D acceleration in inertial-fixed representation
|
208
210
|
and the joint accelerations.
|
209
211
|
"""
|
210
212
|
|
@@ -270,15 +272,18 @@ def system_acceleration(
|
|
270
272
|
)
|
271
273
|
|
272
274
|
# - Joint accelerations: s̈ ∈ ℝⁿ
|
273
|
-
# - Base acceleration:
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
275
|
+
# - Base inertial-fixed acceleration: W_v̇_WB = (W_p̈_B, W_ω̇_B) ∈ ℝ⁶
|
276
|
+
with (
|
277
|
+
data.switch_velocity_representation(velocity_representation=VelRepr.Inertial),
|
278
|
+
references.switch_velocity_representation(VelRepr.Inertial),
|
279
|
+
):
|
280
|
+
W_v̇_WB, s̈ = js.model.forward_dynamics_aba(
|
281
|
+
model=model,
|
282
|
+
data=data,
|
283
|
+
joint_forces=references.joint_force_references(),
|
284
|
+
link_forces=references.link_forces(),
|
285
|
+
)
|
286
|
+
return W_v̇_WB, s̈
|
282
287
|
|
283
288
|
|
284
289
|
@jax.jit
|
@@ -348,7 +353,7 @@ def system_dynamics(
|
|
348
353
|
corresponding derivative, and the dictionary of auxiliary data returned
|
349
354
|
by the system dynamics evaluation.
|
350
355
|
"""
|
351
|
-
|
356
|
+
|
352
357
|
from jaxsim.rbda.contacts.rigid import RigidContacts
|
353
358
|
from jaxsim.rbda.contacts.soft import SoftContacts
|
354
359
|
|
@@ -366,7 +371,7 @@ def system_dynamics(
|
|
366
371
|
case SoftContacts():
|
367
372
|
ode_state_kwargs["tangential_deformation"] = aux_dict["m_dot"]
|
368
373
|
|
369
|
-
case RigidContacts()
|
374
|
+
case RigidContacts():
|
370
375
|
pass
|
371
376
|
|
372
377
|
case _:
|
jaxsim/api/ode_data.py
CHANGED
@@ -6,10 +6,6 @@ import jax_dataclasses
|
|
6
6
|
import jaxsim.api as js
|
7
7
|
import jaxsim.typing as jtp
|
8
8
|
from jaxsim.rbda import ContactsState
|
9
|
-
from jaxsim.rbda.contacts.relaxed_rigid import (
|
10
|
-
RelaxedRigidContacts,
|
11
|
-
RelaxedRigidContactsState,
|
12
|
-
)
|
13
9
|
from jaxsim.rbda.contacts.rigid import RigidContacts, RigidContactsState
|
14
10
|
from jaxsim.rbda.contacts.soft import SoftContacts, SoftContactsState
|
15
11
|
from jaxsim.utils import JaxsimDataclass
|
@@ -177,10 +173,6 @@ class ODEState(JaxsimDataclass):
|
|
177
173
|
)
|
178
174
|
case RigidContacts():
|
179
175
|
contact = RigidContactsState.build()
|
180
|
-
|
181
|
-
case RelaxedRigidContacts():
|
182
|
-
contact = RelaxedRigidContactsState.build()
|
183
|
-
|
184
176
|
case _:
|
185
177
|
raise ValueError("Unable to determine contact state class prefix.")
|
186
178
|
|
@@ -224,9 +216,7 @@ class ODEState(JaxsimDataclass):
|
|
224
216
|
|
225
217
|
# Get the contact model from the `JaxSimModel`.
|
226
218
|
match contact:
|
227
|
-
case (
|
228
|
-
SoftContactsState() | RigidContactsState() | RelaxedRigidContactsState()
|
229
|
-
):
|
219
|
+
case SoftContactsState() | RigidContactsState():
|
230
220
|
pass
|
231
221
|
case None:
|
232
222
|
contact = SoftContactsState.zero(model=model)
|
jaxsim/integrators/common.py
CHANGED
@@ -497,7 +497,7 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
|
|
497
497
|
b: jtp.Matrix,
|
498
498
|
c: jtp.Vector,
|
499
499
|
index_of_solution: jtp.IntLike = 0,
|
500
|
-
) ->
|
500
|
+
) -> [bool, int | None]:
|
501
501
|
"""
|
502
502
|
Check if the Butcher tableau supports the FSAL (first-same-as-last) property.
|
503
503
|
|
jaxsim/math/inertia.py
CHANGED
@@ -45,7 +45,7 @@ class Inertia:
|
|
45
45
|
M (jtp.Matrix): The 6x6 inertia matrix.
|
46
46
|
|
47
47
|
Returns:
|
48
|
-
|
48
|
+
Tuple[jtp.Float, jtp.Vector, jtp.Matrix]: A tuple containing mass, center of mass (3D), and inertia matrix (3x3).
|
49
49
|
|
50
50
|
Raises:
|
51
51
|
ValueError: If the input matrix M has an unexpected shape.
|
jaxsim/mujoco/loaders.py
CHANGED
@@ -211,7 +211,7 @@ class RodModelToMjcf:
|
|
211
211
|
joints_dict = {j.name: j for j in rod_model.joints()}
|
212
212
|
|
213
213
|
# Convert all the joints not considered to fixed joints.
|
214
|
-
for joint_name in
|
214
|
+
for joint_name in set(j.name for j in rod_model.joints()) - considered_joints:
|
215
215
|
joints_dict[joint_name].type = "fixed"
|
216
216
|
|
217
217
|
# Convert the ROD model to URDF.
|
@@ -289,10 +289,10 @@ class RodModelToMjcf:
|
|
289
289
|
mj_model = mj.MjModel.from_xml_string(xml=urdf_string, assets=assets)
|
290
290
|
|
291
291
|
# Get the joint names.
|
292
|
-
mj_joint_names =
|
292
|
+
mj_joint_names = set(
|
293
293
|
mj.mj_id2name(mj_model, mj.mjtObj.mjOBJ_JOINT, idx)
|
294
294
|
for idx in range(mj_model.njnt)
|
295
|
-
|
295
|
+
)
|
296
296
|
|
297
297
|
# Check that the Mujoco model only has the considered joints.
|
298
298
|
if mj_joint_names != considered_joints:
|
@@ -394,7 +394,7 @@ class KinematicGraph(Sequence[LinkDescription]):
|
|
394
394
|
return copy.deepcopy(self)
|
395
395
|
|
396
396
|
# Check if all considered joints are part of the full kinematic graph
|
397
|
-
if len(set(considered_joints) -
|
397
|
+
if len(set(considered_joints) - set(j.name for j in full_graph.joints)) != 0:
|
398
398
|
extra_j = set(considered_joints) - {j.name for j in full_graph.joints}
|
399
399
|
msg = f"Not all joints to consider are part of the graph ({{{extra_j}}})"
|
400
400
|
raise ValueError(msg)
|
@@ -536,8 +536,8 @@ class KinematicGraph(Sequence[LinkDescription]):
|
|
536
536
|
root_link_name=full_graph.root.name,
|
537
537
|
)
|
538
538
|
|
539
|
-
assert
|
540
|
-
|
539
|
+
assert set(f.name for f in self.frames).isdisjoint(
|
540
|
+
set(f.name for f in unconnected_frames + reduced_frames)
|
541
541
|
)
|
542
542
|
|
543
543
|
for link in unconnected_links:
|
jaxsim/parsers/rod/parser.py
CHANGED
@@ -223,7 +223,7 @@ def extract_model_data(
|
|
223
223
|
child=links_dict[j.child],
|
224
224
|
jtype=utils.joint_to_joint_type(joint=j),
|
225
225
|
axis=(
|
226
|
-
np.array(j.axis.xyz.xyz
|
226
|
+
np.array(j.axis.xyz.xyz)
|
227
227
|
if j.axis is not None
|
228
228
|
and j.axis.xyz is not None
|
229
229
|
and j.axis.xyz.xyz is not None
|
@@ -232,43 +232,39 @@ def extract_model_data(
|
|
232
232
|
pose=j.pose.transform() if j.pose is not None else np.eye(4),
|
233
233
|
initial_position=0.0,
|
234
234
|
position_limit=(
|
235
|
-
|
236
|
-
j.axis.limit.lower
|
237
|
-
if j.axis is not None
|
238
|
-
|
239
|
-
and j.axis.limit.lower is not None
|
240
|
-
else jnp.finfo(float).min
|
235
|
+
(
|
236
|
+
float(j.axis.limit.lower)
|
237
|
+
if j.axis is not None and j.axis.limit is not None
|
238
|
+
else np.finfo(float).min
|
241
239
|
),
|
242
|
-
|
243
|
-
j.axis.limit.upper
|
244
|
-
if j.axis is not None
|
245
|
-
|
246
|
-
and j.axis.limit.upper is not None
|
247
|
-
else jnp.finfo(float).max
|
240
|
+
(
|
241
|
+
float(j.axis.limit.upper)
|
242
|
+
if j.axis is not None and j.axis.limit is not None
|
243
|
+
else np.finfo(float).max
|
248
244
|
),
|
249
245
|
),
|
250
|
-
friction_static=
|
246
|
+
friction_static=(
|
251
247
|
j.axis.dynamics.friction
|
252
248
|
if j.axis is not None
|
253
249
|
and j.axis.dynamics is not None
|
254
250
|
and j.axis.dynamics.friction is not None
|
255
251
|
else 0.0
|
256
252
|
),
|
257
|
-
friction_viscous=
|
253
|
+
friction_viscous=(
|
258
254
|
j.axis.dynamics.damping
|
259
255
|
if j.axis is not None
|
260
256
|
and j.axis.dynamics is not None
|
261
257
|
and j.axis.dynamics.damping is not None
|
262
258
|
else 0.0
|
263
259
|
),
|
264
|
-
position_limit_damper=
|
260
|
+
position_limit_damper=(
|
265
261
|
j.axis.limit.dissipation
|
266
262
|
if j.axis is not None
|
267
263
|
and j.axis.limit is not None
|
268
264
|
and j.axis.limit.dissipation is not None
|
269
265
|
else 0.0
|
270
266
|
),
|
271
|
-
position_limit_spring=
|
267
|
+
position_limit_spring=(
|
272
268
|
j.axis.limit.stiffness
|
273
269
|
if j.axis is not None
|
274
270
|
and j.axis.limit is not None
|
@@ -277,7 +273,7 @@ def extract_model_data(
|
|
277
273
|
),
|
278
274
|
)
|
279
275
|
for j in sdf_model.joints()
|
280
|
-
if j.type in {"revolute", "
|
276
|
+
if j.type in {"revolute", "prismatic", "fixed"}
|
281
277
|
and j.parent != "world"
|
282
278
|
and j.child in links_dict.keys()
|
283
279
|
]
|
jaxsim/rbda/contacts/rigid.py
CHANGED
@@ -9,6 +9,7 @@ import jax_dataclasses
|
|
9
9
|
|
10
10
|
import jaxsim.api as js
|
11
11
|
import jaxsim.typing as jtp
|
12
|
+
from jaxsim import math
|
12
13
|
from jaxsim.api.common import ModelDataWithVelocityRepresentation, VelRepr
|
13
14
|
from jaxsim.terrain import FlatTerrain, Terrain
|
14
15
|
|
@@ -271,17 +272,9 @@ class RigidContacts(ContactModel):
|
|
271
272
|
link_forces=link_forces,
|
272
273
|
)
|
273
274
|
|
274
|
-
with (
|
275
|
-
|
276
|
-
|
277
|
-
):
|
278
|
-
BW_ν̇_free = jnp.hstack(
|
279
|
-
js.ode.system_acceleration(
|
280
|
-
model=model,
|
281
|
-
data=data,
|
282
|
-
joint_forces=references.joint_force_references(model=model),
|
283
|
-
link_forces=references.link_forces(model=model, data=data),
|
284
|
-
)
|
275
|
+
with references.switch_velocity_representation(VelRepr.Mixed):
|
276
|
+
BW_ν̇_free = RigidContacts._compute_mixed_nu_dot_free(
|
277
|
+
model, data, references=references
|
285
278
|
)
|
286
279
|
|
287
280
|
free_contact_acc = RigidContacts._linear_acceleration_of_collidable_points(
|
@@ -387,6 +380,43 @@ class RigidContacts(ContactModel):
|
|
387
380
|
n_constraints = 6 * n_collidable_points
|
388
381
|
return jnp.zeros(shape=(n_constraints,))
|
389
382
|
|
383
|
+
@staticmethod
|
384
|
+
def _compute_mixed_nu_dot_free(
|
385
|
+
model: js.model.JaxSimModel,
|
386
|
+
data: js.data.JaxSimModelData,
|
387
|
+
references: js.references.JaxSimModelReferences | None = None,
|
388
|
+
) -> jtp.Array:
|
389
|
+
references = (
|
390
|
+
references
|
391
|
+
if references is not None
|
392
|
+
else js.references.JaxSimModelReferences.zero(model=model, data=data)
|
393
|
+
)
|
394
|
+
|
395
|
+
with (
|
396
|
+
data.switch_velocity_representation(VelRepr.Mixed),
|
397
|
+
references.switch_velocity_representation(VelRepr.Mixed),
|
398
|
+
):
|
399
|
+
BW_v_WB = data.base_velocity()
|
400
|
+
W_ṗ_B, W_ω_WB = jnp.split(BW_v_WB, 2)
|
401
|
+
W_v̇_WB, s̈ = js.ode.system_acceleration(
|
402
|
+
model=model,
|
403
|
+
data=data,
|
404
|
+
joint_forces=references.joint_force_references(model=model),
|
405
|
+
link_forces=references.link_forces(model=model, data=data),
|
406
|
+
)
|
407
|
+
|
408
|
+
# Convert the inertial-fixed base acceleration to a mixed base acceleration.
|
409
|
+
W_H_B = data.base_transform()
|
410
|
+
W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3))
|
411
|
+
BW_X_W = math.Adjoint.from_transform(W_H_BW, inverse=True)
|
412
|
+
term1 = BW_X_W @ W_v̇_WB
|
413
|
+
term2 = jnp.zeros(6).at[0:3].set(jnp.cross(W_ṗ_B, W_ω_WB))
|
414
|
+
BW_v̇_WB = term1 - term2
|
415
|
+
|
416
|
+
BW_ν̇ = jnp.hstack([BW_v̇_WB, s̈])
|
417
|
+
|
418
|
+
return BW_ν̇
|
419
|
+
|
390
420
|
@staticmethod
|
391
421
|
def _linear_acceleration_of_collidable_points(
|
392
422
|
model: js.model.JaxSimModel,
|
jaxsim/terrain/terrain.py
CHANGED
@@ -46,82 +46,66 @@ class Terrain(abc.ABC):
|
|
46
46
|
@jax_dataclasses.pytree_dataclass
|
47
47
|
class FlatTerrain(Terrain):
|
48
48
|
|
49
|
-
|
49
|
+
z: float = dataclasses.field(default=0.0, kw_only=True)
|
50
50
|
|
51
51
|
@staticmethod
|
52
52
|
def build(height: jtp.FloatLike) -> FlatTerrain:
|
53
53
|
|
54
|
-
return FlatTerrain(
|
54
|
+
return FlatTerrain(z=float(height))
|
55
55
|
|
56
56
|
def height(self, x: jtp.FloatLike, y: jtp.FloatLike) -> jtp.Float:
|
57
57
|
|
58
|
-
return jnp.array(self.
|
59
|
-
|
60
|
-
def normal(self, x: jtp.FloatLike, y: jtp.FloatLike) -> jtp.Vector:
|
61
|
-
|
62
|
-
return jnp.array([0.0, 0.0, 1.0], dtype=float)
|
58
|
+
return jnp.array(self.z, dtype=float)
|
63
59
|
|
64
60
|
def __hash__(self) -> int:
|
65
61
|
|
66
|
-
return hash(self.
|
62
|
+
return hash(self.z)
|
67
63
|
|
68
64
|
def __eq__(self, other: FlatTerrain) -> bool:
|
69
65
|
|
70
66
|
if not isinstance(other, FlatTerrain):
|
71
67
|
return False
|
72
68
|
|
73
|
-
return self.
|
69
|
+
return self.z == other.z
|
74
70
|
|
75
71
|
|
76
72
|
@jax_dataclasses.pytree_dataclass
|
77
73
|
class PlaneTerrain(FlatTerrain):
|
78
74
|
|
79
|
-
|
75
|
+
plane_normal: tuple[float, float, float] = jax_dataclasses.field(
|
80
76
|
default=(0.0, 0.0, 1.0), kw_only=True
|
81
77
|
)
|
82
78
|
|
83
79
|
@staticmethod
|
84
|
-
def build(
|
80
|
+
def build(
|
81
|
+
plane_normal: jtp.VectorLike, plane_height_over_origin: jtp.FloatLike = 0.0
|
82
|
+
) -> PlaneTerrain:
|
85
83
|
"""
|
86
84
|
Create a PlaneTerrain instance with a specified plane normal vector.
|
87
85
|
|
88
86
|
Args:
|
89
|
-
|
90
|
-
|
87
|
+
plane_normal: The normal vector of the terrain plane.
|
88
|
+
plane_height_over_origin: The height of the plane over the origin.
|
91
89
|
|
92
90
|
Returns:
|
93
91
|
PlaneTerrain: A PlaneTerrain instance.
|
94
92
|
"""
|
95
93
|
|
96
|
-
|
97
|
-
|
94
|
+
plane_normal = jnp.array(plane_normal, dtype=float)
|
95
|
+
plane_height_over_origin = jnp.array(plane_height_over_origin, dtype=float)
|
98
96
|
|
99
|
-
if
|
97
|
+
if plane_normal.shape != (3,):
|
100
98
|
msg = "Expected a 3D vector for the plane normal, got '{}'."
|
101
|
-
raise ValueError(msg.format(
|
99
|
+
raise ValueError(msg.format(plane_normal.shape))
|
102
100
|
|
103
101
|
# Make sure that the plane normal is a unit vector.
|
104
|
-
|
102
|
+
plane_normal = plane_normal / jnp.linalg.norm(plane_normal)
|
105
103
|
|
106
104
|
return PlaneTerrain(
|
107
|
-
|
108
|
-
|
105
|
+
z=float(plane_height_over_origin),
|
106
|
+
plane_normal=tuple(plane_normal.tolist()),
|
109
107
|
)
|
110
108
|
|
111
|
-
def normal(self, x: jtp.FloatLike, y: jtp.FloatLike) -> jtp.Vector:
|
112
|
-
"""
|
113
|
-
Compute the normal vector of the terrain at a specific (x, y) location.
|
114
|
-
|
115
|
-
Args:
|
116
|
-
x: The x-coordinate of the location.
|
117
|
-
y: The y-coordinate of the location.
|
118
|
-
|
119
|
-
Returns:
|
120
|
-
The normal vector of the terrain surface at the specified location.
|
121
|
-
"""
|
122
|
-
|
123
|
-
return jnp.array(self._normal, dtype=float)
|
124
|
-
|
125
109
|
def height(self, x: jtp.FloatLike, y: jtp.FloatLike) -> jtp.Float:
|
126
110
|
"""
|
127
111
|
Compute the height of the terrain at a specific (x, y) location on a plane.
|
@@ -139,10 +123,10 @@ class PlaneTerrain(FlatTerrain):
|
|
139
123
|
# The height over the origin: -D/C
|
140
124
|
|
141
125
|
# Get the plane equation coefficients from the terrain normal.
|
142
|
-
A, B, C = self.
|
126
|
+
A, B, C = self.plane_normal
|
143
127
|
|
144
128
|
# Compute the final coefficient D considering the terrain height.
|
145
|
-
D = -C * self.
|
129
|
+
D = -C * self.z
|
146
130
|
|
147
131
|
# Invert the plane equation to get the height at the given (x, y) coordinates.
|
148
132
|
return jnp.array(-(A * x + B * y + D) / C).astype(float)
|
@@ -153,9 +137,9 @@ class PlaneTerrain(FlatTerrain):
|
|
153
137
|
|
154
138
|
return hash(
|
155
139
|
(
|
156
|
-
hash(self.
|
140
|
+
hash(self.z),
|
157
141
|
HashedNumpyArray.hash_of_array(
|
158
|
-
array=jnp.array(self.
|
142
|
+
array=jnp.array(self.plane_normal, dtype=float)
|
159
143
|
),
|
160
144
|
)
|
161
145
|
)
|
@@ -166,10 +150,10 @@ class PlaneTerrain(FlatTerrain):
|
|
166
150
|
return False
|
167
151
|
|
168
152
|
if not (
|
169
|
-
np.allclose(self.
|
153
|
+
np.allclose(self.z, other.z)
|
170
154
|
and np.allclose(
|
171
|
-
np.array(self.
|
172
|
-
np.array(other.
|
155
|
+
np.array(self.plane_normal, dtype=float),
|
156
|
+
np.array(other.plane_normal, dtype=float),
|
173
157
|
)
|
174
158
|
):
|
175
159
|
return False
|
jaxsim/typing.py
CHANGED
jaxsim/utils/jaxsim_dataclass.py
CHANGED
@@ -135,10 +135,9 @@ class JaxsimDataclass(abc.ABC):
|
|
135
135
|
"""
|
136
136
|
|
137
137
|
return tuple(
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
)
|
138
|
+
leaf.shape if hasattr(leaf, "shape") else None
|
139
|
+
for leaf in jax.tree_util.tree_leaves(tree)
|
140
|
+
if hasattr(leaf, "shape")
|
142
141
|
)
|
143
142
|
|
144
143
|
@staticmethod
|
@@ -155,10 +154,9 @@ class JaxsimDataclass(abc.ABC):
|
|
155
154
|
"""
|
156
155
|
|
157
156
|
return tuple(
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
)
|
157
|
+
leaf.dtype if hasattr(leaf, "dtype") else None
|
158
|
+
for leaf in jax.tree_util.tree_leaves(tree)
|
159
|
+
if hasattr(leaf, "dtype")
|
162
160
|
)
|
163
161
|
|
164
162
|
@staticmethod
|
@@ -174,10 +172,9 @@ class JaxsimDataclass(abc.ABC):
|
|
174
172
|
"""
|
175
173
|
|
176
174
|
return tuple(
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
)
|
175
|
+
leaf.weak_type if hasattr(leaf, "weak_type") else False
|
176
|
+
for leaf in jax.tree_util.tree_leaves(tree)
|
177
|
+
if hasattr(leaf, "weak_type")
|
181
178
|
)
|
182
179
|
|
183
180
|
@staticmethod
|
jaxsim/utils/wrappers.py
CHANGED
@@ -110,7 +110,7 @@ class HashedNumpyArray:
|
|
110
110
|
return np.allclose(
|
111
111
|
self.array,
|
112
112
|
other.array,
|
113
|
-
**(dict(atol=self.precision) if self.precision is not None else {}),
|
113
|
+
**({dict(atol=self.precision)} if self.precision is not None else {}),
|
114
114
|
)
|
115
115
|
|
116
116
|
return hash(self) == hash(other)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: jaxsim
|
3
|
-
Version: 0.4.3.
|
3
|
+
Version: 0.4.3.dev68
|
4
4
|
Summary: A differentiable physics engine and multibody dynamics library for control and robot learning.
|
5
5
|
Author-email: Diego Ferigo <dgferigo@gmail.com>
|
6
6
|
Maintainer-email: Diego Ferigo <dgferigo@gmail.com>, Filippo Luca Ferretti <filippo.ferretti@iit.it>
|
@@ -61,7 +61,6 @@ Description-Content-Type: text/markdown
|
|
61
61
|
License-File: LICENSE
|
62
62
|
Requires-Dist: coloredlogs
|
63
63
|
Requires-Dist: jax>=0.4.13
|
64
|
-
Requires-Dist: jaxopt>=0.8.0
|
65
64
|
Requires-Dist: jaxlib>=0.4.13
|
66
65
|
Requires-Dist: jaxlie>=1.3.0
|
67
66
|
Requires-Dist: jax-dataclasses>=1.4.0
|
@@ -1,29 +1,29 @@
|
|
1
|
-
jaxsim/__init__.py,sha256=
|
2
|
-
jaxsim/_version.py,sha256=
|
1
|
+
jaxsim/__init__.py,sha256=ixsS4dYMPex2wOUUp_rkPnwrPhYzkRh1xO_YuMj3Cr4,2626
|
2
|
+
jaxsim/_version.py,sha256=XDf5LPSlhAhH48AO29kysLP_4FTR5VWOpS0LrK5RSfo,426
|
3
3
|
jaxsim/exceptions.py,sha256=8_h8iqL8DgNR754dR8SZiQ7361GR5V1sUk3ZuZCHw1Q,2069
|
4
4
|
jaxsim/logging.py,sha256=STI-D_upXZYX-ZezLrlJJ0UlD5YspST0vZ_DcIwkzO4,1553
|
5
|
-
jaxsim/typing.py,sha256=
|
5
|
+
jaxsim/typing.py,sha256=IbFx3UkEXi-cm7UBqMPi58rJAFV_HbZ9E_K4JwfNvVM,753
|
6
6
|
jaxsim/api/__init__.py,sha256=8eV22t2S3UwNyCg8karPetG1dmX1VDBXkyv28_FwNQA,210
|
7
7
|
jaxsim/api/com.py,sha256=m-p3EJDhpnMTlXKplfbZE_aH9NqX_VyLlAE3vUhc6l4,13642
|
8
8
|
jaxsim/api/common.py,sha256=SNgxq42r6eF_-aPszvOjUYkGwXOzz4hKmhDwEUkscFQ,6650
|
9
|
-
jaxsim/api/contact.py,sha256=
|
10
|
-
jaxsim/api/data.py,sha256=
|
9
|
+
jaxsim/api/contact.py,sha256=HyEAjF7BySDDOlRahN0l7V15IPB0HPXuoM0twamuEW0,20913
|
10
|
+
jaxsim/api/data.py,sha256=CUh9lvhVk3_clNQ26BUBGpjvFSsK_PrVWVMEWpMdHRM,27206
|
11
11
|
jaxsim/api/frame.py,sha256=KS8A5wRfjxhe9NgcVo2QA516iP5zky7UVnWxG7nTa7c,12911
|
12
|
-
jaxsim/api/joint.py,sha256=
|
12
|
+
jaxsim/api/joint.py,sha256=L81bQe-noPT6_54KOSF7KBjRmEPAS433ULn2EcXI8vI,5115
|
13
13
|
jaxsim/api/kin_dyn_parameters.py,sha256=CcfSg5Mc8qb1mZeMQ4AK_ffZIsK5yOl7tu397pFhcDA,29369
|
14
14
|
jaxsim/api/link.py,sha256=qPRtc8qqMRjZxUCZYXJMygbB6huDXBfIT1b1b8Durkw,18631
|
15
|
-
jaxsim/api/model.py,sha256=
|
16
|
-
jaxsim/api/ode.py,sha256=
|
17
|
-
jaxsim/api/ode_data.py,sha256=
|
15
|
+
jaxsim/api/model.py,sha256=HXoqCtQ3KStGoxhgvFm8P_Sc-lbEM4l5No2MoHzNlOk,65558
|
16
|
+
jaxsim/api/ode.py,sha256=Vb2sN4zwpXnaJDD9-ziz2qvfmfa4jvIQ0fONbBIRGmU,13368
|
17
|
+
jaxsim/api/ode_data.py,sha256=U7F6TL6bENAxpQQl4PupPoDG7d7VfTTFqDAs3xwu6Hs,20003
|
18
18
|
jaxsim/api/references.py,sha256=XOVKuQXRmjPoP-T5JWGSbqIGX5DzOkeGafqRpj0ZQEM,20771
|
19
19
|
jaxsim/integrators/__init__.py,sha256=hxvOD-VK_mmd6v31wtC-nb28AYve1gLuZCNLV9wS-Kg,103
|
20
|
-
jaxsim/integrators/common.py,sha256=
|
20
|
+
jaxsim/integrators/common.py,sha256=ntjflaV3qWaFH_E65pAGZ6QipdnFsgQDasKtIKpxTe4,20432
|
21
21
|
jaxsim/integrators/fixed_step.py,sha256=KpjRd6hHtapxDoo6D1kyDrVDSHnke2TepI5grFH7_bM,2693
|
22
22
|
jaxsim/integrators/variable_step.py,sha256=5StkFh9oQba34zlkIoXG2fUN78gbxkHePWbrpQ-QZOI,21274
|
23
23
|
jaxsim/math/__init__.py,sha256=8oPITEoGwgRcOeG8KxtqxPQ8b5uku1HNRMokpCoi9Tc,352
|
24
24
|
jaxsim/math/adjoint.py,sha256=o1FCipkGwPtMbN2gFNIyUV8ADF3TX5fxElpTEXK0bIs,4377
|
25
25
|
jaxsim/math/cross.py,sha256=U7yEx_l75mSy5g6O-jsjBztApvxC3WaV4MpkS5tThu4,1330
|
26
|
-
jaxsim/math/inertia.py,sha256=
|
26
|
+
jaxsim/math/inertia.py,sha256=_hNpoeyEpAGr9ExDQJjckbjhk39luJFF-jv0SKqefnQ,1614
|
27
27
|
jaxsim/math/joint_model.py,sha256=EzAveaG5B6ZnCFNUzN30KEQUVesd83lfWXJarYR-kUw,9989
|
28
28
|
jaxsim/math/quaternion.py,sha256=_WA7W3iv7px83sWO1V1n0-J78hqAlO4SL1-jofE-UZ4,4754
|
29
29
|
jaxsim/math/rotation.py,sha256=k-nwT79zmWrys3NNAB-lGWxat7Kqm_6JnFRoimJ8rBg,2156
|
@@ -31,18 +31,18 @@ jaxsim/math/skew.py,sha256=oOGSSR8PUGROl6IJFlrmu6K3gPH-u16hUPfKIkcVv9o,1177
|
|
31
31
|
jaxsim/math/transform.py,sha256=KXzQgOnCfAtbXCwxhplpJ3F0JT3oEyeLVby1_uRAryQ,2892
|
32
32
|
jaxsim/mujoco/__init__.py,sha256=Zo5GAlN1DYKvX8s1hu1j6HntKIbBMLB9Puv9ouaNAZ8,158
|
33
33
|
jaxsim/mujoco/__main__.py,sha256=GBmB7J-zj75ZnFyuAAmpSOpbxi_HhHhWJeot3ljGDJY,5291
|
34
|
-
jaxsim/mujoco/loaders.py,sha256=
|
34
|
+
jaxsim/mujoco/loaders.py,sha256=XB-fgXuWMTFiaand5MZlLFQ5__Sh8MK5CJsxIU34MBk,25328
|
35
35
|
jaxsim/mujoco/model.py,sha256=AQksXemXWACJ3yvefV2G5HLwwBU9ISoJrOD1wlxdY5w,16386
|
36
36
|
jaxsim/mujoco/visualizer.py,sha256=T1vU-w4NKSmgEkZ0FqVcGmIvYrYO0len2UBSsU4MOZ0,6978
|
37
37
|
jaxsim/parsers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
38
|
-
jaxsim/parsers/kinematic_graph.py,sha256=
|
38
|
+
jaxsim/parsers/kinematic_graph.py,sha256=KijMWKyhTLKSNUmOOk4sYQMgPh_OkA_brncL7gBRHaY,34757
|
39
39
|
jaxsim/parsers/descriptions/__init__.py,sha256=PbIlunVfb59pB5jSX97YVpMAANRZPRkJ0X-hS14rzv4,221
|
40
40
|
jaxsim/parsers/descriptions/collision.py,sha256=BQeIG-TKi4SVny23w6riDrQ5itC6VRwEMBX6HgAXHxA,3973
|
41
41
|
jaxsim/parsers/descriptions/joint.py,sha256=VSb6C0FBBKMqwrHBKfc-Bbn4rl_J0RzUxMQlhIEvOPM,5185
|
42
42
|
jaxsim/parsers/descriptions/link.py,sha256=Eh0W5qL7_Uw0GV-BkNKXhm9Q2dRTfIWCX5D-87zQkxA,3711
|
43
43
|
jaxsim/parsers/descriptions/model.py,sha256=I2Vsbv8Josl4Le7b5rIvhqA2k9Bbv5JxMqwytayxds0,9833
|
44
44
|
jaxsim/parsers/rod/__init__.py,sha256=G2vqlLajBLUc4gyzXwsEI2Wsi4TMOIF9bLDFeT6KrGU,92
|
45
|
-
jaxsim/parsers/rod/parser.py,sha256=
|
45
|
+
jaxsim/parsers/rod/parser.py,sha256=HskeCqDsbtwH2BDk3vfxvx391wUTVGLaUXNvBrdNo-4,13486
|
46
46
|
jaxsim/parsers/rod/utils.py,sha256=5DsF3OeePZGidOJ5GiFSZx-51uIdnFvMW9EK6SgOW6Q,5698
|
47
47
|
jaxsim/rbda/__init__.py,sha256=H7DhXpxkPOi9lpUvg31IMHFfRafke1UoJLc5GQIdyhA,387
|
48
48
|
jaxsim/rbda/aba.py,sha256=w7ciyxB0IsmueatT0C7PcBQEl9dyiH9oqJgIi3xeTUE,8983
|
@@ -54,17 +54,16 @@ jaxsim/rbda/rnea.py,sha256=CLfqs9XFVaD-hvkLABshDAfdw5bm_AMV3UVAQ_IvURQ,7542
|
|
54
54
|
jaxsim/rbda/utils.py,sha256=eeT21Y4DiiyhrdF0lUE_VvRuwru5-rR7yOlOlWzCCWE,5381
|
55
55
|
jaxsim/rbda/contacts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
56
56
|
jaxsim/rbda/contacts/common.py,sha256=VwAs742futAmLnDgbaOuLzNDBFiKDfYItdEZ4UcFgzE,2467
|
57
|
-
jaxsim/rbda/contacts/
|
58
|
-
jaxsim/rbda/contacts/rigid.py,sha256=fbZk7sC6YOnTs_tzQRfsyBpHyT22XF-wB-EvOSZmhos,14746
|
57
|
+
jaxsim/rbda/contacts/rigid.py,sha256=8Vbnxng-ERZ5ka_eZGIBuhBDr2PNjc7m-Or255AfEw4,15862
|
59
58
|
jaxsim/rbda/contacts/soft.py,sha256=_wvb5iZDjGcVg6rNQelN4LZN7qSC2NIp0HdKvZmlGfk,15647
|
60
59
|
jaxsim/terrain/__init__.py,sha256=f7lVX-iNpH_wkkjef9Qpjh19TTAUOQw76EiLYJDVizc,78
|
61
|
-
jaxsim/terrain/terrain.py,sha256=
|
60
|
+
jaxsim/terrain/terrain.py,sha256=ctyNANIFSM3tZmamprjaEDcWgUSP0oNJbmT1zw9RjPs,4565
|
62
61
|
jaxsim/utils/__init__.py,sha256=Y5zyoRevl3EMVQadhZ4EtSwTEkDt2vcnFoRhPJjKTZ0,215
|
63
|
-
jaxsim/utils/jaxsim_dataclass.py,sha256=
|
62
|
+
jaxsim/utils/jaxsim_dataclass.py,sha256=5xJbY0G8d7C0OTNIW9T4vQxiDak6TGZT9gpNOvRykFI,11373
|
64
63
|
jaxsim/utils/tracing.py,sha256=KDMoyVPlu2NJvFkhtZwq5AkqMMgajt3munvJom-vEjQ,650
|
65
|
-
jaxsim/utils/wrappers.py,sha256=
|
66
|
-
jaxsim-0.4.3.
|
67
|
-
jaxsim-0.4.3.
|
68
|
-
jaxsim-0.4.3.
|
69
|
-
jaxsim-0.4.3.
|
70
|
-
jaxsim-0.4.3.
|
64
|
+
jaxsim/utils/wrappers.py,sha256=JhLUh1g8iU-lhjbuZRfkscPZhYlLCOorVM2Xl3ulRBI,4054
|
65
|
+
jaxsim-0.4.3.dev68.dist-info/LICENSE,sha256=eaYdFmdeMbiIoIiPzEK0MjP1S9wtFXjXNR5er49uLR0,1546
|
66
|
+
jaxsim-0.4.3.dev68.dist-info/METADATA,sha256=IrZMXHUptvvLA5YgloveNIge4OdEBjT-DxhdHBrn_WM,17247
|
67
|
+
jaxsim-0.4.3.dev68.dist-info/WHEEL,sha256=cVxcB9AmuTcXqmwrtPhNK88dr7IR_b6qagTj0UvIEbY,91
|
68
|
+
jaxsim-0.4.3.dev68.dist-info/top_level.txt,sha256=LxGMA8FLtXjQ6oI7N5gd_R_oSUHxpXxUEOfT1xS_ni0,7
|
69
|
+
jaxsim-0.4.3.dev68.dist-info/RECORD,,
|
@@ -1,384 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
import dataclasses
|
4
|
-
from typing import Any
|
5
|
-
|
6
|
-
import jax
|
7
|
-
import jax.numpy as jnp
|
8
|
-
import jax_dataclasses
|
9
|
-
import jaxopt
|
10
|
-
|
11
|
-
import jaxsim.api as js
|
12
|
-
import jaxsim.typing as jtp
|
13
|
-
from jaxsim.api.common import VelRepr
|
14
|
-
from jaxsim.math import Adjoint
|
15
|
-
from jaxsim.terrain.terrain import FlatTerrain, Terrain
|
16
|
-
|
17
|
-
from .common import ContactModel, ContactsParams, ContactsState
|
18
|
-
|
19
|
-
|
20
|
-
@jax_dataclasses.pytree_dataclass
|
21
|
-
class RelaxedRigidContactsParams(ContactsParams):
|
22
|
-
"""Parameters of the relaxed rigid contacts model."""
|
23
|
-
|
24
|
-
# Time constant
|
25
|
-
time_constant: jtp.Float = dataclasses.field(
|
26
|
-
default_factory=lambda: jnp.array(0.01, dtype=float)
|
27
|
-
)
|
28
|
-
|
29
|
-
# Adimensional damping coefficient
|
30
|
-
damping_coefficient: jtp.Float = dataclasses.field(
|
31
|
-
default_factory=lambda: jnp.array(1.0, dtype=float)
|
32
|
-
)
|
33
|
-
|
34
|
-
# Minimum impedance
|
35
|
-
d_min: jtp.Float = dataclasses.field(
|
36
|
-
default_factory=lambda: jnp.array(0.9, dtype=float)
|
37
|
-
)
|
38
|
-
|
39
|
-
# Maximum impedance
|
40
|
-
d_max: jtp.Float = dataclasses.field(
|
41
|
-
default_factory=lambda: jnp.array(0.95, dtype=float)
|
42
|
-
)
|
43
|
-
|
44
|
-
# Width
|
45
|
-
width: jtp.Float = dataclasses.field(
|
46
|
-
default_factory=lambda: jnp.array(0.0001, dtype=float)
|
47
|
-
)
|
48
|
-
|
49
|
-
# Midpoint
|
50
|
-
midpoint: jtp.Float = dataclasses.field(
|
51
|
-
default_factory=lambda: jnp.array(0.1, dtype=float)
|
52
|
-
)
|
53
|
-
|
54
|
-
# Power exponent
|
55
|
-
power: jtp.Float = dataclasses.field(
|
56
|
-
default_factory=lambda: jnp.array(1.0, dtype=float)
|
57
|
-
)
|
58
|
-
|
59
|
-
# Stiffness
|
60
|
-
stiffness: jtp.Float = dataclasses.field(
|
61
|
-
default_factory=lambda: jnp.array(0.0, dtype=float)
|
62
|
-
)
|
63
|
-
|
64
|
-
# Damping
|
65
|
-
damping: jtp.Float = dataclasses.field(
|
66
|
-
default_factory=lambda: jnp.array(0.0, dtype=float)
|
67
|
-
)
|
68
|
-
|
69
|
-
# Friction coefficient
|
70
|
-
mu: jtp.Float = dataclasses.field(
|
71
|
-
default_factory=lambda: jnp.array(0.5, dtype=float)
|
72
|
-
)
|
73
|
-
|
74
|
-
# Maximum number of iterations
|
75
|
-
max_iterations: jtp.Int = dataclasses.field(
|
76
|
-
default_factory=lambda: jnp.array(50, dtype=int)
|
77
|
-
)
|
78
|
-
|
79
|
-
# Solver tolerance
|
80
|
-
tolerance: jtp.Float = dataclasses.field(
|
81
|
-
default_factory=lambda: jnp.array(1e-6, dtype=float)
|
82
|
-
)
|
83
|
-
|
84
|
-
def __hash__(self) -> int:
|
85
|
-
from jaxsim.utils.wrappers import HashedNumpyArray
|
86
|
-
|
87
|
-
return hash(
|
88
|
-
(
|
89
|
-
HashedNumpyArray(self.time_constant),
|
90
|
-
HashedNumpyArray(self.damping_coefficient),
|
91
|
-
HashedNumpyArray(self.d_min),
|
92
|
-
HashedNumpyArray(self.d_max),
|
93
|
-
HashedNumpyArray(self.width),
|
94
|
-
HashedNumpyArray(self.midpoint),
|
95
|
-
HashedNumpyArray(self.power),
|
96
|
-
HashedNumpyArray(self.stiffness),
|
97
|
-
HashedNumpyArray(self.damping),
|
98
|
-
HashedNumpyArray(self.mu),
|
99
|
-
HashedNumpyArray(self.max_iterations),
|
100
|
-
HashedNumpyArray(self.tolerance),
|
101
|
-
)
|
102
|
-
)
|
103
|
-
|
104
|
-
def __eq__(self, other: RelaxedRigidContactsParams) -> bool:
|
105
|
-
return hash(self) == hash(other)
|
106
|
-
|
107
|
-
@classmethod
|
108
|
-
def build(
|
109
|
-
cls,
|
110
|
-
time_constant: jtp.FloatLike | None = None,
|
111
|
-
damping_coefficient: jtp.FloatLike | None = None,
|
112
|
-
d_min: jtp.FloatLike | None = None,
|
113
|
-
d_max: jtp.FloatLike | None = None,
|
114
|
-
width: jtp.FloatLike | None = None,
|
115
|
-
midpoint: jtp.FloatLike | None = None,
|
116
|
-
power: jtp.FloatLike | None = None,
|
117
|
-
stiffness: jtp.FloatLike | None = None,
|
118
|
-
damping: jtp.FloatLike | None = None,
|
119
|
-
mu: jtp.FloatLike | None = None,
|
120
|
-
max_iterations: jtp.IntLike | None = None,
|
121
|
-
tolerance: jtp.FloatLike | None = None,
|
122
|
-
) -> RelaxedRigidContactsParams:
|
123
|
-
"""Create a `RelaxedRigidContactsParams` instance"""
|
124
|
-
|
125
|
-
return cls(
|
126
|
-
**{
|
127
|
-
field: jnp.array(locals().get(field, default), dtype=default.dtype)
|
128
|
-
for field, default in map(
|
129
|
-
lambda f: (f, cls.__dataclass_fields__[f].default),
|
130
|
-
filter(lambda f: f != "__mutability__", cls.__dataclass_fields__),
|
131
|
-
)
|
132
|
-
}
|
133
|
-
)
|
134
|
-
|
135
|
-
def valid(self) -> bool:
|
136
|
-
return bool(
|
137
|
-
jnp.all(self.time_constant >= 0.0)
|
138
|
-
and jnp.all(self.damping_coefficient > 0.0)
|
139
|
-
and jnp.all(self.d_min >= 0.0)
|
140
|
-
and jnp.all(self.d_max <= 1.0)
|
141
|
-
and jnp.all(self.d_min <= self.d_max)
|
142
|
-
and jnp.all(self.width >= 0.0)
|
143
|
-
and jnp.all(self.midpoint >= 0.0)
|
144
|
-
and jnp.all(self.power >= 0.0)
|
145
|
-
and jnp.all(self.mu >= 0.0)
|
146
|
-
and jnp.all(self.max_iterations > 0)
|
147
|
-
and jnp.all(self.tolerance > 0.0)
|
148
|
-
)
|
149
|
-
|
150
|
-
|
151
|
-
@jax_dataclasses.pytree_dataclass
|
152
|
-
class RelaxedRigidContactsState(ContactsState):
|
153
|
-
"""Class storing the state of the relaxed rigid contacts model."""
|
154
|
-
|
155
|
-
def __eq__(self, other: RelaxedRigidContactsState) -> bool:
|
156
|
-
return hash(self) == hash(other)
|
157
|
-
|
158
|
-
@staticmethod
|
159
|
-
def build() -> RelaxedRigidContactsState:
|
160
|
-
"""Create a `RelaxedRigidContactsState` instance"""
|
161
|
-
|
162
|
-
return RelaxedRigidContactsState()
|
163
|
-
|
164
|
-
@staticmethod
|
165
|
-
def zero(model: js.model.JaxSimModel) -> RelaxedRigidContactsState:
|
166
|
-
"""Build a zero `RelaxedRigidContactsState` instance from a `JaxSimModel`."""
|
167
|
-
return RelaxedRigidContactsState.build()
|
168
|
-
|
169
|
-
def valid(self, model: js.model.JaxSimModel) -> bool:
|
170
|
-
return True
|
171
|
-
|
172
|
-
|
173
|
-
@jax_dataclasses.pytree_dataclass
|
174
|
-
class RelaxedRigidContacts(ContactModel):
|
175
|
-
"""Relaxed rigid contacts model."""
|
176
|
-
|
177
|
-
parameters: RelaxedRigidContactsParams = dataclasses.field(
|
178
|
-
default_factory=RelaxedRigidContactsParams
|
179
|
-
)
|
180
|
-
|
181
|
-
terrain: jax_dataclasses.Static[Terrain] = dataclasses.field(
|
182
|
-
default_factory=FlatTerrain
|
183
|
-
)
|
184
|
-
|
185
|
-
def compute_contact_forces(
|
186
|
-
self,
|
187
|
-
position: jtp.Vector,
|
188
|
-
velocity: jtp.Vector,
|
189
|
-
model: js.model.JaxSimModel,
|
190
|
-
data: js.data.JaxSimModelData,
|
191
|
-
link_forces: jtp.MatrixLike | None = None,
|
192
|
-
) -> tuple[jtp.Vector, tuple[Any, ...]]:
|
193
|
-
|
194
|
-
link_forces = (
|
195
|
-
link_forces
|
196
|
-
if link_forces is not None
|
197
|
-
else jnp.zeros((model.number_of_links(), 6))
|
198
|
-
)
|
199
|
-
|
200
|
-
references = js.references.JaxSimModelReferences.build(
|
201
|
-
model=model,
|
202
|
-
data=data,
|
203
|
-
velocity_representation=data.velocity_representation,
|
204
|
-
link_forces=link_forces,
|
205
|
-
)
|
206
|
-
|
207
|
-
def _detect_contact(x: jtp.Array, y: jtp.Array, z: jtp.Array) -> jtp.Array:
|
208
|
-
x, y, z = jax.tree_map(jnp.squeeze, (x, y, z))
|
209
|
-
|
210
|
-
n̂ = self.terrain.normal(x=x, y=y).squeeze()
|
211
|
-
h = jnp.array([0, 0, z - model.terrain.height(x=x, y=y)])
|
212
|
-
|
213
|
-
return jnp.dot(h, n̂)
|
214
|
-
|
215
|
-
# Compute the activation state of the collidable points
|
216
|
-
δ = jax.vmap(_detect_contact)(*position.T)
|
217
|
-
|
218
|
-
with (
|
219
|
-
references.switch_velocity_representation(VelRepr.Mixed),
|
220
|
-
data.switch_velocity_representation(VelRepr.Mixed),
|
221
|
-
):
|
222
|
-
M = js.model.free_floating_mass_matrix(model=model, data=data)
|
223
|
-
Jl_WC = jnp.vstack(
|
224
|
-
jax.vmap(lambda J, height: J * (height < 0))(
|
225
|
-
js.contact.jacobian(model=model, data=data)[:, :3, :], δ
|
226
|
-
)
|
227
|
-
)
|
228
|
-
W_H_C = js.contact.transforms(model=model, data=data)
|
229
|
-
BW_ν̇_free = jnp.hstack(
|
230
|
-
js.ode.system_acceleration(
|
231
|
-
model=model,
|
232
|
-
data=data,
|
233
|
-
link_forces=references.link_forces(model=model, data=data),
|
234
|
-
)
|
235
|
-
)
|
236
|
-
BW_ν = data.generalized_velocity()
|
237
|
-
J̇_WC = jnp.vstack(
|
238
|
-
jax.vmap(lambda J̇, height: J̇ * (height < 0))(
|
239
|
-
js.contact.jacobian_derivative(model=model, data=data)[:, :3], δ
|
240
|
-
),
|
241
|
-
)
|
242
|
-
|
243
|
-
a_ref, R, K, D = self._regularizers(
|
244
|
-
model=model,
|
245
|
-
penetration=δ,
|
246
|
-
velocity=velocity,
|
247
|
-
parameters=self.parameters,
|
248
|
-
)
|
249
|
-
|
250
|
-
G = Jl_WC @ jnp.linalg.lstsq(M, Jl_WC.T)[0]
|
251
|
-
CW_al_free_WC = Jl_WC @ BW_ν̇_free + J̇_WC @ BW_ν
|
252
|
-
|
253
|
-
# Calculate quantities for the linear optimization problem.
|
254
|
-
A = G + R
|
255
|
-
b = CW_al_free_WC - a_ref
|
256
|
-
|
257
|
-
objective = lambda x: jnp.sum(jnp.square(A @ x + b))
|
258
|
-
|
259
|
-
# Compute the 3D linear force in C[W] frame
|
260
|
-
opt = jaxopt.LBFGS(
|
261
|
-
fun=objective,
|
262
|
-
maxiter=self.parameters.max_iterations,
|
263
|
-
tol=self.parameters.tolerance,
|
264
|
-
maxls=30,
|
265
|
-
history_size=10,
|
266
|
-
max_stepsize=100.0,
|
267
|
-
)
|
268
|
-
|
269
|
-
init_params = (
|
270
|
-
K[:, jnp.newaxis] * jnp.zeros_like(position).at[:, 2].set(δ)
|
271
|
-
+ D[:, jnp.newaxis] * velocity
|
272
|
-
).flatten()
|
273
|
-
|
274
|
-
CW_f_Ci = opt.run(init_params=init_params).params.reshape(-1, 3)
|
275
|
-
|
276
|
-
def mixed_to_inertial(W_H_C: jax.Array, CW_fl: jax.Array) -> jax.Array:
|
277
|
-
W_Xf_CW = Adjoint.from_transform(
|
278
|
-
W_H_C.at[0:3, 0:3].set(jnp.eye(3)),
|
279
|
-
inverse=True,
|
280
|
-
).T
|
281
|
-
return W_Xf_CW @ jnp.hstack([CW_fl, jnp.zeros(3)])
|
282
|
-
|
283
|
-
W_f_C = jax.vmap(mixed_to_inertial)(W_H_C, CW_f_Ci)
|
284
|
-
|
285
|
-
return W_f_C, (None,)
|
286
|
-
|
287
|
-
@staticmethod
|
288
|
-
def _regularizers(
|
289
|
-
model: js.model.JaxSimModel,
|
290
|
-
penetration: jtp.Array,
|
291
|
-
velocity: jtp.Array,
|
292
|
-
parameters: RelaxedRigidContactsParams,
|
293
|
-
) -> tuple:
|
294
|
-
"""
|
295
|
-
Compute the contact jacobian and the reference acceleration.
|
296
|
-
|
297
|
-
Args:
|
298
|
-
model: The jaxsim model.
|
299
|
-
penetration: The penetration of the collidable points.
|
300
|
-
velocity: The velocity of the collidable points.
|
301
|
-
parameters: The parameters of the relaxed rigid contacts model.
|
302
|
-
|
303
|
-
Returns:
|
304
|
-
A tuple containing the reference acceleration, the regularization matrix, the stiffness, and the damping.
|
305
|
-
"""
|
306
|
-
|
307
|
-
Ω, ζ, ξ_min, ξ_max, width, mid, p, K, D, μ, *_ = jax_dataclasses.astuple(
|
308
|
-
parameters
|
309
|
-
)
|
310
|
-
|
311
|
-
def _imp_aref(
|
312
|
-
penetration: jtp.Array,
|
313
|
-
velocity: jtp.Array,
|
314
|
-
) -> tuple[jtp.Array, jtp.Array]:
|
315
|
-
"""
|
316
|
-
Calculates impedance and offset acceleration in constraint frame.
|
317
|
-
|
318
|
-
Args:
|
319
|
-
penetration: penetration in constraint frame
|
320
|
-
velocity: velocity in constraint frame
|
321
|
-
|
322
|
-
Returns:
|
323
|
-
a_ref: offset acceleration in constraint frame
|
324
|
-
R: regularization matrix
|
325
|
-
K: computed stiffness
|
326
|
-
D: computed damping
|
327
|
-
"""
|
328
|
-
position = jnp.zeros(shape=(3,)).at[2].set(penetration)
|
329
|
-
|
330
|
-
imp_x = jnp.abs(position) / width
|
331
|
-
imp_a = (1.0 / jnp.power(mid, p - 1)) * jnp.power(imp_x, p)
|
332
|
-
|
333
|
-
imp_b = 1 - (1.0 / jnp.power(1 - mid, p - 1)) * jnp.power(1 - imp_x, p)
|
334
|
-
|
335
|
-
imp_y = jnp.where(imp_x < mid, imp_a, imp_b)
|
336
|
-
|
337
|
-
imp = jnp.clip(ξ_min + imp_y * (ξ_max - ξ_min), ξ_min, ξ_max)
|
338
|
-
imp = jnp.atleast_1d(jnp.where(imp_x > 1.0, ξ_max, imp))
|
339
|
-
|
340
|
-
# When passing negative values, K and D represent a spring and damper, respectively.
|
341
|
-
K_f = jnp.where(K < 0, -K / ξ_max**2, 1 / (ξ_max * Ω * ζ) ** 2)
|
342
|
-
D_f = jnp.where(D < 0, -D / ξ_max, 2 / (ξ_max * Ω))
|
343
|
-
|
344
|
-
a_ref = -jnp.atleast_1d(D_f * velocity + K_f * imp * position)
|
345
|
-
|
346
|
-
return imp, a_ref, jnp.atleast_1d(K_f), jnp.atleast_1d(D_f)
|
347
|
-
|
348
|
-
def _compute_row(
|
349
|
-
*,
|
350
|
-
link_idx: jtp.Float,
|
351
|
-
penetration: jtp.Array,
|
352
|
-
velocity: jtp.Array,
|
353
|
-
) -> tuple[jtp.Array, jtp.Array]:
|
354
|
-
|
355
|
-
# Compute the reference acceleration.
|
356
|
-
ξ, a_ref, K, D = _imp_aref(
|
357
|
-
penetration=penetration,
|
358
|
-
velocity=velocity,
|
359
|
-
)
|
360
|
-
|
361
|
-
# Compute the regularization terms.
|
362
|
-
R = (
|
363
|
-
(2 * μ**2 * (1 - ξ) / (ξ + 1e-12))
|
364
|
-
* (1 + μ**2)
|
365
|
-
@ jnp.linalg.inv(M_L[link_idx, :3, :3])
|
366
|
-
)
|
367
|
-
|
368
|
-
return jax.tree.map(lambda x: x * (penetration < 0), (a_ref, R, K, D))
|
369
|
-
|
370
|
-
M_L = js.model.link_spatial_inertia_matrices(model=model)
|
371
|
-
|
372
|
-
a_ref, R, K, D = jax.tree.map(
|
373
|
-
jnp.concatenate,
|
374
|
-
(
|
375
|
-
*jax.vmap(_compute_row)(
|
376
|
-
link_idx=jnp.array(
|
377
|
-
model.kin_dyn_parameters.contact_parameters.body
|
378
|
-
),
|
379
|
-
penetration=penetration,
|
380
|
-
velocity=velocity,
|
381
|
-
),
|
382
|
-
),
|
383
|
-
)
|
384
|
-
return a_ref, jnp.diag(R), K, D
|
File without changes
|
File without changes
|
File without changes
|