jaxsim 0.7.1.dev49__py3-none-any.whl → 0.7.1.dev53__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jaxsim/_version.py CHANGED
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.7.1.dev49'
21
- __version_tuple__ = version_tuple = (0, 7, 1, 'dev49')
20
+ __version__ = version = '0.7.1.dev53'
21
+ __version_tuple__ = version_tuple = (0, 7, 1, 'dev53')
jaxsim/api/contact.py CHANGED
@@ -191,39 +191,22 @@ def estimate_good_contact_parameters(
191
191
  The user is encouraged to fine-tune the parameters based on the
192
192
  specific application.
193
193
  """
194
-
195
- def estimate_model_height(model: js.model.JaxSimModel) -> jtp.Float:
196
- """
197
- Displacement between the CoM and the lowest collidable point using zero
198
- joint positions.
199
- """
200
-
201
- zero_data = js.data.JaxSimModelData.build(
202
- model=model,
203
- )
204
-
194
+ if max_penetration is None:
195
+ zero_data = js.data.JaxSimModelData.build(model=model)
205
196
  W_pz_CoM = js.com.com_position(model=model, data=zero_data)[2]
206
-
207
197
  if model.floating_base():
208
198
  W_pz_C = collidable_point_positions(model=model, data=zero_data)[:, -1]
209
- return 2 * (W_pz_CoM - W_pz_C.min())
210
-
211
- return 2 * W_pz_CoM
199
+ W_pz_CoM = W_pz_CoM - W_pz_C.min()
212
200
 
213
- max_δ = (
214
- max_penetration
215
- if max_penetration is not None
216
- # Consider as default a 0.5% of the model height.
217
- else 0.005 * estimate_model_height(model=model)
218
- )
201
+ # Consider as default a 1% of the model center of mass height.
202
+ max_penetration = 0.01 * W_pz_CoM
219
203
 
220
204
  nc = number_of_active_collidable_points_steady_state
221
-
222
205
  return model.contact_model._parameters_class().build_default_from_jaxsim_model(
223
206
  model=model,
224
207
  standard_gravity=standard_gravity,
225
208
  static_friction_coefficient=static_friction_coefficient,
226
- max_penetration=max_δ,
209
+ max_penetration=max_penetration,
227
210
  number_of_active_collidable_points_steady_state=nc,
228
211
  damping_ratio=damping_ratio,
229
212
  )
@@ -569,6 +552,39 @@ def link_contact_forces(
569
552
  # to the frames associated to the collidable points.
570
553
  W_f_L = link_forces_from_contact_forces(model=model, contact_forces=W_f_C)
571
554
 
555
+ # Process constraint wrenches if present.
556
+ if "constr_wrenches_inertial" in aux_dict:
557
+ wrench_pair_constr_inertial = aux_dict["constr_wrenches_inertial"]
558
+
559
+ # Retrieve the constraint map from the model's kinematic parameters.
560
+ constraint_map = model.kin_dyn_parameters.constraints
561
+
562
+ # Extract the frame indices of the constraints.
563
+ frame_idxs_1 = constraint_map.frame_idxs_1
564
+ frame_idxs_2 = constraint_map.frame_idxs_2
565
+
566
+ n_kin_constraints = frame_idxs_1.shape[0]
567
+
568
+ if n_kin_constraints > 0:
569
+ parent_link_indices = jax.vmap(
570
+ lambda frame_idx_1, frame_idx_2: jnp.array(
571
+ (
572
+ js.frame.idx_of_parent_link(model, frame_index=frame_idx_1),
573
+ js.frame.idx_of_parent_link(model, frame_index=frame_idx_2),
574
+ )
575
+ )
576
+ )(frame_idxs_1, frame_idxs_2)
577
+
578
+ # Apply each constraint wrench to its corresponding parent link in W_f_L.
579
+ def apply_wrench(W_f_L, parent_indices, wrench_pair):
580
+ W_f_L = W_f_L.at[parent_indices[0]].add(wrench_pair[0])
581
+ W_f_L = W_f_L.at[parent_indices[1]].add(wrench_pair[1])
582
+ return W_f_L
583
+
584
+ W_f_L = jax.vmap(apply_wrench, in_axes=(None, 0, 0))(
585
+ W_f_L, parent_link_indices, wrench_pair_constr_inertial
586
+ ).sum(axis=0)
587
+
572
588
  return W_f_L, aux_dict
573
589
 
574
590
 
@@ -34,6 +34,7 @@ class KinDynParameters(JaxsimDataclass):
34
34
  joint_model: The joint model of the model.
35
35
  joint_parameters: The parameters of the joints.
36
36
  hw_link_metadata: The hardware parameters of the model links.
37
+ constraints: The kinematic constraints of the model. They can be used only with Relaxed-Rigid contact model.
37
38
  """
38
39
 
39
40
  # Static
@@ -58,6 +59,9 @@ class KinDynParameters(JaxsimDataclass):
58
59
  # Model hardware parameters
59
60
  hw_link_metadata: HwLinkMetadata | None = dataclasses.field(default=None)
60
61
 
62
+ # Kinematic constraints
63
+ constraints: ConstraintMap | None = dataclasses.field(default=None)
64
+
61
65
  @property
62
66
  def motion_subspaces(self) -> jtp.Matrix:
63
67
  r"""
@@ -80,12 +84,15 @@ class KinDynParameters(JaxsimDataclass):
80
84
  return self._support_body_array_bool.get()
81
85
 
82
86
  @staticmethod
83
- def build(model_description: ModelDescription) -> KinDynParameters:
87
+ def build(
88
+ model_description: ModelDescription, constraints: ConstraintMap | None
89
+ ) -> KinDynParameters:
84
90
  """
85
91
  Construct the kinematic and dynamic parameters of the model.
86
92
 
87
93
  Args:
88
94
  model_description: The parsed model description to consider.
95
+ constraints: An object of type ConstraintMap specifying the kinematic constraint of the model.
89
96
 
90
97
  Returns:
91
98
  The kinematic and dynamic parameters of the model.
@@ -253,6 +260,12 @@ class KinDynParameters(JaxsimDataclass):
253
260
 
254
261
  motion_subspaces = jnp.vstack([jnp.zeros((6, 1))[jnp.newaxis, ...], S_J])
255
262
 
263
+ # ===========
264
+ # Constraints
265
+ # ===========
266
+
267
+ constraints = ConstraintMap() if constraints is None else constraints
268
+
256
269
  # =================================
257
270
  # Build and return KinDynParameters
258
271
  # =================================
@@ -267,6 +280,7 @@ class KinDynParameters(JaxsimDataclass):
267
280
  joint_parameters=joint_parameters,
268
281
  contact_parameters=contact_parameters,
269
282
  frame_parameters=frame_parameters,
283
+ constraints=constraints,
270
284
  )
