warp-lang 0.10.1__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (300) hide show
  1. warp/__init__.py +10 -4
  2. warp/__init__.pyi +1 -0
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +5 -3
  6. warp/build_dll.py +29 -9
  7. warp/builtins.py +868 -507
  8. warp/codegen.py +1074 -638
  9. warp/config.py +3 -3
  10. warp/constants.py +6 -0
  11. warp/context.py +715 -222
  12. warp/fabric.py +326 -0
  13. warp/fem/__init__.py +27 -0
  14. warp/fem/cache.py +389 -0
  15. warp/fem/dirichlet.py +181 -0
  16. warp/fem/domain.py +263 -0
  17. warp/fem/field/__init__.py +101 -0
  18. warp/fem/field/field.py +149 -0
  19. warp/fem/field/nodal_field.py +299 -0
  20. warp/fem/field/restriction.py +21 -0
  21. warp/fem/field/test.py +181 -0
  22. warp/fem/field/trial.py +183 -0
  23. warp/fem/geometry/__init__.py +19 -0
  24. warp/fem/geometry/closest_point.py +70 -0
  25. warp/fem/geometry/deformed_geometry.py +271 -0
  26. warp/fem/geometry/element.py +744 -0
  27. warp/fem/geometry/geometry.py +186 -0
  28. warp/fem/geometry/grid_2d.py +373 -0
  29. warp/fem/geometry/grid_3d.py +435 -0
  30. warp/fem/geometry/hexmesh.py +953 -0
  31. warp/fem/geometry/partition.py +376 -0
  32. warp/fem/geometry/quadmesh_2d.py +532 -0
  33. warp/fem/geometry/tetmesh.py +840 -0
  34. warp/fem/geometry/trimesh_2d.py +577 -0
  35. warp/fem/integrate.py +1616 -0
  36. warp/fem/operator.py +191 -0
  37. warp/fem/polynomial.py +213 -0
  38. warp/fem/quadrature/__init__.py +2 -0
  39. warp/fem/quadrature/pic_quadrature.py +245 -0
  40. warp/fem/quadrature/quadrature.py +294 -0
  41. warp/fem/space/__init__.py +292 -0
  42. warp/fem/space/basis_space.py +489 -0
  43. warp/fem/space/collocated_function_space.py +105 -0
  44. warp/fem/space/dof_mapper.py +236 -0
  45. warp/fem/space/function_space.py +145 -0
  46. warp/fem/space/grid_2d_function_space.py +267 -0
  47. warp/fem/space/grid_3d_function_space.py +306 -0
  48. warp/fem/space/hexmesh_function_space.py +352 -0
  49. warp/fem/space/partition.py +350 -0
  50. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  51. warp/fem/space/restriction.py +160 -0
  52. warp/fem/space/shape/__init__.py +15 -0
  53. warp/fem/space/shape/cube_shape_function.py +738 -0
  54. warp/fem/space/shape/shape_function.py +103 -0
  55. warp/fem/space/shape/square_shape_function.py +611 -0
  56. warp/fem/space/shape/tet_shape_function.py +567 -0
  57. warp/fem/space/shape/triangle_shape_function.py +429 -0
  58. warp/fem/space/tetmesh_function_space.py +292 -0
  59. warp/fem/space/topology.py +295 -0
  60. warp/fem/space/trimesh_2d_function_space.py +221 -0
  61. warp/fem/types.py +77 -0
  62. warp/fem/utils.py +495 -0
  63. warp/native/array.h +147 -44
  64. warp/native/builtin.h +122 -149
  65. warp/native/bvh.cpp +73 -325
  66. warp/native/bvh.cu +406 -23
  67. warp/native/bvh.h +34 -43
  68. warp/native/clang/clang.cpp +13 -8
  69. warp/native/crt.h +2 -0
  70. warp/native/cuda_crt.h +5 -0
  71. warp/native/cuda_util.cpp +15 -3
  72. warp/native/cuda_util.h +3 -1
  73. warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
  74. warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
  75. warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
  76. warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
  77. warp/native/cutlass/tools/library/scripts/library.py +799 -0
  78. warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
  79. warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
  80. warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
  81. warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
  82. warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
  83. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
  84. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
  85. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
  86. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
  87. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
  88. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
  89. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
  90. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
  91. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
  92. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
  93. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
  94. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
  95. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
  96. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
  97. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
  98. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
  99. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
  100. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
  101. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
  102. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
  103. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
  104. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
  105. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
  106. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
  107. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
  108. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
  109. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
  110. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
  111. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
  112. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
  113. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
  114. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
  115. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  116. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  117. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
  118. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
  119. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
  120. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  121. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  122. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
  123. warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
  124. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
  125. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
  126. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
  127. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
  128. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
  129. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
  130. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
  131. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
  132. warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
  133. warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
  134. warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
  135. warp/native/cutlass/tools/library/scripts/rt.py +796 -0
  136. warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
  137. warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
  138. warp/native/cutlass_gemm.cu +5 -3
  139. warp/native/exports.h +1240 -952
  140. warp/native/fabric.h +228 -0
  141. warp/native/hashgrid.cpp +4 -4
  142. warp/native/hashgrid.h +22 -2
  143. warp/native/intersect.h +22 -7
  144. warp/native/intersect_adj.h +8 -8
  145. warp/native/intersect_tri.h +1 -1
  146. warp/native/marching.cu +157 -161
  147. warp/native/mat.h +80 -19
  148. warp/native/matnn.h +2 -2
  149. warp/native/mesh.cpp +33 -108
  150. warp/native/mesh.cu +114 -23
  151. warp/native/mesh.h +446 -46
  152. warp/native/noise.h +272 -329
  153. warp/native/quat.h +51 -8
  154. warp/native/rand.h +45 -35
  155. warp/native/range.h +6 -2
  156. warp/native/reduce.cpp +1 -1
  157. warp/native/reduce.cu +10 -12
  158. warp/native/runlength_encode.cu +6 -10
  159. warp/native/scan.cu +8 -11
  160. warp/native/sparse.cpp +4 -4
  161. warp/native/sparse.cu +164 -154
  162. warp/native/spatial.h +2 -2
  163. warp/native/temp_buffer.h +14 -30
  164. warp/native/vec.h +107 -23
  165. warp/native/volume.h +120 -0
  166. warp/native/warp.cpp +560 -30
  167. warp/native/warp.cu +431 -44
  168. warp/native/warp.h +13 -4
  169. warp/optim/__init__.py +1 -0
  170. warp/optim/linear.py +922 -0
  171. warp/optim/sgd.py +92 -0
  172. warp/render/render_opengl.py +335 -119
  173. warp/render/render_usd.py +11 -11
  174. warp/sim/__init__.py +2 -2
  175. warp/sim/articulation.py +385 -185
  176. warp/sim/collide.py +8 -0
  177. warp/sim/import_mjcf.py +297 -106
  178. warp/sim/import_urdf.py +389 -210
  179. warp/sim/import_usd.py +198 -97
  180. warp/sim/inertia.py +17 -18
  181. warp/sim/integrator_euler.py +14 -8
  182. warp/sim/integrator_xpbd.py +158 -16
  183. warp/sim/model.py +795 -291
  184. warp/sim/render.py +3 -3
  185. warp/sim/utils.py +3 -0
  186. warp/sparse.py +640 -150
  187. warp/stubs.py +606 -267
  188. warp/tape.py +61 -10
  189. warp/tests/__main__.py +3 -6
  190. warp/tests/assets/curlnoise_golden.npy +0 -0
  191. warp/tests/assets/pnoise_golden.npy +0 -0
  192. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  193. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  194. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  195. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  196. warp/tests/aux_test_unresolved_func.py +14 -0
  197. warp/tests/aux_test_unresolved_symbol.py +14 -0
  198. warp/tests/disabled_kinematics.py +239 -0
  199. warp/tests/run_coverage_serial.py +31 -0
  200. warp/tests/test_adam.py +103 -106
  201. warp/tests/test_arithmetic.py +128 -74
  202. warp/tests/test_array.py +212 -97
  203. warp/tests/test_array_reduce.py +57 -23
  204. warp/tests/test_atomic.py +64 -28
  205. warp/tests/test_bool.py +99 -0
  206. warp/tests/test_builtins_resolution.py +1292 -0
  207. warp/tests/test_bvh.py +42 -18
  208. warp/tests/test_closest_point_edge_edge.py +54 -57
  209. warp/tests/test_codegen.py +208 -130
  210. warp/tests/test_compile_consts.py +28 -20
  211. warp/tests/test_conditional.py +108 -24
  212. warp/tests/test_copy.py +10 -12
  213. warp/tests/test_ctypes.py +112 -88
  214. warp/tests/test_dense.py +21 -14
  215. warp/tests/test_devices.py +98 -0
  216. warp/tests/test_dlpack.py +75 -75
  217. warp/tests/test_examples.py +277 -0
  218. warp/tests/test_fabricarray.py +955 -0
  219. warp/tests/test_fast_math.py +15 -11
  220. warp/tests/test_fem.py +1271 -0
  221. warp/tests/test_fp16.py +53 -19
  222. warp/tests/test_func.py +187 -86
  223. warp/tests/test_generics.py +194 -49
  224. warp/tests/test_grad.py +178 -109
  225. warp/tests/test_grad_customs.py +176 -0
  226. warp/tests/test_hash_grid.py +52 -37
  227. warp/tests/test_import.py +10 -23
  228. warp/tests/test_indexedarray.py +32 -31
  229. warp/tests/test_intersect.py +18 -9
  230. warp/tests/test_large.py +141 -0
  231. warp/tests/test_launch.py +14 -41
  232. warp/tests/test_lerp.py +64 -65
  233. warp/tests/test_linear_solvers.py +154 -0
  234. warp/tests/test_lvalue.py +493 -0
  235. warp/tests/test_marching_cubes.py +12 -13
  236. warp/tests/test_mat.py +517 -2898
  237. warp/tests/test_mat_lite.py +115 -0
  238. warp/tests/test_mat_scalar_ops.py +2889 -0
  239. warp/tests/test_math.py +103 -9
  240. warp/tests/test_matmul.py +305 -69
  241. warp/tests/test_matmul_lite.py +410 -0
  242. warp/tests/test_mesh.py +71 -14
  243. warp/tests/test_mesh_query_aabb.py +41 -25
  244. warp/tests/test_mesh_query_point.py +140 -22
  245. warp/tests/test_mesh_query_ray.py +39 -22
  246. warp/tests/test_mlp.py +30 -22
  247. warp/tests/test_model.py +92 -89
  248. warp/tests/test_modules_lite.py +39 -0
  249. warp/tests/test_multigpu.py +88 -114
  250. warp/tests/test_noise.py +12 -11
  251. warp/tests/test_operators.py +16 -20
  252. warp/tests/test_options.py +11 -11
  253. warp/tests/test_pinned.py +17 -18
  254. warp/tests/test_print.py +32 -11
  255. warp/tests/test_quat.py +275 -129
  256. warp/tests/test_rand.py +18 -16
  257. warp/tests/test_reload.py +38 -34
  258. warp/tests/test_rounding.py +50 -43
  259. warp/tests/test_runlength_encode.py +168 -20
  260. warp/tests/test_smoothstep.py +9 -11
  261. warp/tests/test_snippet.py +143 -0
  262. warp/tests/test_sparse.py +261 -63
  263. warp/tests/test_spatial.py +276 -243
  264. warp/tests/test_streams.py +110 -85
  265. warp/tests/test_struct.py +268 -63
  266. warp/tests/test_tape.py +39 -21
  267. warp/tests/test_torch.py +118 -89
  268. warp/tests/test_transient_module.py +12 -13
  269. warp/tests/test_types.py +614 -0
  270. warp/tests/test_utils.py +494 -0
  271. warp/tests/test_vec.py +354 -2050
  272. warp/tests/test_vec_lite.py +73 -0
  273. warp/tests/test_vec_scalar_ops.py +2099 -0
  274. warp/tests/test_volume.py +457 -293
  275. warp/tests/test_volume_write.py +124 -134
  276. warp/tests/unittest_serial.py +35 -0
  277. warp/tests/unittest_suites.py +341 -0
  278. warp/tests/unittest_utils.py +568 -0
  279. warp/tests/unused_test_misc.py +71 -0
  280. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  281. warp/thirdparty/appdirs.py +36 -45
  282. warp/thirdparty/unittest_parallel.py +549 -0
  283. warp/torch.py +9 -6
  284. warp/types.py +1089 -366
  285. warp/utils.py +93 -387
  286. warp_lang-0.11.0.dist-info/METADATA +238 -0
  287. warp_lang-0.11.0.dist-info/RECORD +332 -0
  288. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
  289. warp/tests/test_all.py +0 -219
  290. warp/tests/test_array_scan.py +0 -60
  291. warp/tests/test_base.py +0 -208
  292. warp/tests/test_unresolved_func.py +0 -7
  293. warp/tests/test_unresolved_symbol.py +0 -7
  294. warp_lang-0.10.1.dist-info/METADATA +0 -21
  295. warp_lang-0.10.1.dist-info/RECORD +0 -188
  296. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  297. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  298. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  299. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/LICENSE.md +0 -0
  300. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
