warp-lang 1.7.0__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (429) hide show
  1. warp/__init__.py +139 -0
  2. warp/__init__.pyi +1 -0
  3. warp/autograd.py +1142 -0
  4. warp/bin/warp-clang.so +0 -0
  5. warp/bin/warp.so +0 -0
  6. warp/build.py +557 -0
  7. warp/build_dll.py +405 -0
  8. warp/builtins.py +6855 -0
  9. warp/codegen.py +3969 -0
  10. warp/config.py +158 -0
  11. warp/constants.py +57 -0
  12. warp/context.py +6812 -0
  13. warp/dlpack.py +462 -0
  14. warp/examples/__init__.py +24 -0
  15. warp/examples/assets/bear.usd +0 -0
  16. warp/examples/assets/bunny.usd +0 -0
  17. warp/examples/assets/cartpole.urdf +110 -0
  18. warp/examples/assets/crazyflie.usd +0 -0
  19. warp/examples/assets/cube.usd +0 -0
  20. warp/examples/assets/nonuniform.usd +0 -0
  21. warp/examples/assets/nv_ant.xml +92 -0
  22. warp/examples/assets/nv_humanoid.xml +183 -0
  23. warp/examples/assets/nvidia_logo.png +0 -0
  24. warp/examples/assets/pixel.jpg +0 -0
  25. warp/examples/assets/quadruped.urdf +268 -0
  26. warp/examples/assets/rocks.nvdb +0 -0
  27. warp/examples/assets/rocks.usd +0 -0
  28. warp/examples/assets/sphere.usd +0 -0
  29. warp/examples/assets/square_cloth.usd +0 -0
  30. warp/examples/benchmarks/benchmark_api.py +389 -0
  31. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  32. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  33. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  34. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  35. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  36. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  37. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  38. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  39. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  40. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  41. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  42. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  43. warp/examples/benchmarks/benchmark_launches.py +301 -0
  44. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  45. warp/examples/browse.py +37 -0
  46. warp/examples/core/example_cupy.py +86 -0
  47. warp/examples/core/example_dem.py +241 -0
  48. warp/examples/core/example_fluid.py +299 -0
  49. warp/examples/core/example_graph_capture.py +150 -0
  50. warp/examples/core/example_marching_cubes.py +194 -0
  51. warp/examples/core/example_mesh.py +180 -0
  52. warp/examples/core/example_mesh_intersect.py +211 -0
  53. warp/examples/core/example_nvdb.py +182 -0
  54. warp/examples/core/example_raycast.py +111 -0
  55. warp/examples/core/example_raymarch.py +205 -0
  56. warp/examples/core/example_render_opengl.py +193 -0
  57. warp/examples/core/example_sample_mesh.py +300 -0
  58. warp/examples/core/example_sph.py +411 -0
  59. warp/examples/core/example_torch.py +211 -0
  60. warp/examples/core/example_wave.py +269 -0
  61. warp/examples/fem/example_adaptive_grid.py +286 -0
  62. warp/examples/fem/example_apic_fluid.py +423 -0
  63. warp/examples/fem/example_burgers.py +261 -0
  64. warp/examples/fem/example_convection_diffusion.py +178 -0
  65. warp/examples/fem/example_convection_diffusion_dg.py +204 -0
  66. warp/examples/fem/example_deformed_geometry.py +172 -0
  67. warp/examples/fem/example_diffusion.py +196 -0
  68. warp/examples/fem/example_diffusion_3d.py +225 -0
  69. warp/examples/fem/example_diffusion_mgpu.py +220 -0
  70. warp/examples/fem/example_distortion_energy.py +228 -0
  71. warp/examples/fem/example_magnetostatics.py +240 -0
  72. warp/examples/fem/example_mixed_elasticity.py +291 -0
  73. warp/examples/fem/example_navier_stokes.py +261 -0
  74. warp/examples/fem/example_nonconforming_contact.py +298 -0
  75. warp/examples/fem/example_stokes.py +213 -0
  76. warp/examples/fem/example_stokes_transfer.py +262 -0
  77. warp/examples/fem/example_streamlines.py +352 -0
  78. warp/examples/fem/utils.py +1000 -0
  79. warp/examples/interop/example_jax_callable.py +116 -0
  80. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  81. warp/examples/interop/example_jax_kernel.py +205 -0
  82. warp/examples/optim/example_bounce.py +266 -0
  83. warp/examples/optim/example_cloth_throw.py +228 -0
  84. warp/examples/optim/example_diffray.py +561 -0
  85. warp/examples/optim/example_drone.py +870 -0
  86. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  87. warp/examples/optim/example_inverse_kinematics.py +182 -0
  88. warp/examples/optim/example_inverse_kinematics_torch.py +191 -0
  89. warp/examples/optim/example_softbody_properties.py +400 -0
  90. warp/examples/optim/example_spring_cage.py +245 -0
  91. warp/examples/optim/example_trajectory.py +227 -0
  92. warp/examples/sim/example_cartpole.py +143 -0
  93. warp/examples/sim/example_cloth.py +225 -0
  94. warp/examples/sim/example_cloth_self_contact.py +322 -0
  95. warp/examples/sim/example_granular.py +130 -0
  96. warp/examples/sim/example_granular_collision_sdf.py +202 -0
  97. warp/examples/sim/example_jacobian_ik.py +244 -0
  98. warp/examples/sim/example_particle_chain.py +124 -0
  99. warp/examples/sim/example_quadruped.py +203 -0
  100. warp/examples/sim/example_rigid_chain.py +203 -0
  101. warp/examples/sim/example_rigid_contact.py +195 -0
  102. warp/examples/sim/example_rigid_force.py +133 -0
  103. warp/examples/sim/example_rigid_gyroscopic.py +115 -0
  104. warp/examples/sim/example_rigid_soft_contact.py +140 -0
  105. warp/examples/sim/example_soft_body.py +196 -0
  106. warp/examples/tile/example_tile_cholesky.py +87 -0
  107. warp/examples/tile/example_tile_convolution.py +66 -0
  108. warp/examples/tile/example_tile_fft.py +55 -0
  109. warp/examples/tile/example_tile_filtering.py +113 -0
  110. warp/examples/tile/example_tile_matmul.py +85 -0
  111. warp/examples/tile/example_tile_mlp.py +383 -0
  112. warp/examples/tile/example_tile_nbody.py +199 -0
  113. warp/examples/tile/example_tile_walker.py +327 -0
  114. warp/fabric.py +355 -0
  115. warp/fem/__init__.py +106 -0
  116. warp/fem/adaptivity.py +508 -0
  117. warp/fem/cache.py +572 -0
  118. warp/fem/dirichlet.py +202 -0
  119. warp/fem/domain.py +411 -0
  120. warp/fem/field/__init__.py +125 -0
  121. warp/fem/field/field.py +619 -0
  122. warp/fem/field/nodal_field.py +326 -0
  123. warp/fem/field/restriction.py +37 -0
  124. warp/fem/field/virtual.py +848 -0
  125. warp/fem/geometry/__init__.py +32 -0
  126. warp/fem/geometry/adaptive_nanogrid.py +857 -0
  127. warp/fem/geometry/closest_point.py +84 -0
  128. warp/fem/geometry/deformed_geometry.py +221 -0
  129. warp/fem/geometry/element.py +776 -0
  130. warp/fem/geometry/geometry.py +362 -0
  131. warp/fem/geometry/grid_2d.py +392 -0
  132. warp/fem/geometry/grid_3d.py +452 -0
  133. warp/fem/geometry/hexmesh.py +911 -0
  134. warp/fem/geometry/nanogrid.py +571 -0
  135. warp/fem/geometry/partition.py +389 -0
  136. warp/fem/geometry/quadmesh.py +663 -0
  137. warp/fem/geometry/tetmesh.py +855 -0
  138. warp/fem/geometry/trimesh.py +806 -0
  139. warp/fem/integrate.py +2335 -0
  140. warp/fem/linalg.py +419 -0
  141. warp/fem/operator.py +293 -0
  142. warp/fem/polynomial.py +229 -0
  143. warp/fem/quadrature/__init__.py +17 -0
  144. warp/fem/quadrature/pic_quadrature.py +299 -0
  145. warp/fem/quadrature/quadrature.py +591 -0
  146. warp/fem/space/__init__.py +228 -0
  147. warp/fem/space/basis_function_space.py +468 -0
  148. warp/fem/space/basis_space.py +667 -0
  149. warp/fem/space/dof_mapper.py +251 -0
  150. warp/fem/space/function_space.py +309 -0
  151. warp/fem/space/grid_2d_function_space.py +177 -0
  152. warp/fem/space/grid_3d_function_space.py +227 -0
  153. warp/fem/space/hexmesh_function_space.py +257 -0
  154. warp/fem/space/nanogrid_function_space.py +201 -0
  155. warp/fem/space/partition.py +367 -0
  156. warp/fem/space/quadmesh_function_space.py +223 -0
  157. warp/fem/space/restriction.py +179 -0
  158. warp/fem/space/shape/__init__.py +143 -0
  159. warp/fem/space/shape/cube_shape_function.py +1105 -0
  160. warp/fem/space/shape/shape_function.py +133 -0
  161. warp/fem/space/shape/square_shape_function.py +926 -0
  162. warp/fem/space/shape/tet_shape_function.py +834 -0
  163. warp/fem/space/shape/triangle_shape_function.py +672 -0
  164. warp/fem/space/tetmesh_function_space.py +271 -0
  165. warp/fem/space/topology.py +424 -0
  166. warp/fem/space/trimesh_function_space.py +194 -0
  167. warp/fem/types.py +99 -0
  168. warp/fem/utils.py +420 -0
  169. warp/jax.py +187 -0
  170. warp/jax_experimental/__init__.py +16 -0
  171. warp/jax_experimental/custom_call.py +351 -0
  172. warp/jax_experimental/ffi.py +698 -0
  173. warp/jax_experimental/xla_ffi.py +602 -0
  174. warp/math.py +244 -0
  175. warp/native/array.h +1145 -0
  176. warp/native/builtin.h +1800 -0
  177. warp/native/bvh.cpp +492 -0
  178. warp/native/bvh.cu +791 -0
  179. warp/native/bvh.h +554 -0
  180. warp/native/clang/clang.cpp +536 -0
  181. warp/native/coloring.cpp +613 -0
  182. warp/native/crt.cpp +51 -0
  183. warp/native/crt.h +362 -0
  184. warp/native/cuda_crt.h +1058 -0
  185. warp/native/cuda_util.cpp +646 -0
  186. warp/native/cuda_util.h +307 -0
  187. warp/native/error.cpp +77 -0
  188. warp/native/error.h +36 -0
  189. warp/native/exports.h +1878 -0
  190. warp/native/fabric.h +245 -0
  191. warp/native/hashgrid.cpp +311 -0
  192. warp/native/hashgrid.cu +87 -0
  193. warp/native/hashgrid.h +240 -0
  194. warp/native/initializer_array.h +41 -0
  195. warp/native/intersect.h +1230 -0
  196. warp/native/intersect_adj.h +375 -0
  197. warp/native/intersect_tri.h +339 -0
  198. warp/native/marching.cpp +19 -0
  199. warp/native/marching.cu +514 -0
  200. warp/native/marching.h +19 -0
  201. warp/native/mat.h +2220 -0
  202. warp/native/mathdx.cpp +87 -0
  203. warp/native/matnn.h +343 -0
  204. warp/native/mesh.cpp +266 -0
  205. warp/native/mesh.cu +404 -0
  206. warp/native/mesh.h +1980 -0
  207. warp/native/nanovdb/GridHandle.h +366 -0
  208. warp/native/nanovdb/HostBuffer.h +590 -0
  209. warp/native/nanovdb/NanoVDB.h +6624 -0
  210. warp/native/nanovdb/PNanoVDB.h +3390 -0
  211. warp/native/noise.h +859 -0
  212. warp/native/quat.h +1371 -0
  213. warp/native/rand.h +342 -0
  214. warp/native/range.h +139 -0
  215. warp/native/reduce.cpp +174 -0
  216. warp/native/reduce.cu +364 -0
  217. warp/native/runlength_encode.cpp +79 -0
  218. warp/native/runlength_encode.cu +61 -0
  219. warp/native/scan.cpp +47 -0
  220. warp/native/scan.cu +53 -0
  221. warp/native/scan.h +23 -0
  222. warp/native/solid_angle.h +466 -0
  223. warp/native/sort.cpp +251 -0
  224. warp/native/sort.cu +277 -0
  225. warp/native/sort.h +33 -0
  226. warp/native/sparse.cpp +378 -0
  227. warp/native/sparse.cu +524 -0
  228. warp/native/spatial.h +657 -0
  229. warp/native/svd.h +702 -0
  230. warp/native/temp_buffer.h +46 -0
  231. warp/native/tile.h +2584 -0
  232. warp/native/tile_reduce.h +264 -0
  233. warp/native/vec.h +1426 -0
  234. warp/native/volume.cpp +501 -0
  235. warp/native/volume.cu +67 -0
  236. warp/native/volume.h +969 -0
  237. warp/native/volume_builder.cu +477 -0
  238. warp/native/volume_builder.h +52 -0
  239. warp/native/volume_impl.h +70 -0
  240. warp/native/warp.cpp +1082 -0
  241. warp/native/warp.cu +3636 -0
  242. warp/native/warp.h +381 -0
  243. warp/optim/__init__.py +17 -0
  244. warp/optim/adam.py +163 -0
  245. warp/optim/linear.py +1137 -0
  246. warp/optim/sgd.py +112 -0
  247. warp/paddle.py +407 -0
  248. warp/render/__init__.py +18 -0
  249. warp/render/render_opengl.py +3518 -0
  250. warp/render/render_usd.py +784 -0
  251. warp/render/utils.py +160 -0
  252. warp/sim/__init__.py +65 -0
  253. warp/sim/articulation.py +793 -0
  254. warp/sim/collide.py +2395 -0
  255. warp/sim/graph_coloring.py +300 -0
  256. warp/sim/import_mjcf.py +790 -0
  257. warp/sim/import_snu.py +227 -0
  258. warp/sim/import_urdf.py +579 -0
  259. warp/sim/import_usd.py +894 -0
  260. warp/sim/inertia.py +324 -0
  261. warp/sim/integrator.py +242 -0
  262. warp/sim/integrator_euler.py +1997 -0
  263. warp/sim/integrator_featherstone.py +2101 -0
  264. warp/sim/integrator_vbd.py +2048 -0
  265. warp/sim/integrator_xpbd.py +3292 -0
  266. warp/sim/model.py +4791 -0
  267. warp/sim/particles.py +121 -0
  268. warp/sim/render.py +427 -0
  269. warp/sim/utils.py +428 -0
  270. warp/sparse.py +2057 -0
  271. warp/stubs.py +3333 -0
  272. warp/tape.py +1203 -0
  273. warp/tests/__init__.py +1 -0
  274. warp/tests/__main__.py +4 -0
  275. warp/tests/assets/curlnoise_golden.npy +0 -0
  276. warp/tests/assets/mlp_golden.npy +0 -0
  277. warp/tests/assets/pixel.npy +0 -0
  278. warp/tests/assets/pnoise_golden.npy +0 -0
  279. warp/tests/assets/spiky.usd +0 -0
  280. warp/tests/assets/test_grid.nvdb +0 -0
  281. warp/tests/assets/test_index_grid.nvdb +0 -0
  282. warp/tests/assets/test_int32_grid.nvdb +0 -0
  283. warp/tests/assets/test_vec_grid.nvdb +0 -0
  284. warp/tests/assets/torus.nvdb +0 -0
  285. warp/tests/assets/torus.usda +105 -0
  286. warp/tests/aux_test_class_kernel.py +34 -0
  287. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  288. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  289. warp/tests/aux_test_dependent.py +29 -0
  290. warp/tests/aux_test_grad_customs.py +29 -0
  291. warp/tests/aux_test_instancing_gc.py +26 -0
  292. warp/tests/aux_test_module_unload.py +23 -0
  293. warp/tests/aux_test_name_clash1.py +40 -0
  294. warp/tests/aux_test_name_clash2.py +40 -0
  295. warp/tests/aux_test_reference.py +9 -0
  296. warp/tests/aux_test_reference_reference.py +8 -0
  297. warp/tests/aux_test_square.py +16 -0
  298. warp/tests/aux_test_unresolved_func.py +22 -0
  299. warp/tests/aux_test_unresolved_symbol.py +22 -0
  300. warp/tests/cuda/__init__.py +0 -0
  301. warp/tests/cuda/test_async.py +676 -0
  302. warp/tests/cuda/test_ipc.py +124 -0
  303. warp/tests/cuda/test_mempool.py +233 -0
  304. warp/tests/cuda/test_multigpu.py +169 -0
  305. warp/tests/cuda/test_peer.py +139 -0
  306. warp/tests/cuda/test_pinned.py +84 -0
  307. warp/tests/cuda/test_streams.py +634 -0
  308. warp/tests/geometry/__init__.py +0 -0
  309. warp/tests/geometry/test_bvh.py +200 -0
  310. warp/tests/geometry/test_hash_grid.py +221 -0
  311. warp/tests/geometry/test_marching_cubes.py +74 -0
  312. warp/tests/geometry/test_mesh.py +316 -0
  313. warp/tests/geometry/test_mesh_query_aabb.py +399 -0
  314. warp/tests/geometry/test_mesh_query_point.py +932 -0
  315. warp/tests/geometry/test_mesh_query_ray.py +311 -0
  316. warp/tests/geometry/test_volume.py +1103 -0
  317. warp/tests/geometry/test_volume_write.py +346 -0
  318. warp/tests/interop/__init__.py +0 -0
  319. warp/tests/interop/test_dlpack.py +729 -0
  320. warp/tests/interop/test_jax.py +371 -0
  321. warp/tests/interop/test_paddle.py +800 -0
  322. warp/tests/interop/test_torch.py +1001 -0
  323. warp/tests/run_coverage_serial.py +39 -0
  324. warp/tests/sim/__init__.py +0 -0
  325. warp/tests/sim/disabled_kinematics.py +244 -0
  326. warp/tests/sim/flaky_test_sim_grad.py +290 -0
  327. warp/tests/sim/test_collision.py +604 -0
  328. warp/tests/sim/test_coloring.py +258 -0
  329. warp/tests/sim/test_model.py +224 -0
  330. warp/tests/sim/test_sim_grad_bounce_linear.py +212 -0
  331. warp/tests/sim/test_sim_kinematics.py +98 -0
  332. warp/tests/sim/test_vbd.py +597 -0
  333. warp/tests/test_adam.py +163 -0
  334. warp/tests/test_arithmetic.py +1096 -0
  335. warp/tests/test_array.py +2972 -0
  336. warp/tests/test_array_reduce.py +156 -0
  337. warp/tests/test_assert.py +250 -0
  338. warp/tests/test_atomic.py +153 -0
  339. warp/tests/test_bool.py +220 -0
  340. warp/tests/test_builtins_resolution.py +1298 -0
  341. warp/tests/test_closest_point_edge_edge.py +327 -0
  342. warp/tests/test_codegen.py +810 -0
  343. warp/tests/test_codegen_instancing.py +1495 -0
  344. warp/tests/test_compile_consts.py +215 -0
  345. warp/tests/test_conditional.py +252 -0
  346. warp/tests/test_context.py +42 -0
  347. warp/tests/test_copy.py +238 -0
  348. warp/tests/test_ctypes.py +638 -0
  349. warp/tests/test_dense.py +73 -0
  350. warp/tests/test_devices.py +97 -0
  351. warp/tests/test_examples.py +482 -0
  352. warp/tests/test_fabricarray.py +996 -0
  353. warp/tests/test_fast_math.py +74 -0
  354. warp/tests/test_fem.py +2003 -0
  355. warp/tests/test_fp16.py +136 -0
  356. warp/tests/test_func.py +454 -0
  357. warp/tests/test_future_annotations.py +98 -0
  358. warp/tests/test_generics.py +656 -0
  359. warp/tests/test_grad.py +893 -0
  360. warp/tests/test_grad_customs.py +339 -0
  361. warp/tests/test_grad_debug.py +341 -0
  362. warp/tests/test_implicit_init.py +411 -0
  363. warp/tests/test_import.py +45 -0
  364. warp/tests/test_indexedarray.py +1140 -0
  365. warp/tests/test_intersect.py +73 -0
  366. warp/tests/test_iter.py +76 -0
  367. warp/tests/test_large.py +177 -0
  368. warp/tests/test_launch.py +411 -0
  369. warp/tests/test_lerp.py +151 -0
  370. warp/tests/test_linear_solvers.py +193 -0
  371. warp/tests/test_lvalue.py +427 -0
  372. warp/tests/test_mat.py +2089 -0
  373. warp/tests/test_mat_lite.py +122 -0
  374. warp/tests/test_mat_scalar_ops.py +2913 -0
  375. warp/tests/test_math.py +178 -0
  376. warp/tests/test_mlp.py +282 -0
  377. warp/tests/test_module_hashing.py +258 -0
  378. warp/tests/test_modules_lite.py +44 -0
  379. warp/tests/test_noise.py +252 -0
  380. warp/tests/test_operators.py +299 -0
  381. warp/tests/test_options.py +129 -0
  382. warp/tests/test_overwrite.py +551 -0
  383. warp/tests/test_print.py +339 -0
  384. warp/tests/test_quat.py +2315 -0
  385. warp/tests/test_rand.py +339 -0
  386. warp/tests/test_reload.py +302 -0
  387. warp/tests/test_rounding.py +185 -0
  388. warp/tests/test_runlength_encode.py +196 -0
  389. warp/tests/test_scalar_ops.py +105 -0
  390. warp/tests/test_smoothstep.py +108 -0
  391. warp/tests/test_snippet.py +318 -0
  392. warp/tests/test_sparse.py +582 -0
  393. warp/tests/test_spatial.py +2229 -0
  394. warp/tests/test_special_values.py +361 -0
  395. warp/tests/test_static.py +592 -0
  396. warp/tests/test_struct.py +734 -0
  397. warp/tests/test_tape.py +204 -0
  398. warp/tests/test_transient_module.py +93 -0
  399. warp/tests/test_triangle_closest_point.py +145 -0
  400. warp/tests/test_types.py +562 -0
  401. warp/tests/test_utils.py +588 -0
  402. warp/tests/test_vec.py +1487 -0
  403. warp/tests/test_vec_lite.py +80 -0
  404. warp/tests/test_vec_scalar_ops.py +2327 -0
  405. warp/tests/test_verify_fp.py +100 -0
  406. warp/tests/tile/__init__.py +0 -0
  407. warp/tests/tile/test_tile.py +780 -0
  408. warp/tests/tile/test_tile_load.py +407 -0
  409. warp/tests/tile/test_tile_mathdx.py +208 -0
  410. warp/tests/tile/test_tile_mlp.py +402 -0
  411. warp/tests/tile/test_tile_reduce.py +447 -0
  412. warp/tests/tile/test_tile_shared_memory.py +247 -0
  413. warp/tests/tile/test_tile_view.py +173 -0
  414. warp/tests/unittest_serial.py +47 -0
  415. warp/tests/unittest_suites.py +427 -0
  416. warp/tests/unittest_utils.py +468 -0
  417. warp/tests/walkthrough_debug.py +93 -0
  418. warp/thirdparty/__init__.py +0 -0
  419. warp/thirdparty/appdirs.py +598 -0
  420. warp/thirdparty/dlpack.py +145 -0
  421. warp/thirdparty/unittest_parallel.py +570 -0
  422. warp/torch.py +391 -0
  423. warp/types.py +5230 -0
  424. warp/utils.py +1137 -0
  425. warp_lang-1.7.0.dist-info/METADATA +516 -0
  426. warp_lang-1.7.0.dist-info/RECORD +429 -0
  427. warp_lang-1.7.0.dist-info/WHEEL +5 -0
  428. warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
  429. warp_lang-1.7.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,2315 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import unittest
