jaxsim 0.4.3.dev245__py3-none-any.whl → 0.4.3.dev269__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jaxsim/_version.py CHANGED
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.4.3.dev245'
16
- __version_tuple__ = version_tuple = (0, 4, 3, 'dev245')
15
+ __version__ = version = '0.4.3.dev269'
16
+ __version_tuple__ = version_tuple = (0, 4, 3, 'dev269')
jaxsim/api/model.py CHANGED
@@ -54,6 +54,10 @@ class JaxSimModel(JaxsimDataclass):
54
54
  default=None, repr=False
55
55
  )
56
56
 
57
+ integrator: Static[jaxsim.integrators.Integrator | None] = dataclasses.field(
58
+ default=None, repr=False
59
+ )
60
+
57
61
  _description: Static[wrappers.HashlessObject[ModelDescription | None]] = (
58
62
  dataclasses.field(default=None, repr=False)
59
63
  )
@@ -93,12 +97,16 @@ class JaxSimModel(JaxsimDataclass):
93
97
  # Initialization and state
94
98
  # ========================
95
99
 
96
- @staticmethod
100
+ @classmethod
97
101
  def build_from_model_description(
102
+ cls,
98
103
  model_description: str | pathlib.Path | rod.Model,
99
- model_name: str | None = None,
100
104
  *,
105
+ model_name: str | None = None,
101
106
  time_step: jtp.FloatLike | None = None,
107
+ integrator: (
108
+ jaxsim.integrators.Integrator | type[jaxsim.integrators.Integrator] | None
109
+ ) = None,
102
110
  terrain: jaxsim.terrain.Terrain | None = None,
103
111
  contact_model: jaxsim.rbda.contacts.ContactModel | None = None,
104
112
  is_urdf: bool | None = None,
@@ -120,6 +128,10 @@ class JaxSimModel(JaxsimDataclass):
120
128
  contact_model:
121
129
  The contact model to consider.
122
130
  If not specified, a soft contacts model is used.
131
+ integrator:
132
+ The integrator to use. If not specified, a default one is used.
133
+ This argument can either be a pre-built integrator instance or one
134
+ of the integrator classes defined in JaxSim.
123
135
  is_urdf:
124
136
  The optional flag to force the model description to be parsed as a URDF.
125
137
  This is usually automatically inferred.
@@ -146,10 +158,11 @@ class JaxSimModel(JaxsimDataclass):
146
158
  )
147
159
 
148
160
  # Build the model.
