jaxsim 0.1rc0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. jaxsim/__init__.py +5 -6
  2. jaxsim/_version.py +2 -2
  3. jaxsim/api/__init__.py +3 -0
  4. jaxsim/api/com.py +240 -0
  5. jaxsim/api/common.py +216 -0
  6. jaxsim/api/contact.py +271 -0
  7. jaxsim/api/data.py +821 -0
  8. jaxsim/api/joint.py +189 -0
  9. jaxsim/api/kin_dyn_parameters.py +777 -0
  10. jaxsim/api/link.py +361 -0
  11. jaxsim/api/model.py +1633 -0
  12. jaxsim/api/ode.py +295 -0
  13. jaxsim/api/ode_data.py +694 -0
  14. jaxsim/api/references.py +421 -0
  15. jaxsim/integrators/__init__.py +2 -0
  16. jaxsim/integrators/common.py +594 -0
  17. jaxsim/integrators/fixed_step.py +102 -0
  18. jaxsim/integrators/variable_step.py +610 -0
  19. jaxsim/math/__init__.py +11 -0
  20. jaxsim/math/adjoint.py +24 -2
  21. jaxsim/math/joint_model.py +335 -0
  22. jaxsim/math/quaternion.py +44 -3
  23. jaxsim/math/rotation.py +4 -4
  24. jaxsim/math/transform.py +92 -0
  25. jaxsim/mujoco/__init__.py +3 -0
  26. jaxsim/mujoco/__main__.py +192 -0
  27. jaxsim/mujoco/loaders.py +615 -0
  28. jaxsim/mujoco/model.py +414 -0
  29. jaxsim/mujoco/visualizer.py +176 -0
  30. jaxsim/parsers/descriptions/collision.py +14 -0
  31. jaxsim/parsers/descriptions/link.py +13 -2
  32. jaxsim/parsers/kinematic_graph.py +8 -3
  33. jaxsim/parsers/rod/parser.py +54 -38
  34. jaxsim/parsers/rod/utils.py +7 -8
  35. jaxsim/rbda/__init__.py +7 -0
  36. jaxsim/rbda/aba.py +295 -0
  37. jaxsim/rbda/collidable_points.py +142 -0
  38. jaxsim/{physics/algos → rbda}/crba.py +43 -42
  39. jaxsim/rbda/forward_kinematics.py +113 -0
  40. jaxsim/rbda/jacobian.py +201 -0
  41. jaxsim/rbda/rnea.py +237 -0
  42. jaxsim/rbda/soft_contacts.py +296 -0
  43. jaxsim/rbda/utils.py +152 -0
  44. jaxsim/terrain/__init__.py +2 -0
  45. jaxsim/{physics/algos → terrain}/terrain.py +4 -6
  46. jaxsim/typing.py +30 -30
  47. jaxsim/utils/__init__.py +1 -4
  48. jaxsim/utils/hashless.py +18 -0
  49. jaxsim/utils/jaxsim_dataclass.py +281 -31
  50. {jaxsim-0.1rc0.dist-info → jaxsim-0.2.0.dist-info}/LICENSE +1 -1
  51. jaxsim-0.2.0.dist-info/METADATA +237 -0
  52. jaxsim-0.2.0.dist-info/RECORD +64 -0
  53. {jaxsim-0.1rc0.dist-info → jaxsim-0.2.0.dist-info}/WHEEL +1 -1
  54. jaxsim/high_level/__init__.py +0 -2
  55. jaxsim/high_level/common.py +0 -11
  56. jaxsim/high_level/joint.py +0 -148
  57. jaxsim/high_level/link.py +0 -259
  58. jaxsim/high_level/model.py +0 -1695
  59. jaxsim/math/conv.py +0 -114
  60. jaxsim/math/joint.py +0 -101
  61. jaxsim/math/plucker.py +0 -100
  62. jaxsim/physics/__init__.py +0 -12
  63. jaxsim/physics/algos/__init__.py +0 -0
  64. jaxsim/physics/algos/aba.py +0 -256
  65. jaxsim/physics/algos/aba_motors.py +0 -284
  66. jaxsim/physics/algos/forward_kinematics.py +0 -79
  67. jaxsim/physics/algos/jacobian.py +0 -98
  68. jaxsim/physics/algos/rnea.py +0 -180
  69. jaxsim/physics/algos/rnea_motors.py +0 -196
  70. jaxsim/physics/algos/soft_contacts.py +0 -454
  71. jaxsim/physics/algos/utils.py +0 -69
  72. jaxsim/physics/model/__init__.py +0 -0
  73. jaxsim/physics/model/ground_contact.py +0 -55
  74. jaxsim/physics/model/physics_model.py +0 -358
  75. jaxsim/physics/model/physics_model_state.py +0 -174
  76. jaxsim/simulation/__init__.py +0 -4
  77. jaxsim/simulation/integrators.py +0 -452
  78. jaxsim/simulation/ode.py +0 -290
  79. jaxsim/simulation/ode_data.py +0 -53
  80. jaxsim/simulation/ode_integration.py +0 -125
  81. jaxsim/simulation/simulator.py +0 -544
  82. jaxsim/simulation/simulator_callbacks.py +0 -53
  83. jaxsim/simulation/utils.py +0 -15
  84. jaxsim/sixd/__init__.py +0 -2
  85. jaxsim/utils/oop.py +0 -532
  86. jaxsim/utils/vmappable.py +0 -117
  87. jaxsim-0.1rc0.dist-info/METADATA +0 -167
  88. jaxsim-0.1rc0.dist-info/RECORD +0 -64
  89. {jaxsim-0.1rc0.dist-info → jaxsim-0.2.0.dist-info}/top_level.txt +0 -0
