jaxsim 0.2.dev191__py3-none-any.whl → 0.2.dev364__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. jaxsim/__init__.py +3 -4
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +3 -1
  4. jaxsim/api/com.py +240 -0
  5. jaxsim/api/common.py +13 -2
  6. jaxsim/api/contact.py +120 -43
  7. jaxsim/api/data.py +112 -71
  8. jaxsim/api/joint.py +77 -36
  9. jaxsim/api/kin_dyn_parameters.py +777 -0
  10. jaxsim/api/link.py +150 -75
  11. jaxsim/api/model.py +542 -269
  12. jaxsim/api/ode.py +86 -74
  13. jaxsim/api/ode_data.py +694 -0
  14. jaxsim/api/references.py +12 -11
  15. jaxsim/integrators/__init__.py +2 -2
  16. jaxsim/integrators/common.py +110 -24
  17. jaxsim/integrators/fixed_step.py +11 -67
  18. jaxsim/integrators/variable_step.py +610 -0
  19. jaxsim/math/__init__.py +11 -0
  20. jaxsim/math/adjoint.py +24 -2
  21. jaxsim/math/joint_model.py +335 -0
  22. jaxsim/math/quaternion.py +44 -3
  23. jaxsim/math/rotation.py +4 -4
  24. jaxsim/math/transform.py +93 -0
  25. jaxsim/parsers/descriptions/link.py +2 -2
  26. jaxsim/parsers/rod/utils.py +7 -8
  27. jaxsim/rbda/__init__.py +7 -0
  28. jaxsim/rbda/aba.py +295 -0
  29. jaxsim/rbda/collidable_points.py +142 -0
  30. jaxsim/{physics/algos → rbda}/crba.py +43 -42
  31. jaxsim/rbda/forward_kinematics.py +113 -0
  32. jaxsim/rbda/jacobian.py +201 -0
  33. jaxsim/rbda/rnea.py +237 -0
  34. jaxsim/rbda/soft_contacts.py +296 -0
  35. jaxsim/rbda/utils.py +152 -0
  36. jaxsim/terrain/__init__.py +2 -0
  37. jaxsim/utils/__init__.py +1 -4
  38. jaxsim/utils/hashless.py +18 -0
  39. jaxsim/utils/jaxsim_dataclass.py +281 -30
  40. {jaxsim-0.2.dev191.dist-info → jaxsim-0.2.dev364.dist-info}/METADATA +4 -6
  41. jaxsim-0.2.dev364.dist-info/RECORD +64 -0
  42. {jaxsim-0.2.dev191.dist-info → jaxsim-0.2.dev364.dist-info}/WHEEL +1 -1
  43. jaxsim/high_level/__init__.py +0 -2
  44. jaxsim/high_level/common.py +0 -11
  45. jaxsim/high_level/joint.py +0 -148
  46. jaxsim/high_level/link.py +0 -259
  47. jaxsim/high_level/model.py +0 -1686
  48. jaxsim/math/conv.py +0 -114
  49. jaxsim/math/joint.py +0 -102
  50. jaxsim/math/plucker.py +0 -100
  51. jaxsim/physics/__init__.py +0 -12
  52. jaxsim/physics/algos/__init__.py +0 -0
  53. jaxsim/physics/algos/aba.py +0 -254
  54. jaxsim/physics/algos/aba_motors.py +0 -284
  55. jaxsim/physics/algos/forward_kinematics.py +0 -79
  56. jaxsim/physics/algos/jacobian.py +0 -98
  57. jaxsim/physics/algos/rnea.py +0 -180
  58. jaxsim/physics/algos/rnea_motors.py +0 -196
  59. jaxsim/physics/algos/soft_contacts.py +0 -523
  60. jaxsim/physics/algos/utils.py +0 -69
  61. jaxsim/physics/model/__init__.py +0 -0
  62. jaxsim/physics/model/ground_contact.py +0 -53
  63. jaxsim/physics/model/physics_model.py +0 -388
  64. jaxsim/physics/model/physics_model_state.py +0 -283
  65. jaxsim/simulation/__init__.py +0 -4
  66. jaxsim/simulation/integrators.py +0 -393
  67. jaxsim/simulation/ode.py +0 -290
  68. jaxsim/simulation/ode_data.py +0 -96
  69. jaxsim/simulation/ode_integration.py +0 -62
  70. jaxsim/simulation/simulator.py +0 -543
  71. jaxsim/simulation/simulator_callbacks.py +0 -79
  72. jaxsim/simulation/utils.py +0 -15
  73. jaxsim/sixd/__init__.py +0 -2
  74. jaxsim/utils/oop.py +0 -536
  75. jaxsim/utils/vmappable.py +0 -117
  76. jaxsim-0.2.dev191.dist-info/RECORD +0 -81
  77. /jaxsim/{physics/algos → terrain}/terrain.py +0 -0
  78. {jaxsim-0.2.dev191.dist-info → jaxsim-0.2.dev364.dist-info}/LICENSE +0 -0
  79. {jaxsim-0.2.dev191.dist-info → jaxsim-0.2.dev364.dist-info}/top_level.txt +0 -0
