jaxsim 0.2.dev191__py3-none-any.whl → 0.6.1.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/__init__.py +73 -22
- jaxsim/_version.py +2 -2
- jaxsim/api/__init__.py +13 -1
- jaxsim/api/com.py +423 -0
- jaxsim/api/common.py +48 -19
- jaxsim/api/contact.py +604 -52
- jaxsim/api/data.py +308 -163
- jaxsim/api/frame.py +471 -0
- jaxsim/api/joint.py +166 -37
- jaxsim/api/kin_dyn_parameters.py +901 -0
- jaxsim/api/link.py +277 -78
- jaxsim/api/model.py +1572 -362
- jaxsim/api/ode.py +324 -133
- jaxsim/api/ode_data.py +401 -0
- jaxsim/api/references.py +216 -80
- jaxsim/exceptions.py +80 -0
- jaxsim/integrators/__init__.py +2 -2
- jaxsim/integrators/common.py +191 -107
- jaxsim/integrators/fixed_step.py +97 -102
- jaxsim/integrators/variable_step.py +706 -0
- jaxsim/logging.py +1 -2
- jaxsim/math/__init__.py +13 -0
- jaxsim/math/adjoint.py +57 -22
- jaxsim/math/cross.py +16 -7
- jaxsim/math/inertia.py +10 -8
- jaxsim/math/joint_model.py +289 -0
- jaxsim/math/quaternion.py +54 -20
- jaxsim/math/rotation.py +27 -21
- jaxsim/math/skew.py +16 -5
- jaxsim/math/transform.py +102 -0
- jaxsim/math/utils.py +31 -0
- jaxsim/mujoco/__init__.py +2 -1
- jaxsim/mujoco/loaders.py +216 -29
- jaxsim/mujoco/model.py +163 -33
- jaxsim/mujoco/utils.py +228 -0
- jaxsim/mujoco/visualizer.py +107 -22
- jaxsim/parsers/__init__.py +0 -1
- jaxsim/parsers/descriptions/__init__.py +8 -2
- jaxsim/parsers/descriptions/collision.py +83 -26
- jaxsim/parsers/descriptions/joint.py +80 -87
- jaxsim/parsers/descriptions/link.py +58 -31
- jaxsim/parsers/descriptions/model.py +101 -68
- jaxsim/parsers/kinematic_graph.py +606 -229
- jaxsim/parsers/rod/meshes.py +104 -0
- jaxsim/parsers/rod/parser.py +125 -82
- jaxsim/parsers/rod/utils.py +127 -82
- jaxsim/rbda/__init__.py +11 -0
- jaxsim/rbda/aba.py +289 -0
- jaxsim/rbda/collidable_points.py +156 -0
- jaxsim/rbda/contacts/__init__.py +13 -0
- jaxsim/rbda/contacts/common.py +313 -0
- jaxsim/rbda/contacts/relaxed_rigid.py +605 -0
- jaxsim/rbda/contacts/rigid.py +462 -0
- jaxsim/rbda/contacts/soft.py +480 -0
- jaxsim/rbda/contacts/visco_elastic.py +1066 -0
- jaxsim/rbda/crba.py +167 -0
- jaxsim/rbda/forward_kinematics.py +117 -0
- jaxsim/rbda/jacobian.py +330 -0
- jaxsim/rbda/rnea.py +235 -0
- jaxsim/rbda/utils.py +160 -0
- jaxsim/terrain/__init__.py +2 -0
- jaxsim/terrain/terrain.py +238 -0
- jaxsim/typing.py +24 -24
- jaxsim/utils/__init__.py +1 -4
- jaxsim/utils/jaxsim_dataclass.py +289 -34
- jaxsim/utils/tracing.py +5 -11
- jaxsim/utils/wrappers.py +159 -0
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/LICENSE +1 -1
- jaxsim-0.6.1.dev2.dist-info/METADATA +465 -0
- jaxsim-0.6.1.dev2.dist-info/RECORD +74 -0
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/WHEEL +1 -1
- jaxsim/high_level/__init__.py +0 -2
- jaxsim/high_level/common.py +0 -11
- jaxsim/high_level/joint.py +0 -148
- jaxsim/high_level/link.py +0 -259
- jaxsim/high_level/model.py +0 -1686
- jaxsim/math/conv.py +0 -114
- jaxsim/math/joint.py +0 -102
- jaxsim/math/plucker.py +0 -100
- jaxsim/physics/__init__.py +0 -12
- jaxsim/physics/algos/__init__.py +0 -0
- jaxsim/physics/algos/aba.py +0 -254
- jaxsim/physics/algos/aba_motors.py +0 -284
- jaxsim/physics/algos/crba.py +0 -154
- jaxsim/physics/algos/forward_kinematics.py +0 -79
- jaxsim/physics/algos/jacobian.py +0 -98
- jaxsim/physics/algos/rnea.py +0 -180
- jaxsim/physics/algos/rnea_motors.py +0 -196
- jaxsim/physics/algos/soft_contacts.py +0 -523
- jaxsim/physics/algos/terrain.py +0 -78
- jaxsim/physics/algos/utils.py +0 -69
- jaxsim/physics/model/__init__.py +0 -0
- jaxsim/physics/model/ground_contact.py +0 -53
- jaxsim/physics/model/physics_model.py +0 -388
- jaxsim/physics/model/physics_model_state.py +0 -283
- jaxsim/simulation/__init__.py +0 -4
- jaxsim/simulation/integrators.py +0 -393
- jaxsim/simulation/ode.py +0 -290
- jaxsim/simulation/ode_data.py +0 -96
- jaxsim/simulation/ode_integration.py +0 -62
- jaxsim/simulation/simulator.py +0 -543
- jaxsim/simulation/simulator_callbacks.py +0 -79
- jaxsim/simulation/utils.py +0 -15
- jaxsim/sixd/__init__.py +0 -2
- jaxsim/utils/oop.py +0 -536
- jaxsim/utils/vmappable.py +0 -117
- jaxsim-0.2.dev191.dist-info/METADATA +0 -184
- jaxsim-0.2.dev191.dist-info/RECORD +0 -81
- {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/top_level.txt +0 -0
jaxsim/simulation/simulator.py
DELETED
@@ -1,543 +0,0 @@
|
|
1
|
-
import dataclasses
|
2
|
-
import functools
|
3
|
-
import pathlib
|
4
|
-
from typing import Dict, List, Optional, Union
|
5
|
-
|
6
|
-
try:
|
7
|
-
from typing import Self
|
8
|
-
except ImportError:
|
9
|
-
from typing_extensions import Self
|
10
|
-
|
11
|
-
import jax
|
12
|
-
import jax.numpy as jnp
|
13
|
-
import jax_dataclasses
|
14
|
-
import rod
|
15
|
-
from jax_dataclasses import Static
|
16
|
-
|
17
|
-
import jaxsim.high_level
|
18
|
-
import jaxsim.physics
|
19
|
-
import jaxsim.typing as jtp
|
20
|
-
from jaxsim import logging
|
21
|
-
from jaxsim.high_level.common import VelRepr
|
22
|
-
from jaxsim.high_level.model import Model, StepData
|
23
|
-
from jaxsim.parsers import descriptions
|
24
|
-
from jaxsim.physics.algos.soft_contacts import SoftContactsParams
|
25
|
-
from jaxsim.physics.algos.terrain import FlatTerrain, Terrain
|
26
|
-
from jaxsim.physics.model.physics_model import PhysicsModel
|
27
|
-
from jaxsim.utils import Mutability, Vmappable, oop
|
28
|
-
|
29
|
-
from . import simulator_callbacks as scb
|
30
|
-
from .ode_integration import IntegratorType
|
31
|
-
|
32
|
-
|
33
|
-
@jax_dataclasses.pytree_dataclass
|
34
|
-
class SimulatorData(Vmappable):
|
35
|
-
"""
|
36
|
-
Data used by the simulator.
|
37
|
-
|
38
|
-
It can be used as JaxSim state in a functional programming style.
|
39
|
-
"""
|
40
|
-
|
41
|
-
# Simulation time stored in ns in order to prevent floats approximation
|
42
|
-
time_ns: jtp.Int = dataclasses.field(
|
43
|
-
default_factory=lambda: jnp.array(0, dtype=jnp.uint64)
|
44
|
-
)
|
45
|
-
|
46
|
-
# Terrain and contact parameters
|
47
|
-
terrain: Terrain = dataclasses.field(default_factory=lambda: FlatTerrain())
|
48
|
-
contact_parameters: SoftContactsParams = dataclasses.field(
|
49
|
-
default_factory=lambda: SoftContactsParams()
|
50
|
-
)
|
51
|
-
|
52
|
-
# Dictionary containing all handled models
|
53
|
-
models: Dict[str, Model] = dataclasses.field(default_factory=dict)
|
54
|
-
|
55
|
-
# Default gravity vector (could be overridden for individual models)
|
56
|
-
gravity: jtp.Vector = dataclasses.field(
|
57
|
-
default_factory=lambda: jaxsim.physics.default_gravity()
|
58
|
-
)
|
59
|
-
|
60
|
-
|
61
|
-
@jax_dataclasses.pytree_dataclass
|
62
|
-
class JaxSim(Vmappable):
|
63
|
-
"""The JaxSim simulator."""
|
64
|
-
|
65
|
-
# Step size stored in ns in order to prevent floats approximation
|
66
|
-
step_size_ns: Static[jtp.Int] = dataclasses.field(
|
67
|
-
default_factory=lambda: jnp.array(1_000_000, dtype=jnp.uint64)
|
68
|
-
)
|
69
|
-
|
70
|
-
# Number of sub-steps performed at each integration step.
|
71
|
-
# Note: there is no collision detection performed in sub-steps.
|
72
|
-
steps_per_run: Static[jtp.Int] = dataclasses.field(default=1)
|
73
|
-
|
74
|
-
# Default velocity representation (could be overridden for individual models)
|
75
|
-
velocity_representation: Static[VelRepr] = dataclasses.field(
|
76
|
-
default=VelRepr.Inertial
|
77
|
-
)
|
78
|
-
|
79
|
-
# Integrator type
|
80
|
-
integrator_type: Static[IntegratorType] = dataclasses.field(
|
81
|
-
default=IntegratorType.EulerForward
|
82
|
-
)
|
83
|
-
|
84
|
-
# Simulator data
|
85
|
-
data: SimulatorData = dataclasses.field(default_factory=lambda: SimulatorData())
|
86
|
-
|
87
|
-
@staticmethod
|
88
|
-
def build(
|
89
|
-
step_size: jtp.Float,
|
90
|
-
steps_per_run: jtp.Int = 1,
|
91
|
-
velocity_representation: VelRepr = VelRepr.Inertial,
|
92
|
-
integrator_type: IntegratorType = IntegratorType.EulerSemiImplicit,
|
93
|
-
simulator_data: SimulatorData | None = None,
|
94
|
-
) -> "JaxSim":
|
95
|
-
"""
|
96
|
-
Build a JaxSim simulator object.
|
97
|
-
|
98
|
-
Args:
|
99
|
-
step_size: The integration step size in seconds.
|
100
|
-
steps_per_run: Number of sub-steps performed at each integration step.
|
101
|
-
velocity_representation: Default velocity representation of simulated models.
|
102
|
-
integrator_type: Type of integrator used for integrating the equations of motion.
|
103
|
-
simulator_data: Optional simulator data to initialize the simulator state.
|
104
|
-
|
105
|
-
Returns:
|
106
|
-
The JaxSim simulator object.
|
107
|
-
"""
|
108
|
-
|
109
|
-
return JaxSim(
|
110
|
-
step_size_ns=jnp.array(step_size * 1e9, dtype=jnp.uint64),
|
111
|
-
steps_per_run=int(steps_per_run),
|
112
|
-
velocity_representation=velocity_representation,
|
113
|
-
integrator_type=integrator_type,
|
114
|
-
data=simulator_data if simulator_data is not None else SimulatorData(),
|
115
|
-
)
|
116
|
-
|
117
|
-
@functools.partial(
|
118
|
-
oop.jax_tf.method_rw, static_argnames=["remove_models"], validate=False
|
119
|
-
)
|
120
|
-
def reset(self, remove_models: bool = True) -> None:
|
121
|
-
"""
|
122
|
-
Reset the simulator.
|
123
|
-
|
124
|
-
Args:
|
125
|
-
remove_models: Flag indicating whether to remove all models from the simulator.
|
126
|
-
If False, the models are kept but their state is reset.
|
127
|
-
"""
|
128
|
-
|
129
|
-
self.data.time_ns = jnp.zeros_like(self.data.time_ns)
|
130
|
-
|
131
|
-
if remove_models:
|
132
|
-
self.data.models = {}
|
133
|
-
else:
|
134
|
-
_ = [m.zero() for m in self.models()]
|
135
|
-
|
136
|
-
@functools.partial(oop.jax_tf.method_rw, jit=False)
|
137
|
-
def set_step_size(self, step_size: float) -> None:
|
138
|
-
"""
|
139
|
-
Set the integration step size.
|
140
|
-
|
141
|
-
Args:
|
142
|
-
step_size: The integration step size in seconds.
|
143
|
-
"""
|
144
|
-
|
145
|
-
self.step_size_ns = jnp.array(step_size * 1e9, dtype=jnp.uint64)
|
146
|
-
|
147
|
-
@functools.partial(oop.jax_tf.method_ro, jit=False)
|
148
|
-
def step_size(self) -> jtp.Float:
|
149
|
-
"""
|
150
|
-
Get the integration step size.
|
151
|
-
|
152
|
-
Returns:
|
153
|
-
The integration step size in seconds.
|
154
|
-
"""
|
155
|
-
|
156
|
-
return jnp.array(self.step_size_ns / 1e9, dtype=float)
|
157
|
-
|
158
|
-
@functools.partial(oop.jax_tf.method_ro)
|
159
|
-
def dt(self) -> jtp.Float:
|
160
|
-
"""
|
161
|
-
Return the integration step size in seconds.
|
162
|
-
|
163
|
-
Returns:
|
164
|
-
The integration step size in seconds.
|
165
|
-
"""
|
166
|
-
|
167
|
-
return jnp.array((self.step_size_ns * self.steps_per_run) / 1e9, dtype=float)
|
168
|
-
|
169
|
-
@functools.partial(oop.jax_tf.method_ro)
|
170
|
-
def time(self) -> jtp.Float:
|
171
|
-
"""
|
172
|
-
Return the current simulation time in seconds.
|
173
|
-
|
174
|
-
Returns:
|
175
|
-
The current simulation time in seconds.
|
176
|
-
"""
|
177
|
-
|
178
|
-
return jnp.array(self.data.time_ns / 1e9, dtype=float)
|
179
|
-
|
180
|
-
@functools.partial(oop.jax_tf.method_ro)
|
181
|
-
def gravity(self) -> jtp.Vector:
|
182
|
-
"""
|
183
|
-
Return the 3D gravity vector.
|
184
|
-
|
185
|
-
Returns:
|
186
|
-
The 3D gravity vector.
|
187
|
-
"""
|
188
|
-
|
189
|
-
return jnp.array(self.data.gravity, dtype=float)
|
190
|
-
|
191
|
-
@functools.partial(oop.jax_tf.method_ro, jit=False, vmap=False)
|
192
|
-
def model_names(self) -> tuple[str, ...]:
|
193
|
-
"""
|
194
|
-
Return the list of model names.
|
195
|
-
|
196
|
-
Returns:
|
197
|
-
The list of model names.
|
198
|
-
"""
|
199
|
-
|
200
|
-
return tuple(self.data.models.keys())
|
201
|
-
|
202
|
-
@functools.partial(
|
203
|
-
oop.jax_tf.method_ro, static_argnames=["model_name"], jit=False, vmap=False
|
204
|
-
)
|
205
|
-
def get_model(self, model_name: str) -> Model:
|
206
|
-
"""
|
207
|
-
Return the model with the given name.
|
208
|
-
|
209
|
-
Args:
|
210
|
-
model_name: The name of the model to return.
|
211
|
-
|
212
|
-
Returns:
|
213
|
-
The model with the given name.
|
214
|
-
"""
|
215
|
-
|
216
|
-
if model_name not in self.data.models:
|
217
|
-
raise ValueError(f"Failed to find model '{model_name}'")
|
218
|
-
|
219
|
-
return self.data.models[model_name]
|
220
|
-
|
221
|
-
@functools.partial(oop.jax_tf.method_ro, jit=False, vmap=False)
|
222
|
-
def models(self, model_names: tuple[str, ...] | None = None) -> tuple[Model, ...]:
|
223
|
-
"""
|
224
|
-
Return the simulated models.
|
225
|
-
|
226
|
-
Args:
|
227
|
-
model_names: Optional list of model names to return.
|
228
|
-
If None, all models are returned.
|
229
|
-
|
230
|
-
Returns:
|
231
|
-
The list of simulated models.
|
232
|
-
"""
|
233
|
-
|
234
|
-
model_names = model_names if model_names is not None else self.model_names()
|
235
|
-
return tuple(self.data.models[name] for name in model_names)
|
236
|
-
|
237
|
-
@functools.partial(oop.jax_tf.method_rw)
|
238
|
-
def set_gravity(self, gravity: jtp.Vector) -> None:
|
239
|
-
"""
|
240
|
-
Set the gravity vector to all the simulated models.
|
241
|
-
|
242
|
-
Args:
|
243
|
-
gravity: The 3D gravity vector.
|
244
|
-
"""
|
245
|
-
|
246
|
-
gravity = jnp.array(gravity, dtype=float)
|
247
|
-
|
248
|
-
if gravity.size != 3:
|
249
|
-
raise ValueError(gravity)
|
250
|
-
|
251
|
-
self.data.gravity = gravity
|
252
|
-
|
253
|
-
for model in self.data.models.values():
|
254
|
-
model.physics_model.set_gravity(gravity=gravity)
|
255
|
-
|
256
|
-
@functools.partial(oop.jax_tf.method_rw, jit=False, vmap=False, validate=False)
|
257
|
-
def insert_model_from_description(
|
258
|
-
self,
|
259
|
-
model_description: Union[pathlib.Path, str, rod.Model],
|
260
|
-
model_name: str | None = None,
|
261
|
-
considered_joints: List[str] | None = None,
|
262
|
-
) -> Model:
|
263
|
-
"""
|
264
|
-
Insert a model from a model description.
|
265
|
-
|
266
|
-
Args:
|
267
|
-
model_description: A path to an SDF/URDF file, a string containing its content, or a pre-parsed/pre-built rod model.
|
268
|
-
model_name: The optional name of the model that overrides the one in the description.
|
269
|
-
considered_joints: Optional list of joints to consider. It is also useful to specify the joint serialization.
|
270
|
-
|
271
|
-
Returns:
|
272
|
-
The newly inserted model.
|
273
|
-
"""
|
274
|
-
|
275
|
-
if self.vectorized:
|
276
|
-
raise RuntimeError("Cannot insert a model in a vectorized simulation")
|
277
|
-
|
278
|
-
# Build the model from the given model description
|
279
|
-
model = jaxsim.high_level.model.Model.build_from_model_description(
|
280
|
-
model_description=model_description,
|
281
|
-
model_name=model_name,
|
282
|
-
vel_repr=self.velocity_representation,
|
283
|
-
considered_joints=considered_joints,
|
284
|
-
)
|
285
|
-
|
286
|
-
# Make sure the model is not already part of the simulation
|
287
|
-
if model.name() in self.model_names():
|
288
|
-
msg = f"Model '{model.name()}' is already part of the simulation"
|
289
|
-
raise ValueError(msg)
|
290
|
-
|
291
|
-
# Insert the model
|
292
|
-
self.data.models[model.name()] = model
|
293
|
-
|
294
|
-
# Return the newly inserted model
|
295
|
-
return self.data.models[model.name()]
|
296
|
-
|
297
|
-
@functools.partial(oop.jax_tf.method_rw, jit=False, vmap=False, validate=False)
|
298
|
-
def insert_model_from_sdf(
|
299
|
-
self,
|
300
|
-
sdf: Union[pathlib.Path, str],
|
301
|
-
model_name: str | None = None,
|
302
|
-
considered_joints: List[str] | None = None,
|
303
|
-
) -> Model:
|
304
|
-
"""
|
305
|
-
Insert a model from an SDF resource.
|
306
|
-
"""
|
307
|
-
|
308
|
-
msg = "JaxSim.{} is deprecated, use JaxSim.{} instead."
|
309
|
-
logging.warning(
|
310
|
-
msg=msg.format("insert_model_from_sdf", "insert_model_from_description")
|
311
|
-
)
|
312
|
-
|
313
|
-
return self.insert_model_from_description(
|
314
|
-
model_description=sdf,
|
315
|
-
model_name=model_name,
|
316
|
-
considered_joints=considered_joints,
|
317
|
-
)
|
318
|
-
|
319
|
-
@functools.partial(oop.jax_tf.method_rw, jit=False, vmap=False, validate=False)
|
320
|
-
def insert_model(
|
321
|
-
self,
|
322
|
-
model_description: descriptions.ModelDescription,
|
323
|
-
model_name: str | None = None,
|
324
|
-
) -> Model:
|
325
|
-
"""
|
326
|
-
Insert a model from a model description object.
|
327
|
-
|
328
|
-
Args:
|
329
|
-
model_description: The model description object.
|
330
|
-
model_name: Optional name of the model to insert.
|
331
|
-
|
332
|
-
Returns:
|
333
|
-
The newly inserted model.
|
334
|
-
"""
|
335
|
-
|
336
|
-
if self.vectorized:
|
337
|
-
raise RuntimeError("Cannot insert a model in a vectorized simulation")
|
338
|
-
|
339
|
-
model_name = model_name if model_name is not None else model_description.name
|
340
|
-
|
341
|
-
if model_name in self.model_names():
|
342
|
-
msg = f"Model '{model_name}' is already part of the simulation"
|
343
|
-
raise ValueError(msg)
|
344
|
-
|
345
|
-
# Build the physics model the model description
|
346
|
-
physics_model = PhysicsModel.build_from(
|
347
|
-
model_description=model_description, gravity=self.gravity()
|
348
|
-
)
|
349
|
-
|
350
|
-
# Build the high-level model from the physics model
|
351
|
-
model = jaxsim.high_level.model.Model.build(
|
352
|
-
model_name=model_name,
|
353
|
-
physics_model=physics_model,
|
354
|
-
vel_repr=self.velocity_representation,
|
355
|
-
)
|
356
|
-
|
357
|
-
# Insert the model into the simulators
|
358
|
-
self.data.models[model.name()] = model
|
359
|
-
|
360
|
-
# Return the newly inserted model
|
361
|
-
return self.data.models[model.name()]
|
362
|
-
|
363
|
-
@functools.partial(
|
364
|
-
oop.jax_tf.method_rw,
|
365
|
-
jit=False,
|
366
|
-
validate=False,
|
367
|
-
static_argnames=["model_name"],
|
368
|
-
)
|
369
|
-
def remove_model(self, model_name: str) -> None:
|
370
|
-
"""
|
371
|
-
Remove a model from the simulator.
|
372
|
-
|
373
|
-
Args:
|
374
|
-
model_name: The name of the model to remove.
|
375
|
-
"""
|
376
|
-
|
377
|
-
if model_name not in self.model_names():
|
378
|
-
msg = f"Model '{model_name}' is not part of the simulation"
|
379
|
-
raise ValueError(msg)
|
380
|
-
|
381
|
-
_ = self.data.models.pop(model_name)
|
382
|
-
|
383
|
-
@functools.partial(oop.jax_tf.method_rw, vmap_in_axes=(0, None))
|
384
|
-
def step(self, clear_inputs: bool = False) -> Dict[str, StepData]:
|
385
|
-
"""
|
386
|
-
Advance the simulation by one step.
|
387
|
-
|
388
|
-
Args:
|
389
|
-
clear_inputs: Zero the inputs of the models after the integration.
|
390
|
-
|
391
|
-
Returns:
|
392
|
-
A dictionary containing the StepData of all models.
|
393
|
-
"""
|
394
|
-
|
395
|
-
# Compute the initial and final time of the integration as integers
|
396
|
-
t0_ns = jnp.array(self.data.time_ns, dtype=jnp.uint64)
|
397
|
-
dt_ns = jnp.array(self.step_size_ns * self.steps_per_run, dtype=jnp.uint64)
|
398
|
-
|
399
|
-
# Compute the final time using integer arithmetics
|
400
|
-
tf_ns = t0_ns + dt_ns
|
401
|
-
|
402
|
-
# We collect the StepData of all models
|
403
|
-
step_data = {}
|
404
|
-
|
405
|
-
for model in self.models():
|
406
|
-
# Integrate individually all models and collect their StepData.
|
407
|
-
# We use the context manager to make sure that the PyTree of the models
|
408
|
-
# never changes, so that it never triggers JIT recompilations.
|
409
|
-
with model.editable(validate=True) as integrated_model:
|
410
|
-
step_data[model.name()] = integrated_model.integrate(
|
411
|
-
t0=jnp.array(t0_ns, dtype=float) / 1e9,
|
412
|
-
tf=jnp.array(tf_ns, dtype=float) / 1e9,
|
413
|
-
sub_steps=self.steps_per_run,
|
414
|
-
integrator_type=self.integrator_type,
|
415
|
-
terrain=self.data.terrain,
|
416
|
-
contact_parameters=self.data.contact_parameters,
|
417
|
-
clear_inputs=clear_inputs,
|
418
|
-
)
|
419
|
-
|
420
|
-
self.data.models[model.name()].data = integrated_model.data
|
421
|
-
|
422
|
-
# Store the final time
|
423
|
-
self.data.time_ns += dt_ns
|
424
|
-
|
425
|
-
return step_data
|
426
|
-
|
427
|
-
@functools.partial(
|
428
|
-
oop.jax_tf.method_ro,
|
429
|
-
static_argnames=["horizon_steps"],
|
430
|
-
vmap_in_axes=(0, None, 0, None),
|
431
|
-
)
|
432
|
-
def step_over_horizon(
|
433
|
-
self,
|
434
|
-
horizon_steps: jtp.Int,
|
435
|
-
callback_handler: (
|
436
|
-
Union["scb.SimulatorCallback", "scb.CallbackHandler"] | None
|
437
|
-
) = None,
|
438
|
-
clear_inputs: jtp.Bool = False,
|
439
|
-
) -> Union[
|
440
|
-
"JaxSim",
|
441
|
-
tuple["JaxSim", tuple["scb.SimulatorCallback", tuple[jtp.PyTree, jtp.PyTree]]],
|
442
|
-
]:
|
443
|
-
"""
|
444
|
-
Advance the simulation by a given number of steps.
|
445
|
-
|
446
|
-
Args:
|
447
|
-
horizon_steps: The number of steps to advance the simulation.
|
448
|
-
callback_handler: A callback handler to inject custom login in the simulation loop.
|
449
|
-
clear_inputs: Zero the inputs of the models after the integration.
|
450
|
-
|
451
|
-
Returns:
|
452
|
-
The updated simulator if no callback handler is provided, otherwise a tuple
|
453
|
-
containing the updated simulator and a tuple containing callback data.
|
454
|
-
The optional callback data is a tuple containing the updated callback object,
|
455
|
-
the produced pre-step output, and the produced post-step output.
|
456
|
-
"""
|
457
|
-
|
458
|
-
# Process a mutable copy of the simulator
|
459
|
-
original_mutability = self._mutability()
|
460
|
-
sim = self.copy().mutable(validate=True)
|
461
|
-
|
462
|
-
# Helper to get callbacks from the handler
|
463
|
-
get_cb = lambda h, cb_name: (
|
464
|
-
getattr(h, cb_name) if h is not None and hasattr(h, cb_name) else None
|
465
|
-
)
|
466
|
-
|
467
|
-
# Get the callbacks
|
468
|
-
configure_cb: Optional[scb.ConfigureCallbackSignature] = get_cb(
|
469
|
-
h=callback_handler, cb_name="configure_cb"
|
470
|
-
)
|
471
|
-
pre_step_cb: Optional[scb.PreStepCallbackSignature] = get_cb(
|
472
|
-
h=callback_handler, cb_name="pre_step_cb"
|
473
|
-
)
|
474
|
-
post_step_cb: Optional[scb.PostStepCallbackSignature] = get_cb(
|
475
|
-
h=callback_handler, cb_name="post_step_cb"
|
476
|
-
)
|
477
|
-
|
478
|
-
# Callback: configuration
|
479
|
-
sim = configure_cb(sim) if configure_cb is not None else sim
|
480
|
-
|
481
|
-
# Initialize the carry
|
482
|
-
Carry = tuple[JaxSim, scb.CallbackHandler]
|
483
|
-
carry_init: Carry = (sim, callback_handler)
|
484
|
-
|
485
|
-
def body_fun(
|
486
|
-
carry: Carry, xs: None
|
487
|
-
) -> tuple[Carry, tuple[jtp.PyTree, jtp.PyTree]]:
|
488
|
-
sim, callback_handler = carry
|
489
|
-
|
490
|
-
# Make sure to pass a mutable version of the simulator to the callbacks
|
491
|
-
sim = sim.mutable(validate=True)
|
492
|
-
|
493
|
-
# Callback: pre-step
|
494
|
-
sim, out_pre_step = (
|
495
|
-
pre_step_cb(sim) if pre_step_cb is not None else (sim, None)
|
496
|
-
)
|
497
|
-
|
498
|
-
# Integrate all models
|
499
|
-
step_data = sim.step(clear_inputs=clear_inputs)
|
500
|
-
|
501
|
-
# Callback: post-step
|
502
|
-
sim, out_post_step = (
|
503
|
-
post_step_cb(sim, step_data)
|
504
|
-
if post_step_cb is not None
|
505
|
-
else (sim, None)
|
506
|
-
)
|
507
|
-
|
508
|
-
# Pack the carry
|
509
|
-
carry = (sim, callback_handler)
|
510
|
-
|
511
|
-
return carry, (out_pre_step, out_post_step)
|
512
|
-
|
513
|
-
# Integrate over the given horizon
|
514
|
-
(sim, callback_handler), (
|
515
|
-
out_pre_step_horizon,
|
516
|
-
out_post_step_horizon,
|
517
|
-
) = jax.lax.scan(f=body_fun, init=carry_init, xs=None, length=horizon_steps)
|
518
|
-
|
519
|
-
# Enforce original mutability of the entire simulator
|
520
|
-
sim._set_mutability(original_mutability)
|
521
|
-
|
522
|
-
return (
|
523
|
-
sim
|
524
|
-
if callback_handler is None
|
525
|
-
else (
|
526
|
-
sim,
|
527
|
-
(callback_handler, (out_pre_step_horizon, out_post_step_horizon)),
|
528
|
-
)
|
529
|
-
)
|
530
|
-
|
531
|
-
def vectorize(self: Self, batch_size: int) -> Self:
|
532
|
-
"""
|
533
|
-
Inherit docs.
|
534
|
-
"""
|
535
|
-
|
536
|
-
jaxsim_vec: JaxSim = super().vectorize(batch_size=batch_size) # noqa
|
537
|
-
|
538
|
-
# We need to manually specify the batch size of the handled models
|
539
|
-
with jaxsim_vec.mutable_context(mutability=Mutability.MUTABLE):
|
540
|
-
for model in jaxsim_vec.models():
|
541
|
-
model.batch_size = batch_size
|
542
|
-
|
543
|
-
return jaxsim_vec
|
@@ -1,79 +0,0 @@
|
|
1
|
-
import abc
|
2
|
-
from typing import Callable, Dict, Tuple
|
3
|
-
|
4
|
-
import jaxsim.typing as jtp
|
5
|
-
from jaxsim.high_level.model import StepData
|
6
|
-
|
7
|
-
ConfigureCallbackSignature = Callable[["jaxsim.JaxSim"], "jaxsim.JaxSim"]
|
8
|
-
PreStepCallbackSignature = Callable[
|
9
|
-
["jaxsim.JaxSim"], Tuple["jaxsim.JaxSim", jtp.PyTree]
|
10
|
-
]
|
11
|
-
PostStepCallbackSignature = Callable[
|
12
|
-
["jaxsim.JaxSim", Dict[str, StepData]], Tuple["jaxsim.JaxSim", jtp.PyTree]
|
13
|
-
]
|
14
|
-
|
15
|
-
|
16
|
-
class SimulatorCallback(abc.ABC):
|
17
|
-
"""
|
18
|
-
A base class for simulator callbacks.
|
19
|
-
"""
|
20
|
-
|
21
|
-
pass
|
22
|
-
|
23
|
-
|
24
|
-
class ConfigureCallback(SimulatorCallback):
|
25
|
-
"""
|
26
|
-
A callback class to define logic for configuring the simulator before taking the first step.
|
27
|
-
"""
|
28
|
-
|
29
|
-
@property
|
30
|
-
def configure_cb(self) -> ConfigureCallbackSignature:
|
31
|
-
return lambda sim: self.configure(sim=sim)
|
32
|
-
|
33
|
-
@abc.abstractmethod
|
34
|
-
def configure(self, sim: "jaxsim.JaxSim") -> "jaxsim.JaxSim":
|
35
|
-
pass
|
36
|
-
|
37
|
-
|
38
|
-
class PreStepCallback(SimulatorCallback):
|
39
|
-
"""
|
40
|
-
A callback class for performing actions before each simulation step.
|
41
|
-
"""
|
42
|
-
|
43
|
-
@property
|
44
|
-
def pre_step_cb(self) -> PreStepCallbackSignature:
|
45
|
-
return lambda sim: self.pre_step(sim=sim)
|
46
|
-
|
47
|
-
@abc.abstractmethod
|
48
|
-
def pre_step(self, sim: "jaxsim.JaxSim") -> Tuple["jaxsim.JaxSim", jtp.PyTree]:
|
49
|
-
pass
|
50
|
-
|
51
|
-
|
52
|
-
class PostStepCallback(SimulatorCallback):
|
53
|
-
"""
|
54
|
-
A callback class for performing actions after each simulation step.
|
55
|
-
"""
|
56
|
-
|
57
|
-
@property
|
58
|
-
def post_step_cb(self) -> PostStepCallbackSignature:
|
59
|
-
return lambda sim, step_data: self.post_step(sim=sim, step_data=step_data)
|
60
|
-
|
61
|
-
@abc.abstractmethod
|
62
|
-
def post_step(
|
63
|
-
self, sim: "jaxsim.JaxSim", step_data: Dict[str, StepData]
|
64
|
-
) -> Tuple["jaxsim.JaxSim", jtp.PyTree]:
|
65
|
-
pass
|
66
|
-
|
67
|
-
|
68
|
-
class CallbackHandler(ConfigureCallback, PreStepCallback, PostStepCallback):
|
69
|
-
"""
|
70
|
-
A class that handles callbacks for the simulator.
|
71
|
-
|
72
|
-
Note:
|
73
|
-
The are different simulation stages with associated callbacks:
|
74
|
-
- `configure`: runs before the first step is taken.
|
75
|
-
- `pre_step`: runs at each step before integrating the dynamics and advancing the time.
|
76
|
-
- `post_step`: runs at each step after the integration of the dynamics.
|
77
|
-
"""
|
78
|
-
|
79
|
-
pass
|
jaxsim/simulation/utils.py
DELETED
@@ -1,15 +0,0 @@
|
|
1
|
-
from typing import Tuple
|
2
|
-
|
3
|
-
from jaxsim import logging
|
4
|
-
|
5
|
-
|
6
|
-
def check_valid_shape(
|
7
|
-
what: str, shape: Tuple, expected_shape: Tuple, valid: bool
|
8
|
-
) -> bool:
|
9
|
-
valid_shape = shape == expected_shape
|
10
|
-
|
11
|
-
if not valid_shape:
|
12
|
-
logging.debug(f"Shape of {what} differs: {shape}, {expected_shape}")
|
13
|
-
raise
|
14
|
-
|
15
|
-
return valid
|
jaxsim/sixd/__init__.py
DELETED