warp-lang 1.9.1__py3-none-manylinux_2_34_aarch64.whl → 1.10.0rc2__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (346) hide show
  1. warp/__init__.py +301 -287
  2. warp/__init__.pyi +794 -305
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1075 -0
  5. warp/_src/build.py +618 -0
  6. warp/_src/build_dll.py +640 -0
  7. warp/{builtins.py → _src/builtins.py} +1382 -377
  8. warp/_src/codegen.py +4359 -0
  9. warp/{config.py → _src/config.py} +178 -169
  10. warp/_src/constants.py +57 -0
  11. warp/_src/context.py +8294 -0
  12. warp/_src/dlpack.py +462 -0
  13. warp/_src/fabric.py +355 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +508 -0
  16. warp/_src/fem/cache.py +687 -0
  17. warp/_src/fem/dirichlet.py +188 -0
  18. warp/{fem → _src/fem}/domain.py +40 -30
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +701 -0
  21. warp/{fem → _src/fem}/field/nodal_field.py +30 -15
  22. warp/{fem → _src/fem}/field/restriction.py +1 -1
  23. warp/{fem → _src/fem}/field/virtual.py +53 -27
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/{fem → _src/fem}/geometry/adaptive_nanogrid.py +77 -163
  26. warp/_src/fem/geometry/closest_point.py +97 -0
  27. warp/{fem → _src/fem}/geometry/deformed_geometry.py +14 -22
  28. warp/{fem → _src/fem}/geometry/element.py +32 -10
  29. warp/{fem → _src/fem}/geometry/geometry.py +48 -20
  30. warp/{fem → _src/fem}/geometry/grid_2d.py +12 -23
  31. warp/{fem → _src/fem}/geometry/grid_3d.py +12 -23
  32. warp/{fem → _src/fem}/geometry/hexmesh.py +40 -63
  33. warp/{fem → _src/fem}/geometry/nanogrid.py +255 -248
  34. warp/{fem → _src/fem}/geometry/partition.py +121 -63
  35. warp/{fem → _src/fem}/geometry/quadmesh.py +26 -45
  36. warp/{fem → _src/fem}/geometry/tetmesh.py +40 -63
  37. warp/{fem → _src/fem}/geometry/trimesh.py +26 -45
  38. warp/{fem → _src/fem}/integrate.py +164 -158
  39. warp/_src/fem/linalg.py +383 -0
  40. warp/_src/fem/operator.py +396 -0
  41. warp/_src/fem/polynomial.py +229 -0
  42. warp/{fem → _src/fem}/quadrature/pic_quadrature.py +15 -20
  43. warp/{fem → _src/fem}/quadrature/quadrature.py +95 -47
  44. warp/_src/fem/space/__init__.py +248 -0
  45. warp/{fem → _src/fem}/space/basis_function_space.py +20 -11
  46. warp/_src/fem/space/basis_space.py +679 -0
  47. warp/{fem → _src/fem}/space/dof_mapper.py +3 -3
  48. warp/{fem → _src/fem}/space/function_space.py +14 -13
  49. warp/{fem → _src/fem}/space/grid_2d_function_space.py +4 -7
  50. warp/{fem → _src/fem}/space/grid_3d_function_space.py +4 -4
  51. warp/{fem → _src/fem}/space/hexmesh_function_space.py +4 -10
  52. warp/{fem → _src/fem}/space/nanogrid_function_space.py +3 -9
  53. warp/{fem → _src/fem}/space/partition.py +117 -60
  54. warp/{fem → _src/fem}/space/quadmesh_function_space.py +4 -10
  55. warp/{fem → _src/fem}/space/restriction.py +66 -33
  56. warp/_src/fem/space/shape/__init__.py +152 -0
  57. warp/{fem → _src/fem}/space/shape/cube_shape_function.py +9 -9
  58. warp/{fem → _src/fem}/space/shape/shape_function.py +8 -9
  59. warp/{fem → _src/fem}/space/shape/square_shape_function.py +6 -6
  60. warp/{fem → _src/fem}/space/shape/tet_shape_function.py +3 -3
  61. warp/{fem → _src/fem}/space/shape/triangle_shape_function.py +3 -3
  62. warp/{fem → _src/fem}/space/tetmesh_function_space.py +3 -9
  63. warp/_src/fem/space/topology.py +459 -0
  64. warp/{fem → _src/fem}/space/trimesh_function_space.py +3 -9
  65. warp/_src/fem/types.py +112 -0
  66. warp/_src/fem/utils.py +486 -0
  67. warp/_src/jax.py +186 -0
  68. warp/_src/jax_experimental/__init__.py +14 -0
  69. warp/_src/jax_experimental/custom_call.py +387 -0
  70. warp/_src/jax_experimental/ffi.py +1284 -0
  71. warp/_src/jax_experimental/xla_ffi.py +656 -0
  72. warp/_src/marching_cubes.py +708 -0
  73. warp/_src/math.py +414 -0
  74. warp/_src/optim/__init__.py +14 -0
  75. warp/_src/optim/adam.py +163 -0
  76. warp/_src/optim/linear.py +1606 -0
  77. warp/_src/optim/sgd.py +112 -0
  78. warp/_src/paddle.py +406 -0
  79. warp/_src/render/__init__.py +14 -0
  80. warp/_src/render/imgui_manager.py +289 -0
  81. warp/_src/render/render_opengl.py +3636 -0
  82. warp/_src/render/render_usd.py +937 -0
  83. warp/_src/render/utils.py +160 -0
  84. warp/_src/sparse.py +2716 -0
  85. warp/_src/tape.py +1206 -0
  86. warp/{thirdparty → _src/thirdparty}/unittest_parallel.py +9 -2
  87. warp/_src/torch.py +391 -0
  88. warp/_src/types.py +5870 -0
  89. warp/_src/utils.py +1693 -0
  90. warp/autograd.py +12 -1054
  91. warp/bin/warp-clang.so +0 -0
  92. warp/bin/warp.so +0 -0
  93. warp/build.py +8 -588
  94. warp/build_dll.py +6 -721
  95. warp/codegen.py +6 -4251
  96. warp/constants.py +6 -39
  97. warp/context.py +12 -8062
  98. warp/dlpack.py +6 -444
  99. warp/examples/distributed/example_jacobi_mpi.py +4 -5
  100. warp/examples/fem/example_adaptive_grid.py +1 -1
  101. warp/examples/fem/example_apic_fluid.py +1 -1
  102. warp/examples/fem/example_burgers.py +8 -8
  103. warp/examples/fem/example_diffusion.py +1 -1
  104. warp/examples/fem/example_distortion_energy.py +1 -1
  105. warp/examples/fem/example_mixed_elasticity.py +2 -2
  106. warp/examples/fem/example_navier_stokes.py +1 -1
  107. warp/examples/fem/example_nonconforming_contact.py +7 -7
  108. warp/examples/fem/example_stokes.py +1 -1
  109. warp/examples/fem/example_stokes_transfer.py +1 -1
  110. warp/examples/fem/utils.py +2 -2
  111. warp/examples/interop/example_jax_callable.py +1 -1
  112. warp/examples/interop/example_jax_ffi_callback.py +1 -1
  113. warp/examples/interop/example_jax_kernel.py +1 -1
  114. warp/examples/tile/example_tile_mcgp.py +191 -0
  115. warp/fabric.py +6 -337
  116. warp/fem/__init__.py +159 -97
  117. warp/fem/adaptivity.py +7 -489
  118. warp/fem/cache.py +9 -648
  119. warp/fem/dirichlet.py +6 -184
  120. warp/fem/field/__init__.py +8 -109
  121. warp/fem/field/field.py +7 -652
  122. warp/fem/geometry/__init__.py +7 -18
  123. warp/fem/geometry/closest_point.py +11 -77
  124. warp/fem/linalg.py +18 -366
  125. warp/fem/operator.py +11 -369
  126. warp/fem/polynomial.py +9 -209
  127. warp/fem/space/__init__.py +5 -211
  128. warp/fem/space/basis_space.py +6 -662
  129. warp/fem/space/shape/__init__.py +41 -118
  130. warp/fem/space/topology.py +6 -437
  131. warp/fem/types.py +6 -81
  132. warp/fem/utils.py +11 -444
  133. warp/jax.py +8 -165
  134. warp/jax_experimental/__init__.py +14 -1
  135. warp/jax_experimental/custom_call.py +8 -365
  136. warp/jax_experimental/ffi.py +17 -873
  137. warp/jax_experimental/xla_ffi.py +5 -605
  138. warp/marching_cubes.py +5 -689
  139. warp/math.py +16 -393
  140. warp/native/array.h +385 -37
  141. warp/native/builtin.h +314 -37
  142. warp/native/bvh.cpp +43 -9
  143. warp/native/bvh.cu +62 -27
  144. warp/native/bvh.h +310 -309
  145. warp/native/clang/clang.cpp +102 -97
  146. warp/native/coloring.cpp +0 -1
  147. warp/native/crt.h +208 -0
  148. warp/native/exports.h +156 -0
  149. warp/native/hashgrid.cu +2 -0
  150. warp/native/intersect.h +24 -1
  151. warp/native/intersect_tri.h +44 -35
  152. warp/native/mat.h +1456 -276
  153. warp/native/mesh.cpp +4 -4
  154. warp/native/mesh.cu +4 -2
  155. warp/native/mesh.h +176 -61
  156. warp/native/quat.h +0 -52
  157. warp/native/scan.cu +2 -0
  158. warp/native/sparse.cu +7 -3
  159. warp/native/spatial.h +12 -0
  160. warp/native/tile.h +681 -89
  161. warp/native/tile_radix_sort.h +1 -1
  162. warp/native/tile_reduce.h +394 -46
  163. warp/native/tile_scan.h +4 -4
  164. warp/native/vec.h +469 -0
  165. warp/native/version.h +23 -0
  166. warp/native/volume.cpp +1 -1
  167. warp/native/volume.cu +1 -0
  168. warp/native/volume.h +1 -1
  169. warp/native/volume_builder.cu +2 -0
  170. warp/native/warp.cpp +57 -29
  171. warp/native/warp.cu +253 -171
  172. warp/native/warp.h +11 -8
  173. warp/optim/__init__.py +6 -3
  174. warp/optim/adam.py +6 -145
  175. warp/optim/linear.py +14 -1585
  176. warp/optim/sgd.py +6 -94
  177. warp/paddle.py +6 -388
  178. warp/render/__init__.py +8 -4
  179. warp/render/imgui_manager.py +7 -267
  180. warp/render/render_opengl.py +6 -3618
  181. warp/render/render_usd.py +6 -919
  182. warp/render/utils.py +6 -142
  183. warp/sparse.py +37 -2563
  184. warp/tape.py +6 -1188
  185. warp/tests/__main__.py +1 -1
  186. warp/tests/cuda/test_async.py +4 -4
  187. warp/tests/cuda/test_conditional_captures.py +1 -1
  188. warp/tests/cuda/test_multigpu.py +1 -1
  189. warp/tests/cuda/test_streams.py +58 -1
  190. warp/tests/geometry/test_bvh.py +157 -22
  191. warp/tests/geometry/test_marching_cubes.py +0 -1
  192. warp/tests/geometry/test_mesh.py +5 -3
  193. warp/tests/geometry/test_mesh_query_aabb.py +5 -12
  194. warp/tests/geometry/test_mesh_query_point.py +5 -2
  195. warp/tests/geometry/test_mesh_query_ray.py +15 -3
  196. warp/tests/geometry/test_volume_write.py +5 -5
  197. warp/tests/interop/test_dlpack.py +14 -14
  198. warp/tests/interop/test_jax.py +772 -49
  199. warp/tests/interop/test_paddle.py +1 -1
  200. warp/tests/test_adam.py +0 -1
  201. warp/tests/test_arithmetic.py +9 -9
  202. warp/tests/test_array.py +527 -100
  203. warp/tests/test_array_reduce.py +3 -3
  204. warp/tests/test_atomic.py +12 -8
  205. warp/tests/test_atomic_bitwise.py +209 -0
  206. warp/tests/test_atomic_cas.py +4 -4
  207. warp/tests/test_bool.py +2 -2
  208. warp/tests/test_builtins_resolution.py +5 -571
  209. warp/tests/test_codegen.py +33 -14
  210. warp/tests/test_conditional.py +1 -1
  211. warp/tests/test_context.py +6 -6
  212. warp/tests/test_copy.py +242 -161
  213. warp/tests/test_ctypes.py +3 -3
  214. warp/tests/test_devices.py +24 -2
  215. warp/tests/test_examples.py +16 -84
  216. warp/tests/test_fabricarray.py +35 -35
  217. warp/tests/test_fast_math.py +0 -2
  218. warp/tests/test_fem.py +56 -10
  219. warp/tests/test_fixedarray.py +3 -3
  220. warp/tests/test_func.py +8 -5
  221. warp/tests/test_generics.py +1 -1
  222. warp/tests/test_indexedarray.py +24 -24
  223. warp/tests/test_intersect.py +39 -9
  224. warp/tests/test_large.py +1 -1
  225. warp/tests/test_lerp.py +3 -1
  226. warp/tests/test_linear_solvers.py +1 -1
  227. warp/tests/test_map.py +35 -4
  228. warp/tests/test_mat.py +52 -62
  229. warp/tests/test_mat_constructors.py +4 -5
  230. warp/tests/test_mat_lite.py +1 -1
  231. warp/tests/test_mat_scalar_ops.py +121 -121
  232. warp/tests/test_math.py +34 -0
  233. warp/tests/test_module_aot.py +4 -4
  234. warp/tests/test_modules_lite.py +28 -2
  235. warp/tests/test_print.py +11 -11
  236. warp/tests/test_quat.py +93 -58
  237. warp/tests/test_runlength_encode.py +1 -1
  238. warp/tests/test_scalar_ops.py +38 -10
  239. warp/tests/test_smoothstep.py +1 -1
  240. warp/tests/test_sparse.py +126 -15
  241. warp/tests/test_spatial.py +105 -87
  242. warp/tests/test_special_values.py +6 -6
  243. warp/tests/test_static.py +7 -7
  244. warp/tests/test_struct.py +13 -2
  245. warp/tests/test_triangle_closest_point.py +48 -1
  246. warp/tests/test_types.py +27 -15
  247. warp/tests/test_utils.py +52 -52
  248. warp/tests/test_vec.py +29 -29
  249. warp/tests/test_vec_constructors.py +5 -5
  250. warp/tests/test_vec_scalar_ops.py +97 -97
  251. warp/tests/test_version.py +75 -0
  252. warp/tests/tile/test_tile.py +178 -0
  253. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  254. warp/tests/tile/test_tile_cholesky.py +7 -4
  255. warp/tests/tile/test_tile_load.py +26 -2
  256. warp/tests/tile/test_tile_mathdx.py +3 -3
  257. warp/tests/tile/test_tile_matmul.py +1 -1
  258. warp/tests/tile/test_tile_mlp.py +2 -4
  259. warp/tests/tile/test_tile_reduce.py +214 -13
  260. warp/tests/unittest_suites.py +6 -14
  261. warp/tests/unittest_utils.py +10 -9
  262. warp/tests/walkthrough_debug.py +3 -1
  263. warp/torch.py +6 -373
  264. warp/types.py +29 -5764
  265. warp/utils.py +10 -1659
  266. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/METADATA +46 -99
  267. warp_lang-1.10.0rc2.dist-info/RECORD +468 -0
  268. warp_lang-1.10.0rc2.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  269. warp_lang-1.10.0rc2.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  270. warp_lang-1.10.0rc2.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  271. warp_lang-1.10.0rc2.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  272. warp_lang-1.10.0rc2.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  273. warp_lang-1.10.0rc2.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  274. warp_lang-1.10.0rc2.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  275. warp_lang-1.10.0rc2.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  276. warp_lang-1.10.0rc2.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  277. warp_lang-1.10.0rc2.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  278. warp_lang-1.10.0rc2.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  279. warp_lang-1.10.0rc2.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  280. warp_lang-1.10.0rc2.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  281. warp_lang-1.10.0rc2.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  282. warp_lang-1.10.0rc2.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  283. warp/examples/assets/cartpole.urdf +0 -110
  284. warp/examples/assets/crazyflie.usd +0 -0
  285. warp/examples/assets/nv_ant.xml +0 -92
  286. warp/examples/assets/nv_humanoid.xml +0 -183
  287. warp/examples/assets/quadruped.urdf +0 -268
  288. warp/examples/optim/example_bounce.py +0 -266
  289. warp/examples/optim/example_cloth_throw.py +0 -228
  290. warp/examples/optim/example_drone.py +0 -870
  291. warp/examples/optim/example_inverse_kinematics.py +0 -182
  292. warp/examples/optim/example_inverse_kinematics_torch.py +0 -191
  293. warp/examples/optim/example_softbody_properties.py +0 -400
  294. warp/examples/optim/example_spring_cage.py +0 -245
  295. warp/examples/optim/example_trajectory.py +0 -227
  296. warp/examples/sim/example_cartpole.py +0 -143
  297. warp/examples/sim/example_cloth.py +0 -225
  298. warp/examples/sim/example_cloth_self_contact.py +0 -316
  299. warp/examples/sim/example_granular.py +0 -130
  300. warp/examples/sim/example_granular_collision_sdf.py +0 -202
  301. warp/examples/sim/example_jacobian_ik.py +0 -244
  302. warp/examples/sim/example_particle_chain.py +0 -124
  303. warp/examples/sim/example_quadruped.py +0 -203
  304. warp/examples/sim/example_rigid_chain.py +0 -203
  305. warp/examples/sim/example_rigid_contact.py +0 -195
  306. warp/examples/sim/example_rigid_force.py +0 -133
  307. warp/examples/sim/example_rigid_gyroscopic.py +0 -115
  308. warp/examples/sim/example_rigid_soft_contact.py +0 -140
  309. warp/examples/sim/example_soft_body.py +0 -196
  310. warp/examples/tile/example_tile_walker.py +0 -327
  311. warp/sim/__init__.py +0 -74
  312. warp/sim/articulation.py +0 -793
  313. warp/sim/collide.py +0 -2570
  314. warp/sim/graph_coloring.py +0 -307
  315. warp/sim/import_mjcf.py +0 -791
  316. warp/sim/import_snu.py +0 -227
  317. warp/sim/import_urdf.py +0 -579
  318. warp/sim/import_usd.py +0 -898
  319. warp/sim/inertia.py +0 -357
  320. warp/sim/integrator.py +0 -245
  321. warp/sim/integrator_euler.py +0 -2000
  322. warp/sim/integrator_featherstone.py +0 -2101
  323. warp/sim/integrator_vbd.py +0 -2487
  324. warp/sim/integrator_xpbd.py +0 -3295
  325. warp/sim/model.py +0 -4821
  326. warp/sim/particles.py +0 -121
  327. warp/sim/render.py +0 -431
  328. warp/sim/utils.py +0 -431
  329. warp/tests/sim/disabled_kinematics.py +0 -244
  330. warp/tests/sim/test_cloth.py +0 -863
  331. warp/tests/sim/test_collision.py +0 -743
  332. warp/tests/sim/test_coloring.py +0 -347
  333. warp/tests/sim/test_inertia.py +0 -161
  334. warp/tests/sim/test_model.py +0 -226
  335. warp/tests/sim/test_sim_grad.py +0 -287
  336. warp/tests/sim/test_sim_grad_bounce_linear.py +0 -212
  337. warp/tests/sim/test_sim_kinematics.py +0 -98
  338. warp/thirdparty/__init__.py +0 -0
  339. warp_lang-1.9.1.dist-info/RECORD +0 -456
  340. /warp/{fem → _src/fem}/quadrature/__init__.py +0 -0
  341. /warp/{tests/sim → _src/thirdparty}/__init__.py +0 -0
  342. /warp/{thirdparty → _src/thirdparty}/appdirs.py +0 -0
  343. /warp/{thirdparty → _src/thirdparty}/dlpack.py +0 -0
  344. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/WHEEL +0 -0
  345. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/licenses/LICENSE.md +0 -0
  346. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/top_level.txt +0 -0
