jaxsim 0.2.dev191__py3-none-any.whl → 0.2.dev364__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. jaxsim/__init__.py +3 -4
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +3 -1
  4. jaxsim/api/com.py +240 -0
  5. jaxsim/api/common.py +13 -2
  6. jaxsim/api/contact.py +120 -43
  7. jaxsim/api/data.py +112 -71
  8. jaxsim/api/joint.py +77 -36
  9. jaxsim/api/kin_dyn_parameters.py +777 -0
  10. jaxsim/api/link.py +150 -75
  11. jaxsim/api/model.py +542 -269
  12. jaxsim/api/ode.py +86 -74
  13. jaxsim/api/ode_data.py +694 -0
  14. jaxsim/api/references.py +12 -11
  15. jaxsim/integrators/__init__.py +2 -2
  16. jaxsim/integrators/common.py +110 -24
  17. jaxsim/integrators/fixed_step.py +11 -67
  18. jaxsim/integrators/variable_step.py +610 -0
  19. jaxsim/math/__init__.py +11 -0
  20. jaxsim/math/adjoint.py +24 -2
  21. jaxsim/math/joint_model.py +335 -0
  22. jaxsim/math/quaternion.py +44 -3
  23. jaxsim/math/rotation.py +4 -4
  24. jaxsim/math/transform.py +93 -0
  25. jaxsim/parsers/descriptions/link.py +2 -2
  26. jaxsim/parsers/rod/utils.py +7 -8
  27. jaxsim/rbda/__init__.py +7 -0
  28. jaxsim/rbda/aba.py +295 -0
  29. jaxsim/rbda/collidable_points.py +142 -0
  30. jaxsim/{physics/algos → rbda}/crba.py +43 -42
  31. jaxsim/rbda/forward_kinematics.py +113 -0
  32. jaxsim/rbda/jacobian.py +201 -0
  33. jaxsim/rbda/rnea.py +237 -0
  34. jaxsim/rbda/soft_contacts.py +296 -0
  35. jaxsim/rbda/utils.py +152 -0
  36. jaxsim/terrain/__init__.py +2 -0
  37. jaxsim/utils/__init__.py +1 -4
  38. jaxsim/utils/hashless.py +18 -0
  39. jaxsim/utils/jaxsim_dataclass.py +281 -30
  40. {jaxsim-0.2.dev191.dist-info → jaxsim-0.2.dev364.dist-info}/METADATA +4 -6
  41. jaxsim-0.2.dev364.dist-info/RECORD +64 -0
  42. {jaxsim-0.2.dev191.dist-info → jaxsim-0.2.dev364.dist-info}/WHEEL +1 -1
  43. jaxsim/high_level/__init__.py +0 -2
  44. jaxsim/high_level/common.py +0 -11
  45. jaxsim/high_level/joint.py +0 -148
  46. jaxsim/high_level/link.py +0 -259
  47. jaxsim/high_level/model.py +0 -1686
  48. jaxsim/math/conv.py +0 -114
  49. jaxsim/math/joint.py +0 -102
  50. jaxsim/math/plucker.py +0 -100
  51. jaxsim/physics/__init__.py +0 -12
  52. jaxsim/physics/algos/__init__.py +0 -0
  53. jaxsim/physics/algos/aba.py +0 -254
  54. jaxsim/physics/algos/aba_motors.py +0 -284
  55. jaxsim/physics/algos/forward_kinematics.py +0 -79
  56. jaxsim/physics/algos/jacobian.py +0 -98
  57. jaxsim/physics/algos/rnea.py +0 -180
  58. jaxsim/physics/algos/rnea_motors.py +0 -196
  59. jaxsim/physics/algos/soft_contacts.py +0 -523
  60. jaxsim/physics/algos/utils.py +0 -69
  61. jaxsim/physics/model/__init__.py +0 -0
  62. jaxsim/physics/model/ground_contact.py +0 -53
  63. jaxsim/physics/model/physics_model.py +0 -388
  64. jaxsim/physics/model/physics_model_state.py +0 -283
  65. jaxsim/simulation/__init__.py +0 -4
  66. jaxsim/simulation/integrators.py +0 -393
  67. jaxsim/simulation/ode.py +0 -290
  68. jaxsim/simulation/ode_data.py +0 -96
  69. jaxsim/simulation/ode_integration.py +0 -62
  70. jaxsim/simulation/simulator.py +0 -543
  71. jaxsim/simulation/simulator_callbacks.py +0 -79
  72. jaxsim/simulation/utils.py +0 -15
  73. jaxsim/sixd/__init__.py +0 -2
  74. jaxsim/utils/oop.py +0 -536
  75. jaxsim/utils/vmappable.py +0 -117
  76. jaxsim-0.2.dev191.dist-info/RECORD +0 -81
  77. /jaxsim/{physics/algos → terrain}/terrain.py +0 -0
  78. {jaxsim-0.2.dev191.dist-info → jaxsim-0.2.dev364.dist-info}/LICENSE +0 -0
  79. {jaxsim-0.2.dev191.dist-info → jaxsim-0.2.dev364.dist-info}/top_level.txt +0 -0
