jaxsim 0.2.dev188__py3-none-any.whl → 0.2.dev364__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/__init__.py +3 -4
- jaxsim/_version.py +2 -2
- jaxsim/api/__init__.py +3 -1
- jaxsim/api/com.py +240 -0
- jaxsim/api/common.py +13 -2
- jaxsim/api/contact.py +120 -43
- jaxsim/api/data.py +112 -71
- jaxsim/api/joint.py +77 -36
- jaxsim/api/kin_dyn_parameters.py +777 -0
- jaxsim/api/link.py +150 -75
- jaxsim/api/model.py +542 -269
- jaxsim/api/ode.py +88 -72
- jaxsim/api/ode_data.py +694 -0
- jaxsim/api/references.py +12 -11
- jaxsim/integrators/__init__.py +2 -2
- jaxsim/integrators/common.py +110 -24
- jaxsim/integrators/fixed_step.py +11 -67
- jaxsim/integrators/variable_step.py +610 -0
- jaxsim/math/__init__.py +11 -0
- jaxsim/math/adjoint.py +24 -2
- jaxsim/math/joint_model.py +335 -0
- jaxsim/math/quaternion.py +44 -3
- jaxsim/math/rotation.py +4 -4
- jaxsim/math/transform.py +93 -0
- jaxsim/parsers/descriptions/collision.py +14 -0
- jaxsim/parsers/descriptions/link.py +13 -2
- jaxsim/parsers/kinematic_graph.py +5 -0
- jaxsim/parsers/rod/utils.py +7 -8
- jaxsim/rbda/__init__.py +7 -0
- jaxsim/rbda/aba.py +295 -0
- jaxsim/rbda/collidable_points.py +142 -0
- jaxsim/{physics/algos → rbda}/crba.py +43 -42
- jaxsim/rbda/forward_kinematics.py +113 -0
- jaxsim/rbda/jacobian.py +201 -0
- jaxsim/rbda/rnea.py +237 -0
- jaxsim/rbda/soft_contacts.py +296 -0
- jaxsim/rbda/utils.py +152 -0
- jaxsim/terrain/__init__.py +2 -0
- jaxsim/{physics/algos → terrain}/terrain.py +4 -6
- jaxsim/utils/__init__.py +1 -4
- jaxsim/utils/hashless.py +18 -0
- jaxsim/utils/jaxsim_dataclass.py +281 -30
- {jaxsim-0.2.dev188.dist-info → jaxsim-0.2.dev364.dist-info}/METADATA +4 -6
- jaxsim-0.2.dev364.dist-info/RECORD +64 -0
- {jaxsim-0.2.dev188.dist-info → jaxsim-0.2.dev364.dist-info}/WHEEL +1 -1
- jaxsim/high_level/__init__.py +0 -2
- jaxsim/high_level/common.py +0 -11
- jaxsim/high_level/joint.py +0 -148
- jaxsim/high_level/link.py +0 -259
- jaxsim/high_level/model.py +0 -1686
- jaxsim/math/conv.py +0 -114
- jaxsim/math/joint.py +0 -102
- jaxsim/math/plucker.py +0 -100
- jaxsim/physics/__init__.py +0 -12
- jaxsim/physics/algos/__init__.py +0 -0
- jaxsim/physics/algos/aba.py +0 -254
- jaxsim/physics/algos/aba_motors.py +0 -284
- jaxsim/physics/algos/forward_kinematics.py +0 -79
- jaxsim/physics/algos/jacobian.py +0 -98
- jaxsim/physics/algos/rnea.py +0 -180
- jaxsim/physics/algos/rnea_motors.py +0 -196
- jaxsim/physics/algos/soft_contacts.py +0 -523
- jaxsim/physics/algos/utils.py +0 -69
- jaxsim/physics/model/__init__.py +0 -0
- jaxsim/physics/model/ground_contact.py +0 -55
- jaxsim/physics/model/physics_model.py +0 -388
- jaxsim/physics/model/physics_model_state.py +0 -283
- jaxsim/simulation/__init__.py +0 -4
- jaxsim/simulation/integrators.py +0 -393
- jaxsim/simulation/ode.py +0 -290
- jaxsim/simulation/ode_data.py +0 -96
- jaxsim/simulation/ode_integration.py +0 -62
- jaxsim/simulation/simulator.py +0 -543
- jaxsim/simulation/simulator_callbacks.py +0 -79
- jaxsim/simulation/utils.py +0 -15
- jaxsim/sixd/__init__.py +0 -2
- jaxsim/utils/oop.py +0 -536
- jaxsim/utils/vmappable.py +0 -117
- jaxsim-0.2.dev188.dist-info/RECORD +0 -81
- {jaxsim-0.2.dev188.dist-info → jaxsim-0.2.dev364.dist-info}/LICENSE +0 -0
- {jaxsim-0.2.dev188.dist-info → jaxsim-0.2.dev364.dist-info}/top_level.txt +0 -0
@@ -1,79 +0,0 @@
|
|
1
|
-
import abc
|
2
|
-
from typing import Callable, Dict, Tuple
|
3
|
-
|
4
|
-
import jaxsim.typing as jtp
|
5
|
-
from jaxsim.high_level.model import StepData
|
6
|
-
|
7
|
-
ConfigureCallbackSignature = Callable[["jaxsim.JaxSim"], "jaxsim.JaxSim"]
|
8
|
-
PreStepCallbackSignature = Callable[
|
9
|
-
["jaxsim.JaxSim"], Tuple["jaxsim.JaxSim", jtp.PyTree]
|
10
|
-
]
|
11
|
-
PostStepCallbackSignature = Callable[
|
12
|
-
["jaxsim.JaxSim", Dict[str, StepData]], Tuple["jaxsim.JaxSim", jtp.PyTree]
|
13
|
-
]
|
14
|
-
|
15
|
-
|
16
|
-
class SimulatorCallback(abc.ABC):
|
17
|
-
"""
|
18
|
-
A base class for simulator callbacks.
|
19
|
-
"""
|
20
|
-
|
21
|
-
pass
|
22
|
-
|
23
|
-
|
24
|
-
class ConfigureCallback(SimulatorCallback):
|
25
|
-
"""
|
26
|
-
A callback class to define logic for configuring the simulator before taking the first step.
|
27
|
-
"""
|
28
|
-
|
29
|
-
@property
|
30
|
-
def configure_cb(self) -> ConfigureCallbackSignature:
|
31
|
-
return lambda sim: self.configure(sim=sim)
|
32
|
-
|
33
|
-
@abc.abstractmethod
|
34
|
-
def configure(self, sim: "jaxsim.JaxSim") -> "jaxsim.JaxSim":
|
35
|
-
pass
|
36
|
-
|
37
|
-
|
38
|
-
class PreStepCallback(SimulatorCallback):
|
39
|
-
"""
|
40
|
-
A callback class for performing actions before each simulation step.
|
41
|
-
"""
|
42
|
-
|
43
|
-
@property
|
44
|
-
def pre_step_cb(self) -> PreStepCallbackSignature:
|
45
|
-
return lambda sim: self.pre_step(sim=sim)
|
46
|
-
|
47
|
-
@abc.abstractmethod
|
48
|
-
def pre_step(self, sim: "jaxsim.JaxSim") -> Tuple["jaxsim.JaxSim", jtp.PyTree]:
|
49
|
-
pass
|
50
|
-
|
51
|
-
|
52
|
-
class PostStepCallback(SimulatorCallback):
|
53
|
-
"""
|
54
|
-
A callback class for performing actions after each simulation step.
|
55
|
-
"""
|
56
|
-
|
57
|
-
@property
|
58
|
-
def post_step_cb(self) -> PostStepCallbackSignature:
|
59
|
-
return lambda sim, step_data: self.post_step(sim=sim, step_data=step_data)
|
60
|
-
|
61
|
-
@abc.abstractmethod
|
62
|
-
def post_step(
|
63
|
-
self, sim: "jaxsim.JaxSim", step_data: Dict[str, StepData]
|
64
|
-
) -> Tuple["jaxsim.JaxSim", jtp.PyTree]:
|
65
|
-
pass
|
66
|
-
|
67
|
-
|
68
|
-
class CallbackHandler(ConfigureCallback, PreStepCallback, PostStepCallback):
|
69
|
-
"""
|
70
|
-
A class that handles callbacks for the simulator.
|
71
|
-
|
72
|
-
Note:
|
73
|
-
The are different simulation stages with associated callbacks:
|
74
|
-
- `configure`: runs before the first step is taken.
|
75
|
-
- `pre_step`: runs at each step before integrating the dynamics and advancing the time.
|
76
|
-
- `post_step`: runs at each step after the integration of the dynamics.
|
77
|
-
"""
|
78
|
-
|
79
|
-
pass
|
jaxsim/simulation/utils.py
DELETED
@@ -1,15 +0,0 @@
|
|
1
|
-
from typing import Tuple
|
2
|
-
|
3
|
-
from jaxsim import logging
|
4
|
-
|
5
|
-
|
6
|
-
def check_valid_shape(
|
7
|
-
what: str, shape: Tuple, expected_shape: Tuple, valid: bool
|
8
|
-
) -> bool:
|
9
|
-
valid_shape = shape == expected_shape
|
10
|
-
|
11
|
-
if not valid_shape:
|
12
|
-
logging.debug(f"Shape of {what} differs: {shape}, {expected_shape}")
|
13
|
-
raise
|
14
|
-
|
15
|
-
return valid
|
jaxsim/sixd/__init__.py
DELETED
jaxsim/utils/oop.py
DELETED
@@ -1,536 +0,0 @@
|
|
1
|
-
import contextlib
|
2
|
-
import dataclasses
|
3
|
-
import functools
|
4
|
-
import inspect
|
5
|
-
import os
|
6
|
-
from typing import Any, Callable, Generator, TypeVar
|
7
|
-
|
8
|
-
import jax
|
9
|
-
import jax.flatten_util
|
10
|
-
from typing_extensions import ParamSpec
|
11
|
-
|
12
|
-
from jaxsim import logging
|
13
|
-
from jaxsim.utils import tracing
|
14
|
-
|
15
|
-
from . import Mutability, Vmappable
|
16
|
-
|
17
|
-
_P = ParamSpec("_P")
|
18
|
-
_R = TypeVar("_R")
|
19
|
-
|
20
|
-
|
21
|
-
class jax_tf:
|
22
|
-
"""
|
23
|
-
Class containing decorators applicable to methods of Vmappable objects.
|
24
|
-
"""
|
25
|
-
|
26
|
-
# Environment variables that can be used to disable the transformations
|
27
|
-
EnvVarOOP: str = "JAXSIM_OOP_DECORATORS"
|
28
|
-
EnvVarJitOOP: str = "JAXSIM_OOP_DECORATORS_JIT"
|
29
|
-
EnvVarVmapOOP: str = "JAXSIM_OOP_DECORATORS_VMAP"
|
30
|
-
EnvVarCacheOOP: str = "JAXSIM_OOP_DECORATORS_CACHE"
|
31
|
-
|
32
|
-
@staticmethod
|
33
|
-
def method_ro(
|
34
|
-
fn: Callable[_P, _R],
|
35
|
-
jit: bool = True,
|
36
|
-
static_argnames: tuple[str, ...] | list[str] = (),
|
37
|
-
vmap: bool | None = None,
|
38
|
-
vmap_in_axes: tuple[int, ...] | int | None = None,
|
39
|
-
vmap_out_axes: tuple[int, ...] | int | None = None,
|
40
|
-
) -> Callable[_P, _R]:
|
41
|
-
"""
|
42
|
-
Decorator for r/o methods of classes inheriting from Vmappable.
|
43
|
-
"""
|
44
|
-
|
45
|
-
return jax_tf.method(
|
46
|
-
fn=fn,
|
47
|
-
read_only=True,
|
48
|
-
validate=True,
|
49
|
-
jit_enabled=jit,
|
50
|
-
static_argnames=static_argnames,
|
51
|
-
vmap_enabled=vmap,
|
52
|
-
vmap_in_axes=vmap_in_axes,
|
53
|
-
vmap_out_axes=vmap_out_axes,
|
54
|
-
)
|
55
|
-
|
56
|
-
@staticmethod
|
57
|
-
def method_rw(
|
58
|
-
fn: Callable[_P, _R],
|
59
|
-
validate: bool = True,
|
60
|
-
jit: bool = True,
|
61
|
-
static_argnames: tuple[str, ...] | list[str] = (),
|
62
|
-
vmap: bool | None = None,
|
63
|
-
vmap_in_axes: tuple[int, ...] | int | None = None,
|
64
|
-
vmap_out_axes: tuple[int, ...] | int | None = None,
|
65
|
-
) -> Callable[_P, _R]:
|
66
|
-
"""
|
67
|
-
Decorator for r/w methods of classes inheriting from Vmappable.
|
68
|
-
"""
|
69
|
-
|
70
|
-
return jax_tf.method(
|
71
|
-
fn=fn,
|
72
|
-
read_only=False,
|
73
|
-
validate=validate,
|
74
|
-
jit_enabled=jit,
|
75
|
-
static_argnames=static_argnames,
|
76
|
-
vmap_enabled=vmap,
|
77
|
-
vmap_in_axes=vmap_in_axes,
|
78
|
-
vmap_out_axes=vmap_out_axes,
|
79
|
-
)
|
80
|
-
|
81
|
-
@staticmethod
|
82
|
-
def method(
|
83
|
-
fn: Callable[_P, _R],
|
84
|
-
read_only: bool = True,
|
85
|
-
validate: bool = True,
|
86
|
-
jit_enabled: bool = True,
|
87
|
-
static_argnames: tuple[str, ...] | list[str] = (),
|
88
|
-
vmap_enabled: bool | None = None,
|
89
|
-
vmap_in_axes: tuple[int, ...] | int | None = None,
|
90
|
-
vmap_out_axes: tuple[int, ...] | int | None = None,
|
91
|
-
):
|
92
|
-
"""
|
93
|
-
Decorator for methods of classes inheriting from Vmappable.
|
94
|
-
|
95
|
-
This decorator enables executing the methods on an object characterized by a
|
96
|
-
desired mutability, that is selected considering the r/o and validation flags.
|
97
|
-
It also allows to transform the method with the jit/vmap transformations.
|
98
|
-
If the Vmappable object is vectorized, the method is automatically vmapped, and
|
99
|
-
the in_axes are properly post-processed to simplify the combination with jit.
|
100
|
-
|
101
|
-
Args:
|
102
|
-
fn: The method to decorate.
|
103
|
-
read_only: Whether the method operates on a read-only object.
|
104
|
-
validate: Whether r/w methods should preserve the pytree structure.
|
105
|
-
jit_enabled: Whether to apply the jit transformation.
|
106
|
-
static_argnames: The names of the arguments that should be static.
|
107
|
-
vmap_enabled: Whether to apply the vmap transformation.
|
108
|
-
vmap_in_axes: The in_axes to use for the vmap transformation.
|
109
|
-
vmap_out_axes: The out_axes to use for the vmap transformation.
|
110
|
-
|
111
|
-
Returns:
|
112
|
-
The decorated method.
|
113
|
-
"""
|
114
|
-
|
115
|
-
@functools.wraps(fn)
|
116
|
-
def wrapper(*args: _P.args, **kwargs: _P.kwargs):
|
117
|
-
"""The wrapper function that is returned by this decorator."""
|
118
|
-
|
119
|
-
# Methods of classes inheriting from Vmappable decorated by this wrapper
|
120
|
-
# automatically support jit/vmap/mutability features when called standalone.
|
121
|
-
# However, when objects are arguments of plain functions transformed with
|
122
|
-
# jit/vmap, and decorated methods are called inside those functions, we need
|
123
|
-
# to disable this decorator to avoid double wrapping and execution errors.
|
124
|
-
# We do so by iterating over the arguments, and checking whether they are
|
125
|
-
# being traced by JAX.
|
126
|
-
for argument in list(args) + list(kwargs.values()):
|
127
|
-
try:
|
128
|
-
argument_flat, _ = jax.flatten_util.ravel_pytree(argument)
|
129
|
-
|
130
|
-
if tracing(argument_flat):
|
131
|
-
return fn(*args, **kwargs)
|
132
|
-
except:
|
133
|
-
continue
|
134
|
-
|
135
|
-
# ===============================================================
|
136
|
-
# Wrap fn so that jit/vmap/mutability transformations are applied
|
137
|
-
# ===============================================================
|
138
|
-
|
139
|
-
# Initialize the mutability of the instance over which the method is running.
|
140
|
-
# * In r/o methods, this approach prevents any type of mutation.
|
141
|
-
# * In r/w methods, this approach allows to catch early JIT recompilations
|
142
|
-
# caused by unwanted changes in the pytree structure.
|
143
|
-
if read_only:
|
144
|
-
mutability = Mutability.FROZEN
|
145
|
-
else:
|
146
|
-
mutability = (
|
147
|
-
Mutability.MUTABLE if validate else Mutability.MUTABLE_NO_VALIDATION
|
148
|
-
)
|
149
|
-
|
150
|
-
# Extract the class instance over which fn is called
|
151
|
-
instance: Vmappable = args[0]
|
152
|
-
assert isinstance(instance, Vmappable)
|
153
|
-
|
154
|
-
# Save the original mutability
|
155
|
-
original_mutability = instance._mutability()
|
156
|
-
|
157
|
-
# Inspect the environment to detect whether to enforce disabling jit/vmap
|
158
|
-
deco_on = jax_tf.env_var_on(jax_tf.EnvVarOOP)
|
159
|
-
jit_enabled_env = jax_tf.env_var_on(jax_tf.EnvVarJitOOP) and deco_on
|
160
|
-
vmap_enabled_env = jax_tf.env_var_on(jax_tf.EnvVarVmapOOP) and deco_on
|
161
|
-
|
162
|
-
# Allow disabling the cache of jit-compiled functions.
|
163
|
-
# It can be useful for debugging or testing purposes.
|
164
|
-
wrap_fn = (
|
165
|
-
jax_tf.wrap_fn
|
166
|
-
if jax_tf.env_var_on(jax_tf.EnvVarCacheOOP) and deco_on
|
167
|
-
else jax_tf.wrap_fn.__wrapped__
|
168
|
-
)
|
169
|
-
|
170
|
-
# Get the transformed function (possibly cached by functools.cache).
|
171
|
-
# Note that all the arguments of the following methods, when hashed, should
|
172
|
-
# uniquely identify the returned function so that a new function is built
|
173
|
-
# when arguments change and either jit or vmap have to be called again.
|
174
|
-
fn_db = wrap_fn(
|
175
|
-
fn=fn, # noqa
|
176
|
-
mutability=mutability,
|
177
|
-
jit=jit_enabled_env and jit_enabled,
|
178
|
-
static_argnames=tuple(static_argnames),
|
179
|
-
vmap=vmap_enabled_env
|
180
|
-
and (
|
181
|
-
vmap_enabled is True
|
182
|
-
or (vmap_enabled is None and instance.vectorized)
|
183
|
-
),
|
184
|
-
in_axes=vmap_in_axes,
|
185
|
-
out_axes=vmap_out_axes,
|
186
|
-
)
|
187
|
-
|
188
|
-
# Call the transformed (mutable/jit/vmap) method
|
189
|
-
out, obj = fn_db(*args, **kwargs)
|
190
|
-
|
191
|
-
if read_only:
|
192
|
-
# Restore the original mutability
|
193
|
-
instance._set_mutability(mutability=original_mutability)
|
194
|
-
|
195
|
-
return out
|
196
|
-
|
197
|
-
# =================================================================
|
198
|
-
# From here we assume that the wrapper is operating on a r/w method
|
199
|
-
# =================================================================
|
200
|
-
|
201
|
-
from jax_dataclasses._dataclasses import JDC_STATIC_MARKER
|
202
|
-
|
203
|
-
# Select the right runtime mutability. The only difference here is when a r/w
|
204
|
-
# method is called on a frozen object. In this case, we enable updating the
|
205
|
-
# pytree data and preserve its structure only if validation is enabled.
|
206
|
-
mutability_dict = {
|
207
|
-
Mutability.MUTABLE_NO_VALIDATION: Mutability.MUTABLE_NO_VALIDATION,
|
208
|
-
Mutability.MUTABLE: Mutability.MUTABLE,
|
209
|
-
Mutability.FROZEN: (
|
210
|
-
Mutability.MUTABLE if validate else Mutability.MUTABLE_NO_VALIDATION
|
211
|
-
),
|
212
|
-
}
|
213
|
-
|
214
|
-
# We need to replace all the dynamic leafs of the original instance with those
|
215
|
-
# computed by the functional transformation.
|
216
|
-
# We do so by iterating over the fields of the jax_dataclasses and ignoring
|
217
|
-
# all the fields that are marked as static.
|
218
|
-
# Caveats: https://github.com/ami-iit/jaxsim/pull/48#issuecomment-1746635121.
|
219
|
-
with instance.mutable_context(
|
220
|
-
mutability=mutability_dict[instance._mutability()]
|
221
|
-
):
|
222
|
-
for f in dataclasses.fields(instance): # noqa
|
223
|
-
if (
|
224
|
-
hasattr(f, "type")
|
225
|
-
and hasattr(f.type, "__metadata__")
|
226
|
-
and JDC_STATIC_MARKER in f.type.__metadata__
|
227
|
-
):
|
228
|
-
continue
|
229
|
-
|
230
|
-
try:
|
231
|
-
setattr(instance, f.name, getattr(obj, f.name))
|
232
|
-
except AssertionError as exc:
|
233
|
-
logging.debug(f"Old object:\n{getattr(instance, f.name)}")
|
234
|
-
logging.debug(f"New object:\n{getattr(obj, f.name)}")
|
235
|
-
raise RuntimeError(
|
236
|
-
f"Failed to update field '{f.name}'"
|
237
|
-
) from exc
|
238
|
-
|
239
|
-
return out
|
240
|
-
|
241
|
-
return wrapper
|
242
|
-
|
243
|
-
@staticmethod
|
244
|
-
@functools.cache
|
245
|
-
def wrap_fn(
|
246
|
-
fn: Callable,
|
247
|
-
mutability: Mutability,
|
248
|
-
jit: bool,
|
249
|
-
static_argnames: tuple[str, ...] | list[str],
|
250
|
-
vmap: bool,
|
251
|
-
in_axes: tuple[int, ...] | int | None,
|
252
|
-
out_axes: tuple[int, ...] | int | None,
|
253
|
-
) -> Callable:
|
254
|
-
"""
|
255
|
-
Transform a method with jit/vmap and execute it on an object characterized
|
256
|
-
by the desired mutability.
|
257
|
-
|
258
|
-
Note:
|
259
|
-
The method should take the object (self) as first argument.
|
260
|
-
|
261
|
-
Note:
|
262
|
-
This returned transformed method is cached by considering the hash of all
|
263
|
-
the arguments. It will re-apply jit/vmap transformations only if needed.
|
264
|
-
|
265
|
-
Args:
|
266
|
-
fn: The method to consider.
|
267
|
-
mutability: The mutability of the object on which the method is called.
|
268
|
-
jit: Whether to apply jit transformations.
|
269
|
-
static_argnames: The names of the arguments that should be considered static.
|
270
|
-
vmap: Whether to apply vmap transformations.
|
271
|
-
in_axes: The axes along which to vmap input arguments.
|
272
|
-
out_axes: The axes along which to vmap output arguments.
|
273
|
-
|
274
|
-
Note:
|
275
|
-
In order to simplify the application of vmap, we close the method arguments
|
276
|
-
over all the non-mapped input arguments. Furthermore, for improving the
|
277
|
-
compatibility with jit, we also close the vmap application over the static
|
278
|
-
arguments.
|
279
|
-
|
280
|
-
Returns:
|
281
|
-
The transformed method operating on an object with the desired mutability.
|
282
|
-
We maintain the same signature of the original method.
|
283
|
-
"""
|
284
|
-
|
285
|
-
# Extract the signature of the function
|
286
|
-
sig = inspect.signature(fn)
|
287
|
-
|
288
|
-
# All static arguments must be actual arguments of fn
|
289
|
-
for name in static_argnames:
|
290
|
-
if name not in sig.parameters:
|
291
|
-
raise ValueError(f"Static argument '{name}' not found in {fn}")
|
292
|
-
|
293
|
-
# If in_axes is a tuple, its dimension should match the number of arguments
|
294
|
-
if isinstance(in_axes, tuple) and len(in_axes) != len(sig.parameters):
|
295
|
-
msg = "The length of 'in_axes' must match the number of arguments ({})"
|
296
|
-
raise ValueError(msg.format(len(sig.parameters)))
|
297
|
-
|
298
|
-
# Check that static arguments are not mapped with vmap.
|
299
|
-
# This case would not work since static arguments are not traces and vmap need
|
300
|
-
# to trace arguments in order to map them.
|
301
|
-
if isinstance(in_axes, tuple):
|
302
|
-
for mapped_axis, arg_name in zip(in_axes, sig.parameters.keys()):
|
303
|
-
if mapped_axis is not None and arg_name in static_argnames:
|
304
|
-
raise ValueError(
|
305
|
-
f"Static argument '{arg_name}' cannot be mapped with vmap"
|
306
|
-
)
|
307
|
-
|
308
|
-
def fn_tf_vmap(*args, function_to_vmap: Callable, **kwargs):
|
309
|
-
"""Wrapper applying the vmap transformation"""
|
310
|
-
|
311
|
-
# Canonicalize the arguments so that all of them are kwargs
|
312
|
-
bound = sig.bind(*args, **kwargs)
|
313
|
-
bound.apply_defaults()
|
314
|
-
|
315
|
-
# Build a dictionary mapping all arguments to a mapped axis, even when
|
316
|
-
# the None is passed (defaults to in_axes=0) or and int is passed (defaults
|
317
|
-
# to in_axes=<int>).
|
318
|
-
match in_axes:
|
319
|
-
case None:
|
320
|
-
argname_to_mapped_axis = {name: 0 for name in bound.arguments}
|
321
|
-
case tuple():
|
322
|
-
argname_to_mapped_axis = {
|
323
|
-
name: in_axes[i] for i, name in enumerate(bound.arguments)
|
324
|
-
}
|
325
|
-
case int():
|
326
|
-
argname_to_mapped_axis = {name: in_axes for name in bound.arguments}
|
327
|
-
case _:
|
328
|
-
raise ValueError(in_axes)
|
329
|
-
|
330
|
-
# Build a dictionary (argument_name -> argument) for all mapped arguments.
|
331
|
-
# Note that a mapped argument is an argument whose axis is not None and
|
332
|
-
# is not a static jit argument.
|
333
|
-
vmap_mapped_args = {
|
334
|
-
arg: value
|
335
|
-
for arg, value in bound.arguments.items()
|
336
|
-
if argname_to_mapped_axis[arg] is not None
|
337
|
-
and arg not in static_argnames
|
338
|
-
}
|
339
|
-
|
340
|
-
# Build a dictionary (argument_name -> argument) for all unmapped arguments
|
341
|
-
vmap_unmapped_args = {
|
342
|
-
arg: value
|
343
|
-
for arg, value in bound.arguments.items()
|
344
|
-
if arg not in vmap_mapped_args
|
345
|
-
}
|
346
|
-
|
347
|
-
# Disable mapping of non-vectorized default arguments
|
348
|
-
for arg, value in argname_to_mapped_axis.items():
|
349
|
-
if arg in vmap_mapped_args and value == sig.parameters[arg].default:
|
350
|
-
logging.debug(f"Disabling vmapping of default argument '{arg}'")
|
351
|
-
argname_to_mapped_axis[arg] = None
|
352
|
-
|
353
|
-
# Close the function over the unmapped arguments of vmap
|
354
|
-
fn_closed = lambda *mapped_args: function_to_vmap(
|
355
|
-
**vmap_unmapped_args, **dict(zip(vmap_mapped_args.keys(), mapped_args))
|
356
|
-
)
|
357
|
-
|
358
|
-
# Create the in_axes tuple of only the mapped arguments
|
359
|
-
in_axes_mapped = tuple(
|
360
|
-
argname_to_mapped_axis[name] for name in vmap_mapped_args
|
361
|
-
)
|
362
|
-
|
363
|
-
# If all in_axes are the same, simplify in_axes tuple to be just an integer
|
364
|
-
if len(set(in_axes_mapped)) == 1:
|
365
|
-
in_axes_mapped = list(set(in_axes_mapped))[0]
|
366
|
-
|
367
|
-
# If, instead, in_axes has different elements, we need to replace the mapped
|
368
|
-
# axis of "self" with a pytree having as leafs the mapped axis.
|
369
|
-
# This is because the vmap in_axes specification must be a tree prefix of
|
370
|
-
# the corresponding value.
|
371
|
-
if isinstance(in_axes_mapped, tuple) and "self" in vmap_mapped_args:
|
372
|
-
argname_to_mapped_axis["self"] = jax.tree_util.tree_map(
|
373
|
-
lambda _: argname_to_mapped_axis["self"], vmap_mapped_args["self"]
|
374
|
-
)
|
375
|
-
in_axes_mapped = tuple(
|
376
|
-
argname_to_mapped_axis[name] for name in vmap_mapped_args
|
377
|
-
)
|
378
|
-
|
379
|
-
# Apply the vmap transformation and call the function passing only the
|
380
|
-
# mapped arguments. The unmapped arguments have been closed over.
|
381
|
-
# Note: we altered the "in_axes" tuple so that it does not have any
|
382
|
-
# None elements.
|
383
|
-
# Note: if "in_axes_mapped" is a tuple, the following fails if we pass kwargs,
|
384
|
-
# we need to pass the unpacked args tuple instead.
|
385
|
-
return jax.vmap(
|
386
|
-
fn_closed,
|
387
|
-
in_axes=in_axes_mapped,
|
388
|
-
**dict(out_axes=out_axes) if out_axes is not None else {},
|
389
|
-
)(*list(vmap_mapped_args.values()))
|
390
|
-
|
391
|
-
def fn_tf_jit(*args, function_to_jit: Callable, **kwargs):
|
392
|
-
"""Wrapper applying the jit transformation"""
|
393
|
-
|
394
|
-
# Canonicalize the arguments so that all of them are kwargs
|
395
|
-
bound = sig.bind(*args, **kwargs)
|
396
|
-
bound.apply_defaults()
|
397
|
-
|
398
|
-
# Apply the jit transformation and call the function passing all arguments
|
399
|
-
# as keyword arguments
|
400
|
-
return jax.jit(function_to_jit, static_argnames=static_argnames)(
|
401
|
-
**bound.arguments
|
402
|
-
)
|
403
|
-
|
404
|
-
# First applied wrapper that executes fn in a mutable context
|
405
|
-
fn_mutable = functools.partial(
|
406
|
-
jax_tf.call_class_method_in_mutable_context,
|
407
|
-
fn=fn,
|
408
|
-
jit=jit,
|
409
|
-
mutability=mutability,
|
410
|
-
)
|
411
|
-
|
412
|
-
# Second applied wrapper that transforms fn with vmap
|
413
|
-
fn_vmap = (
|
414
|
-
fn_mutable
|
415
|
-
if not vmap
|
416
|
-
else functools.partial(fn_tf_vmap, function_to_vmap=fn_mutable)
|
417
|
-
)
|
418
|
-
|
419
|
-
# Third applied wrapper that transforms fn with jit
|
420
|
-
fn_jit_vmap = (
|
421
|
-
fn_vmap
|
422
|
-
if not jit
|
423
|
-
else functools.partial(fn_tf_jit, function_to_jit=fn_vmap)
|
424
|
-
)
|
425
|
-
|
426
|
-
return fn_jit_vmap
|
427
|
-
|
428
|
-
@staticmethod
|
429
|
-
def call_class_method_in_mutable_context(
|
430
|
-
*args, fn: Callable, jit: bool, mutability: Mutability, **kwargs
|
431
|
-
) -> tuple[Any, Vmappable]:
|
432
|
-
"""
|
433
|
-
Wrapper to call a method on an object with the desired mutable context.
|
434
|
-
|
435
|
-
Args:
|
436
|
-
fn: The method to call.
|
437
|
-
jit: Whether the method is being jit compiled or not.
|
438
|
-
mutability: The desired mutability context.
|
439
|
-
*args: The positional arguments to pass to the method (including self).
|
440
|
-
**kwargs: The keyword arguments to pass to the method.
|
441
|
-
|
442
|
-
Returns:
|
443
|
-
A tuple containing the return value of the method and the object
|
444
|
-
possibly updated by the method if it is in read-write.
|
445
|
-
|
446
|
-
Note:
|
447
|
-
This approach enables to jit-compile methods of a stateful object without
|
448
|
-
leaking traces, therefore obtaining a jax-compatible OOP pattern.
|
449
|
-
"""
|
450
|
-
|
451
|
-
# Log here whether the method is being jit compiled or not.
|
452
|
-
# This log message does not get printed from compiled code, so here is the
|
453
|
-
# most appropriate place to be sure that we log it correctly.
|
454
|
-
if jit:
|
455
|
-
logging.debug(msg=f"JIT compiling {fn}")
|
456
|
-
|
457
|
-
# Canonicalize the arguments so that all of them are kwargs
|
458
|
-
sig = inspect.signature(fn)
|
459
|
-
bound = sig.bind(*args, **kwargs)
|
460
|
-
bound.apply_defaults()
|
461
|
-
|
462
|
-
# Extract the class instance over which fn is called
|
463
|
-
instance: Vmappable = bound.arguments["self"]
|
464
|
-
|
465
|
-
# Select the right mutability. If the instance is mutable with validation
|
466
|
-
# disabled, we override the input mutability so that we do not fail in case
|
467
|
-
# of mismatched tree structure.
|
468
|
-
mut = (
|
469
|
-
Mutability.MUTABLE_NO_VALIDATION
|
470
|
-
if instance._mutability() is Mutability.MUTABLE_NO_VALIDATION
|
471
|
-
else mutability
|
472
|
-
)
|
473
|
-
|
474
|
-
# Call fn in a mutable context
|
475
|
-
with instance.mutable_context(mutability=mut):
|
476
|
-
# Methods could call other decorated methods. When it happens, the decorator
|
477
|
-
# of the called method is invoked, that applies jit and vmap transformations.
|
478
|
-
# This is not desired as it calls vmap inside an already vmapped method.
|
479
|
-
# We work around this occurrence by disabling the jit/vmap decorators of all
|
480
|
-
# methods called inside fn through a context manager.
|
481
|
-
# Note that we already work around this in the beginning of the wrapper
|
482
|
-
# function by detecting traced arguments, but the decorator works also
|
483
|
-
# when jit=False and vmap=False, therefore only enforcing the mutability.
|
484
|
-
with jax_tf.disabled_oop_decorators():
|
485
|
-
out = fn(**bound.arguments)
|
486
|
-
|
487
|
-
return out, instance
|
488
|
-
|
489
|
-
@staticmethod
|
490
|
-
def env_var_on(var_name: str, default_value: str = "1") -> bool:
|
491
|
-
"""
|
492
|
-
Check whether an environment variable is set to a value that is considered on.
|
493
|
-
|
494
|
-
Args:
|
495
|
-
var_name: The name of the environment variable.
|
496
|
-
default_value: The default variable value to consider if the variable has not
|
497
|
-
been exported.
|
498
|
-
|
499
|
-
Returns:
|
500
|
-
True if the environment variable contains an on value, False otherwise.
|
501
|
-
"""
|
502
|
-
|
503
|
-
on_values = {"1", "true", "on", "yes"}
|
504
|
-
return os.environ.get(var_name, default_value).lower() in on_values
|
505
|
-
|
506
|
-
@staticmethod
|
507
|
-
@contextlib.contextmanager
|
508
|
-
def disabled_oop_decorators() -> Generator[None, None, None]:
|
509
|
-
"""
|
510
|
-
Context manager to disable the application of jax transformations performed by
|
511
|
-
the decorators of this class.
|
512
|
-
|
513
|
-
Note: when the transformations are disabled, the only logic still applied is
|
514
|
-
the selection of the object mutability over which the method is running.
|
515
|
-
"""
|
516
|
-
|
517
|
-
# Check whether the environment variable is part of the environment and
|
518
|
-
# save its value. We restore the original value before exiting the context.
|
519
|
-
env_cache = (
|
520
|
-
None if jax_tf.EnvVarOOP not in os.environ else os.environ[jax_tf.EnvVarOOP]
|
521
|
-
)
|
522
|
-
|
523
|
-
# Disable both jit and vmap transformations
|
524
|
-
os.environ[jax_tf.EnvVarOOP] = "0"
|
525
|
-
|
526
|
-
try:
|
527
|
-
# Execute the code in the context with disabled transformations
|
528
|
-
yield
|
529
|
-
|
530
|
-
finally:
|
531
|
-
# Restore the original value of the environment variable or remove it if
|
532
|
-
# it was not present before entering the context
|
533
|
-
if env_cache is not None:
|
534
|
-
os.environ[jax_tf.EnvVarOOP] = env_cache
|
535
|
-
else:
|
536
|
-
_ = os.environ.pop(jax_tf.EnvVarOOP)
|