warp-lang 0.9.0__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (315) hide show
  1. warp/__init__.py +15 -7
  2. warp/__init__.pyi +1 -0
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +22 -443
  6. warp/build_dll.py +384 -0
  7. warp/builtins.py +998 -488
  8. warp/codegen.py +1307 -739
  9. warp/config.py +5 -3
  10. warp/constants.py +6 -0
  11. warp/context.py +1291 -548
  12. warp/dlpack.py +31 -31
  13. warp/fabric.py +326 -0
  14. warp/fem/__init__.py +27 -0
  15. warp/fem/cache.py +389 -0
  16. warp/fem/dirichlet.py +181 -0
  17. warp/fem/domain.py +263 -0
  18. warp/fem/field/__init__.py +101 -0
  19. warp/fem/field/field.py +149 -0
  20. warp/fem/field/nodal_field.py +299 -0
  21. warp/fem/field/restriction.py +21 -0
  22. warp/fem/field/test.py +181 -0
  23. warp/fem/field/trial.py +183 -0
  24. warp/fem/geometry/__init__.py +19 -0
  25. warp/fem/geometry/closest_point.py +70 -0
  26. warp/fem/geometry/deformed_geometry.py +271 -0
  27. warp/fem/geometry/element.py +744 -0
  28. warp/fem/geometry/geometry.py +186 -0
  29. warp/fem/geometry/grid_2d.py +373 -0
  30. warp/fem/geometry/grid_3d.py +435 -0
  31. warp/fem/geometry/hexmesh.py +953 -0
  32. warp/fem/geometry/partition.py +376 -0
  33. warp/fem/geometry/quadmesh_2d.py +532 -0
  34. warp/fem/geometry/tetmesh.py +840 -0
  35. warp/fem/geometry/trimesh_2d.py +577 -0
  36. warp/fem/integrate.py +1616 -0
  37. warp/fem/operator.py +191 -0
  38. warp/fem/polynomial.py +213 -0
  39. warp/fem/quadrature/__init__.py +2 -0
  40. warp/fem/quadrature/pic_quadrature.py +245 -0
  41. warp/fem/quadrature/quadrature.py +294 -0
  42. warp/fem/space/__init__.py +292 -0
  43. warp/fem/space/basis_space.py +489 -0
  44. warp/fem/space/collocated_function_space.py +105 -0
  45. warp/fem/space/dof_mapper.py +236 -0
  46. warp/fem/space/function_space.py +145 -0
  47. warp/fem/space/grid_2d_function_space.py +267 -0
  48. warp/fem/space/grid_3d_function_space.py +306 -0
  49. warp/fem/space/hexmesh_function_space.py +352 -0
  50. warp/fem/space/partition.py +350 -0
  51. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  52. warp/fem/space/restriction.py +160 -0
  53. warp/fem/space/shape/__init__.py +15 -0
  54. warp/fem/space/shape/cube_shape_function.py +738 -0
  55. warp/fem/space/shape/shape_function.py +103 -0
  56. warp/fem/space/shape/square_shape_function.py +611 -0
  57. warp/fem/space/shape/tet_shape_function.py +567 -0
  58. warp/fem/space/shape/triangle_shape_function.py +429 -0
  59. warp/fem/space/tetmesh_function_space.py +292 -0
  60. warp/fem/space/topology.py +295 -0
  61. warp/fem/space/trimesh_2d_function_space.py +221 -0
  62. warp/fem/types.py +77 -0
  63. warp/fem/utils.py +495 -0
  64. warp/native/array.h +164 -55
  65. warp/native/builtin.h +150 -174
  66. warp/native/bvh.cpp +75 -328
  67. warp/native/bvh.cu +406 -23
  68. warp/native/bvh.h +37 -45
  69. warp/native/clang/clang.cpp +136 -24
  70. warp/native/crt.cpp +1 -76
  71. warp/native/crt.h +111 -104
  72. warp/native/cuda_crt.h +1049 -0
  73. warp/native/cuda_util.cpp +15 -3
  74. warp/native/cuda_util.h +3 -1
  75. warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
  76. warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
  77. warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
  78. warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
  79. warp/native/cutlass/tools/library/scripts/library.py +799 -0
  80. warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
  81. warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
  82. warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
  83. warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
  84. warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
  85. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
  86. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
  87. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
  88. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
  89. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
  90. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
  91. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
  92. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
  93. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
  94. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
  95. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
  96. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
  97. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
  98. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
  99. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
  100. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
  101. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
  102. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
  103. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
  104. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
  105. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
  106. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
  107. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
  108. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
  109. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
  110. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
  111. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
  112. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
  113. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
  114. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
  115. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
  116. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
  117. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  118. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  119. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
  120. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
  121. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
  122. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  123. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  124. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
  125. warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
  126. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
  127. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
  128. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
  129. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
  130. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
  131. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
  132. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
  133. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
  134. warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
  135. warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
  136. warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
  137. warp/native/cutlass/tools/library/scripts/rt.py +796 -0
  138. warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
  139. warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
  140. warp/native/cutlass_gemm.cu +5 -3
  141. warp/native/exports.h +1240 -949
  142. warp/native/fabric.h +228 -0
  143. warp/native/hashgrid.cpp +4 -4
  144. warp/native/hashgrid.h +22 -2
  145. warp/native/initializer_array.h +2 -2
  146. warp/native/intersect.h +22 -7
  147. warp/native/intersect_adj.h +8 -8
  148. warp/native/intersect_tri.h +13 -16
  149. warp/native/marching.cu +157 -161
  150. warp/native/mat.h +119 -19
  151. warp/native/matnn.h +2 -2
  152. warp/native/mesh.cpp +108 -83
  153. warp/native/mesh.cu +243 -6
  154. warp/native/mesh.h +1547 -458
  155. warp/native/nanovdb/NanoVDB.h +1 -1
  156. warp/native/noise.h +272 -329
  157. warp/native/quat.h +51 -8
  158. warp/native/rand.h +45 -35
  159. warp/native/range.h +6 -2
  160. warp/native/reduce.cpp +157 -0
  161. warp/native/reduce.cu +348 -0
  162. warp/native/runlength_encode.cpp +62 -0
  163. warp/native/runlength_encode.cu +46 -0
  164. warp/native/scan.cu +11 -13
  165. warp/native/scan.h +1 -0
  166. warp/native/solid_angle.h +442 -0
  167. warp/native/sort.cpp +13 -0
  168. warp/native/sort.cu +9 -1
  169. warp/native/sparse.cpp +338 -0
  170. warp/native/sparse.cu +545 -0
  171. warp/native/spatial.h +2 -2
  172. warp/native/temp_buffer.h +30 -0
  173. warp/native/vec.h +126 -24
  174. warp/native/volume.h +120 -0
  175. warp/native/warp.cpp +658 -53
  176. warp/native/warp.cu +660 -68
  177. warp/native/warp.h +112 -12
  178. warp/optim/__init__.py +1 -0
  179. warp/optim/linear.py +922 -0
  180. warp/optim/sgd.py +92 -0
  181. warp/render/render_opengl.py +392 -152
  182. warp/render/render_usd.py +11 -11
  183. warp/sim/__init__.py +2 -2
  184. warp/sim/articulation.py +385 -185
  185. warp/sim/collide.py +21 -8
  186. warp/sim/import_mjcf.py +297 -106
  187. warp/sim/import_urdf.py +389 -210
  188. warp/sim/import_usd.py +198 -97
  189. warp/sim/inertia.py +17 -18
  190. warp/sim/integrator_euler.py +14 -8
  191. warp/sim/integrator_xpbd.py +161 -19
  192. warp/sim/model.py +795 -291
  193. warp/sim/optimizer.py +2 -6
  194. warp/sim/render.py +65 -3
  195. warp/sim/utils.py +3 -0
  196. warp/sparse.py +1227 -0
  197. warp/stubs.py +665 -223
  198. warp/tape.py +66 -15
  199. warp/tests/__main__.py +3 -6
  200. warp/tests/assets/curlnoise_golden.npy +0 -0
  201. warp/tests/assets/pnoise_golden.npy +0 -0
  202. warp/tests/assets/torus.usda +105 -105
  203. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  204. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  205. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  206. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  207. warp/tests/aux_test_unresolved_func.py +14 -0
  208. warp/tests/aux_test_unresolved_symbol.py +14 -0
  209. warp/tests/disabled_kinematics.py +239 -0
  210. warp/tests/run_coverage_serial.py +31 -0
  211. warp/tests/test_adam.py +103 -106
  212. warp/tests/test_arithmetic.py +128 -74
  213. warp/tests/test_array.py +1497 -211
  214. warp/tests/test_array_reduce.py +150 -0
  215. warp/tests/test_atomic.py +64 -28
  216. warp/tests/test_bool.py +99 -0
  217. warp/tests/test_builtins_resolution.py +1292 -0
  218. warp/tests/test_bvh.py +75 -43
  219. warp/tests/test_closest_point_edge_edge.py +54 -57
  220. warp/tests/test_codegen.py +233 -128
  221. warp/tests/test_compile_consts.py +28 -20
  222. warp/tests/test_conditional.py +108 -24
  223. warp/tests/test_copy.py +10 -12
  224. warp/tests/test_ctypes.py +112 -88
  225. warp/tests/test_dense.py +21 -14
  226. warp/tests/test_devices.py +98 -0
  227. warp/tests/test_dlpack.py +136 -108
  228. warp/tests/test_examples.py +277 -0
  229. warp/tests/test_fabricarray.py +955 -0
  230. warp/tests/test_fast_math.py +15 -11
  231. warp/tests/test_fem.py +1271 -0
  232. warp/tests/test_fp16.py +53 -19
  233. warp/tests/test_func.py +187 -74
  234. warp/tests/test_generics.py +194 -49
  235. warp/tests/test_grad.py +180 -116
  236. warp/tests/test_grad_customs.py +176 -0
  237. warp/tests/test_hash_grid.py +52 -37
  238. warp/tests/test_import.py +10 -23
  239. warp/tests/test_indexedarray.py +577 -24
  240. warp/tests/test_intersect.py +18 -9
  241. warp/tests/test_large.py +141 -0
  242. warp/tests/test_launch.py +251 -15
  243. warp/tests/test_lerp.py +64 -65
  244. warp/tests/test_linear_solvers.py +154 -0
  245. warp/tests/test_lvalue.py +493 -0
  246. warp/tests/test_marching_cubes.py +12 -13
  247. warp/tests/test_mat.py +508 -2778
  248. warp/tests/test_mat_lite.py +115 -0
  249. warp/tests/test_mat_scalar_ops.py +2889 -0
  250. warp/tests/test_math.py +103 -9
  251. warp/tests/test_matmul.py +305 -69
  252. warp/tests/test_matmul_lite.py +410 -0
  253. warp/tests/test_mesh.py +71 -14
  254. warp/tests/test_mesh_query_aabb.py +41 -25
  255. warp/tests/test_mesh_query_point.py +325 -34
  256. warp/tests/test_mesh_query_ray.py +39 -22
  257. warp/tests/test_mlp.py +30 -22
  258. warp/tests/test_model.py +92 -89
  259. warp/tests/test_modules_lite.py +39 -0
  260. warp/tests/test_multigpu.py +88 -114
  261. warp/tests/test_noise.py +12 -11
  262. warp/tests/test_operators.py +16 -20
  263. warp/tests/test_options.py +11 -11
  264. warp/tests/test_pinned.py +17 -18
  265. warp/tests/test_print.py +32 -11
  266. warp/tests/test_quat.py +275 -129
  267. warp/tests/test_rand.py +18 -16
  268. warp/tests/test_reload.py +38 -34
  269. warp/tests/test_rounding.py +50 -43
  270. warp/tests/test_runlength_encode.py +190 -0
  271. warp/tests/test_smoothstep.py +9 -11
  272. warp/tests/test_snippet.py +143 -0
  273. warp/tests/test_sparse.py +460 -0
  274. warp/tests/test_spatial.py +276 -243
  275. warp/tests/test_streams.py +110 -85
  276. warp/tests/test_struct.py +331 -85
  277. warp/tests/test_tape.py +39 -21
  278. warp/tests/test_torch.py +118 -89
  279. warp/tests/test_transient_module.py +12 -13
  280. warp/tests/test_types.py +614 -0
  281. warp/tests/test_utils.py +494 -0
  282. warp/tests/test_vec.py +354 -1987
  283. warp/tests/test_vec_lite.py +73 -0
  284. warp/tests/test_vec_scalar_ops.py +2099 -0
  285. warp/tests/test_volume.py +457 -293
  286. warp/tests/test_volume_write.py +124 -134
  287. warp/tests/unittest_serial.py +35 -0
  288. warp/tests/unittest_suites.py +341 -0
  289. warp/tests/unittest_utils.py +568 -0
  290. warp/tests/unused_test_misc.py +71 -0
  291. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  292. warp/thirdparty/appdirs.py +36 -45
  293. warp/thirdparty/unittest_parallel.py +549 -0
  294. warp/torch.py +72 -30
  295. warp/types.py +1744 -713
  296. warp/utils.py +360 -350
  297. warp_lang-0.11.0.dist-info/LICENSE.md +36 -0
  298. warp_lang-0.11.0.dist-info/METADATA +238 -0
  299. warp_lang-0.11.0.dist-info/RECORD +332 -0
  300. {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
  301. warp/bin/warp-clang.exp +0 -0
  302. warp/bin/warp-clang.lib +0 -0
  303. warp/bin/warp.exp +0 -0
  304. warp/bin/warp.lib +0 -0
  305. warp/tests/test_all.py +0 -215
  306. warp/tests/test_array_scan.py +0 -60
  307. warp/tests/test_base.py +0 -208
  308. warp/tests/test_unresolved_func.py +0 -7
  309. warp/tests/test_unresolved_symbol.py +0 -7
  310. warp_lang-0.9.0.dist-info/METADATA +0 -20
  311. warp_lang-0.9.0.dist-info/RECORD +0 -177
  312. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  313. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  314. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  315. {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
@@ -5,11 +5,12 @@
5
5
  # distribution of this software and related documentation without an express
6
6
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
7
 
8
+ import unittest
9
+
8
10
  import numpy as np
9
- import warp as wp
10
- from warp.tests.test_base import *
11
11
 
12
- import unittest
12
+ import warp as wp
13
+ from warp.tests.unittest_utils import *
13
14
 
14
15
  wp.init()
15
16
 
@@ -45,6 +46,11 @@ def test_stream_arg_implicit_sync(test, device):
45
46
 
46
47
  new_stream = wp.Stream(device)
47
48
 
49
+ # Exercise code path
50
+ wp.set_stream(new_stream, device)
51
+
52
+ test.assertTrue(wp.get_device(device).has_stream)
53
+
48
54
  # launch work on new stream
49
55
  wp.launch(inc, dim=a.size, inputs=[a], stream=new_stream)
50
56
  wp.copy(b, a, stream=new_stream)
@@ -278,119 +284,138 @@ def test_stream_scope_wait_stream(test, device):
278
284
  assert_np_equal(d.numpy(), np.full(N, fill_value=4.0))
279
285
 
280
286
 
281
- def test_stream_arg_graph_mgpu(test, device):
282
- # resources on GPU 0
283
- stream0 = wp.get_stream("cuda:0")
284
- a0 = wp.zeros(N, dtype=float, device="cuda:0")
285
- b0 = wp.empty(N, dtype=float, device="cuda:0")
286
- c0 = wp.empty(N, dtype=float, device="cuda:0")
287
+ devices = get_unique_cuda_test_devices()
287
288
 
288
- # resources on GPU 1
289
- stream1 = wp.get_stream("cuda:1")
290
- a1 = wp.zeros(N, dtype=float, device="cuda:1")
291
289
 
292
- # start recording on stream0
293
- wp.capture_begin(stream=stream0)
290
+ class TestStreams(unittest.TestCase):
291
+ def test_stream_exceptions(self):
292
+ cpu_device = wp.get_device("cpu")
294
293
 
295
- # branch into stream1
296
- stream1.wait_stream(stream0)
294
+ # Can't set the stream on a CPU device
295
+ with self.assertRaises(RuntimeError):
296
+ stream0 = wp.Stream()
297
+ cpu_device.stream = stream0
297
298
 
298
- # launch concurrent kernels on each stream
299
- wp.launch(inc, dim=N, inputs=[a0], stream=stream0)
300
- wp.launch(inc, dim=N, inputs=[a1], stream=stream1)
299
+ # Can't create a stream on the CPU
300
+ with self.assertRaises(RuntimeError):
301
+ wp.Stream(device="cpu")
301
302
 
302
- # wait for stream1 to finish
303
- stream0.wait_stream(stream1)
303
+ # Can't create an event with CPU device
304
+ with self.assertRaises(RuntimeError):
305
+ wp.Event(device=cpu_device)
304
306
 
305
- # copy values from stream1
306
- wp.copy(b0, a1, stream=stream0)
307
+ # Can't get the stream on a CPU device
308
+ with self.assertRaises(RuntimeError):
309
+ cpu_stream = cpu_device.stream # noqa: F841
307
310
 
308
- # compute sum
309
- wp.launch(sum, dim=N, inputs=[a0, b0, c0], stream=stream0)
311
+ @unittest.skipUnless(len(wp.get_cuda_devices()) > 1, "Requires at least two CUDA devices")
312
+ def test_stream_arg_graph_mgpu(self):
313
+ wp.load_module(device="cuda:0")
314
+ wp.load_module(device="cuda:1")
310
315
 
311
- # finish recording on stream0
312
- g = wp.capture_end(stream=stream0)
316
+ # resources on GPU 0
317
+ stream0 = wp.get_stream("cuda:0")
318
+ a0 = wp.zeros(N, dtype=float, device="cuda:0")
319
+ b0 = wp.empty(N, dtype=float, device="cuda:0")
320
+ c0 = wp.empty(N, dtype=float, device="cuda:0")
313
321
 
314
- # replay
315
- num_iters = 10
316
- for _ in range(num_iters):
317
- wp.capture_launch(g, stream=stream0)
322
+ # resources on GPU 1
323
+ stream1 = wp.get_stream("cuda:1")
324
+ a1 = wp.zeros(N, dtype=float, device="cuda:1")
318
325
 
319
- # check results
320
- assert_np_equal(c0.numpy(), np.full(N, fill_value=2 * num_iters))
326
+ # start recording on stream0
327
+ wp.capture_begin(stream=stream0, force_module_load=False)
328
+ try:
329
+ # branch into stream1
330
+ stream1.wait_stream(stream0)
321
331
 
332
+ # launch concurrent kernels on each stream
333
+ wp.launch(inc, dim=N, inputs=[a0], stream=stream0)
334
+ wp.launch(inc, dim=N, inputs=[a1], stream=stream1)
322
335
 
323
- def test_stream_scope_graph_mgpu(test, device):
324
- # resources on GPU 0
325
- with wp.ScopedDevice("cuda:0"):
326
- stream0 = wp.get_stream()
327
- a0 = wp.zeros(N, dtype=float)
328
- b0 = wp.empty(N, dtype=float)
329
- c0 = wp.empty(N, dtype=float)
336
+ # wait for stream1 to finish
337
+ stream0.wait_stream(stream1)
330
338
 
331
- # resources on GPU 1
332
- with wp.ScopedDevice("cuda:1"):
333
- stream1 = wp.get_stream()
334
- a1 = wp.zeros(N, dtype=float)
339
+ # copy values from stream1
340
+ wp.copy(b0, a1, stream=stream0)
335
341
 
336
- # capture graph
337
- with wp.ScopedDevice("cuda:0"):
338
- # start recording
339
- wp.capture_begin()
342
+ # compute sum
343
+ wp.launch(sum, dim=N, inputs=[a0, b0, c0], stream=stream0)
344
+ finally:
345
+ # finish recording on stream0
346
+ g = wp.capture_end(stream=stream0)
340
347
 
341
- with wp.ScopedDevice("cuda:1"):
342
- # branch into stream1
343
- wp.wait_stream(stream0)
348
+ # replay
349
+ num_iters = 10
350
+ for _ in range(num_iters):
351
+ wp.capture_launch(g, stream=stream0)
344
352
 
345
- wp.launch(inc, dim=N, inputs=[a1])
353
+ # check results
354
+ assert_np_equal(c0.numpy(), np.full(N, fill_value=2 * num_iters))
346
355
 
347
- wp.launch(inc, dim=N, inputs=[a0])
356
+ @unittest.skipUnless(len(wp.get_cuda_devices()) > 1, "Requires at least two CUDA devices")
357
+ def test_stream_scope_graph_mgpu(self):
358
+ wp.load_module(device="cuda:0")
359
+ wp.load_module(device="cuda:1")
348
360
 
349
- # wait for stream1 to finish
350
- wp.wait_stream(stream1)
361
+ # resources on GPU 0
362
+ with wp.ScopedDevice("cuda:0"):
363
+ stream0 = wp.get_stream()
364
+ a0 = wp.zeros(N, dtype=float)
365
+ b0 = wp.empty(N, dtype=float)
366
+ c0 = wp.empty(N, dtype=float)
351
367
 
352
- # copy values from stream1
353
- wp.copy(b0, a1)
368
+ # resources on GPU 1
369
+ with wp.ScopedDevice("cuda:1"):
370
+ stream1 = wp.get_stream()
371
+ a1 = wp.zeros(N, dtype=float)
354
372
 
355
- # compute sum
356
- wp.launch(sum, dim=N, inputs=[a0, b0, c0])
373
+ # capture graph
374
+ with wp.ScopedDevice("cuda:0"):
375
+ # start recording
376
+ wp.capture_begin(force_module_load=False)
377
+ try:
378
+ with wp.ScopedDevice("cuda:1"):
379
+ # branch into stream1
380
+ wp.wait_stream(stream0)
357
381
 
358
- # finish recording
359
- g = wp.capture_end()
382
+ wp.launch(inc, dim=N, inputs=[a1])
360
383
 
361
- # replay
362
- with wp.ScopedDevice("cuda:0"):
363
- num_iters = 10
364
- for _ in range(num_iters):
365
- wp.capture_launch(g)
384
+ wp.launch(inc, dim=N, inputs=[a0])
366
385
 
367
- # check results
368
- assert_np_equal(c0.numpy(), np.full(N, fill_value=2 * num_iters))
386
+ # wait for stream1 to finish
387
+ wp.wait_stream(stream1)
369
388
 
389
+ # copy values from stream1
390
+ wp.copy(b0, a1)
370
391
 
371
- def register(parent):
372
- devices = wp.get_cuda_devices()
392
+ # compute sum
393
+ wp.launch(sum, dim=N, inputs=[a0, b0, c0])
394
+ finally:
395
+ # finish recording
396
+ g = wp.capture_end()
373
397
 
374
- class TestStreams(parent):
375
- pass
398
+ # replay
399
+ with wp.ScopedDevice("cuda:0"):
400
+ num_iters = 10
401
+ for _ in range(num_iters):
402
+ wp.capture_launch(g)
376
403
 
377
- add_function_test(TestStreams, "test_stream_arg_implicit_sync", test_stream_arg_implicit_sync, devices=devices)
378
- add_function_test(TestStreams, "test_stream_scope_implicit_sync", test_stream_scope_implicit_sync, devices=devices)
404
+ # check results
405
+ assert_np_equal(c0.numpy(), np.full(N, fill_value=2 * num_iters))
379
406
 
380
- add_function_test(TestStreams, "test_stream_arg_synchronize", test_stream_arg_synchronize, devices=devices)
381
- add_function_test(TestStreams, "test_stream_arg_wait_event", test_stream_arg_wait_event, devices=devices)
382
- add_function_test(TestStreams, "test_stream_arg_wait_stream", test_stream_arg_wait_stream, devices=devices)
383
- add_function_test(TestStreams, "test_stream_scope_synchronize", test_stream_scope_synchronize, devices=devices)
384
- add_function_test(TestStreams, "test_stream_scope_wait_event", test_stream_scope_wait_event, devices=devices)
385
- add_function_test(TestStreams, "test_stream_scope_wait_stream", test_stream_scope_wait_stream, devices=devices)
386
407
 
387
- if len(devices) > 1:
388
- add_function_test(TestStreams, "test_stream_arg_graph_mgpu", test_stream_arg_graph_mgpu)
389
- add_function_test(TestStreams, "test_stream_scope_graph_mgpu", test_stream_scope_graph_mgpu)
408
+ add_function_test(TestStreams, "test_stream_arg_implicit_sync", test_stream_arg_implicit_sync, devices=devices)
409
+ add_function_test(TestStreams, "test_stream_scope_implicit_sync", test_stream_scope_implicit_sync, devices=devices)
390
410
 
391
- return TestStreams
411
+ add_function_test(TestStreams, "test_stream_arg_synchronize", test_stream_arg_synchronize, devices=devices)
412
+ add_function_test(TestStreams, "test_stream_arg_wait_event", test_stream_arg_wait_event, devices=devices)
413
+ add_function_test(TestStreams, "test_stream_arg_wait_stream", test_stream_arg_wait_stream, devices=devices)
414
+ add_function_test(TestStreams, "test_stream_scope_synchronize", test_stream_scope_synchronize, devices=devices)
415
+ add_function_test(TestStreams, "test_stream_scope_wait_event", test_stream_scope_wait_event, devices=devices)
416
+ add_function_test(TestStreams, "test_stream_scope_wait_stream", test_stream_scope_wait_stream, devices=devices)
392
417
 
393
418
 
394
419
  if __name__ == "__main__":
395
- c = register(unittest.TestCase)
420
+ wp.build.clear_kernel_cache()
396
421
  unittest.main(verbosity=2)