jaxsim 0.2.dev188__py3-none-any.whl → 0.2.dev364__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. jaxsim/__init__.py +3 -4
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +3 -1
  4. jaxsim/api/com.py +240 -0
  5. jaxsim/api/common.py +13 -2
  6. jaxsim/api/contact.py +120 -43
  7. jaxsim/api/data.py +112 -71
  8. jaxsim/api/joint.py +77 -36
  9. jaxsim/api/kin_dyn_parameters.py +777 -0
  10. jaxsim/api/link.py +150 -75
  11. jaxsim/api/model.py +542 -269
  12. jaxsim/api/ode.py +88 -72
  13. jaxsim/api/ode_data.py +694 -0
  14. jaxsim/api/references.py +12 -11
  15. jaxsim/integrators/__init__.py +2 -2
  16. jaxsim/integrators/common.py +110 -24
  17. jaxsim/integrators/fixed_step.py +11 -67
  18. jaxsim/integrators/variable_step.py +610 -0
  19. jaxsim/math/__init__.py +11 -0
  20. jaxsim/math/adjoint.py +24 -2
  21. jaxsim/math/joint_model.py +335 -0
  22. jaxsim/math/quaternion.py +44 -3
  23. jaxsim/math/rotation.py +4 -4
  24. jaxsim/math/transform.py +93 -0
  25. jaxsim/parsers/descriptions/collision.py +14 -0
  26. jaxsim/parsers/descriptions/link.py +13 -2
  27. jaxsim/parsers/kinematic_graph.py +5 -0
  28. jaxsim/parsers/rod/utils.py +7 -8
  29. jaxsim/rbda/__init__.py +7 -0
  30. jaxsim/rbda/aba.py +295 -0
  31. jaxsim/rbda/collidable_points.py +142 -0
  32. jaxsim/{physics/algos → rbda}/crba.py +43 -42
  33. jaxsim/rbda/forward_kinematics.py +113 -0
  34. jaxsim/rbda/jacobian.py +201 -0
  35. jaxsim/rbda/rnea.py +237 -0
  36. jaxsim/rbda/soft_contacts.py +296 -0
  37. jaxsim/rbda/utils.py +152 -0
  38. jaxsim/terrain/__init__.py +2 -0
  39. jaxsim/{physics/algos → terrain}/terrain.py +4 -6
  40. jaxsim/utils/__init__.py +1 -4
  41. jaxsim/utils/hashless.py +18 -0
  42. jaxsim/utils/jaxsim_dataclass.py +281 -30
  43. {jaxsim-0.2.dev188.dist-info → jaxsim-0.2.dev364.dist-info}/METADATA +4 -6
  44. jaxsim-0.2.dev364.dist-info/RECORD +64 -0
  45. {jaxsim-0.2.dev188.dist-info → jaxsim-0.2.dev364.dist-info}/WHEEL +1 -1
  46. jaxsim/high_level/__init__.py +0 -2
  47. jaxsim/high_level/common.py +0 -11
  48. jaxsim/high_level/joint.py +0 -148
  49. jaxsim/high_level/link.py +0 -259
  50. jaxsim/high_level/model.py +0 -1686
  51. jaxsim/math/conv.py +0 -114
  52. jaxsim/math/joint.py +0 -102
  53. jaxsim/math/plucker.py +0 -100
  54. jaxsim/physics/__init__.py +0 -12
  55. jaxsim/physics/algos/__init__.py +0 -0
  56. jaxsim/physics/algos/aba.py +0 -254
  57. jaxsim/physics/algos/aba_motors.py +0 -284
  58. jaxsim/physics/algos/forward_kinematics.py +0 -79
  59. jaxsim/physics/algos/jacobian.py +0 -98
  60. jaxsim/physics/algos/rnea.py +0 -180
  61. jaxsim/physics/algos/rnea_motors.py +0 -196
  62. jaxsim/physics/algos/soft_contacts.py +0 -523
  63. jaxsim/physics/algos/utils.py +0 -69
  64. jaxsim/physics/model/__init__.py +0 -0
  65. jaxsim/physics/model/ground_contact.py +0 -55
  66. jaxsim/physics/model/physics_model.py +0 -388
  67. jaxsim/physics/model/physics_model_state.py +0 -283
  68. jaxsim/simulation/__init__.py +0 -4
  69. jaxsim/simulation/integrators.py +0 -393
  70. jaxsim/simulation/ode.py +0 -290
  71. jaxsim/simulation/ode_data.py +0 -96
  72. jaxsim/simulation/ode_integration.py +0 -62
  73. jaxsim/simulation/simulator.py +0 -543
  74. jaxsim/simulation/simulator_callbacks.py +0 -79
  75. jaxsim/simulation/utils.py +0 -15
  76. jaxsim/sixd/__init__.py +0 -2
  77. jaxsim/utils/oop.py +0 -536
  78. jaxsim/utils/vmappable.py +0 -117
  79. jaxsim-0.2.dev188.dist-info/RECORD +0 -81
  80. {jaxsim-0.2.dev188.dist-info → jaxsim-0.2.dev364.dist-info}/LICENSE +0 -0
  81. {jaxsim-0.2.dev188.dist-info → jaxsim-0.2.dev364.dist-info}/top_level.txt +0 -0
