warp-lang 0.9.0__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (315) hide show
  1. warp/__init__.py +15 -7
  2. warp/__init__.pyi +1 -0
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +22 -443
  6. warp/build_dll.py +384 -0
  7. warp/builtins.py +998 -488
  8. warp/codegen.py +1307 -739
  9. warp/config.py +5 -3
  10. warp/constants.py +6 -0
  11. warp/context.py +1291 -548
  12. warp/dlpack.py +31 -31
  13. warp/fabric.py +326 -0
  14. warp/fem/__init__.py +27 -0
  15. warp/fem/cache.py +389 -0
  16. warp/fem/dirichlet.py +181 -0
  17. warp/fem/domain.py +263 -0
  18. warp/fem/field/__init__.py +101 -0
  19. warp/fem/field/field.py +149 -0
  20. warp/fem/field/nodal_field.py +299 -0
  21. warp/fem/field/restriction.py +21 -0
  22. warp/fem/field/test.py +181 -0
  23. warp/fem/field/trial.py +183 -0
  24. warp/fem/geometry/__init__.py +19 -0
  25. warp/fem/geometry/closest_point.py +70 -0
  26. warp/fem/geometry/deformed_geometry.py +271 -0
  27. warp/fem/geometry/element.py +744 -0
  28. warp/fem/geometry/geometry.py +186 -0
  29. warp/fem/geometry/grid_2d.py +373 -0
  30. warp/fem/geometry/grid_3d.py +435 -0
  31. warp/fem/geometry/hexmesh.py +953 -0
  32. warp/fem/geometry/partition.py +376 -0
  33. warp/fem/geometry/quadmesh_2d.py +532 -0
  34. warp/fem/geometry/tetmesh.py +840 -0
  35. warp/fem/geometry/trimesh_2d.py +577 -0
  36. warp/fem/integrate.py +1616 -0
  37. warp/fem/operator.py +191 -0
  38. warp/fem/polynomial.py +213 -0
  39. warp/fem/quadrature/__init__.py +2 -0
  40. warp/fem/quadrature/pic_quadrature.py +245 -0
  41. warp/fem/quadrature/quadrature.py +294 -0
  42. warp/fem/space/__init__.py +292 -0
  43. warp/fem/space/basis_space.py +489 -0
  44. warp/fem/space/collocated_function_space.py +105 -0
  45. warp/fem/space/dof_mapper.py +236 -0
  46. warp/fem/space/function_space.py +145 -0
  47. warp/fem/space/grid_2d_function_space.py +267 -0
  48. warp/fem/space/grid_3d_function_space.py +306 -0
  49. warp/fem/space/hexmesh_function_space.py +352 -0
  50. warp/fem/space/partition.py +350 -0
  51. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  52. warp/fem/space/restriction.py +160 -0
  53. warp/fem/space/shape/__init__.py +15 -0
  54. warp/fem/space/shape/cube_shape_function.py +738 -0
  55. warp/fem/space/shape/shape_function.py +103 -0
  56. warp/fem/space/shape/square_shape_function.py +611 -0
  57. warp/fem/space/shape/tet_shape_function.py +567 -0
  58. warp/fem/space/shape/triangle_shape_function.py +429 -0
  59. warp/fem/space/tetmesh_function_space.py +292 -0
  60. warp/fem/space/topology.py +295 -0
  61. warp/fem/space/trimesh_2d_function_space.py +221 -0
  62. warp/fem/types.py +77 -0
  63. warp/fem/utils.py +495 -0
  64. warp/native/array.h +164 -55
  65. warp/native/builtin.h +150 -174
  66. warp/native/bvh.cpp +75 -328
  67. warp/native/bvh.cu +406 -23
  68. warp/native/bvh.h +37 -45
  69. warp/native/clang/clang.cpp +136 -24
  70. warp/native/crt.cpp +1 -76
  71. warp/native/crt.h +111 -104
  72. warp/native/cuda_crt.h +1049 -0
  73. warp/native/cuda_util.cpp +15 -3
  74. warp/native/cuda_util.h +3 -1
  75. warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
  76. warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
  77. warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
  78. warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
  79. warp/native/cutlass/tools/library/scripts/library.py +799 -0
  80. warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
  81. warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
  82. warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
  83. warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
  84. warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
  85. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
  86. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
  87. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
  88. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
  89. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
  90. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
  91. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
  92. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
  93. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
  94. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
  95. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
  96. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
  97. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
  98. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
  99. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
  100. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
  101. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
  102. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
  103. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
  104. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
  105. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
  106. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
  107. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
  108. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
  109. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
  110. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
  111. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
  112. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
  113. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
  114. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
  115. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
  116. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
  117. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  118. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  119. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
  120. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
  121. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
  122. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  123. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  124. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
  125. warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
  126. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
  127. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
  128. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
  129. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
  130. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
  131. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
  132. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
  133. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
  134. warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
  135. warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
  136. warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
  137. warp/native/cutlass/tools/library/scripts/rt.py +796 -0
  138. warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
  139. warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
  140. warp/native/cutlass_gemm.cu +5 -3
  141. warp/native/exports.h +1240 -949
  142. warp/native/fabric.h +228 -0
  143. warp/native/hashgrid.cpp +4 -4
  144. warp/native/hashgrid.h +22 -2
  145. warp/native/initializer_array.h +2 -2
  146. warp/native/intersect.h +22 -7
  147. warp/native/intersect_adj.h +8 -8
  148. warp/native/intersect_tri.h +13 -16
  149. warp/native/marching.cu +157 -161
  150. warp/native/mat.h +119 -19
  151. warp/native/matnn.h +2 -2
  152. warp/native/mesh.cpp +108 -83
  153. warp/native/mesh.cu +243 -6
  154. warp/native/mesh.h +1547 -458
  155. warp/native/nanovdb/NanoVDB.h +1 -1
  156. warp/native/noise.h +272 -329
  157. warp/native/quat.h +51 -8
  158. warp/native/rand.h +45 -35
  159. warp/native/range.h +6 -2
  160. warp/native/reduce.cpp +157 -0
  161. warp/native/reduce.cu +348 -0
  162. warp/native/runlength_encode.cpp +62 -0
  163. warp/native/runlength_encode.cu +46 -0
  164. warp/native/scan.cu +11 -13
  165. warp/native/scan.h +1 -0
  166. warp/native/solid_angle.h +442 -0
  167. warp/native/sort.cpp +13 -0
  168. warp/native/sort.cu +9 -1
  169. warp/native/sparse.cpp +338 -0
  170. warp/native/sparse.cu +545 -0
  171. warp/native/spatial.h +2 -2
  172. warp/native/temp_buffer.h +30 -0
  173. warp/native/vec.h +126 -24
  174. warp/native/volume.h +120 -0
  175. warp/native/warp.cpp +658 -53
  176. warp/native/warp.cu +660 -68
  177. warp/native/warp.h +112 -12
  178. warp/optim/__init__.py +1 -0
  179. warp/optim/linear.py +922 -0
  180. warp/optim/sgd.py +92 -0
  181. warp/render/render_opengl.py +392 -152
  182. warp/render/render_usd.py +11 -11
  183. warp/sim/__init__.py +2 -2
  184. warp/sim/articulation.py +385 -185
  185. warp/sim/collide.py +21 -8
  186. warp/sim/import_mjcf.py +297 -106
  187. warp/sim/import_urdf.py +389 -210
  188. warp/sim/import_usd.py +198 -97
  189. warp/sim/inertia.py +17 -18
  190. warp/sim/integrator_euler.py +14 -8
  191. warp/sim/integrator_xpbd.py +161 -19
  192. warp/sim/model.py +795 -291
  193. warp/sim/optimizer.py +2 -6
  194. warp/sim/render.py +65 -3
  195. warp/sim/utils.py +3 -0
  196. warp/sparse.py +1227 -0
  197. warp/stubs.py +665 -223
  198. warp/tape.py +66 -15
  199. warp/tests/__main__.py +3 -6
  200. warp/tests/assets/curlnoise_golden.npy +0 -0
  201. warp/tests/assets/pnoise_golden.npy +0 -0
  202. warp/tests/assets/torus.usda +105 -105
  203. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  204. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  205. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  206. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  207. warp/tests/aux_test_unresolved_func.py +14 -0
  208. warp/tests/aux_test_unresolved_symbol.py +14 -0
  209. warp/tests/disabled_kinematics.py +239 -0
  210. warp/tests/run_coverage_serial.py +31 -0
  211. warp/tests/test_adam.py +103 -106
  212. warp/tests/test_arithmetic.py +128 -74
  213. warp/tests/test_array.py +1497 -211
  214. warp/tests/test_array_reduce.py +150 -0
  215. warp/tests/test_atomic.py +64 -28
  216. warp/tests/test_bool.py +99 -0
  217. warp/tests/test_builtins_resolution.py +1292 -0
  218. warp/tests/test_bvh.py +75 -43
  219. warp/tests/test_closest_point_edge_edge.py +54 -57
  220. warp/tests/test_codegen.py +233 -128
  221. warp/tests/test_compile_consts.py +28 -20
  222. warp/tests/test_conditional.py +108 -24
  223. warp/tests/test_copy.py +10 -12
  224. warp/tests/test_ctypes.py +112 -88
  225. warp/tests/test_dense.py +21 -14
  226. warp/tests/test_devices.py +98 -0
  227. warp/tests/test_dlpack.py +136 -108
  228. warp/tests/test_examples.py +277 -0
  229. warp/tests/test_fabricarray.py +955 -0
  230. warp/tests/test_fast_math.py +15 -11
  231. warp/tests/test_fem.py +1271 -0
  232. warp/tests/test_fp16.py +53 -19
  233. warp/tests/test_func.py +187 -74
  234. warp/tests/test_generics.py +194 -49
  235. warp/tests/test_grad.py +180 -116
  236. warp/tests/test_grad_customs.py +176 -0
  237. warp/tests/test_hash_grid.py +52 -37
  238. warp/tests/test_import.py +10 -23
  239. warp/tests/test_indexedarray.py +577 -24
  240. warp/tests/test_intersect.py +18 -9
  241. warp/tests/test_large.py +141 -0
  242. warp/tests/test_launch.py +251 -15
  243. warp/tests/test_lerp.py +64 -65
  244. warp/tests/test_linear_solvers.py +154 -0
  245. warp/tests/test_lvalue.py +493 -0
  246. warp/tests/test_marching_cubes.py +12 -13
  247. warp/tests/test_mat.py +508 -2778
  248. warp/tests/test_mat_lite.py +115 -0
  249. warp/tests/test_mat_scalar_ops.py +2889 -0
  250. warp/tests/test_math.py +103 -9
  251. warp/tests/test_matmul.py +305 -69
  252. warp/tests/test_matmul_lite.py +410 -0
  253. warp/tests/test_mesh.py +71 -14
  254. warp/tests/test_mesh_query_aabb.py +41 -25
  255. warp/tests/test_mesh_query_point.py +325 -34
  256. warp/tests/test_mesh_query_ray.py +39 -22
  257. warp/tests/test_mlp.py +30 -22
  258. warp/tests/test_model.py +92 -89
  259. warp/tests/test_modules_lite.py +39 -0
  260. warp/tests/test_multigpu.py +88 -114
  261. warp/tests/test_noise.py +12 -11
  262. warp/tests/test_operators.py +16 -20
  263. warp/tests/test_options.py +11 -11
  264. warp/tests/test_pinned.py +17 -18
  265. warp/tests/test_print.py +32 -11
  266. warp/tests/test_quat.py +275 -129
  267. warp/tests/test_rand.py +18 -16
  268. warp/tests/test_reload.py +38 -34
  269. warp/tests/test_rounding.py +50 -43
  270. warp/tests/test_runlength_encode.py +190 -0
  271. warp/tests/test_smoothstep.py +9 -11
  272. warp/tests/test_snippet.py +143 -0
  273. warp/tests/test_sparse.py +460 -0
  274. warp/tests/test_spatial.py +276 -243
  275. warp/tests/test_streams.py +110 -85
  276. warp/tests/test_struct.py +331 -85
  277. warp/tests/test_tape.py +39 -21
  278. warp/tests/test_torch.py +118 -89
  279. warp/tests/test_transient_module.py +12 -13
  280. warp/tests/test_types.py +614 -0
  281. warp/tests/test_utils.py +494 -0
  282. warp/tests/test_vec.py +354 -1987
  283. warp/tests/test_vec_lite.py +73 -0
  284. warp/tests/test_vec_scalar_ops.py +2099 -0
  285. warp/tests/test_volume.py +457 -293
  286. warp/tests/test_volume_write.py +124 -134
  287. warp/tests/unittest_serial.py +35 -0
  288. warp/tests/unittest_suites.py +341 -0
  289. warp/tests/unittest_utils.py +568 -0
  290. warp/tests/unused_test_misc.py +71 -0
  291. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  292. warp/thirdparty/appdirs.py +36 -45
  293. warp/thirdparty/unittest_parallel.py +549 -0
  294. warp/torch.py +72 -30
  295. warp/types.py +1744 -713
  296. warp/utils.py +360 -350
  297. warp_lang-0.11.0.dist-info/LICENSE.md +36 -0
  298. warp_lang-0.11.0.dist-info/METADATA +238 -0
  299. warp_lang-0.11.0.dist-info/RECORD +332 -0
  300. {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
  301. warp/bin/warp-clang.exp +0 -0
  302. warp/bin/warp-clang.lib +0 -0
  303. warp/bin/warp.exp +0 -0
  304. warp/bin/warp.lib +0 -0
  305. warp/tests/test_all.py +0 -215
  306. warp/tests/test_array_scan.py +0 -60
  307. warp/tests/test_base.py +0 -208
  308. warp/tests/test_unresolved_func.py +0 -7
  309. warp/tests/test_unresolved_symbol.py +0 -7
  310. warp_lang-0.9.0.dist-info/METADATA +0 -20
  311. warp_lang-0.9.0.dist-info/RECORD +0 -177
  312. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  313. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  314. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  315. {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
warp/context.py CHANGED
@@ -5,37 +5,27 @@
5
5
  # distribution of this software and related documentation without an express
6
6
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
7
 
8
- import math
9
- import os
10
- import sys
11
- import hashlib
8
+ import ast
12
9
  import ctypes
10
+ import gc
11
+ import hashlib
12
+ import inspect
13
+ import io
14
+ import os
13
15
  import platform
14
- import ast
16
+ import sys
15
17
  import types
16
- import inspect
17
-
18
- from typing import Tuple
19
- from typing import List
20
- from typing import Dict
21
- from typing import Any
22
- from typing import Callable
23
- from typing import Union
24
- from typing import Mapping
25
- from typing import Optional
26
-
18
+ from copy import copy as shallowcopy
27
19
  from types import ModuleType
20
+ from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union
28
21
 
29
- from copy import copy as shallowcopy
22
+ import numpy as np
30
23
 
31
24
  import warp
32
- import warp.utils
33
- import warp.codegen
34
25
  import warp.build
26
+ import warp.codegen
35
27
  import warp.config
36
28
 
37
- import numpy as np
38
-
39
29
  # represents either a built-in or user-defined function
40
30
 
41
31
 
@@ -46,6 +36,18 @@ def create_value_func(type):
46
36
  return value_func
47
37
 
48
38
 
39
+ def get_function_args(func):
40
+ """Ensures that all function arguments are annotated and returns a dictionary mapping from argument name to its type."""
41
+ import inspect
42
+
43
+ argspec = inspect.getfullargspec(func)
44
+
45
+ # use source-level argument annotations
46
+ if len(argspec.annotations) < len(argspec.args):
47
+ raise RuntimeError(f"Incomplete argument annotations on function {func.__qualname__}")
48
+ return argspec.annotations
49
+
50
+
49
51
  class Function:
50
52
  def __init__(
51
53
  self,
@@ -67,6 +69,17 @@ class Function:
67
69
  generic=False,
68
70
  native_func=None,
69
71
  defaults=None,
72
+ custom_replay_func=None,
73
+ native_snippet=None,
74
+ adj_native_snippet=None,
75
+ skip_forward_codegen=False,
76
+ skip_reverse_codegen=False,
77
+ custom_reverse_num_input_args=-1,
78
+ custom_reverse_mode=False,
79
+ overloaded_annotations=None,
80
+ code_transformers=[],
81
+ skip_adding_overload=False,
82
+ require_original_output_arg=False,
70
83
  ):
71
84
  self.func = func # points to Python function decorated with @wp.func, may be None for builtins
72
85
  self.key = key
@@ -80,6 +93,12 @@ class Function:
80
93
  self.module = module
81
94
  self.variadic = variadic # function can take arbitrary number of inputs, e.g.: printf()
82
95
  self.defaults = defaults
96
+ # Function instance for a custom implementation of the replay pass
97
+ self.custom_replay_func = custom_replay_func
98
+ self.native_snippet = native_snippet
99
+ self.adj_native_snippet = adj_native_snippet
100
+ self.custom_grad_func = None
101
+ self.require_original_output_arg = require_original_output_arg
83
102
 
84
103
  if initializer_list_func is None:
85
104
  self.initializer_list_func = lambda x, y: False
@@ -108,7 +127,16 @@ class Function:
108
127
  self.user_overloads = {}
109
128
 
110
129
  # user defined (Python) function
111
- self.adj = warp.codegen.Adjoint(func)
130
+ self.adj = warp.codegen.Adjoint(
131
+ func,
132
+ is_user_function=True,
133
+ skip_forward_codegen=skip_forward_codegen,
134
+ skip_reverse_codegen=skip_reverse_codegen,
135
+ custom_reverse_num_input_args=custom_reverse_num_input_args,
136
+ custom_reverse_mode=custom_reverse_mode,
137
+ overload_annotations=overloaded_annotations,
138
+ transformers=code_transformers,
139
+ )
112
140
 
113
141
  # record input types
114
142
  for name, type in self.adj.arg_types.items():
@@ -136,11 +164,12 @@ class Function:
136
164
  else:
137
165
  self.mangled_name = None
138
166
 
139
- self.add_overload(self)
167
+ if not skip_adding_overload:
168
+ self.add_overload(self)
140
169
 
141
170
  # add to current module
142
171
  if module:
143
- module.register_function(self)
172
+ module.register_function(self, skip_adding_overload)
144
173
 
145
174
  def __call__(self, *args, **kwargs):
146
175
  # handles calling a builtin (native) function
@@ -149,124 +178,52 @@ class Function:
149
178
  # from within a kernel (experimental).
150
179
 
151
180
  if self.is_builtin() and self.mangled_name:
152
- # store last error during overload resolution
153
- error = None
154
-
155
- for f in self.overloads:
156
- if f.generic:
181
+ # For each of this function's existing overloads, we attempt to pack
182
+ # the given arguments into the C types expected by the corresponding
183
+ # parameters, and we rinse and repeat until we get a match.
184
+ for overload in self.overloads:
185
+ if overload.generic:
157
186
  continue
158
187
 
159
- # try and find builtin in the warp.dll
160
- if hasattr(warp.context.runtime.core, f.mangled_name) == False:
161
- raise RuntimeError(
162
- f"Couldn't find function {self.key} with mangled name {f.mangled_name} in the Warp native library"
163
- )
164
-
165
- try:
166
- # try and pack args into what the function expects
167
- params = []
168
- for i, (arg_name, arg_type) in enumerate(f.input_types.items()):
169
- a = args[i]
170
-
171
- # try to convert to a value type (vec3, mat33, etc)
172
- if issubclass(arg_type, ctypes.Array):
173
- # wrap the arg_type (which is an ctypes.Array) in a structure
174
- # to ensure parameter is passed to the .dll by value rather than reference
175
- class ValueArg(ctypes.Structure):
176
- _fields_ = [("value", arg_type)]
177
-
178
- x = ValueArg()
179
-
180
- # force conversion to ndarray first (handles tuple / list, Gf.Vec3 case)
181
- if isinstance(a, ctypes.Array) == False:
182
- # assume you want the float32 version of the function so it doesn't just
183
- # grab an override for a random data type:
184
- if arg_type._type_ != ctypes.c_float:
185
- raise RuntimeError(
186
- f"Error calling function '{f.key}', parameter for argument '{arg_name}' does not have c_float type."
187
- )
188
-
189
- a = np.array(a)
190
-
191
- # flatten to 1D array
192
- v = a.flatten()
193
- if len(v) != arg_type._length_:
194
- raise RuntimeError(
195
- f"Error calling function '{f.key}', parameter for argument '{arg_name}' has length {len(v)}, but expected {arg_type._length_}. Could not convert parameter to {arg_type}."
196
- )
197
-
198
- for i in range(arg_type._length_):
199
- x.value[i] = v[i]
200
-
201
- else:
202
- # already a built-in type, check it matches
203
- if not warp.types.types_equal(type(a), arg_type):
204
- raise RuntimeError(
205
- f"Error calling function '{f.key}', parameter for argument '{arg_name}' has type '{type(a)}' but expected '{arg_type}'"
206
- )
207
-
208
- x.value = a
209
-
210
- params.append(x)
211
-
212
- else:
213
- try:
214
- # try to pack as a scalar type
215
- params.append(arg_type._type_(a))
216
- except:
217
- raise RuntimeError(
218
- f"Error calling function {f.key}, unable to pack function parameter type {type(a)} for param {arg_name}, expected {arg_type}"
219
- )
220
-
221
- # returns the corresponding ctype for a scalar or vector warp type
222
- def type_ctype(dtype):
223
- if dtype == float:
224
- return ctypes.c_float
225
- elif dtype == int:
226
- return ctypes.c_int32
227
- elif issubclass(dtype, ctypes.Array):
228
- return dtype
229
- elif issubclass(dtype, ctypes.Structure):
230
- return dtype
231
- else:
232
- # scalar type
233
- return dtype._type_
234
-
235
- value_type = type_ctype(f.value_func(None, None, None))
236
-
237
- # construct return value (passed by address)
238
- ret = value_type()
239
- ret_addr = ctypes.c_void_p(ctypes.addressof(ret))
188
+ success, return_value = call_builtin(overload, *args)
189
+ if success:
190
+ return return_value
240
191
 
241
- params.append(ret_addr)
192
+ # overload resolution or call failed
193
+ raise RuntimeError(
194
+ f"Couldn't find a function '{self.key}' compatible with "
195
+ f"the arguments '{', '.join(type(x).__name__ for x in args)}'"
196
+ )
242
197
 
243
- c_func = getattr(warp.context.runtime.core, f.mangled_name)
244
- c_func(*params)
198
+ if hasattr(self, "user_overloads") and len(self.user_overloads):
199
+ # user-defined function with overloads
245
200
 
246
- if issubclass(value_type, ctypes.Array) or issubclass(value_type, ctypes.Structure):
247
- # return vector types as ctypes
248
- return ret
249
- else:
250
- # return scalar types as int/float
251
- return ret.value
201
+ if len(kwargs):
202
+ raise RuntimeError(
203
+ f"Error calling function '{self.key}', keyword arguments are not supported for user-defined overloads."
204
+ )
252
205
 
253
- except Exception as e:
254
- # couldn't pack values to match this overload
255
- # store error and move onto the next one
256
- error = e
206
+ # try and find a matching overload
207
+ for overload in self.user_overloads.values():
208
+ if len(overload.input_types) != len(args):
209
+ continue
210
+ template_types = list(overload.input_types.values())
211
+ arg_names = list(overload.input_types.keys())
212
+ try:
213
+ # attempt to unify argument types with function template types
214
+ warp.types.infer_argument_types(args, template_types, arg_names)
215
+ return overload.func(*args)
216
+ except Exception:
257
217
  continue
258
218
 
259
- # overload resolution or call failed
260
- # raise the last exception encountered
261
- if error:
262
- raise error
263
- else:
264
- raise RuntimeError(f"Error calling function '{f.key}'.")
219
+ raise RuntimeError(f"Error calling function '{self.key}', no overload found for arguments {args}")
265
220
 
266
- else:
267
- raise RuntimeError(
268
- f"Error, functions decorated with @wp.func can only be called from within Warp kernels (trying to call {self.key}())"
269
- )
221
+ # user-defined function with no overloads
222
+ if self.func is None:
223
+ raise RuntimeError(f"Error calling function '{self.key}', function is undefined")
224
+
225
+ # this function has no overloads, call it like a plain Python function
226
+ return self.func(*args, **kwargs)
270
227
 
271
228
  def is_builtin(self):
272
229
  return self.func is None
@@ -286,7 +243,7 @@ class Function:
286
243
  # todo: construct a default value for each of the functions args
287
244
  # so we can generate the return type for overloaded functions
288
245
  return_type = type_str(self.value_func(None, None, None))
289
- except:
246
+ except Exception:
290
247
  return False
291
248
 
292
249
  if return_type.startswith("Tuple"):
@@ -379,10 +336,187 @@ class Function:
379
336
  return None
380
337
 
381
338
  def __repr__(self):
382
- inputs_str = ", ".join([f"{k}: {v.__name__}" for k, v in self.input_types.items()])
339
+ inputs_str = ", ".join([f"{k}: {warp.types.type_repr(v)}" for k, v in self.input_types.items()])
383
340
  return f"<Function {self.key}({inputs_str})>"
384
341
 
385
342
 
343
+ def call_builtin(func: Function, *params) -> Tuple[bool, Any]:
344
+ uses_non_warp_array_type = False
345
+
346
+ # Retrieve the built-in function from Warp's dll.
347
+ c_func = getattr(warp.context.runtime.core, func.mangled_name)
348
+
349
+ # Try gathering the parameters that the function expects and pack them
350
+ # into their corresponding C types.
351
+ c_params = []
352
+ for i, (_, arg_type) in enumerate(func.input_types.items()):
353
+ param = params[i]
354
+
355
+ try:
356
+ iter(param)
357
+ except TypeError:
358
+ is_array = False
359
+ else:
360
+ is_array = True
361
+
362
+ if is_array:
363
+ if not issubclass(arg_type, ctypes.Array):
364
+ return (False, None)
365
+
366
+ # The argument expects a built-in Warp type like a vector or a matrix.
367
+
368
+ c_param = None
369
+
370
+ if isinstance(param, ctypes.Array):
371
+ # The given parameter is also a built-in Warp type, so we only need
372
+ # to make sure that it matches with the argument.
373
+ if not warp.types.types_equal(type(param), arg_type):
374
+ return (False, None)
375
+
376
+ if isinstance(param, arg_type):
377
+ c_param = param
378
+ else:
379
+ # Cast the value to its argument type to make sure that it
380
+ # can be assigned to the field of the `Param` struct.
381
+ # This could error otherwise when, for example, the field type
382
+ # is set to `vec3i` while the value is of type `vector(length=3, dtype=int)`,
383
+ # even though both types are semantically identical.
384
+ c_param = arg_type(param)
385
+ else:
386
+ # Flatten the parameter values into a flat 1-D array.
387
+ arr = []
388
+ ndim = 1
389
+ stack = [(0, param)]
390
+ while stack:
391
+ depth, elem = stack.pop(0)
392
+ try:
393
+ # If `elem` is a sequence, then it should be possible
394
+ # to add its elements to the stack for later processing.
395
+ stack.extend((depth + 1, x) for x in elem)
396
+ except TypeError:
397
+ # Since `elem` doesn't seem to be a sequence,
398
+ # we must have a leaf value that we need to add to our
399
+ # resulting array.
400
+ arr.append(elem)
401
+ ndim = max(depth, ndim)
402
+
403
+ assert ndim > 0
404
+
405
+ # Ensure that if the given parameter value is, say, a 2-D array,
406
+ # then we try to resolve it against a matrix argument rather than
407
+ # a vector.
408
+ if ndim > len(arg_type._shape_):
409
+ return (False, None)
410
+
411
+ elem_count = len(arr)
412
+ if elem_count != arg_type._length_:
413
+ return (False, None)
414
+
415
+ # Retrieve the element type of the sequence while ensuring
416
+ # that it's homogeneous.
417
+ elem_type = type(arr[0])
418
+ for i in range(1, elem_count):
419
+ if type(arr[i]) is not elem_type:
420
+ raise ValueError("All array elements must share the same type.")
421
+
422
+ expected_elem_type = arg_type._wp_scalar_type_
423
+ if not (
424
+ elem_type is expected_elem_type
425
+ or (elem_type is float and expected_elem_type is warp.types.float32)
426
+ or (elem_type is int and expected_elem_type is warp.types.int32)
427
+ or (
428
+ issubclass(elem_type, np.number)
429
+ and warp.types.np_dtype_to_warp_type[np.dtype(elem_type)] is expected_elem_type
430
+ )
431
+ ):
432
+ # The parameter value has a type not matching the type defined
433
+ # for the corresponding argument.
434
+ return (False, None)
435
+
436
+ if elem_type in warp.types.int_types:
437
+ # Pass the value through the expected integer type
438
+ # in order to evaluate any integer wrapping.
439
+ # For example `uint8(-1)` should result in the value `-255`.
440
+ arr = tuple(elem_type._type_(x.value).value for x in arr)
441
+ elif elem_type in warp.types.float_types:
442
+ # Extract the floating-point values.
443
+ arr = tuple(x.value for x in arr)
444
+
445
+ c_param = arg_type()
446
+ if warp.types.type_is_matrix(arg_type):
447
+ rows, cols = arg_type._shape_
448
+ for i in range(rows):
449
+ idx_start = i * cols
450
+ idx_end = idx_start + cols
451
+ c_param[i] = arr[idx_start:idx_end]
452
+ else:
453
+ c_param[:] = arr
454
+
455
+ uses_non_warp_array_type = True
456
+
457
+ c_params.append(ctypes.byref(c_param))
458
+ else:
459
+ if issubclass(arg_type, ctypes.Array):
460
+ return (False, None)
461
+
462
+ if not (
463
+ isinstance(param, arg_type)
464
+ or (type(param) is float and arg_type is warp.types.float32)
465
+ or (type(param) is int and arg_type is warp.types.int32)
466
+ or warp.types.np_dtype_to_warp_type.get(getattr(param, "dtype", None)) is arg_type
467
+ ):
468
+ return (False, None)
469
+
470
+ if type(param) in warp.types.scalar_types:
471
+ param = param.value
472
+
473
+ # try to pack as a scalar type
474
+ if arg_type == warp.types.float16:
475
+ c_params.append(arg_type._type_(warp.types.float_to_half_bits(param)))
476
+ else:
477
+ c_params.append(arg_type._type_(param))
478
+
479
+ # returns the corresponding ctype for a scalar or vector warp type
480
+ value_type = func.value_func(None, None, None)
481
+ if value_type == float:
482
+ value_ctype = ctypes.c_float
483
+ elif value_type == int:
484
+ value_ctype = ctypes.c_int32
485
+ elif issubclass(value_type, (ctypes.Array, ctypes.Structure)):
486
+ value_ctype = value_type
487
+ else:
488
+ # scalar type
489
+ value_ctype = value_type._type_
490
+
491
+ # construct return value (passed by address)
492
+ ret = value_ctype()
493
+ ret_addr = ctypes.c_void_p(ctypes.addressof(ret))
494
+ c_params.append(ret_addr)
495
+
496
+ # Call the built-in function from Warp's dll.
497
+ c_func(*c_params)
498
+
499
+ if uses_non_warp_array_type:
500
+ warp.utils.warn(
501
+ "Support for built-in functions called with non-Warp array types, "
502
+ "such as lists, tuples, NumPy arrays, and others, will be dropped "
503
+ "in the future. Use a Warp type such as `wp.vec`, `wp.mat`, "
504
+ "`wp.quat`, or `wp.transform`.",
505
+ DeprecationWarning,
506
+ stacklevel=3,
507
+ )
508
+
509
+ if issubclass(value_ctype, ctypes.Array) or issubclass(value_ctype, ctypes.Structure):
510
+ # return vector types as ctypes
511
+ return (True, ret)
512
+
513
+ if value_type == warp.types.float16:
514
+ return (True, warp.types.half_bits_to_float(ret.value))
515
+
516
+ # return scalar types as int/float
517
+ return (True, ret.value)
518
+
519
+
386
520
  class KernelHooks:
387
521
  def __init__(self, forward, backward):
388
522
  self.forward = forward
@@ -391,13 +525,23 @@ class KernelHooks:
391
525
 
392
526
  # caches source and compiled entry points for a kernel (will be populated after module loads)
393
527
  class Kernel:
394
- def __init__(self, func, key, module, options=None):
528
+ def __init__(self, func, key=None, module=None, options=None, code_transformers=[]):
395
529
  self.func = func
396
- self.module = module
397
- self.key = key
530
+
531
+ if module is None:
532
+ self.module = get_module(func.__module__)
533
+ else:
534
+ self.module = module
535
+
536
+ if key is None:
537
+ unique_key = self.module.generate_unique_kernel_key(func.__name__)
538
+ self.key = unique_key
539
+ else:
540
+ self.key = key
541
+
398
542
  self.options = {} if options is None else options
399
543
 
400
- self.adj = warp.codegen.Adjoint(func)
544
+ self.adj = warp.codegen.Adjoint(func, transformers=code_transformers)
401
545
 
402
546
  # check if generic
403
547
  self.is_generic = False
@@ -415,8 +559,8 @@ class Kernel:
415
559
  # argument indices by name
416
560
  self.arg_indices = dict((a.label, i) for i, a in enumerate(self.adj.args))
417
561
 
418
- if module:
419
- module.register_kernel(self)
562
+ if self.module:
563
+ self.module.register_kernel(self)
420
564
 
421
565
  def infer_argument_types(self, args):
422
566
  template_types = list(self.adj.arg_types.values())
@@ -425,44 +569,8 @@ class Kernel:
425
569
  raise RuntimeError(f"Invalid number of arguments for kernel {self.key}")
426
570
 
427
571
  arg_names = list(self.adj.arg_types.keys())
428
- arg_types = []
429
-
430
- for i in range(len(args)):
431
- arg = args[i]
432
- arg_type = type(arg)
433
- if arg_type in warp.types.array_types:
434
- arg_types.append(arg_type(dtype=arg.dtype, ndim=arg.ndim))
435
- elif arg_type in warp.types.scalar_types:
436
- arg_types.append(arg_type)
437
- elif arg_type in [int, float]:
438
- # canonicalize type
439
- arg_types.append(warp.types.type_to_warp(arg_type))
440
- elif hasattr(arg_type, "_wp_scalar_type_"):
441
- # vector/matrix type
442
- arg_types.append(arg_type)
443
- elif issubclass(arg_type, warp.codegen.StructInstance):
444
- # a struct
445
- arg_types.append(arg._struct_)
446
- # elif arg_type in [warp.types.launch_bounds_t, warp.types.shape_t, warp.types.range_t]:
447
- # arg_types.append(arg_type)
448
- # elif arg_type in [warp.hash_grid_query_t, warp.mesh_query_aabb_t, warp.bvh_query_t]:
449
- # arg_types.append(arg_type)
450
- elif arg is None:
451
- # allow passing None for arrays
452
- t = template_types[i]
453
- if warp.types.is_array(t):
454
- arg_types.append(type(t)(dtype=t.dtype, ndim=t.ndim))
455
- else:
456
- raise TypeError(
457
- f"Unable to infer the type of argument '{arg_names[i]}' for kernel {self.key}, got None"
458
- )
459
- else:
460
- # TODO: attempt to figure out if it's a vector/matrix type given as a numpy array, list, etc.
461
- raise TypeError(
462
- f"Unable to infer the type of argument '{arg_names[i]}' for kernel {self.key}, got {arg_type}"
463
- )
464
572
 
465
- return arg_types
573
+ return warp.types.infer_argument_types(args, template_types, arg_names)
466
574
 
467
575
  def add_overload(self, arg_types):
468
576
  if len(arg_types) != len(self.adj.arg_types):
@@ -529,7 +637,7 @@ def func(f):
529
637
  name = warp.codegen.make_full_qualified_name(f)
530
638
 
531
639
  m = get_module(f.__module__)
532
- func = Function(
640
+ Function(
533
641
  func=f, key=name, namespace="", module=m, value_func=None
534
642
  ) # value_type not known yet, will be inferred during Adjoint.build()
535
643
 
@@ -537,6 +645,167 @@ def func(f):
537
645
  return m.functions[name]
538
646
 
539
647
 
648
+ def func_native(snippet, adj_snippet=None):
649
+ """
650
+ Decorator to register native code snippet, @func_native
651
+ """
652
+
653
+ def snippet_func(f):
654
+ name = warp.codegen.make_full_qualified_name(f)
655
+
656
+ m = get_module(f.__module__)
657
+ func = Function(
658
+ func=f, key=name, namespace="", module=m, native_snippet=snippet, adj_native_snippet=adj_snippet
659
+ ) # cuda snippets do not have a return value_type
660
+
661
+ return m.functions[name]
662
+
663
+ return snippet_func
664
+
665
+
666
+ def func_grad(forward_fn):
667
+ """
668
+ Decorator to register a custom gradient function for a given forward function.
669
+ The function signature must correspond to one of the function overloads in the following way:
670
+ the first part of the input arguments are the original input variables with the same types as their
671
+ corresponding arguments in the original function, and the second part of the input arguments are the
672
+ adjoint variables of the output variables (if available) of the original function with the same types as the
673
+ output variables. The function must not return anything.
674
+ """
675
+
676
+ def wrapper(grad_fn):
677
+ generic = any(warp.types.type_is_generic(x) for x in forward_fn.input_types.values())
678
+ if generic:
679
+ raise RuntimeError(
680
+ f"Cannot define custom grad definition for {forward_fn.key} since functions with generic input arguments are not yet supported."
681
+ )
682
+
683
+ reverse_args = {}
684
+ reverse_args.update(forward_fn.input_types)
685
+
686
+ # create temporary Adjoint instance to analyze the function signature
687
+ adj = warp.codegen.Adjoint(
688
+ grad_fn, skip_forward_codegen=True, skip_reverse_codegen=False, transformers=forward_fn.adj.transformers
689
+ )
690
+
691
+ from warp.types import types_equal
692
+
693
+ grad_args = adj.args
694
+ grad_sig = warp.types.get_signature([arg.type for arg in grad_args], func_name=forward_fn.key)
695
+
696
+ generic = any(warp.types.type_is_generic(x.type) for x in grad_args)
697
+ if generic:
698
+ raise RuntimeError(
699
+ f"Cannot define custom grad definition for {forward_fn.key} since the provided grad function has generic input arguments."
700
+ )
701
+
702
+ def match_function(f):
703
+ # check whether the function overload f matches the signature of the provided gradient function
704
+ if not hasattr(f.adj, "return_var"):
705
+ f.adj.build(None)
706
+ expected_args = list(f.input_types.items())
707
+ if f.adj.return_var is not None:
708
+ expected_args += [(f"adj_ret_{var.label}", var.type) for var in f.adj.return_var]
709
+ if len(grad_args) != len(expected_args):
710
+ return False
711
+ if any(not types_equal(a.type, exp_type) for a, (_, exp_type) in zip(grad_args, expected_args)):
712
+ return False
713
+ return True
714
+
715
+ def add_custom_grad(f: Function):
716
+ # register custom gradient function
717
+ f.custom_grad_func = Function(
718
+ grad_fn,
719
+ key=f.key,
720
+ namespace=f.namespace,
721
+ input_types=reverse_args,
722
+ value_func=None,
723
+ module=f.module,
724
+ template_func=f.template_func,
725
+ skip_forward_codegen=True,
726
+ custom_reverse_mode=True,
727
+ custom_reverse_num_input_args=len(f.input_types),
728
+ skip_adding_overload=False,
729
+ code_transformers=f.adj.transformers,
730
+ )
731
+ f.adj.skip_reverse_codegen = True
732
+
733
+ if hasattr(forward_fn, "user_overloads") and len(forward_fn.user_overloads):
734
+ # find matching overload for which this grad function is defined
735
+ for sig, f in forward_fn.user_overloads.items():
736
+ if not grad_sig.startswith(sig):
737
+ continue
738
+ if match_function(f):
739
+ add_custom_grad(f)
740
+ return
741
+ raise RuntimeError(
742
+ f"No function overload found for gradient function {grad_fn.__qualname__} for function {forward_fn.key}"
743
+ )
744
+ else:
745
+ # resolve return variables
746
+ forward_fn.adj.build(None)
747
+
748
+ expected_args = list(forward_fn.input_types.items())
749
+ if forward_fn.adj.return_var is not None:
750
+ expected_args += [(f"adj_ret_{var.label}", var.type) for var in forward_fn.adj.return_var]
751
+
752
+ # check if the signature matches this function
753
+ if match_function(forward_fn):
754
+ add_custom_grad(forward_fn)
755
+ else:
756
+ raise RuntimeError(
757
+ f"Gradient function {grad_fn.__qualname__} for function {forward_fn.key} has an incorrect signature. The arguments must match the "
758
+ "forward function arguments plus the adjoint variables corresponding to the return variables:"
759
+ f"\n{', '.join(map(lambda nt: f'{nt[0]}: {nt[1].__name__}', expected_args))}"
760
+ )
761
+
762
+ return wrapper
763
+
764
+
765
+ def func_replay(forward_fn):
766
+ """
767
+ Decorator to register a custom replay function for a given forward function.
768
+ The replay function is the function version that is called in the forward phase of the backward pass (replay mode) and corresponds to the forward function by default.
769
+ The provided function has to match the signature of one of the original forward function overloads.
770
+ """
771
+
772
+ def wrapper(replay_fn):
773
+ generic = any(warp.types.type_is_generic(x) for x in forward_fn.input_types.values())
774
+ if generic:
775
+ raise RuntimeError(
776
+ f"Cannot define custom replay definition for {forward_fn.key} since functions with generic input arguments are not yet supported."
777
+ )
778
+
779
+ args = get_function_args(replay_fn)
780
+ arg_types = list(args.values())
781
+ generic = any(warp.types.type_is_generic(x) for x in arg_types)
782
+ if generic:
783
+ raise RuntimeError(
784
+ f"Cannot define custom replay definition for {forward_fn.key} since the provided replay function has generic input arguments."
785
+ )
786
+
787
+ f = forward_fn.get_overload(arg_types)
788
+ if f is None:
789
+ inputs_str = ", ".join([f"{k}: {v.__name__}" for k, v in args.items()])
790
+ raise RuntimeError(
791
+ f"Could not find forward definition of function {forward_fn.key} that matches custom replay definition with arguments:\n{inputs_str}"
792
+ )
793
+ f.custom_replay_func = Function(
794
+ replay_fn,
795
+ key=f"replay_{f.key}",
796
+ namespace=f.namespace,
797
+ input_types=f.input_types,
798
+ value_func=f.value_func,
799
+ module=f.module,
800
+ template_func=f.template_func,
801
+ skip_reverse_codegen=True,
802
+ skip_adding_overload=True,
803
+ code_transformers=f.adj.transformers,
804
+ )
805
+
806
+ return wrapper
807
+
808
+
540
809
  # decorator to register kernel, @kernel, custom_name may be a string
541
810
  # that creates a kernel with a different name from the actual function
542
811
  def kernel(f=None, *, enable_backward=None):
@@ -664,6 +933,7 @@ def add_builtin(
664
933
  missing_grad=False,
665
934
  native_func=None,
666
935
  defaults=None,
936
+ require_original_output_arg=False,
667
937
  ):
668
938
  # wrap simple single-type functions with a value_func()
669
939
  if value_func is None:
@@ -676,7 +946,7 @@ def add_builtin(
676
946
  def initializer_list_func(args, templates):
677
947
  return False
678
948
 
679
- if defaults == None:
949
+ if defaults is None:
680
950
  defaults = {}
681
951
 
682
952
  # Add specialized versions of this builtin if it's generic by matching arguments against
@@ -757,8 +1027,8 @@ def add_builtin(
757
1027
  # on the generated argument list and skip generation if it fails.
758
1028
  # This also gives us the return type, which we keep for later:
759
1029
  try:
760
- return_type = value_func([warp.codegen.Var("", t) for t in argtypes], {}, [])
761
- except Exception as e:
1030
+ return_type = value_func(argtypes, {}, [])
1031
+ except Exception:
762
1032
  continue
763
1033
 
764
1034
  # The return_type might just be vector_t(length=3,dtype=wp.float32), so we've got to match that
@@ -788,6 +1058,7 @@ def add_builtin(
788
1058
  hidden=True,
789
1059
  skip_replay=skip_replay,
790
1060
  missing_grad=missing_grad,
1061
+ require_original_output_arg=require_original_output_arg,
791
1062
  )
792
1063
 
793
1064
  func = Function(
@@ -808,6 +1079,7 @@ def add_builtin(
808
1079
  generic=generic,
809
1080
  native_func=native_func,
810
1081
  defaults=defaults,
1082
+ require_original_output_arg=require_original_output_arg,
811
1083
  )
812
1084
 
813
1085
  if key in builtin_functions:
@@ -817,7 +1089,7 @@ def add_builtin(
817
1089
 
818
1090
  # export means the function will be added to the `warp` module namespace
819
1091
  # so that users can call it directly from the Python interpreter
820
- if export == True:
1092
+ if export:
821
1093
  if hasattr(warp, key):
822
1094
  # check that we haven't already created something at this location
823
1095
  # if it's just an overload stub for auto-complete then overwrite it
@@ -884,6 +1156,8 @@ class ModuleBuilder:
884
1156
  for func in module.functions.values():
885
1157
  for f in func.user_overloads.values():
886
1158
  self.build_function(f)
1159
+ if f.custom_replay_func is not None:
1160
+ self.build_function(f.custom_replay_func)
887
1161
 
888
1162
  # build all kernel entry points
889
1163
  for kernel in module.kernels.values():
@@ -900,12 +1174,13 @@ class ModuleBuilder:
900
1174
  while stack:
901
1175
  s = stack.pop()
902
1176
 
903
- if not s in structs:
904
- structs.append(s)
1177
+ structs.append(s)
905
1178
 
906
1179
  for var in s.vars.values():
907
1180
  if isinstance(var.type, warp.codegen.Struct):
908
1181
  stack.append(var.type)
1182
+ elif isinstance(var.type, warp.types.array) and isinstance(var.type.dtype, warp.codegen.Struct):
1183
+ stack.append(var.type.dtype)
909
1184
 
910
1185
  # Build them in reverse to generate a correct dependency order.
911
1186
  for s in reversed(structs):
@@ -931,7 +1206,7 @@ class ModuleBuilder:
931
1206
  if not func.value_func:
932
1207
 
933
1208
  def wrap(adj):
934
- def value_type(args, kwds, templates):
1209
+ def value_type(arg_types, kwds, templates):
935
1210
  if adj.return_var is None or len(adj.return_var) == 0:
936
1211
  return None
937
1212
  if len(adj.return_var) == 1:
@@ -946,56 +1221,41 @@ class ModuleBuilder:
946
1221
  # use dict to preserve import order
947
1222
  self.functions[func] = None
948
1223
 
949
- def codegen_cpu(self):
950
- cpp_source = ""
1224
+ def codegen(self, device):
1225
+ source = ""
951
1226
 
952
1227
  # code-gen structs
953
1228
  for struct in self.structs.keys():
954
- cpp_source += warp.codegen.codegen_struct(struct)
1229
+ source += warp.codegen.codegen_struct(struct)
955
1230
 
956
1231
  # code-gen all imported functions
957
1232
  for func in self.functions.keys():
958
- cpp_source += warp.codegen.codegen_func(func.adj, device="cpu")
959
-
960
- for kernel in self.module.kernels.values():
961
- # each kernel gets an entry point in the module
962
- if not kernel.is_generic:
963
- cpp_source += warp.codegen.codegen_kernel(kernel, device="cpu", options=self.options)
964
- cpp_source += warp.codegen.codegen_module(kernel, device="cpu")
1233
+ if func.native_snippet is None:
1234
+ source += warp.codegen.codegen_func(
1235
+ func.adj, c_func_name=func.native_func, device=device, options=self.options
1236
+ )
965
1237
  else:
966
- for k in kernel.overloads.values():
967
- cpp_source += warp.codegen.codegen_kernel(k, device="cpu", options=self.options)
968
- cpp_source += warp.codegen.codegen_module(k, device="cpu")
969
-
970
- # add headers
971
- cpp_source = warp.codegen.cpu_module_header + cpp_source
972
-
973
- return cpp_source
974
-
975
- def codegen_cuda(self):
976
- cu_source = ""
977
-
978
- # code-gen structs
979
- for struct in self.structs.keys():
980
- cu_source += warp.codegen.codegen_struct(struct)
981
-
982
- # code-gen all imported functions
983
- for func in self.functions.keys():
984
- cu_source += warp.codegen.codegen_func(func.adj, device="cuda")
1238
+ source += warp.codegen.codegen_snippet(
1239
+ func.adj, name=func.key, snippet=func.native_snippet, adj_snippet=func.adj_native_snippet
1240
+ )
985
1241
 
986
1242
  for kernel in self.module.kernels.values():
1243
+ # each kernel gets an entry point in the module
987
1244
  if not kernel.is_generic:
988
- cu_source += warp.codegen.codegen_kernel(kernel, device="cuda", options=self.options)
989
- cu_source += warp.codegen.codegen_module(kernel, device="cuda")
1245
+ source += warp.codegen.codegen_kernel(kernel, device=device, options=self.options)
1246
+ source += warp.codegen.codegen_module(kernel, device=device)
990
1247
  else:
991
1248
  for k in kernel.overloads.values():
992
- cu_source += warp.codegen.codegen_kernel(k, device="cuda", options=self.options)
993
- cu_source += warp.codegen.codegen_module(k, device="cuda")
1249
+ source += warp.codegen.codegen_kernel(k, device=device, options=self.options)
1250
+ source += warp.codegen.codegen_module(k, device=device)
994
1251
 
995
1252
  # add headers
996
- cu_source = warp.codegen.cuda_module_header + cu_source
1253
+ if device == "cpu":
1254
+ source = warp.codegen.cpu_module_header + source
1255
+ else:
1256
+ source = warp.codegen.cuda_module_header + source
997
1257
 
998
- return cu_source
1258
+ return source
999
1259
 
1000
1260
 
1001
1261
  # -----------------------------------------------------
@@ -1014,7 +1274,6 @@ class Module:
1014
1274
  self.constants = []
1015
1275
  self.structs = {}
1016
1276
 
1017
- self.dll = None
1018
1277
  self.cpu_module = None
1019
1278
  self.cuda_modules = {} # module lookup by CUDA context
1020
1279
 
@@ -1058,6 +1317,10 @@ class Module:
1058
1317
 
1059
1318
  self.content_hash = None
1060
1319
 
1320
+ # number of times module auto-generates kernel key for user
1321
+ # used to ensure unique kernel keys
1322
+ self.count = 0
1323
+
1061
1324
  def register_struct(self, struct):
1062
1325
  self.structs[struct.key] = struct
1063
1326
 
@@ -1072,7 +1335,7 @@ class Module:
1072
1335
  # for a reload of module on next launch
1073
1336
  self.unload()
1074
1337
 
1075
- def register_function(self, func):
1338
+ def register_function(self, func, skip_adding_overload=False):
1076
1339
  if func.key not in self.functions:
1077
1340
  self.functions[func.key] = func
1078
1341
  else:
@@ -1092,7 +1355,7 @@ class Module:
1092
1355
  )
1093
1356
  if sig == sig_existing:
1094
1357
  self.functions[func.key] = func
1095
- else:
1358
+ elif not skip_adding_overload:
1096
1359
  func_existing.add_overload(func)
1097
1360
 
1098
1361
  self.find_references(func.adj)
@@ -1100,6 +1363,11 @@ class Module:
1100
1363
  # for a reload of module on next launch
1101
1364
  self.unload()
1102
1365
 
1366
+ def generate_unique_kernel_key(self, key):
1367
+ unique_key = f"{key}_{self.count}"
1368
+ self.count += 1
1369
+ return unique_key
1370
+
1103
1371
  # collect all referenced functions / structs
1104
1372
  # given the AST of a function or kernel
1105
1373
  def find_references(self, adj):
@@ -1113,13 +1381,13 @@ class Module:
1113
1381
  if isinstance(node, ast.Call):
1114
1382
  try:
1115
1383
  # try to resolve the function
1116
- func, _ = adj.resolve_path(node.func)
1384
+ func, _ = adj.resolve_static_expression(node.func, eval_types=False)
1117
1385
 
1118
1386
  # if this is a user-defined function, add a module reference
1119
1387
  if isinstance(func, warp.context.Function) and func.module is not None:
1120
1388
  add_ref(func.module)
1121
1389
 
1122
- except:
1390
+ except Exception:
1123
1391
  # Lookups may fail for builtins, but that's ok.
1124
1392
  # Lookups may also fail for functions in this module that haven't been imported yet,
1125
1393
  # and that's ok too (not an external reference).
@@ -1139,6 +1407,11 @@ class Module:
1139
1407
 
1140
1408
  return getattr(obj, "__annotations__", {})
1141
1409
 
1410
+ def get_type_name(type_hint):
1411
+ if isinstance(type_hint, warp.codegen.Struct):
1412
+ return get_type_name(type_hint.cls)
1413
+ return type_hint
1414
+
1142
1415
  def hash_recursive(module, visited):
1143
1416
  # Hash this module, including all referenced modules recursively.
1144
1417
  # The visited set tracks modules already visited to avoid circular references.
@@ -1151,7 +1424,8 @@ class Module:
1151
1424
  # struct source
1152
1425
  for struct in module.structs.values():
1153
1426
  s = ",".join(
1154
- "{}: {}".format(name, type_hint) for name, type_hint in get_annotations(struct.cls).items()
1427
+ "{}: {}".format(name, get_type_name(type_hint))
1428
+ for name, type_hint in get_annotations(struct.cls).items()
1155
1429
  )
1156
1430
  ch.update(bytes(s, "utf-8"))
1157
1431
 
@@ -1160,13 +1434,29 @@ class Module:
1160
1434
  s = func.adj.source
1161
1435
  ch.update(bytes(s, "utf-8"))
1162
1436
 
1437
+ if func.custom_grad_func:
1438
+ s = func.custom_grad_func.adj.source
1439
+ ch.update(bytes(s, "utf-8"))
1440
+ if func.custom_replay_func:
1441
+ s = func.custom_replay_func.adj.source
1442
+
1443
+ # cache func arg types
1444
+ for arg, arg_type in func.adj.arg_types.items():
1445
+ s = f"{arg}: {get_type_name(arg_type)}"
1446
+ ch.update(bytes(s, "utf-8"))
1447
+
1163
1448
  # kernel source
1164
1449
  for kernel in module.kernels.values():
1165
- if not kernel.is_generic:
1166
- ch.update(bytes(kernel.adj.source, "utf-8"))
1167
- else:
1168
- for k in kernel.overloads.values():
1169
- ch.update(bytes(k.adj.source, "utf-8"))
1450
+ ch.update(bytes(kernel.adj.source, "utf-8"))
1451
+ # cache kernel arg types
1452
+ for arg, arg_type in kernel.adj.arg_types.items():
1453
+ s = f"{arg}: {get_type_name(arg_type)}"
1454
+ ch.update(bytes(s, "utf-8"))
1455
+ # for generic kernels the Python source is always the same,
1456
+ # but we hash the type signatures of all the overloads
1457
+ if kernel.is_generic:
1458
+ for sig in sorted(kernel.overloads.keys()):
1459
+ ch.update(bytes(sig, "utf-8"))
1170
1460
 
1171
1461
  module.content_hash = ch.digest()
1172
1462
 
@@ -1204,12 +1494,12 @@ class Module:
1204
1494
  return hash_recursive(self, visited=set())
1205
1495
 
1206
1496
  def load(self, device):
1497
+ from warp.utils import ScopedTimer
1498
+
1207
1499
  device = get_device(device)
1208
1500
 
1209
1501
  if device.is_cpu:
1210
1502
  # check if already loaded
1211
- if self.dll:
1212
- return True
1213
1503
  if self.cpu_module:
1214
1504
  return True
1215
1505
  # avoid repeated build attempts
@@ -1227,7 +1517,7 @@ class Module:
1227
1517
  if not warp.is_cuda_available():
1228
1518
  raise RuntimeError("Failed to build CUDA module because CUDA is not available")
1229
1519
 
1230
- with warp.utils.ScopedTimer(f"Module {self.name} load on device '{device}'", active=not warp.config.quiet):
1520
+ with ScopedTimer(f"Module {self.name} load on device '{device}'", active=not warp.config.quiet):
1231
1521
  build_path = warp.build.kernel_bin_dir
1232
1522
  gen_path = warp.build.kernel_gen_dir
1233
1523
 
@@ -1238,89 +1528,54 @@ class Module:
1238
1528
 
1239
1529
  module_name = "wp_" + self.name
1240
1530
  module_path = os.path.join(build_path, module_name)
1241
- obj_path = os.path.join(gen_path, module_name)
1242
1531
  module_hash = self.hash_module()
1243
1532
 
1244
1533
  builder = ModuleBuilder(self, self.options)
1245
1534
 
1246
1535
  if device.is_cpu:
1247
- if runtime.llvm:
1248
- if os.name == "nt":
1249
- dll_path = obj_path + ".cpp.obj"
1250
- else:
1251
- dll_path = obj_path + ".cpp.o"
1252
- else:
1253
- if os.name == "nt":
1254
- dll_path = module_path + ".dll"
1255
- else:
1256
- dll_path = module_path + ".so"
1257
-
1536
+ obj_path = os.path.join(build_path, module_name)
1537
+ obj_path = obj_path + ".o"
1258
1538
  cpu_hash_path = module_path + ".cpu.hash"
1259
1539
 
1260
1540
  # check cache
1261
- if warp.config.cache_kernels and os.path.isfile(cpu_hash_path) and os.path.isfile(dll_path):
1541
+ if warp.config.cache_kernels and os.path.isfile(cpu_hash_path) and os.path.isfile(obj_path):
1262
1542
  with open(cpu_hash_path, "rb") as f:
1263
1543
  cache_hash = f.read()
1264
1544
 
1265
1545
  if cache_hash == module_hash:
1266
- if runtime.llvm:
1267
- runtime.llvm.load_obj(dll_path.encode("utf-8"), module_name.encode("utf-8"))
1268
- self.cpu_module = module_name
1269
- return True
1270
- else:
1271
- self.dll = warp.build.load_dll(dll_path)
1272
- if self.dll is not None:
1273
- return True
1546
+ runtime.llvm.load_obj(obj_path.encode("utf-8"), module_name.encode("utf-8"))
1547
+ self.cpu_module = module_name
1548
+ return True
1274
1549
 
1275
1550
  # build
1276
1551
  try:
1277
1552
  cpp_path = os.path.join(gen_path, module_name + ".cpp")
1278
1553
 
1279
1554
  # write cpp sources
1280
- cpp_source = builder.codegen_cpu()
1555
+ cpp_source = builder.codegen("cpu")
1281
1556
 
1282
1557
  cpp_file = open(cpp_path, "w")
1283
1558
  cpp_file.write(cpp_source)
1284
1559
  cpp_file.close()
1285
1560
 
1286
- bin_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "bin")
1287
- if os.name == "nt":
1288
- libs = ["warp.lib", f'/LIBPATH:"{bin_path}"']
1289
- libs.append("/NOENTRY")
1290
- libs.append("/NODEFAULTLIB")
1291
- elif sys.platform == "darwin":
1292
- libs = [f"-lwarp", f"-L{bin_path}", f"-Wl,-rpath,'{bin_path}'"]
1293
- else:
1294
- libs = ["-l:warp.so", f"-L{bin_path}", f"-Wl,-rpath,'{bin_path}'"]
1295
-
1296
- # build DLL or object code
1297
- with warp.utils.ScopedTimer("Compile x86", active=warp.config.verbose):
1298
- warp.build.build_dll(
1299
- dll_path,
1300
- [cpp_path],
1301
- None,
1302
- libs,
1561
+ # build object code
1562
+ with ScopedTimer("Compile x86", active=warp.config.verbose):
1563
+ warp.build.build_cpu(
1564
+ obj_path,
1565
+ cpp_path,
1303
1566
  mode=self.options["mode"],
1304
1567
  fast_math=self.options["fast_math"],
1305
1568
  verify_fp=warp.config.verify_fp,
1306
1569
  )
1307
1570
 
1308
- if runtime.llvm:
1309
- # load the object code
1310
- obj_ext = ".obj" if os.name == "nt" else ".o"
1311
- obj_path = cpp_path + obj_ext
1312
- runtime.llvm.load_obj(obj_path.encode("utf-8"), module_name.encode("utf-8"))
1313
- self.cpu_module = module_name
1314
- else:
1315
- # load the DLL
1316
- self.dll = warp.build.load_dll(dll_path)
1317
- if self.dll is None:
1318
- raise Exception("Failed to load CPU module")
1319
-
1320
1571
  # update cpu hash
1321
1572
  with open(cpu_hash_path, "wb") as f:
1322
1573
  f.write(module_hash)
1323
1574
 
1575
+ # load the object code
1576
+ runtime.llvm.load_obj(obj_path.encode("utf-8"), module_name.encode("utf-8"))
1577
+ self.cpu_module = module_name
1578
+
1324
1579
  except Exception as e:
1325
1580
  self.cpu_build_failed = True
1326
1581
  raise (e)
@@ -1365,14 +1620,14 @@ class Module:
1365
1620
  cu_path = os.path.join(gen_path, module_name + ".cu")
1366
1621
 
1367
1622
  # write cuda sources
1368
- cu_source = builder.codegen_cuda()
1623
+ cu_source = builder.codegen("cuda")
1369
1624
 
1370
1625
  cu_file = open(cu_path, "w")
1371
1626
  cu_file.write(cu_source)
1372
1627
  cu_file.close()
1373
1628
 
1374
1629
  # generate PTX or CUBIN
1375
- with warp.utils.ScopedTimer("Compile CUDA", active=warp.config.verbose):
1630
+ with ScopedTimer("Compile CUDA", active=warp.config.verbose):
1376
1631
  warp.build.build_cuda(
1377
1632
  cu_path,
1378
1633
  output_arch,
@@ -1382,6 +1637,10 @@ class Module:
1382
1637
  verify_fp=warp.config.verify_fp,
1383
1638
  )
1384
1639
 
1640
+ # update cuda hash
1641
+ with open(cuda_hash_path, "wb") as f:
1642
+ f.write(module_hash)
1643
+
1385
1644
  # load the module
1386
1645
  cuda_module = warp.build.load_cuda(output_path, device)
1387
1646
  if cuda_module is not None:
@@ -1389,10 +1648,6 @@ class Module:
1389
1648
  else:
1390
1649
  raise Exception("Failed to load CUDA module")
1391
1650
 
1392
- # update cuda hash
1393
- with open(cuda_hash_path, "wb") as f:
1394
- f.write(module_hash)
1395
-
1396
1651
  except Exception as e:
1397
1652
  self.cuda_build_failed = True
1398
1653
  raise (e)
@@ -1400,10 +1655,6 @@ class Module:
1400
1655
  return True
1401
1656
 
1402
1657
  def unload(self):
1403
- if self.dll:
1404
- warp.build.unload_dll(self.dll)
1405
- self.dll = None
1406
-
1407
1658
  if self.cpu_module:
1408
1659
  runtime.llvm.unload_obj(self.cpu_module.encode("utf-8"))
1409
1660
  self.cpu_module = None
@@ -1438,17 +1689,13 @@ class Module:
1438
1689
  name = kernel.get_mangled_name()
1439
1690
 
1440
1691
  if device.is_cpu:
1441
- if self.cpu_module:
1442
- func = ctypes.CFUNCTYPE(None)
1443
- forward = func(
1444
- runtime.llvm.lookup(self.cpu_module.encode("utf-8"), (name + "_cpu_forward").encode("utf-8"))
1445
- )
1446
- backward = func(
1447
- runtime.llvm.lookup(self.cpu_module.encode("utf-8"), (name + "_cpu_backward").encode("utf-8"))
1448
- )
1449
- else:
1450
- forward = eval("self.dll." + name + "_cpu_forward")
1451
- backward = eval("self.dll." + name + "_cpu_backward")
1692
+ func = ctypes.CFUNCTYPE(None)
1693
+ forward = func(
1694
+ runtime.llvm.lookup(self.cpu_module.encode("utf-8"), (name + "_cpu_forward").encode("utf-8"))
1695
+ )
1696
+ backward = func(
1697
+ runtime.llvm.lookup(self.cpu_module.encode("utf-8"), (name + "_cpu_backward").encode("utf-8"))
1698
+ )
1452
1699
  else:
1453
1700
  cu_module = self.cuda_modules[device.context]
1454
1701
  forward = runtime.core.cuda_get_kernel(
@@ -1475,6 +1722,8 @@ class Allocator:
1475
1722
 
1476
1723
  def alloc(self, size_in_bytes, pinned=False):
1477
1724
  if self.device.is_cuda:
1725
+ if self.device.is_capturing:
1726
+ raise RuntimeError(f"Cannot allocate memory on device {self} while graph capture is active")
1478
1727
  return runtime.core.alloc_device(self.device.context, size_in_bytes)
1479
1728
  elif self.device.is_cpu:
1480
1729
  if pinned:
@@ -1484,6 +1733,8 @@ class Allocator:
1484
1733
 
1485
1734
  def free(self, ptr, size_in_bytes, pinned=False):
1486
1735
  if self.device.is_cuda:
1736
+ if self.device.is_capturing:
1737
+ raise RuntimeError(f"Cannot free memory on device {self} while graph capture is active")
1487
1738
  return runtime.core.free_device(self.device.context, ptr)
1488
1739
  elif self.device.is_cpu:
1489
1740
  if pinned:
@@ -1499,13 +1750,13 @@ class ContextGuard:
1499
1750
  def __enter__(self):
1500
1751
  if self.device.is_cuda:
1501
1752
  runtime.core.cuda_context_push_current(self.device.context)
1502
- elif is_cuda_available():
1753
+ elif is_cuda_driver_initialized():
1503
1754
  self.saved_context = runtime.core.cuda_context_get_current()
1504
1755
 
1505
1756
  def __exit__(self, exc_type, exc_value, traceback):
1506
1757
  if self.device.is_cuda:
1507
1758
  runtime.core.cuda_context_pop_current()
1508
- elif is_cuda_available():
1759
+ elif is_cuda_driver_initialized():
1509
1760
  runtime.core.cuda_context_set_current(self.saved_context)
1510
1761
 
1511
1762
 
@@ -1596,6 +1847,29 @@ class Event:
1596
1847
 
1597
1848
 
1598
1849
  class Device:
1850
+ """A device to allocate Warp arrays and to launch kernels on.
1851
+
1852
+ Attributes:
1853
+ ordinal: A Warp-specific integer label for the device. ``-1`` for CPU devices.
1854
+ name: A string label for the device. By default, CPU devices will be named according to the processor name,
1855
+ or ``"CPU"`` if the processor name cannot be determined.
1856
+ arch: An integer representing the compute capability version number calculated as
1857
+ ``10 * major + minor``. ``0`` for CPU devices.
1858
+ is_uva: A boolean indicating whether or not the device supports unified addressing.
1859
+ ``False`` for CPU devices.
1860
+ is_cubin_supported: A boolean indicating whether or not Warp's version of NVRTC can directly
1861
+ generate CUDA binary files (cubin) for this device's architecture. ``False`` for CPU devices.
1862
+ is_mempool_supported: A boolean indicating whether or not the device supports using the
1863
+ ``cuMemAllocAsync`` and ``cuMemPool`` family of APIs for stream-ordered memory allocations. ``False`` for
1864
+ CPU devices.
1865
+ is_primary: A boolean indicating whether or not this device's CUDA context is also the
1866
+ device's primary context.
1867
+ uuid: A string representing the UUID of the CUDA device. The UUID is in the same format used by
1868
+ ``nvidia-smi -L``. ``None`` for CPU devices.
1869
+ pci_bus_id: A string identifier for the CUDA device in the format ``[domain]:[bus]:[device]``, in which
1870
+ ``domain``, ``bus``, and ``device`` are all hexadecimal values. ``None`` for CPU devices.
1871
+ """
1872
+
1599
1873
  def __init__(self, runtime, alias, ordinal=-1, is_primary=False, context=None):
1600
1874
  self.runtime = runtime
1601
1875
  self.alias = alias
@@ -1625,6 +1899,9 @@ class Device:
1625
1899
  self.arch = 0
1626
1900
  self.is_uva = False
1627
1901
  self.is_cubin_supported = False
1902
+ self.is_mempool_supported = False
1903
+ self.uuid = None
1904
+ self.pci_bus_id = None
1628
1905
 
1629
1906
  # TODO: add more device-specific dispatch functions
1630
1907
  self.memset = runtime.core.memset_host
@@ -1637,6 +1914,26 @@ class Device:
1637
1914
  self.is_uva = runtime.core.cuda_device_is_uva(ordinal)
1638
1915
  # check whether our NVRTC can generate CUBINs for this architecture
1639
1916
  self.is_cubin_supported = self.arch in runtime.nvrtc_supported_archs
1917
+ self.is_mempool_supported = runtime.core.cuda_device_is_memory_pool_supported(ordinal)
1918
+
1919
+ uuid_buffer = (ctypes.c_char * 16)()
1920
+ runtime.core.cuda_device_get_uuid(ordinal, uuid_buffer)
1921
+ uuid_byte_str = bytes(uuid_buffer).hex()
1922
+ self.uuid = f"GPU-{uuid_byte_str[0:8]}-{uuid_byte_str[8:12]}-{uuid_byte_str[12:16]}-{uuid_byte_str[16:20]}-{uuid_byte_str[20:]}"
1923
+
1924
+ pci_domain_id = runtime.core.cuda_device_get_pci_domain_id(ordinal)
1925
+ pci_bus_id = runtime.core.cuda_device_get_pci_bus_id(ordinal)
1926
+ pci_device_id = runtime.core.cuda_device_get_pci_device_id(ordinal)
1927
+ # This is (mis)named to correspond to the naming of cudaDeviceGetPCIBusId
1928
+ self.pci_bus_id = f"{pci_domain_id:08X}:{pci_bus_id:02X}:{pci_device_id:02X}"
1929
+
1930
+ # Warn the user of a possible misconfiguration of their system
1931
+ if not self.is_mempool_supported:
1932
+ warp.utils.warn(
1933
+ f"Support for stream ordered memory allocators was not detected on device {ordinal}. "
1934
+ "This can prevent the use of graphs and/or result in poor performance. "
1935
+ "Is the UVM driver enabled?"
1936
+ )
1640
1937
 
1641
1938
  # initialize streams unless context acquisition is postponed
1642
1939
  if self._context is not None:
@@ -1660,14 +1957,17 @@ class Device:
1660
1957
 
1661
1958
  @property
1662
1959
  def is_cpu(self):
1960
+ """A boolean indicating whether or not the device is a CPU device."""
1663
1961
  return self.ordinal < 0
1664
1962
 
1665
1963
  @property
1666
1964
  def is_cuda(self):
1965
+ """A boolean indicating whether or not the device is a CUDA device."""
1667
1966
  return self.ordinal >= 0
1668
1967
 
1669
1968
  @property
1670
1969
  def context(self):
1970
+ """The context associated with the device."""
1671
1971
  if self._context is not None:
1672
1972
  return self._context
1673
1973
  elif self.is_primary:
@@ -1682,10 +1982,16 @@ class Device:
1682
1982
 
1683
1983
  @property
1684
1984
  def has_context(self):
1985
+ """A boolean indicating whether or not the device has a CUDA context associated with it."""
1685
1986
  return self._context is not None
1686
1987
 
1687
1988
  @property
1688
1989
  def stream(self):
1990
+ """The stream associated with a CUDA device.
1991
+
1992
+ Raises:
1993
+ RuntimeError: The device is not a CUDA device.
1994
+ """
1689
1995
  if self.context:
1690
1996
  return self._stream
1691
1997
  else:
@@ -1703,6 +2009,7 @@ class Device:
1703
2009
 
1704
2010
  @property
1705
2011
  def has_stream(self):
2012
+ """A boolean indicating whether or not the device has a stream associated with it."""
1706
2013
  return self._stream is not None
1707
2014
 
1708
2015
  def __str__(self):
@@ -1778,10 +2085,10 @@ class Runtime:
1778
2085
  warp_lib = os.path.join(bin_path, "warp.so")
1779
2086
  llvm_lib = os.path.join(bin_path, "warp-clang.so")
1780
2087
 
1781
- self.core = warp.build.load_dll(warp_lib)
2088
+ self.core = self.load_dll(warp_lib)
1782
2089
 
1783
- if llvm_lib and os.path.exists(llvm_lib):
1784
- self.llvm = warp.build.load_dll(llvm_lib)
2090
+ if os.path.exists(llvm_lib):
2091
+ self.llvm = self.load_dll(llvm_lib)
1785
2092
  # setup c-types for warp-clang.dll
1786
2093
  self.llvm.lookup.restype = ctypes.c_uint64
1787
2094
  else:
@@ -1852,11 +2159,106 @@ class Runtime:
1852
2159
  ]
1853
2160
  self.core.array_copy_device.restype = ctypes.c_size_t
1854
2161
 
2162
+ self.core.array_fill_host.argtypes = [ctypes.c_void_p, ctypes.c_int, ctypes.c_void_p, ctypes.c_int]
2163
+ self.core.array_fill_host.restype = None
2164
+ self.core.array_fill_device.argtypes = [
2165
+ ctypes.c_void_p,
2166
+ ctypes.c_void_p,
2167
+ ctypes.c_int,
2168
+ ctypes.c_void_p,
2169
+ ctypes.c_int,
2170
+ ]
2171
+ self.core.array_fill_device.restype = None
2172
+
2173
+ self.core.array_sum_double_host.argtypes = [
2174
+ ctypes.c_uint64,
2175
+ ctypes.c_uint64,
2176
+ ctypes.c_int,
2177
+ ctypes.c_int,
2178
+ ctypes.c_int,
2179
+ ]
2180
+ self.core.array_sum_float_host.argtypes = [
2181
+ ctypes.c_uint64,
2182
+ ctypes.c_uint64,
2183
+ ctypes.c_int,
2184
+ ctypes.c_int,
2185
+ ctypes.c_int,
2186
+ ]
2187
+ self.core.array_sum_double_device.argtypes = [
2188
+ ctypes.c_uint64,
2189
+ ctypes.c_uint64,
2190
+ ctypes.c_int,
2191
+ ctypes.c_int,
2192
+ ctypes.c_int,
2193
+ ]
2194
+ self.core.array_sum_float_device.argtypes = [
2195
+ ctypes.c_uint64,
2196
+ ctypes.c_uint64,
2197
+ ctypes.c_int,
2198
+ ctypes.c_int,
2199
+ ctypes.c_int,
2200
+ ]
2201
+
2202
+ self.core.array_inner_double_host.argtypes = [
2203
+ ctypes.c_uint64,
2204
+ ctypes.c_uint64,
2205
+ ctypes.c_uint64,
2206
+ ctypes.c_int,
2207
+ ctypes.c_int,
2208
+ ctypes.c_int,
2209
+ ctypes.c_int,
2210
+ ]
2211
+ self.core.array_inner_float_host.argtypes = [
2212
+ ctypes.c_uint64,
2213
+ ctypes.c_uint64,
2214
+ ctypes.c_uint64,
2215
+ ctypes.c_int,
2216
+ ctypes.c_int,
2217
+ ctypes.c_int,
2218
+ ctypes.c_int,
2219
+ ]
2220
+ self.core.array_inner_double_device.argtypes = [
2221
+ ctypes.c_uint64,
2222
+ ctypes.c_uint64,
2223
+ ctypes.c_uint64,
2224
+ ctypes.c_int,
2225
+ ctypes.c_int,
2226
+ ctypes.c_int,
2227
+ ctypes.c_int,
2228
+ ]
2229
+ self.core.array_inner_float_device.argtypes = [
2230
+ ctypes.c_uint64,
2231
+ ctypes.c_uint64,
2232
+ ctypes.c_uint64,
2233
+ ctypes.c_int,
2234
+ ctypes.c_int,
2235
+ ctypes.c_int,
2236
+ ctypes.c_int,
2237
+ ]
2238
+
1855
2239
  self.core.array_scan_int_host.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int, ctypes.c_bool]
1856
2240
  self.core.array_scan_float_host.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int, ctypes.c_bool]
1857
2241
  self.core.array_scan_int_device.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int, ctypes.c_bool]
1858
2242
  self.core.array_scan_float_device.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int, ctypes.c_bool]
1859
2243
 
2244
+ self.core.radix_sort_pairs_int_host.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int]
2245
+ self.core.radix_sort_pairs_int_device.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int]
2246
+
2247
+ self.core.runlength_encode_int_host.argtypes = [
2248
+ ctypes.c_uint64,
2249
+ ctypes.c_uint64,
2250
+ ctypes.c_uint64,
2251
+ ctypes.c_uint64,
2252
+ ctypes.c_int,
2253
+ ]
2254
+ self.core.runlength_encode_int_device.argtypes = [
2255
+ ctypes.c_uint64,
2256
+ ctypes.c_uint64,
2257
+ ctypes.c_uint64,
2258
+ ctypes.c_uint64,
2259
+ ctypes.c_int,
2260
+ ]
2261
+
1860
2262
  self.core.bvh_create_host.restype = ctypes.c_uint64
