warp-lang 1.0.2__py3-none-manylinux2014_x86_64.whl → 1.2.0__py3-none-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (356) hide show
  1. warp/__init__.py +108 -97
  2. warp/__init__.pyi +1 -1
  3. warp/bin/warp-clang.so +0 -0
  4. warp/bin/warp.so +0 -0
  5. warp/build.py +88 -113
  6. warp/build_dll.py +383 -375
  7. warp/builtins.py +3693 -3354
  8. warp/codegen.py +2925 -2792
  9. warp/config.py +40 -36
  10. warp/constants.py +49 -45
  11. warp/context.py +5409 -5102
  12. warp/dlpack.py +442 -442
  13. warp/examples/__init__.py +16 -16
  14. warp/examples/assets/bear.usd +0 -0
  15. warp/examples/assets/bunny.usd +0 -0
  16. warp/examples/assets/cartpole.urdf +110 -110
  17. warp/examples/assets/crazyflie.usd +0 -0
  18. warp/examples/assets/cube.usd +0 -0
  19. warp/examples/assets/nv_ant.xml +92 -92
  20. warp/examples/assets/nv_humanoid.xml +183 -183
  21. warp/examples/assets/quadruped.urdf +267 -267
  22. warp/examples/assets/rocks.nvdb +0 -0
  23. warp/examples/assets/rocks.usd +0 -0
  24. warp/examples/assets/sphere.usd +0 -0
  25. warp/examples/benchmarks/benchmark_api.py +381 -383
  26. warp/examples/benchmarks/benchmark_cloth.py +278 -277
  27. warp/examples/benchmarks/benchmark_cloth_cupy.py +88 -88
  28. warp/examples/benchmarks/benchmark_cloth_jax.py +97 -100
  29. warp/examples/benchmarks/benchmark_cloth_numba.py +146 -142
  30. warp/examples/benchmarks/benchmark_cloth_numpy.py +77 -77
  31. warp/examples/benchmarks/benchmark_cloth_pytorch.py +86 -86
  32. warp/examples/benchmarks/benchmark_cloth_taichi.py +112 -112
  33. warp/examples/benchmarks/benchmark_cloth_warp.py +145 -146
  34. warp/examples/benchmarks/benchmark_launches.py +293 -295
  35. warp/examples/browse.py +29 -29
  36. warp/examples/core/example_dem.py +232 -219
  37. warp/examples/core/example_fluid.py +291 -267
  38. warp/examples/core/example_graph_capture.py +142 -126
  39. warp/examples/core/example_marching_cubes.py +186 -174
  40. warp/examples/core/example_mesh.py +172 -155
  41. warp/examples/core/example_mesh_intersect.py +203 -193
  42. warp/examples/core/example_nvdb.py +174 -170
  43. warp/examples/core/example_raycast.py +103 -90
  44. warp/examples/core/example_raymarch.py +197 -178
  45. warp/examples/core/example_render_opengl.py +183 -141
  46. warp/examples/core/example_sph.py +403 -387
  47. warp/examples/core/example_torch.py +219 -181
  48. warp/examples/core/example_wave.py +261 -248
  49. warp/examples/fem/bsr_utils.py +378 -380
  50. warp/examples/fem/example_apic_fluid.py +432 -389
  51. warp/examples/fem/example_burgers.py +262 -0
  52. warp/examples/fem/example_convection_diffusion.py +180 -168
  53. warp/examples/fem/example_convection_diffusion_dg.py +217 -209
  54. warp/examples/fem/example_deformed_geometry.py +175 -159
  55. warp/examples/fem/example_diffusion.py +199 -173
  56. warp/examples/fem/example_diffusion_3d.py +178 -152
  57. warp/examples/fem/example_diffusion_mgpu.py +219 -214
  58. warp/examples/fem/example_mixed_elasticity.py +242 -222
  59. warp/examples/fem/example_navier_stokes.py +257 -243
  60. warp/examples/fem/example_stokes.py +218 -192
  61. warp/examples/fem/example_stokes_transfer.py +263 -249
  62. warp/examples/fem/mesh_utils.py +133 -109
  63. warp/examples/fem/plot_utils.py +292 -287
  64. warp/examples/optim/example_bounce.py +258 -246
  65. warp/examples/optim/example_cloth_throw.py +220 -209
  66. warp/examples/optim/example_diffray.py +564 -536
  67. warp/examples/optim/example_drone.py +862 -835
  68. warp/examples/optim/example_inverse_kinematics.py +174 -168
  69. warp/examples/optim/example_inverse_kinematics_torch.py +183 -169
  70. warp/examples/optim/example_spring_cage.py +237 -231
  71. warp/examples/optim/example_trajectory.py +221 -199
  72. warp/examples/optim/example_walker.py +304 -293
  73. warp/examples/sim/example_cartpole.py +137 -129
  74. warp/examples/sim/example_cloth.py +194 -186
  75. warp/examples/sim/example_granular.py +122 -111
  76. warp/examples/sim/example_granular_collision_sdf.py +195 -186
  77. warp/examples/sim/example_jacobian_ik.py +234 -214
  78. warp/examples/sim/example_particle_chain.py +116 -105
  79. warp/examples/sim/example_quadruped.py +191 -180
  80. warp/examples/sim/example_rigid_chain.py +195 -187
  81. warp/examples/sim/example_rigid_contact.py +187 -177
  82. warp/examples/sim/example_rigid_force.py +125 -125
  83. warp/examples/sim/example_rigid_gyroscopic.py +107 -95
  84. warp/examples/sim/example_rigid_soft_contact.py +132 -122
  85. warp/examples/sim/example_soft_body.py +188 -177
  86. warp/fabric.py +337 -335
  87. warp/fem/__init__.py +61 -27
  88. warp/fem/cache.py +403 -388
  89. warp/fem/dirichlet.py +178 -179
  90. warp/fem/domain.py +262 -263
  91. warp/fem/field/__init__.py +100 -101
  92. warp/fem/field/field.py +148 -149
  93. warp/fem/field/nodal_field.py +298 -299
  94. warp/fem/field/restriction.py +22 -21
  95. warp/fem/field/test.py +180 -181
  96. warp/fem/field/trial.py +183 -183
  97. warp/fem/geometry/__init__.py +16 -19
  98. warp/fem/geometry/closest_point.py +69 -70
  99. warp/fem/geometry/deformed_geometry.py +270 -271
  100. warp/fem/geometry/element.py +748 -744
  101. warp/fem/geometry/geometry.py +184 -186
  102. warp/fem/geometry/grid_2d.py +380 -373
  103. warp/fem/geometry/grid_3d.py +437 -435
  104. warp/fem/geometry/hexmesh.py +953 -953
  105. warp/fem/geometry/nanogrid.py +455 -0
  106. warp/fem/geometry/partition.py +374 -376
  107. warp/fem/geometry/quadmesh_2d.py +532 -532
  108. warp/fem/geometry/tetmesh.py +840 -840
  109. warp/fem/geometry/trimesh_2d.py +577 -577
  110. warp/fem/integrate.py +1684 -1615
  111. warp/fem/operator.py +190 -191
  112. warp/fem/polynomial.py +214 -213
  113. warp/fem/quadrature/__init__.py +2 -2
  114. warp/fem/quadrature/pic_quadrature.py +243 -245
  115. warp/fem/quadrature/quadrature.py +295 -294
  116. warp/fem/space/__init__.py +179 -292
  117. warp/fem/space/basis_space.py +522 -489
  118. warp/fem/space/collocated_function_space.py +100 -105
  119. warp/fem/space/dof_mapper.py +236 -236
  120. warp/fem/space/function_space.py +148 -145
  121. warp/fem/space/grid_2d_function_space.py +148 -267
  122. warp/fem/space/grid_3d_function_space.py +167 -306
  123. warp/fem/space/hexmesh_function_space.py +253 -352
  124. warp/fem/space/nanogrid_function_space.py +202 -0
  125. warp/fem/space/partition.py +350 -350
  126. warp/fem/space/quadmesh_2d_function_space.py +261 -369
  127. warp/fem/space/restriction.py +161 -160
  128. warp/fem/space/shape/__init__.py +90 -15
  129. warp/fem/space/shape/cube_shape_function.py +728 -738
  130. warp/fem/space/shape/shape_function.py +102 -103
  131. warp/fem/space/shape/square_shape_function.py +611 -611
  132. warp/fem/space/shape/tet_shape_function.py +565 -567
  133. warp/fem/space/shape/triangle_shape_function.py +429 -429
  134. warp/fem/space/tetmesh_function_space.py +224 -292
  135. warp/fem/space/topology.py +297 -295
  136. warp/fem/space/trimesh_2d_function_space.py +153 -221
  137. warp/fem/types.py +77 -77
  138. warp/fem/utils.py +495 -495
  139. warp/jax.py +166 -141
  140. warp/jax_experimental.py +341 -339
  141. warp/native/array.h +1081 -1025
  142. warp/native/builtin.h +1603 -1560
  143. warp/native/bvh.cpp +402 -398
  144. warp/native/bvh.cu +533 -525
  145. warp/native/bvh.h +430 -429
  146. warp/native/clang/clang.cpp +496 -464
  147. warp/native/crt.cpp +42 -32
  148. warp/native/crt.h +352 -335
  149. warp/native/cuda_crt.h +1049 -1049
  150. warp/native/cuda_util.cpp +549 -540
  151. warp/native/cuda_util.h +288 -203
  152. warp/native/cutlass_gemm.cpp +34 -34
  153. warp/native/cutlass_gemm.cu +372 -372
  154. warp/native/error.cpp +66 -66
  155. warp/native/error.h +27 -27
  156. warp/native/exports.h +187 -0
  157. warp/native/fabric.h +228 -228
  158. warp/native/hashgrid.cpp +301 -278
  159. warp/native/hashgrid.cu +78 -77
  160. warp/native/hashgrid.h +227 -227
  161. warp/native/initializer_array.h +32 -32
  162. warp/native/intersect.h +1204 -1204
  163. warp/native/intersect_adj.h +365 -365
  164. warp/native/intersect_tri.h +322 -322
  165. warp/native/marching.cpp +2 -2
  166. warp/native/marching.cu +497 -497
  167. warp/native/marching.h +2 -2
  168. warp/native/mat.h +1545 -1498
  169. warp/native/matnn.h +333 -333
  170. warp/native/mesh.cpp +203 -203
  171. warp/native/mesh.cu +292 -293
  172. warp/native/mesh.h +1887 -1887
  173. warp/native/nanovdb/GridHandle.h +366 -0
  174. warp/native/nanovdb/HostBuffer.h +590 -0
  175. warp/native/nanovdb/NanoVDB.h +6624 -4782
  176. warp/native/nanovdb/PNanoVDB.h +3390 -2553
  177. warp/native/noise.h +850 -850
  178. warp/native/quat.h +1112 -1085
  179. warp/native/rand.h +303 -299
  180. warp/native/range.h +108 -108
  181. warp/native/reduce.cpp +156 -156
  182. warp/native/reduce.cu +348 -348
  183. warp/native/runlength_encode.cpp +61 -61
  184. warp/native/runlength_encode.cu +46 -46
  185. warp/native/scan.cpp +30 -30
  186. warp/native/scan.cu +36 -36
  187. warp/native/scan.h +7 -7
  188. warp/native/solid_angle.h +442 -442
  189. warp/native/sort.cpp +94 -94
  190. warp/native/sort.cu +97 -97
  191. warp/native/sort.h +14 -14
  192. warp/native/sparse.cpp +337 -337
  193. warp/native/sparse.cu +544 -544
  194. warp/native/spatial.h +630 -630
  195. warp/native/svd.h +562 -562
  196. warp/native/temp_buffer.h +30 -30
  197. warp/native/vec.h +1177 -1133
  198. warp/native/volume.cpp +529 -297
  199. warp/native/volume.cu +58 -32
  200. warp/native/volume.h +960 -538
  201. warp/native/volume_builder.cu +446 -425
  202. warp/native/volume_builder.h +34 -19
  203. warp/native/volume_impl.h +61 -0
  204. warp/native/warp.cpp +1057 -1052
  205. warp/native/warp.cu +2949 -2828
  206. warp/native/warp.h +321 -305
  207. warp/optim/__init__.py +9 -9
  208. warp/optim/adam.py +120 -120
  209. warp/optim/linear.py +1104 -939
  210. warp/optim/sgd.py +104 -92
  211. warp/render/__init__.py +10 -10
  212. warp/render/render_opengl.py +3356 -3204
  213. warp/render/render_usd.py +768 -749
  214. warp/render/utils.py +152 -150
  215. warp/sim/__init__.py +52 -59
  216. warp/sim/articulation.py +685 -685
  217. warp/sim/collide.py +1594 -1590
  218. warp/sim/import_mjcf.py +489 -481
  219. warp/sim/import_snu.py +220 -221
  220. warp/sim/import_urdf.py +536 -516
  221. warp/sim/import_usd.py +887 -881
  222. warp/sim/inertia.py +316 -317
  223. warp/sim/integrator.py +234 -233
  224. warp/sim/integrator_euler.py +1956 -1956
  225. warp/sim/integrator_featherstone.py +1917 -1991
  226. warp/sim/integrator_xpbd.py +3288 -3312
  227. warp/sim/model.py +4473 -4314
  228. warp/sim/particles.py +113 -112
  229. warp/sim/render.py +417 -403
  230. warp/sim/utils.py +413 -410
  231. warp/sparse.py +1289 -1227
  232. warp/stubs.py +2192 -2469
  233. warp/tape.py +1162 -225
  234. warp/tests/__init__.py +1 -1
  235. warp/tests/__main__.py +4 -4
  236. warp/tests/assets/test_index_grid.nvdb +0 -0
  237. warp/tests/assets/torus.usda +105 -105
  238. warp/tests/aux_test_class_kernel.py +26 -26
  239. warp/tests/aux_test_compile_consts_dummy.py +10 -10
  240. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -21
  241. warp/tests/aux_test_dependent.py +20 -22
  242. warp/tests/aux_test_grad_customs.py +21 -23
  243. warp/tests/aux_test_reference.py +9 -11
  244. warp/tests/aux_test_reference_reference.py +8 -10
  245. warp/tests/aux_test_square.py +15 -17
  246. warp/tests/aux_test_unresolved_func.py +14 -14
  247. warp/tests/aux_test_unresolved_symbol.py +14 -14
  248. warp/tests/disabled_kinematics.py +237 -239
  249. warp/tests/run_coverage_serial.py +31 -31
  250. warp/tests/test_adam.py +155 -157
  251. warp/tests/test_arithmetic.py +1088 -1124
  252. warp/tests/test_array.py +2415 -2326
  253. warp/tests/test_array_reduce.py +148 -150
  254. warp/tests/test_async.py +666 -656
  255. warp/tests/test_atomic.py +139 -141
  256. warp/tests/test_bool.py +212 -149
  257. warp/tests/test_builtins_resolution.py +1290 -1292
  258. warp/tests/test_bvh.py +162 -171
  259. warp/tests/test_closest_point_edge_edge.py +227 -228
  260. warp/tests/test_codegen.py +562 -553
  261. warp/tests/test_compile_consts.py +217 -101
  262. warp/tests/test_conditional.py +244 -246
  263. warp/tests/test_copy.py +230 -215
  264. warp/tests/test_ctypes.py +630 -632
  265. warp/tests/test_dense.py +65 -67
  266. warp/tests/test_devices.py +89 -98
  267. warp/tests/test_dlpack.py +528 -529
  268. warp/tests/test_examples.py +403 -378
  269. warp/tests/test_fabricarray.py +952 -955
  270. warp/tests/test_fast_math.py +60 -54
  271. warp/tests/test_fem.py +1298 -1278
  272. warp/tests/test_fp16.py +128 -130
  273. warp/tests/test_func.py +336 -337
  274. warp/tests/test_generics.py +596 -571
  275. warp/tests/test_grad.py +885 -640
  276. warp/tests/test_grad_customs.py +331 -336
  277. warp/tests/test_hash_grid.py +208 -164
  278. warp/tests/test_import.py +37 -39
  279. warp/tests/test_indexedarray.py +1132 -1134
  280. warp/tests/test_intersect.py +65 -67
  281. warp/tests/test_jax.py +305 -307
  282. warp/tests/test_large.py +169 -164
  283. warp/tests/test_launch.py +352 -354
  284. warp/tests/test_lerp.py +217 -261
  285. warp/tests/test_linear_solvers.py +189 -171
  286. warp/tests/test_lvalue.py +419 -493
  287. warp/tests/test_marching_cubes.py +63 -65
  288. warp/tests/test_mat.py +1799 -1827
  289. warp/tests/test_mat_lite.py +113 -115
  290. warp/tests/test_mat_scalar_ops.py +2905 -2889
  291. warp/tests/test_math.py +124 -193
  292. warp/tests/test_matmul.py +498 -499
  293. warp/tests/test_matmul_lite.py +408 -410
  294. warp/tests/test_mempool.py +186 -190
  295. warp/tests/test_mesh.py +281 -324
  296. warp/tests/test_mesh_query_aabb.py +226 -241
  297. warp/tests/test_mesh_query_point.py +690 -702
  298. warp/tests/test_mesh_query_ray.py +290 -303
  299. warp/tests/test_mlp.py +274 -276
  300. warp/tests/test_model.py +108 -110
  301. warp/tests/test_module_hashing.py +111 -0
  302. warp/tests/test_modules_lite.py +36 -39
  303. warp/tests/test_multigpu.py +161 -163
  304. warp/tests/test_noise.py +244 -248
  305. warp/tests/test_operators.py +248 -250
  306. warp/tests/test_options.py +121 -125
  307. warp/tests/test_peer.py +131 -137
  308. warp/tests/test_pinned.py +76 -78
  309. warp/tests/test_print.py +52 -54
  310. warp/tests/test_quat.py +2084 -2086
  311. warp/tests/test_rand.py +324 -288
  312. warp/tests/test_reload.py +207 -217
  313. warp/tests/test_rounding.py +177 -179
  314. warp/tests/test_runlength_encode.py +188 -190
  315. warp/tests/test_sim_grad.py +241 -0
  316. warp/tests/test_sim_kinematics.py +89 -97
  317. warp/tests/test_smoothstep.py +166 -168
  318. warp/tests/test_snippet.py +303 -266
  319. warp/tests/test_sparse.py +466 -460
  320. warp/tests/test_spatial.py +2146 -2148
  321. warp/tests/test_special_values.py +362 -0
  322. warp/tests/test_streams.py +484 -473
  323. warp/tests/test_struct.py +708 -675
  324. warp/tests/test_tape.py +171 -148
  325. warp/tests/test_torch.py +741 -743
  326. warp/tests/test_transient_module.py +85 -87
  327. warp/tests/test_types.py +554 -659
  328. warp/tests/test_utils.py +488 -499
  329. warp/tests/test_vec.py +1262 -1268
  330. warp/tests/test_vec_lite.py +71 -73
  331. warp/tests/test_vec_scalar_ops.py +2097 -2099
  332. warp/tests/test_verify_fp.py +92 -94
  333. warp/tests/test_volume.py +961 -736
  334. warp/tests/test_volume_write.py +338 -265
  335. warp/tests/unittest_serial.py +38 -37
  336. warp/tests/unittest_suites.py +367 -359
  337. warp/tests/unittest_utils.py +434 -578
  338. warp/tests/unused_test_misc.py +69 -71
  339. warp/tests/walkthrough_debug.py +85 -85
  340. warp/thirdparty/appdirs.py +598 -598
  341. warp/thirdparty/dlpack.py +143 -143
  342. warp/thirdparty/unittest_parallel.py +563 -561
  343. warp/torch.py +321 -295
  344. warp/types.py +4941 -4450
  345. warp/utils.py +1008 -821
  346. {warp_lang-1.0.2.dist-info → warp_lang-1.2.0.dist-info}/LICENSE.md +126 -126
  347. {warp_lang-1.0.2.dist-info → warp_lang-1.2.0.dist-info}/METADATA +365 -400
  348. warp_lang-1.2.0.dist-info/RECORD +359 -0
  349. warp/examples/assets/cube.usda +0 -42
  350. warp/examples/assets/sphere.usda +0 -56
  351. warp/examples/assets/torus.usda +0 -105
  352. warp/examples/fem/example_convection_diffusion_dg0.py +0 -194
  353. warp/native/nanovdb/PNanoVDBWrite.h +0 -295
  354. warp_lang-1.0.2.dist-info/RECORD +0 -352
  355. {warp_lang-1.0.2.dist-info → warp_lang-1.2.0.dist-info}/WHEEL +0 -0
  356. {warp_lang-1.0.2.dist-info → warp_lang-1.2.0.dist-info}/top_level.txt +0 -0
