warp-lang 0.10.1__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (300) hide show
  1. warp/__init__.py +10 -4
  2. warp/__init__.pyi +1 -0
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +5 -3
  6. warp/build_dll.py +29 -9
  7. warp/builtins.py +868 -507
  8. warp/codegen.py +1074 -638
  9. warp/config.py +3 -3
  10. warp/constants.py +6 -0
  11. warp/context.py +715 -222
  12. warp/fabric.py +326 -0
  13. warp/fem/__init__.py +27 -0
  14. warp/fem/cache.py +389 -0
  15. warp/fem/dirichlet.py +181 -0
  16. warp/fem/domain.py +263 -0
  17. warp/fem/field/__init__.py +101 -0
  18. warp/fem/field/field.py +149 -0
  19. warp/fem/field/nodal_field.py +299 -0
  20. warp/fem/field/restriction.py +21 -0
  21. warp/fem/field/test.py +181 -0
  22. warp/fem/field/trial.py +183 -0
  23. warp/fem/geometry/__init__.py +19 -0
  24. warp/fem/geometry/closest_point.py +70 -0
  25. warp/fem/geometry/deformed_geometry.py +271 -0
  26. warp/fem/geometry/element.py +744 -0
  27. warp/fem/geometry/geometry.py +186 -0
  28. warp/fem/geometry/grid_2d.py +373 -0
  29. warp/fem/geometry/grid_3d.py +435 -0
  30. warp/fem/geometry/hexmesh.py +953 -0
  31. warp/fem/geometry/partition.py +376 -0
  32. warp/fem/geometry/quadmesh_2d.py +532 -0
  33. warp/fem/geometry/tetmesh.py +840 -0
  34. warp/fem/geometry/trimesh_2d.py +577 -0
  35. warp/fem/integrate.py +1616 -0
  36. warp/fem/operator.py +191 -0
  37. warp/fem/polynomial.py +213 -0
  38. warp/fem/quadrature/__init__.py +2 -0
  39. warp/fem/quadrature/pic_quadrature.py +245 -0
  40. warp/fem/quadrature/quadrature.py +294 -0
  41. warp/fem/space/__init__.py +292 -0
  42. warp/fem/space/basis_space.py +489 -0
  43. warp/fem/space/collocated_function_space.py +105 -0
  44. warp/fem/space/dof_mapper.py +236 -0
  45. warp/fem/space/function_space.py +145 -0
  46. warp/fem/space/grid_2d_function_space.py +267 -0
  47. warp/fem/space/grid_3d_function_space.py +306 -0
  48. warp/fem/space/hexmesh_function_space.py +352 -0
  49. warp/fem/space/partition.py +350 -0
  50. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  51. warp/fem/space/restriction.py +160 -0
  52. warp/fem/space/shape/__init__.py +15 -0
  53. warp/fem/space/shape/cube_shape_function.py +738 -0
  54. warp/fem/space/shape/shape_function.py +103 -0
  55. warp/fem/space/shape/square_shape_function.py +611 -0
  56. warp/fem/space/shape/tet_shape_function.py +567 -0
  57. warp/fem/space/shape/triangle_shape_function.py +429 -0
  58. warp/fem/space/tetmesh_function_space.py +292 -0
  59. warp/fem/space/topology.py +295 -0
  60. warp/fem/space/trimesh_2d_function_space.py +221 -0
  61. warp/fem/types.py +77 -0
  62. warp/fem/utils.py +495 -0
  63. warp/native/array.h +147 -44
  64. warp/native/builtin.h +122 -149
  65. warp/native/bvh.cpp +73 -325
  66. warp/native/bvh.cu +406 -23
  67. warp/native/bvh.h +34 -43
  68. warp/native/clang/clang.cpp +13 -8
  69. warp/native/crt.h +2 -0
  70. warp/native/cuda_crt.h +5 -0
  71. warp/native/cuda_util.cpp +15 -3
  72. warp/native/cuda_util.h +3 -1
  73. warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
  74. warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
  75. warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
  76. warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
  77. warp/native/cutlass/tools/library/scripts/library.py +799 -0
  78. warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
  79. warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
  80. warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
  81. warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
  82. warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
  83. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
  84. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
  85. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
  86. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
  87. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
  88. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
  89. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
  90. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
  91. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
  92. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
  93. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
  94. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
  95. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
  96. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
  97. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
  98. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
  99. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
  100. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
  101. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
  102. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
  103. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
  104. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
  105. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
  106. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
  107. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
  108. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
  109. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
  110. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
  111. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
  112. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
  113. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
  114. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
  115. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  116. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  117. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
  118. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
  119. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
  120. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  121. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  122. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
  123. warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
  124. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
  125. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
  126. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
  127. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
  128. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
  129. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
  130. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
  131. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
  132. warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
  133. warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
  134. warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
  135. warp/native/cutlass/tools/library/scripts/rt.py +796 -0
  136. warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
  137. warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
  138. warp/native/cutlass_gemm.cu +5 -3
  139. warp/native/exports.h +1240 -952
  140. warp/native/fabric.h +228 -0
  141. warp/native/hashgrid.cpp +4 -4
  142. warp/native/hashgrid.h +22 -2
  143. warp/native/intersect.h +22 -7
  144. warp/native/intersect_adj.h +8 -8
  145. warp/native/intersect_tri.h +1 -1
  146. warp/native/marching.cu +157 -161
  147. warp/native/mat.h +80 -19
  148. warp/native/matnn.h +2 -2
  149. warp/native/mesh.cpp +33 -108
  150. warp/native/mesh.cu +114 -23
  151. warp/native/mesh.h +446 -46
  152. warp/native/noise.h +272 -329
  153. warp/native/quat.h +51 -8
  154. warp/native/rand.h +45 -35
  155. warp/native/range.h +6 -2
  156. warp/native/reduce.cpp +1 -1
  157. warp/native/reduce.cu +10 -12
  158. warp/native/runlength_encode.cu +6 -10
  159. warp/native/scan.cu +8 -11
  160. warp/native/sparse.cpp +4 -4
  161. warp/native/sparse.cu +164 -154
  162. warp/native/spatial.h +2 -2
  163. warp/native/temp_buffer.h +14 -30
  164. warp/native/vec.h +107 -23
  165. warp/native/volume.h +120 -0
  166. warp/native/warp.cpp +560 -30
  167. warp/native/warp.cu +431 -44
  168. warp/native/warp.h +13 -4
  169. warp/optim/__init__.py +1 -0
  170. warp/optim/linear.py +922 -0
  171. warp/optim/sgd.py +92 -0
  172. warp/render/render_opengl.py +335 -119
  173. warp/render/render_usd.py +11 -11
  174. warp/sim/__init__.py +2 -2
  175. warp/sim/articulation.py +385 -185
  176. warp/sim/collide.py +8 -0
  177. warp/sim/import_mjcf.py +297 -106
  178. warp/sim/import_urdf.py +389 -210
  179. warp/sim/import_usd.py +198 -97
  180. warp/sim/inertia.py +17 -18
  181. warp/sim/integrator_euler.py +14 -8
  182. warp/sim/integrator_xpbd.py +158 -16
  183. warp/sim/model.py +795 -291
  184. warp/sim/render.py +3 -3
  185. warp/sim/utils.py +3 -0
  186. warp/sparse.py +640 -150
  187. warp/stubs.py +606 -267
  188. warp/tape.py +61 -10
  189. warp/tests/__main__.py +3 -6
  190. warp/tests/assets/curlnoise_golden.npy +0 -0
  191. warp/tests/assets/pnoise_golden.npy +0 -0
  192. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  193. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  194. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  195. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  196. warp/tests/aux_test_unresolved_func.py +14 -0
  197. warp/tests/aux_test_unresolved_symbol.py +14 -0
  198. warp/tests/disabled_kinematics.py +239 -0
  199. warp/tests/run_coverage_serial.py +31 -0
  200. warp/tests/test_adam.py +103 -106
  201. warp/tests/test_arithmetic.py +128 -74
  202. warp/tests/test_array.py +212 -97
  203. warp/tests/test_array_reduce.py +57 -23
  204. warp/tests/test_atomic.py +64 -28
  205. warp/tests/test_bool.py +99 -0
  206. warp/tests/test_builtins_resolution.py +1292 -0
  207. warp/tests/test_bvh.py +42 -18
  208. warp/tests/test_closest_point_edge_edge.py +54 -57
  209. warp/tests/test_codegen.py +208 -130
  210. warp/tests/test_compile_consts.py +28 -20
  211. warp/tests/test_conditional.py +108 -24
  212. warp/tests/test_copy.py +10 -12
  213. warp/tests/test_ctypes.py +112 -88
  214. warp/tests/test_dense.py +21 -14
  215. warp/tests/test_devices.py +98 -0
  216. warp/tests/test_dlpack.py +75 -75
  217. warp/tests/test_examples.py +277 -0
  218. warp/tests/test_fabricarray.py +955 -0
  219. warp/tests/test_fast_math.py +15 -11
  220. warp/tests/test_fem.py +1271 -0
  221. warp/tests/test_fp16.py +53 -19
  222. warp/tests/test_func.py +187 -86
  223. warp/tests/test_generics.py +194 -49
  224. warp/tests/test_grad.py +178 -109
  225. warp/tests/test_grad_customs.py +176 -0
  226. warp/tests/test_hash_grid.py +52 -37
  227. warp/tests/test_import.py +10 -23
  228. warp/tests/test_indexedarray.py +32 -31
  229. warp/tests/test_intersect.py +18 -9
  230. warp/tests/test_large.py +141 -0
  231. warp/tests/test_launch.py +14 -41
  232. warp/tests/test_lerp.py +64 -65
  233. warp/tests/test_linear_solvers.py +154 -0
  234. warp/tests/test_lvalue.py +493 -0
  235. warp/tests/test_marching_cubes.py +12 -13
  236. warp/tests/test_mat.py +517 -2898
  237. warp/tests/test_mat_lite.py +115 -0
  238. warp/tests/test_mat_scalar_ops.py +2889 -0
  239. warp/tests/test_math.py +103 -9
  240. warp/tests/test_matmul.py +305 -69
  241. warp/tests/test_matmul_lite.py +410 -0
  242. warp/tests/test_mesh.py +71 -14
  243. warp/tests/test_mesh_query_aabb.py +41 -25
  244. warp/tests/test_mesh_query_point.py +140 -22
  245. warp/tests/test_mesh_query_ray.py +39 -22
  246. warp/tests/test_mlp.py +30 -22
  247. warp/tests/test_model.py +92 -89
  248. warp/tests/test_modules_lite.py +39 -0
  249. warp/tests/test_multigpu.py +88 -114
  250. warp/tests/test_noise.py +12 -11
  251. warp/tests/test_operators.py +16 -20
  252. warp/tests/test_options.py +11 -11
  253. warp/tests/test_pinned.py +17 -18
  254. warp/tests/test_print.py +32 -11
  255. warp/tests/test_quat.py +275 -129
  256. warp/tests/test_rand.py +18 -16
  257. warp/tests/test_reload.py +38 -34
  258. warp/tests/test_rounding.py +50 -43
  259. warp/tests/test_runlength_encode.py +168 -20
  260. warp/tests/test_smoothstep.py +9 -11
  261. warp/tests/test_snippet.py +143 -0
  262. warp/tests/test_sparse.py +261 -63
  263. warp/tests/test_spatial.py +276 -243
  264. warp/tests/test_streams.py +110 -85
  265. warp/tests/test_struct.py +268 -63
  266. warp/tests/test_tape.py +39 -21
  267. warp/tests/test_torch.py +118 -89
  268. warp/tests/test_transient_module.py +12 -13
  269. warp/tests/test_types.py +614 -0
  270. warp/tests/test_utils.py +494 -0
  271. warp/tests/test_vec.py +354 -2050
  272. warp/tests/test_vec_lite.py +73 -0
  273. warp/tests/test_vec_scalar_ops.py +2099 -0
  274. warp/tests/test_volume.py +457 -293
  275. warp/tests/test_volume_write.py +124 -134
  276. warp/tests/unittest_serial.py +35 -0
  277. warp/tests/unittest_suites.py +341 -0
  278. warp/tests/unittest_utils.py +568 -0
  279. warp/tests/unused_test_misc.py +71 -0
  280. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  281. warp/thirdparty/appdirs.py +36 -45
  282. warp/thirdparty/unittest_parallel.py +549 -0
  283. warp/torch.py +9 -6
  284. warp/types.py +1089 -366
  285. warp/utils.py +93 -387
  286. warp_lang-0.11.0.dist-info/METADATA +238 -0
  287. warp_lang-0.11.0.dist-info/RECORD +332 -0
  288. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
  289. warp/tests/test_all.py +0 -219
  290. warp/tests/test_array_scan.py +0 -60
  291. warp/tests/test_base.py +0 -208
  292. warp/tests/test_unresolved_func.py +0 -7
  293. warp/tests/test_unresolved_symbol.py +0 -7
  294. warp_lang-0.10.1.dist-info/METADATA +0 -21
  295. warp_lang-0.10.1.dist-info/RECORD +0 -188
  296. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  297. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  298. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  299. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/LICENSE.md +0 -0
  300. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
