warp-lang 1.10.0__py3-none-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (468) hide show
  1. warp/__init__.py +334 -0
  2. warp/__init__.pyi +5856 -0
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1077 -0
  5. warp/_src/build.py +620 -0
  6. warp/_src/build_dll.py +642 -0
  7. warp/_src/builtins.py +10555 -0
  8. warp/_src/codegen.py +4361 -0
  9. warp/_src/config.py +178 -0
  10. warp/_src/constants.py +59 -0
  11. warp/_src/context.py +8352 -0
  12. warp/_src/dlpack.py +464 -0
  13. warp/_src/fabric.py +362 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +510 -0
  16. warp/_src/fem/cache.py +689 -0
  17. warp/_src/fem/dirichlet.py +190 -0
  18. warp/_src/fem/domain.py +553 -0
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +703 -0
  21. warp/_src/fem/field/nodal_field.py +403 -0
  22. warp/_src/fem/field/restriction.py +39 -0
  23. warp/_src/fem/field/virtual.py +1021 -0
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/_src/fem/geometry/adaptive_nanogrid.py +782 -0
  26. warp/_src/fem/geometry/closest_point.py +99 -0
  27. warp/_src/fem/geometry/deformed_geometry.py +277 -0
  28. warp/_src/fem/geometry/element.py +854 -0
  29. warp/_src/fem/geometry/geometry.py +693 -0
  30. warp/_src/fem/geometry/grid_2d.py +478 -0
  31. warp/_src/fem/geometry/grid_3d.py +539 -0
  32. warp/_src/fem/geometry/hexmesh.py +956 -0
  33. warp/_src/fem/geometry/nanogrid.py +660 -0
  34. warp/_src/fem/geometry/partition.py +483 -0
  35. warp/_src/fem/geometry/quadmesh.py +597 -0
  36. warp/_src/fem/geometry/tetmesh.py +762 -0
  37. warp/_src/fem/geometry/trimesh.py +588 -0
  38. warp/_src/fem/integrate.py +2507 -0
  39. warp/_src/fem/linalg.py +385 -0
  40. warp/_src/fem/operator.py +398 -0
  41. warp/_src/fem/polynomial.py +231 -0
  42. warp/_src/fem/quadrature/__init__.py +17 -0
  43. warp/_src/fem/quadrature/pic_quadrature.py +318 -0
  44. warp/_src/fem/quadrature/quadrature.py +665 -0
  45. warp/_src/fem/space/__init__.py +248 -0
  46. warp/_src/fem/space/basis_function_space.py +499 -0
  47. warp/_src/fem/space/basis_space.py +681 -0
  48. warp/_src/fem/space/dof_mapper.py +253 -0
  49. warp/_src/fem/space/function_space.py +312 -0
  50. warp/_src/fem/space/grid_2d_function_space.py +179 -0
  51. warp/_src/fem/space/grid_3d_function_space.py +229 -0
  52. warp/_src/fem/space/hexmesh_function_space.py +255 -0
  53. warp/_src/fem/space/nanogrid_function_space.py +199 -0
  54. warp/_src/fem/space/partition.py +435 -0
  55. warp/_src/fem/space/quadmesh_function_space.py +222 -0
  56. warp/_src/fem/space/restriction.py +221 -0
  57. warp/_src/fem/space/shape/__init__.py +152 -0
  58. warp/_src/fem/space/shape/cube_shape_function.py +1107 -0
  59. warp/_src/fem/space/shape/shape_function.py +134 -0
  60. warp/_src/fem/space/shape/square_shape_function.py +928 -0
  61. warp/_src/fem/space/shape/tet_shape_function.py +829 -0
  62. warp/_src/fem/space/shape/triangle_shape_function.py +674 -0
  63. warp/_src/fem/space/tetmesh_function_space.py +270 -0
  64. warp/_src/fem/space/topology.py +461 -0
  65. warp/_src/fem/space/trimesh_function_space.py +193 -0
  66. warp/_src/fem/types.py +114 -0
  67. warp/_src/fem/utils.py +488 -0
  68. warp/_src/jax.py +188 -0
  69. warp/_src/jax_experimental/__init__.py +14 -0
  70. warp/_src/jax_experimental/custom_call.py +389 -0
  71. warp/_src/jax_experimental/ffi.py +1286 -0
  72. warp/_src/jax_experimental/xla_ffi.py +658 -0
  73. warp/_src/marching_cubes.py +710 -0
  74. warp/_src/math.py +416 -0
  75. warp/_src/optim/__init__.py +14 -0
  76. warp/_src/optim/adam.py +165 -0
  77. warp/_src/optim/linear.py +1608 -0
  78. warp/_src/optim/sgd.py +114 -0
  79. warp/_src/paddle.py +408 -0
  80. warp/_src/render/__init__.py +14 -0
  81. warp/_src/render/imgui_manager.py +291 -0
  82. warp/_src/render/render_opengl.py +3638 -0
  83. warp/_src/render/render_usd.py +939 -0
  84. warp/_src/render/utils.py +162 -0
  85. warp/_src/sparse.py +2718 -0
  86. warp/_src/tape.py +1208 -0
  87. warp/_src/thirdparty/__init__.py +0 -0
  88. warp/_src/thirdparty/appdirs.py +598 -0
  89. warp/_src/thirdparty/dlpack.py +145 -0
  90. warp/_src/thirdparty/unittest_parallel.py +676 -0
  91. warp/_src/torch.py +393 -0
  92. warp/_src/types.py +5888 -0
  93. warp/_src/utils.py +1695 -0
  94. warp/autograd.py +33 -0
  95. warp/bin/libwarp-clang.dylib +0 -0
  96. warp/bin/libwarp.dylib +0 -0
  97. warp/build.py +29 -0
  98. warp/build_dll.py +24 -0
  99. warp/codegen.py +24 -0
  100. warp/constants.py +24 -0
  101. warp/context.py +33 -0
  102. warp/dlpack.py +24 -0
  103. warp/examples/__init__.py +24 -0
  104. warp/examples/assets/bear.usd +0 -0
  105. warp/examples/assets/bunny.usd +0 -0
  106. warp/examples/assets/cube.usd +0 -0
  107. warp/examples/assets/nonuniform.usd +0 -0
  108. warp/examples/assets/nvidia_logo.png +0 -0
  109. warp/examples/assets/pixel.jpg +0 -0
  110. warp/examples/assets/rocks.nvdb +0 -0
  111. warp/examples/assets/rocks.usd +0 -0
  112. warp/examples/assets/sphere.usd +0 -0
  113. warp/examples/assets/square_cloth.usd +0 -0
  114. warp/examples/benchmarks/benchmark_api.py +389 -0
  115. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  116. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  117. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  118. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  119. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  120. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  121. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  122. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  123. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  124. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  125. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  126. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  127. warp/examples/benchmarks/benchmark_launches.py +301 -0
  128. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  129. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  130. warp/examples/browse.py +37 -0
  131. warp/examples/core/example_cupy.py +86 -0
  132. warp/examples/core/example_dem.py +241 -0
  133. warp/examples/core/example_fluid.py +299 -0
  134. warp/examples/core/example_graph_capture.py +150 -0
  135. warp/examples/core/example_marching_cubes.py +195 -0
  136. warp/examples/core/example_mesh.py +180 -0
  137. warp/examples/core/example_mesh_intersect.py +211 -0
  138. warp/examples/core/example_nvdb.py +182 -0
  139. warp/examples/core/example_raycast.py +111 -0
  140. warp/examples/core/example_raymarch.py +205 -0
  141. warp/examples/core/example_render_opengl.py +290 -0
  142. warp/examples/core/example_sample_mesh.py +300 -0
  143. warp/examples/core/example_sph.py +411 -0
  144. warp/examples/core/example_spin_lock.py +93 -0
  145. warp/examples/core/example_torch.py +211 -0
  146. warp/examples/core/example_wave.py +269 -0
  147. warp/examples/core/example_work_queue.py +118 -0
  148. warp/examples/distributed/example_jacobi_mpi.py +506 -0
  149. warp/examples/fem/example_adaptive_grid.py +286 -0
  150. warp/examples/fem/example_apic_fluid.py +469 -0
  151. warp/examples/fem/example_burgers.py +261 -0
  152. warp/examples/fem/example_convection_diffusion.py +181 -0
  153. warp/examples/fem/example_convection_diffusion_dg.py +225 -0
  154. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  155. warp/examples/fem/example_deformed_geometry.py +172 -0
  156. warp/examples/fem/example_diffusion.py +196 -0
  157. warp/examples/fem/example_diffusion_3d.py +225 -0
  158. warp/examples/fem/example_diffusion_mgpu.py +225 -0
  159. warp/examples/fem/example_distortion_energy.py +228 -0
  160. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  161. warp/examples/fem/example_magnetostatics.py +242 -0
  162. warp/examples/fem/example_mixed_elasticity.py +293 -0
  163. warp/examples/fem/example_navier_stokes.py +263 -0
  164. warp/examples/fem/example_nonconforming_contact.py +300 -0
  165. warp/examples/fem/example_stokes.py +213 -0
  166. warp/examples/fem/example_stokes_transfer.py +262 -0
  167. warp/examples/fem/example_streamlines.py +357 -0
  168. warp/examples/fem/utils.py +1047 -0
  169. warp/examples/interop/example_jax_callable.py +146 -0
  170. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  171. warp/examples/interop/example_jax_kernel.py +232 -0
  172. warp/examples/optim/example_diffray.py +561 -0
  173. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  174. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  175. warp/examples/tile/example_tile_cholesky.py +88 -0
  176. warp/examples/tile/example_tile_convolution.py +66 -0
  177. warp/examples/tile/example_tile_fft.py +55 -0
  178. warp/examples/tile/example_tile_filtering.py +113 -0
  179. warp/examples/tile/example_tile_matmul.py +85 -0
  180. warp/examples/tile/example_tile_mcgp.py +191 -0
  181. warp/examples/tile/example_tile_mlp.py +385 -0
  182. warp/examples/tile/example_tile_nbody.py +199 -0
  183. warp/fabric.py +24 -0
  184. warp/fem/__init__.py +173 -0
  185. warp/fem/adaptivity.py +26 -0
  186. warp/fem/cache.py +30 -0
  187. warp/fem/dirichlet.py +24 -0
  188. warp/fem/field/__init__.py +24 -0
  189. warp/fem/field/field.py +26 -0
  190. warp/fem/geometry/__init__.py +21 -0
  191. warp/fem/geometry/closest_point.py +31 -0
  192. warp/fem/linalg.py +38 -0
  193. warp/fem/operator.py +32 -0
  194. warp/fem/polynomial.py +29 -0
  195. warp/fem/space/__init__.py +22 -0
  196. warp/fem/space/basis_space.py +24 -0
  197. warp/fem/space/shape/__init__.py +68 -0
  198. warp/fem/space/topology.py +24 -0
  199. warp/fem/types.py +24 -0
  200. warp/fem/utils.py +32 -0
  201. warp/jax.py +29 -0
  202. warp/jax_experimental/__init__.py +29 -0
  203. warp/jax_experimental/custom_call.py +29 -0
  204. warp/jax_experimental/ffi.py +39 -0
  205. warp/jax_experimental/xla_ffi.py +24 -0
  206. warp/marching_cubes.py +24 -0
  207. warp/math.py +37 -0
  208. warp/native/array.h +1687 -0
  209. warp/native/builtin.h +2327 -0
  210. warp/native/bvh.cpp +562 -0
  211. warp/native/bvh.cu +826 -0
  212. warp/native/bvh.h +555 -0
  213. warp/native/clang/clang.cpp +541 -0
  214. warp/native/coloring.cpp +622 -0
  215. warp/native/crt.cpp +51 -0
  216. warp/native/crt.h +568 -0
  217. warp/native/cuda_crt.h +1058 -0
  218. warp/native/cuda_util.cpp +677 -0
  219. warp/native/cuda_util.h +313 -0
  220. warp/native/error.cpp +77 -0
  221. warp/native/error.h +36 -0
  222. warp/native/exports.h +2023 -0
  223. warp/native/fabric.h +246 -0
  224. warp/native/hashgrid.cpp +311 -0
  225. warp/native/hashgrid.cu +89 -0
  226. warp/native/hashgrid.h +240 -0
  227. warp/native/initializer_array.h +41 -0
  228. warp/native/intersect.h +1253 -0
  229. warp/native/intersect_adj.h +375 -0
  230. warp/native/intersect_tri.h +348 -0
  231. warp/native/mat.h +5189 -0
  232. warp/native/mathdx.cpp +93 -0
  233. warp/native/matnn.h +221 -0
  234. warp/native/mesh.cpp +266 -0
  235. warp/native/mesh.cu +406 -0
  236. warp/native/mesh.h +2097 -0
  237. warp/native/nanovdb/GridHandle.h +533 -0
  238. warp/native/nanovdb/HostBuffer.h +591 -0
  239. warp/native/nanovdb/NanoVDB.h +6246 -0
  240. warp/native/nanovdb/NodeManager.h +323 -0
  241. warp/native/nanovdb/PNanoVDB.h +3390 -0
  242. warp/native/noise.h +859 -0
  243. warp/native/quat.h +1664 -0
  244. warp/native/rand.h +342 -0
  245. warp/native/range.h +145 -0
  246. warp/native/reduce.cpp +174 -0
  247. warp/native/reduce.cu +363 -0
  248. warp/native/runlength_encode.cpp +79 -0
  249. warp/native/runlength_encode.cu +61 -0
  250. warp/native/scan.cpp +47 -0
  251. warp/native/scan.cu +55 -0
  252. warp/native/scan.h +23 -0
  253. warp/native/solid_angle.h +466 -0
  254. warp/native/sort.cpp +251 -0
  255. warp/native/sort.cu +286 -0
  256. warp/native/sort.h +35 -0
  257. warp/native/sparse.cpp +241 -0
  258. warp/native/sparse.cu +435 -0
  259. warp/native/spatial.h +1306 -0
  260. warp/native/svd.h +727 -0
  261. warp/native/temp_buffer.h +46 -0
  262. warp/native/tile.h +4124 -0
  263. warp/native/tile_radix_sort.h +1112 -0
  264. warp/native/tile_reduce.h +838 -0
  265. warp/native/tile_scan.h +240 -0
  266. warp/native/tuple.h +189 -0
  267. warp/native/vec.h +2199 -0
  268. warp/native/version.h +23 -0
  269. warp/native/volume.cpp +501 -0
  270. warp/native/volume.cu +68 -0
  271. warp/native/volume.h +970 -0
  272. warp/native/volume_builder.cu +483 -0
  273. warp/native/volume_builder.h +52 -0
  274. warp/native/volume_impl.h +70 -0
  275. warp/native/warp.cpp +1143 -0
  276. warp/native/warp.cu +4604 -0
  277. warp/native/warp.h +358 -0
  278. warp/optim/__init__.py +20 -0
  279. warp/optim/adam.py +24 -0
  280. warp/optim/linear.py +35 -0
  281. warp/optim/sgd.py +24 -0
  282. warp/paddle.py +24 -0
  283. warp/py.typed +0 -0
  284. warp/render/__init__.py +22 -0
  285. warp/render/imgui_manager.py +29 -0
  286. warp/render/render_opengl.py +24 -0
  287. warp/render/render_usd.py +24 -0
  288. warp/render/utils.py +24 -0
  289. warp/sparse.py +51 -0
  290. warp/tape.py +24 -0
  291. warp/tests/__init__.py +1 -0
  292. warp/tests/__main__.py +4 -0
  293. warp/tests/assets/curlnoise_golden.npy +0 -0
  294. warp/tests/assets/mlp_golden.npy +0 -0
  295. warp/tests/assets/pixel.npy +0 -0
  296. warp/tests/assets/pnoise_golden.npy +0 -0
  297. warp/tests/assets/spiky.usd +0 -0
  298. warp/tests/assets/test_grid.nvdb +0 -0
  299. warp/tests/assets/test_index_grid.nvdb +0 -0
  300. warp/tests/assets/test_int32_grid.nvdb +0 -0
  301. warp/tests/assets/test_vec_grid.nvdb +0 -0
  302. warp/tests/assets/torus.nvdb +0 -0
  303. warp/tests/assets/torus.usda +105 -0
  304. warp/tests/aux_test_class_kernel.py +34 -0
  305. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  306. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  307. warp/tests/aux_test_dependent.py +29 -0
  308. warp/tests/aux_test_grad_customs.py +29 -0
  309. warp/tests/aux_test_instancing_gc.py +26 -0
  310. warp/tests/aux_test_module_aot.py +7 -0
  311. warp/tests/aux_test_module_unload.py +23 -0
  312. warp/tests/aux_test_name_clash1.py +40 -0
  313. warp/tests/aux_test_name_clash2.py +40 -0
  314. warp/tests/aux_test_reference.py +9 -0
  315. warp/tests/aux_test_reference_reference.py +8 -0
  316. warp/tests/aux_test_square.py +16 -0
  317. warp/tests/aux_test_unresolved_func.py +22 -0
  318. warp/tests/aux_test_unresolved_symbol.py +22 -0
  319. warp/tests/cuda/__init__.py +0 -0
  320. warp/tests/cuda/test_async.py +676 -0
  321. warp/tests/cuda/test_conditional_captures.py +1147 -0
  322. warp/tests/cuda/test_ipc.py +124 -0
  323. warp/tests/cuda/test_mempool.py +233 -0
  324. warp/tests/cuda/test_multigpu.py +169 -0
  325. warp/tests/cuda/test_peer.py +139 -0
  326. warp/tests/cuda/test_pinned.py +84 -0
  327. warp/tests/cuda/test_streams.py +691 -0
  328. warp/tests/geometry/__init__.py +0 -0
  329. warp/tests/geometry/test_bvh.py +335 -0
  330. warp/tests/geometry/test_hash_grid.py +259 -0
  331. warp/tests/geometry/test_marching_cubes.py +294 -0
  332. warp/tests/geometry/test_mesh.py +318 -0
  333. warp/tests/geometry/test_mesh_query_aabb.py +392 -0
  334. warp/tests/geometry/test_mesh_query_point.py +935 -0
  335. warp/tests/geometry/test_mesh_query_ray.py +323 -0
  336. warp/tests/geometry/test_volume.py +1103 -0
  337. warp/tests/geometry/test_volume_write.py +346 -0
  338. warp/tests/interop/__init__.py +0 -0
  339. warp/tests/interop/test_dlpack.py +730 -0
  340. warp/tests/interop/test_jax.py +1673 -0
  341. warp/tests/interop/test_paddle.py +800 -0
  342. warp/tests/interop/test_torch.py +1001 -0
  343. warp/tests/run_coverage_serial.py +39 -0
  344. warp/tests/test_adam.py +162 -0
  345. warp/tests/test_arithmetic.py +1096 -0
  346. warp/tests/test_array.py +3756 -0
  347. warp/tests/test_array_reduce.py +156 -0
  348. warp/tests/test_assert.py +303 -0
  349. warp/tests/test_atomic.py +336 -0
  350. warp/tests/test_atomic_bitwise.py +209 -0
  351. warp/tests/test_atomic_cas.py +312 -0
  352. warp/tests/test_bool.py +220 -0
  353. warp/tests/test_builtins_resolution.py +732 -0
  354. warp/tests/test_closest_point_edge_edge.py +327 -0
  355. warp/tests/test_codegen.py +974 -0
  356. warp/tests/test_codegen_instancing.py +1495 -0
  357. warp/tests/test_compile_consts.py +215 -0
  358. warp/tests/test_conditional.py +298 -0
  359. warp/tests/test_context.py +35 -0
  360. warp/tests/test_copy.py +319 -0
  361. warp/tests/test_ctypes.py +618 -0
  362. warp/tests/test_dense.py +73 -0
  363. warp/tests/test_devices.py +127 -0
  364. warp/tests/test_enum.py +136 -0
  365. warp/tests/test_examples.py +424 -0
  366. warp/tests/test_fabricarray.py +998 -0
  367. warp/tests/test_fast_math.py +72 -0
  368. warp/tests/test_fem.py +2204 -0
  369. warp/tests/test_fixedarray.py +229 -0
  370. warp/tests/test_fp16.py +136 -0
  371. warp/tests/test_func.py +501 -0
  372. warp/tests/test_future_annotations.py +100 -0
  373. warp/tests/test_generics.py +656 -0
  374. warp/tests/test_grad.py +893 -0
  375. warp/tests/test_grad_customs.py +339 -0
  376. warp/tests/test_grad_debug.py +341 -0
  377. warp/tests/test_implicit_init.py +411 -0
  378. warp/tests/test_import.py +45 -0
  379. warp/tests/test_indexedarray.py +1140 -0
  380. warp/tests/test_intersect.py +103 -0
  381. warp/tests/test_iter.py +76 -0
  382. warp/tests/test_large.py +177 -0
  383. warp/tests/test_launch.py +411 -0
  384. warp/tests/test_lerp.py +151 -0
  385. warp/tests/test_linear_solvers.py +223 -0
  386. warp/tests/test_lvalue.py +427 -0
  387. warp/tests/test_map.py +526 -0
  388. warp/tests/test_mat.py +3515 -0
  389. warp/tests/test_mat_assign_copy.py +178 -0
  390. warp/tests/test_mat_constructors.py +573 -0
  391. warp/tests/test_mat_lite.py +122 -0
  392. warp/tests/test_mat_scalar_ops.py +2913 -0
  393. warp/tests/test_math.py +212 -0
  394. warp/tests/test_module_aot.py +287 -0
  395. warp/tests/test_module_hashing.py +258 -0
  396. warp/tests/test_modules_lite.py +70 -0
  397. warp/tests/test_noise.py +252 -0
  398. warp/tests/test_operators.py +299 -0
  399. warp/tests/test_options.py +129 -0
  400. warp/tests/test_overwrite.py +551 -0
  401. warp/tests/test_print.py +408 -0
  402. warp/tests/test_quat.py +2653 -0
  403. warp/tests/test_quat_assign_copy.py +145 -0
  404. warp/tests/test_rand.py +339 -0
  405. warp/tests/test_reload.py +303 -0
  406. warp/tests/test_rounding.py +157 -0
  407. warp/tests/test_runlength_encode.py +196 -0
  408. warp/tests/test_scalar_ops.py +133 -0
  409. warp/tests/test_smoothstep.py +108 -0
  410. warp/tests/test_snippet.py +318 -0
  411. warp/tests/test_sparse.py +845 -0
  412. warp/tests/test_spatial.py +2859 -0
  413. warp/tests/test_spatial_assign_copy.py +160 -0
  414. warp/tests/test_special_values.py +361 -0
  415. warp/tests/test_static.py +640 -0
  416. warp/tests/test_struct.py +901 -0
  417. warp/tests/test_tape.py +242 -0
  418. warp/tests/test_transient_module.py +93 -0
  419. warp/tests/test_triangle_closest_point.py +192 -0
  420. warp/tests/test_tuple.py +361 -0
  421. warp/tests/test_types.py +615 -0
  422. warp/tests/test_utils.py +594 -0
  423. warp/tests/test_vec.py +1408 -0
  424. warp/tests/test_vec_assign_copy.py +143 -0
  425. warp/tests/test_vec_constructors.py +325 -0
  426. warp/tests/test_vec_lite.py +80 -0
  427. warp/tests/test_vec_scalar_ops.py +2327 -0
  428. warp/tests/test_verify_fp.py +100 -0
  429. warp/tests/test_version.py +75 -0
  430. warp/tests/tile/__init__.py +0 -0
  431. warp/tests/tile/test_tile.py +1519 -0
  432. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  433. warp/tests/tile/test_tile_cholesky.py +608 -0
  434. warp/tests/tile/test_tile_load.py +724 -0
  435. warp/tests/tile/test_tile_mathdx.py +156 -0
  436. warp/tests/tile/test_tile_matmul.py +179 -0
  437. warp/tests/tile/test_tile_mlp.py +400 -0
  438. warp/tests/tile/test_tile_reduce.py +950 -0
  439. warp/tests/tile/test_tile_shared_memory.py +376 -0
  440. warp/tests/tile/test_tile_sort.py +121 -0
  441. warp/tests/tile/test_tile_view.py +173 -0
  442. warp/tests/unittest_serial.py +47 -0
  443. warp/tests/unittest_suites.py +430 -0
  444. warp/tests/unittest_utils.py +469 -0
  445. warp/tests/walkthrough_debug.py +95 -0
  446. warp/torch.py +24 -0
  447. warp/types.py +51 -0
  448. warp/utils.py +31 -0
  449. warp_lang-1.10.0.dist-info/METADATA +459 -0
  450. warp_lang-1.10.0.dist-info/RECORD +468 -0
  451. warp_lang-1.10.0.dist-info/WHEEL +5 -0
  452. warp_lang-1.10.0.dist-info/licenses/LICENSE.md +176 -0
  453. warp_lang-1.10.0.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  454. warp_lang-1.10.0.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  455. warp_lang-1.10.0.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  456. warp_lang-1.10.0.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  457. warp_lang-1.10.0.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  458. warp_lang-1.10.0.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  459. warp_lang-1.10.0.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  460. warp_lang-1.10.0.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  461. warp_lang-1.10.0.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  462. warp_lang-1.10.0.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  463. warp_lang-1.10.0.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  464. warp_lang-1.10.0.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  465. warp_lang-1.10.0.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  466. warp_lang-1.10.0.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  467. warp_lang-1.10.0.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  468. warp_lang-1.10.0.dist-info/top_level.txt +1 -0
