warp-lang 0.9.0__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (315) hide show
  1. warp/__init__.py +15 -7
  2. warp/__init__.pyi +1 -0
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +22 -443
  6. warp/build_dll.py +384 -0
  7. warp/builtins.py +998 -488
  8. warp/codegen.py +1307 -739
  9. warp/config.py +5 -3
  10. warp/constants.py +6 -0
  11. warp/context.py +1291 -548
  12. warp/dlpack.py +31 -31
  13. warp/fabric.py +326 -0
  14. warp/fem/__init__.py +27 -0
  15. warp/fem/cache.py +389 -0
  16. warp/fem/dirichlet.py +181 -0
  17. warp/fem/domain.py +263 -0
  18. warp/fem/field/__init__.py +101 -0
  19. warp/fem/field/field.py +149 -0
  20. warp/fem/field/nodal_field.py +299 -0
  21. warp/fem/field/restriction.py +21 -0
  22. warp/fem/field/test.py +181 -0
  23. warp/fem/field/trial.py +183 -0
  24. warp/fem/geometry/__init__.py +19 -0
  25. warp/fem/geometry/closest_point.py +70 -0
  26. warp/fem/geometry/deformed_geometry.py +271 -0
  27. warp/fem/geometry/element.py +744 -0
  28. warp/fem/geometry/geometry.py +186 -0
  29. warp/fem/geometry/grid_2d.py +373 -0
  30. warp/fem/geometry/grid_3d.py +435 -0
  31. warp/fem/geometry/hexmesh.py +953 -0
  32. warp/fem/geometry/partition.py +376 -0
  33. warp/fem/geometry/quadmesh_2d.py +532 -0
  34. warp/fem/geometry/tetmesh.py +840 -0
  35. warp/fem/geometry/trimesh_2d.py +577 -0
  36. warp/fem/integrate.py +1616 -0
  37. warp/fem/operator.py +191 -0
  38. warp/fem/polynomial.py +213 -0
  39. warp/fem/quadrature/__init__.py +2 -0
  40. warp/fem/quadrature/pic_quadrature.py +245 -0
  41. warp/fem/quadrature/quadrature.py +294 -0
  42. warp/fem/space/__init__.py +292 -0
  43. warp/fem/space/basis_space.py +489 -0
  44. warp/fem/space/collocated_function_space.py +105 -0
  45. warp/fem/space/dof_mapper.py +236 -0
  46. warp/fem/space/function_space.py +145 -0
  47. warp/fem/space/grid_2d_function_space.py +267 -0
  48. warp/fem/space/grid_3d_function_space.py +306 -0
  49. warp/fem/space/hexmesh_function_space.py +352 -0
  50. warp/fem/space/partition.py +350 -0
  51. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  52. warp/fem/space/restriction.py +160 -0
  53. warp/fem/space/shape/__init__.py +15 -0
  54. warp/fem/space/shape/cube_shape_function.py +738 -0
  55. warp/fem/space/shape/shape_function.py +103 -0
  56. warp/fem/space/shape/square_shape_function.py +611 -0
  57. warp/fem/space/shape/tet_shape_function.py +567 -0
  58. warp/fem/space/shape/triangle_shape_function.py +429 -0
  59. warp/fem/space/tetmesh_function_space.py +292 -0
  60. warp/fem/space/topology.py +295 -0
  61. warp/fem/space/trimesh_2d_function_space.py +221 -0
  62. warp/fem/types.py +77 -0
  63. warp/fem/utils.py +495 -0
  64. warp/native/array.h +164 -55
  65. warp/native/builtin.h +150 -174
  66. warp/native/bvh.cpp +75 -328
  67. warp/native/bvh.cu +406 -23
  68. warp/native/bvh.h +37 -45
  69. warp/native/clang/clang.cpp +136 -24
  70. warp/native/crt.cpp +1 -76
  71. warp/native/crt.h +111 -104
  72. warp/native/cuda_crt.h +1049 -0
  73. warp/native/cuda_util.cpp +15 -3
  74. warp/native/cuda_util.h +3 -1
  75. warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
  76. warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
  77. warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
  78. warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
  79. warp/native/cutlass/tools/library/scripts/library.py +799 -0
  80. warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
  81. warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
  82. warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
  83. warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
  84. warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
  85. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
  86. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
  87. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
  88. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
  89. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
  90. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
  91. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
  92. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
  93. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
  94. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
  95. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
  96. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
  97. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
  98. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
  99. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
  100. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
  101. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
  102. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
  103. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
  104. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
  105. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
  106. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
  107. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
  108. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
  109. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
  110. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
  111. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
  112. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
  113. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
  114. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
  115. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
  116. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
  117. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  118. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  119. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
  120. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
  121. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
  122. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  123. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  124. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
  125. warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
  126. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
  127. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
  128. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
  129. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
  130. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
  131. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
  132. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
  133. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
  134. warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
  135. warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
  136. warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
  137. warp/native/cutlass/tools/library/scripts/rt.py +796 -0
  138. warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
  139. warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
  140. warp/native/cutlass_gemm.cu +5 -3
  141. warp/native/exports.h +1240 -949
  142. warp/native/fabric.h +228 -0
  143. warp/native/hashgrid.cpp +4 -4
  144. warp/native/hashgrid.h +22 -2
  145. warp/native/initializer_array.h +2 -2
  146. warp/native/intersect.h +22 -7
  147. warp/native/intersect_adj.h +8 -8
  148. warp/native/intersect_tri.h +13 -16
  149. warp/native/marching.cu +157 -161
  150. warp/native/mat.h +119 -19
  151. warp/native/matnn.h +2 -2
  152. warp/native/mesh.cpp +108 -83
  153. warp/native/mesh.cu +243 -6
  154. warp/native/mesh.h +1547 -458
  155. warp/native/nanovdb/NanoVDB.h +1 -1
  156. warp/native/noise.h +272 -329
  157. warp/native/quat.h +51 -8
  158. warp/native/rand.h +45 -35
  159. warp/native/range.h +6 -2
  160. warp/native/reduce.cpp +157 -0
  161. warp/native/reduce.cu +348 -0
  162. warp/native/runlength_encode.cpp +62 -0
  163. warp/native/runlength_encode.cu +46 -0
  164. warp/native/scan.cu +11 -13
  165. warp/native/scan.h +1 -0
  166. warp/native/solid_angle.h +442 -0
  167. warp/native/sort.cpp +13 -0
  168. warp/native/sort.cu +9 -1
  169. warp/native/sparse.cpp +338 -0
  170. warp/native/sparse.cu +545 -0
  171. warp/native/spatial.h +2 -2
  172. warp/native/temp_buffer.h +30 -0
  173. warp/native/vec.h +126 -24
  174. warp/native/volume.h +120 -0
  175. warp/native/warp.cpp +658 -53
  176. warp/native/warp.cu +660 -68
  177. warp/native/warp.h +112 -12
  178. warp/optim/__init__.py +1 -0
  179. warp/optim/linear.py +922 -0
  180. warp/optim/sgd.py +92 -0
  181. warp/render/render_opengl.py +392 -152
  182. warp/render/render_usd.py +11 -11
  183. warp/sim/__init__.py +2 -2
  184. warp/sim/articulation.py +385 -185
  185. warp/sim/collide.py +21 -8
  186. warp/sim/import_mjcf.py +297 -106
  187. warp/sim/import_urdf.py +389 -210
  188. warp/sim/import_usd.py +198 -97
  189. warp/sim/inertia.py +17 -18
  190. warp/sim/integrator_euler.py +14 -8
  191. warp/sim/integrator_xpbd.py +161 -19
  192. warp/sim/model.py +795 -291
  193. warp/sim/optimizer.py +2 -6
  194. warp/sim/render.py +65 -3
  195. warp/sim/utils.py +3 -0
  196. warp/sparse.py +1227 -0
  197. warp/stubs.py +665 -223
  198. warp/tape.py +66 -15
  199. warp/tests/__main__.py +3 -6
  200. warp/tests/assets/curlnoise_golden.npy +0 -0
  201. warp/tests/assets/pnoise_golden.npy +0 -0
  202. warp/tests/assets/torus.usda +105 -105
  203. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  204. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  205. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  206. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  207. warp/tests/aux_test_unresolved_func.py +14 -0
  208. warp/tests/aux_test_unresolved_symbol.py +14 -0
  209. warp/tests/disabled_kinematics.py +239 -0
  210. warp/tests/run_coverage_serial.py +31 -0
  211. warp/tests/test_adam.py +103 -106
  212. warp/tests/test_arithmetic.py +128 -74
  213. warp/tests/test_array.py +1497 -211
  214. warp/tests/test_array_reduce.py +150 -0
  215. warp/tests/test_atomic.py +64 -28
  216. warp/tests/test_bool.py +99 -0
  217. warp/tests/test_builtins_resolution.py +1292 -0
  218. warp/tests/test_bvh.py +75 -43
  219. warp/tests/test_closest_point_edge_edge.py +54 -57
  220. warp/tests/test_codegen.py +233 -128
  221. warp/tests/test_compile_consts.py +28 -20
  222. warp/tests/test_conditional.py +108 -24
  223. warp/tests/test_copy.py +10 -12
  224. warp/tests/test_ctypes.py +112 -88
  225. warp/tests/test_dense.py +21 -14
  226. warp/tests/test_devices.py +98 -0
  227. warp/tests/test_dlpack.py +136 -108
  228. warp/tests/test_examples.py +277 -0
  229. warp/tests/test_fabricarray.py +955 -0
  230. warp/tests/test_fast_math.py +15 -11
  231. warp/tests/test_fem.py +1271 -0
  232. warp/tests/test_fp16.py +53 -19
  233. warp/tests/test_func.py +187 -74
  234. warp/tests/test_generics.py +194 -49
  235. warp/tests/test_grad.py +180 -116
  236. warp/tests/test_grad_customs.py +176 -0
  237. warp/tests/test_hash_grid.py +52 -37
  238. warp/tests/test_import.py +10 -23
  239. warp/tests/test_indexedarray.py +577 -24
  240. warp/tests/test_intersect.py +18 -9
  241. warp/tests/test_large.py +141 -0
  242. warp/tests/test_launch.py +251 -15
  243. warp/tests/test_lerp.py +64 -65
  244. warp/tests/test_linear_solvers.py +154 -0
  245. warp/tests/test_lvalue.py +493 -0
  246. warp/tests/test_marching_cubes.py +12 -13
  247. warp/tests/test_mat.py +508 -2778
  248. warp/tests/test_mat_lite.py +115 -0
  249. warp/tests/test_mat_scalar_ops.py +2889 -0
  250. warp/tests/test_math.py +103 -9
  251. warp/tests/test_matmul.py +305 -69
  252. warp/tests/test_matmul_lite.py +410 -0
  253. warp/tests/test_mesh.py +71 -14
  254. warp/tests/test_mesh_query_aabb.py +41 -25
  255. warp/tests/test_mesh_query_point.py +325 -34
  256. warp/tests/test_mesh_query_ray.py +39 -22
  257. warp/tests/test_mlp.py +30 -22
  258. warp/tests/test_model.py +92 -89
  259. warp/tests/test_modules_lite.py +39 -0
  260. warp/tests/test_multigpu.py +88 -114
  261. warp/tests/test_noise.py +12 -11
  262. warp/tests/test_operators.py +16 -20
  263. warp/tests/test_options.py +11 -11
  264. warp/tests/test_pinned.py +17 -18
  265. warp/tests/test_print.py +32 -11
  266. warp/tests/test_quat.py +275 -129
  267. warp/tests/test_rand.py +18 -16
  268. warp/tests/test_reload.py +38 -34
  269. warp/tests/test_rounding.py +50 -43
  270. warp/tests/test_runlength_encode.py +190 -0
  271. warp/tests/test_smoothstep.py +9 -11
  272. warp/tests/test_snippet.py +143 -0
  273. warp/tests/test_sparse.py +460 -0
  274. warp/tests/test_spatial.py +276 -243
  275. warp/tests/test_streams.py +110 -85
  276. warp/tests/test_struct.py +331 -85
  277. warp/tests/test_tape.py +39 -21
  278. warp/tests/test_torch.py +118 -89
  279. warp/tests/test_transient_module.py +12 -13
  280. warp/tests/test_types.py +614 -0
  281. warp/tests/test_utils.py +494 -0
  282. warp/tests/test_vec.py +354 -1987
  283. warp/tests/test_vec_lite.py +73 -0
  284. warp/tests/test_vec_scalar_ops.py +2099 -0
  285. warp/tests/test_volume.py +457 -293
  286. warp/tests/test_volume_write.py +124 -134
  287. warp/tests/unittest_serial.py +35 -0
  288. warp/tests/unittest_suites.py +341 -0
  289. warp/tests/unittest_utils.py +568 -0
  290. warp/tests/unused_test_misc.py +71 -0
  291. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  292. warp/thirdparty/appdirs.py +36 -45
  293. warp/thirdparty/unittest_parallel.py +549 -0
  294. warp/torch.py +72 -30
  295. warp/types.py +1744 -713
  296. warp/utils.py +360 -350
  297. warp_lang-0.11.0.dist-info/LICENSE.md +36 -0
  298. warp_lang-0.11.0.dist-info/METADATA +238 -0
  299. warp_lang-0.11.0.dist-info/RECORD +332 -0
  300. {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
  301. warp/bin/warp-clang.exp +0 -0
  302. warp/bin/warp-clang.lib +0 -0
  303. warp/bin/warp.exp +0 -0
  304. warp/bin/warp.lib +0 -0
  305. warp/tests/test_all.py +0 -215
  306. warp/tests/test_array_scan.py +0 -60
  307. warp/tests/test_base.py +0 -208
  308. warp/tests/test_unresolved_func.py +0 -7
  309. warp/tests/test_unresolved_symbol.py +0 -7
  310. warp_lang-0.9.0.dist-info/METADATA +0 -20
  311. warp_lang-0.9.0.dist-info/RECORD +0 -177
  312. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  313. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  314. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  315. {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
warp/native/warp.cu CHANGED
@@ -73,10 +73,15 @@ struct DeviceInfo
73
73
  static constexpr int kNameLen = 128;
74
74
 
75
75
  CUdevice device = -1;
76
+ CUuuid uuid = {0};
76
77
  int ordinal = -1;
78
+ int pci_domain_id = -1;
79
+ int pci_bus_id = -1;
80
+ int pci_device_id = -1;
77
81
  char name[kNameLen] = "";
78
82
  int arch = 0;
79
83
  int is_uva = 0;
84
+ int is_memory_pool_supported = 0;
80
85
  };
81
86
 
82
87
  struct ContextInfo
@@ -125,7 +130,12 @@ int cuda_init()
125
130
  g_devices[i].device = device;
126
131
  g_devices[i].ordinal = i;
127
132
  check_cu(cuDeviceGetName_f(g_devices[i].name, DeviceInfo::kNameLen, device));
133
+ check_cu(cuDeviceGetUuid_f(&g_devices[i].uuid, device));
134
+ check_cu(cuDeviceGetAttribute_f(&g_devices[i].pci_domain_id, CU_DEVICE_ATTRIBUTE_PCI_DOMAIN_ID, device));
135
+ check_cu(cuDeviceGetAttribute_f(&g_devices[i].pci_bus_id, CU_DEVICE_ATTRIBUTE_PCI_BUS_ID, device));
136
+ check_cu(cuDeviceGetAttribute_f(&g_devices[i].pci_device_id, CU_DEVICE_ATTRIBUTE_PCI_DEVICE_ID, device));
128
137
  check_cu(cuDeviceGetAttribute_f(&g_devices[i].is_uva, CU_DEVICE_ATTRIBUTE_UNIFIED_ADDRESSING, device));
138
+ check_cu(cuDeviceGetAttribute_f(&g_devices[i].is_memory_pool_supported, CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED, device));
129
139
  int major = 0;
130
140
  int minor = 0;
131
141
  check_cu(cuDeviceGetAttribute_f(&major, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device));
@@ -216,6 +226,26 @@ void* alloc_device(void* context, size_t s)
216
226
  return ptr;
217
227
  }
