jaxsim 0.1rc0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. jaxsim/__init__.py +5 -6
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +3 -0
  4. jaxsim/api/com.py +240 -0
  5. jaxsim/api/common.py +216 -0
  6. jaxsim/api/contact.py +271 -0
  7. jaxsim/api/data.py +821 -0
  8. jaxsim/api/joint.py +189 -0
  9. jaxsim/api/kin_dyn_parameters.py +777 -0
  10. jaxsim/api/link.py +361 -0
  11. jaxsim/api/model.py +1633 -0
  12. jaxsim/api/ode.py +295 -0
  13. jaxsim/api/ode_data.py +694 -0
  14. jaxsim/api/references.py +421 -0
  15. jaxsim/integrators/__init__.py +2 -0
  16. jaxsim/integrators/common.py +594 -0
  17. jaxsim/integrators/fixed_step.py +102 -0
  18. jaxsim/integrators/variable_step.py +610 -0
  19. jaxsim/math/__init__.py +11 -0
  20. jaxsim/math/adjoint.py +24 -2
  21. jaxsim/math/joint_model.py +335 -0
  22. jaxsim/math/quaternion.py +44 -3
  23. jaxsim/math/rotation.py +4 -4
  24. jaxsim/math/transform.py +92 -0
  25. jaxsim/mujoco/__init__.py +3 -0
  26. jaxsim/mujoco/__main__.py +192 -0
  27. jaxsim/mujoco/loaders.py +615 -0
  28. jaxsim/mujoco/model.py +414 -0
  29. jaxsim/mujoco/visualizer.py +176 -0
  30. jaxsim/parsers/descriptions/collision.py +14 -0
  31. jaxsim/parsers/descriptions/link.py +13 -2
  32. jaxsim/parsers/kinematic_graph.py +8 -3
  33. jaxsim/parsers/rod/parser.py +54 -38
  34. jaxsim/parsers/rod/utils.py +7 -8
  35. jaxsim/rbda/__init__.py +7 -0
  36. jaxsim/rbda/aba.py +295 -0
  37. jaxsim/rbda/collidable_points.py +142 -0
  38. jaxsim/{physics/algos → rbda}/crba.py +43 -42
  39. jaxsim/rbda/forward_kinematics.py +113 -0
  40. jaxsim/rbda/jacobian.py +201 -0
  41. jaxsim/rbda/rnea.py +237 -0
  42. jaxsim/rbda/soft_contacts.py +296 -0
  43. jaxsim/rbda/utils.py +152 -0
  44. jaxsim/terrain/__init__.py +2 -0
  45. jaxsim/{physics/algos → terrain}/terrain.py +4 -6
  46. jaxsim/typing.py +30 -30
  47. jaxsim/utils/__init__.py +1 -4
  48. jaxsim/utils/hashless.py +18 -0
  49. jaxsim/utils/jaxsim_dataclass.py +281 -31
  50. {jaxsim-0.1rc0.dist-info → jaxsim-0.2.0.dist-info}/LICENSE +1 -1
  51. jaxsim-0.2.0.dist-info/METADATA +237 -0
  52. jaxsim-0.2.0.dist-info/RECORD +64 -0
  53. {jaxsim-0.1rc0.dist-info → jaxsim-0.2.0.dist-info}/WHEEL +1 -1
  54. jaxsim/high_level/__init__.py +0 -2
  55. jaxsim/high_level/common.py +0 -11
  56. jaxsim/high_level/joint.py +0 -148
  57. jaxsim/high_level/link.py +0 -259
  58. jaxsim/high_level/model.py +0 -1695
  59. jaxsim/math/conv.py +0 -114
  60. jaxsim/math/joint.py +0 -101
  61. jaxsim/math/plucker.py +0 -100
  62. jaxsim/physics/__init__.py +0 -12
  63. jaxsim/physics/algos/__init__.py +0 -0
  64. jaxsim/physics/algos/aba.py +0 -256
  65. jaxsim/physics/algos/aba_motors.py +0 -284
  66. jaxsim/physics/algos/forward_kinematics.py +0 -79
  67. jaxsim/physics/algos/jacobian.py +0 -98
  68. jaxsim/physics/algos/rnea.py +0 -180
  69. jaxsim/physics/algos/rnea_motors.py +0 -196
  70. jaxsim/physics/algos/soft_contacts.py +0 -454
  71. jaxsim/physics/algos/utils.py +0 -69
  72. jaxsim/physics/model/__init__.py +0 -0
  73. jaxsim/physics/model/ground_contact.py +0 -55
  74. jaxsim/physics/model/physics_model.py +0 -358
  75. jaxsim/physics/model/physics_model_state.py +0 -174
  76. jaxsim/simulation/__init__.py +0 -4
  77. jaxsim/simulation/integrators.py +0 -452
  78. jaxsim/simulation/ode.py +0 -290
  79. jaxsim/simulation/ode_data.py +0 -53
  80. jaxsim/simulation/ode_integration.py +0 -125
  81. jaxsim/simulation/simulator.py +0 -544
  82. jaxsim/simulation/simulator_callbacks.py +0 -53
  83. jaxsim/simulation/utils.py +0 -15
  84. jaxsim/sixd/__init__.py +0 -2
  85. jaxsim/utils/oop.py +0 -532
  86. jaxsim/utils/vmappable.py +0 -117
  87. jaxsim-0.1rc0.dist-info/METADATA +0 -167
  88. jaxsim-0.1rc0.dist-info/RECORD +0 -64
  89. {jaxsim-0.1rc0.dist-info → jaxsim-0.2.0.dist-info}/top_level.txt +0 -0
