warp-lang 0.10.1__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (300) hide show
  1. warp/__init__.py +10 -4
  2. warp/__init__.pyi +1 -0
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +5 -3
  6. warp/build_dll.py +29 -9
  7. warp/builtins.py +868 -507
  8. warp/codegen.py +1074 -638
  9. warp/config.py +3 -3
  10. warp/constants.py +6 -0
  11. warp/context.py +715 -222
  12. warp/fabric.py +326 -0
  13. warp/fem/__init__.py +27 -0
  14. warp/fem/cache.py +389 -0
  15. warp/fem/dirichlet.py +181 -0
  16. warp/fem/domain.py +263 -0
  17. warp/fem/field/__init__.py +101 -0
  18. warp/fem/field/field.py +149 -0
  19. warp/fem/field/nodal_field.py +299 -0
  20. warp/fem/field/restriction.py +21 -0
  21. warp/fem/field/test.py +181 -0
  22. warp/fem/field/trial.py +183 -0
  23. warp/fem/geometry/__init__.py +19 -0
  24. warp/fem/geometry/closest_point.py +70 -0
  25. warp/fem/geometry/deformed_geometry.py +271 -0
  26. warp/fem/geometry/element.py +744 -0
  27. warp/fem/geometry/geometry.py +186 -0
  28. warp/fem/geometry/grid_2d.py +373 -0
  29. warp/fem/geometry/grid_3d.py +435 -0
  30. warp/fem/geometry/hexmesh.py +953 -0
  31. warp/fem/geometry/partition.py +376 -0
  32. warp/fem/geometry/quadmesh_2d.py +532 -0
  33. warp/fem/geometry/tetmesh.py +840 -0
  34. warp/fem/geometry/trimesh_2d.py +577 -0
  35. warp/fem/integrate.py +1616 -0
  36. warp/fem/operator.py +191 -0
  37. warp/fem/polynomial.py +213 -0
  38. warp/fem/quadrature/__init__.py +2 -0
  39. warp/fem/quadrature/pic_quadrature.py +245 -0
  40. warp/fem/quadrature/quadrature.py +294 -0
  41. warp/fem/space/__init__.py +292 -0
  42. warp/fem/space/basis_space.py +489 -0
  43. warp/fem/space/collocated_function_space.py +105 -0
  44. warp/fem/space/dof_mapper.py +236 -0
  45. warp/fem/space/function_space.py +145 -0
  46. warp/fem/space/grid_2d_function_space.py +267 -0
  47. warp/fem/space/grid_3d_function_space.py +306 -0
  48. warp/fem/space/hexmesh_function_space.py +352 -0
  49. warp/fem/space/partition.py +350 -0
  50. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  51. warp/fem/space/restriction.py +160 -0
  52. warp/fem/space/shape/__init__.py +15 -0
  53. warp/fem/space/shape/cube_shape_function.py +738 -0
  54. warp/fem/space/shape/shape_function.py +103 -0
  55. warp/fem/space/shape/square_shape_function.py +611 -0
  56. warp/fem/space/shape/tet_shape_function.py +567 -0
  57. warp/fem/space/shape/triangle_shape_function.py +429 -0
  58. warp/fem/space/tetmesh_function_space.py +292 -0
  59. warp/fem/space/topology.py +295 -0
  60. warp/fem/space/trimesh_2d_function_space.py +221 -0
  61. warp/fem/types.py +77 -0
  62. warp/fem/utils.py +495 -0
  63. warp/native/array.h +147 -44
  64. warp/native/builtin.h +122 -149
  65. warp/native/bvh.cpp +73 -325
  66. warp/native/bvh.cu +406 -23
  67. warp/native/bvh.h +34 -43
  68. warp/native/clang/clang.cpp +13 -8
  69. warp/native/crt.h +2 -0
  70. warp/native/cuda_crt.h +5 -0
  71. warp/native/cuda_util.cpp +15 -3
  72. warp/native/cuda_util.h +3 -1
  73. warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
  74. warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
  75. warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
  76. warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
  77. warp/native/cutlass/tools/library/scripts/library.py +799 -0
  78. warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
  79. warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
  80. warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
  81. warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
  82. warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
  83. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
  84. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
  85. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
  86. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
  87. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
  88. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
  89. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
  90. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
  91. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
  92. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
  93. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
  94. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
  95. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
  96. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
  97. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
  98. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
  99. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
  100. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
  101. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
  102. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
  103. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
  104. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
  105. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
  106. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
  107. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
  108. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
  109. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
  110. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
  111. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
  112. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
  113. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
  114. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
  115. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  116. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  117. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
  118. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
  119. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
  120. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  121. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  122. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
  123. warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
  124. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
  125. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
  126. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
  127. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
  128. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
  129. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
  130. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
  131. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
  132. warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
  133. warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
  134. warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
  135. warp/native/cutlass/tools/library/scripts/rt.py +796 -0
  136. warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
  137. warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
  138. warp/native/cutlass_gemm.cu +5 -3
  139. warp/native/exports.h +1240 -952
  140. warp/native/fabric.h +228 -0
  141. warp/native/hashgrid.cpp +4 -4
  142. warp/native/hashgrid.h +22 -2
  143. warp/native/intersect.h +22 -7
  144. warp/native/intersect_adj.h +8 -8
  145. warp/native/intersect_tri.h +1 -1
  146. warp/native/marching.cu +157 -161
  147. warp/native/mat.h +80 -19
  148. warp/native/matnn.h +2 -2
  149. warp/native/mesh.cpp +33 -108
  150. warp/native/mesh.cu +114 -23
  151. warp/native/mesh.h +446 -46
  152. warp/native/noise.h +272 -329
  153. warp/native/quat.h +51 -8
  154. warp/native/rand.h +45 -35
  155. warp/native/range.h +6 -2
  156. warp/native/reduce.cpp +1 -1
  157. warp/native/reduce.cu +10 -12
  158. warp/native/runlength_encode.cu +6 -10
  159. warp/native/scan.cu +8 -11
  160. warp/native/sparse.cpp +4 -4
  161. warp/native/sparse.cu +164 -154
  162. warp/native/spatial.h +2 -2
  163. warp/native/temp_buffer.h +14 -30
  164. warp/native/vec.h +107 -23
  165. warp/native/volume.h +120 -0
  166. warp/native/warp.cpp +560 -30
  167. warp/native/warp.cu +431 -44
  168. warp/native/warp.h +13 -4
  169. warp/optim/__init__.py +1 -0
  170. warp/optim/linear.py +922 -0
  171. warp/optim/sgd.py +92 -0
  172. warp/render/render_opengl.py +335 -119
  173. warp/render/render_usd.py +11 -11
  174. warp/sim/__init__.py +2 -2
  175. warp/sim/articulation.py +385 -185
  176. warp/sim/collide.py +8 -0
  177. warp/sim/import_mjcf.py +297 -106
  178. warp/sim/import_urdf.py +389 -210
  179. warp/sim/import_usd.py +198 -97
  180. warp/sim/inertia.py +17 -18
  181. warp/sim/integrator_euler.py +14 -8
  182. warp/sim/integrator_xpbd.py +158 -16
  183. warp/sim/model.py +795 -291
  184. warp/sim/render.py +3 -3
  185. warp/sim/utils.py +3 -0
  186. warp/sparse.py +640 -150
  187. warp/stubs.py +606 -267
  188. warp/tape.py +61 -10
  189. warp/tests/__main__.py +3 -6
  190. warp/tests/assets/curlnoise_golden.npy +0 -0
  191. warp/tests/assets/pnoise_golden.npy +0 -0
  192. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  193. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  194. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  195. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  196. warp/tests/aux_test_unresolved_func.py +14 -0
  197. warp/tests/aux_test_unresolved_symbol.py +14 -0
  198. warp/tests/disabled_kinematics.py +239 -0
  199. warp/tests/run_coverage_serial.py +31 -0
  200. warp/tests/test_adam.py +103 -106
  201. warp/tests/test_arithmetic.py +128 -74
  202. warp/tests/test_array.py +212 -97
  203. warp/tests/test_array_reduce.py +57 -23
  204. warp/tests/test_atomic.py +64 -28
  205. warp/tests/test_bool.py +99 -0
  206. warp/tests/test_builtins_resolution.py +1292 -0
  207. warp/tests/test_bvh.py +42 -18
  208. warp/tests/test_closest_point_edge_edge.py +54 -57
  209. warp/tests/test_codegen.py +208 -130
  210. warp/tests/test_compile_consts.py +28 -20
  211. warp/tests/test_conditional.py +108 -24
  212. warp/tests/test_copy.py +10 -12
  213. warp/tests/test_ctypes.py +112 -88
  214. warp/tests/test_dense.py +21 -14
  215. warp/tests/test_devices.py +98 -0
  216. warp/tests/test_dlpack.py +75 -75
  217. warp/tests/test_examples.py +277 -0
  218. warp/tests/test_fabricarray.py +955 -0
  219. warp/tests/test_fast_math.py +15 -11
  220. warp/tests/test_fem.py +1271 -0
  221. warp/tests/test_fp16.py +53 -19
  222. warp/tests/test_func.py +187 -86
  223. warp/tests/test_generics.py +194 -49
  224. warp/tests/test_grad.py +178 -109
  225. warp/tests/test_grad_customs.py +176 -0
  226. warp/tests/test_hash_grid.py +52 -37
  227. warp/tests/test_import.py +10 -23
  228. warp/tests/test_indexedarray.py +32 -31
  229. warp/tests/test_intersect.py +18 -9
  230. warp/tests/test_large.py +141 -0
  231. warp/tests/test_launch.py +14 -41
  232. warp/tests/test_lerp.py +64 -65
  233. warp/tests/test_linear_solvers.py +154 -0
  234. warp/tests/test_lvalue.py +493 -0
  235. warp/tests/test_marching_cubes.py +12 -13
  236. warp/tests/test_mat.py +517 -2898
  237. warp/tests/test_mat_lite.py +115 -0
  238. warp/tests/test_mat_scalar_ops.py +2889 -0
  239. warp/tests/test_math.py +103 -9
  240. warp/tests/test_matmul.py +305 -69
  241. warp/tests/test_matmul_lite.py +410 -0
  242. warp/tests/test_mesh.py +71 -14
  243. warp/tests/test_mesh_query_aabb.py +41 -25
  244. warp/tests/test_mesh_query_point.py +140 -22
  245. warp/tests/test_mesh_query_ray.py +39 -22
  246. warp/tests/test_mlp.py +30 -22
  247. warp/tests/test_model.py +92 -89
  248. warp/tests/test_modules_lite.py +39 -0
  249. warp/tests/test_multigpu.py +88 -114
  250. warp/tests/test_noise.py +12 -11
  251. warp/tests/test_operators.py +16 -20
  252. warp/tests/test_options.py +11 -11
  253. warp/tests/test_pinned.py +17 -18
  254. warp/tests/test_print.py +32 -11
  255. warp/tests/test_quat.py +275 -129
  256. warp/tests/test_rand.py +18 -16
  257. warp/tests/test_reload.py +38 -34
  258. warp/tests/test_rounding.py +50 -43
  259. warp/tests/test_runlength_encode.py +168 -20
  260. warp/tests/test_smoothstep.py +9 -11
  261. warp/tests/test_snippet.py +143 -0
  262. warp/tests/test_sparse.py +261 -63
  263. warp/tests/test_spatial.py +276 -243
  264. warp/tests/test_streams.py +110 -85
  265. warp/tests/test_struct.py +268 -63
  266. warp/tests/test_tape.py +39 -21
  267. warp/tests/test_torch.py +118 -89
  268. warp/tests/test_transient_module.py +12 -13
  269. warp/tests/test_types.py +614 -0
  270. warp/tests/test_utils.py +494 -0
  271. warp/tests/test_vec.py +354 -2050
  272. warp/tests/test_vec_lite.py +73 -0
  273. warp/tests/test_vec_scalar_ops.py +2099 -0
  274. warp/tests/test_volume.py +457 -293
  275. warp/tests/test_volume_write.py +124 -134
  276. warp/tests/unittest_serial.py +35 -0
  277. warp/tests/unittest_suites.py +341 -0
  278. warp/tests/unittest_utils.py +568 -0
  279. warp/tests/unused_test_misc.py +71 -0
  280. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  281. warp/thirdparty/appdirs.py +36 -45
  282. warp/thirdparty/unittest_parallel.py +549 -0
  283. warp/torch.py +9 -6
  284. warp/types.py +1089 -366
  285. warp/utils.py +93 -387
  286. warp_lang-0.11.0.dist-info/METADATA +238 -0
  287. warp_lang-0.11.0.dist-info/RECORD +332 -0
  288. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
  289. warp/tests/test_all.py +0 -219
  290. warp/tests/test_array_scan.py +0 -60
  291. warp/tests/test_base.py +0 -208
  292. warp/tests/test_unresolved_func.py +0 -7
  293. warp/tests/test_unresolved_symbol.py +0 -7
  294. warp_lang-0.10.1.dist-info/METADATA +0 -21
  295. warp_lang-0.10.1.dist-info/RECORD +0 -188
  296. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  297. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  298. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  299. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/LICENSE.md +0 -0
  300. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