warp/tests/test_tape.py CHANGED
@@ -5,9 +5,12 @@
5
5
  # distribution of this software and related documentation without an express
6
6
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
7
 
8
+ import unittest
9
+
8
10
  import numpy as np
11
+
9
12
  import warp as wp
10
- from warp.tests.test_base import *
13
+ from warp.tests.unittest_utils import *
11
14
 
12
15
  wp.init()
13
16
 
@@ -19,11 +22,17 @@ def mul_constant(x: wp.array(dtype=float), y: wp.array(dtype=float)):
19
22
  y[tid] = x[tid] * 2.0
20
23
 
21
24
 
25
+ @wp.struct
26
+ class Multiplicands:
27
+ x: wp.array(dtype=float)
28
+ y: wp.array(dtype=float)
29
+
30
+
22
31
  @wp.kernel
23
- def mul_variable(x: wp.array(dtype=float), y: wp.array(dtype=float), z: wp.array(dtype=float)):
32
+ def mul_variable(mutiplicands: Multiplicands, z: wp.array(dtype=float)):
24
33
  tid = wp.tid()
25
34
 
26
- z[tid] = x[tid] * y[tid]
35
+ z[tid] = mutiplicands.x[tid] * mutiplicands.y[tid]
27
36
 
28
37
 