@@ -1,388 +0,0 @@
1
- import dataclasses
2
- from typing import Dict, Union
3
-
4
- import jax.lax
5
- import jax.numpy as jnp
6
- import jax_dataclasses
7
- import numpy as np
8
- from jax_dataclasses import Static
9
-
10
- import jaxsim.parsers
11
- import jaxsim.physics
12
- import jaxsim.typing as jtp
13
- from jaxsim.parsers.descriptions import JointDescriptor, JointType
14
- from jaxsim.physics import default_gravity
15
- from jaxsim.sixd import se3
16
- from jaxsim.utils import JaxsimDataclass, not_tracing
17
-
18
- from .ground_contact import GroundContact
19
- from .physics_model_state import PhysicsModelState
20
-
21
-
22
- @jax_dataclasses.pytree_dataclass
23
- class PhysicsModel(JaxsimDataclass):
24
- """
25
- A read-only class to store all the information necessary to run RBDAs on a model.
26
-
27
- This class contains information about the physics model, including the number of bodies, initial state, gravity,
28
- floating base configuration, ground contact points, and more.
29
-
30
- Attributes:
31
- NB (Static[int]): The number of bodies in the physics model.
32
- initial_state (PhysicsModelState): The initial state of the physics model (default: None).
33
- gravity (jtp.Vector): The gravity vector (default: [0, 0, 0, 0, 0, 0]).
34
- is_floating_base (Static[bool]): A flag indicating whether the model has a floating base (default: False).
35
- gc (GroundContact): The ground contact points of the model (default: empty GroundContact instance).
36
- description (Static[jaxsim.parsers.descriptions.model.ModelDescription]): A description of the model (default: None).
37
- """
38
-
39
- NB: Static[int]
40
- initial_state: PhysicsModelState = dataclasses.field(default=None)
41
- gravity: jtp.Vector = dataclasses.field(
42
- default_factory=lambda: jnp.hstack(
43
- [np.zeros(3), jaxsim.physics.default_gravity()]
44
- )
45
- )
46
- is_floating_base: Static[bool] = dataclasses.field(default=False)
47
- gc: GroundContact = dataclasses.field(default_factory=lambda: GroundContact())
48
- description: Static[jaxsim.parsers.descriptions.model.ModelDescription] = (
49
- dataclasses.field(default=None)
50
- )
51
-
52
- _parent_array_dict: Static[Dict[int, int]] = dataclasses.field(default_factory=dict)
53
- _jtype_dict: Static[Dict[int, Union[JointType, JointDescriptor]]] = (
54
- dataclasses.field(default_factory=dict)
55
- )
56
- _tree_transforms_dict: Dict[int, jtp.Matrix] = dataclasses.field(
57
- default_factory=dict
58
- )
59
- _link_inertias_dict: Dict[int, jtp.Matrix] = dataclasses.field(default_factory=dict)
60
-
61
- _joint_friction_static: Dict[int, float] = dataclasses.field(default_factory=dict)
62
- _joint_friction_viscous: Dict[int, float] = dataclasses.field(default_factory=dict)
63
-
64
- _joint_limit_spring: Dict[int, float] = dataclasses.field(default_factory=dict)
65
- _joint_limit_damper: Dict[int, float] = dataclasses.field(default_factory=dict)
66
-
67
- _joint_motor_inertia: Dict[int, float] = dataclasses.field(default_factory=dict)
68
- _joint_motor_gear_ratio: Dict[int, float] = dataclasses.field(default_factory=dict)
69
- _joint_motor_viscous_friction: Dict[int, float] = dataclasses.field(
70
- default_factory=dict
71
- )
72
-
73
- _link_masses: jtp.Vector = dataclasses.field(init=False)
74
- _link_spatial_inertias: jtp.Vector = dataclasses.field(init=False)
75
- _joint_position_limits_min: jtp.Matrix = dataclasses.field(init=False)
76
- _joint_position_limits_max: jtp.Matrix = dataclasses.field(init=False)
77
-
78
- def __post_init__(self):
79
- if self.initial_state is None:
80
- initial_state = PhysicsModelState.zero(physics_model=self)
81
- object.__setattr__(self, "initial_state", initial_state)
82
-
83
- ordered_links = sorted(
84
- list(self.description.links_dict.values()),
85
- key=lambda l: l.index,
86
- )
87
-
88
- ordered_joints = sorted(
89
- list(self.description.joints_dict.values()),
90
- key=lambda j: j.index,
91
- )
92
-
93
- from jaxsim.utils import Mutability
94
-
95
- with self.mutable_context(
96
- mutability=Mutability.MUTABLE_NO_VALIDATION, restore_after_exception=False
97
- ):
98
- self._link_masses = jnp.stack([link.mass for link in ordered_links])
99
- self._link_spatial_inertias = jnp.stack(
100
- [self._link_inertias_dict[l.index] for l in ordered_links]
101
- )
102
-
103
- s_min = jnp.array([j.position_limit[0] for j in ordered_joints])
104
- s_max = jnp.array([j.position_limit[1] for j in ordered_joints])
105
- self._joint_position_limits_min = jnp.vstack([s_min, s_max]).min(axis=0)
106
- self._joint_position_limits_max = jnp.vstack([s_min, s_max]).max(axis=0)
107
-
108
- @staticmethod
109
- def build_from(
110
- model_description: jaxsim.parsers.descriptions.model.ModelDescription,
111
- gravity: jtp.Vector = default_gravity(),
112
- ) -> "PhysicsModel":
113
- if gravity.size != 3:
114
- raise ValueError(gravity.size)
115
-
116
- # Currently, we assume that the link frame matches the frame of its parent joint
117
- for l in model_description:
118
- if not jnp.allclose(l.pose, jnp.eye(4)):
119
- raise ValueError(f"Link '{l.name}' has unsupported pose:\n{l.pose}")
120
-
121
- # ===================================
122
- # Initialize physics model parameters
123
- # ===================================
124
-
125
- # Get the number of bodies, including the base link
126
- num_of_bodies = len(model_description)
127
-
128
- # Build the parent array λ of the floating-base model.
129
- # Note: the parent of the base link is not set since it's not defined.
130
- parent_array_dict = {
131
- link.index: link.parent.index
132
- for link in model_description
133
- if link.parent is not None
134
- }
135
-
136
- # Get the 6D inertias of all links
137
- link_spatial_inertias_dict = {
138
- link.index: link.inertia for link in iter(model_description)
139
- }
140
-
141
- # Dict from the joint index to its type.
142
- # Note: the joint index is equal to its child link index.
143
- joint_types_dict = {
144
- joint.index: joint.jtype for joint in model_description.joints
145
- }
146
-
147
- # Dicts from the joint index to the static and viscous friction.
148
- # Note: the joint index is equal to its child link index.
149
- joint_friction_static = {
150
- joint.index: jnp.array(joint.friction_static, dtype=float)
151
- for joint in model_description.joints
152
- }
153
- joint_friction_viscous = {
154
- joint.index: jnp.array(joint.friction_viscous, dtype=float)
155
- for joint in model_description.joints
156
- }
157
-
158
- # Dicts from the joint index to the spring and damper joint limits parameters.
159
- # Note: the joint index is equal to its child link index.
160
- joint_limit_spring = {
161
- joint.index: jnp.array(joint.position_limit_spring, dtype=float)
162
- for joint in model_description.joints
163
- }
164
- joint_limit_damper = {
165
- joint.index: jnp.array(joint.position_limit_damper, dtype=float)
166
- for joint in model_description.joints
167
- }
168
-
169
- # Dicts from the joint index to the motor inertia, gear ratio and viscous friction.
170
- # Note: the joint index is equal to its child link index.
171
- joint_motor_inertia = {
172
- joint.index: jnp.array(joint.motor_inertia, dtype=float)
173
- for joint in model_description.joints
174
- }
175
- joint_motor_gear_ratio = {
176
- joint.index: jnp.array(joint.motor_gear_ratio, dtype=float)
177
- for joint in model_description.joints
178
- }
179
- joint_motor_viscous_friction = {
180
- joint.index: jnp.array(joint.motor_viscous_friction, dtype=float)
181
- for joint in model_description.joints
182
- }
183
-
184
- # Transform between model's root and model's base link
185
- # (this is just the pose of the base link in the SDF description)
186
- base_link = model_description.links_dict[model_description.link_names()[0]]
187
- R_H_B = model_description.transform(name=base_link.name)
188
- tree_transform_0 = se3.SE3.from_matrix(matrix=R_H_B).adjoint()
189
-
190
- # Helper to compute the transform pre(i)_H_λ(i).
191
- # Given a joint 'i', it is the coordinate transform between its predecessor
192
- # frame [pre(i)] and the frame of its parent link [λ(i)].
193
- prei_H_λi = lambda j: model_description.relative_transform(
194
- relative_to=j.name, name=j.parent.name
195
- )
196
-
197
- # Compute the tree transforms: pre(i)_X_λ(i).
198
- # Given a joint 'i', it is the coordinate transform between its predecessor
199
- # frame [pre(i)] and the frame of its parent link [λ(i)].
200
- tree_transforms_dict = {
201
- 0: tree_transform_0,
202
- **{
203
- j.index: se3.SE3.from_matrix(matrix=prei_H_λi(j)).adjoint()
204
- for j in model_description.joints
205
- },
206
- }
207
-
208
- # =======================
209
- # Build the initial state
210
- # =======================
211
-
212
- # Initial joint positions
213
- q0 = jnp.array(
214
- [
215
- model_description.joints_dict[j.name].initial_position
216
- for j in model_description.joints
217
- ]
218
- )
219
-
220
- # Build the initial state
221
- initial_state = PhysicsModelState(
222
- joint_positions=q0,
223
- joint_velocities=jnp.zeros_like(q0),
224
- base_position=model_description.root_pose.root_position,
225
- base_quaternion=model_description.root_pose.root_quaternion,
226
- )
227
-
228
- # =======================
229
- # Build the physics model
230
- # =======================
231
-
232
- # Initialize the model
233
- physics_model = PhysicsModel(
234
- NB=num_of_bodies,
235
- initial_state=initial_state,
236
- _parent_array_dict=parent_array_dict,
237
- _jtype_dict=joint_types_dict,
238
- _tree_transforms_dict=tree_transforms_dict,
239
- _link_inertias_dict=link_spatial_inertias_dict,
240
- _joint_friction_static=joint_friction_static,
241
- _joint_friction_viscous=joint_friction_viscous,
242
- _joint_limit_spring=joint_limit_spring,
243
- _joint_limit_damper=joint_limit_damper,
244
- _joint_motor_gear_ratio=joint_motor_gear_ratio,
245
- _joint_motor_inertia=joint_motor_inertia,
246
- _joint_motor_viscous_friction=joint_motor_viscous_friction,
247
- gravity=jnp.hstack([gravity.squeeze(), np.zeros(3)]),
248
- is_floating_base=True,
249
- gc=GroundContact.build_from(model_description=model_description),
250
- description=model_description,
251
- )
252
-
253
- # Floating-base models
254
- if not model_description.fixed_base:
255
- return physics_model
256
-
257
- # Fixed-base models
258
- with jax_dataclasses.copy_and_mutate(physics_model) as physics_model_fixed:
259
- physics_model_fixed.is_floating_base = False
260
-
261
- return physics_model_fixed
262
-
263
- def dofs(self) -> int:
264
- return len(list(self._jtype_dict.keys()))
265
-
266
- def set_gravity(self, gravity: jtp.Vector) -> None:
267
- gravity = gravity.squeeze()
268
-
269
- if gravity.size == 3:
270
- self.gravity = jnp.hstack([gravity, 0, 0, 0])
271
-
272
- elif gravity.size == 6:
273
- self.gravity = gravity
274
-
275
- else:
276
- raise ValueError(gravity.shape)
277
-
278
- @property
279
- def parent(self) -> jtp.Vector:
280
- return self.parent_array()
281
-
282
- def parent_array(self) -> jtp.Vector:
283
- """Returns λ(i)"""
284
- return jnp.array([-1] + list(self._parent_array_dict.values()), dtype=int)
285
-
286
- def support_body_array(self, body_index: jtp.Int) -> jtp.Vector:
287
- """Returns κ(i)"""
288
-
289
- κ_bool = self.support_body_array_bool(body_index=body_index)
290
- return jnp.array(jnp.where(κ_bool)[0], dtype=int)
291
-
292
- def support_body_array_bool(self, body_index: jtp.Int) -> jtp.Vector:
293
- active_link = body_index
294
- κ_bool = jnp.zeros(self.NB, dtype=bool)
295
-
296
- for i in np.flip(np.arange(start=0, stop=self.NB)):
297
- κ_bool, active_link = jax.lax.cond(
298
- pred=(i == active_link),
299
- false_fun=lambda: (κ_bool, active_link),
300
- true_fun=lambda: (
301
- κ_bool.at[active_link].set(True),
302
- self.parent[active_link],
303
- ),
304
- )
305
-
306
- return κ_bool
307
-
308
- @property
309
- def tree_transforms(self) -> jtp.Array:
310
- X_tree = jnp.array(
311
- [
312
- self._tree_transforms_dict.get(idx, jnp.eye(6))
313
- for idx in np.arange(start=0, stop=self.NB)
314
- ]
315
- )
316
-
317
- return X_tree
318
-
319
- @property
320
- def spatial_inertias(self) -> jtp.Array:
321
- M_links = jnp.array(
322
- [
323
- self._link_inertias_dict.get(idx, jnp.zeros(6))
324
- for idx in np.arange(start=0, stop=self.NB)
325
- ]
326
- )
327
-
328
- return M_links
329
-
330
- def jtype(self, joint_index: int) -> JointType:
331
- if joint_index == 0 or joint_index >= self.NB:
332
- raise ValueError(joint_index)
333
-
334
- return self._jtype_dict[joint_index]
335
-
336
- def joint_transforms(self, q: jtp.Vector) -> jtp.Array:
337
- from jaxsim.math.joint import jcalc
338
-
339
- if not_tracing(q):
340
- if q.shape[0] != self.dofs():
341
- raise ValueError(q.shape)
342
-
343
- Xj = jnp.stack(
344
- [jnp.zeros(shape=(6, 6))]
345
- + [
346
- jcalc(jtyp=self.jtype(index + 1), q=joint_position)[0]
347
- for index, joint_position in enumerate(q)
348
- ]
349
- )
350
-
351
- return Xj
352
-
353
- def motion_subspaces(self, q: jtp.Vector) -> jtp.Array:
354
- from jaxsim.math.joint import jcalc
355
-
356
- if not_tracing(var=q):
357
- if q.shape[0] != self.dofs():
358
- raise ValueError(q.shape)
359
-
360
- SS = jnp.stack(
361
- [jnp.vstack(jnp.zeros(6))]
362
- + [
363
- jcalc(jtyp=self.jtype(index + 1), q=joint_position)[1]
364
- for index, joint_position in enumerate(q)
365
- ]
366
- )
367
-
368
- return SS
369
-
370
- def __eq__(self, other: "PhysicsModel") -> bool:
371
- same = True
372
- same = same and self.NB == other.NB
373
- same = same and np.allclose(self.gravity, other.gravity)
374
-
375
- return same
376
-
377
- def __hash__(self):
378
- return hash(self.__repr__())
379
-
380
- def __repr__(self) -> str:
381
- attributes = [
382
- f"dofs: {self.dofs()},",
383
- f"links: {self.NB},",
384
- f"floating_base: {self.is_floating_base},",
385
- ]
386
- attributes_string = "\n ".join(attributes)
387
-
388
- return f"{type(self).__name__}(\n {attributes_string}\n)"
@@ -1,283 +0,0 @@
1
- from typing import Union
2
-
3
- import jax.numpy as jnp
4
- import jax_dataclasses
5
-
6
- import jaxsim.physics.model.physics_model
7
- import jaxsim.typing as jtp
8
- from jaxsim.utils import JaxsimDataclass
9
-
10
-
11
- @jax_dataclasses.pytree_dataclass
12
- class PhysicsModelState(JaxsimDataclass):
13
- """
14
- A class representing the state of a physics model.
15
-
16
- This class stores the joint positions, joint velocities, and the base state (position, orientation, linear velocity,
17
- and angular velocity) of a physics model.
18
-
19
- Attributes:
20
- joint_positions (jtp.Vector): An array representing the joint positions.
21
- joint_velocities (jtp.Vector): An array representing the joint velocities.
22
- base_position (jtp.Vector): An array representing the base position (default: zeros).
23
- base_quaternion (jtp.Vector): An array representing the base quaternion (default: [1.0, 0, 0, 0]).
24
- base_linear_velocity (jtp.Vector): An array representing the base linear velocity (default: zeros).
25
- base_angular_velocity (jtp.Vector): An array representing the base angular velocity (default: zeros).
26
- """
27
-
28
- # Joint state
29
- joint_positions: jtp.Vector
30
- joint_velocities: jtp.Vector
31
-
32
- # Base state
33
- base_position: jtp.Vector = jax_dataclasses.field(
34
- default_factory=lambda: jnp.zeros(3)
35
- )
36
- base_quaternion: jtp.Vector = jax_dataclasses.field(
37
- default_factory=lambda: jnp.array([1.0, 0, 0, 0])
38
- )
39
- base_linear_velocity: jtp.Vector = jax_dataclasses.field(
40
- default_factory=lambda: jnp.zeros(3)
41
- )
42
- base_angular_velocity: jtp.Vector = jax_dataclasses.field(
43
- default_factory=lambda: jnp.zeros(3)
44
- )
45
-
46
- @staticmethod
47
- def build(
48
- joint_positions: jtp.Vector | None = None,
49
- joint_velocities: jtp.Vector | None = None,
50
- base_position: jtp.Vector | None = None,
51
- base_quaternion: jtp.Vector | None = None,
52
- base_linear_velocity: jtp.Vector | None = None,
53
- base_angular_velocity: jtp.Vector | None = None,
54
- number_of_dofs: jtp.Int | None = None,
55
- ) -> "PhysicsModelState":
56
- """"""
57
-
58
- joint_positions = (
59
- joint_positions
60
- if joint_positions is not None
61
- else jnp.zeros(number_of_dofs)
62
- )
63
-
64
- joint_velocities = (
65
- joint_velocities
66
- if joint_velocities is not None
67
- else jnp.zeros(number_of_dofs)
68
- )
69
-
70
- base_position = base_position if base_position is not None else jnp.zeros(3)
71
-
72
- base_quaternion = (
73
- base_quaternion
74
- if base_quaternion is not None
75
- else jnp.array([1.0, 0, 0, 0])
76
- )
77
-
78
- base_linear_velocity = (
79
- base_linear_velocity if base_linear_velocity is not None else jnp.zeros(3)
80
- )
81
-
82
- base_angular_velocity = (
83
- base_angular_velocity if base_angular_velocity is not None else jnp.zeros(3)
84
- )
85
-
86
- physics_model_state = PhysicsModelState(
87
- joint_positions=jnp.array(joint_positions, dtype=float),
88
- joint_velocities=jnp.array(joint_velocities, dtype=float),
89
- base_position=jnp.array(base_position, dtype=float),
90
- base_quaternion=jnp.array(base_quaternion, dtype=float),
91
- base_linear_velocity=jnp.array(base_linear_velocity, dtype=float),
92
- base_angular_velocity=jnp.array(base_angular_velocity, dtype=float),
93
- )
94
-
95
- return physics_model_state
96
-
97
- @staticmethod
98
- def build_from_physics_model(
99
- joint_positions: jtp.Vector | None = None,
100
- joint_velocities: jtp.Vector | None = None,
101
- base_position: jtp.Vector | None = None,
102
- base_quaternion: jtp.Vector | None = None,
103
- base_linear_velocity: jtp.Vector | None = None,
104
- base_angular_velocity: jtp.Vector | None = None,
105
- physics_model: Union[
106
- "jaxsim.physics.model.physics_model.PhysicsModel", None
107
- ] = None,
108
- ) -> "PhysicsModelState":
109
- """"""
110
-
111
- return PhysicsModelState.build(
112
- joint_positions=joint_positions,
113
- joint_velocities=joint_velocities,
114
- base_position=base_position,
115
- base_quaternion=base_quaternion,
116
- base_linear_velocity=base_linear_velocity,
117
- base_angular_velocity=base_angular_velocity,
118
- number_of_dofs=physics_model.dofs(),
119
- )
120
-
121
- @staticmethod
122
- def zero(
123
- physics_model: "jaxsim.physics.model.physics_model.PhysicsModel",
124
- ) -> "PhysicsModelState":
125
- return PhysicsModelState.build_from_physics_model(physics_model=physics_model)
126
-
127
- def position(self) -> jtp.Vector:
128
- return jnp.hstack(
129
- [self.base_position, self.base_quaternion, self.joint_positions]
130
- )
131
-
132
- def velocity(self) -> jtp.Vector:
133
- # W_v_WB: inertial-fixed representation of the base velocity
134
- return jnp.hstack(
135
- [
136
- self.base_linear_velocity,
137
- self.base_angular_velocity,
138
- self.joint_velocities,
139
- ]
140
- )
141
-
142
- def xfb(self) -> jtp.Vector:
143
- return jnp.hstack(
144
- [
145
- self.base_quaternion,
146
- self.base_position,
147
- self.base_angular_velocity,
148
- self.base_linear_velocity,
149
- ]
150
- )
151
-
152
- def valid(
153
- self, physics_model: "jaxsim.physics.model.physics_model.PhysicsModel"
154
- ) -> bool:
155
- from jaxsim.simulation.utils import check_valid_shape
156
-
157
- valid = True
158
-
159
- valid = check_valid_shape(
160
- what="joint_positions",
161
- shape=self.joint_positions.shape,
162
- expected_shape=(physics_model.dofs(),),
163
- valid=valid,
164
- )
165
-
166
- valid = check_valid_shape(
167
- what="joint_velocities",
168
- shape=self.joint_velocities.shape,
169
- expected_shape=(physics_model.dofs(),),
170
- valid=valid,
171
- )
172
-
173
- valid = check_valid_shape(
174
- what="base_position",
175
- shape=self.base_position.shape,
176
- expected_shape=(3,),
177
- valid=valid,
178
- )
179
-
180
- valid = check_valid_shape(
181
- what="base_quaternion",
182
- shape=self.base_quaternion.shape,
183
- expected_shape=(4,),
184
- valid=valid,
185
- )
186
-
187
- valid = check_valid_shape(
188
- what="base_linear_velocity",
189
- shape=self.base_linear_velocity.shape,
190
- expected_shape=(3,),
191
- valid=valid,
192
- )
193
-
194
- valid = check_valid_shape(
195
- what="base_angular_velocity",
196
- shape=self.base_angular_velocity.shape,
197
- expected_shape=(3,),
198
- valid=valid,
199
- )
200
-
201
- return valid
202
-
203
-
204
- @jax_dataclasses.pytree_dataclass
205
- class PhysicsModelInput(JaxsimDataclass):
206
- """
207
- A class representing the input to a physics model.
208
-
209
- This class stores the joint torques and external forces acting on the bodies of a physics model.
210
-
211
- Attributes:
212
- tau: An array representing the joint torques.
213
- f_ext: A matrix representing the external forces acting on the bodies of the physics model.
214
- """
215
-
216
- tau: jtp.VectorJax
217
- f_ext: jtp.MatrixJax
218
-
219
- @staticmethod
220
- def build(
221
- tau: jtp.VectorJax | None = None,
222
- f_ext: jtp.MatrixJax | None = None,
223
- number_of_dofs: jtp.Int | None = None,
224
- number_of_links: jtp.Int | None = None,
225
- ) -> "PhysicsModelInput":
226
- """"""
227
-
228
- tau = tau if tau is not None else jnp.zeros(number_of_dofs)
229
- f_ext = f_ext if f_ext is not None else jnp.zeros(shape=(number_of_links, 6))
230
-
231
- return PhysicsModelInput(
232
- tau=jnp.array(tau, dtype=float), f_ext=jnp.array(f_ext, dtype=float)
233
- )
234
-
235
- @staticmethod
236
- def build_from_physics_model(
237
- tau: jtp.VectorJax | None = None,
238
- f_ext: jtp.MatrixJax | None = None,
239
- physics_model: Union[
240
- "jaxsim.physics.model.physics_model.PhysicsModel", None
241
- ] = None,
242
- ) -> "PhysicsModelInput":
243
- return PhysicsModelInput.build(
244
- tau=tau,
245
- f_ext=f_ext,
246
- number_of_dofs=physics_model.dofs(),
247
- number_of_links=physics_model.NB,
248
- )
249
-
250
- @staticmethod
251
- def zero(
252
- physics_model: "jaxsim.physics.model.physics_model.PhysicsModel",
253
- ) -> "PhysicsModelInput":
254
- return PhysicsModelInput.build_from_physics_model(physics_model=physics_model)
255
-
256
- def replace(self, validate: bool = True, **kwargs) -> "PhysicsModelInput":
257
- with jax_dataclasses.copy_and_mutate(self, validate=validate) as updated_input:
258
- _ = [updated_input.__setattr__(k, v) for k, v in kwargs.items()]
259
-
260
- return updated_input
261
-
262
- def valid(
263
- self, physics_model: "jaxsim.physics.model.physics_model.PhysicsModel"
264
- ) -> bool:
265
- from jaxsim.simulation.utils import check_valid_shape
266
-
267
- valid = True
268
-
269
- valid = check_valid_shape(
270
- what="tau",
271
- shape=self.tau.shape,
272
- expected_shape=(physics_model.dofs(),),
273
- valid=valid,
274
- )
275
-
276
- valid = check_valid_shape(
277
- what="f_ext",
278
- shape=self.f_ext.shape,
279
- expected_shape=(physics_model.NB, 6),
280
- valid=valid,
281
- )
282
-
283
- return valid
@@ -1,4 +0,0 @@
1
- from . import integrators, ode, ode_data, simulator
2
- from .ode_data import ODEInput, ODEState
3
- from .ode_integration import IntegratorType
4
- from .simulator import JaxSim, SimulatorData