jaxsim/api/joint.py ADDED
@@ -0,0 +1,189 @@
1
+ import functools
2
+ from typing import Sequence
3
+
4
+ import jax
5
+ import jax.numpy as jnp
6
+ import numpy as np
7
+
8
+ import jaxsim.api as js
9
+ import jaxsim.typing as jtp
10
+
11
+ # =======================
12
+ # Index-related functions
13
+ # =======================
14
+
15
+
16
+ @functools.partial(jax.jit, static_argnames="joint_name")
17
+ def name_to_idx(model: js.model.JaxSimModel, *, joint_name: str) -> jtp.Int:
18
+ """
19
+ Convert the name of a joint to its index.
20
+
21
+ Args:
22
+ model: The model to consider.
23
+ joint_name: The name of the joint.
24
+
25
+ Returns:
26
+ The index of the joint.
27
+ """
28
+
29
+ if joint_name in model.kin_dyn_parameters.joint_model.joint_names:
30
+ # Note: the index of the joint for RBDAs starts from 1, but
31
+ # the index for accessing the right element starts from 0.
32
+ # Therefore, there is a -1.
33
+ return (
34
+ jnp.array(
35
+ np.argwhere(
36
+ np.array(model.kin_dyn_parameters.joint_model.joint_names)
37
+ == joint_name
38
+ )
39
+ - 1
40
+ )
41
+ .squeeze()
42
+ .astype(int)
43
+ )
44
+ return jnp.array(-1).astype(int)
45
+
46
+
47
+ def idx_to_name(model: js.model.JaxSimModel, *, joint_index: jtp.IntLike) -> str:
48
+ """
49
+ Convert the index of a joint to its name.
50
+
51
+ Args:
52
+ model: The model to consider.
53
+ joint_index: The index of the joint.
54
+
55
+ Returns:
56
+ The name of the joint.
57
+ """
58
+
59
+ return model.kin_dyn_parameters.joint_model.joint_names[joint_index + 1]
60
+
61
+
62
+ @functools.partial(jax.jit, static_argnames="joint_names")
63
+ def names_to_idxs(
64
+ model: js.model.JaxSimModel, *, joint_names: Sequence[str]
65
+ ) -> jax.Array:
66
+ """
67
+ Convert a sequence of joint names to their corresponding indices.
68
+
69
+ Args:
70
+ model: The model to consider.
71
+ joint_names: The names of the joints.
72
+
73
+ Returns:
74
+ The indices of the joints.
75
+ """
76
+
77
+ return jnp.array(
78
+ [name_to_idx(model=model, joint_name=name) for name in joint_names],
79
+ ).astype(int)
80
+
81
+
82
+ def idxs_to_names(
83
+ model: js.model.JaxSimModel,
84
+ *,
85
+ joint_indices: Sequence[jtp.IntLike] | jtp.VectorLike,
86
+ ) -> tuple[str, ...]:
87
+ """
88
+ Convert a sequence of joint indices to their corresponding names.
89
+
90
+ Args:
91
+ model: The model to consider.
92
+ joint_indices: The indices of the joints.
93
+
94
+ Returns:
95
+ The names of the joints.
96
+ """
97
+
98
+ return tuple(idx_to_name(model=model, joint_index=idx) for idx in joint_indices)
99
+
100
+
101
+ # ============
102
+ # Joint limits
103
+ # ============
104
+
105
+
106
+ @jax.jit
107
+ def position_limit(
108
+ model: js.model.JaxSimModel, *, joint_index: jtp.IntLike
109
+ ) -> tuple[jtp.Float, jtp.Float]:
110
+ """
111
+ Get the position limits of a joint.
112
+
113
+ Args:
114
+ model: The model to consider.
115
+ joint_index: The index of the joint.
116
+
117
+ Returns:
118
+ The position limits of the joint.
119
+ """
120
+
121
+ if model.number_of_joints() <= 1:
122
+ return jnp.empty(0).astype(float), jnp.empty(0).astype(float)
123
+
124
+ s_min = model.kin_dyn_parameters.joint_parameters.position_limits_min[joint_index]
125
+ s_max = model.kin_dyn_parameters.joint_parameters.position_limits_max[joint_index]
126
+
127
+ return s_min.astype(float), s_max.astype(float)
128
+
129
+
130
+ @functools.partial(jax.jit, static_argnames=["joint_names"])
131
+ def position_limits(
132
+ model: js.model.JaxSimModel, *, joint_names: Sequence[str] | None = None
133
+ ) -> tuple[jtp.Vector, jtp.Vector]:
134
+ """
135
+ Get the position limits of a list of joint.
136
+
137
+ Args:
138
+ model: The model to consider.
139
+ joint_names: The names of the joints.
140
+
141
+ Returns:
142
+ The position limits of the joints.
143
+ """
144
+
145
+ joint_names = joint_names if joint_names is not None else model.joint_names()
146
+
147
+ if len(joint_names) == 0:
148
+ return jnp.empty(0).astype(float), jnp.empty(0).astype(float)
149
+
150
+ joint_idxs = names_to_idxs(joint_names=joint_names, model=model)
151
+ return jax.vmap(lambda i: position_limit(model=model, joint_index=i))(joint_idxs)
152
+
153
+
154
+ # ======================
155
+ # Random data generation
156
+ # ======================
157
+
158
+
159
+ @functools.partial(jax.jit, static_argnames=["joint_names"])
160
+ def random_joint_positions(
161
+ model: js.model.JaxSimModel,
162
+ *,
163
+ joint_names: Sequence[str] | None = None,
164
+ key: jax.Array | None = None,
165
+ ) -> jtp.Vector:
166
+ """
167
+ Generate random joint positions.
168
+
169
+ Args:
170
+ model: The model to consider.
171
+ joint_names: The names of the joints.
172
+ key: The random key.
173
+
174
+ Returns:
175
+ The random joint positions.
176
+ """
177
+
178
+ key = key if key is not None else jax.random.PRNGKey(seed=0)
179
+
180
+ s_min, s_max = position_limits(model=model, joint_names=joint_names)
181
+
182
+ s_random = jax.random.uniform(
183
+ minval=s_min,
184
+ maxval=s_max,
185
+ key=key,
186
+ shape=s_min.shape,
187
+ )
188
+
189
+ return s_random