warp-lang 1.9.0__py3-none-win_amd64.whl → 1.10.0rc2__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (350) hide show
  1. warp/__init__.py +301 -287
  2. warp/__init__.pyi +2220 -313
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1075 -0
  5. warp/_src/build.py +618 -0
  6. warp/_src/build_dll.py +640 -0
  7. warp/{builtins.py → _src/builtins.py} +1497 -226
  8. warp/_src/codegen.py +4359 -0
  9. warp/{config.py → _src/config.py} +178 -169
  10. warp/_src/constants.py +57 -0
  11. warp/_src/context.py +8294 -0
  12. warp/_src/dlpack.py +462 -0
  13. warp/_src/fabric.py +355 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +508 -0
  16. warp/_src/fem/cache.py +687 -0
  17. warp/_src/fem/dirichlet.py +188 -0
  18. warp/{fem → _src/fem}/domain.py +40 -30
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +701 -0
  21. warp/{fem → _src/fem}/field/nodal_field.py +30 -15
  22. warp/{fem → _src/fem}/field/restriction.py +1 -1
  23. warp/{fem → _src/fem}/field/virtual.py +53 -27
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/{fem → _src/fem}/geometry/adaptive_nanogrid.py +77 -163
  26. warp/_src/fem/geometry/closest_point.py +97 -0
  27. warp/{fem → _src/fem}/geometry/deformed_geometry.py +14 -22
  28. warp/{fem → _src/fem}/geometry/element.py +32 -10
  29. warp/{fem → _src/fem}/geometry/geometry.py +48 -20
  30. warp/{fem → _src/fem}/geometry/grid_2d.py +12 -23
  31. warp/{fem → _src/fem}/geometry/grid_3d.py +12 -23
  32. warp/{fem → _src/fem}/geometry/hexmesh.py +40 -63
  33. warp/{fem → _src/fem}/geometry/nanogrid.py +255 -248
  34. warp/{fem → _src/fem}/geometry/partition.py +121 -63
  35. warp/{fem → _src/fem}/geometry/quadmesh.py +26 -45
  36. warp/{fem → _src/fem}/geometry/tetmesh.py +40 -63
  37. warp/{fem → _src/fem}/geometry/trimesh.py +26 -45
  38. warp/{fem → _src/fem}/integrate.py +164 -158
  39. warp/_src/fem/linalg.py +383 -0
  40. warp/_src/fem/operator.py +396 -0
  41. warp/_src/fem/polynomial.py +229 -0
  42. warp/{fem → _src/fem}/quadrature/pic_quadrature.py +15 -20
  43. warp/{fem → _src/fem}/quadrature/quadrature.py +95 -47
  44. warp/_src/fem/space/__init__.py +248 -0
  45. warp/{fem → _src/fem}/space/basis_function_space.py +20 -11
  46. warp/_src/fem/space/basis_space.py +679 -0
  47. warp/{fem → _src/fem}/space/dof_mapper.py +3 -3
  48. warp/{fem → _src/fem}/space/function_space.py +14 -13
  49. warp/{fem → _src/fem}/space/grid_2d_function_space.py +4 -7
  50. warp/{fem → _src/fem}/space/grid_3d_function_space.py +4 -4
  51. warp/{fem → _src/fem}/space/hexmesh_function_space.py +4 -10
  52. warp/{fem → _src/fem}/space/nanogrid_function_space.py +3 -9
  53. warp/{fem → _src/fem}/space/partition.py +117 -60
  54. warp/{fem → _src/fem}/space/quadmesh_function_space.py +4 -10
  55. warp/{fem → _src/fem}/space/restriction.py +66 -33
  56. warp/_src/fem/space/shape/__init__.py +152 -0
  57. warp/{fem → _src/fem}/space/shape/cube_shape_function.py +9 -9
  58. warp/{fem → _src/fem}/space/shape/shape_function.py +8 -9
  59. warp/{fem → _src/fem}/space/shape/square_shape_function.py +6 -6
  60. warp/{fem → _src/fem}/space/shape/tet_shape_function.py +3 -3
  61. warp/{fem → _src/fem}/space/shape/triangle_shape_function.py +3 -3
  62. warp/{fem → _src/fem}/space/tetmesh_function_space.py +3 -9
  63. warp/_src/fem/space/topology.py +459 -0
  64. warp/{fem → _src/fem}/space/trimesh_function_space.py +3 -9
  65. warp/_src/fem/types.py +112 -0
  66. warp/_src/fem/utils.py +486 -0
  67. warp/_src/jax.py +186 -0
  68. warp/_src/jax_experimental/__init__.py +14 -0
  69. warp/_src/jax_experimental/custom_call.py +387 -0
  70. warp/_src/jax_experimental/ffi.py +1284 -0
  71. warp/_src/jax_experimental/xla_ffi.py +656 -0
  72. warp/_src/marching_cubes.py +708 -0
  73. warp/_src/math.py +414 -0
  74. warp/_src/optim/__init__.py +14 -0
  75. warp/_src/optim/adam.py +163 -0
  76. warp/_src/optim/linear.py +1606 -0
  77. warp/_src/optim/sgd.py +112 -0
  78. warp/_src/paddle.py +406 -0
  79. warp/_src/render/__init__.py +14 -0
  80. warp/_src/render/imgui_manager.py +289 -0
  81. warp/_src/render/render_opengl.py +3636 -0
  82. warp/_src/render/render_usd.py +937 -0
  83. warp/_src/render/utils.py +160 -0
  84. warp/_src/sparse.py +2716 -0
  85. warp/_src/tape.py +1206 -0
  86. warp/{thirdparty → _src/thirdparty}/unittest_parallel.py +9 -2
  87. warp/_src/torch.py +391 -0
  88. warp/_src/types.py +5870 -0
  89. warp/_src/utils.py +1693 -0
  90. warp/autograd.py +12 -1054
  91. warp/bin/warp-clang.dll +0 -0
  92. warp/bin/warp.dll +0 -0
  93. warp/build.py +8 -588
  94. warp/build_dll.py +6 -471
  95. warp/codegen.py +6 -4246
  96. warp/constants.py +6 -39
  97. warp/context.py +12 -7851
  98. warp/dlpack.py +6 -444
  99. warp/examples/distributed/example_jacobi_mpi.py +4 -5
  100. warp/examples/fem/example_adaptive_grid.py +1 -1
  101. warp/examples/fem/example_apic_fluid.py +1 -1
  102. warp/examples/fem/example_burgers.py +8 -8
  103. warp/examples/fem/example_diffusion.py +1 -1
  104. warp/examples/fem/example_distortion_energy.py +1 -1
  105. warp/examples/fem/example_mixed_elasticity.py +2 -2
  106. warp/examples/fem/example_navier_stokes.py +1 -1
  107. warp/examples/fem/example_nonconforming_contact.py +7 -7
  108. warp/examples/fem/example_stokes.py +1 -1
  109. warp/examples/fem/example_stokes_transfer.py +1 -1
  110. warp/examples/fem/utils.py +2 -2
  111. warp/examples/interop/example_jax_callable.py +1 -1
  112. warp/examples/interop/example_jax_ffi_callback.py +1 -1
  113. warp/examples/interop/example_jax_kernel.py +3 -2
  114. warp/examples/tile/example_tile_mcgp.py +191 -0
  115. warp/fabric.py +6 -337
  116. warp/fem/__init__.py +159 -97
  117. warp/fem/adaptivity.py +7 -489
  118. warp/fem/cache.py +9 -648
  119. warp/fem/dirichlet.py +6 -184
  120. warp/fem/field/__init__.py +8 -109
  121. warp/fem/field/field.py +7 -652
  122. warp/fem/geometry/__init__.py +7 -18
  123. warp/fem/geometry/closest_point.py +11 -77
  124. warp/fem/linalg.py +18 -366
  125. warp/fem/operator.py +11 -369
  126. warp/fem/polynomial.py +9 -209
  127. warp/fem/space/__init__.py +5 -211
  128. warp/fem/space/basis_space.py +6 -662
  129. warp/fem/space/shape/__init__.py +41 -118
  130. warp/fem/space/topology.py +6 -437
  131. warp/fem/types.py +6 -81
  132. warp/fem/utils.py +11 -444
  133. warp/jax.py +8 -165
  134. warp/jax_experimental/__init__.py +14 -1
  135. warp/jax_experimental/custom_call.py +8 -342
  136. warp/jax_experimental/ffi.py +17 -853
  137. warp/jax_experimental/xla_ffi.py +5 -596
  138. warp/marching_cubes.py +5 -689
  139. warp/math.py +16 -393
  140. warp/native/array.h +385 -37
  141. warp/native/builtin.h +316 -39
  142. warp/native/bvh.cpp +43 -9
  143. warp/native/bvh.cu +62 -27
  144. warp/native/bvh.h +310 -309
  145. warp/native/clang/clang.cpp +102 -97
  146. warp/native/coloring.cpp +0 -1
  147. warp/native/crt.h +208 -0
  148. warp/native/exports.h +156 -0
  149. warp/native/hashgrid.cu +2 -0
  150. warp/native/intersect.h +24 -1
  151. warp/native/intersect_tri.h +44 -35
  152. warp/native/mat.h +1456 -276
  153. warp/native/mesh.cpp +4 -4
  154. warp/native/mesh.cu +4 -2
  155. warp/native/mesh.h +176 -61
  156. warp/native/quat.h +0 -52
  157. warp/native/scan.cu +2 -0
  158. warp/native/sort.cu +22 -13
  159. warp/native/sort.h +2 -0
  160. warp/native/sparse.cu +7 -3
  161. warp/native/spatial.h +12 -0
  162. warp/native/tile.h +837 -70
  163. warp/native/tile_radix_sort.h +1 -1
  164. warp/native/tile_reduce.h +394 -46
  165. warp/native/tile_scan.h +4 -4
  166. warp/native/vec.h +469 -53
  167. warp/native/version.h +23 -0
  168. warp/native/volume.cpp +1 -1
  169. warp/native/volume.cu +1 -0
  170. warp/native/volume.h +1 -1
  171. warp/native/volume_builder.cu +2 -0
  172. warp/native/warp.cpp +60 -32
  173. warp/native/warp.cu +313 -201
  174. warp/native/warp.h +14 -11
  175. warp/optim/__init__.py +6 -3
  176. warp/optim/adam.py +6 -145
  177. warp/optim/linear.py +14 -1585
  178. warp/optim/sgd.py +6 -94
  179. warp/paddle.py +6 -388
  180. warp/render/__init__.py +8 -4
  181. warp/render/imgui_manager.py +7 -267
  182. warp/render/render_opengl.py +6 -3616
  183. warp/render/render_usd.py +6 -918
  184. warp/render/utils.py +6 -142
  185. warp/sparse.py +37 -2563
  186. warp/tape.py +6 -1188
  187. warp/tests/__main__.py +1 -1
  188. warp/tests/cuda/test_async.py +4 -4
  189. warp/tests/cuda/test_conditional_captures.py +1 -1
  190. warp/tests/cuda/test_multigpu.py +1 -1
  191. warp/tests/cuda/test_streams.py +58 -1
  192. warp/tests/geometry/test_bvh.py +157 -22
  193. warp/tests/geometry/test_hash_grid.py +38 -0
  194. warp/tests/geometry/test_marching_cubes.py +0 -1
  195. warp/tests/geometry/test_mesh.py +5 -3
  196. warp/tests/geometry/test_mesh_query_aabb.py +5 -12
  197. warp/tests/geometry/test_mesh_query_point.py +5 -2
  198. warp/tests/geometry/test_mesh_query_ray.py +15 -3
  199. warp/tests/geometry/test_volume_write.py +5 -5
  200. warp/tests/interop/test_dlpack.py +14 -14
  201. warp/tests/interop/test_jax.py +1382 -79
  202. warp/tests/interop/test_paddle.py +1 -1
  203. warp/tests/test_adam.py +0 -1
  204. warp/tests/test_arithmetic.py +9 -9
  205. warp/tests/test_array.py +529 -100
  206. warp/tests/test_array_reduce.py +3 -3
  207. warp/tests/test_atomic.py +12 -8
  208. warp/tests/test_atomic_bitwise.py +209 -0
  209. warp/tests/test_atomic_cas.py +4 -4
  210. warp/tests/test_bool.py +2 -2
  211. warp/tests/test_builtins_resolution.py +5 -571
  212. warp/tests/test_codegen.py +34 -15
  213. warp/tests/test_conditional.py +1 -1
  214. warp/tests/test_context.py +6 -6
  215. warp/tests/test_copy.py +242 -161
  216. warp/tests/test_ctypes.py +3 -3
  217. warp/tests/test_devices.py +24 -2
  218. warp/tests/test_examples.py +16 -84
  219. warp/tests/test_fabricarray.py +35 -35
  220. warp/tests/test_fast_math.py +0 -2
  221. warp/tests/test_fem.py +60 -14
  222. warp/tests/test_fixedarray.py +3 -3
  223. warp/tests/test_func.py +8 -5
  224. warp/tests/test_generics.py +1 -1
  225. warp/tests/test_indexedarray.py +24 -24
  226. warp/tests/test_intersect.py +39 -9
  227. warp/tests/test_large.py +1 -1
  228. warp/tests/test_lerp.py +3 -1
  229. warp/tests/test_linear_solvers.py +1 -1
  230. warp/tests/test_map.py +49 -4
  231. warp/tests/test_mat.py +52 -62
  232. warp/tests/test_mat_constructors.py +4 -5
  233. warp/tests/test_mat_lite.py +1 -1
  234. warp/tests/test_mat_scalar_ops.py +121 -121
  235. warp/tests/test_math.py +34 -0
  236. warp/tests/test_module_aot.py +4 -4
  237. warp/tests/test_modules_lite.py +28 -2
  238. warp/tests/test_print.py +11 -11
  239. warp/tests/test_quat.py +93 -58
  240. warp/tests/test_runlength_encode.py +1 -1
  241. warp/tests/test_scalar_ops.py +38 -10
  242. warp/tests/test_smoothstep.py +1 -1
  243. warp/tests/test_sparse.py +126 -15
  244. warp/tests/test_spatial.py +105 -87
  245. warp/tests/test_special_values.py +6 -6
  246. warp/tests/test_static.py +7 -7
  247. warp/tests/test_struct.py +13 -2
  248. warp/tests/test_triangle_closest_point.py +48 -1
  249. warp/tests/test_tuple.py +96 -0
  250. warp/tests/test_types.py +82 -9
  251. warp/tests/test_utils.py +52 -52
  252. warp/tests/test_vec.py +29 -29
  253. warp/tests/test_vec_constructors.py +5 -5
  254. warp/tests/test_vec_scalar_ops.py +97 -97
  255. warp/tests/test_version.py +75 -0
  256. warp/tests/tile/test_tile.py +239 -0
  257. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  258. warp/tests/tile/test_tile_cholesky.py +7 -4
  259. warp/tests/tile/test_tile_load.py +26 -2
  260. warp/tests/tile/test_tile_mathdx.py +3 -3
  261. warp/tests/tile/test_tile_matmul.py +1 -1
  262. warp/tests/tile/test_tile_mlp.py +2 -4
  263. warp/tests/tile/test_tile_reduce.py +214 -13
  264. warp/tests/unittest_suites.py +6 -14
  265. warp/tests/unittest_utils.py +10 -9
  266. warp/tests/walkthrough_debug.py +3 -1
  267. warp/torch.py +6 -373
  268. warp/types.py +29 -5750
  269. warp/utils.py +10 -1659
  270. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/METADATA +47 -103
  271. warp_lang-1.10.0rc2.dist-info/RECORD +468 -0
  272. warp_lang-1.10.0rc2.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  273. warp_lang-1.10.0rc2.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  274. warp_lang-1.10.0rc2.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  275. warp_lang-1.10.0rc2.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  276. warp_lang-1.10.0rc2.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  277. warp_lang-1.10.0rc2.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  278. warp_lang-1.10.0rc2.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  279. warp_lang-1.10.0rc2.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  280. warp_lang-1.10.0rc2.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  281. warp_lang-1.10.0rc2.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  282. warp_lang-1.10.0rc2.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  283. warp_lang-1.10.0rc2.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  284. warp_lang-1.10.0rc2.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  285. warp_lang-1.10.0rc2.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  286. warp_lang-1.10.0rc2.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  287. warp/examples/assets/cartpole.urdf +0 -110
  288. warp/examples/assets/crazyflie.usd +0 -0
  289. warp/examples/assets/nv_ant.xml +0 -92
  290. warp/examples/assets/nv_humanoid.xml +0 -183
  291. warp/examples/assets/quadruped.urdf +0 -268
  292. warp/examples/optim/example_bounce.py +0 -266
  293. warp/examples/optim/example_cloth_throw.py +0 -228
  294. warp/examples/optim/example_drone.py +0 -870
  295. warp/examples/optim/example_inverse_kinematics.py +0 -182
  296. warp/examples/optim/example_inverse_kinematics_torch.py +0 -191
  297. warp/examples/optim/example_softbody_properties.py +0 -400
  298. warp/examples/optim/example_spring_cage.py +0 -245
  299. warp/examples/optim/example_trajectory.py +0 -227
  300. warp/examples/sim/example_cartpole.py +0 -143
  301. warp/examples/sim/example_cloth.py +0 -225
  302. warp/examples/sim/example_cloth_self_contact.py +0 -316
  303. warp/examples/sim/example_granular.py +0 -130
  304. warp/examples/sim/example_granular_collision_sdf.py +0 -202
  305. warp/examples/sim/example_jacobian_ik.py +0 -244
  306. warp/examples/sim/example_particle_chain.py +0 -124
  307. warp/examples/sim/example_quadruped.py +0 -203
  308. warp/examples/sim/example_rigid_chain.py +0 -203
  309. warp/examples/sim/example_rigid_contact.py +0 -195
  310. warp/examples/sim/example_rigid_force.py +0 -133
  311. warp/examples/sim/example_rigid_gyroscopic.py +0 -115
  312. warp/examples/sim/example_rigid_soft_contact.py +0 -140
  313. warp/examples/sim/example_soft_body.py +0 -196
  314. warp/examples/tile/example_tile_walker.py +0 -327
  315. warp/sim/__init__.py +0 -74
  316. warp/sim/articulation.py +0 -793
  317. warp/sim/collide.py +0 -2570
  318. warp/sim/graph_coloring.py +0 -307
  319. warp/sim/import_mjcf.py +0 -791
  320. warp/sim/import_snu.py +0 -227
  321. warp/sim/import_urdf.py +0 -579
  322. warp/sim/import_usd.py +0 -898
  323. warp/sim/inertia.py +0 -357
  324. warp/sim/integrator.py +0 -245
  325. warp/sim/integrator_euler.py +0 -2000
  326. warp/sim/integrator_featherstone.py +0 -2101
  327. warp/sim/integrator_vbd.py +0 -2487
  328. warp/sim/integrator_xpbd.py +0 -3295
  329. warp/sim/model.py +0 -4821
  330. warp/sim/particles.py +0 -121
  331. warp/sim/render.py +0 -431
  332. warp/sim/utils.py +0 -431
  333. warp/tests/sim/disabled_kinematics.py +0 -244
  334. warp/tests/sim/test_cloth.py +0 -863
  335. warp/tests/sim/test_collision.py +0 -743
  336. warp/tests/sim/test_coloring.py +0 -347
  337. warp/tests/sim/test_inertia.py +0 -161
  338. warp/tests/sim/test_model.py +0 -226
  339. warp/tests/sim/test_sim_grad.py +0 -287
  340. warp/tests/sim/test_sim_grad_bounce_linear.py +0 -212
  341. warp/tests/sim/test_sim_kinematics.py +0 -98
  342. warp/thirdparty/__init__.py +0 -0
  343. warp_lang-1.9.0.dist-info/RECORD +0 -456
  344. /warp/{fem → _src/fem}/quadrature/__init__.py +0 -0
  345. /warp/{tests/sim → _src/thirdparty}/__init__.py +0 -0
  346. /warp/{thirdparty → _src/thirdparty}/appdirs.py +0 -0
  347. /warp/{thirdparty → _src/thirdparty}/dlpack.py +0 -0
  348. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/WHEEL +0 -0
  349. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/licenses/LICENSE.md +0 -0
  350. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/top_level.txt +0 -0