1861
2263
  self.core.bvh_create_host.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_int]
1862
2264
 
@@ -1876,6 +2278,7 @@ class Runtime:
1876
2278
  warp.types.array_t,
1877
2279
  ctypes.c_int,
1878
2280
  ctypes.c_int,
2281
+ ctypes.c_int,
1879
2282
  ]
1880
2283
 
1881
2284
  self.core.mesh_create_device.restype = ctypes.c_uint64
@@ -1886,6 +2289,7 @@ class Runtime:
1886
2289
  warp.types.array_t,
1887
2290
  ctypes.c_int,
1888
2291
  ctypes.c_int,
2292
+ ctypes.c_int,
1889
2293
  ]
1890
2294
 
1891
2295
  self.core.mesh_destroy_host.argtypes = [ctypes.c_uint64]
@@ -1998,6 +2402,46 @@ class Runtime:
1998
2402
  ctypes.POINTER(ctypes.c_float),
1999
2403
  ]
2000
2404
 
2405
+ bsr_matrix_from_triplets_argtypes = [
2406
+ ctypes.c_int,
2407
+ ctypes.c_int,
2408
+ ctypes.c_int,
2409
+ ctypes.c_int,
2410
+ ctypes.c_uint64,
2411
+ ctypes.c_uint64,
2412
+ ctypes.c_uint64,
2413
+ ctypes.c_uint64,
2414
+ ctypes.c_uint64,
2415
+ ctypes.c_uint64,
2416
+ ]
2417
+ self.core.bsr_matrix_from_triplets_float_host.argtypes = bsr_matrix_from_triplets_argtypes
2418
+ self.core.bsr_matrix_from_triplets_double_host.argtypes = bsr_matrix_from_triplets_argtypes
2419
+ self.core.bsr_matrix_from_triplets_float_device.argtypes = bsr_matrix_from_triplets_argtypes
2420
+ self.core.bsr_matrix_from_triplets_double_device.argtypes = bsr_matrix_from_triplets_argtypes
2421
+
2422
+ self.core.bsr_matrix_from_triplets_float_host.restype = ctypes.c_int
2423
+ self.core.bsr_matrix_from_triplets_double_host.restype = ctypes.c_int
2424
+ self.core.bsr_matrix_from_triplets_float_device.restype = ctypes.c_int
2425
+ self.core.bsr_matrix_from_triplets_double_device.restype = ctypes.c_int
2426
+
2427
+ bsr_transpose_argtypes = [
2428
+ ctypes.c_int,
2429
+ ctypes.c_int,
2430
+ ctypes.c_int,
2431
+ ctypes.c_int,
2432
+ ctypes.c_int,
2433
+ ctypes.c_uint64,
2434
+ ctypes.c_uint64,
2435
+ ctypes.c_uint64,
2436
+ ctypes.c_uint64,
2437
+ ctypes.c_uint64,
2438
+ ctypes.c_uint64,
2439
+ ]
2440
+ self.core.bsr_transpose_float_host.argtypes = bsr_transpose_argtypes
2441
+ self.core.bsr_transpose_double_host.argtypes = bsr_transpose_argtypes
2442
+ self.core.bsr_transpose_float_device.argtypes = bsr_transpose_argtypes
2443
+ self.core.bsr_transpose_double_device.argtypes = bsr_transpose_argtypes
2444
+
2001
2445
  self.core.is_cuda_enabled.argtypes = None
