jaxsim 0.2.dev191__py3-none-any.whl → 0.6.1.dev5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. jaxsim/__init__.py +73 -22
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +13 -1
  4. jaxsim/api/com.py +423 -0
  5. jaxsim/api/common.py +48 -19
  6. jaxsim/api/contact.py +604 -52
  7. jaxsim/api/data.py +308 -163
  8. jaxsim/api/frame.py +471 -0
  9. jaxsim/api/joint.py +166 -37
  10. jaxsim/api/kin_dyn_parameters.py +901 -0
  11. jaxsim/api/link.py +277 -78
  12. jaxsim/api/model.py +1572 -362
  13. jaxsim/api/ode.py +324 -133
  14. jaxsim/api/ode_data.py +401 -0
  15. jaxsim/api/references.py +216 -80
  16. jaxsim/exceptions.py +80 -0
  17. jaxsim/integrators/__init__.py +2 -2
  18. jaxsim/integrators/common.py +191 -107
  19. jaxsim/integrators/fixed_step.py +97 -102
  20. jaxsim/integrators/variable_step.py +706 -0
  21. jaxsim/logging.py +1 -2
  22. jaxsim/math/__init__.py +13 -0
  23. jaxsim/math/adjoint.py +64 -30
  24. jaxsim/math/cross.py +18 -9
  25. jaxsim/math/inertia.py +11 -9
  26. jaxsim/math/joint_model.py +289 -0
  27. jaxsim/math/quaternion.py +59 -25
  28. jaxsim/math/rotation.py +30 -24
  29. jaxsim/math/skew.py +18 -7
  30. jaxsim/math/transform.py +102 -0
  31. jaxsim/math/utils.py +31 -0
  32. jaxsim/mujoco/__init__.py +2 -1
  33. jaxsim/mujoco/loaders.py +216 -29
  34. jaxsim/mujoco/model.py +163 -33
  35. jaxsim/mujoco/utils.py +228 -0
  36. jaxsim/mujoco/visualizer.py +107 -22
  37. jaxsim/parsers/__init__.py +0 -1
  38. jaxsim/parsers/descriptions/__init__.py +8 -2
  39. jaxsim/parsers/descriptions/collision.py +83 -26
  40. jaxsim/parsers/descriptions/joint.py +80 -87
  41. jaxsim/parsers/descriptions/link.py +58 -31
  42. jaxsim/parsers/descriptions/model.py +101 -68
  43. jaxsim/parsers/kinematic_graph.py +606 -229
  44. jaxsim/parsers/rod/meshes.py +104 -0
  45. jaxsim/parsers/rod/parser.py +125 -82
  46. jaxsim/parsers/rod/utils.py +127 -82
  47. jaxsim/rbda/__init__.py +11 -0
  48. jaxsim/rbda/aba.py +289 -0
  49. jaxsim/rbda/collidable_points.py +156 -0
  50. jaxsim/rbda/contacts/__init__.py +13 -0
  51. jaxsim/rbda/contacts/common.py +313 -0
  52. jaxsim/rbda/contacts/relaxed_rigid.py +605 -0
  53. jaxsim/rbda/contacts/rigid.py +462 -0
  54. jaxsim/rbda/contacts/soft.py +480 -0
  55. jaxsim/rbda/contacts/visco_elastic.py +1066 -0
  56. jaxsim/rbda/crba.py +167 -0
  57. jaxsim/rbda/forward_kinematics.py +117 -0
  58. jaxsim/rbda/jacobian.py +330 -0
  59. jaxsim/rbda/rnea.py +235 -0
  60. jaxsim/rbda/utils.py +160 -0
  61. jaxsim/terrain/__init__.py +2 -0
  62. jaxsim/terrain/terrain.py +238 -0
  63. jaxsim/typing.py +24 -24
  64. jaxsim/utils/__init__.py +1 -4
  65. jaxsim/utils/jaxsim_dataclass.py +289 -34
  66. jaxsim/utils/tracing.py +5 -11
  67. jaxsim/utils/wrappers.py +159 -0
  68. {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev5.dist-info}/LICENSE +1 -1
  69. jaxsim-0.6.1.dev5.dist-info/METADATA +465 -0
  70. jaxsim-0.6.1.dev5.dist-info/RECORD +74 -0
  71. {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev5.dist-info}/WHEEL +1 -1
  72. jaxsim/high_level/__init__.py +0 -2
  73. jaxsim/high_level/common.py +0 -11
  74. jaxsim/high_level/joint.py +0 -148
  75. jaxsim/high_level/link.py +0 -259
  76. jaxsim/high_level/model.py +0 -1686
  77. jaxsim/math/conv.py +0 -114
  78. jaxsim/math/joint.py +0 -102
  79. jaxsim/math/plucker.py +0 -100
  80. jaxsim/physics/__init__.py +0 -12
  81. jaxsim/physics/algos/__init__.py +0 -0
  82. jaxsim/physics/algos/aba.py +0 -254
  83. jaxsim/physics/algos/aba_motors.py +0 -284
  84. jaxsim/physics/algos/crba.py +0 -154
  85. jaxsim/physics/algos/forward_kinematics.py +0 -79
  86. jaxsim/physics/algos/jacobian.py +0 -98
  87. jaxsim/physics/algos/rnea.py +0 -180
  88. jaxsim/physics/algos/rnea_motors.py +0 -196
  89. jaxsim/physics/algos/soft_contacts.py +0 -523
  90. jaxsim/physics/algos/terrain.py +0 -78
  91. jaxsim/physics/algos/utils.py +0 -69
  92. jaxsim/physics/model/__init__.py +0 -0
  93. jaxsim/physics/model/ground_contact.py +0 -53
  94. jaxsim/physics/model/physics_model.py +0 -388
  95. jaxsim/physics/model/physics_model_state.py +0 -283
  96. jaxsim/simulation/__init__.py +0 -4
  97. jaxsim/simulation/integrators.py +0 -393
  98. jaxsim/simulation/ode.py +0 -290
  99. jaxsim/simulation/ode_data.py +0 -96
  100. jaxsim/simulation/ode_integration.py +0 -62
  101. jaxsim/simulation/simulator.py +0 -543
  102. jaxsim/simulation/simulator_callbacks.py +0 -79
  103. jaxsim/simulation/utils.py +0 -15
  104. jaxsim/sixd/__init__.py +0 -2
  105. jaxsim/utils/oop.py +0 -536
  106. jaxsim/utils/vmappable.py +0 -117
  107. jaxsim-0.2.dev191.dist-info/METADATA +0 -184
  108. jaxsim-0.2.dev191.dist-info/RECORD +0 -81
  109. {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev5.dist-info}/top_level.txt +0 -0
