jaxsim 0.2.dev188__py3-none-any.whl → 0.6.1.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. jaxsim/__init__.py +73 -22
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +13 -1
  4. jaxsim/api/com.py +423 -0
  5. jaxsim/api/common.py +48 -19
  6. jaxsim/api/contact.py +604 -52
  7. jaxsim/api/data.py +308 -163
  8. jaxsim/api/frame.py +471 -0
  9. jaxsim/api/joint.py +166 -37
  10. jaxsim/api/kin_dyn_parameters.py +901 -0
  11. jaxsim/api/link.py +277 -78
  12. jaxsim/api/model.py +1572 -362
  13. jaxsim/api/ode.py +324 -129
  14. jaxsim/api/ode_data.py +401 -0
  15. jaxsim/api/references.py +216 -80
  16. jaxsim/exceptions.py +80 -0
  17. jaxsim/integrators/__init__.py +2 -2
  18. jaxsim/integrators/common.py +191 -107
  19. jaxsim/integrators/fixed_step.py +97 -102
  20. jaxsim/integrators/variable_step.py +706 -0
  21. jaxsim/logging.py +1 -2
  22. jaxsim/math/__init__.py +13 -0
  23. jaxsim/math/adjoint.py +57 -22
  24. jaxsim/math/cross.py +16 -7
  25. jaxsim/math/inertia.py +10 -8
  26. jaxsim/math/joint_model.py +289 -0
  27. jaxsim/math/quaternion.py +54 -20
  28. jaxsim/math/rotation.py +27 -21
  29. jaxsim/math/skew.py +16 -5
  30. jaxsim/math/transform.py +102 -0
  31. jaxsim/math/utils.py +31 -0
  32. jaxsim/mujoco/__init__.py +2 -1
  33. jaxsim/mujoco/loaders.py +216 -29
  34. jaxsim/mujoco/model.py +163 -33
  35. jaxsim/mujoco/utils.py +228 -0
  36. jaxsim/mujoco/visualizer.py +107 -22
  37. jaxsim/parsers/__init__.py +0 -1
  38. jaxsim/parsers/descriptions/__init__.py +8 -2
  39. jaxsim/parsers/descriptions/collision.py +87 -16
  40. jaxsim/parsers/descriptions/joint.py +80 -87
  41. jaxsim/parsers/descriptions/link.py +62 -24
  42. jaxsim/parsers/descriptions/model.py +101 -68
  43. jaxsim/parsers/kinematic_graph.py +607 -225
  44. jaxsim/parsers/rod/meshes.py +104 -0
  45. jaxsim/parsers/rod/parser.py +125 -82
  46. jaxsim/parsers/rod/utils.py +127 -82
  47. jaxsim/rbda/__init__.py +11 -0
  48. jaxsim/rbda/aba.py +289 -0
  49. jaxsim/rbda/collidable_points.py +156 -0
  50. jaxsim/rbda/contacts/__init__.py +13 -0
  51. jaxsim/rbda/contacts/common.py +313 -0
  52. jaxsim/rbda/contacts/relaxed_rigid.py +605 -0
  53. jaxsim/rbda/contacts/rigid.py +462 -0
  54. jaxsim/rbda/contacts/soft.py +480 -0
  55. jaxsim/rbda/contacts/visco_elastic.py +1066 -0
  56. jaxsim/rbda/crba.py +167 -0
  57. jaxsim/rbda/forward_kinematics.py +117 -0
  58. jaxsim/rbda/jacobian.py +330 -0
  59. jaxsim/rbda/rnea.py +235 -0
  60. jaxsim/rbda/utils.py +160 -0
  61. jaxsim/terrain/__init__.py +2 -0
  62. jaxsim/terrain/terrain.py +238 -0
  63. jaxsim/typing.py +24 -24
  64. jaxsim/utils/__init__.py +1 -4
  65. jaxsim/utils/jaxsim_dataclass.py +289 -34
  66. jaxsim/utils/tracing.py +5 -11
  67. jaxsim/utils/wrappers.py +159 -0
  68. {jaxsim-0.2.dev188.dist-info → jaxsim-0.6.1.dev2.dist-info}/LICENSE +1 -1
  69. jaxsim-0.6.1.dev2.dist-info/METADATA +465 -0
  70. jaxsim-0.6.1.dev2.dist-info/RECORD +74 -0
  71. {jaxsim-0.2.dev188.dist-info → jaxsim-0.6.1.dev2.dist-info}/WHEEL +1 -1
  72. jaxsim/high_level/__init__.py +0 -2
  73. jaxsim/high_level/common.py +0 -11
  74. jaxsim/high_level/joint.py +0 -148
  75. jaxsim/high_level/link.py +0 -259
  76. jaxsim/high_level/model.py +0 -1686
  77. jaxsim/math/conv.py +0 -114
  78. jaxsim/math/joint.py +0 -102
  79. jaxsim/math/plucker.py +0 -100
  80. jaxsim/physics/__init__.py +0 -12
  81. jaxsim/physics/algos/__init__.py +0 -0
  82. jaxsim/physics/algos/aba.py +0 -254
  83. jaxsim/physics/algos/aba_motors.py +0 -284
  84. jaxsim/physics/algos/crba.py +0 -154
  85. jaxsim/physics/algos/forward_kinematics.py +0 -79
  86. jaxsim/physics/algos/jacobian.py +0 -98
  87. jaxsim/physics/algos/rnea.py +0 -180
  88. jaxsim/physics/algos/rnea_motors.py +0 -196
  89. jaxsim/physics/algos/soft_contacts.py +0 -523
  90. jaxsim/physics/algos/terrain.py +0 -80
  91. jaxsim/physics/algos/utils.py +0 -69
  92. jaxsim/physics/model/__init__.py +0 -0
  93. jaxsim/physics/model/ground_contact.py +0 -55
  94. jaxsim/physics/model/physics_model.py +0 -388
  95. jaxsim/physics/model/physics_model_state.py +0 -283
  96. jaxsim/simulation/__init__.py +0 -4
  97. jaxsim/simulation/integrators.py +0 -393
  98. jaxsim/simulation/ode.py +0 -290
  99. jaxsim/simulation/ode_data.py +0 -96
  100. jaxsim/simulation/ode_integration.py +0 -62
  101. jaxsim/simulation/simulator.py +0 -543
  102. jaxsim/simulation/simulator_callbacks.py +0 -79
  103. jaxsim/simulation/utils.py +0 -15
  104. jaxsim/sixd/__init__.py +0 -2
  105. jaxsim/utils/oop.py +0 -536
  106. jaxsim/utils/vmappable.py +0 -117
  107. jaxsim-0.2.dev188.dist-info/METADATA +0 -184
  108. jaxsim-0.2.dev188.dist-info/RECORD +0 -81
  109. {jaxsim-0.2.dev188.dist-info → jaxsim-0.6.1.dev2.dist-info}/top_level.txt +0 -0
