jaxsim 0.1rc0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. jaxsim/__init__.py +5 -6
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +3 -0
  4. jaxsim/api/com.py +240 -0
  5. jaxsim/api/common.py +216 -0
  6. jaxsim/api/contact.py +271 -0
  7. jaxsim/api/data.py +821 -0
  8. jaxsim/api/joint.py +189 -0
  9. jaxsim/api/kin_dyn_parameters.py +777 -0
  10. jaxsim/api/link.py +361 -0
  11. jaxsim/api/model.py +1633 -0
  12. jaxsim/api/ode.py +295 -0
  13. jaxsim/api/ode_data.py +694 -0
  14. jaxsim/api/references.py +421 -0
  15. jaxsim/integrators/__init__.py +2 -0
  16. jaxsim/integrators/common.py +594 -0
  17. jaxsim/integrators/fixed_step.py +102 -0
  18. jaxsim/integrators/variable_step.py +610 -0
  19. jaxsim/math/__init__.py +11 -0
  20. jaxsim/math/adjoint.py +24 -2
  21. jaxsim/math/joint_model.py +335 -0
  22. jaxsim/math/quaternion.py +44 -3
  23. jaxsim/math/rotation.py +4 -4
  24. jaxsim/math/transform.py +92 -0
  25. jaxsim/mujoco/__init__.py +3 -0
  26. jaxsim/mujoco/__main__.py +192 -0
  27. jaxsim/mujoco/loaders.py +615 -0
  28. jaxsim/mujoco/model.py +414 -0
  29. jaxsim/mujoco/visualizer.py +176 -0
  30. jaxsim/parsers/descriptions/collision.py +14 -0
  31. jaxsim/parsers/descriptions/link.py +13 -2
  32. jaxsim/parsers/kinematic_graph.py +8 -3
  33. jaxsim/parsers/rod/parser.py +54 -38
  34. jaxsim/parsers/rod/utils.py +7 -8
  35. jaxsim/rbda/__init__.py +7 -0
  36. jaxsim/rbda/aba.py +295 -0
  37. jaxsim/rbda/collidable_points.py +142 -0
  38. jaxsim/{physics/algos → rbda}/crba.py +43 -42
  39. jaxsim/rbda/forward_kinematics.py +113 -0
  40. jaxsim/rbda/jacobian.py +201 -0
  41. jaxsim/rbda/rnea.py +237 -0
  42. jaxsim/rbda/soft_contacts.py +296 -0
  43. jaxsim/rbda/utils.py +152 -0
  44. jaxsim/terrain/__init__.py +2 -0
  45. jaxsim/{physics/algos → terrain}/terrain.py +4 -6
  46. jaxsim/typing.py +30 -30
  47. jaxsim/utils/__init__.py +1 -4
  48. jaxsim/utils/hashless.py +18 -0
  49. jaxsim/utils/jaxsim_dataclass.py +281 -31
  50. {jaxsim-0.1rc0.dist-info → jaxsim-0.2.0.dist-info}/LICENSE +1 -1
  51. jaxsim-0.2.0.dist-info/METADATA +237 -0
  52. jaxsim-0.2.0.dist-info/RECORD +64 -0
  53. {jaxsim-0.1rc0.dist-info → jaxsim-0.2.0.dist-info}/WHEEL +1 -1
  54. jaxsim/high_level/__init__.py +0 -2
  55. jaxsim/high_level/common.py +0 -11
  56. jaxsim/high_level/joint.py +0 -148
  57. jaxsim/high_level/link.py +0 -259
  58. jaxsim/high_level/model.py +0 -1695
  59. jaxsim/math/conv.py +0 -114
  60. jaxsim/math/joint.py +0 -101
  61. jaxsim/math/plucker.py +0 -100
  62. jaxsim/physics/__init__.py +0 -12
  63. jaxsim/physics/algos/__init__.py +0 -0
  64. jaxsim/physics/algos/aba.py +0 -256
  65. jaxsim/physics/algos/aba_motors.py +0 -284
  66. jaxsim/physics/algos/forward_kinematics.py +0 -79
  67. jaxsim/physics/algos/jacobian.py +0 -98
  68. jaxsim/physics/algos/rnea.py +0 -180
  69. jaxsim/physics/algos/rnea_motors.py +0 -196
  70. jaxsim/physics/algos/soft_contacts.py +0 -454
  71. jaxsim/physics/algos/utils.py +0 -69
  72. jaxsim/physics/model/__init__.py +0 -0
  73. jaxsim/physics/model/ground_contact.py +0 -55
  74. jaxsim/physics/model/physics_model.py +0 -358
  75. jaxsim/physics/model/physics_model_state.py +0 -174
  76. jaxsim/simulation/__init__.py +0 -4
  77. jaxsim/simulation/integrators.py +0 -452
  78. jaxsim/simulation/ode.py +0 -290
  79. jaxsim/simulation/ode_data.py +0 -53
  80. jaxsim/simulation/ode_integration.py +0 -125
  81. jaxsim/simulation/simulator.py +0 -544
  82. jaxsim/simulation/simulator_callbacks.py +0 -53
  83. jaxsim/simulation/utils.py +0 -15
  84. jaxsim/sixd/__init__.py +0 -2
  85. jaxsim/utils/oop.py +0 -532
  86. jaxsim/utils/vmappable.py +0 -117
  87. jaxsim-0.1rc0.dist-info/METADATA +0 -167
  88. jaxsim-0.1rc0.dist-info/RECORD +0 -64
  89. {jaxsim-0.1rc0.dist-info → jaxsim-0.2.0.dist-info}/top_level.txt +0 -0