jaxsim/api/ode.py CHANGED
@@ -2,34 +2,30 @@ from typing import Any, Protocol
2
2
 
3
3
  import jax
4
4
  import jax.numpy as jnp
5
- import jaxlie
6
5
 
7
- import jaxsim.physics.algos.soft_contacts
6
+ import jaxsim.api as js
7
+ import jaxsim.rbda
8
8
  import jaxsim.typing as jtp
9
- from jaxsim import VelRepr, integrators
10
- from jaxsim.integrators.common import Time
11
- from jaxsim.math.quaternion import Quaternion
12
- from jaxsim.physics.algos.soft_contacts import SoftContactsState
13
- from jaxsim.physics.model.physics_model_state import PhysicsModelState
14
- from jaxsim.simulation.ode_data import ODEState
9
+ from jaxsim.integrators import Time
10
+ from jaxsim.math import Quaternion
11
+ from jaxsim.utils import Mutability
15
12
 
16
- from . import contact as Contact
17
- from . import data as Data
18
- from . import model as Model
13
+ from .common import VelRepr
14
+ from .ode_data import ODEState
19
15
 
20
16
 
21
17
  class SystemDynamicsFromModelAndData(Protocol):
22
18
  def __call__(
23
19
  self,
24
- model: Model.JaxSimModel,
25
- data: Data.JaxSimModelData,
20
+ model: js.model.JaxSimModel,
21
+ data: js.data.JaxSimModelData,
26
22
  **kwargs: dict[str, Any],
27
23
  ) -> tuple[ODEState, dict[str, Any]]: ...
28
24
 
29
25
 
