jaxsim 0.5.1.dev86__py3-none-any.whl → 0.5.1.dev91__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jaxsim/_version.py CHANGED
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.5.1.dev86'
16
- __version_tuple__ = version_tuple = (0, 5, 1, 'dev86')
15
+ __version__ = version = '0.5.1.dev91'
16
+ __version_tuple__ = version_tuple = (0, 5, 1, 'dev91')
@@ -53,10 +53,6 @@ class Integrator(JaxsimDataclass, abc.ABC, Generic[State, StateDerivative]):
53
53
  repr=False, hash=False, compare=False, kw_only=True
54
54
  )
55
55
 
56
- metadata: dict[str, Any] = dataclasses.field(
57
- default_factory=dict, repr=False, hash=False, compare=False, kw_only=True
58
- )
59
-
60
56
  @classmethod
61
57
  def build(
62
58
  cls: type[Self],
@@ -102,10 +98,7 @@ class Integrator(JaxsimDataclass, abc.ABC, Generic[State, StateDerivative]):
102
98
 
103
99
  metadata = metadata if metadata is not None else {}
104
100
 
105
- with self.editable(validate=False) as integrator:
106
- integrator.metadata = metadata
107
-
108
- with integrator.mutable_context(mutability=Mutability.MUTABLE):
101
+ with self.mutable_context(mutability=Mutability.MUTABLE) as integrator:
109
102
  xf, metadata_step = integrator(x0, t0, dt, **kwargs)
110
103
 
111
104
  return (
@@ -315,6 +308,9 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
315
308
  b = self.b
316
309
  A = self.A
317
310
 
311
+ # Extract metadata from the kwargs.
312
+ metadata = kwargs.pop("metadata", {})
313
+
318
314
  # Close f over optional kwargs.
319
315
  f = lambda x, t: self.dynamics(x=x, t=t, **kwargs)
320
316
 
@@ -327,7 +323,7 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
327
323
  # or to use the previous state derivative (only integrators supporting FSAL).
328
324
  def get_ẋ0_and_aux_dict() -> tuple[StateDerivative, dict[str, Any]]:
329
325
  ẋ0, aux_dict = f(x0, t0)
330
- return self.metadata.get("dxdt0", ẋ0), aux_dict
326
+ return metadata.get("dxdt0", ẋ0), aux_dict
331
327
 
332
328
  # We use a `jax.lax.scan` to compile the `f` function only once.
333
329
  # Otherwise, if we compute e.g. for RK4 sequentially, the jit-compiled code
@@ -381,7 +377,8 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
381
377
 
382
378
  # Update the FSAL property for the next iteration.
383
379
  if self.has_fsal:
384
- self.metadata["dxdt0"] = jax.tree.map(lambda l: l[self.index_of_fsal], K)
380
+ # Store the first derivative of the next step in the metadata.
381
+ metadata["dxdt0"] = jax.tree.map(lambda l: l[self.index_of_fsal], K)
385
382
 
386
383
  # Compute the output state.
387
384
  # Note that z contains as many new states as the rows of `b.T`.
@@ -394,7 +391,7 @@ class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]
394
391
  lambda xf: self.post_process_state(x0=x0, t0=t0, xf=xf, dt=dt)
395
392
  )(z)
396
393
 
397
- return z_transformed, aux_dict
394
+ return z_transformed, aux_dict | {"metadata": metadata}
398
395
 
399
396
  @staticmethod