218
228
 
229
+ void* alloc_temp_device(void* context, size_t s)
230
+ {
231
+ // "cudaMallocAsync ignores the current device/context when determining where the allocation will reside. Instead,
232
+ // cudaMallocAsync determines the resident device based on the specified memory pool or the supplied stream."
233
+ ContextGuard guard(context);
234
+
235
+ void* ptr;
236
+
237
+ if (cuda_context_is_memory_pool_supported(context))
238
+ {
239
+ check_cuda(cudaMallocAsync(&ptr, s, get_current_stream()));
240
+ }
241
+ else
242
+ {
243
+ check_cuda(cudaMalloc(&ptr, s));
244
+ }
245
+
246
+ return ptr;
247
+ }
248
+
219
249
  void free_device(void* context, void* ptr)
220
250
  {
221
251
  ContextGuard guard(context);
@@ -223,6 +253,20 @@ void free_device(void* context, void* ptr)
223
253
  check_cuda(cudaFree(ptr));
224
254
  }
225
255
 
256
+ void free_temp_device(void* context, void* ptr)
257
+ {
258
+ ContextGuard guard(context);
259
+
260
+ if (cuda_context_is_memory_pool_supported(context))
261
+ {
262
+ check_cuda(cudaFreeAsync(ptr, get_current_stream()));
263
+ }
264
+ else
265
+ {
266
+ check_cuda(cudaFree(ptr));
267
+ }
268
+ }
269
+
226
270
  void memcpy_h2d(void* context, void* dest, void* src, size_t n)
