warp-lang 1.0.1__py3-none-macosx_10_13_universal2.whl → 1.1.0__py3-none-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (346) hide show
  1. warp/__init__.py +108 -97
  2. warp/__init__.pyi +1 -1
  3. warp/bin/libwarp-clang.dylib +0 -0
  4. warp/bin/libwarp.dylib +0 -0
  5. warp/build.py +115 -113
  6. warp/build_dll.py +383 -375
  7. warp/builtins.py +3425 -3354
  8. warp/codegen.py +2878 -2792
  9. warp/config.py +40 -36
  10. warp/constants.py +45 -45
  11. warp/context.py +5194 -5102
  12. warp/dlpack.py +442 -442
  13. warp/examples/__init__.py +16 -16
  14. warp/examples/assets/bear.usd +0 -0
  15. warp/examples/assets/bunny.usd +0 -0
  16. warp/examples/assets/cartpole.urdf +110 -110
  17. warp/examples/assets/crazyflie.usd +0 -0
  18. warp/examples/assets/cube.usd +0 -0
  19. warp/examples/assets/nv_ant.xml +92 -92
  20. warp/examples/assets/nv_humanoid.xml +183 -183
  21. warp/examples/assets/quadruped.urdf +267 -267
  22. warp/examples/assets/rocks.nvdb +0 -0
  23. warp/examples/assets/rocks.usd +0 -0
  24. warp/examples/assets/sphere.usd +0 -0
  25. warp/examples/benchmarks/benchmark_api.py +383 -383
  26. warp/examples/benchmarks/benchmark_cloth.py +278 -279
  27. warp/examples/benchmarks/benchmark_cloth_cupy.py +88 -88
  28. warp/examples/benchmarks/benchmark_cloth_jax.py +97 -100
  29. warp/examples/benchmarks/benchmark_cloth_numba.py +146 -142
  30. warp/examples/benchmarks/benchmark_cloth_numpy.py +77 -77
  31. warp/examples/benchmarks/benchmark_cloth_pytorch.py +86 -86
  32. warp/examples/benchmarks/benchmark_cloth_taichi.py +112 -112
  33. warp/examples/benchmarks/benchmark_cloth_warp.py +146 -146
  34. warp/examples/benchmarks/benchmark_launches.py +295 -295
  35. warp/examples/browse.py +29 -28
  36. warp/examples/core/example_dem.py +234 -221
  37. warp/examples/core/example_fluid.py +293 -267
  38. warp/examples/core/example_graph_capture.py +144 -129
  39. warp/examples/core/example_marching_cubes.py +188 -176
  40. warp/examples/core/example_mesh.py +174 -154
  41. warp/examples/core/example_mesh_intersect.py +205 -193
  42. warp/examples/core/example_nvdb.py +176 -169
  43. warp/examples/core/example_raycast.py +105 -89
  44. warp/examples/core/example_raymarch.py +199 -178
  45. warp/examples/core/example_render_opengl.py +185 -141
  46. warp/examples/core/example_sph.py +405 -389
  47. warp/examples/core/example_torch.py +222 -181
  48. warp/examples/core/example_wave.py +263 -249
  49. warp/examples/fem/bsr_utils.py +378 -380
  50. warp/examples/fem/example_apic_fluid.py +407 -391
  51. warp/examples/fem/example_convection_diffusion.py +182 -168
  52. warp/examples/fem/example_convection_diffusion_dg.py +219 -209
  53. warp/examples/fem/example_convection_diffusion_dg0.py +204 -194
  54. warp/examples/fem/example_deformed_geometry.py +177 -159
  55. warp/examples/fem/example_diffusion.py +201 -173
  56. warp/examples/fem/example_diffusion_3d.py +177 -152
  57. warp/examples/fem/example_diffusion_mgpu.py +221 -214
  58. warp/examples/fem/example_mixed_elasticity.py +244 -222
  59. warp/examples/fem/example_navier_stokes.py +259 -243
  60. warp/examples/fem/example_stokes.py +220 -192
  61. warp/examples/fem/example_stokes_transfer.py +265 -249
  62. warp/examples/fem/mesh_utils.py +133 -109
  63. warp/examples/fem/plot_utils.py +292 -287
  64. warp/examples/optim/example_bounce.py +260 -248
  65. warp/examples/optim/example_cloth_throw.py +222 -210
  66. warp/examples/optim/example_diffray.py +566 -535
  67. warp/examples/optim/example_drone.py +864 -835
  68. warp/examples/optim/example_inverse_kinematics.py +176 -169
  69. warp/examples/optim/example_inverse_kinematics_torch.py +185 -170
  70. warp/examples/optim/example_spring_cage.py +239 -234
  71. warp/examples/optim/example_trajectory.py +223 -201
  72. warp/examples/optim/example_walker.py +306 -292
  73. warp/examples/sim/example_cartpole.py +139 -128
  74. warp/examples/sim/example_cloth.py +196 -184
  75. warp/examples/sim/example_granular.py +124 -113
  76. warp/examples/sim/example_granular_collision_sdf.py +197 -185
  77. warp/examples/sim/example_jacobian_ik.py +236 -213
  78. warp/examples/sim/example_particle_chain.py +118 -106
  79. warp/examples/sim/example_quadruped.py +193 -179
  80. warp/examples/sim/example_rigid_chain.py +197 -189
  81. warp/examples/sim/example_rigid_contact.py +189 -176
  82. warp/examples/sim/example_rigid_force.py +127 -126
  83. warp/examples/sim/example_rigid_gyroscopic.py +109 -97
  84. warp/examples/sim/example_rigid_soft_contact.py +134 -124
  85. warp/examples/sim/example_soft_body.py +190 -178
  86. warp/fabric.py +337 -335
  87. warp/fem/__init__.py +60 -27
  88. warp/fem/cache.py +401 -388
  89. warp/fem/dirichlet.py +178 -179
  90. warp/fem/domain.py +262 -263
  91. warp/fem/field/__init__.py +100 -101
  92. warp/fem/field/field.py +148 -149
  93. warp/fem/field/nodal_field.py +298 -299
  94. warp/fem/field/restriction.py +22 -21
  95. warp/fem/field/test.py +180 -181
  96. warp/fem/field/trial.py +183 -183
  97. warp/fem/geometry/__init__.py +15 -19
  98. warp/fem/geometry/closest_point.py +69 -70
  99. warp/fem/geometry/deformed_geometry.py +270 -271
  100. warp/fem/geometry/element.py +744 -744
  101. warp/fem/geometry/geometry.py +184 -186
  102. warp/fem/geometry/grid_2d.py +380 -373
  103. warp/fem/geometry/grid_3d.py +441 -435
  104. warp/fem/geometry/hexmesh.py +953 -953
  105. warp/fem/geometry/partition.py +374 -376
  106. warp/fem/geometry/quadmesh_2d.py +532 -532
  107. warp/fem/geometry/tetmesh.py +840 -840
  108. warp/fem/geometry/trimesh_2d.py +577 -577
  109. warp/fem/integrate.py +1630 -1615
  110. warp/fem/operator.py +190 -191
  111. warp/fem/polynomial.py +214 -213
  112. warp/fem/quadrature/__init__.py +2 -2
  113. warp/fem/quadrature/pic_quadrature.py +243 -245
  114. warp/fem/quadrature/quadrature.py +295 -294
  115. warp/fem/space/__init__.py +294 -292
  116. warp/fem/space/basis_space.py +488 -489
  117. warp/fem/space/collocated_function_space.py +100 -105
  118. warp/fem/space/dof_mapper.py +236 -236
  119. warp/fem/space/function_space.py +148 -145
  120. warp/fem/space/grid_2d_function_space.py +267 -267
  121. warp/fem/space/grid_3d_function_space.py +305 -306
  122. warp/fem/space/hexmesh_function_space.py +350 -352
  123. warp/fem/space/partition.py +350 -350
  124. warp/fem/space/quadmesh_2d_function_space.py +368 -369
  125. warp/fem/space/restriction.py +158 -160
  126. warp/fem/space/shape/__init__.py +13 -15
  127. warp/fem/space/shape/cube_shape_function.py +738 -738
  128. warp/fem/space/shape/shape_function.py +102 -103
  129. warp/fem/space/shape/square_shape_function.py +611 -611
  130. warp/fem/space/shape/tet_shape_function.py +565 -567
  131. warp/fem/space/shape/triangle_shape_function.py +429 -429
  132. warp/fem/space/tetmesh_function_space.py +294 -292
  133. warp/fem/space/topology.py +297 -295
  134. warp/fem/space/trimesh_2d_function_space.py +223 -221
  135. warp/fem/types.py +77 -77
  136. warp/fem/utils.py +495 -495
  137. warp/jax.py +166 -141
  138. warp/jax_experimental.py +341 -339
  139. warp/native/array.h +1072 -1025
  140. warp/native/builtin.h +1560 -1560
  141. warp/native/bvh.cpp +398 -398
  142. warp/native/bvh.cu +525 -525
  143. warp/native/bvh.h +429 -429
  144. warp/native/clang/clang.cpp +495 -464
  145. warp/native/crt.cpp +31 -31
  146. warp/native/crt.h +334 -334
  147. warp/native/cuda_crt.h +1049 -1049
  148. warp/native/cuda_util.cpp +549 -540
  149. warp/native/cuda_util.h +288 -203
  150. warp/native/cutlass_gemm.cpp +34 -34
  151. warp/native/cutlass_gemm.cu +372 -372
  152. warp/native/error.cpp +66 -66
  153. warp/native/error.h +27 -27
  154. warp/native/fabric.h +228 -228
  155. warp/native/hashgrid.cpp +301 -278
  156. warp/native/hashgrid.cu +78 -77
  157. warp/native/hashgrid.h +227 -227
  158. warp/native/initializer_array.h +32 -32
  159. warp/native/intersect.h +1204 -1204
  160. warp/native/intersect_adj.h +365 -365
  161. warp/native/intersect_tri.h +322 -322
  162. warp/native/marching.cpp +2 -2
  163. warp/native/marching.cu +497 -497
  164. warp/native/marching.h +2 -2
  165. warp/native/mat.h +1498 -1498
  166. warp/native/matnn.h +333 -333
  167. warp/native/mesh.cpp +203 -203
  168. warp/native/mesh.cu +293 -293
  169. warp/native/mesh.h +1887 -1887
  170. warp/native/nanovdb/NanoVDB.h +4782 -4782
  171. warp/native/nanovdb/PNanoVDB.h +2553 -2553
  172. warp/native/nanovdb/PNanoVDBWrite.h +294 -294
  173. warp/native/noise.h +850 -850
  174. warp/native/quat.h +1084 -1084
  175. warp/native/rand.h +299 -299
  176. warp/native/range.h +108 -108
  177. warp/native/reduce.cpp +156 -156
  178. warp/native/reduce.cu +348 -348
  179. warp/native/runlength_encode.cpp +61 -61
  180. warp/native/runlength_encode.cu +46 -46
  181. warp/native/scan.cpp +30 -30
  182. warp/native/scan.cu +36 -36
  183. warp/native/scan.h +7 -7
  184. warp/native/solid_angle.h +442 -442
  185. warp/native/sort.cpp +94 -94
  186. warp/native/sort.cu +97 -97
  187. warp/native/sort.h +14 -14
  188. warp/native/sparse.cpp +337 -337
  189. warp/native/sparse.cu +544 -544
  190. warp/native/spatial.h +630 -630
  191. warp/native/svd.h +562 -562
  192. warp/native/temp_buffer.h +30 -30
  193. warp/native/vec.h +1132 -1132
  194. warp/native/volume.cpp +297 -297
  195. warp/native/volume.cu +32 -32
  196. warp/native/volume.h +538 -538
  197. warp/native/volume_builder.cu +425 -425
  198. warp/native/volume_builder.h +19 -19
  199. warp/native/warp.cpp +1057 -1052
  200. warp/native/warp.cu +2943 -2828
  201. warp/native/warp.h +313 -305
  202. warp/optim/__init__.py +9 -9
  203. warp/optim/adam.py +120 -120
  204. warp/optim/linear.py +1104 -939
  205. warp/optim/sgd.py +104 -92
  206. warp/render/__init__.py +10 -10
  207. warp/render/render_opengl.py +3217 -3204
  208. warp/render/render_usd.py +768 -749
  209. warp/render/utils.py +152 -150
  210. warp/sim/__init__.py +52 -59
  211. warp/sim/articulation.py +685 -685
  212. warp/sim/collide.py +1594 -1590
  213. warp/sim/import_mjcf.py +489 -481
  214. warp/sim/import_snu.py +220 -221
  215. warp/sim/import_urdf.py +536 -516
  216. warp/sim/import_usd.py +887 -881
  217. warp/sim/inertia.py +316 -317
  218. warp/sim/integrator.py +234 -233
  219. warp/sim/integrator_euler.py +1956 -1956
  220. warp/sim/integrator_featherstone.py +1910 -1991
  221. warp/sim/integrator_xpbd.py +3294 -3312
  222. warp/sim/model.py +4473 -4314
  223. warp/sim/particles.py +113 -112
  224. warp/sim/render.py +417 -403
  225. warp/sim/utils.py +413 -410
  226. warp/sparse.py +1227 -1227
  227. warp/stubs.py +2109 -2469
  228. warp/tape.py +1162 -225
  229. warp/tests/__init__.py +1 -1
  230. warp/tests/__main__.py +4 -4
  231. warp/tests/assets/torus.usda +105 -105
  232. warp/tests/aux_test_class_kernel.py +26 -26
  233. warp/tests/aux_test_compile_consts_dummy.py +10 -10
  234. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -21
  235. warp/tests/aux_test_dependent.py +22 -22
  236. warp/tests/aux_test_grad_customs.py +23 -23
  237. warp/tests/aux_test_reference.py +11 -11
  238. warp/tests/aux_test_reference_reference.py +10 -10
  239. warp/tests/aux_test_square.py +17 -17
  240. warp/tests/aux_test_unresolved_func.py +14 -14
  241. warp/tests/aux_test_unresolved_symbol.py +14 -14
  242. warp/tests/disabled_kinematics.py +239 -239
  243. warp/tests/run_coverage_serial.py +31 -31
  244. warp/tests/test_adam.py +157 -157
  245. warp/tests/test_arithmetic.py +1124 -1124
  246. warp/tests/test_array.py +2417 -2326
  247. warp/tests/test_array_reduce.py +150 -150
  248. warp/tests/test_async.py +668 -656
  249. warp/tests/test_atomic.py +141 -141
  250. warp/tests/test_bool.py +204 -149
  251. warp/tests/test_builtins_resolution.py +1292 -1292
  252. warp/tests/test_bvh.py +164 -171
  253. warp/tests/test_closest_point_edge_edge.py +228 -228
  254. warp/tests/test_codegen.py +566 -553
  255. warp/tests/test_compile_consts.py +97 -101
  256. warp/tests/test_conditional.py +246 -246
  257. warp/tests/test_copy.py +232 -215
  258. warp/tests/test_ctypes.py +632 -632
  259. warp/tests/test_dense.py +67 -67
  260. warp/tests/test_devices.py +91 -98
  261. warp/tests/test_dlpack.py +530 -529
  262. warp/tests/test_examples.py +400 -378
  263. warp/tests/test_fabricarray.py +955 -955
  264. warp/tests/test_fast_math.py +62 -54
  265. warp/tests/test_fem.py +1277 -1278
  266. warp/tests/test_fp16.py +130 -130
  267. warp/tests/test_func.py +338 -337
  268. warp/tests/test_generics.py +571 -571
  269. warp/tests/test_grad.py +746 -640
  270. warp/tests/test_grad_customs.py +333 -336
  271. warp/tests/test_hash_grid.py +210 -164
  272. warp/tests/test_import.py +39 -39
  273. warp/tests/test_indexedarray.py +1134 -1134
  274. warp/tests/test_intersect.py +67 -67
  275. warp/tests/test_jax.py +307 -307
  276. warp/tests/test_large.py +167 -164
  277. warp/tests/test_launch.py +354 -354
  278. warp/tests/test_lerp.py +261 -261
  279. warp/tests/test_linear_solvers.py +191 -171
  280. warp/tests/test_lvalue.py +421 -493
  281. warp/tests/test_marching_cubes.py +65 -65
  282. warp/tests/test_mat.py +1801 -1827
  283. warp/tests/test_mat_lite.py +115 -115
  284. warp/tests/test_mat_scalar_ops.py +2907 -2889
  285. warp/tests/test_math.py +126 -193
  286. warp/tests/test_matmul.py +500 -499
  287. warp/tests/test_matmul_lite.py +410 -410
  288. warp/tests/test_mempool.py +188 -190
  289. warp/tests/test_mesh.py +284 -324
  290. warp/tests/test_mesh_query_aabb.py +228 -241
  291. warp/tests/test_mesh_query_point.py +692 -702
  292. warp/tests/test_mesh_query_ray.py +292 -303
  293. warp/tests/test_mlp.py +276 -276
  294. warp/tests/test_model.py +110 -110
  295. warp/tests/test_modules_lite.py +39 -39
  296. warp/tests/test_multigpu.py +163 -163
  297. warp/tests/test_noise.py +248 -248
  298. warp/tests/test_operators.py +250 -250
  299. warp/tests/test_options.py +123 -125
  300. warp/tests/test_peer.py +133 -137
  301. warp/tests/test_pinned.py +78 -78
  302. warp/tests/test_print.py +54 -54
  303. warp/tests/test_quat.py +2086 -2086
  304. warp/tests/test_rand.py +288 -288
  305. warp/tests/test_reload.py +217 -217
  306. warp/tests/test_rounding.py +179 -179
  307. warp/tests/test_runlength_encode.py +190 -190
  308. warp/tests/test_sim_grad.py +243 -0
  309. warp/tests/test_sim_kinematics.py +91 -97
  310. warp/tests/test_smoothstep.py +168 -168
  311. warp/tests/test_snippet.py +305 -266
  312. warp/tests/test_sparse.py +468 -460
  313. warp/tests/test_spatial.py +2148 -2148
  314. warp/tests/test_streams.py +486 -473
  315. warp/tests/test_struct.py +710 -675
  316. warp/tests/test_tape.py +173 -148
  317. warp/tests/test_torch.py +743 -743
  318. warp/tests/test_transient_module.py +87 -87
  319. warp/tests/test_types.py +556 -659
  320. warp/tests/test_utils.py +490 -499
  321. warp/tests/test_vec.py +1264 -1268
  322. warp/tests/test_vec_lite.py +73 -73
  323. warp/tests/test_vec_scalar_ops.py +2099 -2099
  324. warp/tests/test_verify_fp.py +94 -94
  325. warp/tests/test_volume.py +737 -736
  326. warp/tests/test_volume_write.py +255 -265
  327. warp/tests/unittest_serial.py +37 -37
  328. warp/tests/unittest_suites.py +363 -359
  329. warp/tests/unittest_utils.py +603 -578
  330. warp/tests/unused_test_misc.py +71 -71
  331. warp/tests/walkthrough_debug.py +85 -85
  332. warp/thirdparty/appdirs.py +598 -598
  333. warp/thirdparty/dlpack.py +143 -143
  334. warp/thirdparty/unittest_parallel.py +566 -561
  335. warp/torch.py +321 -295
  336. warp/types.py +4504 -4450
  337. warp/utils.py +1008 -821
  338. {warp_lang-1.0.1.dist-info → warp_lang-1.1.0.dist-info}/LICENSE.md +126 -126
  339. {warp_lang-1.0.1.dist-info → warp_lang-1.1.0.dist-info}/METADATA +338 -400
  340. warp_lang-1.1.0.dist-info/RECORD +352 -0
  341. warp/examples/assets/cube.usda +0 -42
  342. warp/examples/assets/sphere.usda +0 -56
  343. warp/examples/assets/torus.usda +0 -105
  344. warp_lang-1.0.1.dist-info/RECORD +0 -352
  345. {warp_lang-1.0.1.dist-info → warp_lang-1.1.0.dist-info}/WHEEL +0 -0
  346. {warp_lang-1.0.1.dist-info → warp_lang-1.1.0.dist-info}/top_level.txt +0 -0