271
285
 
272
286
  def __eq__(self, other: KinDynParameters) -> bool:
@@ -1168,3 +1182,85 @@ class ScalingFactors(JaxsimDataclass):
1168
1182
 
1169
1183
  dims: jtp.Vector
1170
1184
  density: jtp.Float
1185
+
1186
+
1187
+ @dataclasses.dataclass(frozen=True)
1188
+ class ConstraintType:
1189
+ """
1190
+ Enumeration of all supported constraint types.
1191
+ """
1192
+
1193
+ Weld: ClassVar[int] = 0
1194
+ # TODO: handle Connect constraint
1195
+ # Connect: ClassVar[int] = 1
1196
+
1197
+
1198
+ @jax_dataclasses.pytree_dataclass
1199
+ class ConstraintMap(JaxsimDataclass):
1200
+ """
1201
+ Class storing the kinematic constraints of a model.
1202
+ """
1203
+
1204
+ frame_idxs_1: jtp.Int = dataclasses.field(
1205
+ default_factory=lambda: jnp.array([], dtype=int)
1206
+ )
1207
+ frame_idxs_2: jtp.Int = dataclasses.field(
1208
+ default_factory=lambda: jnp.array([], dtype=int)
1209
+ )
1210
+ constraint_types: jtp.Int = dataclasses.field(
1211
+ default_factory=lambda: jnp.array([], dtype=int)
1212
+ )
1213
+ K_P: jtp.Float = dataclasses.field(
1214
+ default_factory=lambda: jnp.array([], dtype=float)
1215
+ )
1216
+ K_D: jtp.Float = dataclasses.field(
1217
+ default_factory=lambda: jnp.array([], dtype=float)
1218
+ )
1219
+
1220
+ def add_constraint(
1221
+ self,
1222
+ frame_idx_1: int,
1223
+ frame_idx_2: int,
1224
+ constraint_type: int,
1225
+ K_P: float | None = None,
1226
+ K_D: float | None = None,
1227
+ ) -> ConstraintMap:
1228
+ """
1229
+ Add a constraint to the constraint map.
1230
+
1231
+ Args:
1232
+ frame_idx_1: The index of the first frame.
1233
+ frame_idx_2: The index of the second frame.
1234
+ constraint_type: The type of constraint.
1235
+ K_P: The proportional gain for Baumgarte stabilization (default: 1000).
1236
+ K_D: The derivative gain for Baumgarte stabilization (default: 2 * sqrt(K_P)).
1237
+
1238
+ Returns:
1239
+ A new ConstraintMap instance with the added constraint.
1240
+
1241
+ Note:
1242
+ Since this method returns a new instance of ConstraintMap with the new constraint,
1243
+ it will trigger recompilations in JIT-compiled functions.
1244
+ """
1245
+
1246
+ # Set default values for Baumgarte coefficients if not provided
1247
+ if K_P is None:
1248
+ K_P = 1000
1249
+ if K_D is None:
1250
+ K_D = 2 * np.sqrt(K_P)
1251
+
1252
+ # Create new arrays with the input elements appended
1253
+ new_frame_idxs_1 = jnp.append(self.frame_idxs_1, frame_idx_1)
1254
+ new_frame_idxs2 = jnp.append(self.frame_idxs_2, frame_idx_2)
1255
+ new_constraint_types = jnp.append(self.constraint_types, constraint_type)
1256
+ new_K_P = jnp.append(self.K_P, K_P)
1257
+ new_K_D = jnp.append(self.K_D, K_D)
1258
+
1259
+ # Return a new ConstraintMap object with updated attributes
1260
+ return ConstraintMap(
1261
+ frame_idxs_1=new_frame_idxs_1,
1262
+ frame_idxs_2=new_frame_idxs2,
1263
+ constraint_types=new_constraint_types,
1264
+ K_P=new_K_P,
1265
+ K_D=new_K_D,
1266
+ )
jaxsim/api/model.py CHANGED
@@ -141,6 +141,7 @@ class JaxSimModel(JaxsimDataclass):
141
141
  is_urdf: bool | None = None,
142
142
  considered_joints: Sequence[str] | None = None,
143
143
  gravity: jtp.FloatLike = jaxsim.math.STANDARD_GRAVITY,
144
+ constraints: jaxsim.rbda.kinematic_constraints.ConstraintMap | None = None,
144
145
  ) -> JaxSimModel:
145
146
  """
146
147
  Build a Model object from a model description.
@@ -167,6 +168,9 @@ class JaxSimModel(JaxsimDataclass):
167
168
  considered_joints:
168
169
  The list of joints to consider. If None, all joints are considered.
169
170
  gravity: The gravity constant. Normally passed as a positive value.
171
+ constraints:
172
+ An object of type ConstraintMap containing the kinematic constraints to consider. If None, no constraints are considered.
173
+ Note that constraints can be used only with RelaxedRigidContacts.
170
174
 
171
175
  Returns:
172
176
  The built Model object.
@@ -198,6 +202,7 @@ class JaxSimModel(JaxsimDataclass):
198
202
  contact_params=contact_params,
199
203
  integrator=integrator,
200
204
  gravity=-gravity,
205
+ constraints=constraints,
201
206
  )
202
207
 
203
208
  # Store the origin of the model, in case downstream logic needs it.
@@ -225,6 +230,7 @@ class JaxSimModel(JaxsimDataclass):
225
230
  actuation_params: jaxsim.rbda.actuation.ActuationParams | None = None,
226
231
  integrator: IntegratorType | None = None,
227
232
  gravity: jtp.FloatLike = jaxsim.math.STANDARD_GRAVITY,
233
+ constraints: jaxsim.rbda.kinematic_constraints.ConstraintMap | None = None,
228
234
  ) -> JaxSimModel:
229
235
  """