227
271
  {
228
272
  ContextGuard guard(context);
@@ -266,7 +310,7 @@ void memset_device(void* context, void* dest, int value, size_t n)
266
310
  {
267
311
  ContextGuard guard(context);
268
312
 
269
- if ((n%4) > 0)
313
+ if (true)// ((n%4) > 0)
270
314
  {
271
315
  // for unaligned lengths fallback to CUDA memset
272
316
  check_cuda(cudaMemsetAsync(dest, value, n, get_current_stream()));
@@ -279,6 +323,72 @@ void memset_device(void* context, void* dest, int value, size_t n)
279
323
  }
280
324
  }
281
325
 
326
+ // fill memory buffer with a value: generic memtile kernel using memcpy for each element
327
+ __global__ void memtile_kernel(void* dst, const void* src, size_t srcsize, size_t n)
328
+ {
329
+ size_t tid = wp::grid_index();
330
+ if (tid < n)
331
+ {
332
+ memcpy((int8_t*)dst + srcsize * tid, src, srcsize);
333
+ }
334
+ }
335
+
336
+ // this should be faster than memtile_kernel, but requires proper alignment of dst
337
+ template <typename T>
338
+ __global__ void memtile_value_kernel(T* dst, T value, size_t n)
339
+ {
340
+ size_t tid = wp::grid_index();
341
+ if (tid < n)
342
+ {
343
+ dst[tid] = value;
344
+ }
345
+ }
346
+
347
+ void memtile_device(void* context, void* dst, const void* src, size_t srcsize, size_t n)
348
+ {
349
+ ContextGuard guard(context);
350
+
351
+ size_t dst_addr = reinterpret_cast<size_t>(dst);
352
+ size_t src_addr = reinterpret_cast<size_t>(src);
353
+
354
+ // try memtile_value first because it should be faster, but we need to ensure proper alignment
355
+ if (srcsize == 8 && (dst_addr & 7) == 0 && (src_addr & 7) == 0)
356
+ {
357
+ int64_t* p = reinterpret_cast<int64_t*>(dst);
358
+ int64_t value = *reinterpret_cast<const int64_t*>(src);
359
+ wp_launch_device(WP_CURRENT_CONTEXT, memtile_value_kernel, n, (p, value, n));
360
+ }
361
+ else if (srcsize == 4 && (dst_addr & 3) == 0 && (src_addr & 3) == 0)
362
+ {
363
+ int32_t* p = reinterpret_cast<int32_t*>(dst);
364
+ int32_t value = *reinterpret_cast<const int32_t*>(src);
365
+ wp_launch_device(WP_CURRENT_CONTEXT, memtile_value_kernel, n, (p, value, n));
366
+ }
367
+ else if (srcsize == 2 && (dst_addr & 1) == 0 && (src_addr & 1) == 0)
368
+ {
369
+ int16_t* p = reinterpret_cast<int16_t*>(dst);
370
+ int16_t value = *reinterpret_cast<const int16_t*>(src);
371
+ wp_launch_device(WP_CURRENT_CONTEXT, memtile_value_kernel, n, (p, value, n));
372
+ }
373
+ else if (srcsize == 1)
374
+ {
375
+ check_cuda(cudaMemset(dst, *reinterpret_cast<const int8_t*>(src), n));
376
+ }
377
+ else
378
+ {
379
+ // generic version
380
+
381
+ // TODO: use a persistent stream-local staging buffer to avoid allocs?
382
+ void* src_device;
383
+ check_cuda(cudaMalloc(&src_device, srcsize));
384
+ check_cuda(cudaMemcpyAsync(src_device, src, srcsize, cudaMemcpyHostToDevice, get_current_stream()));
385
+
386
+ wp_launch_device(WP_CURRENT_CONTEXT, memtile_kernel, n, (dst, src_device, srcsize, n));
387
+
388
+ check_cuda(cudaFree(src_device));
389
+ }
390
+ }
391
+
282
392
 
283
393
  static __global__ void array_copy_1d_kernel(void* dst, const void* src,
284
394
  int dst_stride, int src_stride,
@@ -382,6 +492,125 @@ static __global__ void array_copy_4d_kernel(void* dst, const void* src,
382
492
  }
383
493
 
384
494
 
495
+ static __global__ void array_copy_from_fabric_kernel(wp::fabricarray_t<void> src,
496
+ void* dst_data, int dst_stride, const int* dst_indices,
497
+ int elem_size)
498
+ {
499
+ int tid = blockIdx.x * blockDim.x + threadIdx.x;
500
+
501
+ if (tid < src.size)
502
+ {
503
+ int dst_idx = dst_indices ? dst_indices[tid] : tid;
504
+ void* dst_ptr = (char*)dst_data + dst_idx * dst_stride;
505
+ const void* src_ptr = fabricarray_element_ptr(src, tid, elem_size);
506
+ memcpy(dst_ptr, src_ptr, elem_size);
507
+ }
508
+ }
509
+
510
+ static __global__ void array_copy_from_fabric_indexed_kernel(wp::indexedfabricarray_t<void> src,
511
+ void* dst_data, int dst_stride, const int* dst_indices,
512
+ int elem_size)
513
+ {
514
+ int tid = blockIdx.x * blockDim.x + threadIdx.x;
515
+
516
+ if (tid < src.size)
517
+ {
518
+ int src_index = src.indices[tid];
519
+ int dst_idx = dst_indices ? dst_indices[tid] : tid;
520
+ void* dst_ptr = (char*)dst_data + dst_idx * dst_stride;
521
+ const void* src_ptr = fabricarray_element_ptr(src.fa, src_index, elem_size);
522
+ memcpy(dst_ptr, src_ptr, elem_size);
523
+ }
524
+ }
525
+
526
+ static __global__ void array_copy_to_fabric_kernel(wp::fabricarray_t<void> dst,
527
+ const void* src_data, int src_stride, const int* src_indices,
528
+ int elem_size)
529
+ {
530
+ int tid = blockIdx.x * blockDim.x + threadIdx.x;
531
+
532
+ if (tid < dst.size)
533
+ {
534
+ int src_idx = src_indices ? src_indices[tid] : tid;
535
+ const void* src_ptr = (const char*)src_data + src_idx * src_stride;
536
+ void* dst_ptr = fabricarray_element_ptr(dst, tid, elem_size);
537
+ memcpy(dst_ptr, src_ptr, elem_size);
538
+ }
539
+ }
540
+
541
+ static __global__ void array_copy_to_fabric_indexed_kernel(wp::indexedfabricarray_t<void> dst,
542
+ const void* src_data, int src_stride, const int* src_indices,
543
+ int elem_size)
544
+ {
545
+ int tid = blockIdx.x * blockDim.x + threadIdx.x;
546
+
547
+ if (tid < dst.size)
548
+ {
549
+ int src_idx = src_indices ? src_indices[tid] : tid;
550
+ const void* src_ptr = (const char*)src_data + src_idx * src_stride;
551
+ int dst_idx = dst.indices[tid];
552
+ void* dst_ptr = fabricarray_element_ptr(dst.fa, dst_idx, elem_size);
553
+ memcpy(dst_ptr, src_ptr, elem_size);
554
+ }
555
+ }
556
+
557
+
558
+ static __global__ void array_copy_fabric_to_fabric_kernel(wp::fabricarray_t<void> dst, wp::fabricarray_t<void> src, int elem_size)
559
+ {
560
+ int tid = blockIdx.x * blockDim.x + threadIdx.x;
561
+
562
+ if (tid < dst.size)
563
+ {
564
+ const void* src_ptr = fabricarray_element_ptr(src, tid, elem_size);
565
+ void* dst_ptr = fabricarray_element_ptr(dst, tid, elem_size);
566
+ memcpy(dst_ptr, src_ptr, elem_size);
567
+ }
568
+ }
569
+
570
+
571
+ static __global__ void array_copy_fabric_to_fabric_indexed_kernel(wp::indexedfabricarray_t<void> dst, wp::fabricarray_t<void> src, int elem_size)
572
+ {
573
+ int tid = blockIdx.x * blockDim.x + threadIdx.x;
574
+
575
+ if (tid < dst.size)
576
+ {
577
+ const void* src_ptr = fabricarray_element_ptr(src, tid, elem_size);
578
+ int dst_index = dst.indices[tid];
579
+ void* dst_ptr = fabricarray_element_ptr(dst.fa, dst_index, elem_size);
580
+ memcpy(dst_ptr, src_ptr, elem_size);
581
+ }
582
+ }
583
+
584
+
585
+ static __global__ void array_copy_fabric_indexed_to_fabric_kernel(wp::fabricarray_t<void> dst, wp::indexedfabricarray_t<void> src, int elem_size)
586
+ {
587
+ int tid = blockIdx.x * blockDim.x + threadIdx.x;
588
+
589
+ if (tid < dst.size)
590
+ {
591
+ int src_index = src.indices[tid];
592
+ const void* src_ptr = fabricarray_element_ptr(src.fa, src_index, elem_size);
593
+ void* dst_ptr = fabricarray_element_ptr(dst, tid, elem_size);
594
+ memcpy(dst_ptr, src_ptr, elem_size);
595
+ }
596
+ }
597
+
598
+
599
+ static __global__ void array_copy_fabric_indexed_to_fabric_indexed_kernel(wp::indexedfabricarray_t<void> dst, wp::indexedfabricarray_t<void> src, int elem_size)
600
+ {
601
+ int tid = blockIdx.x * blockDim.x + threadIdx.x;
602
+
603
+ if (tid < dst.size)
604
+ {
605
+ int src_index = src.indices[tid];
606
+ int dst_index = dst.indices[tid];
607
+ const void* src_ptr = fabricarray_element_ptr(src.fa, src_index, elem_size);
608
+ void* dst_ptr = fabricarray_element_ptr(dst.fa, dst_index, elem_size);
609
+ memcpy(dst_ptr, src_ptr, elem_size);
610
+ }
611
+ }
612
+
613
+
385
614
  WP_API size_t array_copy_device(void* context, void* dst, void* src, int dst_type, int src_type, int elem_size)
386
615
  {
387
616
  if (!src || !dst)
@@ -400,6 +629,12 @@ WP_API size_t array_copy_device(void* context, void* dst, void* src, int dst_typ
400
629
  const int*const* src_indices = NULL;
401
630
  const int*const* dst_indices = NULL;
402
631
 
632
+ const wp::fabricarray_t<void>* src_fabricarray = NULL;
633
+ wp::fabricarray_t<void>* dst_fabricarray = NULL;
634
+
635
+ const wp::indexedfabricarray_t<void>* src_indexedfabricarray = NULL;
636
+ wp::indexedfabricarray_t<void>* dst_indexedfabricarray = NULL;
637
+
403
638
  const int* null_indices[wp::ARRAY_MAX_DIMS] = { NULL };
404
639
 
405
640
  if (src_type == wp::ARRAY_TYPE_REGULAR)
@@ -421,9 +656,19 @@ WP_API size_t array_copy_device(void* context, void* dst, void* src, int dst_typ
421
656
  src_strides = src_arr.arr.strides;
422
657
  src_indices = src_arr.indices;
423
658
  }
659
+ else if (src_type == wp::ARRAY_TYPE_FABRIC)
660
+ {
661
+ src_fabricarray = static_cast<const wp::fabricarray_t<void>*>(src);
662
+ src_ndim = 1;
663
+ }
664
+ else if (src_type == wp::ARRAY_TYPE_FABRIC_INDEXED)
665
+ {
666
+ src_indexedfabricarray = static_cast<const wp::indexedfabricarray_t<void>*>(src);
667
+ src_ndim = 1;
668
+ }
424
669
  else
425
670
  {
426
- fprintf(stderr, "Warp error: Invalid array type (%d)\n", src_type);
671
+ fprintf(stderr, "Warp copy error: Invalid array type (%d)\n", src_type);
427
672
  return 0;
428
673
  }
429
674
 
@@ -446,33 +691,149 @@ WP_API size_t array_copy_device(void* context, void* dst, void* src, int dst_typ
446
691
  dst_strides = dst_arr.arr.strides;
447
692
  dst_indices = dst_arr.indices;
448
693
  }
694
+ else if (dst_type == wp::ARRAY_TYPE_FABRIC)
695
+ {
696
+ dst_fabricarray = static_cast<wp::fabricarray_t<void>*>(dst);
697
+ dst_ndim = 1;
698
+ }
699
+ else if (dst_type == wp::ARRAY_TYPE_FABRIC_INDEXED)
700
+ {
701
+ dst_indexedfabricarray = static_cast<wp::indexedfabricarray_t<void>*>(dst);
702
+ dst_ndim = 1;
703
+ }
449
704
  else
450
705
  {
451
- fprintf(stderr, "Warp error: Invalid array type (%d)\n", dst_type);
706
+ fprintf(stderr, "Warp copy error: Invalid array type (%d)\n", dst_type);
452
707
  return 0;
453
708
  }
454
709
 
455
710
  if (src_ndim != dst_ndim)
456
711
  {
457
- fprintf(stderr, "Warp error: Incompatible array dimensionalities (%d and %d)\n", src_ndim, dst_ndim);
712
+ fprintf(stderr, "Warp copy error: Incompatible array dimensionalities (%d and %d)\n", src_ndim, dst_ndim);
458
713
  return 0;
459
714
  }
460
715
 
461
- bool has_grad = (src_grad && dst_grad);
462
- size_t n = 1;
716
+ ContextGuard guard(context);
463
717
 
718
+ // handle fabric arrays
719
+ if (dst_fabricarray)
720
+ {
721
+ size_t n = dst_fabricarray->size;
722
+ if (src_fabricarray)
723
+ {
724
+ // copy from fabric to fabric
725
+ if (src_fabricarray->size != n)
726
+ {
727
+ fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
728
+ return 0;
729
+ }
730
+ wp_launch_device(WP_CURRENT_CONTEXT, array_copy_fabric_to_fabric_kernel, n,
731
+ (*dst_fabricarray, *src_fabricarray, elem_size));
732
+ return n;
733
+ }
734
+ else if (src_indexedfabricarray)
735
+ {
736
+ // copy from fabric indexed to fabric
737
+ if (src_indexedfabricarray->size != n)
738
+ {
739
+ fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
740
+ return 0;
741
+ }
742
+ wp_launch_device(WP_CURRENT_CONTEXT, array_copy_fabric_indexed_to_fabric_kernel, n,
743
+ (*dst_fabricarray, *src_indexedfabricarray, elem_size));
744
+ return n;
745
+ }
746
+ else
747
+ {
748
+ // copy to fabric
749
+ if (size_t(src_shape[0]) != n)
750
+ {
751
+ fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
752
+ return 0;
753
+ }
754
+ wp_launch_device(WP_CURRENT_CONTEXT, array_copy_to_fabric_kernel, n,
755
+ (*dst_fabricarray, src_data, src_strides[0], src_indices[0], elem_size));
756
+ return n;
757
+ }
758
+ }
759
+ if (dst_indexedfabricarray)
760
+ {
761
+ size_t n = dst_indexedfabricarray->size;
762
+ if (src_fabricarray)
763
+ {
764
+ // copy from fabric to fabric indexed
765
+ if (src_fabricarray->size != n)
766
+ {
767
+ fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
768
+ return 0;
769
+ }
770
+ wp_launch_device(WP_CURRENT_CONTEXT, array_copy_fabric_to_fabric_indexed_kernel, n,
771
+ (*dst_indexedfabricarray, *src_fabricarray, elem_size));
772
+ return n;
773
+ }
774
+ else if (src_indexedfabricarray)
775
+ {
776
+ // copy from fabric indexed to fabric indexed
777
+ if (src_indexedfabricarray->size != n)
778
+ {
779
+ fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
780
+ return 0;
781
+ }
782
+ wp_launch_device(WP_CURRENT_CONTEXT, array_copy_fabric_indexed_to_fabric_indexed_kernel, n,
783
+ (*dst_indexedfabricarray, *src_indexedfabricarray, elem_size));
784
+ return n;
785
+ }
786
+ else
787
+ {
788
+ // copy to fabric indexed
789
+ if (size_t(src_shape[0]) != n)
790
+ {
791
+ fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
792
+ return 0;
793
+ }
794
+ wp_launch_device(WP_CURRENT_CONTEXT, array_copy_to_fabric_indexed_kernel, n,
795
+ (*dst_indexedfabricarray, src_data, src_strides[0], src_indices[0], elem_size));
796
+ return n;
797
+ }
798
+ }
799
+ else if (src_fabricarray)
800
+ {
801
+ // copy from fabric
802
+ size_t n = src_fabricarray->size;
803
+ if (size_t(dst_shape[0]) != n)
804
+ {
805
+ fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
806
+ return 0;
807
+ }
808
+ wp_launch_device(WP_CURRENT_CONTEXT, array_copy_from_fabric_kernel, n,
809
+ (*src_fabricarray, dst_data, dst_strides[0], dst_indices[0], elem_size));
810
+ return n;
811
+ }
812
+ else if (src_indexedfabricarray)
813
+ {
814
+ // copy from fabric indexed
815
+ size_t n = src_indexedfabricarray->size;
816
+ if (size_t(dst_shape[0]) != n)
817
+ {
818
+ fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
819
+ return 0;
820
+ }
821
+ wp_launch_device(WP_CURRENT_CONTEXT, array_copy_from_fabric_indexed_kernel, n,
822
+ (*src_indexedfabricarray, dst_data, dst_strides[0], dst_indices[0], elem_size));
823
+ return n;
824
+ }
825
+
826
+ size_t n = 1;
464
827
  for (int i = 0; i < src_ndim; i++)
465
828
  {
466
829
  if (src_shape[i] != dst_shape[i])
467
830
  {
468
- fprintf(stderr, "Warp error: Incompatible array shapes\n");
831
+ fprintf(stderr, "Warp copy error: Incompatible array shapes\n");
469
832
  return 0;
470
833
  }
471
834
  n *= src_shape[i];
472
835
  }
473
836
 
474
- ContextGuard guard(context);
475
-
476
837
  switch (src_ndim)
477
838
  {
478
839
  case 1:
@@ -481,13 +842,6 @@ WP_API size_t array_copy_device(void* context, void* dst, void* src, int dst_typ
481
842
  dst_strides[0], src_strides[0],
482
843
  dst_indices[0], src_indices[0],
483
844
  src_shape[0], elem_size));
484
- if (has_grad)
485
- {
486
- wp_launch_device(WP_CURRENT_CONTEXT, array_copy_1d_kernel, n, (dst_grad, src_grad,
487
- dst_strides[0], src_strides[0],
488
- dst_indices[0], src_indices[0],
489
- src_shape[0], elem_size));
490
- }
491
845
  break;
492
846
  }
493
847
  case 2:
@@ -502,13 +856,6 @@ WP_API size_t array_copy_device(void* context, void* dst, void* src, int dst_typ
502
856
  dst_strides_v, src_strides_v,
503
857
  dst_indices_v, src_indices_v,
504
858
  shape_v, elem_size));
505
- if (has_grad)
506
- {
507
- wp_launch_device(WP_CURRENT_CONTEXT, array_copy_2d_kernel, n, (dst_grad, src_grad,
508
- dst_strides_v, src_strides_v,
509
- dst_indices_v, src_indices_v,
510
- shape_v, elem_size));
511
- }
512
859
  break;
513
860
  }
514
861
  case 3:
@@ -523,13 +870,6 @@ WP_API size_t array_copy_device(void* context, void* dst, void* src, int dst_typ
523
870
  dst_strides_v, src_strides_v,
524
871
  dst_indices_v, src_indices_v,
525
872
  shape_v, elem_size));