warp/native/warp.cu CHANGED
@@ -73,10 +73,15 @@ struct DeviceInfo
73
73
  static constexpr int kNameLen = 128;
74
74
 
75
75
  CUdevice device = -1;
76
+ CUuuid uuid = {0};
76
77
  int ordinal = -1;
78
+ int pci_domain_id = -1;
79
+ int pci_bus_id = -1;
80
+ int pci_device_id = -1;
77
81
  char name[kNameLen] = "";
78
82
  int arch = 0;
79
83
  int is_uva = 0;
84
+ int is_memory_pool_supported = 0;
80
85
  };
81
86
 
82
87
  struct ContextInfo
@@ -125,7 +130,12 @@ int cuda_init()
125
130
  g_devices[i].device = device;
126
131
  g_devices[i].ordinal = i;
127
132
  check_cu(cuDeviceGetName_f(g_devices[i].name, DeviceInfo::kNameLen, device));
133
+ check_cu(cuDeviceGetUuid_f(&g_devices[i].uuid, device));
134
+ check_cu(cuDeviceGetAttribute_f(&g_devices[i].pci_domain_id, CU_DEVICE_ATTRIBUTE_PCI_DOMAIN_ID, device));
135
+ check_cu(cuDeviceGetAttribute_f(&g_devices[i].pci_bus_id, CU_DEVICE_ATTRIBUTE_PCI_BUS_ID, device));
136
+ check_cu(cuDeviceGetAttribute_f(&g_devices[i].pci_device_id, CU_DEVICE_ATTRIBUTE_PCI_DEVICE_ID, device));
128
137
  check_cu(cuDeviceGetAttribute_f(&g_devices[i].is_uva, CU_DEVICE_ATTRIBUTE_UNIFIED_ADDRESSING, device));
