warp-lang 1.9.1__py3-none-win_amd64.whl → 1.10.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (346) hide show
  1. warp/__init__.py +301 -287
  2. warp/__init__.pyi +882 -305
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1077 -0
  5. warp/_src/build.py +620 -0
  6. warp/_src/build_dll.py +642 -0
  7. warp/{builtins.py → _src/builtins.py} +1435 -379
  8. warp/_src/codegen.py +4361 -0
  9. warp/{config.py → _src/config.py} +178 -169
  10. warp/_src/constants.py +59 -0
  11. warp/_src/context.py +8352 -0
  12. warp/_src/dlpack.py +464 -0
  13. warp/_src/fabric.py +362 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +510 -0
  16. warp/_src/fem/cache.py +689 -0
  17. warp/_src/fem/dirichlet.py +190 -0
  18. warp/{fem → _src/fem}/domain.py +42 -30
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +703 -0
  21. warp/{fem → _src/fem}/field/nodal_field.py +32 -15
  22. warp/{fem → _src/fem}/field/restriction.py +3 -1
  23. warp/{fem → _src/fem}/field/virtual.py +55 -27
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/{fem → _src/fem}/geometry/adaptive_nanogrid.py +79 -163
  26. warp/_src/fem/geometry/closest_point.py +99 -0
  27. warp/{fem → _src/fem}/geometry/deformed_geometry.py +16 -22
  28. warp/{fem → _src/fem}/geometry/element.py +34 -10
  29. warp/{fem → _src/fem}/geometry/geometry.py +50 -20
  30. warp/{fem → _src/fem}/geometry/grid_2d.py +14 -23
  31. warp/{fem → _src/fem}/geometry/grid_3d.py +14 -23
  32. warp/{fem → _src/fem}/geometry/hexmesh.py +42 -63
  33. warp/{fem → _src/fem}/geometry/nanogrid.py +256 -247
  34. warp/{fem → _src/fem}/geometry/partition.py +123 -63
  35. warp/{fem → _src/fem}/geometry/quadmesh.py +28 -45
  36. warp/{fem → _src/fem}/geometry/tetmesh.py +42 -63
  37. warp/{fem → _src/fem}/geometry/trimesh.py +28 -45
  38. warp/{fem → _src/fem}/integrate.py +166 -158
  39. warp/_src/fem/linalg.py +385 -0
  40. warp/_src/fem/operator.py +398 -0
  41. warp/_src/fem/polynomial.py +231 -0
  42. warp/{fem → _src/fem}/quadrature/pic_quadrature.py +17 -20
  43. warp/{fem → _src/fem}/quadrature/quadrature.py +97 -47
  44. warp/_src/fem/space/__init__.py +248 -0
  45. warp/{fem → _src/fem}/space/basis_function_space.py +22 -11
  46. warp/_src/fem/space/basis_space.py +681 -0
  47. warp/{fem → _src/fem}/space/dof_mapper.py +5 -3
  48. warp/{fem → _src/fem}/space/function_space.py +16 -13
  49. warp/{fem → _src/fem}/space/grid_2d_function_space.py +6 -7
  50. warp/{fem → _src/fem}/space/grid_3d_function_space.py +6 -4
  51. warp/{fem → _src/fem}/space/hexmesh_function_space.py +6 -10
  52. warp/{fem → _src/fem}/space/nanogrid_function_space.py +5 -9
  53. warp/{fem → _src/fem}/space/partition.py +119 -60
  54. warp/{fem → _src/fem}/space/quadmesh_function_space.py +6 -10
  55. warp/{fem → _src/fem}/space/restriction.py +68 -33
  56. warp/_src/fem/space/shape/__init__.py +152 -0
  57. warp/{fem → _src/fem}/space/shape/cube_shape_function.py +11 -9
  58. warp/{fem → _src/fem}/space/shape/shape_function.py +10 -9
  59. warp/{fem → _src/fem}/space/shape/square_shape_function.py +8 -6
  60. warp/{fem → _src/fem}/space/shape/tet_shape_function.py +5 -3
  61. warp/{fem → _src/fem}/space/shape/triangle_shape_function.py +5 -3
  62. warp/{fem → _src/fem}/space/tetmesh_function_space.py +5 -9
  63. warp/_src/fem/space/topology.py +461 -0
  64. warp/{fem → _src/fem}/space/trimesh_function_space.py +5 -9
  65. warp/_src/fem/types.py +114 -0
  66. warp/_src/fem/utils.py +488 -0
  67. warp/_src/jax.py +188 -0
  68. warp/_src/jax_experimental/__init__.py +14 -0
  69. warp/_src/jax_experimental/custom_call.py +389 -0
  70. warp/_src/jax_experimental/ffi.py +1286 -0
  71. warp/_src/jax_experimental/xla_ffi.py +658 -0
  72. warp/_src/marching_cubes.py +710 -0
  73. warp/_src/math.py +416 -0
  74. warp/_src/optim/__init__.py +14 -0
  75. warp/_src/optim/adam.py +165 -0
  76. warp/_src/optim/linear.py +1608 -0
  77. warp/_src/optim/sgd.py +114 -0
  78. warp/_src/paddle.py +408 -0
  79. warp/_src/render/__init__.py +14 -0
  80. warp/_src/render/imgui_manager.py +291 -0
  81. warp/_src/render/render_opengl.py +3638 -0
  82. warp/_src/render/render_usd.py +939 -0
  83. warp/_src/render/utils.py +162 -0
  84. warp/_src/sparse.py +2718 -0
  85. warp/_src/tape.py +1208 -0
  86. warp/{thirdparty → _src/thirdparty}/unittest_parallel.py +9 -2
  87. warp/_src/torch.py +393 -0
  88. warp/_src/types.py +5888 -0
  89. warp/_src/utils.py +1695 -0
  90. warp/autograd.py +12 -1054
  91. warp/bin/warp-clang.dll +0 -0
  92. warp/bin/warp.dll +0 -0
  93. warp/build.py +8 -588
  94. warp/build_dll.py +6 -721
  95. warp/codegen.py +6 -4251
  96. warp/constants.py +6 -39
  97. warp/context.py +12 -8062
  98. warp/dlpack.py +6 -444
  99. warp/examples/distributed/example_jacobi_mpi.py +4 -5
  100. warp/examples/fem/example_adaptive_grid.py +1 -1
  101. warp/examples/fem/example_apic_fluid.py +1 -1
  102. warp/examples/fem/example_burgers.py +8 -8
  103. warp/examples/fem/example_diffusion.py +1 -1
  104. warp/examples/fem/example_distortion_energy.py +1 -1
  105. warp/examples/fem/example_mixed_elasticity.py +2 -2
  106. warp/examples/fem/example_navier_stokes.py +1 -1
  107. warp/examples/fem/example_nonconforming_contact.py +7 -7
  108. warp/examples/fem/example_stokes.py +1 -1
  109. warp/examples/fem/example_stokes_transfer.py +1 -1
  110. warp/examples/fem/utils.py +2 -2
  111. warp/examples/interop/example_jax_callable.py +1 -1
  112. warp/examples/interop/example_jax_ffi_callback.py +1 -1
  113. warp/examples/interop/example_jax_kernel.py +1 -1
  114. warp/examples/tile/example_tile_mcgp.py +191 -0
  115. warp/fabric.py +6 -337
  116. warp/fem/__init__.py +159 -97
  117. warp/fem/adaptivity.py +7 -489
  118. warp/fem/cache.py +9 -648
  119. warp/fem/dirichlet.py +6 -184
  120. warp/fem/field/__init__.py +8 -109
  121. warp/fem/field/field.py +7 -652
  122. warp/fem/geometry/__init__.py +7 -18
  123. warp/fem/geometry/closest_point.py +11 -77
  124. warp/fem/linalg.py +18 -366
  125. warp/fem/operator.py +11 -369
  126. warp/fem/polynomial.py +9 -209
  127. warp/fem/space/__init__.py +5 -211
  128. warp/fem/space/basis_space.py +6 -662
  129. warp/fem/space/shape/__init__.py +41 -118
  130. warp/fem/space/topology.py +6 -437
  131. warp/fem/types.py +6 -81
  132. warp/fem/utils.py +11 -444
  133. warp/jax.py +8 -165
  134. warp/jax_experimental/__init__.py +14 -1
  135. warp/jax_experimental/custom_call.py +8 -365
  136. warp/jax_experimental/ffi.py +17 -873
  137. warp/jax_experimental/xla_ffi.py +5 -605
  138. warp/marching_cubes.py +5 -689
  139. warp/math.py +16 -393
  140. warp/native/array.h +385 -37
  141. warp/native/builtin.h +314 -37
  142. warp/native/bvh.cpp +43 -9
  143. warp/native/bvh.cu +62 -27
  144. warp/native/bvh.h +310 -309
  145. warp/native/clang/clang.cpp +102 -97
  146. warp/native/coloring.cpp +0 -1
  147. warp/native/crt.h +208 -0
  148. warp/native/exports.h +156 -0
  149. warp/native/hashgrid.cu +2 -0
  150. warp/native/intersect.h +24 -1
  151. warp/native/intersect_tri.h +44 -35
  152. warp/native/mat.h +1456 -276
  153. warp/native/mesh.cpp +4 -4
  154. warp/native/mesh.cu +4 -2
  155. warp/native/mesh.h +176 -61
  156. warp/native/quat.h +0 -52
  157. warp/native/scan.cu +2 -0
  158. warp/native/sparse.cu +7 -3
  159. warp/native/spatial.h +12 -0
  160. warp/native/tile.h +681 -89
  161. warp/native/tile_radix_sort.h +3 -3
  162. warp/native/tile_reduce.h +394 -46
  163. warp/native/tile_scan.h +4 -4
  164. warp/native/vec.h +469 -0
  165. warp/native/version.h +23 -0
  166. warp/native/volume.cpp +1 -1
  167. warp/native/volume.cu +1 -0
  168. warp/native/volume.h +1 -1
  169. warp/native/volume_builder.cu +2 -0
  170. warp/native/warp.cpp +57 -29
  171. warp/native/warp.cu +521 -250
  172. warp/native/warp.h +11 -8
  173. warp/optim/__init__.py +6 -3
  174. warp/optim/adam.py +6 -145
  175. warp/optim/linear.py +14 -1585
  176. warp/optim/sgd.py +6 -94
  177. warp/paddle.py +6 -388
  178. warp/render/__init__.py +8 -4
  179. warp/render/imgui_manager.py +7 -267
  180. warp/render/render_opengl.py +6 -3618
  181. warp/render/render_usd.py +6 -919
  182. warp/render/utils.py +6 -142
  183. warp/sparse.py +37 -2563
  184. warp/tape.py +6 -1188
  185. warp/tests/__main__.py +1 -1
  186. warp/tests/cuda/test_async.py +4 -4
  187. warp/tests/cuda/test_conditional_captures.py +1 -1
  188. warp/tests/cuda/test_multigpu.py +1 -1
  189. warp/tests/cuda/test_streams.py +58 -1
  190. warp/tests/geometry/test_bvh.py +157 -22
  191. warp/tests/geometry/test_marching_cubes.py +0 -1
  192. warp/tests/geometry/test_mesh.py +5 -3
  193. warp/tests/geometry/test_mesh_query_aabb.py +5 -12
  194. warp/tests/geometry/test_mesh_query_point.py +5 -2
  195. warp/tests/geometry/test_mesh_query_ray.py +15 -3
  196. warp/tests/geometry/test_volume_write.py +5 -5
  197. warp/tests/interop/test_dlpack.py +18 -17
  198. warp/tests/interop/test_jax.py +772 -49
  199. warp/tests/interop/test_paddle.py +1 -1
  200. warp/tests/test_adam.py +0 -1
  201. warp/tests/test_arithmetic.py +9 -9
  202. warp/tests/test_array.py +578 -100
  203. warp/tests/test_array_reduce.py +3 -3
  204. warp/tests/test_atomic.py +12 -8
  205. warp/tests/test_atomic_bitwise.py +209 -0
  206. warp/tests/test_atomic_cas.py +4 -4
  207. warp/tests/test_bool.py +2 -2
  208. warp/tests/test_builtins_resolution.py +5 -571
  209. warp/tests/test_codegen.py +33 -14
  210. warp/tests/test_conditional.py +1 -1
  211. warp/tests/test_context.py +6 -6
  212. warp/tests/test_copy.py +242 -161
  213. warp/tests/test_ctypes.py +3 -3
  214. warp/tests/test_devices.py +24 -2
  215. warp/tests/test_examples.py +16 -84
  216. warp/tests/test_fabricarray.py +35 -35
  217. warp/tests/test_fast_math.py +0 -2
  218. warp/tests/test_fem.py +56 -10
  219. warp/tests/test_fixedarray.py +3 -3
  220. warp/tests/test_func.py +8 -5
  221. warp/tests/test_generics.py +1 -1
  222. warp/tests/test_indexedarray.py +24 -24
  223. warp/tests/test_intersect.py +39 -9
  224. warp/tests/test_large.py +1 -1
  225. warp/tests/test_lerp.py +3 -1
  226. warp/tests/test_linear_solvers.py +1 -1
  227. warp/tests/test_map.py +35 -4
  228. warp/tests/test_mat.py +52 -62
  229. warp/tests/test_mat_constructors.py +4 -5
  230. warp/tests/test_mat_lite.py +1 -1
  231. warp/tests/test_mat_scalar_ops.py +121 -121
  232. warp/tests/test_math.py +34 -0
  233. warp/tests/test_module_aot.py +4 -4
  234. warp/tests/test_modules_lite.py +28 -2
  235. warp/tests/test_print.py +11 -11
  236. warp/tests/test_quat.py +93 -58
  237. warp/tests/test_runlength_encode.py +1 -1
  238. warp/tests/test_scalar_ops.py +38 -10
  239. warp/tests/test_smoothstep.py +1 -1
  240. warp/tests/test_sparse.py +126 -15
  241. warp/tests/test_spatial.py +105 -87
  242. warp/tests/test_special_values.py +6 -6
  243. warp/tests/test_static.py +7 -7
  244. warp/tests/test_struct.py +13 -2
  245. warp/tests/test_triangle_closest_point.py +48 -1
  246. warp/tests/test_types.py +27 -15
  247. warp/tests/test_utils.py +52 -52
  248. warp/tests/test_vec.py +29 -29
  249. warp/tests/test_vec_constructors.py +5 -5
  250. warp/tests/test_vec_scalar_ops.py +97 -97
  251. warp/tests/test_version.py +75 -0
  252. warp/tests/tile/test_tile.py +178 -0
  253. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  254. warp/tests/tile/test_tile_cholesky.py +7 -4
  255. warp/tests/tile/test_tile_load.py +26 -2
  256. warp/tests/tile/test_tile_mathdx.py +3 -3
  257. warp/tests/tile/test_tile_matmul.py +1 -1
  258. warp/tests/tile/test_tile_mlp.py +2 -4
  259. warp/tests/tile/test_tile_reduce.py +214 -13
  260. warp/tests/unittest_suites.py +6 -14
  261. warp/tests/unittest_utils.py +10 -9
  262. warp/tests/walkthrough_debug.py +3 -1
  263. warp/torch.py +6 -373
  264. warp/types.py +29 -5764
  265. warp/utils.py +10 -1659
  266. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0.dist-info}/METADATA +46 -99
  267. warp_lang-1.10.0.dist-info/RECORD +468 -0
  268. warp_lang-1.10.0.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  269. warp_lang-1.10.0.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  270. warp_lang-1.10.0.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  271. warp_lang-1.10.0.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  272. warp_lang-1.10.0.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  273. warp_lang-1.10.0.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  274. warp_lang-1.10.0.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  275. warp_lang-1.10.0.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  276. warp_lang-1.10.0.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  277. warp_lang-1.10.0.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  278. warp_lang-1.10.0.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  279. warp_lang-1.10.0.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  280. warp_lang-1.10.0.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  281. warp_lang-1.10.0.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  282. warp_lang-1.10.0.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  283. warp/examples/assets/cartpole.urdf +0 -110
  284. warp/examples/assets/crazyflie.usd +0 -0
  285. warp/examples/assets/nv_ant.xml +0 -92
  286. warp/examples/assets/nv_humanoid.xml +0 -183
  287. warp/examples/assets/quadruped.urdf +0 -268
  288. warp/examples/optim/example_bounce.py +0 -266
  289. warp/examples/optim/example_cloth_throw.py +0 -228
  290. warp/examples/optim/example_drone.py +0 -870
  291. warp/examples/optim/example_inverse_kinematics.py +0 -182
  292. warp/examples/optim/example_inverse_kinematics_torch.py +0 -191
  293. warp/examples/optim/example_softbody_properties.py +0 -400
  294. warp/examples/optim/example_spring_cage.py +0 -245
  295. warp/examples/optim/example_trajectory.py +0 -227
  296. warp/examples/sim/example_cartpole.py +0 -143
  297. warp/examples/sim/example_cloth.py +0 -225
  298. warp/examples/sim/example_cloth_self_contact.py +0 -316
  299. warp/examples/sim/example_granular.py +0 -130
  300. warp/examples/sim/example_granular_collision_sdf.py +0 -202
  301. warp/examples/sim/example_jacobian_ik.py +0 -244
  302. warp/examples/sim/example_particle_chain.py +0 -124
  303. warp/examples/sim/example_quadruped.py +0 -203
  304. warp/examples/sim/example_rigid_chain.py +0 -203
  305. warp/examples/sim/example_rigid_contact.py +0 -195
  306. warp/examples/sim/example_rigid_force.py +0 -133
  307. warp/examples/sim/example_rigid_gyroscopic.py +0 -115
  308. warp/examples/sim/example_rigid_soft_contact.py +0 -140
  309. warp/examples/sim/example_soft_body.py +0 -196
  310. warp/examples/tile/example_tile_walker.py +0 -327
  311. warp/sim/__init__.py +0 -74
  312. warp/sim/articulation.py +0 -793
  313. warp/sim/collide.py +0 -2570
  314. warp/sim/graph_coloring.py +0 -307
  315. warp/sim/import_mjcf.py +0 -791
  316. warp/sim/import_snu.py +0 -227
  317. warp/sim/import_urdf.py +0 -579
  318. warp/sim/import_usd.py +0 -898
  319. warp/sim/inertia.py +0 -357
  320. warp/sim/integrator.py +0 -245
  321. warp/sim/integrator_euler.py +0 -2000
  322. warp/sim/integrator_featherstone.py +0 -2101
  323. warp/sim/integrator_vbd.py +0 -2487
  324. warp/sim/integrator_xpbd.py +0 -3295
  325. warp/sim/model.py +0 -4821
  326. warp/sim/particles.py +0 -121
  327. warp/sim/render.py +0 -431
  328. warp/sim/utils.py +0 -431
  329. warp/tests/sim/disabled_kinematics.py +0 -244
  330. warp/tests/sim/test_cloth.py +0 -863
  331. warp/tests/sim/test_collision.py +0 -743
  332. warp/tests/sim/test_coloring.py +0 -347
  333. warp/tests/sim/test_inertia.py +0 -161
  334. warp/tests/sim/test_model.py +0 -226
  335. warp/tests/sim/test_sim_grad.py +0 -287
  336. warp/tests/sim/test_sim_grad_bounce_linear.py +0 -212
  337. warp/tests/sim/test_sim_kinematics.py +0 -98
  338. warp/thirdparty/__init__.py +0 -0
  339. warp_lang-1.9.1.dist-info/RECORD +0 -456
  340. /warp/{fem → _src/fem}/quadrature/__init__.py +0 -0
  341. /warp/{tests/sim → _src/thirdparty}/__init__.py +0 -0
  342. /warp/{thirdparty → _src/thirdparty}/appdirs.py +0 -0
  343. /warp/{thirdparty → _src/thirdparty}/dlpack.py +0 -0
  344. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0.dist-info}/WHEEL +0 -0
  345. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0.dist-info}/licenses/LICENSE.md +0 -0
  346. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0.dist-info}/top_level.txt +0 -0