@@ -1,543 +0,0 @@
1
- import dataclasses
2
- import functools
3
- import pathlib
4
- from typing import Dict, List, Optional, Union
5
-
6
- try:
7
- from typing import Self
8
- except ImportError:
9
- from typing_extensions import Self
10
-
11
- import jax
12
- import jax.numpy as jnp
13
- import jax_dataclasses
14
- import rod
15
- from jax_dataclasses import Static
16
-
17
- import jaxsim.high_level
18
- import jaxsim.physics
19
- import jaxsim.typing as jtp
20
- from jaxsim import logging
21
- from jaxsim.high_level.common import VelRepr
22
- from jaxsim.high_level.model import Model, StepData
23
- from jaxsim.parsers import descriptions
24
- from jaxsim.physics.algos.soft_contacts import SoftContactsParams
25
- from jaxsim.physics.algos.terrain import FlatTerrain, Terrain
26
- from jaxsim.physics.model.physics_model import PhysicsModel
27
- from jaxsim.utils import Mutability, Vmappable, oop
28
-
29
- from . import simulator_callbacks as scb
30
- from .ode_integration import IntegratorType
31
-
32
-
33
- @jax_dataclasses.pytree_dataclass
34
- class SimulatorData(Vmappable):
35
- """
36
- Data used by the simulator.
37
-
38
- It can be used as JaxSim state in a functional programming style.
39
- """
40
-
41
- # Simulation time stored in ns in order to prevent floats approximation
42
- time_ns: jtp.Int = dataclasses.field(
43
- default_factory=lambda: jnp.array(0, dtype=jnp.uint64)
44
- )
45
-
46
- # Terrain and contact parameters
47
- terrain: Terrain = dataclasses.field(default_factory=lambda: FlatTerrain())
48
- contact_parameters: SoftContactsParams = dataclasses.field(
49
- default_factory=lambda: SoftContactsParams()
50
- )
51
-
52
- # Dictionary containing all handled models
53
- models: Dict[str, Model] = dataclasses.field(default_factory=dict)
54
-
55
- # Default gravity vector (could be overridden for individual models)
56
- gravity: jtp.Vector = dataclasses.field(
57
- default_factory=lambda: jaxsim.physics.default_gravity()
58
- )
59
-
60
-
61
- @jax_dataclasses.pytree_dataclass
62
- class JaxSim(Vmappable):
63
- """The JaxSim simulator."""
64
-
65
- # Step size stored in ns in order to prevent floats approximation
66
- step_size_ns: Static[jtp.Int] = dataclasses.field(
67
- default_factory=lambda: jnp.array(1_000_000, dtype=jnp.uint64)
68
- )
69
-
70
- # Number of sub-steps performed at each integration step.
71
- # Note: there is no collision detection performed in sub-steps.
72
- steps_per_run: Static[jtp.Int] = dataclasses.field(default=1)
73
-
74
- # Default velocity representation (could be overridden for individual models)
75
- velocity_representation: Static[VelRepr] = dataclasses.field(
76
- default=VelRepr.Inertial
77
- )
78
-
79
- # Integrator type
80
- integrator_type: Static[IntegratorType] = dataclasses.field(
81
- default=IntegratorType.EulerForward
82
- )
83
-
84
- # Simulator data
85
- data: SimulatorData = dataclasses.field(default_factory=lambda: SimulatorData())
86
-
87
- @staticmethod
88
- def build(
89
- step_size: jtp.Float,
90
- steps_per_run: jtp.Int = 1,
91
- velocity_representation: VelRepr = VelRepr.Inertial,
92
- integrator_type: IntegratorType = IntegratorType.EulerSemiImplicit,
93
- simulator_data: SimulatorData | None = None,
94
- ) -> "JaxSim":
95
- """
96
- Build a JaxSim simulator object.
97
-
98
- Args:
99
- step_size: The integration step size in seconds.
100
- steps_per_run: Number of sub-steps performed at each integration step.
101
- velocity_representation: Default velocity representation of simulated models.
102
- integrator_type: Type of integrator used for integrating the equations of motion.
103
- simulator_data: Optional simulator data to initialize the simulator state.
104
-
105
- Returns:
106
- The JaxSim simulator object.
107
- """
108
-
109
- return JaxSim(
110
- step_size_ns=jnp.array(step_size * 1e9, dtype=jnp.uint64),
111
- steps_per_run=int(steps_per_run),
112
- velocity_representation=velocity_representation,
113
- integrator_type=integrator_type,
114
- data=simulator_data if simulator_data is not None else SimulatorData(),
115
- )
116
-
117
- @functools.partial(
118
- oop.jax_tf.method_rw, static_argnames=["remove_models"], validate=False
119
- )
120
- def reset(self, remove_models: bool = True) -> None:
121
- """
122
- Reset the simulator.
123
-
124
- Args:
125
- remove_models: Flag indicating whether to remove all models from the simulator.
126
- If False, the models are kept but their state is reset.
127
- """
128
-
129
- self.data.time_ns = jnp.zeros_like(self.data.time_ns)
130
-
131
- if remove_models:
132
- self.data.models = {}
133
- else:
134
- _ = [m.zero() for m in self.models()]
135
-
136
- @functools.partial(oop.jax_tf.method_rw, jit=False)
137
- def set_step_size(self, step_size: float) -> None:
138
- """
139
- Set the integration step size.
140
-
141
- Args:
142
- step_size: The integration step size in seconds.
143
- """
144
-
145
- self.step_size_ns = jnp.array(step_size * 1e9, dtype=jnp.uint64)
146
-
147
- @functools.partial(oop.jax_tf.method_ro, jit=False)
148
- def step_size(self) -> jtp.Float:
149
- """
150
- Get the integration step size.
151
-
152
- Returns:
153
- The integration step size in seconds.
154
- """
155
-
156
- return jnp.array(self.step_size_ns / 1e9, dtype=float)
157
-
158
- @functools.partial(oop.jax_tf.method_ro)
159
- def dt(self) -> jtp.Float:
160
- """
161
- Return the integration step size in seconds.
162
-
163
- Returns:
164
- The integration step size in seconds.
165
- """
166
-
167
- return jnp.array((self.step_size_ns * self.steps_per_run) / 1e9, dtype=float)
168
-
169
- @functools.partial(oop.jax_tf.method_ro)
170
- def time(self) -> jtp.Float:
171
- """
172
- Return the current simulation time in seconds.
173
-
174
- Returns:
175
- The current simulation time in seconds.
176
- """
177
-
178
- return jnp.array(self.data.time_ns / 1e9, dtype=float)
179
-
180
- @functools.partial(oop.jax_tf.method_ro)
181
- def gravity(self) -> jtp.Vector:
182
- """
183
- Return the 3D gravity vector.
184
-
185
- Returns:
186
- The 3D gravity vector.
187
- """
188
-
189
- return jnp.array(self.data.gravity, dtype=float)
190
-
191
- @functools.partial(oop.jax_tf.method_ro, jit=False, vmap=False)
192
- def model_names(self) -> tuple[str, ...]:
193
- """
194
- Return the list of model names.
195
-
196
- Returns:
197
- The list of model names.
198
- """
199
-
200
- return tuple(self.data.models.keys())
201
-
202
- @functools.partial(
203
- oop.jax_tf.method_ro, static_argnames=["model_name"], jit=False, vmap=False
204
- )
205
- def get_model(self, model_name: str) -> Model:
206
- """
207
- Return the model with the given name.
208
-
209
- Args:
210
- model_name: The name of the model to return.
211
-
212
- Returns:
213
- The model with the given name.
214
- """
215
-
216
- if model_name not in self.data.models:
217
- raise ValueError(f"Failed to find model '{model_name}'")
218
-
219
- return self.data.models[model_name]
220
-
221
- @functools.partial(oop.jax_tf.method_ro, jit=False, vmap=False)
222
- def models(self, model_names: tuple[str, ...] | None = None) -> tuple[Model, ...]:
223
- """
224
- Return the simulated models.
225
-
226
- Args:
227
- model_names: Optional list of model names to return.
228
- If None, all models are returned.
229
-
230
- Returns:
231
- The list of simulated models.
232
- """
233
-
234
- model_names = model_names if model_names is not None else self.model_names()
235
- return tuple(self.data.models[name] for name in model_names)
236
-
237
- @functools.partial(oop.jax_tf.method_rw)
238
- def set_gravity(self, gravity: jtp.Vector) -> None:
239
- """
240
- Set the gravity vector to all the simulated models.
241
-
242
- Args:
243
- gravity: The 3D gravity vector.
244
- """
245
-
246
- gravity = jnp.array(gravity, dtype=float)
247
-
248
- if gravity.size != 3:
249
- raise ValueError(gravity)
250
-
251
- self.data.gravity = gravity
252
-
253
- for model in self.data.models.values():
254
- model.physics_model.set_gravity(gravity=gravity)
255
-
256
- @functools.partial(oop.jax_tf.method_rw, jit=False, vmap=False, validate=False)
257
- def insert_model_from_description(
258
- self,
259
- model_description: Union[pathlib.Path, str, rod.Model],
260
- model_name: str | None = None,
261
- considered_joints: List[str] | None = None,
262
- ) -> Model:
263
- """
264
- Insert a model from a model description.
265
-
266
- Args:
267
- model_description: A path to an SDF/URDF file, a string containing its content, or a pre-parsed/pre-built rod model.
268
- model_name: The optional name of the model that overrides the one in the description.
269
- considered_joints: Optional list of joints to consider. It is also useful to specify the joint serialization.
270
-
271
- Returns:
272
- The newly inserted model.
273
- """
274
-
275
- if self.vectorized:
276
- raise RuntimeError("Cannot insert a model in a vectorized simulation")
277
-
278
- # Build the model from the given model description
279
- model = jaxsim.high_level.model.Model.build_from_model_description(
280
- model_description=model_description,
281
- model_name=model_name,
282
- vel_repr=self.velocity_representation,
283
- considered_joints=considered_joints,
284
- )
285
-
286
- # Make sure the model is not already part of the simulation
287
- if model.name() in self.model_names():
288
- msg = f"Model '{model.name()}' is already part of the simulation"
289
- raise ValueError(msg)
290
-
291
- # Insert the model
292
- self.data.models[model.name()] = model
293
-
294
- # Return the newly inserted model
295
- return self.data.models[model.name()]
296
-
297
- @functools.partial(oop.jax_tf.method_rw, jit=False, vmap=False, validate=False)
298
- def insert_model_from_sdf(
299
- self,
300
- sdf: Union[pathlib.Path, str],
301
- model_name: str | None = None,
302
- considered_joints: List[str] | None = None,
303
- ) -> Model:
304
- """
305
- Insert a model from an SDF resource.
306
- """
307
-
308
- msg = "JaxSim.{} is deprecated, use JaxSim.{} instead."
309
- logging.warning(
310
- msg=msg.format("insert_model_from_sdf", "insert_model_from_description")
311
- )
312
-
313
- return self.insert_model_from_description(
314
- model_description=sdf,
315
- model_name=model_name,
316
- considered_joints=considered_joints,
317
- )
318
-
319
- @functools.partial(oop.jax_tf.method_rw, jit=False, vmap=False, validate=False)
320
- def insert_model(
321
- self,
322
- model_description: descriptions.ModelDescription,
323
- model_name: str | None = None,
324
- ) -> Model:
325
- """
326
- Insert a model from a model description object.
327
-
328
- Args:
329
- model_description: The model description object.
330
- model_name: Optional name of the model to insert.
331
-
332
- Returns:
333
- The newly inserted model.
334
- """
335
-
336
- if self.vectorized:
337
- raise RuntimeError("Cannot insert a model in a vectorized simulation")
338
-
339
- model_name = model_name if model_name is not None else model_description.name
340
-
341
- if model_name in self.model_names():
342
- msg = f"Model '{model_name}' is already part of the simulation"
343
- raise ValueError(msg)
344
-
345
- # Build the physics model the model description
346
- physics_model = PhysicsModel.build_from(
347
- model_description=model_description, gravity=self.gravity()
348
- )
349
-
350
- # Build the high-level model from the physics model
351
- model = jaxsim.high_level.model.Model.build(
352
- model_name=model_name,
353
- physics_model=physics_model,
354
- vel_repr=self.velocity_representation,
355
- )
356
-
357
- # Insert the model into the simulators
358
- self.data.models[model.name()] = model
359
-
360
- # Return the newly inserted model
361
- return self.data.models[model.name()]
362
-
363
- @functools.partial(
364
- oop.jax_tf.method_rw,
365
- jit=False,
366
- validate=False,
367
- static_argnames=["model_name"],
368
- )
369
- def remove_model(self, model_name: str) -> None:
370
- """
371
- Remove a model from the simulator.
372
-
373
- Args:
374
- model_name: The name of the model to remove.
375
- """
376
-
377
- if model_name not in self.model_names():
378
- msg = f"Model '{model_name}' is not part of the simulation"
379
- raise ValueError(msg)
380
-
381
- _ = self.data.models.pop(model_name)
382
-
383
- @functools.partial(oop.jax_tf.method_rw, vmap_in_axes=(0, None))
384
- def step(self, clear_inputs: bool = False) -> Dict[str, StepData]:
385
- """
386
- Advance the simulation by one step.
387
-
388
- Args:
389
- clear_inputs: Zero the inputs of the models after the integration.
390
-
391
- Returns:
392
- A dictionary containing the StepData of all models.
393
- """
394
-
395
- # Compute the initial and final time of the integration as integers
396
- t0_ns = jnp.array(self.data.time_ns, dtype=jnp.uint64)
397
- dt_ns = jnp.array(self.step_size_ns * self.steps_per_run, dtype=jnp.uint64)
398
-
399
- # Compute the final time using integer arithmetics
400
- tf_ns = t0_ns + dt_ns
401
-
402
- # We collect the StepData of all models
403
- step_data = {}
404
-
405
- for model in self.models():
406
- # Integrate individually all models and collect their StepData.
407
- # We use the context manager to make sure that the PyTree of the models
408
- # never changes, so that it never triggers JIT recompilations.
409
- with model.editable(validate=True) as integrated_model:
410
- step_data[model.name()] = integrated_model.integrate(
411
- t0=jnp.array(t0_ns, dtype=float) / 1e9,
412
- tf=jnp.array(tf_ns, dtype=float) / 1e9,
413
- sub_steps=self.steps_per_run,
414
- integrator_type=self.integrator_type,
415
- terrain=self.data.terrain,
416
- contact_parameters=self.data.contact_parameters,
417
- clear_inputs=clear_inputs,
418
- )
419
-
420
- self.data.models[model.name()].data = integrated_model.data
421
-
422
- # Store the final time
423
- self.data.time_ns += dt_ns
424
-
425
- return step_data
426
-
427
- @functools.partial(
428
- oop.jax_tf.method_ro,
429
- static_argnames=["horizon_steps"],
430
- vmap_in_axes=(0, None, 0, None),
431
- )
432
- def step_over_horizon(
433
- self,
434
- horizon_steps: jtp.Int,
435
- callback_handler: (
436
- Union["scb.SimulatorCallback", "scb.CallbackHandler"] | None
437
- ) = None,
438
- clear_inputs: jtp.Bool = False,
439
- ) -> Union[
440
- "JaxSim",
441
- tuple["JaxSim", tuple["scb.SimulatorCallback", tuple[jtp.PyTree, jtp.PyTree]]],
442
- ]:
443
- """
444
- Advance the simulation by a given number of steps.
445
-
446
- Args:
447
- horizon_steps: The number of steps to advance the simulation.
448
- callback_handler: A callback handler to inject custom login in the simulation loop.
449
- clear_inputs: Zero the inputs of the models after the integration.
450
-
451
- Returns:
452
- The updated simulator if no callback handler is provided, otherwise a tuple
453
- containing the updated simulator and a tuple containing callback data.
454
- The optional callback data is a tuple containing the updated callback object,
455
- the produced pre-step output, and the produced post-step output.
456
- """
457
-
458
- # Process a mutable copy of the simulator
459
- original_mutability = self._mutability()
460
- sim = self.copy().mutable(validate=True)
461
-
462
- # Helper to get callbacks from the handler
463
- get_cb = lambda h, cb_name: (
464
- getattr(h, cb_name) if h is not None and hasattr(h, cb_name) else None
465
- )
466
-
467
- # Get the callbacks
468
- configure_cb: Optional[scb.ConfigureCallbackSignature] = get_cb(
469
- h=callback_handler, cb_name="configure_cb"
470
- )
471
- pre_step_cb: Optional[scb.PreStepCallbackSignature] = get_cb(
472
- h=callback_handler, cb_name="pre_step_cb"
473
- )
474
- post_step_cb: Optional[scb.PostStepCallbackSignature] = get_cb(
475
- h=callback_handler, cb_name="post_step_cb"
476
- )
477
-
478
- # Callback: configuration
479
- sim = configure_cb(sim) if configure_cb is not None else sim
480
-
481
- # Initialize the carry
482
- Carry = tuple[JaxSim, scb.CallbackHandler]
483
- carry_init: Carry = (sim, callback_handler)
484
-
485
- def body_fun(
486
- carry: Carry, xs: None
487
- ) -> tuple[Carry, tuple[jtp.PyTree, jtp.PyTree]]:
488
- sim, callback_handler = carry
489
-
490
- # Make sure to pass a mutable version of the simulator to the callbacks
491
- sim = sim.mutable(validate=True)
492
-
493
- # Callback: pre-step
494
- sim, out_pre_step = (
495
- pre_step_cb(sim) if pre_step_cb is not None else (sim, None)
496
- )
497
-
498
- # Integrate all models
499
- step_data = sim.step(clear_inputs=clear_inputs)
500
-
501
- # Callback: post-step
502
- sim, out_post_step = (
503
- post_step_cb(sim, step_data)
504
- if post_step_cb is not None
505
- else (sim, None)
506
- )
507
-
508
- # Pack the carry
509
- carry = (sim, callback_handler)
510
-
511
- return carry, (out_pre_step, out_post_step)
512
-
513
- # Integrate over the given horizon
514
- (sim, callback_handler), (
515
- out_pre_step_horizon,
516
- out_post_step_horizon,
517
- ) = jax.lax.scan(f=body_fun, init=carry_init, xs=None, length=horizon_steps)
518
-
519
- # Enforce original mutability of the entire simulator
520
- sim._set_mutability(original_mutability)
521
-
522
- return (
523
- sim
524
- if callback_handler is None
525
- else (
526
- sim,
527
- (callback_handler, (out_pre_step_horizon, out_post_step_horizon)),
528
- )
529
- )
530
-
531
- def vectorize(self: Self, batch_size: int) -> Self:
532
- """
533
- Inherit docs.
534
- """
535
-
536
- jaxsim_vec: JaxSim = super().vectorize(batch_size=batch_size) # noqa
537
-
538
- # We need to manually specify the batch size of the handled models
539
- with jaxsim_vec.mutable_context(mutability=Mutability.MUTABLE):
540
- for model in jaxsim_vec.models():
541
- model.batch_size = batch_size
542
-
543
- return jaxsim_vec
@@ -1,79 +0,0 @@
1
- import abc
2
- from typing import Callable, Dict, Tuple
3
-
4
- import jaxsim.typing as jtp
5
- from jaxsim.high_level.model import StepData
6
-
7
- ConfigureCallbackSignature = Callable[["jaxsim.JaxSim"], "jaxsim.JaxSim"]
8
- PreStepCallbackSignature = Callable[
9
- ["jaxsim.JaxSim"], Tuple["jaxsim.JaxSim", jtp.PyTree]
10
- ]
11
- PostStepCallbackSignature = Callable[
12
- ["jaxsim.JaxSim", Dict[str, StepData]], Tuple["jaxsim.JaxSim", jtp.PyTree]
13
- ]
14
-
15
-
16
- class SimulatorCallback(abc.ABC):
17
- """
18
- A base class for simulator callbacks.
19
- """
20
-
21
- pass
22
-
23
-
24
- class ConfigureCallback(SimulatorCallback):
25
- """
26
- A callback class to define logic for configuring the simulator before taking the first step.
27
- """
28
-
29
- @property
30
- def configure_cb(self) -> ConfigureCallbackSignature:
31
- return lambda sim: self.configure(sim=sim)
32
-
33
- @abc.abstractmethod
34
- def configure(self, sim: "jaxsim.JaxSim") -> "jaxsim.JaxSim":
35
- pass
36
-
37
-
38
- class PreStepCallback(SimulatorCallback):
39
- """
40
- A callback class for performing actions before each simulation step.
41
- """
42
-
43
- @property
44
- def pre_step_cb(self) -> PreStepCallbackSignature:
45
- return lambda sim: self.pre_step(sim=sim)
46
-
47
- @abc.abstractmethod
48
- def pre_step(self, sim: "jaxsim.JaxSim") -> Tuple["jaxsim.JaxSim", jtp.PyTree]:
49
- pass
50
-
51
-
52
- class PostStepCallback(SimulatorCallback):
53
- """
54
- A callback class for performing actions after each simulation step.
55
- """
56
-
57
- @property
58
- def post_step_cb(self) -> PostStepCallbackSignature:
59
- return lambda sim, step_data: self.post_step(sim=sim, step_data=step_data)
60
-
61
- @abc.abstractmethod
62
- def post_step(
63
- self, sim: "jaxsim.JaxSim", step_data: Dict[str, StepData]
64
- ) -> Tuple["jaxsim.JaxSim", jtp.PyTree]:
65
- pass
66
-
67
-
68
- class CallbackHandler(ConfigureCallback, PreStepCallback, PostStepCallback):
69
- """
70
- A class that handles callbacks for the simulator.
71
-
72
- Note:
73
- The are different simulation stages with associated callbacks:
74
- - `configure`: runs before the first step is taken.
75
- - `pre_step`: runs at each step before integrating the dynamics and advancing the time.
76
- - `post_step`: runs at each step after the integration of the dynamics.
77
- """
78
-
79
- pass
@@ -1,15 +0,0 @@
1
- from typing import Tuple
2
-
3
- from jaxsim import logging
4
-
5
-
6
- def check_valid_shape(
7
- what: str, shape: Tuple, expected_shape: Tuple, valid: bool
8
- ) -> bool:
9
- valid_shape = shape == expected_shape
10
-
11
- if not valid_shape:
12
- logging.debug(f"Shape of {what} differs: {shape}, {expected_shape}")
13
- raise
14
-
15
- return valid
jaxsim/sixd/__init__.py DELETED
@@ -1,2 +0,0 @@
1
- from jaxlie import _se3 as se3
2
- from jaxlie import _so3 as so3