526
- if (has_grad)
527
- {
528
- wp_launch_device(WP_CURRENT_CONTEXT, array_copy_3d_kernel, n, (dst_grad, src_grad,
529
- dst_strides_v, src_strides_v,
530
- dst_indices_v, src_indices_v,
531
- shape_v, elem_size));
532
- }
533
873
  break;
534
874
  }
535
875
  case 4:
@@ -544,17 +884,10 @@ WP_API size_t array_copy_device(void* context, void* dst, void* src, int dst_typ
544
884
  dst_strides_v, src_strides_v,
545
885
  dst_indices_v, src_indices_v,
546
886
  shape_v, elem_size));
547
- if (has_grad)
548
- {
549
- wp_launch_device(WP_CURRENT_CONTEXT, array_copy_4d_kernel, n, (dst_grad, src_grad,
550
- dst_strides_v, src_strides_v,
551
- dst_indices_v, src_indices_v,
552
- shape_v, elem_size));
553
- }
554
887
  break;
555
888
  }
556
889
  default:
557
- fprintf(stderr, "Warp error: invalid array dimensionality (%d)\n", src_ndim);
890
+ fprintf(stderr, "Warp copy error: invalid array dimensionality (%d)\n", src_ndim);
558
891
  return 0;
559
892
  }