17
+
18
+ import numpy as np
19
+
20
+ import warp as wp
21
+ import warp.sim
22
+ from warp.tests.unittest_utils import *
23
+
24
+ np_float_types = [np.float32, np.float64, np.float16]
25
+
26
+ kernel_cache = {}
27
+
28
+
29
+ def getkernel(func, suffix=""):
30
+ key = func.__name__ + "_" + suffix
31
+ if key not in kernel_cache:
32
+ kernel_cache[key] = wp.Kernel(func=func, key=key)
33
+ return kernel_cache[key]
34
+
35
+
36
+ def get_select_kernel(dtype):
37
+ def output_select_kernel_fn(
38
+ input: wp.array(dtype=dtype),
39
+ index: int,
40
+ out: wp.array(dtype=dtype),
41
+ ):
42
+ out[0] = input[index]
43
+
44
+ return getkernel(output_select_kernel_fn, suffix=dtype.__name__)
45
+
46
+
47
+ ############################################################
48
+
49
+
50
+ def test_constructors(test, device, dtype, register_kernels=False):
51
+ rng = np.random.default_rng(123)
52
+
53
+ tol = {
54
+ np.float16: 5.0e-3,
55
+ np.float32: 1.0e-6,
56
+ np.float64: 1.0e-8,
57
+ }.get(dtype, 0)
58
+
59
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
60
+ vec3 = wp.types.vector(length=3, dtype=wptype)
61
+ quat = wp.types.quaternion(dtype=wptype)
62
+
63
+ def check_component_constructor(
64
+ input: wp.array(dtype=wptype),
65
+ q: wp.array(dtype=wptype),
66
+ ):
67
+ qresult = quat(input[0], input[1], input[2], input[3])
68
+
69
+ # multiply the output by 2 so we've got something to backpropagate:
70
+ q[0] = wptype(2) * qresult[0]
71
+ q[1] = wptype(2) * qresult[1]
72
+ q[2] = wptype(2) * qresult[2]
73
+ q[3] = wptype(2) * qresult[3]
74
+
75
+ def check_vector_constructor(
76
+ input: wp.array(dtype=wptype),
77
+ q: wp.array(dtype=wptype),
78
+ ):
79
+ qresult = quat(vec3(input[0], input[1], input[2]), input[3])
80
+
81
+ # multiply the output by 2 so we've got something to backpropagate:
82
+ q[0] = wptype(2) * qresult[0]
83
+ q[1] = wptype(2) * qresult[1]
84
+ q[2] = wptype(2) * qresult[2]
85
+ q[3] = wptype(2) * qresult[3]
86
+
87
+ kernel = getkernel(check_component_constructor, suffix=dtype.__name__)
88
+ output_select_kernel = get_select_kernel(wptype)
89
+ vec_kernel = getkernel(check_vector_constructor, suffix=dtype.__name__)
90
+
91
+ if register_kernels:
92
+ return
93
+
94
+ input = wp.array(rng.standard_normal(size=4).astype(dtype), requires_grad=True, device=device)
95
+ output = wp.zeros_like(input)
96
+ wp.launch(kernel, dim=1, inputs=[input], outputs=[output], device=device)
97
+
98
+ assert_np_equal(output.numpy(), 2 * input.numpy(), tol=tol)
99
+
100
+ for i in range(4):
101
+ cmp = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
102
+ tape = wp.Tape()
103
+ with tape:
104
+ wp.launch(kernel, dim=1, inputs=[input], outputs=[output], device=device)
105
+ wp.launch(output_select_kernel, dim=1, inputs=[output, i], outputs=[cmp], device=device)
106
+ tape.backward(loss=cmp)
107
+ expectedgrads = np.zeros(len(input))
108
+ expectedgrads[i] = 2
109
+ assert_np_equal(tape.gradients[input].numpy(), expectedgrads)
110
+ tape.zero()
111
+
112
+ input = wp.array(rng.standard_normal(size=4).astype(dtype), requires_grad=True, device=device)
113
+ output = wp.zeros_like(input)
114
+ wp.launch(vec_kernel, dim=1, inputs=[input], outputs=[output], device=device)
115
+
116
+ assert_np_equal(output.numpy(), 2 * input.numpy(), tol=tol)
117
+
118
+ for i in range(4):
119
+ cmp = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
120
+ tape = wp.Tape()
121
+ with tape:
122
+ wp.launch(vec_kernel, dim=1, inputs=[input], outputs=[output], device=device)
123
+ wp.launch(output_select_kernel, dim=1, inputs=[output, i], outputs=[cmp], device=device)
124
+ tape.backward(loss=cmp)
125
+ expectedgrads = np.zeros(len(input))
126
+ expectedgrads[i] = 2
127
+ assert_np_equal(tape.gradients[input].numpy(), expectedgrads)
128
+ tape.zero()
129
+
130
+
131
+ def test_casting_constructors(test, device, dtype, register_kernels=False):
132
+ np_type = np.dtype(dtype)
133
+ wp_type = wp.types.np_dtype_to_warp_type[np_type]
134
+ quat = wp.types.quaternion(dtype=wp_type)
135
+
136
+ np16 = np.dtype(np.float16)
137
+ wp16 = wp.types.np_dtype_to_warp_type[np16]
138
+
139
+ np32 = np.dtype(np.float32)
140
+ wp32 = wp.types.np_dtype_to_warp_type[np32]
141
+
142
+ np64 = np.dtype(np.float64)
143
+ wp64 = wp.types.np_dtype_to_warp_type[np64]
144
+
145
+ def cast_float16(a: wp.array(dtype=wp_type, ndim=2), b: wp.array(dtype=wp16, ndim=2)):
146
+ tid = wp.tid()
147
+
148
+ q1 = quat(a[tid, 0], a[tid, 1], a[tid, 2], a[tid, 3])
149
+ q2 = wp.quaternion(q1, dtype=wp16)
150
+
151
+ b[tid, 0] = q2[0]
152
+ b[tid, 1] = q2[1]
153
+ b[tid, 2] = q2[2]
154
+ b[tid, 3] = q2[3]
155
+
156
+ def cast_float32(a: wp.array(dtype=wp_type, ndim=2), b: wp.array(dtype=wp32, ndim=2)):
157
+ tid = wp.tid()
158
+
159
+ q1 = quat(a[tid, 0], a[tid, 1], a[tid, 2], a[tid, 3])
160
+ q2 = wp.quaternion(q1, dtype=wp32)
161
+
162
+ b[tid, 0] = q2[0]
163
+ b[tid, 1] = q2[1]
164
+ b[tid, 2] = q2[2]
165
+ b[tid, 3] = q2[3]
166
+
167
+ def cast_float64(a: wp.array(dtype=wp_type, ndim=2), b: wp.array(dtype=wp64, ndim=2)):
168
+ tid = wp.tid()
169
+
170
+ q1 = quat(a[tid, 0], a[tid, 1], a[tid, 2], a[tid, 3])
171
+ q2 = wp.quaternion(q1, dtype=wp64)
172
+
173
+ b[tid, 0] = q2[0]
174
+ b[tid, 1] = q2[1]
175
+ b[tid, 2] = q2[2]
176
+ b[tid, 3] = q2[3]
177
+
178
+ kernel_16 = getkernel(cast_float16, suffix=dtype.__name__)
179
+ kernel_32 = getkernel(cast_float32, suffix=dtype.__name__)
180
+ kernel_64 = getkernel(cast_float64, suffix=dtype.__name__)
181
+
182
+ if register_kernels:
183
+ return
184
+
185
+ # check casting to float 16
186
+ a = wp.array(np.ones((1, 4), dtype=np_type), dtype=wp_type, requires_grad=True, device=device)
187
+ b = wp.array(np.zeros((1, 4), dtype=np16), dtype=wp16, requires_grad=True, device=device)
188
+ b_result = np.ones((1, 4), dtype=np16)
189
+ b_grad = wp.array(np.ones((1, 4), dtype=np16), dtype=wp16, device=device)
190
+ a_grad = wp.array(np.ones((1, 4), dtype=np_type), dtype=wp_type, device=device)
191
+
192
+ tape = wp.Tape()
193
+ with tape:
194
+ wp.launch(kernel=kernel_16, dim=1, inputs=[a, b], device=device)
195
+
196
+ tape.backward(grads={b: b_grad})
197
+ out = tape.gradients[a].numpy()
198
+
199
+ assert_np_equal(b.numpy(), b_result)
200
+ assert_np_equal(out, a_grad.numpy())
201
+
202
+ # check casting to float 32
203
+ a = wp.array(np.ones((1, 4), dtype=np_type), dtype=wp_type, requires_grad=True, device=device)
204
+ b = wp.array(np.zeros((1, 4), dtype=np32), dtype=wp32, requires_grad=True, device=device)
205
+ b_result = np.ones((1, 4), dtype=np32)
206
+ b_grad = wp.array(np.ones((1, 4), dtype=np32), dtype=wp32, device=device)
207
+ a_grad = wp.array(np.ones((1, 4), dtype=np_type), dtype=wp_type, device=device)
208
+
209
+ tape = wp.Tape()
210
+ with tape:
211
+ wp.launch(kernel=kernel_32, dim=1, inputs=[a, b], device=device)
212
+
213
+ tape.backward(grads={b: b_grad})
214
+ out = tape.gradients[a].numpy()
215
+
216
+ assert_np_equal(b.numpy(), b_result)
217
+ assert_np_equal(out, a_grad.numpy())
218
+
219
+ # check casting to float 64
220
+ a = wp.array(np.ones((1, 4), dtype=np_type), dtype=wp_type, requires_grad=True, device=device)
221
+ b = wp.array(np.zeros((1, 4), dtype=np64), dtype=wp64, requires_grad=True, device=device)
222
+ b_result = np.ones((1, 4), dtype=np64)
223
+ b_grad = wp.array(np.ones((1, 4), dtype=np64), dtype=wp64, device=device)
224
+ a_grad = wp.array(np.ones((1, 4), dtype=np_type), dtype=wp_type, device=device)
225
+
226
+ tape = wp.Tape()
227
+ with tape:
228
+ wp.launch(kernel=kernel_64, dim=1, inputs=[a, b], device=device)
229
+
230
+ tape.backward(grads={b: b_grad})
231
+ out = tape.gradients[a].numpy()
232
+
233
+ assert_np_equal(b.numpy(), b_result)
234
+ assert_np_equal(out, a_grad.numpy())
235
+
236
+
237
+ def test_inverse(test, device, dtype, register_kernels=False):
238
+ rng = np.random.default_rng(123)
239
+
240
+ tol = {
241
+ np.float16: 2.0e-3,
242
+ np.float32: 1.0e-6,
243
+ np.float64: 1.0e-8,
244
+ }.get(dtype, 0)
245
+
246
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
247
+ quat = wp.types.quaternion(dtype=wptype)
248
+
249
+ output_select_kernel = get_select_kernel(wptype)
250
+
251
+ def check_quat_inverse(
252
+ input: wp.array(dtype=wptype),
253
+ shouldbeidentity: wp.array(dtype=quat),
254
+ q: wp.array(dtype=wptype),
255
+ ):
256
+ qread = quat(input[0], input[1], input[2], input[3])
257
+ qresult = wp.quat_inverse(qread)
258
+
259
+ # this inverse should work for normalized quaternions:
260
+ shouldbeidentity[0] = wp.normalize(qread) * wp.quat_inverse(wp.normalize(qread))
261
+
262
+ # multiply the output by 2 so we've got something to backpropagate:
263
+ q[0] = wptype(2) * qresult[0]
264
+ q[1] = wptype(2) * qresult[1]
265
+ q[2] = wptype(2) * qresult[2]
266
+ q[3] = wptype(2) * qresult[3]
267
+
268
+ kernel = getkernel(check_quat_inverse, suffix=dtype.__name__)
269
+
270
+ if register_kernels:
271
+ return
272
+
273
+ input = wp.array(rng.standard_normal(size=4).astype(dtype), requires_grad=True, device=device)
274
+ shouldbeidentity = wp.array(np.zeros((1, 4)), dtype=quat, requires_grad=True, device=device)
275
+ output = wp.zeros_like(input)
276
+ wp.launch(kernel, dim=1, inputs=[input], outputs=[shouldbeidentity, output], device=device)
277
+
278
+ assert_np_equal(shouldbeidentity.numpy(), np.array([0, 0, 0, 1]), tol=tol)
279
+
280
+ for i in range(4):
281
+ cmp = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
282
+ tape = wp.Tape()
283
+ with tape:
284
+ wp.launch(kernel, dim=1, inputs=[input], outputs=[shouldbeidentity, output], device=device)
285
+ wp.launch(output_select_kernel, dim=1, inputs=[output, i], outputs=[cmp], device=device)
286
+ tape.backward(loss=cmp)
287
+ expectedgrads = np.zeros(len(input))
288
+ expectedgrads[i] = -2 if i != 3 else 2
289
+ assert_np_equal(tape.gradients[input].numpy(), expectedgrads)
290
+ tape.zero()
291
+
292
+
293
+ def test_dotproduct(test, device, dtype, register_kernels=False):
294
+ rng = np.random.default_rng(123)
295
+
296
+ tol = {
297
+ np.float16: 1.0e-2,
298
+ np.float32: 1.0e-6,
299
+ np.float64: 1.0e-8,
300
+ }.get(dtype, 0)
301
+
302
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
303
+ quat = wp.types.quaternion(dtype=wptype)
304
+
305
+ def check_quat_dot(
306
+ s: wp.array(dtype=quat),
307
+ v: wp.array(dtype=quat),
308
+ dot: wp.array(dtype=wptype),
309
+ ):
310
+ dot[0] = wptype(2) * wp.dot(v[0], s[0])
311
+
312
+ dotkernel = getkernel(check_quat_dot, suffix=dtype.__name__)
313
+ if register_kernels:
314
+ return
315
+
316
+ s = wp.array(rng.standard_normal(size=4).astype(dtype), dtype=quat, requires_grad=True, device=device)
317
+ v = wp.array(rng.standard_normal(size=4).astype(dtype), dtype=quat, requires_grad=True, device=device)
318
+ dot = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
319
+
320
+ tape = wp.Tape()
321
+ with tape:
322
+ wp.launch(
323
+ dotkernel,
324
+ dim=1,
325
+ inputs=[
326
+ s,
327
+ v,
328
+ ],
329
+ outputs=[dot],
330
+ device=device,
331
+ )
332
+
333
+ assert_np_equal(dot.numpy()[0], 2.0 * (v.numpy() * s.numpy()).sum(), tol=tol)
334
+
335
+ tape.backward(loss=dot)
336
+ sgrads = tape.gradients[s].numpy()[0]
337
+ expected_grads = 2.0 * v.numpy()[0]
338
+ assert_np_equal(sgrads, expected_grads, tol=10 * tol)
339
+
340
+ vgrads = tape.gradients[v].numpy()[0]
341
+ expected_grads = 2.0 * s.numpy()[0]
342
+ assert_np_equal(vgrads, expected_grads, tol=tol)
343
+
344
+
345
+ def test_length(test, device, dtype, register_kernels=False):
346
+ rng = np.random.default_rng(123)
347
+
348
+ tol = {
349
+ np.float16: 5.0e-3,
350
+ np.float32: 1.0e-6,
351
+ np.float64: 1.0e-7,
352
+ }.get(dtype, 0)
353
+
354
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
355
+ quat = wp.types.quaternion(dtype=wptype)
356
+
357
+ def check_quat_length(
358
+ q: wp.array(dtype=quat),
359
+ l: wp.array(dtype=wptype),
360
+ l2: wp.array(dtype=wptype),
361
+ ):
362
+ l[0] = wptype(2) * wp.length(q[0])
363
+ l2[0] = wptype(2) * wp.length_sq(q[0])
364
+
365
+ kernel = getkernel(check_quat_length, suffix=dtype.__name__)
366
+
367
+ if register_kernels:
368
+ return
369
+
370
+ q = wp.array(rng.standard_normal(size=4).astype(dtype), dtype=quat, requires_grad=True, device=device)
371
+ l = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
372
+ l2 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
373
+
374
+ tape = wp.Tape()
375
+ with tape:
376
+ wp.launch(
377
+ kernel,
378
+ dim=1,
379
+ inputs=[
380
+ q,
381
+ ],
382
+ outputs=[l, l2],
383
+ device=device,
384
+ )
385
+
386
+ assert_np_equal(l.numpy()[0], 2 * np.linalg.norm(q.numpy()), tol=10 * tol)
387
+ assert_np_equal(l2.numpy()[0], 2 * np.linalg.norm(q.numpy()) ** 2, tol=10 * tol)
388
+
389
+ tape.backward(loss=l)
390
+ grad = tape.gradients[q].numpy()[0]
391
+ expected_grad = 2 * q.numpy()[0] / np.linalg.norm(q.numpy())
392
+ assert_np_equal(grad, expected_grad, tol=10 * tol)
393
+ tape.zero()
394
+
395
+ tape.backward(loss=l2)
396
+ grad = tape.gradients[q].numpy()[0]
397
+ expected_grad = 4 * q.numpy()[0]
398
+ assert_np_equal(grad, expected_grad, tol=10 * tol)
399
+ tape.zero()
400
+
401
+
402
+ def test_normalize(test, device, dtype, register_kernels=False):
403
+ rng = np.random.default_rng(123)
404
+
405
+ tol = {
406
+ np.float16: 5.0e-3,
407
+ np.float32: 1.0e-6,
408
+ np.float64: 1.0e-8,
409
+ }.get(dtype, 0)
410
+
411
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
412
+ quat = wp.types.quaternion(dtype=wptype)
413
+
414
+ def check_normalize(
415
+ q: wp.array(dtype=quat),
416
+ n0: wp.array(dtype=wptype),
417
+ n1: wp.array(dtype=wptype),
418
+ n2: wp.array(dtype=wptype),
419
+ n3: wp.array(dtype=wptype),
420
+ ):
421
+ n = wptype(2) * (wp.normalize(q[0]))
422
+
423
+ n0[0] = n[0]
424
+ n1[0] = n[1]
425
+ n2[0] = n[2]
426
+ n3[0] = n[3]
427
+
428
+ def check_normalize_alt(
429
+ q: wp.array(dtype=quat),
430
+ n0: wp.array(dtype=wptype),
431
+ n1: wp.array(dtype=wptype),
432
+ n2: wp.array(dtype=wptype),
433
+ n3: wp.array(dtype=wptype),
434
+ ):
435
+ n = wptype(2) * (q[0] / wp.length(q[0]))
436
+
437
+ n0[0] = n[0]
438
+ n1[0] = n[1]
439
+ n2[0] = n[2]
440
+ n3[0] = n[3]
441
+
442
+ normalize_kernel = getkernel(check_normalize, suffix=dtype.__name__)
443
+ normalize_alt_kernel = getkernel(check_normalize_alt, suffix=dtype.__name__)
444
+
445
+ if register_kernels:
446
+ return
447
+
448
+ # I've already tested the things I'm using in check_normalize_alt, so I'll just
449
+ # make sure the two are giving the same results/gradients
450
+ q = wp.array(rng.standard_normal(size=4).astype(dtype), dtype=quat, requires_grad=True, device=device)
451
+
452
+ n0 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
453
+ n1 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
454
+ n2 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
455
+ n3 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
456
+
457
+ n0_alt = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
458
+ n1_alt = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
459
+ n2_alt = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
460
+ n3_alt = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
461
+
462
+ outputs0 = [
463
+ n0,
464
+ n1,
465
+ n2,
466
+ n3,
467
+ ]
468
+ tape0 = wp.Tape()
469
+ with tape0:
470
+ wp.launch(normalize_kernel, dim=1, inputs=[q], outputs=outputs0, device=device)
471
+
472
+ outputs1 = [
473
+ n0_alt,
474
+ n1_alt,
475
+ n2_alt,
476
+ n3_alt,
477
+ ]
478
+ tape1 = wp.Tape()
479
+ with tape1:
480
+ wp.launch(
481
+ normalize_alt_kernel,
482
+ dim=1,
483
+ inputs=[
484
+ q,
485
+ ],
486
+ outputs=outputs1,
487
+ device=device,
488
+ )
489
+
490
+ assert_np_equal(n0.numpy()[0], n0_alt.numpy()[0], tol=tol)
491
+ assert_np_equal(n1.numpy()[0], n1_alt.numpy()[0], tol=tol)
492
+ assert_np_equal(n2.numpy()[0], n2_alt.numpy()[0], tol=tol)
493
+ assert_np_equal(n3.numpy()[0], n3_alt.numpy()[0], tol=tol)
494
+
495
+ for ncmp, ncmpalt in zip(outputs0, outputs1):
496
+ tape0.backward(loss=ncmp)
497
+ tape1.backward(loss=ncmpalt)
498
+ assert_np_equal(tape0.gradients[q].numpy()[0], tape1.gradients[q].numpy()[0], tol=tol)
499
+ tape0.zero()
500
+ tape1.zero()
501
+
502
+
503
+ def test_addition(test, device, dtype, register_kernels=False):
504
+ rng = np.random.default_rng(123)
505
+
506
+ tol = {
507
+ np.float16: 5.0e-3,
508
+ np.float32: 1.0e-6,
509
+ np.float64: 1.0e-8,
510
+ }.get(dtype, 0)
511
+
512
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
513
+ quat = wp.types.quaternion(dtype=wptype)
514
+
515
+ def check_quat_add(
516
+ q: wp.array(dtype=quat),
517
+ v: wp.array(dtype=quat),
518
+ r0: wp.array(dtype=wptype),
519
+ r1: wp.array(dtype=wptype),
520
+ r2: wp.array(dtype=wptype),
521
+ r3: wp.array(dtype=wptype),
522
+ ):
523
+ result = q[0] + v[0]
524
+
525
+ r0[0] = wptype(2) * result[0]
526
+ r1[0] = wptype(2) * result[1]
527
+ r2[0] = wptype(2) * result[2]
528
+ r3[0] = wptype(2) * result[3]
529
+
530
+ kernel = getkernel(check_quat_add, suffix=dtype.__name__)
531
+
532
+ if register_kernels:
533
+ return
534
+
535
+ q = wp.array(rng.standard_normal(size=4).astype(dtype), dtype=quat, requires_grad=True, device=device)
536
+ v = wp.array(rng.standard_normal(size=4).astype(dtype), dtype=quat, requires_grad=True, device=device)
537
+
538
+ r0 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
539
+ r1 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
540
+ r2 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
541
+ r3 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
542
+
543
+ tape = wp.Tape()
544
+ with tape:
545
+ wp.launch(
546
+ kernel,
547
+ dim=1,
548
+ inputs=[
549
+ q,
550
+ v,
551
+ ],
552
+ outputs=[r0, r1, r2, r3],
553
+ device=device,
554
+ )
555
+
556
+ assert_np_equal(r0.numpy()[0], 2 * (v.numpy()[0, 0] + q.numpy()[0, 0]), tol=tol)
557
+ assert_np_equal(r1.numpy()[0], 2 * (v.numpy()[0, 1] + q.numpy()[0, 1]), tol=tol)
558
+ assert_np_equal(r2.numpy()[0], 2 * (v.numpy()[0, 2] + q.numpy()[0, 2]), tol=tol)
559
+ assert_np_equal(r3.numpy()[0], 2 * (v.numpy()[0, 3] + q.numpy()[0, 3]), tol=tol)
560
+
561
+ for i, l in enumerate([r0, r1, r2, r3]):
562
+ tape.backward(loss=l)
563
+ qgrads = tape.gradients[q].numpy()[0]
564
+ expected_grads = np.zeros_like(qgrads)
565
+
566
+ expected_grads[i] = 2
567
+ assert_np_equal(qgrads, expected_grads, tol=10 * tol)
568
+
569
+ vgrads = tape.gradients[v].numpy()[0]
570
+ assert_np_equal(vgrads, expected_grads, tol=tol)
571
+
572
+ tape.zero()
573
+
574
+
575
+ def test_subtraction(test, device, dtype, register_kernels=False):
576
+ rng = np.random.default_rng(123)
577
+
578
+ tol = {
579
+ np.float16: 5.0e-3,
580
+ np.float32: 1.0e-6,
581
+ np.float64: 1.0e-8,
582
+ }.get(dtype, 0)
583
+
584
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
585
+ quat = wp.types.quaternion(dtype=wptype)
586
+
587
+ def check_quat_sub(
588
+ q: wp.array(dtype=quat),
589
+ v: wp.array(dtype=quat),
590
+ r0: wp.array(dtype=wptype),
591
+ r1: wp.array(dtype=wptype),
592
+ r2: wp.array(dtype=wptype),
593
+ r3: wp.array(dtype=wptype),
594
+ ):
595
+ result = v[0] - q[0]
596
+
597
+ r0[0] = wptype(2) * result[0]
598
+ r1[0] = wptype(2) * result[1]
599
+ r2[0] = wptype(2) * result[2]
600
+ r3[0] = wptype(2) * result[3]
601
+
602
+ kernel = getkernel(check_quat_sub, suffix=dtype.__name__)
603
+
604
+ if register_kernels:
605
+ return
606
+
607
+ q = wp.array(rng.standard_normal(size=4).astype(dtype), dtype=quat, requires_grad=True, device=device)
608
+ v = wp.array(rng.standard_normal(size=4).astype(dtype), dtype=quat, requires_grad=True, device=device)
609
+
610
+ r0 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
611
+ r1 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
612
+ r2 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
613
+ r3 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
614
+
615
+ tape = wp.Tape()
616
+ with tape:
617
+ wp.launch(
618
+ kernel,
619
+ dim=1,
620
+ inputs=[
621
+ q,
622
+ v,
623
+ ],
624
+ outputs=[r0, r1, r2, r3],
625
+ device=device,
626
+ )
627
+
628
+ assert_np_equal(r0.numpy()[0], 2 * (v.numpy()[0, 0] - q.numpy()[0, 0]), tol=tol)
629
+ assert_np_equal(r1.numpy()[0], 2 * (v.numpy()[0, 1] - q.numpy()[0, 1]), tol=tol)
630
+ assert_np_equal(r2.numpy()[0], 2 * (v.numpy()[0, 2] - q.numpy()[0, 2]), tol=tol)
631
+ assert_np_equal(r3.numpy()[0], 2 * (v.numpy()[0, 3] - q.numpy()[0, 3]), tol=tol)
632
+
633
+ for i, l in enumerate([r0, r1, r2, r3]):
634
+ tape.backward(loss=l)
635
+ qgrads = tape.gradients[q].numpy()[0]
636
+ expected_grads = np.zeros_like(qgrads)
637
+
638
+ expected_grads[i] = -2
639
+ assert_np_equal(qgrads, expected_grads, tol=10 * tol)
640
+
641
+ vgrads = tape.gradients[v].numpy()[0]
642
+ expected_grads[i] = 2
643
+ assert_np_equal(vgrads, expected_grads, tol=tol)
644
+
645
+ tape.zero()
646
+
647
+
648
+ def test_scalar_multiplication(test, device, dtype, register_kernels=False):
649
+ rng = np.random.default_rng(123)
650
+
651
+ tol = {
652
+ np.float16: 5.0e-3,
653
+ np.float32: 1.0e-6,
654
+ np.float64: 1.0e-8,
655
+ }.get(dtype, 0)
656
+
657
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
658
+ quat = wp.types.quaternion(dtype=wptype)
659
+
660
+ def check_quat_scalar_mul(
661
+ s: wp.array(dtype=wptype),
662
+ q: wp.array(dtype=quat),
663
+ l0: wp.array(dtype=wptype),
664
+ l1: wp.array(dtype=wptype),
665
+ l2: wp.array(dtype=wptype),
666
+ l3: wp.array(dtype=wptype),
667
+ r0: wp.array(dtype=wptype),
668
+ r1: wp.array(dtype=wptype),
669
+ r2: wp.array(dtype=wptype),
670
+ r3: wp.array(dtype=wptype),
671
+ ):
672
+ lresult = s[0] * q[0]
673
+ rresult = q[0] * s[0]
674
+
675
+ # multiply outputs by 2 so we've got something to backpropagate:
676
+ l0[0] = wptype(2) * lresult[0]
677
+ l1[0] = wptype(2) * lresult[1]
678
+ l2[0] = wptype(2) * lresult[2]
679
+ l3[0] = wptype(2) * lresult[3]
680
+
681
+ r0[0] = wptype(2) * rresult[0]
682
+ r1[0] = wptype(2) * rresult[1]
683
+ r2[0] = wptype(2) * rresult[2]
684
+ r3[0] = wptype(2) * rresult[3]
685
+
686
+ kernel = getkernel(check_quat_scalar_mul, suffix=dtype.__name__)
687
+
688
+ if register_kernels:
689
+ return
690
+
691
+ s = wp.array(rng.standard_normal(size=1).astype(dtype), requires_grad=True, device=device)
692
+ q = wp.array(rng.standard_normal(size=(1, 4)).astype(dtype), dtype=quat, requires_grad=True, device=device)
693
+
694
+ l0 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
695
+ l1 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
696
+ l2 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
697
+ l3 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
698
+
699
+ r0 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
700
+ r1 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
701
+ r2 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
702
+ r3 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
703
+
704
+ tape = wp.Tape()
705
+ with tape:
706
+ wp.launch(
707
+ kernel,
708
+ dim=1,
709
+ inputs=[s, q],
710
+ outputs=[
711
+ l0,
712
+ l1,
713
+ l2,
714
+ l3,
715
+ r0,
716
+ r1,
717
+ r2,
718
+ r3,
719
+ ],
720
+ device=device,
721
+ )
722
+
723
+ assert_np_equal(l0.numpy()[0], 2 * s.numpy()[0] * q.numpy()[0, 0], tol=tol)
724
+ assert_np_equal(l1.numpy()[0], 2 * s.numpy()[0] * q.numpy()[0, 1], tol=tol)
725
+ assert_np_equal(l2.numpy()[0], 2 * s.numpy()[0] * q.numpy()[0, 2], tol=tol)
726
+ assert_np_equal(l3.numpy()[0], 2 * s.numpy()[0] * q.numpy()[0, 3], tol=tol)
727
+
728
+ assert_np_equal(r0.numpy()[0], 2 * s.numpy()[0] * q.numpy()[0, 0], tol=tol)
729
+ assert_np_equal(r1.numpy()[0], 2 * s.numpy()[0] * q.numpy()[0, 1], tol=tol)
730
+ assert_np_equal(r2.numpy()[0], 2 * s.numpy()[0] * q.numpy()[0, 2], tol=tol)
731
+ assert_np_equal(r3.numpy()[0], 2 * s.numpy()[0] * q.numpy()[0, 3], tol=tol)
732
+
733
+ if dtype in np_float_types:
734
+ for i, outputs in enumerate([(l0, r0), (l1, r1), (l2, r2), (l3, r3)]):
735
+ for l in outputs:
736
+ tape.backward(loss=l)
737
+ sgrad = tape.gradients[s].numpy()[0]
738
+ assert_np_equal(sgrad, 2 * q.numpy()[0, i], tol=tol)
739
+ allgrads = tape.gradients[q].numpy()[0]
740
+ expected_grads = np.zeros_like(allgrads)
741
+ expected_grads[i] = s.numpy()[0] * 2
742
+ assert_np_equal(allgrads, expected_grads, tol=10 * tol)
743
+ tape.zero()
744
+
745
+
746
+ def test_scalar_division(test, device, dtype, register_kernels=False):
747
+ rng = np.random.default_rng(123)
748
+
749
+ tol = {
750
+ np.float16: 1.0e-3,
751
+ np.float32: 1.0e-6,
752
+ np.float64: 1.0e-8,
753
+ }.get(dtype, 0)
754
+
755
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
756
+ quat = wp.types.quaternion(dtype=wptype)
757
+
758
+ def check_quat_scalar_div(
759
+ s: wp.array(dtype=wptype),
760
+ q: wp.array(dtype=quat),
761
+ r0: wp.array(dtype=wptype),
762
+ r1: wp.array(dtype=wptype),
763
+ r2: wp.array(dtype=wptype),
764
+ r3: wp.array(dtype=wptype),
765
+ ):
766
+ result = q[0] / s[0]
767
+
768
+ # multiply outputs by 2 so we've got something to backpropagate:
769
+ r0[0] = wptype(2) * result[0]
770
+ r1[0] = wptype(2) * result[1]
771
+ r2[0] = wptype(2) * result[2]
772
+ r3[0] = wptype(2) * result[3]
773
+
774
+ kernel = getkernel(check_quat_scalar_div, suffix=dtype.__name__)
775
+
776
+ if register_kernels:
777
+ return
778
+
779
+ s = wp.array(rng.standard_normal(size=1).astype(dtype), requires_grad=True, device=device)
780
+ q = wp.array(rng.standard_normal(size=(1, 4)).astype(dtype), dtype=quat, requires_grad=True, device=device)
781
+
782
+ r0 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
783
+ r1 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
784
+ r2 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
785
+ r3 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
786
+
787
+ tape = wp.Tape()
788
+ with tape:
789
+ wp.launch(
790
+ kernel,
791
+ dim=1,
792
+ inputs=[s, q],
793
+ outputs=[
794
+ r0,
795
+ r1,
796
+ r2,
797
+ r3,
798
+ ],
799
+ device=device,
800
+ )
801
+ assert_np_equal(r0.numpy()[0], 2 * q.numpy()[0, 0] / s.numpy()[0], tol=tol)
802
+ assert_np_equal(r1.numpy()[0], 2 * q.numpy()[0, 1] / s.numpy()[0], tol=tol)
803
+ assert_np_equal(r2.numpy()[0], 2 * q.numpy()[0, 2] / s.numpy()[0], tol=tol)
804
+ assert_np_equal(r3.numpy()[0], 2 * q.numpy()[0, 3] / s.numpy()[0], tol=tol)
805
+
806
+ if dtype in np_float_types:
807
+ for i, r in enumerate([r0, r1, r2, r3]):
808
+ tape.backward(loss=r)
809
+ sgrad = tape.gradients[s].numpy()[0]
810
+ assert_np_equal(sgrad, -2 * q.numpy()[0, i] / (s.numpy()[0] * s.numpy()[0]), tol=tol)
811
+
812
+ allgrads = tape.gradients[q].numpy()[0]
813
+ expected_grads = np.zeros_like(allgrads)
814
+ expected_grads[i] = 2 / s.numpy()[0]
815
+ assert_np_equal(allgrads, expected_grads, tol=10 * tol)
816
+ tape.zero()
817
+
818
+
819
+ def test_quat_multiplication(test, device, dtype, register_kernels=False):
820
+ rng = np.random.default_rng(123)
821
+
822
+ tol = {
823
+ np.float16: 1.0e-2,
824
+ np.float32: 1.0e-6,
825
+ np.float64: 1.0e-8,
826
+ }.get(dtype, 0)
827
+
828
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
829
+ quat = wp.types.quaternion(dtype=wptype)
830
+
831
+ def check_quat_mul(
832
+ s: wp.array(dtype=quat),
833
+ q: wp.array(dtype=quat),
834
+ r0: wp.array(dtype=wptype),
835
+ r1: wp.array(dtype=wptype),
836
+ r2: wp.array(dtype=wptype),
837
+ r3: wp.array(dtype=wptype),
838
+ ):
839
+ result = s[0] * q[0]
840
+
841
+ # multiply outputs by 2 so we've got something to backpropagate:
842
+ r0[0] = wptype(2) * result[0]
843
+ r1[0] = wptype(2) * result[1]
844
+ r2[0] = wptype(2) * result[2]
845
+ r3[0] = wptype(2) * result[3]
846
+
847
+ kernel = getkernel(check_quat_mul, suffix=dtype.__name__)
848
+
849
+ if register_kernels:
850
+ return
851
+
852
+ s = wp.array(rng.standard_normal(size=(1, 4)).astype(dtype), dtype=quat, requires_grad=True, device=device)
853
+ q = wp.array(rng.standard_normal(size=(1, 4)).astype(dtype), dtype=quat, requires_grad=True, device=device)
854
+
855
+ r0 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
856
+ r1 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
857
+ r2 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
858
+ r3 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
859
+
860
+ tape = wp.Tape()
861
+ with tape:
862
+ wp.launch(
863
+ kernel,
864
+ dim=1,
865
+ inputs=[s, q],
866
+ outputs=[
867
+ r0,
868
+ r1,
869
+ r2,
870
+ r3,
871
+ ],
872
+ device=device,
873
+ )
874
+
875
+ a = s.numpy()
876
+ b = q.numpy()
877
+ assert_np_equal(
878
+ r0.numpy()[0], 2 * (a[0, 3] * b[0, 0] + b[0, 3] * a[0, 0] + a[0, 1] * b[0, 2] - b[0, 1] * a[0, 2]), tol=tol
879
+ )
880
+ assert_np_equal(
881
+ r1.numpy()[0], 2 * (a[0, 3] * b[0, 1] + b[0, 3] * a[0, 1] + a[0, 2] * b[0, 0] - b[0, 2] * a[0, 0]), tol=tol
882
+ )
883
+ assert_np_equal(
884
+ r2.numpy()[0], 2 * (a[0, 3] * b[0, 2] + b[0, 3] * a[0, 2] + a[0, 0] * b[0, 1] - b[0, 0] * a[0, 1]), tol=tol
885
+ )
886
+ assert_np_equal(
887
+ r3.numpy()[0], 2 * (a[0, 3] * b[0, 3] - a[0, 0] * b[0, 0] - a[0, 1] * b[0, 1] - a[0, 2] * b[0, 2]), tol=tol
888
+ )
889
+
890
+ tape.backward(loss=r0)
891
+ agrad = tape.gradients[s].numpy()[0]
892
+ assert_np_equal(agrad, 2 * np.array([b[0, 3], b[0, 2], -b[0, 1], b[0, 0]]), tol=tol)
893
+
894
+ bgrad = tape.gradients[q].numpy()[0]
895
+ assert_np_equal(bgrad, 2 * np.array([a[0, 3], -a[0, 2], a[0, 1], a[0, 0]]), tol=tol)
896
+ tape.zero()
897
+
898
+ tape.backward(loss=r1)
899
+ agrad = tape.gradients[s].numpy()[0]
900
+ assert_np_equal(agrad, 2 * np.array([-b[0, 2], b[0, 3], b[0, 0], b[0, 1]]), tol=tol)
901
+
902
+ bgrad = tape.gradients[q].numpy()[0]
903
+ assert_np_equal(bgrad, 2 * np.array([a[0, 2], a[0, 3], -a[0, 0], a[0, 1]]), tol=tol)
904
+ tape.zero()
905
+
906
+ tape.backward(loss=r2)
907
+ agrad = tape.gradients[s].numpy()[0]
908
+ assert_np_equal(agrad, 2 * np.array([b[0, 1], -b[0, 0], b[0, 3], b[0, 2]]), tol=tol)
909
+
910
+ bgrad = tape.gradients[q].numpy()[0]
911
+ assert_np_equal(bgrad, 2 * np.array([-a[0, 1], a[0, 0], a[0, 3], a[0, 2]]), tol=tol)
912
+ tape.zero()
913
+
914
+ tape.backward(loss=r3)
915
+ agrad = tape.gradients[s].numpy()[0]
916
+ assert_np_equal(agrad, 2 * np.array([-b[0, 0], -b[0, 1], -b[0, 2], b[0, 3]]), tol=tol)
917
+
918
+ bgrad = tape.gradients[q].numpy()[0]
919
+ assert_np_equal(bgrad, 2 * np.array([-a[0, 0], -a[0, 1], -a[0, 2], a[0, 3]]), tol=tol)
920
+ tape.zero()
921
+
922
+
923
+ def test_indexing(test, device, dtype, register_kernels=False):
924
+ rng = np.random.default_rng(123)
925
+
926
+ tol = {
927
+ np.float16: 5.0e-3,
928
+ np.float32: 1.0e-6,
929
+ np.float64: 1.0e-8,
930
+ }.get(dtype, 0)
931
+
932
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
933
+ quat = wp.types.quaternion(dtype=wptype)
934
+
935
+ def check_quat_indexing(
936
+ q: wp.array(dtype=quat),
937
+ r0: wp.array(dtype=wptype),
938
+ r1: wp.array(dtype=wptype),
939
+ r2: wp.array(dtype=wptype),
940
+ r3: wp.array(dtype=wptype),
941
+ ):
942
+ # multiply outputs by 2 so we've got something to backpropagate:
943
+ r0[0] = wptype(2) * q[0][0]
944
+ r1[0] = wptype(2) * q[0][1]
945
+ r2[0] = wptype(2) * q[0][2]
946
+ r3[0] = wptype(2) * q[0][3]
947
+
948
+ kernel = getkernel(check_quat_indexing, suffix=dtype.__name__)
949
+
950
+ if register_kernels:
951
+ return
952
+
953
+ q = wp.array(rng.standard_normal(size=(1, 4)).astype(dtype), dtype=quat, requires_grad=True, device=device)
954
+ r0 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
955
+ r1 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
956
+ r2 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
957
+ r3 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
958
+
959
+ tape = wp.Tape()
960
+ with tape:
961
+ wp.launch(kernel, dim=1, inputs=[q], outputs=[r0, r1, r2, r3], device=device)
962
+
963
+ for i, l in enumerate([r0, r1, r2, r3]):
964
+ tape.backward(loss=l)
965
+ allgrads = tape.gradients[q].numpy()[0]
966
+ expected_grads = np.zeros_like(allgrads)
967
+ expected_grads[i] = 2
968
+ assert_np_equal(allgrads, expected_grads, tol=tol)
969
+ tape.zero()
970
+
971
+ assert_np_equal(r0.numpy()[0], 2.0 * q.numpy()[0, 0], tol=tol)
972
+ assert_np_equal(r1.numpy()[0], 2.0 * q.numpy()[0, 1], tol=tol)
973
+ assert_np_equal(r2.numpy()[0], 2.0 * q.numpy()[0, 2], tol=tol)
974
+ assert_np_equal(r3.numpy()[0], 2.0 * q.numpy()[0, 3], tol=tol)
975
+
976
+
977
+ @wp.kernel
978
+ def test_assignment():
979
+ q = wp.quat(1.0, 2.0, 3.0, 4.0)
980
+ q[0] = 1.23
981
+ q[1] = 2.34
982
+ q[2] = 3.45
983
+ q[3] = 4.56
984
+ wp.expect_eq(q[0], 1.23)
985
+ wp.expect_eq(q[1], 2.34)
986
+ wp.expect_eq(q[2], 3.45)
987
+ wp.expect_eq(q[3], 4.56)
988
+
989
+
990
+ def test_quat_lerp(test, device, dtype, register_kernels=False):
991
+ rng = np.random.default_rng(123)
992
+
993
+ tol = {
994
+ np.float16: 1.0e-2,
995
+ np.float32: 1.0e-6,
996
+ np.float64: 1.0e-8,
997
+ }.get(dtype, 0)
998
+
999
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1000
+ quat = wp.types.quaternion(dtype=wptype)
1001
+
1002
+ def check_quat_lerp(
1003
+ s: wp.array(dtype=quat),
1004
+ q: wp.array(dtype=quat),
1005
+ t: wp.array(dtype=wptype),
1006
+ r0: wp.array(dtype=wptype),
1007
+ r1: wp.array(dtype=wptype),
1008
+ r2: wp.array(dtype=wptype),
1009
+ r3: wp.array(dtype=wptype),
1010
+ ):
1011
+ result = wp.lerp(s[0], q[0], t[0])
1012
+
1013
+ # multiply outputs by 2 so we've got something to backpropagate:
1014
+ r0[0] = wptype(2) * result[0]
1015
+ r1[0] = wptype(2) * result[1]
1016
+ r2[0] = wptype(2) * result[2]
1017
+ r3[0] = wptype(2) * result[3]
1018
+
1019
+ kernel = getkernel(check_quat_lerp, suffix=dtype.__name__)
1020
+
1021
+ if register_kernels:
1022
+ return
1023
+
1024
+ s = wp.array(rng.standard_normal(size=(1, 4)).astype(dtype), dtype=quat, requires_grad=True, device=device)
1025
+ q = wp.array(rng.standard_normal(size=(1, 4)).astype(dtype), dtype=quat, requires_grad=True, device=device)
1026
+ t = wp.array(rng.uniform(size=1).astype(dtype), dtype=wptype, requires_grad=True, device=device)
1027
+
1028
+ r0 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1029
+ r1 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1030
+ r2 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1031
+ r3 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1032
+
1033
+ tape = wp.Tape()
1034
+ with tape:
1035
+ wp.launch(
1036
+ kernel,
1037
+ dim=1,
1038
+ inputs=[s, q, t],
1039
+ outputs=[
1040
+ r0,
1041
+ r1,
1042
+ r2,
1043
+ r3,
1044
+ ],
1045
+ device=device,
1046
+ )
1047
+
1048
+ a = s.numpy()
1049
+ b = q.numpy()
1050
+ tt = t.numpy()
1051
+ assert_np_equal(r0.numpy()[0], 2 * ((1 - tt) * a[0, 0] + tt * b[0, 0]), tol=tol)
1052
+ assert_np_equal(r1.numpy()[0], 2 * ((1 - tt) * a[0, 1] + tt * b[0, 1]), tol=tol)
1053
+ assert_np_equal(r2.numpy()[0], 2 * ((1 - tt) * a[0, 2] + tt * b[0, 2]), tol=tol)
1054
+ assert_np_equal(r3.numpy()[0], 2 * ((1 - tt) * a[0, 3] + tt * b[0, 3]), tol=tol)
1055
+
1056
+ for i, l in enumerate([r0, r1, r2, r3]):
1057
+ tape.backward(loss=l)
1058
+ agrad = tape.gradients[s].numpy()[0]
1059
+ bgrad = tape.gradients[q].numpy()[0]
1060
+ tgrad = tape.gradients[t].numpy()[0]
1061
+ expected_grads = np.zeros_like(agrad)
1062
+ expected_grads[i] = 2 * (1 - tt)
1063
+ assert_np_equal(agrad, expected_grads, tol=tol)
1064
+ expected_grads[i] = 2 * tt
1065
+ assert_np_equal(bgrad, expected_grads, tol=tol)
1066
+ assert_np_equal(tgrad, 2 * (b[0, i] - a[0, i]), tol=tol)
1067
+
1068
+ tape.zero()
1069
+
1070
+
1071
+ def test_quat_rotate(test, device, dtype, register_kernels=False):
1072
+ rng = np.random.default_rng(123)
1073
+
1074
+ tol = {
1075
+ np.float16: 1.0e-2,
1076
+ np.float32: 1.0e-6,
1077
+ np.float64: 1.0e-8,
1078
+ }.get(dtype, 0)
1079
+
1080
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1081
+ quat = wp.types.quaternion(dtype=wptype)
1082
+ vec3 = wp.types.vector(length=3, dtype=wptype)
1083
+
1084
+ def check_quat_rotate(
1085
+ q: wp.array(dtype=quat),
1086
+ v: wp.array(dtype=vec3),
1087
+ outputs: wp.array(dtype=wptype),
1088
+ outputs_inv: wp.array(dtype=wptype),
1089
+ outputs_manual: wp.array(dtype=wptype),
1090
+ outputs_inv_manual: wp.array(dtype=wptype),
1091
+ ):
1092
+ result = wp.quat_rotate(q[0], v[0])
1093
+ result_inv = wp.quat_rotate_inv(q[0], v[0])
1094
+
1095
+ qv = vec3(q[0][0], q[0][1], q[0][2])
1096
+ qw = q[0][3]
1097
+
1098
+ result_manual = v[0] * (wptype(2) * qw * qw - wptype(1))
1099
+ result_manual += wp.cross(qv, v[0]) * qw * wptype(2)
1100
+ result_manual += qv * wp.dot(qv, v[0]) * wptype(2)
1101
+
1102
+ result_inv_manual = v[0] * (wptype(2) * qw * qw - wptype(1))
1103
+ result_inv_manual -= wp.cross(qv, v[0]) * qw * wptype(2)
1104
+ result_inv_manual += qv * wp.dot(qv, v[0]) * wptype(2)
1105
+
1106
+ for i in range(3):
1107
+ # multiply outputs by 2 so we've got something to backpropagate:
1108
+ outputs[i] = wptype(2) * result[i]
1109
+ outputs_inv[i] = wptype(2) * result_inv[i]
1110
+ outputs_manual[i] = wptype(2) * result_manual[i]
1111
+ outputs_inv_manual[i] = wptype(2) * result_inv_manual[i]
1112
+
1113
+ kernel = getkernel(check_quat_rotate, suffix=dtype.__name__)
1114
+ output_select_kernel = get_select_kernel(wptype)
1115
+
1116
+ if register_kernels:
1117
+ return
1118
+
1119
+ q = rng.standard_normal(size=(1, 4))
1120
+ q /= np.linalg.norm(q)
1121
+ q = wp.array(q.astype(dtype), dtype=quat, requires_grad=True, device=device)
1122
+ v = wp.array(0.5 * rng.standard_normal(size=(1, 3)).astype(dtype), dtype=vec3, requires_grad=True, device=device)
1123
+
1124
+ # test values against the manually computed result:
1125
+ outputs = wp.zeros(3, dtype=wptype, requires_grad=True, device=device)
1126
+ outputs_inv = wp.zeros(3, dtype=wptype, requires_grad=True, device=device)
1127
+ outputs_manual = wp.zeros(3, dtype=wptype, requires_grad=True, device=device)
1128
+ outputs_inv_manual = wp.zeros(3, dtype=wptype, requires_grad=True, device=device)
1129
+
1130
+ wp.launch(
1131
+ kernel,
1132
+ dim=1,
1133
+ inputs=[q, v],
1134
+ outputs=[
1135
+ outputs,
1136
+ outputs_inv,
1137
+ outputs_manual,
1138
+ outputs_inv_manual,
1139
+ ],
1140
+ device=device,
1141
+ )
1142
+
1143
+ assert_np_equal(outputs.numpy(), outputs_manual.numpy(), tol=tol)
1144
+ assert_np_equal(outputs_inv.numpy(), outputs_inv_manual.numpy(), tol=tol)
1145
+
1146
+ # test gradients against the manually computed result:
1147
+ for i in range(3):
1148
+ cmp = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1149
+ cmp_inv = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1150
+ cmp_manual = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1151
+ cmp_inv_manual = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1152
+ tape = wp.Tape()
1153
+ with tape:
1154
+ wp.launch(
1155
+ kernel,
1156
+ dim=1,
1157
+ inputs=[q, v],
1158
+ outputs=[
1159
+ outputs,
1160
+ outputs_inv,
1161
+ outputs_manual,
1162
+ outputs_inv_manual,
1163
+ ],
1164
+ device=device,
1165
+ )
1166
+ wp.launch(output_select_kernel, dim=1, inputs=[outputs, i], outputs=[cmp], device=device)
1167
+ wp.launch(output_select_kernel, dim=1, inputs=[outputs_inv, i], outputs=[cmp_inv], device=device)
1168
+ wp.launch(output_select_kernel, dim=1, inputs=[outputs_manual, i], outputs=[cmp_manual], device=device)
1169
+ wp.launch(
1170
+ output_select_kernel, dim=1, inputs=[outputs_inv_manual, i], outputs=[cmp_inv_manual], device=device
1171
+ )
1172
+
1173
+ tape.backward(loss=cmp)
1174
+ qgrads = 1.0 * tape.gradients[q].numpy()
1175
+ vgrads = 1.0 * tape.gradients[v].numpy()
1176
+ tape.zero()
1177
+ tape.backward(loss=cmp_inv)
1178
+ qgrads_inv = 1.0 * tape.gradients[q].numpy()
1179
+ vgrads_inv = 1.0 * tape.gradients[v].numpy()
1180
+ tape.zero()
1181
+ tape.backward(loss=cmp_manual)
1182
+ qgrads_manual = 1.0 * tape.gradients[q].numpy()
1183
+ vgrads_manual = 1.0 * tape.gradients[v].numpy()
1184
+ tape.zero()
1185
+ tape.backward(loss=cmp_inv_manual)
1186
+ qgrads_inv_manual = 1.0 * tape.gradients[q].numpy()
1187
+ vgrads_inv_manual = 1.0 * tape.gradients[v].numpy()
1188
+ tape.zero()
1189
+
1190
+ assert_np_equal(qgrads, qgrads_manual, tol=tol)
1191
+ assert_np_equal(vgrads, vgrads_manual, tol=tol)
1192
+
1193
+ assert_np_equal(qgrads_inv, qgrads_inv_manual, tol=tol)
1194
+ assert_np_equal(vgrads_inv, vgrads_inv_manual, tol=tol)
1195
+
1196
+
1197
+ def test_quat_to_matrix(test, device, dtype, register_kernels=False):
1198
+ rng = np.random.default_rng(123)
1199
+
1200
+ tol = {
1201
+ np.float16: 1.0e-2,
1202
+ np.float32: 1.0e-6,
1203
+ np.float64: 1.0e-8,
1204
+ }.get(dtype, 0)
1205
+
1206
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1207
+ quat = wp.types.quaternion(dtype=wptype)
1208
+ vec3 = wp.types.vector(length=3, dtype=wptype)
1209
+
1210
+ def check_quat_to_matrix(
1211
+ q: wp.array(dtype=quat),
1212
+ outputs: wp.array(dtype=wptype),
1213
+ outputs_manual: wp.array(dtype=wptype),
1214
+ ):
1215
+ result = wp.quat_to_matrix(q[0])
1216
+
1217
+ xaxis = wp.quat_rotate(
1218
+ q[0],
1219
+ vec3(
1220
+ wptype(1),
1221
+ wptype(0),
1222
+ wptype(0),
1223
+ ),
1224
+ )
1225
+ yaxis = wp.quat_rotate(
1226
+ q[0],
1227
+ vec3(
1228
+ wptype(0),
1229
+ wptype(1),
1230
+ wptype(0),
1231
+ ),
1232
+ )
1233
+ zaxis = wp.quat_rotate(
1234
+ q[0],
1235
+ vec3(
1236
+ wptype(0),
1237
+ wptype(0),
1238
+ wptype(1),
1239
+ ),
1240
+ )
1241
+ result_manual = wp.matrix_from_cols(xaxis, yaxis, zaxis)
1242
+
1243
+ idx = 0
1244
+ for i in range(3):
1245
+ for j in range(3):
1246
+ # multiply outputs by 2 so we've got something to backpropagate:
1247
+ outputs[idx] = wptype(2) * result[i, j]
1248
+ outputs_manual[idx] = wptype(2) * result_manual[i, j]
1249
+
1250
+ idx = idx + 1
1251
+
1252
+ kernel = getkernel(check_quat_to_matrix, suffix=dtype.__name__)
1253
+ output_select_kernel = get_select_kernel(wptype)
1254
+
1255
+ if register_kernels:
1256
+ return
1257
+
1258
+ q = rng.standard_normal(size=(1, 4))
1259
+ q /= np.linalg.norm(q)
1260
+ q = wp.array(q.astype(dtype), dtype=quat, requires_grad=True, device=device)
1261
+
1262
+ # test values against the manually computed result:
1263
+ outputs = wp.zeros(3 * 3, dtype=wptype, requires_grad=True, device=device)
1264
+ outputs_manual = wp.zeros(3 * 3, dtype=wptype, requires_grad=True, device=device)
1265
+
1266
+ wp.launch(
1267
+ kernel,
1268
+ dim=1,
1269
+ inputs=[q],
1270
+ outputs=[
1271
+ outputs,
1272
+ outputs_manual,
1273
+ ],
1274
+ device=device,
1275
+ )
1276
+
1277
+ assert_np_equal(outputs.numpy(), outputs_manual.numpy(), tol=tol)
1278
+
1279
+ # sanity check: divide by 2 to remove that scale factor we put in there, and
1280
+ # it should be a rotation matrix
1281
+ R = 0.5 * outputs.numpy().reshape(3, 3)
1282
+ assert_np_equal(np.matmul(R, R.T), np.eye(3), tol=tol)
1283
+
1284
+ # test gradients against the manually computed result:
1285
+ idx = 0
1286
+ for _i in range(3):
1287
+ for _j in range(3):
1288
+ cmp = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1289
+ cmp_manual = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1290
+ tape = wp.Tape()
1291
+ with tape:
1292
+ wp.launch(
1293
+ kernel,
1294
+ dim=1,
1295
+ inputs=[q],
1296
+ outputs=[
1297
+ outputs,
1298
+ outputs_manual,
1299
+ ],
1300
+ device=device,
1301
+ )
1302
+ wp.launch(output_select_kernel, dim=1, inputs=[outputs, idx], outputs=[cmp], device=device)
1303
+ wp.launch(
1304
+ output_select_kernel, dim=1, inputs=[outputs_manual, idx], outputs=[cmp_manual], device=device
1305
+ )
1306
+ tape.backward(loss=cmp)
1307
+ qgrads = 1.0 * tape.gradients[q].numpy()
1308
+ tape.zero()
1309
+ tape.backward(loss=cmp_manual)
1310
+ qgrads_manual = 1.0 * tape.gradients[q].numpy()
1311
+ tape.zero()
1312
+
1313
+ assert_np_equal(qgrads, qgrads_manual, tol=tol)
1314
+ idx = idx + 1
1315
+
1316
+
1317
+ ############################################################
1318
+
1319
+
1320
+ def test_slerp_grad(test, device, dtype, register_kernels=False):
1321
+ rng = np.random.default_rng(123)
1322
+ seed = 42
1323
+
1324
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1325
+ vec3 = wp.types.vector(3, wptype)
1326
+ quat = wp.types.quaternion(wptype)
1327
+
1328
+ def slerp_kernel(
1329
+ q0: wp.array(dtype=quat),
1330
+ q1: wp.array(dtype=quat),
1331
+ t: wp.array(dtype=wptype),
1332
+ loss: wp.array(dtype=wptype),
1333
+ index: int,
1334
+ ):
1335
+ tid = wp.tid()
1336
+
1337
+ q = wp.quat_slerp(q0[tid], q1[tid], t[tid])
1338
+ wp.atomic_add(loss, 0, q[index])
1339
+
1340
+ slerp_kernel = getkernel(slerp_kernel, suffix=dtype.__name__)
1341
+
1342
+ def slerp_kernel_forward(
1343
+ q0: wp.array(dtype=quat),
1344
+ q1: wp.array(dtype=quat),
1345
+ t: wp.array(dtype=wptype),
1346
+ loss: wp.array(dtype=wptype),
1347
+ index: int,
1348
+ ):
1349
+ tid = wp.tid()
1350
+
1351
+ axis = vec3()
1352
+ angle = wptype(0.0)
1353
+
1354
+ wp.quat_to_axis_angle(wp.mul(wp.quat_inverse(q0[tid]), q1[tid]), axis, angle)
1355
+ q = wp.mul(q0[tid], wp.quat_from_axis_angle(axis, t[tid] * angle))
1356
+
1357
+ wp.atomic_add(loss, 0, q[index])
1358
+
1359
+ slerp_kernel_forward = getkernel(slerp_kernel_forward, suffix=dtype.__name__)
1360
+
1361
+ def quat_sampler_slerp(kernel_seed: int, quats: wp.array(dtype=quat)):
1362
+ tid = wp.tid()
1363
+
1364
+ state = wp.rand_init(kernel_seed, tid)
1365
+
1366
+ angle = wp.randf(state, 0.0, 2.0 * 3.1415926535)
1367
+ dir = wp.sample_unit_sphere_surface(state) * wp.sin(angle * 0.5)
1368
+
1369
+ q = quat(wptype(dir[0]), wptype(dir[1]), wptype(dir[2]), wptype(wp.cos(angle * 0.5)))
1370
+ qn = wp.normalize(q)
1371
+
1372
+ quats[tid] = qn
1373
+
1374
+ quat_sampler = getkernel(quat_sampler_slerp, suffix=dtype.__name__)
1375
+
1376
+ if register_kernels:
1377
+ return
1378
+
1379
+ N = 50
1380
+
1381
+ q0 = wp.zeros(N, dtype=quat, device=device, requires_grad=True)
1382
+ q1 = wp.zeros(N, dtype=quat, device=device, requires_grad=True)
1383
+
1384
+ wp.launch(kernel=quat_sampler, dim=N, inputs=[seed, q0], device=device)
1385
+ wp.launch(kernel=quat_sampler, dim=N, inputs=[seed + 1, q1], device=device)
1386
+
1387
+ t = rng.uniform(low=0.0, high=1.0, size=N)
1388
+ t = wp.array(t, dtype=wptype, device=device, requires_grad=True)
1389
+
1390
+ def compute_gradients(kernel, wrt, index):
1391
+ loss = wp.zeros(1, dtype=wptype, device=device, requires_grad=True)
1392
+ tape = wp.Tape()
1393
+ with tape:
1394
+ wp.launch(kernel=kernel, dim=N, inputs=[q0, q1, t, loss, index], device=device)
1395
+
1396
+ tape.backward(loss)
1397
+
1398
+ gradients = 1.0 * tape.gradients[wrt].numpy()
1399
+ tape.zero()
1400
+
1401
+ return loss.numpy()[0], gradients
1402
+
1403
+ eps = {
1404
+ np.float16: 2.0e-2,
1405
+ np.float32: 1.0e-5,
1406
+ np.float64: 1.0e-8,
1407
+ }.get(dtype, 0)
1408
+
1409
+ # wrt t
1410
+
1411
+ # gather gradients from builtin adjoints
1412
+ xcmp, gradients_x = compute_gradients(slerp_kernel, t, 0)
1413
+ ycmp, gradients_y = compute_gradients(slerp_kernel, t, 1)
1414
+ zcmp, gradients_z = compute_gradients(slerp_kernel, t, 2)
1415
+ wcmp, gradients_w = compute_gradients(slerp_kernel, t, 3)
1416
+
1417
+ # gather gradients from autodiff
1418
+ xcmp_auto, gradients_x_auto = compute_gradients(slerp_kernel_forward, t, 0)
1419
+ ycmp_auto, gradients_y_auto = compute_gradients(slerp_kernel_forward, t, 1)
1420
+ zcmp_auto, gradients_z_auto = compute_gradients(slerp_kernel_forward, t, 2)
1421
+ wcmp_auto, gradients_w_auto = compute_gradients(slerp_kernel_forward, t, 3)
1422
+
1423
+ assert_np_equal(gradients_x, gradients_x_auto, tol=eps)
1424
+ assert_np_equal(gradients_y, gradients_y_auto, tol=eps)
1425
+ assert_np_equal(gradients_z, gradients_z_auto, tol=eps)
1426
+ assert_np_equal(gradients_w, gradients_w_auto, tol=eps)
1427
+ assert_np_equal(xcmp, xcmp_auto, tol=eps)
1428
+ assert_np_equal(ycmp, ycmp_auto, tol=eps)
1429
+ assert_np_equal(zcmp, zcmp_auto, tol=eps)
1430
+ assert_np_equal(wcmp, wcmp_auto, tol=eps)
1431
+
1432
+ # wrt q0
1433
+
1434
+ # gather gradients from builtin adjoints
1435
+ xcmp, gradients_x = compute_gradients(slerp_kernel, q0, 0)
1436
+ ycmp, gradients_y = compute_gradients(slerp_kernel, q0, 1)
1437
+ zcmp, gradients_z = compute_gradients(slerp_kernel, q0, 2)
1438
+ wcmp, gradients_w = compute_gradients(slerp_kernel, q0, 3)
1439
+
1440
+ # gather gradients from autodiff
1441
+ xcmp_auto, gradients_x_auto = compute_gradients(slerp_kernel_forward, q0, 0)
1442
+ ycmp_auto, gradients_y_auto = compute_gradients(slerp_kernel_forward, q0, 1)
1443
+ zcmp_auto, gradients_z_auto = compute_gradients(slerp_kernel_forward, q0, 2)
1444
+ wcmp_auto, gradients_w_auto = compute_gradients(slerp_kernel_forward, q0, 3)
1445
+
1446
+ assert_np_equal(gradients_x, gradients_x_auto, tol=eps)
1447
+ assert_np_equal(gradients_y, gradients_y_auto, tol=eps)
1448
+ assert_np_equal(gradients_z, gradients_z_auto, tol=eps)
1449
+ assert_np_equal(gradients_w, gradients_w_auto, tol=eps)
1450
+ assert_np_equal(xcmp, xcmp_auto, tol=eps)
1451
+ assert_np_equal(ycmp, ycmp_auto, tol=eps)
1452
+ assert_np_equal(zcmp, zcmp_auto, tol=eps)
1453
+ assert_np_equal(wcmp, wcmp_auto, tol=eps)
1454
+
1455
+ # wrt q1
1456
+
1457
+ # gather gradients from builtin adjoints
1458
+ xcmp, gradients_x = compute_gradients(slerp_kernel, q1, 0)
1459
+ ycmp, gradients_y = compute_gradients(slerp_kernel, q1, 1)
1460
+ zcmp, gradients_z = compute_gradients(slerp_kernel, q1, 2)
1461
+ wcmp, gradients_w = compute_gradients(slerp_kernel, q1, 3)
1462
+
1463
+ # gather gradients from autodiff
1464
+ xcmp_auto, gradients_x_auto = compute_gradients(slerp_kernel_forward, q1, 0)
1465
+ ycmp_auto, gradients_y_auto = compute_gradients(slerp_kernel_forward, q1, 1)
1466
+ zcmp_auto, gradients_z_auto = compute_gradients(slerp_kernel_forward, q1, 2)
1467
+ wcmp_auto, gradients_w_auto = compute_gradients(slerp_kernel_forward, q1, 3)
1468
+
1469
+ assert_np_equal(gradients_x, gradients_x_auto, tol=eps)
1470
+ assert_np_equal(gradients_y, gradients_y_auto, tol=eps)
1471
+ assert_np_equal(gradients_z, gradients_z_auto, tol=eps)
1472
+ assert_np_equal(gradients_w, gradients_w_auto, tol=eps)
1473
+ assert_np_equal(xcmp, xcmp_auto, tol=eps)
1474
+ assert_np_equal(ycmp, ycmp_auto, tol=eps)
1475
+ assert_np_equal(zcmp, zcmp_auto, tol=eps)
1476
+ assert_np_equal(wcmp, wcmp_auto, tol=eps)
1477
+
1478
+
1479
+ ############################################################
1480
+
1481
+
1482
+ def test_quat_to_axis_angle_grad(test, device, dtype, register_kernels=False):
1483
+ rng = np.random.default_rng(123)
1484
+ seed = 42
1485
+ num_rand = 50
1486
+
1487
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1488
+ vec3 = wp.types.vector(3, wptype)
1489
+ vec4 = wp.types.vector(4, wptype)
1490
+ quat = wp.types.quaternion(wptype)
1491
+
1492
+ def quat_to_axis_angle_kernel(quats: wp.array(dtype=quat), loss: wp.array(dtype=wptype), coord_idx: int):
1493
+ tid = wp.tid()
1494
+ axis = vec3()
1495
+ angle = wptype(0.0)
1496
+
1497
+ wp.quat_to_axis_angle(quats[tid], axis, angle)
1498
+ a = vec4(axis[0], axis[1], axis[2], angle)
1499
+
1500
+ wp.atomic_add(loss, 0, a[coord_idx])
1501
+
1502
+ quat_to_axis_angle_kernel = getkernel(quat_to_axis_angle_kernel, suffix=dtype.__name__)
1503
+
1504
+ def quat_to_axis_angle_kernel_forward(quats: wp.array(dtype=quat), loss: wp.array(dtype=wptype), coord_idx: int):
1505
+ tid = wp.tid()
1506
+ q = quats[tid]
1507
+ axis = vec3()
1508
+ angle = wptype(0.0)
1509
+
1510
+ v = vec3(q[0], q[1], q[2])
1511
+ if q[3] < wptype(0):
1512
+ axis = -wp.normalize(v)
1513
+ else:
1514
+ axis = wp.normalize(v)
1515
+
1516
+ angle = wptype(2) * wp.atan2(wp.length(v), wp.abs(q[3]))
1517
+ a = vec4(axis[0], axis[1], axis[2], angle)
1518
+
1519
+ wp.atomic_add(loss, 0, a[coord_idx])
1520
+
1521
+ quat_to_axis_angle_kernel_forward = getkernel(quat_to_axis_angle_kernel_forward, suffix=dtype.__name__)
1522
+
1523
+ def quat_sampler(kernel_seed: int, angles: wp.array(dtype=float), quats: wp.array(dtype=quat)):
1524
+ tid = wp.tid()
1525
+
1526
+ state = wp.rand_init(kernel_seed, tid)
1527
+
1528
+ angle = angles[tid]
1529
+ dir = wp.sample_unit_sphere_surface(state) * wp.sin(angle * 0.5)
1530
+
1531
+ q = quat(wptype(dir[0]), wptype(dir[1]), wptype(dir[2]), wptype(wp.cos(angle * 0.5)))
1532
+ qn = wp.normalize(q)
1533
+
1534
+ quats[tid] = qn
1535
+
1536
+ quat_sampler = getkernel(quat_sampler, suffix=dtype.__name__)
1537
+
1538
+ if register_kernels:
1539
+ return
1540
+
1541
+ quats = wp.zeros(num_rand, dtype=quat, device=device, requires_grad=True)
1542
+ angles = wp.array(
1543
+ np.linspace(0.0, 2.0 * np.pi, num_rand, endpoint=False, dtype=np.float32), dtype=float, device=device
1544
+ )
1545
+ wp.launch(kernel=quat_sampler, dim=num_rand, inputs=[seed, angles, quats], device=device)
1546
+
1547
+ edge_cases = np.array(
1548
+ [(1.0, 0.0, 0.0, 0.0), (0.0, 1.0 / np.sqrt(3), 1.0 / np.sqrt(3), 1.0 / np.sqrt(3)), (0.0, 0.0, 0.0, 0.0)]
1549
+ )
1550
+ num_edge = len(edge_cases)
1551
+ edge_cases = wp.array(edge_cases, dtype=quat, device=device, requires_grad=True)
1552
+
1553
+ def compute_gradients(arr, kernel, dim, index):
1554
+ loss = wp.zeros(1, dtype=wptype, device=device, requires_grad=True)
1555
+ tape = wp.Tape()
1556
+ with tape:
1557
+ wp.launch(kernel=kernel, dim=dim, inputs=[arr, loss, index], device=device)
1558
+
1559
+ tape.backward(loss)
1560
+
1561
+ gradients = 1.0 * tape.gradients[arr].numpy()
1562
+ tape.zero()
1563
+
1564
+ return loss.numpy()[0], gradients
1565
+
1566
+ # gather gradients from builtin adjoints
1567
+ xcmp, gradients_x = compute_gradients(quats, quat_to_axis_angle_kernel, num_rand, 0)
1568
+ ycmp, gradients_y = compute_gradients(quats, quat_to_axis_angle_kernel, num_rand, 1)
1569
+ zcmp, gradients_z = compute_gradients(quats, quat_to_axis_angle_kernel, num_rand, 2)
1570
+ wcmp, gradients_w = compute_gradients(quats, quat_to_axis_angle_kernel, num_rand, 3)
1571
+
1572
+ # gather gradients from autodiff
1573
+ xcmp_auto, gradients_x_auto = compute_gradients(quats, quat_to_axis_angle_kernel_forward, num_rand, 0)
1574
+ ycmp_auto, gradients_y_auto = compute_gradients(quats, quat_to_axis_angle_kernel_forward, num_rand, 1)
1575
+ zcmp_auto, gradients_z_auto = compute_gradients(quats, quat_to_axis_angle_kernel_forward, num_rand, 2)
1576
+ wcmp_auto, gradients_w_auto = compute_gradients(quats, quat_to_axis_angle_kernel_forward, num_rand, 3)
1577
+
1578
+ # edge cases: gather gradients from builtin adjoints
1579
+ _, edge_gradients_x = compute_gradients(edge_cases, quat_to_axis_angle_kernel, num_edge, 0)
1580
+ _, edge_gradients_y = compute_gradients(edge_cases, quat_to_axis_angle_kernel, num_edge, 1)
1581
+ _, edge_gradients_z = compute_gradients(edge_cases, quat_to_axis_angle_kernel, num_edge, 2)
1582
+ _, edge_gradients_w = compute_gradients(edge_cases, quat_to_axis_angle_kernel, num_edge, 3)
1583
+
1584
+ # edge cases: gather gradients from autodiff
1585
+ _, edge_gradients_x_auto = compute_gradients(edge_cases, quat_to_axis_angle_kernel_forward, num_edge, 0)
1586
+ _, edge_gradients_y_auto = compute_gradients(edge_cases, quat_to_axis_angle_kernel_forward, num_edge, 1)
1587
+ _, edge_gradients_z_auto = compute_gradients(edge_cases, quat_to_axis_angle_kernel_forward, num_edge, 2)
1588
+ _, edge_gradients_w_auto = compute_gradients(edge_cases, quat_to_axis_angle_kernel_forward, num_edge, 3)
1589
+
1590
+ eps = {
1591
+ np.float16: 2.0e-1,
1592
+ np.float32: 2.0e-4,
1593
+ np.float64: 2.0e-7,
1594
+ }.get(dtype, 0)
1595
+
1596
+ assert_np_equal(xcmp, xcmp_auto, tol=eps)
1597
+ assert_np_equal(ycmp, ycmp_auto, tol=eps)
1598
+ assert_np_equal(zcmp, zcmp_auto, tol=eps)
1599
+ assert_np_equal(wcmp, wcmp_auto, tol=eps)
1600
+
1601
+ assert_np_equal(gradients_x, gradients_x_auto, tol=eps)
1602
+ assert_np_equal(gradients_y, gradients_y_auto, tol=eps)
1603
+ assert_np_equal(gradients_z, gradients_z_auto, tol=eps)
1604
+ assert_np_equal(gradients_w, gradients_w_auto, tol=eps)
1605
+
1606
+ assert_np_equal(edge_gradients_x, edge_gradients_x_auto, tol=eps)
1607
+ assert_np_equal(edge_gradients_y, edge_gradients_y_auto, tol=eps)
1608
+ assert_np_equal(edge_gradients_z, edge_gradients_z_auto, tol=eps)
1609
+ assert_np_equal(edge_gradients_w, edge_gradients_w_auto, tol=eps)
1610
+
1611
+
1612
+ ############################################################
1613
+
1614
+
1615
+ def test_quat_rpy_grad(test, device, dtype, register_kernels=False):
1616
+ rng = np.random.default_rng(123)
1617
+ N = 3
1618
+
1619
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1620
+
1621
+ vec3 = wp.types.vector(3, wptype)
1622
+ quat = wp.types.quaternion(wptype)
1623
+
1624
+ def rpy_to_quat_kernel(rpy_arr: wp.array(dtype=vec3), loss: wp.array(dtype=wptype), coord_idx: int):
1625
+ tid = wp.tid()
1626
+ rpy = rpy_arr[tid]
1627
+ roll = rpy[0]
1628
+ pitch = rpy[1]
1629
+ yaw = rpy[2]
1630
+
1631
+ q = wp.quat_rpy(roll, pitch, yaw)
1632
+
1633
+ wp.atomic_add(loss, 0, q[coord_idx])
1634
+
1635
+ rpy_to_quat_kernel = getkernel(rpy_to_quat_kernel, suffix=dtype.__name__)
1636
+
1637
+ def rpy_to_quat_kernel_forward(rpy_arr: wp.array(dtype=vec3), loss: wp.array(dtype=wptype), coord_idx: int):
1638
+ tid = wp.tid()
1639
+ rpy = rpy_arr[tid]
1640
+ roll = rpy[0]
1641
+ pitch = rpy[1]
1642
+ yaw = rpy[2]
1643
+
1644
+ cy = wp.cos(yaw * wptype(0.5))
1645
+ sy = wp.sin(yaw * wptype(0.5))
1646
+ cr = wp.cos(roll * wptype(0.5))
1647
+ sr = wp.sin(roll * wptype(0.5))
1648
+ cp = wp.cos(pitch * wptype(0.5))
1649
+ sp = wp.sin(pitch * wptype(0.5))
1650
+
1651
+ w = cy * cr * cp + sy * sr * sp
1652
+ x = cy * sr * cp - sy * cr * sp
1653
+ y = cy * cr * sp + sy * sr * cp
1654
+ z = sy * cr * cp - cy * sr * sp
1655
+
1656
+ q = quat(x, y, z, w)
1657
+
1658
+ wp.atomic_add(loss, 0, q[coord_idx])
1659
+
1660
+ rpy_to_quat_kernel_forward = getkernel(rpy_to_quat_kernel_forward, suffix=dtype.__name__)
1661
+
1662
+ if register_kernels:
1663
+ return
1664
+
1665
+ rpy_arr = rng.uniform(low=-np.pi, high=np.pi, size=(N, 3))
1666
+ rpy_arr = wp.array(rpy_arr, dtype=vec3, device=device, requires_grad=True)
1667
+
1668
+ def compute_gradients(kernel, wrt, index):
1669
+ loss = wp.zeros(1, dtype=wptype, device=device, requires_grad=True)
1670
+ tape = wp.Tape()
1671
+ with tape:
1672
+ wp.launch(kernel=kernel, dim=N, inputs=[wrt, loss, index], device=device)
1673
+
1674
+ tape.backward(loss)
1675
+
1676
+ gradients = 1.0 * tape.gradients[wrt].numpy()
1677
+ tape.zero()
1678
+
1679
+ return loss.numpy()[0], gradients
1680
+
1681
+ # wrt rpy
1682
+ # gather gradients from builtin adjoints
1683
+ rcmp, gradients_r = compute_gradients(rpy_to_quat_kernel, rpy_arr, 0)
1684
+ pcmp, gradients_p = compute_gradients(rpy_to_quat_kernel, rpy_arr, 1)
1685
+ ycmp, gradients_y = compute_gradients(rpy_to_quat_kernel, rpy_arr, 2)
1686
+
1687
+ # gather gradients from autodiff
1688
+ rcmp_auto, gradients_r_auto = compute_gradients(rpy_to_quat_kernel_forward, rpy_arr, 0)
1689
+ pcmp_auto, gradients_p_auto = compute_gradients(rpy_to_quat_kernel_forward, rpy_arr, 1)
1690
+ ycmp_auto, gradients_y_auto = compute_gradients(rpy_to_quat_kernel_forward, rpy_arr, 2)
1691
+
1692
+ eps = {
1693
+ np.float16: 2.0e-2,
1694
+ np.float32: 1.0e-5,
1695
+ np.float64: 1.0e-8,
1696
+ }.get(dtype, 0)
1697
+
1698
+ assert_np_equal(rcmp, rcmp_auto, tol=eps)
1699
+ assert_np_equal(pcmp, pcmp_auto, tol=eps)
1700
+ assert_np_equal(ycmp, ycmp_auto, tol=eps)
1701
+
1702
+ assert_np_equal(gradients_r, gradients_r_auto, tol=eps)
1703
+ assert_np_equal(gradients_p, gradients_p_auto, tol=eps)
1704
+ assert_np_equal(gradients_y, gradients_y_auto, tol=eps)
1705
+
1706
+
1707
+ ############################################################
1708
+
1709
+
1710
+ def test_quat_from_matrix(test, device, dtype, register_kernels=False):
1711
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1712
+ mat33 = wp.types.matrix((3, 3), wptype)
1713
+ mat44 = wp.types.matrix((4, 4), wptype)
1714
+ quat = wp.types.quaternion(wptype)
1715
+
1716
+ def quat_from_matrix(m: wp.array2d(dtype=wptype), loss: wp.array(dtype=wptype), idx: int):
1717
+ tid = wp.tid()
1718
+
1719
+ # fmt: off
1720
+ m3 = mat33(
1721
+ m[tid, 0], m[tid, 1], m[tid, 2],
1722
+ m[tid, 3], m[tid, 4], m[tid, 5],
1723
+ m[tid, 6], m[tid, 7], m[tid, 8],
1724
+ )
1725
+ q1 = wp.quat_from_matrix(m3)
1726
+
1727
+ m4 = mat44(
1728
+ m[tid, 0], m[tid, 1], m[tid, 2], wptype(0.0),
1729
+ m[tid, 3], m[tid, 4], m[tid, 5], wptype(0.0),
1730
+ m[tid, 6], m[tid, 7], m[tid, 8], wptype(0.0),
1731
+ wptype(0.0), wptype(0.0), wptype(0.0), wptype(1.0),
1732
+ )
1733
+ q2 = wp.quat_from_matrix(m4)
1734
+ # fmt: on
1735
+
1736
+ wp.expect_eq(q1, q2)
1737
+ wp.atomic_add(loss, 0, q1[idx])
1738
+
1739
+ def quat_from_matrix_forward(mats: wp.array2d(dtype=wptype), loss: wp.array(dtype=wptype), idx: int):
1740
+ tid = wp.tid()
1741
+
1742
+ m = mat33(
1743
+ mats[tid, 0],
1744
+ mats[tid, 1],
1745
+ mats[tid, 2],
1746
+ mats[tid, 3],
1747
+ mats[tid, 4],
1748
+ mats[tid, 5],
1749
+ mats[tid, 6],
1750
+ mats[tid, 7],
1751
+ mats[tid, 8],
1752
+ )
1753
+
1754
+ tr = m[0][0] + m[1][1] + m[2][2]
1755
+ x = wptype(0)
1756
+ y = wptype(0)
1757
+ z = wptype(0)
1758
+ w = wptype(0)
1759
+ h = wptype(0)
1760
+
1761
+ if tr >= wptype(0):
1762
+ h = wp.sqrt(tr + wptype(1))
1763
+ w = wptype(0.5) * h
1764
+ h = wptype(0.5) / h
1765
+
1766
+ x = (m[2][1] - m[1][2]) * h
1767
+ y = (m[0][2] - m[2][0]) * h
1768
+ z = (m[1][0] - m[0][1]) * h
1769
+ else:
1770
+ max_diag = 0
1771
+ if m[1][1] > m[0][0]:
1772
+ max_diag = 1
1773
+ if m[2][2] > m[max_diag][max_diag]:
1774
+ max_diag = 2
1775
+
1776
+ if max_diag == 0:
1777
+ h = wp.sqrt((m[0][0] - (m[1][1] + m[2][2])) + wptype(1))
1778
+ x = wptype(0.5) * h
1779
+ h = wptype(0.5) / h
1780
+
1781
+ y = (m[0][1] + m[1][0]) * h
1782
+ z = (m[2][0] + m[0][2]) * h
1783
+ w = (m[2][1] - m[1][2]) * h
1784
+ elif max_diag == 1:
1785
+ h = wp.sqrt((m[1][1] - (m[2][2] + m[0][0])) + wptype(1))
1786
+ y = wptype(0.5) * h
1787
+ h = wptype(0.5) / h
1788
+
1789
+ z = (m[1][2] + m[2][1]) * h
1790
+ x = (m[0][1] + m[1][0]) * h
1791
+ w = (m[0][2] - m[2][0]) * h
1792
+ if max_diag == 2:
1793
+ h = wp.sqrt((m[2][2] - (m[0][0] + m[1][1])) + wptype(1))
1794
+ z = wptype(0.5) * h
1795
+ h = wptype(0.5) / h
1796
+
1797
+ x = (m[2][0] + m[0][2]) * h
1798
+ y = (m[1][2] + m[2][1]) * h
1799
+ w = (m[1][0] - m[0][1]) * h
1800
+
1801
+ q = wp.normalize(quat(x, y, z, w))
1802
+
1803
+ wp.atomic_add(loss, 0, q[idx])
1804
+
1805
+ quat_from_matrix = getkernel(quat_from_matrix, suffix=dtype.__name__)
1806
+ quat_from_matrix_forward = getkernel(quat_from_matrix_forward, suffix=dtype.__name__)
1807
+
1808
+ if register_kernels:
1809
+ return
1810
+
1811
+ m = np.array(
1812
+ [
1813
+ [1.0, 0.0, 0.0, 0.0, 0.5, 0.866, 0.0, -0.866, 0.5],
1814
+ [0.866, 0.0, 0.25, -0.433, 0.5, 0.75, -0.25, -0.866, 0.433],
1815
+ [0.866, -0.433, 0.25, 0.0, 0.5, 0.866, -0.5, -0.75, 0.433],
1816
+ [-1.2, -1.6, -2.3, 0.25, -0.6, -0.33, 3.2, -1.0, -2.2],
1817
+ ]
1818
+ )
1819
+ m = wp.array2d(m, dtype=wptype, device=device, requires_grad=True)
1820
+
1821
+ N = m.shape[0]
1822
+
1823
+ def compute_gradients(kernel, wrt, index):
1824
+ loss = wp.zeros(1, dtype=wptype, device=device, requires_grad=True)
1825
+ tape = wp.Tape()
1826
+
1827
+ with tape:
1828
+ wp.launch(kernel=kernel, dim=N, inputs=[m, loss, index], device=device)
1829
+
1830
+ tape.backward(loss)
1831
+
1832
+ gradients = 1.0 * tape.gradients[wrt].numpy()
1833
+ tape.zero()
1834
+
1835
+ return loss.numpy()[0], gradients
1836
+
1837
+ # gather gradients from builtin adjoints
1838
+ cmpx, gradients_x = compute_gradients(quat_from_matrix, m, 0)
1839
+ cmpy, gradients_y = compute_gradients(quat_from_matrix, m, 1)
1840
+ cmpz, gradients_z = compute_gradients(quat_from_matrix, m, 2)
1841
+ cmpw, gradients_w = compute_gradients(quat_from_matrix, m, 3)
1842
+
1843
+ # gather gradients from autodiff
1844
+ cmpx_auto, gradients_x_auto = compute_gradients(quat_from_matrix_forward, m, 0)
1845
+ cmpy_auto, gradients_y_auto = compute_gradients(quat_from_matrix_forward, m, 1)
1846
+ cmpz_auto, gradients_z_auto = compute_gradients(quat_from_matrix_forward, m, 2)
1847
+ cmpw_auto, gradients_w_auto = compute_gradients(quat_from_matrix_forward, m, 3)
1848
+
1849
+ # compare
1850
+ eps = 1.0e6
1851
+
1852
+ eps = {
1853
+ np.float16: 2.0e-2,
1854
+ np.float32: 1.0e-5,
1855
+ np.float64: 1.0e-8,
1856
+ }.get(dtype, 0)
1857
+
1858
+ assert_np_equal(cmpx, cmpx_auto, tol=eps)
1859
+ assert_np_equal(cmpy, cmpy_auto, tol=eps)
1860
+ assert_np_equal(cmpz, cmpz_auto, tol=eps)
1861
+ assert_np_equal(cmpw, cmpw_auto, tol=eps)
1862
+
1863
+ assert_np_equal(gradients_x, gradients_x_auto, tol=eps)
1864
+ assert_np_equal(gradients_y, gradients_y_auto, tol=eps)
1865
+ assert_np_equal(gradients_z, gradients_z_auto, tol=eps)
1866
+ assert_np_equal(gradients_w, gradients_w_auto, tol=eps)
1867
+
1868
+
1869
+ def test_quat_identity(test, device, dtype, register_kernels=False):
1870
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1871
+
1872
+ def quat_identity_test(output: wp.array(dtype=wptype)):
1873
+ q = wp.quat_identity(dtype=wptype)
1874
+ output[0] = q[0]
1875
+ output[1] = q[1]
1876
+ output[2] = q[2]
1877
+ output[3] = q[3]
1878
+
1879
+ def quat_identity_test_default(output: wp.array(dtype=wp.float32)):
1880
+ q = wp.quat_identity()
1881
+ output[0] = q[0]
1882
+ output[1] = q[1]
1883
+ output[2] = q[2]
1884
+ output[3] = q[3]
1885
+
1886
+ quat_identity_kernel = getkernel(quat_identity_test, suffix=dtype.__name__)
1887
+ quat_identity_default_kernel = getkernel(quat_identity_test_default, suffix=np.float32.__name__)
1888
+
1889
+ if register_kernels:
1890
+ return
1891
+
1892
+ output = wp.zeros(4, dtype=wptype, device=device)
1893
+ wp.launch(quat_identity_kernel, dim=1, inputs=[], outputs=[output], device=device)
1894
+ expected = np.zeros_like(output.numpy())
1895
+ expected[3] = 1
1896
+ assert_np_equal(output.numpy(), expected)
1897
+
1898
+ # let's just test that it defaults to float32:
1899
+ output = wp.zeros(4, dtype=wp.float32, device=device)
1900
+ wp.launch(quat_identity_default_kernel, dim=1, inputs=[], outputs=[output], device=device)
1901
+ expected = np.zeros_like(output.numpy())
1902
+ expected[3] = 1
1903
+ assert_np_equal(output.numpy(), expected)
1904
+
1905
+
1906
+ ############################################################
1907
+
1908
+
1909
+ def test_quat_assign_inplace(test, device, dtype, register_kernels=False):
1910
+ np_type = np.dtype(dtype)
1911
+ wp_type = wp.types.np_dtype_to_warp_type[np_type]
1912
+
1913
+ quat = wp.types.quaternion(dtype=wp_type)
1914
+
1915
+ def quattest_read_write_store(x: wp.array(dtype=wp_type), a: wp.array(dtype=quat)):
1916
+ tid = wp.tid()
1917
+
1918
+ t = a[tid]
1919
+ t[0] = x[tid]
1920
+ a[tid] = t
1921
+
1922
+ def quattest_in_register(x: wp.array(dtype=wp_type), a: wp.array(dtype=quat)):
1923
+ tid = wp.tid()
1924
+
1925
+ g = wp_type(0.0)
1926
+ q = a[tid]
1927
+ g = q[0] + wp_type(2.0) * q[1] + wp_type(3.0) * q[2] + wp_type(4.0) * q[3]
1928
+ x[tid] = g
1929
+
1930
+ def quattest_component(x: wp.array(dtype=quat), y: wp.array(dtype=wp_type)):
1931
+ i = wp.tid()
1932
+
1933
+ a = quat()
1934
+ a.x = wp_type(1.0) * y[i]
1935
+ a.y = wp_type(2.0) * y[i]
1936
+ a.z = wp_type(3.0) * y[i]
1937
+ a.w = wp_type(4.0) * y[i]
1938
+ x[i] = a
1939
+
1940
+ kernel_read_write_store = getkernel(quattest_read_write_store, suffix=dtype.__name__)
1941
+ kernel_in_register = getkernel(quattest_in_register, suffix=dtype.__name__)
1942
+ kernel_component = getkernel(quattest_component, suffix=dtype.__name__)
1943
+
1944
+ if register_kernels:
1945
+ return
1946
+
1947
+ a = wp.ones(1, dtype=quat, device=device, requires_grad=True)
1948
+ x = wp.full(1, value=2.0, dtype=wp_type, device=device, requires_grad=True)
1949
+
1950
+ tape = wp.Tape()
1951
+ with tape:
1952
+ wp.launch(kernel_read_write_store, dim=1, inputs=[x, a], device=device)
1953
+
1954
+ tape.backward(grads={a: wp.ones_like(a, requires_grad=False)})
1955
+
1956
+ assert_np_equal(a.numpy(), np.array([[2.0, 1.0, 1.0, 1.0]], dtype=np_type))
1957
+ assert_np_equal(x.grad.numpy(), np.array([1.0], dtype=np_type))
1958
+
1959
+ tape.reset()
1960
+
1961
+ a = wp.ones(1, dtype=quat, device=device, requires_grad=True)
1962
+ x = wp.zeros(1, dtype=wp_type, device=device, requires_grad=True)
1963
+
1964
+ with tape:
1965
+ wp.launch(kernel_in_register, dim=1, inputs=[x, a], device=device)
1966
+
1967
+ tape.backward(grads={x: wp.ones_like(x, requires_grad=False)})
1968
+
1969
+ assert_np_equal(x.numpy(), np.array([10.0], dtype=np_type))
1970
+ assert_np_equal(a.grad.numpy(), np.array([[1.0, 2.0, 3.0, 4.0]], dtype=np_type))
1971
+
1972
+ tape.reset()
1973
+
1974
+ x = wp.zeros(1, dtype=quat, requires_grad=True)
1975
+ y = wp.ones(1, dtype=wp_type, requires_grad=True)
1976
+
1977
+ with tape:
1978
+ wp.launch(kernel_component, dim=1, inputs=[x, y])
1979
+
1980
+ tape.backward(grads={x: wp.ones_like(x, requires_grad=False)})
1981
+
1982
+ assert_np_equal(x.numpy(), np.array([[1.0, 2.0, 3.0, 4.0]], dtype=np_type))
1983
+ assert_np_equal(y.grad.numpy(), np.array([10.0], dtype=np_type))
1984
+
1985
+
1986
+ ############################################################
1987
+
1988
+
1989
+ def test_quat_euler_conversion(test, device, dtype, register_kernels=False):
1990
+ rng = np.random.default_rng(123)
1991
+ N = 3
1992
+
1993
+ rpy_arr = rng.uniform(low=-np.pi, high=np.pi, size=(N, 3))
1994
+
1995
+ quats_from_euler = [list(wp.sim.quat_from_euler(wp.vec3(*rpy), 0, 1, 2)) for rpy in rpy_arr]
1996
+ quats_from_rpy = [list(wp.quat_rpy(rpy[0], rpy[1], rpy[2])) for rpy in rpy_arr]
1997
+
1998
+ assert_np_equal(np.array(quats_from_euler), np.array(quats_from_rpy), tol=1e-4)
1999
+
2000
+
2001
+ def test_anon_type_instance(test, device, dtype, register_kernels=False):
2002
+ rng = np.random.default_rng(123)
2003
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
2004
+
2005
+ def quat_create_test(input: wp.array(dtype=wptype), output: wp.array(dtype=wptype)):
2006
+ # component constructor:
2007
+ q = wp.quaternion(input[0], input[1], input[2], input[3])
2008
+ output[0] = wptype(2) * q[0]
2009
+ output[1] = wptype(2) * q[1]
2010
+ output[2] = wptype(2) * q[2]
2011
+ output[3] = wptype(2) * q[3]
2012
+
2013
+ # vector / scalar constructor:
2014
+ q2 = wp.quaternion(wp.vector(input[4], input[5], input[6]), input[7])
2015
+ output[4] = wptype(2) * q2[0]
2016
+ output[5] = wptype(2) * q2[1]
2017
+ output[6] = wptype(2) * q2[2]
2018
+ output[7] = wptype(2) * q2[3]
2019
+
2020
+ quat_create_kernel = getkernel(quat_create_test, suffix=dtype.__name__)
2021
+ output_select_kernel = get_select_kernel(wptype)
2022
+
2023
+ if register_kernels:
2024
+ return
2025
+
2026
+ input = wp.array(rng.standard_normal(size=8).astype(dtype), requires_grad=True, device=device)
2027
+ output = wp.zeros(8, dtype=wptype, requires_grad=True, device=device)
2028
+ wp.launch(quat_create_kernel, dim=1, inputs=[input], outputs=[output], device=device)
2029
+ assert_np_equal(output.numpy(), 2 * input.numpy())
2030
+
2031
+ for i in range(len(input)):
2032
+ cmp = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
2033
+ tape = wp.Tape()
2034
+ with tape:
2035
+ wp.launch(quat_create_kernel, dim=1, inputs=[input], outputs=[output], device=device)
2036
+ wp.launch(output_select_kernel, dim=1, inputs=[output, i], outputs=[cmp], device=device)
2037
+ tape.backward(loss=cmp)
2038
+ expectedgrads = np.zeros(len(input))
2039
+ expectedgrads[i] = 2
2040
+ assert_np_equal(tape.gradients[input].numpy(), expectedgrads)
2041
+ tape.zero()
2042
+
2043
+
2044
+ # Same as above but with a default (float) type
2045
+ # which tests some different code paths that
2046
+ # need to ensure types are correctly canonicalized
2047
+ # during codegen
2048
+ @wp.kernel
2049
+ def test_constructor_default():
2050
+ qzero = wp.quat()
2051
+ wp.expect_eq(qzero[0], 0.0)
2052
+ wp.expect_eq(qzero[1], 0.0)
2053
+ wp.expect_eq(qzero[2], 0.0)
2054
+ wp.expect_eq(qzero[3], 0.0)
2055
+
2056
+ qval = wp.quat(1.0, 2.0, 3.0, 4.0)
2057
+ wp.expect_eq(qval[0], 1.0)
2058
+ wp.expect_eq(qval[1], 2.0)
2059
+ wp.expect_eq(qval[2], 3.0)
2060
+ wp.expect_eq(qval[3], 4.0)
2061
+
2062
+ qeye = wp.quat_identity()
2063
+ wp.expect_eq(qeye[0], 0.0)
2064
+ wp.expect_eq(qeye[1], 0.0)
2065
+ wp.expect_eq(qeye[2], 0.0)
2066
+ wp.expect_eq(qeye[3], 1.0)
2067
+
2068
+
2069
+ def test_py_arithmetic_ops(test, device, dtype):
2070
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
2071
+
2072
+ def make_quat(*args):
2073
+ if wptype in wp.types.int_types:
2074
+ # Cast to the correct integer type to simulate wrapping.
2075
+ return tuple(wptype._type_(x).value for x in args)
2076
+
2077
+ return args
2078
+
2079
+ quat_cls = wp.types.quaternion(wptype)
2080
+
2081
+ v = quat_cls(1, -2, 3, -4)
2082
+ test.assertSequenceEqual(+v, make_quat(1, -2, 3, -4))
2083
+ test.assertSequenceEqual(-v, make_quat(-1, 2, -3, 4))
2084
+ test.assertSequenceEqual(v + quat_cls(5, 5, 5, 5), make_quat(6, 3, 8, 1))
2085
+ test.assertSequenceEqual(v - quat_cls(5, 5, 5, 5), make_quat(-4, -7, -2, -9))
2086
+
2087
+ v = quat_cls(2, 4, 6, 8)
2088
+ test.assertSequenceEqual(v * wptype(2), make_quat(4, 8, 12, 16))
2089
+ test.assertSequenceEqual(wptype(2) * v, make_quat(4, 8, 12, 16))
2090
+ test.assertSequenceEqual(v / wptype(2), make_quat(1, 2, 3, 4))
2091
+ test.assertSequenceEqual(wptype(24) / v, make_quat(12, 6, 4, 3))
2092
+
2093
+
2094
+ @wp.kernel
2095
+ def quat_len_kernel(
2096
+ q: wp.quat,
2097
+ out: wp.array(dtype=int),
2098
+ ):
2099
+ length = wp.static(len(q))
2100
+ wp.expect_eq(wp.static(len(q)), 4)
2101
+ out[0] = wp.static(len(q))
2102
+
2103
+ foo = wp.quat()
2104
+ length = len(foo)
2105
+ wp.expect_eq(len(foo), 4)
2106
+ out[1] = len(foo)
2107
+
2108
+
2109
+ def test_quat_len(test, device):
2110
+ q = wp.quat()
2111
+ out = wp.empty(2, dtype=int, device=device)
2112
+ wp.launch(quat_len_kernel, dim=(1,), inputs=(q,), outputs=(out,), device=device)
2113
+
2114
+ test.assertEqual(out.numpy()[0], 4)
2115
+ test.assertEqual(out.numpy()[1], 4)
2116
+
2117
+
2118
+ @wp.kernel
2119
+ def quat_augassign_kernel(
2120
+ a: wp.array(dtype=wp.quat), b: wp.array(dtype=wp.quat), c: wp.array(dtype=wp.quat), d: wp.array(dtype=wp.quat)
2121
+ ):
2122
+ i = wp.tid()
2123
+
2124
+ q1 = wp.quat()
2125
+ q2 = b[i]
2126
+
2127
+ q1[0] += q2[0]
2128
+ q1[1] += q2[1]
2129
+ q1[2] += q2[2]
2130
+ q1[3] += q2[3]
2131
+
2132
+ a[i] = q1
2133
+
2134
+ q3 = wp.quat()
2135
+ q4 = d[i]
2136
+
2137
+ q3[0] -= q4[0]
2138
+ q3[1] -= q4[1]
2139
+ q3[2] -= q4[2]
2140
+ q3[3] -= q4[3]
2141
+
2142
+ c[i] = q3
2143
+
2144
+
2145
+ def test_quat_augassign(test, device):
2146
+ N = 3
2147
+
2148
+ a = wp.zeros(N, dtype=wp.quat, requires_grad=True, device=device)
2149
+ b = wp.ones(N, dtype=wp.quat, requires_grad=True, device=device)
2150
+
2151
+ c = wp.zeros(N, dtype=wp.quat, requires_grad=True, device=device)
2152
+ d = wp.ones(N, dtype=wp.quat, requires_grad=True, device=device)
2153
+
2154
+ tape = wp.Tape()
2155
+ with tape:
2156
+ wp.launch(quat_augassign_kernel, N, inputs=[a, b, c, d], device=device)
2157
+
2158
+ tape.backward(grads={a: wp.ones_like(a), c: wp.ones_like(c)})
2159
+
2160
+ assert_np_equal(a.numpy(), wp.ones_like(a).numpy())
2161
+ assert_np_equal(a.grad.numpy(), wp.ones_like(a).numpy())
2162
+ assert_np_equal(b.grad.numpy(), wp.ones_like(a).numpy())
2163
+
2164
+ assert_np_equal(c.numpy(), -wp.ones_like(c).numpy())
2165
+ assert_np_equal(c.grad.numpy(), wp.ones_like(c).numpy())
2166
+ assert_np_equal(d.grad.numpy(), -wp.ones_like(d).numpy())
2167
+
2168
+
2169
+ def test_quat_assign_copy(test, device):
2170
+ saved_enable_vector_component_overwrites_setting = wp.config.enable_vector_component_overwrites
2171
+ try:
2172
+ wp.config.enable_vector_component_overwrites = True
2173
+
2174
+ @wp.kernel
2175
+ def quat_in_register_overwrite(x: wp.array(dtype=wp.quat), a: wp.array(dtype=wp.quat)):
2176
+ tid = wp.tid()
2177
+
2178
+ f = wp.quat()
2179
+ a_quat = a[tid]
2180
+ f = a_quat
2181
+ f[1] = 3.0
2182
+
2183
+ x[tid] = f
2184
+
2185
+ x = wp.zeros(1, dtype=wp.quat, device=device, requires_grad=True)
2186
+ a = wp.ones(1, dtype=wp.quat, device=device, requires_grad=True)
2187
+
2188
+ tape = wp.Tape()
2189
+ with tape:
2190
+ wp.launch(quat_in_register_overwrite, dim=1, inputs=[x, a], device=device)
2191
+
2192
+ tape.backward(grads={x: wp.ones_like(x, requires_grad=False)})
2193
+
2194
+ assert_np_equal(x.numpy(), np.array([[1.0, 3.0, 1.0, 1.0]], dtype=float))
2195
+ assert_np_equal(a.grad.numpy(), np.array([[1.0, 0.0, 1.0, 1.0]], dtype=float))
2196
+
2197
+ finally:
2198
+ wp.config.enable_vector_component_overwrites = saved_enable_vector_component_overwrites_setting
2199
+
2200
+
2201
+ devices = get_test_devices()
2202
+
2203
+
2204
+ class TestQuat(unittest.TestCase):
2205
+ pass
2206
+
2207
+
2208
+ add_kernel_test(TestQuat, test_constructor_default, dim=1, devices=devices)
2209
+ add_kernel_test(TestQuat, test_assignment, dim=1, devices=devices)
2210
+
2211
+ for dtype in np_float_types:
2212
+ add_function_test_register_kernel(
2213
+ TestQuat, f"test_constructors_{dtype.__name__}", test_constructors, devices=devices, dtype=dtype
2214
+ )
2215
+ add_function_test_register_kernel(
2216
+ TestQuat,
2217
+ f"test_casting_constructors_{dtype.__name__}",
2218
+ test_casting_constructors,
2219
+ devices=devices,
2220
+ dtype=dtype,
2221
+ )
2222
+ add_function_test_register_kernel(
2223
+ TestQuat, f"test_anon_type_instance_{dtype.__name__}", test_anon_type_instance, devices=devices, dtype=dtype
2224
+ )
2225
+ add_function_test_register_kernel(
2226
+ TestQuat, f"test_inverse_{dtype.__name__}", test_inverse, devices=devices, dtype=dtype
2227
+ )
2228
+ add_function_test_register_kernel(
2229
+ TestQuat, f"test_quat_identity_{dtype.__name__}", test_quat_identity, devices=devices, dtype=dtype
2230
+ )
2231
+ add_function_test_register_kernel(
2232
+ TestQuat, f"test_dotproduct_{dtype.__name__}", test_dotproduct, devices=devices, dtype=dtype
2233
+ )
2234
+ add_function_test_register_kernel(
2235
+ TestQuat, f"test_length_{dtype.__name__}", test_length, devices=devices, dtype=dtype
2236
+ )
2237
+ add_function_test_register_kernel(
2238
+ TestQuat, f"test_normalize_{dtype.__name__}", test_normalize, devices=devices, dtype=dtype
2239
+ )
2240
+ add_function_test_register_kernel(
2241
+ TestQuat, f"test_addition_{dtype.__name__}", test_addition, devices=devices, dtype=dtype
2242
+ )
2243
+ add_function_test_register_kernel(
2244
+ TestQuat, f"test_subtraction_{dtype.__name__}", test_subtraction, devices=devices, dtype=dtype
2245
+ )
2246
+ add_function_test_register_kernel(
2247
+ TestQuat,
2248
+ f"test_scalar_multiplication_{dtype.__name__}",
2249
+ test_scalar_multiplication,
2250
+ devices=devices,
2251
+ dtype=dtype,
2252
+ )
2253
+ add_function_test_register_kernel(
2254
+ TestQuat, f"test_scalar_division_{dtype.__name__}", test_scalar_division, devices=devices, dtype=dtype
2255
+ )
2256
+ add_function_test_register_kernel(
2257
+ TestQuat,
2258
+ f"test_quat_multiplication_{dtype.__name__}",
2259
+ test_quat_multiplication,
2260
+ devices=devices,
2261
+ dtype=dtype,
2262
+ )
2263
+ add_function_test_register_kernel(
2264
+ TestQuat, f"test_indexing_{dtype.__name__}", test_indexing, devices=devices, dtype=dtype
2265
+ )
2266
+ add_function_test_register_kernel(
2267
+ TestQuat, f"test_quat_lerp_{dtype.__name__}", test_quat_lerp, devices=devices, dtype=dtype
2268
+ )
2269
+ add_function_test_register_kernel(
2270
+ TestQuat,
2271
+ f"test_quat_to_axis_angle_grad_{dtype.__name__}",
2272
+ test_quat_to_axis_angle_grad,
2273
+ devices=devices,
2274
+ dtype=dtype,
2275
+ )
2276
+ add_function_test_register_kernel(
2277
+ TestQuat, f"test_slerp_grad_{dtype.__name__}", test_slerp_grad, devices=devices, dtype=dtype
2278
+ )
2279
+ add_function_test_register_kernel(
2280
+ TestQuat, f"test_quat_rpy_grad_{dtype.__name__}", test_quat_rpy_grad, devices=devices, dtype=dtype
2281
+ )
2282
+ add_function_test_register_kernel(
2283
+ TestQuat, f"test_quat_from_matrix_{dtype.__name__}", test_quat_from_matrix, devices=devices, dtype=dtype
2284
+ )
2285
+ add_function_test_register_kernel(
2286
+ TestQuat, f"test_quat_rotate_{dtype.__name__}", test_quat_rotate, devices=devices, dtype=dtype
2287
+ )
2288
+ add_function_test_register_kernel(
2289
+ TestQuat, f"test_quat_to_matrix_{dtype.__name__}", test_quat_to_matrix, devices=devices, dtype=dtype
2290
+ )
2291
+ add_function_test_register_kernel(
2292
+ TestQuat,
2293
+ f"test_quat_euler_conversion_{dtype.__name__}",
2294
+ test_quat_euler_conversion,
2295
+ devices=devices,
2296
+ dtype=dtype,
2297
+ )
2298
+ add_function_test_register_kernel(
2299
+ TestQuat,
2300
+ f"test_quat_assign_inplace_{dtype.__name__}",
2301
+ test_quat_assign_inplace,
2302
+ devices=devices,
2303
+ dtype=dtype,
2304
+ )
2305
+ add_function_test(
2306
+ TestQuat, f"test_py_arithmetic_ops_{dtype.__name__}", test_py_arithmetic_ops, devices=None, dtype=dtype
2307
+ )
2308
+
2309
+ add_function_test(TestQuat, "test_quat_len", test_quat_len, devices=devices)
2310
+ add_function_test(TestQuat, "test_quat_augassign", test_quat_augassign, devices=devices)
2311
+ add_function_test(TestQuat, "test_quat_assign_copy", test_quat_assign_copy, devices=devices)
2312
+
2313
+ if __name__ == "__main__":
2314
+ wp.clear_kernel_cache()
2315
+ unittest.main(verbosity=2)