29
38
  @wp.kernel
@@ -65,12 +74,13 @@ def test_tape_mul_variable(test, device):
65
74
 
66
75
  # record onto tape
67
76
  with tape:
68
- # input data
69
- x = wp.array(np.ones(dim) * 16.0, dtype=wp.float32, device=device, requires_grad=True)
70
- y = wp.array(np.ones(dim) * 32.0, dtype=wp.float32, device=device, requires_grad=True)
71
- z = wp.zeros_like(x)
77
+ # input data (Note: We're intentionally testing structs in tapes here)
78
+ multiplicands = Multiplicands()
79
+ multiplicands.x = wp.array(np.ones(dim) * 16.0, dtype=wp.float32, device=device, requires_grad=True)
80
+ multiplicands.y = wp.array(np.ones(dim) * 32.0, dtype=wp.float32, device=device, requires_grad=True)
81
+ z = wp.zeros_like(multiplicands.x)
72
82
 
73
- wp.launch(kernel=mul_variable, dim=dim, inputs=[x, y], outputs=[z], device=device)
83
+ wp.launch(kernel=mul_variable, dim=dim, inputs=[multiplicands], outputs=[z], device=device)
74
84
 
75
85
  # loss = wp.sum(x)
76
86
  z.grad = wp.array(np.ones(dim), device=device, dtype=wp.float32)