warp/tests/test_fp16.py CHANGED
@@ -1,11 +1,16 @@
1
- import warp as wp
2
- import numpy as np
1
+ # Copyright (c) 2023 NVIDIA CORPORATION. All rights reserved.
2
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
3
+ # and proprietary rights in and to this software, related documentation
4
+ # and any modifications thereto. Any use, reproduction, disclosure or
5
+ # distribution of this software and related documentation without an express
6
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
3
7
 
4
8
  import unittest
5
9
 
6
- import warp as wp
7
- from warp.tests.test_base import *
10
+ import numpy as np
8
11
 
12
+ import warp as wp
13
+ from warp.tests.unittest_utils import *
9
14
 
10
15
  wp.init()
11
16
 
@@ -42,6 +47,34 @@ def test_fp16_conversion(test, device):
42
47
  assert_np_equal(np_f16, wp_f16.numpy())
43
48
 
44
49
 
50
+ @wp.kernel
51
+ def value_load_store_half(f16_value: wp.float16, f16_array: wp.array(dtype=wp.float16)):
52
+ wp.expect_eq(f16_value, f16_array[0])
53
+
54
+ # check stores
55
+ f16_array[0] = f16_value
56
+
57
+
58
+ def test_fp16_kernel_parameter(test, device):
59
+ """Test the ability to pass in fp16 into kernels as parameters"""
60
+
61
+ s = [1.0, 2.0, 3.0, -3.14159]
62
+
63
+ for test_val in s:
64
+ np_f16 = np.array([test_val], dtype=np.float16)
65
+ wp_f16 = wp.array([test_val], dtype=wp.float16, device=device)
66
+
67
+ wp.launch(value_load_store_half, (1,), inputs=[wp.float16(test_val), wp_f16], device=device)
68
+
69
+ # check that stores worked
70
+ assert_np_equal(np_f16, wp_f16.numpy())
71
+
72
+ # Do the same thing but pass in test_val as a Python float to test automatic conversion
73
+ wp_f16 = wp.array([test_val], dtype=wp.float16, device=device)
74
+ wp.launch(value_load_store_half, (1,), inputs=[test_val, wp_f16], device=device)
75
+ assert_np_equal(np_f16, wp_f16.numpy())
76
+
77
+
45
78
  @wp.kernel