warp/native/warp.cu CHANGED
@@ -38,6 +38,7 @@
38
38
  #include <iterator>
39
39
  #include <list>
40
40
  #include <map>
41
+ #include <mutex>
41
42
  #include <string>
42
43
  #include <unordered_map>
43
44
  #include <unordered_set>
@@ -176,11 +177,20 @@ struct ContextInfo
176
177
  CUmodule conditional_module = NULL;
177
178
  };
178
179
 
180
+ // Information used for freeing allocations.
181
+ struct FreeInfo
182
+ {
183
+ void* context = NULL;
184
+ void* ptr = NULL;
185
+ bool is_async = false;
186
+ };
187
+
179
188
  struct CaptureInfo
180
189
  {
181
190
  CUstream stream = NULL; // the main stream where capture begins and ends
182
191
  uint64_t id = 0; // unique capture id from CUDA
183
192
  bool external = false; // whether this is an external capture
193
+ std::vector<FreeInfo> tmp_allocs; // temporary allocations owned by the graph (e.g., staged array fill values)
184
194
  };
185
195
 
186
196
  struct StreamInfo
@@ -189,9 +199,13 @@ struct StreamInfo
189
199
  CaptureInfo* capture = NULL; // capture info (only if started on this stream)
190
200
  };
191
201
 