@@ -79,16 +89,21 @@ def test_tape_mul_variable(test, device):
79
89
  tape.backward()
80
90
 
81
91
  # grad_x=y, grad_y=x
82
- assert_np_equal(tape.gradients[x].numpy(), y.numpy())
83
- assert_np_equal(tape.gradients[y].numpy(), x.numpy())
92
+ assert_np_equal(tape.gradients[multiplicands].x.numpy(), multiplicands.y.numpy())
93
+ assert_np_equal(tape.gradients[multiplicands].y.numpy(), multiplicands.x.numpy())
84
94
 
85
95
  # run backward again with different incoming gradient
86
96
  # should accumulate the same gradients again onto output
87
97
  # so gradients = 2.0*prev
88
98
  tape.backward()
89
99
 
90
- assert_np_equal(tape.gradients[x].numpy(), y.numpy() * 2.0)
91
- assert_np_equal(tape.gradients[y].numpy(), x.numpy() * 2.0)
100
+ assert_np_equal(tape.gradients[multiplicands].x.numpy(), multiplicands.y.numpy() * 2.0)
101
+ assert_np_equal(tape.gradients[multiplicands].y.numpy(), multiplicands.x.numpy() * 2.0)
102
+
103
+ # Clear launches and zero out the gradients
104
+ tape.reset()
105
+ assert_np_equal(tape.gradients[multiplicands].x.numpy(), np.zeros_like(tape.gradients[multiplicands].x.numpy()))
106
+ test.assertFalse(tape.launches)
92
107
 
