jaxsim 0.1rc0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxsim/__init__.py +5 -6
- jaxsim/_version.py +2 -2
- jaxsim/api/__init__.py +3 -0
- jaxsim/api/com.py +240 -0
- jaxsim/api/common.py +216 -0
- jaxsim/api/contact.py +271 -0
- jaxsim/api/data.py +821 -0
- jaxsim/api/joint.py +189 -0
- jaxsim/api/kin_dyn_parameters.py +777 -0
- jaxsim/api/link.py +361 -0
- jaxsim/api/model.py +1633 -0
- jaxsim/api/ode.py +295 -0
- jaxsim/api/ode_data.py +694 -0
- jaxsim/api/references.py +421 -0
- jaxsim/integrators/__init__.py +2 -0
- jaxsim/integrators/common.py +594 -0
- jaxsim/integrators/fixed_step.py +102 -0
- jaxsim/integrators/variable_step.py +610 -0
- jaxsim/math/__init__.py +11 -0
- jaxsim/math/adjoint.py +24 -2
- jaxsim/math/joint_model.py +335 -0
- jaxsim/math/quaternion.py +44 -3
- jaxsim/math/rotation.py +4 -4
- jaxsim/math/transform.py +92 -0
- jaxsim/mujoco/__init__.py +3 -0
- jaxsim/mujoco/__main__.py +192 -0
- jaxsim/mujoco/loaders.py +615 -0
- jaxsim/mujoco/model.py +414 -0
- jaxsim/mujoco/visualizer.py +176 -0
- jaxsim/parsers/descriptions/collision.py +14 -0
- jaxsim/parsers/descriptions/link.py +13 -2
- jaxsim/parsers/kinematic_graph.py +8 -3
- jaxsim/parsers/rod/parser.py +54 -38
- jaxsim/parsers/rod/utils.py +7 -8
- jaxsim/rbda/__init__.py +7 -0
- jaxsim/rbda/aba.py +295 -0
- jaxsim/rbda/collidable_points.py +142 -0
- jaxsim/{physics/algos → rbda}/crba.py +43 -42
- jaxsim/rbda/forward_kinematics.py +113 -0
- jaxsim/rbda/jacobian.py +201 -0
- jaxsim/rbda/rnea.py +237 -0
- jaxsim/rbda/soft_contacts.py +296 -0
- jaxsim/rbda/utils.py +152 -0
- jaxsim/terrain/__init__.py +2 -0
- jaxsim/{physics/algos → terrain}/terrain.py +4 -6
- jaxsim/typing.py +30 -30
- jaxsim/utils/__init__.py +1 -4
- jaxsim/utils/hashless.py +18 -0
- jaxsim/utils/jaxsim_dataclass.py +281 -31
- {jaxsim-0.1rc0.dist-info → jaxsim-0.2.0.dist-info}/LICENSE +1 -1
- jaxsim-0.2.0.dist-info/METADATA +237 -0
- jaxsim-0.2.0.dist-info/RECORD +64 -0
- {jaxsim-0.1rc0.dist-info → jaxsim-0.2.0.dist-info}/WHEEL +1 -1
- jaxsim/high_level/__init__.py +0 -2
- jaxsim/high_level/common.py +0 -11
- jaxsim/high_level/joint.py +0 -148
- jaxsim/high_level/link.py +0 -259
- jaxsim/high_level/model.py +0 -1695
- jaxsim/math/conv.py +0 -114
- jaxsim/math/joint.py +0 -101
- jaxsim/math/plucker.py +0 -100
- jaxsim/physics/__init__.py +0 -12
- jaxsim/physics/algos/__init__.py +0 -0
- jaxsim/physics/algos/aba.py +0 -256
- jaxsim/physics/algos/aba_motors.py +0 -284
- jaxsim/physics/algos/forward_kinematics.py +0 -79
- jaxsim/physics/algos/jacobian.py +0 -98
- jaxsim/physics/algos/rnea.py +0 -180
- jaxsim/physics/algos/rnea_motors.py +0 -196
- jaxsim/physics/algos/soft_contacts.py +0 -454
- jaxsim/physics/algos/utils.py +0 -69
- jaxsim/physics/model/__init__.py +0 -0
- jaxsim/physics/model/ground_contact.py +0 -55
- jaxsim/physics/model/physics_model.py +0 -358
- jaxsim/physics/model/physics_model_state.py +0 -174
- jaxsim/simulation/__init__.py +0 -4
- jaxsim/simulation/integrators.py +0 -452
- jaxsim/simulation/ode.py +0 -290
- jaxsim/simulation/ode_data.py +0 -53
- jaxsim/simulation/ode_integration.py +0 -125
- jaxsim/simulation/simulator.py +0 -544
- jaxsim/simulation/simulator_callbacks.py +0 -53
- jaxsim/simulation/utils.py +0 -15
- jaxsim/sixd/__init__.py +0 -2
- jaxsim/utils/oop.py +0 -532
- jaxsim/utils/vmappable.py +0 -117
- jaxsim-0.1rc0.dist-info/METADATA +0 -167
- jaxsim-0.1rc0.dist-info/RECORD +0 -64
- {jaxsim-0.1rc0.dist-info → jaxsim-0.2.0.dist-info}/top_level.txt +0 -0
jaxsim/api/joint.py
ADDED
@@ -0,0 +1,189 @@
|
|
1
|
+
import functools
|
2
|
+
from typing import Sequence
|
3
|
+
|
4
|
+
import jax
|
5
|
+
import jax.numpy as jnp
|
6
|
+
import numpy as np
|
7
|
+
|
8
|
+
import jaxsim.api as js
|
9
|
+
import jaxsim.typing as jtp
|
10
|
+
|
11
|
+
# =======================
|
12
|
+
# Index-related functions
|
13
|
+
# =======================
|
14
|
+
|
15
|
+
|
16
|
+
@functools.partial(jax.jit, static_argnames="joint_name")
|
17
|
+
def name_to_idx(model: js.model.JaxSimModel, *, joint_name: str) -> jtp.Int:
|
18
|
+
"""
|
19
|
+
Convert the name of a joint to its index.
|
20
|
+
|
21
|
+
Args:
|
22
|
+
model: The model to consider.
|
23
|
+
joint_name: The name of the joint.
|
24
|
+
|
25
|
+
Returns:
|
26
|
+
The index of the joint.
|
27
|
+
"""
|
28
|
+
|
29
|
+
if joint_name in model.kin_dyn_parameters.joint_model.joint_names:
|
30
|
+
# Note: the index of the joint for RBDAs starts from 1, but
|
31
|
+
# the index for accessing the right element starts from 0.
|
32
|
+
# Therefore, there is a -1.
|
33
|
+
return (
|
34
|
+
jnp.array(
|
35
|
+
np.argwhere(
|
36
|
+
np.array(model.kin_dyn_parameters.joint_model.joint_names)
|
37
|
+
== joint_name
|
38
|
+
)
|
39
|
+
- 1
|
40
|
+
)
|
41
|
+
.squeeze()
|
42
|
+
.astype(int)
|
43
|
+
)
|
44
|
+
return jnp.array(-1).astype(int)
|
45
|
+
|
46
|
+
|
47
|
+
def idx_to_name(model: js.model.JaxSimModel, *, joint_index: jtp.IntLike) -> str:
|
48
|
+
"""
|
49
|
+
Convert the index of a joint to its name.
|
50
|
+
|
51
|
+
Args:
|
52
|
+
model: The model to consider.
|
53
|
+
joint_index: The index of the joint.
|
54
|
+
|
55
|
+
Returns:
|
56
|
+
The name of the joint.
|
57
|
+
"""
|
58
|
+
|
59
|
+
return model.kin_dyn_parameters.joint_model.joint_names[joint_index + 1]
|
60
|
+
|
61
|
+
|
62
|
+
@functools.partial(jax.jit, static_argnames="joint_names")
|
63
|
+
def names_to_idxs(
|
64
|
+
model: js.model.JaxSimModel, *, joint_names: Sequence[str]
|
65
|
+
) -> jax.Array:
|
66
|
+
"""
|
67
|
+
Convert a sequence of joint names to their corresponding indices.
|
68
|
+
|
69
|
+
Args:
|
70
|
+
model: The model to consider.
|
71
|
+
joint_names: The names of the joints.
|
72
|
+
|
73
|
+
Returns:
|
74
|
+
The indices of the joints.
|
75
|
+
"""
|
76
|
+
|
77
|
+
return jnp.array(
|
78
|
+
[name_to_idx(model=model, joint_name=name) for name in joint_names],
|
79
|
+
).astype(int)
|
80
|
+
|
81
|
+
|
82
|
+
def idxs_to_names(
|
83
|
+
model: js.model.JaxSimModel,
|
84
|
+
*,
|
85
|
+
joint_indices: Sequence[jtp.IntLike] | jtp.VectorLike,
|
86
|
+
) -> tuple[str, ...]:
|
87
|
+
"""
|
88
|
+
Convert a sequence of joint indices to their corresponding names.
|
89
|
+
|
90
|
+
Args:
|
91
|
+
model: The model to consider.
|
92
|
+
joint_indices: The indices of the joints.
|
93
|
+
|
94
|
+
Returns:
|
95
|
+
The names of the joints.
|
96
|
+
"""
|
97
|
+
|
98
|
+
return tuple(idx_to_name(model=model, joint_index=idx) for idx in joint_indices)
|
99
|
+
|
100
|
+
|
101
|
+
# ============
|
102
|
+
# Joint limits
|
103
|
+
# ============
|
104
|
+
|
105
|
+
|
106
|
+
@jax.jit
|
107
|
+
def position_limit(
|
108
|
+
model: js.model.JaxSimModel, *, joint_index: jtp.IntLike
|
109
|
+
) -> tuple[jtp.Float, jtp.Float]:
|
110
|
+
"""
|
111
|
+
Get the position limits of a joint.
|
112
|
+
|
113
|
+
Args:
|
114
|
+
model: The model to consider.
|
115
|
+
joint_index: The index of the joint.
|
116
|
+
|
117
|
+
Returns:
|
118
|
+
The position limits of the joint.
|
119
|
+
"""
|
120
|
+
|
121
|
+
if model.number_of_joints() <= 1:
|
122
|
+
return jnp.empty(0).astype(float), jnp.empty(0).astype(float)
|
123
|
+
|
124
|
+
s_min = model.kin_dyn_parameters.joint_parameters.position_limits_min[joint_index]
|
125
|
+
s_max = model.kin_dyn_parameters.joint_parameters.position_limits_max[joint_index]
|
126
|
+
|
127
|
+
return s_min.astype(float), s_max.astype(float)
|
128
|
+
|
129
|
+
|
130
|
+
@functools.partial(jax.jit, static_argnames=["joint_names"])
|
131
|
+
def position_limits(
|
132
|
+
model: js.model.JaxSimModel, *, joint_names: Sequence[str] | None = None
|
133
|
+
) -> tuple[jtp.Vector, jtp.Vector]:
|
134
|
+
"""
|
135
|
+
Get the position limits of a list of joint.
|
136
|
+
|
137
|
+
Args:
|
138
|
+
model: The model to consider.
|
139
|
+
joint_names: The names of the joints.
|
140
|
+
|
141
|
+
Returns:
|
142
|
+
The position limits of the joints.
|
143
|
+
"""
|
144
|
+
|
145
|
+
joint_names = joint_names if joint_names is not None else model.joint_names()
|
146
|
+
|
147
|
+
if len(joint_names) == 0:
|
148
|
+
return jnp.empty(0).astype(float), jnp.empty(0).astype(float)
|
149
|
+
|
150
|
+
joint_idxs = names_to_idxs(joint_names=joint_names, model=model)
|
151
|
+
return jax.vmap(lambda i: position_limit(model=model, joint_index=i))(joint_idxs)
|
152
|
+
|
153
|
+
|
154
|
+
# ======================
|
155
|
+
# Random data generation
|
156
|
+
# ======================
|
157
|
+
|
158
|
+
|
159
|
+
@functools.partial(jax.jit, static_argnames=["joint_names"])
|
160
|
+
def random_joint_positions(
|
161
|
+
model: js.model.JaxSimModel,
|
162
|
+
*,
|
163
|
+
joint_names: Sequence[str] | None = None,
|
164
|
+
key: jax.Array | None = None,
|
165
|
+
) -> jtp.Vector:
|
166
|
+
"""
|
167
|
+
Generate random joint positions.
|
168
|
+
|
169
|
+
Args:
|
170
|
+
model: The model to consider.
|
171
|
+
joint_names: The names of the joints.
|
172
|
+
key: The random key.
|
173
|
+
|
174
|
+
Returns:
|
175
|
+
The random joint positions.
|
176
|
+
"""
|
177
|
+
|
178
|
+
key = key if key is not None else jax.random.PRNGKey(seed=0)
|
179
|
+
|
180
|
+
s_min, s_max = position_limits(model=model, joint_names=joint_names)
|
181
|
+
|
182
|
+
s_random = jax.random.uniform(
|
183
|
+
minval=s_min,
|
184
|
+
maxval=s_max,
|
185
|
+
key=key,
|
186
|
+
shape=s_min.shape,
|
187
|
+
)
|
188
|
+
|
189
|
+
return s_random
|