jaxsim 0.1rc0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. jaxsim/__init__.py +5 -6
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +3 -0
  4. jaxsim/api/com.py +240 -0
  5. jaxsim/api/common.py +216 -0
  6. jaxsim/api/contact.py +271 -0
  7. jaxsim/api/data.py +821 -0
  8. jaxsim/api/joint.py +189 -0
  9. jaxsim/api/kin_dyn_parameters.py +777 -0
  10. jaxsim/api/link.py +361 -0
  11. jaxsim/api/model.py +1633 -0
  12. jaxsim/api/ode.py +295 -0
  13. jaxsim/api/ode_data.py +694 -0
  14. jaxsim/api/references.py +421 -0
  15. jaxsim/integrators/__init__.py +2 -0
  16. jaxsim/integrators/common.py +594 -0
  17. jaxsim/integrators/fixed_step.py +102 -0
  18. jaxsim/integrators/variable_step.py +610 -0
  19. jaxsim/math/__init__.py +11 -0
  20. jaxsim/math/adjoint.py +24 -2
  21. jaxsim/math/joint_model.py +335 -0
  22. jaxsim/math/quaternion.py +44 -3
  23. jaxsim/math/rotation.py +4 -4
  24. jaxsim/math/transform.py +92 -0
  25. jaxsim/mujoco/__init__.py +3 -0
  26. jaxsim/mujoco/__main__.py +192 -0
  27. jaxsim/mujoco/loaders.py +615 -0
  28. jaxsim/mujoco/model.py +414 -0
  29. jaxsim/mujoco/visualizer.py +176 -0
  30. jaxsim/parsers/descriptions/collision.py +14 -0
  31. jaxsim/parsers/descriptions/link.py +13 -2
  32. jaxsim/parsers/kinematic_graph.py +8 -3
  33. jaxsim/parsers/rod/parser.py +54 -38
  34. jaxsim/parsers/rod/utils.py +7 -8
  35. jaxsim/rbda/__init__.py +7 -0
  36. jaxsim/rbda/aba.py +295 -0
  37. jaxsim/rbda/collidable_points.py +142 -0
  38. jaxsim/{physics/algos → rbda}/crba.py +43 -42
  39. jaxsim/rbda/forward_kinematics.py +113 -0
  40. jaxsim/rbda/jacobian.py +201 -0
  41. jaxsim/rbda/rnea.py +237 -0
  42. jaxsim/rbda/soft_contacts.py +296 -0
  43. jaxsim/rbda/utils.py +152 -0
  44. jaxsim/terrain/__init__.py +2 -0
  45. jaxsim/{physics/algos → terrain}/terrain.py +4 -6
  46. jaxsim/typing.py +30 -30
  47. jaxsim/utils/__init__.py +1 -4
  48. jaxsim/utils/hashless.py +18 -0
  49. jaxsim/utils/jaxsim_dataclass.py +281 -31
  50. {jaxsim-0.1rc0.dist-info → jaxsim-0.2.0.dist-info}/LICENSE +1 -1
  51. jaxsim-0.2.0.dist-info/METADATA +237 -0
  52. jaxsim-0.2.0.dist-info/RECORD +64 -0
  53. {jaxsim-0.1rc0.dist-info → jaxsim-0.2.0.dist-info}/WHEEL +1 -1
  54. jaxsim/high_level/__init__.py +0 -2
  55. jaxsim/high_level/common.py +0 -11
  56. jaxsim/high_level/joint.py +0 -148
  57. jaxsim/high_level/link.py +0 -259
  58. jaxsim/high_level/model.py +0 -1695
  59. jaxsim/math/conv.py +0 -114
  60. jaxsim/math/joint.py +0 -101
  61. jaxsim/math/plucker.py +0 -100
  62. jaxsim/physics/__init__.py +0 -12
  63. jaxsim/physics/algos/__init__.py +0 -0
  64. jaxsim/physics/algos/aba.py +0 -256
  65. jaxsim/physics/algos/aba_motors.py +0 -284
  66. jaxsim/physics/algos/forward_kinematics.py +0 -79
  67. jaxsim/physics/algos/jacobian.py +0 -98
  68. jaxsim/physics/algos/rnea.py +0 -180
  69. jaxsim/physics/algos/rnea_motors.py +0 -196
  70. jaxsim/physics/algos/soft_contacts.py +0 -454
  71. jaxsim/physics/algos/utils.py +0 -69
  72. jaxsim/physics/model/__init__.py +0 -0
  73. jaxsim/physics/model/ground_contact.py +0 -55
  74. jaxsim/physics/model/physics_model.py +0 -358
  75. jaxsim/physics/model/physics_model_state.py +0 -174
  76. jaxsim/simulation/__init__.py +0 -4
  77. jaxsim/simulation/integrators.py +0 -452
  78. jaxsim/simulation/ode.py +0 -290
  79. jaxsim/simulation/ode_data.py +0 -53
  80. jaxsim/simulation/ode_integration.py +0 -125
  81. jaxsim/simulation/simulator.py +0 -544
  82. jaxsim/simulation/simulator_callbacks.py +0 -53
  83. jaxsim/simulation/utils.py +0 -15
  84. jaxsim/sixd/__init__.py +0 -2
  85. jaxsim/utils/oop.py +0 -532
  86. jaxsim/utils/vmappable.py +0 -117
  87. jaxsim-0.1rc0.dist-info/METADATA +0 -167
  88. jaxsim-0.1rc0.dist-info/RECORD +0 -64
  89. {jaxsim-0.1rc0.dist-info → jaxsim-0.2.0.dist-info}/top_level.txt +0 -0