93
108
 
94
109
  def test_tape_dot_product(test, device):
@@ -112,19 +127,22 @@ def test_tape_dot_product(test, device):
112
127
  assert_np_equal(tape.gradients[y].numpy(), x.numpy())
113
128
 
114
129
 
115
- def register(parent):
116
- devices = get_test_devices()
130
+ devices = get_test_devices()
131
+
117
132
 
118
- class TestTape(parent):
119
- pass
133
+ class TestTape(unittest.TestCase):
134
+ def test_tape_no_nested_tapes(self):
135
+ with self.assertRaises(RuntimeError):
136
+ with wp.Tape():
137
+ with wp.Tape():
138
+ pass
120
139
 
121
- add_function_test(TestTape, "test_tape_mul_constant", test_tape_mul_constant, devices=devices)
122
- add_function_test(TestTape, "test_tape_mul_variable", test_tape_mul_variable, devices=devices)
123
- add_function_test(TestTape, "test_tape_dot_product", test_tape_dot_product, devices=devices)
124
140
 
125
- return TestTape
141
+ add_function_test(TestTape, "test_tape_mul_constant", test_tape_mul_constant, devices=devices)
142
+ add_function_test(TestTape, "test_tape_mul_variable", test_tape_mul_variable, devices=devices)
143
+ add_function_test(TestTape, "test_tape_dot_product", test_tape_dot_product, devices=devices)
126
144
 
127
145
 
128
146
  if __name__ == "__main__":
129
- c = register(unittest.TestCase)
147
+ wp.build.clear_kernel_cache()
130
148
  unittest.main(verbosity=2)
warp/tests/test_torch.py CHANGED
@@ -5,13 +5,12 @@
5
5
  # distribution of this software and related documentation without an express
6
6
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
7
 
8
- # include parent path
9
- import numpy as np
10
8
  import unittest
11
- import sys
9
+
10
+ import numpy as np
12
11
 
13
12
  import warp as wp
14
- from warp.tests.test_base import *
13
+ from warp.tests.unittest_utils import *
15
14
 
16
15
  wp.init()
17
16
 
@@ -103,7 +102,7 @@ def test_from_torch(test, device):
103
102
  wrap_scalar_tensor_implicit(torch.int16, wp.int16)
104
103
  wrap_scalar_tensor_implicit(torch.int8, wp.int8)
105
104
  wrap_scalar_tensor_implicit(torch.uint8, wp.uint8)
106
- wrap_scalar_tensor_implicit(torch.bool, wp.uint8)
105
+ wrap_scalar_tensor_implicit(torch.bool, wp.bool)
107
106
 
108
107
  # explicitly specify warp dtype
109
108
  def wrap_scalar_tensor_explicit(torch_dtype, expected_warp_dtype):
@@ -127,6 +126,7 @@ def test_from_torch(test, device):
127
126
  wrap_scalar_tensor_explicit(torch.uint8, wp.int8)
128
127
  wrap_scalar_tensor_explicit(torch.bool, wp.uint8)
129
128
  wrap_scalar_tensor_explicit(torch.bool, wp.int8)
129
+ wrap_scalar_tensor_explicit(torch.bool, wp.bool)
130
130
 
131
131
  def wrap_vec_tensor(n, desired_warp_dtype):
132
132
  t = torch.zeros((10, n), dtype=torch.float32, device=torch_device)
@@ -151,6 +151,29 @@ def test_from_torch(test, device):
151
151
  wrap_mat_tensor(4, 4, wp.mat44)
152
152
  wrap_mat_tensor(6, 6, wp.spatial_matrix)
153
153
 
