warp-lang 1.0.2__py3-none-win_amd64.whl → 1.2.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (356) hide show
  1. warp/__init__.py +108 -97
  2. warp/__init__.pyi +1 -1
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +88 -113
  6. warp/build_dll.py +383 -375
  7. warp/builtins.py +3693 -3354
  8. warp/codegen.py +2925 -2792
  9. warp/config.py +40 -36
  10. warp/constants.py +49 -45
  11. warp/context.py +5409 -5102
  12. warp/dlpack.py +442 -442
  13. warp/examples/__init__.py +16 -16
  14. warp/examples/assets/bear.usd +0 -0
  15. warp/examples/assets/bunny.usd +0 -0
  16. warp/examples/assets/cartpole.urdf +110 -110
  17. warp/examples/assets/crazyflie.usd +0 -0
  18. warp/examples/assets/cube.usd +0 -0
  19. warp/examples/assets/nv_ant.xml +92 -92
  20. warp/examples/assets/nv_humanoid.xml +183 -183
  21. warp/examples/assets/quadruped.urdf +267 -267
  22. warp/examples/assets/rocks.nvdb +0 -0
  23. warp/examples/assets/rocks.usd +0 -0
  24. warp/examples/assets/sphere.usd +0 -0
  25. warp/examples/benchmarks/benchmark_api.py +381 -383
  26. warp/examples/benchmarks/benchmark_cloth.py +278 -277
  27. warp/examples/benchmarks/benchmark_cloth_cupy.py +88 -88
  28. warp/examples/benchmarks/benchmark_cloth_jax.py +97 -100
  29. warp/examples/benchmarks/benchmark_cloth_numba.py +146 -142
  30. warp/examples/benchmarks/benchmark_cloth_numpy.py +77 -77
  31. warp/examples/benchmarks/benchmark_cloth_pytorch.py +86 -86
  32. warp/examples/benchmarks/benchmark_cloth_taichi.py +112 -112
  33. warp/examples/benchmarks/benchmark_cloth_warp.py +145 -146
  34. warp/examples/benchmarks/benchmark_launches.py +293 -295
  35. warp/examples/browse.py +29 -29
  36. warp/examples/core/example_dem.py +232 -219
  37. warp/examples/core/example_fluid.py +291 -267
  38. warp/examples/core/example_graph_capture.py +142 -126
  39. warp/examples/core/example_marching_cubes.py +186 -174
  40. warp/examples/core/example_mesh.py +172 -155
  41. warp/examples/core/example_mesh_intersect.py +203 -193
  42. warp/examples/core/example_nvdb.py +174 -170
  43. warp/examples/core/example_raycast.py +103 -90
  44. warp/examples/core/example_raymarch.py +197 -178
  45. warp/examples/core/example_render_opengl.py +183 -141
  46. warp/examples/core/example_sph.py +403 -387
  47. warp/examples/core/example_torch.py +219 -181
  48. warp/examples/core/example_wave.py +261 -248
  49. warp/examples/fem/bsr_utils.py +378 -380
  50. warp/examples/fem/example_apic_fluid.py +432 -389
  51. warp/examples/fem/example_burgers.py +262 -0
  52. warp/examples/fem/example_convection_diffusion.py +180 -168
  53. warp/examples/fem/example_convection_diffusion_dg.py +217 -209
  54. warp/examples/fem/example_deformed_geometry.py +175 -159
  55. warp/examples/fem/example_diffusion.py +199 -173
  56. warp/examples/fem/example_diffusion_3d.py +178 -152
  57. warp/examples/fem/example_diffusion_mgpu.py +219 -214
  58. warp/examples/fem/example_mixed_elasticity.py +242 -222
  59. warp/examples/fem/example_navier_stokes.py +257 -243
  60. warp/examples/fem/example_stokes.py +218 -192
  61. warp/examples/fem/example_stokes_transfer.py +263 -249
  62. warp/examples/fem/mesh_utils.py +133 -109
  63. warp/examples/fem/plot_utils.py +292 -287
  64. warp/examples/optim/example_bounce.py +258 -246
  65. warp/examples/optim/example_cloth_throw.py +220 -209
  66. warp/examples/optim/example_diffray.py +564 -536
  67. warp/examples/optim/example_drone.py +862 -835
  68. warp/examples/optim/example_inverse_kinematics.py +174 -168
  69. warp/examples/optim/example_inverse_kinematics_torch.py +183 -169
  70. warp/examples/optim/example_spring_cage.py +237 -231
  71. warp/examples/optim/example_trajectory.py +221 -199
  72. warp/examples/optim/example_walker.py +304 -293
  73. warp/examples/sim/example_cartpole.py +137 -129
  74. warp/examples/sim/example_cloth.py +194 -186
  75. warp/examples/sim/example_granular.py +122 -111
  76. warp/examples/sim/example_granular_collision_sdf.py +195 -186
  77. warp/examples/sim/example_jacobian_ik.py +234 -214
  78. warp/examples/sim/example_particle_chain.py +116 -105
  79. warp/examples/sim/example_quadruped.py +191 -180
  80. warp/examples/sim/example_rigid_chain.py +195 -187
  81. warp/examples/sim/example_rigid_contact.py +187 -177
  82. warp/examples/sim/example_rigid_force.py +125 -125
  83. warp/examples/sim/example_rigid_gyroscopic.py +107 -95
  84. warp/examples/sim/example_rigid_soft_contact.py +132 -122
  85. warp/examples/sim/example_soft_body.py +188 -177
  86. warp/fabric.py +337 -335
  87. warp/fem/__init__.py +61 -27
  88. warp/fem/cache.py +403 -388
  89. warp/fem/dirichlet.py +178 -179
  90. warp/fem/domain.py +262 -263
  91. warp/fem/field/__init__.py +100 -101
  92. warp/fem/field/field.py +148 -149
  93. warp/fem/field/nodal_field.py +298 -299
  94. warp/fem/field/restriction.py +22 -21
  95. warp/fem/field/test.py +180 -181
  96. warp/fem/field/trial.py +183 -183
  97. warp/fem/geometry/__init__.py +16 -19
  98. warp/fem/geometry/closest_point.py +69 -70
  99. warp/fem/geometry/deformed_geometry.py +270 -271
  100. warp/fem/geometry/element.py +748 -744
  101. warp/fem/geometry/geometry.py +184 -186
  102. warp/fem/geometry/grid_2d.py +380 -373
  103. warp/fem/geometry/grid_3d.py +437 -435
  104. warp/fem/geometry/hexmesh.py +953 -953
  105. warp/fem/geometry/nanogrid.py +455 -0
  106. warp/fem/geometry/partition.py +374 -376
  107. warp/fem/geometry/quadmesh_2d.py +532 -532
  108. warp/fem/geometry/tetmesh.py +840 -840
  109. warp/fem/geometry/trimesh_2d.py +577 -577
  110. warp/fem/integrate.py +1684 -1615
  111. warp/fem/operator.py +190 -191
  112. warp/fem/polynomial.py +214 -213
  113. warp/fem/quadrature/__init__.py +2 -2
  114. warp/fem/quadrature/pic_quadrature.py +243 -245
  115. warp/fem/quadrature/quadrature.py +295 -294
  116. warp/fem/space/__init__.py +179 -292
  117. warp/fem/space/basis_space.py +522 -489
  118. warp/fem/space/collocated_function_space.py +100 -105
  119. warp/fem/space/dof_mapper.py +236 -236
  120. warp/fem/space/function_space.py +148 -145
  121. warp/fem/space/grid_2d_function_space.py +148 -267
  122. warp/fem/space/grid_3d_function_space.py +167 -306
  123. warp/fem/space/hexmesh_function_space.py +253 -352
  124. warp/fem/space/nanogrid_function_space.py +202 -0
  125. warp/fem/space/partition.py +350 -350
  126. warp/fem/space/quadmesh_2d_function_space.py +261 -369
  127. warp/fem/space/restriction.py +161 -160
  128. warp/fem/space/shape/__init__.py +90 -15
  129. warp/fem/space/shape/cube_shape_function.py +728 -738
  130. warp/fem/space/shape/shape_function.py +102 -103
  131. warp/fem/space/shape/square_shape_function.py +611 -611
  132. warp/fem/space/shape/tet_shape_function.py +565 -567
  133. warp/fem/space/shape/triangle_shape_function.py +429 -429
  134. warp/fem/space/tetmesh_function_space.py +224 -292
  135. warp/fem/space/topology.py +297 -295
  136. warp/fem/space/trimesh_2d_function_space.py +153 -221
  137. warp/fem/types.py +77 -77
  138. warp/fem/utils.py +495 -495
  139. warp/jax.py +166 -141
  140. warp/jax_experimental.py +341 -339
  141. warp/native/array.h +1081 -1025
  142. warp/native/builtin.h +1603 -1560
  143. warp/native/bvh.cpp +402 -398
  144. warp/native/bvh.cu +533 -525
  145. warp/native/bvh.h +430 -429
  146. warp/native/clang/clang.cpp +496 -464
  147. warp/native/crt.cpp +42 -32
  148. warp/native/crt.h +352 -335
  149. warp/native/cuda_crt.h +1049 -1049
  150. warp/native/cuda_util.cpp +549 -540
  151. warp/native/cuda_util.h +288 -203
  152. warp/native/cutlass_gemm.cpp +34 -34
  153. warp/native/cutlass_gemm.cu +372 -372
  154. warp/native/error.cpp +66 -66
  155. warp/native/error.h +27 -27
  156. warp/native/exports.h +187 -0
  157. warp/native/fabric.h +228 -228
  158. warp/native/hashgrid.cpp +301 -278
  159. warp/native/hashgrid.cu +78 -77
  160. warp/native/hashgrid.h +227 -227
  161. warp/native/initializer_array.h +32 -32
  162. warp/native/intersect.h +1204 -1204
  163. warp/native/intersect_adj.h +365 -365
  164. warp/native/intersect_tri.h +322 -322
  165. warp/native/marching.cpp +2 -2
  166. warp/native/marching.cu +497 -497
  167. warp/native/marching.h +2 -2
  168. warp/native/mat.h +1545 -1498
  169. warp/native/matnn.h +333 -333
  170. warp/native/mesh.cpp +203 -203
  171. warp/native/mesh.cu +292 -293
  172. warp/native/mesh.h +1887 -1887
  173. warp/native/nanovdb/GridHandle.h +366 -0
  174. warp/native/nanovdb/HostBuffer.h +590 -0
  175. warp/native/nanovdb/NanoVDB.h +6624 -4782
  176. warp/native/nanovdb/PNanoVDB.h +3390 -2553
  177. warp/native/noise.h +850 -850
  178. warp/native/quat.h +1112 -1085
  179. warp/native/rand.h +303 -299
  180. warp/native/range.h +108 -108
  181. warp/native/reduce.cpp +156 -156
  182. warp/native/reduce.cu +348 -348
  183. warp/native/runlength_encode.cpp +61 -61
  184. warp/native/runlength_encode.cu +46 -46
  185. warp/native/scan.cpp +30 -30
  186. warp/native/scan.cu +36 -36
  187. warp/native/scan.h +7 -7
  188. warp/native/solid_angle.h +442 -442
  189. warp/native/sort.cpp +94 -94
  190. warp/native/sort.cu +97 -97
  191. warp/native/sort.h +14 -14
  192. warp/native/sparse.cpp +337 -337
  193. warp/native/sparse.cu +544 -544
  194. warp/native/spatial.h +630 -630
  195. warp/native/svd.h +562 -562
  196. warp/native/temp_buffer.h +30 -30
  197. warp/native/vec.h +1177 -1133
  198. warp/native/volume.cpp +529 -297
  199. warp/native/volume.cu +58 -32
  200. warp/native/volume.h +960 -538
  201. warp/native/volume_builder.cu +446 -425
  202. warp/native/volume_builder.h +34 -19
  203. warp/native/volume_impl.h +61 -0
  204. warp/native/warp.cpp +1057 -1052
  205. warp/native/warp.cu +2949 -2828
  206. warp/native/warp.h +321 -305
  207. warp/optim/__init__.py +9 -9
  208. warp/optim/adam.py +120 -120
  209. warp/optim/linear.py +1104 -939
  210. warp/optim/sgd.py +104 -92
  211. warp/render/__init__.py +10 -10
  212. warp/render/render_opengl.py +3356 -3204
  213. warp/render/render_usd.py +768 -749
  214. warp/render/utils.py +152 -150
  215. warp/sim/__init__.py +52 -59
  216. warp/sim/articulation.py +685 -685
  217. warp/sim/collide.py +1594 -1590
  218. warp/sim/import_mjcf.py +489 -481
  219. warp/sim/import_snu.py +220 -221
  220. warp/sim/import_urdf.py +536 -516
  221. warp/sim/import_usd.py +887 -881
  222. warp/sim/inertia.py +316 -317
  223. warp/sim/integrator.py +234 -233
  224. warp/sim/integrator_euler.py +1956 -1956
  225. warp/sim/integrator_featherstone.py +1917 -1991
  226. warp/sim/integrator_xpbd.py +3288 -3312
  227. warp/sim/model.py +4473 -4314
  228. warp/sim/particles.py +113 -112
  229. warp/sim/render.py +417 -403
  230. warp/sim/utils.py +413 -410
  231. warp/sparse.py +1289 -1227
  232. warp/stubs.py +2192 -2469
  233. warp/tape.py +1162 -225
  234. warp/tests/__init__.py +1 -1
  235. warp/tests/__main__.py +4 -4
  236. warp/tests/assets/test_index_grid.nvdb +0 -0
  237. warp/tests/assets/torus.usda +105 -105
  238. warp/tests/aux_test_class_kernel.py +26 -26
  239. warp/tests/aux_test_compile_consts_dummy.py +10 -10
  240. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -21
  241. warp/tests/aux_test_dependent.py +20 -22
  242. warp/tests/aux_test_grad_customs.py +21 -23
  243. warp/tests/aux_test_reference.py +9 -11
  244. warp/tests/aux_test_reference_reference.py +8 -10
  245. warp/tests/aux_test_square.py +15 -17
  246. warp/tests/aux_test_unresolved_func.py +14 -14
  247. warp/tests/aux_test_unresolved_symbol.py +14 -14
  248. warp/tests/disabled_kinematics.py +237 -239
  249. warp/tests/run_coverage_serial.py +31 -31
  250. warp/tests/test_adam.py +155 -157
  251. warp/tests/test_arithmetic.py +1088 -1124
  252. warp/tests/test_array.py +2415 -2326
  253. warp/tests/test_array_reduce.py +148 -150
  254. warp/tests/test_async.py +666 -656
  255. warp/tests/test_atomic.py +139 -141
  256. warp/tests/test_bool.py +212 -149
  257. warp/tests/test_builtins_resolution.py +1290 -1292
  258. warp/tests/test_bvh.py +162 -171
  259. warp/tests/test_closest_point_edge_edge.py +227 -228
  260. warp/tests/test_codegen.py +562 -553
  261. warp/tests/test_compile_consts.py +217 -101
  262. warp/tests/test_conditional.py +244 -246
  263. warp/tests/test_copy.py +230 -215
  264. warp/tests/test_ctypes.py +630 -632
  265. warp/tests/test_dense.py +65 -67
  266. warp/tests/test_devices.py +89 -98
  267. warp/tests/test_dlpack.py +528 -529
  268. warp/tests/test_examples.py +403 -378
  269. warp/tests/test_fabricarray.py +952 -955
  270. warp/tests/test_fast_math.py +60 -54
  271. warp/tests/test_fem.py +1298 -1278
  272. warp/tests/test_fp16.py +128 -130
  273. warp/tests/test_func.py +336 -337
  274. warp/tests/test_generics.py +596 -571
  275. warp/tests/test_grad.py +885 -640
  276. warp/tests/test_grad_customs.py +331 -336
  277. warp/tests/test_hash_grid.py +208 -164
  278. warp/tests/test_import.py +37 -39
  279. warp/tests/test_indexedarray.py +1132 -1134
  280. warp/tests/test_intersect.py +65 -67
  281. warp/tests/test_jax.py +305 -307
  282. warp/tests/test_large.py +169 -164
  283. warp/tests/test_launch.py +352 -354
  284. warp/tests/test_lerp.py +217 -261
  285. warp/tests/test_linear_solvers.py +189 -171
  286. warp/tests/test_lvalue.py +419 -493
  287. warp/tests/test_marching_cubes.py +63 -65
  288. warp/tests/test_mat.py +1799 -1827
  289. warp/tests/test_mat_lite.py +113 -115
  290. warp/tests/test_mat_scalar_ops.py +2905 -2889
  291. warp/tests/test_math.py +124 -193
  292. warp/tests/test_matmul.py +498 -499
  293. warp/tests/test_matmul_lite.py +408 -410
  294. warp/tests/test_mempool.py +186 -190
  295. warp/tests/test_mesh.py +281 -324
  296. warp/tests/test_mesh_query_aabb.py +226 -241
  297. warp/tests/test_mesh_query_point.py +690 -702
  298. warp/tests/test_mesh_query_ray.py +290 -303
  299. warp/tests/test_mlp.py +274 -276
  300. warp/tests/test_model.py +108 -110
  301. warp/tests/test_module_hashing.py +111 -0
  302. warp/tests/test_modules_lite.py +36 -39
  303. warp/tests/test_multigpu.py +161 -163
  304. warp/tests/test_noise.py +244 -248
  305. warp/tests/test_operators.py +248 -250
  306. warp/tests/test_options.py +121 -125
  307. warp/tests/test_peer.py +131 -137
  308. warp/tests/test_pinned.py +76 -78
  309. warp/tests/test_print.py +52 -54
  310. warp/tests/test_quat.py +2084 -2086
  311. warp/tests/test_rand.py +324 -288
  312. warp/tests/test_reload.py +207 -217
  313. warp/tests/test_rounding.py +177 -179
  314. warp/tests/test_runlength_encode.py +188 -190
  315. warp/tests/test_sim_grad.py +241 -0
  316. warp/tests/test_sim_kinematics.py +89 -97
  317. warp/tests/test_smoothstep.py +166 -168
  318. warp/tests/test_snippet.py +303 -266
  319. warp/tests/test_sparse.py +466 -460
  320. warp/tests/test_spatial.py +2146 -2148
  321. warp/tests/test_special_values.py +362 -0
  322. warp/tests/test_streams.py +484 -473
  323. warp/tests/test_struct.py +708 -675
  324. warp/tests/test_tape.py +171 -148
  325. warp/tests/test_torch.py +741 -743
  326. warp/tests/test_transient_module.py +85 -87
  327. warp/tests/test_types.py +554 -659
  328. warp/tests/test_utils.py +488 -499
  329. warp/tests/test_vec.py +1262 -1268
  330. warp/tests/test_vec_lite.py +71 -73
  331. warp/tests/test_vec_scalar_ops.py +2097 -2099
  332. warp/tests/test_verify_fp.py +92 -94
  333. warp/tests/test_volume.py +961 -736
  334. warp/tests/test_volume_write.py +338 -265
  335. warp/tests/unittest_serial.py +38 -37
  336. warp/tests/unittest_suites.py +367 -359
  337. warp/tests/unittest_utils.py +434 -578
  338. warp/tests/unused_test_misc.py +69 -71
  339. warp/tests/walkthrough_debug.py +85 -85
  340. warp/thirdparty/appdirs.py +598 -598
  341. warp/thirdparty/dlpack.py +143 -143
  342. warp/thirdparty/unittest_parallel.py +563 -561
  343. warp/torch.py +321 -295
  344. warp/types.py +4941 -4450
  345. warp/utils.py +1008 -821
  346. {warp_lang-1.0.2.dist-info → warp_lang-1.2.0.dist-info}/LICENSE.md +126 -126
  347. {warp_lang-1.0.2.dist-info → warp_lang-1.2.0.dist-info}/METADATA +365 -400
  348. warp_lang-1.2.0.dist-info/RECORD +359 -0
  349. warp/examples/assets/cube.usda +0 -42
  350. warp/examples/assets/sphere.usda +0 -56
  351. warp/examples/assets/torus.usda +0 -105
  352. warp/examples/fem/example_convection_diffusion_dg0.py +0 -194
  353. warp/native/nanovdb/PNanoVDBWrite.h +0 -295
  354. warp_lang-1.0.2.dist-info/RECORD +0 -352
  355. {warp_lang-1.0.2.dist-info → warp_lang-1.2.0.dist-info}/WHEEL +0 -0
  356. {warp_lang-1.0.2.dist-info → warp_lang-1.2.0.dist-info}/top_level.txt +0 -0
