warp-lang 1.0.0b2__py3-none-manylinux2014_x86_64.whl → 1.0.0b6__py3-none-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (269) hide show
  1. docs/conf.py +17 -5
  2. examples/env/env_ant.py +1 -1
  3. examples/env/env_cartpole.py +1 -1
  4. examples/env/env_humanoid.py +1 -1
  5. examples/env/env_usd.py +4 -1
  6. examples/env/environment.py +8 -9
  7. examples/example_dem.py +34 -33
  8. examples/example_diffray.py +364 -337
  9. examples/example_fluid.py +32 -23
  10. examples/example_jacobian_ik.py +97 -93
  11. examples/example_marching_cubes.py +6 -16
  12. examples/example_mesh.py +6 -16
  13. examples/example_mesh_intersect.py +16 -14
  14. examples/example_nvdb.py +14 -16
  15. examples/example_raycast.py +14 -13
  16. examples/example_raymarch.py +16 -23
  17. examples/example_render_opengl.py +19 -10
  18. examples/example_sim_cartpole.py +82 -78
  19. examples/example_sim_cloth.py +45 -48
  20. examples/example_sim_fk_grad.py +51 -44
  21. examples/example_sim_fk_grad_torch.py +47 -40
  22. examples/example_sim_grad_bounce.py +108 -133
  23. examples/example_sim_grad_cloth.py +99 -113
  24. examples/example_sim_granular.py +5 -6
  25. examples/{example_sim_sdf_shape.py → example_sim_granular_collision_sdf.py} +37 -26
  26. examples/example_sim_neo_hookean.py +51 -55
  27. examples/example_sim_particle_chain.py +4 -4
  28. examples/example_sim_quadruped.py +126 -81
  29. examples/example_sim_rigid_chain.py +54 -61
  30. examples/example_sim_rigid_contact.py +66 -70
  31. examples/example_sim_rigid_fem.py +3 -3
  32. examples/example_sim_rigid_force.py +1 -1
  33. examples/example_sim_rigid_gyroscopic.py +3 -4
  34. examples/example_sim_rigid_kinematics.py +28 -39
  35. examples/example_sim_trajopt.py +112 -110
  36. examples/example_sph.py +9 -8
  37. examples/example_wave.py +7 -7
  38. examples/fem/bsr_utils.py +30 -17
  39. examples/fem/example_apic_fluid.py +85 -69
  40. examples/fem/example_convection_diffusion.py +97 -93
  41. examples/fem/example_convection_diffusion_dg.py +142 -149
  42. examples/fem/example_convection_diffusion_dg0.py +141 -136
  43. examples/fem/example_deformed_geometry.py +146 -0
  44. examples/fem/example_diffusion.py +115 -84
  45. examples/fem/example_diffusion_3d.py +116 -86
  46. examples/fem/example_diffusion_mgpu.py +102 -79
  47. examples/fem/example_mixed_elasticity.py +139 -100
  48. examples/fem/example_navier_stokes.py +175 -162
  49. examples/fem/example_stokes.py +143 -111
  50. examples/fem/example_stokes_transfer.py +186 -157
  51. examples/fem/mesh_utils.py +59 -97
  52. examples/fem/plot_utils.py +138 -17
  53. tools/ci/publishing/build_nodes_info.py +54 -0
  54. warp/__init__.py +4 -3
  55. warp/__init__.pyi +1 -0
  56. warp/bin/warp-clang.so +0 -0
  57. warp/bin/warp.so +0 -0
  58. warp/build.py +5 -3
  59. warp/build_dll.py +29 -9
  60. warp/builtins.py +836 -492
  61. warp/codegen.py +864 -553
  62. warp/config.py +3 -1
  63. warp/context.py +389 -172
  64. warp/fem/__init__.py +24 -6
  65. warp/fem/cache.py +318 -25
  66. warp/fem/dirichlet.py +7 -3
  67. warp/fem/domain.py +14 -0
  68. warp/fem/field/__init__.py +30 -38
  69. warp/fem/field/field.py +149 -0
  70. warp/fem/field/nodal_field.py +244 -138
  71. warp/fem/field/restriction.py +8 -6
  72. warp/fem/field/test.py +127 -59
  73. warp/fem/field/trial.py +117 -60
  74. warp/fem/geometry/__init__.py +5 -1
  75. warp/fem/geometry/deformed_geometry.py +271 -0
  76. warp/fem/geometry/element.py +24 -1
  77. warp/fem/geometry/geometry.py +86 -14
  78. warp/fem/geometry/grid_2d.py +112 -54
  79. warp/fem/geometry/grid_3d.py +134 -65
  80. warp/fem/geometry/hexmesh.py +953 -0
  81. warp/fem/geometry/partition.py +85 -33
  82. warp/fem/geometry/quadmesh_2d.py +532 -0
  83. warp/fem/geometry/tetmesh.py +451 -115
  84. warp/fem/geometry/trimesh_2d.py +197 -92
  85. warp/fem/integrate.py +534 -268
  86. warp/fem/operator.py +58 -31
  87. warp/fem/polynomial.py +11 -0
  88. warp/fem/quadrature/__init__.py +1 -1
  89. warp/fem/quadrature/pic_quadrature.py +150 -58
  90. warp/fem/quadrature/quadrature.py +209 -57
  91. warp/fem/space/__init__.py +230 -53
  92. warp/fem/space/basis_space.py +489 -0
  93. warp/fem/space/collocated_function_space.py +105 -0
  94. warp/fem/space/dof_mapper.py +49 -2
  95. warp/fem/space/function_space.py +90 -39
  96. warp/fem/space/grid_2d_function_space.py +149 -496
  97. warp/fem/space/grid_3d_function_space.py +173 -538
  98. warp/fem/space/hexmesh_function_space.py +352 -0
  99. warp/fem/space/partition.py +129 -76
  100. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  101. warp/fem/space/restriction.py +46 -34
  102. warp/fem/space/shape/__init__.py +15 -0
  103. warp/fem/space/shape/cube_shape_function.py +738 -0
  104. warp/fem/space/shape/shape_function.py +103 -0
  105. warp/fem/space/shape/square_shape_function.py +611 -0
  106. warp/fem/space/shape/tet_shape_function.py +567 -0
  107. warp/fem/space/shape/triangle_shape_function.py +429 -0
  108. warp/fem/space/tetmesh_function_space.py +132 -1039
  109. warp/fem/space/topology.py +295 -0
  110. warp/fem/space/trimesh_2d_function_space.py +104 -742
  111. warp/fem/types.py +13 -11
  112. warp/fem/utils.py +335 -60
  113. warp/native/array.h +120 -34
  114. warp/native/builtin.h +101 -72
  115. warp/native/bvh.cpp +73 -325
  116. warp/native/bvh.cu +406 -23
  117. warp/native/bvh.h +22 -40
  118. warp/native/clang/clang.cpp +1 -0
  119. warp/native/crt.h +2 -0
  120. warp/native/cuda_util.cpp +8 -3
  121. warp/native/cuda_util.h +1 -0
  122. warp/native/exports.h +1522 -1243
  123. warp/native/intersect.h +19 -4
  124. warp/native/intersect_adj.h +8 -8
  125. warp/native/mat.h +76 -17
  126. warp/native/mesh.cpp +33 -108
  127. warp/native/mesh.cu +114 -18
  128. warp/native/mesh.h +395 -40
  129. warp/native/noise.h +272 -329
  130. warp/native/quat.h +51 -8
  131. warp/native/rand.h +44 -34
  132. warp/native/reduce.cpp +1 -1
  133. warp/native/sparse.cpp +4 -4
  134. warp/native/sparse.cu +163 -155
  135. warp/native/spatial.h +2 -2
  136. warp/native/temp_buffer.h +18 -14
  137. warp/native/vec.h +103 -21
  138. warp/native/warp.cpp +2 -1
  139. warp/native/warp.cu +28 -3
  140. warp/native/warp.h +4 -3
  141. warp/render/render_opengl.py +261 -109
  142. warp/sim/__init__.py +1 -2
  143. warp/sim/articulation.py +385 -185
  144. warp/sim/import_mjcf.py +59 -48
  145. warp/sim/import_urdf.py +15 -15
  146. warp/sim/import_usd.py +174 -102
  147. warp/sim/inertia.py +17 -18
  148. warp/sim/integrator_xpbd.py +4 -3
  149. warp/sim/model.py +330 -250
  150. warp/sim/render.py +1 -1
  151. warp/sparse.py +625 -152
  152. warp/stubs.py +341 -309
  153. warp/tape.py +9 -6
  154. warp/tests/__main__.py +3 -6
  155. warp/tests/assets/curlnoise_golden.npy +0 -0
  156. warp/tests/assets/pnoise_golden.npy +0 -0
  157. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  158. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  159. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  160. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  161. warp/tests/aux_test_unresolved_func.py +14 -0
  162. warp/tests/aux_test_unresolved_symbol.py +14 -0
  163. warp/tests/disabled_kinematics.py +239 -0
  164. warp/tests/run_coverage_serial.py +31 -0
  165. warp/tests/test_adam.py +103 -106
  166. warp/tests/test_arithmetic.py +94 -74
  167. warp/tests/test_array.py +82 -101
  168. warp/tests/test_array_reduce.py +57 -23
  169. warp/tests/test_atomic.py +64 -28
  170. warp/tests/test_bool.py +22 -12
  171. warp/tests/test_builtins_resolution.py +1292 -0
  172. warp/tests/test_bvh.py +18 -18
  173. warp/tests/test_closest_point_edge_edge.py +54 -57
  174. warp/tests/test_codegen.py +165 -134
  175. warp/tests/test_compile_consts.py +28 -20
  176. warp/tests/test_conditional.py +108 -24
  177. warp/tests/test_copy.py +10 -12
  178. warp/tests/test_ctypes.py +112 -88
  179. warp/tests/test_dense.py +21 -14
  180. warp/tests/test_devices.py +98 -0
  181. warp/tests/test_dlpack.py +75 -75
  182. warp/tests/test_examples.py +237 -0
  183. warp/tests/test_fabricarray.py +22 -24
  184. warp/tests/test_fast_math.py +15 -11
  185. warp/tests/test_fem.py +1034 -124
  186. warp/tests/test_fp16.py +23 -16
  187. warp/tests/test_func.py +187 -86
  188. warp/tests/test_generics.py +194 -49
  189. warp/tests/test_grad.py +123 -181
  190. warp/tests/test_grad_customs.py +176 -0
  191. warp/tests/test_hash_grid.py +35 -34
  192. warp/tests/test_import.py +10 -23
  193. warp/tests/test_indexedarray.py +24 -25
  194. warp/tests/test_intersect.py +18 -9
  195. warp/tests/test_large.py +141 -0
  196. warp/tests/test_launch.py +14 -41
  197. warp/tests/test_lerp.py +64 -65
  198. warp/tests/test_lvalue.py +493 -0
  199. warp/tests/test_marching_cubes.py +12 -13
  200. warp/tests/test_mat.py +517 -2898
  201. warp/tests/test_mat_lite.py +115 -0
  202. warp/tests/test_mat_scalar_ops.py +2889 -0
  203. warp/tests/test_math.py +103 -9
  204. warp/tests/test_matmul.py +304 -69
  205. warp/tests/test_matmul_lite.py +410 -0
  206. warp/tests/test_mesh.py +60 -22
  207. warp/tests/test_mesh_query_aabb.py +21 -25
  208. warp/tests/test_mesh_query_point.py +111 -22
  209. warp/tests/test_mesh_query_ray.py +12 -24
  210. warp/tests/test_mlp.py +30 -22
  211. warp/tests/test_model.py +92 -89
  212. warp/tests/test_modules_lite.py +39 -0
  213. warp/tests/test_multigpu.py +88 -114
  214. warp/tests/test_noise.py +12 -11
  215. warp/tests/test_operators.py +16 -20
  216. warp/tests/test_options.py +11 -11
  217. warp/tests/test_pinned.py +17 -18
  218. warp/tests/test_print.py +32 -11
  219. warp/tests/test_quat.py +275 -129
  220. warp/tests/test_rand.py +18 -16
  221. warp/tests/test_reload.py +38 -34
  222. warp/tests/test_rounding.py +50 -43
  223. warp/tests/test_runlength_encode.py +168 -20
  224. warp/tests/test_smoothstep.py +9 -11
  225. warp/tests/test_snippet.py +143 -0
  226. warp/tests/test_sparse.py +261 -63
  227. warp/tests/test_spatial.py +276 -243
  228. warp/tests/test_streams.py +110 -85
  229. warp/tests/test_struct.py +268 -63
  230. warp/tests/test_tape.py +39 -21
  231. warp/tests/test_torch.py +90 -86
  232. warp/tests/test_transient_module.py +10 -12
  233. warp/tests/test_types.py +363 -0
  234. warp/tests/test_utils.py +451 -0
  235. warp/tests/test_vec.py +354 -2050
  236. warp/tests/test_vec_lite.py +73 -0
  237. warp/tests/test_vec_scalar_ops.py +2099 -0
  238. warp/tests/test_volume.py +418 -376
  239. warp/tests/test_volume_write.py +124 -134
  240. warp/tests/unittest_serial.py +35 -0
  241. warp/tests/unittest_suites.py +291 -0
  242. warp/tests/unittest_utils.py +342 -0
  243. warp/tests/{test_misc.py → unused_test_misc.py} +13 -5
  244. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  245. warp/thirdparty/appdirs.py +36 -45
  246. warp/thirdparty/unittest_parallel.py +589 -0
  247. warp/types.py +622 -211
  248. warp/utils.py +54 -393
  249. warp_lang-1.0.0b6.dist-info/METADATA +238 -0
  250. warp_lang-1.0.0b6.dist-info/RECORD +409 -0
  251. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/WHEEL +1 -1
  252. examples/example_cache_management.py +0 -40
  253. examples/example_multigpu.py +0 -54
  254. examples/example_struct.py +0 -65
  255. examples/fem/example_stokes_transfer_3d.py +0 -210
  256. warp/fem/field/discrete_field.py +0 -80
  257. warp/fem/space/nodal_function_space.py +0 -233
  258. warp/tests/test_all.py +0 -223
  259. warp/tests/test_array_scan.py +0 -60
  260. warp/tests/test_base.py +0 -208
  261. warp/tests/test_unresolved_func.py +0 -7
  262. warp/tests/test_unresolved_symbol.py +0 -7
  263. warp_lang-1.0.0b2.dist-info/METADATA +0 -26
  264. warp_lang-1.0.0b2.dist-info/RECORD +0 -378
  265. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  266. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  267. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  268. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/LICENSE.md +0 -0
  269. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/top_level.txt +0 -0