jaxsim/api/ode.py CHANGED
@@ -2,34 +2,30 @@ from typing import Any, Protocol
2
2
 
3
3
  import jax
4
4
  import jax.numpy as jnp
5
- import jaxlie
6
5
 
7
- import jaxsim.physics.algos.soft_contacts
6
+ import jaxsim.api as js
7
+ import jaxsim.rbda
8
8
  import jaxsim.typing as jtp
9
- from jaxsim import VelRepr, integrators
10
- from jaxsim.integrators.common import Time
11
- from jaxsim.math.quaternion import Quaternion
12
- from jaxsim.physics.algos.soft_contacts import SoftContactsState
13
- from jaxsim.physics.model.physics_model_state import PhysicsModelState
14
- from jaxsim.simulation.ode_data import ODEState
9
+ from jaxsim.integrators import Time
10
+ from jaxsim.math import Quaternion
11
+ from jaxsim.utils import Mutability
15
12
 
16
- from . import contact as Contact
17
- from . import data as Data
18
- from . import model as Model
13
+ from .common import VelRepr
14
+ from .ode_data import ODEState
19
15
 
20
16
 
21
17
  class SystemDynamicsFromModelAndData(Protocol):
22
18
  def __call__(
23
19
  self,
24
- model: Model.JaxSimModel,
25
- data: Data.JaxSimModelData,
20
+ model: js.model.JaxSimModel,
21
+ data: js.data.JaxSimModelData,
26
22
  **kwargs: dict[str, Any],
27
23
  ) -> tuple[ODEState, dict[str, Any]]: ...
28
24
 
29
25
 