138
+ check_cu(cuDeviceGetAttribute_f(&g_devices[i].is_memory_pool_supported, CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED, device));
129
139
  int major = 0;
130
140
  int minor = 0;
131
141
  check_cu(cuDeviceGetAttribute_f(&major, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device));
@@ -216,6 +226,26 @@ void* alloc_device(void* context, size_t s)
216
226
  return ptr;
217
227
  }
218
228
 
229
+ void* alloc_temp_device(void* context, size_t s)
230
+ {
231
+ // "cudaMallocAsync ignores the current device/context when determining where the allocation will reside. Instead,
232
+ // cudaMallocAsync determines the resident device based on the specified memory pool or the supplied stream."
233
+ ContextGuard guard(context);
234
+
235
+ void* ptr;
236
+
237
+ if (cuda_context_is_memory_pool_supported(context))
238
+ {
239
+ check_cuda(cudaMallocAsync(&ptr, s, get_current_stream()));
240
+ }
241
+ else
242
+ {
243
+ check_cuda(cudaMalloc(&ptr, s));
244
+ }
245
+
246
+ return ptr;
247
+ }
248
+
219
249
  void free_device(void* context, void* ptr)
220
250
  {
221
251
  ContextGuard guard(context);
@@ -223,6 +253,20 @@ void free_device(void* context, void* ptr)
223
253
  check_cuda(cudaFree(ptr));
224
254
  }
225
255
 
256
+ void free_temp_device(void* context, void* ptr)
257
+ {
258
+ ContextGuard guard(context);
259
+
260
+ if (cuda_context_is_memory_pool_supported(context))
261
+ {
262
+ check_cuda(cudaFreeAsync(ptr, get_current_stream()));
263
+ }
264
+ else
265
+ {
266
+ check_cuda(cudaFree(ptr));
267
+ }
268
+ }
269
+
226
270
  void memcpy_h2d(void* context, void* dest, void* src, size_t n)