warp/native/warp.cu CHANGED
@@ -19,6 +19,7 @@
19
19
  #include "scan.h"
20
20
  #include "cuda_util.h"
21
21
  #include "error.h"
22
+ #include "sort.h"
22
23
 
23
24
  #include <cstdlib>
24
25
  #include <fstream>
@@ -221,6 +222,14 @@ struct ModuleInfo
221
222
  void* module = NULL;
222
223
  };
223
224
 
225
+ // Information used when deferring graph destruction.
226
+ struct GraphDestroyInfo
227
+ {
228
+ void* context = NULL;
229
+ void* graph = NULL;
230
+ void* graph_exec = NULL;
231
+ };
232
+
224
233
  static std::unordered_map<CUfunction, std::string> g_kernel_names;
225
234
 
226
235
  // cached info for all devices, indexed by ordinal
@@ -252,6 +261,11 @@ static std::vector<FreeInfo> g_deferred_free_list;
252
261
  // Call unload_deferred_modules() to release.
253
262
  static std::vector<ModuleInfo> g_deferred_module_list;
254
263
 
264
+ // Graphs that cannot be destroyed immediately get queued here.
265
+ // Call destroy_deferred_graphs() to release.
266
+ static std::vector<GraphDestroyInfo> g_deferred_graph_list;
267
+
268
+
255
269
  void wp_cuda_set_context_restore_policy(bool always_restore)