46
79
  def mul_half(input: wp.array(dtype=wp.float16), output: wp.array(dtype=wp.float16)):
47
80
  tid = wp.tid()
@@ -54,11 +87,13 @@ def mul_half(input: wp.array(dtype=wp.float16), output: wp.array(dtype=wp.float1
54
87
 
55
88
 
56
89
  def test_fp16_grad(test, device):
90
+ rng = np.random.default_rng(123)
91
+
57
92
  # checks that gradients are correctly propagated for
58
- # fp16 arrays, even when intermediate calcualtions
93
+ # fp16 arrays, even when intermediate calculations
59
94
  # are performed in e.g.: fp32
60
95
 
61
- s = np.random.rand(15).astype(np.float16)
96
+ s = rng.random(size=15).astype(np.float16)
62
97
 
63
98
  input = wp.array(s, dtype=wp.float16, device=device, requires_grad=True)
64
99
  output = wp.zeros_like(input)
@@ -74,23 +109,22 @@ def test_fp16_grad(test, device):
74
109
  assert_np_equal(input.grad.numpy(), np.ones(len(s)) * 2.0)
75
110
 
76
111
 
77
- def register(parent):
78
- class TestFp16(parent):
79
- pass
112
+ class TestFp16(unittest.TestCase):
113
+ pass
80
114
 
81
- devices = []
82
- if wp.is_cpu_available():
83
- devices.append("cpu")
84
- for cuda_device in wp.get_cuda_devices():
85
- if cuda_device.arch >= 70:
86
- devices.append(cuda_device)
87
115
 
88
- add_function_test(TestFp16, "test_fp16_conversion", test_fp16_conversion, devices=devices)
89
- add_function_test(TestFp16, "test_fp16_grad", test_fp16_grad, devices=devices)
116
+ devices = []
117
+ if wp.is_cpu_available():
118
+ devices.append("cpu")
119
+ for cuda_device in get_unique_cuda_test_devices():
120
+ if cuda_device.arch >= 70:
121
+ devices.append(cuda_device)
90
122
 
91
- return TestFp16
123
+ add_function_test(TestFp16, "test_fp16_conversion", test_fp16_conversion, devices=devices)
124
+ add_function_test(TestFp16, "test_fp16_grad", test_fp16_grad, devices=devices)
125
+ add_function_test(TestFp16, "test_fp16_kernel_parameter", test_fp16_kernel_parameter, devices=devices)
92
126
 
93
127
 
94
128
  if __name__ == "__main__":
95
- c = register(unittest.TestCase)
129
+ wp.build.clear_kernel_cache()
96
130
  unittest.main(verbosity=2)
warp/tests/test_func.py CHANGED
@@ -5,14 +5,13 @@
5
5
  # distribution of this software and related documentation without an express
6
6
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
7
 
8
- # include parent path
9
- import numpy as np
10
8
  import math
9
+ import unittest
11
10
 
12
- import warp as wp
13
- from warp.tests.test_base import *
11
+ import numpy as np
14
12
 
15
- import unittest
13
+ import warp as wp
14
+ from warp.tests.unittest_utils import *
16
15
 
17
16
  wp.init()
18
17
 
@@ -83,76 +82,13 @@ def test_override_func():
83
82
  wp.expect_eq(i, 3)
84
83
 
85
84
 
86
- def test_native_func_export(test, device):
87
- # tests calling native functions from Python
88
-
89
- q = wp.quat(0.0, 0.0, 0.0, 1.0)
90
- assert_np_equal(np.array([*q]), np.array([0.0, 0.0, 0.0, 1.0]))
91
-
92
- r = wp.quat_from_axis_angle((1.0, 0.0, 0.0), 2.0)
93
- assert_np_equal(np.array([*r]), np.array([0.8414709568023682, 0.0, 0.0, 0.5403022170066833]), tol=1.0e-3)
94
-
95
- q = wp.quat(1.0, 2.0, 3.0, 4.0)
96
- q = wp.normalize(q) * 2.0
97
- assert_np_equal(
98
- np.array([*q]),
99
- np.array([0.18257418274879456, 0.3651483654975891, 0.547722578048706, 0.7302967309951782]) * 2.0,
100
- tol=1.0e-3,
101
- )
102
-
103
- v2 = wp.vec2(1.0, 2.0)
104
- v2 = wp.normalize(v2) * 2.0
105
- assert_np_equal(np.array([*v2]), np.array([0.4472135901451111, 0.8944271802902222]) * 2.0, tol=1.0e-3)
106
-
107
- v3 = wp.vec3(1.0, 2.0, 3.0)
108
- v3 = wp.normalize(v3) * 2.0
109
- assert_np_equal(
110
- np.array([*v3]), np.array([0.26726123690605164, 0.5345224738121033, 0.8017836809158325]) * 2.0, tol=1.0e-3
111
- )
112
-
113
- v4 = wp.vec4(1.0, 2.0, 3.0, 4.0)
114
- v4 = wp.normalize(v4) * 2.0
115
- assert_np_equal(
116
- np.array([*v4]),
117
- np.array([0.18257418274879456, 0.3651483654975891, 0.547722578048706, 0.7302967309951782]) * 2.0,
118
- tol=1.0e-3,
119
- )
120
-
121
- m22 = wp.mat22(1.0, 2.0, 3.0, 4.0)
122
- m22 = m22 + m22
123
-
124
- test.assertEqual(m22[1, 1], 8.0)
125
- test.assertEqual(str(m22), "[[2.0, 4.0],\n [6.0, 8.0]]")
126
-
127
- t = wp.transform(
128
- wp.vec3(0.0, 0.0, 0.0),
129
- wp.quat(0.0, 0.0, 0.0, 1.0),
130
- )
131
- assert_np_equal(np.array([*t]), np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0]))
132
-
133
- f = wp.sin(math.pi * 0.5)
134
- test.assertAlmostEqual(f, 1.0, places=3)
135
-
136
-
137
- def test_user_func_export(test, device):
138
- # tests calling overloaded user-defined functions from Python
139
- i = custom(1)
140
- f = custom(1.0)
141
- v = custom(wp.vec3(1.0, 0.0, 0.0))
142
-
143
- test.assertEqual(i, 2)
144
- test.assertEqual(f, 2.0)
145
- assert_np_equal(np.array([*v]), np.array([2.0, 0.0, 0.0]))
146
-
147
-
148
85
  def test_func_closure_capture(test, device):