154
+ def wrap_vec_tensor_with_grad(n, desired_warp_dtype):
155
+ t = torch.zeros((10, n), dtype=torch.float32, device=torch_device)
156
+ a = wp.from_torch(t, desired_warp_dtype, requires_grad=True)
157
+ assert a.dtype == desired_warp_dtype
158
+ assert a.shape == (10,)
159
+
160
+ wrap_vec_tensor_with_grad(2, wp.vec2)
161
+ wrap_vec_tensor_with_grad(3, wp.vec3)
162
+ wrap_vec_tensor_with_grad(4, wp.vec4)
163
+ wrap_vec_tensor_with_grad(6, wp.spatial_vector)
164
+ wrap_vec_tensor_with_grad(7, wp.transform)
165
+
166
+ def wrap_mat_tensor_with_grad(n, m, desired_warp_dtype):
167
+ t = torch.zeros((10, n, m), dtype=torch.float32, device=torch_device)
168
+ a = wp.from_torch(t, desired_warp_dtype, requires_grad=True)
169
+ assert a.dtype == desired_warp_dtype
170
+ assert a.shape == (10,)
171
+
172
+ wrap_mat_tensor_with_grad(2, 2, wp.mat22)
173
+ wrap_mat_tensor_with_grad(3, 3, wp.mat33)
174
+ wrap_mat_tensor_with_grad(4, 4, wp.mat44)
175
+ wrap_mat_tensor_with_grad(6, 6, wp.spatial_matrix)
176
+
154
177
 
155
178
  def test_to_torch(test, device):
156
179
  import torch
@@ -169,6 +192,7 @@ def test_to_torch(test, device):
169
192
  wrap_scalar_array(wp.int16, torch.int16)
170
193
  wrap_scalar_array(wp.int8, torch.int8)
171
194
  wrap_scalar_array(wp.uint8, torch.uint8)
195
+ wrap_scalar_array(wp.bool, torch.bool)
172
196
 
173
197
  # not supported by torch
174
198
  # wrap_scalar_array(wp.uint64, torch.int64)
@@ -445,6 +469,8 @@ def test_torch_autograd(test, device):
445
469
  def test_torch_graph_torch_stream(test, device):
446
470
  """Capture Torch graph on Torch stream"""
447
471
 
472
+ wp.load_module(device=device)
473
+
448
474
  import torch
449
475
 
450
476
  torch_device = wp.device_to_torch(device)
@@ -526,12 +552,14 @@ def test_warp_graph_warp_stream(test, device):
526
552
 
527
553
  # capture graph
528
554
  with wp.ScopedDevice(device), torch.cuda.stream(torch_stream):
529
- wp.capture_begin()
530
- t += 1.0
531
- wp.launch(inc, dim=n, inputs=[a])
532
- t += 1.0
533
- wp.launch(inc, dim=n, inputs=[a])
534
- g = wp.capture_end()
555
+ wp.capture_begin(force_module_load=False)
556
+ try:
557
+ t += 1.0
558
+ wp.launch(inc, dim=n, inputs=[a])
559
+ t += 1.0
560
+ wp.launch(inc, dim=n, inputs=[a])
561
+ finally:
562
+ g = wp.capture_end()
535
563
 
536
564
  # replay graph
537
565
  num_iters = 10
@@ -545,6 +573,8 @@ def test_warp_graph_warp_stream(test, device):
545
573
  def test_warp_graph_torch_stream(test, device):
546
574
  """Capture Warp graph on Torch stream"""
547
575
 
576
+ wp.load_module(device=device)
577
+
548
578
  import torch
549
579
 
550
580
  torch_device = wp.device_to_torch(device)
@@ -562,12 +592,14 @@ def test_warp_graph_torch_stream(test, device):
562
592
 
563
593
  # capture graph
564
594
  with wp.ScopedStream(warp_stream), torch.cuda.stream(torch_stream):
565
- wp.capture_begin()
566
- t += 1.0
567
- wp.launch(inc, dim=n, inputs=[a])
568
- t += 1.0
569
- wp.launch(inc, dim=n, inputs=[a])
570
- g = wp.capture_end()
595
+ wp.capture_begin(force_module_load=False)
596
+ try:
597
+ t += 1.0
598
+ wp.launch(inc, dim=n, inputs=[a])
599
+ t += 1.0
600
+ wp.launch(inc, dim=n, inputs=[a])
601
+ finally:
602
+ g = wp.capture_end()
571
603
 
572
604
  # replay graph
573
605
  num_iters = 10
@@ -578,82 +610,79 @@ def test_warp_graph_torch_stream(test, device):
578
610
  assert passed.item()
579
611
 
580
612
 