jaxsim/utils/oop.py DELETED
@@ -1,532 +0,0 @@
1
- import contextlib
2
- import dataclasses
3
- import functools
4
- import inspect
5
- import os
6
- from typing import Any, Callable, Generator
7
-
8
- import jax
9
- import jax.flatten_util
10
-
11
- from jaxsim import logging
12
- from jaxsim.utils import tracing
13
-
14
- from . import Mutability, Vmappable
15
-
16
-
17
- class jax_tf:
18
- """
19
- Class containing decorators applicable to methods of Vmappable objects.
20
- """
21
-
22
- # Environment variables that can be used to disable the transformations
23
- EnvVarOOP: str = "JAXSIM_OOP_DECORATORS"
24
- EnvVarJitOOP: str = "JAXSIM_OOP_DECORATORS_JIT"
25
- EnvVarVmapOOP: str = "JAXSIM_OOP_DECORATORS_VMAP"
26
- EnvVarCacheOOP: str = "JAXSIM_OOP_DECORATORS_CACHE"
27
-
28
- @staticmethod
29
- def method_ro(
30
- fn: Callable,
31
- jit: bool = True,
32
- static_argnames: tuple[str, ...] | list[str] = (),
33
- vmap: bool | None = None,
34
- vmap_in_axes: tuple[int, ...] | int | None = None,
35
- vmap_out_axes: tuple[int, ...] | int | None = None,
36
- ):
37
- """
38
- Decorator for r/o methods of classes inheriting from Vmappable.
39
- """
40
-
41
- return jax_tf.method(
42
- fn=fn,
43
- read_only=True,
44
- validate=True,
45
- jit_enabled=jit,
46
- static_argnames=static_argnames,
47
- vmap_enabled=vmap,
48
- vmap_in_axes=vmap_in_axes,
49
- vmap_out_axes=vmap_out_axes,
50
- )
51
-
52
- @staticmethod
53
- def method_rw(
54
- fn: Callable,
55
- validate: bool = True,
56
- jit: bool = True,
57
- static_argnames: tuple[str, ...] | list[str] = (),
58
- vmap: bool | None = None,
59
- vmap_in_axes: tuple[int, ...] | int | None = None,
60
- vmap_out_axes: tuple[int, ...] | int | None = None,
61
- ):
62
- """
63
- Decorator for r/w methods of classes inheriting from Vmappable.
64
- """
65
-
66
- return jax_tf.method(
67
- fn=fn,
68
- read_only=False,
69
- validate=validate,
70
- jit_enabled=jit,
71
- static_argnames=static_argnames,
72
- vmap_enabled=vmap,
73
- vmap_in_axes=vmap_in_axes,
74
- vmap_out_axes=vmap_out_axes,
75
- )
76
-
77
- @staticmethod
78
- def method(
79
- fn: Callable,
80
- read_only: bool = True,
81
- validate: bool = True,
82
- jit_enabled: bool = True,
83
- static_argnames: tuple[str, ...] | list[str] = (),
84
- vmap_enabled: bool | None = None,
85
- vmap_in_axes: tuple[int, ...] | int | None = None,
86
- vmap_out_axes: tuple[int, ...] | int | None = None,
87
- ):
88
- """
89
- Decorator for methods of classes inheriting from Vmappable.
90
-
91
- This decorator enables executing the methods on an object characterized by a
92
- desired mutability, that is selected considering the r/o and validation flags.
93
- It also allows to transform the method with the jit/vmap transformations.
94
- If the Vmappable object is vectorized, the method is automatically vmapped, and
95
- the in_axes are properly post-processed to simplify the combination with jit.
96
-
97
- Args:
98
- fn: The method to decorate.
99
- read_only: Whether the method operates on a read-only object.
100
- validate: Whether r/w methods should preserve the pytree structure.
101
- jit_enabled: Whether to apply the jit transformation.
102
- static_argnames: The names of the arguments that should be static.
103
- vmap_enabled: Whether to apply the vmap transformation.
104
- vmap_in_axes: The in_axes to use for the vmap transformation.
105
- vmap_out_axes: The out_axes to use for the vmap transformation.
106
-
107
- Returns:
108
- The decorated method.
109
- """
110
-
111
- @functools.wraps(fn)
112
- def wrapper(*args, **kwargs):
113
- """The wrapper function that is returned by this decorator."""
114
-
115
- # Methods of classes inheriting from Vmappable decorated by this wrapper
116
- # automatically support jit/vmap/mutability features when called standalone.
117
- # However, when objects are arguments of plain functions transformed with
118
- # jit/vmap, and decorated methods are called inside those functions, we need
119
- # to disable this decorator to avoid double wrapping and execution errors.
120
- # We do so by iterating over the arguments, and checking whether they are
121
- # being traced by JAX.
122
- for argument in list(args) + list(kwargs.values()):
123
- try:
124
- argument_flat, _ = jax.flatten_util.ravel_pytree(argument)
125
-
126
- if tracing(argument_flat):
127
- return fn(*args, **kwargs)
128
- except:
129
- continue
130
-
131
- # ===============================================================
132
- # Wrap fn so that jit/vmap/mutability transformations are applied
133
- # ===============================================================
134
-
135
- # Initialize the mutability of the instance over which the method is running.
136
- # * In r/o methods, this approach prevents any type of mutation.
137
- # * In r/w methods, this approach allows to catch early JIT recompilations
138
- # caused by unwanted changes in the pytree structure.
139
- if read_only:
140
- mutability = Mutability.FROZEN
141
- else:
142
- mutability = (
143
- Mutability.MUTABLE if validate else Mutability.MUTABLE_NO_VALIDATION
144
- )
145
-
146
- # Extract the class instance over which fn is called
147
- instance: Vmappable = args[0]
148
- assert isinstance(instance, Vmappable)
149
-
150
- # Save the original mutability
151
- original_mutability = instance._mutability()
152
-
153
- # Inspect the environment to detect whether to enforce disabling jit/vmap
154
- deco_on = jax_tf.env_var_on(jax_tf.EnvVarOOP)
155
- jit_enabled_env = jax_tf.env_var_on(jax_tf.EnvVarJitOOP) and deco_on
156
- vmap_enabled_env = jax_tf.env_var_on(jax_tf.EnvVarVmapOOP) and deco_on
157
-
158
- # Allow disabling the cache of jit-compiled functions.
159
- # It can be useful for debugging or testing purposes.
160
- wrap_fn = (
161
- jax_tf.wrap_fn
162
- if jax_tf.env_var_on(jax_tf.EnvVarCacheOOP) and deco_on
163
- else jax_tf.wrap_fn.__wrapped__
164
- )
165
-
166
- # Get the transformed function (possibly cached by functools.cache).
167
- # Note that all the arguments of the following methods, when hashed, should
168
- # uniquely identify the returned function so that a new function is built
169
- # when arguments change and either jit or vmap have to be called again.
170
- fn_db = wrap_fn(
171
- fn=fn, # noqa
172
- mutability=mutability,
173
- jit=jit_enabled_env and jit_enabled,
174
- static_argnames=tuple(static_argnames),
175
- vmap=vmap_enabled_env
176
- and (
177
- vmap_enabled is True
178
- or (vmap_enabled is None and instance.vectorized)
179
- ),
180
- in_axes=vmap_in_axes,
181
- out_axes=vmap_out_axes,
182
- )
183
-
184
- # Call the transformed (mutable/jit/vmap) method
185
- out, obj = fn_db(*args, **kwargs)
186
-
187
- if read_only:
188
- # Restore the original mutability
189
- instance._set_mutability(mutability=original_mutability)
190
-
191
- return out
192
-
193
- # =================================================================
194
- # From here we assume that the wrapper is operating on a r/w method
195
- # =================================================================
196
-
197
- from jax_dataclasses._dataclasses import JDC_STATIC_MARKER
198
-
199
- # Select the right runtime mutability. The only difference here is when a r/w
200
- # method is called on a frozen object. In this case, we enable updating the
201
- # pytree data and preserve its structure only if validation is enabled.
202
- mutability_dict = {
203
- Mutability.MUTABLE_NO_VALIDATION: Mutability.MUTABLE_NO_VALIDATION,
204
- Mutability.MUTABLE: Mutability.MUTABLE,
205
- Mutability.FROZEN: Mutability.MUTABLE
206
- if validate
207
- else Mutability.MUTABLE_NO_VALIDATION,
208
- }
209
-
210
- # We need to replace all the dynamic leafs of the original instance with those
211
- # computed by the functional transformation.
212
- # We do so by iterating over the fields of the jax_dataclasses and ignoring
213
- # all the fields that are marked as static.
214
- # Caveats: https://github.com/ami-iit/jaxsim/pull/48#issuecomment-1746635121.
215
- with instance.mutable_context(
216
- mutability=mutability_dict[instance._mutability()]
217
- ):
218
- for f in dataclasses.fields(instance): # noqa
219
- if (
220
- hasattr(f, "type")
221
- and hasattr(f.type, "__metadata__")
222
- and JDC_STATIC_MARKER in f.type.__metadata__
223
- ):
224
- continue
225
-
226
- try:
227
- setattr(instance, f.name, getattr(obj, f.name))
228
- except AssertionError as exc:
229
- logging.debug(f"Old object:\n{getattr(instance, f.name)}")
230
- logging.debug(f"New object:\n{getattr(obj, f.name)}")
231
- raise RuntimeError(
232
- f"Failed to update field '{f.name}'"
233
- ) from exc
234
-
235
- return out
236
-
237
- return wrapper
238
-
239
- @staticmethod
240
- @functools.cache
241
- def wrap_fn(
242
- fn: Callable,
243
- mutability: Mutability,
244
- jit: bool,
245
- static_argnames: tuple[str, ...] | list[str],
246
- vmap: bool,
247
- in_axes: tuple[int, ...] | int | None,
248
- out_axes: tuple[int, ...] | int | None,
249
- ) -> Callable:
250
- """
251
- Transform a method with jit/vmap and execute it on an object characterized
252
- by the desired mutability.
253
-
254
- Note:
255
- The method should take the object (self) as first argument.
256
-
257
- Note:
258
- This returned transformed method is cached by considering the hash of all
259
- the arguments. It will re-apply jit/vmap transformations only if needed.
260
-
261
- Args:
262
- fn: The method to consider.
263
- mutability: The mutability of the object on which the method is called.
264
- jit: Whether to apply jit transformations.
265
- static_argnames: The names of the arguments that should be considered static.
266
- vmap: Whether to apply vmap transformations.
267
- in_axes: The axes along which to vmap input arguments.
268
- out_axes: The axes along which to vmap output arguments.
269
-
270
- Note:
271
- In order to simplify the application of vmap, we close the method arguments
272
- over all the non-mapped input arguments. Furthermore, for improving the
273
- compatibility with jit, we also close the vmap application over the static
274
- arguments.
275
-
276
- Returns:
277
- The transformed method operating on an object with the desired mutability.
278
- We maintain the same signature of the original method.
279
- """
280
-
281
- # Extract the signature of the function
282
- sig = inspect.signature(fn)
283
-
284
- # All static arguments must be actual arguments of fn
285
- for name in static_argnames:
286
- if name not in sig.parameters:
287
- raise ValueError(f"Static argument '{name}' not found in {fn}")
288
-
289
- # If in_axes is a tuple, its dimension should match the number of arguments
290
- if isinstance(in_axes, tuple) and len(in_axes) != len(sig.parameters):
291
- msg = "The length of 'in_axes' must match the number of arguments ({})"
292
- raise ValueError(msg.format(len(sig.parameters)))
293
-
294
- # Check that static arguments are not mapped with vmap.
295
- # This case would not work since static arguments are not traces and vmap need
296
- # to trace arguments in order to map them.
297
- if isinstance(in_axes, tuple):
298
- for mapped_axis, arg_name in zip(in_axes, sig.parameters.keys()):
299
- if mapped_axis is not None and arg_name in static_argnames:
300
- raise ValueError(
301
- f"Static argument '{arg_name}' cannot be mapped with vmap"
302
- )
303
-
304
- def fn_tf_vmap(*args, function_to_vmap: Callable, **kwargs):
305
- """Wrapper applying the vmap transformation"""
306
-
307
- # Canonicalize the arguments so that all of them are kwargs
308
- bound = sig.bind(*args, **kwargs)
309
- bound.apply_defaults()
310
-
311
- # Build a dictionary mapping all arguments to a mapped axis, even when
312
- # the None is passed (defaults to in_axes=0) or and int is passed (defaults
313
- # to in_axes=<int>).
314
- match in_axes:
315
- case None:
316
- argname_to_mapped_axis = {name: 0 for name in bound.arguments}
317
- case tuple():
318
- argname_to_mapped_axis = {
319
- name: in_axes[i] for i, name in enumerate(bound.arguments)
320
- }
321
- case int():
322
- argname_to_mapped_axis = {name: in_axes for name in bound.arguments}
323
- case _:
324
- raise ValueError(in_axes)
325
-
326
- # Build a dictionary (argument_name -> argument) for all mapped arguments.
327
- # Note that a mapped argument is an argument whose axis is not None and
328
- # is not a static jit argument.
329
- vmap_mapped_args = {
330
- arg: value
331
- for arg, value in bound.arguments.items()
332
- if argname_to_mapped_axis[arg] is not None
333
- and arg not in static_argnames
334
- }
335
-
336
- # Build a dictionary (argument_name -> argument) for all unmapped arguments
337
- vmap_unmapped_args = {
338
- arg: value
339
- for arg, value in bound.arguments.items()
340
- if arg not in vmap_mapped_args
341
- }
342
-
343
- # Disable mapping of non-vectorized default arguments
344
- for arg, value in argname_to_mapped_axis.items():
345
- if arg in vmap_mapped_args and value == sig.parameters[arg].default:
346
- logging.debug(f"Disabling vmapping of default argument '{arg}'")
347
- argname_to_mapped_axis[arg] = None
348
-
349
- # Close the function over the unmapped arguments of vmap
350
- fn_closed = lambda *mapped_args: function_to_vmap(
351
- **vmap_unmapped_args, **dict(zip(vmap_mapped_args.keys(), mapped_args))
352
- )
353
-
354
- # Create the in_axes tuple of only the mapped arguments
355
- in_axes_mapped = tuple(
356
- argname_to_mapped_axis[name] for name in vmap_mapped_args
357
- )
358
-
359
- # If all in_axes are the same, simplify in_axes tuple to be just an integer
360
- if len(set(in_axes_mapped)) == 1:
361
- in_axes_mapped = list(set(in_axes_mapped))[0]
362
-
363
- # If, instead, in_axes has different elements, we need to replace the mapped
364
- # axis of "self" with a pytree having as leafs the mapped axis.
365
- # This is because the vmap in_axes specification must be a tree prefix of
366
- # the corresponding value.
367
- if isinstance(in_axes_mapped, tuple) and "self" in vmap_mapped_args:
368
- argname_to_mapped_axis["self"] = jax.tree_util.tree_map(
369
- lambda _: argname_to_mapped_axis["self"], vmap_mapped_args["self"]
370
- )
371
- in_axes_mapped = tuple(
372
- argname_to_mapped_axis[name] for name in vmap_mapped_args
373
- )
374
-
375
- # Apply the vmap transformation and call the function passing only the
376
- # mapped arguments. The unmapped arguments have been closed over.
377
- # Note: we altered the "in_axes" tuple so that it does not have any
378
- # None elements.
379
- # Note: if "in_axes_mapped" is a tuple, the following fails if we pass kwargs,
380
- # we need to pass the unpacked args tuple instead.
381
- return jax.vmap(
382
- fn_closed,
383
- in_axes=in_axes_mapped,
384
- **dict(out_axes=out_axes) if out_axes is not None else {},
385
- )(*list(vmap_mapped_args.values()))
386
-
387
- def fn_tf_jit(*args, function_to_jit: Callable, **kwargs):
388
- """Wrapper applying the jit transformation"""
389
-
390
- # Canonicalize the arguments so that all of them are kwargs
391
- bound = sig.bind(*args, **kwargs)
392
- bound.apply_defaults()
393
-
394
- # Apply the jit transformation and call the function passing all arguments
395
- # as keyword arguments
396
- return jax.jit(function_to_jit, static_argnames=static_argnames)(
397
- **bound.arguments
398
- )
399
-
400
- # First applied wrapper that executes fn in a mutable context
401
- fn_mutable = functools.partial(
402
- jax_tf.call_class_method_in_mutable_context,
403
- fn=fn,
404
- jit=jit,
405
- mutability=mutability,
406
- )
407
-
408
- # Second applied wrapper that transforms fn with vmap
409
- fn_vmap = (
410
- fn_mutable
411
- if not vmap
412
- else functools.partial(fn_tf_vmap, function_to_vmap=fn_mutable)
413
- )
414
-
415
- # Third applied wrapper that transforms fn with jit
416
- fn_jit_vmap = (
417
- fn_vmap
418
- if not jit
419
- else functools.partial(fn_tf_jit, function_to_jit=fn_vmap)
420
- )
421
-
422
- return fn_jit_vmap
423
-
424
- @staticmethod
425
- def call_class_method_in_mutable_context(
426
- *args, fn: Callable, jit: bool, mutability: Mutability, **kwargs
427
- ) -> tuple[Any, Vmappable]:
428
- """
429
- Wrapper to call a method on an object with the desired mutable context.
430
-
431
- Args:
432
- fn: The method to call.
433
- jit: Whether the method is being jit compiled or not.
434
- mutability: The desired mutability context.
435
- *args: The positional arguments to pass to the method (including self).
436
- **kwargs: The keyword arguments to pass to the method.
437
-
438
- Returns:
439
- A tuple containing the return value of the method and the object
440
- possibly updated by the method if it is in read-write.
441
-
442
- Note:
443
- This approach enables to jit-compile methods of a stateful object without
444
- leaking traces, therefore obtaining a jax-compatible OOP pattern.
445
- """
446
-
447
- # Log here whether the method is being jit compiled or not.
448
- # This log message does not get printed from compiled code, so here is the
449
- # most appropriate place to be sure that we log it correctly.
450
- if jit:
451
- logging.debug(msg=f"JIT compiling {fn}")
452
-
453
- # Canonicalize the arguments so that all of them are kwargs
454
- sig = inspect.signature(fn)
455
- bound = sig.bind(*args, **kwargs)
456
- bound.apply_defaults()
457
-
458
- # Extract the class instance over which fn is called
459
- instance: Vmappable = bound.arguments["self"]
460
-
461
- # Select the right mutability. If the instance is mutable with validation
462
- # disabled, we override the input mutability so that we do not fail in case
463
- # of mismatched tree structure.
464
- mut = (
465
- Mutability.MUTABLE_NO_VALIDATION
466
- if instance._mutability() is Mutability.MUTABLE_NO_VALIDATION
467
- else mutability
468
- )
469
-
470
- # Call fn in a mutable context
471
- with instance.mutable_context(mutability=mut):
472
- # Methods could call other decorated methods. When it happens, the decorator
473
- # of the called method is invoked, that applies jit and vmap transformations.
474
- # This is not desired as it calls vmap inside an already vmapped method.
475
- # We work around this occurrence by disabling the jit/vmap decorators of all
476
- # methods called inside fn through a context manager.
477
- # Note that we already work around this in the beginning of the wrapper
478
- # function by detecting traced arguments, but the decorator works also
479
- # when jit=False and vmap=False, therefore only enforcing the mutability.
480
- with jax_tf.disabled_oop_decorators():
481
- out = fn(**bound.arguments)
482
-
483
- return out, instance
484
-
485
- @staticmethod
486
- def env_var_on(var_name: str, default_value: str = "1") -> bool:
487
- """
488
- Check whether an environment variable is set to a value that is considered on.
489
-
490
- Args:
491
- var_name: The name of the environment variable.
492
- default_value: The default variable value to consider if the variable has not
493
- been exported.
494
-
495
- Returns:
496
- True if the environment variable contains an on value, False otherwise.
497
- """
498
-
499
- on_values = {"1", "true", "on", "yes"}
500
- return os.environ.get(var_name, default_value).lower() in on_values
501
-
502
- @staticmethod
503
- @contextlib.contextmanager
504
- def disabled_oop_decorators() -> Generator[None, None, None]:
505
- """
506
- Context manager to disable the application of jax transformations performed by
507
- the decorators of this class.
508
-
509
- Note: when the transformations are disabled, the only logic still applied is
510
- the selection of the object mutability over which the method is running.
511
- """
512
-
513
- # Check whether the environment variable is part of the environment and
514
- # save its value. We restore the original value before exiting the context.
515
- env_cache = (
516
- None if jax_tf.EnvVarOOP not in os.environ else os.environ[jax_tf.EnvVarOOP]
517
- )
518
-
519
- # Disable both jit and vmap transformations
520
- os.environ[jax_tf.EnvVarOOP] = "0"
521
-
522
- try:
523
- # Execute the code in the context with disabled transformations
524
- yield
525
-
526
- finally:
527
- # Restore the original value of the environment variable or remove it if
528
- # it was not present before entering the context
529
- if env_cache is not None:
530
- os.environ[jax_tf.EnvVarOOP] = env_cache
531
- else:
532
- _ = os.environ.pop(jax_tf.EnvVarOOP)
jaxsim/utils/vmappable.py DELETED
@@ -1,117 +0,0 @@
1
- import dataclasses
2
- from typing import Type
3
-
4
- import jax
5
- import jax.numpy as jnp
6
- import jax_dataclasses
7
-
8
- from . import JaxsimDataclass, Mutability
9
-
10
- try:
11
- from typing import Self
12
- except ImportError:
13
- from typing_extensions import Self
14
-
15
-
16
- @jax_dataclasses.pytree_dataclass
17
- class Vmappable(JaxsimDataclass):
18
- """Abstract class with utilities for vmappable pytrees."""
19
-
20
- batch_size: jax_dataclasses.Static[int] = dataclasses.field(
21
- default=int(0), repr=False, compare=False, hash=False, kw_only=True
22
- )
23
-
24
- @property
25
- def vectorized(self) -> bool:
26
- """Marks this pytree as vectorized."""
27
-
28
- return self.batch_size > 0
29
-
30
- @classmethod
31
- def build_from_list(cls: Type[Self], list_of_obj: list[Self]) -> Self:
32
- """
33
- Build a vectorized pytree from a list of pytree of the same type.
34
-
35
- Args:
36
- list_of_obj: The list of pytrees to vectorize.
37
-
38
- Returns:
39
- The vectorized pytree having as leaves the stacked leaves of the input list.
40
- """
41
-
42
- if set(type(el) for el in list_of_obj) != {cls}:
43
- msg = "The input list must contain only objects of type '{}'"
44
- raise ValueError(msg.format(cls.__name__))
45
-
46
- # Create a pytree by stacking all the leafs of the input list
47
- data_vec: Vmappable = jax.tree_map(
48
- lambda *leafs: jnp.array(leafs), *list_of_obj
49
- )
50
-
51
- # Store the batch dimension
52
- with data_vec.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):
53
- data_vec.batch_size = len(list_of_obj)
54
-
55
- # Detect the most common mutability in the input list
56
- mutabilities = [e._mutability() for e in list_of_obj]
57
- mutability = max(set(mutabilities), key=mutabilities.count)
58
-
59
- # Update the mutability of the vectorized pytree
60
- data_vec._set_mutability(mutability)
61
-
62
- return data_vec
63
-
64
- def vectorize(self: Self, batch_size: int) -> Self:
65
- """
66
- Return a vectorized version of this pytree.
67
-
68
- Args:
69
- batch_size: The batch size.
70
-
71
- Returns:
72
- A vectorized version of this pytree obtained by stacking the leaves of the
73
- original pytree along a new batch dimension (the first one).
74
- """
75
-
76
- if self.vectorized:
77
- raise RuntimeError("Cannot vectorize an already vectorized object")
78
-
79
- if batch_size == 0:
80
- return self.copy()
81
-
82
- # TODO validate if mutability is maintained
83
-
84
- return self.__class__.build_from_list(list_of_obj=[self] * batch_size)
85
-
86
- def extract_element(self: Self, index: int) -> Self:
87
- """
88
- Extract the i-th element from a vectorized pytree.
89
-
90
- Args:
91
- index: The index of the element to extract.
92
-
93
- Returns:
94
- A non vectorized pytree obtained by extracting the i-th element from the
95
- vectorized pytree.
96
- """
97
-
98
- if index < 0:
99
- raise ValueError("The index of the desired element cannot be negative")
100
-
101
- if index == 0 and self.batch_size == 0:
102
- return self.copy()
103
-
104
- if not self.vectorized:
105
- raise RuntimeError("Cannot extract elements from a non-vectorized object")
106
-
107
- if index >= self.batch_size:
108
- raise ValueError("The index must be smaller than the batch size")
109
-
110
- # Get the i-th pytree by extracting the i-th element from the vectorized pytree
111
- data = jax.tree_map(lambda leaf: leaf[index], self)
112
-
113
- # Update the batch size of the extracted scalar pytree
114
- with data.mutable_context(mutability=Mutability.MUTABLE):
115
- data.batch_size = 0
116
-
117
- return data