warp-lang 1.9.0__py3-none-manylinux_2_34_aarch64.whl → 1.10.0__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (350) hide show
  1. warp/__init__.py +301 -287
  2. warp/__init__.pyi +2302 -307
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1077 -0
  5. warp/_src/build.py +620 -0
  6. warp/_src/build_dll.py +642 -0
  7. warp/{builtins.py → _src/builtins.py} +1546 -224
  8. warp/_src/codegen.py +4361 -0
  9. warp/{config.py → _src/config.py} +178 -169
  10. warp/_src/constants.py +59 -0
  11. warp/_src/context.py +8352 -0
  12. warp/_src/dlpack.py +464 -0
  13. warp/_src/fabric.py +362 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +510 -0
  16. warp/_src/fem/cache.py +689 -0
  17. warp/_src/fem/dirichlet.py +190 -0
  18. warp/{fem → _src/fem}/domain.py +42 -30
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +703 -0
  21. warp/{fem → _src/fem}/field/nodal_field.py +32 -15
  22. warp/{fem → _src/fem}/field/restriction.py +3 -1
  23. warp/{fem → _src/fem}/field/virtual.py +55 -27
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/{fem → _src/fem}/geometry/adaptive_nanogrid.py +79 -163
  26. warp/_src/fem/geometry/closest_point.py +99 -0
  27. warp/{fem → _src/fem}/geometry/deformed_geometry.py +16 -22
  28. warp/{fem → _src/fem}/geometry/element.py +34 -10
  29. warp/{fem → _src/fem}/geometry/geometry.py +50 -20
  30. warp/{fem → _src/fem}/geometry/grid_2d.py +14 -23
  31. warp/{fem → _src/fem}/geometry/grid_3d.py +14 -23
  32. warp/{fem → _src/fem}/geometry/hexmesh.py +42 -63
  33. warp/{fem → _src/fem}/geometry/nanogrid.py +256 -247
  34. warp/{fem → _src/fem}/geometry/partition.py +123 -63
  35. warp/{fem → _src/fem}/geometry/quadmesh.py +28 -45
  36. warp/{fem → _src/fem}/geometry/tetmesh.py +42 -63
  37. warp/{fem → _src/fem}/geometry/trimesh.py +28 -45
  38. warp/{fem → _src/fem}/integrate.py +166 -158
  39. warp/_src/fem/linalg.py +385 -0
  40. warp/_src/fem/operator.py +398 -0
  41. warp/_src/fem/polynomial.py +231 -0
  42. warp/{fem → _src/fem}/quadrature/pic_quadrature.py +17 -20
  43. warp/{fem → _src/fem}/quadrature/quadrature.py +97 -47
  44. warp/_src/fem/space/__init__.py +248 -0
  45. warp/{fem → _src/fem}/space/basis_function_space.py +22 -11
  46. warp/_src/fem/space/basis_space.py +681 -0
  47. warp/{fem → _src/fem}/space/dof_mapper.py +5 -3
  48. warp/{fem → _src/fem}/space/function_space.py +16 -13
  49. warp/{fem → _src/fem}/space/grid_2d_function_space.py +6 -7
  50. warp/{fem → _src/fem}/space/grid_3d_function_space.py +6 -4
  51. warp/{fem → _src/fem}/space/hexmesh_function_space.py +6 -10
  52. warp/{fem → _src/fem}/space/nanogrid_function_space.py +5 -9
  53. warp/{fem → _src/fem}/space/partition.py +119 -60
  54. warp/{fem → _src/fem}/space/quadmesh_function_space.py +6 -10
  55. warp/{fem → _src/fem}/space/restriction.py +68 -33
  56. warp/_src/fem/space/shape/__init__.py +152 -0
  57. warp/{fem → _src/fem}/space/shape/cube_shape_function.py +11 -9
  58. warp/{fem → _src/fem}/space/shape/shape_function.py +10 -9
  59. warp/{fem → _src/fem}/space/shape/square_shape_function.py +8 -6
  60. warp/{fem → _src/fem}/space/shape/tet_shape_function.py +5 -3
  61. warp/{fem → _src/fem}/space/shape/triangle_shape_function.py +5 -3
  62. warp/{fem → _src/fem}/space/tetmesh_function_space.py +5 -9
  63. warp/_src/fem/space/topology.py +461 -0
  64. warp/{fem → _src/fem}/space/trimesh_function_space.py +5 -9
  65. warp/_src/fem/types.py +114 -0
  66. warp/_src/fem/utils.py +488 -0
  67. warp/_src/jax.py +188 -0
  68. warp/_src/jax_experimental/__init__.py +14 -0
  69. warp/_src/jax_experimental/custom_call.py +389 -0
  70. warp/_src/jax_experimental/ffi.py +1286 -0
  71. warp/_src/jax_experimental/xla_ffi.py +658 -0
  72. warp/_src/marching_cubes.py +710 -0
  73. warp/_src/math.py +416 -0
  74. warp/_src/optim/__init__.py +14 -0
  75. warp/_src/optim/adam.py +165 -0
  76. warp/_src/optim/linear.py +1608 -0
  77. warp/_src/optim/sgd.py +114 -0
  78. warp/_src/paddle.py +408 -0
  79. warp/_src/render/__init__.py +14 -0
  80. warp/_src/render/imgui_manager.py +291 -0
  81. warp/_src/render/render_opengl.py +3638 -0
  82. warp/_src/render/render_usd.py +939 -0
  83. warp/_src/render/utils.py +162 -0
  84. warp/_src/sparse.py +2718 -0
  85. warp/_src/tape.py +1208 -0
  86. warp/{thirdparty → _src/thirdparty}/unittest_parallel.py +9 -2
  87. warp/_src/torch.py +393 -0
  88. warp/_src/types.py +5888 -0
  89. warp/_src/utils.py +1695 -0
  90. warp/autograd.py +12 -1054
  91. warp/bin/warp-clang.so +0 -0
  92. warp/bin/warp.so +0 -0
  93. warp/build.py +8 -588
  94. warp/build_dll.py +6 -471
  95. warp/codegen.py +6 -4246
  96. warp/constants.py +6 -39
  97. warp/context.py +12 -7851
  98. warp/dlpack.py +6 -444
  99. warp/examples/distributed/example_jacobi_mpi.py +4 -5
  100. warp/examples/fem/example_adaptive_grid.py +1 -1
  101. warp/examples/fem/example_apic_fluid.py +1 -1
  102. warp/examples/fem/example_burgers.py +8 -8
  103. warp/examples/fem/example_diffusion.py +1 -1
  104. warp/examples/fem/example_distortion_energy.py +1 -1
  105. warp/examples/fem/example_mixed_elasticity.py +2 -2
  106. warp/examples/fem/example_navier_stokes.py +1 -1
  107. warp/examples/fem/example_nonconforming_contact.py +7 -7
  108. warp/examples/fem/example_stokes.py +1 -1
  109. warp/examples/fem/example_stokes_transfer.py +1 -1
  110. warp/examples/fem/utils.py +2 -2
  111. warp/examples/interop/example_jax_callable.py +1 -1
  112. warp/examples/interop/example_jax_ffi_callback.py +1 -1
  113. warp/examples/interop/example_jax_kernel.py +3 -2
  114. warp/examples/tile/example_tile_mcgp.py +191 -0
  115. warp/fabric.py +6 -337
  116. warp/fem/__init__.py +159 -97
  117. warp/fem/adaptivity.py +7 -489
  118. warp/fem/cache.py +9 -648
  119. warp/fem/dirichlet.py +6 -184
  120. warp/fem/field/__init__.py +8 -109
  121. warp/fem/field/field.py +7 -652
  122. warp/fem/geometry/__init__.py +7 -18
  123. warp/fem/geometry/closest_point.py +11 -77
  124. warp/fem/linalg.py +18 -366
  125. warp/fem/operator.py +11 -369
  126. warp/fem/polynomial.py +9 -209
  127. warp/fem/space/__init__.py +5 -211
  128. warp/fem/space/basis_space.py +6 -662
  129. warp/fem/space/shape/__init__.py +41 -118
  130. warp/fem/space/topology.py +6 -437
  131. warp/fem/types.py +6 -81
  132. warp/fem/utils.py +11 -444
  133. warp/jax.py +8 -165
  134. warp/jax_experimental/__init__.py +14 -1
  135. warp/jax_experimental/custom_call.py +8 -342
  136. warp/jax_experimental/ffi.py +17 -853
  137. warp/jax_experimental/xla_ffi.py +5 -596
  138. warp/marching_cubes.py +5 -689
  139. warp/math.py +16 -393
  140. warp/native/array.h +385 -37
  141. warp/native/builtin.h +316 -39
  142. warp/native/bvh.cpp +43 -9
  143. warp/native/bvh.cu +62 -27
  144. warp/native/bvh.h +310 -309
  145. warp/native/clang/clang.cpp +102 -97
  146. warp/native/coloring.cpp +0 -1
  147. warp/native/crt.h +208 -0
  148. warp/native/exports.h +156 -0
  149. warp/native/hashgrid.cu +2 -0
  150. warp/native/intersect.h +24 -1
  151. warp/native/intersect_tri.h +44 -35
  152. warp/native/mat.h +1456 -276
  153. warp/native/mesh.cpp +4 -4
  154. warp/native/mesh.cu +4 -2
  155. warp/native/mesh.h +176 -61
  156. warp/native/quat.h +0 -52
  157. warp/native/scan.cu +2 -0
  158. warp/native/sort.cu +22 -13
  159. warp/native/sort.h +2 -0
  160. warp/native/sparse.cu +7 -3
  161. warp/native/spatial.h +12 -0
  162. warp/native/tile.h +837 -70
  163. warp/native/tile_radix_sort.h +3 -3
  164. warp/native/tile_reduce.h +394 -46
  165. warp/native/tile_scan.h +4 -4
  166. warp/native/vec.h +469 -53
  167. warp/native/version.h +23 -0
  168. warp/native/volume.cpp +1 -1
  169. warp/native/volume.cu +1 -0
  170. warp/native/volume.h +1 -1
  171. warp/native/volume_builder.cu +2 -0
  172. warp/native/warp.cpp +60 -32
  173. warp/native/warp.cu +581 -280
  174. warp/native/warp.h +14 -11
  175. warp/optim/__init__.py +6 -3
  176. warp/optim/adam.py +6 -145
  177. warp/optim/linear.py +14 -1585
  178. warp/optim/sgd.py +6 -94
  179. warp/paddle.py +6 -388
  180. warp/render/__init__.py +8 -4
  181. warp/render/imgui_manager.py +7 -267
  182. warp/render/render_opengl.py +6 -3616
  183. warp/render/render_usd.py +6 -918
  184. warp/render/utils.py +6 -142
  185. warp/sparse.py +37 -2563
  186. warp/tape.py +6 -1188
  187. warp/tests/__main__.py +1 -1
  188. warp/tests/cuda/test_async.py +4 -4
  189. warp/tests/cuda/test_conditional_captures.py +1 -1
  190. warp/tests/cuda/test_multigpu.py +1 -1
  191. warp/tests/cuda/test_streams.py +58 -1
  192. warp/tests/geometry/test_bvh.py +157 -22
  193. warp/tests/geometry/test_hash_grid.py +38 -0
  194. warp/tests/geometry/test_marching_cubes.py +0 -1
  195. warp/tests/geometry/test_mesh.py +5 -3
  196. warp/tests/geometry/test_mesh_query_aabb.py +5 -12
  197. warp/tests/geometry/test_mesh_query_point.py +5 -2
  198. warp/tests/geometry/test_mesh_query_ray.py +15 -3
  199. warp/tests/geometry/test_volume_write.py +5 -5
  200. warp/tests/interop/test_dlpack.py +18 -17
  201. warp/tests/interop/test_jax.py +1382 -79
  202. warp/tests/interop/test_paddle.py +1 -1
  203. warp/tests/test_adam.py +0 -1
  204. warp/tests/test_arithmetic.py +9 -9
  205. warp/tests/test_array.py +580 -100
  206. warp/tests/test_array_reduce.py +3 -3
  207. warp/tests/test_atomic.py +12 -8
  208. warp/tests/test_atomic_bitwise.py +209 -0
  209. warp/tests/test_atomic_cas.py +4 -4
  210. warp/tests/test_bool.py +2 -2
  211. warp/tests/test_builtins_resolution.py +5 -571
  212. warp/tests/test_codegen.py +34 -15
  213. warp/tests/test_conditional.py +1 -1
  214. warp/tests/test_context.py +6 -6
  215. warp/tests/test_copy.py +242 -161
  216. warp/tests/test_ctypes.py +3 -3
  217. warp/tests/test_devices.py +24 -2
  218. warp/tests/test_examples.py +16 -84
  219. warp/tests/test_fabricarray.py +35 -35
  220. warp/tests/test_fast_math.py +0 -2
  221. warp/tests/test_fem.py +60 -14
  222. warp/tests/test_fixedarray.py +3 -3
  223. warp/tests/test_func.py +8 -5
  224. warp/tests/test_generics.py +1 -1
  225. warp/tests/test_indexedarray.py +24 -24
  226. warp/tests/test_intersect.py +39 -9
  227. warp/tests/test_large.py +1 -1
  228. warp/tests/test_lerp.py +3 -1
  229. warp/tests/test_linear_solvers.py +1 -1
  230. warp/tests/test_map.py +49 -4
  231. warp/tests/test_mat.py +52 -62
  232. warp/tests/test_mat_constructors.py +4 -5
  233. warp/tests/test_mat_lite.py +1 -1
  234. warp/tests/test_mat_scalar_ops.py +121 -121
  235. warp/tests/test_math.py +34 -0
  236. warp/tests/test_module_aot.py +4 -4
  237. warp/tests/test_modules_lite.py +28 -2
  238. warp/tests/test_print.py +11 -11
  239. warp/tests/test_quat.py +93 -58
  240. warp/tests/test_runlength_encode.py +1 -1
  241. warp/tests/test_scalar_ops.py +38 -10
  242. warp/tests/test_smoothstep.py +1 -1
  243. warp/tests/test_sparse.py +126 -15
  244. warp/tests/test_spatial.py +105 -87
  245. warp/tests/test_special_values.py +6 -6
  246. warp/tests/test_static.py +7 -7
  247. warp/tests/test_struct.py +13 -2
  248. warp/tests/test_triangle_closest_point.py +48 -1
  249. warp/tests/test_tuple.py +96 -0
  250. warp/tests/test_types.py +82 -9
  251. warp/tests/test_utils.py +52 -52
  252. warp/tests/test_vec.py +29 -29
  253. warp/tests/test_vec_constructors.py +5 -5
  254. warp/tests/test_vec_scalar_ops.py +97 -97
  255. warp/tests/test_version.py +75 -0
  256. warp/tests/tile/test_tile.py +239 -0
  257. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  258. warp/tests/tile/test_tile_cholesky.py +7 -4
  259. warp/tests/tile/test_tile_load.py +26 -2
  260. warp/tests/tile/test_tile_mathdx.py +3 -3
  261. warp/tests/tile/test_tile_matmul.py +1 -1
  262. warp/tests/tile/test_tile_mlp.py +2 -4
  263. warp/tests/tile/test_tile_reduce.py +214 -13
  264. warp/tests/unittest_suites.py +6 -14
  265. warp/tests/unittest_utils.py +10 -9
  266. warp/tests/walkthrough_debug.py +3 -1
  267. warp/torch.py +6 -373
  268. warp/types.py +29 -5750
  269. warp/utils.py +10 -1659
  270. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0.dist-info}/METADATA +47 -103
  271. warp_lang-1.10.0.dist-info/RECORD +468 -0
  272. warp_lang-1.10.0.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  273. warp_lang-1.10.0.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  274. warp_lang-1.10.0.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  275. warp_lang-1.10.0.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  276. warp_lang-1.10.0.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  277. warp_lang-1.10.0.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  278. warp_lang-1.10.0.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  279. warp_lang-1.10.0.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  280. warp_lang-1.10.0.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  281. warp_lang-1.10.0.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  282. warp_lang-1.10.0.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  283. warp_lang-1.10.0.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  284. warp_lang-1.10.0.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  285. warp_lang-1.10.0.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  286. warp_lang-1.10.0.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  287. warp/examples/assets/cartpole.urdf +0 -110
  288. warp/examples/assets/crazyflie.usd +0 -0
  289. warp/examples/assets/nv_ant.xml +0 -92
  290. warp/examples/assets/nv_humanoid.xml +0 -183
  291. warp/examples/assets/quadruped.urdf +0 -268
  292. warp/examples/optim/example_bounce.py +0 -266
  293. warp/examples/optim/example_cloth_throw.py +0 -228
  294. warp/examples/optim/example_drone.py +0 -870
  295. warp/examples/optim/example_inverse_kinematics.py +0 -182
  296. warp/examples/optim/example_inverse_kinematics_torch.py +0 -191
  297. warp/examples/optim/example_softbody_properties.py +0 -400
  298. warp/examples/optim/example_spring_cage.py +0 -245
  299. warp/examples/optim/example_trajectory.py +0 -227
  300. warp/examples/sim/example_cartpole.py +0 -143
  301. warp/examples/sim/example_cloth.py +0 -225
  302. warp/examples/sim/example_cloth_self_contact.py +0 -316
  303. warp/examples/sim/example_granular.py +0 -130
  304. warp/examples/sim/example_granular_collision_sdf.py +0 -202
  305. warp/examples/sim/example_jacobian_ik.py +0 -244
  306. warp/examples/sim/example_particle_chain.py +0 -124
  307. warp/examples/sim/example_quadruped.py +0 -203
  308. warp/examples/sim/example_rigid_chain.py +0 -203
  309. warp/examples/sim/example_rigid_contact.py +0 -195
  310. warp/examples/sim/example_rigid_force.py +0 -133
  311. warp/examples/sim/example_rigid_gyroscopic.py +0 -115
  312. warp/examples/sim/example_rigid_soft_contact.py +0 -140
  313. warp/examples/sim/example_soft_body.py +0 -196
  314. warp/examples/tile/example_tile_walker.py +0 -327
  315. warp/sim/__init__.py +0 -74
  316. warp/sim/articulation.py +0 -793
  317. warp/sim/collide.py +0 -2570
  318. warp/sim/graph_coloring.py +0 -307
  319. warp/sim/import_mjcf.py +0 -791
  320. warp/sim/import_snu.py +0 -227
  321. warp/sim/import_urdf.py +0 -579
  322. warp/sim/import_usd.py +0 -898
  323. warp/sim/inertia.py +0 -357
  324. warp/sim/integrator.py +0 -245
  325. warp/sim/integrator_euler.py +0 -2000
  326. warp/sim/integrator_featherstone.py +0 -2101
  327. warp/sim/integrator_vbd.py +0 -2487
  328. warp/sim/integrator_xpbd.py +0 -3295
  329. warp/sim/model.py +0 -4821
  330. warp/sim/particles.py +0 -121
  331. warp/sim/render.py +0 -431
  332. warp/sim/utils.py +0 -431
  333. warp/tests/sim/disabled_kinematics.py +0 -244
  334. warp/tests/sim/test_cloth.py +0 -863
  335. warp/tests/sim/test_collision.py +0 -743
  336. warp/tests/sim/test_coloring.py +0 -347
  337. warp/tests/sim/test_inertia.py +0 -161
  338. warp/tests/sim/test_model.py +0 -226
  339. warp/tests/sim/test_sim_grad.py +0 -287
  340. warp/tests/sim/test_sim_grad_bounce_linear.py +0 -212
  341. warp/tests/sim/test_sim_kinematics.py +0 -98
  342. warp/thirdparty/__init__.py +0 -0
  343. warp_lang-1.9.0.dist-info/RECORD +0 -456
  344. /warp/{fem → _src/fem}/quadrature/__init__.py +0 -0
  345. /warp/{tests/sim → _src/thirdparty}/__init__.py +0 -0
  346. /warp/{thirdparty → _src/thirdparty}/appdirs.py +0 -0
  347. /warp/{thirdparty → _src/thirdparty}/dlpack.py +0 -0
  348. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0.dist-info}/WHEEL +0 -0
  349. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0.dist-info}/licenses/LICENSE.md +0 -0
  350. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0.dist-info}/top_level.txt +0 -0