581
- def register(parent):
582
- class TestTorch(parent):
583
- pass
584
-
585
- try:
586
- import torch
587
-
588
- # check which Warp devices work with Torch
589
- # CUDA devices may fail if Torch was not compiled with CUDA support
590
- test_devices = get_test_devices()
591
- torch_compatible_devices = []
592
- torch_compatible_cuda_devices = []
593
-
594
- for d in test_devices:
595
- try:
596
- t = torch.arange(10, device=wp.device_to_torch(d))
597
- t += 1
598
- torch_compatible_devices.append(d)
599
- if d.is_cuda:
600
- torch_compatible_cuda_devices.append(d)
601
- except Exception as e:
602
- print(f"Skipping Torch tests on device '{d}' due to exception: {e}")
603
-
604
- if torch_compatible_devices:
605
- add_function_test(TestTorch, "test_from_torch", test_from_torch, devices=torch_compatible_devices)
606
- add_function_test(
607
- TestTorch, "test_from_torch_slices", test_from_torch_slices, devices=torch_compatible_devices
608
- )
609
- add_function_test(
610
- TestTorch,
611
- "test_from_torch_zero_strides",
612
- test_from_torch_zero_strides,
613
- devices=torch_compatible_devices,
614
- )
615
- add_function_test(TestTorch, "test_to_torch", test_to_torch, devices=torch_compatible_devices)
616
- add_function_test(TestTorch, "test_torch_zerocopy", test_torch_zerocopy, devices=torch_compatible_devices)
617
- add_function_test(TestTorch, "test_torch_autograd", test_torch_autograd, devices=torch_compatible_devices)
618
-
619
- if torch_compatible_cuda_devices:
620
- add_function_test(
621
- TestTorch,
622
- "test_torch_graph_torch_stream",
623
- test_torch_graph_torch_stream,
624
- devices=torch_compatible_cuda_devices,
625
- )
626
- add_function_test(
627
- TestTorch,
628
- "test_torch_graph_warp_stream",
629
- test_torch_graph_warp_stream,
630
- devices=torch_compatible_cuda_devices,
631
- )
632
- add_function_test(
633
- TestTorch,
634
- "test_warp_graph_warp_stream",
635
- test_warp_graph_warp_stream,
636
- devices=torch_compatible_cuda_devices,
637
- )
638
- add_function_test(
639
- TestTorch,
640
- "test_warp_graph_torch_stream",
641
- test_warp_graph_torch_stream,
642
- devices=torch_compatible_cuda_devices,
643
- )
613
+ class TestTorch(unittest.TestCase):
614
+ pass
615
+
644
616
 
645
- # multi-GPU tests
646
- if len(torch_compatible_cuda_devices) > 1:
647
- add_function_test(TestTorch, "test_torch_mgpu_from_torch", test_torch_mgpu_from_torch)
648
- add_function_test(TestTorch, "test_torch_mgpu_to_torch", test_torch_mgpu_to_torch)
649
- add_function_test(TestTorch, "test_torch_mgpu_interop", test_torch_mgpu_interop)
617
+ test_devices = get_test_devices()
650
618
 
651
- except Exception as e:
652
- print(f"Skipping Torch tests due to exception: {e}")
619
+ try:
620
+ import torch
653
621
 