149
86
  def make_closure_kernel(func):
150
87
  def closure_kernel_fn(data: wp.array(dtype=float), expected: float):
151
88
  f = func(data[wp.tid()])
152
89
  wp.expect_eq(f, expected)
153
90
 
154
- key = f"test_func_closure_capture_{func.key}"
155
- return wp.Kernel(func=closure_kernel_fn, key=key, module=wp.get_module(closure_kernel_fn.__module__))
91
+ return wp.Kernel(func=closure_kernel_fn)
156
92
 
157
93
  sqr_closure = make_closure_kernel(sqr)
158
94
  cube_closure = make_closure_kernel(cube)
@@ -211,26 +147,191 @@ def test_func_defaults():
211
147
  wp.expect_near(1.0, 1.1, 0.5)
212
148
 
213
149
 
214
- def register(parent):
215
- devices = get_test_devices()
216
-
217
- class TestFunc(parent):
218
- pass
150
+ @wp.func
151
+ def sign(x: float):
152
+ return 123.0
219
153
 
220
- add_kernel_test(TestFunc, kernel=test_overload_func, name="test_overload_func", dim=1, devices=devices)
221
- add_function_test(TestFunc, func=test_return_func, name="test_return_func", devices=devices)
222
- add_kernel_test(TestFunc, kernel=test_override_func, name="test_override_func", dim=1, devices=devices)
223
- add_function_test(TestFunc, func=test_native_func_export, name="test_native_func_export", devices=["cpu"])
224
- add_function_test(TestFunc, func=test_user_func_export, name="test_user_func_export", devices=["cpu"])
225
- add_function_test(TestFunc, func=test_func_closure_capture, name="test_func_closure_capture", devices=devices)
226
- add_function_test(TestFunc, func=test_multi_valued_func, name="test_multi_valued_func", devices=devices)
227
- add_kernel_test(TestFunc, kernel=test_func_defaults, name="test_func_defaults", dim=1, devices=devices)
228
154
 