560
893
 
@@ -565,43 +898,231 @@ WP_API size_t array_copy_device(void* context, void* dst, void* src, int dst_typ
565
898
  }
566
899
 
567
900
 
568
- __global__ void memtile_kernel(char* dest, char* src, size_t srcsize, size_t n)
901
+ static __global__ void array_fill_1d_kernel(void* data,
902
+ int n,
903
+ int stride,
904
+ const int* indices,
905
+ const void* value,
906
+ int value_size)
569
907
  {
570
- const size_t tid = wp::grid_index();
571
-
572
- if (tid < n)
908
+ int i = blockIdx.x * blockDim.x + threadIdx.x;
909
+ if (i < n)
573
910
  {
574
- char *d = dest + srcsize * tid;
575
- char *s = src;
576
- for( size_t i=0; i < srcsize; ++i,++d,++s )
577
- {
578
- *d = *s;
579
- }
911
+ int idx = indices ? indices[i] : i;
912
+ char* p = (char*)data + idx * stride;
913
+ memcpy(p, value, value_size);
580
914
  }
581
915
  }
582
916
 
583
- void memtile_device(void* context, void* dest, void *src, size_t srcsize, size_t n)
917
+ static __global__ void array_fill_2d_kernel(void* data,
918
+ wp::vec_t<2, int> shape,
919
+ wp::vec_t<2, int> strides,
920
+ wp::vec_t<2, const int*> indices,
921
+ const void* value,
922
+ int value_size)
584
923
  {
585
- ContextGuard guard(context);
586
-
587
- void* src_device;
588
- check_cuda(cudaMalloc(&src_device, srcsize));
589
- check_cuda(cudaMemcpyAsync(src_device, src, srcsize, cudaMemcpyHostToDevice, get_current_stream()));
924
+ int tid = blockIdx.x * blockDim.x + threadIdx.x;
925
+ int n = shape[1];
926
+ int i = tid / n;
927
+ int j = tid % n;
928
+ if (i < shape[0] /*&& j < shape[1]*/)
929
+ {
930
+ int idx0 = indices[0] ? indices[0][i] : i;
931
+ int idx1 = indices[1] ? indices[1][j] : j;
932
+ char* p = (char*)data + idx0 * strides[0] + idx1 * strides[1];
933
+ memcpy(p, value, value_size);
934
+ }
935
+ }
590
936
 
