jaxsim 0.3.1.dev17__py3-none-any.whl → 0.3.1.dev21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/_version.py +2 -2
- jaxsim/exceptions.py +63 -0
- {jaxsim-0.3.1.dev17.dist-info → jaxsim-0.3.1.dev21.dist-info}/METADATA +1 -1
- {jaxsim-0.3.1.dev17.dist-info → jaxsim-0.3.1.dev21.dist-info}/RECORD +7 -6
- {jaxsim-0.3.1.dev17.dist-info → jaxsim-0.3.1.dev21.dist-info}/LICENSE +0 -0
- {jaxsim-0.3.1.dev17.dist-info → jaxsim-0.3.1.dev21.dist-info}/WHEEL +0 -0
- {jaxsim-0.3.1.dev17.dist-info → jaxsim-0.3.1.dev21.dist-info}/top_level.txt +0 -0
jaxsim/_version.py
CHANGED
@@ -12,5 +12,5 @@ __version__: str
|
|
12
12
|
__version_tuple__: VERSION_TUPLE
|
13
13
|
version_tuple: VERSION_TUPLE
|
14
14
|
|
15
|
-
__version__ = version = '0.3.1.
|
16
|
-
__version_tuple__ = version_tuple = (0, 3, 1, '
|
15
|
+
__version__ = version = '0.3.1.dev21'
|
16
|
+
__version_tuple__ = version_tuple = (0, 3, 1, 'dev21')
|
jaxsim/exceptions.py
ADDED
@@ -0,0 +1,63 @@
|
|
1
|
+
import jax
|
2
|
+
|
3
|
+
|
4
|
+
def raise_if(
|
5
|
+
condition: bool | jax.Array, exception: type, msg: str, *args, **kwargs
|
6
|
+
) -> None:
|
7
|
+
"""
|
8
|
+
Raise a host-side exception if a condition is met. Useful in jit-compiled functions.
|
9
|
+
|
10
|
+
Args:
|
11
|
+
condition:
|
12
|
+
The boolean condition of the evaluated expression that triggers
|
13
|
+
the exception during runtime.
|
14
|
+
exception: The type of exception to raise.
|
15
|
+
msg:
|
16
|
+
The message to display when the exception is raised. The message can be a
|
17
|
+
format string (fmt), whose fields are filled with the args and kwargs.
|
18
|
+
"""
|
19
|
+
|
20
|
+
# Check early that the format string is well-formed.
|
21
|
+
try:
|
22
|
+
_ = msg.format(*args, **kwargs)
|
23
|
+
except Exception as e:
|
24
|
+
msg = "Error in formatting exception message with args={} and kwargs={}"
|
25
|
+
raise ValueError(msg.format(args, kwargs)) from e
|
26
|
+
|
27
|
+
def _raise_exception(condition: bool, *args, **kwargs) -> None:
|
28
|
+
"""The function called by the JAX callback."""
|
29
|
+
|
30
|
+
if condition:
|
31
|
+
raise exception(msg.format(*args, **kwargs))
|
32
|
+
|
33
|
+
def _callback(args, kwargs) -> None:
|
34
|
+
"""The function that calls the JAX callback, executed only when needed."""
|
35
|
+
|
36
|
+
jax.debug.callback(_raise_exception, condition, *args, **kwargs)
|
37
|
+
|
38
|
+
# Since running a callable on the host is expensive, we prevent its execution
|
39
|
+
# if the condition is False with a low-level conditional expression.
|
40
|
+
def _run_callback_only_if_condition_is_true(*args, **kwargs) -> None:
|
41
|
+
return jax.lax.cond(
|
42
|
+
condition,
|
43
|
+
_callback,
|
44
|
+
lambda args, kwargs: None,
|
45
|
+
args,
|
46
|
+
kwargs,
|
47
|
+
)
|
48
|
+
|
49
|
+
return _run_callback_only_if_condition_is_true(*args, **kwargs)
|
50
|
+
|
51
|
+
|
52
|
+
def raise_runtime_error_if(
|
53
|
+
condition: bool | jax.Array, msg: str, *args, **kwargs
|
54
|
+
) -> None:
|
55
|
+
|
56
|
+
return raise_if(condition, RuntimeError, msg, *args, **kwargs)
|
57
|
+
|
58
|
+
|
59
|
+
def raise_value_error_if(
|
60
|
+
condition: bool | jax.Array, msg: str, *args, **kwargs
|
61
|
+
) -> None:
|
62
|
+
|
63
|
+
return raise_if(condition, ValueError, msg, *args, **kwargs)
|
@@ -1,5 +1,6 @@
|
|
1
1
|
jaxsim/__init__.py,sha256=xzuTuZrgKdWLqqDzbvqzm2cJrEtAbepOeUqDu7ByVek,2621
|
2
|
-
jaxsim/_version.py,sha256=
|
2
|
+
jaxsim/_version.py,sha256=fSekabX0ZEIHwkdp0Sa0iQ2H7hhPzvja3dhu7EFiX4I,426
|
3
|
+
jaxsim/exceptions.py,sha256=8_h8iqL8DgNR754dR8SZiQ7361GR5V1sUk3ZuZCHw1Q,2069
|
3
4
|
jaxsim/logging.py,sha256=c4zhwBKf9eAYAHVp62kTEllqdsZgh0K-kPKVy8L3elU,1584
|
4
5
|
jaxsim/typing.py,sha256=cl7HHQCeP3mHmtF6EuQZcCjGvDmc_AryMWntP_lRBGg,722
|
5
6
|
jaxsim/api/__init__.py,sha256=8eV22t2S3UwNyCg8karPetG1dmX1VDBXkyv28_FwNQA,210
|
@@ -58,8 +59,8 @@ jaxsim/utils/__init__.py,sha256=Y5zyoRevl3EMVQadhZ4EtSwTEkDt2vcnFoRhPJjKTZ0,215
|
|
58
59
|
jaxsim/utils/jaxsim_dataclass.py,sha256=h26timZ_XrBL_Q_oymv-DkQd-EcUiHn8QexAaZXBY9c,11396
|
59
60
|
jaxsim/utils/tracing.py,sha256=KDMoyVPlu2NJvFkhtZwq5AkqMMgajt3munvJom-vEjQ,650
|
60
61
|
jaxsim/utils/wrappers.py,sha256=QIJitSoljrKR_U4T3ewCJPT3DTh-tPZsRsg0t_MH93E,3896
|
61
|
-
jaxsim-0.3.1.
|
62
|
-
jaxsim-0.3.1.
|
63
|
-
jaxsim-0.3.1.
|
64
|
-
jaxsim-0.3.1.
|
65
|
-
jaxsim-0.3.1.
|
62
|
+
jaxsim-0.3.1.dev21.dist-info/LICENSE,sha256=eaYdFmdeMbiIoIiPzEK0MjP1S9wtFXjXNR5er49uLR0,1546
|
63
|
+
jaxsim-0.3.1.dev21.dist-info/METADATA,sha256=wtxQdWa5FFEqYdZx81i-VgNk7DKBY6YMQAXn5_1ctMY,9739
|
64
|
+
jaxsim-0.3.1.dev21.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
65
|
+
jaxsim-0.3.1.dev21.dist-info/top_level.txt,sha256=LxGMA8FLtXjQ6oI7N5gd_R_oSUHxpXxUEOfT1xS_ni0,7
|
66
|
+
jaxsim-0.3.1.dev21.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|