jaxsim 0.1.dev401__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. jaxsim/__init__.py +5 -6
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +3 -0
  4. jaxsim/api/com.py +240 -0
  5. jaxsim/api/common.py +216 -0
  6. jaxsim/api/contact.py +271 -0
  7. jaxsim/api/data.py +821 -0
  8. jaxsim/api/joint.py +189 -0
  9. jaxsim/api/kin_dyn_parameters.py +777 -0
  10. jaxsim/api/link.py +361 -0
  11. jaxsim/api/model.py +1633 -0
  12. jaxsim/api/ode.py +295 -0
  13. jaxsim/api/ode_data.py +694 -0
  14. jaxsim/api/references.py +421 -0
  15. jaxsim/integrators/__init__.py +2 -0
  16. jaxsim/integrators/common.py +594 -0
  17. jaxsim/integrators/fixed_step.py +102 -0
  18. jaxsim/integrators/variable_step.py +610 -0
  19. jaxsim/math/__init__.py +11 -0
  20. jaxsim/math/adjoint.py +24 -2
  21. jaxsim/math/joint_model.py +335 -0
  22. jaxsim/math/quaternion.py +44 -3
  23. jaxsim/math/rotation.py +4 -4
  24. jaxsim/math/transform.py +92 -0
  25. jaxsim/mujoco/__init__.py +3 -0
  26. jaxsim/mujoco/__main__.py +192 -0
  27. jaxsim/mujoco/loaders.py +615 -0
  28. jaxsim/mujoco/model.py +414 -0
  29. jaxsim/mujoco/visualizer.py +176 -0
  30. jaxsim/parsers/descriptions/collision.py +14 -0
  31. jaxsim/parsers/descriptions/link.py +13 -2
  32. jaxsim/parsers/kinematic_graph.py +8 -3
  33. jaxsim/parsers/rod/parser.py +54 -38
  34. jaxsim/parsers/rod/utils.py +7 -8
  35. jaxsim/rbda/__init__.py +7 -0
  36. jaxsim/rbda/aba.py +295 -0
  37. jaxsim/rbda/collidable_points.py +142 -0
  38. jaxsim/{physics/algos → rbda}/crba.py +43 -42
  39. jaxsim/rbda/forward_kinematics.py +113 -0
  40. jaxsim/rbda/jacobian.py +201 -0
  41. jaxsim/rbda/rnea.py +237 -0
  42. jaxsim/rbda/soft_contacts.py +296 -0
  43. jaxsim/rbda/utils.py +152 -0
  44. jaxsim/terrain/__init__.py +2 -0
  45. jaxsim/{physics/algos → terrain}/terrain.py +4 -6
  46. jaxsim/typing.py +30 -30
  47. jaxsim/utils/__init__.py +1 -4
  48. jaxsim/utils/hashless.py +18 -0
  49. jaxsim/utils/jaxsim_dataclass.py +281 -31
  50. {jaxsim-0.1.dev401.dist-info → jaxsim-0.2.0.dist-info}/LICENSE +1 -1
  51. jaxsim-0.2.0.dist-info/METADATA +237 -0
  52. jaxsim-0.2.0.dist-info/RECORD +64 -0
  53. {jaxsim-0.1.dev401.dist-info → jaxsim-0.2.0.dist-info}/WHEEL +1 -1
  54. jaxsim/high_level/__init__.py +0 -2
  55. jaxsim/high_level/common.py +0 -11
  56. jaxsim/high_level/joint.py +0 -148
  57. jaxsim/high_level/link.py +0 -259
  58. jaxsim/high_level/model.py +0 -1695
  59. jaxsim/math/conv.py +0 -114
  60. jaxsim/math/joint.py +0 -101
  61. jaxsim/math/plucker.py +0 -100
  62. jaxsim/physics/__init__.py +0 -12
  63. jaxsim/physics/algos/__init__.py +0 -0
  64. jaxsim/physics/algos/aba.py +0 -256
  65. jaxsim/physics/algos/aba_motors.py +0 -284
  66. jaxsim/physics/algos/forward_kinematics.py +0 -79
  67. jaxsim/physics/algos/jacobian.py +0 -98
  68. jaxsim/physics/algos/rnea.py +0 -180
  69. jaxsim/physics/algos/rnea_motors.py +0 -196
  70. jaxsim/physics/algos/soft_contacts.py +0 -454
  71. jaxsim/physics/algos/utils.py +0 -69
  72. jaxsim/physics/model/__init__.py +0 -0
  73. jaxsim/physics/model/ground_contact.py +0 -55
  74. jaxsim/physics/model/physics_model.py +0 -358
  75. jaxsim/physics/model/physics_model_state.py +0 -174
  76. jaxsim/simulation/__init__.py +0 -4
  77. jaxsim/simulation/integrators.py +0 -452
  78. jaxsim/simulation/ode.py +0 -290
  79. jaxsim/simulation/ode_data.py +0 -53
  80. jaxsim/simulation/ode_integration.py +0 -125
  81. jaxsim/simulation/simulator.py +0 -544
  82. jaxsim/simulation/simulator_callbacks.py +0 -53
  83. jaxsim/simulation/utils.py +0 -15
  84. jaxsim/sixd/__init__.py +0 -2
  85. jaxsim/utils/oop.py +0 -532
  86. jaxsim/utils/vmappable.py +0 -117
  87. jaxsim-0.1.dev401.dist-info/METADATA +0 -167
  88. jaxsim-0.1.dev401.dist-info/RECORD +0 -64
  89. {jaxsim-0.1.dev401.dist-info → jaxsim-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,610 @@
1
+ import functools
2
+ from typing import Any, ClassVar, Generic, Type
3
+
4
+ try:
5
+ from typing import Self
6
+ except ImportError:
7
+ from typing_extensions import Self
8
+
9
+ import jax
10
+ import jax.flatten_util
11
+ import jax.numpy as jnp
12
+ import jax_dataclasses
13
+ from jax_dataclasses import Static
14
+
15
+ from jaxsim import typing as jtp
16
+ from jaxsim.utils import Mutability
17
+
18
+ from .common import (
19
+ ExplicitRungeKutta,
20
+ ExplicitRungeKuttaSO3Mixin,
21
+ NextState,
22
+ PyTreeType,
23
+ State,
24
+ StateDerivative,
25
+ SystemDynamics,
26
+ Time,
27
+ TimeStep,
28
+ )
29
+
30
+ # For robot dynamics, the following default tolerances are already pretty accurate.
31
+ # Users can either decrease them and pay the price of smaller Δt, or increase
32
+ # them and pay the price of less accurate dynamics.
33
+ RTOL_DEFAULT = 0.000_100 # 0.01%
34
+ ATOL_DEFAULT = 0.000_010 # 10μ
35
+
36
+ # Default parameters of Embedded RK schemes.
37
+ SAFETY_DEFAULT = 0.9
38
+ BETA_MIN_DEFAULT = 1.0 / 10
39
+ BETA_MAX_DEFAULT = 2.5
40
+ MAX_STEP_REJECTIONS_DEFAULT = 5
41
+
42
+
43
+ # =================
44
+ # Utility functions
45
+ # =================
46
+
47
+
48
+ @functools.partial(jax.jit, static_argnames=["f"])
49
+ def estimate_step_size(
50
+ x0: jtp.PyTree,
51
+ t0: Time,
52
+ f: SystemDynamics,
53
+ order: jtp.IntLike,
54
+ rtol: jtp.FloatLike = RTOL_DEFAULT,
55
+ atol: jtp.FloatLike = ATOL_DEFAULT,
56
+ ) -> tuple[jtp.Float, jtp.PyTree]:
57
+ r"""
58
+ Compute the initial step size to warm-start variable-step integrators.
59
+
60
+ Args:
61
+ x0: The initial state.
62
+ t0: The initial time.
63
+ f: The state derivative function :math:`f(x, t)`.
64
+ order:
65
+ The order :math:`p` of an integrator with truncation error
66
+ :math:`\mathcal{O}(\Delta t^{p+1})`.
67
+ rtol: The relative tolerance to scale the state.
68
+ atol: The absolute tolerance to scale the state.
69
+
70
+ Returns:
71
+ A tuple containing the computed initial step size
72
+ and the state derivative :math:`\dot{x} = f(x_0, t_0)`.
73
+
74
+ Note:
75
+ Interested readers could find implementation details in:
76
+
77
+ Solving Ordinary Differential Equations I: Nonstiff Problems, Sec. II.4.
78
+ E. Hairer, S. P. Norsett G. Wanner.
79
+ """
80
+
81
+ # Helper to flatten a pytree to a 1D vector.
82
+ def flatten(pytree) -> jax.Array:
83
+ return jax.flatten_util.ravel_pytree(pytree=pytree)[0]
84
+
85
+ # Compute the state derivative at the initial state.
86
+ ẋ0 = f(x0, t0)[0]
87
+
88
+ # Compute the scaling factors of the initial state and its derivative.
89
+ compute_scale = lambda x: atol + jnp.abs(x) * rtol
90
+ scale0 = jax.tree_util.tree_map(compute_scale, x0)
91
+ scale1 = jax.tree_util.tree_map(compute_scale, ẋ0)
92
+
93
+ # Scale the initial state and its derivative.
94
+ scale_pytree = lambda x, scale: jnp.abs(x) / scale
95
+ x0_scaled = jax.tree_util.tree_map(scale_pytree, x0, scale0)
96
+ ẋ0_scaled = jax.tree_util.tree_map(scale_pytree, ẋ0, scale1)
97
+
98
+ # Get the maximum of the scaled pytrees.
99
+ d0 = jnp.linalg.norm(flatten(x0_scaled), ord=jnp.inf)
100
+ d1 = jnp.linalg.norm(flatten(ẋ0_scaled), ord=jnp.inf)
101
+
102
+ # Compute the first guess of the initial step size.
103
+ h0 = jnp.where(jnp.minimum(d0, d1) <= 1e-5, 1e-6, 0.01 * d0 / d1)
104
+
105
+ # Compute the next state (explicit Euler step) and its derivative.
106
+ x1 = jax.tree_util.tree_map(lambda x0, ẋ0: x0 + h0 * ẋ0, x0, ẋ0)
107
+ ẋ1 = f(x1, t0 + h0)[0]
108
+
109
+ # Compute the scaling factor of the state derivatives.
110
+ compute_scale_2 = lambda ẋ0, ẋ1: atol + jnp.maximum(jnp.abs(ẋ0), jnp.abs(ẋ1)) * rtol
111
+ scale2 = jax.tree_util.tree_map(compute_scale_2, ẋ0, ẋ1)
112
+
113
+ # Scale the difference of the state derivatives.
114
+ scale_ẋ_difference = lambda ẋ0, ẋ1, scale: jnp.abs((ẋ0 - ẋ1) / scale)
115
+ ẋ_difference_scaled = jax.tree_util.tree_map(scale_ẋ_difference, ẋ0, ẋ1, scale2)
116
+
117
+ # Get the maximum of the scaled derivatives difference.
118
+ d2 = jnp.linalg.norm(flatten(ẋ_difference_scaled), ord=jnp.inf) / h0
119
+
120
+ # Compute the second guess of the initial step size.
121
+ h1 = jnp.where(
122
+ jnp.maximum(d1, d2) <= 1e-15,
123
+ jnp.maximum(1e-6, h0 * 1e-3),
124
+ (0.01 / jnp.maximum(d1, d2)) ** (1.0 / (order + 1.0)),
125
+ )
126
+
127
+ # Propose the final guess of the initial step size.
128
+ # Also return the state derivative computed at the initial state since
129
+ # likely it is a quantity that needs to be computed again later.
130
+ return jnp.array(jnp.minimum(100.0 * h0, h1), dtype=float), ẋ0
131
+
132
+
133
+ @jax.jit
134
+ def compute_pytree_scale(
135
+ x1: jtp.PyTree,
136
+ x2: jtp.PyTree | None = None,
137
+ rtol: jtp.FloatLike = RTOL_DEFAULT,
138
+ atol: jtp.FloatLike = ATOL_DEFAULT,
139
+ ) -> jtp.PyTree:
140
+ """
141
+ Compute the component-wise state scale factors to scale dynamical states.
142
+
143
+ Args:
144
+ x1: The first state (often the initial state).
145
+ x2: The optional second state (often the final state).
146
+ rtol: The relative tolerance to scale the state.
147
+ atol: The absolute tolerance to scale the state.
148
+
149
+ Returns:
150
+ A pytree with the same structure of the state containing the scaling factors.
151
+ """
152
+
153
+ # Consider a zero second pytree, if not given.
154
+ x2 = jax.tree_util.tree_map(lambda l: jnp.zeros_like(l), x1) if x2 is None else x2
155
+
156
+ # Compute the scaling factors of the initial state and its derivative.
157
+ compute_scale = lambda l1, l2: atol + jnp.maximum(jnp.abs(l1), jnp.abs(l2)) * rtol
158
+ scale = jax.tree_util.tree_map(compute_scale, x1, x2)
159
+
160
+ return scale
161
+
162
+
163
+ @jax.jit
164
+ def local_error_estimation(
165
+ xf: jtp.PyTree,
166
+ xf_estimate: jtp.PyTree | None = None,
167
+ x0: jtp.PyTree | None = None,
168
+ rtol: jtp.FloatLike = RTOL_DEFAULT,
169
+ atol: jtp.FloatLike = ATOL_DEFAULT,
170
+ norm_ord: jtp.IntLike | jtp.FloatLike = jnp.inf,
171
+ ) -> jtp.Float:
172
+ """
173
+ Estimate the local integration error, often used in Embedded RK schemes.
174
+
175
+ Args:
176
+ xf: The final state, often computed with the most accurate integrator.
177
+ xf_estimate:
178
+ The estimated final state, often computed with the less accurate integrator.
179
+ If missing, it is initialized to zero.
180
+ x0:
181
+ The initial state to compute the scaling factors. If missing, it is
182
+ initialized to zero.
183
+ rtol: The relative tolerance to scale the state.
184
+ atol: The absolute tolerance to scale the state.
185
+ norm_ord:
186
+ The norm to use to compute the error. Default is the infinity norm.
187
+
188
+ Returns:
189
+ The estimated local integration error.
190
+ """
191
+
192
+ # Helper to flatten a pytree to a 1D vector.
193
+ def flatten(pytree) -> jax.Array:
194
+ return jax.flatten_util.ravel_pytree(pytree=pytree)[0]
195
+
196
+ # Compute the scale considering the initial and final states.
197
+ scale = compute_pytree_scale(x1=xf, x2=x0, rtol=rtol, atol=atol)
198
+
199
+ # Consider a zero estimated final state, if not given.
200
+ xf_estimate = (
201
+ jax.tree_util.tree_map(lambda l: jnp.zeros_like(l), xf)
202
+ if xf_estimate is None
203
+ else xf_estimate
204
+ )
205
+
206
+ # Estimate the error.
207
+ estimate_error = lambda l, l̂, sc: jnp.abs(l - l̂) / sc
208
+ error_estimate = jax.tree_util.tree_map(estimate_error, xf, xf_estimate, scale)
209
+
210
+ # Return the highest element of the error estimate.
211
+ return jnp.linalg.norm(flatten(error_estimate), ord=norm_ord)
212
+
213
+
214
+ # ================================
215
+ # Embedded Runge-Kutta integrators
216
+ # ================================
217
+
218
+
219
+ @jax_dataclasses.pytree_dataclass
220
+ class EmbeddedRungeKutta(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
221
+
222
+ # Define the row of the integration output corresponding to the solution estimate.
223
+ # This is the row of b.T that produces the state used e.g. by embedded methods to
224
+ # implement the adaptive timestep logic.
225
+ row_index_of_solution_estimate: ClassVar[int | None] = None
226
+
227
+ # Bounds of the adaptive Δt.
228
+ dt_max: Static[jtp.FloatLike] = jnp.inf
229
+ dt_min: Static[jtp.FloatLike] = -jnp.inf
230
+
231
+ # Tolerances used to scale the two states corresponding to the high-order solution
232
+ # and the low-order estimate during the computation of the local integration error.
233
+ rtol: Static[jtp.FloatLike] = RTOL_DEFAULT
234
+ atol: Static[jtp.FloatLike] = ATOL_DEFAULT
235
+
236
+ # Parameters of the adaptive timestep logic.
237
+ # Refer to Eq. (4.13) pag. 168 of Hairer93.
238
+ safety: Static[jtp.FloatLike] = SAFETY_DEFAULT
239
+ beta_max: Static[jtp.FloatLike] = BETA_MAX_DEFAULT
240
+ beta_min: Static[jtp.FloatLike] = BETA_MIN_DEFAULT
241
+
242
+ # Maximum number of rejected steps when the Δt needs to be reduced.
243
+ max_step_rejections: Static[jtp.IntLike] = MAX_STEP_REJECTIONS_DEFAULT
244
+
245
+ def init(
246
+ self,
247
+ x0: State,
248
+ t0: Time,
249
+ dt: TimeStep | None = None,
250
+ *,
251
+ include_dynamics_aux_dict: bool = False,
252
+ **kwargs,
253
+ ) -> dict[str, Any]:
254
+
255
+ # In these type of integrators, it's not relevant picking a meaningful dt.
256
+ # We just need to execute __call__ once to initialize the dictionary of params.
257
+ return super().init(
258
+ x0=x0,
259
+ t0=t0,
260
+ dt=0.001,
261
+ include_dynamics_aux_dict=include_dynamics_aux_dict,
262
+ **kwargs,
263
+ )
264
+
265
+ def __call__(self, x0: State, t0: Time, dt: TimeStep, **kwargs) -> NextState:
266
+
267
+ # This method is called differently in three stages:
268
+ #
269
+ # 1. During initialization, to allocate a dummy params dictionary.
270
+ # 2. During the first step, to compute the initial valid params dictionary.
271
+ # 3. After the first step, to compute the next state and the next valid params.
272
+ #
273
+ # Stage 1 produces a zero-filled dummy dictionary.
274
+ # Stage 2 receives a dummy dictionary and produces valid parameters that can be
275
+ # fed to later steps.
276
+ # Stage 3 corresponds to any consecutive step after the first one. It can re-use
277
+ # data (like for FSAL) from previous steps.
278
+ #
279
+ integrator_init = self.params.get(self.InitializingKey, jnp.array(False))
280
+ integrator_first_step = self.params.get(self.AfterInitKey, jnp.array(False))
281
+
282
+ # Close f over optional kwargs.
283
+ f = lambda x, t: self.dynamics(x=x, t=t, **kwargs)
284
+
285
+ # Define the final time.
286
+ tf = t0 + dt
287
+
288
+ # Initialize solution orders.
289
+ p = self.order_of_solution
290
+ p̂ = self.order_of_solution_estimate
291
+ q = jnp.minimum(p, p̂)
292
+
293
+ # In Stage 1 and 2, estimate from scratch dt0 and dxdt0.
294
+ # In Stage 3, dt0 is taken from the previous step. If the integrator supports
295
+ # FSAL, dxdt0 is taken from the previous step. Otherwise, it is computed by
296
+ # evaluating the dynamics.
297
+ self.params["dt0"], self.params["dxdt0"] = jax.lax.cond(
298
+ pred=jnp.logical_or("dt0" not in self.params, integrator_first_step),
299
+ true_fun=lambda params: estimate_step_size(
300
+ x0=x0, t0=t0, f=f, order=p, atol=self.atol, rtol=self.rtol
301
+ ),
302
+ false_fun=lambda params: (
303
+ params.get("dt0", jnp.array(0).astype(float)),
304
+ self.params.get("dxdt0", f(x0, t0)[0]),
305
+ ),
306
+ operand=self.params,
307
+ )
308
+
309
+ # If the integrator does not support FSAL, it is useless to store dxdt0.
310
+ if not self.has_fsal:
311
+ _ = self.params.pop("dxdt0")
312
+
313
+ # Clip the estimated initial step size to the given bounds, if necessary.
314
+ self.params["dt0"] = jnp.clip(
315
+ a=self.params["dt0"],
316
+ a_min=jnp.minimum(self.dt_min, self.params["dt0"]),
317
+ a_max=jnp.minimum(self.dt_max, self.params["dt0"]),
318
+ )
319
+
320
+ # =========================================================
321
+ # While loop to reach tf from t0 using an adaptive timestep
322
+ # =========================================================
323
+
324
+ # Initialize the carry of the while loop.
325
+ Carry = tuple[Any, ...]
326
+ carry0: Carry = (
327
+ x0,
328
+ jnp.array(t0).astype(float),
329
+ self.params,
330
+ jnp.array(0, dtype=int),
331
+ jnp.array(False).astype(bool),
332
+ )
333
+
334
+ def while_loop_cond(carry: Carry) -> jtp.Bool:
335
+ _, _, _, _, break_loop = carry
336
+ return jnp.logical_not(break_loop)
337
+
338
+ # Each loop is an integration step with variable Δt.
339
+ # Depending on the integration error, the step could be discarded and the
340
+ # while body ran again from the same (x0, t0) but with a smaller Δt.
341
+ # We run these loops until the final time tf is reached.
342
+ def while_loop_body(carry: Carry) -> Carry:
343
+
344
+ # Unpack the carry.
345
+ x0, t0, params, discarded_steps, _ = carry
346
+
347
+ # Take care of the final adaptive step.
348
+ # We want the final Δt to let us reach tf exactly.
349
+ # Then we can exit the while loop.
350
+ Δt0 = params["dt0"]
351
+ Δt0 = jnp.where(t0 + Δt0 < tf, Δt0, tf - t0)
352
+ break_loop = jnp.where(t0 + Δt0 < tf, False, True)
353
+
354
+ # Run the underlying explicit RK integrator.
355
+ # The output z contains multiple solutions (depending on the rows of b.T).
356
+ with self.editable(validate=True) as integrator:
357
+ integrator.params = params
358
+ z = integrator._compute_next_state(x0=x0, t0=t0, dt=Δt0, **kwargs)
359
+ params_next = integrator.params
360
+
361
+ # Extract the high-order solution xf and the low-order estimate x̂f.
362
+ xf = jax.tree_util.tree_map(lambda l: l[self.row_index_of_solution], z)
363
+ x̂f = jax.tree_util.tree_map(
364
+ lambda l: l[self.row_index_of_solution_estimate], z
365
+ )
366
+
367
+ # Calculate the local integration error.
368
+ local_error = local_error_estimation(
369
+ x0=x0, xf=xf, xf_estimate=x̂f, rtol=self.rtol, atol=self.atol
370
+ )
371
+
372
+ # Shrink the Δt every time by the safety factor (even when accepted).
373
+ # The β parameters define the bounds of the timestep update factor.
374
+ safety = jnp.clip(self.safety, a_min=0.0, a_max=1.0)
375
+ β_min = jnp.maximum(0.0, self.beta_min)
376
+ β_max = jnp.maximum(β_min, self.beta_max)
377
+
378
+ # Compute the next Δt from the desired integration error.
379
+ # The computed integration step is accepted if error <= 1.0,
380
+ # otherwise it is rejected.
381
+ #
382
+ # In case of rejection, Δt_next is always smaller than Δt0.
383
+ # In case of acceptance, Δt_next could either be larger than Δt0,
384
+ # or slightly smaller than Δt0 depending on the safety factor.
385
+ Δt_next = Δt0 * jnp.clip(
386
+ a=safety * jnp.power(1 / local_error, 1 / (q + 1)),
387
+ a_min=β_min,
388
+ a_max=β_max,
389
+ )
390
+
391
+ def accept_step():
392
+ # Use Δt_next in the next while loop.
393
+ # If it is the last one, and Δt0 was clipped, return the initial Δt0.
394
+ params_next_accepted = params_next | dict(
395
+ dt0=jnp.clip(
396
+ jax.lax.select(
397
+ pred=break_loop,
398
+ on_true=params["dt0"],
399
+ on_false=Δt_next,
400
+ ),
401
+ self.dt_min,
402
+ self.dt_max,
403
+ )
404
+ )
405
+
406
+ # Start the next while loop from the final state.
407
+ x0_next = xf
408
+
409
+ # Advance the starting time of the next adaptive step.
410
+ t0_next = t0 + Δt0
411
+
412
+ # Signal that the final time has been reached.
413
+ break_loop_next = t0 + Δt0 >= tf
414
+
415
+ return (
416
+ x0_next,
417
+ t0_next,
418
+ break_loop_next,
419
+ params_next_accepted,
420
+ jnp.array(0, dtype=int),
421
+ )
422
+
423
+ def reject_step():
424
+ # Get back the original params.
425
+ params_next_rejected = params
426
+
427
+ # This time, with a reduced Δt.
428
+ params_next_rejected["dt0"] = jnp.clip(
429
+ Δt_next, self.dt_min, self.dt_max
430
+ )
431
+
432
+ return (
433
+ x0,
434
+ t0,
435
+ False,
436
+ params_next_rejected,
437
+ discarded_steps + 1,
438
+ )
439
+
440
+ # Decide whether to accept or reject the step.
441
+ (
442
+ x0_next,
443
+ t0_next,
444
+ break_loop,
445
+ params_next,
446
+ discarded_steps,
447
+ ) = jax.lax.cond(
448
+ pred=jnp.array(
449
+ [
450
+ discarded_steps >= self.max_step_rejections,
451
+ local_error <= 1.0,
452
+ Δt_next < self.dt_min,
453
+ integrator_init,
454
+ ]
455
+ ).any(),
456
+ true_fun=accept_step,
457
+ false_fun=reject_step,
458
+ )
459
+
460
+ return (
461
+ x0_next,
462
+ t0_next,
463
+ params_next,
464
+ discarded_steps,
465
+ break_loop,
466
+ )
467
+
468
+ # Integrate with adaptive step until tf is reached.
469
+ (
470
+ xf,
471
+ tf,
472
+ params_tf,
473
+ _,
474
+ _,
475
+ ) = jax.lax.while_loop(
476
+ cond_fun=while_loop_cond,
477
+ body_fun=while_loop_body,
478
+ init_val=carry0,
479
+ )
480
+
481
+ # Store the parameters.
482
+ # They will be returned to the caller in a functional way in the step method.
483
+ with self.mutable_context(mutability=Mutability.MUTABLE):
484
+ self.params = params_tf
485
+
486
+ return xf
487
+
488
+ @property
489
+ def order_of_solution(self) -> int:
490
+ return self.order_of_bT_rows[self.row_index_of_solution]
491
+
492
+ @property
493
+ def order_of_solution_estimate(self) -> int:
494
+ return self.order_of_bT_rows[self.row_index_of_solution_estimate]
495
+
496
+ @classmethod
497
+ def build(
498
+ cls: Type[Self],
499
+ *,
500
+ dynamics: SystemDynamics[State, StateDerivative],
501
+ fsal_enabled_if_supported: jtp.BoolLike = True,
502
+ dt_max: jtp.FloatLike = jnp.inf,
503
+ dt_min: jtp.FloatLike = -jnp.inf,
504
+ rtol: jtp.FloatLike = RTOL_DEFAULT,
505
+ atol: jtp.FloatLike = ATOL_DEFAULT,
506
+ safety: jtp.FloatLike = SAFETY_DEFAULT,
507
+ beta_max: jtp.FloatLike = BETA_MAX_DEFAULT,
508
+ beta_min: jtp.FloatLike = BETA_MIN_DEFAULT,
509
+ max_step_rejections: jtp.IntLike = MAX_STEP_REJECTIONS_DEFAULT,
510
+ **kwargs,
511
+ ) -> Self:
512
+
513
+ # Check that b.T has enough rows based on the configured index of the
514
+ # solution estimate. This is necessary for embedded methods.
515
+ if (
516
+ cls.row_index_of_solution_estimate is not None
517
+ and cls.row_index_of_solution_estimate >= cls.b.T.shape[0]
518
+ ):
519
+ msg = "The index of the solution estimate ({}-th row of `b.T`) "
520
+ msg += "is out of range ({})."
521
+ raise ValueError(
522
+ msg.format(cls.row_index_of_solution_estimate, cls.b.T.shape[0])
523
+ )
524
+
525
+ integrator = super().build(
526
+ # Integrator:
527
+ dynamics=dynamics,
528
+ # ExplicitRungeKutta:
529
+ fsal_enabled_if_supported=bool(fsal_enabled_if_supported),
530
+ # EmbeddedRungeKutta:
531
+ dt_max=float(dt_max),
532
+ dt_min=float(dt_min),
533
+ rtol=float(rtol),
534
+ atol=float(atol),
535
+ safety=float(safety),
536
+ beta_max=float(beta_max),
537
+ beta_min=float(beta_min),
538
+ max_step_rejections=int(max_step_rejections),
539
+ **kwargs,
540
+ )
541
+
542
+ return integrator
543
+
544
+
545
+ @jax_dataclasses.pytree_dataclass
546
+ class HeunEulerSO3(EmbeddedRungeKutta[PyTreeType], ExplicitRungeKuttaSO3Mixin):
547
+
548
+ A: ClassVar[jax.typing.ArrayLike] = jnp.array(
549
+ [
550
+ [0, 0],
551
+ [1, 0],
552
+ ]
553
+ ).astype(float)
554
+
555
+ b: ClassVar[jax.typing.ArrayLike] = (
556
+ jnp.atleast_2d(
557
+ jnp.array(
558
+ [
559
+ [1 / 2, 1 / 2],
560
+ [1, 0],
561
+ ]
562
+ ),
563
+ )
564
+ .astype(float)
565
+ .transpose()
566
+ )
567
+
568
+ c: ClassVar[jax.typing.ArrayLike] = jnp.array(
569
+ [0, 1],
570
+ ).astype(float)
571
+
572
+ row_index_of_solution: ClassVar[int] = 0
573
+ row_index_of_solution_estimate: ClassVar[int | None] = 1
574
+
575
+ order_of_bT_rows: ClassVar[tuple[int, ...]] = (2, 1)
576
+
577
+
578
+ @jax_dataclasses.pytree_dataclass
579
+ class BogackiShampineSO3(EmbeddedRungeKutta[PyTreeType], ExplicitRungeKuttaSO3Mixin):
580
+
581
+ A: ClassVar[jax.typing.ArrayLike] = jnp.array(
582
+ [
583
+ [0, 0, 0, 0],
584
+ [1 / 2, 0, 0, 0],
585
+ [0, 3 / 4, 0, 0],
586
+ [2 / 9, 1 / 3, 4 / 9, 0],
587
+ ]
588
+ ).astype(float)
589
+
590
+ b: ClassVar[jax.typing.ArrayLike] = (
591
+ jnp.atleast_2d(
592
+ jnp.array(
593
+ [
594
+ [2 / 9, 1 / 3, 4 / 9, 0],
595
+ [7 / 24, 1 / 4, 1 / 3, 1 / 8],
596
+ ]
597
+ ),
598
+ )
599
+ .astype(float)
600
+ .transpose()
601
+ )
602
+
603
+ c: ClassVar[jax.typing.ArrayLike] = jnp.array(
604
+ [0, 1 / 2, 3 / 4, 1],
605
+ ).astype(float)
606
+
607
+ row_index_of_solution: ClassVar[int] = 0
608
+ row_index_of_solution_estimate: ClassVar[int | None] = 1
609
+
610
+ order_of_bT_rows: ClassVar[tuple[int, ...]] = (3, 2)
jaxsim/math/__init__.py CHANGED
@@ -0,0 +1,11 @@
1
+ # Define the default standard gravity constant.
2
+ StandardGravity = 9.81
3
+
4
+ from .adjoint import Adjoint
5
+ from .cross import Cross
6
+ from .inertia import Inertia
7
+ from .joint_model import JointModel, supported_joint_motion
8
+ from .quaternion import Quaternion
9
+ from .rotation import Rotation
10
+ from .skew import Skew
11
+ from .transform import Transform
jaxsim/math/adjoint.py CHANGED
@@ -1,7 +1,7 @@
1
1
  import jax.numpy as jnp
2
+ import jaxlie
2
3
 
3
4
  import jaxsim.typing as jtp
4
- from jaxsim.sixd import so3
5
5
 
6
6
  from .quaternion import Quaternion
7
7
  from .skew import Skew
@@ -31,13 +31,35 @@ class Adjoint:
31
31
  assert quaternion.size == 4
32
32
  assert translation.size == 3
33
33
 
34
- Q_sixd = so3.SO3.from_quaternion_xyzw(xyzw=Quaternion.to_xyzw(quaternion))
34
+ Q_sixd = jaxlie.SO3.from_quaternion_xyzw(xyzw=Quaternion.to_xyzw(quaternion))
35
35
  Q_sixd = Q_sixd if not normalize_quaternion else Q_sixd.normalize()
36
36
 
37
37
  return Adjoint.from_rotation_and_translation(
38
38
  rotation=Q_sixd.as_matrix(), translation=translation, inverse=inverse
39
39
  )
40
40
 
41
+ @staticmethod
42
+ def from_transform(transform: jtp.MatrixLike, inverse: bool = False) -> jtp.Matrix:
43
+ """
44
+ Create an adjoint matrix from a transformation matrix.
45
+
46
+ Args:
47
+ transform: A 4x4 transformation matrix.
48
+ inverse: Whether to compute the inverse adjoint.
49
+
50
+ Returns:
51
+ The 6x6 adjoint matrix.
52
+ """
53
+
54
+ A_H_B = jnp.array(transform).astype(float)
55
+ assert transform.shape == (4, 4)
56
+
57
+ return (
58
+ jaxlie.SE3.from_matrix(matrix=A_H_B).adjoint()
59
+ if not inverse
60
+ else jaxlie.SE3.from_matrix(matrix=A_H_B).inverse().adjoint()
61
+ )
62
+
41
63
  @staticmethod
42
64
  def from_rotation_and_translation(
43
65
  rotation: jtp.Matrix = jnp.eye(3),