227
271
  {
228
272
  ContextGuard guard(context);
@@ -266,7 +310,7 @@ void memset_device(void* context, void* dest, int value, size_t n)
266
310
  {
267
311
  ContextGuard guard(context);
268
312
 
269
- if ((n%4) > 0)
313
+ if (true)// ((n%4) > 0)
270
314
  {
271
315
  // for unaligned lengths fallback to CUDA memset
272
316
  check_cuda(cudaMemsetAsync(dest, value, n, get_current_stream()));
@@ -448,6 +492,125 @@ static __global__ void array_copy_4d_kernel(void* dst, const void* src,
448
492
  }
449
493
 
450
494
 
495
+ static __global__ void array_copy_from_fabric_kernel(wp::fabricarray_t<void> src,
496
+ void* dst_data, int dst_stride, const int* dst_indices,
497
+ int elem_size)
498
+ {
499
+ int tid = blockIdx.x * blockDim.x + threadIdx.x;
500
+
501
+ if (tid < src.size)
502
+ {
503
+ int dst_idx = dst_indices ? dst_indices[tid] : tid;
504
+ void* dst_ptr = (char*)dst_data + dst_idx * dst_stride;
505
+ const void* src_ptr = fabricarray_element_ptr(src, tid, elem_size);
506
+ memcpy(dst_ptr, src_ptr, elem_size);
507
+ }
508
+ }
509
+
510
+ static __global__ void array_copy_from_fabric_indexed_kernel(wp::indexedfabricarray_t<void> src,
511
+ void* dst_data, int dst_stride, const int* dst_indices,
512
+ int elem_size)
513
+ {
514
+ int tid = blockIdx.x * blockDim.x + threadIdx.x;
515
+
516
+ if (tid < src.size)
517
+ {
518
+ int src_index = src.indices[tid];
519
+ int dst_idx = dst_indices ? dst_indices[tid] : tid;
520
+ void* dst_ptr = (char*)dst_data + dst_idx * dst_stride;
521
+ const void* src_ptr = fabricarray_element_ptr(src.fa, src_index, elem_size);
522
+ memcpy(dst_ptr, src_ptr, elem_size);
523
+ }
524
+ }
525
+
526
+ static __global__ void array_copy_to_fabric_kernel(wp::fabricarray_t<void> dst,
527
+ const void* src_data, int src_stride, const int* src_indices,
528
+ int elem_size)
529
+ {
530
+ int tid = blockIdx.x * blockDim.x + threadIdx.x;
531
+
532
+ if (tid < dst.size)
533
+ {
534
+ int src_idx = src_indices ? src_indices[tid] : tid;
535
+ const void* src_ptr = (const char*)src_data + src_idx * src_stride;
536
+ void* dst_ptr = fabricarray_element_ptr(dst, tid, elem_size);
537
+ memcpy(dst_ptr, src_ptr, elem_size);
538
+ }
539
+ }
540
+
541
+ static __global__ void array_copy_to_fabric_indexed_kernel(wp::indexedfabricarray_t<void> dst,
542
+ const void* src_data, int src_stride, const int* src_indices,
543
+ int elem_size)
544
+ {
545
+ int tid = blockIdx.x * blockDim.x + threadIdx.x;
546
+
547
+ if (tid < dst.size)
548
+ {
549
+ int src_idx = src_indices ? src_indices[tid] : tid;
550
+ const void* src_ptr = (const char*)src_data + src_idx * src_stride;
551
+ int dst_idx = dst.indices[tid];
552
+ void* dst_ptr = fabricarray_element_ptr(dst.fa, dst_idx, elem_size);
553
+ memcpy(dst_ptr, src_ptr, elem_size);
554
+ }
555
+ }
556
+
557
+
558
+ static __global__ void array_copy_fabric_to_fabric_kernel(wp::fabricarray_t<void> dst, wp::fabricarray_t<void> src, int elem_size)
559
+ {
560
+ int tid = blockIdx.x * blockDim.x + threadIdx.x;
561
+
562
+ if (tid < dst.size)
563
+ {
564
+ const void* src_ptr = fabricarray_element_ptr(src, tid, elem_size);
565
+ void* dst_ptr = fabricarray_element_ptr(dst, tid, elem_size);
566
+ memcpy(dst_ptr, src_ptr, elem_size);
567
+ }
568
+ }
569
+
570
+
571
+ static __global__ void array_copy_fabric_to_fabric_indexed_kernel(wp::indexedfabricarray_t<void> dst, wp::fabricarray_t<void> src, int elem_size)
572
+ {
573
+ int tid = blockIdx.x * blockDim.x + threadIdx.x;
574
+
575
+ if (tid < dst.size)
576
+ {
577
+ const void* src_ptr = fabricarray_element_ptr(src, tid, elem_size);
578
+ int dst_index = dst.indices[tid];
579
+ void* dst_ptr = fabricarray_element_ptr(dst.fa, dst_index, elem_size);
580
+ memcpy(dst_ptr, src_ptr, elem_size);
581
+ }
582
+ }
583
+
584
+
585
+ static __global__ void array_copy_fabric_indexed_to_fabric_kernel(wp::fabricarray_t<void> dst, wp::indexedfabricarray_t<void> src, int elem_size)
586
+ {
587
+ int tid = blockIdx.x * blockDim.x + threadIdx.x;
588
+
589
+ if (tid < dst.size)
590
+ {
591
+ int src_index = src.indices[tid];
592
+ const void* src_ptr = fabricarray_element_ptr(src.fa, src_index, elem_size);
593
+ void* dst_ptr = fabricarray_element_ptr(dst, tid, elem_size);
594
+ memcpy(dst_ptr, src_ptr, elem_size);
595
+ }
596
+ }
597
+
598
+
599
+ static __global__ void array_copy_fabric_indexed_to_fabric_indexed_kernel(wp::indexedfabricarray_t<void> dst, wp::indexedfabricarray_t<void> src, int elem_size)
600
+ {
601
+ int tid = blockIdx.x * blockDim.x + threadIdx.x;
602
+
603
+ if (tid < dst.size)
604
+ {
605
+ int src_index = src.indices[tid];
606
+ int dst_index = dst.indices[tid];
607
+ const void* src_ptr = fabricarray_element_ptr(src.fa, src_index, elem_size);
608
+ void* dst_ptr = fabricarray_element_ptr(dst.fa, dst_index, elem_size);
609
+ memcpy(dst_ptr, src_ptr, elem_size);
610
+ }
611
+ }
612
+
613
+
451
614
  WP_API size_t array_copy_device(void* context, void* dst, void* src, int dst_type, int src_type, int elem_size)
452
615
  {
453
616
  if (!src || !dst)
@@ -466,6 +629,12 @@ WP_API size_t array_copy_device(void* context, void* dst, void* src, int dst_typ
466
629
  const int*const* src_indices = NULL;
467
630
  const int*const* dst_indices = NULL;
468
631
 
632
+ const wp::fabricarray_t<void>* src_fabricarray = NULL;
633
+ wp::fabricarray_t<void>* dst_fabricarray = NULL;
634
+
635
+ const wp::indexedfabricarray_t<void>* src_indexedfabricarray = NULL;
636
+ wp::indexedfabricarray_t<void>* dst_indexedfabricarray = NULL;
637
+
469
638
  const int* null_indices[wp::ARRAY_MAX_DIMS] = { NULL };
470
639
 
471
640
  if (src_type == wp::ARRAY_TYPE_REGULAR)
@@ -487,9 +656,19 @@ WP_API size_t array_copy_device(void* context, void* dst, void* src, int dst_typ
487
656
  src_strides = src_arr.arr.strides;
488
657
  src_indices = src_arr.indices;
489
658
  }
659
+ else if (src_type == wp::ARRAY_TYPE_FABRIC)
660
+ {
661
+ src_fabricarray = static_cast<const wp::fabricarray_t<void>*>(src);
662
+ src_ndim = 1;
663
+ }
664
+ else if (src_type == wp::ARRAY_TYPE_FABRIC_INDEXED)
665
+ {
666
+ src_indexedfabricarray = static_cast<const wp::indexedfabricarray_t<void>*>(src);
667
+ src_ndim = 1;
668
+ }
490
669
  else
491
670
  {
492
- fprintf(stderr, "Warp error: Invalid array type (%d)\n", src_type);
671
+ fprintf(stderr, "Warp copy error: Invalid array type (%d)\n", src_type);
493
672
  return 0;
494
673
  }
495
674
 
@@ -512,33 +691,149 @@ WP_API size_t array_copy_device(void* context, void* dst, void* src, int dst_typ
512
691
  dst_strides = dst_arr.arr.strides;
513
692
  dst_indices = dst_arr.indices;
514
693
  }
694
+ else if (dst_type == wp::ARRAY_TYPE_FABRIC)
695
+ {
696
+ dst_fabricarray = static_cast<wp::fabricarray_t<void>*>(dst);
697
+ dst_ndim = 1;
698
+ }
699
+ else if (dst_type == wp::ARRAY_TYPE_FABRIC_INDEXED)
700
+ {
701
+ dst_indexedfabricarray = static_cast<wp::indexedfabricarray_t<void>*>(dst);
702
+ dst_ndim = 1;
703
+ }
515
704
  else
516
705
  {
517
- fprintf(stderr, "Warp error: Invalid array type (%d)\n", dst_type);
706
+ fprintf(stderr, "Warp copy error: Invalid array type (%d)\n", dst_type);
518
707
  return 0;
519
708
  }
520
709
 
521
710
  if (src_ndim != dst_ndim)
522
711
  {
523
- fprintf(stderr, "Warp error: Incompatible array dimensionalities (%d and %d)\n", src_ndim, dst_ndim);
712
+ fprintf(stderr, "Warp copy error: Incompatible array dimensionalities (%d and %d)\n", src_ndim, dst_ndim);
524
713
  return 0;
525
714
  }
526
715
 
527
- bool has_grad = (src_grad && dst_grad);
528
- size_t n = 1;
716
+ ContextGuard guard(context);
717
+
718
+ // handle fabric arrays
719
+ if (dst_fabricarray)
720
+ {
721
+ size_t n = dst_fabricarray->size;
722
+ if (src_fabricarray)
723
+ {
724
+ // copy from fabric to fabric
725
+ if (src_fabricarray->size != n)
726
+ {
727
+ fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
728
+ return 0;
729
+ }
730
+ wp_launch_device(WP_CURRENT_CONTEXT, array_copy_fabric_to_fabric_kernel, n,
731
+ (*dst_fabricarray, *src_fabricarray, elem_size));
732
+ return n;
733
+ }
734
+ else if (src_indexedfabricarray)
735
+ {
736
+ // copy from fabric indexed to fabric
737
+ if (src_indexedfabricarray->size != n)
738
+ {
739
+ fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
740
+ return 0;
741
+ }
742
+ wp_launch_device(WP_CURRENT_CONTEXT, array_copy_fabric_indexed_to_fabric_kernel, n,
743
+ (*dst_fabricarray, *src_indexedfabricarray, elem_size));
744
+ return n;
745
+ }
746
+ else
747
+ {
748
+ // copy to fabric
749
+ if (size_t(src_shape[0]) != n)
750
+ {
751
+ fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
752
+ return 0;
753
+ }
754
+ wp_launch_device(WP_CURRENT_CONTEXT, array_copy_to_fabric_kernel, n,
755
+ (*dst_fabricarray, src_data, src_strides[0], src_indices[0], elem_size));
756
+ return n;
757
+ }
758
+ }
759
+ if (dst_indexedfabricarray)
760
+ {
761
+ size_t n = dst_indexedfabricarray->size;
762
+ if (src_fabricarray)
763
+ {
764
+ // copy from fabric to fabric indexed
765
+ if (src_fabricarray->size != n)
766
+ {
767
+ fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
768
+ return 0;
769
+ }
770
+ wp_launch_device(WP_CURRENT_CONTEXT, array_copy_fabric_to_fabric_indexed_kernel, n,
771
+ (*dst_indexedfabricarray, *src_fabricarray, elem_size));
772
+ return n;
773
+ }
774
+ else if (src_indexedfabricarray)
775
+ {
776
+ // copy from fabric indexed to fabric indexed
777
+ if (src_indexedfabricarray->size != n)
778
+ {
779
+ fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
780
+ return 0;
781
+ }
782
+ wp_launch_device(WP_CURRENT_CONTEXT, array_copy_fabric_indexed_to_fabric_indexed_kernel, n,
783
+ (*dst_indexedfabricarray, *src_indexedfabricarray, elem_size));
784
+ return n;
785
+ }
786
+ else
787
+ {
788
+ // copy to fabric indexed
789
+ if (size_t(src_shape[0]) != n)
790
+ {
791
+ fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
792
+ return 0;
793
+ }
794
+ wp_launch_device(WP_CURRENT_CONTEXT, array_copy_to_fabric_indexed_kernel, n,
795
+ (*dst_indexedfabricarray, src_data, src_strides[0], src_indices[0], elem_size));
796
+ return n;
797
+ }
798
+ }
799
+ else if (src_fabricarray)
800
+ {
801
+ // copy from fabric
802
+ size_t n = src_fabricarray->size;
803
+ if (size_t(dst_shape[0]) != n)
804
+ {
805
+ fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
806
+ return 0;
807
+ }
808
+ wp_launch_device(WP_CURRENT_CONTEXT, array_copy_from_fabric_kernel, n,
809
+ (*src_fabricarray, dst_data, dst_strides[0], dst_indices[0], elem_size));
810
+ return n;
811
+ }
812
+ else if (src_indexedfabricarray)
813
+ {
814
+ // copy from fabric indexed
815
+ size_t n = src_indexedfabricarray->size;
816
+ if (size_t(dst_shape[0]) != n)
817
+ {
818
+ fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
819
+ return 0;
820
+ }
821
+ wp_launch_device(WP_CURRENT_CONTEXT, array_copy_from_fabric_indexed_kernel, n,
822
+ (*src_indexedfabricarray, dst_data, dst_strides[0], dst_indices[0], elem_size));
823
+ return n;
824
+ }
529
825
 
826
+ size_t n = 1;
530
827
  for (int i = 0; i < src_ndim; i++)
531
828
  {
532
829
  if (src_shape[i] != dst_shape[i])
533
830
  {
534
- fprintf(stderr, "Warp error: Incompatible array shapes\n");
831
+ fprintf(stderr, "Warp copy error: Incompatible array shapes\n");
535
832
  return 0;
536
833
  }
537
834
  n *= src_shape[i];
538
835
  }
539
836
 
540
- ContextGuard guard(context);
541
-
542
837
  switch (src_ndim)
543
838
  {
544
839
  case 1:
@@ -547,13 +842,6 @@ WP_API size_t array_copy_device(void* context, void* dst, void* src, int dst_typ
547
842
  dst_strides[0], src_strides[0],
548
843
  dst_indices[0], src_indices[0],
549
844
  src_shape[0], elem_size));
550
- if (has_grad)
551
- {
552
- wp_launch_device(WP_CURRENT_CONTEXT, array_copy_1d_kernel, n, (dst_grad, src_grad,
553
- dst_strides[0], src_strides[0],
554
- dst_indices[0], src_indices[0],
555
- src_shape[0], elem_size));
556
- }
557
845
  break;
558
846
  }
559
847
  case 2:
@@ -568,13 +856,6 @@ WP_API size_t array_copy_device(void* context, void* dst, void* src, int dst_typ
568
856
  dst_strides_v, src_strides_v,
569
857
  dst_indices_v, src_indices_v,
570
858
  shape_v, elem_size));
571
- if (has_grad)
572
- {
573
- wp_launch_device(WP_CURRENT_CONTEXT, array_copy_2d_kernel, n, (dst_grad, src_grad,
574
- dst_strides_v, src_strides_v,
575
- dst_indices_v, src_indices_v,
576
- shape_v, elem_size));
577
- }
578
859
  break;
579
860
  }
580
861
  case 3:
@@ -589,13 +870,6 @@ WP_API size_t array_copy_device(void* context, void* dst, void* src, int dst_typ
589
870
  dst_strides_v, src_strides_v,
590
871
  dst_indices_v, src_indices_v,
591
872
  shape_v, elem_size));
592
- if (has_grad)
593
- {
594
- wp_launch_device(WP_CURRENT_CONTEXT, array_copy_3d_kernel, n, (dst_grad, src_grad,
595
- dst_strides_v, src_strides_v,
596
- dst_indices_v, src_indices_v,
597
- shape_v, elem_size));
598
- }
599
873
  break;
600
874
  }