2002
2446
  self.core.is_cuda_enabled.restype = ctypes.c_int
2003
2447
  self.core.is_cuda_compatibility_enabled.argtypes = None
@@ -2009,6 +2453,8 @@ class Runtime:
2009
2453
  self.core.cuda_driver_version.restype = ctypes.c_int
2010
2454
  self.core.cuda_toolkit_version.argtypes = None
2011
2455
  self.core.cuda_toolkit_version.restype = ctypes.c_int
2456
+ self.core.cuda_driver_is_initialized.argtypes = None
2457
+ self.core.cuda_driver_is_initialized.restype = ctypes.c_bool
2012
2458
 
2013
2459
  self.core.nvrtc_supported_arch_count.argtypes = None
2014
2460
  self.core.nvrtc_supported_arch_count.restype = ctypes.c_int
@@ -2025,6 +2471,14 @@ class Runtime:
2025
2471
  self.core.cuda_device_get_arch.restype = ctypes.c_int
2026
2472
  self.core.cuda_device_is_uva.argtypes = [ctypes.c_int]
2027
2473
  self.core.cuda_device_is_uva.restype = ctypes.c_int
2474
+ self.core.cuda_device_get_uuid.argtypes = [ctypes.c_int, ctypes.c_char * 16]
2475
+ self.core.cuda_device_get_uuid.restype = None
2476
+ self.core.cuda_device_get_pci_domain_id.argtypes = [ctypes.c_int]
2477
+ self.core.cuda_device_get_pci_domain_id.restype = ctypes.c_int
2478
+ self.core.cuda_device_get_pci_bus_id.argtypes = [ctypes.c_int]
2479
+ self.core.cuda_device_get_pci_bus_id.restype = ctypes.c_int
2480
+ self.core.cuda_device_get_pci_device_id.argtypes = [ctypes.c_int]
2481
+ self.core.cuda_device_get_pci_device_id.restype = ctypes.c_int
2028
2482
 