warp/native/warp.cu CHANGED
@@ -19,6 +19,7 @@
19
19
  #include "scan.h"
20
20
  #include "cuda_util.h"
21
21
  #include "error.h"
22
+ #include "sort.h"
22
23
 
23
24
  #include <cstdlib>
24
25
  #include <fstream>
@@ -37,6 +38,7 @@
37
38
  #include <iterator>
38
39
  #include <list>
39
40
  #include <map>
41
+ #include <mutex>
40
42
  #include <string>
41
43
  #include <unordered_map>
42
44
  #include <unordered_set>
@@ -175,11 +177,20 @@ struct ContextInfo
175
177
  CUmodule conditional_module = NULL;
176
178
  };
177
179
 
180
+ // Information used for freeing allocations.
181
+ struct FreeInfo
182
+ {
183
+ void* context = NULL;
184
+ void* ptr = NULL;
185
+ bool is_async = false;
186
+ };
187
+
178
188
  struct CaptureInfo
179
189
  {
180
190
  CUstream stream = NULL; // the main stream where capture begins and ends
181
191
  uint64_t id = 0; // unique capture id from CUDA
182
192
  bool external = false; // whether this is an external capture
193
+ std::vector<FreeInfo> tmp_allocs; // temporary allocations owned by the graph (e.g., staged array fill values)
183
194
  };
184
195
 
185
196
  struct StreamInfo
@@ -188,9 +199,13 @@ struct StreamInfo
188
199
  CaptureInfo* capture = NULL; // capture info (only if started on this stream)
189
200
  };
190
201
 
191
- struct GraphInfo
202
+ // Extra resources tied to a graph, freed after the graph is released by CUDA.
203
+ // Used with the on_graph_destroy() callback.
204
+ struct GraphDestroyCallbackInfo
192
205
  {
193
- std::vector<void*> unfreed_allocs;
206
+ void* context = NULL; // graph CUDA context
207
+ std::vector<void*> unfreed_allocs; // graph allocations not freed by the graph
208
+ std::vector<FreeInfo> tmp_allocs; // temporary allocations owned by the graph (e.g., staged array fill values)
194
209
  };
195
210
 
196
211
  // Information for graph allocations that are not freed by the graph.
@@ -206,19 +221,19 @@ struct GraphAllocInfo
206
221
  bool graph_destroyed = false; // whether graph instance was destroyed
207
222
  };
208
223
 
209
- // Information used when deferring deallocations.
210
- struct FreeInfo
224
+ // Information used when deferring module unloading.
225
+ struct ModuleInfo
211
226
  {
212
227
  void* context = NULL;
213
- void* ptr = NULL;
214
- bool is_async = false;
228
+ void* module = NULL;
215
229
  };
216
230
 
217
- // Information used when deferring module unloading.
218
- struct ModuleInfo
231
+ // Information used when deferring graph destruction.
232
+ struct GraphDestroyInfo
219
233
  {
220
234
  void* context = NULL;
221
- void* module = NULL;
235
+ void* graph = NULL;
236
+ void* graph_exec = NULL;
222
237
  };
223
238
 
224
239
  static std::unordered_map<CUfunction, std::string> g_kernel_names;
@@ -252,6 +267,15 @@ static std::vector<FreeInfo> g_deferred_free_list;
252
267
  // Call unload_deferred_modules() to release.
253
268
  static std::vector<ModuleInfo> g_deferred_module_list;
254
269
 
270
+ // Graphs that cannot be destroyed immediately get queued here.
271
+ // Call destroy_deferred_graphs() to release.
272
+ static std::vector<GraphDestroyInfo> g_deferred_graph_list;
273
+
274
+ // Data from on_graph_destroy() callbacks that run on a different thread.
275
+ static std::vector<GraphDestroyCallbackInfo*> g_deferred_graph_destroy_list;
276
+ static std::mutex g_graph_destroy_mutex;
277
+
278
+
255
279
  void wp_cuda_set_context_restore_policy(bool always_restore)