192
- struct GraphInfo
202
+ // Extra resources tied to a graph, freed after the graph is released by CUDA.
203
+ // Used with the on_graph_destroy() callback.
204
+ struct GraphDestroyCallbackInfo
193
205
  {
194
- std::vector<void*> unfreed_allocs;
206
+ void* context = NULL; // graph CUDA context
207
+ std::vector<void*> unfreed_allocs; // graph allocations not freed by the graph
208
+ std::vector<FreeInfo> tmp_allocs; // temporary allocations owned by the graph (e.g., staged array fill values)
195
209
  };
196
210
 
197
211
  // Information for graph allocations that are not freed by the graph.
@@ -207,19 +221,19 @@ struct GraphAllocInfo
207
221
  bool graph_destroyed = false; // whether graph instance was destroyed
208
222
  };
209
223
 
210
- // Information used when deferring deallocations.
211
- struct FreeInfo
224
+ // Information used when deferring module unloading.
225
+ struct ModuleInfo
212
226
  {
213
227
  void* context = NULL;
214
- void* ptr = NULL;
215
- bool is_async = false;
228
+ void* module = NULL;
216
229
  };
217
230
 
218
- // Information used when deferring module unloading.
219
- struct ModuleInfo
231
+ // Information used when deferring graph destruction.
232
+ struct GraphDestroyInfo
220
233
  {
221
234
  void* context = NULL;
222
- void* module = NULL;
235
+ void* graph = NULL;
236
+ void* graph_exec = NULL;
223
237
  };
224
238
 
225
239
  static std::unordered_map<CUfunction, std::string> g_kernel_names;
@@ -253,6 +267,15 @@ static std::vector<FreeInfo> g_deferred_free_list;
253
267
  // Call unload_deferred_modules() to release.
254
268
  static std::vector<ModuleInfo> g_deferred_module_list;
255
269
 
270
+ // Graphs that cannot be destroyed immediately get queued here.
271
+ // Call destroy_deferred_graphs() to release.
272
+ static std::vector<GraphDestroyInfo> g_deferred_graph_list;
273
+
274
+ // Data from on_graph_destroy() callbacks that run on a different thread.
275
+ static std::vector<GraphDestroyCallbackInfo*> g_deferred_graph_destroy_list;
276
+ static std::mutex g_graph_destroy_mutex;
277
+
278
+
256
279
  void wp_cuda_set_context_restore_policy(bool always_restore)
257
280
  {
258
281
  ContextGuard::always_restore = always_restore;
@@ -338,7 +361,7 @@ int cuda_init()
338
361
  }
339
362
 
340
363
 
341
- static inline CUcontext get_current_context()
364
+ CUcontext get_current_context()
342
365
  {
343
366
  CUcontext ctx;
344
367
  if (check_cu(cuCtxGetCurrent_f(&ctx)))
@@ -408,6 +431,114 @@ static inline StreamInfo* get_stream_info(CUstream stream)
408
431
  return NULL;
409
432
  }
410
433
 
434
+ static inline CaptureInfo* get_capture_info(CUstream stream)
435
+ {
436
+ if (!g_captures.empty() && wp_cuda_stream_is_capturing(stream))
437
+ {
438
+ uint64_t capture_id = get_capture_id(stream);
439
+ auto capture_iter = g_captures.find(capture_id);
440
+ if (capture_iter != g_captures.end())
441
+ return capture_iter->second;
442
+ }
443
+ return NULL;
444
+ }
445
+
446
+ // helper function to copy a value to device memory in a graph-friendly way
447
+ static bool capturable_tmp_alloc(void* context, const void* data, size_t size, void** devptr_ret, bool* free_devptr_ret)
448
+ {
449
+ ContextGuard guard(context);
450
+
451
+ CUstream stream = get_current_stream();
452
+ CaptureInfo* capture_info = get_capture_info(stream);
453
+ int device_ordinal = wp_cuda_context_get_device_ordinal(context);
454
+ void* devptr = NULL;
455
+ bool free_devptr = true;
456
+
457
+ if (capture_info)
458
+ {
459
+ // ongoing graph capture - need to stage the fill value so that it persists with the graph
460
+ if (CUDA_VERSION >= 12040 && wp_cuda_driver_version() >= 12040)
461
+ {
462
+ // pause the capture so that the alloc/memcpy won't be captured
463
+ void* graph = NULL;
464
+ if (!wp_cuda_graph_pause_capture(WP_CURRENT_CONTEXT, stream, &graph))
465
+ return false;
466
+
467
+ // copy value to device memory
468
+ devptr = wp_alloc_device(WP_CURRENT_CONTEXT, size);
469
+ if (!devptr)
470
+ {
471
+ fprintf(stderr, "Warp error: Failed to allocate %llu bytes on device 'cuda:%d' (in function %s)\n", (unsigned long long)size, device_ordinal, __FUNCTION__);
472
+ return false;
473
+ }
474
+ if (!check_cuda(cudaMemcpyAsync(devptr, data, size, cudaMemcpyHostToDevice, stream)))
475
+ return false;
476
+
477
+ // graph takes ownership of the value storage
478
+ FreeInfo free_info;
479
+ free_info.context = context ? context : get_current_context();
480
+ free_info.ptr = devptr;
481
+ free_info.is_async = wp_cuda_device_is_mempool_supported(device_ordinal);
482
+
483
+ // allocation will be freed when graph is destroyed
484
+ capture_info->tmp_allocs.push_back(free_info);
485
+
486
+ // resume the capture
487
+ if (!wp_cuda_graph_resume_capture(WP_CURRENT_CONTEXT, stream, graph))
488
+ return false;
489
+
490
+ free_devptr = false; // memory is owned by the graph, doesn't need to be freed
491
+ }
492
+ else
493
+ {
494
+ // older CUDA can't pause/resume the capture, so stage in CPU memory
495
+ void* hostptr = wp_alloc_host(size);
496
+ if (!hostptr)
497
+ {
498
+ fprintf(stderr, "Warp error: Failed to allocate %llu bytes on device 'cpu' (in function %s)\n", (unsigned long long)size, __FUNCTION__);
499
+ return false;
500
+ }
501
+ memcpy(hostptr, data, size);
502
+
503
+ // the device allocation and h2d copy will be captured in the graph
504
+ devptr = wp_alloc_device(WP_CURRENT_CONTEXT, size);
505
+ if (!devptr)
506
+ {
507
+ fprintf(stderr, "Warp error: Failed to allocate %llu bytes on device 'cuda:%d' (in function %s)\n", (unsigned long long)size, device_ordinal, __FUNCTION__);
508
+ return false;
509
+ }
510
+ if (!check_cuda(cudaMemcpyAsync(devptr, hostptr, size, cudaMemcpyHostToDevice, stream)))
511
+ return false;
512
+
513
+ // graph takes ownership of the value storage
514
+ FreeInfo free_info;
515
+ free_info.context = NULL;
516
+ free_info.ptr = hostptr;
517
+ free_info.is_async = false;
518
+
519
+ // allocation will be freed when graph is destroyed
520
+ capture_info->tmp_allocs.push_back(free_info);
521
+ }
522
+ }
523
+ else
524
+ {
525
+ // not capturing, copy the value to device memory
526
+ devptr = wp_alloc_device(WP_CURRENT_CONTEXT, size);
527
+ if (!devptr)
528
+ {
529
+ fprintf(stderr, "Warp error: Failed to allocate %llu bytes on device 'cuda:%d' (in function %s)\n", (unsigned long long)size, device_ordinal, __FUNCTION__);
530
+ return false;
531
+ }
532
+ if (!check_cuda(cudaMemcpyAsync(devptr, data, size, cudaMemcpyHostToDevice, stream)))
533
+ return false;
534
+ }
535
+
536
+ *devptr_ret = devptr;
537
+ *free_devptr_ret = free_devptr;
538
+
539
+ return true;
540
+ }
541
+
411
542
  static void deferred_free(void* ptr, void* context, bool is_async)
412
543
  {
413
544
  FreeInfo free_info;
@@ -495,34 +626,124 @@ static int unload_deferred_modules(void* context = NULL)
495
626
  return num_unloaded_modules;
496
627
  }
497
628
 