2029
2483
  self.core.cuda_context_get_current.argtypes = None
2030
2484
  self.core.cuda_context_get_current.restype = ctypes.c_void_p
@@ -2111,6 +2565,7 @@ class Runtime:
2111
2565
  ctypes.c_void_p,
2112
2566
  ctypes.c_void_p,
2113
2567
  ctypes.c_size_t,
2568
+ ctypes.c_int,
2114
2569
  ctypes.POINTER(ctypes.c_void_p),
2115
2570
  ]
2116
2571
  self.core.cuda_launch_kernel.restype = ctypes.c_size_t
@@ -2140,7 +2595,6 @@ class Runtime:
2140
2595
 
2141
2596
  self.device_map = {} # device lookup by alias
2142
2597
  self.context_map = {} # device lookup by context
2143
- self.graph_capture_map = {} # indicates whether graph capture is active for a given device
2144
2598
 
2145
2599
  # register CPU device
2146
2600
  cpu_name = platform.processor()
@@ -2149,7 +2603,6 @@ class Runtime:
2149
2603
  self.cpu_device = Device(self, "cpu")
2150
2604
  self.device_map["cpu"] = self.cpu_device
2151
2605
  self.context_map[None] = self.cpu_device
2152
- self.graph_capture_map[None] = False
2153
2606
 
