warp-lang 1.6.1__py3-none-win_amd64.whl → 1.7.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (401) hide show
  1. warp/__init__.py +21 -7
  2. warp/autograd.py +14 -6
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +424 -6
  6. warp/build_dll.py +20 -20
  7. warp/builtins.py +467 -368
  8. warp/codegen.py +193 -125
  9. warp/config.py +56 -12
  10. warp/constants.py +14 -6
  11. warp/context.py +524 -277
  12. warp/dlpack.py +22 -12
  13. warp/examples/__init__.py +14 -6
  14. warp/examples/assets/nonuniform.usd +0 -0
  15. warp/examples/assets/nvidia_logo.png +0 -0
  16. warp/examples/benchmarks/benchmark_api.py +14 -6
  17. warp/examples/benchmarks/benchmark_cloth.py +14 -6
  18. warp/examples/benchmarks/benchmark_cloth_cupy.py +14 -6
  19. warp/examples/benchmarks/benchmark_cloth_jax.py +14 -6
  20. warp/examples/benchmarks/benchmark_cloth_numba.py +15 -0
  21. warp/examples/benchmarks/benchmark_cloth_numpy.py +14 -6
  22. warp/examples/benchmarks/benchmark_cloth_paddle.py +14 -6
  23. warp/examples/benchmarks/benchmark_cloth_pytorch.py +14 -6
  24. warp/examples/benchmarks/benchmark_cloth_taichi.py +14 -6
  25. warp/examples/benchmarks/benchmark_cloth_warp.py +14 -6
  26. warp/examples/benchmarks/benchmark_gemm.py +82 -48
  27. warp/examples/benchmarks/benchmark_interop_paddle.py +14 -6
  28. warp/examples/benchmarks/benchmark_interop_torch.py +14 -6
  29. warp/examples/benchmarks/benchmark_launches.py +14 -6
  30. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  31. warp/examples/browse.py +14 -6
  32. warp/examples/core/example_cupy.py +14 -6
  33. warp/examples/core/example_dem.py +14 -6
  34. warp/examples/core/example_fluid.py +14 -6
  35. warp/examples/core/example_graph_capture.py +14 -6
  36. warp/examples/core/example_marching_cubes.py +14 -6
  37. warp/examples/core/example_mesh.py +14 -6
  38. warp/examples/core/example_mesh_intersect.py +14 -6
  39. warp/examples/core/example_nvdb.py +14 -6
  40. warp/examples/core/example_raycast.py +14 -6
  41. warp/examples/core/example_raymarch.py +14 -6
  42. warp/examples/core/example_render_opengl.py +14 -6
  43. warp/examples/core/example_sample_mesh.py +300 -0
  44. warp/examples/core/example_sph.py +14 -6
  45. warp/examples/core/example_torch.py +14 -6
  46. warp/examples/core/example_wave.py +14 -6
  47. warp/examples/fem/example_adaptive_grid.py +14 -6
  48. warp/examples/fem/example_apic_fluid.py +15 -7
  49. warp/examples/fem/example_burgers.py +16 -8
  50. warp/examples/fem/example_convection_diffusion.py +14 -6
  51. warp/examples/fem/example_convection_diffusion_dg.py +14 -6
  52. warp/examples/fem/example_deformed_geometry.py +15 -7
  53. warp/examples/fem/example_diffusion.py +14 -6
  54. warp/examples/fem/example_diffusion_3d.py +14 -6
  55. warp/examples/fem/example_diffusion_mgpu.py +14 -6
  56. warp/examples/fem/example_distortion_energy.py +15 -7
  57. warp/examples/fem/example_magnetostatics.py +20 -12
  58. warp/examples/fem/example_mixed_elasticity.py +14 -6
  59. warp/examples/fem/example_navier_stokes.py +14 -6
  60. warp/examples/fem/example_nonconforming_contact.py +14 -6
  61. warp/examples/fem/example_stokes.py +14 -6
  62. warp/examples/fem/example_stokes_transfer.py +14 -6
  63. warp/examples/fem/example_streamlines.py +14 -6
  64. warp/examples/fem/utils.py +24 -3
  65. warp/examples/interop/example_jax_callable.py +116 -0
  66. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  67. warp/examples/interop/example_jax_kernel.py +205 -0
  68. warp/examples/optim/example_bounce.py +14 -6
  69. warp/examples/optim/example_cloth_throw.py +14 -6
  70. warp/examples/optim/example_diffray.py +14 -6
  71. warp/examples/optim/example_drone.py +14 -6
  72. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  73. warp/examples/optim/example_inverse_kinematics.py +14 -6
  74. warp/examples/optim/example_inverse_kinematics_torch.py +14 -6
  75. warp/examples/optim/example_softbody_properties.py +14 -6
  76. warp/examples/optim/example_spring_cage.py +14 -6
  77. warp/examples/optim/example_trajectory.py +14 -6
  78. warp/examples/sim/example_cartpole.py +14 -6
  79. warp/examples/sim/example_cloth.py +14 -6
  80. warp/examples/sim/example_cloth_self_contact.py +14 -6
  81. warp/examples/sim/example_granular.py +14 -6
  82. warp/examples/sim/example_granular_collision_sdf.py +14 -6
  83. warp/examples/sim/example_jacobian_ik.py +14 -6
  84. warp/examples/sim/example_particle_chain.py +14 -6
  85. warp/examples/sim/example_quadruped.py +14 -6
  86. warp/examples/sim/example_rigid_chain.py +14 -6
  87. warp/examples/sim/example_rigid_contact.py +14 -6
  88. warp/examples/sim/example_rigid_force.py +14 -6
  89. warp/examples/sim/example_rigid_gyroscopic.py +14 -6
  90. warp/examples/sim/example_rigid_soft_contact.py +14 -6
  91. warp/examples/sim/example_soft_body.py +14 -6
  92. warp/examples/tile/example_tile_cholesky.py +14 -6
  93. warp/examples/tile/example_tile_convolution.py +14 -6
  94. warp/examples/tile/example_tile_fft.py +14 -6
  95. warp/examples/tile/example_tile_filtering.py +14 -6
  96. warp/examples/tile/example_tile_matmul.py +16 -10
  97. warp/examples/tile/example_tile_mlp.py +14 -6
  98. warp/examples/tile/example_tile_nbody.py +14 -6
  99. warp/examples/tile/example_tile_walker.py +14 -6
  100. warp/fabric.py +15 -0
  101. warp/fem/__init__.py +26 -1
  102. warp/fem/adaptivity.py +19 -4
  103. warp/fem/cache.py +15 -0
  104. warp/fem/dirichlet.py +15 -0
  105. warp/fem/domain.py +15 -0
  106. warp/fem/field/__init__.py +15 -0
  107. warp/fem/field/field.py +15 -0
  108. warp/fem/field/nodal_field.py +37 -68
  109. warp/fem/field/restriction.py +15 -0
  110. warp/fem/field/virtual.py +77 -23
  111. warp/fem/geometry/__init__.py +15 -0
  112. warp/fem/geometry/adaptive_nanogrid.py +24 -10
  113. warp/fem/geometry/closest_point.py +16 -1
  114. warp/fem/geometry/deformed_geometry.py +20 -2
  115. warp/fem/geometry/element.py +15 -0
  116. warp/fem/geometry/geometry.py +20 -0
  117. warp/fem/geometry/grid_2d.py +27 -12
  118. warp/fem/geometry/grid_3d.py +27 -15
  119. warp/fem/geometry/hexmesh.py +20 -7
  120. warp/fem/geometry/nanogrid.py +24 -11
  121. warp/fem/geometry/partition.py +15 -0
  122. warp/fem/geometry/quadmesh.py +28 -13
  123. warp/fem/geometry/tetmesh.py +18 -4
  124. warp/fem/geometry/trimesh.py +18 -8
  125. warp/fem/integrate.py +277 -93
  126. warp/fem/linalg.py +20 -5
  127. warp/fem/operator.py +15 -0
  128. warp/fem/polynomial.py +15 -0
  129. warp/fem/quadrature/__init__.py +15 -0
  130. warp/fem/quadrature/pic_quadrature.py +52 -22
  131. warp/fem/quadrature/quadrature.py +209 -25
  132. warp/fem/space/__init__.py +16 -1
  133. warp/fem/space/basis_function_space.py +19 -2
  134. warp/fem/space/basis_space.py +40 -18
  135. warp/fem/space/dof_mapper.py +15 -0
  136. warp/fem/space/function_space.py +15 -0
  137. warp/fem/space/grid_2d_function_space.py +15 -0
  138. warp/fem/space/grid_3d_function_space.py +15 -0
  139. warp/fem/space/hexmesh_function_space.py +17 -2
  140. warp/fem/space/nanogrid_function_space.py +15 -0
  141. warp/fem/space/partition.py +21 -2
  142. warp/fem/space/quadmesh_function_space.py +23 -8
  143. warp/fem/space/restriction.py +15 -0
  144. warp/fem/space/shape/__init__.py +15 -0
  145. warp/fem/space/shape/cube_shape_function.py +38 -23
  146. warp/fem/space/shape/shape_function.py +15 -0
  147. warp/fem/space/shape/square_shape_function.py +27 -12
  148. warp/fem/space/shape/tet_shape_function.py +15 -0
  149. warp/fem/space/shape/triangle_shape_function.py +16 -1
  150. warp/fem/space/tetmesh_function_space.py +18 -3
  151. warp/fem/space/topology.py +15 -0
  152. warp/fem/space/trimesh_function_space.py +17 -2
  153. warp/fem/types.py +15 -0
  154. warp/fem/utils.py +27 -6
  155. warp/jax.py +28 -7
  156. warp/jax_experimental/__init__.py +16 -0
  157. warp/{jax_experimental.py → jax_experimental/custom_call.py} +28 -33
  158. warp/jax_experimental/ffi.py +698 -0
  159. warp/jax_experimental/xla_ffi.py +602 -0
  160. warp/math.py +103 -6
  161. warp/native/array.h +28 -6
  162. warp/native/builtin.h +44 -9
  163. warp/native/bvh.cpp +18 -7
  164. warp/native/bvh.cu +57 -20
  165. warp/native/bvh.h +17 -7
  166. warp/native/clang/clang.cpp +45 -9
  167. warp/native/coloring.cpp +15 -6
  168. warp/native/crt.cpp +15 -6
  169. warp/native/crt.h +15 -6
  170. warp/native/cuda_crt.h +15 -6
  171. warp/native/cuda_util.cpp +29 -6
  172. warp/native/cuda_util.h +17 -6
  173. warp/native/error.cpp +15 -6
  174. warp/native/error.h +15 -6
  175. warp/native/exports.h +85 -63
  176. warp/native/fabric.h +15 -6
  177. warp/native/hashgrid.cpp +15 -6
  178. warp/native/hashgrid.cu +15 -6
  179. warp/native/hashgrid.h +15 -6
  180. warp/native/initializer_array.h +15 -6
  181. warp/native/intersect.h +41 -32
  182. warp/native/intersect_adj.h +48 -39
  183. warp/native/intersect_tri.h +17 -0
  184. warp/native/marching.cpp +16 -0
  185. warp/native/marching.cu +16 -7
  186. warp/native/marching.h +17 -0
  187. warp/native/mat.h +528 -15
  188. warp/native/mathdx.cpp +15 -6
  189. warp/native/matnn.h +15 -6
  190. warp/native/mesh.cpp +15 -6
  191. warp/native/mesh.cu +15 -6
  192. warp/native/mesh.h +25 -16
  193. warp/native/noise.h +15 -6
  194. warp/native/quat.h +114 -17
  195. warp/native/rand.h +21 -6
  196. warp/native/range.h +15 -6
  197. warp/native/reduce.cpp +15 -6
  198. warp/native/reduce.cu +15 -6
  199. warp/native/runlength_encode.cpp +15 -6
  200. warp/native/runlength_encode.cu +15 -6
  201. warp/native/scan.cpp +15 -6
  202. warp/native/scan.cu +15 -6
  203. warp/native/scan.h +15 -6
  204. warp/native/solid_angle.h +17 -0
  205. warp/native/sort.cpp +137 -65
  206. warp/native/sort.cu +167 -21
  207. warp/native/sort.h +23 -7
  208. warp/native/sparse.cpp +58 -28
  209. warp/native/sparse.cu +67 -23
  210. warp/native/spatial.h +15 -6
  211. warp/native/svd.h +131 -6
  212. warp/native/temp_buffer.h +15 -6
  213. warp/native/tile.h +316 -111
  214. warp/native/tile_reduce.h +61 -9
  215. warp/native/vec.h +83 -13
  216. warp/native/volume.cpp +100 -119
  217. warp/native/volume.cu +15 -6
  218. warp/native/volume.h +15 -6
  219. warp/native/volume_builder.cu +40 -16
  220. warp/native/volume_builder.h +21 -6
  221. warp/native/volume_impl.h +15 -6
  222. warp/native/warp.cpp +20 -12
  223. warp/native/warp.cu +114 -16
  224. warp/native/warp.h +34 -16
  225. warp/optim/__init__.py +14 -6
  226. warp/optim/adam.py +14 -6
  227. warp/optim/linear.py +25 -10
  228. warp/optim/sgd.py +14 -6
  229. warp/paddle.py +14 -6
  230. warp/render/__init__.py +14 -6
  231. warp/render/render_opengl.py +14 -6
  232. warp/render/render_usd.py +14 -6
  233. warp/render/utils.py +14 -6
  234. warp/sim/__init__.py +14 -7
  235. warp/sim/articulation.py +18 -10
  236. warp/sim/collide.py +35 -16
  237. warp/sim/graph_coloring.py +14 -6
  238. warp/sim/import_mjcf.py +463 -162
  239. warp/sim/import_snu.py +14 -7
  240. warp/sim/import_urdf.py +46 -18
  241. warp/sim/import_usd.py +14 -7
  242. warp/sim/inertia.py +14 -6
  243. warp/sim/integrator.py +14 -6
  244. warp/sim/integrator_euler.py +19 -11
  245. warp/sim/integrator_featherstone.py +17 -16
  246. warp/sim/integrator_vbd.py +222 -8
  247. warp/sim/integrator_xpbd.py +19 -11
  248. warp/sim/model.py +56 -19
  249. warp/sim/particles.py +14 -6
  250. warp/sim/render.py +14 -6
  251. warp/sim/utils.py +17 -2
  252. warp/sparse.py +657 -555
  253. warp/stubs.py +231 -19
  254. warp/tape.py +14 -6
  255. warp/tests/aux_test_class_kernel.py +14 -6
  256. warp/tests/aux_test_compile_consts_dummy.py +14 -6
  257. warp/tests/aux_test_conditional_unequal_types_kernels.py +14 -6
  258. warp/tests/aux_test_dependent.py +14 -6
  259. warp/tests/aux_test_grad_customs.py +14 -6
  260. warp/tests/aux_test_instancing_gc.py +14 -6
  261. warp/tests/aux_test_module_unload.py +14 -6
  262. warp/tests/aux_test_name_clash1.py +14 -6
  263. warp/tests/aux_test_name_clash2.py +14 -6
  264. warp/tests/aux_test_unresolved_func.py +14 -6
  265. warp/tests/aux_test_unresolved_symbol.py +14 -6
  266. warp/tests/cuda/__init__.py +0 -0
  267. warp/tests/{test_async.py → cuda/test_async.py} +14 -6
  268. warp/tests/{test_ipc.py → cuda/test_ipc.py} +14 -6
  269. warp/tests/{test_mempool.py → cuda/test_mempool.py} +53 -6
  270. warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +14 -6
  271. warp/tests/{test_peer.py → cuda/test_peer.py} +14 -6
  272. warp/tests/{test_pinned.py → cuda/test_pinned.py} +14 -6
  273. warp/tests/{test_streams.py → cuda/test_streams.py} +85 -6
  274. warp/tests/geometry/__init__.py +0 -0
  275. warp/tests/{test_bvh.py → geometry/test_bvh.py} +14 -6
  276. warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +14 -6
  277. warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +14 -6
  278. warp/tests/{test_mesh.py → geometry/test_mesh.py} +14 -6
  279. warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +14 -6
  280. warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +80 -69
  281. warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +15 -7
  282. warp/tests/{test_volume.py → geometry/test_volume.py} +55 -12
  283. warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +14 -6
  284. warp/tests/interop/__init__.py +0 -0
  285. warp/tests/{test_dlpack.py → interop/test_dlpack.py} +42 -11
  286. warp/tests/{test_jax.py → interop/test_jax.py} +14 -6
  287. warp/tests/{test_paddle.py → interop/test_paddle.py} +14 -6
  288. warp/tests/{test_torch.py → interop/test_torch.py} +14 -6
  289. warp/tests/run_coverage_serial.py +14 -6
  290. warp/tests/sim/__init__.py +0 -0
  291. warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +23 -16
  292. warp/tests/{flaky_test_sim_grad.py → sim/flaky_test_sim_grad.py} +14 -6
  293. warp/tests/{test_collision.py → sim/test_collision.py} +16 -8
  294. warp/tests/{test_coloring.py → sim/test_coloring.py} +14 -7
  295. warp/tests/{test_model.py → sim/test_model.py} +55 -7
  296. warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +14 -6
  297. warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +16 -7
  298. warp/tests/sim/test_vbd.py +597 -0
  299. warp/tests/test_adam.py +14 -6
  300. warp/tests/test_arithmetic.py +14 -6
  301. warp/tests/test_array.py +14 -6
  302. warp/tests/test_array_reduce.py +14 -6
  303. warp/tests/test_assert.py +14 -6
  304. warp/tests/test_atomic.py +14 -6
  305. warp/tests/test_bool.py +15 -7
  306. warp/tests/test_builtins_resolution.py +14 -6
  307. warp/tests/test_closest_point_edge_edge.py +14 -6
  308. warp/tests/test_codegen.py +14 -6
  309. warp/tests/test_codegen_instancing.py +14 -6
  310. warp/tests/test_compile_consts.py +14 -6
  311. warp/tests/test_conditional.py +14 -6
  312. warp/tests/test_context.py +14 -6
  313. warp/tests/test_copy.py +14 -6
  314. warp/tests/test_ctypes.py +14 -6
  315. warp/tests/test_dense.py +14 -6
  316. warp/tests/test_devices.py +14 -6
  317. warp/tests/test_examples.py +42 -42
  318. warp/tests/test_fabricarray.py +14 -6
  319. warp/tests/test_fast_math.py +14 -6
  320. warp/tests/test_fem.py +37 -10
  321. warp/tests/test_fp16.py +14 -6
  322. warp/tests/test_func.py +14 -6
  323. warp/tests/test_future_annotations.py +14 -6
  324. warp/tests/test_generics.py +14 -6
  325. warp/tests/test_grad.py +14 -6
  326. warp/tests/test_grad_customs.py +14 -6
  327. warp/tests/test_grad_debug.py +14 -6
  328. warp/tests/test_implicit_init.py +14 -6
  329. warp/tests/test_import.py +14 -6
  330. warp/tests/test_indexedarray.py +14 -6
  331. warp/tests/test_intersect.py +14 -6
  332. warp/tests/test_iter.py +14 -6
  333. warp/tests/test_large.py +14 -6
  334. warp/tests/test_launch.py +14 -6
  335. warp/tests/test_lerp.py +14 -6
  336. warp/tests/test_linear_solvers.py +15 -11
  337. warp/tests/test_lvalue.py +14 -6
  338. warp/tests/test_mat.py +247 -85
  339. warp/tests/test_mat_lite.py +14 -6
  340. warp/tests/test_mat_scalar_ops.py +18 -10
  341. warp/tests/test_math.py +14 -6
  342. warp/tests/test_mlp.py +14 -6
  343. warp/tests/test_module_hashing.py +14 -6
  344. warp/tests/test_modules_lite.py +14 -6
  345. warp/tests/test_noise.py +14 -6
  346. warp/tests/test_operators.py +14 -6
  347. warp/tests/test_options.py +14 -6
  348. warp/tests/test_overwrite.py +15 -60
  349. warp/tests/test_print.py +14 -6
  350. warp/tests/test_quat.py +81 -52
  351. warp/tests/test_rand.py +58 -43
  352. warp/tests/test_reload.py +14 -6
  353. warp/tests/test_rounding.py +14 -6
  354. warp/tests/test_runlength_encode.py +14 -6
  355. warp/tests/test_scalar_ops.py +14 -6
  356. warp/tests/test_smoothstep.py +14 -6
  357. warp/tests/test_snippet.py +15 -0
  358. warp/tests/test_sparse.py +61 -12
  359. warp/tests/test_spatial.py +89 -6
  360. warp/tests/test_special_values.py +14 -6
  361. warp/tests/test_static.py +15 -7
  362. warp/tests/test_struct.py +14 -6
  363. warp/tests/test_tape.py +14 -6
  364. warp/tests/test_transient_module.py +14 -6
  365. warp/tests/test_triangle_closest_point.py +14 -6
  366. warp/tests/test_types.py +14 -6
  367. warp/tests/test_utils.py +98 -10
  368. warp/tests/test_vec.py +60 -40
  369. warp/tests/test_vec_lite.py +14 -6
  370. warp/tests/test_vec_scalar_ops.py +14 -6
  371. warp/tests/test_verify_fp.py +14 -6
  372. warp/tests/tile/__init__.py +0 -0
  373. warp/tests/{test_tile.py → tile/test_tile.py} +150 -57
  374. warp/tests/{test_tile_load.py → tile/test_tile_load.py} +15 -7
  375. warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +23 -12
  376. warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +39 -20
  377. warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +74 -7
  378. warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +14 -6
  379. warp/tests/{test_tile_view.py → tile/test_tile_view.py} +15 -7
  380. warp/tests/unittest_serial.py +15 -6
  381. warp/tests/unittest_suites.py +59 -65
  382. warp/tests/unittest_utils.py +16 -7
  383. warp/tests/walkthrough_debug.py +14 -6
  384. warp/thirdparty/unittest_parallel.py +15 -8
  385. warp/torch.py +14 -6
  386. warp/types.py +124 -664
  387. warp/utils.py +151 -78
  388. {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/METADATA +39 -12
  389. warp_lang-1.7.0.dist-info/RECORD +429 -0
  390. {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/WHEEL +1 -1
  391. warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
  392. warp/examples/optim/example_walker.py +0 -309
  393. warp/native/cutlass_gemm.cpp +0 -34
  394. warp/native/cutlass_gemm.cu +0 -373
  395. warp/tests/test_matmul.py +0 -503
  396. warp/tests/test_matmul_lite.py +0 -403
  397. warp/tests/test_vbd.py +0 -378
  398. warp/tests/unused_test_misc.py +0 -69
  399. warp_lang-1.6.1.dist-info/LICENSE.md +0 -126
  400. warp_lang-1.6.1.dist-info/RECORD +0 -419
  401. {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/top_level.txt +0 -0
warp/context.py CHANGED
@@ -1,15 +1,22 @@
1
- # Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
2
- # NVIDIA CORPORATION and its licensors retain all intellectual property
3
- # and proprietary rights in and to this software, related documentation
4
- # and any modifications thereto. Any use, reproduction, disclosure or
5
- # distribution of this software and related documentation without an express
6
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
7
15
 
8
16
  from __future__ import annotations
9
17
 
10
18
  import ast
11
19
  import ctypes
12
- import errno
13
20
  import functools
14
21
  import hashlib
15
22
  import inspect
@@ -20,13 +27,27 @@ import operator
20
27
  import os
21
28
  import platform
22
29
  import sys
23
- import time
24
30
  import types
25
31
  import typing
26
32
  import weakref
27
33
  from copy import copy as shallowcopy
28
34
  from pathlib import Path
29
- from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
35
+ from typing import (
36
+ Any,
37
+ Callable,
38
+ Dict,
39
+ List,
40
+ Literal,
41
+ Mapping,
42
+ Optional,
43
+ Sequence,
44
+ Set,
45
+ Tuple,
46
+ TypeVar,
47
+ Union,
48
+ get_args,
49
+ get_origin,
50
+ )
30
51
 
31
52
  import numpy as np
32
53
 
@@ -34,7 +55,7 @@ import warp
34
55
  import warp.build
35
56
  import warp.codegen
36
57
  import warp.config
37
- from warp.types import launch_bounds_t
58
+ from warp.types import Array, launch_bounds_t
38
59
 
39
60
  # represents either a built-in or user-defined function
40
61
 
@@ -63,10 +84,10 @@ def get_function_args(func):
63
84
  complex_type_hints = (Any, Callable, Tuple)
64
85
  sequence_types = (list, tuple)
65
86
 
66
- function_key_counts = {}
87
+ function_key_counts: Dict[str, int] = {}
67
88
 
68
89
 
69
- def generate_unique_function_identifier(key):
90
+ def generate_unique_function_identifier(key: str) -> str:
70
91
  # Generate unique identifiers for user-defined functions in native code.
71
92
  # - Prevents conflicts when a function is redefined and old versions are still in use.
72
93
  # - Prevents conflicts between multiple closures returned from the same function.
@@ -99,40 +120,40 @@ def generate_unique_function_identifier(key):
99
120
  class Function:
100
121
  def __init__(
101
122
  self,
102
- func,
103
- key,
104
- namespace,
105
- input_types=None,
106
- value_type=None,
107
- value_func=None,
108
- export_func=None,
109
- dispatch_func=None,
110
- lto_dispatch_func=None,
111
- module=None,
112
- variadic=False,
113
- initializer_list_func=None,
114
- export=False,
115
- doc="",
116
- group="",
117
- hidden=False,
118
- skip_replay=False,
119
- missing_grad=False,
120
- generic=False,
121
- native_func=None,
122
- defaults=None,
123
- custom_replay_func=None,
124
- native_snippet=None,
125
- adj_native_snippet=None,
126
- replay_snippet=None,
127
- skip_forward_codegen=False,
128
- skip_reverse_codegen=False,
129
- custom_reverse_num_input_args=-1,
130
- custom_reverse_mode=False,
131
- overloaded_annotations=None,
132
- code_transformers=None,
133
- skip_adding_overload=False,
134
- require_original_output_arg=False,
135
- scope_locals=None, # the locals() where the function is defined, used for overload management
123
+ func: Optional[Callable],
124
+ key: str,
125
+ namespace: str,
126
+ input_types: Optional[Dict[str, Union[type, TypeVar]]] = None,
127
+ value_type: Optional[type] = None,
128
+ value_func: Optional[Callable[[Mapping[str, type], Mapping[str, Any]], type]] = None,
129
+ export_func: Optional[Callable[[Dict[str, type]], Dict[str, type]]] = None,
130
+ dispatch_func: Optional[Callable] = None,
131
+ lto_dispatch_func: Optional[Callable] = None,
132
+ module: Optional[Module] = None,
133
+ variadic: bool = False,
134
+ initializer_list_func: Optional[Callable[[Dict[str, Any], type], bool]] = None,
135
+ export: bool = False,
136
+ doc: str = "",
137
+ group: str = "",
138
+ hidden: bool = False,
139
+ skip_replay: bool = False,
140
+ missing_grad: bool = False,
141
+ generic: bool = False,
142
+ native_func: Optional[str] = None,
143
+ defaults: Optional[Dict[str, Any]] = None,
144
+ custom_replay_func: Optional[Function] = None,
145
+ native_snippet: Optional[str] = None,
146
+ adj_native_snippet: Optional[str] = None,
147
+ replay_snippet: Optional[str] = None,
148
+ skip_forward_codegen: bool = False,
149
+ skip_reverse_codegen: bool = False,
150
+ custom_reverse_num_input_args: int = -1,
151
+ custom_reverse_mode: bool = False,
152
+ overloaded_annotations: Optional[Dict[str, type]] = None,
153
+ code_transformers: Optional[List[ast.NodeTransformer]] = None,
154
+ skip_adding_overload: bool = False,
155
+ require_original_output_arg: bool = False,
156
+ scope_locals: Optional[Dict[str, Any]] = None,
136
157
  ):
137
158
  if code_transformers is None:
138
159
  code_transformers = []
@@ -157,7 +178,7 @@ class Function:
157
178
  self.native_snippet = native_snippet
158
179
  self.adj_native_snippet = adj_native_snippet
159
180
  self.replay_snippet = replay_snippet
160
- self.custom_grad_func = None
181
+ self.custom_grad_func: Optional[Function] = None
161
182
  self.require_original_output_arg = require_original_output_arg
162
183
  self.generic_parent = None # generic function that was used to instantiate this overload
163
184
 
@@ -173,6 +194,7 @@ class Function:
173
194
  )
174
195
  self.missing_grad = missing_grad # whether builtin is missing a corresponding adjoint
175
196
  self.generic = generic
197
+ self.mangled_name: Optional[str] = None
176
198
 
177
199
  # allow registering functions with a different name in Python and native code
178
200
  if native_func is None:
@@ -189,8 +211,8 @@ class Function:
189
211
  # user-defined function
190
212
 
191
213
  # generic and concrete overload lookups by type signature
192
- self.user_templates = {}
193
- self.user_overloads = {}
214
+ self.user_templates: Dict[str, Function] = {}
215
+ self.user_overloads: Dict[str, Function] = {}
194
216
 
195
217
  # user defined (Python) function
196
218
  self.adj = warp.codegen.Adjoint(
@@ -221,19 +243,17 @@ class Function:
221
243
  # builtin function
222
244
 
223
245
  # embedded linked list of all overloads
224
- # the builtin_functions dictionary holds
225
- # the list head for a given key (func name)
226
- self.overloads = []
246
+ # the builtin_functions dictionary holds the list head for a given key (func name)
247
+ self.overloads: List[Function] = []
227
248
 
228
249
  # builtin (native) function, canonicalize argument types
229
- for k, v in input_types.items():
230
- self.input_types[k] = warp.types.type_to_warp(v)
250
+ if input_types is not None:
251
+ for k, v in input_types.items():
252
+ self.input_types[k] = warp.types.type_to_warp(v)
231
253
 
232
254
  # cache mangled name
233
255
  if self.export and self.is_simple():
234
256
  self.mangled_name = self.mangle()
235
- else:
236
- self.mangled_name = None
237
257
 
238
258
  if not skip_adding_overload:
239
259
  self.add_overload(self)
@@ -264,7 +284,7 @@ class Function:
264
284
  signature_params.append(param)
265
285
  self.signature = inspect.Signature(signature_params)
266
286
 
267
- # scope for resolving overloads
287
+ # scope for resolving overloads, the locals() where the function is defined
268
288
  if scope_locals is None:
269
289
  scope_locals = inspect.currentframe().f_back.f_locals
270
290
 
@@ -326,10 +346,10 @@ class Function:
326
346
  # this function has no overloads, call it like a plain Python function
327
347
  return self.func(*args, **kwargs)
328
348
 
329
- def is_builtin(self):
349
+ def is_builtin(self) -> bool:
330
350
  return self.func is None
331
351
 
332
- def is_simple(self):
352
+ def is_simple(self) -> bool:
333
353
  if self.variadic:
334
354
  return False
335
355
 
@@ -343,9 +363,8 @@ class Function:
343
363
 
344
364
  return True
345
365
 
346
- def mangle(self):
347
- # builds a mangled name for the C-exported
348
- # function, e.g.: builtin_normalize_vec3()
366
+ def mangle(self) -> str:
367
+ """Build a mangled name for the C-exported function, e.g.: `builtin_normalize_vec3()`."""
349
368
 
350
369
  name = "builtin_" + self.key
351
370
 
@@ -361,7 +380,7 @@ class Function:
361
380
 
362
381
  return "_".join([name, *types])
363
382
 
364
- def add_overload(self, f):
383
+ def add_overload(self, f: Function) -> None:
365
384
  if self.is_builtin():
366
385
  # todo: note that it is an error to add two functions
367
386
  # with the exact same signature as this would cause compile
@@ -376,7 +395,7 @@ class Function:
376
395
  else:
377
396
  # get function signature based on the input types
378
397
  sig = warp.types.get_signature(
379
- f.input_types.values(), func_name=f.key, arg_names=list(f.input_types.keys())
398
+ list(f.input_types.values()), func_name=f.key, arg_names=list(f.input_types.keys())
380
399
  )
381
400
 
382
401
  # check if generic
@@ -385,7 +404,7 @@ class Function:
385
404
  else:
386
405
  self.user_overloads[sig] = f
387
406
 
388
- def get_overload(self, arg_types, kwarg_types):
407
+ def get_overload(self, arg_types: List[type], kwarg_types: Mapping[str, type]) -> Optional[Function]:
389
408
  assert not self.is_builtin()
390
409
 
391
410
  for f in self.user_overloads.values():
@@ -438,7 +457,7 @@ class Function:
438
457
  return f"<Function {self.key}({inputs_str})>"
439
458
 
440
459
 
441
- def call_builtin(func: Function, *params) -> Tuple[bool, Any]:
460
+ def call_builtin(func: Function, *params: Any) -> Tuple[bool, Any]:
442
461
  uses_non_warp_array_type = False
443
462
 
444
463
  init()
@@ -755,37 +774,51 @@ class Kernel:
755
774
 
756
775
 
757
776
  # decorator to register function, @func
758
- def func(f):
759
- name = warp.codegen.make_full_qualified_name(f)
760
-
761
- scope_locals = inspect.currentframe().f_back.f_locals
762
-
763
- m = get_module(f.__module__)
764
- doc = getattr(f, "__doc__", "") or ""
765
- Function(
766
- func=f,
767
- key=name,
768
- namespace="",
769
- module=m,
770
- value_func=None,
771
- scope_locals=scope_locals,
772
- doc=doc.strip(),
773
- ) # value_type not known yet, will be inferred during Adjoint.build()
774
-
775
- # use the top of the list of overloads for this key
776
- g = m.functions[name]
777
- # copy over the function attributes, including docstring
778
- return functools.update_wrapper(g, f)
779
-
780
-
781
- def func_native(snippet, adj_snippet=None, replay_snippet=None):
777
+ def func(f: Optional[Callable] = None, *, name: Optional[str] = None):
778
+ def wrapper(f, *args, **kwargs):
779
+ if name is None:
780
+ key = warp.codegen.make_full_qualified_name(f)
781
+ else:
782
+ key = name
783
+
784
+ scope_locals = inspect.currentframe().f_back.f_back.f_locals
785
+
786
+ m = get_module(f.__module__)
787
+ doc = getattr(f, "__doc__", "") or ""
788
+ Function(
789
+ func=f,
790
+ key=key,
791
+ namespace="",
792
+ module=m,
793
+ value_func=None,
794
+ scope_locals=scope_locals,
795
+ doc=doc.strip(),
796
+ ) # value_type not known yet, will be inferred during Adjoint.build()
797
+
798
+ # use the top of the list of overloads for this key
799
+ g = m.functions[key]
800
+ # copy over the function attributes, including docstring
801
+ return functools.update_wrapper(g, f)
802
+
803
+ if f is None:
804
+ # Arguments were passed to the decorator.
805
+ return wrapper
806
+
807
+ return wrapper(f)
808
+
809
+
810
+ def func_native(snippet: str, adj_snippet: Optional[str] = None, replay_snippet: Optional[str] = None):
782
811
  """
783
812
  Decorator to register native code snippet, @func_native
784
813
  """
785
814
 
786
- scope_locals = inspect.currentframe().f_back.f_locals
815
+ frame = inspect.currentframe()
816
+ if frame is None or frame.f_back is None:
817
+ scope_locals = {}
818
+ else:
819
+ scope_locals = frame.f_back.f_locals
787
820
 
788
- def snippet_func(f):
821
+ def snippet_func(f: Callable) -> Callable:
789
822
  name = warp.codegen.make_full_qualified_name(f)
790
823
 
791
824
  m = get_module(f.__module__)
@@ -957,22 +990,71 @@ def func_replay(forward_fn):
957
990
  return wrapper
958
991
 
959
992
 
960
- # decorator to register kernel, @kernel, custom_name may be a string
961
- # that creates a kernel with a different name from the actual function
962
- def kernel(f=None, *, enable_backward=None):
993
+ def kernel(
994
+ f: Optional[Callable] = None,
995
+ *,
996
+ enable_backward: Optional[bool] = None,
997
+ module: Optional[Union[Module, Literal["unique"]]] = None,
998
+ ):
999
+ """
1000
+ Decorator to register a Warp kernel from a Python function.
1001
+ The function must be defined with type annotations for all arguments.
1002
+ The function must not return anything.
1003
+
1004
+ Example::
1005
+
1006
+ @wp.kernel
1007
+ def my_kernel(a: wp.array(dtype=float), b: wp.array(dtype=float)):
1008
+ tid = wp.tid()
1009
+ b[tid] = a[tid] + 1.0
1010
+
1011
+
1012
+ @wp.kernel(enable_backward=False)
1013
+ def my_kernel_no_backward(a: wp.array(dtype=float, ndim=2), x: float):
1014
+ # the backward pass will not be generated
1015
+ i, j = wp.tid()
1016
+ a[i, j] = x
1017
+
1018
+
1019
+ @wp.kernel(module="unique")
1020
+ def my_kernel_unique_module(a: wp.array(dtype=float), b: wp.array(dtype=float)):
1021
+ # the kernel will be registered in new unique module created just for this
1022
+ # kernel and its dependent functions and structs
1023
+ tid = wp.tid()
1024
+ b[tid] = a[tid] + 1.0
1025
+
1026
+ Args:
1027
+ f: The function to be registered as a kernel.
1028
+ enable_backward: If False, the backward pass will not be generated.
1029
+ module: The :class:`warp.context.Module` to which the kernel belongs. Alternatively, if a string `"unique"` is provided, the kernel is assigned to a new module named after the kernel name and hash. If None, the module is inferred from the function's module.
1030
+
1031
+ Returns:
1032
+ The registered kernel.
1033
+ """
1034
+
963
1035
  def wrapper(f, *args, **kwargs):
964
1036
  options = {}
965
1037
 
966
1038
  if enable_backward is not None:
967
1039
  options["enable_backward"] = enable_backward
968
1040
 
969
- m = get_module(f.__module__)
1041
+ if module is None:
1042
+ m = get_module(f.__module__)
1043
+ elif module == "unique":
1044
+ m = Module(f.__name__, None)
1045
+ else:
1046
+ m = module
970
1047
  k = Kernel(
971
1048
  func=f,
972
1049
  key=warp.codegen.make_full_qualified_name(f),
973
1050
  module=m,
974
1051
  options=options,
975
1052
  )
1053
+ if module == "unique":
1054
+ # add the hash to the module name
1055
+ hasher = warp.context.ModuleHasher(m)
1056
+ k.module.name = f"{k.key}_{hasher.module_hash.hex()[:8]}"
1057
+
976
1058
  k = functools.update_wrapper(k, f)
977
1059
  return k
978
1060
 
@@ -984,7 +1066,7 @@ def kernel(f=None, *, enable_backward=None):
984
1066
 
985
1067
 
986
1068
  # decorator to register struct, @struct
987
- def struct(c):
1069
+ def struct(c: type):
988
1070
  m = get_module(c.__module__)
989
1071
  s = warp.codegen.Struct(cls=c, key=warp.codegen.make_full_qualified_name(c), module=m)
990
1072
  s = functools.update_wrapper(s, c)
@@ -1097,47 +1179,47 @@ scalar_types.update({x: x._wp_scalar_type_ for x in warp.types.vector_types})
1097
1179
 
1098
1180
 
1099
1181
  def add_builtin(
1100
- key,
1101
- input_types=None,
1102
- constraint=None,
1103
- value_type=None,
1104
- value_func=None,
1105
- export_func=None,
1106
- dispatch_func=None,
1107
- lto_dispatch_func=None,
1108
- doc="",
1109
- namespace="wp::",
1110
- variadic=False,
1182
+ key: str,
1183
+ input_types: Optional[Dict[str, Union[type, TypeVar]]] = None,
1184
+ constraint: Optional[Callable[[Mapping[str, type]], bool]] = None,
1185
+ value_type: Optional[type] = None,
1186
+ value_func: Optional[Callable] = None,
1187
+ export_func: Optional[Callable] = None,
1188
+ dispatch_func: Optional[Callable] = None,
1189
+ lto_dispatch_func: Optional[Callable] = None,
1190
+ doc: str = "",
1191
+ namespace: str = "wp::",
1192
+ variadic: bool = False,
1111
1193
  initializer_list_func=None,
1112
- export=True,
1113
- group="Other",
1114
- hidden=False,
1115
- skip_replay=False,
1116
- missing_grad=False,
1117
- native_func=None,
1118
- defaults=None,
1119
- require_original_output_arg=False,
1194
+ export: bool = True,
1195
+ group: str = "Other",
1196
+ hidden: bool = False,
1197
+ skip_replay: bool = False,
1198
+ missing_grad: bool = False,
1199
+ native_func: Optional[str] = None,
1200
+ defaults: Optional[Dict[str, Any]] = None,
1201
+ require_original_output_arg: bool = False,
1120
1202
  ):
1121
1203
  """Main entry point to register a new built-in function.
1122
1204
 
1123
1205
  Args:
1124
- key (str): Function name. Multiple overloaded functions can be registered
1206
+ key: Function name. Multiple overloaded functions can be registered
1125
1207
  under the same name as long as their signature differ.
1126
- input_types (Mapping[str, Any]): Signature of the user-facing function.
1208
+ input_types: Signature of the user-facing function.
1127
1209
  Variadic arguments are supported by prefixing the parameter names
1128
1210
  with asterisks as in `*args` and `**kwargs`. Generic arguments are
1129
1211
  supported with types such as `Any`, `Float`, `Scalar`, etc.
1130
- constraint (Callable): For functions that define generic arguments and
1212
+ constraint: For functions that define generic arguments and
1131
1213
  are to be exported, this callback is used to specify whether some
1132
1214
  combination of inferred arguments are valid or not.
1133
- value_type (Any): Type returned by the function.
1134
- value_func (Callable): Callback used to specify the return type when
1215
+ value_type: Type returned by the function.
1216
+ value_func: Callback used to specify the return type when
1135
1217
  `value_type` isn't enough.
1136
- export_func (Callable): Callback used during the context stage to specify
1218
+ export_func: Callback used during the context stage to specify
1137
1219
  the signature of the underlying C++ function, not accounting for
1138
1220
  the template parameters.
1139
1221
  If not provided, `input_types` is used.
1140
- dispatch_func (Callable): Callback used during the codegen stage to specify
1222
+ dispatch_func: Callback used during the codegen stage to specify
1141
1223
  the runtime and template arguments to be passed to the underlying C++
1142
1224
  function. In other words, this allows defining a mapping between
1143
1225
  the signatures of the user-facing and the C++ functions, and even to
@@ -1145,27 +1227,26 @@ def add_builtin(
1145
1227
  The arguments returned must be of type `codegen.Var`.
1146
1228
  If not provided, all arguments passed by the users when calling
1147
1229
  the built-in are passed as-is as runtime arguments to the C++ function.
1148
- lto_dispatch_func (Callable): Same as dispatch_func, but takes an 'option' dict
1230
+ lto_dispatch_func: Same as dispatch_func, but takes an 'option' dict
1149
1231
  as extra argument (indicating tile_size and target architecture) and returns
1150
1232
  an LTO-IR buffer as extra return value
1151
- doc (str): Used to generate the Python's docstring and the HTML documentation.
1233
+ doc: Used to generate the Python's docstring and the HTML documentation.
1152
1234
  namespace: Namespace for the underlying C++ function.
1153
- variadic (bool): Whether the function declares variadic arguments.
1154
- initializer_list_func (bool): Whether to use the initializer list syntax
1155
- when passing the arguments to the underlying C++ function.
1156
- export (bool): Whether the function is to be exposed to the Python
1235
+ variadic: Whether the function declares variadic arguments.
1236
+ initializer_list_func: Callback to determine whether to use the
1237
+ initializer list syntax when passing the arguments to the underlying
1238
+ C++ function.
1239
+ export: Whether the function is to be exposed to the Python
1157
1240
  interpreter so that it becomes available from within the `warp`
1158
1241
  module.
1159
- group (str): Classification used for the documentation.
1160
- hidden (bool): Whether to add that function into the documentation.
1161
- skip_replay (bool): Whether operation will be performed during
1242
+ group: Classification used for the documentation.
1243
+ hidden: Whether to add that function into the documentation.
1244
+ skip_replay: Whether operation will be performed during
1162
1245
  the forward replay in the backward pass.
1163
- missing_grad (bool): Whether the function is missing a corresponding
1164
- adjoint.
1165
- native_func (str): Name of the underlying C++ function.
1166
- defaults (Mapping[str, Any]): Default values for the parameters defined
1167
- in `input_types`.
1168
- require_original_output_arg (bool): Used during the codegen stage to
1246
+ missing_grad: Whether the function is missing a corresponding adjoint.
1247
+ native_func: Name of the underlying C++ function.
1248
+ defaults: Default values for the parameters defined in `input_types`.
1249
+ require_original_output_arg: Used during the codegen stage to
1169
1250
  specify whether an adjoint parameter corresponding to the return
1170
1251
  value should be included in the signature of the backward function.
1171
1252
  """
@@ -1347,19 +1428,14 @@ def add_builtin(
1347
1428
  def register_api_function(
1348
1429
  function: Function,
1349
1430
  group: str = "Other",
1350
- hidden=False,
1431
+ hidden: bool = False,
1351
1432
  ):
1352
1433
  """Main entry point to register a Warp Python function to be part of the Warp API and appear in the documentation.
1353
1434
 
1354
1435
  Args:
1355
- function (Function): Warp function to be registered.
1356
- group (str): Classification used for the documentation.
1357
- input_types (Mapping[str, Any]): Signature of the user-facing function.
1358
- Variadic arguments are supported by prefixing the parameter names
1359
- with asterisks as in `*args` and `**kwargs`. Generic arguments are
1360
- supported with types such as `Any`, `Float`, `Scalar`, etc.
1361
- value_type (Any): Type returned by the function.
1362
- hidden (bool): Whether to add that function into the documentation.
1436
+ function: Warp function to be registered.
1437
+ group: Classification used for the documentation.
1438
+ hidden: Whether to add that function into the documentation.
1363
1439
  """
1364
1440
  function.group = group
1365
1441
  function.hidden = hidden
@@ -1367,10 +1443,10 @@ def register_api_function(
1367
1443
 
1368
1444
 
1369
1445
  # global dictionary of modules
1370
- user_modules = {}
1446
+ user_modules: Dict[str, Module] = {}
1371
1447
 
1372
1448
 
1373
- def get_module(name):
1449
+ def get_module(name: str) -> Module:
1374
1450
  # some modules might be manually imported using `importlib` without being
1375
1451
  # registered into `sys.modules`
1376
1452
  parent = sys.modules.get(name, None)
@@ -1452,13 +1528,16 @@ class ModuleHasher:
1452
1528
  if warp.config.verify_fp:
1453
1529
  ch.update(bytes("verify_fp", "utf-8"))
1454
1530
 
1531
+ # line directives, e.g. for Nsight Compute
1532
+ ch.update(bytes(ctypes.c_int(warp.config.line_directives)))
1533
+
1455
1534
  # build config
1456
1535
  ch.update(bytes(warp.config.mode, "utf-8"))
1457
1536
 
1458
1537
  # save the module hash
1459
1538
  self.module_hash = ch.digest()
1460
1539
 
1461
- def hash_kernel(self, kernel):
1540
+ def hash_kernel(self, kernel: Kernel) -> bytes:
1462
1541
  # NOTE: We only hash non-generic kernels, so we don't traverse kernel overloads here.
1463
1542
 
1464
1543
  ch = hashlib.sha256()
@@ -1472,7 +1551,7 @@ class ModuleHasher:
1472
1551
 
1473
1552
  return h
1474
1553
 
1475
- def hash_function(self, func):
1554
+ def hash_function(self, func: Function) -> bytes:
1476
1555
  # NOTE: This method hashes all possible overloads that a function call could resolve to.
1477
1556
  # The exact overload will be resolved at build time, when the argument types are known.
1478
1557
 
@@ -1487,7 +1566,7 @@ class ModuleHasher:
1487
1566
  ch.update(bytes(func.key, "utf-8"))
1488
1567
 
1489
1568
  # include all concrete and generic overloads
1490
- overloads = {**func.user_overloads, **func.user_templates}
1569
+ overloads: Dict[str, Function] = {**func.user_overloads, **func.user_templates}
1491
1570
  for sig in sorted(overloads.keys()):
1492
1571
  ovl = overloads[sig]
1493
1572
 
@@ -1518,7 +1597,7 @@ class ModuleHasher:
1518
1597
 
1519
1598
  return h
1520
1599
 
1521
- def hash_adjoint(self, adj):
1600
+ def hash_adjoint(self, adj: warp.codegen.Adjoint) -> bytes:
1522
1601
  # NOTE: We don't cache adjoint hashes, because adjoints are always unique.
1523
1602
  # Even instances of generic kernels and functions have unique adjoints with
1524
1603
  # different argument types.
@@ -1567,7 +1646,7 @@ class ModuleHasher:
1567
1646
 
1568
1647
  return ch.digest()
1569
1648
 
1570
- def get_constant_bytes(self, value):
1649
+ def get_constant_bytes(self, value) -> bytes:
1571
1650
  if isinstance(value, int):
1572
1651
  # this also handles builtins.bool
1573
1652
  return bytes(ctypes.c_int(value))
@@ -1585,7 +1664,7 @@ class ModuleHasher:
1585
1664
  else:
1586
1665
  raise TypeError(f"Invalid constant type: {type(value)}")
1587
1666
 
1588
- def get_module_hash(self):
1667
+ def get_module_hash(self) -> bytes:
1589
1668
  return self.module_hash
1590
1669
 
1591
1670
  def get_unique_kernels(self):
@@ -1602,6 +1681,7 @@ class ModuleBuilder:
1602
1681
  self.fatbins = {} # map from <some identifier> to fatbins, to add at link time
1603
1682
  self.ltoirs = {} # map from lto symbol to lto binary
1604
1683
  self.ltoirs_decl = {} # map from lto symbol to lto forward declaration
1684
+ self.shared_memory_bytes = {} # map from lto symbol to shared memory requirements
1605
1685
 
1606
1686
  if hasher is None:
1607
1687
  hasher = ModuleHasher(module)
@@ -1718,9 +1798,9 @@ class ModuleBuilder:
1718
1798
 
1719
1799
  # add headers
1720
1800
  if device == "cpu":
1721
- source = warp.codegen.cpu_module_header.format(tile_size=self.options["block_dim"]) + source
1801
+ source = warp.codegen.cpu_module_header.format(block_dim=self.options["block_dim"]) + source
1722
1802
  else:
1723
- source = warp.codegen.cuda_module_header.format(tile_size=self.options["block_dim"]) + source
1803
+ source = warp.codegen.cuda_module_header.format(block_dim=self.options["block_dim"]) + source
1724
1804
 
1725
1805
  return source
1726
1806
 
@@ -1757,7 +1837,7 @@ class ModuleExec:
1757
1837
  runtime.llvm.unload_obj(self.handle.encode("utf-8"))
1758
1838
 
1759
1839
  # lookup and cache kernel entry points
1760
- def get_kernel_hooks(self, kernel):
1840
+ def get_kernel_hooks(self, kernel) -> KernelHooks:
1761
1841
  # Use kernel.adj as a unique key for cache lookups instead of the kernel itself.
1762
1842
  # This avoids holding a reference to the kernel and is faster than using
1763
1843
  # a WeakKeyDictionary with kernels as keys.
@@ -1830,7 +1910,7 @@ class ModuleExec:
1830
1910
  # creates a hash of the function to use for checking
1831
1911
  # build cache
1832
1912
  class Module:
1833
- def __init__(self, name, loader):
1913
+ def __init__(self, name: Optional[str], loader=None):
1834
1914
  self.name = name if name is not None else "None"
1835
1915
 
1836
1916
  self.loader = loader
@@ -1870,7 +1950,7 @@ class Module:
1870
1950
  "enable_backward": warp.config.enable_backward,
1871
1951
  "fast_math": False,
1872
1952
  "fuse_fp": True,
1873
- "lineinfo": False,
1953
+ "lineinfo": warp.config.lineinfo,
1874
1954
  "cuda_output": None, # supported values: "ptx", "cubin", or None (automatic)
1875
1955
  "mode": warp.config.mode,
1876
1956
  "block_dim": 256,
@@ -2073,7 +2153,11 @@ class Module:
2073
2153
  use_ptx = True
2074
2154
 
2075
2155
  if use_ptx:
2076
- output_arch = min(device.arch, warp.config.ptx_target_arch)
2156
+ # use the default PTX arch if the device supports it
2157
+ if warp.config.ptx_target_arch is not None:
2158
+ output_arch = min(device.arch, warp.config.ptx_target_arch)
2159
+ else:
2160
+ output_arch = min(device.arch, runtime.default_ptx_arch)
2077
2161
  output_name = f"{module_name_short}.sm{output_arch}.ptx"
2078
2162
  else:
2079
2163
  output_arch = device.arch
@@ -2186,34 +2270,8 @@ class Module:
2186
2270
  # -----------------------------------------------------------
2187
2271
  # update cache
2188
2272
 
2189
- def safe_rename(src, dst, attempts=5, delay=0.1):
2190
- for i in range(attempts):
2191
- try:
2192
- os.rename(src, dst)
2193
- return
2194
- except FileExistsError:
2195
- return
2196
- except OSError as e:
2197
- if e.errno == errno.ENOTEMPTY:
2198
- # if directory exists we assume another process
2199
- # got there first, in which case we will copy
2200
- # our output to the directory manually in second step
2201
- return
2202
- else:
2203
- # otherwise assume directory creation failed e.g.: access denied
2204
- # on Windows we see occasional failures to rename directories due to
2205
- # some process holding a lock on a file to be moved to workaround
2206
- # this we make multiple attempts to rename with some delay
2207
- if i < attempts - 1:
2208
- time.sleep(delay)
2209
- else:
2210
- print(
2211
- f"Could not update Warp cache with module binaries, trying to rename {build_dir} to {module_dir}, error {e}"
2212
- )
2213
- raise e
2214
-
2215
2273
  # try to move process outputs to cache
2216
- safe_rename(build_dir, module_dir)
2274
+ warp.build.safe_rename(build_dir, module_dir)
2217
2275
 
2218
2276
  if os.path.exists(module_dir):
2219
2277
  if not os.path.exists(binary_path):
@@ -2286,7 +2344,7 @@ class Module:
2286
2344
  self.failed_builds = set()
2287
2345
 
2288
2346
  # lookup kernel entry points based on name, called after compilation / module load
2289
- def get_kernel_hooks(self, kernel, device):
2347
+ def get_kernel_hooks(self, kernel, device: Device) -> KernelHooks:
2290
2348
  module_exec = self.execs.get((device.context, self.options["block_dim"]))
2291
2349
  if module_exec is not None:
2292
2350
  return module_exec.get_kernel_hooks(kernel)
@@ -2441,6 +2499,7 @@ class Event:
2441
2499
  raise RuntimeError(f"Device {device} is not a CUDA device")
2442
2500
 
2443
2501
  self.device = device
2502
+ self.enable_timing = enable_timing
2444
2503
 
2445
2504
  if cuda_event is not None:
2446
2505
  self.cuda_event = cuda_event
@@ -2490,6 +2549,17 @@ class Event:
2490
2549
  else:
2491
2550
  raise RuntimeError(f"Device {self.device} does not support IPC.")
2492
2551
 
2552
+ @property
2553
+ def is_complete(self) -> bool:
2554
+ """A boolean indicating whether all work on the stream when the event was recorded has completed.
2555
+
2556
+ This property may not be accessed during a graph capture on any stream.
2557
+ """
2558
+
2559
+ result_code = runtime.core.cuda_event_query(self.cuda_event)
2560
+
2561
+ return result_code == 0
2562
+
2493
2563
  def __del__(self):
2494
2564
  if not self.owner:
2495
2565
  return
@@ -2504,7 +2574,7 @@ class Stream:
2504
2574
  instance.owner = False
2505
2575
  return instance
2506
2576
 
2507
- def __init__(self, device: Optional[Union["Device", str]] = None, priority: int = 0, **kwargs):
2577
+ def __init__(self, device: Union["Device", str, None] = None, priority: int = 0, **kwargs):
2508
2578
  """Initialize the stream on a device with an optional specified priority.
2509
2579
 
2510
2580
  Args:
@@ -2520,7 +2590,7 @@ class Stream:
2520
2590
  Raises:
2521
2591
  RuntimeError: If function is called before Warp has completed
2522
2592
  initialization with a ``device`` that is not an instance of
2523
- :class:`Device``.
2593
+ :class:`Device <warp.context.Device>`.
2524
2594
  RuntimeError: ``device`` is not a CUDA Device.
2525
2595
  RuntimeError: The stream could not be created on the device.
2526
2596
  TypeError: The requested stream priority is not an integer.
@@ -2588,7 +2658,7 @@ class Stream:
2588
2658
  f"Event from device {event.device} cannot be recorded on stream from device {self.device}"
2589
2659
  )
2590
2660
 
2591
- runtime.core.cuda_event_record(event.cuda_event, self.cuda_stream)
2661
+ runtime.core.cuda_event_record(event.cuda_event, self.cuda_stream, event.enable_timing)
2592
2662
 
2593
2663
  return event
2594
2664
 
@@ -2622,6 +2692,17 @@ class Stream:
2622
2692
 
2623
2693
  runtime.core.cuda_stream_wait_stream(self.cuda_stream, other_stream.cuda_stream, event.cuda_event)
2624
2694
 
2695
+ @property
2696
+ def is_complete(self) -> bool:
2697
+ """A boolean indicating whether all work on the stream has completed.
2698
+
2699
+ This property may not be accessed during a graph capture on any stream.
2700
+ """
2701
+
2702
+ result_code = runtime.core.cuda_stream_query(self.cuda_stream)
2703
+
2704
+ return result_code == 0
2705
+
2625
2706
  @property
2626
2707
  def is_capturing(self) -> bool:
2627
2708
  """A boolean indicating whether a graph capture is currently ongoing on this stream."""
@@ -2944,18 +3025,14 @@ Devicelike = Union[Device, str, None]
2944
3025
 
2945
3026
 
2946
3027
  class Graph:
2947
- def __new__(cls, *args, **kwargs):
2948
- instance = super(Graph, cls).__new__(cls)
2949
- instance.graph_exec = None
2950
- return instance
2951
-
2952
3028
  def __init__(self, device: Device, capture_id: int):
2953
3029
  self.device = device
2954
3030
  self.capture_id = capture_id
2955
- self.module_execs = set()
3031
+ self.module_execs: Set[ModuleExec] = set()
3032
+ self.graph_exec: Optional[ctypes.c_void_p] = None
2956
3033
 
2957
3034
  def __del__(self):
2958
- if not self.graph_exec:
3035
+ if not hasattr(self, "graph_exec") or not hasattr(self, "device") or not self.graph_exec:
2959
3036
  return
2960
3037
 
2961
3038
  # use CUDA context guard to avoid side effects during garbage collection
@@ -3197,6 +3274,43 @@ class Runtime:
3197
3274
  self.core.radix_sort_pairs_float_host.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int]
3198
3275
  self.core.radix_sort_pairs_float_device.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int]
3199
3276
 
3277
+ self.core.radix_sort_pairs_int64_host.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int]
3278
+ self.core.radix_sort_pairs_int64_device.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int]
3279
+
3280
+ self.core.segmented_sort_pairs_int_host.argtypes = [
3281
+ ctypes.c_uint64,
3282
+ ctypes.c_uint64,
3283
+ ctypes.c_int,
3284
+ ctypes.c_uint64,
3285
+ ctypes.c_uint64,
3286
+ ctypes.c_int,
3287
+ ]
3288
+ self.core.segmented_sort_pairs_int_device.argtypes = [
3289
+ ctypes.c_uint64,
3290
+ ctypes.c_uint64,
3291
+ ctypes.c_int,
3292
+ ctypes.c_uint64,
3293
+ ctypes.c_uint64,
3294
+ ctypes.c_int,
3295
+ ]
3296
+
3297
+ self.core.segmented_sort_pairs_float_host.argtypes = [
3298
+ ctypes.c_uint64,
3299
+ ctypes.c_uint64,
3300
+ ctypes.c_int,
3301
+ ctypes.c_uint64,
3302
+ ctypes.c_uint64,
3303
+ ctypes.c_int,
3304
+ ]
3305
+ self.core.segmented_sort_pairs_float_device.argtypes = [
3306
+ ctypes.c_uint64,
3307
+ ctypes.c_uint64,
3308
+ ctypes.c_int,
3309
+ ctypes.c_uint64,
3310
+ ctypes.c_uint64,
3311
+ ctypes.c_int,
3312
+ ]
3313
+
3200
3314
  self.core.runlength_encode_int_host.argtypes = [
3201
3315
  ctypes.c_uint64,
3202
3316
  ctypes.c_uint64,
@@ -3277,26 +3391,6 @@ class Runtime:
3277
3391
  self.core.hash_grid_update_device.argtypes = [ctypes.c_uint64, ctypes.c_float, ctypes.c_void_p]
3278
3392
  self.core.hash_grid_reserve_device.argtypes = [ctypes.c_uint64, ctypes.c_int]
3279
3393
 
3280
- self.core.cutlass_gemm.argtypes = [
3281
- ctypes.c_void_p,
3282
- ctypes.c_int,
3283
- ctypes.c_int,
3284
- ctypes.c_int,
3285
- ctypes.c_int,
3286
- ctypes.c_char_p,
3287
- ctypes.c_void_p,
3288
- ctypes.c_void_p,
3289
- ctypes.c_void_p,
3290
- ctypes.c_void_p,
3291
- ctypes.c_float,
3292
- ctypes.c_float,
3293
- ctypes.c_bool,
3294
- ctypes.c_bool,
3295
- ctypes.c_bool,
3296
- ctypes.c_int,
3297
- ]
3298
- self.core.cutlass_gemm.restype = ctypes.c_bool
3299
-
3300
3394
  self.core.volume_create_host.argtypes = [ctypes.c_void_p, ctypes.c_uint64, ctypes.c_bool, ctypes.c_bool]
3301
3395
  self.core.volume_create_host.restype = ctypes.c_uint64
3302
3396
  self.core.volume_get_tiles_host.argtypes = [
@@ -3327,36 +3421,18 @@ class Runtime:
3327
3421
  ]
3328
3422
  self.core.volume_destroy_device.argtypes = [ctypes.c_uint64]
3329
3423
 
3330
- self.core.volume_f_from_tiles_device.argtypes = [
3424
+ self.core.volume_from_tiles_device.argtypes = [
3331
3425
  ctypes.c_void_p,
3332
3426
  ctypes.c_void_p,
3333
3427
  ctypes.c_int,
3334
3428
  ctypes.c_float * 9,
3335
3429
  ctypes.c_float * 3,
3336
3430
  ctypes.c_bool,
3337
- ctypes.c_float,
3338
- ]
3339
- self.core.volume_f_from_tiles_device.restype = ctypes.c_uint64
3340
- self.core.volume_v_from_tiles_device.argtypes = [
3341
3431
  ctypes.c_void_p,
3342
- ctypes.c_void_p,
3343
- ctypes.c_int,
3344
- ctypes.c_float * 9,
3345
- ctypes.c_float * 3,
3346
- ctypes.c_bool,
3347
- ctypes.c_float * 3,
3348
- ]
3349
- self.core.volume_v_from_tiles_device.restype = ctypes.c_uint64
3350
- self.core.volume_i_from_tiles_device.argtypes = [
3351
- ctypes.c_void_p,
3352
- ctypes.c_void_p,
3353
- ctypes.c_int,
3354
- ctypes.c_float * 9,
3355
- ctypes.c_float * 3,
3356
- ctypes.c_bool,
3357
- ctypes.c_int,
3432
+ ctypes.c_uint32,
3433
+ ctypes.c_char_p,
3358
3434
  ]
3359
- self.core.volume_i_from_tiles_device.restype = ctypes.c_uint64
3435
+ self.core.volume_from_tiles_device.restype = ctypes.c_uint64
3360
3436
  self.core.volume_index_from_tiles_device.argtypes = [
3361
3437
  ctypes.c_void_p,
3362
3438
  ctypes.c_void_p,
@@ -3425,6 +3501,7 @@ class Runtime:
3425
3501
  ctypes.POINTER(ctypes.c_int), # tpl_cols
3426
3502
  ctypes.c_void_p, # tpl_values
3427
3503
  ctypes.c_bool, # prune_numerical_zeros
3504
+ ctypes.c_bool, # masked
3428
3505
  ctypes.POINTER(ctypes.c_int), # bsr_offsets
3429
3506
  ctypes.POINTER(ctypes.c_int), # bsr_columns
3430
3507
  ctypes.c_void_p, # bsr_values
@@ -3459,8 +3536,6 @@ class Runtime:
3459
3536
  self.core.is_cuda_enabled.restype = ctypes.c_int
3460
3537
  self.core.is_cuda_compatibility_enabled.argtypes = None
3461
3538
  self.core.is_cuda_compatibility_enabled.restype = ctypes.c_int
3462
- self.core.is_cutlass_enabled.argtypes = None
3463
- self.core.is_cutlass_enabled.restype = ctypes.c_int
3464
3539
  self.core.is_mathdx_enabled.argtypes = None
3465
3540
  self.core.is_mathdx_enabled.restype = ctypes.c_int
3466
3541
 
@@ -3494,6 +3569,10 @@ class Runtime:
3494
3569
  self.core.cuda_device_set_mempool_release_threshold.restype = ctypes.c_int
3495
3570
  self.core.cuda_device_get_mempool_release_threshold.argtypes = [ctypes.c_int]
3496
3571
  self.core.cuda_device_get_mempool_release_threshold.restype = ctypes.c_uint64
3572
+ self.core.cuda_device_get_mempool_used_mem_current.argtypes = [ctypes.c_int]
3573
+ self.core.cuda_device_get_mempool_used_mem_current.restype = ctypes.c_uint64
3574
+ self.core.cuda_device_get_mempool_used_mem_high.argtypes = [ctypes.c_int]
3575
+ self.core.cuda_device_get_mempool_used_mem_high.restype = ctypes.c_uint64
3497
3576
  self.core.cuda_device_get_memory_info.argtypes = [ctypes.c_int, ctypes.c_void_p, ctypes.c_void_p]
3498
3577
  self.core.cuda_device_get_memory_info.restype = None
3499
3578
  self.core.cuda_device_get_uuid.argtypes = [ctypes.c_int, ctypes.c_char * 16]
@@ -3563,6 +3642,8 @@ class Runtime:
3563
3642
  self.core.cuda_stream_create.restype = ctypes.c_void_p
3564
3643
  self.core.cuda_stream_destroy.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
3565
3644
  self.core.cuda_stream_destroy.restype = None
3645
+ self.core.cuda_stream_query.argtypes = [ctypes.c_void_p]
3646
+ self.core.cuda_stream_query.restype = ctypes.c_int
3566
3647
  self.core.cuda_stream_register.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
3567
3648
  self.core.cuda_stream_register.restype = None
3568
3649
  self.core.cuda_stream_unregister.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
@@ -3584,7 +3665,9 @@ class Runtime:
3584
3665
  self.core.cuda_event_create.restype = ctypes.c_void_p
3585
3666
  self.core.cuda_event_destroy.argtypes = [ctypes.c_void_p]
3586
3667
  self.core.cuda_event_destroy.restype = None
3587
- self.core.cuda_event_record.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
3668
+ self.core.cuda_event_query.argtypes = [ctypes.c_void_p]
3669
+ self.core.cuda_event_query.restype = ctypes.c_int
3670
+ self.core.cuda_event_record.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_bool]
3588
3671
  self.core.cuda_event_record.restype = None
3589
3672
  self.core.cuda_event_synchronize.argtypes = [ctypes.c_void_p]
3590
3673
  self.core.cuda_event_synchronize.restype = None
@@ -3833,9 +3916,20 @@ class Runtime:
3833
3916
  cuda_device_count = len(self.cuda_devices)
3834
3917
  else:
3835
3918
  self.set_default_device("cuda:0")
3919
+
3920
+ # the minimum PTX architecture that supports all of Warp's features
3921
+ self.default_ptx_arch = 75
3922
+
3923
+ # Update the default PTX architecture based on devices present in the system.
3924
+ # Use the lowest architecture among devices that meet the minimum architecture requirement.
3925
+ # Devices below the required minimum will use the highest architecture they support.
3926
+ eligible_archs = [d.arch for d in self.cuda_devices if d.arch >= self.default_ptx_arch]
3927
+ if eligible_archs:
3928
+ self.default_ptx_arch = min(eligible_archs)
3836
3929
  else:
3837
3930
  # CUDA not available
3838
3931
  self.set_default_device("cpu")
3932
+ self.default_ptx_arch = None
3839
3933
 
3840
3934
  # initialize kernel cache
3841
3935
  warp.build.init_kernel_cache(warp.config.kernel_cache_dir)
@@ -3848,6 +3942,11 @@ class Runtime:
3848
3942
  greeting = []
3849
3943
 
3850
3944
  greeting.append(f"Warp {warp.config.version} initialized:")
3945
+
3946
+ # Add git commit hash to greeting if available
3947
+ if warp.config._git_commit_hash is not None:
3948
+ greeting.append(f" Git commit: {warp.config._git_commit_hash}")
3949
+
3851
3950
  if cuda_device_count > 0:
3852
3951
  # print CUDA version info
3853
3952
  greeting.append(
@@ -4200,7 +4299,7 @@ def set_device(ident: Devicelike) -> None:
4200
4299
  device.make_current()
4201
4300
 
4202
4301
 
4203
- def map_cuda_device(alias: str, context: ctypes.c_void_p = None) -> Device:
4302
+ def map_cuda_device(alias: str, context: Optional[ctypes.c_void_p] = None) -> Device:
4204
4303
  """Assign a device alias to a CUDA context.
4205
4304
 
4206
4305
  This function can be used to create a wp.Device for an external CUDA context.
@@ -4228,7 +4327,13 @@ def unmap_cuda_device(alias: str) -> None:
4228
4327
 
4229
4328
 
4230
4329
  def is_mempool_supported(device: Devicelike) -> bool:
4231
- """Check if CUDA memory pool allocators are available on the device."""
4330
+ """Check if CUDA memory pool allocators are available on the device.
4331
+
4332
+ Parameters:
4333
+ device: The :class:`Device <warp.context.Device>` or device identifier
4334
+ for which the query is to be performed.
4335
+ If ``None``, the default device will be used.
4336
+ """
4232
4337
 
4233
4338
  init()
4234
4339
 
@@ -4238,7 +4343,13 @@ def is_mempool_supported(device: Devicelike) -> bool:
4238
4343
 
4239
4344
 
4240
4345
  def is_mempool_enabled(device: Devicelike) -> bool:
4241
- """Check if CUDA memory pool allocators are enabled on the device."""
4346
+ """Check if CUDA memory pool allocators are enabled on the device.
4347
+
4348
+ Parameters:
4349
+ device: The :class:`Device <warp.context.Device>` or device identifier
4350
+ for which the query is to be performed.
4351
+ If ``None``, the default device will be used.
4352
+ """
4242
4353
 
4243
4354
  init()
4244
4355
 
@@ -4258,6 +4369,11 @@ def set_mempool_enabled(device: Devicelike, enable: bool) -> None:
4258
4369
  to Warp. The preferred solution is to enable memory pool access using :func:`set_mempool_access_enabled`.
4259
4370
  If peer access is not supported, then the default CUDA allocators must be used to pre-allocate the memory
4260
4371
  prior to graph capture.
4372
+
4373
+ Parameters:
4374
+ device: The :class:`Device <warp.context.Device>` or device identifier
4375
+ for which the operation is to be performed.
4376
+ If ``None``, the default device will be used.
4261
4377
  """
4262
4378
 
4263
4379
  init()
@@ -4288,6 +4404,18 @@ def set_mempool_release_threshold(device: Devicelike, threshold: Union[int, floa
4288
4404
  Values between 0 and 1 are interpreted as fractions of available memory. For example, 0.5 means
4289
4405
  half of the device's physical memory. Greater values are interpreted as an absolute number of bytes.
4290
4406
  For example, 1024**3 means one GiB of memory.
4407
+
4408
+ Parameters:
4409
+ device: The :class:`Device <warp.context.Device>` or device identifier
4410
+ for which the operation is to be performed.
4411
+ If ``None``, the default device will be used.
4412
+ threshold: An integer representing a number of bytes, or a ``float`` between 0 and 1,
4413
+ specifying the desired release threshold.
4414
+
4415
+ Raises:
4416
+ ValueError: If ``device`` is not a CUDA device.
4417
+ RuntimeError: If ``device`` is a CUDA device, but does not support memory pools.
4418
+ RuntimeError: Failed to set the memory pool release threshold.
4291
4419
  """
4292
4420
 
4293
4421
  init()
@@ -4309,8 +4437,21 @@ def set_mempool_release_threshold(device: Devicelike, threshold: Union[int, floa
4309
4437
  raise RuntimeError(f"Failed to set memory pool release threshold for device {device}")
4310
4438
 
4311
4439
 
4312
- def get_mempool_release_threshold(device: Devicelike) -> int:
4313
- """Get the CUDA memory pool release threshold on the device in bytes."""
4440
+ def get_mempool_release_threshold(device: Devicelike = None) -> int:
4441
+ """Get the CUDA memory pool release threshold on the device.
4442
+
4443
+ Parameters:
4444
+ device: The :class:`Device <warp.context.Device>` or device identifier
4445
+ for which the query is to be performed.
4446
+ If ``None``, the default device will be used.
4447
+
4448
+ Returns:
4449
+ The memory pool release threshold in bytes.
4450
+
4451
+ Raises:
4452
+ ValueError: If ``device`` is not a CUDA device.
4453
+ RuntimeError: If ``device`` is a CUDA device, but does not support memory pools.
4454
+ """
4314
4455
 
4315
4456
  init()
4316
4457
 
@@ -4325,6 +4466,64 @@ def get_mempool_release_threshold(device: Devicelike) -> int:
4325
4466
  return runtime.core.cuda_device_get_mempool_release_threshold(device.ordinal)
4326
4467
 
4327
4468
 
4469
+ def get_mempool_used_mem_current(device: Devicelike = None) -> int:
4470
+ """Get the amount of memory from the device's memory pool that is currently in use by the application.
4471
+
4472
+ Parameters:
4473
+ device: The :class:`Device <warp.context.Device>` or device identifier
4474
+ for which the query is to be performed.
4475
+ If ``None``, the default device will be used.
4476
+
4477
+ Returns:
4478
+ The amount of memory used in bytes.
4479
+
4480
+ Raises:
4481
+ ValueError: If ``device`` is not a CUDA device.
4482
+ RuntimeError: If ``device`` is a CUDA device, but does not support memory pools.
4483
+ """
4484
+
4485
+ init()
4486
+
4487
+ device = runtime.get_device(device)
4488
+
4489
+ if not device.is_cuda:
4490
+ raise ValueError("Memory pools are only supported on CUDA devices")
4491
+
4492
+ if not device.is_mempool_supported:
4493
+ raise RuntimeError(f"Device {device} does not support memory pools")
4494
+
4495
+ return runtime.core.cuda_device_get_mempool_used_mem_current(device.ordinal)
4496
+
4497
+
4498
+ def get_mempool_used_mem_high(device: Devicelike = None) -> int:
4499
+ """Get the application's memory usage high-water mark from the device's CUDA memory pool.
4500
+
4501
+ Parameters:
4502
+ device: The :class:`Device <warp.context.Device>` or device identifier
4503
+ for which the query is to be performed.
4504
+ If ``None``, the default device will be used.
4505
+
4506
+ Returns:
4507
+ The high-water mark of memory used from the memory pool in bytes.
4508
+
4509
+ Raises:
4510
+ ValueError: If ``device`` is not a CUDA device.
4511
+ RuntimeError: If ``device`` is a CUDA device, but does not support memory pools.
4512
+ """
4513
+
4514
+ init()
4515
+
4516
+ device = runtime.get_device(device)
4517
+
4518
+ if not device.is_cuda:
4519
+ raise ValueError("Memory pools are only supported on CUDA devices")
4520
+
4521
+ if not device.is_mempool_supported:
4522
+ raise RuntimeError(f"Device {device} does not support memory pools")
4523
+
4524
+ return runtime.core.cuda_device_get_mempool_used_mem_high(device.ordinal)
4525
+
4526
+
4328
4527
  def is_peer_access_supported(target_device: Devicelike, peer_device: Devicelike) -> bool:
4329
4528
  """Check if `peer_device` can directly access the memory of `target_device` on this system.
4330
4529
 
@@ -4527,7 +4726,7 @@ def wait_event(event: Event):
4527
4726
  get_stream().wait_event(event)
4528
4727
 
4529
4728
 
4530
- def get_event_elapsed_time(start_event: Event, end_event: Event, synchronize: Optional[bool] = True):
4729
+ def get_event_elapsed_time(start_event: Event, end_event: Event, synchronize: bool = True):
4531
4730
  """Get the elapsed time between two recorded events.
4532
4731
 
4533
4732
  Both events must have been previously recorded with
@@ -4552,7 +4751,7 @@ def get_event_elapsed_time(start_event: Event, end_event: Event, synchronize: Op
4552
4751
  return runtime.core.cuda_event_elapsed_time(start_event.cuda_event, end_event.cuda_event)
4553
4752
 
4554
4753
 
4555
- def wait_stream(other_stream: Stream, event: Event = None):
4754
+ def wait_stream(other_stream: Stream, event: Optional[Event] = None):
4556
4755
  """Convenience function for calling :meth:`Stream.wait_stream` on the current stream.
4557
4756
 
4558
4757
  Args:
@@ -4719,7 +4918,7 @@ class RegisteredGLBuffer:
4719
4918
 
4720
4919
 
4721
4920
  def zeros(
4722
- shape: Tuple = None,
4921
+ shape: Union[int, Tuple[int, ...], List[int], None] = None,
4723
4922
  dtype=float,
4724
4923
  device: Devicelike = None,
4725
4924
  requires_grad: bool = False,
@@ -4747,7 +4946,7 @@ def zeros(
4747
4946
 
4748
4947
 
4749
4948
  def zeros_like(
4750
- src: warp.array, device: Devicelike = None, requires_grad: bool = None, pinned: bool = None
4949
+ src: Array, device: Devicelike = None, requires_grad: Optional[bool] = None, pinned: Optional[bool] = None
4751
4950
  ) -> warp.array:
4752
4951
  """Return a zero-initialized array with the same type and dimension of another array
4753
4952
 
@@ -4769,7 +4968,7 @@ def zeros_like(
4769
4968
 
4770
4969
 
4771
4970
  def ones(
4772
- shape: Tuple = None,
4971
+ shape: Union[int, Tuple[int, ...], List[int], None] = None,
4773
4972
  dtype=float,
4774
4973
  device: Devicelike = None,
4775
4974
  requires_grad: bool = False,
@@ -4793,7 +4992,7 @@ def ones(
4793
4992
 
4794
4993
 
4795
4994
  def ones_like(
4796
- src: warp.array, device: Devicelike = None, requires_grad: bool = None, pinned: bool = None
4995
+ src: Array, device: Devicelike = None, requires_grad: Optional[bool] = None, pinned: Optional[bool] = None
4797
4996
  ) -> warp.array:
4798
4997
  """Return a one-initialized array with the same type and dimension of another array
4799
4998
 
@@ -4811,7 +5010,7 @@ def ones_like(
4811
5010
 
4812
5011
 
4813
5012
  def full(
4814
- shape: Tuple = None,
5013
+ shape: Union[int, Tuple[int, ...], List[int], None] = None,
4815
5014
  value=0,
4816
5015
  dtype=Any,
4817
5016
  device: Devicelike = None,
@@ -4877,7 +5076,11 @@ def full(
4877
5076
 
4878
5077
 
4879
5078
  def full_like(
4880
- src: warp.array, value: Any, device: Devicelike = None, requires_grad: bool = None, pinned: bool = None
5079
+ src: Array,
5080
+ value: Any,
5081
+ device: Devicelike = None,
5082
+ requires_grad: Optional[bool] = None,
5083
+ pinned: Optional[bool] = None,
4881
5084
  ) -> warp.array:
4882
5085
  """Return an array with all elements initialized to the given value with the same type and dimension of another array
4883
5086
 
@@ -4899,7 +5102,9 @@ def full_like(
4899
5102
  return arr
4900
5103
 
4901
5104
 
4902
- def clone(src: warp.array, device: Devicelike = None, requires_grad: bool = None, pinned: bool = None) -> warp.array:
5105
+ def clone(
5106
+ src: warp.array, device: Devicelike = None, requires_grad: Optional[bool] = None, pinned: Optional[bool] = None
5107
+ ) -> warp.array:
4903
5108
  """Clone an existing array, allocates a copy of the src memory
4904
5109
 
4905
5110
  Args:
@@ -4920,7 +5125,7 @@ def clone(src: warp.array, device: Devicelike = None, requires_grad: bool = None
4920
5125
 
4921
5126
 
4922
5127
  def empty(
4923
- shape: Tuple = None,
5128
+ shape: Union[int, Tuple[int, ...], List[int], None] = None,
4924
5129
  dtype=float,
4925
5130
  device: Devicelike = None,
4926
5131
  requires_grad: bool = False,
@@ -4953,7 +5158,7 @@ def empty(
4953
5158
 
4954
5159
 
4955
5160
  def empty_like(
4956
- src: warp.array, device: Devicelike = None, requires_grad: bool = None, pinned: bool = None
5161
+ src: Array, device: Devicelike = None, requires_grad: Optional[bool] = None, pinned: Optional[bool] = None
4957
5162
  ) -> warp.array:
4958
5163
  """Return an uninitialized array with the same type and dimension of another array
4959
5164
 
@@ -5185,8 +5390,6 @@ def pack_arg(kernel, arg_type, arg_name, value, device, adjoint=False):
5185
5390
  ) from e
5186
5391
 
5187
5392
 
5188
- # represents all data required for a kernel launch
5189
- # so that launches can be replayed quickly, use `wp.launch(..., record_cmd=True)`
5190
5393
  class Launch:
5191
5394
  """Represents all data required for a kernel launch so that launches can be replayed quickly.
5192
5395
 
@@ -5457,7 +5660,7 @@ def launch(
5457
5660
  max_blocks: The maximum number of CUDA thread blocks to use.
5458
5661
  Only has an effect for CUDA kernel launches.
5459
5662
  If negative or zero, the maximum hardware value will be used.
5460
- block_dim: The number of threads per block.
5663
+ block_dim: The number of threads per block (always 1 for "cpu" devices).
5461
5664
  """
5462
5665
 
5463
5666
  init()
@@ -5468,6 +5671,9 @@ def launch(
5468
5671
  else:
5469
5672
  device = runtime.get_device(device)
5470
5673
 
5674
+ if device == "cpu":
5675
+ block_dim = 1
5676
+
5471
5677
  # check function is a Kernel
5472
5678
  if not isinstance(kernel, Kernel):
5473
5679
  raise RuntimeError("Error launching kernel, can only launch functions decorated with @wp.kernel.")
@@ -5700,6 +5906,18 @@ def launch_tiled(*args, **kwargs):
5700
5906
  "Launch block dimension 'block_dim' argument should be passed via. keyword args for wp.launch_tiled()"
5701
5907
  )
5702
5908
 
5909
+ if "device" in kwargs:
5910
+ device = kwargs["device"]
5911
+ else:
5912
+ # todo: this doesn't consider the case where device
5913
+ # is passed through positional args
5914
+ device = None
5915
+
5916
+ # force the block_dim to 1 if running on "cpu"
5917
+ device = runtime.get_device(device)
5918
+ if device.is_cpu:
5919
+ kwargs["block_dim"] = 1
5920
+
5703
5921
  dim = kwargs["dim"]
5704
5922
  if not isinstance(dim, list):
5705
5923
  dim = list(dim) if isinstance(dim, tuple) else [dim]
@@ -5868,6 +6086,7 @@ def set_module_options(options: Dict[str, Any], module: Optional[Any] = None):
5868
6086
 
5869
6087
  * **mode**: The compilation mode to use, can be "debug", or "release", defaults to the value of ``warp.config.mode``.
5870
6088
  * **max_unroll**: The maximum fixed-size loop to unroll, defaults to the value of ``warp.config.max_unroll``.
6089
+ * **block_dim**: The default number of threads to assign to each block
5871
6090
 
5872
6091
  Args:
5873
6092
 
@@ -5893,7 +6112,12 @@ def get_module_options(module: Optional[Any] = None) -> Dict[str, Any]:
5893
6112
  return get_module(m.__name__).options
5894
6113
 
5895
6114
 
5896
- def capture_begin(device: Devicelike = None, stream=None, force_module_load=None, external=False):
6115
+ def capture_begin(
6116
+ device: Devicelike = None,
6117
+ stream: Optional[Stream] = None,
6118
+ force_module_load: Optional[bool] = None,
6119
+ external: bool = False,
6120
+ ):
5897
6121
  """Begin capture of a CUDA graph
5898
6122
 
5899
6123
  Captures all subsequent kernel launches and memory operations on CUDA devices.
@@ -5960,16 +6184,15 @@ def capture_begin(device: Devicelike = None, stream=None, force_module_load=None
5960
6184
  runtime.captures[capture_id] = graph
5961
6185
 
5962
6186
 
5963
- def capture_end(device: Devicelike = None, stream: Stream = None) -> Graph:
5964
- """Ends the capture of a CUDA graph
6187
+ def capture_end(device: Devicelike = None, stream: Optional[Stream] = None) -> Graph:
6188
+ """End the capture of a CUDA graph.
5965
6189
 
5966
6190
  Args:
5967
-
5968
6191
  device: The CUDA device where capture began
5969
6192
  stream: The CUDA stream where capture began
5970
6193
 
5971
6194
  Returns:
5972
- A Graph object that can be launched with :func:`~warp.capture_launch()`
6195
+ A :class:`Graph` object that can be launched with :func:`~warp.capture_launch()`
5973
6196
  """
5974
6197
 
5975
6198
  if stream is not None:
@@ -6003,12 +6226,12 @@ def capture_end(device: Devicelike = None, stream: Stream = None) -> Graph:
6003
6226
  return graph
6004
6227
 
6005
6228
 
6006
- def capture_launch(graph: Graph, stream: Stream = None):
6229
+ def capture_launch(graph: Graph, stream: Optional[Stream] = None):
6007
6230
  """Launch a previously captured CUDA graph
6008
6231
 
6009
6232
  Args:
6010
- graph: A Graph as returned by :func:`~warp.capture_end()`
6011
- stream: A Stream to launch the graph on (optional)
6233
+ graph: A :class:`Graph` as returned by :func:`~warp.capture_end()`
6234
+ stream: A :class:`Stream` to launch the graph on
6012
6235
  """
6013
6236
 
6014
6237
  if stream is not None:
@@ -6024,24 +6247,28 @@ def capture_launch(graph: Graph, stream: Stream = None):
6024
6247
 
6025
6248
 
6026
6249
  def copy(
6027
- dest: warp.array, src: warp.array, dest_offset: int = 0, src_offset: int = 0, count: int = 0, stream: Stream = None
6250
+ dest: warp.array,
6251
+ src: warp.array,
6252
+ dest_offset: int = 0,
6253
+ src_offset: int = 0,
6254
+ count: int = 0,
6255
+ stream: Optional[Stream] = None,
6028
6256
  ):
6029
6257
  """Copy array contents from `src` to `dest`.
6030
6258
 
6031
6259
  Args:
6032
- dest: Destination array, must be at least as big as source buffer
6260
+ dest: Destination array, must be at least as large as source buffer
6033
6261
  src: Source array
6034
6262
  dest_offset: Element offset in the destination array
6035
6263
  src_offset: Element offset in the source array
6036
6264
  count: Number of array elements to copy (will copy all elements if set to 0)
6037
- stream: The stream on which to perform the copy (optional)
6265
+ stream: The stream on which to perform the copy
6038
6266
 
6039
6267
  The stream, if specified, can be from any device. If the stream is omitted, then Warp selects a stream based on the following rules:
6040
6268
  (1) If the destination array is on a CUDA device, use the current stream on the destination device.
6041
6269
  (2) Otherwise, if the source array is on a CUDA device, use the current stream on the source device.
6042
6270
 
6043
6271
  If neither source nor destination are on a CUDA device, no stream is used for the copy.
6044
-
6045
6272
  """
6046
6273
 
6047
6274
  from warp.context import runtime
@@ -6266,8 +6493,8 @@ def type_str(t):
6266
6493
  return f"Transformation[{type_str(t._wp_scalar_type_)}]"
6267
6494
 
6268
6495
  raise TypeError("Invalid vector or matrix dimensions")
6269
- elif warp.codegen.get_type_origin(t) in (list, tuple):
6270
- args_repr = ", ".join(type_str(x) for x in warp.codegen.get_type_args(t))
6496
+ elif get_origin(t) in (list, tuple):
6497
+ args_repr = ", ".join(type_str(x) for x in get_args(t))
6271
6498
  return f"{t._name}[{args_repr}]"
6272
6499
  elif t is Ellipsis:
6273
6500
  return "..."
@@ -6423,6 +6650,26 @@ def export_functions_rst(file): # pragma: no cover
6423
6650
  def export_stubs(file): # pragma: no cover
6424
6651
  """Generates stub file for auto-complete of builtin functions"""
6425
6652
 
6653
+ # Add copyright notice
6654
+ print(
6655
+ """# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
6656
+ # SPDX-License-Identifier: Apache-2.0
6657
+ #
6658
+ # Licensed under the Apache License, Version 2.0 (the "License");
6659
+ # you may not use this file except in compliance with the License.
6660
+ # You may obtain a copy of the License at
6661
+ #
6662
+ # http://www.apache.org/licenses/LICENSE-2.0
6663
+ #
6664
+ # Unless required by applicable law or agreed to in writing, software
6665
+ # distributed under the License is distributed on an "AS IS" BASIS,
6666
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
6667
+ # See the License for the specific language governing permissions and
6668
+ # limitations under the License.
6669
+ """,
6670
+ file=file,
6671
+ )
6672
+
6426
6673
  print(
6427
6674
  "# Autogenerated file, do not edit, this file provides stubs for builtins autocomplete in VSCode, PyCharm, etc",
6428
6675
  file=file,