149
- model = JaxSimModel.build(
161
+ model = cls.build(
150
162
  model_description=intermediate_description,
151
163
  model_name=model_name,
152
164
  time_step=time_step,
165
+ integrator=integrator,
153
166
  terrain=terrain,
154
167
  contact_model=contact_model,
155
168
  )
@@ -160,12 +173,16 @@ class JaxSimModel(JaxsimDataclass):
160
173
 
161
174
  return model
162
175
 
163
- @staticmethod
176
+ @classmethod
164
177
  def build(
178
+ cls,
165
179
  model_description: ModelDescription,
166
- model_name: str | None = None,
167
180
  *,
181
+ model_name: str | None = None,
168
182
  time_step: jtp.FloatLike | None = None,
183
+ integrator: (
184
+ jaxsim.integrators.Integrator | type[jaxsim.integrators.Integrator] | None
185
+ ) = None,
169
186
  terrain: jaxsim.terrain.Terrain | None = None,
170
187
  contact_model: jaxsim.rbda.contacts.ContactModel | None = None,
171
188
  ) -> JaxSimModel:
@@ -182,6 +199,11 @@ class JaxSimModel(JaxsimDataclass):
182
199
  The default time step to consider for the simulation. It can be
183
200
  manually overridden in the function that steps the simulation.
184
201
  terrain: The terrain to consider (the default is a flat infinite plane).
202
+ The optional name of the model overriding the physics model name.
203
+ integrator:
204
+ The integrator to use. If not specified, a default one is used.
205
+ This argument can either be a pre-built integrator instance or one
206
+ of the integrator classes defined in JaxSim.
185
207
  contact_model:
186
208
  The contact model to consider.
187
209
  If not specified, a soft contacts model is used.
@@ -195,23 +217,62 @@ class JaxSimModel(JaxsimDataclass):
195
217
 
196
218
  # Consider the default terrain (a flat infinite plane) if not specified.
197
219
  terrain = (
198
- terrain or JaxSimModel.__dataclass_fields__["terrain"].default_factory()
220
+ terrain
221
+ if terrain is not None
222
+ else JaxSimModel.__dataclass_fields__["terrain"].default_factory()
199
223
  )
200
224
 
201
225
  # Consider the default time step if not specified.
202
226
  time_step = (
203
- time_step or JaxSimModel.__dataclass_fields__["time_step"].default_factory()
227
+ time_step
228
+ if time_step is not None
229
+ else JaxSimModel.__dataclass_fields__["time_step"].default_factory()
204
230
  )
205
231
 
206
232
  # Create the default contact model.
207
233
  # It will be populated with an initial estimation of good parameters.
208
234
  # While these might not be the best, they are a good starting point.
209
- contact_model = contact_model or jaxsim.rbda.contacts.SoftContacts.build(
210
- terrain=terrain, parameters=None
235
+ contact_model = (
236
+ contact_model
237
+ if contact_model is not None
238
+ else jaxsim.rbda.contacts.SoftContacts.build(
239
+ terrain=terrain, parameters=None
240
+ )
211
241
  )
212
242
 
243
+ # Build the integrator if not provided.
244
+ match integrator:
245
+
246
+ # If None, build a default integrator.
247
+ case None:
248
+
249
+ integrator = jaxsim.integrators.fixed_step.Heun2SO3.build(
250
+ dynamics=js.ode.wrap_system_dynamics_for_integration(
251
+ system_dynamics=js.ode.system_dynamics
252
+ )
253
+ )
254
+
255
+ # If it's a pre-built integrator (also a custom one from the user)
256
+ # just use it as is.
257
+ case _ if isinstance(integrator, jaxsim.integrators.Integrator):
258
+ pass
259
+
260
+ # If an integrator class is passed, assume that it is a JaxSim integrator
261
+ # and build it with the default system dynamics.
262
+ case _ if issubclass(integrator, jaxsim.integrators.Integrator):
263
+
264
+ integrator_cls = integrator
265
+ integrator = integrator_cls.build(
266
+ dynamics=js.ode.wrap_system_dynamics_for_integration(
267
+ system_dynamics=js.ode.system_dynamics
268
+ )
269
+ )
270
+
271
+ case _:
272
+ raise ValueError(f"Invalid integrator: {integrator}")
273
+
213
274
  # Build the model.
214
- model = JaxSimModel(
275
+ model = cls(
215
276
  model_name=model_name,
216
277
  kin_dyn_parameters=js.kin_dyn_parameters.KynDynParameters.build(
217
278
  model_description=model_description
@@ -219,6 +280,7 @@ class JaxSimModel(JaxsimDataclass):
219
280
  time_step=time_step,
220
281
  terrain=terrain,
221
282
  contact_model=contact_model,
283
+ integrator=integrator,
222
284
  # The following is wrapped as hashless since it's a static argument, and we
223
285
  # don't want to trigger recompilation if it changes. All relevant parameters
224
286
  # needed to compute kinematics and dynamics quantities are stored in the
@@ -404,6 +466,7 @@ def reduce(
404
466
  reduced_model = JaxSimModel.build(
405
467
  model_description=reduced_intermediate_description,
406
468
  model_name=model.name(),
469
+ time_step=model.time_step,
407
470
  terrain=model.terrain,
408
471
  contact_model=model.contact_model,
409
472
  )
@@ -1912,10 +1975,10 @@ def step(
1912
1975
  model: JaxSimModel,
1913
1976
  data: js.data.JaxSimModelData,
1914
1977
  *,
1915
- integrator: jaxsim.integrators.Integrator,
1916
1978
  t0: jtp.FloatLike = 0.0,
1917
1979
  dt: jtp.FloatLike | None = None,
1918
- integrator_state: dict[str, Any] | None = None,
1980
+ integrator: jaxsim.integrators.Integrator | None = None,
1981
+ integrator_metadata: dict[str, Any] | None = None,
1919
1982
  link_forces: jtp.MatrixLike | None = None,
1920
1983
  joint_force_references: jtp.VectorLike | None = None,
1921
1984
  **kwargs,
@@ -1927,7 +1990,7 @@ def step(
1927
1990
  model: The model to consider.
1928
1991
  data: The data of the considered model.
1929
1992
  integrator: The integrator to use.
1930
- integrator_state: The state of the integrator.
1993
+ integrator_metadata: The metadata of the integrator, if needed.
1931
1994
  t0: The initial time to consider. Only relevant for time-dependent dynamics.
1932
1995
  dt: The time step to consider. If not specified, it is read from the model.
1933
1996
  link_forces:
@@ -1937,8 +2000,9 @@ def step(
1937
2000
  kwargs: Additional kwargs to pass to the integrator.
1938
2001
 
1939
2002
  Returns:
1940
- A tuple containing the new data of the model
1941
- and the new state of the integrator.
2003
+ A tuple containing the new data of the model and a dictionary of auxiliary
2004
+ data computed during the step. If the integrator has metadata, the dictionary
2005
+ will contain the new metadata stored in the `integrator_metadata` key.
1942
2006
 
1943
2007
  Note:
1944
2008
  In order to reduce the occurrences of frame conversions performed internally,
@@ -1953,8 +2017,9 @@ def step(
1953
2017
  integrator_kwargs = kwargs.pop("integrator_kwargs", {})
1954
2018
  integrator_kwargs = kwargs | integrator_kwargs
1955
2019
 
1956
- # Initialize the integrator state.
1957
- integrator_state_t0 = integrator_state if integrator_state is not None else dict()
2020
+ # Extract the integrator and the optional metadata.
2021
+ integrator_metadata_t0 = integrator_metadata
2022
+ integrator = integrator if integrator is not None else model.integrator
1958
2023
 
1959
2024
  # Initialize the time-related variables.
1960
2025
  state_t0 = data.state
@@ -2010,11 +2075,11 @@ def step(
2010
2075
  τ_references = references.joint_force_references(model=model)
2011
2076
 
2012
2077
  # Step the dynamics forward.
2013
- state_tf, integrator_state_tf = integrator.step(
2078
+ state_tf, integrator_metadata_tf = integrator.step(
2014
2079
  x0=state_t0,
2015
2080
  t0=t0,
2016
2081
  dt=dt,
2017
- params=integrator_state_t0,
2082
+ metadata=integrator_metadata_t0,
2018
2083
  # Always inject the current (model, data) pair into the system dynamics
2019
2084
  # considered by the integrator, and include the input variables represented
2020
2085
  # by the pair (f_L, τ_references).
@@ -2100,4 +2165,8 @@ def step(
2100
2165
  velocity_representation=data.velocity_representation, validate=False
2101
2166
  )
2102
2167
 
2103
- return data_tf, integrator_state_tf
2168
+ return data_tf, {} | (
2169
+ dict(integrator_metadata=integrator_metadata_tf)
2170
+ if integrator_metadata is not None
2171
+ else {}
2172
+ )
jaxsim/api/ode.py CHANGED
@@ -24,41 +24,45 @@ class SystemDynamicsFromModelAndData(Protocol):
24
24
 
25
25
 
26
26
  def wrap_system_dynamics_for_integration(
27
- model: js.model.JaxSimModel,
28
- data: js.data.JaxSimModelData,
29
27
  *,
30
28
  system_dynamics: SystemDynamicsFromModelAndData,
31
- **kwargs,
29
+ **kwargs: dict[str, Any],
32
30
  ) -> jaxsim.integrators.common.SystemDynamics[ODEState, ODEState]:
33
31
  """
34
- Wrap generic system dynamics operating on `JaxSimModel` and `JaxSimModelData`
35
- for integration with `jaxsim.integrators`.
32
+ Wrap the system dynamics considered by JaxSim integrators in a generic
33
+ `f(x, t, **u, **parameters)` function.
36
34
 
37
35
  Args:
38
- model: The model to consider.
39
- data: The data of the considered model.
40
36
  system_dynamics: The system dynamics to wrap.
41
37
  **kwargs: Additional kwargs to close over the system dynamics.
42
38
 
43
39
  Returns:
44
- The system dynamics closed over the model, the data, and the additional kwargs.
40
+ The system dynamics closed over the additional kwargs to be used by
41
+ JaxSim integrators.
45
42
  """
46
43
 
47
- # We allow to close `system_dynamics` over additional kwargs.
48
- kwargs_closed = kwargs.copy()
49
-
50
- # Create a local copy of model and data.
51
- # The wrapped dynamics will hold a reference of this object.
52
- model_closed = model.copy()
53
- data_closed = data.copy().replace(
54
- state=js.ode_data.ODEState.zero(model=model_closed, data=data)
55
- )
56
-
44
+ # Close `system_dynamics` over additional kwargs.
45
+ # Similarly to what done in `jaxsim.api.model.step`, to be future-proof, we use the
46
+ # following logic to allow the caller to close over arguments having the same name
47
+ # of the ones used in the `wrap_system_dynamics_for_integration` function.
48
+ kwargs = kwargs.copy() if kwargs is not None else {}
49
+ colliding_system_dynamics_kwargs = kwargs.pop("system_dynamics_kwargs", {})
50
+ system_dynamics_kwargs = kwargs | colliding_system_dynamics_kwargs
51
+
52
+ # Remove `model` and `data` for backward compatibility.
53
+ # It's no longer necessary to close over them at this stage, as this is always
54
+ # done in `jaxsim.api.model.step`.
55
+ # We can remove the following lines in a few releases.
56
+ _ = system_dynamics_kwargs.pop("data", None)
57
+ _ = system_dynamics_kwargs.pop("model", None)
58
+
59
+ # Create the function with the signature expected by our generic integrators.
60
+ # Note that our system dynamics is time independent.
57
61
  def f(x: ODEState, t: Time, **kwargs_f) -> tuple[ODEState, dict[str, Any]]:
58
62
 
59
- # Allow caller to override the closed data and model objects.
60
- data_f = kwargs_f.pop("data", data_closed)
61
- model_f = kwargs_f.pop("model", model_closed)
63
+ # Get the data and model objects from the kwargs.
64
+ data_f = kwargs_f.pop("data")
65
+ model_f = kwargs_f.pop("model")
62
66
 
63
67
  # Update the state and time stored inside data.
64
68
  with data_f.editable(validate=True) as data_rw:
@@ -69,7 +73,7 @@ def wrap_system_dynamics_for_integration(
69
73
  return system_dynamics(
70
74
  model=model_f,
71
75
  data=data_rw,
72
- **(kwargs_closed | kwargs_f),
76
+ **(system_dynamics_kwargs | kwargs_f),
73
77
  )
74
78
 
75
79
  f: jaxsim.integrators.common.SystemDynamics[ODEState, ODEState]
@@ -10,7 +10,7 @@ from jax_dataclasses import Static
10
10
  import jaxsim.api as js
11
11
  import jaxsim.math
12
12
  import jaxsim.typing as jtp
13
- from jaxsim import exceptions
13
+ from jaxsim import exceptions, logging
14
14
  from jaxsim.utils.jaxsim_dataclass import JaxsimDataclass, Mutability
15
15
 
16
16
  try:
@@ -49,16 +49,11 @@ class SystemDynamics(Protocol[State, StateDerivative]):
49
49
  @jax_dataclasses.pytree_dataclass
50
50
  class Integrator(JaxsimDataclass, abc.ABC, Generic[State, StateDerivative]):
51
51
 
52
- AfterInitKey: ClassVar[str] = "after_init"
53
- InitializingKey: ClassVar[str] = "initializing"
54
-
55
- AuxDictDynamicsKey: ClassVar[str] = "aux_dict_dynamics"
56
-
57
52
  dynamics: Static[SystemDynamics[State, StateDerivative]] = dataclasses.field(
58
53
  repr=False, hash=False, compare=False, kw_only=True
59
54
  )
60
55
 
61
- params: dict[str, Any] = dataclasses.field(
56
+ metadata: dict[str, Any] = dataclasses.field(
62
57
  default_factory=dict, repr=False, hash=False, compare=False, kw_only=True
63
58
  )
64
59
 
@@ -88,9 +83,9 @@ class Integrator(JaxsimDataclass, abc.ABC, Generic[State, StateDerivative]):
88
83
  t0: Time,
89
84
  dt: TimeStep,
90
85
  *,
91
- params: dict[str, Any],
86
+ metadata: dict[str, Any] | None = None,
92
87
  **kwargs,
93
- ) -> tuple[State, dict[str, Any]]:
88
+ ) -> tuple[NextState, dict[str, Any]]:
94
89
  """
95
90
  Perform a single integration step.
96
91
 
@@ -98,28 +93,30 @@ class Integrator(JaxsimDataclass, abc.ABC, Generic[State, StateDerivative]):
98
93
  x0: The initial state of the system.
99
94
  t0: The initial time of the system.
100
95
  dt: The time step of the integration.
101
- params: The auxiliary dictionary of the integrator.
96
+ metadata: The state auxiliary dictionary of the integrator.
102
97
  **kwargs: Additional keyword arguments.
103
98
 
104
99
  Returns:
105
100
  The final state of the system and the updated auxiliary dictionary.
106
101
  """
107
102
 
103
+ metadata = metadata if metadata is not None else {}
104
+
108
105
  with self.editable(validate=False) as integrator:
109
- integrator.params = params
106
+ integrator.metadata = metadata
110
107
 
111
108
  with integrator.mutable_context(mutability=Mutability.MUTABLE):
112
- xf, aux_dict = integrator(x0, t0, dt, **kwargs)
109
+ xf, metadata_step = integrator(x0, t0, dt, **kwargs)
113
110
 
114
111
  return (
115
112
  xf,
116
- integrator.params
117
- | {Integrator.AfterInitKey: jnp.array(False).astype(bool)}
118
- | aux_dict,
113
+ metadata | metadata_step,
119
114
  )
120
115
 
121
116
  @abc.abstractmethod
122
- def __call__(self, x0: State, t0: Time, dt: TimeStep, **kwargs) -> NextState:
117
+ def __call__(
118
+ self, x0: State, t0: Time, dt: TimeStep, **kwargs
119
+ ) -> tuple[NextState, dict[str, Any]]:
123
120
  pass
124
121
 
125
122
  def init(
@@ -131,62 +128,12 @@ class Integrator(JaxsimDataclass, abc.ABC, Generic[State, StateDerivative]):
131
128
  include_dynamics_aux_dict: bool = False,
132
129
  **kwargs,
133
130
  ) -> dict[str, Any]:
134
- """
135
- Initialize the integrator.
136
-
137
- Args:
138
- x0: The initial state of the system.
139
- t0: The initial time of the system.
140
- dt: The time step of the integration.
141
-
142
- Returns:
143
- The auxiliary dictionary of the integrator.
144
131
 
145
- Note:
146
- This method should have the same signature as the inherited `__call__`
147
- method, including additional kwargs.
148
-
149
- Note:
150
- If the integrator supports FSAL, the pair `(x0, t0)` must match the real
151
- initial state and time of the system, otherwise the initial derivative of
152
- the first step will be wrong.
153
- """
154
-
155
- with self.editable(validate=False) as integrator:
156
-
157
- # Initialize the integrator parameters.
158
- # For initialization purpose, the integrators can check if the
159
- # `Integrator.InitializingKey` is present in their parameters.
160
- # The AfterInitKey is used in the first step after initialization.
161
- integrator.params = {
162
- Integrator.InitializingKey: jnp.array(True),
163
- Integrator.AfterInitKey: jnp.array(False),
164
- }
165
-
166
- # Run a dummy call of the integrator.
167
- # It is used only to get the params so that we know the structure
168
- # of the corresponding pytree.
169
- _ = integrator(x0, t0, dt, **kwargs)
170
-
171
- # Remove the injected key.
172
- _ = integrator.params.pop(Integrator.InitializingKey)
173
-
174
- # Make sure that all leafs of the dictionary are JAX arrays.
175
- # Also, since these are dummy parameters, set them all to zero.
176
- params_after_init = jax.tree.map(lambda l: jnp.zeros_like(l), integrator.params)
177
-
178
- # Mark the next step as first step after initialization.
179
- params_after_init = params_after_init | {
180
- Integrator.AfterInitKey: jnp.array(True)
181
- }
182
-
183
- # Store the zero parameters in the integrator.
184
- # When the integrator is stepped, this is used to check if the passed
185
- # parameters are valid.
186
- with self.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):
187
- self.params = params_after_init
132
+ logging.warning(
133
+ "The 'init' method has been deprecated. There is no need to call it."
134
+ )
188
135
 
189
- return params_after_init
136
+ return {}
190
137
 
191
138
 
192
139
  @jax_dataclasses.pytree_dataclass
@@ -377,8 +324,11 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
377
324
  x0,
378
325
  )
379
326
 
380
- # Apply FSAL property by passing ẋ0 = f(x0, t0) from the previous iteration.
381
- get_ẋ0_and_aux_dict = lambda: self.params.get("dxdt0", f(x0, t0))
327
+ # Closure on metadata to either evaluate the dynamics at the initial state
328
+ # or to use the previous state derivative (only integrators supporting FSAL).
329
+ def get_ẋ0_and_aux_dict() -> tuple[StateDerivative, dict[str, Any]]:
330
+ ẋ0, aux_dict = f(x0, t0)
331
+ return self.metadata.get("dxdt0", ẋ0), aux_dict
382
332
 
383
333
  # We use a `jax.lax.scan` to compile the `f` function only once.
384
334
  # Otherwise, if we compute e.g. for RK4 sequentially, the jit-compiled code
@@ -405,8 +355,9 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
405
355
  # Compute the next time for the kᵢ evaluation.
406
356
  ti = t0 + c[i] * Δt
407
357
 
408
- # This is kᵢ, aux_dict = f(xᵢ, tᵢ).
409
- return f(xi, ti)
358
+ # Evaluate the dynamics.
359
+ ki, aux_dict = f(xi, ti)
360
+ return ki, aux_dict
410
361
 
411
362
  # This selector enables FSAL property in the first iteration (i=0).
412
363
  ki, aux_dict = jax.lax.cond(
@@ -431,7 +382,7 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
431
382
 
432
383
  # Update the FSAL property for the next iteration.
433
384
  if self.has_fsal:
434
- self.params["dxdt0"] = jax.tree.map(lambda l: l[self.index_of_fsal], K)
385
+ self.metadata["dxdt0"] = jax.tree.map(lambda l: l[self.index_of_fsal], K)
435
386
 
436
387
  # Compute the output state.
437
388
  # Note that z contains as many new states as the rows of `b.T`.
@@ -514,7 +465,7 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
514
465
  raise ValueError("The Butcher tableau is not valid.")
515
466
 
516
467
  if not ExplicitRungeKutta.butcher_tableau_is_explicit(A=A):
517
- return False
468
+ return False, None
518
469
 
519
470
  if index_of_solution >= b.T.shape[0]:
520
471
  msg = "The index of the solution (i-th row of `b.T`) is out of range."
@@ -12,6 +12,7 @@ import jax.numpy as jnp
12
12
  import jax_dataclasses
13
13
  from jax_dataclasses import Static
14
14
 
15
+ import jaxsim.utils.tracing
15
16
  from jaxsim import typing as jtp
16
17
  from jaxsim.utils import Mutability
17
18
 
@@ -219,6 +220,9 @@ def local_error_estimation(
219
220
  @jax_dataclasses.pytree_dataclass
220
221
  class EmbeddedRungeKutta(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
221
222
 
223
+ AfterInitKey: ClassVar[str] = "after_init"
224
+ InitializingKey: ClassVar[str] = "initializing"
225
+
222
226
  # Define the row of the integration output corresponding to the solution estimate.
223
227
  # This is the row of b.T that produces the state used e.g. by embedded methods to
224
228
  # implement the adaptive timestep logic.
@@ -246,40 +250,79 @@ class EmbeddedRungeKutta(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
246
250
  self,
247
251
  x0: State,
248
252
  t0: Time,
249
- dt: TimeStep | None = None,
250
- *,
251
- include_dynamics_aux_dict: bool = False,
253
+ dt: TimeStep,
252
254
  **kwargs,
253
255
  ) -> dict[str, Any]:
256
+ """
257
+ Initialize the integrator and get the metadata.
254
258
 
255
- # In these type of integrators, it's not relevant picking a meaningful dt.
256
- # We just need to execute __call__ once to initialize the dictionary of params.
257
- return super().init(
258
- x0=x0,
259
- t0=t0,
260
- dt=0.001,
261
- include_dynamics_aux_dict=include_dynamics_aux_dict,
262
- **kwargs,
259
+ Args:
260
+ x0: The initial state of the system.
261
+ t0: The initial time of the system.
262
+ dt: The time step of the integration.
263
+
264
+ Returns:
265
+ The metadata of the integrator to be passed to the first step.
266
+ """
267
+
268
+ if jaxsim.utils.tracing(var=jnp.zeros(0)):
269
+ raise RuntimeError("This method cannot be used within a JIT context")
270
+
271
+ with self.editable(validate=False) as integrator:
272
+
273
+ # Inject this key to signal that the integrator is initializing.
274
+ # This is used to allocate the arrays of the metadata dictionary,
275
+ # that are then filled with NaNs.
276
+ integrator.metadata = {EmbeddedRungeKutta.InitializingKey: jnp.array(True)}
277
+
278
+ # Run a dummy call of the integrator.
279
+ # It is used only to get the metadata so that we know the structure
280
+ # of the corresponding pytree.
281
+ _ = integrator(
282
+ x0, jnp.array(t0, dtype=float), jnp.array(dt, dtype=float), **kwargs
283
+ )
284
+
285
+ # Remove the injected key.
286
+ _ = integrator.metadata.pop(EmbeddedRungeKutta.InitializingKey)
287
+
288
+ # Make sure that all leafs of the dictionary are JAX arrays.
289
+ # Also, since these are dummy parameters, set them all to NaN.
290
+ metadata_after_init = jax.tree.map(
291
+ lambda l: jnp.nan * jnp.zeros_like(l), integrator.metadata
263
292
  )
264
293
 
294
+ # Store the zero parameters in the integrator.
295
+ # When the integrator is stepped, this is used to check if the passed
296
+ # parameters are valid.
297
+ with self.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):
298
+ self.metadata = metadata_after_init
299
+
300
+ return metadata_after_init
301
+
265
302
  def __call__(
266
303
  self, x0: State, t0: Time, dt: TimeStep, **kwargs
267
304
  ) -> tuple[NextState, dict[str, Any]]:
268
305
 
269
306
  # This method is called differently in three stages:
270
307
  #
271
- # 1. During initialization, to allocate a dummy params dictionary.
272
- # 2. During the first step, to compute the initial valid params dictionary.
273
- # 3. After the first step, to compute the next state and the next valid params.
308
+ # 1. During initialization, to allocate a dummy metadata dictionary.
309
+ # The metadata is a dictionary of float JAX arrays, that are initialized
310
+ # with the right shape and filled with NaNs.
311
+ # 2. During the first step, this method operates on the Nan-filled
312
+ # `self.metadata` attribute, and it populates with the actual metadata.
313
+ # 3. After the first step, this method operates on the actual metadata.
274
314
  #
275
- # Stage 1 produces a zero-filled dummy dictionary.
276
- # Stage 2 receives a dummy dictionary and produces valid parameters that can be
277
- # fed to later steps.
278
- # Stage 3 corresponds to any consecutive step after the first one. It can re-use
279
- # data (like for FSAL) from previous steps.
315
+ # In particular, we store the following information in the metadata:
316
+ # - The first attempt of the step size, `dt0`. This is either estimated during
317
+ # phase 2, or taken from the previous step during phase 3.
318
+ # - For integrators that support FSAL, the derivative at the initial state
319
+ # computed during the previous step. This can be done because FSAL integrators
320
+ # evaluate the dynamics at the final state of the previous step, that matches
321
+ # the initial state of the current step.
280
322
  #
281
- integrator_init = self.params.get(self.InitializingKey, jnp.array(False))
282
- integrator_first_step = self.params.get(self.AfterInitKey, jnp.array(False))
323
+ integrator_init = jnp.array(
324
+ self.metadata.get(self.InitializingKey, False), dtype=bool
325
+ )
283
326
 
284
327
  # Close f over optional kwargs.
285
328
  f = lambda x, t: self.dynamics(x=x, t=t, **kwargs)
@@ -292,34 +335,26 @@ class EmbeddedRungeKutta(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
292
335
  p̂ = self.order_of_solution_estimate
293
336
  q = jnp.minimum(p, p̂)
294
337
 
295
- # In Stage 1 and 2, estimate from scratch dt0 and dxdt0.
296
- # In Stage 3, dt0 is taken from the previous step. If the integrator supports
297
- # FSAL, dxdt0 is taken from the previous step. Otherwise, it is computed by
298
- # evaluating the dynamics.
299
- self.params["dt0"], self.params["dxdt0"], aux_dict = jax.lax.cond(
300
- pred=jnp.logical_or("dt0" not in self.params, integrator_first_step),
301
- true_fun=lambda params: (
302
- *estimate_step_size(
303
- x0=x0, t0=t0, f=f, order=p, atol=self.atol, rtol=self.rtol
304
- ),
305
- self.params.get("dxdt0", f(x0, t0))[1],
338
+ # The value of dt0 is NaN (or, at least, it should be) only after initialization
339
+ # and before the first step.
340
+ self.metadata["dt0"], self.metadata["dxdt0"] = jax.lax.cond(
341
+ pred=("dt0" in self.metadata)
342
+ & ~jnp.isnan(self.metadata.get("dt0", 0.0)).any(),
343
+ true_fun=lambda metadata: (
344
+ metadata.get("dt0", jnp.array(0.0, dtype=float)),
345
+ self.metadata.get("dxdt0", f(x0, t0)[0]),
306
346
  ),
307
- false_fun=lambda params: (
308
- params.get("dt0", jnp.array(0).astype(float)),
309
- *self.params.get("dxdt0", f(x0, t0)),
347
+ false_fun=lambda aux: estimate_step_size(
348
+ x0=x0, t0=t0, f=f, order=p, atol=self.atol, rtol=self.rtol
310
349
  ),
311
- operand=self.params,
350
+ operand=self.metadata,
312
351
  )
313
352
 
314
- # If the integrator does not support FSAL, it is useless to store dxdt0.
315
- if not self.has_fsal:
316
- _ = self.params.pop("dxdt0")
317
-
318
353
  # Clip the estimated initial step size to the given bounds, if necessary.
319
- self.params["dt0"] = jnp.clip(
320
- self.params["dt0"],
321
- jnp.minimum(self.dt_min, self.params["dt0"]),
322
- jnp.minimum(self.dt_max, self.params["dt0"]),
354
+ self.metadata["dt0"] = jnp.clip(
355
+ self.metadata["dt0"],
356
+ jnp.minimum(self.dt_min, self.metadata["dt0"]),
357
+ jnp.minimum(self.dt_max, self.metadata["dt0"]),
323
358
  )
324
359
 
325
360
  # =========================================================
@@ -331,7 +366,7 @@ class EmbeddedRungeKutta(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
331
366
  carry0: Carry = (
332
367
  x0,
333
368
  jnp.array(t0).astype(float),
334
- self.params,
369
+ self.metadata,
335
370
  jnp.array(0, dtype=int),
336
371
  jnp.array(False).astype(bool),
337
372
  )
@@ -347,21 +382,21 @@ class EmbeddedRungeKutta(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
347
382
  def while_loop_body(carry: Carry) -> Carry:
348
383
 
349
384
  # Unpack the carry.
350
- x0, t0, params, discarded_steps, _ = carry
385
+ x0, t0, metadata, discarded_steps, _ = carry
351
386
 
352
387
  # Take care of the final adaptive step.
353
388
  # We want the final Δt to let us reach tf exactly.
354
389
  # Then we can exit the while loop.
355
- Δt0 = params["dt0"]
390
+ Δt0 = metadata["dt0"]
356
391
  Δt0 = jnp.where(t0 + Δt0 < tf, Δt0, tf - t0)
357
392
  break_loop = jnp.where(t0 + Δt0 < tf, False, True)
358
393
 
359
394
  # Run the underlying explicit RK integrator.
360
395
  # The output z contains multiple solutions (depending on the rows of b.T).
361
396
  with self.editable(validate=True) as integrator:
362
- integrator.params = params
397
+ integrator.metadata = metadata
363
398
  z, _ = integrator._compute_next_state(x0=x0, t0=t0, dt=Δt0, **kwargs)
364
- params_next = integrator.params
399
+ metadata_next = integrator.metadata
365
400
 
366
401
  # Extract the high-order solution xf and the low-order estimate x̂f.
367
402
  xf = jax.tree.map(lambda l: l[self.row_index_of_solution], z)
@@ -394,11 +429,11 @@ class EmbeddedRungeKutta(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
394
429
  def accept_step():
395
430
  # Use Δt_next in the next while loop.
396
431
  # If it is the last one, and Δt0 was clipped, return the initial Δt0.
397
- params_next_accepted = params_next | dict(
432
+ metadata_next_accepted = metadata_next | dict(
398
433
  dt0=jnp.clip(
399
434
  jax.lax.select(
400
435
  pred=break_loop,
401
- on_true=params["dt0"],
436
+ on_true=metadata["dt0"],
402
437
  on_false=Δt_next,
403
438
  ),
404
439
  self.dt_min,
@@ -419,16 +454,16 @@ class EmbeddedRungeKutta(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
419
454
  x0_next,
420
455
  t0_next,
421
456
  break_loop_next,
422
- params_next_accepted,
457
+ metadata_next_accepted,
423
458
  jnp.array(0, dtype=int),
424
459
  )
425
460
 
426
461
  def reject_step():
427
- # Get back the original params.
428
- params_next_rejected = params
462
+ # Get back the original metadata.
463
+ metadata_next_rejected = metadata
429
464
 
430
465
  # This time, with a reduced Δt.
431
- params_next_rejected["dt0"] = jnp.clip(
466
+ metadata_next_rejected["dt0"] = jnp.clip(
432
467
  Δt_next, self.dt_min, self.dt_max
433
468
  )
434
469
 
@@ -436,7 +471,7 @@ class EmbeddedRungeKutta(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
436
471
  x0,
437
472
  t0,
438
473
  False,
439
- params_next_rejected,
474
+ metadata_next_rejected,
440
475
  discarded_steps + 1,
441
476
  )
442
477
 
@@ -445,7 +480,7 @@ class EmbeddedRungeKutta(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
445
480
  x0_next,
446
481
  t0_next,
447
482
  break_loop,
448
- params_next,
483
+ metadata_next,
449
484
  discarded_steps,
450
485
  ) = jax.lax.cond(
451
486
  pred=jnp.array(
@@ -463,7 +498,7 @@ class EmbeddedRungeKutta(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
463
498
  return (
464
499
  x0_next,
465
500
  t0_next,
466
- params_next,
501
+ metadata_next,
467
502
  discarded_steps,
468
503
  break_loop,
469
504
  )
@@ -472,7 +507,7 @@ class EmbeddedRungeKutta(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
472
507
  (
473
508
  xf,
474
509
  tf,
475
- params_tf,
510
+ metadata_tf,
476
511
  _,
477
512
  _,
478
513
  ) = jax.lax.while_loop(
@@ -484,9 +519,9 @@ class EmbeddedRungeKutta(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
484
519
  # Store the parameters.
485
520
  # They will be returned to the caller in a functional way in the step method.
486
521
  with self.mutable_context(mutability=Mutability.MUTABLE):
487
- self.params = params_tf
522
+ self.metadata = metadata_tf
488
523
 
489
- return xf, aux_dict
524
+ return xf, {}
490
525
 
491
526
  @property
492
527
  def order_of_solution(self) -> int:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jaxsim
3
- Version: 0.4.3.dev245
3
+ Version: 0.4.3.dev269
4
4
  Summary: A differentiable physics engine and multibody dynamics library for control and robot learning.
5
5
  Author-email: Diego Ferigo <dgferigo@gmail.com>
6
6
  Maintainer-email: Diego Ferigo <dgferigo@gmail.com>, Filippo Luca Ferretti <filippo.ferretti@iit.it>
@@ -1,5 +1,5 @@
1
1
  jaxsim/__init__.py,sha256=opgtbhhd1kDsHI4H1vOd3loMPDRi884yQ3tohfFGfNc,3382
2
- jaxsim/_version.py,sha256=M-BxL76M-utEC1M8k4OEPThRT-EPO5mWYFUyfr6637A,428
2
+ jaxsim/_version.py,sha256=OawRAQdHlu_afqGSl1ctP75uaAlAxnxKssKlI9hT4tQ,428
3
3
  jaxsim/exceptions.py,sha256=vSoScaRD4nvh6jltgK9Ry5pKnE0O5hb4_yI_pk_fvR8,2175
4
4
  jaxsim/logging.py,sha256=STI-D_upXZYX-ZezLrlJJ0UlD5YspST0vZ_DcIwkzO4,1553
5
5
  jaxsim/typing.py,sha256=2HXy9hgazPXjofi1vLQ09ZubPtgVmg80U9NKmZ6NYiI,761
@@ -12,14 +12,14 @@ jaxsim/api/frame.py,sha256=yPSgNygHkvWlln4wShNt7vZm_fFobVEm7phsklNNyH8,12922
12
12
  jaxsim/api/joint.py,sha256=lksT1Doxz2jknHyhb4ls20z6f6dofpZSzBJtVacZXAE,7129
13
13
  jaxsim/api/kin_dyn_parameters.py,sha256=kbDN5n9uj8CamVJXk1U5oYLbxyjaWDIeUG0V68DCEFs,29578
14
14
  jaxsim/api/link.py,sha256=LAA6ZMQXkWomXeptURBtc7z3_xDZ2BBnBMhVrohh0bE,18621
15
- jaxsim/api/model.py,sha256=FSK2KvH2oNoGIt67gxQD7ScbClqXCn1Oy5JLtc4QHYg,70585
16
- jaxsim/api/ode.py,sha256=2KvGT3WW1eWEme4fH-5LlOXdLF6JUP5Z-IGY93ashUc,13815
15
+ jaxsim/api/model.py,sha256=bWMCE-tWyF1Ijf2dq3GOe3s0GK53E1Eh-YFVAgv7vNU,73398
16
+ jaxsim/api/ode.py,sha256=jFE4yk5lHSNk_SynbgA4tHcPdWq17cB-qUUW8KhcknQ,14289
17
17
  jaxsim/api/ode_data.py,sha256=1SD-x-lYk_YSEnVpxTLd69uOKC0mFUj44ZqpSmEDOxw,20190
18
18
  jaxsim/api/references.py,sha256=fW77LitZ8DYgT6ZmUInJfm5luBV1mTcqcNRiC_i79og,20862
19
19
  jaxsim/integrators/__init__.py,sha256=hxvOD-VK_mmd6v31wtC-nb28AYve1gLuZCNLV9wS-Kg,103
20
- jaxsim/integrators/common.py,sha256=78MBs89GxsL0wU2yAexjvBZt3HEtfZoGVIN9f0a8yTc,20305
20
+ jaxsim/integrators/common.py,sha256=_FZs7E0EazERGA3K0tGC1baUrs8sBDzYTf2U2mFYh9s,18329
21
21
  jaxsim/integrators/fixed_step.py,sha256=KpjRd6hHtapxDoo6D1kyDrVDSHnke2TepI5grFH7_bM,2693
22
- jaxsim/integrators/variable_step.py,sha256=1VoSU3GeFcGEuP2dgZQ83sTkI5Xe-IThqKlRoVtwGSE,21270
22
+ jaxsim/integrators/variable_step.py,sha256=hGYKG3Sq3QITgzIePmCVCrrirwagqsKnB3aYifAcKR4,22848
23
23
  jaxsim/math/__init__.py,sha256=8oPITEoGwgRcOeG8KxtqxPQ8b5uku1HNRMokpCoi9Tc,352
24
24
  jaxsim/math/adjoint.py,sha256=o1FCipkGwPtMbN2gFNIyUV8ADF3TX5fxElpTEXK0bIs,4377
25
25
  jaxsim/math/cross.py,sha256=U7yEx_l75mSy5g6O-jsjBztApvxC3WaV4MpkS5tThu4,1330
@@ -65,8 +65,8 @@ jaxsim/utils/__init__.py,sha256=Y5zyoRevl3EMVQadhZ4EtSwTEkDt2vcnFoRhPJjKTZ0,215
65
65
  jaxsim/utils/jaxsim_dataclass.py,sha256=TGmTQV2Lq7Q-2nLoAEaeNtkPa_qj0IKkdBm4COj46Os,11312
66
66
  jaxsim/utils/tracing.py,sha256=KDMoyVPlu2NJvFkhtZwq5AkqMMgajt3munvJom-vEjQ,650
67
67
  jaxsim/utils/wrappers.py,sha256=Fh82ZcaFi5fUnByyFLnmumaobsu1hJIvFdopUVzJ1ps,4052
68
- jaxsim-0.4.3.dev245.dist-info/LICENSE,sha256=eaYdFmdeMbiIoIiPzEK0MjP1S9wtFXjXNR5er49uLR0,1546
69
- jaxsim-0.4.3.dev245.dist-info/METADATA,sha256=KkzN5EWZkz0Oj6awfhPKRHWrXTZNNTl5vVlfNK64ZY8,17276
70
- jaxsim-0.4.3.dev245.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
71
- jaxsim-0.4.3.dev245.dist-info/top_level.txt,sha256=LxGMA8FLtXjQ6oI7N5gd_R_oSUHxpXxUEOfT1xS_ni0,7
72
- jaxsim-0.4.3.dev245.dist-info/RECORD,,
68
+ jaxsim-0.4.3.dev269.dist-info/LICENSE,sha256=eaYdFmdeMbiIoIiPzEK0MjP1S9wtFXjXNR5er49uLR0,1546
69
+ jaxsim-0.4.3.dev269.dist-info/METADATA,sha256=duM0e5oZnIH-TNDLRoAeBN0JXf8M4gbLlnrTw9DppiU,17276
70
+ jaxsim-0.4.3.dev269.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
71
+ jaxsim-0.4.3.dev269.dist-info/top_level.txt,sha256=LxGMA8FLtXjQ6oI7N5gd_R_oSUHxpXxUEOfT1xS_ni0,7
72
+ jaxsim-0.4.3.dev269.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.2.0)
2
+ Generator: setuptools (75.3.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5