256
270
  {
257
271
  ContextGuard::always_restore = always_restore;
@@ -337,7 +351,7 @@ int cuda_init()
337
351
  }
338
352
 
339
353
 
340
- static inline CUcontext get_current_context()
354
+ CUcontext get_current_context()
341
355
  {
342
356
  CUcontext ctx;
343
357
  if (check_cu(cuCtxGetCurrent_f(&ctx)))
@@ -494,6 +508,38 @@ static int unload_deferred_modules(void* context = NULL)
494
508
  return num_unloaded_modules;
495
509
  }
496
510
 
511
+ static int destroy_deferred_graphs(void* context = NULL)
512
+ {
513
+ if (g_deferred_graph_list.empty() || !g_captures.empty())
514
+ return 0;
515
+
516
+ int num_destroyed_graphs = 0;
517
+ for (auto it = g_deferred_graph_list.begin(); it != g_deferred_graph_list.end(); /*noop*/)
518
+ {
519
+ // destroy the graph if it matches the given context or if the context is unspecified
520
+ const GraphDestroyInfo& graph_info = *it;
521
+ if (graph_info.context == context || !context)
522
+ {
523
+ if (graph_info.graph)
524
+ {
525
+ check_cuda(cudaGraphDestroy((cudaGraph_t)graph_info.graph));
526
+ }
527
+ if (graph_info.graph_exec)
528
+ {
529
+ check_cuda(cudaGraphExecDestroy((cudaGraphExec_t)graph_info.graph_exec));
530
+ }
531
+ ++num_destroyed_graphs;
532
+ it = g_deferred_graph_list.erase(it);
533
+ }
534
+ else
535
+ {
536
+ ++it;
537
+ }
538
+ }
539
+
540
+ return num_destroyed_graphs;
541
+ }
542
+
497
543
  static void CUDART_CB on_graph_destroy(void* user_data)