230
236
  Build a Model object from an intermediate model description.
@@ -247,6 +253,8 @@ class JaxSimModel(JaxsimDataclass):
247
253
  actuation_params: The parameters of the actuation model.
248
254
  integrator: The integrator to use for the simulation.
249
255
  gravity: The gravity constant.
256
+ constraints:
257
+ An object of type ConstraintMap containing the kinematic constraints to consider. If None, no constraints are considered.
250
258
 
251
259
  Returns:
252
260
  The built Model object.
@@ -278,6 +286,14 @@ class JaxSimModel(JaxsimDataclass):
278
286
  else jaxsim.rbda.contacts.RelaxedRigidContacts.build()
279
287
  )
280
288
 
289
+ if constraints is not None and not isinstance(
290
+ contact_model, jaxsim.rbda.contacts.RelaxedRigidContacts
291
+ ):
292
+ constraints = None
293
+ logging.warning(
294
+ f"Contact model {contact_model.__class__.__name__} does not support kinematic constraints. Use RelaxedRigidContacts instead."
295
+ )
296
+
281
297
  if contact_params is None:
282
298
  contact_params = contact_model._parameters_class()
283
299
 
@@ -295,7 +311,7 @@ class JaxSimModel(JaxsimDataclass):
295
311
  model = cls(
296
312
  model_name=model_name,
297
313
  kin_dyn_parameters=js.kin_dyn_parameters.KinDynParameters.build(
298
- model_description=model_description
314
+ model_description=model_description, constraints=constraints
299
315
  ),
300
316
  time_step=time_step,
301
317
  terrain=terrain,
@@ -777,6 +793,7 @@ def reduce(
777
793
  actuation_params=model.actuation_params,
778
794
  gravity=model.gravity,
779
795
  integrator=model.integrator,
796
+ constraints=model.kin_dyn_parameters.constraints,
780
797
  )
781
798
 
782
799
  with reduced_model.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):
