warp-lang 0.9.0__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (315) hide show
  1. warp/__init__.py +15 -7
  2. warp/__init__.pyi +1 -0
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +22 -443
  6. warp/build_dll.py +384 -0
  7. warp/builtins.py +998 -488
  8. warp/codegen.py +1307 -739
  9. warp/config.py +5 -3
  10. warp/constants.py +6 -0
  11. warp/context.py +1291 -548
  12. warp/dlpack.py +31 -31
  13. warp/fabric.py +326 -0
  14. warp/fem/__init__.py +27 -0
  15. warp/fem/cache.py +389 -0
  16. warp/fem/dirichlet.py +181 -0
  17. warp/fem/domain.py +263 -0
  18. warp/fem/field/__init__.py +101 -0
  19. warp/fem/field/field.py +149 -0
  20. warp/fem/field/nodal_field.py +299 -0
  21. warp/fem/field/restriction.py +21 -0
  22. warp/fem/field/test.py +181 -0
  23. warp/fem/field/trial.py +183 -0
  24. warp/fem/geometry/__init__.py +19 -0
  25. warp/fem/geometry/closest_point.py +70 -0
  26. warp/fem/geometry/deformed_geometry.py +271 -0
  27. warp/fem/geometry/element.py +744 -0
  28. warp/fem/geometry/geometry.py +186 -0
  29. warp/fem/geometry/grid_2d.py +373 -0
  30. warp/fem/geometry/grid_3d.py +435 -0
  31. warp/fem/geometry/hexmesh.py +953 -0
  32. warp/fem/geometry/partition.py +376 -0
  33. warp/fem/geometry/quadmesh_2d.py +532 -0
  34. warp/fem/geometry/tetmesh.py +840 -0
  35. warp/fem/geometry/trimesh_2d.py +577 -0
  36. warp/fem/integrate.py +1616 -0
  37. warp/fem/operator.py +191 -0
  38. warp/fem/polynomial.py +213 -0
  39. warp/fem/quadrature/__init__.py +2 -0
  40. warp/fem/quadrature/pic_quadrature.py +245 -0
  41. warp/fem/quadrature/quadrature.py +294 -0
  42. warp/fem/space/__init__.py +292 -0
  43. warp/fem/space/basis_space.py +489 -0
  44. warp/fem/space/collocated_function_space.py +105 -0
  45. warp/fem/space/dof_mapper.py +236 -0
  46. warp/fem/space/function_space.py +145 -0
  47. warp/fem/space/grid_2d_function_space.py +267 -0
  48. warp/fem/space/grid_3d_function_space.py +306 -0
  49. warp/fem/space/hexmesh_function_space.py +352 -0
  50. warp/fem/space/partition.py +350 -0
  51. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  52. warp/fem/space/restriction.py +160 -0
  53. warp/fem/space/shape/__init__.py +15 -0
  54. warp/fem/space/shape/cube_shape_function.py +738 -0
  55. warp/fem/space/shape/shape_function.py +103 -0
  56. warp/fem/space/shape/square_shape_function.py +611 -0
  57. warp/fem/space/shape/tet_shape_function.py +567 -0
  58. warp/fem/space/shape/triangle_shape_function.py +429 -0
  59. warp/fem/space/tetmesh_function_space.py +292 -0
  60. warp/fem/space/topology.py +295 -0
  61. warp/fem/space/trimesh_2d_function_space.py +221 -0
  62. warp/fem/types.py +77 -0
  63. warp/fem/utils.py +495 -0
  64. warp/native/array.h +164 -55
  65. warp/native/builtin.h +150 -174
  66. warp/native/bvh.cpp +75 -328
  67. warp/native/bvh.cu +406 -23
  68. warp/native/bvh.h +37 -45
  69. warp/native/clang/clang.cpp +136 -24
  70. warp/native/crt.cpp +1 -76
  71. warp/native/crt.h +111 -104
  72. warp/native/cuda_crt.h +1049 -0
  73. warp/native/cuda_util.cpp +15 -3
  74. warp/native/cuda_util.h +3 -1
  75. warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
  76. warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
  77. warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
  78. warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
  79. warp/native/cutlass/tools/library/scripts/library.py +799 -0
  80. warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
  81. warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
  82. warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
  83. warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
  84. warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
  85. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
  86. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
  87. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
  88. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
  89. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
  90. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
  91. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
  92. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
  93. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
  94. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
  95. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
  96. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
  97. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
  98. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
  99. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
  100. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
  101. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
  102. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
  103. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
  104. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
  105. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
  106. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
  107. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
  108. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
  109. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
  110. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
  111. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
  112. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
  113. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
  114. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
  115. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
  116. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
  117. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  118. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  119. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
  120. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
  121. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
  122. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  123. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  124. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
  125. warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
  126. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
  127. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
  128. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
  129. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
  130. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
  131. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
  132. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
  133. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
  134. warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
  135. warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
  136. warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
  137. warp/native/cutlass/tools/library/scripts/rt.py +796 -0
  138. warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
  139. warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
  140. warp/native/cutlass_gemm.cu +5 -3
  141. warp/native/exports.h +1240 -949
  142. warp/native/fabric.h +228 -0
  143. warp/native/hashgrid.cpp +4 -4
  144. warp/native/hashgrid.h +22 -2
  145. warp/native/initializer_array.h +2 -2
  146. warp/native/intersect.h +22 -7
  147. warp/native/intersect_adj.h +8 -8
  148. warp/native/intersect_tri.h +13 -16
  149. warp/native/marching.cu +157 -161
  150. warp/native/mat.h +119 -19
  151. warp/native/matnn.h +2 -2
  152. warp/native/mesh.cpp +108 -83
  153. warp/native/mesh.cu +243 -6
  154. warp/native/mesh.h +1547 -458
  155. warp/native/nanovdb/NanoVDB.h +1 -1
  156. warp/native/noise.h +272 -329
  157. warp/native/quat.h +51 -8
  158. warp/native/rand.h +45 -35
  159. warp/native/range.h +6 -2
  160. warp/native/reduce.cpp +157 -0
  161. warp/native/reduce.cu +348 -0
  162. warp/native/runlength_encode.cpp +62 -0
  163. warp/native/runlength_encode.cu +46 -0
  164. warp/native/scan.cu +11 -13
  165. warp/native/scan.h +1 -0
  166. warp/native/solid_angle.h +442 -0
  167. warp/native/sort.cpp +13 -0
  168. warp/native/sort.cu +9 -1
  169. warp/native/sparse.cpp +338 -0
  170. warp/native/sparse.cu +545 -0
  171. warp/native/spatial.h +2 -2
  172. warp/native/temp_buffer.h +30 -0
  173. warp/native/vec.h +126 -24
  174. warp/native/volume.h +120 -0
  175. warp/native/warp.cpp +658 -53
  176. warp/native/warp.cu +660 -68
  177. warp/native/warp.h +112 -12
  178. warp/optim/__init__.py +1 -0
  179. warp/optim/linear.py +922 -0
  180. warp/optim/sgd.py +92 -0
  181. warp/render/render_opengl.py +392 -152
  182. warp/render/render_usd.py +11 -11
  183. warp/sim/__init__.py +2 -2
  184. warp/sim/articulation.py +385 -185
  185. warp/sim/collide.py +21 -8
  186. warp/sim/import_mjcf.py +297 -106
  187. warp/sim/import_urdf.py +389 -210
  188. warp/sim/import_usd.py +198 -97
  189. warp/sim/inertia.py +17 -18
  190. warp/sim/integrator_euler.py +14 -8
  191. warp/sim/integrator_xpbd.py +161 -19
  192. warp/sim/model.py +795 -291
  193. warp/sim/optimizer.py +2 -6
  194. warp/sim/render.py +65 -3
  195. warp/sim/utils.py +3 -0
  196. warp/sparse.py +1227 -0
  197. warp/stubs.py +665 -223
  198. warp/tape.py +66 -15
  199. warp/tests/__main__.py +3 -6
  200. warp/tests/assets/curlnoise_golden.npy +0 -0
  201. warp/tests/assets/pnoise_golden.npy +0 -0
  202. warp/tests/assets/torus.usda +105 -105
  203. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  204. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  205. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  206. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  207. warp/tests/aux_test_unresolved_func.py +14 -0
  208. warp/tests/aux_test_unresolved_symbol.py +14 -0
  209. warp/tests/disabled_kinematics.py +239 -0
  210. warp/tests/run_coverage_serial.py +31 -0
  211. warp/tests/test_adam.py +103 -106
  212. warp/tests/test_arithmetic.py +128 -74
  213. warp/tests/test_array.py +1497 -211
  214. warp/tests/test_array_reduce.py +150 -0
  215. warp/tests/test_atomic.py +64 -28
  216. warp/tests/test_bool.py +99 -0
  217. warp/tests/test_builtins_resolution.py +1292 -0
  218. warp/tests/test_bvh.py +75 -43
  219. warp/tests/test_closest_point_edge_edge.py +54 -57
  220. warp/tests/test_codegen.py +233 -128
  221. warp/tests/test_compile_consts.py +28 -20
  222. warp/tests/test_conditional.py +108 -24
  223. warp/tests/test_copy.py +10 -12
  224. warp/tests/test_ctypes.py +112 -88
  225. warp/tests/test_dense.py +21 -14
  226. warp/tests/test_devices.py +98 -0
  227. warp/tests/test_dlpack.py +136 -108
  228. warp/tests/test_examples.py +277 -0
  229. warp/tests/test_fabricarray.py +955 -0
  230. warp/tests/test_fast_math.py +15 -11
  231. warp/tests/test_fem.py +1271 -0
  232. warp/tests/test_fp16.py +53 -19
  233. warp/tests/test_func.py +187 -74
  234. warp/tests/test_generics.py +194 -49
  235. warp/tests/test_grad.py +180 -116
  236. warp/tests/test_grad_customs.py +176 -0
  237. warp/tests/test_hash_grid.py +52 -37
  238. warp/tests/test_import.py +10 -23
  239. warp/tests/test_indexedarray.py +577 -24
  240. warp/tests/test_intersect.py +18 -9
  241. warp/tests/test_large.py +141 -0
  242. warp/tests/test_launch.py +251 -15
  243. warp/tests/test_lerp.py +64 -65
  244. warp/tests/test_linear_solvers.py +154 -0
  245. warp/tests/test_lvalue.py +493 -0
  246. warp/tests/test_marching_cubes.py +12 -13
  247. warp/tests/test_mat.py +508 -2778
  248. warp/tests/test_mat_lite.py +115 -0
  249. warp/tests/test_mat_scalar_ops.py +2889 -0
  250. warp/tests/test_math.py +103 -9
  251. warp/tests/test_matmul.py +305 -69
  252. warp/tests/test_matmul_lite.py +410 -0
  253. warp/tests/test_mesh.py +71 -14
  254. warp/tests/test_mesh_query_aabb.py +41 -25
  255. warp/tests/test_mesh_query_point.py +325 -34
  256. warp/tests/test_mesh_query_ray.py +39 -22
  257. warp/tests/test_mlp.py +30 -22
  258. warp/tests/test_model.py +92 -89
  259. warp/tests/test_modules_lite.py +39 -0
  260. warp/tests/test_multigpu.py +88 -114
  261. warp/tests/test_noise.py +12 -11
  262. warp/tests/test_operators.py +16 -20
  263. warp/tests/test_options.py +11 -11
  264. warp/tests/test_pinned.py +17 -18
  265. warp/tests/test_print.py +32 -11
  266. warp/tests/test_quat.py +275 -129
  267. warp/tests/test_rand.py +18 -16
  268. warp/tests/test_reload.py +38 -34
  269. warp/tests/test_rounding.py +50 -43
  270. warp/tests/test_runlength_encode.py +190 -0
  271. warp/tests/test_smoothstep.py +9 -11
  272. warp/tests/test_snippet.py +143 -0
  273. warp/tests/test_sparse.py +460 -0
  274. warp/tests/test_spatial.py +276 -243
  275. warp/tests/test_streams.py +110 -85
  276. warp/tests/test_struct.py +331 -85
  277. warp/tests/test_tape.py +39 -21
  278. warp/tests/test_torch.py +118 -89
  279. warp/tests/test_transient_module.py +12 -13
  280. warp/tests/test_types.py +614 -0
  281. warp/tests/test_utils.py +494 -0
  282. warp/tests/test_vec.py +354 -1987
  283. warp/tests/test_vec_lite.py +73 -0
  284. warp/tests/test_vec_scalar_ops.py +2099 -0
  285. warp/tests/test_volume.py +457 -293
  286. warp/tests/test_volume_write.py +124 -134
  287. warp/tests/unittest_serial.py +35 -0
  288. warp/tests/unittest_suites.py +341 -0
  289. warp/tests/unittest_utils.py +568 -0
  290. warp/tests/unused_test_misc.py +71 -0
  291. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  292. warp/thirdparty/appdirs.py +36 -45
  293. warp/thirdparty/unittest_parallel.py +549 -0
  294. warp/torch.py +72 -30
  295. warp/types.py +1744 -713
  296. warp/utils.py +360 -350
  297. warp_lang-0.11.0.dist-info/LICENSE.md +36 -0
  298. warp_lang-0.11.0.dist-info/METADATA +238 -0
  299. warp_lang-0.11.0.dist-info/RECORD +332 -0
  300. {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
  301. warp/bin/warp-clang.exp +0 -0
  302. warp/bin/warp-clang.lib +0 -0
  303. warp/bin/warp.exp +0 -0
  304. warp/bin/warp.lib +0 -0
  305. warp/tests/test_all.py +0 -215
  306. warp/tests/test_array_scan.py +0 -60
  307. warp/tests/test_base.py +0 -208
  308. warp/tests/test_unresolved_func.py +0 -7
  309. warp/tests/test_unresolved_symbol.py +0 -7
  310. warp_lang-0.9.0.dist-info/METADATA +0 -20
  311. warp_lang-0.9.0.dist-info/RECORD +0 -177
  312. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  313. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  314. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  315. {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
warp/tests/test_rand.py CHANGED
@@ -5,12 +5,15 @@
5
5
  # distribution of this software and related documentation without an express
6
6
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
7
 
8
+ import unittest
9
+
8
10
  import numpy as np
9
11
 
12
+ import warp as wp
13
+ from warp.tests.unittest_utils import *
14
+
10
15
  # import matplotlib.pyplot as plt
11
16
 
12
- import warp as wp
13
- from warp.tests.test_base import *
14
17
 
15
18
  wp.init()
16
19
 
@@ -60,7 +63,7 @@ def test_rand(test, device):
60
63
  wp.copy(int_ab_host, int_ab_device)
61
64
  wp.copy(float_01_host, float_01_device)
62
65
  wp.copy(float_ab_host, float_ab_device)
63
- wp.synchronize()
66
+ wp.synchronize_device(device)
64
67
 
65
68
  int_a = int_a_host.numpy()
66
69
  int_ab = int_ab_host.numpy()
@@ -243,10 +246,10 @@ def test_poisson(test, device):
243
246
  # _ = plt.hist(poisson_high.numpy(), bins)
244
247
  # plt.show()
245
248
 
246
- np.random.default_rng(seed)
249
+ rng = np.random.default_rng(seed)
247
250
 
248
- np_poisson_low = np.random.poisson(3.0, N)
249
- np_poisson_high = np.random.poisson(42.0, N)
251
+ np_poisson_low = rng.poisson(lam=3.0, size=N)
252
+ np_poisson_high = rng.poisson(lam=42.0, size=N)
250
253
 
251
254
  poisson_low_mean = np.mean(poisson_low.numpy())
252
255
  np_poisson_low_mean = np.mean(np_poisson_low)
@@ -267,20 +270,19 @@ def test_poisson(test, device):
267
270
  test.assertTrue(np.abs(poisson_high_std - np_poisson_high_std) <= 2e-1)
268
271
 
269
272
 
270
- def register(parent):
271
- devices = get_test_devices()
273
+ devices = get_test_devices()
274
+
272
275
 
273
- class TestNoise(parent):
274
- pass
276
+ class TestRand(unittest.TestCase):
277
+ pass
275
278
 
276
- add_function_test(TestNoise, "test_rand", test_rand, devices=devices)
277
- add_function_test(TestNoise, "test_sample_cdf", test_sample_cdf, devices=devices)
278
- add_function_test(TestNoise, "test_sampling_methods", test_sampling_methods, devices=devices)
279
- add_function_test(TestNoise, "test_poisson", test_poisson, devices=devices)
280
279
 
281
- return TestNoise
280
+ add_function_test(TestRand, "test_rand", test_rand, devices=devices)
281
+ add_function_test(TestRand, "test_sample_cdf", test_sample_cdf, devices=devices)
282
+ add_function_test(TestRand, "test_sampling_methods", test_sampling_methods, devices=devices)
283
+ add_function_test(TestRand, "test_poisson", test_poisson, devices=devices)
282
284
 
283
285
 
284
286
  if __name__ == "__main__":
285
- c = register(unittest.TestCase)
287
+ wp.build.clear_kernel_cache()
286
288
  unittest.main(verbosity=2)
warp/tests/test_reload.py CHANGED
@@ -5,29 +5,34 @@
5
5
  # distribution of this software and related documentation without an express
6
6
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
7
 
8
- import numpy as np
9
- import warp as wp
8
+ import importlib
9
+ import os
10
+ import unittest
10
11
 
11
- import math
12
+ import numpy as np
12
13
 
13
14
  import warp as wp
14
- from warp.tests.test_base import *
15
15
 
16
- import unittest
17
- import importlib
18
- import os
16
+ # dummy modules used for testing reload with dependencies
17
+ import warp.tests.aux_test_dependent as test_dependent
18
+ import warp.tests.aux_test_reference as test_reference
19
+ import warp.tests.aux_test_reference_reference as test_reference_reference
19
20
 
20
21
  # dummy module used for testing reload
21
- import warp.tests.test_square as test_square
22
-
23
- # dummy modules used for testing reload with dependencies
24
- import warp.tests.test_dependent as test_dependent
25
- import warp.tests.test_reference as test_reference
26
- import warp.tests.test_reference_reference as test_reference_reference
22
+ import warp.tests.aux_test_square as test_square
23
+ from warp.tests.unittest_utils import *
27
24
 
28
25
  wp.init()
29
26
 
30
27
 
28
+ def reload_module(module):
29
+ # Clearing the .pyc file associated with a module is a necessary workaround
30
+ # for `importlib.reload` to work as expected when run from within Kit.
31
+ cache_file = importlib.util.cache_from_source(module.__file__)
32
+ os.remove(cache_file)
33
+ importlib.reload(module)
34
+
35
+
31
36
  def test_redefine(test, device):
32
37
  # --------------------------------------------
33
38
  # first pass
@@ -102,32 +107,32 @@ def run(expect, device):
102
107
 
103
108
  def test_reload(test, device):
104
109
  # write out the module python and import it
105
- f = open(os.path.abspath(os.path.join(os.path.dirname(__file__), "test_square.py")), "w")
110
+ f = open(os.path.abspath(os.path.join(os.path.dirname(__file__), "aux_test_square.py")), "w")
106
111
  f.writelines(square_two)
107
112
  f.flush()
108
113
  f.close()
109
114
 
110
- importlib.reload(test_square)
115
+ reload_module(test_square)
111
116
  test_square.run(expect=4.0, device=device) # 2*2=4
112
117
 
113
- f = open(os.path.abspath(os.path.join(os.path.dirname(__file__), "test_square.py")), "w")
118
+ f = open(os.path.abspath(os.path.join(os.path.dirname(__file__), "aux_test_square.py")), "w")
114
119
  f.writelines(square_four)
115
120
  f.flush()
116
121
  f.close()
117
122
 
118
123
  # reload module, this should trigger all of the funcs / kernels to be updated
119
- importlib.reload(test_square)
124
+ reload_module(test_square)
120
125
  test_square.run(expect=16.0, device=device) # 4*4 = 16
121
126
 
122
127
 
123
128
  def test_reload_class(test, device):
124
129
  def test_func():
125
- import warp.tests.test_class_kernel
126
- from warp.tests.test_class_kernel import ClassKernelTest
127
-
128
130
  import importlib as imp
129
131
 
130
- imp.reload(warp.tests.test_class_kernel)
132
+ import warp.tests.aux_test_class_kernel
133
+ from warp.tests.aux_test_class_kernel import ClassKernelTest
134
+
135
+ imp.reload(warp.tests.aux_test_class_kernel)
131
136
 
132
137
  ctest = ClassKernelTest(device)
133
138
  expected = np.zeros((10, 3, 3), dtype=np.float32)
@@ -141,7 +146,7 @@ def test_reload_class(test, device):
141
146
  template_ref = """# This file is used to test reloading module references.
142
147
 
143
148
  import warp as wp
144
- import warp.tests.test_reference_reference as refref
149
+ import warp.tests.aux_test_reference_reference as refref
145
150
 
146
151
  wp.init()
147
152
 
@@ -165,8 +170,8 @@ def more_magic():
165
170
 
166
171
 
167
172
  def test_reload_references(test, device):
168
- path_ref = os.path.abspath(os.path.join(os.path.dirname(__file__), "test_reference.py"))
169
- path_refref = os.path.abspath(os.path.join(os.path.dirname(__file__), "test_reference_reference.py"))
173
+ path_ref = os.path.abspath(os.path.join(os.path.dirname(__file__), "aux_test_reference.py"))
174
+ path_refref = os.path.abspath(os.path.join(os.path.dirname(__file__), "aux_test_reference_reference.py"))
170
175
 
171
176
  # rewrite both dependency modules and reload them
172
177
  with open(path_ref, "w") as f:
@@ -194,20 +199,19 @@ def test_reload_references(test, device):
194
199
  test_dependent.run(expect=4.0, device=device) # 2 * 2 = 4
195
200
 
196
201
 
197
- def register(parent):
198
- devices = get_test_devices()
202
+ devices = get_test_devices()
203
+
199
204
 
200
- class TestReload(parent):
201
- pass
205
+ class TestReload(unittest.TestCase):
206
+ pass
202
207
 
203
- add_function_test(TestReload, "test_redefine", test_redefine, devices=devices)
204
- add_function_test(TestReload, "test_reload", test_reload, devices=devices)
205
- add_function_test(TestReload, "test_reload_class", test_reload_class, devices=devices)
206
- add_function_test(TestReload, "test_reload_references", test_reload_references, devices=devices)
207
208
 
208
- return TestReload
209
+ add_function_test(TestReload, "test_redefine", test_redefine, devices=devices)
210
+ add_function_test(TestReload, "test_reload", test_reload, devices=devices)
211
+ add_function_test(TestReload, "test_reload_class", test_reload_class, devices=devices)
212
+ add_function_test(TestReload, "test_reload_references", test_reload_references, devices=devices)
209
213
 
210
214
 
211
215
  if __name__ == "__main__":
212
- c = register(unittest.TestCase)
216
+ wp.build.clear_kernel_cache()
213
217
  unittest.main(verbosity=2, failfast=False)
@@ -5,11 +5,13 @@
5
5
  # distribution of this software and related documentation without an express
6
6
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
7
 
8
- import warp as wp
9
- from warp.tests.test_base import *
8
+ import unittest
10
9
 
11
10
  import numpy as np
12
11
 
12
+ import warp as wp
13
+ from warp.tests.unittest_utils import *
14
+
13
15
  compare_to_numpy = False
14
16
  print_results = False
15
17
 
@@ -25,6 +27,7 @@ def test_kernel(
25
27
  x_cast: wp.array(dtype=float),
26
28
  x_floor: wp.array(dtype=float),
27
29
  x_ceil: wp.array(dtype=float),
30
+ x_frac: wp.array(dtype=float),
28
31
  ):
29
32
  tid = wp.tid()
30
33
 
@@ -34,6 +37,7 @@ def test_kernel(
34
37
  x_cast[tid] = float(int(x[tid]))
35
38
  x_floor[tid] = wp.floor(x[tid])
36
39
  x_ceil[tid] = wp.ceil(x[tid])
40
+ x_frac[tid] = wp.frac(x[tid])
37
41
 
38
42
 
39
43
  def test_rounding(test, device):
@@ -82,8 +86,11 @@ def test_rounding(test, device):
82
86
  x_cast = wp.empty(N, dtype=float, device=device)
83
87
  x_floor = wp.empty(N, dtype=float, device=device)
84
88
  x_ceil = wp.empty(N, dtype=float, device=device)
89
+ x_frac = wp.empty(N, dtype=float, device=device)
85
90
 
86
- wp.launch(kernel=test_kernel, dim=N, inputs=[x, x_round, x_rint, x_trunc, x_cast, x_floor, x_ceil], device=device)
91
+ wp.launch(
92
+ kernel=test_kernel, dim=N, inputs=[x, x_round, x_rint, x_trunc, x_cast, x_floor, x_ceil, x_frac], device=device
93
+ )
87
94
 
88
95
  wp.synchronize()
89
96
 
@@ -93,46 +100,47 @@ def test_rounding(test, device):
93
100
  nx_cast = x_cast.numpy().reshape(N)
94
101
  nx_floor = x_floor.numpy().reshape(N)
95
102
  nx_ceil = x_ceil.numpy().reshape(N)
103
+ nx_frac = x_frac.numpy().reshape(N)
96
104
 
97
- tab = np.stack([nx, nx_round, nx_rint, nx_trunc, nx_cast, nx_floor, nx_ceil], axis=1)
105
+ tab = np.stack([nx, nx_round, nx_rint, nx_trunc, nx_cast, nx_floor, nx_ceil, nx_frac], axis=1)
98
106
 
99
107
  golden = np.array(
100
108
  [
101
- [4.9, 5.0, 5.0, 4.0, 4.0, 4.0, 5.0],
102
- [4.5, 5.0, 4.0, 4.0, 4.0, 4.0, 5.0],
103
- [4.1, 4.0, 4.0, 4.0, 4.0, 4.0, 5.0],
104
- [3.9, 4.0, 4.0, 3.0, 3.0, 3.0, 4.0],
105
- [3.5, 4.0, 4.0, 3.0, 3.0, 3.0, 4.0],
106
- [3.1, 3.0, 3.0, 3.0, 3.0, 3.0, 4.0],
107
- [2.9, 3.0, 3.0, 2.0, 2.0, 2.0, 3.0],
108
- [2.5, 3.0, 2.0, 2.0, 2.0, 2.0, 3.0],
109
- [2.1, 2.0, 2.0, 2.0, 2.0, 2.0, 3.0],
110
- [1.9, 2.0, 2.0, 1.0, 1.0, 1.0, 2.0],
111
- [1.5, 2.0, 2.0, 1.0, 1.0, 1.0, 2.0],
112
- [1.1, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0],
113
- [0.9, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0],
114
- [0.5, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0],
115
- [0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0],
116
- [-0.1, -0.0, -0.0, -0.0, 0.0, -1.0, -0.0],
117
- [-0.5, -1.0, -0.0, -0.0, 0.0, -1.0, -0.0],
118
- [-0.9, -1.0, -1.0, -0.0, 0.0, -1.0, -0.0],
119
- [-1.1, -1.0, -1.0, -1.0, -1.0, -2.0, -1.0],
120
- [-1.5, -2.0, -2.0, -1.0, -1.0, -2.0, -1.0],
121
- [-1.9, -2.0, -2.0, -1.0, -1.0, -2.0, -1.0],
122
- [-2.1, -2.0, -2.0, -2.0, -2.0, -3.0, -2.0],
123
- [-2.5, -3.0, -2.0, -2.0, -2.0, -3.0, -2.0],
124
- [-2.9, -3.0, -3.0, -2.0, -2.0, -3.0, -2.0],
125
- [-3.1, -3.0, -3.0, -3.0, -3.0, -4.0, -3.0],
126
- [-3.5, -4.0, -4.0, -3.0, -3.0, -4.0, -3.0],
127
- [-3.9, -4.0, -4.0, -3.0, -3.0, -4.0, -3.0],
128
- [-4.1, -4.0, -4.0, -4.0, -4.0, -5.0, -4.0],
129
- [-4.5, -5.0, -4.0, -4.0, -4.0, -5.0, -4.0],
130
- [-4.9, -5.0, -5.0, -4.0, -4.0, -5.0, -4.0],
109
+ [4.9, 5.0, 5.0, 4.0, 4.0, 4.0, 5.0, 0.9],
110
+ [4.5, 5.0, 4.0, 4.0, 4.0, 4.0, 5.0, 0.5],
111
+ [4.1, 4.0, 4.0, 4.0, 4.0, 4.0, 5.0, 0.1],
112
+ [3.9, 4.0, 4.0, 3.0, 3.0, 3.0, 4.0, 0.9],
113
+ [3.5, 4.0, 4.0, 3.0, 3.0, 3.0, 4.0, 0.5],
114
+ [3.1, 3.0, 3.0, 3.0, 3.0, 3.0, 4.0, 0.1],
115
+ [2.9, 3.0, 3.0, 2.0, 2.0, 2.0, 3.0, 0.9],
116
+ [2.5, 3.0, 2.0, 2.0, 2.0, 2.0, 3.0, 0.5],
117
+ [2.1, 2.0, 2.0, 2.0, 2.0, 2.0, 3.0, 0.1],
118
+ [1.9, 2.0, 2.0, 1.0, 1.0, 1.0, 2.0, 0.9],
119
+ [1.5, 2.0, 2.0, 1.0, 1.0, 1.0, 2.0, 0.5],
120
+ [1.1, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 0.1],
121
+ [0.9, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.9],
122
+ [0.5, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.5],
123
+ [0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.1],
124
+ [-0.1, -0.0, -0.0, -0.0, 0.0, -1.0, -0.0, -0.1],
125
+ [-0.5, -1.0, -0.0, -0.0, 0.0, -1.0, -0.0, -0.5],
126
+ [-0.9, -1.0, -1.0, -0.0, 0.0, -1.0, -0.0, -0.9],
127
+ [-1.1, -1.0, -1.0, -1.0, -1.0, -2.0, -1.0, -0.1],
128
+ [-1.5, -2.0, -2.0, -1.0, -1.0, -2.0, -1.0, -0.5],
129
+ [-1.9, -2.0, -2.0, -1.0, -1.0, -2.0, -1.0, -0.9],
130
+ [-2.1, -2.0, -2.0, -2.0, -2.0, -3.0, -2.0, -0.1],
131
+ [-2.5, -3.0, -2.0, -2.0, -2.0, -3.0, -2.0, -0.5],
132
+ [-2.9, -3.0, -3.0, -2.0, -2.0, -3.0, -2.0, -0.9],
133
+ [-3.1, -3.0, -3.0, -3.0, -3.0, -4.0, -3.0, -0.1],
134
+ [-3.5, -4.0, -4.0, -3.0, -3.0, -4.0, -3.0, -0.5],
135
+ [-3.9, -4.0, -4.0, -3.0, -3.0, -4.0, -3.0, -0.9],
136
+ [-4.1, -4.0, -4.0, -4.0, -4.0, -5.0, -4.0, -0.1],
137
+ [-4.5, -5.0, -4.0, -4.0, -4.0, -5.0, -4.0, -0.5],
138
+ [-4.9, -5.0, -5.0, -4.0, -4.0, -5.0, -4.0, -0.9],
131
139
  ],
132
140
  dtype=np.float32,
133
141
  )
134
142
 
135
- assert_np_equal(tab, golden)
143
+ assert_np_equal(tab, golden, tol=1e-6)
136
144
 
137
145
  if print_results:
138
146
  np.set_printoptions(formatter={"float": lambda x: "{:6.1f}".format(x).replace(".0", ".")})
@@ -149,24 +157,23 @@ def test_rounding(test, device):
149
157
  nx_fix = np.fix(nx)
150
158
  nx_floor = np.floor(nx)
151
159
  nx_ceil = np.ceil(nx)
160
+ nx_frac = np.modf(nx)[0]
152
161
 
153
- tab = np.stack([nx, nx_round, nx_rint, nx_trunc, nx_fix, nx_floor, nx_ceil], axis=1)
162
+ tab = np.stack([nx, nx_round, nx_rint, nx_trunc, nx_fix, nx_floor, nx_ceil, nx_frac], axis=1)
154
163
  print(" %5s %5s %5s %5s %5s %5s %5s" % ("x ", "round", "rint", "trunc", "fix", "floor", "ceil"))
155
164
  print(tab)
156
165
  print("----------------------------------------------")
157
166
 
158
167
 
159
- def register(parent):
160
- class TestRounding(parent):
161
- pass
168
+ class TestRounding(unittest.TestCase):
169
+ pass
162
170
 
163
- devices = get_test_devices()
164
171
 
165
- add_function_test(TestRounding, "test_rounding", test_rounding, devices=devices)
172
+ devices = get_test_devices()
166
173
 
167
- return TestRounding
174
+ add_function_test(TestRounding, "test_rounding", test_rounding, devices=devices)
168
175
 
169
176
 
170
177
  if __name__ == "__main__":
171
- c = register(unittest.TestCase)
178
+ wp.build.clear_kernel_cache()
172
179
  unittest.main(verbosity=2)
@@ -0,0 +1,190 @@
1
+ # Copyright (c) 2023 NVIDIA CORPORATION. All rights reserved.
2
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
3
+ # and proprietary rights in and to this software, related documentation
4
+ # and any modifications thereto. Any use, reproduction, disclosure or
5
+ # distribution of this software and related documentation without an express
6
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
+
8
+ import unittest
9
+ from functools import partial
10
+
11
+ import numpy as np
12
+
13
+ import warp as wp
14
+ from warp.tests.unittest_utils import *
15
+ from warp.utils import runlength_encode
16
+
17
+ wp.init()
18
+
19
+
20
+ def test_runlength_encode_int(test, device, n):
21
+ rng = np.random.default_rng(123)
22
+
23
+ values_np = np.sort(rng.integers(-10, high=10, size=n, dtype=int))
24
+
25
+ unique_values_np, unique_counts_np = np.unique(values_np, return_counts=True)
26
+
27
+ values = wp.array(values_np, device=device, dtype=int)
28
+
29
+ unique_values = wp.empty_like(values)
30
+ unique_counts = wp.empty_like(values)
31
+
32
+ run_count = runlength_encode(values, unique_values, unique_counts)
33
+
34
+ test.assertEqual(run_count, len(unique_values_np))
35
+ assert_np_equal(unique_values.numpy()[:run_count], unique_values_np[:run_count])
36
+ assert_np_equal(unique_counts.numpy()[:run_count], unique_counts_np[:run_count])
37
+
38
+
39
+ def test_runlength_encode_error_insufficient_storage(test, device):
40
+ values = wp.zeros(123, dtype=int, device=device)
41
+ run_values = wp.empty(1, dtype=int, device=device)
42
+ run_lengths = wp.empty(123, dtype=int, device=device)
43
+ with test.assertRaisesRegex(
44
+ RuntimeError,
45
+ r"Output array storage sizes must be at least equal to value_count$",
46
+ ):
47
+ runlength_encode(values, run_values, run_lengths)
48
+
49
+ values = wp.zeros(123, dtype=int, device="cpu")
50
+ run_values = wp.empty(123, dtype=int, device="cpu")
51
+ run_lengths = wp.empty(1, dtype=int, device="cpu")
52
+ with test.assertRaisesRegex(
53
+ RuntimeError,
54
+ r"Output array storage sizes must be at least equal to value_count$",
55
+ ):
56
+ runlength_encode(values, run_values, run_lengths)
57
+
58
+
59
+ def test_runlength_encode_error_dtypes_mismatch(test, device):
60
+ values = wp.zeros(123, dtype=int, device=device)
61
+ run_values = wp.empty(123, dtype=float, device=device)
62
+ run_lengths = wp.empty_like(values, device=device)
63
+ with test.assertRaisesRegex(
64
+ RuntimeError,
65
+ r"values and run_values data types do not match$",
66
+ ):
67
+ runlength_encode(values, run_values, run_lengths)
68
+
69
+
70
+ def test_runlength_encode_error_run_length_unsupported_dtype(test, device):
71
+ values = wp.zeros(123, dtype=int, device=device)
72
+ run_values = wp.empty(123, dtype=int, device=device)
73
+ run_lengths = wp.empty(123, dtype=float, device=device)
74
+ with test.assertRaisesRegex(
75
+ RuntimeError,
76
+ r"run_lengths array must be of type int32$",
77
+ ):
78
+ runlength_encode(values, run_values, run_lengths)
79
+
80
+
81
+ def test_runlength_encode_error_run_count_unsupported_dtype(test, device):
82
+ values = wp.zeros(123, dtype=int, device=device)
83
+ run_values = wp.empty_like(values, device=device)
84
+ run_lengths = wp.empty_like(values, device=device)
85
+ run_count = wp.empty(shape=(1,), dtype=float, device=device)
86
+ with test.assertRaisesRegex(
87
+ RuntimeError,
88
+ r"run_count array must be of type int32$",
89
+ ):
90
+ runlength_encode(values, run_values, run_lengths, run_count=run_count)
91
+
92
+
93
+ def test_runlength_encode_error_unsupported_dtype(test, device):
94
+ values = wp.zeros(123, dtype=float, device=device)
95
+ run_values = wp.empty(123, dtype=float, device=device)
96
+ run_lengths = wp.empty(123, dtype=int, device=device)
97
+ with test.assertRaisesRegex(
98
+ RuntimeError,
99
+ r"Unsupported data type$",
100
+ ):
101
+ runlength_encode(values, run_values, run_lengths)
102
+
103
+
104
+ devices = get_test_devices()
105
+
106
+
107
+ class TestRunlengthEncode(unittest.TestCase):
108
+ @unittest.skipUnless(wp.is_cuda_available(), "Requires CUDA")
109
+ def test_runlength_encode_error_devices_mismatch(self):
110
+ values = wp.zeros(123, dtype=int, device="cpu")
111
+ run_values = wp.empty_like(values, device="cuda:0")
112
+ run_lengths = wp.empty_like(values, device="cuda:0")
113
+ with self.assertRaisesRegex(
114
+ RuntimeError,
115
+ r"Array storage devices do not match$",
116
+ ):
117
+ runlength_encode(values, run_values, run_lengths)
118
+
119
+ values = wp.zeros(123, dtype=int, device="cpu")
120
+ run_values = wp.empty_like(values, device="cpu")
121
+ run_lengths = wp.empty_like(values, device="cuda:0")
122
+ with self.assertRaisesRegex(
123
+ RuntimeError,
124
+ r"Array storage devices do not match$",
125
+ ):
126
+ runlength_encode(values, run_values, run_lengths)
127
+
128
+ values = wp.zeros(123, dtype=int, device="cpu")
129
+ run_values = wp.empty_like(values, device="cuda:0")
130
+ run_lengths = wp.empty_like(values, device="cpu")
131
+ with self.assertRaisesRegex(
132
+ RuntimeError,
133
+ r"Array storage devices do not match$",
134
+ ):
135
+ runlength_encode(values, run_values, run_lengths)
136
+
137
+ @unittest.skipUnless(wp.is_cuda_available(), "Requires CUDA")
138
+ def test_runlength_encode_error_run_count_device_mismatch(self):
139
+ values = wp.zeros(123, dtype=int, device="cpu")
140
+ run_values = wp.empty_like(values, device="cpu")
141
+ run_lengths = wp.empty_like(values, device="cpu")
142
+ run_count = wp.empty(shape=(1,), dtype=int, device="cuda:0")
143
+ with self.assertRaisesRegex(
144
+ RuntimeError,
145
+ r"run_count storage device does not match other arrays$",
146
+ ):
147
+ runlength_encode(values, run_values, run_lengths, run_count=run_count)
148
+
149
+
150
+ add_function_test(
151
+ TestRunlengthEncode, "test_runlength_encode_int", partial(test_runlength_encode_int, n=100), devices=devices
152
+ )
153
+ add_function_test(
154
+ TestRunlengthEncode, "test_runlength_encode_empty", partial(test_runlength_encode_int, n=0), devices=devices
155
+ )
156
+ add_function_test(
157
+ TestRunlengthEncode,
158
+ "test_runlength_encode_error_insufficient_storage",
159
+ test_runlength_encode_error_insufficient_storage,
160
+ devices=devices,
161
+ )
162
+ add_function_test(
163
+ TestRunlengthEncode,
164
+ "test_runlength_encode_error_dtypes_mismatch",
165
+ test_runlength_encode_error_dtypes_mismatch,
166
+ devices=devices,
167
+ )
168
+ add_function_test(
169
+ TestRunlengthEncode,
170
+ "test_runlength_encode_error_run_length_unsupported_dtype",
171
+ test_runlength_encode_error_run_length_unsupported_dtype,
172
+ devices=devices,
173
+ )
174
+ add_function_test(
175
+ TestRunlengthEncode,
176
+ "test_runlength_encode_error_run_count_unsupported_dtype",
177
+ test_runlength_encode_error_run_count_unsupported_dtype,
178
+ devices=devices,
179
+ )
180
+ add_function_test(
181
+ TestRunlengthEncode,
182
+ "test_runlength_encode_error_unsupported_dtype",
183
+ test_runlength_encode_error_unsupported_dtype,
184
+ devices=devices,
185
+ )
186
+
187
+
188
+ if __name__ == "__main__":
189
+ wp.build.clear_kernel_cache()
190
+ unittest.main(verbosity=2)
@@ -5,14 +5,14 @@
5
5
  # distribution of this software and related documentation without an express
6
6
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
7
 
8
+ import unittest
8
9
  from dataclasses import dataclass
9
10
  from typing import Any
10
- import unittest
11
11
 
12
12
  import numpy as np
13
13
 
14
14
  import warp as wp
15
- from warp.tests.test_base import *
15
+ from warp.tests.unittest_utils import *
16
16
 
17
17
 
18
18
  @dataclass
@@ -87,11 +87,9 @@ def test_smoothstep(test, device):
87
87
 
88
88
  for data_type in TEST_DATA:
89
89
  kernel_fn = make_kernel_fn(data_type)
90
- module = wp.get_module(kernel_fn.__module__)
91
90
  kernel = wp.Kernel(
92
91
  func=kernel_fn,
93
92
  key=f"test_smoothstep{data_type.__name__}_kernel",
94
- module=module,
95
93
  )
96
94
 
97
95
  for test_data in TEST_DATA[data_type]:
@@ -155,16 +153,16 @@ def test_smoothstep(test, device):
155
153
  )
156
154
 
157
155
 
158
- def register(parent):
159
- devices = get_test_devices()
156
+ devices = get_test_devices()
157
+
158
+
159
+ class TestSmoothstep(unittest.TestCase):
160
+ pass
160
161
 
161
- class TestSmoothstep(parent):
162
- pass
163
162
 
164
- add_function_test(TestSmoothstep, "test_smoothstep", test_smoothstep, devices=devices)
165
- return TestSmoothstep
163
+ add_function_test(TestSmoothstep, "test_smoothstep", test_smoothstep, devices=devices)
166
164
 
167
165
 
168
166
  if __name__ == "__main__":
169
- _ = register(unittest.TestCase)
167
+ wp.build.clear_kernel_cache()
170
168
  unittest.main(verbosity=2)