jaxsim 0.2.dev191__py3-none-any.whl → 0.2.dev364__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. jaxsim/__init__.py +3 -4
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +3 -1
  4. jaxsim/api/com.py +240 -0
  5. jaxsim/api/common.py +13 -2
  6. jaxsim/api/contact.py +120 -43
  7. jaxsim/api/data.py +112 -71
  8. jaxsim/api/joint.py +77 -36
  9. jaxsim/api/kin_dyn_parameters.py +777 -0
  10. jaxsim/api/link.py +150 -75
  11. jaxsim/api/model.py +542 -269
  12. jaxsim/api/ode.py +86 -74
  13. jaxsim/api/ode_data.py +694 -0
  14. jaxsim/api/references.py +12 -11
  15. jaxsim/integrators/__init__.py +2 -2
  16. jaxsim/integrators/common.py +110 -24
  17. jaxsim/integrators/fixed_step.py +11 -67
  18. jaxsim/integrators/variable_step.py +610 -0
  19. jaxsim/math/__init__.py +11 -0
  20. jaxsim/math/adjoint.py +24 -2
  21. jaxsim/math/joint_model.py +335 -0
  22. jaxsim/math/quaternion.py +44 -3
  23. jaxsim/math/rotation.py +4 -4
  24. jaxsim/math/transform.py +93 -0
  25. jaxsim/parsers/descriptions/link.py +2 -2
  26. jaxsim/parsers/rod/utils.py +7 -8
  27. jaxsim/rbda/__init__.py +7 -0
  28. jaxsim/rbda/aba.py +295 -0
  29. jaxsim/rbda/collidable_points.py +142 -0
  30. jaxsim/{physics/algos → rbda}/crba.py +43 -42
  31. jaxsim/rbda/forward_kinematics.py +113 -0
  32. jaxsim/rbda/jacobian.py +201 -0
  33. jaxsim/rbda/rnea.py +237 -0
  34. jaxsim/rbda/soft_contacts.py +296 -0
  35. jaxsim/rbda/utils.py +152 -0
  36. jaxsim/terrain/__init__.py +2 -0
  37. jaxsim/utils/__init__.py +1 -4
  38. jaxsim/utils/hashless.py +18 -0
  39. jaxsim/utils/jaxsim_dataclass.py +281 -30
  40. {jaxsim-0.2.dev191.dist-info → jaxsim-0.2.dev364.dist-info}/METADATA +4 -6
  41. jaxsim-0.2.dev364.dist-info/RECORD +64 -0
  42. {jaxsim-0.2.dev191.dist-info → jaxsim-0.2.dev364.dist-info}/WHEEL +1 -1
  43. jaxsim/high_level/__init__.py +0 -2
  44. jaxsim/high_level/common.py +0 -11
  45. jaxsim/high_level/joint.py +0 -148
  46. jaxsim/high_level/link.py +0 -259
  47. jaxsim/high_level/model.py +0 -1686
  48. jaxsim/math/conv.py +0 -114
  49. jaxsim/math/joint.py +0 -102
  50. jaxsim/math/plucker.py +0 -100
  51. jaxsim/physics/__init__.py +0 -12
  52. jaxsim/physics/algos/__init__.py +0 -0
  53. jaxsim/physics/algos/aba.py +0 -254
  54. jaxsim/physics/algos/aba_motors.py +0 -284
  55. jaxsim/physics/algos/forward_kinematics.py +0 -79
  56. jaxsim/physics/algos/jacobian.py +0 -98
  57. jaxsim/physics/algos/rnea.py +0 -180
  58. jaxsim/physics/algos/rnea_motors.py +0 -196
  59. jaxsim/physics/algos/soft_contacts.py +0 -523
  60. jaxsim/physics/algos/utils.py +0 -69
  61. jaxsim/physics/model/__init__.py +0 -0
  62. jaxsim/physics/model/ground_contact.py +0 -53
  63. jaxsim/physics/model/physics_model.py +0 -388
  64. jaxsim/physics/model/physics_model_state.py +0 -283
  65. jaxsim/simulation/__init__.py +0 -4
  66. jaxsim/simulation/integrators.py +0 -393
  67. jaxsim/simulation/ode.py +0 -290
  68. jaxsim/simulation/ode_data.py +0 -96
  69. jaxsim/simulation/ode_integration.py +0 -62
  70. jaxsim/simulation/simulator.py +0 -543
  71. jaxsim/simulation/simulator_callbacks.py +0 -79
  72. jaxsim/simulation/utils.py +0 -15
  73. jaxsim/sixd/__init__.py +0 -2
  74. jaxsim/utils/oop.py +0 -536
  75. jaxsim/utils/vmappable.py +0 -117
  76. jaxsim-0.2.dev191.dist-info/RECORD +0 -81
  77. /jaxsim/{physics/algos → terrain}/terrain.py +0 -0
  78. {jaxsim-0.2.dev191.dist-info → jaxsim-0.2.dev364.dist-info}/LICENSE +0 -0
  79. {jaxsim-0.2.dev191.dist-info → jaxsim-0.2.dev364.dist-info}/top_level.txt +0 -0
