jaxsim 0.2.dev191__py3-none-any.whl → 0.2.dev364__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/__init__.py +3 -4
- jaxsim/_version.py +2 -2
- jaxsim/api/__init__.py +3 -1
- jaxsim/api/com.py +240 -0
- jaxsim/api/common.py +13 -2
- jaxsim/api/contact.py +120 -43
- jaxsim/api/data.py +112 -71
- jaxsim/api/joint.py +77 -36
- jaxsim/api/kin_dyn_parameters.py +777 -0
- jaxsim/api/link.py +150 -75
- jaxsim/api/model.py +542 -269
- jaxsim/api/ode.py +86 -74
- jaxsim/api/ode_data.py +694 -0
- jaxsim/api/references.py +12 -11
- jaxsim/integrators/__init__.py +2 -2
- jaxsim/integrators/common.py +110 -24
- jaxsim/integrators/fixed_step.py +11 -67
- jaxsim/integrators/variable_step.py +610 -0
- jaxsim/math/__init__.py +11 -0
- jaxsim/math/adjoint.py +24 -2
- jaxsim/math/joint_model.py +335 -0
- jaxsim/math/quaternion.py +44 -3
- jaxsim/math/rotation.py +4 -4
- jaxsim/math/transform.py +93 -0
- jaxsim/parsers/descriptions/link.py +2 -2
- jaxsim/parsers/rod/utils.py +7 -8
- jaxsim/rbda/__init__.py +7 -0
- jaxsim/rbda/aba.py +295 -0
- jaxsim/rbda/collidable_points.py +142 -0
- jaxsim/{physics/algos → rbda}/crba.py +43 -42
- jaxsim/rbda/forward_kinematics.py +113 -0
- jaxsim/rbda/jacobian.py +201 -0
- jaxsim/rbda/rnea.py +237 -0
- jaxsim/rbda/soft_contacts.py +296 -0
- jaxsim/rbda/utils.py +152 -0
- jaxsim/terrain/__init__.py +2 -0
- jaxsim/utils/__init__.py +1 -4
- jaxsim/utils/hashless.py +18 -0
- jaxsim/utils/jaxsim_dataclass.py +281 -30
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.2.dev364.dist-info}/METADATA +4 -6
- jaxsim-0.2.dev364.dist-info/RECORD +64 -0
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.2.dev364.dist-info}/WHEEL +1 -1
- jaxsim/high_level/__init__.py +0 -2
- jaxsim/high_level/common.py +0 -11
- jaxsim/high_level/joint.py +0 -148
- jaxsim/high_level/link.py +0 -259
- jaxsim/high_level/model.py +0 -1686
- jaxsim/math/conv.py +0 -114
- jaxsim/math/joint.py +0 -102
- jaxsim/math/plucker.py +0 -100
- jaxsim/physics/__init__.py +0 -12
- jaxsim/physics/algos/__init__.py +0 -0
- jaxsim/physics/algos/aba.py +0 -254
- jaxsim/physics/algos/aba_motors.py +0 -284
- jaxsim/physics/algos/forward_kinematics.py +0 -79
- jaxsim/physics/algos/jacobian.py +0 -98
- jaxsim/physics/algos/rnea.py +0 -180
- jaxsim/physics/algos/rnea_motors.py +0 -196
- jaxsim/physics/algos/soft_contacts.py +0 -523
- jaxsim/physics/algos/utils.py +0 -69
- jaxsim/physics/model/__init__.py +0 -0
- jaxsim/physics/model/ground_contact.py +0 -53
- jaxsim/physics/model/physics_model.py +0 -388
- jaxsim/physics/model/physics_model_state.py +0 -283
- jaxsim/simulation/__init__.py +0 -4
- jaxsim/simulation/integrators.py +0 -393
- jaxsim/simulation/ode.py +0 -290
- jaxsim/simulation/ode_data.py +0 -96
- jaxsim/simulation/ode_integration.py +0 -62
- jaxsim/simulation/simulator.py +0 -543
- jaxsim/simulation/simulator_callbacks.py +0 -79
- jaxsim/simulation/utils.py +0 -15
- jaxsim/sixd/__init__.py +0 -2
- jaxsim/utils/oop.py +0 -536
- jaxsim/utils/vmappable.py +0 -117
- jaxsim-0.2.dev191.dist-info/RECORD +0 -81
- /jaxsim/{physics/algos → terrain}/terrain.py +0 -0
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.2.dev364.dist-info}/LICENSE +0 -0
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.2.dev364.dist-info}/top_level.txt +0 -0
jaxsim/api/ode.py
CHANGED
@@ -2,34 +2,30 @@ from typing import Any, Protocol
|
|
2
2
|
|
3
3
|
import jax
|
4
4
|
import jax.numpy as jnp
|
5
|
-
import jaxlie
|
6
5
|
|
7
|
-
import jaxsim.
|
6
|
+
import jaxsim.api as js
|
7
|
+
import jaxsim.rbda
|
8
8
|
import jaxsim.typing as jtp
|
9
|
-
from jaxsim import
|
10
|
-
from jaxsim.
|
11
|
-
from jaxsim.
|
12
|
-
from jaxsim.physics.algos.soft_contacts import SoftContactsState
|
13
|
-
from jaxsim.physics.model.physics_model_state import PhysicsModelState
|
14
|
-
from jaxsim.simulation.ode_data import ODEState
|
9
|
+
from jaxsim.integrators import Time
|
10
|
+
from jaxsim.math import Quaternion
|
11
|
+
from jaxsim.utils import Mutability
|
15
12
|
|
16
|
-
from . import
|
17
|
-
from . import
|
18
|
-
from . import model as Model
|
13
|
+
from .common import VelRepr
|
14
|
+
from .ode_data import ODEState
|
19
15
|
|
20
16
|
|
21
17
|
class SystemDynamicsFromModelAndData(Protocol):
|
22
18
|
def __call__(
|
23
19
|
self,
|
24
|
-
model:
|
25
|
-
data:
|
20
|
+
model: js.model.JaxSimModel,
|
21
|
+
data: js.data.JaxSimModelData,
|
26
22
|
**kwargs: dict[str, Any],
|
27
23
|
) -> tuple[ODEState, dict[str, Any]]: ...
|
28
24
|
|
29
25
|
|
30
26
|
def wrap_system_dynamics_for_integration(
|
31
|
-
model:
|
32
|
-
data:
|
27
|
+
model: js.model.JaxSimModel,
|
28
|
+
data: js.data.JaxSimModelData,
|
33
29
|
*,
|
34
30
|
system_dynamics: SystemDynamicsFromModelAndData,
|
35
31
|
**kwargs,
|
@@ -49,17 +45,33 @@ def wrap_system_dynamics_for_integration(
|
|
49
45
|
"""
|
50
46
|
|
51
47
|
# We allow to close `system_dynamics` over additional kwargs.
|
52
|
-
kwargs_closed = kwargs
|
48
|
+
kwargs_closed = kwargs.copy()
|
53
49
|
|
54
|
-
|
50
|
+
# Create a local copy of model and data.
|
51
|
+
# The wrapped dynamics will hold a reference of this object.
|
52
|
+
model_closed = model.copy()
|
53
|
+
data_closed = data.copy().replace(
|
54
|
+
state=js.ode_data.ODEState.zero(model=model_closed)
|
55
|
+
)
|
55
56
|
|
56
|
-
|
57
|
-
with data.editable(validate=True) as data_rw:
|
58
|
-
data_rw.state = x
|
59
|
-
data_rw.time_ns = jnp.array(t * 1e9).astype(jnp.uint64)
|
57
|
+
def f(x: ODEState, t: Time, **kwargs_f) -> tuple[ODEState, dict[str, Any]]:
|
60
58
|
|
61
|
-
#
|
62
|
-
|
59
|
+
# Allow caller to override the closed data and model objects.
|
60
|
+
data_f = kwargs_f.pop("data", data_closed)
|
61
|
+
model_f = kwargs_f.pop("model", model_closed)
|
62
|
+
|
63
|
+
# Update the state and time stored inside data.
|
64
|
+
with data_f.editable(validate=True) as data_rw:
|
65
|
+
data_rw.state = x
|
66
|
+
data_rw.time_ns = jnp.array(t * 1e9).astype(data_rw.time_ns.dtype)
|
67
|
+
|
68
|
+
# Evaluate the system dynamics, allowing to override the kwargs originally
|
69
|
+
# passed when the closure was created.
|
70
|
+
return system_dynamics(
|
71
|
+
model=model_f,
|
72
|
+
data=data_rw,
|
73
|
+
**(kwargs_closed | kwargs_f),
|
74
|
+
)
|
63
75
|
|
64
76
|
f: jaxsim.integrators.common.SystemDynamics[ODEState, ODEState]
|
65
77
|
return f
|
@@ -72,11 +84,11 @@ def wrap_system_dynamics_for_integration(
|
|
72
84
|
|
73
85
|
@jax.jit
|
74
86
|
def system_velocity_dynamics(
|
75
|
-
model:
|
76
|
-
data:
|
87
|
+
model: js.model.JaxSimModel,
|
88
|
+
data: js.data.JaxSimModelData,
|
77
89
|
*,
|
78
90
|
joint_forces: jtp.Vector | None = None,
|
79
|
-
|
91
|
+
link_forces: jtp.Vector | None = None,
|
80
92
|
) -> tuple[jtp.Vector, jtp.Vector, jtp.Matrix, dict[str, Any]]:
|
81
93
|
"""
|
82
94
|
Compute the dynamics of the system velocity.
|
@@ -85,13 +97,13 @@ def system_velocity_dynamics(
|
|
85
97
|
model: The model to consider.
|
86
98
|
data: The data of the considered model.
|
87
99
|
joint_forces: The joint forces to apply.
|
88
|
-
|
100
|
+
link_forces: The 6D forces to apply to the links.
|
89
101
|
|
90
102
|
Returns:
|
91
103
|
A tuple containing the derivative of the base 6D velocity in inertial-fixed
|
92
104
|
representation, the derivative of the joint velocities, the derivative of
|
93
105
|
the material deformation, and the dictionary of auxiliary data returned by
|
94
|
-
the system dynamics
|
106
|
+
the system dynamics evaluation.
|
95
107
|
"""
|
96
108
|
|
97
109
|
# Build joint torques if not provided
|
@@ -101,10 +113,10 @@ def system_velocity_dynamics(
|
|
101
113
|
else jnp.zeros_like(data.joint_positions())
|
102
114
|
).astype(float)
|
103
115
|
|
104
|
-
# Build
|
105
|
-
|
106
|
-
jnp.atleast_2d(
|
107
|
-
if
|
116
|
+
# Build link forces if not provided
|
117
|
+
W_f_L = (
|
118
|
+
jnp.atleast_2d(link_forces.squeeze())
|
119
|
+
if link_forces is not None
|
108
120
|
else jnp.zeros((model.number_of_links(), 6))
|
109
121
|
).astype(float)
|
110
122
|
|
@@ -114,26 +126,27 @@ def system_velocity_dynamics(
|
|
114
126
|
|
115
127
|
# Initialize the 6D forces W_f ∈ ℝ^{n_L × 6} applied to links due to contact
|
116
128
|
# with the terrain.
|
117
|
-
W_f_Li_terrain = jnp.zeros_like(
|
129
|
+
W_f_Li_terrain = jnp.zeros_like(W_f_L).astype(float)
|
118
130
|
|
119
|
-
# Initialize the 6D contact forces W_f ∈ ℝ^{n_c ×
|
131
|
+
# Initialize the 6D contact forces W_f ∈ ℝ^{n_c × 6} applied to collidable points,
|
120
132
|
# expressed in the world frame.
|
121
133
|
W_f_Ci = None
|
122
134
|
|
123
135
|
# Initialize the derivative of the tangential deformation ṁ ∈ ℝ^{n_c × 3}.
|
124
136
|
ṁ = jnp.zeros_like(data.state.soft_contacts.tangential_deformation).astype(float)
|
125
137
|
|
126
|
-
if len(model.
|
127
|
-
# Compute the
|
128
|
-
#
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
138
|
+
if len(model.kin_dyn_parameters.contact_parameters.body) > 0:
|
139
|
+
# Compute the 6D forces applied to each collidable point and the
|
140
|
+
# corresponding material deformation rates.
|
141
|
+
with data.switch_velocity_representation(VelRepr.Inertial):
|
142
|
+
W_f_Ci, ṁ = js.contact.collidable_point_dynamics(model=model, data=data)
|
143
|
+
|
144
|
+
# Construct the vector defining the parent link index of each collidable point.
|
145
|
+
# We use this vector to sum the 6D forces of all collidable points rigidly
|
146
|
+
# attached to the same link.
|
147
|
+
parent_link_index_of_collidable_points = jnp.array(
|
148
|
+
model.kin_dyn_parameters.contact_parameters.body, dtype=int
|
149
|
+
)
|
137
150
|
|
138
151
|
# Sum the forces of all collidable points rigidly attached to a body.
|
139
152
|
# Since the contact forces W_f_Ci are expressed in the world frame,
|
@@ -141,9 +154,7 @@ def system_velocity_dynamics(
|
|
141
154
|
W_f_Li_terrain = jax.vmap(
|
142
155
|
lambda nc: (
|
143
156
|
jnp.vstack(
|
144
|
-
jnp.equal(
|
145
|
-
np.array(model.physics_model.gc.body, dtype=int), nc
|
146
|
-
).astype(int)
|
157
|
+
jnp.equal(parent_link_index_of_collidable_points, nc).astype(int)
|
147
158
|
)
|
148
159
|
* W_f_Ci
|
149
160
|
).sum(axis=0)
|
@@ -164,8 +175,12 @@ def system_velocity_dynamics(
|
|
164
175
|
|
165
176
|
if model.dofs() > 0:
|
166
177
|
# Static and viscous joint friction parameters
|
167
|
-
kc = jnp.array(
|
168
|
-
|
178
|
+
kc = jnp.array(
|
179
|
+
model.kin_dyn_parameters.joint_parameters.friction_static
|
180
|
+
).astype(float)
|
181
|
+
kv = jnp.array(
|
182
|
+
model.kin_dyn_parameters.joint_parameters.friction_viscous
|
183
|
+
).astype(float)
|
169
184
|
|
170
185
|
# Compute the joint friction torque
|
171
186
|
τ_friction = -(
|
@@ -181,24 +196,24 @@ def system_velocity_dynamics(
|
|
181
196
|
τ_total = τ + τ_friction + τ_position_limit
|
182
197
|
|
183
198
|
# Compute the total external 6D forces applied to the links
|
184
|
-
W_f_L_total =
|
199
|
+
W_f_L_total = W_f_L + W_f_Li_terrain
|
185
200
|
|
186
201
|
# - Joint accelerations: s̈ ∈ ℝⁿ
|
187
202
|
# - Base inertial-fixed acceleration: W_v̇_WB = (W_p̈_B, W_ω̇_B) ∈ ℝ⁶
|
188
203
|
with data.switch_velocity_representation(velocity_representation=VelRepr.Inertial):
|
189
|
-
W_v̇_WB, s̈ =
|
204
|
+
W_v̇_WB, s̈ = js.model.forward_dynamics_aba(
|
190
205
|
model=model,
|
191
206
|
data=data,
|
192
207
|
joint_forces=τ_total,
|
193
|
-
|
208
|
+
link_forces=W_f_L_total,
|
194
209
|
)
|
195
210
|
|
196
|
-
return W_v̇_WB, s̈, m
|
211
|
+
return W_v̇_WB, s̈, ṁ, dict()
|
197
212
|
|
198
213
|
|
199
214
|
@jax.jit
|
200
215
|
def system_position_dynamics(
|
201
|
-
model:
|
216
|
+
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
|
202
217
|
) -> tuple[jtp.Vector, jtp.Vector, jtp.Vector]:
|
203
218
|
"""
|
204
219
|
Compute the dynamics of the system position.
|
@@ -212,8 +227,8 @@ def system_position_dynamics(
|
|
212
227
|
base quaternion, and the derivative of the joint positions.
|
213
228
|
"""
|
214
229
|
|
215
|
-
ṡ = data.
|
216
|
-
W_Q_B = data.
|
230
|
+
ṡ = data.joint_velocities(model=model)
|
231
|
+
W_Q_B = data.base_orientation(dcm=False)
|
217
232
|
|
218
233
|
with data.switch_velocity_representation(velocity_representation=VelRepr.Mixed):
|
219
234
|
W_ṗ_B = data.base_velocity()[0:3]
|
@@ -232,11 +247,11 @@ def system_position_dynamics(
|
|
232
247
|
|
233
248
|
@jax.jit
|
234
249
|
def system_dynamics(
|
235
|
-
model:
|
236
|
-
data:
|
250
|
+
model: js.model.JaxSimModel,
|
251
|
+
data: js.data.JaxSimModelData,
|
237
252
|
*,
|
238
253
|
joint_forces: jtp.Vector | None = None,
|
239
|
-
|
254
|
+
link_forces: jtp.Vector | None = None,
|
240
255
|
) -> tuple[ODEState, dict[str, Any]]:
|
241
256
|
"""
|
242
257
|
Compute the dynamics of the system.
|
@@ -245,7 +260,7 @@ def system_dynamics(
|
|
245
260
|
model: The model to consider.
|
246
261
|
data: The data of the considered model.
|
247
262
|
joint_forces: The joint forces to apply.
|
248
|
-
|
263
|
+
link_forces: The 6D forces to apply to the links.
|
249
264
|
|
250
265
|
Returns:
|
251
266
|
A tuple with an `ODEState` object storing in each of its attributes the
|
@@ -258,7 +273,7 @@ def system_dynamics(
|
|
258
273
|
model=model,
|
259
274
|
data=data,
|
260
275
|
joint_forces=joint_forces,
|
261
|
-
|
276
|
+
link_forces=link_forces,
|
262
277
|
)
|
263
278
|
|
264
279
|
# Extract the velocities.
|
@@ -267,18 +282,15 @@ def system_dynamics(
|
|
267
282
|
# Create an ODEState object populated with the derivative of each leaf.
|
268
283
|
# Our integrators, operating on generic pytrees, will be able to handle it
|
269
284
|
# automatically as state derivative.
|
270
|
-
ode_state_derivative = ODEState.
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
soft_contacts_state=SoftContactsState.build(
|
280
|
-
tangential_deformation=ṁ,
|
281
|
-
),
|
285
|
+
ode_state_derivative = ODEState.build_from_jaxsim_model(
|
286
|
+
model=model,
|
287
|
+
base_position=W_ṗ_B,
|
288
|
+
base_quaternion=W_Q̇_B,
|
289
|
+
joint_positions=ṡ,
|
290
|
+
base_linear_velocity=W_v̇_WB[0:3],
|
291
|
+
base_angular_velocity=W_v̇_WB[3:6],
|
292
|
+
joint_velocities=s̈,
|
293
|
+
tangential_deformation=ṁ,
|
282
294
|
)
|
283
295
|
|
284
296
|
return ode_state_derivative, aux_dict
|