229
- return TestFunc
155
+ @wp.kernel
156
+ def test_builtin_shadowing():
157
+ wp.expect_eq(sign(1.23), 123.0)
158
+
159
+
160
+ devices = get_test_devices()
161
+
162
+
163
+ class TestFunc(unittest.TestCase):
164
+ def test_user_func_export(self):
165
+ # tests calling overloaded user-defined functions from Python
166
+ i = custom(1)
167
+ f = custom(1.0)
168
+ v = custom(wp.vec3(1.0, 0.0, 0.0))
169
+
170
+ self.assertEqual(i, 2)
171
+ self.assertEqual(f, 2.0)
172
+ assert_np_equal(np.array([*v]), np.array([2.0, 0.0, 0.0]))
173
+
174
+ def test_native_func_export(self):
175
+ # tests calling native functions from Python
176
+
177
+ q = wp.quat(0.0, 0.0, 0.0, 1.0)
178
+ assert_np_equal(np.array([*q]), np.array([0.0, 0.0, 0.0, 1.0]))
179
+
180
+ r = wp.quat_from_axis_angle(wp.vec3(1.0, 0.0, 0.0), 2.0)
181
+ assert_np_equal(np.array([*r]), np.array([0.8414709568023682, 0.0, 0.0, 0.5403022170066833]), tol=1.0e-3)
182
+
183
+ q = wp.quat(1.0, 2.0, 3.0, 4.0)
184
+ q = wp.normalize(q) * 2.0
185
+ assert_np_equal(
186
+ np.array([*q]),
187
+ np.array([0.18257418274879456, 0.3651483654975891, 0.547722578048706, 0.7302967309951782]) * 2.0,
188
+ tol=1.0e-3,
189
+ )
190
+
191
+ v2 = wp.vec2(1.0, 2.0)
192
+ v2 = wp.normalize(v2) * 2.0
193
+ assert_np_equal(np.array([*v2]), np.array([0.4472135901451111, 0.8944271802902222]) * 2.0, tol=1.0e-3)
194
+
195
+ v3 = wp.vec3(1.0, 2.0, 3.0)
196
+ v3 = wp.normalize(v3) * 2.0
197
+ assert_np_equal(
198
+ np.array([*v3]), np.array([0.26726123690605164, 0.5345224738121033, 0.8017836809158325]) * 2.0, tol=1.0e-3
199
+ )
200
+
201
+ v4 = wp.vec4(1.0, 2.0, 3.0, 4.0)
202
+ v4 = wp.normalize(v4) * 2.0
203
+ assert_np_equal(
204
+ np.array([*v4]),
205
+ np.array([0.18257418274879456, 0.3651483654975891, 0.547722578048706, 0.7302967309951782]) * 2.0,
206
+ tol=1.0e-3,
207
+ )
208
+
209
+ v = wp.vec2(0.0)
210
+ v += wp.vec2(1.0, 1.0)
211
+ assert v == wp.vec2(1.0, 1.0)
212
+ v -= wp.vec2(1.0, 1.0)
213
+ assert v == wp.vec2(0.0, 0.0)
214
+ v = wp.vec2(2.0, 2.0) - wp.vec2(1.0, 1.0)
215
+ assert v == wp.vec2(1.0, 1.0)
216
+ v *= 2.0
217
+ assert v == wp.vec2(2.0, 2.0)
218
+ v = v * 2.0
219
+ assert v == wp.vec2(4.0, 4.0)
220
+ v = v / 2.0
221
+ assert v == wp.vec2(2.0, 2.0)
222
+ v /= 2.0
223
+ assert v == wp.vec2(1.0, 1.0)
224
+ v = -v
225
+ assert v == wp.vec2(-1.0, -1.0)
226
+ v = +v
227
+ assert v == wp.vec2(-1.0, -1.0)
228
+
229
+ m22 = wp.mat22(1.0, 2.0, 3.0, 4.0)
230
+ m22 = m22 + m22
231
+
232
+ self.assertEqual(m22[1, 1], 8.0)
233
+ self.assertEqual(str(m22), "[[2.0, 4.0],\n [6.0, 8.0]]")
234
+
235
+ t = wp.transform(
236
+ wp.vec3(1.0, 2.0, 3.0),
237
+ wp.quat(4.0, 5.0, 6.0, 7.0),
238
+ )
239
+ self.assertSequenceEqual(t, (1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0))
240
+ self.assertSequenceEqual(t * wp.transform(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0), (396.0, 432.0, 720.0, 56.0, 70.0, 84.0, -28.0))
241
+ self.assertSequenceEqual(
242
+ t * wp.transform((1.0, 2.0, 3.0), (4.0, 5.0, 6.0, 7.0)), (396.0, 432.0, 720.0, 56.0, 70.0, 84.0, -28.0)
243
+ )
244
+
245
+ t = wp.transform()
246
+ self.assertSequenceEqual(t, (0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0))
247
+
248
+ t = wp.transform(p=(1.0, 2.0, 3.0), q=(4.0, 5.0, 6.0, 7.0))
249
+ self.assertSequenceEqual(t, (1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0))
250
+
251
+ t = wp.transform(q=(4.0, 5.0, 6.0, 7.0), p=(1.0, 2.0, 3.0))
252
+ self.assertSequenceEqual(t, (1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0))
253
+
254
+ t = wp.transform((1.0, 2.0, 3.0), q=(4.0, 5.0, 6.0, 7.0))
255
+ self.assertSequenceEqual(t, (1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0))
256
+
257
+ t = wp.transform(p=(1.0, 2.0, 3.0))
258
+ self.assertSequenceEqual(t, (1.0, 2.0, 3.0, 0.0, 0.0, 0.0, 1.0))
259
+
260
+ t = wp.transform(q=(4.0, 5.0, 6.0, 7.0))
261
+ self.assertSequenceEqual(t, (0.0, 0.0, 0.0, 4.0, 5.0, 6.0, 7.0))
262
+
263
+ t = wp.transform((1.0, 2.0, 3.0), (4.0, 5.0, 6.0, 7.0))
264
+ self.assertSequenceEqual(t, (1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0))
265
+
266
+ t = wp.transform(p=wp.vec3(1.0, 2.0, 3.0), q=wp.quat(4.0, 5.0, 6.0, 7.0))
267
+ self.assertSequenceEqual(t, (1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0))
268
+
269
+ t = wp.transform(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0)
270
+ self.assertSequenceEqual(t, (1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0))
271
+
272
+ t = wp.transform(wp.transform(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0))
273
+ self.assertSequenceEqual(t, (1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0))
274
+
275
+ t = wp.transform(*wp.transform(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0))
276
+ self.assertSequenceEqual(t, (1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0))
277
+
278
+ transformf = wp.types.transformation(dtype=float)
279
+
280
+ t = wp.transformf((1.0, 2.0, 3.0), (4.0, 5.0, 6.0, 7.0))
281
+ self.assertSequenceEqual(
282
+ t + transformf((2.0, 3.0, 4.0), (5.0, 6.0, 7.0, 8.0)),
283
+ (3.0, 5.0, 7.0, 9.0, 11.0, 13.0, 15.0),
284
+ )
285
+ self.assertSequenceEqual(
286
+ t - transformf((2.0, 3.0, 4.0), (5.0, 6.0, 7.0, 8.0)),
287
+ (-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0),
288
+ )
289
+
290
+ f = wp.sin(math.pi * 0.5)
291
+ self.assertAlmostEqual(f, 1.0, places=3)
292
+
293
+ m = wp.mat22(0.0, 0.0, 0.0, 0.0)
294
+ m += wp.mat22(1.0, 1.0, 1.0, 1.0)
295
+ assert m == wp.mat22(1.0, 1.0, 1.0, 1.0)
296
+ m -= wp.mat22(1.0, 1.0, 1.0, 1.0)
297
+ assert m == wp.mat22(0.0, 0.0, 0.0, 0.0)
298
+ m = wp.mat22(2.0, 2.0, 2.0, 2.0) - wp.mat22(1.0, 1.0, 1.0, 1.0)
299
+ assert m == wp.mat22(1.0, 1.0, 1.0, 1.0)
300
+ m *= 2.0
301
+ assert m == wp.mat22(2.0, 2.0, 2.0, 2.0)
302
+ m = m * 2.0
303
+ assert m == wp.mat22(4.0, 4.0, 4.0, 4.0)
304
+ m = m / 2.0
305
+ assert m == wp.mat22(2.0, 2.0, 2.0, 2.0)
306
+ m /= 2.0
307
+ assert m == wp.mat22(1.0, 1.0, 1.0, 1.0)
308
+ m = -m
309
+ assert m == wp.mat22(-1.0, -1.0, -1.0, -1.0)
310
+ m = +m
311
+ assert m == wp.mat22(-1.0, -1.0, -1.0, -1.0)
312
+ m = m * m
313
+ assert m == wp.mat22(2.0, 2.0, 2.0, 2.0)
314
+
315
+
316
+ def test_native_function_error_resolution(self):
317
+ a = wp.mat22f(1.0, 2.0, 3.0, 4.0)
318
+ b = wp.mat22d(1.0, 2.0, 3.0, 4.0)
319
+ with self.assertRaisesRegex(
320
+ RuntimeError,
321
+ r"^Couldn't find a function 'mul' compatible with " r"the arguments 'mat22f, mat22d'$",
322
+ ):
323
+ a * b
324
+
325
+
326
+ add_kernel_test(TestFunc, kernel=test_overload_func, name="test_overload_func", dim=1, devices=devices)
327
+ add_function_test(TestFunc, func=test_return_func, name="test_return_func", devices=devices)
328
+ add_kernel_test(TestFunc, kernel=test_override_func, name="test_override_func", dim=1, devices=devices)
329
+ add_function_test(TestFunc, func=test_func_closure_capture, name="test_func_closure_capture", devices=devices)
330
+ add_function_test(TestFunc, func=test_multi_valued_func, name="test_multi_valued_func", devices=devices)
331
+ add_kernel_test(TestFunc, kernel=test_func_defaults, name="test_func_defaults", dim=1, devices=devices)
332
+ add_kernel_test(TestFunc, kernel=test_builtin_shadowing, name="test_builtin_shadowing", dim=1, devices=devices)
230
333
 