601
875
  case 4:
@@ -610,17 +884,10 @@ WP_API size_t array_copy_device(void* context, void* dst, void* src, int dst_typ
610
884
  dst_strides_v, src_strides_v,
611
885
  dst_indices_v, src_indices_v,
612
886
  shape_v, elem_size));
613
- if (has_grad)
614
- {
615
- wp_launch_device(WP_CURRENT_CONTEXT, array_copy_4d_kernel, n, (dst_grad, src_grad,
616
- dst_strides_v, src_strides_v,
617
- dst_indices_v, src_indices_v,
618
- shape_v, elem_size));
619
- }
620
887
  break;
621
888
  }
622
889
  default:
623
- fprintf(stderr, "Warp error: invalid array dimensionality (%d)\n", src_ndim);
890
+ fprintf(stderr, "Warp copy error: invalid array dimensionality (%d)\n", src_ndim);
624
891
  return 0;
625
892
  }
626
893
 
@@ -717,6 +984,32 @@ static __global__ void array_fill_4d_kernel(void* data,
717
984
  }
718
985
 
719
986
 
987
+ static __global__ void array_fill_fabric_kernel(wp::fabricarray_t<void> fa, const void* value, int value_size)
988
+ {
989
+ int tid = blockIdx.x * blockDim.x + threadIdx.x;
990
+ if (tid < fa.size)
991
+ {
992
+ void* dst_ptr = fabricarray_element_ptr(fa, tid, value_size);
993
+ memcpy(dst_ptr, value, value_size);
994
+ }
995
+ }
996
+
997
+
998
+ static __global__ void array_fill_fabric_indexed_kernel(wp::indexedfabricarray_t<void> ifa, const void* value, int value_size)
999
+ {
1000
+ int tid = blockIdx.x * blockDim.x + threadIdx.x;
1001
+ if (tid < ifa.size)
1002
+ {
1003
+ size_t idx = size_t(ifa.indices[tid]);
1004
+ if (idx < ifa.fa.size)
1005
+ {
1006
+ void* dst_ptr = fabricarray_element_ptr(ifa.fa, idx, value_size);
1007
+ memcpy(dst_ptr, value, value_size);
1008
+ }
1009
+ }
1010
+ }
1011
+
1012
+
720
1013
  WP_API void array_fill_device(void* context, void* arr_ptr, int arr_type, const void* value_ptr, int value_size)
