warp-lang 0.9.0__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (315) hide show
  1. warp/__init__.py +15 -7
  2. warp/__init__.pyi +1 -0
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +22 -443
  6. warp/build_dll.py +384 -0
  7. warp/builtins.py +998 -488
  8. warp/codegen.py +1307 -739
  9. warp/config.py +5 -3
  10. warp/constants.py +6 -0
  11. warp/context.py +1291 -548
  12. warp/dlpack.py +31 -31
  13. warp/fabric.py +326 -0
  14. warp/fem/__init__.py +27 -0
  15. warp/fem/cache.py +389 -0
  16. warp/fem/dirichlet.py +181 -0
  17. warp/fem/domain.py +263 -0
  18. warp/fem/field/__init__.py +101 -0
  19. warp/fem/field/field.py +149 -0
  20. warp/fem/field/nodal_field.py +299 -0
  21. warp/fem/field/restriction.py +21 -0
  22. warp/fem/field/test.py +181 -0
  23. warp/fem/field/trial.py +183 -0
  24. warp/fem/geometry/__init__.py +19 -0
  25. warp/fem/geometry/closest_point.py +70 -0
  26. warp/fem/geometry/deformed_geometry.py +271 -0
  27. warp/fem/geometry/element.py +744 -0
  28. warp/fem/geometry/geometry.py +186 -0
  29. warp/fem/geometry/grid_2d.py +373 -0
  30. warp/fem/geometry/grid_3d.py +435 -0
  31. warp/fem/geometry/hexmesh.py +953 -0
  32. warp/fem/geometry/partition.py +376 -0
  33. warp/fem/geometry/quadmesh_2d.py +532 -0
  34. warp/fem/geometry/tetmesh.py +840 -0
  35. warp/fem/geometry/trimesh_2d.py +577 -0
  36. warp/fem/integrate.py +1616 -0
  37. warp/fem/operator.py +191 -0
  38. warp/fem/polynomial.py +213 -0
  39. warp/fem/quadrature/__init__.py +2 -0
  40. warp/fem/quadrature/pic_quadrature.py +245 -0
  41. warp/fem/quadrature/quadrature.py +294 -0
  42. warp/fem/space/__init__.py +292 -0
  43. warp/fem/space/basis_space.py +489 -0
  44. warp/fem/space/collocated_function_space.py +105 -0
  45. warp/fem/space/dof_mapper.py +236 -0
  46. warp/fem/space/function_space.py +145 -0
  47. warp/fem/space/grid_2d_function_space.py +267 -0
  48. warp/fem/space/grid_3d_function_space.py +306 -0
  49. warp/fem/space/hexmesh_function_space.py +352 -0
  50. warp/fem/space/partition.py +350 -0
  51. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  52. warp/fem/space/restriction.py +160 -0
  53. warp/fem/space/shape/__init__.py +15 -0
  54. warp/fem/space/shape/cube_shape_function.py +738 -0
  55. warp/fem/space/shape/shape_function.py +103 -0
  56. warp/fem/space/shape/square_shape_function.py +611 -0
  57. warp/fem/space/shape/tet_shape_function.py +567 -0
  58. warp/fem/space/shape/triangle_shape_function.py +429 -0
  59. warp/fem/space/tetmesh_function_space.py +292 -0
  60. warp/fem/space/topology.py +295 -0
  61. warp/fem/space/trimesh_2d_function_space.py +221 -0
  62. warp/fem/types.py +77 -0
  63. warp/fem/utils.py +495 -0
  64. warp/native/array.h +164 -55
  65. warp/native/builtin.h +150 -174
  66. warp/native/bvh.cpp +75 -328
  67. warp/native/bvh.cu +406 -23
  68. warp/native/bvh.h +37 -45
  69. warp/native/clang/clang.cpp +136 -24
  70. warp/native/crt.cpp +1 -76
  71. warp/native/crt.h +111 -104
  72. warp/native/cuda_crt.h +1049 -0
  73. warp/native/cuda_util.cpp +15 -3
  74. warp/native/cuda_util.h +3 -1
  75. warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
  76. warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
  77. warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
  78. warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
  79. warp/native/cutlass/tools/library/scripts/library.py +799 -0
  80. warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
  81. warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
  82. warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
  83. warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
  84. warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
  85. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
  86. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
  87. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
  88. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
  89. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
  90. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
  91. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
  92. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
  93. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
  94. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
  95. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
  96. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
  97. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
  98. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
  99. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
  100. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
  101. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
  102. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
  103. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
  104. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
  105. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
  106. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
  107. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
  108. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
  109. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
  110. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
  111. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
  112. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
  113. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
  114. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
  115. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
  116. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
  117. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  118. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  119. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
  120. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
  121. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
  122. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  123. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  124. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
  125. warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
  126. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
  127. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
  128. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
  129. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
  130. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
  131. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
  132. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
  133. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
  134. warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
  135. warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
  136. warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
  137. warp/native/cutlass/tools/library/scripts/rt.py +796 -0
  138. warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
  139. warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
  140. warp/native/cutlass_gemm.cu +5 -3
  141. warp/native/exports.h +1240 -949
  142. warp/native/fabric.h +228 -0
  143. warp/native/hashgrid.cpp +4 -4
  144. warp/native/hashgrid.h +22 -2
  145. warp/native/initializer_array.h +2 -2
  146. warp/native/intersect.h +22 -7
  147. warp/native/intersect_adj.h +8 -8
  148. warp/native/intersect_tri.h +13 -16
  149. warp/native/marching.cu +157 -161
  150. warp/native/mat.h +119 -19
  151. warp/native/matnn.h +2 -2
  152. warp/native/mesh.cpp +108 -83
  153. warp/native/mesh.cu +243 -6
  154. warp/native/mesh.h +1547 -458
  155. warp/native/nanovdb/NanoVDB.h +1 -1
  156. warp/native/noise.h +272 -329
  157. warp/native/quat.h +51 -8
  158. warp/native/rand.h +45 -35
  159. warp/native/range.h +6 -2
  160. warp/native/reduce.cpp +157 -0
  161. warp/native/reduce.cu +348 -0
  162. warp/native/runlength_encode.cpp +62 -0
  163. warp/native/runlength_encode.cu +46 -0
  164. warp/native/scan.cu +11 -13
  165. warp/native/scan.h +1 -0
  166. warp/native/solid_angle.h +442 -0
  167. warp/native/sort.cpp +13 -0
  168. warp/native/sort.cu +9 -1
  169. warp/native/sparse.cpp +338 -0
  170. warp/native/sparse.cu +545 -0
  171. warp/native/spatial.h +2 -2
  172. warp/native/temp_buffer.h +30 -0
  173. warp/native/vec.h +126 -24
  174. warp/native/volume.h +120 -0
  175. warp/native/warp.cpp +658 -53
  176. warp/native/warp.cu +660 -68
  177. warp/native/warp.h +112 -12
  178. warp/optim/__init__.py +1 -0
  179. warp/optim/linear.py +922 -0
  180. warp/optim/sgd.py +92 -0
  181. warp/render/render_opengl.py +392 -152
  182. warp/render/render_usd.py +11 -11
  183. warp/sim/__init__.py +2 -2
  184. warp/sim/articulation.py +385 -185
  185. warp/sim/collide.py +21 -8
  186. warp/sim/import_mjcf.py +297 -106
  187. warp/sim/import_urdf.py +389 -210
  188. warp/sim/import_usd.py +198 -97
  189. warp/sim/inertia.py +17 -18
  190. warp/sim/integrator_euler.py +14 -8
  191. warp/sim/integrator_xpbd.py +161 -19
  192. warp/sim/model.py +795 -291
  193. warp/sim/optimizer.py +2 -6
  194. warp/sim/render.py +65 -3
  195. warp/sim/utils.py +3 -0
  196. warp/sparse.py +1227 -0
  197. warp/stubs.py +665 -223
  198. warp/tape.py +66 -15
  199. warp/tests/__main__.py +3 -6
  200. warp/tests/assets/curlnoise_golden.npy +0 -0
  201. warp/tests/assets/pnoise_golden.npy +0 -0
  202. warp/tests/assets/torus.usda +105 -105
  203. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  204. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  205. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  206. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  207. warp/tests/aux_test_unresolved_func.py +14 -0
  208. warp/tests/aux_test_unresolved_symbol.py +14 -0
  209. warp/tests/disabled_kinematics.py +239 -0
  210. warp/tests/run_coverage_serial.py +31 -0
  211. warp/tests/test_adam.py +103 -106
  212. warp/tests/test_arithmetic.py +128 -74
  213. warp/tests/test_array.py +1497 -211
  214. warp/tests/test_array_reduce.py +150 -0
  215. warp/tests/test_atomic.py +64 -28
  216. warp/tests/test_bool.py +99 -0
  217. warp/tests/test_builtins_resolution.py +1292 -0
  218. warp/tests/test_bvh.py +75 -43
  219. warp/tests/test_closest_point_edge_edge.py +54 -57
  220. warp/tests/test_codegen.py +233 -128
  221. warp/tests/test_compile_consts.py +28 -20
  222. warp/tests/test_conditional.py +108 -24
  223. warp/tests/test_copy.py +10 -12
  224. warp/tests/test_ctypes.py +112 -88
  225. warp/tests/test_dense.py +21 -14
  226. warp/tests/test_devices.py +98 -0
  227. warp/tests/test_dlpack.py +136 -108
  228. warp/tests/test_examples.py +277 -0
  229. warp/tests/test_fabricarray.py +955 -0
  230. warp/tests/test_fast_math.py +15 -11
  231. warp/tests/test_fem.py +1271 -0
  232. warp/tests/test_fp16.py +53 -19
  233. warp/tests/test_func.py +187 -74
  234. warp/tests/test_generics.py +194 -49
  235. warp/tests/test_grad.py +180 -116
  236. warp/tests/test_grad_customs.py +176 -0
  237. warp/tests/test_hash_grid.py +52 -37
  238. warp/tests/test_import.py +10 -23
  239. warp/tests/test_indexedarray.py +577 -24
  240. warp/tests/test_intersect.py +18 -9
  241. warp/tests/test_large.py +141 -0
  242. warp/tests/test_launch.py +251 -15
  243. warp/tests/test_lerp.py +64 -65
  244. warp/tests/test_linear_solvers.py +154 -0
  245. warp/tests/test_lvalue.py +493 -0
  246. warp/tests/test_marching_cubes.py +12 -13
  247. warp/tests/test_mat.py +508 -2778
  248. warp/tests/test_mat_lite.py +115 -0
  249. warp/tests/test_mat_scalar_ops.py +2889 -0
  250. warp/tests/test_math.py +103 -9
  251. warp/tests/test_matmul.py +305 -69
  252. warp/tests/test_matmul_lite.py +410 -0
  253. warp/tests/test_mesh.py +71 -14
  254. warp/tests/test_mesh_query_aabb.py +41 -25
  255. warp/tests/test_mesh_query_point.py +325 -34
  256. warp/tests/test_mesh_query_ray.py +39 -22
  257. warp/tests/test_mlp.py +30 -22
  258. warp/tests/test_model.py +92 -89
  259. warp/tests/test_modules_lite.py +39 -0
  260. warp/tests/test_multigpu.py +88 -114
  261. warp/tests/test_noise.py +12 -11
  262. warp/tests/test_operators.py +16 -20
  263. warp/tests/test_options.py +11 -11
  264. warp/tests/test_pinned.py +17 -18
  265. warp/tests/test_print.py +32 -11
  266. warp/tests/test_quat.py +275 -129
  267. warp/tests/test_rand.py +18 -16
  268. warp/tests/test_reload.py +38 -34
  269. warp/tests/test_rounding.py +50 -43
  270. warp/tests/test_runlength_encode.py +190 -0
  271. warp/tests/test_smoothstep.py +9 -11
  272. warp/tests/test_snippet.py +143 -0
  273. warp/tests/test_sparse.py +460 -0
  274. warp/tests/test_spatial.py +276 -243
  275. warp/tests/test_streams.py +110 -85
  276. warp/tests/test_struct.py +331 -85
  277. warp/tests/test_tape.py +39 -21
  278. warp/tests/test_torch.py +118 -89
  279. warp/tests/test_transient_module.py +12 -13
  280. warp/tests/test_types.py +614 -0
  281. warp/tests/test_utils.py +494 -0
  282. warp/tests/test_vec.py +354 -1987
  283. warp/tests/test_vec_lite.py +73 -0
  284. warp/tests/test_vec_scalar_ops.py +2099 -0
  285. warp/tests/test_volume.py +457 -293
  286. warp/tests/test_volume_write.py +124 -134
  287. warp/tests/unittest_serial.py +35 -0
  288. warp/tests/unittest_suites.py +341 -0
  289. warp/tests/unittest_utils.py +568 -0
  290. warp/tests/unused_test_misc.py +71 -0
  291. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  292. warp/thirdparty/appdirs.py +36 -45
  293. warp/thirdparty/unittest_parallel.py +549 -0
  294. warp/torch.py +72 -30
  295. warp/types.py +1744 -713
  296. warp/utils.py +360 -350
  297. warp_lang-0.11.0.dist-info/LICENSE.md +36 -0
  298. warp_lang-0.11.0.dist-info/METADATA +238 -0
  299. warp_lang-0.11.0.dist-info/RECORD +332 -0
  300. {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
  301. warp/bin/warp-clang.exp +0 -0
  302. warp/bin/warp-clang.lib +0 -0
  303. warp/bin/warp.exp +0 -0
  304. warp/bin/warp.lib +0 -0
  305. warp/tests/test_all.py +0 -215
  306. warp/tests/test_array_scan.py +0 -60
  307. warp/tests/test_base.py +0 -208
  308. warp/tests/test_unresolved_func.py +0 -7
  309. warp/tests/test_unresolved_symbol.py +0 -7
  310. warp_lang-0.9.0.dist-info/METADATA +0 -20
  311. warp_lang-0.9.0.dist-info/RECORD +0 -177
  312. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  313. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  314. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  315. {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
@@ -5,9 +5,13 @@
5
5
  # distribution of this software and related documentation without an express
6
6
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
7
 
8
+ import math
9
+ import unittest
10
+
8
11
  import numpy as np
12
+
9
13
  import warp as wp
10
- from warp.tests.test_base import *
14
+ from warp.tests.unittest_utils import *
11
15
 
12
16
  wp.init()
13
17
 
@@ -34,22 +38,21 @@ np_float_types = [np.float16, np.float32, np.float64]
34
38
  np_scalar_types = np_int_types + np_float_types
35
39
 
36
40
 
37
- def randvals(shape, dtype):
41
+ def randvals(rng, shape, dtype):
38
42
  if dtype in np_float_types:
39
- return np.random.randn(*shape).astype(dtype)
43
+ return rng.standard_normal(size=shape).astype(dtype)
40
44
  elif dtype in [np.int8, np.uint8, np.byte, np.ubyte]:
41
- return np.random.randint(1, 3, size=shape, dtype=dtype)
42
- return np.random.randint(1, 5, size=shape, dtype=dtype)
45
+ return rng.integers(1, high=3, size=shape, dtype=dtype)
46
+ return rng.integers(1, high=5, size=shape, dtype=dtype)
43
47
 
44
48
 
45
49
  kernel_cache = dict()
46
50
 
47
51
 
48
52
  def getkernel(func, suffix=""):
49
- module = wp.get_module(func.__module__)
50
53
  key = func.__name__ + "_" + suffix
51
54
  if key not in kernel_cache:
52
- kernel_cache[key] = wp.Kernel(func=func, key=key, module=module)
55
+ kernel_cache[key] = wp.Kernel(func=func, key=key)
53
56
  return kernel_cache[key]
54
57
 
55
58
 
@@ -77,7 +80,7 @@ def get_select_kernel2(dtype):
77
80
 
78
81
 
79
82
  def test_arrays(test, device, dtype):
80
- np.random.seed(123)
83
+ rng = np.random.default_rng(123)
81
84
 
82
85
  tol = {
83
86
  np.float16: 1.0e-3,
@@ -86,14 +89,14 @@ def test_arrays(test, device, dtype):
86
89
  }.get(dtype, 0)
87
90
 
88
91
  wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
89
- arr_np = randvals((10, 5), dtype)
92
+ arr_np = randvals(rng, (10, 5), dtype)
90
93
  arr = wp.array(arr_np, dtype=wptype, requires_grad=True, device=device)
91
94
 
92
95
  assert_np_equal(arr.numpy(), arr_np, tol=tol)
93
96
 
94
97
 
95
98
  def test_unary_ops(test, device, dtype, register_kernels=False):
96
- np.random.seed(123)
99
+ rng = np.random.default_rng(123)
97
100
 
98
101
  tol = {
99
102
  np.float16: 5.0e-3,
@@ -128,10 +131,12 @@ def test_unary_ops(test, device, dtype, register_kernels=False):
128
131
  return
129
132
 
130
133
  if dtype in np_float_types:
131
- inputs = wp.array(np.random.randn(5, 10).astype(dtype), dtype=wptype, requires_grad=True, device=device)
134
+ inputs = wp.array(
135
+ rng.standard_normal(size=(5, 10)).astype(dtype), dtype=wptype, requires_grad=True, device=device
136
+ )
132
137
  else:
133
138
  inputs = wp.array(
134
- np.random.randint(-2, 3, size=(5, 10), dtype=dtype), dtype=wptype, requires_grad=True, device=device
139
+ rng.integers(-2, high=3, size=(5, 10), dtype=dtype), dtype=wptype, requires_grad=True, device=device
135
140
  )
136
141
  outputs = wp.zeros_like(inputs)
137
142
 
@@ -207,7 +212,7 @@ def test_unary_ops(test, device, dtype, register_kernels=False):
207
212
 
208
213
 
209
214
  def test_nonzero(test, device, dtype, register_kernels=False):
210
- np.random.seed(123)
215
+ rng = np.random.default_rng(123)
211
216
 
212
217
  tol = {
213
218
  np.float16: 5.0e-3,
@@ -231,7 +236,7 @@ def test_nonzero(test, device, dtype, register_kernels=False):
231
236
  if register_kernels:
232
237
  return
233
238
 
234
- inputs = wp.array(np.random.randint(-2, 3, size=10).astype(dtype), dtype=wptype, requires_grad=True, device=device)
239
+ inputs = wp.array(rng.integers(-2, high=3, size=10).astype(dtype), dtype=wptype, requires_grad=True, device=device)
235
240
  outputs = wp.zeros_like(inputs)
236
241
 
237
242
  wp.launch(kernel, dim=1, inputs=[inputs], outputs=[outputs], device=device)
@@ -253,10 +258,10 @@ def test_nonzero(test, device, dtype, register_kernels=False):
253
258
 
254
259
 
255
260
  def test_binary_ops(test, device, dtype, register_kernels=False):
256
- np.random.seed(123)
261
+ rng = np.random.default_rng(123)
257
262
 
258
263
  tol = {
259
- np.float16: 1.0e-2,
264
+ np.float16: 5.0e-2,
260
265
  np.float32: 1.0e-6,
261
266
  np.float64: 1.0e-8,
262
267
  }.get(dtype, 0)
@@ -302,11 +307,11 @@ def test_binary_ops(test, device, dtype, register_kernels=False):
302
307
  if register_kernels:
303
308
  return
304
309
 
305
- vals1 = randvals([8, 10], dtype)
310
+ vals1 = randvals(rng, [8, 10], dtype)
306
311
  if dtype in [np_unsigned_int_types]:
307
- vals2 = vals1 + randvals([8, 10], dtype)
312
+ vals2 = vals1 + randvals(rng, [8, 10], dtype)
308
313
  else:
309
- vals2 = np.abs(randvals([8, 10], dtype))
314
+ vals2 = np.abs(randvals(rng, [8, 10], dtype))
310
315
 
311
316
  in1 = wp.array(vals1, dtype=wptype, requires_grad=True, device=device)
312
317
  in2 = wp.array(vals2, dtype=wptype, requires_grad=True, device=device)
@@ -458,7 +463,7 @@ def test_binary_ops(test, device, dtype, register_kernels=False):
458
463
 
459
464
 
460
465
  def test_special_funcs(test, device, dtype, register_kernels=False):
461
- np.random.seed(123)
466
+ rng = np.random.default_rng(123)
462
467
 
463
468
  tol = {
464
469
  np.float16: 1.0e-2,
@@ -488,6 +493,7 @@ def test_special_funcs(test, device, dtype, register_kernels=False):
488
493
  outputs[11, i] = wptype(2) * wp.tanh(inputs[11, i])
489
494
  outputs[12, i] = wptype(2) * wp.acos(inputs[12, i])
490
495
  outputs[13, i] = wptype(2) * wp.asin(inputs[13, i])
496
+ outputs[14, i] = wptype(2) * wp.cbrt(inputs[14, i])
491
497
 
492
498
  kernel = getkernel(check_special_funcs, suffix=dtype.__name__)
493
499
  output_select_kernel = get_select_kernel2(wptype)
@@ -495,8 +501,8 @@ def test_special_funcs(test, device, dtype, register_kernels=False):
495
501
  if register_kernels:
496
502
  return
497
503
 
498
- invals = np.random.randn(14, 10).astype(dtype)
499
- invals[[0, 1, 2, 7]] = 0.1 + np.abs(invals[[0, 1, 2, 7]])
504
+ invals = rng.normal(size=(15, 10)).astype(dtype)
505
+ invals[[0, 1, 2, 7, 14]] = 0.1 + np.abs(invals[[0, 1, 2, 7, 14]])
500
506
  invals[12] = np.clip(invals[12], -0.9, 0.9)
501
507
  invals[13] = np.clip(invals[13], -0.9, 0.9)
502
508
  inputs = wp.array(invals, dtype=wptype, requires_grad=True, device=device)
@@ -518,6 +524,7 @@ def test_special_funcs(test, device, dtype, register_kernels=False):
518
524
  assert_np_equal(outputs.numpy()[11], 2 * np.tanh(inputs.numpy()[11]), tol=tol)
519
525
  assert_np_equal(outputs.numpy()[12], 2 * np.arccos(inputs.numpy()[12]), tol=tol)
520
526
  assert_np_equal(outputs.numpy()[13], 2 * np.arcsin(inputs.numpy()[13]), tol=tol)
527
+ assert_np_equal(outputs.numpy()[14], 2 * np.cbrt(inputs.numpy()[14]), tol=tol)
521
528
 
522
529
  out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
523
530
  if dtype in np_float_types:
@@ -694,9 +701,22 @@ def test_special_funcs(test, device, dtype, register_kernels=False):
694
701
  assert_np_equal(tape.gradients[inputs].numpy(), expected, tol=6 * tol)
695
702
  tape.zero()
696
703
 
704
+ # cbrt:
705
+ tape = wp.Tape()
706
+ with tape:
707
+ wp.launch(kernel, dim=1, inputs=[inputs], outputs=[outputs], device=device)
708
+ wp.launch(output_select_kernel, dim=1, inputs=[outputs, 14, i], outputs=[out], device=device)
709
+
710
+ tape.backward(loss=out)
711
+ expected = np.zeros_like(inputs.numpy())
712
+ cbrt = np.cbrt(inputs.numpy()[14, i], dtype=np.dtype(dtype))
713
+ expected[14, i] = (2.0 / 3.0) * (1.0 / (cbrt * cbrt))
714
+ assert_np_equal(tape.gradients[inputs].numpy(), expected, tol=tol)
715
+ tape.zero()
716
+
697
717
 
698
718
  def test_special_funcs_2arg(test, device, dtype, register_kernels=False):
699
- np.random.seed(123)
719
+ rng = np.random.default_rng(123)
700
720
 
701
721
  tol = {
702
722
  np.float16: 1.0e-2,
@@ -722,8 +742,8 @@ def test_special_funcs_2arg(test, device, dtype, register_kernels=False):
722
742
  if register_kernels:
723
743
  return
724
744
 
725
- in1 = wp.array(np.abs(randvals([2, 10], dtype)), dtype=wptype, requires_grad=True, device=device)
726
- in2 = wp.array(randvals([2, 10], dtype), dtype=wptype, requires_grad=True, device=device)
745
+ in1 = wp.array(np.abs(randvals(rng, [2, 10], dtype)), dtype=wptype, requires_grad=True, device=device)
746
+ in2 = wp.array(randvals(rng, [2, 10], dtype), dtype=wptype, requires_grad=True, device=device)
727
747
  outputs = wp.zeros_like(in1)
728
748
 
729
749
  wp.launch(kernel, dim=1, inputs=[in1, in2], outputs=[outputs], device=device)
@@ -763,7 +783,7 @@ def test_special_funcs_2arg(test, device, dtype, register_kernels=False):
763
783
 
764
784
 
765
785
  def test_float_to_int(test, device, dtype, register_kernels=False):
766
- np.random.seed(123)
786
+ rng = np.random.default_rng(123)
767
787
 
768
788
  tol = {
769
789
  np.float16: 5.0e-3,
@@ -783,6 +803,7 @@ def test_float_to_int(test, device, dtype, register_kernels=False):
783
803
  outputs[2, i] = wp.trunc(inputs[2, i])
784
804
  outputs[3, i] = wp.floor(inputs[3, i])
785
805
  outputs[4, i] = wp.ceil(inputs[4, i])
806
+ outputs[5, i] = wp.frac(inputs[5, i])
786
807
 
787
808
  kernel = getkernel(check_float_to_int, suffix=dtype.__name__)
788
809
  output_select_kernel = get_select_kernel2(wptype)
@@ -790,7 +811,7 @@ def test_float_to_int(test, device, dtype, register_kernels=False):
790
811
  if register_kernels:
791
812
  return
792
813
 
793
- inputs = wp.array(np.random.randn(5, 10).astype(dtype), dtype=wptype, requires_grad=True, device=device)
814
+ inputs = wp.array(rng.standard_normal(size=(6, 10)).astype(dtype), dtype=wptype, requires_grad=True, device=device)
794
815
  outputs = wp.zeros_like(inputs)
795
816
 
796
817
  wp.launch(kernel, dim=1, inputs=[inputs], outputs=[outputs], device=device)
@@ -800,6 +821,7 @@ def test_float_to_int(test, device, dtype, register_kernels=False):
800
821
  assert_np_equal(outputs.numpy()[2], np.trunc(inputs.numpy()[2]))
801
822
  assert_np_equal(outputs.numpy()[3], np.floor(inputs.numpy()[3]))
802
823
  assert_np_equal(outputs.numpy()[4], np.ceil(inputs.numpy()[4]))
824
+ assert_np_equal(outputs.numpy()[5], np.modf(inputs.numpy()[5])[0])
803
825
 
804
826
  # all the gradients should be zero as these functions are piecewise constant:
805
827
 
@@ -816,8 +838,38 @@ def test_float_to_int(test, device, dtype, register_kernels=False):
816
838
  tape.zero()
817
839
 
818
840
 
841
+ def test_infinity(test, device, dtype, register_kernels=False):
842
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
843
+
844
+ def check_infinity(
845
+ outputs: wp.array(dtype=wptype),
846
+ ):
847
+ outputs[0] = wptype(wp.inf)
848
+ outputs[1] = wptype(-wp.inf)
849
+ outputs[2] = wptype(2.0 * wp.inf)
850
+ outputs[3] = wptype(-2.0 * wp.inf)
851
+ outputs[4] = wptype(2.0 / 0.0)
852
+ outputs[5] = wptype(-2.0 / 0.0)
853
+
854
+ kernel = getkernel(check_infinity, suffix=dtype.__name__)
855
+
856
+ if register_kernels:
857
+ return
858
+
859
+ outputs = wp.zeros(6, dtype=wptype, device=device)
860
+
861
+ wp.launch(kernel, dim=1, inputs=[], outputs=[outputs], device=device)
862
+
863
+ test.assertEqual(outputs.numpy()[0], math.inf)
864
+ test.assertEqual(outputs.numpy()[1], -math.inf)
865
+ test.assertEqual(outputs.numpy()[2], math.inf)
866
+ test.assertEqual(outputs.numpy()[3], -math.inf)
867
+ test.assertEqual(outputs.numpy()[4], math.inf)
868
+ test.assertEqual(outputs.numpy()[5], -math.inf)
869
+
870
+
819
871
  def test_interp(test, device, dtype, register_kernels=False):
820
- np.random.seed(123)
872
+ rng = np.random.default_rng(123)
821
873
 
822
874
  tol = {
823
875
  np.float16: 1.0e-2,
@@ -844,11 +896,11 @@ def test_interp(test, device, dtype, register_kernels=False):
844
896
  if register_kernels:
845
897
  return
846
898
 
847
- e0 = randvals([2, 10], dtype)
848
- e1 = e0 + randvals([2, 10], dtype) + 0.1
899
+ e0 = randvals(rng, [2, 10], dtype)
900
+ e1 = e0 + randvals(rng, [2, 10], dtype) + 0.1
849
901
  in1 = wp.array(e0, dtype=wptype, requires_grad=True, device=device)
850
902
  in2 = wp.array(e1, dtype=wptype, requires_grad=True, device=device)
851
- in3 = wp.array(randvals([2, 10], dtype), dtype=wptype, requires_grad=True, device=device)
903
+ in3 = wp.array(randvals(rng, [2, 10], dtype), dtype=wptype, requires_grad=True, device=device)
852
904
 
853
905
  outputs = wp.zeros_like(in1)
854
906
 
@@ -948,7 +1000,7 @@ def test_interp(test, device, dtype, register_kernels=False):
948
1000
 
949
1001
 
950
1002
  def test_clamp(test, device, dtype, register_kernels=False):
951
- np.random.seed(123)
1003
+ rng = np.random.default_rng(123)
952
1004
 
953
1005
  tol = {
954
1006
  np.float16: 5.0e-3,
@@ -974,9 +1026,9 @@ def test_clamp(test, device, dtype, register_kernels=False):
974
1026
  if register_kernels:
975
1027
  return
976
1028
 
977
- in1 = wp.array(randvals([100], dtype), dtype=wptype, requires_grad=True, device=device)
978
- starts = randvals([100], dtype)
979
- diffs = np.abs(randvals([100], dtype))
1029
+ in1 = wp.array(randvals(rng, [100], dtype), dtype=wptype, requires_grad=True, device=device)
1030
+ starts = randvals(rng, [100], dtype)
1031
+ diffs = np.abs(randvals(rng, [100], dtype))
980
1032
  in2 = wp.array(starts, dtype=wptype, requires_grad=True, device=device)
981
1033
  in3 = wp.array(starts + diffs, dtype=wptype, requires_grad=True, device=device)
982
1034
  outputs = wp.zeros_like(in1)
@@ -1020,51 +1072,53 @@ def test_clamp(test, device, dtype, register_kernels=False):
1020
1072
  tape.zero()
1021
1073
 
1022
1074
 
1023
- def register(parent):
1024
- devices = get_test_devices()
1075
+ devices = get_test_devices()
1025
1076
 
1026
- class TestArithmetic(parent):
1027
- pass
1028
1077
 
1029
- # these unary ops only make sense for signed values:
1030
- for dtype in np_signed_int_types + np_float_types:
1031
- add_function_test_register_kernel(
1032
- TestArithmetic, f"test_unary_ops_{dtype.__name__}", test_unary_ops, devices=devices, dtype=dtype
1033
- )
1078
+ class TestArithmetic(unittest.TestCase):
1079
+ pass
1034
1080
 
1035
- for dtype in np_float_types:
1036
- add_function_test_register_kernel(
1037
- TestArithmetic, f"test_special_funcs_{dtype.__name__}", test_special_funcs, devices=devices, dtype=dtype
1038
- )
1039
- add_function_test_register_kernel(
1040
- TestArithmetic,
1041
- f"test_special_funcs_2arg_{dtype.__name__}",
1042
- test_special_funcs_2arg,
1043
- devices=devices,
1044
- dtype=dtype,
1045
- )
1046
- add_function_test_register_kernel(
1047
- TestArithmetic, f"test_interp_{dtype.__name__}", test_interp, devices=devices, dtype=dtype
1048
- )
1049
- add_function_test_register_kernel(
1050
- TestArithmetic, f"test_float_to_int_{dtype.__name__}", test_float_to_int, devices=devices, dtype=dtype
1051
- )
1052
1081
 
1053
- for dtype in np_scalar_types:
1054
- add_function_test_register_kernel(
1055
- TestArithmetic, f"test_clamp_{dtype.__name__}", test_clamp, devices=devices, dtype=dtype
1056
- )
1057
- add_function_test_register_kernel(
1058
- TestArithmetic, f"test_nonzero_{dtype.__name__}", test_nonzero, devices=devices, dtype=dtype
1059
- )
1060
- add_function_test(TestArithmetic, f"test_arrays_{dtype.__name__}", test_arrays, devices=devices, dtype=dtype)
1061
- add_function_test_register_kernel(
1062
- TestArithmetic, f"test_binary_ops_{dtype.__name__}", test_binary_ops, devices=devices, dtype=dtype
1063
- )
1082
+ # these unary ops only make sense for signed values:
1083
+ for dtype in np_signed_int_types + np_float_types:
1084
+ add_function_test_register_kernel(
1085
+ TestArithmetic, f"test_unary_ops_{dtype.__name__}", test_unary_ops, devices=devices, dtype=dtype
1086
+ )
1064
1087
 
1065
- return TestArithmetic
1088
+ for dtype in np_float_types:
1089
+ add_function_test_register_kernel(
1090
+ TestArithmetic, f"test_special_funcs_{dtype.__name__}", test_special_funcs, devices=devices, dtype=dtype
1091
+ )
1092
+ add_function_test_register_kernel(
1093
+ TestArithmetic,
1094
+ f"test_special_funcs_2arg_{dtype.__name__}",
1095
+ test_special_funcs_2arg,
1096
+ devices=devices,
1097
+ dtype=dtype,
1098
+ )
1099
+ add_function_test_register_kernel(
1100
+ TestArithmetic, f"test_interp_{dtype.__name__}", test_interp, devices=devices, dtype=dtype
1101
+ )
1102
+ add_function_test_register_kernel(
1103
+ TestArithmetic, f"test_float_to_int_{dtype.__name__}", test_float_to_int, devices=devices, dtype=dtype
1104
+ )
1105
+ add_function_test_register_kernel(
1106
+ TestArithmetic, f"test_infinity_{dtype.__name__}", test_infinity, devices=devices, dtype=dtype
1107
+ )
1108
+
1109
+ for dtype in np_scalar_types:
1110
+ add_function_test_register_kernel(
1111
+ TestArithmetic, f"test_clamp_{dtype.__name__}", test_clamp, devices=devices, dtype=dtype
1112
+ )
1113
+ add_function_test_register_kernel(
1114
+ TestArithmetic, f"test_nonzero_{dtype.__name__}", test_nonzero, devices=devices, dtype=dtype
1115
+ )
1116
+ add_function_test(TestArithmetic, f"test_arrays_{dtype.__name__}", test_arrays, devices=devices, dtype=dtype)
1117
+ add_function_test_register_kernel(
1118
+ TestArithmetic, f"test_binary_ops_{dtype.__name__}", test_binary_ops, devices=devices, dtype=dtype
1119
+ )
1066
1120
 
1067
1121
 
1068
1122
  if __name__ == "__main__":
1069
- c = register(unittest.TestCase)
1123
+ wp.build.clear_kernel_cache()
1070
1124
  unittest.main(verbosity=2, failfast=False)