warp-lang 0.9.0__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (315) hide show
  1. warp/__init__.py +15 -7
  2. warp/__init__.pyi +1 -0
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +22 -443
  6. warp/build_dll.py +384 -0
  7. warp/builtins.py +998 -488
  8. warp/codegen.py +1307 -739
  9. warp/config.py +5 -3
  10. warp/constants.py +6 -0
  11. warp/context.py +1291 -548
  12. warp/dlpack.py +31 -31
  13. warp/fabric.py +326 -0
  14. warp/fem/__init__.py +27 -0
  15. warp/fem/cache.py +389 -0
  16. warp/fem/dirichlet.py +181 -0
  17. warp/fem/domain.py +263 -0
  18. warp/fem/field/__init__.py +101 -0
  19. warp/fem/field/field.py +149 -0
  20. warp/fem/field/nodal_field.py +299 -0
  21. warp/fem/field/restriction.py +21 -0
  22. warp/fem/field/test.py +181 -0
  23. warp/fem/field/trial.py +183 -0
  24. warp/fem/geometry/__init__.py +19 -0
  25. warp/fem/geometry/closest_point.py +70 -0
  26. warp/fem/geometry/deformed_geometry.py +271 -0
  27. warp/fem/geometry/element.py +744 -0
  28. warp/fem/geometry/geometry.py +186 -0
  29. warp/fem/geometry/grid_2d.py +373 -0
  30. warp/fem/geometry/grid_3d.py +435 -0
  31. warp/fem/geometry/hexmesh.py +953 -0
  32. warp/fem/geometry/partition.py +376 -0
  33. warp/fem/geometry/quadmesh_2d.py +532 -0
  34. warp/fem/geometry/tetmesh.py +840 -0
  35. warp/fem/geometry/trimesh_2d.py +577 -0
  36. warp/fem/integrate.py +1616 -0
  37. warp/fem/operator.py +191 -0
  38. warp/fem/polynomial.py +213 -0
  39. warp/fem/quadrature/__init__.py +2 -0
  40. warp/fem/quadrature/pic_quadrature.py +245 -0
  41. warp/fem/quadrature/quadrature.py +294 -0
  42. warp/fem/space/__init__.py +292 -0
  43. warp/fem/space/basis_space.py +489 -0
  44. warp/fem/space/collocated_function_space.py +105 -0
  45. warp/fem/space/dof_mapper.py +236 -0
  46. warp/fem/space/function_space.py +145 -0
  47. warp/fem/space/grid_2d_function_space.py +267 -0
  48. warp/fem/space/grid_3d_function_space.py +306 -0
  49. warp/fem/space/hexmesh_function_space.py +352 -0
  50. warp/fem/space/partition.py +350 -0
  51. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  52. warp/fem/space/restriction.py +160 -0
  53. warp/fem/space/shape/__init__.py +15 -0
  54. warp/fem/space/shape/cube_shape_function.py +738 -0
  55. warp/fem/space/shape/shape_function.py +103 -0
  56. warp/fem/space/shape/square_shape_function.py +611 -0
  57. warp/fem/space/shape/tet_shape_function.py +567 -0
  58. warp/fem/space/shape/triangle_shape_function.py +429 -0
  59. warp/fem/space/tetmesh_function_space.py +292 -0
  60. warp/fem/space/topology.py +295 -0
  61. warp/fem/space/trimesh_2d_function_space.py +221 -0
  62. warp/fem/types.py +77 -0
  63. warp/fem/utils.py +495 -0
  64. warp/native/array.h +164 -55
  65. warp/native/builtin.h +150 -174
  66. warp/native/bvh.cpp +75 -328
  67. warp/native/bvh.cu +406 -23
  68. warp/native/bvh.h +37 -45
  69. warp/native/clang/clang.cpp +136 -24
  70. warp/native/crt.cpp +1 -76
  71. warp/native/crt.h +111 -104
  72. warp/native/cuda_crt.h +1049 -0
  73. warp/native/cuda_util.cpp +15 -3
  74. warp/native/cuda_util.h +3 -1
  75. warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
  76. warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
  77. warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
  78. warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
  79. warp/native/cutlass/tools/library/scripts/library.py +799 -0
  80. warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
  81. warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
  82. warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
  83. warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
  84. warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
  85. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
  86. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
  87. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
  88. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
  89. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
  90. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
  91. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
  92. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
  93. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
  94. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
  95. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
  96. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
  97. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
  98. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
  99. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
  100. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
  101. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
  102. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
  103. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
  104. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
  105. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
  106. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
  107. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
  108. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
  109. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
  110. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
  111. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
  112. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
  113. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
  114. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
  115. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
  116. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
  117. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  118. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  119. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
  120. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
  121. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
  122. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  123. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  124. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
  125. warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
  126. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
  127. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
  128. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
  129. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
  130. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
  131. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
  132. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
  133. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
  134. warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
  135. warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
  136. warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
  137. warp/native/cutlass/tools/library/scripts/rt.py +796 -0
  138. warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
  139. warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
  140. warp/native/cutlass_gemm.cu +5 -3
  141. warp/native/exports.h +1240 -949
  142. warp/native/fabric.h +228 -0
  143. warp/native/hashgrid.cpp +4 -4
  144. warp/native/hashgrid.h +22 -2
  145. warp/native/initializer_array.h +2 -2
  146. warp/native/intersect.h +22 -7
  147. warp/native/intersect_adj.h +8 -8
  148. warp/native/intersect_tri.h +13 -16
  149. warp/native/marching.cu +157 -161
  150. warp/native/mat.h +119 -19
  151. warp/native/matnn.h +2 -2
  152. warp/native/mesh.cpp +108 -83
  153. warp/native/mesh.cu +243 -6
  154. warp/native/mesh.h +1547 -458
  155. warp/native/nanovdb/NanoVDB.h +1 -1
  156. warp/native/noise.h +272 -329
  157. warp/native/quat.h +51 -8
  158. warp/native/rand.h +45 -35
  159. warp/native/range.h +6 -2
  160. warp/native/reduce.cpp +157 -0
  161. warp/native/reduce.cu +348 -0
  162. warp/native/runlength_encode.cpp +62 -0
  163. warp/native/runlength_encode.cu +46 -0
  164. warp/native/scan.cu +11 -13
  165. warp/native/scan.h +1 -0
  166. warp/native/solid_angle.h +442 -0
  167. warp/native/sort.cpp +13 -0
  168. warp/native/sort.cu +9 -1
  169. warp/native/sparse.cpp +338 -0
  170. warp/native/sparse.cu +545 -0
  171. warp/native/spatial.h +2 -2
  172. warp/native/temp_buffer.h +30 -0
  173. warp/native/vec.h +126 -24
  174. warp/native/volume.h +120 -0
  175. warp/native/warp.cpp +658 -53
  176. warp/native/warp.cu +660 -68
  177. warp/native/warp.h +112 -12
  178. warp/optim/__init__.py +1 -0
  179. warp/optim/linear.py +922 -0
  180. warp/optim/sgd.py +92 -0
  181. warp/render/render_opengl.py +392 -152
  182. warp/render/render_usd.py +11 -11
  183. warp/sim/__init__.py +2 -2
  184. warp/sim/articulation.py +385 -185
  185. warp/sim/collide.py +21 -8
  186. warp/sim/import_mjcf.py +297 -106
  187. warp/sim/import_urdf.py +389 -210
  188. warp/sim/import_usd.py +198 -97
  189. warp/sim/inertia.py +17 -18
  190. warp/sim/integrator_euler.py +14 -8
  191. warp/sim/integrator_xpbd.py +161 -19
  192. warp/sim/model.py +795 -291
  193. warp/sim/optimizer.py +2 -6
  194. warp/sim/render.py +65 -3
  195. warp/sim/utils.py +3 -0
  196. warp/sparse.py +1227 -0
  197. warp/stubs.py +665 -223
  198. warp/tape.py +66 -15
  199. warp/tests/__main__.py +3 -6
  200. warp/tests/assets/curlnoise_golden.npy +0 -0
  201. warp/tests/assets/pnoise_golden.npy +0 -0
  202. warp/tests/assets/torus.usda +105 -105
  203. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  204. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  205. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  206. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  207. warp/tests/aux_test_unresolved_func.py +14 -0
  208. warp/tests/aux_test_unresolved_symbol.py +14 -0
  209. warp/tests/disabled_kinematics.py +239 -0
  210. warp/tests/run_coverage_serial.py +31 -0
  211. warp/tests/test_adam.py +103 -106
  212. warp/tests/test_arithmetic.py +128 -74
  213. warp/tests/test_array.py +1497 -211
  214. warp/tests/test_array_reduce.py +150 -0
  215. warp/tests/test_atomic.py +64 -28
  216. warp/tests/test_bool.py +99 -0
  217. warp/tests/test_builtins_resolution.py +1292 -0
  218. warp/tests/test_bvh.py +75 -43
  219. warp/tests/test_closest_point_edge_edge.py +54 -57
  220. warp/tests/test_codegen.py +233 -128
  221. warp/tests/test_compile_consts.py +28 -20
  222. warp/tests/test_conditional.py +108 -24
  223. warp/tests/test_copy.py +10 -12
  224. warp/tests/test_ctypes.py +112 -88
  225. warp/tests/test_dense.py +21 -14
  226. warp/tests/test_devices.py +98 -0
  227. warp/tests/test_dlpack.py +136 -108
  228. warp/tests/test_examples.py +277 -0
  229. warp/tests/test_fabricarray.py +955 -0
  230. warp/tests/test_fast_math.py +15 -11
  231. warp/tests/test_fem.py +1271 -0
  232. warp/tests/test_fp16.py +53 -19
  233. warp/tests/test_func.py +187 -74
  234. warp/tests/test_generics.py +194 -49
  235. warp/tests/test_grad.py +180 -116
  236. warp/tests/test_grad_customs.py +176 -0
  237. warp/tests/test_hash_grid.py +52 -37
  238. warp/tests/test_import.py +10 -23
  239. warp/tests/test_indexedarray.py +577 -24
  240. warp/tests/test_intersect.py +18 -9
  241. warp/tests/test_large.py +141 -0
  242. warp/tests/test_launch.py +251 -15
  243. warp/tests/test_lerp.py +64 -65
  244. warp/tests/test_linear_solvers.py +154 -0
  245. warp/tests/test_lvalue.py +493 -0
  246. warp/tests/test_marching_cubes.py +12 -13
  247. warp/tests/test_mat.py +508 -2778
  248. warp/tests/test_mat_lite.py +115 -0
  249. warp/tests/test_mat_scalar_ops.py +2889 -0
  250. warp/tests/test_math.py +103 -9
  251. warp/tests/test_matmul.py +305 -69
  252. warp/tests/test_matmul_lite.py +410 -0
  253. warp/tests/test_mesh.py +71 -14
  254. warp/tests/test_mesh_query_aabb.py +41 -25
  255. warp/tests/test_mesh_query_point.py +325 -34
  256. warp/tests/test_mesh_query_ray.py +39 -22
  257. warp/tests/test_mlp.py +30 -22
  258. warp/tests/test_model.py +92 -89
  259. warp/tests/test_modules_lite.py +39 -0
  260. warp/tests/test_multigpu.py +88 -114
  261. warp/tests/test_noise.py +12 -11
  262. warp/tests/test_operators.py +16 -20
  263. warp/tests/test_options.py +11 -11
  264. warp/tests/test_pinned.py +17 -18
  265. warp/tests/test_print.py +32 -11
  266. warp/tests/test_quat.py +275 -129
  267. warp/tests/test_rand.py +18 -16
  268. warp/tests/test_reload.py +38 -34
  269. warp/tests/test_rounding.py +50 -43
  270. warp/tests/test_runlength_encode.py +190 -0
  271. warp/tests/test_smoothstep.py +9 -11
  272. warp/tests/test_snippet.py +143 -0
  273. warp/tests/test_sparse.py +460 -0
  274. warp/tests/test_spatial.py +276 -243
  275. warp/tests/test_streams.py +110 -85
  276. warp/tests/test_struct.py +331 -85
  277. warp/tests/test_tape.py +39 -21
  278. warp/tests/test_torch.py +118 -89
  279. warp/tests/test_transient_module.py +12 -13
  280. warp/tests/test_types.py +614 -0
  281. warp/tests/test_utils.py +494 -0
  282. warp/tests/test_vec.py +354 -1987
  283. warp/tests/test_vec_lite.py +73 -0
  284. warp/tests/test_vec_scalar_ops.py +2099 -0
  285. warp/tests/test_volume.py +457 -293
  286. warp/tests/test_volume_write.py +124 -134
  287. warp/tests/unittest_serial.py +35 -0
  288. warp/tests/unittest_suites.py +341 -0
  289. warp/tests/unittest_utils.py +568 -0
  290. warp/tests/unused_test_misc.py +71 -0
  291. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  292. warp/thirdparty/appdirs.py +36 -45
  293. warp/thirdparty/unittest_parallel.py +549 -0
  294. warp/torch.py +72 -30
  295. warp/types.py +1744 -713
  296. warp/utils.py +360 -350
  297. warp_lang-0.11.0.dist-info/LICENSE.md +36 -0
  298. warp_lang-0.11.0.dist-info/METADATA +238 -0
  299. warp_lang-0.11.0.dist-info/RECORD +332 -0
  300. {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
  301. warp/bin/warp-clang.exp +0 -0
  302. warp/bin/warp-clang.lib +0 -0
  303. warp/bin/warp.exp +0 -0
  304. warp/bin/warp.lib +0 -0
  305. warp/tests/test_all.py +0 -215
  306. warp/tests/test_array_scan.py +0 -60
  307. warp/tests/test_base.py +0 -208
  308. warp/tests/test_unresolved_func.py +0 -7
  309. warp/tests/test_unresolved_symbol.py +0 -7
  310. warp_lang-0.9.0.dist-info/METADATA +0 -20
  311. warp_lang-0.9.0.dist-info/RECORD +0 -177
  312. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  313. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  314. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  315. {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
warp/native/sparse.cu ADDED
@@ -0,0 +1,545 @@
1
+ #include "cuda_util.h"
2
+ #include "warp.h"
3
+
4
+ #define THRUST_IGNORE_CUB_VERSION_CHECK
5
+
6
+ #include <cub/device/device_radix_sort.cuh>
7
+ #include <cub/device/device_run_length_encode.cuh>
8
+ #include <cub/device/device_scan.cuh>
9
+ #include <cub/device/device_select.cuh>
10
+
11
+ namespace {
12
+
13
+ // Combined row+column value that can be radix-sorted with CUB
14
+ using BsrRowCol = uint64_t;
15
+
16
+ CUDA_CALLABLE BsrRowCol bsr_combine_row_col(uint32_t row, uint32_t col) {
17
+ return (static_cast<uint64_t>(row) << 32) | col;
18
+ }
19
+
20
+ CUDA_CALLABLE uint32_t bsr_get_row(const BsrRowCol &row_col) {
21
+ return row_col >> 32;
22
+ }
23
+
24
+ CUDA_CALLABLE uint32_t bsr_get_col(const BsrRowCol &row_col) {
25
+ return row_col & INT_MAX;
26
+ }
27
+
28
+ // Cached temporary storage
29
+ struct BsrFromTripletsTemp {
30
+
31
+ int *count_buffer = NULL;
32
+ cudaEvent_t host_sync_event = NULL;
33
+
34
+ BsrFromTripletsTemp()
35
+ : count_buffer(static_cast<int*>(alloc_pinned(sizeof(int))))
36
+ {
37
+ cudaEventCreateWithFlags(&host_sync_event, cudaEventDisableTiming);
38
+ }
39
+
40
+ ~BsrFromTripletsTemp()
41
+ {
42
+ cudaEventDestroy(host_sync_event);
43
+ free_pinned(count_buffer);
44
+ }
45
+
46
+ BsrFromTripletsTemp(const BsrFromTripletsTemp&) = delete;
47
+ BsrFromTripletsTemp& operator=(const BsrFromTripletsTemp&) = delete;
48
+
49
+ };
50
+
51
+ // map temp buffers to CUDA contexts
52
+ static std::unordered_map<void *, BsrFromTripletsTemp> g_bsr_from_triplets_temp_map;
53
+
54
+ template <typename T> struct BsrBlockIsNotZero {
55
+ int block_size;
56
+ const T *values;
57
+
58
+ CUDA_CALLABLE_DEVICE bool operator()(int i) const {
59
+ const T *val = values + i * block_size;
60
+ for (int i = 0; i < block_size; ++i, ++val) {
61
+ if (*val != T(0))
62
+ return true;
63
+ }
64
+ return false;
65
+ }
66
+ };
67
+
68
+ __global__ void bsr_fill_block_indices(int nnz, int *block_indices) {
69
+ int i = blockIdx.x * blockDim.x + threadIdx.x;
70
+ if (i >= nnz)
71
+ return;
72
+
73
+ block_indices[i] = i;
74
+ }
75
+
76
+ __global__ void bsr_fill_row_col(const int *nnz, const int *block_indices,
77
+ const int *tpl_rows, const int *tpl_columns,
78
+ BsrRowCol *tpl_row_col) {
79
+ int i = blockIdx.x * blockDim.x + threadIdx.x;
80
+ if (i >= *nnz)
81
+ return;
82
+
83
+ const int block = block_indices[i];
84
+
85
+ BsrRowCol row_col = bsr_combine_row_col(tpl_rows[block], tpl_columns[block]);
86
+ tpl_row_col[i] = row_col;
87
+ }
88
+
89
+ template <typename T>
90
+ __global__ void
91
+ bsr_merge_blocks(int nnz, int block_size, const int *block_offsets,
92
+ const int *sorted_block_indices,
93
+ const BsrRowCol *unique_row_cols, const T *tpl_values,
94
+ int *bsr_row_counts, int *bsr_cols, T *bsr_values)
95
+
96
+ {
97
+ int i = blockIdx.x * blockDim.x + threadIdx.x;
98
+ if (i >= nnz)
99
+ return;
100
+
101
+ const int beg = i ? block_offsets[i - 1] : 0;
102
+ const int end = block_offsets[i];
103
+
104
+ BsrRowCol row_col = unique_row_cols[i];
105
+
106
+ bsr_cols[i] = bsr_get_col(row_col);
107
+ atomicAdd(bsr_row_counts + bsr_get_row(row_col) + 1, 1);
108
+
109
+ if (bsr_values == nullptr)
110
+ return;
111
+
112
+ T *bsr_val = bsr_values + i * block_size;
113
+ const T *tpl_val = tpl_values + sorted_block_indices[beg] * block_size;
114
+
115
+ for (int k = 0; k < block_size; ++k) {
116
+ bsr_val[k] = tpl_val[k];
117
+ }
118
+
119
+ for (int cur = beg + 1; cur != end; ++cur) {
120
+ const T *tpl_val = tpl_values + sorted_block_indices[cur] * block_size;
121
+ for (int k = 0; k < block_size; ++k) {
122
+ bsr_val[k] += tpl_val[k];
123
+ }
124
+ }
125
+ }
126
+
127
+ template <typename T>
128
+ int bsr_matrix_from_triplets_device(const int rows_per_block,
129
+ const int cols_per_block,
130
+ const int row_count, const int nnz,
131
+ const int *tpl_rows, const int *tpl_columns,
132
+ const T *tpl_values, int *bsr_offsets,
133
+ int *bsr_columns, T *bsr_values) {
134
+ const int block_size = rows_per_block * cols_per_block;
135
+
136
+ void *context = cuda_context_get_current();
137
+ ContextGuard guard(context);
138
+
139
+ // Per-context cached temporary buffers
140
+ BsrFromTripletsTemp &bsr_temp = g_bsr_from_triplets_temp_map[context];
141
+
142
+ cudaStream_t stream = static_cast<cudaStream_t>(cuda_stream_get_current());
143
+
144
+ ScopedTemporary<int> block_indices(context, 2*nnz);
145
+ ScopedTemporary<BsrRowCol> combined_row_col(context, 2*nnz);
146
+
147
+ cub::DoubleBuffer<int> d_keys(block_indices.buffer(),
148
+ block_indices.buffer() + nnz);
149
+ cub::DoubleBuffer<BsrRowCol> d_values(combined_row_col.buffer(),
150
+ combined_row_col.buffer() + nnz);
151
+
152
+ int *p_nz_triplet_count = bsr_temp.count_buffer;
153
+
154
+ wp_launch_device(WP_CURRENT_CONTEXT, bsr_fill_block_indices, nnz,
155
+ (nnz, d_keys.Current()));
156
+
157
+ if (tpl_values) {
158
+
159
+ // Remove zero blocks
160
+ {
161
+ size_t buff_size = 0;
162
+ BsrBlockIsNotZero<T> isNotZero{block_size, tpl_values};
163
+ check_cuda(cub::DeviceSelect::If(nullptr, buff_size, d_keys.Current(),
164
+ d_keys.Alternate(), p_nz_triplet_count,
165
+ nnz, isNotZero, stream));
166
+ ScopedTemporary<> temp(context, buff_size);
167
+ check_cuda(cub::DeviceSelect::If(
168
+ temp.buffer(), buff_size, d_keys.Current(), d_keys.Alternate(),
169
+ p_nz_triplet_count, nnz, isNotZero, stream));
170
+ }
171
+ cudaEventRecord(bsr_temp.host_sync_event, stream);
172
+
173
+ // switch current/alternate in double buffer
174
+ d_keys.selector ^= 1;
175
+
176
+ } else {
177
+ *p_nz_triplet_count = nnz;
178
+ }
179
+
180
+ // Combine rows and columns so we can sort on them both
181
+ wp_launch_device(WP_CURRENT_CONTEXT, bsr_fill_row_col, nnz,
182
+ (p_nz_triplet_count, d_keys.Current(), tpl_rows, tpl_columns,
183
+ d_values.Current()));
184
+
185
+ if (tpl_values) {
186
+ // Make sure count is available on host
187
+ cudaEventSynchronize(bsr_temp.host_sync_event);
188
+ }
189
+
190
+ const int nz_triplet_count = *p_nz_triplet_count;
191
+
192
+ // Sort
193
+ {
194
+ size_t buff_size = 0;
195
+ check_cuda(cub::DeviceRadixSort::SortPairs(
196
+ nullptr, buff_size, d_values, d_keys, nz_triplet_count, 0, 64, stream));
197
+ ScopedTemporary<> temp(context, buff_size);
198
+ check_cuda(cub::DeviceRadixSort::SortPairs(temp.buffer(), buff_size,
199
+ d_values, d_keys, nz_triplet_count,
200
+ 0, 64, stream));
201
+ }
202
+
203
+ // Runlength encode row-col sequences
204
+ {
205
+ size_t buff_size = 0;
206
+ check_cuda(cub::DeviceRunLengthEncode::Encode(
207
+ nullptr, buff_size, d_values.Current(), d_values.Alternate(),
208
+ d_keys.Alternate(), p_nz_triplet_count, nz_triplet_count, stream));
209
+ ScopedTemporary<> temp(context, buff_size);
210
+ check_cuda(cub::DeviceRunLengthEncode::Encode(
211
+ temp.buffer(), buff_size, d_values.Current(), d_values.Alternate(),
212
+ d_keys.Alternate(), p_nz_triplet_count, nz_triplet_count, stream));
213
+ }
214
+
215
+ cudaEventRecord(bsr_temp.host_sync_event, stream);
216
+
217
+ // Now we have the following:
218
+ // d_values.Current(): sorted block row-col
219
+ // d_values.Alternate(): sorted unique row-col
220
+ // d_keys.Current(): sorted block indices
221
+ // d_keys.Alternate(): repeated block-row count
222
+
223
+ // Scan repeated block counts
224
+ {
225
+ size_t buff_size = 0;
226
+ check_cuda(cub::DeviceScan::InclusiveSum(
227
+ nullptr, buff_size, d_keys.Alternate(), d_keys.Alternate(),
228
+ nz_triplet_count, stream));
229
+ ScopedTemporary<> temp(context, buff_size);
230
+ check_cuda(cub::DeviceScan::InclusiveSum(
231
+ temp.buffer(), buff_size, d_keys.Alternate(), d_keys.Alternate(),
232
+ nz_triplet_count, stream));
233
+ }
234
+
235
+ // While we're at it, zero the bsr offsets buffer
236
+ memset_device(WP_CURRENT_CONTEXT, bsr_offsets, 0,
237
+ (row_count + 1) * sizeof(int));
238
+
239
+ // Wait for number of compressed blocks
240
+ cudaEventSynchronize(bsr_temp.host_sync_event);
241
+ const int compressed_nnz = *p_nz_triplet_count;
242
+
243
+ // We have all we need to accumulate our repeated blocks
244
+ wp_launch_device(WP_CURRENT_CONTEXT, bsr_merge_blocks, compressed_nnz,
245
+ (compressed_nnz, block_size, d_keys.Alternate(),
246
+ d_keys.Current(), d_values.Alternate(), tpl_values,
247
+ bsr_offsets, bsr_columns, bsr_values));
248
+
249
+ // Last, prefix sum the row block counts
250
+ {
251
+ size_t buff_size = 0;
252
+ check_cuda(cub::DeviceScan::InclusiveSum(nullptr, buff_size, bsr_offsets,
253
+ bsr_offsets, row_count + 1, stream));
254
+ ScopedTemporary<> temp(context, buff_size);
255
+ check_cuda(cub::DeviceScan::InclusiveSum(temp.buffer(), buff_size,
256
+ bsr_offsets, bsr_offsets,
257
+ row_count + 1, stream));
258
+ }
259
+
260
+ return compressed_nnz;
261
+ }
262
+
263
+ __global__ void bsr_transpose_fill_row_col(const int nnz, const int row_count,
264
+ const int *bsr_offsets,
265
+ const int *bsr_columns,
266
+ int *block_indices,
267
+ BsrRowCol *transposed_row_col,
268
+ int *transposed_bsr_offsets) {
269
+ int i = blockIdx.x * blockDim.x + threadIdx.x;
270
+ if (i >= nnz)
271
+ return;
272
+
273
+ block_indices[i] = i;
274
+
275
+ // Binary search for row
276
+ int lower = 0;
277
+ int upper = row_count - 1;
278
+
279
+ while (lower < upper) {
280
+ int mid = lower + (upper - lower) / 2;
281
+
282
+ if (bsr_offsets[mid + 1] <= i) {
283
+ lower = mid + 1;
284
+ } else {
285
+ upper = mid;
286
+ }
287
+ }
288
+
289
+ const int row = lower;
290
+ const int col = bsr_columns[i];
291
+ BsrRowCol row_col = bsr_combine_row_col(col, row);
292
+ transposed_row_col[i] = row_col;
293
+
294
+ atomicAdd(transposed_bsr_offsets + col + 1, 1);
295
+ }
296
+
297
+ template <int Rows, int Cols, typename T> struct BsrBlockTransposer {
298
+ void CUDA_CALLABLE_DEVICE operator()(const T *src, T *dest) const {
299
+ for (int r = 0; r < Rows; ++r) {
300
+ for (int c = 0; c < Cols; ++c) {
301
+ dest[c * Rows + r] = src[r * Cols + c];
302
+ }
303
+ }
304
+ }
305
+ };
306
+
307
+ template <typename T> struct BsrBlockTransposer<-1, -1, T> {
308
+
309
+ int row_count;
310
+ int col_count;
311
+
312
+ void CUDA_CALLABLE_DEVICE operator()(const T *src, T *dest) const {
313
+ for (int r = 0; r < row_count; ++r) {
314
+ for (int c = 0; c < col_count; ++c) {
315
+ dest[c * row_count + r] = src[r * col_count + c];
316
+ }
317
+ }
318
+ }
319
+ };
320
+
321
+ template <int Rows, int Cols, typename T>
322
+ __global__ void
323
+ bsr_transpose_blocks(const int nnz, const int block_size,
324
+ BsrBlockTransposer<Rows, Cols, T> transposer,
325
+ const int *block_indices,
326
+ const BsrRowCol *transposed_indices, const T *bsr_values,
327
+ int *transposed_bsr_columns, T *transposed_bsr_values) {
328
+ int i = blockIdx.x * blockDim.x + threadIdx.x;
329
+ if (i >= nnz)
330
+ return;
331
+
332
+ const int src_idx = block_indices[i];
333
+
334
+ transposer(bsr_values + src_idx * block_size,
335
+ transposed_bsr_values + i * block_size);
336
+
337
+ transposed_bsr_columns[i] = bsr_get_col(transposed_indices[i]);
338
+ }
339
+
340
+ template <typename T>
341
+ void
342
+ launch_bsr_transpose_blocks(const int nnz, const int block_size,
343
+ const int rows_per_block, const int cols_per_block,
344
+ const int *block_indices,
345
+ const BsrRowCol *transposed_indices,
346
+ const T *bsr_values,
347
+ int *transposed_bsr_columns, T *transposed_bsr_values) {
348
+
349
+ switch (rows_per_block) {
350
+ case 1:
351
+ switch (cols_per_block) {
352
+ case 1:
353
+ wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
354
+ (nnz, block_size, BsrBlockTransposer<1, 1, T>{},
355
+ block_indices, transposed_indices, bsr_values,
356
+ transposed_bsr_columns, transposed_bsr_values));
357
+ return;
358
+ case 2:
359
+ wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
360
+ (nnz, block_size, BsrBlockTransposer<1, 2, T>{},
361
+ block_indices, transposed_indices, bsr_values,
362
+ transposed_bsr_columns, transposed_bsr_values));
363
+ return;
364
+ case 3:
365
+ wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
366
+ (nnz, block_size, BsrBlockTransposer<1, 3, T>{},
367
+ block_indices, transposed_indices, bsr_values,
368
+ transposed_bsr_columns, transposed_bsr_values));
369
+ return;
370
+ }
371
+ case 2:
372
+ switch (cols_per_block) {
373
+ case 1:
374
+ wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
375
+ (nnz, block_size, BsrBlockTransposer<2, 1, T>{},
376
+ block_indices, transposed_indices, bsr_values,
377
+ transposed_bsr_columns, transposed_bsr_values));
378
+ return;
379
+ case 2:
380
+ wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
381
+ (nnz, block_size, BsrBlockTransposer<2, 2, T>{},
382
+ block_indices, transposed_indices, bsr_values,
383
+ transposed_bsr_columns, transposed_bsr_values));
384
+ return;
385
+ case 3:
386
+ wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
387
+ (nnz, block_size, BsrBlockTransposer<2, 3, T>{},
388
+ block_indices, transposed_indices, bsr_values,
389
+ transposed_bsr_columns, transposed_bsr_values));
390
+ return;
391
+ }
392
+ case 3:
393
+ switch (cols_per_block) {
394
+ case 1:
395
+ wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
396
+ (nnz, block_size, BsrBlockTransposer<3, 1, T>{},
397
+ block_indices, transposed_indices, bsr_values,
398
+ transposed_bsr_columns, transposed_bsr_values));
399
+ return;
400
+ case 2:
401
+ wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
402
+ (nnz, block_size, BsrBlockTransposer<3, 2, T>{},
403
+ block_indices, transposed_indices, bsr_values,
404
+ transposed_bsr_columns, transposed_bsr_values));
405
+ return;
406
+ case 3:
407
+ wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
408
+ (nnz, block_size, BsrBlockTransposer<3, 3, T>{},
409
+ block_indices, transposed_indices, bsr_values,
410
+ transposed_bsr_columns, transposed_bsr_values));
411
+ return;
412
+ }
413
+ }
414
+
415
+ wp_launch_device(
416
+ WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
417
+ (nnz, block_size,
418
+ BsrBlockTransposer<-1, -1, T>{rows_per_block, cols_per_block},
419
+ block_indices, transposed_indices, bsr_values, transposed_bsr_columns,
420
+ transposed_bsr_values));
421
+ }
422
+
423
+ template <typename T>
424
+ void bsr_transpose_device(int rows_per_block, int cols_per_block, int row_count,
425
+ int col_count, int nnz, const int *bsr_offsets,
426
+ const int *bsr_columns, const T *bsr_values,
427
+ int *transposed_bsr_offsets,
428
+ int *transposed_bsr_columns,
429
+ T *transposed_bsr_values) {
430
+
431
+ const int block_size = rows_per_block * cols_per_block;
432
+
433
+ void *context = cuda_context_get_current();
434
+ ContextGuard guard(context);
435
+
436
+ cudaStream_t stream = static_cast<cudaStream_t>(cuda_stream_get_current());
437
+
438
+ // Zero the transposed offsets
439
+ memset_device(WP_CURRENT_CONTEXT, transposed_bsr_offsets, 0,
440
+ (col_count + 1) * sizeof(int));
441
+
442
+ ScopedTemporary<int> block_indices(context, 2*nnz);
443
+ ScopedTemporary<BsrRowCol> combined_row_col(context, 2*nnz);
444
+
445
+ cub::DoubleBuffer<int> d_keys(block_indices.buffer(),
446
+ block_indices.buffer() + nnz);
447
+ cub::DoubleBuffer<BsrRowCol> d_values(combined_row_col.buffer(),
448
+ combined_row_col.buffer() + nnz);
449
+
450
+ wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_fill_row_col, nnz,
451
+ (nnz, row_count, bsr_offsets, bsr_columns, d_keys.Current(),
452
+ d_values.Current(), transposed_bsr_offsets));
453
+
454
+ // Sort blocks
455
+ {
456
+ size_t buff_size = 0;
457
+ check_cuda(cub::DeviceRadixSort::SortPairs(nullptr, buff_size, d_values,
458
+ d_keys, nnz, 0, 64, stream));
459
+ void* temp_buffer = alloc_temp_device(WP_CURRENT_CONTEXT, buff_size);
460
+ ScopedTemporary<> temp(context, buff_size);
461
+ check_cuda(cub::DeviceRadixSort::SortPairs(
462
+ temp.buffer(), buff_size, d_values, d_keys, nnz, 0, 64, stream));
463
+ }
464
+
465
+ // Prefix sum the transposed row block counts
466
+ {
467
+ size_t buff_size = 0;
468
+ check_cuda(cub::DeviceScan::InclusiveSum(
469
+ nullptr, buff_size, transposed_bsr_offsets, transposed_bsr_offsets,
470
+ col_count + 1, stream));
471
+ ScopedTemporary<> temp(context, buff_size);
472
+ check_cuda(cub::DeviceScan::InclusiveSum(
473
+ temp.buffer(), buff_size, transposed_bsr_offsets,
474
+ transposed_bsr_offsets, col_count + 1, stream));
475
+ }
476
+
477
+ // Move and transpose individual blocks
478
+ launch_bsr_transpose_blocks(
479
+ nnz, block_size,
480
+ rows_per_block, cols_per_block,
481
+ d_keys.Current(), d_values.Current(), bsr_values, transposed_bsr_columns,
482
+ transposed_bsr_values);
483
+ }
484
+
485
+ } // namespace
486
+
487
+ int bsr_matrix_from_triplets_float_device(
488
+ int rows_per_block, int cols_per_block, int row_count, int nnz,
489
+ uint64_t tpl_rows, uint64_t tpl_columns, uint64_t tpl_values,
490
+ uint64_t bsr_offsets, uint64_t bsr_columns, uint64_t bsr_values) {
491
+ return bsr_matrix_from_triplets_device<float>(
492
+ rows_per_block, cols_per_block, row_count, nnz,
493
+ reinterpret_cast<const int *>(tpl_rows),
494
+ reinterpret_cast<const int *>(tpl_columns),
495
+ reinterpret_cast<const float *>(tpl_values),
496
+ reinterpret_cast<int *>(bsr_offsets),
497
+ reinterpret_cast<int *>(bsr_columns),
498
+ reinterpret_cast<float *>(bsr_values));
499
+ }
500
+
501
+ int bsr_matrix_from_triplets_double_device(
502
+ int rows_per_block, int cols_per_block, int row_count, int nnz,
503
+ uint64_t tpl_rows, uint64_t tpl_columns, uint64_t tpl_values,
504
+ uint64_t bsr_offsets, uint64_t bsr_columns, uint64_t bsr_values) {
505
+ return bsr_matrix_from_triplets_device<double>(
506
+ rows_per_block, cols_per_block, row_count, nnz,
507
+ reinterpret_cast<const int *>(tpl_rows),
508
+ reinterpret_cast<const int *>(tpl_columns),
509
+ reinterpret_cast<const double *>(tpl_values),
510
+ reinterpret_cast<int *>(bsr_offsets),
511
+ reinterpret_cast<int *>(bsr_columns),
512
+ reinterpret_cast<double *>(bsr_values));
513
+ }
514
+
515
+ void bsr_transpose_float_device(int rows_per_block, int cols_per_block,
516
+ int row_count, int col_count, int nnz,
517
+ uint64_t bsr_offsets, uint64_t bsr_columns,
518
+ uint64_t bsr_values,
519
+ uint64_t transposed_bsr_offsets,
520
+ uint64_t transposed_bsr_columns,
521
+ uint64_t transposed_bsr_values) {
522
+ bsr_transpose_device(rows_per_block, cols_per_block, row_count, col_count,
523
+ nnz, reinterpret_cast<const int *>(bsr_offsets),
524
+ reinterpret_cast<const int *>(bsr_columns),
525
+ reinterpret_cast<const float *>(bsr_values),
526
+ reinterpret_cast<int *>(transposed_bsr_offsets),
527
+ reinterpret_cast<int *>(transposed_bsr_columns),
528
+ reinterpret_cast<float *>(transposed_bsr_values));
529
+ }
530
+
531
+ void bsr_transpose_double_device(int rows_per_block, int cols_per_block,
532
+ int row_count, int col_count, int nnz,
533
+ uint64_t bsr_offsets, uint64_t bsr_columns,
534
+ uint64_t bsr_values,
535
+ uint64_t transposed_bsr_offsets,
536
+ uint64_t transposed_bsr_columns,
537
+ uint64_t transposed_bsr_values) {
538
+ bsr_transpose_device(rows_per_block, cols_per_block, row_count, col_count,
539
+ nnz, reinterpret_cast<const int *>(bsr_offsets),
540
+ reinterpret_cast<const int *>(bsr_columns),
541
+ reinterpret_cast<const double *>(bsr_values),
542
+ reinterpret_cast<int *>(transposed_bsr_offsets),
543
+ reinterpret_cast<int *>(transposed_bsr_columns),
544
+ reinterpret_cast<double *>(transposed_bsr_values));
545
+ }
warp/native/spatial.h CHANGED
@@ -265,13 +265,13 @@ inline CUDA_CALLABLE Type tensordot(const transform_t<Type>& a, const transform_
265
265
  }