256
280
  {
257
281
  ContextGuard::always_restore = always_restore;
@@ -337,7 +361,7 @@ int cuda_init()
337
361
  }
338
362
 
339
363
 
340
- static inline CUcontext get_current_context()
364
+ CUcontext get_current_context()
341
365
  {
342
366
  CUcontext ctx;
343
367
  if (check_cu(cuCtxGetCurrent_f(&ctx)))
@@ -407,6 +431,114 @@ static inline StreamInfo* get_stream_info(CUstream stream)
407
431
  return NULL;
408
432
  }
409
433
 
434
+ static inline CaptureInfo* get_capture_info(CUstream stream)
435
+ {
436
+ if (!g_captures.empty() && wp_cuda_stream_is_capturing(stream))
437
+ {
438
+ uint64_t capture_id = get_capture_id(stream);
439
+ auto capture_iter = g_captures.find(capture_id);
440
+ if (capture_iter != g_captures.end())
441
+ return capture_iter->second;
442
+ }
443
+ return NULL;
444
+ }
445
+
446
+ // helper function to copy a value to device memory in a graph-friendly way
447
+ static bool capturable_tmp_alloc(void* context, const void* data, size_t size, void** devptr_ret, bool* free_devptr_ret)
448
+ {
449
+ ContextGuard guard(context);
450
+
451
+ CUstream stream = get_current_stream();
452
+ CaptureInfo* capture_info = get_capture_info(stream);
453
+ int device_ordinal = wp_cuda_context_get_device_ordinal(context);
454
+ void* devptr = NULL;
455
+ bool free_devptr = true;
456
+
457
+ if (capture_info)
458
+ {
459
+ // ongoing graph capture - need to stage the fill value so that it persists with the graph
460
+ if (CUDA_VERSION >= 12040 && wp_cuda_driver_version() >= 12040)
461
+ {
462
+ // pause the capture so that the alloc/memcpy won't be captured
463
+ void* graph = NULL;
464
+ if (!wp_cuda_graph_pause_capture(WP_CURRENT_CONTEXT, stream, &graph))
465
+ return false;
466
+
467
+ // copy value to device memory
468
+ devptr = wp_alloc_device(WP_CURRENT_CONTEXT, size);
469
+ if (!devptr)
470
+ {
471
+ fprintf(stderr, "Warp error: Failed to allocate %llu bytes on device 'cuda:%d' (in function %s)\n", (unsigned long long)size, device_ordinal, __FUNCTION__);
472
+ return false;
473
+ }
474
+ if (!check_cuda(cudaMemcpyAsync(devptr, data, size, cudaMemcpyHostToDevice, stream)))
475
+ return false;
476
+
477
+ // graph takes ownership of the value storage
478
+ FreeInfo free_info;
479
+ free_info.context = context ? context : get_current_context();
480
+ free_info.ptr = devptr;
481
+ free_info.is_async = wp_cuda_device_is_mempool_supported(device_ordinal);
482
+
483
+ // allocation will be freed when graph is destroyed
484
+ capture_info->tmp_allocs.push_back(free_info);
485
+
486
+ // resume the capture
487
+ if (!wp_cuda_graph_resume_capture(WP_CURRENT_CONTEXT, stream, graph))
488
+ return false;
489
+
490
+ free_devptr = false; // memory is owned by the graph, doesn't need to be freed
491
+ }
492
+ else
493
+ {
494
+ // older CUDA can't pause/resume the capture, so stage in CPU memory
495
+ void* hostptr = wp_alloc_host(size);
496
+ if (!hostptr)
497
+ {
498
+ fprintf(stderr, "Warp error: Failed to allocate %llu bytes on device 'cpu' (in function %s)\n", (unsigned long long)size, __FUNCTION__);
499
+ return false;
500
+ }
501
+ memcpy(hostptr, data, size);
502
+
503
+ // the device allocation and h2d copy will be captured in the graph
504
+ devptr = wp_alloc_device(WP_CURRENT_CONTEXT, size);
505
+ if (!devptr)
506
+ {
507
+ fprintf(stderr, "Warp error: Failed to allocate %llu bytes on device 'cuda:%d' (in function %s)\n", (unsigned long long)size, device_ordinal, __FUNCTION__);
508
+ return false;
509
+ }
510
+ if (!check_cuda(cudaMemcpyAsync(devptr, hostptr, size, cudaMemcpyHostToDevice, stream)))
511
+ return false;
512
+
513
+ // graph takes ownership of the value storage
514
+ FreeInfo free_info;
515
+ free_info.context = NULL;
516
+ free_info.ptr = hostptr;
517
+ free_info.is_async = false;
518
+
519
+ // allocation will be freed when graph is destroyed
520
+ capture_info->tmp_allocs.push_back(free_info);
521
+ }
522
+ }
523
+ else
524
+ {
525
+ // not capturing, copy the value to device memory
526
+ devptr = wp_alloc_device(WP_CURRENT_CONTEXT, size);
527
+ if (!devptr)
528
+ {
529
+ fprintf(stderr, "Warp error: Failed to allocate %llu bytes on device 'cuda:%d' (in function %s)\n", (unsigned long long)size, device_ordinal, __FUNCTION__);
530
+ return false;
531
+ }
532
+ if (!check_cuda(cudaMemcpyAsync(devptr, data, size, cudaMemcpyHostToDevice, stream)))
533
+ return false;
534
+ }
535
+
536
+ *devptr_ret = devptr;
537
+ *free_devptr_ret = free_devptr;
538
+
539
+ return true;
540
+ }
541
+
410
542
  static void deferred_free(void* ptr, void* context, bool is_async)
411
543
  {
412
544
  FreeInfo free_info;
@@ -494,34 +626,124 @@ static int unload_deferred_modules(void* context = NULL)
494
626
  return num_unloaded_modules;
495
627
  }
496
628
 
497
- static void CUDART_CB on_graph_destroy(void* user_data)
629
+ static int destroy_deferred_graphs(void* context = NULL)
498
630
  {
499
- if (!user_data)
500
- return;
631
+ if (g_deferred_graph_list.empty() || !g_captures.empty())
632
+ return 0;
633
+
634
+ int num_destroyed_graphs = 0;
635
+ for (auto it = g_deferred_graph_list.begin(); it != g_deferred_graph_list.end(); /*noop*/)
636
+ {
637
+ // destroy the graph if it matches the given context or if the context is unspecified
638
+ const GraphDestroyInfo& graph_info = *it;
639
+ if (graph_info.context == context || !context)
640
+ {
641
+ if (graph_info.graph)
642
+ {
643
+ check_cuda(cudaGraphDestroy((cudaGraph_t)graph_info.graph));
644
+ }
645
+ if (graph_info.graph_exec)
646
+ {
647
+ check_cuda(cudaGraphExecDestroy((cudaGraphExec_t)graph_info.graph_exec));
648
+ }
649
+ ++num_destroyed_graphs;
650
+ it = g_deferred_graph_list.erase(it);
651
+ }
652
+ else
653
+ {
654
+ ++it;
655
+ }
656
+ }
657
+
658
+ return num_destroyed_graphs;
659
+ }
660
+
661
+ static int process_deferred_graph_destroy_callbacks(void* context = NULL)
662
+ {
663
+ int num_freed = 0;
501
664
 
502
- GraphInfo* graph_info = static_cast<GraphInfo*>(user_data);
665
+ std::lock_guard<std::mutex> lock(g_graph_destroy_mutex);
503
666
 
504
- for (void* ptr : graph_info->unfreed_allocs)
667
+ for (auto it = g_deferred_graph_destroy_list.begin(); it != g_deferred_graph_destroy_list.end(); /*noop*/)
505
668
  {
506
- auto alloc_iter = g_graph_allocs.find(ptr);
507
- if (alloc_iter != g_graph_allocs.end())
669
+ GraphDestroyCallbackInfo* graph_info = *it;
670
+ if (graph_info->context == context || !context)
508
671
  {
509
- GraphAllocInfo& alloc_info = alloc_iter->second;
510
- if (alloc_info.ref_exists)
672
+ // handle unfreed graph allocations (may have outstanding user references)
673
+ for (void* ptr : graph_info->unfreed_allocs)
511
674
  {
512
- // unreference from graph so the pointer will be deallocated when the user reference goes away
513
- alloc_info.graph_destroyed = true;
675
+ auto alloc_iter = g_graph_allocs.find(ptr);
676
+ if (alloc_iter != g_graph_allocs.end())
677
+ {
678
+ GraphAllocInfo& alloc_info = alloc_iter->second;
679
+ if (alloc_info.ref_exists)
680
+ {
681
+ // unreference from graph so the pointer will be deallocated when the user reference goes away
682
+ alloc_info.graph_destroyed = true;
683
+ }
684
+ else
685
+ {
686
+ // the pointer can be freed, no references remain
687
+ wp_free_device_async(alloc_info.context, ptr);
688
+ g_graph_allocs.erase(alloc_iter);
689
+ }
690
+ }
514
691
  }
515
- else
692
+
693
+ // handle temporary allocations owned by the graph (no user references)
694
+ for (const FreeInfo& tmp_info : graph_info->tmp_allocs)
516
695
  {
517
- // the pointer can be freed, but we can't call CUDA functions in this callback, so defer it
518
- deferred_free(ptr, alloc_info.context, true);
519
- g_graph_allocs.erase(alloc_iter);
696
+ if (tmp_info.context)
697
+ {
698
+ // GPU alloc
699
+ if (tmp_info.is_async)
700
+ {
701
+ wp_free_device_async(tmp_info.context, tmp_info.ptr);
702
+ }
703
+ else
704
+ {
705
+ wp_free_device_default(tmp_info.context, tmp_info.ptr);
706
+ }
707
+ }
708
+ else
709
+ {
710
+ // CPU alloc
711
+ wp_free_host(tmp_info.ptr);
712
+ }
520
713
  }
714
+
715
+ ++num_freed;
716
+ delete graph_info;
717
+ it = g_deferred_graph_destroy_list.erase(it);
718
+ }
719
+ else
720
+ {
721
+ ++it;
521
722
  }
522
723
  }
523
724
 
524
- delete graph_info;
725
+ return num_freed;
726
+ }
727
+
728
+ static int run_deferred_actions(void* context = NULL)
729
+ {
730
+ int num_actions = 0;
731
+ num_actions += free_deferred_allocs(context);
732
+ num_actions += unload_deferred_modules(context);
733
+ num_actions += destroy_deferred_graphs(context);
734
+ num_actions += process_deferred_graph_destroy_callbacks(context);
735
+ return num_actions;
736
+ }
737
+
738
+ // Callback used when a graph is destroyed.
739
+ // NOTE: this runs on an internal CUDA thread and requires synchronization.
740
+ static void CUDART_CB on_graph_destroy(void* user_data)
741
+ {
742
+ if (user_data)
743
+ {
744
+ std::lock_guard<std::mutex> lock(g_graph_destroy_mutex);
745
+ g_deferred_graph_destroy_list.push_back(static_cast<GraphDestroyCallbackInfo*>(user_data));
746
+ }
525
747
  }
526
748
 
527
749
  static inline const char* get_cuda_kernel_name(void* kernel)
