warp-lang 0.9.0__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (315) hide show
  1. warp/__init__.py +15 -7
  2. warp/__init__.pyi +1 -0
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +22 -443
  6. warp/build_dll.py +384 -0
  7. warp/builtins.py +998 -488
  8. warp/codegen.py +1307 -739
  9. warp/config.py +5 -3
  10. warp/constants.py +6 -0
  11. warp/context.py +1291 -548
  12. warp/dlpack.py +31 -31
  13. warp/fabric.py +326 -0
  14. warp/fem/__init__.py +27 -0
  15. warp/fem/cache.py +389 -0
  16. warp/fem/dirichlet.py +181 -0
  17. warp/fem/domain.py +263 -0
  18. warp/fem/field/__init__.py +101 -0
  19. warp/fem/field/field.py +149 -0
  20. warp/fem/field/nodal_field.py +299 -0
  21. warp/fem/field/restriction.py +21 -0
  22. warp/fem/field/test.py +181 -0
  23. warp/fem/field/trial.py +183 -0
  24. warp/fem/geometry/__init__.py +19 -0
  25. warp/fem/geometry/closest_point.py +70 -0
  26. warp/fem/geometry/deformed_geometry.py +271 -0
  27. warp/fem/geometry/element.py +744 -0
  28. warp/fem/geometry/geometry.py +186 -0
  29. warp/fem/geometry/grid_2d.py +373 -0
  30. warp/fem/geometry/grid_3d.py +435 -0
  31. warp/fem/geometry/hexmesh.py +953 -0
  32. warp/fem/geometry/partition.py +376 -0
  33. warp/fem/geometry/quadmesh_2d.py +532 -0
  34. warp/fem/geometry/tetmesh.py +840 -0
  35. warp/fem/geometry/trimesh_2d.py +577 -0
  36. warp/fem/integrate.py +1616 -0
  37. warp/fem/operator.py +191 -0
  38. warp/fem/polynomial.py +213 -0
  39. warp/fem/quadrature/__init__.py +2 -0
  40. warp/fem/quadrature/pic_quadrature.py +245 -0
  41. warp/fem/quadrature/quadrature.py +294 -0
  42. warp/fem/space/__init__.py +292 -0
  43. warp/fem/space/basis_space.py +489 -0
  44. warp/fem/space/collocated_function_space.py +105 -0
  45. warp/fem/space/dof_mapper.py +236 -0
  46. warp/fem/space/function_space.py +145 -0
  47. warp/fem/space/grid_2d_function_space.py +267 -0
  48. warp/fem/space/grid_3d_function_space.py +306 -0
  49. warp/fem/space/hexmesh_function_space.py +352 -0
  50. warp/fem/space/partition.py +350 -0
  51. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  52. warp/fem/space/restriction.py +160 -0
  53. warp/fem/space/shape/__init__.py +15 -0
  54. warp/fem/space/shape/cube_shape_function.py +738 -0
  55. warp/fem/space/shape/shape_function.py +103 -0
  56. warp/fem/space/shape/square_shape_function.py +611 -0
  57. warp/fem/space/shape/tet_shape_function.py +567 -0
  58. warp/fem/space/shape/triangle_shape_function.py +429 -0
  59. warp/fem/space/tetmesh_function_space.py +292 -0
  60. warp/fem/space/topology.py +295 -0
  61. warp/fem/space/trimesh_2d_function_space.py +221 -0
  62. warp/fem/types.py +77 -0
  63. warp/fem/utils.py +495 -0
  64. warp/native/array.h +164 -55
  65. warp/native/builtin.h +150 -174
  66. warp/native/bvh.cpp +75 -328
  67. warp/native/bvh.cu +406 -23
  68. warp/native/bvh.h +37 -45
  69. warp/native/clang/clang.cpp +136 -24
  70. warp/native/crt.cpp +1 -76
  71. warp/native/crt.h +111 -104
  72. warp/native/cuda_crt.h +1049 -0
  73. warp/native/cuda_util.cpp +15 -3
  74. warp/native/cuda_util.h +3 -1
  75. warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
  76. warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
  77. warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
  78. warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
  79. warp/native/cutlass/tools/library/scripts/library.py +799 -0
  80. warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
  81. warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
  82. warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
  83. warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
  84. warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
  85. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
  86. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
  87. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
  88. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
  89. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
  90. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
  91. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
  92. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
  93. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
  94. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
  95. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
  96. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
  97. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
  98. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
  99. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
  100. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
  101. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
  102. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
  103. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
  104. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
  105. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
  106. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
  107. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
  108. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
  109. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
  110. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
  111. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
  112. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
  113. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
  114. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
  115. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
  116. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
  117. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  118. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  119. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
  120. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
  121. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
  122. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  123. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  124. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
  125. warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
  126. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
  127. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
  128. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
  129. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
  130. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
  131. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
  132. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
  133. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
  134. warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
  135. warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
  136. warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
  137. warp/native/cutlass/tools/library/scripts/rt.py +796 -0
  138. warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
  139. warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
  140. warp/native/cutlass_gemm.cu +5 -3
  141. warp/native/exports.h +1240 -949
  142. warp/native/fabric.h +228 -0
  143. warp/native/hashgrid.cpp +4 -4
  144. warp/native/hashgrid.h +22 -2
  145. warp/native/initializer_array.h +2 -2
  146. warp/native/intersect.h +22 -7
  147. warp/native/intersect_adj.h +8 -8
  148. warp/native/intersect_tri.h +13 -16
  149. warp/native/marching.cu +157 -161
  150. warp/native/mat.h +119 -19
  151. warp/native/matnn.h +2 -2
  152. warp/native/mesh.cpp +108 -83
  153. warp/native/mesh.cu +243 -6
  154. warp/native/mesh.h +1547 -458
  155. warp/native/nanovdb/NanoVDB.h +1 -1
  156. warp/native/noise.h +272 -329
  157. warp/native/quat.h +51 -8
  158. warp/native/rand.h +45 -35
  159. warp/native/range.h +6 -2
  160. warp/native/reduce.cpp +157 -0
  161. warp/native/reduce.cu +348 -0
  162. warp/native/runlength_encode.cpp +62 -0
  163. warp/native/runlength_encode.cu +46 -0
  164. warp/native/scan.cu +11 -13
  165. warp/native/scan.h +1 -0
  166. warp/native/solid_angle.h +442 -0
  167. warp/native/sort.cpp +13 -0
  168. warp/native/sort.cu +9 -1
  169. warp/native/sparse.cpp +338 -0
  170. warp/native/sparse.cu +545 -0
  171. warp/native/spatial.h +2 -2
  172. warp/native/temp_buffer.h +30 -0
  173. warp/native/vec.h +126 -24
  174. warp/native/volume.h +120 -0
  175. warp/native/warp.cpp +658 -53
  176. warp/native/warp.cu +660 -68
  177. warp/native/warp.h +112 -12
  178. warp/optim/__init__.py +1 -0
  179. warp/optim/linear.py +922 -0
  180. warp/optim/sgd.py +92 -0
  181. warp/render/render_opengl.py +392 -152
  182. warp/render/render_usd.py +11 -11
  183. warp/sim/__init__.py +2 -2
  184. warp/sim/articulation.py +385 -185
  185. warp/sim/collide.py +21 -8
  186. warp/sim/import_mjcf.py +297 -106
  187. warp/sim/import_urdf.py +389 -210
  188. warp/sim/import_usd.py +198 -97
  189. warp/sim/inertia.py +17 -18
  190. warp/sim/integrator_euler.py +14 -8
  191. warp/sim/integrator_xpbd.py +161 -19
  192. warp/sim/model.py +795 -291
  193. warp/sim/optimizer.py +2 -6
  194. warp/sim/render.py +65 -3
  195. warp/sim/utils.py +3 -0
  196. warp/sparse.py +1227 -0
  197. warp/stubs.py +665 -223
  198. warp/tape.py +66 -15
  199. warp/tests/__main__.py +3 -6
  200. warp/tests/assets/curlnoise_golden.npy +0 -0
  201. warp/tests/assets/pnoise_golden.npy +0 -0
  202. warp/tests/assets/torus.usda +105 -105
  203. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  204. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  205. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  206. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  207. warp/tests/aux_test_unresolved_func.py +14 -0
  208. warp/tests/aux_test_unresolved_symbol.py +14 -0
  209. warp/tests/disabled_kinematics.py +239 -0
  210. warp/tests/run_coverage_serial.py +31 -0
  211. warp/tests/test_adam.py +103 -106
  212. warp/tests/test_arithmetic.py +128 -74
  213. warp/tests/test_array.py +1497 -211
  214. warp/tests/test_array_reduce.py +150 -0
  215. warp/tests/test_atomic.py +64 -28
  216. warp/tests/test_bool.py +99 -0
  217. warp/tests/test_builtins_resolution.py +1292 -0
  218. warp/tests/test_bvh.py +75 -43
  219. warp/tests/test_closest_point_edge_edge.py +54 -57
  220. warp/tests/test_codegen.py +233 -128
  221. warp/tests/test_compile_consts.py +28 -20
  222. warp/tests/test_conditional.py +108 -24
  223. warp/tests/test_copy.py +10 -12
  224. warp/tests/test_ctypes.py +112 -88
  225. warp/tests/test_dense.py +21 -14
  226. warp/tests/test_devices.py +98 -0
  227. warp/tests/test_dlpack.py +136 -108
  228. warp/tests/test_examples.py +277 -0
  229. warp/tests/test_fabricarray.py +955 -0
  230. warp/tests/test_fast_math.py +15 -11
  231. warp/tests/test_fem.py +1271 -0
  232. warp/tests/test_fp16.py +53 -19
  233. warp/tests/test_func.py +187 -74
  234. warp/tests/test_generics.py +194 -49
  235. warp/tests/test_grad.py +180 -116
  236. warp/tests/test_grad_customs.py +176 -0
  237. warp/tests/test_hash_grid.py +52 -37
  238. warp/tests/test_import.py +10 -23
  239. warp/tests/test_indexedarray.py +577 -24
  240. warp/tests/test_intersect.py +18 -9
  241. warp/tests/test_large.py +141 -0
  242. warp/tests/test_launch.py +251 -15
  243. warp/tests/test_lerp.py +64 -65
  244. warp/tests/test_linear_solvers.py +154 -0
  245. warp/tests/test_lvalue.py +493 -0
  246. warp/tests/test_marching_cubes.py +12 -13
  247. warp/tests/test_mat.py +508 -2778
  248. warp/tests/test_mat_lite.py +115 -0
  249. warp/tests/test_mat_scalar_ops.py +2889 -0
  250. warp/tests/test_math.py +103 -9
  251. warp/tests/test_matmul.py +305 -69
  252. warp/tests/test_matmul_lite.py +410 -0
  253. warp/tests/test_mesh.py +71 -14
  254. warp/tests/test_mesh_query_aabb.py +41 -25
  255. warp/tests/test_mesh_query_point.py +325 -34
  256. warp/tests/test_mesh_query_ray.py +39 -22
  257. warp/tests/test_mlp.py +30 -22
  258. warp/tests/test_model.py +92 -89
  259. warp/tests/test_modules_lite.py +39 -0
  260. warp/tests/test_multigpu.py +88 -114
  261. warp/tests/test_noise.py +12 -11
  262. warp/tests/test_operators.py +16 -20
  263. warp/tests/test_options.py +11 -11
  264. warp/tests/test_pinned.py +17 -18
  265. warp/tests/test_print.py +32 -11
  266. warp/tests/test_quat.py +275 -129
  267. warp/tests/test_rand.py +18 -16
  268. warp/tests/test_reload.py +38 -34
  269. warp/tests/test_rounding.py +50 -43
  270. warp/tests/test_runlength_encode.py +190 -0
  271. warp/tests/test_smoothstep.py +9 -11
  272. warp/tests/test_snippet.py +143 -0
  273. warp/tests/test_sparse.py +460 -0
  274. warp/tests/test_spatial.py +276 -243
  275. warp/tests/test_streams.py +110 -85
  276. warp/tests/test_struct.py +331 -85
  277. warp/tests/test_tape.py +39 -21
  278. warp/tests/test_torch.py +118 -89
  279. warp/tests/test_transient_module.py +12 -13
  280. warp/tests/test_types.py +614 -0
  281. warp/tests/test_utils.py +494 -0
  282. warp/tests/test_vec.py +354 -1987
  283. warp/tests/test_vec_lite.py +73 -0
  284. warp/tests/test_vec_scalar_ops.py +2099 -0
  285. warp/tests/test_volume.py +457 -293
  286. warp/tests/test_volume_write.py +124 -134
  287. warp/tests/unittest_serial.py +35 -0
  288. warp/tests/unittest_suites.py +341 -0
  289. warp/tests/unittest_utils.py +568 -0
  290. warp/tests/unused_test_misc.py +71 -0
  291. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  292. warp/thirdparty/appdirs.py +36 -45
  293. warp/thirdparty/unittest_parallel.py +549 -0
  294. warp/torch.py +72 -30
  295. warp/types.py +1744 -713
  296. warp/utils.py +360 -350
  297. warp_lang-0.11.0.dist-info/LICENSE.md +36 -0
  298. warp_lang-0.11.0.dist-info/METADATA +238 -0
  299. warp_lang-0.11.0.dist-info/RECORD +332 -0
  300. {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
  301. warp/bin/warp-clang.exp +0 -0
  302. warp/bin/warp-clang.lib +0 -0
  303. warp/bin/warp.exp +0 -0
  304. warp/bin/warp.lib +0 -0
  305. warp/tests/test_all.py +0 -215
  306. warp/tests/test_array_scan.py +0 -60
  307. warp/tests/test_base.py +0 -208
  308. warp/tests/test_unresolved_func.py +0 -7
  309. warp/tests/test_unresolved_symbol.py +0 -7
  310. warp_lang-0.9.0.dist-info/METADATA +0 -20
  311. warp_lang-0.9.0.dist-info/RECORD +0 -177
  312. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  313. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  314. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  315. {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
warp/tests/test_torch.py CHANGED
@@ -5,13 +5,12 @@
5
5
  # distribution of this software and related documentation without an express
6
6
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
7
 
8
- # include parent path
9
- import numpy as np
10
8
  import unittest
11
- import sys
9
+
10
+ import numpy as np
12
11
 
13
12
  import warp as wp
14
- from warp.tests.test_base import *
13
+ from warp.tests.unittest_utils import *
15
14
 
16
15
  wp.init()
17
16
 
@@ -103,7 +102,7 @@ def test_from_torch(test, device):
103
102
  wrap_scalar_tensor_implicit(torch.int16, wp.int16)
104
103
  wrap_scalar_tensor_implicit(torch.int8, wp.int8)
105
104
  wrap_scalar_tensor_implicit(torch.uint8, wp.uint8)
106
- wrap_scalar_tensor_implicit(torch.bool, wp.uint8)
105
+ wrap_scalar_tensor_implicit(torch.bool, wp.bool)
107
106
 
108
107
  # explicitly specify warp dtype
109
108
  def wrap_scalar_tensor_explicit(torch_dtype, expected_warp_dtype):
@@ -127,6 +126,7 @@ def test_from_torch(test, device):
127
126
  wrap_scalar_tensor_explicit(torch.uint8, wp.int8)
128
127
  wrap_scalar_tensor_explicit(torch.bool, wp.uint8)
129
128
  wrap_scalar_tensor_explicit(torch.bool, wp.int8)
129
+ wrap_scalar_tensor_explicit(torch.bool, wp.bool)
130
130
 
131
131
  def wrap_vec_tensor(n, desired_warp_dtype):
132
132
  t = torch.zeros((10, n), dtype=torch.float32, device=torch_device)
@@ -151,6 +151,29 @@ def test_from_torch(test, device):
151
151
  wrap_mat_tensor(4, 4, wp.mat44)
152
152
  wrap_mat_tensor(6, 6, wp.spatial_matrix)
153
153
 
154
+ def wrap_vec_tensor_with_grad(n, desired_warp_dtype):
155
+ t = torch.zeros((10, n), dtype=torch.float32, device=torch_device)
156
+ a = wp.from_torch(t, desired_warp_dtype, requires_grad=True)
157
+ assert a.dtype == desired_warp_dtype
158
+ assert a.shape == (10,)
159
+
160
+ wrap_vec_tensor_with_grad(2, wp.vec2)
161
+ wrap_vec_tensor_with_grad(3, wp.vec3)
162
+ wrap_vec_tensor_with_grad(4, wp.vec4)
163
+ wrap_vec_tensor_with_grad(6, wp.spatial_vector)
164
+ wrap_vec_tensor_with_grad(7, wp.transform)
165
+
166
+ def wrap_mat_tensor_with_grad(n, m, desired_warp_dtype):
167
+ t = torch.zeros((10, n, m), dtype=torch.float32, device=torch_device)
168
+ a = wp.from_torch(t, desired_warp_dtype, requires_grad=True)
169
+ assert a.dtype == desired_warp_dtype
170
+ assert a.shape == (10,)
171
+
172
+ wrap_mat_tensor_with_grad(2, 2, wp.mat22)
173
+ wrap_mat_tensor_with_grad(3, 3, wp.mat33)
174
+ wrap_mat_tensor_with_grad(4, 4, wp.mat44)
175
+ wrap_mat_tensor_with_grad(6, 6, wp.spatial_matrix)
176
+
154
177
 
155
178
  def test_to_torch(test, device):
156
179
  import torch
@@ -169,6 +192,7 @@ def test_to_torch(test, device):
169
192
  wrap_scalar_array(wp.int16, torch.int16)
170
193
  wrap_scalar_array(wp.int8, torch.int8)
171
194
  wrap_scalar_array(wp.uint8, torch.uint8)
195
+ wrap_scalar_array(wp.bool, torch.bool)
172
196
 
173
197
  # not supported by torch
174
198
  # wrap_scalar_array(wp.uint64, torch.int64)
@@ -445,6 +469,8 @@ def test_torch_autograd(test, device):
445
469
  def test_torch_graph_torch_stream(test, device):
446
470
  """Capture Torch graph on Torch stream"""
447
471
 
472
+ wp.load_module(device=device)
473
+
448
474
  import torch
449
475
 
450
476
  torch_device = wp.device_to_torch(device)
@@ -526,12 +552,14 @@ def test_warp_graph_warp_stream(test, device):
526
552
 
527
553
  # capture graph
528
554
  with wp.ScopedDevice(device), torch.cuda.stream(torch_stream):
529
- wp.capture_begin()
530
- t += 1.0
531
- wp.launch(inc, dim=n, inputs=[a])
532
- t += 1.0
533
- wp.launch(inc, dim=n, inputs=[a])
534
- g = wp.capture_end()
555
+ wp.capture_begin(force_module_load=False)
556
+ try:
557
+ t += 1.0
558
+ wp.launch(inc, dim=n, inputs=[a])
559
+ t += 1.0
560
+ wp.launch(inc, dim=n, inputs=[a])
561
+ finally:
562
+ g = wp.capture_end()
535
563
 
536
564
  # replay graph
537
565
  num_iters = 10
@@ -545,6 +573,8 @@ def test_warp_graph_warp_stream(test, device):
545
573
  def test_warp_graph_torch_stream(test, device):
546
574
  """Capture Warp graph on Torch stream"""
547
575
 
576
+ wp.load_module(device=device)
577
+
548
578
  import torch
549
579
 
550
580
  torch_device = wp.device_to_torch(device)
@@ -562,12 +592,14 @@ def test_warp_graph_torch_stream(test, device):
562
592
 
563
593
  # capture graph
564
594
  with wp.ScopedStream(warp_stream), torch.cuda.stream(torch_stream):
565
- wp.capture_begin()
566
- t += 1.0
567
- wp.launch(inc, dim=n, inputs=[a])
568
- t += 1.0
569
- wp.launch(inc, dim=n, inputs=[a])
570
- g = wp.capture_end()
595
+ wp.capture_begin(force_module_load=False)
596
+ try:
597
+ t += 1.0
598
+ wp.launch(inc, dim=n, inputs=[a])
599
+ t += 1.0
600
+ wp.launch(inc, dim=n, inputs=[a])
601
+ finally:
602
+ g = wp.capture_end()
571
603
 
572
604
  # replay graph
573
605
  num_iters = 10
@@ -578,82 +610,79 @@ def test_warp_graph_torch_stream(test, device):
578
610
  assert passed.item()
579
611
 
580
612
 
581
- def register(parent):
582
- class TestTorch(parent):
583
- pass
584
-
585
- try:
586
- import torch
587
-
588
- # check which Warp devices work with Torch
589
- # CUDA devices may fail if Torch was not compiled with CUDA support
590
- test_devices = get_test_devices()
591
- torch_compatible_devices = []
592
- torch_compatible_cuda_devices = []
593
-
594
- for d in test_devices:
595
- try:
596
- t = torch.arange(10, device=wp.device_to_torch(d))
597
- t += 1
598
- torch_compatible_devices.append(d)
599
- if d.is_cuda:
600
- torch_compatible_cuda_devices.append(d)
601
- except Exception as e:
602
- print(f"Skipping Torch tests on device '{d}' due to exception: {e}")
603
-
604
- if torch_compatible_devices:
605
- add_function_test(TestTorch, "test_from_torch", test_from_torch, devices=torch_compatible_devices)
606
- add_function_test(
607
- TestTorch, "test_from_torch_slices", test_from_torch_slices, devices=torch_compatible_devices
608
- )
609
- add_function_test(
610
- TestTorch,
611
- "test_from_torch_zero_strides",
612
- test_from_torch_zero_strides,
613
- devices=torch_compatible_devices,
614
- )
615
- add_function_test(TestTorch, "test_to_torch", test_to_torch, devices=torch_compatible_devices)
616
- add_function_test(TestTorch, "test_torch_zerocopy", test_torch_zerocopy, devices=torch_compatible_devices)
617
- add_function_test(TestTorch, "test_torch_autograd", test_torch_autograd, devices=torch_compatible_devices)
618
-
619
- if torch_compatible_cuda_devices:
620
- add_function_test(
621
- TestTorch,
622
- "test_torch_graph_torch_stream",
623
- test_torch_graph_torch_stream,
624
- devices=torch_compatible_cuda_devices,
625
- )
626
- add_function_test(
627
- TestTorch,
628
- "test_torch_graph_warp_stream",
629
- test_torch_graph_warp_stream,
630
- devices=torch_compatible_cuda_devices,
631
- )
632
- add_function_test(
633
- TestTorch,
634
- "test_warp_graph_warp_stream",
635
- test_warp_graph_warp_stream,
636
- devices=torch_compatible_cuda_devices,
637
- )
638
- add_function_test(
639
- TestTorch,
640
- "test_warp_graph_torch_stream",
641
- test_warp_graph_torch_stream,
642
- devices=torch_compatible_cuda_devices,
643
- )
613
+ class TestTorch(unittest.TestCase):
614
+ pass
615
+
644
616
 
645
- # multi-GPU tests
646
- if len(torch_compatible_cuda_devices) > 1:
647
- add_function_test(TestTorch, "test_torch_mgpu_from_torch", test_torch_mgpu_from_torch)
648
- add_function_test(TestTorch, "test_torch_mgpu_to_torch", test_torch_mgpu_to_torch)
649
- add_function_test(TestTorch, "test_torch_mgpu_interop", test_torch_mgpu_interop)
617
+ test_devices = get_test_devices()
650
618
 
651
- except Exception as e:
652
- print(f"Skipping Torch tests due to exception: {e}")
619
+ try:
620
+ import torch
653
621
 
654
- return TestTorch
622
+ # check which Warp devices work with Torch
623
+ # CUDA devices may fail if Torch was not compiled with CUDA support
624
+ torch_compatible_devices = []
625
+ torch_compatible_cuda_devices = []
626
+
627
+ for d in test_devices:
628
+ try:
629
+ t = torch.arange(10, device=wp.device_to_torch(d))
630
+ t += 1
631
+ torch_compatible_devices.append(d)
632
+ if d.is_cuda:
633
+ torch_compatible_cuda_devices.append(d)
634
+ except Exception as e:
635
+ print(f"Skipping Torch tests on device '{d}' due to exception: {e}")
636
+
637
+ if torch_compatible_devices:
638
+ add_function_test(TestTorch, "test_from_torch", test_from_torch, devices=torch_compatible_devices)
639
+ add_function_test(TestTorch, "test_from_torch_slices", test_from_torch_slices, devices=torch_compatible_devices)
640
+ add_function_test(
641
+ TestTorch,
642
+ "test_from_torch_zero_strides",
643
+ test_from_torch_zero_strides,
644
+ devices=torch_compatible_devices,
645
+ )
646
+ add_function_test(TestTorch, "test_to_torch", test_to_torch, devices=torch_compatible_devices)
647
+ add_function_test(TestTorch, "test_torch_zerocopy", test_torch_zerocopy, devices=torch_compatible_devices)
648
+ add_function_test(TestTorch, "test_torch_autograd", test_torch_autograd, devices=torch_compatible_devices)
649
+
650
+ if torch_compatible_cuda_devices:
651
+ add_function_test(
652
+ TestTorch,
653
+ "test_torch_graph_torch_stream",
654
+ test_torch_graph_torch_stream,
655
+ devices=torch_compatible_cuda_devices,
656
+ )
657
+ add_function_test(
658
+ TestTorch,
659
+ "test_torch_graph_warp_stream",
660
+ test_torch_graph_warp_stream,
661
+ devices=torch_compatible_cuda_devices,
662
+ )
663
+ add_function_test(
664
+ TestTorch,
665
+ "test_warp_graph_warp_stream",
666
+ test_warp_graph_warp_stream,
667
+ devices=torch_compatible_cuda_devices,
668
+ )
669
+ add_function_test(
670
+ TestTorch,
671
+ "test_warp_graph_torch_stream",
672
+ test_warp_graph_torch_stream,
673
+ devices=torch_compatible_cuda_devices,
674
+ )
675
+
676
+ # multi-GPU tests
677
+ if len(torch_compatible_cuda_devices) > 1:
678
+ add_function_test(TestTorch, "test_torch_mgpu_from_torch", test_torch_mgpu_from_torch)
679
+ add_function_test(TestTorch, "test_torch_mgpu_to_torch", test_torch_mgpu_to_torch)
680
+ add_function_test(TestTorch, "test_torch_mgpu_interop", test_torch_mgpu_interop)
681
+
682
+ except Exception as e:
683
+ print(f"Skipping Torch tests due to exception: {e}")
655
684
 
656
685
 
657
686
  if __name__ == "__main__":
658
- c = register(unittest.TestCase)
687
+ wp.build.clear_kernel_cache()
659
688
  unittest.main(verbosity=2)
@@ -5,13 +5,13 @@
5
5
  # distribution of this software and related documentation without an express
6
6
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
7
 
8
- import importlib
9
8
  import os
10
9
  import tempfile
11
10
  import unittest
11
+ from importlib import util
12
12
 
13
13
  import warp as wp
14
- from warp.tests.test_base import *
14
+ from warp.tests.unittest_utils import *
15
15
 
16
16
  CODE = """# -*- coding: utf-8 -*-
17
17
 
@@ -45,8 +45,8 @@ def load_code_as_module(code, name):
45
45
  with os.fdopen(file, "w") as f:
46
46
  f.write(code)
47
47
 
48
- spec = importlib.util.spec_from_file_location(name, file_path)
49
- module = importlib.util.module_from_spec(spec)
48
+ spec = util.spec_from_file_location(name, file_path)
49
+ module = util.module_from_spec(spec)
50
50
  spec.loader.exec_module(module)
51
51
  finally:
52
52
  os.remove(file_path)
@@ -63,26 +63,25 @@ def test_transient_module(test, device):
63
63
  assert len(module.compute.module.functions) == 1
64
64
 
65
65
  data = module.Data()
66
- data.x = wp.array(123, dtype=int)
66
+ data.x = wp.array([123], dtype=int, device=device)
67
67
 
68
68
  wp.set_module_options({"foo": "bar"}, module=module)
69
69
  assert wp.get_module_options(module=module).get("foo") == "bar"
70
70
  assert module.compute.module.options.get("foo") == "bar"
71
71
 
72
- wp.launch(module.compute, dim=1, inputs=[data])
72
+ wp.launch(module.compute, dim=1, inputs=[data], device=device)
73
73
  assert_np_equal(data.x.numpy(), np.array([124]))
74
74
 
75
75
 
76
- def register(parent):
77
- devices = get_test_devices()
76
+ devices = get_test_devices()
78
77
 
79
- class TestTransientModule(parent):
80
- pass
81
78
 
82
- add_function_test(TestTransientModule, "test_transient_module", test_transient_module, devices=devices)
83
- return TestTransientModule
79
+ class TestTransientModule(unittest.TestCase):
80
+ pass
84
81
 
85
82
 
83
+ add_function_test(TestTransientModule, "test_transient_module", test_transient_module, devices=devices)
84
+
86
85
  if __name__ == "__main__":
87
- _ = register(unittest.TestCase)
86
+ wp.build.clear_kernel_cache()
88
87
  unittest.main(verbosity=2)