30
26
  def wrap_system_dynamics_for_integration(
31
- model: Model.JaxSimModel,
32
- data: Data.JaxSimModelData,
27
+ model: js.model.JaxSimModel,
28
+ data: js.data.JaxSimModelData,
33
29
  *,
34
30
  system_dynamics: SystemDynamicsFromModelAndData,
35
31
  **kwargs,
@@ -49,17 +45,33 @@ def wrap_system_dynamics_for_integration(
49
45
  """
50
46
 
51
47
  # We allow to close `system_dynamics` over additional kwargs.
52
- kwargs_closed = kwargs
48
+ kwargs_closed = kwargs.copy()
53
49
 
54
- def f(x: ODEState, t: Time, **kwargs) -> tuple[ODEState, dict[str, Any]]:
50
+ # Create a local copy of model and data.
51
+ # The wrapped dynamics will hold a reference of this object.
52
+ model_closed = model.copy()
53
+ data_closed = data.copy().replace(
54
+ state=js.ode_data.ODEState.zero(model=model_closed)
55
+ )
55
56
 
56
- # Close f over the `data` parameter.
57
- with data.editable(validate=True) as data_rw:
58
- data_rw.state = x
59
- data_rw.time_ns = jnp.array(t * 1e9).astype(jnp.uint64)
57
+ def f(x: ODEState, t: Time, **kwargs_f) -> tuple[ODEState, dict[str, Any]]:
60
58
 
61
- # Close f over the `model` parameter.
62
- return system_dynamics(model=model, data=data_rw, **kwargs_closed | kwargs)
59
+ # Allow caller to override the closed data and model objects.
60
+ data_f = kwargs_f.pop("data", data_closed)
61
+ model_f = kwargs_f.pop("model", model_closed)
62
+
63
+ # Update the state and time stored inside data.
64
+ with data_f.editable(validate=True) as data_rw:
65
+ data_rw.state = x
66
+ data_rw.time_ns = jnp.array(t * 1e9).astype(data_rw.time_ns.dtype)
67
+
68
+ # Evaluate the system dynamics, allowing to override the kwargs originally
69
+ # passed when the closure was created.
70
+ return system_dynamics(
71
+ model=model_f,
72
+ data=data_rw,
73
+ **(kwargs_closed | kwargs_f),
74
+ )
63
75
 
64
76
  f: jaxsim.integrators.common.SystemDynamics[ODEState, ODEState]
65
77
  return f
@@ -72,11 +84,11 @@ def wrap_system_dynamics_for_integration(
72
84
 
73
85
  @jax.jit
74
86
  def system_velocity_dynamics(
75
- model: Model.JaxSimModel,
76
- data: Data.JaxSimModelData,
87
+ model: js.model.JaxSimModel,
88
+ data: js.data.JaxSimModelData,
77
89
  *,
78
90
  joint_forces: jtp.Vector | None = None,
79
- external_forces: jtp.Vector | None = None,
91
+ link_forces: jtp.Vector | None = None,
80
92
  ) -> tuple[jtp.Vector, jtp.Vector, jtp.Matrix, dict[str, Any]]:
81
93
  """
82
94
  Compute the dynamics of the system velocity.
@@ -85,13 +97,13 @@ def system_velocity_dynamics(
85
97
  model: The model to consider.
86
98
  data: The data of the considered model.
87
99
  joint_forces: The joint forces to apply.
88
- external_forces: The external forces to apply to the links.
100
+ link_forces: The 6D forces to apply to the links.
89
101
 
90
102
  Returns:
91
103
  A tuple containing the derivative of the base 6D velocity in inertial-fixed
92
104
  representation, the derivative of the joint velocities, the derivative of
93
105
  the material deformation, and the dictionary of auxiliary data returned by
94
- the system dynamics evalutation.
106
+ the system dynamics evaluation.
95
107
  """
96
108
 
97
109
  # Build joint torques if not provided
@@ -101,10 +113,10 @@ def system_velocity_dynamics(
101
113
  else jnp.zeros_like(data.joint_positions())
102
114
  ).astype(float)
103
115
 
104
- # Build external forces if not provided
105
- f_ext = (
106
- jnp.atleast_2d(external_forces.squeeze())
107
- if external_forces is not None
116
+ # Build link forces if not provided
117
+ W_f_L = (
118
+ jnp.atleast_2d(link_forces.squeeze())
119
+ if link_forces is not None
108
120
  else jnp.zeros((model.number_of_links(), 6))
109
121
  ).astype(float)
110
122
 
@@ -114,33 +126,36 @@ def system_velocity_dynamics(
114
126
 
115
127
  # Initialize the 6D forces W_f ∈ ℝ^{n_L × 6} applied to links due to contact
116
128
  # with the terrain.
117
- W_f_Li_terrain = jnp.zeros_like(f_ext).astype(float)
129
+ W_f_Li_terrain = jnp.zeros_like(W_f_L).astype(float)
118
130
 
119
- # Initialize the 6D contact forces W_f ∈ ℝ^{n_c × 3} applied to collidable points,
131
+ # Initialize the 6D contact forces W_f ∈ ℝ^{n_c × 6} applied to collidable points,
120
132
  # expressed in the world frame.
121
133
  W_f_Ci = None
122
134
 
123
135
  # Initialize the derivative of the tangential deformation ṁ ∈ ℝ^{n_c × 3}.
124
136
  ṁ = jnp.zeros_like(data.state.soft_contacts.tangential_deformation).astype(float)
125
137
 
126
- if model.physics_model.gc.body.size > 0:
127
- # Compute the position and linear velocities (mixed representation) of
128
- # all collidable points belonging to the robot.
129
- W_p_Ci, W_ṗ_Ci = Contact.collidable_point_kinematics(model=model, data=data)
130
-
131
- # Compute the 3D forces applied to each collidable point.
132
- W_f_Ci, = jax.vmap(
133
- lambda p, ṗ, m: jaxsim.physics.algos.soft_contacts.SoftContacts(
134
- parameters=data.soft_contacts_params, terrain=model.terrain
135
- ).contact_model(position=p, velocity=ṗ, tangential_deformation=m)
136
- )(W_p_Ci, W_ṗ_Ci, data.state.soft_contacts.tangential_deformation.T)
138
+ if len(model.kin_dyn_parameters.contact_parameters.body) > 0:
139
+ # Compute the 6D forces applied to each collidable point and the
140
+ # corresponding material deformation rates.
141
+ with data.switch_velocity_representation(VelRepr.Inertial):
142
+ W_f_Ci, ṁ = js.contact.collidable_point_dynamics(model=model, data=data)
143
+
144
+ # Construct the vector defining the parent link index of each collidable point.
145
+ # We use this vector to sum the 6D forces of all collidable points rigidly
146
+ # attached to the same link.
147
+ parent_link_index_of_collidable_points = jnp.array(
148
+ model.kin_dyn_parameters.contact_parameters.body, dtype=int
149
+ )
137
150
 
138
151
  # Sum the forces of all collidable points rigidly attached to a body.
139
152
  # Since the contact forces W_f_Ci are expressed in the world frame,
140
153
  # we don't need any coordinate transformation.
141
154
  W_f_Li_terrain = jax.vmap(
142
155
  lambda nc: (
143
- jnp.vstack(jnp.equal(model.physics_model.gc.body, nc).astype(int))
156
+ jnp.vstack(
157
+ jnp.equal(parent_link_index_of_collidable_points, nc).astype(int)
158
+ )
144
159
  * W_f_Ci
145
160
  ).sum(axis=0)
146
161
  )(jnp.arange(model.number_of_links()))
@@ -160,8 +175,12 @@ def system_velocity_dynamics(
160
175
 
161
176
  if model.dofs() > 0:
162
177
  # Static and viscous joint friction parameters
163
- kc = jnp.array(list(model.physics_model._joint_friction_static.values()))
164
- kv = jnp.array(list(model.physics_model._joint_friction_viscous.values()))
178
+ kc = jnp.array(
179
+ model.kin_dyn_parameters.joint_parameters.friction_static
180
+ ).astype(float)
181
+ kv = jnp.array(
182
+ model.kin_dyn_parameters.joint_parameters.friction_viscous
183
+ ).astype(float)
165
184
 
166
185
  # Compute the joint friction torque
167
186
  τ_friction = -(
@@ -177,24 +196,24 @@ def system_velocity_dynamics(
177
196
  τ_total = τ + τ_friction + τ_position_limit
178
197
 
179
198
  # Compute the total external 6D forces applied to the links
180
- W_f_L_total = f_ext + W_f_Li_terrain
199
+ W_f_L_total = W_f_L + W_f_Li_terrain
181
200
 
182
201
  # - Joint accelerations: s̈ ∈ ℝⁿ
183
202
  # - Base inertial-fixed acceleration: W_v̇_WB = (W_p̈_B, W_ω̇_B) ∈ ℝ⁶
184
203
  with data.switch_velocity_representation(velocity_representation=VelRepr.Inertial):
185
- W_v̇_WB, s̈ = Model.forward_dynamics_aba(
204
+ W_v̇_WB, s̈ = js.model.forward_dynamics_aba(
186
205
  model=model,
187
206
  data=data,
188
207
  joint_forces=τ_total,
189
- external_forces=W_f_L_total,
208
+ link_forces=W_f_L_total,
190
209
  )
191
210
 
192
- return W_v̇_WB, s̈, ṁ.T, dict()
211
+ return W_v̇_WB, s̈, ṁ, dict()
193
212
 
194
213
 
195
214
  @jax.jit
196
215
  def system_position_dynamics(
197
- model: Model.JaxSimModel, data: Data.JaxSimModelData
216
+ model: js.model.JaxSimModel, data: js.data.JaxSimModelData
198
217
  ) -> tuple[jtp.Vector, jtp.Vector, jtp.Vector]:
199
218
  """
200
219
  Compute the dynamics of the system position.
@@ -208,8 +227,8 @@ def system_position_dynamics(
208
227
  base quaternion, and the derivative of the joint positions.
209
228
  """
210
229
 
211
- ṡ = data.state.physics_model.joint_velocities
212
- W_Q_B = data.state.physics_model.base_quaternion
230
+ ṡ = data.joint_velocities(model=model)
231
+ W_Q_B = data.base_orientation(dcm=False)
213
232
 
214
233
  with data.switch_velocity_representation(velocity_representation=VelRepr.Mixed):
215
234
  W_ṗ_B = data.base_velocity()[0:3]
@@ -228,11 +247,11 @@ def system_position_dynamics(
228
247
 
229
248
  @jax.jit
230
249
  def system_dynamics(
231
- model: Model.JaxSimModel,
232
- data: Data.JaxSimModelData,
250
+ model: js.model.JaxSimModel,
251
+ data: js.data.JaxSimModelData,
233
252
  *,
234
253
  joint_forces: jtp.Vector | None = None,
235
- external_forces: jtp.Vector | None = None,
254
+ link_forces: jtp.Vector | None = None,
236
255
  ) -> tuple[ODEState, dict[str, Any]]:
237
256
  """
238
257
  Compute the dynamics of the system.
@@ -241,7 +260,7 @@ def system_dynamics(
241
260
  model: The model to consider.
242
261
  data: The data of the considered model.
243
262
  joint_forces: The joint forces to apply.
244
- external_forces: The external forces to apply to the links.
263
+ link_forces: The 6D forces to apply to the links.
245
264
 
246
265
  Returns:
247
266
  A tuple with an `ODEState` object storing in each of its attributes the
@@ -254,7 +273,7 @@ def system_dynamics(
254
273
  model=model,
255
274
  data=data,
256
275
  joint_forces=joint_forces,
257
- external_forces=external_forces,
276
+ link_forces=link_forces,
258
277
  )
259
278
 
260
279
  # Extract the velocities.
@@ -263,18 +282,15 @@ def system_dynamics(
263
282
  # Create an ODEState object populated with the derivative of each leaf.
264
283
  # Our integrators, operating on generic pytrees, will be able to handle it
265
284
  # automatically as state derivative.
266
- ode_state_derivative = ODEState.build(
267
- physics_model_state=PhysicsModelState.build(
268
- base_position=W_ṗ_B,
269
- base_quaternion=W_Q̇_B,
270
- joint_positions=ṡ,
271
- base_linear_velocity=W_v̇_WB[0:3],
272
- base_angular_velocity=W_v̇_WB[3:6],
273
- joint_velocities=s̈,
274
- ),
275
- soft_contacts_state=SoftContactsState.build(
276
- tangential_deformation=ṁ,
277
- ),
285
+ ode_state_derivative = ODEState.build_from_jaxsim_model(
286
+ model=model,
287
+ base_position=W_ṗ_B,
288
+ base_quaternion=W_Q̇_B,
289
+ joint_positions=ṡ,
290
+ base_linear_velocity=W_v̇_WB[0:3],
291
+ base_angular_velocity=W_v̇_WB[3:6],
292
+ joint_velocities=s̈,
293
+ tangential_deformation=ṁ,
278
294
  )
279
295
 
280
296
  return ode_state_derivative, aux_dict