30
26
  def wrap_system_dynamics_for_integration(
31
- model: Model.JaxSimModel,
32
- data: Data.JaxSimModelData,
27
+ model: js.model.JaxSimModel,
28
+ data: js.data.JaxSimModelData,
33
29
  *,
34
30
  system_dynamics: SystemDynamicsFromModelAndData,
35
31
  **kwargs,
@@ -49,17 +45,33 @@ def wrap_system_dynamics_for_integration(
49
45
  """
50
46
 
51
47
  # We allow to close `system_dynamics` over additional kwargs.
52
- kwargs_closed = kwargs
48
+ kwargs_closed = kwargs.copy()
53
49
 
54
- def f(x: ODEState, t: Time, **kwargs) -> tuple[ODEState, dict[str, Any]]:
50
+ # Create a local copy of model and data.
51
+ # The wrapped dynamics will hold a reference of this object.
52
+ model_closed = model.copy()
53
+ data_closed = data.copy().replace(
54
+ state=js.ode_data.ODEState.zero(model=model_closed)
55
+ )
55
56
 
56
- # Close f over the `data` parameter.
57
- with data.editable(validate=True) as data_rw:
58
- data_rw.state = x
59
- data_rw.time_ns = jnp.array(t * 1e9).astype(jnp.uint64)
57
+ def f(x: ODEState, t: Time, **kwargs_f) -> tuple[ODEState, dict[str, Any]]:
60
58
 
61
- # Close f over the `model` parameter.
62
- return system_dynamics(model=model, data=data_rw, **kwargs_closed | kwargs)
59
+ # Allow caller to override the closed data and model objects.
60
+ data_f = kwargs_f.pop("data", data_closed)
61
+ model_f = kwargs_f.pop("model", model_closed)
62
+
63
+ # Update the state and time stored inside data.
64
+ with data_f.editable(validate=True) as data_rw:
65
+ data_rw.state = x
66
+ data_rw.time_ns = jnp.array(t * 1e9).astype(data_rw.time_ns.dtype)
67
+
68
+ # Evaluate the system dynamics, allowing to override the kwargs originally
69
+ # passed when the closure was created.
70
+ return system_dynamics(
71
+ model=model_f,
72
+ data=data_rw,
73
+ **(kwargs_closed | kwargs_f),
74
+ )
63
75
 
64
76
  f: jaxsim.integrators.common.SystemDynamics[ODEState, ODEState]
65
77
  return f
@@ -72,11 +84,11 @@ def wrap_system_dynamics_for_integration(
72
84
 
73
85
  @jax.jit
74
86
  def system_velocity_dynamics(
75
- model: Model.JaxSimModel,
76
- data: Data.JaxSimModelData,
87
+ model: js.model.JaxSimModel,
88
+ data: js.data.JaxSimModelData,
77
89
  *,
78
90
  joint_forces: jtp.Vector | None = None,
79
- external_forces: jtp.Vector | None = None,
91
+ link_forces: jtp.Vector | None = None,
80
92
  ) -> tuple[jtp.Vector, jtp.Vector, jtp.Matrix, dict[str, Any]]:
81
93
  """
82
94
  Compute the dynamics of the system velocity.
@@ -85,13 +97,13 @@ def system_velocity_dynamics(
85
97
  model: The model to consider.
86
98
  data: The data of the considered model.
87
99
  joint_forces: The joint forces to apply.
88
- external_forces: The external forces to apply to the links.
100
+ link_forces: The 6D forces to apply to the links.
89
101
 
90
102
  Returns:
91
103
  A tuple containing the derivative of the base 6D velocity in inertial-fixed
92
104
  representation, the derivative of the joint velocities, the derivative of
93
105
  the material deformation, and the dictionary of auxiliary data returned by
94
- the system dynamics evalutation.
106
+ the system dynamics evaluation.
95
107
  """
96
108
 
97
109
  # Build joint torques if not provided
@@ -101,10 +113,10 @@ def system_velocity_dynamics(
101
113
  else jnp.zeros_like(data.joint_positions())
102
114
  ).astype(float)
103
115
 
104
- # Build external forces if not provided
105
- f_ext = (
106
- jnp.atleast_2d(external_forces.squeeze())
107
- if external_forces is not None
116
+ # Build link forces if not provided
117
+ W_f_L = (
118
+ jnp.atleast_2d(link_forces.squeeze())
119
+ if link_forces is not None
108
120
  else jnp.zeros((model.number_of_links(), 6))
109
121
  ).astype(float)
110
122
 
@@ -114,26 +126,27 @@ def system_velocity_dynamics(
114
126
 
115
127
  # Initialize the 6D forces W_f ∈ ℝ^{n_L × 6} applied to links due to contact
116
128
  # with the terrain.
117
- W_f_Li_terrain = jnp.zeros_like(f_ext).astype(float)
129
+ W_f_Li_terrain = jnp.zeros_like(W_f_L).astype(float)
118
130
 
119
- # Initialize the 6D contact forces W_f ∈ ℝ^{n_c × 3} applied to collidable points,
131
+ # Initialize the 6D contact forces W_f ∈ ℝ^{n_c × 6} applied to collidable points,
120
132
  # expressed in the world frame.
121
133
  W_f_Ci = None
122
134
 
123
135
  # Initialize the derivative of the tangential deformation ṁ ∈ ℝ^{n_c × 3}.
124
136
  ṁ = jnp.zeros_like(data.state.soft_contacts.tangential_deformation).astype(float)
125
137
 
126
- if len(model.physics_model.gc.body) > 0:
127
- # Compute the position and linear velocities (mixed representation) of
128
- # all collidable points belonging to the robot.
129
- W_p_Ci, W_ṗ_Ci = Contact.collidable_point_kinematics(model=model, data=data)
130
-
131
- # Compute the 3D forces applied to each collidable point.
132
- W_f_Ci, = jax.vmap(
133
- lambda p, ṗ, m: jaxsim.physics.algos.soft_contacts.SoftContacts(
134
- parameters=data.soft_contacts_params, terrain=model.terrain
135
- ).contact_model(position=p, velocity=ṗ, tangential_deformation=m)
136
- )(W_p_Ci, W_ṗ_Ci, data.state.soft_contacts.tangential_deformation.T)
138
+ if len(model.kin_dyn_parameters.contact_parameters.body) > 0:
139
+ # Compute the 6D forces applied to each collidable point and the
140
+ # corresponding material deformation rates.
141
+ with data.switch_velocity_representation(VelRepr.Inertial):
142
+ W_f_Ci, ṁ = js.contact.collidable_point_dynamics(model=model, data=data)
143
+
144
+ # Construct the vector defining the parent link index of each collidable point.
145
+ # We use this vector to sum the 6D forces of all collidable points rigidly
146
+ # attached to the same link.
147
+ parent_link_index_of_collidable_points = jnp.array(
148
+ model.kin_dyn_parameters.contact_parameters.body, dtype=int
149
+ )
137
150
 
138
151
  # Sum the forces of all collidable points rigidly attached to a body.
139
152
  # Since the contact forces W_f_Ci are expressed in the world frame,
@@ -141,9 +154,7 @@ def system_velocity_dynamics(
141
154
  W_f_Li_terrain = jax.vmap(
142
155
  lambda nc: (
143
156
  jnp.vstack(
144
- jnp.equal(
145
- np.array(model.physics_model.gc.body, dtype=int), nc
146
- ).astype(int)
157
+ jnp.equal(parent_link_index_of_collidable_points, nc).astype(int)
147
158
  )
148
159
  * W_f_Ci
149
160
  ).sum(axis=0)
@@ -164,8 +175,12 @@ def system_velocity_dynamics(
164
175
 
165
176
  if model.dofs() > 0:
166
177
  # Static and viscous joint friction parameters
167
- kc = jnp.array(list(model.physics_model._joint_friction_static.values()))
168
- kv = jnp.array(list(model.physics_model._joint_friction_viscous.values()))
178
+ kc = jnp.array(
179
+ model.kin_dyn_parameters.joint_parameters.friction_static
180
+ ).astype(float)
181
+ kv = jnp.array(
182
+ model.kin_dyn_parameters.joint_parameters.friction_viscous
183
+ ).astype(float)
169
184
 
170
185
  # Compute the joint friction torque
171
186
  τ_friction = -(
@@ -181,24 +196,24 @@ def system_velocity_dynamics(
181
196
  τ_total = τ + τ_friction + τ_position_limit
182
197
 
183
198
  # Compute the total external 6D forces applied to the links
184
- W_f_L_total = f_ext + W_f_Li_terrain
199
+ W_f_L_total = W_f_L + W_f_Li_terrain
185
200
 
186
201
  # - Joint accelerations: s̈ ∈ ℝⁿ
187
202
  # - Base inertial-fixed acceleration: W_v̇_WB = (W_p̈_B, W_ω̇_B) ∈ ℝ⁶
188
203
  with data.switch_velocity_representation(velocity_representation=VelRepr.Inertial):
189
- W_v̇_WB, s̈ = Model.forward_dynamics_aba(
204
+ W_v̇_WB, s̈ = js.model.forward_dynamics_aba(
190
205
  model=model,
191
206
  data=data,
192
207
  joint_forces=τ_total,
193
- external_forces=W_f_L_total,
208
+ link_forces=W_f_L_total,
194
209
  )
195
210
 
196
- return W_v̇_WB, s̈, ṁ.T, dict()
211
+ return W_v̇_WB, s̈, ṁ, dict()
197
212
 
198
213
 
199
214
  @jax.jit
200
215
  def system_position_dynamics(
201
- model: Model.JaxSimModel, data: Data.JaxSimModelData
216
+ model: js.model.JaxSimModel, data: js.data.JaxSimModelData
202
217
  ) -> tuple[jtp.Vector, jtp.Vector, jtp.Vector]:
203
218
  """
204
219
  Compute the dynamics of the system position.
@@ -212,8 +227,8 @@ def system_position_dynamics(
212
227
  base quaternion, and the derivative of the joint positions.
213
228
  """
214
229
 
215
- ṡ = data.state.physics_model.joint_velocities
216
- W_Q_B = data.state.physics_model.base_quaternion
230
+ ṡ = data.joint_velocities(model=model)
231
+ W_Q_B = data.base_orientation(dcm=False)
217
232
 
218
233
  with data.switch_velocity_representation(velocity_representation=VelRepr.Mixed):
219
234
  W_ṗ_B = data.base_velocity()[0:3]
@@ -232,11 +247,11 @@ def system_position_dynamics(
232
247
 
233
248
  @jax.jit
234
249
  def system_dynamics(
235
- model: Model.JaxSimModel,
236
- data: Data.JaxSimModelData,
250
+ model: js.model.JaxSimModel,
251
+ data: js.data.JaxSimModelData,
237
252
  *,
238
253
  joint_forces: jtp.Vector | None = None,
239
- external_forces: jtp.Vector | None = None,
254
+ link_forces: jtp.Vector | None = None,
240
255
  ) -> tuple[ODEState, dict[str, Any]]:
241
256
  """
242
257
  Compute the dynamics of the system.
@@ -245,7 +260,7 @@ def system_dynamics(
245
260
  model: The model to consider.
246
261
  data: The data of the considered model.
247
262
  joint_forces: The joint forces to apply.
248
- external_forces: The external forces to apply to the links.
263
+ link_forces: The 6D forces to apply to the links.
249
264
 
250
265
  Returns:
251
266
  A tuple with an `ODEState` object storing in each of its attributes the
@@ -258,7 +273,7 @@ def system_dynamics(
258
273
  model=model,
259
274
  data=data,
260
275
  joint_forces=joint_forces,
261
- external_forces=external_forces,
276
+ link_forces=link_forces,
262
277
  )
263
278
 
264
279
  # Extract the velocities.
@@ -267,18 +282,15 @@ def system_dynamics(
267
282
  # Create an ODEState object populated with the derivative of each leaf.
268
283
  # Our integrators, operating on generic pytrees, will be able to handle it
269
284
  # automatically as state derivative.
270
- ode_state_derivative = ODEState.build(
271
- physics_model_state=PhysicsModelState.build(
272
- base_position=W_ṗ_B,
273
- base_quaternion=W_Q̇_B,
274
- joint_positions=ṡ,
275
- base_linear_velocity=W_v̇_WB[0:3],
276
- base_angular_velocity=W_v̇_WB[3:6],
277
- joint_velocities=s̈,
278
- ),
279
- soft_contacts_state=SoftContactsState.build(
280
- tangential_deformation=ṁ,
281
- ),
285
+ ode_state_derivative = ODEState.build_from_jaxsim_model(
286
+ model=model,
287
+ base_position=W_ṗ_B,
288
+ base_quaternion=W_Q̇_B,
289
+ joint_positions=ṡ,
290
+ base_linear_velocity=W_v̇_WB[0:3],
291
+ base_angular_velocity=W_v̇_WB[3:6],
292
+ joint_velocities=s̈,
293
+ tangential_deformation=ṁ,
282
294
  )
283
295
 
284
296
  return ode_state_derivative, aux_dict