jaxsim 0.1.dev401__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/__init__.py +5 -6
- jaxsim/_version.py +2 -2
- jaxsim/api/__init__.py +3 -0
- jaxsim/api/com.py +240 -0
- jaxsim/api/common.py +216 -0
- jaxsim/api/contact.py +271 -0
- jaxsim/api/data.py +821 -0
- jaxsim/api/joint.py +189 -0
- jaxsim/api/kin_dyn_parameters.py +777 -0
- jaxsim/api/link.py +361 -0
- jaxsim/api/model.py +1633 -0
- jaxsim/api/ode.py +295 -0
- jaxsim/api/ode_data.py +694 -0
- jaxsim/api/references.py +421 -0
- jaxsim/integrators/__init__.py +2 -0
- jaxsim/integrators/common.py +594 -0
- jaxsim/integrators/fixed_step.py +102 -0
- jaxsim/integrators/variable_step.py +610 -0
- jaxsim/math/__init__.py +11 -0
- jaxsim/math/adjoint.py +24 -2
- jaxsim/math/joint_model.py +335 -0
- jaxsim/math/quaternion.py +44 -3
- jaxsim/math/rotation.py +4 -4
- jaxsim/math/transform.py +92 -0
- jaxsim/mujoco/__init__.py +3 -0
- jaxsim/mujoco/__main__.py +192 -0
- jaxsim/mujoco/loaders.py +615 -0
- jaxsim/mujoco/model.py +414 -0
- jaxsim/mujoco/visualizer.py +176 -0
- jaxsim/parsers/descriptions/collision.py +14 -0
- jaxsim/parsers/descriptions/link.py +13 -2
- jaxsim/parsers/kinematic_graph.py +8 -3
- jaxsim/parsers/rod/parser.py +54 -38
- jaxsim/parsers/rod/utils.py +7 -8
- jaxsim/rbda/__init__.py +7 -0
- jaxsim/rbda/aba.py +295 -0
- jaxsim/rbda/collidable_points.py +142 -0
- jaxsim/{physics/algos → rbda}/crba.py +43 -42
- jaxsim/rbda/forward_kinematics.py +113 -0
- jaxsim/rbda/jacobian.py +201 -0
- jaxsim/rbda/rnea.py +237 -0
- jaxsim/rbda/soft_contacts.py +296 -0
- jaxsim/rbda/utils.py +152 -0
- jaxsim/terrain/__init__.py +2 -0
- jaxsim/{physics/algos → terrain}/terrain.py +4 -6
- jaxsim/typing.py +30 -30
- jaxsim/utils/__init__.py +1 -4
- jaxsim/utils/hashless.py +18 -0
- jaxsim/utils/jaxsim_dataclass.py +281 -31
- {jaxsim-0.1.dev401.dist-info → jaxsim-0.2.0.dist-info}/LICENSE +1 -1
- jaxsim-0.2.0.dist-info/METADATA +237 -0
- jaxsim-0.2.0.dist-info/RECORD +64 -0
- {jaxsim-0.1.dev401.dist-info → jaxsim-0.2.0.dist-info}/WHEEL +1 -1
- jaxsim/high_level/__init__.py +0 -2
- jaxsim/high_level/common.py +0 -11
- jaxsim/high_level/joint.py +0 -148
- jaxsim/high_level/link.py +0 -259
- jaxsim/high_level/model.py +0 -1695
- jaxsim/math/conv.py +0 -114
- jaxsim/math/joint.py +0 -101
- jaxsim/math/plucker.py +0 -100
- jaxsim/physics/__init__.py +0 -12
- jaxsim/physics/algos/__init__.py +0 -0
- jaxsim/physics/algos/aba.py +0 -256
- jaxsim/physics/algos/aba_motors.py +0 -284
- jaxsim/physics/algos/forward_kinematics.py +0 -79
- jaxsim/physics/algos/jacobian.py +0 -98
- jaxsim/physics/algos/rnea.py +0 -180
- jaxsim/physics/algos/rnea_motors.py +0 -196
- jaxsim/physics/algos/soft_contacts.py +0 -454
- jaxsim/physics/algos/utils.py +0 -69
- jaxsim/physics/model/__init__.py +0 -0
- jaxsim/physics/model/ground_contact.py +0 -55
- jaxsim/physics/model/physics_model.py +0 -358
- jaxsim/physics/model/physics_model_state.py +0 -174
- jaxsim/simulation/__init__.py +0 -4
- jaxsim/simulation/integrators.py +0 -452
- jaxsim/simulation/ode.py +0 -290
- jaxsim/simulation/ode_data.py +0 -53
- jaxsim/simulation/ode_integration.py +0 -125
- jaxsim/simulation/simulator.py +0 -544
- jaxsim/simulation/simulator_callbacks.py +0 -53
- jaxsim/simulation/utils.py +0 -15
- jaxsim/sixd/__init__.py +0 -2
- jaxsim/utils/oop.py +0 -532
- jaxsim/utils/vmappable.py +0 -117
- jaxsim-0.1.dev401.dist-info/METADATA +0 -167
- jaxsim-0.1.dev401.dist-info/RECORD +0 -64
- {jaxsim-0.1.dev401.dist-info → jaxsim-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,610 @@
|
|
1
|
+
import functools
|
2
|
+
from typing import Any, ClassVar, Generic, Type
|
3
|
+
|
4
|
+
try:
|
5
|
+
from typing import Self
|
6
|
+
except ImportError:
|
7
|
+
from typing_extensions import Self
|
8
|
+
|
9
|
+
import jax
|
10
|
+
import jax.flatten_util
|
11
|
+
import jax.numpy as jnp
|
12
|
+
import jax_dataclasses
|
13
|
+
from jax_dataclasses import Static
|
14
|
+
|
15
|
+
from jaxsim import typing as jtp
|
16
|
+
from jaxsim.utils import Mutability
|
17
|
+
|
18
|
+
from .common import (
|
19
|
+
ExplicitRungeKutta,
|
20
|
+
ExplicitRungeKuttaSO3Mixin,
|
21
|
+
NextState,
|
22
|
+
PyTreeType,
|
23
|
+
State,
|
24
|
+
StateDerivative,
|
25
|
+
SystemDynamics,
|
26
|
+
Time,
|
27
|
+
TimeStep,
|
28
|
+
)
|
29
|
+
|
30
|
+
# For robot dynamics, the following default tolerances are already pretty accurate.
|
31
|
+
# Users can either decrease them and pay the price of smaller Δt, or increase
|
32
|
+
# them and pay the price of less accurate dynamics.
|
33
|
+
RTOL_DEFAULT = 0.000_100 # 0.01%
|
34
|
+
ATOL_DEFAULT = 0.000_010 # 10μ
|
35
|
+
|
36
|
+
# Default parameters of Embedded RK schemes.
|
37
|
+
SAFETY_DEFAULT = 0.9
|
38
|
+
BETA_MIN_DEFAULT = 1.0 / 10
|
39
|
+
BETA_MAX_DEFAULT = 2.5
|
40
|
+
MAX_STEP_REJECTIONS_DEFAULT = 5
|
41
|
+
|
42
|
+
|
43
|
+
# =================
|
44
|
+
# Utility functions
|
45
|
+
# =================
|
46
|
+
|
47
|
+
|
48
|
+
@functools.partial(jax.jit, static_argnames=["f"])
|
49
|
+
def estimate_step_size(
|
50
|
+
x0: jtp.PyTree,
|
51
|
+
t0: Time,
|
52
|
+
f: SystemDynamics,
|
53
|
+
order: jtp.IntLike,
|
54
|
+
rtol: jtp.FloatLike = RTOL_DEFAULT,
|
55
|
+
atol: jtp.FloatLike = ATOL_DEFAULT,
|
56
|
+
) -> tuple[jtp.Float, jtp.PyTree]:
|
57
|
+
r"""
|
58
|
+
Compute the initial step size to warm-start variable-step integrators.
|
59
|
+
|
60
|
+
Args:
|
61
|
+
x0: The initial state.
|
62
|
+
t0: The initial time.
|
63
|
+
f: The state derivative function :math:`f(x, t)`.
|
64
|
+
order:
|
65
|
+
The order :math:`p` of an integrator with truncation error
|
66
|
+
:math:`\mathcal{O}(\Delta t^{p+1})`.
|
67
|
+
rtol: The relative tolerance to scale the state.
|
68
|
+
atol: The absolute tolerance to scale the state.
|
69
|
+
|
70
|
+
Returns:
|
71
|
+
A tuple containing the computed initial step size
|
72
|
+
and the state derivative :math:`\dot{x} = f(x_0, t_0)`.
|
73
|
+
|
74
|
+
Note:
|
75
|
+
Interested readers could find implementation details in:
|
76
|
+
|
77
|
+
Solving Ordinary Differential Equations I: Nonstiff Problems, Sec. II.4.
|
78
|
+
E. Hairer, S. P. Norsett G. Wanner.
|
79
|
+
"""
|
80
|
+
|
81
|
+
# Helper to flatten a pytree to a 1D vector.
|
82
|
+
def flatten(pytree) -> jax.Array:
|
83
|
+
return jax.flatten_util.ravel_pytree(pytree=pytree)[0]
|
84
|
+
|
85
|
+
# Compute the state derivative at the initial state.
|
86
|
+
ẋ0 = f(x0, t0)[0]
|
87
|
+
|
88
|
+
# Compute the scaling factors of the initial state and its derivative.
|
89
|
+
compute_scale = lambda x: atol + jnp.abs(x) * rtol
|
90
|
+
scale0 = jax.tree_util.tree_map(compute_scale, x0)
|
91
|
+
scale1 = jax.tree_util.tree_map(compute_scale, ẋ0)
|
92
|
+
|
93
|
+
# Scale the initial state and its derivative.
|
94
|
+
scale_pytree = lambda x, scale: jnp.abs(x) / scale
|
95
|
+
x0_scaled = jax.tree_util.tree_map(scale_pytree, x0, scale0)
|
96
|
+
ẋ0_scaled = jax.tree_util.tree_map(scale_pytree, ẋ0, scale1)
|
97
|
+
|
98
|
+
# Get the maximum of the scaled pytrees.
|
99
|
+
d0 = jnp.linalg.norm(flatten(x0_scaled), ord=jnp.inf)
|
100
|
+
d1 = jnp.linalg.norm(flatten(ẋ0_scaled), ord=jnp.inf)
|
101
|
+
|
102
|
+
# Compute the first guess of the initial step size.
|
103
|
+
h0 = jnp.where(jnp.minimum(d0, d1) <= 1e-5, 1e-6, 0.01 * d0 / d1)
|
104
|
+
|
105
|
+
# Compute the next state (explicit Euler step) and its derivative.
|
106
|
+
x1 = jax.tree_util.tree_map(lambda x0, ẋ0: x0 + h0 * ẋ0, x0, ẋ0)
|
107
|
+
ẋ1 = f(x1, t0 + h0)[0]
|
108
|
+
|
109
|
+
# Compute the scaling factor of the state derivatives.
|
110
|
+
compute_scale_2 = lambda ẋ0, ẋ1: atol + jnp.maximum(jnp.abs(ẋ0), jnp.abs(ẋ1)) * rtol
|
111
|
+
scale2 = jax.tree_util.tree_map(compute_scale_2, ẋ0, ẋ1)
|
112
|
+
|
113
|
+
# Scale the difference of the state derivatives.
|
114
|
+
scale_ẋ_difference = lambda ẋ0, ẋ1, scale: jnp.abs((ẋ0 - ẋ1) / scale)
|
115
|
+
ẋ_difference_scaled = jax.tree_util.tree_map(scale_ẋ_difference, ẋ0, ẋ1, scale2)
|
116
|
+
|
117
|
+
# Get the maximum of the scaled derivatives difference.
|
118
|
+
d2 = jnp.linalg.norm(flatten(ẋ_difference_scaled), ord=jnp.inf) / h0
|
119
|
+
|
120
|
+
# Compute the second guess of the initial step size.
|
121
|
+
h1 = jnp.where(
|
122
|
+
jnp.maximum(d1, d2) <= 1e-15,
|
123
|
+
jnp.maximum(1e-6, h0 * 1e-3),
|
124
|
+
(0.01 / jnp.maximum(d1, d2)) ** (1.0 / (order + 1.0)),
|
125
|
+
)
|
126
|
+
|
127
|
+
# Propose the final guess of the initial step size.
|
128
|
+
# Also return the state derivative computed at the initial state since
|
129
|
+
# likely it is a quantity that needs to be computed again later.
|
130
|
+
return jnp.array(jnp.minimum(100.0 * h0, h1), dtype=float), ẋ0
|
131
|
+
|
132
|
+
|
133
|
+
@jax.jit
|
134
|
+
def compute_pytree_scale(
|
135
|
+
x1: jtp.PyTree,
|
136
|
+
x2: jtp.PyTree | None = None,
|
137
|
+
rtol: jtp.FloatLike = RTOL_DEFAULT,
|
138
|
+
atol: jtp.FloatLike = ATOL_DEFAULT,
|
139
|
+
) -> jtp.PyTree:
|
140
|
+
"""
|
141
|
+
Compute the component-wise state scale factors to scale dynamical states.
|
142
|
+
|
143
|
+
Args:
|
144
|
+
x1: The first state (often the initial state).
|
145
|
+
x2: The optional second state (often the final state).
|
146
|
+
rtol: The relative tolerance to scale the state.
|
147
|
+
atol: The absolute tolerance to scale the state.
|
148
|
+
|
149
|
+
Returns:
|
150
|
+
A pytree with the same structure of the state containing the scaling factors.
|
151
|
+
"""
|
152
|
+
|
153
|
+
# Consider a zero second pytree, if not given.
|
154
|
+
x2 = jax.tree_util.tree_map(lambda l: jnp.zeros_like(l), x1) if x2 is None else x2
|
155
|
+
|
156
|
+
# Compute the scaling factors of the initial state and its derivative.
|
157
|
+
compute_scale = lambda l1, l2: atol + jnp.maximum(jnp.abs(l1), jnp.abs(l2)) * rtol
|
158
|
+
scale = jax.tree_util.tree_map(compute_scale, x1, x2)
|
159
|
+
|
160
|
+
return scale
|
161
|
+
|
162
|
+
|
163
|
+
@jax.jit
|
164
|
+
def local_error_estimation(
|
165
|
+
xf: jtp.PyTree,
|
166
|
+
xf_estimate: jtp.PyTree | None = None,
|
167
|
+
x0: jtp.PyTree | None = None,
|
168
|
+
rtol: jtp.FloatLike = RTOL_DEFAULT,
|
169
|
+
atol: jtp.FloatLike = ATOL_DEFAULT,
|
170
|
+
norm_ord: jtp.IntLike | jtp.FloatLike = jnp.inf,
|
171
|
+
) -> jtp.Float:
|
172
|
+
"""
|
173
|
+
Estimate the local integration error, often used in Embedded RK schemes.
|
174
|
+
|
175
|
+
Args:
|
176
|
+
xf: The final state, often computed with the most accurate integrator.
|
177
|
+
xf_estimate:
|
178
|
+
The estimated final state, often computed with the less accurate integrator.
|
179
|
+
If missing, it is initialized to zero.
|
180
|
+
x0:
|
181
|
+
The initial state to compute the scaling factors. If missing, it is
|
182
|
+
initialized to zero.
|
183
|
+
rtol: The relative tolerance to scale the state.
|
184
|
+
atol: The absolute tolerance to scale the state.
|
185
|
+
norm_ord:
|
186
|
+
The norm to use to compute the error. Default is the infinity norm.
|
187
|
+
|
188
|
+
Returns:
|
189
|
+
The estimated local integration error.
|
190
|
+
"""
|
191
|
+
|
192
|
+
# Helper to flatten a pytree to a 1D vector.
|
193
|
+
def flatten(pytree) -> jax.Array:
|
194
|
+
return jax.flatten_util.ravel_pytree(pytree=pytree)[0]
|
195
|
+
|
196
|
+
# Compute the scale considering the initial and final states.
|
197
|
+
scale = compute_pytree_scale(x1=xf, x2=x0, rtol=rtol, atol=atol)
|
198
|
+
|
199
|
+
# Consider a zero estimated final state, if not given.
|
200
|
+
xf_estimate = (
|
201
|
+
jax.tree_util.tree_map(lambda l: jnp.zeros_like(l), xf)
|
202
|
+
if xf_estimate is None
|
203
|
+
else xf_estimate
|
204
|
+
)
|
205
|
+
|
206
|
+
# Estimate the error.
|
207
|
+
estimate_error = lambda l, l̂, sc: jnp.abs(l - l̂) / sc
|
208
|
+
error_estimate = jax.tree_util.tree_map(estimate_error, xf, xf_estimate, scale)
|
209
|
+
|
210
|
+
# Return the highest element of the error estimate.
|
211
|
+
return jnp.linalg.norm(flatten(error_estimate), ord=norm_ord)
|
212
|
+
|
213
|
+
|
214
|
+
# ================================
|
215
|
+
# Embedded Runge-Kutta integrators
|
216
|
+
# ================================
|
217
|
+
|
218
|
+
|
219
|
+
@jax_dataclasses.pytree_dataclass
|
220
|
+
class EmbeddedRungeKutta(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
|
221
|
+
|
222
|
+
# Define the row of the integration output corresponding to the solution estimate.
|
223
|
+
# This is the row of b.T that produces the state used e.g. by embedded methods to
|
224
|
+
# implement the adaptive timestep logic.
|
225
|
+
row_index_of_solution_estimate: ClassVar[int | None] = None
|
226
|
+
|
227
|
+
# Bounds of the adaptive Δt.
|
228
|
+
dt_max: Static[jtp.FloatLike] = jnp.inf
|
229
|
+
dt_min: Static[jtp.FloatLike] = -jnp.inf
|
230
|
+
|
231
|
+
# Tolerances used to scale the two states corresponding to the high-order solution
|
232
|
+
# and the low-order estimate during the computation of the local integration error.
|
233
|
+
rtol: Static[jtp.FloatLike] = RTOL_DEFAULT
|
234
|
+
atol: Static[jtp.FloatLike] = ATOL_DEFAULT
|
235
|
+
|
236
|
+
# Parameters of the adaptive timestep logic.
|
237
|
+
# Refer to Eq. (4.13) pag. 168 of Hairer93.
|
238
|
+
safety: Static[jtp.FloatLike] = SAFETY_DEFAULT
|
239
|
+
beta_max: Static[jtp.FloatLike] = BETA_MAX_DEFAULT
|
240
|
+
beta_min: Static[jtp.FloatLike] = BETA_MIN_DEFAULT
|
241
|
+
|
242
|
+
# Maximum number of rejected steps when the Δt needs to be reduced.
|
243
|
+
max_step_rejections: Static[jtp.IntLike] = MAX_STEP_REJECTIONS_DEFAULT
|
244
|
+
|
245
|
+
def init(
|
246
|
+
self,
|
247
|
+
x0: State,
|
248
|
+
t0: Time,
|
249
|
+
dt: TimeStep | None = None,
|
250
|
+
*,
|
251
|
+
include_dynamics_aux_dict: bool = False,
|
252
|
+
**kwargs,
|
253
|
+
) -> dict[str, Any]:
|
254
|
+
|
255
|
+
# In these type of integrators, it's not relevant picking a meaningful dt.
|
256
|
+
# We just need to execute __call__ once to initialize the dictionary of params.
|
257
|
+
return super().init(
|
258
|
+
x0=x0,
|
259
|
+
t0=t0,
|
260
|
+
dt=0.001,
|
261
|
+
include_dynamics_aux_dict=include_dynamics_aux_dict,
|
262
|
+
**kwargs,
|
263
|
+
)
|
264
|
+
|
265
|
+
def __call__(self, x0: State, t0: Time, dt: TimeStep, **kwargs) -> NextState:
|
266
|
+
|
267
|
+
# This method is called differently in three stages:
|
268
|
+
#
|
269
|
+
# 1. During initialization, to allocate a dummy params dictionary.
|
270
|
+
# 2. During the first step, to compute the initial valid params dictionary.
|
271
|
+
# 3. After the first step, to compute the next state and the next valid params.
|
272
|
+
#
|
273
|
+
# Stage 1 produces a zero-filled dummy dictionary.
|
274
|
+
# Stage 2 receives a dummy dictionary and produces valid parameters that can be
|
275
|
+
# fed to later steps.
|
276
|
+
# Stage 3 corresponds to any consecutive step after the first one. It can re-use
|
277
|
+
# data (like for FSAL) from previous steps.
|
278
|
+
#
|
279
|
+
integrator_init = self.params.get(self.InitializingKey, jnp.array(False))
|
280
|
+
integrator_first_step = self.params.get(self.AfterInitKey, jnp.array(False))
|
281
|
+
|
282
|
+
# Close f over optional kwargs.
|
283
|
+
f = lambda x, t: self.dynamics(x=x, t=t, **kwargs)
|
284
|
+
|
285
|
+
# Define the final time.
|
286
|
+
tf = t0 + dt
|
287
|
+
|
288
|
+
# Initialize solution orders.
|
289
|
+
p = self.order_of_solution
|
290
|
+
p̂ = self.order_of_solution_estimate
|
291
|
+
q = jnp.minimum(p, p̂)
|
292
|
+
|
293
|
+
# In Stage 1 and 2, estimate from scratch dt0 and dxdt0.
|
294
|
+
# In Stage 3, dt0 is taken from the previous step. If the integrator supports
|
295
|
+
# FSAL, dxdt0 is taken from the previous step. Otherwise, it is computed by
|
296
|
+
# evaluating the dynamics.
|
297
|
+
self.params["dt0"], self.params["dxdt0"] = jax.lax.cond(
|
298
|
+
pred=jnp.logical_or("dt0" not in self.params, integrator_first_step),
|
299
|
+
true_fun=lambda params: estimate_step_size(
|
300
|
+
x0=x0, t0=t0, f=f, order=p, atol=self.atol, rtol=self.rtol
|
301
|
+
),
|
302
|
+
false_fun=lambda params: (
|
303
|
+
params.get("dt0", jnp.array(0).astype(float)),
|
304
|
+
self.params.get("dxdt0", f(x0, t0)[0]),
|
305
|
+
),
|
306
|
+
operand=self.params,
|
307
|
+
)
|
308
|
+
|
309
|
+
# If the integrator does not support FSAL, it is useless to store dxdt0.
|
310
|
+
if not self.has_fsal:
|
311
|
+
_ = self.params.pop("dxdt0")
|
312
|
+
|
313
|
+
# Clip the estimated initial step size to the given bounds, if necessary.
|
314
|
+
self.params["dt0"] = jnp.clip(
|
315
|
+
a=self.params["dt0"],
|
316
|
+
a_min=jnp.minimum(self.dt_min, self.params["dt0"]),
|
317
|
+
a_max=jnp.minimum(self.dt_max, self.params["dt0"]),
|
318
|
+
)
|
319
|
+
|
320
|
+
# =========================================================
|
321
|
+
# While loop to reach tf from t0 using an adaptive timestep
|
322
|
+
# =========================================================
|
323
|
+
|
324
|
+
# Initialize the carry of the while loop.
|
325
|
+
Carry = tuple[Any, ...]
|
326
|
+
carry0: Carry = (
|
327
|
+
x0,
|
328
|
+
jnp.array(t0).astype(float),
|
329
|
+
self.params,
|
330
|
+
jnp.array(0, dtype=int),
|
331
|
+
jnp.array(False).astype(bool),
|
332
|
+
)
|
333
|
+
|
334
|
+
def while_loop_cond(carry: Carry) -> jtp.Bool:
|
335
|
+
_, _, _, _, break_loop = carry
|
336
|
+
return jnp.logical_not(break_loop)
|
337
|
+
|
338
|
+
# Each loop is an integration step with variable Δt.
|
339
|
+
# Depending on the integration error, the step could be discarded and the
|
340
|
+
# while body ran again from the same (x0, t0) but with a smaller Δt.
|
341
|
+
# We run these loops until the final time tf is reached.
|
342
|
+
def while_loop_body(carry: Carry) -> Carry:
|
343
|
+
|
344
|
+
# Unpack the carry.
|
345
|
+
x0, t0, params, discarded_steps, _ = carry
|
346
|
+
|
347
|
+
# Take care of the final adaptive step.
|
348
|
+
# We want the final Δt to let us reach tf exactly.
|
349
|
+
# Then we can exit the while loop.
|
350
|
+
Δt0 = params["dt0"]
|
351
|
+
Δt0 = jnp.where(t0 + Δt0 < tf, Δt0, tf - t0)
|
352
|
+
break_loop = jnp.where(t0 + Δt0 < tf, False, True)
|
353
|
+
|
354
|
+
# Run the underlying explicit RK integrator.
|
355
|
+
# The output z contains multiple solutions (depending on the rows of b.T).
|
356
|
+
with self.editable(validate=True) as integrator:
|
357
|
+
integrator.params = params
|
358
|
+
z = integrator._compute_next_state(x0=x0, t0=t0, dt=Δt0, **kwargs)
|
359
|
+
params_next = integrator.params
|
360
|
+
|
361
|
+
# Extract the high-order solution xf and the low-order estimate x̂f.
|
362
|
+
xf = jax.tree_util.tree_map(lambda l: l[self.row_index_of_solution], z)
|
363
|
+
x̂f = jax.tree_util.tree_map(
|
364
|
+
lambda l: l[self.row_index_of_solution_estimate], z
|
365
|
+
)
|
366
|
+
|
367
|
+
# Calculate the local integration error.
|
368
|
+
local_error = local_error_estimation(
|
369
|
+
x0=x0, xf=xf, xf_estimate=x̂f, rtol=self.rtol, atol=self.atol
|
370
|
+
)
|
371
|
+
|
372
|
+
# Shrink the Δt every time by the safety factor (even when accepted).
|
373
|
+
# The β parameters define the bounds of the timestep update factor.
|
374
|
+
safety = jnp.clip(self.safety, a_min=0.0, a_max=1.0)
|
375
|
+
β_min = jnp.maximum(0.0, self.beta_min)
|
376
|
+
β_max = jnp.maximum(β_min, self.beta_max)
|
377
|
+
|
378
|
+
# Compute the next Δt from the desired integration error.
|
379
|
+
# The computed integration step is accepted if error <= 1.0,
|
380
|
+
# otherwise it is rejected.
|
381
|
+
#
|
382
|
+
# In case of rejection, Δt_next is always smaller than Δt0.
|
383
|
+
# In case of acceptance, Δt_next could either be larger than Δt0,
|
384
|
+
# or slightly smaller than Δt0 depending on the safety factor.
|
385
|
+
Δt_next = Δt0 * jnp.clip(
|
386
|
+
a=safety * jnp.power(1 / local_error, 1 / (q + 1)),
|
387
|
+
a_min=β_min,
|
388
|
+
a_max=β_max,
|
389
|
+
)
|
390
|
+
|
391
|
+
def accept_step():
|
392
|
+
# Use Δt_next in the next while loop.
|
393
|
+
# If it is the last one, and Δt0 was clipped, return the initial Δt0.
|
394
|
+
params_next_accepted = params_next | dict(
|
395
|
+
dt0=jnp.clip(
|
396
|
+
jax.lax.select(
|
397
|
+
pred=break_loop,
|
398
|
+
on_true=params["dt0"],
|
399
|
+
on_false=Δt_next,
|
400
|
+
),
|
401
|
+
self.dt_min,
|
402
|
+
self.dt_max,
|
403
|
+
)
|
404
|
+
)
|
405
|
+
|
406
|
+
# Start the next while loop from the final state.
|
407
|
+
x0_next = xf
|
408
|
+
|
409
|
+
# Advance the starting time of the next adaptive step.
|
410
|
+
t0_next = t0 + Δt0
|
411
|
+
|
412
|
+
# Signal that the final time has been reached.
|
413
|
+
break_loop_next = t0 + Δt0 >= tf
|
414
|
+
|
415
|
+
return (
|
416
|
+
x0_next,
|
417
|
+
t0_next,
|
418
|
+
break_loop_next,
|
419
|
+
params_next_accepted,
|
420
|
+
jnp.array(0, dtype=int),
|
421
|
+
)
|
422
|
+
|
423
|
+
def reject_step():
|
424
|
+
# Get back the original params.
|
425
|
+
params_next_rejected = params
|
426
|
+
|
427
|
+
# This time, with a reduced Δt.
|
428
|
+
params_next_rejected["dt0"] = jnp.clip(
|
429
|
+
Δt_next, self.dt_min, self.dt_max
|
430
|
+
)
|
431
|
+
|
432
|
+
return (
|
433
|
+
x0,
|
434
|
+
t0,
|
435
|
+
False,
|
436
|
+
params_next_rejected,
|
437
|
+
discarded_steps + 1,
|
438
|
+
)
|
439
|
+
|
440
|
+
# Decide whether to accept or reject the step.
|
441
|
+
(
|
442
|
+
x0_next,
|
443
|
+
t0_next,
|
444
|
+
break_loop,
|
445
|
+
params_next,
|
446
|
+
discarded_steps,
|
447
|
+
) = jax.lax.cond(
|
448
|
+
pred=jnp.array(
|
449
|
+
[
|
450
|
+
discarded_steps >= self.max_step_rejections,
|
451
|
+
local_error <= 1.0,
|
452
|
+
Δt_next < self.dt_min,
|
453
|
+
integrator_init,
|
454
|
+
]
|
455
|
+
).any(),
|
456
|
+
true_fun=accept_step,
|
457
|
+
false_fun=reject_step,
|
458
|
+
)
|
459
|
+
|
460
|
+
return (
|
461
|
+
x0_next,
|
462
|
+
t0_next,
|
463
|
+
params_next,
|
464
|
+
discarded_steps,
|
465
|
+
break_loop,
|
466
|
+
)
|
467
|
+
|
468
|
+
# Integrate with adaptive step until tf is reached.
|
469
|
+
(
|
470
|
+
xf,
|
471
|
+
tf,
|
472
|
+
params_tf,
|
473
|
+
_,
|
474
|
+
_,
|
475
|
+
) = jax.lax.while_loop(
|
476
|
+
cond_fun=while_loop_cond,
|
477
|
+
body_fun=while_loop_body,
|
478
|
+
init_val=carry0,
|
479
|
+
)
|
480
|
+
|
481
|
+
# Store the parameters.
|
482
|
+
# They will be returned to the caller in a functional way in the step method.
|
483
|
+
with self.mutable_context(mutability=Mutability.MUTABLE):
|
484
|
+
self.params = params_tf
|
485
|
+
|
486
|
+
return xf
|
487
|
+
|
488
|
+
@property
|
489
|
+
def order_of_solution(self) -> int:
|
490
|
+
return self.order_of_bT_rows[self.row_index_of_solution]
|
491
|
+
|
492
|
+
@property
|
493
|
+
def order_of_solution_estimate(self) -> int:
|
494
|
+
return self.order_of_bT_rows[self.row_index_of_solution_estimate]
|
495
|
+
|
496
|
+
@classmethod
|
497
|
+
def build(
|
498
|
+
cls: Type[Self],
|
499
|
+
*,
|
500
|
+
dynamics: SystemDynamics[State, StateDerivative],
|
501
|
+
fsal_enabled_if_supported: jtp.BoolLike = True,
|
502
|
+
dt_max: jtp.FloatLike = jnp.inf,
|
503
|
+
dt_min: jtp.FloatLike = -jnp.inf,
|
504
|
+
rtol: jtp.FloatLike = RTOL_DEFAULT,
|
505
|
+
atol: jtp.FloatLike = ATOL_DEFAULT,
|
506
|
+
safety: jtp.FloatLike = SAFETY_DEFAULT,
|
507
|
+
beta_max: jtp.FloatLike = BETA_MAX_DEFAULT,
|
508
|
+
beta_min: jtp.FloatLike = BETA_MIN_DEFAULT,
|
509
|
+
max_step_rejections: jtp.IntLike = MAX_STEP_REJECTIONS_DEFAULT,
|
510
|
+
**kwargs,
|
511
|
+
) -> Self:
|
512
|
+
|
513
|
+
# Check that b.T has enough rows based on the configured index of the
|
514
|
+
# solution estimate. This is necessary for embedded methods.
|
515
|
+
if (
|
516
|
+
cls.row_index_of_solution_estimate is not None
|
517
|
+
and cls.row_index_of_solution_estimate >= cls.b.T.shape[0]
|
518
|
+
):
|
519
|
+
msg = "The index of the solution estimate ({}-th row of `b.T`) "
|
520
|
+
msg += "is out of range ({})."
|
521
|
+
raise ValueError(
|
522
|
+
msg.format(cls.row_index_of_solution_estimate, cls.b.T.shape[0])
|
523
|
+
)
|
524
|
+
|
525
|
+
integrator = super().build(
|
526
|
+
# Integrator:
|
527
|
+
dynamics=dynamics,
|
528
|
+
# ExplicitRungeKutta:
|
529
|
+
fsal_enabled_if_supported=bool(fsal_enabled_if_supported),
|
530
|
+
# EmbeddedRungeKutta:
|
531
|
+
dt_max=float(dt_max),
|
532
|
+
dt_min=float(dt_min),
|
533
|
+
rtol=float(rtol),
|
534
|
+
atol=float(atol),
|
535
|
+
safety=float(safety),
|
536
|
+
beta_max=float(beta_max),
|
537
|
+
beta_min=float(beta_min),
|
538
|
+
max_step_rejections=int(max_step_rejections),
|
539
|
+
**kwargs,
|
540
|
+
)
|
541
|
+
|
542
|
+
return integrator
|
543
|
+
|
544
|
+
|
545
|
+
@jax_dataclasses.pytree_dataclass
|
546
|
+
class HeunEulerSO3(EmbeddedRungeKutta[PyTreeType], ExplicitRungeKuttaSO3Mixin):
|
547
|
+
|
548
|
+
A: ClassVar[jax.typing.ArrayLike] = jnp.array(
|
549
|
+
[
|
550
|
+
[0, 0],
|
551
|
+
[1, 0],
|
552
|
+
]
|
553
|
+
).astype(float)
|
554
|
+
|
555
|
+
b: ClassVar[jax.typing.ArrayLike] = (
|
556
|
+
jnp.atleast_2d(
|
557
|
+
jnp.array(
|
558
|
+
[
|
559
|
+
[1 / 2, 1 / 2],
|
560
|
+
[1, 0],
|
561
|
+
]
|
562
|
+
),
|
563
|
+
)
|
564
|
+
.astype(float)
|
565
|
+
.transpose()
|
566
|
+
)
|
567
|
+
|
568
|
+
c: ClassVar[jax.typing.ArrayLike] = jnp.array(
|
569
|
+
[0, 1],
|
570
|
+
).astype(float)
|
571
|
+
|
572
|
+
row_index_of_solution: ClassVar[int] = 0
|
573
|
+
row_index_of_solution_estimate: ClassVar[int | None] = 1
|
574
|
+
|
575
|
+
order_of_bT_rows: ClassVar[tuple[int, ...]] = (2, 1)
|
576
|
+
|
577
|
+
|
578
|
+
@jax_dataclasses.pytree_dataclass
|
579
|
+
class BogackiShampineSO3(EmbeddedRungeKutta[PyTreeType], ExplicitRungeKuttaSO3Mixin):
|
580
|
+
|
581
|
+
A: ClassVar[jax.typing.ArrayLike] = jnp.array(
|
582
|
+
[
|
583
|
+
[0, 0, 0, 0],
|
584
|
+
[1 / 2, 0, 0, 0],
|
585
|
+
[0, 3 / 4, 0, 0],
|
586
|
+
[2 / 9, 1 / 3, 4 / 9, 0],
|
587
|
+
]
|
588
|
+
).astype(float)
|
589
|
+
|
590
|
+
b: ClassVar[jax.typing.ArrayLike] = (
|
591
|
+
jnp.atleast_2d(
|
592
|
+
jnp.array(
|
593
|
+
[
|
594
|
+
[2 / 9, 1 / 3, 4 / 9, 0],
|
595
|
+
[7 / 24, 1 / 4, 1 / 3, 1 / 8],
|
596
|
+
]
|
597
|
+
),
|
598
|
+
)
|
599
|
+
.astype(float)
|
600
|
+
.transpose()
|
601
|
+
)
|
602
|
+
|
603
|
+
c: ClassVar[jax.typing.ArrayLike] = jnp.array(
|
604
|
+
[0, 1 / 2, 3 / 4, 1],
|
605
|
+
).astype(float)
|
606
|
+
|
607
|
+
row_index_of_solution: ClassVar[int] = 0
|
608
|
+
row_index_of_solution_estimate: ClassVar[int | None] = 1
|
609
|
+
|
610
|
+
order_of_bT_rows: ClassVar[tuple[int, ...]] = (3, 2)
|
jaxsim/math/__init__.py
CHANGED
@@ -0,0 +1,11 @@
|
|
1
|
+
# Define the default standard gravity constant.
|
2
|
+
StandardGravity = 9.81
|
3
|
+
|
4
|
+
from .adjoint import Adjoint
|
5
|
+
from .cross import Cross
|
6
|
+
from .inertia import Inertia
|
7
|
+
from .joint_model import JointModel, supported_joint_motion
|
8
|
+
from .quaternion import Quaternion
|
9
|
+
from .rotation import Rotation
|
10
|
+
from .skew import Skew
|
11
|
+
from .transform import Transform
|
jaxsim/math/adjoint.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
import jax.numpy as jnp
|
2
|
+
import jaxlie
|
2
3
|
|
3
4
|
import jaxsim.typing as jtp
|
4
|
-
from jaxsim.sixd import so3
|
5
5
|
|
6
6
|
from .quaternion import Quaternion
|
7
7
|
from .skew import Skew
|
@@ -31,13 +31,35 @@ class Adjoint:
|
|
31
31
|
assert quaternion.size == 4
|
32
32
|
assert translation.size == 3
|
33
33
|
|
34
|
-
Q_sixd =
|
34
|
+
Q_sixd = jaxlie.SO3.from_quaternion_xyzw(xyzw=Quaternion.to_xyzw(quaternion))
|
35
35
|
Q_sixd = Q_sixd if not normalize_quaternion else Q_sixd.normalize()
|
36
36
|
|
37
37
|
return Adjoint.from_rotation_and_translation(
|
38
38
|
rotation=Q_sixd.as_matrix(), translation=translation, inverse=inverse
|
39
39
|
)
|
40
40
|
|
41
|
+
@staticmethod
|
42
|
+
def from_transform(transform: jtp.MatrixLike, inverse: bool = False) -> jtp.Matrix:
|
43
|
+
"""
|
44
|
+
Create an adjoint matrix from a transformation matrix.
|
45
|
+
|
46
|
+
Args:
|
47
|
+
transform: A 4x4 transformation matrix.
|
48
|
+
inverse: Whether to compute the inverse adjoint.
|
49
|
+
|
50
|
+
Returns:
|
51
|
+
The 6x6 adjoint matrix.
|
52
|
+
"""
|
53
|
+
|
54
|
+
A_H_B = jnp.array(transform).astype(float)
|
55
|
+
assert transform.shape == (4, 4)
|
56
|
+
|
57
|
+
return (
|
58
|
+
jaxlie.SE3.from_matrix(matrix=A_H_B).adjoint()
|
59
|
+
if not inverse
|
60
|
+
else jaxlie.SE3.from_matrix(matrix=A_H_B).inverse().adjoint()
|
61
|
+
)
|
62
|
+
|
41
63
|
@staticmethod
|
42
64
|
def from_rotation_and_translation(
|
43
65
|
rotation: jtp.Matrix = jnp.eye(3),
|