@@ -973,30 +1195,36 @@ void wp_memtile_device(void* context, void* dst, const void* src, size_t srcsize
973
1195
  else
974
1196
  {
975
1197
  // generic version
1198
+ void* value_devptr = NULL; // fill value in device memory
1199
+ bool free_devptr = true; // whether we need to free the memory
976
1200
 
977
- // copy value to device memory
978
- // TODO: use a persistent stream-local staging buffer to avoid allocs?
979
- void* src_devptr = wp_alloc_device(WP_CURRENT_CONTEXT, srcsize);
980
- check_cuda(cudaMemcpyAsync(src_devptr, src, srcsize, cudaMemcpyHostToDevice, get_current_stream()));
981
-
982
- wp_launch_device(WP_CURRENT_CONTEXT, memtile_kernel, n, (dst, src_devptr, srcsize, n));
1201
+ // prepare the fill value in a graph-friendly way
1202
+ if (!capturable_tmp_alloc(WP_CURRENT_CONTEXT, src, srcsize, &value_devptr, &free_devptr))
1203
+ {
1204
+ fprintf(stderr, "Warp fill error: failed to copy value to device memory\n");
1205
+ return;
1206
+ }
983
1207
 
984
- wp_free_device(WP_CURRENT_CONTEXT, src_devptr);
1208
+ wp_launch_device(WP_CURRENT_CONTEXT, memtile_kernel, n, (dst, value_devptr, srcsize, n));
985
1209
 
1210
+ if (free_devptr)
1211
+ {
1212
+ wp_free_device(WP_CURRENT_CONTEXT, value_devptr);
1213
+ }
986
1214
  }
987
1215
  }
988
1216
 
989
1217
 
