jaxsim 0.1.dev401__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. jaxsim/__init__.py +5 -6
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +3 -0
  4. jaxsim/api/com.py +240 -0
  5. jaxsim/api/common.py +216 -0
  6. jaxsim/api/contact.py +271 -0
  7. jaxsim/api/data.py +821 -0
  8. jaxsim/api/joint.py +189 -0
  9. jaxsim/api/kin_dyn_parameters.py +777 -0
  10. jaxsim/api/link.py +361 -0
  11. jaxsim/api/model.py +1633 -0
  12. jaxsim/api/ode.py +295 -0
  13. jaxsim/api/ode_data.py +694 -0
  14. jaxsim/api/references.py +421 -0
  15. jaxsim/integrators/__init__.py +2 -0
  16. jaxsim/integrators/common.py +594 -0
  17. jaxsim/integrators/fixed_step.py +102 -0
  18. jaxsim/integrators/variable_step.py +610 -0
  19. jaxsim/math/__init__.py +11 -0
  20. jaxsim/math/adjoint.py +24 -2
  21. jaxsim/math/joint_model.py +335 -0
  22. jaxsim/math/quaternion.py +44 -3
  23. jaxsim/math/rotation.py +4 -4
  24. jaxsim/math/transform.py +92 -0
  25. jaxsim/mujoco/__init__.py +3 -0
  26. jaxsim/mujoco/__main__.py +192 -0
  27. jaxsim/mujoco/loaders.py +615 -0
  28. jaxsim/mujoco/model.py +414 -0
  29. jaxsim/mujoco/visualizer.py +176 -0
  30. jaxsim/parsers/descriptions/collision.py +14 -0
  31. jaxsim/parsers/descriptions/link.py +13 -2
  32. jaxsim/parsers/kinematic_graph.py +8 -3
  33. jaxsim/parsers/rod/parser.py +54 -38
  34. jaxsim/parsers/rod/utils.py +7 -8
  35. jaxsim/rbda/__init__.py +7 -0
  36. jaxsim/rbda/aba.py +295 -0
  37. jaxsim/rbda/collidable_points.py +142 -0
  38. jaxsim/{physics/algos → rbda}/crba.py +43 -42
  39. jaxsim/rbda/forward_kinematics.py +113 -0
  40. jaxsim/rbda/jacobian.py +201 -0
  41. jaxsim/rbda/rnea.py +237 -0
  42. jaxsim/rbda/soft_contacts.py +296 -0
  43. jaxsim/rbda/utils.py +152 -0
  44. jaxsim/terrain/__init__.py +2 -0
  45. jaxsim/{physics/algos → terrain}/terrain.py +4 -6
  46. jaxsim/typing.py +30 -30
  47. jaxsim/utils/__init__.py +1 -4
  48. jaxsim/utils/hashless.py +18 -0
  49. jaxsim/utils/jaxsim_dataclass.py +281 -31
  50. {jaxsim-0.1.dev401.dist-info → jaxsim-0.2.0.dist-info}/LICENSE +1 -1
  51. jaxsim-0.2.0.dist-info/METADATA +237 -0
  52. jaxsim-0.2.0.dist-info/RECORD +64 -0
  53. {jaxsim-0.1.dev401.dist-info → jaxsim-0.2.0.dist-info}/WHEEL +1 -1
  54. jaxsim/high_level/__init__.py +0 -2
  55. jaxsim/high_level/common.py +0 -11
  56. jaxsim/high_level/joint.py +0 -148
  57. jaxsim/high_level/link.py +0 -259
  58. jaxsim/high_level/model.py +0 -1695
  59. jaxsim/math/conv.py +0 -114
  60. jaxsim/math/joint.py +0 -101
  61. jaxsim/math/plucker.py +0 -100
  62. jaxsim/physics/__init__.py +0 -12
  63. jaxsim/physics/algos/__init__.py +0 -0
  64. jaxsim/physics/algos/aba.py +0 -256
  65. jaxsim/physics/algos/aba_motors.py +0 -284
  66. jaxsim/physics/algos/forward_kinematics.py +0 -79
  67. jaxsim/physics/algos/jacobian.py +0 -98
  68. jaxsim/physics/algos/rnea.py +0 -180
  69. jaxsim/physics/algos/rnea_motors.py +0 -196
  70. jaxsim/physics/algos/soft_contacts.py +0 -454
  71. jaxsim/physics/algos/utils.py +0 -69
  72. jaxsim/physics/model/__init__.py +0 -0
  73. jaxsim/physics/model/ground_contact.py +0 -55
  74. jaxsim/physics/model/physics_model.py +0 -358
  75. jaxsim/physics/model/physics_model_state.py +0 -174
  76. jaxsim/simulation/__init__.py +0 -4
  77. jaxsim/simulation/integrators.py +0 -452
  78. jaxsim/simulation/ode.py +0 -290
  79. jaxsim/simulation/ode_data.py +0 -53
  80. jaxsim/simulation/ode_integration.py +0 -125
  81. jaxsim/simulation/simulator.py +0 -544
  82. jaxsim/simulation/simulator_callbacks.py +0 -53
  83. jaxsim/simulation/utils.py +0 -15
  84. jaxsim/sixd/__init__.py +0 -2
  85. jaxsim/utils/oop.py +0 -532
  86. jaxsim/utils/vmappable.py +0 -117
  87. jaxsim-0.1.dev401.dist-info/METADATA +0 -167
  88. jaxsim-0.1.dev401.dist-info/RECORD +0 -64
  89. {jaxsim-0.1.dev401.dist-info → jaxsim-0.2.0.dist-info}/top_level.txt +0 -0