warp/tests/test_quat.py CHANGED
@@ -1,2086 +1,2084 @@
1
- # Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
2
- # NVIDIA CORPORATION and its licensors retain all intellectual property
3
- # and proprietary rights in and to this software, related documentation
4
- # and any modifications thereto. Any use, reproduction, disclosure or
5
- # distribution of this software and related documentation without an express
6
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
-
8
- import unittest
9
-
10
- import numpy as np
11
-
12
- import warp as wp
13
- from warp.tests.unittest_utils import *
14
- import warp.sim
15
-
16
- wp.init()
17
-
18
- np_float_types = [np.float32, np.float64, np.float16]
19
-
20
- kernel_cache = dict()
21
-
22
-
23
- def getkernel(func, suffix=""):
24
- key = func.__name__ + "_" + suffix
25
- if key not in kernel_cache:
26
- kernel_cache[key] = wp.Kernel(func=func, key=key)
27
- return kernel_cache[key]
28
-
29
-
30
- def get_select_kernel(dtype):
31
- def output_select_kernel_fn(
32
- input: wp.array(dtype=dtype),
33
- index: int,
34
- out: wp.array(dtype=dtype),
35
- ):
36
- out[0] = input[index]
37
-
38
- return getkernel(output_select_kernel_fn, suffix=dtype.__name__)
39
-
40
-
41
- ############################################################
42
-
43
-
44
- def test_constructors(test, device, dtype, register_kernels=False):
45
- rng = np.random.default_rng(123)
46
-
47
- tol = {
48
- np.float16: 5.0e-3,
49
- np.float32: 1.0e-6,
50
- np.float64: 1.0e-8,
51
- }.get(dtype, 0)
52
-
53
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
54
- vec3 = wp.types.vector(length=3, dtype=wptype)
55
- quat = wp.types.quaternion(dtype=wptype)
56
-
57
- def check_component_constructor(
58
- input: wp.array(dtype=wptype),
59
- q: wp.array(dtype=wptype),
60
- ):
61
- qresult = quat(input[0], input[1], input[2], input[3])
62
-
63
- # multiply the output by 2 so we've got something to backpropagate:
64
- q[0] = wptype(2) * qresult[0]
65
- q[1] = wptype(2) * qresult[1]
66
- q[2] = wptype(2) * qresult[2]
67
- q[3] = wptype(2) * qresult[3]
68
-
69
- def check_vector_constructor(
70
- input: wp.array(dtype=wptype),
71
- q: wp.array(dtype=wptype),
72
- ):
73
- qresult = quat(vec3(input[0], input[1], input[2]), input[3])
74
-
75
- # multiply the output by 2 so we've got something to backpropagate:
76
- q[0] = wptype(2) * qresult[0]
77
- q[1] = wptype(2) * qresult[1]
78
- q[2] = wptype(2) * qresult[2]
79
- q[3] = wptype(2) * qresult[3]
80
-
81
- kernel = getkernel(check_component_constructor, suffix=dtype.__name__)
82
- output_select_kernel = get_select_kernel(wptype)
83
- vec_kernel = getkernel(check_vector_constructor, suffix=dtype.__name__)
84
-
85
- if register_kernels:
86
- return
87
-
88
- input = wp.array(rng.standard_normal(size=4).astype(dtype), requires_grad=True, device=device)
89
- output = wp.zeros_like(input)
90
- wp.launch(kernel, dim=1, inputs=[input], outputs=[output], device=device)
91
-
92
- assert_np_equal(output.numpy(), 2 * input.numpy(), tol=tol)
93
-
94
- for i in range(4):
95
- cmp = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
96
- tape = wp.Tape()
97
- with tape:
98
- wp.launch(kernel, dim=1, inputs=[input], outputs=[output], device=device)
99
- wp.launch(output_select_kernel, dim=1, inputs=[output, i], outputs=[cmp], device=device)
100
- tape.backward(loss=cmp)
101
- expectedgrads = np.zeros(len(input))
102
- expectedgrads[i] = 2
103
- assert_np_equal(tape.gradients[input].numpy(), expectedgrads)
104
- tape.zero()
105
-
106
- input = wp.array(rng.standard_normal(size=4).astype(dtype), requires_grad=True, device=device)
107
- output = wp.zeros_like(input)
108
- wp.launch(vec_kernel, dim=1, inputs=[input], outputs=[output], device=device)
109
-
110
- assert_np_equal(output.numpy(), 2 * input.numpy(), tol=tol)
111
-
112
- for i in range(4):
113
- cmp = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
114
- tape = wp.Tape()
115
- with tape:
116
- wp.launch(vec_kernel, dim=1, inputs=[input], outputs=[output], device=device)
117
- wp.launch(output_select_kernel, dim=1, inputs=[output, i], outputs=[cmp], device=device)
118
- tape.backward(loss=cmp)
119
- expectedgrads = np.zeros(len(input))
120
- expectedgrads[i] = 2
121
- assert_np_equal(tape.gradients[input].numpy(), expectedgrads)
122
- tape.zero()
123
-
124
-
125
- def test_casting_constructors(test, device, dtype, register_kernels=False):
126
- np_type = np.dtype(dtype)
127
- wp_type = wp.types.np_dtype_to_warp_type[np_type]
128
- quat = wp.types.quaternion(dtype=wp_type)
129
-
130
- np16 = np.dtype(np.float16)
131
- wp16 = wp.types.np_dtype_to_warp_type[np16]
132
-
133
- np32 = np.dtype(np.float32)
134
- wp32 = wp.types.np_dtype_to_warp_type[np32]
135
-
136
- np64 = np.dtype(np.float64)
137
- wp64 = wp.types.np_dtype_to_warp_type[np64]
138
-
139
- def cast_float16(a: wp.array(dtype=wp_type, ndim=2), b: wp.array(dtype=wp16, ndim=2)):
140
- tid = wp.tid()
141
-
142
- q1 = quat(a[tid, 0], a[tid, 1], a[tid, 2], a[tid, 3])
143
- q2 = wp.quaternion(q1, dtype=wp16)
144
-
145
- b[tid, 0] = q2[0]
146
- b[tid, 1] = q2[1]
147
- b[tid, 2] = q2[2]
148
- b[tid, 3] = q2[3]
149
-
150
- def cast_float32(a: wp.array(dtype=wp_type, ndim=2), b: wp.array(dtype=wp32, ndim=2)):
151
- tid = wp.tid()
152
-
153
- q1 = quat(a[tid, 0], a[tid, 1], a[tid, 2], a[tid, 3])
154
- q2 = wp.quaternion(q1, dtype=wp32)
155
-
156
- b[tid, 0] = q2[0]
157
- b[tid, 1] = q2[1]
158
- b[tid, 2] = q2[2]
159
- b[tid, 3] = q2[3]
160
-
161
- def cast_float64(a: wp.array(dtype=wp_type, ndim=2), b: wp.array(dtype=wp64, ndim=2)):
162
- tid = wp.tid()
163
-
164
- q1 = quat(a[tid, 0], a[tid, 1], a[tid, 2], a[tid, 3])
165
- q2 = wp.quaternion(q1, dtype=wp64)
166
-
167
- b[tid, 0] = q2[0]
168
- b[tid, 1] = q2[1]
169
- b[tid, 2] = q2[2]
170
- b[tid, 3] = q2[3]
171
-
172
- kernel_16 = getkernel(cast_float16, suffix=dtype.__name__)
173
- kernel_32 = getkernel(cast_float32, suffix=dtype.__name__)
174
- kernel_64 = getkernel(cast_float64, suffix=dtype.__name__)
175
-
176
- if register_kernels:
177
- return
178
-
179
- # check casting to float 16
180
- a = wp.array(np.ones((1, 4), dtype=np_type), dtype=wp_type, requires_grad=True, device=device)
181
- b = wp.array(np.zeros((1, 4), dtype=np16), dtype=wp16, requires_grad=True, device=device)
182
- b_result = np.ones((1, 4), dtype=np16)
183
- b_grad = wp.array(np.ones((1, 4), dtype=np16), dtype=wp16, device=device)
184
- a_grad = wp.array(np.ones((1, 4), dtype=np_type), dtype=wp_type, device=device)
185
-
186
- tape = wp.Tape()
187
- with tape:
188
- wp.launch(kernel=kernel_16, dim=1, inputs=[a, b], device=device)
189
-
190
- tape.backward(grads={b: b_grad})
191
- out = tape.gradients[a].numpy()
192
-
193
- assert_np_equal(b.numpy(), b_result)
194
- assert_np_equal(out, a_grad.numpy())
195
-
196
- # check casting to float 32
197
- a = wp.array(np.ones((1, 4), dtype=np_type), dtype=wp_type, requires_grad=True, device=device)
198
- b = wp.array(np.zeros((1, 4), dtype=np32), dtype=wp32, requires_grad=True, device=device)
199
- b_result = np.ones((1, 4), dtype=np32)
200
- b_grad = wp.array(np.ones((1, 4), dtype=np32), dtype=wp32, device=device)
201
- a_grad = wp.array(np.ones((1, 4), dtype=np_type), dtype=wp_type, device=device)
202
-
203
- tape = wp.Tape()
204
- with tape:
205
- wp.launch(kernel=kernel_32, dim=1, inputs=[a, b], device=device)
206
-
207
- tape.backward(grads={b: b_grad})
208
- out = tape.gradients[a].numpy()
209
-
210
- assert_np_equal(b.numpy(), b_result)
211
- assert_np_equal(out, a_grad.numpy())
212
-
213
- # check casting to float 64
214
- a = wp.array(np.ones((1, 4), dtype=np_type), dtype=wp_type, requires_grad=True, device=device)
215
- b = wp.array(np.zeros((1, 4), dtype=np64), dtype=wp64, requires_grad=True, device=device)
216
- b_result = np.ones((1, 4), dtype=np64)
217
- b_grad = wp.array(np.ones((1, 4), dtype=np64), dtype=wp64, device=device)
218
- a_grad = wp.array(np.ones((1, 4), dtype=np_type), dtype=wp_type, device=device)
219
-
220
- tape = wp.Tape()
221
- with tape:
222
- wp.launch(kernel=kernel_64, dim=1, inputs=[a, b], device=device)
223
-
224
- tape.backward(grads={b: b_grad})
225
- out = tape.gradients[a].numpy()
226
-
227
- assert_np_equal(b.numpy(), b_result)
228
- assert_np_equal(out, a_grad.numpy())
229
-
230
-
231
- def test_inverse(test, device, dtype, register_kernels=False):
232
- rng = np.random.default_rng(123)
233
-
234
- tol = {
235
- np.float16: 2.0e-3,
236
- np.float32: 1.0e-6,
237
- np.float64: 1.0e-8,
238
- }.get(dtype, 0)
239
-
240
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
241
- quat = wp.types.quaternion(dtype=wptype)
242
-
243
- output_select_kernel = get_select_kernel(wptype)
244
-
245
- def check_quat_inverse(
246
- input: wp.array(dtype=wptype),
247
- shouldbeidentity: wp.array(dtype=quat),
248
- q: wp.array(dtype=wptype),
249
- ):
250
- qread = quat(input[0], input[1], input[2], input[3])
251
- qresult = wp.quat_inverse(qread)
252
-
253
- # this inverse should work for normalized quaternions:
254
- shouldbeidentity[0] = wp.normalize(qread) * wp.quat_inverse(wp.normalize(qread))
255
-
256
- # multiply the output by 2 so we've got something to backpropagate:
257
- q[0] = wptype(2) * qresult[0]
258
- q[1] = wptype(2) * qresult[1]
259
- q[2] = wptype(2) * qresult[2]
260
- q[3] = wptype(2) * qresult[3]
261
-
262
- kernel = getkernel(check_quat_inverse, suffix=dtype.__name__)
263
-
264
- if register_kernels:
265
- return
266
-
267
- input = wp.array(rng.standard_normal(size=4).astype(dtype), requires_grad=True, device=device)
268
- shouldbeidentity = wp.array(np.zeros((1, 4)), dtype=quat, requires_grad=True, device=device)
269
- output = wp.zeros_like(input)
270
- wp.launch(kernel, dim=1, inputs=[input], outputs=[shouldbeidentity, output], device=device)
271
-
272
- assert_np_equal(shouldbeidentity.numpy(), np.array([0, 0, 0, 1]), tol=tol)
273
-
274
- for i in range(4):
275
- cmp = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
276
- tape = wp.Tape()
277
- with tape:
278
- wp.launch(kernel, dim=1, inputs=[input], outputs=[shouldbeidentity, output], device=device)
279
- wp.launch(output_select_kernel, dim=1, inputs=[output, i], outputs=[cmp], device=device)
280
- tape.backward(loss=cmp)
281
- expectedgrads = np.zeros(len(input))
282
- expectedgrads[i] = -2 if i != 3 else 2
283
- assert_np_equal(tape.gradients[input].numpy(), expectedgrads)
284
- tape.zero()
285
-
286
-
287
- def test_dotproduct(test, device, dtype, register_kernels=False):
288
- rng = np.random.default_rng(123)
289
-
290
- tol = {
291
- np.float16: 1.0e-2,
292
- np.float32: 1.0e-6,
293
- np.float64: 1.0e-8,
294
- }.get(dtype, 0)
295
-
296
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
297
- quat = wp.types.quaternion(dtype=wptype)
298
-
299
- def check_quat_dot(
300
- s: wp.array(dtype=quat),
301
- v: wp.array(dtype=quat),
302
- dot: wp.array(dtype=wptype),
303
- ):
304
- dot[0] = wptype(2) * wp.dot(v[0], s[0])
305
-
306
- dotkernel = getkernel(check_quat_dot, suffix=dtype.__name__)
307
- if register_kernels:
308
- return
309
-
310
- s = wp.array(rng.standard_normal(size=4).astype(dtype), dtype=quat, requires_grad=True, device=device)
311
- v = wp.array(rng.standard_normal(size=4).astype(dtype), dtype=quat, requires_grad=True, device=device)
312
- dot = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
313
-
314
- tape = wp.Tape()
315
- with tape:
316
- wp.launch(
317
- dotkernel,
318
- dim=1,
319
- inputs=[
320
- s,
321
- v,
322
- ],
323
- outputs=[dot],
324
- device=device,
325
- )
326
-
327
- assert_np_equal(dot.numpy()[0], 2.0 * (v.numpy() * s.numpy()).sum(), tol=tol)
328
-
329
- tape.backward(loss=dot)
330
- sgrads = tape.gradients[s].numpy()[0]
331
- expected_grads = 2.0 * v.numpy()[0]
332
- assert_np_equal(sgrads, expected_grads, tol=10 * tol)
333
-
334
- vgrads = tape.gradients[v].numpy()[0]
335
- expected_grads = 2.0 * s.numpy()[0]
336
- assert_np_equal(vgrads, expected_grads, tol=tol)
337
-
338
-
339
- def test_length(test, device, dtype, register_kernels=False):
340
- rng = np.random.default_rng(123)
341
-
342
- tol = {
343
- np.float16: 5.0e-3,
344
- np.float32: 1.0e-6,
345
- np.float64: 1.0e-7,
346
- }.get(dtype, 0)
347
-
348
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
349
- quat = wp.types.quaternion(dtype=wptype)
350
-
351
- def check_quat_length(
352
- q: wp.array(dtype=quat),
353
- l: wp.array(dtype=wptype),
354
- l2: wp.array(dtype=wptype),
355
- ):
356
- l[0] = wptype(2) * wp.length(q[0])
357
- l2[0] = wptype(2) * wp.length_sq(q[0])
358
-
359
- kernel = getkernel(check_quat_length, suffix=dtype.__name__)
360
-
361
- if register_kernels:
362
- return
363
-
364
- q = wp.array(rng.standard_normal(size=4).astype(dtype), dtype=quat, requires_grad=True, device=device)
365
- l = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
366
- l2 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
367
-
368
- tape = wp.Tape()
369
- with tape:
370
- wp.launch(
371
- kernel,
372
- dim=1,
373
- inputs=[
374
- q,
375
- ],
376
- outputs=[l, l2],
377
- device=device,
378
- )
379
-
380
- assert_np_equal(l.numpy()[0], 2 * np.linalg.norm(q.numpy()), tol=10 * tol)
381
- assert_np_equal(l2.numpy()[0], 2 * np.linalg.norm(q.numpy()) ** 2, tol=10 * tol)
382
-
383
- tape.backward(loss=l)
384
- grad = tape.gradients[q].numpy()[0]
385
- expected_grad = 2 * q.numpy()[0] / np.linalg.norm(q.numpy())
386
- assert_np_equal(grad, expected_grad, tol=10 * tol)
387
- tape.zero()
388
-
389
- tape.backward(loss=l2)
390
- grad = tape.gradients[q].numpy()[0]
391
- expected_grad = 4 * q.numpy()[0]
392
- assert_np_equal(grad, expected_grad, tol=10 * tol)
393
- tape.zero()
394
-
395
-
396
- def test_normalize(test, device, dtype, register_kernels=False):
397
- rng = np.random.default_rng(123)
398
-
399
- tol = {
400
- np.float16: 5.0e-3,
401
- np.float32: 1.0e-6,
402
- np.float64: 1.0e-8,
403
- }.get(dtype, 0)
404
-
405
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
406
- quat = wp.types.quaternion(dtype=wptype)
407
-
408
- def check_normalize(
409
- q: wp.array(dtype=quat),
410
- n0: wp.array(dtype=wptype),
411
- n1: wp.array(dtype=wptype),
412
- n2: wp.array(dtype=wptype),
413
- n3: wp.array(dtype=wptype),
414
- ):
415
- n = wptype(2) * (wp.normalize(q[0]))
416
-
417
- n0[0] = n[0]
418
- n1[0] = n[1]
419
- n2[0] = n[2]
420
- n3[0] = n[3]
421
-
422
- def check_normalize_alt(
423
- q: wp.array(dtype=quat),
424
- n0: wp.array(dtype=wptype),
425
- n1: wp.array(dtype=wptype),
426
- n2: wp.array(dtype=wptype),
427
- n3: wp.array(dtype=wptype),
428
- ):
429
- n = wptype(2) * (q[0] / wp.length(q[0]))
430
-
431
- n0[0] = n[0]
432
- n1[0] = n[1]
433
- n2[0] = n[2]
434
- n3[0] = n[3]
435
-
436
- normalize_kernel = getkernel(check_normalize, suffix=dtype.__name__)
437
- normalize_alt_kernel = getkernel(check_normalize_alt, suffix=dtype.__name__)
438
-
439
- if register_kernels:
440
- return
441
-
442
- # I've already tested the things I'm using in check_normalize_alt, so I'll just
443
- # make sure the two are giving the same results/gradients
444
- q = wp.array(rng.standard_normal(size=4).astype(dtype), dtype=quat, requires_grad=True, device=device)
445
-
446
- n0 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
447
- n1 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
448
- n2 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
449
- n3 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
450
-
451
- n0_alt = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
452
- n1_alt = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
453
- n2_alt = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
454
- n3_alt = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
455
-
456
- outputs0 = [
457
- n0,
458
- n1,
459
- n2,
460
- n3,
461
- ]
462
- tape0 = wp.Tape()
463
- with tape0:
464
- wp.launch(normalize_kernel, dim=1, inputs=[q], outputs=outputs0, device=device)
465
-
466
- outputs1 = [
467
- n0_alt,
468
- n1_alt,
469
- n2_alt,
470
- n3_alt,
471
- ]
472
- tape1 = wp.Tape()
473
- with tape1:
474
- wp.launch(
475
- normalize_alt_kernel,
476
- dim=1,
477
- inputs=[
478
- q,
479
- ],
480
- outputs=outputs1,
481
- device=device,
482
- )
483
-
484
- assert_np_equal(n0.numpy()[0], n0_alt.numpy()[0], tol=tol)
485
- assert_np_equal(n1.numpy()[0], n1_alt.numpy()[0], tol=tol)
486
- assert_np_equal(n2.numpy()[0], n2_alt.numpy()[0], tol=tol)
487
- assert_np_equal(n3.numpy()[0], n3_alt.numpy()[0], tol=tol)
488
-
489
- for ncmp, ncmpalt in zip(outputs0, outputs1):
490
- tape0.backward(loss=ncmp)
491
- tape1.backward(loss=ncmpalt)
492
- assert_np_equal(tape0.gradients[q].numpy()[0], tape1.gradients[q].numpy()[0], tol=tol)
493
- tape0.zero()
494
- tape1.zero()
495
-
496
-
497
- def test_addition(test, device, dtype, register_kernels=False):
498
- rng = np.random.default_rng(123)
499
-
500
- tol = {
501
- np.float16: 5.0e-3,
502
- np.float32: 1.0e-6,
503
- np.float64: 1.0e-8,
504
- }.get(dtype, 0)
505
-
506
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
507
- quat = wp.types.quaternion(dtype=wptype)
508
-
509
- def check_quat_add(
510
- q: wp.array(dtype=quat),
511
- v: wp.array(dtype=quat),
512
- r0: wp.array(dtype=wptype),
513
- r1: wp.array(dtype=wptype),
514
- r2: wp.array(dtype=wptype),
515
- r3: wp.array(dtype=wptype),
516
- ):
517
- result = q[0] + v[0]
518
-
519
- r0[0] = wptype(2) * result[0]
520
- r1[0] = wptype(2) * result[1]
521
- r2[0] = wptype(2) * result[2]
522
- r3[0] = wptype(2) * result[3]
523
-
524
- kernel = getkernel(check_quat_add, suffix=dtype.__name__)
525
-
526
- if register_kernels:
527
- return
528
-
529
- q = wp.array(rng.standard_normal(size=4).astype(dtype), dtype=quat, requires_grad=True, device=device)
530
- v = wp.array(rng.standard_normal(size=4).astype(dtype), dtype=quat, requires_grad=True, device=device)
531
-
532
- r0 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
533
- r1 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
534
- r2 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
535
- r3 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
536
-
537
- tape = wp.Tape()
538
- with tape:
539
- wp.launch(
540
- kernel,
541
- dim=1,
542
- inputs=[
543
- q,
544
- v,
545
- ],
546
- outputs=[r0, r1, r2, r3],
547
- device=device,
548
- )
549
-
550
- assert_np_equal(r0.numpy()[0], 2 * (v.numpy()[0, 0] + q.numpy()[0, 0]), tol=tol)
551
- assert_np_equal(r1.numpy()[0], 2 * (v.numpy()[0, 1] + q.numpy()[0, 1]), tol=tol)
552
- assert_np_equal(r2.numpy()[0], 2 * (v.numpy()[0, 2] + q.numpy()[0, 2]), tol=tol)
553
- assert_np_equal(r3.numpy()[0], 2 * (v.numpy()[0, 3] + q.numpy()[0, 3]), tol=tol)
554
-
555
- for i, l in enumerate([r0, r1, r2, r3]):
556
- tape.backward(loss=l)
557
- qgrads = tape.gradients[q].numpy()[0]
558
- expected_grads = np.zeros_like(qgrads)
559
-
560
- expected_grads[i] = 2
561
- assert_np_equal(qgrads, expected_grads, tol=10 * tol)
562
-
563
- vgrads = tape.gradients[v].numpy()[0]
564
- assert_np_equal(vgrads, expected_grads, tol=tol)
565
-
566
- tape.zero()
567
-
568
-
569
- def test_subtraction(test, device, dtype, register_kernels=False):
570
- rng = np.random.default_rng(123)
571
-
572
- tol = {
573
- np.float16: 5.0e-3,
574
- np.float32: 1.0e-6,
575
- np.float64: 1.0e-8,
576
- }.get(dtype, 0)
577
-
578
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
579
- quat = wp.types.quaternion(dtype=wptype)
580
-
581
- def check_quat_sub(
582
- q: wp.array(dtype=quat),
583
- v: wp.array(dtype=quat),
584
- r0: wp.array(dtype=wptype),
585
- r1: wp.array(dtype=wptype),
586
- r2: wp.array(dtype=wptype),
587
- r3: wp.array(dtype=wptype),
588
- ):
589
- result = v[0] - q[0]
590
-
591
- r0[0] = wptype(2) * result[0]
592
- r1[0] = wptype(2) * result[1]
593
- r2[0] = wptype(2) * result[2]
594
- r3[0] = wptype(2) * result[3]
595
-
596
- kernel = getkernel(check_quat_sub, suffix=dtype.__name__)
597
-
598
- if register_kernels:
599
- return
600
-
601
- q = wp.array(rng.standard_normal(size=4).astype(dtype), dtype=quat, requires_grad=True, device=device)
602
- v = wp.array(rng.standard_normal(size=4).astype(dtype), dtype=quat, requires_grad=True, device=device)
603
-
604
- r0 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
605
- r1 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
606
- r2 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
607
- r3 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
608
-
609
- tape = wp.Tape()
610
- with tape:
611
- wp.launch(
612
- kernel,
613
- dim=1,
614
- inputs=[
615
- q,
616
- v,
617
- ],
618
- outputs=[r0, r1, r2, r3],
619
- device=device,
620
- )
621
-
622
- assert_np_equal(r0.numpy()[0], 2 * (v.numpy()[0, 0] - q.numpy()[0, 0]), tol=tol)
623
- assert_np_equal(r1.numpy()[0], 2 * (v.numpy()[0, 1] - q.numpy()[0, 1]), tol=tol)
624
- assert_np_equal(r2.numpy()[0], 2 * (v.numpy()[0, 2] - q.numpy()[0, 2]), tol=tol)
625
- assert_np_equal(r3.numpy()[0], 2 * (v.numpy()[0, 3] - q.numpy()[0, 3]), tol=tol)
626
-
627
- for i, l in enumerate([r0, r1, r2, r3]):
628
- tape.backward(loss=l)
629
- qgrads = tape.gradients[q].numpy()[0]
630
- expected_grads = np.zeros_like(qgrads)
631
-
632
- expected_grads[i] = -2
633
- assert_np_equal(qgrads, expected_grads, tol=10 * tol)
634
-
635
- vgrads = tape.gradients[v].numpy()[0]
636
- expected_grads[i] = 2
637
- assert_np_equal(vgrads, expected_grads, tol=tol)
638
-
639
- tape.zero()
640
-
641
-
642
- def test_scalar_multiplication(test, device, dtype, register_kernels=False):
643
- rng = np.random.default_rng(123)
644
-
645
- tol = {
646
- np.float16: 5.0e-3,
647
- np.float32: 1.0e-6,
648
- np.float64: 1.0e-8,
649
- }.get(dtype, 0)
650
-
651
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
652
- quat = wp.types.quaternion(dtype=wptype)
653
-
654
- def check_quat_scalar_mul(
655
- s: wp.array(dtype=wptype),
656
- q: wp.array(dtype=quat),
657
- l0: wp.array(dtype=wptype),
658
- l1: wp.array(dtype=wptype),
659
- l2: wp.array(dtype=wptype),
660
- l3: wp.array(dtype=wptype),
661
- r0: wp.array(dtype=wptype),
662
- r1: wp.array(dtype=wptype),
663
- r2: wp.array(dtype=wptype),
664
- r3: wp.array(dtype=wptype),
665
- ):
666
- lresult = s[0] * q[0]
667
- rresult = q[0] * s[0]
668
-
669
- # multiply outputs by 2 so we've got something to backpropagate:
670
- l0[0] = wptype(2) * lresult[0]
671
- l1[0] = wptype(2) * lresult[1]
672
- l2[0] = wptype(2) * lresult[2]
673
- l3[0] = wptype(2) * lresult[3]
674
-
675
- r0[0] = wptype(2) * rresult[0]
676
- r1[0] = wptype(2) * rresult[1]
677
- r2[0] = wptype(2) * rresult[2]
678
- r3[0] = wptype(2) * rresult[3]
679
-
680
- kernel = getkernel(check_quat_scalar_mul, suffix=dtype.__name__)
681
-
682
- if register_kernels:
683
- return
684
-
685
- s = wp.array(rng.standard_normal(size=1).astype(dtype), requires_grad=True, device=device)
686
- q = wp.array(rng.standard_normal(size=(1, 4)).astype(dtype), dtype=quat, requires_grad=True, device=device)
687
-
688
- l0 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
689
- l1 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
690
- l2 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
691
- l3 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
692
-
693
- r0 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
694
- r1 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
695
- r2 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
696
- r3 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
697
-
698
- tape = wp.Tape()
699
- with tape:
700
- wp.launch(
701
- kernel,
702
- dim=1,
703
- inputs=[s, q],
704
- outputs=[
705
- l0,
706
- l1,
707
- l2,
708
- l3,
709
- r0,
710
- r1,
711
- r2,
712
- r3,
713
- ],
714
- device=device,
715
- )
716
-
717
- assert_np_equal(l0.numpy()[0], 2 * s.numpy()[0] * q.numpy()[0, 0], tol=tol)
718
- assert_np_equal(l1.numpy()[0], 2 * s.numpy()[0] * q.numpy()[0, 1], tol=tol)
719
- assert_np_equal(l2.numpy()[0], 2 * s.numpy()[0] * q.numpy()[0, 2], tol=tol)
720
- assert_np_equal(l3.numpy()[0], 2 * s.numpy()[0] * q.numpy()[0, 3], tol=tol)
721
-
722
- assert_np_equal(r0.numpy()[0], 2 * s.numpy()[0] * q.numpy()[0, 0], tol=tol)
723
- assert_np_equal(r1.numpy()[0], 2 * s.numpy()[0] * q.numpy()[0, 1], tol=tol)
724
- assert_np_equal(r2.numpy()[0], 2 * s.numpy()[0] * q.numpy()[0, 2], tol=tol)
725
- assert_np_equal(r3.numpy()[0], 2 * s.numpy()[0] * q.numpy()[0, 3], tol=tol)
726
-
727
- if dtype in np_float_types:
728
- for i, outputs in enumerate([(l0, r0), (l1, r1), (l2, r2), (l3, r3)]):
729
- for l in outputs:
730
- tape.backward(loss=l)
731
- sgrad = tape.gradients[s].numpy()[0]
732
- assert_np_equal(sgrad, 2 * q.numpy()[0, i], tol=tol)
733
- allgrads = tape.gradients[q].numpy()[0]
734
- expected_grads = np.zeros_like(allgrads)
735
- expected_grads[i] = s.numpy()[0] * 2
736
- assert_np_equal(allgrads, expected_grads, tol=10 * tol)
737
- tape.zero()
738
-
739
-
740
- def test_scalar_division(test, device, dtype, register_kernels=False):
741
- rng = np.random.default_rng(123)
742
-
743
- tol = {
744
- np.float16: 1.0e-3,
745
- np.float32: 1.0e-6,
746
- np.float64: 1.0e-8,
747
- }.get(dtype, 0)
748
-
749
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
750
- quat = wp.types.quaternion(dtype=wptype)
751
-
752
- def check_quat_scalar_div(
753
- s: wp.array(dtype=wptype),
754
- q: wp.array(dtype=quat),
755
- r0: wp.array(dtype=wptype),
756
- r1: wp.array(dtype=wptype),
757
- r2: wp.array(dtype=wptype),
758
- r3: wp.array(dtype=wptype),
759
- ):
760
- result = q[0] / s[0]
761
-
762
- # multiply outputs by 2 so we've got something to backpropagate:
763
- r0[0] = wptype(2) * result[0]
764
- r1[0] = wptype(2) * result[1]
765
- r2[0] = wptype(2) * result[2]
766
- r3[0] = wptype(2) * result[3]
767
-
768
- kernel = getkernel(check_quat_scalar_div, suffix=dtype.__name__)
769
-
770
- if register_kernels:
771
- return
772
-
773
- s = wp.array(rng.standard_normal(size=1).astype(dtype), requires_grad=True, device=device)
774
- q = wp.array(rng.standard_normal(size=(1, 4)).astype(dtype), dtype=quat, requires_grad=True, device=device)
775
-
776
- r0 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
777
- r1 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
778
- r2 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
779
- r3 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
780
-
781
- tape = wp.Tape()
782
- with tape:
783
- wp.launch(
784
- kernel,
785
- dim=1,
786
- inputs=[s, q],
787
- outputs=[
788
- r0,
789
- r1,
790
- r2,
791
- r3,
792
- ],
793
- device=device,
794
- )
795
- assert_np_equal(r0.numpy()[0], 2 * q.numpy()[0, 0] / s.numpy()[0], tol=tol)
796
- assert_np_equal(r1.numpy()[0], 2 * q.numpy()[0, 1] / s.numpy()[0], tol=tol)
797
- assert_np_equal(r2.numpy()[0], 2 * q.numpy()[0, 2] / s.numpy()[0], tol=tol)
798
- assert_np_equal(r3.numpy()[0], 2 * q.numpy()[0, 3] / s.numpy()[0], tol=tol)
799
-
800
- if dtype in np_float_types:
801
- for i, r in enumerate([r0, r1, r2, r3]):
802
- tape.backward(loss=r)
803
- sgrad = tape.gradients[s].numpy()[0]
804
- assert_np_equal(sgrad, -2 * q.numpy()[0, i] / (s.numpy()[0] * s.numpy()[0]), tol=tol)
805
-
806
- allgrads = tape.gradients[q].numpy()[0]
807
- expected_grads = np.zeros_like(allgrads)
808
- expected_grads[i] = 2 / s.numpy()[0]
809
- assert_np_equal(allgrads, expected_grads, tol=10 * tol)
810
- tape.zero()
811
-
812
-
813
- def test_quat_multiplication(test, device, dtype, register_kernels=False):
814
- rng = np.random.default_rng(123)
815
-
816
- tol = {
817
- np.float16: 1.0e-2,
818
- np.float32: 1.0e-6,
819
- np.float64: 1.0e-8,
820
- }.get(dtype, 0)
821
-
822
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
823
- quat = wp.types.quaternion(dtype=wptype)
824
-
825
- def check_quat_mul(
826
- s: wp.array(dtype=quat),
827
- q: wp.array(dtype=quat),
828
- r0: wp.array(dtype=wptype),
829
- r1: wp.array(dtype=wptype),
830
- r2: wp.array(dtype=wptype),
831
- r3: wp.array(dtype=wptype),
832
- ):
833
- result = s[0] * q[0]
834
-
835
- # multiply outputs by 2 so we've got something to backpropagate:
836
- r0[0] = wptype(2) * result[0]
837
- r1[0] = wptype(2) * result[1]
838
- r2[0] = wptype(2) * result[2]
839
- r3[0] = wptype(2) * result[3]
840
-
841
- kernel = getkernel(check_quat_mul, suffix=dtype.__name__)
842
-
843
- if register_kernels:
844
- return
845
-
846
- s = wp.array(rng.standard_normal(size=(1, 4)).astype(dtype), dtype=quat, requires_grad=True, device=device)
847
- q = wp.array(rng.standard_normal(size=(1, 4)).astype(dtype), dtype=quat, requires_grad=True, device=device)
848
-
849
- r0 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
850
- r1 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
851
- r2 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
852
- r3 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
853
-
854
- tape = wp.Tape()
855
- with tape:
856
- wp.launch(
857
- kernel,
858
- dim=1,
859
- inputs=[s, q],
860
- outputs=[
861
- r0,
862
- r1,
863
- r2,
864
- r3,
865
- ],
866
- device=device,
867
- )
868
-
869
- a = s.numpy()
870
- b = q.numpy()
871
- assert_np_equal(
872
- r0.numpy()[0], 2 * (a[0, 3] * b[0, 0] + b[0, 3] * a[0, 0] + a[0, 1] * b[0, 2] - b[0, 1] * a[0, 2]), tol=tol
873
- )
874
- assert_np_equal(
875
- r1.numpy()[0], 2 * (a[0, 3] * b[0, 1] + b[0, 3] * a[0, 1] + a[0, 2] * b[0, 0] - b[0, 2] * a[0, 0]), tol=tol
876
- )
877
- assert_np_equal(
878
- r2.numpy()[0], 2 * (a[0, 3] * b[0, 2] + b[0, 3] * a[0, 2] + a[0, 0] * b[0, 1] - b[0, 0] * a[0, 1]), tol=tol
879
- )
880
- assert_np_equal(
881
- r3.numpy()[0], 2 * (a[0, 3] * b[0, 3] - a[0, 0] * b[0, 0] - a[0, 1] * b[0, 1] - a[0, 2] * b[0, 2]), tol=tol
882
- )
883
-
884
- tape.backward(loss=r0)
885
- agrad = tape.gradients[s].numpy()[0]
886
- assert_np_equal(agrad, 2 * np.array([b[0, 3], b[0, 2], -b[0, 1], b[0, 0]]), tol=tol)
887
-
888
- bgrad = tape.gradients[q].numpy()[0]
889
- assert_np_equal(bgrad, 2 * np.array([a[0, 3], -a[0, 2], a[0, 1], a[0, 0]]), tol=tol)
890
- tape.zero()
891
-
892
- tape.backward(loss=r1)
893
- agrad = tape.gradients[s].numpy()[0]
894
- assert_np_equal(agrad, 2 * np.array([-b[0, 2], b[0, 3], b[0, 0], b[0, 1]]), tol=tol)
895
-
896
- bgrad = tape.gradients[q].numpy()[0]
897
- assert_np_equal(bgrad, 2 * np.array([a[0, 2], a[0, 3], -a[0, 0], a[0, 1]]), tol=tol)
898
- tape.zero()
899
-
900
- tape.backward(loss=r2)
901
- agrad = tape.gradients[s].numpy()[0]
902
- assert_np_equal(agrad, 2 * np.array([b[0, 1], -b[0, 0], b[0, 3], b[0, 2]]), tol=tol)
903
-
904
- bgrad = tape.gradients[q].numpy()[0]
905
- assert_np_equal(bgrad, 2 * np.array([-a[0, 1], a[0, 0], a[0, 3], a[0, 2]]), tol=tol)
906
- tape.zero()
907
-
908
- tape.backward(loss=r3)
909
- agrad = tape.gradients[s].numpy()[0]
910
- assert_np_equal(agrad, 2 * np.array([-b[0, 0], -b[0, 1], -b[0, 2], b[0, 3]]), tol=tol)
911
-
912
- bgrad = tape.gradients[q].numpy()[0]
913
- assert_np_equal(bgrad, 2 * np.array([-a[0, 0], -a[0, 1], -a[0, 2], a[0, 3]]), tol=tol)
914
- tape.zero()
915
-
916
-
917
- def test_indexing(test, device, dtype, register_kernels=False):
918
- rng = np.random.default_rng(123)
919
-
920
- tol = {
921
- np.float16: 5.0e-3,
922
- np.float32: 1.0e-6,
923
- np.float64: 1.0e-8,
924
- }.get(dtype, 0)
925
-
926
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
927
- quat = wp.types.quaternion(dtype=wptype)
928
-
929
- def check_quat_indexing(
930
- q: wp.array(dtype=quat),
931
- r0: wp.array(dtype=wptype),
932
- r1: wp.array(dtype=wptype),
933
- r2: wp.array(dtype=wptype),
934
- r3: wp.array(dtype=wptype),
935
- ):
936
- # multiply outputs by 2 so we've got something to backpropagate:
937
- r0[0] = wptype(2) * q[0][0]
938
- r1[0] = wptype(2) * q[0][1]
939
- r2[0] = wptype(2) * q[0][2]
940
- r3[0] = wptype(2) * q[0][3]
941
-
942
- kernel = getkernel(check_quat_indexing, suffix=dtype.__name__)
943
-
944
- if register_kernels:
945
- return
946
-
947
- q = wp.array(rng.standard_normal(size=(1, 4)).astype(dtype), dtype=quat, requires_grad=True, device=device)
948
- r0 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
949
- r1 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
950
- r2 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
951
- r3 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
952
-
953
- tape = wp.Tape()
954
- with tape:
955
- wp.launch(kernel, dim=1, inputs=[q], outputs=[r0, r1, r2, r3], device=device)
956
-
957
- for i, l in enumerate([r0, r1, r2, r3]):
958
- tape.backward(loss=l)
959
- allgrads = tape.gradients[q].numpy()[0]
960
- expected_grads = np.zeros_like(allgrads)
961
- expected_grads[i] = 2
962
- assert_np_equal(allgrads, expected_grads, tol=tol)
963
- tape.zero()
964
-
965
- assert_np_equal(r0.numpy()[0], 2.0 * q.numpy()[0, 0], tol=tol)
966
- assert_np_equal(r1.numpy()[0], 2.0 * q.numpy()[0, 1], tol=tol)
967
- assert_np_equal(r2.numpy()[0], 2.0 * q.numpy()[0, 2], tol=tol)
968
- assert_np_equal(r3.numpy()[0], 2.0 * q.numpy()[0, 3], tol=tol)
969
-
970
-
971
- def test_quat_lerp(test, device, dtype, register_kernels=False):
972
- rng = np.random.default_rng(123)
973
-
974
- tol = {
975
- np.float16: 1.0e-2,
976
- np.float32: 1.0e-6,
977
- np.float64: 1.0e-8,
978
- }.get(dtype, 0)
979
-
980
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
981
- quat = wp.types.quaternion(dtype=wptype)
982
-
983
- def check_quat_lerp(
984
- s: wp.array(dtype=quat),
985
- q: wp.array(dtype=quat),
986
- t: wp.array(dtype=wptype),
987
- r0: wp.array(dtype=wptype),
988
- r1: wp.array(dtype=wptype),
989
- r2: wp.array(dtype=wptype),
990
- r3: wp.array(dtype=wptype),
991
- ):
992
- result = wp.lerp(s[0], q[0], t[0])
993
-
994
- # multiply outputs by 2 so we've got something to backpropagate:
995
- r0[0] = wptype(2) * result[0]
996
- r1[0] = wptype(2) * result[1]
997
- r2[0] = wptype(2) * result[2]
998
- r3[0] = wptype(2) * result[3]
999
-
1000
- kernel = getkernel(check_quat_lerp, suffix=dtype.__name__)
1001
-
1002
- if register_kernels:
1003
- return
1004
-
1005
- s = wp.array(rng.standard_normal(size=(1, 4)).astype(dtype), dtype=quat, requires_grad=True, device=device)
1006
- q = wp.array(rng.standard_normal(size=(1, 4)).astype(dtype), dtype=quat, requires_grad=True, device=device)
1007
- t = wp.array(rng.uniform(size=1).astype(dtype), dtype=wptype, requires_grad=True, device=device)
1008
-
1009
- r0 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1010
- r1 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1011
- r2 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1012
- r3 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1013
-
1014
- tape = wp.Tape()
1015
- with tape:
1016
- wp.launch(
1017
- kernel,
1018
- dim=1,
1019
- inputs=[s, q, t],
1020
- outputs=[
1021
- r0,
1022
- r1,
1023
- r2,
1024
- r3,
1025
- ],
1026
- device=device,
1027
- )
1028
-
1029
- a = s.numpy()
1030
- b = q.numpy()
1031
- tt = t.numpy()
1032
- assert_np_equal(r0.numpy()[0], 2 * ((1 - tt) * a[0, 0] + tt * b[0, 0]), tol=tol)
1033
- assert_np_equal(r1.numpy()[0], 2 * ((1 - tt) * a[0, 1] + tt * b[0, 1]), tol=tol)
1034
- assert_np_equal(r2.numpy()[0], 2 * ((1 - tt) * a[0, 2] + tt * b[0, 2]), tol=tol)
1035
- assert_np_equal(r3.numpy()[0], 2 * ((1 - tt) * a[0, 3] + tt * b[0, 3]), tol=tol)
1036
-
1037
- for i, l in enumerate([r0, r1, r2, r3]):
1038
- tape.backward(loss=l)
1039
- agrad = tape.gradients[s].numpy()[0]
1040
- bgrad = tape.gradients[q].numpy()[0]
1041
- tgrad = tape.gradients[t].numpy()[0]
1042
- expected_grads = np.zeros_like(agrad)
1043
- expected_grads[i] = 2 * (1 - tt)
1044
- assert_np_equal(agrad, expected_grads, tol=tol)
1045
- expected_grads[i] = 2 * tt
1046
- assert_np_equal(bgrad, expected_grads, tol=tol)
1047
- assert_np_equal(tgrad, 2 * (b[0, i] - a[0, i]), tol=tol)
1048
-
1049
- tape.zero()
1050
-
1051
-
1052
- def test_quat_rotate(test, device, dtype, register_kernels=False):
1053
- rng = np.random.default_rng(123)
1054
-
1055
- tol = {
1056
- np.float16: 1.0e-2,
1057
- np.float32: 1.0e-6,
1058
- np.float64: 1.0e-8,
1059
- }.get(dtype, 0)
1060
-
1061
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1062
- quat = wp.types.quaternion(dtype=wptype)
1063
- vec3 = wp.types.vector(length=3, dtype=wptype)
1064
-
1065
- def check_quat_rotate(
1066
- q: wp.array(dtype=quat),
1067
- v: wp.array(dtype=vec3),
1068
- outputs: wp.array(dtype=wptype),
1069
- outputs_inv: wp.array(dtype=wptype),
1070
- outputs_manual: wp.array(dtype=wptype),
1071
- outputs_inv_manual: wp.array(dtype=wptype),
1072
- ):
1073
- result = wp.quat_rotate(q[0], v[0])
1074
- result_inv = wp.quat_rotate_inv(q[0], v[0])
1075
-
1076
- qv = vec3(q[0][0], q[0][1], q[0][2])
1077
- qw = q[0][3]
1078
-
1079
- result_manual = v[0] * (wptype(2) * qw * qw - wptype(1))
1080
- result_manual += wp.cross(qv, v[0]) * qw * wptype(2)
1081
- result_manual += qv * wp.dot(qv, v[0]) * wptype(2)
1082
-
1083
- result_inv_manual = v[0] * (wptype(2) * qw * qw - wptype(1))
1084
- result_inv_manual -= wp.cross(qv, v[0]) * qw * wptype(2)
1085
- result_inv_manual += qv * wp.dot(qv, v[0]) * wptype(2)
1086
-
1087
- for i in range(3):
1088
- # multiply outputs by 2 so we've got something to backpropagate:
1089
- outputs[i] = wptype(2) * result[i]
1090
- outputs_inv[i] = wptype(2) * result_inv[i]
1091
- outputs_manual[i] = wptype(2) * result_manual[i]
1092
- outputs_inv_manual[i] = wptype(2) * result_inv_manual[i]
1093
-
1094
- kernel = getkernel(check_quat_rotate, suffix=dtype.__name__)
1095
- output_select_kernel = get_select_kernel(wptype)
1096
-
1097
- if register_kernels:
1098
- return
1099
-
1100
- q = rng.standard_normal(size=(1, 4))
1101
- q /= np.linalg.norm(q)
1102
- q = wp.array(q.astype(dtype), dtype=quat, requires_grad=True, device=device)
1103
- v = wp.array(0.5 * rng.standard_normal(size=(1, 3)).astype(dtype), dtype=vec3, requires_grad=True, device=device)
1104
-
1105
- # test values against the manually computed result:
1106
- outputs = wp.zeros(3, dtype=wptype, requires_grad=True, device=device)
1107
- outputs_inv = wp.zeros(3, dtype=wptype, requires_grad=True, device=device)
1108
- outputs_manual = wp.zeros(3, dtype=wptype, requires_grad=True, device=device)
1109
- outputs_inv_manual = wp.zeros(3, dtype=wptype, requires_grad=True, device=device)
1110
-
1111
- wp.launch(
1112
- kernel,
1113
- dim=1,
1114
- inputs=[q, v],
1115
- outputs=[
1116
- outputs,
1117
- outputs_inv,
1118
- outputs_manual,
1119
- outputs_inv_manual,
1120
- ],
1121
- device=device,
1122
- )
1123
-
1124
- assert_np_equal(outputs.numpy(), outputs_manual.numpy(), tol=tol)
1125
- assert_np_equal(outputs_inv.numpy(), outputs_inv_manual.numpy(), tol=tol)
1126
-
1127
- # test gradients against the manually computed result:
1128
- for i in range(3):
1129
- cmp = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1130
- cmp_inv = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1131
- cmp_manual = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1132
- cmp_inv_manual = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1133
- tape = wp.Tape()
1134
- with tape:
1135
- wp.launch(
1136
- kernel,
1137
- dim=1,
1138
- inputs=[q, v],
1139
- outputs=[
1140
- outputs,
1141
- outputs_inv,
1142
- outputs_manual,
1143
- outputs_inv_manual,
1144
- ],
1145
- device=device,
1146
- )
1147
- wp.launch(output_select_kernel, dim=1, inputs=[outputs, i], outputs=[cmp], device=device)
1148
- wp.launch(output_select_kernel, dim=1, inputs=[outputs_inv, i], outputs=[cmp_inv], device=device)
1149
- wp.launch(output_select_kernel, dim=1, inputs=[outputs_manual, i], outputs=[cmp_manual], device=device)
1150
- wp.launch(
1151
- output_select_kernel, dim=1, inputs=[outputs_inv_manual, i], outputs=[cmp_inv_manual], device=device
1152
- )
1153
-
1154
- tape.backward(loss=cmp)
1155
- qgrads = 1.0 * tape.gradients[q].numpy()
1156
- vgrads = 1.0 * tape.gradients[v].numpy()
1157
- tape.zero()
1158
- tape.backward(loss=cmp_inv)
1159
- qgrads_inv = 1.0 * tape.gradients[q].numpy()
1160
- vgrads_inv = 1.0 * tape.gradients[v].numpy()
1161
- tape.zero()
1162
- tape.backward(loss=cmp_manual)
1163
- qgrads_manual = 1.0 * tape.gradients[q].numpy()
1164
- vgrads_manual = 1.0 * tape.gradients[v].numpy()
1165
- tape.zero()
1166
- tape.backward(loss=cmp_inv_manual)
1167
- qgrads_inv_manual = 1.0 * tape.gradients[q].numpy()
1168
- vgrads_inv_manual = 1.0 * tape.gradients[v].numpy()
1169
- tape.zero()
1170
-
1171
- assert_np_equal(qgrads, qgrads_manual, tol=tol)
1172
- assert_np_equal(vgrads, vgrads_manual, tol=tol)
1173
-
1174
- assert_np_equal(qgrads_inv, qgrads_inv_manual, tol=tol)
1175
- assert_np_equal(vgrads_inv, vgrads_inv_manual, tol=tol)
1176
-
1177
-
1178
- def test_quat_to_matrix(test, device, dtype, register_kernels=False):
1179
- rng = np.random.default_rng(123)
1180
-
1181
- tol = {
1182
- np.float16: 1.0e-2,
1183
- np.float32: 1.0e-6,
1184
- np.float64: 1.0e-8,
1185
- }.get(dtype, 0)
1186
-
1187
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1188
- quat = wp.types.quaternion(dtype=wptype)
1189
- mat3 = wp.types.matrix(shape=(3, 3), dtype=wptype)
1190
- vec3 = wp.types.vector(length=3, dtype=wptype)
1191
-
1192
- def check_quat_to_matrix(
1193
- q: wp.array(dtype=quat),
1194
- outputs: wp.array(dtype=wptype),
1195
- outputs_manual: wp.array(dtype=wptype),
1196
- ):
1197
- result = wp.quat_to_matrix(q[0])
1198
-
1199
- xaxis = wp.quat_rotate(
1200
- q[0],
1201
- vec3(
1202
- wptype(1),
1203
- wptype(0),
1204
- wptype(0),
1205
- ),
1206
- )
1207
- yaxis = wp.quat_rotate(
1208
- q[0],
1209
- vec3(
1210
- wptype(0),
1211
- wptype(1),
1212
- wptype(0),
1213
- ),
1214
- )
1215
- zaxis = wp.quat_rotate(
1216
- q[0],
1217
- vec3(
1218
- wptype(0),
1219
- wptype(0),
1220
- wptype(1),
1221
- ),
1222
- )
1223
- result_manual = mat3(xaxis, yaxis, zaxis)
1224
-
1225
- idx = 0
1226
- for i in range(3):
1227
- for j in range(3):
1228
- # multiply outputs by 2 so we've got something to backpropagate:
1229
- outputs[idx] = wptype(2) * result[i, j]
1230
- outputs_manual[idx] = wptype(2) * result_manual[i, j]
1231
-
1232
- idx = idx + 1
1233
-
1234
- kernel = getkernel(check_quat_to_matrix, suffix=dtype.__name__)
1235
- output_select_kernel = get_select_kernel(wptype)
1236
-
1237
- if register_kernels:
1238
- return
1239
-
1240
- q = rng.standard_normal(size=(1, 4))
1241
- q /= np.linalg.norm(q)
1242
- q = wp.array(q.astype(dtype), dtype=quat, requires_grad=True, device=device)
1243
-
1244
- # test values against the manually computed result:
1245
- outputs = wp.zeros(3 * 3, dtype=wptype, requires_grad=True, device=device)
1246
- outputs_manual = wp.zeros(3 * 3, dtype=wptype, requires_grad=True, device=device)
1247
-
1248
- wp.launch(
1249
- kernel,
1250
- dim=1,
1251
- inputs=[q],
1252
- outputs=[
1253
- outputs,
1254
- outputs_manual,
1255
- ],
1256
- device=device,
1257
- )
1258
-
1259
- assert_np_equal(outputs.numpy(), outputs_manual.numpy(), tol=tol)
1260
-
1261
- # sanity check: divide by 2 to remove that scale factor we put in there, and
1262
- # it should be a rotation matrix
1263
- R = 0.5 * outputs.numpy().reshape(3, 3)
1264
- assert_np_equal(np.matmul(R, R.T), np.eye(3), tol=tol)
1265
-
1266
- # test gradients against the manually computed result:
1267
- idx = 0
1268
- for i in range(3):
1269
- for j in range(3):
1270
- cmp = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1271
- cmp_manual = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1272
- tape = wp.Tape()
1273
- with tape:
1274
- wp.launch(
1275
- kernel,
1276
- dim=1,
1277
- inputs=[q],
1278
- outputs=[
1279
- outputs,
1280
- outputs_manual,
1281
- ],
1282
- device=device,
1283
- )
1284
- wp.launch(output_select_kernel, dim=1, inputs=[outputs, idx], outputs=[cmp], device=device)
1285
- wp.launch(
1286
- output_select_kernel, dim=1, inputs=[outputs_manual, idx], outputs=[cmp_manual], device=device
1287
- )
1288
- tape.backward(loss=cmp)
1289
- qgrads = 1.0 * tape.gradients[q].numpy()
1290
- tape.zero()
1291
- tape.backward(loss=cmp_manual)
1292
- qgrads_manual = 1.0 * tape.gradients[q].numpy()
1293
- tape.zero()
1294
-
1295
- assert_np_equal(qgrads, qgrads_manual, tol=tol)
1296
- idx = idx + 1
1297
-
1298
-
1299
- ############################################################
1300
-
1301
-
1302
- def test_slerp_grad(test, device, dtype, register_kernels=False):
1303
- rng = np.random.default_rng(123)
1304
- seed = 42
1305
-
1306
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1307
- vec3 = wp.types.vector(3, wptype)
1308
- quat = wp.types.quaternion(wptype)
1309
-
1310
- def slerp_kernel(
1311
- q0: wp.array(dtype=quat),
1312
- q1: wp.array(dtype=quat),
1313
- t: wp.array(dtype=wptype),
1314
- loss: wp.array(dtype=wptype),
1315
- index: int,
1316
- ):
1317
- tid = wp.tid()
1318
-
1319
- q = wp.quat_slerp(q0[tid], q1[tid], t[tid])
1320
- wp.atomic_add(loss, 0, q[index])
1321
-
1322
- slerp_kernel = getkernel(slerp_kernel, suffix=dtype.__name__)
1323
-
1324
- def slerp_kernel_forward(
1325
- q0: wp.array(dtype=quat),
1326
- q1: wp.array(dtype=quat),
1327
- t: wp.array(dtype=wptype),
1328
- loss: wp.array(dtype=wptype),
1329
- index: int,
1330
- ):
1331
- tid = wp.tid()
1332
-
1333
- axis = vec3()
1334
- angle = wptype(0.0)
1335
-
1336
- wp.quat_to_axis_angle(wp.mul(wp.quat_inverse(q0[tid]), q1[tid]), axis, angle)
1337
- q = wp.mul(q0[tid], wp.quat_from_axis_angle(axis, t[tid] * angle))
1338
-
1339
- wp.atomic_add(loss, 0, q[index])
1340
-
1341
- slerp_kernel_forward = getkernel(slerp_kernel_forward, suffix=dtype.__name__)
1342
-
1343
- def quat_sampler_slerp(kernel_seed: int, quats: wp.array(dtype=quat)):
1344
- tid = wp.tid()
1345
-
1346
- state = wp.rand_init(kernel_seed, tid)
1347
-
1348
- angle = wp.randf(state, 0.0, 2.0 * 3.1415926535)
1349
- dir = wp.sample_unit_sphere_surface(state) * wp.sin(angle * 0.5)
1350
-
1351
- q = quat(wptype(dir[0]), wptype(dir[1]), wptype(dir[2]), wptype(wp.cos(angle * 0.5)))
1352
- qn = wp.normalize(q)
1353
-
1354
- quats[tid] = qn
1355
-
1356
- quat_sampler = getkernel(quat_sampler_slerp, suffix=dtype.__name__)
1357
-
1358
- if register_kernels:
1359
- return
1360
-
1361
- N = 50
1362
-
1363
- q0 = wp.zeros(N, dtype=quat, device=device, requires_grad=True)
1364
- q1 = wp.zeros(N, dtype=quat, device=device, requires_grad=True)
1365
-
1366
- wp.launch(kernel=quat_sampler, dim=N, inputs=[seed, q0], device=device)
1367
- wp.launch(kernel=quat_sampler, dim=N, inputs=[seed + 1, q1], device=device)
1368
-
1369
- t = rng.uniform(low=0.0, high=1.0, size=N)
1370
- t = wp.array(t, dtype=wptype, device=device, requires_grad=True)
1371
-
1372
- def compute_gradients(kernel, wrt, index):
1373
- loss = wp.zeros(1, dtype=wptype, device=device, requires_grad=True)
1374
- tape = wp.Tape()
1375
- with tape:
1376
- wp.launch(kernel=kernel, dim=N, inputs=[q0, q1, t, loss, index], device=device)
1377
-
1378
- tape.backward(loss)
1379
-
1380
- gradients = 1.0 * tape.gradients[wrt].numpy()
1381
- tape.zero()
1382
-
1383
- return loss.numpy()[0], gradients
1384
-
1385
- eps = {
1386
- np.float16: 2.0e-2,
1387
- np.float32: 1.0e-5,
1388
- np.float64: 1.0e-8,
1389
- }.get(dtype, 0)
1390
-
1391
- # wrt t
1392
-
1393
- # gather gradients from builtin adjoints
1394
- xcmp, gradients_x = compute_gradients(slerp_kernel, t, 0)
1395
- ycmp, gradients_y = compute_gradients(slerp_kernel, t, 1)
1396
- zcmp, gradients_z = compute_gradients(slerp_kernel, t, 2)
1397
- wcmp, gradients_w = compute_gradients(slerp_kernel, t, 3)
1398
-
1399
- # gather gradients from autodiff
1400
- xcmp_auto, gradients_x_auto = compute_gradients(slerp_kernel_forward, t, 0)
1401
- ycmp_auto, gradients_y_auto = compute_gradients(slerp_kernel_forward, t, 1)
1402
- zcmp_auto, gradients_z_auto = compute_gradients(slerp_kernel_forward, t, 2)
1403
- wcmp_auto, gradients_w_auto = compute_gradients(slerp_kernel_forward, t, 3)
1404
-
1405
- assert_np_equal(gradients_x, gradients_x_auto, tol=eps)
1406
- assert_np_equal(gradients_y, gradients_y_auto, tol=eps)
1407
- assert_np_equal(gradients_z, gradients_z_auto, tol=eps)
1408
- assert_np_equal(gradients_w, gradients_w_auto, tol=eps)
1409
- assert_np_equal(xcmp, xcmp_auto, tol=eps)
1410
- assert_np_equal(ycmp, ycmp_auto, tol=eps)
1411
- assert_np_equal(zcmp, zcmp_auto, tol=eps)
1412
- assert_np_equal(wcmp, wcmp_auto, tol=eps)
1413
-
1414
- # wrt q0
1415
-
1416
- # gather gradients from builtin adjoints
1417
- xcmp, gradients_x = compute_gradients(slerp_kernel, q0, 0)
1418
- ycmp, gradients_y = compute_gradients(slerp_kernel, q0, 1)
1419
- zcmp, gradients_z = compute_gradients(slerp_kernel, q0, 2)
1420
- wcmp, gradients_w = compute_gradients(slerp_kernel, q0, 3)
1421
-
1422
- # gather gradients from autodiff
1423
- xcmp_auto, gradients_x_auto = compute_gradients(slerp_kernel_forward, q0, 0)
1424
- ycmp_auto, gradients_y_auto = compute_gradients(slerp_kernel_forward, q0, 1)
1425
- zcmp_auto, gradients_z_auto = compute_gradients(slerp_kernel_forward, q0, 2)
1426
- wcmp_auto, gradients_w_auto = compute_gradients(slerp_kernel_forward, q0, 3)
1427
-
1428
- assert_np_equal(gradients_x, gradients_x_auto, tol=eps)
1429
- assert_np_equal(gradients_y, gradients_y_auto, tol=eps)
1430
- assert_np_equal(gradients_z, gradients_z_auto, tol=eps)
1431
- assert_np_equal(gradients_w, gradients_w_auto, tol=eps)
1432
- assert_np_equal(xcmp, xcmp_auto, tol=eps)
1433
- assert_np_equal(ycmp, ycmp_auto, tol=eps)
1434
- assert_np_equal(zcmp, zcmp_auto, tol=eps)
1435
- assert_np_equal(wcmp, wcmp_auto, tol=eps)
1436
-
1437
- # wrt q1
1438
-
1439
- # gather gradients from builtin adjoints
1440
- xcmp, gradients_x = compute_gradients(slerp_kernel, q1, 0)
1441
- ycmp, gradients_y = compute_gradients(slerp_kernel, q1, 1)
1442
- zcmp, gradients_z = compute_gradients(slerp_kernel, q1, 2)
1443
- wcmp, gradients_w = compute_gradients(slerp_kernel, q1, 3)
1444
-
1445
- # gather gradients from autodiff
1446
- xcmp_auto, gradients_x_auto = compute_gradients(slerp_kernel_forward, q1, 0)
1447
- ycmp_auto, gradients_y_auto = compute_gradients(slerp_kernel_forward, q1, 1)
1448
- zcmp_auto, gradients_z_auto = compute_gradients(slerp_kernel_forward, q1, 2)
1449
- wcmp_auto, gradients_w_auto = compute_gradients(slerp_kernel_forward, q1, 3)
1450
-
1451
- assert_np_equal(gradients_x, gradients_x_auto, tol=eps)
1452
- assert_np_equal(gradients_y, gradients_y_auto, tol=eps)
1453
- assert_np_equal(gradients_z, gradients_z_auto, tol=eps)
1454
- assert_np_equal(gradients_w, gradients_w_auto, tol=eps)
1455
- assert_np_equal(xcmp, xcmp_auto, tol=eps)
1456
- assert_np_equal(ycmp, ycmp_auto, tol=eps)
1457
- assert_np_equal(zcmp, zcmp_auto, tol=eps)
1458
- assert_np_equal(wcmp, wcmp_auto, tol=eps)
1459
-
1460
-
1461
- ############################################################
1462
-
1463
-
1464
- def test_quat_to_axis_angle_grad(test, device, dtype, register_kernels=False):
1465
- rng = np.random.default_rng(123)
1466
- seed = 42
1467
- num_rand = 50
1468
-
1469
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1470
- vec3 = wp.types.vector(3, wptype)
1471
- vec4 = wp.types.vector(4, wptype)
1472
- quat = wp.types.quaternion(wptype)
1473
-
1474
- def quat_to_axis_angle_kernel(quats: wp.array(dtype=quat), loss: wp.array(dtype=wptype), coord_idx: int):
1475
- tid = wp.tid()
1476
- axis = vec3()
1477
- angle = wptype(0.0)
1478
-
1479
- wp.quat_to_axis_angle(quats[tid], axis, angle)
1480
- a = vec4(axis[0], axis[1], axis[2], angle)
1481
-
1482
- wp.atomic_add(loss, 0, a[coord_idx])
1483
-
1484
- quat_to_axis_angle_kernel = getkernel(quat_to_axis_angle_kernel, suffix=dtype.__name__)
1485
-
1486
- def quat_to_axis_angle_kernel_forward(quats: wp.array(dtype=quat), loss: wp.array(dtype=wptype), coord_idx: int):
1487
- tid = wp.tid()
1488
- q = quats[tid]
1489
- axis = vec3()
1490
- angle = wptype(0.0)
1491
-
1492
- v = vec3(q[0], q[1], q[2])
1493
- if q[3] < wptype(0):
1494
- axis = -wp.normalize(v)
1495
- else:
1496
- axis = wp.normalize(v)
1497
-
1498
- angle = wptype(2) * wp.atan2(wp.length(v), wp.abs(q[3]))
1499
- a = vec4(axis[0], axis[1], axis[2], angle)
1500
-
1501
- wp.atomic_add(loss, 0, a[coord_idx])
1502
-
1503
- quat_to_axis_angle_kernel_forward = getkernel(quat_to_axis_angle_kernel_forward, suffix=dtype.__name__)
1504
-
1505
- def quat_sampler(kernel_seed: int, angles: wp.array(dtype=float), quats: wp.array(dtype=quat)):
1506
- tid = wp.tid()
1507
-
1508
- state = wp.rand_init(kernel_seed, tid)
1509
-
1510
- angle = angles[tid]
1511
- dir = wp.sample_unit_sphere_surface(state) * wp.sin(angle * 0.5)
1512
-
1513
- q = quat(wptype(dir[0]), wptype(dir[1]), wptype(dir[2]), wptype(wp.cos(angle * 0.5)))
1514
- qn = wp.normalize(q)
1515
-
1516
- quats[tid] = qn
1517
-
1518
- quat_sampler = getkernel(quat_sampler, suffix=dtype.__name__)
1519
-
1520
- if register_kernels:
1521
- return
1522
-
1523
- quats = wp.zeros(num_rand, dtype=quat, device=device, requires_grad=True)
1524
- angles = wp.array(
1525
- np.linspace(0.0, 2.0 * np.pi, num_rand, endpoint=False, dtype=np.float32), dtype=float, device=device
1526
- )
1527
- wp.launch(kernel=quat_sampler, dim=num_rand, inputs=[seed, angles, quats], device=device)
1528
-
1529
- edge_cases = np.array(
1530
- [(1.0, 0.0, 0.0, 0.0), (0.0, 1.0 / np.sqrt(3), 1.0 / np.sqrt(3), 1.0 / np.sqrt(3)), (0.0, 0.0, 0.0, 0.0)]
1531
- )
1532
- num_edge = len(edge_cases)
1533
- edge_cases = wp.array(edge_cases, dtype=quat, device=device, requires_grad=True)
1534
-
1535
- def compute_gradients(arr, kernel, dim, index):
1536
- loss = wp.zeros(1, dtype=wptype, device=device, requires_grad=True)
1537
- tape = wp.Tape()
1538
- with tape:
1539
- wp.launch(kernel=kernel, dim=dim, inputs=[arr, loss, index], device=device)
1540
-
1541
- tape.backward(loss)
1542
-
1543
- gradients = 1.0 * tape.gradients[arr].numpy()
1544
- tape.zero()
1545
-
1546
- return loss.numpy()[0], gradients
1547
-
1548
- # gather gradients from builtin adjoints
1549
- xcmp, gradients_x = compute_gradients(quats, quat_to_axis_angle_kernel, num_rand, 0)
1550
- ycmp, gradients_y = compute_gradients(quats, quat_to_axis_angle_kernel, num_rand, 1)
1551
- zcmp, gradients_z = compute_gradients(quats, quat_to_axis_angle_kernel, num_rand, 2)
1552
- wcmp, gradients_w = compute_gradients(quats, quat_to_axis_angle_kernel, num_rand, 3)
1553
-
1554
- # gather gradients from autodiff
1555
- xcmp_auto, gradients_x_auto = compute_gradients(quats, quat_to_axis_angle_kernel_forward, num_rand, 0)
1556
- ycmp_auto, gradients_y_auto = compute_gradients(quats, quat_to_axis_angle_kernel_forward, num_rand, 1)
1557
- zcmp_auto, gradients_z_auto = compute_gradients(quats, quat_to_axis_angle_kernel_forward, num_rand, 2)
1558
- wcmp_auto, gradients_w_auto = compute_gradients(quats, quat_to_axis_angle_kernel_forward, num_rand, 3)
1559
-
1560
- # edge cases: gather gradients from builtin adjoints
1561
- _, edge_gradients_x = compute_gradients(edge_cases, quat_to_axis_angle_kernel, num_edge, 0)
1562
- _, edge_gradients_y = compute_gradients(edge_cases, quat_to_axis_angle_kernel, num_edge, 1)
1563
- _, edge_gradients_z = compute_gradients(edge_cases, quat_to_axis_angle_kernel, num_edge, 2)
1564
- _, edge_gradients_w = compute_gradients(edge_cases, quat_to_axis_angle_kernel, num_edge, 3)
1565
-
1566
- # edge cases: gather gradients from autodiff
1567
- _, edge_gradients_x_auto = compute_gradients(edge_cases, quat_to_axis_angle_kernel_forward, num_edge, 0)
1568
- _, edge_gradients_y_auto = compute_gradients(edge_cases, quat_to_axis_angle_kernel_forward, num_edge, 1)
1569
- _, edge_gradients_z_auto = compute_gradients(edge_cases, quat_to_axis_angle_kernel_forward, num_edge, 2)
1570
- _, edge_gradients_w_auto = compute_gradients(edge_cases, quat_to_axis_angle_kernel_forward, num_edge, 3)
1571
-
1572
- eps = {
1573
- np.float16: 2.0e-1,
1574
- np.float32: 2.0e-4,
1575
- np.float64: 2.0e-7,
1576
- }.get(dtype, 0)
1577
-
1578
- assert_np_equal(xcmp, xcmp_auto, tol=eps)
1579
- assert_np_equal(ycmp, ycmp_auto, tol=eps)
1580
- assert_np_equal(zcmp, zcmp_auto, tol=eps)
1581
- assert_np_equal(wcmp, wcmp_auto, tol=eps)
1582
-
1583
- assert_np_equal(gradients_x, gradients_x_auto, tol=eps)
1584
- assert_np_equal(gradients_y, gradients_y_auto, tol=eps)
1585
- assert_np_equal(gradients_z, gradients_z_auto, tol=eps)
1586
- assert_np_equal(gradients_w, gradients_w_auto, tol=eps)
1587
-
1588
- assert_np_equal(edge_gradients_x, edge_gradients_x_auto, tol=eps)
1589
- assert_np_equal(edge_gradients_y, edge_gradients_y_auto, tol=eps)
1590
- assert_np_equal(edge_gradients_z, edge_gradients_z_auto, tol=eps)
1591
- assert_np_equal(edge_gradients_w, edge_gradients_w_auto, tol=eps)
1592
-
1593
-
1594
- ############################################################
1595
-
1596
-
1597
- def test_quat_rpy_grad(test, device, dtype, register_kernels=False):
1598
- rng = np.random.default_rng(123)
1599
- N = 3
1600
-
1601
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1602
-
1603
- vec3 = wp.types.vector(3, wptype)
1604
- quat = wp.types.quaternion(wptype)
1605
-
1606
- def rpy_to_quat_kernel(rpy_arr: wp.array(dtype=vec3), loss: wp.array(dtype=wptype), coord_idx: int):
1607
- tid = wp.tid()
1608
- rpy = rpy_arr[tid]
1609
- roll = rpy[0]
1610
- pitch = rpy[1]
1611
- yaw = rpy[2]
1612
-
1613
- q = wp.quat_rpy(roll, pitch, yaw)
1614
-
1615
- wp.atomic_add(loss, 0, q[coord_idx])
1616
-
1617
- rpy_to_quat_kernel = getkernel(rpy_to_quat_kernel, suffix=dtype.__name__)
1618
-
1619
- def rpy_to_quat_kernel_forward(rpy_arr: wp.array(dtype=vec3), loss: wp.array(dtype=wptype), coord_idx: int):
1620
- tid = wp.tid()
1621
- rpy = rpy_arr[tid]
1622
- roll = rpy[0]
1623
- pitch = rpy[1]
1624
- yaw = rpy[2]
1625
-
1626
- cy = wp.cos(yaw * wptype(0.5))
1627
- sy = wp.sin(yaw * wptype(0.5))
1628
- cr = wp.cos(roll * wptype(0.5))
1629
- sr = wp.sin(roll * wptype(0.5))
1630
- cp = wp.cos(pitch * wptype(0.5))
1631
- sp = wp.sin(pitch * wptype(0.5))
1632
-
1633
- w = cy * cr * cp + sy * sr * sp
1634
- x = cy * sr * cp - sy * cr * sp
1635
- y = cy * cr * sp + sy * sr * cp
1636
- z = sy * cr * cp - cy * sr * sp
1637
-
1638
- q = quat(x, y, z, w)
1639
-
1640
- wp.atomic_add(loss, 0, q[coord_idx])
1641
-
1642
- rpy_to_quat_kernel_forward = getkernel(rpy_to_quat_kernel_forward, suffix=dtype.__name__)
1643
-
1644
- if register_kernels:
1645
- return
1646
-
1647
- rpy_arr = rng.uniform(low=-np.pi, high=np.pi, size=(N, 3))
1648
- rpy_arr = wp.array(rpy_arr, dtype=vec3, device=device, requires_grad=True)
1649
-
1650
- def compute_gradients(kernel, wrt, index):
1651
- loss = wp.zeros(1, dtype=wptype, device=device, requires_grad=True)
1652
- tape = wp.Tape()
1653
- with tape:
1654
- wp.launch(kernel=kernel, dim=N, inputs=[wrt, loss, index], device=device)
1655
-
1656
- tape.backward(loss)
1657
-
1658
- gradients = 1.0 * tape.gradients[wrt].numpy()
1659
- tape.zero()
1660
-
1661
- return loss.numpy()[0], gradients
1662
-
1663
- # wrt rpy
1664
- # gather gradients from builtin adjoints
1665
- rcmp, gradients_r = compute_gradients(rpy_to_quat_kernel, rpy_arr, 0)
1666
- pcmp, gradients_p = compute_gradients(rpy_to_quat_kernel, rpy_arr, 1)
1667
- ycmp, gradients_y = compute_gradients(rpy_to_quat_kernel, rpy_arr, 2)
1668
-
1669
- # gather gradients from autodiff
1670
- rcmp_auto, gradients_r_auto = compute_gradients(rpy_to_quat_kernel_forward, rpy_arr, 0)
1671
- pcmp_auto, gradients_p_auto = compute_gradients(rpy_to_quat_kernel_forward, rpy_arr, 1)
1672
- ycmp_auto, gradients_y_auto = compute_gradients(rpy_to_quat_kernel_forward, rpy_arr, 2)
1673
-
1674
- eps = {
1675
- np.float16: 2.0e-2,
1676
- np.float32: 1.0e-5,
1677
- np.float64: 1.0e-8,
1678
- }.get(dtype, 0)
1679
-
1680
- assert_np_equal(rcmp, rcmp_auto, tol=eps)
1681
- assert_np_equal(pcmp, pcmp_auto, tol=eps)
1682
- assert_np_equal(ycmp, ycmp_auto, tol=eps)
1683
-
1684
- assert_np_equal(gradients_r, gradients_r_auto, tol=eps)
1685
- assert_np_equal(gradients_p, gradients_p_auto, tol=eps)
1686
- assert_np_equal(gradients_y, gradients_y_auto, tol=eps)
1687
-
1688
-
1689
- ############################################################
1690
-
1691
-
1692
- def test_quat_from_matrix(test, device, dtype, register_kernels=False):
1693
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1694
- mat33 = wp.types.matrix((3, 3), wptype)
1695
- quat = wp.types.quaternion(wptype)
1696
-
1697
- def quat_from_matrix(m: wp.array2d(dtype=wptype), loss: wp.array(dtype=wptype), idx: int):
1698
- tid = wp.tid()
1699
-
1700
- matrix = mat33(
1701
- m[tid, 0], m[tid, 1], m[tid, 2], m[tid, 3], m[tid, 4], m[tid, 5], m[tid, 6], m[tid, 7], m[tid, 8]
1702
- )
1703
-
1704
- q = wp.quat_from_matrix(matrix)
1705
-
1706
- wp.atomic_add(loss, 0, q[idx])
1707
-
1708
- def quat_from_matrix_forward(mats: wp.array2d(dtype=wptype), loss: wp.array(dtype=wptype), idx: int):
1709
- tid = wp.tid()
1710
-
1711
- m = mat33(
1712
- mats[tid, 0],
1713
- mats[tid, 1],
1714
- mats[tid, 2],
1715
- mats[tid, 3],
1716
- mats[tid, 4],
1717
- mats[tid, 5],
1718
- mats[tid, 6],
1719
- mats[tid, 7],
1720
- mats[tid, 8],
1721
- )
1722
-
1723
- tr = m[0][0] + m[1][1] + m[2][2]
1724
- x = wptype(0)
1725
- y = wptype(0)
1726
- z = wptype(0)
1727
- w = wptype(0)
1728
- h = wptype(0)
1729
-
1730
- if tr >= wptype(0):
1731
- h = wp.sqrt(tr + wptype(1))
1732
- w = wptype(0.5) * h
1733
- h = wptype(0.5) / h
1734
-
1735
- x = (m[2][1] - m[1][2]) * h
1736
- y = (m[0][2] - m[2][0]) * h
1737
- z = (m[1][0] - m[0][1]) * h
1738
- else:
1739
- max_diag = 0
1740
- if m[1][1] > m[0][0]:
1741
- max_diag = 1
1742
- if m[2][2] > m[max_diag][max_diag]:
1743
- max_diag = 2
1744
-
1745
- if max_diag == 0:
1746
- h = wp.sqrt((m[0][0] - (m[1][1] + m[2][2])) + wptype(1))
1747
- x = wptype(0.5) * h
1748
- h = wptype(0.5) / h
1749
-
1750
- y = (m[0][1] + m[1][0]) * h
1751
- z = (m[2][0] + m[0][2]) * h
1752
- w = (m[2][1] - m[1][2]) * h
1753
- elif max_diag == 1:
1754
- h = wp.sqrt((m[1][1] - (m[2][2] + m[0][0])) + wptype(1))
1755
- y = wptype(0.5) * h
1756
- h = wptype(0.5) / h
1757
-
1758
- z = (m[1][2] + m[2][1]) * h
1759
- x = (m[0][1] + m[1][0]) * h
1760
- w = (m[0][2] - m[2][0]) * h
1761
- if max_diag == 2:
1762
- h = wp.sqrt((m[2][2] - (m[0][0] + m[1][1])) + wptype(1))
1763
- z = wptype(0.5) * h
1764
- h = wptype(0.5) / h
1765
-
1766
- x = (m[2][0] + m[0][2]) * h
1767
- y = (m[1][2] + m[2][1]) * h
1768
- w = (m[1][0] - m[0][1]) * h
1769
-
1770
- q = wp.normalize(quat(x, y, z, w))
1771
-
1772
- wp.atomic_add(loss, 0, q[idx])
1773
-
1774
- quat_from_matrix = getkernel(quat_from_matrix, suffix=dtype.__name__)
1775
- quat_from_matrix_forward = getkernel(quat_from_matrix_forward, suffix=dtype.__name__)
1776
-
1777
- if register_kernels:
1778
- return
1779
-
1780
- m = np.array(
1781
- [
1782
- [1.0, 0.0, 0.0, 0.0, 0.5, 0.866, 0.0, -0.866, 0.5],
1783
- [0.866, 0.0, 0.25, -0.433, 0.5, 0.75, -0.25, -0.866, 0.433],
1784
- [0.866, -0.433, 0.25, 0.0, 0.5, 0.866, -0.5, -0.75, 0.433],
1785
- [-1.2, -1.6, -2.3, 0.25, -0.6, -0.33, 3.2, -1.0, -2.2],
1786
- ]
1787
- )
1788
- m = wp.array2d(m, dtype=wptype, device=device, requires_grad=True)
1789
-
1790
- N = m.shape[0]
1791
-
1792
- def compute_gradients(kernel, wrt, index):
1793
- loss = wp.zeros(1, dtype=wptype, device=device, requires_grad=True)
1794
- tape = wp.Tape()
1795
-
1796
- with tape:
1797
- wp.launch(kernel=kernel, dim=N, inputs=[m, loss, index], device=device)
1798
-
1799
- tape.backward(loss)
1800
-
1801
- gradients = 1.0 * tape.gradients[wrt].numpy()
1802
- tape.zero()
1803
-
1804
- return loss.numpy()[0], gradients
1805
-
1806
- # gather gradients from builtin adjoints
1807
- cmpx, gradients_x = compute_gradients(quat_from_matrix, m, 0)
1808
- cmpy, gradients_y = compute_gradients(quat_from_matrix, m, 1)
1809
- cmpz, gradients_z = compute_gradients(quat_from_matrix, m, 2)
1810
- cmpw, gradients_w = compute_gradients(quat_from_matrix, m, 3)
1811
-
1812
- # gather gradients from autodiff
1813
- cmpx_auto, gradients_x_auto = compute_gradients(quat_from_matrix_forward, m, 0)
1814
- cmpy_auto, gradients_y_auto = compute_gradients(quat_from_matrix_forward, m, 1)
1815
- cmpz_auto, gradients_z_auto = compute_gradients(quat_from_matrix_forward, m, 2)
1816
- cmpw_auto, gradients_w_auto = compute_gradients(quat_from_matrix_forward, m, 3)
1817
-
1818
- # compare
1819
- eps = 1.0e6
1820
-
1821
- eps = {
1822
- np.float16: 2.0e-2,
1823
- np.float32: 1.0e-5,
1824
- np.float64: 1.0e-8,
1825
- }.get(dtype, 0)
1826
-
1827
- assert_np_equal(cmpx, cmpx_auto, tol=eps)
1828
- assert_np_equal(cmpy, cmpy_auto, tol=eps)
1829
- assert_np_equal(cmpz, cmpz_auto, tol=eps)
1830
- assert_np_equal(cmpw, cmpw_auto, tol=eps)
1831
-
1832
- assert_np_equal(gradients_x, gradients_x_auto, tol=eps)
1833
- assert_np_equal(gradients_y, gradients_y_auto, tol=eps)
1834
- assert_np_equal(gradients_z, gradients_z_auto, tol=eps)
1835
- assert_np_equal(gradients_w, gradients_w_auto, tol=eps)
1836
-
1837
-
1838
- def test_quat_identity(test, device, dtype, register_kernels=False):
1839
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1840
-
1841
- def quat_identity_test(output: wp.array(dtype=wptype)):
1842
- q = wp.quat_identity(dtype=wptype)
1843
- output[0] = q[0]
1844
- output[1] = q[1]
1845
- output[2] = q[2]
1846
- output[3] = q[3]
1847
-
1848
- def quat_identity_test_default(output: wp.array(dtype=wp.float32)):
1849
- q = wp.quat_identity()
1850
- output[0] = q[0]
1851
- output[1] = q[1]
1852
- output[2] = q[2]
1853
- output[3] = q[3]
1854
-
1855
- quat_identity_kernel = getkernel(quat_identity_test, suffix=dtype.__name__)
1856
- quat_identity_default_kernel = getkernel(quat_identity_test_default, suffix=np.float32.__name__)
1857
-
1858
- if register_kernels:
1859
- return
1860
-
1861
- output = wp.zeros(4, dtype=wptype, device=device)
1862
- wp.launch(quat_identity_kernel, dim=1, inputs=[], outputs=[output], device=device)
1863
- expected = np.zeros_like(output.numpy())
1864
- expected[3] = 1
1865
- assert_np_equal(output.numpy(), expected)
1866
-
1867
- # let's just test that it defaults to float32:
1868
- output = wp.zeros(4, dtype=wp.float32, device=device)
1869
- wp.launch(quat_identity_default_kernel, dim=1, inputs=[], outputs=[output], device=device)
1870
- expected = np.zeros_like(output.numpy())
1871
- expected[3] = 1
1872
- assert_np_equal(output.numpy(), expected)
1873
-
1874
-
1875
- ############################################################
1876
-
1877
-
1878
- def test_quat_euler_conversion(test, device, dtype, register_kernels=False):
1879
- rng = np.random.default_rng(123)
1880
- N = 3
1881
-
1882
- rpy_arr = rng.uniform(low=-np.pi, high=np.pi, size=(N, 3))
1883
-
1884
- quats_from_euler = [list(wp.sim.quat_from_euler(wp.vec3(*rpy), 0, 1, 2)) for rpy in rpy_arr]
1885
- quats_from_rpy = [list(wp.quat_rpy(rpy[0], rpy[1], rpy[2])) for rpy in rpy_arr]
1886
-
1887
- assert_np_equal(np.array(quats_from_euler), np.array(quats_from_rpy), tol=1e-4)
1888
-
1889
-
1890
- def test_anon_type_instance(test, device, dtype, register_kernels=False):
1891
- rng = np.random.default_rng(123)
1892
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1893
-
1894
- def quat_create_test(input: wp.array(dtype=wptype), output: wp.array(dtype=wptype)):
1895
- # component constructor:
1896
- q = wp.quaternion(input[0], input[1], input[2], input[3])
1897
- output[0] = wptype(2) * q[0]
1898
- output[1] = wptype(2) * q[1]
1899
- output[2] = wptype(2) * q[2]
1900
- output[3] = wptype(2) * q[3]
1901
-
1902
- # vector / scalar constructor:
1903
- q2 = wp.quaternion(wp.vector(input[4], input[5], input[6]), input[7])
1904
- output[4] = wptype(2) * q2[0]
1905
- output[5] = wptype(2) * q2[1]
1906
- output[6] = wptype(2) * q2[2]
1907
- output[7] = wptype(2) * q2[3]
1908
-
1909
- quat_create_kernel = getkernel(quat_create_test, suffix=dtype.__name__)
1910
- output_select_kernel = get_select_kernel(wptype)
1911
-
1912
- if register_kernels:
1913
- return
1914
-
1915
- input = wp.array(rng.standard_normal(size=8).astype(dtype), requires_grad=True, device=device)
1916
- output = wp.zeros(8, dtype=wptype, requires_grad=True, device=device)
1917
- wp.launch(quat_create_kernel, dim=1, inputs=[input], outputs=[output], device=device)
1918
- assert_np_equal(output.numpy(), 2 * input.numpy())
1919
-
1920
- for i in range(len(input)):
1921
- cmp = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1922
- tape = wp.Tape()
1923
- with tape:
1924
- wp.launch(quat_create_kernel, dim=1, inputs=[input], outputs=[output], device=device)
1925
- wp.launch(output_select_kernel, dim=1, inputs=[output, i], outputs=[cmp], device=device)
1926
- tape.backward(loss=cmp)
1927
- expectedgrads = np.zeros(len(input))
1928
- expectedgrads[i] = 2
1929
- assert_np_equal(tape.gradients[input].numpy(), expectedgrads)
1930
- tape.zero()
1931
-
1932
-
1933
- # Same as above but with a default (float) type
1934
- # which tests some different code paths that
1935
- # need to ensure types are correctly canonicalized
1936
- # during codegen
1937
- @wp.kernel
1938
- def test_constructor_default():
1939
- qzero = wp.quat()
1940
- wp.expect_eq(qzero[0], 0.0)
1941
- wp.expect_eq(qzero[1], 0.0)
1942
- wp.expect_eq(qzero[2], 0.0)
1943
- wp.expect_eq(qzero[3], 0.0)
1944
-
1945
- qval = wp.quat(1.0, 2.0, 3.0, 4.0)
1946
- wp.expect_eq(qval[0], 1.0)
1947
- wp.expect_eq(qval[1], 2.0)
1948
- wp.expect_eq(qval[2], 3.0)
1949
- wp.expect_eq(qval[3], 4.0)
1950
-
1951
- qeye = wp.quat_identity()
1952
- wp.expect_eq(qeye[0], 0.0)
1953
- wp.expect_eq(qeye[1], 0.0)
1954
- wp.expect_eq(qeye[2], 0.0)
1955
- wp.expect_eq(qeye[3], 1.0)
1956
-
1957
-
1958
- def test_py_arithmetic_ops(test, device, dtype):
1959
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1960
-
1961
- def make_quat(*args):
1962
- if wptype in wp.types.int_types:
1963
- # Cast to the correct integer type to simulate wrapping.
1964
- return tuple(wptype._type_(x).value for x in args)
1965
-
1966
- return args
1967
-
1968
- quat_cls = wp.types.quaternion(wptype)
1969
-
1970
- v = quat_cls(1, -2, 3, -4)
1971
- test.assertSequenceEqual(+v, make_quat(1, -2, 3, -4))
1972
- test.assertSequenceEqual(-v, make_quat(-1, 2, -3, 4))
1973
- test.assertSequenceEqual(v + quat_cls(5, 5, 5, 5), make_quat(6, 3, 8, 1))
1974
- test.assertSequenceEqual(v - quat_cls(5, 5, 5, 5), make_quat(-4, -7, -2, -9))
1975
-
1976
- v = quat_cls(2, 4, 6, 8)
1977
- test.assertSequenceEqual(v * wptype(2), make_quat(4, 8, 12, 16))
1978
- test.assertSequenceEqual(wptype(2) * v, make_quat(4, 8, 12, 16))
1979
- test.assertSequenceEqual(v / wptype(2), make_quat(1, 2, 3, 4))
1980
- test.assertSequenceEqual(wptype(24) / v, make_quat(12, 6, 4, 3))
1981
-
1982
-
1983
- devices = get_test_devices()
1984
-
1985
-
1986
- class TestQuat(unittest.TestCase):
1987
- pass
1988
-
1989
-
1990
- add_kernel_test(TestQuat, test_constructor_default, dim=1, devices=devices)
1991
-
1992
- for dtype in np_float_types:
1993
- add_function_test_register_kernel(
1994
- TestQuat, f"test_constructors_{dtype.__name__}", test_constructors, devices=devices, dtype=dtype
1995
- )
1996
- add_function_test_register_kernel(
1997
- TestQuat,
1998
- f"test_casting_constructors_{dtype.__name__}",
1999
- test_casting_constructors,
2000
- devices=devices,
2001
- dtype=dtype,
2002
- )
2003
- add_function_test_register_kernel(
2004
- TestQuat, f"test_anon_type_instance_{dtype.__name__}", test_anon_type_instance, devices=devices, dtype=dtype
2005
- )
2006
- add_function_test_register_kernel(
2007
- TestQuat, f"test_inverse_{dtype.__name__}", test_inverse, devices=devices, dtype=dtype
2008
- )
2009
- add_function_test_register_kernel(
2010
- TestQuat, f"test_quat_identity_{dtype.__name__}", test_quat_identity, devices=devices, dtype=dtype
2011
- )
2012
- add_function_test_register_kernel(
2013
- TestQuat, f"test_dotproduct_{dtype.__name__}", test_dotproduct, devices=devices, dtype=dtype
2014
- )
2015
- add_function_test_register_kernel(
2016
- TestQuat, f"test_length_{dtype.__name__}", test_length, devices=devices, dtype=dtype
2017
- )
2018
- add_function_test_register_kernel(
2019
- TestQuat, f"test_normalize_{dtype.__name__}", test_normalize, devices=devices, dtype=dtype
2020
- )
2021
- add_function_test_register_kernel(
2022
- TestQuat, f"test_addition_{dtype.__name__}", test_addition, devices=devices, dtype=dtype
2023
- )
2024
- add_function_test_register_kernel(
2025
- TestQuat, f"test_subtraction_{dtype.__name__}", test_subtraction, devices=devices, dtype=dtype
2026
- )
2027
- add_function_test_register_kernel(
2028
- TestQuat,
2029
- f"test_scalar_multiplication_{dtype.__name__}",
2030
- test_scalar_multiplication,
2031
- devices=devices,
2032
- dtype=dtype,
2033
- )
2034
- add_function_test_register_kernel(
2035
- TestQuat, f"test_scalar_division_{dtype.__name__}", test_scalar_division, devices=devices, dtype=dtype
2036
- )
2037
- add_function_test_register_kernel(
2038
- TestQuat,
2039
- f"test_quat_multiplication_{dtype.__name__}",
2040
- test_quat_multiplication,
2041
- devices=devices,
2042
- dtype=dtype,
2043
- )
2044
- add_function_test_register_kernel(
2045
- TestQuat, f"test_indexing_{dtype.__name__}", test_indexing, devices=devices, dtype=dtype
2046
- )
2047
- add_function_test_register_kernel(
2048
- TestQuat, f"test_quat_lerp_{dtype.__name__}", test_quat_lerp, devices=devices, dtype=dtype
2049
- )
2050
- add_function_test_register_kernel(
2051
- TestQuat,
2052
- f"test_quat_to_axis_angle_grad_{dtype.__name__}",
2053
- test_quat_to_axis_angle_grad,
2054
- devices=devices,
2055
- dtype=dtype,
2056
- )
2057
- add_function_test_register_kernel(
2058
- TestQuat, f"test_slerp_grad_{dtype.__name__}", test_slerp_grad, devices=devices, dtype=dtype
2059
- )
2060
- add_function_test_register_kernel(
2061
- TestQuat, f"test_quat_rpy_grad_{dtype.__name__}", test_quat_rpy_grad, devices=devices, dtype=dtype
2062
- )
2063
- add_function_test_register_kernel(
2064
- TestQuat, f"test_quat_from_matrix_{dtype.__name__}", test_quat_from_matrix, devices=devices, dtype=dtype
2065
- )
2066
- add_function_test_register_kernel(
2067
- TestQuat, f"test_quat_rotate_{dtype.__name__}", test_quat_rotate, devices=devices, dtype=dtype
2068
- )
2069
- add_function_test_register_kernel(
2070
- TestQuat, f"test_quat_to_matrix_{dtype.__name__}", test_quat_to_matrix, devices=devices, dtype=dtype
2071
- )
2072
- add_function_test_register_kernel(
2073
- TestQuat,
2074
- f"test_quat_euler_conversion_{dtype.__name__}",
2075
- test_quat_euler_conversion,
2076
- devices=devices,
2077
- dtype=dtype,
2078
- )
2079
- add_function_test(
2080
- TestQuat, f"test_py_arithmetic_ops_{dtype.__name__}", test_py_arithmetic_ops, devices=None, dtype=dtype
2081
- )
2082
-
2083
-
2084
- if __name__ == "__main__":
2085
- wp.build.clear_kernel_cache()
2086
- unittest.main(verbosity=2)
1
+ # Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
2
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
3
+ # and proprietary rights in and to this software, related documentation
4
+ # and any modifications thereto. Any use, reproduction, disclosure or
5
+ # distribution of this software and related documentation without an express
6
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
+
8
+ import unittest
9
+
10
+ import numpy as np
11
+
12
+ import warp as wp
13
+ import warp.sim
14
+ from warp.tests.unittest_utils import *
15
+
16
+ np_float_types = [np.float32, np.float64, np.float16]
17
+
18
+ kernel_cache = {}
19
+
20
+
21
+ def getkernel(func, suffix=""):
22
+ key = func.__name__ + "_" + suffix
23
+ if key not in kernel_cache:
24
+ kernel_cache[key] = wp.Kernel(func=func, key=key)
25
+ return kernel_cache[key]
26
+
27
+
28
+ def get_select_kernel(dtype):
29
+ def output_select_kernel_fn(
30
+ input: wp.array(dtype=dtype),
31
+ index: int,
32
+ out: wp.array(dtype=dtype),
33
+ ):
34
+ out[0] = input[index]
35
+
36
+ return getkernel(output_select_kernel_fn, suffix=dtype.__name__)
37
+
38
+
39
+ ############################################################
40
+
41
+
42
+ def test_constructors(test, device, dtype, register_kernels=False):
43
+ rng = np.random.default_rng(123)
44
+
45
+ tol = {
46
+ np.float16: 5.0e-3,
47
+ np.float32: 1.0e-6,
48
+ np.float64: 1.0e-8,
49
+ }.get(dtype, 0)
50
+
51
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
52
+ vec3 = wp.types.vector(length=3, dtype=wptype)
53
+ quat = wp.types.quaternion(dtype=wptype)
54
+
55
+ def check_component_constructor(
56
+ input: wp.array(dtype=wptype),
57
+ q: wp.array(dtype=wptype),
58
+ ):
59
+ qresult = quat(input[0], input[1], input[2], input[3])
60
+
61
+ # multiply the output by 2 so we've got something to backpropagate:
62
+ q[0] = wptype(2) * qresult[0]
63
+ q[1] = wptype(2) * qresult[1]
64
+ q[2] = wptype(2) * qresult[2]
65
+ q[3] = wptype(2) * qresult[3]
66
+
67
+ def check_vector_constructor(
68
+ input: wp.array(dtype=wptype),
69
+ q: wp.array(dtype=wptype),
70
+ ):
71
+ qresult = quat(vec3(input[0], input[1], input[2]), input[3])
72
+
73
+ # multiply the output by 2 so we've got something to backpropagate:
74
+ q[0] = wptype(2) * qresult[0]
75
+ q[1] = wptype(2) * qresult[1]
76
+ q[2] = wptype(2) * qresult[2]
77
+ q[3] = wptype(2) * qresult[3]
78
+
79
+ kernel = getkernel(check_component_constructor, suffix=dtype.__name__)
80
+ output_select_kernel = get_select_kernel(wptype)
81
+ vec_kernel = getkernel(check_vector_constructor, suffix=dtype.__name__)
82
+
83
+ if register_kernels:
84
+ return
85
+
86
+ input = wp.array(rng.standard_normal(size=4).astype(dtype), requires_grad=True, device=device)
87
+ output = wp.zeros_like(input)
88
+ wp.launch(kernel, dim=1, inputs=[input], outputs=[output], device=device)
89
+
90
+ assert_np_equal(output.numpy(), 2 * input.numpy(), tol=tol)
91
+
92
+ for i in range(4):
93
+ cmp = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
94
+ tape = wp.Tape()
95
+ with tape:
96
+ wp.launch(kernel, dim=1, inputs=[input], outputs=[output], device=device)
97
+ wp.launch(output_select_kernel, dim=1, inputs=[output, i], outputs=[cmp], device=device)
98
+ tape.backward(loss=cmp)
99
+ expectedgrads = np.zeros(len(input))
100
+ expectedgrads[i] = 2
101
+ assert_np_equal(tape.gradients[input].numpy(), expectedgrads)
102
+ tape.zero()
103
+
104
+ input = wp.array(rng.standard_normal(size=4).astype(dtype), requires_grad=True, device=device)
105
+ output = wp.zeros_like(input)
106
+ wp.launch(vec_kernel, dim=1, inputs=[input], outputs=[output], device=device)
107
+
108
+ assert_np_equal(output.numpy(), 2 * input.numpy(), tol=tol)
109
+
110
+ for i in range(4):
111
+ cmp = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
112
+ tape = wp.Tape()
113
+ with tape:
114
+ wp.launch(vec_kernel, dim=1, inputs=[input], outputs=[output], device=device)
115
+ wp.launch(output_select_kernel, dim=1, inputs=[output, i], outputs=[cmp], device=device)
116
+ tape.backward(loss=cmp)
117
+ expectedgrads = np.zeros(len(input))
118
+ expectedgrads[i] = 2
119
+ assert_np_equal(tape.gradients[input].numpy(), expectedgrads)
120
+ tape.zero()
121
+
122
+
123
+ def test_casting_constructors(test, device, dtype, register_kernels=False):
124
+ np_type = np.dtype(dtype)
125
+ wp_type = wp.types.np_dtype_to_warp_type[np_type]
126
+ quat = wp.types.quaternion(dtype=wp_type)
127
+
128
+ np16 = np.dtype(np.float16)
129
+ wp16 = wp.types.np_dtype_to_warp_type[np16]
130
+
131
+ np32 = np.dtype(np.float32)
132
+ wp32 = wp.types.np_dtype_to_warp_type[np32]
133
+
134
+ np64 = np.dtype(np.float64)
135
+ wp64 = wp.types.np_dtype_to_warp_type[np64]
136
+
137
+ def cast_float16(a: wp.array(dtype=wp_type, ndim=2), b: wp.array(dtype=wp16, ndim=2)):
138
+ tid = wp.tid()
139
+
140
+ q1 = quat(a[tid, 0], a[tid, 1], a[tid, 2], a[tid, 3])
141
+ q2 = wp.quaternion(q1, dtype=wp16)
142
+
143
+ b[tid, 0] = q2[0]
144
+ b[tid, 1] = q2[1]
145
+ b[tid, 2] = q2[2]
146
+ b[tid, 3] = q2[3]
147
+
148
+ def cast_float32(a: wp.array(dtype=wp_type, ndim=2), b: wp.array(dtype=wp32, ndim=2)):
149
+ tid = wp.tid()
150
+
151
+ q1 = quat(a[tid, 0], a[tid, 1], a[tid, 2], a[tid, 3])
152
+ q2 = wp.quaternion(q1, dtype=wp32)
153
+
154
+ b[tid, 0] = q2[0]
155
+ b[tid, 1] = q2[1]
156
+ b[tid, 2] = q2[2]
157
+ b[tid, 3] = q2[3]
158
+
159
+ def cast_float64(a: wp.array(dtype=wp_type, ndim=2), b: wp.array(dtype=wp64, ndim=2)):
160
+ tid = wp.tid()
161
+
162
+ q1 = quat(a[tid, 0], a[tid, 1], a[tid, 2], a[tid, 3])
163
+ q2 = wp.quaternion(q1, dtype=wp64)
164
+
165
+ b[tid, 0] = q2[0]
166
+ b[tid, 1] = q2[1]
167
+ b[tid, 2] = q2[2]
168
+ b[tid, 3] = q2[3]
169
+
170
+ kernel_16 = getkernel(cast_float16, suffix=dtype.__name__)
171
+ kernel_32 = getkernel(cast_float32, suffix=dtype.__name__)
172
+ kernel_64 = getkernel(cast_float64, suffix=dtype.__name__)
173
+
174
+ if register_kernels:
175
+ return
176
+
177
+ # check casting to float 16
178
+ a = wp.array(np.ones((1, 4), dtype=np_type), dtype=wp_type, requires_grad=True, device=device)
179
+ b = wp.array(np.zeros((1, 4), dtype=np16), dtype=wp16, requires_grad=True, device=device)
180
+ b_result = np.ones((1, 4), dtype=np16)
181
+ b_grad = wp.array(np.ones((1, 4), dtype=np16), dtype=wp16, device=device)
182
+ a_grad = wp.array(np.ones((1, 4), dtype=np_type), dtype=wp_type, device=device)
183
+
184
+ tape = wp.Tape()
185
+ with tape:
186
+ wp.launch(kernel=kernel_16, dim=1, inputs=[a, b], device=device)
187
+
188
+ tape.backward(grads={b: b_grad})
189
+ out = tape.gradients[a].numpy()
190
+
191
+ assert_np_equal(b.numpy(), b_result)
192
+ assert_np_equal(out, a_grad.numpy())
193
+
194
+ # check casting to float 32
195
+ a = wp.array(np.ones((1, 4), dtype=np_type), dtype=wp_type, requires_grad=True, device=device)
196
+ b = wp.array(np.zeros((1, 4), dtype=np32), dtype=wp32, requires_grad=True, device=device)
197
+ b_result = np.ones((1, 4), dtype=np32)
198
+ b_grad = wp.array(np.ones((1, 4), dtype=np32), dtype=wp32, device=device)
199
+ a_grad = wp.array(np.ones((1, 4), dtype=np_type), dtype=wp_type, device=device)
200
+
201
+ tape = wp.Tape()
202
+ with tape:
203
+ wp.launch(kernel=kernel_32, dim=1, inputs=[a, b], device=device)
204
+
205
+ tape.backward(grads={b: b_grad})
206
+ out = tape.gradients[a].numpy()
207
+
208
+ assert_np_equal(b.numpy(), b_result)
209
+ assert_np_equal(out, a_grad.numpy())
210
+
211
+ # check casting to float 64
212
+ a = wp.array(np.ones((1, 4), dtype=np_type), dtype=wp_type, requires_grad=True, device=device)
213
+ b = wp.array(np.zeros((1, 4), dtype=np64), dtype=wp64, requires_grad=True, device=device)
214
+ b_result = np.ones((1, 4), dtype=np64)
215
+ b_grad = wp.array(np.ones((1, 4), dtype=np64), dtype=wp64, device=device)
216
+ a_grad = wp.array(np.ones((1, 4), dtype=np_type), dtype=wp_type, device=device)
217
+
218
+ tape = wp.Tape()
219
+ with tape:
220
+ wp.launch(kernel=kernel_64, dim=1, inputs=[a, b], device=device)
221
+
222
+ tape.backward(grads={b: b_grad})
223
+ out = tape.gradients[a].numpy()
224
+
225
+ assert_np_equal(b.numpy(), b_result)
226
+ assert_np_equal(out, a_grad.numpy())
227
+
228
+
229
+ def test_inverse(test, device, dtype, register_kernels=False):
230
+ rng = np.random.default_rng(123)
231
+
232
+ tol = {
233
+ np.float16: 2.0e-3,
234
+ np.float32: 1.0e-6,
235
+ np.float64: 1.0e-8,
236
+ }.get(dtype, 0)
237
+
238
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
239
+ quat = wp.types.quaternion(dtype=wptype)
240
+
241
+ output_select_kernel = get_select_kernel(wptype)
242
+
243
+ def check_quat_inverse(
244
+ input: wp.array(dtype=wptype),
245
+ shouldbeidentity: wp.array(dtype=quat),
246
+ q: wp.array(dtype=wptype),
247
+ ):
248
+ qread = quat(input[0], input[1], input[2], input[3])
249
+ qresult = wp.quat_inverse(qread)
250
+
251
+ # this inverse should work for normalized quaternions:
252
+ shouldbeidentity[0] = wp.normalize(qread) * wp.quat_inverse(wp.normalize(qread))
253
+
254
+ # multiply the output by 2 so we've got something to backpropagate:
255
+ q[0] = wptype(2) * qresult[0]
256
+ q[1] = wptype(2) * qresult[1]
257
+ q[2] = wptype(2) * qresult[2]
258
+ q[3] = wptype(2) * qresult[3]
259
+
260
+ kernel = getkernel(check_quat_inverse, suffix=dtype.__name__)
261
+
262
+ if register_kernels:
263
+ return
264
+
265
+ input = wp.array(rng.standard_normal(size=4).astype(dtype), requires_grad=True, device=device)
266
+ shouldbeidentity = wp.array(np.zeros((1, 4)), dtype=quat, requires_grad=True, device=device)
267
+ output = wp.zeros_like(input)
268
+ wp.launch(kernel, dim=1, inputs=[input], outputs=[shouldbeidentity, output], device=device)
269
+
270
+ assert_np_equal(shouldbeidentity.numpy(), np.array([0, 0, 0, 1]), tol=tol)
271
+
272
+ for i in range(4):
273
+ cmp = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
274
+ tape = wp.Tape()
275
+ with tape:
276
+ wp.launch(kernel, dim=1, inputs=[input], outputs=[shouldbeidentity, output], device=device)
277
+ wp.launch(output_select_kernel, dim=1, inputs=[output, i], outputs=[cmp], device=device)
278
+ tape.backward(loss=cmp)
279
+ expectedgrads = np.zeros(len(input))
280
+ expectedgrads[i] = -2 if i != 3 else 2
281
+ assert_np_equal(tape.gradients[input].numpy(), expectedgrads)
282
+ tape.zero()
283
+
284
+
285
+ def test_dotproduct(test, device, dtype, register_kernels=False):
286
+ rng = np.random.default_rng(123)
287
+
288
+ tol = {
289
+ np.float16: 1.0e-2,
290
+ np.float32: 1.0e-6,
291
+ np.float64: 1.0e-8,
292
+ }.get(dtype, 0)
293
+
294
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
295
+ quat = wp.types.quaternion(dtype=wptype)
296
+
297
+ def check_quat_dot(
298
+ s: wp.array(dtype=quat),
299
+ v: wp.array(dtype=quat),
300
+ dot: wp.array(dtype=wptype),
301
+ ):
302
+ dot[0] = wptype(2) * wp.dot(v[0], s[0])
303
+
304
+ dotkernel = getkernel(check_quat_dot, suffix=dtype.__name__)
305
+ if register_kernels:
306
+ return
307
+
308
+ s = wp.array(rng.standard_normal(size=4).astype(dtype), dtype=quat, requires_grad=True, device=device)
309
+ v = wp.array(rng.standard_normal(size=4).astype(dtype), dtype=quat, requires_grad=True, device=device)
310
+ dot = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
311
+
312
+ tape = wp.Tape()
313
+ with tape:
314
+ wp.launch(
315
+ dotkernel,
316
+ dim=1,
317
+ inputs=[
318
+ s,
319
+ v,
320
+ ],
321
+ outputs=[dot],
322
+ device=device,
323
+ )
324
+
325
+ assert_np_equal(dot.numpy()[0], 2.0 * (v.numpy() * s.numpy()).sum(), tol=tol)
326
+
327
+ tape.backward(loss=dot)
328
+ sgrads = tape.gradients[s].numpy()[0]
329
+ expected_grads = 2.0 * v.numpy()[0]
330
+ assert_np_equal(sgrads, expected_grads, tol=10 * tol)
331
+
332
+ vgrads = tape.gradients[v].numpy()[0]
333
+ expected_grads = 2.0 * s.numpy()[0]
334
+ assert_np_equal(vgrads, expected_grads, tol=tol)
335
+
336
+
337
+ def test_length(test, device, dtype, register_kernels=False):
338
+ rng = np.random.default_rng(123)
339
+
340
+ tol = {
341
+ np.float16: 5.0e-3,
342
+ np.float32: 1.0e-6,
343
+ np.float64: 1.0e-7,
344
+ }.get(dtype, 0)
345
+
346
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
347
+ quat = wp.types.quaternion(dtype=wptype)
348
+
349
+ def check_quat_length(
350
+ q: wp.array(dtype=quat),
351
+ l: wp.array(dtype=wptype),
352
+ l2: wp.array(dtype=wptype),
353
+ ):
354
+ l[0] = wptype(2) * wp.length(q[0])
355
+ l2[0] = wptype(2) * wp.length_sq(q[0])
356
+
357
+ kernel = getkernel(check_quat_length, suffix=dtype.__name__)
358
+
359
+ if register_kernels:
360
+ return
361
+
362
+ q = wp.array(rng.standard_normal(size=4).astype(dtype), dtype=quat, requires_grad=True, device=device)
363
+ l = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
364
+ l2 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
365
+
366
+ tape = wp.Tape()
367
+ with tape:
368
+ wp.launch(
369
+ kernel,
370
+ dim=1,
371
+ inputs=[
372
+ q,
373
+ ],
374
+ outputs=[l, l2],
375
+ device=device,
376
+ )
377
+
378
+ assert_np_equal(l.numpy()[0], 2 * np.linalg.norm(q.numpy()), tol=10 * tol)
379
+ assert_np_equal(l2.numpy()[0], 2 * np.linalg.norm(q.numpy()) ** 2, tol=10 * tol)
380
+
381
+ tape.backward(loss=l)
382
+ grad = tape.gradients[q].numpy()[0]
383
+ expected_grad = 2 * q.numpy()[0] / np.linalg.norm(q.numpy())
384
+ assert_np_equal(grad, expected_grad, tol=10 * tol)
385
+ tape.zero()
386
+
387
+ tape.backward(loss=l2)
388
+ grad = tape.gradients[q].numpy()[0]
389
+ expected_grad = 4 * q.numpy()[0]
390
+ assert_np_equal(grad, expected_grad, tol=10 * tol)
391
+ tape.zero()
392
+
393
+
394
+ def test_normalize(test, device, dtype, register_kernels=False):
395
+ rng = np.random.default_rng(123)
396
+
397
+ tol = {
398
+ np.float16: 5.0e-3,
399
+ np.float32: 1.0e-6,
400
+ np.float64: 1.0e-8,
401
+ }.get(dtype, 0)
402
+
403
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
404
+ quat = wp.types.quaternion(dtype=wptype)
405
+
406
+ def check_normalize(
407
+ q: wp.array(dtype=quat),
408
+ n0: wp.array(dtype=wptype),
409
+ n1: wp.array(dtype=wptype),
410
+ n2: wp.array(dtype=wptype),
411
+ n3: wp.array(dtype=wptype),
412
+ ):
413
+ n = wptype(2) * (wp.normalize(q[0]))
414
+
415
+ n0[0] = n[0]
416
+ n1[0] = n[1]
417
+ n2[0] = n[2]
418
+ n3[0] = n[3]
419
+
420
+ def check_normalize_alt(
421
+ q: wp.array(dtype=quat),
422
+ n0: wp.array(dtype=wptype),
423
+ n1: wp.array(dtype=wptype),
424
+ n2: wp.array(dtype=wptype),
425
+ n3: wp.array(dtype=wptype),
426
+ ):
427
+ n = wptype(2) * (q[0] / wp.length(q[0]))
428
+
429
+ n0[0] = n[0]
430
+ n1[0] = n[1]
431
+ n2[0] = n[2]
432
+ n3[0] = n[3]
433
+
434
+ normalize_kernel = getkernel(check_normalize, suffix=dtype.__name__)
435
+ normalize_alt_kernel = getkernel(check_normalize_alt, suffix=dtype.__name__)
436
+
437
+ if register_kernels:
438
+ return
439
+
440
+ # I've already tested the things I'm using in check_normalize_alt, so I'll just
441
+ # make sure the two are giving the same results/gradients
442
+ q = wp.array(rng.standard_normal(size=4).astype(dtype), dtype=quat, requires_grad=True, device=device)
443
+
444
+ n0 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
445
+ n1 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
446
+ n2 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
447
+ n3 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
448
+
449
+ n0_alt = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
450
+ n1_alt = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
451
+ n2_alt = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
452
+ n3_alt = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
453
+
454
+ outputs0 = [
455
+ n0,
456
+ n1,
457
+ n2,
458
+ n3,
459
+ ]
460
+ tape0 = wp.Tape()
461
+ with tape0:
462
+ wp.launch(normalize_kernel, dim=1, inputs=[q], outputs=outputs0, device=device)
463
+
464
+ outputs1 = [
465
+ n0_alt,
466
+ n1_alt,
467
+ n2_alt,
468
+ n3_alt,
469
+ ]
470
+ tape1 = wp.Tape()
471
+ with tape1:
472
+ wp.launch(
473
+ normalize_alt_kernel,
474
+ dim=1,
475
+ inputs=[
476
+ q,
477
+ ],
478
+ outputs=outputs1,
479
+ device=device,
480
+ )
481
+
482
+ assert_np_equal(n0.numpy()[0], n0_alt.numpy()[0], tol=tol)
483
+ assert_np_equal(n1.numpy()[0], n1_alt.numpy()[0], tol=tol)
484
+ assert_np_equal(n2.numpy()[0], n2_alt.numpy()[0], tol=tol)
485
+ assert_np_equal(n3.numpy()[0], n3_alt.numpy()[0], tol=tol)
486
+
487
+ for ncmp, ncmpalt in zip(outputs0, outputs1):
488
+ tape0.backward(loss=ncmp)
489
+ tape1.backward(loss=ncmpalt)
490
+ assert_np_equal(tape0.gradients[q].numpy()[0], tape1.gradients[q].numpy()[0], tol=tol)
491
+ tape0.zero()
492
+ tape1.zero()
493
+
494
+
495
+ def test_addition(test, device, dtype, register_kernels=False):
496
+ rng = np.random.default_rng(123)
497
+
498
+ tol = {
499
+ np.float16: 5.0e-3,
500
+ np.float32: 1.0e-6,
501
+ np.float64: 1.0e-8,
502
+ }.get(dtype, 0)
503
+
504
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
505
+ quat = wp.types.quaternion(dtype=wptype)
506
+
507
+ def check_quat_add(
508
+ q: wp.array(dtype=quat),
509
+ v: wp.array(dtype=quat),
510
+ r0: wp.array(dtype=wptype),
511
+ r1: wp.array(dtype=wptype),
512
+ r2: wp.array(dtype=wptype),
513
+ r3: wp.array(dtype=wptype),
514
+ ):
515
+ result = q[0] + v[0]
516
+
517
+ r0[0] = wptype(2) * result[0]
518
+ r1[0] = wptype(2) * result[1]
519
+ r2[0] = wptype(2) * result[2]
520
+ r3[0] = wptype(2) * result[3]
521
+
522
+ kernel = getkernel(check_quat_add, suffix=dtype.__name__)
523
+
524
+ if register_kernels:
525
+ return
526
+
527
+ q = wp.array(rng.standard_normal(size=4).astype(dtype), dtype=quat, requires_grad=True, device=device)
528
+ v = wp.array(rng.standard_normal(size=4).astype(dtype), dtype=quat, requires_grad=True, device=device)
529
+
530
+ r0 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
531
+ r1 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
532
+ r2 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
533
+ r3 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
534
+
535
+ tape = wp.Tape()
536
+ with tape:
537
+ wp.launch(
538
+ kernel,
539
+ dim=1,
540
+ inputs=[
541
+ q,
542
+ v,
543
+ ],
544
+ outputs=[r0, r1, r2, r3],
545
+ device=device,
546
+ )
547
+
548
+ assert_np_equal(r0.numpy()[0], 2 * (v.numpy()[0, 0] + q.numpy()[0, 0]), tol=tol)
549
+ assert_np_equal(r1.numpy()[0], 2 * (v.numpy()[0, 1] + q.numpy()[0, 1]), tol=tol)
550
+ assert_np_equal(r2.numpy()[0], 2 * (v.numpy()[0, 2] + q.numpy()[0, 2]), tol=tol)
551
+ assert_np_equal(r3.numpy()[0], 2 * (v.numpy()[0, 3] + q.numpy()[0, 3]), tol=tol)
552
+
553
+ for i, l in enumerate([r0, r1, r2, r3]):
554
+ tape.backward(loss=l)
555
+ qgrads = tape.gradients[q].numpy()[0]
556
+ expected_grads = np.zeros_like(qgrads)
557
+
558
+ expected_grads[i] = 2
559
+ assert_np_equal(qgrads, expected_grads, tol=10 * tol)
560
+
561
+ vgrads = tape.gradients[v].numpy()[0]
562
+ assert_np_equal(vgrads, expected_grads, tol=tol)
563
+
564
+ tape.zero()
565
+
566
+
567
+ def test_subtraction(test, device, dtype, register_kernels=False):
568
+ rng = np.random.default_rng(123)
569
+
570
+ tol = {
571
+ np.float16: 5.0e-3,
572
+ np.float32: 1.0e-6,
573
+ np.float64: 1.0e-8,
574
+ }.get(dtype, 0)
575
+
576
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
577
+ quat = wp.types.quaternion(dtype=wptype)
578
+
579
+ def check_quat_sub(
580
+ q: wp.array(dtype=quat),
581
+ v: wp.array(dtype=quat),
582
+ r0: wp.array(dtype=wptype),
583
+ r1: wp.array(dtype=wptype),
584
+ r2: wp.array(dtype=wptype),
585
+ r3: wp.array(dtype=wptype),
586
+ ):
587
+ result = v[0] - q[0]
588
+
589
+ r0[0] = wptype(2) * result[0]
590
+ r1[0] = wptype(2) * result[1]
591
+ r2[0] = wptype(2) * result[2]
592
+ r3[0] = wptype(2) * result[3]
593
+
594
+ kernel = getkernel(check_quat_sub, suffix=dtype.__name__)
595
+
596
+ if register_kernels:
597
+ return
598
+
599
+ q = wp.array(rng.standard_normal(size=4).astype(dtype), dtype=quat, requires_grad=True, device=device)
600
+ v = wp.array(rng.standard_normal(size=4).astype(dtype), dtype=quat, requires_grad=True, device=device)
601
+
602
+ r0 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
603
+ r1 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
604
+ r2 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
605
+ r3 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
606
+
607
+ tape = wp.Tape()
608
+ with tape:
609
+ wp.launch(
610
+ kernel,
611
+ dim=1,
612
+ inputs=[
613
+ q,
614
+ v,
615
+ ],
616
+ outputs=[r0, r1, r2, r3],
617
+ device=device,
618
+ )
619
+
620
+ assert_np_equal(r0.numpy()[0], 2 * (v.numpy()[0, 0] - q.numpy()[0, 0]), tol=tol)
621
+ assert_np_equal(r1.numpy()[0], 2 * (v.numpy()[0, 1] - q.numpy()[0, 1]), tol=tol)
622
+ assert_np_equal(r2.numpy()[0], 2 * (v.numpy()[0, 2] - q.numpy()[0, 2]), tol=tol)
623
+ assert_np_equal(r3.numpy()[0], 2 * (v.numpy()[0, 3] - q.numpy()[0, 3]), tol=tol)
624
+
625
+ for i, l in enumerate([r0, r1, r2, r3]):
626
+ tape.backward(loss=l)
627
+ qgrads = tape.gradients[q].numpy()[0]
628
+ expected_grads = np.zeros_like(qgrads)
629
+
630
+ expected_grads[i] = -2
631
+ assert_np_equal(qgrads, expected_grads, tol=10 * tol)
632
+
633
+ vgrads = tape.gradients[v].numpy()[0]
634
+ expected_grads[i] = 2
635
+ assert_np_equal(vgrads, expected_grads, tol=tol)
636
+
637
+ tape.zero()
638
+
639
+
640
+ def test_scalar_multiplication(test, device, dtype, register_kernels=False):
641
+ rng = np.random.default_rng(123)
642
+
643
+ tol = {
644
+ np.float16: 5.0e-3,
645
+ np.float32: 1.0e-6,
646
+ np.float64: 1.0e-8,
647
+ }.get(dtype, 0)
648
+
649
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
650
+ quat = wp.types.quaternion(dtype=wptype)
651
+
652
+ def check_quat_scalar_mul(
653
+ s: wp.array(dtype=wptype),
654
+ q: wp.array(dtype=quat),
655
+ l0: wp.array(dtype=wptype),
656
+ l1: wp.array(dtype=wptype),
657
+ l2: wp.array(dtype=wptype),
658
+ l3: wp.array(dtype=wptype),
659
+ r0: wp.array(dtype=wptype),
660
+ r1: wp.array(dtype=wptype),
661
+ r2: wp.array(dtype=wptype),
662
+ r3: wp.array(dtype=wptype),
663
+ ):
664
+ lresult = s[0] * q[0]
665
+ rresult = q[0] * s[0]
666
+
667
+ # multiply outputs by 2 so we've got something to backpropagate:
668
+ l0[0] = wptype(2) * lresult[0]
669
+ l1[0] = wptype(2) * lresult[1]
670
+ l2[0] = wptype(2) * lresult[2]
671
+ l3[0] = wptype(2) * lresult[3]
672
+
673
+ r0[0] = wptype(2) * rresult[0]
674
+ r1[0] = wptype(2) * rresult[1]
675
+ r2[0] = wptype(2) * rresult[2]
676
+ r3[0] = wptype(2) * rresult[3]
677
+
678
+ kernel = getkernel(check_quat_scalar_mul, suffix=dtype.__name__)
679
+
680
+ if register_kernels:
681
+ return
682
+
683
+ s = wp.array(rng.standard_normal(size=1).astype(dtype), requires_grad=True, device=device)
684
+ q = wp.array(rng.standard_normal(size=(1, 4)).astype(dtype), dtype=quat, requires_grad=True, device=device)
685
+
686
+ l0 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
687
+ l1 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
688
+ l2 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
689
+ l3 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
690
+
691
+ r0 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
692
+ r1 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
693
+ r2 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
694
+ r3 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
695
+
696
+ tape = wp.Tape()
697
+ with tape:
698
+ wp.launch(
699
+ kernel,
700
+ dim=1,
701
+ inputs=[s, q],
702
+ outputs=[
703
+ l0,
704
+ l1,
705
+ l2,
706
+ l3,
707
+ r0,
708
+ r1,
709
+ r2,
710
+ r3,
711
+ ],
712
+ device=device,
713
+ )
714
+
715
+ assert_np_equal(l0.numpy()[0], 2 * s.numpy()[0] * q.numpy()[0, 0], tol=tol)
716
+ assert_np_equal(l1.numpy()[0], 2 * s.numpy()[0] * q.numpy()[0, 1], tol=tol)
717
+ assert_np_equal(l2.numpy()[0], 2 * s.numpy()[0] * q.numpy()[0, 2], tol=tol)
718
+ assert_np_equal(l3.numpy()[0], 2 * s.numpy()[0] * q.numpy()[0, 3], tol=tol)
719
+
720
+ assert_np_equal(r0.numpy()[0], 2 * s.numpy()[0] * q.numpy()[0, 0], tol=tol)
721
+ assert_np_equal(r1.numpy()[0], 2 * s.numpy()[0] * q.numpy()[0, 1], tol=tol)
722
+ assert_np_equal(r2.numpy()[0], 2 * s.numpy()[0] * q.numpy()[0, 2], tol=tol)
723
+ assert_np_equal(r3.numpy()[0], 2 * s.numpy()[0] * q.numpy()[0, 3], tol=tol)
724
+
725
+ if dtype in np_float_types:
726
+ for i, outputs in enumerate([(l0, r0), (l1, r1), (l2, r2), (l3, r3)]):
727
+ for l in outputs:
728
+ tape.backward(loss=l)
729
+ sgrad = tape.gradients[s].numpy()[0]
730
+ assert_np_equal(sgrad, 2 * q.numpy()[0, i], tol=tol)
731
+ allgrads = tape.gradients[q].numpy()[0]
732
+ expected_grads = np.zeros_like(allgrads)
733
+ expected_grads[i] = s.numpy()[0] * 2
734
+ assert_np_equal(allgrads, expected_grads, tol=10 * tol)
735
+ tape.zero()
736
+
737
+
738
+ def test_scalar_division(test, device, dtype, register_kernels=False):
739
+ rng = np.random.default_rng(123)
740
+
741
+ tol = {
742
+ np.float16: 1.0e-3,
743
+ np.float32: 1.0e-6,
744
+ np.float64: 1.0e-8,
745
+ }.get(dtype, 0)
746
+
747
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
748
+ quat = wp.types.quaternion(dtype=wptype)
749
+
750
+ def check_quat_scalar_div(
751
+ s: wp.array(dtype=wptype),
752
+ q: wp.array(dtype=quat),
753
+ r0: wp.array(dtype=wptype),
754
+ r1: wp.array(dtype=wptype),
755
+ r2: wp.array(dtype=wptype),
756
+ r3: wp.array(dtype=wptype),
757
+ ):
758
+ result = q[0] / s[0]
759
+
760
+ # multiply outputs by 2 so we've got something to backpropagate:
761
+ r0[0] = wptype(2) * result[0]
762
+ r1[0] = wptype(2) * result[1]
763
+ r2[0] = wptype(2) * result[2]
764
+ r3[0] = wptype(2) * result[3]
765
+
766
+ kernel = getkernel(check_quat_scalar_div, suffix=dtype.__name__)
767
+
768
+ if register_kernels:
769
+ return
770
+
771
+ s = wp.array(rng.standard_normal(size=1).astype(dtype), requires_grad=True, device=device)
772
+ q = wp.array(rng.standard_normal(size=(1, 4)).astype(dtype), dtype=quat, requires_grad=True, device=device)
773
+
774
+ r0 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
775
+ r1 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
776
+ r2 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
777
+ r3 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
778
+
779
+ tape = wp.Tape()
780
+ with tape:
781
+ wp.launch(
782
+ kernel,
783
+ dim=1,
784
+ inputs=[s, q],
785
+ outputs=[
786
+ r0,
787
+ r1,
788
+ r2,
789
+ r3,
790
+ ],
791
+ device=device,
792
+ )
793
+ assert_np_equal(r0.numpy()[0], 2 * q.numpy()[0, 0] / s.numpy()[0], tol=tol)
794
+ assert_np_equal(r1.numpy()[0], 2 * q.numpy()[0, 1] / s.numpy()[0], tol=tol)
795
+ assert_np_equal(r2.numpy()[0], 2 * q.numpy()[0, 2] / s.numpy()[0], tol=tol)
796
+ assert_np_equal(r3.numpy()[0], 2 * q.numpy()[0, 3] / s.numpy()[0], tol=tol)
797
+
798
+ if dtype in np_float_types:
799
+ for i, r in enumerate([r0, r1, r2, r3]):
800
+ tape.backward(loss=r)
801
+ sgrad = tape.gradients[s].numpy()[0]
802
+ assert_np_equal(sgrad, -2 * q.numpy()[0, i] / (s.numpy()[0] * s.numpy()[0]), tol=tol)
803
+
804
+ allgrads = tape.gradients[q].numpy()[0]
805
+ expected_grads = np.zeros_like(allgrads)
806
+ expected_grads[i] = 2 / s.numpy()[0]
807
+ assert_np_equal(allgrads, expected_grads, tol=10 * tol)
808
+ tape.zero()
809
+
810
+
811
+ def test_quat_multiplication(test, device, dtype, register_kernels=False):
812
+ rng = np.random.default_rng(123)
813
+
814
+ tol = {
815
+ np.float16: 1.0e-2,
816
+ np.float32: 1.0e-6,
817
+ np.float64: 1.0e-8,
818
+ }.get(dtype, 0)
819
+
820
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
821
+ quat = wp.types.quaternion(dtype=wptype)
822
+
823
+ def check_quat_mul(
824
+ s: wp.array(dtype=quat),
825
+ q: wp.array(dtype=quat),
826
+ r0: wp.array(dtype=wptype),
827
+ r1: wp.array(dtype=wptype),
828
+ r2: wp.array(dtype=wptype),
829
+ r3: wp.array(dtype=wptype),
830
+ ):
831
+ result = s[0] * q[0]
832
+
833
+ # multiply outputs by 2 so we've got something to backpropagate:
834
+ r0[0] = wptype(2) * result[0]
835
+ r1[0] = wptype(2) * result[1]
836
+ r2[0] = wptype(2) * result[2]
837
+ r3[0] = wptype(2) * result[3]
838
+
839
+ kernel = getkernel(check_quat_mul, suffix=dtype.__name__)
840
+
841
+ if register_kernels:
842
+ return
843
+
844
+ s = wp.array(rng.standard_normal(size=(1, 4)).astype(dtype), dtype=quat, requires_grad=True, device=device)
845
+ q = wp.array(rng.standard_normal(size=(1, 4)).astype(dtype), dtype=quat, requires_grad=True, device=device)
846
+
847
+ r0 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
848
+ r1 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
849
+ r2 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
850
+ r3 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
851
+
852
+ tape = wp.Tape()
853
+ with tape:
854
+ wp.launch(
855
+ kernel,
856
+ dim=1,
857
+ inputs=[s, q],
858
+ outputs=[
859
+ r0,
860
+ r1,
861
+ r2,
862
+ r3,
863
+ ],
864
+ device=device,
865
+ )
866
+
867
+ a = s.numpy()
868
+ b = q.numpy()
869
+ assert_np_equal(
870
+ r0.numpy()[0], 2 * (a[0, 3] * b[0, 0] + b[0, 3] * a[0, 0] + a[0, 1] * b[0, 2] - b[0, 1] * a[0, 2]), tol=tol
871
+ )
872
+ assert_np_equal(
873
+ r1.numpy()[0], 2 * (a[0, 3] * b[0, 1] + b[0, 3] * a[0, 1] + a[0, 2] * b[0, 0] - b[0, 2] * a[0, 0]), tol=tol
874
+ )
875
+ assert_np_equal(
876
+ r2.numpy()[0], 2 * (a[0, 3] * b[0, 2] + b[0, 3] * a[0, 2] + a[0, 0] * b[0, 1] - b[0, 0] * a[0, 1]), tol=tol
877
+ )
878
+ assert_np_equal(
879
+ r3.numpy()[0], 2 * (a[0, 3] * b[0, 3] - a[0, 0] * b[0, 0] - a[0, 1] * b[0, 1] - a[0, 2] * b[0, 2]), tol=tol
880
+ )
881
+
882
+ tape.backward(loss=r0)
883
+ agrad = tape.gradients[s].numpy()[0]
884
+ assert_np_equal(agrad, 2 * np.array([b[0, 3], b[0, 2], -b[0, 1], b[0, 0]]), tol=tol)
885
+
886
+ bgrad = tape.gradients[q].numpy()[0]
887
+ assert_np_equal(bgrad, 2 * np.array([a[0, 3], -a[0, 2], a[0, 1], a[0, 0]]), tol=tol)
888
+ tape.zero()
889
+
890
+ tape.backward(loss=r1)
891
+ agrad = tape.gradients[s].numpy()[0]
892
+ assert_np_equal(agrad, 2 * np.array([-b[0, 2], b[0, 3], b[0, 0], b[0, 1]]), tol=tol)
893
+
894
+ bgrad = tape.gradients[q].numpy()[0]
895
+ assert_np_equal(bgrad, 2 * np.array([a[0, 2], a[0, 3], -a[0, 0], a[0, 1]]), tol=tol)
896
+ tape.zero()
897
+
898
+ tape.backward(loss=r2)
899
+ agrad = tape.gradients[s].numpy()[0]
900
+ assert_np_equal(agrad, 2 * np.array([b[0, 1], -b[0, 0], b[0, 3], b[0, 2]]), tol=tol)
901
+
902
+ bgrad = tape.gradients[q].numpy()[0]
903
+ assert_np_equal(bgrad, 2 * np.array([-a[0, 1], a[0, 0], a[0, 3], a[0, 2]]), tol=tol)
904
+ tape.zero()
905
+
906
+ tape.backward(loss=r3)
907
+ agrad = tape.gradients[s].numpy()[0]
908
+ assert_np_equal(agrad, 2 * np.array([-b[0, 0], -b[0, 1], -b[0, 2], b[0, 3]]), tol=tol)
909
+
910
+ bgrad = tape.gradients[q].numpy()[0]
911
+ assert_np_equal(bgrad, 2 * np.array([-a[0, 0], -a[0, 1], -a[0, 2], a[0, 3]]), tol=tol)
912
+ tape.zero()
913
+
914
+
915
+ def test_indexing(test, device, dtype, register_kernels=False):
916
+ rng = np.random.default_rng(123)
917
+
918
+ tol = {
919
+ np.float16: 5.0e-3,
920
+ np.float32: 1.0e-6,
921
+ np.float64: 1.0e-8,
922
+ }.get(dtype, 0)
923
+
924
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
925
+ quat = wp.types.quaternion(dtype=wptype)
926
+
927
+ def check_quat_indexing(
928
+ q: wp.array(dtype=quat),
929
+ r0: wp.array(dtype=wptype),
930
+ r1: wp.array(dtype=wptype),
931
+ r2: wp.array(dtype=wptype),
932
+ r3: wp.array(dtype=wptype),
933
+ ):
934
+ # multiply outputs by 2 so we've got something to backpropagate:
935
+ r0[0] = wptype(2) * q[0][0]
936
+ r1[0] = wptype(2) * q[0][1]
937
+ r2[0] = wptype(2) * q[0][2]
938
+ r3[0] = wptype(2) * q[0][3]
939
+
940
+ kernel = getkernel(check_quat_indexing, suffix=dtype.__name__)
941
+
942
+ if register_kernels:
943
+ return
944
+
945
+ q = wp.array(rng.standard_normal(size=(1, 4)).astype(dtype), dtype=quat, requires_grad=True, device=device)
946
+ r0 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
947
+ r1 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
948
+ r2 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
949
+ r3 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
950
+
951
+ tape = wp.Tape()
952
+ with tape:
953
+ wp.launch(kernel, dim=1, inputs=[q], outputs=[r0, r1, r2, r3], device=device)
954
+
955
+ for i, l in enumerate([r0, r1, r2, r3]):
956
+ tape.backward(loss=l)
957
+ allgrads = tape.gradients[q].numpy()[0]
958
+ expected_grads = np.zeros_like(allgrads)
959
+ expected_grads[i] = 2
960
+ assert_np_equal(allgrads, expected_grads, tol=tol)
961
+ tape.zero()
962
+
963
+ assert_np_equal(r0.numpy()[0], 2.0 * q.numpy()[0, 0], tol=tol)
964
+ assert_np_equal(r1.numpy()[0], 2.0 * q.numpy()[0, 1], tol=tol)
965
+ assert_np_equal(r2.numpy()[0], 2.0 * q.numpy()[0, 2], tol=tol)
966
+ assert_np_equal(r3.numpy()[0], 2.0 * q.numpy()[0, 3], tol=tol)
967
+
968
+
969
+ def test_quat_lerp(test, device, dtype, register_kernels=False):
970
+ rng = np.random.default_rng(123)
971
+
972
+ tol = {
973
+ np.float16: 1.0e-2,
974
+ np.float32: 1.0e-6,
975
+ np.float64: 1.0e-8,
976
+ }.get(dtype, 0)
977
+
978
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
979
+ quat = wp.types.quaternion(dtype=wptype)
980
+
981
+ def check_quat_lerp(
982
+ s: wp.array(dtype=quat),
983
+ q: wp.array(dtype=quat),
984
+ t: wp.array(dtype=wptype),
985
+ r0: wp.array(dtype=wptype),
986
+ r1: wp.array(dtype=wptype),
987
+ r2: wp.array(dtype=wptype),
988
+ r3: wp.array(dtype=wptype),
989
+ ):
990
+ result = wp.lerp(s[0], q[0], t[0])
991
+
992
+ # multiply outputs by 2 so we've got something to backpropagate:
993
+ r0[0] = wptype(2) * result[0]
994
+ r1[0] = wptype(2) * result[1]
995
+ r2[0] = wptype(2) * result[2]
996
+ r3[0] = wptype(2) * result[3]
997
+
998
+ kernel = getkernel(check_quat_lerp, suffix=dtype.__name__)
999
+
1000
+ if register_kernels:
1001
+ return
1002
+
1003
+ s = wp.array(rng.standard_normal(size=(1, 4)).astype(dtype), dtype=quat, requires_grad=True, device=device)
1004
+ q = wp.array(rng.standard_normal(size=(1, 4)).astype(dtype), dtype=quat, requires_grad=True, device=device)
1005
+ t = wp.array(rng.uniform(size=1).astype(dtype), dtype=wptype, requires_grad=True, device=device)
1006
+
1007
+ r0 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1008
+ r1 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1009
+ r2 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1010
+ r3 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1011
+
1012
+ tape = wp.Tape()
1013
+ with tape:
1014
+ wp.launch(
1015
+ kernel,
1016
+ dim=1,
1017
+ inputs=[s, q, t],
1018
+ outputs=[
1019
+ r0,
1020
+ r1,
1021
+ r2,
1022
+ r3,
1023
+ ],
1024
+ device=device,
1025
+ )
1026
+
1027
+ a = s.numpy()
1028
+ b = q.numpy()
1029
+ tt = t.numpy()
1030
+ assert_np_equal(r0.numpy()[0], 2 * ((1 - tt) * a[0, 0] + tt * b[0, 0]), tol=tol)
1031
+ assert_np_equal(r1.numpy()[0], 2 * ((1 - tt) * a[0, 1] + tt * b[0, 1]), tol=tol)
1032
+ assert_np_equal(r2.numpy()[0], 2 * ((1 - tt) * a[0, 2] + tt * b[0, 2]), tol=tol)
1033
+ assert_np_equal(r3.numpy()[0], 2 * ((1 - tt) * a[0, 3] + tt * b[0, 3]), tol=tol)
1034
+
1035
+ for i, l in enumerate([r0, r1, r2, r3]):
1036
+ tape.backward(loss=l)
1037
+ agrad = tape.gradients[s].numpy()[0]
1038
+ bgrad = tape.gradients[q].numpy()[0]
1039
+ tgrad = tape.gradients[t].numpy()[0]
1040
+ expected_grads = np.zeros_like(agrad)
1041
+ expected_grads[i] = 2 * (1 - tt)
1042
+ assert_np_equal(agrad, expected_grads, tol=tol)
1043
+ expected_grads[i] = 2 * tt
1044
+ assert_np_equal(bgrad, expected_grads, tol=tol)
1045
+ assert_np_equal(tgrad, 2 * (b[0, i] - a[0, i]), tol=tol)
1046
+
1047
+ tape.zero()
1048
+
1049
+
1050
+ def test_quat_rotate(test, device, dtype, register_kernels=False):
1051
+ rng = np.random.default_rng(123)
1052
+
1053
+ tol = {
1054
+ np.float16: 1.0e-2,
1055
+ np.float32: 1.0e-6,
1056
+ np.float64: 1.0e-8,
1057
+ }.get(dtype, 0)
1058
+
1059
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1060
+ quat = wp.types.quaternion(dtype=wptype)
1061
+ vec3 = wp.types.vector(length=3, dtype=wptype)
1062
+
1063
+ def check_quat_rotate(
1064
+ q: wp.array(dtype=quat),
1065
+ v: wp.array(dtype=vec3),
1066
+ outputs: wp.array(dtype=wptype),
1067
+ outputs_inv: wp.array(dtype=wptype),
1068
+ outputs_manual: wp.array(dtype=wptype),
1069
+ outputs_inv_manual: wp.array(dtype=wptype),
1070
+ ):
1071
+ result = wp.quat_rotate(q[0], v[0])
1072
+ result_inv = wp.quat_rotate_inv(q[0], v[0])
1073
+
1074
+ qv = vec3(q[0][0], q[0][1], q[0][2])
1075
+ qw = q[0][3]
1076
+
1077
+ result_manual = v[0] * (wptype(2) * qw * qw - wptype(1))
1078
+ result_manual += wp.cross(qv, v[0]) * qw * wptype(2)
1079
+ result_manual += qv * wp.dot(qv, v[0]) * wptype(2)
1080
+
1081
+ result_inv_manual = v[0] * (wptype(2) * qw * qw - wptype(1))
1082
+ result_inv_manual -= wp.cross(qv, v[0]) * qw * wptype(2)
1083
+ result_inv_manual += qv * wp.dot(qv, v[0]) * wptype(2)
1084
+
1085
+ for i in range(3):
1086
+ # multiply outputs by 2 so we've got something to backpropagate:
1087
+ outputs[i] = wptype(2) * result[i]
1088
+ outputs_inv[i] = wptype(2) * result_inv[i]
1089
+ outputs_manual[i] = wptype(2) * result_manual[i]
1090
+ outputs_inv_manual[i] = wptype(2) * result_inv_manual[i]
1091
+
1092
+ kernel = getkernel(check_quat_rotate, suffix=dtype.__name__)
1093
+ output_select_kernel = get_select_kernel(wptype)
1094
+
1095
+ if register_kernels:
1096
+ return
1097
+
1098
+ q = rng.standard_normal(size=(1, 4))
1099
+ q /= np.linalg.norm(q)
1100
+ q = wp.array(q.astype(dtype), dtype=quat, requires_grad=True, device=device)
1101
+ v = wp.array(0.5 * rng.standard_normal(size=(1, 3)).astype(dtype), dtype=vec3, requires_grad=True, device=device)
1102
+
1103
+ # test values against the manually computed result:
1104
+ outputs = wp.zeros(3, dtype=wptype, requires_grad=True, device=device)
1105
+ outputs_inv = wp.zeros(3, dtype=wptype, requires_grad=True, device=device)
1106
+ outputs_manual = wp.zeros(3, dtype=wptype, requires_grad=True, device=device)
1107
+ outputs_inv_manual = wp.zeros(3, dtype=wptype, requires_grad=True, device=device)
1108
+
1109
+ wp.launch(
1110
+ kernel,
1111
+ dim=1,
1112
+ inputs=[q, v],
1113
+ outputs=[
1114
+ outputs,
1115
+ outputs_inv,
1116
+ outputs_manual,
1117
+ outputs_inv_manual,
1118
+ ],
1119
+ device=device,
1120
+ )
1121
+
1122
+ assert_np_equal(outputs.numpy(), outputs_manual.numpy(), tol=tol)
1123
+ assert_np_equal(outputs_inv.numpy(), outputs_inv_manual.numpy(), tol=tol)
1124
+
1125
+ # test gradients against the manually computed result:
1126
+ for i in range(3):
1127
+ cmp = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1128
+ cmp_inv = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1129
+ cmp_manual = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1130
+ cmp_inv_manual = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1131
+ tape = wp.Tape()
1132
+ with tape:
1133
+ wp.launch(
1134
+ kernel,
1135
+ dim=1,
1136
+ inputs=[q, v],
1137
+ outputs=[
1138
+ outputs,
1139
+ outputs_inv,
1140
+ outputs_manual,
1141
+ outputs_inv_manual,
1142
+ ],
1143
+ device=device,
1144
+ )
1145
+ wp.launch(output_select_kernel, dim=1, inputs=[outputs, i], outputs=[cmp], device=device)
1146
+ wp.launch(output_select_kernel, dim=1, inputs=[outputs_inv, i], outputs=[cmp_inv], device=device)
1147
+ wp.launch(output_select_kernel, dim=1, inputs=[outputs_manual, i], outputs=[cmp_manual], device=device)
1148
+ wp.launch(
1149
+ output_select_kernel, dim=1, inputs=[outputs_inv_manual, i], outputs=[cmp_inv_manual], device=device
1150
+ )
1151
+
1152
+ tape.backward(loss=cmp)
1153
+ qgrads = 1.0 * tape.gradients[q].numpy()
1154
+ vgrads = 1.0 * tape.gradients[v].numpy()
1155
+ tape.zero()
1156
+ tape.backward(loss=cmp_inv)
1157
+ qgrads_inv = 1.0 * tape.gradients[q].numpy()
1158
+ vgrads_inv = 1.0 * tape.gradients[v].numpy()
1159
+ tape.zero()
1160
+ tape.backward(loss=cmp_manual)
1161
+ qgrads_manual = 1.0 * tape.gradients[q].numpy()
1162
+ vgrads_manual = 1.0 * tape.gradients[v].numpy()
1163
+ tape.zero()
1164
+ tape.backward(loss=cmp_inv_manual)
1165
+ qgrads_inv_manual = 1.0 * tape.gradients[q].numpy()
1166
+ vgrads_inv_manual = 1.0 * tape.gradients[v].numpy()
1167
+ tape.zero()
1168
+
1169
+ assert_np_equal(qgrads, qgrads_manual, tol=tol)
1170
+ assert_np_equal(vgrads, vgrads_manual, tol=tol)
1171
+
1172
+ assert_np_equal(qgrads_inv, qgrads_inv_manual, tol=tol)
1173
+ assert_np_equal(vgrads_inv, vgrads_inv_manual, tol=tol)
1174
+
1175
+
1176
+ def test_quat_to_matrix(test, device, dtype, register_kernels=False):
1177
+ rng = np.random.default_rng(123)
1178
+
1179
+ tol = {
1180
+ np.float16: 1.0e-2,
1181
+ np.float32: 1.0e-6,
1182
+ np.float64: 1.0e-8,
1183
+ }.get(dtype, 0)
1184
+
1185
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1186
+ quat = wp.types.quaternion(dtype=wptype)
1187
+ mat3 = wp.types.matrix(shape=(3, 3), dtype=wptype)
1188
+ vec3 = wp.types.vector(length=3, dtype=wptype)
1189
+
1190
+ def check_quat_to_matrix(
1191
+ q: wp.array(dtype=quat),
1192
+ outputs: wp.array(dtype=wptype),
1193
+ outputs_manual: wp.array(dtype=wptype),
1194
+ ):
1195
+ result = wp.quat_to_matrix(q[0])
1196
+
1197
+ xaxis = wp.quat_rotate(
1198
+ q[0],
1199
+ vec3(
1200
+ wptype(1),
1201
+ wptype(0),
1202
+ wptype(0),
1203
+ ),
1204
+ )
1205
+ yaxis = wp.quat_rotate(
1206
+ q[0],
1207
+ vec3(
1208
+ wptype(0),
1209
+ wptype(1),
1210
+ wptype(0),
1211
+ ),
1212
+ )
1213
+ zaxis = wp.quat_rotate(
1214
+ q[0],
1215
+ vec3(
1216
+ wptype(0),
1217
+ wptype(0),
1218
+ wptype(1),
1219
+ ),
1220
+ )
1221
+ result_manual = mat3(xaxis, yaxis, zaxis)
1222
+
1223
+ idx = 0
1224
+ for i in range(3):
1225
+ for j in range(3):
1226
+ # multiply outputs by 2 so we've got something to backpropagate:
1227
+ outputs[idx] = wptype(2) * result[i, j]
1228
+ outputs_manual[idx] = wptype(2) * result_manual[i, j]
1229
+
1230
+ idx = idx + 1
1231
+
1232
+ kernel = getkernel(check_quat_to_matrix, suffix=dtype.__name__)
1233
+ output_select_kernel = get_select_kernel(wptype)
1234
+
1235
+ if register_kernels:
1236
+ return
1237
+
1238
+ q = rng.standard_normal(size=(1, 4))
1239
+ q /= np.linalg.norm(q)
1240
+ q = wp.array(q.astype(dtype), dtype=quat, requires_grad=True, device=device)
1241
+
1242
+ # test values against the manually computed result:
1243
+ outputs = wp.zeros(3 * 3, dtype=wptype, requires_grad=True, device=device)
1244
+ outputs_manual = wp.zeros(3 * 3, dtype=wptype, requires_grad=True, device=device)
1245
+
1246
+ wp.launch(
1247
+ kernel,
1248
+ dim=1,
1249
+ inputs=[q],
1250
+ outputs=[
1251
+ outputs,
1252
+ outputs_manual,
1253
+ ],
1254
+ device=device,
1255
+ )
1256
+
1257
+ assert_np_equal(outputs.numpy(), outputs_manual.numpy(), tol=tol)
1258
+
1259
+ # sanity check: divide by 2 to remove that scale factor we put in there, and
1260
+ # it should be a rotation matrix
1261
+ R = 0.5 * outputs.numpy().reshape(3, 3)
1262
+ assert_np_equal(np.matmul(R, R.T), np.eye(3), tol=tol)
1263
+
1264
+ # test gradients against the manually computed result:
1265
+ idx = 0
1266
+ for _i in range(3):
1267
+ for _j in range(3):
1268
+ cmp = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1269
+ cmp_manual = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1270
+ tape = wp.Tape()
1271
+ with tape:
1272
+ wp.launch(
1273
+ kernel,
1274
+ dim=1,
1275
+ inputs=[q],
1276
+ outputs=[
1277
+ outputs,
1278
+ outputs_manual,
1279
+ ],
1280
+ device=device,
1281
+ )
1282
+ wp.launch(output_select_kernel, dim=1, inputs=[outputs, idx], outputs=[cmp], device=device)
1283
+ wp.launch(
1284
+ output_select_kernel, dim=1, inputs=[outputs_manual, idx], outputs=[cmp_manual], device=device
1285
+ )
1286
+ tape.backward(loss=cmp)
1287
+ qgrads = 1.0 * tape.gradients[q].numpy()
1288
+ tape.zero()
1289
+ tape.backward(loss=cmp_manual)
1290
+ qgrads_manual = 1.0 * tape.gradients[q].numpy()
1291
+ tape.zero()
1292
+
1293
+ assert_np_equal(qgrads, qgrads_manual, tol=tol)
1294
+ idx = idx + 1
1295
+
1296
+
1297
+ ############################################################
1298
+
1299
+
1300
+ def test_slerp_grad(test, device, dtype, register_kernels=False):
1301
+ rng = np.random.default_rng(123)
1302
+ seed = 42
1303
+
1304
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1305
+ vec3 = wp.types.vector(3, wptype)
1306
+ quat = wp.types.quaternion(wptype)
1307
+
1308
+ def slerp_kernel(
1309
+ q0: wp.array(dtype=quat),
1310
+ q1: wp.array(dtype=quat),
1311
+ t: wp.array(dtype=wptype),
1312
+ loss: wp.array(dtype=wptype),
1313
+ index: int,
1314
+ ):
1315
+ tid = wp.tid()
1316
+
1317
+ q = wp.quat_slerp(q0[tid], q1[tid], t[tid])
1318
+ wp.atomic_add(loss, 0, q[index])
1319
+
1320
+ slerp_kernel = getkernel(slerp_kernel, suffix=dtype.__name__)
1321
+
1322
+ def slerp_kernel_forward(
1323
+ q0: wp.array(dtype=quat),
1324
+ q1: wp.array(dtype=quat),
1325
+ t: wp.array(dtype=wptype),
1326
+ loss: wp.array(dtype=wptype),
1327
+ index: int,
1328
+ ):
1329
+ tid = wp.tid()
1330
+
1331
+ axis = vec3()
1332
+ angle = wptype(0.0)
1333
+
1334
+ wp.quat_to_axis_angle(wp.mul(wp.quat_inverse(q0[tid]), q1[tid]), axis, angle)
1335
+ q = wp.mul(q0[tid], wp.quat_from_axis_angle(axis, t[tid] * angle))
1336
+
1337
+ wp.atomic_add(loss, 0, q[index])
1338
+
1339
+ slerp_kernel_forward = getkernel(slerp_kernel_forward, suffix=dtype.__name__)
1340
+
1341
+ def quat_sampler_slerp(kernel_seed: int, quats: wp.array(dtype=quat)):
1342
+ tid = wp.tid()
1343
+
1344
+ state = wp.rand_init(kernel_seed, tid)
1345
+
1346
+ angle = wp.randf(state, 0.0, 2.0 * 3.1415926535)
1347
+ dir = wp.sample_unit_sphere_surface(state) * wp.sin(angle * 0.5)
1348
+
1349
+ q = quat(wptype(dir[0]), wptype(dir[1]), wptype(dir[2]), wptype(wp.cos(angle * 0.5)))
1350
+ qn = wp.normalize(q)
1351
+
1352
+ quats[tid] = qn
1353
+
1354
+ quat_sampler = getkernel(quat_sampler_slerp, suffix=dtype.__name__)
1355
+
1356
+ if register_kernels:
1357
+ return
1358
+
1359
+ N = 50
1360
+
1361
+ q0 = wp.zeros(N, dtype=quat, device=device, requires_grad=True)
1362
+ q1 = wp.zeros(N, dtype=quat, device=device, requires_grad=True)
1363
+
1364
+ wp.launch(kernel=quat_sampler, dim=N, inputs=[seed, q0], device=device)
1365
+ wp.launch(kernel=quat_sampler, dim=N, inputs=[seed + 1, q1], device=device)
1366
+
1367
+ t = rng.uniform(low=0.0, high=1.0, size=N)
1368
+ t = wp.array(t, dtype=wptype, device=device, requires_grad=True)
1369
+
1370
+ def compute_gradients(kernel, wrt, index):
1371
+ loss = wp.zeros(1, dtype=wptype, device=device, requires_grad=True)
1372
+ tape = wp.Tape()
1373
+ with tape:
1374
+ wp.launch(kernel=kernel, dim=N, inputs=[q0, q1, t, loss, index], device=device)
1375
+
1376
+ tape.backward(loss)
1377
+
1378
+ gradients = 1.0 * tape.gradients[wrt].numpy()
1379
+ tape.zero()
1380
+
1381
+ return loss.numpy()[0], gradients
1382
+
1383
+ eps = {
1384
+ np.float16: 2.0e-2,
1385
+ np.float32: 1.0e-5,
1386
+ np.float64: 1.0e-8,
1387
+ }.get(dtype, 0)
1388
+
1389
+ # wrt t
1390
+
1391
+ # gather gradients from builtin adjoints
1392
+ xcmp, gradients_x = compute_gradients(slerp_kernel, t, 0)
1393
+ ycmp, gradients_y = compute_gradients(slerp_kernel, t, 1)
1394
+ zcmp, gradients_z = compute_gradients(slerp_kernel, t, 2)
1395
+ wcmp, gradients_w = compute_gradients(slerp_kernel, t, 3)
1396
+
1397
+ # gather gradients from autodiff
1398
+ xcmp_auto, gradients_x_auto = compute_gradients(slerp_kernel_forward, t, 0)
1399
+ ycmp_auto, gradients_y_auto = compute_gradients(slerp_kernel_forward, t, 1)
1400
+ zcmp_auto, gradients_z_auto = compute_gradients(slerp_kernel_forward, t, 2)
1401
+ wcmp_auto, gradients_w_auto = compute_gradients(slerp_kernel_forward, t, 3)
1402
+
1403
+ assert_np_equal(gradients_x, gradients_x_auto, tol=eps)
1404
+ assert_np_equal(gradients_y, gradients_y_auto, tol=eps)
1405
+ assert_np_equal(gradients_z, gradients_z_auto, tol=eps)
1406
+ assert_np_equal(gradients_w, gradients_w_auto, tol=eps)
1407
+ assert_np_equal(xcmp, xcmp_auto, tol=eps)
1408
+ assert_np_equal(ycmp, ycmp_auto, tol=eps)
1409
+ assert_np_equal(zcmp, zcmp_auto, tol=eps)
1410
+ assert_np_equal(wcmp, wcmp_auto, tol=eps)
1411
+
1412
+ # wrt q0
1413
+
1414
+ # gather gradients from builtin adjoints
1415
+ xcmp, gradients_x = compute_gradients(slerp_kernel, q0, 0)
1416
+ ycmp, gradients_y = compute_gradients(slerp_kernel, q0, 1)
1417
+ zcmp, gradients_z = compute_gradients(slerp_kernel, q0, 2)
1418
+ wcmp, gradients_w = compute_gradients(slerp_kernel, q0, 3)
1419
+
1420
+ # gather gradients from autodiff
1421
+ xcmp_auto, gradients_x_auto = compute_gradients(slerp_kernel_forward, q0, 0)
1422
+ ycmp_auto, gradients_y_auto = compute_gradients(slerp_kernel_forward, q0, 1)
1423
+ zcmp_auto, gradients_z_auto = compute_gradients(slerp_kernel_forward, q0, 2)
1424
+ wcmp_auto, gradients_w_auto = compute_gradients(slerp_kernel_forward, q0, 3)
1425
+
1426
+ assert_np_equal(gradients_x, gradients_x_auto, tol=eps)
1427
+ assert_np_equal(gradients_y, gradients_y_auto, tol=eps)
1428
+ assert_np_equal(gradients_z, gradients_z_auto, tol=eps)
1429
+ assert_np_equal(gradients_w, gradients_w_auto, tol=eps)
1430
+ assert_np_equal(xcmp, xcmp_auto, tol=eps)
1431
+ assert_np_equal(ycmp, ycmp_auto, tol=eps)
1432
+ assert_np_equal(zcmp, zcmp_auto, tol=eps)
1433
+ assert_np_equal(wcmp, wcmp_auto, tol=eps)
1434
+
1435
+ # wrt q1
1436
+
1437
+ # gather gradients from builtin adjoints
1438
+ xcmp, gradients_x = compute_gradients(slerp_kernel, q1, 0)
1439
+ ycmp, gradients_y = compute_gradients(slerp_kernel, q1, 1)
1440
+ zcmp, gradients_z = compute_gradients(slerp_kernel, q1, 2)
1441
+ wcmp, gradients_w = compute_gradients(slerp_kernel, q1, 3)
1442
+
1443
+ # gather gradients from autodiff
1444
+ xcmp_auto, gradients_x_auto = compute_gradients(slerp_kernel_forward, q1, 0)
1445
+ ycmp_auto, gradients_y_auto = compute_gradients(slerp_kernel_forward, q1, 1)
1446
+ zcmp_auto, gradients_z_auto = compute_gradients(slerp_kernel_forward, q1, 2)
1447
+ wcmp_auto, gradients_w_auto = compute_gradients(slerp_kernel_forward, q1, 3)
1448
+
1449
+ assert_np_equal(gradients_x, gradients_x_auto, tol=eps)
1450
+ assert_np_equal(gradients_y, gradients_y_auto, tol=eps)
1451
+ assert_np_equal(gradients_z, gradients_z_auto, tol=eps)
1452
+ assert_np_equal(gradients_w, gradients_w_auto, tol=eps)
1453
+ assert_np_equal(xcmp, xcmp_auto, tol=eps)
1454
+ assert_np_equal(ycmp, ycmp_auto, tol=eps)
1455
+ assert_np_equal(zcmp, zcmp_auto, tol=eps)
1456
+ assert_np_equal(wcmp, wcmp_auto, tol=eps)
1457
+
1458
+
1459
+ ############################################################
1460
+
1461
+
1462
+ def test_quat_to_axis_angle_grad(test, device, dtype, register_kernels=False):
1463
+ rng = np.random.default_rng(123)
1464
+ seed = 42
1465
+ num_rand = 50
1466
+
1467
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1468
+ vec3 = wp.types.vector(3, wptype)
1469
+ vec4 = wp.types.vector(4, wptype)
1470
+ quat = wp.types.quaternion(wptype)
1471
+
1472
+ def quat_to_axis_angle_kernel(quats: wp.array(dtype=quat), loss: wp.array(dtype=wptype), coord_idx: int):
1473
+ tid = wp.tid()
1474
+ axis = vec3()
1475
+ angle = wptype(0.0)
1476
+
1477
+ wp.quat_to_axis_angle(quats[tid], axis, angle)
1478
+ a = vec4(axis[0], axis[1], axis[2], angle)
1479
+
1480
+ wp.atomic_add(loss, 0, a[coord_idx])
1481
+
1482
+ quat_to_axis_angle_kernel = getkernel(quat_to_axis_angle_kernel, suffix=dtype.__name__)
1483
+
1484
+ def quat_to_axis_angle_kernel_forward(quats: wp.array(dtype=quat), loss: wp.array(dtype=wptype), coord_idx: int):
1485
+ tid = wp.tid()
1486
+ q = quats[tid]
1487
+ axis = vec3()
1488
+ angle = wptype(0.0)
1489
+
1490
+ v = vec3(q[0], q[1], q[2])
1491
+ if q[3] < wptype(0):
1492
+ axis = -wp.normalize(v)
1493
+ else:
1494
+ axis = wp.normalize(v)
1495
+
1496
+ angle = wptype(2) * wp.atan2(wp.length(v), wp.abs(q[3]))
1497
+ a = vec4(axis[0], axis[1], axis[2], angle)
1498
+
1499
+ wp.atomic_add(loss, 0, a[coord_idx])
1500
+
1501
+ quat_to_axis_angle_kernel_forward = getkernel(quat_to_axis_angle_kernel_forward, suffix=dtype.__name__)
1502
+
1503
+ def quat_sampler(kernel_seed: int, angles: wp.array(dtype=float), quats: wp.array(dtype=quat)):
1504
+ tid = wp.tid()
1505
+
1506
+ state = wp.rand_init(kernel_seed, tid)
1507
+
1508
+ angle = angles[tid]
1509
+ dir = wp.sample_unit_sphere_surface(state) * wp.sin(angle * 0.5)
1510
+
1511
+ q = quat(wptype(dir[0]), wptype(dir[1]), wptype(dir[2]), wptype(wp.cos(angle * 0.5)))
1512
+ qn = wp.normalize(q)
1513
+
1514
+ quats[tid] = qn
1515
+
1516
+ quat_sampler = getkernel(quat_sampler, suffix=dtype.__name__)
1517
+
1518
+ if register_kernels:
1519
+ return
1520
+
1521
+ quats = wp.zeros(num_rand, dtype=quat, device=device, requires_grad=True)
1522
+ angles = wp.array(
1523
+ np.linspace(0.0, 2.0 * np.pi, num_rand, endpoint=False, dtype=np.float32), dtype=float, device=device
1524
+ )
1525
+ wp.launch(kernel=quat_sampler, dim=num_rand, inputs=[seed, angles, quats], device=device)
1526
+
1527
+ edge_cases = np.array(
1528
+ [(1.0, 0.0, 0.0, 0.0), (0.0, 1.0 / np.sqrt(3), 1.0 / np.sqrt(3), 1.0 / np.sqrt(3)), (0.0, 0.0, 0.0, 0.0)]
1529
+ )
1530
+ num_edge = len(edge_cases)
1531
+ edge_cases = wp.array(edge_cases, dtype=quat, device=device, requires_grad=True)
1532
+
1533
+ def compute_gradients(arr, kernel, dim, index):
1534
+ loss = wp.zeros(1, dtype=wptype, device=device, requires_grad=True)
1535
+ tape = wp.Tape()
1536
+ with tape:
1537
+ wp.launch(kernel=kernel, dim=dim, inputs=[arr, loss, index], device=device)
1538
+
1539
+ tape.backward(loss)
1540
+
1541
+ gradients = 1.0 * tape.gradients[arr].numpy()
1542
+ tape.zero()
1543
+
1544
+ return loss.numpy()[0], gradients
1545
+
1546
+ # gather gradients from builtin adjoints
1547
+ xcmp, gradients_x = compute_gradients(quats, quat_to_axis_angle_kernel, num_rand, 0)
1548
+ ycmp, gradients_y = compute_gradients(quats, quat_to_axis_angle_kernel, num_rand, 1)
1549
+ zcmp, gradients_z = compute_gradients(quats, quat_to_axis_angle_kernel, num_rand, 2)
1550
+ wcmp, gradients_w = compute_gradients(quats, quat_to_axis_angle_kernel, num_rand, 3)
1551
+
1552
+ # gather gradients from autodiff
1553
+ xcmp_auto, gradients_x_auto = compute_gradients(quats, quat_to_axis_angle_kernel_forward, num_rand, 0)
1554
+ ycmp_auto, gradients_y_auto = compute_gradients(quats, quat_to_axis_angle_kernel_forward, num_rand, 1)
1555
+ zcmp_auto, gradients_z_auto = compute_gradients(quats, quat_to_axis_angle_kernel_forward, num_rand, 2)
1556
+ wcmp_auto, gradients_w_auto = compute_gradients(quats, quat_to_axis_angle_kernel_forward, num_rand, 3)
1557
+
1558
+ # edge cases: gather gradients from builtin adjoints
1559
+ _, edge_gradients_x = compute_gradients(edge_cases, quat_to_axis_angle_kernel, num_edge, 0)
1560
+ _, edge_gradients_y = compute_gradients(edge_cases, quat_to_axis_angle_kernel, num_edge, 1)
1561
+ _, edge_gradients_z = compute_gradients(edge_cases, quat_to_axis_angle_kernel, num_edge, 2)
1562
+ _, edge_gradients_w = compute_gradients(edge_cases, quat_to_axis_angle_kernel, num_edge, 3)
1563
+
1564
+ # edge cases: gather gradients from autodiff
1565
+ _, edge_gradients_x_auto = compute_gradients(edge_cases, quat_to_axis_angle_kernel_forward, num_edge, 0)
1566
+ _, edge_gradients_y_auto = compute_gradients(edge_cases, quat_to_axis_angle_kernel_forward, num_edge, 1)
1567
+ _, edge_gradients_z_auto = compute_gradients(edge_cases, quat_to_axis_angle_kernel_forward, num_edge, 2)
1568
+ _, edge_gradients_w_auto = compute_gradients(edge_cases, quat_to_axis_angle_kernel_forward, num_edge, 3)
1569
+
1570
+ eps = {
1571
+ np.float16: 2.0e-1,
1572
+ np.float32: 2.0e-4,
1573
+ np.float64: 2.0e-7,
1574
+ }.get(dtype, 0)
1575
+
1576
+ assert_np_equal(xcmp, xcmp_auto, tol=eps)
1577
+ assert_np_equal(ycmp, ycmp_auto, tol=eps)
1578
+ assert_np_equal(zcmp, zcmp_auto, tol=eps)
1579
+ assert_np_equal(wcmp, wcmp_auto, tol=eps)
1580
+
1581
+ assert_np_equal(gradients_x, gradients_x_auto, tol=eps)
1582
+ assert_np_equal(gradients_y, gradients_y_auto, tol=eps)
1583
+ assert_np_equal(gradients_z, gradients_z_auto, tol=eps)
1584
+ assert_np_equal(gradients_w, gradients_w_auto, tol=eps)
1585
+
1586
+ assert_np_equal(edge_gradients_x, edge_gradients_x_auto, tol=eps)
1587
+ assert_np_equal(edge_gradients_y, edge_gradients_y_auto, tol=eps)
1588
+ assert_np_equal(edge_gradients_z, edge_gradients_z_auto, tol=eps)
1589
+ assert_np_equal(edge_gradients_w, edge_gradients_w_auto, tol=eps)
1590
+
1591
+
1592
+ ############################################################
1593
+
1594
+
1595
+ def test_quat_rpy_grad(test, device, dtype, register_kernels=False):
1596
+ rng = np.random.default_rng(123)
1597
+ N = 3
1598
+
1599
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1600
+
1601
+ vec3 = wp.types.vector(3, wptype)
1602
+ quat = wp.types.quaternion(wptype)
1603
+
1604
+ def rpy_to_quat_kernel(rpy_arr: wp.array(dtype=vec3), loss: wp.array(dtype=wptype), coord_idx: int):
1605
+ tid = wp.tid()
1606
+ rpy = rpy_arr[tid]
1607
+ roll = rpy[0]
1608
+ pitch = rpy[1]
1609
+ yaw = rpy[2]
1610
+
1611
+ q = wp.quat_rpy(roll, pitch, yaw)
1612
+
1613
+ wp.atomic_add(loss, 0, q[coord_idx])
1614
+
1615
+ rpy_to_quat_kernel = getkernel(rpy_to_quat_kernel, suffix=dtype.__name__)
1616
+
1617
+ def rpy_to_quat_kernel_forward(rpy_arr: wp.array(dtype=vec3), loss: wp.array(dtype=wptype), coord_idx: int):
1618
+ tid = wp.tid()
1619
+ rpy = rpy_arr[tid]
1620
+ roll = rpy[0]
1621
+ pitch = rpy[1]
1622
+ yaw = rpy[2]
1623
+
1624
+ cy = wp.cos(yaw * wptype(0.5))
1625
+ sy = wp.sin(yaw * wptype(0.5))
1626
+ cr = wp.cos(roll * wptype(0.5))
1627
+ sr = wp.sin(roll * wptype(0.5))
1628
+ cp = wp.cos(pitch * wptype(0.5))
1629
+ sp = wp.sin(pitch * wptype(0.5))
1630
+
1631
+ w = cy * cr * cp + sy * sr * sp
1632
+ x = cy * sr * cp - sy * cr * sp
1633
+ y = cy * cr * sp + sy * sr * cp
1634
+ z = sy * cr * cp - cy * sr * sp
1635
+
1636
+ q = quat(x, y, z, w)
1637
+
1638
+ wp.atomic_add(loss, 0, q[coord_idx])
1639
+
1640
+ rpy_to_quat_kernel_forward = getkernel(rpy_to_quat_kernel_forward, suffix=dtype.__name__)
1641
+
1642
+ if register_kernels:
1643
+ return
1644
+
1645
+ rpy_arr = rng.uniform(low=-np.pi, high=np.pi, size=(N, 3))
1646
+ rpy_arr = wp.array(rpy_arr, dtype=vec3, device=device, requires_grad=True)
1647
+
1648
+ def compute_gradients(kernel, wrt, index):
1649
+ loss = wp.zeros(1, dtype=wptype, device=device, requires_grad=True)
1650
+ tape = wp.Tape()
1651
+ with tape:
1652
+ wp.launch(kernel=kernel, dim=N, inputs=[wrt, loss, index], device=device)
1653
+
1654
+ tape.backward(loss)
1655
+
1656
+ gradients = 1.0 * tape.gradients[wrt].numpy()
1657
+ tape.zero()
1658
+
1659
+ return loss.numpy()[0], gradients
1660
+
1661
+ # wrt rpy
1662
+ # gather gradients from builtin adjoints
1663
+ rcmp, gradients_r = compute_gradients(rpy_to_quat_kernel, rpy_arr, 0)
1664
+ pcmp, gradients_p = compute_gradients(rpy_to_quat_kernel, rpy_arr, 1)
1665
+ ycmp, gradients_y = compute_gradients(rpy_to_quat_kernel, rpy_arr, 2)
1666
+
1667
+ # gather gradients from autodiff
1668
+ rcmp_auto, gradients_r_auto = compute_gradients(rpy_to_quat_kernel_forward, rpy_arr, 0)
1669
+ pcmp_auto, gradients_p_auto = compute_gradients(rpy_to_quat_kernel_forward, rpy_arr, 1)
1670
+ ycmp_auto, gradients_y_auto = compute_gradients(rpy_to_quat_kernel_forward, rpy_arr, 2)
1671
+
1672
+ eps = {
1673
+ np.float16: 2.0e-2,
1674
+ np.float32: 1.0e-5,
1675
+ np.float64: 1.0e-8,
1676
+ }.get(dtype, 0)
1677
+
1678
+ assert_np_equal(rcmp, rcmp_auto, tol=eps)
1679
+ assert_np_equal(pcmp, pcmp_auto, tol=eps)
1680
+ assert_np_equal(ycmp, ycmp_auto, tol=eps)
1681
+
1682
+ assert_np_equal(gradients_r, gradients_r_auto, tol=eps)
1683
+ assert_np_equal(gradients_p, gradients_p_auto, tol=eps)
1684
+ assert_np_equal(gradients_y, gradients_y_auto, tol=eps)
1685
+
1686
+
1687
+ ############################################################
1688
+
1689
+
1690
+ def test_quat_from_matrix(test, device, dtype, register_kernels=False):
1691
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1692
+ mat33 = wp.types.matrix((3, 3), wptype)
1693
+ quat = wp.types.quaternion(wptype)
1694
+
1695
+ def quat_from_matrix(m: wp.array2d(dtype=wptype), loss: wp.array(dtype=wptype), idx: int):
1696
+ tid = wp.tid()
1697
+
1698
+ matrix = mat33(
1699
+ m[tid, 0], m[tid, 1], m[tid, 2], m[tid, 3], m[tid, 4], m[tid, 5], m[tid, 6], m[tid, 7], m[tid, 8]
1700
+ )
1701
+
1702
+ q = wp.quat_from_matrix(matrix)
1703
+
1704
+ wp.atomic_add(loss, 0, q[idx])
1705
+
1706
+ def quat_from_matrix_forward(mats: wp.array2d(dtype=wptype), loss: wp.array(dtype=wptype), idx: int):
1707
+ tid = wp.tid()
1708
+
1709
+ m = mat33(
1710
+ mats[tid, 0],
1711
+ mats[tid, 1],
1712
+ mats[tid, 2],
1713
+ mats[tid, 3],
1714
+ mats[tid, 4],
1715
+ mats[tid, 5],
1716
+ mats[tid, 6],
1717
+ mats[tid, 7],
1718
+ mats[tid, 8],
1719
+ )
1720
+
1721
+ tr = m[0][0] + m[1][1] + m[2][2]
1722
+ x = wptype(0)
1723
+ y = wptype(0)
1724
+ z = wptype(0)
1725
+ w = wptype(0)
1726
+ h = wptype(0)
1727
+
1728
+ if tr >= wptype(0):
1729
+ h = wp.sqrt(tr + wptype(1))
1730
+ w = wptype(0.5) * h
1731
+ h = wptype(0.5) / h
1732
+
1733
+ x = (m[2][1] - m[1][2]) * h
1734
+ y = (m[0][2] - m[2][0]) * h
1735
+ z = (m[1][0] - m[0][1]) * h
1736
+ else:
1737
+ max_diag = 0
1738
+ if m[1][1] > m[0][0]:
1739
+ max_diag = 1
1740
+ if m[2][2] > m[max_diag][max_diag]:
1741
+ max_diag = 2
1742
+
1743
+ if max_diag == 0:
1744
+ h = wp.sqrt((m[0][0] - (m[1][1] + m[2][2])) + wptype(1))
1745
+ x = wptype(0.5) * h
1746
+ h = wptype(0.5) / h
1747
+
1748
+ y = (m[0][1] + m[1][0]) * h
1749
+ z = (m[2][0] + m[0][2]) * h
1750
+ w = (m[2][1] - m[1][2]) * h
1751
+ elif max_diag == 1:
1752
+ h = wp.sqrt((m[1][1] - (m[2][2] + m[0][0])) + wptype(1))
1753
+ y = wptype(0.5) * h
1754
+ h = wptype(0.5) / h
1755
+
1756
+ z = (m[1][2] + m[2][1]) * h
1757
+ x = (m[0][1] + m[1][0]) * h
1758
+ w = (m[0][2] - m[2][0]) * h
1759
+ if max_diag == 2:
1760
+ h = wp.sqrt((m[2][2] - (m[0][0] + m[1][1])) + wptype(1))
1761
+ z = wptype(0.5) * h
1762
+ h = wptype(0.5) / h
1763
+
1764
+ x = (m[2][0] + m[0][2]) * h
1765
+ y = (m[1][2] + m[2][1]) * h
1766
+ w = (m[1][0] - m[0][1]) * h
1767
+
1768
+ q = wp.normalize(quat(x, y, z, w))
1769
+
1770
+ wp.atomic_add(loss, 0, q[idx])
1771
+
1772
+ quat_from_matrix = getkernel(quat_from_matrix, suffix=dtype.__name__)
1773
+ quat_from_matrix_forward = getkernel(quat_from_matrix_forward, suffix=dtype.__name__)
1774
+
1775
+ if register_kernels:
1776
+ return
1777
+
1778
+ m = np.array(
1779
+ [
1780
+ [1.0, 0.0, 0.0, 0.0, 0.5, 0.866, 0.0, -0.866, 0.5],
1781
+ [0.866, 0.0, 0.25, -0.433, 0.5, 0.75, -0.25, -0.866, 0.433],
1782
+ [0.866, -0.433, 0.25, 0.0, 0.5, 0.866, -0.5, -0.75, 0.433],
1783
+ [-1.2, -1.6, -2.3, 0.25, -0.6, -0.33, 3.2, -1.0, -2.2],
1784
+ ]
1785
+ )
1786
+ m = wp.array2d(m, dtype=wptype, device=device, requires_grad=True)
1787
+
1788
+ N = m.shape[0]
1789
+
1790
+ def compute_gradients(kernel, wrt, index):
1791
+ loss = wp.zeros(1, dtype=wptype, device=device, requires_grad=True)
1792
+ tape = wp.Tape()
1793
+
1794
+ with tape:
1795
+ wp.launch(kernel=kernel, dim=N, inputs=[m, loss, index], device=device)
1796
+
1797
+ tape.backward(loss)
1798
+
1799
+ gradients = 1.0 * tape.gradients[wrt].numpy()
1800
+ tape.zero()
1801
+
1802
+ return loss.numpy()[0], gradients
1803
+
1804
+ # gather gradients from builtin adjoints
1805
+ cmpx, gradients_x = compute_gradients(quat_from_matrix, m, 0)
1806
+ cmpy, gradients_y = compute_gradients(quat_from_matrix, m, 1)
1807
+ cmpz, gradients_z = compute_gradients(quat_from_matrix, m, 2)
1808
+ cmpw, gradients_w = compute_gradients(quat_from_matrix, m, 3)
1809
+
1810
+ # gather gradients from autodiff
1811
+ cmpx_auto, gradients_x_auto = compute_gradients(quat_from_matrix_forward, m, 0)
1812
+ cmpy_auto, gradients_y_auto = compute_gradients(quat_from_matrix_forward, m, 1)
1813
+ cmpz_auto, gradients_z_auto = compute_gradients(quat_from_matrix_forward, m, 2)
1814
+ cmpw_auto, gradients_w_auto = compute_gradients(quat_from_matrix_forward, m, 3)
1815
+
1816
+ # compare
1817
+ eps = 1.0e6
1818
+
1819
+ eps = {
1820
+ np.float16: 2.0e-2,
1821
+ np.float32: 1.0e-5,
1822
+ np.float64: 1.0e-8,
1823
+ }.get(dtype, 0)
1824
+
1825
+ assert_np_equal(cmpx, cmpx_auto, tol=eps)
1826
+ assert_np_equal(cmpy, cmpy_auto, tol=eps)
1827
+ assert_np_equal(cmpz, cmpz_auto, tol=eps)
1828
+ assert_np_equal(cmpw, cmpw_auto, tol=eps)
1829
+
1830
+ assert_np_equal(gradients_x, gradients_x_auto, tol=eps)
1831
+ assert_np_equal(gradients_y, gradients_y_auto, tol=eps)
1832
+ assert_np_equal(gradients_z, gradients_z_auto, tol=eps)
1833
+ assert_np_equal(gradients_w, gradients_w_auto, tol=eps)
1834
+
1835
+
1836
+ def test_quat_identity(test, device, dtype, register_kernels=False):
1837
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1838
+
1839
+ def quat_identity_test(output: wp.array(dtype=wptype)):
1840
+ q = wp.quat_identity(dtype=wptype)
1841
+ output[0] = q[0]
1842
+ output[1] = q[1]
1843
+ output[2] = q[2]
1844
+ output[3] = q[3]
1845
+
1846
+ def quat_identity_test_default(output: wp.array(dtype=wp.float32)):
1847
+ q = wp.quat_identity()
1848
+ output[0] = q[0]
1849
+ output[1] = q[1]
1850
+ output[2] = q[2]
1851
+ output[3] = q[3]
1852
+
1853
+ quat_identity_kernel = getkernel(quat_identity_test, suffix=dtype.__name__)
1854
+ quat_identity_default_kernel = getkernel(quat_identity_test_default, suffix=np.float32.__name__)
1855
+
1856
+ if register_kernels:
1857
+ return
1858
+
1859
+ output = wp.zeros(4, dtype=wptype, device=device)
1860
+ wp.launch(quat_identity_kernel, dim=1, inputs=[], outputs=[output], device=device)
1861
+ expected = np.zeros_like(output.numpy())
1862
+ expected[3] = 1
1863
+ assert_np_equal(output.numpy(), expected)
1864
+
1865
+ # let's just test that it defaults to float32:
1866
+ output = wp.zeros(4, dtype=wp.float32, device=device)
1867
+ wp.launch(quat_identity_default_kernel, dim=1, inputs=[], outputs=[output], device=device)
1868
+ expected = np.zeros_like(output.numpy())
1869
+ expected[3] = 1
1870
+ assert_np_equal(output.numpy(), expected)
1871
+
1872
+
1873
+ ############################################################
1874
+
1875
+
1876
+ def test_quat_euler_conversion(test, device, dtype, register_kernels=False):
1877
+ rng = np.random.default_rng(123)
1878
+ N = 3
1879
+
1880
+ rpy_arr = rng.uniform(low=-np.pi, high=np.pi, size=(N, 3))
1881
+
1882
+ quats_from_euler = [list(wp.sim.quat_from_euler(wp.vec3(*rpy), 0, 1, 2)) for rpy in rpy_arr]
1883
+ quats_from_rpy = [list(wp.quat_rpy(rpy[0], rpy[1], rpy[2])) for rpy in rpy_arr]
1884
+
1885
+ assert_np_equal(np.array(quats_from_euler), np.array(quats_from_rpy), tol=1e-4)
1886
+
1887
+
1888
+ def test_anon_type_instance(test, device, dtype, register_kernels=False):
1889
+ rng = np.random.default_rng(123)
1890
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1891
+
1892
+ def quat_create_test(input: wp.array(dtype=wptype), output: wp.array(dtype=wptype)):
1893
+ # component constructor:
1894
+ q = wp.quaternion(input[0], input[1], input[2], input[3])
1895
+ output[0] = wptype(2) * q[0]
1896
+ output[1] = wptype(2) * q[1]
1897
+ output[2] = wptype(2) * q[2]
1898
+ output[3] = wptype(2) * q[3]
1899
+
1900
+ # vector / scalar constructor:
1901
+ q2 = wp.quaternion(wp.vector(input[4], input[5], input[6]), input[7])
1902
+ output[4] = wptype(2) * q2[0]
1903
+ output[5] = wptype(2) * q2[1]
1904
+ output[6] = wptype(2) * q2[2]
1905
+ output[7] = wptype(2) * q2[3]
1906
+
1907
+ quat_create_kernel = getkernel(quat_create_test, suffix=dtype.__name__)
1908
+ output_select_kernel = get_select_kernel(wptype)
1909
+
1910
+ if register_kernels:
1911
+ return
1912
+
1913
+ input = wp.array(rng.standard_normal(size=8).astype(dtype), requires_grad=True, device=device)
1914
+ output = wp.zeros(8, dtype=wptype, requires_grad=True, device=device)
1915
+ wp.launch(quat_create_kernel, dim=1, inputs=[input], outputs=[output], device=device)
1916
+ assert_np_equal(output.numpy(), 2 * input.numpy())
1917
+
1918
+ for i in range(len(input)):
1919
+ cmp = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1920
+ tape = wp.Tape()
1921
+ with tape:
1922
+ wp.launch(quat_create_kernel, dim=1, inputs=[input], outputs=[output], device=device)
1923
+ wp.launch(output_select_kernel, dim=1, inputs=[output, i], outputs=[cmp], device=device)
1924
+ tape.backward(loss=cmp)
1925
+ expectedgrads = np.zeros(len(input))
1926
+ expectedgrads[i] = 2
1927
+ assert_np_equal(tape.gradients[input].numpy(), expectedgrads)
1928
+ tape.zero()
1929
+
1930
+
1931
+ # Same as above but with a default (float) type
1932
+ # which tests some different code paths that
1933
+ # need to ensure types are correctly canonicalized
1934
+ # during codegen
1935
+ @wp.kernel
1936
+ def test_constructor_default():
1937
+ qzero = wp.quat()
1938
+ wp.expect_eq(qzero[0], 0.0)
1939
+ wp.expect_eq(qzero[1], 0.0)
1940
+ wp.expect_eq(qzero[2], 0.0)
1941
+ wp.expect_eq(qzero[3], 0.0)
1942
+
1943
+ qval = wp.quat(1.0, 2.0, 3.0, 4.0)
1944
+ wp.expect_eq(qval[0], 1.0)
1945
+ wp.expect_eq(qval[1], 2.0)
1946
+ wp.expect_eq(qval[2], 3.0)
1947
+ wp.expect_eq(qval[3], 4.0)
1948
+
1949
+ qeye = wp.quat_identity()
1950
+ wp.expect_eq(qeye[0], 0.0)
1951
+ wp.expect_eq(qeye[1], 0.0)
1952
+ wp.expect_eq(qeye[2], 0.0)
1953
+ wp.expect_eq(qeye[3], 1.0)
1954
+
1955
+
1956
+ def test_py_arithmetic_ops(test, device, dtype):
1957
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1958
+
1959
+ def make_quat(*args):
1960
+ if wptype in wp.types.int_types:
1961
+ # Cast to the correct integer type to simulate wrapping.
1962
+ return tuple(wptype._type_(x).value for x in args)
1963
+
1964
+ return args
1965
+
1966
+ quat_cls = wp.types.quaternion(wptype)
1967
+
1968
+ v = quat_cls(1, -2, 3, -4)
1969
+ test.assertSequenceEqual(+v, make_quat(1, -2, 3, -4))
1970
+ test.assertSequenceEqual(-v, make_quat(-1, 2, -3, 4))
1971
+ test.assertSequenceEqual(v + quat_cls(5, 5, 5, 5), make_quat(6, 3, 8, 1))
1972
+ test.assertSequenceEqual(v - quat_cls(5, 5, 5, 5), make_quat(-4, -7, -2, -9))
1973
+
1974
+ v = quat_cls(2, 4, 6, 8)
1975
+ test.assertSequenceEqual(v * wptype(2), make_quat(4, 8, 12, 16))
1976
+ test.assertSequenceEqual(wptype(2) * v, make_quat(4, 8, 12, 16))
1977
+ test.assertSequenceEqual(v / wptype(2), make_quat(1, 2, 3, 4))
1978
+ test.assertSequenceEqual(wptype(24) / v, make_quat(12, 6, 4, 3))
1979
+
1980
+
1981
+ devices = get_test_devices()
1982
+
1983
+
1984
+ class TestQuat(unittest.TestCase):
1985
+ pass
1986
+
1987
+
1988
+ add_kernel_test(TestQuat, test_constructor_default, dim=1, devices=devices)
1989
+
1990
+ for dtype in np_float_types:
1991
+ add_function_test_register_kernel(
1992
+ TestQuat, f"test_constructors_{dtype.__name__}", test_constructors, devices=devices, dtype=dtype
1993
+ )
1994
+ add_function_test_register_kernel(
1995
+ TestQuat,
1996
+ f"test_casting_constructors_{dtype.__name__}",
1997
+ test_casting_constructors,
1998
+ devices=devices,
1999
+ dtype=dtype,
2000
+ )
2001
+ add_function_test_register_kernel(
2002
+ TestQuat, f"test_anon_type_instance_{dtype.__name__}", test_anon_type_instance, devices=devices, dtype=dtype
2003
+ )
2004
+ add_function_test_register_kernel(
2005
+ TestQuat, f"test_inverse_{dtype.__name__}", test_inverse, devices=devices, dtype=dtype
2006
+ )
2007
+ add_function_test_register_kernel(
2008
+ TestQuat, f"test_quat_identity_{dtype.__name__}", test_quat_identity, devices=devices, dtype=dtype
2009
+ )
2010
+ add_function_test_register_kernel(
2011
+ TestQuat, f"test_dotproduct_{dtype.__name__}", test_dotproduct, devices=devices, dtype=dtype
2012
+ )
2013
+ add_function_test_register_kernel(
2014
+ TestQuat, f"test_length_{dtype.__name__}", test_length, devices=devices, dtype=dtype
2015
+ )
2016
+ add_function_test_register_kernel(
2017
+ TestQuat, f"test_normalize_{dtype.__name__}", test_normalize, devices=devices, dtype=dtype
2018
+ )
2019
+ add_function_test_register_kernel(
2020
+ TestQuat, f"test_addition_{dtype.__name__}", test_addition, devices=devices, dtype=dtype
2021
+ )
2022
+ add_function_test_register_kernel(
2023
+ TestQuat, f"test_subtraction_{dtype.__name__}", test_subtraction, devices=devices, dtype=dtype
2024
+ )
2025
+ add_function_test_register_kernel(
2026
+ TestQuat,
2027
+ f"test_scalar_multiplication_{dtype.__name__}",
2028
+ test_scalar_multiplication,
2029
+ devices=devices,
2030
+ dtype=dtype,
2031
+ )
2032
+ add_function_test_register_kernel(
2033
+ TestQuat, f"test_scalar_division_{dtype.__name__}", test_scalar_division, devices=devices, dtype=dtype
2034
+ )
2035
+ add_function_test_register_kernel(
2036
+ TestQuat,
2037
+ f"test_quat_multiplication_{dtype.__name__}",
2038
+ test_quat_multiplication,
2039
+ devices=devices,
2040
+ dtype=dtype,
2041
+ )
2042
+ add_function_test_register_kernel(
2043
+ TestQuat, f"test_indexing_{dtype.__name__}", test_indexing, devices=devices, dtype=dtype
2044
+ )
2045
+ add_function_test_register_kernel(
2046
+ TestQuat, f"test_quat_lerp_{dtype.__name__}", test_quat_lerp, devices=devices, dtype=dtype
2047
+ )
2048
+ add_function_test_register_kernel(
2049
+ TestQuat,
2050
+ f"test_quat_to_axis_angle_grad_{dtype.__name__}",
2051
+ test_quat_to_axis_angle_grad,
2052
+ devices=devices,
2053
+ dtype=dtype,
2054
+ )
2055
+ add_function_test_register_kernel(
2056
+ TestQuat, f"test_slerp_grad_{dtype.__name__}", test_slerp_grad, devices=devices, dtype=dtype
2057
+ )
2058
+ add_function_test_register_kernel(
2059
+ TestQuat, f"test_quat_rpy_grad_{dtype.__name__}", test_quat_rpy_grad, devices=devices, dtype=dtype
2060
+ )
2061
+ add_function_test_register_kernel(
2062
+ TestQuat, f"test_quat_from_matrix_{dtype.__name__}", test_quat_from_matrix, devices=devices, dtype=dtype
2063
+ )
2064
+ add_function_test_register_kernel(
2065
+ TestQuat, f"test_quat_rotate_{dtype.__name__}", test_quat_rotate, devices=devices, dtype=dtype
2066
+ )
2067
+ add_function_test_register_kernel(
2068
+ TestQuat, f"test_quat_to_matrix_{dtype.__name__}", test_quat_to_matrix, devices=devices, dtype=dtype
2069
+ )
2070
+ add_function_test_register_kernel(
2071
+ TestQuat,
2072
+ f"test_quat_euler_conversion_{dtype.__name__}",
2073
+ test_quat_euler_conversion,
2074
+ devices=devices,
2075
+ dtype=dtype,
2076
+ )
2077
+ add_function_test(
2078
+ TestQuat, f"test_py_arithmetic_ops_{dtype.__name__}", test_py_arithmetic_ops, devices=None, dtype=dtype
2079
+ )
2080
+
2081
+
2082
+ if __name__ == "__main__":
2083
+ wp.build.clear_kernel_cache()
2084
+ unittest.main(verbosity=2)