warp/native/warp.cu CHANGED
@@ -222,6 +222,14 @@ struct ModuleInfo
222
222
  void* module = NULL;
223
223
  };
224
224
 
225
+ // Information used when deferring graph destruction.
226
+ struct GraphDestroyInfo
227
+ {
228
+ void* context = NULL;
229
+ void* graph = NULL;
230
+ void* graph_exec = NULL;
231
+ };
232
+
225
233
  static std::unordered_map<CUfunction, std::string> g_kernel_names;
226
234
 
227
235
  // cached info for all devices, indexed by ordinal
@@ -253,6 +261,11 @@ static std::vector<FreeInfo> g_deferred_free_list;
253
261
  // Call unload_deferred_modules() to release.
254
262
  static std::vector<ModuleInfo> g_deferred_module_list;
255
263
 
264
+ // Graphs that cannot be destroyed immediately get queued here.
265
+ // Call destroy_deferred_graphs() to release.
266
+ static std::vector<GraphDestroyInfo> g_deferred_graph_list;
267
+
268
+
256
269
  void wp_cuda_set_context_restore_policy(bool always_restore)
257
270
  {
258
271
  ContextGuard::always_restore = always_restore;
@@ -338,7 +351,7 @@ int cuda_init()
338
351
  }
339
352
 
