warp-lang 1.7.0__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (429) hide show
  1. warp/__init__.py +139 -0
  2. warp/__init__.pyi +1 -0
  3. warp/autograd.py +1142 -0
  4. warp/bin/warp-clang.so +0 -0
  5. warp/bin/warp.so +0 -0
  6. warp/build.py +557 -0
  7. warp/build_dll.py +405 -0
  8. warp/builtins.py +6855 -0
  9. warp/codegen.py +3969 -0
  10. warp/config.py +158 -0
  11. warp/constants.py +57 -0
  12. warp/context.py +6812 -0
  13. warp/dlpack.py +462 -0
  14. warp/examples/__init__.py +24 -0
  15. warp/examples/assets/bear.usd +0 -0
  16. warp/examples/assets/bunny.usd +0 -0
  17. warp/examples/assets/cartpole.urdf +110 -0
  18. warp/examples/assets/crazyflie.usd +0 -0
  19. warp/examples/assets/cube.usd +0 -0
  20. warp/examples/assets/nonuniform.usd +0 -0
  21. warp/examples/assets/nv_ant.xml +92 -0
  22. warp/examples/assets/nv_humanoid.xml +183 -0
  23. warp/examples/assets/nvidia_logo.png +0 -0
  24. warp/examples/assets/pixel.jpg +0 -0
  25. warp/examples/assets/quadruped.urdf +268 -0
  26. warp/examples/assets/rocks.nvdb +0 -0
  27. warp/examples/assets/rocks.usd +0 -0
  28. warp/examples/assets/sphere.usd +0 -0
  29. warp/examples/assets/square_cloth.usd +0 -0
  30. warp/examples/benchmarks/benchmark_api.py +389 -0
  31. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  32. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  33. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  34. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  35. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  36. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  37. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  38. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  39. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  40. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  41. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  42. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  43. warp/examples/benchmarks/benchmark_launches.py +301 -0
  44. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  45. warp/examples/browse.py +37 -0
  46. warp/examples/core/example_cupy.py +86 -0
  47. warp/examples/core/example_dem.py +241 -0
  48. warp/examples/core/example_fluid.py +299 -0
  49. warp/examples/core/example_graph_capture.py +150 -0
  50. warp/examples/core/example_marching_cubes.py +194 -0
  51. warp/examples/core/example_mesh.py +180 -0
  52. warp/examples/core/example_mesh_intersect.py +211 -0
  53. warp/examples/core/example_nvdb.py +182 -0
  54. warp/examples/core/example_raycast.py +111 -0
  55. warp/examples/core/example_raymarch.py +205 -0
  56. warp/examples/core/example_render_opengl.py +193 -0
  57. warp/examples/core/example_sample_mesh.py +300 -0
  58. warp/examples/core/example_sph.py +411 -0
  59. warp/examples/core/example_torch.py +211 -0
  60. warp/examples/core/example_wave.py +269 -0
  61. warp/examples/fem/example_adaptive_grid.py +286 -0
  62. warp/examples/fem/example_apic_fluid.py +423 -0
  63. warp/examples/fem/example_burgers.py +261 -0
  64. warp/examples/fem/example_convection_diffusion.py +178 -0
  65. warp/examples/fem/example_convection_diffusion_dg.py +204 -0
  66. warp/examples/fem/example_deformed_geometry.py +172 -0
  67. warp/examples/fem/example_diffusion.py +196 -0
  68. warp/examples/fem/example_diffusion_3d.py +225 -0
  69. warp/examples/fem/example_diffusion_mgpu.py +220 -0
  70. warp/examples/fem/example_distortion_energy.py +228 -0
  71. warp/examples/fem/example_magnetostatics.py +240 -0
  72. warp/examples/fem/example_mixed_elasticity.py +291 -0
  73. warp/examples/fem/example_navier_stokes.py +261 -0
  74. warp/examples/fem/example_nonconforming_contact.py +298 -0
  75. warp/examples/fem/example_stokes.py +213 -0
  76. warp/examples/fem/example_stokes_transfer.py +262 -0
  77. warp/examples/fem/example_streamlines.py +352 -0
  78. warp/examples/fem/utils.py +1000 -0
  79. warp/examples/interop/example_jax_callable.py +116 -0
  80. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  81. warp/examples/interop/example_jax_kernel.py +205 -0
  82. warp/examples/optim/example_bounce.py +266 -0
  83. warp/examples/optim/example_cloth_throw.py +228 -0
  84. warp/examples/optim/example_diffray.py +561 -0
  85. warp/examples/optim/example_drone.py +870 -0
  86. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  87. warp/examples/optim/example_inverse_kinematics.py +182 -0
  88. warp/examples/optim/example_inverse_kinematics_torch.py +191 -0
  89. warp/examples/optim/example_softbody_properties.py +400 -0
  90. warp/examples/optim/example_spring_cage.py +245 -0
  91. warp/examples/optim/example_trajectory.py +227 -0
  92. warp/examples/sim/example_cartpole.py +143 -0
  93. warp/examples/sim/example_cloth.py +225 -0
  94. warp/examples/sim/example_cloth_self_contact.py +322 -0
  95. warp/examples/sim/example_granular.py +130 -0
  96. warp/examples/sim/example_granular_collision_sdf.py +202 -0
  97. warp/examples/sim/example_jacobian_ik.py +244 -0
  98. warp/examples/sim/example_particle_chain.py +124 -0
  99. warp/examples/sim/example_quadruped.py +203 -0
  100. warp/examples/sim/example_rigid_chain.py +203 -0
  101. warp/examples/sim/example_rigid_contact.py +195 -0
  102. warp/examples/sim/example_rigid_force.py +133 -0
  103. warp/examples/sim/example_rigid_gyroscopic.py +115 -0
  104. warp/examples/sim/example_rigid_soft_contact.py +140 -0
  105. warp/examples/sim/example_soft_body.py +196 -0
  106. warp/examples/tile/example_tile_cholesky.py +87 -0
  107. warp/examples/tile/example_tile_convolution.py +66 -0
  108. warp/examples/tile/example_tile_fft.py +55 -0
  109. warp/examples/tile/example_tile_filtering.py +113 -0
  110. warp/examples/tile/example_tile_matmul.py +85 -0
  111. warp/examples/tile/example_tile_mlp.py +383 -0
  112. warp/examples/tile/example_tile_nbody.py +199 -0
  113. warp/examples/tile/example_tile_walker.py +327 -0
  114. warp/fabric.py +355 -0
  115. warp/fem/__init__.py +106 -0
  116. warp/fem/adaptivity.py +508 -0
  117. warp/fem/cache.py +572 -0
  118. warp/fem/dirichlet.py +202 -0
  119. warp/fem/domain.py +411 -0
  120. warp/fem/field/__init__.py +125 -0
  121. warp/fem/field/field.py +619 -0
  122. warp/fem/field/nodal_field.py +326 -0
  123. warp/fem/field/restriction.py +37 -0
  124. warp/fem/field/virtual.py +848 -0
  125. warp/fem/geometry/__init__.py +32 -0
  126. warp/fem/geometry/adaptive_nanogrid.py +857 -0
  127. warp/fem/geometry/closest_point.py +84 -0
  128. warp/fem/geometry/deformed_geometry.py +221 -0
  129. warp/fem/geometry/element.py +776 -0
  130. warp/fem/geometry/geometry.py +362 -0
  131. warp/fem/geometry/grid_2d.py +392 -0
  132. warp/fem/geometry/grid_3d.py +452 -0
  133. warp/fem/geometry/hexmesh.py +911 -0
  134. warp/fem/geometry/nanogrid.py +571 -0
  135. warp/fem/geometry/partition.py +389 -0
  136. warp/fem/geometry/quadmesh.py +663 -0
  137. warp/fem/geometry/tetmesh.py +855 -0
  138. warp/fem/geometry/trimesh.py +806 -0
  139. warp/fem/integrate.py +2335 -0
  140. warp/fem/linalg.py +419 -0
  141. warp/fem/operator.py +293 -0
  142. warp/fem/polynomial.py +229 -0
  143. warp/fem/quadrature/__init__.py +17 -0
  144. warp/fem/quadrature/pic_quadrature.py +299 -0
  145. warp/fem/quadrature/quadrature.py +591 -0
  146. warp/fem/space/__init__.py +228 -0
  147. warp/fem/space/basis_function_space.py +468 -0
  148. warp/fem/space/basis_space.py +667 -0
  149. warp/fem/space/dof_mapper.py +251 -0
  150. warp/fem/space/function_space.py +309 -0
  151. warp/fem/space/grid_2d_function_space.py +177 -0
  152. warp/fem/space/grid_3d_function_space.py +227 -0
  153. warp/fem/space/hexmesh_function_space.py +257 -0
  154. warp/fem/space/nanogrid_function_space.py +201 -0
  155. warp/fem/space/partition.py +367 -0
  156. warp/fem/space/quadmesh_function_space.py +223 -0
  157. warp/fem/space/restriction.py +179 -0
  158. warp/fem/space/shape/__init__.py +143 -0
  159. warp/fem/space/shape/cube_shape_function.py +1105 -0
  160. warp/fem/space/shape/shape_function.py +133 -0
  161. warp/fem/space/shape/square_shape_function.py +926 -0
  162. warp/fem/space/shape/tet_shape_function.py +834 -0
  163. warp/fem/space/shape/triangle_shape_function.py +672 -0
  164. warp/fem/space/tetmesh_function_space.py +271 -0
  165. warp/fem/space/topology.py +424 -0
  166. warp/fem/space/trimesh_function_space.py +194 -0
  167. warp/fem/types.py +99 -0
  168. warp/fem/utils.py +420 -0
  169. warp/jax.py +187 -0
  170. warp/jax_experimental/__init__.py +16 -0
  171. warp/jax_experimental/custom_call.py +351 -0
  172. warp/jax_experimental/ffi.py +698 -0
  173. warp/jax_experimental/xla_ffi.py +602 -0
  174. warp/math.py +244 -0
  175. warp/native/array.h +1145 -0
  176. warp/native/builtin.h +1800 -0
  177. warp/native/bvh.cpp +492 -0
  178. warp/native/bvh.cu +791 -0
  179. warp/native/bvh.h +554 -0
  180. warp/native/clang/clang.cpp +536 -0
  181. warp/native/coloring.cpp +613 -0
  182. warp/native/crt.cpp +51 -0
  183. warp/native/crt.h +362 -0
  184. warp/native/cuda_crt.h +1058 -0
  185. warp/native/cuda_util.cpp +646 -0
  186. warp/native/cuda_util.h +307 -0
  187. warp/native/error.cpp +77 -0
  188. warp/native/error.h +36 -0
  189. warp/native/exports.h +1878 -0
  190. warp/native/fabric.h +245 -0
  191. warp/native/hashgrid.cpp +311 -0
  192. warp/native/hashgrid.cu +87 -0
  193. warp/native/hashgrid.h +240 -0
  194. warp/native/initializer_array.h +41 -0
  195. warp/native/intersect.h +1230 -0
  196. warp/native/intersect_adj.h +375 -0
  197. warp/native/intersect_tri.h +339 -0
  198. warp/native/marching.cpp +19 -0
  199. warp/native/marching.cu +514 -0
  200. warp/native/marching.h +19 -0
  201. warp/native/mat.h +2220 -0
  202. warp/native/mathdx.cpp +87 -0
  203. warp/native/matnn.h +343 -0
  204. warp/native/mesh.cpp +266 -0
  205. warp/native/mesh.cu +404 -0
  206. warp/native/mesh.h +1980 -0
  207. warp/native/nanovdb/GridHandle.h +366 -0
  208. warp/native/nanovdb/HostBuffer.h +590 -0
  209. warp/native/nanovdb/NanoVDB.h +6624 -0
  210. warp/native/nanovdb/PNanoVDB.h +3390 -0
  211. warp/native/noise.h +859 -0
  212. warp/native/quat.h +1371 -0
  213. warp/native/rand.h +342 -0
  214. warp/native/range.h +139 -0
  215. warp/native/reduce.cpp +174 -0
  216. warp/native/reduce.cu +364 -0
  217. warp/native/runlength_encode.cpp +79 -0
  218. warp/native/runlength_encode.cu +61 -0
  219. warp/native/scan.cpp +47 -0
  220. warp/native/scan.cu +53 -0
  221. warp/native/scan.h +23 -0
  222. warp/native/solid_angle.h +466 -0
  223. warp/native/sort.cpp +251 -0
  224. warp/native/sort.cu +277 -0
  225. warp/native/sort.h +33 -0
  226. warp/native/sparse.cpp +378 -0
  227. warp/native/sparse.cu +524 -0
  228. warp/native/spatial.h +657 -0
  229. warp/native/svd.h +702 -0
  230. warp/native/temp_buffer.h +46 -0
  231. warp/native/tile.h +2584 -0
  232. warp/native/tile_reduce.h +264 -0
  233. warp/native/vec.h +1426 -0
  234. warp/native/volume.cpp +501 -0
  235. warp/native/volume.cu +67 -0
  236. warp/native/volume.h +969 -0
  237. warp/native/volume_builder.cu +477 -0
  238. warp/native/volume_builder.h +52 -0
  239. warp/native/volume_impl.h +70 -0
  240. warp/native/warp.cpp +1082 -0
  241. warp/native/warp.cu +3636 -0
  242. warp/native/warp.h +381 -0
  243. warp/optim/__init__.py +17 -0
  244. warp/optim/adam.py +163 -0
  245. warp/optim/linear.py +1137 -0
  246. warp/optim/sgd.py +112 -0
  247. warp/paddle.py +407 -0
  248. warp/render/__init__.py +18 -0
  249. warp/render/render_opengl.py +3518 -0
  250. warp/render/render_usd.py +784 -0
  251. warp/render/utils.py +160 -0
  252. warp/sim/__init__.py +65 -0
  253. warp/sim/articulation.py +793 -0
  254. warp/sim/collide.py +2395 -0
  255. warp/sim/graph_coloring.py +300 -0
  256. warp/sim/import_mjcf.py +790 -0
  257. warp/sim/import_snu.py +227 -0
  258. warp/sim/import_urdf.py +579 -0
  259. warp/sim/import_usd.py +894 -0
  260. warp/sim/inertia.py +324 -0
  261. warp/sim/integrator.py +242 -0
  262. warp/sim/integrator_euler.py +1997 -0
  263. warp/sim/integrator_featherstone.py +2101 -0
  264. warp/sim/integrator_vbd.py +2048 -0
  265. warp/sim/integrator_xpbd.py +3292 -0
  266. warp/sim/model.py +4791 -0
  267. warp/sim/particles.py +121 -0
  268. warp/sim/render.py +427 -0
  269. warp/sim/utils.py +428 -0
  270. warp/sparse.py +2057 -0
  271. warp/stubs.py +3333 -0
  272. warp/tape.py +1203 -0
  273. warp/tests/__init__.py +1 -0
  274. warp/tests/__main__.py +4 -0
  275. warp/tests/assets/curlnoise_golden.npy +0 -0
  276. warp/tests/assets/mlp_golden.npy +0 -0
  277. warp/tests/assets/pixel.npy +0 -0
  278. warp/tests/assets/pnoise_golden.npy +0 -0
  279. warp/tests/assets/spiky.usd +0 -0
  280. warp/tests/assets/test_grid.nvdb +0 -0
  281. warp/tests/assets/test_index_grid.nvdb +0 -0
  282. warp/tests/assets/test_int32_grid.nvdb +0 -0
  283. warp/tests/assets/test_vec_grid.nvdb +0 -0
  284. warp/tests/assets/torus.nvdb +0 -0
  285. warp/tests/assets/torus.usda +105 -0
  286. warp/tests/aux_test_class_kernel.py +34 -0
  287. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  288. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  289. warp/tests/aux_test_dependent.py +29 -0
  290. warp/tests/aux_test_grad_customs.py +29 -0
  291. warp/tests/aux_test_instancing_gc.py +26 -0
  292. warp/tests/aux_test_module_unload.py +23 -0
  293. warp/tests/aux_test_name_clash1.py +40 -0
  294. warp/tests/aux_test_name_clash2.py +40 -0
  295. warp/tests/aux_test_reference.py +9 -0
  296. warp/tests/aux_test_reference_reference.py +8 -0
  297. warp/tests/aux_test_square.py +16 -0
  298. warp/tests/aux_test_unresolved_func.py +22 -0
  299. warp/tests/aux_test_unresolved_symbol.py +22 -0
  300. warp/tests/cuda/__init__.py +0 -0
  301. warp/tests/cuda/test_async.py +676 -0
  302. warp/tests/cuda/test_ipc.py +124 -0
  303. warp/tests/cuda/test_mempool.py +233 -0
  304. warp/tests/cuda/test_multigpu.py +169 -0
  305. warp/tests/cuda/test_peer.py +139 -0
  306. warp/tests/cuda/test_pinned.py +84 -0
  307. warp/tests/cuda/test_streams.py +634 -0
  308. warp/tests/geometry/__init__.py +0 -0
  309. warp/tests/geometry/test_bvh.py +200 -0
  310. warp/tests/geometry/test_hash_grid.py +221 -0
  311. warp/tests/geometry/test_marching_cubes.py +74 -0
  312. warp/tests/geometry/test_mesh.py +316 -0
  313. warp/tests/geometry/test_mesh_query_aabb.py +399 -0
  314. warp/tests/geometry/test_mesh_query_point.py +932 -0
  315. warp/tests/geometry/test_mesh_query_ray.py +311 -0
  316. warp/tests/geometry/test_volume.py +1103 -0
  317. warp/tests/geometry/test_volume_write.py +346 -0
  318. warp/tests/interop/__init__.py +0 -0
  319. warp/tests/interop/test_dlpack.py +729 -0
  320. warp/tests/interop/test_jax.py +371 -0
  321. warp/tests/interop/test_paddle.py +800 -0
  322. warp/tests/interop/test_torch.py +1001 -0
  323. warp/tests/run_coverage_serial.py +39 -0
  324. warp/tests/sim/__init__.py +0 -0
  325. warp/tests/sim/disabled_kinematics.py +244 -0
  326. warp/tests/sim/flaky_test_sim_grad.py +290 -0
  327. warp/tests/sim/test_collision.py +604 -0
  328. warp/tests/sim/test_coloring.py +258 -0
  329. warp/tests/sim/test_model.py +224 -0
  330. warp/tests/sim/test_sim_grad_bounce_linear.py +212 -0
  331. warp/tests/sim/test_sim_kinematics.py +98 -0
  332. warp/tests/sim/test_vbd.py +597 -0
  333. warp/tests/test_adam.py +163 -0
  334. warp/tests/test_arithmetic.py +1096 -0
  335. warp/tests/test_array.py +2972 -0
  336. warp/tests/test_array_reduce.py +156 -0
  337. warp/tests/test_assert.py +250 -0
  338. warp/tests/test_atomic.py +153 -0
  339. warp/tests/test_bool.py +220 -0
  340. warp/tests/test_builtins_resolution.py +1298 -0
  341. warp/tests/test_closest_point_edge_edge.py +327 -0
  342. warp/tests/test_codegen.py +810 -0
  343. warp/tests/test_codegen_instancing.py +1495 -0
  344. warp/tests/test_compile_consts.py +215 -0
  345. warp/tests/test_conditional.py +252 -0
  346. warp/tests/test_context.py +42 -0
  347. warp/tests/test_copy.py +238 -0
  348. warp/tests/test_ctypes.py +638 -0
  349. warp/tests/test_dense.py +73 -0
  350. warp/tests/test_devices.py +97 -0
  351. warp/tests/test_examples.py +482 -0
  352. warp/tests/test_fabricarray.py +996 -0
  353. warp/tests/test_fast_math.py +74 -0
  354. warp/tests/test_fem.py +2003 -0
  355. warp/tests/test_fp16.py +136 -0
  356. warp/tests/test_func.py +454 -0
  357. warp/tests/test_future_annotations.py +98 -0
  358. warp/tests/test_generics.py +656 -0
  359. warp/tests/test_grad.py +893 -0
  360. warp/tests/test_grad_customs.py +339 -0
  361. warp/tests/test_grad_debug.py +341 -0
  362. warp/tests/test_implicit_init.py +411 -0
  363. warp/tests/test_import.py +45 -0
  364. warp/tests/test_indexedarray.py +1140 -0
  365. warp/tests/test_intersect.py +73 -0
  366. warp/tests/test_iter.py +76 -0
  367. warp/tests/test_large.py +177 -0
  368. warp/tests/test_launch.py +411 -0
  369. warp/tests/test_lerp.py +151 -0
  370. warp/tests/test_linear_solvers.py +193 -0
  371. warp/tests/test_lvalue.py +427 -0
  372. warp/tests/test_mat.py +2089 -0
  373. warp/tests/test_mat_lite.py +122 -0
  374. warp/tests/test_mat_scalar_ops.py +2913 -0
  375. warp/tests/test_math.py +178 -0
  376. warp/tests/test_mlp.py +282 -0
  377. warp/tests/test_module_hashing.py +258 -0
  378. warp/tests/test_modules_lite.py +44 -0
  379. warp/tests/test_noise.py +252 -0
  380. warp/tests/test_operators.py +299 -0
  381. warp/tests/test_options.py +129 -0
  382. warp/tests/test_overwrite.py +551 -0
  383. warp/tests/test_print.py +339 -0
  384. warp/tests/test_quat.py +2315 -0
  385. warp/tests/test_rand.py +339 -0
  386. warp/tests/test_reload.py +302 -0
  387. warp/tests/test_rounding.py +185 -0
  388. warp/tests/test_runlength_encode.py +196 -0
  389. warp/tests/test_scalar_ops.py +105 -0
  390. warp/tests/test_smoothstep.py +108 -0
  391. warp/tests/test_snippet.py +318 -0
  392. warp/tests/test_sparse.py +582 -0
  393. warp/tests/test_spatial.py +2229 -0
  394. warp/tests/test_special_values.py +361 -0
  395. warp/tests/test_static.py +592 -0
  396. warp/tests/test_struct.py +734 -0
  397. warp/tests/test_tape.py +204 -0
  398. warp/tests/test_transient_module.py +93 -0
  399. warp/tests/test_triangle_closest_point.py +145 -0
  400. warp/tests/test_types.py +562 -0
  401. warp/tests/test_utils.py +588 -0
  402. warp/tests/test_vec.py +1487 -0
  403. warp/tests/test_vec_lite.py +80 -0
  404. warp/tests/test_vec_scalar_ops.py +2327 -0
  405. warp/tests/test_verify_fp.py +100 -0
  406. warp/tests/tile/__init__.py +0 -0
  407. warp/tests/tile/test_tile.py +780 -0
  408. warp/tests/tile/test_tile_load.py +407 -0
  409. warp/tests/tile/test_tile_mathdx.py +208 -0
  410. warp/tests/tile/test_tile_mlp.py +402 -0
  411. warp/tests/tile/test_tile_reduce.py +447 -0
  412. warp/tests/tile/test_tile_shared_memory.py +247 -0
  413. warp/tests/tile/test_tile_view.py +173 -0
  414. warp/tests/unittest_serial.py +47 -0
  415. warp/tests/unittest_suites.py +427 -0
  416. warp/tests/unittest_utils.py +468 -0
  417. warp/tests/walkthrough_debug.py +93 -0
  418. warp/thirdparty/__init__.py +0 -0
  419. warp/thirdparty/appdirs.py +598 -0
  420. warp/thirdparty/dlpack.py +145 -0
  421. warp/thirdparty/unittest_parallel.py +570 -0
  422. warp/torch.py +391 -0
  423. warp/types.py +5230 -0
  424. warp/utils.py +1137 -0
  425. warp_lang-1.7.0.dist-info/METADATA +516 -0
  426. warp_lang-1.7.0.dist-info/RECORD +429 -0
  427. warp_lang-1.7.0.dist-info/WHEEL +5 -0
  428. warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
  429. warp_lang-1.7.0.dist-info/top_level.txt +1 -0
