jaxsim 0.3.1.dev17__py3-none-any.whl → 0.3.1.dev21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jaxsim/_version.py CHANGED
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.3.1.dev17'
16
- __version_tuple__ = version_tuple = (0, 3, 1, 'dev17')
15
+ __version__ = version = '0.3.1.dev21'
16
+ __version_tuple__ = version_tuple = (0, 3, 1, 'dev21')
jaxsim/exceptions.py ADDED
@@ -0,0 +1,63 @@
1
+ import jax
2
+
3
+
4
+ def raise_if(
5
+ condition: bool | jax.Array, exception: type, msg: str, *args, **kwargs
6
+ ) -> None:
7
+ """
8
+ Raise a host-side exception if a condition is met. Useful in jit-compiled functions.
9
+
10
+ Args:
11
+ condition:
12
+ The boolean condition of the evaluated expression that triggers
13
+ the exception during runtime.
14
+ exception: The type of exception to raise.
15
+ msg:
16
+ The message to display when the exception is raised. The message can be a
17
+ format string (fmt), whose fields are filled with the args and kwargs.
18
+ """
19
+
20
+ # Check early that the format string is well-formed.
21
+ try:
22
+ _ = msg.format(*args, **kwargs)
23
+ except Exception as e:
24
+ msg = "Error in formatting exception message with args={} and kwargs={}"
25
+ raise ValueError(msg.format(args, kwargs)) from e
26
+
27
+ def _raise_exception(condition: bool, *args, **kwargs) -> None:
28
+ """The function called by the JAX callback."""
29
+
30
+ if condition:
31
+ raise exception(msg.format(*args, **kwargs))
32
+
33
+ def _callback(args, kwargs) -> None:
34
+ """The function that calls the JAX callback, executed only when needed."""
35
+
36
+ jax.debug.callback(_raise_exception, condition, *args, **kwargs)
37
+
38
+ # Since running a callable on the host is expensive, we prevent its execution
39
+ # if the condition is False with a low-level conditional expression.
40
+ def _run_callback_only_if_condition_is_true(*args, **kwargs) -> None:
41
+ return jax.lax.cond(
42
+ condition,
43
+ _callback,
44
+ lambda args, kwargs: None,
45
+ args,
46
+ kwargs,
47
+ )
48
+
49
+ return _run_callback_only_if_condition_is_true(*args, **kwargs)
50
+
51
+
52
+ def raise_runtime_error_if(
53
+ condition: bool | jax.Array, msg: str, *args, **kwargs
54
+ ) -> None:
55
+
56
+ return raise_if(condition, RuntimeError, msg, *args, **kwargs)
57
+
58
+
59
+ def raise_value_error_if(
60
+ condition: bool | jax.Array, msg: str, *args, **kwargs
61
+ ) -> None:
62
+
63
+ return raise_if(condition, ValueError, msg, *args, **kwargs)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: jaxsim
3
- Version: 0.3.1.dev17
3
+ Version: 0.3.1.dev21
4
4
  Home-page: https://github.com/ami-iit/jaxsim
5
5
  Author: Diego Ferigo
6
6
  Author-email: diego.ferigo@iit.it
@@ -1,5 +1,6 @@
1
1
  jaxsim/__init__.py,sha256=xzuTuZrgKdWLqqDzbvqzm2cJrEtAbepOeUqDu7ByVek,2621
2
- jaxsim/_version.py,sha256=EQQfkY5WXMHFjdRnYAQqABGWC0VK4dlpuNh_wr1KxYA,426
2
+ jaxsim/_version.py,sha256=fSekabX0ZEIHwkdp0Sa0iQ2H7hhPzvja3dhu7EFiX4I,426
3
+ jaxsim/exceptions.py,sha256=8_h8iqL8DgNR754dR8SZiQ7361GR5V1sUk3ZuZCHw1Q,2069
3
4
  jaxsim/logging.py,sha256=c4zhwBKf9eAYAHVp62kTEllqdsZgh0K-kPKVy8L3elU,1584
4
5
  jaxsim/typing.py,sha256=cl7HHQCeP3mHmtF6EuQZcCjGvDmc_AryMWntP_lRBGg,722
5
6
  jaxsim/api/__init__.py,sha256=8eV22t2S3UwNyCg8karPetG1dmX1VDBXkyv28_FwNQA,210
@@ -58,8 +59,8 @@ jaxsim/utils/__init__.py,sha256=Y5zyoRevl3EMVQadhZ4EtSwTEkDt2vcnFoRhPJjKTZ0,215
58
59
  jaxsim/utils/jaxsim_dataclass.py,sha256=h26timZ_XrBL_Q_oymv-DkQd-EcUiHn8QexAaZXBY9c,11396
59
60
  jaxsim/utils/tracing.py,sha256=KDMoyVPlu2NJvFkhtZwq5AkqMMgajt3munvJom-vEjQ,650
60
61
  jaxsim/utils/wrappers.py,sha256=QIJitSoljrKR_U4T3ewCJPT3DTh-tPZsRsg0t_MH93E,3896
61
- jaxsim-0.3.1.dev17.dist-info/LICENSE,sha256=eaYdFmdeMbiIoIiPzEK0MjP1S9wtFXjXNR5er49uLR0,1546
62
- jaxsim-0.3.1.dev17.dist-info/METADATA,sha256=zRsMl96hDJt919NgrEuxkhye1S8X20bi_nWdPzJiptU,9739
63
- jaxsim-0.3.1.dev17.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
64
- jaxsim-0.3.1.dev17.dist-info/top_level.txt,sha256=LxGMA8FLtXjQ6oI7N5gd_R_oSUHxpXxUEOfT1xS_ni0,7
65
- jaxsim-0.3.1.dev17.dist-info/RECORD,,
62
+ jaxsim-0.3.1.dev21.dist-info/LICENSE,sha256=eaYdFmdeMbiIoIiPzEK0MjP1S9wtFXjXNR5er49uLR0,1546
63
+ jaxsim-0.3.1.dev21.dist-info/METADATA,sha256=wtxQdWa5FFEqYdZx81i-VgNk7DKBY6YMQAXn5_1ctMY,9739
64
+ jaxsim-0.3.1.dev21.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
65
+ jaxsim-0.3.1.dev21.dist-info/top_level.txt,sha256=LxGMA8FLtXjQ6oI7N5gd_R_oSUHxpXxUEOfT1xS_ni0,7
66
+ jaxsim-0.3.1.dev21.dist-info/RECORD,,