warp/native/warp.cu ADDED
@@ -0,0 +1,4604 @@
1
+ /*
2
+ * SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: Apache-2.0
4
+ *
5
+ * Licensed under the Apache License, Version 2.0 (the "License");
6
+ * you may not use this file except in compliance with the License.
7
+ * You may obtain a copy of the License at
8
+ *
9
+ * http://www.apache.org/licenses/LICENSE-2.0
10
+ *
11
+ * Unless required by applicable law or agreed to in writing, software
12
+ * distributed under the License is distributed on an "AS IS" BASIS,
13
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ * See the License for the specific language governing permissions and
15
+ * limitations under the License.
16
+ */
17
+
18
+ #include "warp.h"
19
+ #include "scan.h"
20
+ #include "cuda_util.h"
21
+ #include "error.h"
22
+ #include "sort.h"
23
+
24
+ #include <cstdlib>
25
+ #include <fstream>
26
+ #include <nvrtc.h>
27
+ #include <nvPTXCompiler.h>
28
+ #if WP_ENABLE_MATHDX
29
+ #include <nvJitLink.h>
30
+ #include <libmathdx.h>
31
+ #include <libcublasdx.h>
32
+ #include <libcufftdx.h>
33
+ #include <libcusolverdx.h>
34
+ #endif
35
+
36
+ #include <array>
37
+ #include <algorithm>
38
+ #include <iterator>
39
+ #include <list>
40
+ #include <map>
41
+ #include <mutex>
42
+ #include <string>
43
+ #include <unordered_map>
44
+ #include <unordered_set>
45
+ #include <vector>
46
+
47
+ #define check_any(result) (check_generic(result, __FILE__, __LINE__))
48
+ #define check_nvrtc(code) (check_nvrtc_result(code, __FILE__, __LINE__))
49
+ #define check_nvptx(code) (check_nvptx_result(code, __FILE__, __LINE__))
50
+ #define check_nvjitlink(handle, code) (check_nvjitlink_result(handle, code, __FILE__, __LINE__))
51
+ #define check_cufftdx(code) (check_cufftdx_result(code, __FILE__, __LINE__))
52
+ #define check_cublasdx(code) (check_cublasdx_result(code, __FILE__, __LINE__))
53
+ #define check_cusolver(code) (check_cusolver_result(code, __FILE__, __LINE__))
54
+ #define CHECK_ANY(code) \
55
+ { \
56
+ do { \
57
+ bool out = (check_any(code)); \
58
+ if(!out) { \
59
+ return out; \
60
+ } \
61
+ } while(0); \
62
+ }
63
+ #define CHECK_CUFFTDX(code) \
64
+ { \
65
+ do { \
66
+ bool out = (check_cufftdx(code)); \
67
+ if(!out) { \
68
+ return out; \
69
+ } \
70
+ } while(0); \
71
+ }
72
+ #define CHECK_CUBLASDX(code) \
73
+ { \
74
+ do { \
75
+ bool out = (check_cufftdx(code)); \
76
+ if(!out) { \
77
+ return out; \
78
+ } \
79
+ } while(0); \
80
+ }
81
+ #define CHECK_CUSOLVER(code) \
82
+ { \
83
+ do { \
84
+ bool out = (check_cusolver(code)); \
85
+ if(!out) { \
86
+ return out; \
87
+ } \
88
+ } while(0); \
89
+ }
90
+
91
+ bool check_nvrtc_result(nvrtcResult result, const char* file, int line)
92
+ {
93
+ if (result == NVRTC_SUCCESS)
94
+ return true;
95
+
96
+ const char* error_string = nvrtcGetErrorString(result);
97
+ fprintf(stderr, "Warp NVRTC compilation error %u: %s (%s:%d)\n", unsigned(result), error_string, file, line);
98
+ return false;
99
+ }
100
+
101
+ bool check_nvptx_result(nvPTXCompileResult result, const char* file, int line)
102
+ {
103
+ if (result == NVPTXCOMPILE_SUCCESS)
104
+ return true;
105
+
106
+ const char* error_string;
107
+ switch (result)
108
+ {
109
+ case NVPTXCOMPILE_ERROR_INVALID_COMPILER_HANDLE:
110
+ error_string = "Invalid compiler handle";
111
+ break;
112
+ case NVPTXCOMPILE_ERROR_INVALID_INPUT:
113
+ error_string = "Invalid input";
114
+ break;
115
+ case NVPTXCOMPILE_ERROR_COMPILATION_FAILURE:
116
+ error_string = "Compilation failure";
117
+ break;
118
+ case NVPTXCOMPILE_ERROR_INTERNAL:
119
+ error_string = "Internal error";
120
+ break;
121
+ case NVPTXCOMPILE_ERROR_OUT_OF_MEMORY:
122
+ error_string = "Out of memory";
123
+ break;
124
+ case NVPTXCOMPILE_ERROR_COMPILER_INVOCATION_INCOMPLETE:
125
+ error_string = "Incomplete compiler invocation";
126
+ break;
127
+ case NVPTXCOMPILE_ERROR_UNSUPPORTED_PTX_VERSION:
128
+ error_string = "Unsupported PTX version";
129
+ break;
130
+ default:
131
+ error_string = "Unknown error";
132
+ break;
133
+ }
134
+
135
+ fprintf(stderr, "Warp PTX compilation error %u: %s (%s:%d)\n", unsigned(result), error_string, file, line);
136
+ return false;
137
+ }
138
+
139
+ bool check_generic(int result, const char* file, int line)
140
+ {
141
+ if (!result) {
142
+ fprintf(stderr, "Error %d on %s:%d\n", (int)result, file, line);
143
+ return false;
144
+ } else {
145
+ return true;
146
+ }
147
+ }
148
+
149
+ struct DeviceInfo
150
+ {
151
+ static constexpr int kNameLen = 128;
152
+
153
+ CUdevice device = -1;
154
+ CUuuid uuid = {0};
155
+ int ordinal = -1;
156
+ int pci_domain_id = -1;
157
+ int pci_bus_id = -1;
158
+ int pci_device_id = -1;
159
+ char name[kNameLen] = "";
160
+ int arch = 0;
161
+ int is_uva = 0;
162
+ int is_mempool_supported = 0;
163
+ int sm_count = 0;
164
+ int is_ipc_supported = -1;
165
+ int max_smem_bytes = 0;
166
+ CUcontext primary_context = NULL;
167
+ };
168
+
169
+ struct ContextInfo
170
+ {
171
+ DeviceInfo* device_info = NULL;
172
+
173
+ // the current stream, managed from Python (see wp_cuda_context_set_stream() and wp_cuda_context_get_stream())
174
+ CUstream stream = NULL;
175
+
176
+ // conditional graph node support, loaded on demand if the driver supports it (CUDA 12.4+)
177
+ CUmodule conditional_module = NULL;
178
+ };
179
+
180
+ // Information used for freeing allocations.
181
+ struct FreeInfo
182
+ {
183
+ void* context = NULL;
184
+ void* ptr = NULL;
185
+ bool is_async = false;
186
+ };
187
+
188
+ struct CaptureInfo
189
+ {
190
+ CUstream stream = NULL; // the main stream where capture begins and ends
191
+ uint64_t id = 0; // unique capture id from CUDA
192
+ bool external = false; // whether this is an external capture
193
+ std::vector<FreeInfo> tmp_allocs; // temporary allocations owned by the graph (e.g., staged array fill values)
194
+ };
195
+
196
+ struct StreamInfo
197
+ {
198
+ CUevent cached_event = NULL; // event used for stream synchronization (cached to avoid creating temporary events)
199
+ CaptureInfo* capture = NULL; // capture info (only if started on this stream)
200
+ };
201
+
202
+ // Extra resources tied to a graph, freed after the graph is released by CUDA.
203
+ // Used with the on_graph_destroy() callback.
204
+ struct GraphDestroyCallbackInfo
205
+ {
206
+ void* context = NULL; // graph CUDA context
207
+ std::vector<void*> unfreed_allocs; // graph allocations not freed by the graph
208
+ std::vector<FreeInfo> tmp_allocs; // temporary allocations owned by the graph (e.g., staged array fill values)
209
+ };
210
+
211
+ // Information for graph allocations that are not freed by the graph.
212
+ // These allocations have a shared ownership:
213
+ // - The graph instance allocates/maps the memory on each launch, even if the user reference is released.
214
+ // - The user reference must remain valid even if the graph is destroyed.
215
+ // The memory will be freed once the user reference is released and the graph is destroyed.
216
+ struct GraphAllocInfo
217
+ {
218
+ uint64_t capture_id = 0;
219
+ void* context = NULL;
220
+ bool ref_exists = false; // whether user reference still exists
221
+ bool graph_destroyed = false; // whether graph instance was destroyed
222
+ };
223
+
224
+ // Information used when deferring module unloading.
225
+ struct ModuleInfo
226
+ {
227
+ void* context = NULL;
228
+ void* module = NULL;
229
+ };
230
+
231
+ // Information used when deferring graph destruction.
232
+ struct GraphDestroyInfo
233
+ {
234
+ void* context = NULL;
235
+ void* graph = NULL;
236
+ void* graph_exec = NULL;
237
+ };
238
+
239
+ static std::unordered_map<CUfunction, std::string> g_kernel_names;
240
+
241
+ // cached info for all devices, indexed by ordinal
242
+ static std::vector<DeviceInfo> g_devices;
243
+
244
+ // maps CUdevice to DeviceInfo
245
+ static std::map<CUdevice, DeviceInfo*> g_device_map;
246
+
247
+ // cached info for all known contexts
248
+ static std::map<CUcontext, ContextInfo> g_contexts;
249
+
250
+ // cached info for all known streams (including registered external streams)
251
+ static std::unordered_map<CUstream, StreamInfo> g_streams;
252
+
253
+ // Ongoing graph captures registered using wp.capture_begin().
254
+ // This maps the capture id to the stream where capture was started.
255
+ // See wp_cuda_graph_begin_capture(), wp_cuda_graph_end_capture(), and wp_free_device_async().
256
+ static std::unordered_map<uint64_t, CaptureInfo*> g_captures;
257
+
258
+ // Memory allocated during graph capture requires special handling.
259
+ // See wp_alloc_device_async() and wp_free_device_async().
260
+ static std::unordered_map<void*, GraphAllocInfo> g_graph_allocs;
261
+
262
+ // Memory that cannot be freed immediately gets queued here.
263
+ // Call free_deferred_allocs() to release.
264
+ static std::vector<FreeInfo> g_deferred_free_list;
265
+
266
+ // Modules that cannot be unloaded immediately get queued here.
267
+ // Call unload_deferred_modules() to release.
268
+ static std::vector<ModuleInfo> g_deferred_module_list;
269
+
270
+ // Graphs that cannot be destroyed immediately get queued here.
271
+ // Call destroy_deferred_graphs() to release.
272
+ static std::vector<GraphDestroyInfo> g_deferred_graph_list;
273
+
274
+ // Data from on_graph_destroy() callbacks that run on a different thread.
275
+ static std::vector<GraphDestroyCallbackInfo*> g_deferred_graph_destroy_list;
276
+ static std::mutex g_graph_destroy_mutex;
277
+
278
+
279
+ void wp_cuda_set_context_restore_policy(bool always_restore)
280
+ {
281
+ ContextGuard::always_restore = always_restore;
282
+ }
283
+
284
+ int wp_cuda_get_context_restore_policy()
285
+ {
286
+ return int(ContextGuard::always_restore);
287
+ }
288
+
289
+ int cuda_init()
290
+ {
291
+ if (!init_cuda_driver())
292
+ return -1;
293
+
294
+ int device_count = 0;
295
+ if (check_cu(cuDeviceGetCount_f(&device_count)))
296
+ {
297
+ g_devices.resize(device_count);
298
+
299
+ for (int i = 0; i < device_count; i++)
300
+ {
301
+ CUdevice device;
302
+ if (check_cu(cuDeviceGet_f(&device, i)))
303
+ {
304
+ // query device info
305
+ g_devices[i].device = device;
306
+ g_devices[i].ordinal = i;
307
+ check_cu(cuDeviceGetName_f(g_devices[i].name, DeviceInfo::kNameLen, device));
308
+ check_cu(cuDeviceGetUuid_f(&g_devices[i].uuid, device));
309
+ check_cu(cuDeviceGetAttribute_f(&g_devices[i].pci_domain_id, CU_DEVICE_ATTRIBUTE_PCI_DOMAIN_ID, device));
310
+ check_cu(cuDeviceGetAttribute_f(&g_devices[i].pci_bus_id, CU_DEVICE_ATTRIBUTE_PCI_BUS_ID, device));
311
+ check_cu(cuDeviceGetAttribute_f(&g_devices[i].pci_device_id, CU_DEVICE_ATTRIBUTE_PCI_DEVICE_ID, device));
312
+ check_cu(cuDeviceGetAttribute_f(&g_devices[i].is_uva, CU_DEVICE_ATTRIBUTE_UNIFIED_ADDRESSING, device));
313
+ check_cu(cuDeviceGetAttribute_f(&g_devices[i].is_mempool_supported, CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED, device));
314
+ check_cu(cuDeviceGetAttribute_f(&g_devices[i].sm_count, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, device));
315
+ #ifdef CUDA_VERSION
316
+ #if CUDA_VERSION >= 12000
317
+ int device_attribute_integrated = 0;
318
+ check_cu(cuDeviceGetAttribute_f(&device_attribute_integrated, CU_DEVICE_ATTRIBUTE_INTEGRATED, device));
319
+ if (device_attribute_integrated == 0)
320
+ {
321
+ check_cu(cuDeviceGetAttribute_f(&g_devices[i].is_ipc_supported, CU_DEVICE_ATTRIBUTE_IPC_EVENT_SUPPORTED, device));
322
+ }
323
+ else
324
+ {
325
+ // integrated devices do not support CUDA IPC
326
+ g_devices[i].is_ipc_supported = 0;
327
+ }
328
+ #endif
329
+ #endif
330
+ check_cu(cuDeviceGetAttribute_f(&g_devices[i].max_smem_bytes, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, device));
331
+ int major = 0;
332
+ int minor = 0;
333
+ check_cu(cuDeviceGetAttribute_f(&major, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device));
334
+ check_cu(cuDeviceGetAttribute_f(&minor, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, device));
335
+ g_devices[i].arch = 10 * major + minor;
336
+ #ifdef CUDA_VERSION
337
+ #if CUDA_VERSION < 13000
338
+ if (g_devices[i].arch == 110) {
339
+ g_devices[i].arch = 101; // Thor SM change
340
+ }
341
+ #endif
342
+ #endif
343
+ g_device_map[device] = &g_devices[i];
344
+ }
345
+ else
346
+ {
347
+ return -1;
348
+ }
349
+ }
350
+ }
351
+ else
352
+ {
353
+ return -1;
354
+ }
355
+
356
+ // initialize default timing state
357
+ static CudaTimingState default_timing_state(0, NULL);
358
+ g_cuda_timing_state = &default_timing_state;
359
+
360
+ return 0;
361
+ }
362
+
363
+
364
+ CUcontext get_current_context()
365
+ {
366
+ CUcontext ctx;
367
+ if (check_cu(cuCtxGetCurrent_f(&ctx)))
368
+ return ctx;
369
+ else
370
+ return NULL;
371
+ }
372
+
373
+ static inline CUstream get_current_stream(void* context=NULL)
374
+ {
375
+ return static_cast<CUstream>(wp_cuda_context_get_stream(context));
376
+ }
377
+
378
+ static ContextInfo* get_context_info(CUcontext ctx)
379
+ {
380
+ if (!ctx)
381
+ {
382
+ ctx = get_current_context();
383
+ if (!ctx)
384
+ return NULL;
385
+ }
386
+
387
+ auto it = g_contexts.find(ctx);
388
+ if (it != g_contexts.end())
389
+ {
390
+ return &it->second;
391
+ }
392
+ else
393
+ {
394
+ // previously unseen context, add the info
395
+ ContextGuard guard(ctx, true);
396
+
397
+ CUdevice device;
398
+ if (check_cu(cuCtxGetDevice_f(&device)))
399
+ {
400
+ DeviceInfo* device_info = g_device_map[device];
401
+
402
+ // workaround for https://nvbugspro.nvidia.com/bug/4456003
403
+ if (device_info->is_mempool_supported)
404
+ {
405
+ void* dummy = NULL;
406
+ check_cuda(cudaMallocAsync(&dummy, 1, NULL));
407
+ check_cuda(cudaFreeAsync(dummy, NULL));
408
+ }
409
+
410
+ ContextInfo context_info;
411
+ context_info.device_info = device_info;
412
+ auto result = g_contexts.insert(std::make_pair(ctx, context_info));
413
+ return &result.first->second;
414
+ }
415
+ }
416
+
417
+ return NULL;
418
+ }
419
+
420
+ static inline ContextInfo* get_context_info(void* context)
421
+ {
422
+ return get_context_info(static_cast<CUcontext>(context));
423
+ }
424
+
425
+ static inline StreamInfo* get_stream_info(CUstream stream)
426
+ {
427
+ auto it = g_streams.find(stream);
428
+ if (it != g_streams.end())
429
+ return &it->second;
430
+ else
431
+ return NULL;
432
+ }
433
+
434
+ static inline CaptureInfo* get_capture_info(CUstream stream)
435
+ {
436
+ if (!g_captures.empty() && wp_cuda_stream_is_capturing(stream))
437
+ {
438
+ uint64_t capture_id = get_capture_id(stream);
439
+ auto capture_iter = g_captures.find(capture_id);
440
+ if (capture_iter != g_captures.end())
441
+ return capture_iter->second;
442
+ }
443
+ return NULL;
444
+ }
445
+
446
+ // helper function to copy a value to device memory in a graph-friendly way
447
+ static bool capturable_tmp_alloc(void* context, const void* data, size_t size, void** devptr_ret, bool* free_devptr_ret)
448
+ {
449
+ ContextGuard guard(context);
450
+
451
+ CUstream stream = get_current_stream();
452
+ CaptureInfo* capture_info = get_capture_info(stream);
453
+ int device_ordinal = wp_cuda_context_get_device_ordinal(context);
454
+ void* devptr = NULL;
455
+ bool free_devptr = true;
456
+
457
+ if (capture_info)
458
+ {
459
+ // ongoing graph capture - need to stage the fill value so that it persists with the graph
460
+ if (CUDA_VERSION >= 12040 && wp_cuda_driver_version() >= 12040)
461
+ {
462
+ // pause the capture so that the alloc/memcpy won't be captured
463
+ void* graph = NULL;
464
+ if (!wp_cuda_graph_pause_capture(WP_CURRENT_CONTEXT, stream, &graph))
465
+ return false;
466
+
467
+ // copy value to device memory
468
+ devptr = wp_alloc_device(WP_CURRENT_CONTEXT, size);
469
+ if (!devptr)
470
+ {
471
+ fprintf(stderr, "Warp error: Failed to allocate %llu bytes on device 'cuda:%d' (in function %s)\n", (unsigned long long)size, device_ordinal, __FUNCTION__);
472
+ return false;
473
+ }
474
+ if (!check_cuda(cudaMemcpyAsync(devptr, data, size, cudaMemcpyHostToDevice, stream)))
475
+ return false;
476
+
477
+ // graph takes ownership of the value storage
478
+ FreeInfo free_info;
479
+ free_info.context = context ? context : get_current_context();
480
+ free_info.ptr = devptr;
481
+ free_info.is_async = wp_cuda_device_is_mempool_supported(device_ordinal);
482
+
483
+ // allocation will be freed when graph is destroyed
484
+ capture_info->tmp_allocs.push_back(free_info);
485
+
486
+ // resume the capture
487
+ if (!wp_cuda_graph_resume_capture(WP_CURRENT_CONTEXT, stream, graph))
488
+ return false;
489
+
490
+ free_devptr = false; // memory is owned by the graph, doesn't need to be freed
491
+ }
492
+ else
493
+ {
494
+ // older CUDA can't pause/resume the capture, so stage in CPU memory
495
+ void* hostptr = wp_alloc_host(size);
496
+ if (!hostptr)
497
+ {
498
+ fprintf(stderr, "Warp error: Failed to allocate %llu bytes on device 'cpu' (in function %s)\n", (unsigned long long)size, __FUNCTION__);
499
+ return false;
500
+ }
501
+ memcpy(hostptr, data, size);
502
+
503
+ // the device allocation and h2d copy will be captured in the graph
504
+ devptr = wp_alloc_device(WP_CURRENT_CONTEXT, size);
505
+ if (!devptr)
506
+ {
507
+ fprintf(stderr, "Warp error: Failed to allocate %llu bytes on device 'cuda:%d' (in function %s)\n", (unsigned long long)size, device_ordinal, __FUNCTION__);
508
+ return false;
509
+ }
510
+ if (!check_cuda(cudaMemcpyAsync(devptr, hostptr, size, cudaMemcpyHostToDevice, stream)))
511
+ return false;
512
+
513
+ // graph takes ownership of the value storage
514
+ FreeInfo free_info;
515
+ free_info.context = NULL;
516
+ free_info.ptr = hostptr;
517
+ free_info.is_async = false;
518
+
519
+ // allocation will be freed when graph is destroyed
520
+ capture_info->tmp_allocs.push_back(free_info);
521
+ }
522
+ }
523
+ else
524
+ {
525
+ // not capturing, copy the value to device memory
526
+ devptr = wp_alloc_device(WP_CURRENT_CONTEXT, size);
527
+ if (!devptr)
528
+ {
529
+ fprintf(stderr, "Warp error: Failed to allocate %llu bytes on device 'cuda:%d' (in function %s)\n", (unsigned long long)size, device_ordinal, __FUNCTION__);
530
+ return false;
531
+ }
532
+ if (!check_cuda(cudaMemcpyAsync(devptr, data, size, cudaMemcpyHostToDevice, stream)))
533
+ return false;
534
+ }
535
+
536
+ *devptr_ret = devptr;
537
+ *free_devptr_ret = free_devptr;
538
+
539
+ return true;
540
+ }
541
+
542
+ static void deferred_free(void* ptr, void* context, bool is_async)
543
+ {
544
+ FreeInfo free_info;
545
+ free_info.ptr = ptr;
546
+ free_info.context = context ? context : get_current_context();
547
+ free_info.is_async = is_async;
548
+ g_deferred_free_list.push_back(free_info);
549
+ }
550
+
551
+ static int free_deferred_allocs(void* context = NULL)
552
+ {
553
+ if (g_deferred_free_list.empty() || !g_captures.empty())
554
+ return 0;
555
+
556
+ int num_freed_allocs = 0;
557
+ for (auto it = g_deferred_free_list.begin(); it != g_deferred_free_list.end(); /*noop*/)
558
+ {
559
+ const FreeInfo& free_info = *it;
560
+
561
+ // free the pointer if it matches the given context or if the context is unspecified
562
+ if (free_info.context == context || !context)
563
+ {
564
+ ContextGuard guard(free_info.context);
565
+
566
+ if (free_info.is_async)
567
+ {
568
+ // this could be a regular stream-ordered allocation or a graph allocation
569
+ cudaError_t res = cudaFreeAsync(free_info.ptr, NULL);
570
+ if (res != cudaSuccess)
571
+ {
572
+ if (res == cudaErrorInvalidValue)
573
+ {
574
+ // This can happen if we try to release the pointer but the graph was
575
+ // never launched, so the memory isn't mapped.
576
+ // This is fine, so clear the error.
577
+ cudaGetLastError();
578
+ }
579
+ else
580
+ {
581
+ // something else went wrong, report error
582
+ check_cuda(res);
583
+ }
584
+ }
585
+ }
586
+ else
587
+ {
588
+ check_cuda(cudaFree(free_info.ptr));
589
+ }
590
+
591
+ ++num_freed_allocs;
592
+
593
+ it = g_deferred_free_list.erase(it);
594
+ }
595
+ else
596
+ {
597
+ ++it;
598
+ }
599
+ }
600
+
601
+ return num_freed_allocs;
602
+ }
603
+
604
+ static int unload_deferred_modules(void* context = NULL)
605
+ {
606
+ if (g_deferred_module_list.empty() || !g_captures.empty())
607
+ return 0;
608
+
609
+ int num_unloaded_modules = 0;
610
+ for (auto it = g_deferred_module_list.begin(); it != g_deferred_module_list.end(); /*noop*/)
611
+ {
612
+ // free the module if it matches the given context or if the context is unspecified
613
+ const ModuleInfo& module_info = *it;
614
+ if (module_info.context == context || !context)
615
+ {
616
+ wp_cuda_unload_module(module_info.context, module_info.module);
617
+ ++num_unloaded_modules;
618
+ it = g_deferred_module_list.erase(it);
619
+ }
620
+ else
621
+ {
622
+ ++it;
623
+ }
624
+ }
625
+
626
+ return num_unloaded_modules;
627
+ }
628
+
629
+ static int destroy_deferred_graphs(void* context = NULL)
630
+ {
631
+ if (g_deferred_graph_list.empty() || !g_captures.empty())
632
+ return 0;
633
+
634
+ int num_destroyed_graphs = 0;
635
+ for (auto it = g_deferred_graph_list.begin(); it != g_deferred_graph_list.end(); /*noop*/)
636
+ {
637
+ // destroy the graph if it matches the given context or if the context is unspecified
638
+ const GraphDestroyInfo& graph_info = *it;
639
+ if (graph_info.context == context || !context)
640
+ {
641
+ if (graph_info.graph)
642
+ {
643
+ check_cuda(cudaGraphDestroy((cudaGraph_t)graph_info.graph));
644
+ }
645
+ if (graph_info.graph_exec)
646
+ {
647
+ check_cuda(cudaGraphExecDestroy((cudaGraphExec_t)graph_info.graph_exec));
648
+ }
649
+ ++num_destroyed_graphs;
650
+ it = g_deferred_graph_list.erase(it);
651
+ }
652
+ else
653
+ {
654
+ ++it;
655
+ }
656
+ }
657
+
658
+ return num_destroyed_graphs;
659
+ }
660
+
661
+ static int process_deferred_graph_destroy_callbacks(void* context = NULL)
662
+ {
663
+ int num_freed = 0;
664
+
665
+ std::lock_guard<std::mutex> lock(g_graph_destroy_mutex);
666
+
667
+ for (auto it = g_deferred_graph_destroy_list.begin(); it != g_deferred_graph_destroy_list.end(); /*noop*/)
668
+ {
669
+ GraphDestroyCallbackInfo* graph_info = *it;
670
+ if (graph_info->context == context || !context)
671
+ {
672
+ // handle unfreed graph allocations (may have outstanding user references)
673
+ for (void* ptr : graph_info->unfreed_allocs)
674
+ {
675
+ auto alloc_iter = g_graph_allocs.find(ptr);
676
+ if (alloc_iter != g_graph_allocs.end())
677
+ {
678
+ GraphAllocInfo& alloc_info = alloc_iter->second;
679
+ if (alloc_info.ref_exists)
680
+ {
681
+ // unreference from graph so the pointer will be deallocated when the user reference goes away
682
+ alloc_info.graph_destroyed = true;
683
+ }
684
+ else
685
+ {
686
+ // the pointer can be freed, no references remain
687
+ wp_free_device_async(alloc_info.context, ptr);
688
+ g_graph_allocs.erase(alloc_iter);
689
+ }
690
+ }
691
+ }
692
+
693
+ // handle temporary allocations owned by the graph (no user references)
694
+ for (const FreeInfo& tmp_info : graph_info->tmp_allocs)
695
+ {
696
+ if (tmp_info.context)
697
+ {
698
+ // GPU alloc
699
+ if (tmp_info.is_async)
700
+ {
701
+ wp_free_device_async(tmp_info.context, tmp_info.ptr);
702
+ }
703
+ else
704
+ {
705
+ wp_free_device_default(tmp_info.context, tmp_info.ptr);
706
+ }
707
+ }
708
+ else
709
+ {
710
+ // CPU alloc
711
+ wp_free_host(tmp_info.ptr);
712
+ }
713
+ }
714
+
715
+ ++num_freed;
716
+ delete graph_info;
717
+ it = g_deferred_graph_destroy_list.erase(it);
718
+ }
719
+ else
720
+ {
721
+ ++it;
722
+ }
723
+ }
724
+
725
+ return num_freed;
726
+ }
727
+
728
+ static int run_deferred_actions(void* context = NULL)
729
+ {
730
+ int num_actions = 0;
731
+ num_actions += free_deferred_allocs(context);
732
+ num_actions += unload_deferred_modules(context);
733
+ num_actions += destroy_deferred_graphs(context);
734
+ num_actions += process_deferred_graph_destroy_callbacks(context);
735
+ return num_actions;
736
+ }
737
+
738
+ // Callback used when a graph is destroyed.
739
+ // NOTE: this runs on an internal CUDA thread and requires synchronization.
740
+ static void CUDART_CB on_graph_destroy(void* user_data)
741
+ {
742
+ if (user_data)
743
+ {
744
+ std::lock_guard<std::mutex> lock(g_graph_destroy_mutex);
745
+ g_deferred_graph_destroy_list.push_back(static_cast<GraphDestroyCallbackInfo*>(user_data));
746
+ }
747
+ }
748
+
749
+ static inline const char* get_cuda_kernel_name(void* kernel)
750
+ {
751
+ CUfunction cuda_func = static_cast<CUfunction>(kernel);
752
+ auto name_iter = g_kernel_names.find((CUfunction)cuda_func);
753
+ if (name_iter != g_kernel_names.end())
754
+ return name_iter->second.c_str();
755
+ else
756
+ return "unknown_kernel";
757
+ }
758
+
759
+
760
+ void* wp_alloc_pinned(size_t s)
761
+ {
762
+ void* ptr = NULL;
763
+ check_cuda(cudaMallocHost(&ptr, s));
764
+ return ptr;
765
+ }
766
+
767
+ void wp_free_pinned(void* ptr)
768
+ {
769
+ cudaFreeHost(ptr);
770
+ }
771
+
772
+ void* wp_alloc_device(void* context, size_t s)
773
+ {
774
+ int ordinal = wp_cuda_context_get_device_ordinal(context);
775
+
776
+ // use stream-ordered allocator if available
777
+ if (wp_cuda_device_is_mempool_supported(ordinal))
778
+ return wp_alloc_device_async(context, s);
779
+ else
780
+ return wp_alloc_device_default(context, s);
781
+ }
782
+
783
+ void wp_free_device(void* context, void* ptr)
784
+ {
785
+ int ordinal = wp_cuda_context_get_device_ordinal(context);
786
+
787
+ // use stream-ordered allocator if available
788
+ if (wp_cuda_device_is_mempool_supported(ordinal))
789
+ wp_free_device_async(context, ptr);
790
+ else
791
+ wp_free_device_default(context, ptr);
792
+ }
793
+
794
+ void* wp_alloc_device_default(void* context, size_t s)
795
+ {
796
+ ContextGuard guard(context);
797
+
798
+ void* ptr = NULL;
799
+ check_cuda(cudaMalloc(&ptr, s));
800
+
801
+ return ptr;
802
+ }
803
+
804
+ void wp_free_device_default(void* context, void* ptr)
805
+ {
806
+ ContextGuard guard(context);
807
+
808
+ // check if a capture is in progress
809
+ if (g_captures.empty())
810
+ {
811
+ check_cuda(cudaFree(ptr));
812
+ }
813
+ else
814
+ {
815
+ // we must defer the operation until graph captures complete
816
+ deferred_free(ptr, context, false);
817
+ }
818
+ }
819
+
820
+ void* wp_alloc_device_async(void* context, size_t s)
821
+ {
822
+ // stream-ordered allocations don't rely on the current context,
823
+ // but we set the context here for consistent behaviour
824
+ ContextGuard guard(context);
825
+
826
+ ContextInfo* context_info = get_context_info(context);
827
+ if (!context_info)
828
+ return NULL;
829
+
830
+ CUstream stream = context_info->stream;
831
+
832
+ void* ptr = NULL;
833
+ check_cuda(cudaMallocAsync(&ptr, s, stream));
834
+
835
+ if (ptr)
836
+ {
837
+ // if the stream is capturing, the allocation requires special handling
838
+ if (wp_cuda_stream_is_capturing(stream))
839
+ {
840
+ // check if this is a known capture
841
+ uint64_t capture_id = get_capture_id(stream);
842
+ auto capture_iter = g_captures.find(capture_id);
843
+ if (capture_iter != g_captures.end())
844
+ {
845
+ // remember graph allocation details
846
+ GraphAllocInfo alloc_info;
847
+ alloc_info.capture_id = capture_id;
848
+ alloc_info.context = context ? context : get_current_context();
849
+ alloc_info.ref_exists = true; // user reference created and returned here
850
+ alloc_info.graph_destroyed = false; // graph not destroyed yet
851
+ g_graph_allocs[ptr] = alloc_info;
852
+ }
853
+ }
854
+ }
855
+
856
+ return ptr;
857
+ }
858
+
859
+ void wp_free_device_async(void* context, void* ptr)
860
+ {
861
+ // stream-ordered allocators generally don't rely on the current context,
862
+ // but we set the context here for consistent behaviour
863
+ ContextGuard guard(context);
864
+
865
+ // NB: Stream-ordered deallocations are tricky, because the memory could still be used on another stream
866
+ // or even multiple streams. To avoid use-after-free errors, we need to ensure that all preceding work
867
+ // completes before releasing the memory. The strategy is different for regular stream-ordered allocations
868
+ // and allocations made during graph capture. See below for details.
869
+
870
+ // check if this allocation was made during graph capture
871
+ auto alloc_iter = g_graph_allocs.find(ptr);
872
+ if (alloc_iter == g_graph_allocs.end())
873
+ {
874
+ // Not a graph allocation.
875
+ // Check if graph capture is ongoing.
876
+ if (g_captures.empty())
877
+ {
878
+ // cudaFreeAsync on the null stream does not block or trigger synchronization, but it postpones
879
+ // the deallocation until a synchronization point is reached, so preceding work on this pointer
880
+ // should safely complete.
881
+ check_cuda(cudaFreeAsync(ptr, NULL));
882
+ }
883
+ else
884
+ {
885
+ // We must defer the free operation until graph capture completes.
886
+ deferred_free(ptr, context, true);
887
+ }
888
+ }
889
+ else
890
+ {
891
+ // get the graph allocation details
892
+ GraphAllocInfo& alloc_info = alloc_iter->second;
893
+
894
+ uint64_t capture_id = alloc_info.capture_id;
895
+
896
+ // check if the capture is still active
897
+ auto capture_iter = g_captures.find(capture_id);
898
+ if (capture_iter != g_captures.end())
899
+ {
900
+ // Add a mem free node. Use all current leaf nodes as dependencies to ensure that all prior
901
+ // work completes before deallocating. This works with both Warp-initiated and external captures
902
+ // and avoids the need to explicitly track all streams used during the capture.
903
+ CaptureInfo* capture = capture_iter->second;
904
+ cudaGraph_t graph = get_capture_graph(capture->stream);
905
+ std::vector<cudaGraphNode_t> leaf_nodes;
906
+ if (graph && get_graph_leaf_nodes(graph, leaf_nodes))
907
+ {
908
+ cudaGraphNode_t free_node;
909
+ check_cuda(cudaGraphAddMemFreeNode(&free_node, graph, leaf_nodes.data(), leaf_nodes.size(), ptr));
910
+ }
911
+
912
+ // we're done with this allocation, it's owned by the graph
913
+ g_graph_allocs.erase(alloc_iter);
914
+ }
915
+ else
916
+ {
917
+ // the capture has ended
918
+ // if the owning graph was already destroyed, we can free the pointer now
919
+ if (alloc_info.graph_destroyed)
920
+ {
921
+ if (g_captures.empty())
922
+ {
923
+ // try to free the pointer now
924
+ cudaError_t res = cudaFreeAsync(ptr, NULL);
925
+ if (res == cudaErrorInvalidValue)
926
+ {
927
+ // This can happen if we try to release the pointer but the graph was
928
+ // never launched, so the memory isn't mapped.
929
+ // This is fine, so clear the error.
930
+ cudaGetLastError();
931
+ }
932
+ else
933
+ {
934
+ // check for other errors
935
+ check_cuda(res);
936
+ }
937
+ }
938
+ else
939
+ {
940
+ // We must defer the operation until graph capture completes.
941
+ deferred_free(ptr, context, true);
942
+ }
943
+
944
+ // we're done with this allocation
945
+ g_graph_allocs.erase(alloc_iter);
946
+ }
947
+ else
948
+ {
949
+ // graph still exists
950
+ // unreference the pointer so it will be deallocated once the graph instance is destroyed
951
+ alloc_info.ref_exists = false;
952
+ }
953
+ }
954
+ }
955
+ }
956
+
957
+ bool wp_memcpy_h2d(void* context, void* dest, void* src, size_t n, void* stream)
958
+ {
959
+ ContextGuard guard(context);
960
+
961
+ CUstream cuda_stream;
962
+ if (stream != WP_CURRENT_STREAM)
963
+ cuda_stream = static_cast<CUstream>(stream);
964
+ else
965
+ cuda_stream = get_current_stream(context);
966
+
967
+ begin_cuda_range(WP_TIMING_MEMCPY, cuda_stream, context, "memcpy HtoD");
968
+
969
+ bool result = check_cuda(cudaMemcpyAsync(dest, src, n, cudaMemcpyHostToDevice, cuda_stream));
970
+
971
+ end_cuda_range(WP_TIMING_MEMCPY, cuda_stream);
972
+
973
+ return result;
974
+ }
975
+
976
+ bool wp_memcpy_d2h(void* context, void* dest, void* src, size_t n, void* stream)
977
+ {
978
+ ContextGuard guard(context);
979
+
980
+ CUstream cuda_stream;
981
+ if (stream != WP_CURRENT_STREAM)
982
+ cuda_stream = static_cast<CUstream>(stream);
983
+ else
984
+ cuda_stream = get_current_stream(context);
985
+
986
+ begin_cuda_range(WP_TIMING_MEMCPY, cuda_stream, context, "memcpy DtoH");
987
+
988
+ bool result = check_cuda(cudaMemcpyAsync(dest, src, n, cudaMemcpyDeviceToHost, cuda_stream));
989
+
990
+ end_cuda_range(WP_TIMING_MEMCPY, cuda_stream);
991
+
992
+ return result;
993
+ }
994
+
995
+ bool wp_memcpy_d2d(void* context, void* dest, void* src, size_t n, void* stream)
996
+ {
997
+ ContextGuard guard(context);
998
+
999
+ CUstream cuda_stream;
1000
+ if (stream != WP_CURRENT_STREAM)
1001
+ cuda_stream = static_cast<CUstream>(stream);
1002
+ else
1003
+ cuda_stream = get_current_stream(context);
1004
+
1005
+ begin_cuda_range(WP_TIMING_MEMCPY, cuda_stream, context, "memcpy DtoD");
1006
+
1007
+ bool result = check_cuda(cudaMemcpyAsync(dest, src, n, cudaMemcpyDeviceToDevice, cuda_stream));
1008
+
1009
+ end_cuda_range(WP_TIMING_MEMCPY, cuda_stream);
1010
+
1011
+ return result;
1012
+ }
1013
+
1014
+ bool wp_memcpy_p2p(void* dst_context, void* dst, void* src_context, void* src, size_t n, void* stream)
1015
+ {
1016
+ // ContextGuard guard(context);
1017
+
1018
+ CUstream cuda_stream;
1019
+ if (stream != WP_CURRENT_STREAM)
1020
+ cuda_stream = static_cast<CUstream>(stream);
1021
+ else
1022
+ cuda_stream = get_current_stream(dst_context);
1023
+
1024
+ // Notes:
1025
+ // - cuMemcpyPeerAsync() works fine with both regular and pooled allocations (cudaMalloc() and cudaMallocAsync(), respectively)
1026
+ // when not capturing a graph.
1027
+ // - cuMemcpyPeerAsync() is not supported during graph capture, so we must use cudaMemcpyAsync() with kind=cudaMemcpyDefault.
1028
+ // - cudaMemcpyAsync() works fine with regular allocations, but doesn't work with pooled allocations
1029
+ // unless mempool access has been enabled.
1030
+ // - There is no reliable way to check if mempool access is enabled during graph capture,
1031
+ // because cudaMemPoolGetAccess() cannot be called during graph capture.
1032
+ // - CUDA will report error 1 (invalid argument) if cudaMemcpyAsync() is called but mempool access is not enabled.
1033
+
1034
+ if (!wp_cuda_stream_is_capturing(stream))
1035
+ {
1036
+ begin_cuda_range(WP_TIMING_MEMCPY, cuda_stream, get_stream_context(stream), "memcpy PtoP");
1037
+
1038
+ bool result = check_cu(cuMemcpyPeerAsync_f(
1039
+ (CUdeviceptr)dst, (CUcontext)dst_context,
1040
+ (CUdeviceptr)src, (CUcontext)src_context,
1041
+ n, cuda_stream));
1042
+
1043
+ end_cuda_range(WP_TIMING_MEMCPY, cuda_stream);
1044
+
1045
+ return result;
1046
+ }
1047
+ else
1048
+ {
1049
+ cudaError_t result = cudaSuccess;
1050
+
1051
+ // cudaMemcpyAsync() is sensitive to the bound context to resolve pointer locations.
1052
+ // If fails with cudaErrorInvalidValue if it cannot resolve an argument.
1053
+ // We first try the copy in the destination context, then if it fails we retry in the source context.
1054
+ // The cudaErrorInvalidValue error doesn't cause graph capture to fail, so it's ok to retry.
1055
+ // Since this trial-and-error shenanigans only happens during capture, there
1056
+ // is no perf impact when the graph is launched.
1057
+ // For bonus points, this approach simplifies memory pool access requirements.
1058
+ // Access only needs to be enabled one way, either from the source device to the destination device
1059
+ // or vice versa. Sometimes, when it's really quiet, you can actually hear my genius.
1060
+ {
1061
+ // try doing the copy in the destination context
1062
+ ContextGuard guard(dst_context);
1063
+ result = cudaMemcpyAsync(dst, src, n, cudaMemcpyDefault, cuda_stream);
1064
+
1065
+ if (result != cudaSuccess)
1066
+ {
1067
+ // clear error in destination context
1068
+ cudaGetLastError();
1069
+
1070
+ // try doing the copy in the source context
1071
+ ContextGuard guard(src_context);
1072
+ result = cudaMemcpyAsync(dst, src, n, cudaMemcpyDefault, cuda_stream);
1073
+
1074
+ // clear error in source context
1075
+ cudaGetLastError();
1076
+ }
1077
+ }
1078
+
1079
+ // If the copy failed, try to detect if mempool allocations are involved to generate a helpful error message.
1080
+ if (!check_cuda(result))
1081
+ {
1082
+ if (result == cudaErrorInvalidValue && src != NULL && dst != NULL)
1083
+ {
1084
+ // check if either of the pointers was allocated from a mempool
1085
+ void* src_mempool = NULL;
1086
+ void* dst_mempool = NULL;
1087
+ cuPointerGetAttribute_f(&src_mempool, CU_POINTER_ATTRIBUTE_MEMPOOL_HANDLE, (CUdeviceptr)src);
1088
+ cuPointerGetAttribute_f(&dst_mempool, CU_POINTER_ATTRIBUTE_MEMPOOL_HANDLE, (CUdeviceptr)dst);
1089
+ cudaGetLastError(); // clear any errors
1090
+ // check if either of the pointers was allocated during graph capture
1091
+ auto src_alloc = g_graph_allocs.find(src);
1092
+ auto dst_alloc = g_graph_allocs.find(dst);
1093
+ if (src_mempool != NULL || src_alloc != g_graph_allocs.end() ||
1094
+ dst_mempool != NULL || dst_alloc != g_graph_allocs.end())
1095
+ {
1096
+ wp::append_error_string("*** CUDA mempool allocations were used in a peer-to-peer copy during graph capture.");
1097
+ wp::append_error_string("*** This operation fails if mempool access is not enabled between the peer devices.");
1098
+ wp::append_error_string("*** Either enable mempool access between the devices or use the default CUDA allocator");
1099
+ wp::append_error_string("*** to pre-allocate the arrays before graph capture begins.");
1100
+ }
1101
+ }
1102
+
1103
+ return false;
1104
+ }
1105
+
1106
+ return true;
1107
+ }
1108
+ }
1109
+
1110
+
1111
+ __global__ void memset_kernel(int* dest, int value, size_t n)
1112
+ {
1113
+ const size_t tid = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
1114
+
1115
+ if (tid < n)
1116
+ {
1117
+ dest[tid] = value;
1118
+ }
1119
+ }
1120
+
1121
+ void wp_memset_device(void* context, void* dest, int value, size_t n)
1122
+ {
1123
+ ContextGuard guard(context);
1124
+
1125
+ if (true)// ((n%4) > 0)
1126
+ {
1127
+ cudaStream_t stream = get_current_stream();
1128
+
1129
+ begin_cuda_range(WP_TIMING_MEMSET, stream, context, "memset");
1130
+
1131
+ // for unaligned lengths fallback to CUDA memset
1132
+ check_cuda(cudaMemsetAsync(dest, value, n, stream));
1133
+
1134
+ end_cuda_range(WP_TIMING_MEMSET, stream);
1135
+ }
1136
+ else
1137
+ {
1138
+ // custom kernel to support 4-byte values (and slightly lower host overhead)
1139
+ const size_t num_words = n/4;
1140
+ wp_launch_device(WP_CURRENT_CONTEXT, memset_kernel, num_words, ((int*)dest, value, num_words));
1141
+ }
1142
+ }
1143
+
1144
+ // fill memory buffer with a value: generic memtile kernel using memcpy for each element
1145
+ __global__ void memtile_kernel(void* dst, const void* src, size_t srcsize, size_t n)
1146
+ {
1147
+ size_t tid = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
1148
+ if (tid < n)
1149
+ {
1150
+ memcpy((int8_t*)dst + srcsize * tid, src, srcsize);
1151
+ }
1152
+ }
1153
+
1154
+ // this should be faster than memtile_kernel, but requires proper alignment of dst
1155
+ template <typename T>
1156
+ __global__ void memtile_value_kernel(T* dst, T value, size_t n)
1157
+ {
1158
+ size_t tid = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
1159
+ if (tid < n)
1160
+ {
1161
+ dst[tid] = value;
1162
+ }
1163
+ }
1164
+
1165
+ void wp_memtile_device(void* context, void* dst, const void* src, size_t srcsize, size_t n)
1166
+ {
1167
+ ContextGuard guard(context);
1168
+
1169
+ size_t dst_addr = reinterpret_cast<size_t>(dst);
1170
+ size_t src_addr = reinterpret_cast<size_t>(src);
1171
+
1172
+ // try memtile_value first because it should be faster, but we need to ensure proper alignment
1173
+ if (srcsize == 8 && (dst_addr & 7) == 0 && (src_addr & 7) == 0)
1174
+ {
1175
+ int64_t* p = reinterpret_cast<int64_t*>(dst);
1176
+ int64_t value = *reinterpret_cast<const int64_t*>(src);
1177
+ wp_launch_device(WP_CURRENT_CONTEXT, memtile_value_kernel, n, (p, value, n));
1178
+ }
1179
+ else if (srcsize == 4 && (dst_addr & 3) == 0 && (src_addr & 3) == 0)
1180
+ {
1181
+ int32_t* p = reinterpret_cast<int32_t*>(dst);
1182
+ int32_t value = *reinterpret_cast<const int32_t*>(src);
1183
+ wp_launch_device(WP_CURRENT_CONTEXT, memtile_value_kernel, n, (p, value, n));
1184
+ }
1185
+ else if (srcsize == 2 && (dst_addr & 1) == 0 && (src_addr & 1) == 0)
1186
+ {
1187
+ int16_t* p = reinterpret_cast<int16_t*>(dst);
1188
+ int16_t value = *reinterpret_cast<const int16_t*>(src);
1189
+ wp_launch_device(WP_CURRENT_CONTEXT, memtile_value_kernel, n, (p, value, n));
1190
+ }
1191
+ else if (srcsize == 1)
1192
+ {
1193
+ check_cuda(cudaMemset(dst, *reinterpret_cast<const int8_t*>(src), n));
1194
+ }
1195
+ else
1196
+ {
1197
+ // generic version
1198
+ void* value_devptr = NULL; // fill value in device memory
1199
+ bool free_devptr = true; // whether we need to free the memory
1200
+
1201
+ // prepare the fill value in a graph-friendly way
1202
+ if (!capturable_tmp_alloc(WP_CURRENT_CONTEXT, src, srcsize, &value_devptr, &free_devptr))
1203
+ {
1204
+ fprintf(stderr, "Warp fill error: failed to copy value to device memory\n");
1205
+ return;
1206
+ }
1207
+
1208
+ wp_launch_device(WP_CURRENT_CONTEXT, memtile_kernel, n, (dst, value_devptr, srcsize, n));
1209
+
1210
+ if (free_devptr)
1211
+ {
1212
+ wp_free_device(WP_CURRENT_CONTEXT, value_devptr);
1213
+ }
1214
+ }
1215
+ }
1216
+
1217
+
1218
+ static __global__ void array_copy_1d_kernel(void* dst, const void* src,
1219
+ size_t dst_stride, size_t src_stride,
1220
+ const int* dst_indices, const int* src_indices,
1221
+ size_t n, size_t elem_size)
1222
+ {
1223
+ size_t i = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1224
+ if (i < n)
1225
+ {
1226
+ size_t src_idx = src_indices ? src_indices[i] : i;
1227
+ size_t dst_idx = dst_indices ? dst_indices[i] : i;
1228
+ const char* p = (const char*)src + src_idx * src_stride;
1229
+ char* q = (char*)dst + dst_idx * dst_stride;
1230
+ memcpy(q, p, elem_size);
1231
+ }
1232
+ }
1233
+
1234
+ static __global__ void array_copy_2d_kernel(void* dst, const void* src,
1235
+ wp::vec_t<2, size_t> dst_strides, wp::vec_t<2, size_t> src_strides,
1236
+ wp::vec_t<2, const int*> dst_indices, wp::vec_t<2, const int*> src_indices,
1237
+ wp::vec_t<2, size_t> shape, size_t elem_size)
1238
+ {
1239
+ size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1240
+ size_t n = shape[1];
1241
+ size_t i = tid / n;
1242
+ size_t j = tid % n;
1243
+ if (i < shape[0] /*&& j < shape[1]*/)
1244
+ {
1245
+ size_t src_idx0 = src_indices[0] ? src_indices[0][i] : i;
1246
+ size_t dst_idx0 = dst_indices[0] ? dst_indices[0][i] : i;
1247
+ size_t src_idx1 = src_indices[1] ? src_indices[1][j] : j;
1248
+ size_t dst_idx1 = dst_indices[1] ? dst_indices[1][j] : j;
1249
+ const char* p = (const char*)src + src_idx0 * src_strides[0] + src_idx1 * src_strides[1];
1250
+ char* q = (char*)dst + dst_idx0 * dst_strides[0] + dst_idx1 * dst_strides[1];
1251
+ memcpy(q, p, elem_size);
1252
+ }
1253
+ }
1254
+
1255
+ static __global__ void array_copy_3d_kernel(void* dst, const void* src,
1256
+ wp::vec_t<3, size_t> dst_strides, wp::vec_t<3, size_t> src_strides,
1257
+ wp::vec_t<3, const int*> dst_indices, wp::vec_t<3, const int*> src_indices,
1258
+ wp::vec_t<3, size_t> shape, size_t elem_size)
1259
+ {
1260
+ size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1261
+ size_t n = shape[1];
1262
+ size_t o = shape[2];
1263
+ size_t i = tid / (n * o);
1264
+ size_t j = tid % (n * o) / o;
1265
+ size_t k = tid % o;
1266
+ if (i < shape[0] && j < shape[1] /*&& k < shape[2]*/)
1267
+ {
1268
+ size_t src_idx0 = src_indices[0] ? src_indices[0][i] : i;
1269
+ size_t dst_idx0 = dst_indices[0] ? dst_indices[0][i] : i;
1270
+ size_t src_idx1 = src_indices[1] ? src_indices[1][j] : j;
1271
+ size_t dst_idx1 = dst_indices[1] ? dst_indices[1][j] : j;
1272
+ size_t src_idx2 = src_indices[2] ? src_indices[2][k] : k;
1273
+ size_t dst_idx2 = dst_indices[2] ? dst_indices[2][k] : k;
1274
+ const char* p = (const char*)src + src_idx0 * src_strides[0]
1275
+ + src_idx1 * src_strides[1]
1276
+ + src_idx2 * src_strides[2];
1277
+ char* q = (char*)dst + dst_idx0 * dst_strides[0]
1278
+ + dst_idx1 * dst_strides[1]
1279
+ + dst_idx2 * dst_strides[2];
1280
+ memcpy(q, p, elem_size);
1281
+ }
1282
+ }
1283
+
1284
+ static __global__ void array_copy_4d_kernel(void* dst, const void* src,
1285
+ wp::vec_t<4, size_t> dst_strides, wp::vec_t<4, size_t> src_strides,
1286
+ wp::vec_t<4, const int*> dst_indices, wp::vec_t<4, const int*> src_indices,
1287
+ wp::vec_t<4, size_t> shape, size_t elem_size)
1288
+ {
1289
+ size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1290
+ size_t n = shape[1];
1291
+ size_t o = shape[2];
1292
+ size_t p = shape[3];
1293
+ size_t i = tid / (n * o * p);
1294
+ size_t j = tid % (n * o * p) / (o * p);
1295
+ size_t k = tid % (o * p) / p;
1296
+ size_t l = tid % p;
1297
+ if (i < shape[0] && j < shape[1] && k < shape[2] /*&& l < shape[3]*/)
1298
+ {
1299
+ size_t src_idx0 = src_indices[0] ? src_indices[0][i] : i;
1300
+ size_t dst_idx0 = dst_indices[0] ? dst_indices[0][i] : i;
1301
+ size_t src_idx1 = src_indices[1] ? src_indices[1][j] : j;
1302
+ size_t dst_idx1 = dst_indices[1] ? dst_indices[1][j] : j;
1303
+ size_t src_idx2 = src_indices[2] ? src_indices[2][k] : k;
1304
+ size_t dst_idx2 = dst_indices[2] ? dst_indices[2][k] : k;
1305
+ size_t src_idx3 = src_indices[3] ? src_indices[3][l] : l;
1306
+ size_t dst_idx3 = dst_indices[3] ? dst_indices[3][l] : l;
1307
+ const char* p = (const char*)src + src_idx0 * src_strides[0]
1308
+ + src_idx1 * src_strides[1]
1309
+ + src_idx2 * src_strides[2]
1310
+ + src_idx3 * src_strides[3];
1311
+ char* q = (char*)dst + dst_idx0 * dst_strides[0]
1312
+ + dst_idx1 * dst_strides[1]
1313
+ + dst_idx2 * dst_strides[2]
1314
+ + dst_idx3 * dst_strides[3];
1315
+ memcpy(q, p, elem_size);
1316
+ }
1317
+ }
1318
+
1319
+
1320
+ static __global__ void array_copy_from_fabric_kernel(wp::fabricarray_t<void> src,
1321
+ void* dst_data, size_t dst_stride, const int* dst_indices,
1322
+ size_t elem_size)
1323
+ {
1324
+ size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1325
+
1326
+ if (tid < src.size)
1327
+ {
1328
+ size_t dst_idx = dst_indices ? dst_indices[tid] : tid;
1329
+ void* dst_ptr = (char*)dst_data + dst_idx * dst_stride;
1330
+ const void* src_ptr = fabricarray_element_ptr(src, tid, elem_size);
1331
+ memcpy(dst_ptr, src_ptr, elem_size);
1332
+ }
1333
+ }
1334
+
1335
+ static __global__ void array_copy_from_fabric_indexed_kernel(wp::indexedfabricarray_t<void> src,
1336
+ void* dst_data, size_t dst_stride, const int* dst_indices,
1337
+ size_t elem_size)
1338
+ {
1339
+ size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1340
+
1341
+ if (tid < src.size)
1342
+ {
1343
+ size_t src_index = src.indices[tid];
1344
+ size_t dst_idx = dst_indices ? dst_indices[tid] : tid;
1345
+ void* dst_ptr = (char*)dst_data + dst_idx * dst_stride;
1346
+ const void* src_ptr = fabricarray_element_ptr(src.fa, src_index, elem_size);
1347
+ memcpy(dst_ptr, src_ptr, elem_size);
1348
+ }
1349
+ }
1350
+
1351
+ static __global__ void array_copy_to_fabric_kernel(wp::fabricarray_t<void> dst,
1352
+ const void* src_data, size_t src_stride, const int* src_indices,
1353
+ size_t elem_size)
1354
+ {
1355
+ size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1356
+
1357
+ if (tid < dst.size)
1358
+ {
1359
+ size_t src_idx = src_indices ? src_indices[tid] : tid;
1360
+ const void* src_ptr = (const char*)src_data + src_idx * src_stride;
1361
+ void* dst_ptr = fabricarray_element_ptr(dst, tid, elem_size);
1362
+ memcpy(dst_ptr, src_ptr, elem_size);
1363
+ }
1364
+ }
1365
+
1366
+ static __global__ void array_copy_to_fabric_indexed_kernel(wp::indexedfabricarray_t<void> dst,
1367
+ const void* src_data, size_t src_stride, const int* src_indices,
1368
+ size_t elem_size)
1369
+ {
1370
+ size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1371
+
1372
+ if (tid < dst.size)
1373
+ {
1374
+ size_t src_idx = src_indices ? src_indices[tid] : tid;
1375
+ const void* src_ptr = (const char*)src_data + src_idx * src_stride;
1376
+ size_t dst_idx = dst.indices[tid];
1377
+ void* dst_ptr = fabricarray_element_ptr(dst.fa, dst_idx, elem_size);
1378
+ memcpy(dst_ptr, src_ptr, elem_size);
1379
+ }
1380
+ }
1381
+
1382
+
1383
+ static __global__ void array_copy_fabric_to_fabric_kernel(wp::fabricarray_t<void> dst, wp::fabricarray_t<void> src, size_t elem_size)
1384
+ {
1385
+ size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1386
+
1387
+ if (tid < dst.size)
1388
+ {
1389
+ const void* src_ptr = fabricarray_element_ptr(src, tid, elem_size);
1390
+ void* dst_ptr = fabricarray_element_ptr(dst, tid, elem_size);
1391
+ memcpy(dst_ptr, src_ptr, elem_size);
1392
+ }
1393
+ }
1394
+
1395
+
1396
+ static __global__ void array_copy_fabric_to_fabric_indexed_kernel(wp::indexedfabricarray_t<void> dst, wp::fabricarray_t<void> src, size_t elem_size)
1397
+ {
1398
+ size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1399
+
1400
+ if (tid < dst.size)
1401
+ {
1402
+ const void* src_ptr = fabricarray_element_ptr(src, tid, elem_size);
1403
+ size_t dst_index = dst.indices[tid];
1404
+ void* dst_ptr = fabricarray_element_ptr(dst.fa, dst_index, elem_size);
1405
+ memcpy(dst_ptr, src_ptr, elem_size);
1406
+ }
1407
+ }
1408
+
1409
+
1410
+ static __global__ void array_copy_fabric_indexed_to_fabric_kernel(wp::fabricarray_t<void> dst, wp::indexedfabricarray_t<void> src, size_t elem_size)
1411
+ {
1412
+ size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1413
+
1414
+ if (tid < dst.size)
1415
+ {
1416
+ size_t src_index = src.indices[tid];
1417
+ const void* src_ptr = fabricarray_element_ptr(src.fa, src_index, elem_size);
1418
+ void* dst_ptr = fabricarray_element_ptr(dst, tid, elem_size);
1419
+ memcpy(dst_ptr, src_ptr, elem_size);
1420
+ }
1421
+ }
1422
+
1423
+
1424
+ static __global__ void array_copy_fabric_indexed_to_fabric_indexed_kernel(wp::indexedfabricarray_t<void> dst, wp::indexedfabricarray_t<void> src, size_t elem_size)
1425
+ {
1426
+ size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1427
+
1428
+ if (tid < dst.size)
1429
+ {
1430
+ size_t src_index = src.indices[tid];
1431
+ size_t dst_index = dst.indices[tid];
1432
+ const void* src_ptr = fabricarray_element_ptr(src.fa, src_index, elem_size);
1433
+ void* dst_ptr = fabricarray_element_ptr(dst.fa, dst_index, elem_size);
1434
+ memcpy(dst_ptr, src_ptr, elem_size);
1435
+ }
1436
+ }
1437
+
1438
+
1439
+ WP_API bool wp_array_copy_device(void* context, void* dst, void* src, int dst_type, int src_type, int elem_size)
1440
+ {
1441
+ if (!src || !dst)
1442
+ return false;
1443
+
1444
+ const void* src_data = NULL;
1445
+ void* dst_data = NULL;
1446
+ int src_ndim = 0;
1447
+ int dst_ndim = 0;
1448
+ const int* src_shape = NULL;
1449
+ const int* dst_shape = NULL;
1450
+ const int* src_strides = NULL;
1451
+ const int* dst_strides = NULL;
1452
+ const int*const* src_indices = NULL;
1453
+ const int*const* dst_indices = NULL;
1454
+
1455
+ const wp::fabricarray_t<void>* src_fabricarray = NULL;
1456
+ wp::fabricarray_t<void>* dst_fabricarray = NULL;
1457
+
1458
+ const wp::indexedfabricarray_t<void>* src_indexedfabricarray = NULL;
1459
+ wp::indexedfabricarray_t<void>* dst_indexedfabricarray = NULL;
1460
+
1461
+ const int* null_indices[wp::ARRAY_MAX_DIMS] = { NULL };
1462
+
1463
+ if (src_type == wp::ARRAY_TYPE_REGULAR)
1464
+ {
1465
+ const wp::array_t<void>& src_arr = *static_cast<const wp::array_t<void>*>(src);
1466
+ src_data = src_arr.data;
1467
+ src_ndim = src_arr.ndim;
1468
+ src_shape = src_arr.shape.dims;
1469
+ src_strides = src_arr.strides;
1470
+ src_indices = null_indices;
1471
+ }
1472
+ else if (src_type == wp::ARRAY_TYPE_INDEXED)
1473
+ {
1474
+ const wp::indexedarray_t<void>& src_arr = *static_cast<const wp::indexedarray_t<void>*>(src);
1475
+ src_data = src_arr.arr.data;
1476
+ src_ndim = src_arr.arr.ndim;
1477
+ src_shape = src_arr.shape.dims;
1478
+ src_strides = src_arr.arr.strides;
1479
+ src_indices = src_arr.indices;
1480
+ }
1481
+ else if (src_type == wp::ARRAY_TYPE_FABRIC)
1482
+ {
1483
+ src_fabricarray = static_cast<const wp::fabricarray_t<void>*>(src);
1484
+ src_ndim = 1;
1485
+ }
1486
+ else if (src_type == wp::ARRAY_TYPE_FABRIC_INDEXED)
1487
+ {
1488
+ src_indexedfabricarray = static_cast<const wp::indexedfabricarray_t<void>*>(src);
1489
+ src_ndim = 1;
1490
+ }
1491
+ else
1492
+ {
1493
+ fprintf(stderr, "Warp copy error: Invalid array type (%d)\n", src_type);
1494
+ return false;
1495
+ }
1496
+
1497
+ if (dst_type == wp::ARRAY_TYPE_REGULAR)
1498
+ {
1499
+ const wp::array_t<void>& dst_arr = *static_cast<const wp::array_t<void>*>(dst);
1500
+ dst_data = dst_arr.data;
1501
+ dst_ndim = dst_arr.ndim;
1502
+ dst_shape = dst_arr.shape.dims;
1503
+ dst_strides = dst_arr.strides;
1504
+ dst_indices = null_indices;
1505
+ }
1506
+ else if (dst_type == wp::ARRAY_TYPE_INDEXED)
1507
+ {
1508
+ const wp::indexedarray_t<void>& dst_arr = *static_cast<const wp::indexedarray_t<void>*>(dst);
1509
+ dst_data = dst_arr.arr.data;
1510
+ dst_ndim = dst_arr.arr.ndim;
1511
+ dst_shape = dst_arr.shape.dims;
1512
+ dst_strides = dst_arr.arr.strides;
1513
+ dst_indices = dst_arr.indices;
1514
+ }
1515
+ else if (dst_type == wp::ARRAY_TYPE_FABRIC)
1516
+ {
1517
+ dst_fabricarray = static_cast<wp::fabricarray_t<void>*>(dst);
1518
+ dst_ndim = 1;
1519
+ }
1520
+ else if (dst_type == wp::ARRAY_TYPE_FABRIC_INDEXED)
1521
+ {
1522
+ dst_indexedfabricarray = static_cast<wp::indexedfabricarray_t<void>*>(dst);
1523
+ dst_ndim = 1;
1524
+ }
1525
+ else
1526
+ {
1527
+ fprintf(stderr, "Warp copy error: Invalid array type (%d)\n", dst_type);
1528
+ return false;
1529
+ }
1530
+
1531
+ if (src_ndim != dst_ndim)
1532
+ {
1533
+ fprintf(stderr, "Warp copy error: Incompatible array dimensionalities (%d and %d)\n", src_ndim, dst_ndim);
1534
+ return false;
1535
+ }
1536
+
1537
+ ContextGuard guard(context);
1538
+
1539
+ // handle fabric arrays
1540
+ if (dst_fabricarray)
1541
+ {
1542
+ size_t n = dst_fabricarray->size;
1543
+ if (src_fabricarray)
1544
+ {
1545
+ // copy from fabric to fabric
1546
+ if (src_fabricarray->size != n)
1547
+ {
1548
+ fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
1549
+ return false;
1550
+ }
1551
+ wp_launch_device(WP_CURRENT_CONTEXT, array_copy_fabric_to_fabric_kernel, n,
1552
+ (*dst_fabricarray, *src_fabricarray, elem_size));
1553
+ return true;
1554
+ }
1555
+ else if (src_indexedfabricarray)
1556
+ {
1557
+ // copy from fabric indexed to fabric
1558
+ if (src_indexedfabricarray->size != n)
1559
+ {
1560
+ fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
1561
+ return false;
1562
+ }
1563
+ wp_launch_device(WP_CURRENT_CONTEXT, array_copy_fabric_indexed_to_fabric_kernel, n,
1564
+ (*dst_fabricarray, *src_indexedfabricarray, elem_size));
1565
+ return true;
1566
+ }
1567
+ else
1568
+ {
1569
+ // copy to fabric
1570
+ if (size_t(src_shape[0]) != n)
1571
+ {
1572
+ fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
1573
+ return false;
1574
+ }
1575
+ wp_launch_device(WP_CURRENT_CONTEXT, array_copy_to_fabric_kernel, n,
1576
+ (*dst_fabricarray, src_data, src_strides[0], src_indices[0], elem_size));
1577
+ return true;
1578
+ }
1579
+ }
1580
+ if (dst_indexedfabricarray)
1581
+ {
1582
+ size_t n = dst_indexedfabricarray->size;
1583
+ if (src_fabricarray)
1584
+ {
1585
+ // copy from fabric to fabric indexed
1586
+ if (src_fabricarray->size != n)
1587
+ {
1588
+ fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
1589
+ return false;
1590
+ }
1591
+ wp_launch_device(WP_CURRENT_CONTEXT, array_copy_fabric_to_fabric_indexed_kernel, n,
1592
+ (*dst_indexedfabricarray, *src_fabricarray, elem_size));
1593
+ return true;
1594
+ }
1595
+ else if (src_indexedfabricarray)
1596
+ {
1597
+ // copy from fabric indexed to fabric indexed
1598
+ if (src_indexedfabricarray->size != n)
1599
+ {
1600
+ fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
1601
+ return false;
1602
+ }
1603
+ wp_launch_device(WP_CURRENT_CONTEXT, array_copy_fabric_indexed_to_fabric_indexed_kernel, n,
1604
+ (*dst_indexedfabricarray, *src_indexedfabricarray, elem_size));
1605
+ return true;
1606
+ }
1607
+ else
1608
+ {
1609
+ // copy to fabric indexed
1610
+ if (size_t(src_shape[0]) != n)
1611
+ {
1612
+ fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
1613
+ return false;
1614
+ }
1615
+ wp_launch_device(WP_CURRENT_CONTEXT, array_copy_to_fabric_indexed_kernel, n,
1616
+ (*dst_indexedfabricarray, src_data, src_strides[0], src_indices[0], elem_size));
1617
+ return true;
1618
+ }
1619
+ }
1620
+ else if (src_fabricarray)
1621
+ {
1622
+ // copy from fabric
1623
+ size_t n = src_fabricarray->size;
1624
+ if (size_t(dst_shape[0]) != n)
1625
+ {
1626
+ fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
1627
+ return false;
1628
+ }
1629
+ wp_launch_device(WP_CURRENT_CONTEXT, array_copy_from_fabric_kernel, n,
1630
+ (*src_fabricarray, dst_data, dst_strides[0], dst_indices[0], elem_size));
1631
+ return true;
1632
+ }
1633
+ else if (src_indexedfabricarray)
1634
+ {
1635
+ // copy from fabric indexed
1636
+ size_t n = src_indexedfabricarray->size;
1637
+ if (size_t(dst_shape[0]) != n)
1638
+ {
1639
+ fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
1640
+ return false;
1641
+ }
1642
+ wp_launch_device(WP_CURRENT_CONTEXT, array_copy_from_fabric_indexed_kernel, n,
1643
+ (*src_indexedfabricarray, dst_data, dst_strides[0], dst_indices[0], elem_size));
1644
+ return true;
1645
+ }
1646
+
1647
+ size_t n = 1;
1648
+ for (int i = 0; i < src_ndim; i++)
1649
+ {
1650
+ if (src_shape[i] != dst_shape[i])
1651
+ {
1652
+ fprintf(stderr, "Warp copy error: Incompatible array shapes\n");
1653
+ return false;
1654
+ }
1655
+ n *= src_shape[i];
1656
+ }
1657
+
1658
+ switch (src_ndim)
1659
+ {
1660
+ case 1:
1661
+ {
1662
+ wp_launch_device(WP_CURRENT_CONTEXT, array_copy_1d_kernel, n, (dst_data, src_data,
1663
+ dst_strides[0], src_strides[0],
1664
+ dst_indices[0], src_indices[0],
1665
+ src_shape[0], elem_size));
1666
+ break;
1667
+ }
1668
+ case 2:
1669
+ {
1670
+ wp::vec_t<2, size_t> shape_v(src_shape[0], src_shape[1]);
1671
+ wp::vec_t<2, size_t> src_strides_v(src_strides[0], src_strides[1]);
1672
+ wp::vec_t<2, size_t> dst_strides_v(dst_strides[0], dst_strides[1]);
1673
+ wp::vec_t<2, const int*> src_indices_v(src_indices[0], src_indices[1]);
1674
+ wp::vec_t<2, const int*> dst_indices_v(dst_indices[0], dst_indices[1]);
1675
+
1676
+ wp_launch_device(WP_CURRENT_CONTEXT, array_copy_2d_kernel, n, (dst_data, src_data,
1677
+ dst_strides_v, src_strides_v,
1678
+ dst_indices_v, src_indices_v,
1679
+ shape_v, elem_size));
1680
+ break;
1681
+ }
1682
+ case 3:
1683
+ {
1684
+ wp::vec_t<3, size_t> shape_v(src_shape[0], src_shape[1], src_shape[2]);
1685
+ wp::vec_t<3, size_t> src_strides_v(src_strides[0], src_strides[1], src_strides[2]);
1686
+ wp::vec_t<3, size_t> dst_strides_v(dst_strides[0], dst_strides[1], dst_strides[2]);
1687
+ wp::vec_t<3, const int*> src_indices_v(src_indices[0], src_indices[1], src_indices[2]);
1688
+ wp::vec_t<3, const int*> dst_indices_v(dst_indices[0], dst_indices[1], dst_indices[2]);
1689
+
1690
+ wp_launch_device(WP_CURRENT_CONTEXT, array_copy_3d_kernel, n, (dst_data, src_data,
1691
+ dst_strides_v, src_strides_v,
1692
+ dst_indices_v, src_indices_v,
1693
+ shape_v, elem_size));
1694
+ break;
1695
+ }
1696
+ case 4:
1697
+ {
1698
+ wp::vec_t<4, size_t> shape_v(src_shape[0], src_shape[1], src_shape[2], src_shape[3]);
1699
+ wp::vec_t<4, size_t> src_strides_v(src_strides[0], src_strides[1], src_strides[2], src_strides[3]);
1700
+ wp::vec_t<4, size_t> dst_strides_v(dst_strides[0], dst_strides[1], dst_strides[2], dst_strides[3]);
1701
+ wp::vec_t<4, const int*> src_indices_v(src_indices[0], src_indices[1], src_indices[2], src_indices[3]);
1702
+ wp::vec_t<4, const int*> dst_indices_v(dst_indices[0], dst_indices[1], dst_indices[2], dst_indices[3]);
1703
+
1704
+ wp_launch_device(WP_CURRENT_CONTEXT, array_copy_4d_kernel, n, (dst_data, src_data,
1705
+ dst_strides_v, src_strides_v,
1706
+ dst_indices_v, src_indices_v,
1707
+ shape_v, elem_size));
1708
+ break;
1709
+ }
1710
+ default:
1711
+ fprintf(stderr, "Warp copy error: invalid array dimensionality (%d)\n", src_ndim);
1712
+ return false;
1713
+ }
1714
+
1715
+ return check_cuda(cudaGetLastError());
1716
+ }
1717
+
1718
+
1719
+ static __global__ void array_fill_1d_kernel(void* data,
1720
+ size_t n,
1721
+ size_t stride,
1722
+ const int* indices,
1723
+ const void* value,
1724
+ size_t value_size)
1725
+ {
1726
+ size_t i = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1727
+ if (i < n)
1728
+ {
1729
+ size_t idx = indices ? indices[i] : i;
1730
+ char* p = (char*)data + idx * stride;
1731
+ memcpy(p, value, value_size);
1732
+ }
1733
+ }
1734
+
1735
+ static __global__ void array_fill_2d_kernel(void* data,
1736
+ wp::vec_t<2, size_t> shape,
1737
+ wp::vec_t<2, size_t> strides,
1738
+ wp::vec_t<2, const int*> indices,
1739
+ const void* value,
1740
+ size_t value_size)
1741
+ {
1742
+ size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1743
+ size_t n = shape[1];
1744
+ size_t i = tid / n;
1745
+ size_t j = tid % n;
1746
+ if (i < shape[0] /*&& j < shape[1]*/)
1747
+ {
1748
+ size_t idx0 = indices[0] ? indices[0][i] : i;
1749
+ size_t idx1 = indices[1] ? indices[1][j] : j;
1750
+ char* p = (char*)data + idx0 * strides[0] + idx1 * strides[1];
1751
+ memcpy(p, value, value_size);
1752
+ }
1753
+ }
1754
+
1755
+ static __global__ void array_fill_3d_kernel(void* data,
1756
+ wp::vec_t<3, size_t> shape,
1757
+ wp::vec_t<3, size_t> strides,
1758
+ wp::vec_t<3, const int*> indices,
1759
+ const void* value,
1760
+ size_t value_size)
1761
+ {
1762
+ size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1763
+ size_t n = shape[1];
1764
+ size_t o = shape[2];
1765
+ size_t i = tid / (n * o);
1766
+ size_t j = tid % (n * o) / o;
1767
+ size_t k = tid % o;
1768
+ if (i < shape[0] && j < shape[1] /*&& k < shape[2]*/)
1769
+ {
1770
+ size_t idx0 = indices[0] ? indices[0][i] : i;
1771
+ size_t idx1 = indices[1] ? indices[1][j] : j;
1772
+ size_t idx2 = indices[2] ? indices[2][k] : k;
1773
+ char* p = (char*)data + idx0 * strides[0] + idx1 * strides[1] + idx2 * strides[2];
1774
+ memcpy(p, value, value_size);
1775
+ }
1776
+ }
1777
+
1778
+ static __global__ void array_fill_4d_kernel(void* data,
1779
+ wp::vec_t<4, size_t> shape,
1780
+ wp::vec_t<4, size_t> strides,
1781
+ wp::vec_t<4, const int*> indices,
1782
+ const void* value,
1783
+ size_t value_size)
1784
+ {
1785
+ size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1786
+ size_t n = shape[1];
1787
+ size_t o = shape[2];
1788
+ size_t p = shape[3];
1789
+ size_t i = tid / (n * o * p);
1790
+ size_t j = tid % (n * o * p) / (o * p);
1791
+ size_t k = tid % (o * p) / p;
1792
+ size_t l = tid % p;
1793
+ if (i < shape[0] && j < shape[1] && k < shape[2] /*&& l < shape[3]*/)
1794
+ {
1795
+ size_t idx0 = indices[0] ? indices[0][i] : i;
1796
+ size_t idx1 = indices[1] ? indices[1][j] : j;
1797
+ size_t idx2 = indices[2] ? indices[2][k] : k;
1798
+ size_t idx3 = indices[3] ? indices[3][l] : l;
1799
+ char* p = (char*)data + idx0 * strides[0] + idx1 * strides[1] + idx2 * strides[2] + idx3 * strides[3];
1800
+ memcpy(p, value, value_size);
1801
+ }
1802
+ }
1803
+
1804
+
1805
+ static __global__ void array_fill_fabric_kernel(wp::fabricarray_t<void> fa, const void* value, size_t value_size)
1806
+ {
1807
+ size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1808
+ if (tid < fa.size)
1809
+ {
1810
+ void* dst_ptr = fabricarray_element_ptr(fa, tid, value_size);
1811
+ memcpy(dst_ptr, value, value_size);
1812
+ }
1813
+ }
1814
+
1815
+
1816
+ static __global__ void array_fill_fabric_indexed_kernel(wp::indexedfabricarray_t<void> ifa, const void* value, size_t value_size)
1817
+ {
1818
+ size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
1819
+ if (tid < ifa.size)
1820
+ {
1821
+ size_t idx = size_t(ifa.indices[tid]);
1822
+ if (idx < ifa.fa.size)
1823
+ {
1824
+ void* dst_ptr = fabricarray_element_ptr(ifa.fa, idx, value_size);
1825
+ memcpy(dst_ptr, value, value_size);
1826
+ }
1827
+ }
1828
+ }
1829
+
1830
+
1831
+ WP_API void wp_array_fill_device(void* context, void* arr_ptr, int arr_type, const void* value_ptr, int value_size)
1832
+ {
1833
+ if (!arr_ptr || !value_ptr)
1834
+ return;
1835
+
1836
+ void* data = NULL;
1837
+ int ndim = 0;
1838
+ const int* shape = NULL;
1839
+ const int* strides = NULL;
1840
+ const int*const* indices = NULL;
1841
+
1842
+ wp::fabricarray_t<void>* fa = NULL;
1843
+ wp::indexedfabricarray_t<void>* ifa = NULL;
1844
+
1845
+ const int* null_indices[wp::ARRAY_MAX_DIMS] = { NULL };
1846
+
1847
+ if (arr_type == wp::ARRAY_TYPE_REGULAR)
1848
+ {
1849
+ wp::array_t<void>& arr = *static_cast<wp::array_t<void>*>(arr_ptr);
1850
+ data = arr.data;
1851
+ ndim = arr.ndim;
1852
+ shape = arr.shape.dims;
1853
+ strides = arr.strides;
1854
+ indices = null_indices;
1855
+ }
1856
+ else if (arr_type == wp::ARRAY_TYPE_INDEXED)
1857
+ {
1858
+ wp::indexedarray_t<void>& ia = *static_cast<wp::indexedarray_t<void>*>(arr_ptr);
1859
+ data = ia.arr.data;
1860
+ ndim = ia.arr.ndim;
1861
+ shape = ia.shape.dims;
1862
+ strides = ia.arr.strides;
1863
+ indices = ia.indices;
1864
+ }
1865
+ else if (arr_type == wp::ARRAY_TYPE_FABRIC)
1866
+ {
1867
+ fa = static_cast<wp::fabricarray_t<void>*>(arr_ptr);
1868
+ }
1869
+ else if (arr_type == wp::ARRAY_TYPE_FABRIC_INDEXED)
1870
+ {
1871
+ ifa = static_cast<wp::indexedfabricarray_t<void>*>(arr_ptr);
1872
+ }
1873
+ else
1874
+ {
1875
+ fprintf(stderr, "Warp fill error: Invalid array type id %d\n", arr_type);
1876
+ return;
1877
+ }
1878
+
1879
+ size_t n = 1;
1880
+ for (int i = 0; i < ndim; i++)
1881
+ n *= shape[i];
1882
+
1883
+ ContextGuard guard(context);
1884
+
1885
+ void* value_devptr = NULL; // fill value in device memory
1886
+ bool free_devptr = true; // whether we need to free the memory
1887
+
1888
+ // prepare the fill value in a graph-friendly way
1889
+ if (!capturable_tmp_alloc(WP_CURRENT_CONTEXT, value_ptr, value_size, &value_devptr, &free_devptr))
1890
+ {
1891
+ fprintf(stderr, "Warp fill error: failed to copy value to device memory\n");
1892
+ return;
1893
+ }
1894
+
1895
+ if (fa)
1896
+ {
1897
+ // handle fabric arrays
1898
+ wp_launch_device(WP_CURRENT_CONTEXT, array_fill_fabric_kernel, n,
1899
+ (*fa, value_devptr, value_size));
1900
+ }
1901
+ else if (ifa)
1902
+ {
1903
+ // handle indexed fabric arrays
1904
+ wp_launch_device(WP_CURRENT_CONTEXT, array_fill_fabric_indexed_kernel, n,
1905
+ (*ifa, value_devptr, value_size));
1906
+ }
1907
+ else
1908
+ {
1909
+ // handle regular or indexed arrays
1910
+ switch (ndim)
1911
+ {
1912
+ case 1:
1913
+ {
1914
+ wp_launch_device(WP_CURRENT_CONTEXT, array_fill_1d_kernel, n,
1915
+ (data, shape[0], strides[0], indices[0], value_devptr, value_size));
1916
+ break;
1917
+ }
1918
+ case 2:
1919
+ {
1920
+ wp::vec_t<2, size_t> shape_v(shape[0], shape[1]);
1921
+ wp::vec_t<2, size_t> strides_v(strides[0], strides[1]);
1922
+ wp::vec_t<2, const int*> indices_v(indices[0], indices[1]);
1923
+ wp_launch_device(WP_CURRENT_CONTEXT, array_fill_2d_kernel, n,
1924
+ (data, shape_v, strides_v, indices_v, value_devptr, value_size));
1925
+ break;
1926
+ }
1927
+ case 3:
1928
+ {
1929
+ wp::vec_t<3, size_t> shape_v(shape[0], shape[1], shape[2]);
1930
+ wp::vec_t<3, size_t> strides_v(strides[0], strides[1], strides[2]);
1931
+ wp::vec_t<3, const int*> indices_v(indices[0], indices[1], indices[2]);
1932
+ wp_launch_device(WP_CURRENT_CONTEXT, array_fill_3d_kernel, n,
1933
+ (data, shape_v, strides_v, indices_v, value_devptr, value_size));
1934
+ break;
1935
+ }
1936
+ case 4:
1937
+ {
1938
+ wp::vec_t<4, size_t> shape_v(shape[0], shape[1], shape[2], shape[3]);
1939
+ wp::vec_t<4, size_t> strides_v(strides[0], strides[1], strides[2], strides[3]);
1940
+ wp::vec_t<4, const int*> indices_v(indices[0], indices[1], indices[2], indices[3]);
1941
+ wp_launch_device(WP_CURRENT_CONTEXT, array_fill_4d_kernel, n,
1942
+ (data, shape_v, strides_v, indices_v, value_devptr, value_size));
1943
+ break;
1944
+ }
1945
+ default:
1946
+ fprintf(stderr, "Warp fill error: invalid array dimensionality (%d)\n", ndim);
1947
+ break;
1948
+ }
1949
+ }
1950
+
1951
+ if (free_devptr)
1952
+ {
1953
+ wp_free_device(WP_CURRENT_CONTEXT, value_devptr);
1954
+ }
1955
+ }
1956
+
1957
+ void wp_array_scan_int_device(uint64_t in, uint64_t out, int len, bool inclusive)
1958
+ {
1959
+ scan_device((const int*)in, (int*)out, len, inclusive);
1960
+ }
1961
+
1962
+ void wp_array_scan_float_device(uint64_t in, uint64_t out, int len, bool inclusive)
1963
+ {
1964
+ scan_device((const float*)in, (float*)out, len, inclusive);
1965
+ }
1966
+
1967
+ int wp_cuda_driver_version()
1968
+ {
1969
+ int version;
1970
+ if (check_cu(cuDriverGetVersion_f(&version)))
1971
+ return version;
1972
+ else
1973
+ return 0;
1974
+ }
1975
+
1976
+ int wp_cuda_toolkit_version()
1977
+ {
1978
+ return CUDA_VERSION;
1979
+ }
1980
+
1981
+ bool wp_cuda_driver_is_initialized()
1982
+ {
1983
+ return is_cuda_driver_initialized();
1984
+ }
1985
+
1986
+ int wp_nvrtc_supported_arch_count()
1987
+ {
1988
+ int count;
1989
+ if (check_nvrtc(nvrtcGetNumSupportedArchs(&count)))
1990
+ return count;
1991
+ else
1992
+ return 0;
1993
+ }
1994
+
1995
+ void wp_nvrtc_supported_archs(int* archs)
1996
+ {
1997
+ if (archs)
1998
+ {
1999
+ check_nvrtc(nvrtcGetSupportedArchs(archs));
2000
+ }
2001
+ }
2002
+
2003
+ int wp_cuda_device_get_count()
2004
+ {
2005
+ int count = 0;
2006
+ check_cu(cuDeviceGetCount_f(&count));
2007
+ return count;
2008
+ }
2009
+
2010
+ void* wp_cuda_device_get_primary_context(int ordinal)
2011
+ {
2012
+ if (ordinal >= 0 && ordinal < int(g_devices.size()))
2013
+ {
2014
+ DeviceInfo& device_info = g_devices[ordinal];
2015
+
2016
+ // acquire the primary context if we haven't already
2017
+ if (!device_info.primary_context)
2018
+ check_cu(cuDevicePrimaryCtxRetain_f(&device_info.primary_context, device_info.device));
2019
+
2020
+ return device_info.primary_context;
2021
+ }
2022
+
2023
+ return NULL;
2024
+ }
2025
+
2026
+ const char* wp_cuda_device_get_name(int ordinal)
2027
+ {
2028
+ if (ordinal >= 0 && ordinal < int(g_devices.size()))
2029
+ return g_devices[ordinal].name;
2030
+ return NULL;
2031
+ }
2032
+
2033
+ int wp_cuda_device_get_arch(int ordinal)
2034
+ {
2035
+ if (ordinal >= 0 && ordinal < int(g_devices.size()))
2036
+ return g_devices[ordinal].arch;
2037
+ return 0;
2038
+ }
2039
+
2040
+ int wp_cuda_device_get_sm_count(int ordinal)
2041
+ {
2042
+ if (ordinal >= 0 && ordinal < int(g_devices.size()))
2043
+ return g_devices[ordinal].sm_count;
2044
+ return 0;
2045
+ }
2046
+
2047
+ void wp_cuda_device_get_uuid(int ordinal, char uuid[16])
2048
+ {
2049
+ memcpy(uuid, g_devices[ordinal].uuid.bytes, sizeof(char)*16);
2050
+ }
2051
+
2052
+ int wp_cuda_device_get_pci_domain_id(int ordinal)
2053
+ {
2054
+ if (ordinal >= 0 && ordinal < int(g_devices.size()))
2055
+ return g_devices[ordinal].pci_domain_id;
2056
+ return -1;
2057
+ }
2058
+
2059
+ int wp_cuda_device_get_pci_bus_id(int ordinal)
2060
+ {
2061
+ if (ordinal >= 0 && ordinal < int(g_devices.size()))
2062
+ return g_devices[ordinal].pci_bus_id;
2063
+ return -1;
2064
+ }
2065
+
2066
+ int wp_cuda_device_get_pci_device_id(int ordinal)
2067
+ {
2068
+ if (ordinal >= 0 && ordinal < int(g_devices.size()))
2069
+ return g_devices[ordinal].pci_device_id;
2070
+ return -1;
2071
+ }
2072
+
2073
+ int wp_cuda_device_is_uva(int ordinal)
2074
+ {
2075
+ if (ordinal >= 0 && ordinal < int(g_devices.size()))
2076
+ return g_devices[ordinal].is_uva;
2077
+ return 0;
2078
+ }
2079
+
2080
+ int wp_cuda_device_is_mempool_supported(int ordinal)
2081
+ {
2082
+ if (ordinal >= 0 && ordinal < int(g_devices.size()))
2083
+ return g_devices[ordinal].is_mempool_supported;
2084
+ return 0;
2085
+ }
2086
+
2087
+ int wp_cuda_device_is_ipc_supported(int ordinal)
2088
+ {
2089
+ if (ordinal >= 0 && ordinal < int(g_devices.size()))
2090
+ return g_devices[ordinal].is_ipc_supported;
2091
+ return 0;
2092
+ }
2093
+
2094
+ int wp_cuda_device_set_mempool_release_threshold(int ordinal, uint64_t threshold)
2095
+ {
2096
+ if (ordinal < 0 || ordinal > int(g_devices.size()))
2097
+ {
2098
+ fprintf(stderr, "Invalid device ordinal %d\n", ordinal);
2099
+ return 0;
2100
+ }
2101
+
2102
+ if (!g_devices[ordinal].is_mempool_supported)
2103
+ return 0;
2104
+
2105
+ cudaMemPool_t pool;
2106
+ if (!check_cuda(cudaDeviceGetDefaultMemPool(&pool, ordinal)))
2107
+ {
2108
+ fprintf(stderr, "Warp error: Failed to get memory pool on device %d\n", ordinal);
2109
+ return 0;
2110
+ }
2111
+
2112
+ if (!check_cuda(cudaMemPoolSetAttribute(pool, cudaMemPoolAttrReleaseThreshold, &threshold)))
2113
+ {
2114
+ fprintf(stderr, "Warp error: Failed to set memory pool attribute on device %d\n", ordinal);
2115
+ return 0;
2116
+ }
2117
+
2118
+ return 1; // success
2119
+ }
2120
+
2121
+ uint64_t wp_cuda_device_get_mempool_release_threshold(int ordinal)
2122
+ {
2123
+ if (ordinal < 0 || ordinal > int(g_devices.size()))
2124
+ {
2125
+ fprintf(stderr, "Invalid device ordinal %d\n", ordinal);
2126
+ return 0;
2127
+ }
2128
+
2129
+ if (!g_devices[ordinal].is_mempool_supported)
2130
+ return 0;
2131
+
2132
+ cudaMemPool_t pool;
2133
+ if (!check_cuda(cudaDeviceGetDefaultMemPool(&pool, ordinal)))
2134
+ {
2135
+ fprintf(stderr, "Warp error: Failed to get memory pool on device %d\n", ordinal);
2136
+ return 0;
2137
+ }
2138
+
2139
+ uint64_t threshold = 0;
2140
+ if (!check_cuda(cudaMemPoolGetAttribute(pool, cudaMemPoolAttrReleaseThreshold, &threshold)))
2141
+ {
2142
+ fprintf(stderr, "Warp error: Failed to get memory pool release threshold on device %d\n", ordinal);
2143
+ return 0;
2144
+ }
2145
+
2146
+ return threshold;
2147
+ }
2148
+
2149
+ uint64_t wp_cuda_device_get_mempool_used_mem_current(int ordinal)
2150
+ {
2151
+ if (ordinal < 0 || ordinal > int(g_devices.size()))
2152
+ {
2153
+ fprintf(stderr, "Invalid device ordinal %d\n", ordinal);
2154
+ return 0;
2155
+ }
2156
+
2157
+ if (!g_devices[ordinal].is_mempool_supported)
2158
+ return 0;
2159
+
2160
+ cudaMemPool_t pool;
2161
+ if (!check_cuda(cudaDeviceGetDefaultMemPool(&pool, ordinal)))
2162
+ {
2163
+ fprintf(stderr, "Warp error: Failed to get memory pool on device %d\n", ordinal);
2164
+ return 0;
2165
+ }
2166
+
2167
+ uint64_t mem_used = 0;
2168
+ if (!check_cuda(cudaMemPoolGetAttribute(pool, cudaMemPoolAttrUsedMemCurrent, &mem_used)))
2169
+ {
2170
+ fprintf(stderr, "Warp error: Failed to get amount of currently used memory from the memory pool on device %d\n", ordinal);
2171
+ return 0;
2172
+ }
2173
+
2174
+ return mem_used;
2175
+ }
2176
+
2177
+ uint64_t wp_cuda_device_get_mempool_used_mem_high(int ordinal)
2178
+ {
2179
+ if (ordinal < 0 || ordinal > int(g_devices.size()))
2180
+ {
2181
+ fprintf(stderr, "Invalid device ordinal %d\n", ordinal);
2182
+ return 0;
2183
+ }
2184
+
2185
+ if (!g_devices[ordinal].is_mempool_supported)
2186
+ return 0;
2187
+
2188
+ cudaMemPool_t pool;
2189
+ if (!check_cuda(cudaDeviceGetDefaultMemPool(&pool, ordinal)))
2190
+ {
2191
+ fprintf(stderr, "Warp error: Failed to get memory pool on device %d\n", ordinal);
2192
+ return 0;
2193
+ }
2194
+
2195
+ uint64_t mem_high_water_mark = 0;
2196
+ if (!check_cuda(cudaMemPoolGetAttribute(pool, cudaMemPoolAttrUsedMemHigh, &mem_high_water_mark)))
2197
+ {
2198
+ fprintf(stderr, "Warp error: Failed to get memory usage high water mark from the memory pool on device %d\n", ordinal);
2199
+ return 0;
2200
+ }
2201
+
2202
+ return mem_high_water_mark;
2203
+ }
2204
+
2205
+ void wp_cuda_device_get_memory_info(int ordinal, size_t* free_mem, size_t* total_mem)
2206
+ {
2207
+ // use temporary storage if user didn't specify pointers
2208
+ size_t tmp_free_mem, tmp_total_mem;
2209
+
2210
+ if (free_mem)
2211
+ *free_mem = 0;
2212
+ else
2213
+ free_mem = &tmp_free_mem;
2214
+
2215
+ if (total_mem)
2216
+ *total_mem = 0;
2217
+ else
2218
+ total_mem = &tmp_total_mem;
2219
+
2220
+ if (ordinal >= 0 && ordinal < int(g_devices.size()))
2221
+ {
2222
+ if (g_devices[ordinal].primary_context)
2223
+ {
2224
+ ContextGuard guard(g_devices[ordinal].primary_context, true);
2225
+ check_cu(cuMemGetInfo_f(free_mem, total_mem));
2226
+ }
2227
+ else
2228
+ {
2229
+ // if we haven't acquired the primary context yet, acquire it temporarily
2230
+ CUcontext primary_context = NULL;
2231
+ check_cu(cuDevicePrimaryCtxRetain_f(&primary_context, g_devices[ordinal].device));
2232
+ {
2233
+ ContextGuard guard(primary_context, true);
2234
+ check_cu(cuMemGetInfo_f(free_mem, total_mem));
2235
+ }
2236
+ check_cu(cuDevicePrimaryCtxRelease_f(g_devices[ordinal].device));
2237
+ }
2238
+ }
2239
+ }
2240
+
2241
+
2242
+ void* wp_cuda_context_get_current()
2243
+ {
2244
+ return get_current_context();
2245
+ }
2246
+
2247
+ void wp_cuda_context_set_current(void* context)
2248
+ {
2249
+ CUcontext ctx = static_cast<CUcontext>(context);
2250
+ CUcontext prev_ctx = NULL;
2251
+ check_cu(cuCtxGetCurrent_f(&prev_ctx));
2252
+ if (ctx != prev_ctx)
2253
+ {
2254
+ check_cu(cuCtxSetCurrent_f(ctx));
2255
+ }
2256
+ }
2257
+
2258
+ void wp_cuda_context_push_current(void* context)
2259
+ {
2260
+ check_cu(cuCtxPushCurrent_f(static_cast<CUcontext>(context)));
2261
+ }
2262
+
2263
+ void wp_cuda_context_pop_current()
2264
+ {
2265
+ CUcontext context;
2266
+ check_cu(cuCtxPopCurrent_f(&context));
2267
+ }
2268
+
2269
+ void* wp_cuda_context_create(int device_ordinal)
2270
+ {
2271
+ CUcontext ctx = NULL;
2272
+ CUdevice device;
2273
+ if (check_cu(cuDeviceGet_f(&device, device_ordinal)))
2274
+ check_cu(cuCtxCreate_f(&ctx, 0, device));
2275
+ return ctx;
2276
+ }
2277
+
2278
+ void wp_cuda_context_destroy(void* context)
2279
+ {
2280
+ if (context)
2281
+ {
2282
+ CUcontext ctx = static_cast<CUcontext>(context);
2283
+
2284
+ // ensure this is not the current context
2285
+ if (ctx == wp_cuda_context_get_current())
2286
+ wp_cuda_context_set_current(NULL);
2287
+
2288
+ // release the cached info about this context
2289
+ ContextInfo* info = get_context_info(ctx);
2290
+ if (info)
2291
+ {
2292
+ if (info->stream)
2293
+ check_cu(cuStreamDestroy_f(info->stream));
2294
+
2295
+ if (info->conditional_module)
2296
+ check_cu(cuModuleUnload_f(info->conditional_module));
2297
+
2298
+ g_contexts.erase(ctx);
2299
+ }
2300
+
2301
+ check_cu(cuCtxDestroy_f(ctx));
2302
+ }
2303
+ }
2304
+
2305
+ void wp_cuda_context_synchronize(void* context)
2306
+ {
2307
+ ContextGuard guard(context);
2308
+
2309
+ check_cu(cuCtxSynchronize_f());
2310
+
2311
+ if (!context)
2312
+ context = get_current_context();
2313
+
2314
+ if (run_deferred_actions(context) > 0)
2315
+ {
2316
+ // ensure deferred asynchronous operations complete
2317
+ check_cu(cuCtxSynchronize_f());
2318
+ }
2319
+
2320
+ // check_cuda(cudaDeviceGraphMemTrim(wp_cuda_context_get_device_ordinal(context)));
2321
+ }
2322
+
2323
+ uint64_t wp_cuda_context_check(void* context)
2324
+ {
2325
+ ContextGuard guard(context);
2326
+
2327
+ // check errors before syncing
2328
+ cudaError_t e = cudaGetLastError();
2329
+ check_cuda(e);
2330
+
2331
+ cudaStreamCaptureStatus status = cudaStreamCaptureStatusNone;
2332
+ check_cuda(cudaStreamIsCapturing(get_current_stream(), &status));
2333
+
2334
+ // synchronize if the stream is not capturing
2335
+ if (status == cudaStreamCaptureStatusNone)
2336
+ {
2337
+ check_cuda(cudaDeviceSynchronize());
2338
+ e = cudaGetLastError();
2339
+ }
2340
+
2341
+ return static_cast<uint64_t>(e);
2342
+ }
2343
+
2344
+
2345
+ int wp_cuda_context_get_device_ordinal(void* context)
2346
+ {
2347
+ ContextInfo* info = get_context_info(static_cast<CUcontext>(context));
2348
+ return info && info->device_info ? info->device_info->ordinal : -1;
2349
+ }
2350
+
2351
+ int wp_cuda_context_is_primary(void* context)
2352
+ {
2353
+ CUcontext ctx = static_cast<CUcontext>(context);
2354
+ ContextInfo* context_info = get_context_info(ctx);
2355
+ if (!context_info)
2356
+ {
2357
+ fprintf(stderr, "Warp error: Failed to get context info\n");
2358
+ return 0;
2359
+ }
2360
+
2361
+ // if the device primary context is known, check if it matches the given context
2362
+ DeviceInfo* device_info = context_info->device_info;
2363
+ if (device_info->primary_context)
2364
+ return int(ctx == device_info->primary_context);
2365
+
2366
+ // there is no CUDA API to check if a context is primary, but we can temporarily
2367
+ // acquire the device's primary context to check the pointer
2368
+ CUcontext primary_ctx;
2369
+ if (check_cu(cuDevicePrimaryCtxRetain_f(&primary_ctx, device_info->device)))
2370
+ {
2371
+ check_cu(cuDevicePrimaryCtxRelease_f(device_info->device));
2372
+ return int(ctx == primary_ctx);
2373
+ }
2374
+
2375
+ return 0;
2376
+ }
2377
+
2378
+ void* wp_cuda_context_get_stream(void* context)
2379
+ {
2380
+ ContextInfo* info = get_context_info(static_cast<CUcontext>(context));
2381
+ if (info)
2382
+ {
2383
+ return info->stream;
2384
+ }
2385
+ return NULL;
2386
+ }
2387
+
2388
+ void wp_cuda_context_set_stream(void* context, void* stream, int sync)
2389
+ {
2390
+ ContextInfo* context_info = get_context_info(static_cast<CUcontext>(context));
2391
+ if (context_info)
2392
+ {
2393
+ CUstream new_stream = static_cast<CUstream>(stream);
2394
+
2395
+ // check whether we should sync with the previous stream on this device
2396
+ if (sync)
2397
+ {
2398
+ CUstream old_stream = context_info->stream;
2399
+ StreamInfo* old_stream_info = get_stream_info(old_stream);
2400
+ if (old_stream_info)
2401
+ {
2402
+ CUevent cached_event = old_stream_info->cached_event;
2403
+ check_cu(cuEventRecord_f(cached_event, old_stream));
2404
+ check_cu(cuStreamWaitEvent_f(new_stream, cached_event, CU_EVENT_WAIT_DEFAULT));
2405
+ }
2406
+ }
2407
+
2408
+ context_info->stream = new_stream;
2409
+ }
2410
+ }
2411
+
2412
+ int wp_cuda_is_peer_access_supported(int target_ordinal, int peer_ordinal)
2413
+ {
2414
+ int num_devices = int(g_devices.size());
2415
+
2416
+ if (target_ordinal < 0 || target_ordinal > num_devices)
2417
+ {
2418
+ fprintf(stderr, "Warp error: Invalid target device ordinal %d\n", target_ordinal);
2419
+ return 0;
2420
+ }
2421
+
2422
+ if (peer_ordinal < 0 || peer_ordinal > num_devices)
2423
+ {
2424
+ fprintf(stderr, "Warp error: Invalid peer device ordinal %d\n", peer_ordinal);
2425
+ return 0;
2426
+ }
2427
+
2428
+ if (target_ordinal == peer_ordinal)
2429
+ return 1;
2430
+
2431
+ int can_access = 0;
2432
+ check_cuda(cudaDeviceCanAccessPeer(&can_access, peer_ordinal, target_ordinal));
2433
+
2434
+ return can_access;
2435
+ }
2436
+
2437
+ int wp_cuda_is_peer_access_enabled(void* target_context, void* peer_context)
2438
+ {
2439
+ if (!target_context || !peer_context)
2440
+ {
2441
+ fprintf(stderr, "Warp error: invalid CUDA context\n");
2442
+ return 0;
2443
+ }
2444
+
2445
+ if (target_context == peer_context)
2446
+ return 1;
2447
+
2448
+ int target_ordinal = wp_cuda_context_get_device_ordinal(target_context);
2449
+ int peer_ordinal = wp_cuda_context_get_device_ordinal(peer_context);
2450
+
2451
+ // check if peer access is supported
2452
+ int can_access = 0;
2453
+ check_cuda(cudaDeviceCanAccessPeer(&can_access, peer_ordinal, target_ordinal));
2454
+ if (!can_access)
2455
+ return 0;
2456
+
2457
+ // There is no CUDA API to query if peer access is enabled, but we can try to enable it and check the result.
2458
+
2459
+ ContextGuard guard(peer_context, true);
2460
+
2461
+ CUcontext target_ctx = static_cast<CUcontext>(target_context);
2462
+
2463
+ CUresult result = cuCtxEnablePeerAccess_f(target_ctx, 0);
2464
+ if (result == CUDA_ERROR_PEER_ACCESS_ALREADY_ENABLED)
2465
+ {
2466
+ return 1;
2467
+ }
2468
+ else if (result == CUDA_SUCCESS)
2469
+ {
2470
+ // undo enablement
2471
+ check_cu(cuCtxDisablePeerAccess_f(target_ctx));
2472
+ return 0;
2473
+ }
2474
+ else
2475
+ {
2476
+ // report error
2477
+ check_cu(result);
2478
+ return 0;
2479
+ }
2480
+ }
2481
+
2482
+ int wp_cuda_set_peer_access_enabled(void* target_context, void* peer_context, int enable)
2483
+ {
2484
+ if (!target_context || !peer_context)
2485
+ {
2486
+ fprintf(stderr, "Warp error: invalid CUDA context\n");
2487
+ return 0;
2488
+ }
2489
+
2490
+ if (target_context == peer_context)
2491
+ return 1; // no-op
2492
+
2493
+ int target_ordinal = wp_cuda_context_get_device_ordinal(target_context);
2494
+ int peer_ordinal = wp_cuda_context_get_device_ordinal(peer_context);
2495
+
2496
+ // check if peer access is supported
2497
+ int can_access = 0;
2498
+ check_cuda(cudaDeviceCanAccessPeer(&can_access, peer_ordinal, target_ordinal));
2499
+ if (!can_access)
2500
+ {
2501
+ // failure if enabling, success if disabling
2502
+ if (enable)
2503
+ {
2504
+ fprintf(stderr, "Warp error: device %d cannot access device %d\n", peer_ordinal, target_ordinal);
2505
+ return 0;
2506
+ }
2507
+ else
2508
+ return 1;
2509
+ }
2510
+
2511
+ ContextGuard guard(peer_context, true);
2512
+
2513
+ CUcontext target_ctx = static_cast<CUcontext>(target_context);
2514
+
2515
+ if (enable)
2516
+ {
2517
+ CUresult status = cuCtxEnablePeerAccess_f(target_ctx, 0);
2518
+ if (status != CUDA_SUCCESS && status != CUDA_ERROR_PEER_ACCESS_ALREADY_ENABLED)
2519
+ {
2520
+ check_cu(status);
2521
+ fprintf(stderr, "Warp error: failed to enable peer access from device %d to device %d\n", peer_ordinal, target_ordinal);
2522
+ return 0;
2523
+ }
2524
+ }
2525
+ else
2526
+ {
2527
+ CUresult status = cuCtxDisablePeerAccess_f(target_ctx);
2528
+ if (status != CUDA_SUCCESS && status != CUDA_ERROR_PEER_ACCESS_NOT_ENABLED)
2529
+ {
2530
+ check_cu(status);
2531
+ fprintf(stderr, "Warp error: failed to disable peer access from device %d to device %d\n", peer_ordinal, target_ordinal);
2532
+ return 0;
2533
+ }
2534
+ }
2535
+
2536
+ return 1; // success
2537
+ }
2538
+
2539
+ int wp_cuda_is_mempool_access_enabled(int target_ordinal, int peer_ordinal)
2540
+ {
2541
+ int num_devices = int(g_devices.size());
2542
+
2543
+ if (target_ordinal < 0 || target_ordinal > num_devices)
2544
+ {
2545
+ fprintf(stderr, "Warp error: Invalid device ordinal %d\n", target_ordinal);
2546
+ return 0;
2547
+ }
2548
+
2549
+ if (peer_ordinal < 0 || peer_ordinal > num_devices)
2550
+ {
2551
+ fprintf(stderr, "Warp error: Invalid peer device ordinal %d\n", peer_ordinal);
2552
+ return 0;
2553
+ }
2554
+
2555
+ if (target_ordinal == peer_ordinal)
2556
+ return 1;
2557
+
2558
+ cudaMemPool_t pool;
2559
+ if (!check_cuda(cudaDeviceGetDefaultMemPool(&pool, target_ordinal)))
2560
+ {
2561
+ fprintf(stderr, "Warp error: Failed to get memory pool of device %d\n", target_ordinal);
2562
+ return 0;
2563
+ }
2564
+
2565
+ cudaMemAccessFlags flags = cudaMemAccessFlagsProtNone;
2566
+ cudaMemLocation location;
2567
+ location.id = peer_ordinal;
2568
+ location.type = cudaMemLocationTypeDevice;
2569
+ if (check_cuda(cudaMemPoolGetAccess(&flags, pool, &location)))
2570
+ return int(flags != cudaMemAccessFlagsProtNone);
2571
+
2572
+ return 0;
2573
+ }
2574
+
2575
+ int wp_cuda_set_mempool_access_enabled(int target_ordinal, int peer_ordinal, int enable)
2576
+ {
2577
+ int num_devices = int(g_devices.size());
2578
+
2579
+ if (target_ordinal < 0 || target_ordinal > num_devices)
2580
+ {
2581
+ fprintf(stderr, "Warp error: Invalid device ordinal %d\n", target_ordinal);
2582
+ return 0;
2583
+ }
2584
+
2585
+ if (peer_ordinal < 0 || peer_ordinal > num_devices)
2586
+ {
2587
+ fprintf(stderr, "Warp error: Invalid peer device ordinal %d\n", peer_ordinal);
2588
+ return 0;
2589
+ }
2590
+
2591
+ if (target_ordinal == peer_ordinal)
2592
+ return 1; // no-op
2593
+
2594
+ // get the memory pool
2595
+ cudaMemPool_t pool;
2596
+ if (!check_cuda(cudaDeviceGetDefaultMemPool(&pool, target_ordinal)))
2597
+ {
2598
+ fprintf(stderr, "Warp error: Failed to get memory pool of device %d\n", target_ordinal);
2599
+ return 0;
2600
+ }
2601
+
2602
+ cudaMemAccessDesc desc;
2603
+ desc.location.type = cudaMemLocationTypeDevice;
2604
+ desc.location.id = peer_ordinal;
2605
+
2606
+ // only cudaMemAccessFlagsProtReadWrite and cudaMemAccessFlagsProtNone are supported
2607
+ if (enable)
2608
+ desc.flags = cudaMemAccessFlagsProtReadWrite;
2609
+ else
2610
+ desc.flags = cudaMemAccessFlagsProtNone;
2611
+
2612
+ if (!check_cuda(cudaMemPoolSetAccess(pool, &desc, 1)))
2613
+ {
2614
+ fprintf(stderr, "Warp error: Failed to set mempool access from device %d to device %d\n", peer_ordinal, target_ordinal);
2615
+ return 0;
2616
+ }
2617
+
2618
+ return 1; // success
2619
+ }
2620
+
2621
+ void wp_cuda_ipc_get_mem_handle(void* ptr, char* out_buffer) {
2622
+ CUipcMemHandle memHandle;
2623
+ check_cu(cuIpcGetMemHandle_f(&memHandle, (CUdeviceptr)ptr));
2624
+ memcpy(out_buffer, memHandle.reserved, CU_IPC_HANDLE_SIZE);
2625
+ }
2626
+
2627
+ void* wp_cuda_ipc_open_mem_handle(void* context, char* handle) {
2628
+ ContextGuard guard(context);
2629
+
2630
+ CUipcMemHandle memHandle;
2631
+ memcpy(memHandle.reserved, handle, CU_IPC_HANDLE_SIZE);
2632
+
2633
+ CUdeviceptr device_ptr;
2634
+
2635
+ // Strangely, the CU_IPC_MEM_LAZY_ENABLE_PEER_ACCESS flag is required
2636
+ if check_cu(cuIpcOpenMemHandle_f(&device_ptr, memHandle, CU_IPC_MEM_LAZY_ENABLE_PEER_ACCESS))
2637
+ return (void*) device_ptr;
2638
+ else
2639
+ return NULL;
2640
+ }
2641
+
2642
+ void wp_cuda_ipc_close_mem_handle(void* ptr) {
2643
+ check_cu(cuIpcCloseMemHandle_f((CUdeviceptr) ptr));
2644
+ }
2645
+
2646
+ void wp_cuda_ipc_get_event_handle(void* context, void* event, char* out_buffer) {
2647
+ ContextGuard guard(context);
2648
+
2649
+ CUipcEventHandle eventHandle;
2650
+ check_cu(cuIpcGetEventHandle_f(&eventHandle, static_cast<CUevent>(event)));
2651
+ memcpy(out_buffer, eventHandle.reserved, CU_IPC_HANDLE_SIZE);
2652
+ }
2653
+
2654
+ void* wp_cuda_ipc_open_event_handle(void* context, char* handle) {
2655
+ ContextGuard guard(context);
2656
+
2657
+ CUipcEventHandle eventHandle;
2658
+ memcpy(eventHandle.reserved, handle, CU_IPC_HANDLE_SIZE);
2659
+
2660
+ CUevent event;
2661
+
2662
+ if (check_cu(cuIpcOpenEventHandle_f(&event, eventHandle)))
2663
+ return event;
2664
+ else
2665
+ return NULL;
2666
+ }
2667
+
2668
+ void* wp_cuda_stream_create(void* context, int priority)
2669
+ {
2670
+ ContextGuard guard(context, true);
2671
+
2672
+ CUstream stream;
2673
+ if (check_cu(cuStreamCreateWithPriority_f(&stream, CU_STREAM_DEFAULT, priority)))
2674
+ {
2675
+ wp_cuda_stream_register(WP_CURRENT_CONTEXT, stream);
2676
+ return stream;
2677
+ }
2678
+ else
2679
+ return NULL;
2680
+ }
2681
+
2682
+ void wp_cuda_stream_destroy(void* context, void* stream)
2683
+ {
2684
+ if (!stream)
2685
+ return;
2686
+
2687
+ wp_cuda_stream_unregister(context, stream);
2688
+
2689
+ // release temporary radix sort buffer associated with this stream
2690
+ radix_sort_release(context, stream);
2691
+
2692
+ check_cu(cuStreamDestroy_f(static_cast<CUstream>(stream)));
2693
+ }
2694
+
2695
+ int wp_cuda_stream_query(void* stream)
2696
+ {
2697
+ CUresult res = cuStreamQuery_f(static_cast<CUstream>(stream));
2698
+
2699
+ if ((res != CUDA_SUCCESS) && (res != CUDA_ERROR_NOT_READY))
2700
+ {
2701
+ // Abnormal, print out error
2702
+ check_cu(res);
2703
+ }
2704
+
2705
+ return res;
2706
+ }
2707
+
2708
+ void wp_cuda_stream_register(void* context, void* stream)
2709
+ {
2710
+ if (!stream)
2711
+ return;
2712
+
2713
+ ContextGuard guard(context);
2714
+
2715
+ // populate stream info
2716
+ StreamInfo& stream_info = g_streams[static_cast<CUstream>(stream)];
2717
+ check_cu(cuEventCreate_f(&stream_info.cached_event, CU_EVENT_DISABLE_TIMING));
2718
+ }
2719
+
2720
+ void wp_cuda_stream_unregister(void* context, void* stream)
2721
+ {
2722
+ if (!stream)
2723
+ return;
2724
+
2725
+ CUstream cuda_stream = static_cast<CUstream>(stream);
2726
+
2727
+ StreamInfo* stream_info = get_stream_info(cuda_stream);
2728
+ if (stream_info)
2729
+ {
2730
+ // release stream info
2731
+ check_cu(cuEventDestroy_f(stream_info->cached_event));
2732
+ g_streams.erase(cuda_stream);
2733
+ }
2734
+
2735
+ // make sure we don't leave dangling references to this stream
2736
+ ContextInfo* context_info = get_context_info(context);
2737
+ if (context_info)
2738
+ {
2739
+ if (cuda_stream == context_info->stream)
2740
+ context_info->stream = NULL;
2741
+ }
2742
+ }
2743
+
2744
+ void* wp_cuda_stream_get_current()
2745
+ {
2746
+ return get_current_stream();
2747
+ }
2748
+
2749
+ void wp_cuda_stream_synchronize(void* stream)
2750
+ {
2751
+ check_cu(cuStreamSynchronize_f(static_cast<CUstream>(stream)));
2752
+ }
2753
+
2754
+ void wp_cuda_stream_wait_event(void* stream, void* event, bool external)
2755
+ {
2756
+ // the external flag can only be used during graph capture
2757
+ if (external && !g_captures.empty() && wp_cuda_stream_is_capturing(stream))
2758
+ {
2759
+ // wait for an external event during graph capture
2760
+ check_cu(cuStreamWaitEvent_f(static_cast<CUstream>(stream), static_cast<CUevent>(event), CU_EVENT_WAIT_EXTERNAL));
2761
+ }
2762
+ else
2763
+ {
2764
+ check_cu(cuStreamWaitEvent_f(static_cast<CUstream>(stream), static_cast<CUevent>(event), CU_EVENT_WAIT_DEFAULT));
2765
+ }
2766
+ }
2767
+
2768
+ void wp_cuda_stream_wait_stream(void* stream, void* other_stream, void* event, bool external)
2769
+ {
2770
+ unsigned record_flags = CU_EVENT_RECORD_DEFAULT;
2771
+ unsigned wait_flags = CU_EVENT_WAIT_DEFAULT;
2772
+
2773
+ // the external flag can only be used during graph capture
2774
+ if (external && !g_captures.empty())
2775
+ {
2776
+ if (wp_cuda_stream_is_capturing(other_stream))
2777
+ record_flags = CU_EVENT_RECORD_EXTERNAL;
2778
+ if (wp_cuda_stream_is_capturing(stream))
2779
+ wait_flags = CU_EVENT_WAIT_EXTERNAL;
2780
+ }
2781
+
2782
+ check_cu(cuEventRecordWithFlags_f(static_cast<CUevent>(event), static_cast<CUstream>(other_stream), record_flags));
2783
+ check_cu(cuStreamWaitEvent_f(static_cast<CUstream>(stream), static_cast<CUevent>(event), wait_flags));
2784
+ }
2785
+
2786
+ int wp_cuda_stream_is_capturing(void* stream)
2787
+ {
2788
+ cudaStreamCaptureStatus status = cudaStreamCaptureStatusNone;
2789
+ check_cuda(cudaStreamIsCapturing(static_cast<cudaStream_t>(stream), &status));
2790
+
2791
+ return int(status != cudaStreamCaptureStatusNone);
2792
+ }
2793
+
2794
+ uint64_t wp_cuda_stream_get_capture_id(void* stream)
2795
+ {
2796
+ return get_capture_id(static_cast<CUstream>(stream));
2797
+ }
2798
+
2799
+ int wp_cuda_stream_get_priority(void* stream)
2800
+ {
2801
+ int priority = 0;
2802
+ check_cuda(cuStreamGetPriority_f(static_cast<CUstream>(stream), &priority));
2803
+
2804
+ return priority;
2805
+ }
2806
+
2807
+ void* wp_cuda_event_create(void* context, unsigned flags)
2808
+ {
2809
+ ContextGuard guard(context, true);
2810
+
2811
+ CUevent event;
2812
+ if (check_cu(cuEventCreate_f(&event, flags)))
2813
+ return event;
2814
+ else
2815
+ return NULL;
2816
+ }
2817
+
2818
+ void wp_cuda_event_destroy(void* event)
2819
+ {
2820
+ check_cu(cuEventDestroy_f(static_cast<CUevent>(event)));
2821
+ }
2822
+
2823
+ int wp_cuda_event_query(void* event)
2824
+ {
2825
+ CUresult res = cuEventQuery_f(static_cast<CUevent>(event));
2826
+
2827
+ if ((res != CUDA_SUCCESS) && (res != CUDA_ERROR_NOT_READY))
2828
+ {
2829
+ // Abnormal, print out error
2830
+ check_cu(res);
2831
+ }
2832
+
2833
+ return res;
2834
+ }
2835
+
2836
+ void wp_cuda_event_record(void* event, void* stream, bool external)
2837
+ {
2838
+ // the external flag can only be used during graph capture
2839
+ if (external && !g_captures.empty() && wp_cuda_stream_is_capturing(stream))
2840
+ {
2841
+ // record external event during graph capture (e.g., for timing or when explicitly specified by the user)
2842
+ check_cu(cuEventRecordWithFlags_f(static_cast<CUevent>(event), static_cast<CUstream>(stream), CU_EVENT_RECORD_EXTERNAL));
2843
+ }
2844
+ else
2845
+ {
2846
+ check_cu(cuEventRecord_f(static_cast<CUevent>(event), static_cast<CUstream>(stream)));
2847
+ }
2848
+ }
2849
+
2850
+ void wp_cuda_event_synchronize(void* event)
2851
+ {
2852
+ check_cu(cuEventSynchronize_f(static_cast<CUevent>(event)));
2853
+ }
2854
+
2855
+ float wp_cuda_event_elapsed_time(void* start_event, void* end_event)
2856
+ {
2857
+ float elapsed = 0.0f;
2858
+ cudaEvent_t start = static_cast<cudaEvent_t>(start_event);
2859
+ cudaEvent_t end = static_cast<cudaEvent_t>(end_event);
2860
+ check_cuda(cudaEventElapsedTime(&elapsed, start, end));
2861
+ return elapsed;
2862
+ }
2863
+
2864
+ bool wp_cuda_graph_begin_capture(void* context, void* stream, int external)
2865
+ {
2866
+ ContextGuard guard(context);
2867
+
2868
+ CUstream cuda_stream = static_cast<CUstream>(stream);
2869
+ StreamInfo* stream_info = get_stream_info(cuda_stream);
2870
+ if (!stream_info)
2871
+ {
2872
+ wp::set_error_string("Warp error: unknown stream");
2873
+ return false;
2874
+ }
2875
+
2876
+ if (external)
2877
+ {
2878
+ // if it's an external capture, make sure it's already active so we can get the capture id
2879
+ cudaStreamCaptureStatus status = cudaStreamCaptureStatusNone;
2880
+ if (!check_cuda(cudaStreamIsCapturing(cuda_stream, &status)))
2881
+ return false;
2882
+ if (status != cudaStreamCaptureStatusActive)
2883
+ {
2884
+ wp::set_error_string("Warp error: stream is not capturing");
2885
+ return false;
2886
+ }
2887
+ }
2888
+ else
2889
+ {
2890
+ // start the capture
2891
+ if (!check_cuda(cudaStreamBeginCapture(cuda_stream, cudaStreamCaptureModeThreadLocal)))
2892
+ return false;
2893
+ }
2894
+
2895
+ uint64_t capture_id = get_capture_id(cuda_stream);
2896
+
2897
+ CaptureInfo* capture = new CaptureInfo();
2898
+ capture->stream = cuda_stream;
2899
+ capture->id = capture_id;
2900
+ capture->external = bool(external);
2901
+
2902
+ // update stream info
2903
+ stream_info->capture = capture;
2904
+
2905
+ // add to known captures
2906
+ g_captures[capture_id] = capture;
2907
+
2908
+ return true;
2909
+ }
2910
+
2911
+ bool wp_cuda_graph_end_capture(void* context, void* stream, void** graph_ret)
2912
+ {
2913
+ ContextGuard guard(context);
2914
+
2915
+ // check if this is a known stream
2916
+ CUstream cuda_stream = static_cast<CUstream>(stream);
2917
+ StreamInfo* stream_info = get_stream_info(cuda_stream);
2918
+ if (!stream_info)
2919
+ {
2920
+ wp::set_error_string("Warp error: unknown capture stream");
2921
+ return false;
2922
+ }
2923
+
2924
+ // check if this stream was used to start a capture
2925
+ CaptureInfo* capture = stream_info->capture;
2926
+ if (!capture)
2927
+ {
2928
+ wp::set_error_string("Warp error: stream has no capture started");
2929
+ return false;
2930
+ }
2931
+
2932
+ // get capture info
2933
+ bool external = capture->external;
2934
+ uint64_t capture_id = capture->id;
2935
+ std::vector<FreeInfo> tmp_allocs = capture->tmp_allocs;
2936
+
2937
+ // clear capture info
2938
+ stream_info->capture = NULL;
2939
+ g_captures.erase(capture_id);
2940
+ delete capture;
2941
+
2942
+ // a lambda to clean up on exit in case of error
2943
+ auto clean_up = [cuda_stream, capture_id, external]()
2944
+ {
2945
+ // unreference outstanding graph allocs so that they will be released with the user reference
2946
+ for (auto it = g_graph_allocs.begin(); it != g_graph_allocs.end(); ++it)
2947
+ {
2948
+ GraphAllocInfo& alloc_info = it->second;
2949
+ if (alloc_info.capture_id == capture_id)
2950
+ alloc_info.graph_destroyed = true;
2951
+ }
2952
+
2953
+ // make sure we terminate the capture
2954
+ if (!external)
2955
+ {
2956
+ cudaGraph_t graph = NULL;
2957
+ cudaStreamEndCapture(cuda_stream, &graph);
2958
+ cudaGetLastError();
2959
+ }
2960
+ };
2961
+
2962
+ // get captured graph without ending the capture in case it is external
2963
+ cudaGraph_t graph = get_capture_graph(cuda_stream);
2964
+ if (!graph)
2965
+ {
2966
+ clean_up();
2967
+ return false;
2968
+ }
2969
+
2970
+ // ensure that all forked streams are joined to the main capture stream by manually
2971
+ // adding outstanding capture dependencies gathered from the graph leaf nodes
2972
+ std::vector<cudaGraphNode_t> stream_dependencies;
2973
+ std::vector<cudaGraphNode_t> leaf_nodes;
2974
+ if (get_capture_dependencies(cuda_stream, stream_dependencies) && get_graph_leaf_nodes(graph, leaf_nodes))
2975
+ {
2976
+ // compute set difference to get unjoined dependencies
2977
+ std::vector<cudaGraphNode_t> unjoined_dependencies;
2978
+ std::sort(stream_dependencies.begin(), stream_dependencies.end());
2979
+ std::sort(leaf_nodes.begin(), leaf_nodes.end());
2980
+ std::set_difference(leaf_nodes.begin(), leaf_nodes.end(),
2981
+ stream_dependencies.begin(), stream_dependencies.end(),
2982
+ std::back_inserter(unjoined_dependencies));
2983
+ if (!unjoined_dependencies.empty())
2984
+ {
2985
+ check_cu(cuStreamUpdateCaptureDependencies_f(cuda_stream, unjoined_dependencies.data(), unjoined_dependencies.size(),
2986
+ CU_STREAM_ADD_CAPTURE_DEPENDENCIES));
2987
+ // ensure graph is still valid
2988
+ if (get_capture_graph(cuda_stream) != graph)
2989
+ {
2990
+ clean_up();
2991
+ return false;
2992
+ }
2993
+ }
2994
+ }
2995
+
2996
+ // check if this graph has unfreed allocations, which require special handling
2997
+ std::vector<void*> unfreed_allocs;
2998
+ for (auto it = g_graph_allocs.begin(); it != g_graph_allocs.end(); ++it)
2999
+ {
3000
+ GraphAllocInfo& alloc_info = it->second;
3001
+ if (alloc_info.capture_id == capture_id)
3002
+ unfreed_allocs.push_back(it->first);
3003
+ }
3004
+
3005
+ if (!unfreed_allocs.empty() || !tmp_allocs.empty())
3006
+ {
3007
+ // Create a user object that will notify us when the instantiated graph is destroyed.
3008
+ // This works for external captures also, since we wouldn't otherwise know when
3009
+ // the externally-created graph instance gets deleted.
3010
+ // This callback is guaranteed to arrive after the graph has finished executing on the device,
3011
+ // not necessarily when cudaGraphExecDestroy() is called.
3012
+ GraphDestroyCallbackInfo* graph_info = new GraphDestroyCallbackInfo;
3013
+ graph_info->context = context ? context : get_current_context();
3014
+ graph_info->unfreed_allocs = unfreed_allocs;
3015
+ graph_info->tmp_allocs = tmp_allocs;
3016
+ cudaUserObject_t user_object;
3017
+ check_cuda(cudaUserObjectCreate(&user_object, graph_info, on_graph_destroy, 1, cudaUserObjectNoDestructorSync));
3018
+ check_cuda(cudaGraphRetainUserObject(graph, user_object, 1, cudaGraphUserObjectMove));
3019
+
3020
+ // ensure graph is still valid
3021
+ if (get_capture_graph(cuda_stream) != graph)
3022
+ {
3023
+ clean_up();
3024
+ return false;
3025
+ }
3026
+ }
3027
+
3028
+ // for external captures, we don't instantiate the graph ourselves, so we're done
3029
+ if (external)
3030
+ return true;
3031
+
3032
+ // end the capture
3033
+ if (!check_cuda(cudaStreamEndCapture(cuda_stream, &graph)))
3034
+ return false;
3035
+
3036
+ // process deferred free list if no more captures are ongoing
3037
+ if (g_captures.empty())
3038
+ {
3039
+ run_deferred_actions();
3040
+ }
3041
+
3042
+ if (graph_ret)
3043
+ *graph_ret = graph;
3044
+
3045
+ return true;
3046
+ }
3047
+
3048
+ bool wp_capture_debug_dot_print(void* graph, const char *path, uint32_t flags)
3049
+ {
3050
+ if (!check_cuda(cudaGraphDebugDotPrint((cudaGraph_t)graph, path, flags)))
3051
+ return false;
3052
+ return true;
3053
+ }
3054
+
3055
+ bool wp_cuda_graph_create_exec(void* context, void* stream, void* graph, void** graph_exec_ret)
3056
+ {
3057
+ ContextGuard guard(context);
3058
+
3059
+ cudaGraphExec_t graph_exec = NULL;
3060
+ if (!check_cuda(cudaGraphInstantiateWithFlags(&graph_exec, (cudaGraph_t)graph, cudaGraphInstantiateFlagAutoFreeOnLaunch)))
3061
+ return false;
3062
+
3063
+ // Usually uploading the graph explicitly is optional, but when updating graph nodes (e.g., indirect dispatch)
3064
+ // then the upload is required because otherwise the graph nodes that get updated might not yet be uploaded, which
3065
+ // results in undefined behavior.
3066
+ CUstream cuda_stream = static_cast<CUstream>(stream);
3067
+ if (!check_cuda(cudaGraphUpload(graph_exec, cuda_stream)))
3068
+ return false;
3069
+
3070
+ if (graph_exec_ret)
3071
+ *graph_exec_ret = graph_exec;
3072
+
3073
+ return true;
3074
+ }
3075
+
3076
+ // Support for conditional graph nodes available with CUDA 12.4+.
3077
+ #if CUDA_VERSION >= 12040
3078
+
3079
+ // CUBIN or PTX data for compiled conditional modules, loaded on demand, keyed on device architecture
3080
+ using ModuleKey = std::pair<int, bool>; // <arch, use_ptx>
3081
+ static std::map<ModuleKey, void*> g_conditional_modules;
3082
+
3083
+ // Compile module with conditional helper kernels
3084
+ static void* compile_conditional_module(int arch, bool use_ptx)
3085
+ {
3086
+ static const char* kernel_source = R"(
3087
+ typedef __device_builtin__ unsigned long long cudaGraphConditionalHandle;
3088
+ extern "C" __device__ __cudart_builtin__ void cudaGraphSetConditional(cudaGraphConditionalHandle handle, unsigned int value);
3089
+
3090
+ extern "C" __global__ void set_conditional_if_handle_kernel(cudaGraphConditionalHandle handle, int* value)
3091
+ {
3092
+ if (threadIdx.x + blockIdx.x * blockDim.x == 0)
3093
+ cudaGraphSetConditional(handle, *value);
3094
+ }
3095
+
3096
+ extern "C" __global__ void set_conditional_else_handle_kernel(cudaGraphConditionalHandle handle, int* value)
3097
+ {
3098
+ if (threadIdx.x + blockIdx.x * blockDim.x == 0)
3099
+ cudaGraphSetConditional(handle, !*value);
3100
+ }
3101
+
3102
+ extern "C" __global__ void set_conditional_if_else_handles_kernel(cudaGraphConditionalHandle if_handle, cudaGraphConditionalHandle else_handle, int* value)
3103
+ {
3104
+ if (threadIdx.x + blockIdx.x * blockDim.x == 0)
3105
+ {
3106
+ cudaGraphSetConditional(if_handle, *value);
3107
+ cudaGraphSetConditional(else_handle, !*value);
3108
+ }
3109
+ }
3110
+ )";
3111
+
3112
+ // avoid recompilation
3113
+ ModuleKey key = {arch, use_ptx};
3114
+ auto it = g_conditional_modules.find(key);
3115
+ if (it != g_conditional_modules.end())
3116
+ return it->second;
3117
+
3118
+ nvrtcProgram prog;
3119
+ if (!check_nvrtc(nvrtcCreateProgram(&prog, kernel_source, "conditional_kernels", 0, NULL, NULL)))
3120
+ return NULL;
3121
+
3122
+ char arch_opt[128];
3123
+ if (use_ptx)
3124
+ snprintf(arch_opt, sizeof(arch_opt), "--gpu-architecture=compute_%d", arch);
3125
+ else
3126
+ snprintf(arch_opt, sizeof(arch_opt), "--gpu-architecture=sm_%d", arch);
3127
+
3128
+ std::vector<const char*> opts;
3129
+ opts.push_back(arch_opt);
3130
+
3131
+ const bool print_debug = (std::getenv("WARP_DEBUG") != nullptr);
3132
+ if (print_debug)
3133
+ {
3134
+ printf("NVRTC options (conditional module, arch=%d, use_ptx=%s):\n", arch, use_ptx ? "true" : "false");
3135
+ for(auto o: opts) {
3136
+ printf("%s\n", o);
3137
+ }
3138
+ }
3139
+
3140
+ if (!check_nvrtc(nvrtcCompileProgram(prog, int(opts.size()), opts.data())))
3141
+ {
3142
+ size_t log_size;
3143
+ if (check_nvrtc(nvrtcGetProgramLogSize(prog, &log_size)))
3144
+ {
3145
+ std::vector<char> log(log_size);
3146
+ if (check_nvrtc(nvrtcGetProgramLog(prog, log.data())))
3147
+ fprintf(stderr, "%s", log.data());
3148
+ }
3149
+ nvrtcDestroyProgram(&prog);
3150
+ return NULL;
3151
+ }
3152
+
3153
+ // get output
3154
+ char* output = NULL;
3155
+ size_t output_size = 0;
3156
+
3157
+ if (use_ptx)
3158
+ {
3159
+ check_nvrtc(nvrtcGetPTXSize(prog, &output_size));
3160
+ if (output_size > 0)
3161
+ {
3162
+ output = new char[output_size];
3163
+ if (check_nvrtc(nvrtcGetPTX(prog, output)))
3164
+ g_conditional_modules[key] = output;
3165
+ }
3166
+ }
3167
+ else
3168
+ {
3169
+ check_nvrtc(nvrtcGetCUBINSize(prog, &output_size));
3170
+ if (output_size > 0)
3171
+ {
3172
+ output = new char[output_size];
3173
+ if (check_nvrtc(nvrtcGetCUBIN(prog, output)))
3174
+ g_conditional_modules[key] = output;
3175
+ }
3176
+ }
3177
+
3178
+ nvrtcDestroyProgram(&prog);
3179
+
3180
+ // return CUBIN or PTX data
3181
+ return output;
3182
+ }
3183
+
3184
+
3185
+ // Load module with conditional helper kernels
3186
+ static CUmodule load_conditional_module(void* context, int arch, bool use_ptx)
3187
+ {
3188
+ ContextInfo* context_info = get_context_info(context);
3189
+ if (!context_info)
3190
+ return NULL;
3191
+
3192
+ // check if already loaded
3193
+ if (context_info->conditional_module)
3194
+ return context_info->conditional_module;
3195
+
3196
+ // compile if needed
3197
+ void* compiled_module = compile_conditional_module(arch, use_ptx);
3198
+ if (!compiled_module)
3199
+ {
3200
+ fprintf(stderr, "Warp error: Failed to compile conditional kernels\n");
3201
+ return NULL;
3202
+ }
3203
+
3204
+ // load module (handles both PTX and CUBIN data automatically)
3205
+ CUmodule module = NULL;
3206
+ if (!check_cu(cuModuleLoadDataEx_f(&module, compiled_module, 0, NULL, NULL)))
3207
+ {
3208
+ fprintf(stderr, "Warp error: Failed to load conditional kernels module\n");
3209
+ return NULL;
3210
+ }
3211
+
3212
+ context_info->conditional_module = module;
3213
+
3214
+ return module;
3215
+ }
3216
+
3217
+ static CUfunction get_conditional_kernel(void* context, int arch, bool use_ptx, const char* name)
3218
+ {
3219
+ // load module if needed
3220
+ CUmodule module = load_conditional_module(context, arch, use_ptx);
3221
+ if (!module)
3222
+ return NULL;
3223
+
3224
+ CUfunction kernel;
3225
+ if (!check_cu(cuModuleGetFunction_f(&kernel, module, name)))
3226
+ {
3227
+ fprintf(stderr, "Warp error: Failed to get kernel %s\n", name);
3228
+ return NULL;
3229
+ }
3230
+
3231
+ return kernel;
3232
+ }
3233
+
3234
+ bool wp_cuda_graph_pause_capture(void* context, void* stream, void** graph_ret)
3235
+ {
3236
+ ContextGuard guard(context);
3237
+
3238
+ CUstream cuda_stream = static_cast<CUstream>(stream);
3239
+ if (!check_cuda(cudaStreamEndCapture(cuda_stream, (cudaGraph_t*)graph_ret)))
3240
+ return false;
3241
+ return true;
3242
+ }
3243
+
3244
+ bool wp_cuda_graph_resume_capture(void* context, void* stream, void* graph)
3245
+ {
3246
+ ContextGuard guard(context);
3247
+
3248
+ CUstream cuda_stream = static_cast<CUstream>(stream);
3249
+ cudaGraph_t cuda_graph = static_cast<cudaGraph_t>(graph);
3250
+
3251
+ std::vector<cudaGraphNode_t> leaf_nodes;
3252
+ if (!get_graph_leaf_nodes(cuda_graph, leaf_nodes))
3253
+ return false;
3254
+
3255
+ if (!check_cuda(cudaStreamBeginCaptureToGraph(cuda_stream,
3256
+ cuda_graph,
3257
+ leaf_nodes.data(),
3258
+ nullptr,
3259
+ leaf_nodes.size(),
3260
+ cudaStreamCaptureModeThreadLocal)))
3261
+ return false;
3262
+
3263
+ return true;
3264
+ }
3265
+
3266
+ // https://developer.nvidia.com/blog/constructing-cuda-graphs-with-dynamic-parameters/#combined_approach
3267
+ // https://developer.nvidia.com/blog/dynamic-control-flow-in-cuda-graphs-with-conditional-nodes/
3268
+ // condition is a gpu pointer
3269
+ // if_graph_ret and else_graph_ret should be NULL if not needed
3270
+ bool wp_cuda_graph_insert_if_else(void* context, void* stream, int arch, bool use_ptx, int* condition, void** if_graph_ret, void** else_graph_ret)
3271
+ {
3272
+ bool has_if = if_graph_ret != NULL;
3273
+ bool has_else = else_graph_ret != NULL;
3274
+ int num_branches = int(has_if) + int(has_else);
3275
+
3276
+ // if neither the IF nor ELSE branches are required, it's a no-op
3277
+ if (num_branches == 0)
3278
+ return true;
3279
+
3280
+ ContextGuard guard(context);
3281
+
3282
+ CUstream cuda_stream = static_cast<CUstream>(stream);
3283
+
3284
+ // Get the current stream capturing graph
3285
+ CUstreamCaptureStatus capture_status = CU_STREAM_CAPTURE_STATUS_NONE;
3286
+ cudaGraph_t cuda_graph = NULL;
3287
+ const cudaGraphNode_t* capture_deps = NULL;
3288
+ size_t dep_count = 0;
3289
+ if (!check_cu(cuStreamGetCaptureInfo_f(cuda_stream, &capture_status, nullptr, &cuda_graph, &capture_deps, &dep_count)))
3290
+ return false;
3291
+
3292
+ // abort if not capturing
3293
+ if (!cuda_graph || capture_status != CU_STREAM_CAPTURE_STATUS_ACTIVE)
3294
+ {
3295
+ wp::set_error_string("Stream is not capturing");
3296
+ return false;
3297
+ }
3298
+
3299
+ //int driver_version = wp_cuda_driver_version();
3300
+
3301
+ // IF-ELSE nodes are only supported with CUDA 12.8+
3302
+ // Somehow child graphs produce wrong results when an else branch is used
3303
+ // Seems to be a bug in the CUDA driver: https://nvbugs/5241330
3304
+ if (num_branches == 1 /*|| driver_version >= 12080*/)
3305
+ {
3306
+ cudaGraphConditionalHandle handle;
3307
+ check_cuda(cudaGraphConditionalHandleCreate(&handle, cuda_graph));
3308
+
3309
+ // run a kernel to set the condition handle from the condition pointer
3310
+ // (need to negate the condition if only the else branch is used)
3311
+ CUfunction kernel;
3312
+ if (has_if)
3313
+ kernel = get_conditional_kernel(context, arch, use_ptx, "set_conditional_if_handle_kernel");
3314
+ else
3315
+ kernel = get_conditional_kernel(context, arch, use_ptx, "set_conditional_else_handle_kernel");
3316
+
3317
+ if (!kernel)
3318
+ {
3319
+ wp::set_error_string("Failed to get built-in conditional kernel");
3320
+ return false;
3321
+ }
3322
+
3323
+ void* kernel_args[2];
3324
+ kernel_args[0] = &handle;
3325
+ kernel_args[1] = &condition;
3326
+
3327
+ if (!check_cu(cuLaunchKernel_f(kernel, 1, 1, 1, 1, 1, 1, 0, cuda_stream, kernel_args, NULL)))
3328
+ return false;
3329
+
3330
+ if (!check_cu(cuStreamGetCaptureInfo_f(cuda_stream, &capture_status, nullptr, &cuda_graph, &capture_deps, &dep_count)))
3331
+ return false;
3332
+
3333
+ // create conditional node
3334
+ CUgraphNode condition_node;
3335
+ CUgraphNodeParams condition_params = { CU_GRAPH_NODE_TYPE_CONDITIONAL };
3336
+ condition_params.conditional.handle = handle;
3337
+ condition_params.conditional.type = CU_GRAPH_COND_TYPE_IF;
3338
+ condition_params.conditional.size = num_branches;
3339
+ condition_params.conditional.ctx = get_current_context();
3340
+ if (!check_cu(cuGraphAddNode_f(&condition_node, cuda_graph, capture_deps, NULL, dep_count, &condition_params)))
3341
+ return false;
3342
+
3343
+ if (!check_cu(cuStreamUpdateCaptureDependencies_f(cuda_stream, &condition_node, 1, cudaStreamSetCaptureDependencies)))
3344
+ return false;
3345
+
3346
+ if (num_branches == 1)
3347
+ {
3348
+ if (has_if)
3349
+ *if_graph_ret = condition_params.conditional.phGraph_out[0];
3350
+ else
3351
+ *else_graph_ret = condition_params.conditional.phGraph_out[0];
3352
+ }
3353
+ else
3354
+ {
3355
+ *if_graph_ret = condition_params.conditional.phGraph_out[0];
3356
+ *else_graph_ret = condition_params.conditional.phGraph_out[1];
3357
+ }
3358
+ }
3359
+ else
3360
+ {
3361
+ // Create IF node followed by an additional IF node with negated condition
3362
+ cudaGraphConditionalHandle if_handle, else_handle;
3363
+ check_cuda(cudaGraphConditionalHandleCreate(&if_handle, cuda_graph));
3364
+ check_cuda(cudaGraphConditionalHandleCreate(&else_handle, cuda_graph));
3365
+
3366
+ CUfunction kernel = get_conditional_kernel(context, arch, use_ptx, "set_conditional_if_else_handles_kernel");
3367
+ if (!kernel)
3368
+ {
3369
+ wp::set_error_string("Failed to get built-in conditional kernel");
3370
+ return false;
3371
+ }
3372
+
3373
+ void* kernel_args[3];
3374
+ kernel_args[0] = &if_handle;
3375
+ kernel_args[1] = &else_handle;
3376
+ kernel_args[2] = &condition;
3377
+
3378
+ if (!check_cu(cuLaunchKernel_f(kernel, 1, 1, 1, 1, 1, 1, 0, cuda_stream, kernel_args, NULL)))
3379
+ return false;
3380
+
3381
+ if (!check_cu(cuStreamGetCaptureInfo_f(cuda_stream, &capture_status, nullptr, &cuda_graph, &capture_deps, &dep_count)))
3382
+ return false;
3383
+
3384
+ CUgraphNode if_node;
3385
+ CUgraphNodeParams if_params = { CU_GRAPH_NODE_TYPE_CONDITIONAL };
3386
+ if_params.conditional.handle = if_handle;
3387
+ if_params.conditional.type = CU_GRAPH_COND_TYPE_IF;
3388
+ if_params.conditional.size = 1;
3389
+ if_params.conditional.ctx = get_current_context();
3390
+ if (!check_cu(cuGraphAddNode_f(&if_node, cuda_graph, capture_deps, NULL, dep_count, &if_params)))
3391
+ return false;
3392
+
3393
+ CUgraphNode else_node;
3394
+ CUgraphNodeParams else_params = { CU_GRAPH_NODE_TYPE_CONDITIONAL };
3395
+ else_params.conditional.handle = else_handle;
3396
+ else_params.conditional.type = CU_GRAPH_COND_TYPE_IF;
3397
+ else_params.conditional.size = 1;
3398
+ else_params.conditional.ctx = get_current_context();
3399
+ if (!check_cu(cuGraphAddNode_f(&else_node, cuda_graph, &if_node, NULL, 1, &else_params)))
3400
+ return false;
3401
+
3402
+ if (!check_cu(cuStreamUpdateCaptureDependencies_f(cuda_stream, &else_node, 1, cudaStreamSetCaptureDependencies)))
3403
+ return false;
3404
+
3405
+ *if_graph_ret = if_params.conditional.phGraph_out[0];
3406
+ *else_graph_ret = else_params.conditional.phGraph_out[0];
3407
+ }
3408
+
3409
+ return true;
3410
+ }
3411
+
3412
+ // graph node type names for intelligible error reporting
3413
+ static const char* get_graph_node_type_name(CUgraphNodeType type)
3414
+ {
3415
+ static const std::unordered_map<CUgraphNodeType, const char*> names
3416
+ {
3417
+ {CU_GRAPH_NODE_TYPE_KERNEL, "kernel launch"},
3418
+ {CU_GRAPH_NODE_TYPE_MEMCPY, "memcpy"},
3419
+ {CU_GRAPH_NODE_TYPE_MEMSET, "memset"},
3420
+ {CU_GRAPH_NODE_TYPE_HOST, "host execution"},
3421
+ {CU_GRAPH_NODE_TYPE_GRAPH, "graph launch"},
3422
+ {CU_GRAPH_NODE_TYPE_EMPTY, "empty node"},
3423
+ {CU_GRAPH_NODE_TYPE_WAIT_EVENT, "event wait"},
3424
+ {CU_GRAPH_NODE_TYPE_EVENT_RECORD, "event record"},
3425
+ {CU_GRAPH_NODE_TYPE_EXT_SEMAS_SIGNAL, "semaphore signal"},
3426
+ {CU_GRAPH_NODE_TYPE_EXT_SEMAS_WAIT, "semaphore wait"},
3427
+ {CU_GRAPH_NODE_TYPE_MEM_ALLOC, "memory allocation"},
3428
+ {CU_GRAPH_NODE_TYPE_MEM_FREE, "memory deallocation"},
3429
+ {CU_GRAPH_NODE_TYPE_BATCH_MEM_OP, "batched mem op"},
3430
+ {CU_GRAPH_NODE_TYPE_CONDITIONAL, "conditional node"},
3431
+ };
3432
+
3433
+ auto it = names.find(type);
3434
+ if (it != names.end())
3435
+ return it->second;
3436
+ else
3437
+ return "unknown node";
3438
+ }
3439
+
3440
+ // check if a graph can be launched as a child graph
3441
+ static bool is_valid_child_graph(void* child_graph)
3442
+ {
3443
+ // disallowed child graph nodes according to the documentation of cuGraphAddChildGraphNode()
3444
+ static const std::unordered_set<CUgraphNodeType> disallowed_nodes
3445
+ {
3446
+ CU_GRAPH_NODE_TYPE_MEM_ALLOC,
3447
+ CU_GRAPH_NODE_TYPE_MEM_FREE,
3448
+ CU_GRAPH_NODE_TYPE_CONDITIONAL,
3449
+ };
3450
+
3451
+ if (!child_graph)
3452
+ {
3453
+ wp::set_error_string("Child graph is null");
3454
+ return false;
3455
+ }
3456
+
3457
+ size_t num_nodes = 0;
3458
+ if (!check_cuda(cudaGraphGetNodes((cudaGraph_t)child_graph, NULL, &num_nodes)))
3459
+ return false;
3460
+ std::vector<cudaGraphNode_t> nodes(num_nodes);
3461
+ if (!check_cuda(cudaGraphGetNodes((cudaGraph_t)child_graph, nodes.data(), &num_nodes)))
3462
+ return false;
3463
+
3464
+ for (size_t i = 0; i < num_nodes; i++)
3465
+ {
3466
+ // note: we use the driver API to get the node type, otherwise some nodes are not recognized correctly
3467
+ CUgraphNodeType node_type;
3468
+ check_cu(cuGraphNodeGetType_f(nodes[i], &node_type));
3469
+ auto it = disallowed_nodes.find(node_type);
3470
+ if (it != disallowed_nodes.end())
3471
+ {
3472
+ wp::set_error_string("Child graph contains an unsupported operation (%s)", get_graph_node_type_name(node_type));
3473
+ return false;
3474
+ }
3475
+ }
3476
+
3477
+ return true;
3478
+ }
3479
+
3480
+ // check if a graph can be used as a conditional body graph
3481
+ // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#condtional-node-body-graph-requirements
3482
+ bool wp_cuda_graph_check_conditional_body(void* body_graph)
3483
+ {
3484
+ static const std::unordered_set<CUgraphNodeType> allowed_nodes
3485
+ {
3486
+ CU_GRAPH_NODE_TYPE_MEMCPY,
3487
+ CU_GRAPH_NODE_TYPE_MEMSET,
3488
+ CU_GRAPH_NODE_TYPE_KERNEL,
3489
+ CU_GRAPH_NODE_TYPE_GRAPH,
3490
+ CU_GRAPH_NODE_TYPE_EMPTY,
3491
+ CU_GRAPH_NODE_TYPE_CONDITIONAL,
3492
+ };
3493
+
3494
+ if (!body_graph)
3495
+ {
3496
+ wp::set_error_string("Conditional body graph is null");
3497
+ return false;
3498
+ }
3499
+
3500
+ size_t num_nodes = 0;
3501
+ if (!check_cuda(cudaGraphGetNodes((cudaGraph_t)body_graph, NULL, &num_nodes)))
3502
+ return false;
3503
+ std::vector<cudaGraphNode_t> nodes(num_nodes);
3504
+ if (!check_cuda(cudaGraphGetNodes((cudaGraph_t)body_graph, nodes.data(), &num_nodes)))
3505
+ return false;
3506
+
3507
+ for (size_t i = 0; i < num_nodes; i++)
3508
+ {
3509
+ // note: we use the driver API to get the node type, otherwise some nodes are not recognized correctly
3510
+ CUgraphNodeType node_type;
3511
+ check_cu(cuGraphNodeGetType_f(nodes[i], &node_type));
3512
+ if (allowed_nodes.find(node_type) == allowed_nodes.end())
3513
+ {
3514
+ wp::set_error_string("Conditional body graph contains an unsupported operation (%s)", get_graph_node_type_name(node_type));
3515
+ return false;
3516
+ }
3517
+ else if (node_type == CU_GRAPH_NODE_TYPE_GRAPH)
3518
+ {
3519
+ // check nested child graphs recursively
3520
+ cudaGraph_t child_graph = NULL;
3521
+ if (!check_cuda(cudaGraphChildGraphNodeGetGraph(nodes[i], &child_graph)))
3522
+ return false;
3523
+ if (!wp_cuda_graph_check_conditional_body(child_graph))
3524
+ return false;
3525
+ }
3526
+ }
3527
+
3528
+ return true;
3529
+ }
3530
+
3531
+ bool wp_cuda_graph_insert_child_graph(void* context, void* stream, void* child_graph)
3532
+ {
3533
+ if (!is_valid_child_graph(child_graph))
3534
+ return false;
3535
+
3536
+ ContextGuard guard(context);
3537
+
3538
+ CUstream cuda_stream = static_cast<CUstream>(stream);
3539
+
3540
+ // Get the current stream capturing graph
3541
+ CUstreamCaptureStatus capture_status = CU_STREAM_CAPTURE_STATUS_NONE;
3542
+ void* cuda_graph = NULL;
3543
+ const CUgraphNode* capture_deps = NULL;
3544
+ size_t dep_count = 0;
3545
+ if (!check_cu(cuStreamGetCaptureInfo_f(cuda_stream, &capture_status, nullptr, (cudaGraph_t*)&cuda_graph, &capture_deps, &dep_count)))
3546
+ return false;
3547
+
3548
+ if (!wp_cuda_graph_pause_capture(context, cuda_stream, &cuda_graph))
3549
+ return false;
3550
+
3551
+ cudaGraphNode_t body_node;
3552
+ if (!check_cuda(cudaGraphAddChildGraphNode(&body_node,
3553
+ static_cast<cudaGraph_t>(cuda_graph),
3554
+ capture_deps, dep_count,
3555
+ static_cast<cudaGraph_t>(child_graph))))
3556
+ return false;
3557
+
3558
+ if (!wp_cuda_graph_resume_capture(context, cuda_stream, cuda_graph))
3559
+ return false;
3560
+
3561
+ if (!check_cu(cuStreamUpdateCaptureDependencies_f(cuda_stream, &body_node, 1, cudaStreamSetCaptureDependencies)))
3562
+ return false;
3563
+
3564
+ return true;
3565
+ }
3566
+
3567
+ bool wp_cuda_graph_insert_while(void* context, void* stream, int arch, bool use_ptx, int* condition, void** body_graph_ret, uint64_t* handle_ret)
3568
+ {
3569
+ // if there's no body, it's a no-op
3570
+ if (!body_graph_ret)
3571
+ return true;
3572
+
3573
+ ContextGuard guard(context);
3574
+
3575
+ CUstream cuda_stream = static_cast<CUstream>(stream);
3576
+
3577
+ // Get the current stream capturing graph
3578
+ CUstreamCaptureStatus capture_status = CU_STREAM_CAPTURE_STATUS_NONE;
3579
+ cudaGraph_t cuda_graph = NULL;
3580
+ const cudaGraphNode_t* capture_deps = NULL;
3581
+ size_t dep_count = 0;
3582
+ if (!check_cu(cuStreamGetCaptureInfo_f(cuda_stream, &capture_status, nullptr, &cuda_graph, &capture_deps, &dep_count)))
3583
+ return false;
3584
+
3585
+ // abort if not capturing
3586
+ if (!cuda_graph || capture_status != CU_STREAM_CAPTURE_STATUS_ACTIVE)
3587
+ {
3588
+ wp::set_error_string("Stream is not capturing");
3589
+ return false;
3590
+ }
3591
+
3592
+ cudaGraphConditionalHandle handle;
3593
+ if (!check_cuda(cudaGraphConditionalHandleCreate(&handle, cuda_graph)))
3594
+ return false;
3595
+
3596
+ // launch a kernel to set the condition handle from condition pointer
3597
+ CUfunction kernel = get_conditional_kernel(context, arch, use_ptx, "set_conditional_if_handle_kernel");
3598
+ if (!kernel)
3599
+ {
3600
+ wp::set_error_string("Failed to get built-in conditional kernel");
3601
+ return false;
3602
+ }
3603
+
3604
+ void* kernel_args[2];
3605
+ kernel_args[0] = &handle;
3606
+ kernel_args[1] = &condition;
3607
+
3608
+ if (!check_cu(cuLaunchKernel_f(kernel, 1, 1, 1, 1, 1, 1, 0, cuda_stream, kernel_args, NULL)))
3609
+ return false;
3610
+
3611
+ if (!check_cu(cuStreamGetCaptureInfo_f(cuda_stream, &capture_status, nullptr, &cuda_graph, &capture_deps, &dep_count)))
3612
+ return false;
3613
+
3614
+ // insert conditional graph node
3615
+ CUgraphNode while_node;
3616
+ CUgraphNodeParams while_params = { CU_GRAPH_NODE_TYPE_CONDITIONAL };
3617
+ while_params.conditional.handle = handle;
3618
+ while_params.conditional.type = CU_GRAPH_COND_TYPE_WHILE;
3619
+ while_params.conditional.size = 1;
3620
+ while_params.conditional.ctx = get_current_context();
3621
+ if (!check_cu(cuGraphAddNode_f(&while_node, cuda_graph, capture_deps, NULL, dep_count, &while_params)))
3622
+ return false;
3623
+
3624
+ if (!check_cu(cuStreamUpdateCaptureDependencies_f(cuda_stream, &while_node, 1, cudaStreamSetCaptureDependencies)))
3625
+ return false;
3626
+
3627
+ *body_graph_ret = while_params.conditional.phGraph_out[0];
3628
+ *handle_ret = handle;
3629
+
3630
+ return true;
3631
+ }
3632
+
3633
+ bool wp_cuda_graph_set_condition(void* context, void* stream, int arch, bool use_ptx, int* condition, uint64_t handle)
3634
+ {
3635
+ ContextGuard guard(context);
3636
+
3637
+ CUstream cuda_stream = static_cast<CUstream>(stream);
3638
+
3639
+ // launch a kernel to set the condition handle from condition pointer
3640
+ CUfunction kernel = get_conditional_kernel(context, arch, use_ptx, "set_conditional_if_handle_kernel");
3641
+ if (!kernel)
3642
+ {
3643
+ wp::set_error_string("Failed to get built-in conditional kernel");
3644
+ return false;
3645
+ }
3646
+
3647
+ void* kernel_args[2];
3648
+ kernel_args[0] = &handle;
3649
+ kernel_args[1] = &condition;
3650
+
3651
+ if (!check_cu(cuLaunchKernel_f(kernel, 1, 1, 1, 1, 1, 1, 0, cuda_stream, kernel_args, NULL)))
3652
+ return false;
3653
+
3654
+ return true;
3655
+ }
3656
+
3657
+ #else
3658
+ // stubs for conditional graph node API if CUDA toolkit is too old.
3659
+
3660
+ bool wp_cuda_graph_pause_capture(void* context, void* stream, void** graph_ret)
3661
+ {
3662
+ wp::set_error_string("Warp error: Warp must be built with CUDA Toolkit 12.4+ to enable conditional graph nodes");
3663
+ return false;
3664
+ }
3665
+
3666
+ bool wp_cuda_graph_resume_capture(void* context, void* stream, void* graph)
3667
+ {
3668
+ wp::set_error_string("Warp error: Warp must be built with CUDA Toolkit 12.4+ to enable conditional graph nodes");
3669
+ return false;
3670
+ }
3671
+
3672
+ bool wp_cuda_graph_insert_if_else(void* context, void* stream, int arch, bool use_ptx, int* condition, void** if_graph_ret, void** else_graph_ret)
3673
+ {
3674
+ wp::set_error_string("Warp error: Warp must be built with CUDA Toolkit 12.4+ to enable conditional graph nodes");
3675
+ return false;
3676
+ }
3677
+
3678
+ bool wp_cuda_graph_insert_while(void* context, void* stream, int arch, bool use_ptx, int* condition, void** body_graph_ret, uint64_t* handle_ret)
3679
+ {
3680
+ wp::set_error_string("Warp error: Warp must be built with CUDA Toolkit 12.4+ to enable conditional graph nodes");
3681
+ return false;
3682
+ }
3683
+
3684
+ bool wp_cuda_graph_set_condition(void* context, void* stream, int arch, bool use_ptx, int* condition, uint64_t handle)
3685
+ {
3686
+ wp::set_error_string("Warp error: Warp must be built with CUDA Toolkit 12.4+ to enable conditional graph nodes");
3687
+ return false;
3688
+ }
3689
+
3690
+ bool wp_cuda_graph_insert_child_graph(void* context, void* stream, void* child_graph)
3691
+ {
3692
+ wp::set_error_string("Warp error: Warp must be built with CUDA Toolkit 12.4+ to enable conditional graph nodes");
3693
+ return false;
3694
+ }
3695
+
3696
+ bool wp_cuda_graph_check_conditional_body(void* body_graph)
3697
+ {
3698
+ wp::set_error_string("Warp error: Warp must be built with CUDA Toolkit 12.4+ to enable conditional graph nodes");
3699
+ return false;
3700
+ }
3701
+
3702
+ #endif // support for conditional graph nodes
3703
+
3704
+
3705
+ bool wp_cuda_graph_launch(void* graph_exec, void* stream)
3706
+ {
3707
+ // TODO: allow naming graphs?
3708
+ begin_cuda_range(WP_TIMING_GRAPH, stream, get_stream_context(stream), "graph");
3709
+
3710
+ bool result = check_cuda(cudaGraphLaunch((cudaGraphExec_t)graph_exec, (cudaStream_t)stream));
3711
+
3712
+ end_cuda_range(WP_TIMING_GRAPH, stream);
3713
+
3714
+ return result;
3715
+ }
3716
+
3717
+ bool wp_cuda_graph_destroy(void* context, void* graph)
3718
+ {
3719
+ // ensure there are no graph captures in progress
3720
+ if (g_captures.empty())
3721
+ {
3722
+ ContextGuard guard(context);
3723
+ return check_cuda(cudaGraphDestroy((cudaGraph_t)graph));
3724
+ }
3725
+ else
3726
+ {
3727
+ GraphDestroyInfo info;
3728
+ info.context = context ? context : get_current_context();
3729
+ info.graph = graph;
3730
+ g_deferred_graph_list.push_back(info);
3731
+ return true;
3732
+ }
3733
+ }
3734
+
3735
+ bool wp_cuda_graph_exec_destroy(void* context, void* graph_exec)
3736
+ {
3737
+ // ensure there are no graph captures in progress
3738
+ if (g_captures.empty())
3739
+ {
3740
+ ContextGuard guard(context);
3741
+ return check_cuda(cudaGraphExecDestroy((cudaGraphExec_t)graph_exec));
3742
+ }
3743
+ else
3744
+ {
3745
+ GraphDestroyInfo info;
3746
+ info.context = context ? context : get_current_context();
3747
+ info.graph_exec = graph_exec;
3748
+ g_deferred_graph_list.push_back(info);
3749
+ return true;
3750
+ }
3751
+ }
3752
+
3753
+ bool write_file(const char* data, size_t size, std::string filename, const char* mode)
3754
+ {
3755
+ const bool print_debug = (std::getenv("WARP_DEBUG") != nullptr);
3756
+ if (print_debug)
3757
+ {
3758
+ printf("Writing %zu B to %s (%s)\n", size, filename.c_str(), mode);
3759
+ }
3760
+ FILE* file = fopen(filename.c_str(), mode);
3761
+ if (file)
3762
+ {
3763
+ if (fwrite(data, 1, size, file) != size) {
3764
+ fprintf(stderr, "Warp error: Failed to write to output file '%s'\n", filename.c_str());
3765
+ return false;
3766
+ }
3767
+ fclose(file);
3768
+ return true;
3769
+ }
3770
+ else
3771
+ {
3772
+ fprintf(stderr, "Warp error: Failed to open file '%s'\n", filename.c_str());
3773
+ return false;
3774
+ }
3775
+ }
3776
+
3777
+ #if WP_ENABLE_MATHDX
3778
+ bool check_nvjitlink_result(nvJitLinkHandle handle, nvJitLinkResult result, const char* file, int line)
3779
+ {
3780
+ if (result != NVJITLINK_SUCCESS) {
3781
+ fprintf(stderr, "nvJitLink error: %d on %s:%d\n", (int)result, file, line);
3782
+ size_t lsize;
3783
+ result = nvJitLinkGetErrorLogSize(handle, &lsize);
3784
+ if (result == NVJITLINK_SUCCESS && lsize > 0) {
3785
+ std::vector<char> log(lsize);
3786
+ result = nvJitLinkGetErrorLog(handle, log.data());
3787
+ if (result == NVJITLINK_SUCCESS) {
3788
+ fprintf(stderr, "%s\n", log.data());
3789
+ }
3790
+ }
3791
+ return false;
3792
+ } else {
3793
+ return true;
3794
+ }
3795
+ }
3796
+ #endif
3797
+
3798
+ size_t wp_cuda_compile_program(const char* cuda_src, const char* program_name, int arch, const char* include_dir, int num_cuda_include_dirs, const char** cuda_include_dirs, bool debug, bool verbose, bool verify_fp, bool fast_math, bool fuse_fp, bool lineinfo, bool compile_time_trace, const char* output_path, size_t num_ltoirs, char** ltoirs, size_t* ltoir_sizes, int* ltoir_input_types)
3799
+ {
3800
+ // use file extension to determine whether to output PTX or CUBIN
3801
+ const char* output_ext = strrchr(output_path, '.');
3802
+ bool use_ptx = output_ext && strcmp(output_ext + 1, "ptx") == 0;
3803
+ const bool print_debug = (std::getenv("WARP_DEBUG") != nullptr);
3804
+
3805
+ // check include dir path len (path + option)
3806
+ const int max_path = 4096 + 16;
3807
+ if (strlen(include_dir) > max_path)
3808
+ {
3809
+ fprintf(stderr, "Warp error: Include path too long\n");
3810
+ return size_t(-1);
3811
+ }
3812
+
3813
+ if (print_debug)
3814
+ {
3815
+ // Not available in all nvJitLink versions
3816
+ // unsigned major = 0;
3817
+ // unsigned minor = 0;
3818
+ // nvJitLinkVersion(&major, &minor);
3819
+ // printf("nvJitLink version %d.%d\n", major, minor);
3820
+ int major = 0;
3821
+ int minor = 0;
3822
+ nvrtcVersion(&major, &minor);
3823
+ printf("NVRTC version %d.%d\n", major, minor);
3824
+ }
3825
+
3826
+ char include_opt[max_path];
3827
+ strcpy(include_opt, "--include-path=");
3828
+ strcat(include_opt, include_dir);
3829
+
3830
+ const int max_arch = 128;
3831
+ char arch_opt[max_arch];
3832
+ char arch_opt_lto[max_arch];
3833
+
3834
+ if (use_ptx)
3835
+ {
3836
+ snprintf(arch_opt, max_arch, "--gpu-architecture=compute_%d", arch);
3837
+ snprintf(arch_opt_lto, max_arch, "-arch=compute_%d", arch);
3838
+ }
3839
+ else
3840
+ {
3841
+ snprintf(arch_opt, max_arch, "--gpu-architecture=sm_%d", arch);
3842
+ snprintf(arch_opt_lto, max_arch, "-arch=sm_%d", arch);
3843
+ }
3844
+
3845
+ std::vector<const char*> opts;
3846
+ opts.push_back(arch_opt);
3847
+ opts.push_back(include_opt);
3848
+ opts.push_back("--std=c++17");
3849
+
3850
+ if (debug)
3851
+ {
3852
+ opts.push_back("--define-macro=_DEBUG");
3853
+ opts.push_back("--generate-line-info");
3854
+ #ifndef _WIN32
3855
+ opts.push_back("--device-debug"); // -G
3856
+ #endif
3857
+ }
3858
+ else
3859
+ {
3860
+ opts.push_back("--define-macro=NDEBUG");
3861
+
3862
+ if (lineinfo)
3863
+ opts.push_back("--generate-line-info");
3864
+ }
3865
+
3866
+ if (verify_fp)
3867
+ opts.push_back("--define-macro=WP_VERIFY_FP");
3868
+ else
3869
+ opts.push_back("--undefine-macro=WP_VERIFY_FP");
3870
+
3871
+ #if WP_ENABLE_MATHDX
3872
+ opts.push_back("--define-macro=WP_ENABLE_MATHDX=1");
3873
+ #else
3874
+ opts.push_back("--define-macro=WP_ENABLE_MATHDX=0");
3875
+ #endif
3876
+
3877
+ if (fast_math)
3878
+ opts.push_back("--use_fast_math");
3879
+
3880
+ if (fuse_fp)
3881
+ opts.push_back("--fmad=true");
3882
+ else
3883
+ opts.push_back("--fmad=false");
3884
+
3885
+ std::vector<std::string> stored_options;
3886
+ for(int i = 0; i < num_cuda_include_dirs; i++)
3887
+ {
3888
+ stored_options.push_back(std::string("--include-path=") + cuda_include_dirs[i]);
3889
+ opts.push_back(stored_options.back().c_str());
3890
+ }
3891
+
3892
+ opts.push_back("--device-as-default-execution-space");
3893
+ opts.push_back("--extra-device-vectorization");
3894
+ opts.push_back("--restrict");
3895
+
3896
+ if (num_ltoirs > 0)
3897
+ {
3898
+ opts.push_back("-dlto");
3899
+ opts.push_back("--relocatable-device-code=true");
3900
+ }
3901
+
3902
+ if (compile_time_trace)
3903
+ {
3904
+ #if CUDA_VERSION >= 12080
3905
+ stored_options.push_back(std::string("--fdevice-time-trace=") + std::string(output_path).append("_compile-time-trace.json"));
3906
+ opts.push_back(stored_options.back().c_str());
3907
+ #else
3908
+ fprintf(stderr, "Warp warning: CUDA version is less than 12.8, compile_time_trace is not supported\n");
3909
+ #endif
3910
+ }
3911
+
3912
+ nvrtcProgram prog;
3913
+ nvrtcResult res;
3914
+
3915
+ res = nvrtcCreateProgram(
3916
+ &prog, // prog
3917
+ cuda_src, // buffer
3918
+ program_name, // name
3919
+ 0, // numHeaders
3920
+ NULL, // headers
3921
+ NULL); // includeNames
3922
+
3923
+ if (!check_nvrtc(res))
3924
+ return size_t(res);
3925
+
3926
+ if (print_debug)
3927
+ {
3928
+ printf("NVRTC options:\n");
3929
+ for(auto o: opts) {
3930
+ printf("%s\n", o);
3931
+ }
3932
+ }
3933
+ res = nvrtcCompileProgram(prog, int(opts.size()), opts.data());
3934
+
3935
+ if (!check_nvrtc(res) || verbose)
3936
+ {
3937
+ // get program log
3938
+ size_t log_size;
3939
+ if (check_nvrtc(nvrtcGetProgramLogSize(prog, &log_size)))
3940
+ {
3941
+ std::vector<char> log(log_size);
3942
+ if (check_nvrtc(nvrtcGetProgramLog(prog, log.data())))
3943
+ {
3944
+ // todo: figure out better way to return this to python
3945
+ if (res != NVRTC_SUCCESS)
3946
+ fprintf(stderr, "%s", log.data());
3947
+ else
3948
+ fprintf(stdout, "%s", log.data());
3949
+ }
3950
+ }
3951
+
3952
+ if (res != NVRTC_SUCCESS)
3953
+ {
3954
+ nvrtcDestroyProgram(&prog);
3955
+ return size_t(res);
3956
+ }
3957
+ }
3958
+
3959
+ nvrtcResult (*get_output_size)(nvrtcProgram, size_t*);
3960
+ nvrtcResult (*get_output_data)(nvrtcProgram, char*);
3961
+ const char* output_mode;
3962
+ if(num_ltoirs > 0) {
3963
+ #if WP_ENABLE_MATHDX
3964
+ get_output_size = nvrtcGetLTOIRSize;
3965
+ get_output_data = nvrtcGetLTOIR;
3966
+ output_mode = "wb";
3967
+ #else
3968
+ fprintf(stderr, "Warp error: num_ltoirs > 0 but Warp was not built with MathDx support\n");
3969
+ return size_t(-1);
3970
+ #endif
3971
+ }
3972
+ else if (use_ptx)
3973
+ {
3974
+ get_output_size = nvrtcGetPTXSize;
3975
+ get_output_data = nvrtcGetPTX;
3976
+ output_mode = "wt";
3977
+ }
3978
+ else
3979
+ {
3980
+ get_output_size = nvrtcGetCUBINSize;
3981
+ get_output_data = nvrtcGetCUBIN;
3982
+ output_mode = "wb";
3983
+ }
3984
+
3985
+ // save output
3986
+ size_t output_size;
3987
+ res = get_output_size(prog, &output_size);
3988
+ if (check_nvrtc(res))
3989
+ {
3990
+ std::vector<char> output(output_size);
3991
+ res = get_output_data(prog, output.data());
3992
+ if (check_nvrtc(res))
3993
+ {
3994
+
3995
+ // LTOIR case - need an extra step
3996
+ if (num_ltoirs > 0)
3997
+ {
3998
+ #if WP_ENABLE_MATHDX
3999
+ if(ltoir_input_types == nullptr || ltoirs == nullptr || ltoir_sizes == nullptr) {
4000
+ fprintf(stderr, "Warp error: num_ltoirs > 0 but ltoir_input_types, ltoirs or ltoir_sizes are NULL\n");
4001
+ return size_t(-1);
4002
+ }
4003
+ nvJitLinkHandle handle = nullptr;
4004
+ std::vector<const char *> lopts = {"-dlto", arch_opt_lto};
4005
+ if (use_ptx) {
4006
+ lopts.push_back("-ptx");
4007
+ }
4008
+ if (print_debug)
4009
+ {
4010
+ printf("nvJitLink options:\n");
4011
+ for(auto o: lopts) {
4012
+ printf("%s\n", o);
4013
+ }
4014
+ }
4015
+ if(!check_nvjitlink(handle, nvJitLinkCreate(&handle, lopts.size(), lopts.data())))
4016
+ {
4017
+ res = nvrtcResult(-1);
4018
+ }
4019
+ // Links
4020
+ if(std::getenv("WARP_DUMP_LTOIR"))
4021
+ {
4022
+ write_file(output.data(), output.size(), "nvrtc_output.ltoir", "wb");
4023
+ }
4024
+ if(!check_nvjitlink(handle, nvJitLinkAddData(handle, NVJITLINK_INPUT_LTOIR, output.data(), output.size(), "nvrtc_output"))) // NVRTC business
4025
+ {
4026
+ res = nvrtcResult(-1);
4027
+ }
4028
+ for(size_t ltoidx = 0; ltoidx < num_ltoirs; ltoidx++)
4029
+ {
4030
+ nvJitLinkInputType input_type = static_cast<nvJitLinkInputType>(ltoir_input_types[ltoidx]);
4031
+ const char* ext = ".unknown";
4032
+ switch(input_type) {
4033
+ case NVJITLINK_INPUT_CUBIN:
4034
+ ext = ".cubin";
4035
+ break;
4036
+ case NVJITLINK_INPUT_LTOIR:
4037
+ ext = ".ltoir";
4038
+ break;
4039
+ case NVJITLINK_INPUT_FATBIN:
4040
+ ext = ".fatbin";
4041
+ break;
4042
+ default:
4043
+ break;
4044
+ }
4045
+ if(std::getenv("WARP_DUMP_LTOIR"))
4046
+ {
4047
+ write_file(ltoirs[ltoidx], ltoir_sizes[ltoidx], std::string("lto_online_") + std::to_string(ltoidx) + ext, "wb");
4048
+ }
4049
+ if(!check_nvjitlink(handle, nvJitLinkAddData(handle, input_type, ltoirs[ltoidx], ltoir_sizes[ltoidx], "lto_online"))) // External LTOIR
4050
+ {
4051
+ res = nvrtcResult(-1);
4052
+ }
4053
+ }
4054
+ if(!check_nvjitlink(handle, nvJitLinkComplete(handle)))
4055
+ {
4056
+ res = nvrtcResult(-1);
4057
+ }
4058
+ else
4059
+ {
4060
+ if(use_ptx)
4061
+ {
4062
+ size_t ptx_size = 0;
4063
+ check_nvjitlink(handle, nvJitLinkGetLinkedPtxSize(handle, &ptx_size));
4064
+ std::vector<char> ptx(ptx_size);
4065
+ check_nvjitlink(handle, nvJitLinkGetLinkedPtx(handle, ptx.data()));
4066
+ output = ptx;
4067
+ }
4068
+ else
4069
+ {
4070
+ size_t cubin_size = 0;
4071
+ check_nvjitlink(handle, nvJitLinkGetLinkedCubinSize(handle, &cubin_size));
4072
+ std::vector<char> cubin(cubin_size);
4073
+ check_nvjitlink(handle, nvJitLinkGetLinkedCubin(handle, cubin.data()));
4074
+ output = cubin;
4075
+ }
4076
+ }
4077
+ check_nvjitlink(handle, nvJitLinkDestroy(&handle));
4078
+ #else
4079
+ fprintf(stderr, "Warp error: num_ltoirs > 0 but Warp was not built with MathDx support\n");
4080
+ return size_t(-1);
4081
+ #endif
4082
+ }
4083
+
4084
+ if(!write_file(output.data(), output.size(), output_path, output_mode)) {
4085
+ res = nvrtcResult(-1);
4086
+ }
4087
+ }
4088
+ }
4089
+
4090
+ check_nvrtc(nvrtcDestroyProgram(&prog));
4091
+
4092
+ return res;
4093
+ }
4094
+
4095
+ #if WP_ENABLE_MATHDX
4096
+ bool check_cufftdx_result(commondxStatusType result, const char* file, int line)
4097
+ {
4098
+ if (result != commondxStatusType::COMMONDX_SUCCESS) {
4099
+ fprintf(stderr, "libmathdx cuFFTDx error: %d on %s:%d\n", (int)result, file, line);
4100
+ return false;
4101
+ } else {
4102
+ return true;
4103
+ }
4104
+ }
4105
+
4106
+ bool check_cublasdx_result(commondxStatusType result, const char* file, int line)
4107
+ {
4108
+ if (result != commondxStatusType::COMMONDX_SUCCESS) {
4109
+ fprintf(stderr, "libmathdx cuBLASDx error: %d on %s:%d\n", (int)result, file, line);
4110
+ return false;
4111
+ } else {
4112
+ return true;
4113
+ }
4114
+ }
4115
+
4116
+ bool check_cusolver_result(commondxStatusType result, const char* file, int line)
4117
+ {
4118
+ if (result != commondxStatusType::COMMONDX_SUCCESS) {
4119
+ fprintf(stderr, "libmathdx cuSOLVER error: %d on %s:%d\n", (int)result, file, line);
4120
+ return false;
4121
+ } else {
4122
+ return true;
4123
+ }
4124
+ }
4125
+
4126
+ bool wp_cuda_compile_fft(const char* ltoir_output_path, const char* symbol_name, int num_include_dirs, const char** include_dirs, const char* mathdx_include_dir, int arch, int size, int elements_per_thread, int direction, int precision, int* shared_memory_size)
4127
+ {
4128
+
4129
+ CHECK_ANY(ltoir_output_path != nullptr);
4130
+ CHECK_ANY(symbol_name != nullptr);
4131
+ CHECK_ANY(shared_memory_size != nullptr);
4132
+ // Includes currently unused
4133
+ CHECK_ANY(include_dirs == nullptr);
4134
+ CHECK_ANY(mathdx_include_dir == nullptr);
4135
+ CHECK_ANY(num_include_dirs == 0);
4136
+
4137
+ bool res = true;
4138
+ cufftdxDescriptor h;
4139
+ CHECK_CUFFTDX(cufftdxCreateDescriptor(&h));
4140
+
4141
+ // CUFFTDX_API_LMEM means each thread starts with a subset of the data
4142
+ CHECK_CUFFTDX(cufftdxSetOperatorInt64(h, cufftdxOperatorType::CUFFTDX_OPERATOR_API, cufftdxApi::CUFFTDX_API_LMEM));
4143
+ CHECK_CUFFTDX(cufftdxSetOperatorInt64(h, cufftdxOperatorType::CUFFTDX_OPERATOR_EXECUTION, commondxExecution::COMMONDX_EXECUTION_BLOCK));
4144
+ CHECK_CUFFTDX(cufftdxSetOperatorInt64(h, cufftdxOperatorType::CUFFTDX_OPERATOR_SIZE, (long long)size));
4145
+ CHECK_CUFFTDX(cufftdxSetOperatorInt64(h, cufftdxOperatorType::CUFFTDX_OPERATOR_DIRECTION, (cufftdxDirection)direction));
4146
+ CHECK_CUFFTDX(cufftdxSetOperatorInt64(h, cufftdxOperatorType::CUFFTDX_OPERATOR_PRECISION, (commondxPrecision)precision));
4147
+ CHECK_CUFFTDX(cufftdxSetOperatorInt64(h, cufftdxOperatorType::CUFFTDX_OPERATOR_SM, (long long)(arch * 10)));
4148
+ CHECK_CUFFTDX(cufftdxSetOperatorInt64(h, cufftdxOperatorType::CUFFTDX_OPERATOR_ELEMENTS_PER_THREAD, (long long)(elements_per_thread)));
4149
+ CHECK_CUFFTDX(cufftdxSetOperatorInt64(h, cufftdxOperatorType::CUFFTDX_OPERATOR_FFTS_PER_BLOCK, 1));
4150
+
4151
+ CHECK_CUFFTDX(cufftdxSetOptionStr(h, commondxOption::COMMONDX_OPTION_SYMBOL_NAME, symbol_name));
4152
+
4153
+ size_t lto_size = 0;
4154
+ CHECK_CUFFTDX(cufftdxGetLTOIRSize(h, &lto_size));
4155
+
4156
+ std::vector<char> lto(lto_size);
4157
+ CHECK_CUFFTDX(cufftdxGetLTOIR(h, lto.size(), lto.data()));
4158
+
4159
+ long long int smem = 0;
4160
+ CHECK_CUFFTDX(cufftdxGetTraitInt64(h, cufftdxTraitType::CUFFTDX_TRAIT_SHARED_MEMORY_SIZE, &smem));
4161
+ *shared_memory_size = (int)smem;
4162
+
4163
+ if(!write_file(lto.data(), lto.size(), ltoir_output_path, "wb")) {
4164
+ res = false;
4165
+ }
4166
+
4167
+ CHECK_CUFFTDX(cufftdxDestroyDescriptor(h));
4168
+
4169
+ return res;
4170
+ }
4171
+
4172
+ bool wp_cuda_compile_dot(const char* ltoir_output_path, const char* symbol_name, int num_include_dirs, const char** include_dirs, const char* mathdx_include_dir, int arch, int M, int N, int K, int precision_A, int precision_B, int precision_C, int type, int arrangement_A, int arrangement_B, int arrangement_C, int num_threads)
4173
+ {
4174
+
4175
+ CHECK_ANY(ltoir_output_path != nullptr);
4176
+ CHECK_ANY(symbol_name != nullptr);
4177
+ // Includes currently unused
4178
+ CHECK_ANY(include_dirs == nullptr);
4179
+ CHECK_ANY(mathdx_include_dir == nullptr);
4180
+ CHECK_ANY(num_include_dirs == 0);
4181
+
4182
+ bool res = true;
4183
+ cublasdxDescriptor h;
4184
+ CHECK_CUBLASDX(cublasdxCreateDescriptor(&h));
4185
+
4186
+ CHECK_CUBLASDX(cublasdxSetOperatorInt64(h, cublasdxOperatorType::CUBLASDX_OPERATOR_FUNCTION, cublasdxFunction::CUBLASDX_FUNCTION_MM));
4187
+ CHECK_CUBLASDX(cublasdxSetOperatorInt64(h, cublasdxOperatorType::CUBLASDX_OPERATOR_EXECUTION, commondxExecution::COMMONDX_EXECUTION_BLOCK));
4188
+ CHECK_CUBLASDX(cublasdxSetOperatorInt64(h, cublasdxOperatorType::CUBLASDX_OPERATOR_API, cublasdxApi::CUBLASDX_API_SMEM));
4189
+ std::array<long long int, 3> precisions = {precision_A, precision_B, precision_C};
4190
+ CHECK_CUBLASDX(cublasdxSetOperatorInt64s(h, cublasdxOperatorType::CUBLASDX_OPERATOR_PRECISION, 3, precisions.data()));
4191
+ CHECK_CUBLASDX(cublasdxSetOperatorInt64(h, cublasdxOperatorType::CUBLASDX_OPERATOR_SM, (long long)(arch * 10)));
4192
+ CHECK_CUBLASDX(cublasdxSetOperatorInt64(h, cublasdxOperatorType::CUBLASDX_OPERATOR_TYPE, (cublasdxType)type));
4193
+ std::array<long long int, 3> block_dim = {num_threads, 1, 1};
4194
+ CHECK_CUBLASDX(cublasdxSetOperatorInt64s(h, cublasdxOperatorType::CUBLASDX_OPERATOR_BLOCK_DIM, block_dim.size(), block_dim.data()));
4195
+ std::array<long long int, 3> size = {M, N, K};
4196
+ CHECK_CUBLASDX(cublasdxSetOperatorInt64s(h, cublasdxOperatorType::CUBLASDX_OPERATOR_SIZE, size.size(), size.data()));
4197
+ std::array<long long int, 3> arrangement = {arrangement_A, arrangement_B, arrangement_C};
4198
+ CHECK_CUBLASDX(cublasdxSetOperatorInt64s(h, cublasdxOperatorType::CUBLASDX_OPERATOR_ARRANGEMENT, arrangement.size(), arrangement.data()));
4199
+
4200
+ CHECK_CUBLASDX(cublasdxSetOptionStr(h, commondxOption::COMMONDX_OPTION_SYMBOL_NAME, symbol_name));
4201
+
4202
+ size_t lto_size = 0;
4203
+ CHECK_CUBLASDX(cublasdxGetLTOIRSize(h, &lto_size));
4204
+
4205
+ std::vector<char> lto(lto_size);
4206
+ CHECK_CUBLASDX(cublasdxGetLTOIR(h, lto.size(), lto.data()));
4207
+
4208
+ if(!write_file(lto.data(), lto.size(), ltoir_output_path, "wb")) {
4209
+ res = false;
4210
+ }
4211
+
4212
+ CHECK_CUBLASDX(cublasdxDestroyDescriptor(h));
4213
+
4214
+ return res;
4215
+ }
4216
+
4217
+ bool wp_cuda_compile_solver(const char* fatbin_output_path, const char* ltoir_output_path, const char* symbol_name, int num_include_dirs, const char** include_dirs, const char* mathdx_include_dir, int arch, int M, int N, int NRHS, int function, int side, int diag, int precision, int arrangement_A, int arrangement_B, int fill_mode, int num_threads)
4218
+ {
4219
+
4220
+ CHECK_ANY(ltoir_output_path != nullptr);
4221
+ CHECK_ANY(symbol_name != nullptr);
4222
+ CHECK_ANY(mathdx_include_dir == nullptr);
4223
+ CHECK_ANY(num_include_dirs == 0);
4224
+ CHECK_ANY(include_dirs == nullptr);
4225
+
4226
+ bool res = true;
4227
+
4228
+ cusolverdxDescriptor h { 0 };
4229
+ CHECK_CUSOLVER(cusolverdxCreateDescriptor(&h));
4230
+ std::array<long long int, 3> size = {M, N, NRHS};
4231
+ CHECK_CUSOLVER(cusolverdxSetOperatorInt64s(h, cusolverdxOperatorType::CUSOLVERDX_OPERATOR_SIZE, size.size(), size.data()));
4232
+ std::array<long long int, 3> block_dim = {num_threads, 1, 1};
4233
+ CHECK_CUSOLVER(cusolverdxSetOperatorInt64s(h, cusolverdxOperatorType::CUSOLVERDX_OPERATOR_BLOCK_DIM, block_dim.size(), block_dim.data()));
4234
+ CHECK_CUSOLVER(cusolverdxSetOperatorInt64(h, cusolverdxOperatorType::CUSOLVERDX_OPERATOR_TYPE, cusolverdxType::CUSOLVERDX_TYPE_REAL));
4235
+ CHECK_CUSOLVER(cusolverdxSetOperatorInt64(h, cusolverdxOperatorType::CUSOLVERDX_OPERATOR_API, cusolverdxApi::CUSOLVERDX_API_SMEM));
4236
+ CHECK_CUSOLVER(cusolverdxSetOperatorInt64(h, cusolverdxOperatorType::CUSOLVERDX_OPERATOR_FUNCTION, (cusolverdxFunction)function));
4237
+ if (side >= 0) {
4238
+ CHECK_CUSOLVER(cusolverdxSetOperatorInt64(h, cusolverdxOperatorType::CUSOLVERDX_OPERATOR_SIDE, (cusolverdxSide)side));
4239
+ }
4240
+ if (diag >= 0) {
4241
+ CHECK_CUSOLVER(cusolverdxSetOperatorInt64(h, cusolverdxOperatorType::CUSOLVERDX_OPERATOR_DIAG, (cusolverdxDiag)diag));
4242
+ }
4243
+ CHECK_CUSOLVER(cusolverdxSetOperatorInt64(h, cusolverdxOperatorType::CUSOLVERDX_OPERATOR_EXECUTION, commondxExecution::COMMONDX_EXECUTION_BLOCK));
4244
+ CHECK_CUSOLVER(cusolverdxSetOperatorInt64(h, cusolverdxOperatorType::CUSOLVERDX_OPERATOR_PRECISION, (commondxPrecision)precision));
4245
+ std::array<long long int, 2> arrangement = {arrangement_A, arrangement_B};
4246
+ CHECK_CUSOLVER(cusolverdxSetOperatorInt64s(h, cusolverdxOperatorType::CUSOLVERDX_OPERATOR_ARRANGEMENT, arrangement.size(), arrangement.data()));
4247
+ CHECK_CUSOLVER(cusolverdxSetOperatorInt64(h, cusolverdxOperatorType::CUSOLVERDX_OPERATOR_FILL_MODE, (cusolverdxFillMode)fill_mode));
4248
+ CHECK_CUSOLVER(cusolverdxSetOperatorInt64(h, cusolverdxOperatorType::CUSOLVERDX_OPERATOR_SM, (long long)(arch * 10)));
4249
+
4250
+ CHECK_CUSOLVER(cusolverdxSetOptionStr(h, commondxOption::COMMONDX_OPTION_SYMBOL_NAME, symbol_name));
4251
+
4252
+ size_t lto_size = 0;
4253
+ CHECK_CUSOLVER(cusolverdxGetLTOIRSize(h, &lto_size));
4254
+
4255
+ std::vector<char> lto(lto_size);
4256
+ CHECK_CUSOLVER(cusolverdxGetLTOIR(h, lto.size(), lto.data()));
4257
+
4258
+ // This fatbin is universal, ie it is the same for any instantiations of a cusolver device function
4259
+ size_t fatbin_size = 0;
4260
+ CHECK_CUSOLVER(cusolverdxGetUniversalFATBINSize(h, &fatbin_size));
4261
+
4262
+ std::vector<char> fatbin(fatbin_size);
4263
+ CHECK_CUSOLVER(cusolverdxGetUniversalFATBIN(h, fatbin.size(), fatbin.data()));
4264
+
4265
+ if(!write_file(lto.data(), lto.size(), ltoir_output_path, "wb")) {
4266
+ res = false;
4267
+ }
4268
+
4269
+ if(!write_file(fatbin.data(), fatbin.size(), fatbin_output_path, "wb")) {
4270
+ res = false;
4271
+ }
4272
+
4273
+ CHECK_CUSOLVER(cusolverdxDestroyDescriptor(h));
4274
+
4275
+ return res;
4276
+ }
4277
+
4278
+ #endif
4279
+
4280
+ void* wp_cuda_load_module(void* context, const char* path)
4281
+ {
4282
+ ContextGuard guard(context);
4283
+
4284
+ // use file extension to determine whether to load PTX or CUBIN
4285
+ const char* input_ext = strrchr(path, '.');
4286
+ bool load_ptx = input_ext && strcmp(input_ext + 1, "ptx") == 0;
4287
+
4288
+ std::vector<char> input;
4289
+
4290
+ FILE* file = fopen(path, "rb");
4291
+ if (file)
4292
+ {
4293
+ fseek(file, 0, SEEK_END);
4294
+ size_t length = ftell(file);
4295
+ fseek(file, 0, SEEK_SET);
4296
+
4297
+ input.resize(length + 1);
4298
+ if (fread(input.data(), 1, length, file) != length)
4299
+ {
4300
+ fprintf(stderr, "Warp error: Failed to read input file '%s'\n", path);
4301
+ fclose(file);
4302
+ return NULL;
4303
+ }
4304
+ fclose(file);
4305
+
4306
+ input[length] = '\0';
4307
+ }
4308
+ else
4309
+ {
4310
+ fprintf(stderr, "Warp error: Failed to open input file '%s'\n", path);
4311
+ return NULL;
4312
+ }
4313
+
4314
+ int driver_cuda_version = 0;
4315
+ CUmodule module = NULL;
4316
+
4317
+ if (load_ptx)
4318
+ {
4319
+ if (check_cu(cuDriverGetVersion_f(&driver_cuda_version)) && driver_cuda_version >= CUDA_VERSION)
4320
+ {
4321
+ // let the driver compile the PTX
4322
+
4323
+ CUjit_option options[2];
4324
+ void *option_vals[2];
4325
+ char error_log[8192] = "";
4326
+ unsigned int log_size = 8192;
4327
+ // Set up loader options
4328
+ // Pass a buffer for error message
4329
+ options[0] = CU_JIT_ERROR_LOG_BUFFER;
4330
+ option_vals[0] = (void*)error_log;
4331
+ // Pass the size of the error buffer
4332
+ options[1] = CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES;
4333
+ option_vals[1] = (void*)(size_t)log_size;
4334
+
4335
+ if (!check_cu(cuModuleLoadDataEx_f(&module, input.data(), 2, options, option_vals)))
4336
+ {
4337
+ fprintf(stderr, "Warp error: Loading PTX module failed\n");
4338
+ // print error log if not empty
4339
+ if (*error_log)
4340
+ fprintf(stderr, "PTX loader error:\n%s\n", error_log);
4341
+ return NULL;
4342
+ }
4343
+ }
4344
+ else
4345
+ {
4346
+ // manually compile the PTX and load as CUBIN
4347
+
4348
+ ContextInfo* context_info = get_context_info(static_cast<CUcontext>(context));
4349
+ if (!context_info || !context_info->device_info)
4350
+ {
4351
+ fprintf(stderr, "Warp error: Failed to determine target architecture\n");
4352
+ return NULL;
4353
+ }
4354
+
4355
+ int arch = context_info->device_info->arch;
4356
+
4357
+ char arch_opt[128];
4358
+ sprintf(arch_opt, "--gpu-name=sm_%d", arch);
4359
+
4360
+ const char* compiler_options[] = { arch_opt };
4361
+
4362
+ nvPTXCompilerHandle compiler = NULL;
4363
+ if (!check_nvptx(nvPTXCompilerCreate(&compiler, input.size(), input.data())))
4364
+ return NULL;
4365
+
4366
+ if (!check_nvptx(nvPTXCompilerCompile(compiler, sizeof(compiler_options) / sizeof(*compiler_options), compiler_options)))
4367
+ return NULL;
4368
+
4369
+ size_t cubin_size = 0;
4370
+ if (!check_nvptx(nvPTXCompilerGetCompiledProgramSize(compiler, &cubin_size)))
4371
+ return NULL;
4372
+
4373
+ std::vector<char> cubin(cubin_size);
4374
+ if (!check_nvptx(nvPTXCompilerGetCompiledProgram(compiler, cubin.data())))
4375
+ return NULL;
4376
+
4377
+ check_nvptx(nvPTXCompilerDestroy(&compiler));
4378
+
4379
+ if (!check_cu(cuModuleLoadDataEx_f(&module, cubin.data(), 0, NULL, NULL)))
4380
+ {
4381
+ fprintf(stderr, "Warp CUDA error: Loading module failed\n");
4382
+ return NULL;
4383
+ }
4384
+ }
4385
+ }
4386
+ else
4387
+ {
4388
+ // load CUBIN
4389
+ if (!check_cu(cuModuleLoadDataEx_f(&module, input.data(), 0, NULL, NULL)))
4390
+ {
4391
+ fprintf(stderr, "Warp CUDA error: Loading module failed\n");
4392
+ return NULL;
4393
+ }
4394
+ }
4395
+
4396
+ return module;
4397
+ }
4398
+
4399
+ void wp_cuda_unload_module(void* context, void* module)
4400
+ {
4401
+ // ensure there are no graph captures in progress
4402
+ if (g_captures.empty())
4403
+ {
4404
+ ContextGuard guard(context);
4405
+ check_cu(cuModuleUnload_f((CUmodule)module));
4406
+ }
4407
+ else
4408
+ {
4409
+ // defer until graph capture completes
4410
+ ModuleInfo module_info;
4411
+ module_info.context = context ? context : get_current_context();
4412
+ module_info.module = module;
4413
+ g_deferred_module_list.push_back(module_info);
4414
+ }
4415
+ }
4416
+
4417
+
4418
+ int wp_cuda_get_max_shared_memory(void* context)
4419
+ {
4420
+ ContextInfo* info = get_context_info(context);
4421
+ if (!info)
4422
+ return -1;
4423
+
4424
+ int max_smem_bytes = info->device_info->max_smem_bytes;
4425
+ return max_smem_bytes;
4426
+ }
4427
+
4428
+ bool wp_cuda_configure_kernel_shared_memory(void* kernel, int size)
4429
+ {
4430
+ int requested_smem_bytes = size;
4431
+
4432
+ // configure shared memory
4433
+ CUresult res = cuFuncSetAttribute_f((CUfunction)kernel, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, requested_smem_bytes);
4434
+ if (res != CUDA_SUCCESS)
4435
+ return false;
4436
+
4437
+ return true;
4438
+ }
4439
+
4440
+ void* wp_cuda_get_kernel(void* context, void* module, const char* name)
4441
+ {
4442
+ ContextGuard guard(context);
4443
+
4444
+ CUfunction kernel = NULL;
4445
+ if (!check_cu(cuModuleGetFunction_f(&kernel, (CUmodule)module, name)))
4446
+ {
4447
+ fprintf(stderr, "Warp CUDA error: Failed to lookup kernel function %s in module\n", name);
4448
+ return NULL;
4449
+ }
4450
+
4451
+ g_kernel_names[kernel] = name;
4452
+ return kernel;
4453
+ }
4454
+
4455
+ size_t wp_cuda_launch_kernel(void* context, void* kernel, size_t dim, int max_blocks, int block_dim, int shared_memory_bytes, void** args, void* stream)
4456
+ {
4457
+ ContextGuard guard(context);
4458
+
4459
+ if (block_dim <= 0)
4460
+ {
4461
+ #if defined(_DEBUG)
4462
+ fprintf(stderr, "Warp warning: Launch got block_dim %d. Setting to 256.\n", block_dim);
4463
+ #endif
4464
+ block_dim = 256;
4465
+ }
4466
+
4467
+ // CUDA specs up to compute capability 9.0 says the max x-dim grid is 2**31-1, so
4468
+ // grid_dim is fine as an int for the near future
4469
+ int grid_dim = (dim + block_dim - 1)/block_dim;
4470
+
4471
+ if (max_blocks <= 0) {
4472
+ max_blocks = 2147483647;
4473
+ }
4474
+
4475
+ if (grid_dim < 0)
4476
+ {
4477
+ #if defined(_DEBUG)
4478
+ fprintf(stderr, "Warp warning: Overflow in grid dimensions detected for %zu total elements and 256 threads "
4479
+ "per block.\n Setting block count to %d.\n", dim, max_blocks);
4480
+ #endif
4481
+ grid_dim = max_blocks;
4482
+ }
4483
+ else
4484
+ {
4485
+ if (grid_dim > max_blocks)
4486
+ {
4487
+ grid_dim = max_blocks;
4488
+ }
4489
+ }
4490
+
4491
+ begin_cuda_range(WP_TIMING_KERNEL, stream, context, get_cuda_kernel_name(kernel));
4492
+
4493
+ CUresult res = cuLaunchKernel_f(
4494
+ (CUfunction)kernel,
4495
+ grid_dim, 1, 1,
4496
+ block_dim, 1, 1,
4497
+ shared_memory_bytes,
4498
+ static_cast<CUstream>(stream),
4499
+ args,
4500
+ 0);
4501
+
4502
+ check_cu(res);
4503
+
4504
+ end_cuda_range(WP_TIMING_KERNEL, stream);
4505
+
4506
+ return res;
4507
+ }
4508
+
4509
+ void wp_cuda_graphics_map(void* context, void* resource)
4510
+ {
4511
+ ContextGuard guard(context);
4512
+
4513
+ check_cu(cuGraphicsMapResources_f(1, (CUgraphicsResource*)resource, get_current_stream()));
4514
+ }
4515
+
4516
+ void wp_cuda_graphics_unmap(void* context, void* resource)
4517
+ {
4518
+ ContextGuard guard(context);
4519
+
4520
+ check_cu(cuGraphicsUnmapResources_f(1, (CUgraphicsResource*)resource, get_current_stream()));
4521
+ }
4522
+
4523
+ void wp_cuda_graphics_device_ptr_and_size(void* context, void* resource, uint64_t* ptr, size_t* size)
4524
+ {
4525
+ ContextGuard guard(context);
4526
+
4527
+ CUdeviceptr device_ptr;
4528
+ size_t bytes;
4529
+ check_cu(cuGraphicsResourceGetMappedPointer_f(&device_ptr, &bytes, *(CUgraphicsResource*)resource));
4530
+
4531
+ *ptr = device_ptr;
4532
+ *size = bytes;
4533
+ }
4534
+
4535
+ void* wp_cuda_graphics_register_gl_buffer(void* context, uint32_t gl_buffer, unsigned int flags)
4536
+ {
4537
+ ContextGuard guard(context);
4538
+
4539
+ CUgraphicsResource *resource = new CUgraphicsResource;
4540
+ bool success = check_cu(cuGraphicsGLRegisterBuffer_f(resource, gl_buffer, flags));
4541
+ if (!success)
4542
+ {
4543
+ delete resource;
4544
+ return NULL;
4545
+ }
4546
+
4547
+ return resource;
4548
+ }
4549
+
4550
+ void wp_cuda_graphics_unregister_resource(void* context, void* resource)
4551
+ {
4552
+ ContextGuard guard(context);
4553
+
4554
+ CUgraphicsResource *res = (CUgraphicsResource*)resource;
4555
+ check_cu(cuGraphicsUnregisterResource_f(*res));
4556
+ delete res;
4557
+ }
4558
+
4559
+ void wp_cuda_timing_begin(int flags)
4560
+ {
4561
+ g_cuda_timing_state = new CudaTimingState(flags, g_cuda_timing_state);
4562
+ }
4563
+
4564
+ int wp_cuda_timing_get_result_count()
4565
+ {
4566
+ if (g_cuda_timing_state)
4567
+ return int(g_cuda_timing_state->ranges.size());
4568
+ return 0;
4569
+ }
4570
+
4571
+ void wp_cuda_timing_end(timing_result_t* results, int size)
4572
+ {
4573
+ if (!g_cuda_timing_state)
4574
+ return;
4575
+
4576
+ // number of results to write to the user buffer
4577
+ int count = std::min(wp_cuda_timing_get_result_count(), size);
4578
+
4579
+ // compute timings and write results
4580
+ for (int i = 0; i < count; i++)
4581
+ {
4582
+ const CudaTimingRange& range = g_cuda_timing_state->ranges[i];
4583
+ timing_result_t& result = results[i];
4584
+ result.context = range.context;
4585
+ result.name = range.name;
4586
+ result.flag = range.flag;
4587
+ check_cuda(cudaEventElapsedTime(&result.elapsed, range.start, range.end));
4588
+ }
4589
+
4590
+ // release events
4591
+ for (CudaTimingRange& range : g_cuda_timing_state->ranges)
4592
+ {
4593
+ check_cu(cuEventDestroy_f(range.start));
4594
+ check_cu(cuEventDestroy_f(range.end));
4595
+ }
4596
+
4597
+ // restore previous state
4598
+ CudaTimingState* parent_state = g_cuda_timing_state->parent;
4599
+ delete g_cuda_timing_state;
4600
+ g_cuda_timing_state = parent_state;
4601
+ }
4602
+
4603
+ //#include "spline.inl"
4604
+ //#include "volume.inl"