231
334
 
232
335
  if __name__ == "__main__":
233
- c = register(unittest.TestCase)
234
- wp.force_load()
235
-
336
+ wp.build.clear_kernel_cache()
236
337
  unittest.main(verbosity=2)
@@ -5,12 +5,13 @@
5
5
  # distribution of this software and related documentation without an express
6
6
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
7
 
8
- import numpy as np
9
8
  import unittest
10
9
  from typing import Any
11
10
 
11
+ import numpy as np
12
+
12
13
  import warp as wp
13
- from warp.tests.test_base import *
14
+ from warp.tests.unittest_utils import *
14
15
 
15
16
  wp.init()
16
17
 
@@ -363,54 +364,198 @@ wp.overload(test_generic_struct_kernel, [Foo])
363
364
  wp.overload(test_generic_struct_kernel, [Bar])
364
365
 
365
366
 
366
- def register(parent):
367
- class TestGenerics(parent):
368
- pass
369
-
370
- devices = get_test_devices()
371
-
372
- add_kernel_test(TestGenerics, name="test_generic_adder", kernel=test_generic_adder, dim=1, devices=devices)
373
- add_kernel_test(TestGenerics, name="test_specialized_func", kernel=test_specialized_func, dim=1, devices=devices)
374
-
375
- add_function_test(TestGenerics, "test_generic_array_kernel", test_generic_array_kernel, devices=devices)
376
- add_function_test(TestGenerics, "test_generic_accumulator_kernel", test_generic_accumulator_kernel, devices=devices)
377
- add_function_test(TestGenerics, "test_generic_fill", test_generic_fill, devices=devices)
378
- add_function_test(TestGenerics, "test_generic_fill_overloads", test_generic_fill_overloads, devices=devices)
379
- add_function_test(TestGenerics, "test_generic_transform_kernel", test_generic_transform_kernel, devices=devices)
380
- add_function_test(
381
- TestGenerics, "test_generic_transform_array_kernel", test_generic_transform_array_kernel, devices=devices
382
- )
383
-
384
- foo = Foo()
385
- foo.x = 17.0
386
- foo.y = 25.0
387
- foo.z = 42.0
388
-
389
- bar = Bar()
390
- bar.x = wp.vec3(1, 2, 3)
391
- bar.y = wp.vec3(10, 20, 30)
392
- bar.z = wp.vec3(11, 22, 33)
393
-
394
- add_kernel_test(
395
- TestGenerics,
396
- name="test_generic_struct_kernel",
397
- kernel=test_generic_struct_kernel,
398
- dim=1,
399
- inputs=[foo],
400
- devices=devices,
401
- )
402
- add_kernel_test(
403
- TestGenerics,
404
- name="test_generic_struct_kernel",
405
- kernel=test_generic_struct_kernel,
406
- dim=1,
407
- inputs=[bar],
408
- devices=devices,
409
- )
410
-
411
- return TestGenerics
367
+ @wp.kernel
368
+ def test_generic_type_cast_kernel(a: Any, b: Any):
369
+ a = type(a)(b)
370
+ c = type(generic_adder(b, b))(a)
371
+ wp.expect_eq(b, c)
372
+
373
+
374
+ wp.overload(test_generic_type_cast_kernel, [wp.float32, wp.float64])
375
+ wp.overload(test_generic_type_cast_kernel, [wp.float32, wp.int32])
376
+ wp.overload(test_generic_type_cast_kernel, [wp.vec3f, wp.vec3d])
377
+ wp.overload(test_generic_type_cast_kernel, [wp.mat22f, wp.mat22d])
378
+
379
+
380
+ def test_generic_type_cast(test, device):
381
+ with wp.ScopedDevice(device):
382
+ wp.launch(test_generic_type_cast_kernel, dim=1, inputs=[1.0, 2.0])
383
+ wp.launch(test_generic_type_cast_kernel, dim=1, inputs=[2.0, -5])
384
+ wp.launch(test_generic_type_cast_kernel, dim=1, inputs=[wp.vec3f(1.0, 2.0, 3.0), wp.vec3d(4.0, 5.0, 6.0)])
385
+ wp.launch(test_generic_type_cast_kernel, dim=1, inputs=[wp.mat22f(0.0), wp.mat22d(np.eye(2))])
386
+
387
+ wp.synchronize()
388
+
389
+
390
+ @wp.kernel
391
+ def test_generic_scalar_construction_kernel(a: wp.array(dtype=Any)):
392
+ zero = type(a[0])(0)
393
+ copy = a.dtype(a[0])
394
+ copy += zero
395
+ wp.expect_eq(copy, a[0])
396
+
397
+
398
+ wp.overload(test_generic_scalar_construction_kernel, [wp.array(dtype=wp.int32)])
399
+ wp.overload(test_generic_scalar_construction_kernel, [wp.array(dtype=wp.float64)])
400
+
401
+
402
+ def test_generic_scalar_construction(test, device):
403
+ with wp.ScopedDevice(device):
404
+ wp.launch(test_generic_scalar_construction_kernel, dim=1, inputs=[wp.array([1.0], dtype=wp.int32)])
405
+ wp.launch(test_generic_scalar_construction_kernel, dim=1, inputs=[wp.array([-5], dtype=wp.float64)])
406
+
407
+ wp.synchronize()
408
+
409
+
410
+ @wp.kernel
411
+ def test_generic_type_construction_kernel(a: wp.array(dtype=Any)):
412
+ zero = type(a[0])()
413
+ copy = type(a).dtype(a[0]) * a.dtype.dtype(1.0)
414
+ copy += zero
415
+ wp.expect_eq(copy, a[0])
416
+
417
+
418
+ wp.overload(test_generic_type_construction_kernel, [wp.array(dtype=wp.vec3f)])
419
+ wp.overload(test_generic_type_construction_kernel, [wp.array(dtype=wp.mat22d)])
420
+
421
+
422
+ def test_generic_type_construction(test, device):
423
+ with wp.ScopedDevice(device):
424
+ wp.launch(test_generic_type_construction_kernel, dim=1, inputs=[wp.array([1.0, 2.0, 3.0], dtype=wp.vec3f)])
425
+ wp.launch(test_generic_type_construction_kernel, dim=1, inputs=[wp.array([np.eye(2)], dtype=wp.mat22d)])
426
+
427
+ wp.synchronize()
428
+
429
+
430
+ @wp.kernel
431
+ def test_generic_struct_construction_kernel(a: Any):
432
+ b = type(a)(a.x, a.y, a.z)
433
+ wp.expect_eq(a.x, b.x)
434
+ wp.expect_eq(a.y, b.y)
435
+ wp.expect_eq(a.z, b.z)
436
+
437
+
438
+ wp.overload(test_generic_struct_construction_kernel, [Foo])
439
+ wp.overload(test_generic_struct_construction_kernel, [Bar])
440
+
441
+
442
+ @wp.kernel
443
+ def test_generic_type_as_argument_kernel(a: Any):
444
+ vec = wp.vector(length=2, dtype=type(a))
445
+ matrix = wp.identity(n=vec.length, dtype=vec.dtype) * a
446
+ wp.expect_eq(wp.trace(matrix), type(a)(2.0) * a)
447
+
448
+
449
+ wp.overload(test_generic_type_as_argument_kernel, [wp.float32])
450
+ wp.overload(test_generic_type_as_argument_kernel, [wp.float64])
451
+
452
+
453
+ def test_generic_type_as_argument(test, device):
454
+ with wp.ScopedDevice(device):
455
+ wp.launch(test_generic_type_as_argument_kernel, dim=1, inputs=[2.0])
456
+ wp.launch(test_generic_type_as_argument_kernel, dim=1, inputs=[-1.0])
457
+
458
+ wp.synchronize()
459
+
412
460
 