warp/native/warp.cu ADDED
@@ -0,0 +1,3636 @@
1
+ /*
2
+ * SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: Apache-2.0
4
+ *
5
+ * Licensed under the Apache License, Version 2.0 (the "License");
6
+ * you may not use this file except in compliance with the License.
7
+ * You may obtain a copy of the License at
8
+ *
9
+ * http://www.apache.org/licenses/LICENSE-2.0
10
+ *
11
+ * Unless required by applicable law or agreed to in writing, software
12
+ * distributed under the License is distributed on an "AS IS" BASIS,
13
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ * See the License for the specific language governing permissions and
15
+ * limitations under the License.
16
+ */
17
+
18
+ #include "warp.h"
19
+ #include "scan.h"
20
+ #include "cuda_util.h"
21
+ #include "error.h"
22
+
23
+ #include <cstdlib>
24
+ #include <fstream>
25
+ #include <nvrtc.h>
26
+ #include <nvPTXCompiler.h>
27
+ #if WP_ENABLE_MATHDX
28
+ #include <nvJitLink.h>
29
+ #include <libmathdx.h>
30
+ #endif
31
+
32
+ #include <array>
33
+ #include <algorithm>
34
+ #include <iterator>
35
+ #include <list>
36
+ #include <map>
37
+ #include <string>
38
+ #include <unordered_map>
39
+ #include <unordered_set>
40
+ #include <vector>
41
+
42
+ #define check_any(result) (check_generic(result, __FILE__, __LINE__))
43
+ #define check_nvrtc(code) (check_nvrtc_result(code, __FILE__, __LINE__))
44
+ #define check_nvptx(code) (check_nvptx_result(code, __FILE__, __LINE__))
45
+ #define check_nvjitlink(handle, code) (check_nvjitlink_result(handle, code, __FILE__, __LINE__))
46
+ #define check_cufftdx(code) (check_cufftdx_result(code, __FILE__, __LINE__))
47
+ #define check_cublasdx(code) (check_cublasdx_result(code, __FILE__, __LINE__))
48
+ #define check_cusolver(code) (check_cusolver_result(code, __FILE__, __LINE__))
49
+ #define CHECK_ANY(code) \
50
+ { \
51
+ do { \
52
+ bool out = (check_any(code)); \
53
+ if(!out) { \
54
+ return out; \
55
+ } \
56
+ } while(0); \
57
+ }
58
+ #define CHECK_CUFFTDX(code) \
59
+ { \
60
+ do { \
61
+ bool out = (check_cufftdx(code)); \
62
+ if(!out) { \
63
+ return out; \
64
+ } \
65
+ } while(0); \
66
+ }
67
+ #define CHECK_CUBLASDX(code) \
68
+ { \
69
+ do { \
70
+ bool out = (check_cufftdx(code)); \
71
+ if(!out) { \
72
+ return out; \
73
+ } \
74
+ } while(0); \
75
+ }
76
+ #define CHECK_CUSOLVER(code) \
77
+ { \
78
+ do { \
79
+ bool out = (check_cusolver(code)); \
80
+ if(!out) { \
81
+ return out; \
82
+ } \
83
+ } while(0); \
84
+ }
85
+
86
+ bool check_nvrtc_result(nvrtcResult result, const char* file, int line)
87
+ {
88
+ if (result == NVRTC_SUCCESS)
89
+ return true;
90
+
91
+ const char* error_string = nvrtcGetErrorString(result);
92
+ fprintf(stderr, "Warp NVRTC compilation error %u: %s (%s:%d)\n", unsigned(result), error_string, file, line);
93
+ return false;
94
+ }
95
+
96
+ bool check_nvptx_result(nvPTXCompileResult result, const char* file, int line)
97
+ {
98
+ if (result == NVPTXCOMPILE_SUCCESS)
99
+ return true;
100
+
101
+ const char* error_string;
102
+ switch (result)
103
+ {
104
+ case NVPTXCOMPILE_ERROR_INVALID_COMPILER_HANDLE:
105
+ error_string = "Invalid compiler handle";
106
+ break;
107
+ case NVPTXCOMPILE_ERROR_INVALID_INPUT:
108
+ error_string = "Invalid input";
109
+ break;
110
+ case NVPTXCOMPILE_ERROR_COMPILATION_FAILURE:
111
+ error_string = "Compilation failure";
112
+ break;
113
+ case NVPTXCOMPILE_ERROR_INTERNAL:
114
+ error_string = "Internal error";
115
+ break;
116
+ case NVPTXCOMPILE_ERROR_OUT_OF_MEMORY:
117
+ error_string = "Out of memory";
118
+ break;
119
+ case NVPTXCOMPILE_ERROR_COMPILER_INVOCATION_INCOMPLETE:
120
+ error_string = "Incomplete compiler invocation";
121
+ break;
122
+ case NVPTXCOMPILE_ERROR_UNSUPPORTED_PTX_VERSION:
123
+ error_string = "Unsupported PTX version";
124
+ break;
125
+ default:
126
+ error_string = "Unknown error";
127
+ break;
128
+ }
129
+
130
+ fprintf(stderr, "Warp PTX compilation error %u: %s (%s:%d)\n", unsigned(result), error_string, file, line);
131
+ return false;
132
+ }
133
+
134
+ bool check_generic(int result, const char* file, int line)
135
+ {
136
+ if (!result) {
137
+ fprintf(stderr, "Error %d on %s:%d\n", (int)result, file, line);
138
+ return false;
139
+ } else {
140
+ return true;
141
+ }
142
+ }
143
+
144
+ struct DeviceInfo
145
+ {
146
+ static constexpr int kNameLen = 128;
147
+
148
+ CUdevice device = -1;
149
+ CUuuid uuid = {0};
150
+ int ordinal = -1;
151
+ int pci_domain_id = -1;
152
+ int pci_bus_id = -1;
153
+ int pci_device_id = -1;
154
+ char name[kNameLen] = "";
155
+ int arch = 0;
156
+ int is_uva = 0;
157
+ int is_mempool_supported = 0;
158
+ int is_ipc_supported = -1;
159
+ int max_smem_bytes = 0;
160
+ CUcontext primary_context = NULL;
161
+ };
162
+
163
+ struct ContextInfo
164
+ {
165
+ DeviceInfo* device_info = NULL;
166
+
167
+ // the current stream, managed from Python (see cuda_context_set_stream() and cuda_context_get_stream())
168
+ CUstream stream = NULL;
169
+ };
170
+
171
+ struct CaptureInfo
172
+ {
173
+ CUstream stream = NULL; // the main stream where capture begins and ends
174
+ uint64_t id = 0; // unique capture id from CUDA
175
+ bool external = false; // whether this is an external capture
176
+ };
177
+
178
+ struct StreamInfo
179
+ {
180
+ CUevent cached_event = NULL; // event used for stream synchronization (cached to avoid creating temporary events)
181
+ CaptureInfo* capture = NULL; // capture info (only if started on this stream)
182
+ };
183
+
184
+ struct GraphInfo
185
+ {
186
+ std::vector<void*> unfreed_allocs;
187
+ };
188
+
189
+ // Information for graph allocations that are not freed by the graph.
190
+ // These allocations have a shared ownership:
191
+ // - The graph instance allocates/maps the memory on each launch, even if the user reference is released.
192
+ // - The user reference must remain valid even if the graph is destroyed.
193
+ // The memory will be freed once the user reference is released and the graph is destroyed.
194
+ struct GraphAllocInfo
195
+ {
196
+ uint64_t capture_id = 0;
197
+ void* context = NULL;
198
+ bool ref_exists = false; // whether user reference still exists
199
+ bool graph_destroyed = false; // whether graph instance was destroyed
200
+ };
201
+
202
+ // Information used when deferring deallocations.
203
+ struct FreeInfo
204
+ {
205
+ void* context = NULL;
206
+ void* ptr = NULL;
207
+ bool is_async = false;
208
+ };
209
+
210
+ // Information used when deferring module unloading.
211
+ struct ModuleInfo
212
+ {
213
+ void* context = NULL;
214
+ void* module = NULL;
215
+ };
216
+
217
+ static std::unordered_map<CUfunction, std::string> g_kernel_names;
218
+
219
+ // cached info for all devices, indexed by ordinal
220
+ static std::vector<DeviceInfo> g_devices;
221
+
222
+ // maps CUdevice to DeviceInfo
223
+ static std::map<CUdevice, DeviceInfo*> g_device_map;
224
+
225
+ // cached info for all known contexts
226
+ static std::map<CUcontext, ContextInfo> g_contexts;
227
+
228
+ // cached info for all known streams (including registered external streams)
229
+ static std::unordered_map<CUstream, StreamInfo> g_streams;
230
+
231
+ // Ongoing graph captures registered using wp.capture_begin().
232
+ // This maps the capture id to the stream where capture was started.
233
+ // See cuda_graph_begin_capture(), cuda_graph_end_capture(), and free_device_async().
234
+ static std::unordered_map<uint64_t, CaptureInfo*> g_captures;
235
+
236
+ // Memory allocated during graph capture requires special handling.
237
+ // See alloc_device_async() and free_device_async().
238
+ static std::unordered_map<void*, GraphAllocInfo> g_graph_allocs;
239
+
240
+ // Memory that cannot be freed immediately gets queued here.
241
+ // Call free_deferred_allocs() to release.
242
+ static std::vector<FreeInfo> g_deferred_free_list;
243
+
244
+ // Modules that cannot be unloaded immediately get queued here.
245
+ // Call unload_deferred_modules() to release.
246
+ static std::vector<ModuleInfo> g_deferred_module_list;
247
+
248
+ void cuda_set_context_restore_policy(bool always_restore)
249
+ {
250
+ ContextGuard::always_restore = always_restore;
251
+ }
252
+
253
+ int cuda_get_context_restore_policy()
254
+ {
255
+ return int(ContextGuard::always_restore);
256
+ }
257
+
258
+ int cuda_init()
259
+ {
260
+ if (!init_cuda_driver())
261
+ return -1;
262
+
263
+ int device_count = 0;
264
+ if (check_cu(cuDeviceGetCount_f(&device_count)))
265
+ {
266
+ g_devices.resize(device_count);
267
+
268
+ for (int i = 0; i < device_count; i++)
269
+ {
270
+ CUdevice device;
271
+ if (check_cu(cuDeviceGet_f(&device, i)))
272
+ {
273
+ // query device info
274
+ g_devices[i].device = device;
275
+ g_devices[i].ordinal = i;
276
+ check_cu(cuDeviceGetName_f(g_devices[i].name, DeviceInfo::kNameLen, device));
277
+ check_cu(cuDeviceGetUuid_f(&g_devices[i].uuid, device));
278
+ check_cu(cuDeviceGetAttribute_f(&g_devices[i].pci_domain_id, CU_DEVICE_ATTRIBUTE_PCI_DOMAIN_ID, device));
279
+ check_cu(cuDeviceGetAttribute_f(&g_devices[i].pci_bus_id, CU_DEVICE_ATTRIBUTE_PCI_BUS_ID, device));
280
+ check_cu(cuDeviceGetAttribute_f(&g_devices[i].pci_device_id, CU_DEVICE_ATTRIBUTE_PCI_DEVICE_ID, device));
281
+ check_cu(cuDeviceGetAttribute_f(&g_devices[i].is_uva, CU_DEVICE_ATTRIBUTE_UNIFIED_ADDRESSING, device));
282
+ check_cu(cuDeviceGetAttribute_f(&g_devices[i].is_mempool_supported, CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED, device));
283
+ #ifdef CUDA_VERSION
284
+ #if CUDA_VERSION >= 12000
285
+ int device_attribute_integrated = 0;
286
+ check_cu(cuDeviceGetAttribute_f(&device_attribute_integrated, CU_DEVICE_ATTRIBUTE_INTEGRATED, device));
287
+ if (device_attribute_integrated == 0)
288
+ {
289
+ check_cu(cuDeviceGetAttribute_f(&g_devices[i].is_ipc_supported, CU_DEVICE_ATTRIBUTE_IPC_EVENT_SUPPORTED, device));
290
+ }
291
+ else
292
+ {
293
+ // integrated devices do not support CUDA IPC
294
+ g_devices[i].is_ipc_supported = 0;
295
+ }
296
+ #endif
297
+ #endif
298
+ check_cu(cuDeviceGetAttribute_f(&g_devices[i].max_smem_bytes, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, device));
299
+ int major = 0;
300
+ int minor = 0;
301
+ check_cu(cuDeviceGetAttribute_f(&major, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device));
302
+ check_cu(cuDeviceGetAttribute_f(&minor, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, device));
303
+ g_devices[i].arch = 10 * major + minor;
304
+
305
+ g_device_map[device] = &g_devices[i];
306
+ }
307
+ else
308
+ {
309
+ return -1;
310
+ }
311
+ }
312
+ }
313
+ else
314
+ {
315
+ return -1;
316
+ }
317
+
318
+ // initialize default timing state
319
+ static CudaTimingState default_timing_state(0, NULL);
320
+ g_cuda_timing_state = &default_timing_state;
321
+
322
+ return 0;
323
+ }
324
+
325
+
326
+ static inline CUcontext get_current_context()
327
+ {
328
+ CUcontext ctx;
329
+ if (check_cu(cuCtxGetCurrent_f(&ctx)))
330
+ return ctx;
331
+ else
332
+ return NULL;
333
+ }
334
+
335
+ static inline CUstream get_current_stream(void* context=NULL)
336
+ {
337
+ return static_cast<CUstream>(cuda_context_get_stream(context));
338
+ }
339
+
340
+ static ContextInfo* get_context_info(CUcontext ctx)
341
+ {
342
+ if (!ctx)
343
+ {
344
+ ctx = get_current_context();
345
+ if (!ctx)
346
+ return NULL;
347
+ }
348
+
349
+ auto it = g_contexts.find(ctx);
350
+ if (it != g_contexts.end())
351
+ {
352
+ return &it->second;
353
+ }
354
+ else
355
+ {
356
+ // previously unseen context, add the info
357
+ ContextGuard guard(ctx, true);
358
+
359
+ CUdevice device;
360
+ if (check_cu(cuCtxGetDevice_f(&device)))
361
+ {
362
+ DeviceInfo* device_info = g_device_map[device];
363
+
364
+ // workaround for https://nvbugspro.nvidia.com/bug/4456003
365
+ if (device_info->is_mempool_supported)
366
+ {
367
+ void* dummy = NULL;
368
+ check_cuda(cudaMallocAsync(&dummy, 1, NULL));
369
+ check_cuda(cudaFreeAsync(dummy, NULL));
370
+ }
371
+
372
+ ContextInfo context_info;
373
+ context_info.device_info = device_info;
374
+ auto result = g_contexts.insert(std::make_pair(ctx, context_info));
375
+ return &result.first->second;
376
+ }
377
+ }
378
+
379
+ return NULL;
380
+ }
381
+
382
+ static inline ContextInfo* get_context_info(void* context)
383
+ {
384
+ return get_context_info(static_cast<CUcontext>(context));
385
+ }
386
+
387
+ static inline StreamInfo* get_stream_info(CUstream stream)
388
+ {
389
+ auto it = g_streams.find(stream);
390
+ if (it != g_streams.end())
391
+ return &it->second;
392
+ else
393
+ return NULL;
394
+ }
395
+
396
+ static void deferred_free(void* ptr, void* context, bool is_async)
397
+ {
398
+ FreeInfo free_info;
399
+ free_info.ptr = ptr;
400
+ free_info.context = context ? context : get_current_context();
401
+ free_info.is_async = is_async;
402
+ g_deferred_free_list.push_back(free_info);
403
+ }
404
+
405
+ static int free_deferred_allocs(void* context = NULL)
406
+ {
407
+ if (g_deferred_free_list.empty() || !g_captures.empty())
408
+ return 0;
409
+
410
+ int num_freed_allocs = 0;
411
+ for (auto it = g_deferred_free_list.begin(); it != g_deferred_free_list.end(); /*noop*/)
412
+ {
413
+ const FreeInfo& free_info = *it;
414
+
415
+ // free the pointer if it matches the given context or if the context is unspecified
416
+ if (free_info.context == context || !context)
417
+ {
418
+ ContextGuard guard(free_info.context);
419
+
420
+ if (free_info.is_async)
421
+ {
422
+ // this could be a regular stream-ordered allocation or a graph allocation
423
+ cudaError_t res = cudaFreeAsync(free_info.ptr, NULL);
424
+ if (res != cudaSuccess)
425
+ {
426
+ if (res == cudaErrorInvalidValue)
427
+ {
428
+ // This can happen if we try to release the pointer but the graph was
429
+ // never launched, so the memory isn't mapped.
430
+ // This is fine, so clear the error.
431
+ cudaGetLastError();
432
+ }
433
+ else
434
+ {
435
+ // something else went wrong, report error
436
+ check_cuda(res);
437
+ }
438
+ }
439
+ }
440
+ else
441
+ {
442
+ check_cuda(cudaFree(free_info.ptr));
443
+ }
444
+
445
+ ++num_freed_allocs;
446
+
447
+ it = g_deferred_free_list.erase(it);
448
+ }
449
+ else
450
+ {
451
+ ++it;
452
+ }
453
+ }
454
+
455
+ return num_freed_allocs;
456
+ }
457
+
458
+ static int unload_deferred_modules(void* context = NULL)
459
+ {
460
+ if (g_deferred_module_list.empty() || !g_captures.empty())
461
+ return 0;
462
+
463
+ int num_unloaded_modules = 0;
464
+ for (auto it = g_deferred_module_list.begin(); it != g_deferred_module_list.end(); /*noop*/)
465
+ {
466
+ // free the module if it matches the given context or if the context is unspecified
467
+ const ModuleInfo& module_info = *it;
468
+ if (module_info.context == context || !context)
469
+ {
470
+ cuda_unload_module(module_info.context, module_info.module);
471
+ ++num_unloaded_modules;
472
+ it = g_deferred_module_list.erase(it);
473
+ }
474
+ else
475
+ {
476
+ ++it;
477
+ }
478
+ }
479
+
480
+ return num_unloaded_modules;
481
+ }
482
+
483
+ static void CUDART_CB on_graph_destroy(void* user_data)
484
+ {
485
+ if (!user_data)
486
+ return;
487
+
488
+ GraphInfo* graph_info = static_cast<GraphInfo*>(user_data);
489
+
490
+ for (void* ptr : graph_info->unfreed_allocs)
491
+ {
492
+ auto alloc_iter = g_graph_allocs.find(ptr);
493
+ if (alloc_iter != g_graph_allocs.end())
494
+ {
495
+ GraphAllocInfo& alloc_info = alloc_iter->second;
496
+ if (alloc_info.ref_exists)
497
+ {
498
+ // unreference from graph so the pointer will be deallocated when the user reference goes away
499
+ alloc_info.graph_destroyed = true;
500
+ }
501
+ else
502
+ {
503
+ // the pointer can be freed, but we can't call CUDA functions in this callback, so defer it
504
+ deferred_free(ptr, alloc_info.context, true);
505
+ g_graph_allocs.erase(alloc_iter);
506
+ }
507
+ }
508
+ }
509
+
510
+ delete graph_info;
511
+ }
512
+
513
+ static inline const char* get_cuda_kernel_name(void* kernel)
514
+ {
515
+ CUfunction cuda_func = static_cast<CUfunction>(kernel);
516
+ auto name_iter = g_kernel_names.find((CUfunction)cuda_func);
517
+ if (name_iter != g_kernel_names.end())
518
+ return name_iter->second.c_str();
519
+ else
520
+ return "unknown_kernel";
521
+ }
522
+
523
+
524
+ void* alloc_pinned(size_t s)
525
+ {
526
+ void* ptr = NULL;
527
+ check_cuda(cudaMallocHost(&ptr, s));
528
+ return ptr;
529
+ }
530
+
531
+ void free_pinned(void* ptr)
532
+ {
533
+ cudaFreeHost(ptr);
534
+ }
535
+
536
+ void* alloc_device(void* context, size_t s)
537
+ {
538
+ int ordinal = cuda_context_get_device_ordinal(context);
539
+
540
+ // use stream-ordered allocator if available
541
+ if (cuda_device_is_mempool_supported(ordinal))
542
+ return alloc_device_async(context, s);
543
+ else
544
+ return alloc_device_default(context, s);
545
+ }
546
+
547
+ void free_device(void* context, void* ptr)
548
+ {
549
+ int ordinal = cuda_context_get_device_ordinal(context);
550
+
551
+ // use stream-ordered allocator if available
552
+ if (cuda_device_is_mempool_supported(ordinal))
553
+ free_device_async(context, ptr);
554
+ else
555
+ free_device_default(context, ptr);
556
+ }
557
+
558
+ void* alloc_device_default(void* context, size_t s)
559
+ {
560
+ ContextGuard guard(context);
561
+
562
+ void* ptr = NULL;
563
+ check_cuda(cudaMalloc(&ptr, s));
564
+
565
+ return ptr;
566
+ }
567
+
568
+ void free_device_default(void* context, void* ptr)
569
+ {
570
+ ContextGuard guard(context);
571
+
572
+ // check if a capture is in progress
573
+ if (g_captures.empty())
574
+ {
575
+ check_cuda(cudaFree(ptr));
576
+ }
577
+ else
578
+ {
579
+ // we must defer the operation until graph captures complete
580
+ deferred_free(ptr, context, false);
581
+ }
582
+ }
583
+
584
+ void* alloc_device_async(void* context, size_t s)
585
+ {
586
+ // stream-ordered allocations don't rely on the current context,
587
+ // but we set the context here for consistent behaviour
588
+ ContextGuard guard(context);
589
+
590
+ ContextInfo* context_info = get_context_info(context);
591
+ if (!context_info)
592
+ return NULL;
593
+
594
+ CUstream stream = context_info->stream;
595
+
596
+ void* ptr = NULL;
597
+ check_cuda(cudaMallocAsync(&ptr, s, stream));
598
+
599
+ if (ptr)
600
+ {
601
+ // if the stream is capturing, the allocation requires special handling
602
+ if (cuda_stream_is_capturing(stream))
603
+ {
604
+ // check if this is a known capture
605
+ uint64_t capture_id = get_capture_id(stream);
606
+ auto capture_iter = g_captures.find(capture_id);
607
+ if (capture_iter != g_captures.end())
608
+ {
609
+ // remember graph allocation details
610
+ GraphAllocInfo alloc_info;
611
+ alloc_info.capture_id = capture_id;
612
+ alloc_info.context = context ? context : get_current_context();
613
+ alloc_info.ref_exists = true; // user reference created and returned here
614
+ alloc_info.graph_destroyed = false; // graph not destroyed yet
615
+ g_graph_allocs[ptr] = alloc_info;
616
+ }
617
+ }
618
+ }
619
+
620
+ return ptr;
621
+ }
622
+
623
+ void free_device_async(void* context, void* ptr)
624
+ {
625
+ // stream-ordered allocators generally don't rely on the current context,
626
+ // but we set the context here for consistent behaviour
627
+ ContextGuard guard(context);
628
+
629
+ // NB: Stream-ordered deallocations are tricky, because the memory could still be used on another stream
630
+ // or even multiple streams. To avoid use-after-free errors, we need to ensure that all preceding work
631
+ // completes before releasing the memory. The strategy is different for regular stream-ordered allocations
632
+ // and allocations made during graph capture. See below for details.
633
+
634
+ // check if this allocation was made during graph capture
635
+ auto alloc_iter = g_graph_allocs.find(ptr);
636
+ if (alloc_iter == g_graph_allocs.end())
637
+ {
638
+ // Not a graph allocation.
639
+ // Check if graph capture is ongoing.
640
+ if (g_captures.empty())
641
+ {
642
+ // cudaFreeAsync on the null stream does not block or trigger synchronization, but it postpones
643
+ // the deallocation until a synchronization point is reached, so preceding work on this pointer
644
+ // should safely complete.
645
+ check_cuda(cudaFreeAsync(ptr, NULL));
646
+ }
647
+ else
648
+ {
649
+ // We must defer the free operation until graph capture completes.
650
+ deferred_free(ptr, context, true);
651
+ }
652
+ }
653
+ else
654
+ {
655
+ // get the graph allocation details
656
+ GraphAllocInfo& alloc_info = alloc_iter->second;
657
+
658
+ uint64_t capture_id = alloc_info.capture_id;
659
+
660
+ // check if the capture is still active
661
+ auto capture_iter = g_captures.find(capture_id);
662
+ if (capture_iter != g_captures.end())
663
+ {
664
+ // Add a mem free node. Use all current leaf nodes as dependencies to ensure that all prior
665
+ // work completes before deallocating. This works with both Warp-initiated and external captures
666
+ // and avoids the need to explicitly track all streams used during the capture.
667
+ CaptureInfo* capture = capture_iter->second;
668
+ cudaGraph_t graph = get_capture_graph(capture->stream);
669
+ std::vector<cudaGraphNode_t> leaf_nodes;
670
+ if (graph && get_graph_leaf_nodes(graph, leaf_nodes))
671
+ {
672
+ cudaGraphNode_t free_node;
673
+ check_cuda(cudaGraphAddMemFreeNode(&free_node, graph, leaf_nodes.data(), leaf_nodes.size(), ptr));
674
+ }
675
+
676
+ // we're done with this allocation, it's owned by the graph
677
+ g_graph_allocs.erase(alloc_iter);
678
+ }
679
+ else
680
+ {
681
+ // the capture has ended
682
+ // if the owning graph was already destroyed, we can free the pointer now
683
+ if (alloc_info.graph_destroyed)
684
+ {
685
+ if (g_captures.empty())
686
+ {
687
+ // try to free the pointer now
688
+ cudaError_t res = cudaFreeAsync(ptr, NULL);
689
+ if (res == cudaErrorInvalidValue)
690
+ {
691
+ // This can happen if we try to release the pointer but the graph was
692
+ // never launched, so the memory isn't mapped.
693
+ // This is fine, so clear the error.
694
+ cudaGetLastError();
695
+ }
696
+ else
697
+ {
698
+ // check for other errors
699
+ check_cuda(res);
700
+ }
701
+ }
702
+ else
703
+ {
704
+ // We must defer the operation until graph capture completes.
705
+ deferred_free(ptr, context, true);
706
+ }
707
+
708
+ // we're done with this allocation
709
+ g_graph_allocs.erase(alloc_iter);
710
+ }
711
+ else
712
+ {
713
+ // graph still exists
714
+ // unreference the pointer so it will be deallocated once the graph instance is destroyed
715
+ alloc_info.ref_exists = false;
716
+ }
717
+ }
718
+ }
719
+ }
720
+
721
+ bool memcpy_h2d(void* context, void* dest, void* src, size_t n, void* stream)
722
+ {
723
+ ContextGuard guard(context);
724
+
725
+ CUstream cuda_stream;
726
+ if (stream != WP_CURRENT_STREAM)
727
+ cuda_stream = static_cast<CUstream>(stream);
728
+ else
729
+ cuda_stream = get_current_stream(context);
730
+
731
+ begin_cuda_range(WP_TIMING_MEMCPY, cuda_stream, context, "memcpy HtoD");
732
+
733
+ bool result = check_cuda(cudaMemcpyAsync(dest, src, n, cudaMemcpyHostToDevice, cuda_stream));
734
+
735
+ end_cuda_range(WP_TIMING_MEMCPY, cuda_stream);
736
+
737
+ return result;
738
+ }
739
+
740
+ bool memcpy_d2h(void* context, void* dest, void* src, size_t n, void* stream)
741
+ {
742
+ ContextGuard guard(context);
743
+
744
+ CUstream cuda_stream;
745
+ if (stream != WP_CURRENT_STREAM)
746
+ cuda_stream = static_cast<CUstream>(stream);
747
+ else
748
+ cuda_stream = get_current_stream(context);
749
+
750
+ begin_cuda_range(WP_TIMING_MEMCPY, cuda_stream, context, "memcpy DtoH");
751
+
752
+ bool result = check_cuda(cudaMemcpyAsync(dest, src, n, cudaMemcpyDeviceToHost, cuda_stream));
753
+
754
+ end_cuda_range(WP_TIMING_MEMCPY, cuda_stream);
755
+
756
+ return result;
757
+ }
758
+
759
+ bool memcpy_d2d(void* context, void* dest, void* src, size_t n, void* stream)
760
+ {
761
+ ContextGuard guard(context);
762
+
763
+ CUstream cuda_stream;
764
+ if (stream != WP_CURRENT_STREAM)
765
+ cuda_stream = static_cast<CUstream>(stream);
766
+ else
767
+ cuda_stream = get_current_stream(context);
768
+
769
+ begin_cuda_range(WP_TIMING_MEMCPY, cuda_stream, context, "memcpy DtoD");
770
+
771
+ bool result = check_cuda(cudaMemcpyAsync(dest, src, n, cudaMemcpyDeviceToDevice, cuda_stream));
772
+
773
+ end_cuda_range(WP_TIMING_MEMCPY, cuda_stream);
774
+
775
+ return result;
776
+ }
777
+
778
+ bool memcpy_p2p(void* dst_context, void* dst, void* src_context, void* src, size_t n, void* stream)
779
+ {
780
+ // ContextGuard guard(context);
781
+
782
+ CUstream cuda_stream;
783
+ if (stream != WP_CURRENT_STREAM)
784
+ cuda_stream = static_cast<CUstream>(stream);
785
+ else
786
+ cuda_stream = get_current_stream(dst_context);
787
+
788
+ // Notes:
789
+ // - cuMemcpyPeerAsync() works fine with both regular and pooled allocations (cudaMalloc() and cudaMallocAsync(), respectively)
790
+ // when not capturing a graph.
791
+ // - cuMemcpyPeerAsync() is not supported during graph capture, so we must use cudaMemcpyAsync() with kind=cudaMemcpyDefault.
792
+ // - cudaMemcpyAsync() works fine with regular allocations, but doesn't work with pooled allocations
793
+ // unless mempool access has been enabled.
794
+ // - There is no reliable way to check if mempool access is enabled during graph capture,
795
+ // because cudaMemPoolGetAccess() cannot be called during graph capture.
796
+ // - CUDA will report error 1 (invalid argument) if cudaMemcpyAsync() is called but mempool access is not enabled.
797
+
798
+ if (!cuda_stream_is_capturing(stream))
799
+ {
800
+ begin_cuda_range(WP_TIMING_MEMCPY, cuda_stream, get_stream_context(stream), "memcpy PtoP");
801
+
802
+ bool result = check_cu(cuMemcpyPeerAsync_f(
803
+ (CUdeviceptr)dst, (CUcontext)dst_context,
804
+ (CUdeviceptr)src, (CUcontext)src_context,
805
+ n, cuda_stream));
806
+
807
+ end_cuda_range(WP_TIMING_MEMCPY, cuda_stream);
808
+
809
+ return result;
810
+ }
811
+ else
812
+ {
813
+ cudaError_t result = cudaSuccess;
814
+
815
+ // cudaMemcpyAsync() is sensitive to the bound context to resolve pointer locations.
816
+ // If fails with cudaErrorInvalidValue if it cannot resolve an argument.
817
+ // We first try the copy in the destination context, then if it fails we retry in the source context.
818
+ // The cudaErrorInvalidValue error doesn't cause graph capture to fail, so it's ok to retry.
819
+ // Since this trial-and-error shenanigans only happens during capture, there
820
+ // is no perf impact when the graph is launched.
821
+ // For bonus points, this approach simplifies memory pool access requirements.
822
+ // Access only needs to be enabled one way, either from the source device to the destination device
823
+ // or vice versa. Sometimes, when it's really quiet, you can actually hear my genius.
824
+ {
825
+ // try doing the copy in the destination context
826
+ ContextGuard guard(dst_context);
827
+ result = cudaMemcpyAsync(dst, src, n, cudaMemcpyDefault, cuda_stream);
828
+
829
+ if (result != cudaSuccess)
830
+ {
831
+ // clear error in destination context
832
+ cudaGetLastError();
833
+
834
+ // try doing the copy in the source context
835
+ ContextGuard guard(src_context);
836
+ result = cudaMemcpyAsync(dst, src, n, cudaMemcpyDefault, cuda_stream);
837
+
838
+ // clear error in source context
839
+ cudaGetLastError();
840
+ }
841
+ }
842
+
843
+ // If the copy failed, try to detect if mempool allocations are involved to generate a helpful error message.
844
+ if (!check_cuda(result))
845
+ {
846
+ if (result == cudaErrorInvalidValue && src != NULL && dst != NULL)
847
+ {
848
+ // check if either of the pointers was allocated from a mempool
849
+ void* src_mempool = NULL;
850
+ void* dst_mempool = NULL;
851
+ cuPointerGetAttribute_f(&src_mempool, CU_POINTER_ATTRIBUTE_MEMPOOL_HANDLE, (CUdeviceptr)src);
852
+ cuPointerGetAttribute_f(&dst_mempool, CU_POINTER_ATTRIBUTE_MEMPOOL_HANDLE, (CUdeviceptr)dst);
853
+ cudaGetLastError(); // clear any errors
854
+ // check if either of the pointers was allocated during graph capture
855
+ auto src_alloc = g_graph_allocs.find(src);
856
+ auto dst_alloc = g_graph_allocs.find(dst);
857
+ if (src_mempool != NULL || src_alloc != g_graph_allocs.end() ||
858
+ dst_mempool != NULL || dst_alloc != g_graph_allocs.end())
859
+ {
860
+ wp::append_error_string("*** CUDA mempool allocations were used in a peer-to-peer copy during graph capture.");
861
+ wp::append_error_string("*** This operation fails if mempool access is not enabled between the peer devices.");
862
+ wp::append_error_string("*** Either enable mempool access between the devices or use the default CUDA allocator");
863
+ wp::append_error_string("*** to pre-allocate the arrays before graph capture begins.");
864
+ }
865
+ }
866
+
867
+ return false;
868
+ }
869
+
870
+ return true;
871
+ }
872
+ }
873
+
874
+
875
+ __global__ void memset_kernel(int* dest, int value, size_t n)
876
+ {
877
+ const size_t tid = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
878
+
879
+ if (tid < n)
880
+ {
881
+ dest[tid] = value;
882
+ }
883
+ }
884
+
885
+ void memset_device(void* context, void* dest, int value, size_t n)
886
+ {
887
+ ContextGuard guard(context);
888
+
889
+ if (true)// ((n%4) > 0)
890
+ {
891
+ cudaStream_t stream = get_current_stream();
892
+
893
+ begin_cuda_range(WP_TIMING_MEMSET, stream, context, "memset");
894
+
895
+ // for unaligned lengths fallback to CUDA memset
896
+ check_cuda(cudaMemsetAsync(dest, value, n, stream));
897
+
898
+ end_cuda_range(WP_TIMING_MEMSET, stream);
899
+ }
900
+ else
901
+ {
902
+ // custom kernel to support 4-byte values (and slightly lower host overhead)
903
+ const size_t num_words = n/4;
904
+ wp_launch_device(WP_CURRENT_CONTEXT, memset_kernel, num_words, ((int*)dest, value, num_words));
905
+ }
906
+ }
907
+
908
+ // fill memory buffer with a value: generic memtile kernel using memcpy for each element
909
+ __global__ void memtile_kernel(void* dst, const void* src, size_t srcsize, size_t n)
910
+ {
911
+ size_t tid = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
912
+ if (tid < n)
913
+ {
914
+ memcpy((int8_t*)dst + srcsize * tid, src, srcsize);
915
+ }
916
+ }
917
+
918
+ // this should be faster than memtile_kernel, but requires proper alignment of dst
919
+ template <typename T>
920
+ __global__ void memtile_value_kernel(T* dst, T value, size_t n)
921
+ {
922
+ size_t tid = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
923
+ if (tid < n)
924
+ {
925
+ dst[tid] = value;
926
+ }
927
+ }
928
+
929
+ void memtile_device(void* context, void* dst, const void* src, size_t srcsize, size_t n)
930
+ {
931
+ ContextGuard guard(context);
932
+
933
+ size_t dst_addr = reinterpret_cast<size_t>(dst);
934
+ size_t src_addr = reinterpret_cast<size_t>(src);
935
+
936
+ // try memtile_value first because it should be faster, but we need to ensure proper alignment
937
+ if (srcsize == 8 && (dst_addr & 7) == 0 && (src_addr & 7) == 0)
938
+ {
939
+ int64_t* p = reinterpret_cast<int64_t*>(dst);
940
+ int64_t value = *reinterpret_cast<const int64_t*>(src);
941
+ wp_launch_device(WP_CURRENT_CONTEXT, memtile_value_kernel, n, (p, value, n));
942
+ }
943
+ else if (srcsize == 4 && (dst_addr & 3) == 0 && (src_addr & 3) == 0)
944
+ {
945
+ int32_t* p = reinterpret_cast<int32_t*>(dst);
946
+ int32_t value = *reinterpret_cast<const int32_t*>(src);
947
+ wp_launch_device(WP_CURRENT_CONTEXT, memtile_value_kernel, n, (p, value, n));
948
+ }
949
+ else if (srcsize == 2 && (dst_addr & 1) == 0 && (src_addr & 1) == 0)
950
+ {
951
+ int16_t* p = reinterpret_cast<int16_t*>(dst);
952
+ int16_t value = *reinterpret_cast<const int16_t*>(src);
953
+ wp_launch_device(WP_CURRENT_CONTEXT, memtile_value_kernel, n, (p, value, n));
954
+ }
955
+ else if (srcsize == 1)
956
+ {
957
+ check_cuda(cudaMemset(dst, *reinterpret_cast<const int8_t*>(src), n));
958
+ }
959
+ else
960
+ {
961
+ // generic version
962
+
963
+ // copy value to device memory
964
+ // TODO: use a persistent stream-local staging buffer to avoid allocs?
965
+ void* src_devptr = alloc_device(WP_CURRENT_CONTEXT, srcsize);
966
+ check_cuda(cudaMemcpyAsync(src_devptr, src, srcsize, cudaMemcpyHostToDevice, get_current_stream()));
967
+
968
+ wp_launch_device(WP_CURRENT_CONTEXT, memtile_kernel, n, (dst, src_devptr, srcsize, n));
969
+
970
+ free_device(WP_CURRENT_CONTEXT, src_devptr);
971
+
972
+ }
973
+ }
974
+
975
+
976
+ static __global__ void array_copy_1d_kernel(void* dst, const void* src,
977
+ int dst_stride, int src_stride,
978
+ const int* dst_indices, const int* src_indices,
979
+ int n, int elem_size)
980
+ {
981
+ int i = blockIdx.x * blockDim.x + threadIdx.x;
982
+ if (i < n)
983
+ {
984
+ int src_idx = src_indices ? src_indices[i] : i;
985
+ int dst_idx = dst_indices ? dst_indices[i] : i;
986
+ const char* p = (const char*)src + src_idx * src_stride;
987
+ char* q = (char*)dst + dst_idx * dst_stride;
988
+ memcpy(q, p, elem_size);
989
+ }
990
+ }
991
+
992
+ static __global__ void array_copy_2d_kernel(void* dst, const void* src,
993
+ wp::vec_t<2, int> dst_strides, wp::vec_t<2, int> src_strides,
994
+ wp::vec_t<2, const int*> dst_indices, wp::vec_t<2, const int*> src_indices,
995
+ wp::vec_t<2, int> shape, int elem_size)
996
+ {
997
+ int tid = blockIdx.x * blockDim.x + threadIdx.x;
998
+ int n = shape[1];
999
+ int i = tid / n;
1000
+ int j = tid % n;
1001
+ if (i < shape[0] /*&& j < shape[1]*/)
1002
+ {
1003
+ int src_idx0 = src_indices[0] ? src_indices[0][i] : i;
1004
+ int dst_idx0 = dst_indices[0] ? dst_indices[0][i] : i;
1005
+ int src_idx1 = src_indices[1] ? src_indices[1][j] : j;
1006
+ int dst_idx1 = dst_indices[1] ? dst_indices[1][j] : j;
1007
+ const char* p = (const char*)src + src_idx0 * src_strides[0] + src_idx1 * src_strides[1];
1008
+ char* q = (char*)dst + dst_idx0 * dst_strides[0] + dst_idx1 * dst_strides[1];
1009
+ memcpy(q, p, elem_size);
1010
+ }
1011
+ }
1012
+
1013
+ static __global__ void array_copy_3d_kernel(void* dst, const void* src,
1014
+ wp::vec_t<3, int> dst_strides, wp::vec_t<3, int> src_strides,
1015
+ wp::vec_t<3, const int*> dst_indices, wp::vec_t<3, const int*> src_indices,
1016
+ wp::vec_t<3, int> shape, int elem_size)
1017
+ {
1018
+ int tid = blockIdx.x * blockDim.x + threadIdx.x;
1019
+ int n = shape[1];
1020
+ int o = shape[2];
1021
+ int i = tid / (n * o);
1022
+ int j = tid % (n * o) / o;
1023
+ int k = tid % o;
1024
+ if (i < shape[0] && j < shape[1] /*&& k < shape[2]*/)
1025
+ {
1026
+ int src_idx0 = src_indices[0] ? src_indices[0][i] : i;
1027
+ int dst_idx0 = dst_indices[0] ? dst_indices[0][i] : i;
1028
+ int src_idx1 = src_indices[1] ? src_indices[1][j] : j;
1029
+ int dst_idx1 = dst_indices[1] ? dst_indices[1][j] : j;
1030
+ int src_idx2 = src_indices[2] ? src_indices[2][k] : k;
1031
+ int dst_idx2 = dst_indices[2] ? dst_indices[2][k] : k;
1032
+ const char* p = (const char*)src + src_idx0 * src_strides[0]
1033
+ + src_idx1 * src_strides[1]
1034
+ + src_idx2 * src_strides[2];
1035
+ char* q = (char*)dst + dst_idx0 * dst_strides[0]
1036
+ + dst_idx1 * dst_strides[1]
1037
+ + dst_idx2 * dst_strides[2];
1038
+ memcpy(q, p, elem_size);
1039
+ }
1040
+ }
1041
+
1042
+ static __global__ void array_copy_4d_kernel(void* dst, const void* src,
1043
+ wp::vec_t<4, int> dst_strides, wp::vec_t<4, int> src_strides,
1044
+ wp::vec_t<4, const int*> dst_indices, wp::vec_t<4, const int*> src_indices,
1045
+ wp::vec_t<4, int> shape, int elem_size)
1046
+ {
1047
+ int tid = blockIdx.x * blockDim.x + threadIdx.x;
1048
+ int n = shape[1];
1049
+ int o = shape[2];
1050
+ int p = shape[3];
1051
+ int i = tid / (n * o * p);
1052
+ int j = tid % (n * o * p) / (o * p);
1053
+ int k = tid % (o * p) / p;
1054
+ int l = tid % p;
1055
+ if (i < shape[0] && j < shape[1] && k < shape[2] /*&& l < shape[3]*/)
1056
+ {
1057
+ int src_idx0 = src_indices[0] ? src_indices[0][i] : i;
1058
+ int dst_idx0 = dst_indices[0] ? dst_indices[0][i] : i;
1059
+ int src_idx1 = src_indices[1] ? src_indices[1][j] : j;
1060
+ int dst_idx1 = dst_indices[1] ? dst_indices[1][j] : j;
1061
+ int src_idx2 = src_indices[2] ? src_indices[2][k] : k;
1062
+ int dst_idx2 = dst_indices[2] ? dst_indices[2][k] : k;
1063
+ int src_idx3 = src_indices[3] ? src_indices[3][l] : l;
1064
+ int dst_idx3 = dst_indices[3] ? dst_indices[3][l] : l;
1065
+ const char* p = (const char*)src + src_idx0 * src_strides[0]
1066
+ + src_idx1 * src_strides[1]
1067
+ + src_idx2 * src_strides[2]
1068
+ + src_idx3 * src_strides[3];
1069
+ char* q = (char*)dst + dst_idx0 * dst_strides[0]
1070
+ + dst_idx1 * dst_strides[1]
1071
+ + dst_idx2 * dst_strides[2]
1072
+ + dst_idx3 * dst_strides[3];
1073
+ memcpy(q, p, elem_size);
1074
+ }
1075
+ }
1076
+
1077
+
1078
+ static __global__ void array_copy_from_fabric_kernel(wp::fabricarray_t<void> src,
1079
+ void* dst_data, int dst_stride, const int* dst_indices,
1080
+ int elem_size)
1081
+ {
1082
+ int tid = blockIdx.x * blockDim.x + threadIdx.x;
1083
+
1084
+ if (tid < src.size)
1085
+ {
1086
+ int dst_idx = dst_indices ? dst_indices[tid] : tid;
1087
+ void* dst_ptr = (char*)dst_data + dst_idx * dst_stride;
1088
+ const void* src_ptr = fabricarray_element_ptr(src, tid, elem_size);
1089
+ memcpy(dst_ptr, src_ptr, elem_size);
1090
+ }
1091
+ }
1092
+
1093
+ static __global__ void array_copy_from_fabric_indexed_kernel(wp::indexedfabricarray_t<void> src,
1094
+ void* dst_data, int dst_stride, const int* dst_indices,
1095
+ int elem_size)
1096
+ {
1097
+ int tid = blockIdx.x * blockDim.x + threadIdx.x;
1098
+
1099
+ if (tid < src.size)
1100
+ {
1101
+ int src_index = src.indices[tid];
1102
+ int dst_idx = dst_indices ? dst_indices[tid] : tid;
1103
+ void* dst_ptr = (char*)dst_data + dst_idx * dst_stride;
1104
+ const void* src_ptr = fabricarray_element_ptr(src.fa, src_index, elem_size);
1105
+ memcpy(dst_ptr, src_ptr, elem_size);
1106
+ }
1107
+ }
1108
+
1109
+ static __global__ void array_copy_to_fabric_kernel(wp::fabricarray_t<void> dst,
1110
+ const void* src_data, int src_stride, const int* src_indices,
1111
+ int elem_size)
1112
+ {
1113
+ int tid = blockIdx.x * blockDim.x + threadIdx.x;
1114
+
1115
+ if (tid < dst.size)
1116
+ {
1117
+ int src_idx = src_indices ? src_indices[tid] : tid;
1118
+ const void* src_ptr = (const char*)src_data + src_idx * src_stride;
1119
+ void* dst_ptr = fabricarray_element_ptr(dst, tid, elem_size);
1120
+ memcpy(dst_ptr, src_ptr, elem_size);
1121
+ }
1122
+ }
1123
+
1124
+ static __global__ void array_copy_to_fabric_indexed_kernel(wp::indexedfabricarray_t<void> dst,
1125
+ const void* src_data, int src_stride, const int* src_indices,
1126
+ int elem_size)
1127
+ {
1128
+ int tid = blockIdx.x * blockDim.x + threadIdx.x;
1129
+
1130
+ if (tid < dst.size)
1131
+ {
1132
+ int src_idx = src_indices ? src_indices[tid] : tid;
1133
+ const void* src_ptr = (const char*)src_data + src_idx * src_stride;
1134
+ int dst_idx = dst.indices[tid];
1135
+ void* dst_ptr = fabricarray_element_ptr(dst.fa, dst_idx, elem_size);
1136
+ memcpy(dst_ptr, src_ptr, elem_size);
1137
+ }
1138
+ }
1139
+
1140
+
1141
+ static __global__ void array_copy_fabric_to_fabric_kernel(wp::fabricarray_t<void> dst, wp::fabricarray_t<void> src, int elem_size)
1142
+ {
1143
+ int tid = blockIdx.x * blockDim.x + threadIdx.x;
1144
+
1145
+ if (tid < dst.size)
1146
+ {
1147
+ const void* src_ptr = fabricarray_element_ptr(src, tid, elem_size);
1148
+ void* dst_ptr = fabricarray_element_ptr(dst, tid, elem_size);
1149
+ memcpy(dst_ptr, src_ptr, elem_size);
1150
+ }
1151
+ }
1152
+
1153
+
1154
+ static __global__ void array_copy_fabric_to_fabric_indexed_kernel(wp::indexedfabricarray_t<void> dst, wp::fabricarray_t<void> src, int elem_size)
1155
+ {
1156
+ int tid = blockIdx.x * blockDim.x + threadIdx.x;
1157
+
1158
+ if (tid < dst.size)
1159
+ {
1160
+ const void* src_ptr = fabricarray_element_ptr(src, tid, elem_size);
1161
+ int dst_index = dst.indices[tid];
1162
+ void* dst_ptr = fabricarray_element_ptr(dst.fa, dst_index, elem_size);
1163
+ memcpy(dst_ptr, src_ptr, elem_size);
1164
+ }
1165
+ }
1166
+
1167
+
1168
+ static __global__ void array_copy_fabric_indexed_to_fabric_kernel(wp::fabricarray_t<void> dst, wp::indexedfabricarray_t<void> src, int elem_size)
1169
+ {
1170
+ int tid = blockIdx.x * blockDim.x + threadIdx.x;
1171
+
1172
+ if (tid < dst.size)
1173
+ {
1174
+ int src_index = src.indices[tid];
1175
+ const void* src_ptr = fabricarray_element_ptr(src.fa, src_index, elem_size);
1176
+ void* dst_ptr = fabricarray_element_ptr(dst, tid, elem_size);
1177
+ memcpy(dst_ptr, src_ptr, elem_size);
1178
+ }
1179
+ }
1180
+
1181
+
1182
+ static __global__ void array_copy_fabric_indexed_to_fabric_indexed_kernel(wp::indexedfabricarray_t<void> dst, wp::indexedfabricarray_t<void> src, int elem_size)
1183
+ {
1184
+ int tid = blockIdx.x * blockDim.x + threadIdx.x;
1185
+
1186
+ if (tid < dst.size)
1187
+ {
1188
+ int src_index = src.indices[tid];
1189
+ int dst_index = dst.indices[tid];
1190
+ const void* src_ptr = fabricarray_element_ptr(src.fa, src_index, elem_size);
1191
+ void* dst_ptr = fabricarray_element_ptr(dst.fa, dst_index, elem_size);
1192
+ memcpy(dst_ptr, src_ptr, elem_size);
1193
+ }
1194
+ }
1195
+
1196
+
1197
+ WP_API bool array_copy_device(void* context, void* dst, void* src, int dst_type, int src_type, int elem_size)
1198
+ {
1199
+ if (!src || !dst)
1200
+ return false;
1201
+
1202
+ const void* src_data = NULL;
1203
+ void* dst_data = NULL;
1204
+ int src_ndim = 0;
1205
+ int dst_ndim = 0;
1206
+ const int* src_shape = NULL;
1207
+ const int* dst_shape = NULL;
1208
+ const int* src_strides = NULL;
1209
+ const int* dst_strides = NULL;
1210
+ const int*const* src_indices = NULL;
1211
+ const int*const* dst_indices = NULL;
1212
+
1213
+ const wp::fabricarray_t<void>* src_fabricarray = NULL;
1214
+ wp::fabricarray_t<void>* dst_fabricarray = NULL;
1215
+
1216
+ const wp::indexedfabricarray_t<void>* src_indexedfabricarray = NULL;
1217
+ wp::indexedfabricarray_t<void>* dst_indexedfabricarray = NULL;
1218
+
1219
+ const int* null_indices[wp::ARRAY_MAX_DIMS] = { NULL };
1220
+
1221
+ if (src_type == wp::ARRAY_TYPE_REGULAR)
1222
+ {
1223
+ const wp::array_t<void>& src_arr = *static_cast<const wp::array_t<void>*>(src);
1224
+ src_data = src_arr.data;
1225
+ src_ndim = src_arr.ndim;
1226
+ src_shape = src_arr.shape.dims;
1227
+ src_strides = src_arr.strides;
1228
+ src_indices = null_indices;
1229
+ }
1230
+ else if (src_type == wp::ARRAY_TYPE_INDEXED)
1231
+ {
1232
+ const wp::indexedarray_t<void>& src_arr = *static_cast<const wp::indexedarray_t<void>*>(src);
1233
+ src_data = src_arr.arr.data;
1234
+ src_ndim = src_arr.arr.ndim;
1235
+ src_shape = src_arr.shape.dims;
1236
+ src_strides = src_arr.arr.strides;
1237
+ src_indices = src_arr.indices;
1238
+ }
1239
+ else if (src_type == wp::ARRAY_TYPE_FABRIC)
1240
+ {
1241
+ src_fabricarray = static_cast<const wp::fabricarray_t<void>*>(src);
1242
+ src_ndim = 1;
1243
+ }
1244
+ else if (src_type == wp::ARRAY_TYPE_FABRIC_INDEXED)
1245
+ {
1246
+ src_indexedfabricarray = static_cast<const wp::indexedfabricarray_t<void>*>(src);
1247
+ src_ndim = 1;
1248
+ }
1249
+ else
1250
+ {
1251
+ fprintf(stderr, "Warp copy error: Invalid array type (%d)\n", src_type);
1252
+ return false;
1253
+ }
1254
+
1255
+ if (dst_type == wp::ARRAY_TYPE_REGULAR)
1256
+ {
1257
+ const wp::array_t<void>& dst_arr = *static_cast<const wp::array_t<void>*>(dst);
1258
+ dst_data = dst_arr.data;
1259
+ dst_ndim = dst_arr.ndim;
1260
+ dst_shape = dst_arr.shape.dims;
1261
+ dst_strides = dst_arr.strides;
1262
+ dst_indices = null_indices;
1263
+ }
1264
+ else if (dst_type == wp::ARRAY_TYPE_INDEXED)
1265
+ {
1266
+ const wp::indexedarray_t<void>& dst_arr = *static_cast<const wp::indexedarray_t<void>*>(dst);
1267
+ dst_data = dst_arr.arr.data;
1268
+ dst_ndim = dst_arr.arr.ndim;
1269
+ dst_shape = dst_arr.shape.dims;
1270
+ dst_strides = dst_arr.arr.strides;
1271
+ dst_indices = dst_arr.indices;
1272
+ }
1273
+ else if (dst_type == wp::ARRAY_TYPE_FABRIC)
1274
+ {
1275
+ dst_fabricarray = static_cast<wp::fabricarray_t<void>*>(dst);
1276
+ dst_ndim = 1;
1277
+ }
1278
+ else if (dst_type == wp::ARRAY_TYPE_FABRIC_INDEXED)
1279
+ {
1280
+ dst_indexedfabricarray = static_cast<wp::indexedfabricarray_t<void>*>(dst);
1281
+ dst_ndim = 1;
1282
+ }
1283
+ else
1284
+ {
1285
+ fprintf(stderr, "Warp copy error: Invalid array type (%d)\n", dst_type);
1286
+ return false;
1287
+ }
1288
+
1289
+ if (src_ndim != dst_ndim)
1290
+ {
1291
+ fprintf(stderr, "Warp copy error: Incompatible array dimensionalities (%d and %d)\n", src_ndim, dst_ndim);
1292
+ return false;
1293
+ }
1294
+
1295
+ ContextGuard guard(context);
1296
+
1297
+ // handle fabric arrays
1298
+ if (dst_fabricarray)
1299
+ {
1300
+ size_t n = dst_fabricarray->size;
1301
+ if (src_fabricarray)
1302
+ {
1303
+ // copy from fabric to fabric
1304
+ if (src_fabricarray->size != n)
1305
+ {
1306
+ fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
1307
+ return false;
1308
+ }
1309
+ wp_launch_device(WP_CURRENT_CONTEXT, array_copy_fabric_to_fabric_kernel, n,
1310
+ (*dst_fabricarray, *src_fabricarray, elem_size));
1311
+ return true;
1312
+ }
1313
+ else if (src_indexedfabricarray)
1314
+ {
1315
+ // copy from fabric indexed to fabric
1316
+ if (src_indexedfabricarray->size != n)
1317
+ {
1318
+ fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
1319
+ return false;
1320
+ }
1321
+ wp_launch_device(WP_CURRENT_CONTEXT, array_copy_fabric_indexed_to_fabric_kernel, n,
1322
+ (*dst_fabricarray, *src_indexedfabricarray, elem_size));
1323
+ return true;
1324
+ }
1325
+ else
1326
+ {
1327
+ // copy to fabric
1328
+ if (size_t(src_shape[0]) != n)
1329
+ {
1330
+ fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
1331
+ return false;
1332
+ }
1333
+ wp_launch_device(WP_CURRENT_CONTEXT, array_copy_to_fabric_kernel, n,
1334
+ (*dst_fabricarray, src_data, src_strides[0], src_indices[0], elem_size));
1335
+ return true;
1336
+ }
1337
+ }
1338
+ if (dst_indexedfabricarray)
1339
+ {
1340
+ size_t n = dst_indexedfabricarray->size;
1341
+ if (src_fabricarray)
1342
+ {
1343
+ // copy from fabric to fabric indexed
1344
+ if (src_fabricarray->size != n)
1345
+ {
1346
+ fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
1347
+ return false;
1348
+ }
1349
+ wp_launch_device(WP_CURRENT_CONTEXT, array_copy_fabric_to_fabric_indexed_kernel, n,
1350
+ (*dst_indexedfabricarray, *src_fabricarray, elem_size));
1351
+ return true;
1352
+ }
1353
+ else if (src_indexedfabricarray)
1354
+ {
1355
+ // copy from fabric indexed to fabric indexed
1356
+ if (src_indexedfabricarray->size != n)
1357
+ {
1358
+ fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
1359
+ return false;
1360
+ }
1361
+ wp_launch_device(WP_CURRENT_CONTEXT, array_copy_fabric_indexed_to_fabric_indexed_kernel, n,
1362
+ (*dst_indexedfabricarray, *src_indexedfabricarray, elem_size));
1363
+ return true;
1364
+ }
1365
+ else
1366
+ {
1367
+ // copy to fabric indexed
1368
+ if (size_t(src_shape[0]) != n)
1369
+ {
1370
+ fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
1371
+ return false;
1372
+ }
1373
+ wp_launch_device(WP_CURRENT_CONTEXT, array_copy_to_fabric_indexed_kernel, n,
1374
+ (*dst_indexedfabricarray, src_data, src_strides[0], src_indices[0], elem_size));
1375
+ return true;
1376
+ }
1377
+ }
1378
+ else if (src_fabricarray)
1379
+ {
1380
+ // copy from fabric
1381
+ size_t n = src_fabricarray->size;
1382
+ if (size_t(dst_shape[0]) != n)
1383
+ {
1384
+ fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
1385
+ return false;
1386
+ }
1387
+ wp_launch_device(WP_CURRENT_CONTEXT, array_copy_from_fabric_kernel, n,
1388
+ (*src_fabricarray, dst_data, dst_strides[0], dst_indices[0], elem_size));
1389
+ return true;
1390
+ }
1391
+ else if (src_indexedfabricarray)
1392
+ {
1393
+ // copy from fabric indexed
1394
+ size_t n = src_indexedfabricarray->size;
1395
+ if (size_t(dst_shape[0]) != n)
1396
+ {
1397
+ fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
1398
+ return false;
1399
+ }
1400
+ wp_launch_device(WP_CURRENT_CONTEXT, array_copy_from_fabric_indexed_kernel, n,
1401
+ (*src_indexedfabricarray, dst_data, dst_strides[0], dst_indices[0], elem_size));
1402
+ return true;
1403
+ }
1404
+
1405
+ size_t n = 1;
1406
+ for (int i = 0; i < src_ndim; i++)
1407
+ {
1408
+ if (src_shape[i] != dst_shape[i])
1409
+ {
1410
+ fprintf(stderr, "Warp copy error: Incompatible array shapes\n");
1411
+ return false;
1412
+ }
1413
+ n *= src_shape[i];
1414
+ }
1415
+
1416
+ switch (src_ndim)
1417
+ {
1418
+ case 1:
1419
+ {
1420
+ wp_launch_device(WP_CURRENT_CONTEXT, array_copy_1d_kernel, n, (dst_data, src_data,
1421
+ dst_strides[0], src_strides[0],
1422
+ dst_indices[0], src_indices[0],
1423
+ src_shape[0], elem_size));
1424
+ break;
1425
+ }
1426
+ case 2:
1427
+ {
1428
+ wp::vec_t<2, int> shape_v(src_shape[0], src_shape[1]);
1429
+ wp::vec_t<2, int> src_strides_v(src_strides[0], src_strides[1]);
1430
+ wp::vec_t<2, int> dst_strides_v(dst_strides[0], dst_strides[1]);
1431
+ wp::vec_t<2, const int*> src_indices_v(src_indices[0], src_indices[1]);
1432
+ wp::vec_t<2, const int*> dst_indices_v(dst_indices[0], dst_indices[1]);
1433
+
1434
+ wp_launch_device(WP_CURRENT_CONTEXT, array_copy_2d_kernel, n, (dst_data, src_data,
1435
+ dst_strides_v, src_strides_v,
1436
+ dst_indices_v, src_indices_v,
1437
+ shape_v, elem_size));
1438
+ break;
1439
+ }
1440
+ case 3:
1441
+ {
1442
+ wp::vec_t<3, int> shape_v(src_shape[0], src_shape[1], src_shape[2]);
1443
+ wp::vec_t<3, int> src_strides_v(src_strides[0], src_strides[1], src_strides[2]);
1444
+ wp::vec_t<3, int> dst_strides_v(dst_strides[0], dst_strides[1], dst_strides[2]);
1445
+ wp::vec_t<3, const int*> src_indices_v(src_indices[0], src_indices[1], src_indices[2]);
1446
+ wp::vec_t<3, const int*> dst_indices_v(dst_indices[0], dst_indices[1], dst_indices[2]);
1447
+
1448
+ wp_launch_device(WP_CURRENT_CONTEXT, array_copy_3d_kernel, n, (dst_data, src_data,
1449
+ dst_strides_v, src_strides_v,
1450
+ dst_indices_v, src_indices_v,
1451
+ shape_v, elem_size));
1452
+ break;
1453
+ }
1454
+ case 4:
1455
+ {
1456
+ wp::vec_t<4, int> shape_v(src_shape[0], src_shape[1], src_shape[2], src_shape[3]);
1457
+ wp::vec_t<4, int> src_strides_v(src_strides[0], src_strides[1], src_strides[2], src_strides[3]);
1458
+ wp::vec_t<4, int> dst_strides_v(dst_strides[0], dst_strides[1], dst_strides[2], dst_strides[3]);
1459
+ wp::vec_t<4, const int*> src_indices_v(src_indices[0], src_indices[1], src_indices[2], src_indices[3]);
1460
+ wp::vec_t<4, const int*> dst_indices_v(dst_indices[0], dst_indices[1], dst_indices[2], dst_indices[3]);
1461
+
1462
+ wp_launch_device(WP_CURRENT_CONTEXT, array_copy_4d_kernel, n, (dst_data, src_data,
1463
+ dst_strides_v, src_strides_v,
1464
+ dst_indices_v, src_indices_v,
1465
+ shape_v, elem_size));
1466
+ break;
1467
+ }
1468
+ default:
1469
+ fprintf(stderr, "Warp copy error: invalid array dimensionality (%d)\n", src_ndim);
1470
+ return false;
1471
+ }
1472
+
1473
+ return check_cuda(cudaGetLastError());
1474
+ }
1475
+
1476
+
1477
+ static __global__ void array_fill_1d_kernel(void* data,
1478
+ int n,
1479
+ int stride,
1480
+ const int* indices,
1481
+ const void* value,
1482
+ int value_size)
1483
+ {
1484
+ int i = blockIdx.x * blockDim.x + threadIdx.x;
1485
+ if (i < n)
1486
+ {
1487
+ int idx = indices ? indices[i] : i;
1488
+ char* p = (char*)data + idx * stride;
1489
+ memcpy(p, value, value_size);
1490
+ }
1491
+ }
1492
+
1493
+ static __global__ void array_fill_2d_kernel(void* data,
1494
+ wp::vec_t<2, int> shape,
1495
+ wp::vec_t<2, int> strides,
1496
+ wp::vec_t<2, const int*> indices,
1497
+ const void* value,
1498
+ int value_size)
1499
+ {
1500
+ int tid = blockIdx.x * blockDim.x + threadIdx.x;
1501
+ int n = shape[1];
1502
+ int i = tid / n;
1503
+ int j = tid % n;
1504
+ if (i < shape[0] /*&& j < shape[1]*/)
1505
+ {
1506
+ int idx0 = indices[0] ? indices[0][i] : i;
1507
+ int idx1 = indices[1] ? indices[1][j] : j;
1508
+ char* p = (char*)data + idx0 * strides[0] + idx1 * strides[1];
1509
+ memcpy(p, value, value_size);
1510
+ }
1511
+ }
1512
+
1513
+ static __global__ void array_fill_3d_kernel(void* data,
1514
+ wp::vec_t<3, int> shape,
1515
+ wp::vec_t<3, int> strides,
1516
+ wp::vec_t<3, const int*> indices,
1517
+ const void* value,
1518
+ int value_size)
1519
+ {
1520
+ int tid = blockIdx.x * blockDim.x + threadIdx.x;
1521
+ int n = shape[1];
1522
+ int o = shape[2];
1523
+ int i = tid / (n * o);
1524
+ int j = tid % (n * o) / o;
1525
+ int k = tid % o;
1526
+ if (i < shape[0] && j < shape[1] /*&& k < shape[2]*/)
1527
+ {
1528
+ int idx0 = indices[0] ? indices[0][i] : i;
1529
+ int idx1 = indices[1] ? indices[1][j] : j;
1530
+ int idx2 = indices[2] ? indices[2][k] : k;
1531
+ char* p = (char*)data + idx0 * strides[0] + idx1 * strides[1] + idx2 * strides[2];
1532
+ memcpy(p, value, value_size);
1533
+ }
1534
+ }
1535
+
1536
+ static __global__ void array_fill_4d_kernel(void* data,
1537
+ wp::vec_t<4, int> shape,
1538
+ wp::vec_t<4, int> strides,
1539
+ wp::vec_t<4, const int*> indices,
1540
+ const void* value,
1541
+ int value_size)
1542
+ {
1543
+ int tid = blockIdx.x * blockDim.x + threadIdx.x;
1544
+ int n = shape[1];
1545
+ int o = shape[2];
1546
+ int p = shape[3];
1547
+ int i = tid / (n * o * p);
1548
+ int j = tid % (n * o * p) / (o * p);
1549
+ int k = tid % (o * p) / p;
1550
+ int l = tid % p;
1551
+ if (i < shape[0] && j < shape[1] && k < shape[2] /*&& l < shape[3]*/)
1552
+ {
1553
+ int idx0 = indices[0] ? indices[0][i] : i;
1554
+ int idx1 = indices[1] ? indices[1][j] : j;
1555
+ int idx2 = indices[2] ? indices[2][k] : k;
1556
+ int idx3 = indices[3] ? indices[3][l] : l;
1557
+ char* p = (char*)data + idx0 * strides[0] + idx1 * strides[1] + idx2 * strides[2] + idx3 * strides[3];
1558
+ memcpy(p, value, value_size);
1559
+ }
1560
+ }
1561
+
1562
+
1563
+ static __global__ void array_fill_fabric_kernel(wp::fabricarray_t<void> fa, const void* value, int value_size)
1564
+ {
1565
+ int tid = blockIdx.x * blockDim.x + threadIdx.x;
1566
+ if (tid < fa.size)
1567
+ {
1568
+ void* dst_ptr = fabricarray_element_ptr(fa, tid, value_size);
1569
+ memcpy(dst_ptr, value, value_size);
1570
+ }
1571
+ }
1572
+
1573
+
1574
+ static __global__ void array_fill_fabric_indexed_kernel(wp::indexedfabricarray_t<void> ifa, const void* value, int value_size)
1575
+ {
1576
+ int tid = blockIdx.x * blockDim.x + threadIdx.x;
1577
+ if (tid < ifa.size)
1578
+ {
1579
+ size_t idx = size_t(ifa.indices[tid]);
1580
+ if (idx < ifa.fa.size)
1581
+ {
1582
+ void* dst_ptr = fabricarray_element_ptr(ifa.fa, idx, value_size);
1583
+ memcpy(dst_ptr, value, value_size);
1584
+ }
1585
+ }
1586
+ }
1587
+
1588
+
1589
+ WP_API void array_fill_device(void* context, void* arr_ptr, int arr_type, const void* value_ptr, int value_size)
1590
+ {
1591
+ if (!arr_ptr || !value_ptr)
1592
+ return;
1593
+
1594
+ void* data = NULL;
1595
+ int ndim = 0;
1596
+ const int* shape = NULL;
1597
+ const int* strides = NULL;
1598
+ const int*const* indices = NULL;
1599
+
1600
+ wp::fabricarray_t<void>* fa = NULL;
1601
+ wp::indexedfabricarray_t<void>* ifa = NULL;
1602
+
1603
+ const int* null_indices[wp::ARRAY_MAX_DIMS] = { NULL };
1604
+
1605
+ if (arr_type == wp::ARRAY_TYPE_REGULAR)
1606
+ {
1607
+ wp::array_t<void>& arr = *static_cast<wp::array_t<void>*>(arr_ptr);
1608
+ data = arr.data;
1609
+ ndim = arr.ndim;
1610
+ shape = arr.shape.dims;
1611
+ strides = arr.strides;
1612
+ indices = null_indices;
1613
+ }
1614
+ else if (arr_type == wp::ARRAY_TYPE_INDEXED)
1615
+ {
1616
+ wp::indexedarray_t<void>& ia = *static_cast<wp::indexedarray_t<void>*>(arr_ptr);
1617
+ data = ia.arr.data;
1618
+ ndim = ia.arr.ndim;
1619
+ shape = ia.shape.dims;
1620
+ strides = ia.arr.strides;
1621
+ indices = ia.indices;
1622
+ }
1623
+ else if (arr_type == wp::ARRAY_TYPE_FABRIC)
1624
+ {
1625
+ fa = static_cast<wp::fabricarray_t<void>*>(arr_ptr);
1626
+ }
1627
+ else if (arr_type == wp::ARRAY_TYPE_FABRIC_INDEXED)
1628
+ {
1629
+ ifa = static_cast<wp::indexedfabricarray_t<void>*>(arr_ptr);
1630
+ }
1631
+ else
1632
+ {
1633
+ fprintf(stderr, "Warp fill error: Invalid array type id %d\n", arr_type);
1634
+ return;
1635
+ }
1636
+
1637
+ size_t n = 1;
1638
+ for (int i = 0; i < ndim; i++)
1639
+ n *= shape[i];
1640
+
1641
+ ContextGuard guard(context);
1642
+
1643
+ // copy value to device memory
1644
+ // TODO: use a persistent stream-local staging buffer to avoid allocs?
1645
+ void* value_devptr = alloc_device(WP_CURRENT_CONTEXT, value_size);
1646
+ check_cuda(cudaMemcpyAsync(value_devptr, value_ptr, value_size, cudaMemcpyHostToDevice, get_current_stream()));
1647
+
1648
+ // handle fabric arrays
1649
+ if (fa)
1650
+ {
1651
+ wp_launch_device(WP_CURRENT_CONTEXT, array_fill_fabric_kernel, n,
1652
+ (*fa, value_devptr, value_size));
1653
+ return;
1654
+ }
1655
+ else if (ifa)
1656
+ {
1657
+ wp_launch_device(WP_CURRENT_CONTEXT, array_fill_fabric_indexed_kernel, n,
1658
+ (*ifa, value_devptr, value_size));
1659
+ return;
1660
+ }
1661
+
1662
+ // handle regular or indexed arrays
1663
+ switch (ndim)
1664
+ {
1665
+ case 1:
1666
+ {
1667
+ wp_launch_device(WP_CURRENT_CONTEXT, array_fill_1d_kernel, n,
1668
+ (data, shape[0], strides[0], indices[0], value_devptr, value_size));
1669
+ break;
1670
+ }
1671
+ case 2:
1672
+ {
1673
+ wp::vec_t<2, int> shape_v(shape[0], shape[1]);
1674
+ wp::vec_t<2, int> strides_v(strides[0], strides[1]);
1675
+ wp::vec_t<2, const int*> indices_v(indices[0], indices[1]);
1676
+ wp_launch_device(WP_CURRENT_CONTEXT, array_fill_2d_kernel, n,
1677
+ (data, shape_v, strides_v, indices_v, value_devptr, value_size));
1678
+ break;
1679
+ }
1680
+ case 3:
1681
+ {
1682
+ wp::vec_t<3, int> shape_v(shape[0], shape[1], shape[2]);
1683
+ wp::vec_t<3, int> strides_v(strides[0], strides[1], strides[2]);
1684
+ wp::vec_t<3, const int*> indices_v(indices[0], indices[1], indices[2]);
1685
+ wp_launch_device(WP_CURRENT_CONTEXT, array_fill_3d_kernel, n,
1686
+ (data, shape_v, strides_v, indices_v, value_devptr, value_size));
1687
+ break;
1688
+ }
1689
+ case 4:
1690
+ {
1691
+ wp::vec_t<4, int> shape_v(shape[0], shape[1], shape[2], shape[3]);
1692
+ wp::vec_t<4, int> strides_v(strides[0], strides[1], strides[2], strides[3]);
1693
+ wp::vec_t<4, const int*> indices_v(indices[0], indices[1], indices[2], indices[3]);
1694
+ wp_launch_device(WP_CURRENT_CONTEXT, array_fill_4d_kernel, n,
1695
+ (data, shape_v, strides_v, indices_v, value_devptr, value_size));
1696
+ break;
1697
+ }
1698
+ default:
1699
+ fprintf(stderr, "Warp fill error: invalid array dimensionality (%d)\n", ndim);
1700
+ return;
1701
+ }
1702
+
1703
+ free_device(WP_CURRENT_CONTEXT, value_devptr);
1704
+ }
1705
+
1706
+ void array_scan_int_device(uint64_t in, uint64_t out, int len, bool inclusive)
1707
+ {
1708
+ scan_device((const int*)in, (int*)out, len, inclusive);
1709
+ }
1710
+
1711
+ void array_scan_float_device(uint64_t in, uint64_t out, int len, bool inclusive)
1712
+ {
1713
+ scan_device((const float*)in, (float*)out, len, inclusive);
1714
+ }
1715
+
1716
+ int cuda_driver_version()
1717
+ {
1718
+ int version;
1719
+ if (check_cu(cuDriverGetVersion_f(&version)))
1720
+ return version;
1721
+ else
1722
+ return 0;
1723
+ }
1724
+
1725
+ int cuda_toolkit_version()
1726
+ {
1727
+ return CUDA_VERSION;
1728
+ }
1729
+
1730
+ bool cuda_driver_is_initialized()
1731
+ {
1732
+ return is_cuda_driver_initialized();
1733
+ }
1734
+
1735
+ int nvrtc_supported_arch_count()
1736
+ {
1737
+ int count;
1738
+ if (check_nvrtc(nvrtcGetNumSupportedArchs(&count)))
1739
+ return count;
1740
+ else
1741
+ return 0;
1742
+ }
1743
+
1744
+ void nvrtc_supported_archs(int* archs)
1745
+ {
1746
+ if (archs)
1747
+ {
1748
+ check_nvrtc(nvrtcGetSupportedArchs(archs));
1749
+ }
1750
+ }
1751
+
1752
+ int cuda_device_get_count()
1753
+ {
1754
+ int count = 0;
1755
+ check_cu(cuDeviceGetCount_f(&count));
1756
+ return count;
1757
+ }
1758
+
1759
+ void* cuda_device_get_primary_context(int ordinal)
1760
+ {
1761
+ if (ordinal >= 0 && ordinal < int(g_devices.size()))
1762
+ {
1763
+ DeviceInfo& device_info = g_devices[ordinal];
1764
+
1765
+ // acquire the primary context if we haven't already
1766
+ if (!device_info.primary_context)
1767
+ check_cu(cuDevicePrimaryCtxRetain_f(&device_info.primary_context, device_info.device));
1768
+
1769
+ return device_info.primary_context;
1770
+ }
1771
+
1772
+ return NULL;
1773
+ }
1774
+
1775
+ const char* cuda_device_get_name(int ordinal)
1776
+ {
1777
+ if (ordinal >= 0 && ordinal < int(g_devices.size()))
1778
+ return g_devices[ordinal].name;
1779
+ return NULL;
1780
+ }
1781
+
1782
+ int cuda_device_get_arch(int ordinal)
1783
+ {
1784
+ if (ordinal >= 0 && ordinal < int(g_devices.size()))
1785
+ return g_devices[ordinal].arch;
1786
+ return 0;
1787
+ }
1788
+
1789
+ void cuda_device_get_uuid(int ordinal, char uuid[16])
1790
+ {
1791
+ memcpy(uuid, g_devices[ordinal].uuid.bytes, sizeof(char)*16);
1792
+ }
1793
+
1794
+ int cuda_device_get_pci_domain_id(int ordinal)
1795
+ {
1796
+ if (ordinal >= 0 && ordinal < int(g_devices.size()))
1797
+ return g_devices[ordinal].pci_domain_id;
1798
+ return -1;
1799
+ }
1800
+
1801
+ int cuda_device_get_pci_bus_id(int ordinal)
1802
+ {
1803
+ if (ordinal >= 0 && ordinal < int(g_devices.size()))
1804
+ return g_devices[ordinal].pci_bus_id;
1805
+ return -1;
1806
+ }
1807
+
1808
+ int cuda_device_get_pci_device_id(int ordinal)
1809
+ {
1810
+ if (ordinal >= 0 && ordinal < int(g_devices.size()))
1811
+ return g_devices[ordinal].pci_device_id;
1812
+ return -1;
1813
+ }
1814
+
1815
+ int cuda_device_is_uva(int ordinal)
1816
+ {
1817
+ if (ordinal >= 0 && ordinal < int(g_devices.size()))
1818
+ return g_devices[ordinal].is_uva;
1819
+ return 0;
1820
+ }
1821
+
1822
+ int cuda_device_is_mempool_supported(int ordinal)
1823
+ {
1824
+ if (ordinal >= 0 && ordinal < int(g_devices.size()))
1825
+ return g_devices[ordinal].is_mempool_supported;
1826
+ return 0;
1827
+ }
1828
+
1829
+ int cuda_device_is_ipc_supported(int ordinal)
1830
+ {
1831
+ if (ordinal >= 0 && ordinal < int(g_devices.size()))
1832
+ return g_devices[ordinal].is_ipc_supported;
1833
+ return 0;
1834
+ }
1835
+
1836
+ int cuda_device_set_mempool_release_threshold(int ordinal, uint64_t threshold)
1837
+ {
1838
+ if (ordinal < 0 || ordinal > int(g_devices.size()))
1839
+ {
1840
+ fprintf(stderr, "Invalid device ordinal %d\n", ordinal);
1841
+ return 0;
1842
+ }
1843
+
1844
+ if (!g_devices[ordinal].is_mempool_supported)
1845
+ return 0;
1846
+
1847
+ cudaMemPool_t pool;
1848
+ if (!check_cuda(cudaDeviceGetDefaultMemPool(&pool, ordinal)))
1849
+ {
1850
+ fprintf(stderr, "Warp error: Failed to get memory pool on device %d\n", ordinal);
1851
+ return 0;
1852
+ }
1853
+
1854
+ if (!check_cuda(cudaMemPoolSetAttribute(pool, cudaMemPoolAttrReleaseThreshold, &threshold)))
1855
+ {
1856
+ fprintf(stderr, "Warp error: Failed to set memory pool attribute on device %d\n", ordinal);
1857
+ return 0;
1858
+ }
1859
+
1860
+ return 1; // success
1861
+ }
1862
+
1863
+ uint64_t cuda_device_get_mempool_release_threshold(int ordinal)
1864
+ {
1865
+ if (ordinal < 0 || ordinal > int(g_devices.size()))
1866
+ {
1867
+ fprintf(stderr, "Invalid device ordinal %d\n", ordinal);
1868
+ return 0;
1869
+ }
1870
+
1871
+ if (!g_devices[ordinal].is_mempool_supported)
1872
+ return 0;
1873
+
1874
+ cudaMemPool_t pool;
1875
+ if (!check_cuda(cudaDeviceGetDefaultMemPool(&pool, ordinal)))
1876
+ {
1877
+ fprintf(stderr, "Warp error: Failed to get memory pool on device %d\n", ordinal);
1878
+ return 0;
1879
+ }
1880
+
1881
+ uint64_t threshold = 0;
1882
+ if (!check_cuda(cudaMemPoolGetAttribute(pool, cudaMemPoolAttrReleaseThreshold, &threshold)))
1883
+ {
1884
+ fprintf(stderr, "Warp error: Failed to get memory pool release threshold on device %d\n", ordinal);
1885
+ return 0;
1886
+ }
1887
+
1888
+ return threshold;
1889
+ }
1890
+
1891
+ uint64_t cuda_device_get_mempool_used_mem_current(int ordinal)
1892
+ {
1893
+ if (ordinal < 0 || ordinal > int(g_devices.size()))
1894
+ {
1895
+ fprintf(stderr, "Invalid device ordinal %d\n", ordinal);
1896
+ return 0;
1897
+ }
1898
+
1899
+ if (!g_devices[ordinal].is_mempool_supported)
1900
+ return 0;
1901
+
1902
+ cudaMemPool_t pool;
1903
+ if (!check_cuda(cudaDeviceGetDefaultMemPool(&pool, ordinal)))
1904
+ {
1905
+ fprintf(stderr, "Warp error: Failed to get memory pool on device %d\n", ordinal);
1906
+ return 0;
1907
+ }
1908
+
1909
+ uint64_t mem_used = 0;
1910
+ if (!check_cuda(cudaMemPoolGetAttribute(pool, cudaMemPoolAttrUsedMemCurrent, &mem_used)))
1911
+ {
1912
+ fprintf(stderr, "Warp error: Failed to get amount of currently used memory from the memory pool on device %d\n", ordinal);
1913
+ return 0;
1914
+ }
1915
+
1916
+ return mem_used;
1917
+ }
1918
+
1919
+ uint64_t cuda_device_get_mempool_used_mem_high(int ordinal)
1920
+ {
1921
+ if (ordinal < 0 || ordinal > int(g_devices.size()))
1922
+ {
1923
+ fprintf(stderr, "Invalid device ordinal %d\n", ordinal);
1924
+ return 0;
1925
+ }
1926
+
1927
+ if (!g_devices[ordinal].is_mempool_supported)
1928
+ return 0;
1929
+
1930
+ cudaMemPool_t pool;
1931
+ if (!check_cuda(cudaDeviceGetDefaultMemPool(&pool, ordinal)))
1932
+ {
1933
+ fprintf(stderr, "Warp error: Failed to get memory pool on device %d\n", ordinal);
1934
+ return 0;
1935
+ }
1936
+
1937
+ uint64_t mem_high_water_mark = 0;
1938
+ if (!check_cuda(cudaMemPoolGetAttribute(pool, cudaMemPoolAttrUsedMemHigh, &mem_high_water_mark)))
1939
+ {
1940
+ fprintf(stderr, "Warp error: Failed to get memory usage high water mark from the memory pool on device %d\n", ordinal);
1941
+ return 0;
1942
+ }
1943
+
1944
+ return mem_high_water_mark;
1945
+ }
1946
+
1947
+ void cuda_device_get_memory_info(int ordinal, size_t* free_mem, size_t* total_mem)
1948
+ {
1949
+ // use temporary storage if user didn't specify pointers
1950
+ size_t tmp_free_mem, tmp_total_mem;
1951
+
1952
+ if (free_mem)
1953
+ *free_mem = 0;
1954
+ else
1955
+ free_mem = &tmp_free_mem;
1956
+
1957
+ if (total_mem)
1958
+ *total_mem = 0;
1959
+ else
1960
+ total_mem = &tmp_total_mem;
1961
+
1962
+ if (ordinal >= 0 && ordinal < int(g_devices.size()))
1963
+ {
1964
+ if (g_devices[ordinal].primary_context)
1965
+ {
1966
+ ContextGuard guard(g_devices[ordinal].primary_context, true);
1967
+ check_cu(cuMemGetInfo_f(free_mem, total_mem));
1968
+ }
1969
+ else
1970
+ {
1971
+ // if we haven't acquired the primary context yet, acquire it temporarily
1972
+ CUcontext primary_context = NULL;
1973
+ check_cu(cuDevicePrimaryCtxRetain_f(&primary_context, g_devices[ordinal].device));
1974
+ {
1975
+ ContextGuard guard(primary_context, true);
1976
+ check_cu(cuMemGetInfo_f(free_mem, total_mem));
1977
+ }
1978
+ check_cu(cuDevicePrimaryCtxRelease_f(g_devices[ordinal].device));
1979
+ }
1980
+ }
1981
+ }
1982
+
1983
+
1984
+ void* cuda_context_get_current()
1985
+ {
1986
+ return get_current_context();
1987
+ }
1988
+
1989
+ void cuda_context_set_current(void* context)
1990
+ {
1991
+ CUcontext ctx = static_cast<CUcontext>(context);
1992
+ CUcontext prev_ctx = NULL;
1993
+ check_cu(cuCtxGetCurrent_f(&prev_ctx));
1994
+ if (ctx != prev_ctx)
1995
+ {
1996
+ check_cu(cuCtxSetCurrent_f(ctx));
1997
+ }
1998
+ }
1999
+
2000
+ void cuda_context_push_current(void* context)
2001
+ {
2002
+ check_cu(cuCtxPushCurrent_f(static_cast<CUcontext>(context)));
2003
+ }
2004
+
2005
+ void cuda_context_pop_current()
2006
+ {
2007
+ CUcontext context;
2008
+ check_cu(cuCtxPopCurrent_f(&context));
2009
+ }
2010
+
2011
+ void* cuda_context_create(int device_ordinal)
2012
+ {
2013
+ CUcontext ctx = NULL;
2014
+ CUdevice device;
2015
+ if (check_cu(cuDeviceGet_f(&device, device_ordinal)))
2016
+ check_cu(cuCtxCreate_f(&ctx, 0, device));
2017
+ return ctx;
2018
+ }
2019
+
2020
+ void cuda_context_destroy(void* context)
2021
+ {
2022
+ if (context)
2023
+ {
2024
+ CUcontext ctx = static_cast<CUcontext>(context);
2025
+
2026
+ // ensure this is not the current context
2027
+ if (ctx == cuda_context_get_current())
2028
+ cuda_context_set_current(NULL);
2029
+
2030
+ // release the cached info about this context
2031
+ ContextInfo* info = get_context_info(ctx);
2032
+ if (info)
2033
+ {
2034
+ if (info->stream)
2035
+ check_cu(cuStreamDestroy_f(info->stream));
2036
+
2037
+ g_contexts.erase(ctx);
2038
+ }
2039
+
2040
+ check_cu(cuCtxDestroy_f(ctx));
2041
+ }
2042
+ }
2043
+
2044
+ void cuda_context_synchronize(void* context)
2045
+ {
2046
+ ContextGuard guard(context);
2047
+
2048
+ check_cu(cuCtxSynchronize_f());
2049
+
2050
+ if (free_deferred_allocs(context ? context : get_current_context()) > 0)
2051
+ {
2052
+ // ensure deferred asynchronous deallocations complete
2053
+ check_cu(cuCtxSynchronize_f());
2054
+ }
2055
+
2056
+ unload_deferred_modules(context);
2057
+
2058
+ // check_cuda(cudaDeviceGraphMemTrim(cuda_context_get_device_ordinal(context)));
2059
+ }
2060
+
2061
+ uint64_t cuda_context_check(void* context)
2062
+ {
2063
+ ContextGuard guard(context);
2064
+
2065
+ // check errors before syncing
2066
+ cudaError_t e = cudaGetLastError();
2067
+ check_cuda(e);
2068
+
2069
+ cudaStreamCaptureStatus status = cudaStreamCaptureStatusNone;
2070
+ check_cuda(cudaStreamIsCapturing(get_current_stream(), &status));
2071
+
2072
+ // synchronize if the stream is not capturing
2073
+ if (status == cudaStreamCaptureStatusNone)
2074
+ {
2075
+ check_cuda(cudaDeviceSynchronize());
2076
+ e = cudaGetLastError();
2077
+ }
2078
+
2079
+ return static_cast<uint64_t>(e);
2080
+ }
2081
+
2082
+
2083
+ int cuda_context_get_device_ordinal(void* context)
2084
+ {
2085
+ ContextInfo* info = get_context_info(static_cast<CUcontext>(context));
2086
+ return info && info->device_info ? info->device_info->ordinal : -1;
2087
+ }
2088
+
2089
+ int cuda_context_is_primary(void* context)
2090
+ {
2091
+ CUcontext ctx = static_cast<CUcontext>(context);
2092
+ ContextInfo* context_info = get_context_info(ctx);
2093
+ if (!context_info)
2094
+ {
2095
+ fprintf(stderr, "Warp error: Failed to get context info\n");
2096
+ return 0;
2097
+ }
2098
+
2099
+ // if the device primary context is known, check if it matches the given context
2100
+ DeviceInfo* device_info = context_info->device_info;
2101
+ if (device_info->primary_context)
2102
+ return int(ctx == device_info->primary_context);
2103
+
2104
+ // there is no CUDA API to check if a context is primary, but we can temporarily
2105
+ // acquire the device's primary context to check the pointer
2106
+ CUcontext primary_ctx;
2107
+ if (check_cu(cuDevicePrimaryCtxRetain_f(&primary_ctx, device_info->device)))
2108
+ {
2109
+ check_cu(cuDevicePrimaryCtxRelease_f(device_info->device));
2110
+ return int(ctx == primary_ctx);
2111
+ }
2112
+
2113
+ return 0;
2114
+ }
2115
+
2116
+ void* cuda_context_get_stream(void* context)
2117
+ {
2118
+ ContextInfo* info = get_context_info(static_cast<CUcontext>(context));
2119
+ if (info)
2120
+ {
2121
+ return info->stream;
2122
+ }
2123
+ return NULL;
2124
+ }
2125
+
2126
+ void cuda_context_set_stream(void* context, void* stream, int sync)
2127
+ {
2128
+ ContextInfo* context_info = get_context_info(static_cast<CUcontext>(context));
2129
+ if (context_info)
2130
+ {
2131
+ CUstream new_stream = static_cast<CUstream>(stream);
2132
+
2133
+ // check whether we should sync with the previous stream on this device
2134
+ if (sync)
2135
+ {
2136
+ CUstream old_stream = context_info->stream;
2137
+ StreamInfo* old_stream_info = get_stream_info(old_stream);
2138
+ if (old_stream_info)
2139
+ {
2140
+ CUevent cached_event = old_stream_info->cached_event;
2141
+ check_cu(cuEventRecord_f(cached_event, old_stream));
2142
+ check_cu(cuStreamWaitEvent_f(new_stream, cached_event, CU_EVENT_WAIT_DEFAULT));
2143
+ }
2144
+ }
2145
+
2146
+ context_info->stream = new_stream;
2147
+ }
2148
+ }
2149
+
2150
+ int cuda_is_peer_access_supported(int target_ordinal, int peer_ordinal)
2151
+ {
2152
+ int num_devices = int(g_devices.size());
2153
+
2154
+ if (target_ordinal < 0 || target_ordinal > num_devices)
2155
+ {
2156
+ fprintf(stderr, "Warp error: Invalid target device ordinal %d\n", target_ordinal);
2157
+ return 0;
2158
+ }
2159
+
2160
+ if (peer_ordinal < 0 || peer_ordinal > num_devices)
2161
+ {
2162
+ fprintf(stderr, "Warp error: Invalid peer device ordinal %d\n", peer_ordinal);
2163
+ return 0;
2164
+ }
2165
+
2166
+ if (target_ordinal == peer_ordinal)
2167
+ return 1;
2168
+
2169
+ int can_access = 0;
2170
+ check_cuda(cudaDeviceCanAccessPeer(&can_access, peer_ordinal, target_ordinal));
2171
+
2172
+ return can_access;
2173
+ }
2174
+
2175
+ int cuda_is_peer_access_enabled(void* target_context, void* peer_context)
2176
+ {
2177
+ if (!target_context || !peer_context)
2178
+ {
2179
+ fprintf(stderr, "Warp error: invalid CUDA context\n");
2180
+ return 0;
2181
+ }
2182
+
2183
+ if (target_context == peer_context)
2184
+ return 1;
2185
+
2186
+ int target_ordinal = cuda_context_get_device_ordinal(target_context);
2187
+ int peer_ordinal = cuda_context_get_device_ordinal(peer_context);
2188
+
2189
+ // check if peer access is supported
2190
+ int can_access = 0;
2191
+ check_cuda(cudaDeviceCanAccessPeer(&can_access, peer_ordinal, target_ordinal));
2192
+ if (!can_access)
2193
+ return 0;
2194
+
2195
+ // There is no CUDA API to query if peer access is enabled, but we can try to enable it and check the result.
2196
+
2197
+ ContextGuard guard(peer_context, true);
2198
+
2199
+ CUcontext target_ctx = static_cast<CUcontext>(target_context);
2200
+
2201
+ CUresult result = cuCtxEnablePeerAccess_f(target_ctx, 0);
2202
+ if (result == CUDA_ERROR_PEER_ACCESS_ALREADY_ENABLED)
2203
+ {
2204
+ return 1;
2205
+ }
2206
+ else if (result == CUDA_SUCCESS)
2207
+ {
2208
+ // undo enablement
2209
+ check_cu(cuCtxDisablePeerAccess_f(target_ctx));
2210
+ return 0;
2211
+ }
2212
+ else
2213
+ {
2214
+ // report error
2215
+ check_cu(result);
2216
+ return 0;
2217
+ }
2218
+ }
2219
+
2220
+ int cuda_set_peer_access_enabled(void* target_context, void* peer_context, int enable)
2221
+ {
2222
+ if (!target_context || !peer_context)
2223
+ {
2224
+ fprintf(stderr, "Warp error: invalid CUDA context\n");
2225
+ return 0;
2226
+ }
2227
+
2228
+ if (target_context == peer_context)
2229
+ return 1; // no-op
2230
+
2231
+ int target_ordinal = cuda_context_get_device_ordinal(target_context);
2232
+ int peer_ordinal = cuda_context_get_device_ordinal(peer_context);
2233
+
2234
+ // check if peer access is supported
2235
+ int can_access = 0;
2236
+ check_cuda(cudaDeviceCanAccessPeer(&can_access, peer_ordinal, target_ordinal));
2237
+ if (!can_access)
2238
+ {
2239
+ // failure if enabling, success if disabling
2240
+ if (enable)
2241
+ {
2242
+ fprintf(stderr, "Warp error: device %d cannot access device %d\n", peer_ordinal, target_ordinal);
2243
+ return 0;
2244
+ }
2245
+ else
2246
+ return 1;
2247
+ }
2248
+
2249
+ ContextGuard guard(peer_context, true);
2250
+
2251
+ CUcontext target_ctx = static_cast<CUcontext>(target_context);
2252
+
2253
+ if (enable)
2254
+ {
2255
+ CUresult status = cuCtxEnablePeerAccess_f(target_ctx, 0);
2256
+ if (status != CUDA_SUCCESS && status != CUDA_ERROR_PEER_ACCESS_ALREADY_ENABLED)
2257
+ {
2258
+ check_cu(status);
2259
+ fprintf(stderr, "Warp error: failed to enable peer access from device %d to device %d\n", peer_ordinal, target_ordinal);
2260
+ return 0;
2261
+ }
2262
+ }
2263
+ else
2264
+ {
2265
+ CUresult status = cuCtxDisablePeerAccess_f(target_ctx);
2266
+ if (status != CUDA_SUCCESS && status != CUDA_ERROR_PEER_ACCESS_NOT_ENABLED)
2267
+ {
2268
+ check_cu(status);
2269
+ fprintf(stderr, "Warp error: failed to disable peer access from device %d to device %d\n", peer_ordinal, target_ordinal);
2270
+ return 0;
2271
+ }
2272
+ }
2273
+
2274
+ return 1; // success
2275
+ }
2276
+
2277
+ int cuda_is_mempool_access_enabled(int target_ordinal, int peer_ordinal)
2278
+ {
2279
+ int num_devices = int(g_devices.size());
2280
+
2281
+ if (target_ordinal < 0 || target_ordinal > num_devices)
2282
+ {
2283
+ fprintf(stderr, "Warp error: Invalid device ordinal %d\n", target_ordinal);
2284
+ return 0;
2285
+ }
2286
+
2287
+ if (peer_ordinal < 0 || peer_ordinal > num_devices)
2288
+ {
2289
+ fprintf(stderr, "Warp error: Invalid peer device ordinal %d\n", peer_ordinal);
2290
+ return 0;
2291
+ }
2292
+
2293
+ if (target_ordinal == peer_ordinal)
2294
+ return 1;
2295
+
2296
+ cudaMemPool_t pool;
2297
+ if (!check_cuda(cudaDeviceGetDefaultMemPool(&pool, target_ordinal)))
2298
+ {
2299
+ fprintf(stderr, "Warp error: Failed to get memory pool of device %d\n", target_ordinal);
2300
+ return 0;
2301
+ }
2302
+
2303
+ cudaMemAccessFlags flags = cudaMemAccessFlagsProtNone;
2304
+ cudaMemLocation location;
2305
+ location.id = peer_ordinal;
2306
+ location.type = cudaMemLocationTypeDevice;
2307
+ if (check_cuda(cudaMemPoolGetAccess(&flags, pool, &location)))
2308
+ return int(flags != cudaMemAccessFlagsProtNone);
2309
+
2310
+ return 0;
2311
+ }
2312
+
2313
+ int cuda_set_mempool_access_enabled(int target_ordinal, int peer_ordinal, int enable)
2314
+ {
2315
+ int num_devices = int(g_devices.size());
2316
+
2317
+ if (target_ordinal < 0 || target_ordinal > num_devices)
2318
+ {
2319
+ fprintf(stderr, "Warp error: Invalid device ordinal %d\n", target_ordinal);
2320
+ return 0;
2321
+ }
2322
+
2323
+ if (peer_ordinal < 0 || peer_ordinal > num_devices)
2324
+ {
2325
+ fprintf(stderr, "Warp error: Invalid peer device ordinal %d\n", peer_ordinal);
2326
+ return 0;
2327
+ }
2328
+
2329
+ if (target_ordinal == peer_ordinal)
2330
+ return 1; // no-op
2331
+
2332
+ // get the memory pool
2333
+ cudaMemPool_t pool;
2334
+ if (!check_cuda(cudaDeviceGetDefaultMemPool(&pool, target_ordinal)))
2335
+ {
2336
+ fprintf(stderr, "Warp error: Failed to get memory pool of device %d\n", target_ordinal);
2337
+ return 0;
2338
+ }
2339
+
2340
+ cudaMemAccessDesc desc;
2341
+ desc.location.type = cudaMemLocationTypeDevice;
2342
+ desc.location.id = peer_ordinal;
2343
+
2344
+ // only cudaMemAccessFlagsProtReadWrite and cudaMemAccessFlagsProtNone are supported
2345
+ if (enable)
2346
+ desc.flags = cudaMemAccessFlagsProtReadWrite;
2347
+ else
2348
+ desc.flags = cudaMemAccessFlagsProtNone;
2349
+
2350
+ if (!check_cuda(cudaMemPoolSetAccess(pool, &desc, 1)))
2351
+ {
2352
+ fprintf(stderr, "Warp error: Failed to set mempool access from device %d to device %d\n", peer_ordinal, target_ordinal);
2353
+ return 0;
2354
+ }
2355
+
2356
+ return 1; // success
2357
+ }
2358
+
2359
+ void cuda_ipc_get_mem_handle(void* ptr, char* out_buffer) {
2360
+ CUipcMemHandle memHandle;
2361
+ check_cu(cuIpcGetMemHandle_f(&memHandle, (CUdeviceptr)ptr));
2362
+ memcpy(out_buffer, memHandle.reserved, CU_IPC_HANDLE_SIZE);
2363
+ }
2364
+
2365
+ void* cuda_ipc_open_mem_handle(void* context, char* handle) {
2366
+ ContextGuard guard(context);
2367
+
2368
+ CUipcMemHandle memHandle;
2369
+ memcpy(memHandle.reserved, handle, CU_IPC_HANDLE_SIZE);
2370
+
2371
+ CUdeviceptr device_ptr;
2372
+
2373
+ // Strangely, the CU_IPC_MEM_LAZY_ENABLE_PEER_ACCESS flag is required
2374
+ if check_cu(cuIpcOpenMemHandle_f(&device_ptr, memHandle, CU_IPC_MEM_LAZY_ENABLE_PEER_ACCESS))
2375
+ return (void*) device_ptr;
2376
+ else
2377
+ return NULL;
2378
+ }
2379
+
2380
+ void cuda_ipc_close_mem_handle(void* ptr) {
2381
+ check_cu(cuIpcCloseMemHandle_f((CUdeviceptr) ptr));
2382
+ }
2383
+
2384
+ void cuda_ipc_get_event_handle(void* context, void* event, char* out_buffer) {
2385
+ ContextGuard guard(context);
2386
+
2387
+ CUipcEventHandle eventHandle;
2388
+ check_cu(cuIpcGetEventHandle_f(&eventHandle, static_cast<CUevent>(event)));
2389
+ memcpy(out_buffer, eventHandle.reserved, CU_IPC_HANDLE_SIZE);
2390
+ }
2391
+
2392
+ void* cuda_ipc_open_event_handle(void* context, char* handle) {
2393
+ ContextGuard guard(context);
2394
+
2395
+ CUipcEventHandle eventHandle;
2396
+ memcpy(eventHandle.reserved, handle, CU_IPC_HANDLE_SIZE);
2397
+
2398
+ CUevent event;
2399
+
2400
+ if (check_cu(cuIpcOpenEventHandle_f(&event, eventHandle)))
2401
+ return event;
2402
+ else
2403
+ return NULL;
2404
+ }
2405
+
2406
+ void* cuda_stream_create(void* context, int priority)
2407
+ {
2408
+ ContextGuard guard(context, true);
2409
+
2410
+ CUstream stream;
2411
+ if (check_cu(cuStreamCreateWithPriority_f(&stream, CU_STREAM_DEFAULT, priority)))
2412
+ {
2413
+ cuda_stream_register(WP_CURRENT_CONTEXT, stream);
2414
+ return stream;
2415
+ }
2416
+ else
2417
+ return NULL;
2418
+ }
2419
+
2420
+ void cuda_stream_destroy(void* context, void* stream)
2421
+ {
2422
+ if (!stream)
2423
+ return;
2424
+
2425
+ cuda_stream_unregister(context, stream);
2426
+
2427
+ check_cu(cuStreamDestroy_f(static_cast<CUstream>(stream)));
2428
+ }
2429
+
2430
+ int cuda_stream_query(void* stream)
2431
+ {
2432
+ CUresult res = cuStreamQuery_f(static_cast<CUstream>(stream));
2433
+
2434
+ if ((res != CUDA_SUCCESS) && (res != CUDA_ERROR_NOT_READY))
2435
+ {
2436
+ // Abnormal, print out error
2437
+ check_cu(res);
2438
+ }
2439
+
2440
+ return res;
2441
+ }
2442
+
2443
+ void cuda_stream_register(void* context, void* stream)
2444
+ {
2445
+ if (!stream)
2446
+ return;
2447
+
2448
+ ContextGuard guard(context);
2449
+
2450
+ // populate stream info
2451
+ StreamInfo& stream_info = g_streams[static_cast<CUstream>(stream)];
2452
+ check_cu(cuEventCreate_f(&stream_info.cached_event, CU_EVENT_DISABLE_TIMING));
2453
+ }
2454
+
2455
+ void cuda_stream_unregister(void* context, void* stream)
2456
+ {
2457
+ if (!stream)
2458
+ return;
2459
+
2460
+ CUstream cuda_stream = static_cast<CUstream>(stream);
2461
+
2462
+ StreamInfo* stream_info = get_stream_info(cuda_stream);
2463
+ if (stream_info)
2464
+ {
2465
+ // release stream info
2466
+ check_cu(cuEventDestroy_f(stream_info->cached_event));
2467
+ g_streams.erase(cuda_stream);
2468
+ }
2469
+
2470
+ // make sure we don't leave dangling references to this stream
2471
+ ContextInfo* context_info = get_context_info(context);
2472
+ if (context_info)
2473
+ {
2474
+ if (cuda_stream == context_info->stream)
2475
+ context_info->stream = NULL;
2476
+ }
2477
+ }
2478
+
2479
+ void* cuda_stream_get_current()
2480
+ {
2481
+ return get_current_stream();
2482
+ }
2483
+
2484
+ void cuda_stream_synchronize(void* stream)
2485
+ {
2486
+ check_cu(cuStreamSynchronize_f(static_cast<CUstream>(stream)));
2487
+ }
2488
+
2489
+ void cuda_stream_wait_event(void* stream, void* event)
2490
+ {
2491
+ check_cu(cuStreamWaitEvent_f(static_cast<CUstream>(stream), static_cast<CUevent>(event), 0));
2492
+ }
2493
+
2494
+ void cuda_stream_wait_stream(void* stream, void* other_stream, void* event)
2495
+ {
2496
+ check_cu(cuEventRecord_f(static_cast<CUevent>(event), static_cast<CUstream>(other_stream)));
2497
+ check_cu(cuStreamWaitEvent_f(static_cast<CUstream>(stream), static_cast<CUevent>(event), 0));
2498
+ }
2499
+
2500
+ int cuda_stream_is_capturing(void* stream)
2501
+ {
2502
+ cudaStreamCaptureStatus status = cudaStreamCaptureStatusNone;
2503
+ check_cuda(cudaStreamIsCapturing(static_cast<cudaStream_t>(stream), &status));
2504
+
2505
+ return int(status != cudaStreamCaptureStatusNone);
2506
+ }
2507
+
2508
+ uint64_t cuda_stream_get_capture_id(void* stream)
2509
+ {
2510
+ return get_capture_id(static_cast<CUstream>(stream));
2511
+ }
2512
+
2513
+ int cuda_stream_get_priority(void* stream)
2514
+ {
2515
+ int priority = 0;
2516
+ check_cuda(cuStreamGetPriority_f(static_cast<CUstream>(stream), &priority));
2517
+
2518
+ return priority;
2519
+ }
2520
+
2521
+ void* cuda_event_create(void* context, unsigned flags)
2522
+ {
2523
+ ContextGuard guard(context, true);
2524
+
2525
+ CUevent event;
2526
+ if (check_cu(cuEventCreate_f(&event, flags)))
2527
+ return event;
2528
+ else
2529
+ return NULL;
2530
+ }
2531
+
2532
+ void cuda_event_destroy(void* event)
2533
+ {
2534
+ check_cu(cuEventDestroy_f(static_cast<CUevent>(event)));
2535
+ }
2536
+
2537
+ int cuda_event_query(void* event)
2538
+ {
2539
+ CUresult res = cuEventQuery_f(static_cast<CUevent>(event));
2540
+
2541
+ if ((res != CUDA_SUCCESS) && (res != CUDA_ERROR_NOT_READY))
2542
+ {
2543
+ // Abnormal, print out error
2544
+ check_cu(res);
2545
+ }
2546
+
2547
+ return res;
2548
+ }
2549
+
2550
+ void cuda_event_record(void* event, void* stream, bool timing)
2551
+ {
2552
+ if (timing && !g_captures.empty() && cuda_stream_is_capturing(stream))
2553
+ {
2554
+ // record timing event during graph capture
2555
+ check_cu(cuEventRecordWithFlags_f(static_cast<CUevent>(event), static_cast<CUstream>(stream), CU_EVENT_RECORD_EXTERNAL));
2556
+ }
2557
+ else
2558
+ {
2559
+ check_cu(cuEventRecord_f(static_cast<CUevent>(event), static_cast<CUstream>(stream)));
2560
+ }
2561
+ }
2562
+
2563
+ void cuda_event_synchronize(void* event)
2564
+ {
2565
+ check_cu(cuEventSynchronize_f(static_cast<CUevent>(event)));
2566
+ }
2567
+
2568
+ float cuda_event_elapsed_time(void* start_event, void* end_event)
2569
+ {
2570
+ float elapsed = 0.0f;
2571
+ cudaEvent_t start = static_cast<cudaEvent_t>(start_event);
2572
+ cudaEvent_t end = static_cast<cudaEvent_t>(end_event);
2573
+ check_cuda(cudaEventElapsedTime(&elapsed, start, end));
2574
+ return elapsed;
2575
+ }
2576
+
2577
+ bool cuda_graph_begin_capture(void* context, void* stream, int external)
2578
+ {
2579
+ ContextGuard guard(context);
2580
+
2581
+ CUstream cuda_stream = static_cast<CUstream>(stream);
2582
+ StreamInfo* stream_info = get_stream_info(cuda_stream);
2583
+ if (!stream_info)
2584
+ {
2585
+ wp::set_error_string("Warp error: unknown stream");
2586
+ return false;
2587
+ }
2588
+
2589
+ if (external)
2590
+ {
2591
+ // if it's an external capture, make sure it's already active so we can get the capture id
2592
+ cudaStreamCaptureStatus status = cudaStreamCaptureStatusNone;
2593
+ if (!check_cuda(cudaStreamIsCapturing(cuda_stream, &status)))
2594
+ return false;
2595
+ if (status != cudaStreamCaptureStatusActive)
2596
+ {
2597
+ wp::set_error_string("Warp error: stream is not capturing");
2598
+ return false;
2599
+ }
2600
+ }
2601
+ else
2602
+ {
2603
+ // start the capture
2604
+ if (!check_cuda(cudaStreamBeginCapture(cuda_stream, cudaStreamCaptureModeGlobal)))
2605
+ return false;
2606
+ }
2607
+
2608
+ uint64_t capture_id = get_capture_id(cuda_stream);
2609
+
2610
+ CaptureInfo* capture = new CaptureInfo();
2611
+ capture->stream = cuda_stream;
2612
+ capture->id = capture_id;
2613
+ capture->external = bool(external);
2614
+
2615
+ // update stream info
2616
+ stream_info->capture = capture;
2617
+
2618
+ // add to known captures
2619
+ g_captures[capture_id] = capture;
2620
+
2621
+ return true;
2622
+ }
2623
+
2624
+ bool cuda_graph_end_capture(void* context, void* stream, void** graph_ret)
2625
+ {
2626
+ ContextGuard guard(context);
2627
+
2628
+ // check if this is a known stream
2629
+ CUstream cuda_stream = static_cast<CUstream>(stream);
2630
+ StreamInfo* stream_info = get_stream_info(cuda_stream);
2631
+ if (!stream_info)
2632
+ {
2633
+ wp::set_error_string("Warp error: unknown capture stream");
2634
+ return false;
2635
+ }
2636
+
2637
+ // check if this stream was used to start a capture
2638
+ CaptureInfo* capture = stream_info->capture;
2639
+ if (!capture)
2640
+ {
2641
+ wp::set_error_string("Warp error: stream has no capture started");
2642
+ return false;
2643
+ }
2644
+
2645
+ // get capture info
2646
+ bool external = capture->external;
2647
+ uint64_t capture_id = capture->id;
2648
+
2649
+ // clear capture info
2650
+ stream_info->capture = NULL;
2651
+ g_captures.erase(capture_id);
2652
+ delete capture;
2653
+
2654
+ // a lambda to clean up on exit in case of error
2655
+ auto clean_up = [cuda_stream, capture_id, external]()
2656
+ {
2657
+ // unreference outstanding graph allocs so that they will be released with the user reference
2658
+ for (auto it = g_graph_allocs.begin(); it != g_graph_allocs.end(); ++it)
2659
+ {
2660
+ GraphAllocInfo& alloc_info = it->second;
2661
+ if (alloc_info.capture_id == capture_id)
2662
+ alloc_info.graph_destroyed = true;
2663
+ }
2664
+
2665
+ // make sure we terminate the capture
2666
+ if (!external)
2667
+ {
2668
+ cudaGraph_t graph = NULL;
2669
+ cudaStreamEndCapture(cuda_stream, &graph);
2670
+ cudaGetLastError();
2671
+ }
2672
+ };
2673
+
2674
+ // get captured graph without ending the capture in case it is external
2675
+ cudaGraph_t graph = get_capture_graph(cuda_stream);
2676
+ if (!graph)
2677
+ {
2678
+ clean_up();
2679
+ return false;
2680
+ }
2681
+
2682
+ // ensure that all forked streams are joined to the main capture stream by manually
2683
+ // adding outstanding capture dependencies gathered from the graph leaf nodes
2684
+ std::vector<cudaGraphNode_t> stream_dependencies;
2685
+ std::vector<cudaGraphNode_t> leaf_nodes;
2686
+ if (get_capture_dependencies(cuda_stream, stream_dependencies) && get_graph_leaf_nodes(graph, leaf_nodes))
2687
+ {
2688
+ // compute set difference to get unjoined dependencies
2689
+ std::vector<cudaGraphNode_t> unjoined_dependencies;
2690
+ std::sort(stream_dependencies.begin(), stream_dependencies.end());
2691
+ std::sort(leaf_nodes.begin(), leaf_nodes.end());
2692
+ std::set_difference(leaf_nodes.begin(), leaf_nodes.end(),
2693
+ stream_dependencies.begin(), stream_dependencies.end(),
2694
+ std::back_inserter(unjoined_dependencies));
2695
+ if (!unjoined_dependencies.empty())
2696
+ {
2697
+ check_cu(cuStreamUpdateCaptureDependencies_f(cuda_stream, unjoined_dependencies.data(), unjoined_dependencies.size(),
2698
+ CU_STREAM_ADD_CAPTURE_DEPENDENCIES));
2699
+ // ensure graph is still valid
2700
+ if (get_capture_graph(cuda_stream) != graph)
2701
+ {
2702
+ clean_up();
2703
+ return false;
2704
+ }
2705
+ }
2706
+ }
2707
+
2708
+ // check if this graph has unfreed allocations, which require special handling
2709
+ std::vector<void*> unfreed_allocs;
2710
+ for (auto it = g_graph_allocs.begin(); it != g_graph_allocs.end(); ++it)
2711
+ {
2712
+ GraphAllocInfo& alloc_info = it->second;
2713
+ if (alloc_info.capture_id == capture_id)
2714
+ unfreed_allocs.push_back(it->first);
2715
+ }
2716
+
2717
+ if (!unfreed_allocs.empty())
2718
+ {
2719
+ // Create a user object that will notify us when the instantiated graph is destroyed.
2720
+ // This works for external captures also, since we wouldn't otherwise know when
2721
+ // the externally-created graph instance gets deleted.
2722
+ // This callback is guaranteed to arrive after the graph has finished executing on the device,
2723
+ // not necessarily when cudaGraphExecDestroy() is called.
2724
+ GraphInfo* graph_info = new GraphInfo;
2725
+ graph_info->unfreed_allocs = unfreed_allocs;
2726
+ cudaUserObject_t user_object;
2727
+ check_cuda(cudaUserObjectCreate(&user_object, graph_info, on_graph_destroy, 1, cudaUserObjectNoDestructorSync));
2728
+ check_cuda(cudaGraphRetainUserObject(graph, user_object, 1, cudaGraphUserObjectMove));
2729
+
2730
+ // ensure graph is still valid
2731
+ if (get_capture_graph(cuda_stream) != graph)
2732
+ {
2733
+ clean_up();
2734
+ return false;
2735
+ }
2736
+ }
2737
+
2738
+ // for external captures, we don't instantiate the graph ourselves, so we're done
2739
+ if (external)
2740
+ return true;
2741
+
2742
+ cudaGraphExec_t graph_exec = NULL;
2743
+
2744
+ // end the capture
2745
+ if (!check_cuda(cudaStreamEndCapture(cuda_stream, &graph)))
2746
+ return false;
2747
+
2748
+ // enable to create debug GraphVis visualization of graph
2749
+ // cudaGraphDebugDotPrint(graph, "graph.dot", cudaGraphDebugDotFlagsVerbose);
2750
+
2751
+ // can use after CUDA 11.4 to permit graphs to capture cudaMallocAsync() operations
2752
+ if (!check_cuda(cudaGraphInstantiateWithFlags(&graph_exec, graph, cudaGraphInstantiateFlagAutoFreeOnLaunch)))
2753
+ return false;
2754
+
2755
+ // free source graph
2756
+ check_cuda(cudaGraphDestroy(graph));
2757
+
2758
+ // process deferred free list if no more captures are ongoing
2759
+ if (g_captures.empty())
2760
+ {
2761
+ free_deferred_allocs();
2762
+ unload_deferred_modules();
2763
+ }
2764
+
2765
+ if (graph_ret)
2766
+ *graph_ret = graph_exec;
2767
+
2768
+ return true;
2769
+ }
2770
+
2771
+ bool cuda_graph_launch(void* graph_exec, void* stream)
2772
+ {
2773
+ // TODO: allow naming graphs?
2774
+ begin_cuda_range(WP_TIMING_GRAPH, stream, get_stream_context(stream), "graph");
2775
+
2776
+ bool result = check_cuda(cudaGraphLaunch((cudaGraphExec_t)graph_exec, (cudaStream_t)stream));
2777
+
2778
+ end_cuda_range(WP_TIMING_GRAPH, stream);
2779
+
2780
+ return result;
2781
+ }
2782
+
2783
+ bool cuda_graph_destroy(void* context, void* graph_exec)
2784
+ {
2785
+ ContextGuard guard(context);
2786
+
2787
+ return check_cuda(cudaGraphExecDestroy((cudaGraphExec_t)graph_exec));
2788
+ }
2789
+
2790
+ bool write_file(const char* data, size_t size, std::string filename, const char* mode)
2791
+ {
2792
+ const bool print_debug = (std::getenv("WARP_DEBUG") != nullptr);
2793
+ if (print_debug)
2794
+ {
2795
+ printf("Writing %zu B to %s (%s)\n", size, filename.c_str(), mode);
2796
+ }
2797
+ FILE* file = fopen(filename.c_str(), mode);
2798
+ if (file)
2799
+ {
2800
+ if (fwrite(data, 1, size, file) != size) {
2801
+ fprintf(stderr, "Warp error: Failed to write to output file '%s'\n", filename.c_str());
2802
+ return false;
2803
+ }
2804
+ fclose(file);
2805
+ return true;
2806
+ }
2807
+ else
2808
+ {
2809
+ fprintf(stderr, "Warp error: Failed to open file '%s'\n", filename.c_str());
2810
+ return false;
2811
+ }
2812
+ }
2813
+
2814
+ #if WP_ENABLE_MATHDX
2815
+ bool check_nvjitlink_result(nvJitLinkHandle handle, nvJitLinkResult result, const char* file, int line)
2816
+ {
2817
+ if (result != NVJITLINK_SUCCESS) {
2818
+ fprintf(stderr, "nvJitLink error: %d on %s:%d\n", (int)result, file, line);
2819
+ size_t lsize;
2820
+ result = nvJitLinkGetErrorLogSize(handle, &lsize);
2821
+ if (result == NVJITLINK_SUCCESS && lsize > 0) {
2822
+ std::vector<char> log(lsize);
2823
+ result = nvJitLinkGetErrorLog(handle, log.data());
2824
+ if (result == NVJITLINK_SUCCESS) {
2825
+ fprintf(stderr, "%s\n", log.data());
2826
+ }
2827
+ }
2828
+ return false;
2829
+ } else {
2830
+ return true;
2831
+ }
2832
+ }
2833
+ #endif
2834
+
2835
+ size_t cuda_compile_program(const char* cuda_src, const char* program_name, int arch, const char* include_dir, int num_cuda_include_dirs, const char** cuda_include_dirs, bool debug, bool verbose, bool verify_fp, bool fast_math, bool fuse_fp, bool lineinfo, const char* output_path, size_t num_ltoirs, char** ltoirs, size_t* ltoir_sizes, int* ltoir_input_types)
2836
+ {
2837
+ // use file extension to determine whether to output PTX or CUBIN
2838
+ const char* output_ext = strrchr(output_path, '.');
2839
+ bool use_ptx = output_ext && strcmp(output_ext + 1, "ptx") == 0;
2840
+ const bool print_debug = (std::getenv("WARP_DEBUG") != nullptr);
2841
+
2842
+ // check include dir path len (path + option)
2843
+ const int max_path = 4096 + 16;
2844
+ if (strlen(include_dir) > max_path)
2845
+ {
2846
+ fprintf(stderr, "Warp error: Include path too long\n");
2847
+ return size_t(-1);
2848
+ }
2849
+
2850
+ if (print_debug)
2851
+ {
2852
+ // Not available in all nvJitLink versions
2853
+ // unsigned major = 0;
2854
+ // unsigned minor = 0;
2855
+ // nvJitLinkVersion(&major, &minor);
2856
+ // printf("nvJitLink version %d.%d\n", major, minor);
2857
+ int major = 0;
2858
+ int minor = 0;
2859
+ nvrtcVersion(&major, &minor);
2860
+ printf("NVRTC version %d.%d\n", major, minor);
2861
+ }
2862
+
2863
+ char include_opt[max_path];
2864
+ strcpy(include_opt, "--include-path=");
2865
+ strcat(include_opt, include_dir);
2866
+
2867
+ const int max_arch = 128;
2868
+ char arch_opt[max_arch];
2869
+ char arch_opt_lto[max_arch];
2870
+
2871
+ if (use_ptx)
2872
+ {
2873
+ snprintf(arch_opt, max_arch, "--gpu-architecture=compute_%d", arch);
2874
+ snprintf(arch_opt_lto, max_arch, "-arch=compute_%d", arch);
2875
+ }
2876
+ else
2877
+ {
2878
+ snprintf(arch_opt, max_arch, "--gpu-architecture=sm_%d", arch);
2879
+ snprintf(arch_opt_lto, max_arch, "-arch=sm_%d", arch);
2880
+ }
2881
+
2882
+ std::vector<const char*> opts;
2883
+ opts.push_back(arch_opt);
2884
+ opts.push_back(include_opt);
2885
+ opts.push_back("--std=c++17");
2886
+
2887
+ if (debug)
2888
+ {
2889
+ opts.push_back("--define-macro=_DEBUG");
2890
+ opts.push_back("--generate-line-info");
2891
+
2892
+ // disabling since it causes issues with `Unresolved extern function 'cudaGetParameterBufferV2'
2893
+ //opts.push_back("--device-debug");
2894
+ }
2895
+ else
2896
+ {
2897
+ opts.push_back("--define-macro=NDEBUG");
2898
+
2899
+ if (lineinfo)
2900
+ opts.push_back("--generate-line-info");
2901
+ }
2902
+
2903
+ if (verify_fp)
2904
+ opts.push_back("--define-macro=WP_VERIFY_FP");
2905
+ else
2906
+ opts.push_back("--undefine-macro=WP_VERIFY_FP");
2907
+
2908
+ #if WP_ENABLE_MATHDX
2909
+ opts.push_back("--define-macro=WP_ENABLE_MATHDX=1");
2910
+ #else
2911
+ opts.push_back("--define-macro=WP_ENABLE_MATHDX=0");
2912
+ #endif
2913
+
2914
+ if (fast_math)
2915
+ opts.push_back("--use_fast_math");
2916
+
2917
+ if (fuse_fp)
2918
+ opts.push_back("--fmad=true");
2919
+ else
2920
+ opts.push_back("--fmad=false");
2921
+
2922
+ std::vector<std::string> cuda_include_opt;
2923
+ for(int i = 0; i < num_cuda_include_dirs; i++)
2924
+ {
2925
+ cuda_include_opt.push_back(std::string("--include-path=") + cuda_include_dirs[i]);
2926
+ opts.push_back(cuda_include_opt.back().c_str());
2927
+ }
2928
+
2929
+ opts.push_back("--device-as-default-execution-space");
2930
+ opts.push_back("--extra-device-vectorization");
2931
+ opts.push_back("--restrict");
2932
+
2933
+ if (num_ltoirs > 0)
2934
+ {
2935
+ opts.push_back("-dlto");
2936
+ opts.push_back("--relocatable-device-code=true");
2937
+ }
2938
+
2939
+ nvrtcProgram prog;
2940
+ nvrtcResult res;
2941
+
2942
+ res = nvrtcCreateProgram(
2943
+ &prog, // prog
2944
+ cuda_src, // buffer
2945
+ program_name, // name
2946
+ 0, // numHeaders
2947
+ NULL, // headers
2948
+ NULL); // includeNames
2949
+
2950
+ if (!check_nvrtc(res))
2951
+ return size_t(res);
2952
+
2953
+ if (print_debug)
2954
+ {
2955
+ printf("NVRTC options:\n");
2956
+ for(auto o: opts) {
2957
+ printf("%s\n", o);
2958
+ }
2959
+ }
2960
+ res = nvrtcCompileProgram(prog, int(opts.size()), opts.data());
2961
+
2962
+ if (!check_nvrtc(res) || verbose)
2963
+ {
2964
+ // get program log
2965
+ size_t log_size;
2966
+ if (check_nvrtc(nvrtcGetProgramLogSize(prog, &log_size)))
2967
+ {
2968
+ std::vector<char> log(log_size);
2969
+ if (check_nvrtc(nvrtcGetProgramLog(prog, log.data())))
2970
+ {
2971
+ // todo: figure out better way to return this to python
2972
+ if (res != NVRTC_SUCCESS)
2973
+ fprintf(stderr, "%s", log.data());
2974
+ else
2975
+ fprintf(stdout, "%s", log.data());
2976
+ }
2977
+ }
2978
+
2979
+ if (res != NVRTC_SUCCESS)
2980
+ {
2981
+ nvrtcDestroyProgram(&prog);
2982
+ return size_t(res);
2983
+ }
2984
+ }
2985
+
2986
+ nvrtcResult (*get_output_size)(nvrtcProgram, size_t*);
2987
+ nvrtcResult (*get_output_data)(nvrtcProgram, char*);
2988
+ const char* output_mode;
2989
+ if(num_ltoirs > 0) {
2990
+ #if WP_ENABLE_MATHDX
2991
+ get_output_size = nvrtcGetLTOIRSize;
2992
+ get_output_data = nvrtcGetLTOIR;
2993
+ output_mode = "wb";
2994
+ #else
2995
+ fprintf(stderr, "Warp error: num_ltoirs > 0 but Warp was not built with MathDx support\n");
2996
+ return size_t(-1);
2997
+ #endif
2998
+ }
2999
+ else if (use_ptx)
3000
+ {
3001
+ get_output_size = nvrtcGetPTXSize;
3002
+ get_output_data = nvrtcGetPTX;
3003
+ output_mode = "wt";
3004
+ }
3005
+ else
3006
+ {
3007
+ get_output_size = nvrtcGetCUBINSize;
3008
+ get_output_data = nvrtcGetCUBIN;
3009
+ output_mode = "wb";
3010
+ }
3011
+
3012
+ // save output
3013
+ size_t output_size;
3014
+ res = get_output_size(prog, &output_size);
3015
+ if (check_nvrtc(res))
3016
+ {
3017
+ std::vector<char> output(output_size);
3018
+ res = get_output_data(prog, output.data());
3019
+ if (check_nvrtc(res))
3020
+ {
3021
+
3022
+ // LTOIR case - need an extra step
3023
+ if (num_ltoirs > 0)
3024
+ {
3025
+ #if WP_ENABLE_MATHDX
3026
+ if(ltoir_input_types == nullptr || ltoirs == nullptr || ltoir_sizes == nullptr) {
3027
+ fprintf(stderr, "Warp error: num_ltoirs > 0 but ltoir_input_types, ltoirs or ltoir_sizes are NULL\n");
3028
+ return size_t(-1);
3029
+ }
3030
+ nvJitLinkHandle handle;
3031
+ std::vector<const char *> lopts = {"-dlto", arch_opt_lto};
3032
+ if (use_ptx) {
3033
+ lopts.push_back("-ptx");
3034
+ }
3035
+ if (print_debug)
3036
+ {
3037
+ printf("nvJitLink options:\n");
3038
+ for(auto o: lopts) {
3039
+ printf("%s\n", o);
3040
+ }
3041
+ }
3042
+ if(!check_nvjitlink(handle, nvJitLinkCreate(&handle, lopts.size(), lopts.data())))
3043
+ {
3044
+ res = nvrtcResult(-1);
3045
+ }
3046
+ // Links
3047
+ if(std::getenv("WARP_DUMP_LTOIR"))
3048
+ {
3049
+ write_file(output.data(), output.size(), "nvrtc_output.ltoir", "wb");
3050
+ }
3051
+ if(!check_nvjitlink(handle, nvJitLinkAddData(handle, NVJITLINK_INPUT_LTOIR, output.data(), output.size(), "nvrtc_output"))) // NVRTC business
3052
+ {
3053
+ res = nvrtcResult(-1);
3054
+ }
3055
+ for(size_t ltoidx = 0; ltoidx < num_ltoirs; ltoidx++)
3056
+ {
3057
+ nvJitLinkInputType input_type = static_cast<nvJitLinkInputType>(ltoir_input_types[ltoidx]);
3058
+ const char* ext = ".unknown";
3059
+ switch(input_type) {
3060
+ case NVJITLINK_INPUT_CUBIN:
3061
+ ext = ".cubin";
3062
+ break;
3063
+ case NVJITLINK_INPUT_LTOIR:
3064
+ ext = ".ltoir";
3065
+ break;
3066
+ case NVJITLINK_INPUT_FATBIN:
3067
+ ext = ".fatbin";
3068
+ break;
3069
+ default:
3070
+ break;
3071
+ }
3072
+ if(std::getenv("WARP_DUMP_LTOIR"))
3073
+ {
3074
+ write_file(ltoirs[ltoidx], ltoir_sizes[ltoidx], std::string("lto_online_") + std::to_string(ltoidx) + ext, "wb");
3075
+ }
3076
+ if(!check_nvjitlink(handle, nvJitLinkAddData(handle, input_type, ltoirs[ltoidx], ltoir_sizes[ltoidx], "lto_online"))) // External LTOIR
3077
+ {
3078
+ res = nvrtcResult(-1);
3079
+ }
3080
+ }
3081
+ if(!check_nvjitlink(handle, nvJitLinkComplete(handle)))
3082
+ {
3083
+ res = nvrtcResult(-1);
3084
+ }
3085
+ else
3086
+ {
3087
+ if(use_ptx)
3088
+ {
3089
+ size_t ptx_size = 0;
3090
+ check_nvjitlink(handle, nvJitLinkGetLinkedPtxSize(handle, &ptx_size));
3091
+ std::vector<char> ptx(ptx_size);
3092
+ check_nvjitlink(handle, nvJitLinkGetLinkedPtx(handle, ptx.data()));
3093
+ output = ptx;
3094
+ }
3095
+ else
3096
+ {
3097
+ size_t cubin_size = 0;
3098
+ check_nvjitlink(handle, nvJitLinkGetLinkedCubinSize(handle, &cubin_size));
3099
+ std::vector<char> cubin(cubin_size);
3100
+ check_nvjitlink(handle, nvJitLinkGetLinkedCubin(handle, cubin.data()));
3101
+ output = cubin;
3102
+ }
3103
+ }
3104
+ check_nvjitlink(handle, nvJitLinkDestroy(&handle));
3105
+ #else
3106
+ fprintf(stderr, "Warp error: num_ltoirs > 0 but Warp was not built with MathDx support\n");
3107
+ return size_t(-1);
3108
+ #endif
3109
+ }
3110
+
3111
+ if(!write_file(output.data(), output.size(), output_path, output_mode)) {
3112
+ res = nvrtcResult(-1);
3113
+ }
3114
+ }
3115
+ }
3116
+
3117
+ check_nvrtc(nvrtcDestroyProgram(&prog));
3118
+
3119
+ return res;
3120
+ }
3121
+
3122
+ #if WP_ENABLE_MATHDX
3123
+ bool check_cufftdx_result(commondxStatusType result, const char* file, int line)
3124
+ {
3125
+ if (result != commondxStatusType::COMMONDX_SUCCESS) {
3126
+ fprintf(stderr, "libmathdx cuFFTDx error: %d on %s:%d\n", (int)result, file, line);
3127
+ return false;
3128
+ } else {
3129
+ return true;
3130
+ }
3131
+ }
3132
+
3133
+ bool check_cublasdx_result(commondxStatusType result, const char* file, int line)
3134
+ {
3135
+ if (result != commondxStatusType::COMMONDX_SUCCESS) {
3136
+ fprintf(stderr, "libmathdx cuBLASDx error: %d on %s:%d\n", (int)result, file, line);
3137
+ return false;
3138
+ } else {
3139
+ return true;
3140
+ }
3141
+ }
3142
+
3143
+ bool check_cusolver_result(commondxStatusType result, const char* file, int line)
3144
+ {
3145
+ if (result != commondxStatusType::COMMONDX_SUCCESS) {
3146
+ fprintf(stderr, "libmathdx cuSOLVER error: %d on %s:%d\n", (int)result, file, line);
3147
+ return false;
3148
+ } else {
3149
+ return true;
3150
+ }
3151
+ }
3152
+
3153
+ bool cuda_compile_fft(const char* ltoir_output_path, const char* symbol_name, int num_include_dirs, const char** include_dirs, const char* mathdx_include_dir, int arch, int size, int elements_per_thread, int direction, int precision, int* shared_memory_size)
3154
+ {
3155
+
3156
+ CHECK_ANY(ltoir_output_path != nullptr);
3157
+ CHECK_ANY(symbol_name != nullptr);
3158
+ CHECK_ANY(shared_memory_size != nullptr);
3159
+ // Includes currently unused
3160
+ CHECK_ANY(include_dirs == nullptr);
3161
+ CHECK_ANY(mathdx_include_dir == nullptr);
3162
+ CHECK_ANY(num_include_dirs == 0);
3163
+
3164
+ bool res = true;
3165
+ cufftdxHandle h;
3166
+ CHECK_CUFFTDX(cufftdxCreate(&h));
3167
+
3168
+ // CUFFTDX_API_BLOCK_LMEM means each thread starts with a subset of the data
3169
+ CHECK_CUFFTDX(cufftdxSetOperatorInt64(h, cufftdxOperatorType::CUFFTDX_OPERATOR_API, cufftdxApi::CUFFTDX_API_BLOCK_LMEM));
3170
+ CHECK_CUFFTDX(cufftdxSetOperatorInt64(h, cufftdxOperatorType::CUFFTDX_OPERATOR_EXECUTION, commondxExecution::COMMONDX_EXECUTION_BLOCK));
3171
+ CHECK_CUFFTDX(cufftdxSetOperatorInt64(h, cufftdxOperatorType::CUFFTDX_OPERATOR_SIZE, (long long)size));
3172
+ CHECK_CUFFTDX(cufftdxSetOperatorInt64(h, cufftdxOperatorType::CUFFTDX_OPERATOR_DIRECTION, (cufftdxDirection)direction));
3173
+ CHECK_CUFFTDX(cufftdxSetOperatorInt64(h, cufftdxOperatorType::CUFFTDX_OPERATOR_PRECISION, (commondxPrecision)precision));
3174
+ CHECK_CUFFTDX(cufftdxSetOperatorInt64(h, cufftdxOperatorType::CUFFTDX_OPERATOR_SM, (long long)(arch * 10)));
3175
+ CHECK_CUFFTDX(cufftdxSetOperatorInt64(h, cufftdxOperatorType::CUFFTDX_OPERATOR_ELEMENTS_PER_THREAD, (long long)(elements_per_thread)));
3176
+ CHECK_CUFFTDX(cufftdxSetOperatorInt64(h, cufftdxOperatorType::CUFFTDX_OPERATOR_FFTS_PER_BLOCK, 1));
3177
+
3178
+ CHECK_CUFFTDX(cufftdxSetOptionStr(h, commondxOption::COMMONDX_OPTION_SYMBOL_NAME, symbol_name));
3179
+
3180
+ size_t lto_size = 0;
3181
+ CHECK_CUFFTDX(cufftdxGetLTOIRSize(h, &lto_size));
3182
+
3183
+ std::vector<char> lto(lto_size);
3184
+ CHECK_CUFFTDX(cufftdxGetLTOIR(h, lto.size(), lto.data()));
3185
+
3186
+ long long int smem = 0;
3187
+ CHECK_CUFFTDX(cufftdxGetTraitInt64(h, cufftdxTraitType::CUFFTDX_TRAIT_SHARED_MEMORY_SIZE, &smem));
3188
+ *shared_memory_size = (int)smem;
3189
+
3190
+ if(!write_file(lto.data(), lto.size(), ltoir_output_path, "wb")) {
3191
+ res = false;
3192
+ }
3193
+
3194
+ CHECK_CUFFTDX(cufftdxDestroy(h));
3195
+
3196
+ return res;
3197
+ }
3198
+
3199
+ bool cuda_compile_dot(const char* ltoir_output_path, const char* symbol_name, int num_include_dirs, const char** include_dirs, const char* mathdx_include_dir, int arch, int M, int N, int K, int precision_A, int precision_B, int precision_C, int type, int arrangement_A, int arrangement_B, int arrangement_C, int num_threads)
3200
+ {
3201
+
3202
+ CHECK_ANY(ltoir_output_path != nullptr);
3203
+ CHECK_ANY(symbol_name != nullptr);
3204
+ // Includes currently unused
3205
+ CHECK_ANY(include_dirs == nullptr);
3206
+ CHECK_ANY(mathdx_include_dir == nullptr);
3207
+ CHECK_ANY(num_include_dirs == 0);
3208
+
3209
+ bool res = true;
3210
+ cublasdxHandle h;
3211
+ CHECK_CUBLASDX(cublasdxCreate(&h));
3212
+
3213
+ CHECK_CUBLASDX(cublasdxSetOperatorInt64(h, cublasdxOperatorType::CUBLASDX_OPERATOR_FUNCTION, cublasdxFunction::CUBLASDX_FUNCTION_MM));
3214
+ CHECK_CUBLASDX(cublasdxSetOperatorInt64(h, cublasdxOperatorType::CUBLASDX_OPERATOR_EXECUTION, commondxExecution::COMMONDX_EXECUTION_BLOCK));
3215
+ CHECK_CUBLASDX(cublasdxSetOperatorInt64(h, cublasdxOperatorType::CUBLASDX_OPERATOR_API, cublasdxApi::CUBLASDX_API_BLOCK_SMEM));
3216
+ std::array<long long int, 3> precisions = {precision_A, precision_B, precision_C};
3217
+ CHECK_CUBLASDX(cublasdxSetOperatorInt64Array(h, cublasdxOperatorType::CUBLASDX_OPERATOR_PRECISION, 3, precisions.data()));
3218
+ CHECK_CUBLASDX(cublasdxSetOperatorInt64(h, cublasdxOperatorType::CUBLASDX_OPERATOR_SM, (long long)(arch * 10)));
3219
+ CHECK_CUBLASDX(cublasdxSetOperatorInt64(h, cublasdxOperatorType::CUBLASDX_OPERATOR_TYPE, (cublasdxType)type));
3220
+ std::array<long long int, 3> block_dim = {num_threads, 1, 1};
3221
+ CHECK_CUBLASDX(cublasdxSetOperatorInt64Array(h, cublasdxOperatorType::CUBLASDX_OPERATOR_BLOCK_DIM, block_dim.size(), block_dim.data()));
3222
+ std::array<long long int, 3> size = {M, N, K};
3223
+ CHECK_CUBLASDX(cublasdxSetOperatorInt64Array(h, cublasdxOperatorType::CUBLASDX_OPERATOR_SIZE, size.size(), size.data()));
3224
+ std::array<long long int, 3> arrangement = {arrangement_A, arrangement_B, arrangement_C};
3225
+ CHECK_CUBLASDX(cublasdxSetOperatorInt64Array(h, cublasdxOperatorType::CUBLASDX_OPERATOR_ARRANGEMENT, arrangement.size(), arrangement.data()));
3226
+
3227
+ CHECK_CUBLASDX(cublasdxSetOptionStr(h, commondxOption::COMMONDX_OPTION_SYMBOL_NAME, symbol_name));
3228
+
3229
+ size_t lto_size = 0;
3230
+ CHECK_CUBLASDX(cublasdxGetLTOIRSize(h, &lto_size));
3231
+
3232
+ std::vector<char> lto(lto_size);
3233
+ CHECK_CUBLASDX(cublasdxGetLTOIR(h, lto.size(), lto.data()));
3234
+
3235
+ if(!write_file(lto.data(), lto.size(), ltoir_output_path, "wb")) {
3236
+ res = false;
3237
+ }
3238
+
3239
+ CHECK_CUBLASDX(cublasdxDestroy(h));
3240
+
3241
+ return res;
3242
+ }
3243
+
3244
+ bool cuda_compile_solver(const char* fatbin_output_path, const char* ltoir_output_path, const char* symbol_name, int num_include_dirs, const char** include_dirs, const char* mathdx_include_dir, int arch, int M, int N, int function, int precision, int fill_mode, int num_threads)
3245
+ {
3246
+
3247
+ CHECK_ANY(ltoir_output_path != nullptr);
3248
+ CHECK_ANY(symbol_name != nullptr);
3249
+ CHECK_ANY(mathdx_include_dir == nullptr);
3250
+ CHECK_ANY(num_include_dirs == 0);
3251
+ CHECK_ANY(include_dirs == nullptr);
3252
+
3253
+ bool res = true;
3254
+
3255
+ cusolverHandle h { 0 };
3256
+ CHECK_CUSOLVER(cusolverCreate(&h));
3257
+ long long int size[2] = {M, N};
3258
+ long long int block_dim[3] = {num_threads, 1, 1};
3259
+ CHECK_CUSOLVER(cusolverSetOperatorInt64Array(h, cusolverOperatorType::CUSOLVER_OPERATOR_SIZE, 2, size));
3260
+ CHECK_CUSOLVER(cusolverSetOperatorInt64Array(h, cusolverOperatorType::CUSOLVER_OPERATOR_BLOCK_DIM, 3, block_dim));
3261
+ CHECK_CUSOLVER(cusolverSetOperatorInt64(h, cusolverOperatorType::CUSOLVER_OPERATOR_TYPE, cusolverType::CUSOLVER_TYPE_REAL));
3262
+ CHECK_CUSOLVER(cusolverSetOperatorInt64(h, cusolverOperatorType::CUSOLVER_OPERATOR_API, cusolverApi::CUSOLVER_API_BLOCK_SMEM));
3263
+ CHECK_CUSOLVER(cusolverSetOperatorInt64(h, cusolverOperatorType::CUSOLVER_OPERATOR_FUNCTION, (cusolverFunction)function));
3264
+ CHECK_CUSOLVER(cusolverSetOperatorInt64(h, cusolverOperatorType::CUSOLVER_OPERATOR_EXECUTION, commondxExecution::COMMONDX_EXECUTION_BLOCK));
3265
+ CHECK_CUSOLVER(cusolverSetOperatorInt64(h, cusolverOperatorType::CUSOLVER_OPERATOR_PRECISION, (commondxPrecision)precision));
3266
+ CHECK_CUSOLVER(cusolverSetOperatorInt64(h, cusolverOperatorType::CUSOLVER_OPERATOR_FILL_MODE, (cusolverFillMode)fill_mode));
3267
+ CHECK_CUSOLVER(cusolverSetOperatorInt64(h, cusolverOperatorType::CUSOLVER_OPERATOR_SM, (long long)(arch * 10)));
3268
+
3269
+ CHECK_CUSOLVER(cusolverSetOptionStr(h, commondxOption::COMMONDX_OPTION_SYMBOL_NAME, symbol_name));
3270
+
3271
+ size_t lto_size = 0;
3272
+ CHECK_CUSOLVER(cusolverGetLTOIRSize(h, &lto_size));
3273
+
3274
+ std::vector<char> lto(lto_size);
3275
+ CHECK_CUSOLVER(cusolverGetLTOIR(h, lto.size(), lto.data()));
3276
+
3277
+ // This fatbin is universal, ie it is the same for any instantiations of a cusolver device function
3278
+ size_t fatbin_size = 0;
3279
+ CHECK_CUSOLVER(cusolverGetUniversalFATBINSize(h, &fatbin_size));
3280
+
3281
+ std::vector<char> fatbin(fatbin_size);
3282
+ CHECK_CUSOLVER(cusolverGetUniversalFATBIN(h, fatbin.size(), fatbin.data()));
3283
+
3284
+ if(!write_file(lto.data(), lto.size(), ltoir_output_path, "wb")) {
3285
+ res = false;
3286
+ }
3287
+
3288
+ if(!write_file(fatbin.data(), fatbin.size(), fatbin_output_path, "wb")) {
3289
+ res = false;
3290
+ }
3291
+
3292
+ CHECK_CUSOLVER(cusolverDestroy(h));
3293
+
3294
+ return res;
3295
+ }
3296
+
3297
+ #endif
3298
+
3299
+ void* cuda_load_module(void* context, const char* path)
3300
+ {
3301
+ ContextGuard guard(context);
3302
+
3303
+ // use file extension to determine whether to load PTX or CUBIN
3304
+ const char* input_ext = strrchr(path, '.');
3305
+ bool load_ptx = input_ext && strcmp(input_ext + 1, "ptx") == 0;
3306
+
3307
+ std::vector<char> input;
3308
+
3309
+ FILE* file = fopen(path, "rb");
3310
+ if (file)
3311
+ {
3312
+ fseek(file, 0, SEEK_END);
3313
+ size_t length = ftell(file);
3314
+ fseek(file, 0, SEEK_SET);
3315
+
3316
+ input.resize(length + 1);
3317
+ if (fread(input.data(), 1, length, file) != length)
3318
+ {
3319
+ fprintf(stderr, "Warp error: Failed to read input file '%s'\n", path);
3320
+ fclose(file);
3321
+ return NULL;
3322
+ }
3323
+ fclose(file);
3324
+
3325
+ input[length] = '\0';
3326
+ }
3327
+ else
3328
+ {
3329
+ fprintf(stderr, "Warp error: Failed to open input file '%s'\n", path);
3330
+ return NULL;
3331
+ }
3332
+
3333
+ int driver_cuda_version = 0;
3334
+ CUmodule module = NULL;
3335
+
3336
+ if (load_ptx)
3337
+ {
3338
+ if (check_cu(cuDriverGetVersion_f(&driver_cuda_version)) && driver_cuda_version >= CUDA_VERSION)
3339
+ {
3340
+ // let the driver compile the PTX
3341
+
3342
+ CUjit_option options[2];
3343
+ void *option_vals[2];
3344
+ char error_log[8192] = "";
3345
+ unsigned int log_size = 8192;
3346
+ // Set up loader options
3347
+ // Pass a buffer for error message
3348
+ options[0] = CU_JIT_ERROR_LOG_BUFFER;
3349
+ option_vals[0] = (void*)error_log;
3350
+ // Pass the size of the error buffer
3351
+ options[1] = CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES;
3352
+ option_vals[1] = (void*)(size_t)log_size;
3353
+
3354
+ if (!check_cu(cuModuleLoadDataEx_f(&module, input.data(), 2, options, option_vals)))
3355
+ {
3356
+ fprintf(stderr, "Warp error: Loading PTX module failed\n");
3357
+ // print error log if not empty
3358
+ if (*error_log)
3359
+ fprintf(stderr, "PTX loader error:\n%s\n", error_log);
3360
+ return NULL;
3361
+ }
3362
+ }
3363
+ else
3364
+ {
3365
+ // manually compile the PTX and load as CUBIN
3366
+
3367
+ ContextInfo* context_info = get_context_info(static_cast<CUcontext>(context));
3368
+ if (!context_info || !context_info->device_info)
3369
+ {
3370
+ fprintf(stderr, "Warp error: Failed to determine target architecture\n");
3371
+ return NULL;
3372
+ }
3373
+
3374
+ int arch = context_info->device_info->arch;
3375
+
3376
+ char arch_opt[128];
3377
+ sprintf(arch_opt, "--gpu-name=sm_%d", arch);
3378
+
3379
+ const char* compiler_options[] = { arch_opt };
3380
+
3381
+ nvPTXCompilerHandle compiler = NULL;
3382
+ if (!check_nvptx(nvPTXCompilerCreate(&compiler, input.size(), input.data())))
3383
+ return NULL;
3384
+
3385
+ if (!check_nvptx(nvPTXCompilerCompile(compiler, sizeof(compiler_options) / sizeof(*compiler_options), compiler_options)))
3386
+ return NULL;
3387
+
3388
+ size_t cubin_size = 0;
3389
+ if (!check_nvptx(nvPTXCompilerGetCompiledProgramSize(compiler, &cubin_size)))
3390
+ return NULL;
3391
+
3392
+ std::vector<char> cubin(cubin_size);
3393
+ if (!check_nvptx(nvPTXCompilerGetCompiledProgram(compiler, cubin.data())))
3394
+ return NULL;
3395
+
3396
+ check_nvptx(nvPTXCompilerDestroy(&compiler));
3397
+
3398
+ if (!check_cu(cuModuleLoadDataEx_f(&module, cubin.data(), 0, NULL, NULL)))
3399
+ {
3400
+ fprintf(stderr, "Warp CUDA error: Loading module failed\n");
3401
+ return NULL;
3402
+ }
3403
+ }
3404
+ }
3405
+ else
3406
+ {
3407
+ // load CUBIN
3408
+ if (!check_cu(cuModuleLoadDataEx_f(&module, input.data(), 0, NULL, NULL)))
3409
+ {
3410
+ fprintf(stderr, "Warp CUDA error: Loading module failed\n");
3411
+ return NULL;
3412
+ }
3413
+ }
3414
+
3415
+ return module;
3416
+ }
3417
+
3418
+ void cuda_unload_module(void* context, void* module)
3419
+ {
3420
+ // ensure there are no graph captures in progress
3421
+ if (g_captures.empty())
3422
+ {
3423
+ ContextGuard guard(context);
3424
+ check_cu(cuModuleUnload_f((CUmodule)module));
3425
+ }
3426
+ else
3427
+ {
3428
+ // defer until graph capture completes
3429
+ ModuleInfo module_info;
3430
+ module_info.context = context ? context : get_current_context();
3431
+ module_info.module = module;
3432
+ g_deferred_module_list.push_back(module_info);
3433
+ }
3434
+ }
3435
+
3436
+
3437
+ int cuda_get_max_shared_memory(void* context)
3438
+ {
3439
+ ContextInfo* info = get_context_info(context);
3440
+ if (!info)
3441
+ return -1;
3442
+
3443
+ int max_smem_bytes = info->device_info->max_smem_bytes;
3444
+ return max_smem_bytes;
3445
+ }
3446
+
3447
+ bool cuda_configure_kernel_shared_memory(void* kernel, int size)
3448
+ {
3449
+ int requested_smem_bytes = size;
3450
+
3451
+ // configure shared memory
3452
+ CUresult res = cuFuncSetAttribute_f((CUfunction)kernel, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, requested_smem_bytes);
3453
+ if (res != CUDA_SUCCESS)
3454
+ return false;
3455
+
3456
+ return true;
3457
+ }
3458
+
3459
+ void* cuda_get_kernel(void* context, void* module, const char* name)
3460
+ {
3461
+ ContextGuard guard(context);
3462
+
3463
+ CUfunction kernel = NULL;
3464
+ if (!check_cu(cuModuleGetFunction_f(&kernel, (CUmodule)module, name)))
3465
+ {
3466
+ fprintf(stderr, "Warp CUDA error: Failed to lookup kernel function %s in module\n", name);
3467
+ return NULL;
3468
+ }
3469
+
3470
+ g_kernel_names[kernel] = name;
3471
+ return kernel;
3472
+ }
3473
+
3474
+ size_t cuda_launch_kernel(void* context, void* kernel, size_t dim, int max_blocks, int block_dim, int shared_memory_bytes, void** args, void* stream)
3475
+ {
3476
+ ContextGuard guard(context);
3477
+
3478
+ if (block_dim <= 0)
3479
+ {
3480
+ #if defined(_DEBUG)
3481
+ fprintf(stderr, "Warp warning: Launch got block_dim %d. Setting to 256.\n", block_dim);
3482
+ #endif
3483
+ block_dim = 256;
3484
+ }
3485
+
3486
+ // CUDA specs up to compute capability 9.0 says the max x-dim grid is 2**31-1, so
3487
+ // grid_dim is fine as an int for the near future
3488
+ int grid_dim = (dim + block_dim - 1)/block_dim;
3489
+
3490
+ if (max_blocks <= 0) {
3491
+ max_blocks = 2147483647;
3492
+ }
3493
+
3494
+ if (grid_dim < 0)
3495
+ {
3496
+ #if defined(_DEBUG)
3497
+ fprintf(stderr, "Warp warning: Overflow in grid dimensions detected for %zu total elements and 256 threads "
3498
+ "per block.\n Setting block count to %d.\n", dim, max_blocks);
3499
+ #endif
3500
+ grid_dim = max_blocks;
3501
+ }
3502
+ else
3503
+ {
3504
+ if (grid_dim > max_blocks)
3505
+ {
3506
+ grid_dim = max_blocks;
3507
+ }
3508
+ }
3509
+
3510
+ begin_cuda_range(WP_TIMING_KERNEL, stream, context, get_cuda_kernel_name(kernel));
3511
+
3512
+ CUresult res = cuLaunchKernel_f(
3513
+ (CUfunction)kernel,
3514
+ grid_dim, 1, 1,
3515
+ block_dim, 1, 1,
3516
+ shared_memory_bytes,
3517
+ static_cast<CUstream>(stream),
3518
+ args,
3519
+ 0);
3520
+
3521
+ check_cu(res);
3522
+
3523
+ end_cuda_range(WP_TIMING_KERNEL, stream);
3524
+
3525
+ return res;
3526
+ }
3527
+
3528
+ void cuda_graphics_map(void* context, void* resource)
3529
+ {
3530
+ ContextGuard guard(context);
3531
+
3532
+ check_cu(cuGraphicsMapResources_f(1, (CUgraphicsResource*)resource, get_current_stream()));
3533
+ }
3534
+
3535
+ void cuda_graphics_unmap(void* context, void* resource)
3536
+ {
3537
+ ContextGuard guard(context);
3538
+
3539
+ check_cu(cuGraphicsUnmapResources_f(1, (CUgraphicsResource*)resource, get_current_stream()));
3540
+ }
3541
+
3542
+ void cuda_graphics_device_ptr_and_size(void* context, void* resource, uint64_t* ptr, size_t* size)
3543
+ {
3544
+ ContextGuard guard(context);
3545
+
3546
+ CUdeviceptr device_ptr;
3547
+ size_t bytes;
3548
+ check_cu(cuGraphicsResourceGetMappedPointer_f(&device_ptr, &bytes, *(CUgraphicsResource*)resource));
3549
+
3550
+ *ptr = device_ptr;
3551
+ *size = bytes;
3552
+ }
3553
+
3554
+ void* cuda_graphics_register_gl_buffer(void* context, uint32_t gl_buffer, unsigned int flags)
3555
+ {
3556
+ ContextGuard guard(context);
3557
+
3558
+ CUgraphicsResource *resource = new CUgraphicsResource;
3559
+ bool success = check_cu(cuGraphicsGLRegisterBuffer_f(resource, gl_buffer, flags));
3560
+ if (!success)
3561
+ {
3562
+ delete resource;
3563
+ return NULL;
3564
+ }
3565
+
3566
+ return resource;
3567
+ }
3568
+
3569
+ void cuda_graphics_unregister_resource(void* context, void* resource)
3570
+ {
3571
+ ContextGuard guard(context);
3572
+
3573
+ CUgraphicsResource *res = (CUgraphicsResource*)resource;
3574
+ check_cu(cuGraphicsUnregisterResource_f(*res));
3575
+ delete res;
3576
+ }
3577
+
3578
+ void cuda_timing_begin(int flags)
3579
+ {
3580
+ g_cuda_timing_state = new CudaTimingState(flags, g_cuda_timing_state);
3581
+ }
3582
+
3583
+ int cuda_timing_get_result_count()
3584
+ {
3585
+ if (g_cuda_timing_state)
3586
+ return int(g_cuda_timing_state->ranges.size());
3587
+ return 0;
3588
+ }
3589
+
3590
+ void cuda_timing_end(timing_result_t* results, int size)
3591
+ {
3592
+ if (!g_cuda_timing_state)
3593
+ return;
3594
+
3595
+ // number of results to write to the user buffer
3596
+ int count = std::min(cuda_timing_get_result_count(), size);
3597
+
3598
+ // compute timings and write results
3599
+ for (int i = 0; i < count; i++)
3600
+ {
3601
+ const CudaTimingRange& range = g_cuda_timing_state->ranges[i];
3602
+ timing_result_t& result = results[i];
3603
+ result.context = range.context;
3604
+ result.name = range.name;
3605
+ result.flag = range.flag;
3606
+ check_cuda(cudaEventElapsedTime(&result.elapsed, range.start, range.end));
3607
+ }
3608
+
3609
+ // release events
3610
+ for (CudaTimingRange& range : g_cuda_timing_state->ranges)
3611
+ {
3612
+ check_cu(cuEventDestroy_f(range.start));
3613
+ check_cu(cuEventDestroy_f(range.end));
3614
+ }
3615
+
3616
+ // restore previous state
3617
+ CudaTimingState* parent_state = g_cuda_timing_state->parent;
3618
+ delete g_cuda_timing_state;
3619
+ g_cuda_timing_state = parent_state;
3620
+ }
3621
+
3622
+ // impl. files
3623
+ #include "bvh.cu"
3624
+ #include "mesh.cu"
3625
+ #include "sort.cu"
3626
+ #include "hashgrid.cu"
3627
+ #include "reduce.cu"
3628
+ #include "runlength_encode.cu"
3629
+ #include "scan.cu"
3630
+ #include "marching.cu"
3631
+ #include "sparse.cu"
3632
+ #include "volume.cu"
3633
+ #include "volume_builder.cu"
3634
+
3635
+ //#include "spline.inl"
3636
+ //#include "volume.inl"