498
- static void CUDART_CB on_graph_destroy(void* user_data)
629
+ static int destroy_deferred_graphs(void* context = NULL)
499
630
  {
500
- if (!user_data)
501
- return;
631
+ if (g_deferred_graph_list.empty() || !g_captures.empty())
632
+ return 0;
502
633
 
503
- GraphInfo* graph_info = static_cast<GraphInfo*>(user_data);
634
+ int num_destroyed_graphs = 0;
635
+ for (auto it = g_deferred_graph_list.begin(); it != g_deferred_graph_list.end(); /*noop*/)
636
+ {
637
+ // destroy the graph if it matches the given context or if the context is unspecified
638
+ const GraphDestroyInfo& graph_info = *it;
639
+ if (graph_info.context == context || !context)
640
+ {
641
+ if (graph_info.graph)
642
+ {
643
+ check_cuda(cudaGraphDestroy((cudaGraph_t)graph_info.graph));
644
+ }
645
+ if (graph_info.graph_exec)
646
+ {
647
+ check_cuda(cudaGraphExecDestroy((cudaGraphExec_t)graph_info.graph_exec));
648
+ }
649
+ ++num_destroyed_graphs;
650
+ it = g_deferred_graph_list.erase(it);
651
+ }
652
+ else
653
+ {
654
+ ++it;
655
+ }
656
+ }
504
657
 
505
- for (void* ptr : graph_info->unfreed_allocs)
658
+ return num_destroyed_graphs;
659
+ }
660
+
661
+ static int process_deferred_graph_destroy_callbacks(void* context = NULL)
662
+ {
663
+ int num_freed = 0;
664
+
665
+ std::lock_guard<std::mutex> lock(g_graph_destroy_mutex);
666
+
667
+ for (auto it = g_deferred_graph_destroy_list.begin(); it != g_deferred_graph_destroy_list.end(); /*noop*/)
506
668
  {
507
- auto alloc_iter = g_graph_allocs.find(ptr);
508
- if (alloc_iter != g_graph_allocs.end())
669
+ GraphDestroyCallbackInfo* graph_info = *it;
670
+ if (graph_info->context == context || !context)
509
671
  {
510
- GraphAllocInfo& alloc_info = alloc_iter->second;
511
- if (alloc_info.ref_exists)
672
+ // handle unfreed graph allocations (may have outstanding user references)
673
+ for (void* ptr : graph_info->unfreed_allocs)
512
674
  {
513
- // unreference from graph so the pointer will be deallocated when the user reference goes away
514
- alloc_info.graph_destroyed = true;
675
+ auto alloc_iter = g_graph_allocs.find(ptr);
676
+ if (alloc_iter != g_graph_allocs.end())
677
+ {
678
+ GraphAllocInfo& alloc_info = alloc_iter->second;
679
+ if (alloc_info.ref_exists)
680
+ {
681
+ // unreference from graph so the pointer will be deallocated when the user reference goes away
682
+ alloc_info.graph_destroyed = true;
683
+ }
684
+ else
685
+ {
686
+ // the pointer can be freed, no references remain
687
+ wp_free_device_async(alloc_info.context, ptr);
688
+ g_graph_allocs.erase(alloc_iter);
689
+ }
690
+ }
515
691
  }
516
- else
692
+
693
+ // handle temporary allocations owned by the graph (no user references)
694
+ for (const FreeInfo& tmp_info : graph_info->tmp_allocs)
517
695
  {
518
- // the pointer can be freed, but we can't call CUDA functions in this callback, so defer it
519
- deferred_free(ptr, alloc_info.context, true);
520
- g_graph_allocs.erase(alloc_iter);
696
+ if (tmp_info.context)
697
+ {
698
+ // GPU alloc
699
+ if (tmp_info.is_async)
700
+ {
701
+ wp_free_device_async(tmp_info.context, tmp_info.ptr);
702
+ }
703
+ else
704
+ {
705
+ wp_free_device_default(tmp_info.context, tmp_info.ptr);
706
+ }
707
+ }
708
+ else
709
+ {
710
+ // CPU alloc
711
+ wp_free_host(tmp_info.ptr);
712
+ }
521
713
  }
714
+
715
+ ++num_freed;
716
+ delete graph_info;
717
+ it = g_deferred_graph_destroy_list.erase(it);
718
+ }
719
+ else
720
+ {
721
+ ++it;
522
722
  }
523
723
  }
524
724
 
525
- delete graph_info;
725
+ return num_freed;
726
+ }
727
+
728
+ static int run_deferred_actions(void* context = NULL)
729
+ {
730
+ int num_actions = 0;
731
+ num_actions += free_deferred_allocs(context);
732
+ num_actions += unload_deferred_modules(context);
733
+ num_actions += destroy_deferred_graphs(context);
734
+ num_actions += process_deferred_graph_destroy_callbacks(context);
735
+ return num_actions;
736
+ }
737
+
738
+ // Callback used when a graph is destroyed.
739
+ // NOTE: this runs on an internal CUDA thread and requires synchronization.
740
+ static void CUDART_CB on_graph_destroy(void* user_data)
741
+ {
742
+ if (user_data)
743
+ {
744
+ std::lock_guard<std::mutex> lock(g_graph_destroy_mutex);
745
+ g_deferred_graph_destroy_list.push_back(static_cast<GraphDestroyCallbackInfo*>(user_data));
746
+ }
526
747
  }
527
748
 
528
749
  static inline const char* get_cuda_kernel_name(void* kernel)
@@ -974,30 +1195,36 @@ void wp_memtile_device(void* context, void* dst, const void* src, size_t srcsize
974
1195
  else
975
1196
  {
976
1197
  // generic version
1198
+ void* value_devptr = NULL; // fill value in device memory
1199
+ bool free_devptr = true; // whether we need to free the memory
977
1200
 
978
- // copy value to device memory
979
- // TODO: use a persistent stream-local staging buffer to avoid allocs?
980
- void* src_devptr = wp_alloc_device(WP_CURRENT_CONTEXT, srcsize);
981
- check_cuda(cudaMemcpyAsync(src_devptr, src, srcsize, cudaMemcpyHostToDevice, get_current_stream()));
982
-
983
- wp_launch_device(WP_CURRENT_CONTEXT, memtile_kernel, n, (dst, src_devptr, srcsize, n));
1201
+ // prepare the fill value in a graph-friendly way
1202
+ if (!capturable_tmp_alloc(WP_CURRENT_CONTEXT, src, srcsize, &value_devptr, &free_devptr))
1203
+ {
1204
+ fprintf(stderr, "Warp fill error: failed to copy value to device memory\n");
1205
+ return;
1206
+ }
984
1207
 
985
- wp_free_device(WP_CURRENT_CONTEXT, src_devptr);
1208
+ wp_launch_device(WP_CURRENT_CONTEXT, memtile_kernel, n, (dst, value_devptr, srcsize, n));
986
1209
 
1210
+ if (free_devptr)
1211
+ {
1212
+ wp_free_device(WP_CURRENT_CONTEXT, value_devptr);
1213
+ }
987
1214
  }
988
1215
  }
989
1216
 
990
1217
 