jaxsim/utils/oop.py DELETED
@@ -1,536 +0,0 @@
1
- import contextlib
2
- import dataclasses
3
- import functools
4
- import inspect
5
- import os
6
- from typing import Any, Callable, Generator, TypeVar
7
-
8
- import jax
9
- import jax.flatten_util
10
- from typing_extensions import ParamSpec
11
-
12
- from jaxsim import logging
13
- from jaxsim.utils import tracing
14
-
15
- from . import Mutability, Vmappable
16
-
17
- _P = ParamSpec("_P")
18
- _R = TypeVar("_R")
19
-
20
-
21
- class jax_tf:
22
- """
23
- Class containing decorators applicable to methods of Vmappable objects.
24
- """
25
-
26
- # Environment variables that can be used to disable the transformations
27
- EnvVarOOP: str = "JAXSIM_OOP_DECORATORS"
28
- EnvVarJitOOP: str = "JAXSIM_OOP_DECORATORS_JIT"
29
- EnvVarVmapOOP: str = "JAXSIM_OOP_DECORATORS_VMAP"
30
- EnvVarCacheOOP: str = "JAXSIM_OOP_DECORATORS_CACHE"
31
-
32
- @staticmethod
33
- def method_ro(
34
- fn: Callable[_P, _R],
35
- jit: bool = True,
36
- static_argnames: tuple[str, ...] | list[str] = (),
37
- vmap: bool | None = None,
38
- vmap_in_axes: tuple[int, ...] | int | None = None,
39
- vmap_out_axes: tuple[int, ...] | int | None = None,
40
- ) -> Callable[_P, _R]:
41
- """
42
- Decorator for r/o methods of classes inheriting from Vmappable.
43
- """
44
-
45
- return jax_tf.method(
46
- fn=fn,
47
- read_only=True,
48
- validate=True,
49
- jit_enabled=jit,
50
- static_argnames=static_argnames,
51
- vmap_enabled=vmap,
52
- vmap_in_axes=vmap_in_axes,
53
- vmap_out_axes=vmap_out_axes,
54
- )
55
-
56
- @staticmethod
57
- def method_rw(
58
- fn: Callable[_P, _R],
59
- validate: bool = True,
60
- jit: bool = True,
61
- static_argnames: tuple[str, ...] | list[str] = (),
62
- vmap: bool | None = None,
63
- vmap_in_axes: tuple[int, ...] | int | None = None,
64
- vmap_out_axes: tuple[int, ...] | int | None = None,
65
- ) -> Callable[_P, _R]:
66
- """
67
- Decorator for r/w methods of classes inheriting from Vmappable.
68
- """
69
-
70
- return jax_tf.method(
71
- fn=fn,
72
- read_only=False,
73
- validate=validate,
74
- jit_enabled=jit,
75
- static_argnames=static_argnames,
76
- vmap_enabled=vmap,
77
- vmap_in_axes=vmap_in_axes,
78
- vmap_out_axes=vmap_out_axes,
79
- )
80
-
81
- @staticmethod
82
- def method(
83
- fn: Callable[_P, _R],
84
- read_only: bool = True,
85
- validate: bool = True,
86
- jit_enabled: bool = True,
87
- static_argnames: tuple[str, ...] | list[str] = (),
88
- vmap_enabled: bool | None = None,
89
- vmap_in_axes: tuple[int, ...] | int | None = None,
90
- vmap_out_axes: tuple[int, ...] | int | None = None,
91
- ):
92
- """
93
- Decorator for methods of classes inheriting from Vmappable.
94
-
95
- This decorator enables executing the methods on an object characterized by a
96
- desired mutability, that is selected considering the r/o and validation flags.
97
- It also allows to transform the method with the jit/vmap transformations.
98
- If the Vmappable object is vectorized, the method is automatically vmapped, and
99
- the in_axes are properly post-processed to simplify the combination with jit.
100
-
101
- Args:
102
- fn: The method to decorate.
103
- read_only: Whether the method operates on a read-only object.
104
- validate: Whether r/w methods should preserve the pytree structure.
105
- jit_enabled: Whether to apply the jit transformation.
106
- static_argnames: The names of the arguments that should be static.
107
- vmap_enabled: Whether to apply the vmap transformation.
108
- vmap_in_axes: The in_axes to use for the vmap transformation.
109
- vmap_out_axes: The out_axes to use for the vmap transformation.
110
-
111
- Returns:
112
- The decorated method.
113
- """
114
-
115
- @functools.wraps(fn)
116
- def wrapper(*args: _P.args, **kwargs: _P.kwargs):
117
- """The wrapper function that is returned by this decorator."""
118
-
119
- # Methods of classes inheriting from Vmappable decorated by this wrapper
120
- # automatically support jit/vmap/mutability features when called standalone.
121
- # However, when objects are arguments of plain functions transformed with
122
- # jit/vmap, and decorated methods are called inside those functions, we need
123
- # to disable this decorator to avoid double wrapping and execution errors.
124
- # We do so by iterating over the arguments, and checking whether they are
125
- # being traced by JAX.
126
- for argument in list(args) + list(kwargs.values()):
127
- try:
128
- argument_flat, _ = jax.flatten_util.ravel_pytree(argument)
129
-
130
- if tracing(argument_flat):
131
- return fn(*args, **kwargs)
132
- except:
133
- continue
134
-
135
- # ===============================================================
136
- # Wrap fn so that jit/vmap/mutability transformations are applied
137
- # ===============================================================
138
-
139
- # Initialize the mutability of the instance over which the method is running.
140
- # * In r/o methods, this approach prevents any type of mutation.
141
- # * In r/w methods, this approach allows to catch early JIT recompilations
142
- # caused by unwanted changes in the pytree structure.
143
- if read_only:
144
- mutability = Mutability.FROZEN
145
- else:
146
- mutability = (
147
- Mutability.MUTABLE if validate else Mutability.MUTABLE_NO_VALIDATION
148
- )
149
-
150
- # Extract the class instance over which fn is called
151
- instance: Vmappable = args[0]
152
- assert isinstance(instance, Vmappable)
153
-
154
- # Save the original mutability
155
- original_mutability = instance._mutability()
156
-
157
- # Inspect the environment to detect whether to enforce disabling jit/vmap
158
- deco_on = jax_tf.env_var_on(jax_tf.EnvVarOOP)
159
- jit_enabled_env = jax_tf.env_var_on(jax_tf.EnvVarJitOOP) and deco_on
160
- vmap_enabled_env = jax_tf.env_var_on(jax_tf.EnvVarVmapOOP) and deco_on
161
-
162
- # Allow disabling the cache of jit-compiled functions.
163
- # It can be useful for debugging or testing purposes.
164
- wrap_fn = (
165
- jax_tf.wrap_fn
166
- if jax_tf.env_var_on(jax_tf.EnvVarCacheOOP) and deco_on
167
- else jax_tf.wrap_fn.__wrapped__
168
- )
169
-
170
- # Get the transformed function (possibly cached by functools.cache).
171
- # Note that all the arguments of the following methods, when hashed, should
172
- # uniquely identify the returned function so that a new function is built
173
- # when arguments change and either jit or vmap have to be called again.
174
- fn_db = wrap_fn(
175
- fn=fn, # noqa
176
- mutability=mutability,
177
- jit=jit_enabled_env and jit_enabled,
178
- static_argnames=tuple(static_argnames),
179
- vmap=vmap_enabled_env
180
- and (
181
- vmap_enabled is True
182
- or (vmap_enabled is None and instance.vectorized)
183
- ),
184
- in_axes=vmap_in_axes,
185
- out_axes=vmap_out_axes,
186
- )
187
-
188
- # Call the transformed (mutable/jit/vmap) method
189
- out, obj = fn_db(*args, **kwargs)
190
-
191
- if read_only:
192
- # Restore the original mutability
193
- instance._set_mutability(mutability=original_mutability)
194
-
195
- return out
196
-
197
- # =================================================================
198
- # From here we assume that the wrapper is operating on a r/w method
199
- # =================================================================
200
-
201
- from jax_dataclasses._dataclasses import JDC_STATIC_MARKER
202
-
203
- # Select the right runtime mutability. The only difference here is when a r/w
204
- # method is called on a frozen object. In this case, we enable updating the
205
- # pytree data and preserve its structure only if validation is enabled.
206
- mutability_dict = {
207
- Mutability.MUTABLE_NO_VALIDATION: Mutability.MUTABLE_NO_VALIDATION,
208
- Mutability.MUTABLE: Mutability.MUTABLE,
209
- Mutability.FROZEN: (
210
- Mutability.MUTABLE if validate else Mutability.MUTABLE_NO_VALIDATION
211
- ),
212
- }
213
-
214
- # We need to replace all the dynamic leafs of the original instance with those
215
- # computed by the functional transformation.
216
- # We do so by iterating over the fields of the jax_dataclasses and ignoring
217
- # all the fields that are marked as static.
218
- # Caveats: https://github.com/ami-iit/jaxsim/pull/48#issuecomment-1746635121.
219
- with instance.mutable_context(
220
- mutability=mutability_dict[instance._mutability()]
221
- ):
222
- for f in dataclasses.fields(instance): # noqa
223
- if (
224
- hasattr(f, "type")
225
- and hasattr(f.type, "__metadata__")
226
- and JDC_STATIC_MARKER in f.type.__metadata__
227
- ):
228
- continue
229
-
230
- try:
231
- setattr(instance, f.name, getattr(obj, f.name))
232
- except AssertionError as exc:
233
- logging.debug(f"Old object:\n{getattr(instance, f.name)}")
234
- logging.debug(f"New object:\n{getattr(obj, f.name)}")
235
- raise RuntimeError(
236
- f"Failed to update field '{f.name}'"
237
- ) from exc
238
-
239
- return out
240
-
241
- return wrapper
242
-
243
- @staticmethod
244
- @functools.cache
245
- def wrap_fn(
246
- fn: Callable,
247
- mutability: Mutability,
248
- jit: bool,
249
- static_argnames: tuple[str, ...] | list[str],
250
- vmap: bool,
251
- in_axes: tuple[int, ...] | int | None,
252
- out_axes: tuple[int, ...] | int | None,
253
- ) -> Callable:
254
- """
255
- Transform a method with jit/vmap and execute it on an object characterized
256
- by the desired mutability.
257
-
258
- Note:
259
- The method should take the object (self) as first argument.
260
-
261
- Note:
262
- This returned transformed method is cached by considering the hash of all
263
- the arguments. It will re-apply jit/vmap transformations only if needed.
264
-
265
- Args:
266
- fn: The method to consider.
267
- mutability: The mutability of the object on which the method is called.
268
- jit: Whether to apply jit transformations.
269
- static_argnames: The names of the arguments that should be considered static.
270
- vmap: Whether to apply vmap transformations.
271
- in_axes: The axes along which to vmap input arguments.
272
- out_axes: The axes along which to vmap output arguments.
273
-
274
- Note:
275
- In order to simplify the application of vmap, we close the method arguments
276
- over all the non-mapped input arguments. Furthermore, for improving the
277
- compatibility with jit, we also close the vmap application over the static
278
- arguments.
279
-
280
- Returns:
281
- The transformed method operating on an object with the desired mutability.
282
- We maintain the same signature of the original method.
283
- """
284
-
285
- # Extract the signature of the function
286
- sig = inspect.signature(fn)
287
-
288
- # All static arguments must be actual arguments of fn
289
- for name in static_argnames:
290
- if name not in sig.parameters:
291
- raise ValueError(f"Static argument '{name}' not found in {fn}")
292
-
293
- # If in_axes is a tuple, its dimension should match the number of arguments
294
- if isinstance(in_axes, tuple) and len(in_axes) != len(sig.parameters):
295
- msg = "The length of 'in_axes' must match the number of arguments ({})"
296
- raise ValueError(msg.format(len(sig.parameters)))
297
-
298
- # Check that static arguments are not mapped with vmap.
299
- # This case would not work since static arguments are not traces and vmap need
300
- # to trace arguments in order to map them.
301
- if isinstance(in_axes, tuple):
302
- for mapped_axis, arg_name in zip(in_axes, sig.parameters.keys()):
303
- if mapped_axis is not None and arg_name in static_argnames:
304
- raise ValueError(
305
- f"Static argument '{arg_name}' cannot be mapped with vmap"
306
- )
307
-
308
- def fn_tf_vmap(*args, function_to_vmap: Callable, **kwargs):
309
- """Wrapper applying the vmap transformation"""
310
-
311
- # Canonicalize the arguments so that all of them are kwargs
312
- bound = sig.bind(*args, **kwargs)
313
- bound.apply_defaults()
314
-
315
- # Build a dictionary mapping all arguments to a mapped axis, even when
316
- # the None is passed (defaults to in_axes=0) or and int is passed (defaults
317
- # to in_axes=<int>).
318
- match in_axes:
319
- case None:
320
- argname_to_mapped_axis = {name: 0 for name in bound.arguments}
321
- case tuple():
322
- argname_to_mapped_axis = {
323
- name: in_axes[i] for i, name in enumerate(bound.arguments)
324
- }
325
- case int():
326
- argname_to_mapped_axis = {name: in_axes for name in bound.arguments}
327
- case _:
328
- raise ValueError(in_axes)
329
-
330
- # Build a dictionary (argument_name -> argument) for all mapped arguments.
331
- # Note that a mapped argument is an argument whose axis is not None and
332
- # is not a static jit argument.
333
- vmap_mapped_args = {
334
- arg: value
335
- for arg, value in bound.arguments.items()
336
- if argname_to_mapped_axis[arg] is not None
337
- and arg not in static_argnames
338
- }
339
-
340
- # Build a dictionary (argument_name -> argument) for all unmapped arguments
341
- vmap_unmapped_args = {
342
- arg: value
343
- for arg, value in bound.arguments.items()
344
- if arg not in vmap_mapped_args
345
- }
346
-
347
- # Disable mapping of non-vectorized default arguments
348
- for arg, value in argname_to_mapped_axis.items():
349
- if arg in vmap_mapped_args and value == sig.parameters[arg].default:
350
- logging.debug(f"Disabling vmapping of default argument '{arg}'")
351
- argname_to_mapped_axis[arg] = None
352
-
353
- # Close the function over the unmapped arguments of vmap
354
- fn_closed = lambda *mapped_args: function_to_vmap(
355
- **vmap_unmapped_args, **dict(zip(vmap_mapped_args.keys(), mapped_args))
356
- )
357
-
358
- # Create the in_axes tuple of only the mapped arguments
359
- in_axes_mapped = tuple(
360
- argname_to_mapped_axis[name] for name in vmap_mapped_args
361
- )
362
-
363
- # If all in_axes are the same, simplify in_axes tuple to be just an integer
364
- if len(set(in_axes_mapped)) == 1:
365
- in_axes_mapped = list(set(in_axes_mapped))[0]
366
-
367
- # If, instead, in_axes has different elements, we need to replace the mapped
368
- # axis of "self" with a pytree having as leafs the mapped axis.
369
- # This is because the vmap in_axes specification must be a tree prefix of
370
- # the corresponding value.
371
- if isinstance(in_axes_mapped, tuple) and "self" in vmap_mapped_args:
372
- argname_to_mapped_axis["self"] = jax.tree_util.tree_map(
373
- lambda _: argname_to_mapped_axis["self"], vmap_mapped_args["self"]
374
- )
375
- in_axes_mapped = tuple(
376
- argname_to_mapped_axis[name] for name in vmap_mapped_args
377
- )
378
-
379
- # Apply the vmap transformation and call the function passing only the
380
- # mapped arguments. The unmapped arguments have been closed over.
381
- # Note: we altered the "in_axes" tuple so that it does not have any
382
- # None elements.
383
- # Note: if "in_axes_mapped" is a tuple, the following fails if we pass kwargs,
384
- # we need to pass the unpacked args tuple instead.
385
- return jax.vmap(
386
- fn_closed,
387
- in_axes=in_axes_mapped,
388
- **dict(out_axes=out_axes) if out_axes is not None else {},
389
- )(*list(vmap_mapped_args.values()))
390
-
391
- def fn_tf_jit(*args, function_to_jit: Callable, **kwargs):
392
- """Wrapper applying the jit transformation"""
393
-
394
- # Canonicalize the arguments so that all of them are kwargs
395
- bound = sig.bind(*args, **kwargs)
396
- bound.apply_defaults()
397
-
398
- # Apply the jit transformation and call the function passing all arguments
399
- # as keyword arguments
400
- return jax.jit(function_to_jit, static_argnames=static_argnames)(
401
- **bound.arguments
402
- )
403
-
404
- # First applied wrapper that executes fn in a mutable context
405
- fn_mutable = functools.partial(
406
- jax_tf.call_class_method_in_mutable_context,
407
- fn=fn,
408
- jit=jit,
409
- mutability=mutability,
410
- )
411
-
412
- # Second applied wrapper that transforms fn with vmap
413
- fn_vmap = (
414
- fn_mutable
415
- if not vmap
416
- else functools.partial(fn_tf_vmap, function_to_vmap=fn_mutable)
417
- )
418
-
419
- # Third applied wrapper that transforms fn with jit
420
- fn_jit_vmap = (
421
- fn_vmap
422
- if not jit
423
- else functools.partial(fn_tf_jit, function_to_jit=fn_vmap)
424
- )
425
-
426
- return fn_jit_vmap
427
-
428
- @staticmethod
429
- def call_class_method_in_mutable_context(
430
- *args, fn: Callable, jit: bool, mutability: Mutability, **kwargs
431
- ) -> tuple[Any, Vmappable]:
432
- """
433
- Wrapper to call a method on an object with the desired mutable context.
434
-
435
- Args:
436
- fn: The method to call.
437
- jit: Whether the method is being jit compiled or not.
438
- mutability: The desired mutability context.
439
- *args: The positional arguments to pass to the method (including self).
440
- **kwargs: The keyword arguments to pass to the method.
441
-
442
- Returns:
443
- A tuple containing the return value of the method and the object
444
- possibly updated by the method if it is in read-write.
445
-
446
- Note:
447
- This approach enables to jit-compile methods of a stateful object without
448
- leaking traces, therefore obtaining a jax-compatible OOP pattern.
449
- """
450
-
451
- # Log here whether the method is being jit compiled or not.
452
- # This log message does not get printed from compiled code, so here is the
453
- # most appropriate place to be sure that we log it correctly.
454
- if jit:
455
- logging.debug(msg=f"JIT compiling {fn}")
456
-
457
- # Canonicalize the arguments so that all of them are kwargs
458
- sig = inspect.signature(fn)
459
- bound = sig.bind(*args, **kwargs)
460
- bound.apply_defaults()
461
-
462
- # Extract the class instance over which fn is called
463
- instance: Vmappable = bound.arguments["self"]
464
-
465
- # Select the right mutability. If the instance is mutable with validation
466
- # disabled, we override the input mutability so that we do not fail in case
467
- # of mismatched tree structure.
468
- mut = (
469
- Mutability.MUTABLE_NO_VALIDATION
470
- if instance._mutability() is Mutability.MUTABLE_NO_VALIDATION
471
- else mutability
472
- )
473
-
474
- # Call fn in a mutable context
475
- with instance.mutable_context(mutability=mut):
476
- # Methods could call other decorated methods. When it happens, the decorator
477
- # of the called method is invoked, that applies jit and vmap transformations.
478
- # This is not desired as it calls vmap inside an already vmapped method.
479
- # We work around this occurrence by disabling the jit/vmap decorators of all
480
- # methods called inside fn through a context manager.
481
- # Note that we already work around this in the beginning of the wrapper
482
- # function by detecting traced arguments, but the decorator works also
483
- # when jit=False and vmap=False, therefore only enforcing the mutability.
484
- with jax_tf.disabled_oop_decorators():
485
- out = fn(**bound.arguments)
486
-
487
- return out, instance
488
-
489
- @staticmethod
490
- def env_var_on(var_name: str, default_value: str = "1") -> bool:
491
- """
492
- Check whether an environment variable is set to a value that is considered on.
493
-
494
- Args:
495
- var_name: The name of the environment variable.
496
- default_value: The default variable value to consider if the variable has not
497
- been exported.
498
-
499
- Returns:
500
- True if the environment variable contains an on value, False otherwise.
501
- """
502
-
503
- on_values = {"1", "true", "on", "yes"}
504
- return os.environ.get(var_name, default_value).lower() in on_values
505
-
506
- @staticmethod
507
- @contextlib.contextmanager
508
- def disabled_oop_decorators() -> Generator[None, None, None]:
509
- """
510
- Context manager to disable the application of jax transformations performed by
511
- the decorators of this class.
512
-
513
- Note: when the transformations are disabled, the only logic still applied is
514
- the selection of the object mutability over which the method is running.
515
- """
516
-
517
- # Check whether the environment variable is part of the environment and
518
- # save its value. We restore the original value before exiting the context.
519
- env_cache = (
520
- None if jax_tf.EnvVarOOP not in os.environ else os.environ[jax_tf.EnvVarOOP]
521
- )
522
-
523
- # Disable both jit and vmap transformations
524
- os.environ[jax_tf.EnvVarOOP] = "0"
525
-
526
- try:
527
- # Execute the code in the context with disabled transformations
528
- yield
529
-
530
- finally:
531
- # Restore the original value of the environment variable or remove it if
532
- # it was not present before entering the context
533
- if env_cache is not None:
534
- os.environ[jax_tf.EnvVarOOP] = env_cache
535
- else:
536
- _ = os.environ.pop(jax_tf.EnvVarOOP)
jaxsim/utils/vmappable.py DELETED
@@ -1,117 +0,0 @@
1
- import dataclasses
2
- from typing import Type
3
-
4
- import jax
5
- import jax.numpy as jnp
6
- import jax_dataclasses
7
-
8
- from . import JaxsimDataclass, Mutability
9
-
10
- try:
11
- from typing import Self
12
- except ImportError:
13
- from typing_extensions import Self
14
-
15
-
16
- @jax_dataclasses.pytree_dataclass
17
- class Vmappable(JaxsimDataclass):
18
- """Abstract class with utilities for vmappable pytrees."""
19
-
20
- batch_size: jax_dataclasses.Static[int] = dataclasses.field(
21
- default=int(0), repr=False, compare=False, hash=False, kw_only=True
22
- )
23
-
24
- @property
25
- def vectorized(self) -> bool:
26
- """Marks this pytree as vectorized."""
27
-
28
- return self.batch_size > 0
29
-
30
- @classmethod
31
- def build_from_list(cls: Type[Self], list_of_obj: list[Self]) -> Self:
32
- """
33
- Build a vectorized pytree from a list of pytree of the same type.
34
-
35
- Args:
36
- list_of_obj: The list of pytrees to vectorize.
37
-
38
- Returns:
39
- The vectorized pytree having as leaves the stacked leaves of the input list.
40
- """
41
-
42
- if set(type(el) for el in list_of_obj) != {cls}:
43
- msg = "The input list must contain only objects of type '{}'"
44
- raise ValueError(msg.format(cls.__name__))
45
-
46
- # Create a pytree by stacking all the leafs of the input list
47
- data_vec: Vmappable = jax.tree_map(
48
- lambda *leafs: jnp.array(leafs), *list_of_obj
49
- )
50
-
51
- # Store the batch dimension
52
- with data_vec.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):
53
- data_vec.batch_size = len(list_of_obj)
54
-
55
- # Detect the most common mutability in the input list
56
- mutabilities = [e._mutability() for e in list_of_obj]
57
- mutability = max(set(mutabilities), key=mutabilities.count)
58
-
59
- # Update the mutability of the vectorized pytree
60
- data_vec._set_mutability(mutability)
61
-
62
- return data_vec
63
-
64
- def vectorize(self: Self, batch_size: int) -> Self:
65
- """
66
- Return a vectorized version of this pytree.
67
-
68
- Args:
69
- batch_size: The batch size.
70
-
71
- Returns:
72
- A vectorized version of this pytree obtained by stacking the leaves of the
73
- original pytree along a new batch dimension (the first one).
74
- """
75
-
76
- if self.vectorized:
77
- raise RuntimeError("Cannot vectorize an already vectorized object")
78
-
79
- if batch_size == 0:
80
- return self.copy()
81
-
82
- # TODO validate if mutability is maintained
83
-
84
- return self.__class__.build_from_list(list_of_obj=[self] * batch_size)
85
-
86
- def extract_element(self: Self, index: int) -> Self:
87
- """
88
- Extract the i-th element from a vectorized pytree.
89
-
90
- Args:
91
- index: The index of the element to extract.
92
-
93
- Returns:
94
- A non vectorized pytree obtained by extracting the i-th element from the
95
- vectorized pytree.
96
- """
97
-
98
- if index < 0:
99
- raise ValueError("The index of the desired element cannot be negative")
100
-
101
- if index == 0 and self.batch_size == 0:
102
- return self.copy()
103
-
104
- if not self.vectorized:
105
- raise RuntimeError("Cannot extract elements from a non-vectorized object")
106
-
107
- if index >= self.batch_size:
108
- raise ValueError("The index must be smaller than the batch size")
109
-
110
- # Get the i-th pytree by extracting the i-th element from the vectorized pytree
111
- data = jax.tree_map(lambda leaf: leaf[index], self)
112
-
113
- # Update the batch size of the extracted scalar pytree
114
- with data.mutable_context(mutability=Mutability.MUTABLE):
115
- data.batch_size = 0
116
-
117
- return data