jaxsim 0.2.dev188__py3-none-any.whl → 0.6.1.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/__init__.py +73 -22
- jaxsim/_version.py +2 -2
- jaxsim/api/__init__.py +13 -1
- jaxsim/api/com.py +423 -0
- jaxsim/api/common.py +48 -19
- jaxsim/api/contact.py +604 -52
- jaxsim/api/data.py +308 -163
- jaxsim/api/frame.py +471 -0
- jaxsim/api/joint.py +166 -37
- jaxsim/api/kin_dyn_parameters.py +901 -0
- jaxsim/api/link.py +277 -78
- jaxsim/api/model.py +1572 -362
- jaxsim/api/ode.py +324 -129
- jaxsim/api/ode_data.py +401 -0
- jaxsim/api/references.py +216 -80
- jaxsim/exceptions.py +80 -0
- jaxsim/integrators/__init__.py +2 -2
- jaxsim/integrators/common.py +191 -107
- jaxsim/integrators/fixed_step.py +97 -102
- jaxsim/integrators/variable_step.py +706 -0
- jaxsim/logging.py +1 -2
- jaxsim/math/__init__.py +13 -0
- jaxsim/math/adjoint.py +57 -22
- jaxsim/math/cross.py +16 -7
- jaxsim/math/inertia.py +10 -8
- jaxsim/math/joint_model.py +289 -0
- jaxsim/math/quaternion.py +54 -20
- jaxsim/math/rotation.py +27 -21
- jaxsim/math/skew.py +16 -5
- jaxsim/math/transform.py +102 -0
- jaxsim/math/utils.py +31 -0
- jaxsim/mujoco/__init__.py +2 -1
- jaxsim/mujoco/loaders.py +216 -29
- jaxsim/mujoco/model.py +163 -33
- jaxsim/mujoco/utils.py +228 -0
- jaxsim/mujoco/visualizer.py +107 -22
- jaxsim/parsers/__init__.py +0 -1
- jaxsim/parsers/descriptions/__init__.py +8 -2
- jaxsim/parsers/descriptions/collision.py +87 -16
- jaxsim/parsers/descriptions/joint.py +80 -87
- jaxsim/parsers/descriptions/link.py +62 -24
- jaxsim/parsers/descriptions/model.py +101 -68
- jaxsim/parsers/kinematic_graph.py +607 -225
- jaxsim/parsers/rod/meshes.py +104 -0
- jaxsim/parsers/rod/parser.py +125 -82
- jaxsim/parsers/rod/utils.py +127 -82
- jaxsim/rbda/__init__.py +11 -0
- jaxsim/rbda/aba.py +289 -0
- jaxsim/rbda/collidable_points.py +156 -0
- jaxsim/rbda/contacts/__init__.py +13 -0
- jaxsim/rbda/contacts/common.py +313 -0
- jaxsim/rbda/contacts/relaxed_rigid.py +605 -0
- jaxsim/rbda/contacts/rigid.py +462 -0
- jaxsim/rbda/contacts/soft.py +480 -0
- jaxsim/rbda/contacts/visco_elastic.py +1066 -0
- jaxsim/rbda/crba.py +167 -0
- jaxsim/rbda/forward_kinematics.py +117 -0
- jaxsim/rbda/jacobian.py +330 -0
- jaxsim/rbda/rnea.py +235 -0
- jaxsim/rbda/utils.py +160 -0
- jaxsim/terrain/__init__.py +2 -0
- jaxsim/terrain/terrain.py +238 -0
- jaxsim/typing.py +24 -24
- jaxsim/utils/__init__.py +1 -4
- jaxsim/utils/jaxsim_dataclass.py +289 -34
- jaxsim/utils/tracing.py +5 -11
- jaxsim/utils/wrappers.py +159 -0
- {jaxsim-0.2.dev188.dist-info → jaxsim-0.6.1.dev2.dist-info}/LICENSE +1 -1
- jaxsim-0.6.1.dev2.dist-info/METADATA +465 -0
- jaxsim-0.6.1.dev2.dist-info/RECORD +74 -0
- {jaxsim-0.2.dev188.dist-info → jaxsim-0.6.1.dev2.dist-info}/WHEEL +1 -1
- jaxsim/high_level/__init__.py +0 -2
- jaxsim/high_level/common.py +0 -11
- jaxsim/high_level/joint.py +0 -148
- jaxsim/high_level/link.py +0 -259
- jaxsim/high_level/model.py +0 -1686
- jaxsim/math/conv.py +0 -114
- jaxsim/math/joint.py +0 -102
- jaxsim/math/plucker.py +0 -100
- jaxsim/physics/__init__.py +0 -12
- jaxsim/physics/algos/__init__.py +0 -0
- jaxsim/physics/algos/aba.py +0 -254
- jaxsim/physics/algos/aba_motors.py +0 -284
- jaxsim/physics/algos/crba.py +0 -154
- jaxsim/physics/algos/forward_kinematics.py +0 -79
- jaxsim/physics/algos/jacobian.py +0 -98
- jaxsim/physics/algos/rnea.py +0 -180
- jaxsim/physics/algos/rnea_motors.py +0 -196
- jaxsim/physics/algos/soft_contacts.py +0 -523
- jaxsim/physics/algos/terrain.py +0 -80
- jaxsim/physics/algos/utils.py +0 -69
- jaxsim/physics/model/__init__.py +0 -0
- jaxsim/physics/model/ground_contact.py +0 -55
- jaxsim/physics/model/physics_model.py +0 -388
- jaxsim/physics/model/physics_model_state.py +0 -283
- jaxsim/simulation/__init__.py +0 -4
- jaxsim/simulation/integrators.py +0 -393
- jaxsim/simulation/ode.py +0 -290
- jaxsim/simulation/ode_data.py +0 -96
- jaxsim/simulation/ode_integration.py +0 -62
- jaxsim/simulation/simulator.py +0 -543
- jaxsim/simulation/simulator_callbacks.py +0 -79
- jaxsim/simulation/utils.py +0 -15
- jaxsim/sixd/__init__.py +0 -2
- jaxsim/utils/oop.py +0 -536
- jaxsim/utils/vmappable.py +0 -117
- jaxsim-0.2.dev188.dist-info/METADATA +0 -184
- jaxsim-0.2.dev188.dist-info/RECORD +0 -81
- {jaxsim-0.2.dev188.dist-info → jaxsim-0.6.1.dev2.dist-info}/top_level.txt +0 -0
jaxsim/utils/oop.py
DELETED
@@ -1,536 +0,0 @@
|
|
1
|
-
import contextlib
|
2
|
-
import dataclasses
|
3
|
-
import functools
|
4
|
-
import inspect
|
5
|
-
import os
|
6
|
-
from typing import Any, Callable, Generator, TypeVar
|
7
|
-
|
8
|
-
import jax
|
9
|
-
import jax.flatten_util
|
10
|
-
from typing_extensions import ParamSpec
|
11
|
-
|
12
|
-
from jaxsim import logging
|
13
|
-
from jaxsim.utils import tracing
|
14
|
-
|
15
|
-
from . import Mutability, Vmappable
|
16
|
-
|
17
|
-
_P = ParamSpec("_P")
|
18
|
-
_R = TypeVar("_R")
|
19
|
-
|
20
|
-
|
21
|
-
class jax_tf:
|
22
|
-
"""
|
23
|
-
Class containing decorators applicable to methods of Vmappable objects.
|
24
|
-
"""
|
25
|
-
|
26
|
-
# Environment variables that can be used to disable the transformations
|
27
|
-
EnvVarOOP: str = "JAXSIM_OOP_DECORATORS"
|
28
|
-
EnvVarJitOOP: str = "JAXSIM_OOP_DECORATORS_JIT"
|
29
|
-
EnvVarVmapOOP: str = "JAXSIM_OOP_DECORATORS_VMAP"
|
30
|
-
EnvVarCacheOOP: str = "JAXSIM_OOP_DECORATORS_CACHE"
|
31
|
-
|
32
|
-
@staticmethod
|
33
|
-
def method_ro(
|
34
|
-
fn: Callable[_P, _R],
|
35
|
-
jit: bool = True,
|
36
|
-
static_argnames: tuple[str, ...] | list[str] = (),
|
37
|
-
vmap: bool | None = None,
|
38
|
-
vmap_in_axes: tuple[int, ...] | int | None = None,
|
39
|
-
vmap_out_axes: tuple[int, ...] | int | None = None,
|
40
|
-
) -> Callable[_P, _R]:
|
41
|
-
"""
|
42
|
-
Decorator for r/o methods of classes inheriting from Vmappable.
|
43
|
-
"""
|
44
|
-
|
45
|
-
return jax_tf.method(
|
46
|
-
fn=fn,
|
47
|
-
read_only=True,
|
48
|
-
validate=True,
|
49
|
-
jit_enabled=jit,
|
50
|
-
static_argnames=static_argnames,
|
51
|
-
vmap_enabled=vmap,
|
52
|
-
vmap_in_axes=vmap_in_axes,
|
53
|
-
vmap_out_axes=vmap_out_axes,
|
54
|
-
)
|
55
|
-
|
56
|
-
@staticmethod
|
57
|
-
def method_rw(
|
58
|
-
fn: Callable[_P, _R],
|
59
|
-
validate: bool = True,
|
60
|
-
jit: bool = True,
|
61
|
-
static_argnames: tuple[str, ...] | list[str] = (),
|
62
|
-
vmap: bool | None = None,
|
63
|
-
vmap_in_axes: tuple[int, ...] | int | None = None,
|
64
|
-
vmap_out_axes: tuple[int, ...] | int | None = None,
|
65
|
-
) -> Callable[_P, _R]:
|
66
|
-
"""
|
67
|
-
Decorator for r/w methods of classes inheriting from Vmappable.
|
68
|
-
"""
|
69
|
-
|
70
|
-
return jax_tf.method(
|
71
|
-
fn=fn,
|
72
|
-
read_only=False,
|
73
|
-
validate=validate,
|
74
|
-
jit_enabled=jit,
|
75
|
-
static_argnames=static_argnames,
|
76
|
-
vmap_enabled=vmap,
|
77
|
-
vmap_in_axes=vmap_in_axes,
|
78
|
-
vmap_out_axes=vmap_out_axes,
|
79
|
-
)
|
80
|
-
|
81
|
-
@staticmethod
|
82
|
-
def method(
|
83
|
-
fn: Callable[_P, _R],
|
84
|
-
read_only: bool = True,
|
85
|
-
validate: bool = True,
|
86
|
-
jit_enabled: bool = True,
|
87
|
-
static_argnames: tuple[str, ...] | list[str] = (),
|
88
|
-
vmap_enabled: bool | None = None,
|
89
|
-
vmap_in_axes: tuple[int, ...] | int | None = None,
|
90
|
-
vmap_out_axes: tuple[int, ...] | int | None = None,
|
91
|
-
):
|
92
|
-
"""
|
93
|
-
Decorator for methods of classes inheriting from Vmappable.
|
94
|
-
|
95
|
-
This decorator enables executing the methods on an object characterized by a
|
96
|
-
desired mutability, that is selected considering the r/o and validation flags.
|
97
|
-
It also allows to transform the method with the jit/vmap transformations.
|
98
|
-
If the Vmappable object is vectorized, the method is automatically vmapped, and
|
99
|
-
the in_axes are properly post-processed to simplify the combination with jit.
|
100
|
-
|
101
|
-
Args:
|
102
|
-
fn: The method to decorate.
|
103
|
-
read_only: Whether the method operates on a read-only object.
|
104
|
-
validate: Whether r/w methods should preserve the pytree structure.
|
105
|
-
jit_enabled: Whether to apply the jit transformation.
|
106
|
-
static_argnames: The names of the arguments that should be static.
|
107
|
-
vmap_enabled: Whether to apply the vmap transformation.
|
108
|
-
vmap_in_axes: The in_axes to use for the vmap transformation.
|
109
|
-
vmap_out_axes: The out_axes to use for the vmap transformation.
|
110
|
-
|
111
|
-
Returns:
|
112
|
-
The decorated method.
|
113
|
-
"""
|
114
|
-
|
115
|
-
@functools.wraps(fn)
|
116
|
-
def wrapper(*args: _P.args, **kwargs: _P.kwargs):
|
117
|
-
"""The wrapper function that is returned by this decorator."""
|
118
|
-
|
119
|
-
# Methods of classes inheriting from Vmappable decorated by this wrapper
|
120
|
-
# automatically support jit/vmap/mutability features when called standalone.
|
121
|
-
# However, when objects are arguments of plain functions transformed with
|
122
|
-
# jit/vmap, and decorated methods are called inside those functions, we need
|
123
|
-
# to disable this decorator to avoid double wrapping and execution errors.
|
124
|
-
# We do so by iterating over the arguments, and checking whether they are
|
125
|
-
# being traced by JAX.
|
126
|
-
for argument in list(args) + list(kwargs.values()):
|
127
|
-
try:
|
128
|
-
argument_flat, _ = jax.flatten_util.ravel_pytree(argument)
|
129
|
-
|
130
|
-
if tracing(argument_flat):
|
131
|
-
return fn(*args, **kwargs)
|
132
|
-
except:
|
133
|
-
continue
|
134
|
-
|
135
|
-
# ===============================================================
|
136
|
-
# Wrap fn so that jit/vmap/mutability transformations are applied
|
137
|
-
# ===============================================================
|
138
|
-
|
139
|
-
# Initialize the mutability of the instance over which the method is running.
|
140
|
-
# * In r/o methods, this approach prevents any type of mutation.
|
141
|
-
# * In r/w methods, this approach allows to catch early JIT recompilations
|
142
|
-
# caused by unwanted changes in the pytree structure.
|
143
|
-
if read_only:
|
144
|
-
mutability = Mutability.FROZEN
|
145
|
-
else:
|
146
|
-
mutability = (
|
147
|
-
Mutability.MUTABLE if validate else Mutability.MUTABLE_NO_VALIDATION
|
148
|
-
)
|
149
|
-
|
150
|
-
# Extract the class instance over which fn is called
|
151
|
-
instance: Vmappable = args[0]
|
152
|
-
assert isinstance(instance, Vmappable)
|
153
|
-
|
154
|
-
# Save the original mutability
|
155
|
-
original_mutability = instance._mutability()
|
156
|
-
|
157
|
-
# Inspect the environment to detect whether to enforce disabling jit/vmap
|
158
|
-
deco_on = jax_tf.env_var_on(jax_tf.EnvVarOOP)
|
159
|
-
jit_enabled_env = jax_tf.env_var_on(jax_tf.EnvVarJitOOP) and deco_on
|
160
|
-
vmap_enabled_env = jax_tf.env_var_on(jax_tf.EnvVarVmapOOP) and deco_on
|
161
|
-
|
162
|
-
# Allow disabling the cache of jit-compiled functions.
|
163
|
-
# It can be useful for debugging or testing purposes.
|
164
|
-
wrap_fn = (
|
165
|
-
jax_tf.wrap_fn
|
166
|
-
if jax_tf.env_var_on(jax_tf.EnvVarCacheOOP) and deco_on
|
167
|
-
else jax_tf.wrap_fn.__wrapped__
|
168
|
-
)
|
169
|
-
|
170
|
-
# Get the transformed function (possibly cached by functools.cache).
|
171
|
-
# Note that all the arguments of the following methods, when hashed, should
|
172
|
-
# uniquely identify the returned function so that a new function is built
|
173
|
-
# when arguments change and either jit or vmap have to be called again.
|
174
|
-
fn_db = wrap_fn(
|
175
|
-
fn=fn, # noqa
|
176
|
-
mutability=mutability,
|
177
|
-
jit=jit_enabled_env and jit_enabled,
|
178
|
-
static_argnames=tuple(static_argnames),
|
179
|
-
vmap=vmap_enabled_env
|
180
|
-
and (
|
181
|
-
vmap_enabled is True
|
182
|
-
or (vmap_enabled is None and instance.vectorized)
|
183
|
-
),
|
184
|
-
in_axes=vmap_in_axes,
|
185
|
-
out_axes=vmap_out_axes,
|
186
|
-
)
|
187
|
-
|
188
|
-
# Call the transformed (mutable/jit/vmap) method
|
189
|
-
out, obj = fn_db(*args, **kwargs)
|
190
|
-
|
191
|
-
if read_only:
|
192
|
-
# Restore the original mutability
|
193
|
-
instance._set_mutability(mutability=original_mutability)
|
194
|
-
|
195
|
-
return out
|
196
|
-
|
197
|
-
# =================================================================
|
198
|
-
# From here we assume that the wrapper is operating on a r/w method
|
199
|
-
# =================================================================
|
200
|
-
|
201
|
-
from jax_dataclasses._dataclasses import JDC_STATIC_MARKER
|
202
|
-
|
203
|
-
# Select the right runtime mutability. The only difference here is when a r/w
|
204
|
-
# method is called on a frozen object. In this case, we enable updating the
|
205
|
-
# pytree data and preserve its structure only if validation is enabled.
|
206
|
-
mutability_dict = {
|
207
|
-
Mutability.MUTABLE_NO_VALIDATION: Mutability.MUTABLE_NO_VALIDATION,
|
208
|
-
Mutability.MUTABLE: Mutability.MUTABLE,
|
209
|
-
Mutability.FROZEN: (
|
210
|
-
Mutability.MUTABLE if validate else Mutability.MUTABLE_NO_VALIDATION
|
211
|
-
),
|
212
|
-
}
|
213
|
-
|
214
|
-
# We need to replace all the dynamic leafs of the original instance with those
|
215
|
-
# computed by the functional transformation.
|
216
|
-
# We do so by iterating over the fields of the jax_dataclasses and ignoring
|
217
|
-
# all the fields that are marked as static.
|
218
|
-
# Caveats: https://github.com/ami-iit/jaxsim/pull/48#issuecomment-1746635121.
|
219
|
-
with instance.mutable_context(
|
220
|
-
mutability=mutability_dict[instance._mutability()]
|
221
|
-
):
|
222
|
-
for f in dataclasses.fields(instance): # noqa
|
223
|
-
if (
|
224
|
-
hasattr(f, "type")
|
225
|
-
and hasattr(f.type, "__metadata__")
|
226
|
-
and JDC_STATIC_MARKER in f.type.__metadata__
|
227
|
-
):
|
228
|
-
continue
|
229
|
-
|
230
|
-
try:
|
231
|
-
setattr(instance, f.name, getattr(obj, f.name))
|
232
|
-
except AssertionError as exc:
|
233
|
-
logging.debug(f"Old object:\n{getattr(instance, f.name)}")
|
234
|
-
logging.debug(f"New object:\n{getattr(obj, f.name)}")
|
235
|
-
raise RuntimeError(
|
236
|
-
f"Failed to update field '{f.name}'"
|
237
|
-
) from exc
|
238
|
-
|
239
|
-
return out
|
240
|
-
|
241
|
-
return wrapper
|
242
|
-
|
243
|
-
@staticmethod
|
244
|
-
@functools.cache
|
245
|
-
def wrap_fn(
|
246
|
-
fn: Callable,
|
247
|
-
mutability: Mutability,
|
248
|
-
jit: bool,
|
249
|
-
static_argnames: tuple[str, ...] | list[str],
|
250
|
-
vmap: bool,
|
251
|
-
in_axes: tuple[int, ...] | int | None,
|
252
|
-
out_axes: tuple[int, ...] | int | None,
|
253
|
-
) -> Callable:
|
254
|
-
"""
|
255
|
-
Transform a method with jit/vmap and execute it on an object characterized
|
256
|
-
by the desired mutability.
|
257
|
-
|
258
|
-
Note:
|
259
|
-
The method should take the object (self) as first argument.
|
260
|
-
|
261
|
-
Note:
|
262
|
-
This returned transformed method is cached by considering the hash of all
|
263
|
-
the arguments. It will re-apply jit/vmap transformations only if needed.
|
264
|
-
|
265
|
-
Args:
|
266
|
-
fn: The method to consider.
|
267
|
-
mutability: The mutability of the object on which the method is called.
|
268
|
-
jit: Whether to apply jit transformations.
|
269
|
-
static_argnames: The names of the arguments that should be considered static.
|
270
|
-
vmap: Whether to apply vmap transformations.
|
271
|
-
in_axes: The axes along which to vmap input arguments.
|
272
|
-
out_axes: The axes along which to vmap output arguments.
|
273
|
-
|
274
|
-
Note:
|
275
|
-
In order to simplify the application of vmap, we close the method arguments
|
276
|
-
over all the non-mapped input arguments. Furthermore, for improving the
|
277
|
-
compatibility with jit, we also close the vmap application over the static
|
278
|
-
arguments.
|
279
|
-
|
280
|
-
Returns:
|
281
|
-
The transformed method operating on an object with the desired mutability.
|
282
|
-
We maintain the same signature of the original method.
|
283
|
-
"""
|
284
|
-
|
285
|
-
# Extract the signature of the function
|
286
|
-
sig = inspect.signature(fn)
|
287
|
-
|
288
|
-
# All static arguments must be actual arguments of fn
|
289
|
-
for name in static_argnames:
|
290
|
-
if name not in sig.parameters:
|
291
|
-
raise ValueError(f"Static argument '{name}' not found in {fn}")
|
292
|
-
|
293
|
-
# If in_axes is a tuple, its dimension should match the number of arguments
|
294
|
-
if isinstance(in_axes, tuple) and len(in_axes) != len(sig.parameters):
|
295
|
-
msg = "The length of 'in_axes' must match the number of arguments ({})"
|
296
|
-
raise ValueError(msg.format(len(sig.parameters)))
|
297
|
-
|
298
|
-
# Check that static arguments are not mapped with vmap.
|
299
|
-
# This case would not work since static arguments are not traces and vmap need
|
300
|
-
# to trace arguments in order to map them.
|
301
|
-
if isinstance(in_axes, tuple):
|
302
|
-
for mapped_axis, arg_name in zip(in_axes, sig.parameters.keys()):
|
303
|
-
if mapped_axis is not None and arg_name in static_argnames:
|
304
|
-
raise ValueError(
|
305
|
-
f"Static argument '{arg_name}' cannot be mapped with vmap"
|
306
|
-
)
|
307
|
-
|
308
|
-
def fn_tf_vmap(*args, function_to_vmap: Callable, **kwargs):
|
309
|
-
"""Wrapper applying the vmap transformation"""
|
310
|
-
|
311
|
-
# Canonicalize the arguments so that all of them are kwargs
|
312
|
-
bound = sig.bind(*args, **kwargs)
|
313
|
-
bound.apply_defaults()
|
314
|
-
|
315
|
-
# Build a dictionary mapping all arguments to a mapped axis, even when
|
316
|
-
# the None is passed (defaults to in_axes=0) or and int is passed (defaults
|
317
|
-
# to in_axes=<int>).
|
318
|
-
match in_axes:
|
319
|
-
case None:
|
320
|
-
argname_to_mapped_axis = {name: 0 for name in bound.arguments}
|
321
|
-
case tuple():
|
322
|
-
argname_to_mapped_axis = {
|
323
|
-
name: in_axes[i] for i, name in enumerate(bound.arguments)
|
324
|
-
}
|
325
|
-
case int():
|
326
|
-
argname_to_mapped_axis = {name: in_axes for name in bound.arguments}
|
327
|
-
case _:
|
328
|
-
raise ValueError(in_axes)
|
329
|
-
|
330
|
-
# Build a dictionary (argument_name -> argument) for all mapped arguments.
|
331
|
-
# Note that a mapped argument is an argument whose axis is not None and
|
332
|
-
# is not a static jit argument.
|
333
|
-
vmap_mapped_args = {
|
334
|
-
arg: value
|
335
|
-
for arg, value in bound.arguments.items()
|
336
|
-
if argname_to_mapped_axis[arg] is not None
|
337
|
-
and arg not in static_argnames
|
338
|
-
}
|
339
|
-
|
340
|
-
# Build a dictionary (argument_name -> argument) for all unmapped arguments
|
341
|
-
vmap_unmapped_args = {
|
342
|
-
arg: value
|
343
|
-
for arg, value in bound.arguments.items()
|
344
|
-
if arg not in vmap_mapped_args
|
345
|
-
}
|
346
|
-
|
347
|
-
# Disable mapping of non-vectorized default arguments
|
348
|
-
for arg, value in argname_to_mapped_axis.items():
|
349
|
-
if arg in vmap_mapped_args and value == sig.parameters[arg].default:
|
350
|
-
logging.debug(f"Disabling vmapping of default argument '{arg}'")
|
351
|
-
argname_to_mapped_axis[arg] = None
|
352
|
-
|
353
|
-
# Close the function over the unmapped arguments of vmap
|
354
|
-
fn_closed = lambda *mapped_args: function_to_vmap(
|
355
|
-
**vmap_unmapped_args, **dict(zip(vmap_mapped_args.keys(), mapped_args))
|
356
|
-
)
|
357
|
-
|
358
|
-
# Create the in_axes tuple of only the mapped arguments
|
359
|
-
in_axes_mapped = tuple(
|
360
|
-
argname_to_mapped_axis[name] for name in vmap_mapped_args
|
361
|
-
)
|
362
|
-
|
363
|
-
# If all in_axes are the same, simplify in_axes tuple to be just an integer
|
364
|
-
if len(set(in_axes_mapped)) == 1:
|
365
|
-
in_axes_mapped = list(set(in_axes_mapped))[0]
|
366
|
-
|
367
|
-
# If, instead, in_axes has different elements, we need to replace the mapped
|
368
|
-
# axis of "self" with a pytree having as leafs the mapped axis.
|
369
|
-
# This is because the vmap in_axes specification must be a tree prefix of
|
370
|
-
# the corresponding value.
|
371
|
-
if isinstance(in_axes_mapped, tuple) and "self" in vmap_mapped_args:
|
372
|
-
argname_to_mapped_axis["self"] = jax.tree_util.tree_map(
|
373
|
-
lambda _: argname_to_mapped_axis["self"], vmap_mapped_args["self"]
|
374
|
-
)
|
375
|
-
in_axes_mapped = tuple(
|
376
|
-
argname_to_mapped_axis[name] for name in vmap_mapped_args
|
377
|
-
)
|
378
|
-
|
379
|
-
# Apply the vmap transformation and call the function passing only the
|
380
|
-
# mapped arguments. The unmapped arguments have been closed over.
|
381
|
-
# Note: we altered the "in_axes" tuple so that it does not have any
|
382
|
-
# None elements.
|
383
|
-
# Note: if "in_axes_mapped" is a tuple, the following fails if we pass kwargs,
|
384
|
-
# we need to pass the unpacked args tuple instead.
|
385
|
-
return jax.vmap(
|
386
|
-
fn_closed,
|
387
|
-
in_axes=in_axes_mapped,
|
388
|
-
**dict(out_axes=out_axes) if out_axes is not None else {},
|
389
|
-
)(*list(vmap_mapped_args.values()))
|
390
|
-
|
391
|
-
def fn_tf_jit(*args, function_to_jit: Callable, **kwargs):
|
392
|
-
"""Wrapper applying the jit transformation"""
|
393
|
-
|
394
|
-
# Canonicalize the arguments so that all of them are kwargs
|
395
|
-
bound = sig.bind(*args, **kwargs)
|
396
|
-
bound.apply_defaults()
|
397
|
-
|
398
|
-
# Apply the jit transformation and call the function passing all arguments
|
399
|
-
# as keyword arguments
|
400
|
-
return jax.jit(function_to_jit, static_argnames=static_argnames)(
|
401
|
-
**bound.arguments
|
402
|
-
)
|
403
|
-
|
404
|
-
# First applied wrapper that executes fn in a mutable context
|
405
|
-
fn_mutable = functools.partial(
|
406
|
-
jax_tf.call_class_method_in_mutable_context,
|
407
|
-
fn=fn,
|
408
|
-
jit=jit,
|
409
|
-
mutability=mutability,
|
410
|
-
)
|
411
|
-
|
412
|
-
# Second applied wrapper that transforms fn with vmap
|
413
|
-
fn_vmap = (
|
414
|
-
fn_mutable
|
415
|
-
if not vmap
|
416
|
-
else functools.partial(fn_tf_vmap, function_to_vmap=fn_mutable)
|
417
|
-
)
|
418
|
-
|
419
|
-
# Third applied wrapper that transforms fn with jit
|
420
|
-
fn_jit_vmap = (
|
421
|
-
fn_vmap
|
422
|
-
if not jit
|
423
|
-
else functools.partial(fn_tf_jit, function_to_jit=fn_vmap)
|
424
|
-
)
|
425
|
-
|
426
|
-
return fn_jit_vmap
|
427
|
-
|
428
|
-
@staticmethod
|
429
|
-
def call_class_method_in_mutable_context(
|
430
|
-
*args, fn: Callable, jit: bool, mutability: Mutability, **kwargs
|
431
|
-
) -> tuple[Any, Vmappable]:
|
432
|
-
"""
|
433
|
-
Wrapper to call a method on an object with the desired mutable context.
|
434
|
-
|
435
|
-
Args:
|
436
|
-
fn: The method to call.
|
437
|
-
jit: Whether the method is being jit compiled or not.
|
438
|
-
mutability: The desired mutability context.
|
439
|
-
*args: The positional arguments to pass to the method (including self).
|
440
|
-
**kwargs: The keyword arguments to pass to the method.
|
441
|
-
|
442
|
-
Returns:
|
443
|
-
A tuple containing the return value of the method and the object
|
444
|
-
possibly updated by the method if it is in read-write.
|
445
|
-
|
446
|
-
Note:
|
447
|
-
This approach enables to jit-compile methods of a stateful object without
|
448
|
-
leaking traces, therefore obtaining a jax-compatible OOP pattern.
|
449
|
-
"""
|
450
|
-
|
451
|
-
# Log here whether the method is being jit compiled or not.
|
452
|
-
# This log message does not get printed from compiled code, so here is the
|
453
|
-
# most appropriate place to be sure that we log it correctly.
|
454
|
-
if jit:
|
455
|
-
logging.debug(msg=f"JIT compiling {fn}")
|
456
|
-
|
457
|
-
# Canonicalize the arguments so that all of them are kwargs
|
458
|
-
sig = inspect.signature(fn)
|
459
|
-
bound = sig.bind(*args, **kwargs)
|
460
|
-
bound.apply_defaults()
|
461
|
-
|
462
|
-
# Extract the class instance over which fn is called
|
463
|
-
instance: Vmappable = bound.arguments["self"]
|
464
|
-
|
465
|
-
# Select the right mutability. If the instance is mutable with validation
|
466
|
-
# disabled, we override the input mutability so that we do not fail in case
|
467
|
-
# of mismatched tree structure.
|
468
|
-
mut = (
|
469
|
-
Mutability.MUTABLE_NO_VALIDATION
|
470
|
-
if instance._mutability() is Mutability.MUTABLE_NO_VALIDATION
|
471
|
-
else mutability
|
472
|
-
)
|
473
|
-
|
474
|
-
# Call fn in a mutable context
|
475
|
-
with instance.mutable_context(mutability=mut):
|
476
|
-
# Methods could call other decorated methods. When it happens, the decorator
|
477
|
-
# of the called method is invoked, that applies jit and vmap transformations.
|
478
|
-
# This is not desired as it calls vmap inside an already vmapped method.
|
479
|
-
# We work around this occurrence by disabling the jit/vmap decorators of all
|
480
|
-
# methods called inside fn through a context manager.
|
481
|
-
# Note that we already work around this in the beginning of the wrapper
|
482
|
-
# function by detecting traced arguments, but the decorator works also
|
483
|
-
# when jit=False and vmap=False, therefore only enforcing the mutability.
|
484
|
-
with jax_tf.disabled_oop_decorators():
|
485
|
-
out = fn(**bound.arguments)
|
486
|
-
|
487
|
-
return out, instance
|
488
|
-
|
489
|
-
@staticmethod
|
490
|
-
def env_var_on(var_name: str, default_value: str = "1") -> bool:
|
491
|
-
"""
|
492
|
-
Check whether an environment variable is set to a value that is considered on.
|
493
|
-
|
494
|
-
Args:
|
495
|
-
var_name: The name of the environment variable.
|
496
|
-
default_value: The default variable value to consider if the variable has not
|
497
|
-
been exported.
|
498
|
-
|
499
|
-
Returns:
|
500
|
-
True if the environment variable contains an on value, False otherwise.
|
501
|
-
"""
|
502
|
-
|
503
|
-
on_values = {"1", "true", "on", "yes"}
|
504
|
-
return os.environ.get(var_name, default_value).lower() in on_values
|
505
|
-
|
506
|
-
@staticmethod
|
507
|
-
@contextlib.contextmanager
|
508
|
-
def disabled_oop_decorators() -> Generator[None, None, None]:
|
509
|
-
"""
|
510
|
-
Context manager to disable the application of jax transformations performed by
|
511
|
-
the decorators of this class.
|
512
|
-
|
513
|
-
Note: when the transformations are disabled, the only logic still applied is
|
514
|
-
the selection of the object mutability over which the method is running.
|
515
|
-
"""
|
516
|
-
|
517
|
-
# Check whether the environment variable is part of the environment and
|
518
|
-
# save its value. We restore the original value before exiting the context.
|
519
|
-
env_cache = (
|
520
|
-
None if jax_tf.EnvVarOOP not in os.environ else os.environ[jax_tf.EnvVarOOP]
|
521
|
-
)
|
522
|
-
|
523
|
-
# Disable both jit and vmap transformations
|
524
|
-
os.environ[jax_tf.EnvVarOOP] = "0"
|
525
|
-
|
526
|
-
try:
|
527
|
-
# Execute the code in the context with disabled transformations
|
528
|
-
yield
|
529
|
-
|
530
|
-
finally:
|
531
|
-
# Restore the original value of the environment variable or remove it if
|
532
|
-
# it was not present before entering the context
|
533
|
-
if env_cache is not None:
|
534
|
-
os.environ[jax_tf.EnvVarOOP] = env_cache
|
535
|
-
else:
|
536
|
-
_ = os.environ.pop(jax_tf.EnvVarOOP)
|
jaxsim/utils/vmappable.py
DELETED
@@ -1,117 +0,0 @@
|
|
1
|
-
import dataclasses
|
2
|
-
from typing import Type
|
3
|
-
|
4
|
-
import jax
|
5
|
-
import jax.numpy as jnp
|
6
|
-
import jax_dataclasses
|
7
|
-
|
8
|
-
from . import JaxsimDataclass, Mutability
|
9
|
-
|
10
|
-
try:
|
11
|
-
from typing import Self
|
12
|
-
except ImportError:
|
13
|
-
from typing_extensions import Self
|
14
|
-
|
15
|
-
|
16
|
-
@jax_dataclasses.pytree_dataclass
|
17
|
-
class Vmappable(JaxsimDataclass):
|
18
|
-
"""Abstract class with utilities for vmappable pytrees."""
|
19
|
-
|
20
|
-
batch_size: jax_dataclasses.Static[int] = dataclasses.field(
|
21
|
-
default=int(0), repr=False, compare=False, hash=False, kw_only=True
|
22
|
-
)
|
23
|
-
|
24
|
-
@property
|
25
|
-
def vectorized(self) -> bool:
|
26
|
-
"""Marks this pytree as vectorized."""
|
27
|
-
|
28
|
-
return self.batch_size > 0
|
29
|
-
|
30
|
-
@classmethod
|
31
|
-
def build_from_list(cls: Type[Self], list_of_obj: list[Self]) -> Self:
|
32
|
-
"""
|
33
|
-
Build a vectorized pytree from a list of pytree of the same type.
|
34
|
-
|
35
|
-
Args:
|
36
|
-
list_of_obj: The list of pytrees to vectorize.
|
37
|
-
|
38
|
-
Returns:
|
39
|
-
The vectorized pytree having as leaves the stacked leaves of the input list.
|
40
|
-
"""
|
41
|
-
|
42
|
-
if set(type(el) for el in list_of_obj) != {cls}:
|
43
|
-
msg = "The input list must contain only objects of type '{}'"
|
44
|
-
raise ValueError(msg.format(cls.__name__))
|
45
|
-
|
46
|
-
# Create a pytree by stacking all the leafs of the input list
|
47
|
-
data_vec: Vmappable = jax.tree_map(
|
48
|
-
lambda *leafs: jnp.array(leafs), *list_of_obj
|
49
|
-
)
|
50
|
-
|
51
|
-
# Store the batch dimension
|
52
|
-
with data_vec.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):
|
53
|
-
data_vec.batch_size = len(list_of_obj)
|
54
|
-
|
55
|
-
# Detect the most common mutability in the input list
|
56
|
-
mutabilities = [e._mutability() for e in list_of_obj]
|
57
|
-
mutability = max(set(mutabilities), key=mutabilities.count)
|
58
|
-
|
59
|
-
# Update the mutability of the vectorized pytree
|
60
|
-
data_vec._set_mutability(mutability)
|
61
|
-
|
62
|
-
return data_vec
|
63
|
-
|
64
|
-
def vectorize(self: Self, batch_size: int) -> Self:
|
65
|
-
"""
|
66
|
-
Return a vectorized version of this pytree.
|
67
|
-
|
68
|
-
Args:
|
69
|
-
batch_size: The batch size.
|
70
|
-
|
71
|
-
Returns:
|
72
|
-
A vectorized version of this pytree obtained by stacking the leaves of the
|
73
|
-
original pytree along a new batch dimension (the first one).
|
74
|
-
"""
|
75
|
-
|
76
|
-
if self.vectorized:
|
77
|
-
raise RuntimeError("Cannot vectorize an already vectorized object")
|
78
|
-
|
79
|
-
if batch_size == 0:
|
80
|
-
return self.copy()
|
81
|
-
|
82
|
-
# TODO validate if mutability is maintained
|
83
|
-
|
84
|
-
return self.__class__.build_from_list(list_of_obj=[self] * batch_size)
|
85
|
-
|
86
|
-
def extract_element(self: Self, index: int) -> Self:
|
87
|
-
"""
|
88
|
-
Extract the i-th element from a vectorized pytree.
|
89
|
-
|
90
|
-
Args:
|
91
|
-
index: The index of the element to extract.
|
92
|
-
|
93
|
-
Returns:
|
94
|
-
A non vectorized pytree obtained by extracting the i-th element from the
|
95
|
-
vectorized pytree.
|
96
|
-
"""
|
97
|
-
|
98
|
-
if index < 0:
|
99
|
-
raise ValueError("The index of the desired element cannot be negative")
|
100
|
-
|
101
|
-
if index == 0 and self.batch_size == 0:
|
102
|
-
return self.copy()
|
103
|
-
|
104
|
-
if not self.vectorized:
|
105
|
-
raise RuntimeError("Cannot extract elements from a non-vectorized object")
|
106
|
-
|
107
|
-
if index >= self.batch_size:
|
108
|
-
raise ValueError("The index must be smaller than the batch size")
|
109
|
-
|
110
|
-
# Get the i-th pytree by extracting the i-th element from the vectorized pytree
|
111
|
-
data = jax.tree_map(lambda leaf: leaf[index], self)
|
112
|
-
|
113
|
-
# Update the batch size of the extracted scalar pytree
|
114
|
-
with data.mutable_context(mutability=Mutability.MUTABLE):
|
115
|
-
data.batch_size = 0
|
116
|
-
|
117
|
-
return data
|