jaxsim 0.2.dev188__py3-none-any.whl → 0.6.1.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/__init__.py +73 -22
- jaxsim/_version.py +2 -2
- jaxsim/api/__init__.py +13 -1
- jaxsim/api/com.py +423 -0
- jaxsim/api/common.py +48 -19
- jaxsim/api/contact.py +604 -52
- jaxsim/api/data.py +308 -163
- jaxsim/api/frame.py +471 -0
- jaxsim/api/joint.py +166 -37
- jaxsim/api/kin_dyn_parameters.py +901 -0
- jaxsim/api/link.py +277 -78
- jaxsim/api/model.py +1572 -362
- jaxsim/api/ode.py +324 -129
- jaxsim/api/ode_data.py +401 -0
- jaxsim/api/references.py +216 -80
- jaxsim/exceptions.py +80 -0
- jaxsim/integrators/__init__.py +2 -2
- jaxsim/integrators/common.py +191 -107
- jaxsim/integrators/fixed_step.py +97 -102
- jaxsim/integrators/variable_step.py +706 -0
- jaxsim/logging.py +1 -2
- jaxsim/math/__init__.py +13 -0
- jaxsim/math/adjoint.py +57 -22
- jaxsim/math/cross.py +16 -7
- jaxsim/math/inertia.py +10 -8
- jaxsim/math/joint_model.py +289 -0
- jaxsim/math/quaternion.py +54 -20
- jaxsim/math/rotation.py +27 -21
- jaxsim/math/skew.py +16 -5
- jaxsim/math/transform.py +102 -0
- jaxsim/math/utils.py +31 -0
- jaxsim/mujoco/__init__.py +2 -1
- jaxsim/mujoco/loaders.py +216 -29
- jaxsim/mujoco/model.py +163 -33
- jaxsim/mujoco/utils.py +228 -0
- jaxsim/mujoco/visualizer.py +107 -22
- jaxsim/parsers/__init__.py +0 -1
- jaxsim/parsers/descriptions/__init__.py +8 -2
- jaxsim/parsers/descriptions/collision.py +87 -16
- jaxsim/parsers/descriptions/joint.py +80 -87
- jaxsim/parsers/descriptions/link.py +62 -24
- jaxsim/parsers/descriptions/model.py +101 -68
- jaxsim/parsers/kinematic_graph.py +607 -225
- jaxsim/parsers/rod/meshes.py +104 -0
- jaxsim/parsers/rod/parser.py +125 -82
- jaxsim/parsers/rod/utils.py +127 -82
- jaxsim/rbda/__init__.py +11 -0
- jaxsim/rbda/aba.py +289 -0
- jaxsim/rbda/collidable_points.py +156 -0
- jaxsim/rbda/contacts/__init__.py +13 -0
- jaxsim/rbda/contacts/common.py +313 -0
- jaxsim/rbda/contacts/relaxed_rigid.py +605 -0
- jaxsim/rbda/contacts/rigid.py +462 -0
- jaxsim/rbda/contacts/soft.py +480 -0
- jaxsim/rbda/contacts/visco_elastic.py +1066 -0
- jaxsim/rbda/crba.py +167 -0
- jaxsim/rbda/forward_kinematics.py +117 -0
- jaxsim/rbda/jacobian.py +330 -0
- jaxsim/rbda/rnea.py +235 -0
- jaxsim/rbda/utils.py +160 -0
- jaxsim/terrain/__init__.py +2 -0
- jaxsim/terrain/terrain.py +238 -0
- jaxsim/typing.py +24 -24
- jaxsim/utils/__init__.py +1 -4
- jaxsim/utils/jaxsim_dataclass.py +289 -34
- jaxsim/utils/tracing.py +5 -11
- jaxsim/utils/wrappers.py +159 -0
- {jaxsim-0.2.dev188.dist-info → jaxsim-0.6.1.dev2.dist-info}/LICENSE +1 -1
- jaxsim-0.6.1.dev2.dist-info/METADATA +465 -0
- jaxsim-0.6.1.dev2.dist-info/RECORD +74 -0
- {jaxsim-0.2.dev188.dist-info → jaxsim-0.6.1.dev2.dist-info}/WHEEL +1 -1
- jaxsim/high_level/__init__.py +0 -2
- jaxsim/high_level/common.py +0 -11
- jaxsim/high_level/joint.py +0 -148
- jaxsim/high_level/link.py +0 -259
- jaxsim/high_level/model.py +0 -1686
- jaxsim/math/conv.py +0 -114
- jaxsim/math/joint.py +0 -102
- jaxsim/math/plucker.py +0 -100
- jaxsim/physics/__init__.py +0 -12
- jaxsim/physics/algos/__init__.py +0 -0
- jaxsim/physics/algos/aba.py +0 -254
- jaxsim/physics/algos/aba_motors.py +0 -284
- jaxsim/physics/algos/crba.py +0 -154
- jaxsim/physics/algos/forward_kinematics.py +0 -79
- jaxsim/physics/algos/jacobian.py +0 -98
- jaxsim/physics/algos/rnea.py +0 -180
- jaxsim/physics/algos/rnea_motors.py +0 -196
- jaxsim/physics/algos/soft_contacts.py +0 -523
- jaxsim/physics/algos/terrain.py +0 -80
- jaxsim/physics/algos/utils.py +0 -69
- jaxsim/physics/model/__init__.py +0 -0
- jaxsim/physics/model/ground_contact.py +0 -55
- jaxsim/physics/model/physics_model.py +0 -388
- jaxsim/physics/model/physics_model_state.py +0 -283
- jaxsim/simulation/__init__.py +0 -4
- jaxsim/simulation/integrators.py +0 -393
- jaxsim/simulation/ode.py +0 -290
- jaxsim/simulation/ode_data.py +0 -96
- jaxsim/simulation/ode_integration.py +0 -62
- jaxsim/simulation/simulator.py +0 -543
- jaxsim/simulation/simulator_callbacks.py +0 -79
- jaxsim/simulation/utils.py +0 -15
- jaxsim/sixd/__init__.py +0 -2
- jaxsim/utils/oop.py +0 -536
- jaxsim/utils/vmappable.py +0 -117
- jaxsim-0.2.dev188.dist-info/METADATA +0 -184
- jaxsim-0.2.dev188.dist-info/RECORD +0 -81
- {jaxsim-0.2.dev188.dist-info → jaxsim-0.6.1.dev2.dist-info}/top_level.txt +0 -0
@@ -1,55 +0,0 @@
|
|
1
|
-
import dataclasses
|
2
|
-
|
3
|
-
import jax.numpy as jnp
|
4
|
-
import jax_dataclasses
|
5
|
-
import numpy as np
|
6
|
-
import numpy.typing as npt
|
7
|
-
from jax_dataclasses import Static
|
8
|
-
|
9
|
-
from jaxsim.parsers.descriptions import ModelDescription
|
10
|
-
|
11
|
-
|
12
|
-
@jax_dataclasses.pytree_dataclass
|
13
|
-
class GroundContact:
|
14
|
-
"""
|
15
|
-
A class for managing collidable points in a robot model.
|
16
|
-
|
17
|
-
This class is used to store and manage information about collidable points on a robot model,
|
18
|
-
such as their positions and the corresponding bodies (links) they are associated with.
|
19
|
-
|
20
|
-
Attributes:
|
21
|
-
point (npt.NDArray): An array of shape (3, N) representing the 3D positions of collidable points.
|
22
|
-
body (Static[npt.NDArray]): An array of integers representing the indices of the bodies (links) associated with each collidable point.
|
23
|
-
"""
|
24
|
-
|
25
|
-
point: npt.NDArray = dataclasses.field(default_factory=lambda: jnp.array([]))
|
26
|
-
body: Static[npt.NDArray] = dataclasses.field(
|
27
|
-
default_factory=lambda: np.array([], dtype=int)
|
28
|
-
)
|
29
|
-
|
30
|
-
@staticmethod
|
31
|
-
def build_from(
|
32
|
-
model_description: ModelDescription,
|
33
|
-
) -> "GroundContact":
|
34
|
-
if len(model_description.collision_shapes) == 0:
|
35
|
-
return GroundContact()
|
36
|
-
|
37
|
-
# Get all the links so that we can take their updated index
|
38
|
-
links_dict = {link.name: link for link in model_description}
|
39
|
-
|
40
|
-
# Get all the enabled collidable points of the model
|
41
|
-
collidable_points = model_description.all_enabled_collidable_points()
|
42
|
-
|
43
|
-
# Build the GroundContact attributes
|
44
|
-
points = jnp.vstack([cp.position for cp in collidable_points]).T
|
45
|
-
link_index_of_points = np.array(
|
46
|
-
[links_dict[cp.parent_link.name].index for cp in collidable_points]
|
47
|
-
)
|
48
|
-
|
49
|
-
# Build the object
|
50
|
-
gc = GroundContact(point=points, body=link_index_of_points)
|
51
|
-
|
52
|
-
assert gc.point.shape[0] == 3
|
53
|
-
assert gc.point.shape[1] == len(gc.body)
|
54
|
-
|
55
|
-
return gc
|
@@ -1,388 +0,0 @@
|
|
1
|
-
import dataclasses
|
2
|
-
from typing import Dict, Union
|
3
|
-
|
4
|
-
import jax.lax
|
5
|
-
import jax.numpy as jnp
|
6
|
-
import jax_dataclasses
|
7
|
-
import numpy as np
|
8
|
-
from jax_dataclasses import Static
|
9
|
-
|
10
|
-
import jaxsim.parsers
|
11
|
-
import jaxsim.physics
|
12
|
-
import jaxsim.typing as jtp
|
13
|
-
from jaxsim.parsers.descriptions import JointDescriptor, JointType
|
14
|
-
from jaxsim.physics import default_gravity
|
15
|
-
from jaxsim.sixd import se3
|
16
|
-
from jaxsim.utils import JaxsimDataclass, not_tracing
|
17
|
-
|
18
|
-
from .ground_contact import GroundContact
|
19
|
-
from .physics_model_state import PhysicsModelState
|
20
|
-
|
21
|
-
|
22
|
-
@jax_dataclasses.pytree_dataclass
|
23
|
-
class PhysicsModel(JaxsimDataclass):
|
24
|
-
"""
|
25
|
-
A read-only class to store all the information necessary to run RBDAs on a model.
|
26
|
-
|
27
|
-
This class contains information about the physics model, including the number of bodies, initial state, gravity,
|
28
|
-
floating base configuration, ground contact points, and more.
|
29
|
-
|
30
|
-
Attributes:
|
31
|
-
NB (Static[int]): The number of bodies in the physics model.
|
32
|
-
initial_state (PhysicsModelState): The initial state of the physics model (default: None).
|
33
|
-
gravity (jtp.Vector): The gravity vector (default: [0, 0, 0, 0, 0, 0]).
|
34
|
-
is_floating_base (Static[bool]): A flag indicating whether the model has a floating base (default: False).
|
35
|
-
gc (GroundContact): The ground contact points of the model (default: empty GroundContact instance).
|
36
|
-
description (Static[jaxsim.parsers.descriptions.model.ModelDescription]): A description of the model (default: None).
|
37
|
-
"""
|
38
|
-
|
39
|
-
NB: Static[int]
|
40
|
-
initial_state: PhysicsModelState = dataclasses.field(default=None)
|
41
|
-
gravity: jtp.Vector = dataclasses.field(
|
42
|
-
default_factory=lambda: jnp.hstack(
|
43
|
-
[np.zeros(3), jaxsim.physics.default_gravity()]
|
44
|
-
)
|
45
|
-
)
|
46
|
-
is_floating_base: Static[bool] = dataclasses.field(default=False)
|
47
|
-
gc: GroundContact = dataclasses.field(default_factory=lambda: GroundContact())
|
48
|
-
description: Static[jaxsim.parsers.descriptions.model.ModelDescription] = (
|
49
|
-
dataclasses.field(default=None)
|
50
|
-
)
|
51
|
-
|
52
|
-
_parent_array_dict: Static[Dict[int, int]] = dataclasses.field(default_factory=dict)
|
53
|
-
_jtype_dict: Static[Dict[int, Union[JointType, JointDescriptor]]] = (
|
54
|
-
dataclasses.field(default_factory=dict)
|
55
|
-
)
|
56
|
-
_tree_transforms_dict: Dict[int, jtp.Matrix] = dataclasses.field(
|
57
|
-
default_factory=dict
|
58
|
-
)
|
59
|
-
_link_inertias_dict: Dict[int, jtp.Matrix] = dataclasses.field(default_factory=dict)
|
60
|
-
|
61
|
-
_joint_friction_static: Dict[int, float] = dataclasses.field(default_factory=dict)
|
62
|
-
_joint_friction_viscous: Dict[int, float] = dataclasses.field(default_factory=dict)
|
63
|
-
|
64
|
-
_joint_limit_spring: Dict[int, float] = dataclasses.field(default_factory=dict)
|
65
|
-
_joint_limit_damper: Dict[int, float] = dataclasses.field(default_factory=dict)
|
66
|
-
|
67
|
-
_joint_motor_inertia: Dict[int, float] = dataclasses.field(default_factory=dict)
|
68
|
-
_joint_motor_gear_ratio: Dict[int, float] = dataclasses.field(default_factory=dict)
|
69
|
-
_joint_motor_viscous_friction: Dict[int, float] = dataclasses.field(
|
70
|
-
default_factory=dict
|
71
|
-
)
|
72
|
-
|
73
|
-
_link_masses: jtp.Vector = dataclasses.field(init=False)
|
74
|
-
_link_spatial_inertias: jtp.Vector = dataclasses.field(init=False)
|
75
|
-
_joint_position_limits_min: jtp.Matrix = dataclasses.field(init=False)
|
76
|
-
_joint_position_limits_max: jtp.Matrix = dataclasses.field(init=False)
|
77
|
-
|
78
|
-
def __post_init__(self):
|
79
|
-
if self.initial_state is None:
|
80
|
-
initial_state = PhysicsModelState.zero(physics_model=self)
|
81
|
-
object.__setattr__(self, "initial_state", initial_state)
|
82
|
-
|
83
|
-
ordered_links = sorted(
|
84
|
-
list(self.description.links_dict.values()),
|
85
|
-
key=lambda l: l.index,
|
86
|
-
)
|
87
|
-
|
88
|
-
ordered_joints = sorted(
|
89
|
-
list(self.description.joints_dict.values()),
|
90
|
-
key=lambda j: j.index,
|
91
|
-
)
|
92
|
-
|
93
|
-
from jaxsim.utils import Mutability
|
94
|
-
|
95
|
-
with self.mutable_context(
|
96
|
-
mutability=Mutability.MUTABLE_NO_VALIDATION, restore_after_exception=False
|
97
|
-
):
|
98
|
-
self._link_masses = jnp.stack([link.mass for link in ordered_links])
|
99
|
-
self._link_spatial_inertias = jnp.stack(
|
100
|
-
[self._link_inertias_dict[l.index] for l in ordered_links]
|
101
|
-
)
|
102
|
-
|
103
|
-
s_min = jnp.array([j.position_limit[0] for j in ordered_joints])
|
104
|
-
s_max = jnp.array([j.position_limit[1] for j in ordered_joints])
|
105
|
-
self._joint_position_limits_min = jnp.vstack([s_min, s_max]).min(axis=0)
|
106
|
-
self._joint_position_limits_max = jnp.vstack([s_min, s_max]).max(axis=0)
|
107
|
-
|
108
|
-
@staticmethod
|
109
|
-
def build_from(
|
110
|
-
model_description: jaxsim.parsers.descriptions.model.ModelDescription,
|
111
|
-
gravity: jtp.Vector = default_gravity(),
|
112
|
-
) -> "PhysicsModel":
|
113
|
-
if gravity.size != 3:
|
114
|
-
raise ValueError(gravity.size)
|
115
|
-
|
116
|
-
# Currently, we assume that the link frame matches the frame of its parent joint
|
117
|
-
for l in model_description:
|
118
|
-
if not jnp.allclose(l.pose, jnp.eye(4)):
|
119
|
-
raise ValueError(f"Link '{l.name}' has unsupported pose:\n{l.pose}")
|
120
|
-
|
121
|
-
# ===================================
|
122
|
-
# Initialize physics model parameters
|
123
|
-
# ===================================
|
124
|
-
|
125
|
-
# Get the number of bodies, including the base link
|
126
|
-
num_of_bodies = len(model_description)
|
127
|
-
|
128
|
-
# Build the parent array λ of the floating-base model.
|
129
|
-
# Note: the parent of the base link is not set since it's not defined.
|
130
|
-
parent_array_dict = {
|
131
|
-
link.index: link.parent.index
|
132
|
-
for link in model_description
|
133
|
-
if link.parent is not None
|
134
|
-
}
|
135
|
-
|
136
|
-
# Get the 6D inertias of all links
|
137
|
-
link_spatial_inertias_dict = {
|
138
|
-
link.index: link.inertia for link in iter(model_description)
|
139
|
-
}
|
140
|
-
|
141
|
-
# Dict from the joint index to its type.
|
142
|
-
# Note: the joint index is equal to its child link index.
|
143
|
-
joint_types_dict = {
|
144
|
-
joint.index: joint.jtype for joint in model_description.joints
|
145
|
-
}
|
146
|
-
|
147
|
-
# Dicts from the joint index to the static and viscous friction.
|
148
|
-
# Note: the joint index is equal to its child link index.
|
149
|
-
joint_friction_static = {
|
150
|
-
joint.index: jnp.array(joint.friction_static, dtype=float)
|
151
|
-
for joint in model_description.joints
|
152
|
-
}
|
153
|
-
joint_friction_viscous = {
|
154
|
-
joint.index: jnp.array(joint.friction_viscous, dtype=float)
|
155
|
-
for joint in model_description.joints
|
156
|
-
}
|
157
|
-
|
158
|
-
# Dicts from the joint index to the spring and damper joint limits parameters.
|
159
|
-
# Note: the joint index is equal to its child link index.
|
160
|
-
joint_limit_spring = {
|
161
|
-
joint.index: jnp.array(joint.position_limit_spring, dtype=float)
|
162
|
-
for joint in model_description.joints
|
163
|
-
}
|
164
|
-
joint_limit_damper = {
|
165
|
-
joint.index: jnp.array(joint.position_limit_damper, dtype=float)
|
166
|
-
for joint in model_description.joints
|
167
|
-
}
|
168
|
-
|
169
|
-
# Dicts from the joint index to the motor inertia, gear ratio and viscous friction.
|
170
|
-
# Note: the joint index is equal to its child link index.
|
171
|
-
joint_motor_inertia = {
|
172
|
-
joint.index: jnp.array(joint.motor_inertia, dtype=float)
|
173
|
-
for joint in model_description.joints
|
174
|
-
}
|
175
|
-
joint_motor_gear_ratio = {
|
176
|
-
joint.index: jnp.array(joint.motor_gear_ratio, dtype=float)
|
177
|
-
for joint in model_description.joints
|
178
|
-
}
|
179
|
-
joint_motor_viscous_friction = {
|
180
|
-
joint.index: jnp.array(joint.motor_viscous_friction, dtype=float)
|
181
|
-
for joint in model_description.joints
|
182
|
-
}
|
183
|
-
|
184
|
-
# Transform between model's root and model's base link
|
185
|
-
# (this is just the pose of the base link in the SDF description)
|
186
|
-
base_link = model_description.links_dict[model_description.link_names()[0]]
|
187
|
-
R_H_B = model_description.transform(name=base_link.name)
|
188
|
-
tree_transform_0 = se3.SE3.from_matrix(matrix=R_H_B).adjoint()
|
189
|
-
|
190
|
-
# Helper to compute the transform pre(i)_H_λ(i).
|
191
|
-
# Given a joint 'i', it is the coordinate transform between its predecessor
|
192
|
-
# frame [pre(i)] and the frame of its parent link [λ(i)].
|
193
|
-
prei_H_λi = lambda j: model_description.relative_transform(
|
194
|
-
relative_to=j.name, name=j.parent.name
|
195
|
-
)
|
196
|
-
|
197
|
-
# Compute the tree transforms: pre(i)_X_λ(i).
|
198
|
-
# Given a joint 'i', it is the coordinate transform between its predecessor
|
199
|
-
# frame [pre(i)] and the frame of its parent link [λ(i)].
|
200
|
-
tree_transforms_dict = {
|
201
|
-
0: tree_transform_0,
|
202
|
-
**{
|
203
|
-
j.index: se3.SE3.from_matrix(matrix=prei_H_λi(j)).adjoint()
|
204
|
-
for j in model_description.joints
|
205
|
-
},
|
206
|
-
}
|
207
|
-
|
208
|
-
# =======================
|
209
|
-
# Build the initial state
|
210
|
-
# =======================
|
211
|
-
|
212
|
-
# Initial joint positions
|
213
|
-
q0 = jnp.array(
|
214
|
-
[
|
215
|
-
model_description.joints_dict[j.name].initial_position
|
216
|
-
for j in model_description.joints
|
217
|
-
]
|
218
|
-
)
|
219
|
-
|
220
|
-
# Build the initial state
|
221
|
-
initial_state = PhysicsModelState(
|
222
|
-
joint_positions=q0,
|
223
|
-
joint_velocities=jnp.zeros_like(q0),
|
224
|
-
base_position=model_description.root_pose.root_position,
|
225
|
-
base_quaternion=model_description.root_pose.root_quaternion,
|
226
|
-
)
|
227
|
-
|
228
|
-
# =======================
|
229
|
-
# Build the physics model
|
230
|
-
# =======================
|
231
|
-
|
232
|
-
# Initialize the model
|
233
|
-
physics_model = PhysicsModel(
|
234
|
-
NB=num_of_bodies,
|
235
|
-
initial_state=initial_state,
|
236
|
-
_parent_array_dict=parent_array_dict,
|
237
|
-
_jtype_dict=joint_types_dict,
|
238
|
-
_tree_transforms_dict=tree_transforms_dict,
|
239
|
-
_link_inertias_dict=link_spatial_inertias_dict,
|
240
|
-
_joint_friction_static=joint_friction_static,
|
241
|
-
_joint_friction_viscous=joint_friction_viscous,
|
242
|
-
_joint_limit_spring=joint_limit_spring,
|
243
|
-
_joint_limit_damper=joint_limit_damper,
|
244
|
-
_joint_motor_gear_ratio=joint_motor_gear_ratio,
|
245
|
-
_joint_motor_inertia=joint_motor_inertia,
|
246
|
-
_joint_motor_viscous_friction=joint_motor_viscous_friction,
|
247
|
-
gravity=jnp.hstack([gravity.squeeze(), np.zeros(3)]),
|
248
|
-
is_floating_base=True,
|
249
|
-
gc=GroundContact.build_from(model_description=model_description),
|
250
|
-
description=model_description,
|
251
|
-
)
|
252
|
-
|
253
|
-
# Floating-base models
|
254
|
-
if not model_description.fixed_base:
|
255
|
-
return physics_model
|
256
|
-
|
257
|
-
# Fixed-base models
|
258
|
-
with jax_dataclasses.copy_and_mutate(physics_model) as physics_model_fixed:
|
259
|
-
physics_model_fixed.is_floating_base = False
|
260
|
-
|
261
|
-
return physics_model_fixed
|
262
|
-
|
263
|
-
def dofs(self) -> int:
|
264
|
-
return len(list(self._jtype_dict.keys()))
|
265
|
-
|
266
|
-
def set_gravity(self, gravity: jtp.Vector) -> None:
|
267
|
-
gravity = gravity.squeeze()
|
268
|
-
|
269
|
-
if gravity.size == 3:
|
270
|
-
self.gravity = jnp.hstack([gravity, 0, 0, 0])
|
271
|
-
|
272
|
-
elif gravity.size == 6:
|
273
|
-
self.gravity = gravity
|
274
|
-
|
275
|
-
else:
|
276
|
-
raise ValueError(gravity.shape)
|
277
|
-
|
278
|
-
@property
|
279
|
-
def parent(self) -> jtp.Vector:
|
280
|
-
return self.parent_array()
|
281
|
-
|
282
|
-
def parent_array(self) -> jtp.Vector:
|
283
|
-
"""Returns λ(i)"""
|
284
|
-
return jnp.array([-1] + list(self._parent_array_dict.values()), dtype=int)
|
285
|
-
|
286
|
-
def support_body_array(self, body_index: jtp.Int) -> jtp.Vector:
|
287
|
-
"""Returns κ(i)"""
|
288
|
-
|
289
|
-
κ_bool = self.support_body_array_bool(body_index=body_index)
|
290
|
-
return jnp.array(jnp.where(κ_bool)[0], dtype=int)
|
291
|
-
|
292
|
-
def support_body_array_bool(self, body_index: jtp.Int) -> jtp.Vector:
|
293
|
-
active_link = body_index
|
294
|
-
κ_bool = jnp.zeros(self.NB, dtype=bool)
|
295
|
-
|
296
|
-
for i in np.flip(np.arange(start=0, stop=self.NB)):
|
297
|
-
κ_bool, active_link = jax.lax.cond(
|
298
|
-
pred=(i == active_link),
|
299
|
-
false_fun=lambda: (κ_bool, active_link),
|
300
|
-
true_fun=lambda: (
|
301
|
-
κ_bool.at[active_link].set(True),
|
302
|
-
self.parent[active_link],
|
303
|
-
),
|
304
|
-
)
|
305
|
-
|
306
|
-
return κ_bool
|
307
|
-
|
308
|
-
@property
|
309
|
-
def tree_transforms(self) -> jtp.Array:
|
310
|
-
X_tree = jnp.array(
|
311
|
-
[
|
312
|
-
self._tree_transforms_dict.get(idx, jnp.eye(6))
|
313
|
-
for idx in np.arange(start=0, stop=self.NB)
|
314
|
-
]
|
315
|
-
)
|
316
|
-
|
317
|
-
return X_tree
|
318
|
-
|
319
|
-
@property
|
320
|
-
def spatial_inertias(self) -> jtp.Array:
|
321
|
-
M_links = jnp.array(
|
322
|
-
[
|
323
|
-
self._link_inertias_dict.get(idx, jnp.zeros(6))
|
324
|
-
for idx in np.arange(start=0, stop=self.NB)
|
325
|
-
]
|
326
|
-
)
|
327
|
-
|
328
|
-
return M_links
|
329
|
-
|
330
|
-
def jtype(self, joint_index: int) -> JointType:
|
331
|
-
if joint_index == 0 or joint_index >= self.NB:
|
332
|
-
raise ValueError(joint_index)
|
333
|
-
|
334
|
-
return self._jtype_dict[joint_index]
|
335
|
-
|
336
|
-
def joint_transforms(self, q: jtp.Vector) -> jtp.Array:
|
337
|
-
from jaxsim.math.joint import jcalc
|
338
|
-
|
339
|
-
if not_tracing(q):
|
340
|
-
if q.shape[0] != self.dofs():
|
341
|
-
raise ValueError(q.shape)
|
342
|
-
|
343
|
-
Xj = jnp.stack(
|
344
|
-
[jnp.zeros(shape=(6, 6))]
|
345
|
-
+ [
|
346
|
-
jcalc(jtyp=self.jtype(index + 1), q=joint_position)[0]
|
347
|
-
for index, joint_position in enumerate(q)
|
348
|
-
]
|
349
|
-
)
|
350
|
-
|
351
|
-
return Xj
|
352
|
-
|
353
|
-
def motion_subspaces(self, q: jtp.Vector) -> jtp.Array:
|
354
|
-
from jaxsim.math.joint import jcalc
|
355
|
-
|
356
|
-
if not_tracing(var=q):
|
357
|
-
if q.shape[0] != self.dofs():
|
358
|
-
raise ValueError(q.shape)
|
359
|
-
|
360
|
-
SS = jnp.stack(
|
361
|
-
[jnp.vstack(jnp.zeros(6))]
|
362
|
-
+ [
|
363
|
-
jcalc(jtyp=self.jtype(index + 1), q=joint_position)[1]
|
364
|
-
for index, joint_position in enumerate(q)
|
365
|
-
]
|
366
|
-
)
|
367
|
-
|
368
|
-
return SS
|
369
|
-
|
370
|
-
def __eq__(self, other: "PhysicsModel") -> bool:
|
371
|
-
same = True
|
372
|
-
same = same and self.NB == other.NB
|
373
|
-
same = same and np.allclose(self.gravity, other.gravity)
|
374
|
-
|
375
|
-
return same
|
376
|
-
|
377
|
-
def __hash__(self):
|
378
|
-
return hash(self.__repr__())
|
379
|
-
|
380
|
-
def __repr__(self) -> str:
|
381
|
-
attributes = [
|
382
|
-
f"dofs: {self.dofs()},",
|
383
|
-
f"links: {self.NB},",
|
384
|
-
f"floating_base: {self.is_floating_base},",
|
385
|
-
]
|
386
|
-
attributes_string = "\n ".join(attributes)
|
387
|
-
|
388
|
-
return f"{type(self).__name__}(\n {attributes_string}\n)"
|