340
353
 
341
- static inline CUcontext get_current_context()
354
+ CUcontext get_current_context()
342
355
  {
343
356
  CUcontext ctx;
344
357
  if (check_cu(cuCtxGetCurrent_f(&ctx)))
@@ -495,6 +508,38 @@ static int unload_deferred_modules(void* context = NULL)
495
508
  return num_unloaded_modules;
496
509
  }
497
510
 
511
+ static int destroy_deferred_graphs(void* context = NULL)
512
+ {
513
+ if (g_deferred_graph_list.empty() || !g_captures.empty())
514
+ return 0;
515
+
516
+ int num_destroyed_graphs = 0;
517
+ for (auto it = g_deferred_graph_list.begin(); it != g_deferred_graph_list.end(); /*noop*/)
518
+ {
519
+ // destroy the graph if it matches the given context or if the context is unspecified
520
+ const GraphDestroyInfo& graph_info = *it;
521
+ if (graph_info.context == context || !context)
522
+ {
523
+ if (graph_info.graph)
524
+ {
525
+ check_cuda(cudaGraphDestroy((cudaGraph_t)graph_info.graph));
526
+ }
527
+ if (graph_info.graph_exec)
528
+ {
529
+ check_cuda(cudaGraphExecDestroy((cudaGraphExec_t)graph_info.graph_exec));
530
+ }
531
+ ++num_destroyed_graphs;
532
+ it = g_deferred_graph_list.erase(it);
533
+ }
534
+ else
535
+ {
536
+ ++it;
537
+ }
538
+ }
539
+
540
+ return num_destroyed_graphs;
541
+ }
542
+
498
543
  static void CUDART_CB on_graph_destroy(void* user_data)