721
1014
  {
722
1015
  if (!arr_ptr || !value_ptr)
@@ -728,6 +1021,9 @@ WP_API void array_fill_device(void* context, void* arr_ptr, int arr_type, const
728
1021
  const int* strides = NULL;
729
1022
  const int*const* indices = NULL;
730
1023
 
1024
+ wp::fabricarray_t<void>* fa = NULL;
1025
+ wp::indexedfabricarray_t<void>* ifa = NULL;
1026
+
731
1027
  const int* null_indices[wp::ARRAY_MAX_DIMS] = { NULL };
732
1028
 
733
1029
  if (arr_type == wp::ARRAY_TYPE_REGULAR)
@@ -748,9 +1044,17 @@ WP_API void array_fill_device(void* context, void* arr_ptr, int arr_type, const
748
1044
  strides = ia.arr.strides;
749
1045
  indices = ia.indices;
750
1046
  }
1047
+ else if (arr_type == wp::ARRAY_TYPE_FABRIC)
1048
+ {
1049
+ fa = static_cast<wp::fabricarray_t<void>*>(arr_ptr);
1050
+ }
1051
+ else if (arr_type == wp::ARRAY_TYPE_FABRIC_INDEXED)
1052
+ {
1053
+ ifa = static_cast<wp::indexedfabricarray_t<void>*>(arr_ptr);
1054
+ }
751
1055
  else
752
1056
  {
753
- fprintf(stderr, "Warp error: Invalid array type id %d\n", arr_type);
1057
+ fprintf(stderr, "Warp fill error: Invalid array type id %d\n", arr_type);
754
1058
  return;
755
1059
  }
756
1060
 
@@ -765,6 +1069,21 @@ WP_API void array_fill_device(void* context, void* arr_ptr, int arr_type, const
765
1069
  check_cuda(cudaMalloc(&value_devptr, value_size));
766
1070
  check_cuda(cudaMemcpyAsync(value_devptr, value_ptr, value_size, cudaMemcpyHostToDevice, get_current_stream()));
767
1071
 
1072
+ // handle fabric arrays
1073
+ if (fa)
1074
+ {
1075
+ wp_launch_device(WP_CURRENT_CONTEXT, array_fill_fabric_kernel, n,
1076
+ (*fa, value_devptr, value_size));
1077
+ return;
1078
+ }
1079
+ else if (ifa)
1080
+ {
1081
+ wp_launch_device(WP_CURRENT_CONTEXT, array_fill_fabric_indexed_kernel, n,
1082
+ (*ifa, value_devptr, value_size));
1083
+ return;
1084
+ }
1085
+
1086
+ // handle regular or indexed arrays
768
1087
  switch (ndim)
769
1088
  {
770
1089
  case 1:
@@ -801,7 +1120,7 @@ WP_API void array_fill_device(void* context, void* arr_ptr, int arr_type, const
801
1120
  break;
802
1121
  }
803
1122
  default:
804
- fprintf(stderr, "Warp error: invalid array dimensionality (%d)\n", ndim);
1123
+ fprintf(stderr, "Warp fill error: invalid array dimensionality (%d)\n", ndim);
805
1124
  return;
806
1125
  }
807
1126
  }
@@ -830,6 +1149,11 @@ int cuda_toolkit_version()
830
1149
  return CUDA_VERSION;
831
1150
  }
832
1151
 
1152
+ bool cuda_driver_is_initialized()
1153
+ {
1154
+ return is_cuda_driver_initialized();
1155
+ }
1156
+
833
1157
  int nvrtc_supported_arch_count()
834
1158
  {
835
1159
  int count;
@@ -884,6 +1208,32 @@ int cuda_device_get_arch(int ordinal)
884
1208
  return 0;
885
1209
  }
886
1210
 
1211
+ void cuda_device_get_uuid(int ordinal, char uuid[16])
1212
+ {
1213
+ memcpy(uuid, g_devices[ordinal].uuid.bytes, sizeof(char)*16);
1214
+ }
1215
+
1216
+ int cuda_device_get_pci_domain_id(int ordinal)
1217
+ {
1218
+ if (ordinal >= 0 && ordinal < int(g_devices.size()))
1219
+ return g_devices[ordinal].pci_domain_id;
1220
+ return -1;
1221
+ }
1222
+
1223
+ int cuda_device_get_pci_bus_id(int ordinal)
1224
+ {
1225
+ if (ordinal >= 0 && ordinal < int(g_devices.size()))
1226
+ return g_devices[ordinal].pci_bus_id;
1227
+ return -1;
1228
+ }
1229
+
1230
+ int cuda_device_get_pci_device_id(int ordinal)
1231
+ {
1232
+ if (ordinal >= 0 && ordinal < int(g_devices.size()))
1233
+ return g_devices[ordinal].pci_device_id;
1234
+ return -1;
1235
+ }
1236
+
887
1237
  int cuda_device_is_uva(int ordinal)
888
1238
  {
889
1239
  if (ordinal >= 0 && ordinal < int(g_devices.size()))
@@ -891,6 +1241,13 @@ int cuda_device_is_uva(int ordinal)
891
1241
  return 0;
892
1242
  }
893
1243
 
1244
+ int cuda_device_is_memory_pool_supported(int ordinal)
1245
+ {
1246
+ if (ordinal >= 0 && ordinal < int(g_devices.size()))
1247
+ return g_devices[ordinal].is_memory_pool_supported;
1248
+ return false;
1249
+ }
1250
+
894
1251
  void* cuda_context_get_current()
895
1252
  {
896
1253
  return get_current_context();
@@ -999,6 +1356,16 @@ int cuda_context_is_primary(void* context)
999
1356
  return 0;
1000
1357
  }
1001
1358
 
1359
+ int cuda_context_is_memory_pool_supported(void* context)
1360
+ {
1361
+ int ordinal = cuda_context_get_device_ordinal(context);
1362
+ if (ordinal != -1)
1363
+ {
1364
+ return cuda_device_is_memory_pool_supported(ordinal);
1365
+ }
1366
+ return 0;
1367
+ }
1368
+
1002
1369
  void* cuda_context_get_stream(void* context)
1003
1370
  {
1004
1371
  ContextInfo* info = get_context_info(static_cast<CUcontext>(context));
@@ -1208,10 +1575,10 @@ void* cuda_graph_end_capture(void* context)
1208
1575
  //cudaGraphDebugDotPrint(graph, "graph.dot", cudaGraphDebugDotFlagsVerbose);
1209
1576
 
1210
1577
  cudaGraphExec_t graph_exec = NULL;
1211
- check_cuda(cudaGraphInstantiate(&graph_exec, graph, NULL, NULL, 0));
1578
+ //check_cuda(cudaGraphInstantiate(&graph_exec, graph, NULL, NULL, 0));
1212
1579
 
1213
1580
  // can use after CUDA 11.4 to permit graphs to capture cudaMallocAsync() operations
1214
- //check_cuda(cudaGraphInstantiateWithFlags(&graph_exec, graph, cudaGraphInstantiateFlagAutoFreeOnLaunch));
1581
+ check_cuda(cudaGraphInstantiateWithFlags(&graph_exec, graph, cudaGraphInstantiateFlagAutoFreeOnLaunch));
1215
1582
 
1216
1583
  // free source graph
1217
1584
  check_cuda(cudaGraphDestroy(graph));
@@ -1513,14 +1880,34 @@ void* cuda_get_kernel(void* context, void* module, const char* name)
1513
1880
  return kernel;
1514
1881
  }
1515
1882
 
1516
- size_t cuda_launch_kernel(void* context, void* kernel, size_t dim, void** args)
1883
+ size_t cuda_launch_kernel(void* context, void* kernel, size_t dim, int max_blocks, void** args)
1517
1884
  {
1518
1885
  ContextGuard guard(context);
1519
1886
 
1520
1887
  const int block_dim = 256;
1521
1888
  // CUDA specs up to compute capability 9.0 says the max x-dim grid is 2**31-1, so
1522
1889
  // grid_dim is fine as an int for the near future
1523
- const int grid_dim = (dim + block_dim - 1)/block_dim;
1890
+ int grid_dim = (dim + block_dim - 1)/block_dim;
1891
+
1892
+ if (max_blocks <= 0) {
1893
+ max_blocks = 2147483647;
1894
+ }
1895
+
1896
+ if (grid_dim < 0)
1897
+ {
1898
+ #if defined(_DEBUG)
1899
+ fprintf(stderr, "Warp warning: Overflow in grid dimensions detected for %zu total elements and 256 threads "
1900
+ "per block.\n Setting block count to %d.\n", dim, max_blocks);
1901
+ #endif
1902
+ grid_dim = max_blocks;
1903
+ }
1904
+ else
1905
+ {
1906
+ if (grid_dim > max_blocks)
1907
+ {
1908
+ grid_dim = max_blocks;
1909
+ }
1910
+ }
1524
1911
 
1525
1912
  CUresult res = cuLaunchKernel_f(
1526
1913
  (CUfunction)kernel,