@@ -2330,7 +2347,11 @@ def update_hw_parameters(
2330
2347
 
2331
2348
  Returns:
2332
2349
  The updated JaxSimModel object with modified hardware parameters.
2350
+
2351
+ Note:
2352
+ This function can be used only with models using Relax-Rigid contact model.
2333
2353
  """
2354
+
2334
2355
  kin_dyn_params: KinDynParameters = model.kin_dyn_parameters
2335
2356
  link_parameters: LinkParameters = kin_dyn_params.link_parameters
2336
2357
  hw_link_metadata: HwLinkMetadata = kin_dyn_params.hw_link_metadata
jaxsim/math/rotation.py CHANGED
@@ -82,3 +82,17 @@ class Rotation:
82
82
  R = c * jnp.eye(3) - s * Skew.wedge(u) + c1 * u @ u.T
83
83
 
84
84
  return R.transpose()
85
+
86
+ @staticmethod
87
+ def log_vee(R: jnp.ndarray) -> jtp.Vector:
88
+ """
89
+ Compute the logarithm map of an SO(3) rotation matrix.
90
+
91
+ Args:
92
+ R: The SO(3) rotation matrix.
93
+
94
+ Returns:
95
+ The corresponding so(3) tangent vector.
96
+ """
97
+
98
+ return jaxlie.SO3.from_matrix(R).log()
jaxsim/mujoco/loaders.py CHANGED
@@ -5,6 +5,7 @@ import warnings
5
5
  from collections.abc import Sequence
6
6
  from typing import Any
7
7
 
8
+ import jaxlie
8
9
  import mujoco as mj
9
10
  import numpy as np
10
11
  import rod.urdf.exporter
@@ -279,11 +280,13 @@ class RodModelToMjcf:
279
280
  # Add a floating joint if floating-base
280
281
  # -------------------------------------
281
282
 
283
+ base_link_name = rod_model.get_canonical_link()
284
+
282
285
  if not rod_model.is_fixed_base():
283
286
  considered_joints |= {"world_to_base"}
284
287
  urdf_string = RodModelToMjcf.add_floating_joint(
285
288
  urdf_string=urdf_string,
286
- base_link_name=rod_model.get_canonical_link(),
289
+ base_link_name=base_link_name,
287
290
  floating_joint_name="world_to_base",
288
291
  )
289
292
 
@@ -379,6 +382,30 @@ class RodModelToMjcf:
379
382
  # Find the <mujoco> element (might be the root itself).
380
383
  mujoco_element: ET._Element = next(iter(root.iter("mujoco")))
381
384
 
385
+ # --------------
386
+ # Add the frames
387
+ # --------------
388
+
389
+ for frame in rod_model.frames():
390
+ frame: rod.Frame
391
+ parent_name = frame.attached_to
392
+ parent_element = mujoco_element.find(f".//body[@name='{parent_name}']")
393
+
394
+ if parent_element is None and parent_name == base_link_name:
395
+ parent_element = mujoco_element.find(".//worldbody")
396
+
397
+ if parent_element is not None:
398
+ quat = jaxlie.SO3.from_rpy_radians(*frame.pose.rpy).wxyz
399
+ _ = ET.SubElement(
400
+ parent_element,
401
+ "site",
402
+ name=frame.name,
403
+ pos=" ".join(map(str, frame.pose.xyz)),
404
+ quat=" ".join(map(str, quat)),
405
+ )
406
+ else:
407
+ warnings.warn(f"Parent link '{parent_name}' not found", stacklevel=2)
408
+
382
409
  # --------------
383
410
  # Add the motors
384
411
  # --------------
jaxsim/rbda/__init__.py CHANGED
@@ -8,4 +8,10 @@ from .jacobian import (
8
8
  jacobian_derivative_full_doubly_left,
9
9
  jacobian_full_doubly_left,
10
10
  )
11
+ from .kinematic_constraints import (
12
+ compute_constraint_baumgarte_term,
13
+ compute_constraint_jacobians,
14
+ compute_constraint_jacobians_derivative,
15
+ compute_constraint_transforms,
16
+ )
11
17
  from .rnea import rnea
@@ -18,6 +18,10 @@ except ImportError:
18
18
  from typing_extensions import Self
19
19
 
20
20
 
21
+ MAX_STIFFNESS = 1e6
22
+ MAX_DAMPING = 1e4
23
+
24
+
21
25
  @functools.partial(jax.jit, static_argnames=("terrain",))
22
26
  def compute_penetration_data(
23
27
  p: jtp.VectorLike,
@@ -133,28 +137,30 @@ class ContactsParams(JaxsimDataclass):
133
137
  ξ = damping_ratio
134
138
  δ_max = max_penetration
135
139
  μc = static_friction_coefficient
140
+ nc = number_of_active_collidable_points_steady_state
136
141
 
137
142
  # Compute the total mass of the model.
138
143
  m = jnp.array(model.kin_dyn_parameters.link_parameters.mass).sum()
139
144
 
140
- # Rename the standard gravity.
141
- g = standard_gravity
142
-
143
- # Compute the average support force on each collidable point.
144
- f_average = m * g / number_of_active_collidable_points_steady_state
145
-
146
145
  # Compute the stiffness to get the desired steady-state penetration.
147
146
  # Note that this is dependent on the non-linear exponent used in
148
147
  # the damping term of the Hunt/Crossley model.
149
- K = f_average / jnp.power(δ_max, 1 + p) if stiffness is None else stiffness
148
+ if stiffness is None:
149
+ # Compute the average support force on each collidable point.
150
+ f_average = m * standard_gravity / nc
151
+
152
+ stiffness = f_average / jnp.power(δ_max, 1 + p)
153
+ stiffness = jnp.clip(stiffness, 0, MAX_STIFFNESS)
150
154
 
151
155
  # Compute the damping using the damping ratio.
152
- critical_damping = 2 * jnp.sqrt(K * m)
153
- D = ξ * critical_damping if damping is None else damping
156
+ critical_damping = 2 * jnp.sqrt(stiffness * m)
157
+ if damping is None:
158
+ damping = ξ * critical_damping
159
+ damping = jnp.clip(damping, 0, MAX_DAMPING)
154
160
 
155
161
  return self.build(
156
- K=K,
157
- D=D,
162
+ K=stiffness,
163
+ D=damping,
158
164
  mu=μc,
159
165
  p=p,
160
166
  q=q,
@@ -12,6 +12,13 @@ import optax
12
12
  import jaxsim.api as js
13
13
  import jaxsim.typing as jtp
14
14
  from jaxsim.api.common import ModelDataWithVelocityRepresentation, VelRepr
15
+ from jaxsim.rbda.kinematic_constraints import (
16
+ ConstraintMap,
17
+ compute_constraint_baumgarte_term,
18
+ compute_constraint_jacobians,
19
+ compute_constraint_jacobians_derivative,
20
+ compute_constraint_transforms,
21
+ )
15
22
 
16
23
  from . import common, soft
17
24
 
@@ -333,10 +340,22 @@ class RelaxedRigidContacts(common.ContactModel):
333
340
  # Compute the position in the constraint frame.
334
341
  position_constraint = jax.vmap(lambda δ, n̂: -δ * n̂)(δ, n̂)
335
342
 
343
+ # Compute the regularization terms.
344
+ a_ref, r, *_ = self._regularizers(
345
+ model=model,
346
+ position_constraint=position_constraint,
347
+ velocity_constraint=velocity,
348
+ parameters=model.contact_params,
349
+ )
350
+
336
351
  # Compute the transforms of the implicit frames corresponding to the
337
352
  # collidable points.
338
353
  W_H_C = js.contact.transforms(model=model, data=data)
339
354
 
355
+ # Retrieve the kinematic constraints, if any.
356
+ kin_constraints: ConstraintMap = model.kin_dyn_parameters.constraints
357
+ n_kin_constraints: int = 6 * kin_constraints.frame_idxs_1.shape[0]
358
+
340
359
  with (
341
360
  data.switch_velocity_representation(VelRepr.Mixed),
342
361
  references.switch_velocity_representation(VelRepr.Mixed),
@@ -354,30 +373,103 @@ class RelaxedRigidContacts(common.ContactModel):
354
373
 
355
374
  M = js.model.free_floating_mass_matrix(model=model, data=data)
356
375
 
376
+ # Compute the linear part of the Jacobian of the collidable points
357
377
  Jl_WC = jnp.vstack(
358
378
  jax.vmap(lambda J, δ: J * (δ > 0))(
359
379
  js.contact.jacobian(model=model, data=data)[:, :3, :], δ
360
380
  )
361
381
  )
362
382
 
363
- J̇_WC = jnp.vstack(
383
+ # Compute the linear part of the Jacobian derivative of the collidable points
384
+ J̇l_WC = jnp.vstack(
364
385
  jax.vmap(lambda J̇, δ: J̇ * (δ > 0))(
365
386
  js.contact.jacobian_derivative(model=model, data=data)[:, :3], δ
366
387
  ),
367
388
  )
368
389
 
369
- # Compute the regularization terms.
370
- a_ref, R, *_ = self._regularizers(
371
- model=model,
372
- position_constraint=position_constraint,
373
- velocity_constraint=velocity,
374
- parameters=model.contact_params,
375
- )
390
+ # Check if there are any kinematic constraints
391
+ if n_kin_constraints > 0:
392
+ with (
393
+ data.switch_velocity_representation(VelRepr.Mixed),
394
+ references.switch_velocity_representation(VelRepr.Mixed),
395
+ ):
396
+ J_constr = jax.vmap(
397
+ compute_constraint_jacobians, in_axes=(None, None, 0)
398
+ )(model, data, kin_constraints)
399
+
400
+ J̇_constr = jax.vmap(
401
+ compute_constraint_jacobians_derivative, in_axes=(None, None, 0)
402
+ )(model, data, kin_constraints)
403
+
404
+ W_H_constr_pairs = jax.vmap(
405
+ compute_constraint_transforms, in_axes=(None, None, 0)
406
+ )(model, data, kin_constraints)
407
+
408
+ constr_baumgarte_term = jnp.ravel(
409
+ jax.vmap(
410
+ compute_constraint_baumgarte_term,
411
+ in_axes=(0, None, 0, 0),
412
+ )(
413
+ J_constr,
414
+ BW_ν,
415
+ W_H_constr_pairs,
416
+ kin_constraints,
417
+ ),
418
+ )
419
+
420
+ J_constr = jnp.vstack(J_constr)
421
+ J̇_constr = jnp.vstack(J̇_constr)
422
+
423
+ R = jnp.diag(jnp.hstack([r, jnp.zeros(n_kin_constraints)]))
424
+ a_ref = jnp.hstack([a_ref, -constr_baumgarte_term])
425
+
426
+ J = jnp.vstack([Jl_WC, J_constr])
427
+ J̇ = jnp.vstack([J̇l_WC, J̇_constr])
428
+
429
+ else:
430
+ R = jnp.diag(r)
431
+
432
+ J = Jl_WC
433
+ J̇ = J̇l_WC
434
+
435
+ # Compute the Delassus matrix for contacts (mixed representation).
436
+ G_contacts = Jl_WC @ jnp.linalg.pinv(M) @ Jl_WC.T
376
437
 
377
- # Compute the Delassus matrix and the free mixed linear acceleration of
378
- # the collidable points.
379
- G = Jl_WC @ jnp.linalg.pinv(M) @ Jl_WC.T
380
- CW_al_free_WC = Jl_WC @ BW_ν̇_free + J̇_WC @ BW_ν
438
+ # Compute the Delassus matrix for constraints (inertial representation) if constraints exist.
439
+ with data.switch_velocity_representation(VelRepr.Inertial):
440
+ if n_kin_constraints > 0:
441
+ G_constraints = (
442
+ J_constr
443
+ @ jnp.linalg.pinv(
444
+ js.model.free_floating_mass_matrix(
445
+ model=model,
446
+ data=data,
447
+ )
448
+ )
449
+ @ J_constr.T
450
+ )
451
+ else:
452
+ G_constraints = jnp.zeros((0, 0))
453
+
454
+ # Combine the Delassus matrices for contacts and constraints if constraints exist.
455
+ if G_constraints.shape[0] > 0:
456
+ G = jnp.block(
457
+ [
458
+ [
459
+ G_contacts,
460
+ jnp.zeros((G_contacts.shape[0], G_constraints.shape[1])),
461
+ ],
462
+ [
463
+ jnp.zeros((G_constraints.shape[0], G_contacts.shape[1])),
464
+ G_constraints,
465
+ ],
466
+ ]
467
+ )
468
+ else:
469
+ G = G_contacts
470
+
471
+ # Compute the free mixed linear acceleration of the collidable points.
472
+ CW_al_free_WC = J @ BW_ν̇_free + J̇ @ BW_ν
381
473
 
382
474
  # Calculate quantities for the linear optimization problem.
383
475
  A = G + R
@@ -469,6 +561,9 @@ class RelaxedRigidContacts(common.ContactModel):
469
561
  )[0]
470
562
  )(position, velocity).flatten()
471
563
 
564
+ if n_kin_constraints > 0:
565
+ init_params = jnp.hstack([init_params, jnp.zeros(n_kin_constraints)])
566
+
472
567
  # Get the solver options.
473
568
  solver_options = self.solver_options
474
569
 
@@ -494,8 +589,20 @@ class RelaxedRigidContacts(common.ContactModel):
494
589
  has_aux=True,
495
590
  )
496
591
 
497
- # Reshape the optimized solution to be a matrix of 3D contact forces.
498
- CW_fl_C = solution.reshape(-1, 3)
592
+ if n_kin_constraints > 0:
593
+ # Extract the last n_kin_constr values from the solution and split them into 6D wrenches
594
+ kin_constr_wrench_inertial = solution[-n_kin_constraints:].reshape(-1, 6)
595
+
596
+ # Form an array of tuples with each wrench and its opposite using jax constructs
597
+ kin_constr_wrench_pairs_inertial = jnp.stack(
598
+ (kin_constr_wrench_inertial, -kin_constr_wrench_inertial), axis=1
599
+ )
600
+
601
+ # Reshape the optimized solution to be a matrix of 3D contact forces.
602
+ CW_fl_C = solution[0:-n_kin_constraints].reshape(-1, 3)
603
+ else:
604
+ kin_constr_wrench_pairs_inertial = jnp.zeros((0, 2, 6))
605
+ CW_fl_C = solution.reshape(-1, 3)
499
606
 
500
607
  # Convert the contact forces from mixed to inertial-fixed representation.
501
608
  W_f_C = ModelDataWithVelocityRepresentation.other_representation_to_inertial(
@@ -505,7 +612,9 @@ class RelaxedRigidContacts(common.ContactModel):
505
612
  is_force=True,
506
613
  )
507
614
 
508
- return W_f_C, {}
615
+ return W_f_C, {
616
+ "constr_wrenches_inertial": kin_constr_wrench_pairs_inertial,
617
+ }
509
618
 
510
619
  @staticmethod
511
620
  def _regularizers(
@@ -635,4 +744,4 @@ class RelaxedRigidContacts(common.ContactModel):
635
744
  ),
636
745
  )
637
746
 
638
- return a_ref, jnp.diag(R), K, D
747
+ return a_ref, R, K, D
@@ -0,0 +1,148 @@
1
+ from __future__ import annotations
2
+
3
+ import jax.numpy as jnp
4
+
5
+ import jaxsim.api as js
6
+ import jaxsim.typing as jtp
7
+ from jaxsim.api.common import VelRepr
8
+ from jaxsim.api.kin_dyn_parameters import ConstraintMap
9
+ from jaxsim.math.rotation import Rotation
10
+
11
+
12
+ def compute_constraint_jacobians(
13
+ model: js.model.JaxSimModel,
14
+ data: js.data.JaxSimModelData,
15
+ constraint: ConstraintMap,
16
+ ) -> jtp.Matrix:
17
+ """
18
+ Compute the constraint Jacobian matrix representing the kinematic constraints between two frames.
19
+
20
+ Args:
21
+ model: The JaxSim model.
22
+ data: The data of the considered model.
23
+ constraint: The considered constraint.
24
+
25
+ Returns:
26
+ The resulting constraint Jacobian matrix representing the kinematic constraint
27
+ between the two specified frames, in inertial representation.
28
+ """
29
+
30
+ J_WF1 = js.frame.jacobian(
31
+ model=model,
32
+ data=data,
33
+ frame_index=constraint.frame_idxs_1,
34
+ output_vel_repr=VelRepr.Inertial,
35
+ )
36
+ J_WF2 = js.frame.jacobian(
37
+ model=model,
38
+ data=data,
39
+ frame_index=constraint.frame_idxs_2,
40
+ output_vel_repr=VelRepr.Inertial,
41
+ )
42
+
43
+ return J_WF1 - J_WF2
44
+
45
+
46
+ def compute_constraint_jacobians_derivative(
47
+ model: js.model.JaxSimModel,
48
+ data: js.data.JaxSimModelData,
49
+ constraint: ConstraintMap,
50
+ ) -> jtp.Matrix:
51
+ """
52
+ Compute the derivative of the constraint Jacobian matrix representing the kinematic constraints between two frames.
53
+
54
+ Args:
55
+ model: The JaxSim model.
56
+ data: The data of the considered model.
57
+ constraint: The considered constraint.
58
+
59
+ Returns:
60
+ The resulting constraint Jacobian derivative matrix representing the kinematic constraint
61
+ between the two specified frames, in inertial representation.
62
+ """
63
+
64
+ J̇_WF1 = js.frame.jacobian_derivative(
65
+ model=model,
66
+ data=data,
67
+ frame_index=constraint.frame_idxs_1,
68
+ output_vel_repr=VelRepr.Inertial,
69
+ )
70
+ J̇_WF2 = js.frame.jacobian_derivative(
71
+ model=model,
72
+ data=data,
73
+ frame_index=constraint.frame_idxs_2,
74
+ output_vel_repr=VelRepr.Inertial,
75
+ )
76
+
77
+ return J̇_WF1 - J̇_WF2
78
+
79
+
80
+ def compute_constraint_baumgarte_term(
81
+ J_constr: jtp.Matrix,
82
+ nu: jtp.Vector,
83
+ W_H_F_constr: tuple[jtp.Matrix, jtp.Matrix],
84
+ constraint: ConstraintMap,
85
+ ) -> jtp.Vector:
86
+ """
87
+ Compute the Baumgarte stabilization term for kinematic constraints.
88
+
89
+ Args:
90
+ J_constr: The constraint Jacobian matrix.
91
+ nu: The generalized velocity vector.
92
+ W_H_F_constr: A tuple containing the homogeneous transformation matrices
93
+ of two frames (W_H_F1 and W_H_F2) with respect to the world frame.
94
+ K_P: The proportional gain for position and orientation error correction.
95
+ K_D: The derivative gain for velocity error correction.
96
+ constraint: The considered constraint.
97
+
98
+ Returns:
99
+ The computed Baumgarte stabilization term.
100
+ """
101
+ W_H_F1, W_H_F2 = W_H_F_constr
102
+
103
+ W_p_F1 = W_H_F1[0:3, 3]
104
+ W_p_F2 = W_H_F2[0:3, 3]
105
+
106
+ W_R_F1 = W_H_F1[0:3, 0:3]
107
+ W_R_F2 = W_H_F2[0:3, 0:3]
108
+
109
+ K_P = constraint.K_P
110
+ K_D = constraint.K_D
111
+
112
+ vel_error = J_constr @ nu
113
+ position_error = W_p_F1 - W_p_F2
114
+ R_error = W_R_F2.T @ W_R_F1
115
+ orientation_error = Rotation.log_vee(R_error)
116
+
117
+ baumgarte_term = (
118
+ K_P * jnp.concatenate([position_error, orientation_error]) + K_D * vel_error
119
+ )
120
+
121
+ return baumgarte_term
122
+
123
+
124
+ def compute_constraint_transforms(
125
+ model: js.model.JaxSimModel,
126
+ data: js.data.JaxSimModelData,
127
+ constraint: ConstraintMap,
128
+ ) -> jtp.Matrix:
129
+ """
130
+ Compute the transformation matrices for a given kinematic constraint between two frames.
131
+
132
+ Args:
133
+ model: The JaxSim model.
134
+ data: The data of the considered model.
135
+ constraint: The considered constraint.
136
+
137
+ Returns:
138
+ A matrix containing the tuple of transformation matrices of the two frames.
139
+ """
140
+
141
+ W_H_F1 = js.frame.transform(
142
+ model=model, data=data, frame_index=constraint.frame_idxs_1
143
+ )
144
+ W_H_F2 = js.frame.transform(
145
+ model=model, data=data, frame_index=constraint.frame_idxs_2
146
+ )
147
+
148
+ return jnp.array((W_H_F1, W_H_F2))
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: jaxsim
3
- Version: 0.7.1.dev49
3
+ Version: 0.7.1.dev53
4
4
  Summary: A differentiable physics engine and multibody dynamics library for control and robot learning.
5
5
  Author-email: Diego Ferigo <dgferigo@gmail.com>, Filippo Luca Ferretti <filippoluca.ferretti@outlook.com>
6
6
  Maintainer-email: Filippo Luca Ferretti <filippo.ferretti@iit.it>, Alessandro Croci <alessandro.croci@iit.it>
@@ -1,5 +1,5 @@
1
1
  jaxsim/__init__.py,sha256=EKeysKN-7UswwJLCl7n6qIBKQIVUtYsCMYu_tCoFn7g,3628
2
- jaxsim/_version.py,sha256=YV4uYC9h68LubyOgYdRGhoUvNTSB-Q4AvQqllU5Yg2o,526
2
+ jaxsim/_version.py,sha256=LFxKSa9ekR80WFGFfy_25o7U4fTH7Li5aXFdTjAxVI0,526
3
3
  jaxsim/exceptions.py,sha256=MQ3LRMfVMX2-g3qYj7mUVNV9OLlIA48TANJegbcQyXI,2641
4
4
  jaxsim/logging.py,sha256=STI-D_upXZYX-ZezLrlJJ0UlD5YspST0vZ_DcIwkzO4,1553
5
5
  jaxsim/typing.py,sha256=7msl8t5Jt09RNYfKdPJtpjLfWurldcycDappb045Eso,761
@@ -7,14 +7,14 @@ jaxsim/api/__init__.py,sha256=4skzcTTejLFfZ_JE6yEEyNxObpXnR5u-bYsn2lBEx-4,234
7
7
  jaxsim/api/actuation_model.py,sha256=GgLi-yhEpsczwhMNYIlMvet8tirmmED6S7AumbSbk4U,3705
8
8
  jaxsim/api/com.py,sha256=47a9SSaXY540RCkVnHodwLNUrodIfJIkguIYdSEQVwQ,13697
9
9
  jaxsim/api/common.py,sha256=yTaRXDYkXmISBOhZ93f9TssR0p4wq7qj7B6OsvYzRME,6942
10
- jaxsim/api/contact.py,sha256=dlKKDQUG-KQ5qQaYBv2NmZLDb1OnJdltZv8MWXkD_W0,20969
10
+ jaxsim/api/contact.py,sha256=H4ltV0RYv3AbaKPtu0pfyWhcsm4FiEBUEfBtpusEpHw,22087
11
11
  jaxsim/api/data.py,sha256=9pxug2gFIZPwqi9kNYXhEziA5IQBB9KNNwIfPfc_kAU,23249
12
12
  jaxsim/api/frame.py,sha256=4wg6GsyBQgYhSvc-ry_31JsaL66sZt3TtgwjB7NrHmk,14583
13
13
  jaxsim/api/integrators.py,sha256=sHdTWw2Z-Q7jggA8zRkA1KYYd4HNIozXPwNvFwt0YBs,9011
14
14
  jaxsim/api/joint.py,sha256=AnqlNWmBOay-gsoo0y4AbfFQ2OCJm-8T1E0IMhZeLoY,7457
15
- jaxsim/api/kin_dyn_parameters.py,sha256=VrQvH5MXxyiZRkUtXrm99bAZ8bs9Ry02uWMYc-5ZL9Q,39263
15
+ jaxsim/api/kin_dyn_parameters.py,sha256=73Bixrqm481d058xcEBMj8C5xA9ndLDQdn1YE55JR8M,42488
16
16
  jaxsim/api/link.py,sha256=bSZOYJDY9HJMgY25VzevTTsdFZTJc6yRRpslc5FhGHE,12782
17
- jaxsim/api/model.py,sha256=QmLwVogrYQXP97F_6Sz04Wu_gnavuXDA4CNIgDfve8o,86158
17
+ jaxsim/api/model.py,sha256=HaP9Yql_wz2iUGFdMVTtUuI2-CNr5wyFJLbL6rh1Jyo,87302
18
18
  jaxsim/api/ode.py,sha256=fp20_LK9lXw2DfNkQgrfJmtd_iBXDNzZkOn0u5Pm8Qw,6190
19
19
  jaxsim/api/references.py,sha256=-vd50y3v-jkXAsILS432etIKV6e2EUE2oOeLHuUrfuQ,20399
20
20
  jaxsim/math/__init__.py,sha256=dNozvtm8WsB7nxw4uK29yQQKPcDUEczr2zcHoZfMItc,383
@@ -23,13 +23,13 @@ jaxsim/math/cross.py,sha256=AM4HauuuT09q2TN42qvdXhJ9LvtCh0e7ZyLjP-7sANs,1498
23
23
  jaxsim/math/inertia.py,sha256=T-iAjPYSD_72R0ZG8GDJhe5i3Jc3ojhlbBRSscTdCKg,1577
24
24
  jaxsim/math/joint_model.py,sha256=vBnwXSsw2LCb2Tr5wl2iCo0KvLqcibBbeKcsoH5r9tk,6990
25
25
  jaxsim/math/quaternion.py,sha256=MSaZywzJDxs2te1ZELeIcupKSFIA9q_pdXy7fDAEqM4,4539
26
- jaxsim/math/rotation.py,sha256=TEUtT3X2tFieNxdlccup1pfaTgCTtfX-hTNotd8-nNk,1892
26
+ jaxsim/math/rotation.py,sha256=XMBEnyyWF7jAjRvcFPT1SNj4l-ls0nhxZImp-VGMnrM,2220
27
27
  jaxsim/math/skew.py,sha256=z_9YN-NDHL3n4KXWNbzTSMkFDZ0SDpz4RUcwwYFOaao,1402
28
28
  jaxsim/math/transform.py,sha256=d0_m_obmUOmnI8Bte0ktvibR9Hv9M9qpg8tVuLON2g0,3192
29
29
  jaxsim/math/utils.py,sha256=JgJrBPeuCvi0969VqoNsyk3CflQiLzopngKDjl6RfiE,1898
30
30
  jaxsim/mujoco/__init__.py,sha256=1kAWzYOS7nP29S5FGyWPqiAnPf4yPSoaPW-WBGBjVV0,214
31
31
  jaxsim/mujoco/__main__.py,sha256=GBmB7J-zj75ZnFyuAAmpSOpbxi_HhHhWJeot3ljGDJY,5291
32
- jaxsim/mujoco/loaders.py,sha256=OCk1T11iIm3qZUibNpo_bxxLgaGSkCpLt7ae_ND0ExA,23272
32
+ jaxsim/mujoco/loaders.py,sha256=CY-AANwrKRKOhM5E9EDi_ofQYb5jiN4YOleEu8Z0S9A,24231
33
33
  jaxsim/mujoco/model.py,sha256=bRXo1uhWDN37sP9qdejr_2vq_PXHQ7p6eyBlFff_JcE,16492
34
34
  jaxsim/mujoco/utils.py,sha256=q75OSjxLU2BromVUejt0DVnSbrV5D177YW6LkOdu78g,8823
35
35
  jaxsim/mujoco/visualizer.py,sha256=cmI6DhFb1XC7oEtg_wl-s-U56dWHA-F7GlBD6EDYTyA,7744
@@ -44,19 +44,20 @@ jaxsim/parsers/rod/__init__.py,sha256=G2vqlLajBLUc4gyzXwsEI2Wsi4TMOIF9bLDFeT6KrG
44
44
  jaxsim/parsers/rod/meshes.py,sha256=yAXefG73_zqbVKRUdlcz9yFmypjDIpiP9cO96PeAozE,2842
45
45
  jaxsim/parsers/rod/parser.py,sha256=rmj-W5ekdcIe_Lu66nxTKkgwxF7vGDZdNkfcYYlU-Yc,14398
46
46
  jaxsim/parsers/rod/utils.py,sha256=wmD-wCF1lLO8pknX7A3a8CGt9wDlTS_xCqQulcZ_XlM,8242
47
- jaxsim/rbda/__init__.py,sha256=ksfupKZzeJyysxrbyMyEfszUdBH6LfCfkSz3KLfODhY,328
47
+ jaxsim/rbda/__init__.py,sha256=1gjy9Q7uwZSK4Cqe4gqRHlJ_tYLtF1GvfQEoYx27rsk,520
48
48
  jaxsim/rbda/aba.py,sha256=jxvss3XL8pBaT40bWG5pHcH9f1DDl3LoRqdnpFSlCWo,9030
49
49
  jaxsim/rbda/collidable_points.py,sha256=XyeV1I43GL22j03rkNVocaIPOGYirt3PiDHrFMndziQ,2070
50
50
  jaxsim/rbda/crba.py,sha256=DC9kBXMG1qXaoAdo8K7OCnVHT_YUaL_t6Li56sRf8ro,5093
51
51
  jaxsim/rbda/forward_kinematics.py,sha256=qem7Yp-B2oNVOsU3Q2CWV2tbfZKJOCAdDozFgaPB8tg,3838
52
52
  jaxsim/rbda/jacobian.py,sha256=EaMvf073UnLWJGXm4UZIlYd4erulFAGgj_pp89k6xic,11113
53
+ jaxsim/rbda/kinematic_constraints.py,sha256=qXxjenDBV67oMaK4dl5MOfDOXKLMZ6K5zbm6qC1Lvz8,4234
53
54
  jaxsim/rbda/rnea.py,sha256=lMU7xxdPqGGzk0QwteB-IYjL4auHOpd78C1YqAXlp9s,7588
54
55
  jaxsim/rbda/utils.py,sha256=6JwEDQqLMsBX7CUmPYEhdPEscXmGbWVYg6xEriPOgvE,5587
55
56
  jaxsim/rbda/actuation/__init__.py,sha256=zWqB8VBHadbyf8FAuhQtcfWdetGjfVxuNDwIeUqNOS4,36
56
57
  jaxsim/rbda/actuation/common.py,sha256=aGFqO4VTgQLsTJyOtVuoa_otT_RbkckmG3rq7wjOyB4,462
57
58
  jaxsim/rbda/contacts/__init__.py,sha256=resrBkTdOA-1YMdcdUH2RATEhAf_Ye6MQNtjG3ClMYQ,371
58
- jaxsim/rbda/contacts/common.py,sha256=qVm3Ghoytg1HAeykNrYw5-4rQJ4Mv7h0Pk75ETzGXyc,9045
59
- jaxsim/rbda/contacts/relaxed_rigid.py,sha256=RjeLF06Pp19qio447U9z5EdhdM6nyMh-ISQX_2-vdaE,21349
59
+ jaxsim/rbda/contacts/common.py,sha256=e-NDoC1N9PK5B0e1fQkYtxS495GSpRCGZRzB6-53WS4,9207
60
+ jaxsim/rbda/contacts/relaxed_rigid.py,sha256=oBQV-oAbx6_D8xhddCRLbV5VV7Lz75levRS5mxBnYVg,25621
60
61
  jaxsim/rbda/contacts/rigid.py,sha256=I_TjsT-84ywou2idYioqekeQbIOHTdHI54vot8ijMPk,17605
61
62
  jaxsim/rbda/contacts/soft.py,sha256=a7NYMknPfWKfCdbVu83ttDu1u_gssIRvxe9L1622tM0,15284
62
63
  jaxsim/terrain/__init__.py,sha256=f7lVX-iNpH_wkkjef9Qpjh19TTAUOQw76EiLYJDVizc,78
@@ -65,8 +66,8 @@ jaxsim/utils/__init__.py,sha256=Y5zyoRevl3EMVQadhZ4EtSwTEkDt2vcnFoRhPJjKTZ0,215
65
66
  jaxsim/utils/jaxsim_dataclass.py,sha256=XzmZeIibcaOzaxpprsGSxH3UrM66PAO456rFV91sNXg,11453
66
67
  jaxsim/utils/tracing.py,sha256=Btwxdfhb7fJLk3r5PlQkGYj60Y2KbFT1gANGIA697FU,530
67
68
  jaxsim/utils/wrappers.py,sha256=3IMwydqFgmSPqeuUQ3PRmdhDc1IoT6XC23jPC_LjWXs,4175
68
- jaxsim-0.7.1.dev49.dist-info/licenses/LICENSE,sha256=eaYdFmdeMbiIoIiPzEK0MjP1S9wtFXjXNR5er49uLR0,1546
69
- jaxsim-0.7.1.dev49.dist-info/METADATA,sha256=c1M6f3ufhWSqOwFuveQj65wnlXdNdHZ24weT_KaKJFY,17851
70
- jaxsim-0.7.1.dev49.dist-info/WHEEL,sha256=zaaOINJESkSfm_4HQVc5ssNzHCPXhJm0kEUakpsEHaU,91
71
- jaxsim-0.7.1.dev49.dist-info/top_level.txt,sha256=LxGMA8FLtXjQ6oI7N5gd_R_oSUHxpXxUEOfT1xS_ni0,7
72
- jaxsim-0.7.1.dev49.dist-info/RECORD,,
69
+ jaxsim-0.7.1.dev53.dist-info/licenses/LICENSE,sha256=eaYdFmdeMbiIoIiPzEK0MjP1S9wtFXjXNR5er49uLR0,1546
70
+ jaxsim-0.7.1.dev53.dist-info/METADATA,sha256=qaL63IhImfWKV-9NalfteDwSmo6w7f-LOH73phZr_do,17851
71
+ jaxsim-0.7.1.dev53.dist-info/WHEEL,sha256=zaaOINJESkSfm_4HQVc5ssNzHCPXhJm0kEUakpsEHaU,91
72
+ jaxsim-0.7.1.dev53.dist-info/top_level.txt,sha256=LxGMA8FLtXjQ6oI7N5gd_R_oSUHxpXxUEOfT1xS_ni0,7
73
+ jaxsim-0.7.1.dev53.dist-info/RECORD,,