266
266
 
267
267
  template<typename Type>
268
- inline CUDA_CALLABLE Type index(const transform_t<Type>& t, int i)
268
+ inline CUDA_CALLABLE Type extract(const transform_t<Type>& t, int i)
269
269
  {
270
270
  return t[i];
271
271
  }
272
272
 
273
273
  template<typename Type>
274
- inline void CUDA_CALLABLE adj_index(const transform_t<Type>& t, int i, transform_t<Type>& adj_t, int& adj_i, Type adj_ret)
274
+ inline void CUDA_CALLABLE adj_extract(const transform_t<Type>& t, int i, transform_t<Type>& adj_t, int& adj_i, Type adj_ret)
275
275
  {
276
276
  adj_t[i] += adj_ret;
277
277
  }
@@ -0,0 +1,30 @@
1
+
2
+ #pragma once
3
+
4
+ #include "cuda_util.h"
5
+ #include "warp.h"
6
+
7
+ #include <unordered_map>
8
+
9
+ template <typename T = char> struct ScopedTemporary
10
+ {
11
+
12
+ ScopedTemporary(void *context, size_t size)
13
+ : m_context(context), m_buffer(static_cast<T*>(alloc_temp_device(m_context, size * sizeof(T))))
14
+ {
15
+ }
16
+
17
+ ~ScopedTemporary()
18
+ {
19
+ free_temp_device(m_context, m_buffer);
20
+ }
21
+
22
+ T *buffer() const
23
+ {
24
+ return m_buffer;
25
+ }
26
+
27
+ private:
28
+ void *m_context;
29
+ T *m_buffer;
30
+ };