461
+ def test_type_operator_mispell(test, device):
462
+ @wp.kernel
463
+ def kernel():
464
+ i = wp.tid()
465
+ _ = typez(i)(0)
466
+
467
+ with test.assertRaisesRegex(RuntimeError, r"Unknown function or operator: 'typez'$"):
468
+ wp.launch(
469
+ kernel,
470
+ dim=1,
471
+ inputs=[],
472
+ device=device,
473
+ )
474
+
475
+
476
+ def test_type_attribute_error(test, device):
477
+ @wp.kernel
478
+ def kernel():
479
+ a = wp.vec3(0.0)
480
+ _ = a.dtype.shape
481
+
482
+ with test.assertRaisesRegex(AttributeError, r"`shape` is not an attribute of '<class 'warp.types.float32'>'"):
483
+ wp.launch(
484
+ kernel,
485
+ dim=1,
486
+ inputs=[],
487
+ device=device,
488
+ )
489
+
490
+
491
+ class TestGenerics(unittest.TestCase):
492
+ pass
493
+
494
+
495
+ devices = get_test_devices()
496
+
497
+ add_kernel_test(TestGenerics, name="test_generic_adder", kernel=test_generic_adder, dim=1, devices=devices)
498
+ add_kernel_test(TestGenerics, name="test_specialized_func", kernel=test_specialized_func, dim=1, devices=devices)
499
+
500
+ add_function_test(TestGenerics, "test_generic_array_kernel", test_generic_array_kernel, devices=devices)
501
+ add_function_test(TestGenerics, "test_generic_accumulator_kernel", test_generic_accumulator_kernel, devices=devices)
502
+ add_function_test(TestGenerics, "test_generic_fill", test_generic_fill, devices=devices)
503
+ add_function_test(TestGenerics, "test_generic_fill_overloads", test_generic_fill_overloads, devices=devices)
504
+ add_function_test(TestGenerics, "test_generic_transform_kernel", test_generic_transform_kernel, devices=devices)
505
+ add_function_test(
506
+ TestGenerics, "test_generic_transform_array_kernel", test_generic_transform_array_kernel, devices=devices
507
+ )
508
+ add_function_test(TestGenerics, "test_generic_type_cast", test_generic_type_cast, devices=devices)
509
+ add_function_test(TestGenerics, "test_generic_type_construction", test_generic_type_construction, devices=devices)
510
+ add_function_test(TestGenerics, "test_generic_scalar_construction", test_generic_scalar_construction, devices=devices)
511
+ add_function_test(TestGenerics, "test_generic_type_as_argument", test_generic_type_as_argument, devices=devices)
512
+
513
+ foo = Foo()
514
+ foo.x = 17.0
515
+ foo.y = 25.0
516
+ foo.z = 42.0
517
+
518
+ bar = Bar()
519
+ bar.x = wp.vec3(1, 2, 3)
520
+ bar.y = wp.vec3(10, 20, 30)
521
+ bar.z = wp.vec3(11, 22, 33)
522
+
523
+ add_kernel_test(
524
+ TestGenerics,
525
+ name="test_generic_struct_kernel",
526
+ kernel=test_generic_struct_kernel,
527
+ dim=1,
528
+ inputs=[foo],
529
+ devices=devices,
530
+ )
531
+ add_kernel_test(
532
+ TestGenerics,
533
+ name="test_generic_struct_kernel",
534
+ kernel=test_generic_struct_kernel,
535
+ dim=1,
536
+ inputs=[bar],
537
+ devices=devices,
538
+ )
539
+
540
+ add_kernel_test(
541
+ TestGenerics,
542
+ name="test_generic_struct_construction_kernel",
543
+ kernel=test_generic_struct_construction_kernel,
544
+ dim=1,
545
+ inputs=[foo],
546
+ devices=devices,
547
+ )
548
+ add_kernel_test(
549
+ TestGenerics,
550
+ name="test_generic_struct_construction_kernel",
551
+ kernel=test_generic_struct_construction_kernel,
552
+ dim=1,
553
+ inputs=[bar],
554
+ devices=devices,
555
+ )
556
+ add_function_test(TestGenerics, "test_type_operator_mispell", test_type_operator_mispell, devices=devices)
557
+ add_function_test(TestGenerics, "test_type_attribute_error", test_type_attribute_error, devices=devices)
413
558
 
414
559
  if __name__ == "__main__":
415
- c = register(unittest.TestCase)
560
+ wp.build.clear_kernel_cache()
416
561
  unittest.main(verbosity=2)