498
544
  {
499
545
  if (!user_data)
@@ -988,15 +1034,15 @@ void wp_memtile_device(void* context, void* dst, const void* src, size_t srcsize
988
1034
 
989
1035
 
990
1036
  static __global__ void array_copy_1d_kernel(void* dst, const void* src,
991
- int dst_stride, int src_stride,
1037
+ size_t dst_stride, size_t src_stride,
992
1038
  const int* dst_indices, const int* src_indices,
993
- int n, int elem_size)
1039
+ size_t n, size_t elem_size)
994
1040
  {
995
- int i = blockIdx.x * blockDim.x + threadIdx.x;
1041
+ size_t i = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
996
1042
  if (i < n)
997
1043
  {
998
- int src_idx = src_indices ? src_indices[i] : i;
999
- int dst_idx = dst_indices ? dst_indices[i] : i;
1044
+ size_t src_idx = src_indices ? src_indices[i] : i;
1045
+ size_t dst_idx = dst_indices ? dst_indices[i] : i;
1000
1046
  const char* p = (const char*)src + src_idx * src_stride;
1001
1047
  char* q = (char*)dst + dst_idx * dst_stride;
1002
1048
  memcpy(q, p, elem_size);
@@ -1004,20 +1050,20 @@ static __global__ void array_copy_1d_kernel(void* dst, const void* src,
1004
1050
  }
1005
1051
 
1006
1052
  static __global__ void array_copy_2d_kernel(void* dst, const void* src,
1007
- wp::vec_t<2, int> dst_strides, wp::vec_t<2, int> src_strides,
1053
+ wp::vec_t<2, size_t> dst_strides, wp::vec_t<2, size_t> src_strides,
1008
1054
  wp::vec_t<2, const int*> dst_indices, wp::vec_t<2, const int*> src_indices,
1009
- wp::vec_t<2, int> shape, int elem_size)
1055
+ wp::vec_t<2, size_t> shape, size_t elem_size)
1010
1056
  {
1011
- int tid = blockIdx.x * blockDim.x + threadIdx.x;
1012
- int n = shape[1];
1013
- int i = tid / n;
1014
- int j = tid % n;
1057
+ size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1058
+ size_t n = shape[1];
1059
+ size_t i = tid / n;
1060
+ size_t j = tid % n;
1015
1061
  if (i < shape[0] /*&& j < shape[1]*/)
1016
1062
  {
1017
- int src_idx0 = src_indices[0] ? src_indices[0][i] : i;
1018
- int dst_idx0 = dst_indices[0] ? dst_indices[0][i] : i;
1019
- int src_idx1 = src_indices[1] ? src_indices[1][j] : j;
1020
- int dst_idx1 = dst_indices[1] ? dst_indices[1][j] : j;
1063
+ size_t src_idx0 = src_indices[0] ? src_indices[0][i] : i;
1064
+ size_t dst_idx0 = dst_indices[0] ? dst_indices[0][i] : i;
1065
+ size_t src_idx1 = src_indices[1] ? src_indices[1][j] : j;
1066
+ size_t dst_idx1 = dst_indices[1] ? dst_indices[1][j] : j;
1021
1067
  const char* p = (const char*)src + src_idx0 * src_strides[0] + src_idx1 * src_strides[1];
1022
1068
  char* q = (char*)dst + dst_idx0 * dst_strides[0] + dst_idx1 * dst_strides[1];
1023
1069
  memcpy(q, p, elem_size);
@@ -1025,24 +1071,24 @@ static __global__ void array_copy_2d_kernel(void* dst, const void* src,
1025
1071
  }
1026
1072
 
1027
1073
  static __global__ void array_copy_3d_kernel(void* dst, const void* src,
1028
- wp::vec_t<3, int> dst_strides, wp::vec_t<3, int> src_strides,
1074
+ wp::vec_t<3, size_t> dst_strides, wp::vec_t<3, size_t> src_strides,
1029
1075
  wp::vec_t<3, const int*> dst_indices, wp::vec_t<3, const int*> src_indices,
1030
- wp::vec_t<3, int> shape, int elem_size)
1031
- {
1032
- int tid = blockIdx.x * blockDim.x + threadIdx.x;
1033
- int n = shape[1];
1034
- int o = shape[2];
1035
- int i = tid / (n * o);
1036
- int j = tid % (n * o) / o;
1037
- int k = tid % o;
1076
+ wp::vec_t<3, size_t> shape, size_t elem_size)
1077
+ {
1078
+ size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1079
+ size_t n = shape[1];
1080
+ size_t o = shape[2];
1081
+ size_t i = tid / (n * o);
1082
+ size_t j = tid % (n * o) / o;
1083
+ size_t k = tid % o;
1038
1084
  if (i < shape[0] && j < shape[1] /*&& k < shape[2]*/)
1039
1085
  {
1040
- int src_idx0 = src_indices[0] ? src_indices[0][i] : i;
1041
- int dst_idx0 = dst_indices[0] ? dst_indices[0][i] : i;
1042
- int src_idx1 = src_indices[1] ? src_indices[1][j] : j;
1043
- int dst_idx1 = dst_indices[1] ? dst_indices[1][j] : j;
1044
- int src_idx2 = src_indices[2] ? src_indices[2][k] : k;
1045
- int dst_idx2 = dst_indices[2] ? dst_indices[2][k] : k;
1086
+ size_t src_idx0 = src_indices[0] ? src_indices[0][i] : i;
1087
+ size_t dst_idx0 = dst_indices[0] ? dst_indices[0][i] : i;
1088
+ size_t src_idx1 = src_indices[1] ? src_indices[1][j] : j;
1089
+ size_t dst_idx1 = dst_indices[1] ? dst_indices[1][j] : j;
1090
+ size_t src_idx2 = src_indices[2] ? src_indices[2][k] : k;
1091
+ size_t dst_idx2 = dst_indices[2] ? dst_indices[2][k] : k;
1046
1092
  const char* p = (const char*)src + src_idx0 * src_strides[0]
1047
1093
  + src_idx1 * src_strides[1]
1048
1094
  + src_idx2 * src_strides[2];
@@ -1054,28 +1100,28 @@ static __global__ void array_copy_3d_kernel(void* dst, const void* src,
1054
1100
  }
1055
1101
 
1056
1102
  static __global__ void array_copy_4d_kernel(void* dst, const void* src,
1057
- wp::vec_t<4, int> dst_strides, wp::vec_t<4, int> src_strides,
1103
+ wp::vec_t<4, size_t> dst_strides, wp::vec_t<4, size_t> src_strides,
1058
1104
  wp::vec_t<4, const int*> dst_indices, wp::vec_t<4, const int*> src_indices,
1059
- wp::vec_t<4, int> shape, int elem_size)
1060
- {
1061
- int tid = blockIdx.x * blockDim.x + threadIdx.x;
1062
- int n = shape[1];
1063
- int o = shape[2];
1064
- int p = shape[3];
1065
- int i = tid / (n * o * p);
1066
- int j = tid % (n * o * p) / (o * p);
1067
- int k = tid % (o * p) / p;
1068
- int l = tid % p;
1105
+ wp::vec_t<4, size_t> shape, size_t elem_size)
1106
+ {
1107
+ size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1108
+ size_t n = shape[1];
1109
+ size_t o = shape[2];
1110
+ size_t p = shape[3];
1111
+ size_t i = tid / (n * o * p);
1112
+ size_t j = tid % (n * o * p) / (o * p);
1113
+ size_t k = tid % (o * p) / p;
1114
+ size_t l = tid % p;
1069
1115
  if (i < shape[0] && j < shape[1] && k < shape[2] /*&& l < shape[3]*/)
1070
1116
  {
1071
- int src_idx0 = src_indices[0] ? src_indices[0][i] : i;
1072
- int dst_idx0 = dst_indices[0] ? dst_indices[0][i] : i;
1073
- int src_idx1 = src_indices[1] ? src_indices[1][j] : j;
1074
- int dst_idx1 = dst_indices[1] ? dst_indices[1][j] : j;
1075
- int src_idx2 = src_indices[2] ? src_indices[2][k] : k;
1076
- int dst_idx2 = dst_indices[2] ? dst_indices[2][k] : k;
1077
- int src_idx3 = src_indices[3] ? src_indices[3][l] : l;
1078
- int dst_idx3 = dst_indices[3] ? dst_indices[3][l] : l;
1117
+ size_t src_idx0 = src_indices[0] ? src_indices[0][i] : i;
1118
+ size_t dst_idx0 = dst_indices[0] ? dst_indices[0][i] : i;
1119
+ size_t src_idx1 = src_indices[1] ? src_indices[1][j] : j;
1120
+ size_t dst_idx1 = dst_indices[1] ? dst_indices[1][j] : j;
1121
+ size_t src_idx2 = src_indices[2] ? src_indices[2][k] : k;
1122
+ size_t dst_idx2 = dst_indices[2] ? dst_indices[2][k] : k;
1123
+ size_t src_idx3 = src_indices[3] ? src_indices[3][l] : l;
1124
+ size_t dst_idx3 = dst_indices[3] ? dst_indices[3][l] : l;
1079
1125
  const char* p = (const char*)src + src_idx0 * src_strides[0]
1080
1126
  + src_idx1 * src_strides[1]
1081
1127
  + src_idx2 * src_strides[2]
@@ -1090,14 +1136,14 @@ static __global__ void array_copy_4d_kernel(void* dst, const void* src,
1090
1136
 
1091
1137
 
1092
1138
  static __global__ void array_copy_from_fabric_kernel(wp::fabricarray_t<void> src,
1093
- void* dst_data, int dst_stride, const int* dst_indices,
1094
- int elem_size)
1139
+ void* dst_data, size_t dst_stride, const int* dst_indices,
1140
+ size_t elem_size)
1095
1141
  {
1096
- int tid = blockIdx.x * blockDim.x + threadIdx.x;
1142
+ size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1097
1143
 
1098
1144
  if (tid < src.size)
1099
1145
  {
1100
- int dst_idx = dst_indices ? dst_indices[tid] : tid;
1146
+ size_t dst_idx = dst_indices ? dst_indices[tid] : tid;
1101
1147
  void* dst_ptr = (char*)dst_data + dst_idx * dst_stride;
1102
1148
  const void* src_ptr = fabricarray_element_ptr(src, tid, elem_size);
1103
1149
  memcpy(dst_ptr, src_ptr, elem_size);
@@ -1105,15 +1151,15 @@ static __global__ void array_copy_from_fabric_kernel(wp::fabricarray_t<void> src
1105
1151
  }
1106
1152
 
1107
1153
  static __global__ void array_copy_from_fabric_indexed_kernel(wp::indexedfabricarray_t<void> src,
1108
- void* dst_data, int dst_stride, const int* dst_indices,
1109
- int elem_size)
1154
+ void* dst_data, size_t dst_stride, const int* dst_indices,
1155
+ size_t elem_size)
1110
1156
  {
1111
- int tid = blockIdx.x * blockDim.x + threadIdx.x;
1157
+ size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1112
1158
 
1113
1159
  if (tid < src.size)
1114
1160
  {
1115
- int src_index = src.indices[tid];
1116
- int dst_idx = dst_indices ? dst_indices[tid] : tid;
1161
+ size_t src_index = src.indices[tid];
1162
+ size_t dst_idx = dst_indices ? dst_indices[tid] : tid;
1117
1163
  void* dst_ptr = (char*)dst_data + dst_idx * dst_stride;
1118
1164
  const void* src_ptr = fabricarray_element_ptr(src.fa, src_index, elem_size);
1119
1165
  memcpy(dst_ptr, src_ptr, elem_size);
@@ -1121,14 +1167,14 @@ static __global__ void array_copy_from_fabric_indexed_kernel(wp::indexedfabricar
1121
1167
  }
1122
1168
 
1123
1169
  static __global__ void array_copy_to_fabric_kernel(wp::fabricarray_t<void> dst,
1124
- const void* src_data, int src_stride, const int* src_indices,
1125
- int elem_size)
1170
+ const void* src_data, size_t src_stride, const int* src_indices,
1171
+ size_t elem_size)
1126
1172
  {
1127
- int tid = blockIdx.x * blockDim.x + threadIdx.x;
1173
+ size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1128
1174
 
1129
1175
  if (tid < dst.size)
1130
1176
  {
1131
- int src_idx = src_indices ? src_indices[tid] : tid;
1177
+ size_t src_idx = src_indices ? src_indices[tid] : tid;
1132
1178
  const void* src_ptr = (const char*)src_data + src_idx * src_stride;
1133
1179
  void* dst_ptr = fabricarray_element_ptr(dst, tid, elem_size);
1134
1180
  memcpy(dst_ptr, src_ptr, elem_size);
@@ -1136,25 +1182,25 @@ static __global__ void array_copy_to_fabric_kernel(wp::fabricarray_t<void> dst,
1136
1182
  }
1137
1183
 
1138
1184
  static __global__ void array_copy_to_fabric_indexed_kernel(wp::indexedfabricarray_t<void> dst,
1139
- const void* src_data, int src_stride, const int* src_indices,
1140
- int elem_size)
1185
+ const void* src_data, size_t src_stride, const int* src_indices,
1186
+ size_t elem_size)
1141
1187
  {
1142
- int tid = blockIdx.x * blockDim.x + threadIdx.x;
1188
+ size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1143
1189
 
1144
1190
  if (tid < dst.size)
1145
1191
  {
1146
- int src_idx = src_indices ? src_indices[tid] : tid;
1192
+ size_t src_idx = src_indices ? src_indices[tid] : tid;
1147
1193
  const void* src_ptr = (const char*)src_data + src_idx * src_stride;
1148
- int dst_idx = dst.indices[tid];
1194
+ size_t dst_idx = dst.indices[tid];
1149
1195
  void* dst_ptr = fabricarray_element_ptr(dst.fa, dst_idx, elem_size);
1150
1196
  memcpy(dst_ptr, src_ptr, elem_size);
1151
1197
  }
1152
1198
  }
1153
1199
 
1154
1200
 
1155
- static __global__ void array_copy_fabric_to_fabric_kernel(wp::fabricarray_t<void> dst, wp::fabricarray_t<void> src, int elem_size)
1201
+ static __global__ void array_copy_fabric_to_fabric_kernel(wp::fabricarray_t<void> dst, wp::fabricarray_t<void> src, size_t elem_size)
1156
1202
  {
1157
- int tid = blockIdx.x * blockDim.x + threadIdx.x;
1203
+ size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1158
1204
 
1159
1205
  if (tid < dst.size)
1160
1206
  {
@@ -1165,27 +1211,27 @@ static __global__ void array_copy_fabric_to_fabric_kernel(wp::fabricarray_t<void
1165
1211
  }
1166
1212
 
1167
1213
 
1168
- static __global__ void array_copy_fabric_to_fabric_indexed_kernel(wp::indexedfabricarray_t<void> dst, wp::fabricarray_t<void> src, int elem_size)
1214
+ static __global__ void array_copy_fabric_to_fabric_indexed_kernel(wp::indexedfabricarray_t<void> dst, wp::fabricarray_t<void> src, size_t elem_size)
1169
1215
  {
1170
- int tid = blockIdx.x * blockDim.x + threadIdx.x;
1216
+ size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1171
1217
 
1172
1218
  if (tid < dst.size)
1173
1219
  {
1174
1220
  const void* src_ptr = fabricarray_element_ptr(src, tid, elem_size);
1175
- int dst_index = dst.indices[tid];
1221
+ size_t dst_index = dst.indices[tid];
1176
1222
  void* dst_ptr = fabricarray_element_ptr(dst.fa, dst_index, elem_size);
1177
1223
  memcpy(dst_ptr, src_ptr, elem_size);
1178
1224
  }
1179
1225
  }
1180
1226
 
1181
1227
 
1182
- static __global__ void array_copy_fabric_indexed_to_fabric_kernel(wp::fabricarray_t<void> dst, wp::indexedfabricarray_t<void> src, int elem_size)
1228
+ static __global__ void array_copy_fabric_indexed_to_fabric_kernel(wp::fabricarray_t<void> dst, wp::indexedfabricarray_t<void> src, size_t elem_size)
1183
1229
  {
1184
- int tid = blockIdx.x * blockDim.x + threadIdx.x;
1230
+ size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1185
1231
 
1186
1232
  if (tid < dst.size)
1187
1233
  {
1188
- int src_index = src.indices[tid];
1234
+ size_t src_index = src.indices[tid];
1189
1235
  const void* src_ptr = fabricarray_element_ptr(src.fa, src_index, elem_size);
1190
1236
  void* dst_ptr = fabricarray_element_ptr(dst, tid, elem_size);
1191
1237
  memcpy(dst_ptr, src_ptr, elem_size);
@@ -1193,14 +1239,14 @@ static __global__ void array_copy_fabric_indexed_to_fabric_kernel(wp::fabricarra
1193
1239
  }
1194
1240
 
1195
1241
 
1196
- static __global__ void array_copy_fabric_indexed_to_fabric_indexed_kernel(wp::indexedfabricarray_t<void> dst, wp::indexedfabricarray_t<void> src, int elem_size)
1242
+ static __global__ void array_copy_fabric_indexed_to_fabric_indexed_kernel(wp::indexedfabricarray_t<void> dst, wp::indexedfabricarray_t<void> src, size_t elem_size)
1197
1243
  {
1198
- int tid = blockIdx.x * blockDim.x + threadIdx.x;
1244
+ size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1199
1245
 
1200
1246
  if (tid < dst.size)
1201
1247
  {
1202
- int src_index = src.indices[tid];
1203
- int dst_index = dst.indices[tid];
1248
+ size_t src_index = src.indices[tid];
1249
+ size_t dst_index = dst.indices[tid];
1204
1250
  const void* src_ptr = fabricarray_element_ptr(src.fa, src_index, elem_size);
1205
1251
  void* dst_ptr = fabricarray_element_ptr(dst.fa, dst_index, elem_size);
1206
1252
  memcpy(dst_ptr, src_ptr, elem_size);
@@ -1439,9 +1485,9 @@ WP_API bool wp_array_copy_device(void* context, void* dst, void* src, int dst_ty
1439
1485
  }
1440
1486
  case 2:
1441
1487
  {
1442
- wp::vec_t<2, int> shape_v(src_shape[0], src_shape[1]);
1443
- wp::vec_t<2, int> src_strides_v(src_strides[0], src_strides[1]);
1444
- wp::vec_t<2, int> dst_strides_v(dst_strides[0], dst_strides[1]);
1488
+ wp::vec_t<2, size_t> shape_v(src_shape[0], src_shape[1]);
1489
+ wp::vec_t<2, size_t> src_strides_v(src_strides[0], src_strides[1]);
1490
+ wp::vec_t<2, size_t> dst_strides_v(dst_strides[0], dst_strides[1]);
1445
1491
  wp::vec_t<2, const int*> src_indices_v(src_indices[0], src_indices[1]);
1446
1492
  wp::vec_t<2, const int*> dst_indices_v(dst_indices[0], dst_indices[1]);
1447
1493
 
@@ -1453,9 +1499,9 @@ WP_API bool wp_array_copy_device(void* context, void* dst, void* src, int dst_ty
1453
1499
  }
1454
1500
  case 3:
1455
1501
  {
1456
- wp::vec_t<3, int> shape_v(src_shape[0], src_shape[1], src_shape[2]);
1457
- wp::vec_t<3, int> src_strides_v(src_strides[0], src_strides[1], src_strides[2]);
1458
- wp::vec_t<3, int> dst_strides_v(dst_strides[0], dst_strides[1], dst_strides[2]);
1502
+ wp::vec_t<3, size_t> shape_v(src_shape[0], src_shape[1], src_shape[2]);
1503
+ wp::vec_t<3, size_t> src_strides_v(src_strides[0], src_strides[1], src_strides[2]);
1504
+ wp::vec_t<3, size_t> dst_strides_v(dst_strides[0], dst_strides[1], dst_strides[2]);
1459
1505
  wp::vec_t<3, const int*> src_indices_v(src_indices[0], src_indices[1], src_indices[2]);
1460
1506
  wp::vec_t<3, const int*> dst_indices_v(dst_indices[0], dst_indices[1], dst_indices[2]);
1461
1507
 
@@ -1467,9 +1513,9 @@ WP_API bool wp_array_copy_device(void* context, void* dst, void* src, int dst_ty
1467
1513
  }
1468
1514
  case 4:
1469
1515
  {
1470
- wp::vec_t<4, int> shape_v(src_shape[0], src_shape[1], src_shape[2], src_shape[3]);
1471
- wp::vec_t<4, int> src_strides_v(src_strides[0], src_strides[1], src_strides[2], src_strides[3]);
1472
- wp::vec_t<4, int> dst_strides_v(dst_strides[0], dst_strides[1], dst_strides[2], dst_strides[3]);
1516
+ wp::vec_t<4, size_t> shape_v(src_shape[0], src_shape[1], src_shape[2], src_shape[3]);
1517
+ wp::vec_t<4, size_t> src_strides_v(src_strides[0], src_strides[1], src_strides[2], src_strides[3]);
1518
+ wp::vec_t<4, size_t> dst_strides_v(dst_strides[0], dst_strides[1], dst_strides[2], dst_strides[3]);
1473
1519
  wp::vec_t<4, const int*> src_indices_v(src_indices[0], src_indices[1], src_indices[2], src_indices[3]);
1474
1520
  wp::vec_t<4, const int*> dst_indices_v(dst_indices[0], dst_indices[1], dst_indices[2], dst_indices[3]);
1475
1521
 
@@ -1489,94 +1535,94 @@ WP_API bool wp_array_copy_device(void* context, void* dst, void* src, int dst_ty
1489
1535
 
1490
1536
 
1491
1537
  static __global__ void array_fill_1d_kernel(void* data,
1492
- int n,
1493
- int stride,
1538
+ size_t n,
1539
+ size_t stride,
1494
1540
  const int* indices,
1495
1541
  const void* value,
1496
- int value_size)
1542
+ size_t value_size)
1497
1543
  {
1498
- int i = blockIdx.x * blockDim.x + threadIdx.x;
1544
+ size_t i = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1499
1545
  if (i < n)
1500
1546
  {
1501
- int idx = indices ? indices[i] : i;
1547
+ size_t idx = indices ? indices[i] : i;
1502
1548
  char* p = (char*)data + idx * stride;
1503
1549
  memcpy(p, value, value_size);
1504
1550
  }
1505
1551
  }
1506
1552
 
1507
1553
  static __global__ void array_fill_2d_kernel(void* data,
1508
- wp::vec_t<2, int> shape,
1509
- wp::vec_t<2, int> strides,
1554
+ wp::vec_t<2, size_t> shape,
1555
+ wp::vec_t<2, size_t> strides,
1510
1556
  wp::vec_t<2, const int*> indices,
1511
1557
  const void* value,
1512
- int value_size)
1558
+ size_t value_size)
1513
1559
  {
1514
- int tid = blockIdx.x * blockDim.x + threadIdx.x;
1515
- int n = shape[1];
1516
- int i = tid / n;
1517
- int j = tid % n;
1560
+ size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1561
+ size_t n = shape[1];
1562
+ size_t i = tid / n;
1563
+ size_t j = tid % n;
1518
1564
  if (i < shape[0] /*&& j < shape[1]*/)
1519
1565
  {
1520
- int idx0 = indices[0] ? indices[0][i] : i;
1521
- int idx1 = indices[1] ? indices[1][j] : j;
1566
+ size_t idx0 = indices[0] ? indices[0][i] : i;
1567
+ size_t idx1 = indices[1] ? indices[1][j] : j;
1522
1568
  char* p = (char*)data + idx0 * strides[0] + idx1 * strides[1];
1523
1569
  memcpy(p, value, value_size);
1524
1570
  }
1525
1571
  }
1526
1572
 
1527
1573
  static __global__ void array_fill_3d_kernel(void* data,
1528
- wp::vec_t<3, int> shape,
1529
- wp::vec_t<3, int> strides,
1574
+ wp::vec_t<3, size_t> shape,
1575
+ wp::vec_t<3, size_t> strides,
1530
1576
  wp::vec_t<3, const int*> indices,
1531
1577
  const void* value,
1532
- int value_size)
1533
- {
1534
- int tid = blockIdx.x * blockDim.x + threadIdx.x;
1535
- int n = shape[1];
1536
- int o = shape[2];
1537
- int i = tid / (n * o);
1538
- int j = tid % (n * o) / o;
1539
- int k = tid % o;
1578
+ size_t value_size)
1579
+ {
1580
+ size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1581
+ size_t n = shape[1];
1582
+ size_t o = shape[2];
1583
+ size_t i = tid / (n * o);
1584
+ size_t j = tid % (n * o) / o;
1585
+ size_t k = tid % o;
1540
1586
  if (i < shape[0] && j < shape[1] /*&& k < shape[2]*/)
1541
1587
  {
1542
- int idx0 = indices[0] ? indices[0][i] : i;
1543
- int idx1 = indices[1] ? indices[1][j] : j;
1544
- int idx2 = indices[2] ? indices[2][k] : k;
1588
+ size_t idx0 = indices[0] ? indices[0][i] : i;
1589
+ size_t idx1 = indices[1] ? indices[1][j] : j;
1590
+ size_t idx2 = indices[2] ? indices[2][k] : k;
1545
1591
  char* p = (char*)data + idx0 * strides[0] + idx1 * strides[1] + idx2 * strides[2];
1546
1592
  memcpy(p, value, value_size);
1547
1593
  }
1548
1594
  }
1549
1595
 
1550
1596
  static __global__ void array_fill_4d_kernel(void* data,
1551
- wp::vec_t<4, int> shape,
1552
- wp::vec_t<4, int> strides,
1597
+ wp::vec_t<4, size_t> shape,
1598
+ wp::vec_t<4, size_t> strides,
1553
1599
  wp::vec_t<4, const int*> indices,
1554
1600
  const void* value,
1555
- int value_size)
1556
- {
1557
- int tid = blockIdx.x * blockDim.x + threadIdx.x;
1558
- int n = shape[1];
1559
- int o = shape[2];
1560
- int p = shape[3];
1561
- int i = tid / (n * o * p);
1562
- int j = tid % (n * o * p) / (o * p);
1563
- int k = tid % (o * p) / p;
1564
- int l = tid % p;
1601
+ size_t value_size)
1602
+ {
1603
+ size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1604
+ size_t n = shape[1];
1605
+ size_t o = shape[2];
1606
+ size_t p = shape[3];
1607
+ size_t i = tid / (n * o * p);
1608
+ size_t j = tid % (n * o * p) / (o * p);
1609
+ size_t k = tid % (o * p) / p;
1610
+ size_t l = tid % p;
1565
1611
  if (i < shape[0] && j < shape[1] && k < shape[2] /*&& l < shape[3]*/)
1566
1612
  {
1567
- int idx0 = indices[0] ? indices[0][i] : i;
1568
- int idx1 = indices[1] ? indices[1][j] : j;
1569
- int idx2 = indices[2] ? indices[2][k] : k;
1570
- int idx3 = indices[3] ? indices[3][l] : l;
1613
+ size_t idx0 = indices[0] ? indices[0][i] : i;
1614
+ size_t idx1 = indices[1] ? indices[1][j] : j;
1615
+ size_t idx2 = indices[2] ? indices[2][k] : k;
1616
+ size_t idx3 = indices[3] ? indices[3][l] : l;
1571
1617
  char* p = (char*)data + idx0 * strides[0] + idx1 * strides[1] + idx2 * strides[2] + idx3 * strides[3];
1572
1618
  memcpy(p, value, value_size);
1573
1619
  }
1574
1620
  }
1575
1621
 
1576
1622
 
1577
- static __global__ void array_fill_fabric_kernel(wp::fabricarray_t<void> fa, const void* value, int value_size)
1623
+ static __global__ void array_fill_fabric_kernel(wp::fabricarray_t<void> fa, const void* value, size_t value_size)
1578
1624
  {
1579
- int tid = blockIdx.x * blockDim.x + threadIdx.x;
1625
+ size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1580
1626
  if (tid < fa.size)
1581
1627
  {
1582
1628
  void* dst_ptr = fabricarray_element_ptr(fa, tid, value_size);
@@ -1585,9 +1631,9 @@ static __global__ void array_fill_fabric_kernel(wp::fabricarray_t<void> fa, cons
1585
1631
  }
1586
1632
 
1587
1633
 
1588
- static __global__ void array_fill_fabric_indexed_kernel(wp::indexedfabricarray_t<void> ifa, const void* value, int value_size)
1634
+ static __global__ void array_fill_fabric_indexed_kernel(wp::indexedfabricarray_t<void> ifa, const void* value, size_t value_size)
1589
1635
  {
1590
- int tid = blockIdx.x * blockDim.x + threadIdx.x;
1636
+ size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1591
1637
  if (tid < ifa.size)
1592
1638
  {
1593
1639
  size_t idx = size_t(ifa.indices[tid]);
@@ -1684,8 +1730,8 @@ WP_API void wp_array_fill_device(void* context, void* arr_ptr, int arr_type, con
1684
1730
  }
1685
1731
  case 2:
1686
1732
  {
1687
- wp::vec_t<2, int> shape_v(shape[0], shape[1]);
1688
- wp::vec_t<2, int> strides_v(strides[0], strides[1]);
1733
+ wp::vec_t<2, size_t> shape_v(shape[0], shape[1]);
1734
+ wp::vec_t<2, size_t> strides_v(strides[0], strides[1]);
1689
1735
  wp::vec_t<2, const int*> indices_v(indices[0], indices[1]);
1690
1736
  wp_launch_device(WP_CURRENT_CONTEXT, array_fill_2d_kernel, n,
1691
1737
  (data, shape_v, strides_v, indices_v, value_devptr, value_size));
@@ -1693,8 +1739,8 @@ WP_API void wp_array_fill_device(void* context, void* arr_ptr, int arr_type, con
1693
1739
  }
1694
1740
  case 3:
1695
1741
  {
1696
- wp::vec_t<3, int> shape_v(shape[0], shape[1], shape[2]);
1697
- wp::vec_t<3, int> strides_v(strides[0], strides[1], strides[2]);
1742
+ wp::vec_t<3, size_t> shape_v(shape[0], shape[1], shape[2]);
1743
+ wp::vec_t<3, size_t> strides_v(strides[0], strides[1], strides[2]);
1698
1744
  wp::vec_t<3, const int*> indices_v(indices[0], indices[1], indices[2]);
1699
1745
  wp_launch_device(WP_CURRENT_CONTEXT, array_fill_3d_kernel, n,
1700
1746
  (data, shape_v, strides_v, indices_v, value_devptr, value_size));
@@ -1702,8 +1748,8 @@ WP_API void wp_array_fill_device(void* context, void* arr_ptr, int arr_type, con
1702
1748
  }
1703
1749
  case 4:
1704
1750
  {
1705
- wp::vec_t<4, int> shape_v(shape[0], shape[1], shape[2], shape[3]);
1706
- wp::vec_t<4, int> strides_v(strides[0], strides[1], strides[2], strides[3]);
1751
+ wp::vec_t<4, size_t> shape_v(shape[0], shape[1], shape[2], shape[3]);
1752
+ wp::vec_t<4, size_t> strides_v(strides[0], strides[1], strides[2], strides[3]);
1707
1753
  wp::vec_t<4, const int*> indices_v(indices[0], indices[1], indices[2], indices[3]);
1708
1754
  wp_launch_device(WP_CURRENT_CONTEXT, array_fill_4d_kernel, n,
1709
1755
  (data, shape_v, strides_v, indices_v, value_devptr, value_size));
@@ -2071,13 +2117,17 @@ void wp_cuda_context_synchronize(void* context)
2071
2117
 
2072
2118
  check_cu(cuCtxSynchronize_f());
2073
2119
 
2074
- if (free_deferred_allocs(context ? context : get_current_context()) > 0)
2120
+ if (!context)
2121
+ context = get_current_context();
2122
+
2123
+ if (free_deferred_allocs(context) > 0)
2075
2124
  {
2076
2125
  // ensure deferred asynchronous deallocations complete
2077
2126
  check_cu(cuCtxSynchronize_f());
2078
2127
  }
2079
2128
 
2080
2129
  unload_deferred_modules(context);
2130
+ destroy_deferred_graphs(context);
2081
2131
 
2082
2132
  // check_cuda(cudaDeviceGraphMemTrim(wp_cuda_context_get_device_ordinal(context)));
2083
2133
  }
@@ -2448,6 +2498,9 @@ void wp_cuda_stream_destroy(void* context, void* stream)
2448
2498
 
2449
2499
  wp_cuda_stream_unregister(context, stream);
2450
2500
 
2501
+ // release temporary radix sort buffer associated with this stream
2502
+ radix_sort_release(context, stream);
2503
+
2451
2504
  check_cu(cuStreamDestroy_f(static_cast<CUstream>(stream)));
2452
2505
  }
2453
2506
 
@@ -2510,15 +2563,36 @@ void wp_cuda_stream_synchronize(void* stream)
2510
2563
  check_cu(cuStreamSynchronize_f(static_cast<CUstream>(stream)));
2511
2564
  }
2512
2565
 
2513
- void wp_cuda_stream_wait_event(void* stream, void* event)
2566
+ void wp_cuda_stream_wait_event(void* stream, void* event, bool external)
2514
2567
  {
2515
- check_cu(cuStreamWaitEvent_f(static_cast<CUstream>(stream), static_cast<CUevent>(event), 0));
2568
+ // the external flag can only be used during graph capture
2569
+ if (external && !g_captures.empty() && wp_cuda_stream_is_capturing(stream))
2570
+ {
2571
+ // wait for an external event during graph capture
2572
+ check_cu(cuStreamWaitEvent_f(static_cast<CUstream>(stream), static_cast<CUevent>(event), CU_EVENT_WAIT_EXTERNAL));
2573
+ }
2574
+ else
2575
+ {
2576
+ check_cu(cuStreamWaitEvent_f(static_cast<CUstream>(stream), static_cast<CUevent>(event), CU_EVENT_WAIT_DEFAULT));
2577
+ }
2516
2578
  }
2517
2579
 
2518
- void wp_cuda_stream_wait_stream(void* stream, void* other_stream, void* event)
2580
+ void wp_cuda_stream_wait_stream(void* stream, void* other_stream, void* event, bool external)
2519
2581
  {
2520
- check_cu(cuEventRecord_f(static_cast<CUevent>(event), static_cast<CUstream>(other_stream)));
2521
- check_cu(cuStreamWaitEvent_f(static_cast<CUstream>(stream), static_cast<CUevent>(event), 0));
2582
+ unsigned record_flags = CU_EVENT_RECORD_DEFAULT;
2583
+ unsigned wait_flags = CU_EVENT_WAIT_DEFAULT;
2584
+
2585
+ // the external flag can only be used during graph capture
2586
+ if (external && !g_captures.empty())
2587
+ {
2588
+ if (wp_cuda_stream_is_capturing(other_stream))
2589
+ record_flags = CU_EVENT_RECORD_EXTERNAL;
2590
+ if (wp_cuda_stream_is_capturing(stream))
2591
+ wait_flags = CU_EVENT_WAIT_EXTERNAL;
2592
+ }
2593
+
2594
+ check_cu(cuEventRecordWithFlags_f(static_cast<CUevent>(event), static_cast<CUstream>(other_stream), record_flags));
2595
+ check_cu(cuStreamWaitEvent_f(static_cast<CUstream>(stream), static_cast<CUevent>(event), wait_flags));
2522
2596
  }
2523
2597
 
2524
2598
  int wp_cuda_stream_is_capturing(void* stream)
@@ -2571,11 +2645,12 @@ int wp_cuda_event_query(void* event)
2571
2645
  return res;
2572
2646
  }
2573
2647
 
2574
- void wp_cuda_event_record(void* event, void* stream, bool timing)
2648
+ void wp_cuda_event_record(void* event, void* stream, bool external)
2575
2649
  {
2576
- if (timing && !g_captures.empty() && wp_cuda_stream_is_capturing(stream))
2650
+ // the external flag can only be used during graph capture
2651
+ if (external && !g_captures.empty() && wp_cuda_stream_is_capturing(stream))
2577
2652
  {
2578
- // record timing event during graph capture
2653
+ // record external event during graph capture (e.g., for timing or when explicitly specified by the user)
2579
2654
  check_cu(cuEventRecordWithFlags_f(static_cast<CUevent>(event), static_cast<CUstream>(stream), CU_EVENT_RECORD_EXTERNAL));
2580
2655
  }
2581
2656
  else
@@ -2625,7 +2700,7 @@ bool wp_cuda_graph_begin_capture(void* context, void* stream, int external)
2625
2700
  else
2626
2701
  {
2627
2702
  // start the capture
2628
- if (!check_cuda(cudaStreamBeginCapture(cuda_stream, cudaStreamCaptureModeGlobal)))
2703
+ if (!check_cuda(cudaStreamBeginCapture(cuda_stream, cudaStreamCaptureModeThreadLocal)))
2629
2704
  return false;
2630
2705
  }
2631
2706
 
@@ -2772,6 +2847,7 @@ bool wp_cuda_graph_end_capture(void* context, void* stream, void** graph_ret)
2772
2847
  {
2773
2848
  free_deferred_allocs();
2774
2849
  unload_deferred_modules();
2850
+ destroy_deferred_graphs();
2775
2851
  }
2776
2852
 
2777
2853
  if (graph_ret)
@@ -2811,11 +2887,12 @@ bool wp_cuda_graph_create_exec(void* context, void* stream, void* graph, void**
2811
2887
  // Support for conditional graph nodes available with CUDA 12.4+.
2812
2888
  #if CUDA_VERSION >= 12040
2813
2889
 
2814
- // CUBIN data for compiled conditional modules, loaded on demand, keyed on device architecture
2815
- static std::map<int, void*> g_conditional_cubins;
2890
+ // CUBIN or PTX data for compiled conditional modules, loaded on demand, keyed on device architecture
2891
+ using ModuleKey = std::pair<int, bool>; // <arch, use_ptx>
2892
+ static std::map<ModuleKey, void*> g_conditional_modules;
2816
2893
 
2817
2894
  // Compile module with conditional helper kernels
2818
- static void* compile_conditional_module(int arch)
2895
+ static void* compile_conditional_module(int arch, bool use_ptx)
2819
2896
  {
2820
2897
  static const char* kernel_source = R"(
2821
2898
  typedef __device_builtin__ unsigned long long cudaGraphConditionalHandle;
@@ -2844,8 +2921,9 @@ static void* compile_conditional_module(int arch)
2844
2921
  )";
2845
2922
 
2846
2923
  // avoid recompilation
2847
- auto it = g_conditional_cubins.find(arch);
2848
- if (it != g_conditional_cubins.end())
2924
+ ModuleKey key = {arch, use_ptx};
2925
+ auto it = g_conditional_modules.find(key);
2926
+ if (it != g_conditional_modules.end())
2849
2927
  return it->second;
2850
2928
 
2851
2929
  nvrtcProgram prog;
@@ -2853,11 +2931,23 @@ static void* compile_conditional_module(int arch)
2853
2931
  return NULL;
2854
2932
 
2855
2933
  char arch_opt[128];
2856
- snprintf(arch_opt, sizeof(arch_opt), "--gpu-architecture=sm_%d", arch);
2934
+ if (use_ptx)
2935
+ snprintf(arch_opt, sizeof(arch_opt), "--gpu-architecture=compute_%d", arch);
2936
+ else
2937
+ snprintf(arch_opt, sizeof(arch_opt), "--gpu-architecture=sm_%d", arch);
2857
2938
 
2858
2939
  std::vector<const char*> opts;
2859
2940
  opts.push_back(arch_opt);
2860
2941
 
2942
+ const bool print_debug = (std::getenv("WARP_DEBUG") != nullptr);
2943
+ if (print_debug)
2944
+ {
2945
+ printf("NVRTC options (conditional module, arch=%d, use_ptx=%s):\n", arch, use_ptx ? "true" : "false");
2946
+ for(auto o: opts) {
2947
+ printf("%s\n", o);
2948
+ }
2949
+ }
2950
+
2861
2951
  if (!check_nvrtc(nvrtcCompileProgram(prog, int(opts.size()), opts.data())))
2862
2952
  {
2863
2953
  size_t log_size;
@@ -2874,23 +2964,37 @@ static void* compile_conditional_module(int arch)
2874
2964
  // get output
2875
2965
  char* output = NULL;
2876
2966
  size_t output_size = 0;
2877
- check_nvrtc(nvrtcGetCUBINSize(prog, &output_size));
2878
- if (output_size > 0)
2967
+
2968
+ if (use_ptx)
2969
+ {
2970
+ check_nvrtc(nvrtcGetPTXSize(prog, &output_size));
2971
+ if (output_size > 0)
2972
+ {
2973
+ output = new char[output_size];
2974
+ if (check_nvrtc(nvrtcGetPTX(prog, output)))
2975
+ g_conditional_modules[key] = output;
2976
+ }
2977
+ }
2978
+ else
2879
2979
  {
2880
- output = new char[output_size];
2881
- if (check_nvrtc(nvrtcGetCUBIN(prog, output)))
2882
- g_conditional_cubins[arch] = output;
2980
+ check_nvrtc(nvrtcGetCUBINSize(prog, &output_size));
2981
+ if (output_size > 0)
2982
+ {
2983
+ output = new char[output_size];
2984
+ if (check_nvrtc(nvrtcGetCUBIN(prog, output)))
2985
+ g_conditional_modules[key] = output;
2986
+ }
2883
2987
  }
2884
2988
 
2885
2989
  nvrtcDestroyProgram(&prog);
2886
2990
 
2887
- // return CUBIN data
2991
+ // return CUBIN or PTX data
2888
2992
  return output;
2889
2993
  }
2890
2994
 
2891
2995
 
2892
2996
  // Load module with conditional helper kernels
2893
- static CUmodule load_conditional_module(void* context)
2997
+ static CUmodule load_conditional_module(void* context, int arch, bool use_ptx)
2894
2998
  {
2895
2999
  ContextInfo* context_info = get_context_info(context);
2896
3000
  if (!context_info)
@@ -2900,17 +3004,15 @@ static CUmodule load_conditional_module(void* context)
2900
3004
  if (context_info->conditional_module)
2901
3005
  return context_info->conditional_module;
2902
3006
 
2903
- int arch = context_info->device_info->arch;
2904
-
2905
3007
  // compile if needed
2906
- void* compiled_module = compile_conditional_module(arch);
3008
+ void* compiled_module = compile_conditional_module(arch, use_ptx);
2907
3009
  if (!compiled_module)
2908
3010
  {
2909
3011
  fprintf(stderr, "Warp error: Failed to compile conditional kernels\n");
2910
3012
  return NULL;
2911
3013
  }
2912
3014
 
2913
- // load module
3015
+ // load module (handles both PTX and CUBIN data automatically)
2914
3016
  CUmodule module = NULL;
2915
3017
  if (!check_cu(cuModuleLoadDataEx_f(&module, compiled_module, 0, NULL, NULL)))
2916
3018
  {
@@ -2923,10 +3025,10 @@ static CUmodule load_conditional_module(void* context)
2923
3025
  return module;
2924
3026
  }
2925
3027
 
2926
- static CUfunction get_conditional_kernel(void* context, const char* name)
3028
+ static CUfunction get_conditional_kernel(void* context, int arch, bool use_ptx, const char* name)
2927
3029
  {
2928
3030
  // load module if needed
2929
- CUmodule module = load_conditional_module(context);
3031
+ CUmodule module = load_conditional_module(context, arch, use_ptx);
2930
3032
  if (!module)
2931
3033
  return NULL;
2932
3034
 
@@ -2966,7 +3068,7 @@ bool wp_cuda_graph_resume_capture(void* context, void* stream, void* graph)
2966
3068
  leaf_nodes.data(),
2967
3069
  nullptr,
2968
3070
  leaf_nodes.size(),
2969
- cudaStreamCaptureModeGlobal)))
3071
+ cudaStreamCaptureModeThreadLocal)))
2970
3072
  return false;
2971
3073
 
2972
3074
  return true;
@@ -2976,7 +3078,7 @@ bool wp_cuda_graph_resume_capture(void* context, void* stream, void* graph)
2976
3078
  // https://developer.nvidia.com/blog/dynamic-control-flow-in-cuda-graphs-with-conditional-nodes/
2977
3079
  // condition is a gpu pointer
2978
3080
  // if_graph_ret and else_graph_ret should be NULL if not needed
2979
- bool wp_cuda_graph_insert_if_else(void* context, void* stream, int* condition, void** if_graph_ret, void** else_graph_ret)
3081
+ bool wp_cuda_graph_insert_if_else(void* context, void* stream, int arch, bool use_ptx, int* condition, void** if_graph_ret, void** else_graph_ret)
2980
3082
  {
2981
3083
  bool has_if = if_graph_ret != NULL;
2982
3084
  bool has_else = else_graph_ret != NULL;
@@ -3019,9 +3121,9 @@ bool wp_cuda_graph_insert_if_else(void* context, void* stream, int* condition, v
3019
3121
  // (need to negate the condition if only the else branch is used)
3020
3122
  CUfunction kernel;
3021
3123
  if (has_if)
3022
- kernel = get_conditional_kernel(context, "set_conditional_if_handle_kernel");
3124
+ kernel = get_conditional_kernel(context, arch, use_ptx, "set_conditional_if_handle_kernel");
3023
3125
  else
3024
- kernel = get_conditional_kernel(context, "set_conditional_else_handle_kernel");
3126
+ kernel = get_conditional_kernel(context, arch, use_ptx, "set_conditional_else_handle_kernel");
3025
3127
 
3026
3128
  if (!kernel)
3027
3129
  {
@@ -3072,7 +3174,7 @@ bool wp_cuda_graph_insert_if_else(void* context, void* stream, int* condition, v
3072
3174
  check_cuda(cudaGraphConditionalHandleCreate(&if_handle, cuda_graph));
3073
3175
  check_cuda(cudaGraphConditionalHandleCreate(&else_handle, cuda_graph));
3074
3176
 
3075
- CUfunction kernel = get_conditional_kernel(context, "set_conditional_if_else_handles_kernel");
3177
+ CUfunction kernel = get_conditional_kernel(context, arch, use_ptx, "set_conditional_if_else_handles_kernel");
3076
3178
  if (!kernel)
3077
3179
  {
3078
3180
  wp::set_error_string("Failed to get built-in conditional kernel");
@@ -3273,7 +3375,7 @@ bool wp_cuda_graph_insert_child_graph(void* context, void* stream, void* child_g
3273
3375
  return true;
3274
3376
  }
3275
3377
 
3276
- bool wp_cuda_graph_insert_while(void* context, void* stream, int* condition, void** body_graph_ret, uint64_t* handle_ret)
3378
+ bool wp_cuda_graph_insert_while(void* context, void* stream, int arch, bool use_ptx, int* condition, void** body_graph_ret, uint64_t* handle_ret)
3277
3379
  {
3278
3380
  // if there's no body, it's a no-op
3279
3381
  if (!body_graph_ret)
@@ -3303,7 +3405,7 @@ bool wp_cuda_graph_insert_while(void* context, void* stream, int* condition, voi
3303
3405
  return false;
3304
3406
 
3305
3407
  // launch a kernel to set the condition handle from condition pointer
3306
- CUfunction kernel = get_conditional_kernel(context, "set_conditional_if_handle_kernel");
3408
+ CUfunction kernel = get_conditional_kernel(context, arch, use_ptx, "set_conditional_if_handle_kernel");
3307
3409
  if (!kernel)
3308
3410
  {
3309
3411
  wp::set_error_string("Failed to get built-in conditional kernel");
@@ -3339,14 +3441,14 @@ bool wp_cuda_graph_insert_while(void* context, void* stream, int* condition, voi
3339
3441
  return true;
3340
3442
  }
3341
3443
 
3342
- bool wp_cuda_graph_set_condition(void* context, void* stream, int* condition, uint64_t handle)
3444
+ bool wp_cuda_graph_set_condition(void* context, void* stream, int arch, bool use_ptx, int* condition, uint64_t handle)
3343
3445
  {
3344
3446
  ContextGuard guard(context);
3345
3447
 
3346
3448
  CUstream cuda_stream = static_cast<CUstream>(stream);
3347
3449
 
3348
3450
  // launch a kernel to set the condition handle from condition pointer
3349
- CUfunction kernel = get_conditional_kernel(context, "set_conditional_if_handle_kernel");
3451
+ CUfunction kernel = get_conditional_kernel(context, arch, use_ptx, "set_conditional_if_handle_kernel");
3350
3452
  if (!kernel)
3351
3453
  {
3352
3454
  wp::set_error_string("Failed to get built-in conditional kernel");
@@ -3378,19 +3480,19 @@ bool wp_cuda_graph_resume_capture(void* context, void* stream, void* graph)
3378
3480
  return false;
3379
3481
  }
3380
3482
 
3381
- bool wp_cuda_graph_insert_if_else(void* context, void* stream, int* condition, void** if_graph_ret, void** else_graph_ret)
3483
+ bool wp_cuda_graph_insert_if_else(void* context, void* stream, int arch, bool use_ptx, int* condition, void** if_graph_ret, void** else_graph_ret)
3382
3484
  {
3383
3485
  wp::set_error_string("Warp error: Warp must be built with CUDA Toolkit 12.4+ to enable conditional graph nodes");
3384
3486
  return false;
3385
3487
  }
3386
3488
 
3387
- bool wp_cuda_graph_insert_while(void* context, void* stream, int* condition, void** body_graph_ret, uint64_t* handle_ret)
3489
+ bool wp_cuda_graph_insert_while(void* context, void* stream, int arch, bool use_ptx, int* condition, void** body_graph_ret, uint64_t* handle_ret)
3388
3490
  {
3389
3491
  wp::set_error_string("Warp error: Warp must be built with CUDA Toolkit 12.4+ to enable conditional graph nodes");
3390
3492
  return false;
3391
3493
  }
3392
3494
 
3393
- bool wp_cuda_graph_set_condition(void* context, void* stream, int* condition, uint64_t handle)
3495
+ bool wp_cuda_graph_set_condition(void* context, void* stream, int arch, bool use_ptx, int* condition, uint64_t handle)
3394
3496
  {
3395
3497
  wp::set_error_string("Warp error: Warp must be built with CUDA Toolkit 12.4+ to enable conditional graph nodes");
3396
3498
  return false;
@@ -3425,16 +3527,38 @@ bool wp_cuda_graph_launch(void* graph_exec, void* stream)
3425
3527
 
3426
3528
  bool wp_cuda_graph_destroy(void* context, void* graph)
3427
3529
  {
3428
- ContextGuard guard(context);
3429
-
3430
- return check_cuda(cudaGraphDestroy((cudaGraph_t)graph));
3530
+ // ensure there are no graph captures in progress
3531
+ if (g_captures.empty())
3532
+ {
3533
+ ContextGuard guard(context);
3534
+ return check_cuda(cudaGraphDestroy((cudaGraph_t)graph));
3535
+ }
3536
+ else
3537
+ {
3538
+ GraphDestroyInfo info;
3539
+ info.context = context ? context : get_current_context();
3540
+ info.graph = graph;
3541
+ g_deferred_graph_list.push_back(info);
3542
+ return true;
3543
+ }
3431
3544
  }
3432
3545
 
3433
3546
  bool wp_cuda_graph_exec_destroy(void* context, void* graph_exec)
3434
3547
  {
3435
- ContextGuard guard(context);
3436
-
3437
- return check_cuda(cudaGraphExecDestroy((cudaGraphExec_t)graph_exec));
3548
+ // ensure there are no graph captures in progress
3549
+ if (g_captures.empty())
3550
+ {
3551
+ ContextGuard guard(context);
3552
+ return check_cuda(cudaGraphExecDestroy((cudaGraphExec_t)graph_exec));
3553
+ }
3554
+ else
3555
+ {
3556
+ GraphDestroyInfo info;
3557
+ info.context = context ? context : get_current_context();
3558
+ info.graph_exec = graph_exec;
3559
+ g_deferred_graph_list.push_back(info);
3560
+ return true;
3561
+ }
3438
3562
  }
3439
3563
 
3440
3564
  bool write_file(const char* data, size_t size, std::string filename, const char* mode)
@@ -4287,17 +4411,5 @@ void wp_cuda_timing_end(timing_result_t* results, int size)
4287
4411
  g_cuda_timing_state = parent_state;
4288
4412
  }
4289
4413
 
4290
- // impl. files
4291
- #include "bvh.cu"
4292
- #include "mesh.cu"
4293
- #include "sort.cu"
4294
- #include "hashgrid.cu"
4295
- #include "reduce.cu"
4296
- #include "runlength_encode.cu"
4297
- #include "scan.cu"
4298
- #include "sparse.cu"
4299
- #include "volume.cu"
4300
- #include "volume_builder.cu"
4301
-
4302
4414
  //#include "spline.inl"
4303
4415
  //#include "volume.inl"