jaxsim 0.4.3.dev245__py3-none-any.whl → 0.4.3.dev271__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/_version.py +2 -2
- jaxsim/api/model.py +92 -23
- jaxsim/api/ode.py +26 -22
- jaxsim/integrators/common.py +27 -76
- jaxsim/integrators/variable_step.py +96 -61
- {jaxsim-0.4.3.dev245.dist-info → jaxsim-0.4.3.dev271.dist-info}/METADATA +1 -1
- {jaxsim-0.4.3.dev245.dist-info → jaxsim-0.4.3.dev271.dist-info}/RECORD +10 -10
- {jaxsim-0.4.3.dev245.dist-info → jaxsim-0.4.3.dev271.dist-info}/WHEEL +1 -1
- {jaxsim-0.4.3.dev245.dist-info → jaxsim-0.4.3.dev271.dist-info}/LICENSE +0 -0
- {jaxsim-0.4.3.dev245.dist-info → jaxsim-0.4.3.dev271.dist-info}/top_level.txt +0 -0
jaxsim/_version.py
CHANGED
@@ -12,5 +12,5 @@ __version__: str
|
|
12
12
|
__version_tuple__: VERSION_TUPLE
|
13
13
|
version_tuple: VERSION_TUPLE
|
14
14
|
|
15
|
-
__version__ = version = '0.4.3.
|
16
|
-
__version_tuple__ = version_tuple = (0, 4, 3, '
|
15
|
+
__version__ = version = '0.4.3.dev271'
|
16
|
+
__version_tuple__ = version_tuple = (0, 4, 3, 'dev271')
|
jaxsim/api/model.py
CHANGED
@@ -54,6 +54,10 @@ class JaxSimModel(JaxsimDataclass):
|
|
54
54
|
default=None, repr=False
|
55
55
|
)
|
56
56
|
|
57
|
+
integrator: Static[jaxsim.integrators.Integrator | None] = dataclasses.field(
|
58
|
+
default=None, repr=False
|
59
|
+
)
|
60
|
+
|
57
61
|
_description: Static[wrappers.HashlessObject[ModelDescription | None]] = (
|
58
62
|
dataclasses.field(default=None, repr=False)
|
59
63
|
)
|
@@ -93,12 +97,16 @@ class JaxSimModel(JaxsimDataclass):
|
|
93
97
|
# Initialization and state
|
94
98
|
# ========================
|
95
99
|
|
96
|
-
@
|
100
|
+
@classmethod
|
97
101
|
def build_from_model_description(
|
102
|
+
cls,
|
98
103
|
model_description: str | pathlib.Path | rod.Model,
|
99
|
-
model_name: str | None = None,
|
100
104
|
*,
|
105
|
+
model_name: str | None = None,
|
101
106
|
time_step: jtp.FloatLike | None = None,
|
107
|
+
integrator: (
|
108
|
+
jaxsim.integrators.Integrator | type[jaxsim.integrators.Integrator] | None
|
109
|
+
) = None,
|
102
110
|
terrain: jaxsim.terrain.Terrain | None = None,
|
103
111
|
contact_model: jaxsim.rbda.contacts.ContactModel | None = None,
|
104
112
|
is_urdf: bool | None = None,
|
@@ -120,6 +128,10 @@ class JaxSimModel(JaxsimDataclass):
|
|
120
128
|
contact_model:
|
121
129
|
The contact model to consider.
|
122
130
|
If not specified, a soft contacts model is used.
|
131
|
+
integrator:
|
132
|
+
The integrator to use. If not specified, a default one is used.
|
133
|
+
This argument can either be a pre-built integrator instance or one
|
134
|
+
of the integrator classes defined in JaxSim.
|
123
135
|
is_urdf:
|
124
136
|
The optional flag to force the model description to be parsed as a URDF.
|
125
137
|
This is usually automatically inferred.
|
@@ -146,10 +158,11 @@ class JaxSimModel(JaxsimDataclass):
|
|
146
158
|
)
|
147
159
|
|
148
160
|
# Build the model.
|
149
|
-
model =
|
161
|
+
model = cls.build(
|
150
162
|
model_description=intermediate_description,
|
151
163
|
model_name=model_name,
|
152
164
|
time_step=time_step,
|
165
|
+
integrator=integrator,
|
153
166
|
terrain=terrain,
|
154
167
|
contact_model=contact_model,
|
155
168
|
)
|
@@ -160,12 +173,16 @@ class JaxSimModel(JaxsimDataclass):
|
|
160
173
|
|
161
174
|
return model
|
162
175
|
|
163
|
-
@
|
176
|
+
@classmethod
|
164
177
|
def build(
|
178
|
+
cls,
|
165
179
|
model_description: ModelDescription,
|
166
|
-
model_name: str | None = None,
|
167
180
|
*,
|
181
|
+
model_name: str | None = None,
|
168
182
|
time_step: jtp.FloatLike | None = None,
|
183
|
+
integrator: (
|
184
|
+
jaxsim.integrators.Integrator | type[jaxsim.integrators.Integrator] | None
|
185
|
+
) = None,
|
169
186
|
terrain: jaxsim.terrain.Terrain | None = None,
|
170
187
|
contact_model: jaxsim.rbda.contacts.ContactModel | None = None,
|
171
188
|
) -> JaxSimModel:
|
@@ -182,6 +199,11 @@ class JaxSimModel(JaxsimDataclass):
|
|
182
199
|
The default time step to consider for the simulation. It can be
|
183
200
|
manually overridden in the function that steps the simulation.
|
184
201
|
terrain: The terrain to consider (the default is a flat infinite plane).
|
202
|
+
The optional name of the model overriding the physics model name.
|
203
|
+
integrator:
|
204
|
+
The integrator to use. If not specified, a default one is used.
|
205
|
+
This argument can either be a pre-built integrator instance or one
|
206
|
+
of the integrator classes defined in JaxSim.
|
185
207
|
contact_model:
|
186
208
|
The contact model to consider.
|
187
209
|
If not specified, a soft contacts model is used.
|
@@ -195,23 +217,62 @@ class JaxSimModel(JaxsimDataclass):
|
|
195
217
|
|
196
218
|
# Consider the default terrain (a flat infinite plane) if not specified.
|
197
219
|
terrain = (
|
198
|
-
terrain
|
220
|
+
terrain
|
221
|
+
if terrain is not None
|
222
|
+
else JaxSimModel.__dataclass_fields__["terrain"].default_factory()
|
199
223
|
)
|
200
224
|
|
201
225
|
# Consider the default time step if not specified.
|
202
226
|
time_step = (
|
203
|
-
time_step
|
227
|
+
time_step
|
228
|
+
if time_step is not None
|
229
|
+
else JaxSimModel.__dataclass_fields__["time_step"].default_factory()
|
204
230
|
)
|
205
231
|
|
206
232
|
# Create the default contact model.
|
207
233
|
# It will be populated with an initial estimation of good parameters.
|
208
234
|
# While these might not be the best, they are a good starting point.
|
209
|
-
contact_model =
|
210
|
-
|
235
|
+
contact_model = (
|
236
|
+
contact_model
|
237
|
+
if contact_model is not None
|
238
|
+
else jaxsim.rbda.contacts.SoftContacts.build(
|
239
|
+
terrain=terrain, parameters=None
|
240
|
+
)
|
211
241
|
)
|
212
242
|
|
243
|
+
# Build the integrator if not provided.
|
244
|
+
match integrator:
|
245
|
+
|
246
|
+
# If None, build a default integrator.
|
247
|
+
case None:
|
248
|
+
|
249
|
+
integrator = jaxsim.integrators.fixed_step.Heun2SO3.build(
|
250
|
+
dynamics=js.ode.wrap_system_dynamics_for_integration(
|
251
|
+
system_dynamics=js.ode.system_dynamics
|
252
|
+
)
|
253
|
+
)
|
254
|
+
|
255
|
+
# If it's a pre-built integrator (also a custom one from the user)
|
256
|
+
# just use it as is.
|
257
|
+
case _ if isinstance(integrator, jaxsim.integrators.Integrator):
|
258
|
+
pass
|
259
|
+
|
260
|
+
# If an integrator class is passed, assume that it is a JaxSim integrator
|
261
|
+
# and build it with the default system dynamics.
|
262
|
+
case _ if issubclass(integrator, jaxsim.integrators.Integrator):
|
263
|
+
|
264
|
+
integrator_cls = integrator
|
265
|
+
integrator = integrator_cls.build(
|
266
|
+
dynamics=js.ode.wrap_system_dynamics_for_integration(
|
267
|
+
system_dynamics=js.ode.system_dynamics
|
268
|
+
)
|
269
|
+
)
|
270
|
+
|
271
|
+
case _:
|
272
|
+
raise ValueError(f"Invalid integrator: {integrator}")
|
273
|
+
|
213
274
|
# Build the model.
|
214
|
-
model =
|
275
|
+
model = cls(
|
215
276
|
model_name=model_name,
|
216
277
|
kin_dyn_parameters=js.kin_dyn_parameters.KynDynParameters.build(
|
217
278
|
model_description=model_description
|
@@ -219,6 +280,7 @@ class JaxSimModel(JaxsimDataclass):
|
|
219
280
|
time_step=time_step,
|
220
281
|
terrain=terrain,
|
221
282
|
contact_model=contact_model,
|
283
|
+
integrator=integrator,
|
222
284
|
# The following is wrapped as hashless since it's a static argument, and we
|
223
285
|
# don't want to trigger recompilation if it changes. All relevant parameters
|
224
286
|
# needed to compute kinematics and dynamics quantities are stored in the
|
@@ -404,6 +466,7 @@ def reduce(
|
|
404
466
|
reduced_model = JaxSimModel.build(
|
405
467
|
model_description=reduced_intermediate_description,
|
406
468
|
model_name=model.name(),
|
469
|
+
time_step=model.time_step,
|
407
470
|
terrain=model.terrain,
|
408
471
|
contact_model=model.contact_model,
|
409
472
|
)
|
@@ -1912,10 +1975,10 @@ def step(
|
|
1912
1975
|
model: JaxSimModel,
|
1913
1976
|
data: js.data.JaxSimModelData,
|
1914
1977
|
*,
|
1915
|
-
integrator: jaxsim.integrators.Integrator,
|
1916
1978
|
t0: jtp.FloatLike = 0.0,
|
1917
1979
|
dt: jtp.FloatLike | None = None,
|
1918
|
-
|
1980
|
+
integrator: jaxsim.integrators.Integrator | None = None,
|
1981
|
+
integrator_metadata: dict[str, Any] | None = None,
|
1919
1982
|
link_forces: jtp.MatrixLike | None = None,
|
1920
1983
|
joint_force_references: jtp.VectorLike | None = None,
|
1921
1984
|
**kwargs,
|
@@ -1927,7 +1990,7 @@ def step(
|
|
1927
1990
|
model: The model to consider.
|
1928
1991
|
data: The data of the considered model.
|
1929
1992
|
integrator: The integrator to use.
|
1930
|
-
|
1993
|
+
integrator_metadata: The metadata of the integrator, if needed.
|
1931
1994
|
t0: The initial time to consider. Only relevant for time-dependent dynamics.
|
1932
1995
|
dt: The time step to consider. If not specified, it is read from the model.
|
1933
1996
|
link_forces:
|
@@ -1937,8 +2000,9 @@ def step(
|
|
1937
2000
|
kwargs: Additional kwargs to pass to the integrator.
|
1938
2001
|
|
1939
2002
|
Returns:
|
1940
|
-
A tuple containing the new data of the model
|
1941
|
-
|
2003
|
+
A tuple containing the new data of the model and a dictionary of auxiliary
|
2004
|
+
data computed during the step. If the integrator has metadata, the dictionary
|
2005
|
+
will contain the new metadata stored in the `integrator_metadata` key.
|
1942
2006
|
|
1943
2007
|
Note:
|
1944
2008
|
In order to reduce the occurrences of frame conversions performed internally,
|
@@ -1953,8 +2017,9 @@ def step(
|
|
1953
2017
|
integrator_kwargs = kwargs.pop("integrator_kwargs", {})
|
1954
2018
|
integrator_kwargs = kwargs | integrator_kwargs
|
1955
2019
|
|
1956
|
-
#
|
1957
|
-
|
2020
|
+
# Extract the integrator and the optional metadata.
|
2021
|
+
integrator_metadata_t0 = integrator_metadata
|
2022
|
+
integrator = integrator if integrator is not None else model.integrator
|
1958
2023
|
|
1959
2024
|
# Initialize the time-related variables.
|
1960
2025
|
state_t0 = data.state
|
@@ -2010,11 +2075,11 @@ def step(
|
|
2010
2075
|
τ_references = references.joint_force_references(model=model)
|
2011
2076
|
|
2012
2077
|
# Step the dynamics forward.
|
2013
|
-
state_tf,
|
2078
|
+
state_tf, integrator_metadata_tf = integrator.step(
|
2014
2079
|
x0=state_t0,
|
2015
2080
|
t0=t0,
|
2016
2081
|
dt=dt,
|
2017
|
-
|
2082
|
+
metadata=integrator_metadata_t0,
|
2018
2083
|
# Always inject the current (model, data) pair into the system dynamics
|
2019
2084
|
# considered by the integrator, and include the input variables represented
|
2020
2085
|
# by the pair (f_L, τ_references).
|
@@ -2091,13 +2156,17 @@ def step(
|
|
2091
2156
|
)
|
2092
2157
|
)
|
2093
2158
|
|
2094
|
-
|
2095
|
-
|
2096
|
-
|
2159
|
+
# Reset the generalized velocity.
|
2160
|
+
data_tf = data_tf.reset_base_velocity(BW_nu_post_impact[0:6])
|
2161
|
+
data_tf = data_tf.reset_joint_velocities(BW_nu_post_impact[6:])
|
2097
2162
|
|
2098
2163
|
# Restore the input velocity representation.
|
2099
2164
|
data_tf = data_tf.replace(
|
2100
2165
|
velocity_representation=data.velocity_representation, validate=False
|
2101
2166
|
)
|
2102
2167
|
|
2103
|
-
return data_tf,
|
2168
|
+
return data_tf, {} | (
|
2169
|
+
dict(integrator_metadata=integrator_metadata_tf)
|
2170
|
+
if integrator_metadata is not None
|
2171
|
+
else {}
|
2172
|
+
)
|
jaxsim/api/ode.py
CHANGED
@@ -24,41 +24,45 @@ class SystemDynamicsFromModelAndData(Protocol):
|
|
24
24
|
|
25
25
|
|
26
26
|
def wrap_system_dynamics_for_integration(
|
27
|
-
model: js.model.JaxSimModel,
|
28
|
-
data: js.data.JaxSimModelData,
|
29
27
|
*,
|
30
28
|
system_dynamics: SystemDynamicsFromModelAndData,
|
31
|
-
**kwargs,
|
29
|
+
**kwargs: dict[str, Any],
|
32
30
|
) -> jaxsim.integrators.common.SystemDynamics[ODEState, ODEState]:
|
33
31
|
"""
|
34
|
-
Wrap
|
35
|
-
|
32
|
+
Wrap the system dynamics considered by JaxSim integrators in a generic
|
33
|
+
`f(x, t, **u, **parameters)` function.
|
36
34
|
|
37
35
|
Args:
|
38
|
-
model: The model to consider.
|
39
|
-
data: The data of the considered model.
|
40
36
|
system_dynamics: The system dynamics to wrap.
|
41
37
|
**kwargs: Additional kwargs to close over the system dynamics.
|
42
38
|
|
43
39
|
Returns:
|
44
|
-
The system dynamics closed over the
|
40
|
+
The system dynamics closed over the additional kwargs to be used by
|
41
|
+
JaxSim integrators.
|
45
42
|
"""
|
46
43
|
|
47
|
-
#
|
48
|
-
|
49
|
-
|
50
|
-
#
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
44
|
+
# Close `system_dynamics` over additional kwargs.
|
45
|
+
# Similarly to what done in `jaxsim.api.model.step`, to be future-proof, we use the
|
46
|
+
# following logic to allow the caller to close over arguments having the same name
|
47
|
+
# of the ones used in the `wrap_system_dynamics_for_integration` function.
|
48
|
+
kwargs = kwargs.copy() if kwargs is not None else {}
|
49
|
+
colliding_system_dynamics_kwargs = kwargs.pop("system_dynamics_kwargs", {})
|
50
|
+
system_dynamics_kwargs = kwargs | colliding_system_dynamics_kwargs
|
51
|
+
|
52
|
+
# Remove `model` and `data` for backward compatibility.
|
53
|
+
# It's no longer necessary to close over them at this stage, as this is always
|
54
|
+
# done in `jaxsim.api.model.step`.
|
55
|
+
# We can remove the following lines in a few releases.
|
56
|
+
_ = system_dynamics_kwargs.pop("data", None)
|
57
|
+
_ = system_dynamics_kwargs.pop("model", None)
|
58
|
+
|
59
|
+
# Create the function with the signature expected by our generic integrators.
|
60
|
+
# Note that our system dynamics is time independent.
|
57
61
|
def f(x: ODEState, t: Time, **kwargs_f) -> tuple[ODEState, dict[str, Any]]:
|
58
62
|
|
59
|
-
#
|
60
|
-
data_f = kwargs_f.pop("data"
|
61
|
-
model_f = kwargs_f.pop("model"
|
63
|
+
# Get the data and model objects from the kwargs.
|
64
|
+
data_f = kwargs_f.pop("data")
|
65
|
+
model_f = kwargs_f.pop("model")
|
62
66
|
|
63
67
|
# Update the state and time stored inside data.
|
64
68
|
with data_f.editable(validate=True) as data_rw:
|
@@ -69,7 +73,7 @@ def wrap_system_dynamics_for_integration(
|
|
69
73
|
return system_dynamics(
|
70
74
|
model=model_f,
|
71
75
|
data=data_rw,
|
72
|
-
**(
|
76
|
+
**(system_dynamics_kwargs | kwargs_f),
|
73
77
|
)
|
74
78
|
|
75
79
|
f: jaxsim.integrators.common.SystemDynamics[ODEState, ODEState]
|
jaxsim/integrators/common.py
CHANGED
@@ -10,7 +10,7 @@ from jax_dataclasses import Static
|
|
10
10
|
import jaxsim.api as js
|
11
11
|
import jaxsim.math
|
12
12
|
import jaxsim.typing as jtp
|
13
|
-
from jaxsim import exceptions
|
13
|
+
from jaxsim import exceptions, logging
|
14
14
|
from jaxsim.utils.jaxsim_dataclass import JaxsimDataclass, Mutability
|
15
15
|
|
16
16
|
try:
|
@@ -49,16 +49,11 @@ class SystemDynamics(Protocol[State, StateDerivative]):
|
|
49
49
|
@jax_dataclasses.pytree_dataclass
|
50
50
|
class Integrator(JaxsimDataclass, abc.ABC, Generic[State, StateDerivative]):
|
51
51
|
|
52
|
-
AfterInitKey: ClassVar[str] = "after_init"
|
53
|
-
InitializingKey: ClassVar[str] = "initializing"
|
54
|
-
|
55
|
-
AuxDictDynamicsKey: ClassVar[str] = "aux_dict_dynamics"
|
56
|
-
|
57
52
|
dynamics: Static[SystemDynamics[State, StateDerivative]] = dataclasses.field(
|
58
53
|
repr=False, hash=False, compare=False, kw_only=True
|
59
54
|
)
|
60
55
|
|
61
|
-
|
56
|
+
metadata: dict[str, Any] = dataclasses.field(
|
62
57
|
default_factory=dict, repr=False, hash=False, compare=False, kw_only=True
|
63
58
|
)
|
64
59
|
|
@@ -88,9 +83,9 @@ class Integrator(JaxsimDataclass, abc.ABC, Generic[State, StateDerivative]):
|
|
88
83
|
t0: Time,
|
89
84
|
dt: TimeStep,
|
90
85
|
*,
|
91
|
-
|
86
|
+
metadata: dict[str, Any] | None = None,
|
92
87
|
**kwargs,
|
93
|
-
) -> tuple[
|
88
|
+
) -> tuple[NextState, dict[str, Any]]:
|
94
89
|
"""
|
95
90
|
Perform a single integration step.
|
96
91
|
|
@@ -98,28 +93,30 @@ class Integrator(JaxsimDataclass, abc.ABC, Generic[State, StateDerivative]):
|
|
98
93
|
x0: The initial state of the system.
|
99
94
|
t0: The initial time of the system.
|
100
95
|
dt: The time step of the integration.
|
101
|
-
|
96
|
+
metadata: The state auxiliary dictionary of the integrator.
|
102
97
|
**kwargs: Additional keyword arguments.
|
103
98
|
|
104
99
|
Returns:
|
105
100
|
The final state of the system and the updated auxiliary dictionary.
|
106
101
|
"""
|
107
102
|
|
103
|
+
metadata = metadata if metadata is not None else {}
|
104
|
+
|
108
105
|
with self.editable(validate=False) as integrator:
|
109
|
-
integrator.
|
106
|
+
integrator.metadata = metadata
|
110
107
|
|
111
108
|
with integrator.mutable_context(mutability=Mutability.MUTABLE):
|
112
|
-
xf,
|
109
|
+
xf, metadata_step = integrator(x0, t0, dt, **kwargs)
|
113
110
|
|
114
111
|
return (
|
115
112
|
xf,
|
116
|
-
|
117
|
-
| {Integrator.AfterInitKey: jnp.array(False).astype(bool)}
|
118
|
-
| aux_dict,
|
113
|
+
metadata | metadata_step,
|
119
114
|
)
|
120
115
|
|
121
116
|
@abc.abstractmethod
|
122
|
-
def __call__(
|
117
|
+
def __call__(
|
118
|
+
self, x0: State, t0: Time, dt: TimeStep, **kwargs
|
119
|
+
) -> tuple[NextState, dict[str, Any]]:
|
123
120
|
pass
|
124
121
|
|
125
122
|
def init(
|
@@ -131,62 +128,12 @@ class Integrator(JaxsimDataclass, abc.ABC, Generic[State, StateDerivative]):
|
|
131
128
|
include_dynamics_aux_dict: bool = False,
|
132
129
|
**kwargs,
|
133
130
|
) -> dict[str, Any]:
|
134
|
-
"""
|
135
|
-
Initialize the integrator.
|
136
|
-
|
137
|
-
Args:
|
138
|
-
x0: The initial state of the system.
|
139
|
-
t0: The initial time of the system.
|
140
|
-
dt: The time step of the integration.
|
141
|
-
|
142
|
-
Returns:
|
143
|
-
The auxiliary dictionary of the integrator.
|
144
131
|
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
Note:
|
150
|
-
If the integrator supports FSAL, the pair `(x0, t0)` must match the real
|
151
|
-
initial state and time of the system, otherwise the initial derivative of
|
152
|
-
the first step will be wrong.
|
153
|
-
"""
|
154
|
-
|
155
|
-
with self.editable(validate=False) as integrator:
|
156
|
-
|
157
|
-
# Initialize the integrator parameters.
|
158
|
-
# For initialization purpose, the integrators can check if the
|
159
|
-
# `Integrator.InitializingKey` is present in their parameters.
|
160
|
-
# The AfterInitKey is used in the first step after initialization.
|
161
|
-
integrator.params = {
|
162
|
-
Integrator.InitializingKey: jnp.array(True),
|
163
|
-
Integrator.AfterInitKey: jnp.array(False),
|
164
|
-
}
|
165
|
-
|
166
|
-
# Run a dummy call of the integrator.
|
167
|
-
# It is used only to get the params so that we know the structure
|
168
|
-
# of the corresponding pytree.
|
169
|
-
_ = integrator(x0, t0, dt, **kwargs)
|
170
|
-
|
171
|
-
# Remove the injected key.
|
172
|
-
_ = integrator.params.pop(Integrator.InitializingKey)
|
173
|
-
|
174
|
-
# Make sure that all leafs of the dictionary are JAX arrays.
|
175
|
-
# Also, since these are dummy parameters, set them all to zero.
|
176
|
-
params_after_init = jax.tree.map(lambda l: jnp.zeros_like(l), integrator.params)
|
177
|
-
|
178
|
-
# Mark the next step as first step after initialization.
|
179
|
-
params_after_init = params_after_init | {
|
180
|
-
Integrator.AfterInitKey: jnp.array(True)
|
181
|
-
}
|
182
|
-
|
183
|
-
# Store the zero parameters in the integrator.
|
184
|
-
# When the integrator is stepped, this is used to check if the passed
|
185
|
-
# parameters are valid.
|
186
|
-
with self.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):
|
187
|
-
self.params = params_after_init
|
132
|
+
logging.warning(
|
133
|
+
"The 'init' method has been deprecated. There is no need to call it."
|
134
|
+
)
|
188
135
|
|
189
|
-
return
|
136
|
+
return {}
|
190
137
|
|
191
138
|
|
192
139
|
@jax_dataclasses.pytree_dataclass
|
@@ -377,8 +324,11 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
|
|
377
324
|
x0,
|
378
325
|
)
|
379
326
|
|
380
|
-
#
|
381
|
-
|
327
|
+
# Closure on metadata to either evaluate the dynamics at the initial state
|
328
|
+
# or to use the previous state derivative (only integrators supporting FSAL).
|
329
|
+
def get_ẋ0_and_aux_dict() -> tuple[StateDerivative, dict[str, Any]]:
|
330
|
+
ẋ0, aux_dict = f(x0, t0)
|
331
|
+
return self.metadata.get("dxdt0", ẋ0), aux_dict
|
382
332
|
|
383
333
|
# We use a `jax.lax.scan` to compile the `f` function only once.
|
384
334
|
# Otherwise, if we compute e.g. for RK4 sequentially, the jit-compiled code
|
@@ -405,8 +355,9 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
|
|
405
355
|
# Compute the next time for the kᵢ evaluation.
|
406
356
|
ti = t0 + c[i] * Δt
|
407
357
|
|
408
|
-
#
|
409
|
-
|
358
|
+
# Evaluate the dynamics.
|
359
|
+
ki, aux_dict = f(xi, ti)
|
360
|
+
return ki, aux_dict
|
410
361
|
|
411
362
|
# This selector enables FSAL property in the first iteration (i=0).
|
412
363
|
ki, aux_dict = jax.lax.cond(
|
@@ -431,7 +382,7 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
|
|
431
382
|
|
432
383
|
# Update the FSAL property for the next iteration.
|
433
384
|
if self.has_fsal:
|
434
|
-
self.
|
385
|
+
self.metadata["dxdt0"] = jax.tree.map(lambda l: l[self.index_of_fsal], K)
|
435
386
|
|
436
387
|
# Compute the output state.
|
437
388
|
# Note that z contains as many new states as the rows of `b.T`.
|
@@ -514,7 +465,7 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
|
|
514
465
|
raise ValueError("The Butcher tableau is not valid.")
|
515
466
|
|
516
467
|
if not ExplicitRungeKutta.butcher_tableau_is_explicit(A=A):
|
517
|
-
return False
|
468
|
+
return False, None
|
518
469
|
|
519
470
|
if index_of_solution >= b.T.shape[0]:
|
520
471
|
msg = "The index of the solution (i-th row of `b.T`) is out of range."
|
@@ -12,6 +12,7 @@ import jax.numpy as jnp
|
|
12
12
|
import jax_dataclasses
|
13
13
|
from jax_dataclasses import Static
|
14
14
|
|
15
|
+
import jaxsim.utils.tracing
|
15
16
|
from jaxsim import typing as jtp
|
16
17
|
from jaxsim.utils import Mutability
|
17
18
|
|
@@ -219,6 +220,9 @@ def local_error_estimation(
|
|
219
220
|
@jax_dataclasses.pytree_dataclass
|
220
221
|
class EmbeddedRungeKutta(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
|
221
222
|
|
223
|
+
AfterInitKey: ClassVar[str] = "after_init"
|
224
|
+
InitializingKey: ClassVar[str] = "initializing"
|
225
|
+
|
222
226
|
# Define the row of the integration output corresponding to the solution estimate.
|
223
227
|
# This is the row of b.T that produces the state used e.g. by embedded methods to
|
224
228
|
# implement the adaptive timestep logic.
|
@@ -246,40 +250,79 @@ class EmbeddedRungeKutta(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
|
|
246
250
|
self,
|
247
251
|
x0: State,
|
248
252
|
t0: Time,
|
249
|
-
dt: TimeStep
|
250
|
-
*,
|
251
|
-
include_dynamics_aux_dict: bool = False,
|
253
|
+
dt: TimeStep,
|
252
254
|
**kwargs,
|
253
255
|
) -> dict[str, Any]:
|
256
|
+
"""
|
257
|
+
Initialize the integrator and get the metadata.
|
254
258
|
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
259
|
+
Args:
|
260
|
+
x0: The initial state of the system.
|
261
|
+
t0: The initial time of the system.
|
262
|
+
dt: The time step of the integration.
|
263
|
+
|
264
|
+
Returns:
|
265
|
+
The metadata of the integrator to be passed to the first step.
|
266
|
+
"""
|
267
|
+
|
268
|
+
if jaxsim.utils.tracing(var=jnp.zeros(0)):
|
269
|
+
raise RuntimeError("This method cannot be used within a JIT context")
|
270
|
+
|
271
|
+
with self.editable(validate=False) as integrator:
|
272
|
+
|
273
|
+
# Inject this key to signal that the integrator is initializing.
|
274
|
+
# This is used to allocate the arrays of the metadata dictionary,
|
275
|
+
# that are then filled with NaNs.
|
276
|
+
integrator.metadata = {EmbeddedRungeKutta.InitializingKey: jnp.array(True)}
|
277
|
+
|
278
|
+
# Run a dummy call of the integrator.
|
279
|
+
# It is used only to get the metadata so that we know the structure
|
280
|
+
# of the corresponding pytree.
|
281
|
+
_ = integrator(
|
282
|
+
x0, jnp.array(t0, dtype=float), jnp.array(dt, dtype=float), **kwargs
|
283
|
+
)
|
284
|
+
|
285
|
+
# Remove the injected key.
|
286
|
+
_ = integrator.metadata.pop(EmbeddedRungeKutta.InitializingKey)
|
287
|
+
|
288
|
+
# Make sure that all leafs of the dictionary are JAX arrays.
|
289
|
+
# Also, since these are dummy parameters, set them all to NaN.
|
290
|
+
metadata_after_init = jax.tree.map(
|
291
|
+
lambda l: jnp.nan * jnp.zeros_like(l), integrator.metadata
|
263
292
|
)
|
264
293
|
|
294
|
+
# Store the zero parameters in the integrator.
|
295
|
+
# When the integrator is stepped, this is used to check if the passed
|
296
|
+
# parameters are valid.
|
297
|
+
with self.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):
|
298
|
+
self.metadata = metadata_after_init
|
299
|
+
|
300
|
+
return metadata_after_init
|
301
|
+
|
265
302
|
def __call__(
|
266
303
|
self, x0: State, t0: Time, dt: TimeStep, **kwargs
|
267
304
|
) -> tuple[NextState, dict[str, Any]]:
|
268
305
|
|
269
306
|
# This method is called differently in three stages:
|
270
307
|
#
|
271
|
-
# 1. During initialization, to allocate a dummy
|
272
|
-
#
|
273
|
-
#
|
308
|
+
# 1. During initialization, to allocate a dummy metadata dictionary.
|
309
|
+
# The metadata is a dictionary of float JAX arrays, that are initialized
|
310
|
+
# with the right shape and filled with NaNs.
|
311
|
+
# 2. During the first step, this method operates on the Nan-filled
|
312
|
+
# `self.metadata` attribute, and it populates with the actual metadata.
|
313
|
+
# 3. After the first step, this method operates on the actual metadata.
|
274
314
|
#
|
275
|
-
#
|
276
|
-
#
|
277
|
-
#
|
278
|
-
#
|
279
|
-
#
|
315
|
+
# In particular, we store the following information in the metadata:
|
316
|
+
# - The first attempt of the step size, `dt0`. This is either estimated during
|
317
|
+
# phase 2, or taken from the previous step during phase 3.
|
318
|
+
# - For integrators that support FSAL, the derivative at the initial state
|
319
|
+
# computed during the previous step. This can be done because FSAL integrators
|
320
|
+
# evaluate the dynamics at the final state of the previous step, that matches
|
321
|
+
# the initial state of the current step.
|
280
322
|
#
|
281
|
-
integrator_init =
|
282
|
-
|
323
|
+
integrator_init = jnp.array(
|
324
|
+
self.metadata.get(self.InitializingKey, False), dtype=bool
|
325
|
+
)
|
283
326
|
|
284
327
|
# Close f over optional kwargs.
|
285
328
|
f = lambda x, t: self.dynamics(x=x, t=t, **kwargs)
|
@@ -292,34 +335,26 @@ class EmbeddedRungeKutta(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
|
|
292
335
|
p̂ = self.order_of_solution_estimate
|
293
336
|
q = jnp.minimum(p, p̂)
|
294
337
|
|
295
|
-
#
|
296
|
-
#
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
x0=x0, t0=t0, f=f, order=p, atol=self.atol, rtol=self.rtol
|
304
|
-
),
|
305
|
-
self.params.get("dxdt0", f(x0, t0))[1],
|
338
|
+
# The value of dt0 is NaN (or, at least, it should be) only after initialization
|
339
|
+
# and before the first step.
|
340
|
+
self.metadata["dt0"], self.metadata["dxdt0"] = jax.lax.cond(
|
341
|
+
pred=("dt0" in self.metadata)
|
342
|
+
& ~jnp.isnan(self.metadata.get("dt0", 0.0)).any(),
|
343
|
+
true_fun=lambda metadata: (
|
344
|
+
metadata.get("dt0", jnp.array(0.0, dtype=float)),
|
345
|
+
self.metadata.get("dxdt0", f(x0, t0)[0]),
|
306
346
|
),
|
307
|
-
false_fun=lambda
|
308
|
-
|
309
|
-
*self.params.get("dxdt0", f(x0, t0)),
|
347
|
+
false_fun=lambda aux: estimate_step_size(
|
348
|
+
x0=x0, t0=t0, f=f, order=p, atol=self.atol, rtol=self.rtol
|
310
349
|
),
|
311
|
-
operand=self.
|
350
|
+
operand=self.metadata,
|
312
351
|
)
|
313
352
|
|
314
|
-
# If the integrator does not support FSAL, it is useless to store dxdt0.
|
315
|
-
if not self.has_fsal:
|
316
|
-
_ = self.params.pop("dxdt0")
|
317
|
-
|
318
353
|
# Clip the estimated initial step size to the given bounds, if necessary.
|
319
|
-
self.
|
320
|
-
self.
|
321
|
-
jnp.minimum(self.dt_min, self.
|
322
|
-
jnp.minimum(self.dt_max, self.
|
354
|
+
self.metadata["dt0"] = jnp.clip(
|
355
|
+
self.metadata["dt0"],
|
356
|
+
jnp.minimum(self.dt_min, self.metadata["dt0"]),
|
357
|
+
jnp.minimum(self.dt_max, self.metadata["dt0"]),
|
323
358
|
)
|
324
359
|
|
325
360
|
# =========================================================
|
@@ -331,7 +366,7 @@ class EmbeddedRungeKutta(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
|
|
331
366
|
carry0: Carry = (
|
332
367
|
x0,
|
333
368
|
jnp.array(t0).astype(float),
|
334
|
-
self.
|
369
|
+
self.metadata,
|
335
370
|
jnp.array(0, dtype=int),
|
336
371
|
jnp.array(False).astype(bool),
|
337
372
|
)
|
@@ -347,21 +382,21 @@ class EmbeddedRungeKutta(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
|
|
347
382
|
def while_loop_body(carry: Carry) -> Carry:
|
348
383
|
|
349
384
|
# Unpack the carry.
|
350
|
-
x0, t0,
|
385
|
+
x0, t0, metadata, discarded_steps, _ = carry
|
351
386
|
|
352
387
|
# Take care of the final adaptive step.
|
353
388
|
# We want the final Δt to let us reach tf exactly.
|
354
389
|
# Then we can exit the while loop.
|
355
|
-
Δt0 =
|
390
|
+
Δt0 = metadata["dt0"]
|
356
391
|
Δt0 = jnp.where(t0 + Δt0 < tf, Δt0, tf - t0)
|
357
392
|
break_loop = jnp.where(t0 + Δt0 < tf, False, True)
|
358
393
|
|
359
394
|
# Run the underlying explicit RK integrator.
|
360
395
|
# The output z contains multiple solutions (depending on the rows of b.T).
|
361
396
|
with self.editable(validate=True) as integrator:
|
362
|
-
integrator.
|
397
|
+
integrator.metadata = metadata
|
363
398
|
z, _ = integrator._compute_next_state(x0=x0, t0=t0, dt=Δt0, **kwargs)
|
364
|
-
|
399
|
+
metadata_next = integrator.metadata
|
365
400
|
|
366
401
|
# Extract the high-order solution xf and the low-order estimate x̂f.
|
367
402
|
xf = jax.tree.map(lambda l: l[self.row_index_of_solution], z)
|
@@ -394,11 +429,11 @@ class EmbeddedRungeKutta(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
|
|
394
429
|
def accept_step():
|
395
430
|
# Use Δt_next in the next while loop.
|
396
431
|
# If it is the last one, and Δt0 was clipped, return the initial Δt0.
|
397
|
-
|
432
|
+
metadata_next_accepted = metadata_next | dict(
|
398
433
|
dt0=jnp.clip(
|
399
434
|
jax.lax.select(
|
400
435
|
pred=break_loop,
|
401
|
-
on_true=
|
436
|
+
on_true=metadata["dt0"],
|
402
437
|
on_false=Δt_next,
|
403
438
|
),
|
404
439
|
self.dt_min,
|
@@ -419,16 +454,16 @@ class EmbeddedRungeKutta(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
|
|
419
454
|
x0_next,
|
420
455
|
t0_next,
|
421
456
|
break_loop_next,
|
422
|
-
|
457
|
+
metadata_next_accepted,
|
423
458
|
jnp.array(0, dtype=int),
|
424
459
|
)
|
425
460
|
|
426
461
|
def reject_step():
|
427
|
-
# Get back the original
|
428
|
-
|
462
|
+
# Get back the original metadata.
|
463
|
+
metadata_next_rejected = metadata
|
429
464
|
|
430
465
|
# This time, with a reduced Δt.
|
431
|
-
|
466
|
+
metadata_next_rejected["dt0"] = jnp.clip(
|
432
467
|
Δt_next, self.dt_min, self.dt_max
|
433
468
|
)
|
434
469
|
|
@@ -436,7 +471,7 @@ class EmbeddedRungeKutta(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
|
|
436
471
|
x0,
|
437
472
|
t0,
|
438
473
|
False,
|
439
|
-
|
474
|
+
metadata_next_rejected,
|
440
475
|
discarded_steps + 1,
|
441
476
|
)
|
442
477
|
|
@@ -445,7 +480,7 @@ class EmbeddedRungeKutta(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
|
|
445
480
|
x0_next,
|
446
481
|
t0_next,
|
447
482
|
break_loop,
|
448
|
-
|
483
|
+
metadata_next,
|
449
484
|
discarded_steps,
|
450
485
|
) = jax.lax.cond(
|
451
486
|
pred=jnp.array(
|
@@ -463,7 +498,7 @@ class EmbeddedRungeKutta(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
|
|
463
498
|
return (
|
464
499
|
x0_next,
|
465
500
|
t0_next,
|
466
|
-
|
501
|
+
metadata_next,
|
467
502
|
discarded_steps,
|
468
503
|
break_loop,
|
469
504
|
)
|
@@ -472,7 +507,7 @@ class EmbeddedRungeKutta(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
|
|
472
507
|
(
|
473
508
|
xf,
|
474
509
|
tf,
|
475
|
-
|
510
|
+
metadata_tf,
|
476
511
|
_,
|
477
512
|
_,
|
478
513
|
) = jax.lax.while_loop(
|
@@ -484,9 +519,9 @@ class EmbeddedRungeKutta(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
|
|
484
519
|
# Store the parameters.
|
485
520
|
# They will be returned to the caller in a functional way in the step method.
|
486
521
|
with self.mutable_context(mutability=Mutability.MUTABLE):
|
487
|
-
self.
|
522
|
+
self.metadata = metadata_tf
|
488
523
|
|
489
|
-
return xf,
|
524
|
+
return xf, {}
|
490
525
|
|
491
526
|
@property
|
492
527
|
def order_of_solution(self) -> int:
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: jaxsim
|
3
|
-
Version: 0.4.3.
|
3
|
+
Version: 0.4.3.dev271
|
4
4
|
Summary: A differentiable physics engine and multibody dynamics library for control and robot learning.
|
5
5
|
Author-email: Diego Ferigo <dgferigo@gmail.com>
|
6
6
|
Maintainer-email: Diego Ferigo <dgferigo@gmail.com>, Filippo Luca Ferretti <filippo.ferretti@iit.it>
|
@@ -1,5 +1,5 @@
|
|
1
1
|
jaxsim/__init__.py,sha256=opgtbhhd1kDsHI4H1vOd3loMPDRi884yQ3tohfFGfNc,3382
|
2
|
-
jaxsim/_version.py,sha256=
|
2
|
+
jaxsim/_version.py,sha256=WDllAgWs3R6C27z-j1tgLt8CxLvz_Ys39X00J18KcPI,428
|
3
3
|
jaxsim/exceptions.py,sha256=vSoScaRD4nvh6jltgK9Ry5pKnE0O5hb4_yI_pk_fvR8,2175
|
4
4
|
jaxsim/logging.py,sha256=STI-D_upXZYX-ZezLrlJJ0UlD5YspST0vZ_DcIwkzO4,1553
|
5
5
|
jaxsim/typing.py,sha256=2HXy9hgazPXjofi1vLQ09ZubPtgVmg80U9NKmZ6NYiI,761
|
@@ -12,14 +12,14 @@ jaxsim/api/frame.py,sha256=yPSgNygHkvWlln4wShNt7vZm_fFobVEm7phsklNNyH8,12922
|
|
12
12
|
jaxsim/api/joint.py,sha256=lksT1Doxz2jknHyhb4ls20z6f6dofpZSzBJtVacZXAE,7129
|
13
13
|
jaxsim/api/kin_dyn_parameters.py,sha256=kbDN5n9uj8CamVJXk1U5oYLbxyjaWDIeUG0V68DCEFs,29578
|
14
14
|
jaxsim/api/link.py,sha256=LAA6ZMQXkWomXeptURBtc7z3_xDZ2BBnBMhVrohh0bE,18621
|
15
|
-
jaxsim/api/model.py,sha256=
|
16
|
-
jaxsim/api/ode.py,sha256=
|
15
|
+
jaxsim/api/model.py,sha256=dpQZDT0BodMfK1wmpG-STFh-rFsJStobQ1fhrWILK9o,73410
|
16
|
+
jaxsim/api/ode.py,sha256=jFE4yk5lHSNk_SynbgA4tHcPdWq17cB-qUUW8KhcknQ,14289
|
17
17
|
jaxsim/api/ode_data.py,sha256=1SD-x-lYk_YSEnVpxTLd69uOKC0mFUj44ZqpSmEDOxw,20190
|
18
18
|
jaxsim/api/references.py,sha256=fW77LitZ8DYgT6ZmUInJfm5luBV1mTcqcNRiC_i79og,20862
|
19
19
|
jaxsim/integrators/__init__.py,sha256=hxvOD-VK_mmd6v31wtC-nb28AYve1gLuZCNLV9wS-Kg,103
|
20
|
-
jaxsim/integrators/common.py,sha256=
|
20
|
+
jaxsim/integrators/common.py,sha256=_FZs7E0EazERGA3K0tGC1baUrs8sBDzYTf2U2mFYh9s,18329
|
21
21
|
jaxsim/integrators/fixed_step.py,sha256=KpjRd6hHtapxDoo6D1kyDrVDSHnke2TepI5grFH7_bM,2693
|
22
|
-
jaxsim/integrators/variable_step.py,sha256=
|
22
|
+
jaxsim/integrators/variable_step.py,sha256=hGYKG3Sq3QITgzIePmCVCrrirwagqsKnB3aYifAcKR4,22848
|
23
23
|
jaxsim/math/__init__.py,sha256=8oPITEoGwgRcOeG8KxtqxPQ8b5uku1HNRMokpCoi9Tc,352
|
24
24
|
jaxsim/math/adjoint.py,sha256=o1FCipkGwPtMbN2gFNIyUV8ADF3TX5fxElpTEXK0bIs,4377
|
25
25
|
jaxsim/math/cross.py,sha256=U7yEx_l75mSy5g6O-jsjBztApvxC3WaV4MpkS5tThu4,1330
|
@@ -65,8 +65,8 @@ jaxsim/utils/__init__.py,sha256=Y5zyoRevl3EMVQadhZ4EtSwTEkDt2vcnFoRhPJjKTZ0,215
|
|
65
65
|
jaxsim/utils/jaxsim_dataclass.py,sha256=TGmTQV2Lq7Q-2nLoAEaeNtkPa_qj0IKkdBm4COj46Os,11312
|
66
66
|
jaxsim/utils/tracing.py,sha256=KDMoyVPlu2NJvFkhtZwq5AkqMMgajt3munvJom-vEjQ,650
|
67
67
|
jaxsim/utils/wrappers.py,sha256=Fh82ZcaFi5fUnByyFLnmumaobsu1hJIvFdopUVzJ1ps,4052
|
68
|
-
jaxsim-0.4.3.
|
69
|
-
jaxsim-0.4.3.
|
70
|
-
jaxsim-0.4.3.
|
71
|
-
jaxsim-0.4.3.
|
72
|
-
jaxsim-0.4.3.
|
68
|
+
jaxsim-0.4.3.dev271.dist-info/LICENSE,sha256=eaYdFmdeMbiIoIiPzEK0MjP1S9wtFXjXNR5er49uLR0,1546
|
69
|
+
jaxsim-0.4.3.dev271.dist-info/METADATA,sha256=3B9vB5QnQXwr70wS3SjMkXOMgqPauA_bSIOinzvb7xU,17276
|
70
|
+
jaxsim-0.4.3.dev271.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
|
71
|
+
jaxsim-0.4.3.dev271.dist-info/top_level.txt,sha256=LxGMA8FLtXjQ6oI7N5gd_R_oSUHxpXxUEOfT1xS_ni0,7
|
72
|
+
jaxsim-0.4.3.dev271.dist-info/RECORD,,
|
File without changes
|
File without changes
|