jaxsim 0.1rc0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/__init__.py +5 -6
- jaxsim/_version.py +2 -2
- jaxsim/api/__init__.py +3 -0
- jaxsim/api/com.py +240 -0
- jaxsim/api/common.py +216 -0
- jaxsim/api/contact.py +271 -0
- jaxsim/api/data.py +821 -0
- jaxsim/api/joint.py +189 -0
- jaxsim/api/kin_dyn_parameters.py +777 -0
- jaxsim/api/link.py +361 -0
- jaxsim/api/model.py +1633 -0
- jaxsim/api/ode.py +295 -0
- jaxsim/api/ode_data.py +694 -0
- jaxsim/api/references.py +421 -0
- jaxsim/integrators/__init__.py +2 -0
- jaxsim/integrators/common.py +594 -0
- jaxsim/integrators/fixed_step.py +102 -0
- jaxsim/integrators/variable_step.py +610 -0
- jaxsim/math/__init__.py +11 -0
- jaxsim/math/adjoint.py +24 -2
- jaxsim/math/joint_model.py +335 -0
- jaxsim/math/quaternion.py +44 -3
- jaxsim/math/rotation.py +4 -4
- jaxsim/math/transform.py +92 -0
- jaxsim/mujoco/__init__.py +3 -0
- jaxsim/mujoco/__main__.py +192 -0
- jaxsim/mujoco/loaders.py +615 -0
- jaxsim/mujoco/model.py +414 -0
- jaxsim/mujoco/visualizer.py +176 -0
- jaxsim/parsers/descriptions/collision.py +14 -0
- jaxsim/parsers/descriptions/link.py +13 -2
- jaxsim/parsers/kinematic_graph.py +8 -3
- jaxsim/parsers/rod/parser.py +54 -38
- jaxsim/parsers/rod/utils.py +7 -8
- jaxsim/rbda/__init__.py +7 -0
- jaxsim/rbda/aba.py +295 -0
- jaxsim/rbda/collidable_points.py +142 -0
- jaxsim/{physics/algos → rbda}/crba.py +43 -42
- jaxsim/rbda/forward_kinematics.py +113 -0
- jaxsim/rbda/jacobian.py +201 -0
- jaxsim/rbda/rnea.py +237 -0
- jaxsim/rbda/soft_contacts.py +296 -0
- jaxsim/rbda/utils.py +152 -0
- jaxsim/terrain/__init__.py +2 -0
- jaxsim/{physics/algos → terrain}/terrain.py +4 -6
- jaxsim/typing.py +30 -30
- jaxsim/utils/__init__.py +1 -4
- jaxsim/utils/hashless.py +18 -0
- jaxsim/utils/jaxsim_dataclass.py +281 -31
- {jaxsim-0.1rc0.dist-info → jaxsim-0.2.0.dist-info}/LICENSE +1 -1
- jaxsim-0.2.0.dist-info/METADATA +237 -0
- jaxsim-0.2.0.dist-info/RECORD +64 -0
- {jaxsim-0.1rc0.dist-info → jaxsim-0.2.0.dist-info}/WHEEL +1 -1
- jaxsim/high_level/__init__.py +0 -2
- jaxsim/high_level/common.py +0 -11
- jaxsim/high_level/joint.py +0 -148
- jaxsim/high_level/link.py +0 -259
- jaxsim/high_level/model.py +0 -1695
- jaxsim/math/conv.py +0 -114
- jaxsim/math/joint.py +0 -101
- jaxsim/math/plucker.py +0 -100
- jaxsim/physics/__init__.py +0 -12
- jaxsim/physics/algos/__init__.py +0 -0
- jaxsim/physics/algos/aba.py +0 -256
- jaxsim/physics/algos/aba_motors.py +0 -284
- jaxsim/physics/algos/forward_kinematics.py +0 -79
- jaxsim/physics/algos/jacobian.py +0 -98
- jaxsim/physics/algos/rnea.py +0 -180
- jaxsim/physics/algos/rnea_motors.py +0 -196
- jaxsim/physics/algos/soft_contacts.py +0 -454
- jaxsim/physics/algos/utils.py +0 -69
- jaxsim/physics/model/__init__.py +0 -0
- jaxsim/physics/model/ground_contact.py +0 -55
- jaxsim/physics/model/physics_model.py +0 -358
- jaxsim/physics/model/physics_model_state.py +0 -174
- jaxsim/simulation/__init__.py +0 -4
- jaxsim/simulation/integrators.py +0 -452
- jaxsim/simulation/ode.py +0 -290
- jaxsim/simulation/ode_data.py +0 -53
- jaxsim/simulation/ode_integration.py +0 -125
- jaxsim/simulation/simulator.py +0 -544
- jaxsim/simulation/simulator_callbacks.py +0 -53
- jaxsim/simulation/utils.py +0 -15
- jaxsim/sixd/__init__.py +0 -2
- jaxsim/utils/oop.py +0 -532
- jaxsim/utils/vmappable.py +0 -117
- jaxsim-0.1rc0.dist-info/METADATA +0 -167
- jaxsim-0.1rc0.dist-info/RECORD +0 -64
- {jaxsim-0.1rc0.dist-info → jaxsim-0.2.0.dist-info}/top_level.txt +0 -0
@@ -46,23 +46,21 @@ class FlatTerrain(Terrain):
|
|
46
46
|
|
47
47
|
@jax_dataclasses.pytree_dataclass
|
48
48
|
class PlaneTerrain(Terrain):
|
49
|
-
plane_normal:
|
50
|
-
default_factory=lambda: jnp.array([0, 0, 1.0])
|
51
|
-
)
|
49
|
+
plane_normal: list = jax_dataclasses.field(default_factory=lambda: [0, 0, 1.0])
|
52
50
|
|
53
51
|
@staticmethod
|
54
|
-
def build(plane_normal:
|
52
|
+
def build(plane_normal: list) -> "PlaneTerrain":
|
55
53
|
"""
|
56
54
|
Create a PlaneTerrain instance with a specified plane normal vector.
|
57
55
|
|
58
56
|
Args:
|
59
|
-
plane_normal (
|
57
|
+
plane_normal (list): The normal vector of the terrain plane.
|
60
58
|
|
61
59
|
Returns:
|
62
60
|
PlaneTerrain: A PlaneTerrain instance.
|
63
61
|
"""
|
64
62
|
|
65
|
-
return PlaneTerrain(plane_normal=
|
63
|
+
return PlaneTerrain(plane_normal=plane_normal)
|
66
64
|
|
67
65
|
def height(self, x: float, y: float) -> float:
|
68
66
|
"""
|
jaxsim/typing.py
CHANGED
@@ -1,39 +1,39 @@
|
|
1
|
-
from typing import Any,
|
1
|
+
from typing import Any, Hashable
|
2
2
|
|
3
|
-
import jax
|
4
|
-
import numpy.typing as npt
|
3
|
+
import jax
|
5
4
|
|
5
|
+
# =========
|
6
6
|
# JAX types
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
jnp.uint32,
|
16
|
-
jnp.uint64,
|
17
|
-
]
|
18
|
-
ArrayJax = jnp.ndarray
|
19
|
-
TensorJax = jnp.ndarray
|
7
|
+
# =========
|
8
|
+
|
9
|
+
ScalarJax = jax.Array
|
10
|
+
IntJax = ScalarJax
|
11
|
+
BoolJax = ScalarJax
|
12
|
+
FloatJax = ScalarJax
|
13
|
+
|
14
|
+
ArrayJax = jax.Array
|
20
15
|
VectorJax = ArrayJax
|
21
16
|
MatrixJax = ArrayJax
|
22
|
-
PyTree = Union[
|
23
|
-
TensorJax,
|
24
|
-
Dict[Hashable, "PyTree"],
|
25
|
-
List["PyTree"],
|
26
|
-
NamedTuple,
|
27
|
-
Tuple["PyTree"],
|
28
|
-
None,
|
29
|
-
Any,
|
30
|
-
]
|
31
17
|
|
18
|
+
PyTree = (
|
19
|
+
dict[Hashable, "PyTree"] | list["PyTree"] | tuple["PyTree"] | None | jax.Array | Any
|
20
|
+
)
|
21
|
+
|
22
|
+
# =======================
|
32
23
|
# Mixed JAX / NumPy types
|
33
|
-
|
34
|
-
|
24
|
+
# =======================
|
25
|
+
|
26
|
+
Array = jax.typing.ArrayLike
|
35
27
|
Vector = Array
|
36
28
|
Matrix = Array
|
37
|
-
|
38
|
-
Int =
|
39
|
-
|
29
|
+
|
30
|
+
Int = int | IntJax
|
31
|
+
Bool = bool | ArrayJax
|
32
|
+
Float = float | FloatJax
|
33
|
+
|
34
|
+
ArrayLike = Array
|
35
|
+
VectorLike = Vector
|
36
|
+
MatrixLike = Matrix
|
37
|
+
IntLike = Int
|
38
|
+
BoolLike = Bool
|
39
|
+
FloatLike = Float
|
jaxsim/utils/__init__.py
CHANGED
@@ -1,8 +1,5 @@
|
|
1
1
|
from jax_dataclasses._copy_and_mutate import _Mutability as Mutability
|
2
2
|
|
3
|
+
from .hashless import HashlessObject
|
3
4
|
from .jaxsim_dataclass import JaxsimDataclass
|
4
5
|
from .tracing import not_tracing, tracing
|
5
|
-
from .vmappable import Vmappable
|
6
|
-
|
7
|
-
# Leave this below the others to prevent circular imports
|
8
|
-
from .oop import jax_tf # isort: skip
|
jaxsim/utils/hashless.py
ADDED
@@ -0,0 +1,18 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import dataclasses
|
4
|
+
from typing import Generic, TypeVar
|
5
|
+
|
6
|
+
T = TypeVar("T")
|
7
|
+
|
8
|
+
|
9
|
+
@dataclasses.dataclass
|
10
|
+
class HashlessObject(Generic[T]):
|
11
|
+
|
12
|
+
obj: T
|
13
|
+
|
14
|
+
def get(self: HashlessObject[T]) -> T:
|
15
|
+
return self.obj
|
16
|
+
|
17
|
+
def __hash__(self) -> int:
|
18
|
+
return 0
|
jaxsim/utils/jaxsim_dataclass.py
CHANGED
@@ -1,8 +1,9 @@
|
|
1
1
|
import abc
|
2
2
|
import contextlib
|
3
|
-
import copy
|
4
3
|
import dataclasses
|
5
|
-
|
4
|
+
import functools
|
5
|
+
from collections.abc import Iterator
|
6
|
+
from typing import Any, Callable, ClassVar, Sequence, Type
|
6
7
|
|
7
8
|
import jax.flatten_util
|
8
9
|
import jax_dataclasses
|
@@ -19,51 +20,219 @@ except ImportError:
|
|
19
20
|
|
20
21
|
@jax_dataclasses.pytree_dataclass
|
21
22
|
class JaxsimDataclass(abc.ABC):
|
22
|
-
""""""
|
23
|
+
"""Class extending `jax_dataclasses.pytree_dataclass` instances with utilities."""
|
23
24
|
|
24
25
|
# This attribute is set by jax_dataclasses
|
25
26
|
__mutability__: ClassVar[Mutability] = Mutability.FROZEN
|
26
27
|
|
27
28
|
@contextlib.contextmanager
|
28
|
-
def editable(self: Self, validate: bool = True) ->
|
29
|
-
"""
|
29
|
+
def editable(self: Self, validate: bool = True) -> Iterator[Self]:
|
30
|
+
"""
|
31
|
+
Context manager to operate on a mutable copy of the object.
|
32
|
+
|
33
|
+
Args:
|
34
|
+
validate: Whether to validate the output PyTree upon exiting the context.
|
35
|
+
|
36
|
+
Yields:
|
37
|
+
A mutable copy of the object.
|
38
|
+
|
39
|
+
Note:
|
40
|
+
This context manager is useful to operate on an r/w copy of a PyTree making
|
41
|
+
sure that the output object does not trigger JIT recompilations.
|
42
|
+
"""
|
30
43
|
|
31
44
|
mutability = (
|
32
45
|
Mutability.MUTABLE if validate else Mutability.MUTABLE_NO_VALIDATION
|
33
46
|
)
|
34
47
|
|
35
|
-
with
|
48
|
+
with self.copy().mutable_context(mutability=mutability) as obj:
|
36
49
|
yield obj
|
37
50
|
|
38
51
|
@contextlib.contextmanager
|
39
52
|
def mutable_context(
|
40
53
|
self: Self, mutability: Mutability, restore_after_exception: bool = True
|
41
|
-
) ->
|
42
|
-
"""
|
54
|
+
) -> Iterator[Self]:
|
55
|
+
"""
|
56
|
+
Context manager to temporarily change the mutability of the object.
|
57
|
+
|
58
|
+
Args:
|
59
|
+
mutability: The mutability to set.
|
60
|
+
restore_after_exception:
|
61
|
+
Whether to restore the original object in case of an exception
|
62
|
+
occurring within the context.
|
63
|
+
|
64
|
+
Yields:
|
65
|
+
The object with the new mutability.
|
66
|
+
|
67
|
+
Note:
|
68
|
+
This context manager is useful to operate in place on a PyTree without
|
69
|
+
the need to make a copy while optionally keeping active the checks on
|
70
|
+
the PyTree structure, shapes, and dtypes.
|
71
|
+
"""
|
43
72
|
|
44
73
|
if restore_after_exception:
|
45
74
|
self_copy = self.copy()
|
46
75
|
|
47
|
-
original_mutability = self.
|
76
|
+
original_mutability = self.mutability()
|
77
|
+
|
78
|
+
original_dtypes = JaxsimDataclass.get_leaf_dtypes(tree=self)
|
79
|
+
original_shapes = JaxsimDataclass.get_leaf_shapes(tree=self)
|
80
|
+
original_weak_types = JaxsimDataclass.get_leaf_weak_types(tree=self)
|
81
|
+
original_structure = jax.tree_util.tree_structure(tree=self)
|
48
82
|
|
49
|
-
def restore_self():
|
50
|
-
self.
|
83
|
+
def restore_self() -> None:
|
84
|
+
self.set_mutability(mutability=Mutability.MUTABLE_NO_VALIDATION)
|
51
85
|
for f in dataclasses.fields(self_copy):
|
52
86
|
setattr(self, f.name, getattr(self_copy, f.name))
|
53
87
|
|
54
88
|
try:
|
55
|
-
self.
|
89
|
+
self.set_mutability(mutability)
|
56
90
|
yield self
|
91
|
+
|
92
|
+
if mutability is not Mutability.MUTABLE_NO_VALIDATION:
|
93
|
+
new_structure = jax.tree_util.tree_structure(tree=self)
|
94
|
+
if original_structure != new_structure:
|
95
|
+
msg = "Pytree structure has changed from {} to {}"
|
96
|
+
raise ValueError(msg.format(original_structure, new_structure))
|
97
|
+
|
98
|
+
new_shapes = JaxsimDataclass.get_leaf_shapes(tree=self)
|
99
|
+
if original_shapes != new_shapes:
|
100
|
+
msg = "Leaves shapes have changed from {} to {}"
|
101
|
+
raise ValueError(msg.format(original_shapes, new_shapes))
|
102
|
+
|
103
|
+
new_dtypes = JaxsimDataclass.get_leaf_dtypes(tree=self)
|
104
|
+
if original_dtypes != new_dtypes:
|
105
|
+
msg = "Leaves dtypes have changed from {} to {}"
|
106
|
+
raise ValueError(msg.format(original_dtypes, new_dtypes))
|
107
|
+
|
108
|
+
new_weak_types = JaxsimDataclass.get_leaf_weak_types(tree=self)
|
109
|
+
if original_weak_types != new_weak_types:
|
110
|
+
msg = "Leaves weak types have changed from {} to {}"
|
111
|
+
raise ValueError(msg.format(original_weak_types, new_weak_types))
|
112
|
+
|
57
113
|
except Exception as e:
|
58
114
|
if restore_after_exception:
|
59
115
|
restore_self()
|
60
|
-
self.
|
116
|
+
self.set_mutability(original_mutability)
|
61
117
|
raise e
|
118
|
+
|
62
119
|
finally:
|
63
|
-
self.
|
120
|
+
self.set_mutability(original_mutability)
|
121
|
+
|
122
|
+
@staticmethod
|
123
|
+
def get_leaf_shapes(tree: jtp.PyTree) -> tuple[tuple[int, ...] | None]:
|
124
|
+
"""
|
125
|
+
Helper method to get the leaf shapes of a PyTree.
|
126
|
+
|
127
|
+
Args:
|
128
|
+
tree: The PyTree to consider.
|
129
|
+
|
130
|
+
Returns:
|
131
|
+
A tuple containing the leaf shapes of the PyTree or `None` is the leaf is
|
132
|
+
not a numpy-like array.
|
133
|
+
"""
|
134
|
+
|
135
|
+
return tuple( # noqa
|
136
|
+
leaf.shape if hasattr(leaf, "shape") else None
|
137
|
+
for leaf in jax.tree_util.tree_leaves(tree)
|
138
|
+
if hasattr(leaf, "shape")
|
139
|
+
)
|
140
|
+
|
141
|
+
@staticmethod
|
142
|
+
def get_leaf_dtypes(tree: jtp.PyTree) -> tuple:
|
143
|
+
"""
|
144
|
+
Helper method to get the leaf dtypes of a PyTree.
|
145
|
+
|
146
|
+
Args:
|
147
|
+
tree: The PyTree to consider.
|
148
|
+
|
149
|
+
Returns:
|
150
|
+
A tuple containing the leaf dtypes of the PyTree or `None` is the leaf is
|
151
|
+
not a numpy-like array.
|
152
|
+
"""
|
153
|
+
|
154
|
+
return tuple(
|
155
|
+
leaf.dtype if hasattr(leaf, "dtype") else None
|
156
|
+
for leaf in jax.tree_util.tree_leaves(tree)
|
157
|
+
if hasattr(leaf, "dtype")
|
158
|
+
)
|
159
|
+
|
160
|
+
@staticmethod
|
161
|
+
def get_leaf_weak_types(tree: jtp.PyTree) -> tuple[bool, ...]:
|
162
|
+
"""
|
163
|
+
Helper method to get the leaf weak types of a PyTree.
|
164
|
+
|
165
|
+
Args:
|
166
|
+
tree: The PyTree to consider.
|
167
|
+
|
168
|
+
Returns:
|
169
|
+
A tuple marking whether the leaf contains a JAX array with weak type.
|
170
|
+
"""
|
171
|
+
|
172
|
+
return tuple(
|
173
|
+
leaf.weak_type if hasattr(leaf, "weak_type") else False
|
174
|
+
for leaf in jax.tree_util.tree_leaves(tree)
|
175
|
+
if hasattr(leaf, "weak_type")
|
176
|
+
)
|
177
|
+
|
178
|
+
@staticmethod
|
179
|
+
def check_compatibility(*trees: Sequence[Any]) -> None:
|
180
|
+
"""
|
181
|
+
Check whether the PyTrees are compatible in structure, shape, and dtype.
|
182
|
+
|
183
|
+
Args:
|
184
|
+
*trees: The PyTrees to compare.
|
185
|
+
|
186
|
+
Raises:
|
187
|
+
ValueError: If the PyTrees have incompatible structures, shapes, or dtypes.
|
188
|
+
"""
|
189
|
+
|
190
|
+
target_structure = jax.tree_util.tree_structure(trees[0])
|
191
|
+
|
192
|
+
compatible_structure = functools.reduce(
|
193
|
+
lambda compatible, tree: compatible
|
194
|
+
and jax.tree_util.tree_structure(tree) == target_structure,
|
195
|
+
trees[1:],
|
196
|
+
True,
|
197
|
+
)
|
198
|
+
|
199
|
+
if not compatible_structure:
|
200
|
+
raise ValueError("Pytrees have incompatible structures.")
|
201
|
+
|
202
|
+
target_shapes = JaxsimDataclass.get_leaf_shapes(trees[0])
|
203
|
+
|
204
|
+
compatible_shapes = functools.reduce(
|
205
|
+
lambda compatible, tree: compatible
|
206
|
+
and JaxsimDataclass.get_leaf_shapes(tree) == target_shapes,
|
207
|
+
trees[1:],
|
208
|
+
True,
|
209
|
+
)
|
210
|
+
|
211
|
+
if not compatible_shapes:
|
212
|
+
raise ValueError("Pytrees have incompatible shapes.")
|
213
|
+
|
214
|
+
target_dtypes = JaxsimDataclass.get_leaf_dtypes(trees[0])
|
215
|
+
|
216
|
+
compatible_dtypes = functools.reduce(
|
217
|
+
lambda compatible, tree: compatible
|
218
|
+
and JaxsimDataclass.get_leaf_dtypes(tree) == target_dtypes,
|
219
|
+
trees[1:],
|
220
|
+
True,
|
221
|
+
)
|
222
|
+
|
223
|
+
if not compatible_dtypes:
|
224
|
+
raise ValueError("Pytrees have incompatible dtypes.")
|
64
225
|
|
65
226
|
def is_mutable(self, validate: bool = False) -> bool:
|
66
|
-
"""
|
227
|
+
"""
|
228
|
+
Check whether the object is mutable.
|
229
|
+
|
230
|
+
Args:
|
231
|
+
validate: Additionally checks if the object also has validation enabled.
|
232
|
+
|
233
|
+
Returns:
|
234
|
+
True if the object is mutable, False otherwise.
|
235
|
+
"""
|
67
236
|
|
68
237
|
return (
|
69
238
|
self.__mutability__ is Mutability.MUTABLE
|
@@ -71,39 +240,120 @@ class JaxsimDataclass(abc.ABC):
|
|
71
240
|
else self.__mutability__ is Mutability.MUTABLE_NO_VALIDATION
|
72
241
|
)
|
73
242
|
|
74
|
-
def
|
75
|
-
|
76
|
-
|
77
|
-
else:
|
78
|
-
mutability = (
|
79
|
-
Mutability.MUTABLE if validate else Mutability.MUTABLE_NO_VALIDATION
|
80
|
-
)
|
243
|
+
def mutability(self) -> Mutability:
|
244
|
+
"""
|
245
|
+
Get the mutability type of the object.
|
81
246
|
|
82
|
-
|
247
|
+
Returns:
|
248
|
+
The mutability type of the object.
|
249
|
+
"""
|
83
250
|
|
84
|
-
def _mutability(self) -> Mutability:
|
85
251
|
return self.__mutability__
|
86
252
|
|
87
|
-
def
|
253
|
+
def set_mutability(self, mutability: Mutability) -> None:
|
254
|
+
"""
|
255
|
+
Set the mutability of the object in-place.
|
256
|
+
|
257
|
+
Args:
|
258
|
+
mutability: The desired mutability type.
|
259
|
+
"""
|
260
|
+
|
88
261
|
jax_dataclasses._copy_and_mutate._mark_mutable(
|
89
262
|
self, mutable=mutability, visited=set()
|
90
263
|
)
|
91
264
|
|
92
265
|
def mutable(self: Self, mutable: bool = True, validate: bool = False) -> Self:
|
93
|
-
|
266
|
+
"""
|
267
|
+
Return a mutable reference of the object.
|
268
|
+
|
269
|
+
Args:
|
270
|
+
mutable: Whether to make the object mutable.
|
271
|
+
validate: Whether to enable validation on the object.
|
272
|
+
|
273
|
+
Returns:
|
274
|
+
A mutable reference of the object.
|
275
|
+
"""
|
276
|
+
|
277
|
+
if mutable:
|
278
|
+
mutability = (
|
279
|
+
Mutability.MUTABLE if validate else Mutability.MUTABLE_NO_VALIDATION
|
280
|
+
)
|
281
|
+
else:
|
282
|
+
mutability = Mutability.FROZEN
|
283
|
+
|
284
|
+
self.set_mutability(mutability=mutability)
|
94
285
|
return self
|
95
286
|
|
96
287
|
def copy(self: Self) -> Self:
|
288
|
+
"""
|
289
|
+
Return a copy of the object.
|
290
|
+
|
291
|
+
Returns:
|
292
|
+
A copy of the object.
|
293
|
+
"""
|
294
|
+
|
295
|
+
# Make a copy calling tree_map.
|
97
296
|
obj = jax.tree_util.tree_map(lambda leaf: leaf, self)
|
98
|
-
|
297
|
+
|
298
|
+
# Make sure that the copied object and all the copied leaves have the same
|
299
|
+
# mutability of the original object.
|
300
|
+
obj.set_mutability(mutability=self.mutability())
|
301
|
+
|
99
302
|
return obj
|
100
303
|
|
101
304
|
def replace(self: Self, validate: bool = True, **kwargs) -> Self:
|
102
|
-
|
103
|
-
|
305
|
+
"""
|
306
|
+
Return a new object replacing in-place the specified fields with new values.
|
307
|
+
|
308
|
+
Args:
|
309
|
+
validate: Whether to validate that the new fields do not alter the PyTree.
|
310
|
+
**kwargs: The fields to replace.
|
311
|
+
|
312
|
+
Returns:
|
313
|
+
A reference of the object with the specified fields replaced.
|
314
|
+
"""
|
315
|
+
|
316
|
+
# Use the dataclasses replace method.
|
317
|
+
obj = dataclasses.replace(self, **kwargs)
|
318
|
+
|
319
|
+
if validate:
|
320
|
+
JaxsimDataclass.check_compatibility(self, obj)
|
321
|
+
|
322
|
+
# Make sure that all the new leaves have the same mutability of the object.
|
323
|
+
obj.set_mutability(mutability=self.mutability())
|
104
324
|
|
105
|
-
obj._set_mutability(mutability=self._mutability())
|
106
325
|
return obj
|
107
326
|
|
108
327
|
def flatten(self) -> jtp.VectorJax:
|
109
|
-
|
328
|
+
"""
|
329
|
+
Flatten the object into a 1D vector.
|
330
|
+
|
331
|
+
Returns:
|
332
|
+
A 1D vector containing the flattened object.
|
333
|
+
"""
|
334
|
+
|
335
|
+
return self.flatten_fn()(self)
|
336
|
+
|
337
|
+
@classmethod
|
338
|
+
def flatten_fn(cls: Type[Self]) -> Callable[[Self], jtp.VectorJax]:
|
339
|
+
"""
|
340
|
+
Return a function to flatten the object into a 1D vector.
|
341
|
+
|
342
|
+
Returns:
|
343
|
+
A function to flatten the object into a 1D vector.
|
344
|
+
"""
|
345
|
+
|
346
|
+
return lambda pytree: jax.flatten_util.ravel_pytree(pytree)[0]
|
347
|
+
|
348
|
+
def unflatten_fn(self: Self) -> Callable[[jtp.VectorJax], Self]:
|
349
|
+
"""
|
350
|
+
Return a function to unflatten a 1D vector into the object.
|
351
|
+
|
352
|
+
Returns:
|
353
|
+
A function to unflatten a 1D vector into the object.
|
354
|
+
|
355
|
+
Notes:
|
356
|
+
Due to JAX internals, the function to unflatten a PyTree needs to be
|
357
|
+
created from an existing instance of the PyTree.
|
358
|
+
"""
|
359
|
+
return jax.flatten_util.ravel_pytree(self)[1]
|