@@ -1,96 +0,0 @@
1
- import jax.flatten_util
2
- import jax_dataclasses
3
-
4
- import jaxsim.typing as jtp
5
- from jaxsim.physics.algos.soft_contacts import SoftContactsState
6
- from jaxsim.physics.model.physics_model import PhysicsModel
7
- from jaxsim.physics.model.physics_model_state import (
8
- PhysicsModelInput,
9
- PhysicsModelState,
10
- )
11
- from jaxsim.utils import JaxsimDataclass
12
-
13
-
14
- @jax_dataclasses.pytree_dataclass
15
- class ODEInput(JaxsimDataclass):
16
- """"""
17
-
18
- physics_model: PhysicsModelInput
19
-
20
- @staticmethod
21
- def build(
22
- physics_model_input: PhysicsModelInput | None = None,
23
- physics_model: PhysicsModel | None = None,
24
- ) -> "ODEInput":
25
- """"""
26
-
27
- physics_model_input = (
28
- physics_model_input
29
- if physics_model_input is not None
30
- else PhysicsModelInput.zero(physics_model=physics_model)
31
- )
32
-
33
- return ODEInput(physics_model=physics_model_input)
34
-
35
- @staticmethod
36
- def zero(physics_model: PhysicsModel) -> "ODEInput":
37
- return ODEInput(
38
- physics_model=PhysicsModelInput.zero(physics_model=physics_model)
39
- )
40
-
41
- def valid(self, physics_model: PhysicsModel) -> bool:
42
- return self.physics_model.valid(physics_model=physics_model)
43
-
44
-
45
- @jax_dataclasses.pytree_dataclass
46
- class ODEState(JaxsimDataclass):
47
- """"""
48
-
49
- physics_model: PhysicsModelState
50
- soft_contacts: SoftContactsState
51
-
52
- @staticmethod
53
- def build(
54
- physics_model_state: PhysicsModelState | None = None,
55
- soft_contacts_state: SoftContactsState | None = None,
56
- physics_model: PhysicsModel | None = None,
57
- ) -> "ODEState":
58
- """"""
59
-
60
- physics_model_state = (
61
- physics_model_state
62
- if physics_model_state is not None
63
- else PhysicsModelState.zero(physics_model=physics_model)
64
- )
65
-
66
- soft_contacts_state = (
67
- soft_contacts_state
68
- if soft_contacts_state is not None
69
- else SoftContactsState.zero(physics_model=physics_model)
70
- )
71
-
72
- return ODEState(
73
- physics_model=physics_model_state, soft_contacts=soft_contacts_state
74
- )
75
-
76
- @staticmethod
77
- def deserialize(data: jtp.VectorJax, physics_model: PhysicsModel) -> "ODEState":
78
- dummy_object = ODEState.zero(physics_model=physics_model)
79
- _, unflatten_data = jax.flatten_util.ravel_pytree(dummy_object)
80
-
81
- return unflatten_data(data)
82
-
83
- @staticmethod
84
- def zero(physics_model: PhysicsModel) -> "ODEState":
85
- model_state = ODEState(
86
- physics_model=PhysicsModelState.zero(physics_model=physics_model),
87
- soft_contacts=SoftContactsState.zero(physics_model=physics_model),
88
- )
89
-
90
- assert model_state.valid(physics_model)
91
- return model_state
92
-
93
- def valid(self, physics_model: PhysicsModel) -> bool:
94
- return self.physics_model.valid(
95
- physics_model=physics_model
96
- ) and self.soft_contacts.valid(physics_model=physics_model)
@@ -1,62 +0,0 @@
1
- import enum
2
- import functools
3
- from typing import Any, Dict, Tuple, Union
4
-
5
- import jax.flatten_util
6
- from jax.experimental.ode import odeint
7
-
8
- import jaxsim.typing as jtp
9
- from jaxsim.physics.algos.soft_contacts import SoftContactsParams
10
- from jaxsim.physics.algos.terrain import FlatTerrain, Terrain
11
- from jaxsim.physics.model.physics_model import PhysicsModel
12
- from jaxsim.simulation import integrators, ode
13
- from jaxsim.simulation.integrators import IntegratorType
14
-
15
-
16
- @jax.jit
17
- def ode_integration_rk4_adaptive(
18
- x0: jtp.Array,
19
- t: integrators.TimeHorizon,
20
- physics_model: PhysicsModel,
21
- *args,
22
- **kwargs,
23
- ) -> jtp.Array:
24
- # Close function over its inputs and parameters
25
- dx_dt_closure = lambda x, ts: ode.dx_dt(x, ts, physics_model, *args)
26
-
27
- return odeint(dx_dt_closure, x0, t, **kwargs)
28
-
29
-
30
- @functools.partial(
31
- jax.jit, static_argnames=["num_sub_steps", "integrator_type", "return_aux"]
32
- )
33
- def ode_integration_fixed_step(
34
- x0: ode.ode_data.ODEState,
35
- t: integrators.TimeHorizon,
36
- physics_model: PhysicsModel,
37
- integrator_type: IntegratorType,
38
- soft_contacts_params: SoftContactsParams = SoftContactsParams(),
39
- terrain: Terrain = FlatTerrain(),
40
- ode_input: ode.ode_data.ODEInput | None = None,
41
- *args,
42
- num_sub_steps: int = 1,
43
- return_aux: bool = False,
44
- ) -> Union[ode.ode_data.ODEState, Tuple[ode.ode_data.ODEState, Dict]]:
45
- # Close func over additional inputs and parameters
46
- dx_dt_closure = lambda x, ts: ode.dx_dt(
47
- x, ts, physics_model, soft_contacts_params, ode_input, terrain, *args
48
- )
49
-
50
- # Integrate over the horizon
51
- out = integrators.odeint(
52
- func=dx_dt_closure,
53
- y0=x0,
54
- t=t,
55
- num_sub_steps=num_sub_steps,
56
- return_aux=return_aux,
57
- integrator_type=integrator_type,
58
- )
59
-
60
- # Return output pytree and, optionally, the aux dict
61
- state = out if not return_aux else out[0]
62
- return (state, out[1]) if return_aux else state
@@ -1,543 +0,0 @@
1
- import dataclasses
2
- import functools
3
- import pathlib
4
- from typing import Dict, List, Optional, Union
5
-
6
- try:
7
- from typing import Self
8
- except ImportError:
9
- from typing_extensions import Self
10
-
11
- import jax
12
- import jax.numpy as jnp
13
- import jax_dataclasses
14
- import rod
15
- from jax_dataclasses import Static
16
-
17
- import jaxsim.high_level
18
- import jaxsim.physics
19
- import jaxsim.typing as jtp
20
- from jaxsim import logging
21
- from jaxsim.high_level.common import VelRepr
22
- from jaxsim.high_level.model import Model, StepData
23
- from jaxsim.parsers import descriptions
24
- from jaxsim.physics.algos.soft_contacts import SoftContactsParams
25
- from jaxsim.physics.algos.terrain import FlatTerrain, Terrain
26
- from jaxsim.physics.model.physics_model import PhysicsModel
27
- from jaxsim.utils import Mutability, Vmappable, oop
28
-
29
- from . import simulator_callbacks as scb
30
- from .ode_integration import IntegratorType
31
-
32
-
33
- @jax_dataclasses.pytree_dataclass
34
- class SimulatorData(Vmappable):
35
- """
36
- Data used by the simulator.
37
-
38
- It can be used as JaxSim state in a functional programming style.
39
- """
40
-
41
- # Simulation time stored in ns in order to prevent floats approximation
42
- time_ns: jtp.Int = dataclasses.field(
43
- default_factory=lambda: jnp.array(0, dtype=jnp.uint64)
44
- )
45
-
46
- # Terrain and contact parameters
47
- terrain: Terrain = dataclasses.field(default_factory=lambda: FlatTerrain())
48
- contact_parameters: SoftContactsParams = dataclasses.field(
49
- default_factory=lambda: SoftContactsParams()
50
- )
51
-
52
- # Dictionary containing all handled models
53
- models: Dict[str, Model] = dataclasses.field(default_factory=dict)
54
-
55
- # Default gravity vector (could be overridden for individual models)
56
- gravity: jtp.Vector = dataclasses.field(
57
- default_factory=lambda: jaxsim.physics.default_gravity()
58
- )
59
-
60
-
61
- @jax_dataclasses.pytree_dataclass
62
- class JaxSim(Vmappable):
63
- """The JaxSim simulator."""
64
-
65
- # Step size stored in ns in order to prevent floats approximation
66
- step_size_ns: Static[jtp.Int] = dataclasses.field(
67
- default_factory=lambda: jnp.array(1_000_000, dtype=jnp.uint64)
68
- )
69
-
70
- # Number of sub-steps performed at each integration step.
71
- # Note: there is no collision detection performed in sub-steps.
72
- steps_per_run: Static[jtp.Int] = dataclasses.field(default=1)
73
-
74
- # Default velocity representation (could be overridden for individual models)
75
- velocity_representation: Static[VelRepr] = dataclasses.field(
76
- default=VelRepr.Inertial
77
- )
78
-
79
- # Integrator type
80
- integrator_type: Static[IntegratorType] = dataclasses.field(
81
- default=IntegratorType.EulerForward
82
- )
83
-
84
- # Simulator data
85
- data: SimulatorData = dataclasses.field(default_factory=lambda: SimulatorData())
86
-
87
- @staticmethod
88
- def build(
89
- step_size: jtp.Float,
90
- steps_per_run: jtp.Int = 1,
91
- velocity_representation: VelRepr = VelRepr.Inertial,
92
- integrator_type: IntegratorType = IntegratorType.EulerSemiImplicit,
93
- simulator_data: SimulatorData | None = None,
94
- ) -> "JaxSim":
95
- """
96
- Build a JaxSim simulator object.
97
-
98
- Args:
99
- step_size: The integration step size in seconds.
100
- steps_per_run: Number of sub-steps performed at each integration step.
101
- velocity_representation: Default velocity representation of simulated models.
102
- integrator_type: Type of integrator used for integrating the equations of motion.
103
- simulator_data: Optional simulator data to initialize the simulator state.
104
-
105
- Returns:
106
- The JaxSim simulator object.
107
- """
108
-
109
- return JaxSim(
110
- step_size_ns=jnp.array(step_size * 1e9, dtype=jnp.uint64),
111
- steps_per_run=int(steps_per_run),
112
- velocity_representation=velocity_representation,
113
- integrator_type=integrator_type,
114
- data=simulator_data if simulator_data is not None else SimulatorData(),
115
- )
116
-
117
- @functools.partial(
118
- oop.jax_tf.method_rw, static_argnames=["remove_models"], validate=False
119
- )
120
- def reset(self, remove_models: bool = True) -> None:
121
- """
122
- Reset the simulator.
123
-
124
- Args:
125
- remove_models: Flag indicating whether to remove all models from the simulator.
126
- If False, the models are kept but their state is reset.
127
- """
128
-
129
- self.data.time_ns = jnp.zeros_like(self.data.time_ns)
130
-
131
- if remove_models:
132
- self.data.models = {}
133
- else:
134
- _ = [m.zero() for m in self.models()]
135
-
136
- @functools.partial(oop.jax_tf.method_rw, jit=False)
137
- def set_step_size(self, step_size: float) -> None:
138
- """
139
- Set the integration step size.
140
-
141
- Args:
142
- step_size: The integration step size in seconds.
143
- """
144
-
145
- self.step_size_ns = jnp.array(step_size * 1e9, dtype=jnp.uint64)
146
-
147
- @functools.partial(oop.jax_tf.method_ro, jit=False)
148
- def step_size(self) -> jtp.Float:
149
- """
150
- Get the integration step size.
151
-
152
- Returns:
153
- The integration step size in seconds.
154
- """
155
-
156
- return jnp.array(self.step_size_ns / 1e9, dtype=float)
157
-
158
- @functools.partial(oop.jax_tf.method_ro)
159
- def dt(self) -> jtp.Float:
160
- """
161
- Return the integration step size in seconds.
162
-
163
- Returns:
164
- The integration step size in seconds.
165
- """
166
-
167
- return jnp.array((self.step_size_ns * self.steps_per_run) / 1e9, dtype=float)
168
-
169
- @functools.partial(oop.jax_tf.method_ro)
170
- def time(self) -> jtp.Float:
171
- """
172
- Return the current simulation time in seconds.
173
-
174
- Returns:
175
- The current simulation time in seconds.
176
- """
177
-
178
- return jnp.array(self.data.time_ns / 1e9, dtype=float)
179
-
180
- @functools.partial(oop.jax_tf.method_ro)
181
- def gravity(self) -> jtp.Vector:
182
- """
183
- Return the 3D gravity vector.
184
-
185
- Returns:
186
- The 3D gravity vector.
187
- """
188
-
189
- return jnp.array(self.data.gravity, dtype=float)
190
-
191
- @functools.partial(oop.jax_tf.method_ro, jit=False, vmap=False)
192
- def model_names(self) -> tuple[str, ...]:
193
- """
194
- Return the list of model names.
195
-
196
- Returns:
197
- The list of model names.
198
- """
199
-
200
- return tuple(self.data.models.keys())
201
-
202
- @functools.partial(
203
- oop.jax_tf.method_ro, static_argnames=["model_name"], jit=False, vmap=False
204
- )
205
- def get_model(self, model_name: str) -> Model:
206
- """
207
- Return the model with the given name.
208
-
209
- Args:
210
- model_name: The name of the model to return.
211
-
212
- Returns:
213
- The model with the given name.
214
- """
215
-
216
- if model_name not in self.data.models:
217
- raise ValueError(f"Failed to find model '{model_name}'")
218
-
219
- return self.data.models[model_name]
220
-
221
- @functools.partial(oop.jax_tf.method_ro, jit=False, vmap=False)
222
- def models(self, model_names: tuple[str, ...] | None = None) -> tuple[Model, ...]:
223
- """
224
- Return the simulated models.
225
-
226
- Args:
227
- model_names: Optional list of model names to return.
228
- If None, all models are returned.
229
-
230
- Returns:
231
- The list of simulated models.
232
- """
233
-
234
- model_names = model_names if model_names is not None else self.model_names()
235
- return tuple(self.data.models[name] for name in model_names)
236
-
237
- @functools.partial(oop.jax_tf.method_rw)
238
- def set_gravity(self, gravity: jtp.Vector) -> None:
239
- """
240
- Set the gravity vector to all the simulated models.
241
-
242
- Args:
243
- gravity: The 3D gravity vector.
244
- """
245
-
246
- gravity = jnp.array(gravity, dtype=float)
247
-
248
- if gravity.size != 3:
249
- raise ValueError(gravity)
250
-
251
- self.data.gravity = gravity
252
-
253
- for model in self.data.models.values():
254
- model.physics_model.set_gravity(gravity=gravity)
255
-
256
- @functools.partial(oop.jax_tf.method_rw, jit=False, vmap=False, validate=False)
257
- def insert_model_from_description(
258
- self,
259
- model_description: Union[pathlib.Path, str, rod.Model],
260
- model_name: str | None = None,
261
- considered_joints: List[str] | None = None,
262
- ) -> Model:
263
- """
264
- Insert a model from a model description.
265
-
266
- Args:
267
- model_description: A path to an SDF/URDF file, a string containing its content, or a pre-parsed/pre-built rod model.
268
- model_name: The optional name of the model that overrides the one in the description.
269
- considered_joints: Optional list of joints to consider. It is also useful to specify the joint serialization.
270
-
271
- Returns:
272
- The newly inserted model.
273
- """
274
-
275
- if self.vectorized:
276
- raise RuntimeError("Cannot insert a model in a vectorized simulation")
277
-
278
- # Build the model from the given model description
279
- model = jaxsim.high_level.model.Model.build_from_model_description(
280
- model_description=model_description,
281
- model_name=model_name,
282
- vel_repr=self.velocity_representation,
283
- considered_joints=considered_joints,
284
- )
285
-
286
- # Make sure the model is not already part of the simulation
287
- if model.name() in self.model_names():
288
- msg = f"Model '{model.name()}' is already part of the simulation"
289
- raise ValueError(msg)
290
-
291
- # Insert the model
292
- self.data.models[model.name()] = model
293
-
294
- # Return the newly inserted model
295
- return self.data.models[model.name()]
296
-
297
- @functools.partial(oop.jax_tf.method_rw, jit=False, vmap=False, validate=False)
298
- def insert_model_from_sdf(
299
- self,
300
- sdf: Union[pathlib.Path, str],
301
- model_name: str | None = None,
302
- considered_joints: List[str] | None = None,
303
- ) -> Model:
304
- """
305
- Insert a model from an SDF resource.
306
- """
307
-
308
- msg = "JaxSim.{} is deprecated, use JaxSim.{} instead."
309
- logging.warning(
310
- msg=msg.format("insert_model_from_sdf", "insert_model_from_description")
311
- )
312
-
313
- return self.insert_model_from_description(
314
- model_description=sdf,
315
- model_name=model_name,
316
- considered_joints=considered_joints,
317
- )
318
-
319
- @functools.partial(oop.jax_tf.method_rw, jit=False, vmap=False, validate=False)
320
- def insert_model(
321
- self,
322
- model_description: descriptions.ModelDescription,
323
- model_name: str | None = None,
324
- ) -> Model:
325
- """
326
- Insert a model from a model description object.
327
-
328
- Args:
329
- model_description: The model description object.
330
- model_name: Optional name of the model to insert.
331
-
332
- Returns:
333
- The newly inserted model.
334
- """
335
-
336
- if self.vectorized:
337
- raise RuntimeError("Cannot insert a model in a vectorized simulation")
338
-
339
- model_name = model_name if model_name is not None else model_description.name
340
-
341
- if model_name in self.model_names():
342
- msg = f"Model '{model_name}' is already part of the simulation"
343
- raise ValueError(msg)
344
-
345
- # Build the physics model the model description
346
- physics_model = PhysicsModel.build_from(
347
- model_description=model_description, gravity=self.gravity()
348
- )
349
-
350
- # Build the high-level model from the physics model
351
- model = jaxsim.high_level.model.Model.build(
352
- model_name=model_name,
353
- physics_model=physics_model,
354
- vel_repr=self.velocity_representation,
355
- )
356
-
357
- # Insert the model into the simulators
358
- self.data.models[model.name()] = model
359
-
360
- # Return the newly inserted model
361
- return self.data.models[model.name()]
362
-
363
- @functools.partial(
364
- oop.jax_tf.method_rw,
365
- jit=False,
366
- validate=False,
367
- static_argnames=["model_name"],
368
- )
369
- def remove_model(self, model_name: str) -> None:
370
- """
371
- Remove a model from the simulator.
372
-
373
- Args:
374
- model_name: The name of the model to remove.
375
- """
376
-
377
- if model_name not in self.model_names():
378
- msg = f"Model '{model_name}' is not part of the simulation"
379
- raise ValueError(msg)
380
-
381
- _ = self.data.models.pop(model_name)
382
-
383
- @functools.partial(oop.jax_tf.method_rw, vmap_in_axes=(0, None))
384
- def step(self, clear_inputs: bool = False) -> Dict[str, StepData]:
385
- """
386
- Advance the simulation by one step.
387
-
388
- Args:
389
- clear_inputs: Zero the inputs of the models after the integration.
390
-
391
- Returns:
392
- A dictionary containing the StepData of all models.
393
- """
394
-
395
- # Compute the initial and final time of the integration as integers
396
- t0_ns = jnp.array(self.data.time_ns, dtype=jnp.uint64)
397
- dt_ns = jnp.array(self.step_size_ns * self.steps_per_run, dtype=jnp.uint64)
398
-
399
- # Compute the final time using integer arithmetics
400
- tf_ns = t0_ns + dt_ns
401
-
402
- # We collect the StepData of all models
403
- step_data = {}
404
-
405
- for model in self.models():
406
- # Integrate individually all models and collect their StepData.
407
- # We use the context manager to make sure that the PyTree of the models
408
- # never changes, so that it never triggers JIT recompilations.
409
- with model.editable(validate=True) as integrated_model:
410
- step_data[model.name()] = integrated_model.integrate(
411
- t0=jnp.array(t0_ns, dtype=float) / 1e9,
412
- tf=jnp.array(tf_ns, dtype=float) / 1e9,
413
- sub_steps=self.steps_per_run,
414
- integrator_type=self.integrator_type,
415
- terrain=self.data.terrain,
416
- contact_parameters=self.data.contact_parameters,
417
- clear_inputs=clear_inputs,
418
- )
419
-
420
- self.data.models[model.name()].data = integrated_model.data
421
-
422
- # Store the final time
423
- self.data.time_ns += dt_ns
424
-
425
- return step_data
426
-
427
- @functools.partial(
428
- oop.jax_tf.method_ro,
429
- static_argnames=["horizon_steps"],
430
- vmap_in_axes=(0, None, 0, None),
431
- )
432
- def step_over_horizon(
433
- self,
434
- horizon_steps: jtp.Int,
435
- callback_handler: (
436
- Union["scb.SimulatorCallback", "scb.CallbackHandler"] | None
437
- ) = None,
438
- clear_inputs: jtp.Bool = False,
439
- ) -> Union[
440
- "JaxSim",
441
- tuple["JaxSim", tuple["scb.SimulatorCallback", tuple[jtp.PyTree, jtp.PyTree]]],
442
- ]:
443
- """
444
- Advance the simulation by a given number of steps.
445
-
446
- Args:
447
- horizon_steps: The number of steps to advance the simulation.
448
- callback_handler: A callback handler to inject custom login in the simulation loop.
449
- clear_inputs: Zero the inputs of the models after the integration.
450
-
451
- Returns:
452
- The updated simulator if no callback handler is provided, otherwise a tuple
453
- containing the updated simulator and a tuple containing callback data.
454
- The optional callback data is a tuple containing the updated callback object,
455
- the produced pre-step output, and the produced post-step output.
456
- """
457
-
458
- # Process a mutable copy of the simulator
459
- original_mutability = self._mutability()
460
- sim = self.copy().mutable(validate=True)
461
-
462
- # Helper to get callbacks from the handler
463
- get_cb = lambda h, cb_name: (
464
- getattr(h, cb_name) if h is not None and hasattr(h, cb_name) else None
465
- )
466
-
467
- # Get the callbacks
468
- configure_cb: Optional[scb.ConfigureCallbackSignature] = get_cb(
469
- h=callback_handler, cb_name="configure_cb"
470
- )
471
- pre_step_cb: Optional[scb.PreStepCallbackSignature] = get_cb(
472
- h=callback_handler, cb_name="pre_step_cb"
473
- )
474
- post_step_cb: Optional[scb.PostStepCallbackSignature] = get_cb(
475
- h=callback_handler, cb_name="post_step_cb"
476
- )
477
-
478
- # Callback: configuration
479
- sim = configure_cb(sim) if configure_cb is not None else sim
480
-
481
- # Initialize the carry
482
- Carry = tuple[JaxSim, scb.CallbackHandler]
483
- carry_init: Carry = (sim, callback_handler)
484
-
485
- def body_fun(
486
- carry: Carry, xs: None
487
- ) -> tuple[Carry, tuple[jtp.PyTree, jtp.PyTree]]:
488
- sim, callback_handler = carry
489
-
490
- # Make sure to pass a mutable version of the simulator to the callbacks
491
- sim = sim.mutable(validate=True)
492
-
493
- # Callback: pre-step
494
- sim, out_pre_step = (
495
- pre_step_cb(sim) if pre_step_cb is not None else (sim, None)
496
- )
497
-
498
- # Integrate all models
499
- step_data = sim.step(clear_inputs=clear_inputs)
500
-
501
- # Callback: post-step
502
- sim, out_post_step = (
503
- post_step_cb(sim, step_data)
504
- if post_step_cb is not None
505
- else (sim, None)
506
- )
507
-
508
- # Pack the carry
509
- carry = (sim, callback_handler)
510
-
511
- return carry, (out_pre_step, out_post_step)
512
-
513
- # Integrate over the given horizon
514
- (sim, callback_handler), (
515
- out_pre_step_horizon,
516
- out_post_step_horizon,
517
- ) = jax.lax.scan(f=body_fun, init=carry_init, xs=None, length=horizon_steps)
518
-
519
- # Enforce original mutability of the entire simulator
520
- sim._set_mutability(original_mutability)
521
-
522
- return (
523
- sim
524
- if callback_handler is None
525
- else (
526
- sim,
527
- (callback_handler, (out_pre_step_horizon, out_post_step_horizon)),
528
- )
529
- )
530
-
531
- def vectorize(self: Self, batch_size: int) -> Self:
532
- """
533
- Inherit docs.
534
- """
535
-
536
- jaxsim_vec: JaxSim = super().vectorize(batch_size=batch_size) # noqa
537
-
538
- # We need to manually specify the batch size of the handled models
539
- with jaxsim_vec.mutable_context(mutability=Mutability.MUTABLE):
540
- for model in jaxsim_vec.models():
541
- model.batch_size = batch_size
542
-
543
- return jaxsim_vec