991
1218
  static __global__ void array_copy_1d_kernel(void* dst, const void* src,
992
- int dst_stride, int src_stride,
1219
+ size_t dst_stride, size_t src_stride,
993
1220
  const int* dst_indices, const int* src_indices,
994
- int n, int elem_size)
1221
+ size_t n, size_t elem_size)
995
1222
  {
996
- int i = blockIdx.x * blockDim.x + threadIdx.x;
1223
+ size_t i = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
997
1224
  if (i < n)
998
1225
  {
999
- int src_idx = src_indices ? src_indices[i] : i;
1000
- int dst_idx = dst_indices ? dst_indices[i] : i;
1226
+ size_t src_idx = src_indices ? src_indices[i] : i;
1227
+ size_t dst_idx = dst_indices ? dst_indices[i] : i;
1001
1228
  const char* p = (const char*)src + src_idx * src_stride;
1002
1229
  char* q = (char*)dst + dst_idx * dst_stride;
1003
1230
  memcpy(q, p, elem_size);
@@ -1005,20 +1232,20 @@ static __global__ void array_copy_1d_kernel(void* dst, const void* src,
1005
1232
  }
1006
1233
 
1007
1234
  static __global__ void array_copy_2d_kernel(void* dst, const void* src,
1008
- wp::vec_t<2, int> dst_strides, wp::vec_t<2, int> src_strides,
1235
+ wp::vec_t<2, size_t> dst_strides, wp::vec_t<2, size_t> src_strides,
1009
1236
  wp::vec_t<2, const int*> dst_indices, wp::vec_t<2, const int*> src_indices,
1010
- wp::vec_t<2, int> shape, int elem_size)
1237
+ wp::vec_t<2, size_t> shape, size_t elem_size)
1011
1238
  {
1012
- int tid = blockIdx.x * blockDim.x + threadIdx.x;
1013
- int n = shape[1];
1014
- int i = tid / n;
1015
- int j = tid % n;
1239
+ size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1240
+ size_t n = shape[1];
1241
+ size_t i = tid / n;
1242
+ size_t j = tid % n;
1016
1243
  if (i < shape[0] /*&& j < shape[1]*/)
1017
1244
  {
1018
- int src_idx0 = src_indices[0] ? src_indices[0][i] : i;
1019
- int dst_idx0 = dst_indices[0] ? dst_indices[0][i] : i;
1020
- int src_idx1 = src_indices[1] ? src_indices[1][j] : j;
1021
- int dst_idx1 = dst_indices[1] ? dst_indices[1][j] : j;
1245
+ size_t src_idx0 = src_indices[0] ? src_indices[0][i] : i;
1246
+ size_t dst_idx0 = dst_indices[0] ? dst_indices[0][i] : i;
1247
+ size_t src_idx1 = src_indices[1] ? src_indices[1][j] : j;
1248
+ size_t dst_idx1 = dst_indices[1] ? dst_indices[1][j] : j;
1022
1249
  const char* p = (const char*)src + src_idx0 * src_strides[0] + src_idx1 * src_strides[1];
1023
1250
  char* q = (char*)dst + dst_idx0 * dst_strides[0] + dst_idx1 * dst_strides[1];
1024
1251
  memcpy(q, p, elem_size);
@@ -1026,24 +1253,24 @@ static __global__ void array_copy_2d_kernel(void* dst, const void* src,
1026
1253
  }
1027
1254
 
1028
1255
  static __global__ void array_copy_3d_kernel(void* dst, const void* src,
1029
- wp::vec_t<3, int> dst_strides, wp::vec_t<3, int> src_strides,
1256
+ wp::vec_t<3, size_t> dst_strides, wp::vec_t<3, size_t> src_strides,
1030
1257
  wp::vec_t<3, const int*> dst_indices, wp::vec_t<3, const int*> src_indices,
1031
- wp::vec_t<3, int> shape, int elem_size)
1032
- {
1033
- int tid = blockIdx.x * blockDim.x + threadIdx.x;
1034
- int n = shape[1];
1035
- int o = shape[2];
1036
- int i = tid / (n * o);
1037
- int j = tid % (n * o) / o;
1038
- int k = tid % o;
1258
+ wp::vec_t<3, size_t> shape, size_t elem_size)
1259
+ {
1260
+ size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1261
+ size_t n = shape[1];
1262
+ size_t o = shape[2];
1263
+ size_t i = tid / (n * o);
1264
+ size_t j = tid % (n * o) / o;
1265
+ size_t k = tid % o;
1039
1266
  if (i < shape[0] && j < shape[1] /*&& k < shape[2]*/)
1040
1267
  {
1041
- int src_idx0 = src_indices[0] ? src_indices[0][i] : i;
1042
- int dst_idx0 = dst_indices[0] ? dst_indices[0][i] : i;
1043
- int src_idx1 = src_indices[1] ? src_indices[1][j] : j;
1044
- int dst_idx1 = dst_indices[1] ? dst_indices[1][j] : j;
1045
- int src_idx2 = src_indices[2] ? src_indices[2][k] : k;
1046
- int dst_idx2 = dst_indices[2] ? dst_indices[2][k] : k;
1268
+ size_t src_idx0 = src_indices[0] ? src_indices[0][i] : i;
1269
+ size_t dst_idx0 = dst_indices[0] ? dst_indices[0][i] : i;
1270
+ size_t src_idx1 = src_indices[1] ? src_indices[1][j] : j;
1271
+ size_t dst_idx1 = dst_indices[1] ? dst_indices[1][j] : j;
1272
+ size_t src_idx2 = src_indices[2] ? src_indices[2][k] : k;
1273
+ size_t dst_idx2 = dst_indices[2] ? dst_indices[2][k] : k;
1047
1274
  const char* p = (const char*)src + src_idx0 * src_strides[0]
1048
1275
  + src_idx1 * src_strides[1]
1049
1276
  + src_idx2 * src_strides[2];
@@ -1055,28 +1282,28 @@ static __global__ void array_copy_3d_kernel(void* dst, const void* src,
1055
1282
  }
1056
1283
 
1057
1284
  static __global__ void array_copy_4d_kernel(void* dst, const void* src,
1058
- wp::vec_t<4, int> dst_strides, wp::vec_t<4, int> src_strides,
1285
+ wp::vec_t<4, size_t> dst_strides, wp::vec_t<4, size_t> src_strides,
1059
1286
  wp::vec_t<4, const int*> dst_indices, wp::vec_t<4, const int*> src_indices,
1060
- wp::vec_t<4, int> shape, int elem_size)
1061
- {
1062
- int tid = blockIdx.x * blockDim.x + threadIdx.x;
1063
- int n = shape[1];
1064
- int o = shape[2];
1065
- int p = shape[3];
1066
- int i = tid / (n * o * p);
1067
- int j = tid % (n * o * p) / (o * p);
1068
- int k = tid % (o * p) / p;
1069
- int l = tid % p;
1287
+ wp::vec_t<4, size_t> shape, size_t elem_size)
1288
+ {
1289
+ size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1290
+ size_t n = shape[1];
1291
+ size_t o = shape[2];
1292
+ size_t p = shape[3];
1293
+ size_t i = tid / (n * o * p);
1294
+ size_t j = tid % (n * o * p) / (o * p);
1295
+ size_t k = tid % (o * p) / p;
1296
+ size_t l = tid % p;
1070
1297
  if (i < shape[0] && j < shape[1] && k < shape[2] /*&& l < shape[3]*/)
1071
1298
  {
1072
- int src_idx0 = src_indices[0] ? src_indices[0][i] : i;
1073
- int dst_idx0 = dst_indices[0] ? dst_indices[0][i] : i;
1074
- int src_idx1 = src_indices[1] ? src_indices[1][j] : j;
1075
- int dst_idx1 = dst_indices[1] ? dst_indices[1][j] : j;
1076
- int src_idx2 = src_indices[2] ? src_indices[2][k] : k;
1077
- int dst_idx2 = dst_indices[2] ? dst_indices[2][k] : k;
1078
- int src_idx3 = src_indices[3] ? src_indices[3][l] : l;
1079
- int dst_idx3 = dst_indices[3] ? dst_indices[3][l] : l;
1299
+ size_t src_idx0 = src_indices[0] ? src_indices[0][i] : i;
1300
+ size_t dst_idx0 = dst_indices[0] ? dst_indices[0][i] : i;
1301
+ size_t src_idx1 = src_indices[1] ? src_indices[1][j] : j;
1302
+ size_t dst_idx1 = dst_indices[1] ? dst_indices[1][j] : j;
1303
+ size_t src_idx2 = src_indices[2] ? src_indices[2][k] : k;
1304
+ size_t dst_idx2 = dst_indices[2] ? dst_indices[2][k] : k;
1305
+ size_t src_idx3 = src_indices[3] ? src_indices[3][l] : l;
1306
+ size_t dst_idx3 = dst_indices[3] ? dst_indices[3][l] : l;
1080
1307
  const char* p = (const char*)src + src_idx0 * src_strides[0]
1081
1308
  + src_idx1 * src_strides[1]
1082
1309
  + src_idx2 * src_strides[2]
@@ -1091,14 +1318,14 @@ static __global__ void array_copy_4d_kernel(void* dst, const void* src,
1091
1318
 
1092
1319
 
1093
1320
  static __global__ void array_copy_from_fabric_kernel(wp::fabricarray_t<void> src,
1094
- void* dst_data, int dst_stride, const int* dst_indices,
1095
- int elem_size)
1321
+ void* dst_data, size_t dst_stride, const int* dst_indices,
1322
+ size_t elem_size)
1096
1323
  {
1097
- int tid = blockIdx.x * blockDim.x + threadIdx.x;
1324
+ size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1098
1325
 
1099
1326
  if (tid < src.size)
1100
1327
  {
1101
- int dst_idx = dst_indices ? dst_indices[tid] : tid;
1328
+ size_t dst_idx = dst_indices ? dst_indices[tid] : tid;
1102
1329
  void* dst_ptr = (char*)dst_data + dst_idx * dst_stride;
1103
1330
  const void* src_ptr = fabricarray_element_ptr(src, tid, elem_size);
1104
1331
  memcpy(dst_ptr, src_ptr, elem_size);
@@ -1106,15 +1333,15 @@ static __global__ void array_copy_from_fabric_kernel(wp::fabricarray_t<void> src
1106
1333
  }
1107
1334
 
1108
1335
  static __global__ void array_copy_from_fabric_indexed_kernel(wp::indexedfabricarray_t<void> src,
1109
- void* dst_data, int dst_stride, const int* dst_indices,
1110
- int elem_size)
1336
+ void* dst_data, size_t dst_stride, const int* dst_indices,
1337
+ size_t elem_size)
1111
1338
  {
1112
- int tid = blockIdx.x * blockDim.x + threadIdx.x;
1339
+ size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1113
1340
 
1114
1341
  if (tid < src.size)
1115
1342
  {
1116
- int src_index = src.indices[tid];
1117
- int dst_idx = dst_indices ? dst_indices[tid] : tid;
1343
+ size_t src_index = src.indices[tid];
1344
+ size_t dst_idx = dst_indices ? dst_indices[tid] : tid;
1118
1345
  void* dst_ptr = (char*)dst_data + dst_idx * dst_stride;
1119
1346
  const void* src_ptr = fabricarray_element_ptr(src.fa, src_index, elem_size);
1120
1347
  memcpy(dst_ptr, src_ptr, elem_size);
@@ -1122,14 +1349,14 @@ static __global__ void array_copy_from_fabric_indexed_kernel(wp::indexedfabricar
1122
1349
  }
1123
1350
 
1124
1351
  static __global__ void array_copy_to_fabric_kernel(wp::fabricarray_t<void> dst,
1125
- const void* src_data, int src_stride, const int* src_indices,
1126
- int elem_size)
1352
+ const void* src_data, size_t src_stride, const int* src_indices,
1353
+ size_t elem_size)
1127
1354
  {
1128
- int tid = blockIdx.x * blockDim.x + threadIdx.x;
1355
+ size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1129
1356
 
1130
1357
  if (tid < dst.size)
1131
1358
  {
1132
- int src_idx = src_indices ? src_indices[tid] : tid;
1359
+ size_t src_idx = src_indices ? src_indices[tid] : tid;
1133
1360
  const void* src_ptr = (const char*)src_data + src_idx * src_stride;
1134
1361
  void* dst_ptr = fabricarray_element_ptr(dst, tid, elem_size);
1135
1362
  memcpy(dst_ptr, src_ptr, elem_size);
@@ -1137,25 +1364,25 @@ static __global__ void array_copy_to_fabric_kernel(wp::fabricarray_t<void> dst,
1137
1364
  }
1138
1365
 
1139
1366
  static __global__ void array_copy_to_fabric_indexed_kernel(wp::indexedfabricarray_t<void> dst,
1140
- const void* src_data, int src_stride, const int* src_indices,
1141
- int elem_size)
1367
+ const void* src_data, size_t src_stride, const int* src_indices,
1368
+ size_t elem_size)
1142
1369
  {
1143
- int tid = blockIdx.x * blockDim.x + threadIdx.x;
1370
+ size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1144
1371
 
1145
1372
  if (tid < dst.size)
1146
1373
  {
1147
- int src_idx = src_indices ? src_indices[tid] : tid;
1374
+ size_t src_idx = src_indices ? src_indices[tid] : tid;
1148
1375
  const void* src_ptr = (const char*)src_data + src_idx * src_stride;
1149
- int dst_idx = dst.indices[tid];
1376
+ size_t dst_idx = dst.indices[tid];
1150
1377
  void* dst_ptr = fabricarray_element_ptr(dst.fa, dst_idx, elem_size);
1151
1378
  memcpy(dst_ptr, src_ptr, elem_size);
1152
1379
  }
1153
1380
  }
1154
1381
 
1155
1382
 
1156
- static __global__ void array_copy_fabric_to_fabric_kernel(wp::fabricarray_t<void> dst, wp::fabricarray_t<void> src, int elem_size)
1383
+ static __global__ void array_copy_fabric_to_fabric_kernel(wp::fabricarray_t<void> dst, wp::fabricarray_t<void> src, size_t elem_size)
1157
1384
  {
1158
- int tid = blockIdx.x * blockDim.x + threadIdx.x;
1385
+ size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1159
1386
 
1160
1387
  if (tid < dst.size)
1161
1388
  {
@@ -1166,27 +1393,27 @@ static __global__ void array_copy_fabric_to_fabric_kernel(wp::fabricarray_t<void
1166
1393
  }
1167
1394
 
1168
1395
 
1169
- static __global__ void array_copy_fabric_to_fabric_indexed_kernel(wp::indexedfabricarray_t<void> dst, wp::fabricarray_t<void> src, int elem_size)
1396
+ static __global__ void array_copy_fabric_to_fabric_indexed_kernel(wp::indexedfabricarray_t<void> dst, wp::fabricarray_t<void> src, size_t elem_size)
1170
1397
  {
1171
- int tid = blockIdx.x * blockDim.x + threadIdx.x;
1398
+ size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1172
1399
 
1173
1400
  if (tid < dst.size)
1174
1401
  {
1175
1402
  const void* src_ptr = fabricarray_element_ptr(src, tid, elem_size);
1176
- int dst_index = dst.indices[tid];
1403
+ size_t dst_index = dst.indices[tid];
1177
1404
  void* dst_ptr = fabricarray_element_ptr(dst.fa, dst_index, elem_size);
1178
1405
  memcpy(dst_ptr, src_ptr, elem_size);
1179
1406
  }
1180
1407
  }
1181
1408
 
1182
1409
 
1183
- static __global__ void array_copy_fabric_indexed_to_fabric_kernel(wp::fabricarray_t<void> dst, wp::indexedfabricarray_t<void> src, int elem_size)
1410
+ static __global__ void array_copy_fabric_indexed_to_fabric_kernel(wp::fabricarray_t<void> dst, wp::indexedfabricarray_t<void> src, size_t elem_size)
1184
1411
  {
1185
- int tid = blockIdx.x * blockDim.x + threadIdx.x;
1412
+ size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1186
1413
 
1187
1414
  if (tid < dst.size)
1188
1415
  {
1189
- int src_index = src.indices[tid];
1416
+ size_t src_index = src.indices[tid];
1190
1417
  const void* src_ptr = fabricarray_element_ptr(src.fa, src_index, elem_size);
1191
1418
  void* dst_ptr = fabricarray_element_ptr(dst, tid, elem_size);
1192
1419
  memcpy(dst_ptr, src_ptr, elem_size);
@@ -1194,14 +1421,14 @@ static __global__ void array_copy_fabric_indexed_to_fabric_kernel(wp::fabricarra
1194
1421
  }
1195
1422
 
1196
1423
 
1197
- static __global__ void array_copy_fabric_indexed_to_fabric_indexed_kernel(wp::indexedfabricarray_t<void> dst, wp::indexedfabricarray_t<void> src, int elem_size)
1424
+ static __global__ void array_copy_fabric_indexed_to_fabric_indexed_kernel(wp::indexedfabricarray_t<void> dst, wp::indexedfabricarray_t<void> src, size_t elem_size)
1198
1425
  {
1199
- int tid = blockIdx.x * blockDim.x + threadIdx.x;
1426
+ size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1200
1427
 
1201
1428
  if (tid < dst.size)
1202
1429
  {
1203
- int src_index = src.indices[tid];
1204
- int dst_index = dst.indices[tid];
1430
+ size_t src_index = src.indices[tid];
1431
+ size_t dst_index = dst.indices[tid];
1205
1432
  const void* src_ptr = fabricarray_element_ptr(src.fa, src_index, elem_size);
1206
1433
  void* dst_ptr = fabricarray_element_ptr(dst.fa, dst_index, elem_size);
1207
1434
  memcpy(dst_ptr, src_ptr, elem_size);
@@ -1440,9 +1667,9 @@ WP_API bool wp_array_copy_device(void* context, void* dst, void* src, int dst_ty
1440
1667
  }
1441
1668
  case 2:
1442
1669
  {
1443
- wp::vec_t<2, int> shape_v(src_shape[0], src_shape[1]);
1444
- wp::vec_t<2, int> src_strides_v(src_strides[0], src_strides[1]);
1445
- wp::vec_t<2, int> dst_strides_v(dst_strides[0], dst_strides[1]);
1670
+ wp::vec_t<2, size_t> shape_v(src_shape[0], src_shape[1]);
1671
+ wp::vec_t<2, size_t> src_strides_v(src_strides[0], src_strides[1]);
1672
+ wp::vec_t<2, size_t> dst_strides_v(dst_strides[0], dst_strides[1]);
1446
1673
  wp::vec_t<2, const int*> src_indices_v(src_indices[0], src_indices[1]);
1447
1674
  wp::vec_t<2, const int*> dst_indices_v(dst_indices[0], dst_indices[1]);
1448
1675
 
@@ -1454,9 +1681,9 @@ WP_API bool wp_array_copy_device(void* context, void* dst, void* src, int dst_ty
1454
1681
  }
1455
1682
  case 3:
1456
1683
  {
1457
- wp::vec_t<3, int> shape_v(src_shape[0], src_shape[1], src_shape[2]);
1458
- wp::vec_t<3, int> src_strides_v(src_strides[0], src_strides[1], src_strides[2]);
1459
- wp::vec_t<3, int> dst_strides_v(dst_strides[0], dst_strides[1], dst_strides[2]);
1684
+ wp::vec_t<3, size_t> shape_v(src_shape[0], src_shape[1], src_shape[2]);
1685
+ wp::vec_t<3, size_t> src_strides_v(src_strides[0], src_strides[1], src_strides[2]);
1686
+ wp::vec_t<3, size_t> dst_strides_v(dst_strides[0], dst_strides[1], dst_strides[2]);
1460
1687
  wp::vec_t<3, const int*> src_indices_v(src_indices[0], src_indices[1], src_indices[2]);
1461
1688
  wp::vec_t<3, const int*> dst_indices_v(dst_indices[0], dst_indices[1], dst_indices[2]);
1462
1689
 
@@ -1468,9 +1695,9 @@ WP_API bool wp_array_copy_device(void* context, void* dst, void* src, int dst_ty
1468
1695
  }
1469
1696
  case 4:
1470
1697
  {
1471
- wp::vec_t<4, int> shape_v(src_shape[0], src_shape[1], src_shape[2], src_shape[3]);
1472
- wp::vec_t<4, int> src_strides_v(src_strides[0], src_strides[1], src_strides[2], src_strides[3]);
1473
- wp::vec_t<4, int> dst_strides_v(dst_strides[0], dst_strides[1], dst_strides[2], dst_strides[3]);
1698
+ wp::vec_t<4, size_t> shape_v(src_shape[0], src_shape[1], src_shape[2], src_shape[3]);
1699
+ wp::vec_t<4, size_t> src_strides_v(src_strides[0], src_strides[1], src_strides[2], src_strides[3]);
1700
+ wp::vec_t<4, size_t> dst_strides_v(dst_strides[0], dst_strides[1], dst_strides[2], dst_strides[3]);
1474
1701
  wp::vec_t<4, const int*> src_indices_v(src_indices[0], src_indices[1], src_indices[2], src_indices[3]);
1475
1702
  wp::vec_t<4, const int*> dst_indices_v(dst_indices[0], dst_indices[1], dst_indices[2], dst_indices[3]);
1476
1703
 
@@ -1490,94 +1717,94 @@ WP_API bool wp_array_copy_device(void* context, void* dst, void* src, int dst_ty
1490
1717
 
1491
1718
 
1492
1719
  static __global__ void array_fill_1d_kernel(void* data,
1493
- int n,
1494
- int stride,
1720
+ size_t n,
1721
+ size_t stride,
1495
1722
  const int* indices,
1496
1723
  const void* value,
1497
- int value_size)
1724
+ size_t value_size)
1498
1725
  {
1499
- int i = blockIdx.x * blockDim.x + threadIdx.x;
1726
+ size_t i = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1500
1727
  if (i < n)
1501
1728
  {
1502
- int idx = indices ? indices[i] : i;
1729
+ size_t idx = indices ? indices[i] : i;
1503
1730
  char* p = (char*)data + idx * stride;
1504
1731
  memcpy(p, value, value_size);
1505
1732
  }
1506
1733
  }
1507
1734
 
1508
1735
  static __global__ void array_fill_2d_kernel(void* data,
1509
- wp::vec_t<2, int> shape,
1510
- wp::vec_t<2, int> strides,
1736
+ wp::vec_t<2, size_t> shape,
1737
+ wp::vec_t<2, size_t> strides,
1511
1738
  wp::vec_t<2, const int*> indices,
1512
1739
  const void* value,
1513
- int value_size)
1740
+ size_t value_size)
1514
1741
  {
1515
- int tid = blockIdx.x * blockDim.x + threadIdx.x;
1516
- int n = shape[1];
1517
- int i = tid / n;
1518
- int j = tid % n;
1742
+ size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1743
+ size_t n = shape[1];
1744
+ size_t i = tid / n;
1745
+ size_t j = tid % n;
1519
1746
  if (i < shape[0] /*&& j < shape[1]*/)
1520
1747
  {
1521
- int idx0 = indices[0] ? indices[0][i] : i;
1522
- int idx1 = indices[1] ? indices[1][j] : j;
1748
+ size_t idx0 = indices[0] ? indices[0][i] : i;
1749
+ size_t idx1 = indices[1] ? indices[1][j] : j;
1523
1750
  char* p = (char*)data + idx0 * strides[0] + idx1 * strides[1];
1524
1751
  memcpy(p, value, value_size);
1525
1752
  }
1526
1753
  }
1527
1754
 
1528
1755
  static __global__ void array_fill_3d_kernel(void* data,
1529
- wp::vec_t<3, int> shape,
1530
- wp::vec_t<3, int> strides,
1756
+ wp::vec_t<3, size_t> shape,
1757
+ wp::vec_t<3, size_t> strides,
1531
1758
  wp::vec_t<3, const int*> indices,
1532
1759
  const void* value,
1533
- int value_size)
1534
- {
1535
- int tid = blockIdx.x * blockDim.x + threadIdx.x;
1536
- int n = shape[1];
1537
- int o = shape[2];
1538
- int i = tid / (n * o);
1539
- int j = tid % (n * o) / o;
1540
- int k = tid % o;
1760
+ size_t value_size)
1761
+ {
1762
+ size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1763
+ size_t n = shape[1];
1764
+ size_t o = shape[2];
1765
+ size_t i = tid / (n * o);
1766
+ size_t j = tid % (n * o) / o;
1767
+ size_t k = tid % o;
1541
1768
  if (i < shape[0] && j < shape[1] /*&& k < shape[2]*/)
1542
1769
  {
1543
- int idx0 = indices[0] ? indices[0][i] : i;
1544
- int idx1 = indices[1] ? indices[1][j] : j;
1545
- int idx2 = indices[2] ? indices[2][k] : k;
1770
+ size_t idx0 = indices[0] ? indices[0][i] : i;
1771
+ size_t idx1 = indices[1] ? indices[1][j] : j;
1772
+ size_t idx2 = indices[2] ? indices[2][k] : k;
1546
1773
  char* p = (char*)data + idx0 * strides[0] + idx1 * strides[1] + idx2 * strides[2];
1547
1774
  memcpy(p, value, value_size);
1548
1775
  }
1549
1776
  }
1550
1777
 
1551
1778
  static __global__ void array_fill_4d_kernel(void* data,
1552
- wp::vec_t<4, int> shape,
1553
- wp::vec_t<4, int> strides,
1779
+ wp::vec_t<4, size_t> shape,
1780
+ wp::vec_t<4, size_t> strides,
1554
1781
  wp::vec_t<4, const int*> indices,
1555
1782
  const void* value,
1556
- int value_size)
1557
- {
1558
- int tid = blockIdx.x * blockDim.x + threadIdx.x;
1559
- int n = shape[1];
1560
- int o = shape[2];
1561
- int p = shape[3];
1562
- int i = tid / (n * o * p);
1563
- int j = tid % (n * o * p) / (o * p);
1564
- int k = tid % (o * p) / p;
1565
- int l = tid % p;
1783
+ size_t value_size)
1784
+ {
1785
+ size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1786
+ size_t n = shape[1];
1787
+ size_t o = shape[2];
1788
+ size_t p = shape[3];
1789
+ size_t i = tid / (n * o * p);
1790
+ size_t j = tid % (n * o * p) / (o * p);
1791
+ size_t k = tid % (o * p) / p;
1792
+ size_t l = tid % p;
1566
1793
  if (i < shape[0] && j < shape[1] && k < shape[2] /*&& l < shape[3]*/)
1567
1794
  {
1568
- int idx0 = indices[0] ? indices[0][i] : i;
1569
- int idx1 = indices[1] ? indices[1][j] : j;
1570
- int idx2 = indices[2] ? indices[2][k] : k;
1571
- int idx3 = indices[3] ? indices[3][l] : l;
1795
+ size_t idx0 = indices[0] ? indices[0][i] : i;
1796
+ size_t idx1 = indices[1] ? indices[1][j] : j;
1797
+ size_t idx2 = indices[2] ? indices[2][k] : k;
1798
+ size_t idx3 = indices[3] ? indices[3][l] : l;
1572
1799
  char* p = (char*)data + idx0 * strides[0] + idx1 * strides[1] + idx2 * strides[2] + idx3 * strides[3];
1573
1800
  memcpy(p, value, value_size);
1574
1801
  }
1575
1802
  }
1576
1803
 
1577
1804
 
1578
- static __global__ void array_fill_fabric_kernel(wp::fabricarray_t<void> fa, const void* value, int value_size)
1805
+ static __global__ void array_fill_fabric_kernel(wp::fabricarray_t<void> fa, const void* value, size_t value_size)
1579
1806
  {
1580
- int tid = blockIdx.x * blockDim.x + threadIdx.x;
1807
+ size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1581
1808
  if (tid < fa.size)
1582
1809
  {
1583
1810
  void* dst_ptr = fabricarray_element_ptr(fa, tid, value_size);
@@ -1586,9 +1813,9 @@ static __global__ void array_fill_fabric_kernel(wp::fabricarray_t<void> fa, cons
1586
1813
  }
1587
1814
 
1588
1815
 
1589
- static __global__ void array_fill_fabric_indexed_kernel(wp::indexedfabricarray_t<void> ifa, const void* value, int value_size)
1816
+ static __global__ void array_fill_fabric_indexed_kernel(wp::indexedfabricarray_t<void> ifa, const void* value, size_t value_size)
1590
1817
  {
1591
- int tid = blockIdx.x * blockDim.x + threadIdx.x;
1818
+ size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1592
1819
  if (tid < ifa.size)
1593
1820
  {
1594
1821
  size_t idx = size_t(ifa.indices[tid]);
@@ -1655,67 +1882,76 @@ WP_API void wp_array_fill_device(void* context, void* arr_ptr, int arr_type, con
1655
1882
 
1656
1883
  ContextGuard guard(context);
1657
1884
 
1658
- // copy value to device memory
1659
- // TODO: use a persistent stream-local staging buffer to avoid allocs?
1660
- void* value_devptr = wp_alloc_device(WP_CURRENT_CONTEXT, value_size);
1661
- check_cuda(cudaMemcpyAsync(value_devptr, value_ptr, value_size, cudaMemcpyHostToDevice, get_current_stream()));
1885
+ void* value_devptr = NULL; // fill value in device memory
1886
+ bool free_devptr = true; // whether we need to free the memory
1887
+
1888
+ // prepare the fill value in a graph-friendly way
1889
+ if (!capturable_tmp_alloc(WP_CURRENT_CONTEXT, value_ptr, value_size, &value_devptr, &free_devptr))
1890
+ {
1891
+ fprintf(stderr, "Warp fill error: failed to copy value to device memory\n");
1892
+ return;
1893
+ }
1662
1894
 
1663
- // handle fabric arrays
1664
1895
  if (fa)
1665
1896
  {
1897
+ // handle fabric arrays
1666
1898
  wp_launch_device(WP_CURRENT_CONTEXT, array_fill_fabric_kernel, n,
1667
1899
  (*fa, value_devptr, value_size));
1668
- return;
1669
1900
  }
1670
1901
  else if (ifa)
1671
1902
  {
1903
+ // handle indexed fabric arrays
1672
1904
  wp_launch_device(WP_CURRENT_CONTEXT, array_fill_fabric_indexed_kernel, n,
1673
1905
  (*ifa, value_devptr, value_size));
1674
- return;
1675
1906
  }
1676
-
1677
- // handle regular or indexed arrays
1678
- switch (ndim)
1679
- {
1680
- case 1:
1681
- {
1682
- wp_launch_device(WP_CURRENT_CONTEXT, array_fill_1d_kernel, n,
1683
- (data, shape[0], strides[0], indices[0], value_devptr, value_size));
1684
- break;
1685
- }
1686
- case 2:
1687
- {
1688
- wp::vec_t<2, int> shape_v(shape[0], shape[1]);
1689
- wp::vec_t<2, int> strides_v(strides[0], strides[1]);
1690
- wp::vec_t<2, const int*> indices_v(indices[0], indices[1]);
1691
- wp_launch_device(WP_CURRENT_CONTEXT, array_fill_2d_kernel, n,
1692
- (data, shape_v, strides_v, indices_v, value_devptr, value_size));
1693
- break;
1694
- }
1695
- case 3:
1907
+ else
1696
1908
  {
1697
- wp::vec_t<3, int> shape_v(shape[0], shape[1], shape[2]);
1698
- wp::vec_t<3, int> strides_v(strides[0], strides[1], strides[2]);
1699
- wp::vec_t<3, const int*> indices_v(indices[0], indices[1], indices[2]);
1700
- wp_launch_device(WP_CURRENT_CONTEXT, array_fill_3d_kernel, n,
1701
- (data, shape_v, strides_v, indices_v, value_devptr, value_size));
1702
- break;
1909
+ // handle regular or indexed arrays
1910
+ switch (ndim)
1911
+ {
1912
+ case 1:
1913
+ {
1914
+ wp_launch_device(WP_CURRENT_CONTEXT, array_fill_1d_kernel, n,
1915
+ (data, shape[0], strides[0], indices[0], value_devptr, value_size));
1916
+ break;
1917
+ }
1918
+ case 2:
1919
+ {
1920
+ wp::vec_t<2, size_t> shape_v(shape[0], shape[1]);
1921
+ wp::vec_t<2, size_t> strides_v(strides[0], strides[1]);
1922
+ wp::vec_t<2, const int*> indices_v(indices[0], indices[1]);
1923
+ wp_launch_device(WP_CURRENT_CONTEXT, array_fill_2d_kernel, n,
1924
+ (data, shape_v, strides_v, indices_v, value_devptr, value_size));
1925
+ break;
1926
+ }
1927
+ case 3:
1928
+ {
1929
+ wp::vec_t<3, size_t> shape_v(shape[0], shape[1], shape[2]);
1930
+ wp::vec_t<3, size_t> strides_v(strides[0], strides[1], strides[2]);
1931
+ wp::vec_t<3, const int*> indices_v(indices[0], indices[1], indices[2]);
1932
+ wp_launch_device(WP_CURRENT_CONTEXT, array_fill_3d_kernel, n,
1933
+ (data, shape_v, strides_v, indices_v, value_devptr, value_size));
1934
+ break;
1935
+ }
1936
+ case 4:
1937
+ {
1938
+ wp::vec_t<4, size_t> shape_v(shape[0], shape[1], shape[2], shape[3]);
1939
+ wp::vec_t<4, size_t> strides_v(strides[0], strides[1], strides[2], strides[3]);
1940
+ wp::vec_t<4, const int*> indices_v(indices[0], indices[1], indices[2], indices[3]);
1941
+ wp_launch_device(WP_CURRENT_CONTEXT, array_fill_4d_kernel, n,
1942
+ (data, shape_v, strides_v, indices_v, value_devptr, value_size));
1943
+ break;
1944
+ }
1945
+ default:
1946
+ fprintf(stderr, "Warp fill error: invalid array dimensionality (%d)\n", ndim);
1947
+ break;
1948
+ }
1703
1949
  }
1704
- case 4:
1950
+
1951
+ if (free_devptr)
1705
1952
  {
1706
- wp::vec_t<4, int> shape_v(shape[0], shape[1], shape[2], shape[3]);
1707
- wp::vec_t<4, int> strides_v(strides[0], strides[1], strides[2], strides[3]);
1708
- wp::vec_t<4, const int*> indices_v(indices[0], indices[1], indices[2], indices[3]);
1709
- wp_launch_device(WP_CURRENT_CONTEXT, array_fill_4d_kernel, n,
1710
- (data, shape_v, strides_v, indices_v, value_devptr, value_size));
1711
- break;
1712
- }
1713
- default:
1714
- fprintf(stderr, "Warp fill error: invalid array dimensionality (%d)\n", ndim);
1715
- return;
1953
+ wp_free_device(WP_CURRENT_CONTEXT, value_devptr);
1716
1954
  }
1717
-
1718
- wp_free_device(WP_CURRENT_CONTEXT, value_devptr);
1719
1955
  }
1720
1956
 
1721
1957
  void wp_array_scan_int_device(uint64_t in, uint64_t out, int len, bool inclusive)
@@ -2072,14 +2308,15 @@ void wp_cuda_context_synchronize(void* context)
2072
2308
 
2073
2309
  check_cu(cuCtxSynchronize_f());
2074
2310
 
2075
- if (free_deferred_allocs(context ? context : get_current_context()) > 0)
2311
+ if (!context)
2312
+ context = get_current_context();
2313
+
2314
+ if (run_deferred_actions(context) > 0)
2076
2315
  {
2077
- // ensure deferred asynchronous deallocations complete
2316
+ // ensure deferred asynchronous operations complete
2078
2317
  check_cu(cuCtxSynchronize_f());
2079
2318
  }
2080
2319
 
2081
- unload_deferred_modules(context);
2082
-
2083
2320
  // check_cuda(cudaDeviceGraphMemTrim(wp_cuda_context_get_device_ordinal(context)));
2084
2321
  }
2085
2322
 
@@ -2514,15 +2751,36 @@ void wp_cuda_stream_synchronize(void* stream)
2514
2751
  check_cu(cuStreamSynchronize_f(static_cast<CUstream>(stream)));
2515
2752
  }
2516
2753
 
2517
- void wp_cuda_stream_wait_event(void* stream, void* event)
2754
+ void wp_cuda_stream_wait_event(void* stream, void* event, bool external)
2518
2755
  {
2519
- check_cu(cuStreamWaitEvent_f(static_cast<CUstream>(stream), static_cast<CUevent>(event), 0));
2756
+ // the external flag can only be used during graph capture
2757
+ if (external && !g_captures.empty() && wp_cuda_stream_is_capturing(stream))
2758
+ {
2759
+ // wait for an external event during graph capture
2760
+ check_cu(cuStreamWaitEvent_f(static_cast<CUstream>(stream), static_cast<CUevent>(event), CU_EVENT_WAIT_EXTERNAL));
2761
+ }
2762
+ else
2763
+ {
2764
+ check_cu(cuStreamWaitEvent_f(static_cast<CUstream>(stream), static_cast<CUevent>(event), CU_EVENT_WAIT_DEFAULT));
2765
+ }
2520
2766
  }
2521
2767
 
2522
- void wp_cuda_stream_wait_stream(void* stream, void* other_stream, void* event)
2768
+ void wp_cuda_stream_wait_stream(void* stream, void* other_stream, void* event, bool external)
2523
2769
  {
2524
- check_cu(cuEventRecord_f(static_cast<CUevent>(event), static_cast<CUstream>(other_stream)));
2525
- check_cu(cuStreamWaitEvent_f(static_cast<CUstream>(stream), static_cast<CUevent>(event), 0));
2770
+ unsigned record_flags = CU_EVENT_RECORD_DEFAULT;
2771
+ unsigned wait_flags = CU_EVENT_WAIT_DEFAULT;
2772
+
2773
+ // the external flag can only be used during graph capture
2774
+ if (external && !g_captures.empty())
2775
+ {
2776
+ if (wp_cuda_stream_is_capturing(other_stream))
2777
+ record_flags = CU_EVENT_RECORD_EXTERNAL;
2778
+ if (wp_cuda_stream_is_capturing(stream))
2779
+ wait_flags = CU_EVENT_WAIT_EXTERNAL;
2780
+ }
2781
+
2782
+ check_cu(cuEventRecordWithFlags_f(static_cast<CUevent>(event), static_cast<CUstream>(other_stream), record_flags));
2783
+ check_cu(cuStreamWaitEvent_f(static_cast<CUstream>(stream), static_cast<CUevent>(event), wait_flags));
2526
2784
  }
2527
2785
 
2528
2786
  int wp_cuda_stream_is_capturing(void* stream)
@@ -2575,11 +2833,12 @@ int wp_cuda_event_query(void* event)
2575
2833
  return res;
2576
2834
  }
2577
2835
 
2578
- void wp_cuda_event_record(void* event, void* stream, bool timing)
2836
+ void wp_cuda_event_record(void* event, void* stream, bool external)
2579
2837
  {
2580
- if (timing && !g_captures.empty() && wp_cuda_stream_is_capturing(stream))
2838
+ // the external flag can only be used during graph capture
2839
+ if (external && !g_captures.empty() && wp_cuda_stream_is_capturing(stream))
2581
2840
  {
2582
- // record timing event during graph capture
2841
+ // record external event during graph capture (e.g., for timing or when explicitly specified by the user)
2583
2842
  check_cu(cuEventRecordWithFlags_f(static_cast<CUevent>(event), static_cast<CUstream>(stream), CU_EVENT_RECORD_EXTERNAL));
2584
2843
  }
2585
2844
  else
@@ -2629,7 +2888,7 @@ bool wp_cuda_graph_begin_capture(void* context, void* stream, int external)
2629
2888
  else
2630
2889
  {
2631
2890
  // start the capture
2632
- if (!check_cuda(cudaStreamBeginCapture(cuda_stream, cudaStreamCaptureModeGlobal)))
2891
+ if (!check_cuda(cudaStreamBeginCapture(cuda_stream, cudaStreamCaptureModeThreadLocal)))
2633
2892
  return false;
2634
2893
  }
2635
2894
 
@@ -2673,6 +2932,7 @@ bool wp_cuda_graph_end_capture(void* context, void* stream, void** graph_ret)
2673
2932
  // get capture info
2674
2933
  bool external = capture->external;
2675
2934
  uint64_t capture_id = capture->id;
2935
+ std::vector<FreeInfo> tmp_allocs = capture->tmp_allocs;
2676
2936
 
2677
2937
  // clear capture info
2678
2938
  stream_info->capture = NULL;
@@ -2742,15 +3002,17 @@ bool wp_cuda_graph_end_capture(void* context, void* stream, void** graph_ret)
2742
3002
  unfreed_allocs.push_back(it->first);
2743
3003
  }
2744
3004
 
2745
- if (!unfreed_allocs.empty())
3005
+ if (!unfreed_allocs.empty() || !tmp_allocs.empty())
2746
3006
  {
2747
3007
  // Create a user object that will notify us when the instantiated graph is destroyed.
2748
3008
  // This works for external captures also, since we wouldn't otherwise know when
2749
3009
  // the externally-created graph instance gets deleted.
2750
3010
  // This callback is guaranteed to arrive after the graph has finished executing on the device,
2751
3011
  // not necessarily when cudaGraphExecDestroy() is called.
2752
- GraphInfo* graph_info = new GraphInfo;
3012
+ GraphDestroyCallbackInfo* graph_info = new GraphDestroyCallbackInfo;
3013
+ graph_info->context = context ? context : get_current_context();
2753
3014
  graph_info->unfreed_allocs = unfreed_allocs;
3015
+ graph_info->tmp_allocs = tmp_allocs;
2754
3016
  cudaUserObject_t user_object;
2755
3017
  check_cuda(cudaUserObjectCreate(&user_object, graph_info, on_graph_destroy, 1, cudaUserObjectNoDestructorSync));
2756
3018
  check_cuda(cudaGraphRetainUserObject(graph, user_object, 1, cudaGraphUserObjectMove));
@@ -2774,8 +3036,7 @@ bool wp_cuda_graph_end_capture(void* context, void* stream, void** graph_ret)
2774
3036
  // process deferred free list if no more captures are ongoing
2775
3037
  if (g_captures.empty())
2776
3038
  {
2777
- free_deferred_allocs();
2778
- unload_deferred_modules();
3039
+ run_deferred_actions();
2779
3040
  }
2780
3041
 
2781
3042
  if (graph_ret)
@@ -2996,7 +3257,7 @@ bool wp_cuda_graph_resume_capture(void* context, void* stream, void* graph)
2996
3257
  leaf_nodes.data(),
2997
3258
  nullptr,
2998
3259
  leaf_nodes.size(),
2999
- cudaStreamCaptureModeGlobal)))
3260
+ cudaStreamCaptureModeThreadLocal)))
3000
3261
  return false;
3001
3262
 
3002
3263
  return true;
@@ -3455,16 +3716,38 @@ bool wp_cuda_graph_launch(void* graph_exec, void* stream)
3455
3716
 
3456
3717
  bool wp_cuda_graph_destroy(void* context, void* graph)
3457
3718
  {
3458
- ContextGuard guard(context);
3459
-
3460
- return check_cuda(cudaGraphDestroy((cudaGraph_t)graph));
3719
+ // ensure there are no graph captures in progress
3720
+ if (g_captures.empty())
3721
+ {
3722
+ ContextGuard guard(context);
3723
+ return check_cuda(cudaGraphDestroy((cudaGraph_t)graph));
3724
+ }
3725
+ else
3726
+ {
3727
+ GraphDestroyInfo info;
3728
+ info.context = context ? context : get_current_context();
3729
+ info.graph = graph;
3730
+ g_deferred_graph_list.push_back(info);
3731
+ return true;
3732
+ }
3461
3733
  }
3462
3734
 
3463
3735
  bool wp_cuda_graph_exec_destroy(void* context, void* graph_exec)
3464
3736
  {
3465
- ContextGuard guard(context);
3466
-
3467
- return check_cuda(cudaGraphExecDestroy((cudaGraphExec_t)graph_exec));
3737
+ // ensure there are no graph captures in progress
3738
+ if (g_captures.empty())
3739
+ {
3740
+ ContextGuard guard(context);
3741
+ return check_cuda(cudaGraphExecDestroy((cudaGraphExec_t)graph_exec));
3742
+ }
3743
+ else
3744
+ {
3745
+ GraphDestroyInfo info;
3746
+ info.context = context ? context : get_current_context();
3747
+ info.graph_exec = graph_exec;
3748
+ g_deferred_graph_list.push_back(info);
3749
+ return true;
3750
+ }
3468
3751
  }
3469
3752
 
3470
3753
  bool write_file(const char* data, size_t size, std::string filename, const char* mode)
@@ -4317,17 +4600,5 @@ void wp_cuda_timing_end(timing_result_t* results, int size)
4317
4600
  g_cuda_timing_state = parent_state;
4318
4601
  }
4319
4602
 
4320
- // impl. files
4321
- #include "bvh.cu"
4322
- #include "mesh.cu"
4323
- #include "sort.cu"
4324
- #include "hashgrid.cu"
4325
- #include "reduce.cu"
4326
- #include "runlength_encode.cu"
4327
- #include "scan.cu"
4328
- #include "sparse.cu"
4329
- #include "volume.cu"
4330
- #include "volume_builder.cu"
4331
-
4332
4603
  //#include "spline.inl"
4333
4604
  //#include "volume.inl"