990
1218
  static __global__ void array_copy_1d_kernel(void* dst, const void* src,
991
- int dst_stride, int src_stride,
1219
+ size_t dst_stride, size_t src_stride,
992
1220
  const int* dst_indices, const int* src_indices,
993
- int n, int elem_size)
1221
+ size_t n, size_t elem_size)
994
1222
  {
995
- int i = blockIdx.x * blockDim.x + threadIdx.x;
1223
+ size_t i = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
996
1224
  if (i < n)
997
1225
  {
998
- int src_idx = src_indices ? src_indices[i] : i;
999
- int dst_idx = dst_indices ? dst_indices[i] : i;
1226
+ size_t src_idx = src_indices ? src_indices[i] : i;
1227
+ size_t dst_idx = dst_indices ? dst_indices[i] : i;
1000
1228
  const char* p = (const char*)src + src_idx * src_stride;
1001
1229
  char* q = (char*)dst + dst_idx * dst_stride;
1002
1230
  memcpy(q, p, elem_size);
@@ -1004,20 +1232,20 @@ static __global__ void array_copy_1d_kernel(void* dst, const void* src,
1004
1232
  }
1005
1233
 
1006
1234
  static __global__ void array_copy_2d_kernel(void* dst, const void* src,
1007
- wp::vec_t<2, int> dst_strides, wp::vec_t<2, int> src_strides,
1235
+ wp::vec_t<2, size_t> dst_strides, wp::vec_t<2, size_t> src_strides,
1008
1236
  wp::vec_t<2, const int*> dst_indices, wp::vec_t<2, const int*> src_indices,
1009
- wp::vec_t<2, int> shape, int elem_size)
1237
+ wp::vec_t<2, size_t> shape, size_t elem_size)
1010
1238
  {
1011
- int tid = blockIdx.x * blockDim.x + threadIdx.x;
1012
- int n = shape[1];
1013
- int i = tid / n;
1014
- int j = tid % n;
1239
+ size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1240
+ size_t n = shape[1];
1241
+ size_t i = tid / n;
1242
+ size_t j = tid % n;
1015
1243
  if (i < shape[0] /*&& j < shape[1]*/)
1016
1244
  {
1017
- int src_idx0 = src_indices[0] ? src_indices[0][i] : i;
1018
- int dst_idx0 = dst_indices[0] ? dst_indices[0][i] : i;
1019
- int src_idx1 = src_indices[1] ? src_indices[1][j] : j;
1020
- int dst_idx1 = dst_indices[1] ? dst_indices[1][j] : j;
1245
+ size_t src_idx0 = src_indices[0] ? src_indices[0][i] : i;
1246
+ size_t dst_idx0 = dst_indices[0] ? dst_indices[0][i] : i;
1247
+ size_t src_idx1 = src_indices[1] ? src_indices[1][j] : j;
1248
+ size_t dst_idx1 = dst_indices[1] ? dst_indices[1][j] : j;
1021
1249
  const char* p = (const char*)src + src_idx0 * src_strides[0] + src_idx1 * src_strides[1];
1022
1250
  char* q = (char*)dst + dst_idx0 * dst_strides[0] + dst_idx1 * dst_strides[1];
1023
1251
  memcpy(q, p, elem_size);
@@ -1025,24 +1253,24 @@ static __global__ void array_copy_2d_kernel(void* dst, const void* src,
1025
1253
  }
1026
1254
 
1027
1255
  static __global__ void array_copy_3d_kernel(void* dst, const void* src,
1028
- wp::vec_t<3, int> dst_strides, wp::vec_t<3, int> src_strides,
1256
+ wp::vec_t<3, size_t> dst_strides, wp::vec_t<3, size_t> src_strides,
1029
1257
  wp::vec_t<3, const int*> dst_indices, wp::vec_t<3, const int*> src_indices,
1030
- wp::vec_t<3, int> shape, int elem_size)
1031
- {
1032
- int tid = blockIdx.x * blockDim.x + threadIdx.x;
1033
- int n = shape[1];
1034
- int o = shape[2];
1035
- int i = tid / (n * o);
1036
- int j = tid % (n * o) / o;
1037
- int k = tid % o;
1258
+ wp::vec_t<3, size_t> shape, size_t elem_size)
1259
+ {
1260
+ size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1261
+ size_t n = shape[1];
1262
+ size_t o = shape[2];
1263
+ size_t i = tid / (n * o);
1264
+ size_t j = tid % (n * o) / o;
1265
+ size_t k = tid % o;
1038
1266
  if (i < shape[0] && j < shape[1] /*&& k < shape[2]*/)
1039
1267
  {
1040
- int src_idx0 = src_indices[0] ? src_indices[0][i] : i;
1041
- int dst_idx0 = dst_indices[0] ? dst_indices[0][i] : i;
1042
- int src_idx1 = src_indices[1] ? src_indices[1][j] : j;
1043
- int dst_idx1 = dst_indices[1] ? dst_indices[1][j] : j;
1044
- int src_idx2 = src_indices[2] ? src_indices[2][k] : k;
1045
- int dst_idx2 = dst_indices[2] ? dst_indices[2][k] : k;
1268
+ size_t src_idx0 = src_indices[0] ? src_indices[0][i] : i;
1269
+ size_t dst_idx0 = dst_indices[0] ? dst_indices[0][i] : i;
1270
+ size_t src_idx1 = src_indices[1] ? src_indices[1][j] : j;
1271
+ size_t dst_idx1 = dst_indices[1] ? dst_indices[1][j] : j;
1272
+ size_t src_idx2 = src_indices[2] ? src_indices[2][k] : k;
1273
+ size_t dst_idx2 = dst_indices[2] ? dst_indices[2][k] : k;
1046
1274
  const char* p = (const char*)src + src_idx0 * src_strides[0]
1047
1275
  + src_idx1 * src_strides[1]
1048
1276
  + src_idx2 * src_strides[2];
@@ -1054,28 +1282,28 @@ static __global__ void array_copy_3d_kernel(void* dst, const void* src,
1054
1282
  }
1055
1283
 
1056
1284
  static __global__ void array_copy_4d_kernel(void* dst, const void* src,
1057
- wp::vec_t<4, int> dst_strides, wp::vec_t<4, int> src_strides,
1285
+ wp::vec_t<4, size_t> dst_strides, wp::vec_t<4, size_t> src_strides,
1058
1286
  wp::vec_t<4, const int*> dst_indices, wp::vec_t<4, const int*> src_indices,
1059
- wp::vec_t<4, int> shape, int elem_size)
1060
- {
1061
- int tid = blockIdx.x * blockDim.x + threadIdx.x;
1062
- int n = shape[1];
1063
- int o = shape[2];
1064
- int p = shape[3];
1065
- int i = tid / (n * o * p);
1066
- int j = tid % (n * o * p) / (o * p);
1067
- int k = tid % (o * p) / p;
1068
- int l = tid % p;
1287
+ wp::vec_t<4, size_t> shape, size_t elem_size)
1288
+ {
1289
+ size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1290
+ size_t n = shape[1];
1291
+ size_t o = shape[2];
1292
+ size_t p = shape[3];
1293
+ size_t i = tid / (n * o * p);
1294
+ size_t j = tid % (n * o * p) / (o * p);
1295
+ size_t k = tid % (o * p) / p;
1296
+ size_t l = tid % p;
1069
1297
  if (i < shape[0] && j < shape[1] && k < shape[2] /*&& l < shape[3]*/)
1070
1298
  {
1071
- int src_idx0 = src_indices[0] ? src_indices[0][i] : i;
1072
- int dst_idx0 = dst_indices[0] ? dst_indices[0][i] : i;
1073
- int src_idx1 = src_indices[1] ? src_indices[1][j] : j;
1074
- int dst_idx1 = dst_indices[1] ? dst_indices[1][j] : j;
1075
- int src_idx2 = src_indices[2] ? src_indices[2][k] : k;
1076
- int dst_idx2 = dst_indices[2] ? dst_indices[2][k] : k;
1077
- int src_idx3 = src_indices[3] ? src_indices[3][l] : l;
1078
- int dst_idx3 = dst_indices[3] ? dst_indices[3][l] : l;
1299
+ size_t src_idx0 = src_indices[0] ? src_indices[0][i] : i;
1300
+ size_t dst_idx0 = dst_indices[0] ? dst_indices[0][i] : i;
1301
+ size_t src_idx1 = src_indices[1] ? src_indices[1][j] : j;
1302
+ size_t dst_idx1 = dst_indices[1] ? dst_indices[1][j] : j;
1303
+ size_t src_idx2 = src_indices[2] ? src_indices[2][k] : k;
1304
+ size_t dst_idx2 = dst_indices[2] ? dst_indices[2][k] : k;
1305
+ size_t src_idx3 = src_indices[3] ? src_indices[3][l] : l;
1306
+ size_t dst_idx3 = dst_indices[3] ? dst_indices[3][l] : l;
1079
1307
  const char* p = (const char*)src + src_idx0 * src_strides[0]
1080
1308
  + src_idx1 * src_strides[1]
1081
1309
  + src_idx2 * src_strides[2]
@@ -1090,14 +1318,14 @@ static __global__ void array_copy_4d_kernel(void* dst, const void* src,
1090
1318
 
1091
1319
 
1092
1320
  static __global__ void array_copy_from_fabric_kernel(wp::fabricarray_t<void> src,
1093
- void* dst_data, int dst_stride, const int* dst_indices,
1094
- int elem_size)
1321
+ void* dst_data, size_t dst_stride, const int* dst_indices,
1322
+ size_t elem_size)
1095
1323
  {
1096
- int tid = blockIdx.x * blockDim.x + threadIdx.x;
1324
+ size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1097
1325
 
1098
1326
  if (tid < src.size)
1099
1327
  {
1100
- int dst_idx = dst_indices ? dst_indices[tid] : tid;
1328
+ size_t dst_idx = dst_indices ? dst_indices[tid] : tid;
1101
1329
  void* dst_ptr = (char*)dst_data + dst_idx * dst_stride;
1102
1330
  const void* src_ptr = fabricarray_element_ptr(src, tid, elem_size);
1103
1331
  memcpy(dst_ptr, src_ptr, elem_size);
@@ -1105,15 +1333,15 @@ static __global__ void array_copy_from_fabric_kernel(wp::fabricarray_t<void> src
1105
1333
  }
1106
1334
 
1107
1335
  static __global__ void array_copy_from_fabric_indexed_kernel(wp::indexedfabricarray_t<void> src,
1108
- void* dst_data, int dst_stride, const int* dst_indices,
1109
- int elem_size)
1336
+ void* dst_data, size_t dst_stride, const int* dst_indices,
1337
+ size_t elem_size)
1110
1338
  {
1111
- int tid = blockIdx.x * blockDim.x + threadIdx.x;
1339
+ size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1112
1340
 
1113
1341
  if (tid < src.size)
1114
1342
  {
1115
- int src_index = src.indices[tid];
1116
- int dst_idx = dst_indices ? dst_indices[tid] : tid;
1343
+ size_t src_index = src.indices[tid];
1344
+ size_t dst_idx = dst_indices ? dst_indices[tid] : tid;
1117
1345
  void* dst_ptr = (char*)dst_data + dst_idx * dst_stride;
1118
1346
  const void* src_ptr = fabricarray_element_ptr(src.fa, src_index, elem_size);
1119
1347
  memcpy(dst_ptr, src_ptr, elem_size);
@@ -1121,14 +1349,14 @@ static __global__ void array_copy_from_fabric_indexed_kernel(wp::indexedfabricar
1121
1349
  }
1122
1350
 
1123
1351
  static __global__ void array_copy_to_fabric_kernel(wp::fabricarray_t<void> dst,
1124
- const void* src_data, int src_stride, const int* src_indices,
1125
- int elem_size)
1352
+ const void* src_data, size_t src_stride, const int* src_indices,
1353
+ size_t elem_size)
1126
1354
  {
1127
- int tid = blockIdx.x * blockDim.x + threadIdx.x;
1355
+ size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1128
1356
 
1129
1357
  if (tid < dst.size)
1130
1358
  {
1131
- int src_idx = src_indices ? src_indices[tid] : tid;
1359
+ size_t src_idx = src_indices ? src_indices[tid] : tid;
1132
1360
  const void* src_ptr = (const char*)src_data + src_idx * src_stride;
1133
1361
  void* dst_ptr = fabricarray_element_ptr(dst, tid, elem_size);
1134
1362
  memcpy(dst_ptr, src_ptr, elem_size);
@@ -1136,25 +1364,25 @@ static __global__ void array_copy_to_fabric_kernel(wp::fabricarray_t<void> dst,
1136
1364
  }
1137
1365
 
1138
1366
  static __global__ void array_copy_to_fabric_indexed_kernel(wp::indexedfabricarray_t<void> dst,
1139
- const void* src_data, int src_stride, const int* src_indices,
1140
- int elem_size)
1367
+ const void* src_data, size_t src_stride, const int* src_indices,
1368
+ size_t elem_size)
1141
1369
  {
1142
- int tid = blockIdx.x * blockDim.x + threadIdx.x;
1370
+ size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1143
1371
 
1144
1372
  if (tid < dst.size)
1145
1373
  {
1146
- int src_idx = src_indices ? src_indices[tid] : tid;
1374
+ size_t src_idx = src_indices ? src_indices[tid] : tid;
1147
1375
  const void* src_ptr = (const char*)src_data + src_idx * src_stride;
1148
- int dst_idx = dst.indices[tid];
1376
+ size_t dst_idx = dst.indices[tid];
1149
1377
  void* dst_ptr = fabricarray_element_ptr(dst.fa, dst_idx, elem_size);
1150
1378
  memcpy(dst_ptr, src_ptr, elem_size);
1151
1379
  }
1152
1380
  }
1153
1381
 
1154
1382
 
1155
- static __global__ void array_copy_fabric_to_fabric_kernel(wp::fabricarray_t<void> dst, wp::fabricarray_t<void> src, int elem_size)
1383
+ static __global__ void array_copy_fabric_to_fabric_kernel(wp::fabricarray_t<void> dst, wp::fabricarray_t<void> src, size_t elem_size)
1156
1384
  {
1157
- int tid = blockIdx.x * blockDim.x + threadIdx.x;
1385
+ size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1158
1386
 
1159
1387
  if (tid < dst.size)
1160
1388
  {
@@ -1165,27 +1393,27 @@ static __global__ void array_copy_fabric_to_fabric_kernel(wp::fabricarray_t<void
1165
1393
  }
1166
1394
 
1167
1395
 
1168
- static __global__ void array_copy_fabric_to_fabric_indexed_kernel(wp::indexedfabricarray_t<void> dst, wp::fabricarray_t<void> src, int elem_size)
1396
+ static __global__ void array_copy_fabric_to_fabric_indexed_kernel(wp::indexedfabricarray_t<void> dst, wp::fabricarray_t<void> src, size_t elem_size)
1169
1397
  {
1170
- int tid = blockIdx.x * blockDim.x + threadIdx.x;
1398
+ size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1171
1399
 
1172
1400
  if (tid < dst.size)
1173
1401
  {
1174
1402
  const void* src_ptr = fabricarray_element_ptr(src, tid, elem_size);
1175
- int dst_index = dst.indices[tid];
1403
+ size_t dst_index = dst.indices[tid];
1176
1404
  void* dst_ptr = fabricarray_element_ptr(dst.fa, dst_index, elem_size);
1177
1405
  memcpy(dst_ptr, src_ptr, elem_size);
1178
1406
  }
1179
1407
  }
1180
1408
 
1181
1409
 
1182
- static __global__ void array_copy_fabric_indexed_to_fabric_kernel(wp::fabricarray_t<void> dst, wp::indexedfabricarray_t<void> src, int elem_size)
1410
+ static __global__ void array_copy_fabric_indexed_to_fabric_kernel(wp::fabricarray_t<void> dst, wp::indexedfabricarray_t<void> src, size_t elem_size)
1183
1411
  {
1184
- int tid = blockIdx.x * blockDim.x + threadIdx.x;
1412
+ size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1185
1413
 
1186
1414
  if (tid < dst.size)
1187
1415
  {
1188
- int src_index = src.indices[tid];
1416
+ size_t src_index = src.indices[tid];
1189
1417
  const void* src_ptr = fabricarray_element_ptr(src.fa, src_index, elem_size);
1190
1418
  void* dst_ptr = fabricarray_element_ptr(dst, tid, elem_size);
1191
1419
  memcpy(dst_ptr, src_ptr, elem_size);
@@ -1193,14 +1421,14 @@ static __global__ void array_copy_fabric_indexed_to_fabric_kernel(wp::fabricarra
1193
1421
  }
1194
1422
 
1195
1423
 
1196
- static __global__ void array_copy_fabric_indexed_to_fabric_indexed_kernel(wp::indexedfabricarray_t<void> dst, wp::indexedfabricarray_t<void> src, int elem_size)
1424
+ static __global__ void array_copy_fabric_indexed_to_fabric_indexed_kernel(wp::indexedfabricarray_t<void> dst, wp::indexedfabricarray_t<void> src, size_t elem_size)
1197
1425
  {
1198
- int tid = blockIdx.x * blockDim.x + threadIdx.x;
1426
+ size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1199
1427
 
1200
1428
  if (tid < dst.size)
1201
1429
  {
1202
- int src_index = src.indices[tid];
1203
- int dst_index = dst.indices[tid];
1430
+ size_t src_index = src.indices[tid];
1431
+ size_t dst_index = dst.indices[tid];
1204
1432
  const void* src_ptr = fabricarray_element_ptr(src.fa, src_index, elem_size);
1205
1433
  void* dst_ptr = fabricarray_element_ptr(dst.fa, dst_index, elem_size);
1206
1434
  memcpy(dst_ptr, src_ptr, elem_size);
@@ -1439,9 +1667,9 @@ WP_API bool wp_array_copy_device(void* context, void* dst, void* src, int dst_ty
1439
1667
  }
1440
1668
  case 2:
1441
1669
  {
1442
- wp::vec_t<2, int> shape_v(src_shape[0], src_shape[1]);
1443
- wp::vec_t<2, int> src_strides_v(src_strides[0], src_strides[1]);
1444
- wp::vec_t<2, int> dst_strides_v(dst_strides[0], dst_strides[1]);
1670
+ wp::vec_t<2, size_t> shape_v(src_shape[0], src_shape[1]);
1671
+ wp::vec_t<2, size_t> src_strides_v(src_strides[0], src_strides[1]);
1672
+ wp::vec_t<2, size_t> dst_strides_v(dst_strides[0], dst_strides[1]);
1445
1673
  wp::vec_t<2, const int*> src_indices_v(src_indices[0], src_indices[1]);
1446
1674
  wp::vec_t<2, const int*> dst_indices_v(dst_indices[0], dst_indices[1]);
1447
1675
 
@@ -1453,9 +1681,9 @@ WP_API bool wp_array_copy_device(void* context, void* dst, void* src, int dst_ty
1453
1681
  }
1454
1682
  case 3:
1455
1683
  {
1456
- wp::vec_t<3, int> shape_v(src_shape[0], src_shape[1], src_shape[2]);
1457
- wp::vec_t<3, int> src_strides_v(src_strides[0], src_strides[1], src_strides[2]);
1458
- wp::vec_t<3, int> dst_strides_v(dst_strides[0], dst_strides[1], dst_strides[2]);
1684
+ wp::vec_t<3, size_t> shape_v(src_shape[0], src_shape[1], src_shape[2]);
1685
+ wp::vec_t<3, size_t> src_strides_v(src_strides[0], src_strides[1], src_strides[2]);
1686
+ wp::vec_t<3, size_t> dst_strides_v(dst_strides[0], dst_strides[1], dst_strides[2]);
1459
1687
  wp::vec_t<3, const int*> src_indices_v(src_indices[0], src_indices[1], src_indices[2]);
1460
1688
  wp::vec_t<3, const int*> dst_indices_v(dst_indices[0], dst_indices[1], dst_indices[2]);
1461
1689
 
@@ -1467,9 +1695,9 @@ WP_API bool wp_array_copy_device(void* context, void* dst, void* src, int dst_ty
1467
1695
  }
1468
1696
  case 4:
1469
1697
  {
1470
- wp::vec_t<4, int> shape_v(src_shape[0], src_shape[1], src_shape[2], src_shape[3]);
1471
- wp::vec_t<4, int> src_strides_v(src_strides[0], src_strides[1], src_strides[2], src_strides[3]);
1472
- wp::vec_t<4, int> dst_strides_v(dst_strides[0], dst_strides[1], dst_strides[2], dst_strides[3]);
1698
+ wp::vec_t<4, size_t> shape_v(src_shape[0], src_shape[1], src_shape[2], src_shape[3]);
1699
+ wp::vec_t<4, size_t> src_strides_v(src_strides[0], src_strides[1], src_strides[2], src_strides[3]);
1700
+ wp::vec_t<4, size_t> dst_strides_v(dst_strides[0], dst_strides[1], dst_strides[2], dst_strides[3]);
1473
1701
  wp::vec_t<4, const int*> src_indices_v(src_indices[0], src_indices[1], src_indices[2], src_indices[3]);
1474
1702
  wp::vec_t<4, const int*> dst_indices_v(dst_indices[0], dst_indices[1], dst_indices[2], dst_indices[3]);
1475
1703
 
@@ -1489,94 +1717,94 @@ WP_API bool wp_array_copy_device(void* context, void* dst, void* src, int dst_ty
1489
1717
 
1490
1718
 
1491
1719
  static __global__ void array_fill_1d_kernel(void* data,
1492
- int n,
1493
- int stride,
1720
+ size_t n,
1721
+ size_t stride,
1494
1722
  const int* indices,
1495
1723
  const void* value,
1496
- int value_size)
1724
+ size_t value_size)
1497
1725
  {
1498
- int i = blockIdx.x * blockDim.x + threadIdx.x;
1726
+ size_t i = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1499
1727
  if (i < n)
1500
1728
  {
1501
- int idx = indices ? indices[i] : i;
1729
+ size_t idx = indices ? indices[i] : i;
1502
1730
  char* p = (char*)data + idx * stride;
1503
1731
  memcpy(p, value, value_size);
1504
1732
  }
1505
1733
  }
1506
1734
 
1507
1735
  static __global__ void array_fill_2d_kernel(void* data,
1508
- wp::vec_t<2, int> shape,
1509
- wp::vec_t<2, int> strides,
1736
+ wp::vec_t<2, size_t> shape,
1737
+ wp::vec_t<2, size_t> strides,
1510
1738
  wp::vec_t<2, const int*> indices,
1511
1739
  const void* value,
1512
- int value_size)
1740
+ size_t value_size)
1513
1741
  {
1514
- int tid = blockIdx.x * blockDim.x + threadIdx.x;
1515
- int n = shape[1];
1516
- int i = tid / n;
1517
- int j = tid % n;
1742
+ size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1743
+ size_t n = shape[1];
1744
+ size_t i = tid / n;
1745
+ size_t j = tid % n;
1518
1746
  if (i < shape[0] /*&& j < shape[1]*/)
1519
1747
  {
1520
- int idx0 = indices[0] ? indices[0][i] : i;
1521
- int idx1 = indices[1] ? indices[1][j] : j;
1748
+ size_t idx0 = indices[0] ? indices[0][i] : i;
1749
+ size_t idx1 = indices[1] ? indices[1][j] : j;
1522
1750
  char* p = (char*)data + idx0 * strides[0] + idx1 * strides[1];
1523
1751
  memcpy(p, value, value_size);
1524
1752
  }
1525
1753
  }
1526
1754
 
1527
1755
  static __global__ void array_fill_3d_kernel(void* data,
1528
- wp::vec_t<3, int> shape,
1529
- wp::vec_t<3, int> strides,
1756
+ wp::vec_t<3, size_t> shape,
1757
+ wp::vec_t<3, size_t> strides,
1530
1758
  wp::vec_t<3, const int*> indices,
1531
1759
  const void* value,
1532
- int value_size)
1533
- {
1534
- int tid = blockIdx.x * blockDim.x + threadIdx.x;
1535
- int n = shape[1];
1536
- int o = shape[2];
1537
- int i = tid / (n * o);
1538
- int j = tid % (n * o) / o;
1539
- int k = tid % o;
1760
+ size_t value_size)
1761
+ {
1762
+ size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1763
+ size_t n = shape[1];
1764
+ size_t o = shape[2];
1765
+ size_t i = tid / (n * o);
1766
+ size_t j = tid % (n * o) / o;
1767
+ size_t k = tid % o;
1540
1768
  if (i < shape[0] && j < shape[1] /*&& k < shape[2]*/)
1541
1769
  {
1542
- int idx0 = indices[0] ? indices[0][i] : i;
1543
- int idx1 = indices[1] ? indices[1][j] : j;
1544
- int idx2 = indices[2] ? indices[2][k] : k;
1770
+ size_t idx0 = indices[0] ? indices[0][i] : i;
1771
+ size_t idx1 = indices[1] ? indices[1][j] : j;
1772
+ size_t idx2 = indices[2] ? indices[2][k] : k;
1545
1773
  char* p = (char*)data + idx0 * strides[0] + idx1 * strides[1] + idx2 * strides[2];
1546
1774
  memcpy(p, value, value_size);
1547
1775
  }
1548
1776
  }
1549
1777
 
1550
1778
  static __global__ void array_fill_4d_kernel(void* data,
1551
- wp::vec_t<4, int> shape,
1552
- wp::vec_t<4, int> strides,
1779
+ wp::vec_t<4, size_t> shape,
1780
+ wp::vec_t<4, size_t> strides,
1553
1781
  wp::vec_t<4, const int*> indices,
1554
1782
  const void* value,
1555
- int value_size)
1556
- {
1557
- int tid = blockIdx.x * blockDim.x + threadIdx.x;
1558
- int n = shape[1];
1559
- int o = shape[2];
1560
- int p = shape[3];
1561
- int i = tid / (n * o * p);
1562
- int j = tid % (n * o * p) / (o * p);
1563
- int k = tid % (o * p) / p;
1564
- int l = tid % p;
1783
+ size_t value_size)
1784
+ {
1785
+ size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1786
+ size_t n = shape[1];
1787
+ size_t o = shape[2];
1788
+ size_t p = shape[3];
1789
+ size_t i = tid / (n * o * p);
1790
+ size_t j = tid % (n * o * p) / (o * p);
1791
+ size_t k = tid % (o * p) / p;
1792
+ size_t l = tid % p;
1565
1793
  if (i < shape[0] && j < shape[1] && k < shape[2] /*&& l < shape[3]*/)
1566
1794
  {
1567
- int idx0 = indices[0] ? indices[0][i] : i;
1568
- int idx1 = indices[1] ? indices[1][j] : j;
1569
- int idx2 = indices[2] ? indices[2][k] : k;
1570
- int idx3 = indices[3] ? indices[3][l] : l;
1795
+ size_t idx0 = indices[0] ? indices[0][i] : i;
1796
+ size_t idx1 = indices[1] ? indices[1][j] : j;
1797
+ size_t idx2 = indices[2] ? indices[2][k] : k;
1798
+ size_t idx3 = indices[3] ? indices[3][l] : l;
1571
1799
  char* p = (char*)data + idx0 * strides[0] + idx1 * strides[1] + idx2 * strides[2] + idx3 * strides[3];
1572
1800
  memcpy(p, value, value_size);
1573
1801
  }
1574
1802
  }
1575
1803
 
1576
1804
 
1577
- static __global__ void array_fill_fabric_kernel(wp::fabricarray_t<void> fa, const void* value, int value_size)
1805
+ static __global__ void array_fill_fabric_kernel(wp::fabricarray_t<void> fa, const void* value, size_t value_size)
1578
1806
  {
1579
- int tid = blockIdx.x * blockDim.x + threadIdx.x;
1807
+ size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1580
1808
  if (tid < fa.size)
1581
1809
  {
1582
1810
  void* dst_ptr = fabricarray_element_ptr(fa, tid, value_size);
@@ -1585,9 +1813,9 @@ static __global__ void array_fill_fabric_kernel(wp::fabricarray_t<void> fa, cons
1585
1813
  }
1586
1814
 
1587
1815
 
1588
- static __global__ void array_fill_fabric_indexed_kernel(wp::indexedfabricarray_t<void> ifa, const void* value, int value_size)
1816
+ static __global__ void array_fill_fabric_indexed_kernel(wp::indexedfabricarray_t<void> ifa, const void* value, size_t value_size)
1589
1817
  {
1590
- int tid = blockIdx.x * blockDim.x + threadIdx.x;
1818
+ size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1591
1819
  if (tid < ifa.size)
1592
1820
  {
1593
1821
  size_t idx = size_t(ifa.indices[tid]);
@@ -1654,67 +1882,76 @@ WP_API void wp_array_fill_device(void* context, void* arr_ptr, int arr_type, con
1654
1882
 
1655
1883
  ContextGuard guard(context);
1656
1884
 
1657
- // copy value to device memory
1658
- // TODO: use a persistent stream-local staging buffer to avoid allocs?
1659
- void* value_devptr = wp_alloc_device(WP_CURRENT_CONTEXT, value_size);
1660
- check_cuda(cudaMemcpyAsync(value_devptr, value_ptr, value_size, cudaMemcpyHostToDevice, get_current_stream()));
1885
+ void* value_devptr = NULL; // fill value in device memory
1886
+ bool free_devptr = true; // whether we need to free the memory
1887
+
1888
+ // prepare the fill value in a graph-friendly way
1889
+ if (!capturable_tmp_alloc(WP_CURRENT_CONTEXT, value_ptr, value_size, &value_devptr, &free_devptr))
1890
+ {
1891
+ fprintf(stderr, "Warp fill error: failed to copy value to device memory\n");
1892
+ return;
1893
+ }
1661
1894
 
1662
- // handle fabric arrays
1663
1895
  if (fa)
1664
1896
  {
1897
+ // handle fabric arrays
1665
1898
  wp_launch_device(WP_CURRENT_CONTEXT, array_fill_fabric_kernel, n,
1666
1899
  (*fa, value_devptr, value_size));
1667
- return;
1668
1900
  }
1669
1901
  else if (ifa)
1670
1902
  {
1903
+ // handle indexed fabric arrays
1671
1904
  wp_launch_device(WP_CURRENT_CONTEXT, array_fill_fabric_indexed_kernel, n,
1672
1905
  (*ifa, value_devptr, value_size));
1673
- return;
1674
- }
1675
-
1676
- // handle regular or indexed arrays
1677
- switch (ndim)
1678
- {
1679
- case 1:
1680
- {
1681
- wp_launch_device(WP_CURRENT_CONTEXT, array_fill_1d_kernel, n,
1682
- (data, shape[0], strides[0], indices[0], value_devptr, value_size));
1683
- break;
1684
- }
1685
- case 2:
1686
- {
1687
- wp::vec_t<2, int> shape_v(shape[0], shape[1]);
1688
- wp::vec_t<2, int> strides_v(strides[0], strides[1]);
1689
- wp::vec_t<2, const int*> indices_v(indices[0], indices[1]);
1690
- wp_launch_device(WP_CURRENT_CONTEXT, array_fill_2d_kernel, n,
1691
- (data, shape_v, strides_v, indices_v, value_devptr, value_size));
1692
- break;
1693
1906
  }
1694
- case 3:
1907
+ else
1695
1908
  {
1696
- wp::vec_t<3, int> shape_v(shape[0], shape[1], shape[2]);
1697
- wp::vec_t<3, int> strides_v(strides[0], strides[1], strides[2]);
1698
- wp::vec_t<3, const int*> indices_v(indices[0], indices[1], indices[2]);
1699
- wp_launch_device(WP_CURRENT_CONTEXT, array_fill_3d_kernel, n,
1700
- (data, shape_v, strides_v, indices_v, value_devptr, value_size));
1701
- break;
1909
+ // handle regular or indexed arrays
1910
+ switch (ndim)
1911
+ {
1912
+ case 1:
1913
+ {
1914
+ wp_launch_device(WP_CURRENT_CONTEXT, array_fill_1d_kernel, n,
1915
+ (data, shape[0], strides[0], indices[0], value_devptr, value_size));
1916
+ break;
1917
+ }
1918
+ case 2:
1919
+ {
1920
+ wp::vec_t<2, size_t> shape_v(shape[0], shape[1]);
1921
+ wp::vec_t<2, size_t> strides_v(strides[0], strides[1]);
1922
+ wp::vec_t<2, const int*> indices_v(indices[0], indices[1]);
1923
+ wp_launch_device(WP_CURRENT_CONTEXT, array_fill_2d_kernel, n,
1924
+ (data, shape_v, strides_v, indices_v, value_devptr, value_size));
1925
+ break;
1926
+ }
1927
+ case 3:
1928
+ {
1929
+ wp::vec_t<3, size_t> shape_v(shape[0], shape[1], shape[2]);
1930
+ wp::vec_t<3, size_t> strides_v(strides[0], strides[1], strides[2]);
1931
+ wp::vec_t<3, const int*> indices_v(indices[0], indices[1], indices[2]);
1932
+ wp_launch_device(WP_CURRENT_CONTEXT, array_fill_3d_kernel, n,
1933
+ (data, shape_v, strides_v, indices_v, value_devptr, value_size));
1934
+ break;
1935
+ }
1936
+ case 4:
1937
+ {
1938
+ wp::vec_t<4, size_t> shape_v(shape[0], shape[1], shape[2], shape[3]);
1939
+ wp::vec_t<4, size_t> strides_v(strides[0], strides[1], strides[2], strides[3]);
1940
+ wp::vec_t<4, const int*> indices_v(indices[0], indices[1], indices[2], indices[3]);
1941
+ wp_launch_device(WP_CURRENT_CONTEXT, array_fill_4d_kernel, n,
1942
+ (data, shape_v, strides_v, indices_v, value_devptr, value_size));
1943
+ break;
1944
+ }
1945
+ default:
1946
+ fprintf(stderr, "Warp fill error: invalid array dimensionality (%d)\n", ndim);
1947
+ break;
1948
+ }
1702
1949
  }
1703
- case 4:
1950
+
1951
+ if (free_devptr)
1704
1952
  {
1705
- wp::vec_t<4, int> shape_v(shape[0], shape[1], shape[2], shape[3]);
1706
- wp::vec_t<4, int> strides_v(strides[0], strides[1], strides[2], strides[3]);
1707
- wp::vec_t<4, const int*> indices_v(indices[0], indices[1], indices[2], indices[3]);
1708
- wp_launch_device(WP_CURRENT_CONTEXT, array_fill_4d_kernel, n,
1709
- (data, shape_v, strides_v, indices_v, value_devptr, value_size));
1710
- break;
1953
+ wp_free_device(WP_CURRENT_CONTEXT, value_devptr);
1711
1954
  }
1712
- default:
1713
- fprintf(stderr, "Warp fill error: invalid array dimensionality (%d)\n", ndim);
1714
- return;
1715
- }
1716
-
1717
- wp_free_device(WP_CURRENT_CONTEXT, value_devptr);
1718
1955
  }
1719
1956
 
1720
1957
  void wp_array_scan_int_device(uint64_t in, uint64_t out, int len, bool inclusive)
@@ -2071,14 +2308,15 @@ void wp_cuda_context_synchronize(void* context)
2071
2308
 
2072
2309
  check_cu(cuCtxSynchronize_f());
2073
2310
 
2074
- if (free_deferred_allocs(context ? context : get_current_context()) > 0)
2311
+ if (!context)
2312
+ context = get_current_context();
2313
+
2314
+ if (run_deferred_actions(context) > 0)
2075
2315
  {
2076
- // ensure deferred asynchronous deallocations complete
2316
+ // ensure deferred asynchronous operations complete
2077
2317
  check_cu(cuCtxSynchronize_f());
2078
2318
  }
2079
2319
 
2080
- unload_deferred_modules(context);
2081
-
2082
2320
  // check_cuda(cudaDeviceGraphMemTrim(wp_cuda_context_get_device_ordinal(context)));
2083
2321
  }
2084
2322
 
@@ -2448,6 +2686,9 @@ void wp_cuda_stream_destroy(void* context, void* stream)
2448
2686
 
2449
2687
  wp_cuda_stream_unregister(context, stream);
2450
2688
 
2689
+ // release temporary radix sort buffer associated with this stream
2690
+ radix_sort_release(context, stream);
2691
+
2451
2692
  check_cu(cuStreamDestroy_f(static_cast<CUstream>(stream)));
2452
2693
  }
2453
2694
 
@@ -2510,15 +2751,36 @@ void wp_cuda_stream_synchronize(void* stream)
2510
2751
  check_cu(cuStreamSynchronize_f(static_cast<CUstream>(stream)));
2511
2752
  }
2512
2753
 
2513
- void wp_cuda_stream_wait_event(void* stream, void* event)
2754
+ void wp_cuda_stream_wait_event(void* stream, void* event, bool external)
2514
2755
  {
2515
- check_cu(cuStreamWaitEvent_f(static_cast<CUstream>(stream), static_cast<CUevent>(event), 0));
2756
+ // the external flag can only be used during graph capture
2757
+ if (external && !g_captures.empty() && wp_cuda_stream_is_capturing(stream))
2758
+ {
2759
+ // wait for an external event during graph capture
2760
+ check_cu(cuStreamWaitEvent_f(static_cast<CUstream>(stream), static_cast<CUevent>(event), CU_EVENT_WAIT_EXTERNAL));
2761
+ }
2762
+ else
2763
+ {
2764
+ check_cu(cuStreamWaitEvent_f(static_cast<CUstream>(stream), static_cast<CUevent>(event), CU_EVENT_WAIT_DEFAULT));
2765
+ }
2516
2766
  }
2517
2767
 
2518
- void wp_cuda_stream_wait_stream(void* stream, void* other_stream, void* event)
2768
+ void wp_cuda_stream_wait_stream(void* stream, void* other_stream, void* event, bool external)
2519
2769
  {
2520
- check_cu(cuEventRecord_f(static_cast<CUevent>(event), static_cast<CUstream>(other_stream)));
2521
- check_cu(cuStreamWaitEvent_f(static_cast<CUstream>(stream), static_cast<CUevent>(event), 0));
2770
+ unsigned record_flags = CU_EVENT_RECORD_DEFAULT;
2771
+ unsigned wait_flags = CU_EVENT_WAIT_DEFAULT;
2772
+
2773
+ // the external flag can only be used during graph capture
2774
+ if (external && !g_captures.empty())
2775
+ {
2776
+ if (wp_cuda_stream_is_capturing(other_stream))
2777
+ record_flags = CU_EVENT_RECORD_EXTERNAL;
2778
+ if (wp_cuda_stream_is_capturing(stream))
2779
+ wait_flags = CU_EVENT_WAIT_EXTERNAL;
2780
+ }
2781
+
2782
+ check_cu(cuEventRecordWithFlags_f(static_cast<CUevent>(event), static_cast<CUstream>(other_stream), record_flags));
2783
+ check_cu(cuStreamWaitEvent_f(static_cast<CUstream>(stream), static_cast<CUevent>(event), wait_flags));
2522
2784
  }
2523
2785
 
2524
2786
  int wp_cuda_stream_is_capturing(void* stream)
@@ -2571,11 +2833,12 @@ int wp_cuda_event_query(void* event)
2571
2833
  return res;
2572
2834
  }
2573
2835
 
2574
- void wp_cuda_event_record(void* event, void* stream, bool timing)
2836
+ void wp_cuda_event_record(void* event, void* stream, bool external)
2575
2837
  {
2576
- if (timing && !g_captures.empty() && wp_cuda_stream_is_capturing(stream))
2838
+ // the external flag can only be used during graph capture
2839
+ if (external && !g_captures.empty() && wp_cuda_stream_is_capturing(stream))
2577
2840
  {
2578
- // record timing event during graph capture
2841
+ // record external event during graph capture (e.g., for timing or when explicitly specified by the user)
2579
2842
  check_cu(cuEventRecordWithFlags_f(static_cast<CUevent>(event), static_cast<CUstream>(stream), CU_EVENT_RECORD_EXTERNAL));
2580
2843
  }
2581
2844
  else
@@ -2625,7 +2888,7 @@ bool wp_cuda_graph_begin_capture(void* context, void* stream, int external)
2625
2888
  else
2626
2889
  {
2627
2890
  // start the capture
2628
- if (!check_cuda(cudaStreamBeginCapture(cuda_stream, cudaStreamCaptureModeGlobal)))
2891
+ if (!check_cuda(cudaStreamBeginCapture(cuda_stream, cudaStreamCaptureModeThreadLocal)))
2629
2892
  return false;
2630
2893
  }
2631
2894
 
@@ -2669,6 +2932,7 @@ bool wp_cuda_graph_end_capture(void* context, void* stream, void** graph_ret)
2669
2932
  // get capture info
2670
2933
  bool external = capture->external;
2671
2934
  uint64_t capture_id = capture->id;
2935
+ std::vector<FreeInfo> tmp_allocs = capture->tmp_allocs;
2672
2936
 
2673
2937
  // clear capture info
2674
2938
  stream_info->capture = NULL;
@@ -2738,15 +3002,17 @@ bool wp_cuda_graph_end_capture(void* context, void* stream, void** graph_ret)
2738
3002
  unfreed_allocs.push_back(it->first);
2739
3003
  }
2740
3004
 
2741
- if (!unfreed_allocs.empty())
3005
+ if (!unfreed_allocs.empty() || !tmp_allocs.empty())
2742
3006
  {
2743
3007
  // Create a user object that will notify us when the instantiated graph is destroyed.
2744
3008
  // This works for external captures also, since we wouldn't otherwise know when
2745
3009
  // the externally-created graph instance gets deleted.
2746
3010
  // This callback is guaranteed to arrive after the graph has finished executing on the device,
2747
3011
  // not necessarily when cudaGraphExecDestroy() is called.
2748
- GraphInfo* graph_info = new GraphInfo;
3012
+ GraphDestroyCallbackInfo* graph_info = new GraphDestroyCallbackInfo;
3013
+ graph_info->context = context ? context : get_current_context();
2749
3014
  graph_info->unfreed_allocs = unfreed_allocs;
3015
+ graph_info->tmp_allocs = tmp_allocs;
2750
3016
  cudaUserObject_t user_object;
2751
3017
  check_cuda(cudaUserObjectCreate(&user_object, graph_info, on_graph_destroy, 1, cudaUserObjectNoDestructorSync));
2752
3018
  check_cuda(cudaGraphRetainUserObject(graph, user_object, 1, cudaGraphUserObjectMove));
@@ -2770,8 +3036,7 @@ bool wp_cuda_graph_end_capture(void* context, void* stream, void** graph_ret)
2770
3036
  // process deferred free list if no more captures are ongoing
2771
3037
  if (g_captures.empty())
2772
3038
  {
2773
- free_deferred_allocs();
2774
- unload_deferred_modules();
3039
+ run_deferred_actions();
2775
3040
  }
2776
3041
 
2777
3042
  if (graph_ret)
@@ -2811,11 +3076,12 @@ bool wp_cuda_graph_create_exec(void* context, void* stream, void* graph, void**
2811
3076
  // Support for conditional graph nodes available with CUDA 12.4+.
2812
3077
  #if CUDA_VERSION >= 12040
2813
3078
 
2814
- // CUBIN data for compiled conditional modules, loaded on demand, keyed on device architecture
2815
- static std::map<int, void*> g_conditional_cubins;
3079
+ // CUBIN or PTX data for compiled conditional modules, loaded on demand, keyed on device architecture
3080
+ using ModuleKey = std::pair<int, bool>; // <arch, use_ptx>
3081
+ static std::map<ModuleKey, void*> g_conditional_modules;
2816
3082
 
2817
3083
  // Compile module with conditional helper kernels
2818
- static void* compile_conditional_module(int arch)
3084
+ static void* compile_conditional_module(int arch, bool use_ptx)
2819
3085
  {
2820
3086
  static const char* kernel_source = R"(
2821
3087
  typedef __device_builtin__ unsigned long long cudaGraphConditionalHandle;
@@ -2844,8 +3110,9 @@ static void* compile_conditional_module(int arch)
2844
3110
  )";
2845
3111
 
2846
3112
  // avoid recompilation
2847
- auto it = g_conditional_cubins.find(arch);
2848
- if (it != g_conditional_cubins.end())
3113
+ ModuleKey key = {arch, use_ptx};
3114
+ auto it = g_conditional_modules.find(key);
3115
+ if (it != g_conditional_modules.end())
2849
3116
  return it->second;
2850
3117
 
2851
3118
  nvrtcProgram prog;
@@ -2853,11 +3120,23 @@ static void* compile_conditional_module(int arch)
2853
3120
  return NULL;
2854
3121
 
2855
3122
  char arch_opt[128];
2856
- snprintf(arch_opt, sizeof(arch_opt), "--gpu-architecture=sm_%d", arch);
3123
+ if (use_ptx)
3124
+ snprintf(arch_opt, sizeof(arch_opt), "--gpu-architecture=compute_%d", arch);
3125
+ else
3126
+ snprintf(arch_opt, sizeof(arch_opt), "--gpu-architecture=sm_%d", arch);
2857
3127
 
2858
3128
  std::vector<const char*> opts;
2859
3129
  opts.push_back(arch_opt);
2860
3130
 
3131
+ const bool print_debug = (std::getenv("WARP_DEBUG") != nullptr);
3132
+ if (print_debug)
3133
+ {
3134
+ printf("NVRTC options (conditional module, arch=%d, use_ptx=%s):\n", arch, use_ptx ? "true" : "false");
3135
+ for(auto o: opts) {
3136
+ printf("%s\n", o);
3137
+ }
3138
+ }
3139
+
2861
3140
  if (!check_nvrtc(nvrtcCompileProgram(prog, int(opts.size()), opts.data())))
2862
3141
  {
2863
3142
  size_t log_size;
@@ -2874,23 +3153,37 @@ static void* compile_conditional_module(int arch)
2874
3153
  // get output
2875
3154
  char* output = NULL;
2876
3155
  size_t output_size = 0;
2877
- check_nvrtc(nvrtcGetCUBINSize(prog, &output_size));
2878
- if (output_size > 0)
3156
+
3157
+ if (use_ptx)
3158
+ {
3159
+ check_nvrtc(nvrtcGetPTXSize(prog, &output_size));
3160
+ if (output_size > 0)
3161
+ {
3162
+ output = new char[output_size];
3163
+ if (check_nvrtc(nvrtcGetPTX(prog, output)))
3164
+ g_conditional_modules[key] = output;
3165
+ }
3166
+ }
3167
+ else
2879
3168
  {
2880
- output = new char[output_size];
2881
- if (check_nvrtc(nvrtcGetCUBIN(prog, output)))
2882
- g_conditional_cubins[arch] = output;
3169
+ check_nvrtc(nvrtcGetCUBINSize(prog, &output_size));
3170
+ if (output_size > 0)
3171
+ {
3172
+ output = new char[output_size];
3173
+ if (check_nvrtc(nvrtcGetCUBIN(prog, output)))
3174
+ g_conditional_modules[key] = output;
3175
+ }
2883
3176
  }
2884
3177
 
2885
3178
  nvrtcDestroyProgram(&prog);
2886
3179
 
2887
- // return CUBIN data
3180
+ // return CUBIN or PTX data
2888
3181
  return output;
2889
3182
  }
2890
3183
 
2891
3184
 
2892
3185
  // Load module with conditional helper kernels
2893
- static CUmodule load_conditional_module(void* context)
3186
+ static CUmodule load_conditional_module(void* context, int arch, bool use_ptx)
2894
3187
  {
2895
3188
  ContextInfo* context_info = get_context_info(context);
2896
3189
  if (!context_info)
@@ -2900,17 +3193,15 @@ static CUmodule load_conditional_module(void* context)
2900
3193
  if (context_info->conditional_module)
2901
3194
  return context_info->conditional_module;
2902
3195
 
2903
- int arch = context_info->device_info->arch;
2904
-
2905
3196
  // compile if needed
2906
- void* compiled_module = compile_conditional_module(arch);
3197
+ void* compiled_module = compile_conditional_module(arch, use_ptx);
2907
3198
  if (!compiled_module)
2908
3199
  {
2909
3200
  fprintf(stderr, "Warp error: Failed to compile conditional kernels\n");
2910
3201
  return NULL;
2911
3202
  }
2912
3203
 
2913
- // load module
3204
+ // load module (handles both PTX and CUBIN data automatically)
2914
3205
  CUmodule module = NULL;
2915
3206
  if (!check_cu(cuModuleLoadDataEx_f(&module, compiled_module, 0, NULL, NULL)))
2916
3207
  {
@@ -2923,10 +3214,10 @@ static CUmodule load_conditional_module(void* context)
2923
3214
  return module;
2924
3215
  }
2925
3216
 
2926
- static CUfunction get_conditional_kernel(void* context, const char* name)
3217
+ static CUfunction get_conditional_kernel(void* context, int arch, bool use_ptx, const char* name)
2927
3218
  {
2928
3219
  // load module if needed
2929
- CUmodule module = load_conditional_module(context);
3220
+ CUmodule module = load_conditional_module(context, arch, use_ptx);
2930
3221
  if (!module)
2931
3222
  return NULL;
2932
3223
 
@@ -2966,7 +3257,7 @@ bool wp_cuda_graph_resume_capture(void* context, void* stream, void* graph)
2966
3257
  leaf_nodes.data(),
2967
3258
  nullptr,
2968
3259
  leaf_nodes.size(),
2969
- cudaStreamCaptureModeGlobal)))
3260
+ cudaStreamCaptureModeThreadLocal)))
2970
3261
  return false;
2971
3262
 
2972
3263
  return true;
@@ -2976,7 +3267,7 @@ bool wp_cuda_graph_resume_capture(void* context, void* stream, void* graph)
2976
3267
  // https://developer.nvidia.com/blog/dynamic-control-flow-in-cuda-graphs-with-conditional-nodes/
2977
3268
  // condition is a gpu pointer
2978
3269
  // if_graph_ret and else_graph_ret should be NULL if not needed
2979
- bool wp_cuda_graph_insert_if_else(void* context, void* stream, int* condition, void** if_graph_ret, void** else_graph_ret)
3270
+ bool wp_cuda_graph_insert_if_else(void* context, void* stream, int arch, bool use_ptx, int* condition, void** if_graph_ret, void** else_graph_ret)
2980
3271
  {
2981
3272
  bool has_if = if_graph_ret != NULL;
2982
3273
  bool has_else = else_graph_ret != NULL;
@@ -3019,9 +3310,9 @@ bool wp_cuda_graph_insert_if_else(void* context, void* stream, int* condition, v
3019
3310
  // (need to negate the condition if only the else branch is used)
3020
3311
  CUfunction kernel;
3021
3312
  if (has_if)
3022
- kernel = get_conditional_kernel(context, "set_conditional_if_handle_kernel");
3313
+ kernel = get_conditional_kernel(context, arch, use_ptx, "set_conditional_if_handle_kernel");
3023
3314
  else
3024
- kernel = get_conditional_kernel(context, "set_conditional_else_handle_kernel");
3315
+ kernel = get_conditional_kernel(context, arch, use_ptx, "set_conditional_else_handle_kernel");
3025
3316
 
3026
3317
  if (!kernel)
3027
3318
  {
@@ -3072,7 +3363,7 @@ bool wp_cuda_graph_insert_if_else(void* context, void* stream, int* condition, v
3072
3363
  check_cuda(cudaGraphConditionalHandleCreate(&if_handle, cuda_graph));
3073
3364
  check_cuda(cudaGraphConditionalHandleCreate(&else_handle, cuda_graph));
3074
3365
 
3075
- CUfunction kernel = get_conditional_kernel(context, "set_conditional_if_else_handles_kernel");
3366
+ CUfunction kernel = get_conditional_kernel(context, arch, use_ptx, "set_conditional_if_else_handles_kernel");
3076
3367
  if (!kernel)
3077
3368
  {
3078
3369
  wp::set_error_string("Failed to get built-in conditional kernel");
@@ -3273,7 +3564,7 @@ bool wp_cuda_graph_insert_child_graph(void* context, void* stream, void* child_g
3273
3564
  return true;
3274
3565
  }
3275
3566
 
3276
- bool wp_cuda_graph_insert_while(void* context, void* stream, int* condition, void** body_graph_ret, uint64_t* handle_ret)
3567
+ bool wp_cuda_graph_insert_while(void* context, void* stream, int arch, bool use_ptx, int* condition, void** body_graph_ret, uint64_t* handle_ret)
3277
3568
  {
3278
3569
  // if there's no body, it's a no-op
3279
3570
  if (!body_graph_ret)
@@ -3303,7 +3594,7 @@ bool wp_cuda_graph_insert_while(void* context, void* stream, int* condition, voi
3303
3594
  return false;
3304
3595
 
3305
3596
  // launch a kernel to set the condition handle from condition pointer
3306
- CUfunction kernel = get_conditional_kernel(context, "set_conditional_if_handle_kernel");
3597
+ CUfunction kernel = get_conditional_kernel(context, arch, use_ptx, "set_conditional_if_handle_kernel");
3307
3598
  if (!kernel)
3308
3599
  {
3309
3600
  wp::set_error_string("Failed to get built-in conditional kernel");
@@ -3339,14 +3630,14 @@ bool wp_cuda_graph_insert_while(void* context, void* stream, int* condition, voi
3339
3630
  return true;
3340
3631
  }
3341
3632
 
3342
- bool wp_cuda_graph_set_condition(void* context, void* stream, int* condition, uint64_t handle)
3633
+ bool wp_cuda_graph_set_condition(void* context, void* stream, int arch, bool use_ptx, int* condition, uint64_t handle)
3343
3634
  {
3344
3635
  ContextGuard guard(context);
3345
3636
 
3346
3637
  CUstream cuda_stream = static_cast<CUstream>(stream);
3347
3638
 
3348
3639
  // launch a kernel to set the condition handle from condition pointer
3349
- CUfunction kernel = get_conditional_kernel(context, "set_conditional_if_handle_kernel");
3640
+ CUfunction kernel = get_conditional_kernel(context, arch, use_ptx, "set_conditional_if_handle_kernel");
3350
3641
  if (!kernel)
3351
3642
  {
3352
3643
  wp::set_error_string("Failed to get built-in conditional kernel");
@@ -3378,19 +3669,19 @@ bool wp_cuda_graph_resume_capture(void* context, void* stream, void* graph)
3378
3669
  return false;
3379
3670
  }
3380
3671
 
3381
- bool wp_cuda_graph_insert_if_else(void* context, void* stream, int* condition, void** if_graph_ret, void** else_graph_ret)
3672
+ bool wp_cuda_graph_insert_if_else(void* context, void* stream, int arch, bool use_ptx, int* condition, void** if_graph_ret, void** else_graph_ret)
3382
3673
  {
3383
3674
  wp::set_error_string("Warp error: Warp must be built with CUDA Toolkit 12.4+ to enable conditional graph nodes");
3384
3675
  return false;
3385
3676
  }
3386
3677
 
3387
- bool wp_cuda_graph_insert_while(void* context, void* stream, int* condition, void** body_graph_ret, uint64_t* handle_ret)
3678
+ bool wp_cuda_graph_insert_while(void* context, void* stream, int arch, bool use_ptx, int* condition, void** body_graph_ret, uint64_t* handle_ret)
3388
3679
  {
3389
3680
  wp::set_error_string("Warp error: Warp must be built with CUDA Toolkit 12.4+ to enable conditional graph nodes");
3390
3681
  return false;
3391
3682
  }
3392
3683
 
3393
- bool wp_cuda_graph_set_condition(void* context, void* stream, int* condition, uint64_t handle)
3684
+ bool wp_cuda_graph_set_condition(void* context, void* stream, int arch, bool use_ptx, int* condition, uint64_t handle)
3394
3685
  {
3395
3686
  wp::set_error_string("Warp error: Warp must be built with CUDA Toolkit 12.4+ to enable conditional graph nodes");
3396
3687
  return false;
@@ -3425,16 +3716,38 @@ bool wp_cuda_graph_launch(void* graph_exec, void* stream)
3425
3716
 
3426
3717
  bool wp_cuda_graph_destroy(void* context, void* graph)
3427
3718
  {
3428
- ContextGuard guard(context);
3429
-
3430
- return check_cuda(cudaGraphDestroy((cudaGraph_t)graph));
3719
+ // ensure there are no graph captures in progress
3720
+ if (g_captures.empty())
3721
+ {
3722
+ ContextGuard guard(context);
3723
+ return check_cuda(cudaGraphDestroy((cudaGraph_t)graph));
3724
+ }
3725
+ else
3726
+ {
3727
+ GraphDestroyInfo info;
3728
+ info.context = context ? context : get_current_context();
3729
+ info.graph = graph;
3730
+ g_deferred_graph_list.push_back(info);
3731
+ return true;
3732
+ }
3431
3733
  }
3432
3734
 
3433
3735
  bool wp_cuda_graph_exec_destroy(void* context, void* graph_exec)
3434
3736
  {
3435
- ContextGuard guard(context);
3436
-
3437
- return check_cuda(cudaGraphExecDestroy((cudaGraphExec_t)graph_exec));
3737
+ // ensure there are no graph captures in progress
3738
+ if (g_captures.empty())
3739
+ {
3740
+ ContextGuard guard(context);
3741
+ return check_cuda(cudaGraphExecDestroy((cudaGraphExec_t)graph_exec));
3742
+ }
3743
+ else
3744
+ {
3745
+ GraphDestroyInfo info;
3746
+ info.context = context ? context : get_current_context();
3747
+ info.graph_exec = graph_exec;
3748
+ g_deferred_graph_list.push_back(info);
3749
+ return true;
3750
+ }
3438
3751
  }
3439
3752
 
3440
3753
  bool write_file(const char* data, size_t size, std::string filename, const char* mode)
@@ -4287,17 +4600,5 @@ void wp_cuda_timing_end(timing_result_t* results, int size)
4287
4600
  g_cuda_timing_state = parent_state;
4288
4601
  }
4289
4602
 
4290
- // impl. files
4291
- #include "bvh.cu"
4292
- #include "mesh.cu"
4293
- #include "sort.cu"
4294
- #include "hashgrid.cu"
4295
- #include "reduce.cu"
4296
- #include "runlength_encode.cu"
4297
- #include "scan.cu"
4298
- #include "sparse.cu"
4299
- #include "volume.cu"
4300
- #include "volume_builder.cu"
4301
-
4302
4603
  //#include "spline.inl"
4303
4604
  //#include "volume.inl"