2154
2607
  cuda_device_count = self.core.cuda_device_get_count()
2155
2608
 
@@ -2183,12 +2636,9 @@ class Runtime:
2183
2636
  self.set_default_device("cuda")
2184
2637
  else:
2185
2638
  self.set_default_device("cuda:0")
2186
- # save the initial CUDA device for backward compatibility with ScopedCudaGuard
2187
- self.initial_cuda_device = self.default_device
2188
2639
  else:
2189
2640
  # CUDA not available
2190
2641
  self.set_default_device("cpu")
2191
- self.initial_cuda_device = None
2192
2642
 
2193
2643
  # initialize kernel cache
2194
2644
  warp.build.init_kernel_cache(warp.config.kernel_cache_dir)
@@ -2230,6 +2680,23 @@ class Runtime:
2230
2680
  # global tape
2231
2681
  self.tape = None
2232
2682
 
2683
+ def load_dll(self, dll_path):
2684
+ try:
2685
+ if sys.version_info[0] > 3 or sys.version_info[0] == 3 and sys.version_info[1] >= 8:
2686
+ dll = ctypes.CDLL(dll_path, winmode=0)
2687
+ else:
2688
+ dll = ctypes.CDLL(dll_path)
2689
+ except OSError as e:
2690
+ if "GLIBCXX" in str(e):
2691
+ raise RuntimeError(
2692
+ f"Failed to load the shared library '{dll_path}'.\n"
2693
+ "The execution environment's libstdc++ runtime is older than the version the Warp library was built for.\n"
2694
+ "See https://nvidia.github.io/warp/_build/html/installation.html#conda-environments for details."
2695
+ ) from e
2696
+ else:
2697
+ raise RuntimeError(f"Failed to load the shared library '{dll_path}'") from e
2698
+ return dll
2699
+
2233
2700
  def get_device(self, ident: Devicelike = None) -> Device:
2234
2701
  if isinstance(ident, Device):
2235
2702
  return ident
@@ -2345,15 +2812,7 @@ def assert_initialized():
2345
2812
 