@@ -1,69 +0,0 @@
1
- from typing import Tuple
2
-
3
- import jax.numpy as jnp
4
-
5
- import jaxsim.typing as jtp
6
- from jaxsim.physics.model.physics_model import PhysicsModel
7
-
8
-
9
- def process_inputs(
10
- physics_model: PhysicsModel,
11
- xfb: jtp.Vector | None = None,
12
- q: jtp.Vector | None = None,
13
- qd: jtp.Vector | None = None,
14
- qdd: jtp.Vector | None = None,
15
- tau: jtp.Vector | None = None,
16
- f_ext: jtp.Matrix | None = None,
17
- ) -> Tuple[jtp.Vector, jtp.Vector, jtp.Vector, jtp.Vector, jtp.Vector, jtp.Matrix]:
18
- """
19
- Adjust the inputs to the physics model.
20
-
21
- Args:
22
- physics_model: The physics model.
23
- xfb: The variables of the base link.
24
- q: The generalized coordinates.
25
- qd: The generalized velocities.
26
- qdd: The generalized accelerations.
27
- tau: The generalized forces.
28
- f_ext: The link external forces.
29
-
30
- Returns:
31
- The adjusted inputs.
32
- """
33
-
34
- # Remove extra dimensions
35
- q = q.squeeze() if q is not None else jnp.zeros(physics_model.dofs())
36
- qd = qd.squeeze() if qd is not None else jnp.zeros(physics_model.dofs())
37
- qdd = qdd.squeeze() if qdd is not None else jnp.zeros(physics_model.dofs())
38
- tau = tau.squeeze() if tau is not None else jnp.zeros(physics_model.dofs())
39
- xfb = xfb.squeeze() if xfb is not None else jnp.zeros(13).at[0].set(1)
40
- f_ext = (
41
- f_ext.squeeze()
42
- if f_ext is not None
43
- else jnp.zeros(shape=(physics_model.NB, 6)).squeeze()
44
- )
45
-
46
- # Fix case with just 1 DoF
47
- q = jnp.atleast_1d(q)
48
- qd = jnp.atleast_1d(qd)
49
- qdd = jnp.atleast_1d(qdd)
50
- tau = jnp.atleast_1d(tau)
51
-
52
- # Fix case with just 1 body
53
- f_ext = jnp.atleast_2d(f_ext)
54
-
55
- # Validate dimensions
56
- dofs = physics_model.dofs()
57
-
58
- if xfb is not None and xfb.shape[0] != 13:
59
- raise ValueError(xfb.shape)
60
- if q is not None and q.shape[0] != dofs:
61
- raise ValueError(q.shape, dofs)
62
- if qd is not None and qd.shape[0] != dofs:
63
- raise ValueError(qd.shape, dofs)
64
- if tau is not None and tau.shape[0] != dofs:
65
- raise ValueError(tau.shape, dofs)
66
- if f_ext is not None and f_ext.shape != (physics_model.NB, 6):
67
- raise ValueError(f_ext.shape, (physics_model.NB, 6))
68
-
69
- return xfb, q, qd, qdd, tau, f_ext
File without changes
@@ -1,55 +0,0 @@
1
- import dataclasses
2
-
3
- import jax.numpy as jnp
4
- import jax_dataclasses
5
- import numpy as np
6
- import numpy.typing as npt
7
- from jax_dataclasses import Static
8
-
9
- from jaxsim.parsers.descriptions import ModelDescription
10
-
11
-
12
- @jax_dataclasses.pytree_dataclass
13
- class GroundContact:
14
- """
15
- A class for managing collidable points in a robot model.
16
-
17
- This class is used to store and manage information about collidable points on a robot model,
18
- such as their positions and the corresponding bodies (links) they are associated with.
19
-
20
- Attributes:
21
- point (npt.NDArray): An array of shape (3, N) representing the 3D positions of collidable points.
22
- body (Static[npt.NDArray]): An array of integers representing the indices of the bodies (links) associated with each collidable point.
23
- """
24
-
25
- point: npt.NDArray = dataclasses.field(default_factory=lambda: jnp.array([]))
26
- body: Static[npt.NDArray] = dataclasses.field(
27
- default_factory=lambda: np.array([], dtype=int)
28
- )
29
-
30
- @staticmethod
31
- def build_from(
32
- model_description: ModelDescription,
33
- ) -> "GroundContact":
34
- if len(model_description.collision_shapes) == 0:
35
- return GroundContact()
36
-
37
- # Get all the links so that we can take their updated index
38
- links_dict = {link.name: link for link in model_description}
39
-
40
- # Get all the enabled collidable points of the model
41
- collidable_points = model_description.all_enabled_collidable_points()
42
-
43
- # Build the GroundContact attributes
44
- points = jnp.vstack([cp.position for cp in collidable_points]).T
45
- link_index_of_points = np.array(
46
- [links_dict[cp.parent_link.name].index for cp in collidable_points]
47
- )
48
-
49
- # Build the object
50
- gc = GroundContact(point=points, body=link_index_of_points)
51
-
52
- assert gc.point.shape[0] == 3
53
- assert gc.point.shape[1] == len(gc.body)
54
-
55
- return gc
@@ -1,358 +0,0 @@
1
- import dataclasses
2
- from typing import Dict, Union
3
-
4
- import jax.lax
5
- import jax.numpy as jnp
6
- import jax_dataclasses
7
- import numpy as np
8
- from jax_dataclasses import Static
9
-
10
- import jaxsim.parsers
11
- import jaxsim.physics
12
- import jaxsim.typing as jtp
13
- from jaxsim.parsers.descriptions import JointDescriptor, JointType
14
- from jaxsim.physics import default_gravity
15
- from jaxsim.sixd import se3
16
- from jaxsim.utils import JaxsimDataclass, not_tracing
17
-
18
- from .ground_contact import GroundContact
19
- from .physics_model_state import PhysicsModelState
20
-
21
-
22
- @jax_dataclasses.pytree_dataclass
23
- class PhysicsModel(JaxsimDataclass):
24
- """
25
- A read-only class to store all the information necessary to run RBDAs on a model.
26
-
27
- This class contains information about the physics model, including the number of bodies, initial state, gravity,
28
- floating base configuration, ground contact points, and more.
29
-
30
- Attributes:
31
- NB (Static[int]): The number of bodies in the physics model.
32
- initial_state (PhysicsModelState): The initial state of the physics model (default: None).
33
- gravity (jtp.Vector): The gravity vector (default: [0, 0, 0, 0, 0, 0]).
34
- is_floating_base (Static[bool]): A flag indicating whether the model has a floating base (default: False).
35
- gc (GroundContact): The ground contact points of the model (default: empty GroundContact instance).
36
- description (Static[jaxsim.parsers.descriptions.model.ModelDescription]): A description of the model (default: None).
37
- """
38
-
39
- NB: Static[int]
40
- initial_state: PhysicsModelState = dataclasses.field(default=None)
41
- gravity: jtp.Vector = dataclasses.field(
42
- default_factory=lambda: jnp.hstack(
43
- [np.zeros(3), jaxsim.physics.default_gravity()]
44
- )
45
- )
46
- is_floating_base: Static[bool] = dataclasses.field(default=False)
47
- gc: GroundContact = dataclasses.field(default_factory=lambda: GroundContact())
48
- description: Static[
49
- jaxsim.parsers.descriptions.model.ModelDescription
50
- ] = dataclasses.field(default=None)
51
-
52
- _parent_array_dict: Static[Dict[int, int]] = dataclasses.field(default_factory=dict)
53
- _jtype_dict: Static[
54
- Dict[int, Union[JointType, JointDescriptor]]
55
- ] = dataclasses.field(default_factory=dict)
56
- _tree_transforms_dict: Dict[int, jtp.Matrix] = dataclasses.field(
57
- default_factory=dict
58
- )
59
- _link_inertias_dict: Dict[int, jtp.Matrix] = dataclasses.field(default_factory=dict)
60
-
61
- _joint_friction_static: Dict[int, float] = dataclasses.field(default_factory=dict)
62
- _joint_friction_viscous: Dict[int, float] = dataclasses.field(default_factory=dict)
63
-
64
- _joint_limit_spring: Dict[int, float] = dataclasses.field(default_factory=dict)
65
- _joint_limit_damper: Dict[int, float] = dataclasses.field(default_factory=dict)
66
-
67
- _joint_motor_inertia: Dict[int, float] = dataclasses.field(default_factory=dict)
68
- _joint_motor_gear_ratio: Dict[int, float] = dataclasses.field(default_factory=dict)
69
- _joint_motor_viscous_friction: Dict[int, float] = dataclasses.field(
70
- default_factory=dict
71
- )
72
-
73
- def __post_init__(self):
74
- if self.initial_state is None:
75
- initial_state = PhysicsModelState.zero(physics_model=self)
76
- object.__setattr__(self, "initial_state", initial_state)
77
-
78
- @staticmethod
79
- def build_from(
80
- model_description: jaxsim.parsers.descriptions.model.ModelDescription,
81
- gravity: jtp.Vector = default_gravity(),
82
- ) -> "PhysicsModel":
83
- if gravity.size != 3:
84
- raise ValueError(gravity.size)
85
-
86
- # Currently, we assume that the link frame matches the frame of its parent joint
87
- for l in model_description:
88
- if not jnp.allclose(l.pose, jnp.eye(4)):
89
- raise ValueError(f"Link '{l.name}' has unsupported pose:\n{l.pose}")
90
-
91
- # ===================================
92
- # Initialize physics model parameters
93
- # ===================================
94
-
95
- # Get the number of bodies, including the base link
96
- num_of_bodies = len(model_description)
97
-
98
- # Build the parent array λ of the floating-base model.
99
- # Note: the parent of the base link is not set since it's not defined.
100
- parent_array_dict = {
101
- link.index: link.parent.index
102
- for link in model_description
103
- if link.parent is not None
104
- }
105
-
106
- # Get the 6D inertias of all links
107
- link_spatial_inertias_dict = {
108
- link.index: link.inertia for link in iter(model_description)
109
- }
110
-
111
- # Dict from the joint index to its type.
112
- # Note: the joint index is equal to its child link index.
113
- joint_types_dict = {
114
- joint.index: joint.jtype for joint in model_description.joints
115
- }
116
-
117
- # Dicts from the joint index to the static and viscous friction.
118
- # Note: the joint index is equal to its child link index.
119
- joint_friction_static = {
120
- joint.index: jnp.array(joint.friction_static, dtype=float)
121
- for joint in model_description.joints
122
- }
123
- joint_friction_viscous = {
124
- joint.index: jnp.array(joint.friction_viscous, dtype=float)
125
- for joint in model_description.joints
126
- }
127
-
128
- # Dicts from the joint index to the spring and damper joint limits parameters.
129
- # Note: the joint index is equal to its child link index.
130
- joint_limit_spring = {
131
- joint.index: jnp.array(joint.position_limit_spring, dtype=float)
132
- for joint in model_description.joints
133
- }
134
- joint_limit_damper = {
135
- joint.index: jnp.array(joint.position_limit_damper, dtype=float)
136
- for joint in model_description.joints
137
- }
138
-
139
- # Dicts from the joint index to the motor inertia, gear ratio and viscous friction.
140
- # Note: the joint index is equal to its child link index.
141
- joint_motor_inertia = {
142
- joint.index: jnp.array(joint.motor_inertia, dtype=float)
143
- for joint in model_description.joints
144
- }
145
- joint_motor_gear_ratio = {
146
- joint.index: jnp.array(joint.motor_gear_ratio, dtype=float)
147
- for joint in model_description.joints
148
- }
149
- joint_motor_viscous_friction = {
150
- joint.index: jnp.array(joint.motor_viscous_friction, dtype=float)
151
- for joint in model_description.joints
152
- }
153
-
154
- # Transform between model's root and model's base link
155
- # (this is just the pose of the base link in the SDF description)
156
- base_link = model_description.links_dict[model_description.link_names()[0]]
157
- R_H_B = model_description.transform(name=base_link.name)
158
- tree_transform_0 = se3.SE3.from_matrix(matrix=R_H_B).adjoint()
159
-
160
- # Helper to compute the transform pre(i)_H_λ(i).
161
- # Given a joint 'i', it is the coordinate transform between its predecessor
162
- # frame [pre(i)] and the frame of its parent link [λ(i)].
163
- prei_H_λi = lambda j: model_description.relative_transform(
164
- relative_to=j.name, name=j.parent.name
165
- )
166
-
167
- # Compute the tree transforms: pre(i)_X_λ(i).
168
- # Given a joint 'i', it is the coordinate transform between its predecessor
169
- # frame [pre(i)] and the frame of its parent link [λ(i)].
170
- tree_transforms_dict = {
171
- 0: tree_transform_0,
172
- **{
173
- j.index: se3.SE3.from_matrix(matrix=prei_H_λi(j)).adjoint()
174
- for j in model_description.joints
175
- },
176
- }
177
-
178
- # =======================
179
- # Build the initial state
180
- # =======================
181
-
182
- # Initial joint positions
183
- q0 = jnp.array(
184
- [
185
- model_description.joints_dict[j.name].initial_position
186
- for j in model_description.joints
187
- ]
188
- )
189
-
190
- # Build the initial state
191
- initial_state = PhysicsModelState(
192
- joint_positions=q0,
193
- joint_velocities=jnp.zeros_like(q0),
194
- base_position=model_description.root_pose.root_position,
195
- base_quaternion=model_description.root_pose.root_quaternion,
196
- )
197
-
198
- # =======================
199
- # Build the physics model
200
- # =======================
201
-
202
- # Initialize the model
203
- physics_model = PhysicsModel(
204
- NB=num_of_bodies,
205
- initial_state=initial_state,
206
- _parent_array_dict=parent_array_dict,
207
- _jtype_dict=joint_types_dict,
208
- _tree_transforms_dict=tree_transforms_dict,
209
- _link_inertias_dict=link_spatial_inertias_dict,
210
- _joint_friction_static=joint_friction_static,
211
- _joint_friction_viscous=joint_friction_viscous,
212
- _joint_limit_spring=joint_limit_spring,
213
- _joint_limit_damper=joint_limit_damper,
214
- _joint_motor_gear_ratio=joint_motor_gear_ratio,
215
- _joint_motor_inertia=joint_motor_inertia,
216
- _joint_motor_viscous_friction=joint_motor_viscous_friction,
217
- gravity=jnp.hstack([gravity.squeeze(), np.zeros(3)]),
218
- is_floating_base=True,
219
- gc=GroundContact.build_from(model_description=model_description),
220
- description=model_description,
221
- )
222
-
223
- # Floating-base models
224
- if not model_description.fixed_base:
225
- return physics_model
226
-
227
- # Fixed-base models
228
- with jax_dataclasses.copy_and_mutate(physics_model) as physics_model_fixed:
229
- physics_model_fixed.is_floating_base = False
230
-
231
- return physics_model_fixed
232
-
233
- def dofs(self) -> int:
234
- return len(list(self._jtype_dict.keys()))
235
-
236
- def set_gravity(self, gravity: jtp.Vector) -> None:
237
- gravity = gravity.squeeze()
238
-
239
- if gravity.size == 3:
240
- self.gravity = jnp.hstack([gravity, 0, 0, 0])
241
-
242
- elif gravity.size == 6:
243
- self.gravity = gravity
244
-
245
- else:
246
- raise ValueError(gravity.shape)
247
-
248
- @property
249
- def parent(self) -> jtp.Vector:
250
- return self.parent_array()
251
-
252
- def parent_array(self) -> jtp.Vector:
253
- """Returns λ(i)"""
254
- return jnp.array([-1] + list(self._parent_array_dict.values()), dtype=int)
255
-
256
- def support_body_array(self, body_index: jtp.Int) -> jtp.Vector:
257
- """Returns κ(i)"""
258
-
259
- κ_bool = self.support_body_array_bool(body_index=body_index)
260
- return jnp.array(jnp.where(κ_bool)[0], dtype=int)
261
-
262
- def support_body_array_bool(self, body_index: jtp.Int) -> jtp.Vector:
263
- active_link = body_index
264
- κ_bool = jnp.zeros(self.NB, dtype=bool)
265
-
266
- for i in np.flip(np.arange(start=0, stop=self.NB)):
267
- κ_bool, active_link = jax.lax.cond(
268
- pred=(i == active_link),
269
- false_fun=lambda: (κ_bool, active_link),
270
- true_fun=lambda: (
271
- κ_bool.at[active_link].set(True),
272
- self.parent[active_link],
273
- ),
274
- )
275
-
276
- return κ_bool
277
-
278
- @property
279
- def tree_transforms(self) -> jtp.Array:
280
- X_tree = jnp.array(
281
- [
282
- self._tree_transforms_dict.get(idx, jnp.eye(6))
283
- for idx in np.arange(start=0, stop=self.NB)
284
- ]
285
- )
286
-
287
- return X_tree
288
-
289
- @property
290
- def spatial_inertias(self) -> jtp.Array:
291
- M_links = jnp.array(
292
- [
293
- self._link_inertias_dict.get(idx, jnp.zeros(6))
294
- for idx in np.arange(start=0, stop=self.NB)
295
- ]
296
- )
297
-
298
- return M_links
299
-
300
- def jtype(self, joint_index: int) -> JointType:
301
- if joint_index == 0 or joint_index >= self.NB:
302
- raise ValueError(joint_index)
303
-
304
- return self._jtype_dict[joint_index]
305
-
306
- def joint_transforms(self, q: jtp.Vector) -> jtp.Array:
307
- from jaxsim.math.joint import jcalc
308
-
309
- if not_tracing(q):
310
- if q.shape[0] != self.dofs():
311
- raise ValueError(q.shape)
312
-
313
- Xj = jnp.stack(
314
- [jnp.zeros(shape=(6, 6))]
315
- + [
316
- jcalc(jtyp=self.jtype(index + 1), q=joint_position)[0]
317
- for index, joint_position in enumerate(q)
318
- ]
319
- )
320
-
321
- return Xj
322
-
323
- def motion_subspaces(self, q: jtp.Vector) -> jtp.Array:
324
- from jaxsim.math.joint import jcalc
325
-
326
- if not_tracing(var=q):
327
- if q.shape[0] != self.dofs():
328
- raise ValueError(q.shape)
329
-
330
- SS = jnp.stack(
331
- [jnp.vstack(jnp.zeros(6))]
332
- + [
333
- jcalc(jtyp=self.jtype(index + 1), q=joint_position)[1]
334
- for index, joint_position in enumerate(q)
335
- ]
336
- )
337
-
338
- return SS
339
-
340
- def __eq__(self, other: "PhysicsModel") -> bool:
341
- same = True
342
- same = same and self.NB == other.NB
343
- same = same and np.allclose(self.gravity, other.gravity)
344
-
345
- return same
346
-
347
- def __hash__(self):
348
- return hash(self.__repr__())
349
-
350
- def __repr__(self) -> str:
351
- attributes = [
352
- f"dofs: {self.dofs()},",
353
- f"links: {self.NB},",
354
- f"floating_base: {self.is_floating_base},",
355
- ]
356
- attributes_string = "\n ".join(attributes)
357
-
358
- return f"{type(self).__name__}(\n {attributes_string}\n)"
@@ -1,174 +0,0 @@
1
- import jax.numpy as jnp
2
- import jax_dataclasses
3
-
4
- import jaxsim.physics.model.physics_model
5
- import jaxsim.typing as jtp
6
- from jaxsim.utils import JaxsimDataclass
7
-
8
-
9
- @jax_dataclasses.pytree_dataclass
10
- class PhysicsModelState(JaxsimDataclass):
11
- """
12
- A class representing the state of a physics model.
13
-
14
- This class stores the joint positions, joint velocities, and the base state (position, orientation, linear velocity,
15
- and angular velocity) of a physics model.
16
-
17
- Attributes:
18
- joint_positions (jtp.Vector): An array representing the joint positions.
19
- joint_velocities (jtp.Vector): An array representing the joint velocities.
20
- base_position (jtp.Vector): An array representing the base position (default: zeros).
21
- base_quaternion (jtp.Vector): An array representing the base quaternion (default: [1.0, 0, 0, 0]).
22
- base_linear_velocity (jtp.Vector): An array representing the base linear velocity (default: zeros).
23
- base_angular_velocity (jtp.Vector): An array representing the base angular velocity (default: zeros).
24
- """
25
-
26
- # Joint state
27
- joint_positions: jtp.Vector
28
- joint_velocities: jtp.Vector
29
-
30
- # Base state
31
- base_position: jtp.Vector = jax_dataclasses.field(
32
- default_factory=lambda: jnp.zeros(3)
33
- )
34
- base_quaternion: jtp.Vector = jax_dataclasses.field(
35
- default_factory=lambda: jnp.array([1.0, 0, 0, 0])
36
- )
37
- base_linear_velocity: jtp.Vector = jax_dataclasses.field(
38
- default_factory=lambda: jnp.zeros(3)
39
- )
40
- base_angular_velocity: jtp.Vector = jax_dataclasses.field(
41
- default_factory=lambda: jnp.zeros(3)
42
- )
43
-
44
- @staticmethod
45
- def zero(
46
- physics_model: "jaxsim.physics.model.physics_model.PhysicsModel",
47
- ) -> "PhysicsModelState":
48
- return PhysicsModelState(
49
- joint_positions=jnp.zeros(physics_model.dofs()),
50
- joint_velocities=jnp.zeros(physics_model.dofs()),
51
- )
52
-
53
- def position(self) -> jtp.Vector:
54
- return jnp.hstack(
55
- [self.base_position, self.base_quaternion, self.joint_positions]
56
- )
57
-
58
- def velocity(self) -> jtp.Vector:
59
- # W_v_WB: inertial-fixed representation of the base velocity
60
- return jnp.hstack(
61
- [
62
- self.base_linear_velocity,
63
- self.base_angular_velocity,
64
- self.joint_velocities,
65
- ]
66
- )
67
-
68
- def xfb(self) -> jtp.Vector:
69
- return jnp.hstack(
70
- [
71
- self.base_quaternion,
72
- self.base_position,
73
- self.base_angular_velocity,
74
- self.base_linear_velocity,
75
- ]
76
- )
77
-
78
- def valid(
79
- self, physics_model: "jaxsim.physics.model.physics_model.PhysicsModel"
80
- ) -> bool:
81
- from jaxsim.simulation.utils import check_valid_shape
82
-
83
- valid = True
84
-
85
- valid = check_valid_shape(
86
- what="joint_positions",
87
- shape=self.joint_positions.shape,
88
- expected_shape=(physics_model.dofs(),),
89
- valid=valid,
90
- )
91
-
92
- valid = check_valid_shape(
93
- what="joint_velocities",
94
- shape=self.joint_velocities.shape,
95
- expected_shape=(physics_model.dofs(),),
96
- valid=valid,
97
- )
98
-
99
- valid = check_valid_shape(
100
- what="base_position",
101
- shape=self.base_position.shape,
102
- expected_shape=(3,),
103
- valid=valid,
104
- )
105
-
106
- valid = check_valid_shape(
107
- what="base_quaternion",
108
- shape=self.base_quaternion.shape,
109
- expected_shape=(4,),
110
- valid=valid,
111
- )
112
-
113
- valid = check_valid_shape(
114
- what="base_linear_velocity",
115
- shape=self.base_linear_velocity.shape,
116
- expected_shape=(3,),
117
- valid=valid,
118
- )
119
-
120
- valid = check_valid_shape(
121
- what="base_angular_velocity",
122
- shape=self.base_angular_velocity.shape,
123
- expected_shape=(3,),
124
- valid=valid,
125
- )
126
-
127
- return valid
128
-
129
-
130
- @jax_dataclasses.pytree_dataclass
131
- class PhysicsModelInput(JaxsimDataclass):
132
- tau: jtp.VectorJax
133
- f_ext: jtp.MatrixJax
134
-
135
- @staticmethod
136
- def zero(
137
- physics_model: "jaxsim.physics.model.physics_model.PhysicsModel",
138
- ) -> "PhysicsModelInput":
139
- ode_input = PhysicsModelInput(
140
- tau=jnp.zeros(physics_model.dofs()),
141
- f_ext=jnp.zeros(shape=(physics_model.NB, 6)),
142
- )
143
-
144
- assert ode_input.valid(physics_model)
145
- return ode_input
146
-
147
- def replace(self, validate: bool = True, **kwargs) -> "PhysicsModelInput":
148
- with jax_dataclasses.copy_and_mutate(self, validate=validate) as updated_input:
149
- _ = [updated_input.__setattr__(k, v) for k, v in kwargs.items()]
150
-
151
- return updated_input
152
-
153
- def valid(
154
- self, physics_model: "jaxsim.physics.model.physics_model.PhysicsModel"
155
- ) -> bool:
156
- from jaxsim.simulation.utils import check_valid_shape
157
-
158
- valid = True
159
-
160
- valid = check_valid_shape(
161
- what="tau",
162
- shape=self.tau.shape,
163
- expected_shape=(physics_model.dofs(),),
164
- valid=valid,
165
- )
166
-
167
- valid = check_valid_shape(
168
- what="f_ext",
169
- shape=self.f_ext.shape,
170
- expected_shape=(physics_model.NB, 6),
171
- valid=valid,
172
- )
173
-
174
- return valid
@@ -1,4 +0,0 @@
1
- from . import integrators, ode, ode_data, simulator
2
- from .ode_data import ODEInput, ODEState
3
- from .ode_integration import IntegratorType
4
- from .simulator import JaxSim, SimulatorData