591
- wp_launch_device(WP_CURRENT_CONTEXT, memtile_kernel, n, ((char *)dest,(char *)src_device,srcsize,n));
937
+ static __global__ void array_fill_3d_kernel(void* data,
938
+ wp::vec_t<3, int> shape,
939
+ wp::vec_t<3, int> strides,
940
+ wp::vec_t<3, const int*> indices,
941
+ const void* value,
942
+ int value_size)
943
+ {
944
+ int tid = blockIdx.x * blockDim.x + threadIdx.x;
945
+ int n = shape[1];
946
+ int o = shape[2];
947
+ int i = tid / (n * o);
948
+ int j = tid % (n * o) / o;
949
+ int k = tid % o;
950
+ if (i < shape[0] && j < shape[1] /*&& k < shape[2]*/)
951
+ {
952
+ int idx0 = indices[0] ? indices[0][i] : i;
953
+ int idx1 = indices[1] ? indices[1][j] : j;
954
+ int idx2 = indices[2] ? indices[2][k] : k;
955
+ char* p = (char*)data + idx0 * strides[0] + idx1 * strides[1] + idx2 * strides[2];
956
+ memcpy(p, value, value_size);
957
+ }
958
+ }
592
959
 
593
- check_cuda(cudaFree(src_device));
960
+ static __global__ void array_fill_4d_kernel(void* data,
961
+ wp::vec_t<4, int> shape,
962
+ wp::vec_t<4, int> strides,
963
+ wp::vec_t<4, const int*> indices,
964
+ const void* value,
965
+ int value_size)
966
+ {
967
+ int tid = blockIdx.x * blockDim.x + threadIdx.x;
968
+ int n = shape[1];
969
+ int o = shape[2];
970
+ int p = shape[3];
971
+ int i = tid / (n * o * p);
972
+ int j = tid % (n * o * p) / (o * p);
973
+ int k = tid % (o * p) / p;
974
+ int l = tid % p;
975
+ if (i < shape[0] && j < shape[1] && k < shape[2] /*&& l < shape[3]*/)
976
+ {
977
+ int idx0 = indices[0] ? indices[0][i] : i;
978
+ int idx1 = indices[1] ? indices[1][j] : j;
979
+ int idx2 = indices[2] ? indices[2][k] : k;
980
+ int idx3 = indices[3] ? indices[3][l] : l;
981
+ char* p = (char*)data + idx0 * strides[0] + idx1 * strides[1] + idx2 * strides[2] + idx3 * strides[3];
982
+ memcpy(p, value, value_size);
983
+ }
594
984
  }
595
985
 
596
986
 
597
- void array_inner_device(uint64_t a, uint64_t b, uint64_t out, int len)
987
+ static __global__ void array_fill_fabric_kernel(wp::fabricarray_t<void> fa, const void* value, int value_size)
598
988
  {
989
+ int tid = blockIdx.x * blockDim.x + threadIdx.x;
990
+ if (tid < fa.size)
991
+ {
992
+ void* dst_ptr = fabricarray_element_ptr(fa, tid, value_size);
993
+ memcpy(dst_ptr, value, value_size);
994
+ }
995
+ }
599
996
 