warp/tests/test_quat.py CHANGED
@@ -1,8 +1,16 @@
1
+ # Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
2
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
3
+ # and proprietary rights in and to this software, related documentation
4
+ # and any modifications thereto. Any use, reproduction, disclosure or
5
+ # distribution of this software and related documentation without an express
6
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
+
8
+ import unittest
9
+
1
10
  import numpy as np
2
- import os
3
11
 
4
12
  import warp as wp
5
- from warp.tests.test_base import *
13
+ from warp.tests.unittest_utils import *
6
14
 
7
15
  wp.init()
8
16
 
@@ -12,10 +20,9 @@ kernel_cache = dict()
12
20
 
13
21
 
14
22
  def getkernel(func, suffix=""):
15
- module = wp.get_module(func.__module__)
16
23
  key = func.__name__ + "_" + suffix
17
24
  if key not in kernel_cache:
18
- kernel_cache[key] = wp.Kernel(func=func, key=key, module=module)
25
+ kernel_cache[key] = wp.Kernel(func=func, key=key)
19
26
  return kernel_cache[key]
20
27
 
21
28
 
@@ -34,7 +41,7 @@ def get_select_kernel(dtype):
34
41
 
35
42
 
36
43
  def test_constructors(test, device, dtype, register_kernels=False):
