warp-lang 0.10.1__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (300) hide show
  1. warp/__init__.py +10 -4
  2. warp/__init__.pyi +1 -0
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +5 -3
  6. warp/build_dll.py +29 -9
  7. warp/builtins.py +868 -507
  8. warp/codegen.py +1074 -638
  9. warp/config.py +3 -3
  10. warp/constants.py +6 -0
  11. warp/context.py +715 -222
  12. warp/fabric.py +326 -0
  13. warp/fem/__init__.py +27 -0
  14. warp/fem/cache.py +389 -0
  15. warp/fem/dirichlet.py +181 -0
  16. warp/fem/domain.py +263 -0
  17. warp/fem/field/__init__.py +101 -0
  18. warp/fem/field/field.py +149 -0
  19. warp/fem/field/nodal_field.py +299 -0
  20. warp/fem/field/restriction.py +21 -0
  21. warp/fem/field/test.py +181 -0
  22. warp/fem/field/trial.py +183 -0
  23. warp/fem/geometry/__init__.py +19 -0
  24. warp/fem/geometry/closest_point.py +70 -0
  25. warp/fem/geometry/deformed_geometry.py +271 -0
  26. warp/fem/geometry/element.py +744 -0
  27. warp/fem/geometry/geometry.py +186 -0
  28. warp/fem/geometry/grid_2d.py +373 -0
  29. warp/fem/geometry/grid_3d.py +435 -0
  30. warp/fem/geometry/hexmesh.py +953 -0
  31. warp/fem/geometry/partition.py +376 -0
  32. warp/fem/geometry/quadmesh_2d.py +532 -0
  33. warp/fem/geometry/tetmesh.py +840 -0
  34. warp/fem/geometry/trimesh_2d.py +577 -0
  35. warp/fem/integrate.py +1616 -0
  36. warp/fem/operator.py +191 -0
  37. warp/fem/polynomial.py +213 -0
  38. warp/fem/quadrature/__init__.py +2 -0
  39. warp/fem/quadrature/pic_quadrature.py +245 -0
  40. warp/fem/quadrature/quadrature.py +294 -0
  41. warp/fem/space/__init__.py +292 -0
  42. warp/fem/space/basis_space.py +489 -0
  43. warp/fem/space/collocated_function_space.py +105 -0
  44. warp/fem/space/dof_mapper.py +236 -0
  45. warp/fem/space/function_space.py +145 -0
  46. warp/fem/space/grid_2d_function_space.py +267 -0
  47. warp/fem/space/grid_3d_function_space.py +306 -0
  48. warp/fem/space/hexmesh_function_space.py +352 -0
  49. warp/fem/space/partition.py +350 -0
  50. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  51. warp/fem/space/restriction.py +160 -0
  52. warp/fem/space/shape/__init__.py +15 -0
  53. warp/fem/space/shape/cube_shape_function.py +738 -0
  54. warp/fem/space/shape/shape_function.py +103 -0
  55. warp/fem/space/shape/square_shape_function.py +611 -0
  56. warp/fem/space/shape/tet_shape_function.py +567 -0
  57. warp/fem/space/shape/triangle_shape_function.py +429 -0
  58. warp/fem/space/tetmesh_function_space.py +292 -0
  59. warp/fem/space/topology.py +295 -0
  60. warp/fem/space/trimesh_2d_function_space.py +221 -0
  61. warp/fem/types.py +77 -0
  62. warp/fem/utils.py +495 -0
  63. warp/native/array.h +147 -44
  64. warp/native/builtin.h +122 -149
  65. warp/native/bvh.cpp +73 -325
  66. warp/native/bvh.cu +406 -23
  67. warp/native/bvh.h +34 -43
  68. warp/native/clang/clang.cpp +13 -8
  69. warp/native/crt.h +2 -0
  70. warp/native/cuda_crt.h +5 -0
  71. warp/native/cuda_util.cpp +15 -3
  72. warp/native/cuda_util.h +3 -1
  73. warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
  74. warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
  75. warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
  76. warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
  77. warp/native/cutlass/tools/library/scripts/library.py +799 -0
  78. warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
  79. warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
  80. warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
  81. warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
  82. warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
  83. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
  84. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
  85. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
  86. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
  87. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
  88. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
  89. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
  90. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
  91. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
  92. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
  93. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
  94. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
  95. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
  96. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
  97. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
  98. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
  99. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
  100. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
  101. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
  102. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
  103. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
  104. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
  105. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
  106. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
  107. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
  108. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
  109. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
  110. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
  111. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
  112. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
  113. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
  114. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
  115. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  116. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  117. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
  118. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
  119. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
  120. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  121. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  122. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
  123. warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
  124. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
  125. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
  126. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
  127. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
  128. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
  129. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
  130. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
  131. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
  132. warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
  133. warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
  134. warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
  135. warp/native/cutlass/tools/library/scripts/rt.py +796 -0
  136. warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
  137. warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
  138. warp/native/cutlass_gemm.cu +5 -3
  139. warp/native/exports.h +1240 -952
  140. warp/native/fabric.h +228 -0
  141. warp/native/hashgrid.cpp +4 -4
  142. warp/native/hashgrid.h +22 -2
  143. warp/native/intersect.h +22 -7
  144. warp/native/intersect_adj.h +8 -8
  145. warp/native/intersect_tri.h +1 -1
  146. warp/native/marching.cu +157 -161
  147. warp/native/mat.h +80 -19
  148. warp/native/matnn.h +2 -2
  149. warp/native/mesh.cpp +33 -108
  150. warp/native/mesh.cu +114 -23
  151. warp/native/mesh.h +446 -46
  152. warp/native/noise.h +272 -329
  153. warp/native/quat.h +51 -8
  154. warp/native/rand.h +45 -35
  155. warp/native/range.h +6 -2
  156. warp/native/reduce.cpp +1 -1
  157. warp/native/reduce.cu +10 -12
  158. warp/native/runlength_encode.cu +6 -10
  159. warp/native/scan.cu +8 -11
  160. warp/native/sparse.cpp +4 -4
  161. warp/native/sparse.cu +164 -154
  162. warp/native/spatial.h +2 -2
  163. warp/native/temp_buffer.h +14 -30
  164. warp/native/vec.h +107 -23
  165. warp/native/volume.h +120 -0
  166. warp/native/warp.cpp +560 -30
  167. warp/native/warp.cu +431 -44
  168. warp/native/warp.h +13 -4
  169. warp/optim/__init__.py +1 -0
  170. warp/optim/linear.py +922 -0
  171. warp/optim/sgd.py +92 -0
  172. warp/render/render_opengl.py +335 -119
  173. warp/render/render_usd.py +11 -11
  174. warp/sim/__init__.py +2 -2
  175. warp/sim/articulation.py +385 -185
  176. warp/sim/collide.py +8 -0
  177. warp/sim/import_mjcf.py +297 -106
  178. warp/sim/import_urdf.py +389 -210
  179. warp/sim/import_usd.py +198 -97
  180. warp/sim/inertia.py +17 -18
  181. warp/sim/integrator_euler.py +14 -8
  182. warp/sim/integrator_xpbd.py +158 -16
  183. warp/sim/model.py +795 -291
  184. warp/sim/render.py +3 -3
  185. warp/sim/utils.py +3 -0
  186. warp/sparse.py +640 -150
  187. warp/stubs.py +606 -267
  188. warp/tape.py +61 -10
  189. warp/tests/__main__.py +3 -6
  190. warp/tests/assets/curlnoise_golden.npy +0 -0
  191. warp/tests/assets/pnoise_golden.npy +0 -0
  192. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  193. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  194. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  195. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  196. warp/tests/aux_test_unresolved_func.py +14 -0
  197. warp/tests/aux_test_unresolved_symbol.py +14 -0
  198. warp/tests/disabled_kinematics.py +239 -0
  199. warp/tests/run_coverage_serial.py +31 -0
  200. warp/tests/test_adam.py +103 -106
  201. warp/tests/test_arithmetic.py +128 -74
  202. warp/tests/test_array.py +212 -97
  203. warp/tests/test_array_reduce.py +57 -23
  204. warp/tests/test_atomic.py +64 -28
  205. warp/tests/test_bool.py +99 -0
  206. warp/tests/test_builtins_resolution.py +1292 -0
  207. warp/tests/test_bvh.py +42 -18
  208. warp/tests/test_closest_point_edge_edge.py +54 -57
  209. warp/tests/test_codegen.py +208 -130
  210. warp/tests/test_compile_consts.py +28 -20
  211. warp/tests/test_conditional.py +108 -24
  212. warp/tests/test_copy.py +10 -12
  213. warp/tests/test_ctypes.py +112 -88
  214. warp/tests/test_dense.py +21 -14
  215. warp/tests/test_devices.py +98 -0
  216. warp/tests/test_dlpack.py +75 -75
  217. warp/tests/test_examples.py +277 -0
  218. warp/tests/test_fabricarray.py +955 -0
  219. warp/tests/test_fast_math.py +15 -11
  220. warp/tests/test_fem.py +1271 -0
  221. warp/tests/test_fp16.py +53 -19
  222. warp/tests/test_func.py +187 -86
  223. warp/tests/test_generics.py +194 -49
  224. warp/tests/test_grad.py +178 -109
  225. warp/tests/test_grad_customs.py +176 -0
  226. warp/tests/test_hash_grid.py +52 -37
  227. warp/tests/test_import.py +10 -23
  228. warp/tests/test_indexedarray.py +32 -31
  229. warp/tests/test_intersect.py +18 -9
  230. warp/tests/test_large.py +141 -0
  231. warp/tests/test_launch.py +14 -41
  232. warp/tests/test_lerp.py +64 -65
  233. warp/tests/test_linear_solvers.py +154 -0
  234. warp/tests/test_lvalue.py +493 -0
  235. warp/tests/test_marching_cubes.py +12 -13
  236. warp/tests/test_mat.py +517 -2898
  237. warp/tests/test_mat_lite.py +115 -0
  238. warp/tests/test_mat_scalar_ops.py +2889 -0
  239. warp/tests/test_math.py +103 -9
  240. warp/tests/test_matmul.py +305 -69
  241. warp/tests/test_matmul_lite.py +410 -0
  242. warp/tests/test_mesh.py +71 -14
  243. warp/tests/test_mesh_query_aabb.py +41 -25
  244. warp/tests/test_mesh_query_point.py +140 -22
  245. warp/tests/test_mesh_query_ray.py +39 -22
  246. warp/tests/test_mlp.py +30 -22
  247. warp/tests/test_model.py +92 -89
  248. warp/tests/test_modules_lite.py +39 -0
  249. warp/tests/test_multigpu.py +88 -114
  250. warp/tests/test_noise.py +12 -11
  251. warp/tests/test_operators.py +16 -20
  252. warp/tests/test_options.py +11 -11
  253. warp/tests/test_pinned.py +17 -18
  254. warp/tests/test_print.py +32 -11
  255. warp/tests/test_quat.py +275 -129
  256. warp/tests/test_rand.py +18 -16
  257. warp/tests/test_reload.py +38 -34
  258. warp/tests/test_rounding.py +50 -43
  259. warp/tests/test_runlength_encode.py +168 -20
  260. warp/tests/test_smoothstep.py +9 -11
  261. warp/tests/test_snippet.py +143 -0
  262. warp/tests/test_sparse.py +261 -63
  263. warp/tests/test_spatial.py +276 -243
  264. warp/tests/test_streams.py +110 -85
  265. warp/tests/test_struct.py +268 -63
  266. warp/tests/test_tape.py +39 -21
  267. warp/tests/test_torch.py +118 -89
  268. warp/tests/test_transient_module.py +12 -13
  269. warp/tests/test_types.py +614 -0
  270. warp/tests/test_utils.py +494 -0
  271. warp/tests/test_vec.py +354 -2050
  272. warp/tests/test_vec_lite.py +73 -0
  273. warp/tests/test_vec_scalar_ops.py +2099 -0
  274. warp/tests/test_volume.py +457 -293
  275. warp/tests/test_volume_write.py +124 -134
  276. warp/tests/unittest_serial.py +35 -0
  277. warp/tests/unittest_suites.py +341 -0
  278. warp/tests/unittest_utils.py +568 -0
  279. warp/tests/unused_test_misc.py +71 -0
  280. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  281. warp/thirdparty/appdirs.py +36 -45
  282. warp/thirdparty/unittest_parallel.py +549 -0
  283. warp/torch.py +9 -6
  284. warp/types.py +1089 -366
  285. warp/utils.py +93 -387
  286. warp_lang-0.11.0.dist-info/METADATA +238 -0
  287. warp_lang-0.11.0.dist-info/RECORD +332 -0
  288. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
  289. warp/tests/test_all.py +0 -219
  290. warp/tests/test_array_scan.py +0 -60
  291. warp/tests/test_base.py +0 -208
  292. warp/tests/test_unresolved_func.py +0 -7
  293. warp/tests/test_unresolved_symbol.py +0 -7
  294. warp_lang-0.10.1.dist-info/METADATA +0 -21
  295. warp_lang-0.10.1.dist-info/RECORD +0 -188
  296. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  297. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  298. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  299. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/LICENSE.md +0 -0
  300. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