499
544
  {
500
545
  if (!user_data)
@@ -989,15 +1034,15 @@ void wp_memtile_device(void* context, void* dst, const void* src, size_t srcsize
989
1034
 
990
1035
 
991
1036
  static __global__ void array_copy_1d_kernel(void* dst, const void* src,
992
- int dst_stride, int src_stride,
1037
+ size_t dst_stride, size_t src_stride,
993
1038
  const int* dst_indices, const int* src_indices,
994
- int n, int elem_size)
1039
+ size_t n, size_t elem_size)
995
1040
  {
996
- int i = blockIdx.x * blockDim.x + threadIdx.x;
1041
+ size_t i = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
997
1042
  if (i < n)
998
1043
  {
999
- int src_idx = src_indices ? src_indices[i] : i;
1000
- int dst_idx = dst_indices ? dst_indices[i] : i;
1044
+ size_t src_idx = src_indices ? src_indices[i] : i;
1045
+ size_t dst_idx = dst_indices ? dst_indices[i] : i;
1001
1046
  const char* p = (const char*)src + src_idx * src_stride;
1002
1047
  char* q = (char*)dst + dst_idx * dst_stride;
1003
1048
  memcpy(q, p, elem_size);
@@ -1005,20 +1050,20 @@ static __global__ void array_copy_1d_kernel(void* dst, const void* src,
1005
1050
  }
1006
1051
 
1007
1052
  static __global__ void array_copy_2d_kernel(void* dst, const void* src,
1008
- wp::vec_t<2, int> dst_strides, wp::vec_t<2, int> src_strides,
1053
+ wp::vec_t<2, size_t> dst_strides, wp::vec_t<2, size_t> src_strides,
1009
1054
  wp::vec_t<2, const int*> dst_indices, wp::vec_t<2, const int*> src_indices,
1010
- wp::vec_t<2, int> shape, int elem_size)
1055
+ wp::vec_t<2, size_t> shape, size_t elem_size)
1011
1056
  {
1012
- int tid = blockIdx.x * blockDim.x + threadIdx.x;
1013
- int n = shape[1];
1014
- int i = tid / n;
1015
- int j = tid % n;
1057
+ size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1058
+ size_t n = shape[1];
1059
+ size_t i = tid / n;
1060
+ size_t j = tid % n;
1016
1061
  if (i < shape[0] /*&& j < shape[1]*/)
1017
1062
  {
1018
- int src_idx0 = src_indices[0] ? src_indices[0][i] : i;
1019
- int dst_idx0 = dst_indices[0] ? dst_indices[0][i] : i;
1020
- int src_idx1 = src_indices[1] ? src_indices[1][j] : j;
1021
- int dst_idx1 = dst_indices[1] ? dst_indices[1][j] : j;
1063
+ size_t src_idx0 = src_indices[0] ? src_indices[0][i] : i;
1064
+ size_t dst_idx0 = dst_indices[0] ? dst_indices[0][i] : i;
1065
+ size_t src_idx1 = src_indices[1] ? src_indices[1][j] : j;
1066
+ size_t dst_idx1 = dst_indices[1] ? dst_indices[1][j] : j;
1022
1067
  const char* p = (const char*)src + src_idx0 * src_strides[0] + src_idx1 * src_strides[1];
1023
1068
  char* q = (char*)dst + dst_idx0 * dst_strides[0] + dst_idx1 * dst_strides[1];
1024
1069
  memcpy(q, p, elem_size);
@@ -1026,24 +1071,24 @@ static __global__ void array_copy_2d_kernel(void* dst, const void* src,
1026
1071
  }
1027
1072
 
1028
1073
  static __global__ void array_copy_3d_kernel(void* dst, const void* src,
1029
- wp::vec_t<3, int> dst_strides, wp::vec_t<3, int> src_strides,
1074
+ wp::vec_t<3, size_t> dst_strides, wp::vec_t<3, size_t> src_strides,
1030
1075
  wp::vec_t<3, const int*> dst_indices, wp::vec_t<3, const int*> src_indices,
1031
- wp::vec_t<3, int> shape, int elem_size)
1032
- {
1033
- int tid = blockIdx.x * blockDim.x + threadIdx.x;
1034
- int n = shape[1];
1035
- int o = shape[2];
1036
- int i = tid / (n * o);
1037
- int j = tid % (n * o) / o;
1038
- int k = tid % o;
1076
+ wp::vec_t<3, size_t> shape, size_t elem_size)
1077
+ {
1078
+ size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1079
+ size_t n = shape[1];
1080
+ size_t o = shape[2];
1081
+ size_t i = tid / (n * o);
1082
+ size_t j = tid % (n * o) / o;
1083
+ size_t k = tid % o;
1039
1084
  if (i < shape[0] && j < shape[1] /*&& k < shape[2]*/)
1040
1085
  {
1041
- int src_idx0 = src_indices[0] ? src_indices[0][i] : i;
1042
- int dst_idx0 = dst_indices[0] ? dst_indices[0][i] : i;
1043
- int src_idx1 = src_indices[1] ? src_indices[1][j] : j;
1044
- int dst_idx1 = dst_indices[1] ? dst_indices[1][j] : j;
1045
- int src_idx2 = src_indices[2] ? src_indices[2][k] : k;
1046
- int dst_idx2 = dst_indices[2] ? dst_indices[2][k] : k;
1086
+ size_t src_idx0 = src_indices[0] ? src_indices[0][i] : i;
1087
+ size_t dst_idx0 = dst_indices[0] ? dst_indices[0][i] : i;
1088
+ size_t src_idx1 = src_indices[1] ? src_indices[1][j] : j;
1089
+ size_t dst_idx1 = dst_indices[1] ? dst_indices[1][j] : j;
1090
+ size_t src_idx2 = src_indices[2] ? src_indices[2][k] : k;
1091
+ size_t dst_idx2 = dst_indices[2] ? dst_indices[2][k] : k;
1047
1092
  const char* p = (const char*)src + src_idx0 * src_strides[0]
1048
1093
  + src_idx1 * src_strides[1]
1049
1094
  + src_idx2 * src_strides[2];
@@ -1055,28 +1100,28 @@ static __global__ void array_copy_3d_kernel(void* dst, const void* src,
1055
1100
  }
1056
1101
 
1057
1102
  static __global__ void array_copy_4d_kernel(void* dst, const void* src,
1058
- wp::vec_t<4, int> dst_strides, wp::vec_t<4, int> src_strides,
1103
+ wp::vec_t<4, size_t> dst_strides, wp::vec_t<4, size_t> src_strides,
1059
1104
  wp::vec_t<4, const int*> dst_indices, wp::vec_t<4, const int*> src_indices,
1060
- wp::vec_t<4, int> shape, int elem_size)
1061
- {
1062
- int tid = blockIdx.x * blockDim.x + threadIdx.x;
1063
- int n = shape[1];
1064
- int o = shape[2];
1065
- int p = shape[3];
1066
- int i = tid / (n * o * p);
1067
- int j = tid % (n * o * p) / (o * p);
1068
- int k = tid % (o * p) / p;
1069
- int l = tid % p;
1105
+ wp::vec_t<4, size_t> shape, size_t elem_size)
1106
+ {
1107
+ size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1108
+ size_t n = shape[1];
1109
+ size_t o = shape[2];
1110
+ size_t p = shape[3];
1111
+ size_t i = tid / (n * o * p);
1112
+ size_t j = tid % (n * o * p) / (o * p);
1113
+ size_t k = tid % (o * p) / p;
1114
+ size_t l = tid % p;
1070
1115
  if (i < shape[0] && j < shape[1] && k < shape[2] /*&& l < shape[3]*/)
1071
1116
  {
1072
- int src_idx0 = src_indices[0] ? src_indices[0][i] : i;
1073
- int dst_idx0 = dst_indices[0] ? dst_indices[0][i] : i;
1074
- int src_idx1 = src_indices[1] ? src_indices[1][j] : j;
1075
- int dst_idx1 = dst_indices[1] ? dst_indices[1][j] : j;
1076
- int src_idx2 = src_indices[2] ? src_indices[2][k] : k;
1077
- int dst_idx2 = dst_indices[2] ? dst_indices[2][k] : k;
1078
- int src_idx3 = src_indices[3] ? src_indices[3][l] : l;
1079
- int dst_idx3 = dst_indices[3] ? dst_indices[3][l] : l;
1117
+ size_t src_idx0 = src_indices[0] ? src_indices[0][i] : i;
1118
+ size_t dst_idx0 = dst_indices[0] ? dst_indices[0][i] : i;
1119
+ size_t src_idx1 = src_indices[1] ? src_indices[1][j] : j;
1120
+ size_t dst_idx1 = dst_indices[1] ? dst_indices[1][j] : j;
1121
+ size_t src_idx2 = src_indices[2] ? src_indices[2][k] : k;
1122
+ size_t dst_idx2 = dst_indices[2] ? dst_indices[2][k] : k;
1123
+ size_t src_idx3 = src_indices[3] ? src_indices[3][l] : l;
1124
+ size_t dst_idx3 = dst_indices[3] ? dst_indices[3][l] : l;
1080
1125
  const char* p = (const char*)src + src_idx0 * src_strides[0]
1081
1126
  + src_idx1 * src_strides[1]
1082
1127
  + src_idx2 * src_strides[2]
@@ -1091,14 +1136,14 @@ static __global__ void array_copy_4d_kernel(void* dst, const void* src,
1091
1136
 
1092
1137
 
1093
1138
  static __global__ void array_copy_from_fabric_kernel(wp::fabricarray_t<void> src,
1094
- void* dst_data, int dst_stride, const int* dst_indices,
1095
- int elem_size)
1139
+ void* dst_data, size_t dst_stride, const int* dst_indices,
1140
+ size_t elem_size)
1096
1141
  {
1097
- int tid = blockIdx.x * blockDim.x + threadIdx.x;
1142
+ size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1098
1143
 
1099
1144
  if (tid < src.size)
1100
1145
  {
1101
- int dst_idx = dst_indices ? dst_indices[tid] : tid;
1146
+ size_t dst_idx = dst_indices ? dst_indices[tid] : tid;
1102
1147
  void* dst_ptr = (char*)dst_data + dst_idx * dst_stride;
1103
1148
  const void* src_ptr = fabricarray_element_ptr(src, tid, elem_size);
1104
1149
  memcpy(dst_ptr, src_ptr, elem_size);
@@ -1106,15 +1151,15 @@ static __global__ void array_copy_from_fabric_kernel(wp::fabricarray_t<void> src
1106
1151
  }
1107
1152
 
1108
1153
  static __global__ void array_copy_from_fabric_indexed_kernel(wp::indexedfabricarray_t<void> src,
1109
- void* dst_data, int dst_stride, const int* dst_indices,
1110
- int elem_size)
1154
+ void* dst_data, size_t dst_stride, const int* dst_indices,
1155
+ size_t elem_size)
1111
1156
  {
1112
- int tid = blockIdx.x * blockDim.x + threadIdx.x;
1157
+ size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1113
1158
 
1114
1159
  if (tid < src.size)
1115
1160
  {
1116
- int src_index = src.indices[tid];
1117
- int dst_idx = dst_indices ? dst_indices[tid] : tid;
1161
+ size_t src_index = src.indices[tid];
1162
+ size_t dst_idx = dst_indices ? dst_indices[tid] : tid;
1118
1163
  void* dst_ptr = (char*)dst_data + dst_idx * dst_stride;
1119
1164
  const void* src_ptr = fabricarray_element_ptr(src.fa, src_index, elem_size);
1120
1165
  memcpy(dst_ptr, src_ptr, elem_size);
@@ -1122,14 +1167,14 @@ static __global__ void array_copy_from_fabric_indexed_kernel(wp::indexedfabricar
1122
1167
  }
1123
1168
 
1124
1169
  static __global__ void array_copy_to_fabric_kernel(wp::fabricarray_t<void> dst,
1125
- const void* src_data, int src_stride, const int* src_indices,
1126
- int elem_size)
1170
+ const void* src_data, size_t src_stride, const int* src_indices,
1171
+ size_t elem_size)
1127
1172
  {
1128
- int tid = blockIdx.x * blockDim.x + threadIdx.x;
1173
+ size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1129
1174
 
1130
1175
  if (tid < dst.size)
1131
1176
  {
1132
- int src_idx = src_indices ? src_indices[tid] : tid;
1177
+ size_t src_idx = src_indices ? src_indices[tid] : tid;
1133
1178
  const void* src_ptr = (const char*)src_data + src_idx * src_stride;
1134
1179
  void* dst_ptr = fabricarray_element_ptr(dst, tid, elem_size);
1135
1180
  memcpy(dst_ptr, src_ptr, elem_size);
@@ -1137,25 +1182,25 @@ static __global__ void array_copy_to_fabric_kernel(wp::fabricarray_t<void> dst,
1137
1182
  }
1138
1183
 
1139
1184
  static __global__ void array_copy_to_fabric_indexed_kernel(wp::indexedfabricarray_t<void> dst,
1140
- const void* src_data, int src_stride, const int* src_indices,
1141
- int elem_size)
1185
+ const void* src_data, size_t src_stride, const int* src_indices,
1186
+ size_t elem_size)
1142
1187
  {
1143
- int tid = blockIdx.x * blockDim.x + threadIdx.x;
1188
+ size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1144
1189
 
1145
1190
  if (tid < dst.size)
1146
1191
  {
1147
- int src_idx = src_indices ? src_indices[tid] : tid;
1192
+ size_t src_idx = src_indices ? src_indices[tid] : tid;
1148
1193
  const void* src_ptr = (const char*)src_data + src_idx * src_stride;
1149
- int dst_idx = dst.indices[tid];
1194
+ size_t dst_idx = dst.indices[tid];
1150
1195
  void* dst_ptr = fabricarray_element_ptr(dst.fa, dst_idx, elem_size);
1151
1196
  memcpy(dst_ptr, src_ptr, elem_size);
1152
1197
  }
1153
1198
  }
1154
1199
 
1155
1200
 
1156
- static __global__ void array_copy_fabric_to_fabric_kernel(wp::fabricarray_t<void> dst, wp::fabricarray_t<void> src, int elem_size)
1201
+ static __global__ void array_copy_fabric_to_fabric_kernel(wp::fabricarray_t<void> dst, wp::fabricarray_t<void> src, size_t elem_size)
1157
1202
  {
1158
- int tid = blockIdx.x * blockDim.x + threadIdx.x;
1203
+ size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1159
1204
 
1160
1205
  if (tid < dst.size)
1161
1206
  {
@@ -1166,27 +1211,27 @@ static __global__ void array_copy_fabric_to_fabric_kernel(wp::fabricarray_t<void
1166
1211
  }
1167
1212
 
1168
1213
 
1169
- static __global__ void array_copy_fabric_to_fabric_indexed_kernel(wp::indexedfabricarray_t<void> dst, wp::fabricarray_t<void> src, int elem_size)
1214
+ static __global__ void array_copy_fabric_to_fabric_indexed_kernel(wp::indexedfabricarray_t<void> dst, wp::fabricarray_t<void> src, size_t elem_size)
1170
1215
  {
1171
- int tid = blockIdx.x * blockDim.x + threadIdx.x;
1216
+ size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1172
1217
 
1173
1218
  if (tid < dst.size)
1174
1219
  {
1175
1220
  const void* src_ptr = fabricarray_element_ptr(src, tid, elem_size);
1176
- int dst_index = dst.indices[tid];
1221
+ size_t dst_index = dst.indices[tid];
1177
1222
  void* dst_ptr = fabricarray_element_ptr(dst.fa, dst_index, elem_size);
1178
1223
  memcpy(dst_ptr, src_ptr, elem_size);
1179
1224
  }
1180
1225
  }
1181
1226
 
1182
1227
 
1183
- static __global__ void array_copy_fabric_indexed_to_fabric_kernel(wp::fabricarray_t<void> dst, wp::indexedfabricarray_t<void> src, int elem_size)
1228
+ static __global__ void array_copy_fabric_indexed_to_fabric_kernel(wp::fabricarray_t<void> dst, wp::indexedfabricarray_t<void> src, size_t elem_size)
1184
1229
  {
1185
- int tid = blockIdx.x * blockDim.x + threadIdx.x;
1230
+ size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1186
1231
 
1187
1232
  if (tid < dst.size)
1188
1233
  {
1189
- int src_index = src.indices[tid];
1234
+ size_t src_index = src.indices[tid];
1190
1235
  const void* src_ptr = fabricarray_element_ptr(src.fa, src_index, elem_size);
1191
1236
  void* dst_ptr = fabricarray_element_ptr(dst, tid, elem_size);
1192
1237
  memcpy(dst_ptr, src_ptr, elem_size);
@@ -1194,14 +1239,14 @@ static __global__ void array_copy_fabric_indexed_to_fabric_kernel(wp::fabricarra
1194
1239
  }
1195
1240
 
1196
1241
 
1197
- static __global__ void array_copy_fabric_indexed_to_fabric_indexed_kernel(wp::indexedfabricarray_t<void> dst, wp::indexedfabricarray_t<void> src, int elem_size)
1242
+ static __global__ void array_copy_fabric_indexed_to_fabric_indexed_kernel(wp::indexedfabricarray_t<void> dst, wp::indexedfabricarray_t<void> src, size_t elem_size)
1198
1243
  {
1199
- int tid = blockIdx.x * blockDim.x + threadIdx.x;
1244
+ size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1200
1245
 
1201
1246
  if (tid < dst.size)
1202
1247
  {
1203
- int src_index = src.indices[tid];
1204
- int dst_index = dst.indices[tid];
1248
+ size_t src_index = src.indices[tid];
1249
+ size_t dst_index = dst.indices[tid];
1205
1250
  const void* src_ptr = fabricarray_element_ptr(src.fa, src_index, elem_size);
1206
1251
  void* dst_ptr = fabricarray_element_ptr(dst.fa, dst_index, elem_size);
1207
1252
  memcpy(dst_ptr, src_ptr, elem_size);
@@ -1440,9 +1485,9 @@ WP_API bool wp_array_copy_device(void* context, void* dst, void* src, int dst_ty
1440
1485
  }
1441
1486
  case 2:
1442
1487
  {
1443
- wp::vec_t<2, int> shape_v(src_shape[0], src_shape[1]);
1444
- wp::vec_t<2, int> src_strides_v(src_strides[0], src_strides[1]);
1445
- wp::vec_t<2, int> dst_strides_v(dst_strides[0], dst_strides[1]);
1488
+ wp::vec_t<2, size_t> shape_v(src_shape[0], src_shape[1]);
1489
+ wp::vec_t<2, size_t> src_strides_v(src_strides[0], src_strides[1]);
1490
+ wp::vec_t<2, size_t> dst_strides_v(dst_strides[0], dst_strides[1]);
1446
1491
  wp::vec_t<2, const int*> src_indices_v(src_indices[0], src_indices[1]);
1447
1492
  wp::vec_t<2, const int*> dst_indices_v(dst_indices[0], dst_indices[1]);
1448
1493
 
@@ -1454,9 +1499,9 @@ WP_API bool wp_array_copy_device(void* context, void* dst, void* src, int dst_ty
1454
1499
  }
1455
1500
  case 3:
1456
1501
  {
1457
- wp::vec_t<3, int> shape_v(src_shape[0], src_shape[1], src_shape[2]);
1458
- wp::vec_t<3, int> src_strides_v(src_strides[0], src_strides[1], src_strides[2]);
1459
- wp::vec_t<3, int> dst_strides_v(dst_strides[0], dst_strides[1], dst_strides[2]);
1502
+ wp::vec_t<3, size_t> shape_v(src_shape[0], src_shape[1], src_shape[2]);
1503
+ wp::vec_t<3, size_t> src_strides_v(src_strides[0], src_strides[1], src_strides[2]);
1504
+ wp::vec_t<3, size_t> dst_strides_v(dst_strides[0], dst_strides[1], dst_strides[2]);
1460
1505
  wp::vec_t<3, const int*> src_indices_v(src_indices[0], src_indices[1], src_indices[2]);
1461
1506
  wp::vec_t<3, const int*> dst_indices_v(dst_indices[0], dst_indices[1], dst_indices[2]);
1462
1507
 
@@ -1468,9 +1513,9 @@ WP_API bool wp_array_copy_device(void* context, void* dst, void* src, int dst_ty
1468
1513
  }
1469
1514
  case 4:
1470
1515
  {
1471
- wp::vec_t<4, int> shape_v(src_shape[0], src_shape[1], src_shape[2], src_shape[3]);
1472
- wp::vec_t<4, int> src_strides_v(src_strides[0], src_strides[1], src_strides[2], src_strides[3]);
1473
- wp::vec_t<4, int> dst_strides_v(dst_strides[0], dst_strides[1], dst_strides[2], dst_strides[3]);
1516
+ wp::vec_t<4, size_t> shape_v(src_shape[0], src_shape[1], src_shape[2], src_shape[3]);
1517
+ wp::vec_t<4, size_t> src_strides_v(src_strides[0], src_strides[1], src_strides[2], src_strides[3]);
1518
+ wp::vec_t<4, size_t> dst_strides_v(dst_strides[0], dst_strides[1], dst_strides[2], dst_strides[3]);
1474
1519
  wp::vec_t<4, const int*> src_indices_v(src_indices[0], src_indices[1], src_indices[2], src_indices[3]);
1475
1520
  wp::vec_t<4, const int*> dst_indices_v(dst_indices[0], dst_indices[1], dst_indices[2], dst_indices[3]);
1476
1521
 
@@ -1490,94 +1535,94 @@ WP_API bool wp_array_copy_device(void* context, void* dst, void* src, int dst_ty
1490
1535
 
1491
1536
 
1492
1537
  static __global__ void array_fill_1d_kernel(void* data,
1493
- int n,
1494
- int stride,
1538
+ size_t n,
1539
+ size_t stride,
1495
1540
  const int* indices,
1496
1541
  const void* value,
1497
- int value_size)
1542
+ size_t value_size)
1498
1543
  {
1499
- int i = blockIdx.x * blockDim.x + threadIdx.x;
1544
+ size_t i = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1500
1545
  if (i < n)
1501
1546
  {
1502
- int idx = indices ? indices[i] : i;
1547
+ size_t idx = indices ? indices[i] : i;
1503
1548
  char* p = (char*)data + idx * stride;
1504
1549
  memcpy(p, value, value_size);
1505
1550
  }
1506
1551
  }
1507
1552
 
1508
1553
  static __global__ void array_fill_2d_kernel(void* data,
1509
- wp::vec_t<2, int> shape,
1510
- wp::vec_t<2, int> strides,
1554
+ wp::vec_t<2, size_t> shape,
1555
+ wp::vec_t<2, size_t> strides,
1511
1556
  wp::vec_t<2, const int*> indices,
1512
1557
  const void* value,
1513
- int value_size)
1558
+ size_t value_size)
1514
1559
  {
1515
- int tid = blockIdx.x * blockDim.x + threadIdx.x;
1516
- int n = shape[1];
1517
- int i = tid / n;
1518
- int j = tid % n;
1560
+ size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1561
+ size_t n = shape[1];
1562
+ size_t i = tid / n;
1563
+ size_t j = tid % n;
1519
1564
  if (i < shape[0] /*&& j < shape[1]*/)
1520
1565
  {
1521
- int idx0 = indices[0] ? indices[0][i] : i;
1522
- int idx1 = indices[1] ? indices[1][j] : j;
1566
+ size_t idx0 = indices[0] ? indices[0][i] : i;
1567
+ size_t idx1 = indices[1] ? indices[1][j] : j;
1523
1568
  char* p = (char*)data + idx0 * strides[0] + idx1 * strides[1];
1524
1569
  memcpy(p, value, value_size);
1525
1570
  }
1526
1571
  }
1527
1572
 
1528
1573
  static __global__ void array_fill_3d_kernel(void* data,
1529
- wp::vec_t<3, int> shape,
1530
- wp::vec_t<3, int> strides,
1574
+ wp::vec_t<3, size_t> shape,
1575
+ wp::vec_t<3, size_t> strides,
1531
1576
  wp::vec_t<3, const int*> indices,
1532
1577
  const void* value,
1533
- int value_size)
1534
- {
1535
- int tid = blockIdx.x * blockDim.x + threadIdx.x;
1536
- int n = shape[1];
1537
- int o = shape[2];
1538
- int i = tid / (n * o);
1539
- int j = tid % (n * o) / o;
1540
- int k = tid % o;
1578
+ size_t value_size)
1579
+ {
1580
+ size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1581
+ size_t n = shape[1];
1582
+ size_t o = shape[2];
1583
+ size_t i = tid / (n * o);
1584
+ size_t j = tid % (n * o) / o;
1585
+ size_t k = tid % o;
1541
1586
  if (i < shape[0] && j < shape[1] /*&& k < shape[2]*/)
1542
1587
  {
1543
- int idx0 = indices[0] ? indices[0][i] : i;
1544
- int idx1 = indices[1] ? indices[1][j] : j;
1545
- int idx2 = indices[2] ? indices[2][k] : k;
1588
+ size_t idx0 = indices[0] ? indices[0][i] : i;
1589
+ size_t idx1 = indices[1] ? indices[1][j] : j;
1590
+ size_t idx2 = indices[2] ? indices[2][k] : k;
1546
1591
  char* p = (char*)data + idx0 * strides[0] + idx1 * strides[1] + idx2 * strides[2];
1547
1592
  memcpy(p, value, value_size);
1548
1593
  }
1549
1594
  }
1550
1595
 
1551
1596
  static __global__ void array_fill_4d_kernel(void* data,
1552
- wp::vec_t<4, int> shape,
1553
- wp::vec_t<4, int> strides,
1597
+ wp::vec_t<4, size_t> shape,
1598
+ wp::vec_t<4, size_t> strides,
1554
1599
  wp::vec_t<4, const int*> indices,
1555
1600
  const void* value,
1556
- int value_size)
1557
- {
1558
- int tid = blockIdx.x * blockDim.x + threadIdx.x;
1559
- int n = shape[1];
1560
- int o = shape[2];
1561
- int p = shape[3];
1562
- int i = tid / (n * o * p);
1563
- int j = tid % (n * o * p) / (o * p);
1564
- int k = tid % (o * p) / p;
1565
- int l = tid % p;
1601
+ size_t value_size)
1602
+ {
1603
+ size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1604
+ size_t n = shape[1];
1605
+ size_t o = shape[2];
1606
+ size_t p = shape[3];
1607
+ size_t i = tid / (n * o * p);
1608
+ size_t j = tid % (n * o * p) / (o * p);
1609
+ size_t k = tid % (o * p) / p;
1610
+ size_t l = tid % p;
1566
1611
  if (i < shape[0] && j < shape[1] && k < shape[2] /*&& l < shape[3]*/)
1567
1612
  {
1568
- int idx0 = indices[0] ? indices[0][i] : i;
1569
- int idx1 = indices[1] ? indices[1][j] : j;
1570
- int idx2 = indices[2] ? indices[2][k] : k;
1571
- int idx3 = indices[3] ? indices[3][l] : l;
1613
+ size_t idx0 = indices[0] ? indices[0][i] : i;
1614
+ size_t idx1 = indices[1] ? indices[1][j] : j;
1615
+ size_t idx2 = indices[2] ? indices[2][k] : k;
1616
+ size_t idx3 = indices[3] ? indices[3][l] : l;
1572
1617
  char* p = (char*)data + idx0 * strides[0] + idx1 * strides[1] + idx2 * strides[2] + idx3 * strides[3];
1573
1618
  memcpy(p, value, value_size);
1574
1619
  }
1575
1620
  }
1576
1621
 
1577
1622
 
1578
- static __global__ void array_fill_fabric_kernel(wp::fabricarray_t<void> fa, const void* value, int value_size)
1623
+ static __global__ void array_fill_fabric_kernel(wp::fabricarray_t<void> fa, const void* value, size_t value_size)
1579
1624
  {
1580
- int tid = blockIdx.x * blockDim.x + threadIdx.x;
1625
+ size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1581
1626
  if (tid < fa.size)
1582
1627
  {
1583
1628
  void* dst_ptr = fabricarray_element_ptr(fa, tid, value_size);
@@ -1586,9 +1631,9 @@ static __global__ void array_fill_fabric_kernel(wp::fabricarray_t<void> fa, cons
1586
1631
  }
1587
1632
 
1588
1633
 
1589
- static __global__ void array_fill_fabric_indexed_kernel(wp::indexedfabricarray_t<void> ifa, const void* value, int value_size)
1634
+ static __global__ void array_fill_fabric_indexed_kernel(wp::indexedfabricarray_t<void> ifa, const void* value, size_t value_size)
1590
1635
  {
1591
- int tid = blockIdx.x * blockDim.x + threadIdx.x;
1636
+ size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1592
1637
  if (tid < ifa.size)
1593
1638
  {
1594
1639
  size_t idx = size_t(ifa.indices[tid]);
@@ -1685,8 +1730,8 @@ WP_API void wp_array_fill_device(void* context, void* arr_ptr, int arr_type, con
1685
1730
  }
1686
1731
  case 2:
1687
1732
  {
1688
- wp::vec_t<2, int> shape_v(shape[0], shape[1]);
1689
- wp::vec_t<2, int> strides_v(strides[0], strides[1]);
1733
+ wp::vec_t<2, size_t> shape_v(shape[0], shape[1]);
1734
+ wp::vec_t<2, size_t> strides_v(strides[0], strides[1]);
1690
1735
  wp::vec_t<2, const int*> indices_v(indices[0], indices[1]);
1691
1736
  wp_launch_device(WP_CURRENT_CONTEXT, array_fill_2d_kernel, n,
1692
1737
  (data, shape_v, strides_v, indices_v, value_devptr, value_size));
@@ -1694,8 +1739,8 @@ WP_API void wp_array_fill_device(void* context, void* arr_ptr, int arr_type, con
1694
1739
  }
1695
1740
  case 3:
1696
1741
  {
1697
- wp::vec_t<3, int> shape_v(shape[0], shape[1], shape[2]);
1698
- wp::vec_t<3, int> strides_v(strides[0], strides[1], strides[2]);
1742
+ wp::vec_t<3, size_t> shape_v(shape[0], shape[1], shape[2]);
1743
+ wp::vec_t<3, size_t> strides_v(strides[0], strides[1], strides[2]);
1699
1744
  wp::vec_t<3, const int*> indices_v(indices[0], indices[1], indices[2]);
1700
1745
  wp_launch_device(WP_CURRENT_CONTEXT, array_fill_3d_kernel, n,
1701
1746
  (data, shape_v, strides_v, indices_v, value_devptr, value_size));
@@ -1703,8 +1748,8 @@ WP_API void wp_array_fill_device(void* context, void* arr_ptr, int arr_type, con
1703
1748
  }
1704
1749
  case 4:
1705
1750
  {
1706
- wp::vec_t<4, int> shape_v(shape[0], shape[1], shape[2], shape[3]);
1707
- wp::vec_t<4, int> strides_v(strides[0], strides[1], strides[2], strides[3]);
1751
+ wp::vec_t<4, size_t> shape_v(shape[0], shape[1], shape[2], shape[3]);
1752
+ wp::vec_t<4, size_t> strides_v(strides[0], strides[1], strides[2], strides[3]);
1708
1753
  wp::vec_t<4, const int*> indices_v(indices[0], indices[1], indices[2], indices[3]);
1709
1754
  wp_launch_device(WP_CURRENT_CONTEXT, array_fill_4d_kernel, n,
1710
1755
  (data, shape_v, strides_v, indices_v, value_devptr, value_size));
@@ -2072,13 +2117,17 @@ void wp_cuda_context_synchronize(void* context)
2072
2117
 
2073
2118
  check_cu(cuCtxSynchronize_f());
2074
2119
 
2075
- if (free_deferred_allocs(context ? context : get_current_context()) > 0)
2120
+ if (!context)
2121
+ context = get_current_context();
2122
+
2123
+ if (free_deferred_allocs(context) > 0)
2076
2124
  {
2077
2125
  // ensure deferred asynchronous deallocations complete
2078
2126
  check_cu(cuCtxSynchronize_f());
2079
2127
  }
2080
2128
 
2081
2129
  unload_deferred_modules(context);
2130
+ destroy_deferred_graphs(context);
2082
2131
 
2083
2132
  // check_cuda(cudaDeviceGraphMemTrim(wp_cuda_context_get_device_ordinal(context)));
2084
2133
  }
@@ -2514,15 +2563,36 @@ void wp_cuda_stream_synchronize(void* stream)
2514
2563
  check_cu(cuStreamSynchronize_f(static_cast<CUstream>(stream)));
2515
2564
  }
2516
2565
 
2517
- void wp_cuda_stream_wait_event(void* stream, void* event)
2566
+ void wp_cuda_stream_wait_event(void* stream, void* event, bool external)
2518
2567
  {
2519
- check_cu(cuStreamWaitEvent_f(static_cast<CUstream>(stream), static_cast<CUevent>(event), 0));
2568
+ // the external flag can only be used during graph capture
2569
+ if (external && !g_captures.empty() && wp_cuda_stream_is_capturing(stream))
2570
+ {
2571
+ // wait for an external event during graph capture
2572
+ check_cu(cuStreamWaitEvent_f(static_cast<CUstream>(stream), static_cast<CUevent>(event), CU_EVENT_WAIT_EXTERNAL));
2573
+ }
2574
+ else
2575
+ {
2576
+ check_cu(cuStreamWaitEvent_f(static_cast<CUstream>(stream), static_cast<CUevent>(event), CU_EVENT_WAIT_DEFAULT));
2577
+ }
2520
2578
  }
2521
2579
 
2522
- void wp_cuda_stream_wait_stream(void* stream, void* other_stream, void* event)
2580
+ void wp_cuda_stream_wait_stream(void* stream, void* other_stream, void* event, bool external)
2523
2581
  {
2524
- check_cu(cuEventRecord_f(static_cast<CUevent>(event), static_cast<CUstream>(other_stream)));
2525
- check_cu(cuStreamWaitEvent_f(static_cast<CUstream>(stream), static_cast<CUevent>(event), 0));
2582
+ unsigned record_flags = CU_EVENT_RECORD_DEFAULT;
2583
+ unsigned wait_flags = CU_EVENT_WAIT_DEFAULT;
2584
+
2585
+ // the external flag can only be used during graph capture
2586
+ if (external && !g_captures.empty())
2587
+ {
2588
+ if (wp_cuda_stream_is_capturing(other_stream))
2589
+ record_flags = CU_EVENT_RECORD_EXTERNAL;
2590
+ if (wp_cuda_stream_is_capturing(stream))
2591
+ wait_flags = CU_EVENT_WAIT_EXTERNAL;
2592
+ }
2593
+
2594
+ check_cu(cuEventRecordWithFlags_f(static_cast<CUevent>(event), static_cast<CUstream>(other_stream), record_flags));
2595
+ check_cu(cuStreamWaitEvent_f(static_cast<CUstream>(stream), static_cast<CUevent>(event), wait_flags));
2526
2596
  }
2527
2597
 
2528
2598
  int wp_cuda_stream_is_capturing(void* stream)
@@ -2575,11 +2645,12 @@ int wp_cuda_event_query(void* event)
2575
2645
  return res;
2576
2646
  }
2577
2647
 
2578
- void wp_cuda_event_record(void* event, void* stream, bool timing)
2648
+ void wp_cuda_event_record(void* event, void* stream, bool external)
2579
2649
  {
2580
- if (timing && !g_captures.empty() && wp_cuda_stream_is_capturing(stream))
2650
+ // the external flag can only be used during graph capture
2651
+ if (external && !g_captures.empty() && wp_cuda_stream_is_capturing(stream))
2581
2652
  {
2582
- // record timing event during graph capture
2653
+ // record external event during graph capture (e.g., for timing or when explicitly specified by the user)
2583
2654
  check_cu(cuEventRecordWithFlags_f(static_cast<CUevent>(event), static_cast<CUstream>(stream), CU_EVENT_RECORD_EXTERNAL));
2584
2655
  }
2585
2656
  else
@@ -2629,7 +2700,7 @@ bool wp_cuda_graph_begin_capture(void* context, void* stream, int external)
2629
2700
  else
2630
2701
  {
2631
2702
  // start the capture
2632
- if (!check_cuda(cudaStreamBeginCapture(cuda_stream, cudaStreamCaptureModeGlobal)))
2703
+ if (!check_cuda(cudaStreamBeginCapture(cuda_stream, cudaStreamCaptureModeThreadLocal)))
2633
2704
  return false;
2634
2705
  }
2635
2706
 
@@ -2776,6 +2847,7 @@ bool wp_cuda_graph_end_capture(void* context, void* stream, void** graph_ret)
2776
2847
  {
2777
2848
  free_deferred_allocs();
2778
2849
  unload_deferred_modules();
2850
+ destroy_deferred_graphs();
2779
2851
  }
2780
2852
 
2781
2853
  if (graph_ret)
@@ -2996,7 +3068,7 @@ bool wp_cuda_graph_resume_capture(void* context, void* stream, void* graph)
2996
3068
  leaf_nodes.data(),
2997
3069
  nullptr,
2998
3070
  leaf_nodes.size(),
2999
- cudaStreamCaptureModeGlobal)))
3071
+ cudaStreamCaptureModeThreadLocal)))
3000
3072
  return false;
3001
3073
 
3002
3074
  return true;
@@ -3455,16 +3527,38 @@ bool wp_cuda_graph_launch(void* graph_exec, void* stream)
3455
3527
 
3456
3528
  bool wp_cuda_graph_destroy(void* context, void* graph)
3457
3529
  {
3458
- ContextGuard guard(context);
3459
-
3460
- return check_cuda(cudaGraphDestroy((cudaGraph_t)graph));
3530
+ // ensure there are no graph captures in progress
3531
+ if (g_captures.empty())
3532
+ {
3533
+ ContextGuard guard(context);
3534
+ return check_cuda(cudaGraphDestroy((cudaGraph_t)graph));
3535
+ }
3536
+ else
3537
+ {
3538
+ GraphDestroyInfo info;
3539
+ info.context = context ? context : get_current_context();
3540
+ info.graph = graph;
3541
+ g_deferred_graph_list.push_back(info);
3542
+ return true;
3543
+ }
3461
3544
  }
3462
3545
 
3463
3546
  bool wp_cuda_graph_exec_destroy(void* context, void* graph_exec)
3464
3547
  {
3465
- ContextGuard guard(context);
3466
-
3467
- return check_cuda(cudaGraphExecDestroy((cudaGraphExec_t)graph_exec));
3548
+ // ensure there are no graph captures in progress
3549
+ if (g_captures.empty())
3550
+ {
3551
+ ContextGuard guard(context);
3552
+ return check_cuda(cudaGraphExecDestroy((cudaGraphExec_t)graph_exec));
3553
+ }
3554
+ else
3555
+ {
3556
+ GraphDestroyInfo info;
3557
+ info.context = context ? context : get_current_context();
3558
+ info.graph_exec = graph_exec;
3559
+ g_deferred_graph_list.push_back(info);
3560
+ return true;
3561
+ }
3468
3562
  }
3469
3563
 
3470
3564
  bool write_file(const char* data, size_t size, std::string filename, const char* mode)
@@ -4317,17 +4411,5 @@ void wp_cuda_timing_end(timing_result_t* results, int size)
4317
4411
  g_cuda_timing_state = parent_state;
4318
4412
  }
4319
4413
 
4320
- // impl. files
4321
- #include "bvh.cu"
4322
- #include "mesh.cu"
4323
- #include "sort.cu"
4324
- #include "hashgrid.cu"
4325
- #include "reduce.cu"
4326
- #include "runlength_encode.cu"
4327
- #include "scan.cu"
4328
- #include "sparse.cu"
4329
- #include "volume.cu"
4330
- #include "volume_builder.cu"
4331
-
4332
4414
  //#include "spline.inl"
4333
4415
  //#include "volume.inl"