warp-lang 0.11.0__py3-none-manylinux2014_x86_64.whl → 1.0.0__py3-none-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (170) hide show
  1. warp/__init__.py +8 -0
  2. warp/bin/warp-clang.so +0 -0
  3. warp/bin/warp.so +0 -0
  4. warp/build.py +7 -6
  5. warp/build_dll.py +70 -79
  6. warp/builtins.py +10 -6
  7. warp/codegen.py +51 -19
  8. warp/config.py +7 -8
  9. warp/constants.py +3 -0
  10. warp/context.py +948 -245
  11. warp/dlpack.py +198 -113
  12. warp/examples/assets/bunny.usd +0 -0
  13. warp/examples/assets/cartpole.urdf +110 -0
  14. warp/examples/assets/crazyflie.usd +0 -0
  15. warp/examples/assets/cube.usda +42 -0
  16. warp/examples/assets/nv_ant.xml +92 -0
  17. warp/examples/assets/nv_humanoid.xml +183 -0
  18. warp/examples/assets/quadruped.urdf +268 -0
  19. warp/examples/assets/rocks.nvdb +0 -0
  20. warp/examples/assets/rocks.usd +0 -0
  21. warp/examples/assets/sphere.usda +56 -0
  22. warp/examples/assets/torus.usda +105 -0
  23. warp/examples/benchmarks/benchmark_api.py +383 -0
  24. warp/examples/benchmarks/benchmark_cloth.py +279 -0
  25. warp/examples/benchmarks/benchmark_cloth_cupy.py +88 -0
  26. warp/examples/benchmarks/benchmark_cloth_jax.py +100 -0
  27. warp/examples/benchmarks/benchmark_cloth_numba.py +142 -0
  28. warp/examples/benchmarks/benchmark_cloth_numpy.py +77 -0
  29. warp/examples/benchmarks/benchmark_cloth_pytorch.py +86 -0
  30. warp/examples/benchmarks/benchmark_cloth_taichi.py +112 -0
  31. warp/examples/benchmarks/benchmark_cloth_warp.py +146 -0
  32. warp/examples/benchmarks/benchmark_launches.py +295 -0
  33. warp/examples/core/example_dem.py +221 -0
  34. warp/examples/core/example_fluid.py +267 -0
  35. warp/examples/core/example_graph_capture.py +129 -0
  36. warp/examples/core/example_marching_cubes.py +177 -0
  37. warp/examples/core/example_mesh.py +154 -0
  38. warp/examples/core/example_mesh_intersect.py +193 -0
  39. warp/examples/core/example_nvdb.py +169 -0
  40. warp/examples/core/example_raycast.py +89 -0
  41. warp/examples/core/example_raymarch.py +178 -0
  42. warp/examples/core/example_render_opengl.py +141 -0
  43. warp/examples/core/example_sph.py +389 -0
  44. warp/examples/core/example_torch.py +181 -0
  45. warp/examples/core/example_wave.py +249 -0
  46. warp/examples/fem/bsr_utils.py +380 -0
  47. warp/examples/fem/example_apic_fluid.py +391 -0
  48. warp/examples/fem/example_convection_diffusion.py +168 -0
  49. warp/examples/fem/example_convection_diffusion_dg.py +209 -0
  50. warp/examples/fem/example_convection_diffusion_dg0.py +194 -0
  51. warp/examples/fem/example_deformed_geometry.py +159 -0
  52. warp/examples/fem/example_diffusion.py +173 -0
  53. warp/examples/fem/example_diffusion_3d.py +152 -0
  54. warp/examples/fem/example_diffusion_mgpu.py +214 -0
  55. warp/examples/fem/example_mixed_elasticity.py +222 -0
  56. warp/examples/fem/example_navier_stokes.py +243 -0
  57. warp/examples/fem/example_stokes.py +192 -0
  58. warp/examples/fem/example_stokes_transfer.py +249 -0
  59. warp/examples/fem/mesh_utils.py +109 -0
  60. warp/examples/fem/plot_utils.py +287 -0
  61. warp/examples/optim/example_bounce.py +248 -0
  62. warp/examples/optim/example_cloth_throw.py +210 -0
  63. warp/examples/optim/example_diffray.py +535 -0
  64. warp/examples/optim/example_drone.py +850 -0
  65. warp/examples/optim/example_inverse_kinematics.py +169 -0
  66. warp/examples/optim/example_inverse_kinematics_torch.py +170 -0
  67. warp/examples/optim/example_spring_cage.py +234 -0
  68. warp/examples/optim/example_trajectory.py +201 -0
  69. warp/examples/sim/example_cartpole.py +128 -0
  70. warp/examples/sim/example_cloth.py +184 -0
  71. warp/examples/sim/example_granular.py +113 -0
  72. warp/examples/sim/example_granular_collision_sdf.py +185 -0
  73. warp/examples/sim/example_jacobian_ik.py +213 -0
  74. warp/examples/sim/example_particle_chain.py +106 -0
  75. warp/examples/sim/example_quadruped.py +179 -0
  76. warp/examples/sim/example_rigid_chain.py +191 -0
  77. warp/examples/sim/example_rigid_contact.py +176 -0
  78. warp/examples/sim/example_rigid_force.py +126 -0
  79. warp/examples/sim/example_rigid_gyroscopic.py +97 -0
  80. warp/examples/sim/example_rigid_soft_contact.py +124 -0
  81. warp/examples/sim/example_soft_body.py +178 -0
  82. warp/fabric.py +29 -20
  83. warp/fem/cache.py +0 -1
  84. warp/fem/dirichlet.py +0 -2
  85. warp/fem/integrate.py +0 -1
  86. warp/jax.py +45 -0
  87. warp/jax_experimental.py +339 -0
  88. warp/native/builtin.h +12 -0
  89. warp/native/bvh.cu +18 -18
  90. warp/native/clang/clang.cpp +8 -3
  91. warp/native/cuda_util.cpp +94 -5
  92. warp/native/cuda_util.h +35 -6
  93. warp/native/cutlass_gemm.cpp +1 -1
  94. warp/native/cutlass_gemm.cu +4 -1
  95. warp/native/error.cpp +66 -0
  96. warp/native/error.h +27 -0
  97. warp/native/mesh.cu +2 -2
  98. warp/native/reduce.cu +4 -4
  99. warp/native/runlength_encode.cu +2 -2
  100. warp/native/scan.cu +2 -2
  101. warp/native/sparse.cu +0 -1
  102. warp/native/temp_buffer.h +2 -2
  103. warp/native/warp.cpp +95 -60
  104. warp/native/warp.cu +1053 -218
  105. warp/native/warp.h +49 -32
  106. warp/optim/linear.py +33 -16
  107. warp/render/render_opengl.py +202 -101
  108. warp/render/render_usd.py +82 -40
  109. warp/sim/__init__.py +13 -4
  110. warp/sim/articulation.py +4 -5
  111. warp/sim/collide.py +320 -175
  112. warp/sim/import_mjcf.py +25 -30
  113. warp/sim/import_urdf.py +94 -63
  114. warp/sim/import_usd.py +51 -36
  115. warp/sim/inertia.py +3 -2
  116. warp/sim/integrator.py +233 -0
  117. warp/sim/integrator_euler.py +447 -469
  118. warp/sim/integrator_featherstone.py +1991 -0
  119. warp/sim/integrator_xpbd.py +1420 -640
  120. warp/sim/model.py +765 -487
  121. warp/sim/particles.py +2 -1
  122. warp/sim/render.py +35 -13
  123. warp/sim/utils.py +222 -11
  124. warp/stubs.py +8 -0
  125. warp/tape.py +16 -1
  126. warp/tests/aux_test_grad_customs.py +23 -0
  127. warp/tests/test_array.py +190 -1
  128. warp/tests/test_async.py +656 -0
  129. warp/tests/test_bool.py +50 -0
  130. warp/tests/test_dlpack.py +164 -11
  131. warp/tests/test_examples.py +166 -74
  132. warp/tests/test_fem.py +8 -1
  133. warp/tests/test_generics.py +15 -5
  134. warp/tests/test_grad.py +1 -1
  135. warp/tests/test_grad_customs.py +172 -12
  136. warp/tests/test_jax.py +254 -0
  137. warp/tests/test_large.py +29 -6
  138. warp/tests/test_launch.py +25 -0
  139. warp/tests/test_linear_solvers.py +20 -3
  140. warp/tests/test_matmul.py +61 -16
  141. warp/tests/test_matmul_lite.py +13 -13
  142. warp/tests/test_mempool.py +186 -0
  143. warp/tests/test_multigpu.py +3 -0
  144. warp/tests/test_options.py +16 -2
  145. warp/tests/test_peer.py +137 -0
  146. warp/tests/test_print.py +3 -1
  147. warp/tests/test_quat.py +23 -0
  148. warp/tests/test_sim_kinematics.py +97 -0
  149. warp/tests/test_snippet.py +126 -3
  150. warp/tests/test_streams.py +108 -79
  151. warp/tests/test_torch.py +16 -8
  152. warp/tests/test_utils.py +32 -27
  153. warp/tests/test_verify_fp.py +65 -0
  154. warp/tests/test_volume.py +1 -1
  155. warp/tests/unittest_serial.py +2 -0
  156. warp/tests/unittest_suites.py +12 -0
  157. warp/tests/unittest_utils.py +14 -7
  158. warp/thirdparty/unittest_parallel.py +15 -3
  159. warp/torch.py +10 -8
  160. warp/types.py +363 -246
  161. warp/utils.py +143 -19
  162. warp_lang-1.0.0.dist-info/LICENSE.md +126 -0
  163. warp_lang-1.0.0.dist-info/METADATA +394 -0
  164. {warp_lang-0.11.0.dist-info → warp_lang-1.0.0.dist-info}/RECORD +167 -86
  165. warp/sim/optimizer.py +0 -138
  166. warp_lang-0.11.0.dist-info/LICENSE.md +0 -36
  167. warp_lang-0.11.0.dist-info/METADATA +0 -238
  168. /warp/tests/{walkthough_debug.py → walkthrough_debug.py} +0 -0
  169. {warp_lang-0.11.0.dist-info → warp_lang-1.0.0.dist-info}/WHEEL +0 -0
  170. {warp_lang-0.11.0.dist-info → warp_lang-1.0.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1991 @@
1
+ # Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
2
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
3
+ # and proprietary rights in and to this software, related documentation
4
+ # and any modifications thereto. Any use, reproduction, disclosure or
5
+ # distribution of this software and related documentation without an express
6
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
+
8
+ import warp as wp
9
+
10
+ from .model import Model, State, Control
11
+
12
+ from .integrator import Integrator
13
+
14
+ from .integrator_euler import (
15
+ eval_spring_forces,
16
+ eval_triangle_forces,
17
+ eval_triangle_contact_forces,
18
+ eval_bending_forces,
19
+ eval_tetrahedral_forces,
20
+ eval_particle_forces,
21
+ eval_particle_ground_contact_forces,
22
+ eval_particle_body_contact_forces,
23
+ eval_muscle_forces,
24
+ eval_rigid_contacts,
25
+ eval_joint_force,
26
+ )
27
+
28
+ from .articulation import (
29
+ compute_2d_rotational_dofs,
30
+ compute_3d_rotational_dofs,
31
+ )
32
+
33
+
34
+ # Frank & Park definition 3.20, pg 100
35
+ @wp.func
36
+ def transform_twist(t: wp.transform, x: wp.spatial_vector):
37
+ q = wp.transform_get_rotation(t)
38
+ p = wp.transform_get_translation(t)
39
+
40
+ w = wp.spatial_top(x)
41
+ v = wp.spatial_bottom(x)
42
+
43
+ w = wp.quat_rotate(q, w)
44
+ v = wp.quat_rotate(q, v) + wp.cross(p, w)
45
+
46
+ return wp.spatial_vector(w, v)
47
+
48
+
49
+ @wp.func
50
+ def transform_wrench(t: wp.transform, x: wp.spatial_vector):
51
+ q = wp.transform_get_rotation(t)
52
+ p = wp.transform_get_translation(t)
53
+
54
+ w = wp.spatial_top(x)
55
+ v = wp.spatial_bottom(x)
56
+
57
+ v = wp.quat_rotate(q, v)
58
+ w = wp.quat_rotate(q, w) + wp.cross(p, v)
59
+
60
+ return wp.spatial_vector(w, v)
61
+
62
+
63
+ @wp.func
64
+ def spatial_adjoint(R: wp.mat33, S: wp.mat33):
65
+ # T = [R 0]
66
+ # [S R]
67
+
68
+ # fmt: off
69
+ return wp.spatial_matrix(
70
+ R[0, 0], R[0, 1], R[0, 2], 0.0, 0.0, 0.0,
71
+ R[1, 0], R[1, 1], R[1, 2], 0.0, 0.0, 0.0,
72
+ R[2, 0], R[2, 1], R[2, 2], 0.0, 0.0, 0.0,
73
+ S[0, 0], S[0, 1], S[0, 2], R[0, 0], R[0, 1], R[0, 2],
74
+ S[1, 0], S[1, 1], S[1, 2], R[1, 0], R[1, 1], R[1, 2],
75
+ S[2, 0], S[2, 1], S[2, 2], R[2, 0], R[2, 1], R[2, 2],
76
+ )
77
+ # fmt: on
78
+
79
+
80
+ @wp.kernel
81
+ def compute_spatial_inertia(
82
+ body_inertia: wp.array(dtype=wp.mat33),
83
+ body_mass: wp.array(dtype=float),
84
+ # outputs
85
+ body_I_m: wp.array(dtype=wp.spatial_matrix),
86
+ ):
87
+ tid = wp.tid()
88
+ I = body_inertia[tid]
89
+ m = body_mass[tid]
90
+ # fmt: off
91
+ body_I_m[tid] = wp.spatial_matrix(
92
+ I[0, 0], I[0, 1], I[0, 2], 0.0, 0.0, 0.0,
93
+ I[1, 0], I[1, 1], I[1, 2], 0.0, 0.0, 0.0,
94
+ I[2, 0], I[2, 1], I[2, 2], 0.0, 0.0, 0.0,
95
+ 0.0, 0.0, 0.0, m, 0.0, 0.0,
96
+ 0.0, 0.0, 0.0, 0.0, m, 0.0,
97
+ 0.0, 0.0, 0.0, 0.0, 0.0, m,
98
+ )
99
+ # fmt: on
100
+
101
+
102
+ @wp.kernel
103
+ def compute_com_transforms(
104
+ body_com: wp.array(dtype=wp.vec3),
105
+ # outputs
106
+ body_X_com: wp.array(dtype=wp.transform),
107
+ ):
108
+ tid = wp.tid()
109
+ com = body_com[tid]
110
+ body_X_com[tid] = wp.transform(com, wp.quat_identity())
111
+
112
+
113
+ # computes adj_t^-T*I*adj_t^-1 (tensor change of coordinates), Frank & Park, section 8.2.3, pg 290
114
+ @wp.func
115
+ def spatial_transform_inertia(t: wp.transform, I: wp.spatial_matrix):
116
+ t_inv = wp.transform_inverse(t)
117
+
118
+ q = wp.transform_get_rotation(t_inv)
119
+ p = wp.transform_get_translation(t_inv)
120
+
121
+ r1 = wp.quat_rotate(q, wp.vec3(1.0, 0.0, 0.0))
122
+ r2 = wp.quat_rotate(q, wp.vec3(0.0, 1.0, 0.0))
123
+ r3 = wp.quat_rotate(q, wp.vec3(0.0, 0.0, 1.0))
124
+
125
+ R = wp.mat33(r1, r2, r3)
126
+ S = wp.skew(p) @ R
127
+
128
+ T = spatial_adjoint(R, S)
129
+
130
+ return wp.mul(wp.mul(wp.transpose(T), I), T)
131
+
132
+
133
+ # compute transform across a joint
134
+ @wp.func
135
+ def jcalc_transform(
136
+ type: int,
137
+ joint_axis: wp.array(dtype=wp.vec3),
138
+ axis_start: int,
139
+ lin_axis_count: int,
140
+ ang_axis_count: int,
141
+ joint_q: wp.array(dtype=float),
142
+ start: int,
143
+ ):
144
+ if type == wp.sim.JOINT_PRISMATIC:
145
+ q = joint_q[start]
146
+ axis = joint_axis[axis_start]
147
+ X_jc = wp.transform(axis * q, wp.quat_identity())
148
+ return X_jc
149
+
150
+ if type == wp.sim.JOINT_REVOLUTE:
151
+ q = joint_q[start]
152
+ axis = joint_axis[axis_start]
153
+ X_jc = wp.transform(wp.vec3(), wp.quat_from_axis_angle(axis, q))
154
+ return X_jc
155
+
156
+ if type == wp.sim.JOINT_BALL:
157
+ qx = joint_q[start + 0]
158
+ qy = joint_q[start + 1]
159
+ qz = joint_q[start + 2]
160
+ qw = joint_q[start + 3]
161
+
162
+ X_jc = wp.transform(wp.vec3(), wp.quat(qx, qy, qz, qw))
163
+ return X_jc
164
+
165
+ if type == wp.sim.JOINT_FIXED:
166
+ X_jc = wp.transform_identity()
167
+ return X_jc
168
+
169
+ if type == wp.sim.JOINT_FREE or type == wp.sim.JOINT_DISTANCE:
170
+ px = joint_q[start + 0]
171
+ py = joint_q[start + 1]
172
+ pz = joint_q[start + 2]
173
+
174
+ qx = joint_q[start + 3]
175
+ qy = joint_q[start + 4]
176
+ qz = joint_q[start + 5]
177
+ qw = joint_q[start + 6]
178
+
179
+ X_jc = wp.transform(wp.vec3(px, py, pz), wp.quat(qx, qy, qz, qw))
180
+ return X_jc
181
+
182
+ if type == wp.sim.JOINT_COMPOUND:
183
+ rot, _ = compute_3d_rotational_dofs(
184
+ joint_axis[axis_start],
185
+ joint_axis[axis_start + 1],
186
+ joint_axis[axis_start + 2],
187
+ joint_q[start + 0],
188
+ joint_q[start + 1],
189
+ joint_q[start + 2],
190
+ 0.0,
191
+ 0.0,
192
+ 0.0,
193
+ )
194
+
195
+ X_jc = wp.transform(wp.vec3(), rot)
196
+ return X_jc
197
+
198
+ if type == wp.sim.JOINT_UNIVERSAL:
199
+ rot, _ = compute_2d_rotational_dofs(
200
+ joint_axis[axis_start],
201
+ joint_axis[axis_start + 1],
202
+ joint_q[start + 0],
203
+ joint_q[start + 1],
204
+ 0.0,
205
+ 0.0,
206
+ )
207
+
208
+ X_jc = wp.transform(wp.vec3(), rot)
209
+ return X_jc
210
+
211
+ if type == wp.sim.JOINT_D6:
212
+ pos = wp.vec3(0.0)
213
+ rot = wp.quat_identity()
214
+
215
+ # unroll for loop to ensure joint actions remain differentiable
216
+ # (since differentiating through a for loop that updates a local variable is not supported)
217
+
218
+ if lin_axis_count > 0:
219
+ axis = joint_axis[axis_start + 0]
220
+ pos += axis * joint_q[start + 0]
221
+ if lin_axis_count > 1:
222
+ axis = joint_axis[axis_start + 1]
223
+ pos += axis * joint_q[start + 1]
224
+ if lin_axis_count > 2:
225
+ axis = joint_axis[axis_start + 2]
226
+ pos += axis * joint_q[start + 2]
227
+
228
+ ia = axis_start + lin_axis_count
229
+ iq = start + lin_axis_count
230
+ if ang_axis_count == 1:
231
+ axis = joint_axis[ia]
232
+ rot = wp.quat_from_axis_angle(axis, joint_q[iq])
233
+ if ang_axis_count == 2:
234
+ rot, _ = compute_2d_rotational_dofs(
235
+ joint_axis[ia + 0],
236
+ joint_axis[ia + 1],
237
+ joint_q[iq + 0],
238
+ joint_q[iq + 1],
239
+ 0.0,
240
+ 0.0,
241
+ )
242
+ if ang_axis_count == 3:
243
+ rot, _ = compute_3d_rotational_dofs(
244
+ joint_axis[ia + 0],
245
+ joint_axis[ia + 1],
246
+ joint_axis[ia + 2],
247
+ joint_q[iq + 0],
248
+ joint_q[iq + 1],
249
+ joint_q[iq + 2],
250
+ 0.0,
251
+ 0.0,
252
+ 0.0,
253
+ )
254
+
255
+ X_jc = wp.transform(pos, rot)
256
+ return X_jc
257
+
258
+ # default case
259
+ return wp.transform_identity()
260
+
261
+
262
+ # compute motion subspace and velocity for a joint
263
+ @wp.func
264
+ def jcalc_motion(
265
+ type: int,
266
+ joint_axis: wp.array(dtype=wp.vec3),
267
+ axis_start: int,
268
+ lin_axis_count: int,
269
+ ang_axis_count: int,
270
+ X_sc: wp.transform,
271
+ joint_q: wp.array(dtype=float),
272
+ joint_qd: wp.array(dtype=float),
273
+ q_start: int,
274
+ qd_start: int,
275
+ # outputs
276
+ joint_S_s: wp.array(dtype=wp.spatial_vector),
277
+ ):
278
+ if type == wp.sim.JOINT_PRISMATIC:
279
+ axis = joint_axis[axis_start]
280
+ S_s = transform_twist(X_sc, wp.spatial_vector(wp.vec3(), axis))
281
+ v_j_s = S_s * joint_qd[qd_start]
282
+ joint_S_s[qd_start] = S_s
283
+ return v_j_s
284
+
285
+ if type == wp.sim.JOINT_REVOLUTE:
286
+ axis = joint_axis[axis_start]
287
+ S_s = transform_twist(X_sc, wp.spatial_vector(axis, wp.vec3()))
288
+ v_j_s = S_s * joint_qd[qd_start]
289
+ joint_S_s[qd_start] = S_s
290
+ return v_j_s
291
+
292
+ if type == wp.sim.JOINT_UNIVERSAL:
293
+ axis_0 = joint_axis[axis_start + 0]
294
+ axis_1 = joint_axis[axis_start + 1]
295
+ q_off = wp.quat_from_matrix(wp.mat33(axis_0, axis_1, wp.cross(axis_0, axis_1)))
296
+ local_0 = wp.quat_rotate(q_off, wp.vec3(1.0, 0.0, 0.0))
297
+ local_1 = wp.quat_rotate(q_off, wp.vec3(0.0, 1.0, 0.0))
298
+
299
+ axis_0 = local_0
300
+ q_0 = wp.quat_from_axis_angle(axis_0, joint_q[q_start + 0])
301
+
302
+ axis_1 = wp.quat_rotate(q_0, local_1)
303
+
304
+ S_0 = transform_twist(X_sc, wp.spatial_vector(axis_0, wp.vec3()))
305
+ S_1 = transform_twist(X_sc, wp.spatial_vector(axis_1, wp.vec3()))
306
+
307
+ joint_S_s[qd_start + 0] = S_0
308
+ joint_S_s[qd_start + 1] = S_1
309
+
310
+ return S_0 * joint_qd[qd_start + 0] + S_1 * joint_qd[qd_start + 1]
311
+
312
+ if type == wp.sim.JOINT_COMPOUND:
313
+ axis_0 = joint_axis[axis_start + 0]
314
+ axis_1 = joint_axis[axis_start + 1]
315
+ axis_2 = joint_axis[axis_start + 2]
316
+ q_off = wp.quat_from_matrix(wp.mat33(axis_0, axis_1, axis_2))
317
+ local_0 = wp.quat_rotate(q_off, wp.vec3(1.0, 0.0, 0.0))
318
+ local_1 = wp.quat_rotate(q_off, wp.vec3(0.0, 1.0, 0.0))
319
+ local_2 = wp.quat_rotate(q_off, wp.vec3(0.0, 0.0, 1.0))
320
+
321
+ axis_0 = local_0
322
+ q_0 = wp.quat_from_axis_angle(axis_0, joint_q[q_start + 0])
323
+
324
+ axis_1 = wp.quat_rotate(q_0, local_1)
325
+ q_1 = wp.quat_from_axis_angle(axis_1, joint_q[q_start + 1])
326
+
327
+ axis_2 = wp.quat_rotate(q_1 * q_0, local_2)
328
+
329
+ S_0 = transform_twist(X_sc, wp.spatial_vector(axis_0, wp.vec3()))
330
+ S_1 = transform_twist(X_sc, wp.spatial_vector(axis_1, wp.vec3()))
331
+ S_2 = transform_twist(X_sc, wp.spatial_vector(axis_2, wp.vec3()))
332
+
333
+ joint_S_s[qd_start + 0] = S_0
334
+ joint_S_s[qd_start + 1] = S_1
335
+ joint_S_s[qd_start + 2] = S_2
336
+
337
+ return S_0 * joint_qd[qd_start + 0] + S_1 * joint_qd[qd_start + 1] + S_2 * joint_qd[qd_start + 2]
338
+
339
+ if type == wp.sim.JOINT_D6:
340
+ v_j_s = wp.spatial_vector()
341
+ if lin_axis_count > 0:
342
+ axis = joint_axis[axis_start + 0]
343
+ S_s = transform_twist(X_sc, wp.spatial_vector(wp.vec3(), axis))
344
+ v_j_s += S_s * joint_qd[qd_start + 0]
345
+ joint_S_s[qd_start + 0] = S_s
346
+ if lin_axis_count > 1:
347
+ axis = joint_axis[axis_start + 1]
348
+ S_s = transform_twist(X_sc, wp.spatial_vector(wp.vec3(), axis))
349
+ v_j_s += S_s * joint_qd[qd_start + 1]
350
+ joint_S_s[qd_start + 1] = S_s
351
+ if lin_axis_count > 2:
352
+ axis = joint_axis[axis_start + 2]
353
+ S_s = transform_twist(X_sc, wp.spatial_vector(wp.vec3(), axis))
354
+ v_j_s += S_s * joint_qd[qd_start + 2]
355
+ joint_S_s[qd_start + 2] = S_s
356
+ if ang_axis_count > 0:
357
+ axis = joint_axis[axis_start + lin_axis_count + 0]
358
+ S_s = transform_twist(X_sc, wp.spatial_vector(axis, wp.vec3()))
359
+ v_j_s += S_s * joint_qd[qd_start + lin_axis_count + 0]
360
+ joint_S_s[qd_start + lin_axis_count + 0] = S_s
361
+ if ang_axis_count > 1:
362
+ axis = joint_axis[axis_start + lin_axis_count + 1]
363
+ S_s = transform_twist(X_sc, wp.spatial_vector(axis, wp.vec3()))
364
+ v_j_s += S_s * joint_qd[qd_start + lin_axis_count + 1]
365
+ joint_S_s[qd_start + lin_axis_count + 1] = S_s
366
+ if ang_axis_count > 2:
367
+ axis = joint_axis[axis_start + lin_axis_count + 2]
368
+ S_s = transform_twist(X_sc, wp.spatial_vector(axis, wp.vec3()))
369
+ v_j_s += S_s * joint_qd[qd_start + lin_axis_count + 2]
370
+ joint_S_s[qd_start + lin_axis_count + 2] = S_s
371
+
372
+ return v_j_s
373
+
374
+ if type == wp.sim.JOINT_BALL:
375
+ S_0 = transform_twist(X_sc, wp.spatial_vector(1.0, 0.0, 0.0, 0.0, 0.0, 0.0))
376
+ S_1 = transform_twist(X_sc, wp.spatial_vector(0.0, 1.0, 0.0, 0.0, 0.0, 0.0))
377
+ S_2 = transform_twist(X_sc, wp.spatial_vector(0.0, 0.0, 1.0, 0.0, 0.0, 0.0))
378
+
379
+ joint_S_s[qd_start + 0] = S_0
380
+ joint_S_s[qd_start + 1] = S_1
381
+ joint_S_s[qd_start + 2] = S_2
382
+
383
+ return S_0 * joint_qd[qd_start + 0] + S_1 * joint_qd[qd_start + 1] + S_2 * joint_qd[qd_start + 2]
384
+
385
+ if type == wp.sim.JOINT_FIXED:
386
+ return wp.spatial_vector()
387
+
388
+ if type == wp.sim.JOINT_FREE or type == wp.sim.JOINT_DISTANCE:
389
+ v_j_s = transform_twist(
390
+ X_sc,
391
+ wp.spatial_vector(
392
+ joint_qd[qd_start + 0],
393
+ joint_qd[qd_start + 1],
394
+ joint_qd[qd_start + 2],
395
+ joint_qd[qd_start + 3],
396
+ joint_qd[qd_start + 4],
397
+ joint_qd[qd_start + 5],
398
+ ),
399
+ )
400
+
401
+ joint_S_s[qd_start + 0] = transform_twist(X_sc, wp.spatial_vector(1.0, 0.0, 0.0, 0.0, 0.0, 0.0))
402
+ joint_S_s[qd_start + 1] = transform_twist(X_sc, wp.spatial_vector(0.0, 1.0, 0.0, 0.0, 0.0, 0.0))
403
+ joint_S_s[qd_start + 2] = transform_twist(X_sc, wp.spatial_vector(0.0, 0.0, 1.0, 0.0, 0.0, 0.0))
404
+ joint_S_s[qd_start + 3] = transform_twist(X_sc, wp.spatial_vector(0.0, 0.0, 0.0, 1.0, 0.0, 0.0))
405
+ joint_S_s[qd_start + 4] = transform_twist(X_sc, wp.spatial_vector(0.0, 0.0, 0.0, 0.0, 1.0, 0.0))
406
+ joint_S_s[qd_start + 5] = transform_twist(X_sc, wp.spatial_vector(0.0, 0.0, 0.0, 0.0, 0.0, 1.0))
407
+
408
+ return v_j_s
409
+
410
+ wp.printf("jcalc_motion not implemented for joint type %d\n", type)
411
+
412
+ # default case
413
+ return wp.spatial_vector()
414
+
415
+
416
+ # computes joint space forces/torques in tau
417
+ @wp.func
418
+ def jcalc_tau(
419
+ type: int,
420
+ joint_target_ke: wp.array(dtype=float),
421
+ joint_target_kd: wp.array(dtype=float),
422
+ joint_limit_ke: wp.array(dtype=float),
423
+ joint_limit_kd: wp.array(dtype=float),
424
+ joint_S_s: wp.array(dtype=wp.spatial_vector),
425
+ joint_q: wp.array(dtype=float),
426
+ joint_qd: wp.array(dtype=float),
427
+ joint_act: wp.array(dtype=float),
428
+ joint_axis_mode: wp.array(dtype=int),
429
+ joint_limit_lower: wp.array(dtype=float),
430
+ joint_limit_upper: wp.array(dtype=float),
431
+ coord_start: int,
432
+ dof_start: int,
433
+ axis_start: int,
434
+ lin_axis_count: int,
435
+ ang_axis_count: int,
436
+ body_f_s: wp.spatial_vector,
437
+ # outputs
438
+ tau: wp.array(dtype=float),
439
+ ):
440
+ if type == wp.sim.JOINT_PRISMATIC or type == wp.sim.JOINT_REVOLUTE:
441
+ S_s = joint_S_s[dof_start]
442
+
443
+ q = joint_q[coord_start]
444
+ qd = joint_qd[dof_start]
445
+ act = joint_act[axis_start]
446
+
447
+ lower = joint_limit_lower[axis_start]
448
+ upper = joint_limit_upper[axis_start]
449
+
450
+ limit_ke = joint_limit_ke[axis_start]
451
+ limit_kd = joint_limit_kd[axis_start]
452
+ target_ke = joint_target_ke[axis_start]
453
+ target_kd = joint_target_kd[axis_start]
454
+ mode = joint_axis_mode[axis_start]
455
+
456
+ # total torque / force on the joint
457
+ t = -wp.dot(S_s, body_f_s) + eval_joint_force(
458
+ q, qd, act, target_ke, target_kd, lower, upper, limit_ke, limit_kd, mode
459
+ )
460
+
461
+ tau[dof_start] = t
462
+
463
+ return
464
+
465
+ if type == wp.sim.JOINT_BALL:
466
+ # target_ke = joint_target_ke[axis_start]
467
+ # target_kd = joint_target_kd[axis_start]
468
+
469
+ for i in range(3):
470
+ S_s = joint_S_s[dof_start + i]
471
+
472
+ # w = joint_qd[dof_start + i]
473
+ # r = joint_q[coord_start + i]
474
+
475
+ tau[dof_start + i] = -wp.dot(S_s, body_f_s) # - w * target_kd - r * target_ke
476
+
477
+ return
478
+
479
+ if type == wp.sim.JOINT_FREE or type == wp.sim.JOINT_DISTANCE:
480
+ for i in range(6):
481
+ S_s = joint_S_s[dof_start + i]
482
+ tau[dof_start + i] = -wp.dot(S_s, body_f_s)
483
+
484
+ return
485
+
486
+ if type == wp.sim.JOINT_COMPOUND or type == wp.sim.JOINT_UNIVERSAL or type == wp.sim.JOINT_D6:
487
+ axis_count = lin_axis_count + ang_axis_count
488
+
489
+ for i in range(axis_count):
490
+ S_s = joint_S_s[dof_start + i]
491
+
492
+ q = joint_q[coord_start + i]
493
+ qd = joint_qd[dof_start + i]
494
+ act = joint_act[axis_start + i]
495
+
496
+ lower = joint_limit_lower[axis_start + i]
497
+ upper = joint_limit_upper[axis_start + i]
498
+ limit_ke = joint_limit_ke[axis_start + i]
499
+ limit_kd = joint_limit_kd[axis_start + i]
500
+ target_ke = joint_target_ke[axis_start + i]
501
+ target_kd = joint_target_kd[axis_start + i]
502
+ mode = joint_axis_mode[axis_start + i]
503
+
504
+ f = eval_joint_force(q, qd, act, target_ke, target_kd, lower, upper, limit_ke, limit_kd, mode)
505
+
506
+ # total torque / force on the joint
507
+ t = -wp.dot(S_s, body_f_s) + f
508
+
509
+ tau[dof_start + i] = t
510
+
511
+ return
512
+
513
+
514
+ @wp.func
515
+ def jcalc_integrate(
516
+ type: int,
517
+ joint_q: wp.array(dtype=float),
518
+ joint_qd: wp.array(dtype=float),
519
+ joint_qdd: wp.array(dtype=float),
520
+ coord_start: int,
521
+ dof_start: int,
522
+ lin_axis_count: int,
523
+ ang_axis_count: int,
524
+ dt: float,
525
+ # outputs
526
+ joint_q_new: wp.array(dtype=float),
527
+ joint_qd_new: wp.array(dtype=float),
528
+ ):
529
+ if type == wp.sim.JOINT_FIXED:
530
+ return
531
+
532
+ # prismatic / revolute
533
+ if type == wp.sim.JOINT_PRISMATIC or type == wp.sim.JOINT_REVOLUTE:
534
+ qdd = joint_qdd[dof_start]
535
+ qd = joint_qd[dof_start]
536
+ q = joint_q[coord_start]
537
+
538
+ qd_new = qd + qdd * dt
539
+ q_new = q + qd_new * dt
540
+
541
+ joint_qd_new[dof_start] = qd_new
542
+ joint_q_new[coord_start] = q_new
543
+
544
+ return
545
+
546
+ # ball
547
+ if type == wp.sim.JOINT_BALL:
548
+ m_j = wp.vec3(joint_qdd[dof_start + 0], joint_qdd[dof_start + 1], joint_qdd[dof_start + 2])
549
+ w_j = wp.vec3(joint_qd[dof_start + 0], joint_qd[dof_start + 1], joint_qd[dof_start + 2])
550
+
551
+ r_j = wp.quat(
552
+ joint_q[coord_start + 0], joint_q[coord_start + 1], joint_q[coord_start + 2], joint_q[coord_start + 3]
553
+ )
554
+
555
+ # symplectic Euler
556
+ w_j_new = w_j + m_j * dt
557
+
558
+ drdt_j = wp.quat(w_j_new, 0.0) * r_j * 0.5
559
+
560
+ # new orientation (normalized)
561
+ r_j_new = wp.normalize(r_j + drdt_j * dt)
562
+
563
+ # update joint coords
564
+ joint_q_new[coord_start + 0] = r_j_new[0]
565
+ joint_q_new[coord_start + 1] = r_j_new[1]
566
+ joint_q_new[coord_start + 2] = r_j_new[2]
567
+ joint_q_new[coord_start + 3] = r_j_new[3]
568
+
569
+ # update joint vel
570
+ joint_qd_new[dof_start + 0] = w_j_new[0]
571
+ joint_qd_new[dof_start + 1] = w_j_new[1]
572
+ joint_qd_new[dof_start + 2] = w_j_new[2]
573
+
574
+ return
575
+
576
+ # free joint
577
+ if type == wp.sim.JOINT_FREE or type == wp.sim.JOINT_DISTANCE:
578
+ # dofs: qd = (omega_x, omega_y, omega_z, vel_x, vel_y, vel_z)
579
+ # coords: q = (trans_x, trans_y, trans_z, quat_x, quat_y, quat_z, quat_w)
580
+
581
+ # angular and linear acceleration
582
+ m_s = wp.vec3(joint_qdd[dof_start + 0], joint_qdd[dof_start + 1], joint_qdd[dof_start + 2])
583
+ a_s = wp.vec3(joint_qdd[dof_start + 3], joint_qdd[dof_start + 4], joint_qdd[dof_start + 5])
584
+
585
+ # angular and linear velocity
586
+ w_s = wp.vec3(joint_qd[dof_start + 0], joint_qd[dof_start + 1], joint_qd[dof_start + 2])
587
+ v_s = wp.vec3(joint_qd[dof_start + 3], joint_qd[dof_start + 4], joint_qd[dof_start + 5])
588
+
589
+ # symplectic Euler
590
+ w_s = w_s + m_s * dt
591
+ v_s = v_s + a_s * dt
592
+
593
+ # translation of origin
594
+ p_s = wp.vec3(joint_q[coord_start + 0], joint_q[coord_start + 1], joint_q[coord_start + 2])
595
+
596
+ # linear vel of origin (note q/qd switch order of linear angular elements)
597
+ # note we are converting the body twist in the space frame (w_s, v_s) to compute center of mass velcity
598
+ dpdt_s = v_s + wp.cross(w_s, p_s)
599
+
600
+ # quat and quat derivative
601
+ r_s = wp.quat(
602
+ joint_q[coord_start + 3], joint_q[coord_start + 4], joint_q[coord_start + 5], joint_q[coord_start + 6]
603
+ )
604
+
605
+ drdt_s = wp.quat(w_s, 0.0) * r_s * 0.5
606
+
607
+ # new orientation (normalized)
608
+ p_s_new = p_s + dpdt_s * dt
609
+ r_s_new = wp.normalize(r_s + drdt_s * dt)
610
+
611
+ # update transform
612
+ joint_q_new[coord_start + 0] = p_s_new[0]
613
+ joint_q_new[coord_start + 1] = p_s_new[1]
614
+ joint_q_new[coord_start + 2] = p_s_new[2]
615
+
616
+ joint_q_new[coord_start + 3] = r_s_new[0]
617
+ joint_q_new[coord_start + 4] = r_s_new[1]
618
+ joint_q_new[coord_start + 5] = r_s_new[2]
619
+ joint_q_new[coord_start + 6] = r_s_new[3]
620
+
621
+ # update joint_twist
622
+ joint_qd_new[dof_start + 0] = w_s[0]
623
+ joint_qd_new[dof_start + 1] = w_s[1]
624
+ joint_qd_new[dof_start + 2] = w_s[2]
625
+ joint_qd_new[dof_start + 3] = v_s[0]
626
+ joint_qd_new[dof_start + 4] = v_s[1]
627
+ joint_qd_new[dof_start + 5] = v_s[2]
628
+
629
+ return
630
+
631
+ # other joint types (compound, universal, D6)
632
+ if type == wp.sim.JOINT_COMPOUND or type == wp.sim.JOINT_UNIVERSAL or type == wp.sim.JOINT_D6:
633
+ axis_count = lin_axis_count + ang_axis_count
634
+
635
+ for i in range(axis_count):
636
+ qdd = joint_qdd[dof_start + i]
637
+ qd = joint_qd[dof_start + i]
638
+ q = joint_q[coord_start + i]
639
+
640
+ qd_new = qd + qdd * dt
641
+ q_new = q + qd_new * dt
642
+
643
+ joint_qd_new[dof_start + i] = qd_new
644
+ joint_q_new[coord_start + i] = q_new
645
+
646
+ return
647
+
648
+
649
+ @wp.func
650
+ def compute_link_transform(
651
+ i: int,
652
+ joint_type: wp.array(dtype=int),
653
+ joint_parent: wp.array(dtype=int),
654
+ joint_child: wp.array(dtype=int),
655
+ joint_q_start: wp.array(dtype=int),
656
+ joint_q: wp.array(dtype=float),
657
+ joint_X_p: wp.array(dtype=wp.transform),
658
+ joint_X_c: wp.array(dtype=wp.transform),
659
+ body_X_com: wp.array(dtype=wp.transform),
660
+ joint_axis: wp.array(dtype=wp.vec3),
661
+ joint_axis_start: wp.array(dtype=int),
662
+ joint_axis_dim: wp.array(dtype=int, ndim=2),
663
+ # outputs
664
+ body_q: wp.array(dtype=wp.transform),
665
+ body_q_com: wp.array(dtype=wp.transform),
666
+ ):
667
+ # parent transform
668
+ parent = joint_parent[i]
669
+ child = joint_child[i]
670
+
671
+ # parent transform in spatial coordinates
672
+ X_pj = joint_X_p[i]
673
+ X_cj = joint_X_c[i]
674
+ # parent anchor frame in world space
675
+ X_wpj = X_pj
676
+ if parent >= 0:
677
+ X_wp = body_q[parent]
678
+ X_wpj = X_wp * X_wpj
679
+
680
+ type = joint_type[i]
681
+ axis_start = joint_axis_start[i]
682
+ lin_axis_count = joint_axis_dim[i, 0]
683
+ ang_axis_count = joint_axis_dim[i, 1]
684
+ coord_start = joint_q_start[i]
685
+
686
+ # compute transform across joint
687
+ X_j = jcalc_transform(type, joint_axis, axis_start, lin_axis_count, ang_axis_count, joint_q, coord_start)
688
+
689
+ # transform from world to joint anchor frame at child body
690
+ X_wcj = X_wpj * X_j
691
+ # transform from world to child body frame
692
+ X_wc = X_wcj * wp.transform_inverse(X_cj)
693
+
694
+ # compute transform of center of mass
695
+ X_cm = body_X_com[child]
696
+ X_sm = X_wc * X_cm
697
+
698
+ # store geometry transforms
699
+ body_q[child] = X_wc
700
+ body_q_com[child] = X_sm
701
+
702
+
703
+ @wp.kernel
704
+ def eval_rigid_fk(
705
+ articulation_start: wp.array(dtype=int),
706
+ joint_type: wp.array(dtype=int),
707
+ joint_parent: wp.array(dtype=int),
708
+ joint_child: wp.array(dtype=int),
709
+ joint_q_start: wp.array(dtype=int),
710
+ joint_q: wp.array(dtype=float),
711
+ joint_X_p: wp.array(dtype=wp.transform),
712
+ joint_X_c: wp.array(dtype=wp.transform),
713
+ body_X_com: wp.array(dtype=wp.transform),
714
+ joint_axis: wp.array(dtype=wp.vec3),
715
+ joint_axis_start: wp.array(dtype=int),
716
+ joint_axis_dim: wp.array(dtype=int, ndim=2),
717
+ # outputs
718
+ body_q: wp.array(dtype=wp.transform),
719
+ body_q_com: wp.array(dtype=wp.transform),
720
+ ):
721
+ # one thread per-articulation
722
+ index = wp.tid()
723
+
724
+ start = articulation_start[index]
725
+ end = articulation_start[index + 1]
726
+
727
+ for i in range(start, end):
728
+ compute_link_transform(
729
+ i,
730
+ joint_type,
731
+ joint_parent,
732
+ joint_child,
733
+ joint_q_start,
734
+ joint_q,
735
+ joint_X_p,
736
+ joint_X_c,
737
+ body_X_com,
738
+ joint_axis,
739
+ joint_axis_start,
740
+ joint_axis_dim,
741
+ body_q,
742
+ body_q_com,
743
+ )
744
+
745
+
746
+ @wp.func
747
+ def spatial_cross(a: wp.spatial_vector, b: wp.spatial_vector):
748
+ w_a = wp.spatial_top(a)
749
+ v_a = wp.spatial_bottom(a)
750
+
751
+ w_b = wp.spatial_top(b)
752
+ v_b = wp.spatial_bottom(b)
753
+
754
+ w = wp.cross(w_a, w_b)
755
+ v = wp.cross(w_a, v_b) + wp.cross(v_a, w_b)
756
+
757
+ return wp.spatial_vector(w, v)
758
+
759
+
760
+ @wp.func
761
+ def spatial_cross_dual(a: wp.spatial_vector, b: wp.spatial_vector):
762
+ w_a = wp.spatial_top(a)
763
+ v_a = wp.spatial_bottom(a)
764
+
765
+ w_b = wp.spatial_top(b)
766
+ v_b = wp.spatial_bottom(b)
767
+
768
+ w = wp.cross(w_a, w_b) + wp.cross(v_a, v_b)
769
+ v = wp.cross(w_a, v_b)
770
+
771
+ return wp.spatial_vector(w, v)
772
+
773
+
774
+ @wp.func
775
+ def dense_index(stride: int, i: int, j: int):
776
+ return i * stride + j
777
+
778
+
779
+ @wp.func
780
+ def compute_link_velocity(
781
+ i: int,
782
+ joint_type: wp.array(dtype=int),
783
+ joint_parent: wp.array(dtype=int),
784
+ joint_child: wp.array(dtype=int),
785
+ joint_q_start: wp.array(dtype=int),
786
+ joint_qd_start: wp.array(dtype=int),
787
+ joint_q: wp.array(dtype=float),
788
+ joint_qd: wp.array(dtype=float),
789
+ joint_axis: wp.array(dtype=wp.vec3),
790
+ joint_axis_start: wp.array(dtype=int),
791
+ joint_axis_dim: wp.array(dtype=int, ndim=2),
792
+ body_I_m: wp.array(dtype=wp.spatial_matrix),
793
+ body_q: wp.array(dtype=wp.transform),
794
+ body_q_com: wp.array(dtype=wp.transform),
795
+ joint_X_p: wp.array(dtype=wp.transform),
796
+ joint_X_c: wp.array(dtype=wp.transform),
797
+ gravity: wp.vec3,
798
+ # outputs
799
+ joint_S_s: wp.array(dtype=wp.spatial_vector),
800
+ body_I_s: wp.array(dtype=wp.spatial_matrix),
801
+ body_v_s: wp.array(dtype=wp.spatial_vector),
802
+ body_f_s: wp.array(dtype=wp.spatial_vector),
803
+ body_a_s: wp.array(dtype=wp.spatial_vector),
804
+ ):
805
+ type = joint_type[i]
806
+ child = joint_child[i]
807
+ parent = joint_parent[i]
808
+ q_start = joint_q_start[i]
809
+ qd_start = joint_qd_start[i]
810
+
811
+ X_pj = joint_X_p[i]
812
+ X_cj = joint_X_c[i]
813
+
814
+ # parent anchor frame in world space
815
+ X_wpj = X_pj
816
+ if parent >= 0:
817
+ X_wp = body_q[parent]
818
+ X_wpj = X_wp * X_wpj
819
+
820
+ # compute motion subspace and velocity across the joint (also stores S_s to global memory)
821
+ axis_start = joint_axis_start[i]
822
+ lin_axis_count = joint_axis_dim[i, 0]
823
+ ang_axis_count = joint_axis_dim[i, 1]
824
+ v_j_s = jcalc_motion(
825
+ type,
826
+ joint_axis,
827
+ axis_start,
828
+ lin_axis_count,
829
+ ang_axis_count,
830
+ X_wpj,
831
+ joint_q,
832
+ joint_qd,
833
+ q_start,
834
+ qd_start,
835
+ joint_S_s,
836
+ )
837
+
838
+ # parent velocity
839
+ v_parent_s = wp.spatial_vector()
840
+ a_parent_s = wp.spatial_vector()
841
+
842
+ if parent >= 0:
843
+ v_parent_s = body_v_s[parent]
844
+ a_parent_s = body_a_s[parent]
845
+
846
+ # body velocity, acceleration
847
+ v_s = v_parent_s + v_j_s
848
+ a_s = a_parent_s + spatial_cross(v_s, v_j_s) # + joint_S_s[i]*self.joint_qdd[i]
849
+
850
+ # compute body forces
851
+ X_sm = body_q_com[child]
852
+ I_m = body_I_m[child]
853
+
854
+ # gravity and external forces (expressed in frame aligned with s but centered at body mass)
855
+ m = I_m[3, 3]
856
+
857
+ f_g = m * gravity
858
+ r_com = wp.transform_get_translation(X_sm)
859
+ f_g_s = wp.spatial_vector(wp.cross(r_com, f_g), f_g)
860
+
861
+ # body forces
862
+ I_s = spatial_transform_inertia(X_sm, I_m)
863
+
864
+ f_b_s = I_s * a_s + spatial_cross_dual(v_s, I_s * v_s)
865
+
866
+ body_v_s[child] = v_s
867
+ body_a_s[child] = a_s
868
+ body_f_s[child] = f_b_s - f_g_s
869
+ body_I_s[child] = I_s
870
+
871
+
872
+ # Inverse dynamics via Recursive Newton-Euler algorithm (Featherstone Table 5.1)
873
+ @wp.kernel
874
+ def eval_rigid_id(
875
+ articulation_start: wp.array(dtype=int),
876
+ joint_type: wp.array(dtype=int),
877
+ joint_parent: wp.array(dtype=int),
878
+ joint_child: wp.array(dtype=int),
879
+ joint_q_start: wp.array(dtype=int),
880
+ joint_qd_start: wp.array(dtype=int),
881
+ joint_q: wp.array(dtype=float),
882
+ joint_qd: wp.array(dtype=float),
883
+ joint_axis: wp.array(dtype=wp.vec3),
884
+ joint_axis_start: wp.array(dtype=int),
885
+ joint_axis_dim: wp.array(dtype=int, ndim=2),
886
+ body_I_m: wp.array(dtype=wp.spatial_matrix),
887
+ body_q: wp.array(dtype=wp.transform),
888
+ body_q_com: wp.array(dtype=wp.transform),
889
+ joint_X_p: wp.array(dtype=wp.transform),
890
+ joint_X_c: wp.array(dtype=wp.transform),
891
+ gravity: wp.vec3,
892
+ # outputs
893
+ joint_S_s: wp.array(dtype=wp.spatial_vector),
894
+ body_I_s: wp.array(dtype=wp.spatial_matrix),
895
+ body_v_s: wp.array(dtype=wp.spatial_vector),
896
+ body_f_s: wp.array(dtype=wp.spatial_vector),
897
+ body_a_s: wp.array(dtype=wp.spatial_vector),
898
+ ):
899
+ # one thread per-articulation
900
+ index = wp.tid()
901
+
902
+ start = articulation_start[index]
903
+ end = articulation_start[index + 1]
904
+
905
+ # compute link velocities and coriolis forces
906
+ for i in range(start, end):
907
+ compute_link_velocity(
908
+ i,
909
+ joint_type,
910
+ joint_parent,
911
+ joint_child,
912
+ joint_q_start,
913
+ joint_qd_start,
914
+ joint_q,
915
+ joint_qd,
916
+ joint_axis,
917
+ joint_axis_start,
918
+ joint_axis_dim,
919
+ body_I_m,
920
+ body_q,
921
+ body_q_com,
922
+ joint_X_p,
923
+ joint_X_c,
924
+ gravity,
925
+ joint_S_s,
926
+ body_I_s,
927
+ body_v_s,
928
+ body_f_s,
929
+ body_a_s,
930
+ )
931
+
932
+
933
+ @wp.kernel
934
+ def eval_rigid_tau(
935
+ articulation_start: wp.array(dtype=int),
936
+ joint_type: wp.array(dtype=int),
937
+ joint_parent: wp.array(dtype=int),
938
+ joint_child: wp.array(dtype=int),
939
+ joint_q_start: wp.array(dtype=int),
940
+ joint_qd_start: wp.array(dtype=int),
941
+ joint_axis_start: wp.array(dtype=int),
942
+ joint_axis_dim: wp.array(dtype=int, ndim=2),
943
+ joint_axis_mode: wp.array(dtype=int),
944
+ joint_q: wp.array(dtype=float),
945
+ joint_qd: wp.array(dtype=float),
946
+ joint_act: wp.array(dtype=float),
947
+ joint_target_ke: wp.array(dtype=float),
948
+ joint_target_kd: wp.array(dtype=float),
949
+ joint_limit_lower: wp.array(dtype=float),
950
+ joint_limit_upper: wp.array(dtype=float),
951
+ joint_limit_ke: wp.array(dtype=float),
952
+ joint_limit_kd: wp.array(dtype=float),
953
+ joint_S_s: wp.array(dtype=wp.spatial_vector),
954
+ body_fb_s: wp.array(dtype=wp.spatial_vector),
955
+ body_f_ext: wp.array(dtype=wp.spatial_vector),
956
+ # outputs
957
+ body_ft_s: wp.array(dtype=wp.spatial_vector),
958
+ tau: wp.array(dtype=float),
959
+ ):
960
+ # one thread per-articulation
961
+ index = wp.tid()
962
+
963
+ start = articulation_start[index]
964
+ end = articulation_start[index + 1]
965
+ count = end - start
966
+
967
+ # compute joint forces
968
+ for offset in range(count):
969
+ # for backwards traversal
970
+ i = end - offset - 1
971
+
972
+ type = joint_type[i]
973
+ parent = joint_parent[i]
974
+ child = joint_child[i]
975
+ dof_start = joint_qd_start[i]
976
+ coord_start = joint_q_start[i]
977
+ axis_start = joint_axis_start[i]
978
+ lin_axis_count = joint_axis_dim[i, 0]
979
+ ang_axis_count = joint_axis_dim[i, 1]
980
+
981
+ # total forces on body
982
+ f_b_s = body_fb_s[child]
983
+ f_t_s = body_ft_s[child]
984
+ f_ext = body_f_ext[child]
985
+ f_s = f_b_s + f_t_s + f_ext
986
+
987
+ # compute joint-space forces, writes out tau
988
+ jcalc_tau(
989
+ type,
990
+ joint_target_ke,
991
+ joint_target_kd,
992
+ joint_limit_ke,
993
+ joint_limit_kd,
994
+ joint_S_s,
995
+ joint_q,
996
+ joint_qd,
997
+ joint_act,
998
+ joint_axis_mode,
999
+ joint_limit_lower,
1000
+ joint_limit_upper,
1001
+ coord_start,
1002
+ dof_start,
1003
+ axis_start,
1004
+ lin_axis_count,
1005
+ ang_axis_count,
1006
+ f_s,
1007
+ tau,
1008
+ )
1009
+
1010
+ # update parent forces, todo: check that this is valid for the backwards pass
1011
+ if parent >= 0:
1012
+ wp.atomic_add(body_ft_s, parent, f_s)
1013
+
1014
+
1015
+ # builds spatial Jacobian J which is an (joint_count*6)x(dof_count) matrix
1016
+ @wp.kernel
1017
+ def eval_rigid_jacobian(
1018
+ articulation_start: wp.array(dtype=int),
1019
+ articulation_J_start: wp.array(dtype=int),
1020
+ joint_parent: wp.array(dtype=int),
1021
+ joint_qd_start: wp.array(dtype=int),
1022
+ joint_S_s: wp.array(dtype=wp.spatial_vector),
1023
+ # outputs
1024
+ J: wp.array(dtype=float),
1025
+ ):
1026
+ # one thread per-articulation
1027
+ index = wp.tid()
1028
+
1029
+ joint_start = articulation_start[index]
1030
+ joint_end = articulation_start[index + 1]
1031
+ joint_count = joint_end - joint_start
1032
+
1033
+ J_offset = articulation_J_start[index]
1034
+
1035
+ articulation_dof_start = joint_qd_start[joint_start]
1036
+ articulation_dof_end = joint_qd_start[joint_end]
1037
+ articulation_dof_count = articulation_dof_end - articulation_dof_start
1038
+
1039
+ for i in range(joint_count):
1040
+ row_start = i * 6
1041
+
1042
+ j = joint_start + i
1043
+ while j != -1:
1044
+ joint_dof_start = joint_qd_start[j]
1045
+ joint_dof_end = joint_qd_start[j + 1]
1046
+ joint_dof_count = joint_dof_end - joint_dof_start
1047
+
1048
+ # fill out each row of the Jacobian walking up the tree
1049
+ for dof in range(joint_dof_count):
1050
+ col = (joint_dof_start - articulation_dof_start) + dof
1051
+ S = joint_S_s[joint_dof_start + dof]
1052
+
1053
+ for k in range(6):
1054
+ J[J_offset + dense_index(articulation_dof_count, row_start + k, col)] = S[k]
1055
+
1056
+ j = joint_parent[j]
1057
+
1058
+
1059
+ @wp.func
1060
+ def spatial_mass(
1061
+ body_I_s: wp.array(dtype=wp.spatial_matrix),
1062
+ joint_start: int,
1063
+ joint_count: int,
1064
+ M_start: int,
1065
+ # outputs
1066
+ M: wp.array(dtype=float),
1067
+ ):
1068
+ stride = joint_count * 6
1069
+ for l in range(joint_count):
1070
+ I = body_I_s[joint_start + l]
1071
+ for i in range(6):
1072
+ for j in range(6):
1073
+ M[M_start + dense_index(stride, l * 6 + i, l * 6 + j)] = I[i, j]
1074
+
1075
+
1076
+ @wp.kernel
1077
+ def eval_rigid_mass(
1078
+ articulation_start: wp.array(dtype=int),
1079
+ articulation_M_start: wp.array(dtype=int),
1080
+ body_I_s: wp.array(dtype=wp.spatial_matrix),
1081
+ # outputs
1082
+ M: wp.array(dtype=float),
1083
+ ):
1084
+ # one thread per-articulation
1085
+ index = wp.tid()
1086
+
1087
+ joint_start = articulation_start[index]
1088
+ joint_end = articulation_start[index + 1]
1089
+ joint_count = joint_end - joint_start
1090
+
1091
+ M_offset = articulation_M_start[index]
1092
+
1093
+ spatial_mass(body_I_s, joint_start, joint_count, M_offset, M)
1094
+
1095
+
1096
+ @wp.func
1097
+ def dense_gemm(
1098
+ m: int,
1099
+ n: int,
1100
+ p: int,
1101
+ transpose_A: bool,
1102
+ transpose_B: bool,
1103
+ add_to_C: bool,
1104
+ A_start: int,
1105
+ B_start: int,
1106
+ C_start: int,
1107
+ A: wp.array(dtype=float),
1108
+ B: wp.array(dtype=float),
1109
+ # outputs
1110
+ C: wp.array(dtype=float),
1111
+ ):
1112
+ # multiply a `m x p` matrix A by a `p x n` matrix B to produce a `m x n` matrix C
1113
+ for i in range(m):
1114
+ for j in range(n):
1115
+ sum = float(0.0)
1116
+ for k in range(p):
1117
+ if transpose_A:
1118
+ a_i = k * m + i
1119
+ else:
1120
+ a_i = i * p + k
1121
+ if transpose_B:
1122
+ b_j = j * p + k
1123
+ else:
1124
+ b_j = k * n + j
1125
+ sum += A[A_start + a_i] * B[B_start + b_j]
1126
+
1127
+ if add_to_C:
1128
+ C[C_start + i * n + j] += sum
1129
+ else:
1130
+ C[C_start + i * n + j] = sum
1131
+
1132
+
1133
+ @wp.func_grad(dense_gemm)
1134
+ def adj_dense_gemm(
1135
+ m: int,
1136
+ n: int,
1137
+ p: int,
1138
+ transpose_A: bool,
1139
+ transpose_B: bool,
1140
+ add_to_C: bool,
1141
+ A_start: int,
1142
+ B_start: int,
1143
+ C_start: int,
1144
+ A: wp.array(dtype=float),
1145
+ B: wp.array(dtype=float),
1146
+ # outputs
1147
+ C: wp.array(dtype=float),
1148
+ ):
1149
+ add_to_C = True
1150
+ if transpose_A:
1151
+ dense_gemm(p, m, n, False, True, add_to_C, A_start, B_start, C_start, B, wp.adjoint[C], wp.adjoint[A])
1152
+ dense_gemm(p, n, m, False, False, add_to_C, A_start, B_start, C_start, A, wp.adjoint[C], wp.adjoint[B])
1153
+ else:
1154
+ dense_gemm(
1155
+ m, p, n, False, not transpose_B, add_to_C, A_start, B_start, C_start, wp.adjoint[C], B, wp.adjoint[A]
1156
+ )
1157
+ dense_gemm(p, n, m, True, False, add_to_C, A_start, B_start, C_start, A, wp.adjoint[C], wp.adjoint[B])
1158
+
1159
+
1160
+ @wp.kernel
1161
+ def eval_dense_gemm_batched(
1162
+ m: wp.array(dtype=int),
1163
+ n: wp.array(dtype=int),
1164
+ p: wp.array(dtype=int),
1165
+ transpose_A: bool,
1166
+ transpose_B: bool,
1167
+ A_start: wp.array(dtype=int),
1168
+ B_start: wp.array(dtype=int),
1169
+ C_start: wp.array(dtype=int),
1170
+ A: wp.array(dtype=float),
1171
+ B: wp.array(dtype=float),
1172
+ C: wp.array(dtype=float),
1173
+ ):
1174
+ # on the CPU each thread computes the whole matrix multiply
1175
+ # on the GPU each block computes the multiply with one output per-thread
1176
+ batch = wp.tid() # /kNumThreadsPerBlock;
1177
+ add_to_C = False
1178
+
1179
+ dense_gemm(
1180
+ m[batch],
1181
+ n[batch],
1182
+ p[batch],
1183
+ transpose_A,
1184
+ transpose_B,
1185
+ add_to_C,
1186
+ A_start[batch],
1187
+ B_start[batch],
1188
+ C_start[batch],
1189
+ A,
1190
+ B,
1191
+ C,
1192
+ )
1193
+
1194
+
1195
+ @wp.func
1196
+ def dense_cholesky(
1197
+ n: int,
1198
+ A: wp.array(dtype=float),
1199
+ R: wp.array(dtype=float),
1200
+ A_start: int,
1201
+ R_start: int,
1202
+ # outputs
1203
+ L: wp.array(dtype=float),
1204
+ ):
1205
+ # compute the Cholesky factorization of A = L L^T with diagonal regularization R
1206
+ for j in range(n):
1207
+ s = A[A_start + dense_index(n, j, j)] + R[R_start + j]
1208
+
1209
+ for k in range(j):
1210
+ r = L[A_start + dense_index(n, j, k)]
1211
+ s -= r * r
1212
+
1213
+ s = wp.sqrt(s)
1214
+ invS = 1.0 / s
1215
+
1216
+ L[A_start + dense_index(n, j, j)] = s
1217
+
1218
+ for i in range(j + 1, n):
1219
+ s = A[A_start + dense_index(n, i, j)]
1220
+
1221
+ for k in range(j):
1222
+ s -= L[A_start + dense_index(n, i, k)] * L[A_start + dense_index(n, j, k)]
1223
+
1224
+ L[A_start + dense_index(n, i, j)] = s * invS
1225
+
1226
+
1227
+ @wp.func_grad(dense_cholesky)
1228
+ def adj_dense_cholesky(
1229
+ n: int,
1230
+ A: wp.array(dtype=float),
1231
+ R: wp.array(dtype=float),
1232
+ A_start: int,
1233
+ R_start: int,
1234
+ # outputs
1235
+ L: wp.array(dtype=float),
1236
+ ):
1237
+ # nop, use dense_solve to differentiate through (A^-1)b = x
1238
+ pass
1239
+
1240
+
1241
+ @wp.kernel
1242
+ def eval_dense_cholesky_batched(
1243
+ A_starts: wp.array(dtype=int),
1244
+ A_dim: wp.array(dtype=int),
1245
+ A: wp.array(dtype=float),
1246
+ R: wp.array(dtype=float),
1247
+ L: wp.array(dtype=float),
1248
+ ):
1249
+ batch = wp.tid()
1250
+
1251
+ n = A_dim[batch]
1252
+ A_start = A_starts[batch]
1253
+ R_start = n * batch
1254
+
1255
+ dense_cholesky(n, A, R, A_start, R_start, L)
1256
+
1257
+
1258
+ @wp.func
1259
+ def dense_subs(
1260
+ n: int,
1261
+ L_start: int,
1262
+ b_start: int,
1263
+ L: wp.array(dtype=float),
1264
+ b: wp.array(dtype=float),
1265
+ # outputs
1266
+ x: wp.array(dtype=float),
1267
+ ):
1268
+ # Solves (L L^T) x = b for x given the Cholesky factor L
1269
+ # forward substitution solves the lower triangular system L y = b for y
1270
+ for i in range(n):
1271
+ s = b[b_start + i]
1272
+
1273
+ for j in range(i):
1274
+ s -= L[L_start + dense_index(n, i, j)] * x[b_start + j]
1275
+
1276
+ x[b_start + i] = s / L[L_start + dense_index(n, i, i)]
1277
+
1278
+ # backward substitution solves the upper triangular system L^T x = y for x
1279
+ for i in range(n - 1, -1, -1):
1280
+ s = x[b_start + i]
1281
+
1282
+ for j in range(i + 1, n):
1283
+ s -= L[L_start + dense_index(n, j, i)] * x[b_start + j]
1284
+
1285
+ x[b_start + i] = s / L[L_start + dense_index(n, i, i)]
1286
+
1287
+
1288
+ @wp.func
1289
+ def dense_solve(
1290
+ n: int,
1291
+ L_start: int,
1292
+ b_start: int,
1293
+ L: wp.array(dtype=float),
1294
+ b: wp.array(dtype=float),
1295
+ # outputs
1296
+ x: wp.array(dtype=float),
1297
+ tmp: wp.array(dtype=float),
1298
+ ):
1299
+ # helper function to include tmp argument for backward pass
1300
+ dense_subs(n, L_start, b_start, L, b, x)
1301
+
1302
+
1303
+ @wp.func_grad(dense_solve)
1304
+ def adj_dense_solve(
1305
+ n: int,
1306
+ L_start: int,
1307
+ b_start: int,
1308
+ L: wp.array(dtype=float),
1309
+ b: wp.array(dtype=float),
1310
+ # outputs
1311
+ x: wp.array(dtype=float),
1312
+ tmp: wp.array(dtype=float),
1313
+ ):
1314
+ for i in range(n):
1315
+ tmp[b_start + i] = 0.0
1316
+
1317
+ dense_subs(n, L_start, b_start, L, wp.adjoint[x], tmp)
1318
+
1319
+ for i in range(n):
1320
+ wp.adjoint[b][b_start + i] += tmp[b_start + i]
1321
+
1322
+ # A* = -adj_b*x^T
1323
+ for i in range(n):
1324
+ for j in range(n):
1325
+ wp.adjoint[L][L_start + dense_index(n, i, j)] += -tmp[b_start + i] * x[b_start + j]
1326
+
1327
+
1328
+ @wp.kernel
1329
+ def eval_dense_solve_batched(
1330
+ L_start: wp.array(dtype=int),
1331
+ L_dim: wp.array(dtype=int),
1332
+ b_start: wp.array(dtype=int),
1333
+ L: wp.array(dtype=float),
1334
+ b: wp.array(dtype=float),
1335
+ x: wp.array(dtype=float),
1336
+ tmp: wp.array(dtype=float),
1337
+ ):
1338
+ batch = wp.tid()
1339
+
1340
+ dense_solve(L_dim[batch], L_start[batch], b_start[batch], L, b, x, tmp)
1341
+
1342
+
1343
+ @wp.kernel
1344
+ def integrate_generalized_joints(
1345
+ joint_type: wp.array(dtype=int),
1346
+ joint_q_start: wp.array(dtype=int),
1347
+ joint_qd_start: wp.array(dtype=int),
1348
+ joint_axis_dim: wp.array(dtype=int, ndim=2),
1349
+ joint_q: wp.array(dtype=float),
1350
+ joint_qd: wp.array(dtype=float),
1351
+ joint_qdd: wp.array(dtype=float),
1352
+ dt: float,
1353
+ # outputs
1354
+ joint_q_new: wp.array(dtype=float),
1355
+ joint_qd_new: wp.array(dtype=float),
1356
+ ):
1357
+ # one thread per-articulation
1358
+ index = wp.tid()
1359
+
1360
+ type = joint_type[index]
1361
+ coord_start = joint_q_start[index]
1362
+ dof_start = joint_qd_start[index]
1363
+ lin_axis_count = joint_axis_dim[index, 0]
1364
+ ang_axis_count = joint_axis_dim[index, 1]
1365
+
1366
+ jcalc_integrate(
1367
+ type,
1368
+ joint_q,
1369
+ joint_qd,
1370
+ joint_qdd,
1371
+ coord_start,
1372
+ dof_start,
1373
+ lin_axis_count,
1374
+ ang_axis_count,
1375
+ dt,
1376
+ joint_q_new,
1377
+ joint_qd_new,
1378
+ )
1379
+
1380
+
1381
+ @wp.kernel
1382
+ def eval_body_inertial_velocities(
1383
+ body_q: wp.array(dtype=wp.transform),
1384
+ body_v_s: wp.array(dtype=wp.spatial_vector),
1385
+ # outputs
1386
+ body_qd: wp.array(dtype=wp.spatial_vector),
1387
+ ):
1388
+ tid = wp.tid()
1389
+
1390
+ X_sc = body_q[tid]
1391
+ v_s = body_v_s[tid]
1392
+ w = wp.spatial_top(v_s)
1393
+ v = wp.spatial_bottom(v_s)
1394
+
1395
+ v_inertial = v + wp.cross(w, wp.transform_get_translation(X_sc))
1396
+
1397
+ body_qd[tid] = wp.spatial_vector(w, v_inertial)
1398
+
1399
+
1400
+ class FeatherstoneIntegrator(Integrator):
1401
+ """A semi-implicit integrator using symplectic Euler that operates
1402
+ on reduced (also called generalized) coordinates to simulate articulated rigid body dynamics
1403
+ based on Featherstone's composite rigid body algorithm (CRBA).
1404
+
1405
+ See: Featherstone, Roy. Rigid Body Dynamics Algorithms. Springer US, 2014.
1406
+
1407
+ Instead of maximal coordinates :attr:`State.body_q` (rigid body positions) and :attr:`State.body_qd`
1408
+ (rigid body velocities) as is the case :class:`SemiImplicitIntegrator`, :class:`FeatherstoneIntegrator`
1409
+ uses :attr:`State.joint_q` and :attr:`State.joint_qd` to represent the positions and velocities of
1410
+ joints without allowing any redundant degrees of freedom.
1411
+
1412
+ After constructing :class:`Model` and :class:`State` objects this time-integrator
1413
+ may be used to advance the simulation state forward in time.
1414
+
1415
+ Note:
1416
+ Unlike :class:`SemiImplicitIntegrator` and :class:`XPBDIntegrator`, :class:`FeatherstoneIntegrator` does not simulate rigid bodies with nonzero mass as floating bodies if they are not connected through any joints. Floating-base systems require an explicit free joint with which the body is connected to the world, see :meth:`ModelBuilder.add_joint_free`.
1417
+
1418
+ Semi-implicit time integration is a variational integrator that
1419
+ preserves energy, however it not unconditionally stable, and requires a time-step
1420
+ small enough to support the required stiffness and damping forces.
1421
+
1422
+ See: https://en.wikipedia.org/wiki/Semi-implicit_Euler_method
1423
+
1424
+ Example
1425
+ -------
1426
+
1427
+ .. code-block:: python
1428
+
1429
+ integrator = wp.FeatherstoneIntegrator(model)
1430
+
1431
+ # simulation loop
1432
+ for i in range(100):
1433
+ state = integrator.simulate(model, state_in, state_out, dt)
1434
+
1435
+ Note:
1436
+ The :class:`FeatherstoneIntegrator` requires the :class:`Model` to be passed in as a constructor argument.
1437
+
1438
+ """
1439
+
1440
+ def __init__(self, model, angular_damping=0.05, update_mass_matrix_every=1):
1441
+ """
1442
+ Args:
1443
+ model (Model): the model to be simulated.
1444
+ angular_damping (float, optional): Angular damping factor. Defaults to 0.05.
1445
+ update_mass_matrix_every (int, optional): How often to update the mass matrix (every n-th time the :meth:`simulate` function gets called). Defaults to 1.
1446
+ """
1447
+ self.angular_damping = angular_damping
1448
+ self.update_mass_matrix_every = update_mass_matrix_every
1449
+ self.compute_articulation_indices(model)
1450
+ self.allocate_model_aux_vars(model)
1451
+ self._step = 0
1452
+
1453
+ def compute_articulation_indices(self, model):
1454
+ # calculate total size and offsets of Jacobian and mass matrices for entire system
1455
+ if model.joint_count:
1456
+ self.J_size = 0
1457
+ self.M_size = 0
1458
+ self.H_size = 0
1459
+
1460
+ articulation_J_start = []
1461
+ articulation_M_start = []
1462
+ articulation_H_start = []
1463
+
1464
+ articulation_M_rows = []
1465
+ articulation_H_rows = []
1466
+ articulation_J_rows = []
1467
+ articulation_J_cols = []
1468
+
1469
+ articulation_dof_start = []
1470
+ articulation_coord_start = []
1471
+
1472
+ articulation_start = model.articulation_start.numpy()
1473
+ joint_q_start = model.joint_q_start.numpy()
1474
+ joint_qd_start = model.joint_qd_start.numpy()
1475
+
1476
+ for i in range(model.articulation_count):
1477
+ first_joint = articulation_start[i]
1478
+ last_joint = articulation_start[i + 1]
1479
+
1480
+ first_coord = joint_q_start[first_joint]
1481
+
1482
+ first_dof = joint_qd_start[first_joint]
1483
+ last_dof = joint_qd_start[last_joint]
1484
+
1485
+ joint_count = last_joint - first_joint
1486
+ dof_count = last_dof - first_dof
1487
+
1488
+ articulation_J_start.append(self.J_size)
1489
+ articulation_M_start.append(self.M_size)
1490
+ articulation_H_start.append(self.H_size)
1491
+ articulation_dof_start.append(first_dof)
1492
+ articulation_coord_start.append(first_coord)
1493
+
1494
+ # bit of data duplication here, but will leave it as such for clarity
1495
+ articulation_M_rows.append(joint_count * 6)
1496
+ articulation_H_rows.append(dof_count)
1497
+ articulation_J_rows.append(joint_count * 6)
1498
+ articulation_J_cols.append(dof_count)
1499
+
1500
+ self.J_size += 6 * joint_count * dof_count
1501
+ self.M_size += 6 * joint_count * 6 * joint_count
1502
+ self.H_size += dof_count * dof_count
1503
+
1504
+ # matrix offsets for batched gemm
1505
+ self.articulation_J_start = wp.array(articulation_J_start, dtype=wp.int32, device=model.device)
1506
+ self.articulation_M_start = wp.array(articulation_M_start, dtype=wp.int32, device=model.device)
1507
+ self.articulation_H_start = wp.array(articulation_H_start, dtype=wp.int32, device=model.device)
1508
+
1509
+ self.articulation_M_rows = wp.array(articulation_M_rows, dtype=wp.int32, device=model.device)
1510
+ self.articulation_H_rows = wp.array(articulation_H_rows, dtype=wp.int32, device=model.device)
1511
+ self.articulation_J_rows = wp.array(articulation_J_rows, dtype=wp.int32, device=model.device)
1512
+ self.articulation_J_cols = wp.array(articulation_J_cols, dtype=wp.int32, device=model.device)
1513
+
1514
+ self.articulation_dof_start = wp.array(articulation_dof_start, dtype=wp.int32, device=model.device)
1515
+ self.articulation_coord_start = wp.array(articulation_coord_start, dtype=wp.int32, device=model.device)
1516
+
1517
+ def allocate_model_aux_vars(self, model):
1518
+ # allocate mass, Jacobian matrices, and other auxiliary variables pertaining to the model
1519
+ if model.joint_count:
1520
+ # system matrices
1521
+ self.M = wp.zeros((self.M_size,), dtype=wp.float32, device=model.device, requires_grad=model.requires_grad)
1522
+ self.J = wp.zeros((self.J_size,), dtype=wp.float32, device=model.device, requires_grad=model.requires_grad)
1523
+ self.P = wp.empty_like(self.J, requires_grad=model.requires_grad)
1524
+ self.H = wp.empty((self.H_size,), dtype=wp.float32, device=model.device, requires_grad=model.requires_grad)
1525
+
1526
+ # zero since only upper triangle is set which can trigger NaN detection
1527
+ self.L = wp.zeros_like(self.H)
1528
+
1529
+ if model.body_count:
1530
+ # TODO use requires_grad here?
1531
+ self.body_I_m = wp.empty((model.body_count,), dtype=wp.spatial_matrix, device=model.device)
1532
+ wp.launch(
1533
+ compute_spatial_inertia,
1534
+ model.body_count,
1535
+ inputs=[model.body_inertia, model.body_mass],
1536
+ outputs=[self.body_I_m],
1537
+ device=model.device,
1538
+ )
1539
+ self.body_X_com = wp.empty((model.body_count,), dtype=wp.transform, device=model.device)
1540
+ wp.launch(
1541
+ compute_com_transforms,
1542
+ model.body_count,
1543
+ inputs=[model.body_com],
1544
+ outputs=[self.body_X_com],
1545
+ device=model.device,
1546
+ )
1547
+
1548
+ def allocate_state_aux_vars(self, model, target, requires_grad):
1549
+ # allocate auxiliary variables that vary with state
1550
+ if model.body_count:
1551
+ # joints
1552
+ target.joint_qdd = wp.zeros_like(model.joint_qd, requires_grad=requires_grad)
1553
+ target.joint_tau = wp.empty_like(model.joint_qd, requires_grad=requires_grad)
1554
+ if requires_grad:
1555
+ # used in the custom grad implementation of eval_dense_solve_batched
1556
+ target.joint_solve_tmp = wp.zeros_like(model.joint_qd, requires_grad=True)
1557
+ else:
1558
+ target.joint_solve_tmp = None
1559
+ target.joint_S_s = wp.empty(
1560
+ (model.joint_dof_count,),
1561
+ dtype=wp.spatial_vector,
1562
+ device=model.device,
1563
+ requires_grad=requires_grad,
1564
+ )
1565
+
1566
+ # derived rigid body data (maximal coordinates)
1567
+ target.body_q_com = wp.empty_like(model.body_q, requires_grad=requires_grad)
1568
+ target.body_I_s = wp.empty(
1569
+ (model.body_count,), dtype=wp.spatial_matrix, device=model.device, requires_grad=requires_grad
1570
+ )
1571
+ target.body_v_s = wp.empty(
1572
+ (model.body_count,), dtype=wp.spatial_vector, device=model.device, requires_grad=requires_grad
1573
+ )
1574
+ target.body_a_s = wp.empty(
1575
+ (model.body_count,), dtype=wp.spatial_vector, device=model.device, requires_grad=requires_grad
1576
+ )
1577
+ target.body_f_s = wp.zeros(
1578
+ (model.body_count,), dtype=wp.spatial_vector, device=model.device, requires_grad=requires_grad
1579
+ )
1580
+ target.body_ft_s = wp.zeros(
1581
+ (model.body_count,), dtype=wp.spatial_vector, device=model.device, requires_grad=requires_grad
1582
+ )
1583
+
1584
+ target._featherstone_augmented = True
1585
+
1586
+ def simulate(self, model: Model, state_in: State, state_out: State, dt: float, control: Control = None):
1587
+ requires_grad = state_in.requires_grad
1588
+
1589
+ # optionally create dynamical auxiliary variables
1590
+ if requires_grad:
1591
+ state_aug = state_out
1592
+ else:
1593
+ state_aug = self
1594
+
1595
+ if not getattr(state_aug, "_featherstone_augmented", False):
1596
+ self.allocate_state_aux_vars(model, state_aug, requires_grad)
1597
+ if control is None:
1598
+ control = model.control(clone_variables=False)
1599
+
1600
+ with wp.ScopedTimer("simulate", False):
1601
+ particle_f = None
1602
+ body_f = None
1603
+
1604
+ if state_in.particle_count:
1605
+ particle_f = state_in.particle_f
1606
+
1607
+ if state_in.body_count:
1608
+ body_f = state_in.body_f
1609
+
1610
+ # damped springs
1611
+ eval_spring_forces(model, state_in, particle_f)
1612
+
1613
+ # triangle elastic and lift/drag forces
1614
+ eval_triangle_forces(model, state_in, control, particle_f)
1615
+
1616
+ # triangle/triangle contacts
1617
+ eval_triangle_contact_forces(model, state_in, particle_f)
1618
+
1619
+ # triangle bending
1620
+ eval_bending_forces(model, state_in, particle_f)
1621
+
1622
+ # tetrahedral FEM
1623
+ eval_tetrahedral_forces(model, state_in, control, particle_f)
1624
+
1625
+ # particle-particle interactions
1626
+ eval_particle_forces(model, state_in, particle_f)
1627
+
1628
+ # particle ground contacts
1629
+ eval_particle_ground_contact_forces(model, state_in, particle_f)
1630
+
1631
+ # particle shape contact
1632
+ eval_particle_body_contact_forces(model, state_in, particle_f, body_f)
1633
+
1634
+ # muscles
1635
+ if False:
1636
+ eval_muscle_forces(model, state_in, control, body_f)
1637
+
1638
+ # ----------------------------
1639
+ # articulations
1640
+
1641
+ if model.joint_count:
1642
+
1643
+ # evaluate body transforms
1644
+ wp.launch(
1645
+ eval_rigid_fk,
1646
+ dim=model.articulation_count,
1647
+ inputs=[
1648
+ model.articulation_start,
1649
+ model.joint_type,
1650
+ model.joint_parent,
1651
+ model.joint_child,
1652
+ model.joint_q_start,
1653
+ state_in.joint_q,
1654
+ model.joint_X_p,
1655
+ model.joint_X_c,
1656
+ self.body_X_com,
1657
+ model.joint_axis,
1658
+ model.joint_axis_start,
1659
+ model.joint_axis_dim,
1660
+ ],
1661
+ outputs=[state_in.body_q, state_aug.body_q_com],
1662
+ device=model.device,
1663
+ )
1664
+
1665
+ # print("body_X_sc:")
1666
+ # print(state_in.body_q.numpy())
1667
+
1668
+ # evaluate joint inertias, motion vectors, and forces
1669
+ state_aug.body_f_s.zero_()
1670
+ wp.launch(
1671
+ eval_rigid_id,
1672
+ dim=model.articulation_count,
1673
+ inputs=[
1674
+ model.articulation_start,
1675
+ model.joint_type,
1676
+ model.joint_parent,
1677
+ model.joint_child,
1678
+ model.joint_q_start,
1679
+ model.joint_qd_start,
1680
+ state_in.joint_q,
1681
+ state_in.joint_qd,
1682
+ model.joint_axis,
1683
+ model.joint_axis_start,
1684
+ model.joint_axis_dim,
1685
+ self.body_I_m,
1686
+ state_in.body_q,
1687
+ state_aug.body_q_com,
1688
+ model.joint_X_p,
1689
+ model.joint_X_c,
1690
+ model.gravity,
1691
+ ],
1692
+ outputs=[
1693
+ state_aug.joint_S_s,
1694
+ state_aug.body_I_s,
1695
+ state_aug.body_v_s,
1696
+ state_aug.body_f_s,
1697
+ state_aug.body_a_s,
1698
+ ],
1699
+ device=model.device,
1700
+ )
1701
+
1702
+ if model.rigid_contact_max and (
1703
+ model.ground and model.shape_ground_contact_pair_count or model.shape_contact_pair_count
1704
+ ):
1705
+ wp.launch(
1706
+ kernel=eval_rigid_contacts,
1707
+ dim=model.rigid_contact_max,
1708
+ inputs=[
1709
+ state_in.body_q,
1710
+ state_aug.body_v_s,
1711
+ model.body_com,
1712
+ model.shape_materials,
1713
+ model.shape_geo,
1714
+ model.shape_body,
1715
+ model.rigid_contact_count,
1716
+ model.rigid_contact_point0,
1717
+ model.rigid_contact_point1,
1718
+ model.rigid_contact_normal,
1719
+ model.rigid_contact_shape0,
1720
+ model.rigid_contact_shape1,
1721
+ True,
1722
+ ],
1723
+ outputs=[body_f],
1724
+ device=model.device,
1725
+ )
1726
+
1727
+ # if model.rigid_contact_count.numpy()[0] > 0:
1728
+ # print(body_f.numpy())
1729
+
1730
+ if model.articulation_count:
1731
+ # evaluate joint torques
1732
+ state_aug.body_ft_s.zero_()
1733
+ wp.launch(
1734
+ eval_rigid_tau,
1735
+ dim=model.articulation_count,
1736
+ inputs=[
1737
+ model.articulation_start,
1738
+ model.joint_type,
1739
+ model.joint_parent,
1740
+ model.joint_child,
1741
+ model.joint_q_start,
1742
+ model.joint_qd_start,
1743
+ model.joint_axis_start,
1744
+ model.joint_axis_dim,
1745
+ model.joint_axis_mode,
1746
+ state_in.joint_q,
1747
+ state_in.joint_qd,
1748
+ control.joint_act,
1749
+ model.joint_target_ke,
1750
+ model.joint_target_kd,
1751
+ model.joint_limit_lower,
1752
+ model.joint_limit_upper,
1753
+ model.joint_limit_ke,
1754
+ model.joint_limit_kd,
1755
+ state_aug.joint_S_s,
1756
+ state_aug.body_f_s,
1757
+ body_f,
1758
+ ],
1759
+ outputs=[
1760
+ state_aug.body_ft_s,
1761
+ state_aug.joint_tau,
1762
+ ],
1763
+ device=model.device,
1764
+ )
1765
+
1766
+ # print("joint_tau:")
1767
+ # print(state_aug.joint_tau.numpy())
1768
+ # print("body_q:")
1769
+ # print(state_in.body_q.numpy())
1770
+ # print("body_qd:")
1771
+ # print(state_in.body_qd.numpy())
1772
+
1773
+ if self._step % self.update_mass_matrix_every == 0:
1774
+ # build J
1775
+ wp.launch(
1776
+ eval_rigid_jacobian,
1777
+ dim=model.articulation_count,
1778
+ inputs=[
1779
+ model.articulation_start,
1780
+ self.articulation_J_start,
1781
+ model.joint_parent,
1782
+ model.joint_qd_start,
1783
+ state_aug.joint_S_s,
1784
+ ],
1785
+ outputs=[self.J],
1786
+ device=model.device,
1787
+ )
1788
+
1789
+ # build M
1790
+ wp.launch(
1791
+ eval_rigid_mass,
1792
+ dim=model.articulation_count,
1793
+ inputs=[
1794
+ model.articulation_start,
1795
+ self.articulation_M_start,
1796
+ state_aug.body_I_s,
1797
+ ],
1798
+ outputs=[self.M],
1799
+ device=model.device,
1800
+ )
1801
+
1802
+ # form P = M*J
1803
+ wp.launch(
1804
+ eval_dense_gemm_batched,
1805
+ dim=model.articulation_count,
1806
+ inputs=[
1807
+ self.articulation_M_rows,
1808
+ self.articulation_J_cols,
1809
+ self.articulation_J_rows,
1810
+ False,
1811
+ False,
1812
+ self.articulation_M_start,
1813
+ self.articulation_J_start,
1814
+ # P start is the same as J start since it has the same dims as J
1815
+ self.articulation_J_start,
1816
+ self.M,
1817
+ self.J,
1818
+ ],
1819
+ outputs=[self.P],
1820
+ device=model.device,
1821
+ )
1822
+
1823
+ # form H = J^T*P
1824
+ wp.launch(
1825
+ eval_dense_gemm_batched,
1826
+ dim=model.articulation_count,
1827
+ inputs=[
1828
+ self.articulation_J_cols,
1829
+ self.articulation_J_cols,
1830
+ # P rows is the same as J rows
1831
+ self.articulation_J_rows,
1832
+ True,
1833
+ False,
1834
+ self.articulation_J_start,
1835
+ # P start is the same as J start since it has the same dims as J
1836
+ self.articulation_J_start,
1837
+ self.articulation_H_start,
1838
+ self.J,
1839
+ self.P,
1840
+ ],
1841
+ outputs=[self.H],
1842
+ device=model.device,
1843
+ )
1844
+
1845
+ # compute decomposition
1846
+ wp.launch(
1847
+ eval_dense_cholesky_batched,
1848
+ dim=model.articulation_count,
1849
+ inputs=[
1850
+ self.articulation_H_start,
1851
+ self.articulation_H_rows,
1852
+ self.H,
1853
+ model.joint_armature,
1854
+ ],
1855
+ outputs=[self.L],
1856
+ device=model.device,
1857
+ )
1858
+
1859
+ # print("joint_act:")
1860
+ # print(control.joint_act.numpy())
1861
+ # print("joint_tau:")
1862
+ # print(state_aug.joint_tau.numpy())
1863
+ # print("H:")
1864
+ # print(self.H.numpy())
1865
+ # print("L:")
1866
+ # print(self.L.numpy())
1867
+
1868
+ # solve for qdd
1869
+ state_aug.joint_qdd.zero_()
1870
+ wp.launch(
1871
+ eval_dense_solve_batched,
1872
+ dim=model.articulation_count,
1873
+ inputs=[
1874
+ self.articulation_H_start,
1875
+ self.articulation_H_rows,
1876
+ self.articulation_dof_start,
1877
+ self.L,
1878
+ state_aug.joint_tau,
1879
+ ],
1880
+ outputs=[
1881
+ state_aug.joint_qdd,
1882
+ state_aug.joint_solve_tmp,
1883
+ ],
1884
+ device=model.device,
1885
+ )
1886
+ # if wp.context.runtime.tape:
1887
+ # wp.context.runtime.tape.record_func(
1888
+ # backward=lambda: adj_matmul(
1889
+ # a, b, c, a.grad, b.grad, c.grad, d.grad, alpha, beta, allow_tf32x3_arith, device
1890
+ # ),
1891
+ # arrays=[a, b, c, d],
1892
+ # )
1893
+ # print("joint_qdd:")
1894
+ # print(state_aug.joint_qdd.numpy())
1895
+ # print("\n\n")
1896
+
1897
+ # -------------------------------------
1898
+ # integrate bodies
1899
+
1900
+ if model.joint_count:
1901
+ wp.launch(
1902
+ kernel=integrate_generalized_joints,
1903
+ dim=model.joint_count,
1904
+ inputs=[
1905
+ model.joint_type,
1906
+ model.joint_q_start,
1907
+ model.joint_qd_start,
1908
+ model.joint_axis_dim,
1909
+ state_in.joint_q,
1910
+ state_in.joint_qd,
1911
+ state_aug.joint_qdd,
1912
+ dt,
1913
+ ],
1914
+ outputs=[state_out.joint_q, state_out.joint_qd],
1915
+ device=model.device,
1916
+ )
1917
+
1918
+ wp.launch(
1919
+ eval_rigid_fk,
1920
+ dim=model.articulation_count,
1921
+ inputs=[
1922
+ model.articulation_start,
1923
+ model.joint_type,
1924
+ model.joint_parent,
1925
+ model.joint_child,
1926
+ model.joint_q_start,
1927
+ state_out.joint_q,
1928
+ model.joint_X_p,
1929
+ model.joint_X_c,
1930
+ self.body_X_com,
1931
+ model.joint_axis,
1932
+ model.joint_axis_start,
1933
+ model.joint_axis_dim,
1934
+ ],
1935
+ outputs=[state_out.body_q, state_aug.body_q_com],
1936
+ device=model.device,
1937
+ )
1938
+
1939
+ # compute body_qd
1940
+ state_aug.body_f_s.zero_()
1941
+ wp.launch(
1942
+ eval_rigid_id,
1943
+ dim=model.articulation_count,
1944
+ inputs=[
1945
+ model.articulation_start,
1946
+ model.joint_type,
1947
+ model.joint_parent,
1948
+ model.joint_child,
1949
+ model.joint_q_start,
1950
+ model.joint_qd_start,
1951
+ state_out.joint_q,
1952
+ state_out.joint_qd,
1953
+ model.joint_axis,
1954
+ model.joint_axis_start,
1955
+ model.joint_axis_dim,
1956
+ self.body_I_m,
1957
+ state_out.body_q,
1958
+ state_aug.body_q_com,
1959
+ model.joint_X_p,
1960
+ model.joint_X_c,
1961
+ model.gravity,
1962
+ ],
1963
+ outputs=[
1964
+ state_aug.joint_S_s,
1965
+ state_aug.body_I_s,
1966
+ state_aug.body_v_s,
1967
+ state_aug.body_f_s,
1968
+ state_aug.body_a_s,
1969
+ ],
1970
+ device=model.device,
1971
+ )
1972
+
1973
+ # body velocity in inertial frame
1974
+ wp.launch(
1975
+ kernel=eval_body_inertial_velocities,
1976
+ dim=model.body_count,
1977
+ inputs=[
1978
+ state_out.body_q,
1979
+ state_aug.body_v_s,
1980
+ ],
1981
+ outputs=[
1982
+ state_out.body_qd,
1983
+ ],
1984
+ device=model.device,
1985
+ )
1986
+
1987
+ self.integrate_particles(model, state_in, state_out, dt)
1988
+
1989
+ self._step += 1
1990
+
1991
+ return state_out