2346
2813
  # global entry points
2347
2814
  def is_cpu_available():
2348
- if runtime.llvm:
2349
- return True
2350
-
2351
- # initialize host build env (do this lazily) since
2352
- # it takes 5secs to run all the batch files to locate MSVC
2353
- if warp.config.host_compiler is None:
2354
- warp.config.host_compiler = warp.build.find_host_compiler()
2355
-
2356
- return warp.config.host_compiler != ""
2815
+ return runtime.llvm
2357
2816
 
2358
2817
 
2359
2818
  def is_cuda_available():
@@ -2364,6 +2823,21 @@ def is_device_available(device):
2364
2823
  return device in get_devices()
2365
2824
 
2366
2825
 
2826
+ def is_cuda_driver_initialized() -> bool:
2827
+ """Returns ``True`` if the CUDA driver is initialized.
2828
+
2829
+ This is a stricter test than ``is_cuda_available()`` since a CUDA driver
2830
+ call to ``cuCtxGetCurrent`` is made, and the result is compared to
2831
+ `CUDA_SUCCESS`. Note that `CUDA_SUCCESS` is returned by ``cuCtxGetCurrent``
2832
+ even if there is no context bound to the calling CPU thread.
2833
+
2834
+ This can be helpful in cases in which ``cuInit()`` was called before a fork.
2835
+ """
2836
+ assert_initialized()
2837
+
2838
+ return runtime.core.cuda_driver_is_initialized()
2839
+
2840
+
2367
2841
  def get_devices() -> List[Device]:
2368
2842
  """Returns a list of devices supported in this environment."""
2369
2843
 