warp/tests/test_quat.py CHANGED
@@ -1,2086 +1,2086 @@
1
- # Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
2
- # NVIDIA CORPORATION and its licensors retain all intellectual property
3
- # and proprietary rights in and to this software, related documentation
4
- # and any modifications thereto. Any use, reproduction, disclosure or
5
- # distribution of this software and related documentation without an express
6
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
-
8
- import unittest
9
-
10
- import numpy as np
11
-
12
- import warp as wp
13
- from warp.tests.unittest_utils import *
14
- import warp.sim
15
-
16
- wp.init()
17
-
18
- np_float_types = [np.float32, np.float64, np.float16]
19
-
20
- kernel_cache = dict()
21
-
22
-
23
- def getkernel(func, suffix=""):
24
- key = func.__name__ + "_" + suffix
25
- if key not in kernel_cache:
26
- kernel_cache[key] = wp.Kernel(func=func, key=key)
27
- return kernel_cache[key]
28
-
29
-
30
- def get_select_kernel(dtype):
31
- def output_select_kernel_fn(
32
- input: wp.array(dtype=dtype),
33
- index: int,
34
- out: wp.array(dtype=dtype),
35
- ):
36
- out[0] = input[index]
37
-
38
- return getkernel(output_select_kernel_fn, suffix=dtype.__name__)
39
-
40
-
41
- ############################################################
42
-
43
-
44
- def test_constructors(test, device, dtype, register_kernels=False):
45
- rng = np.random.default_rng(123)
46
-
47
- tol = {
48
- np.float16: 5.0e-3,
49
- np.float32: 1.0e-6,
50
- np.float64: 1.0e-8,
51
- }.get(dtype, 0)
52
-
53
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
54
- vec3 = wp.types.vector(length=3, dtype=wptype)
55
- quat = wp.types.quaternion(dtype=wptype)
56
-
57
- def check_component_constructor(
58
- input: wp.array(dtype=wptype),
59
- q: wp.array(dtype=wptype),
60
- ):
61
- qresult = quat(input[0], input[1], input[2], input[3])
62
-
63
- # multiply the output by 2 so we've got something to backpropagate:
64
- q[0] = wptype(2) * qresult[0]
65
- q[1] = wptype(2) * qresult[1]
66
- q[2] = wptype(2) * qresult[2]
67
- q[3] = wptype(2) * qresult[3]
68
-
69
- def check_vector_constructor(
70
- input: wp.array(dtype=wptype),
71
- q: wp.array(dtype=wptype),
72
- ):
73
- qresult = quat(vec3(input[0], input[1], input[2]), input[3])
74
-
75
- # multiply the output by 2 so we've got something to backpropagate:
76
- q[0] = wptype(2) * qresult[0]
77
- q[1] = wptype(2) * qresult[1]
78
- q[2] = wptype(2) * qresult[2]
79
- q[3] = wptype(2) * qresult[3]
80
-
81
- kernel = getkernel(check_component_constructor, suffix=dtype.__name__)
82
- output_select_kernel = get_select_kernel(wptype)
83
- vec_kernel = getkernel(check_vector_constructor, suffix=dtype.__name__)
84
-
85
- if register_kernels:
86
- return
87
-
88
- input = wp.array(rng.standard_normal(size=4).astype(dtype), requires_grad=True, device=device)
89
- output = wp.zeros_like(input)
90
- wp.launch(kernel, dim=1, inputs=[input], outputs=[output], device=device)
91
-
92
- assert_np_equal(output.numpy(), 2 * input.numpy(), tol=tol)
93
-
94
- for i in range(4):
95
- cmp = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
96
- tape = wp.Tape()
97
- with tape:
98
- wp.launch(kernel, dim=1, inputs=[input], outputs=[output], device=device)
99
- wp.launch(output_select_kernel, dim=1, inputs=[output, i], outputs=[cmp], device=device)
100
- tape.backward(loss=cmp)
101
- expectedgrads = np.zeros(len(input))
102
- expectedgrads[i] = 2
103
- assert_np_equal(tape.gradients[input].numpy(), expectedgrads)
104
- tape.zero()
105
-
106
- input = wp.array(rng.standard_normal(size=4).astype(dtype), requires_grad=True, device=device)
107
- output = wp.zeros_like(input)
108
- wp.launch(vec_kernel, dim=1, inputs=[input], outputs=[output], device=device)
109
-
110
- assert_np_equal(output.numpy(), 2 * input.numpy(), tol=tol)
111
-
112
- for i in range(4):
113
- cmp = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
114
- tape = wp.Tape()
115
- with tape:
116
- wp.launch(vec_kernel, dim=1, inputs=[input], outputs=[output], device=device)
117
- wp.launch(output_select_kernel, dim=1, inputs=[output, i], outputs=[cmp], device=device)
118
- tape.backward(loss=cmp)
119
- expectedgrads = np.zeros(len(input))
120
- expectedgrads[i] = 2
121
- assert_np_equal(tape.gradients[input].numpy(), expectedgrads)
122
- tape.zero()
123
-
124
-
125
- def test_casting_constructors(test, device, dtype, register_kernels=False):
126
- np_type = np.dtype(dtype)
127
- wp_type = wp.types.np_dtype_to_warp_type[np_type]
128
- quat = wp.types.quaternion(dtype=wp_type)
129
-
130
- np16 = np.dtype(np.float16)
131
- wp16 = wp.types.np_dtype_to_warp_type[np16]
132
-
133
- np32 = np.dtype(np.float32)
134
- wp32 = wp.types.np_dtype_to_warp_type[np32]
135
-
136
- np64 = np.dtype(np.float64)
137
- wp64 = wp.types.np_dtype_to_warp_type[np64]
138
-
139
- def cast_float16(a: wp.array(dtype=wp_type, ndim=2), b: wp.array(dtype=wp16, ndim=2)):
140
- tid = wp.tid()
141
-
142
- q1 = quat(a[tid, 0], a[tid, 1], a[tid, 2], a[tid, 3])
143
- q2 = wp.quaternion(q1, dtype=wp16)
144
-
145
- b[tid, 0] = q2[0]
146
- b[tid, 1] = q2[1]
147
- b[tid, 2] = q2[2]
148
- b[tid, 3] = q2[3]
149
-
150
- def cast_float32(a: wp.array(dtype=wp_type, ndim=2), b: wp.array(dtype=wp32, ndim=2)):
151
- tid = wp.tid()
152
-
153
- q1 = quat(a[tid, 0], a[tid, 1], a[tid, 2], a[tid, 3])
154
- q2 = wp.quaternion(q1, dtype=wp32)
155
-
156
- b[tid, 0] = q2[0]
157
- b[tid, 1] = q2[1]
158
- b[tid, 2] = q2[2]
159
- b[tid, 3] = q2[3]
160
-
161
- def cast_float64(a: wp.array(dtype=wp_type, ndim=2), b: wp.array(dtype=wp64, ndim=2)):
162
- tid = wp.tid()
163
-
164
- q1 = quat(a[tid, 0], a[tid, 1], a[tid, 2], a[tid, 3])
165
- q2 = wp.quaternion(q1, dtype=wp64)
166
-
167
- b[tid, 0] = q2[0]
168
- b[tid, 1] = q2[1]
169
- b[tid, 2] = q2[2]
170
- b[tid, 3] = q2[3]
171
-
172
- kernel_16 = getkernel(cast_float16, suffix=dtype.__name__)
173
- kernel_32 = getkernel(cast_float32, suffix=dtype.__name__)
174
- kernel_64 = getkernel(cast_float64, suffix=dtype.__name__)
175
-
176
- if register_kernels:
177
- return
178
-
179
- # check casting to float 16
180
- a = wp.array(np.ones((1, 4), dtype=np_type), dtype=wp_type, requires_grad=True, device=device)
181
- b = wp.array(np.zeros((1, 4), dtype=np16), dtype=wp16, requires_grad=True, device=device)
182
- b_result = np.ones((1, 4), dtype=np16)
183
- b_grad = wp.array(np.ones((1, 4), dtype=np16), dtype=wp16, device=device)
184
- a_grad = wp.array(np.ones((1, 4), dtype=np_type), dtype=wp_type, device=device)
185
-
186
- tape = wp.Tape()
187
- with tape:
188
- wp.launch(kernel=kernel_16, dim=1, inputs=[a, b], device=device)
189
-
190
- tape.backward(grads={b: b_grad})
191
- out = tape.gradients[a].numpy()
192
-
193
- assert_np_equal(b.numpy(), b_result)
194
- assert_np_equal(out, a_grad.numpy())
195
-
196
- # check casting to float 32
197
- a = wp.array(np.ones((1, 4), dtype=np_type), dtype=wp_type, requires_grad=True, device=device)
198
- b = wp.array(np.zeros((1, 4), dtype=np32), dtype=wp32, requires_grad=True, device=device)
199
- b_result = np.ones((1, 4), dtype=np32)
200
- b_grad = wp.array(np.ones((1, 4), dtype=np32), dtype=wp32, device=device)
201
- a_grad = wp.array(np.ones((1, 4), dtype=np_type), dtype=wp_type, device=device)
202
-
203
- tape = wp.Tape()
204
- with tape:
205
- wp.launch(kernel=kernel_32, dim=1, inputs=[a, b], device=device)
206
-
207
- tape.backward(grads={b: b_grad})
208
- out = tape.gradients[a].numpy()
209
-
210
- assert_np_equal(b.numpy(), b_result)
211
- assert_np_equal(out, a_grad.numpy())
212
-
213
- # check casting to float 64
214
- a = wp.array(np.ones((1, 4), dtype=np_type), dtype=wp_type, requires_grad=True, device=device)
215
- b = wp.array(np.zeros((1, 4), dtype=np64), dtype=wp64, requires_grad=True, device=device)
216
- b_result = np.ones((1, 4), dtype=np64)
217
- b_grad = wp.array(np.ones((1, 4), dtype=np64), dtype=wp64, device=device)
218
- a_grad = wp.array(np.ones((1, 4), dtype=np_type), dtype=wp_type, device=device)
219
-
220
- tape = wp.Tape()
221
- with tape:
222
- wp.launch(kernel=kernel_64, dim=1, inputs=[a, b], device=device)
223
-
224
- tape.backward(grads={b: b_grad})
225
- out = tape.gradients[a].numpy()
226
-
227
- assert_np_equal(b.numpy(), b_result)
228
- assert_np_equal(out, a_grad.numpy())
229
-
230
-
231
- def test_inverse(test, device, dtype, register_kernels=False):
232
- rng = np.random.default_rng(123)
233
-
234
- tol = {
235
- np.float16: 2.0e-3,
236
- np.float32: 1.0e-6,
237
- np.float64: 1.0e-8,
238
- }.get(dtype, 0)
239
-
240
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
241
- quat = wp.types.quaternion(dtype=wptype)
242
-
243
- output_select_kernel = get_select_kernel(wptype)
244
-
245
- def check_quat_inverse(
246
- input: wp.array(dtype=wptype),
247
- shouldbeidentity: wp.array(dtype=quat),
248
- q: wp.array(dtype=wptype),
249
- ):
250
- qread = quat(input[0], input[1], input[2], input[3])
251
- qresult = wp.quat_inverse(qread)
252
-
253
- # this inverse should work for normalized quaternions:
254
- shouldbeidentity[0] = wp.normalize(qread) * wp.quat_inverse(wp.normalize(qread))
255
-
256
- # multiply the output by 2 so we've got something to backpropagate:
257
- q[0] = wptype(2) * qresult[0]
258
- q[1] = wptype(2) * qresult[1]
259
- q[2] = wptype(2) * qresult[2]
260
- q[3] = wptype(2) * qresult[3]
261
-
262
- kernel = getkernel(check_quat_inverse, suffix=dtype.__name__)
263
-
264
- if register_kernels:
265
- return
266
-
267
- input = wp.array(rng.standard_normal(size=4).astype(dtype), requires_grad=True, device=device)
268
- shouldbeidentity = wp.array(np.zeros((1, 4)), dtype=quat, requires_grad=True, device=device)
269
- output = wp.zeros_like(input)
270
- wp.launch(kernel, dim=1, inputs=[input], outputs=[shouldbeidentity, output], device=device)
271
-
272
- assert_np_equal(shouldbeidentity.numpy(), np.array([0, 0, 0, 1]), tol=tol)
273
-
274
- for i in range(4):
275
- cmp = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
276
- tape = wp.Tape()
277
- with tape:
278
- wp.launch(kernel, dim=1, inputs=[input], outputs=[shouldbeidentity, output], device=device)
279
- wp.launch(output_select_kernel, dim=1, inputs=[output, i], outputs=[cmp], device=device)
280
- tape.backward(loss=cmp)
281
- expectedgrads = np.zeros(len(input))
282
- expectedgrads[i] = -2 if i != 3 else 2
283
- assert_np_equal(tape.gradients[input].numpy(), expectedgrads)
284
- tape.zero()
285
-
286
-
287
- def test_dotproduct(test, device, dtype, register_kernels=False):
288
- rng = np.random.default_rng(123)
289
-
290
- tol = {
291
- np.float16: 1.0e-2,
292
- np.float32: 1.0e-6,
293
- np.float64: 1.0e-8,
294
- }.get(dtype, 0)
295
-
296
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
297
- quat = wp.types.quaternion(dtype=wptype)
298
-
299
- def check_quat_dot(
300
- s: wp.array(dtype=quat),
301
- v: wp.array(dtype=quat),
302
- dot: wp.array(dtype=wptype),
303
- ):
304
- dot[0] = wptype(2) * wp.dot(v[0], s[0])
305
-
306
- dotkernel = getkernel(check_quat_dot, suffix=dtype.__name__)
307
- if register_kernels:
308
- return
309
-
310
- s = wp.array(rng.standard_normal(size=4).astype(dtype), dtype=quat, requires_grad=True, device=device)
311
- v = wp.array(rng.standard_normal(size=4).astype(dtype), dtype=quat, requires_grad=True, device=device)
312
- dot = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
313
-
314
- tape = wp.Tape()
315
- with tape:
316
- wp.launch(
317
- dotkernel,
318
- dim=1,
319
- inputs=[
320
- s,
321
- v,
322
- ],
323
- outputs=[dot],
324
- device=device,
325
- )
326
-
327
- assert_np_equal(dot.numpy()[0], 2.0 * (v.numpy() * s.numpy()).sum(), tol=tol)
328
-
329
- tape.backward(loss=dot)
330
- sgrads = tape.gradients[s].numpy()[0]
331
- expected_grads = 2.0 * v.numpy()[0]
332
- assert_np_equal(sgrads, expected_grads, tol=10 * tol)
333
-
334
- vgrads = tape.gradients[v].numpy()[0]
335
- expected_grads = 2.0 * s.numpy()[0]
336
- assert_np_equal(vgrads, expected_grads, tol=tol)
337
-
338
-
339
- def test_length(test, device, dtype, register_kernels=False):
340
- rng = np.random.default_rng(123)
341
-
342
- tol = {
343
- np.float16: 5.0e-3,
344
- np.float32: 1.0e-6,
345
- np.float64: 1.0e-7,
346
- }.get(dtype, 0)
347
-
348
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
349
- quat = wp.types.quaternion(dtype=wptype)
350
-
351
- def check_quat_length(
352
- q: wp.array(dtype=quat),
353
- l: wp.array(dtype=wptype),
354
- l2: wp.array(dtype=wptype),
355
- ):
356
- l[0] = wptype(2) * wp.length(q[0])
357
- l2[0] = wptype(2) * wp.length_sq(q[0])
358
-
359
- kernel = getkernel(check_quat_length, suffix=dtype.__name__)
360
-
361
- if register_kernels:
362
- return
363
-
364
- q = wp.array(rng.standard_normal(size=4).astype(dtype), dtype=quat, requires_grad=True, device=device)
365
- l = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
366
- l2 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
367
-
368
- tape = wp.Tape()
369
- with tape:
370
- wp.launch(
371
- kernel,
372
- dim=1,
373
- inputs=[
374
- q,
375
- ],
376
- outputs=[l, l2],
377
- device=device,
378
- )
379
-
380
- assert_np_equal(l.numpy()[0], 2 * np.linalg.norm(q.numpy()), tol=10 * tol)
381
- assert_np_equal(l2.numpy()[0], 2 * np.linalg.norm(q.numpy()) ** 2, tol=10 * tol)
382
-
383
- tape.backward(loss=l)
384
- grad = tape.gradients[q].numpy()[0]
385
- expected_grad = 2 * q.numpy()[0] / np.linalg.norm(q.numpy())
386
- assert_np_equal(grad, expected_grad, tol=10 * tol)
387
- tape.zero()
388
-
389
- tape.backward(loss=l2)
390
- grad = tape.gradients[q].numpy()[0]
391
- expected_grad = 4 * q.numpy()[0]
392
- assert_np_equal(grad, expected_grad, tol=10 * tol)
393
- tape.zero()
394
-
395
-
396
- def test_normalize(test, device, dtype, register_kernels=False):
397
- rng = np.random.default_rng(123)
398
-
399
- tol = {
400
- np.float16: 5.0e-3,
401
- np.float32: 1.0e-6,
402
- np.float64: 1.0e-8,
403
- }.get(dtype, 0)
404
-
405
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
406
- quat = wp.types.quaternion(dtype=wptype)
407
-
408
- def check_normalize(
409
- q: wp.array(dtype=quat),
410
- n0: wp.array(dtype=wptype),
411
- n1: wp.array(dtype=wptype),
412
- n2: wp.array(dtype=wptype),
413
- n3: wp.array(dtype=wptype),
414
- ):
415
- n = wptype(2) * (wp.normalize(q[0]))
416
-
417
- n0[0] = n[0]
418
- n1[0] = n[1]
419
- n2[0] = n[2]
420
- n3[0] = n[3]
421
-
422
- def check_normalize_alt(
423
- q: wp.array(dtype=quat),
424
- n0: wp.array(dtype=wptype),
425
- n1: wp.array(dtype=wptype),
426
- n2: wp.array(dtype=wptype),
427
- n3: wp.array(dtype=wptype),
428
- ):
429
- n = wptype(2) * (q[0] / wp.length(q[0]))
430
-
431
- n0[0] = n[0]
432
- n1[0] = n[1]
433
- n2[0] = n[2]
434
- n3[0] = n[3]
435
-
436
- normalize_kernel = getkernel(check_normalize, suffix=dtype.__name__)
437
- normalize_alt_kernel = getkernel(check_normalize_alt, suffix=dtype.__name__)
438
-
439
- if register_kernels:
440
- return
441
-
442
- # I've already tested the things I'm using in check_normalize_alt, so I'll just
443
- # make sure the two are giving the same results/gradients
444
- q = wp.array(rng.standard_normal(size=4).astype(dtype), dtype=quat, requires_grad=True, device=device)
445
-
446
- n0 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
447
- n1 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
448
- n2 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
449
- n3 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
450
-
451
- n0_alt = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
452
- n1_alt = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
453
- n2_alt = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
454
- n3_alt = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
455
-
456
- outputs0 = [
457
- n0,
458
- n1,
459
- n2,
460
- n3,
461
- ]
462
- tape0 = wp.Tape()
463
- with tape0:
464
- wp.launch(normalize_kernel, dim=1, inputs=[q], outputs=outputs0, device=device)
465
-
466
- outputs1 = [
467
- n0_alt,
468
- n1_alt,
469
- n2_alt,
470
- n3_alt,
471
- ]
472
- tape1 = wp.Tape()
473
- with tape1:
474
- wp.launch(
475
- normalize_alt_kernel,
476
- dim=1,
477
- inputs=[
478
- q,
479
- ],
480
- outputs=outputs1,
481
- device=device,
482
- )
483
-
484
- assert_np_equal(n0.numpy()[0], n0_alt.numpy()[0], tol=tol)
485
- assert_np_equal(n1.numpy()[0], n1_alt.numpy()[0], tol=tol)
486
- assert_np_equal(n2.numpy()[0], n2_alt.numpy()[0], tol=tol)
487
- assert_np_equal(n3.numpy()[0], n3_alt.numpy()[0], tol=tol)
488
-
489
- for ncmp, ncmpalt in zip(outputs0, outputs1):
490
- tape0.backward(loss=ncmp)
491
- tape1.backward(loss=ncmpalt)
492
- assert_np_equal(tape0.gradients[q].numpy()[0], tape1.gradients[q].numpy()[0], tol=tol)
493
- tape0.zero()
494
- tape1.zero()
495
-
496
-
497
- def test_addition(test, device, dtype, register_kernels=False):
498
- rng = np.random.default_rng(123)
499
-
500
- tol = {
501
- np.float16: 5.0e-3,
502
- np.float32: 1.0e-6,
503
- np.float64: 1.0e-8,
504
- }.get(dtype, 0)
505
-
506
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
507
- quat = wp.types.quaternion(dtype=wptype)
508
-
509
- def check_quat_add(
510
- q: wp.array(dtype=quat),
511
- v: wp.array(dtype=quat),
512
- r0: wp.array(dtype=wptype),
513
- r1: wp.array(dtype=wptype),
514
- r2: wp.array(dtype=wptype),
515
- r3: wp.array(dtype=wptype),
516
- ):
517
- result = q[0] + v[0]
518
-
519
- r0[0] = wptype(2) * result[0]
520
- r1[0] = wptype(2) * result[1]
521
- r2[0] = wptype(2) * result[2]
522
- r3[0] = wptype(2) * result[3]
523
-
524
- kernel = getkernel(check_quat_add, suffix=dtype.__name__)
525
-
526
- if register_kernels:
527
- return
528
-
529
- q = wp.array(rng.standard_normal(size=4).astype(dtype), dtype=quat, requires_grad=True, device=device)
530
- v = wp.array(rng.standard_normal(size=4).astype(dtype), dtype=quat, requires_grad=True, device=device)
531
-
532
- r0 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
533
- r1 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
534
- r2 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
535
- r3 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
536
-
537
- tape = wp.Tape()
538
- with tape:
539
- wp.launch(
540
- kernel,
541
- dim=1,
542
- inputs=[
543
- q,
544
- v,
545
- ],
546
- outputs=[r0, r1, r2, r3],
547
- device=device,
548
- )
549
-
550
- assert_np_equal(r0.numpy()[0], 2 * (v.numpy()[0, 0] + q.numpy()[0, 0]), tol=tol)
551
- assert_np_equal(r1.numpy()[0], 2 * (v.numpy()[0, 1] + q.numpy()[0, 1]), tol=tol)
552
- assert_np_equal(r2.numpy()[0], 2 * (v.numpy()[0, 2] + q.numpy()[0, 2]), tol=tol)
553
- assert_np_equal(r3.numpy()[0], 2 * (v.numpy()[0, 3] + q.numpy()[0, 3]), tol=tol)
554
-
555
- for i, l in enumerate([r0, r1, r2, r3]):
556
- tape.backward(loss=l)
557
- qgrads = tape.gradients[q].numpy()[0]
558
- expected_grads = np.zeros_like(qgrads)
559
-
560
- expected_grads[i] = 2
561
- assert_np_equal(qgrads, expected_grads, tol=10 * tol)
562
-
563
- vgrads = tape.gradients[v].numpy()[0]
564
- assert_np_equal(vgrads, expected_grads, tol=tol)
565
-
566
- tape.zero()
567
-
568
-
569
- def test_subtraction(test, device, dtype, register_kernels=False):
570
- rng = np.random.default_rng(123)
571
-
572
- tol = {
573
- np.float16: 5.0e-3,
574
- np.float32: 1.0e-6,
575
- np.float64: 1.0e-8,
576
- }.get(dtype, 0)
577
-
578
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
579
- quat = wp.types.quaternion(dtype=wptype)
580
-
581
- def check_quat_sub(
582
- q: wp.array(dtype=quat),
583
- v: wp.array(dtype=quat),
584
- r0: wp.array(dtype=wptype),
585
- r1: wp.array(dtype=wptype),
586
- r2: wp.array(dtype=wptype),
587
- r3: wp.array(dtype=wptype),
588
- ):
589
- result = v[0] - q[0]
590
-
591
- r0[0] = wptype(2) * result[0]
592
- r1[0] = wptype(2) * result[1]
593
- r2[0] = wptype(2) * result[2]
594
- r3[0] = wptype(2) * result[3]
595
-
596
- kernel = getkernel(check_quat_sub, suffix=dtype.__name__)
597
-
598
- if register_kernels:
599
- return
600
-
601
- q = wp.array(rng.standard_normal(size=4).astype(dtype), dtype=quat, requires_grad=True, device=device)
602
- v = wp.array(rng.standard_normal(size=4).astype(dtype), dtype=quat, requires_grad=True, device=device)
603
-
604
- r0 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
605
- r1 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
606
- r2 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
607
- r3 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
608
-
609
- tape = wp.Tape()
610
- with tape:
611
- wp.launch(
612
- kernel,
613
- dim=1,
614
- inputs=[
615
- q,
616
- v,
617
- ],
618
- outputs=[r0, r1, r2, r3],
619
- device=device,
620
- )
621
-
622
- assert_np_equal(r0.numpy()[0], 2 * (v.numpy()[0, 0] - q.numpy()[0, 0]), tol=tol)
623
- assert_np_equal(r1.numpy()[0], 2 * (v.numpy()[0, 1] - q.numpy()[0, 1]), tol=tol)
624
- assert_np_equal(r2.numpy()[0], 2 * (v.numpy()[0, 2] - q.numpy()[0, 2]), tol=tol)
625
- assert_np_equal(r3.numpy()[0], 2 * (v.numpy()[0, 3] - q.numpy()[0, 3]), tol=tol)
626
-
627
- for i, l in enumerate([r0, r1, r2, r3]):
628
- tape.backward(loss=l)
629
- qgrads = tape.gradients[q].numpy()[0]
630
- expected_grads = np.zeros_like(qgrads)
631
-
632
- expected_grads[i] = -2
633
- assert_np_equal(qgrads, expected_grads, tol=10 * tol)
634
-
635
- vgrads = tape.gradients[v].numpy()[0]
636
- expected_grads[i] = 2
637
- assert_np_equal(vgrads, expected_grads, tol=tol)
638
-
639
- tape.zero()
640
-
641
-
642
- def test_scalar_multiplication(test, device, dtype, register_kernels=False):
643
- rng = np.random.default_rng(123)
644
-
645
- tol = {
646
- np.float16: 5.0e-3,
647
- np.float32: 1.0e-6,
648
- np.float64: 1.0e-8,
649
- }.get(dtype, 0)
650
-
651
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
652
- quat = wp.types.quaternion(dtype=wptype)
653
-
654
- def check_quat_scalar_mul(
655
- s: wp.array(dtype=wptype),
656
- q: wp.array(dtype=quat),
657
- l0: wp.array(dtype=wptype),
658
- l1: wp.array(dtype=wptype),
659
- l2: wp.array(dtype=wptype),
660
- l3: wp.array(dtype=wptype),
661
- r0: wp.array(dtype=wptype),
662
- r1: wp.array(dtype=wptype),
663
- r2: wp.array(dtype=wptype),
664
- r3: wp.array(dtype=wptype),
665
- ):
666
- lresult = s[0] * q[0]
667
- rresult = q[0] * s[0]
668
-
669
- # multiply outputs by 2 so we've got something to backpropagate:
670
- l0[0] = wptype(2) * lresult[0]
671
- l1[0] = wptype(2) * lresult[1]
672
- l2[0] = wptype(2) * lresult[2]
673
- l3[0] = wptype(2) * lresult[3]
674
-
675
- r0[0] = wptype(2) * rresult[0]
676
- r1[0] = wptype(2) * rresult[1]
677
- r2[0] = wptype(2) * rresult[2]
678
- r3[0] = wptype(2) * rresult[3]
679
-
680
- kernel = getkernel(check_quat_scalar_mul, suffix=dtype.__name__)
681
-
682
- if register_kernels:
683
- return
684
-
685
- s = wp.array(rng.standard_normal(size=1).astype(dtype), requires_grad=True, device=device)
686
- q = wp.array(rng.standard_normal(size=(1, 4)).astype(dtype), dtype=quat, requires_grad=True, device=device)
687
-
688
- l0 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
689
- l1 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
690
- l2 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
691
- l3 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
692
-
693
- r0 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
694
- r1 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
695
- r2 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
696
- r3 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
697
-
698
- tape = wp.Tape()
699
- with tape:
700
- wp.launch(
701
- kernel,
702
- dim=1,
703
- inputs=[s, q],
704
- outputs=[
705
- l0,
706
- l1,
707
- l2,
708
- l3,
709
- r0,
710
- r1,
711
- r2,
712
- r3,
713
- ],
714
- device=device,
715
- )
716
-
717
- assert_np_equal(l0.numpy()[0], 2 * s.numpy()[0] * q.numpy()[0, 0], tol=tol)
718
- assert_np_equal(l1.numpy()[0], 2 * s.numpy()[0] * q.numpy()[0, 1], tol=tol)
719
- assert_np_equal(l2.numpy()[0], 2 * s.numpy()[0] * q.numpy()[0, 2], tol=tol)
720
- assert_np_equal(l3.numpy()[0], 2 * s.numpy()[0] * q.numpy()[0, 3], tol=tol)
721
-
722
- assert_np_equal(r0.numpy()[0], 2 * s.numpy()[0] * q.numpy()[0, 0], tol=tol)
723
- assert_np_equal(r1.numpy()[0], 2 * s.numpy()[0] * q.numpy()[0, 1], tol=tol)
724
- assert_np_equal(r2.numpy()[0], 2 * s.numpy()[0] * q.numpy()[0, 2], tol=tol)
725
- assert_np_equal(r3.numpy()[0], 2 * s.numpy()[0] * q.numpy()[0, 3], tol=tol)
726
-
727
- if dtype in np_float_types:
728
- for i, outputs in enumerate([(l0, r0), (l1, r1), (l2, r2), (l3, r3)]):
729
- for l in outputs:
730
- tape.backward(loss=l)
731
- sgrad = tape.gradients[s].numpy()[0]
732
- assert_np_equal(sgrad, 2 * q.numpy()[0, i], tol=tol)
733
- allgrads = tape.gradients[q].numpy()[0]
734
- expected_grads = np.zeros_like(allgrads)
735
- expected_grads[i] = s.numpy()[0] * 2
736
- assert_np_equal(allgrads, expected_grads, tol=10 * tol)
737
- tape.zero()
738
-
739
-
740
- def test_scalar_division(test, device, dtype, register_kernels=False):
741
- rng = np.random.default_rng(123)
742
-
743
- tol = {
744
- np.float16: 1.0e-3,
745
- np.float32: 1.0e-6,
746
- np.float64: 1.0e-8,
747
- }.get(dtype, 0)
748
-
749
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
750
- quat = wp.types.quaternion(dtype=wptype)
751
-
752
- def check_quat_scalar_div(
753
- s: wp.array(dtype=wptype),
754
- q: wp.array(dtype=quat),
755
- r0: wp.array(dtype=wptype),
756
- r1: wp.array(dtype=wptype),
757
- r2: wp.array(dtype=wptype),
758
- r3: wp.array(dtype=wptype),
759
- ):
760
- result = q[0] / s[0]
761
-
762
- # multiply outputs by 2 so we've got something to backpropagate:
763
- r0[0] = wptype(2) * result[0]
764
- r1[0] = wptype(2) * result[1]
765
- r2[0] = wptype(2) * result[2]
766
- r3[0] = wptype(2) * result[3]
767
-
768
- kernel = getkernel(check_quat_scalar_div, suffix=dtype.__name__)
769
-
770
- if register_kernels:
771
- return
772
-
773
- s = wp.array(rng.standard_normal(size=1).astype(dtype), requires_grad=True, device=device)
774
- q = wp.array(rng.standard_normal(size=(1, 4)).astype(dtype), dtype=quat, requires_grad=True, device=device)
775
-
776
- r0 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
777
- r1 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
778
- r2 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
779
- r3 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
780
-
781
- tape = wp.Tape()
782
- with tape:
783
- wp.launch(
784
- kernel,
785
- dim=1,
786
- inputs=[s, q],
787
- outputs=[
788
- r0,
789
- r1,
790
- r2,
791
- r3,
792
- ],
793
- device=device,
794
- )
795
- assert_np_equal(r0.numpy()[0], 2 * q.numpy()[0, 0] / s.numpy()[0], tol=tol)
796
- assert_np_equal(r1.numpy()[0], 2 * q.numpy()[0, 1] / s.numpy()[0], tol=tol)
797
- assert_np_equal(r2.numpy()[0], 2 * q.numpy()[0, 2] / s.numpy()[0], tol=tol)
798
- assert_np_equal(r3.numpy()[0], 2 * q.numpy()[0, 3] / s.numpy()[0], tol=tol)
799
-
800
- if dtype in np_float_types:
801
- for i, r in enumerate([r0, r1, r2, r3]):
802
- tape.backward(loss=r)
803
- sgrad = tape.gradients[s].numpy()[0]
804
- assert_np_equal(sgrad, -2 * q.numpy()[0, i] / (s.numpy()[0] * s.numpy()[0]), tol=tol)
805
-
806
- allgrads = tape.gradients[q].numpy()[0]
807
- expected_grads = np.zeros_like(allgrads)
808
- expected_grads[i] = 2 / s.numpy()[0]
809
- assert_np_equal(allgrads, expected_grads, tol=10 * tol)
810
- tape.zero()
811
-
812
-
813
- def test_quat_multiplication(test, device, dtype, register_kernels=False):
814
- rng = np.random.default_rng(123)
815
-
816
- tol = {
817
- np.float16: 1.0e-2,
818
- np.float32: 1.0e-6,
819
- np.float64: 1.0e-8,
820
- }.get(dtype, 0)
821
-
822
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
823
- quat = wp.types.quaternion(dtype=wptype)
824
-
825
- def check_quat_mul(
826
- s: wp.array(dtype=quat),
827
- q: wp.array(dtype=quat),
828
- r0: wp.array(dtype=wptype),
829
- r1: wp.array(dtype=wptype),
830
- r2: wp.array(dtype=wptype),
831
- r3: wp.array(dtype=wptype),
832
- ):
833
- result = s[0] * q[0]
834
-
835
- # multiply outputs by 2 so we've got something to backpropagate:
836
- r0[0] = wptype(2) * result[0]
837
- r1[0] = wptype(2) * result[1]
838
- r2[0] = wptype(2) * result[2]
839
- r3[0] = wptype(2) * result[3]
840
-
841
- kernel = getkernel(check_quat_mul, suffix=dtype.__name__)
842
-
843
- if register_kernels:
844
- return
845
-
846
- s = wp.array(rng.standard_normal(size=(1, 4)).astype(dtype), dtype=quat, requires_grad=True, device=device)
847
- q = wp.array(rng.standard_normal(size=(1, 4)).astype(dtype), dtype=quat, requires_grad=True, device=device)
848
-
849
- r0 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
850
- r1 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
851
- r2 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
852
- r3 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
853
-
854
- tape = wp.Tape()
855
- with tape:
856
- wp.launch(
857
- kernel,
858
- dim=1,
859
- inputs=[s, q],
860
- outputs=[
861
- r0,
862
- r1,
863
- r2,
864
- r3,
865
- ],
866
- device=device,
867
- )
868
-
869
- a = s.numpy()
870
- b = q.numpy()
871
- assert_np_equal(
872
- r0.numpy()[0], 2 * (a[0, 3] * b[0, 0] + b[0, 3] * a[0, 0] + a[0, 1] * b[0, 2] - b[0, 1] * a[0, 2]), tol=tol
873
- )
874
- assert_np_equal(
875
- r1.numpy()[0], 2 * (a[0, 3] * b[0, 1] + b[0, 3] * a[0, 1] + a[0, 2] * b[0, 0] - b[0, 2] * a[0, 0]), tol=tol
876
- )
877
- assert_np_equal(
878
- r2.numpy()[0], 2 * (a[0, 3] * b[0, 2] + b[0, 3] * a[0, 2] + a[0, 0] * b[0, 1] - b[0, 0] * a[0, 1]), tol=tol
879
- )
880
- assert_np_equal(
881
- r3.numpy()[0], 2 * (a[0, 3] * b[0, 3] - a[0, 0] * b[0, 0] - a[0, 1] * b[0, 1] - a[0, 2] * b[0, 2]), tol=tol
882
- )
883
-
884
- tape.backward(loss=r0)
885
- agrad = tape.gradients[s].numpy()[0]
886
- assert_np_equal(agrad, 2 * np.array([b[0, 3], b[0, 2], -b[0, 1], b[0, 0]]), tol=tol)
887
-
888
- bgrad = tape.gradients[q].numpy()[0]
889
- assert_np_equal(bgrad, 2 * np.array([a[0, 3], -a[0, 2], a[0, 1], a[0, 0]]), tol=tol)
890
- tape.zero()
891
-
892
- tape.backward(loss=r1)
893
- agrad = tape.gradients[s].numpy()[0]
894
- assert_np_equal(agrad, 2 * np.array([-b[0, 2], b[0, 3], b[0, 0], b[0, 1]]), tol=tol)
895
-
896
- bgrad = tape.gradients[q].numpy()[0]
897
- assert_np_equal(bgrad, 2 * np.array([a[0, 2], a[0, 3], -a[0, 0], a[0, 1]]), tol=tol)
898
- tape.zero()
899
-
900
- tape.backward(loss=r2)
901
- agrad = tape.gradients[s].numpy()[0]
902
- assert_np_equal(agrad, 2 * np.array([b[0, 1], -b[0, 0], b[0, 3], b[0, 2]]), tol=tol)
903
-
904
- bgrad = tape.gradients[q].numpy()[0]
905
- assert_np_equal(bgrad, 2 * np.array([-a[0, 1], a[0, 0], a[0, 3], a[0, 2]]), tol=tol)
906
- tape.zero()
907
-
908
- tape.backward(loss=r3)
909
- agrad = tape.gradients[s].numpy()[0]
910
- assert_np_equal(agrad, 2 * np.array([-b[0, 0], -b[0, 1], -b[0, 2], b[0, 3]]), tol=tol)
911
-
912
- bgrad = tape.gradients[q].numpy()[0]
913
- assert_np_equal(bgrad, 2 * np.array([-a[0, 0], -a[0, 1], -a[0, 2], a[0, 3]]), tol=tol)
914
- tape.zero()
915
-
916
-
917
- def test_indexing(test, device, dtype, register_kernels=False):
918
- rng = np.random.default_rng(123)
919
-
920
- tol = {
921
- np.float16: 5.0e-3,
922
- np.float32: 1.0e-6,
923
- np.float64: 1.0e-8,
924
- }.get(dtype, 0)
925
-
926
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
927
- quat = wp.types.quaternion(dtype=wptype)
928
-
929
- def check_quat_indexing(
930
- q: wp.array(dtype=quat),
931
- r0: wp.array(dtype=wptype),
932
- r1: wp.array(dtype=wptype),
933
- r2: wp.array(dtype=wptype),
934
- r3: wp.array(dtype=wptype),
935
- ):
936
- # multiply outputs by 2 so we've got something to backpropagate:
937
- r0[0] = wptype(2) * q[0][0]
938
- r1[0] = wptype(2) * q[0][1]
939
- r2[0] = wptype(2) * q[0][2]
940
- r3[0] = wptype(2) * q[0][3]
941
-
942
- kernel = getkernel(check_quat_indexing, suffix=dtype.__name__)
943
-
944
- if register_kernels:
945
- return
946
-
947
- q = wp.array(rng.standard_normal(size=(1, 4)).astype(dtype), dtype=quat, requires_grad=True, device=device)
948
- r0 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
949
- r1 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
950
- r2 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
951
- r3 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
952
-
953
- tape = wp.Tape()
954
- with tape:
955
- wp.launch(kernel, dim=1, inputs=[q], outputs=[r0, r1, r2, r3], device=device)
956
-
957
- for i, l in enumerate([r0, r1, r2, r3]):
958
- tape.backward(loss=l)
959
- allgrads = tape.gradients[q].numpy()[0]
960
- expected_grads = np.zeros_like(allgrads)
961
- expected_grads[i] = 2
962
- assert_np_equal(allgrads, expected_grads, tol=tol)
963
- tape.zero()
964
-
965
- assert_np_equal(r0.numpy()[0], 2.0 * q.numpy()[0, 0], tol=tol)
966
- assert_np_equal(r1.numpy()[0], 2.0 * q.numpy()[0, 1], tol=tol)
967
- assert_np_equal(r2.numpy()[0], 2.0 * q.numpy()[0, 2], tol=tol)
968
- assert_np_equal(r3.numpy()[0], 2.0 * q.numpy()[0, 3], tol=tol)
969
-
970
-
971
- def test_quat_lerp(test, device, dtype, register_kernels=False):
972
- rng = np.random.default_rng(123)
973
-
974
- tol = {
975
- np.float16: 1.0e-2,
976
- np.float32: 1.0e-6,
977
- np.float64: 1.0e-8,
978
- }.get(dtype, 0)
979
-
980
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
981
- quat = wp.types.quaternion(dtype=wptype)
982
-
983
- def check_quat_lerp(
984
- s: wp.array(dtype=quat),
985
- q: wp.array(dtype=quat),
986
- t: wp.array(dtype=wptype),
987
- r0: wp.array(dtype=wptype),
988
- r1: wp.array(dtype=wptype),
989
- r2: wp.array(dtype=wptype),
990
- r3: wp.array(dtype=wptype),
991
- ):
992
- result = wp.lerp(s[0], q[0], t[0])
993
-
994
- # multiply outputs by 2 so we've got something to backpropagate:
995
- r0[0] = wptype(2) * result[0]
996
- r1[0] = wptype(2) * result[1]
997
- r2[0] = wptype(2) * result[2]
998
- r3[0] = wptype(2) * result[3]
999
-
1000
- kernel = getkernel(check_quat_lerp, suffix=dtype.__name__)
1001
-
1002
- if register_kernels:
1003
- return
1004
-
1005
- s = wp.array(rng.standard_normal(size=(1, 4)).astype(dtype), dtype=quat, requires_grad=True, device=device)
1006
- q = wp.array(rng.standard_normal(size=(1, 4)).astype(dtype), dtype=quat, requires_grad=True, device=device)
1007
- t = wp.array(rng.uniform(size=1).astype(dtype), dtype=wptype, requires_grad=True, device=device)
1008
-
1009
- r0 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1010
- r1 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1011
- r2 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1012
- r3 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1013
-
1014
- tape = wp.Tape()
1015
- with tape:
1016
- wp.launch(
1017
- kernel,
1018
- dim=1,
1019
- inputs=[s, q, t],
1020
- outputs=[
1021
- r0,
1022
- r1,
1023
- r2,
1024
- r3,
1025
- ],
1026
- device=device,
1027
- )
1028
-
1029
- a = s.numpy()
1030
- b = q.numpy()
1031
- tt = t.numpy()
1032
- assert_np_equal(r0.numpy()[0], 2 * ((1 - tt) * a[0, 0] + tt * b[0, 0]), tol=tol)
1033
- assert_np_equal(r1.numpy()[0], 2 * ((1 - tt) * a[0, 1] + tt * b[0, 1]), tol=tol)
1034
- assert_np_equal(r2.numpy()[0], 2 * ((1 - tt) * a[0, 2] + tt * b[0, 2]), tol=tol)
1035
- assert_np_equal(r3.numpy()[0], 2 * ((1 - tt) * a[0, 3] + tt * b[0, 3]), tol=tol)
1036
-
1037
- for i, l in enumerate([r0, r1, r2, r3]):
1038
- tape.backward(loss=l)
1039
- agrad = tape.gradients[s].numpy()[0]
1040
- bgrad = tape.gradients[q].numpy()[0]
1041
- tgrad = tape.gradients[t].numpy()[0]
1042
- expected_grads = np.zeros_like(agrad)
1043
- expected_grads[i] = 2 * (1 - tt)
1044
- assert_np_equal(agrad, expected_grads, tol=tol)
1045
- expected_grads[i] = 2 * tt
1046
- assert_np_equal(bgrad, expected_grads, tol=tol)
1047
- assert_np_equal(tgrad, 2 * (b[0, i] - a[0, i]), tol=tol)
1048
-
1049
- tape.zero()
1050
-
1051
-
1052
- def test_quat_rotate(test, device, dtype, register_kernels=False):
1053
- rng = np.random.default_rng(123)
1054
-
1055
- tol = {
1056
- np.float16: 1.0e-2,
1057
- np.float32: 1.0e-6,
1058
- np.float64: 1.0e-8,
1059
- }.get(dtype, 0)
1060
-
1061
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1062
- quat = wp.types.quaternion(dtype=wptype)
1063
- vec3 = wp.types.vector(length=3, dtype=wptype)
1064
-
1065
- def check_quat_rotate(
1066
- q: wp.array(dtype=quat),
1067
- v: wp.array(dtype=vec3),
1068
- outputs: wp.array(dtype=wptype),
1069
- outputs_inv: wp.array(dtype=wptype),
1070
- outputs_manual: wp.array(dtype=wptype),
1071
- outputs_inv_manual: wp.array(dtype=wptype),
1072
- ):
1073
- result = wp.quat_rotate(q[0], v[0])
1074
- result_inv = wp.quat_rotate_inv(q[0], v[0])
1075
-
1076
- qv = vec3(q[0][0], q[0][1], q[0][2])
1077
- qw = q[0][3]
1078
-
1079
- result_manual = v[0] * (wptype(2) * qw * qw - wptype(1))
1080
- result_manual += wp.cross(qv, v[0]) * qw * wptype(2)
1081
- result_manual += qv * wp.dot(qv, v[0]) * wptype(2)
1082
-
1083
- result_inv_manual = v[0] * (wptype(2) * qw * qw - wptype(1))
1084
- result_inv_manual -= wp.cross(qv, v[0]) * qw * wptype(2)
1085
- result_inv_manual += qv * wp.dot(qv, v[0]) * wptype(2)
1086
-
1087
- for i in range(3):
1088
- # multiply outputs by 2 so we've got something to backpropagate:
1089
- outputs[i] = wptype(2) * result[i]
1090
- outputs_inv[i] = wptype(2) * result_inv[i]
1091
- outputs_manual[i] = wptype(2) * result_manual[i]
1092
- outputs_inv_manual[i] = wptype(2) * result_inv_manual[i]
1093
-
1094
- kernel = getkernel(check_quat_rotate, suffix=dtype.__name__)
1095
- output_select_kernel = get_select_kernel(wptype)
1096
-
1097
- if register_kernels:
1098
- return
1099
-
1100
- q = rng.standard_normal(size=(1, 4))
1101
- q /= np.linalg.norm(q)
1102
- q = wp.array(q.astype(dtype), dtype=quat, requires_grad=True, device=device)
1103
- v = wp.array(0.5 * rng.standard_normal(size=(1, 3)).astype(dtype), dtype=vec3, requires_grad=True, device=device)
1104
-
1105
- # test values against the manually computed result:
1106
- outputs = wp.zeros(3, dtype=wptype, requires_grad=True, device=device)
1107
- outputs_inv = wp.zeros(3, dtype=wptype, requires_grad=True, device=device)
1108
- outputs_manual = wp.zeros(3, dtype=wptype, requires_grad=True, device=device)
1109
- outputs_inv_manual = wp.zeros(3, dtype=wptype, requires_grad=True, device=device)
1110
-
1111
- wp.launch(
1112
- kernel,
1113
- dim=1,
1114
- inputs=[q, v],
1115
- outputs=[
1116
- outputs,
1117
- outputs_inv,
1118
- outputs_manual,
1119
- outputs_inv_manual,
1120
- ],
1121
- device=device,
1122
- )
1123
-
1124
- assert_np_equal(outputs.numpy(), outputs_manual.numpy(), tol=tol)
1125
- assert_np_equal(outputs_inv.numpy(), outputs_inv_manual.numpy(), tol=tol)
1126
-
1127
- # test gradients against the manually computed result:
1128
- for i in range(3):
1129
- cmp = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1130
- cmp_inv = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1131
- cmp_manual = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1132
- cmp_inv_manual = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1133
- tape = wp.Tape()
1134
- with tape:
1135
- wp.launch(
1136
- kernel,
1137
- dim=1,
1138
- inputs=[q, v],
1139
- outputs=[
1140
- outputs,
1141
- outputs_inv,
1142
- outputs_manual,
1143
- outputs_inv_manual,
1144
- ],
1145
- device=device,
1146
- )
1147
- wp.launch(output_select_kernel, dim=1, inputs=[outputs, i], outputs=[cmp], device=device)
1148
- wp.launch(output_select_kernel, dim=1, inputs=[outputs_inv, i], outputs=[cmp_inv], device=device)
1149
- wp.launch(output_select_kernel, dim=1, inputs=[outputs_manual, i], outputs=[cmp_manual], device=device)
1150
- wp.launch(
1151
- output_select_kernel, dim=1, inputs=[outputs_inv_manual, i], outputs=[cmp_inv_manual], device=device
1152
- )
1153
-
1154
- tape.backward(loss=cmp)
1155
- qgrads = 1.0 * tape.gradients[q].numpy()
1156
- vgrads = 1.0 * tape.gradients[v].numpy()
1157
- tape.zero()
1158
- tape.backward(loss=cmp_inv)
1159
- qgrads_inv = 1.0 * tape.gradients[q].numpy()
1160
- vgrads_inv = 1.0 * tape.gradients[v].numpy()
1161
- tape.zero()
1162
- tape.backward(loss=cmp_manual)
1163
- qgrads_manual = 1.0 * tape.gradients[q].numpy()
1164
- vgrads_manual = 1.0 * tape.gradients[v].numpy()
1165
- tape.zero()
1166
- tape.backward(loss=cmp_inv_manual)
1167
- qgrads_inv_manual = 1.0 * tape.gradients[q].numpy()
1168
- vgrads_inv_manual = 1.0 * tape.gradients[v].numpy()
1169
- tape.zero()
1170
-
1171
- assert_np_equal(qgrads, qgrads_manual, tol=tol)
1172
- assert_np_equal(vgrads, vgrads_manual, tol=tol)
1173
-
1174
- assert_np_equal(qgrads_inv, qgrads_inv_manual, tol=tol)
1175
- assert_np_equal(vgrads_inv, vgrads_inv_manual, tol=tol)
1176
-
1177
-
1178
- def test_quat_to_matrix(test, device, dtype, register_kernels=False):
1179
- rng = np.random.default_rng(123)
1180
-
1181
- tol = {
1182
- np.float16: 1.0e-2,
1183
- np.float32: 1.0e-6,
1184
- np.float64: 1.0e-8,
1185
- }.get(dtype, 0)
1186
-
1187
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1188
- quat = wp.types.quaternion(dtype=wptype)
1189
- mat3 = wp.types.matrix(shape=(3, 3), dtype=wptype)
1190
- vec3 = wp.types.vector(length=3, dtype=wptype)
1191
-
1192
- def check_quat_to_matrix(
1193
- q: wp.array(dtype=quat),
1194
- outputs: wp.array(dtype=wptype),
1195
- outputs_manual: wp.array(dtype=wptype),
1196
- ):
1197
- result = wp.quat_to_matrix(q[0])
1198
-
1199
- xaxis = wp.quat_rotate(
1200
- q[0],
1201
- vec3(
1202
- wptype(1),
1203
- wptype(0),
1204
- wptype(0),
1205
- ),
1206
- )
1207
- yaxis = wp.quat_rotate(
1208
- q[0],
1209
- vec3(
1210
- wptype(0),
1211
- wptype(1),
1212
- wptype(0),
1213
- ),
1214
- )
1215
- zaxis = wp.quat_rotate(
1216
- q[0],
1217
- vec3(
1218
- wptype(0),
1219
- wptype(0),
1220
- wptype(1),
1221
- ),
1222
- )
1223
- result_manual = mat3(xaxis, yaxis, zaxis)
1224
-
1225
- idx = 0
1226
- for i in range(3):
1227
- for j in range(3):
1228
- # multiply outputs by 2 so we've got something to backpropagate:
1229
- outputs[idx] = wptype(2) * result[i, j]
1230
- outputs_manual[idx] = wptype(2) * result_manual[i, j]
1231
-
1232
- idx = idx + 1
1233
-
1234
- kernel = getkernel(check_quat_to_matrix, suffix=dtype.__name__)
1235
- output_select_kernel = get_select_kernel(wptype)
1236
-
1237
- if register_kernels:
1238
- return
1239
-
1240
- q = rng.standard_normal(size=(1, 4))
1241
- q /= np.linalg.norm(q)
1242
- q = wp.array(q.astype(dtype), dtype=quat, requires_grad=True, device=device)
1243
-
1244
- # test values against the manually computed result:
1245
- outputs = wp.zeros(3 * 3, dtype=wptype, requires_grad=True, device=device)
1246
- outputs_manual = wp.zeros(3 * 3, dtype=wptype, requires_grad=True, device=device)
1247
-
1248
- wp.launch(
1249
- kernel,
1250
- dim=1,
1251
- inputs=[q],
1252
- outputs=[
1253
- outputs,
1254
- outputs_manual,
1255
- ],
1256
- device=device,
1257
- )
1258
-
1259
- assert_np_equal(outputs.numpy(), outputs_manual.numpy(), tol=tol)
1260
-
1261
- # sanity check: divide by 2 to remove that scale factor we put in there, and
1262
- # it should be a rotation matrix
1263
- R = 0.5 * outputs.numpy().reshape(3, 3)
1264
- assert_np_equal(np.matmul(R, R.T), np.eye(3), tol=tol)
1265
-
1266
- # test gradients against the manually computed result:
1267
- idx = 0
1268
- for i in range(3):
1269
- for j in range(3):
1270
- cmp = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1271
- cmp_manual = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1272
- tape = wp.Tape()
1273
- with tape:
1274
- wp.launch(
1275
- kernel,
1276
- dim=1,
1277
- inputs=[q],
1278
- outputs=[
1279
- outputs,
1280
- outputs_manual,
1281
- ],
1282
- device=device,
1283
- )
1284
- wp.launch(output_select_kernel, dim=1, inputs=[outputs, idx], outputs=[cmp], device=device)
1285
- wp.launch(
1286
- output_select_kernel, dim=1, inputs=[outputs_manual, idx], outputs=[cmp_manual], device=device
1287
- )
1288
- tape.backward(loss=cmp)
1289
- qgrads = 1.0 * tape.gradients[q].numpy()
1290
- tape.zero()
1291
- tape.backward(loss=cmp_manual)
1292
- qgrads_manual = 1.0 * tape.gradients[q].numpy()
1293
- tape.zero()
1294
-
1295
- assert_np_equal(qgrads, qgrads_manual, tol=tol)
1296
- idx = idx + 1
1297
-
1298
-
1299
- ############################################################
1300
-
1301
-
1302
- def test_slerp_grad(test, device, dtype, register_kernels=False):
1303
- rng = np.random.default_rng(123)
1304
- seed = 42
1305
-
1306
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1307
- vec3 = wp.types.vector(3, wptype)
1308
- quat = wp.types.quaternion(wptype)
1309
-
1310
- def slerp_kernel(
1311
- q0: wp.array(dtype=quat),
1312
- q1: wp.array(dtype=quat),
1313
- t: wp.array(dtype=wptype),
1314
- loss: wp.array(dtype=wptype),
1315
- index: int,
1316
- ):
1317
- tid = wp.tid()
1318
-
1319
- q = wp.quat_slerp(q0[tid], q1[tid], t[tid])
1320
- wp.atomic_add(loss, 0, q[index])
1321
-
1322
- slerp_kernel = getkernel(slerp_kernel, suffix=dtype.__name__)
1323
-
1324
- def slerp_kernel_forward(
1325
- q0: wp.array(dtype=quat),
1326
- q1: wp.array(dtype=quat),
1327
- t: wp.array(dtype=wptype),
1328
- loss: wp.array(dtype=wptype),
1329
- index: int,
1330
- ):
1331
- tid = wp.tid()
1332
-
1333
- axis = vec3()
1334
- angle = wptype(0.0)
1335
-
1336
- wp.quat_to_axis_angle(wp.mul(wp.quat_inverse(q0[tid]), q1[tid]), axis, angle)
1337
- q = wp.mul(q0[tid], wp.quat_from_axis_angle(axis, t[tid] * angle))
1338
-
1339
- wp.atomic_add(loss, 0, q[index])
1340
-
1341
- slerp_kernel_forward = getkernel(slerp_kernel_forward, suffix=dtype.__name__)
1342
-
1343
- def quat_sampler_slerp(kernel_seed: int, quats: wp.array(dtype=quat)):
1344
- tid = wp.tid()
1345
-
1346
- state = wp.rand_init(kernel_seed, tid)
1347
-
1348
- angle = wp.randf(state, 0.0, 2.0 * 3.1415926535)
1349
- dir = wp.sample_unit_sphere_surface(state) * wp.sin(angle * 0.5)
1350
-
1351
- q = quat(wptype(dir[0]), wptype(dir[1]), wptype(dir[2]), wptype(wp.cos(angle * 0.5)))
1352
- qn = wp.normalize(q)
1353
-
1354
- quats[tid] = qn
1355
-
1356
- quat_sampler = getkernel(quat_sampler_slerp, suffix=dtype.__name__)
1357
-
1358
- if register_kernels:
1359
- return
1360
-
1361
- N = 50
1362
-
1363
- q0 = wp.zeros(N, dtype=quat, device=device, requires_grad=True)
1364
- q1 = wp.zeros(N, dtype=quat, device=device, requires_grad=True)
1365
-
1366
- wp.launch(kernel=quat_sampler, dim=N, inputs=[seed, q0], device=device)
1367
- wp.launch(kernel=quat_sampler, dim=N, inputs=[seed + 1, q1], device=device)
1368
-
1369
- t = rng.uniform(low=0.0, high=1.0, size=N)
1370
- t = wp.array(t, dtype=wptype, device=device, requires_grad=True)
1371
-
1372
- def compute_gradients(kernel, wrt, index):
1373
- loss = wp.zeros(1, dtype=wptype, device=device, requires_grad=True)
1374
- tape = wp.Tape()
1375
- with tape:
1376
- wp.launch(kernel=kernel, dim=N, inputs=[q0, q1, t, loss, index], device=device)
1377
-
1378
- tape.backward(loss)
1379
-
1380
- gradients = 1.0 * tape.gradients[wrt].numpy()
1381
- tape.zero()
1382
-
1383
- return loss.numpy()[0], gradients
1384
-
1385
- eps = {
1386
- np.float16: 2.0e-2,
1387
- np.float32: 1.0e-5,
1388
- np.float64: 1.0e-8,
1389
- }.get(dtype, 0)
1390
-
1391
- # wrt t
1392
-
1393
- # gather gradients from builtin adjoints
1394
- xcmp, gradients_x = compute_gradients(slerp_kernel, t, 0)
1395
- ycmp, gradients_y = compute_gradients(slerp_kernel, t, 1)
1396
- zcmp, gradients_z = compute_gradients(slerp_kernel, t, 2)
1397
- wcmp, gradients_w = compute_gradients(slerp_kernel, t, 3)
1398
-
1399
- # gather gradients from autodiff
1400
- xcmp_auto, gradients_x_auto = compute_gradients(slerp_kernel_forward, t, 0)
1401
- ycmp_auto, gradients_y_auto = compute_gradients(slerp_kernel_forward, t, 1)
1402
- zcmp_auto, gradients_z_auto = compute_gradients(slerp_kernel_forward, t, 2)
1403
- wcmp_auto, gradients_w_auto = compute_gradients(slerp_kernel_forward, t, 3)
1404
-
1405
- assert_np_equal(gradients_x, gradients_x_auto, tol=eps)
1406
- assert_np_equal(gradients_y, gradients_y_auto, tol=eps)
1407
- assert_np_equal(gradients_z, gradients_z_auto, tol=eps)
1408
- assert_np_equal(gradients_w, gradients_w_auto, tol=eps)
1409
- assert_np_equal(xcmp, xcmp_auto, tol=eps)
1410
- assert_np_equal(ycmp, ycmp_auto, tol=eps)
1411
- assert_np_equal(zcmp, zcmp_auto, tol=eps)
1412
- assert_np_equal(wcmp, wcmp_auto, tol=eps)
1413
-
1414
- # wrt q0
1415
-
1416
- # gather gradients from builtin adjoints
1417
- xcmp, gradients_x = compute_gradients(slerp_kernel, q0, 0)
1418
- ycmp, gradients_y = compute_gradients(slerp_kernel, q0, 1)
1419
- zcmp, gradients_z = compute_gradients(slerp_kernel, q0, 2)
1420
- wcmp, gradients_w = compute_gradients(slerp_kernel, q0, 3)
1421
-
1422
- # gather gradients from autodiff
1423
- xcmp_auto, gradients_x_auto = compute_gradients(slerp_kernel_forward, q0, 0)
1424
- ycmp_auto, gradients_y_auto = compute_gradients(slerp_kernel_forward, q0, 1)
1425
- zcmp_auto, gradients_z_auto = compute_gradients(slerp_kernel_forward, q0, 2)
1426
- wcmp_auto, gradients_w_auto = compute_gradients(slerp_kernel_forward, q0, 3)
1427
-
1428
- assert_np_equal(gradients_x, gradients_x_auto, tol=eps)
1429
- assert_np_equal(gradients_y, gradients_y_auto, tol=eps)
1430
- assert_np_equal(gradients_z, gradients_z_auto, tol=eps)
1431
- assert_np_equal(gradients_w, gradients_w_auto, tol=eps)
1432
- assert_np_equal(xcmp, xcmp_auto, tol=eps)
1433
- assert_np_equal(ycmp, ycmp_auto, tol=eps)
1434
- assert_np_equal(zcmp, zcmp_auto, tol=eps)
1435
- assert_np_equal(wcmp, wcmp_auto, tol=eps)
1436
-
1437
- # wrt q1
1438
-
1439
- # gather gradients from builtin adjoints
1440
- xcmp, gradients_x = compute_gradients(slerp_kernel, q1, 0)
1441
- ycmp, gradients_y = compute_gradients(slerp_kernel, q1, 1)
1442
- zcmp, gradients_z = compute_gradients(slerp_kernel, q1, 2)
1443
- wcmp, gradients_w = compute_gradients(slerp_kernel, q1, 3)
1444
-
1445
- # gather gradients from autodiff
1446
- xcmp_auto, gradients_x_auto = compute_gradients(slerp_kernel_forward, q1, 0)
1447
- ycmp_auto, gradients_y_auto = compute_gradients(slerp_kernel_forward, q1, 1)
1448
- zcmp_auto, gradients_z_auto = compute_gradients(slerp_kernel_forward, q1, 2)
1449
- wcmp_auto, gradients_w_auto = compute_gradients(slerp_kernel_forward, q1, 3)
1450
-
1451
- assert_np_equal(gradients_x, gradients_x_auto, tol=eps)
1452
- assert_np_equal(gradients_y, gradients_y_auto, tol=eps)
1453
- assert_np_equal(gradients_z, gradients_z_auto, tol=eps)
1454
- assert_np_equal(gradients_w, gradients_w_auto, tol=eps)
1455
- assert_np_equal(xcmp, xcmp_auto, tol=eps)
1456
- assert_np_equal(ycmp, ycmp_auto, tol=eps)
1457
- assert_np_equal(zcmp, zcmp_auto, tol=eps)
1458
- assert_np_equal(wcmp, wcmp_auto, tol=eps)
1459
-
1460
-
1461
- ############################################################
1462
-
1463
-
1464
- def test_quat_to_axis_angle_grad(test, device, dtype, register_kernels=False):
1465
- rng = np.random.default_rng(123)
1466
- seed = 42
1467
- num_rand = 50
1468
-
1469
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1470
- vec3 = wp.types.vector(3, wptype)
1471
- vec4 = wp.types.vector(4, wptype)
1472
- quat = wp.types.quaternion(wptype)
1473
-
1474
- def quat_to_axis_angle_kernel(quats: wp.array(dtype=quat), loss: wp.array(dtype=wptype), coord_idx: int):
1475
- tid = wp.tid()
1476
- axis = vec3()
1477
- angle = wptype(0.0)
1478
-
1479
- wp.quat_to_axis_angle(quats[tid], axis, angle)
1480
- a = vec4(axis[0], axis[1], axis[2], angle)
1481
-
1482
- wp.atomic_add(loss, 0, a[coord_idx])
1483
-
1484
- quat_to_axis_angle_kernel = getkernel(quat_to_axis_angle_kernel, suffix=dtype.__name__)
1485
-
1486
- def quat_to_axis_angle_kernel_forward(quats: wp.array(dtype=quat), loss: wp.array(dtype=wptype), coord_idx: int):
1487
- tid = wp.tid()
1488
- q = quats[tid]
1489
- axis = vec3()
1490
- angle = wptype(0.0)
1491
-
1492
- v = vec3(q[0], q[1], q[2])
1493
- if q[3] < wptype(0):
1494
- axis = -wp.normalize(v)
1495
- else:
1496
- axis = wp.normalize(v)
1497
-
1498
- angle = wptype(2) * wp.atan2(wp.length(v), wp.abs(q[3]))
1499
- a = vec4(axis[0], axis[1], axis[2], angle)
1500
-
1501
- wp.atomic_add(loss, 0, a[coord_idx])
1502
-
1503
- quat_to_axis_angle_kernel_forward = getkernel(quat_to_axis_angle_kernel_forward, suffix=dtype.__name__)
1504
-
1505
- def quat_sampler(kernel_seed: int, angles: wp.array(dtype=float), quats: wp.array(dtype=quat)):
1506
- tid = wp.tid()
1507
-
1508
- state = wp.rand_init(kernel_seed, tid)
1509
-
1510
- angle = angles[tid]
1511
- dir = wp.sample_unit_sphere_surface(state) * wp.sin(angle * 0.5)
1512
-
1513
- q = quat(wptype(dir[0]), wptype(dir[1]), wptype(dir[2]), wptype(wp.cos(angle * 0.5)))
1514
- qn = wp.normalize(q)
1515
-
1516
- quats[tid] = qn
1517
-
1518
- quat_sampler = getkernel(quat_sampler, suffix=dtype.__name__)
1519
-
1520
- if register_kernels:
1521
- return
1522
-
1523
- quats = wp.zeros(num_rand, dtype=quat, device=device, requires_grad=True)
1524
- angles = wp.array(
1525
- np.linspace(0.0, 2.0 * np.pi, num_rand, endpoint=False, dtype=np.float32), dtype=float, device=device
1526
- )
1527
- wp.launch(kernel=quat_sampler, dim=num_rand, inputs=[seed, angles, quats], device=device)
1528
-
1529
- edge_cases = np.array(
1530
- [(1.0, 0.0, 0.0, 0.0), (0.0, 1.0 / np.sqrt(3), 1.0 / np.sqrt(3), 1.0 / np.sqrt(3)), (0.0, 0.0, 0.0, 0.0)]
1531
- )
1532
- num_edge = len(edge_cases)
1533
- edge_cases = wp.array(edge_cases, dtype=quat, device=device, requires_grad=True)
1534
-
1535
- def compute_gradients(arr, kernel, dim, index):
1536
- loss = wp.zeros(1, dtype=wptype, device=device, requires_grad=True)
1537
- tape = wp.Tape()
1538
- with tape:
1539
- wp.launch(kernel=kernel, dim=dim, inputs=[arr, loss, index], device=device)
1540
-
1541
- tape.backward(loss)
1542
-
1543
- gradients = 1.0 * tape.gradients[arr].numpy()
1544
- tape.zero()
1545
-
1546
- return loss.numpy()[0], gradients
1547
-
1548
- # gather gradients from builtin adjoints
1549
- xcmp, gradients_x = compute_gradients(quats, quat_to_axis_angle_kernel, num_rand, 0)
1550
- ycmp, gradients_y = compute_gradients(quats, quat_to_axis_angle_kernel, num_rand, 1)
1551
- zcmp, gradients_z = compute_gradients(quats, quat_to_axis_angle_kernel, num_rand, 2)
1552
- wcmp, gradients_w = compute_gradients(quats, quat_to_axis_angle_kernel, num_rand, 3)
1553
-
1554
- # gather gradients from autodiff
1555
- xcmp_auto, gradients_x_auto = compute_gradients(quats, quat_to_axis_angle_kernel_forward, num_rand, 0)
1556
- ycmp_auto, gradients_y_auto = compute_gradients(quats, quat_to_axis_angle_kernel_forward, num_rand, 1)
1557
- zcmp_auto, gradients_z_auto = compute_gradients(quats, quat_to_axis_angle_kernel_forward, num_rand, 2)
1558
- wcmp_auto, gradients_w_auto = compute_gradients(quats, quat_to_axis_angle_kernel_forward, num_rand, 3)
1559
-
1560
- # edge cases: gather gradients from builtin adjoints
1561
- _, edge_gradients_x = compute_gradients(edge_cases, quat_to_axis_angle_kernel, num_edge, 0)
1562
- _, edge_gradients_y = compute_gradients(edge_cases, quat_to_axis_angle_kernel, num_edge, 1)
1563
- _, edge_gradients_z = compute_gradients(edge_cases, quat_to_axis_angle_kernel, num_edge, 2)
1564
- _, edge_gradients_w = compute_gradients(edge_cases, quat_to_axis_angle_kernel, num_edge, 3)
1565
-
1566
- # edge cases: gather gradients from autodiff
1567
- _, edge_gradients_x_auto = compute_gradients(edge_cases, quat_to_axis_angle_kernel_forward, num_edge, 0)
1568
- _, edge_gradients_y_auto = compute_gradients(edge_cases, quat_to_axis_angle_kernel_forward, num_edge, 1)
1569
- _, edge_gradients_z_auto = compute_gradients(edge_cases, quat_to_axis_angle_kernel_forward, num_edge, 2)
1570
- _, edge_gradients_w_auto = compute_gradients(edge_cases, quat_to_axis_angle_kernel_forward, num_edge, 3)
1571
-
1572
- eps = {
1573
- np.float16: 2.0e-1,
1574
- np.float32: 2.0e-4,
1575
- np.float64: 2.0e-7,
1576
- }.get(dtype, 0)
1577
-
1578
- assert_np_equal(xcmp, xcmp_auto, tol=eps)
1579
- assert_np_equal(ycmp, ycmp_auto, tol=eps)
1580
- assert_np_equal(zcmp, zcmp_auto, tol=eps)
1581
- assert_np_equal(wcmp, wcmp_auto, tol=eps)
1582
-
1583
- assert_np_equal(gradients_x, gradients_x_auto, tol=eps)
1584
- assert_np_equal(gradients_y, gradients_y_auto, tol=eps)
1585
- assert_np_equal(gradients_z, gradients_z_auto, tol=eps)
1586
- assert_np_equal(gradients_w, gradients_w_auto, tol=eps)
1587
-
1588
- assert_np_equal(edge_gradients_x, edge_gradients_x_auto, tol=eps)
1589
- assert_np_equal(edge_gradients_y, edge_gradients_y_auto, tol=eps)
1590
- assert_np_equal(edge_gradients_z, edge_gradients_z_auto, tol=eps)
1591
- assert_np_equal(edge_gradients_w, edge_gradients_w_auto, tol=eps)
1592
-
1593
-
1594
- ############################################################
1595
-
1596
-
1597
- def test_quat_rpy_grad(test, device, dtype, register_kernels=False):
1598
- rng = np.random.default_rng(123)
1599
- N = 3
1600
-
1601
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1602
-
1603
- vec3 = wp.types.vector(3, wptype)
1604
- quat = wp.types.quaternion(wptype)
1605
-
1606
- def rpy_to_quat_kernel(rpy_arr: wp.array(dtype=vec3), loss: wp.array(dtype=wptype), coord_idx: int):
1607
- tid = wp.tid()
1608
- rpy = rpy_arr[tid]
1609
- roll = rpy[0]
1610
- pitch = rpy[1]
1611
- yaw = rpy[2]
1612
-
1613
- q = wp.quat_rpy(roll, pitch, yaw)
1614
-
1615
- wp.atomic_add(loss, 0, q[coord_idx])
1616
-
1617
- rpy_to_quat_kernel = getkernel(rpy_to_quat_kernel, suffix=dtype.__name__)
1618
-
1619
- def rpy_to_quat_kernel_forward(rpy_arr: wp.array(dtype=vec3), loss: wp.array(dtype=wptype), coord_idx: int):
1620
- tid = wp.tid()
1621
- rpy = rpy_arr[tid]
1622
- roll = rpy[0]
1623
- pitch = rpy[1]
1624
- yaw = rpy[2]
1625
-
1626
- cy = wp.cos(yaw * wptype(0.5))
1627
- sy = wp.sin(yaw * wptype(0.5))
1628
- cr = wp.cos(roll * wptype(0.5))
1629
- sr = wp.sin(roll * wptype(0.5))
1630
- cp = wp.cos(pitch * wptype(0.5))
1631
- sp = wp.sin(pitch * wptype(0.5))
1632
-
1633
- w = cy * cr * cp + sy * sr * sp
1634
- x = cy * sr * cp - sy * cr * sp
1635
- y = cy * cr * sp + sy * sr * cp
1636
- z = sy * cr * cp - cy * sr * sp
1637
-
1638
- q = quat(x, y, z, w)
1639
-
1640
- wp.atomic_add(loss, 0, q[coord_idx])
1641
-
1642
- rpy_to_quat_kernel_forward = getkernel(rpy_to_quat_kernel_forward, suffix=dtype.__name__)
1643
-
1644
- if register_kernels:
1645
- return
1646
-
1647
- rpy_arr = rng.uniform(low=-np.pi, high=np.pi, size=(N, 3))
1648
- rpy_arr = wp.array(rpy_arr, dtype=vec3, device=device, requires_grad=True)
1649
-
1650
- def compute_gradients(kernel, wrt, index):
1651
- loss = wp.zeros(1, dtype=wptype, device=device, requires_grad=True)
1652
- tape = wp.Tape()
1653
- with tape:
1654
- wp.launch(kernel=kernel, dim=N, inputs=[wrt, loss, index], device=device)
1655
-
1656
- tape.backward(loss)
1657
-
1658
- gradients = 1.0 * tape.gradients[wrt].numpy()
1659
- tape.zero()
1660
-
1661
- return loss.numpy()[0], gradients
1662
-
1663
- # wrt rpy
1664
- # gather gradients from builtin adjoints
1665
- rcmp, gradients_r = compute_gradients(rpy_to_quat_kernel, rpy_arr, 0)
1666
- pcmp, gradients_p = compute_gradients(rpy_to_quat_kernel, rpy_arr, 1)
1667
- ycmp, gradients_y = compute_gradients(rpy_to_quat_kernel, rpy_arr, 2)
1668
-
1669
- # gather gradients from autodiff
1670
- rcmp_auto, gradients_r_auto = compute_gradients(rpy_to_quat_kernel_forward, rpy_arr, 0)
1671
- pcmp_auto, gradients_p_auto = compute_gradients(rpy_to_quat_kernel_forward, rpy_arr, 1)
1672
- ycmp_auto, gradients_y_auto = compute_gradients(rpy_to_quat_kernel_forward, rpy_arr, 2)
1673
-
1674
- eps = {
1675
- np.float16: 2.0e-2,
1676
- np.float32: 1.0e-5,
1677
- np.float64: 1.0e-8,
1678
- }.get(dtype, 0)
1679
-
1680
- assert_np_equal(rcmp, rcmp_auto, tol=eps)
1681
- assert_np_equal(pcmp, pcmp_auto, tol=eps)
1682
- assert_np_equal(ycmp, ycmp_auto, tol=eps)
1683
-
1684
- assert_np_equal(gradients_r, gradients_r_auto, tol=eps)
1685
- assert_np_equal(gradients_p, gradients_p_auto, tol=eps)
1686
- assert_np_equal(gradients_y, gradients_y_auto, tol=eps)
1687
-
1688
-
1689
- ############################################################
1690
-
1691
-
1692
- def test_quat_from_matrix(test, device, dtype, register_kernels=False):
1693
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1694
- mat33 = wp.types.matrix((3, 3), wptype)
1695
- quat = wp.types.quaternion(wptype)
1696
-
1697
- def quat_from_matrix(m: wp.array2d(dtype=wptype), loss: wp.array(dtype=wptype), idx: int):
1698
- tid = wp.tid()
1699
-
1700
- matrix = mat33(
1701
- m[tid, 0], m[tid, 1], m[tid, 2], m[tid, 3], m[tid, 4], m[tid, 5], m[tid, 6], m[tid, 7], m[tid, 8]
1702
- )
1703
-
1704
- q = wp.quat_from_matrix(matrix)
1705
-
1706
- wp.atomic_add(loss, 0, q[idx])
1707
-
1708
- def quat_from_matrix_forward(mats: wp.array2d(dtype=wptype), loss: wp.array(dtype=wptype), idx: int):
1709
- tid = wp.tid()
1710
-
1711
- m = mat33(
1712
- mats[tid, 0],
1713
- mats[tid, 1],
1714
- mats[tid, 2],
1715
- mats[tid, 3],
1716
- mats[tid, 4],
1717
- mats[tid, 5],
1718
- mats[tid, 6],
1719
- mats[tid, 7],
1720
- mats[tid, 8],
1721
- )
1722
-
1723
- tr = m[0][0] + m[1][1] + m[2][2]
1724
- x = wptype(0)
1725
- y = wptype(0)
1726
- z = wptype(0)
1727
- w = wptype(0)
1728
- h = wptype(0)
1729
-
1730
- if tr >= wptype(0):
1731
- h = wp.sqrt(tr + wptype(1))
1732
- w = wptype(0.5) * h
1733
- h = wptype(0.5) / h
1734
-
1735
- x = (m[2][1] - m[1][2]) * h
1736
- y = (m[0][2] - m[2][0]) * h
1737
- z = (m[1][0] - m[0][1]) * h
1738
- else:
1739
- max_diag = 0
1740
- if m[1][1] > m[0][0]:
1741
- max_diag = 1
1742
- if m[2][2] > m[max_diag][max_diag]:
1743
- max_diag = 2
1744
-
1745
- if max_diag == 0:
1746
- h = wp.sqrt((m[0][0] - (m[1][1] + m[2][2])) + wptype(1))
1747
- x = wptype(0.5) * h
1748
- h = wptype(0.5) / h
1749
-
1750
- y = (m[0][1] + m[1][0]) * h
1751
- z = (m[2][0] + m[0][2]) * h
1752
- w = (m[2][1] - m[1][2]) * h
1753
- elif max_diag == 1:
1754
- h = wp.sqrt((m[1][1] - (m[2][2] + m[0][0])) + wptype(1))
1755
- y = wptype(0.5) * h
1756
- h = wptype(0.5) / h
1757
-
1758
- z = (m[1][2] + m[2][1]) * h
1759
- x = (m[0][1] + m[1][0]) * h
1760
- w = (m[0][2] - m[2][0]) * h
1761
- if max_diag == 2:
1762
- h = wp.sqrt((m[2][2] - (m[0][0] + m[1][1])) + wptype(1))
1763
- z = wptype(0.5) * h
1764
- h = wptype(0.5) / h
1765
-
1766
- x = (m[2][0] + m[0][2]) * h
1767
- y = (m[1][2] + m[2][1]) * h
1768
- w = (m[1][0] - m[0][1]) * h
1769
-
1770
- q = wp.normalize(quat(x, y, z, w))
1771
-
1772
- wp.atomic_add(loss, 0, q[idx])
1773
-
1774
- quat_from_matrix = getkernel(quat_from_matrix, suffix=dtype.__name__)
1775
- quat_from_matrix_forward = getkernel(quat_from_matrix_forward, suffix=dtype.__name__)
1776
-
1777
- if register_kernels:
1778
- return
1779
-
1780
- m = np.array(
1781
- [
1782
- [1.0, 0.0, 0.0, 0.0, 0.5, 0.866, 0.0, -0.866, 0.5],
1783
- [0.866, 0.0, 0.25, -0.433, 0.5, 0.75, -0.25, -0.866, 0.433],
1784
- [0.866, -0.433, 0.25, 0.0, 0.5, 0.866, -0.5, -0.75, 0.433],
1785
- [-1.2, -1.6, -2.3, 0.25, -0.6, -0.33, 3.2, -1.0, -2.2],
1786
- ]
1787
- )
1788
- m = wp.array2d(m, dtype=wptype, device=device, requires_grad=True)
1789
-
1790
- N = m.shape[0]
1791
-
1792
- def compute_gradients(kernel, wrt, index):
1793
- loss = wp.zeros(1, dtype=wptype, device=device, requires_grad=True)
1794
- tape = wp.Tape()
1795
-
1796
- with tape:
1797
- wp.launch(kernel=kernel, dim=N, inputs=[m, loss, index], device=device)
1798
-
1799
- tape.backward(loss)
1800
-
1801
- gradients = 1.0 * tape.gradients[wrt].numpy()
1802
- tape.zero()
1803
-
1804
- return loss.numpy()[0], gradients
1805
-
1806
- # gather gradients from builtin adjoints
1807
- cmpx, gradients_x = compute_gradients(quat_from_matrix, m, 0)
1808
- cmpy, gradients_y = compute_gradients(quat_from_matrix, m, 1)
1809
- cmpz, gradients_z = compute_gradients(quat_from_matrix, m, 2)
1810
- cmpw, gradients_w = compute_gradients(quat_from_matrix, m, 3)
1811
-
1812
- # gather gradients from autodiff
1813
- cmpx_auto, gradients_x_auto = compute_gradients(quat_from_matrix_forward, m, 0)
1814
- cmpy_auto, gradients_y_auto = compute_gradients(quat_from_matrix_forward, m, 1)
1815
- cmpz_auto, gradients_z_auto = compute_gradients(quat_from_matrix_forward, m, 2)
1816
- cmpw_auto, gradients_w_auto = compute_gradients(quat_from_matrix_forward, m, 3)
1817
-
1818
- # compare
1819
- eps = 1.0e6
1820
-
1821
- eps = {
1822
- np.float16: 2.0e-2,
1823
- np.float32: 1.0e-5,
1824
- np.float64: 1.0e-8,
1825
- }.get(dtype, 0)
1826
-
1827
- assert_np_equal(cmpx, cmpx_auto, tol=eps)
1828
- assert_np_equal(cmpy, cmpy_auto, tol=eps)
1829
- assert_np_equal(cmpz, cmpz_auto, tol=eps)
1830
- assert_np_equal(cmpw, cmpw_auto, tol=eps)
1831
-
1832
- assert_np_equal(gradients_x, gradients_x_auto, tol=eps)
1833
- assert_np_equal(gradients_y, gradients_y_auto, tol=eps)
1834
- assert_np_equal(gradients_z, gradients_z_auto, tol=eps)
1835
- assert_np_equal(gradients_w, gradients_w_auto, tol=eps)
1836
-
1837
-
1838
- def test_quat_identity(test, device, dtype, register_kernels=False):
1839
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1840
-
1841
- def quat_identity_test(output: wp.array(dtype=wptype)):
1842
- q = wp.quat_identity(dtype=wptype)
1843
- output[0] = q[0]
1844
- output[1] = q[1]
1845
- output[2] = q[2]
1846
- output[3] = q[3]
1847
-
1848
- def quat_identity_test_default(output: wp.array(dtype=wp.float32)):
1849
- q = wp.quat_identity()
1850
- output[0] = q[0]
1851
- output[1] = q[1]
1852
- output[2] = q[2]
1853
- output[3] = q[3]
1854
-
1855
- quat_identity_kernel = getkernel(quat_identity_test, suffix=dtype.__name__)
1856
- quat_identity_default_kernel = getkernel(quat_identity_test_default, suffix=np.float32.__name__)
1857
-
1858
- if register_kernels:
1859
- return
1860
-
1861
- output = wp.zeros(4, dtype=wptype, device=device)
1862
- wp.launch(quat_identity_kernel, dim=1, inputs=[], outputs=[output], device=device)
1863
- expected = np.zeros_like(output.numpy())
1864
- expected[3] = 1
1865
- assert_np_equal(output.numpy(), expected)
1866
-
1867
- # let's just test that it defaults to float32:
1868
- output = wp.zeros(4, dtype=wp.float32, device=device)
1869
- wp.launch(quat_identity_default_kernel, dim=1, inputs=[], outputs=[output], device=device)
1870
- expected = np.zeros_like(output.numpy())
1871
- expected[3] = 1
1872
- assert_np_equal(output.numpy(), expected)
1873
-
1874
-
1875
- ############################################################
1876
-
1877
-
1878
- def test_quat_euler_conversion(test, device, dtype, register_kernels=False):
1879
- rng = np.random.default_rng(123)
1880
- N = 3
1881
-
1882
- rpy_arr = rng.uniform(low=-np.pi, high=np.pi, size=(N, 3))
1883
-
1884
- quats_from_euler = [list(wp.sim.quat_from_euler(wp.vec3(*rpy), 0, 1, 2)) for rpy in rpy_arr]
1885
- quats_from_rpy = [list(wp.quat_rpy(rpy[0], rpy[1], rpy[2])) for rpy in rpy_arr]
1886
-
1887
- assert_np_equal(np.array(quats_from_euler), np.array(quats_from_rpy), tol=1e-4)
1888
-
1889
-
1890
- def test_anon_type_instance(test, device, dtype, register_kernels=False):
1891
- rng = np.random.default_rng(123)
1892
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1893
-
1894
- def quat_create_test(input: wp.array(dtype=wptype), output: wp.array(dtype=wptype)):
1895
- # component constructor:
1896
- q = wp.quaternion(input[0], input[1], input[2], input[3])
1897
- output[0] = wptype(2) * q[0]
1898
- output[1] = wptype(2) * q[1]
1899
- output[2] = wptype(2) * q[2]
1900
- output[3] = wptype(2) * q[3]
1901
-
1902
- # vector / scalar constructor:
1903
- q2 = wp.quaternion(wp.vector(input[4], input[5], input[6]), input[7])
1904
- output[4] = wptype(2) * q2[0]
1905
- output[5] = wptype(2) * q2[1]
1906
- output[6] = wptype(2) * q2[2]
1907
- output[7] = wptype(2) * q2[3]
1908
-
1909
- quat_create_kernel = getkernel(quat_create_test, suffix=dtype.__name__)
1910
- output_select_kernel = get_select_kernel(wptype)
1911
-
1912
- if register_kernels:
1913
- return
1914
-
1915
- input = wp.array(rng.standard_normal(size=8).astype(dtype), requires_grad=True, device=device)
1916
- output = wp.zeros(8, dtype=wptype, requires_grad=True, device=device)
1917
- wp.launch(quat_create_kernel, dim=1, inputs=[input], outputs=[output], device=device)
1918
- assert_np_equal(output.numpy(), 2 * input.numpy())
1919
-
1920
- for i in range(len(input)):
1921
- cmp = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1922
- tape = wp.Tape()
1923
- with tape:
1924
- wp.launch(quat_create_kernel, dim=1, inputs=[input], outputs=[output], device=device)
1925
- wp.launch(output_select_kernel, dim=1, inputs=[output, i], outputs=[cmp], device=device)
1926
- tape.backward(loss=cmp)
1927
- expectedgrads = np.zeros(len(input))
1928
- expectedgrads[i] = 2
1929
- assert_np_equal(tape.gradients[input].numpy(), expectedgrads)
1930
- tape.zero()
1931
-
1932
-
1933
- # Same as above but with a default (float) type
1934
- # which tests some different code paths that
1935
- # need to ensure types are correctly canonicalized
1936
- # during codegen
1937
- @wp.kernel
1938
- def test_constructor_default():
1939
- qzero = wp.quat()
1940
- wp.expect_eq(qzero[0], 0.0)
1941
- wp.expect_eq(qzero[1], 0.0)
1942
- wp.expect_eq(qzero[2], 0.0)
1943
- wp.expect_eq(qzero[3], 0.0)
1944
-
1945
- qval = wp.quat(1.0, 2.0, 3.0, 4.0)
1946
- wp.expect_eq(qval[0], 1.0)
1947
- wp.expect_eq(qval[1], 2.0)
1948
- wp.expect_eq(qval[2], 3.0)
1949
- wp.expect_eq(qval[3], 4.0)
1950
-
1951
- qeye = wp.quat_identity()
1952
- wp.expect_eq(qeye[0], 0.0)
1953
- wp.expect_eq(qeye[1], 0.0)
1954
- wp.expect_eq(qeye[2], 0.0)
1955
- wp.expect_eq(qeye[3], 1.0)
1956
-
1957
-
1958
- def test_py_arithmetic_ops(test, device, dtype):
1959
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1960
-
1961
- def make_quat(*args):
1962
- if wptype in wp.types.int_types:
1963
- # Cast to the correct integer type to simulate wrapping.
1964
- return tuple(wptype._type_(x).value for x in args)
1965
-
1966
- return args
1967
-
1968
- quat_cls = wp.types.quaternion(wptype)
1969
-
1970
- v = quat_cls(1, -2, 3, -4)
1971
- test.assertSequenceEqual(+v, make_quat(1, -2, 3, -4))
1972
- test.assertSequenceEqual(-v, make_quat(-1, 2, -3, 4))
1973
- test.assertSequenceEqual(v + quat_cls(5, 5, 5, 5), make_quat(6, 3, 8, 1))
1974
- test.assertSequenceEqual(v - quat_cls(5, 5, 5, 5), make_quat(-4, -7, -2, -9))
1975
-
1976
- v = quat_cls(2, 4, 6, 8)
1977
- test.assertSequenceEqual(v * wptype(2), make_quat(4, 8, 12, 16))
1978
- test.assertSequenceEqual(wptype(2) * v, make_quat(4, 8, 12, 16))
1979
- test.assertSequenceEqual(v / wptype(2), make_quat(1, 2, 3, 4))
1980
- test.assertSequenceEqual(wptype(24) / v, make_quat(12, 6, 4, 3))
1981
-
1982
-
1983
- devices = get_test_devices()
1984
-
1985
-
1986
- class TestQuat(unittest.TestCase):
1987
- pass
1988
-
1989
-
1990
- add_kernel_test(TestQuat, test_constructor_default, dim=1, devices=devices)
1991
-
1992
- for dtype in np_float_types:
1993
- add_function_test_register_kernel(
1994
- TestQuat, f"test_constructors_{dtype.__name__}", test_constructors, devices=devices, dtype=dtype
1995
- )
1996
- add_function_test_register_kernel(
1997
- TestQuat,
1998
- f"test_casting_constructors_{dtype.__name__}",
1999
- test_casting_constructors,
2000
- devices=devices,
2001
- dtype=dtype,
2002
- )
2003
- add_function_test_register_kernel(
2004
- TestQuat, f"test_anon_type_instance_{dtype.__name__}", test_anon_type_instance, devices=devices, dtype=dtype
2005
- )
2006
- add_function_test_register_kernel(
2007
- TestQuat, f"test_inverse_{dtype.__name__}", test_inverse, devices=devices, dtype=dtype
2008
- )
2009
- add_function_test_register_kernel(
2010
- TestQuat, f"test_quat_identity_{dtype.__name__}", test_quat_identity, devices=devices, dtype=dtype
2011
- )
2012
- add_function_test_register_kernel(
2013
- TestQuat, f"test_dotproduct_{dtype.__name__}", test_dotproduct, devices=devices, dtype=dtype
2014
- )
2015
- add_function_test_register_kernel(
2016
- TestQuat, f"test_length_{dtype.__name__}", test_length, devices=devices, dtype=dtype
2017
- )
2018
- add_function_test_register_kernel(
2019
- TestQuat, f"test_normalize_{dtype.__name__}", test_normalize, devices=devices, dtype=dtype
2020
- )
2021
- add_function_test_register_kernel(
2022
- TestQuat, f"test_addition_{dtype.__name__}", test_addition, devices=devices, dtype=dtype
2023
- )
2024
- add_function_test_register_kernel(
2025
- TestQuat, f"test_subtraction_{dtype.__name__}", test_subtraction, devices=devices, dtype=dtype
2026
- )
2027
- add_function_test_register_kernel(
2028
- TestQuat,
2029
- f"test_scalar_multiplication_{dtype.__name__}",
2030
- test_scalar_multiplication,
2031
- devices=devices,
2032
- dtype=dtype,
2033
- )
2034
- add_function_test_register_kernel(
2035
- TestQuat, f"test_scalar_division_{dtype.__name__}", test_scalar_division, devices=devices, dtype=dtype
2036
- )
2037
- add_function_test_register_kernel(
2038
- TestQuat,
2039
- f"test_quat_multiplication_{dtype.__name__}",
2040
- test_quat_multiplication,
2041
- devices=devices,
2042
- dtype=dtype,
2043
- )
2044
- add_function_test_register_kernel(
2045
- TestQuat, f"test_indexing_{dtype.__name__}", test_indexing, devices=devices, dtype=dtype
2046
- )
2047
- add_function_test_register_kernel(
2048
- TestQuat, f"test_quat_lerp_{dtype.__name__}", test_quat_lerp, devices=devices, dtype=dtype
2049
- )
2050
- add_function_test_register_kernel(
2051
- TestQuat,
2052
- f"test_quat_to_axis_angle_grad_{dtype.__name__}",
2053
- test_quat_to_axis_angle_grad,
2054
- devices=devices,
2055
- dtype=dtype,
2056
- )
2057
- add_function_test_register_kernel(
2058
- TestQuat, f"test_slerp_grad_{dtype.__name__}", test_slerp_grad, devices=devices, dtype=dtype
2059
- )
2060
- add_function_test_register_kernel(
2061
- TestQuat, f"test_quat_rpy_grad_{dtype.__name__}", test_quat_rpy_grad, devices=devices, dtype=dtype
2062
- )
2063
- add_function_test_register_kernel(
2064
- TestQuat, f"test_quat_from_matrix_{dtype.__name__}", test_quat_from_matrix, devices=devices, dtype=dtype
2065
- )
2066
- add_function_test_register_kernel(
2067
- TestQuat, f"test_quat_rotate_{dtype.__name__}", test_quat_rotate, devices=devices, dtype=dtype
2068
- )
2069
- add_function_test_register_kernel(
2070
- TestQuat, f"test_quat_to_matrix_{dtype.__name__}", test_quat_to_matrix, devices=devices, dtype=dtype
2071
- )
2072
- add_function_test_register_kernel(
2073
- TestQuat,
2074
- f"test_quat_euler_conversion_{dtype.__name__}",
2075
- test_quat_euler_conversion,
2076
- devices=devices,
2077
- dtype=dtype,
2078
- )
2079
- add_function_test(
2080
- TestQuat, f"test_py_arithmetic_ops_{dtype.__name__}", test_py_arithmetic_ops, devices=None, dtype=dtype
2081
- )
2082
-
2083
-
2084
- if __name__ == "__main__":
2085
- wp.build.clear_kernel_cache()
2086
- unittest.main(verbosity=2)
1
+ # Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
2
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
3
+ # and proprietary rights in and to this software, related documentation
4
+ # and any modifications thereto. Any use, reproduction, disclosure or
5
+ # distribution of this software and related documentation without an express
6
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
+
8
+ import unittest
9
+
10
+ import numpy as np
11
+
12
+ import warp as wp
13
+ import warp.sim
14
+ from warp.tests.unittest_utils import *
15
+
16
+ wp.init()
17
+
18
+ np_float_types = [np.float32, np.float64, np.float16]
19
+
20
+ kernel_cache = {}
21
+
22
+
23
+ def getkernel(func, suffix=""):
24
+ key = func.__name__ + "_" + suffix
25
+ if key not in kernel_cache:
26
+ kernel_cache[key] = wp.Kernel(func=func, key=key)
27
+ return kernel_cache[key]
28
+
29
+
30
+ def get_select_kernel(dtype):
31
+ def output_select_kernel_fn(
32
+ input: wp.array(dtype=dtype),
33
+ index: int,
34
+ out: wp.array(dtype=dtype),
35
+ ):
36
+ out[0] = input[index]
37
+
38
+ return getkernel(output_select_kernel_fn, suffix=dtype.__name__)
39
+
40
+
41
+ ############################################################
42
+
43
+
44
+ def test_constructors(test, device, dtype, register_kernels=False):
45
+ rng = np.random.default_rng(123)
46
+
47
+ tol = {
48
+ np.float16: 5.0e-3,
49
+ np.float32: 1.0e-6,
50
+ np.float64: 1.0e-8,
51
+ }.get(dtype, 0)
52
+
53
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
54
+ vec3 = wp.types.vector(length=3, dtype=wptype)
55
+ quat = wp.types.quaternion(dtype=wptype)
56
+
57
+ def check_component_constructor(
58
+ input: wp.array(dtype=wptype),
59
+ q: wp.array(dtype=wptype),
60
+ ):
61
+ qresult = quat(input[0], input[1], input[2], input[3])
62
+
63
+ # multiply the output by 2 so we've got something to backpropagate:
64
+ q[0] = wptype(2) * qresult[0]
65
+ q[1] = wptype(2) * qresult[1]
66
+ q[2] = wptype(2) * qresult[2]
67
+ q[3] = wptype(2) * qresult[3]
68
+
69
+ def check_vector_constructor(
70
+ input: wp.array(dtype=wptype),
71
+ q: wp.array(dtype=wptype),
72
+ ):
73
+ qresult = quat(vec3(input[0], input[1], input[2]), input[3])
74
+
75
+ # multiply the output by 2 so we've got something to backpropagate:
76
+ q[0] = wptype(2) * qresult[0]
77
+ q[1] = wptype(2) * qresult[1]
78
+ q[2] = wptype(2) * qresult[2]
79
+ q[3] = wptype(2) * qresult[3]
80
+
81
+ kernel = getkernel(check_component_constructor, suffix=dtype.__name__)
82
+ output_select_kernel = get_select_kernel(wptype)
83
+ vec_kernel = getkernel(check_vector_constructor, suffix=dtype.__name__)
84
+
85
+ if register_kernels:
86
+ return
87
+
88
+ input = wp.array(rng.standard_normal(size=4).astype(dtype), requires_grad=True, device=device)
89
+ output = wp.zeros_like(input)
90
+ wp.launch(kernel, dim=1, inputs=[input], outputs=[output], device=device)
91
+
92
+ assert_np_equal(output.numpy(), 2 * input.numpy(), tol=tol)
93
+
94
+ for i in range(4):
95
+ cmp = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
96
+ tape = wp.Tape()
97
+ with tape:
98
+ wp.launch(kernel, dim=1, inputs=[input], outputs=[output], device=device)
99
+ wp.launch(output_select_kernel, dim=1, inputs=[output, i], outputs=[cmp], device=device)
100
+ tape.backward(loss=cmp)
101
+ expectedgrads = np.zeros(len(input))
102
+ expectedgrads[i] = 2
103
+ assert_np_equal(tape.gradients[input].numpy(), expectedgrads)
104
+ tape.zero()
105
+
106
+ input = wp.array(rng.standard_normal(size=4).astype(dtype), requires_grad=True, device=device)
107
+ output = wp.zeros_like(input)
108
+ wp.launch(vec_kernel, dim=1, inputs=[input], outputs=[output], device=device)
109
+
110
+ assert_np_equal(output.numpy(), 2 * input.numpy(), tol=tol)
111
+
112
+ for i in range(4):
113
+ cmp = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
114
+ tape = wp.Tape()
115
+ with tape:
116
+ wp.launch(vec_kernel, dim=1, inputs=[input], outputs=[output], device=device)
117
+ wp.launch(output_select_kernel, dim=1, inputs=[output, i], outputs=[cmp], device=device)
118
+ tape.backward(loss=cmp)
119
+ expectedgrads = np.zeros(len(input))
120
+ expectedgrads[i] = 2
121
+ assert_np_equal(tape.gradients[input].numpy(), expectedgrads)
122
+ tape.zero()
123
+
124
+
125
+ def test_casting_constructors(test, device, dtype, register_kernels=False):
126
+ np_type = np.dtype(dtype)
127
+ wp_type = wp.types.np_dtype_to_warp_type[np_type]
128
+ quat = wp.types.quaternion(dtype=wp_type)
129
+
130
+ np16 = np.dtype(np.float16)
131
+ wp16 = wp.types.np_dtype_to_warp_type[np16]
132
+
133
+ np32 = np.dtype(np.float32)
134
+ wp32 = wp.types.np_dtype_to_warp_type[np32]
135
+
136
+ np64 = np.dtype(np.float64)
137
+ wp64 = wp.types.np_dtype_to_warp_type[np64]
138
+
139
+ def cast_float16(a: wp.array(dtype=wp_type, ndim=2), b: wp.array(dtype=wp16, ndim=2)):
140
+ tid = wp.tid()
141
+
142
+ q1 = quat(a[tid, 0], a[tid, 1], a[tid, 2], a[tid, 3])
143
+ q2 = wp.quaternion(q1, dtype=wp16)
144
+
145
+ b[tid, 0] = q2[0]
146
+ b[tid, 1] = q2[1]
147
+ b[tid, 2] = q2[2]
148
+ b[tid, 3] = q2[3]
149
+
150
+ def cast_float32(a: wp.array(dtype=wp_type, ndim=2), b: wp.array(dtype=wp32, ndim=2)):
151
+ tid = wp.tid()
152
+
153
+ q1 = quat(a[tid, 0], a[tid, 1], a[tid, 2], a[tid, 3])
154
+ q2 = wp.quaternion(q1, dtype=wp32)
155
+
156
+ b[tid, 0] = q2[0]
157
+ b[tid, 1] = q2[1]
158
+ b[tid, 2] = q2[2]
159
+ b[tid, 3] = q2[3]
160
+
161
+ def cast_float64(a: wp.array(dtype=wp_type, ndim=2), b: wp.array(dtype=wp64, ndim=2)):
162
+ tid = wp.tid()
163
+
164
+ q1 = quat(a[tid, 0], a[tid, 1], a[tid, 2], a[tid, 3])
165
+ q2 = wp.quaternion(q1, dtype=wp64)
166
+
167
+ b[tid, 0] = q2[0]
168
+ b[tid, 1] = q2[1]
169
+ b[tid, 2] = q2[2]
170
+ b[tid, 3] = q2[3]
171
+
172
+ kernel_16 = getkernel(cast_float16, suffix=dtype.__name__)
173
+ kernel_32 = getkernel(cast_float32, suffix=dtype.__name__)
174
+ kernel_64 = getkernel(cast_float64, suffix=dtype.__name__)
175
+
176
+ if register_kernels:
177
+ return
178
+
179
+ # check casting to float 16
180
+ a = wp.array(np.ones((1, 4), dtype=np_type), dtype=wp_type, requires_grad=True, device=device)
181
+ b = wp.array(np.zeros((1, 4), dtype=np16), dtype=wp16, requires_grad=True, device=device)
182
+ b_result = np.ones((1, 4), dtype=np16)
183
+ b_grad = wp.array(np.ones((1, 4), dtype=np16), dtype=wp16, device=device)
184
+ a_grad = wp.array(np.ones((1, 4), dtype=np_type), dtype=wp_type, device=device)
185
+
186
+ tape = wp.Tape()
187
+ with tape:
188
+ wp.launch(kernel=kernel_16, dim=1, inputs=[a, b], device=device)
189
+
190
+ tape.backward(grads={b: b_grad})
191
+ out = tape.gradients[a].numpy()
192
+
193
+ assert_np_equal(b.numpy(), b_result)
194
+ assert_np_equal(out, a_grad.numpy())
195
+
196
+ # check casting to float 32
197
+ a = wp.array(np.ones((1, 4), dtype=np_type), dtype=wp_type, requires_grad=True, device=device)
198
+ b = wp.array(np.zeros((1, 4), dtype=np32), dtype=wp32, requires_grad=True, device=device)
199
+ b_result = np.ones((1, 4), dtype=np32)
200
+ b_grad = wp.array(np.ones((1, 4), dtype=np32), dtype=wp32, device=device)
201
+ a_grad = wp.array(np.ones((1, 4), dtype=np_type), dtype=wp_type, device=device)
202
+
203
+ tape = wp.Tape()
204
+ with tape:
205
+ wp.launch(kernel=kernel_32, dim=1, inputs=[a, b], device=device)
206
+
207
+ tape.backward(grads={b: b_grad})
208
+ out = tape.gradients[a].numpy()
209
+
210
+ assert_np_equal(b.numpy(), b_result)
211
+ assert_np_equal(out, a_grad.numpy())
212
+
213
+ # check casting to float 64
214
+ a = wp.array(np.ones((1, 4), dtype=np_type), dtype=wp_type, requires_grad=True, device=device)
215
+ b = wp.array(np.zeros((1, 4), dtype=np64), dtype=wp64, requires_grad=True, device=device)
216
+ b_result = np.ones((1, 4), dtype=np64)
217
+ b_grad = wp.array(np.ones((1, 4), dtype=np64), dtype=wp64, device=device)
218
+ a_grad = wp.array(np.ones((1, 4), dtype=np_type), dtype=wp_type, device=device)
219
+
220
+ tape = wp.Tape()
221
+ with tape:
222
+ wp.launch(kernel=kernel_64, dim=1, inputs=[a, b], device=device)
223
+
224
+ tape.backward(grads={b: b_grad})
225
+ out = tape.gradients[a].numpy()
226
+
227
+ assert_np_equal(b.numpy(), b_result)
228
+ assert_np_equal(out, a_grad.numpy())
229
+
230
+
231
+ def test_inverse(test, device, dtype, register_kernels=False):
232
+ rng = np.random.default_rng(123)
233
+
234
+ tol = {
235
+ np.float16: 2.0e-3,
236
+ np.float32: 1.0e-6,
237
+ np.float64: 1.0e-8,
238
+ }.get(dtype, 0)
239
+
240
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
241
+ quat = wp.types.quaternion(dtype=wptype)
242
+
243
+ output_select_kernel = get_select_kernel(wptype)
244
+
245
+ def check_quat_inverse(
246
+ input: wp.array(dtype=wptype),
247
+ shouldbeidentity: wp.array(dtype=quat),
248
+ q: wp.array(dtype=wptype),
249
+ ):
250
+ qread = quat(input[0], input[1], input[2], input[3])
251
+ qresult = wp.quat_inverse(qread)
252
+
253
+ # this inverse should work for normalized quaternions:
254
+ shouldbeidentity[0] = wp.normalize(qread) * wp.quat_inverse(wp.normalize(qread))
255
+
256
+ # multiply the output by 2 so we've got something to backpropagate:
257
+ q[0] = wptype(2) * qresult[0]
258
+ q[1] = wptype(2) * qresult[1]
259
+ q[2] = wptype(2) * qresult[2]
260
+ q[3] = wptype(2) * qresult[3]
261
+
262
+ kernel = getkernel(check_quat_inverse, suffix=dtype.__name__)
263
+
264
+ if register_kernels:
265
+ return
266
+
267
+ input = wp.array(rng.standard_normal(size=4).astype(dtype), requires_grad=True, device=device)
268
+ shouldbeidentity = wp.array(np.zeros((1, 4)), dtype=quat, requires_grad=True, device=device)
269
+ output = wp.zeros_like(input)
270
+ wp.launch(kernel, dim=1, inputs=[input], outputs=[shouldbeidentity, output], device=device)
271
+
272
+ assert_np_equal(shouldbeidentity.numpy(), np.array([0, 0, 0, 1]), tol=tol)
273
+
274
+ for i in range(4):
275
+ cmp = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
276
+ tape = wp.Tape()
277
+ with tape:
278
+ wp.launch(kernel, dim=1, inputs=[input], outputs=[shouldbeidentity, output], device=device)
279
+ wp.launch(output_select_kernel, dim=1, inputs=[output, i], outputs=[cmp], device=device)
280
+ tape.backward(loss=cmp)
281
+ expectedgrads = np.zeros(len(input))
282
+ expectedgrads[i] = -2 if i != 3 else 2
283
+ assert_np_equal(tape.gradients[input].numpy(), expectedgrads)
284
+ tape.zero()
285
+
286
+
287
+ def test_dotproduct(test, device, dtype, register_kernels=False):
288
+ rng = np.random.default_rng(123)
289
+
290
+ tol = {
291
+ np.float16: 1.0e-2,
292
+ np.float32: 1.0e-6,
293
+ np.float64: 1.0e-8,
294
+ }.get(dtype, 0)
295
+
296
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
297
+ quat = wp.types.quaternion(dtype=wptype)
298
+
299
+ def check_quat_dot(
300
+ s: wp.array(dtype=quat),
301
+ v: wp.array(dtype=quat),
302
+ dot: wp.array(dtype=wptype),
303
+ ):
304
+ dot[0] = wptype(2) * wp.dot(v[0], s[0])
305
+
306
+ dotkernel = getkernel(check_quat_dot, suffix=dtype.__name__)
307
+ if register_kernels:
308
+ return
309
+
310
+ s = wp.array(rng.standard_normal(size=4).astype(dtype), dtype=quat, requires_grad=True, device=device)
311
+ v = wp.array(rng.standard_normal(size=4).astype(dtype), dtype=quat, requires_grad=True, device=device)
312
+ dot = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
313
+
314
+ tape = wp.Tape()
315
+ with tape:
316
+ wp.launch(
317
+ dotkernel,
318
+ dim=1,
319
+ inputs=[
320
+ s,
321
+ v,
322
+ ],
323
+ outputs=[dot],
324
+ device=device,
325
+ )
326
+
327
+ assert_np_equal(dot.numpy()[0], 2.0 * (v.numpy() * s.numpy()).sum(), tol=tol)
328
+
329
+ tape.backward(loss=dot)
330
+ sgrads = tape.gradients[s].numpy()[0]
331
+ expected_grads = 2.0 * v.numpy()[0]
332
+ assert_np_equal(sgrads, expected_grads, tol=10 * tol)
333
+
334
+ vgrads = tape.gradients[v].numpy()[0]
335
+ expected_grads = 2.0 * s.numpy()[0]
336
+ assert_np_equal(vgrads, expected_grads, tol=tol)
337
+
338
+
339
+ def test_length(test, device, dtype, register_kernels=False):
340
+ rng = np.random.default_rng(123)
341
+
342
+ tol = {
343
+ np.float16: 5.0e-3,
344
+ np.float32: 1.0e-6,
345
+ np.float64: 1.0e-7,
346
+ }.get(dtype, 0)
347
+
348
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
349
+ quat = wp.types.quaternion(dtype=wptype)
350
+
351
+ def check_quat_length(
352
+ q: wp.array(dtype=quat),
353
+ l: wp.array(dtype=wptype),
354
+ l2: wp.array(dtype=wptype),
355
+ ):
356
+ l[0] = wptype(2) * wp.length(q[0])
357
+ l2[0] = wptype(2) * wp.length_sq(q[0])
358
+
359
+ kernel = getkernel(check_quat_length, suffix=dtype.__name__)
360
+
361
+ if register_kernels:
362
+ return
363
+
364
+ q = wp.array(rng.standard_normal(size=4).astype(dtype), dtype=quat, requires_grad=True, device=device)
365
+ l = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
366
+ l2 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
367
+
368
+ tape = wp.Tape()
369
+ with tape:
370
+ wp.launch(
371
+ kernel,
372
+ dim=1,
373
+ inputs=[
374
+ q,
375
+ ],
376
+ outputs=[l, l2],
377
+ device=device,
378
+ )
379
+
380
+ assert_np_equal(l.numpy()[0], 2 * np.linalg.norm(q.numpy()), tol=10 * tol)
381
+ assert_np_equal(l2.numpy()[0], 2 * np.linalg.norm(q.numpy()) ** 2, tol=10 * tol)
382
+
383
+ tape.backward(loss=l)
384
+ grad = tape.gradients[q].numpy()[0]
385
+ expected_grad = 2 * q.numpy()[0] / np.linalg.norm(q.numpy())
386
+ assert_np_equal(grad, expected_grad, tol=10 * tol)
387
+ tape.zero()
388
+
389
+ tape.backward(loss=l2)
390
+ grad = tape.gradients[q].numpy()[0]
391
+ expected_grad = 4 * q.numpy()[0]
392
+ assert_np_equal(grad, expected_grad, tol=10 * tol)
393
+ tape.zero()
394
+
395
+
396
+ def test_normalize(test, device, dtype, register_kernels=False):
397
+ rng = np.random.default_rng(123)
398
+
399
+ tol = {
400
+ np.float16: 5.0e-3,
401
+ np.float32: 1.0e-6,
402
+ np.float64: 1.0e-8,
403
+ }.get(dtype, 0)
404
+
405
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
406
+ quat = wp.types.quaternion(dtype=wptype)
407
+
408
+ def check_normalize(
409
+ q: wp.array(dtype=quat),
410
+ n0: wp.array(dtype=wptype),
411
+ n1: wp.array(dtype=wptype),
412
+ n2: wp.array(dtype=wptype),
413
+ n3: wp.array(dtype=wptype),
414
+ ):
415
+ n = wptype(2) * (wp.normalize(q[0]))
416
+
417
+ n0[0] = n[0]
418
+ n1[0] = n[1]
419
+ n2[0] = n[2]
420
+ n3[0] = n[3]
421
+
422
+ def check_normalize_alt(
423
+ q: wp.array(dtype=quat),
424
+ n0: wp.array(dtype=wptype),
425
+ n1: wp.array(dtype=wptype),
426
+ n2: wp.array(dtype=wptype),
427
+ n3: wp.array(dtype=wptype),
428
+ ):
429
+ n = wptype(2) * (q[0] / wp.length(q[0]))
430
+
431
+ n0[0] = n[0]
432
+ n1[0] = n[1]
433
+ n2[0] = n[2]
434
+ n3[0] = n[3]
435
+
436
+ normalize_kernel = getkernel(check_normalize, suffix=dtype.__name__)
437
+ normalize_alt_kernel = getkernel(check_normalize_alt, suffix=dtype.__name__)
438
+
439
+ if register_kernels:
440
+ return
441
+
442
+ # I've already tested the things I'm using in check_normalize_alt, so I'll just
443
+ # make sure the two are giving the same results/gradients
444
+ q = wp.array(rng.standard_normal(size=4).astype(dtype), dtype=quat, requires_grad=True, device=device)
445
+
446
+ n0 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
447
+ n1 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
448
+ n2 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
449
+ n3 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
450
+
451
+ n0_alt = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
452
+ n1_alt = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
453
+ n2_alt = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
454
+ n3_alt = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
455
+
456
+ outputs0 = [
457
+ n0,
458
+ n1,
459
+ n2,
460
+ n3,
461
+ ]
462
+ tape0 = wp.Tape()
463
+ with tape0:
464
+ wp.launch(normalize_kernel, dim=1, inputs=[q], outputs=outputs0, device=device)
465
+
466
+ outputs1 = [
467
+ n0_alt,
468
+ n1_alt,
469
+ n2_alt,
470
+ n3_alt,
471
+ ]
472
+ tape1 = wp.Tape()
473
+ with tape1:
474
+ wp.launch(
475
+ normalize_alt_kernel,
476
+ dim=1,
477
+ inputs=[
478
+ q,
479
+ ],
480
+ outputs=outputs1,
481
+ device=device,
482
+ )
483
+
484
+ assert_np_equal(n0.numpy()[0], n0_alt.numpy()[0], tol=tol)
485
+ assert_np_equal(n1.numpy()[0], n1_alt.numpy()[0], tol=tol)
486
+ assert_np_equal(n2.numpy()[0], n2_alt.numpy()[0], tol=tol)
487
+ assert_np_equal(n3.numpy()[0], n3_alt.numpy()[0], tol=tol)
488
+
489
+ for ncmp, ncmpalt in zip(outputs0, outputs1):
490
+ tape0.backward(loss=ncmp)
491
+ tape1.backward(loss=ncmpalt)
492
+ assert_np_equal(tape0.gradients[q].numpy()[0], tape1.gradients[q].numpy()[0], tol=tol)
493
+ tape0.zero()
494
+ tape1.zero()
495
+
496
+
497
+ def test_addition(test, device, dtype, register_kernels=False):
498
+ rng = np.random.default_rng(123)
499
+
500
+ tol = {
501
+ np.float16: 5.0e-3,
502
+ np.float32: 1.0e-6,
503
+ np.float64: 1.0e-8,
504
+ }.get(dtype, 0)
505
+
506
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
507
+ quat = wp.types.quaternion(dtype=wptype)
508
+
509
+ def check_quat_add(
510
+ q: wp.array(dtype=quat),
511
+ v: wp.array(dtype=quat),
512
+ r0: wp.array(dtype=wptype),
513
+ r1: wp.array(dtype=wptype),
514
+ r2: wp.array(dtype=wptype),
515
+ r3: wp.array(dtype=wptype),
516
+ ):
517
+ result = q[0] + v[0]
518
+
519
+ r0[0] = wptype(2) * result[0]
520
+ r1[0] = wptype(2) * result[1]
521
+ r2[0] = wptype(2) * result[2]
522
+ r3[0] = wptype(2) * result[3]
523
+
524
+ kernel = getkernel(check_quat_add, suffix=dtype.__name__)
525
+
526
+ if register_kernels:
527
+ return
528
+
529
+ q = wp.array(rng.standard_normal(size=4).astype(dtype), dtype=quat, requires_grad=True, device=device)
530
+ v = wp.array(rng.standard_normal(size=4).astype(dtype), dtype=quat, requires_grad=True, device=device)
531
+
532
+ r0 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
533
+ r1 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
534
+ r2 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
535
+ r3 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
536
+
537
+ tape = wp.Tape()
538
+ with tape:
539
+ wp.launch(
540
+ kernel,
541
+ dim=1,
542
+ inputs=[
543
+ q,
544
+ v,
545
+ ],
546
+ outputs=[r0, r1, r2, r3],
547
+ device=device,
548
+ )
549
+
550
+ assert_np_equal(r0.numpy()[0], 2 * (v.numpy()[0, 0] + q.numpy()[0, 0]), tol=tol)
551
+ assert_np_equal(r1.numpy()[0], 2 * (v.numpy()[0, 1] + q.numpy()[0, 1]), tol=tol)
552
+ assert_np_equal(r2.numpy()[0], 2 * (v.numpy()[0, 2] + q.numpy()[0, 2]), tol=tol)
553
+ assert_np_equal(r3.numpy()[0], 2 * (v.numpy()[0, 3] + q.numpy()[0, 3]), tol=tol)
554
+
555
+ for i, l in enumerate([r0, r1, r2, r3]):
556
+ tape.backward(loss=l)
557
+ qgrads = tape.gradients[q].numpy()[0]
558
+ expected_grads = np.zeros_like(qgrads)
559
+
560
+ expected_grads[i] = 2
561
+ assert_np_equal(qgrads, expected_grads, tol=10 * tol)
562
+
563
+ vgrads = tape.gradients[v].numpy()[0]
564
+ assert_np_equal(vgrads, expected_grads, tol=tol)
565
+
566
+ tape.zero()
567
+
568
+
569
+ def test_subtraction(test, device, dtype, register_kernels=False):
570
+ rng = np.random.default_rng(123)
571
+
572
+ tol = {
573
+ np.float16: 5.0e-3,
574
+ np.float32: 1.0e-6,
575
+ np.float64: 1.0e-8,
576
+ }.get(dtype, 0)
577
+
578
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
579
+ quat = wp.types.quaternion(dtype=wptype)
580
+
581
+ def check_quat_sub(
582
+ q: wp.array(dtype=quat),
583
+ v: wp.array(dtype=quat),
584
+ r0: wp.array(dtype=wptype),
585
+ r1: wp.array(dtype=wptype),
586
+ r2: wp.array(dtype=wptype),
587
+ r3: wp.array(dtype=wptype),
588
+ ):
589
+ result = v[0] - q[0]
590
+
591
+ r0[0] = wptype(2) * result[0]
592
+ r1[0] = wptype(2) * result[1]
593
+ r2[0] = wptype(2) * result[2]
594
+ r3[0] = wptype(2) * result[3]
595
+
596
+ kernel = getkernel(check_quat_sub, suffix=dtype.__name__)
597
+
598
+ if register_kernels:
599
+ return
600
+
601
+ q = wp.array(rng.standard_normal(size=4).astype(dtype), dtype=quat, requires_grad=True, device=device)
602
+ v = wp.array(rng.standard_normal(size=4).astype(dtype), dtype=quat, requires_grad=True, device=device)
603
+
604
+ r0 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
605
+ r1 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
606
+ r2 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
607
+ r3 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
608
+
609
+ tape = wp.Tape()
610
+ with tape:
611
+ wp.launch(
612
+ kernel,
613
+ dim=1,
614
+ inputs=[
615
+ q,
616
+ v,
617
+ ],
618
+ outputs=[r0, r1, r2, r3],
619
+ device=device,
620
+ )
621
+
622
+ assert_np_equal(r0.numpy()[0], 2 * (v.numpy()[0, 0] - q.numpy()[0, 0]), tol=tol)
623
+ assert_np_equal(r1.numpy()[0], 2 * (v.numpy()[0, 1] - q.numpy()[0, 1]), tol=tol)
624
+ assert_np_equal(r2.numpy()[0], 2 * (v.numpy()[0, 2] - q.numpy()[0, 2]), tol=tol)
625
+ assert_np_equal(r3.numpy()[0], 2 * (v.numpy()[0, 3] - q.numpy()[0, 3]), tol=tol)
626
+
627
+ for i, l in enumerate([r0, r1, r2, r3]):
628
+ tape.backward(loss=l)
629
+ qgrads = tape.gradients[q].numpy()[0]
630
+ expected_grads = np.zeros_like(qgrads)
631
+
632
+ expected_grads[i] = -2
633
+ assert_np_equal(qgrads, expected_grads, tol=10 * tol)
634
+
635
+ vgrads = tape.gradients[v].numpy()[0]
636
+ expected_grads[i] = 2
637
+ assert_np_equal(vgrads, expected_grads, tol=tol)
638
+
639
+ tape.zero()
640
+
641
+
642
+ def test_scalar_multiplication(test, device, dtype, register_kernels=False):
643
+ rng = np.random.default_rng(123)
644
+
645
+ tol = {
646
+ np.float16: 5.0e-3,
647
+ np.float32: 1.0e-6,
648
+ np.float64: 1.0e-8,
649
+ }.get(dtype, 0)
650
+
651
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
652
+ quat = wp.types.quaternion(dtype=wptype)
653
+
654
+ def check_quat_scalar_mul(
655
+ s: wp.array(dtype=wptype),
656
+ q: wp.array(dtype=quat),
657
+ l0: wp.array(dtype=wptype),
658
+ l1: wp.array(dtype=wptype),
659
+ l2: wp.array(dtype=wptype),
660
+ l3: wp.array(dtype=wptype),
661
+ r0: wp.array(dtype=wptype),
662
+ r1: wp.array(dtype=wptype),
663
+ r2: wp.array(dtype=wptype),
664
+ r3: wp.array(dtype=wptype),
665
+ ):
666
+ lresult = s[0] * q[0]
667
+ rresult = q[0] * s[0]
668
+
669
+ # multiply outputs by 2 so we've got something to backpropagate:
670
+ l0[0] = wptype(2) * lresult[0]
671
+ l1[0] = wptype(2) * lresult[1]
672
+ l2[0] = wptype(2) * lresult[2]
673
+ l3[0] = wptype(2) * lresult[3]
674
+
675
+ r0[0] = wptype(2) * rresult[0]
676
+ r1[0] = wptype(2) * rresult[1]
677
+ r2[0] = wptype(2) * rresult[2]
678
+ r3[0] = wptype(2) * rresult[3]
679
+
680
+ kernel = getkernel(check_quat_scalar_mul, suffix=dtype.__name__)
681
+
682
+ if register_kernels:
683
+ return
684
+
685
+ s = wp.array(rng.standard_normal(size=1).astype(dtype), requires_grad=True, device=device)
686
+ q = wp.array(rng.standard_normal(size=(1, 4)).astype(dtype), dtype=quat, requires_grad=True, device=device)
687
+
688
+ l0 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
689
+ l1 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
690
+ l2 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
691
+ l3 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
692
+
693
+ r0 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
694
+ r1 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
695
+ r2 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
696
+ r3 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
697
+
698
+ tape = wp.Tape()
699
+ with tape:
700
+ wp.launch(
701
+ kernel,
702
+ dim=1,
703
+ inputs=[s, q],
704
+ outputs=[
705
+ l0,
706
+ l1,
707
+ l2,
708
+ l3,
709
+ r0,
710
+ r1,
711
+ r2,
712
+ r3,
713
+ ],
714
+ device=device,
715
+ )
716
+
717
+ assert_np_equal(l0.numpy()[0], 2 * s.numpy()[0] * q.numpy()[0, 0], tol=tol)
718
+ assert_np_equal(l1.numpy()[0], 2 * s.numpy()[0] * q.numpy()[0, 1], tol=tol)
719
+ assert_np_equal(l2.numpy()[0], 2 * s.numpy()[0] * q.numpy()[0, 2], tol=tol)
720
+ assert_np_equal(l3.numpy()[0], 2 * s.numpy()[0] * q.numpy()[0, 3], tol=tol)
721
+
722
+ assert_np_equal(r0.numpy()[0], 2 * s.numpy()[0] * q.numpy()[0, 0], tol=tol)
723
+ assert_np_equal(r1.numpy()[0], 2 * s.numpy()[0] * q.numpy()[0, 1], tol=tol)
724
+ assert_np_equal(r2.numpy()[0], 2 * s.numpy()[0] * q.numpy()[0, 2], tol=tol)
725
+ assert_np_equal(r3.numpy()[0], 2 * s.numpy()[0] * q.numpy()[0, 3], tol=tol)
726
+
727
+ if dtype in np_float_types:
728
+ for i, outputs in enumerate([(l0, r0), (l1, r1), (l2, r2), (l3, r3)]):
729
+ for l in outputs:
730
+ tape.backward(loss=l)
731
+ sgrad = tape.gradients[s].numpy()[0]
732
+ assert_np_equal(sgrad, 2 * q.numpy()[0, i], tol=tol)
733
+ allgrads = tape.gradients[q].numpy()[0]
734
+ expected_grads = np.zeros_like(allgrads)
735
+ expected_grads[i] = s.numpy()[0] * 2
736
+ assert_np_equal(allgrads, expected_grads, tol=10 * tol)
737
+ tape.zero()
738
+
739
+
740
+ def test_scalar_division(test, device, dtype, register_kernels=False):
741
+ rng = np.random.default_rng(123)
742
+
743
+ tol = {
744
+ np.float16: 1.0e-3,
745
+ np.float32: 1.0e-6,
746
+ np.float64: 1.0e-8,
747
+ }.get(dtype, 0)
748
+
749
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
750
+ quat = wp.types.quaternion(dtype=wptype)
751
+
752
+ def check_quat_scalar_div(
753
+ s: wp.array(dtype=wptype),
754
+ q: wp.array(dtype=quat),
755
+ r0: wp.array(dtype=wptype),
756
+ r1: wp.array(dtype=wptype),
757
+ r2: wp.array(dtype=wptype),
758
+ r3: wp.array(dtype=wptype),
759
+ ):
760
+ result = q[0] / s[0]
761
+
762
+ # multiply outputs by 2 so we've got something to backpropagate:
763
+ r0[0] = wptype(2) * result[0]
764
+ r1[0] = wptype(2) * result[1]
765
+ r2[0] = wptype(2) * result[2]
766
+ r3[0] = wptype(2) * result[3]
767
+
768
+ kernel = getkernel(check_quat_scalar_div, suffix=dtype.__name__)
769
+
770
+ if register_kernels:
771
+ return
772
+
773
+ s = wp.array(rng.standard_normal(size=1).astype(dtype), requires_grad=True, device=device)
774
+ q = wp.array(rng.standard_normal(size=(1, 4)).astype(dtype), dtype=quat, requires_grad=True, device=device)
775
+
776
+ r0 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
777
+ r1 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
778
+ r2 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
779
+ r3 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
780
+
781
+ tape = wp.Tape()
782
+ with tape:
783
+ wp.launch(
784
+ kernel,
785
+ dim=1,
786
+ inputs=[s, q],
787
+ outputs=[
788
+ r0,
789
+ r1,
790
+ r2,
791
+ r3,
792
+ ],
793
+ device=device,
794
+ )
795
+ assert_np_equal(r0.numpy()[0], 2 * q.numpy()[0, 0] / s.numpy()[0], tol=tol)
796
+ assert_np_equal(r1.numpy()[0], 2 * q.numpy()[0, 1] / s.numpy()[0], tol=tol)
797
+ assert_np_equal(r2.numpy()[0], 2 * q.numpy()[0, 2] / s.numpy()[0], tol=tol)
798
+ assert_np_equal(r3.numpy()[0], 2 * q.numpy()[0, 3] / s.numpy()[0], tol=tol)
799
+
800
+ if dtype in np_float_types:
801
+ for i, r in enumerate([r0, r1, r2, r3]):
802
+ tape.backward(loss=r)
803
+ sgrad = tape.gradients[s].numpy()[0]
804
+ assert_np_equal(sgrad, -2 * q.numpy()[0, i] / (s.numpy()[0] * s.numpy()[0]), tol=tol)
805
+
806
+ allgrads = tape.gradients[q].numpy()[0]
807
+ expected_grads = np.zeros_like(allgrads)
808
+ expected_grads[i] = 2 / s.numpy()[0]
809
+ assert_np_equal(allgrads, expected_grads, tol=10 * tol)
810
+ tape.zero()
811
+
812
+
813
+ def test_quat_multiplication(test, device, dtype, register_kernels=False):
814
+ rng = np.random.default_rng(123)
815
+
816
+ tol = {
817
+ np.float16: 1.0e-2,
818
+ np.float32: 1.0e-6,
819
+ np.float64: 1.0e-8,
820
+ }.get(dtype, 0)
821
+
822
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
823
+ quat = wp.types.quaternion(dtype=wptype)
824
+
825
+ def check_quat_mul(
826
+ s: wp.array(dtype=quat),
827
+ q: wp.array(dtype=quat),
828
+ r0: wp.array(dtype=wptype),
829
+ r1: wp.array(dtype=wptype),
830
+ r2: wp.array(dtype=wptype),
831
+ r3: wp.array(dtype=wptype),
832
+ ):
833
+ result = s[0] * q[0]
834
+
835
+ # multiply outputs by 2 so we've got something to backpropagate:
836
+ r0[0] = wptype(2) * result[0]
837
+ r1[0] = wptype(2) * result[1]
838
+ r2[0] = wptype(2) * result[2]
839
+ r3[0] = wptype(2) * result[3]
840
+
841
+ kernel = getkernel(check_quat_mul, suffix=dtype.__name__)
842
+
843
+ if register_kernels:
844
+ return
845
+
846
+ s = wp.array(rng.standard_normal(size=(1, 4)).astype(dtype), dtype=quat, requires_grad=True, device=device)
847
+ q = wp.array(rng.standard_normal(size=(1, 4)).astype(dtype), dtype=quat, requires_grad=True, device=device)
848
+
849
+ r0 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
850
+ r1 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
851
+ r2 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
852
+ r3 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
853
+
854
+ tape = wp.Tape()
855
+ with tape:
856
+ wp.launch(
857
+ kernel,
858
+ dim=1,
859
+ inputs=[s, q],
860
+ outputs=[
861
+ r0,
862
+ r1,
863
+ r2,
864
+ r3,
865
+ ],
866
+ device=device,
867
+ )
868
+
869
+ a = s.numpy()
870
+ b = q.numpy()
871
+ assert_np_equal(
872
+ r0.numpy()[0], 2 * (a[0, 3] * b[0, 0] + b[0, 3] * a[0, 0] + a[0, 1] * b[0, 2] - b[0, 1] * a[0, 2]), tol=tol
873
+ )
874
+ assert_np_equal(
875
+ r1.numpy()[0], 2 * (a[0, 3] * b[0, 1] + b[0, 3] * a[0, 1] + a[0, 2] * b[0, 0] - b[0, 2] * a[0, 0]), tol=tol
876
+ )
877
+ assert_np_equal(
878
+ r2.numpy()[0], 2 * (a[0, 3] * b[0, 2] + b[0, 3] * a[0, 2] + a[0, 0] * b[0, 1] - b[0, 0] * a[0, 1]), tol=tol
879
+ )
880
+ assert_np_equal(
881
+ r3.numpy()[0], 2 * (a[0, 3] * b[0, 3] - a[0, 0] * b[0, 0] - a[0, 1] * b[0, 1] - a[0, 2] * b[0, 2]), tol=tol
882
+ )
883
+
884
+ tape.backward(loss=r0)
885
+ agrad = tape.gradients[s].numpy()[0]
886
+ assert_np_equal(agrad, 2 * np.array([b[0, 3], b[0, 2], -b[0, 1], b[0, 0]]), tol=tol)
887
+
888
+ bgrad = tape.gradients[q].numpy()[0]
889
+ assert_np_equal(bgrad, 2 * np.array([a[0, 3], -a[0, 2], a[0, 1], a[0, 0]]), tol=tol)
890
+ tape.zero()
891
+
892
+ tape.backward(loss=r1)
893
+ agrad = tape.gradients[s].numpy()[0]
894
+ assert_np_equal(agrad, 2 * np.array([-b[0, 2], b[0, 3], b[0, 0], b[0, 1]]), tol=tol)
895
+
896
+ bgrad = tape.gradients[q].numpy()[0]
897
+ assert_np_equal(bgrad, 2 * np.array([a[0, 2], a[0, 3], -a[0, 0], a[0, 1]]), tol=tol)
898
+ tape.zero()
899
+
900
+ tape.backward(loss=r2)
901
+ agrad = tape.gradients[s].numpy()[0]
902
+ assert_np_equal(agrad, 2 * np.array([b[0, 1], -b[0, 0], b[0, 3], b[0, 2]]), tol=tol)
903
+
904
+ bgrad = tape.gradients[q].numpy()[0]
905
+ assert_np_equal(bgrad, 2 * np.array([-a[0, 1], a[0, 0], a[0, 3], a[0, 2]]), tol=tol)
906
+ tape.zero()
907
+
908
+ tape.backward(loss=r3)
909
+ agrad = tape.gradients[s].numpy()[0]
910
+ assert_np_equal(agrad, 2 * np.array([-b[0, 0], -b[0, 1], -b[0, 2], b[0, 3]]), tol=tol)
911
+
912
+ bgrad = tape.gradients[q].numpy()[0]
913
+ assert_np_equal(bgrad, 2 * np.array([-a[0, 0], -a[0, 1], -a[0, 2], a[0, 3]]), tol=tol)
914
+ tape.zero()
915
+
916
+
917
+ def test_indexing(test, device, dtype, register_kernels=False):
918
+ rng = np.random.default_rng(123)
919
+
920
+ tol = {
921
+ np.float16: 5.0e-3,
922
+ np.float32: 1.0e-6,
923
+ np.float64: 1.0e-8,
924
+ }.get(dtype, 0)
925
+
926
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
927
+ quat = wp.types.quaternion(dtype=wptype)
928
+
929
+ def check_quat_indexing(
930
+ q: wp.array(dtype=quat),
931
+ r0: wp.array(dtype=wptype),
932
+ r1: wp.array(dtype=wptype),
933
+ r2: wp.array(dtype=wptype),
934
+ r3: wp.array(dtype=wptype),
935
+ ):
936
+ # multiply outputs by 2 so we've got something to backpropagate:
937
+ r0[0] = wptype(2) * q[0][0]
938
+ r1[0] = wptype(2) * q[0][1]
939
+ r2[0] = wptype(2) * q[0][2]
940
+ r3[0] = wptype(2) * q[0][3]
941
+
942
+ kernel = getkernel(check_quat_indexing, suffix=dtype.__name__)
943
+
944
+ if register_kernels:
945
+ return
946
+
947
+ q = wp.array(rng.standard_normal(size=(1, 4)).astype(dtype), dtype=quat, requires_grad=True, device=device)
948
+ r0 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
949
+ r1 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
950
+ r2 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
951
+ r3 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
952
+
953
+ tape = wp.Tape()
954
+ with tape:
955
+ wp.launch(kernel, dim=1, inputs=[q], outputs=[r0, r1, r2, r3], device=device)
956
+
957
+ for i, l in enumerate([r0, r1, r2, r3]):
958
+ tape.backward(loss=l)
959
+ allgrads = tape.gradients[q].numpy()[0]
960
+ expected_grads = np.zeros_like(allgrads)
961
+ expected_grads[i] = 2
962
+ assert_np_equal(allgrads, expected_grads, tol=tol)
963
+ tape.zero()
964
+
965
+ assert_np_equal(r0.numpy()[0], 2.0 * q.numpy()[0, 0], tol=tol)
966
+ assert_np_equal(r1.numpy()[0], 2.0 * q.numpy()[0, 1], tol=tol)
967
+ assert_np_equal(r2.numpy()[0], 2.0 * q.numpy()[0, 2], tol=tol)
968
+ assert_np_equal(r3.numpy()[0], 2.0 * q.numpy()[0, 3], tol=tol)
969
+
970
+
971
+ def test_quat_lerp(test, device, dtype, register_kernels=False):
972
+ rng = np.random.default_rng(123)
973
+
974
+ tol = {
975
+ np.float16: 1.0e-2,
976
+ np.float32: 1.0e-6,
977
+ np.float64: 1.0e-8,
978
+ }.get(dtype, 0)
979
+
980
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
981
+ quat = wp.types.quaternion(dtype=wptype)
982
+
983
+ def check_quat_lerp(
984
+ s: wp.array(dtype=quat),
985
+ q: wp.array(dtype=quat),
986
+ t: wp.array(dtype=wptype),
987
+ r0: wp.array(dtype=wptype),
988
+ r1: wp.array(dtype=wptype),
989
+ r2: wp.array(dtype=wptype),
990
+ r3: wp.array(dtype=wptype),
991
+ ):
992
+ result = wp.lerp(s[0], q[0], t[0])
993
+
994
+ # multiply outputs by 2 so we've got something to backpropagate:
995
+ r0[0] = wptype(2) * result[0]
996
+ r1[0] = wptype(2) * result[1]
997
+ r2[0] = wptype(2) * result[2]
998
+ r3[0] = wptype(2) * result[3]
999
+
1000
+ kernel = getkernel(check_quat_lerp, suffix=dtype.__name__)
1001
+
1002
+ if register_kernels:
1003
+ return
1004
+
1005
+ s = wp.array(rng.standard_normal(size=(1, 4)).astype(dtype), dtype=quat, requires_grad=True, device=device)
1006
+ q = wp.array(rng.standard_normal(size=(1, 4)).astype(dtype), dtype=quat, requires_grad=True, device=device)
1007
+ t = wp.array(rng.uniform(size=1).astype(dtype), dtype=wptype, requires_grad=True, device=device)
1008
+
1009
+ r0 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1010
+ r1 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1011
+ r2 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1012
+ r3 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1013
+
1014
+ tape = wp.Tape()
1015
+ with tape:
1016
+ wp.launch(
1017
+ kernel,
1018
+ dim=1,
1019
+ inputs=[s, q, t],
1020
+ outputs=[
1021
+ r0,
1022
+ r1,
1023
+ r2,
1024
+ r3,
1025
+ ],
1026
+ device=device,
1027
+ )
1028
+
1029
+ a = s.numpy()
1030
+ b = q.numpy()
1031
+ tt = t.numpy()
1032
+ assert_np_equal(r0.numpy()[0], 2 * ((1 - tt) * a[0, 0] + tt * b[0, 0]), tol=tol)
1033
+ assert_np_equal(r1.numpy()[0], 2 * ((1 - tt) * a[0, 1] + tt * b[0, 1]), tol=tol)
1034
+ assert_np_equal(r2.numpy()[0], 2 * ((1 - tt) * a[0, 2] + tt * b[0, 2]), tol=tol)
1035
+ assert_np_equal(r3.numpy()[0], 2 * ((1 - tt) * a[0, 3] + tt * b[0, 3]), tol=tol)
1036
+
1037
+ for i, l in enumerate([r0, r1, r2, r3]):
1038
+ tape.backward(loss=l)
1039
+ agrad = tape.gradients[s].numpy()[0]
1040
+ bgrad = tape.gradients[q].numpy()[0]
1041
+ tgrad = tape.gradients[t].numpy()[0]
1042
+ expected_grads = np.zeros_like(agrad)
1043
+ expected_grads[i] = 2 * (1 - tt)
1044
+ assert_np_equal(agrad, expected_grads, tol=tol)
1045
+ expected_grads[i] = 2 * tt
1046
+ assert_np_equal(bgrad, expected_grads, tol=tol)
1047
+ assert_np_equal(tgrad, 2 * (b[0, i] - a[0, i]), tol=tol)
1048
+
1049
+ tape.zero()
1050
+
1051
+
1052
+ def test_quat_rotate(test, device, dtype, register_kernels=False):
1053
+ rng = np.random.default_rng(123)
1054
+
1055
+ tol = {
1056
+ np.float16: 1.0e-2,
1057
+ np.float32: 1.0e-6,
1058
+ np.float64: 1.0e-8,
1059
+ }.get(dtype, 0)
1060
+
1061
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1062
+ quat = wp.types.quaternion(dtype=wptype)
1063
+ vec3 = wp.types.vector(length=3, dtype=wptype)
1064
+
1065
+ def check_quat_rotate(
1066
+ q: wp.array(dtype=quat),
1067
+ v: wp.array(dtype=vec3),
1068
+ outputs: wp.array(dtype=wptype),
1069
+ outputs_inv: wp.array(dtype=wptype),
1070
+ outputs_manual: wp.array(dtype=wptype),
1071
+ outputs_inv_manual: wp.array(dtype=wptype),
1072
+ ):
1073
+ result = wp.quat_rotate(q[0], v[0])
1074
+ result_inv = wp.quat_rotate_inv(q[0], v[0])
1075
+
1076
+ qv = vec3(q[0][0], q[0][1], q[0][2])
1077
+ qw = q[0][3]
1078
+
1079
+ result_manual = v[0] * (wptype(2) * qw * qw - wptype(1))
1080
+ result_manual += wp.cross(qv, v[0]) * qw * wptype(2)
1081
+ result_manual += qv * wp.dot(qv, v[0]) * wptype(2)
1082
+
1083
+ result_inv_manual = v[0] * (wptype(2) * qw * qw - wptype(1))
1084
+ result_inv_manual -= wp.cross(qv, v[0]) * qw * wptype(2)
1085
+ result_inv_manual += qv * wp.dot(qv, v[0]) * wptype(2)
1086
+
1087
+ for i in range(3):
1088
+ # multiply outputs by 2 so we've got something to backpropagate:
1089
+ outputs[i] = wptype(2) * result[i]
1090
+ outputs_inv[i] = wptype(2) * result_inv[i]
1091
+ outputs_manual[i] = wptype(2) * result_manual[i]
1092
+ outputs_inv_manual[i] = wptype(2) * result_inv_manual[i]
1093
+
1094
+ kernel = getkernel(check_quat_rotate, suffix=dtype.__name__)
1095
+ output_select_kernel = get_select_kernel(wptype)
1096
+
1097
+ if register_kernels:
1098
+ return
1099
+
1100
+ q = rng.standard_normal(size=(1, 4))
1101
+ q /= np.linalg.norm(q)
1102
+ q = wp.array(q.astype(dtype), dtype=quat, requires_grad=True, device=device)
1103
+ v = wp.array(0.5 * rng.standard_normal(size=(1, 3)).astype(dtype), dtype=vec3, requires_grad=True, device=device)
1104
+
1105
+ # test values against the manually computed result:
1106
+ outputs = wp.zeros(3, dtype=wptype, requires_grad=True, device=device)
1107
+ outputs_inv = wp.zeros(3, dtype=wptype, requires_grad=True, device=device)
1108
+ outputs_manual = wp.zeros(3, dtype=wptype, requires_grad=True, device=device)
1109
+ outputs_inv_manual = wp.zeros(3, dtype=wptype, requires_grad=True, device=device)
1110
+
1111
+ wp.launch(
1112
+ kernel,
1113
+ dim=1,
1114
+ inputs=[q, v],
1115
+ outputs=[
1116
+ outputs,
1117
+ outputs_inv,
1118
+ outputs_manual,
1119
+ outputs_inv_manual,
1120
+ ],
1121
+ device=device,
1122
+ )
1123
+
1124
+ assert_np_equal(outputs.numpy(), outputs_manual.numpy(), tol=tol)
1125
+ assert_np_equal(outputs_inv.numpy(), outputs_inv_manual.numpy(), tol=tol)
1126
+
1127
+ # test gradients against the manually computed result:
1128
+ for i in range(3):
1129
+ cmp = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1130
+ cmp_inv = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1131
+ cmp_manual = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1132
+ cmp_inv_manual = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1133
+ tape = wp.Tape()
1134
+ with tape:
1135
+ wp.launch(
1136
+ kernel,
1137
+ dim=1,
1138
+ inputs=[q, v],
1139
+ outputs=[
1140
+ outputs,
1141
+ outputs_inv,
1142
+ outputs_manual,
1143
+ outputs_inv_manual,
1144
+ ],
1145
+ device=device,
1146
+ )
1147
+ wp.launch(output_select_kernel, dim=1, inputs=[outputs, i], outputs=[cmp], device=device)
1148
+ wp.launch(output_select_kernel, dim=1, inputs=[outputs_inv, i], outputs=[cmp_inv], device=device)
1149
+ wp.launch(output_select_kernel, dim=1, inputs=[outputs_manual, i], outputs=[cmp_manual], device=device)
1150
+ wp.launch(
1151
+ output_select_kernel, dim=1, inputs=[outputs_inv_manual, i], outputs=[cmp_inv_manual], device=device
1152
+ )
1153
+
1154
+ tape.backward(loss=cmp)
1155
+ qgrads = 1.0 * tape.gradients[q].numpy()
1156
+ vgrads = 1.0 * tape.gradients[v].numpy()
1157
+ tape.zero()
1158
+ tape.backward(loss=cmp_inv)
1159
+ qgrads_inv = 1.0 * tape.gradients[q].numpy()
1160
+ vgrads_inv = 1.0 * tape.gradients[v].numpy()
1161
+ tape.zero()
1162
+ tape.backward(loss=cmp_manual)
1163
+ qgrads_manual = 1.0 * tape.gradients[q].numpy()
1164
+ vgrads_manual = 1.0 * tape.gradients[v].numpy()
1165
+ tape.zero()
1166
+ tape.backward(loss=cmp_inv_manual)
1167
+ qgrads_inv_manual = 1.0 * tape.gradients[q].numpy()
1168
+ vgrads_inv_manual = 1.0 * tape.gradients[v].numpy()
1169
+ tape.zero()
1170
+
1171
+ assert_np_equal(qgrads, qgrads_manual, tol=tol)
1172
+ assert_np_equal(vgrads, vgrads_manual, tol=tol)
1173
+
1174
+ assert_np_equal(qgrads_inv, qgrads_inv_manual, tol=tol)
1175
+ assert_np_equal(vgrads_inv, vgrads_inv_manual, tol=tol)
1176
+
1177
+
1178
+ def test_quat_to_matrix(test, device, dtype, register_kernels=False):
1179
+ rng = np.random.default_rng(123)
1180
+
1181
+ tol = {
1182
+ np.float16: 1.0e-2,
1183
+ np.float32: 1.0e-6,
1184
+ np.float64: 1.0e-8,
1185
+ }.get(dtype, 0)
1186
+
1187
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1188
+ quat = wp.types.quaternion(dtype=wptype)
1189
+ mat3 = wp.types.matrix(shape=(3, 3), dtype=wptype)
1190
+ vec3 = wp.types.vector(length=3, dtype=wptype)
1191
+
1192
+ def check_quat_to_matrix(
1193
+ q: wp.array(dtype=quat),
1194
+ outputs: wp.array(dtype=wptype),
1195
+ outputs_manual: wp.array(dtype=wptype),
1196
+ ):
1197
+ result = wp.quat_to_matrix(q[0])
1198
+
1199
+ xaxis = wp.quat_rotate(
1200
+ q[0],
1201
+ vec3(
1202
+ wptype(1),
1203
+ wptype(0),
1204
+ wptype(0),
1205
+ ),
1206
+ )
1207
+ yaxis = wp.quat_rotate(
1208
+ q[0],
1209
+ vec3(
1210
+ wptype(0),
1211
+ wptype(1),
1212
+ wptype(0),
1213
+ ),
1214
+ )
1215
+ zaxis = wp.quat_rotate(
1216
+ q[0],
1217
+ vec3(
1218
+ wptype(0),
1219
+ wptype(0),
1220
+ wptype(1),
1221
+ ),
1222
+ )
1223
+ result_manual = mat3(xaxis, yaxis, zaxis)
1224
+
1225
+ idx = 0
1226
+ for i in range(3):
1227
+ for j in range(3):
1228
+ # multiply outputs by 2 so we've got something to backpropagate:
1229
+ outputs[idx] = wptype(2) * result[i, j]
1230
+ outputs_manual[idx] = wptype(2) * result_manual[i, j]
1231
+
1232
+ idx = idx + 1
1233
+
1234
+ kernel = getkernel(check_quat_to_matrix, suffix=dtype.__name__)
1235
+ output_select_kernel = get_select_kernel(wptype)
1236
+
1237
+ if register_kernels:
1238
+ return
1239
+
1240
+ q = rng.standard_normal(size=(1, 4))
1241
+ q /= np.linalg.norm(q)
1242
+ q = wp.array(q.astype(dtype), dtype=quat, requires_grad=True, device=device)
1243
+
1244
+ # test values against the manually computed result:
1245
+ outputs = wp.zeros(3 * 3, dtype=wptype, requires_grad=True, device=device)
1246
+ outputs_manual = wp.zeros(3 * 3, dtype=wptype, requires_grad=True, device=device)
1247
+
1248
+ wp.launch(
1249
+ kernel,
1250
+ dim=1,
1251
+ inputs=[q],
1252
+ outputs=[
1253
+ outputs,
1254
+ outputs_manual,
1255
+ ],
1256
+ device=device,
1257
+ )
1258
+
1259
+ assert_np_equal(outputs.numpy(), outputs_manual.numpy(), tol=tol)
1260
+
1261
+ # sanity check: divide by 2 to remove that scale factor we put in there, and
1262
+ # it should be a rotation matrix
1263
+ R = 0.5 * outputs.numpy().reshape(3, 3)
1264
+ assert_np_equal(np.matmul(R, R.T), np.eye(3), tol=tol)
1265
+
1266
+ # test gradients against the manually computed result:
1267
+ idx = 0
1268
+ for _i in range(3):
1269
+ for _j in range(3):
1270
+ cmp = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1271
+ cmp_manual = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1272
+ tape = wp.Tape()
1273
+ with tape:
1274
+ wp.launch(
1275
+ kernel,
1276
+ dim=1,
1277
+ inputs=[q],
1278
+ outputs=[
1279
+ outputs,
1280
+ outputs_manual,
1281
+ ],
1282
+ device=device,
1283
+ )
1284
+ wp.launch(output_select_kernel, dim=1, inputs=[outputs, idx], outputs=[cmp], device=device)
1285
+ wp.launch(
1286
+ output_select_kernel, dim=1, inputs=[outputs_manual, idx], outputs=[cmp_manual], device=device
1287
+ )
1288
+ tape.backward(loss=cmp)
1289
+ qgrads = 1.0 * tape.gradients[q].numpy()
1290
+ tape.zero()
1291
+ tape.backward(loss=cmp_manual)
1292
+ qgrads_manual = 1.0 * tape.gradients[q].numpy()
1293
+ tape.zero()
1294
+
1295
+ assert_np_equal(qgrads, qgrads_manual, tol=tol)
1296
+ idx = idx + 1
1297
+
1298
+
1299
+ ############################################################
1300
+
1301
+
1302
+ def test_slerp_grad(test, device, dtype, register_kernels=False):
1303
+ rng = np.random.default_rng(123)
1304
+ seed = 42
1305
+
1306
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1307
+ vec3 = wp.types.vector(3, wptype)
1308
+ quat = wp.types.quaternion(wptype)
1309
+
1310
+ def slerp_kernel(
1311
+ q0: wp.array(dtype=quat),
1312
+ q1: wp.array(dtype=quat),
1313
+ t: wp.array(dtype=wptype),
1314
+ loss: wp.array(dtype=wptype),
1315
+ index: int,
1316
+ ):
1317
+ tid = wp.tid()
1318
+
1319
+ q = wp.quat_slerp(q0[tid], q1[tid], t[tid])
1320
+ wp.atomic_add(loss, 0, q[index])
1321
+
1322
+ slerp_kernel = getkernel(slerp_kernel, suffix=dtype.__name__)
1323
+
1324
+ def slerp_kernel_forward(
1325
+ q0: wp.array(dtype=quat),
1326
+ q1: wp.array(dtype=quat),
1327
+ t: wp.array(dtype=wptype),
1328
+ loss: wp.array(dtype=wptype),
1329
+ index: int,
1330
+ ):
1331
+ tid = wp.tid()
1332
+
1333
+ axis = vec3()
1334
+ angle = wptype(0.0)
1335
+
1336
+ wp.quat_to_axis_angle(wp.mul(wp.quat_inverse(q0[tid]), q1[tid]), axis, angle)
1337
+ q = wp.mul(q0[tid], wp.quat_from_axis_angle(axis, t[tid] * angle))
1338
+
1339
+ wp.atomic_add(loss, 0, q[index])
1340
+
1341
+ slerp_kernel_forward = getkernel(slerp_kernel_forward, suffix=dtype.__name__)
1342
+
1343
+ def quat_sampler_slerp(kernel_seed: int, quats: wp.array(dtype=quat)):
1344
+ tid = wp.tid()
1345
+
1346
+ state = wp.rand_init(kernel_seed, tid)
1347
+
1348
+ angle = wp.randf(state, 0.0, 2.0 * 3.1415926535)
1349
+ dir = wp.sample_unit_sphere_surface(state) * wp.sin(angle * 0.5)
1350
+
1351
+ q = quat(wptype(dir[0]), wptype(dir[1]), wptype(dir[2]), wptype(wp.cos(angle * 0.5)))
1352
+ qn = wp.normalize(q)
1353
+
1354
+ quats[tid] = qn
1355
+
1356
+ quat_sampler = getkernel(quat_sampler_slerp, suffix=dtype.__name__)
1357
+
1358
+ if register_kernels:
1359
+ return
1360
+
1361
+ N = 50
1362
+
1363
+ q0 = wp.zeros(N, dtype=quat, device=device, requires_grad=True)
1364
+ q1 = wp.zeros(N, dtype=quat, device=device, requires_grad=True)
1365
+
1366
+ wp.launch(kernel=quat_sampler, dim=N, inputs=[seed, q0], device=device)
1367
+ wp.launch(kernel=quat_sampler, dim=N, inputs=[seed + 1, q1], device=device)
1368
+
1369
+ t = rng.uniform(low=0.0, high=1.0, size=N)
1370
+ t = wp.array(t, dtype=wptype, device=device, requires_grad=True)
1371
+
1372
+ def compute_gradients(kernel, wrt, index):
1373
+ loss = wp.zeros(1, dtype=wptype, device=device, requires_grad=True)
1374
+ tape = wp.Tape()
1375
+ with tape:
1376
+ wp.launch(kernel=kernel, dim=N, inputs=[q0, q1, t, loss, index], device=device)
1377
+
1378
+ tape.backward(loss)
1379
+
1380
+ gradients = 1.0 * tape.gradients[wrt].numpy()
1381
+ tape.zero()
1382
+
1383
+ return loss.numpy()[0], gradients
1384
+
1385
+ eps = {
1386
+ np.float16: 2.0e-2,
1387
+ np.float32: 1.0e-5,
1388
+ np.float64: 1.0e-8,
1389
+ }.get(dtype, 0)
1390
+
1391
+ # wrt t
1392
+
1393
+ # gather gradients from builtin adjoints
1394
+ xcmp, gradients_x = compute_gradients(slerp_kernel, t, 0)
1395
+ ycmp, gradients_y = compute_gradients(slerp_kernel, t, 1)
1396
+ zcmp, gradients_z = compute_gradients(slerp_kernel, t, 2)
1397
+ wcmp, gradients_w = compute_gradients(slerp_kernel, t, 3)
1398
+
1399
+ # gather gradients from autodiff
1400
+ xcmp_auto, gradients_x_auto = compute_gradients(slerp_kernel_forward, t, 0)
1401
+ ycmp_auto, gradients_y_auto = compute_gradients(slerp_kernel_forward, t, 1)
1402
+ zcmp_auto, gradients_z_auto = compute_gradients(slerp_kernel_forward, t, 2)
1403
+ wcmp_auto, gradients_w_auto = compute_gradients(slerp_kernel_forward, t, 3)
1404
+
1405
+ assert_np_equal(gradients_x, gradients_x_auto, tol=eps)
1406
+ assert_np_equal(gradients_y, gradients_y_auto, tol=eps)
1407
+ assert_np_equal(gradients_z, gradients_z_auto, tol=eps)
1408
+ assert_np_equal(gradients_w, gradients_w_auto, tol=eps)
1409
+ assert_np_equal(xcmp, xcmp_auto, tol=eps)
1410
+ assert_np_equal(ycmp, ycmp_auto, tol=eps)
1411
+ assert_np_equal(zcmp, zcmp_auto, tol=eps)
1412
+ assert_np_equal(wcmp, wcmp_auto, tol=eps)
1413
+
1414
+ # wrt q0
1415
+
1416
+ # gather gradients from builtin adjoints
1417
+ xcmp, gradients_x = compute_gradients(slerp_kernel, q0, 0)
1418
+ ycmp, gradients_y = compute_gradients(slerp_kernel, q0, 1)
1419
+ zcmp, gradients_z = compute_gradients(slerp_kernel, q0, 2)
1420
+ wcmp, gradients_w = compute_gradients(slerp_kernel, q0, 3)
1421
+
1422
+ # gather gradients from autodiff
1423
+ xcmp_auto, gradients_x_auto = compute_gradients(slerp_kernel_forward, q0, 0)
1424
+ ycmp_auto, gradients_y_auto = compute_gradients(slerp_kernel_forward, q0, 1)
1425
+ zcmp_auto, gradients_z_auto = compute_gradients(slerp_kernel_forward, q0, 2)
1426
+ wcmp_auto, gradients_w_auto = compute_gradients(slerp_kernel_forward, q0, 3)
1427
+
1428
+ assert_np_equal(gradients_x, gradients_x_auto, tol=eps)
1429
+ assert_np_equal(gradients_y, gradients_y_auto, tol=eps)
1430
+ assert_np_equal(gradients_z, gradients_z_auto, tol=eps)
1431
+ assert_np_equal(gradients_w, gradients_w_auto, tol=eps)
1432
+ assert_np_equal(xcmp, xcmp_auto, tol=eps)
1433
+ assert_np_equal(ycmp, ycmp_auto, tol=eps)
1434
+ assert_np_equal(zcmp, zcmp_auto, tol=eps)
1435
+ assert_np_equal(wcmp, wcmp_auto, tol=eps)
1436
+
1437
+ # wrt q1
1438
+
1439
+ # gather gradients from builtin adjoints
1440
+ xcmp, gradients_x = compute_gradients(slerp_kernel, q1, 0)
1441
+ ycmp, gradients_y = compute_gradients(slerp_kernel, q1, 1)
1442
+ zcmp, gradients_z = compute_gradients(slerp_kernel, q1, 2)
1443
+ wcmp, gradients_w = compute_gradients(slerp_kernel, q1, 3)
1444
+
1445
+ # gather gradients from autodiff
1446
+ xcmp_auto, gradients_x_auto = compute_gradients(slerp_kernel_forward, q1, 0)
1447
+ ycmp_auto, gradients_y_auto = compute_gradients(slerp_kernel_forward, q1, 1)
1448
+ zcmp_auto, gradients_z_auto = compute_gradients(slerp_kernel_forward, q1, 2)
1449
+ wcmp_auto, gradients_w_auto = compute_gradients(slerp_kernel_forward, q1, 3)
1450
+
1451
+ assert_np_equal(gradients_x, gradients_x_auto, tol=eps)
1452
+ assert_np_equal(gradients_y, gradients_y_auto, tol=eps)
1453
+ assert_np_equal(gradients_z, gradients_z_auto, tol=eps)
1454
+ assert_np_equal(gradients_w, gradients_w_auto, tol=eps)
1455
+ assert_np_equal(xcmp, xcmp_auto, tol=eps)
1456
+ assert_np_equal(ycmp, ycmp_auto, tol=eps)
1457
+ assert_np_equal(zcmp, zcmp_auto, tol=eps)
1458
+ assert_np_equal(wcmp, wcmp_auto, tol=eps)
1459
+
1460
+
1461
+ ############################################################
1462
+
1463
+
1464
+ def test_quat_to_axis_angle_grad(test, device, dtype, register_kernels=False):
1465
+ rng = np.random.default_rng(123)
1466
+ seed = 42
1467
+ num_rand = 50
1468
+
1469
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1470
+ vec3 = wp.types.vector(3, wptype)
1471
+ vec4 = wp.types.vector(4, wptype)
1472
+ quat = wp.types.quaternion(wptype)
1473
+
1474
+ def quat_to_axis_angle_kernel(quats: wp.array(dtype=quat), loss: wp.array(dtype=wptype), coord_idx: int):
1475
+ tid = wp.tid()
1476
+ axis = vec3()
1477
+ angle = wptype(0.0)
1478
+
1479
+ wp.quat_to_axis_angle(quats[tid], axis, angle)
1480
+ a = vec4(axis[0], axis[1], axis[2], angle)
1481
+
1482
+ wp.atomic_add(loss, 0, a[coord_idx])
1483
+
1484
+ quat_to_axis_angle_kernel = getkernel(quat_to_axis_angle_kernel, suffix=dtype.__name__)
1485
+
1486
+ def quat_to_axis_angle_kernel_forward(quats: wp.array(dtype=quat), loss: wp.array(dtype=wptype), coord_idx: int):
1487
+ tid = wp.tid()
1488
+ q = quats[tid]
1489
+ axis = vec3()
1490
+ angle = wptype(0.0)
1491
+
1492
+ v = vec3(q[0], q[1], q[2])
1493
+ if q[3] < wptype(0):
1494
+ axis = -wp.normalize(v)
1495
+ else:
1496
+ axis = wp.normalize(v)
1497
+
1498
+ angle = wptype(2) * wp.atan2(wp.length(v), wp.abs(q[3]))
1499
+ a = vec4(axis[0], axis[1], axis[2], angle)
1500
+
1501
+ wp.atomic_add(loss, 0, a[coord_idx])
1502
+
1503
+ quat_to_axis_angle_kernel_forward = getkernel(quat_to_axis_angle_kernel_forward, suffix=dtype.__name__)
1504
+
1505
+ def quat_sampler(kernel_seed: int, angles: wp.array(dtype=float), quats: wp.array(dtype=quat)):
1506
+ tid = wp.tid()
1507
+
1508
+ state = wp.rand_init(kernel_seed, tid)
1509
+
1510
+ angle = angles[tid]
1511
+ dir = wp.sample_unit_sphere_surface(state) * wp.sin(angle * 0.5)
1512
+
1513
+ q = quat(wptype(dir[0]), wptype(dir[1]), wptype(dir[2]), wptype(wp.cos(angle * 0.5)))
1514
+ qn = wp.normalize(q)
1515
+
1516
+ quats[tid] = qn
1517
+
1518
+ quat_sampler = getkernel(quat_sampler, suffix=dtype.__name__)
1519
+
1520
+ if register_kernels:
1521
+ return
1522
+
1523
+ quats = wp.zeros(num_rand, dtype=quat, device=device, requires_grad=True)
1524
+ angles = wp.array(
1525
+ np.linspace(0.0, 2.0 * np.pi, num_rand, endpoint=False, dtype=np.float32), dtype=float, device=device
1526
+ )
1527
+ wp.launch(kernel=quat_sampler, dim=num_rand, inputs=[seed, angles, quats], device=device)
1528
+
1529
+ edge_cases = np.array(
1530
+ [(1.0, 0.0, 0.0, 0.0), (0.0, 1.0 / np.sqrt(3), 1.0 / np.sqrt(3), 1.0 / np.sqrt(3)), (0.0, 0.0, 0.0, 0.0)]
1531
+ )
1532
+ num_edge = len(edge_cases)
1533
+ edge_cases = wp.array(edge_cases, dtype=quat, device=device, requires_grad=True)
1534
+
1535
+ def compute_gradients(arr, kernel, dim, index):
1536
+ loss = wp.zeros(1, dtype=wptype, device=device, requires_grad=True)
1537
+ tape = wp.Tape()
1538
+ with tape:
1539
+ wp.launch(kernel=kernel, dim=dim, inputs=[arr, loss, index], device=device)
1540
+
1541
+ tape.backward(loss)
1542
+
1543
+ gradients = 1.0 * tape.gradients[arr].numpy()
1544
+ tape.zero()
1545
+
1546
+ return loss.numpy()[0], gradients
1547
+
1548
+ # gather gradients from builtin adjoints
1549
+ xcmp, gradients_x = compute_gradients(quats, quat_to_axis_angle_kernel, num_rand, 0)
1550
+ ycmp, gradients_y = compute_gradients(quats, quat_to_axis_angle_kernel, num_rand, 1)
1551
+ zcmp, gradients_z = compute_gradients(quats, quat_to_axis_angle_kernel, num_rand, 2)
1552
+ wcmp, gradients_w = compute_gradients(quats, quat_to_axis_angle_kernel, num_rand, 3)
1553
+
1554
+ # gather gradients from autodiff
1555
+ xcmp_auto, gradients_x_auto = compute_gradients(quats, quat_to_axis_angle_kernel_forward, num_rand, 0)
1556
+ ycmp_auto, gradients_y_auto = compute_gradients(quats, quat_to_axis_angle_kernel_forward, num_rand, 1)
1557
+ zcmp_auto, gradients_z_auto = compute_gradients(quats, quat_to_axis_angle_kernel_forward, num_rand, 2)
1558
+ wcmp_auto, gradients_w_auto = compute_gradients(quats, quat_to_axis_angle_kernel_forward, num_rand, 3)
1559
+
1560
+ # edge cases: gather gradients from builtin adjoints
1561
+ _, edge_gradients_x = compute_gradients(edge_cases, quat_to_axis_angle_kernel, num_edge, 0)
1562
+ _, edge_gradients_y = compute_gradients(edge_cases, quat_to_axis_angle_kernel, num_edge, 1)
1563
+ _, edge_gradients_z = compute_gradients(edge_cases, quat_to_axis_angle_kernel, num_edge, 2)
1564
+ _, edge_gradients_w = compute_gradients(edge_cases, quat_to_axis_angle_kernel, num_edge, 3)
1565
+
1566
+ # edge cases: gather gradients from autodiff
1567
+ _, edge_gradients_x_auto = compute_gradients(edge_cases, quat_to_axis_angle_kernel_forward, num_edge, 0)
1568
+ _, edge_gradients_y_auto = compute_gradients(edge_cases, quat_to_axis_angle_kernel_forward, num_edge, 1)
1569
+ _, edge_gradients_z_auto = compute_gradients(edge_cases, quat_to_axis_angle_kernel_forward, num_edge, 2)
1570
+ _, edge_gradients_w_auto = compute_gradients(edge_cases, quat_to_axis_angle_kernel_forward, num_edge, 3)
1571
+
1572
+ eps = {
1573
+ np.float16: 2.0e-1,
1574
+ np.float32: 2.0e-4,
1575
+ np.float64: 2.0e-7,
1576
+ }.get(dtype, 0)
1577
+
1578
+ assert_np_equal(xcmp, xcmp_auto, tol=eps)
1579
+ assert_np_equal(ycmp, ycmp_auto, tol=eps)
1580
+ assert_np_equal(zcmp, zcmp_auto, tol=eps)
1581
+ assert_np_equal(wcmp, wcmp_auto, tol=eps)
1582
+
1583
+ assert_np_equal(gradients_x, gradients_x_auto, tol=eps)
1584
+ assert_np_equal(gradients_y, gradients_y_auto, tol=eps)
1585
+ assert_np_equal(gradients_z, gradients_z_auto, tol=eps)
1586
+ assert_np_equal(gradients_w, gradients_w_auto, tol=eps)
1587
+
1588
+ assert_np_equal(edge_gradients_x, edge_gradients_x_auto, tol=eps)
1589
+ assert_np_equal(edge_gradients_y, edge_gradients_y_auto, tol=eps)
1590
+ assert_np_equal(edge_gradients_z, edge_gradients_z_auto, tol=eps)
1591
+ assert_np_equal(edge_gradients_w, edge_gradients_w_auto, tol=eps)
1592
+
1593
+
1594
+ ############################################################
1595
+
1596
+
1597
+ def test_quat_rpy_grad(test, device, dtype, register_kernels=False):
1598
+ rng = np.random.default_rng(123)
1599
+ N = 3
1600
+
1601
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1602
+
1603
+ vec3 = wp.types.vector(3, wptype)
1604
+ quat = wp.types.quaternion(wptype)
1605
+
1606
+ def rpy_to_quat_kernel(rpy_arr: wp.array(dtype=vec3), loss: wp.array(dtype=wptype), coord_idx: int):
1607
+ tid = wp.tid()
1608
+ rpy = rpy_arr[tid]
1609
+ roll = rpy[0]
1610
+ pitch = rpy[1]
1611
+ yaw = rpy[2]
1612
+
1613
+ q = wp.quat_rpy(roll, pitch, yaw)
1614
+
1615
+ wp.atomic_add(loss, 0, q[coord_idx])
1616
+
1617
+ rpy_to_quat_kernel = getkernel(rpy_to_quat_kernel, suffix=dtype.__name__)
1618
+
1619
+ def rpy_to_quat_kernel_forward(rpy_arr: wp.array(dtype=vec3), loss: wp.array(dtype=wptype), coord_idx: int):
1620
+ tid = wp.tid()
1621
+ rpy = rpy_arr[tid]
1622
+ roll = rpy[0]
1623
+ pitch = rpy[1]
1624
+ yaw = rpy[2]
1625
+
1626
+ cy = wp.cos(yaw * wptype(0.5))
1627
+ sy = wp.sin(yaw * wptype(0.5))
1628
+ cr = wp.cos(roll * wptype(0.5))
1629
+ sr = wp.sin(roll * wptype(0.5))
1630
+ cp = wp.cos(pitch * wptype(0.5))
1631
+ sp = wp.sin(pitch * wptype(0.5))
1632
+
1633
+ w = cy * cr * cp + sy * sr * sp
1634
+ x = cy * sr * cp - sy * cr * sp
1635
+ y = cy * cr * sp + sy * sr * cp
1636
+ z = sy * cr * cp - cy * sr * sp
1637
+
1638
+ q = quat(x, y, z, w)
1639
+
1640
+ wp.atomic_add(loss, 0, q[coord_idx])
1641
+
1642
+ rpy_to_quat_kernel_forward = getkernel(rpy_to_quat_kernel_forward, suffix=dtype.__name__)
1643
+
1644
+ if register_kernels:
1645
+ return
1646
+
1647
+ rpy_arr = rng.uniform(low=-np.pi, high=np.pi, size=(N, 3))
1648
+ rpy_arr = wp.array(rpy_arr, dtype=vec3, device=device, requires_grad=True)
1649
+
1650
+ def compute_gradients(kernel, wrt, index):
1651
+ loss = wp.zeros(1, dtype=wptype, device=device, requires_grad=True)
1652
+ tape = wp.Tape()
1653
+ with tape:
1654
+ wp.launch(kernel=kernel, dim=N, inputs=[wrt, loss, index], device=device)
1655
+
1656
+ tape.backward(loss)
1657
+
1658
+ gradients = 1.0 * tape.gradients[wrt].numpy()
1659
+ tape.zero()
1660
+
1661
+ return loss.numpy()[0], gradients
1662
+
1663
+ # wrt rpy
1664
+ # gather gradients from builtin adjoints
1665
+ rcmp, gradients_r = compute_gradients(rpy_to_quat_kernel, rpy_arr, 0)
1666
+ pcmp, gradients_p = compute_gradients(rpy_to_quat_kernel, rpy_arr, 1)
1667
+ ycmp, gradients_y = compute_gradients(rpy_to_quat_kernel, rpy_arr, 2)
1668
+
1669
+ # gather gradients from autodiff
1670
+ rcmp_auto, gradients_r_auto = compute_gradients(rpy_to_quat_kernel_forward, rpy_arr, 0)
1671
+ pcmp_auto, gradients_p_auto = compute_gradients(rpy_to_quat_kernel_forward, rpy_arr, 1)
1672
+ ycmp_auto, gradients_y_auto = compute_gradients(rpy_to_quat_kernel_forward, rpy_arr, 2)
1673
+
1674
+ eps = {
1675
+ np.float16: 2.0e-2,
1676
+ np.float32: 1.0e-5,
1677
+ np.float64: 1.0e-8,
1678
+ }.get(dtype, 0)
1679
+
1680
+ assert_np_equal(rcmp, rcmp_auto, tol=eps)
1681
+ assert_np_equal(pcmp, pcmp_auto, tol=eps)
1682
+ assert_np_equal(ycmp, ycmp_auto, tol=eps)
1683
+
1684
+ assert_np_equal(gradients_r, gradients_r_auto, tol=eps)
1685
+ assert_np_equal(gradients_p, gradients_p_auto, tol=eps)
1686
+ assert_np_equal(gradients_y, gradients_y_auto, tol=eps)
1687
+
1688
+
1689
+ ############################################################
1690
+
1691
+
1692
+ def test_quat_from_matrix(test, device, dtype, register_kernels=False):
1693
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1694
+ mat33 = wp.types.matrix((3, 3), wptype)
1695
+ quat = wp.types.quaternion(wptype)
1696
+
1697
+ def quat_from_matrix(m: wp.array2d(dtype=wptype), loss: wp.array(dtype=wptype), idx: int):
1698
+ tid = wp.tid()
1699
+
1700
+ matrix = mat33(
1701
+ m[tid, 0], m[tid, 1], m[tid, 2], m[tid, 3], m[tid, 4], m[tid, 5], m[tid, 6], m[tid, 7], m[tid, 8]
1702
+ )
1703
+
1704
+ q = wp.quat_from_matrix(matrix)
1705
+
1706
+ wp.atomic_add(loss, 0, q[idx])
1707
+
1708
+ def quat_from_matrix_forward(mats: wp.array2d(dtype=wptype), loss: wp.array(dtype=wptype), idx: int):
1709
+ tid = wp.tid()
1710
+
1711
+ m = mat33(
1712
+ mats[tid, 0],
1713
+ mats[tid, 1],
1714
+ mats[tid, 2],
1715
+ mats[tid, 3],
1716
+ mats[tid, 4],
1717
+ mats[tid, 5],
1718
+ mats[tid, 6],
1719
+ mats[tid, 7],
1720
+ mats[tid, 8],
1721
+ )
1722
+
1723
+ tr = m[0][0] + m[1][1] + m[2][2]
1724
+ x = wptype(0)
1725
+ y = wptype(0)
1726
+ z = wptype(0)
1727
+ w = wptype(0)
1728
+ h = wptype(0)
1729
+
1730
+ if tr >= wptype(0):
1731
+ h = wp.sqrt(tr + wptype(1))
1732
+ w = wptype(0.5) * h
1733
+ h = wptype(0.5) / h
1734
+
1735
+ x = (m[2][1] - m[1][2]) * h
1736
+ y = (m[0][2] - m[2][0]) * h
1737
+ z = (m[1][0] - m[0][1]) * h
1738
+ else:
1739
+ max_diag = 0
1740
+ if m[1][1] > m[0][0]:
1741
+ max_diag = 1
1742
+ if m[2][2] > m[max_diag][max_diag]:
1743
+ max_diag = 2
1744
+
1745
+ if max_diag == 0:
1746
+ h = wp.sqrt((m[0][0] - (m[1][1] + m[2][2])) + wptype(1))
1747
+ x = wptype(0.5) * h
1748
+ h = wptype(0.5) / h
1749
+
1750
+ y = (m[0][1] + m[1][0]) * h
1751
+ z = (m[2][0] + m[0][2]) * h
1752
+ w = (m[2][1] - m[1][2]) * h
1753
+ elif max_diag == 1:
1754
+ h = wp.sqrt((m[1][1] - (m[2][2] + m[0][0])) + wptype(1))
1755
+ y = wptype(0.5) * h
1756
+ h = wptype(0.5) / h
1757
+
1758
+ z = (m[1][2] + m[2][1]) * h
1759
+ x = (m[0][1] + m[1][0]) * h
1760
+ w = (m[0][2] - m[2][0]) * h
1761
+ if max_diag == 2:
1762
+ h = wp.sqrt((m[2][2] - (m[0][0] + m[1][1])) + wptype(1))
1763
+ z = wptype(0.5) * h
1764
+ h = wptype(0.5) / h
1765
+
1766
+ x = (m[2][0] + m[0][2]) * h
1767
+ y = (m[1][2] + m[2][1]) * h
1768
+ w = (m[1][0] - m[0][1]) * h
1769
+
1770
+ q = wp.normalize(quat(x, y, z, w))
1771
+
1772
+ wp.atomic_add(loss, 0, q[idx])
1773
+
1774
+ quat_from_matrix = getkernel(quat_from_matrix, suffix=dtype.__name__)
1775
+ quat_from_matrix_forward = getkernel(quat_from_matrix_forward, suffix=dtype.__name__)
1776
+
1777
+ if register_kernels:
1778
+ return
1779
+
1780
+ m = np.array(
1781
+ [
1782
+ [1.0, 0.0, 0.0, 0.0, 0.5, 0.866, 0.0, -0.866, 0.5],
1783
+ [0.866, 0.0, 0.25, -0.433, 0.5, 0.75, -0.25, -0.866, 0.433],
1784
+ [0.866, -0.433, 0.25, 0.0, 0.5, 0.866, -0.5, -0.75, 0.433],
1785
+ [-1.2, -1.6, -2.3, 0.25, -0.6, -0.33, 3.2, -1.0, -2.2],
1786
+ ]
1787
+ )
1788
+ m = wp.array2d(m, dtype=wptype, device=device, requires_grad=True)
1789
+
1790
+ N = m.shape[0]
1791
+
1792
+ def compute_gradients(kernel, wrt, index):
1793
+ loss = wp.zeros(1, dtype=wptype, device=device, requires_grad=True)
1794
+ tape = wp.Tape()
1795
+
1796
+ with tape:
1797
+ wp.launch(kernel=kernel, dim=N, inputs=[m, loss, index], device=device)
1798
+
1799
+ tape.backward(loss)
1800
+
1801
+ gradients = 1.0 * tape.gradients[wrt].numpy()
1802
+ tape.zero()
1803
+
1804
+ return loss.numpy()[0], gradients
1805
+
1806
+ # gather gradients from builtin adjoints
1807
+ cmpx, gradients_x = compute_gradients(quat_from_matrix, m, 0)
1808
+ cmpy, gradients_y = compute_gradients(quat_from_matrix, m, 1)
1809
+ cmpz, gradients_z = compute_gradients(quat_from_matrix, m, 2)
1810
+ cmpw, gradients_w = compute_gradients(quat_from_matrix, m, 3)
1811
+
1812
+ # gather gradients from autodiff
1813
+ cmpx_auto, gradients_x_auto = compute_gradients(quat_from_matrix_forward, m, 0)
1814
+ cmpy_auto, gradients_y_auto = compute_gradients(quat_from_matrix_forward, m, 1)
1815
+ cmpz_auto, gradients_z_auto = compute_gradients(quat_from_matrix_forward, m, 2)
1816
+ cmpw_auto, gradients_w_auto = compute_gradients(quat_from_matrix_forward, m, 3)
1817
+
1818
+ # compare
1819
+ eps = 1.0e6
1820
+
1821
+ eps = {
1822
+ np.float16: 2.0e-2,
1823
+ np.float32: 1.0e-5,
1824
+ np.float64: 1.0e-8,
1825
+ }.get(dtype, 0)
1826
+
1827
+ assert_np_equal(cmpx, cmpx_auto, tol=eps)
1828
+ assert_np_equal(cmpy, cmpy_auto, tol=eps)
1829
+ assert_np_equal(cmpz, cmpz_auto, tol=eps)
1830
+ assert_np_equal(cmpw, cmpw_auto, tol=eps)
1831
+
1832
+ assert_np_equal(gradients_x, gradients_x_auto, tol=eps)
1833
+ assert_np_equal(gradients_y, gradients_y_auto, tol=eps)
1834
+ assert_np_equal(gradients_z, gradients_z_auto, tol=eps)
1835
+ assert_np_equal(gradients_w, gradients_w_auto, tol=eps)
1836
+
1837
+
1838
+ def test_quat_identity(test, device, dtype, register_kernels=False):
1839
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1840
+
1841
+ def quat_identity_test(output: wp.array(dtype=wptype)):
1842
+ q = wp.quat_identity(dtype=wptype)
1843
+ output[0] = q[0]
1844
+ output[1] = q[1]
1845
+ output[2] = q[2]
1846
+ output[3] = q[3]
1847
+
1848
+ def quat_identity_test_default(output: wp.array(dtype=wp.float32)):
1849
+ q = wp.quat_identity()
1850
+ output[0] = q[0]
1851
+ output[1] = q[1]
1852
+ output[2] = q[2]
1853
+ output[3] = q[3]
1854
+
1855
+ quat_identity_kernel = getkernel(quat_identity_test, suffix=dtype.__name__)
1856
+ quat_identity_default_kernel = getkernel(quat_identity_test_default, suffix=np.float32.__name__)
1857
+
1858
+ if register_kernels:
1859
+ return
1860
+
1861
+ output = wp.zeros(4, dtype=wptype, device=device)
1862
+ wp.launch(quat_identity_kernel, dim=1, inputs=[], outputs=[output], device=device)
1863
+ expected = np.zeros_like(output.numpy())
1864
+ expected[3] = 1
1865
+ assert_np_equal(output.numpy(), expected)
1866
+
1867
+ # let's just test that it defaults to float32:
1868
+ output = wp.zeros(4, dtype=wp.float32, device=device)
1869
+ wp.launch(quat_identity_default_kernel, dim=1, inputs=[], outputs=[output], device=device)
1870
+ expected = np.zeros_like(output.numpy())
1871
+ expected[3] = 1
1872
+ assert_np_equal(output.numpy(), expected)
1873
+
1874
+
1875
+ ############################################################
1876
+
1877
+
1878
+ def test_quat_euler_conversion(test, device, dtype, register_kernels=False):
1879
+ rng = np.random.default_rng(123)
1880
+ N = 3
1881
+
1882
+ rpy_arr = rng.uniform(low=-np.pi, high=np.pi, size=(N, 3))
1883
+
1884
+ quats_from_euler = [list(wp.sim.quat_from_euler(wp.vec3(*rpy), 0, 1, 2)) for rpy in rpy_arr]
1885
+ quats_from_rpy = [list(wp.quat_rpy(rpy[0], rpy[1], rpy[2])) for rpy in rpy_arr]
1886
+
1887
+ assert_np_equal(np.array(quats_from_euler), np.array(quats_from_rpy), tol=1e-4)
1888
+
1889
+
1890
+ def test_anon_type_instance(test, device, dtype, register_kernels=False):
1891
+ rng = np.random.default_rng(123)
1892
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1893
+
1894
+ def quat_create_test(input: wp.array(dtype=wptype), output: wp.array(dtype=wptype)):
1895
+ # component constructor:
1896
+ q = wp.quaternion(input[0], input[1], input[2], input[3])
1897
+ output[0] = wptype(2) * q[0]
1898
+ output[1] = wptype(2) * q[1]
1899
+ output[2] = wptype(2) * q[2]
1900
+ output[3] = wptype(2) * q[3]
1901
+
1902
+ # vector / scalar constructor:
1903
+ q2 = wp.quaternion(wp.vector(input[4], input[5], input[6]), input[7])
1904
+ output[4] = wptype(2) * q2[0]
1905
+ output[5] = wptype(2) * q2[1]
1906
+ output[6] = wptype(2) * q2[2]
1907
+ output[7] = wptype(2) * q2[3]
1908
+
1909
+ quat_create_kernel = getkernel(quat_create_test, suffix=dtype.__name__)
1910
+ output_select_kernel = get_select_kernel(wptype)
1911
+
1912
+ if register_kernels:
1913
+ return
1914
+
1915
+ input = wp.array(rng.standard_normal(size=8).astype(dtype), requires_grad=True, device=device)
1916
+ output = wp.zeros(8, dtype=wptype, requires_grad=True, device=device)
1917
+ wp.launch(quat_create_kernel, dim=1, inputs=[input], outputs=[output], device=device)
1918
+ assert_np_equal(output.numpy(), 2 * input.numpy())
1919
+
1920
+ for i in range(len(input)):
1921
+ cmp = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1922
+ tape = wp.Tape()
1923
+ with tape:
1924
+ wp.launch(quat_create_kernel, dim=1, inputs=[input], outputs=[output], device=device)
1925
+ wp.launch(output_select_kernel, dim=1, inputs=[output, i], outputs=[cmp], device=device)
1926
+ tape.backward(loss=cmp)
1927
+ expectedgrads = np.zeros(len(input))
1928
+ expectedgrads[i] = 2
1929
+ assert_np_equal(tape.gradients[input].numpy(), expectedgrads)
1930
+ tape.zero()
1931
+
1932
+
1933
+ # Same as above but with a default (float) type
1934
+ # which tests some different code paths that
1935
+ # need to ensure types are correctly canonicalized
1936
+ # during codegen
1937
+ @wp.kernel
1938
+ def test_constructor_default():
1939
+ qzero = wp.quat()
1940
+ wp.expect_eq(qzero[0], 0.0)
1941
+ wp.expect_eq(qzero[1], 0.0)
1942
+ wp.expect_eq(qzero[2], 0.0)
1943
+ wp.expect_eq(qzero[3], 0.0)
1944
+
1945
+ qval = wp.quat(1.0, 2.0, 3.0, 4.0)
1946
+ wp.expect_eq(qval[0], 1.0)
1947
+ wp.expect_eq(qval[1], 2.0)
1948
+ wp.expect_eq(qval[2], 3.0)
1949
+ wp.expect_eq(qval[3], 4.0)
1950
+
1951
+ qeye = wp.quat_identity()
1952
+ wp.expect_eq(qeye[0], 0.0)
1953
+ wp.expect_eq(qeye[1], 0.0)
1954
+ wp.expect_eq(qeye[2], 0.0)
1955
+ wp.expect_eq(qeye[3], 1.0)
1956
+
1957
+
1958
+ def test_py_arithmetic_ops(test, device, dtype):
1959
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1960
+
1961
+ def make_quat(*args):
1962
+ if wptype in wp.types.int_types:
1963
+ # Cast to the correct integer type to simulate wrapping.
1964
+ return tuple(wptype._type_(x).value for x in args)
1965
+
1966
+ return args
1967
+
1968
+ quat_cls = wp.types.quaternion(wptype)
1969
+
1970
+ v = quat_cls(1, -2, 3, -4)
1971
+ test.assertSequenceEqual(+v, make_quat(1, -2, 3, -4))
1972
+ test.assertSequenceEqual(-v, make_quat(-1, 2, -3, 4))
1973
+ test.assertSequenceEqual(v + quat_cls(5, 5, 5, 5), make_quat(6, 3, 8, 1))
1974
+ test.assertSequenceEqual(v - quat_cls(5, 5, 5, 5), make_quat(-4, -7, -2, -9))
1975
+
1976
+ v = quat_cls(2, 4, 6, 8)
1977
+ test.assertSequenceEqual(v * wptype(2), make_quat(4, 8, 12, 16))
1978
+ test.assertSequenceEqual(wptype(2) * v, make_quat(4, 8, 12, 16))
1979
+ test.assertSequenceEqual(v / wptype(2), make_quat(1, 2, 3, 4))
1980
+ test.assertSequenceEqual(wptype(24) / v, make_quat(12, 6, 4, 3))
1981
+
1982
+
1983
+ devices = get_test_devices()
1984
+
1985
+
1986
+ class TestQuat(unittest.TestCase):
1987
+ pass
1988
+
1989
+
1990
+ add_kernel_test(TestQuat, test_constructor_default, dim=1, devices=devices)
1991
+
1992
+ for dtype in np_float_types:
1993
+ add_function_test_register_kernel(
1994
+ TestQuat, f"test_constructors_{dtype.__name__}", test_constructors, devices=devices, dtype=dtype
1995
+ )
1996
+ add_function_test_register_kernel(
1997
+ TestQuat,
1998
+ f"test_casting_constructors_{dtype.__name__}",
1999
+ test_casting_constructors,
2000
+ devices=devices,
2001
+ dtype=dtype,
2002
+ )
2003
+ add_function_test_register_kernel(
2004
+ TestQuat, f"test_anon_type_instance_{dtype.__name__}", test_anon_type_instance, devices=devices, dtype=dtype
2005
+ )
2006
+ add_function_test_register_kernel(
2007
+ TestQuat, f"test_inverse_{dtype.__name__}", test_inverse, devices=devices, dtype=dtype
2008
+ )
2009
+ add_function_test_register_kernel(
2010
+ TestQuat, f"test_quat_identity_{dtype.__name__}", test_quat_identity, devices=devices, dtype=dtype
2011
+ )
2012
+ add_function_test_register_kernel(
2013
+ TestQuat, f"test_dotproduct_{dtype.__name__}", test_dotproduct, devices=devices, dtype=dtype
2014
+ )
2015
+ add_function_test_register_kernel(
2016
+ TestQuat, f"test_length_{dtype.__name__}", test_length, devices=devices, dtype=dtype
2017
+ )
2018
+ add_function_test_register_kernel(
2019
+ TestQuat, f"test_normalize_{dtype.__name__}", test_normalize, devices=devices, dtype=dtype
2020
+ )
2021
+ add_function_test_register_kernel(
2022
+ TestQuat, f"test_addition_{dtype.__name__}", test_addition, devices=devices, dtype=dtype
2023
+ )
2024
+ add_function_test_register_kernel(
2025
+ TestQuat, f"test_subtraction_{dtype.__name__}", test_subtraction, devices=devices, dtype=dtype
2026
+ )
2027
+ add_function_test_register_kernel(
2028
+ TestQuat,
2029
+ f"test_scalar_multiplication_{dtype.__name__}",
2030
+ test_scalar_multiplication,
2031
+ devices=devices,
2032
+ dtype=dtype,
2033
+ )
2034
+ add_function_test_register_kernel(
2035
+ TestQuat, f"test_scalar_division_{dtype.__name__}", test_scalar_division, devices=devices, dtype=dtype
2036
+ )
2037
+ add_function_test_register_kernel(
2038
+ TestQuat,
2039
+ f"test_quat_multiplication_{dtype.__name__}",
2040
+ test_quat_multiplication,
2041
+ devices=devices,
2042
+ dtype=dtype,
2043
+ )
2044
+ add_function_test_register_kernel(
2045
+ TestQuat, f"test_indexing_{dtype.__name__}", test_indexing, devices=devices, dtype=dtype
2046
+ )
2047
+ add_function_test_register_kernel(
2048
+ TestQuat, f"test_quat_lerp_{dtype.__name__}", test_quat_lerp, devices=devices, dtype=dtype
2049
+ )
2050
+ add_function_test_register_kernel(
2051
+ TestQuat,
2052
+ f"test_quat_to_axis_angle_grad_{dtype.__name__}",
2053
+ test_quat_to_axis_angle_grad,
2054
+ devices=devices,
2055
+ dtype=dtype,
2056
+ )
2057
+ add_function_test_register_kernel(
2058
+ TestQuat, f"test_slerp_grad_{dtype.__name__}", test_slerp_grad, devices=devices, dtype=dtype
2059
+ )
2060
+ add_function_test_register_kernel(
2061
+ TestQuat, f"test_quat_rpy_grad_{dtype.__name__}", test_quat_rpy_grad, devices=devices, dtype=dtype
2062
+ )
2063
+ add_function_test_register_kernel(
2064
+ TestQuat, f"test_quat_from_matrix_{dtype.__name__}", test_quat_from_matrix, devices=devices, dtype=dtype
2065
+ )
2066
+ add_function_test_register_kernel(
2067
+ TestQuat, f"test_quat_rotate_{dtype.__name__}", test_quat_rotate, devices=devices, dtype=dtype
2068
+ )
2069
+ add_function_test_register_kernel(
2070
+ TestQuat, f"test_quat_to_matrix_{dtype.__name__}", test_quat_to_matrix, devices=devices, dtype=dtype
2071
+ )
2072
+ add_function_test_register_kernel(
2073
+ TestQuat,
2074
+ f"test_quat_euler_conversion_{dtype.__name__}",
2075
+ test_quat_euler_conversion,
2076
+ devices=devices,
2077
+ dtype=dtype,
2078
+ )
2079
+ add_function_test(
2080
+ TestQuat, f"test_py_arithmetic_ops_{dtype.__name__}", test_py_arithmetic_ops, devices=None, dtype=dtype
2081
+ )
2082
+
2083
+
2084
+ if __name__ == "__main__":
2085
+ wp.build.clear_kernel_cache()
2086
+ unittest.main(verbosity=2)