warp-lang 1.10.0__py3-none-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (468) hide show
  1. warp/__init__.py +334 -0
  2. warp/__init__.pyi +5856 -0
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1077 -0
  5. warp/_src/build.py +620 -0
  6. warp/_src/build_dll.py +642 -0
  7. warp/_src/builtins.py +10555 -0
  8. warp/_src/codegen.py +4361 -0
  9. warp/_src/config.py +178 -0
  10. warp/_src/constants.py +59 -0
  11. warp/_src/context.py +8352 -0
  12. warp/_src/dlpack.py +464 -0
  13. warp/_src/fabric.py +362 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +510 -0
  16. warp/_src/fem/cache.py +689 -0
  17. warp/_src/fem/dirichlet.py +190 -0
  18. warp/_src/fem/domain.py +553 -0
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +703 -0
  21. warp/_src/fem/field/nodal_field.py +403 -0
  22. warp/_src/fem/field/restriction.py +39 -0
  23. warp/_src/fem/field/virtual.py +1021 -0
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/_src/fem/geometry/adaptive_nanogrid.py +782 -0
  26. warp/_src/fem/geometry/closest_point.py +99 -0
  27. warp/_src/fem/geometry/deformed_geometry.py +277 -0
  28. warp/_src/fem/geometry/element.py +854 -0
  29. warp/_src/fem/geometry/geometry.py +693 -0
  30. warp/_src/fem/geometry/grid_2d.py +478 -0
  31. warp/_src/fem/geometry/grid_3d.py +539 -0
  32. warp/_src/fem/geometry/hexmesh.py +956 -0
  33. warp/_src/fem/geometry/nanogrid.py +660 -0
  34. warp/_src/fem/geometry/partition.py +483 -0
  35. warp/_src/fem/geometry/quadmesh.py +597 -0
  36. warp/_src/fem/geometry/tetmesh.py +762 -0
  37. warp/_src/fem/geometry/trimesh.py +588 -0
  38. warp/_src/fem/integrate.py +2507 -0
  39. warp/_src/fem/linalg.py +385 -0
  40. warp/_src/fem/operator.py +398 -0
  41. warp/_src/fem/polynomial.py +231 -0
  42. warp/_src/fem/quadrature/__init__.py +17 -0
  43. warp/_src/fem/quadrature/pic_quadrature.py +318 -0
  44. warp/_src/fem/quadrature/quadrature.py +665 -0
  45. warp/_src/fem/space/__init__.py +248 -0
  46. warp/_src/fem/space/basis_function_space.py +499 -0
  47. warp/_src/fem/space/basis_space.py +681 -0
  48. warp/_src/fem/space/dof_mapper.py +253 -0
  49. warp/_src/fem/space/function_space.py +312 -0
  50. warp/_src/fem/space/grid_2d_function_space.py +179 -0
  51. warp/_src/fem/space/grid_3d_function_space.py +229 -0
  52. warp/_src/fem/space/hexmesh_function_space.py +255 -0
  53. warp/_src/fem/space/nanogrid_function_space.py +199 -0
  54. warp/_src/fem/space/partition.py +435 -0
  55. warp/_src/fem/space/quadmesh_function_space.py +222 -0
  56. warp/_src/fem/space/restriction.py +221 -0
  57. warp/_src/fem/space/shape/__init__.py +152 -0
  58. warp/_src/fem/space/shape/cube_shape_function.py +1107 -0
  59. warp/_src/fem/space/shape/shape_function.py +134 -0
  60. warp/_src/fem/space/shape/square_shape_function.py +928 -0
  61. warp/_src/fem/space/shape/tet_shape_function.py +829 -0
  62. warp/_src/fem/space/shape/triangle_shape_function.py +674 -0
  63. warp/_src/fem/space/tetmesh_function_space.py +270 -0
  64. warp/_src/fem/space/topology.py +461 -0
  65. warp/_src/fem/space/trimesh_function_space.py +193 -0
  66. warp/_src/fem/types.py +114 -0
  67. warp/_src/fem/utils.py +488 -0
  68. warp/_src/jax.py +188 -0
  69. warp/_src/jax_experimental/__init__.py +14 -0
  70. warp/_src/jax_experimental/custom_call.py +389 -0
  71. warp/_src/jax_experimental/ffi.py +1286 -0
  72. warp/_src/jax_experimental/xla_ffi.py +658 -0
  73. warp/_src/marching_cubes.py +710 -0
  74. warp/_src/math.py +416 -0
  75. warp/_src/optim/__init__.py +14 -0
  76. warp/_src/optim/adam.py +165 -0
  77. warp/_src/optim/linear.py +1608 -0
  78. warp/_src/optim/sgd.py +114 -0
  79. warp/_src/paddle.py +408 -0
  80. warp/_src/render/__init__.py +14 -0
  81. warp/_src/render/imgui_manager.py +291 -0
  82. warp/_src/render/render_opengl.py +3638 -0
  83. warp/_src/render/render_usd.py +939 -0
  84. warp/_src/render/utils.py +162 -0
  85. warp/_src/sparse.py +2718 -0
  86. warp/_src/tape.py +1208 -0
  87. warp/_src/thirdparty/__init__.py +0 -0
  88. warp/_src/thirdparty/appdirs.py +598 -0
  89. warp/_src/thirdparty/dlpack.py +145 -0
  90. warp/_src/thirdparty/unittest_parallel.py +676 -0
  91. warp/_src/torch.py +393 -0
  92. warp/_src/types.py +5888 -0
  93. warp/_src/utils.py +1695 -0
  94. warp/autograd.py +33 -0
  95. warp/bin/libwarp-clang.dylib +0 -0
  96. warp/bin/libwarp.dylib +0 -0
  97. warp/build.py +29 -0
  98. warp/build_dll.py +24 -0
  99. warp/codegen.py +24 -0
  100. warp/constants.py +24 -0
  101. warp/context.py +33 -0
  102. warp/dlpack.py +24 -0
  103. warp/examples/__init__.py +24 -0
  104. warp/examples/assets/bear.usd +0 -0
  105. warp/examples/assets/bunny.usd +0 -0
  106. warp/examples/assets/cube.usd +0 -0
  107. warp/examples/assets/nonuniform.usd +0 -0
  108. warp/examples/assets/nvidia_logo.png +0 -0
  109. warp/examples/assets/pixel.jpg +0 -0
  110. warp/examples/assets/rocks.nvdb +0 -0
  111. warp/examples/assets/rocks.usd +0 -0
  112. warp/examples/assets/sphere.usd +0 -0
  113. warp/examples/assets/square_cloth.usd +0 -0
  114. warp/examples/benchmarks/benchmark_api.py +389 -0
  115. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  116. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  117. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  118. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  119. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  120. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  121. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  122. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  123. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  124. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  125. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  126. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  127. warp/examples/benchmarks/benchmark_launches.py +301 -0
  128. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  129. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  130. warp/examples/browse.py +37 -0
  131. warp/examples/core/example_cupy.py +86 -0
  132. warp/examples/core/example_dem.py +241 -0
  133. warp/examples/core/example_fluid.py +299 -0
  134. warp/examples/core/example_graph_capture.py +150 -0
  135. warp/examples/core/example_marching_cubes.py +195 -0
  136. warp/examples/core/example_mesh.py +180 -0
  137. warp/examples/core/example_mesh_intersect.py +211 -0
  138. warp/examples/core/example_nvdb.py +182 -0
  139. warp/examples/core/example_raycast.py +111 -0
  140. warp/examples/core/example_raymarch.py +205 -0
  141. warp/examples/core/example_render_opengl.py +290 -0
  142. warp/examples/core/example_sample_mesh.py +300 -0
  143. warp/examples/core/example_sph.py +411 -0
  144. warp/examples/core/example_spin_lock.py +93 -0
  145. warp/examples/core/example_torch.py +211 -0
  146. warp/examples/core/example_wave.py +269 -0
  147. warp/examples/core/example_work_queue.py +118 -0
  148. warp/examples/distributed/example_jacobi_mpi.py +506 -0
  149. warp/examples/fem/example_adaptive_grid.py +286 -0
  150. warp/examples/fem/example_apic_fluid.py +469 -0
  151. warp/examples/fem/example_burgers.py +261 -0
  152. warp/examples/fem/example_convection_diffusion.py +181 -0
  153. warp/examples/fem/example_convection_diffusion_dg.py +225 -0
  154. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  155. warp/examples/fem/example_deformed_geometry.py +172 -0
  156. warp/examples/fem/example_diffusion.py +196 -0
  157. warp/examples/fem/example_diffusion_3d.py +225 -0
  158. warp/examples/fem/example_diffusion_mgpu.py +225 -0
  159. warp/examples/fem/example_distortion_energy.py +228 -0
  160. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  161. warp/examples/fem/example_magnetostatics.py +242 -0
  162. warp/examples/fem/example_mixed_elasticity.py +293 -0
  163. warp/examples/fem/example_navier_stokes.py +263 -0
  164. warp/examples/fem/example_nonconforming_contact.py +300 -0
  165. warp/examples/fem/example_stokes.py +213 -0
  166. warp/examples/fem/example_stokes_transfer.py +262 -0
  167. warp/examples/fem/example_streamlines.py +357 -0
  168. warp/examples/fem/utils.py +1047 -0
  169. warp/examples/interop/example_jax_callable.py +146 -0
  170. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  171. warp/examples/interop/example_jax_kernel.py +232 -0
  172. warp/examples/optim/example_diffray.py +561 -0
  173. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  174. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  175. warp/examples/tile/example_tile_cholesky.py +88 -0
  176. warp/examples/tile/example_tile_convolution.py +66 -0
  177. warp/examples/tile/example_tile_fft.py +55 -0
  178. warp/examples/tile/example_tile_filtering.py +113 -0
  179. warp/examples/tile/example_tile_matmul.py +85 -0
  180. warp/examples/tile/example_tile_mcgp.py +191 -0
  181. warp/examples/tile/example_tile_mlp.py +385 -0
  182. warp/examples/tile/example_tile_nbody.py +199 -0
  183. warp/fabric.py +24 -0
  184. warp/fem/__init__.py +173 -0
  185. warp/fem/adaptivity.py +26 -0
  186. warp/fem/cache.py +30 -0
  187. warp/fem/dirichlet.py +24 -0
  188. warp/fem/field/__init__.py +24 -0
  189. warp/fem/field/field.py +26 -0
  190. warp/fem/geometry/__init__.py +21 -0
  191. warp/fem/geometry/closest_point.py +31 -0
  192. warp/fem/linalg.py +38 -0
  193. warp/fem/operator.py +32 -0
  194. warp/fem/polynomial.py +29 -0
  195. warp/fem/space/__init__.py +22 -0
  196. warp/fem/space/basis_space.py +24 -0
  197. warp/fem/space/shape/__init__.py +68 -0
  198. warp/fem/space/topology.py +24 -0
  199. warp/fem/types.py +24 -0
  200. warp/fem/utils.py +32 -0
  201. warp/jax.py +29 -0
  202. warp/jax_experimental/__init__.py +29 -0
  203. warp/jax_experimental/custom_call.py +29 -0
  204. warp/jax_experimental/ffi.py +39 -0
  205. warp/jax_experimental/xla_ffi.py +24 -0
  206. warp/marching_cubes.py +24 -0
  207. warp/math.py +37 -0
  208. warp/native/array.h +1687 -0
  209. warp/native/builtin.h +2327 -0
  210. warp/native/bvh.cpp +562 -0
  211. warp/native/bvh.cu +826 -0
  212. warp/native/bvh.h +555 -0
  213. warp/native/clang/clang.cpp +541 -0
  214. warp/native/coloring.cpp +622 -0
  215. warp/native/crt.cpp +51 -0
  216. warp/native/crt.h +568 -0
  217. warp/native/cuda_crt.h +1058 -0
  218. warp/native/cuda_util.cpp +677 -0
  219. warp/native/cuda_util.h +313 -0
  220. warp/native/error.cpp +77 -0
  221. warp/native/error.h +36 -0
  222. warp/native/exports.h +2023 -0
  223. warp/native/fabric.h +246 -0
  224. warp/native/hashgrid.cpp +311 -0
  225. warp/native/hashgrid.cu +89 -0
  226. warp/native/hashgrid.h +240 -0
  227. warp/native/initializer_array.h +41 -0
  228. warp/native/intersect.h +1253 -0
  229. warp/native/intersect_adj.h +375 -0
  230. warp/native/intersect_tri.h +348 -0
  231. warp/native/mat.h +5189 -0
  232. warp/native/mathdx.cpp +93 -0
  233. warp/native/matnn.h +221 -0
  234. warp/native/mesh.cpp +266 -0
  235. warp/native/mesh.cu +406 -0
  236. warp/native/mesh.h +2097 -0
  237. warp/native/nanovdb/GridHandle.h +533 -0
  238. warp/native/nanovdb/HostBuffer.h +591 -0
  239. warp/native/nanovdb/NanoVDB.h +6246 -0
  240. warp/native/nanovdb/NodeManager.h +323 -0
  241. warp/native/nanovdb/PNanoVDB.h +3390 -0
  242. warp/native/noise.h +859 -0
  243. warp/native/quat.h +1664 -0
  244. warp/native/rand.h +342 -0
  245. warp/native/range.h +145 -0
  246. warp/native/reduce.cpp +174 -0
  247. warp/native/reduce.cu +363 -0
  248. warp/native/runlength_encode.cpp +79 -0
  249. warp/native/runlength_encode.cu +61 -0
  250. warp/native/scan.cpp +47 -0
  251. warp/native/scan.cu +55 -0
  252. warp/native/scan.h +23 -0
  253. warp/native/solid_angle.h +466 -0
  254. warp/native/sort.cpp +251 -0
  255. warp/native/sort.cu +286 -0
  256. warp/native/sort.h +35 -0
  257. warp/native/sparse.cpp +241 -0
  258. warp/native/sparse.cu +435 -0
  259. warp/native/spatial.h +1306 -0
  260. warp/native/svd.h +727 -0
  261. warp/native/temp_buffer.h +46 -0
  262. warp/native/tile.h +4124 -0
  263. warp/native/tile_radix_sort.h +1112 -0
  264. warp/native/tile_reduce.h +838 -0
  265. warp/native/tile_scan.h +240 -0
  266. warp/native/tuple.h +189 -0
  267. warp/native/vec.h +2199 -0
  268. warp/native/version.h +23 -0
  269. warp/native/volume.cpp +501 -0
  270. warp/native/volume.cu +68 -0
  271. warp/native/volume.h +970 -0
  272. warp/native/volume_builder.cu +483 -0
  273. warp/native/volume_builder.h +52 -0
  274. warp/native/volume_impl.h +70 -0
  275. warp/native/warp.cpp +1143 -0
  276. warp/native/warp.cu +4604 -0
  277. warp/native/warp.h +358 -0
  278. warp/optim/__init__.py +20 -0
  279. warp/optim/adam.py +24 -0
  280. warp/optim/linear.py +35 -0
  281. warp/optim/sgd.py +24 -0
  282. warp/paddle.py +24 -0
  283. warp/py.typed +0 -0
  284. warp/render/__init__.py +22 -0
  285. warp/render/imgui_manager.py +29 -0
  286. warp/render/render_opengl.py +24 -0
  287. warp/render/render_usd.py +24 -0
  288. warp/render/utils.py +24 -0
  289. warp/sparse.py +51 -0
  290. warp/tape.py +24 -0
  291. warp/tests/__init__.py +1 -0
  292. warp/tests/__main__.py +4 -0
  293. warp/tests/assets/curlnoise_golden.npy +0 -0
  294. warp/tests/assets/mlp_golden.npy +0 -0
  295. warp/tests/assets/pixel.npy +0 -0
  296. warp/tests/assets/pnoise_golden.npy +0 -0
  297. warp/tests/assets/spiky.usd +0 -0
  298. warp/tests/assets/test_grid.nvdb +0 -0
  299. warp/tests/assets/test_index_grid.nvdb +0 -0
  300. warp/tests/assets/test_int32_grid.nvdb +0 -0
  301. warp/tests/assets/test_vec_grid.nvdb +0 -0
  302. warp/tests/assets/torus.nvdb +0 -0
  303. warp/tests/assets/torus.usda +105 -0
  304. warp/tests/aux_test_class_kernel.py +34 -0
  305. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  306. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  307. warp/tests/aux_test_dependent.py +29 -0
  308. warp/tests/aux_test_grad_customs.py +29 -0
  309. warp/tests/aux_test_instancing_gc.py +26 -0
  310. warp/tests/aux_test_module_aot.py +7 -0
  311. warp/tests/aux_test_module_unload.py +23 -0
  312. warp/tests/aux_test_name_clash1.py +40 -0
  313. warp/tests/aux_test_name_clash2.py +40 -0
  314. warp/tests/aux_test_reference.py +9 -0
  315. warp/tests/aux_test_reference_reference.py +8 -0
  316. warp/tests/aux_test_square.py +16 -0
  317. warp/tests/aux_test_unresolved_func.py +22 -0
  318. warp/tests/aux_test_unresolved_symbol.py +22 -0
  319. warp/tests/cuda/__init__.py +0 -0
  320. warp/tests/cuda/test_async.py +676 -0
  321. warp/tests/cuda/test_conditional_captures.py +1147 -0
  322. warp/tests/cuda/test_ipc.py +124 -0
  323. warp/tests/cuda/test_mempool.py +233 -0
  324. warp/tests/cuda/test_multigpu.py +169 -0
  325. warp/tests/cuda/test_peer.py +139 -0
  326. warp/tests/cuda/test_pinned.py +84 -0
  327. warp/tests/cuda/test_streams.py +691 -0
  328. warp/tests/geometry/__init__.py +0 -0
  329. warp/tests/geometry/test_bvh.py +335 -0
  330. warp/tests/geometry/test_hash_grid.py +259 -0
  331. warp/tests/geometry/test_marching_cubes.py +294 -0
  332. warp/tests/geometry/test_mesh.py +318 -0
  333. warp/tests/geometry/test_mesh_query_aabb.py +392 -0
  334. warp/tests/geometry/test_mesh_query_point.py +935 -0
  335. warp/tests/geometry/test_mesh_query_ray.py +323 -0
  336. warp/tests/geometry/test_volume.py +1103 -0
  337. warp/tests/geometry/test_volume_write.py +346 -0
  338. warp/tests/interop/__init__.py +0 -0
  339. warp/tests/interop/test_dlpack.py +730 -0
  340. warp/tests/interop/test_jax.py +1673 -0
  341. warp/tests/interop/test_paddle.py +800 -0
  342. warp/tests/interop/test_torch.py +1001 -0
  343. warp/tests/run_coverage_serial.py +39 -0
  344. warp/tests/test_adam.py +162 -0
  345. warp/tests/test_arithmetic.py +1096 -0
  346. warp/tests/test_array.py +3756 -0
  347. warp/tests/test_array_reduce.py +156 -0
  348. warp/tests/test_assert.py +303 -0
  349. warp/tests/test_atomic.py +336 -0
  350. warp/tests/test_atomic_bitwise.py +209 -0
  351. warp/tests/test_atomic_cas.py +312 -0
  352. warp/tests/test_bool.py +220 -0
  353. warp/tests/test_builtins_resolution.py +732 -0
  354. warp/tests/test_closest_point_edge_edge.py +327 -0
  355. warp/tests/test_codegen.py +974 -0
  356. warp/tests/test_codegen_instancing.py +1495 -0
  357. warp/tests/test_compile_consts.py +215 -0
  358. warp/tests/test_conditional.py +298 -0
  359. warp/tests/test_context.py +35 -0
  360. warp/tests/test_copy.py +319 -0
  361. warp/tests/test_ctypes.py +618 -0
  362. warp/tests/test_dense.py +73 -0
  363. warp/tests/test_devices.py +127 -0
  364. warp/tests/test_enum.py +136 -0
  365. warp/tests/test_examples.py +424 -0
  366. warp/tests/test_fabricarray.py +998 -0
  367. warp/tests/test_fast_math.py +72 -0
  368. warp/tests/test_fem.py +2204 -0
  369. warp/tests/test_fixedarray.py +229 -0
  370. warp/tests/test_fp16.py +136 -0
  371. warp/tests/test_func.py +501 -0
  372. warp/tests/test_future_annotations.py +100 -0
  373. warp/tests/test_generics.py +656 -0
  374. warp/tests/test_grad.py +893 -0
  375. warp/tests/test_grad_customs.py +339 -0
  376. warp/tests/test_grad_debug.py +341 -0
  377. warp/tests/test_implicit_init.py +411 -0
  378. warp/tests/test_import.py +45 -0
  379. warp/tests/test_indexedarray.py +1140 -0
  380. warp/tests/test_intersect.py +103 -0
  381. warp/tests/test_iter.py +76 -0
  382. warp/tests/test_large.py +177 -0
  383. warp/tests/test_launch.py +411 -0
  384. warp/tests/test_lerp.py +151 -0
  385. warp/tests/test_linear_solvers.py +223 -0
  386. warp/tests/test_lvalue.py +427 -0
  387. warp/tests/test_map.py +526 -0
  388. warp/tests/test_mat.py +3515 -0
  389. warp/tests/test_mat_assign_copy.py +178 -0
  390. warp/tests/test_mat_constructors.py +573 -0
  391. warp/tests/test_mat_lite.py +122 -0
  392. warp/tests/test_mat_scalar_ops.py +2913 -0
  393. warp/tests/test_math.py +212 -0
  394. warp/tests/test_module_aot.py +287 -0
  395. warp/tests/test_module_hashing.py +258 -0
  396. warp/tests/test_modules_lite.py +70 -0
  397. warp/tests/test_noise.py +252 -0
  398. warp/tests/test_operators.py +299 -0
  399. warp/tests/test_options.py +129 -0
  400. warp/tests/test_overwrite.py +551 -0
  401. warp/tests/test_print.py +408 -0
  402. warp/tests/test_quat.py +2653 -0
  403. warp/tests/test_quat_assign_copy.py +145 -0
  404. warp/tests/test_rand.py +339 -0
  405. warp/tests/test_reload.py +303 -0
  406. warp/tests/test_rounding.py +157 -0
  407. warp/tests/test_runlength_encode.py +196 -0
  408. warp/tests/test_scalar_ops.py +133 -0
  409. warp/tests/test_smoothstep.py +108 -0
  410. warp/tests/test_snippet.py +318 -0
  411. warp/tests/test_sparse.py +845 -0
  412. warp/tests/test_spatial.py +2859 -0
  413. warp/tests/test_spatial_assign_copy.py +160 -0
  414. warp/tests/test_special_values.py +361 -0
  415. warp/tests/test_static.py +640 -0
  416. warp/tests/test_struct.py +901 -0
  417. warp/tests/test_tape.py +242 -0
  418. warp/tests/test_transient_module.py +93 -0
  419. warp/tests/test_triangle_closest_point.py +192 -0
  420. warp/tests/test_tuple.py +361 -0
  421. warp/tests/test_types.py +615 -0
  422. warp/tests/test_utils.py +594 -0
  423. warp/tests/test_vec.py +1408 -0
  424. warp/tests/test_vec_assign_copy.py +143 -0
  425. warp/tests/test_vec_constructors.py +325 -0
  426. warp/tests/test_vec_lite.py +80 -0
  427. warp/tests/test_vec_scalar_ops.py +2327 -0
  428. warp/tests/test_verify_fp.py +100 -0
  429. warp/tests/test_version.py +75 -0
  430. warp/tests/tile/__init__.py +0 -0
  431. warp/tests/tile/test_tile.py +1519 -0
  432. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  433. warp/tests/tile/test_tile_cholesky.py +608 -0
  434. warp/tests/tile/test_tile_load.py +724 -0
  435. warp/tests/tile/test_tile_mathdx.py +156 -0
  436. warp/tests/tile/test_tile_matmul.py +179 -0
  437. warp/tests/tile/test_tile_mlp.py +400 -0
  438. warp/tests/tile/test_tile_reduce.py +950 -0
  439. warp/tests/tile/test_tile_shared_memory.py +376 -0
  440. warp/tests/tile/test_tile_sort.py +121 -0
  441. warp/tests/tile/test_tile_view.py +173 -0
  442. warp/tests/unittest_serial.py +47 -0
  443. warp/tests/unittest_suites.py +430 -0
  444. warp/tests/unittest_utils.py +469 -0
  445. warp/tests/walkthrough_debug.py +95 -0
  446. warp/torch.py +24 -0
  447. warp/types.py +51 -0
  448. warp/utils.py +31 -0
  449. warp_lang-1.10.0.dist-info/METADATA +459 -0
  450. warp_lang-1.10.0.dist-info/RECORD +468 -0
  451. warp_lang-1.10.0.dist-info/WHEEL +5 -0
  452. warp_lang-1.10.0.dist-info/licenses/LICENSE.md +176 -0
  453. warp_lang-1.10.0.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  454. warp_lang-1.10.0.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  455. warp_lang-1.10.0.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  456. warp_lang-1.10.0.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  457. warp_lang-1.10.0.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  458. warp_lang-1.10.0.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  459. warp_lang-1.10.0.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  460. warp_lang-1.10.0.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  461. warp_lang-1.10.0.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  462. warp_lang-1.10.0.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  463. warp_lang-1.10.0.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  464. warp_lang-1.10.0.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  465. warp_lang-1.10.0.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  466. warp_lang-1.10.0.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  467. warp_lang-1.10.0.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  468. warp_lang-1.10.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,2653 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import unittest