warp/native/sparse.cu CHANGED
@@ -1,8 +1,6 @@
1
1
  #include "cuda_util.h"
2
2
  #include "warp.h"
3
3
 
4
- #include "temp_buffer.h"
5
-
6
4
  #define THRUST_IGNORE_CUB_VERSION_CHECK
7
5
 
8
6
  #include <cub/device/device_radix_sort.cuh>
@@ -29,40 +27,29 @@ CUDA_CALLABLE uint32_t bsr_get_col(const BsrRowCol &row_col) {
29
27
 
30
28
  // Cached temporary storage
31
29
  struct BsrFromTripletsTemp {
32
- // Temp work buffers
33
- int nnz = 0;
34
- int *block_indices = NULL;
35
-
36
- BsrRowCol *combined_row_col = NULL;
37
-
30
+
31
+ int *count_buffer = NULL;
38
32
  cudaEvent_t host_sync_event = NULL;
39
33
 
40
- void ensure_fits(size_t size) {
41
-
42
- if (size > nnz) {
43
- size = std::max(2 * size, (static_cast<size_t>(nnz) * 3) / 2);
44
-
45
- free_device(WP_CURRENT_CONTEXT, block_indices);
46
- free_device(WP_CURRENT_CONTEXT, combined_row_col);
47
-
48
- // Factor 2 for in / out versions , +1 for count
49
- block_indices = static_cast<int *>(
50
- alloc_device(WP_CURRENT_CONTEXT, (2 * size + 1) * sizeof(int)));
51
- combined_row_col = static_cast<BsrRowCol *>(
52
- alloc_device(WP_CURRENT_CONTEXT, 2 * size * sizeof(BsrRowCol)));
34
+ BsrFromTripletsTemp()
35
+ : count_buffer(static_cast<int*>(alloc_pinned(sizeof(int))))
36
+ {
37
+ cudaEventCreateWithFlags(&host_sync_event, cudaEventDisableTiming);
38
+ }
39
+
40
+ ~BsrFromTripletsTemp()
41
+ {
42
+ cudaEventDestroy(host_sync_event);
43
+ free_pinned(count_buffer);
44
+ }
53
45
 
54
- nnz = size;
55
- }
46
+ BsrFromTripletsTemp(const BsrFromTripletsTemp&) = delete;
47
+ BsrFromTripletsTemp& operator=(const BsrFromTripletsTemp&) = delete;
56
48
 
57
- if (host_sync_event == NULL) {
58
- cudaEventCreateWithFlags(&host_sync_event, cudaEventDisableTiming);
59
- }
60
- }
61
49
  };
62
50
 
63
51
  // map temp buffers to CUDA contexts
64
- static std::unordered_map<void *, BsrFromTripletsTemp>
65
- g_bsr_from_triplets_temp_map;
52
+ static std::unordered_map<void *, BsrFromTripletsTemp> g_bsr_from_triplets_temp_map;
66
53
 
67
54
  template <typename T> struct BsrBlockIsNotZero {
68
55
  int block_size;
@@ -147,25 +134,22 @@ int bsr_matrix_from_triplets_device(const int rows_per_block,
147
134
  const int block_size = rows_per_block * cols_per_block;
148
135
 
149
136
  void *context = cuda_context_get_current();
137
+ ContextGuard guard(context);
150
138
 
151
139
  // Per-context cached temporary buffers
152
- TemporaryBuffer &cub_temp = g_temp_buffer_map[context];
153
- PinnedTemporaryBuffer &pinned_temp = g_pinned_temp_buffer_map[context];
154
140
  BsrFromTripletsTemp &bsr_temp = g_bsr_from_triplets_temp_map[context];
155
141
 
156
- ContextGuard guard(context);
157
-
158
142
  cudaStream_t stream = static_cast<cudaStream_t>(cuda_stream_get_current());
159
- bsr_temp.ensure_fits(nnz);
160
143
 
161
- cub::DoubleBuffer<int> d_keys(bsr_temp.block_indices,
162
- bsr_temp.block_indices + nnz);
163
- cub::DoubleBuffer<BsrRowCol> d_values(bsr_temp.combined_row_col,
164
- bsr_temp.combined_row_col + nnz);
144
+ ScopedTemporary<int> block_indices(context, 2*nnz);
145
+ ScopedTemporary<BsrRowCol> combined_row_col(context, 2*nnz);
146
+
147
+ cub::DoubleBuffer<int> d_keys(block_indices.buffer(),
148
+ block_indices.buffer() + nnz);
149
+ cub::DoubleBuffer<BsrRowCol> d_values(combined_row_col.buffer(),
150
+ combined_row_col.buffer() + nnz);
165
151
 
166
- int *d_nz_triplet_count = bsr_temp.block_indices + 2 * nnz;
167
- pinned_temp.ensure_fits(sizeof(int));
168
- int *pinned_count = static_cast<int *>(pinned_temp.buffer);
152
+ int *p_nz_triplet_count = bsr_temp.count_buffer;
169
153
 
170
154
  wp_launch_device(WP_CURRENT_CONTEXT, bsr_fill_block_indices, nnz,
171
155
  (nnz, d_keys.Current()));
@@ -173,32 +157,29 @@ int bsr_matrix_from_triplets_device(const int rows_per_block,
173
157
  if (tpl_values) {
174
158
 
175
159
  // Remove zero blocks
176
- size_t buff_size = 0;
177
- BsrBlockIsNotZero<T> isNotZero{block_size, tpl_values};
178
- check_cuda(cub::DeviceSelect::If(nullptr, buff_size, d_keys.Current(),
179
- d_keys.Alternate(), d_nz_triplet_count,
180
- nnz, isNotZero, stream));
181
- cub_temp.ensure_fits(buff_size);
182
- check_cuda(cub::DeviceSelect::If(
183
- cub_temp.buffer, buff_size, d_keys.Current(), d_keys.Alternate(),
184
- d_nz_triplet_count, nnz, isNotZero, stream));
160
+ {
161
+ size_t buff_size = 0;
162
+ BsrBlockIsNotZero<T> isNotZero{block_size, tpl_values};
163
+ check_cuda(cub::DeviceSelect::If(nullptr, buff_size, d_keys.Current(),
164
+ d_keys.Alternate(), p_nz_triplet_count,
165
+ nnz, isNotZero, stream));
166
+ ScopedTemporary<> temp(context, buff_size);
167
+ check_cuda(cub::DeviceSelect::If(
168
+ temp.buffer(), buff_size, d_keys.Current(), d_keys.Alternate(),
169
+ p_nz_triplet_count, nnz, isNotZero, stream));
170
+ }
171
+ cudaEventRecord(bsr_temp.host_sync_event, stream);
185
172
 
186
173
  // switch current/alternate in double buffer
187
174
  d_keys.selector ^= 1;
188
175
 
189
- // Copy number of remaining items to host, needed for further launches
190
- memcpy_d2h(WP_CURRENT_CONTEXT, pinned_count, d_nz_triplet_count,
191
- sizeof(int));
192
- cudaEventRecord(bsr_temp.host_sync_event, stream);
193
176
  } else {
194
- *pinned_count = nnz;
195
- memcpy_h2d(WP_CURRENT_CONTEXT, d_nz_triplet_count, pinned_count,
196
- sizeof(int));
177
+ *p_nz_triplet_count = nnz;
197
178
  }
198
179
 
199
180
  // Combine rows and columns so we can sort on them both
200
181
  wp_launch_device(WP_CURRENT_CONTEXT, bsr_fill_row_col, nnz,
201
- (d_nz_triplet_count, d_keys.Current(), tpl_rows, tpl_columns,
182
+ (p_nz_triplet_count, d_keys.Current(), tpl_rows, tpl_columns,
202
183
  d_values.Current()));
203
184
 
204
185
  if (tpl_values) {
@@ -206,27 +187,31 @@ int bsr_matrix_from_triplets_device(const int rows_per_block,
206
187
  cudaEventSynchronize(bsr_temp.host_sync_event);
207
188
  }
208
189
 
209
- const int nz_triplet_count = *pinned_count;
190
+ const int nz_triplet_count = *p_nz_triplet_count;
210
191
 
211
192
  // Sort
212
- size_t buff_size = 0;
213
- check_cuda(cub::DeviceRadixSort::SortPairs(
214
- nullptr, buff_size, d_values, d_keys, nz_triplet_count, 0, 64, stream));
215
- cub_temp.ensure_fits(buff_size);
216
- check_cuda(cub::DeviceRadixSort::SortPairs(cub_temp.buffer, buff_size,
217
- d_values, d_keys, nz_triplet_count,
218
- 0, 64, stream));
193
+ {
194
+ size_t buff_size = 0;
195
+ check_cuda(cub::DeviceRadixSort::SortPairs(
196
+ nullptr, buff_size, d_values, d_keys, nz_triplet_count, 0, 64, stream));
197
+ ScopedTemporary<> temp(context, buff_size);
198
+ check_cuda(cub::DeviceRadixSort::SortPairs(temp.buffer(), buff_size,
199
+ d_values, d_keys, nz_triplet_count,
200
+ 0, 64, stream));
201
+ }
219
202
 
220
203
  // Runlength encode row-col sequences
221
- check_cuda(cub::DeviceRunLengthEncode::Encode(
222
- nullptr, buff_size, d_values.Current(), d_values.Alternate(),
223
- d_keys.Alternate(), d_nz_triplet_count, nz_triplet_count, stream));
224
- cub_temp.ensure_fits(buff_size);
225
- check_cuda(cub::DeviceRunLengthEncode::Encode(
226
- cub_temp.buffer, buff_size, d_values.Current(), d_values.Alternate(),
227
- d_keys.Alternate(), d_nz_triplet_count, nz_triplet_count, stream));
228
-
229
- memcpy_d2h(WP_CURRENT_CONTEXT, pinned_count, d_nz_triplet_count, sizeof(int));
204
+ {
205
+ size_t buff_size = 0;
206
+ check_cuda(cub::DeviceRunLengthEncode::Encode(
207
+ nullptr, buff_size, d_values.Current(), d_values.Alternate(),
208
+ d_keys.Alternate(), p_nz_triplet_count, nz_triplet_count, stream));
209
+ ScopedTemporary<> temp(context, buff_size);
210
+ check_cuda(cub::DeviceRunLengthEncode::Encode(
211
+ temp.buffer(), buff_size, d_values.Current(), d_values.Alternate(),
212
+ d_keys.Alternate(), p_nz_triplet_count, nz_triplet_count, stream));
213
+ }
214
+
230
215
  cudaEventRecord(bsr_temp.host_sync_event, stream);
231
216
 
232
217
  // Now we have the following:
@@ -236,13 +221,16 @@ int bsr_matrix_from_triplets_device(const int rows_per_block,
236
221
  // d_keys.Alternate(): repeated block-row count
237
222
 
238
223
  // Scan repeated block counts
239
- check_cuda(cub::DeviceScan::InclusiveSum(
240
- nullptr, buff_size, d_keys.Alternate(), d_keys.Alternate(),
241
- nz_triplet_count, stream));
242
- cub_temp.ensure_fits(buff_size);
243
- check_cuda(cub::DeviceScan::InclusiveSum(
244
- cub_temp.buffer, buff_size, d_keys.Alternate(), d_keys.Alternate(),
245
- nz_triplet_count, stream));
224
+ {
225
+ size_t buff_size = 0;
226
+ check_cuda(cub::DeviceScan::InclusiveSum(
227
+ nullptr, buff_size, d_keys.Alternate(), d_keys.Alternate(),
228
+ nz_triplet_count, stream));
229
+ ScopedTemporary<> temp(context, buff_size);
230
+ check_cuda(cub::DeviceScan::InclusiveSum(
231
+ temp.buffer(), buff_size, d_keys.Alternate(), d_keys.Alternate(),
232
+ nz_triplet_count, stream));
233
+ }
246
234
 
247
235
  // While we're at it, zero the bsr offsets buffer
248
236
  memset_device(WP_CURRENT_CONTEXT, bsr_offsets, 0,
@@ -250,7 +238,7 @@ int bsr_matrix_from_triplets_device(const int rows_per_block,
250
238
 
251
239
  // Wait for number of compressed blocks
252
240
  cudaEventSynchronize(bsr_temp.host_sync_event);
253
- const int compressed_nnz = *pinned_count;
241
+ const int compressed_nnz = *p_nz_triplet_count;
254
242
 
255
243
  // We have all we need to accumulate our repeated blocks
256
244
  wp_launch_device(WP_CURRENT_CONTEXT, bsr_merge_blocks, compressed_nnz,
@@ -259,12 +247,15 @@ int bsr_matrix_from_triplets_device(const int rows_per_block,
259
247
  bsr_offsets, bsr_columns, bsr_values));
260
248
 
261
249
  // Last, prefix sum the row block counts
262
- check_cuda(cub::DeviceScan::InclusiveSum(nullptr, buff_size, bsr_offsets,
263
- bsr_offsets, row_count + 1, stream));
264
- cub_temp.ensure_fits(buff_size);
265
- check_cuda(cub::DeviceScan::InclusiveSum(cub_temp.buffer, buff_size,
266
- bsr_offsets, bsr_offsets,
267
- row_count + 1, stream));
250
+ {
251
+ size_t buff_size = 0;
252
+ check_cuda(cub::DeviceScan::InclusiveSum(nullptr, buff_size, bsr_offsets,
253
+ bsr_offsets, row_count + 1, stream));
254
+ ScopedTemporary<> temp(context, buff_size);
255
+ check_cuda(cub::DeviceScan::InclusiveSum(temp.buffer(), buff_size,
256
+ bsr_offsets, bsr_offsets,
257
+ row_count + 1, stream));
258
+ }
268
259
 
269
260
  return compressed_nnz;
270
261
  }
@@ -347,118 +338,75 @@ bsr_transpose_blocks(const int nnz, const int block_size,
347
338
  }
348
339
 
349
340
  template <typename T>
350
- void bsr_transpose_device(int rows_per_block, int cols_per_block, int row_count,
351
- int col_count, int nnz, const int *bsr_offsets,
352
- const int *bsr_columns, const T *bsr_values,
353
- int *transposed_bsr_offsets,
354
- int *transposed_bsr_columns,
355
- T *transposed_bsr_values) {
356
-
357
- const int block_size = rows_per_block * cols_per_block;
358
-
359
- void *context = cuda_context_get_current();
360
-
361
- // Per-context cached temporary buffers
362
- TemporaryBuffer &cub_temp = g_temp_buffer_map[context];
363
- BsrFromTripletsTemp &bsr_temp = g_bsr_from_triplets_temp_map[context];
364
-
365
- ContextGuard guard(context);
366
-
367
- cudaStream_t stream = static_cast<cudaStream_t>(cuda_stream_get_current());
368
- bsr_temp.ensure_fits(nnz);
369
-
370
- // Zero the transposed offsets
371
- memset_device(WP_CURRENT_CONTEXT, transposed_bsr_offsets, 0,
372
- (col_count + 1) * sizeof(int));
373
-
374
- cub::DoubleBuffer<int> d_keys(bsr_temp.block_indices,
375
- bsr_temp.block_indices + nnz);
376
- cub::DoubleBuffer<BsrRowCol> d_values(bsr_temp.combined_row_col,
377
- bsr_temp.combined_row_col + nnz);
378
-
379
- wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_fill_row_col, nnz,
380
- (nnz, row_count, bsr_offsets, bsr_columns, d_keys.Current(),
381
- d_values.Current(), transposed_bsr_offsets));
341
+ void
342
+ launch_bsr_transpose_blocks(const int nnz, const int block_size,
343
+ const int rows_per_block, const int cols_per_block,
344
+ const int *block_indices,
345
+ const BsrRowCol *transposed_indices,
346
+ const T *bsr_values,
347
+ int *transposed_bsr_columns, T *transposed_bsr_values) {
382
348
 
383
- // Sort blocks
384
- size_t buff_size = 0;
385
- check_cuda(cub::DeviceRadixSort::SortPairs(nullptr, buff_size, d_values,
386
- d_keys, nnz, 0, 64, stream));
387
- cub_temp.ensure_fits(buff_size);
388
- check_cuda(cub::DeviceRadixSort::SortPairs(
389
- cub_temp.buffer, buff_size, d_values, d_keys, nnz, 0, 64, stream));
390
-
391
- // Prefix sum the trasnposed row block counts
392
- check_cuda(cub::DeviceScan::InclusiveSum(
393
- nullptr, buff_size, transposed_bsr_offsets, transposed_bsr_offsets,
394
- col_count + 1, stream));
395
- cub_temp.ensure_fits(buff_size);
396
- check_cuda(cub::DeviceScan::InclusiveSum(
397
- cub_temp.buffer, buff_size, transposed_bsr_offsets,
398
- transposed_bsr_offsets, col_count + 1, stream));
399
-
400
- // Move and transpose invidual blocks
401
- switch (row_count) {
349
+ switch (rows_per_block) {
402
350
  case 1:
403
- switch (col_count) {
351
+ switch (cols_per_block) {
404
352
  case 1:
405
353
  wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
406
354
  (nnz, block_size, BsrBlockTransposer<1, 1, T>{},
407
- d_keys.Current(), d_values.Current(), bsr_values,
355
+ block_indices, transposed_indices, bsr_values,
408
356
  transposed_bsr_columns, transposed_bsr_values));
409
357
  return;
410
358
  case 2:
411
359
  wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
412
360
  (nnz, block_size, BsrBlockTransposer<1, 2, T>{},
413
- d_keys.Current(), d_values.Current(), bsr_values,
361
+ block_indices, transposed_indices, bsr_values,
414
362
  transposed_bsr_columns, transposed_bsr_values));
415
363
  return;
416
364
  case 3:
417
365
  wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
418
366
  (nnz, block_size, BsrBlockTransposer<1, 3, T>{},
419
- d_keys.Current(), d_values.Current(), bsr_values,
367
+ block_indices, transposed_indices, bsr_values,
420
368
  transposed_bsr_columns, transposed_bsr_values));
421
369
  return;
422
370
  }
423
371
  case 2:
424
- switch (col_count) {
372
+ switch (cols_per_block) {
425
373
  case 1:
426
374
  wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
427
375
  (nnz, block_size, BsrBlockTransposer<2, 1, T>{},
428
- d_keys.Current(), d_values.Current(), bsr_values,
376
+ block_indices, transposed_indices, bsr_values,
429
377
  transposed_bsr_columns, transposed_bsr_values));
430
378
  return;
431
379
  case 2:
432
380
  wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
433
381
  (nnz, block_size, BsrBlockTransposer<2, 2, T>{},
434
- d_keys.Current(), d_values.Current(), bsr_values,
382
+ block_indices, transposed_indices, bsr_values,
435
383
  transposed_bsr_columns, transposed_bsr_values));
436
384
  return;
437
385
  case 3:
438
386
  wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
439
387
  (nnz, block_size, BsrBlockTransposer<2, 3, T>{},
440
- d_keys.Current(), d_values.Current(), bsr_values,
388
+ block_indices, transposed_indices, bsr_values,
441
389
  transposed_bsr_columns, transposed_bsr_values));
442
390
  return;
443
391
  }
444
392
  case 3:
445
- switch (col_count) {
393
+ switch (cols_per_block) {
446
394
  case 1:
447
395
  wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
448
396
  (nnz, block_size, BsrBlockTransposer<3, 1, T>{},
449
- d_keys.Current(), d_values.Current(), bsr_values,
397
+ block_indices, transposed_indices, bsr_values,
450
398
  transposed_bsr_columns, transposed_bsr_values));
451
399
  return;
452
400
  case 2:
453
401
  wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
454
402
  (nnz, block_size, BsrBlockTransposer<3, 2, T>{},
455
- d_keys.Current(), d_values.Current(), bsr_values,
403
+ block_indices, transposed_indices, bsr_values,
456
404
  transposed_bsr_columns, transposed_bsr_values));
457
405
  return;
458
406
  case 3:
459
407
  wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
460
408
  (nnz, block_size, BsrBlockTransposer<3, 3, T>{},
461
- d_keys.Current(), d_values.Current(), bsr_values,
409
+ block_indices, transposed_indices, bsr_values,
462
410
  transposed_bsr_columns, transposed_bsr_values));
463
411
  return;
464
412
  }
@@ -468,10 +416,72 @@ void bsr_transpose_device(int rows_per_block, int cols_per_block, int row_count,
468
416
  WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
469
417
  (nnz, block_size,
470
418
  BsrBlockTransposer<-1, -1, T>{rows_per_block, cols_per_block},
471
- d_keys.Current(), d_values.Current(), bsr_values, transposed_bsr_columns,
419
+ block_indices, transposed_indices, bsr_values, transposed_bsr_columns,
472
420
  transposed_bsr_values));
473
421
  }
474
422
 
423
+ template <typename T>
424
+ void bsr_transpose_device(int rows_per_block, int cols_per_block, int row_count,
425
+ int col_count, int nnz, const int *bsr_offsets,
426
+ const int *bsr_columns, const T *bsr_values,
427
+ int *transposed_bsr_offsets,
428
+ int *transposed_bsr_columns,
429
+ T *transposed_bsr_values) {
430
+
431
+ const int block_size = rows_per_block * cols_per_block;
432
+
433
+ void *context = cuda_context_get_current();
434
+ ContextGuard guard(context);
435
+
436
+ cudaStream_t stream = static_cast<cudaStream_t>(cuda_stream_get_current());
437
+
438
+ // Zero the transposed offsets
439
+ memset_device(WP_CURRENT_CONTEXT, transposed_bsr_offsets, 0,
440
+ (col_count + 1) * sizeof(int));
441
+
442
+ ScopedTemporary<int> block_indices(context, 2*nnz);
443
+ ScopedTemporary<BsrRowCol> combined_row_col(context, 2*nnz);
444
+
445
+ cub::DoubleBuffer<int> d_keys(block_indices.buffer(),
446
+ block_indices.buffer() + nnz);
447
+ cub::DoubleBuffer<BsrRowCol> d_values(combined_row_col.buffer(),
448
+ combined_row_col.buffer() + nnz);
449
+
450
+ wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_fill_row_col, nnz,
451
+ (nnz, row_count, bsr_offsets, bsr_columns, d_keys.Current(),
452
+ d_values.Current(), transposed_bsr_offsets));
453
+
454
+ // Sort blocks
455
+ {
456
+ size_t buff_size = 0;
457
+ check_cuda(cub::DeviceRadixSort::SortPairs(nullptr, buff_size, d_values,
458
+ d_keys, nnz, 0, 64, stream));
459
+ void* temp_buffer = alloc_temp_device(WP_CURRENT_CONTEXT, buff_size);
460
+ ScopedTemporary<> temp(context, buff_size);
461
+ check_cuda(cub::DeviceRadixSort::SortPairs(
462
+ temp.buffer(), buff_size, d_values, d_keys, nnz, 0, 64, stream));
463
+ }
464
+
465
+ // Prefix sum the transposed row block counts
466
+ {
467
+ size_t buff_size = 0;
468
+ check_cuda(cub::DeviceScan::InclusiveSum(
469
+ nullptr, buff_size, transposed_bsr_offsets, transposed_bsr_offsets,
470
+ col_count + 1, stream));
471
+ ScopedTemporary<> temp(context, buff_size);
472
+ check_cuda(cub::DeviceScan::InclusiveSum(
473
+ temp.buffer(), buff_size, transposed_bsr_offsets,
474
+ transposed_bsr_offsets, col_count + 1, stream));
475
+ }
476
+
477
+ // Move and transpose individual blocks
478
+ launch_bsr_transpose_blocks(
479
+ nnz, block_size,
480
+ rows_per_block, cols_per_block,
481
+ d_keys.Current(), d_values.Current(), bsr_values, transposed_bsr_columns,
482
+ transposed_bsr_values);
483
+ }
484
+
475
485
  } // namespace
476
486
 
477
487
  int bsr_matrix_from_triplets_float_device(
@@ -532,4 +542,4 @@ void bsr_transpose_double_device(int rows_per_block, int cols_per_block,
532
542
  reinterpret_cast<int *>(transposed_bsr_offsets),
533
543
  reinterpret_cast<int *>(transposed_bsr_columns),
534
544
  reinterpret_cast<double *>(transposed_bsr_values));
535
- }
545
+ }
warp/native/spatial.h CHANGED
@@ -265,13 +265,13 @@ inline CUDA_CALLABLE Type tensordot(const transform_t<Type>& a, const transform_
265
265
  }
266
266
 
267
267
  template<typename Type>
268
- inline CUDA_CALLABLE Type index(const transform_t<Type>& t, int i)
268
+ inline CUDA_CALLABLE Type extract(const transform_t<Type>& t, int i)
269
269
  {
270
270
  return t[i];
271
271
  }
272
272
 
273
273
  template<typename Type>
274
- inline void CUDA_CALLABLE adj_index(const transform_t<Type>& t, int i, transform_t<Type>& adj_t, int& adj_i, Type adj_ret)
274
+ inline void CUDA_CALLABLE adj_extract(const transform_t<Type>& t, int i, transform_t<Type>& adj_t, int& adj_i, Type adj_ret)
275
275
  {
276
276
  adj_t[i] += adj_ret;
277
277
  }
warp/native/temp_buffer.h CHANGED
@@ -1,46 +1,30 @@
1
1
 
2
2
  #pragma once
3
3
 
4
- #include "warp.h"
5
4
  #include "cuda_util.h"
5
+ #include "warp.h"
6
6
 
7
7
  #include <unordered_map>
8
8
 
9
- // temporary buffer, useful for cub algorithms
10
- struct TemporaryBuffer
9
+ template <typename T = char> struct ScopedTemporary
11
10
  {
12
- void *buffer = NULL;
13
- size_t buffer_size = 0;
14
11
 
15
- void ensure_fits(size_t size)
12
+ ScopedTemporary(void *context, size_t size)
13
+ : m_context(context), m_buffer(static_cast<T*>(alloc_temp_device(m_context, size * sizeof(T))))
16
14
  {
17
- if (size > buffer_size)
18
- {
19
- size = std::max(2 * size, (buffer_size * 3) / 2);
20
-
21
- free_device(WP_CURRENT_CONTEXT, buffer);
22
- buffer = alloc_device(WP_CURRENT_CONTEXT, size);
23
- buffer_size = size;
24
- }
25
15
  }
26
- };
27
16
 
28
- struct PinnedTemporaryBuffer
29
- {
30
- void *buffer = NULL;
31
- size_t buffer_size = 0;
17
+ ~ScopedTemporary()
18
+ {
19
+ free_temp_device(m_context, m_buffer);
20
+ }
32
21
 
33
- void ensure_fits(size_t size)
22
+ T *buffer() const
34
23
  {
35
- if (size > buffer_size)
36
- {
37
- free_pinned(buffer);
38
- buffer = alloc_pinned(size);
39
- buffer_size = size;
40
- }
24
+ return m_buffer;
41
25
  }
42
- };
43
26
 
44
- // map temp buffers to CUDA contexts
45
- static std::unordered_map<void *, TemporaryBuffer> g_temp_buffer_map;
46
- static std::unordered_map<void *, PinnedTemporaryBuffer> g_pinned_temp_buffer_map;
27
+ private:
28
+ void *m_context;
29
+ T *m_buffer;
30
+ };