jaxsim/api/contact.py ADDED
@@ -0,0 +1,271 @@
1
+ import functools
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+
6
+ import jaxsim.api as js
7
+ import jaxsim.rbda
8
+ import jaxsim.typing as jtp
9
+
10
+ from .common import VelRepr
11
+
12
+
13
+ @jax.jit
14
+ def collidable_point_kinematics(
15
+ model: js.model.JaxSimModel, data: js.data.JaxSimModelData
16
+ ) -> tuple[jtp.Matrix, jtp.Matrix]:
17
+ """
18
+ Compute the position and 3D velocity of the collidable points in the world frame.
19
+
20
+ Args:
21
+ model: The model to consider.
22
+ data: The data of the considered model.
23
+
24
+ Returns:
25
+ The position and velocity of the collidable points in the world frame.
26
+
27
+ Note:
28
+ The collidable point velocity is the plain coordinate derivative of the position.
29
+ If we attach a frame C = (p_C, [C]) to the collidable point, it corresponds to
30
+ the linear component of the mixed 6D frame velocity.
31
+ """
32
+
33
+ from jaxsim.rbda import collidable_points
34
+
35
+ with data.switch_velocity_representation(VelRepr.Inertial):
36
+ W_p_Ci, W_ṗ_Ci = collidable_points.collidable_points_pos_vel(
37
+ model=model,
38
+ base_position=data.base_position(),
39
+ base_quaternion=data.base_orientation(dcm=False),
40
+ joint_positions=data.joint_positions(model=model),
41
+ base_linear_velocity=data.base_velocity()[0:3],
42
+ base_angular_velocity=data.base_velocity()[3:6],
43
+ joint_velocities=data.joint_velocities(model=model),
44
+ )
45
+
46
+ return W_p_Ci, W_ṗ_Ci
47
+
48
+
49
+ @jax.jit
50
+ def collidable_point_positions(
51
+ model: js.model.JaxSimModel, data: js.data.JaxSimModelData
52
+ ) -> jtp.Matrix:
53
+ """
54
+ Compute the position of the collidable points in the world frame.
55
+
56
+ Args:
57
+ model: The model to consider.
58
+ data: The data of the considered model.
59
+
60
+ Returns:
61
+ The position of the collidable points in the world frame.
62
+ """
63
+
64
+ return collidable_point_kinematics(model=model, data=data)[0]
65
+
66
+
67
+ @jax.jit
68
+ def collidable_point_velocities(
69
+ model: js.model.JaxSimModel, data: js.data.JaxSimModelData
70
+ ) -> jtp.Matrix:
71
+ """
72
+ Compute the 3D velocity of the collidable points in the world frame.
73
+
74
+ Args:
75
+ model: The model to consider.
76
+ data: The data of the considered model.
77
+
78
+ Returns:
79
+ The 3D velocity of the collidable points.
80
+ """
81
+
82
+ return collidable_point_kinematics(model=model, data=data)[1]
83
+
84
+
85
+ @jax.jit
86
+ def collidable_point_forces(
87
+ model: js.model.JaxSimModel, data: js.data.JaxSimModelData
88
+ ) -> jtp.Matrix:
89
+ """
90
+ Compute the 6D forces applied to each collidable point.
91
+
92
+ Args:
93
+ model: The model to consider.
94
+ data: The data of the considered model.
95
+
96
+ Returns:
97
+ The 6D forces applied to each collidable point expressed in the frame
98
+ corresponding to the active representation.
99
+ """
100
+
101
+ f_Ci, _ = collidable_point_dynamics(model=model, data=data)
102
+
103
+ return f_Ci
104
+
105
+
106
+ @jax.jit
107
+ def collidable_point_dynamics(
108
+ model: js.model.JaxSimModel, data: js.data.JaxSimModelData
109
+ ) -> tuple[jtp.Matrix, jtp.Matrix]:
110
+ r"""
111
+ Compute the 6D force applied to each collidable point and the corresponding
112
+ material deformation rate.
113
+
114
+ Args:
115
+ model: The model to consider.
116
+ data: The data of the considered model.
117
+
118
+ Returns:
119
+ The 6D force applied to each collidable point and the corresponding
120
+ material deformation rate.
121
+
122
+ Note:
123
+ The material deformation rate is always returned in the mixed frame
124
+ `C[W] = ({}^W \mathbf{p}_C, [W])`. This is convenient for integration purpose.
125
+ Instead, the 6D forces are returned in the active representation.
126
+ """
127
+
128
+ # Compute the position and linear velocities (mixed representation) of
129
+ # all collidable points belonging to the robot.
130
+ W_p_Ci, W_ṗ_Ci = js.contact.collidable_point_kinematics(model=model, data=data)
131
+
132
+ # Build the soft contact model.
133
+ soft_contacts = jaxsim.rbda.SoftContacts(
134
+ parameters=data.soft_contacts_params, terrain=model.terrain
135
+ )
136
+
137
+ # Compute the 6D force expressed in the inertial frame and applied to each
138
+ # collidable point, and the corresponding material deformation rate.
139
+ # Note that the material deformation rate is always returned in the mixed frame
140
+ # C[W] = (W_p_C, [W]). This is convenient for integration purpose.
141
+ W_f_Ci, CW_ṁ = jax.vmap(soft_contacts.contact_model)(
142
+ W_p_Ci, W_ṗ_Ci, data.state.soft_contacts.tangential_deformation
143
+ )
144
+
145
+ # Convert the 6D forces to the active representation.
146
+ f_Ci = jax.vmap(
147
+ lambda W_f_C: data.inertial_to_other_representation(
148
+ array=W_f_C,
149
+ other_representation=data.velocity_representation,
150
+ transform=data.base_transform(),
151
+ is_force=True,
152
+ )
153
+ )(W_f_Ci)
154
+
155
+ return f_Ci, CW_ṁ
156
+
157
+
158
+ @functools.partial(jax.jit, static_argnames=["link_names"])
159
+ def in_contact(
160
+ model: js.model.JaxSimModel,
161
+ data: js.data.JaxSimModelData,
162
+ *,
163
+ link_names: tuple[str, ...] | None = None,
164
+ ) -> jtp.Vector:
165
+ """
166
+ Return whether the links are in contact with the terrain.
167
+
168
+ Args:
169
+ model: The model to consider.
170
+ data: The data of the considered model.
171
+ link_names:
172
+ The names of the links to consider. If None, all links are considered.
173
+
174
+ Returns:
175
+ A boolean vector indicating whether the links are in contact with the terrain.
176
+ """
177
+
178
+ link_names = link_names if link_names is not None else model.link_names()
179
+
180
+ if set(link_names).difference(model.link_names()):
181
+ raise ValueError("One or more link names are not part of the model")
182
+
183
+ W_p_Ci = collidable_point_positions(model=model, data=data)
184
+
185
+ terrain_height = jax.vmap(lambda x, y: model.terrain.height(x=x, y=y))(
186
+ W_p_Ci[:, 0], W_p_Ci[:, 1]
187
+ )
188
+
189
+ below_terrain = W_p_Ci[:, 2] <= terrain_height
190
+
191
+ links_in_contact = jax.vmap(
192
+ lambda link_index: jnp.where(
193
+ jnp.array(model.kin_dyn_parameters.contact_parameters.body) == link_index,
194
+ below_terrain,
195
+ jnp.zeros_like(below_terrain, dtype=bool),
196
+ ).any()
197
+ )(js.link.names_to_idxs(link_names=link_names, model=model))
198
+
199
+ return links_in_contact
200
+
201
+
202
+ @jax.jit
203
+ def estimate_good_soft_contacts_parameters(
204
+ model: js.model.JaxSimModel,
205
+ *,
206
+ standard_gravity: jtp.FloatLike = jaxsim.math.StandardGravity,
207
+ static_friction_coefficient: jtp.FloatLike = 0.5,
208
+ number_of_active_collidable_points_steady_state: jtp.IntLike = 1,
209
+ damping_ratio: jtp.FloatLike = 1.0,
210
+ max_penetration: jtp.FloatLike | None = None,
211
+ ) -> jaxsim.rbda.soft_contacts.SoftContactsParams:
212
+ """
213
+ Estimate good soft contacts parameters for the given model.
214
+
215
+ Args:
216
+ model: The model to consider.
217
+ standard_gravity: The standard gravity constant.
218
+ static_friction_coefficient: The static friction coefficient.
219
+ number_of_active_collidable_points_steady_state:
220
+ The number of active collidable points in steady state supporting
221
+ the weight of the robot.
222
+ damping_ratio: The damping ratio.
223
+ max_penetration:
224
+ The maximum penetration allowed in steady state when the robot is
225
+ supported by the configured number of active collidable points.
226
+
227
+ Returns:
228
+ The estimated good soft contacts parameters.
229
+
230
+ Note:
231
+ This method provides a good starting point for the soft contacts parameters.
232
+ The user is encouraged to fine-tune the parameters based on the
233
+ specific application.
234
+ """
235
+
236
+ def estimate_model_height(model: js.model.JaxSimModel) -> jtp.Float:
237
+ """"""
238
+
239
+ zero_data = js.data.JaxSimModelData.build(
240
+ model=model,
241
+ soft_contacts_params=jaxsim.rbda.soft_contacts.SoftContactsParams(),
242
+ )
243
+
244
+ W_pz_CoM = js.com.com_position(model=model, data=zero_data)[2]
245
+
246
+ if model.floating_base():
247
+ W_pz_C = collidable_point_positions(model=model, data=zero_data)[:, -1]
248
+ return 2 * (W_pz_CoM - W_pz_C.min())
249
+
250
+ return 2 * W_pz_CoM
251
+
252
+ max_δ = (
253
+ max_penetration
254
+ if max_penetration is not None
255
+ else 0.005 * estimate_model_height(model=model)
256
+ )
257
+
258
+ nc = number_of_active_collidable_points_steady_state
259
+
260
+ sc_parameters = (
261
+ jaxsim.rbda.soft_contacts.SoftContactsParams.build_default_from_jaxsim_model(
262
+ model=model,
263
+ standard_gravity=standard_gravity,
264
+ static_friction_coefficient=static_friction_coefficient,
265
+ max_penetration=max_δ,
266
+ number_of_active_collidable_points_steady_state=nc,
267
+ damping_ratio=damping_ratio,
268
+ )
269
+ )
270
+
271
+ return sc_parameters