@@ -2590,63 +3064,53 @@ def zeros(
2590
3064
  A warp.array object representing the allocation
2591
3065
  """
2592
3066
 
2593
- # backwards compatibility for case where users did wp.zeros(n, dtype=..), or wp.zeros(n=length, dtype=..)
2594
- if isinstance(shape, int):
2595
- shape = (shape,)
2596
- elif "n" in kwargs:
2597
- shape = (kwargs["n"],)
3067
+ arr = empty(shape=shape, dtype=dtype, device=device, requires_grad=requires_grad, pinned=pinned, **kwargs)
2598
3068
 
2599
- # compute num els
2600
- num_elements = 1
2601
- for d in shape:
2602
- num_elements *= d
3069
+ # use the CUDA default stream for synchronous behaviour with other streams
3070
+ with warp.ScopedStream(arr.device.null_stream):
3071
+ arr.zero_()
2603
3072
 
2604
- num_bytes = num_elements * warp.types.type_size_in_bytes(dtype)
3073
+ return arr
2605
3074
 
2606
- device = get_device(device)
2607
3075
 
2608
- ptr = None
2609
- grad_ptr = None
3076
+ def zeros_like(
3077
+ src: warp.array, device: Devicelike = None, requires_grad: bool = None, pinned: bool = None
3078
+ ) -> warp.array:
3079
+ """Return a zero-initialized array with the same type and dimension of another array
2610
3080
 
2611
- if num_bytes > 0:
2612
- if device.is_capturing:
2613
- raise RuntimeError(f"Cannot allocate memory while graph capture is active on device {device}.")
2614
-
2615
- ptr = device.allocator.alloc(num_bytes, pinned=pinned)
2616
- if ptr is None:
2617
- raise RuntimeError("Memory allocation failed on device: {} for {} bytes".format(device, num_bytes))
2618
-
2619
- # use the CUDA default stream for synchronous behaviour with other streams
2620
- with warp.ScopedStream(device.null_stream):
2621
- device.memset(ptr, 0, num_bytes)
2622
-
2623
- if requires_grad:
2624
- # allocate gradient array
2625
- grad_ptr = device.allocator.alloc(num_bytes, pinned=pinned)
2626
- if grad_ptr is None:
2627
- raise RuntimeError("Memory allocation failed on device: {} for {} bytes".format(device, num_bytes))
2628
- with warp.ScopedStream(device.null_stream):
2629
- device.memset(grad_ptr, 0, num_bytes)
2630
-
2631
- # construct array
2632
- return warp.types.array(
2633
- dtype=dtype,
2634
- shape=shape,
2635
- capacity=num_bytes,
2636
- ptr=ptr,
2637
- grad_ptr=grad_ptr,
2638
- device=device,
2639
- owner=True,
2640
- requires_grad=requires_grad,
2641
- pinned=pinned,
2642
- )
3081
+ Args:
3082
+ src: The template array to use for shape, data type, and device
3083
+ device: The device where the new array will be created (defaults to src.device)
3084
+ requires_grad: Whether the array will be tracked for back propagation
3085
+ pinned: Whether the array uses pinned host memory (only applicable to CPU arrays)
3086
+
3087
+ Returns:
3088
+ A warp.array object representing the allocation
3089
+ """
2643
3090
 
3091
+ arr = empty_like(src, device=device, requires_grad=requires_grad, pinned=pinned)
2644
3092
 
2645
- def zeros_like(src: warp.array, requires_grad: bool = None, pinned: bool = None) -> warp.array:
2646
- """Return a zero-initialized array with the same type and dimension of another array
3093
+ arr.zero_()
3094
+
3095
+ return arr
3096
+
3097
+
3098
+ def full(
3099
+ shape: Tuple = None,
3100
+ value=0,
3101
+ dtype=Any,
3102
+ device: Devicelike = None,
3103
+ requires_grad: bool = False,
3104
+ pinned: bool = False,
3105
+ **kwargs,
3106
+ ) -> warp.array:
3107
+ """Return an array with all elements initialized to the given value
2647
3108
 
2648
3109
  Args:
2649
- src: The template array to use for length, data type, and device
3110
+ shape: Array dimensions
3111
+ value: Element value
3112
+ dtype: Type of each element, e.g.: float, warp.vec3, warp.mat33, etc
3113
+ device: Device that array will live on
2650
3114
  requires_grad: Whether the array will be tracked for back propagation
2651
3115
  pinned: Whether the array uses pinned host memory (only applicable to CPU arrays)
2652
3116
 
@@ -2654,24 +3118,78 @@ def zeros_like(src: warp.array, requires_grad: bool = None, pinned: bool = None)
2654
3118
  A warp.array object representing the allocation
2655
3119
  """
2656
3120
 
2657
- if requires_grad is None:
2658
- if hasattr(src, "requires_grad"):
2659
- requires_grad = src.requires_grad
3121
+ if dtype == Any:
3122
+ # determine dtype from value
3123
+ value_type = type(value)
3124
+ if value_type == int:
3125
+ dtype = warp.int32
3126
+ elif value_type == float:
3127
+ dtype = warp.float32
3128
+ elif value_type in warp.types.scalar_types or hasattr(value_type, "_wp_scalar_type_"):
3129
+ dtype = value_type
3130
+ elif isinstance(value, warp.codegen.StructInstance):
3131
+ dtype = value._cls
3132
+ elif hasattr(value, "__len__"):
3133
+ # a sequence, assume it's a vector or matrix value
3134
+ try:
3135
+ # try to convert to a numpy array first
3136
+ na = np.array(value, copy=False)
3137
+ except Exception as e:
3138
+ raise ValueError(f"Failed to interpret the value as a vector or matrix: {e}")
3139
+
3140
+ # determine the scalar type
3141
+ scalar_type = warp.types.np_dtype_to_warp_type.get(na.dtype)
3142
+ if scalar_type is None:
3143
+ raise ValueError(f"Failed to convert {na.dtype} to a Warp data type")
3144
+
3145
+ # determine if vector or matrix
3146
+ if na.ndim == 1:
3147
+ dtype = warp.types.vector(na.size, scalar_type)
3148
+ elif na.ndim == 2:
3149
+ dtype = warp.types.matrix(na.shape, scalar_type)
3150
+ else:
3151
+ raise ValueError("Values with more than two dimensions are not supported")
2660
3152
  else:
2661
- requires_grad = False
3153
+ raise ValueError(f"Invalid value type for Warp array: {value_type}")
2662
3154
 
2663
- if pinned is None:
2664
- pinned = src.pinned
3155
+ arr = empty(shape=shape, dtype=dtype, device=device, requires_grad=requires_grad, pinned=pinned, **kwargs)
3156
+
3157
+ # use the CUDA default stream for synchronous behaviour with other streams
3158
+ with warp.ScopedStream(arr.device.null_stream):
3159
+ arr.fill_(value)
3160
+
3161
+ return arr
3162
+
3163
+
3164
+ def full_like(
3165
+ src: warp.array, value: Any, device: Devicelike = None, requires_grad: bool = None, pinned: bool = None
3166
+ ) -> warp.array:
3167
+ """Return an array with all elements initialized to the given value with the same type and dimension of another array
3168
+
3169
+ Args:
3170
+ src: The template array to use for shape, data type, and device
3171
+ value: Element value
3172
+ device: The device where the new array will be created (defaults to src.device)
3173
+ requires_grad: Whether the array will be tracked for back propagation
3174
+ pinned: Whether the array uses pinned host memory (only applicable to CPU arrays)
3175
+
3176
+ Returns:
3177
+ A warp.array object representing the allocation
3178
+ """
3179
+
3180
+ arr = empty_like(src, device=device, requires_grad=requires_grad, pinned=pinned)
3181
+
3182
+ arr.fill_(value)
2665
3183
 
2666
- arr = zeros(shape=src.shape, dtype=src.dtype, device=src.device, requires_grad=requires_grad, pinned=pinned)
2667
3184
  return arr
2668
3185
 
2669
3186
 
2670
- def clone(src: warp.array, requires_grad: bool = None, pinned: bool = None) -> warp.array:
3187
+ def clone(src: warp.array, device: Devicelike = None, requires_grad: bool = None, pinned: bool = None) -> warp.array:
2671
3188
  """Clone an existing array, allocates a copy of the src memory
2672
3189
 
2673
3190
  Args:
2674
3191
  src: The source array to copy
3192
+ device: The device where the new array will be created (defaults to src.device)
2675
3193
  requires_grad: Whether the array will be tracked for back propagation
2676
3194
  pinned: Whether the array uses pinned host memory (only applicable to CPU arrays)
2677
3195
 
@@ -2679,19 +3197,11 @@ def clone(src: warp.array, requires_grad: bool = None, pinned: bool = None) -> w
2679
3197
  A warp.array object representing the allocation
2680
3198
  """
2681
3199
 
2682
- if requires_grad is None:
2683
- if hasattr(src, "requires_grad"):
2684
- requires_grad = src.requires_grad
2685
- else:
2686
- requires_grad = False
2687
-
2688
- if pinned is None:
2689
- pinned = src.pinned
3200
+ arr = empty_like(src, device=device, requires_grad=requires_grad, pinned=pinned)
2690
3201
 
2691
- dest = empty(shape=src.shape, dtype=src.dtype, device=src.device, requires_grad=requires_grad, pinned=pinned)
2692
- copy(dest, src)
3202
+ warp.copy(arr, src)
2693
3203
 
2694
- return dest
3204
+ return arr
2695
3205
 
2696
3206
 
2697
3207
  def empty(
@@ -2705,7 +3215,7 @@ def empty(
2705
3215
  """Returns an uninitialized array
2706
3216
 
2707
3217
  Args:
2708
- n: Number of elements
3218
+ shape: Array dimensions
2709
3219
  dtype: Type of each element, e.g.: `warp.vec3`, `warp.mat33`, etc
2710
3220
  device: Device that array will live on
2711
3221
  requires_grad: Whether the array will be tracked for back propagation
@@ -2715,15 +3225,26 @@ def empty(
2715
3225
  A warp.array object representing the allocation
2716
3226
  """
2717
3227
 
2718
- # todo: implement uninitialized allocation
2719
- return zeros(shape, dtype, device, requires_grad=requires_grad, pinned=pinned, **kwargs)
3228
+ # backwards compatibility for case where users called wp.empty(n=length, ...)
3229
+ if "n" in kwargs:
3230
+ shape = (kwargs["n"],)
3231
+ del kwargs["n"]
3232
+
3233
+ # ensure shape is specified, even if creating a zero-sized array
3234
+ if shape is None:
3235
+ shape = 0
3236
+
3237
+ return warp.array(shape=shape, dtype=dtype, device=device, requires_grad=requires_grad, pinned=pinned, **kwargs)
2720
3238
 
2721
3239
 
2722
- def empty_like(src: warp.array, requires_grad: bool = None, pinned: bool = None) -> warp.array:
3240
+ def empty_like(
3241
+ src: warp.array, device: Devicelike = None, requires_grad: bool = None, pinned: bool = None
3242
+ ) -> warp.array:
2723
3243
  """Return an uninitialized array with the same type and dimension of another array
2724
3244
 
2725
3245
  Args:
2726
- src: The template array to use for length, data type, and device
3246
+ src: The template array to use for shape, data type, and device
3247
+ device: The device where the new array will be created (defaults to src.device)
2727
3248
  requires_grad: Whether the array will be tracked for back propagation
2728
3249
  pinned: Whether the array uses pinned host memory (only applicable to CPU arrays)
2729
3250
 
@@ -2731,6 +3252,9 @@ def empty_like(src: warp.array, requires_grad: bool = None, pinned: bool = None)
2731
3252
  A warp.array object representing the allocation
2732
3253
  """
2733
3254
 
3255
+ if device is None:
3256
+ device = src.device
3257
+
2734
3258
  if requires_grad is None:
2735
3259
  if hasattr(src, "requires_grad"):
2736
3260
  requires_grad = src.requires_grad
@@ -2738,14 +3262,246 @@ def empty_like(src: warp.array, requires_grad: bool = None, pinned: bool = None)
2738
3262
  requires_grad = False
2739
3263
 
2740
3264
  if pinned is None:
2741
- pinned = src.pinned
3265
+ if hasattr(src, "pinned"):
3266
+ pinned = src.pinned
3267
+ else:
3268
+ pinned = False
2742
3269
 
2743
- arr = empty(shape=src.shape, dtype=src.dtype, device=src.device, requires_grad=requires_grad, pinned=pinned)
3270
+ arr = empty(shape=src.shape, dtype=src.dtype, device=device, requires_grad=requires_grad, pinned=pinned)
2744
3271
  return arr
2745
3272
 
2746
3273
 
2747
- def from_numpy(arr, dtype, device: Devicelike = None, requires_grad=False):
2748
- return warp.array(data=arr, dtype=dtype, device=device, requires_grad=requires_grad)
3274
+ def from_numpy(
3275
+ arr: np.ndarray,
3276
+ dtype: Optional[type] = None,
3277
+ shape: Optional[Sequence[int]] = None,
3278
+ device: Optional[Devicelike] = None,
3279
+ requires_grad: bool = False,
3280
+ ) -> warp.array:
3281
+ if dtype is None:
3282
+ base_type = warp.types.np_dtype_to_warp_type.get(arr.dtype)
3283
+ if base_type is None:
3284
+ raise RuntimeError("Unsupported NumPy data type '{}'.".format(arr.dtype))
3285
+
3286
+ dim_count = len(arr.shape)
3287
+ if dim_count == 2:
3288
+ dtype = warp.types.vector(length=arr.shape[1], dtype=base_type)
3289
+ elif dim_count == 3:
3290
+ dtype = warp.types.matrix(shape=(arr.shape[1], arr.shape[2]), dtype=base_type)
3291
+ else:
3292
+ dtype = base_type
3293
+
3294
+ return warp.array(
3295
+ data=arr,
3296
+ dtype=dtype,
3297
+ shape=shape,
3298
+ owner=False,
3299
+ device=device,
3300
+ requires_grad=requires_grad,
3301
+ )
3302
+
3303
+
3304
+ # given a kernel destination argument type and a value convert
3305
+ # to a c-type that can be passed to a kernel
3306
+ def pack_arg(kernel, arg_type, arg_name, value, device, adjoint=False):
3307
+ if warp.types.is_array(arg_type):
3308
+ if value is None:
3309
+ # allow for NULL arrays
3310
+ return arg_type.__ctype__()
3311
+
3312
+ else:
3313
+ # check for array type
3314
+ # - in forward passes, array types have to match
3315
+ # - in backward passes, indexed array gradients are regular arrays
3316
+ if adjoint:
3317
+ array_matches = isinstance(value, warp.array)
3318
+ else:
3319
+ array_matches = type(value) is type(arg_type)
3320
+
3321
+ if not array_matches:
3322
+ adj = "adjoint " if adjoint else ""
3323
+ raise RuntimeError(
3324
+ f"Error launching kernel '{kernel.key}', {adj}argument '{arg_name}' expects an array of type {type(arg_type)}, but passed value has type {type(value)}."
3325
+ )
3326
+
3327
+ # check subtype
3328
+ if not warp.types.types_equal(value.dtype, arg_type.dtype):
3329
+ adj = "adjoint " if adjoint else ""
3330
+ raise RuntimeError(
3331
+ f"Error launching kernel '{kernel.key}', {adj}argument '{arg_name}' expects an array with dtype={arg_type.dtype} but passed array has dtype={value.dtype}."
3332
+ )
3333
+
3334
+ # check dimensions
3335
+ if value.ndim != arg_type.ndim:
3336
+ adj = "adjoint " if adjoint else ""
3337
+ raise RuntimeError(
3338
+ f"Error launching kernel '{kernel.key}', {adj}argument '{arg_name}' expects an array with {arg_type.ndim} dimension(s) but the passed array has {value.ndim} dimension(s)."
3339
+ )
3340
+
3341
+ # check device
3342
+ # if a.device != device and not device.can_access(a.device):
3343
+ if value.device != device:
3344
+ raise RuntimeError(
3345
+ f"Error launching kernel '{kernel.key}', trying to launch on device='{device}', but input array for argument '{arg_name}' is on device={value.device}."
3346
+ )
3347
+
3348
+ return value.__ctype__()
3349
+
3350
+ elif isinstance(arg_type, warp.codegen.Struct):
3351
+ assert value is not None
3352
+ return value.__ctype__()
3353
+
3354
+ # try to convert to a value type (vec3, mat33, etc)
3355
+ elif issubclass(arg_type, ctypes.Array):
3356
+ if warp.types.types_equal(type(value), arg_type):
3357
+ return value
3358
+ else:
3359
+ # try constructing the required value from the argument (handles tuple / list, Gf.Vec3 case)
3360
+ try:
3361
+ return arg_type(value)
3362
+ except Exception:
3363
+ raise ValueError(f"Failed to convert argument for param {arg_name} to {type_str(arg_type)}")
3364
+
3365
+ elif isinstance(value, bool):
3366
+ return ctypes.c_bool(value)
3367
+
3368
+ elif isinstance(value, arg_type):
3369
+ try:
3370
+ # try to pack as a scalar type
3371
+ if arg_type is warp.types.float16:
3372
+ return arg_type._type_(warp.types.float_to_half_bits(value.value))
3373
+ else:
3374
+ return arg_type._type_(value.value)
3375
+ except Exception:
3376
+ raise RuntimeError(
3377
+ "Error launching kernel, unable to pack kernel parameter type "
3378
+ f"{type(value)} for param {arg_name}, expected {arg_type}"
3379
+ )
3380
+
3381
+ else:
3382
+ try:
3383
+ # try to pack as a scalar type
3384
+ if arg_type is warp.types.float16:
3385
+ return arg_type._type_(warp.types.float_to_half_bits(value))
3386
+ else:
3387
+ return arg_type._type_(value)
3388
+ except Exception as e:
3389
+ print(e)
3390
+ raise RuntimeError(
3391
+ "Error launching kernel, unable to pack kernel parameter type "
3392
+ f"{type(value)} for param {arg_name}, expected {arg_type}"
3393
+ )
3394
+
3395
+
3396
+ # represents all data required for a kernel launch
3397
+ # so that launches can be replayed quickly, use `wp.launch(..., record_cmd=True)`
3398
+ class Launch:
3399
+ def __init__(self, kernel, device, hooks=None, params=None, params_addr=None, bounds=None, max_blocks=0):
3400
+ # if not specified look up hooks
3401
+ if not hooks:
3402
+ module = kernel.module
3403
+ if not module.load(device):
3404
+ return
3405
+
3406
+ hooks = module.get_kernel_hooks(kernel, device)
3407
+
3408
+ # if not specified set a zero bound
3409
+ if not bounds:
3410
+ bounds = warp.types.launch_bounds_t(0)
3411
+
3412
+ # if not specified then build a list of default value params for args
3413
+ if not params:
3414
+ params = []
3415
+ params.append(bounds)
3416
+
3417
+ for a in kernel.adj.args:
3418
+ if isinstance(a.type, warp.types.array):
3419
+ params.append(a.type.__ctype__())
3420
+ elif isinstance(a.type, warp.codegen.Struct):
3421
+ params.append(a.type().__ctype__())
3422
+ else:
3423
+ params.append(pack_arg(kernel, a.type, a.label, 0, device, False))
3424
+
3425
+ kernel_args = [ctypes.c_void_p(ctypes.addressof(x)) for x in params]
3426
+ kernel_params = (ctypes.c_void_p * len(kernel_args))(*kernel_args)
3427
+
3428
+ params_addr = kernel_params
3429
+
3430
+ self.kernel = kernel
3431
+ self.hooks = hooks
3432
+ self.params = params
3433
+ self.params_addr = params_addr
3434
+ self.device = device
3435
+ self.bounds = bounds
3436
+ self.max_blocks = max_blocks
3437
+
3438
+ def set_dim(self, dim):
3439
+ self.bounds = warp.types.launch_bounds_t(dim)
3440
+
3441
+ # launch bounds always at index 0
3442
+ self.params[0] = self.bounds
3443
+
3444
+ # for CUDA kernels we need to update the address to each arg
3445
+ if self.params_addr:
3446
+ self.params_addr[0] = ctypes.c_void_p(ctypes.addressof(self.bounds))
3447
+
3448
+ # set kernel param at an index, will convert to ctype as necessary
3449
+ def set_param_at_index(self, index, value):
3450
+ arg_type = self.kernel.adj.args[index].type
3451
+ arg_name = self.kernel.adj.args[index].label
3452
+
3453
+ carg = pack_arg(self.kernel, arg_type, arg_name, value, self.device, False)
3454
+
3455
+ self.params[index + 1] = carg
3456
+
3457
+ # for CUDA kernels we need to update the address to each arg
3458
+ if self.params_addr:
3459
+ self.params_addr[index + 1] = ctypes.c_void_p(ctypes.addressof(carg))
3460
+
3461
+ # set kernel param at an index without any type conversion
3462
+ # args must be passed as ctypes or basic int / float types
3463
+ def set_param_at_index_from_ctype(self, index, value):
3464
+ if isinstance(value, ctypes.Structure):
3465
+ # not sure how to directly assign struct->struct without reallocating using ctypes
3466
+ self.params[index + 1] = value
3467
+
3468
+ # for CUDA kernels we need to update the address to each arg
3469
+ if self.params_addr:
3470
+ self.params_addr[index + 1] = ctypes.c_void_p(ctypes.addressof(value))
3471
+
3472
+ else:
3473
+ self.params[index + 1].__init__(value)
3474
+
3475
+ # set kernel param by argument name
3476
+ def set_param_by_name(self, name, value):
3477
+ for i, arg in enumerate(self.kernel.adj.args):
3478
+ if arg.label == name:
3479
+ self.set_param_at_index(i, value)
3480
+
3481
+ # set kernel param by argument name with no type conversions
3482
+ def set_param_by_name_from_ctype(self, name, value):
3483
+ # lookup argument index
3484
+ for i, arg in enumerate(self.kernel.adj.args):
3485
+ if arg.label == name:
3486
+ self.set_param_at_index_from_ctype(i, value)
3487
+
3488
+ # set all params
3489
+ def set_params(self, values):
3490
+ for i, v in enumerate(values):
3491
+ self.set_param_at_index(i, v)
3492
+
3493
+ # set all params without performing type-conversions
3494
+ def set_params_from_ctypes(self, values):
3495
+ for i, v in enumerate(values):
3496
+ self.set_param_at_index_from_ctype(i, v)
3497
+
3498
+ def launch(self) -> Any:
3499
+ if self.device.is_cpu:
3500
+ self.hooks.forward(*self.params)
3501
+ else:
3502
+ runtime.core.cuda_launch_kernel(
3503
+ self.device.context, self.hooks.forward, self.bounds.size, self.max_blocks, self.params_addr
3504
+ )
2749
3505
 
2750
3506
 
2751
3507
  def launch(
@@ -2759,6 +3515,8 @@ def launch(
2759
3515
  stream: Stream = None,
2760
3516
  adjoint=False,
2761
3517
  record_tape=True,
3518
+ record_cmd=False,
3519
+ max_blocks=0,
2762
3520
  ):
2763
3521
  """Launch a Warp kernel on the target device
2764
3522
 
@@ -2774,6 +3532,10 @@ def launch(
2774
3532
  device: The device to launch on (optional)
2775
3533
  stream: The stream to launch on (optional)
2776
3534
  adjoint: Whether to run forward or backward pass (typically use False)
3535
+ record_tape: When true the launch will be recorded the global wp.Tape() object when present
3536
+ record_cmd: When True the launch will be returned as a ``Launch`` command object, the launch will not occur until the user calls ``cmd.launch()``
3537
+ max_blocks: The maximum number of CUDA thread blocks to use. Only has an effect for CUDA kernel launches.
3538
+ If negative or zero, the maximum hardware value will be used.
2777
3539
  """
2778
3540
 
2779
3541
  assert_initialized()
@@ -2785,7 +3547,7 @@ def launch(
2785
3547
  device = runtime.get_device(device)
2786
3548
 
2787
3549
  # check function is a Kernel
2788
- if isinstance(kernel, Kernel) == False:
3550
+ if not isinstance(kernel, Kernel):
2789
3551
  raise RuntimeError("Error launching kernel, can only launch functions decorated with @wp.kernel.")
2790
3552
 
2791
3553
  # debugging aid
@@ -2806,85 +3568,7 @@ def launch(
2806
3568
  arg_type = kernel.adj.args[i].type
2807
3569
  arg_name = kernel.adj.args[i].label
2808
3570
 
2809
- if warp.types.is_array(arg_type):
2810
- if a is None:
2811
- # allow for NULL arrays
2812
- params.append(arg_type.__ctype__())
2813
-
2814
- else:
2815
- # check for array type
2816
- # - in forward passes, array types have to match
2817
- # - in backward passes, indexed array gradients are regular arrays
2818
- if adjoint:
2819
- array_matches = type(a) == warp.array
2820
- else:
2821
- array_matches = type(a) == type(arg_type)
2822
-
2823
- if not array_matches:
2824
- adj = "adjoint " if adjoint else ""
2825
- raise RuntimeError(
2826
- f"Error launching kernel '{kernel.key}', {adj}argument '{arg_name}' expects an array of type {type(arg_type)}, but passed value has type {type(a)}."
2827
- )
2828
-
2829
- # check subtype
2830
- if not warp.types.types_equal(a.dtype, arg_type.dtype):
2831
- adj = "adjoint " if adjoint else ""
2832
- raise RuntimeError(
2833
- f"Error launching kernel '{kernel.key}', {adj}argument '{arg_name}' expects an array with dtype={arg_type.dtype} but passed array has dtype={a.dtype}."
2834
- )
2835
-
2836
- # check dimensions
2837
- if a.ndim != arg_type.ndim:
2838
- adj = "adjoint " if adjoint else ""
2839
- raise RuntimeError(
2840
- f"Error launching kernel '{kernel.key}', {adj}argument '{arg_name}' expects an array with {arg_type.ndim} dimension(s) but the passed array has {a.ndim} dimension(s)."
2841
- )
2842
-
2843
- # check device
2844
- # if a.device != device and not device.can_access(a.device):
2845
- if a.device != device:
2846
- raise RuntimeError(
2847
- f"Error launching kernel '{kernel.key}', trying to launch on device='{device}', but input array for argument '{arg_name}' is on device={a.device}."
2848
- )
2849
-
2850
- params.append(a.__ctype__())
2851
-
2852
- elif isinstance(arg_type, warp.codegen.Struct):
2853
- assert a is not None
2854
- params.append(a.__ctype__())
2855
-
2856
- # try to convert to a value type (vec3, mat33, etc)
2857
- elif issubclass(arg_type, ctypes.Array):
2858
- if warp.types.types_equal(type(a), arg_type):
2859
- params.append(a)
2860
- else:
2861
- # try constructing the required value from the argument (handles tuple / list, Gf.Vec3 case)
2862
- try:
2863
- params.append(arg_type(a))
2864
- except:
2865
- raise ValueError(f"Failed to convert argument for param {arg_name} to {type_str(arg_type)}")
2866
-
2867
- elif isinstance(a, bool):
2868
- params.append(ctypes.c_bool(a))
2869
-
2870
- elif isinstance(a, arg_type):
2871
- try:
2872
- # try to pack as a scalar type
2873
- params.append(arg_type._type_(a.value))
2874
- except:
2875
- raise RuntimeError(
2876
- f"Error launching kernel, unable to pack kernel parameter type {type(a)} for param {arg_name}, expected {arg_type}"
2877
- )
2878
-
2879
- else:
2880
- try:
2881
- # try to pack as a scalar type
2882
- params.append(arg_type._type_(a))
2883
- except Exception as e:
2884
- print(e)
2885
- raise RuntimeError(
2886
- f"Error launching kernel, unable to pack kernel parameter type {type(a)} for param {arg_name}, expected {arg_type}"
2887
- )
3571
+ params.append(pack_arg(kernel, arg_type, arg_name, a, device, adjoint))
2888
3572
 
2889
3573
  fwd_args = inputs + outputs
2890
3574
  adj_args = adj_inputs + adj_outputs
@@ -2926,7 +3610,13 @@ def launch(
2926
3610
  f"Failed to find forward kernel '{kernel.key}' from module '{kernel.module.name}' for device '{device}'"
2927
3611
  )
2928
3612
 
2929
- hooks.forward(*params)
3613
+ if record_cmd:
3614
+ launch = Launch(
3615
+ kernel=kernel, hooks=hooks, params=params, params_addr=None, bounds=bounds, device=device
3616
+ )
3617
+ return launch
3618
+ else:
3619
+ hooks.forward(*params)
2930
3620
 
2931
3621
  else:
2932
3622
  kernel_args = [ctypes.c_void_p(ctypes.addressof(x)) for x in params]
@@ -2939,7 +3629,9 @@ def launch(
2939
3629
  f"Failed to find backward kernel '{kernel.key}' from module '{kernel.module.name}' for device '{device}'"
2940
3630
  )
2941
3631
 
2942
- runtime.core.cuda_launch_kernel(device.context, hooks.backward, bounds.size, kernel_params)
3632
+ runtime.core.cuda_launch_kernel(
3633
+ device.context, hooks.backward, bounds.size, max_blocks, kernel_params
3634
+ )
2943
3635
 
2944
3636
  else:
2945
3637
  if hooks.forward is None:
@@ -2947,7 +3639,22 @@ def launch(
2947
3639
  f"Failed to find forward kernel '{kernel.key}' from module '{kernel.module.name}' for device '{device}'"
2948
3640
  )
2949
3641
 
2950
- runtime.core.cuda_launch_kernel(device.context, hooks.forward, bounds.size, kernel_params)
3642
+ if record_cmd:
3643
+ launch = Launch(
3644
+ kernel=kernel,
3645
+ hooks=hooks,
3646
+ params=params,
3647
+ params_addr=kernel_params,
3648
+ bounds=bounds,
3649
+ device=device,
3650
+ )
3651
+ return launch
3652
+
3653
+ else:
3654
+ # launch
3655
+ runtime.core.cuda_launch_kernel(
3656
+ device.context, hooks.forward, bounds.size, max_blocks, kernel_params
3657
+ )
2951
3658
 
2952
3659
  try:
2953
3660
  runtime.verify_cuda_device(device)
@@ -2957,7 +3664,7 @@ def launch(
2957
3664
 
2958
3665
  # record on tape if one is active
2959
3666
  if runtime.tape and record_tape:
2960
- runtime.tape.record_launch(kernel, dim, inputs, outputs, device)
3667
+ runtime.tape.record_launch(kernel, dim, max_blocks, inputs, outputs, device)
2961
3668
 
2962
3669
 
2963
3670
  def synchronize():
@@ -2967,7 +3674,7 @@ def synchronize():
2967
3674
  or memory copies have completed.
2968
3675
  """
2969
3676
 
2970
- if is_cuda_available():
3677
+ if is_cuda_driver_initialized():
2971
3678
  # save the original context to avoid side effects
2972
3679
  saved_context = runtime.core.cuda_context_get_current()
2973
3680
 
@@ -3017,7 +3724,7 @@ def synchronize_stream(stream_or_device=None):
3017
3724
  runtime.core.cuda_stream_synchronize(stream.device.context, stream.cuda_stream)
3018
3725
 
3019
3726
 
3020
- def force_load(device: Union[Device, str] = None, modules: List[Module] = None):
3727
+ def force_load(device: Union[Device, str, List[Device], List[str]] = None, modules: List[Module] = None):
3021
3728
  """Force user-defined kernels to be compiled and loaded
3022
3729
 
3023
3730
  Args:
@@ -3025,12 +3732,14 @@ def force_load(device: Union[Device, str] = None, modules: List[Module] = None):
3025
3732
  modules: List of modules to load. If None, load all imported modules.
3026
3733
  """
3027
3734
 
3028
- if is_cuda_available():
3735
+ if is_cuda_driver_initialized():
3029
3736
  # save original context to avoid side effects
3030
3737
  saved_context = runtime.core.cuda_context_get_current()
3031
3738
 
3032
3739
  if device is None:
3033
3740
  devices = get_devices()
3741
+ elif isinstance(device, list):
3742
+ devices = [get_device(device_item) for device_item in device]
3034
3743
  else:
3035
3744
  devices = [get_device(device)]
3036
3745
 
@@ -3122,7 +3831,7 @@ def get_module_options(module: Optional[Any] = None) -> Dict[str, Any]:
3122
3831
  return get_module(m.__name__).options
3123
3832
 
3124
3833
 
3125
- def capture_begin(device: Devicelike = None, stream=None, force_module_load=True):
3834
+ def capture_begin(device: Devicelike = None, stream=None, force_module_load=None):
3126
3835
  """Begin capture of a CUDA graph
3127
3836
 
3128
3837
  Captures all subsequent kernel launches and memory operations on CUDA devices.
@@ -3136,7 +3845,10 @@ def capture_begin(device: Devicelike = None, stream=None, force_module_load=True
3136
3845
 
3137
3846
  """
3138
3847
 
3139
- if warp.config.verify_cuda == True:
3848
+ if force_module_load is None:
3849
+ force_module_load = warp.config.graph_capture_module_load_default
3850
+
3851
+ if warp.config.verify_cuda:
3140
3852
  raise RuntimeError("Cannot use CUDA error verification during graph capture")
3141
3853
 
3142
3854
  if stream is not None:
@@ -3151,6 +3863,9 @@ def capture_begin(device: Devicelike = None, stream=None, force_module_load=True
3151
3863
 
3152
3864
  device.is_capturing = True
3153
3865
 
3866
+ # disable garbage collection to avoid older allocations getting collected during graph capture
3867
+ gc.disable()
3868
+
3154
3869
  with warp.ScopedStream(stream):
3155
3870
  runtime.core.cuda_graph_begin_capture(device.context)
3156
3871
 
@@ -3174,6 +3889,9 @@ def capture_end(device: Devicelike = None, stream=None) -> Graph:
3174
3889
 
3175
3890
  device.is_capturing = False
3176
3891
 
3892
+ # re-enable GC
3893
+ gc.enable()
3894
+
3177
3895
  if graph is None:
3178
3896
  raise RuntimeError(
3179
3897
  "Error occurred during CUDA graph capture. This could be due to an unintended allocation or CPU/GPU synchronization event."
@@ -3226,7 +3944,14 @@ def copy(
3226
3944
  if count == 0:
3227
3945
  return
3228
3946
 
3229
- has_grad = hasattr(src, "grad_ptr") and hasattr(dest, "grad_ptr") and src.grad_ptr and dest.grad_ptr
3947
+ # copying non-contiguous arrays requires that they are on the same device
3948
+ if not (src.is_contiguous and dest.is_contiguous) and src.device != dest.device:
3949
+ if dest.is_contiguous:
3950
+ # make a contiguous copy of the source array
3951
+ src = src.contiguous()
3952
+ else:
3953
+ # make a copy of the source array on the destination device
3954
+ src = src.to(dest.device)
3230
3955
 
3231
3956
  if src.is_contiguous and dest.is_contiguous:
3232
3957
  bytes_to_copy = count * warp.types.type_size_in_bytes(src.dtype)
@@ -3240,10 +3965,6 @@ def copy(
3240
3965
  src_ptr = src.ptr + src_offset_in_bytes
3241
3966
  dst_ptr = dest.ptr + dst_offset_in_bytes
3242
3967
 
3243
- if has_grad:
3244
- src_grad_ptr = src.grad_ptr + src_offset_in_bytes
3245
- dst_grad_ptr = dest.grad_ptr + dst_offset_in_bytes
3246
-
3247
3968
  if src_offset_in_bytes + bytes_to_copy > src_size_in_bytes:
3248
3969
  raise RuntimeError(
3249
3970
  f"Trying to copy source buffer with size ({bytes_to_copy}) from offset ({src_offset_in_bytes}) is larger than source size ({src_size_in_bytes})"
@@ -3256,8 +3977,6 @@ def copy(
3256
3977
 
3257
3978
  if src.device.is_cpu and dest.device.is_cpu:
3258
3979
  runtime.core.memcpy_h2h(dst_ptr, src_ptr, bytes_to_copy)
3259
- if has_grad:
3260
- runtime.core.memcpy_h2h(dst_grad_ptr, src_grad_ptr, bytes_to_copy)
3261
3980
  else:
3262
3981
  # figure out the CUDA context/stream for the copy
3263
3982
  if stream is not None:
@@ -3270,32 +3989,19 @@ def copy(
3270
3989
  with warp.ScopedStream(stream):
3271
3990
  if src.device.is_cpu and dest.device.is_cuda:
3272
3991
  runtime.core.memcpy_h2d(copy_device.context, dst_ptr, src_ptr, bytes_to_copy)
3273
- if has_grad:
3274
- runtime.core.memcpy_h2d(copy_device.context, dst_grad_ptr, src_grad_ptr, bytes_to_copy)
3275
3992
  elif src.device.is_cuda and dest.device.is_cpu:
3276
3993
  runtime.core.memcpy_d2h(copy_device.context, dst_ptr, src_ptr, bytes_to_copy)
3277
- if has_grad:
3278
- runtime.core.memcpy_d2h(copy_device.context, dst_grad_ptr, src_grad_ptr, bytes_to_copy)
3279
3994
  elif src.device.is_cuda and dest.device.is_cuda:
3280
3995
  if src.device == dest.device:
3281
3996
  runtime.core.memcpy_d2d(copy_device.context, dst_ptr, src_ptr, bytes_to_copy)
3282
- if has_grad:
3283
- runtime.core.memcpy_d2d(copy_device.context, dst_grad_ptr, src_grad_ptr, bytes_to_copy)
3284
3997
  else:
3285
3998
  runtime.core.memcpy_peer(copy_device.context, dst_ptr, src_ptr, bytes_to_copy)
3286
- if has_grad:
3287
- runtime.core.memcpy_peer(copy_device.context, dst_grad_ptr, src_grad_ptr, bytes_to_copy)
3288
3999
  else:
3289
4000
  raise RuntimeError("Unexpected source and destination combination")
3290
4001
 
3291
4002
  else:
3292
4003
  # handle non-contiguous and indexed arrays
3293
4004
 
3294
- if src.device != dest.device:
3295
- raise RuntimeError(
3296
- f"Copies between non-contiguous arrays must be on the same device, got {dest.device} and {src.device}"
3297
- )
3298
-
3299
4005
  if src.shape != dest.shape:
3300
4006
  raise RuntimeError("Incompatible array shapes")
3301
4007
 
@@ -3305,18 +4011,22 @@ def copy(
3305
4011
  if src_elem_size != dst_elem_size:
3306
4012
  raise RuntimeError("Incompatible array data types")
3307
4013
 
3308
- def array_type(a):
3309
- if isinstance(a, warp.types.array):
3310
- return warp.types.ARRAY_TYPE_REGULAR
3311
- elif isinstance(a, warp.types.indexedarray):
3312
- return warp.types.ARRAY_TYPE_INDEXED
4014
+ # can't copy to/from fabric arrays of arrays, because they are jagged arrays of arbitrary lengths
4015
+ # TODO?
4016
+ if (
4017
+ isinstance(src, (warp.fabricarray, warp.indexedfabricarray))
4018
+ and src.ndim > 1
4019
+ or isinstance(dest, (warp.fabricarray, warp.indexedfabricarray))
4020
+ and dest.ndim > 1
4021
+ ):
4022
+ raise RuntimeError("Copying to/from Fabric arrays of arrays is not supported")
3313
4023
 
3314
4024
  src_desc = src.__ctype__()
3315
4025
  dst_desc = dest.__ctype__()
3316
4026
  src_ptr = ctypes.pointer(src_desc)
3317
4027
  dst_ptr = ctypes.pointer(dst_desc)
3318
- src_type = array_type(src)
3319
- dst_type = array_type(dest)
4028
+ src_type = warp.types.array_type_id(src)
4029
+ dst_type = warp.types.array_type_id(dest)
3320
4030
 
3321
4031
  if src.device.is_cuda:
3322
4032
  with warp.ScopedStream(stream):
@@ -3324,6 +4034,10 @@ def copy(
3324
4034
  else:
3325
4035
  runtime.core.array_copy_host(dst_ptr, src_ptr, dst_type, src_type, src_elem_size)
3326
4036
 
4037
+ # copy gradient, if needed
4038
+ if hasattr(src, "grad") and src.grad is not None and hasattr(dest, "grad") and dest.grad is not None:
4039
+ copy(dest.grad, src.grad, stream=stream)
4040
+
3327
4041
 
3328
4042
  def type_str(t):
3329
4043
  if t is None:
@@ -3342,6 +4056,10 @@ def type_str(t):
3342
4056
  return f"Array[{type_str(t.dtype)}]"
3343
4057
  elif isinstance(t, warp.indexedarray):
3344
4058
  return f"IndexedArray[{type_str(t.dtype)}]"
4059
+ elif isinstance(t, warp.fabricarray):
4060
+ return f"FabricArray[{type_str(t.dtype)}]"
4061
+ elif isinstance(t, warp.indexedfabricarray):
4062
+ return f"IndexedFabricArray[{type_str(t.dtype)}]"
3345
4063
  elif hasattr(t, "_wp_generic_type_str_"):
3346
4064
  generic_type = t._wp_generic_type_str_
3347
4065
 
@@ -3368,7 +4086,7 @@ def type_str(t):
3368
4086
  return t.__name__
3369
4087
 
3370
4088
 
3371
- def print_function(f, file, noentry=False):
4089
+ def print_function(f, file, noentry=False): # pragma: no cover
3372
4090
  """Writes a function definition to a file for use in reST documentation
3373
4091
 
3374
4092
  Args:
@@ -3392,7 +4110,7 @@ def print_function(f, file, noentry=False):
3392
4110
  # todo: construct a default value for each of the functions args
3393
4111
  # so we can generate the return type for overloaded functions
3394
4112
  return_type = " -> " + type_str(f.value_func(None, None, None))
3395
- except:
4113
+ except Exception:
3396
4114
  pass
3397
4115
 
3398
4116
  print(f".. function:: {f.key}({args}){return_type}", file=file)
@@ -3413,7 +4131,7 @@ def print_function(f, file, noentry=False):
3413
4131
  return True
3414
4132
 
3415
4133
 
3416
- def print_builtins(file):
4134
+ def export_functions_rst(file): # pragma: no cover
3417
4135
  header = (
3418
4136
  "..\n"
3419
4137
  " Autogenerated File - Do not edit. Run build_docs.py to generate.\n"
@@ -3433,6 +4151,8 @@ def print_builtins(file):
3433
4151
 
3434
4152
  for t in warp.types.scalar_types:
3435
4153
  print(f".. class:: {t.__name__}", file=file)
4154
+ # Manually add wp.bool since it's inconvenient to add to wp.types.scalar_types:
4155
+ print(f".. class:: {warp.types.bool.__name__}", file=file)
3436
4156
 
3437
4157
  print("\n\nVector Types", file=file)
3438
4158
  print("------------", file=file)
@@ -3443,14 +4163,22 @@ def print_builtins(file):
3443
4163
  print("\nGeneric Types", file=file)
3444
4164
  print("-------------", file=file)
3445
4165
 
3446
- print(f".. class:: Int", file=file)
3447
- print(f".. class:: Float", file=file)
3448
- print(f".. class:: Scalar", file=file)
3449
- print(f".. class:: Vector", file=file)
3450
- print(f".. class:: Matrix", file=file)
3451
- print(f".. class:: Quaternion", file=file)
3452
- print(f".. class:: Transformation", file=file)
3453
- print(f".. class:: Array", file=file)
4166
+ print(".. class:: Int", file=file)
4167
+ print(".. class:: Float", file=file)
4168
+ print(".. class:: Scalar", file=file)
4169
+ print(".. class:: Vector", file=file)
4170
+ print(".. class:: Matrix", file=file)
4171
+ print(".. class:: Quaternion", file=file)
4172
+ print(".. class:: Transformation", file=file)
4173
+ print(".. class:: Array", file=file)
4174
+
4175
+ print("\nQuery Types", file=file)
4176
+ print("-------------", file=file)
4177
+ print(".. autoclass:: bvh_query_t", file=file)
4178
+ print(".. autoclass:: hash_grid_query_t", file=file)
4179
+ print(".. autoclass:: mesh_query_aabb_t", file=file)
4180
+ print(".. autoclass:: mesh_query_point_t", file=file)
4181
+ print(".. autoclass:: mesh_query_ray_t", file=file)
3454
4182
 
3455
4183
  # build dictionary of all functions by group
3456
4184
  groups = {}
@@ -3485,7 +4213,7 @@ def print_builtins(file):
3485
4213
  print(".. [1] Note: function gradients not implemented for backpropagation.", file=file)
3486
4214
 
3487
4215
 
3488
- def export_stubs(file):
4216
+ def export_stubs(file): # pragma: no cover
3489
4217
  """Generates stub file for auto-complete of builtin functions"""
3490
4218
 
3491
4219
  import textwrap
@@ -3517,6 +4245,8 @@ def export_stubs(file):
3517
4245
  print("Quaternion = Generic[Float]", file=file)
3518
4246
  print("Transformation = Generic[Float]", file=file)
3519
4247
  print("Array = Generic[DType]", file=file)
4248
+ print("FabricArray = Generic[DType]", file=file)
4249
+ print("IndexedFabricArray = Generic[DType]", file=file)
3520
4250
 
3521
4251
  # prepend __init__.py
3522
4252
  with open(os.path.join(os.path.dirname(file.name), "__init__.py")) as header_file:
@@ -3533,7 +4263,7 @@ def export_stubs(file):
3533
4263
 
3534
4264
  return_str = ""
3535
4265
 
3536
- if f.export == False or f.hidden == True: # or f.generic:
4266
+ if not f.export or f.hidden: # or f.generic:
3537
4267
  continue
3538
4268
 
3539
4269
  try:
@@ -3543,29 +4273,42 @@ def export_stubs(file):
3543
4273
  if return_type:
3544
4274
  return_str = " -> " + type_str(return_type)
3545
4275
 
3546
- except:
4276
+ except Exception:
3547
4277
  pass
3548
4278
 
3549
4279
  print("@over", file=file)
3550
4280
  print(f"def {f.key}({args}){return_str}:", file=file)
3551
- print(f' """', file=file)
4281
+ print(' """', file=file)
3552
4282
  print(textwrap.indent(text=f.doc, prefix=" "), file=file)
3553
- print(f' """', file=file)
3554
- print(f" ...\n\n", file=file)
4283
+ print(' """', file=file)
4284
+ print(" ...\n\n", file=file)
3555
4285
 
3556
4286
 
3557
- def export_builtins(file):
3558
- def ctype_str(t):
4287
+ def export_builtins(file: io.TextIOBase): # pragma: no cover
4288
+ def ctype_arg_str(t):
3559
4289
  if isinstance(t, int):
3560
4290
  return "int"
3561
4291
  elif isinstance(t, float):
3562
4292
  return "float"
4293
+ elif t in warp.types.vector_types:
4294
+ return f"{t.__name__}&"
3563
4295
  else:
3564
4296
  return t.__name__
3565
4297
 
4298
+ def ctype_ret_str(t):
4299
+ if isinstance(t, int):
4300
+ return "int"
4301
+ elif isinstance(t, float):
4302
+ return "float"
4303
+ else:
4304
+ return t.__name__
4305
+
4306
+ file.write("namespace wp {\n\n")
4307
+ file.write('extern "C" {\n\n')
4308
+
3566
4309
  for k, g in builtin_functions.items():
3567
4310
  for f in g.overloads:
3568
- if f.export == False or f.generic:
4311
+ if not f.export or f.generic:
3569
4312
  continue
3570
4313
 
3571
4314
  simple = True
@@ -3579,7 +4322,7 @@ def export_builtins(file):
3579
4322
  if not simple or f.variadic:
3580
4323
  continue
3581
4324
 
3582
- args = ", ".join(f"{ctype_str(v)} {k}" for k, v in f.input_types.items())
4325
+ args = ", ".join(f"{ctype_arg_str(v)} {k}" for k, v in f.input_types.items())
3583
4326
  params = ", ".join(f.input_types.keys())
3584
4327
 
3585
4328
  return_type = ""
@@ -3587,25 +4330,25 @@ def export_builtins(file):
3587
4330
  try:
3588
4331
  # todo: construct a default value for each of the functions args
3589
4332
  # so we can generate the return type for overloaded functions
3590
- return_type = ctype_str(f.value_func(None, None, None))
3591
- except:
4333
+ return_type = ctype_ret_str(f.value_func(None, None, None))
4334
+ except Exception:
3592
4335
  continue
3593
4336
 
3594
4337
  if return_type.startswith("Tuple"):
3595
4338
  continue
3596
4339
 
3597
4340
  if args == "":
3598
- print(
3599
- f"WP_API void {f.mangled_name}({return_type}* ret) {{ *ret = wp::{f.key}({params}); }}", file=file
3600
- )
4341
+ file.write(f"WP_API void {f.mangled_name}({return_type}* ret) {{ *ret = wp::{f.key}({params}); }}\n")
3601
4342
  elif return_type == "None":
3602
- print(f"WP_API void {f.mangled_name}({args}) {{ wp::{f.key}({params}); }}", file=file)
4343
+ file.write(f"WP_API void {f.mangled_name}({args}) {{ wp::{f.key}({params}); }}\n")
3603
4344
  else:
3604
- print(
3605
- f"WP_API void {f.mangled_name}({args}, {return_type}* ret) {{ *ret = wp::{f.key}({params}); }}",
3606
- file=file,
4345
+ file.write(
4346
+ f"WP_API void {f.mangled_name}({args}, {return_type}* ret) {{ *ret = wp::{f.key}({params}); }}\n"
3607
4347
  )
3608
4348
 
4349
+ file.write('\n} // extern "C"\n\n')
4350
+ file.write("} // namespace wp\n")
4351
+
3609
4352
 
3610
4353
  # initialize global runtime
3611
4354
  runtime = None