jaxsim 0.2.dev188__py3-none-any.whl → 0.2.dev364__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/__init__.py +3 -4
- jaxsim/_version.py +2 -2
- jaxsim/api/__init__.py +3 -1
- jaxsim/api/com.py +240 -0
- jaxsim/api/common.py +13 -2
- jaxsim/api/contact.py +120 -43
- jaxsim/api/data.py +112 -71
- jaxsim/api/joint.py +77 -36
- jaxsim/api/kin_dyn_parameters.py +777 -0
- jaxsim/api/link.py +150 -75
- jaxsim/api/model.py +542 -269
- jaxsim/api/ode.py +88 -72
- jaxsim/api/ode_data.py +694 -0
- jaxsim/api/references.py +12 -11
- jaxsim/integrators/__init__.py +2 -2
- jaxsim/integrators/common.py +110 -24
- jaxsim/integrators/fixed_step.py +11 -67
- jaxsim/integrators/variable_step.py +610 -0
- jaxsim/math/__init__.py +11 -0
- jaxsim/math/adjoint.py +24 -2
- jaxsim/math/joint_model.py +335 -0
- jaxsim/math/quaternion.py +44 -3
- jaxsim/math/rotation.py +4 -4
- jaxsim/math/transform.py +93 -0
- jaxsim/parsers/descriptions/collision.py +14 -0
- jaxsim/parsers/descriptions/link.py +13 -2
- jaxsim/parsers/kinematic_graph.py +5 -0
- jaxsim/parsers/rod/utils.py +7 -8
- jaxsim/rbda/__init__.py +7 -0
- jaxsim/rbda/aba.py +295 -0
- jaxsim/rbda/collidable_points.py +142 -0
- jaxsim/{physics/algos → rbda}/crba.py +43 -42
- jaxsim/rbda/forward_kinematics.py +113 -0
- jaxsim/rbda/jacobian.py +201 -0
- jaxsim/rbda/rnea.py +237 -0
- jaxsim/rbda/soft_contacts.py +296 -0
- jaxsim/rbda/utils.py +152 -0
- jaxsim/terrain/__init__.py +2 -0
- jaxsim/{physics/algos → terrain}/terrain.py +4 -6
- jaxsim/utils/__init__.py +1 -4
- jaxsim/utils/hashless.py +18 -0
- jaxsim/utils/jaxsim_dataclass.py +281 -30
- {jaxsim-0.2.dev188.dist-info → jaxsim-0.2.dev364.dist-info}/METADATA +4 -6
- jaxsim-0.2.dev364.dist-info/RECORD +64 -0
- {jaxsim-0.2.dev188.dist-info → jaxsim-0.2.dev364.dist-info}/WHEEL +1 -1
- jaxsim/high_level/__init__.py +0 -2
- jaxsim/high_level/common.py +0 -11
- jaxsim/high_level/joint.py +0 -148
- jaxsim/high_level/link.py +0 -259
- jaxsim/high_level/model.py +0 -1686
- jaxsim/math/conv.py +0 -114
- jaxsim/math/joint.py +0 -102
- jaxsim/math/plucker.py +0 -100
- jaxsim/physics/__init__.py +0 -12
- jaxsim/physics/algos/__init__.py +0 -0
- jaxsim/physics/algos/aba.py +0 -254
- jaxsim/physics/algos/aba_motors.py +0 -284
- jaxsim/physics/algos/forward_kinematics.py +0 -79
- jaxsim/physics/algos/jacobian.py +0 -98
- jaxsim/physics/algos/rnea.py +0 -180
- jaxsim/physics/algos/rnea_motors.py +0 -196
- jaxsim/physics/algos/soft_contacts.py +0 -523
- jaxsim/physics/algos/utils.py +0 -69
- jaxsim/physics/model/__init__.py +0 -0
- jaxsim/physics/model/ground_contact.py +0 -55
- jaxsim/physics/model/physics_model.py +0 -388
- jaxsim/physics/model/physics_model_state.py +0 -283
- jaxsim/simulation/__init__.py +0 -4
- jaxsim/simulation/integrators.py +0 -393
- jaxsim/simulation/ode.py +0 -290
- jaxsim/simulation/ode_data.py +0 -96
- jaxsim/simulation/ode_integration.py +0 -62
- jaxsim/simulation/simulator.py +0 -543
- jaxsim/simulation/simulator_callbacks.py +0 -79
- jaxsim/simulation/utils.py +0 -15
- jaxsim/sixd/__init__.py +0 -2
- jaxsim/utils/oop.py +0 -536
- jaxsim/utils/vmappable.py +0 -117
- jaxsim-0.2.dev188.dist-info/RECORD +0 -81
- {jaxsim-0.2.dev188.dist-info → jaxsim-0.2.dev364.dist-info}/LICENSE +0 -0
- {jaxsim-0.2.dev188.dist-info → jaxsim-0.2.dev364.dist-info}/top_level.txt +0 -0
jaxsim/api/ode.py
CHANGED
@@ -2,34 +2,30 @@ from typing import Any, Protocol
|
|
2
2
|
|
3
3
|
import jax
|
4
4
|
import jax.numpy as jnp
|
5
|
-
import jaxlie
|
6
5
|
|
7
|
-
import jaxsim.
|
6
|
+
import jaxsim.api as js
|
7
|
+
import jaxsim.rbda
|
8
8
|
import jaxsim.typing as jtp
|
9
|
-
from jaxsim import
|
10
|
-
from jaxsim.
|
11
|
-
from jaxsim.
|
12
|
-
from jaxsim.physics.algos.soft_contacts import SoftContactsState
|
13
|
-
from jaxsim.physics.model.physics_model_state import PhysicsModelState
|
14
|
-
from jaxsim.simulation.ode_data import ODEState
|
9
|
+
from jaxsim.integrators import Time
|
10
|
+
from jaxsim.math import Quaternion
|
11
|
+
from jaxsim.utils import Mutability
|
15
12
|
|
16
|
-
from . import
|
17
|
-
from . import
|
18
|
-
from . import model as Model
|
13
|
+
from .common import VelRepr
|
14
|
+
from .ode_data import ODEState
|
19
15
|
|
20
16
|
|
21
17
|
class SystemDynamicsFromModelAndData(Protocol):
|
22
18
|
def __call__(
|
23
19
|
self,
|
24
|
-
model:
|
25
|
-
data:
|
20
|
+
model: js.model.JaxSimModel,
|
21
|
+
data: js.data.JaxSimModelData,
|
26
22
|
**kwargs: dict[str, Any],
|
27
23
|
) -> tuple[ODEState, dict[str, Any]]: ...
|
28
24
|
|
29
25
|
|
30
26
|
def wrap_system_dynamics_for_integration(
|
31
|
-
model:
|
32
|
-
data:
|
27
|
+
model: js.model.JaxSimModel,
|
28
|
+
data: js.data.JaxSimModelData,
|
33
29
|
*,
|
34
30
|
system_dynamics: SystemDynamicsFromModelAndData,
|
35
31
|
**kwargs,
|
@@ -49,17 +45,33 @@ def wrap_system_dynamics_for_integration(
|
|
49
45
|
"""
|
50
46
|
|
51
47
|
# We allow to close `system_dynamics` over additional kwargs.
|
52
|
-
kwargs_closed = kwargs
|
48
|
+
kwargs_closed = kwargs.copy()
|
53
49
|
|
54
|
-
|
50
|
+
# Create a local copy of model and data.
|
51
|
+
# The wrapped dynamics will hold a reference of this object.
|
52
|
+
model_closed = model.copy()
|
53
|
+
data_closed = data.copy().replace(
|
54
|
+
state=js.ode_data.ODEState.zero(model=model_closed)
|
55
|
+
)
|
55
56
|
|
56
|
-
|
57
|
-
with data.editable(validate=True) as data_rw:
|
58
|
-
data_rw.state = x
|
59
|
-
data_rw.time_ns = jnp.array(t * 1e9).astype(jnp.uint64)
|
57
|
+
def f(x: ODEState, t: Time, **kwargs_f) -> tuple[ODEState, dict[str, Any]]:
|
60
58
|
|
61
|
-
#
|
62
|
-
|
59
|
+
# Allow caller to override the closed data and model objects.
|
60
|
+
data_f = kwargs_f.pop("data", data_closed)
|
61
|
+
model_f = kwargs_f.pop("model", model_closed)
|
62
|
+
|
63
|
+
# Update the state and time stored inside data.
|
64
|
+
with data_f.editable(validate=True) as data_rw:
|
65
|
+
data_rw.state = x
|
66
|
+
data_rw.time_ns = jnp.array(t * 1e9).astype(data_rw.time_ns.dtype)
|
67
|
+
|
68
|
+
# Evaluate the system dynamics, allowing to override the kwargs originally
|
69
|
+
# passed when the closure was created.
|
70
|
+
return system_dynamics(
|
71
|
+
model=model_f,
|
72
|
+
data=data_rw,
|
73
|
+
**(kwargs_closed | kwargs_f),
|
74
|
+
)
|
63
75
|
|
64
76
|
f: jaxsim.integrators.common.SystemDynamics[ODEState, ODEState]
|
65
77
|
return f
|
@@ -72,11 +84,11 @@ def wrap_system_dynamics_for_integration(
|
|
72
84
|
|
73
85
|
@jax.jit
|
74
86
|
def system_velocity_dynamics(
|
75
|
-
model:
|
76
|
-
data:
|
87
|
+
model: js.model.JaxSimModel,
|
88
|
+
data: js.data.JaxSimModelData,
|
77
89
|
*,
|
78
90
|
joint_forces: jtp.Vector | None = None,
|
79
|
-
|
91
|
+
link_forces: jtp.Vector | None = None,
|
80
92
|
) -> tuple[jtp.Vector, jtp.Vector, jtp.Matrix, dict[str, Any]]:
|
81
93
|
"""
|
82
94
|
Compute the dynamics of the system velocity.
|
@@ -85,13 +97,13 @@ def system_velocity_dynamics(
|
|
85
97
|
model: The model to consider.
|
86
98
|
data: The data of the considered model.
|
87
99
|
joint_forces: The joint forces to apply.
|
88
|
-
|
100
|
+
link_forces: The 6D forces to apply to the links.
|
89
101
|
|
90
102
|
Returns:
|
91
103
|
A tuple containing the derivative of the base 6D velocity in inertial-fixed
|
92
104
|
representation, the derivative of the joint velocities, the derivative of
|
93
105
|
the material deformation, and the dictionary of auxiliary data returned by
|
94
|
-
the system dynamics
|
106
|
+
the system dynamics evaluation.
|
95
107
|
"""
|
96
108
|
|
97
109
|
# Build joint torques if not provided
|
@@ -101,10 +113,10 @@ def system_velocity_dynamics(
|
|
101
113
|
else jnp.zeros_like(data.joint_positions())
|
102
114
|
).astype(float)
|
103
115
|
|
104
|
-
# Build
|
105
|
-
|
106
|
-
jnp.atleast_2d(
|
107
|
-
if
|
116
|
+
# Build link forces if not provided
|
117
|
+
W_f_L = (
|
118
|
+
jnp.atleast_2d(link_forces.squeeze())
|
119
|
+
if link_forces is not None
|
108
120
|
else jnp.zeros((model.number_of_links(), 6))
|
109
121
|
).astype(float)
|
110
122
|
|
@@ -114,33 +126,36 @@ def system_velocity_dynamics(
|
|
114
126
|
|
115
127
|
# Initialize the 6D forces W_f ∈ ℝ^{n_L × 6} applied to links due to contact
|
116
128
|
# with the terrain.
|
117
|
-
W_f_Li_terrain = jnp.zeros_like(
|
129
|
+
W_f_Li_terrain = jnp.zeros_like(W_f_L).astype(float)
|
118
130
|
|
119
|
-
# Initialize the 6D contact forces W_f ∈ ℝ^{n_c ×
|
131
|
+
# Initialize the 6D contact forces W_f ∈ ℝ^{n_c × 6} applied to collidable points,
|
120
132
|
# expressed in the world frame.
|
121
133
|
W_f_Ci = None
|
122
134
|
|
123
135
|
# Initialize the derivative of the tangential deformation ṁ ∈ ℝ^{n_c × 3}.
|
124
136
|
ṁ = jnp.zeros_like(data.state.soft_contacts.tangential_deformation).astype(float)
|
125
137
|
|
126
|
-
if model.
|
127
|
-
# Compute the
|
128
|
-
#
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
138
|
+
if len(model.kin_dyn_parameters.contact_parameters.body) > 0:
|
139
|
+
# Compute the 6D forces applied to each collidable point and the
|
140
|
+
# corresponding material deformation rates.
|
141
|
+
with data.switch_velocity_representation(VelRepr.Inertial):
|
142
|
+
W_f_Ci, ṁ = js.contact.collidable_point_dynamics(model=model, data=data)
|
143
|
+
|
144
|
+
# Construct the vector defining the parent link index of each collidable point.
|
145
|
+
# We use this vector to sum the 6D forces of all collidable points rigidly
|
146
|
+
# attached to the same link.
|
147
|
+
parent_link_index_of_collidable_points = jnp.array(
|
148
|
+
model.kin_dyn_parameters.contact_parameters.body, dtype=int
|
149
|
+
)
|
137
150
|
|
138
151
|
# Sum the forces of all collidable points rigidly attached to a body.
|
139
152
|
# Since the contact forces W_f_Ci are expressed in the world frame,
|
140
153
|
# we don't need any coordinate transformation.
|
141
154
|
W_f_Li_terrain = jax.vmap(
|
142
155
|
lambda nc: (
|
143
|
-
jnp.vstack(
|
156
|
+
jnp.vstack(
|
157
|
+
jnp.equal(parent_link_index_of_collidable_points, nc).astype(int)
|
158
|
+
)
|
144
159
|
* W_f_Ci
|
145
160
|
).sum(axis=0)
|
146
161
|
)(jnp.arange(model.number_of_links()))
|
@@ -160,8 +175,12 @@ def system_velocity_dynamics(
|
|
160
175
|
|
161
176
|
if model.dofs() > 0:
|
162
177
|
# Static and viscous joint friction parameters
|
163
|
-
kc = jnp.array(
|
164
|
-
|
178
|
+
kc = jnp.array(
|
179
|
+
model.kin_dyn_parameters.joint_parameters.friction_static
|
180
|
+
).astype(float)
|
181
|
+
kv = jnp.array(
|
182
|
+
model.kin_dyn_parameters.joint_parameters.friction_viscous
|
183
|
+
).astype(float)
|
165
184
|
|
166
185
|
# Compute the joint friction torque
|
167
186
|
τ_friction = -(
|
@@ -177,24 +196,24 @@ def system_velocity_dynamics(
|
|
177
196
|
τ_total = τ + τ_friction + τ_position_limit
|
178
197
|
|
179
198
|
# Compute the total external 6D forces applied to the links
|
180
|
-
W_f_L_total =
|
199
|
+
W_f_L_total = W_f_L + W_f_Li_terrain
|
181
200
|
|
182
201
|
# - Joint accelerations: s̈ ∈ ℝⁿ
|
183
202
|
# - Base inertial-fixed acceleration: W_v̇_WB = (W_p̈_B, W_ω̇_B) ∈ ℝ⁶
|
184
203
|
with data.switch_velocity_representation(velocity_representation=VelRepr.Inertial):
|
185
|
-
W_v̇_WB, s̈ =
|
204
|
+
W_v̇_WB, s̈ = js.model.forward_dynamics_aba(
|
186
205
|
model=model,
|
187
206
|
data=data,
|
188
207
|
joint_forces=τ_total,
|
189
|
-
|
208
|
+
link_forces=W_f_L_total,
|
190
209
|
)
|
191
210
|
|
192
|
-
return W_v̇_WB, s̈, m
|
211
|
+
return W_v̇_WB, s̈, ṁ, dict()
|
193
212
|
|
194
213
|
|
195
214
|
@jax.jit
|
196
215
|
def system_position_dynamics(
|
197
|
-
model:
|
216
|
+
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
|
198
217
|
) -> tuple[jtp.Vector, jtp.Vector, jtp.Vector]:
|
199
218
|
"""
|
200
219
|
Compute the dynamics of the system position.
|
@@ -208,8 +227,8 @@ def system_position_dynamics(
|
|
208
227
|
base quaternion, and the derivative of the joint positions.
|
209
228
|
"""
|
210
229
|
|
211
|
-
ṡ = data.
|
212
|
-
W_Q_B = data.
|
230
|
+
ṡ = data.joint_velocities(model=model)
|
231
|
+
W_Q_B = data.base_orientation(dcm=False)
|
213
232
|
|
214
233
|
with data.switch_velocity_representation(velocity_representation=VelRepr.Mixed):
|
215
234
|
W_ṗ_B = data.base_velocity()[0:3]
|
@@ -228,11 +247,11 @@ def system_position_dynamics(
|
|
228
247
|
|
229
248
|
@jax.jit
|
230
249
|
def system_dynamics(
|
231
|
-
model:
|
232
|
-
data:
|
250
|
+
model: js.model.JaxSimModel,
|
251
|
+
data: js.data.JaxSimModelData,
|
233
252
|
*,
|
234
253
|
joint_forces: jtp.Vector | None = None,
|
235
|
-
|
254
|
+
link_forces: jtp.Vector | None = None,
|
236
255
|
) -> tuple[ODEState, dict[str, Any]]:
|
237
256
|
"""
|
238
257
|
Compute the dynamics of the system.
|
@@ -241,7 +260,7 @@ def system_dynamics(
|
|
241
260
|
model: The model to consider.
|
242
261
|
data: The data of the considered model.
|
243
262
|
joint_forces: The joint forces to apply.
|
244
|
-
|
263
|
+
link_forces: The 6D forces to apply to the links.
|
245
264
|
|
246
265
|
Returns:
|
247
266
|
A tuple with an `ODEState` object storing in each of its attributes the
|
@@ -254,7 +273,7 @@ def system_dynamics(
|
|
254
273
|
model=model,
|
255
274
|
data=data,
|
256
275
|
joint_forces=joint_forces,
|
257
|
-
|
276
|
+
link_forces=link_forces,
|
258
277
|
)
|
259
278
|
|
260
279
|
# Extract the velocities.
|
@@ -263,18 +282,15 @@ def system_dynamics(
|
|
263
282
|
# Create an ODEState object populated with the derivative of each leaf.
|
264
283
|
# Our integrators, operating on generic pytrees, will be able to handle it
|
265
284
|
# automatically as state derivative.
|
266
|
-
ode_state_derivative = ODEState.
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
soft_contacts_state=SoftContactsState.build(
|
276
|
-
tangential_deformation=ṁ,
|
277
|
-
),
|
285
|
+
ode_state_derivative = ODEState.build_from_jaxsim_model(
|
286
|
+
model=model,
|
287
|
+
base_position=W_ṗ_B,
|
288
|
+
base_quaternion=W_Q̇_B,
|
289
|
+
joint_positions=ṡ,
|
290
|
+
base_linear_velocity=W_v̇_WB[0:3],
|
291
|
+
base_angular_velocity=W_v̇_WB[3:6],
|
292
|
+
joint_velocities=s̈,
|
293
|
+
tangential_deformation=ṁ,
|
278
294
|
)
|
279
295
|
|
280
296
|
return ode_state_derivative, aux_dict
|