jaxsim 0.1.dev401__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/__init__.py +5 -6
- jaxsim/_version.py +2 -2
- jaxsim/api/__init__.py +3 -0
- jaxsim/api/com.py +240 -0
- jaxsim/api/common.py +216 -0
- jaxsim/api/contact.py +271 -0
- jaxsim/api/data.py +821 -0
- jaxsim/api/joint.py +189 -0
- jaxsim/api/kin_dyn_parameters.py +777 -0
- jaxsim/api/link.py +361 -0
- jaxsim/api/model.py +1633 -0
- jaxsim/api/ode.py +295 -0
- jaxsim/api/ode_data.py +694 -0
- jaxsim/api/references.py +421 -0
- jaxsim/integrators/__init__.py +2 -0
- jaxsim/integrators/common.py +594 -0
- jaxsim/integrators/fixed_step.py +102 -0
- jaxsim/integrators/variable_step.py +610 -0
- jaxsim/math/__init__.py +11 -0
- jaxsim/math/adjoint.py +24 -2
- jaxsim/math/joint_model.py +335 -0
- jaxsim/math/quaternion.py +44 -3
- jaxsim/math/rotation.py +4 -4
- jaxsim/math/transform.py +92 -0
- jaxsim/mujoco/__init__.py +3 -0
- jaxsim/mujoco/__main__.py +192 -0
- jaxsim/mujoco/loaders.py +615 -0
- jaxsim/mujoco/model.py +414 -0
- jaxsim/mujoco/visualizer.py +176 -0
- jaxsim/parsers/descriptions/collision.py +14 -0
- jaxsim/parsers/descriptions/link.py +13 -2
- jaxsim/parsers/kinematic_graph.py +8 -3
- jaxsim/parsers/rod/parser.py +54 -38
- jaxsim/parsers/rod/utils.py +7 -8
- jaxsim/rbda/__init__.py +7 -0
- jaxsim/rbda/aba.py +295 -0
- jaxsim/rbda/collidable_points.py +142 -0
- jaxsim/{physics/algos → rbda}/crba.py +43 -42
- jaxsim/rbda/forward_kinematics.py +113 -0
- jaxsim/rbda/jacobian.py +201 -0
- jaxsim/rbda/rnea.py +237 -0
- jaxsim/rbda/soft_contacts.py +296 -0
- jaxsim/rbda/utils.py +152 -0
- jaxsim/terrain/__init__.py +2 -0
- jaxsim/{physics/algos → terrain}/terrain.py +4 -6
- jaxsim/typing.py +30 -30
- jaxsim/utils/__init__.py +1 -4
- jaxsim/utils/hashless.py +18 -0
- jaxsim/utils/jaxsim_dataclass.py +281 -31
- {jaxsim-0.1.dev401.dist-info → jaxsim-0.2.0.dist-info}/LICENSE +1 -1
- jaxsim-0.2.0.dist-info/METADATA +237 -0
- jaxsim-0.2.0.dist-info/RECORD +64 -0
- {jaxsim-0.1.dev401.dist-info → jaxsim-0.2.0.dist-info}/WHEEL +1 -1
- jaxsim/high_level/__init__.py +0 -2
- jaxsim/high_level/common.py +0 -11
- jaxsim/high_level/joint.py +0 -148
- jaxsim/high_level/link.py +0 -259
- jaxsim/high_level/model.py +0 -1695
- jaxsim/math/conv.py +0 -114
- jaxsim/math/joint.py +0 -101
- jaxsim/math/plucker.py +0 -100
- jaxsim/physics/__init__.py +0 -12
- jaxsim/physics/algos/__init__.py +0 -0
- jaxsim/physics/algos/aba.py +0 -256
- jaxsim/physics/algos/aba_motors.py +0 -284
- jaxsim/physics/algos/forward_kinematics.py +0 -79
- jaxsim/physics/algos/jacobian.py +0 -98
- jaxsim/physics/algos/rnea.py +0 -180
- jaxsim/physics/algos/rnea_motors.py +0 -196
- jaxsim/physics/algos/soft_contacts.py +0 -454
- jaxsim/physics/algos/utils.py +0 -69
- jaxsim/physics/model/__init__.py +0 -0
- jaxsim/physics/model/ground_contact.py +0 -55
- jaxsim/physics/model/physics_model.py +0 -358
- jaxsim/physics/model/physics_model_state.py +0 -174
- jaxsim/simulation/__init__.py +0 -4
- jaxsim/simulation/integrators.py +0 -452
- jaxsim/simulation/ode.py +0 -290
- jaxsim/simulation/ode_data.py +0 -53
- jaxsim/simulation/ode_integration.py +0 -125
- jaxsim/simulation/simulator.py +0 -544
- jaxsim/simulation/simulator_callbacks.py +0 -53
- jaxsim/simulation/utils.py +0 -15
- jaxsim/sixd/__init__.py +0 -2
- jaxsim/utils/oop.py +0 -532
- jaxsim/utils/vmappable.py +0 -117
- jaxsim-0.1.dev401.dist-info/METADATA +0 -167
- jaxsim-0.1.dev401.dist-info/RECORD +0 -64
- {jaxsim-0.1.dev401.dist-info → jaxsim-0.2.0.dist-info}/top_level.txt +0 -0
jaxsim/api/ode.py
ADDED
@@ -0,0 +1,295 @@
|
|
1
|
+
from typing import Any, Protocol
|
2
|
+
|
3
|
+
import jax
|
4
|
+
import jax.numpy as jnp
|
5
|
+
|
6
|
+
import jaxsim.api as js
|
7
|
+
import jaxsim.rbda
|
8
|
+
import jaxsim.typing as jtp
|
9
|
+
from jaxsim.integrators import Time
|
10
|
+
from jaxsim.math import Quaternion
|
11
|
+
|
12
|
+
from .common import VelRepr
|
13
|
+
from .ode_data import ODEState
|
14
|
+
|
15
|
+
|
16
|
+
class SystemDynamicsFromModelAndData(Protocol):
|
17
|
+
def __call__(
|
18
|
+
self,
|
19
|
+
model: js.model.JaxSimModel,
|
20
|
+
data: js.data.JaxSimModelData,
|
21
|
+
**kwargs: dict[str, Any],
|
22
|
+
) -> tuple[ODEState, dict[str, Any]]: ...
|
23
|
+
|
24
|
+
|
25
|
+
def wrap_system_dynamics_for_integration(
|
26
|
+
model: js.model.JaxSimModel,
|
27
|
+
data: js.data.JaxSimModelData,
|
28
|
+
*,
|
29
|
+
system_dynamics: SystemDynamicsFromModelAndData,
|
30
|
+
**kwargs,
|
31
|
+
) -> jaxsim.integrators.common.SystemDynamics[ODEState, ODEState]:
|
32
|
+
"""
|
33
|
+
Wrap generic system dynamics operating on `JaxSimModel` and `JaxSimModelData`
|
34
|
+
for integration with `jaxsim.integrators`.
|
35
|
+
|
36
|
+
Args:
|
37
|
+
model: The model to consider.
|
38
|
+
data: The data of the considered model.
|
39
|
+
system_dynamics: The system dynamics to wrap.
|
40
|
+
**kwargs: Additional kwargs to close over the system dynamics.
|
41
|
+
|
42
|
+
Returns:
|
43
|
+
The system dynamics closed over the model, the data, and the additional kwargs.
|
44
|
+
"""
|
45
|
+
|
46
|
+
# We allow to close `system_dynamics` over additional kwargs.
|
47
|
+
kwargs_closed = kwargs.copy()
|
48
|
+
|
49
|
+
# Create a local copy of model and data.
|
50
|
+
# The wrapped dynamics will hold a reference of this object.
|
51
|
+
model_closed = model.copy()
|
52
|
+
data_closed = data.copy().replace(
|
53
|
+
state=js.ode_data.ODEState.zero(model=model_closed)
|
54
|
+
)
|
55
|
+
|
56
|
+
def f(x: ODEState, t: Time, **kwargs_f) -> tuple[ODEState, dict[str, Any]]:
|
57
|
+
|
58
|
+
# Allow caller to override the closed data and model objects.
|
59
|
+
data_f = kwargs_f.pop("data", data_closed)
|
60
|
+
model_f = kwargs_f.pop("model", model_closed)
|
61
|
+
|
62
|
+
# Update the state and time stored inside data.
|
63
|
+
with data_f.editable(validate=True) as data_rw:
|
64
|
+
data_rw.state = x
|
65
|
+
data_rw.time_ns = jnp.array(t * 1e9).astype(data_rw.time_ns.dtype)
|
66
|
+
|
67
|
+
# Evaluate the system dynamics, allowing to override the kwargs originally
|
68
|
+
# passed when the closure was created.
|
69
|
+
return system_dynamics(
|
70
|
+
model=model_f,
|
71
|
+
data=data_rw,
|
72
|
+
**(kwargs_closed | kwargs_f),
|
73
|
+
)
|
74
|
+
|
75
|
+
f: jaxsim.integrators.common.SystemDynamics[ODEState, ODEState]
|
76
|
+
return f
|
77
|
+
|
78
|
+
|
79
|
+
# ==================================
|
80
|
+
# Functions defining system dynamics
|
81
|
+
# ==================================
|
82
|
+
|
83
|
+
|
84
|
+
@jax.jit
|
85
|
+
def system_velocity_dynamics(
|
86
|
+
model: js.model.JaxSimModel,
|
87
|
+
data: js.data.JaxSimModelData,
|
88
|
+
*,
|
89
|
+
joint_forces: jtp.Vector | None = None,
|
90
|
+
link_forces: jtp.Vector | None = None,
|
91
|
+
) -> tuple[jtp.Vector, jtp.Vector, jtp.Matrix, dict[str, Any]]:
|
92
|
+
"""
|
93
|
+
Compute the dynamics of the system velocity.
|
94
|
+
|
95
|
+
Args:
|
96
|
+
model: The model to consider.
|
97
|
+
data: The data of the considered model.
|
98
|
+
joint_forces: The joint forces to apply.
|
99
|
+
link_forces: The 6D forces to apply to the links.
|
100
|
+
|
101
|
+
Returns:
|
102
|
+
A tuple containing the derivative of the base 6D velocity in inertial-fixed
|
103
|
+
representation, the derivative of the joint velocities, the derivative of
|
104
|
+
the material deformation, and the dictionary of auxiliary data returned by
|
105
|
+
the system dynamics evaluation.
|
106
|
+
"""
|
107
|
+
|
108
|
+
# Build joint torques if not provided
|
109
|
+
τ = (
|
110
|
+
jnp.atleast_1d(joint_forces.squeeze())
|
111
|
+
if joint_forces is not None
|
112
|
+
else jnp.zeros_like(data.joint_positions())
|
113
|
+
).astype(float)
|
114
|
+
|
115
|
+
# Build link forces if not provided
|
116
|
+
W_f_L = (
|
117
|
+
jnp.atleast_2d(link_forces.squeeze())
|
118
|
+
if link_forces is not None
|
119
|
+
else jnp.zeros((model.number_of_links(), 6))
|
120
|
+
).astype(float)
|
121
|
+
|
122
|
+
# ======================
|
123
|
+
# Compute contact forces
|
124
|
+
# ======================
|
125
|
+
|
126
|
+
# Initialize the 6D forces W_f ∈ ℝ^{n_L × 6} applied to links due to contact
|
127
|
+
# with the terrain.
|
128
|
+
W_f_Li_terrain = jnp.zeros_like(W_f_L).astype(float)
|
129
|
+
|
130
|
+
# Initialize the 6D contact forces W_f ∈ ℝ^{n_c × 6} applied to collidable points,
|
131
|
+
# expressed in the world frame.
|
132
|
+
W_f_Ci = None
|
133
|
+
|
134
|
+
# Initialize the derivative of the tangential deformation ṁ ∈ ℝ^{n_c × 3}.
|
135
|
+
ṁ = jnp.zeros_like(data.state.soft_contacts.tangential_deformation).astype(float)
|
136
|
+
|
137
|
+
if len(model.kin_dyn_parameters.contact_parameters.body) > 0:
|
138
|
+
# Compute the 6D forces applied to each collidable point and the
|
139
|
+
# corresponding material deformation rates.
|
140
|
+
with data.switch_velocity_representation(VelRepr.Inertial):
|
141
|
+
W_f_Ci, ṁ = js.contact.collidable_point_dynamics(model=model, data=data)
|
142
|
+
|
143
|
+
# Construct the vector defining the parent link index of each collidable point.
|
144
|
+
# We use this vector to sum the 6D forces of all collidable points rigidly
|
145
|
+
# attached to the same link.
|
146
|
+
parent_link_index_of_collidable_points = jnp.array(
|
147
|
+
model.kin_dyn_parameters.contact_parameters.body, dtype=int
|
148
|
+
)
|
149
|
+
|
150
|
+
# Sum the forces of all collidable points rigidly attached to a body.
|
151
|
+
# Since the contact forces W_f_Ci are expressed in the world frame,
|
152
|
+
# we don't need any coordinate transformation.
|
153
|
+
W_f_Li_terrain = jax.vmap(
|
154
|
+
lambda nc: (
|
155
|
+
jnp.vstack(
|
156
|
+
jnp.equal(parent_link_index_of_collidable_points, nc).astype(int)
|
157
|
+
)
|
158
|
+
* W_f_Ci
|
159
|
+
).sum(axis=0)
|
160
|
+
)(jnp.arange(model.number_of_links()))
|
161
|
+
|
162
|
+
# ====================
|
163
|
+
# Enforce joint limits
|
164
|
+
# ====================
|
165
|
+
|
166
|
+
# TODO: enforce joint limits
|
167
|
+
τ_position_limit = jnp.zeros_like(τ).astype(float)
|
168
|
+
|
169
|
+
# ====================
|
170
|
+
# Joint friction model
|
171
|
+
# ====================
|
172
|
+
|
173
|
+
τ_friction = jnp.zeros_like(τ).astype(float)
|
174
|
+
|
175
|
+
if model.dofs() > 0:
|
176
|
+
# Static and viscous joint friction parameters
|
177
|
+
kc = jnp.array(
|
178
|
+
model.kin_dyn_parameters.joint_parameters.friction_static
|
179
|
+
).astype(float)
|
180
|
+
kv = jnp.array(
|
181
|
+
model.kin_dyn_parameters.joint_parameters.friction_viscous
|
182
|
+
).astype(float)
|
183
|
+
|
184
|
+
# Compute the joint friction torque
|
185
|
+
τ_friction = -(
|
186
|
+
jnp.diag(kc) @ jnp.sign(data.state.physics_model.joint_positions)
|
187
|
+
+ jnp.diag(kv) @ data.state.physics_model.joint_velocities
|
188
|
+
)
|
189
|
+
|
190
|
+
# ========================
|
191
|
+
# Compute forward dynamics
|
192
|
+
# ========================
|
193
|
+
|
194
|
+
# Compute the total joint forces
|
195
|
+
τ_total = τ + τ_friction + τ_position_limit
|
196
|
+
|
197
|
+
# Compute the total external 6D forces applied to the links
|
198
|
+
W_f_L_total = W_f_L + W_f_Li_terrain
|
199
|
+
|
200
|
+
# - Joint accelerations: s̈ ∈ ℝⁿ
|
201
|
+
# - Base inertial-fixed acceleration: W_v̇_WB = (W_p̈_B, W_ω̇_B) ∈ ℝ⁶
|
202
|
+
with data.switch_velocity_representation(velocity_representation=VelRepr.Inertial):
|
203
|
+
W_v̇_WB, s̈ = js.model.forward_dynamics_aba(
|
204
|
+
model=model,
|
205
|
+
data=data,
|
206
|
+
joint_forces=τ_total,
|
207
|
+
link_forces=W_f_L_total,
|
208
|
+
)
|
209
|
+
|
210
|
+
return W_v̇_WB, s̈, ṁ, dict()
|
211
|
+
|
212
|
+
|
213
|
+
@jax.jit
|
214
|
+
def system_position_dynamics(
|
215
|
+
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
|
216
|
+
) -> tuple[jtp.Vector, jtp.Vector, jtp.Vector]:
|
217
|
+
"""
|
218
|
+
Compute the dynamics of the system position.
|
219
|
+
|
220
|
+
Args:
|
221
|
+
model: The model to consider.
|
222
|
+
data: The data of the considered model.
|
223
|
+
|
224
|
+
Returns:
|
225
|
+
A tuple containing the derivative of the base position, the derivative of the
|
226
|
+
base quaternion, and the derivative of the joint positions.
|
227
|
+
"""
|
228
|
+
|
229
|
+
ṡ = data.joint_velocities(model=model)
|
230
|
+
W_Q_B = data.base_orientation(dcm=False)
|
231
|
+
|
232
|
+
with data.switch_velocity_representation(velocity_representation=VelRepr.Mixed):
|
233
|
+
W_ṗ_B = data.base_velocity()[0:3]
|
234
|
+
|
235
|
+
with data.switch_velocity_representation(velocity_representation=VelRepr.Inertial):
|
236
|
+
W_ω_WB = data.base_velocity()[3:6]
|
237
|
+
|
238
|
+
W_Q̇_B = Quaternion.derivative(
|
239
|
+
quaternion=W_Q_B,
|
240
|
+
omega=W_ω_WB,
|
241
|
+
omega_in_body_fixed=False,
|
242
|
+
).squeeze()
|
243
|
+
|
244
|
+
return W_ṗ_B, W_Q̇_B, ṡ
|
245
|
+
|
246
|
+
|
247
|
+
@jax.jit
|
248
|
+
def system_dynamics(
|
249
|
+
model: js.model.JaxSimModel,
|
250
|
+
data: js.data.JaxSimModelData,
|
251
|
+
*,
|
252
|
+
joint_forces: jtp.Vector | None = None,
|
253
|
+
link_forces: jtp.Vector | None = None,
|
254
|
+
) -> tuple[ODEState, dict[str, Any]]:
|
255
|
+
"""
|
256
|
+
Compute the dynamics of the system.
|
257
|
+
|
258
|
+
Args:
|
259
|
+
model: The model to consider.
|
260
|
+
data: The data of the considered model.
|
261
|
+
joint_forces: The joint forces to apply.
|
262
|
+
link_forces: The 6D forces to apply to the links.
|
263
|
+
|
264
|
+
Returns:
|
265
|
+
A tuple with an `ODEState` object storing in each of its attributes the
|
266
|
+
corresponding derivative, and the dictionary of auxiliary data returned
|
267
|
+
by the system dynamics evaluation.
|
268
|
+
"""
|
269
|
+
|
270
|
+
# Compute the accelerations and the material deformation rate.
|
271
|
+
W_v̇_WB, s̈, ṁ, aux_dict = system_velocity_dynamics(
|
272
|
+
model=model,
|
273
|
+
data=data,
|
274
|
+
joint_forces=joint_forces,
|
275
|
+
link_forces=link_forces,
|
276
|
+
)
|
277
|
+
|
278
|
+
# Extract the velocities.
|
279
|
+
W_ṗ_B, W_Q̇_B, ṡ = system_position_dynamics(model=model, data=data)
|
280
|
+
|
281
|
+
# Create an ODEState object populated with the derivative of each leaf.
|
282
|
+
# Our integrators, operating on generic pytrees, will be able to handle it
|
283
|
+
# automatically as state derivative.
|
284
|
+
ode_state_derivative = ODEState.build_from_jaxsim_model(
|
285
|
+
model=model,
|
286
|
+
base_position=W_ṗ_B,
|
287
|
+
base_quaternion=W_Q̇_B,
|
288
|
+
joint_positions=ṡ,
|
289
|
+
base_linear_velocity=W_v̇_WB[0:3],
|
290
|
+
base_angular_velocity=W_v̇_WB[3:6],
|
291
|
+
joint_velocities=s̈,
|
292
|
+
tangential_deformation=ṁ,
|
293
|
+
)
|
294
|
+
|
295
|
+
return ode_state_derivative, aux_dict
|