400
397
  def butcher_tableau_is_valid(
@@ -14,7 +14,6 @@ from jax_dataclasses import Static
14
14
 
15
15
  import jaxsim.utils.tracing
16
16
  from jaxsim import typing as jtp
17
- from jaxsim.utils import Mutability
18
17
 
19
18
  from .common import (
20
19
  ExplicitRungeKutta,
@@ -271,30 +270,27 @@ class EmbeddedRungeKutta(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
271
270
  # Inject this key to signal that the integrator is initializing.
272
271
  # This is used to allocate the arrays of the metadata dictionary,
273
272
  # that are then filled with NaNs.
274
- integrator.metadata = {EmbeddedRungeKutta.InitializingKey: jnp.array(True)}
273
+ metadata = {EmbeddedRungeKutta.InitializingKey: jnp.array(True)}
275
274
 
276
275
  # Run a dummy call of the integrator.
277
276
  # It is used only to get the metadata so that we know the structure
278
277
  # of the corresponding pytree.
279
278
  _ = integrator(
280
- x0, jnp.array(t0, dtype=float), jnp.array(dt, dtype=float), **kwargs
279
+ x0,
280
+ jnp.array(t0, dtype=float),
281
+ jnp.array(dt, dtype=float),
282
+ **(kwargs | {"metadata": metadata}),
281
283
  )
282
284
 
283
285
  # Remove the injected key.
284
- _ = integrator.metadata.pop(EmbeddedRungeKutta.InitializingKey)
286
+ _ = metadata.pop(EmbeddedRungeKutta.InitializingKey)
285
287
 
286
288
  # Make sure that all leafs of the dictionary are JAX arrays.
287
289
  # Also, since these are dummy parameters, set them all to NaN.
288
290
  metadata_after_init = jax.tree.map(
289
- lambda l: jnp.nan * jnp.zeros_like(l), integrator.metadata
291
+ lambda l: jnp.nan * jnp.zeros_like(l), metadata
290
292
  )
291
293
 
292
- # Store the zero parameters in the integrator.
293
- # When the integrator is stepped, this is used to check if the passed
294
- # parameters are valid.
295
- with self.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):
296
- self.metadata = metadata_after_init
297
-
298
294
  return metadata_after_init
299
295
 
300
296
  def __call__(
@@ -307,7 +303,7 @@ class EmbeddedRungeKutta(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
307
303
  # The metadata is a dictionary of float JAX arrays, that are initialized
308
304
  # with the right shape and filled with NaNs.
309
305
  # 2. During the first step, this method operates on the Nan-filled
310
- # `self.metadata` attribute, and it populates with the actual metadata.
306
+ # `metadata` argument, and it populates with the actual metadata.
311
307
  # 3. After the first step, this method operates on the actual metadata.
312
308
  #
313
309
  # In particular, we store the following information in the metadata:
@@ -318,8 +314,10 @@ class EmbeddedRungeKutta(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
318
314
  # evaluate the dynamics at the final state of the previous step, that matches
319
315
  # the initial state of the current step.
320
316
  #
317
+ metadata = kwargs.pop("metadata", {})
318
+
321
319
  integrator_init = jnp.array(
322
- self.metadata.get(self.InitializingKey, False), dtype=bool
320
+ metadata.get(self.InitializingKey, False), dtype=bool
323
321
  )
324
322
 
325
323
  # Close f over optional kwargs.
@@ -335,24 +333,23 @@ class EmbeddedRungeKutta(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
335
333
 
336
334
  # The value of dt0 is NaN (or, at least, it should be) only after initialization
337
335
  # and before the first step.
338
- self.metadata["dt0"], self.metadata["dxdt0"] = jax.lax.cond(
339
- pred=("dt0" in self.metadata)
340
- & ~jnp.isnan(self.metadata.get("dt0", 0.0)).any(),
336
+ metadata["dt0"], metadata["dxdt0"] = jax.lax.cond(
337
+ pred=("dt0" in metadata) & ~jnp.isnan(metadata.get("dt0", 0.0)).any(),
341
338
  true_fun=lambda metadata: (
342
339
  metadata.get("dt0", jnp.array(0.0, dtype=float)),
343
- self.metadata.get("dxdt0", f(x0, t0)[0]),
340
+ metadata.get("dxdt0", f(x0, t0)[0]),
344
341
  ),
345
342
  false_fun=lambda aux: estimate_step_size(
346
343
  x0=x0, t0=t0, f=f, order=p, atol=self.atol, rtol=self.rtol
347
344
  ),
348
- operand=self.metadata,
345
+ operand=metadata,
349
346
  )
350
347
 
351
348
  # Clip the estimated initial step size to the given bounds, if necessary.
352
- self.metadata["dt0"] = jnp.clip(
353
- self.metadata["dt0"],
354
- jnp.minimum(self.dt_min, self.metadata["dt0"]),
355
- jnp.minimum(self.dt_max, self.metadata["dt0"]),
349
+ metadata["dt0"] = jnp.clip(
350
+ metadata["dt0"],
351
+ jnp.minimum(self.dt_min, metadata["dt0"]),
352
+ jnp.minimum(self.dt_max, metadata["dt0"]),
356
353
  )
357
354
 
358
355
  # =========================================================
@@ -364,7 +361,7 @@ class EmbeddedRungeKutta(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
364
361
  carry0: Carry = (
365
362
  x0,
366
363
  jnp.array(t0).astype(float),
367
- self.metadata,
364
+ metadata,
368
365
  jnp.array(0, dtype=int),
369
366
  jnp.array(False).astype(bool),
370
367
  )
@@ -392,9 +389,10 @@ class EmbeddedRungeKutta(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
392
389
  # Run the underlying explicit RK integrator.
393
390
  # The output z contains multiple solutions (depending on the rows of b.T).
394
391
  with self.editable(validate=True) as integrator:
395
- integrator.metadata = metadata
396
- z, _ = integrator._compute_next_state(x0=x0, t0=t0, dt=Δt0, **kwargs)
397
- metadata_next = integrator.metadata
392
+ z, aux_dict = integrator._compute_next_state(
393
+ x0=x0, t0=t0, dt=Δt0, **kwargs
394
+ )
395
+ metadata_next = aux_dict["metadata"]
398
396
 
399
397
  # Extract the high-order solution xf and the low-order estimate x̂f.
400
398
  xf = jax.tree.map(lambda l: l[self.row_index_of_solution], z)
@@ -481,10 +479,10 @@ class EmbeddedRungeKutta(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
481
479
  metadata_next,
482
480
  discarded_steps,
483
481
  ) = jax.lax.cond(
484
- pred=discarded_steps
485
- >= self.max_step_rejections | local_error
486
- <= 1.0 | Δt_next
487
- < self.dt_min | integrator_init,
482
+ pred=(discarded_steps >= self.max_step_rejections)
483
+ | (local_error <= 1.0)
484
+ | (Δt_next < self.dt_min)
485
+ | integrator_init,
488
486
  true_fun=accept_step,
489
487
  false_fun=reject_step,
490
488
  )
@@ -510,12 +508,7 @@ class EmbeddedRungeKutta(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
510
508
  init_val=carry0,
511
509
  )
512
510
 
513
- # Store the parameters.
514
- # They will be returned to the caller in a functional way in the step method.
515
- with self.mutable_context(mutability=Mutability.MUTABLE):
516
- self.metadata = metadata_tf
517
-
518
- return xf, {}
511
+ return xf, {"metadata": metadata_tf}
519
512
 
520
513
  @property
521
514
  def order_of_solution(self) -> int:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jaxsim
3
- Version: 0.5.1.dev86
3
+ Version: 0.5.1.dev91
4
4
  Summary: A differentiable physics engine and multibody dynamics library for control and robot learning.
5
5
  Author-email: Diego Ferigo <dgferigo@gmail.com>, Filippo Luca Ferretti <filippoluca.ferretti@outlook.com>
6
6
  Maintainer-email: Filippo Luca Ferretti <filippo.ferretti@iit.it>, Alessandro Croci <alessandro.croci@iit.it>
@@ -1,5 +1,5 @@
1
1
  jaxsim/__init__.py,sha256=opgtbhhd1kDsHI4H1vOd3loMPDRi884yQ3tohfFGfNc,3382
2
- jaxsim/_version.py,sha256=2nWfd54AgYYlO27pKe_wwyCtkKok4970qswX81HrrTc,426
2
+ jaxsim/_version.py,sha256=HHdXV3EXu0rha3QfUW2g4pSsGNWqjfhD2e1Qwed6NGk,426
3
3
  jaxsim/exceptions.py,sha256=Sq3qtqeiy-CK76new_W2KKQ-4MAzyOUK5j5pBLr4RPQ,2250
4
4
  jaxsim/logging.py,sha256=STI-D_upXZYX-ZezLrlJJ0UlD5YspST0vZ_DcIwkzO4,1553
5
5
  jaxsim/typing.py,sha256=2HXy9hgazPXjofi1vLQ09ZubPtgVmg80U9NKmZ6NYiI,761
@@ -17,9 +17,9 @@ jaxsim/api/ode.py,sha256=SIg7UKDkJxhS0FlMH6iqipn7WoQWjpQP6EFdvEBkdts,15429
17
17
  jaxsim/api/ode_data.py,sha256=ggF1AVaLW5QuXrfpNsFs-voVcW6gZkxK2Xe9GiDmou0,13755
18
18
  jaxsim/api/references.py,sha256=YkdZhRv8NoBC94qvpwn1w9_alVuxrfiZV5w5NHQIt-g,20737
19
19
  jaxsim/integrators/__init__.py,sha256=hxvOD-VK_mmd6v31wtC-nb28AYve1gLuZCNLV9wS-Kg,103
20
- jaxsim/integrators/common.py,sha256=ohISUnUWTaNHt2kweg1JyzwYGZgIH_wc-01qJWJsO80,18281
20
+ jaxsim/integrators/common.py,sha256=fnDqVIIXMYe2aiT_qnEhJSAeFuYRGhmElVCl7zPTrN8,18229
21
21
  jaxsim/integrators/fixed_step.py,sha256=KpjRd6hHtapxDoo6D1kyDrVDSHnke2TepI5grFH7_bM,2693
22
- jaxsim/integrators/variable_step.py,sha256=Tqz5ySSgyKak_k6cTXpmtqdPNaFlO7N6zj7jBIlChyM,22681
22
+ jaxsim/integrators/variable_step.py,sha256=HuUKudeFj0W7dvVATVNZK3uk1Nh_qKlGO_CDqXJFV14,22166
23
23
  jaxsim/math/__init__.py,sha256=8oPITEoGwgRcOeG8KxtqxPQ8b5uku1HNRMokpCoi9Tc,352
24
24
  jaxsim/math/adjoint.py,sha256=V7r5VrTCKPLEL5gavNSx9U7xSsrb11a5e4gWqJ2MuRo,4375
25
25
  jaxsim/math/cross.py,sha256=U7yEx_l75mSy5g6O-jsjBztApvxC3WaV4MpkS5tThu4,1330
@@ -66,8 +66,8 @@ jaxsim/utils/__init__.py,sha256=Y5zyoRevl3EMVQadhZ4EtSwTEkDt2vcnFoRhPJjKTZ0,215
66
66
  jaxsim/utils/jaxsim_dataclass.py,sha256=TGmTQV2Lq7Q-2nLoAEaeNtkPa_qj0IKkdBm4COj46Os,11312
67
67
  jaxsim/utils/tracing.py,sha256=eEY28MZW0Lm_jJNt1NkFqZz0ek01tvhR46OXZYCo7tc,532
68
68
  jaxsim/utils/wrappers.py,sha256=ZY7olSORzZRvSzkdeNLj8yjwUIAt9L0Douwl7wItjpk,4008
69
- jaxsim-0.5.1.dev86.dist-info/LICENSE,sha256=eaYdFmdeMbiIoIiPzEK0MjP1S9wtFXjXNR5er49uLR0,1546
70
- jaxsim-0.5.1.dev86.dist-info/METADATA,sha256=GQtBCi6r0B5xji-AMKVBniNryixi5LbnvE4xvfJs45Q,17937
71
- jaxsim-0.5.1.dev86.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
72
- jaxsim-0.5.1.dev86.dist-info/top_level.txt,sha256=LxGMA8FLtXjQ6oI7N5gd_R_oSUHxpXxUEOfT1xS_ni0,7
73
- jaxsim-0.5.1.dev86.dist-info/RECORD,,
69
+ jaxsim-0.5.1.dev91.dist-info/LICENSE,sha256=eaYdFmdeMbiIoIiPzEK0MjP1S9wtFXjXNR5er49uLR0,1546
70
+ jaxsim-0.5.1.dev91.dist-info/METADATA,sha256=PPDDHpeFVoVidQMnmZcNT_Spo8ryctjuFISZWYM_ZgI,17937
71
+ jaxsim-0.5.1.dev91.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
72
+ jaxsim-0.5.1.dev91.dist-info/top_level.txt,sha256=LxGMA8FLtXjQ6oI7N5gd_R_oSUHxpXxUEOfT1xS_ni0,7
73
+ jaxsim-0.5.1.dev91.dist-info/RECORD,,