@@ -1,1991 +1,1917 @@
1
- # Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
2
- # NVIDIA CORPORATION and its licensors retain all intellectual property
3
- # and proprietary rights in and to this software, related documentation
4
- # and any modifications thereto. Any use, reproduction, disclosure or
5
- # distribution of this software and related documentation without an express
6
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
-
8
- import warp as wp
9
-
10
- from .model import Model, State, Control
11
-
12
- from .integrator import Integrator
13
-
14
- from .integrator_euler import (
15
- eval_spring_forces,
16
- eval_triangle_forces,
17
- eval_triangle_contact_forces,
18
- eval_bending_forces,
19
- eval_tetrahedral_forces,
20
- eval_particle_forces,
21
- eval_particle_ground_contact_forces,
22
- eval_particle_body_contact_forces,
23
- eval_muscle_forces,
24
- eval_rigid_contacts,
25
- eval_joint_force,
26
- )
27
-
28
- from .articulation import (
29
- compute_2d_rotational_dofs,
30
- compute_3d_rotational_dofs,
31
- )
32
-
33
-
34
- # Frank & Park definition 3.20, pg 100
35
- @wp.func
36
- def transform_twist(t: wp.transform, x: wp.spatial_vector):
37
- q = wp.transform_get_rotation(t)
38
- p = wp.transform_get_translation(t)
39
-
40
- w = wp.spatial_top(x)
41
- v = wp.spatial_bottom(x)
42
-
43
- w = wp.quat_rotate(q, w)
44
- v = wp.quat_rotate(q, v) + wp.cross(p, w)
45
-
46
- return wp.spatial_vector(w, v)
47
-
48
-
49
- @wp.func
50
- def transform_wrench(t: wp.transform, x: wp.spatial_vector):
51
- q = wp.transform_get_rotation(t)
52
- p = wp.transform_get_translation(t)
53
-
54
- w = wp.spatial_top(x)
55
- v = wp.spatial_bottom(x)
56
-
57
- v = wp.quat_rotate(q, v)
58
- w = wp.quat_rotate(q, w) + wp.cross(p, v)
59
-
60
- return wp.spatial_vector(w, v)
61
-
62
-
63
- @wp.func
64
- def spatial_adjoint(R: wp.mat33, S: wp.mat33):
65
- # T = [R 0]
66
- # [S R]
67
-
68
- # fmt: off
69
- return wp.spatial_matrix(
70
- R[0, 0], R[0, 1], R[0, 2], 0.0, 0.0, 0.0,
71
- R[1, 0], R[1, 1], R[1, 2], 0.0, 0.0, 0.0,
72
- R[2, 0], R[2, 1], R[2, 2], 0.0, 0.0, 0.0,
73
- S[0, 0], S[0, 1], S[0, 2], R[0, 0], R[0, 1], R[0, 2],
74
- S[1, 0], S[1, 1], S[1, 2], R[1, 0], R[1, 1], R[1, 2],
75
- S[2, 0], S[2, 1], S[2, 2], R[2, 0], R[2, 1], R[2, 2],
76
- )
77
- # fmt: on
78
-
79
-
80
- @wp.kernel
81
- def compute_spatial_inertia(
82
- body_inertia: wp.array(dtype=wp.mat33),
83
- body_mass: wp.array(dtype=float),
84
- # outputs
85
- body_I_m: wp.array(dtype=wp.spatial_matrix),
86
- ):
87
- tid = wp.tid()
88
- I = body_inertia[tid]
89
- m = body_mass[tid]
90
- # fmt: off
91
- body_I_m[tid] = wp.spatial_matrix(
92
- I[0, 0], I[0, 1], I[0, 2], 0.0, 0.0, 0.0,
93
- I[1, 0], I[1, 1], I[1, 2], 0.0, 0.0, 0.0,
94
- I[2, 0], I[2, 1], I[2, 2], 0.0, 0.0, 0.0,
95
- 0.0, 0.0, 0.0, m, 0.0, 0.0,
96
- 0.0, 0.0, 0.0, 0.0, m, 0.0,
97
- 0.0, 0.0, 0.0, 0.0, 0.0, m,
98
- )
99
- # fmt: on
100
-
101
-
102
- @wp.kernel
103
- def compute_com_transforms(
104
- body_com: wp.array(dtype=wp.vec3),
105
- # outputs
106
- body_X_com: wp.array(dtype=wp.transform),
107
- ):
108
- tid = wp.tid()
109
- com = body_com[tid]
110
- body_X_com[tid] = wp.transform(com, wp.quat_identity())
111
-
112
-
113
- # computes adj_t^-T*I*adj_t^-1 (tensor change of coordinates), Frank & Park, section 8.2.3, pg 290
114
- @wp.func
115
- def spatial_transform_inertia(t: wp.transform, I: wp.spatial_matrix):
116
- t_inv = wp.transform_inverse(t)
117
-
118
- q = wp.transform_get_rotation(t_inv)
119
- p = wp.transform_get_translation(t_inv)
120
-
121
- r1 = wp.quat_rotate(q, wp.vec3(1.0, 0.0, 0.0))
122
- r2 = wp.quat_rotate(q, wp.vec3(0.0, 1.0, 0.0))
123
- r3 = wp.quat_rotate(q, wp.vec3(0.0, 0.0, 1.0))
124
-
125
- R = wp.mat33(r1, r2, r3)
126
- S = wp.skew(p) @ R
127
-
128
- T = spatial_adjoint(R, S)
129
-
130
- return wp.mul(wp.mul(wp.transpose(T), I), T)
131
-
132
-
133
- # compute transform across a joint
134
- @wp.func
135
- def jcalc_transform(
136
- type: int,
137
- joint_axis: wp.array(dtype=wp.vec3),
138
- axis_start: int,
139
- lin_axis_count: int,
140
- ang_axis_count: int,
141
- joint_q: wp.array(dtype=float),
142
- start: int,
143
- ):
144
- if type == wp.sim.JOINT_PRISMATIC:
145
- q = joint_q[start]
146
- axis = joint_axis[axis_start]
147
- X_jc = wp.transform(axis * q, wp.quat_identity())
148
- return X_jc
149
-
150
- if type == wp.sim.JOINT_REVOLUTE:
151
- q = joint_q[start]
152
- axis = joint_axis[axis_start]
153
- X_jc = wp.transform(wp.vec3(), wp.quat_from_axis_angle(axis, q))
154
- return X_jc
155
-
156
- if type == wp.sim.JOINT_BALL:
157
- qx = joint_q[start + 0]
158
- qy = joint_q[start + 1]
159
- qz = joint_q[start + 2]
160
- qw = joint_q[start + 3]
161
-
162
- X_jc = wp.transform(wp.vec3(), wp.quat(qx, qy, qz, qw))
163
- return X_jc
164
-
165
- if type == wp.sim.JOINT_FIXED:
166
- X_jc = wp.transform_identity()
167
- return X_jc
168
-
169
- if type == wp.sim.JOINT_FREE or type == wp.sim.JOINT_DISTANCE:
170
- px = joint_q[start + 0]
171
- py = joint_q[start + 1]
172
- pz = joint_q[start + 2]
173
-
174
- qx = joint_q[start + 3]
175
- qy = joint_q[start + 4]
176
- qz = joint_q[start + 5]
177
- qw = joint_q[start + 6]
178
-
179
- X_jc = wp.transform(wp.vec3(px, py, pz), wp.quat(qx, qy, qz, qw))
180
- return X_jc
181
-
182
- if type == wp.sim.JOINT_COMPOUND:
183
- rot, _ = compute_3d_rotational_dofs(
184
- joint_axis[axis_start],
185
- joint_axis[axis_start + 1],
186
- joint_axis[axis_start + 2],
187
- joint_q[start + 0],
188
- joint_q[start + 1],
189
- joint_q[start + 2],
190
- 0.0,
191
- 0.0,
192
- 0.0,
193
- )
194
-
195
- X_jc = wp.transform(wp.vec3(), rot)
196
- return X_jc
197
-
198
- if type == wp.sim.JOINT_UNIVERSAL:
199
- rot, _ = compute_2d_rotational_dofs(
200
- joint_axis[axis_start],
201
- joint_axis[axis_start + 1],
202
- joint_q[start + 0],
203
- joint_q[start + 1],
204
- 0.0,
205
- 0.0,
206
- )
207
-
208
- X_jc = wp.transform(wp.vec3(), rot)
209
- return X_jc
210
-
211
- if type == wp.sim.JOINT_D6:
212
- pos = wp.vec3(0.0)
213
- rot = wp.quat_identity()
214
-
215
- # unroll for loop to ensure joint actions remain differentiable
216
- # (since differentiating through a for loop that updates a local variable is not supported)
217
-
218
- if lin_axis_count > 0:
219
- axis = joint_axis[axis_start + 0]
220
- pos += axis * joint_q[start + 0]
221
- if lin_axis_count > 1:
222
- axis = joint_axis[axis_start + 1]
223
- pos += axis * joint_q[start + 1]
224
- if lin_axis_count > 2:
225
- axis = joint_axis[axis_start + 2]
226
- pos += axis * joint_q[start + 2]
227
-
228
- ia = axis_start + lin_axis_count
229
- iq = start + lin_axis_count
230
- if ang_axis_count == 1:
231
- axis = joint_axis[ia]
232
- rot = wp.quat_from_axis_angle(axis, joint_q[iq])
233
- if ang_axis_count == 2:
234
- rot, _ = compute_2d_rotational_dofs(
235
- joint_axis[ia + 0],
236
- joint_axis[ia + 1],
237
- joint_q[iq + 0],
238
- joint_q[iq + 1],
239
- 0.0,
240
- 0.0,
241
- )
242
- if ang_axis_count == 3:
243
- rot, _ = compute_3d_rotational_dofs(
244
- joint_axis[ia + 0],
245
- joint_axis[ia + 1],
246
- joint_axis[ia + 2],
247
- joint_q[iq + 0],
248
- joint_q[iq + 1],
249
- joint_q[iq + 2],
250
- 0.0,
251
- 0.0,
252
- 0.0,
253
- )
254
-
255
- X_jc = wp.transform(pos, rot)
256
- return X_jc
257
-
258
- # default case
259
- return wp.transform_identity()
260
-
261
-
262
- # compute motion subspace and velocity for a joint
263
- @wp.func
264
- def jcalc_motion(
265
- type: int,
266
- joint_axis: wp.array(dtype=wp.vec3),
267
- axis_start: int,
268
- lin_axis_count: int,
269
- ang_axis_count: int,
270
- X_sc: wp.transform,
271
- joint_q: wp.array(dtype=float),
272
- joint_qd: wp.array(dtype=float),
273
- q_start: int,
274
- qd_start: int,
275
- # outputs
276
- joint_S_s: wp.array(dtype=wp.spatial_vector),
277
- ):
278
- if type == wp.sim.JOINT_PRISMATIC:
279
- axis = joint_axis[axis_start]
280
- S_s = transform_twist(X_sc, wp.spatial_vector(wp.vec3(), axis))
281
- v_j_s = S_s * joint_qd[qd_start]
282
- joint_S_s[qd_start] = S_s
283
- return v_j_s
284
-
285
- if type == wp.sim.JOINT_REVOLUTE:
286
- axis = joint_axis[axis_start]
287
- S_s = transform_twist(X_sc, wp.spatial_vector(axis, wp.vec3()))
288
- v_j_s = S_s * joint_qd[qd_start]
289
- joint_S_s[qd_start] = S_s
290
- return v_j_s
291
-
292
- if type == wp.sim.JOINT_UNIVERSAL:
293
- axis_0 = joint_axis[axis_start + 0]
294
- axis_1 = joint_axis[axis_start + 1]
295
- q_off = wp.quat_from_matrix(wp.mat33(axis_0, axis_1, wp.cross(axis_0, axis_1)))
296
- local_0 = wp.quat_rotate(q_off, wp.vec3(1.0, 0.0, 0.0))
297
- local_1 = wp.quat_rotate(q_off, wp.vec3(0.0, 1.0, 0.0))
298
-
299
- axis_0 = local_0
300
- q_0 = wp.quat_from_axis_angle(axis_0, joint_q[q_start + 0])
301
-
302
- axis_1 = wp.quat_rotate(q_0, local_1)
303
-
304
- S_0 = transform_twist(X_sc, wp.spatial_vector(axis_0, wp.vec3()))
305
- S_1 = transform_twist(X_sc, wp.spatial_vector(axis_1, wp.vec3()))
306
-
307
- joint_S_s[qd_start + 0] = S_0
308
- joint_S_s[qd_start + 1] = S_1
309
-
310
- return S_0 * joint_qd[qd_start + 0] + S_1 * joint_qd[qd_start + 1]
311
-
312
- if type == wp.sim.JOINT_COMPOUND:
313
- axis_0 = joint_axis[axis_start + 0]
314
- axis_1 = joint_axis[axis_start + 1]
315
- axis_2 = joint_axis[axis_start + 2]
316
- q_off = wp.quat_from_matrix(wp.mat33(axis_0, axis_1, axis_2))
317
- local_0 = wp.quat_rotate(q_off, wp.vec3(1.0, 0.0, 0.0))
318
- local_1 = wp.quat_rotate(q_off, wp.vec3(0.0, 1.0, 0.0))
319
- local_2 = wp.quat_rotate(q_off, wp.vec3(0.0, 0.0, 1.0))
320
-
321
- axis_0 = local_0
322
- q_0 = wp.quat_from_axis_angle(axis_0, joint_q[q_start + 0])
323
-
324
- axis_1 = wp.quat_rotate(q_0, local_1)
325
- q_1 = wp.quat_from_axis_angle(axis_1, joint_q[q_start + 1])
326
-
327
- axis_2 = wp.quat_rotate(q_1 * q_0, local_2)
328
-
329
- S_0 = transform_twist(X_sc, wp.spatial_vector(axis_0, wp.vec3()))
330
- S_1 = transform_twist(X_sc, wp.spatial_vector(axis_1, wp.vec3()))
331
- S_2 = transform_twist(X_sc, wp.spatial_vector(axis_2, wp.vec3()))
332
-
333
- joint_S_s[qd_start + 0] = S_0
334
- joint_S_s[qd_start + 1] = S_1
335
- joint_S_s[qd_start + 2] = S_2
336
-
337
- return S_0 * joint_qd[qd_start + 0] + S_1 * joint_qd[qd_start + 1] + S_2 * joint_qd[qd_start + 2]
338
-
339
- if type == wp.sim.JOINT_D6:
340
- v_j_s = wp.spatial_vector()
341
- if lin_axis_count > 0:
342
- axis = joint_axis[axis_start + 0]
343
- S_s = transform_twist(X_sc, wp.spatial_vector(wp.vec3(), axis))
344
- v_j_s += S_s * joint_qd[qd_start + 0]
345
- joint_S_s[qd_start + 0] = S_s
346
- if lin_axis_count > 1:
347
- axis = joint_axis[axis_start + 1]
348
- S_s = transform_twist(X_sc, wp.spatial_vector(wp.vec3(), axis))
349
- v_j_s += S_s * joint_qd[qd_start + 1]
350
- joint_S_s[qd_start + 1] = S_s
351
- if lin_axis_count > 2:
352
- axis = joint_axis[axis_start + 2]
353
- S_s = transform_twist(X_sc, wp.spatial_vector(wp.vec3(), axis))
354
- v_j_s += S_s * joint_qd[qd_start + 2]
355
- joint_S_s[qd_start + 2] = S_s
356
- if ang_axis_count > 0:
357
- axis = joint_axis[axis_start + lin_axis_count + 0]
358
- S_s = transform_twist(X_sc, wp.spatial_vector(axis, wp.vec3()))
359
- v_j_s += S_s * joint_qd[qd_start + lin_axis_count + 0]
360
- joint_S_s[qd_start + lin_axis_count + 0] = S_s
361
- if ang_axis_count > 1:
362
- axis = joint_axis[axis_start + lin_axis_count + 1]
363
- S_s = transform_twist(X_sc, wp.spatial_vector(axis, wp.vec3()))
364
- v_j_s += S_s * joint_qd[qd_start + lin_axis_count + 1]
365
- joint_S_s[qd_start + lin_axis_count + 1] = S_s
366
- if ang_axis_count > 2:
367
- axis = joint_axis[axis_start + lin_axis_count + 2]
368
- S_s = transform_twist(X_sc, wp.spatial_vector(axis, wp.vec3()))
369
- v_j_s += S_s * joint_qd[qd_start + lin_axis_count + 2]
370
- joint_S_s[qd_start + lin_axis_count + 2] = S_s
371
-
372
- return v_j_s
373
-
374
- if type == wp.sim.JOINT_BALL:
375
- S_0 = transform_twist(X_sc, wp.spatial_vector(1.0, 0.0, 0.0, 0.0, 0.0, 0.0))
376
- S_1 = transform_twist(X_sc, wp.spatial_vector(0.0, 1.0, 0.0, 0.0, 0.0, 0.0))
377
- S_2 = transform_twist(X_sc, wp.spatial_vector(0.0, 0.0, 1.0, 0.0, 0.0, 0.0))
378
-
379
- joint_S_s[qd_start + 0] = S_0
380
- joint_S_s[qd_start + 1] = S_1
381
- joint_S_s[qd_start + 2] = S_2
382
-
383
- return S_0 * joint_qd[qd_start + 0] + S_1 * joint_qd[qd_start + 1] + S_2 * joint_qd[qd_start + 2]
384
-
385
- if type == wp.sim.JOINT_FIXED:
386
- return wp.spatial_vector()
387
-
388
- if type == wp.sim.JOINT_FREE or type == wp.sim.JOINT_DISTANCE:
389
- v_j_s = transform_twist(
390
- X_sc,
391
- wp.spatial_vector(
392
- joint_qd[qd_start + 0],
393
- joint_qd[qd_start + 1],
394
- joint_qd[qd_start + 2],
395
- joint_qd[qd_start + 3],
396
- joint_qd[qd_start + 4],
397
- joint_qd[qd_start + 5],
398
- ),
399
- )
400
-
401
- joint_S_s[qd_start + 0] = transform_twist(X_sc, wp.spatial_vector(1.0, 0.0, 0.0, 0.0, 0.0, 0.0))
402
- joint_S_s[qd_start + 1] = transform_twist(X_sc, wp.spatial_vector(0.0, 1.0, 0.0, 0.0, 0.0, 0.0))
403
- joint_S_s[qd_start + 2] = transform_twist(X_sc, wp.spatial_vector(0.0, 0.0, 1.0, 0.0, 0.0, 0.0))
404
- joint_S_s[qd_start + 3] = transform_twist(X_sc, wp.spatial_vector(0.0, 0.0, 0.0, 1.0, 0.0, 0.0))
405
- joint_S_s[qd_start + 4] = transform_twist(X_sc, wp.spatial_vector(0.0, 0.0, 0.0, 0.0, 1.0, 0.0))
406
- joint_S_s[qd_start + 5] = transform_twist(X_sc, wp.spatial_vector(0.0, 0.0, 0.0, 0.0, 0.0, 1.0))
407
-
408
- return v_j_s
409
-
410
- wp.printf("jcalc_motion not implemented for joint type %d\n", type)
411
-
412
- # default case
413
- return wp.spatial_vector()
414
-
415
-
416
- # computes joint space forces/torques in tau
417
- @wp.func
418
- def jcalc_tau(
419
- type: int,
420
- joint_target_ke: wp.array(dtype=float),
421
- joint_target_kd: wp.array(dtype=float),
422
- joint_limit_ke: wp.array(dtype=float),
423
- joint_limit_kd: wp.array(dtype=float),
424
- joint_S_s: wp.array(dtype=wp.spatial_vector),
425
- joint_q: wp.array(dtype=float),
426
- joint_qd: wp.array(dtype=float),
427
- joint_act: wp.array(dtype=float),
428
- joint_axis_mode: wp.array(dtype=int),
429
- joint_limit_lower: wp.array(dtype=float),
430
- joint_limit_upper: wp.array(dtype=float),
431
- coord_start: int,
432
- dof_start: int,
433
- axis_start: int,
434
- lin_axis_count: int,
435
- ang_axis_count: int,
436
- body_f_s: wp.spatial_vector,
437
- # outputs
438
- tau: wp.array(dtype=float),
439
- ):
440
- if type == wp.sim.JOINT_PRISMATIC or type == wp.sim.JOINT_REVOLUTE:
441
- S_s = joint_S_s[dof_start]
442
-
443
- q = joint_q[coord_start]
444
- qd = joint_qd[dof_start]
445
- act = joint_act[axis_start]
446
-
447
- lower = joint_limit_lower[axis_start]
448
- upper = joint_limit_upper[axis_start]
449
-
450
- limit_ke = joint_limit_ke[axis_start]
451
- limit_kd = joint_limit_kd[axis_start]
452
- target_ke = joint_target_ke[axis_start]
453
- target_kd = joint_target_kd[axis_start]
454
- mode = joint_axis_mode[axis_start]
455
-
456
- # total torque / force on the joint
457
- t = -wp.dot(S_s, body_f_s) + eval_joint_force(
458
- q, qd, act, target_ke, target_kd, lower, upper, limit_ke, limit_kd, mode
459
- )
460
-
461
- tau[dof_start] = t
462
-
463
- return
464
-
465
- if type == wp.sim.JOINT_BALL:
466
- # target_ke = joint_target_ke[axis_start]
467
- # target_kd = joint_target_kd[axis_start]
468
-
469
- for i in range(3):
470
- S_s = joint_S_s[dof_start + i]
471
-
472
- # w = joint_qd[dof_start + i]
473
- # r = joint_q[coord_start + i]
474
-
475
- tau[dof_start + i] = -wp.dot(S_s, body_f_s) # - w * target_kd - r * target_ke
476
-
477
- return
478
-
479
- if type == wp.sim.JOINT_FREE or type == wp.sim.JOINT_DISTANCE:
480
- for i in range(6):
481
- S_s = joint_S_s[dof_start + i]
482
- tau[dof_start + i] = -wp.dot(S_s, body_f_s)
483
-
484
- return
485
-
486
- if type == wp.sim.JOINT_COMPOUND or type == wp.sim.JOINT_UNIVERSAL or type == wp.sim.JOINT_D6:
487
- axis_count = lin_axis_count + ang_axis_count
488
-
489
- for i in range(axis_count):
490
- S_s = joint_S_s[dof_start + i]
491
-
492
- q = joint_q[coord_start + i]
493
- qd = joint_qd[dof_start + i]
494
- act = joint_act[axis_start + i]
495
-
496
- lower = joint_limit_lower[axis_start + i]
497
- upper = joint_limit_upper[axis_start + i]
498
- limit_ke = joint_limit_ke[axis_start + i]
499
- limit_kd = joint_limit_kd[axis_start + i]
500
- target_ke = joint_target_ke[axis_start + i]
501
- target_kd = joint_target_kd[axis_start + i]
502
- mode = joint_axis_mode[axis_start + i]
503
-
504
- f = eval_joint_force(q, qd, act, target_ke, target_kd, lower, upper, limit_ke, limit_kd, mode)
505
-
506
- # total torque / force on the joint
507
- t = -wp.dot(S_s, body_f_s) + f
508
-
509
- tau[dof_start + i] = t
510
-
511
- return
512
-
513
-
514
- @wp.func
515
- def jcalc_integrate(
516
- type: int,
517
- joint_q: wp.array(dtype=float),
518
- joint_qd: wp.array(dtype=float),
519
- joint_qdd: wp.array(dtype=float),
520
- coord_start: int,
521
- dof_start: int,
522
- lin_axis_count: int,
523
- ang_axis_count: int,
524
- dt: float,
525
- # outputs
526
- joint_q_new: wp.array(dtype=float),
527
- joint_qd_new: wp.array(dtype=float),
528
- ):
529
- if type == wp.sim.JOINT_FIXED:
530
- return
531
-
532
- # prismatic / revolute
533
- if type == wp.sim.JOINT_PRISMATIC or type == wp.sim.JOINT_REVOLUTE:
534
- qdd = joint_qdd[dof_start]
535
- qd = joint_qd[dof_start]
536
- q = joint_q[coord_start]
537
-
538
- qd_new = qd + qdd * dt
539
- q_new = q + qd_new * dt
540
-
541
- joint_qd_new[dof_start] = qd_new
542
- joint_q_new[coord_start] = q_new
543
-
544
- return
545
-
546
- # ball
547
- if type == wp.sim.JOINT_BALL:
548
- m_j = wp.vec3(joint_qdd[dof_start + 0], joint_qdd[dof_start + 1], joint_qdd[dof_start + 2])
549
- w_j = wp.vec3(joint_qd[dof_start + 0], joint_qd[dof_start + 1], joint_qd[dof_start + 2])
550
-
551
- r_j = wp.quat(
552
- joint_q[coord_start + 0], joint_q[coord_start + 1], joint_q[coord_start + 2], joint_q[coord_start + 3]
553
- )
554
-
555
- # symplectic Euler
556
- w_j_new = w_j + m_j * dt
557
-
558
- drdt_j = wp.quat(w_j_new, 0.0) * r_j * 0.5
559
-
560
- # new orientation (normalized)
561
- r_j_new = wp.normalize(r_j + drdt_j * dt)
562
-
563
- # update joint coords
564
- joint_q_new[coord_start + 0] = r_j_new[0]
565
- joint_q_new[coord_start + 1] = r_j_new[1]
566
- joint_q_new[coord_start + 2] = r_j_new[2]
567
- joint_q_new[coord_start + 3] = r_j_new[3]
568
-
569
- # update joint vel
570
- joint_qd_new[dof_start + 0] = w_j_new[0]
571
- joint_qd_new[dof_start + 1] = w_j_new[1]
572
- joint_qd_new[dof_start + 2] = w_j_new[2]
573
-
574
- return
575
-
576
- # free joint
577
- if type == wp.sim.JOINT_FREE or type == wp.sim.JOINT_DISTANCE:
578
- # dofs: qd = (omega_x, omega_y, omega_z, vel_x, vel_y, vel_z)
579
- # coords: q = (trans_x, trans_y, trans_z, quat_x, quat_y, quat_z, quat_w)
580
-
581
- # angular and linear acceleration
582
- m_s = wp.vec3(joint_qdd[dof_start + 0], joint_qdd[dof_start + 1], joint_qdd[dof_start + 2])
583
- a_s = wp.vec3(joint_qdd[dof_start + 3], joint_qdd[dof_start + 4], joint_qdd[dof_start + 5])
584
-
585
- # angular and linear velocity
586
- w_s = wp.vec3(joint_qd[dof_start + 0], joint_qd[dof_start + 1], joint_qd[dof_start + 2])
587
- v_s = wp.vec3(joint_qd[dof_start + 3], joint_qd[dof_start + 4], joint_qd[dof_start + 5])
588
-
589
- # symplectic Euler
590
- w_s = w_s + m_s * dt
591
- v_s = v_s + a_s * dt
592
-
593
- # translation of origin
594
- p_s = wp.vec3(joint_q[coord_start + 0], joint_q[coord_start + 1], joint_q[coord_start + 2])
595
-
596
- # linear vel of origin (note q/qd switch order of linear angular elements)
597
- # note we are converting the body twist in the space frame (w_s, v_s) to compute center of mass velcity
598
- dpdt_s = v_s + wp.cross(w_s, p_s)
599
-
600
- # quat and quat derivative
601
- r_s = wp.quat(
602
- joint_q[coord_start + 3], joint_q[coord_start + 4], joint_q[coord_start + 5], joint_q[coord_start + 6]
603
- )
604
-
605
- drdt_s = wp.quat(w_s, 0.0) * r_s * 0.5
606
-
607
- # new orientation (normalized)
608
- p_s_new = p_s + dpdt_s * dt
609
- r_s_new = wp.normalize(r_s + drdt_s * dt)
610
-
611
- # update transform
612
- joint_q_new[coord_start + 0] = p_s_new[0]
613
- joint_q_new[coord_start + 1] = p_s_new[1]
614
- joint_q_new[coord_start + 2] = p_s_new[2]
615
-
616
- joint_q_new[coord_start + 3] = r_s_new[0]
617
- joint_q_new[coord_start + 4] = r_s_new[1]
618
- joint_q_new[coord_start + 5] = r_s_new[2]
619
- joint_q_new[coord_start + 6] = r_s_new[3]
620
-
621
- # update joint_twist
622
- joint_qd_new[dof_start + 0] = w_s[0]
623
- joint_qd_new[dof_start + 1] = w_s[1]
624
- joint_qd_new[dof_start + 2] = w_s[2]
625
- joint_qd_new[dof_start + 3] = v_s[0]
626
- joint_qd_new[dof_start + 4] = v_s[1]
627
- joint_qd_new[dof_start + 5] = v_s[2]
628
-
629
- return
630
-
631
- # other joint types (compound, universal, D6)
632
- if type == wp.sim.JOINT_COMPOUND or type == wp.sim.JOINT_UNIVERSAL or type == wp.sim.JOINT_D6:
633
- axis_count = lin_axis_count + ang_axis_count
634
-
635
- for i in range(axis_count):
636
- qdd = joint_qdd[dof_start + i]
637
- qd = joint_qd[dof_start + i]
638
- q = joint_q[coord_start + i]
639
-
640
- qd_new = qd + qdd * dt
641
- q_new = q + qd_new * dt
642
-
643
- joint_qd_new[dof_start + i] = qd_new
644
- joint_q_new[coord_start + i] = q_new
645
-
646
- return
647
-
648
-
649
- @wp.func
650
- def compute_link_transform(
651
- i: int,
652
- joint_type: wp.array(dtype=int),
653
- joint_parent: wp.array(dtype=int),
654
- joint_child: wp.array(dtype=int),
655
- joint_q_start: wp.array(dtype=int),
656
- joint_q: wp.array(dtype=float),
657
- joint_X_p: wp.array(dtype=wp.transform),
658
- joint_X_c: wp.array(dtype=wp.transform),
659
- body_X_com: wp.array(dtype=wp.transform),
660
- joint_axis: wp.array(dtype=wp.vec3),
661
- joint_axis_start: wp.array(dtype=int),
662
- joint_axis_dim: wp.array(dtype=int, ndim=2),
663
- # outputs
664
- body_q: wp.array(dtype=wp.transform),
665
- body_q_com: wp.array(dtype=wp.transform),
666
- ):
667
- # parent transform
668
- parent = joint_parent[i]
669
- child = joint_child[i]
670
-
671
- # parent transform in spatial coordinates
672
- X_pj = joint_X_p[i]
673
- X_cj = joint_X_c[i]
674
- # parent anchor frame in world space
675
- X_wpj = X_pj
676
- if parent >= 0:
677
- X_wp = body_q[parent]
678
- X_wpj = X_wp * X_wpj
679
-
680
- type = joint_type[i]
681
- axis_start = joint_axis_start[i]
682
- lin_axis_count = joint_axis_dim[i, 0]
683
- ang_axis_count = joint_axis_dim[i, 1]
684
- coord_start = joint_q_start[i]
685
-
686
- # compute transform across joint
687
- X_j = jcalc_transform(type, joint_axis, axis_start, lin_axis_count, ang_axis_count, joint_q, coord_start)
688
-
689
- # transform from world to joint anchor frame at child body
690
- X_wcj = X_wpj * X_j
691
- # transform from world to child body frame
692
- X_wc = X_wcj * wp.transform_inverse(X_cj)
693
-
694
- # compute transform of center of mass
695
- X_cm = body_X_com[child]
696
- X_sm = X_wc * X_cm
697
-
698
- # store geometry transforms
699
- body_q[child] = X_wc
700
- body_q_com[child] = X_sm
701
-
702
-
703
- @wp.kernel
704
- def eval_rigid_fk(
705
- articulation_start: wp.array(dtype=int),
706
- joint_type: wp.array(dtype=int),
707
- joint_parent: wp.array(dtype=int),
708
- joint_child: wp.array(dtype=int),
709
- joint_q_start: wp.array(dtype=int),
710
- joint_q: wp.array(dtype=float),
711
- joint_X_p: wp.array(dtype=wp.transform),
712
- joint_X_c: wp.array(dtype=wp.transform),
713
- body_X_com: wp.array(dtype=wp.transform),
714
- joint_axis: wp.array(dtype=wp.vec3),
715
- joint_axis_start: wp.array(dtype=int),
716
- joint_axis_dim: wp.array(dtype=int, ndim=2),
717
- # outputs
718
- body_q: wp.array(dtype=wp.transform),
719
- body_q_com: wp.array(dtype=wp.transform),
720
- ):
721
- # one thread per-articulation
722
- index = wp.tid()
723
-
724
- start = articulation_start[index]
725
- end = articulation_start[index + 1]
726
-
727
- for i in range(start, end):
728
- compute_link_transform(
729
- i,
730
- joint_type,
731
- joint_parent,
732
- joint_child,
733
- joint_q_start,
734
- joint_q,
735
- joint_X_p,
736
- joint_X_c,
737
- body_X_com,
738
- joint_axis,
739
- joint_axis_start,
740
- joint_axis_dim,
741
- body_q,
742
- body_q_com,
743
- )
744
-
745
-
746
- @wp.func
747
- def spatial_cross(a: wp.spatial_vector, b: wp.spatial_vector):
748
- w_a = wp.spatial_top(a)
749
- v_a = wp.spatial_bottom(a)
750
-
751
- w_b = wp.spatial_top(b)
752
- v_b = wp.spatial_bottom(b)
753
-
754
- w = wp.cross(w_a, w_b)
755
- v = wp.cross(w_a, v_b) + wp.cross(v_a, w_b)
756
-
757
- return wp.spatial_vector(w, v)
758
-
759
-
760
- @wp.func
761
- def spatial_cross_dual(a: wp.spatial_vector, b: wp.spatial_vector):
762
- w_a = wp.spatial_top(a)
763
- v_a = wp.spatial_bottom(a)
764
-
765
- w_b = wp.spatial_top(b)
766
- v_b = wp.spatial_bottom(b)
767
-
768
- w = wp.cross(w_a, w_b) + wp.cross(v_a, v_b)
769
- v = wp.cross(w_a, v_b)
770
-
771
- return wp.spatial_vector(w, v)
772
-
773
-
774
- @wp.func
775
- def dense_index(stride: int, i: int, j: int):
776
- return i * stride + j
777
-
778
-
779
- @wp.func
780
- def compute_link_velocity(
781
- i: int,
782
- joint_type: wp.array(dtype=int),
783
- joint_parent: wp.array(dtype=int),
784
- joint_child: wp.array(dtype=int),
785
- joint_q_start: wp.array(dtype=int),
786
- joint_qd_start: wp.array(dtype=int),
787
- joint_q: wp.array(dtype=float),
788
- joint_qd: wp.array(dtype=float),
789
- joint_axis: wp.array(dtype=wp.vec3),
790
- joint_axis_start: wp.array(dtype=int),
791
- joint_axis_dim: wp.array(dtype=int, ndim=2),
792
- body_I_m: wp.array(dtype=wp.spatial_matrix),
793
- body_q: wp.array(dtype=wp.transform),
794
- body_q_com: wp.array(dtype=wp.transform),
795
- joint_X_p: wp.array(dtype=wp.transform),
796
- joint_X_c: wp.array(dtype=wp.transform),
797
- gravity: wp.vec3,
798
- # outputs
799
- joint_S_s: wp.array(dtype=wp.spatial_vector),
800
- body_I_s: wp.array(dtype=wp.spatial_matrix),
801
- body_v_s: wp.array(dtype=wp.spatial_vector),
802
- body_f_s: wp.array(dtype=wp.spatial_vector),
803
- body_a_s: wp.array(dtype=wp.spatial_vector),
804
- ):
805
- type = joint_type[i]
806
- child = joint_child[i]
807
- parent = joint_parent[i]
808
- q_start = joint_q_start[i]
809
- qd_start = joint_qd_start[i]
810
-
811
- X_pj = joint_X_p[i]
812
- X_cj = joint_X_c[i]
813
-
814
- # parent anchor frame in world space
815
- X_wpj = X_pj
816
- if parent >= 0:
817
- X_wp = body_q[parent]
818
- X_wpj = X_wp * X_wpj
819
-
820
- # compute motion subspace and velocity across the joint (also stores S_s to global memory)
821
- axis_start = joint_axis_start[i]
822
- lin_axis_count = joint_axis_dim[i, 0]
823
- ang_axis_count = joint_axis_dim[i, 1]
824
- v_j_s = jcalc_motion(
825
- type,
826
- joint_axis,
827
- axis_start,
828
- lin_axis_count,
829
- ang_axis_count,
830
- X_wpj,
831
- joint_q,
832
- joint_qd,
833
- q_start,
834
- qd_start,
835
- joint_S_s,
836
- )
837
-
838
- # parent velocity
839
- v_parent_s = wp.spatial_vector()
840
- a_parent_s = wp.spatial_vector()
841
-
842
- if parent >= 0:
843
- v_parent_s = body_v_s[parent]
844
- a_parent_s = body_a_s[parent]
845
-
846
- # body velocity, acceleration
847
- v_s = v_parent_s + v_j_s
848
- a_s = a_parent_s + spatial_cross(v_s, v_j_s) # + joint_S_s[i]*self.joint_qdd[i]
849
-
850
- # compute body forces
851
- X_sm = body_q_com[child]
852
- I_m = body_I_m[child]
853
-
854
- # gravity and external forces (expressed in frame aligned with s but centered at body mass)
855
- m = I_m[3, 3]
856
-
857
- f_g = m * gravity
858
- r_com = wp.transform_get_translation(X_sm)
859
- f_g_s = wp.spatial_vector(wp.cross(r_com, f_g), f_g)
860
-
861
- # body forces
862
- I_s = spatial_transform_inertia(X_sm, I_m)
863
-
864
- f_b_s = I_s * a_s + spatial_cross_dual(v_s, I_s * v_s)
865
-
866
- body_v_s[child] = v_s
867
- body_a_s[child] = a_s
868
- body_f_s[child] = f_b_s - f_g_s
869
- body_I_s[child] = I_s
870
-
871
-
872
- # Inverse dynamics via Recursive Newton-Euler algorithm (Featherstone Table 5.1)
873
- @wp.kernel
874
- def eval_rigid_id(
875
- articulation_start: wp.array(dtype=int),
876
- joint_type: wp.array(dtype=int),
877
- joint_parent: wp.array(dtype=int),
878
- joint_child: wp.array(dtype=int),
879
- joint_q_start: wp.array(dtype=int),
880
- joint_qd_start: wp.array(dtype=int),
881
- joint_q: wp.array(dtype=float),
882
- joint_qd: wp.array(dtype=float),
883
- joint_axis: wp.array(dtype=wp.vec3),
884
- joint_axis_start: wp.array(dtype=int),
885
- joint_axis_dim: wp.array(dtype=int, ndim=2),
886
- body_I_m: wp.array(dtype=wp.spatial_matrix),
887
- body_q: wp.array(dtype=wp.transform),
888
- body_q_com: wp.array(dtype=wp.transform),
889
- joint_X_p: wp.array(dtype=wp.transform),
890
- joint_X_c: wp.array(dtype=wp.transform),
891
- gravity: wp.vec3,
892
- # outputs
893
- joint_S_s: wp.array(dtype=wp.spatial_vector),
894
- body_I_s: wp.array(dtype=wp.spatial_matrix),
895
- body_v_s: wp.array(dtype=wp.spatial_vector),
896
- body_f_s: wp.array(dtype=wp.spatial_vector),
897
- body_a_s: wp.array(dtype=wp.spatial_vector),
898
- ):
899
- # one thread per-articulation
900
- index = wp.tid()
901
-
902
- start = articulation_start[index]
903
- end = articulation_start[index + 1]
904
-
905
- # compute link velocities and coriolis forces
906
- for i in range(start, end):
907
- compute_link_velocity(
908
- i,
909
- joint_type,
910
- joint_parent,
911
- joint_child,
912
- joint_q_start,
913
- joint_qd_start,
914
- joint_q,
915
- joint_qd,
916
- joint_axis,
917
- joint_axis_start,
918
- joint_axis_dim,
919
- body_I_m,
920
- body_q,
921
- body_q_com,
922
- joint_X_p,
923
- joint_X_c,
924
- gravity,
925
- joint_S_s,
926
- body_I_s,
927
- body_v_s,
928
- body_f_s,
929
- body_a_s,
930
- )
931
-
932
-
933
- @wp.kernel
934
- def eval_rigid_tau(
935
- articulation_start: wp.array(dtype=int),
936
- joint_type: wp.array(dtype=int),
937
- joint_parent: wp.array(dtype=int),
938
- joint_child: wp.array(dtype=int),
939
- joint_q_start: wp.array(dtype=int),
940
- joint_qd_start: wp.array(dtype=int),
941
- joint_axis_start: wp.array(dtype=int),
942
- joint_axis_dim: wp.array(dtype=int, ndim=2),
943
- joint_axis_mode: wp.array(dtype=int),
944
- joint_q: wp.array(dtype=float),
945
- joint_qd: wp.array(dtype=float),
946
- joint_act: wp.array(dtype=float),
947
- joint_target_ke: wp.array(dtype=float),
948
- joint_target_kd: wp.array(dtype=float),
949
- joint_limit_lower: wp.array(dtype=float),
950
- joint_limit_upper: wp.array(dtype=float),
951
- joint_limit_ke: wp.array(dtype=float),
952
- joint_limit_kd: wp.array(dtype=float),
953
- joint_S_s: wp.array(dtype=wp.spatial_vector),
954
- body_fb_s: wp.array(dtype=wp.spatial_vector),
955
- body_f_ext: wp.array(dtype=wp.spatial_vector),
956
- # outputs
957
- body_ft_s: wp.array(dtype=wp.spatial_vector),
958
- tau: wp.array(dtype=float),
959
- ):
960
- # one thread per-articulation
961
- index = wp.tid()
962
-
963
- start = articulation_start[index]
964
- end = articulation_start[index + 1]
965
- count = end - start
966
-
967
- # compute joint forces
968
- for offset in range(count):
969
- # for backwards traversal
970
- i = end - offset - 1
971
-
972
- type = joint_type[i]
973
- parent = joint_parent[i]
974
- child = joint_child[i]
975
- dof_start = joint_qd_start[i]
976
- coord_start = joint_q_start[i]
977
- axis_start = joint_axis_start[i]
978
- lin_axis_count = joint_axis_dim[i, 0]
979
- ang_axis_count = joint_axis_dim[i, 1]
980
-
981
- # total forces on body
982
- f_b_s = body_fb_s[child]
983
- f_t_s = body_ft_s[child]
984
- f_ext = body_f_ext[child]
985
- f_s = f_b_s + f_t_s + f_ext
986
-
987
- # compute joint-space forces, writes out tau
988
- jcalc_tau(
989
- type,
990
- joint_target_ke,
991
- joint_target_kd,
992
- joint_limit_ke,
993
- joint_limit_kd,
994
- joint_S_s,
995
- joint_q,
996
- joint_qd,
997
- joint_act,
998
- joint_axis_mode,
999
- joint_limit_lower,
1000
- joint_limit_upper,
1001
- coord_start,
1002
- dof_start,
1003
- axis_start,
1004
- lin_axis_count,
1005
- ang_axis_count,
1006
- f_s,
1007
- tau,
1008
- )
1009
-
1010
- # update parent forces, todo: check that this is valid for the backwards pass
1011
- if parent >= 0:
1012
- wp.atomic_add(body_ft_s, parent, f_s)
1013
-
1014
-
1015
- # builds spatial Jacobian J which is an (joint_count*6)x(dof_count) matrix
1016
- @wp.kernel
1017
- def eval_rigid_jacobian(
1018
- articulation_start: wp.array(dtype=int),
1019
- articulation_J_start: wp.array(dtype=int),
1020
- joint_parent: wp.array(dtype=int),
1021
- joint_qd_start: wp.array(dtype=int),
1022
- joint_S_s: wp.array(dtype=wp.spatial_vector),
1023
- # outputs
1024
- J: wp.array(dtype=float),
1025
- ):
1026
- # one thread per-articulation
1027
- index = wp.tid()
1028
-
1029
- joint_start = articulation_start[index]
1030
- joint_end = articulation_start[index + 1]
1031
- joint_count = joint_end - joint_start
1032
-
1033
- J_offset = articulation_J_start[index]
1034
-
1035
- articulation_dof_start = joint_qd_start[joint_start]
1036
- articulation_dof_end = joint_qd_start[joint_end]
1037
- articulation_dof_count = articulation_dof_end - articulation_dof_start
1038
-
1039
- for i in range(joint_count):
1040
- row_start = i * 6
1041
-
1042
- j = joint_start + i
1043
- while j != -1:
1044
- joint_dof_start = joint_qd_start[j]
1045
- joint_dof_end = joint_qd_start[j + 1]
1046
- joint_dof_count = joint_dof_end - joint_dof_start
1047
-
1048
- # fill out each row of the Jacobian walking up the tree
1049
- for dof in range(joint_dof_count):
1050
- col = (joint_dof_start - articulation_dof_start) + dof
1051
- S = joint_S_s[joint_dof_start + dof]
1052
-
1053
- for k in range(6):
1054
- J[J_offset + dense_index(articulation_dof_count, row_start + k, col)] = S[k]
1055
-
1056
- j = joint_parent[j]
1057
-
1058
-
1059
- @wp.func
1060
- def spatial_mass(
1061
- body_I_s: wp.array(dtype=wp.spatial_matrix),
1062
- joint_start: int,
1063
- joint_count: int,
1064
- M_start: int,
1065
- # outputs
1066
- M: wp.array(dtype=float),
1067
- ):
1068
- stride = joint_count * 6
1069
- for l in range(joint_count):
1070
- I = body_I_s[joint_start + l]
1071
- for i in range(6):
1072
- for j in range(6):
1073
- M[M_start + dense_index(stride, l * 6 + i, l * 6 + j)] = I[i, j]
1074
-
1075
-
1076
- @wp.kernel
1077
- def eval_rigid_mass(
1078
- articulation_start: wp.array(dtype=int),
1079
- articulation_M_start: wp.array(dtype=int),
1080
- body_I_s: wp.array(dtype=wp.spatial_matrix),
1081
- # outputs
1082
- M: wp.array(dtype=float),
1083
- ):
1084
- # one thread per-articulation
1085
- index = wp.tid()
1086
-
1087
- joint_start = articulation_start[index]
1088
- joint_end = articulation_start[index + 1]
1089
- joint_count = joint_end - joint_start
1090
-
1091
- M_offset = articulation_M_start[index]
1092
-
1093
- spatial_mass(body_I_s, joint_start, joint_count, M_offset, M)
1094
-
1095
-
1096
- @wp.func
1097
- def dense_gemm(
1098
- m: int,
1099
- n: int,
1100
- p: int,
1101
- transpose_A: bool,
1102
- transpose_B: bool,
1103
- add_to_C: bool,
1104
- A_start: int,
1105
- B_start: int,
1106
- C_start: int,
1107
- A: wp.array(dtype=float),
1108
- B: wp.array(dtype=float),
1109
- # outputs
1110
- C: wp.array(dtype=float),
1111
- ):
1112
- # multiply a `m x p` matrix A by a `p x n` matrix B to produce a `m x n` matrix C
1113
- for i in range(m):
1114
- for j in range(n):
1115
- sum = float(0.0)
1116
- for k in range(p):
1117
- if transpose_A:
1118
- a_i = k * m + i
1119
- else:
1120
- a_i = i * p + k
1121
- if transpose_B:
1122
- b_j = j * p + k
1123
- else:
1124
- b_j = k * n + j
1125
- sum += A[A_start + a_i] * B[B_start + b_j]
1126
-
1127
- if add_to_C:
1128
- C[C_start + i * n + j] += sum
1129
- else:
1130
- C[C_start + i * n + j] = sum
1131
-
1132
-
1133
- @wp.func_grad(dense_gemm)
1134
- def adj_dense_gemm(
1135
- m: int,
1136
- n: int,
1137
- p: int,
1138
- transpose_A: bool,
1139
- transpose_B: bool,
1140
- add_to_C: bool,
1141
- A_start: int,
1142
- B_start: int,
1143
- C_start: int,
1144
- A: wp.array(dtype=float),
1145
- B: wp.array(dtype=float),
1146
- # outputs
1147
- C: wp.array(dtype=float),
1148
- ):
1149
- add_to_C = True
1150
- if transpose_A:
1151
- dense_gemm(p, m, n, False, True, add_to_C, A_start, B_start, C_start, B, wp.adjoint[C], wp.adjoint[A])
1152
- dense_gemm(p, n, m, False, False, add_to_C, A_start, B_start, C_start, A, wp.adjoint[C], wp.adjoint[B])
1153
- else:
1154
- dense_gemm(
1155
- m, p, n, False, not transpose_B, add_to_C, A_start, B_start, C_start, wp.adjoint[C], B, wp.adjoint[A]
1156
- )
1157
- dense_gemm(p, n, m, True, False, add_to_C, A_start, B_start, C_start, A, wp.adjoint[C], wp.adjoint[B])
1158
-
1159
-
1160
- @wp.kernel
1161
- def eval_dense_gemm_batched(
1162
- m: wp.array(dtype=int),
1163
- n: wp.array(dtype=int),
1164
- p: wp.array(dtype=int),
1165
- transpose_A: bool,
1166
- transpose_B: bool,
1167
- A_start: wp.array(dtype=int),
1168
- B_start: wp.array(dtype=int),
1169
- C_start: wp.array(dtype=int),
1170
- A: wp.array(dtype=float),
1171
- B: wp.array(dtype=float),
1172
- C: wp.array(dtype=float),
1173
- ):
1174
- # on the CPU each thread computes the whole matrix multiply
1175
- # on the GPU each block computes the multiply with one output per-thread
1176
- batch = wp.tid() # /kNumThreadsPerBlock;
1177
- add_to_C = False
1178
-
1179
- dense_gemm(
1180
- m[batch],
1181
- n[batch],
1182
- p[batch],
1183
- transpose_A,
1184
- transpose_B,
1185
- add_to_C,
1186
- A_start[batch],
1187
- B_start[batch],
1188
- C_start[batch],
1189
- A,
1190
- B,
1191
- C,
1192
- )
1193
-
1194
-
1195
- @wp.func
1196
- def dense_cholesky(
1197
- n: int,
1198
- A: wp.array(dtype=float),
1199
- R: wp.array(dtype=float),
1200
- A_start: int,
1201
- R_start: int,
1202
- # outputs
1203
- L: wp.array(dtype=float),
1204
- ):
1205
- # compute the Cholesky factorization of A = L L^T with diagonal regularization R
1206
- for j in range(n):
1207
- s = A[A_start + dense_index(n, j, j)] + R[R_start + j]
1208
-
1209
- for k in range(j):
1210
- r = L[A_start + dense_index(n, j, k)]
1211
- s -= r * r
1212
-
1213
- s = wp.sqrt(s)
1214
- invS = 1.0 / s
1215
-
1216
- L[A_start + dense_index(n, j, j)] = s
1217
-
1218
- for i in range(j + 1, n):
1219
- s = A[A_start + dense_index(n, i, j)]
1220
-
1221
- for k in range(j):
1222
- s -= L[A_start + dense_index(n, i, k)] * L[A_start + dense_index(n, j, k)]
1223
-
1224
- L[A_start + dense_index(n, i, j)] = s * invS
1225
-
1226
-
1227
- @wp.func_grad(dense_cholesky)
1228
- def adj_dense_cholesky(
1229
- n: int,
1230
- A: wp.array(dtype=float),
1231
- R: wp.array(dtype=float),
1232
- A_start: int,
1233
- R_start: int,
1234
- # outputs
1235
- L: wp.array(dtype=float),
1236
- ):
1237
- # nop, use dense_solve to differentiate through (A^-1)b = x
1238
- pass
1239
-
1240
-
1241
- @wp.kernel
1242
- def eval_dense_cholesky_batched(
1243
- A_starts: wp.array(dtype=int),
1244
- A_dim: wp.array(dtype=int),
1245
- A: wp.array(dtype=float),
1246
- R: wp.array(dtype=float),
1247
- L: wp.array(dtype=float),
1248
- ):
1249
- batch = wp.tid()
1250
-
1251
- n = A_dim[batch]
1252
- A_start = A_starts[batch]
1253
- R_start = n * batch
1254
-
1255
- dense_cholesky(n, A, R, A_start, R_start, L)
1256
-
1257
-
1258
- @wp.func
1259
- def dense_subs(
1260
- n: int,
1261
- L_start: int,
1262
- b_start: int,
1263
- L: wp.array(dtype=float),
1264
- b: wp.array(dtype=float),
1265
- # outputs
1266
- x: wp.array(dtype=float),
1267
- ):
1268
- # Solves (L L^T) x = b for x given the Cholesky factor L
1269
- # forward substitution solves the lower triangular system L y = b for y
1270
- for i in range(n):
1271
- s = b[b_start + i]
1272
-
1273
- for j in range(i):
1274
- s -= L[L_start + dense_index(n, i, j)] * x[b_start + j]
1275
-
1276
- x[b_start + i] = s / L[L_start + dense_index(n, i, i)]
1277
-
1278
- # backward substitution solves the upper triangular system L^T x = y for x
1279
- for i in range(n - 1, -1, -1):
1280
- s = x[b_start + i]
1281
-
1282
- for j in range(i + 1, n):
1283
- s -= L[L_start + dense_index(n, j, i)] * x[b_start + j]
1284
-
1285
- x[b_start + i] = s / L[L_start + dense_index(n, i, i)]
1286
-
1287
-
1288
- @wp.func
1289
- def dense_solve(
1290
- n: int,
1291
- L_start: int,
1292
- b_start: int,
1293
- L: wp.array(dtype=float),
1294
- b: wp.array(dtype=float),
1295
- # outputs
1296
- x: wp.array(dtype=float),
1297
- tmp: wp.array(dtype=float),
1298
- ):
1299
- # helper function to include tmp argument for backward pass
1300
- dense_subs(n, L_start, b_start, L, b, x)
1301
-
1302
-
1303
- @wp.func_grad(dense_solve)
1304
- def adj_dense_solve(
1305
- n: int,
1306
- L_start: int,
1307
- b_start: int,
1308
- L: wp.array(dtype=float),
1309
- b: wp.array(dtype=float),
1310
- # outputs
1311
- x: wp.array(dtype=float),
1312
- tmp: wp.array(dtype=float),
1313
- ):
1314
- for i in range(n):
1315
- tmp[b_start + i] = 0.0
1316
-
1317
- dense_subs(n, L_start, b_start, L, wp.adjoint[x], tmp)
1318
-
1319
- for i in range(n):
1320
- wp.adjoint[b][b_start + i] += tmp[b_start + i]
1321
-
1322
- # A* = -adj_b*x^T
1323
- for i in range(n):
1324
- for j in range(n):
1325
- wp.adjoint[L][L_start + dense_index(n, i, j)] += -tmp[b_start + i] * x[b_start + j]
1326
-
1327
-
1328
- @wp.kernel
1329
- def eval_dense_solve_batched(
1330
- L_start: wp.array(dtype=int),
1331
- L_dim: wp.array(dtype=int),
1332
- b_start: wp.array(dtype=int),
1333
- L: wp.array(dtype=float),
1334
- b: wp.array(dtype=float),
1335
- x: wp.array(dtype=float),
1336
- tmp: wp.array(dtype=float),
1337
- ):
1338
- batch = wp.tid()
1339
-
1340
- dense_solve(L_dim[batch], L_start[batch], b_start[batch], L, b, x, tmp)
1341
-
1342
-
1343
- @wp.kernel
1344
- def integrate_generalized_joints(
1345
- joint_type: wp.array(dtype=int),
1346
- joint_q_start: wp.array(dtype=int),
1347
- joint_qd_start: wp.array(dtype=int),
1348
- joint_axis_dim: wp.array(dtype=int, ndim=2),
1349
- joint_q: wp.array(dtype=float),
1350
- joint_qd: wp.array(dtype=float),
1351
- joint_qdd: wp.array(dtype=float),
1352
- dt: float,
1353
- # outputs
1354
- joint_q_new: wp.array(dtype=float),
1355
- joint_qd_new: wp.array(dtype=float),
1356
- ):
1357
- # one thread per-articulation
1358
- index = wp.tid()
1359
-
1360
- type = joint_type[index]
1361
- coord_start = joint_q_start[index]
1362
- dof_start = joint_qd_start[index]
1363
- lin_axis_count = joint_axis_dim[index, 0]
1364
- ang_axis_count = joint_axis_dim[index, 1]
1365
-
1366
- jcalc_integrate(
1367
- type,
1368
- joint_q,
1369
- joint_qd,
1370
- joint_qdd,
1371
- coord_start,
1372
- dof_start,
1373
- lin_axis_count,
1374
- ang_axis_count,
1375
- dt,
1376
- joint_q_new,
1377
- joint_qd_new,
1378
- )
1379
-
1380
-
1381
- @wp.kernel
1382
- def eval_body_inertial_velocities(
1383
- body_q: wp.array(dtype=wp.transform),
1384
- body_v_s: wp.array(dtype=wp.spatial_vector),
1385
- # outputs
1386
- body_qd: wp.array(dtype=wp.spatial_vector),
1387
- ):
1388
- tid = wp.tid()
1389
-
1390
- X_sc = body_q[tid]
1391
- v_s = body_v_s[tid]
1392
- w = wp.spatial_top(v_s)
1393
- v = wp.spatial_bottom(v_s)
1394
-
1395
- v_inertial = v + wp.cross(w, wp.transform_get_translation(X_sc))
1396
-
1397
- body_qd[tid] = wp.spatial_vector(w, v_inertial)
1398
-
1399
-
1400
- class FeatherstoneIntegrator(Integrator):
1401
- """A semi-implicit integrator using symplectic Euler that operates
1402
- on reduced (also called generalized) coordinates to simulate articulated rigid body dynamics
1403
- based on Featherstone's composite rigid body algorithm (CRBA).
1404
-
1405
- See: Featherstone, Roy. Rigid Body Dynamics Algorithms. Springer US, 2014.
1406
-
1407
- Instead of maximal coordinates :attr:`State.body_q` (rigid body positions) and :attr:`State.body_qd`
1408
- (rigid body velocities) as is the case :class:`SemiImplicitIntegrator`, :class:`FeatherstoneIntegrator`
1409
- uses :attr:`State.joint_q` and :attr:`State.joint_qd` to represent the positions and velocities of
1410
- joints without allowing any redundant degrees of freedom.
1411
-
1412
- After constructing :class:`Model` and :class:`State` objects this time-integrator
1413
- may be used to advance the simulation state forward in time.
1414
-
1415
- Note:
1416
- Unlike :class:`SemiImplicitIntegrator` and :class:`XPBDIntegrator`, :class:`FeatherstoneIntegrator` does not simulate rigid bodies with nonzero mass as floating bodies if they are not connected through any joints. Floating-base systems require an explicit free joint with which the body is connected to the world, see :meth:`ModelBuilder.add_joint_free`.
1417
-
1418
- Semi-implicit time integration is a variational integrator that
1419
- preserves energy, however it not unconditionally stable, and requires a time-step
1420
- small enough to support the required stiffness and damping forces.
1421
-
1422
- See: https://en.wikipedia.org/wiki/Semi-implicit_Euler_method
1423
-
1424
- Example
1425
- -------
1426
-
1427
- .. code-block:: python
1428
-
1429
- integrator = wp.FeatherstoneIntegrator(model)
1430
-
1431
- # simulation loop
1432
- for i in range(100):
1433
- state = integrator.simulate(model, state_in, state_out, dt)
1434
-
1435
- Note:
1436
- The :class:`FeatherstoneIntegrator` requires the :class:`Model` to be passed in as a constructor argument.
1437
-
1438
- """
1439
-
1440
- def __init__(self, model, angular_damping=0.05, update_mass_matrix_every=1):
1441
- """
1442
- Args:
1443
- model (Model): the model to be simulated.
1444
- angular_damping (float, optional): Angular damping factor. Defaults to 0.05.
1445
- update_mass_matrix_every (int, optional): How often to update the mass matrix (every n-th time the :meth:`simulate` function gets called). Defaults to 1.
1446
- """
1447
- self.angular_damping = angular_damping
1448
- self.update_mass_matrix_every = update_mass_matrix_every
1449
- self.compute_articulation_indices(model)
1450
- self.allocate_model_aux_vars(model)
1451
- self._step = 0
1452
-
1453
- def compute_articulation_indices(self, model):
1454
- # calculate total size and offsets of Jacobian and mass matrices for entire system
1455
- if model.joint_count:
1456
- self.J_size = 0
1457
- self.M_size = 0
1458
- self.H_size = 0
1459
-
1460
- articulation_J_start = []
1461
- articulation_M_start = []
1462
- articulation_H_start = []
1463
-
1464
- articulation_M_rows = []
1465
- articulation_H_rows = []
1466
- articulation_J_rows = []
1467
- articulation_J_cols = []
1468
-
1469
- articulation_dof_start = []
1470
- articulation_coord_start = []
1471
-
1472
- articulation_start = model.articulation_start.numpy()
1473
- joint_q_start = model.joint_q_start.numpy()
1474
- joint_qd_start = model.joint_qd_start.numpy()
1475
-
1476
- for i in range(model.articulation_count):
1477
- first_joint = articulation_start[i]
1478
- last_joint = articulation_start[i + 1]
1479
-
1480
- first_coord = joint_q_start[first_joint]
1481
-
1482
- first_dof = joint_qd_start[first_joint]
1483
- last_dof = joint_qd_start[last_joint]
1484
-
1485
- joint_count = last_joint - first_joint
1486
- dof_count = last_dof - first_dof
1487
-
1488
- articulation_J_start.append(self.J_size)
1489
- articulation_M_start.append(self.M_size)
1490
- articulation_H_start.append(self.H_size)
1491
- articulation_dof_start.append(first_dof)
1492
- articulation_coord_start.append(first_coord)
1493
-
1494
- # bit of data duplication here, but will leave it as such for clarity
1495
- articulation_M_rows.append(joint_count * 6)
1496
- articulation_H_rows.append(dof_count)
1497
- articulation_J_rows.append(joint_count * 6)
1498
- articulation_J_cols.append(dof_count)
1499
-
1500
- self.J_size += 6 * joint_count * dof_count
1501
- self.M_size += 6 * joint_count * 6 * joint_count
1502
- self.H_size += dof_count * dof_count
1503
-
1504
- # matrix offsets for batched gemm
1505
- self.articulation_J_start = wp.array(articulation_J_start, dtype=wp.int32, device=model.device)
1506
- self.articulation_M_start = wp.array(articulation_M_start, dtype=wp.int32, device=model.device)
1507
- self.articulation_H_start = wp.array(articulation_H_start, dtype=wp.int32, device=model.device)
1508
-
1509
- self.articulation_M_rows = wp.array(articulation_M_rows, dtype=wp.int32, device=model.device)
1510
- self.articulation_H_rows = wp.array(articulation_H_rows, dtype=wp.int32, device=model.device)
1511
- self.articulation_J_rows = wp.array(articulation_J_rows, dtype=wp.int32, device=model.device)
1512
- self.articulation_J_cols = wp.array(articulation_J_cols, dtype=wp.int32, device=model.device)
1513
-
1514
- self.articulation_dof_start = wp.array(articulation_dof_start, dtype=wp.int32, device=model.device)
1515
- self.articulation_coord_start = wp.array(articulation_coord_start, dtype=wp.int32, device=model.device)
1516
-
1517
- def allocate_model_aux_vars(self, model):
1518
- # allocate mass, Jacobian matrices, and other auxiliary variables pertaining to the model
1519
- if model.joint_count:
1520
- # system matrices
1521
- self.M = wp.zeros((self.M_size,), dtype=wp.float32, device=model.device, requires_grad=model.requires_grad)
1522
- self.J = wp.zeros((self.J_size,), dtype=wp.float32, device=model.device, requires_grad=model.requires_grad)
1523
- self.P = wp.empty_like(self.J, requires_grad=model.requires_grad)
1524
- self.H = wp.empty((self.H_size,), dtype=wp.float32, device=model.device, requires_grad=model.requires_grad)
1525
-
1526
- # zero since only upper triangle is set which can trigger NaN detection
1527
- self.L = wp.zeros_like(self.H)
1528
-
1529
- if model.body_count:
1530
- # TODO use requires_grad here?
1531
- self.body_I_m = wp.empty((model.body_count,), dtype=wp.spatial_matrix, device=model.device)
1532
- wp.launch(
1533
- compute_spatial_inertia,
1534
- model.body_count,
1535
- inputs=[model.body_inertia, model.body_mass],
1536
- outputs=[self.body_I_m],
1537
- device=model.device,
1538
- )
1539
- self.body_X_com = wp.empty((model.body_count,), dtype=wp.transform, device=model.device)
1540
- wp.launch(
1541
- compute_com_transforms,
1542
- model.body_count,
1543
- inputs=[model.body_com],
1544
- outputs=[self.body_X_com],
1545
- device=model.device,
1546
- )
1547
-
1548
- def allocate_state_aux_vars(self, model, target, requires_grad):
1549
- # allocate auxiliary variables that vary with state
1550
- if model.body_count:
1551
- # joints
1552
- target.joint_qdd = wp.zeros_like(model.joint_qd, requires_grad=requires_grad)
1553
- target.joint_tau = wp.empty_like(model.joint_qd, requires_grad=requires_grad)
1554
- if requires_grad:
1555
- # used in the custom grad implementation of eval_dense_solve_batched
1556
- target.joint_solve_tmp = wp.zeros_like(model.joint_qd, requires_grad=True)
1557
- else:
1558
- target.joint_solve_tmp = None
1559
- target.joint_S_s = wp.empty(
1560
- (model.joint_dof_count,),
1561
- dtype=wp.spatial_vector,
1562
- device=model.device,
1563
- requires_grad=requires_grad,
1564
- )
1565
-
1566
- # derived rigid body data (maximal coordinates)
1567
- target.body_q_com = wp.empty_like(model.body_q, requires_grad=requires_grad)
1568
- target.body_I_s = wp.empty(
1569
- (model.body_count,), dtype=wp.spatial_matrix, device=model.device, requires_grad=requires_grad
1570
- )
1571
- target.body_v_s = wp.empty(
1572
- (model.body_count,), dtype=wp.spatial_vector, device=model.device, requires_grad=requires_grad
1573
- )
1574
- target.body_a_s = wp.empty(
1575
- (model.body_count,), dtype=wp.spatial_vector, device=model.device, requires_grad=requires_grad
1576
- )
1577
- target.body_f_s = wp.zeros(
1578
- (model.body_count,), dtype=wp.spatial_vector, device=model.device, requires_grad=requires_grad
1579
- )
1580
- target.body_ft_s = wp.zeros(
1581
- (model.body_count,), dtype=wp.spatial_vector, device=model.device, requires_grad=requires_grad
1582
- )
1583
-
1584
- target._featherstone_augmented = True
1585
-
1586
- def simulate(self, model: Model, state_in: State, state_out: State, dt: float, control: Control = None):
1587
- requires_grad = state_in.requires_grad
1588
-
1589
- # optionally create dynamical auxiliary variables
1590
- if requires_grad:
1591
- state_aug = state_out
1592
- else:
1593
- state_aug = self
1594
-
1595
- if not getattr(state_aug, "_featherstone_augmented", False):
1596
- self.allocate_state_aux_vars(model, state_aug, requires_grad)
1597
- if control is None:
1598
- control = model.control(clone_variables=False)
1599
-
1600
- with wp.ScopedTimer("simulate", False):
1601
- particle_f = None
1602
- body_f = None
1603
-
1604
- if state_in.particle_count:
1605
- particle_f = state_in.particle_f
1606
-
1607
- if state_in.body_count:
1608
- body_f = state_in.body_f
1609
-
1610
- # damped springs
1611
- eval_spring_forces(model, state_in, particle_f)
1612
-
1613
- # triangle elastic and lift/drag forces
1614
- eval_triangle_forces(model, state_in, control, particle_f)
1615
-
1616
- # triangle/triangle contacts
1617
- eval_triangle_contact_forces(model, state_in, particle_f)
1618
-
1619
- # triangle bending
1620
- eval_bending_forces(model, state_in, particle_f)
1621
-
1622
- # tetrahedral FEM
1623
- eval_tetrahedral_forces(model, state_in, control, particle_f)
1624
-
1625
- # particle-particle interactions
1626
- eval_particle_forces(model, state_in, particle_f)
1627
-
1628
- # particle ground contacts
1629
- eval_particle_ground_contact_forces(model, state_in, particle_f)
1630
-
1631
- # particle shape contact
1632
- eval_particle_body_contact_forces(model, state_in, particle_f, body_f)
1633
-
1634
- # muscles
1635
- if False:
1636
- eval_muscle_forces(model, state_in, control, body_f)
1637
-
1638
- # ----------------------------
1639
- # articulations
1640
-
1641
- if model.joint_count:
1642
-
1643
- # evaluate body transforms
1644
- wp.launch(
1645
- eval_rigid_fk,
1646
- dim=model.articulation_count,
1647
- inputs=[
1648
- model.articulation_start,
1649
- model.joint_type,
1650
- model.joint_parent,
1651
- model.joint_child,
1652
- model.joint_q_start,
1653
- state_in.joint_q,
1654
- model.joint_X_p,
1655
- model.joint_X_c,
1656
- self.body_X_com,
1657
- model.joint_axis,
1658
- model.joint_axis_start,
1659
- model.joint_axis_dim,
1660
- ],
1661
- outputs=[state_in.body_q, state_aug.body_q_com],
1662
- device=model.device,
1663
- )
1664
-
1665
- # print("body_X_sc:")
1666
- # print(state_in.body_q.numpy())
1667
-
1668
- # evaluate joint inertias, motion vectors, and forces
1669
- state_aug.body_f_s.zero_()
1670
- wp.launch(
1671
- eval_rigid_id,
1672
- dim=model.articulation_count,
1673
- inputs=[
1674
- model.articulation_start,
1675
- model.joint_type,
1676
- model.joint_parent,
1677
- model.joint_child,
1678
- model.joint_q_start,
1679
- model.joint_qd_start,
1680
- state_in.joint_q,
1681
- state_in.joint_qd,
1682
- model.joint_axis,
1683
- model.joint_axis_start,
1684
- model.joint_axis_dim,
1685
- self.body_I_m,
1686
- state_in.body_q,
1687
- state_aug.body_q_com,
1688
- model.joint_X_p,
1689
- model.joint_X_c,
1690
- model.gravity,
1691
- ],
1692
- outputs=[
1693
- state_aug.joint_S_s,
1694
- state_aug.body_I_s,
1695
- state_aug.body_v_s,
1696
- state_aug.body_f_s,
1697
- state_aug.body_a_s,
1698
- ],
1699
- device=model.device,
1700
- )
1701
-
1702
- if model.rigid_contact_max and (
1703
- model.ground and model.shape_ground_contact_pair_count or model.shape_contact_pair_count
1704
- ):
1705
- wp.launch(
1706
- kernel=eval_rigid_contacts,
1707
- dim=model.rigid_contact_max,
1708
- inputs=[
1709
- state_in.body_q,
1710
- state_aug.body_v_s,
1711
- model.body_com,
1712
- model.shape_materials,
1713
- model.shape_geo,
1714
- model.shape_body,
1715
- model.rigid_contact_count,
1716
- model.rigid_contact_point0,
1717
- model.rigid_contact_point1,
1718
- model.rigid_contact_normal,
1719
- model.rigid_contact_shape0,
1720
- model.rigid_contact_shape1,
1721
- True,
1722
- ],
1723
- outputs=[body_f],
1724
- device=model.device,
1725
- )
1726
-
1727
- # if model.rigid_contact_count.numpy()[0] > 0:
1728
- # print(body_f.numpy())
1729
-
1730
- if model.articulation_count:
1731
- # evaluate joint torques
1732
- state_aug.body_ft_s.zero_()
1733
- wp.launch(
1734
- eval_rigid_tau,
1735
- dim=model.articulation_count,
1736
- inputs=[
1737
- model.articulation_start,
1738
- model.joint_type,
1739
- model.joint_parent,
1740
- model.joint_child,
1741
- model.joint_q_start,
1742
- model.joint_qd_start,
1743
- model.joint_axis_start,
1744
- model.joint_axis_dim,
1745
- model.joint_axis_mode,
1746
- state_in.joint_q,
1747
- state_in.joint_qd,
1748
- control.joint_act,
1749
- model.joint_target_ke,
1750
- model.joint_target_kd,
1751
- model.joint_limit_lower,
1752
- model.joint_limit_upper,
1753
- model.joint_limit_ke,
1754
- model.joint_limit_kd,
1755
- state_aug.joint_S_s,
1756
- state_aug.body_f_s,
1757
- body_f,
1758
- ],
1759
- outputs=[
1760
- state_aug.body_ft_s,
1761
- state_aug.joint_tau,
1762
- ],
1763
- device=model.device,
1764
- )
1765
-
1766
- # print("joint_tau:")
1767
- # print(state_aug.joint_tau.numpy())
1768
- # print("body_q:")
1769
- # print(state_in.body_q.numpy())
1770
- # print("body_qd:")
1771
- # print(state_in.body_qd.numpy())
1772
-
1773
- if self._step % self.update_mass_matrix_every == 0:
1774
- # build J
1775
- wp.launch(
1776
- eval_rigid_jacobian,
1777
- dim=model.articulation_count,
1778
- inputs=[
1779
- model.articulation_start,
1780
- self.articulation_J_start,
1781
- model.joint_parent,
1782
- model.joint_qd_start,
1783
- state_aug.joint_S_s,
1784
- ],
1785
- outputs=[self.J],
1786
- device=model.device,
1787
- )
1788
-
1789
- # build M
1790
- wp.launch(
1791
- eval_rigid_mass,
1792
- dim=model.articulation_count,
1793
- inputs=[
1794
- model.articulation_start,
1795
- self.articulation_M_start,
1796
- state_aug.body_I_s,
1797
- ],
1798
- outputs=[self.M],
1799
- device=model.device,
1800
- )
1801
-
1802
- # form P = M*J
1803
- wp.launch(
1804
- eval_dense_gemm_batched,
1805
- dim=model.articulation_count,
1806
- inputs=[
1807
- self.articulation_M_rows,
1808
- self.articulation_J_cols,
1809
- self.articulation_J_rows,
1810
- False,
1811
- False,
1812
- self.articulation_M_start,
1813
- self.articulation_J_start,
1814
- # P start is the same as J start since it has the same dims as J
1815
- self.articulation_J_start,
1816
- self.M,
1817
- self.J,
1818
- ],
1819
- outputs=[self.P],
1820
- device=model.device,
1821
- )
1822
-
1823
- # form H = J^T*P
1824
- wp.launch(
1825
- eval_dense_gemm_batched,
1826
- dim=model.articulation_count,
1827
- inputs=[
1828
- self.articulation_J_cols,
1829
- self.articulation_J_cols,
1830
- # P rows is the same as J rows
1831
- self.articulation_J_rows,
1832
- True,
1833
- False,
1834
- self.articulation_J_start,
1835
- # P start is the same as J start since it has the same dims as J
1836
- self.articulation_J_start,
1837
- self.articulation_H_start,
1838
- self.J,
1839
- self.P,
1840
- ],
1841
- outputs=[self.H],
1842
- device=model.device,
1843
- )
1844
-
1845
- # compute decomposition
1846
- wp.launch(
1847
- eval_dense_cholesky_batched,
1848
- dim=model.articulation_count,
1849
- inputs=[
1850
- self.articulation_H_start,
1851
- self.articulation_H_rows,
1852
- self.H,
1853
- model.joint_armature,
1854
- ],
1855
- outputs=[self.L],
1856
- device=model.device,
1857
- )
1858
-
1859
- # print("joint_act:")
1860
- # print(control.joint_act.numpy())
1861
- # print("joint_tau:")
1862
- # print(state_aug.joint_tau.numpy())
1863
- # print("H:")
1864
- # print(self.H.numpy())
1865
- # print("L:")
1866
- # print(self.L.numpy())
1867
-
1868
- # solve for qdd
1869
- state_aug.joint_qdd.zero_()
1870
- wp.launch(
1871
- eval_dense_solve_batched,
1872
- dim=model.articulation_count,
1873
- inputs=[
1874
- self.articulation_H_start,
1875
- self.articulation_H_rows,
1876
- self.articulation_dof_start,
1877
- self.L,
1878
- state_aug.joint_tau,
1879
- ],
1880
- outputs=[
1881
- state_aug.joint_qdd,
1882
- state_aug.joint_solve_tmp,
1883
- ],
1884
- device=model.device,
1885
- )
1886
- # if wp.context.runtime.tape:
1887
- # wp.context.runtime.tape.record_func(
1888
- # backward=lambda: adj_matmul(
1889
- # a, b, c, a.grad, b.grad, c.grad, d.grad, alpha, beta, allow_tf32x3_arith, device
1890
- # ),
1891
- # arrays=[a, b, c, d],
1892
- # )
1893
- # print("joint_qdd:")
1894
- # print(state_aug.joint_qdd.numpy())
1895
- # print("\n\n")
1896
-
1897
- # -------------------------------------
1898
- # integrate bodies
1899
-
1900
- if model.joint_count:
1901
- wp.launch(
1902
- kernel=integrate_generalized_joints,
1903
- dim=model.joint_count,
1904
- inputs=[
1905
- model.joint_type,
1906
- model.joint_q_start,
1907
- model.joint_qd_start,
1908
- model.joint_axis_dim,
1909
- state_in.joint_q,
1910
- state_in.joint_qd,
1911
- state_aug.joint_qdd,
1912
- dt,
1913
- ],
1914
- outputs=[state_out.joint_q, state_out.joint_qd],
1915
- device=model.device,
1916
- )
1917
-
1918
- wp.launch(
1919
- eval_rigid_fk,
1920
- dim=model.articulation_count,
1921
- inputs=[
1922
- model.articulation_start,
1923
- model.joint_type,
1924
- model.joint_parent,
1925
- model.joint_child,
1926
- model.joint_q_start,
1927
- state_out.joint_q,
1928
- model.joint_X_p,
1929
- model.joint_X_c,
1930
- self.body_X_com,
1931
- model.joint_axis,
1932
- model.joint_axis_start,
1933
- model.joint_axis_dim,
1934
- ],
1935
- outputs=[state_out.body_q, state_aug.body_q_com],
1936
- device=model.device,
1937
- )
1938
-
1939
- # compute body_qd
1940
- state_aug.body_f_s.zero_()
1941
- wp.launch(
1942
- eval_rigid_id,
1943
- dim=model.articulation_count,
1944
- inputs=[
1945
- model.articulation_start,
1946
- model.joint_type,
1947
- model.joint_parent,
1948
- model.joint_child,
1949
- model.joint_q_start,
1950
- model.joint_qd_start,
1951
- state_out.joint_q,
1952
- state_out.joint_qd,
1953
- model.joint_axis,
1954
- model.joint_axis_start,
1955
- model.joint_axis_dim,
1956
- self.body_I_m,
1957
- state_out.body_q,
1958
- state_aug.body_q_com,
1959
- model.joint_X_p,
1960
- model.joint_X_c,
1961
- model.gravity,
1962
- ],
1963
- outputs=[
1964
- state_aug.joint_S_s,
1965
- state_aug.body_I_s,
1966
- state_aug.body_v_s,
1967
- state_aug.body_f_s,
1968
- state_aug.body_a_s,
1969
- ],
1970
- device=model.device,
1971
- )
1972
-
1973
- # body velocity in inertial frame
1974
- wp.launch(
1975
- kernel=eval_body_inertial_velocities,
1976
- dim=model.body_count,
1977
- inputs=[
1978
- state_out.body_q,
1979
- state_aug.body_v_s,
1980
- ],
1981
- outputs=[
1982
- state_out.body_qd,
1983
- ],
1984
- device=model.device,
1985
- )
1986
-
1987
- self.integrate_particles(model, state_in, state_out, dt)
1988
-
1989
- self._step += 1
1990
-
1991
- return state_out
1
+ # Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
2
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
3
+ # and proprietary rights in and to this software, related documentation
4
+ # and any modifications thereto. Any use, reproduction, disclosure or
5
+ # distribution of this software and related documentation without an express
6
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
+
8
+ import warp as wp
9
+
10
+ from .articulation import (
11
+ compute_2d_rotational_dofs,
12
+ compute_3d_rotational_dofs,
13
+ eval_fk,
14
+ )
15
+ from .integrator import Integrator
16
+ from .integrator_euler import (
17
+ eval_bending_forces,
18
+ eval_joint_force,
19
+ eval_muscle_forces,
20
+ eval_particle_body_contact_forces,
21
+ eval_particle_forces,
22
+ eval_particle_ground_contact_forces,
23
+ eval_rigid_contacts,
24
+ eval_spring_forces,
25
+ eval_tetrahedral_forces,
26
+ eval_triangle_contact_forces,
27
+ eval_triangle_forces,
28
+ )
29
+ from .model import Control, Model, State
30
+
31
+
32
+ # Frank & Park definition 3.20, pg 100
33
+ @wp.func
34
+ def transform_twist(t: wp.transform, x: wp.spatial_vector):
35
+ q = wp.transform_get_rotation(t)
36
+ p = wp.transform_get_translation(t)
37
+
38
+ w = wp.spatial_top(x)
39
+ v = wp.spatial_bottom(x)
40
+
41
+ w = wp.quat_rotate(q, w)
42
+ v = wp.quat_rotate(q, v) + wp.cross(p, w)
43
+
44
+ return wp.spatial_vector(w, v)
45
+
46
+
47
+ @wp.func
48
+ def transform_wrench(t: wp.transform, x: wp.spatial_vector):
49
+ q = wp.transform_get_rotation(t)
50
+ p = wp.transform_get_translation(t)
51
+
52
+ w = wp.spatial_top(x)
53
+ v = wp.spatial_bottom(x)
54
+
55
+ v = wp.quat_rotate(q, v)
56
+ w = wp.quat_rotate(q, w) + wp.cross(p, v)
57
+
58
+ return wp.spatial_vector(w, v)
59
+
60
+
61
+ @wp.func
62
+ def spatial_adjoint(R: wp.mat33, S: wp.mat33):
63
+ # T = [R 0]
64
+ # [S R]
65
+
66
+ # fmt: off
67
+ return wp.spatial_matrix(
68
+ R[0, 0], R[0, 1], R[0, 2], 0.0, 0.0, 0.0,
69
+ R[1, 0], R[1, 1], R[1, 2], 0.0, 0.0, 0.0,
70
+ R[2, 0], R[2, 1], R[2, 2], 0.0, 0.0, 0.0,
71
+ S[0, 0], S[0, 1], S[0, 2], R[0, 0], R[0, 1], R[0, 2],
72
+ S[1, 0], S[1, 1], S[1, 2], R[1, 0], R[1, 1], R[1, 2],
73
+ S[2, 0], S[2, 1], S[2, 2], R[2, 0], R[2, 1], R[2, 2],
74
+ )
75
+ # fmt: on
76
+
77
+
78
+ @wp.kernel
79
+ def compute_spatial_inertia(
80
+ body_inertia: wp.array(dtype=wp.mat33),
81
+ body_mass: wp.array(dtype=float),
82
+ # outputs
83
+ body_I_m: wp.array(dtype=wp.spatial_matrix),
84
+ ):
85
+ tid = wp.tid()
86
+ I = body_inertia[tid]
87
+ m = body_mass[tid]
88
+ # fmt: off
89
+ body_I_m[tid] = wp.spatial_matrix(
90
+ I[0, 0], I[0, 1], I[0, 2], 0.0, 0.0, 0.0,
91
+ I[1, 0], I[1, 1], I[1, 2], 0.0, 0.0, 0.0,
92
+ I[2, 0], I[2, 1], I[2, 2], 0.0, 0.0, 0.0,
93
+ 0.0, 0.0, 0.0, m, 0.0, 0.0,
94
+ 0.0, 0.0, 0.0, 0.0, m, 0.0,
95
+ 0.0, 0.0, 0.0, 0.0, 0.0, m,
96
+ )
97
+ # fmt: on
98
+
99
+
100
+ @wp.kernel
101
+ def compute_com_transforms(
102
+ body_com: wp.array(dtype=wp.vec3),
103
+ # outputs
104
+ body_X_com: wp.array(dtype=wp.transform),
105
+ ):
106
+ tid = wp.tid()
107
+ com = body_com[tid]
108
+ body_X_com[tid] = wp.transform(com, wp.quat_identity())
109
+
110
+
111
+ # computes adj_t^-T*I*adj_t^-1 (tensor change of coordinates), Frank & Park, section 8.2.3, pg 290
112
+ @wp.func
113
+ def spatial_transform_inertia(t: wp.transform, I: wp.spatial_matrix):
114
+ t_inv = wp.transform_inverse(t)
115
+
116
+ q = wp.transform_get_rotation(t_inv)
117
+ p = wp.transform_get_translation(t_inv)
118
+
119
+ r1 = wp.quat_rotate(q, wp.vec3(1.0, 0.0, 0.0))
120
+ r2 = wp.quat_rotate(q, wp.vec3(0.0, 1.0, 0.0))
121
+ r3 = wp.quat_rotate(q, wp.vec3(0.0, 0.0, 1.0))
122
+
123
+ R = wp.mat33(r1, r2, r3)
124
+ S = wp.skew(p) @ R
125
+
126
+ T = spatial_adjoint(R, S)
127
+
128
+ return wp.mul(wp.mul(wp.transpose(T), I), T)
129
+
130
+
131
+ # compute transform across a joint
132
+ @wp.func
133
+ def jcalc_transform(
134
+ type: int,
135
+ joint_axis: wp.array(dtype=wp.vec3),
136
+ axis_start: int,
137
+ lin_axis_count: int,
138
+ ang_axis_count: int,
139
+ joint_q: wp.array(dtype=float),
140
+ start: int,
141
+ ):
142
+ if type == wp.sim.JOINT_PRISMATIC:
143
+ q = joint_q[start]
144
+ axis = joint_axis[axis_start]
145
+ X_jc = wp.transform(axis * q, wp.quat_identity())
146
+ return X_jc
147
+
148
+ if type == wp.sim.JOINT_REVOLUTE:
149
+ q = joint_q[start]
150
+ axis = joint_axis[axis_start]
151
+ X_jc = wp.transform(wp.vec3(), wp.quat_from_axis_angle(axis, q))
152
+ return X_jc
153
+
154
+ if type == wp.sim.JOINT_BALL:
155
+ qx = joint_q[start + 0]
156
+ qy = joint_q[start + 1]
157
+ qz = joint_q[start + 2]
158
+ qw = joint_q[start + 3]
159
+
160
+ X_jc = wp.transform(wp.vec3(), wp.quat(qx, qy, qz, qw))
161
+ return X_jc
162
+
163
+ if type == wp.sim.JOINT_FIXED:
164
+ X_jc = wp.transform_identity()
165
+ return X_jc
166
+
167
+ if type == wp.sim.JOINT_FREE or type == wp.sim.JOINT_DISTANCE:
168
+ px = joint_q[start + 0]
169
+ py = joint_q[start + 1]
170
+ pz = joint_q[start + 2]
171
+
172
+ qx = joint_q[start + 3]
173
+ qy = joint_q[start + 4]
174
+ qz = joint_q[start + 5]
175
+ qw = joint_q[start + 6]
176
+
177
+ X_jc = wp.transform(wp.vec3(px, py, pz), wp.quat(qx, qy, qz, qw))
178
+ return X_jc
179
+
180
+ if type == wp.sim.JOINT_COMPOUND:
181
+ rot, _ = compute_3d_rotational_dofs(
182
+ joint_axis[axis_start],
183
+ joint_axis[axis_start + 1],
184
+ joint_axis[axis_start + 2],
185
+ joint_q[start + 0],
186
+ joint_q[start + 1],
187
+ joint_q[start + 2],
188
+ 0.0,
189
+ 0.0,
190
+ 0.0,
191
+ )
192
+
193
+ X_jc = wp.transform(wp.vec3(), rot)
194
+ return X_jc
195
+
196
+ if type == wp.sim.JOINT_UNIVERSAL:
197
+ rot, _ = compute_2d_rotational_dofs(
198
+ joint_axis[axis_start],
199
+ joint_axis[axis_start + 1],
200
+ joint_q[start + 0],
201
+ joint_q[start + 1],
202
+ 0.0,
203
+ 0.0,
204
+ )
205
+
206
+ X_jc = wp.transform(wp.vec3(), rot)
207
+ return X_jc
208
+
209
+ if type == wp.sim.JOINT_D6:
210
+ pos = wp.vec3(0.0)
211
+ rot = wp.quat_identity()
212
+
213
+ # unroll for loop to ensure joint actions remain differentiable
214
+ # (since differentiating through a for loop that updates a local variable is not supported)
215
+
216
+ if lin_axis_count > 0:
217
+ axis = joint_axis[axis_start + 0]
218
+ pos += axis * joint_q[start + 0]
219
+ if lin_axis_count > 1:
220
+ axis = joint_axis[axis_start + 1]
221
+ pos += axis * joint_q[start + 1]
222
+ if lin_axis_count > 2:
223
+ axis = joint_axis[axis_start + 2]
224
+ pos += axis * joint_q[start + 2]
225
+
226
+ ia = axis_start + lin_axis_count
227
+ iq = start + lin_axis_count
228
+ if ang_axis_count == 1:
229
+ axis = joint_axis[ia]
230
+ rot = wp.quat_from_axis_angle(axis, joint_q[iq])
231
+ if ang_axis_count == 2:
232
+ rot, _ = compute_2d_rotational_dofs(
233
+ joint_axis[ia + 0],
234
+ joint_axis[ia + 1],
235
+ joint_q[iq + 0],
236
+ joint_q[iq + 1],
237
+ 0.0,
238
+ 0.0,
239
+ )
240
+ if ang_axis_count == 3:
241
+ rot, _ = compute_3d_rotational_dofs(
242
+ joint_axis[ia + 0],
243
+ joint_axis[ia + 1],
244
+ joint_axis[ia + 2],
245
+ joint_q[iq + 0],
246
+ joint_q[iq + 1],
247
+ joint_q[iq + 2],
248
+ 0.0,
249
+ 0.0,
250
+ 0.0,
251
+ )
252
+
253
+ X_jc = wp.transform(pos, rot)
254
+ return X_jc
255
+
256
+ # default case
257
+ return wp.transform_identity()
258
+
259
+
260
+ # compute motion subspace and velocity for a joint
261
+ @wp.func
262
+ def jcalc_motion(
263
+ type: int,
264
+ joint_axis: wp.array(dtype=wp.vec3),
265
+ axis_start: int,
266
+ lin_axis_count: int,
267
+ ang_axis_count: int,
268
+ X_sc: wp.transform,
269
+ joint_q: wp.array(dtype=float),
270
+ joint_qd: wp.array(dtype=float),
271
+ q_start: int,
272
+ qd_start: int,
273
+ # outputs
274
+ joint_S_s: wp.array(dtype=wp.spatial_vector),
275
+ ):
276
+ if type == wp.sim.JOINT_PRISMATIC:
277
+ axis = joint_axis[axis_start]
278
+ S_s = transform_twist(X_sc, wp.spatial_vector(wp.vec3(), axis))
279
+ v_j_s = S_s * joint_qd[qd_start]
280
+ joint_S_s[qd_start] = S_s
281
+ return v_j_s
282
+
283
+ if type == wp.sim.JOINT_REVOLUTE:
284
+ axis = joint_axis[axis_start]
285
+ S_s = transform_twist(X_sc, wp.spatial_vector(axis, wp.vec3()))
286
+ v_j_s = S_s * joint_qd[qd_start]
287
+ joint_S_s[qd_start] = S_s
288
+ return v_j_s
289
+
290
+ if type == wp.sim.JOINT_UNIVERSAL:
291
+ axis_0 = joint_axis[axis_start + 0]
292
+ axis_1 = joint_axis[axis_start + 1]
293
+ q_off = wp.quat_from_matrix(wp.mat33(axis_0, axis_1, wp.cross(axis_0, axis_1)))
294
+ local_0 = wp.quat_rotate(q_off, wp.vec3(1.0, 0.0, 0.0))
295
+ local_1 = wp.quat_rotate(q_off, wp.vec3(0.0, 1.0, 0.0))
296
+
297
+ axis_0 = local_0
298
+ q_0 = wp.quat_from_axis_angle(axis_0, joint_q[q_start + 0])
299
+
300
+ axis_1 = wp.quat_rotate(q_0, local_1)
301
+
302
+ S_0 = transform_twist(X_sc, wp.spatial_vector(axis_0, wp.vec3()))
303
+ S_1 = transform_twist(X_sc, wp.spatial_vector(axis_1, wp.vec3()))
304
+
305
+ joint_S_s[qd_start + 0] = S_0
306
+ joint_S_s[qd_start + 1] = S_1
307
+
308
+ return S_0 * joint_qd[qd_start + 0] + S_1 * joint_qd[qd_start + 1]
309
+
310
+ if type == wp.sim.JOINT_COMPOUND:
311
+ axis_0 = joint_axis[axis_start + 0]
312
+ axis_1 = joint_axis[axis_start + 1]
313
+ axis_2 = joint_axis[axis_start + 2]
314
+ q_off = wp.quat_from_matrix(wp.mat33(axis_0, axis_1, axis_2))
315
+ local_0 = wp.quat_rotate(q_off, wp.vec3(1.0, 0.0, 0.0))
316
+ local_1 = wp.quat_rotate(q_off, wp.vec3(0.0, 1.0, 0.0))
317
+ local_2 = wp.quat_rotate(q_off, wp.vec3(0.0, 0.0, 1.0))
318
+
319
+ axis_0 = local_0
320
+ q_0 = wp.quat_from_axis_angle(axis_0, joint_q[q_start + 0])
321
+
322
+ axis_1 = wp.quat_rotate(q_0, local_1)
323
+ q_1 = wp.quat_from_axis_angle(axis_1, joint_q[q_start + 1])
324
+
325
+ axis_2 = wp.quat_rotate(q_1 * q_0, local_2)
326
+
327
+ S_0 = transform_twist(X_sc, wp.spatial_vector(axis_0, wp.vec3()))
328
+ S_1 = transform_twist(X_sc, wp.spatial_vector(axis_1, wp.vec3()))
329
+ S_2 = transform_twist(X_sc, wp.spatial_vector(axis_2, wp.vec3()))
330
+
331
+ joint_S_s[qd_start + 0] = S_0
332
+ joint_S_s[qd_start + 1] = S_1
333
+ joint_S_s[qd_start + 2] = S_2
334
+
335
+ return S_0 * joint_qd[qd_start + 0] + S_1 * joint_qd[qd_start + 1] + S_2 * joint_qd[qd_start + 2]
336
+
337
+ if type == wp.sim.JOINT_D6:
338
+ v_j_s = wp.spatial_vector()
339
+ if lin_axis_count > 0:
340
+ axis = joint_axis[axis_start + 0]
341
+ S_s = transform_twist(X_sc, wp.spatial_vector(wp.vec3(), axis))
342
+ v_j_s += S_s * joint_qd[qd_start + 0]
343
+ joint_S_s[qd_start + 0] = S_s
344
+ if lin_axis_count > 1:
345
+ axis = joint_axis[axis_start + 1]
346
+ S_s = transform_twist(X_sc, wp.spatial_vector(wp.vec3(), axis))
347
+ v_j_s += S_s * joint_qd[qd_start + 1]
348
+ joint_S_s[qd_start + 1] = S_s
349
+ if lin_axis_count > 2:
350
+ axis = joint_axis[axis_start + 2]
351
+ S_s = transform_twist(X_sc, wp.spatial_vector(wp.vec3(), axis))
352
+ v_j_s += S_s * joint_qd[qd_start + 2]
353
+ joint_S_s[qd_start + 2] = S_s
354
+ if ang_axis_count > 0:
355
+ axis = joint_axis[axis_start + lin_axis_count + 0]
356
+ S_s = transform_twist(X_sc, wp.spatial_vector(axis, wp.vec3()))
357
+ v_j_s += S_s * joint_qd[qd_start + lin_axis_count + 0]
358
+ joint_S_s[qd_start + lin_axis_count + 0] = S_s
359
+ if ang_axis_count > 1:
360
+ axis = joint_axis[axis_start + lin_axis_count + 1]
361
+ S_s = transform_twist(X_sc, wp.spatial_vector(axis, wp.vec3()))
362
+ v_j_s += S_s * joint_qd[qd_start + lin_axis_count + 1]
363
+ joint_S_s[qd_start + lin_axis_count + 1] = S_s
364
+ if ang_axis_count > 2:
365
+ axis = joint_axis[axis_start + lin_axis_count + 2]
366
+ S_s = transform_twist(X_sc, wp.spatial_vector(axis, wp.vec3()))
367
+ v_j_s += S_s * joint_qd[qd_start + lin_axis_count + 2]
368
+ joint_S_s[qd_start + lin_axis_count + 2] = S_s
369
+
370
+ return v_j_s
371
+
372
+ if type == wp.sim.JOINT_BALL:
373
+ S_0 = transform_twist(X_sc, wp.spatial_vector(1.0, 0.0, 0.0, 0.0, 0.0, 0.0))
374
+ S_1 = transform_twist(X_sc, wp.spatial_vector(0.0, 1.0, 0.0, 0.0, 0.0, 0.0))
375
+ S_2 = transform_twist(X_sc, wp.spatial_vector(0.0, 0.0, 1.0, 0.0, 0.0, 0.0))
376
+
377
+ joint_S_s[qd_start + 0] = S_0
378
+ joint_S_s[qd_start + 1] = S_1
379
+ joint_S_s[qd_start + 2] = S_2
380
+
381
+ return S_0 * joint_qd[qd_start + 0] + S_1 * joint_qd[qd_start + 1] + S_2 * joint_qd[qd_start + 2]
382
+
383
+ if type == wp.sim.JOINT_FIXED:
384
+ return wp.spatial_vector()
385
+
386
+ if type == wp.sim.JOINT_FREE or type == wp.sim.JOINT_DISTANCE:
387
+ v_j_s = transform_twist(
388
+ X_sc,
389
+ wp.spatial_vector(
390
+ joint_qd[qd_start + 0],
391
+ joint_qd[qd_start + 1],
392
+ joint_qd[qd_start + 2],
393
+ joint_qd[qd_start + 3],
394
+ joint_qd[qd_start + 4],
395
+ joint_qd[qd_start + 5],
396
+ ),
397
+ )
398
+
399
+ joint_S_s[qd_start + 0] = transform_twist(X_sc, wp.spatial_vector(1.0, 0.0, 0.0, 0.0, 0.0, 0.0))
400
+ joint_S_s[qd_start + 1] = transform_twist(X_sc, wp.spatial_vector(0.0, 1.0, 0.0, 0.0, 0.0, 0.0))
401
+ joint_S_s[qd_start + 2] = transform_twist(X_sc, wp.spatial_vector(0.0, 0.0, 1.0, 0.0, 0.0, 0.0))
402
+ joint_S_s[qd_start + 3] = transform_twist(X_sc, wp.spatial_vector(0.0, 0.0, 0.0, 1.0, 0.0, 0.0))
403
+ joint_S_s[qd_start + 4] = transform_twist(X_sc, wp.spatial_vector(0.0, 0.0, 0.0, 0.0, 1.0, 0.0))
404
+ joint_S_s[qd_start + 5] = transform_twist(X_sc, wp.spatial_vector(0.0, 0.0, 0.0, 0.0, 0.0, 1.0))
405
+
406
+ return v_j_s
407
+
408
+ wp.printf("jcalc_motion not implemented for joint type %d\n", type)
409
+
410
+ # default case
411
+ return wp.spatial_vector()
412
+
413
+
414
+ # computes joint space forces/torques in tau
415
+ @wp.func
416
+ def jcalc_tau(
417
+ type: int,
418
+ joint_target_ke: wp.array(dtype=float),
419
+ joint_target_kd: wp.array(dtype=float),
420
+ joint_limit_ke: wp.array(dtype=float),
421
+ joint_limit_kd: wp.array(dtype=float),
422
+ joint_S_s: wp.array(dtype=wp.spatial_vector),
423
+ joint_q: wp.array(dtype=float),
424
+ joint_qd: wp.array(dtype=float),
425
+ joint_act: wp.array(dtype=float),
426
+ joint_axis_mode: wp.array(dtype=int),
427
+ joint_limit_lower: wp.array(dtype=float),
428
+ joint_limit_upper: wp.array(dtype=float),
429
+ coord_start: int,
430
+ dof_start: int,
431
+ axis_start: int,
432
+ lin_axis_count: int,
433
+ ang_axis_count: int,
434
+ body_f_s: wp.spatial_vector,
435
+ # outputs
436
+ tau: wp.array(dtype=float),
437
+ ):
438
+ if type == wp.sim.JOINT_PRISMATIC or type == wp.sim.JOINT_REVOLUTE:
439
+ S_s = joint_S_s[dof_start]
440
+
441
+ q = joint_q[coord_start]
442
+ qd = joint_qd[dof_start]
443
+ act = joint_act[axis_start]
444
+
445
+ lower = joint_limit_lower[axis_start]
446
+ upper = joint_limit_upper[axis_start]
447
+
448
+ limit_ke = joint_limit_ke[axis_start]
449
+ limit_kd = joint_limit_kd[axis_start]
450
+ target_ke = joint_target_ke[axis_start]
451
+ target_kd = joint_target_kd[axis_start]
452
+ mode = joint_axis_mode[axis_start]
453
+
454
+ # total torque / force on the joint
455
+ t = -wp.dot(S_s, body_f_s) + eval_joint_force(
456
+ q, qd, act, target_ke, target_kd, lower, upper, limit_ke, limit_kd, mode
457
+ )
458
+
459
+ tau[dof_start] = t
460
+
461
+ return
462
+
463
+ if type == wp.sim.JOINT_BALL:
464
+ # target_ke = joint_target_ke[axis_start]
465
+ # target_kd = joint_target_kd[axis_start]
466
+
467
+ for i in range(3):
468
+ S_s = joint_S_s[dof_start + i]
469
+
470
+ # w = joint_qd[dof_start + i]
471
+ # r = joint_q[coord_start + i]
472
+
473
+ tau[dof_start + i] = -wp.dot(S_s, body_f_s) # - w * target_kd - r * target_ke
474
+
475
+ return
476
+
477
+ if type == wp.sim.JOINT_FREE or type == wp.sim.JOINT_DISTANCE:
478
+ for i in range(6):
479
+ S_s = joint_S_s[dof_start + i]
480
+ tau[dof_start + i] = -wp.dot(S_s, body_f_s)
481
+
482
+ return
483
+
484
+ if type == wp.sim.JOINT_COMPOUND or type == wp.sim.JOINT_UNIVERSAL or type == wp.sim.JOINT_D6:
485
+ axis_count = lin_axis_count + ang_axis_count
486
+
487
+ for i in range(axis_count):
488
+ S_s = joint_S_s[dof_start + i]
489
+
490
+ q = joint_q[coord_start + i]
491
+ qd = joint_qd[dof_start + i]
492
+ act = joint_act[axis_start + i]
493
+
494
+ lower = joint_limit_lower[axis_start + i]
495
+ upper = joint_limit_upper[axis_start + i]
496
+ limit_ke = joint_limit_ke[axis_start + i]
497
+ limit_kd = joint_limit_kd[axis_start + i]
498
+ target_ke = joint_target_ke[axis_start + i]
499
+ target_kd = joint_target_kd[axis_start + i]
500
+ mode = joint_axis_mode[axis_start + i]
501
+
502
+ f = eval_joint_force(q, qd, act, target_ke, target_kd, lower, upper, limit_ke, limit_kd, mode)
503
+
504
+ # total torque / force on the joint
505
+ t = -wp.dot(S_s, body_f_s) + f
506
+
507
+ tau[dof_start + i] = t
508
+
509
+ return
510
+
511
+
512
+ @wp.func
513
+ def jcalc_integrate(
514
+ type: int,
515
+ joint_q: wp.array(dtype=float),
516
+ joint_qd: wp.array(dtype=float),
517
+ joint_qdd: wp.array(dtype=float),
518
+ coord_start: int,
519
+ dof_start: int,
520
+ lin_axis_count: int,
521
+ ang_axis_count: int,
522
+ dt: float,
523
+ # outputs
524
+ joint_q_new: wp.array(dtype=float),
525
+ joint_qd_new: wp.array(dtype=float),
526
+ ):
527
+ if type == wp.sim.JOINT_FIXED:
528
+ return
529
+
530
+ # prismatic / revolute
531
+ if type == wp.sim.JOINT_PRISMATIC or type == wp.sim.JOINT_REVOLUTE:
532
+ qdd = joint_qdd[dof_start]
533
+ qd = joint_qd[dof_start]
534
+ q = joint_q[coord_start]
535
+
536
+ qd_new = qd + qdd * dt
537
+ q_new = q + qd_new * dt
538
+
539
+ joint_qd_new[dof_start] = qd_new
540
+ joint_q_new[coord_start] = q_new
541
+
542
+ return
543
+
544
+ # ball
545
+ if type == wp.sim.JOINT_BALL:
546
+ m_j = wp.vec3(joint_qdd[dof_start + 0], joint_qdd[dof_start + 1], joint_qdd[dof_start + 2])
547
+ w_j = wp.vec3(joint_qd[dof_start + 0], joint_qd[dof_start + 1], joint_qd[dof_start + 2])
548
+
549
+ r_j = wp.quat(
550
+ joint_q[coord_start + 0], joint_q[coord_start + 1], joint_q[coord_start + 2], joint_q[coord_start + 3]
551
+ )
552
+
553
+ # symplectic Euler
554
+ w_j_new = w_j + m_j * dt
555
+
556
+ drdt_j = wp.quat(w_j_new, 0.0) * r_j * 0.5
557
+
558
+ # new orientation (normalized)
559
+ r_j_new = wp.normalize(r_j + drdt_j * dt)
560
+
561
+ # update joint coords
562
+ joint_q_new[coord_start + 0] = r_j_new[0]
563
+ joint_q_new[coord_start + 1] = r_j_new[1]
564
+ joint_q_new[coord_start + 2] = r_j_new[2]
565
+ joint_q_new[coord_start + 3] = r_j_new[3]
566
+
567
+ # update joint vel
568
+ joint_qd_new[dof_start + 0] = w_j_new[0]
569
+ joint_qd_new[dof_start + 1] = w_j_new[1]
570
+ joint_qd_new[dof_start + 2] = w_j_new[2]
571
+
572
+ return
573
+
574
+ # free joint
575
+ if type == wp.sim.JOINT_FREE or type == wp.sim.JOINT_DISTANCE:
576
+ # dofs: qd = (omega_x, omega_y, omega_z, vel_x, vel_y, vel_z)
577
+ # coords: q = (trans_x, trans_y, trans_z, quat_x, quat_y, quat_z, quat_w)
578
+
579
+ # angular and linear acceleration
580
+ m_s = wp.vec3(joint_qdd[dof_start + 0], joint_qdd[dof_start + 1], joint_qdd[dof_start + 2])
581
+ a_s = wp.vec3(joint_qdd[dof_start + 3], joint_qdd[dof_start + 4], joint_qdd[dof_start + 5])
582
+
583
+ # angular and linear velocity
584
+ w_s = wp.vec3(joint_qd[dof_start + 0], joint_qd[dof_start + 1], joint_qd[dof_start + 2])
585
+ v_s = wp.vec3(joint_qd[dof_start + 3], joint_qd[dof_start + 4], joint_qd[dof_start + 5])
586
+
587
+ # symplectic Euler
588
+ w_s = w_s + m_s * dt
589
+ v_s = v_s + a_s * dt
590
+
591
+ # translation of origin
592
+ p_s = wp.vec3(joint_q[coord_start + 0], joint_q[coord_start + 1], joint_q[coord_start + 2])
593
+
594
+ # linear vel of origin (note q/qd switch order of linear angular elements)
595
+ # note we are converting the body twist in the space frame (w_s, v_s) to compute center of mass velcity
596
+ dpdt_s = v_s + wp.cross(w_s, p_s)
597
+
598
+ # quat and quat derivative
599
+ r_s = wp.quat(
600
+ joint_q[coord_start + 3], joint_q[coord_start + 4], joint_q[coord_start + 5], joint_q[coord_start + 6]
601
+ )
602
+
603
+ drdt_s = wp.quat(w_s, 0.0) * r_s * 0.5
604
+
605
+ # new orientation (normalized)
606
+ p_s_new = p_s + dpdt_s * dt
607
+ r_s_new = wp.normalize(r_s + drdt_s * dt)
608
+
609
+ # update transform
610
+ joint_q_new[coord_start + 0] = p_s_new[0]
611
+ joint_q_new[coord_start + 1] = p_s_new[1]
612
+ joint_q_new[coord_start + 2] = p_s_new[2]
613
+
614
+ joint_q_new[coord_start + 3] = r_s_new[0]
615
+ joint_q_new[coord_start + 4] = r_s_new[1]
616
+ joint_q_new[coord_start + 5] = r_s_new[2]
617
+ joint_q_new[coord_start + 6] = r_s_new[3]
618
+
619
+ # update joint_twist
620
+ joint_qd_new[dof_start + 0] = w_s[0]
621
+ joint_qd_new[dof_start + 1] = w_s[1]
622
+ joint_qd_new[dof_start + 2] = w_s[2]
623
+ joint_qd_new[dof_start + 3] = v_s[0]
624
+ joint_qd_new[dof_start + 4] = v_s[1]
625
+ joint_qd_new[dof_start + 5] = v_s[2]
626
+
627
+ return
628
+
629
+ # other joint types (compound, universal, D6)
630
+ if type == wp.sim.JOINT_COMPOUND or type == wp.sim.JOINT_UNIVERSAL or type == wp.sim.JOINT_D6:
631
+ axis_count = lin_axis_count + ang_axis_count
632
+
633
+ for i in range(axis_count):
634
+ qdd = joint_qdd[dof_start + i]
635
+ qd = joint_qd[dof_start + i]
636
+ q = joint_q[coord_start + i]
637
+
638
+ qd_new = qd + qdd * dt
639
+ q_new = q + qd_new * dt
640
+
641
+ joint_qd_new[dof_start + i] = qd_new
642
+ joint_q_new[coord_start + i] = q_new
643
+
644
+ return
645
+
646
+
647
+ @wp.func
648
+ def compute_link_transform(
649
+ i: int,
650
+ joint_type: wp.array(dtype=int),
651
+ joint_parent: wp.array(dtype=int),
652
+ joint_child: wp.array(dtype=int),
653
+ joint_q_start: wp.array(dtype=int),
654
+ joint_q: wp.array(dtype=float),
655
+ joint_X_p: wp.array(dtype=wp.transform),
656
+ joint_X_c: wp.array(dtype=wp.transform),
657
+ body_X_com: wp.array(dtype=wp.transform),
658
+ joint_axis: wp.array(dtype=wp.vec3),
659
+ joint_axis_start: wp.array(dtype=int),
660
+ joint_axis_dim: wp.array(dtype=int, ndim=2),
661
+ # outputs
662
+ body_q: wp.array(dtype=wp.transform),
663
+ body_q_com: wp.array(dtype=wp.transform),
664
+ ):
665
+ # parent transform
666
+ parent = joint_parent[i]
667
+ child = joint_child[i]
668
+
669
+ # parent transform in spatial coordinates
670
+ X_pj = joint_X_p[i]
671
+ X_cj = joint_X_c[i]
672
+ # parent anchor frame in world space
673
+ X_wpj = X_pj
674
+ if parent >= 0:
675
+ X_wp = body_q[parent]
676
+ X_wpj = X_wp * X_wpj
677
+
678
+ type = joint_type[i]
679
+ axis_start = joint_axis_start[i]
680
+ lin_axis_count = joint_axis_dim[i, 0]
681
+ ang_axis_count = joint_axis_dim[i, 1]
682
+ coord_start = joint_q_start[i]
683
+
684
+ # compute transform across joint
685
+ X_j = jcalc_transform(type, joint_axis, axis_start, lin_axis_count, ang_axis_count, joint_q, coord_start)
686
+
687
+ # transform from world to joint anchor frame at child body
688
+ X_wcj = X_wpj * X_j
689
+ # transform from world to child body frame
690
+ X_wc = X_wcj * wp.transform_inverse(X_cj)
691
+
692
+ # compute transform of center of mass
693
+ X_cm = body_X_com[child]
694
+ X_sm = X_wc * X_cm
695
+
696
+ # store geometry transforms
697
+ body_q[child] = X_wc
698
+ body_q_com[child] = X_sm
699
+
700
+
701
+ @wp.kernel
702
+ def eval_rigid_fk(
703
+ articulation_start: wp.array(dtype=int),
704
+ joint_type: wp.array(dtype=int),
705
+ joint_parent: wp.array(dtype=int),
706
+ joint_child: wp.array(dtype=int),
707
+ joint_q_start: wp.array(dtype=int),
708
+ joint_q: wp.array(dtype=float),
709
+ joint_X_p: wp.array(dtype=wp.transform),
710
+ joint_X_c: wp.array(dtype=wp.transform),
711
+ body_X_com: wp.array(dtype=wp.transform),
712
+ joint_axis: wp.array(dtype=wp.vec3),
713
+ joint_axis_start: wp.array(dtype=int),
714
+ joint_axis_dim: wp.array(dtype=int, ndim=2),
715
+ # outputs
716
+ body_q: wp.array(dtype=wp.transform),
717
+ body_q_com: wp.array(dtype=wp.transform),
718
+ ):
719
+ # one thread per-articulation
720
+ index = wp.tid()
721
+
722
+ start = articulation_start[index]
723
+ end = articulation_start[index + 1]
724
+
725
+ for i in range(start, end):
726
+ compute_link_transform(
727
+ i,
728
+ joint_type,
729
+ joint_parent,
730
+ joint_child,
731
+ joint_q_start,
732
+ joint_q,
733
+ joint_X_p,
734
+ joint_X_c,
735
+ body_X_com,
736
+ joint_axis,
737
+ joint_axis_start,
738
+ joint_axis_dim,
739
+ body_q,
740
+ body_q_com,
741
+ )
742
+
743
+
744
+ @wp.func
745
+ def spatial_cross(a: wp.spatial_vector, b: wp.spatial_vector):
746
+ w_a = wp.spatial_top(a)
747
+ v_a = wp.spatial_bottom(a)
748
+
749
+ w_b = wp.spatial_top(b)
750
+ v_b = wp.spatial_bottom(b)
751
+
752
+ w = wp.cross(w_a, w_b)
753
+ v = wp.cross(w_a, v_b) + wp.cross(v_a, w_b)
754
+
755
+ return wp.spatial_vector(w, v)
756
+
757
+
758
+ @wp.func
759
+ def spatial_cross_dual(a: wp.spatial_vector, b: wp.spatial_vector):
760
+ w_a = wp.spatial_top(a)
761
+ v_a = wp.spatial_bottom(a)
762
+
763
+ w_b = wp.spatial_top(b)
764
+ v_b = wp.spatial_bottom(b)
765
+
766
+ w = wp.cross(w_a, w_b) + wp.cross(v_a, v_b)
767
+ v = wp.cross(w_a, v_b)
768
+
769
+ return wp.spatial_vector(w, v)
770
+
771
+
772
+ @wp.func
773
+ def dense_index(stride: int, i: int, j: int):
774
+ return i * stride + j
775
+
776
+
777
+ @wp.func
778
+ def compute_link_velocity(
779
+ i: int,
780
+ joint_type: wp.array(dtype=int),
781
+ joint_parent: wp.array(dtype=int),
782
+ joint_child: wp.array(dtype=int),
783
+ joint_q_start: wp.array(dtype=int),
784
+ joint_qd_start: wp.array(dtype=int),
785
+ joint_q: wp.array(dtype=float),
786
+ joint_qd: wp.array(dtype=float),
787
+ joint_axis: wp.array(dtype=wp.vec3),
788
+ joint_axis_start: wp.array(dtype=int),
789
+ joint_axis_dim: wp.array(dtype=int, ndim=2),
790
+ body_I_m: wp.array(dtype=wp.spatial_matrix),
791
+ body_q: wp.array(dtype=wp.transform),
792
+ body_q_com: wp.array(dtype=wp.transform),
793
+ joint_X_p: wp.array(dtype=wp.transform),
794
+ joint_X_c: wp.array(dtype=wp.transform),
795
+ gravity: wp.vec3,
796
+ # outputs
797
+ joint_S_s: wp.array(dtype=wp.spatial_vector),
798
+ body_I_s: wp.array(dtype=wp.spatial_matrix),
799
+ body_v_s: wp.array(dtype=wp.spatial_vector),
800
+ body_f_s: wp.array(dtype=wp.spatial_vector),
801
+ body_a_s: wp.array(dtype=wp.spatial_vector),
802
+ ):
803
+ type = joint_type[i]
804
+ child = joint_child[i]
805
+ parent = joint_parent[i]
806
+ q_start = joint_q_start[i]
807
+ qd_start = joint_qd_start[i]
808
+
809
+ X_pj = joint_X_p[i]
810
+ # X_cj = joint_X_c[i]
811
+
812
+ # parent anchor frame in world space
813
+ X_wpj = X_pj
814
+ if parent >= 0:
815
+ X_wp = body_q[parent]
816
+ X_wpj = X_wp * X_wpj
817
+
818
+ # compute motion subspace and velocity across the joint (also stores S_s to global memory)
819
+ axis_start = joint_axis_start[i]
820
+ lin_axis_count = joint_axis_dim[i, 0]
821
+ ang_axis_count = joint_axis_dim[i, 1]
822
+ v_j_s = jcalc_motion(
823
+ type,
824
+ joint_axis,
825
+ axis_start,
826
+ lin_axis_count,
827
+ ang_axis_count,
828
+ X_wpj,
829
+ joint_q,
830
+ joint_qd,
831
+ q_start,
832
+ qd_start,
833
+ joint_S_s,
834
+ )
835
+
836
+ # parent velocity
837
+ v_parent_s = wp.spatial_vector()
838
+ a_parent_s = wp.spatial_vector()
839
+
840
+ if parent >= 0:
841
+ v_parent_s = body_v_s[parent]
842
+ a_parent_s = body_a_s[parent]
843
+
844
+ # body velocity, acceleration
845
+ v_s = v_parent_s + v_j_s
846
+ a_s = a_parent_s + spatial_cross(v_s, v_j_s) # + joint_S_s[i]*self.joint_qdd[i]
847
+
848
+ # compute body forces
849
+ X_sm = body_q_com[child]
850
+ I_m = body_I_m[child]
851
+
852
+ # gravity and external forces (expressed in frame aligned with s but centered at body mass)
853
+ m = I_m[3, 3]
854
+
855
+ f_g = m * gravity
856
+ r_com = wp.transform_get_translation(X_sm)
857
+ f_g_s = wp.spatial_vector(wp.cross(r_com, f_g), f_g)
858
+
859
+ # body forces
860
+ I_s = spatial_transform_inertia(X_sm, I_m)
861
+
862
+ f_b_s = I_s * a_s + spatial_cross_dual(v_s, I_s * v_s)
863
+
864
+ body_v_s[child] = v_s
865
+ body_a_s[child] = a_s
866
+ body_f_s[child] = f_b_s - f_g_s
867
+ body_I_s[child] = I_s
868
+
869
+
870
+ # Inverse dynamics via Recursive Newton-Euler algorithm (Featherstone Table 5.1)
871
+ @wp.kernel
872
+ def eval_rigid_id(
873
+ articulation_start: wp.array(dtype=int),
874
+ joint_type: wp.array(dtype=int),
875
+ joint_parent: wp.array(dtype=int),
876
+ joint_child: wp.array(dtype=int),
877
+ joint_q_start: wp.array(dtype=int),
878
+ joint_qd_start: wp.array(dtype=int),
879
+ joint_q: wp.array(dtype=float),
880
+ joint_qd: wp.array(dtype=float),
881
+ joint_axis: wp.array(dtype=wp.vec3),
882
+ joint_axis_start: wp.array(dtype=int),
883
+ joint_axis_dim: wp.array(dtype=int, ndim=2),
884
+ body_I_m: wp.array(dtype=wp.spatial_matrix),
885
+ body_q: wp.array(dtype=wp.transform),
886
+ body_q_com: wp.array(dtype=wp.transform),
887
+ joint_X_p: wp.array(dtype=wp.transform),
888
+ joint_X_c: wp.array(dtype=wp.transform),
889
+ gravity: wp.vec3,
890
+ # outputs
891
+ joint_S_s: wp.array(dtype=wp.spatial_vector),
892
+ body_I_s: wp.array(dtype=wp.spatial_matrix),
893
+ body_v_s: wp.array(dtype=wp.spatial_vector),
894
+ body_f_s: wp.array(dtype=wp.spatial_vector),
895
+ body_a_s: wp.array(dtype=wp.spatial_vector),
896
+ ):
897
+ # one thread per-articulation
898
+ index = wp.tid()
899
+
900
+ start = articulation_start[index]
901
+ end = articulation_start[index + 1]
902
+
903
+ # compute link velocities and coriolis forces
904
+ for i in range(start, end):
905
+ compute_link_velocity(
906
+ i,
907
+ joint_type,
908
+ joint_parent,
909
+ joint_child,
910
+ joint_q_start,
911
+ joint_qd_start,
912
+ joint_q,
913
+ joint_qd,
914
+ joint_axis,
915
+ joint_axis_start,
916
+ joint_axis_dim,
917
+ body_I_m,
918
+ body_q,
919
+ body_q_com,
920
+ joint_X_p,
921
+ joint_X_c,
922
+ gravity,
923
+ joint_S_s,
924
+ body_I_s,
925
+ body_v_s,
926
+ body_f_s,
927
+ body_a_s,
928
+ )
929
+
930
+
931
+ @wp.kernel
932
+ def eval_rigid_tau(
933
+ articulation_start: wp.array(dtype=int),
934
+ joint_type: wp.array(dtype=int),
935
+ joint_parent: wp.array(dtype=int),
936
+ joint_child: wp.array(dtype=int),
937
+ joint_q_start: wp.array(dtype=int),
938
+ joint_qd_start: wp.array(dtype=int),
939
+ joint_axis_start: wp.array(dtype=int),
940
+ joint_axis_dim: wp.array(dtype=int, ndim=2),
941
+ joint_axis_mode: wp.array(dtype=int),
942
+ joint_q: wp.array(dtype=float),
943
+ joint_qd: wp.array(dtype=float),
944
+ joint_act: wp.array(dtype=float),
945
+ joint_target_ke: wp.array(dtype=float),
946
+ joint_target_kd: wp.array(dtype=float),
947
+ joint_limit_lower: wp.array(dtype=float),
948
+ joint_limit_upper: wp.array(dtype=float),
949
+ joint_limit_ke: wp.array(dtype=float),
950
+ joint_limit_kd: wp.array(dtype=float),
951
+ joint_S_s: wp.array(dtype=wp.spatial_vector),
952
+ body_fb_s: wp.array(dtype=wp.spatial_vector),
953
+ body_f_ext: wp.array(dtype=wp.spatial_vector),
954
+ # outputs
955
+ body_ft_s: wp.array(dtype=wp.spatial_vector),
956
+ tau: wp.array(dtype=float),
957
+ ):
958
+ # one thread per-articulation
959
+ index = wp.tid()
960
+
961
+ start = articulation_start[index]
962
+ end = articulation_start[index + 1]
963
+ count = end - start
964
+
965
+ # compute joint forces
966
+ for offset in range(count):
967
+ # for backwards traversal
968
+ i = end - offset - 1
969
+
970
+ type = joint_type[i]
971
+ parent = joint_parent[i]
972
+ child = joint_child[i]
973
+ dof_start = joint_qd_start[i]
974
+ coord_start = joint_q_start[i]
975
+ axis_start = joint_axis_start[i]
976
+ lin_axis_count = joint_axis_dim[i, 0]
977
+ ang_axis_count = joint_axis_dim[i, 1]
978
+
979
+ # total forces on body
980
+ f_b_s = body_fb_s[child]
981
+ f_t_s = body_ft_s[child]
982
+ f_ext = body_f_ext[child]
983
+ f_s = f_b_s + f_t_s + f_ext
984
+
985
+ # compute joint-space forces, writes out tau
986
+ jcalc_tau(
987
+ type,
988
+ joint_target_ke,
989
+ joint_target_kd,
990
+ joint_limit_ke,
991
+ joint_limit_kd,
992
+ joint_S_s,
993
+ joint_q,
994
+ joint_qd,
995
+ joint_act,
996
+ joint_axis_mode,
997
+ joint_limit_lower,
998
+ joint_limit_upper,
999
+ coord_start,
1000
+ dof_start,
1001
+ axis_start,
1002
+ lin_axis_count,
1003
+ ang_axis_count,
1004
+ f_s,
1005
+ tau,
1006
+ )
1007
+
1008
+ # update parent forces, todo: check that this is valid for the backwards pass
1009
+ if parent >= 0:
1010
+ wp.atomic_add(body_ft_s, parent, f_s)
1011
+
1012
+
1013
+ # builds spatial Jacobian J which is an (joint_count*6)x(dof_count) matrix
1014
+ @wp.kernel
1015
+ def eval_rigid_jacobian(
1016
+ articulation_start: wp.array(dtype=int),
1017
+ articulation_J_start: wp.array(dtype=int),
1018
+ joint_parent: wp.array(dtype=int),
1019
+ joint_qd_start: wp.array(dtype=int),
1020
+ joint_S_s: wp.array(dtype=wp.spatial_vector),
1021
+ # outputs
1022
+ J: wp.array(dtype=float),
1023
+ ):
1024
+ # one thread per-articulation
1025
+ index = wp.tid()
1026
+
1027
+ joint_start = articulation_start[index]
1028
+ joint_end = articulation_start[index + 1]
1029
+ joint_count = joint_end - joint_start
1030
+
1031
+ J_offset = articulation_J_start[index]
1032
+
1033
+ articulation_dof_start = joint_qd_start[joint_start]
1034
+ articulation_dof_end = joint_qd_start[joint_end]
1035
+ articulation_dof_count = articulation_dof_end - articulation_dof_start
1036
+
1037
+ for i in range(joint_count):
1038
+ row_start = i * 6
1039
+
1040
+ j = joint_start + i
1041
+ while j != -1:
1042
+ joint_dof_start = joint_qd_start[j]
1043
+ joint_dof_end = joint_qd_start[j + 1]
1044
+ joint_dof_count = joint_dof_end - joint_dof_start
1045
+
1046
+ # fill out each row of the Jacobian walking up the tree
1047
+ for dof in range(joint_dof_count):
1048
+ col = (joint_dof_start - articulation_dof_start) + dof
1049
+ S = joint_S_s[joint_dof_start + dof]
1050
+
1051
+ for k in range(6):
1052
+ J[J_offset + dense_index(articulation_dof_count, row_start + k, col)] = S[k]
1053
+
1054
+ j = joint_parent[j]
1055
+
1056
+
1057
+ @wp.func
1058
+ def spatial_mass(
1059
+ body_I_s: wp.array(dtype=wp.spatial_matrix),
1060
+ joint_start: int,
1061
+ joint_count: int,
1062
+ M_start: int,
1063
+ # outputs
1064
+ M: wp.array(dtype=float),
1065
+ ):
1066
+ stride = joint_count * 6
1067
+ for l in range(joint_count):
1068
+ I = body_I_s[joint_start + l]
1069
+ for i in range(6):
1070
+ for j in range(6):
1071
+ M[M_start + dense_index(stride, l * 6 + i, l * 6 + j)] = I[i, j]
1072
+
1073
+
1074
+ @wp.kernel
1075
+ def eval_rigid_mass(
1076
+ articulation_start: wp.array(dtype=int),
1077
+ articulation_M_start: wp.array(dtype=int),
1078
+ body_I_s: wp.array(dtype=wp.spatial_matrix),
1079
+ # outputs
1080
+ M: wp.array(dtype=float),
1081
+ ):
1082
+ # one thread per-articulation
1083
+ index = wp.tid()
1084
+
1085
+ joint_start = articulation_start[index]
1086
+ joint_end = articulation_start[index + 1]
1087
+ joint_count = joint_end - joint_start
1088
+
1089
+ M_offset = articulation_M_start[index]
1090
+
1091
+ spatial_mass(body_I_s, joint_start, joint_count, M_offset, M)
1092
+
1093
+
1094
+ @wp.func
1095
+ def dense_gemm(
1096
+ m: int,
1097
+ n: int,
1098
+ p: int,
1099
+ transpose_A: bool,
1100
+ transpose_B: bool,
1101
+ add_to_C: bool,
1102
+ A_start: int,
1103
+ B_start: int,
1104
+ C_start: int,
1105
+ A: wp.array(dtype=float),
1106
+ B: wp.array(dtype=float),
1107
+ # outputs
1108
+ C: wp.array(dtype=float),
1109
+ ):
1110
+ # multiply a `m x p` matrix A by a `p x n` matrix B to produce a `m x n` matrix C
1111
+ for i in range(m):
1112
+ for j in range(n):
1113
+ sum = float(0.0)
1114
+ for k in range(p):
1115
+ if transpose_A:
1116
+ a_i = k * m + i
1117
+ else:
1118
+ a_i = i * p + k
1119
+ if transpose_B:
1120
+ b_j = j * p + k
1121
+ else:
1122
+ b_j = k * n + j
1123
+ sum += A[A_start + a_i] * B[B_start + b_j]
1124
+
1125
+ if add_to_C:
1126
+ C[C_start + i * n + j] += sum
1127
+ else:
1128
+ C[C_start + i * n + j] = sum
1129
+
1130
+
1131
+ # @wp.func_grad(dense_gemm)
1132
+ # def adj_dense_gemm(
1133
+ # m: int,
1134
+ # n: int,
1135
+ # p: int,
1136
+ # transpose_A: bool,
1137
+ # transpose_B: bool,
1138
+ # add_to_C: bool,
1139
+ # A_start: int,
1140
+ # B_start: int,
1141
+ # C_start: int,
1142
+ # A: wp.array(dtype=float),
1143
+ # B: wp.array(dtype=float),
1144
+ # # outputs
1145
+ # C: wp.array(dtype=float),
1146
+ # ):
1147
+ # add_to_C = True
1148
+ # if transpose_A:
1149
+ # dense_gemm(p, m, n, False, True, add_to_C, A_start, B_start, C_start, B, wp.adjoint[C], wp.adjoint[A])
1150
+ # dense_gemm(p, n, m, False, False, add_to_C, A_start, B_start, C_start, A, wp.adjoint[C], wp.adjoint[B])
1151
+ # else:
1152
+ # dense_gemm(
1153
+ # m, p, n, False, not transpose_B, add_to_C, A_start, B_start, C_start, wp.adjoint[C], B, wp.adjoint[A]
1154
+ # )
1155
+ # dense_gemm(p, n, m, True, False, add_to_C, A_start, B_start, C_start, A, wp.adjoint[C], wp.adjoint[B])
1156
+
1157
+
1158
+ @wp.kernel
1159
+ def eval_dense_gemm_batched(
1160
+ m: wp.array(dtype=int),
1161
+ n: wp.array(dtype=int),
1162
+ p: wp.array(dtype=int),
1163
+ transpose_A: bool,
1164
+ transpose_B: bool,
1165
+ A_start: wp.array(dtype=int),
1166
+ B_start: wp.array(dtype=int),
1167
+ C_start: wp.array(dtype=int),
1168
+ A: wp.array(dtype=float),
1169
+ B: wp.array(dtype=float),
1170
+ C: wp.array(dtype=float),
1171
+ ):
1172
+ # on the CPU each thread computes the whole matrix multiply
1173
+ # on the GPU each block computes the multiply with one output per-thread
1174
+ batch = wp.tid() # /kNumThreadsPerBlock;
1175
+ add_to_C = False
1176
+
1177
+ dense_gemm(
1178
+ m[batch],
1179
+ n[batch],
1180
+ p[batch],
1181
+ transpose_A,
1182
+ transpose_B,
1183
+ add_to_C,
1184
+ A_start[batch],
1185
+ B_start[batch],
1186
+ C_start[batch],
1187
+ A,
1188
+ B,
1189
+ C,
1190
+ )
1191
+
1192
+
1193
+ @wp.func
1194
+ def dense_cholesky(
1195
+ n: int,
1196
+ A: wp.array(dtype=float),
1197
+ R: wp.array(dtype=float),
1198
+ A_start: int,
1199
+ R_start: int,
1200
+ # outputs
1201
+ L: wp.array(dtype=float),
1202
+ ):
1203
+ # compute the Cholesky factorization of A = L L^T with diagonal regularization R
1204
+ for j in range(n):
1205
+ s = A[A_start + dense_index(n, j, j)] + R[R_start + j]
1206
+
1207
+ for k in range(j):
1208
+ r = L[A_start + dense_index(n, j, k)]
1209
+ s -= r * r
1210
+
1211
+ s = wp.sqrt(s)
1212
+ invS = 1.0 / s
1213
+
1214
+ L[A_start + dense_index(n, j, j)] = s
1215
+
1216
+ for i in range(j + 1, n):
1217
+ s = A[A_start + dense_index(n, i, j)]
1218
+
1219
+ for k in range(j):
1220
+ s -= L[A_start + dense_index(n, i, k)] * L[A_start + dense_index(n, j, k)]
1221
+
1222
+ L[A_start + dense_index(n, i, j)] = s * invS
1223
+
1224
+
1225
+ @wp.func_grad(dense_cholesky)
1226
+ def adj_dense_cholesky(
1227
+ n: int,
1228
+ A: wp.array(dtype=float),
1229
+ R: wp.array(dtype=float),
1230
+ A_start: int,
1231
+ R_start: int,
1232
+ # outputs
1233
+ L: wp.array(dtype=float),
1234
+ ):
1235
+ # nop, use dense_solve to differentiate through (A^-1)b = x
1236
+ pass
1237
+
1238
+
1239
+ @wp.kernel
1240
+ def eval_dense_cholesky_batched(
1241
+ A_starts: wp.array(dtype=int),
1242
+ A_dim: wp.array(dtype=int),
1243
+ A: wp.array(dtype=float),
1244
+ R: wp.array(dtype=float),
1245
+ L: wp.array(dtype=float),
1246
+ ):
1247
+ batch = wp.tid()
1248
+
1249
+ n = A_dim[batch]
1250
+ A_start = A_starts[batch]
1251
+ R_start = n * batch
1252
+
1253
+ dense_cholesky(n, A, R, A_start, R_start, L)
1254
+
1255
+
1256
+ @wp.func
1257
+ def dense_subs(
1258
+ n: int,
1259
+ L_start: int,
1260
+ b_start: int,
1261
+ L: wp.array(dtype=float),
1262
+ b: wp.array(dtype=float),
1263
+ # outputs
1264
+ x: wp.array(dtype=float),
1265
+ ):
1266
+ # Solves (L L^T) x = b for x given the Cholesky factor L
1267
+ # forward substitution solves the lower triangular system L y = b for y
1268
+ for i in range(n):
1269
+ s = b[b_start + i]
1270
+
1271
+ for j in range(i):
1272
+ s -= L[L_start + dense_index(n, i, j)] * x[b_start + j]
1273
+
1274
+ x[b_start + i] = s / L[L_start + dense_index(n, i, i)]
1275
+
1276
+ # backward substitution solves the upper triangular system L^T x = y for x
1277
+ for i in range(n - 1, -1, -1):
1278
+ s = x[b_start + i]
1279
+
1280
+ for j in range(i + 1, n):
1281
+ s -= L[L_start + dense_index(n, j, i)] * x[b_start + j]
1282
+
1283
+ x[b_start + i] = s / L[L_start + dense_index(n, i, i)]
1284
+
1285
+
1286
+ @wp.func
1287
+ def dense_solve(
1288
+ n: int,
1289
+ L_start: int,
1290
+ b_start: int,
1291
+ A: wp.array(dtype=float),
1292
+ L: wp.array(dtype=float),
1293
+ b: wp.array(dtype=float),
1294
+ # outputs
1295
+ x: wp.array(dtype=float),
1296
+ tmp: wp.array(dtype=float),
1297
+ ):
1298
+ # helper function to include tmp argument for backward pass
1299
+ dense_subs(n, L_start, b_start, L, b, x)
1300
+
1301
+
1302
+ @wp.func_grad(dense_solve)
1303
+ def adj_dense_solve(
1304
+ n: int,
1305
+ L_start: int,
1306
+ b_start: int,
1307
+ A: wp.array(dtype=float),
1308
+ L: wp.array(dtype=float),
1309
+ b: wp.array(dtype=float),
1310
+ # outputs
1311
+ x: wp.array(dtype=float),
1312
+ tmp: wp.array(dtype=float),
1313
+ ):
1314
+ if not tmp or not wp.adjoint[x] or not wp.adjoint[A] or not wp.adjoint[L]:
1315
+ return
1316
+ for i in range(n):
1317
+ tmp[b_start + i] = 0.0
1318
+
1319
+ dense_subs(n, L_start, b_start, L, wp.adjoint[x], tmp)
1320
+
1321
+ for i in range(n):
1322
+ wp.adjoint[b][b_start + i] += tmp[b_start + i]
1323
+
1324
+ # A* = -adj_b*x^T
1325
+ for i in range(n):
1326
+ for j in range(n):
1327
+ wp.adjoint[L][L_start + dense_index(n, i, j)] += -tmp[b_start + i] * x[b_start + j]
1328
+
1329
+ for i in range(n):
1330
+ for j in range(n):
1331
+ wp.adjoint[A][L_start + dense_index(n, i, j)] += -tmp[b_start + i] * x[b_start + j]
1332
+
1333
+
1334
+ @wp.kernel
1335
+ def eval_dense_solve_batched(
1336
+ L_start: wp.array(dtype=int),
1337
+ L_dim: wp.array(dtype=int),
1338
+ b_start: wp.array(dtype=int),
1339
+ A: wp.array(dtype=float),
1340
+ L: wp.array(dtype=float),
1341
+ b: wp.array(dtype=float),
1342
+ # outputs
1343
+ x: wp.array(dtype=float),
1344
+ tmp: wp.array(dtype=float),
1345
+ ):
1346
+ batch = wp.tid()
1347
+
1348
+ dense_solve(L_dim[batch], L_start[batch], b_start[batch], A, L, b, x, tmp)
1349
+
1350
+
1351
+ @wp.kernel
1352
+ def integrate_generalized_joints(
1353
+ joint_type: wp.array(dtype=int),
1354
+ joint_q_start: wp.array(dtype=int),
1355
+ joint_qd_start: wp.array(dtype=int),
1356
+ joint_axis_dim: wp.array(dtype=int, ndim=2),
1357
+ joint_q: wp.array(dtype=float),
1358
+ joint_qd: wp.array(dtype=float),
1359
+ joint_qdd: wp.array(dtype=float),
1360
+ dt: float,
1361
+ # outputs
1362
+ joint_q_new: wp.array(dtype=float),
1363
+ joint_qd_new: wp.array(dtype=float),
1364
+ ):
1365
+ # one thread per-articulation
1366
+ index = wp.tid()
1367
+
1368
+ type = joint_type[index]
1369
+ coord_start = joint_q_start[index]
1370
+ dof_start = joint_qd_start[index]
1371
+ lin_axis_count = joint_axis_dim[index, 0]
1372
+ ang_axis_count = joint_axis_dim[index, 1]
1373
+
1374
+ jcalc_integrate(
1375
+ type,
1376
+ joint_q,
1377
+ joint_qd,
1378
+ joint_qdd,
1379
+ coord_start,
1380
+ dof_start,
1381
+ lin_axis_count,
1382
+ ang_axis_count,
1383
+ dt,
1384
+ joint_q_new,
1385
+ joint_qd_new,
1386
+ )
1387
+
1388
+
1389
+ class FeatherstoneIntegrator(Integrator):
1390
+ """A semi-implicit integrator using symplectic Euler that operates
1391
+ on reduced (also called generalized) coordinates to simulate articulated rigid body dynamics
1392
+ based on Featherstone's composite rigid body algorithm (CRBA).
1393
+
1394
+ See: Featherstone, Roy. Rigid Body Dynamics Algorithms. Springer US, 2014.
1395
+
1396
+ Instead of maximal coordinates :attr:`State.body_q` (rigid body positions) and :attr:`State.body_qd`
1397
+ (rigid body velocities) as is the case :class:`SemiImplicitIntegrator`, :class:`FeatherstoneIntegrator`
1398
+ uses :attr:`State.joint_q` and :attr:`State.joint_qd` to represent the positions and velocities of
1399
+ joints without allowing any redundant degrees of freedom.
1400
+
1401
+ After constructing :class:`Model` and :class:`State` objects this time-integrator
1402
+ may be used to advance the simulation state forward in time.
1403
+
1404
+ Note:
1405
+ Unlike :class:`SemiImplicitIntegrator` and :class:`XPBDIntegrator`, :class:`FeatherstoneIntegrator` does not simulate rigid bodies with nonzero mass as floating bodies if they are not connected through any joints. Floating-base systems require an explicit free joint with which the body is connected to the world, see :meth:`ModelBuilder.add_joint_free`.
1406
+
1407
+ Semi-implicit time integration is a variational integrator that
1408
+ preserves energy, however it not unconditionally stable, and requires a time-step
1409
+ small enough to support the required stiffness and damping forces.
1410
+
1411
+ See: https://en.wikipedia.org/wiki/Semi-implicit_Euler_method
1412
+
1413
+ Example
1414
+ -------
1415
+
1416
+ .. code-block:: python
1417
+
1418
+ integrator = wp.FeatherstoneIntegrator(model)
1419
+
1420
+ # simulation loop
1421
+ for i in range(100):
1422
+ state = integrator.simulate(model, state_in, state_out, dt)
1423
+
1424
+ Note:
1425
+ The :class:`FeatherstoneIntegrator` requires the :class:`Model` to be passed in as a constructor argument.
1426
+
1427
+ """
1428
+
1429
+ def __init__(self, model, angular_damping=0.05, update_mass_matrix_every=1):
1430
+ """
1431
+ Args:
1432
+ model (Model): the model to be simulated.
1433
+ angular_damping (float, optional): Angular damping factor. Defaults to 0.05.
1434
+ update_mass_matrix_every (int, optional): How often to update the mass matrix (every n-th time the :meth:`simulate` function gets called). Defaults to 1.
1435
+ """
1436
+ self.angular_damping = angular_damping
1437
+ self.update_mass_matrix_every = update_mass_matrix_every
1438
+ self.compute_articulation_indices(model)
1439
+ self.allocate_model_aux_vars(model)
1440
+ self._step = 0
1441
+
1442
+ def compute_articulation_indices(self, model):
1443
+ # calculate total size and offsets of Jacobian and mass matrices for entire system
1444
+ if model.joint_count:
1445
+ self.J_size = 0
1446
+ self.M_size = 0
1447
+ self.H_size = 0
1448
+
1449
+ articulation_J_start = []
1450
+ articulation_M_start = []
1451
+ articulation_H_start = []
1452
+
1453
+ articulation_M_rows = []
1454
+ articulation_H_rows = []
1455
+ articulation_J_rows = []
1456
+ articulation_J_cols = []
1457
+
1458
+ articulation_dof_start = []
1459
+ articulation_coord_start = []
1460
+
1461
+ articulation_start = model.articulation_start.numpy()
1462
+ joint_q_start = model.joint_q_start.numpy()
1463
+ joint_qd_start = model.joint_qd_start.numpy()
1464
+
1465
+ for i in range(model.articulation_count):
1466
+ first_joint = articulation_start[i]
1467
+ last_joint = articulation_start[i + 1]
1468
+
1469
+ first_coord = joint_q_start[first_joint]
1470
+
1471
+ first_dof = joint_qd_start[first_joint]
1472
+ last_dof = joint_qd_start[last_joint]
1473
+
1474
+ joint_count = last_joint - first_joint
1475
+ dof_count = last_dof - first_dof
1476
+
1477
+ articulation_J_start.append(self.J_size)
1478
+ articulation_M_start.append(self.M_size)
1479
+ articulation_H_start.append(self.H_size)
1480
+ articulation_dof_start.append(first_dof)
1481
+ articulation_coord_start.append(first_coord)
1482
+
1483
+ # bit of data duplication here, but will leave it as such for clarity
1484
+ articulation_M_rows.append(joint_count * 6)
1485
+ articulation_H_rows.append(dof_count)
1486
+ articulation_J_rows.append(joint_count * 6)
1487
+ articulation_J_cols.append(dof_count)
1488
+
1489
+ self.J_size += 6 * joint_count * dof_count
1490
+ self.M_size += 6 * joint_count * 6 * joint_count
1491
+ self.H_size += dof_count * dof_count
1492
+
1493
+ # matrix offsets for batched gemm
1494
+ self.articulation_J_start = wp.array(articulation_J_start, dtype=wp.int32, device=model.device)
1495
+ self.articulation_M_start = wp.array(articulation_M_start, dtype=wp.int32, device=model.device)
1496
+ self.articulation_H_start = wp.array(articulation_H_start, dtype=wp.int32, device=model.device)
1497
+
1498
+ self.articulation_M_rows = wp.array(articulation_M_rows, dtype=wp.int32, device=model.device)
1499
+ self.articulation_H_rows = wp.array(articulation_H_rows, dtype=wp.int32, device=model.device)
1500
+ self.articulation_J_rows = wp.array(articulation_J_rows, dtype=wp.int32, device=model.device)
1501
+ self.articulation_J_cols = wp.array(articulation_J_cols, dtype=wp.int32, device=model.device)
1502
+
1503
+ self.articulation_dof_start = wp.array(articulation_dof_start, dtype=wp.int32, device=model.device)
1504
+ self.articulation_coord_start = wp.array(articulation_coord_start, dtype=wp.int32, device=model.device)
1505
+
1506
+ def allocate_model_aux_vars(self, model):
1507
+ # allocate mass, Jacobian matrices, and other auxiliary variables pertaining to the model
1508
+ if model.joint_count:
1509
+ # system matrices
1510
+ self.M = wp.zeros((self.M_size,), dtype=wp.float32, device=model.device, requires_grad=model.requires_grad)
1511
+ self.J = wp.zeros((self.J_size,), dtype=wp.float32, device=model.device, requires_grad=model.requires_grad)
1512
+ self.P = wp.empty_like(self.J, requires_grad=model.requires_grad)
1513
+ self.H = wp.empty((self.H_size,), dtype=wp.float32, device=model.device, requires_grad=model.requires_grad)
1514
+
1515
+ # zero since only upper triangle is set which can trigger NaN detection
1516
+ self.L = wp.zeros_like(self.H)
1517
+
1518
+ if model.body_count:
1519
+ self.body_I_m = wp.empty(
1520
+ (model.body_count,), dtype=wp.spatial_matrix, device=model.device, requires_grad=model.requires_grad
1521
+ )
1522
+ wp.launch(
1523
+ compute_spatial_inertia,
1524
+ model.body_count,
1525
+ inputs=[model.body_inertia, model.body_mass],
1526
+ outputs=[self.body_I_m],
1527
+ device=model.device,
1528
+ )
1529
+ self.body_X_com = wp.empty(
1530
+ (model.body_count,), dtype=wp.transform, device=model.device, requires_grad=model.requires_grad
1531
+ )
1532
+ wp.launch(
1533
+ compute_com_transforms,
1534
+ model.body_count,
1535
+ inputs=[model.body_com],
1536
+ outputs=[self.body_X_com],
1537
+ device=model.device,
1538
+ )
1539
+
1540
+ def allocate_state_aux_vars(self, model, target, requires_grad):
1541
+ # allocate auxiliary variables that vary with state
1542
+ if model.body_count:
1543
+ # joints
1544
+ target.joint_qdd = wp.zeros_like(model.joint_qd, requires_grad=requires_grad)
1545
+ target.joint_tau = wp.empty_like(model.joint_qd, requires_grad=requires_grad)
1546
+ if requires_grad:
1547
+ # used in the custom grad implementation of eval_dense_solve_batched
1548
+ target.joint_solve_tmp = wp.zeros_like(model.joint_qd, requires_grad=True)
1549
+ else:
1550
+ target.joint_solve_tmp = None
1551
+ target.joint_S_s = wp.empty(
1552
+ (model.joint_dof_count,),
1553
+ dtype=wp.spatial_vector,
1554
+ device=model.device,
1555
+ requires_grad=requires_grad,
1556
+ )
1557
+
1558
+ # derived rigid body data (maximal coordinates)
1559
+ target.body_q_com = wp.empty_like(model.body_q, requires_grad=requires_grad)
1560
+ target.body_I_s = wp.empty(
1561
+ (model.body_count,), dtype=wp.spatial_matrix, device=model.device, requires_grad=requires_grad
1562
+ )
1563
+ target.body_v_s = wp.empty(
1564
+ (model.body_count,), dtype=wp.spatial_vector, device=model.device, requires_grad=requires_grad
1565
+ )
1566
+ target.body_a_s = wp.empty(
1567
+ (model.body_count,), dtype=wp.spatial_vector, device=model.device, requires_grad=requires_grad
1568
+ )
1569
+ target.body_f_s = wp.zeros(
1570
+ (model.body_count,), dtype=wp.spatial_vector, device=model.device, requires_grad=requires_grad
1571
+ )
1572
+ target.body_ft_s = wp.zeros(
1573
+ (model.body_count,), dtype=wp.spatial_vector, device=model.device, requires_grad=requires_grad
1574
+ )
1575
+
1576
+ target._featherstone_augmented = True
1577
+
1578
+ def simulate(self, model: Model, state_in: State, state_out: State, dt: float, control: Control = None):
1579
+ requires_grad = state_in.requires_grad
1580
+
1581
+ # optionally create dynamical auxiliary variables
1582
+ if requires_grad:
1583
+ state_aug = state_out
1584
+ else:
1585
+ state_aug = self
1586
+
1587
+ if not getattr(state_aug, "_featherstone_augmented", False):
1588
+ self.allocate_state_aux_vars(model, state_aug, requires_grad)
1589
+ if control is None:
1590
+ control = model.control(clone_variables=False)
1591
+
1592
+ with wp.ScopedTimer("simulate", False):
1593
+ particle_f = None
1594
+ body_f = None
1595
+
1596
+ if state_in.particle_count:
1597
+ particle_f = state_in.particle_f
1598
+
1599
+ if state_in.body_count:
1600
+ body_f = state_in.body_f
1601
+
1602
+ # damped springs
1603
+ eval_spring_forces(model, state_in, particle_f)
1604
+
1605
+ # triangle elastic and lift/drag forces
1606
+ eval_triangle_forces(model, state_in, control, particle_f)
1607
+
1608
+ # triangle/triangle contacts
1609
+ eval_triangle_contact_forces(model, state_in, particle_f)
1610
+
1611
+ # triangle bending
1612
+ eval_bending_forces(model, state_in, particle_f)
1613
+
1614
+ # tetrahedral FEM
1615
+ eval_tetrahedral_forces(model, state_in, control, particle_f)
1616
+
1617
+ # particle-particle interactions
1618
+ eval_particle_forces(model, state_in, particle_f)
1619
+
1620
+ # particle ground contacts
1621
+ eval_particle_ground_contact_forces(model, state_in, particle_f)
1622
+
1623
+ # particle shape contact
1624
+ eval_particle_body_contact_forces(model, state_in, particle_f, body_f)
1625
+
1626
+ # muscles
1627
+ if False:
1628
+ eval_muscle_forces(model, state_in, control, body_f)
1629
+
1630
+ # ----------------------------
1631
+ # articulations
1632
+
1633
+ if model.joint_count:
1634
+ # evaluate body transforms
1635
+ wp.launch(
1636
+ eval_rigid_fk,
1637
+ dim=model.articulation_count,
1638
+ inputs=[
1639
+ model.articulation_start,
1640
+ model.joint_type,
1641
+ model.joint_parent,
1642
+ model.joint_child,
1643
+ model.joint_q_start,
1644
+ state_in.joint_q,
1645
+ model.joint_X_p,
1646
+ model.joint_X_c,
1647
+ self.body_X_com,
1648
+ model.joint_axis,
1649
+ model.joint_axis_start,
1650
+ model.joint_axis_dim,
1651
+ ],
1652
+ outputs=[state_in.body_q, state_aug.body_q_com],
1653
+ device=model.device,
1654
+ )
1655
+
1656
+ # print("body_X_sc:")
1657
+ # print(state_in.body_q.numpy())
1658
+
1659
+ # evaluate joint inertias, motion vectors, and forces
1660
+ state_aug.body_f_s.zero_()
1661
+ wp.launch(
1662
+ eval_rigid_id,
1663
+ dim=model.articulation_count,
1664
+ inputs=[
1665
+ model.articulation_start,
1666
+ model.joint_type,
1667
+ model.joint_parent,
1668
+ model.joint_child,
1669
+ model.joint_q_start,
1670
+ model.joint_qd_start,
1671
+ state_in.joint_q,
1672
+ state_in.joint_qd,
1673
+ model.joint_axis,
1674
+ model.joint_axis_start,
1675
+ model.joint_axis_dim,
1676
+ self.body_I_m,
1677
+ state_in.body_q,
1678
+ state_aug.body_q_com,
1679
+ model.joint_X_p,
1680
+ model.joint_X_c,
1681
+ model.gravity,
1682
+ ],
1683
+ outputs=[
1684
+ state_aug.joint_S_s,
1685
+ state_aug.body_I_s,
1686
+ state_aug.body_v_s,
1687
+ state_aug.body_f_s,
1688
+ state_aug.body_a_s,
1689
+ ],
1690
+ device=model.device,
1691
+ )
1692
+
1693
+ if model.rigid_contact_max and (
1694
+ model.ground and model.shape_ground_contact_pair_count or model.shape_contact_pair_count
1695
+ ):
1696
+ wp.launch(
1697
+ kernel=eval_rigid_contacts,
1698
+ dim=model.rigid_contact_max,
1699
+ inputs=[
1700
+ state_in.body_q,
1701
+ state_aug.body_v_s,
1702
+ model.body_com,
1703
+ model.shape_materials,
1704
+ model.shape_geo,
1705
+ model.shape_body,
1706
+ model.rigid_contact_count,
1707
+ model.rigid_contact_point0,
1708
+ model.rigid_contact_point1,
1709
+ model.rigid_contact_normal,
1710
+ model.rigid_contact_shape0,
1711
+ model.rigid_contact_shape1,
1712
+ True,
1713
+ ],
1714
+ outputs=[body_f],
1715
+ device=model.device,
1716
+ )
1717
+
1718
+ # if model.rigid_contact_count.numpy()[0] > 0:
1719
+ # print(body_f.numpy())
1720
+
1721
+ if model.articulation_count:
1722
+ # evaluate joint torques
1723
+ state_aug.body_ft_s.zero_()
1724
+ wp.launch(
1725
+ eval_rigid_tau,
1726
+ dim=model.articulation_count,
1727
+ inputs=[
1728
+ model.articulation_start,
1729
+ model.joint_type,
1730
+ model.joint_parent,
1731
+ model.joint_child,
1732
+ model.joint_q_start,
1733
+ model.joint_qd_start,
1734
+ model.joint_axis_start,
1735
+ model.joint_axis_dim,
1736
+ model.joint_axis_mode,
1737
+ state_in.joint_q,
1738
+ state_in.joint_qd,
1739
+ control.joint_act,
1740
+ model.joint_target_ke,
1741
+ model.joint_target_kd,
1742
+ model.joint_limit_lower,
1743
+ model.joint_limit_upper,
1744
+ model.joint_limit_ke,
1745
+ model.joint_limit_kd,
1746
+ state_aug.joint_S_s,
1747
+ state_aug.body_f_s,
1748
+ body_f,
1749
+ ],
1750
+ outputs=[
1751
+ state_aug.body_ft_s,
1752
+ state_aug.joint_tau,
1753
+ ],
1754
+ device=model.device,
1755
+ )
1756
+
1757
+ # print("joint_tau:")
1758
+ # print(state_aug.joint_tau.numpy())
1759
+ # print("body_q:")
1760
+ # print(state_in.body_q.numpy())
1761
+ # print("body_qd:")
1762
+ # print(state_in.body_qd.numpy())
1763
+
1764
+ if self._step % self.update_mass_matrix_every == 0:
1765
+ # build J
1766
+ wp.launch(
1767
+ eval_rigid_jacobian,
1768
+ dim=model.articulation_count,
1769
+ inputs=[
1770
+ model.articulation_start,
1771
+ self.articulation_J_start,
1772
+ model.joint_parent,
1773
+ model.joint_qd_start,
1774
+ state_aug.joint_S_s,
1775
+ ],
1776
+ outputs=[self.J],
1777
+ device=model.device,
1778
+ )
1779
+
1780
+ # build M
1781
+ wp.launch(
1782
+ eval_rigid_mass,
1783
+ dim=model.articulation_count,
1784
+ inputs=[
1785
+ model.articulation_start,
1786
+ self.articulation_M_start,
1787
+ state_aug.body_I_s,
1788
+ ],
1789
+ outputs=[self.M],
1790
+ device=model.device,
1791
+ )
1792
+
1793
+ # form P = M*J
1794
+ wp.launch(
1795
+ eval_dense_gemm_batched,
1796
+ dim=model.articulation_count,
1797
+ inputs=[
1798
+ self.articulation_M_rows,
1799
+ self.articulation_J_cols,
1800
+ self.articulation_J_rows,
1801
+ False,
1802
+ False,
1803
+ self.articulation_M_start,
1804
+ self.articulation_J_start,
1805
+ # P start is the same as J start since it has the same dims as J
1806
+ self.articulation_J_start,
1807
+ self.M,
1808
+ self.J,
1809
+ ],
1810
+ outputs=[self.P],
1811
+ device=model.device,
1812
+ )
1813
+
1814
+ # form H = J^T*P
1815
+ wp.launch(
1816
+ eval_dense_gemm_batched,
1817
+ dim=model.articulation_count,
1818
+ inputs=[
1819
+ self.articulation_J_cols,
1820
+ self.articulation_J_cols,
1821
+ # P rows is the same as J rows
1822
+ self.articulation_J_rows,
1823
+ True,
1824
+ False,
1825
+ self.articulation_J_start,
1826
+ # P start is the same as J start since it has the same dims as J
1827
+ self.articulation_J_start,
1828
+ self.articulation_H_start,
1829
+ self.J,
1830
+ self.P,
1831
+ ],
1832
+ outputs=[self.H],
1833
+ device=model.device,
1834
+ )
1835
+
1836
+ # compute decomposition
1837
+ wp.launch(
1838
+ eval_dense_cholesky_batched,
1839
+ dim=model.articulation_count,
1840
+ inputs=[
1841
+ self.articulation_H_start,
1842
+ self.articulation_H_rows,
1843
+ self.H,
1844
+ model.joint_armature,
1845
+ ],
1846
+ outputs=[self.L],
1847
+ device=model.device,
1848
+ )
1849
+
1850
+ # print("joint_act:")
1851
+ # print(control.joint_act.numpy())
1852
+ # print("joint_tau:")
1853
+ # print(state_aug.joint_tau.numpy())
1854
+ # print("H:")
1855
+ # print(self.H.numpy())
1856
+ # print("L:")
1857
+ # print(self.L.numpy())
1858
+
1859
+ # solve for qdd
1860
+ state_aug.joint_qdd.zero_()
1861
+ wp.launch(
1862
+ eval_dense_solve_batched,
1863
+ dim=model.articulation_count,
1864
+ inputs=[
1865
+ self.articulation_H_start,
1866
+ self.articulation_H_rows,
1867
+ self.articulation_dof_start,
1868
+ self.H,
1869
+ self.L,
1870
+ state_aug.joint_tau,
1871
+ ],
1872
+ outputs=[
1873
+ state_aug.joint_qdd,
1874
+ state_aug.joint_solve_tmp,
1875
+ ],
1876
+ device=model.device,
1877
+ )
1878
+ # if wp.context.runtime.tape:
1879
+ # wp.context.runtime.tape.record_func(
1880
+ # backward=lambda: adj_matmul(
1881
+ # a, b, c, a.grad, b.grad, c.grad, d.grad, alpha, beta, allow_tf32x3_arith, device
1882
+ # ),
1883
+ # arrays=[a, b, c, d],
1884
+ # )
1885
+ # print("joint_qdd:")
1886
+ # print(state_aug.joint_qdd.numpy())
1887
+ # print("\n\n")
1888
+
1889
+ # -------------------------------------
1890
+ # integrate bodies
1891
+
1892
+ if model.joint_count:
1893
+ wp.launch(
1894
+ kernel=integrate_generalized_joints,
1895
+ dim=model.joint_count,
1896
+ inputs=[
1897
+ model.joint_type,
1898
+ model.joint_q_start,
1899
+ model.joint_qd_start,
1900
+ model.joint_axis_dim,
1901
+ state_in.joint_q,
1902
+ state_in.joint_qd,
1903
+ state_aug.joint_qdd,
1904
+ dt,
1905
+ ],
1906
+ outputs=[state_out.joint_q, state_out.joint_qd],
1907
+ device=model.device,
1908
+ )
1909
+
1910
+ # update maximal coordinates
1911
+ eval_fk(model, state_out.joint_q, state_out.joint_qd, None, state_out)
1912
+
1913
+ self.integrate_particles(model, state_in, state_out, dt)
1914
+
1915
+ self._step += 1
1916
+
1917
+ return state_out