37
- np.random.seed(123)
44
+ rng = np.random.default_rng(123)
38
45
 
39
46
  tol = {
40
47
  np.float16: 5.0e-3,
@@ -77,7 +84,7 @@ def test_constructors(test, device, dtype, register_kernels=False):
77
84
  if register_kernels:
78
85
  return
79
86
 
80
- input = wp.array(np.random.randn(4).astype(dtype), requires_grad=True, device=device)
87
+ input = wp.array(rng.standard_normal(size=4).astype(dtype), requires_grad=True, device=device)
81
88
  output = wp.zeros_like(input)
82
89
  wp.launch(kernel, dim=1, inputs=[input], outputs=[output], device=device)
83
90
 
@@ -95,7 +102,7 @@ def test_constructors(test, device, dtype, register_kernels=False):
95
102
  assert_np_equal(tape.gradients[input].numpy(), expectedgrads)
96
103
  tape.zero()
97
104
 
98
- input = wp.array(np.random.randn(4).astype(dtype), requires_grad=True, device=device)
105
+ input = wp.array(rng.standard_normal(size=4).astype(dtype), requires_grad=True, device=device)
99
106
  output = wp.zeros_like(input)
100
107
  wp.launch(vec_kernel, dim=1, inputs=[input], outputs=[output], device=device)
101
108
 
@@ -114,8 +121,114 @@ def test_constructors(test, device, dtype, register_kernels=False):
114
121
  tape.zero()
115
122
 
116
123
 
124
+ def test_casting_constructors(test, device, dtype, register_kernels=False):
125
+ np_type = np.dtype(dtype)
126
+ wp_type = wp.types.np_dtype_to_warp_type[np_type]
127
+ quat = wp.types.quaternion(dtype=wp_type)
128
+
129
+ np16 = np.dtype(np.float16)
130
+ wp16 = wp.types.np_dtype_to_warp_type[np16]
131
+
132
+ np32 = np.dtype(np.float32)
133
+ wp32 = wp.types.np_dtype_to_warp_type[np32]
134
+
135
+ np64 = np.dtype(np.float64)
136
+ wp64 = wp.types.np_dtype_to_warp_type[np64]
137
+
138
+ def cast_float16(a: wp.array(dtype=wp_type, ndim=2), b: wp.array(dtype=wp16, ndim=2)):
139
+ tid = wp.tid()
140
+
141
+ q1 = quat(a[tid, 0], a[tid, 1], a[tid, 2], a[tid, 3])
142
+ q2 = wp.quaternion(q1, dtype=wp16)
143
+
144
+ b[tid, 0] = q2[0]
145
+ b[tid, 1] = q2[1]
146
+ b[tid, 2] = q2[2]
147
+ b[tid, 3] = q2[3]
148
+
149
+ def cast_float32(a: wp.array(dtype=wp_type, ndim=2), b: wp.array(dtype=wp32, ndim=2)):
150
+ tid = wp.tid()
151
+
152
+ q1 = quat(a[tid, 0], a[tid, 1], a[tid, 2], a[tid, 3])
153
+ q2 = wp.quaternion(q1, dtype=wp32)
154
+
155
+ b[tid, 0] = q2[0]
156
+ b[tid, 1] = q2[1]
157
+ b[tid, 2] = q2[2]
158
+ b[tid, 3] = q2[3]
159
+
160
+ def cast_float64(a: wp.array(dtype=wp_type, ndim=2), b: wp.array(dtype=wp64, ndim=2)):
161
+ tid = wp.tid()
162
+
163
+ q1 = quat(a[tid, 0], a[tid, 1], a[tid, 2], a[tid, 3])
164
+ q2 = wp.quaternion(q1, dtype=wp64)
165
+
166
+ b[tid, 0] = q2[0]
167
+ b[tid, 1] = q2[1]
168
+ b[tid, 2] = q2[2]
169
+ b[tid, 3] = q2[3]
170
+
171
+ kernel_16 = getkernel(cast_float16, suffix=dtype.__name__)
172
+ kernel_32 = getkernel(cast_float32, suffix=dtype.__name__)
173
+ kernel_64 = getkernel(cast_float64, suffix=dtype.__name__)
174
+
175
+ if register_kernels:
176
+ return
177
+
178
+ # check casting to float 16
179
+ a = wp.array(np.ones((1, 4), dtype=np_type), dtype=wp_type, requires_grad=True, device=device)
180
+ b = wp.array(np.zeros((1, 4), dtype=np16), dtype=wp16, requires_grad=True, device=device)
181
+ b_result = np.ones((1, 4), dtype=np16)
182
+ b_grad = wp.array(np.ones((1, 4), dtype=np16), dtype=wp16, device=device)
183
+ a_grad = wp.array(np.ones((1, 4), dtype=np_type), dtype=wp_type, device=device)
184
+
185
+ tape = wp.Tape()
186
+ with tape:
187
+ wp.launch(kernel=kernel_16, dim=1, inputs=[a, b], device=device)
188
+
189
+ tape.backward(grads={b: b_grad})
190
+ out = tape.gradients[a].numpy()
191
+
192
+ assert_np_equal(b.numpy(), b_result)
193
+ assert_np_equal(out, a_grad.numpy())
194
+
195
+ # check casting to float 32
196
+ a = wp.array(np.ones((1, 4), dtype=np_type), dtype=wp_type, requires_grad=True, device=device)
197
+ b = wp.array(np.zeros((1, 4), dtype=np32), dtype=wp32, requires_grad=True, device=device)
198
+ b_result = np.ones((1, 4), dtype=np32)
199
+ b_grad = wp.array(np.ones((1, 4), dtype=np32), dtype=wp32, device=device)
200
+ a_grad = wp.array(np.ones((1, 4), dtype=np_type), dtype=wp_type, device=device)
201
+
202
+ tape = wp.Tape()
203
+ with tape:
204
+ wp.launch(kernel=kernel_32, dim=1, inputs=[a, b], device=device)
205
+
206
+ tape.backward(grads={b: b_grad})
207
+ out = tape.gradients[a].numpy()
208
+
209
+ assert_np_equal(b.numpy(), b_result)
210
+ assert_np_equal(out, a_grad.numpy())
211
+
212
+ # check casting to float 64
213
+ a = wp.array(np.ones((1, 4), dtype=np_type), dtype=wp_type, requires_grad=True, device=device)
214
+ b = wp.array(np.zeros((1, 4), dtype=np64), dtype=wp64, requires_grad=True, device=device)
215
+ b_result = np.ones((1, 4), dtype=np64)
216
+ b_grad = wp.array(np.ones((1, 4), dtype=np64), dtype=wp64, device=device)
217
+ a_grad = wp.array(np.ones((1, 4), dtype=np_type), dtype=wp_type, device=device)
218
+
219
+ tape = wp.Tape()
220
+ with tape:
221
+ wp.launch(kernel=kernel_64, dim=1, inputs=[a, b], device=device)
222
+
223
+ tape.backward(grads={b: b_grad})
224
+ out = tape.gradients[a].numpy()
225
+
226
+ assert_np_equal(b.numpy(), b_result)
227
+ assert_np_equal(out, a_grad.numpy())
228
+
229
+
117
230
  def test_inverse(test, device, dtype, register_kernels=False):
118
- np.random.seed(123)
231
+ rng = np.random.default_rng(123)
119
232
 
120
233
  tol = {
121
234
  np.float16: 2.0e-3,
@@ -150,7 +263,7 @@ def test_inverse(test, device, dtype, register_kernels=False):
150
263
  if register_kernels:
151
264
  return
152
265
 
153
- input = wp.array(np.random.randn(4).astype(dtype), requires_grad=True, device=device)
266
+ input = wp.array(rng.standard_normal(size=4).astype(dtype), requires_grad=True, device=device)
154
267
  shouldbeidentity = wp.array(np.zeros((1, 4)), dtype=quat, requires_grad=True, device=device)
155
268
  output = wp.zeros_like(input)
156
269
  wp.launch(kernel, dim=1, inputs=[input], outputs=[shouldbeidentity, output], device=device)
@@ -171,7 +284,7 @@ def test_inverse(test, device, dtype, register_kernels=False):
171
284
 
172
285
 
173
286
  def test_dotproduct(test, device, dtype, register_kernels=False):
174
- np.random.seed(123)
287
+ rng = np.random.default_rng(123)
175
288
 
176
289
  tol = {
177
290
  np.float16: 1.0e-2,
@@ -193,8 +306,8 @@ def test_dotproduct(test, device, dtype, register_kernels=False):
193
306
  if register_kernels:
194
307
  return
195
308
 
196
- s = wp.array(np.random.randn(4).astype(dtype), dtype=quat, requires_grad=True, device=device)
197
- v = wp.array(np.random.randn(4).astype(dtype), dtype=quat, requires_grad=True, device=device)
309
+ s = wp.array(rng.standard_normal(size=4).astype(dtype), dtype=quat, requires_grad=True, device=device)
310
+ v = wp.array(rng.standard_normal(size=4).astype(dtype), dtype=quat, requires_grad=True, device=device)
198
311
  dot = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
199
312
 
200
313
  tape = wp.Tape()
@@ -223,7 +336,7 @@ def test_dotproduct(test, device, dtype, register_kernels=False):
223
336
 
224
337
 
225
338
  def test_length(test, device, dtype, register_kernels=False):
226
- np.random.seed(123)
339
+ rng = np.random.default_rng(123)
227
340
 
228
341
  tol = {
229
342
  np.float16: 5.0e-3,
@@ -247,7 +360,7 @@ def test_length(test, device, dtype, register_kernels=False):
247
360
  if register_kernels:
248
361
  return
249
362
 
250
- q = wp.array(np.random.randn(4).astype(dtype), dtype=quat, requires_grad=True, device=device)
363
+ q = wp.array(rng.standard_normal(size=4).astype(dtype), dtype=quat, requires_grad=True, device=device)
251
364
  l = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
252
365
  l2 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
253
366
 
@@ -280,7 +393,7 @@ def test_length(test, device, dtype, register_kernels=False):
280
393
 
281
394
 
282
395
  def test_normalize(test, device, dtype, register_kernels=False):
283
- np.random.seed(123)
396
+ rng = np.random.default_rng(123)
284
397
 
285
398
  tol = {
286
399
  np.float16: 5.0e-3,
@@ -327,7 +440,7 @@ def test_normalize(test, device, dtype, register_kernels=False):
327
440
 
328
441
  # I've already tested the things I'm using in check_normalize_alt, so I'll just
329
442
  # make sure the two are giving the same results/gradients
330
- q = wp.array(np.random.randn(4).astype(dtype), dtype=quat, requires_grad=True, device=device)
443
+ q = wp.array(rng.standard_normal(size=4).astype(dtype), dtype=quat, requires_grad=True, device=device)
331
444
 
332
445
  n0 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
333
446
  n1 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
@@ -381,7 +494,7 @@ def test_normalize(test, device, dtype, register_kernels=False):
381
494
 
382
495
 
383
496
  def test_addition(test, device, dtype, register_kernels=False):
384
- np.random.seed(123)
497
+ rng = np.random.default_rng(123)
385
498
 
386
499
  tol = {
387
500
  np.float16: 5.0e-3,
@@ -412,8 +525,8 @@ def test_addition(test, device, dtype, register_kernels=False):
412
525
  if register_kernels:
413
526
  return
414
527
 
415
- q = wp.array(np.random.randn(4).astype(dtype), dtype=quat, requires_grad=True, device=device)
416
- v = wp.array(np.random.randn(4).astype(dtype), dtype=quat, requires_grad=True, device=device)
528
+ q = wp.array(rng.standard_normal(size=4).astype(dtype), dtype=quat, requires_grad=True, device=device)
529
+ v = wp.array(rng.standard_normal(size=4).astype(dtype), dtype=quat, requires_grad=True, device=device)
417
530
 
418
531
  r0 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
419
532
  r1 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
@@ -453,7 +566,7 @@ def test_addition(test, device, dtype, register_kernels=False):
453
566
 
454
567
 
455
568
  def test_subtraction(test, device, dtype, register_kernels=False):
456
- np.random.seed(123)
569
+ rng = np.random.default_rng(123)
457
570
 
458
571
  tol = {
459
572
  np.float16: 5.0e-3,
@@ -484,8 +597,8 @@ def test_subtraction(test, device, dtype, register_kernels=False):
484
597
  if register_kernels:
485
598
  return
486
599
 
487
- q = wp.array(np.random.randn(4).astype(dtype), dtype=quat, requires_grad=True, device=device)
488
- v = wp.array(np.random.randn(4).astype(dtype), dtype=quat, requires_grad=True, device=device)
600
+ q = wp.array(rng.standard_normal(size=4).astype(dtype), dtype=quat, requires_grad=True, device=device)
601
+ v = wp.array(rng.standard_normal(size=4).astype(dtype), dtype=quat, requires_grad=True, device=device)
489
602
 
490
603
  r0 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
491
604
  r1 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
@@ -526,7 +639,7 @@ def test_subtraction(test, device, dtype, register_kernels=False):
526
639
 
527
640
 
528
641
  def test_scalar_multiplication(test, device, dtype, register_kernels=False):
529
- np.random.seed(123)
642
+ rng = np.random.default_rng(123)
530
643
 
531
644
  tol = {
532
645
  np.float16: 5.0e-3,
@@ -568,8 +681,8 @@ def test_scalar_multiplication(test, device, dtype, register_kernels=False):
568
681
  if register_kernels:
569
682
  return
570
683
 
571
- s = wp.array(np.random.randn(1).astype(dtype), requires_grad=True, device=device)
572
- q = wp.array(np.random.randn(1, 4).astype(dtype), dtype=quat, requires_grad=True, device=device)
684
+ s = wp.array(rng.standard_normal(size=1).astype(dtype), requires_grad=True, device=device)
685
+ q = wp.array(rng.standard_normal(size=(1, 4)).astype(dtype), dtype=quat, requires_grad=True, device=device)
573
686
 
574
687
  l0 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
575
688
  l1 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
@@ -624,7 +737,7 @@ def test_scalar_multiplication(test, device, dtype, register_kernels=False):
624
737
 
625
738
 
626
739
  def test_scalar_division(test, device, dtype, register_kernels=False):
627
- np.random.seed(123)
740
+ rng = np.random.default_rng(123)
628
741
 
629
742
  tol = {
630
743
  np.float16: 1.0e-3,
@@ -656,8 +769,8 @@ def test_scalar_division(test, device, dtype, register_kernels=False):
656
769
  if register_kernels:
657
770
  return
658
771
 
659
- s = wp.array(np.random.randn(1).astype(dtype), requires_grad=True, device=device)
660
- q = wp.array(np.random.randn(1, 4).astype(dtype), dtype=quat, requires_grad=True, device=device)
772
+ s = wp.array(rng.standard_normal(size=1).astype(dtype), requires_grad=True, device=device)
773
+ q = wp.array(rng.standard_normal(size=(1, 4)).astype(dtype), dtype=quat, requires_grad=True, device=device)
661
774
 
662
775
  r0 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
663
776
  r1 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
@@ -697,7 +810,7 @@ def test_scalar_division(test, device, dtype, register_kernels=False):
697
810
 
698
811
 
699
812
  def test_quat_multiplication(test, device, dtype, register_kernels=False):
700
- np.random.seed(123)
813
+ rng = np.random.default_rng(123)
701
814
 
702
815
  tol = {
703
816
  np.float16: 1.0e-2,
@@ -729,8 +842,8 @@ def test_quat_multiplication(test, device, dtype, register_kernels=False):
729
842
  if register_kernels:
730
843
  return
731
844
 
732
- s = wp.array(np.random.randn(1, 4).astype(dtype), dtype=quat, requires_grad=True, device=device)
733
- q = wp.array(np.random.randn(1, 4).astype(dtype), dtype=quat, requires_grad=True, device=device)
845
+ s = wp.array(rng.standard_normal(size=(1, 4)).astype(dtype), dtype=quat, requires_grad=True, device=device)
846
+ q = wp.array(rng.standard_normal(size=(1, 4)).astype(dtype), dtype=quat, requires_grad=True, device=device)
734
847
 
735
848
  r0 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
736
849
  r1 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
@@ -801,7 +914,7 @@ def test_quat_multiplication(test, device, dtype, register_kernels=False):
801
914
 
802
915
 
803
916
  def test_indexing(test, device, dtype, register_kernels=False):
804
- np.random.seed(123)
917
+ rng = np.random.default_rng(123)
805
918
 
806
919
  tol = {
807
920
  np.float16: 5.0e-3,
@@ -830,7 +943,7 @@ def test_indexing(test, device, dtype, register_kernels=False):
830
943
  if register_kernels:
831
944
  return
832
945
 
833
- q = wp.array(np.random.randn(1, 4).astype(dtype), dtype=quat, requires_grad=True, device=device)
946
+ q = wp.array(rng.standard_normal(size=(1, 4)).astype(dtype), dtype=quat, requires_grad=True, device=device)
834
947
  r0 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
835
948
  r1 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
836
949
  r2 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
@@ -855,7 +968,7 @@ def test_indexing(test, device, dtype, register_kernels=False):
855
968
 
856
969
 
857
970
  def test_quat_lerp(test, device, dtype, register_kernels=False):
858
- np.random.seed(123)
971
+ rng = np.random.default_rng(123)
859
972
 
860
973
  tol = {
861
974
  np.float16: 1.0e-2,
@@ -888,9 +1001,9 @@ def test_quat_lerp(test, device, dtype, register_kernels=False):
888
1001
  if register_kernels:
889
1002
  return
890
1003
 
891
- s = wp.array(np.random.randn(1, 4).astype(dtype), dtype=quat, requires_grad=True, device=device)
892
- q = wp.array(np.random.randn(1, 4).astype(dtype), dtype=quat, requires_grad=True, device=device)
893
- t = wp.array(np.random.uniform(size=1).astype(dtype), dtype=wptype, requires_grad=True, device=device)
1004
+ s = wp.array(rng.standard_normal(size=(1, 4)).astype(dtype), dtype=quat, requires_grad=True, device=device)
1005
+ q = wp.array(rng.standard_normal(size=(1, 4)).astype(dtype), dtype=quat, requires_grad=True, device=device)
1006
+ t = wp.array(rng.uniform(size=1).astype(dtype), dtype=wptype, requires_grad=True, device=device)
894
1007
 
895
1008
  r0 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
896
1009
  r1 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
@@ -936,7 +1049,7 @@ def test_quat_lerp(test, device, dtype, register_kernels=False):
936
1049
 
937
1050
 
938
1051
  def test_quat_rotate(test, device, dtype, register_kernels=False):
939
- np.random.seed(123)
1052
+ rng = np.random.default_rng(123)
940
1053
 
941
1054
  tol = {
942
1055
  np.float16: 1.0e-2,
@@ -983,10 +1096,10 @@ def test_quat_rotate(test, device, dtype, register_kernels=False):
983
1096
  if register_kernels:
984
1097
  return
985
1098
 
986
- q = np.random.randn(1, 4)
1099
+ q = rng.standard_normal(size=(1, 4))
987
1100
  q /= np.linalg.norm(q)
988
1101
  q = wp.array(q.astype(dtype), dtype=quat, requires_grad=True, device=device)
989
- v = wp.array(0.5 * np.random.randn(1, 3).astype(dtype), dtype=vec3, requires_grad=True, device=device)
1102
+ v = wp.array(0.5 * rng.standard_normal(size=(1, 3)).astype(dtype), dtype=vec3, requires_grad=True, device=device)
990
1103
 
991
1104
  # test values against the manually computed result:
992
1105
  outputs = wp.zeros(3, dtype=wptype, requires_grad=True, device=device)
@@ -1062,7 +1175,7 @@ def test_quat_rotate(test, device, dtype, register_kernels=False):
1062
1175
 
1063
1176
 
1064
1177
  def test_quat_to_matrix(test, device, dtype, register_kernels=False):
1065
- np.random.seed(123)
1178
+ rng = np.random.default_rng(123)
1066
1179
 
1067
1180
  tol = {
1068
1181
  np.float16: 1.0e-2,
@@ -1123,7 +1236,7 @@ def test_quat_to_matrix(test, device, dtype, register_kernels=False):
1123
1236
  if register_kernels:
1124
1237
  return
1125
1238
 
1126
- q = np.random.randn(1, 4)
1239
+ q = rng.standard_normal(size=(1, 4))
1127
1240
  q /= np.linalg.norm(q)
1128
1241
  q = wp.array(q.astype(dtype), dtype=quat, requires_grad=True, device=device)
1129
1242
 
@@ -1186,8 +1299,8 @@ def test_quat_to_matrix(test, device, dtype, register_kernels=False):
1186
1299
 
1187
1300
 
1188
1301
  def test_slerp_grad(test, device, dtype, register_kernels=False):
1302
+ rng = np.random.default_rng(123)
1189
1303
  seed = 42
1190
- np.random.seed(seed)
1191
1304
 
1192
1305
  wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1193
1306
  vec3 = wp.types.vector(3, wptype)
@@ -1252,7 +1365,7 @@ def test_slerp_grad(test, device, dtype, register_kernels=False):
1252
1365
  wp.launch(kernel=quat_sampler, dim=N, inputs=[seed, q0], device=device)
1253
1366
  wp.launch(kernel=quat_sampler, dim=N, inputs=[seed + 1, q1], device=device)
1254
1367
 
1255
- t = np.random.uniform(0.0, 1.0, N)
1368
+ t = rng.uniform(low=0.0, high=1.0, size=N)
1256
1369
  t = wp.array(t, dtype=wptype, device=device, requires_grad=True)
1257
1370
 
1258
1371
  def compute_gradients(kernel, wrt, index):
@@ -1348,8 +1461,8 @@ def test_slerp_grad(test, device, dtype, register_kernels=False):
1348
1461
 
1349
1462
 
1350
1463
  def test_quat_to_axis_angle_grad(test, device, dtype, register_kernels=False):
1464
+ rng = np.random.default_rng(123)
1351
1465
  seed = 42
1352
- rng = np.random.default_rng(seed)
1353
1466
  num_rand = 50
1354
1467
 
1355
1468
  wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
@@ -1481,8 +1594,7 @@ def test_quat_to_axis_angle_grad(test, device, dtype, register_kernels=False):
1481
1594
 
1482
1595
 
1483
1596
  def test_quat_rpy_grad(test, device, dtype, register_kernels=False):
1484
- seed = 42
1485
- np.random.seed(seed)
1597
+ rng = np.random.default_rng(123)
1486
1598
  N = 3
1487
1599
 
1488
1600
  wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
@@ -1531,7 +1643,7 @@ def test_quat_rpy_grad(test, device, dtype, register_kernels=False):
1531
1643
  if register_kernels:
1532
1644
  return
1533
1645
 
1534
- rpy_arr = np.random.uniform(-np.pi, np.pi, (N, 3))
1646
+ rpy_arr = rng.uniform(low=-np.pi, high=np.pi, size=(N, 3))
1535
1647
  rpy_arr = wp.array(rpy_arr, dtype=vec3, device=device, requires_grad=True)
1536
1648
 
1537
1649
  def compute_gradients(kernel, wrt, index):
@@ -1760,6 +1872,7 @@ def test_quat_identity(test, device, dtype, register_kernels=False):
1760
1872
 
1761
1873
 
1762
1874
  def test_anon_type_instance(test, device, dtype, register_kernels=False):
1875
+ rng = np.random.default_rng(123)
1763
1876
  wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1764
1877
 
1765
1878
  def quat_create_test(input: wp.array(dtype=wptype), output: wp.array(dtype=wptype)):
@@ -1783,7 +1896,7 @@ def test_anon_type_instance(test, device, dtype, register_kernels=False):
1783
1896
  if register_kernels:
1784
1897
  return
1785
1898
 
1786
- input = wp.array(np.random.randn(8).astype(dtype), requires_grad=True, device=device)
1899
+ input = wp.array(rng.standard_normal(size=8).astype(dtype), requires_grad=True, device=device)
1787
1900
  output = wp.zeros(8, dtype=wptype, requires_grad=True, device=device)
1788
1901
  wp.launch(quat_create_kernel, dim=1, inputs=[input], outputs=[output], device=device)
1789
1902
  assert_np_equal(output.numpy(), 2 * input.numpy())
@@ -1826,92 +1939,125 @@ def test_constructor_default():
1826
1939
  wp.expect_eq(qeye[3], 1.0)
1827
1940
 
1828
1941
 
1829
- def register(parent):
1830
- devices = get_test_devices()
1942
+ def test_py_arithmetic_ops(test, device, dtype):
1943
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1831
1944
 
1832
- class TestQuat(parent):
1833
- pass
1945
+ def make_quat(*args):
1946
+ if wptype in wp.types.int_types:
1947
+ # Cast to the correct integer type to simulate wrapping.
1948
+ return tuple(wptype._type_(x).value for x in args)
1834
1949
 
1835
- add_kernel_test(TestQuat, test_constructor_default, dim=1, devices=devices)
1950
+ return args
1951
+
1952
+ quat_cls = wp.types.quaternion(wptype)
1953
+
1954
+ v = quat_cls(1, -2, 3, -4)
1955
+ test.assertSequenceEqual(+v, make_quat(1, -2, 3, -4))
1956
+ test.assertSequenceEqual(-v, make_quat(-1, 2, -3, 4))
1957
+ test.assertSequenceEqual(v + quat_cls(5, 5, 5, 5), make_quat(6, 3, 8, 1))
1958
+ test.assertSequenceEqual(v - quat_cls(5, 5, 5, 5), make_quat(-4, -7, -2, -9))
1959
+
1960
+ v = quat_cls(2, 4, 6, 8)
1961
+ test.assertSequenceEqual(v * wptype(2), make_quat(4, 8, 12, 16))
1962
+ test.assertSequenceEqual(wptype(2) * v, make_quat(4, 8, 12, 16))
1963
+ test.assertSequenceEqual(v / wptype(2), make_quat(1, 2, 3, 4))
1964
+ test.assertSequenceEqual(wptype(24) / v, make_quat(12, 6, 4, 3))
1965
+
1966
+
1967
+ devices = get_test_devices()
1968
+
1969
+
1970
+ class TestQuat(unittest.TestCase):
1971
+ pass
1836
1972
 
1837
- for dtype in np_float_types:
1838
- add_function_test_register_kernel(
1839
- TestQuat, f"test_constructors_{dtype.__name__}", test_constructors, devices=devices, dtype=dtype
1840
- )
1841
- add_function_test_register_kernel(
1842
- TestQuat, f"test_anon_type_instance_{dtype.__name__}", test_anon_type_instance, devices=devices, dtype=dtype
1843
- )
1844
- add_function_test_register_kernel(
1845
- TestQuat, f"test_inverse_{dtype.__name__}", test_inverse, devices=devices, dtype=dtype
1846
- )
1847
- add_function_test_register_kernel(
1848
- TestQuat, f"test_quat_identity_{dtype.__name__}", test_quat_identity, devices=devices, dtype=dtype
1849
- )
1850
- add_function_test_register_kernel(
1851
- TestQuat, f"test_dotproduct_{dtype.__name__}", test_dotproduct, devices=devices, dtype=dtype
1852
- )
1853
- add_function_test_register_kernel(
1854
- TestQuat, f"test_length_{dtype.__name__}", test_length, devices=devices, dtype=dtype
1855
- )
1856
- add_function_test_register_kernel(
1857
- TestQuat, f"test_normalize_{dtype.__name__}", test_normalize, devices=devices, dtype=dtype
1858
- )
1859
- add_function_test_register_kernel(
1860
- TestQuat, f"test_addition_{dtype.__name__}", test_addition, devices=devices, dtype=dtype
1861
- )
1862
- add_function_test_register_kernel(
1863
- TestQuat, f"test_subtraction_{dtype.__name__}", test_subtraction, devices=devices, dtype=dtype
1864
- )
1865
- add_function_test_register_kernel(
1866
- TestQuat,
1867
- f"test_scalar_multiplication_{dtype.__name__}",
1868
- test_scalar_multiplication,
1869
- devices=devices,
1870
- dtype=dtype,
1871
- )
1872
- add_function_test_register_kernel(
1873
- TestQuat, f"test_scalar_division_{dtype.__name__}", test_scalar_division, devices=devices, dtype=dtype
1874
- )
1875
- add_function_test_register_kernel(
1876
- TestQuat,
1877
- f"test_quat_multiplication_{dtype.__name__}",
1878
- test_quat_multiplication,
1879
- devices=devices,
1880
- dtype=dtype,
1881
- )
1882
- add_function_test_register_kernel(
1883
- TestQuat, f"test_indexing_{dtype.__name__}", test_indexing, devices=devices, dtype=dtype
1884
- )
1885
- add_function_test_register_kernel(
1886
- TestQuat, f"test_quat_lerp_{dtype.__name__}", test_quat_lerp, devices=devices, dtype=dtype
1887
- )
1888
- add_function_test_register_kernel(
1889
- TestQuat,
1890
- f"test_quat_to_axis_angle_grad_{dtype.__name__}",
1891
- test_quat_to_axis_angle_grad,
1892
- devices=devices,
1893
- dtype=dtype,
1894
- )
1895
- add_function_test_register_kernel(
1896
- TestQuat, f"test_slerp_grad_{dtype.__name__}", test_slerp_grad, devices=devices, dtype=dtype
1897
- )
1898
- add_function_test_register_kernel(
1899
- TestQuat, f"test_quat_rpy_grad_{dtype.__name__}", test_quat_rpy_grad, devices=devices, dtype=dtype
1900
- )
1901
- add_function_test_register_kernel(
1902
- TestQuat, f"test_quat_from_matrix_{dtype.__name__}", test_quat_from_matrix, devices=devices, dtype=dtype
1903
- )
1904
- add_function_test_register_kernel(
1905
- TestQuat, f"test_quat_rotate_{dtype.__name__}", test_quat_rotate, devices=devices, dtype=dtype
1906
- )
1907
- add_function_test_register_kernel(
1908
- TestQuat, f"test_quat_to_matrix_{dtype.__name__}", test_quat_to_matrix, devices=devices, dtype=dtype
1909
- )
1910
1973
 
1911
- return TestQuat
1974
+ add_kernel_test(TestQuat, test_constructor_default, dim=1, devices=devices)
1975
+
1976
+ for dtype in np_float_types:
1977
+ add_function_test_register_kernel(
1978
+ TestQuat, f"test_constructors_{dtype.__name__}", test_constructors, devices=devices, dtype=dtype
1979
+ )
1980
+ add_function_test_register_kernel(
1981
+ TestQuat,
1982
+ f"test_casting_constructors_{dtype.__name__}",
1983
+ test_casting_constructors,
1984
+ devices=devices,
1985
+ dtype=dtype,
1986
+ )
1987
+ add_function_test_register_kernel(
1988
+ TestQuat, f"test_anon_type_instance_{dtype.__name__}", test_anon_type_instance, devices=devices, dtype=dtype
1989
+ )
1990
+ add_function_test_register_kernel(
1991
+ TestQuat, f"test_inverse_{dtype.__name__}", test_inverse, devices=devices, dtype=dtype
1992
+ )
1993
+ add_function_test_register_kernel(
1994
+ TestQuat, f"test_quat_identity_{dtype.__name__}", test_quat_identity, devices=devices, dtype=dtype
1995
+ )
1996
+ add_function_test_register_kernel(
1997
+ TestQuat, f"test_dotproduct_{dtype.__name__}", test_dotproduct, devices=devices, dtype=dtype
1998
+ )
1999
+ add_function_test_register_kernel(
2000
+ TestQuat, f"test_length_{dtype.__name__}", test_length, devices=devices, dtype=dtype
2001
+ )
2002
+ add_function_test_register_kernel(
2003
+ TestQuat, f"test_normalize_{dtype.__name__}", test_normalize, devices=devices, dtype=dtype
2004
+ )
2005
+ add_function_test_register_kernel(
2006
+ TestQuat, f"test_addition_{dtype.__name__}", test_addition, devices=devices, dtype=dtype
2007
+ )
2008
+ add_function_test_register_kernel(
2009
+ TestQuat, f"test_subtraction_{dtype.__name__}", test_subtraction, devices=devices, dtype=dtype
2010
+ )
2011
+ add_function_test_register_kernel(
2012
+ TestQuat,
2013
+ f"test_scalar_multiplication_{dtype.__name__}",
2014
+ test_scalar_multiplication,
2015
+ devices=devices,
2016
+ dtype=dtype,
2017
+ )
2018
+ add_function_test_register_kernel(
2019
+ TestQuat, f"test_scalar_division_{dtype.__name__}", test_scalar_division, devices=devices, dtype=dtype
2020
+ )
2021
+ add_function_test_register_kernel(
2022
+ TestQuat,
2023
+ f"test_quat_multiplication_{dtype.__name__}",
2024
+ test_quat_multiplication,
2025
+ devices=devices,
2026
+ dtype=dtype,
2027
+ )
2028
+ add_function_test_register_kernel(
2029
+ TestQuat, f"test_indexing_{dtype.__name__}", test_indexing, devices=devices, dtype=dtype
2030
+ )
2031
+ add_function_test_register_kernel(
2032
+ TestQuat, f"test_quat_lerp_{dtype.__name__}", test_quat_lerp, devices=devices, dtype=dtype
2033
+ )
2034
+ add_function_test_register_kernel(
2035
+ TestQuat,
2036
+ f"test_quat_to_axis_angle_grad_{dtype.__name__}",
2037
+ test_quat_to_axis_angle_grad,
2038
+ devices=devices,
2039
+ dtype=dtype,
2040
+ )
2041
+ add_function_test_register_kernel(
2042
+ TestQuat, f"test_slerp_grad_{dtype.__name__}", test_slerp_grad, devices=devices, dtype=dtype
2043
+ )
2044
+ add_function_test_register_kernel(
2045
+ TestQuat, f"test_quat_rpy_grad_{dtype.__name__}", test_quat_rpy_grad, devices=devices, dtype=dtype
2046
+ )
2047
+ add_function_test_register_kernel(
2048
+ TestQuat, f"test_quat_from_matrix_{dtype.__name__}", test_quat_from_matrix, devices=devices, dtype=dtype
2049
+ )
2050
+ add_function_test_register_kernel(
2051
+ TestQuat, f"test_quat_rotate_{dtype.__name__}", test_quat_rotate, devices=devices, dtype=dtype
2052
+ )
2053
+ add_function_test_register_kernel(
2054
+ TestQuat, f"test_quat_to_matrix_{dtype.__name__}", test_quat_to_matrix, devices=devices, dtype=dtype
2055
+ )
2056
+ add_function_test(
2057
+ TestQuat, f"test_py_arithmetic_ops_{dtype.__name__}", test_py_arithmetic_ops, devices=None, dtype=dtype
2058
+ )
1912
2059
 
1913
2060
 
1914
2061
  if __name__ == "__main__":
1915
2062
  wp.build.clear_kernel_cache()
1916
- c = register(unittest.TestCase)
1917
2063
  unittest.main(verbosity=2)