jaxsim 0.2.dev191__py3-none-any.whl → 0.6.1.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. jaxsim/__init__.py +73 -22
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +13 -1
  4. jaxsim/api/com.py +423 -0
  5. jaxsim/api/common.py +48 -19
  6. jaxsim/api/contact.py +604 -52
  7. jaxsim/api/data.py +308 -163
  8. jaxsim/api/frame.py +471 -0
  9. jaxsim/api/joint.py +166 -37
  10. jaxsim/api/kin_dyn_parameters.py +901 -0
  11. jaxsim/api/link.py +277 -78
  12. jaxsim/api/model.py +1572 -362
  13. jaxsim/api/ode.py +324 -133
  14. jaxsim/api/ode_data.py +401 -0
  15. jaxsim/api/references.py +216 -80
  16. jaxsim/exceptions.py +80 -0
  17. jaxsim/integrators/__init__.py +2 -2
  18. jaxsim/integrators/common.py +191 -107
  19. jaxsim/integrators/fixed_step.py +97 -102
  20. jaxsim/integrators/variable_step.py +706 -0
  21. jaxsim/logging.py +1 -2
  22. jaxsim/math/__init__.py +13 -0
  23. jaxsim/math/adjoint.py +57 -22
  24. jaxsim/math/cross.py +16 -7
  25. jaxsim/math/inertia.py +10 -8
  26. jaxsim/math/joint_model.py +289 -0
  27. jaxsim/math/quaternion.py +54 -20
  28. jaxsim/math/rotation.py +27 -21
  29. jaxsim/math/skew.py +16 -5
  30. jaxsim/math/transform.py +102 -0
  31. jaxsim/math/utils.py +31 -0
  32. jaxsim/mujoco/__init__.py +2 -1
  33. jaxsim/mujoco/loaders.py +216 -29
  34. jaxsim/mujoco/model.py +163 -33
  35. jaxsim/mujoco/utils.py +228 -0
  36. jaxsim/mujoco/visualizer.py +107 -22
  37. jaxsim/parsers/__init__.py +0 -1
  38. jaxsim/parsers/descriptions/__init__.py +8 -2
  39. jaxsim/parsers/descriptions/collision.py +83 -26
  40. jaxsim/parsers/descriptions/joint.py +80 -87
  41. jaxsim/parsers/descriptions/link.py +58 -31
  42. jaxsim/parsers/descriptions/model.py +101 -68
  43. jaxsim/parsers/kinematic_graph.py +606 -229
  44. jaxsim/parsers/rod/meshes.py +104 -0
  45. jaxsim/parsers/rod/parser.py +125 -82
  46. jaxsim/parsers/rod/utils.py +127 -82
  47. jaxsim/rbda/__init__.py +11 -0
  48. jaxsim/rbda/aba.py +289 -0
  49. jaxsim/rbda/collidable_points.py +156 -0
  50. jaxsim/rbda/contacts/__init__.py +13 -0
  51. jaxsim/rbda/contacts/common.py +313 -0
  52. jaxsim/rbda/contacts/relaxed_rigid.py +605 -0
  53. jaxsim/rbda/contacts/rigid.py +462 -0
  54. jaxsim/rbda/contacts/soft.py +480 -0
  55. jaxsim/rbda/contacts/visco_elastic.py +1066 -0
  56. jaxsim/rbda/crba.py +167 -0
  57. jaxsim/rbda/forward_kinematics.py +117 -0
  58. jaxsim/rbda/jacobian.py +330 -0
  59. jaxsim/rbda/rnea.py +235 -0
  60. jaxsim/rbda/utils.py +160 -0
  61. jaxsim/terrain/__init__.py +2 -0
  62. jaxsim/terrain/terrain.py +238 -0
  63. jaxsim/typing.py +24 -24
  64. jaxsim/utils/__init__.py +1 -4
  65. jaxsim/utils/jaxsim_dataclass.py +289 -34
  66. jaxsim/utils/tracing.py +5 -11
  67. jaxsim/utils/wrappers.py +159 -0
  68. {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/LICENSE +1 -1
  69. jaxsim-0.6.1.dev2.dist-info/METADATA +465 -0
  70. jaxsim-0.6.1.dev2.dist-info/RECORD +74 -0
  71. {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/WHEEL +1 -1
  72. jaxsim/high_level/__init__.py +0 -2
  73. jaxsim/high_level/common.py +0 -11
  74. jaxsim/high_level/joint.py +0 -148
  75. jaxsim/high_level/link.py +0 -259
  76. jaxsim/high_level/model.py +0 -1686
  77. jaxsim/math/conv.py +0 -114
  78. jaxsim/math/joint.py +0 -102
  79. jaxsim/math/plucker.py +0 -100
  80. jaxsim/physics/__init__.py +0 -12
  81. jaxsim/physics/algos/__init__.py +0 -0
  82. jaxsim/physics/algos/aba.py +0 -254
  83. jaxsim/physics/algos/aba_motors.py +0 -284
  84. jaxsim/physics/algos/crba.py +0 -154
  85. jaxsim/physics/algos/forward_kinematics.py +0 -79
  86. jaxsim/physics/algos/jacobian.py +0 -98
  87. jaxsim/physics/algos/rnea.py +0 -180
  88. jaxsim/physics/algos/rnea_motors.py +0 -196
  89. jaxsim/physics/algos/soft_contacts.py +0 -523
  90. jaxsim/physics/algos/terrain.py +0 -78
  91. jaxsim/physics/algos/utils.py +0 -69
  92. jaxsim/physics/model/__init__.py +0 -0
  93. jaxsim/physics/model/ground_contact.py +0 -53
  94. jaxsim/physics/model/physics_model.py +0 -388
  95. jaxsim/physics/model/physics_model_state.py +0 -283
  96. jaxsim/simulation/__init__.py +0 -4
  97. jaxsim/simulation/integrators.py +0 -393
  98. jaxsim/simulation/ode.py +0 -290
  99. jaxsim/simulation/ode_data.py +0 -96
  100. jaxsim/simulation/ode_integration.py +0 -62
  101. jaxsim/simulation/simulator.py +0 -543
  102. jaxsim/simulation/simulator_callbacks.py +0 -79
  103. jaxsim/simulation/utils.py +0 -15
  104. jaxsim/sixd/__init__.py +0 -2
  105. jaxsim/utils/oop.py +0 -536
  106. jaxsim/utils/vmappable.py +0 -117
  107. jaxsim-0.2.dev191.dist-info/METADATA +0 -184
  108. jaxsim-0.2.dev191.dist-info/RECORD +0 -81
  109. {jaxsim-0.2.dev191.dist-info → jaxsim-0.6.1.dev2.dist-info}/top_level.txt +0 -0
jaxsim/api/joint.py CHANGED
@@ -1,19 +1,21 @@
1
1
  import functools
2
- from typing import Sequence
2
+ from collections.abc import Sequence
3
3
 
4
4
  import jax
5
5
  import jax.numpy as jnp
6
6
 
7
+ import jaxsim.api as js
7
8
  import jaxsim.typing as jtp
8
-
9
- from . import model as Model
9
+ from jaxsim import exceptions
10
10
 
11
11
  # =======================
12
12
  # Index-related functions
13
13
  # =======================
14
14
 
15
15
 
16
- def name_to_idx(model: Model.JaxSimModel, *, joint_name: str) -> jtp.Int:
16
+ @functools.partial(jax.jit, static_argnames="joint_name")
17
+ @js.common.named_scope
18
+ def name_to_idx(model: js.model.JaxSimModel, *, joint_name: str) -> jtp.Int:
17
19
  """
18
20
  Convert the name of a joint to its index.
19
21
 
@@ -25,12 +27,21 @@ def name_to_idx(model: Model.JaxSimModel, *, joint_name: str) -> jtp.Int:
25
27
  The index of the joint.
26
28
  """
27
29
 
28
- return jnp.array(
29
- model.physics_model.description.joints_dict[joint_name].index, dtype=int
30
+ if joint_name not in model.joint_names():
31
+ raise ValueError(f"Joint '{joint_name}' not found in the model.")
32
+
33
+ # Note: the index of the joint for RBDAs starts from 1, but the index for
34
+ # accessing the right element starts from 0. Therefore, there is a -1.
35
+ return (
36
+ jnp.array(
37
+ model.kin_dyn_parameters.joint_model.joint_names.index(joint_name) - 1
38
+ )
39
+ .astype(int)
40
+ .squeeze()
30
41
  )
31
42
 
32
43
 
33
- def idx_to_name(model: Model.JaxSimModel, *, joint_index: jtp.IntLike) -> str:
44
+ def idx_to_name(model: js.model.JaxSimModel, *, joint_index: jtp.IntLike) -> str:
34
45
  """
35
46
  Convert the index of a joint to its name.
36
47
 
@@ -42,11 +53,20 @@ def idx_to_name(model: Model.JaxSimModel, *, joint_index: jtp.IntLike) -> str:
42
53
  The name of the joint.
43
54
  """
44
55
 
45
- d = {j.index: j.name for j in model.physics_model.description.joints_dict.values()}
46
- return d[joint_index]
56
+ exceptions.raise_value_error_if(
57
+ condition=joint_index < 0,
58
+ msg="Invalid joint index '{idx}'",
59
+ idx=joint_index,
60
+ )
61
+
62
+ return model.kin_dyn_parameters.joint_model.joint_names[joint_index + 1]
47
63
 
48
64
 
49
- def names_to_idxs(model: Model.JaxSimModel, *, joint_names: Sequence[str]) -> jax.Array:
65
+ @functools.partial(jax.jit, static_argnames="joint_names")
66
+ @js.common.named_scope
67
+ def names_to_idxs(
68
+ model: js.model.JaxSimModel, *, joint_names: Sequence[str]
69
+ ) -> jax.Array:
50
70
  """
51
71
  Convert a sequence of joint names to their corresponding indices.
52
72
 
@@ -59,19 +79,14 @@ def names_to_idxs(model: Model.JaxSimModel, *, joint_names: Sequence[str]) -> ja
59
79
  """
60
80
 
61
81
  return jnp.array(
62
- [
63
- # Note: the index of the joint for RBDAs starts from 1, but
64
- # the index for accessing the right element starts from 0.
65
- # Therefore, there is a -1.
66
- model.physics_model.description.joints_dict[name].index - 1
67
- for name in joint_names
68
- ],
69
- dtype=int,
70
- )
82
+ [name_to_idx(model=model, joint_name=name) for name in joint_names],
83
+ ).astype(int)
71
84
 
72
85
 
73
86
  def idxs_to_names(
74
- model: Model.JaxSimModel, *, joint_indices: Sequence[jtp.IntLike] | jtp.VectorLike
87
+ model: js.model.JaxSimModel,
88
+ *,
89
+ joint_indices: Sequence[jtp.IntLike] | jtp.VectorLike,
75
90
  ) -> tuple[str, ...]:
76
91
  """
77
92
  Convert a sequence of joint indices to their corresponding names.
@@ -84,12 +99,7 @@ def idxs_to_names(
84
99
  The names of the joints.
85
100
  """
86
101
 
87
- d = {
88
- j.index - 1: j.name
89
- for j in model.physics_model.description.joints_dict.values()
90
- }
91
-
92
- return tuple(d[i] for i in joint_indices)
102
+ return tuple(idx_to_name(model=model, joint_index=idx) for idx in joint_indices)
93
103
 
94
104
 
95
105
  # ============
@@ -99,25 +109,69 @@ def idxs_to_names(
99
109
 
100
110
  @jax.jit
101
111
  def position_limit(
102
- model: Model.JaxSimModel, *, joint_index: jtp.IntLike
112
+ model: js.model.JaxSimModel, *, joint_index: jtp.IntLike
103
113
  ) -> tuple[jtp.Float, jtp.Float]:
104
- """"""
114
+ """
115
+ Get the position limits of a joint.
105
116
 
106
- min = model.physics_model._joint_position_limits_min[joint_index]
107
- max = model.physics_model._joint_position_limits_max[joint_index]
117
+ Args:
118
+ model: The model to consider.
119
+ joint_index: The index of the joint.
108
120
 
109
- return min.astype(float), max.astype(float)
121
+ Returns:
122
+ The position limits of the joint.
123
+ """
124
+
125
+ if model.number_of_joints() == 0:
126
+ return jnp.empty(0).astype(float), jnp.empty(0).astype(float)
127
+
128
+ exceptions.raise_value_error_if(
129
+ condition=jnp.array(
130
+ [joint_index < 0, joint_index >= model.number_of_joints()]
131
+ ).any(),
132
+ msg="Invalid joint index '{idx}'",
133
+ idx=joint_index,
134
+ )
135
+
136
+ s_min = jnp.atleast_1d(
137
+ model.kin_dyn_parameters.joint_parameters.position_limits_min
138
+ )[joint_index]
139
+ s_max = jnp.atleast_1d(
140
+ model.kin_dyn_parameters.joint_parameters.position_limits_max
141
+ )[joint_index]
142
+
143
+ return s_min.astype(float), s_max.astype(float)
110
144
 
111
145
 
112
146
  @functools.partial(jax.jit, static_argnames=["joint_names"])
147
+ @js.common.named_scope
113
148
  def position_limits(
114
- model: Model.JaxSimModel, *, joint_names: Sequence[str] | None = None
149
+ model: js.model.JaxSimModel, *, joint_names: Sequence[str] | None = None
115
150
  ) -> tuple[jtp.Vector, jtp.Vector]:
151
+ """
152
+ Get the position limits of a list of joint.
116
153
 
117
- joint_names = joint_names if joint_names is not None else model.joint_names()
154
+ Args:
155
+ model: The model to consider.
156
+ joint_names: The names of the joints.
157
+
158
+ Returns:
159
+ The position limits of the joints.
160
+ """
161
+
162
+ joint_idxs = (
163
+ names_to_idxs(joint_names=joint_names, model=model)
164
+ if joint_names is not None
165
+ else jnp.arange(model.number_of_joints())
166
+ )
167
+
168
+ if len(joint_idxs) == 0:
169
+ return jnp.empty(0).astype(float), jnp.empty(0).astype(float)
118
170
 
119
- joint_idxs = names_to_idxs(joint_names=joint_names, model=model)
120
- return jax.vmap(lambda i: position_limit(model=model, joint_index=i))(joint_idxs)
171
+ s_min = model.kin_dyn_parameters.joint_parameters.position_limits_min[joint_idxs]
172
+ s_max = model.kin_dyn_parameters.joint_parameters.position_limits_max[joint_idxs]
173
+
174
+ return s_min.astype(float), s_max.astype(float)
121
175
 
122
176
 
123
177
  # ======================
@@ -126,18 +180,93 @@ def position_limits(
126
180
 
127
181
 
128
182
  @functools.partial(jax.jit, static_argnames=["joint_names"])
183
+ @js.common.named_scope
129
184
  def random_joint_positions(
130
- model: Model.JaxSimModel,
185
+ model: js.model.JaxSimModel,
131
186
  *,
132
187
  joint_names: Sequence[str] | None = None,
133
188
  key: jax.Array | None = None,
134
189
  ) -> jtp.Vector:
135
- """"""
190
+ """
191
+ Generate random joint positions.
192
+
193
+ Args:
194
+ model: The model to consider.
195
+ joint_names: The names of the considered joints (all if None).
196
+ key: The random key (initialized from seed 0 if None).
136
197
 
198
+ Note:
199
+ If the joint range or revolute joints is larger than 2π, their joint positions
200
+ will be sampled from an interval of size 2π.
201
+
202
+ Returns:
203
+ The random joint positions.
204
+ """
205
+
206
+ # Consider the key corresponding to a zero seed if it was not passed.
137
207
  key = key if key is not None else jax.random.PRNGKey(seed=0)
138
208
 
209
+ # Get the joint limits parsed from the model description.
139
210
  s_min, s_max = position_limits(model=model, joint_names=joint_names)
140
211
 
212
+ # Get the joint indices.
213
+ # Note that it will trigger an exception if the given `joint_names` are not valid.
214
+ joint_names = joint_names if joint_names is not None else model.joint_names()
215
+ joint_indices = (
216
+ names_to_idxs(model=model, joint_names=joint_names)
217
+ if joint_names is not None
218
+ else jnp.arange(model.number_of_joints())
219
+ )
220
+
221
+ from jaxsim.parsers.descriptions.joint import JointType
222
+
223
+ # Filter for revolute joints.
224
+ is_revolute = jnp.where(
225
+ jnp.array(model.kin_dyn_parameters.joint_model.joint_types[1:])[joint_indices]
226
+ == JointType.Revolute,
227
+ True,
228
+ False,
229
+ )
230
+
231
+ # Shorthand for π.
232
+ π = jnp.pi
233
+
234
+ # Filter for revolute with full range (or continuous).
235
+ is_revolute_full_range = jnp.logical_and(is_revolute, s_max - s_min >= 2 * π)
236
+
237
+ # Clip the lower limit to -π if the joint range is larger than [-π, π].
238
+ s_min = jnp.where(
239
+ jnp.logical_and(
240
+ is_revolute_full_range, jnp.logical_and(s_min <= -π, s_max >= π)
241
+ ),
242
+ -π,
243
+ s_min,
244
+ )
245
+
246
+ # Clip the upper limit to +π if the joint range is larger than [-π, π].
247
+ s_max = jnp.where(
248
+ jnp.logical_and(
249
+ is_revolute_full_range, jnp.logical_and(s_min <= -π, s_max >= π)
250
+ ),
251
+ π,
252
+ s_max,
253
+ )
254
+
255
+ # Shift the lower limit if the upper limit is smaller than +π.
256
+ s_min = jnp.where(
257
+ jnp.logical_and(is_revolute_full_range, s_max < π),
258
+ s_max - 2 * π,
259
+ s_min,
260
+ )
261
+
262
+ # Shift the upper limit if the lower limit is larger than -π.
263
+ s_max = jnp.where(
264
+ jnp.logical_and(is_revolute_full_range, s_min > -π),
265
+ s_min + 2 * π,
266
+ s_max,
267
+ )
268
+
269
+ # Sample the joint positions.
141
270
  s_random = jax.random.uniform(
142
271
  minval=s_min,
143
272
  maxval=s_max,