654
- return TestTorch
622
+ # check which Warp devices work with Torch
623
+ # CUDA devices may fail if Torch was not compiled with CUDA support
624
+ torch_compatible_devices = []
625
+ torch_compatible_cuda_devices = []
626
+
627
+ for d in test_devices:
628
+ try:
629
+ t = torch.arange(10, device=wp.device_to_torch(d))
630
+ t += 1
631
+ torch_compatible_devices.append(d)
632
+ if d.is_cuda:
633
+ torch_compatible_cuda_devices.append(d)
634
+ except Exception as e:
635
+ print(f"Skipping Torch tests on device '{d}' due to exception: {e}")
636
+
637
+ if torch_compatible_devices:
638
+ add_function_test(TestTorch, "test_from_torch", test_from_torch, devices=torch_compatible_devices)
639
+ add_function_test(TestTorch, "test_from_torch_slices", test_from_torch_slices, devices=torch_compatible_devices)
640
+ add_function_test(
641
+ TestTorch,
642
+ "test_from_torch_zero_strides",
643
+ test_from_torch_zero_strides,
644
+ devices=torch_compatible_devices,
645
+ )
646
+ add_function_test(TestTorch, "test_to_torch", test_to_torch, devices=torch_compatible_devices)
647
+ add_function_test(TestTorch, "test_torch_zerocopy", test_torch_zerocopy, devices=torch_compatible_devices)
648
+ add_function_test(TestTorch, "test_torch_autograd", test_torch_autograd, devices=torch_compatible_devices)
649
+
650
+ if torch_compatible_cuda_devices:
651
+ add_function_test(
652
+ TestTorch,
653
+ "test_torch_graph_torch_stream",
654
+ test_torch_graph_torch_stream,
655
+ devices=torch_compatible_cuda_devices,
656
+ )
657
+ add_function_test(
658
+ TestTorch,
659
+ "test_torch_graph_warp_stream",
660
+ test_torch_graph_warp_stream,
661
+ devices=torch_compatible_cuda_devices,
662
+ )
663
+ add_function_test(
664
+ TestTorch,
665
+ "test_warp_graph_warp_stream",
666
+ test_warp_graph_warp_stream,
667
+ devices=torch_compatible_cuda_devices,
668
+ )
669
+ add_function_test(
670
+ TestTorch,
671
+ "test_warp_graph_torch_stream",
672
+ test_warp_graph_torch_stream,
673
+ devices=torch_compatible_cuda_devices,
674
+ )
675
+
676
+ # multi-GPU tests
677
+ if len(torch_compatible_cuda_devices) > 1:
678
+ add_function_test(TestTorch, "test_torch_mgpu_from_torch", test_torch_mgpu_from_torch)
679
+ add_function_test(TestTorch, "test_torch_mgpu_to_torch", test_torch_mgpu_to_torch)
680
+ add_function_test(TestTorch, "test_torch_mgpu_interop", test_torch_mgpu_interop)
681
+
682
+ except Exception as e:
683
+ print(f"Skipping Torch tests due to exception: {e}")
655
684
 
656
685
 
657
686
  if __name__ == "__main__":
658
- c = register(unittest.TestCase)
687
+ wp.build.clear_kernel_cache()
659
688
  unittest.main(verbosity=2)
@@ -5,13 +5,13 @@
5
5
  # distribution of this software and related documentation without an express
6
6
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
7
 
8
- import importlib
9
8
  import os
10
9
  import tempfile
11
10
  import unittest
11
+ from importlib import util
12
12
 
13
13
  import warp as wp
14
- from warp.tests.test_base import *
14
+ from warp.tests.unittest_utils import *
15
15
 
16
16
  CODE = """# -*- coding: utf-8 -*-
17
17
 
@@ -45,8 +45,8 @@ def load_code_as_module(code, name):
45
45
  with os.fdopen(file, "w") as f:
46
46
  f.write(code)
47
47
 
48
- spec = importlib.util.spec_from_file_location(name, file_path)
49
- module = importlib.util.module_from_spec(spec)
48
+ spec = util.spec_from_file_location(name, file_path)
49
+ module = util.module_from_spec(spec)
50
50
  spec.loader.exec_module(module)
51
51
  finally:
52
52
  os.remove(file_path)
@@ -63,26 +63,25 @@ def test_transient_module(test, device):
63
63
  assert len(module.compute.module.functions) == 1
64
64
 
65
65
  data = module.Data()
66
- data.x = wp.array([123], dtype=int)
66
+ data.x = wp.array([123], dtype=int, device=device)
67
67
 
68
68
  wp.set_module_options({"foo": "bar"}, module=module)
69
69
  assert wp.get_module_options(module=module).get("foo") == "bar"
70
70
  assert module.compute.module.options.get("foo") == "bar"
71
71
 
72
- wp.launch(module.compute, dim=1, inputs=[data])
72
+ wp.launch(module.compute, dim=1, inputs=[data], device=device)
73
73
  assert_np_equal(data.x.numpy(), np.array([124]))
74
74
 
75
75
 
76
- def register(parent):
77
- devices = get_test_devices()
76
+ devices = get_test_devices()
78
77
 
79
- class TestTransientModule(parent):
80
- pass
81
78
 
82
- add_function_test(TestTransientModule, "test_transient_module", test_transient_module, devices=devices)
83
- return TestTransientModule
79
+ class TestTransientModule(unittest.TestCase):
80
+ pass
84
81
 
85
82
 
83
+ add_function_test(TestTransientModule, "test_transient_module", test_transient_module, devices=devices)
84
+
86
85
  if __name__ == "__main__":
87
- _ = register(unittest.TestCase)
86
+ wp.build.clear_kernel_cache()
88
87
  unittest.main(verbosity=2)