jaxsim/api/ode_data.py ADDED
@@ -0,0 +1,694 @@
1
+ from __future__ import annotations
2
+
3
+ import jax.numpy as jnp
4
+ import jax_dataclasses
5
+
6
+ import jaxsim.api as js
7
+ import jaxsim.typing as jtp
8
+ from jaxsim.utils import JaxsimDataclass
9
+
10
+ # =============================================================================
11
+ # Define the input and state of the ODE system defining the integrated dynamics
12
+ # =============================================================================
13
+
14
+ # Note: the ODE system is the combination of the floating-base dynamics and the
15
+ # soft-contacts dynamics.
16
+
17
+
18
+ @jax_dataclasses.pytree_dataclass
19
+ class ODEInput(JaxsimDataclass):
20
+ """
21
+ The input to the ODE system.
22
+
23
+ Attributes:
24
+ physics_model: The input to the physics model.
25
+ """
26
+
27
+ physics_model: PhysicsModelInput
28
+
29
+ @staticmethod
30
+ def build_from_jaxsim_model(
31
+ model: js.model.JaxSimModel | None = None,
32
+ joint_forces: jtp.VectorJax | None = None,
33
+ link_forces: jtp.MatrixJax | None = None,
34
+ ) -> ODEInput:
35
+ """
36
+ Build an `ODEInput` from a `JaxSimModel`.
37
+
38
+ Args:
39
+ model: The `JaxSimModel` associated with the ODE input.
40
+ joint_forces: The vector of joint forces.
41
+ link_forces: The matrix of external forces applied to the links.
42
+
43
+ Returns:
44
+ The `ODEInput` built from the `JaxSimModel`.
45
+
46
+ Note:
47
+ If any of the input components are not provided, they are built from the
48
+ `JaxSimModel` and initialized to zero.
49
+ """
50
+
51
+ return ODEInput.build(
52
+ physics_model_input=PhysicsModelInput.build_from_jaxsim_model(
53
+ model=model,
54
+ joint_forces=joint_forces,
55
+ link_forces=link_forces,
56
+ ),
57
+ model=model,
58
+ )
59
+
60
+ @staticmethod
61
+ def build(
62
+ physics_model_input: PhysicsModelInput | None = None,
63
+ model: js.model.JaxSimModel | None = None,
64
+ ) -> ODEInput:
65
+ """
66
+ Build an `ODEInput` from a `PhysicsModelInput`.
67
+
68
+ Args:
69
+ physics_model_input: The `PhysicsModelInput` associated with the ODE input.
70
+ model: The `JaxSimModel` associated with the ODE input.
71
+
72
+ Returns:
73
+ A `ODEInput` instance.
74
+ """
75
+
76
+ physics_model_input = (
77
+ physics_model_input
78
+ if physics_model_input is not None
79
+ else PhysicsModelInput.zero(model=model)
80
+ )
81
+
82
+ return ODEInput(physics_model=physics_model_input)
83
+
84
+ @staticmethod
85
+ def zero(model: js.model.JaxSimModel) -> ODEInput:
86
+ """
87
+ Build a zero `ODEInput` from a `JaxSimModel`.
88
+
89
+ Args:
90
+ model: The `JaxSimModel` associated with the ODE input.
91
+
92
+ Returns:
93
+ A zero `ODEInput` instance.
94
+ """
95
+
96
+ return ODEInput.build(model=model)
97
+
98
+ def valid(self, model: js.model.JaxSimModel) -> bool:
99
+ """
100
+ Check if the `ODEInput` is valid for a given `JaxSimModel`.
101
+
102
+ Args:
103
+ model: The `JaxSimModel` to validate the `ODEInput` against.
104
+
105
+ Returns:
106
+ `True` if the ODE input is valid for the given model, `False` otherwise.
107
+ """
108
+
109
+ return self.physics_model.valid(model=model)
110
+
111
+
112
+ @jax_dataclasses.pytree_dataclass
113
+ class ODEState(JaxsimDataclass):
114
+ """
115
+ The state of the ODE system.
116
+
117
+ Attributes:
118
+ physics_model: The state of the physics model.
119
+ soft_contacts: The state of the soft-contacts model.
120
+ """
121
+
122
+ physics_model: PhysicsModelState
123
+ soft_contacts: SoftContactsState
124
+
125
+ @staticmethod
126
+ def build_from_jaxsim_model(
127
+ model: js.model.JaxSimModel | None = None,
128
+ joint_positions: jtp.Vector | None = None,
129
+ joint_velocities: jtp.Vector | None = None,
130
+ base_position: jtp.Vector | None = None,
131
+ base_quaternion: jtp.Vector | None = None,
132
+ base_linear_velocity: jtp.Vector | None = None,
133
+ base_angular_velocity: jtp.Vector | None = None,
134
+ tangential_deformation: jtp.Matrix | None = None,
135
+ ) -> ODEState:
136
+ """
137
+ Build an `ODEState` from a `JaxSimModel`.
138
+
139
+ Args:
140
+ model: The `JaxSimModel` associated with the ODE state.
141
+ joint_positions: The vector of joint positions.
142
+ joint_velocities: The vector of joint velocities.
143
+ base_position: The 3D position of the base link.
144
+ base_quaternion: The quaternion defining the orientation of the base link.
145
+ base_linear_velocity:
146
+ The linear velocity of the base link in inertial-fixed representation.
147
+ base_angular_velocity:
148
+ The angular velocity of the base link in inertial-fixed representation.
149
+ tangential_deformation:
150
+ The matrix of 3D tangential material deformations corresponding to
151
+ each collidable point.
152
+
153
+ Returns:
154
+ The `ODEState` built from the `JaxSimModel`.
155
+
156
+ Note:
157
+ If any of the state components are not provided, they are built from the
158
+ `JaxSimModel` and initialized to zero.
159
+ """
160
+
161
+ return ODEState.build(
162
+ model=model,
163
+ physics_model_state=PhysicsModelState.build_from_jaxsim_model(
164
+ model=model,
165
+ joint_positions=joint_positions,
166
+ joint_velocities=joint_velocities,
167
+ base_position=base_position,
168
+ base_quaternion=base_quaternion,
169
+ base_linear_velocity=base_linear_velocity,
170
+ base_angular_velocity=base_angular_velocity,
171
+ ),
172
+ soft_contacts_state=SoftContactsState.build_from_jaxsim_model(
173
+ model=model,
174
+ tangential_deformation=tangential_deformation,
175
+ ),
176
+ )
177
+
178
+ @staticmethod
179
+ def build(
180
+ physics_model_state: PhysicsModelState | None = None,
181
+ soft_contacts_state: SoftContactsState | None = None,
182
+ model: js.model.JaxSimModel | None = None,
183
+ ) -> ODEState:
184
+ """
185
+ Build an `ODEState` from a `PhysicsModelState` and a `SoftContactsState`.
186
+
187
+ Args:
188
+ physics_model_state: The state of the physics model.
189
+ soft_contacts_state: The state of the soft-contacts model.
190
+ model: The `JaxSimModel` associated with the ODE state.
191
+
192
+ Returns:
193
+ A `ODEState` instance.
194
+ """
195
+
196
+ physics_model_state = (
197
+ physics_model_state
198
+ if physics_model_state is not None
199
+ else PhysicsModelState.zero(model=model)
200
+ )
201
+
202
+ soft_contacts_state = (
203
+ soft_contacts_state
204
+ if soft_contacts_state is not None
205
+ else SoftContactsState.zero(model=model)
206
+ )
207
+
208
+ return ODEState(
209
+ physics_model=physics_model_state, soft_contacts=soft_contacts_state
210
+ )
211
+
212
+ @staticmethod
213
+ def zero(model: js.model.JaxSimModel) -> ODEState:
214
+ """
215
+ Build a zero `ODEState` from a `JaxSimModel`.
216
+
217
+ Args:
218
+ model: The `JaxSimModel` associated with the ODE state.
219
+
220
+ Returns:
221
+ A zero `ODEState` instance.
222
+ """
223
+
224
+ model_state = ODEState.build(model=model)
225
+
226
+ return model_state
227
+
228
+ def valid(self, model: js.model.JaxSimModel) -> bool:
229
+ """
230
+ Check if the `ODEState` is valid for a given `JaxSimModel`.
231
+
232
+ Args:
233
+ model: The `JaxSimModel` to validate the `ODEState` against.
234
+
235
+ Returns:
236
+ `True` if the ODE state is valid for the given model, `False` otherwise.
237
+ """
238
+
239
+ return self.physics_model.valid(model=model) and self.soft_contacts.valid(
240
+ model=model
241
+ )
242
+
243
+
244
+ # ==================================================
245
+ # Define the input and state of floating-base robots
246
+ # ==================================================
247
+
248
+
249
+ @jax_dataclasses.pytree_dataclass
250
+ class PhysicsModelState(JaxsimDataclass):
251
+ """
252
+ Class storing the state of the physics model dynamics.
253
+
254
+ Attributes:
255
+ joint_positions: The vector of joint positions.
256
+ joint_velocities: The vector of joint velocities.
257
+ base_position: The 3D position of the base link.
258
+ base_quaternion: The quaternion defining the orientation of the base link.
259
+ base_linear_velocity:
260
+ The linear velocity of the base link in inertial-fixed representation.
261
+ base_angular_velocity:
262
+ The angular velocity of the base link in inertial-fixed representation.
263
+
264
+ """
265
+
266
+ # Joint state
267
+ joint_positions: jtp.Vector
268
+ joint_velocities: jtp.Vector
269
+
270
+ # Base state
271
+ base_position: jtp.Vector = jax_dataclasses.field(
272
+ default_factory=lambda: jnp.zeros(3)
273
+ )
274
+ base_quaternion: jtp.Vector = jax_dataclasses.field(
275
+ default_factory=lambda: jnp.array([1.0, 0, 0, 0])
276
+ )
277
+ base_linear_velocity: jtp.Vector = jax_dataclasses.field(
278
+ default_factory=lambda: jnp.zeros(3)
279
+ )
280
+ base_angular_velocity: jtp.Vector = jax_dataclasses.field(
281
+ default_factory=lambda: jnp.zeros(3)
282
+ )
283
+
284
+ @staticmethod
285
+ def build_from_jaxsim_model(
286
+ model: js.model.JaxSimModel | None = None,
287
+ joint_positions: jtp.Vector | None = None,
288
+ joint_velocities: jtp.Vector | None = None,
289
+ base_position: jtp.Vector | None = None,
290
+ base_quaternion: jtp.Vector | None = None,
291
+ base_linear_velocity: jtp.Vector | None = None,
292
+ base_angular_velocity: jtp.Vector | None = None,
293
+ ) -> PhysicsModelState:
294
+ """
295
+ Build a `PhysicsModelState` from a `JaxSimModel`.
296
+
297
+ Args:
298
+ model: The `JaxSimModel` associated with the state.
299
+ joint_positions: The vector of joint positions.
300
+ joint_velocities: The vector of joint velocities.
301
+ base_position: The 3D position of the base link.
302
+ base_quaternion: The quaternion defining the orientation of the base link.
303
+ base_linear_velocity:
304
+ The linear velocity of the base link in inertial-fixed representation.
305
+ base_angular_velocity:
306
+ The angular velocity of the base link in inertial-fixed representation.
307
+
308
+ Note:
309
+ If any of the state components are not provided, they are built from the
310
+ `JaxSimModel` and initialized to zero.
311
+
312
+ Returns:
313
+ A `PhysicsModelState` instance.
314
+ """
315
+
316
+ return PhysicsModelState.build(
317
+ joint_positions=joint_positions,
318
+ joint_velocities=joint_velocities,
319
+ base_position=base_position,
320
+ base_quaternion=base_quaternion,
321
+ base_linear_velocity=base_linear_velocity,
322
+ base_angular_velocity=base_angular_velocity,
323
+ number_of_dofs=model.dofs(),
324
+ )
325
+
326
+ @staticmethod
327
+ def build(
328
+ joint_positions: jtp.Vector | None = None,
329
+ joint_velocities: jtp.Vector | None = None,
330
+ base_position: jtp.Vector | None = None,
331
+ base_quaternion: jtp.Vector | None = None,
332
+ base_linear_velocity: jtp.Vector | None = None,
333
+ base_angular_velocity: jtp.Vector | None = None,
334
+ number_of_dofs: jtp.Int | None = None,
335
+ ) -> PhysicsModelState:
336
+ """
337
+ Build a `PhysicsModelState`.
338
+
339
+ Args:
340
+ joint_positions: The vector of joint positions.
341
+ joint_velocities: The vector of joint velocities.
342
+ base_position: The 3D position of the base link.
343
+ base_quaternion: The quaternion defining the orientation of the base link.
344
+ base_linear_velocity:
345
+ The linear velocity of the base link in inertial-fixed representation.
346
+ base_angular_velocity:
347
+ The angular velocity of the base link in inertial-fixed representation.
348
+ number_of_dofs:
349
+ The number of degrees of freedom of the physics model.
350
+
351
+ Returns:
352
+ A `PhysicsModelState` instance.
353
+ """
354
+
355
+ joint_positions = (
356
+ joint_positions
357
+ if joint_positions is not None
358
+ else jnp.zeros(number_of_dofs)
359
+ )
360
+
361
+ joint_velocities = (
362
+ joint_velocities
363
+ if joint_velocities is not None
364
+ else jnp.zeros(number_of_dofs)
365
+ )
366
+
367
+ base_position = base_position if base_position is not None else jnp.zeros(3)
368
+
369
+ base_quaternion = (
370
+ base_quaternion
371
+ if base_quaternion is not None
372
+ else jnp.array([1.0, 0, 0, 0])
373
+ )
374
+
375
+ base_linear_velocity = (
376
+ base_linear_velocity if base_linear_velocity is not None else jnp.zeros(3)
377
+ )
378
+
379
+ base_angular_velocity = (
380
+ base_angular_velocity if base_angular_velocity is not None else jnp.zeros(3)
381
+ )
382
+
383
+ physics_model_state = PhysicsModelState(
384
+ joint_positions=jnp.array(joint_positions, dtype=float),
385
+ joint_velocities=jnp.array(joint_velocities, dtype=float),
386
+ base_position=jnp.array(base_position, dtype=float),
387
+ base_quaternion=jnp.array(base_quaternion, dtype=float),
388
+ base_linear_velocity=jnp.array(base_linear_velocity, dtype=float),
389
+ base_angular_velocity=jnp.array(base_angular_velocity, dtype=float),
390
+ )
391
+
392
+ # assert state.valid(physics_model)
393
+ return physics_model_state
394
+
395
+ @staticmethod
396
+ def zero(model: js.model.JaxSimModel) -> PhysicsModelState:
397
+ """
398
+ Build a `PhysicsModelState` with all components initialized to zero.
399
+
400
+ Args:
401
+ model: The `JaxSimModel` associated with the state.
402
+
403
+ Returns:
404
+ A `PhysicsModelState` instance.
405
+ """
406
+
407
+ return PhysicsModelState.build_from_jaxsim_model(model=model)
408
+
409
+ def valid(self, model: js.model.JaxSimModel) -> bool:
410
+ """
411
+ Check if the `PhysicsModelState` is valid for a given `JaxSimModel`.
412
+
413
+ Args:
414
+ model: The `JaxSimModel` to validate the `PhysicsModelState` against.
415
+
416
+ Returns:
417
+ `True` if the `PhysicsModelState` is valid for the given model,
418
+ `False` otherwise.
419
+ """
420
+
421
+ shape = self.joint_positions.shape
422
+ expected_shape = (model.dofs(),)
423
+
424
+ if shape != expected_shape:
425
+ return False
426
+
427
+ shape = self.joint_velocities.shape
428
+ expected_shape = (model.dofs(),)
429
+
430
+ if shape != expected_shape:
431
+ return False
432
+
433
+ shape = self.base_position.shape
434
+ expected_shape = (3,)
435
+
436
+ if shape != expected_shape:
437
+ return False
438
+
439
+ shape = self.base_quaternion.shape
440
+ expected_shape = (4,)
441
+
442
+ if shape != expected_shape:
443
+ return False
444
+
445
+ shape = self.base_linear_velocity.shape
446
+ expected_shape = (3,)
447
+
448
+ if shape != expected_shape:
449
+ return False
450
+
451
+ shape = self.base_angular_velocity.shape
452
+ expected_shape = (3,)
453
+
454
+ if shape != expected_shape:
455
+ return False
456
+
457
+ return True
458
+
459
+
460
+ @jax_dataclasses.pytree_dataclass
461
+ class PhysicsModelInput(JaxsimDataclass):
462
+ """
463
+ Class storing the inputs of the physics model dynamics.
464
+
465
+ Attributes:
466
+ tau: The vector of joint forces.
467
+ f_ext: The matrix of external forces applied to the links.
468
+ """
469
+
470
+ tau: jtp.VectorJax
471
+ f_ext: jtp.MatrixJax
472
+
473
+ @staticmethod
474
+ def build_from_jaxsim_model(
475
+ model: js.model.JaxSimModel | None = None,
476
+ joint_forces: jtp.VectorJax | None = None,
477
+ link_forces: jtp.MatrixJax | None = None,
478
+ ) -> PhysicsModelInput:
479
+ """
480
+ Build a `PhysicsModelInput` from a `JaxSimModel`.
481
+
482
+ Args:
483
+ model: The `JaxSimModel` associated with the input.
484
+ joint_forces: The vector of joint forces.
485
+ link_forces: The matrix of external forces applied to the links.
486
+
487
+ Returns:
488
+ A `PhysicsModelInput` instance.
489
+
490
+ Note:
491
+ If any of the input components are not provided, they are built from the
492
+ `JaxSimModel` and initialized to zero.
493
+ """
494
+
495
+ return PhysicsModelInput.build(
496
+ joint_forces=joint_forces,
497
+ link_forces=link_forces,
498
+ number_of_dofs=model.dofs(),
499
+ number_of_links=model.number_of_links(),
500
+ )
501
+
502
+ @staticmethod
503
+ def build(
504
+ joint_forces: jtp.VectorJax | None = None,
505
+ link_forces: jtp.MatrixJax | None = None,
506
+ number_of_dofs: jtp.Int | None = None,
507
+ number_of_links: jtp.Int | None = None,
508
+ ) -> PhysicsModelInput:
509
+ """
510
+ Build a `PhysicsModelInput`.
511
+
512
+ Args:
513
+ joint_forces: The vector of joint forces.
514
+ link_forces: The matrix of external forces applied to the links.
515
+ number_of_dofs: The number of degrees of freedom of the model.
516
+ number_of_links: The number of links of the model.
517
+
518
+ Returns:
519
+ A `PhysicsModelInput` instance.
520
+ """
521
+
522
+ joint_forces = (
523
+ joint_forces if joint_forces is not None else jnp.zeros(number_of_dofs)
524
+ )
525
+
526
+ link_forces = (
527
+ link_forces
528
+ if link_forces is not None
529
+ else jnp.zeros(shape=(number_of_links, 6))
530
+ )
531
+
532
+ return PhysicsModelInput(
533
+ tau=jnp.array(joint_forces, dtype=float),
534
+ f_ext=jnp.array(link_forces, dtype=float),
535
+ )
536
+
537
+ @staticmethod
538
+ def zero(model: js.model.JaxSimModel) -> PhysicsModelInput:
539
+ """
540
+ Build a `PhysicsModelInput` with all components initialized to zero.
541
+
542
+ Args:
543
+ model: The `JaxSimModel` associated with the input.
544
+
545
+ Returns:
546
+ A `PhysicsModelInput` instance.
547
+ """
548
+
549
+ return PhysicsModelInput.build_from_jaxsim_model(model=model)
550
+
551
+ def valid(self, model: js.model.JaxSimModel) -> bool:
552
+ """
553
+ Check if the `PhysicsModelInput` is valid for a given `JaxSimModel`.
554
+
555
+ Args:
556
+ model: The `JaxSimModel` to validate the `PhysicsModelInput` against.
557
+
558
+ Returns:
559
+ `True` if the `PhysicsModelInput` is valid for the given model,
560
+ `False` otherwise.
561
+ """
562
+
563
+ shape = self.tau.shape
564
+ expected_shape = (model.dofs(),)
565
+
566
+ if shape != expected_shape:
567
+ return False
568
+
569
+ shape = self.f_ext.shape
570
+ expected_shape = (model.number_of_links(), 6)
571
+
572
+ if shape != expected_shape:
573
+ return False
574
+
575
+ return True
576
+
577
+
578
+ # ===========================================
579
+ # Define the state of the soft-contacts model
580
+ # ===========================================
581
+
582
+
583
+ @jax_dataclasses.pytree_dataclass
584
+ class SoftContactsState(JaxsimDataclass):
585
+ """
586
+ Class storing the state of the soft contacts model.
587
+
588
+ Attributes:
589
+ tangential_deformation:
590
+ The matrix of 3D tangential material deformations corresponding to
591
+ each collidable point.
592
+ """
593
+
594
+ tangential_deformation: jtp.Matrix
595
+
596
+ @staticmethod
597
+ def build_from_jaxsim_model(
598
+ model: js.model.JaxSimModel | None = None,
599
+ tangential_deformation: jtp.Matrix | None = None,
600
+ ) -> SoftContactsState:
601
+ """
602
+ Build a `SoftContactsState` from a `JaxSimModel`.
603
+
604
+ Args:
605
+ model: The `JaxSimModel` associated with the soft contacts state.
606
+ tangential_deformation: The matrix of 3D tangential material deformations.
607
+
608
+ Returns:
609
+ The `SoftContactsState` built from the `JaxSimModel`.
610
+
611
+ Note:
612
+ If any of the state components are not provided, they are built from the
613
+ `JaxSimModel` and initialized to zero.
614
+ """
615
+
616
+ return SoftContactsState.build(
617
+ tangential_deformation=tangential_deformation,
618
+ number_of_collidable_points=len(
619
+ model.kin_dyn_parameters.contact_parameters.body
620
+ ),
621
+ )
622
+
623
+ @staticmethod
624
+ def build(
625
+ tangential_deformation: jtp.Matrix | None = None,
626
+ number_of_collidable_points: int | None = None,
627
+ ) -> SoftContactsState:
628
+ """
629
+ Create a `SoftContactsState`.
630
+
631
+ Args:
632
+ tangential_deformation:
633
+ The matrix of 3D tangential material deformations corresponding to
634
+ each collidable point.
635
+ number_of_collidable_points: The number of collidable points.
636
+
637
+ Returns:
638
+ A `SoftContactsState` instance.
639
+ """
640
+
641
+ tangential_deformation = (
642
+ tangential_deformation
643
+ if tangential_deformation is not None
644
+ else jnp.zeros(shape=(number_of_collidable_points, 3))
645
+ )
646
+
647
+ if tangential_deformation.shape[1] != 3:
648
+ raise RuntimeError("The tangential deformation matrix must have 3 columns.")
649
+
650
+ if (
651
+ number_of_collidable_points is not None
652
+ and tangential_deformation.shape[0] != number_of_collidable_points
653
+ ):
654
+ msg = "The number of collidable points must match the number of rows "
655
+ msg += "in the tangential deformation matrix."
656
+ raise RuntimeError(msg)
657
+
658
+ return SoftContactsState(
659
+ tangential_deformation=jnp.array(tangential_deformation).astype(float)
660
+ )
661
+
662
+ @staticmethod
663
+ def zero(model: js.model.JaxSimModel) -> SoftContactsState:
664
+ """
665
+ Build a zero `SoftContactsState` from a `JaxSimModel`.
666
+
667
+ Args:
668
+ model: The `JaxSimModel` associated with the soft contacts state.
669
+
670
+ Returns:
671
+ A zero `SoftContactsState` instance.
672
+ """
673
+
674
+ return SoftContactsState.build_from_jaxsim_model(model=model)
675
+
676
+ def valid(self, model: js.model.JaxSimModel) -> bool:
677
+ """
678
+ Check if the `SoftContactsState` is valid for a given `JaxSimModel`.
679
+
680
+ Args:
681
+ model: The `JaxSimModel` to validate the `SoftContactsState` against.
682
+
683
+ Returns:
684
+ `True` if the soft contacts state is valid for the given `JaxSimModel`,
685
+ `False` otherwise.
686
+ """
687
+
688
+ shape = self.tangential_deformation.shape
689
+ expected = (len(model.kin_dyn_parameters.contact_parameters.body), 3)
690
+
691
+ if shape != expected:
692
+ return False
693
+
694
+ return True