17
+
18
+ import numpy as np
19
+
20
+ import warp as wp
21
+ from warp.tests.unittest_utils import *
22
+
23
+ np_float_types = [np.float32, np.float64, np.float16]
24
+
25
+ kernel_cache = {}
26
+
27
+
28
+ @wp.func
29
+ def quat_from_euler(e: wp.vec3, i: int, j: int, k: int) -> wp.quat:
30
+ """
31
+ Convert Euler angles to a quaternion.
32
+
33
+ :math:`i, j, k` are the indices in :math:`[0, 1, 2]` of the axes in which the Euler angles are provided
34
+ (:math:`i \\neq j, j \\neq k`), e.g. (0, 1, 2) for Euler sequence XYZ.
35
+
36
+ Args:
37
+ e (vec3): The Euler angles (in radians)
38
+ i (int): The index of the first axis
39
+ j (int): The index of the second axis
40
+ k (int): The index of the third axis
41
+
42
+ Returns:
43
+ quat: The quaternion
44
+ """
45
+ # Half angles
46
+ half_e = e / 2.0
47
+
48
+ # Precompute sines and cosines of half angles
49
+ cr = wp.cos(half_e[i])
50
+ sr = wp.sin(half_e[i])
51
+ cp = wp.cos(half_e[j])
52
+ sp = wp.sin(half_e[j])
53
+ cy = wp.cos(half_e[k])
54
+ sy = wp.sin(half_e[k])
55
+
56
+ # Components of the quaternion based on the rotation sequence
57
+ return wp.quat(
58
+ (cy * sr * cp - sy * cr * sp),
59
+ (cy * cr * sp + sy * sr * cp),
60
+ (sy * cr * cp - cy * sr * sp),
61
+ (cy * cr * cp + sy * sr * sp),
62
+ )
63
+
64
+
65
+ def getkernel(func, suffix=""):
66
+ key = func.__name__ + "_" + suffix
67
+ if key not in kernel_cache:
68
+ kernel_cache[key] = wp.Kernel(func=func, key=key)
69
+ return kernel_cache[key]
70
+
71
+
72
+ def get_select_kernel(dtype):
73
+ def output_select_kernel_fn(
74
+ input: wp.array(dtype=dtype),
75
+ index: int,
76
+ out: wp.array(dtype=dtype),
77
+ ):
78
+ out[0] = input[index]
79
+
80
+ return getkernel(output_select_kernel_fn, suffix=dtype.__name__)
81
+
82
+
83
+ ############################################################
84
+
85
+
86
+ def test_constructors(test, device, dtype, register_kernels=False):
87
+ rng = np.random.default_rng(123)
88
+
89
+ tol = {
90
+ np.float16: 5.0e-3,
91
+ np.float32: 1.0e-6,
92
+ np.float64: 1.0e-8,
93
+ }.get(dtype, 0)
94
+
95
+ wptype = wp._src.types.np_dtype_to_warp_type[np.dtype(dtype)]
96
+ vec3 = wp._src.types.vector(length=3, dtype=wptype)
97
+ quat = wp._src.types.quaternion(dtype=wptype)
98
+
99
+ def check_component_constructor(
100
+ input: wp.array(dtype=wptype),
101
+ q: wp.array(dtype=wptype),
102
+ ):
103
+ qresult = quat(input[0], input[1], input[2], input[3])
104
+
105
+ # multiply the output by 2 so we've got something to backpropagate:
106
+ q[0] = wptype(2) * qresult[0]
107
+ q[1] = wptype(2) * qresult[1]
108
+ q[2] = wptype(2) * qresult[2]
109
+ q[3] = wptype(2) * qresult[3]
110
+
111
+ def check_vector_constructor(
112
+ input: wp.array(dtype=wptype),
113
+ q: wp.array(dtype=wptype),
114
+ ):
115
+ qresult = quat(vec3(input[0], input[1], input[2]), input[3])
116
+
117
+ # multiply the output by 2 so we've got something to backpropagate:
118
+ q[0] = wptype(2) * qresult[0]
119
+ q[1] = wptype(2) * qresult[1]
120
+ q[2] = wptype(2) * qresult[2]
121
+ q[3] = wptype(2) * qresult[3]
122
+
123
+ kernel = getkernel(check_component_constructor, suffix=dtype.__name__)
124
+ output_select_kernel = get_select_kernel(wptype)
125
+ vec_kernel = getkernel(check_vector_constructor, suffix=dtype.__name__)
126
+
127
+ if register_kernels:
128
+ return
129
+
130
+ input = wp.array(rng.standard_normal(size=4).astype(dtype), requires_grad=True, device=device)
131
+ output = wp.zeros_like(input)
132
+ wp.launch(kernel, dim=1, inputs=[input], outputs=[output], device=device)
133
+
134
+ assert_np_equal(output.numpy(), 2 * input.numpy(), tol=tol)
135
+
136
+ for i in range(4):
137
+ cmp = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
138
+ tape = wp.Tape()
139
+ with tape:
140
+ wp.launch(kernel, dim=1, inputs=[input], outputs=[output], device=device)
141
+ wp.launch(output_select_kernel, dim=1, inputs=[output, i], outputs=[cmp], device=device)
142
+ tape.backward(loss=cmp)
143
+ expectedgrads = np.zeros(len(input))
144
+ expectedgrads[i] = 2
145
+ assert_np_equal(tape.gradients[input].numpy(), expectedgrads)
146
+ tape.zero()
147
+
148
+ input = wp.array(rng.standard_normal(size=4).astype(dtype), requires_grad=True, device=device)
149
+ output = wp.zeros_like(input)
150
+ wp.launch(vec_kernel, dim=1, inputs=[input], outputs=[output], device=device)
151
+
152
+ assert_np_equal(output.numpy(), 2 * input.numpy(), tol=tol)
153
+
154
+ for i in range(4):
155
+ cmp = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
156
+ tape = wp.Tape()
157
+ with tape:
158
+ wp.launch(vec_kernel, dim=1, inputs=[input], outputs=[output], device=device)
159
+ wp.launch(output_select_kernel, dim=1, inputs=[output, i], outputs=[cmp], device=device)
160
+ tape.backward(loss=cmp)
161
+ expectedgrads = np.zeros(len(input))
162
+ expectedgrads[i] = 2
163
+ assert_np_equal(tape.gradients[input].numpy(), expectedgrads)
164
+ tape.zero()
165
+
166
+
167
+ def test_casting_constructors(test, device, dtype, register_kernels=False):
168
+ np_type = np.dtype(dtype)
169
+ wp_type = wp._src.types.np_dtype_to_warp_type[np_type]
170
+ quat = wp._src.types.quaternion(dtype=wp_type)
171
+
172
+ np16 = np.dtype(np.float16)
173
+ wp16 = wp._src.types.np_dtype_to_warp_type[np16]
174
+
175
+ np32 = np.dtype(np.float32)
176
+ wp32 = wp._src.types.np_dtype_to_warp_type[np32]
177
+
178
+ np64 = np.dtype(np.float64)
179
+ wp64 = wp._src.types.np_dtype_to_warp_type[np64]
180
+
181
+ def cast_float16(a: wp.array(dtype=wp_type, ndim=2), b: wp.array(dtype=wp16, ndim=2)):
182
+ tid = wp.tid()
183
+
184
+ q1 = quat(a[tid, 0], a[tid, 1], a[tid, 2], a[tid, 3])
185
+ q2 = wp.quaternion(q1, dtype=wp16)
186
+
187
+ b[tid, 0] = q2[0]
188
+ b[tid, 1] = q2[1]
189
+ b[tid, 2] = q2[2]
190
+ b[tid, 3] = q2[3]
191
+
192
+ def cast_float32(a: wp.array(dtype=wp_type, ndim=2), b: wp.array(dtype=wp32, ndim=2)):
193
+ tid = wp.tid()
194
+
195
+ q1 = quat(a[tid, 0], a[tid, 1], a[tid, 2], a[tid, 3])
196
+ q2 = wp.quaternion(q1, dtype=wp32)
197
+
198
+ b[tid, 0] = q2[0]
199
+ b[tid, 1] = q2[1]
200
+ b[tid, 2] = q2[2]
201
+ b[tid, 3] = q2[3]
202
+
203
+ def cast_float64(a: wp.array(dtype=wp_type, ndim=2), b: wp.array(dtype=wp64, ndim=2)):
204
+ tid = wp.tid()
205
+
206
+ q1 = quat(a[tid, 0], a[tid, 1], a[tid, 2], a[tid, 3])
207
+ q2 = wp.quaternion(q1, dtype=wp64)
208
+
209
+ b[tid, 0] = q2[0]
210
+ b[tid, 1] = q2[1]
211
+ b[tid, 2] = q2[2]
212
+ b[tid, 3] = q2[3]
213
+
214
+ kernel_16 = getkernel(cast_float16, suffix=dtype.__name__)
215
+ kernel_32 = getkernel(cast_float32, suffix=dtype.__name__)
216
+ kernel_64 = getkernel(cast_float64, suffix=dtype.__name__)
217
+
218
+ if register_kernels:
219
+ return
220
+
221
+ # check casting to float 16
222
+ a = wp.array(np.ones((1, 4), dtype=np_type), dtype=wp_type, requires_grad=True, device=device)
223
+ b = wp.array(np.zeros((1, 4), dtype=np16), dtype=wp16, requires_grad=True, device=device)
224
+ b_result = np.ones((1, 4), dtype=np16)
225
+ b_grad = wp.array(np.ones((1, 4), dtype=np16), dtype=wp16, device=device)
226
+ a_grad = wp.array(np.ones((1, 4), dtype=np_type), dtype=wp_type, device=device)
227
+
228
+ tape = wp.Tape()
229
+ with tape:
230
+ wp.launch(kernel=kernel_16, dim=1, inputs=[a, b], device=device)
231
+
232
+ tape.backward(grads={b: b_grad})
233
+ out = tape.gradients[a].numpy()
234
+
235
+ assert_np_equal(b.numpy(), b_result)
236
+ assert_np_equal(out, a_grad.numpy())
237
+
238
+ # check casting to float 32
239
+ a = wp.array(np.ones((1, 4), dtype=np_type), dtype=wp_type, requires_grad=True, device=device)
240
+ b = wp.array(np.zeros((1, 4), dtype=np32), dtype=wp32, requires_grad=True, device=device)
241
+ b_result = np.ones((1, 4), dtype=np32)
242
+ b_grad = wp.array(np.ones((1, 4), dtype=np32), dtype=wp32, device=device)
243
+ a_grad = wp.array(np.ones((1, 4), dtype=np_type), dtype=wp_type, device=device)
244
+
245
+ tape = wp.Tape()
246
+ with tape:
247
+ wp.launch(kernel=kernel_32, dim=1, inputs=[a, b], device=device)
248
+
249
+ tape.backward(grads={b: b_grad})
250
+ out = tape.gradients[a].numpy()
251
+
252
+ assert_np_equal(b.numpy(), b_result)
253
+ assert_np_equal(out, a_grad.numpy())
254
+
255
+ # check casting to float 64
256
+ a = wp.array(np.ones((1, 4), dtype=np_type), dtype=wp_type, requires_grad=True, device=device)
257
+ b = wp.array(np.zeros((1, 4), dtype=np64), dtype=wp64, requires_grad=True, device=device)
258
+ b_result = np.ones((1, 4), dtype=np64)
259
+ b_grad = wp.array(np.ones((1, 4), dtype=np64), dtype=wp64, device=device)
260
+ a_grad = wp.array(np.ones((1, 4), dtype=np_type), dtype=wp_type, device=device)
261
+
262
+ tape = wp.Tape()
263
+ with tape:
264
+ wp.launch(kernel=kernel_64, dim=1, inputs=[a, b], device=device)
265
+
266
+ tape.backward(grads={b: b_grad})
267
+ out = tape.gradients[a].numpy()
268
+
269
+ assert_np_equal(b.numpy(), b_result)
270
+ assert_np_equal(out, a_grad.numpy())
271
+
272
+
273
+ def test_inverse(test, device, dtype, register_kernels=False):
274
+ rng = np.random.default_rng(123)
275
+
276
+ tol = {
277
+ np.float16: 2.0e-3,
278
+ np.float32: 1.0e-6,
279
+ np.float64: 1.0e-8,
280
+ }.get(dtype, 0)
281
+
282
+ wptype = wp._src.types.np_dtype_to_warp_type[np.dtype(dtype)]
283
+ quat = wp._src.types.quaternion(dtype=wptype)
284
+
285
+ output_select_kernel = get_select_kernel(wptype)
286
+
287
+ def check_quat_inverse(
288
+ input: wp.array(dtype=wptype),
289
+ shouldbeidentity: wp.array(dtype=quat),
290
+ q: wp.array(dtype=wptype),
291
+ ):
292
+ qread = quat(input[0], input[1], input[2], input[3])
293
+ qresult = wp.quat_inverse(qread)
294
+
295
+ # this inverse should work for normalized quaternions:
296
+ shouldbeidentity[0] = wp.normalize(qread) * wp.quat_inverse(wp.normalize(qread))
297
+
298
+ # multiply the output by 2 so we've got something to backpropagate:
299
+ q[0] = wptype(2) * qresult[0]
300
+ q[1] = wptype(2) * qresult[1]
301
+ q[2] = wptype(2) * qresult[2]
302
+ q[3] = wptype(2) * qresult[3]
303
+
304
+ kernel = getkernel(check_quat_inverse, suffix=dtype.__name__)
305
+
306
+ if register_kernels:
307
+ return
308
+
309
+ input = wp.array(rng.standard_normal(size=4).astype(dtype), requires_grad=True, device=device)
310
+ shouldbeidentity = wp.array(np.zeros((1, 4)), dtype=quat, requires_grad=True, device=device)
311
+ output = wp.zeros_like(input)
312
+ wp.launch(kernel, dim=1, inputs=[input], outputs=[shouldbeidentity, output], device=device)
313
+
314
+ assert_np_equal(shouldbeidentity.numpy(), np.array([0, 0, 0, 1]), tol=tol)
315
+
316
+ for i in range(4):
317
+ cmp = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
318
+ tape = wp.Tape()
319
+ with tape:
320
+ wp.launch(kernel, dim=1, inputs=[input], outputs=[shouldbeidentity, output], device=device)
321
+ wp.launch(output_select_kernel, dim=1, inputs=[output, i], outputs=[cmp], device=device)
322
+ tape.backward(loss=cmp)
323
+ expectedgrads = np.zeros(len(input))
324
+ expectedgrads[i] = -2 if i != 3 else 2
325
+ assert_np_equal(tape.gradients[input].numpy(), expectedgrads)
326
+ tape.zero()
327
+
328
+
329
+ def test_dotproduct(test, device, dtype, register_kernels=False):
330
+ rng = np.random.default_rng(123)
331
+
332
+ tol = {
333
+ np.float16: 1.0e-2,
334
+ np.float32: 1.0e-6,
335
+ np.float64: 1.0e-8,
336
+ }.get(dtype, 0)
337
+
338
+ wptype = wp._src.types.np_dtype_to_warp_type[np.dtype(dtype)]
339
+ quat = wp._src.types.quaternion(dtype=wptype)
340
+
341
+ def check_quat_dot(
342
+ s: wp.array(dtype=quat),
343
+ v: wp.array(dtype=quat),
344
+ dot: wp.array(dtype=wptype),
345
+ ):
346
+ dot[0] = wptype(2) * wp.dot(v[0], s[0])
347
+
348
+ dotkernel = getkernel(check_quat_dot, suffix=dtype.__name__)
349
+ if register_kernels:
350
+ return
351
+
352
+ s = wp.array(rng.standard_normal(size=4).astype(dtype), dtype=quat, requires_grad=True, device=device)
353
+ v = wp.array(rng.standard_normal(size=4).astype(dtype), dtype=quat, requires_grad=True, device=device)
354
+ dot = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
355
+
356
+ tape = wp.Tape()
357
+ with tape:
358
+ wp.launch(
359
+ dotkernel,
360
+ dim=1,
361
+ inputs=[
362
+ s,
363
+ v,
364
+ ],
365
+ outputs=[dot],
366
+ device=device,
367
+ )
368
+
369
+ assert_np_equal(dot.numpy()[0], 2.0 * (v.numpy() * s.numpy()).sum(), tol=tol)
370
+
371
+ tape.backward(loss=dot)
372
+ sgrads = tape.gradients[s].numpy()[0]
373
+ expected_grads = 2.0 * v.numpy()[0]
374
+ assert_np_equal(sgrads, expected_grads, tol=10 * tol)
375
+
376
+ vgrads = tape.gradients[v].numpy()[0]
377
+ expected_grads = 2.0 * s.numpy()[0]
378
+ assert_np_equal(vgrads, expected_grads, tol=tol)
379
+
380
+
381
+ def test_length(test, device, dtype, register_kernels=False):
382
+ rng = np.random.default_rng(123)
383
+
384
+ tol = {
385
+ np.float16: 5.0e-3,
386
+ np.float32: 1.0e-6,
387
+ np.float64: 1.0e-7,
388
+ }.get(dtype, 0)
389
+
390
+ wptype = wp._src.types.np_dtype_to_warp_type[np.dtype(dtype)]
391
+ quat = wp._src.types.quaternion(dtype=wptype)
392
+
393
+ def check_quat_length(
394
+ q: wp.array(dtype=quat),
395
+ l: wp.array(dtype=wptype),
396
+ l2: wp.array(dtype=wptype),
397
+ ):
398
+ l[0] = wptype(2) * wp.length(q[0])
399
+ l2[0] = wptype(2) * wp.length_sq(q[0])
400
+
401
+ kernel = getkernel(check_quat_length, suffix=dtype.__name__)
402
+
403
+ if register_kernels:
404
+ return
405
+
406
+ q = wp.array(rng.standard_normal(size=4).astype(dtype), dtype=quat, requires_grad=True, device=device)
407
+ l = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
408
+ l2 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
409
+
410
+ tape = wp.Tape()
411
+ with tape:
412
+ wp.launch(
413
+ kernel,
414
+ dim=1,
415
+ inputs=[
416
+ q,
417
+ ],
418
+ outputs=[l, l2],
419
+ device=device,
420
+ )
421
+
422
+ assert_np_equal(l.numpy()[0], 2 * np.linalg.norm(q.numpy()), tol=10 * tol)
423
+ assert_np_equal(l2.numpy()[0], 2 * np.linalg.norm(q.numpy()) ** 2, tol=10 * tol)
424
+
425
+ tape.backward(loss=l)
426
+ grad = tape.gradients[q].numpy()[0]
427
+ expected_grad = 2 * q.numpy()[0] / np.linalg.norm(q.numpy())
428
+ assert_np_equal(grad, expected_grad, tol=10 * tol)
429
+ tape.zero()
430
+
431
+ tape.backward(loss=l2)
432
+ grad = tape.gradients[q].numpy()[0]
433
+ expected_grad = 4 * q.numpy()[0]
434
+ assert_np_equal(grad, expected_grad, tol=10 * tol)
435
+ tape.zero()
436
+
437
+
438
+ def test_normalize(test, device, dtype, register_kernels=False):
439
+ rng = np.random.default_rng(123)
440
+
441
+ tol = {
442
+ np.float16: 5.0e-3,
443
+ np.float32: 1.0e-6,
444
+ np.float64: 1.0e-8,
445
+ }.get(dtype, 0)
446
+
447
+ wptype = wp._src.types.np_dtype_to_warp_type[np.dtype(dtype)]
448
+ quat = wp._src.types.quaternion(dtype=wptype)
449
+
450
+ def check_normalize(
451
+ q: wp.array(dtype=quat),
452
+ n0: wp.array(dtype=wptype),
453
+ n1: wp.array(dtype=wptype),
454
+ n2: wp.array(dtype=wptype),
455
+ n3: wp.array(dtype=wptype),
456
+ ):
457
+ n = wptype(2) * (wp.normalize(q[0]))
458
+
459
+ n0[0] = n[0]
460
+ n1[0] = n[1]
461
+ n2[0] = n[2]
462
+ n3[0] = n[3]
463
+
464
+ def check_normalize_alt(
465
+ q: wp.array(dtype=quat),
466
+ n0: wp.array(dtype=wptype),
467
+ n1: wp.array(dtype=wptype),
468
+ n2: wp.array(dtype=wptype),
469
+ n3: wp.array(dtype=wptype),
470
+ ):
471
+ n = wptype(2) * (q[0] / wp.length(q[0]))
472
+
473
+ n0[0] = n[0]
474
+ n1[0] = n[1]
475
+ n2[0] = n[2]
476
+ n3[0] = n[3]
477
+
478
+ normalize_kernel = getkernel(check_normalize, suffix=dtype.__name__)
479
+ normalize_alt_kernel = getkernel(check_normalize_alt, suffix=dtype.__name__)
480
+
481
+ if register_kernels:
482
+ return
483
+
484
+ # I've already tested the things I'm using in check_normalize_alt, so I'll just
485
+ # make sure the two are giving the same results/gradients
486
+ q = wp.array(rng.standard_normal(size=4).astype(dtype), dtype=quat, requires_grad=True, device=device)
487
+
488
+ n0 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
489
+ n1 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
490
+ n2 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
491
+ n3 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
492
+
493
+ n0_alt = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
494
+ n1_alt = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
495
+ n2_alt = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
496
+ n3_alt = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
497
+
498
+ outputs0 = [
499
+ n0,
500
+ n1,
501
+ n2,
502
+ n3,
503
+ ]
504
+ tape0 = wp.Tape()
505
+ with tape0:
506
+ wp.launch(normalize_kernel, dim=1, inputs=[q], outputs=outputs0, device=device)
507
+
508
+ outputs1 = [
509
+ n0_alt,
510
+ n1_alt,
511
+ n2_alt,
512
+ n3_alt,
513
+ ]
514
+ tape1 = wp.Tape()
515
+ with tape1:
516
+ wp.launch(
517
+ normalize_alt_kernel,
518
+ dim=1,
519
+ inputs=[
520
+ q,
521
+ ],
522
+ outputs=outputs1,
523
+ device=device,
524
+ )
525
+
526
+ assert_np_equal(n0.numpy()[0], n0_alt.numpy()[0], tol=tol)
527
+ assert_np_equal(n1.numpy()[0], n1_alt.numpy()[0], tol=tol)
528
+ assert_np_equal(n2.numpy()[0], n2_alt.numpy()[0], tol=tol)
529
+ assert_np_equal(n3.numpy()[0], n3_alt.numpy()[0], tol=tol)
530
+
531
+ for ncmp, ncmpalt in zip(outputs0, outputs1):
532
+ tape0.backward(loss=ncmp)
533
+ tape1.backward(loss=ncmpalt)
534
+ assert_np_equal(tape0.gradients[q].numpy()[0], tape1.gradients[q].numpy()[0], tol=tol)
535
+ tape0.zero()
536
+ tape1.zero()
537
+
538
+
539
+ def test_addition(test, device, dtype, register_kernels=False):
540
+ rng = np.random.default_rng(123)
541
+
542
+ tol = {
543
+ np.float16: 5.0e-3,
544
+ np.float32: 1.0e-6,
545
+ np.float64: 1.0e-8,
546
+ }.get(dtype, 0)
547
+
548
+ wptype = wp._src.types.np_dtype_to_warp_type[np.dtype(dtype)]
549
+ quat = wp._src.types.quaternion(dtype=wptype)
550
+
551
+ def check_quat_add(
552
+ q: wp.array(dtype=quat),
553
+ v: wp.array(dtype=quat),
554
+ r0: wp.array(dtype=wptype),
555
+ r1: wp.array(dtype=wptype),
556
+ r2: wp.array(dtype=wptype),
557
+ r3: wp.array(dtype=wptype),
558
+ ):
559
+ result = q[0] + v[0]
560
+
561
+ r0[0] = wptype(2) * result[0]
562
+ r1[0] = wptype(2) * result[1]
563
+ r2[0] = wptype(2) * result[2]
564
+ r3[0] = wptype(2) * result[3]
565
+
566
+ kernel = getkernel(check_quat_add, suffix=dtype.__name__)
567
+
568
+ if register_kernels:
569
+ return
570
+
571
+ q = wp.array(rng.standard_normal(size=4).astype(dtype), dtype=quat, requires_grad=True, device=device)
572
+ v = wp.array(rng.standard_normal(size=4).astype(dtype), dtype=quat, requires_grad=True, device=device)
573
+
574
+ r0 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
575
+ r1 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
576
+ r2 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
577
+ r3 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
578
+
579
+ tape = wp.Tape()
580
+ with tape:
581
+ wp.launch(
582
+ kernel,
583
+ dim=1,
584
+ inputs=[
585
+ q,
586
+ v,
587
+ ],
588
+ outputs=[r0, r1, r2, r3],
589
+ device=device,
590
+ )
591
+
592
+ assert_np_equal(r0.numpy()[0], 2 * (v.numpy()[0, 0] + q.numpy()[0, 0]), tol=tol)
593
+ assert_np_equal(r1.numpy()[0], 2 * (v.numpy()[0, 1] + q.numpy()[0, 1]), tol=tol)
594
+ assert_np_equal(r2.numpy()[0], 2 * (v.numpy()[0, 2] + q.numpy()[0, 2]), tol=tol)
595
+ assert_np_equal(r3.numpy()[0], 2 * (v.numpy()[0, 3] + q.numpy()[0, 3]), tol=tol)
596
+
597
+ for i, l in enumerate([r0, r1, r2, r3]):
598
+ tape.backward(loss=l)
599
+ qgrads = tape.gradients[q].numpy()[0]
600
+ expected_grads = np.zeros_like(qgrads)
601
+
602
+ expected_grads[i] = 2
603
+ assert_np_equal(qgrads, expected_grads, tol=10 * tol)
604
+
605
+ vgrads = tape.gradients[v].numpy()[0]
606
+ assert_np_equal(vgrads, expected_grads, tol=tol)
607
+
608
+ tape.zero()
609
+
610
+
611
+ def test_subtraction(test, device, dtype, register_kernels=False):
612
+ rng = np.random.default_rng(123)
613
+
614
+ tol = {
615
+ np.float16: 5.0e-3,
616
+ np.float32: 1.0e-6,
617
+ np.float64: 1.0e-8,
618
+ }.get(dtype, 0)
619
+
620
+ wptype = wp._src.types.np_dtype_to_warp_type[np.dtype(dtype)]
621
+ quat = wp._src.types.quaternion(dtype=wptype)
622
+
623
+ def check_quat_sub(
624
+ q: wp.array(dtype=quat),
625
+ v: wp.array(dtype=quat),
626
+ r0: wp.array(dtype=wptype),
627
+ r1: wp.array(dtype=wptype),
628
+ r2: wp.array(dtype=wptype),
629
+ r3: wp.array(dtype=wptype),
630
+ ):
631
+ result = v[0] - q[0]
632
+
633
+ r0[0] = wptype(2) * result[0]
634
+ r1[0] = wptype(2) * result[1]
635
+ r2[0] = wptype(2) * result[2]
636
+ r3[0] = wptype(2) * result[3]
637
+
638
+ kernel = getkernel(check_quat_sub, suffix=dtype.__name__)
639
+
640
+ if register_kernels:
641
+ return
642
+
643
+ q = wp.array(rng.standard_normal(size=4).astype(dtype), dtype=quat, requires_grad=True, device=device)
644
+ v = wp.array(rng.standard_normal(size=4).astype(dtype), dtype=quat, requires_grad=True, device=device)
645
+
646
+ r0 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
647
+ r1 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
648
+ r2 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
649
+ r3 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
650
+
651
+ tape = wp.Tape()
652
+ with tape:
653
+ wp.launch(
654
+ kernel,
655
+ dim=1,
656
+ inputs=[
657
+ q,
658
+ v,
659
+ ],
660
+ outputs=[r0, r1, r2, r3],
661
+ device=device,
662
+ )
663
+
664
+ assert_np_equal(r0.numpy()[0], 2 * (v.numpy()[0, 0] - q.numpy()[0, 0]), tol=tol)
665
+ assert_np_equal(r1.numpy()[0], 2 * (v.numpy()[0, 1] - q.numpy()[0, 1]), tol=tol)
666
+ assert_np_equal(r2.numpy()[0], 2 * (v.numpy()[0, 2] - q.numpy()[0, 2]), tol=tol)
667
+ assert_np_equal(r3.numpy()[0], 2 * (v.numpy()[0, 3] - q.numpy()[0, 3]), tol=tol)
668
+
669
+ for i, l in enumerate([r0, r1, r2, r3]):
670
+ tape.backward(loss=l)
671
+ qgrads = tape.gradients[q].numpy()[0]
672
+ expected_grads = np.zeros_like(qgrads)
673
+
674
+ expected_grads[i] = -2
675
+ assert_np_equal(qgrads, expected_grads, tol=10 * tol)
676
+
677
+ vgrads = tape.gradients[v].numpy()[0]
678
+ expected_grads[i] = 2
679
+ assert_np_equal(vgrads, expected_grads, tol=tol)
680
+
681
+ tape.zero()
682
+
683
+
684
+ def test_scalar_multiplication(test, device, dtype, register_kernels=False):
685
+ rng = np.random.default_rng(123)
686
+
687
+ tol = {
688
+ np.float16: 5.0e-3,
689
+ np.float32: 1.0e-6,
690
+ np.float64: 1.0e-8,
691
+ }.get(dtype, 0)
692
+
693
+ wptype = wp._src.types.np_dtype_to_warp_type[np.dtype(dtype)]
694
+ quat = wp._src.types.quaternion(dtype=wptype)
695
+
696
+ def check_quat_scalar_mul(
697
+ s: wp.array(dtype=wptype),
698
+ q: wp.array(dtype=quat),
699
+ l0: wp.array(dtype=wptype),
700
+ l1: wp.array(dtype=wptype),
701
+ l2: wp.array(dtype=wptype),
702
+ l3: wp.array(dtype=wptype),
703
+ r0: wp.array(dtype=wptype),
704
+ r1: wp.array(dtype=wptype),
705
+ r2: wp.array(dtype=wptype),
706
+ r3: wp.array(dtype=wptype),
707
+ ):
708
+ lresult = s[0] * q[0]
709
+ rresult = q[0] * s[0]
710
+
711
+ # multiply outputs by 2 so we've got something to backpropagate:
712
+ l0[0] = wptype(2) * lresult[0]
713
+ l1[0] = wptype(2) * lresult[1]
714
+ l2[0] = wptype(2) * lresult[2]
715
+ l3[0] = wptype(2) * lresult[3]
716
+
717
+ r0[0] = wptype(2) * rresult[0]
718
+ r1[0] = wptype(2) * rresult[1]
719
+ r2[0] = wptype(2) * rresult[2]
720
+ r3[0] = wptype(2) * rresult[3]
721
+
722
+ kernel = getkernel(check_quat_scalar_mul, suffix=dtype.__name__)
723
+
724
+ if register_kernels:
725
+ return
726
+
727
+ s = wp.array(rng.standard_normal(size=1).astype(dtype), requires_grad=True, device=device)
728
+ q = wp.array(rng.standard_normal(size=(1, 4)).astype(dtype), dtype=quat, requires_grad=True, device=device)
729
+
730
+ l0 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
731
+ l1 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
732
+ l2 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
733
+ l3 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
734
+
735
+ r0 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
736
+ r1 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
737
+ r2 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
738
+ r3 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
739
+
740
+ tape = wp.Tape()
741
+ with tape:
742
+ wp.launch(
743
+ kernel,
744
+ dim=1,
745
+ inputs=[s, q],
746
+ outputs=[
747
+ l0,
748
+ l1,
749
+ l2,
750
+ l3,
751
+ r0,
752
+ r1,
753
+ r2,
754
+ r3,
755
+ ],
756
+ device=device,
757
+ )
758
+
759
+ assert_np_equal(l0.numpy()[0], 2 * s.numpy()[0] * q.numpy()[0, 0], tol=tol)
760
+ assert_np_equal(l1.numpy()[0], 2 * s.numpy()[0] * q.numpy()[0, 1], tol=tol)
761
+ assert_np_equal(l2.numpy()[0], 2 * s.numpy()[0] * q.numpy()[0, 2], tol=tol)
762
+ assert_np_equal(l3.numpy()[0], 2 * s.numpy()[0] * q.numpy()[0, 3], tol=tol)
763
+
764
+ assert_np_equal(r0.numpy()[0], 2 * s.numpy()[0] * q.numpy()[0, 0], tol=tol)
765
+ assert_np_equal(r1.numpy()[0], 2 * s.numpy()[0] * q.numpy()[0, 1], tol=tol)
766
+ assert_np_equal(r2.numpy()[0], 2 * s.numpy()[0] * q.numpy()[0, 2], tol=tol)
767
+ assert_np_equal(r3.numpy()[0], 2 * s.numpy()[0] * q.numpy()[0, 3], tol=tol)
768
+
769
+ if dtype in np_float_types:
770
+ for i, outputs in enumerate([(l0, r0), (l1, r1), (l2, r2), (l3, r3)]):
771
+ for l in outputs:
772
+ tape.backward(loss=l)
773
+ sgrad = tape.gradients[s].numpy()[0]
774
+ assert_np_equal(sgrad, 2 * q.numpy()[0, i], tol=tol)
775
+ allgrads = tape.gradients[q].numpy()[0]
776
+ expected_grads = np.zeros_like(allgrads)
777
+ expected_grads[i] = s.numpy()[0] * 2
778
+ assert_np_equal(allgrads, expected_grads, tol=10 * tol)
779
+ tape.zero()
780
+
781
+
782
+ def test_scalar_division(test, device, dtype, register_kernels=False):
783
+ rng = np.random.default_rng(123)
784
+
785
+ tol = {
786
+ np.float16: 1.0e-3,
787
+ np.float32: 1.0e-6,
788
+ np.float64: 1.0e-8,
789
+ }.get(dtype, 0)
790
+
791
+ wptype = wp._src.types.np_dtype_to_warp_type[np.dtype(dtype)]
792
+ quat = wp._src.types.quaternion(dtype=wptype)
793
+
794
+ def check_quat_scalar_div(
795
+ s: wp.array(dtype=wptype),
796
+ q: wp.array(dtype=quat),
797
+ r0: wp.array(dtype=wptype),
798
+ r1: wp.array(dtype=wptype),
799
+ r2: wp.array(dtype=wptype),
800
+ r3: wp.array(dtype=wptype),
801
+ ):
802
+ result = q[0] / s[0]
803
+
804
+ # multiply outputs by 2 so we've got something to backpropagate:
805
+ r0[0] = wptype(2) * result[0]
806
+ r1[0] = wptype(2) * result[1]
807
+ r2[0] = wptype(2) * result[2]
808
+ r3[0] = wptype(2) * result[3]
809
+
810
+ kernel = getkernel(check_quat_scalar_div, suffix=dtype.__name__)
811
+
812
+ if register_kernels:
813
+ return
814
+
815
+ s = wp.array(rng.standard_normal(size=1).astype(dtype), requires_grad=True, device=device)
816
+ q = wp.array(rng.standard_normal(size=(1, 4)).astype(dtype), dtype=quat, requires_grad=True, device=device)
817
+
818
+ r0 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
819
+ r1 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
820
+ r2 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
821
+ r3 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
822
+
823
+ tape = wp.Tape()
824
+ with tape:
825
+ wp.launch(
826
+ kernel,
827
+ dim=1,
828
+ inputs=[s, q],
829
+ outputs=[
830
+ r0,
831
+ r1,
832
+ r2,
833
+ r3,
834
+ ],
835
+ device=device,
836
+ )
837
+ assert_np_equal(r0.numpy()[0], 2 * q.numpy()[0, 0] / s.numpy()[0], tol=tol)
838
+ assert_np_equal(r1.numpy()[0], 2 * q.numpy()[0, 1] / s.numpy()[0], tol=tol)
839
+ assert_np_equal(r2.numpy()[0], 2 * q.numpy()[0, 2] / s.numpy()[0], tol=tol)
840
+ assert_np_equal(r3.numpy()[0], 2 * q.numpy()[0, 3] / s.numpy()[0], tol=tol)
841
+
842
+ if dtype in np_float_types:
843
+ for i, r in enumerate([r0, r1, r2, r3]):
844
+ tape.backward(loss=r)
845
+ sgrad = tape.gradients[s].numpy()[0]
846
+ assert_np_equal(sgrad, -2 * q.numpy()[0, i] / (s.numpy()[0] * s.numpy()[0]), tol=tol)
847
+
848
+ allgrads = tape.gradients[q].numpy()[0]
849
+ expected_grads = np.zeros_like(allgrads)
850
+ expected_grads[i] = 2 / s.numpy()[0]
851
+ assert_np_equal(allgrads, expected_grads, tol=10 * tol)
852
+ tape.zero()
853
+
854
+
855
+ def test_quat_multiplication(test, device, dtype, register_kernels=False):
856
+ rng = np.random.default_rng(123)
857
+
858
+ tol = {
859
+ np.float16: 1.0e-2,
860
+ np.float32: 1.0e-6,
861
+ np.float64: 1.0e-8,
862
+ }.get(dtype, 0)
863
+
864
+ wptype = wp._src.types.np_dtype_to_warp_type[np.dtype(dtype)]
865
+ quat = wp._src.types.quaternion(dtype=wptype)
866
+
867
+ def check_quat_mul(
868
+ s: wp.array(dtype=quat),
869
+ q: wp.array(dtype=quat),
870
+ r0: wp.array(dtype=wptype),
871
+ r1: wp.array(dtype=wptype),
872
+ r2: wp.array(dtype=wptype),
873
+ r3: wp.array(dtype=wptype),
874
+ ):
875
+ result = s[0] * q[0]
876
+
877
+ # multiply outputs by 2 so we've got something to backpropagate:
878
+ r0[0] = wptype(2) * result[0]
879
+ r1[0] = wptype(2) * result[1]
880
+ r2[0] = wptype(2) * result[2]
881
+ r3[0] = wptype(2) * result[3]
882
+
883
+ kernel = getkernel(check_quat_mul, suffix=dtype.__name__)
884
+
885
+ if register_kernels:
886
+ return
887
+
888
+ s = wp.array(rng.standard_normal(size=(1, 4)).astype(dtype), dtype=quat, requires_grad=True, device=device)
889
+ q = wp.array(rng.standard_normal(size=(1, 4)).astype(dtype), dtype=quat, requires_grad=True, device=device)
890
+
891
+ r0 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
892
+ r1 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
893
+ r2 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
894
+ r3 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
895
+
896
+ tape = wp.Tape()
897
+ with tape:
898
+ wp.launch(
899
+ kernel,
900
+ dim=1,
901
+ inputs=[s, q],
902
+ outputs=[
903
+ r0,
904
+ r1,
905
+ r2,
906
+ r3,
907
+ ],
908
+ device=device,
909
+ )
910
+
911
+ a = s.numpy()
912
+ b = q.numpy()
913
+ assert_np_equal(
914
+ r0.numpy()[0], 2 * (a[0, 3] * b[0, 0] + b[0, 3] * a[0, 0] + a[0, 1] * b[0, 2] - b[0, 1] * a[0, 2]), tol=tol
915
+ )
916
+ assert_np_equal(
917
+ r1.numpy()[0], 2 * (a[0, 3] * b[0, 1] + b[0, 3] * a[0, 1] + a[0, 2] * b[0, 0] - b[0, 2] * a[0, 0]), tol=tol
918
+ )
919
+ assert_np_equal(
920
+ r2.numpy()[0], 2 * (a[0, 3] * b[0, 2] + b[0, 3] * a[0, 2] + a[0, 0] * b[0, 1] - b[0, 0] * a[0, 1]), tol=tol
921
+ )
922
+ assert_np_equal(
923
+ r3.numpy()[0], 2 * (a[0, 3] * b[0, 3] - a[0, 0] * b[0, 0] - a[0, 1] * b[0, 1] - a[0, 2] * b[0, 2]), tol=tol
924
+ )
925
+
926
+ tape.backward(loss=r0)
927
+ agrad = tape.gradients[s].numpy()[0]
928
+ assert_np_equal(agrad, 2 * np.array([b[0, 3], b[0, 2], -b[0, 1], b[0, 0]]), tol=tol)
929
+
930
+ bgrad = tape.gradients[q].numpy()[0]
931
+ assert_np_equal(bgrad, 2 * np.array([a[0, 3], -a[0, 2], a[0, 1], a[0, 0]]), tol=tol)
932
+ tape.zero()
933
+
934
+ tape.backward(loss=r1)
935
+ agrad = tape.gradients[s].numpy()[0]
936
+ assert_np_equal(agrad, 2 * np.array([-b[0, 2], b[0, 3], b[0, 0], b[0, 1]]), tol=tol)
937
+
938
+ bgrad = tape.gradients[q].numpy()[0]
939
+ assert_np_equal(bgrad, 2 * np.array([a[0, 2], a[0, 3], -a[0, 0], a[0, 1]]), tol=tol)
940
+ tape.zero()
941
+
942
+ tape.backward(loss=r2)
943
+ agrad = tape.gradients[s].numpy()[0]
944
+ assert_np_equal(agrad, 2 * np.array([b[0, 1], -b[0, 0], b[0, 3], b[0, 2]]), tol=tol)
945
+
946
+ bgrad = tape.gradients[q].numpy()[0]
947
+ assert_np_equal(bgrad, 2 * np.array([-a[0, 1], a[0, 0], a[0, 3], a[0, 2]]), tol=tol)
948
+ tape.zero()
949
+
950
+ tape.backward(loss=r3)
951
+ agrad = tape.gradients[s].numpy()[0]
952
+ assert_np_equal(agrad, 2 * np.array([-b[0, 0], -b[0, 1], -b[0, 2], b[0, 3]]), tol=tol)
953
+
954
+ bgrad = tape.gradients[q].numpy()[0]
955
+ assert_np_equal(bgrad, 2 * np.array([-a[0, 0], -a[0, 1], -a[0, 2], a[0, 3]]), tol=tol)
956
+ tape.zero()
957
+
958
+
959
+ def test_indexing(test, device, dtype, register_kernels=False):
960
+ rng = np.random.default_rng(123)
961
+
962
+ tol = {
963
+ np.float16: 5.0e-3,
964
+ np.float32: 1.0e-6,
965
+ np.float64: 1.0e-8,
966
+ }.get(dtype, 0)
967
+
968
+ wptype = wp._src.types.np_dtype_to_warp_type[np.dtype(dtype)]
969
+ quat = wp._src.types.quaternion(dtype=wptype)
970
+
971
+ def check_quat_indexing(
972
+ q: wp.array(dtype=quat),
973
+ r0: wp.array(dtype=wptype),
974
+ r1: wp.array(dtype=wptype),
975
+ r2: wp.array(dtype=wptype),
976
+ r3: wp.array(dtype=wptype),
977
+ ):
978
+ # multiply outputs by 2 so we've got something to backpropagate:
979
+ r0[0] = wptype(2) * q[0][0]
980
+ r1[0] = wptype(2) * q[0][1]
981
+ r2[0] = wptype(2) * q[0][2]
982
+ r3[0] = wptype(2) * q[0][3]
983
+
984
+ kernel = getkernel(check_quat_indexing, suffix=dtype.__name__)
985
+
986
+ if register_kernels:
987
+ return
988
+
989
+ q = wp.array(rng.standard_normal(size=(1, 4)).astype(dtype), dtype=quat, requires_grad=True, device=device)
990
+ r0 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
991
+ r1 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
992
+ r2 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
993
+ r3 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
994
+
995
+ tape = wp.Tape()
996
+ with tape:
997
+ wp.launch(kernel, dim=1, inputs=[q], outputs=[r0, r1, r2, r3], device=device)
998
+
999
+ for i, l in enumerate([r0, r1, r2, r3]):
1000
+ tape.backward(loss=l)
1001
+ allgrads = tape.gradients[q].numpy()[0]
1002
+ expected_grads = np.zeros_like(allgrads)
1003
+ expected_grads[i] = 2
1004
+ assert_np_equal(allgrads, expected_grads, tol=tol)
1005
+ tape.zero()
1006
+
1007
+ assert_np_equal(r0.numpy()[0], 2.0 * q.numpy()[0, 0], tol=tol)
1008
+ assert_np_equal(r1.numpy()[0], 2.0 * q.numpy()[0, 1], tol=tol)
1009
+ assert_np_equal(r2.numpy()[0], 2.0 * q.numpy()[0, 2], tol=tol)
1010
+ assert_np_equal(r3.numpy()[0], 2.0 * q.numpy()[0, 3], tol=tol)
1011
+
1012
+
1013
+ @wp.kernel
1014
+ def test_assignment():
1015
+ q = wp.quat(1.0, 2.0, 3.0, 4.0)
1016
+ q[0] = 1.23
1017
+ q[1] = 2.34
1018
+ q[2] = 3.45
1019
+ q[3] = 4.56
1020
+ wp.expect_eq(q[0], 1.23)
1021
+ wp.expect_eq(q[1], 2.34)
1022
+ wp.expect_eq(q[2], 3.45)
1023
+ wp.expect_eq(q[3], 4.56)
1024
+
1025
+
1026
+ def test_quat_lerp(test, device, dtype, register_kernels=False):
1027
+ rng = np.random.default_rng(123)
1028
+
1029
+ tol = {
1030
+ np.float16: 1.0e-2,
1031
+ np.float32: 1.0e-6,
1032
+ np.float64: 1.0e-8,
1033
+ }.get(dtype, 0)
1034
+
1035
+ wptype = wp._src.types.np_dtype_to_warp_type[np.dtype(dtype)]
1036
+ quat = wp._src.types.quaternion(dtype=wptype)
1037
+
1038
+ def check_quat_lerp(
1039
+ s: wp.array(dtype=quat),
1040
+ q: wp.array(dtype=quat),
1041
+ t: wp.array(dtype=wptype),
1042
+ r0: wp.array(dtype=wptype),
1043
+ r1: wp.array(dtype=wptype),
1044
+ r2: wp.array(dtype=wptype),
1045
+ r3: wp.array(dtype=wptype),
1046
+ ):
1047
+ result = wp.lerp(s[0], q[0], t[0])
1048
+
1049
+ # multiply outputs by 2 so we've got something to backpropagate:
1050
+ r0[0] = wptype(2) * result[0]
1051
+ r1[0] = wptype(2) * result[1]
1052
+ r2[0] = wptype(2) * result[2]
1053
+ r3[0] = wptype(2) * result[3]
1054
+
1055
+ kernel = getkernel(check_quat_lerp, suffix=dtype.__name__)
1056
+
1057
+ if register_kernels:
1058
+ return
1059
+
1060
+ s = wp.array(rng.standard_normal(size=(1, 4)).astype(dtype), dtype=quat, requires_grad=True, device=device)
1061
+ q = wp.array(rng.standard_normal(size=(1, 4)).astype(dtype), dtype=quat, requires_grad=True, device=device)
1062
+ t = wp.array(rng.uniform(size=1).astype(dtype), dtype=wptype, requires_grad=True, device=device)
1063
+
1064
+ r0 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1065
+ r1 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1066
+ r2 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1067
+ r3 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1068
+
1069
+ tape = wp.Tape()
1070
+ with tape:
1071
+ wp.launch(
1072
+ kernel,
1073
+ dim=1,
1074
+ inputs=[s, q, t],
1075
+ outputs=[
1076
+ r0,
1077
+ r1,
1078
+ r2,
1079
+ r3,
1080
+ ],
1081
+ device=device,
1082
+ )
1083
+
1084
+ a = s.numpy()
1085
+ b = q.numpy()
1086
+ tt = t.numpy()
1087
+ assert_np_equal(r0.numpy()[0], 2 * ((1 - tt) * a[0, 0] + tt * b[0, 0]), tol=tol)
1088
+ assert_np_equal(r1.numpy()[0], 2 * ((1 - tt) * a[0, 1] + tt * b[0, 1]), tol=tol)
1089
+ assert_np_equal(r2.numpy()[0], 2 * ((1 - tt) * a[0, 2] + tt * b[0, 2]), tol=tol)
1090
+ assert_np_equal(r3.numpy()[0], 2 * ((1 - tt) * a[0, 3] + tt * b[0, 3]), tol=tol)
1091
+
1092
+ for i, l in enumerate([r0, r1, r2, r3]):
1093
+ tape.backward(loss=l)
1094
+ agrad = tape.gradients[s].numpy()[0]
1095
+ bgrad = tape.gradients[q].numpy()[0]
1096
+ tgrad = tape.gradients[t].numpy()[0]
1097
+ expected_grads = np.zeros_like(agrad)
1098
+ expected_grads[i] = 2 * (1 - tt)
1099
+ assert_np_equal(agrad, expected_grads, tol=tol)
1100
+ expected_grads[i] = 2 * tt
1101
+ assert_np_equal(bgrad, expected_grads, tol=tol)
1102
+ assert_np_equal(tgrad, 2 * (b[0, i] - a[0, i]), tol=tol)
1103
+
1104
+ tape.zero()
1105
+
1106
+
1107
+ def test_quat_rotate(test, device, dtype, register_kernels=False):
1108
+ rng = np.random.default_rng(123)
1109
+
1110
+ tol = {
1111
+ np.float16: 1.0e-2,
1112
+ np.float32: 1.0e-6,
1113
+ np.float64: 1.0e-8,
1114
+ }.get(dtype, 0)
1115
+
1116
+ wptype = wp._src.types.np_dtype_to_warp_type[np.dtype(dtype)]
1117
+ quat = wp._src.types.quaternion(dtype=wptype)
1118
+ vec3 = wp._src.types.vector(length=3, dtype=wptype)
1119
+
1120
+ def check_quat_rotate(
1121
+ q: wp.array(dtype=quat),
1122
+ v: wp.array(dtype=vec3),
1123
+ outputs: wp.array(dtype=wptype),
1124
+ outputs_inv: wp.array(dtype=wptype),
1125
+ outputs_manual: wp.array(dtype=wptype),
1126
+ outputs_inv_manual: wp.array(dtype=wptype),
1127
+ ):
1128
+ result = wp.quat_rotate(q[0], v[0])
1129
+ result_inv = wp.quat_rotate_inv(q[0], v[0])
1130
+
1131
+ qv = vec3(q[0][0], q[0][1], q[0][2])
1132
+ qw = q[0][3]
1133
+
1134
+ result_manual = v[0] * (wptype(2) * qw * qw - wptype(1))
1135
+ result_manual += wp.cross(qv, v[0]) * qw * wptype(2)
1136
+ result_manual += qv * wp.dot(qv, v[0]) * wptype(2)
1137
+
1138
+ result_inv_manual = v[0] * (wptype(2) * qw * qw - wptype(1))
1139
+ result_inv_manual -= wp.cross(qv, v[0]) * qw * wptype(2)
1140
+ result_inv_manual += qv * wp.dot(qv, v[0]) * wptype(2)
1141
+
1142
+ for i in range(3):
1143
+ # multiply outputs by 2 so we've got something to backpropagate:
1144
+ outputs[i] = wptype(2) * result[i]
1145
+ outputs_inv[i] = wptype(2) * result_inv[i]
1146
+ outputs_manual[i] = wptype(2) * result_manual[i]
1147
+ outputs_inv_manual[i] = wptype(2) * result_inv_manual[i]
1148
+
1149
+ kernel = getkernel(check_quat_rotate, suffix=dtype.__name__)
1150
+ output_select_kernel = get_select_kernel(wptype)
1151
+
1152
+ if register_kernels:
1153
+ return
1154
+
1155
+ q = rng.standard_normal(size=(1, 4))
1156
+ q /= np.linalg.norm(q)
1157
+ q = wp.array(q.astype(dtype), dtype=quat, requires_grad=True, device=device)
1158
+ v = wp.array(0.5 * rng.standard_normal(size=(1, 3)).astype(dtype), dtype=vec3, requires_grad=True, device=device)
1159
+
1160
+ # test values against the manually computed result:
1161
+ outputs = wp.zeros(3, dtype=wptype, requires_grad=True, device=device)
1162
+ outputs_inv = wp.zeros(3, dtype=wptype, requires_grad=True, device=device)
1163
+ outputs_manual = wp.zeros(3, dtype=wptype, requires_grad=True, device=device)
1164
+ outputs_inv_manual = wp.zeros(3, dtype=wptype, requires_grad=True, device=device)
1165
+
1166
+ wp.launch(
1167
+ kernel,
1168
+ dim=1,
1169
+ inputs=[q, v],
1170
+ outputs=[
1171
+ outputs,
1172
+ outputs_inv,
1173
+ outputs_manual,
1174
+ outputs_inv_manual,
1175
+ ],
1176
+ device=device,
1177
+ )
1178
+
1179
+ assert_np_equal(outputs.numpy(), outputs_manual.numpy(), tol=tol)
1180
+ assert_np_equal(outputs_inv.numpy(), outputs_inv_manual.numpy(), tol=tol)
1181
+
1182
+ # test gradients against the manually computed result:
1183
+ for i in range(3):
1184
+ cmp = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1185
+ cmp_inv = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1186
+ cmp_manual = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1187
+ cmp_inv_manual = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1188
+ tape = wp.Tape()
1189
+ with tape:
1190
+ wp.launch(
1191
+ kernel,
1192
+ dim=1,
1193
+ inputs=[q, v],
1194
+ outputs=[
1195
+ outputs,
1196
+ outputs_inv,
1197
+ outputs_manual,
1198
+ outputs_inv_manual,
1199
+ ],
1200
+ device=device,
1201
+ )
1202
+ wp.launch(output_select_kernel, dim=1, inputs=[outputs, i], outputs=[cmp], device=device)
1203
+ wp.launch(output_select_kernel, dim=1, inputs=[outputs_inv, i], outputs=[cmp_inv], device=device)
1204
+ wp.launch(output_select_kernel, dim=1, inputs=[outputs_manual, i], outputs=[cmp_manual], device=device)
1205
+ wp.launch(
1206
+ output_select_kernel, dim=1, inputs=[outputs_inv_manual, i], outputs=[cmp_inv_manual], device=device
1207
+ )
1208
+
1209
+ tape.backward(loss=cmp)
1210
+ qgrads = 1.0 * tape.gradients[q].numpy()
1211
+ vgrads = 1.0 * tape.gradients[v].numpy()
1212
+ tape.zero()
1213
+ tape.backward(loss=cmp_inv)
1214
+ qgrads_inv = 1.0 * tape.gradients[q].numpy()
1215
+ vgrads_inv = 1.0 * tape.gradients[v].numpy()
1216
+ tape.zero()
1217
+ tape.backward(loss=cmp_manual)
1218
+ qgrads_manual = 1.0 * tape.gradients[q].numpy()
1219
+ vgrads_manual = 1.0 * tape.gradients[v].numpy()
1220
+ tape.zero()
1221
+ tape.backward(loss=cmp_inv_manual)
1222
+ qgrads_inv_manual = 1.0 * tape.gradients[q].numpy()
1223
+ vgrads_inv_manual = 1.0 * tape.gradients[v].numpy()
1224
+ tape.zero()
1225
+
1226
+ assert_np_equal(qgrads, qgrads_manual, tol=tol)
1227
+ assert_np_equal(vgrads, vgrads_manual, tol=tol)
1228
+
1229
+ assert_np_equal(qgrads_inv, qgrads_inv_manual, tol=tol)
1230
+ assert_np_equal(vgrads_inv, vgrads_inv_manual, tol=tol)
1231
+
1232
+
1233
+ def test_quat_to_matrix(test, device, dtype, register_kernels=False):
1234
+ rng = np.random.default_rng(123)
1235
+
1236
+ tol = {
1237
+ np.float16: 1.0e-2,
1238
+ np.float32: 1.0e-6,
1239
+ np.float64: 1.0e-8,
1240
+ }.get(dtype, 0)
1241
+
1242
+ wptype = wp._src.types.np_dtype_to_warp_type[np.dtype(dtype)]
1243
+ quat = wp._src.types.quaternion(dtype=wptype)
1244
+ vec3 = wp._src.types.vector(length=3, dtype=wptype)
1245
+
1246
+ def check_quat_to_matrix(
1247
+ q: wp.array(dtype=quat),
1248
+ outputs: wp.array(dtype=wptype),
1249
+ outputs_manual: wp.array(dtype=wptype),
1250
+ ):
1251
+ result = wp.quat_to_matrix(q[0])
1252
+
1253
+ xaxis = wp.quat_rotate(
1254
+ q[0],
1255
+ vec3(
1256
+ wptype(1),
1257
+ wptype(0),
1258
+ wptype(0),
1259
+ ),
1260
+ )
1261
+ yaxis = wp.quat_rotate(
1262
+ q[0],
1263
+ vec3(
1264
+ wptype(0),
1265
+ wptype(1),
1266
+ wptype(0),
1267
+ ),
1268
+ )
1269
+ zaxis = wp.quat_rotate(
1270
+ q[0],
1271
+ vec3(
1272
+ wptype(0),
1273
+ wptype(0),
1274
+ wptype(1),
1275
+ ),
1276
+ )
1277
+ result_manual = wp.matrix_from_cols(xaxis, yaxis, zaxis)
1278
+
1279
+ idx = 0
1280
+ for i in range(3):
1281
+ for j in range(3):
1282
+ # multiply outputs by 2 so we've got something to backpropagate:
1283
+ outputs[idx] = wptype(2) * result[i, j]
1284
+ outputs_manual[idx] = wptype(2) * result_manual[i, j]
1285
+
1286
+ idx = idx + 1
1287
+
1288
+ kernel = getkernel(check_quat_to_matrix, suffix=dtype.__name__)
1289
+ output_select_kernel = get_select_kernel(wptype)
1290
+
1291
+ if register_kernels:
1292
+ return
1293
+
1294
+ q = rng.standard_normal(size=(1, 4))
1295
+ q /= np.linalg.norm(q)
1296
+ q = wp.array(q.astype(dtype), dtype=quat, requires_grad=True, device=device)
1297
+
1298
+ # test values against the manually computed result:
1299
+ outputs = wp.zeros(3 * 3, dtype=wptype, requires_grad=True, device=device)
1300
+ outputs_manual = wp.zeros(3 * 3, dtype=wptype, requires_grad=True, device=device)
1301
+
1302
+ wp.launch(
1303
+ kernel,
1304
+ dim=1,
1305
+ inputs=[q],
1306
+ outputs=[
1307
+ outputs,
1308
+ outputs_manual,
1309
+ ],
1310
+ device=device,
1311
+ )
1312
+
1313
+ assert_np_equal(outputs.numpy(), outputs_manual.numpy(), tol=tol)
1314
+
1315
+ # sanity check: divide by 2 to remove that scale factor we put in there, and
1316
+ # it should be a rotation matrix
1317
+ R = 0.5 * outputs.numpy().reshape(3, 3)
1318
+ assert_np_equal(np.matmul(R, R.T), np.eye(3), tol=tol)
1319
+
1320
+ # test gradients against the manually computed result:
1321
+ idx = 0
1322
+ for _i in range(3):
1323
+ for _j in range(3):
1324
+ cmp = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1325
+ cmp_manual = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1326
+ tape = wp.Tape()
1327
+ with tape:
1328
+ wp.launch(
1329
+ kernel,
1330
+ dim=1,
1331
+ inputs=[q],
1332
+ outputs=[
1333
+ outputs,
1334
+ outputs_manual,
1335
+ ],
1336
+ device=device,
1337
+ )
1338
+ wp.launch(output_select_kernel, dim=1, inputs=[outputs, idx], outputs=[cmp], device=device)
1339
+ wp.launch(
1340
+ output_select_kernel, dim=1, inputs=[outputs_manual, idx], outputs=[cmp_manual], device=device
1341
+ )
1342
+ tape.backward(loss=cmp)
1343
+ qgrads = 1.0 * tape.gradients[q].numpy()
1344
+ tape.zero()
1345
+ tape.backward(loss=cmp_manual)
1346
+ qgrads_manual = 1.0 * tape.gradients[q].numpy()
1347
+ tape.zero()
1348
+
1349
+ assert_np_equal(qgrads, qgrads_manual, tol=tol)
1350
+ idx = idx + 1
1351
+
1352
+
1353
+ ############################################################
1354
+
1355
+
1356
+ def test_slerp_grad(test, device, dtype, register_kernels=False):
1357
+ rng = np.random.default_rng(123)
1358
+ seed = 42
1359
+
1360
+ wptype = wp._src.types.np_dtype_to_warp_type[np.dtype(dtype)]
1361
+ vec3 = wp._src.types.vector(3, wptype)
1362
+ quat = wp._src.types.quaternion(wptype)
1363
+
1364
+ def slerp_kernel(
1365
+ q0: wp.array(dtype=quat),
1366
+ q1: wp.array(dtype=quat),
1367
+ t: wp.array(dtype=wptype),
1368
+ loss: wp.array(dtype=wptype),
1369
+ index: int,
1370
+ ):
1371
+ tid = wp.tid()
1372
+
1373
+ q = wp.quat_slerp(q0[tid], q1[tid], t[tid])
1374
+ wp.atomic_add(loss, 0, q[index])
1375
+
1376
+ slerp_kernel = getkernel(slerp_kernel, suffix=dtype.__name__)
1377
+
1378
+ def slerp_kernel_forward(
1379
+ q0: wp.array(dtype=quat),
1380
+ q1: wp.array(dtype=quat),
1381
+ t: wp.array(dtype=wptype),
1382
+ loss: wp.array(dtype=wptype),
1383
+ index: int,
1384
+ ):
1385
+ tid = wp.tid()
1386
+
1387
+ axis = vec3()
1388
+ angle = wptype(0.0)
1389
+
1390
+ wp.quat_to_axis_angle(wp.mul(wp.quat_inverse(q0[tid]), q1[tid]), axis, angle)
1391
+ q = wp.mul(q0[tid], wp.quat_from_axis_angle(axis, t[tid] * angle))
1392
+
1393
+ wp.atomic_add(loss, 0, q[index])
1394
+
1395
+ slerp_kernel_forward = getkernel(slerp_kernel_forward, suffix=dtype.__name__)
1396
+
1397
+ def quat_sampler_slerp(kernel_seed: int, quats: wp.array(dtype=quat)):
1398
+ tid = wp.tid()
1399
+
1400
+ state = wp.rand_init(kernel_seed, tid)
1401
+
1402
+ angle = wp.randf(state, 0.0, 2.0 * 3.1415926535)
1403
+ dir = wp.sample_unit_sphere_surface(state) * wp.sin(angle * 0.5)
1404
+
1405
+ q = quat(wptype(dir[0]), wptype(dir[1]), wptype(dir[2]), wptype(wp.cos(angle * 0.5)))
1406
+ qn = wp.normalize(q)
1407
+
1408
+ quats[tid] = qn
1409
+
1410
+ quat_sampler = getkernel(quat_sampler_slerp, suffix=dtype.__name__)
1411
+
1412
+ if register_kernels:
1413
+ return
1414
+
1415
+ N = 50
1416
+
1417
+ q0 = wp.zeros(N, dtype=quat, device=device, requires_grad=True)
1418
+ q1 = wp.zeros(N, dtype=quat, device=device, requires_grad=True)
1419
+
1420
+ wp.launch(kernel=quat_sampler, dim=N, inputs=[seed, q0], device=device)
1421
+ wp.launch(kernel=quat_sampler, dim=N, inputs=[seed + 1, q1], device=device)
1422
+
1423
+ t = rng.uniform(low=0.0, high=1.0, size=N)
1424
+ t = wp.array(t, dtype=wptype, device=device, requires_grad=True)
1425
+
1426
+ def compute_gradients(kernel, wrt, index):
1427
+ loss = wp.zeros(1, dtype=wptype, device=device, requires_grad=True)
1428
+ tape = wp.Tape()
1429
+ with tape:
1430
+ wp.launch(kernel=kernel, dim=N, inputs=[q0, q1, t, loss, index], device=device)
1431
+
1432
+ tape.backward(loss)
1433
+
1434
+ gradients = 1.0 * tape.gradients[wrt].numpy()
1435
+ tape.zero()
1436
+
1437
+ return loss.numpy()[0], gradients
1438
+
1439
+ eps = {
1440
+ np.float16: 2.0e-2,
1441
+ np.float32: 1.0e-5,
1442
+ np.float64: 1.0e-8,
1443
+ }.get(dtype, 0)
1444
+
1445
+ # wrt t
1446
+
1447
+ # gather gradients from builtin adjoints
1448
+ xcmp, gradients_x = compute_gradients(slerp_kernel, t, 0)
1449
+ ycmp, gradients_y = compute_gradients(slerp_kernel, t, 1)
1450
+ zcmp, gradients_z = compute_gradients(slerp_kernel, t, 2)
1451
+ wcmp, gradients_w = compute_gradients(slerp_kernel, t, 3)
1452
+
1453
+ # gather gradients from autodiff
1454
+ xcmp_auto, gradients_x_auto = compute_gradients(slerp_kernel_forward, t, 0)
1455
+ ycmp_auto, gradients_y_auto = compute_gradients(slerp_kernel_forward, t, 1)
1456
+ zcmp_auto, gradients_z_auto = compute_gradients(slerp_kernel_forward, t, 2)
1457
+ wcmp_auto, gradients_w_auto = compute_gradients(slerp_kernel_forward, t, 3)
1458
+
1459
+ assert_np_equal(gradients_x, gradients_x_auto, tol=eps)
1460
+ assert_np_equal(gradients_y, gradients_y_auto, tol=eps)
1461
+ assert_np_equal(gradients_z, gradients_z_auto, tol=eps)
1462
+ assert_np_equal(gradients_w, gradients_w_auto, tol=eps)
1463
+ assert_np_equal(xcmp, xcmp_auto, tol=eps)
1464
+ assert_np_equal(ycmp, ycmp_auto, tol=eps)
1465
+ assert_np_equal(zcmp, zcmp_auto, tol=eps)
1466
+ assert_np_equal(wcmp, wcmp_auto, tol=eps)
1467
+
1468
+ # wrt q0
1469
+
1470
+ # gather gradients from builtin adjoints
1471
+ xcmp, gradients_x = compute_gradients(slerp_kernel, q0, 0)
1472
+ ycmp, gradients_y = compute_gradients(slerp_kernel, q0, 1)
1473
+ zcmp, gradients_z = compute_gradients(slerp_kernel, q0, 2)
1474
+ wcmp, gradients_w = compute_gradients(slerp_kernel, q0, 3)
1475
+
1476
+ # gather gradients from autodiff
1477
+ xcmp_auto, gradients_x_auto = compute_gradients(slerp_kernel_forward, q0, 0)
1478
+ ycmp_auto, gradients_y_auto = compute_gradients(slerp_kernel_forward, q0, 1)
1479
+ zcmp_auto, gradients_z_auto = compute_gradients(slerp_kernel_forward, q0, 2)
1480
+ wcmp_auto, gradients_w_auto = compute_gradients(slerp_kernel_forward, q0, 3)
1481
+
1482
+ assert_np_equal(gradients_x, gradients_x_auto, tol=eps)
1483
+ assert_np_equal(gradients_y, gradients_y_auto, tol=eps)
1484
+ assert_np_equal(gradients_z, gradients_z_auto, tol=eps)
1485
+ assert_np_equal(gradients_w, gradients_w_auto, tol=eps)
1486
+ assert_np_equal(xcmp, xcmp_auto, tol=eps)
1487
+ assert_np_equal(ycmp, ycmp_auto, tol=eps)
1488
+ assert_np_equal(zcmp, zcmp_auto, tol=eps)
1489
+ assert_np_equal(wcmp, wcmp_auto, tol=eps)
1490
+
1491
+ # wrt q1
1492
+
1493
+ # gather gradients from builtin adjoints
1494
+ xcmp, gradients_x = compute_gradients(slerp_kernel, q1, 0)
1495
+ ycmp, gradients_y = compute_gradients(slerp_kernel, q1, 1)
1496
+ zcmp, gradients_z = compute_gradients(slerp_kernel, q1, 2)
1497
+ wcmp, gradients_w = compute_gradients(slerp_kernel, q1, 3)
1498
+
1499
+ # gather gradients from autodiff
1500
+ xcmp_auto, gradients_x_auto = compute_gradients(slerp_kernel_forward, q1, 0)
1501
+ ycmp_auto, gradients_y_auto = compute_gradients(slerp_kernel_forward, q1, 1)
1502
+ zcmp_auto, gradients_z_auto = compute_gradients(slerp_kernel_forward, q1, 2)
1503
+ wcmp_auto, gradients_w_auto = compute_gradients(slerp_kernel_forward, q1, 3)
1504
+
1505
+ assert_np_equal(gradients_x, gradients_x_auto, tol=eps)
1506
+ assert_np_equal(gradients_y, gradients_y_auto, tol=eps)
1507
+ assert_np_equal(gradients_z, gradients_z_auto, tol=eps)
1508
+ assert_np_equal(gradients_w, gradients_w_auto, tol=eps)
1509
+ assert_np_equal(xcmp, xcmp_auto, tol=eps)
1510
+ assert_np_equal(ycmp, ycmp_auto, tol=eps)
1511
+ assert_np_equal(zcmp, zcmp_auto, tol=eps)
1512
+ assert_np_equal(wcmp, wcmp_auto, tol=eps)
1513
+
1514
+
1515
+ ############################################################
1516
+
1517
+
1518
+ def test_quat_to_axis_angle_grad(test, device, dtype, register_kernels=False):
1519
+ seed = 42
1520
+ num_rand = 50
1521
+
1522
+ wptype = wp._src.types.np_dtype_to_warp_type[np.dtype(dtype)]
1523
+ vec3 = wp._src.types.vector(3, wptype)
1524
+ vec4 = wp._src.types.vector(4, wptype)
1525
+ quat = wp._src.types.quaternion(wptype)
1526
+
1527
+ def quat_to_axis_angle_kernel(quats: wp.array(dtype=quat), loss: wp.array(dtype=wptype), coord_idx: int):
1528
+ tid = wp.tid()
1529
+ axis = vec3()
1530
+ angle = wptype(0.0)
1531
+
1532
+ wp.quat_to_axis_angle(quats[tid], axis, angle)
1533
+ a = vec4(axis[0], axis[1], axis[2], angle)
1534
+
1535
+ wp.atomic_add(loss, 0, a[coord_idx])
1536
+
1537
+ quat_to_axis_angle_kernel = getkernel(quat_to_axis_angle_kernel, suffix=dtype.__name__)
1538
+
1539
+ def quat_to_axis_angle_kernel_forward(quats: wp.array(dtype=quat), loss: wp.array(dtype=wptype), coord_idx: int):
1540
+ tid = wp.tid()
1541
+ q = quats[tid]
1542
+ axis = vec3()
1543
+ angle = wptype(0.0)
1544
+
1545
+ v = vec3(q[0], q[1], q[2])
1546
+ if q[3] < wptype(0):
1547
+ axis = -wp.normalize(v)
1548
+ else:
1549
+ axis = wp.normalize(v)
1550
+
1551
+ angle = wptype(2) * wp.atan2(wp.length(v), wp.abs(q[3]))
1552
+ a = vec4(axis[0], axis[1], axis[2], angle)
1553
+
1554
+ wp.atomic_add(loss, 0, a[coord_idx])
1555
+
1556
+ quat_to_axis_angle_kernel_forward = getkernel(quat_to_axis_angle_kernel_forward, suffix=dtype.__name__)
1557
+
1558
+ def quat_sampler(kernel_seed: int, angles: wp.array(dtype=float), quats: wp.array(dtype=quat)):
1559
+ tid = wp.tid()
1560
+
1561
+ state = wp.rand_init(kernel_seed, tid)
1562
+
1563
+ angle = angles[tid]
1564
+ dir = wp.sample_unit_sphere_surface(state) * wp.sin(angle * 0.5)
1565
+
1566
+ q = quat(wptype(dir[0]), wptype(dir[1]), wptype(dir[2]), wptype(wp.cos(angle * 0.5)))
1567
+ qn = wp.normalize(q)
1568
+
1569
+ quats[tid] = qn
1570
+
1571
+ quat_sampler = getkernel(quat_sampler, suffix=dtype.__name__)
1572
+
1573
+ if register_kernels:
1574
+ return
1575
+
1576
+ quats = wp.zeros(num_rand, dtype=quat, device=device, requires_grad=True)
1577
+ angles = wp.array(
1578
+ np.linspace(0.0, 2.0 * np.pi, num_rand, endpoint=False, dtype=np.float32), dtype=float, device=device
1579
+ )
1580
+ wp.launch(kernel=quat_sampler, dim=num_rand, inputs=[seed, angles, quats], device=device)
1581
+
1582
+ edge_cases = np.array(
1583
+ [(1.0, 0.0, 0.0, 0.0), (0.0, 1.0 / np.sqrt(3), 1.0 / np.sqrt(3), 1.0 / np.sqrt(3)), (0.0, 0.0, 0.0, 0.0)]
1584
+ )
1585
+ num_edge = len(edge_cases)
1586
+ edge_cases = wp.array(edge_cases, dtype=quat, device=device, requires_grad=True)
1587
+
1588
+ def compute_gradients(arr, kernel, dim, index):
1589
+ loss = wp.zeros(1, dtype=wptype, device=device, requires_grad=True)
1590
+ tape = wp.Tape()
1591
+ with tape:
1592
+ wp.launch(kernel=kernel, dim=dim, inputs=[arr, loss, index], device=device)
1593
+
1594
+ tape.backward(loss)
1595
+
1596
+ gradients = 1.0 * tape.gradients[arr].numpy()
1597
+ tape.zero()
1598
+
1599
+ return loss.numpy()[0], gradients
1600
+
1601
+ # gather gradients from builtin adjoints
1602
+ xcmp, gradients_x = compute_gradients(quats, quat_to_axis_angle_kernel, num_rand, 0)
1603
+ ycmp, gradients_y = compute_gradients(quats, quat_to_axis_angle_kernel, num_rand, 1)
1604
+ zcmp, gradients_z = compute_gradients(quats, quat_to_axis_angle_kernel, num_rand, 2)
1605
+ wcmp, gradients_w = compute_gradients(quats, quat_to_axis_angle_kernel, num_rand, 3)
1606
+
1607
+ # gather gradients from autodiff
1608
+ xcmp_auto, gradients_x_auto = compute_gradients(quats, quat_to_axis_angle_kernel_forward, num_rand, 0)
1609
+ ycmp_auto, gradients_y_auto = compute_gradients(quats, quat_to_axis_angle_kernel_forward, num_rand, 1)
1610
+ zcmp_auto, gradients_z_auto = compute_gradients(quats, quat_to_axis_angle_kernel_forward, num_rand, 2)
1611
+ wcmp_auto, gradients_w_auto = compute_gradients(quats, quat_to_axis_angle_kernel_forward, num_rand, 3)
1612
+
1613
+ # edge cases: gather gradients from builtin adjoints
1614
+ _, edge_gradients_x = compute_gradients(edge_cases, quat_to_axis_angle_kernel, num_edge, 0)
1615
+ _, edge_gradients_y = compute_gradients(edge_cases, quat_to_axis_angle_kernel, num_edge, 1)
1616
+ _, edge_gradients_z = compute_gradients(edge_cases, quat_to_axis_angle_kernel, num_edge, 2)
1617
+ _, edge_gradients_w = compute_gradients(edge_cases, quat_to_axis_angle_kernel, num_edge, 3)
1618
+
1619
+ # edge cases: gather gradients from autodiff
1620
+ _, edge_gradients_x_auto = compute_gradients(edge_cases, quat_to_axis_angle_kernel_forward, num_edge, 0)
1621
+ _, edge_gradients_y_auto = compute_gradients(edge_cases, quat_to_axis_angle_kernel_forward, num_edge, 1)
1622
+ _, edge_gradients_z_auto = compute_gradients(edge_cases, quat_to_axis_angle_kernel_forward, num_edge, 2)
1623
+ _, edge_gradients_w_auto = compute_gradients(edge_cases, quat_to_axis_angle_kernel_forward, num_edge, 3)
1624
+
1625
+ eps = {
1626
+ np.float16: 2.0e-1,
1627
+ np.float32: 2.0e-4,
1628
+ np.float64: 2.0e-7,
1629
+ }.get(dtype, 0)
1630
+
1631
+ assert_np_equal(xcmp, xcmp_auto, tol=eps)
1632
+ assert_np_equal(ycmp, ycmp_auto, tol=eps)
1633
+ assert_np_equal(zcmp, zcmp_auto, tol=eps)
1634
+ assert_np_equal(wcmp, wcmp_auto, tol=eps)
1635
+
1636
+ assert_np_equal(gradients_x, gradients_x_auto, tol=eps)
1637
+ assert_np_equal(gradients_y, gradients_y_auto, tol=eps)
1638
+ assert_np_equal(gradients_z, gradients_z_auto, tol=eps)
1639
+ assert_np_equal(gradients_w, gradients_w_auto, tol=eps)
1640
+
1641
+ assert_np_equal(edge_gradients_x, edge_gradients_x_auto, tol=eps)
1642
+ assert_np_equal(edge_gradients_y, edge_gradients_y_auto, tol=eps)
1643
+ assert_np_equal(edge_gradients_z, edge_gradients_z_auto, tol=eps)
1644
+ assert_np_equal(edge_gradients_w, edge_gradients_w_auto, tol=eps)
1645
+
1646
+
1647
+ ############################################################
1648
+
1649
+
1650
+ def test_quat_rpy_grad(test, device, dtype, register_kernels=False):
1651
+ rng = np.random.default_rng(123)
1652
+ N = 3
1653
+
1654
+ wptype = wp._src.types.np_dtype_to_warp_type[np.dtype(dtype)]
1655
+
1656
+ vec3 = wp._src.types.vector(3, wptype)
1657
+ quat = wp._src.types.quaternion(wptype)
1658
+
1659
+ def rpy_to_quat_kernel(rpy_arr: wp.array(dtype=vec3), loss: wp.array(dtype=wptype), coord_idx: int):
1660
+ tid = wp.tid()
1661
+ rpy = rpy_arr[tid]
1662
+ roll = rpy[0]
1663
+ pitch = rpy[1]
1664
+ yaw = rpy[2]
1665
+
1666
+ q = wp.quat_rpy(roll, pitch, yaw)
1667
+
1668
+ wp.atomic_add(loss, 0, q[coord_idx])
1669
+
1670
+ rpy_to_quat_kernel = getkernel(rpy_to_quat_kernel, suffix=dtype.__name__)
1671
+
1672
+ def rpy_to_quat_kernel_forward(rpy_arr: wp.array(dtype=vec3), loss: wp.array(dtype=wptype), coord_idx: int):
1673
+ tid = wp.tid()
1674
+ rpy = rpy_arr[tid]
1675
+ roll = rpy[0]
1676
+ pitch = rpy[1]
1677
+ yaw = rpy[2]
1678
+
1679
+ cy = wp.cos(yaw * wptype(0.5))
1680
+ sy = wp.sin(yaw * wptype(0.5))
1681
+ cr = wp.cos(roll * wptype(0.5))
1682
+ sr = wp.sin(roll * wptype(0.5))
1683
+ cp = wp.cos(pitch * wptype(0.5))
1684
+ sp = wp.sin(pitch * wptype(0.5))
1685
+
1686
+ w = cy * cr * cp + sy * sr * sp
1687
+ x = cy * sr * cp - sy * cr * sp
1688
+ y = cy * cr * sp + sy * sr * cp
1689
+ z = sy * cr * cp - cy * sr * sp
1690
+
1691
+ q = quat(x, y, z, w)
1692
+
1693
+ wp.atomic_add(loss, 0, q[coord_idx])
1694
+
1695
+ rpy_to_quat_kernel_forward = getkernel(rpy_to_quat_kernel_forward, suffix=dtype.__name__)
1696
+
1697
+ if register_kernels:
1698
+ return
1699
+
1700
+ rpy_arr = rng.uniform(low=-np.pi, high=np.pi, size=(N, 3))
1701
+ rpy_arr = wp.array(rpy_arr, dtype=vec3, device=device, requires_grad=True)
1702
+
1703
+ def compute_gradients(kernel, wrt, index):
1704
+ loss = wp.zeros(1, dtype=wptype, device=device, requires_grad=True)
1705
+ tape = wp.Tape()
1706
+ with tape:
1707
+ wp.launch(kernel=kernel, dim=N, inputs=[wrt, loss, index], device=device)
1708
+
1709
+ tape.backward(loss)
1710
+
1711
+ gradients = 1.0 * tape.gradients[wrt].numpy()
1712
+ tape.zero()
1713
+
1714
+ return loss.numpy()[0], gradients
1715
+
1716
+ # wrt rpy
1717
+ # gather gradients from builtin adjoints
1718
+ rcmp, gradients_r = compute_gradients(rpy_to_quat_kernel, rpy_arr, 0)
1719
+ pcmp, gradients_p = compute_gradients(rpy_to_quat_kernel, rpy_arr, 1)
1720
+ ycmp, gradients_y = compute_gradients(rpy_to_quat_kernel, rpy_arr, 2)
1721
+
1722
+ # gather gradients from autodiff
1723
+ rcmp_auto, gradients_r_auto = compute_gradients(rpy_to_quat_kernel_forward, rpy_arr, 0)
1724
+ pcmp_auto, gradients_p_auto = compute_gradients(rpy_to_quat_kernel_forward, rpy_arr, 1)
1725
+ ycmp_auto, gradients_y_auto = compute_gradients(rpy_to_quat_kernel_forward, rpy_arr, 2)
1726
+
1727
+ eps = {
1728
+ np.float16: 2.0e-2,
1729
+ np.float32: 1.0e-5,
1730
+ np.float64: 1.0e-8,
1731
+ }.get(dtype, 0)
1732
+
1733
+ assert_np_equal(rcmp, rcmp_auto, tol=eps)
1734
+ assert_np_equal(pcmp, pcmp_auto, tol=eps)
1735
+ assert_np_equal(ycmp, ycmp_auto, tol=eps)
1736
+
1737
+ assert_np_equal(gradients_r, gradients_r_auto, tol=eps)
1738
+ assert_np_equal(gradients_p, gradients_p_auto, tol=eps)
1739
+ assert_np_equal(gradients_y, gradients_y_auto, tol=eps)
1740
+
1741
+
1742
+ ############################################################
1743
+
1744
+
1745
+ def test_quat_from_matrix(test, device, dtype, register_kernels=False):
1746
+ wptype = wp._src.types.np_dtype_to_warp_type[np.dtype(dtype)]
1747
+ mat33 = wp._src.types.matrix((3, 3), wptype)
1748
+ mat44 = wp._src.types.matrix((4, 4), wptype)
1749
+ quat = wp._src.types.quaternion(wptype)
1750
+
1751
+ def quat_from_matrix(m: wp.array2d(dtype=wptype), loss: wp.array(dtype=wptype), idx: int):
1752
+ tid = wp.tid()
1753
+
1754
+ # fmt: off
1755
+ m3 = mat33(
1756
+ m[tid, 0], m[tid, 1], m[tid, 2],
1757
+ m[tid, 3], m[tid, 4], m[tid, 5],
1758
+ m[tid, 6], m[tid, 7], m[tid, 8],
1759
+ )
1760
+ q1 = wp.quat_from_matrix(m3)
1761
+
1762
+ m4 = mat44(
1763
+ m[tid, 0], m[tid, 1], m[tid, 2], wptype(0.0),
1764
+ m[tid, 3], m[tid, 4], m[tid, 5], wptype(0.0),
1765
+ m[tid, 6], m[tid, 7], m[tid, 8], wptype(0.0),
1766
+ wptype(0.0), wptype(0.0), wptype(0.0), wptype(1.0),
1767
+ )
1768
+ q2 = wp.quat_from_matrix(m4)
1769
+ # fmt: on
1770
+
1771
+ wp.expect_eq(q1, q2)
1772
+ wp.atomic_add(loss, 0, q1[idx])
1773
+
1774
+ def quat_from_matrix_forward(mats: wp.array2d(dtype=wptype), loss: wp.array(dtype=wptype), idx: int):
1775
+ tid = wp.tid()
1776
+
1777
+ m = mat33(
1778
+ mats[tid, 0],
1779
+ mats[tid, 1],
1780
+ mats[tid, 2],
1781
+ mats[tid, 3],
1782
+ mats[tid, 4],
1783
+ mats[tid, 5],
1784
+ mats[tid, 6],
1785
+ mats[tid, 7],
1786
+ mats[tid, 8],
1787
+ )
1788
+
1789
+ tr = m[0][0] + m[1][1] + m[2][2]
1790
+ x = wptype(0)
1791
+ y = wptype(0)
1792
+ z = wptype(0)
1793
+ w = wptype(0)
1794
+ h = wptype(0)
1795
+
1796
+ if tr >= wptype(0):
1797
+ h = wp.sqrt(tr + wptype(1))
1798
+ w = wptype(0.5) * h
1799
+ h = wptype(0.5) / h
1800
+
1801
+ x = (m[2][1] - m[1][2]) * h
1802
+ y = (m[0][2] - m[2][0]) * h
1803
+ z = (m[1][0] - m[0][1]) * h
1804
+ else:
1805
+ max_diag = 0
1806
+ if m[1][1] > m[0][0]:
1807
+ max_diag = 1
1808
+ if m[2][2] > m[max_diag][max_diag]:
1809
+ max_diag = 2
1810
+
1811
+ if max_diag == 0:
1812
+ h = wp.sqrt((m[0][0] - (m[1][1] + m[2][2])) + wptype(1))
1813
+ x = wptype(0.5) * h
1814
+ h = wptype(0.5) / h
1815
+
1816
+ y = (m[0][1] + m[1][0]) * h
1817
+ z = (m[2][0] + m[0][2]) * h
1818
+ w = (m[2][1] - m[1][2]) * h
1819
+ elif max_diag == 1:
1820
+ h = wp.sqrt((m[1][1] - (m[2][2] + m[0][0])) + wptype(1))
1821
+ y = wptype(0.5) * h
1822
+ h = wptype(0.5) / h
1823
+
1824
+ z = (m[1][2] + m[2][1]) * h
1825
+ x = (m[0][1] + m[1][0]) * h
1826
+ w = (m[0][2] - m[2][0]) * h
1827
+ if max_diag == 2:
1828
+ h = wp.sqrt((m[2][2] - (m[0][0] + m[1][1])) + wptype(1))
1829
+ z = wptype(0.5) * h
1830
+ h = wptype(0.5) / h
1831
+
1832
+ x = (m[2][0] + m[0][2]) * h
1833
+ y = (m[1][2] + m[2][1]) * h
1834
+ w = (m[1][0] - m[0][1]) * h
1835
+
1836
+ q = wp.normalize(quat(x, y, z, w))
1837
+
1838
+ wp.atomic_add(loss, 0, q[idx])
1839
+
1840
+ quat_from_matrix = getkernel(quat_from_matrix, suffix=dtype.__name__)
1841
+ quat_from_matrix_forward = getkernel(quat_from_matrix_forward, suffix=dtype.__name__)
1842
+
1843
+ if register_kernels:
1844
+ return
1845
+
1846
+ m = np.array(
1847
+ [
1848
+ [1.0, 0.0, 0.0, 0.0, 0.5, 0.866, 0.0, -0.866, 0.5],
1849
+ [0.866, 0.0, 0.25, -0.433, 0.5, 0.75, -0.25, -0.866, 0.433],
1850
+ [0.866, -0.433, 0.25, 0.0, 0.5, 0.866, -0.5, -0.75, 0.433],
1851
+ [-1.2, -1.6, -2.3, 0.25, -0.6, -0.33, 3.2, -1.0, -2.2],
1852
+ ]
1853
+ )
1854
+ m = wp.array2d(m, dtype=wptype, device=device, requires_grad=True)
1855
+
1856
+ N = m.shape[0]
1857
+
1858
+ def compute_gradients(kernel, wrt, index):
1859
+ loss = wp.zeros(1, dtype=wptype, device=device, requires_grad=True)
1860
+ tape = wp.Tape()
1861
+
1862
+ with tape:
1863
+ wp.launch(kernel=kernel, dim=N, inputs=[m, loss, index], device=device)
1864
+
1865
+ tape.backward(loss)
1866
+
1867
+ gradients = 1.0 * tape.gradients[wrt].numpy()
1868
+ tape.zero()
1869
+
1870
+ return loss.numpy()[0], gradients
1871
+
1872
+ # gather gradients from builtin adjoints
1873
+ cmpx, gradients_x = compute_gradients(quat_from_matrix, m, 0)
1874
+ cmpy, gradients_y = compute_gradients(quat_from_matrix, m, 1)
1875
+ cmpz, gradients_z = compute_gradients(quat_from_matrix, m, 2)
1876
+ cmpw, gradients_w = compute_gradients(quat_from_matrix, m, 3)
1877
+
1878
+ # gather gradients from autodiff
1879
+ cmpx_auto, gradients_x_auto = compute_gradients(quat_from_matrix_forward, m, 0)
1880
+ cmpy_auto, gradients_y_auto = compute_gradients(quat_from_matrix_forward, m, 1)
1881
+ cmpz_auto, gradients_z_auto = compute_gradients(quat_from_matrix_forward, m, 2)
1882
+ cmpw_auto, gradients_w_auto = compute_gradients(quat_from_matrix_forward, m, 3)
1883
+
1884
+ # compare
1885
+ eps = 1.0e6
1886
+
1887
+ eps = {
1888
+ np.float16: 2.0e-2,
1889
+ np.float32: 1.0e-5,
1890
+ np.float64: 1.0e-8,
1891
+ }.get(dtype, 0)
1892
+
1893
+ assert_np_equal(cmpx, cmpx_auto, tol=eps)
1894
+ assert_np_equal(cmpy, cmpy_auto, tol=eps)
1895
+ assert_np_equal(cmpz, cmpz_auto, tol=eps)
1896
+ assert_np_equal(cmpw, cmpw_auto, tol=eps)
1897
+
1898
+ assert_np_equal(gradients_x, gradients_x_auto, tol=eps)
1899
+ assert_np_equal(gradients_y, gradients_y_auto, tol=eps)
1900
+ assert_np_equal(gradients_z, gradients_z_auto, tol=eps)
1901
+ assert_np_equal(gradients_w, gradients_w_auto, tol=eps)
1902
+
1903
+
1904
+ def test_quat_identity(test, device, dtype, register_kernels=False):
1905
+ wptype = wp._src.types.np_dtype_to_warp_type[np.dtype(dtype)]
1906
+
1907
+ def quat_identity_test(output: wp.array(dtype=wptype)):
1908
+ q = wp.quat_identity(dtype=wptype)
1909
+ output[0] = q[0]
1910
+ output[1] = q[1]
1911
+ output[2] = q[2]
1912
+ output[3] = q[3]
1913
+
1914
+ def quat_identity_test_default(output: wp.array(dtype=wp.float32)):
1915
+ q = wp.quat_identity()
1916
+ output[0] = q[0]
1917
+ output[1] = q[1]
1918
+ output[2] = q[2]
1919
+ output[3] = q[3]
1920
+
1921
+ quat_identity_kernel = getkernel(quat_identity_test, suffix=dtype.__name__)
1922
+ quat_identity_default_kernel = getkernel(quat_identity_test_default, suffix=np.float32.__name__)
1923
+
1924
+ if register_kernels:
1925
+ return
1926
+
1927
+ output = wp.zeros(4, dtype=wptype, device=device)
1928
+ wp.launch(quat_identity_kernel, dim=1, inputs=[], outputs=[output], device=device)
1929
+ expected = np.zeros_like(output.numpy())
1930
+ expected[3] = 1
1931
+ assert_np_equal(output.numpy(), expected)
1932
+
1933
+ # let's just test that it defaults to float32:
1934
+ output = wp.zeros(4, dtype=wp.float32, device=device)
1935
+ wp.launch(quat_identity_default_kernel, dim=1, inputs=[], outputs=[output], device=device)
1936
+ expected = np.zeros_like(output.numpy())
1937
+ expected[3] = 1
1938
+ assert_np_equal(output.numpy(), expected)
1939
+
1940
+
1941
+ def test_quat_euler_conversion(test, device, dtype, register_kernels=False):
1942
+ rng = np.random.default_rng(123)
1943
+ N = 3
1944
+
1945
+ rpy_arr = rng.uniform(low=-np.pi, high=np.pi, size=(N, 3))
1946
+
1947
+ quats_from_euler = [list(quat_from_euler(wp.vec3(*rpy), 0, 1, 2)) for rpy in rpy_arr]
1948
+ quats_from_rpy = [list(wp.quat_rpy(rpy[0], rpy[1], rpy[2])) for rpy in rpy_arr]
1949
+
1950
+ assert_np_equal(np.array(quats_from_euler), np.array(quats_from_rpy), tol=1e-4)
1951
+
1952
+
1953
+ def test_anon_type_instance(test, device, dtype, register_kernels=False):
1954
+ rng = np.random.default_rng(123)
1955
+ wptype = wp._src.types.np_dtype_to_warp_type[np.dtype(dtype)]
1956
+
1957
+ def quat_create_test(input: wp.array(dtype=wptype), output: wp.array(dtype=wptype)):
1958
+ # component constructor:
1959
+ q = wp.quaternion(input[0], input[1], input[2], input[3])
1960
+ output[0] = wptype(2) * q[0]
1961
+ output[1] = wptype(2) * q[1]
1962
+ output[2] = wptype(2) * q[2]
1963
+ output[3] = wptype(2) * q[3]
1964
+
1965
+ # vector / scalar constructor:
1966
+ q2 = wp.quaternion(wp.vector(input[4], input[5], input[6]), input[7])
1967
+ output[4] = wptype(2) * q2[0]
1968
+ output[5] = wptype(2) * q2[1]
1969
+ output[6] = wptype(2) * q2[2]
1970
+ output[7] = wptype(2) * q2[3]
1971
+
1972
+ quat_create_kernel = getkernel(quat_create_test, suffix=dtype.__name__)
1973
+ output_select_kernel = get_select_kernel(wptype)
1974
+
1975
+ if register_kernels:
1976
+ return
1977
+
1978
+ input = wp.array(rng.standard_normal(size=8).astype(dtype), requires_grad=True, device=device)
1979
+ output = wp.zeros(8, dtype=wptype, requires_grad=True, device=device)
1980
+ wp.launch(quat_create_kernel, dim=1, inputs=[input], outputs=[output], device=device)
1981
+ assert_np_equal(output.numpy(), 2 * input.numpy())
1982
+
1983
+ for i in range(len(input)):
1984
+ cmp = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1985
+ tape = wp.Tape()
1986
+ with tape:
1987
+ wp.launch(quat_create_kernel, dim=1, inputs=[input], outputs=[output], device=device)
1988
+ wp.launch(output_select_kernel, dim=1, inputs=[output, i], outputs=[cmp], device=device)
1989
+ tape.backward(loss=cmp)
1990
+ expectedgrads = np.zeros(len(input))
1991
+ expectedgrads[i] = 2
1992
+ assert_np_equal(tape.gradients[input].numpy(), expectedgrads)
1993
+ tape.zero()
1994
+
1995
+
1996
+ # Same as above but with a default (float) type
1997
+ # which tests some different code paths that
1998
+ # need to ensure types are correctly canonicalized
1999
+ # during codegen
2000
+ @wp.kernel
2001
+ def test_constructor_default():
2002
+ qzero = wp.quat()
2003
+ wp.expect_eq(qzero[0], 0.0)
2004
+ wp.expect_eq(qzero[1], 0.0)
2005
+ wp.expect_eq(qzero[2], 0.0)
2006
+ wp.expect_eq(qzero[3], 0.0)
2007
+
2008
+ qval = wp.quat(1.0, 2.0, 3.0, 4.0)
2009
+ wp.expect_eq(qval[0], 1.0)
2010
+ wp.expect_eq(qval[1], 2.0)
2011
+ wp.expect_eq(qval[2], 3.0)
2012
+ wp.expect_eq(qval[3], 4.0)
2013
+
2014
+ qeye = wp.quat_identity()
2015
+ wp.expect_eq(qeye[0], 0.0)
2016
+ wp.expect_eq(qeye[1], 0.0)
2017
+ wp.expect_eq(qeye[2], 0.0)
2018
+ wp.expect_eq(qeye[3], 1.0)
2019
+
2020
+ qlit = wp.quaternion(1.0, 2.0, 3.0, 4.0, dtype=float)
2021
+ wp.expect_eq(qlit[0], 1.0)
2022
+ wp.expect_eq(qlit[1], 2.0)
2023
+ wp.expect_eq(qlit[2], 3.0)
2024
+ wp.expect_eq(qlit[3], 4.0)
2025
+
2026
+
2027
+ def test_py_arithmetic_ops(test, device, dtype):
2028
+ wptype = wp._src.types.np_dtype_to_warp_type[np.dtype(dtype)]
2029
+
2030
+ def make_quat(*args):
2031
+ if wptype in wp._src.types.int_types:
2032
+ # Cast to the correct integer type to simulate wrapping.
2033
+ return tuple(wptype._type_(x).value for x in args)
2034
+
2035
+ return args
2036
+
2037
+ quat_cls = wp._src.types.quaternion(wptype)
2038
+
2039
+ v = quat_cls(1, -2, 3, -4)
2040
+ test.assertSequenceEqual(+v, make_quat(1, -2, 3, -4))
2041
+ test.assertSequenceEqual(-v, make_quat(-1, 2, -3, 4))
2042
+ test.assertSequenceEqual(v + quat_cls(5, 5, 5, 5), make_quat(6, 3, 8, 1))
2043
+ test.assertSequenceEqual(v - quat_cls(5, 5, 5, 5), make_quat(-4, -7, -2, -9))
2044
+
2045
+ v = quat_cls(2, 4, 6, 8)
2046
+ test.assertSequenceEqual(v * wptype(2), make_quat(4, 8, 12, 16))
2047
+ test.assertSequenceEqual(wptype(2) * v, make_quat(4, 8, 12, 16))
2048
+ test.assertSequenceEqual(v / wptype(2), make_quat(1, 2, 3, 4))
2049
+ test.assertSequenceEqual(wptype(24) / v, make_quat(12, 6, 4, 3))
2050
+
2051
+
2052
+ @wp.kernel
2053
+ def quat_grad(q: wp.quat):
2054
+ wp.expect_eq(q.w, 1.0)
2055
+
2056
+
2057
+ # Test passing of a quaternion in the backward pass
2058
+ def test_quat_backward(test, device):
2059
+ q = wp.quat_identity()
2060
+
2061
+ tape = wp.Tape()
2062
+ with tape:
2063
+ wp.launch(quat_grad, dim=1, inputs=[q], device=device)
2064
+
2065
+ tape.backward()
2066
+
2067
+
2068
+ @wp.kernel
2069
+ def quat_len_kernel(
2070
+ q: wp.quat,
2071
+ out: wp.array(dtype=int),
2072
+ ):
2073
+ length = wp.static(len(q))
2074
+ wp.expect_eq(wp.static(len(q)), 4)
2075
+ out[0] = wp.static(len(q))
2076
+
2077
+ foo = wp.quat()
2078
+ length = len(foo)
2079
+ wp.expect_eq(len(foo), 4)
2080
+ out[1] = len(foo)
2081
+
2082
+
2083
+ def test_quat_len(test, device):
2084
+ q = wp.quat()
2085
+ out = wp.empty(2, dtype=int, device=device)
2086
+ wp.launch(quat_len_kernel, dim=(1,), inputs=(q,), outputs=(out,), device=device)
2087
+
2088
+ test.assertEqual(out.numpy()[0], 4)
2089
+ test.assertEqual(out.numpy()[1], 4)
2090
+
2091
+
2092
+ @wp.kernel
2093
+ def quat_extract_subscript(x: wp.array(dtype=wp.quat), y: wp.array(dtype=float)):
2094
+ tid = wp.tid()
2095
+
2096
+ a = x[tid]
2097
+ b = a[0] + 2.0 * a[1] + 3.0 * a[2] + 4.0 * a[3]
2098
+ y[tid] = b
2099
+
2100
+
2101
+ @wp.kernel
2102
+ def quat_extract_attribute(x: wp.array(dtype=wp.quat), y: wp.array(dtype=float)):
2103
+ tid = wp.tid()
2104
+
2105
+ a = x[tid]
2106
+ b = a.x + float(2.0) * a.y + 3.0 * a.z + 4.0 * a.w
2107
+ y[tid] = b
2108
+
2109
+
2110
+ def test_quat_extract(test, device):
2111
+ def run(kernel):
2112
+ x = wp.ones(1, dtype=wp.quat, requires_grad=True, device=device)
2113
+ y = wp.zeros(1, dtype=float, requires_grad=True, device=device)
2114
+
2115
+ tape = wp.Tape()
2116
+ with tape:
2117
+ wp.launch(kernel, 1, inputs=[x], outputs=[y], device=device)
2118
+
2119
+ y.grad = wp.ones_like(y)
2120
+ tape.backward()
2121
+
2122
+ assert_np_equal(y.numpy(), np.array([10.0], dtype=float))
2123
+ assert_np_equal(x.grad.numpy(), np.array([[1.0, 2.0, 3.0, 4.0]], dtype=float))
2124
+
2125
+ run(quat_extract_subscript)
2126
+ run(quat_extract_attribute)
2127
+
2128
+
2129
+ @wp.kernel
2130
+ def quat_assign_subscript(x: wp.array(dtype=float), y: wp.array(dtype=wp.quat)):
2131
+ i = wp.tid()
2132
+
2133
+ a = wp.quat()
2134
+ a[0] = 1.0 * x[i]
2135
+ a[1] = 2.0 * x[i]
2136
+ a[2] = 3.0 * x[i]
2137
+ a[3] = 4.0 * x[i]
2138
+ y[i] = a
2139
+
2140
+
2141
+ @wp.kernel
2142
+ def quat_assign_attribute(x: wp.array(dtype=float), y: wp.array(dtype=wp.quat)):
2143
+ i = wp.tid()
2144
+
2145
+ a = wp.quat()
2146
+ a.x = 1.0 * x[i]
2147
+ a.y = 2.0 * x[i]
2148
+ a.z = 3.0 * x[i]
2149
+ a.w = 4.0 * x[i]
2150
+ y[i] = a
2151
+
2152
+
2153
+ def test_quat_assign(test, device):
2154
+ def run(kernel):
2155
+ x = wp.ones(1, dtype=float, requires_grad=True, device=device)
2156
+ y = wp.zeros(1, dtype=wp.quat, requires_grad=True, device=device)
2157
+
2158
+ tape = wp.Tape()
2159
+ with tape:
2160
+ wp.launch(kernel, 1, inputs=[x], outputs=[y], device=device)
2161
+
2162
+ y.grad = wp.ones_like(y)
2163
+ tape.backward()
2164
+
2165
+ assert_np_equal(y.numpy(), np.array([[1.0, 2.0, 3.0, 4.0]], dtype=float))
2166
+ assert_np_equal(x.grad.numpy(), np.array([10.0], dtype=float))
2167
+
2168
+ run(quat_assign_subscript)
2169
+ run(quat_assign_attribute)
2170
+
2171
+
2172
+ @wp.kernel
2173
+ def quat_array_extract_subscript(x: wp.array2d(dtype=wp.quat), y: wp.array2d(dtype=float)):
2174
+ i, j = wp.tid()
2175
+ a = x[i, j][0]
2176
+ b = x[i, j][1]
2177
+ c = x[i, j][2]
2178
+ d = x[i, j][3]
2179
+ y[i, j] = 1.0 * a + 2.0 * b + 3.0 * c + 4.0 * d
2180
+
2181
+
2182
+ @wp.kernel
2183
+ def quat_array_extract_attribute(x: wp.array2d(dtype=wp.quat), y: wp.array2d(dtype=float)):
2184
+ i, j = wp.tid()
2185
+ a = x[i, j].x
2186
+ b = x[i, j].y
2187
+ c = x[i, j].z
2188
+ d = x[i, j].w
2189
+ y[i, j] = 1.0 * a + 2.0 * b + 3.0 * c + 4.0 * d
2190
+
2191
+
2192
+ def test_quat_array_extract(test, device):
2193
+ def run(kernel):
2194
+ x = wp.ones((1, 1), dtype=wp.quat, requires_grad=True, device=device)
2195
+ y = wp.zeros((1, 1), dtype=float, requires_grad=True, device=device)
2196
+
2197
+ tape = wp.Tape()
2198
+ with tape:
2199
+ wp.launch(kernel, (1, 1), inputs=[x], outputs=[y], device=device)
2200
+
2201
+ y.grad = wp.ones_like(y)
2202
+ tape.backward()
2203
+
2204
+ assert_np_equal(y.numpy(), np.array([[10.0]], dtype=float))
2205
+ assert_np_equal(x.grad.numpy(), np.array([[[1.0, 2.0, 3.0, 4.0]]], dtype=float))
2206
+
2207
+ run(quat_array_extract_subscript)
2208
+ run(quat_array_extract_attribute)
2209
+
2210
+
2211
+ @wp.kernel
2212
+ def quat_array_assign_subscript(x: wp.array2d(dtype=float), y: wp.array2d(dtype=wp.quat)):
2213
+ i, j = wp.tid()
2214
+
2215
+ y[i, j][0] = 1.0 * x[i, j]
2216
+ y[i, j][1] = 2.0 * x[i, j]
2217
+ y[i, j][2] = 3.0 * x[i, j]
2218
+ y[i, j][3] = 4.0 * x[i, j]
2219
+
2220
+
2221
+ @wp.kernel
2222
+ def quat_array_assign_attribute(x: wp.array2d(dtype=float), y: wp.array2d(dtype=wp.quat)):
2223
+ i, j = wp.tid()
2224
+
2225
+ y[i, j].x = 1.0 * x[i, j]
2226
+ y[i, j].y = 2.0 * x[i, j]
2227
+ y[i, j].z = 3.0 * x[i, j]
2228
+ y[i, j].w = 4.0 * x[i, j]
2229
+
2230
+
2231
+ def test_quat_array_assign(test, device):
2232
+ def run(kernel):
2233
+ x = wp.ones((1, 1), dtype=float, requires_grad=True, device=device)
2234
+ y = wp.zeros((1, 1), dtype=wp.quat, requires_grad=True, device=device)
2235
+
2236
+ tape = wp.Tape()
2237
+ with tape:
2238
+ wp.launch(kernel, (1, 1), inputs=[x], outputs=[y], device=device)
2239
+
2240
+ y.grad = wp.ones_like(y)
2241
+ tape.backward()
2242
+
2243
+ assert_np_equal(y.numpy(), np.array([[[1.0, 2.0, 3.0, 4.0]]], dtype=float))
2244
+ # TODO: gradient propagation for in-place array assignment
2245
+ # assert_np_equal(x.grad.numpy(), np.array([[10.0]], dtype=float))
2246
+
2247
+ run(quat_array_assign_subscript)
2248
+ run(quat_array_assign_attribute)
2249
+
2250
+
2251
+ @wp.kernel
2252
+ def quat_add_inplace_subscript(x: wp.array(dtype=wp.quat), y: wp.array(dtype=wp.quat)):
2253
+ i = wp.tid()
2254
+
2255
+ a = wp.quat()
2256
+ b = x[i]
2257
+
2258
+ a[0] += 1.0 * b[0]
2259
+ a[1] += 2.0 * b[1]
2260
+ a[2] += 3.0 * b[2]
2261
+ a[3] += 4.0 * b[3]
2262
+
2263
+ y[i] = a
2264
+
2265
+
2266
+ @wp.kernel
2267
+ def quat_add_inplace_attribute(x: wp.array(dtype=wp.quat), y: wp.array(dtype=wp.quat)):
2268
+ i = wp.tid()
2269
+
2270
+ a = wp.quat()
2271
+ b = x[i]
2272
+
2273
+ a.x += 1.0 * b.x
2274
+ a.y += 2.0 * b.y
2275
+ a.z += 3.0 * b.z
2276
+ a.w += 4.0 * b.w
2277
+
2278
+ y[i] = a
2279
+
2280
+
2281
+ def test_quat_add_inplace(test, device):
2282
+ def run(kernel):
2283
+ x = wp.ones(1, dtype=wp.quat, requires_grad=True, device=device)
2284
+ y = wp.zeros(1, dtype=wp.quat, requires_grad=True, device=device)
2285
+
2286
+ tape = wp.Tape()
2287
+ with tape:
2288
+ wp.launch(kernel, 1, inputs=[x], outputs=[y], device=device)
2289
+
2290
+ y.grad = wp.ones_like(y)
2291
+ tape.backward()
2292
+
2293
+ assert_np_equal(y.numpy(), np.array([[1.0, 2.0, 3.0, 4.0]], dtype=float))
2294
+ assert_np_equal(x.grad.numpy(), np.array([[1.0, 2.0, 3.0, 4.0]], dtype=float))
2295
+
2296
+ run(quat_add_inplace_subscript)
2297
+ run(quat_add_inplace_attribute)
2298
+
2299
+
2300
+ @wp.kernel
2301
+ def quat_sub_inplace_subscript(x: wp.array(dtype=wp.quat), y: wp.array(dtype=wp.quat)):
2302
+ i = wp.tid()
2303
+
2304
+ a = wp.quat()
2305
+ b = x[i]
2306
+
2307
+ a[0] -= 1.0 * b[0]
2308
+ a[1] -= 2.0 * b[1]
2309
+ a[2] -= 3.0 * b[2]
2310
+ a[3] -= 4.0 * b[3]
2311
+
2312
+ y[i] = a
2313
+
2314
+
2315
+ @wp.kernel
2316
+ def quat_sub_inplace_attribute(x: wp.array(dtype=wp.quat), y: wp.array(dtype=wp.quat)):
2317
+ i = wp.tid()
2318
+
2319
+ a = wp.quat()
2320
+ b = x[i]
2321
+
2322
+ a.x -= 1.0 * b.x
2323
+ a.y -= 2.0 * b.y
2324
+ a.z -= 3.0 * b.z
2325
+ a.w -= 4.0 * b.w
2326
+
2327
+ y[i] = a
2328
+
2329
+
2330
+ def test_quat_sub_inplace(test, device):
2331
+ def run(kernel):
2332
+ x = wp.ones(1, dtype=wp.quat, requires_grad=True, device=device)
2333
+ y = wp.zeros(1, dtype=wp.quat, requires_grad=True, device=device)
2334
+
2335
+ tape = wp.Tape()
2336
+ with tape:
2337
+ wp.launch(kernel, 1, inputs=[x], outputs=[y], device=device)
2338
+
2339
+ y.grad = wp.ones_like(y)
2340
+ tape.backward()
2341
+
2342
+ assert_np_equal(y.numpy(), np.array([[-1.0, -2.0, -3.0, -4.0]], dtype=float))
2343
+ assert_np_equal(x.grad.numpy(), np.array([[-1.0, -2.0, -3.0, -4.0]], dtype=float))
2344
+
2345
+ run(quat_sub_inplace_subscript)
2346
+ run(quat_sub_inplace_attribute)
2347
+
2348
+
2349
+ @wp.kernel
2350
+ def quat_array_add_inplace(x: wp.array(dtype=wp.quat), y: wp.array(dtype=wp.quat)):
2351
+ i = wp.tid()
2352
+
2353
+ y[i] += x[i]
2354
+
2355
+
2356
+ def test_quat_array_add_inplace(test, device):
2357
+ x = wp.ones(1, dtype=wp.quat, requires_grad=True, device=device)
2358
+ y = wp.zeros(1, dtype=wp.quat, requires_grad=True, device=device)
2359
+
2360
+ tape = wp.Tape()
2361
+ with tape:
2362
+ wp.launch(quat_array_add_inplace, 1, inputs=[x], outputs=[y], device=device)
2363
+
2364
+ y.grad = wp.ones_like(y)
2365
+ tape.backward()
2366
+
2367
+ assert_np_equal(y.numpy(), np.array([[1.0, 1.0, 1.0, 1.0]], dtype=float))
2368
+ assert_np_equal(x.grad.numpy(), np.array([[1.0, 1.0, 1.0, 1.0]], dtype=float))
2369
+
2370
+
2371
+ @wp.kernel
2372
+ def quat_array_sub_inplace(x: wp.array(dtype=wp.quat), y: wp.array(dtype=wp.quat)):
2373
+ i = wp.tid()
2374
+
2375
+ y[i] -= x[i]
2376
+
2377
+
2378
+ def test_quat_array_sub_inplace(test, device):
2379
+ x = wp.ones(1, dtype=wp.quat, requires_grad=True, device=device)
2380
+ y = wp.zeros(1, dtype=wp.quat, requires_grad=True, device=device)
2381
+
2382
+ tape = wp.Tape()
2383
+ with tape:
2384
+ wp.launch(quat_array_sub_inplace, 1, inputs=[x], outputs=[y], device=device)
2385
+
2386
+ y.grad = wp.ones_like(y)
2387
+ tape.backward()
2388
+
2389
+ assert_np_equal(y.numpy(), np.array([[-1.0, -1.0, -1.0, -1.0]], dtype=float))
2390
+ assert_np_equal(x.grad.numpy(), np.array([[-1.0, -1.0, -1.0, -1.0]], dtype=float))
2391
+
2392
+
2393
+ @wp.kernel
2394
+ def scalar_quat_div(x: wp.array(dtype=wp.quat), y: wp.array(dtype=wp.quat)):
2395
+ i = wp.tid()
2396
+ y[i] = 1.0 / x[i]
2397
+
2398
+
2399
+ def test_scalar_quat_div(test, device):
2400
+ x = wp.array((wp.quat(1.0, 2.0, 4.0, 8.0),), dtype=wp.quat, requires_grad=True, device=device)
2401
+ y = wp.ones(1, dtype=wp.quat, requires_grad=True, device=device)
2402
+
2403
+ tape = wp.Tape()
2404
+ with tape:
2405
+ wp.launch(scalar_quat_div, 1, inputs=(x,), outputs=(y,), device=device)
2406
+
2407
+ y.grad = wp.ones_like(y)
2408
+ tape.backward()
2409
+
2410
+ assert_np_equal(y.numpy(), np.array(((1.0, 0.5, 0.25, 0.125),), dtype=float))
2411
+ assert_np_equal(x.grad.numpy(), np.array(((-1.0, -0.25, -0.0625, -0.015625),), dtype=float))
2412
+
2413
+
2414
+ def test_quat_indexing_assign(test, device):
2415
+ @wp.func
2416
+ def fn():
2417
+ q = wp.quat(1.0, 2.0, 3.0, 4.0)
2418
+
2419
+ q[0] = 123.0
2420
+ q[1] *= 2.0
2421
+
2422
+ wp.expect_eq(q[0], 123.0)
2423
+ wp.expect_eq(q[1], 4.0)
2424
+ wp.expect_eq(q[2], 3.0)
2425
+ wp.expect_eq(q[3], 4.0)
2426
+
2427
+ q[-1] = 123.0
2428
+ q[-2] *= 2.0
2429
+
2430
+ wp.expect_eq(q[-1], 123.0)
2431
+ wp.expect_eq(q[-2], 6.0)
2432
+ wp.expect_eq(q[-3], 4.0)
2433
+ wp.expect_eq(q[-4], 123.0)
2434
+
2435
+ @wp.kernel(module="unique")
2436
+ def kernel():
2437
+ fn()
2438
+
2439
+ wp.launch(kernel, 1, device=device)
2440
+ wp.synchronize()
2441
+ fn()
2442
+
2443
+
2444
+ def test_quat_slicing_assign(test, device):
2445
+ vec0 = wp.vec(0, float)
2446
+ vec1 = wp.vec(1, float)
2447
+ vec2 = wp.vec(2, float)
2448
+ vec3 = wp.vec(3, float)
2449
+ vec4 = wp.vec(4, float)
2450
+
2451
+ @wp.func
2452
+ def fn():
2453
+ q = wp.quat(1.0, 2.0, 3.0, 4.0)
2454
+
2455
+ wp.expect_eq(q[:] == vec4(1.0, 2.0, 3.0, 4.0), True)
2456
+ wp.expect_eq(q[-123:123] == vec4(1.0, 2.0, 3.0, 4.0), True)
2457
+ wp.expect_eq(q[123:] == vec0(), True)
2458
+ wp.expect_eq(q[:-123] == vec0(), True)
2459
+ wp.expect_eq(q[::123] == vec1(1.0), True)
2460
+
2461
+ wp.expect_eq(q[1:] == vec3(2.0, 3.0, 4.0), True)
2462
+ wp.expect_eq(q[-2:] == vec2(3.0, 4.0), True)
2463
+ wp.expect_eq(q[:2] == vec2(1.0, 2.0), True)
2464
+ wp.expect_eq(q[:-1] == vec3(1.0, 2.0, 3.0), True)
2465
+ wp.expect_eq(q[::2] == vec2(1.0, 3.0), True)
2466
+ wp.expect_eq(q[1::2] == vec2(2.0, 4.0), True)
2467
+ wp.expect_eq(q[::-1] == vec4(4.0, 3.0, 2.0, 1.0), True)
2468
+ wp.expect_eq(q[::-2] == vec2(4.0, 2.0), True)
2469
+ wp.expect_eq(q[1::-2] == vec1(2.0), True)
2470
+
2471
+ q[1:] = vec3(5.0, 6.0, 7.0)
2472
+ wp.expect_eq(q == wp.quat(1.0, 5.0, 6.0, 7.0), True)
2473
+
2474
+ q[-2:] = vec2(8.0, 9.0)
2475
+ wp.expect_eq(q == wp.quat(1.0, 5.0, 8.0, 9.0), True)
2476
+
2477
+ q[:2] = vec2(10.0, 11.0)
2478
+ wp.expect_eq(q == wp.quat(10.0, 11.0, 8.0, 9.0), True)
2479
+
2480
+ q[:-1] = vec3(12.0, 13.0, 14.0)
2481
+ wp.expect_eq(q == wp.quat(12.0, 13.0, 14.0, 9.0), True)
2482
+
2483
+ q[::2] = vec2(15.0, 16.0)
2484
+ wp.expect_eq(q == wp.quat(15.0, 13.0, 16.0, 9.0), True)
2485
+
2486
+ q[1::2] = vec2(17.0, 18.0)
2487
+ wp.expect_eq(q == wp.quat(15.0, 17.0, 16.0, 18.0), True)
2488
+
2489
+ q[1::-2] = vec1(19.0)
2490
+ wp.expect_eq(q == wp.quat(15.0, 19.0, 16.0, 18.0), True)
2491
+
2492
+ q[1:] += vec3(20.0, 21.0, 22.0)
2493
+ wp.expect_eq(q == wp.quat(15.0, 39.0, 37.0, 40.0), True)
2494
+
2495
+ q[:-1] -= vec3(23.0, 24.0, 25.0)
2496
+ wp.expect_eq(q == wp.quat(-8.0, 15.0, 12.0, 40.0), True)
2497
+
2498
+ @wp.kernel(module="unique")
2499
+ def kernel():
2500
+ fn()
2501
+
2502
+ wp.launch(kernel, 1, device=device)
2503
+ wp.synchronize()
2504
+ fn()
2505
+
2506
+
2507
+ def test_quat_slicing_assign_backward(test, device):
2508
+ @wp.kernel(module="unique")
2509
+ def kernel(arr_x: wp.array(dtype=wp.vec2), arr_y: wp.array(dtype=wp.quat)):
2510
+ i = wp.tid()
2511
+
2512
+ y = arr_y[i]
2513
+
2514
+ y[:2] = arr_x[i]
2515
+ y[1:-1] += arr_x[i][:2]
2516
+ y[3:1:-1] -= arr_x[i][0:]
2517
+
2518
+ arr_y[i] = y
2519
+
2520
+ x = wp.ones(1, dtype=wp.vec2, requires_grad=True, device=device)
2521
+ y = wp.zeros(1, dtype=wp.quat, requires_grad=True, device=device)
2522
+
2523
+ tape = wp.Tape()
2524
+ with tape:
2525
+ wp.launch(kernel, 1, inputs=(x,), outputs=(y,), device=device)
2526
+
2527
+ y.grad = wp.ones_like(y)
2528
+ tape.backward()
2529
+
2530
+ assert_np_equal(y.numpy(), np.array(((1.0, 2.0, 0.0, -1.0),), dtype=float))
2531
+ assert_np_equal(x.grad.numpy(), np.array(((1.0, 1.0),), dtype=float))
2532
+
2533
+
2534
+ devices = get_test_devices()
2535
+
2536
+
2537
+ class TestQuat(unittest.TestCase):
2538
+ pass
2539
+
2540
+
2541
+ add_kernel_test(TestQuat, test_constructor_default, dim=1, devices=devices)
2542
+ add_kernel_test(TestQuat, test_assignment, dim=1, devices=devices)
2543
+
2544
+ for dtype in np_float_types:
2545
+ add_function_test_register_kernel(
2546
+ TestQuat, f"test_constructors_{dtype.__name__}", test_constructors, devices=devices, dtype=dtype
2547
+ )
2548
+ add_function_test_register_kernel(
2549
+ TestQuat,
2550
+ f"test_casting_constructors_{dtype.__name__}",
2551
+ test_casting_constructors,
2552
+ devices=devices,
2553
+ dtype=dtype,
2554
+ )
2555
+ add_function_test_register_kernel(
2556
+ TestQuat, f"test_anon_type_instance_{dtype.__name__}", test_anon_type_instance, devices=devices, dtype=dtype
2557
+ )
2558
+ add_function_test_register_kernel(
2559
+ TestQuat, f"test_inverse_{dtype.__name__}", test_inverse, devices=devices, dtype=dtype
2560
+ )
2561
+ add_function_test_register_kernel(
2562
+ TestQuat, f"test_quat_identity_{dtype.__name__}", test_quat_identity, devices=devices, dtype=dtype
2563
+ )
2564
+ add_function_test_register_kernel(
2565
+ TestQuat, f"test_dotproduct_{dtype.__name__}", test_dotproduct, devices=devices, dtype=dtype
2566
+ )
2567
+ add_function_test_register_kernel(
2568
+ TestQuat, f"test_length_{dtype.__name__}", test_length, devices=devices, dtype=dtype
2569
+ )
2570
+ add_function_test_register_kernel(
2571
+ TestQuat, f"test_normalize_{dtype.__name__}", test_normalize, devices=devices, dtype=dtype
2572
+ )
2573
+ add_function_test_register_kernel(
2574
+ TestQuat, f"test_addition_{dtype.__name__}", test_addition, devices=devices, dtype=dtype
2575
+ )
2576
+ add_function_test_register_kernel(
2577
+ TestQuat, f"test_subtraction_{dtype.__name__}", test_subtraction, devices=devices, dtype=dtype
2578
+ )
2579
+ add_function_test_register_kernel(
2580
+ TestQuat,
2581
+ f"test_scalar_multiplication_{dtype.__name__}",
2582
+ test_scalar_multiplication,
2583
+ devices=devices,
2584
+ dtype=dtype,
2585
+ )
2586
+ add_function_test_register_kernel(
2587
+ TestQuat, f"test_scalar_division_{dtype.__name__}", test_scalar_division, devices=devices, dtype=dtype
2588
+ )
2589
+ add_function_test_register_kernel(
2590
+ TestQuat,
2591
+ f"test_quat_multiplication_{dtype.__name__}",
2592
+ test_quat_multiplication,
2593
+ devices=devices,
2594
+ dtype=dtype,
2595
+ )
2596
+ add_function_test_register_kernel(
2597
+ TestQuat, f"test_indexing_{dtype.__name__}", test_indexing, devices=devices, dtype=dtype
2598
+ )
2599
+ add_function_test_register_kernel(
2600
+ TestQuat, f"test_quat_lerp_{dtype.__name__}", test_quat_lerp, devices=devices, dtype=dtype
2601
+ )
2602
+ add_function_test_register_kernel(
2603
+ TestQuat,
2604
+ f"test_quat_to_axis_angle_grad_{dtype.__name__}",
2605
+ test_quat_to_axis_angle_grad,
2606
+ devices=devices,
2607
+ dtype=dtype,
2608
+ )
2609
+ add_function_test_register_kernel(
2610
+ TestQuat, f"test_slerp_grad_{dtype.__name__}", test_slerp_grad, devices=devices, dtype=dtype
2611
+ )
2612
+ add_function_test_register_kernel(
2613
+ TestQuat, f"test_quat_rpy_grad_{dtype.__name__}", test_quat_rpy_grad, devices=devices, dtype=dtype
2614
+ )
2615
+ add_function_test_register_kernel(
2616
+ TestQuat, f"test_quat_from_matrix_{dtype.__name__}", test_quat_from_matrix, devices=devices, dtype=dtype
2617
+ )
2618
+ add_function_test_register_kernel(
2619
+ TestQuat, f"test_quat_rotate_{dtype.__name__}", test_quat_rotate, devices=devices, dtype=dtype
2620
+ )
2621
+ add_function_test_register_kernel(
2622
+ TestQuat, f"test_quat_to_matrix_{dtype.__name__}", test_quat_to_matrix, devices=devices, dtype=dtype
2623
+ )
2624
+ add_function_test_register_kernel(
2625
+ TestQuat,
2626
+ f"test_quat_euler_conversion_{dtype.__name__}",
2627
+ test_quat_euler_conversion,
2628
+ devices=devices,
2629
+ dtype=dtype,
2630
+ )
2631
+ add_function_test(
2632
+ TestQuat, f"test_py_arithmetic_ops_{dtype.__name__}", test_py_arithmetic_ops, devices=None, dtype=dtype
2633
+ )
2634
+
2635
+ add_function_test(TestQuat, "test_quat_backward", test_quat_backward, devices=devices)
2636
+ add_function_test(TestQuat, "test_quat_len", test_quat_len, devices=devices)
2637
+ add_function_test(TestQuat, "test_quat_extract", test_quat_extract, devices=devices)
2638
+ add_function_test(TestQuat, "test_quat_assign", test_quat_assign, devices=devices)
2639
+ add_function_test(TestQuat, "test_quat_array_extract", test_quat_array_extract, devices=devices)
2640
+ add_function_test(TestQuat, "test_quat_array_assign", test_quat_array_assign, devices=devices)
2641
+ add_function_test(TestQuat, "test_quat_add_inplace", test_quat_add_inplace, devices=devices)
2642
+ add_function_test(TestQuat, "test_quat_sub_inplace", test_quat_sub_inplace, devices=devices)
2643
+ add_function_test(TestQuat, "test_quat_array_add_inplace", test_quat_array_add_inplace, devices=devices)
2644
+ add_function_test(TestQuat, "test_quat_array_sub_inplace", test_quat_array_sub_inplace, devices=devices)
2645
+ add_function_test(TestQuat, "test_scalar_quat_div", test_scalar_quat_div, devices=devices)
2646
+ add_function_test(TestQuat, "test_quat_indexing_assign", test_quat_indexing_assign, devices=devices)
2647
+ add_function_test(TestQuat, "test_quat_slicing_assign", test_quat_slicing_assign, devices=devices)
2648
+ add_function_test(TestQuat, "test_quat_slicing_assign_backward", test_quat_slicing_assign_backward, devices=devices)
2649
+
2650
+
2651
+ if __name__ == "__main__":
2652
+ wp.clear_kernel_cache()
2653
+ unittest.main(verbosity=2)