warp-lang 0.10.1__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +10 -4
- warp/__init__.pyi +1 -0
- warp/bin/warp-clang.dll +0 -0
- warp/bin/warp.dll +0 -0
- warp/build.py +5 -3
- warp/build_dll.py +29 -9
- warp/builtins.py +868 -507
- warp/codegen.py +1074 -638
- warp/config.py +3 -3
- warp/constants.py +6 -0
- warp/context.py +715 -222
- warp/fabric.py +326 -0
- warp/fem/__init__.py +27 -0
- warp/fem/cache.py +389 -0
- warp/fem/dirichlet.py +181 -0
- warp/fem/domain.py +263 -0
- warp/fem/field/__init__.py +101 -0
- warp/fem/field/field.py +149 -0
- warp/fem/field/nodal_field.py +299 -0
- warp/fem/field/restriction.py +21 -0
- warp/fem/field/test.py +181 -0
- warp/fem/field/trial.py +183 -0
- warp/fem/geometry/__init__.py +19 -0
- warp/fem/geometry/closest_point.py +70 -0
- warp/fem/geometry/deformed_geometry.py +271 -0
- warp/fem/geometry/element.py +744 -0
- warp/fem/geometry/geometry.py +186 -0
- warp/fem/geometry/grid_2d.py +373 -0
- warp/fem/geometry/grid_3d.py +435 -0
- warp/fem/geometry/hexmesh.py +953 -0
- warp/fem/geometry/partition.py +376 -0
- warp/fem/geometry/quadmesh_2d.py +532 -0
- warp/fem/geometry/tetmesh.py +840 -0
- warp/fem/geometry/trimesh_2d.py +577 -0
- warp/fem/integrate.py +1616 -0
- warp/fem/operator.py +191 -0
- warp/fem/polynomial.py +213 -0
- warp/fem/quadrature/__init__.py +2 -0
- warp/fem/quadrature/pic_quadrature.py +245 -0
- warp/fem/quadrature/quadrature.py +294 -0
- warp/fem/space/__init__.py +292 -0
- warp/fem/space/basis_space.py +489 -0
- warp/fem/space/collocated_function_space.py +105 -0
- warp/fem/space/dof_mapper.py +236 -0
- warp/fem/space/function_space.py +145 -0
- warp/fem/space/grid_2d_function_space.py +267 -0
- warp/fem/space/grid_3d_function_space.py +306 -0
- warp/fem/space/hexmesh_function_space.py +352 -0
- warp/fem/space/partition.py +350 -0
- warp/fem/space/quadmesh_2d_function_space.py +369 -0
- warp/fem/space/restriction.py +160 -0
- warp/fem/space/shape/__init__.py +15 -0
- warp/fem/space/shape/cube_shape_function.py +738 -0
- warp/fem/space/shape/shape_function.py +103 -0
- warp/fem/space/shape/square_shape_function.py +611 -0
- warp/fem/space/shape/tet_shape_function.py +567 -0
- warp/fem/space/shape/triangle_shape_function.py +429 -0
- warp/fem/space/tetmesh_function_space.py +292 -0
- warp/fem/space/topology.py +295 -0
- warp/fem/space/trimesh_2d_function_space.py +221 -0
- warp/fem/types.py +77 -0
- warp/fem/utils.py +495 -0
- warp/native/array.h +147 -44
- warp/native/builtin.h +122 -149
- warp/native/bvh.cpp +73 -325
- warp/native/bvh.cu +406 -23
- warp/native/bvh.h +34 -43
- warp/native/clang/clang.cpp +13 -8
- warp/native/crt.h +2 -0
- warp/native/cuda_crt.h +5 -0
- warp/native/cuda_util.cpp +15 -3
- warp/native/cuda_util.h +3 -1
- warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
- warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
- warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
- warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
- warp/native/cutlass/tools/library/scripts/library.py +799 -0
- warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
- warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
- warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
- warp/native/cutlass/tools/library/scripts/rt.py +796 -0
- warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
- warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
- warp/native/cutlass_gemm.cu +5 -3
- warp/native/exports.h +1240 -952
- warp/native/fabric.h +228 -0
- warp/native/hashgrid.cpp +4 -4
- warp/native/hashgrid.h +22 -2
- warp/native/intersect.h +22 -7
- warp/native/intersect_adj.h +8 -8
- warp/native/intersect_tri.h +1 -1
- warp/native/marching.cu +157 -161
- warp/native/mat.h +80 -19
- warp/native/matnn.h +2 -2
- warp/native/mesh.cpp +33 -108
- warp/native/mesh.cu +114 -23
- warp/native/mesh.h +446 -46
- warp/native/noise.h +272 -329
- warp/native/quat.h +51 -8
- warp/native/rand.h +45 -35
- warp/native/range.h +6 -2
- warp/native/reduce.cpp +1 -1
- warp/native/reduce.cu +10 -12
- warp/native/runlength_encode.cu +6 -10
- warp/native/scan.cu +8 -11
- warp/native/sparse.cpp +4 -4
- warp/native/sparse.cu +164 -154
- warp/native/spatial.h +2 -2
- warp/native/temp_buffer.h +14 -30
- warp/native/vec.h +107 -23
- warp/native/volume.h +120 -0
- warp/native/warp.cpp +560 -30
- warp/native/warp.cu +431 -44
- warp/native/warp.h +13 -4
- warp/optim/__init__.py +1 -0
- warp/optim/linear.py +922 -0
- warp/optim/sgd.py +92 -0
- warp/render/render_opengl.py +335 -119
- warp/render/render_usd.py +11 -11
- warp/sim/__init__.py +2 -2
- warp/sim/articulation.py +385 -185
- warp/sim/collide.py +8 -0
- warp/sim/import_mjcf.py +297 -106
- warp/sim/import_urdf.py +389 -210
- warp/sim/import_usd.py +198 -97
- warp/sim/inertia.py +17 -18
- warp/sim/integrator_euler.py +14 -8
- warp/sim/integrator_xpbd.py +158 -16
- warp/sim/model.py +795 -291
- warp/sim/render.py +3 -3
- warp/sim/utils.py +3 -0
- warp/sparse.py +640 -150
- warp/stubs.py +606 -267
- warp/tape.py +61 -10
- warp/tests/__main__.py +3 -6
- warp/tests/assets/curlnoise_golden.npy +0 -0
- warp/tests/assets/pnoise_golden.npy +0 -0
- warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
- warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
- warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
- warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
- warp/tests/aux_test_unresolved_func.py +14 -0
- warp/tests/aux_test_unresolved_symbol.py +14 -0
- warp/tests/disabled_kinematics.py +239 -0
- warp/tests/run_coverage_serial.py +31 -0
- warp/tests/test_adam.py +103 -106
- warp/tests/test_arithmetic.py +128 -74
- warp/tests/test_array.py +212 -97
- warp/tests/test_array_reduce.py +57 -23
- warp/tests/test_atomic.py +64 -28
- warp/tests/test_bool.py +99 -0
- warp/tests/test_builtins_resolution.py +1292 -0
- warp/tests/test_bvh.py +42 -18
- warp/tests/test_closest_point_edge_edge.py +54 -57
- warp/tests/test_codegen.py +208 -130
- warp/tests/test_compile_consts.py +28 -20
- warp/tests/test_conditional.py +108 -24
- warp/tests/test_copy.py +10 -12
- warp/tests/test_ctypes.py +112 -88
- warp/tests/test_dense.py +21 -14
- warp/tests/test_devices.py +98 -0
- warp/tests/test_dlpack.py +75 -75
- warp/tests/test_examples.py +277 -0
- warp/tests/test_fabricarray.py +955 -0
- warp/tests/test_fast_math.py +15 -11
- warp/tests/test_fem.py +1271 -0
- warp/tests/test_fp16.py +53 -19
- warp/tests/test_func.py +187 -86
- warp/tests/test_generics.py +194 -49
- warp/tests/test_grad.py +178 -109
- warp/tests/test_grad_customs.py +176 -0
- warp/tests/test_hash_grid.py +52 -37
- warp/tests/test_import.py +10 -23
- warp/tests/test_indexedarray.py +32 -31
- warp/tests/test_intersect.py +18 -9
- warp/tests/test_large.py +141 -0
- warp/tests/test_launch.py +14 -41
- warp/tests/test_lerp.py +64 -65
- warp/tests/test_linear_solvers.py +154 -0
- warp/tests/test_lvalue.py +493 -0
- warp/tests/test_marching_cubes.py +12 -13
- warp/tests/test_mat.py +517 -2898
- warp/tests/test_mat_lite.py +115 -0
- warp/tests/test_mat_scalar_ops.py +2889 -0
- warp/tests/test_math.py +103 -9
- warp/tests/test_matmul.py +305 -69
- warp/tests/test_matmul_lite.py +410 -0
- warp/tests/test_mesh.py +71 -14
- warp/tests/test_mesh_query_aabb.py +41 -25
- warp/tests/test_mesh_query_point.py +140 -22
- warp/tests/test_mesh_query_ray.py +39 -22
- warp/tests/test_mlp.py +30 -22
- warp/tests/test_model.py +92 -89
- warp/tests/test_modules_lite.py +39 -0
- warp/tests/test_multigpu.py +88 -114
- warp/tests/test_noise.py +12 -11
- warp/tests/test_operators.py +16 -20
- warp/tests/test_options.py +11 -11
- warp/tests/test_pinned.py +17 -18
- warp/tests/test_print.py +32 -11
- warp/tests/test_quat.py +275 -129
- warp/tests/test_rand.py +18 -16
- warp/tests/test_reload.py +38 -34
- warp/tests/test_rounding.py +50 -43
- warp/tests/test_runlength_encode.py +168 -20
- warp/tests/test_smoothstep.py +9 -11
- warp/tests/test_snippet.py +143 -0
- warp/tests/test_sparse.py +261 -63
- warp/tests/test_spatial.py +276 -243
- warp/tests/test_streams.py +110 -85
- warp/tests/test_struct.py +268 -63
- warp/tests/test_tape.py +39 -21
- warp/tests/test_torch.py +118 -89
- warp/tests/test_transient_module.py +12 -13
- warp/tests/test_types.py +614 -0
- warp/tests/test_utils.py +494 -0
- warp/tests/test_vec.py +354 -2050
- warp/tests/test_vec_lite.py +73 -0
- warp/tests/test_vec_scalar_ops.py +2099 -0
- warp/tests/test_volume.py +457 -293
- warp/tests/test_volume_write.py +124 -134
- warp/tests/unittest_serial.py +35 -0
- warp/tests/unittest_suites.py +341 -0
- warp/tests/unittest_utils.py +568 -0
- warp/tests/unused_test_misc.py +71 -0
- warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
- warp/thirdparty/appdirs.py +36 -45
- warp/thirdparty/unittest_parallel.py +549 -0
- warp/torch.py +9 -6
- warp/types.py +1089 -366
- warp/utils.py +93 -387
- warp_lang-0.11.0.dist-info/METADATA +238 -0
- warp_lang-0.11.0.dist-info/RECORD +332 -0
- {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
- warp/tests/test_all.py +0 -219
- warp/tests/test_array_scan.py +0 -60
- warp/tests/test_base.py +0 -208
- warp/tests/test_unresolved_func.py +0 -7
- warp/tests/test_unresolved_symbol.py +0 -7
- warp_lang-0.10.1.dist-info/METADATA +0 -21
- warp_lang-0.10.1.dist-info/RECORD +0 -188
- /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
- /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
- /warp/tests/{test_square.py → aux_test_square.py} +0 -0
- {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/LICENSE.md +0 -0
- {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
warp/native/sparse.cu
CHANGED
|
@@ -1,8 +1,6 @@
|
|
|
1
1
|
#include "cuda_util.h"
|
|
2
2
|
#include "warp.h"
|
|
3
3
|
|
|
4
|
-
#include "temp_buffer.h"
|
|
5
|
-
|
|
6
4
|
#define THRUST_IGNORE_CUB_VERSION_CHECK
|
|
7
5
|
|
|
8
6
|
#include <cub/device/device_radix_sort.cuh>
|
|
@@ -29,40 +27,29 @@ CUDA_CALLABLE uint32_t bsr_get_col(const BsrRowCol &row_col) {
|
|
|
29
27
|
|
|
30
28
|
// Cached temporary storage
|
|
31
29
|
struct BsrFromTripletsTemp {
|
|
32
|
-
|
|
33
|
-
int
|
|
34
|
-
int *block_indices = NULL;
|
|
35
|
-
|
|
36
|
-
BsrRowCol *combined_row_col = NULL;
|
|
37
|
-
|
|
30
|
+
|
|
31
|
+
int *count_buffer = NULL;
|
|
38
32
|
cudaEvent_t host_sync_event = NULL;
|
|
39
33
|
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
combined_row_col = static_cast<BsrRowCol *>(
|
|
52
|
-
alloc_device(WP_CURRENT_CONTEXT, 2 * size * sizeof(BsrRowCol)));
|
|
34
|
+
BsrFromTripletsTemp()
|
|
35
|
+
: count_buffer(static_cast<int*>(alloc_pinned(sizeof(int))))
|
|
36
|
+
{
|
|
37
|
+
cudaEventCreateWithFlags(&host_sync_event, cudaEventDisableTiming);
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
~BsrFromTripletsTemp()
|
|
41
|
+
{
|
|
42
|
+
cudaEventDestroy(host_sync_event);
|
|
43
|
+
free_pinned(count_buffer);
|
|
44
|
+
}
|
|
53
45
|
|
|
54
|
-
|
|
55
|
-
|
|
46
|
+
BsrFromTripletsTemp(const BsrFromTripletsTemp&) = delete;
|
|
47
|
+
BsrFromTripletsTemp& operator=(const BsrFromTripletsTemp&) = delete;
|
|
56
48
|
|
|
57
|
-
if (host_sync_event == NULL) {
|
|
58
|
-
cudaEventCreateWithFlags(&host_sync_event, cudaEventDisableTiming);
|
|
59
|
-
}
|
|
60
|
-
}
|
|
61
49
|
};
|
|
62
50
|
|
|
63
51
|
// map temp buffers to CUDA contexts
|
|
64
|
-
static std::unordered_map<void *, BsrFromTripletsTemp>
|
|
65
|
-
g_bsr_from_triplets_temp_map;
|
|
52
|
+
static std::unordered_map<void *, BsrFromTripletsTemp> g_bsr_from_triplets_temp_map;
|
|
66
53
|
|
|
67
54
|
template <typename T> struct BsrBlockIsNotZero {
|
|
68
55
|
int block_size;
|
|
@@ -147,25 +134,22 @@ int bsr_matrix_from_triplets_device(const int rows_per_block,
|
|
|
147
134
|
const int block_size = rows_per_block * cols_per_block;
|
|
148
135
|
|
|
149
136
|
void *context = cuda_context_get_current();
|
|
137
|
+
ContextGuard guard(context);
|
|
150
138
|
|
|
151
139
|
// Per-context cached temporary buffers
|
|
152
|
-
TemporaryBuffer &cub_temp = g_temp_buffer_map[context];
|
|
153
|
-
PinnedTemporaryBuffer &pinned_temp = g_pinned_temp_buffer_map[context];
|
|
154
140
|
BsrFromTripletsTemp &bsr_temp = g_bsr_from_triplets_temp_map[context];
|
|
155
141
|
|
|
156
|
-
ContextGuard guard(context);
|
|
157
|
-
|
|
158
142
|
cudaStream_t stream = static_cast<cudaStream_t>(cuda_stream_get_current());
|
|
159
|
-
bsr_temp.ensure_fits(nnz);
|
|
160
143
|
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
144
|
+
ScopedTemporary<int> block_indices(context, 2*nnz);
|
|
145
|
+
ScopedTemporary<BsrRowCol> combined_row_col(context, 2*nnz);
|
|
146
|
+
|
|
147
|
+
cub::DoubleBuffer<int> d_keys(block_indices.buffer(),
|
|
148
|
+
block_indices.buffer() + nnz);
|
|
149
|
+
cub::DoubleBuffer<BsrRowCol> d_values(combined_row_col.buffer(),
|
|
150
|
+
combined_row_col.buffer() + nnz);
|
|
165
151
|
|
|
166
|
-
int *
|
|
167
|
-
pinned_temp.ensure_fits(sizeof(int));
|
|
168
|
-
int *pinned_count = static_cast<int *>(pinned_temp.buffer);
|
|
152
|
+
int *p_nz_triplet_count = bsr_temp.count_buffer;
|
|
169
153
|
|
|
170
154
|
wp_launch_device(WP_CURRENT_CONTEXT, bsr_fill_block_indices, nnz,
|
|
171
155
|
(nnz, d_keys.Current()));
|
|
@@ -173,32 +157,29 @@ int bsr_matrix_from_triplets_device(const int rows_per_block,
|
|
|
173
157
|
if (tpl_values) {
|
|
174
158
|
|
|
175
159
|
// Remove zero blocks
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
160
|
+
{
|
|
161
|
+
size_t buff_size = 0;
|
|
162
|
+
BsrBlockIsNotZero<T> isNotZero{block_size, tpl_values};
|
|
163
|
+
check_cuda(cub::DeviceSelect::If(nullptr, buff_size, d_keys.Current(),
|
|
164
|
+
d_keys.Alternate(), p_nz_triplet_count,
|
|
165
|
+
nnz, isNotZero, stream));
|
|
166
|
+
ScopedTemporary<> temp(context, buff_size);
|
|
167
|
+
check_cuda(cub::DeviceSelect::If(
|
|
168
|
+
temp.buffer(), buff_size, d_keys.Current(), d_keys.Alternate(),
|
|
169
|
+
p_nz_triplet_count, nnz, isNotZero, stream));
|
|
170
|
+
}
|
|
171
|
+
cudaEventRecord(bsr_temp.host_sync_event, stream);
|
|
185
172
|
|
|
186
173
|
// switch current/alternate in double buffer
|
|
187
174
|
d_keys.selector ^= 1;
|
|
188
175
|
|
|
189
|
-
// Copy number of remaining items to host, needed for further launches
|
|
190
|
-
memcpy_d2h(WP_CURRENT_CONTEXT, pinned_count, d_nz_triplet_count,
|
|
191
|
-
sizeof(int));
|
|
192
|
-
cudaEventRecord(bsr_temp.host_sync_event, stream);
|
|
193
176
|
} else {
|
|
194
|
-
*
|
|
195
|
-
memcpy_h2d(WP_CURRENT_CONTEXT, d_nz_triplet_count, pinned_count,
|
|
196
|
-
sizeof(int));
|
|
177
|
+
*p_nz_triplet_count = nnz;
|
|
197
178
|
}
|
|
198
179
|
|
|
199
180
|
// Combine rows and columns so we can sort on them both
|
|
200
181
|
wp_launch_device(WP_CURRENT_CONTEXT, bsr_fill_row_col, nnz,
|
|
201
|
-
(
|
|
182
|
+
(p_nz_triplet_count, d_keys.Current(), tpl_rows, tpl_columns,
|
|
202
183
|
d_values.Current()));
|
|
203
184
|
|
|
204
185
|
if (tpl_values) {
|
|
@@ -206,27 +187,31 @@ int bsr_matrix_from_triplets_device(const int rows_per_block,
|
|
|
206
187
|
cudaEventSynchronize(bsr_temp.host_sync_event);
|
|
207
188
|
}
|
|
208
189
|
|
|
209
|
-
const int nz_triplet_count = *
|
|
190
|
+
const int nz_triplet_count = *p_nz_triplet_count;
|
|
210
191
|
|
|
211
192
|
// Sort
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
193
|
+
{
|
|
194
|
+
size_t buff_size = 0;
|
|
195
|
+
check_cuda(cub::DeviceRadixSort::SortPairs(
|
|
196
|
+
nullptr, buff_size, d_values, d_keys, nz_triplet_count, 0, 64, stream));
|
|
197
|
+
ScopedTemporary<> temp(context, buff_size);
|
|
198
|
+
check_cuda(cub::DeviceRadixSort::SortPairs(temp.buffer(), buff_size,
|
|
199
|
+
d_values, d_keys, nz_triplet_count,
|
|
200
|
+
0, 64, stream));
|
|
201
|
+
}
|
|
219
202
|
|
|
220
203
|
// Runlength encode row-col sequences
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
204
|
+
{
|
|
205
|
+
size_t buff_size = 0;
|
|
206
|
+
check_cuda(cub::DeviceRunLengthEncode::Encode(
|
|
207
|
+
nullptr, buff_size, d_values.Current(), d_values.Alternate(),
|
|
208
|
+
d_keys.Alternate(), p_nz_triplet_count, nz_triplet_count, stream));
|
|
209
|
+
ScopedTemporary<> temp(context, buff_size);
|
|
210
|
+
check_cuda(cub::DeviceRunLengthEncode::Encode(
|
|
211
|
+
temp.buffer(), buff_size, d_values.Current(), d_values.Alternate(),
|
|
212
|
+
d_keys.Alternate(), p_nz_triplet_count, nz_triplet_count, stream));
|
|
213
|
+
}
|
|
214
|
+
|
|
230
215
|
cudaEventRecord(bsr_temp.host_sync_event, stream);
|
|
231
216
|
|
|
232
217
|
// Now we have the following:
|
|
@@ -236,13 +221,16 @@ int bsr_matrix_from_triplets_device(const int rows_per_block,
|
|
|
236
221
|
// d_keys.Alternate(): repeated block-row count
|
|
237
222
|
|
|
238
223
|
// Scan repeated block counts
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
224
|
+
{
|
|
225
|
+
size_t buff_size = 0;
|
|
226
|
+
check_cuda(cub::DeviceScan::InclusiveSum(
|
|
227
|
+
nullptr, buff_size, d_keys.Alternate(), d_keys.Alternate(),
|
|
228
|
+
nz_triplet_count, stream));
|
|
229
|
+
ScopedTemporary<> temp(context, buff_size);
|
|
230
|
+
check_cuda(cub::DeviceScan::InclusiveSum(
|
|
231
|
+
temp.buffer(), buff_size, d_keys.Alternate(), d_keys.Alternate(),
|
|
232
|
+
nz_triplet_count, stream));
|
|
233
|
+
}
|
|
246
234
|
|
|
247
235
|
// While we're at it, zero the bsr offsets buffer
|
|
248
236
|
memset_device(WP_CURRENT_CONTEXT, bsr_offsets, 0,
|
|
@@ -250,7 +238,7 @@ int bsr_matrix_from_triplets_device(const int rows_per_block,
|
|
|
250
238
|
|
|
251
239
|
// Wait for number of compressed blocks
|
|
252
240
|
cudaEventSynchronize(bsr_temp.host_sync_event);
|
|
253
|
-
const int compressed_nnz = *
|
|
241
|
+
const int compressed_nnz = *p_nz_triplet_count;
|
|
254
242
|
|
|
255
243
|
// We have all we need to accumulate our repeated blocks
|
|
256
244
|
wp_launch_device(WP_CURRENT_CONTEXT, bsr_merge_blocks, compressed_nnz,
|
|
@@ -259,12 +247,15 @@ int bsr_matrix_from_triplets_device(const int rows_per_block,
|
|
|
259
247
|
bsr_offsets, bsr_columns, bsr_values));
|
|
260
248
|
|
|
261
249
|
// Last, prefix sum the row block counts
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
250
|
+
{
|
|
251
|
+
size_t buff_size = 0;
|
|
252
|
+
check_cuda(cub::DeviceScan::InclusiveSum(nullptr, buff_size, bsr_offsets,
|
|
253
|
+
bsr_offsets, row_count + 1, stream));
|
|
254
|
+
ScopedTemporary<> temp(context, buff_size);
|
|
255
|
+
check_cuda(cub::DeviceScan::InclusiveSum(temp.buffer(), buff_size,
|
|
256
|
+
bsr_offsets, bsr_offsets,
|
|
257
|
+
row_count + 1, stream));
|
|
258
|
+
}
|
|
268
259
|
|
|
269
260
|
return compressed_nnz;
|
|
270
261
|
}
|
|
@@ -347,118 +338,75 @@ bsr_transpose_blocks(const int nnz, const int block_size,
|
|
|
347
338
|
}
|
|
348
339
|
|
|
349
340
|
template <typename T>
|
|
350
|
-
void
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
const int block_size = rows_per_block * cols_per_block;
|
|
358
|
-
|
|
359
|
-
void *context = cuda_context_get_current();
|
|
360
|
-
|
|
361
|
-
// Per-context cached temporary buffers
|
|
362
|
-
TemporaryBuffer &cub_temp = g_temp_buffer_map[context];
|
|
363
|
-
BsrFromTripletsTemp &bsr_temp = g_bsr_from_triplets_temp_map[context];
|
|
364
|
-
|
|
365
|
-
ContextGuard guard(context);
|
|
366
|
-
|
|
367
|
-
cudaStream_t stream = static_cast<cudaStream_t>(cuda_stream_get_current());
|
|
368
|
-
bsr_temp.ensure_fits(nnz);
|
|
369
|
-
|
|
370
|
-
// Zero the transposed offsets
|
|
371
|
-
memset_device(WP_CURRENT_CONTEXT, transposed_bsr_offsets, 0,
|
|
372
|
-
(col_count + 1) * sizeof(int));
|
|
373
|
-
|
|
374
|
-
cub::DoubleBuffer<int> d_keys(bsr_temp.block_indices,
|
|
375
|
-
bsr_temp.block_indices + nnz);
|
|
376
|
-
cub::DoubleBuffer<BsrRowCol> d_values(bsr_temp.combined_row_col,
|
|
377
|
-
bsr_temp.combined_row_col + nnz);
|
|
378
|
-
|
|
379
|
-
wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_fill_row_col, nnz,
|
|
380
|
-
(nnz, row_count, bsr_offsets, bsr_columns, d_keys.Current(),
|
|
381
|
-
d_values.Current(), transposed_bsr_offsets));
|
|
341
|
+
void
|
|
342
|
+
launch_bsr_transpose_blocks(const int nnz, const int block_size,
|
|
343
|
+
const int rows_per_block, const int cols_per_block,
|
|
344
|
+
const int *block_indices,
|
|
345
|
+
const BsrRowCol *transposed_indices,
|
|
346
|
+
const T *bsr_values,
|
|
347
|
+
int *transposed_bsr_columns, T *transposed_bsr_values) {
|
|
382
348
|
|
|
383
|
-
|
|
384
|
-
size_t buff_size = 0;
|
|
385
|
-
check_cuda(cub::DeviceRadixSort::SortPairs(nullptr, buff_size, d_values,
|
|
386
|
-
d_keys, nnz, 0, 64, stream));
|
|
387
|
-
cub_temp.ensure_fits(buff_size);
|
|
388
|
-
check_cuda(cub::DeviceRadixSort::SortPairs(
|
|
389
|
-
cub_temp.buffer, buff_size, d_values, d_keys, nnz, 0, 64, stream));
|
|
390
|
-
|
|
391
|
-
// Prefix sum the trasnposed row block counts
|
|
392
|
-
check_cuda(cub::DeviceScan::InclusiveSum(
|
|
393
|
-
nullptr, buff_size, transposed_bsr_offsets, transposed_bsr_offsets,
|
|
394
|
-
col_count + 1, stream));
|
|
395
|
-
cub_temp.ensure_fits(buff_size);
|
|
396
|
-
check_cuda(cub::DeviceScan::InclusiveSum(
|
|
397
|
-
cub_temp.buffer, buff_size, transposed_bsr_offsets,
|
|
398
|
-
transposed_bsr_offsets, col_count + 1, stream));
|
|
399
|
-
|
|
400
|
-
// Move and transpose invidual blocks
|
|
401
|
-
switch (row_count) {
|
|
349
|
+
switch (rows_per_block) {
|
|
402
350
|
case 1:
|
|
403
|
-
switch (
|
|
351
|
+
switch (cols_per_block) {
|
|
404
352
|
case 1:
|
|
405
353
|
wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
|
|
406
354
|
(nnz, block_size, BsrBlockTransposer<1, 1, T>{},
|
|
407
|
-
|
|
355
|
+
block_indices, transposed_indices, bsr_values,
|
|
408
356
|
transposed_bsr_columns, transposed_bsr_values));
|
|
409
357
|
return;
|
|
410
358
|
case 2:
|
|
411
359
|
wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
|
|
412
360
|
(nnz, block_size, BsrBlockTransposer<1, 2, T>{},
|
|
413
|
-
|
|
361
|
+
block_indices, transposed_indices, bsr_values,
|
|
414
362
|
transposed_bsr_columns, transposed_bsr_values));
|
|
415
363
|
return;
|
|
416
364
|
case 3:
|
|
417
365
|
wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
|
|
418
366
|
(nnz, block_size, BsrBlockTransposer<1, 3, T>{},
|
|
419
|
-
|
|
367
|
+
block_indices, transposed_indices, bsr_values,
|
|
420
368
|
transposed_bsr_columns, transposed_bsr_values));
|
|
421
369
|
return;
|
|
422
370
|
}
|
|
423
371
|
case 2:
|
|
424
|
-
switch (
|
|
372
|
+
switch (cols_per_block) {
|
|
425
373
|
case 1:
|
|
426
374
|
wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
|
|
427
375
|
(nnz, block_size, BsrBlockTransposer<2, 1, T>{},
|
|
428
|
-
|
|
376
|
+
block_indices, transposed_indices, bsr_values,
|
|
429
377
|
transposed_bsr_columns, transposed_bsr_values));
|
|
430
378
|
return;
|
|
431
379
|
case 2:
|
|
432
380
|
wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
|
|
433
381
|
(nnz, block_size, BsrBlockTransposer<2, 2, T>{},
|
|
434
|
-
|
|
382
|
+
block_indices, transposed_indices, bsr_values,
|
|
435
383
|
transposed_bsr_columns, transposed_bsr_values));
|
|
436
384
|
return;
|
|
437
385
|
case 3:
|
|
438
386
|
wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
|
|
439
387
|
(nnz, block_size, BsrBlockTransposer<2, 3, T>{},
|
|
440
|
-
|
|
388
|
+
block_indices, transposed_indices, bsr_values,
|
|
441
389
|
transposed_bsr_columns, transposed_bsr_values));
|
|
442
390
|
return;
|
|
443
391
|
}
|
|
444
392
|
case 3:
|
|
445
|
-
switch (
|
|
393
|
+
switch (cols_per_block) {
|
|
446
394
|
case 1:
|
|
447
395
|
wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
|
|
448
396
|
(nnz, block_size, BsrBlockTransposer<3, 1, T>{},
|
|
449
|
-
|
|
397
|
+
block_indices, transposed_indices, bsr_values,
|
|
450
398
|
transposed_bsr_columns, transposed_bsr_values));
|
|
451
399
|
return;
|
|
452
400
|
case 2:
|
|
453
401
|
wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
|
|
454
402
|
(nnz, block_size, BsrBlockTransposer<3, 2, T>{},
|
|
455
|
-
|
|
403
|
+
block_indices, transposed_indices, bsr_values,
|
|
456
404
|
transposed_bsr_columns, transposed_bsr_values));
|
|
457
405
|
return;
|
|
458
406
|
case 3:
|
|
459
407
|
wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
|
|
460
408
|
(nnz, block_size, BsrBlockTransposer<3, 3, T>{},
|
|
461
|
-
|
|
409
|
+
block_indices, transposed_indices, bsr_values,
|
|
462
410
|
transposed_bsr_columns, transposed_bsr_values));
|
|
463
411
|
return;
|
|
464
412
|
}
|
|
@@ -468,10 +416,72 @@ void bsr_transpose_device(int rows_per_block, int cols_per_block, int row_count,
|
|
|
468
416
|
WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
|
|
469
417
|
(nnz, block_size,
|
|
470
418
|
BsrBlockTransposer<-1, -1, T>{rows_per_block, cols_per_block},
|
|
471
|
-
|
|
419
|
+
block_indices, transposed_indices, bsr_values, transposed_bsr_columns,
|
|
472
420
|
transposed_bsr_values));
|
|
473
421
|
}
|
|
474
422
|
|
|
423
|
+
template <typename T>
|
|
424
|
+
void bsr_transpose_device(int rows_per_block, int cols_per_block, int row_count,
|
|
425
|
+
int col_count, int nnz, const int *bsr_offsets,
|
|
426
|
+
const int *bsr_columns, const T *bsr_values,
|
|
427
|
+
int *transposed_bsr_offsets,
|
|
428
|
+
int *transposed_bsr_columns,
|
|
429
|
+
T *transposed_bsr_values) {
|
|
430
|
+
|
|
431
|
+
const int block_size = rows_per_block * cols_per_block;
|
|
432
|
+
|
|
433
|
+
void *context = cuda_context_get_current();
|
|
434
|
+
ContextGuard guard(context);
|
|
435
|
+
|
|
436
|
+
cudaStream_t stream = static_cast<cudaStream_t>(cuda_stream_get_current());
|
|
437
|
+
|
|
438
|
+
// Zero the transposed offsets
|
|
439
|
+
memset_device(WP_CURRENT_CONTEXT, transposed_bsr_offsets, 0,
|
|
440
|
+
(col_count + 1) * sizeof(int));
|
|
441
|
+
|
|
442
|
+
ScopedTemporary<int> block_indices(context, 2*nnz);
|
|
443
|
+
ScopedTemporary<BsrRowCol> combined_row_col(context, 2*nnz);
|
|
444
|
+
|
|
445
|
+
cub::DoubleBuffer<int> d_keys(block_indices.buffer(),
|
|
446
|
+
block_indices.buffer() + nnz);
|
|
447
|
+
cub::DoubleBuffer<BsrRowCol> d_values(combined_row_col.buffer(),
|
|
448
|
+
combined_row_col.buffer() + nnz);
|
|
449
|
+
|
|
450
|
+
wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_fill_row_col, nnz,
|
|
451
|
+
(nnz, row_count, bsr_offsets, bsr_columns, d_keys.Current(),
|
|
452
|
+
d_values.Current(), transposed_bsr_offsets));
|
|
453
|
+
|
|
454
|
+
// Sort blocks
|
|
455
|
+
{
|
|
456
|
+
size_t buff_size = 0;
|
|
457
|
+
check_cuda(cub::DeviceRadixSort::SortPairs(nullptr, buff_size, d_values,
|
|
458
|
+
d_keys, nnz, 0, 64, stream));
|
|
459
|
+
void* temp_buffer = alloc_temp_device(WP_CURRENT_CONTEXT, buff_size);
|
|
460
|
+
ScopedTemporary<> temp(context, buff_size);
|
|
461
|
+
check_cuda(cub::DeviceRadixSort::SortPairs(
|
|
462
|
+
temp.buffer(), buff_size, d_values, d_keys, nnz, 0, 64, stream));
|
|
463
|
+
}
|
|
464
|
+
|
|
465
|
+
// Prefix sum the transposed row block counts
|
|
466
|
+
{
|
|
467
|
+
size_t buff_size = 0;
|
|
468
|
+
check_cuda(cub::DeviceScan::InclusiveSum(
|
|
469
|
+
nullptr, buff_size, transposed_bsr_offsets, transposed_bsr_offsets,
|
|
470
|
+
col_count + 1, stream));
|
|
471
|
+
ScopedTemporary<> temp(context, buff_size);
|
|
472
|
+
check_cuda(cub::DeviceScan::InclusiveSum(
|
|
473
|
+
temp.buffer(), buff_size, transposed_bsr_offsets,
|
|
474
|
+
transposed_bsr_offsets, col_count + 1, stream));
|
|
475
|
+
}
|
|
476
|
+
|
|
477
|
+
// Move and transpose individual blocks
|
|
478
|
+
launch_bsr_transpose_blocks(
|
|
479
|
+
nnz, block_size,
|
|
480
|
+
rows_per_block, cols_per_block,
|
|
481
|
+
d_keys.Current(), d_values.Current(), bsr_values, transposed_bsr_columns,
|
|
482
|
+
transposed_bsr_values);
|
|
483
|
+
}
|
|
484
|
+
|
|
475
485
|
} // namespace
|
|
476
486
|
|
|
477
487
|
int bsr_matrix_from_triplets_float_device(
|
|
@@ -532,4 +542,4 @@ void bsr_transpose_double_device(int rows_per_block, int cols_per_block,
|
|
|
532
542
|
reinterpret_cast<int *>(transposed_bsr_offsets),
|
|
533
543
|
reinterpret_cast<int *>(transposed_bsr_columns),
|
|
534
544
|
reinterpret_cast<double *>(transposed_bsr_values));
|
|
535
|
-
}
|
|
545
|
+
}
|
warp/native/spatial.h
CHANGED
|
@@ -265,13 +265,13 @@ inline CUDA_CALLABLE Type tensordot(const transform_t<Type>& a, const transform_
|
|
|
265
265
|
}
|
|
266
266
|
|
|
267
267
|
template<typename Type>
|
|
268
|
-
inline CUDA_CALLABLE Type
|
|
268
|
+
inline CUDA_CALLABLE Type extract(const transform_t<Type>& t, int i)
|
|
269
269
|
{
|
|
270
270
|
return t[i];
|
|
271
271
|
}
|
|
272
272
|
|
|
273
273
|
template<typename Type>
|
|
274
|
-
inline void CUDA_CALLABLE
|
|
274
|
+
inline void CUDA_CALLABLE adj_extract(const transform_t<Type>& t, int i, transform_t<Type>& adj_t, int& adj_i, Type adj_ret)
|
|
275
275
|
{
|
|
276
276
|
adj_t[i] += adj_ret;
|
|
277
277
|
}
|
warp/native/temp_buffer.h
CHANGED
|
@@ -1,46 +1,30 @@
|
|
|
1
1
|
|
|
2
2
|
#pragma once
|
|
3
3
|
|
|
4
|
-
#include "warp.h"
|
|
5
4
|
#include "cuda_util.h"
|
|
5
|
+
#include "warp.h"
|
|
6
6
|
|
|
7
7
|
#include <unordered_map>
|
|
8
8
|
|
|
9
|
-
|
|
10
|
-
struct TemporaryBuffer
|
|
9
|
+
template <typename T = char> struct ScopedTemporary
|
|
11
10
|
{
|
|
12
|
-
void *buffer = NULL;
|
|
13
|
-
size_t buffer_size = 0;
|
|
14
11
|
|
|
15
|
-
void
|
|
12
|
+
ScopedTemporary(void *context, size_t size)
|
|
13
|
+
: m_context(context), m_buffer(static_cast<T*>(alloc_temp_device(m_context, size * sizeof(T))))
|
|
16
14
|
{
|
|
17
|
-
if (size > buffer_size)
|
|
18
|
-
{
|
|
19
|
-
size = std::max(2 * size, (buffer_size * 3) / 2);
|
|
20
|
-
|
|
21
|
-
free_device(WP_CURRENT_CONTEXT, buffer);
|
|
22
|
-
buffer = alloc_device(WP_CURRENT_CONTEXT, size);
|
|
23
|
-
buffer_size = size;
|
|
24
|
-
}
|
|
25
15
|
}
|
|
26
|
-
};
|
|
27
16
|
|
|
28
|
-
|
|
29
|
-
{
|
|
30
|
-
|
|
31
|
-
|
|
17
|
+
~ScopedTemporary()
|
|
18
|
+
{
|
|
19
|
+
free_temp_device(m_context, m_buffer);
|
|
20
|
+
}
|
|
32
21
|
|
|
33
|
-
|
|
22
|
+
T *buffer() const
|
|
34
23
|
{
|
|
35
|
-
|
|
36
|
-
{
|
|
37
|
-
free_pinned(buffer);
|
|
38
|
-
buffer = alloc_pinned(size);
|
|
39
|
-
buffer_size = size;
|
|
40
|
-
}
|
|
24
|
+
return m_buffer;
|
|
41
25
|
}
|
|
42
|
-
};
|
|
43
26
|
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
27
|
+
private:
|
|
28
|
+
void *m_context;
|
|
29
|
+
T *m_buffer;
|
|
30
|
+
};
|