997
+
998
+ static __global__ void array_fill_fabric_indexed_kernel(wp::indexedfabricarray_t<void> ifa, const void* value, int value_size)
999
+ {
1000
+ int tid = blockIdx.x * blockDim.x + threadIdx.x;
1001
+ if (tid < ifa.size)
1002
+ {
1003
+ size_t idx = size_t(ifa.indices[tid]);
1004
+ if (idx < ifa.fa.size)
1005
+ {
1006
+ void* dst_ptr = fabricarray_element_ptr(ifa.fa, idx, value_size);
1007
+ memcpy(dst_ptr, value, value_size);
1008
+ }
1009
+ }
600
1010
  }
601
1011
 
602
- void array_sum_device(uint64_t a, uint64_t out, int len)
1012
+
1013
+ WP_API void array_fill_device(void* context, void* arr_ptr, int arr_type, const void* value_ptr, int value_size)
603
1014
  {
604
-
1015
+ if (!arr_ptr || !value_ptr)
1016
+ return;
1017
+
1018
+ void* data = NULL;
1019
+ int ndim = 0;
1020
+ const int* shape = NULL;
1021
+ const int* strides = NULL;
1022
+ const int*const* indices = NULL;
1023
+
1024
+ wp::fabricarray_t<void>* fa = NULL;
1025
+ wp::indexedfabricarray_t<void>* ifa = NULL;
1026
+
1027
+ const int* null_indices[wp::ARRAY_MAX_DIMS] = { NULL };
1028
+
1029
+ if (arr_type == wp::ARRAY_TYPE_REGULAR)
1030
+ {
1031
+ wp::array_t<void>& arr = *static_cast<wp::array_t<void>*>(arr_ptr);
1032
+ data = arr.data;
1033
+ ndim = arr.ndim;
1034
+ shape = arr.shape.dims;
1035
+ strides = arr.strides;
1036
+ indices = null_indices;
1037
+ }
1038
+ else if (arr_type == wp::ARRAY_TYPE_INDEXED)
1039
+ {
1040
+ wp::indexedarray_t<void>& ia = *static_cast<wp::indexedarray_t<void>*>(arr_ptr);
1041
+ data = ia.arr.data;
1042
+ ndim = ia.arr.ndim;
1043
+ shape = ia.shape.dims;
1044
+ strides = ia.arr.strides;
1045
+ indices = ia.indices;
1046
+ }
1047
+ else if (arr_type == wp::ARRAY_TYPE_FABRIC)
1048
+ {
1049
+ fa = static_cast<wp::fabricarray_t<void>*>(arr_ptr);
1050
+ }
1051
+ else if (arr_type == wp::ARRAY_TYPE_FABRIC_INDEXED)
1052
+ {
1053
+ ifa = static_cast<wp::indexedfabricarray_t<void>*>(arr_ptr);
1054
+ }
1055
+ else
1056
+ {
1057
+ fprintf(stderr, "Warp fill error: Invalid array type id %d\n", arr_type);
1058
+ return;
1059
+ }
1060
+
1061
+ size_t n = 1;
1062
+ for (int i = 0; i < ndim; i++)
1063
+ n *= shape[i];
1064
+
1065
+ ContextGuard guard(context);
1066
+
1067
+ // copy value to device memory
1068
+ void* value_devptr;
1069
+ check_cuda(cudaMalloc(&value_devptr, value_size));
1070
+ check_cuda(cudaMemcpyAsync(value_devptr, value_ptr, value_size, cudaMemcpyHostToDevice, get_current_stream()));
1071
+
1072
+ // handle fabric arrays
1073
+ if (fa)
1074
+ {
1075
+ wp_launch_device(WP_CURRENT_CONTEXT, array_fill_fabric_kernel, n,
1076
+ (*fa, value_devptr, value_size));
1077
+ return;
1078
+ }
1079
+ else if (ifa)
1080
+ {
1081
+ wp_launch_device(WP_CURRENT_CONTEXT, array_fill_fabric_indexed_kernel, n,
1082
+ (*ifa, value_devptr, value_size));
1083
+ return;
1084
+ }
1085
+
1086
+ // handle regular or indexed arrays
1087
+ switch (ndim)
1088
+ {
1089
+ case 1:
1090
+ {
1091
+ wp_launch_device(WP_CURRENT_CONTEXT, array_fill_1d_kernel, n,
1092
+ (data, shape[0], strides[0], indices[0], value_devptr, value_size));
1093
+ break;
1094
+ }
1095
+ case 2:
1096
+ {
1097
+ wp::vec_t<2, int> shape_v(shape[0], shape[1]);
1098
+ wp::vec_t<2, int> strides_v(strides[0], strides[1]);
1099
+ wp::vec_t<2, const int*> indices_v(indices[0], indices[1]);
1100
+ wp_launch_device(WP_CURRENT_CONTEXT, array_fill_2d_kernel, n,
1101
+ (data, shape_v, strides_v, indices_v, value_devptr, value_size));
1102
+ break;
1103
+ }
1104
+ case 3:
1105
+ {
1106
+ wp::vec_t<3, int> shape_v(shape[0], shape[1], shape[2]);
1107
+ wp::vec_t<3, int> strides_v(strides[0], strides[1], strides[2]);
1108
+ wp::vec_t<3, const int*> indices_v(indices[0], indices[1], indices[2]);
1109
+ wp_launch_device(WP_CURRENT_CONTEXT, array_fill_3d_kernel, n,
1110
+ (data, shape_v, strides_v, indices_v, value_devptr, value_size));
1111
+ break;
1112
+ }
1113
+ case 4:
1114
+ {
1115
+ wp::vec_t<4, int> shape_v(shape[0], shape[1], shape[2], shape[3]);
1116
+ wp::vec_t<4, int> strides_v(strides[0], strides[1], strides[2], strides[3]);
1117
+ wp::vec_t<4, const int*> indices_v(indices[0], indices[1], indices[2], indices[3]);
1118
+ wp_launch_device(WP_CURRENT_CONTEXT, array_fill_4d_kernel, n,
1119
+ (data, shape_v, strides_v, indices_v, value_devptr, value_size));
1120
+ break;
1121
+ }
1122
+ default:
1123
+ fprintf(stderr, "Warp fill error: invalid array dimensionality (%d)\n", ndim);
1124
+ return;
1125
+ }
605
1126
  }
606
1127
 
607
1128
  void array_scan_int_device(uint64_t in, uint64_t out, int len, bool inclusive)
@@ -628,6 +1149,11 @@ int cuda_toolkit_version()
628
1149
  return CUDA_VERSION;
629
1150
  }
630
1151
 
1152
+ bool cuda_driver_is_initialized()
1153
+ {
1154
+ return is_cuda_driver_initialized();
1155
+ }
1156
+
631
1157
  int nvrtc_supported_arch_count()
632
1158
  {
633
1159
  int count;
@@ -682,6 +1208,32 @@ int cuda_device_get_arch(int ordinal)
682
1208
  return 0;
683
1209
  }
684
1210
 
1211
+ void cuda_device_get_uuid(int ordinal, char uuid[16])
1212
+ {
1213
+ memcpy(uuid, g_devices[ordinal].uuid.bytes, sizeof(char)*16);
1214
+ }
1215
+
1216
+ int cuda_device_get_pci_domain_id(int ordinal)
1217
+ {
1218
+ if (ordinal >= 0 && ordinal < int(g_devices.size()))
1219
+ return g_devices[ordinal].pci_domain_id;
1220
+ return -1;
1221
+ }
1222
+
1223
+ int cuda_device_get_pci_bus_id(int ordinal)
1224
+ {
1225
+ if (ordinal >= 0 && ordinal < int(g_devices.size()))
1226
+ return g_devices[ordinal].pci_bus_id;
1227
+ return -1;
1228
+ }
1229
+
1230
+ int cuda_device_get_pci_device_id(int ordinal)
1231
+ {
1232
+ if (ordinal >= 0 && ordinal < int(g_devices.size()))
1233
+ return g_devices[ordinal].pci_device_id;
1234
+ return -1;
1235
+ }
1236
+
685
1237
  int cuda_device_is_uva(int ordinal)
686
1238
  {
687
1239
  if (ordinal >= 0 && ordinal < int(g_devices.size()))
@@ -689,6 +1241,13 @@ int cuda_device_is_uva(int ordinal)
689
1241
  return 0;
690
1242
  }
691
1243
 
1244
+ int cuda_device_is_memory_pool_supported(int ordinal)
1245
+ {
1246
+ if (ordinal >= 0 && ordinal < int(g_devices.size()))
1247
+ return g_devices[ordinal].is_memory_pool_supported;
1248
+ return false;
1249
+ }
1250
+
692
1251
  void* cuda_context_get_current()
693
1252
  {
694
1253
  return get_current_context();
@@ -797,6 +1356,16 @@ int cuda_context_is_primary(void* context)
797
1356
  return 0;
798
1357
  }
799
1358
 
1359
+ int cuda_context_is_memory_pool_supported(void* context)
1360
+ {
1361
+ int ordinal = cuda_context_get_device_ordinal(context);
1362
+ if (ordinal != -1)
1363
+ {
1364
+ return cuda_device_is_memory_pool_supported(ordinal);
1365
+ }
1366
+ return 0;
1367
+ }
1368
+
800
1369
  void* cuda_context_get_stream(void* context)
801
1370
  {
802
1371
  ContextInfo* info = get_context_info(static_cast<CUcontext>(context));
@@ -1006,10 +1575,10 @@ void* cuda_graph_end_capture(void* context)
1006
1575
  //cudaGraphDebugDotPrint(graph, "graph.dot", cudaGraphDebugDotFlagsVerbose);
1007
1576
 
1008
1577
  cudaGraphExec_t graph_exec = NULL;
1009
- check_cuda(cudaGraphInstantiate(&graph_exec, graph, NULL, NULL, 0));
1578
+ //check_cuda(cudaGraphInstantiate(&graph_exec, graph, NULL, NULL, 0));
1010
1579
 
1011
1580
  // can use after CUDA 11.4 to permit graphs to capture cudaMallocAsync() operations
1012
- //check_cuda(cudaGraphInstantiateWithFlags(&graph_exec, graph, cudaGraphInstantiateFlagAutoFreeOnLaunch));
1581
+ check_cuda(cudaGraphInstantiateWithFlags(&graph_exec, graph, cudaGraphInstantiateFlagAutoFreeOnLaunch));
1013
1582
 
1014
1583
  // free source graph
1015
1584
  check_cuda(cudaGraphDestroy(graph));
@@ -1064,10 +1633,8 @@ size_t cuda_compile_program(const char* cuda_src, int arch, const char* include_
1064
1633
 
1065
1634
  std::vector<const char*> opts;
1066
1635
  opts.push_back(arch_opt);
1067
- opts.push_back(include_opt);
1068
- opts.push_back("--device-as-default-execution-space");
1636
+ opts.push_back(include_opt);
1069
1637
  opts.push_back("--std=c++11");
1070
- opts.push_back("--define-macro=WP_CUDA");
1071
1638
 
1072
1639
  if (debug)
1073
1640
  {
@@ -1193,7 +1760,7 @@ void* cuda_load_module(void* context, const char* path)
1193
1760
  size_t length = ftell(file);
1194
1761
  fseek(file, 0, SEEK_SET);
1195
1762
 
1196
- input.resize(length);
1763
+ input.resize(length + 1);
1197
1764
  if (fread(input.data(), 1, length, file) != length)
1198
1765
  {
1199
1766
  fprintf(stderr, "Warp error: Failed to read input file '%s'\n", path);
@@ -1201,6 +1768,8 @@ void* cuda_load_module(void* context, const char* path)
1201
1768
  return NULL;
1202
1769
  }
1203
1770
  fclose(file);
1771
+
1772
+ input[length] = '\0';
1204
1773
  }
1205
1774
  else
1206
1775
  {
@@ -1306,19 +1875,39 @@ void* cuda_get_kernel(void* context, void* module, const char* name)
1306
1875
 
1307
1876
  CUfunction kernel = NULL;
1308
1877
  if (!check_cu(cuModuleGetFunction_f(&kernel, (CUmodule)module, name)))
1309
- printf("Warp: Failed to lookup kernel function %s in module\n", name);
1878
+ fprintf(stderr, "Warp CUDA error: Failed to lookup kernel function %s in module\n", name);
1310
1879
 
1311
1880
  return kernel;
1312
1881
  }
1313
1882
 
1314
- size_t cuda_launch_kernel(void* context, void* kernel, size_t dim, void** args)
1883
+ size_t cuda_launch_kernel(void* context, void* kernel, size_t dim, int max_blocks, void** args)
1315
1884
  {
1316
1885
  ContextGuard guard(context);
1317
1886
 
1318
1887
  const int block_dim = 256;
1319
1888
  // CUDA specs up to compute capability 9.0 says the max x-dim grid is 2**31-1, so
1320
1889
  // grid_dim is fine as an int for the near future
1321
- const int grid_dim = (dim + block_dim - 1)/block_dim;
1890
+ int grid_dim = (dim + block_dim - 1)/block_dim;
1891
+
1892
+ if (max_blocks <= 0) {
1893
+ max_blocks = 2147483647;
1894
+ }
1895
+
1896
+ if (grid_dim < 0)
1897
+ {
1898
+ #if defined(_DEBUG)
1899
+ fprintf(stderr, "Warp warning: Overflow in grid dimensions detected for %zu total elements and 256 threads "
1900
+ "per block.\n Setting block count to %d.\n", dim, max_blocks);
1901
+ #endif
1902
+ grid_dim = max_blocks;
1903
+ }
1904
+ else
1905
+ {
1906
+ if (grid_dim > max_blocks)
1907
+ {
1908
+ grid_dim = max_blocks;
1909
+ }
1910
+ }
1322
1911
 
1323
1912
  CUresult res = cuLaunchKernel_f(
1324
1913
  (CUfunction)kernel,
@@ -1384,8 +1973,11 @@ void cuda_graphics_unregister_resource(void* context, void* resource)
1384
1973
  #include "mesh.cu"
1385
1974
  #include "sort.cu"
1386
1975
  #include "hashgrid.cu"
1976
+ #include "reduce.cu"
1977
+ #include "runlength_encode.cu"
1387
1978
  #include "scan.cu"
1388
1979
  #include "marching.cu"
1980
+ #include "sparse.cu"
1389
1981
  #include "volume.cu"
1390
1982
  #include "volume_builder.cu"
1391
1983
  #if WP_ENABLE_CUTLASS