warp-lang 1.0.2__py3-none-manylinux2014_x86_64.whl → 1.1.0__py3-none-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +108 -97
- warp/__init__.pyi +1 -1
- warp/bin/warp-clang.so +0 -0
- warp/bin/warp.so +0 -0
- warp/build.py +115 -113
- warp/build_dll.py +383 -375
- warp/builtins.py +3425 -3354
- warp/codegen.py +2878 -2792
- warp/config.py +40 -36
- warp/constants.py +45 -45
- warp/context.py +5194 -5102
- warp/dlpack.py +442 -442
- warp/examples/__init__.py +16 -16
- warp/examples/assets/bear.usd +0 -0
- warp/examples/assets/bunny.usd +0 -0
- warp/examples/assets/cartpole.urdf +110 -110
- warp/examples/assets/crazyflie.usd +0 -0
- warp/examples/assets/cube.usd +0 -0
- warp/examples/assets/nv_ant.xml +92 -92
- warp/examples/assets/nv_humanoid.xml +183 -183
- warp/examples/assets/quadruped.urdf +267 -267
- warp/examples/assets/rocks.nvdb +0 -0
- warp/examples/assets/rocks.usd +0 -0
- warp/examples/assets/sphere.usd +0 -0
- warp/examples/benchmarks/benchmark_api.py +383 -383
- warp/examples/benchmarks/benchmark_cloth.py +278 -277
- warp/examples/benchmarks/benchmark_cloth_cupy.py +88 -88
- warp/examples/benchmarks/benchmark_cloth_jax.py +97 -100
- warp/examples/benchmarks/benchmark_cloth_numba.py +146 -142
- warp/examples/benchmarks/benchmark_cloth_numpy.py +77 -77
- warp/examples/benchmarks/benchmark_cloth_pytorch.py +86 -86
- warp/examples/benchmarks/benchmark_cloth_taichi.py +112 -112
- warp/examples/benchmarks/benchmark_cloth_warp.py +146 -146
- warp/examples/benchmarks/benchmark_launches.py +295 -295
- warp/examples/browse.py +29 -29
- warp/examples/core/example_dem.py +234 -219
- warp/examples/core/example_fluid.py +293 -267
- warp/examples/core/example_graph_capture.py +144 -126
- warp/examples/core/example_marching_cubes.py +188 -174
- warp/examples/core/example_mesh.py +174 -155
- warp/examples/core/example_mesh_intersect.py +205 -193
- warp/examples/core/example_nvdb.py +176 -170
- warp/examples/core/example_raycast.py +105 -90
- warp/examples/core/example_raymarch.py +199 -178
- warp/examples/core/example_render_opengl.py +185 -141
- warp/examples/core/example_sph.py +405 -387
- warp/examples/core/example_torch.py +222 -181
- warp/examples/core/example_wave.py +263 -248
- warp/examples/fem/bsr_utils.py +378 -380
- warp/examples/fem/example_apic_fluid.py +407 -389
- warp/examples/fem/example_convection_diffusion.py +182 -168
- warp/examples/fem/example_convection_diffusion_dg.py +219 -209
- warp/examples/fem/example_convection_diffusion_dg0.py +204 -194
- warp/examples/fem/example_deformed_geometry.py +177 -159
- warp/examples/fem/example_diffusion.py +201 -173
- warp/examples/fem/example_diffusion_3d.py +177 -152
- warp/examples/fem/example_diffusion_mgpu.py +221 -214
- warp/examples/fem/example_mixed_elasticity.py +244 -222
- warp/examples/fem/example_navier_stokes.py +259 -243
- warp/examples/fem/example_stokes.py +220 -192
- warp/examples/fem/example_stokes_transfer.py +265 -249
- warp/examples/fem/mesh_utils.py +133 -109
- warp/examples/fem/plot_utils.py +292 -287
- warp/examples/optim/example_bounce.py +260 -246
- warp/examples/optim/example_cloth_throw.py +222 -209
- warp/examples/optim/example_diffray.py +566 -536
- warp/examples/optim/example_drone.py +864 -835
- warp/examples/optim/example_inverse_kinematics.py +176 -168
- warp/examples/optim/example_inverse_kinematics_torch.py +185 -169
- warp/examples/optim/example_spring_cage.py +239 -231
- warp/examples/optim/example_trajectory.py +223 -199
- warp/examples/optim/example_walker.py +306 -293
- warp/examples/sim/example_cartpole.py +139 -129
- warp/examples/sim/example_cloth.py +196 -186
- warp/examples/sim/example_granular.py +124 -111
- warp/examples/sim/example_granular_collision_sdf.py +197 -186
- warp/examples/sim/example_jacobian_ik.py +236 -214
- warp/examples/sim/example_particle_chain.py +118 -105
- warp/examples/sim/example_quadruped.py +193 -180
- warp/examples/sim/example_rigid_chain.py +197 -187
- warp/examples/sim/example_rigid_contact.py +189 -177
- warp/examples/sim/example_rigid_force.py +127 -125
- warp/examples/sim/example_rigid_gyroscopic.py +109 -95
- warp/examples/sim/example_rigid_soft_contact.py +134 -122
- warp/examples/sim/example_soft_body.py +190 -177
- warp/fabric.py +337 -335
- warp/fem/__init__.py +60 -27
- warp/fem/cache.py +401 -388
- warp/fem/dirichlet.py +178 -179
- warp/fem/domain.py +262 -263
- warp/fem/field/__init__.py +100 -101
- warp/fem/field/field.py +148 -149
- warp/fem/field/nodal_field.py +298 -299
- warp/fem/field/restriction.py +22 -21
- warp/fem/field/test.py +180 -181
- warp/fem/field/trial.py +183 -183
- warp/fem/geometry/__init__.py +15 -19
- warp/fem/geometry/closest_point.py +69 -70
- warp/fem/geometry/deformed_geometry.py +270 -271
- warp/fem/geometry/element.py +744 -744
- warp/fem/geometry/geometry.py +184 -186
- warp/fem/geometry/grid_2d.py +380 -373
- warp/fem/geometry/grid_3d.py +441 -435
- warp/fem/geometry/hexmesh.py +953 -953
- warp/fem/geometry/partition.py +374 -376
- warp/fem/geometry/quadmesh_2d.py +532 -532
- warp/fem/geometry/tetmesh.py +840 -840
- warp/fem/geometry/trimesh_2d.py +577 -577
- warp/fem/integrate.py +1630 -1615
- warp/fem/operator.py +190 -191
- warp/fem/polynomial.py +214 -213
- warp/fem/quadrature/__init__.py +2 -2
- warp/fem/quadrature/pic_quadrature.py +243 -245
- warp/fem/quadrature/quadrature.py +295 -294
- warp/fem/space/__init__.py +294 -292
- warp/fem/space/basis_space.py +488 -489
- warp/fem/space/collocated_function_space.py +100 -105
- warp/fem/space/dof_mapper.py +236 -236
- warp/fem/space/function_space.py +148 -145
- warp/fem/space/grid_2d_function_space.py +267 -267
- warp/fem/space/grid_3d_function_space.py +305 -306
- warp/fem/space/hexmesh_function_space.py +350 -352
- warp/fem/space/partition.py +350 -350
- warp/fem/space/quadmesh_2d_function_space.py +368 -369
- warp/fem/space/restriction.py +158 -160
- warp/fem/space/shape/__init__.py +13 -15
- warp/fem/space/shape/cube_shape_function.py +738 -738
- warp/fem/space/shape/shape_function.py +102 -103
- warp/fem/space/shape/square_shape_function.py +611 -611
- warp/fem/space/shape/tet_shape_function.py +565 -567
- warp/fem/space/shape/triangle_shape_function.py +429 -429
- warp/fem/space/tetmesh_function_space.py +294 -292
- warp/fem/space/topology.py +297 -295
- warp/fem/space/trimesh_2d_function_space.py +223 -221
- warp/fem/types.py +77 -77
- warp/fem/utils.py +495 -495
- warp/jax.py +166 -141
- warp/jax_experimental.py +341 -339
- warp/native/array.h +1072 -1025
- warp/native/builtin.h +1560 -1560
- warp/native/bvh.cpp +398 -398
- warp/native/bvh.cu +525 -525
- warp/native/bvh.h +429 -429
- warp/native/clang/clang.cpp +495 -464
- warp/native/crt.cpp +31 -31
- warp/native/crt.h +334 -334
- warp/native/cuda_crt.h +1049 -1049
- warp/native/cuda_util.cpp +549 -540
- warp/native/cuda_util.h +288 -203
- warp/native/cutlass_gemm.cpp +34 -34
- warp/native/cutlass_gemm.cu +372 -372
- warp/native/error.cpp +66 -66
- warp/native/error.h +27 -27
- warp/native/fabric.h +228 -228
- warp/native/hashgrid.cpp +301 -278
- warp/native/hashgrid.cu +78 -77
- warp/native/hashgrid.h +227 -227
- warp/native/initializer_array.h +32 -32
- warp/native/intersect.h +1204 -1204
- warp/native/intersect_adj.h +365 -365
- warp/native/intersect_tri.h +322 -322
- warp/native/marching.cpp +2 -2
- warp/native/marching.cu +497 -497
- warp/native/marching.h +2 -2
- warp/native/mat.h +1498 -1498
- warp/native/matnn.h +333 -333
- warp/native/mesh.cpp +203 -203
- warp/native/mesh.cu +293 -293
- warp/native/mesh.h +1887 -1887
- warp/native/nanovdb/NanoVDB.h +4782 -4782
- warp/native/nanovdb/PNanoVDB.h +2553 -2553
- warp/native/nanovdb/PNanoVDBWrite.h +294 -294
- warp/native/noise.h +850 -850
- warp/native/quat.h +1084 -1084
- warp/native/rand.h +299 -299
- warp/native/range.h +108 -108
- warp/native/reduce.cpp +156 -156
- warp/native/reduce.cu +348 -348
- warp/native/runlength_encode.cpp +61 -61
- warp/native/runlength_encode.cu +46 -46
- warp/native/scan.cpp +30 -30
- warp/native/scan.cu +36 -36
- warp/native/scan.h +7 -7
- warp/native/solid_angle.h +442 -442
- warp/native/sort.cpp +94 -94
- warp/native/sort.cu +97 -97
- warp/native/sort.h +14 -14
- warp/native/sparse.cpp +337 -337
- warp/native/sparse.cu +544 -544
- warp/native/spatial.h +630 -630
- warp/native/svd.h +562 -562
- warp/native/temp_buffer.h +30 -30
- warp/native/vec.h +1132 -1132
- warp/native/volume.cpp +297 -297
- warp/native/volume.cu +32 -32
- warp/native/volume.h +538 -538
- warp/native/volume_builder.cu +425 -425
- warp/native/volume_builder.h +19 -19
- warp/native/warp.cpp +1057 -1052
- warp/native/warp.cu +2943 -2828
- warp/native/warp.h +313 -305
- warp/optim/__init__.py +9 -9
- warp/optim/adam.py +120 -120
- warp/optim/linear.py +1104 -939
- warp/optim/sgd.py +104 -92
- warp/render/__init__.py +10 -10
- warp/render/render_opengl.py +3217 -3204
- warp/render/render_usd.py +768 -749
- warp/render/utils.py +152 -150
- warp/sim/__init__.py +52 -59
- warp/sim/articulation.py +685 -685
- warp/sim/collide.py +1594 -1590
- warp/sim/import_mjcf.py +489 -481
- warp/sim/import_snu.py +220 -221
- warp/sim/import_urdf.py +536 -516
- warp/sim/import_usd.py +887 -881
- warp/sim/inertia.py +316 -317
- warp/sim/integrator.py +234 -233
- warp/sim/integrator_euler.py +1956 -1956
- warp/sim/integrator_featherstone.py +1910 -1991
- warp/sim/integrator_xpbd.py +3294 -3312
- warp/sim/model.py +4473 -4314
- warp/sim/particles.py +113 -112
- warp/sim/render.py +417 -403
- warp/sim/utils.py +413 -410
- warp/sparse.py +1227 -1227
- warp/stubs.py +2109 -2469
- warp/tape.py +1162 -225
- warp/tests/__init__.py +1 -1
- warp/tests/__main__.py +4 -4
- warp/tests/assets/torus.usda +105 -105
- warp/tests/aux_test_class_kernel.py +26 -26
- warp/tests/aux_test_compile_consts_dummy.py +10 -10
- warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -21
- warp/tests/aux_test_dependent.py +22 -22
- warp/tests/aux_test_grad_customs.py +23 -23
- warp/tests/aux_test_reference.py +11 -11
- warp/tests/aux_test_reference_reference.py +10 -10
- warp/tests/aux_test_square.py +17 -17
- warp/tests/aux_test_unresolved_func.py +14 -14
- warp/tests/aux_test_unresolved_symbol.py +14 -14
- warp/tests/disabled_kinematics.py +239 -239
- warp/tests/run_coverage_serial.py +31 -31
- warp/tests/test_adam.py +157 -157
- warp/tests/test_arithmetic.py +1124 -1124
- warp/tests/test_array.py +2417 -2326
- warp/tests/test_array_reduce.py +150 -150
- warp/tests/test_async.py +668 -656
- warp/tests/test_atomic.py +141 -141
- warp/tests/test_bool.py +204 -149
- warp/tests/test_builtins_resolution.py +1292 -1292
- warp/tests/test_bvh.py +164 -171
- warp/tests/test_closest_point_edge_edge.py +228 -228
- warp/tests/test_codegen.py +566 -553
- warp/tests/test_compile_consts.py +97 -101
- warp/tests/test_conditional.py +246 -246
- warp/tests/test_copy.py +232 -215
- warp/tests/test_ctypes.py +632 -632
- warp/tests/test_dense.py +67 -67
- warp/tests/test_devices.py +91 -98
- warp/tests/test_dlpack.py +530 -529
- warp/tests/test_examples.py +400 -378
- warp/tests/test_fabricarray.py +955 -955
- warp/tests/test_fast_math.py +62 -54
- warp/tests/test_fem.py +1277 -1278
- warp/tests/test_fp16.py +130 -130
- warp/tests/test_func.py +338 -337
- warp/tests/test_generics.py +571 -571
- warp/tests/test_grad.py +746 -640
- warp/tests/test_grad_customs.py +333 -336
- warp/tests/test_hash_grid.py +210 -164
- warp/tests/test_import.py +39 -39
- warp/tests/test_indexedarray.py +1134 -1134
- warp/tests/test_intersect.py +67 -67
- warp/tests/test_jax.py +307 -307
- warp/tests/test_large.py +167 -164
- warp/tests/test_launch.py +354 -354
- warp/tests/test_lerp.py +261 -261
- warp/tests/test_linear_solvers.py +191 -171
- warp/tests/test_lvalue.py +421 -493
- warp/tests/test_marching_cubes.py +65 -65
- warp/tests/test_mat.py +1801 -1827
- warp/tests/test_mat_lite.py +115 -115
- warp/tests/test_mat_scalar_ops.py +2907 -2889
- warp/tests/test_math.py +126 -193
- warp/tests/test_matmul.py +500 -499
- warp/tests/test_matmul_lite.py +410 -410
- warp/tests/test_mempool.py +188 -190
- warp/tests/test_mesh.py +284 -324
- warp/tests/test_mesh_query_aabb.py +228 -241
- warp/tests/test_mesh_query_point.py +692 -702
- warp/tests/test_mesh_query_ray.py +292 -303
- warp/tests/test_mlp.py +276 -276
- warp/tests/test_model.py +110 -110
- warp/tests/test_modules_lite.py +39 -39
- warp/tests/test_multigpu.py +163 -163
- warp/tests/test_noise.py +248 -248
- warp/tests/test_operators.py +250 -250
- warp/tests/test_options.py +123 -125
- warp/tests/test_peer.py +133 -137
- warp/tests/test_pinned.py +78 -78
- warp/tests/test_print.py +54 -54
- warp/tests/test_quat.py +2086 -2086
- warp/tests/test_rand.py +288 -288
- warp/tests/test_reload.py +217 -217
- warp/tests/test_rounding.py +179 -179
- warp/tests/test_runlength_encode.py +190 -190
- warp/tests/test_sim_grad.py +243 -0
- warp/tests/test_sim_kinematics.py +91 -97
- warp/tests/test_smoothstep.py +168 -168
- warp/tests/test_snippet.py +305 -266
- warp/tests/test_sparse.py +468 -460
- warp/tests/test_spatial.py +2148 -2148
- warp/tests/test_streams.py +486 -473
- warp/tests/test_struct.py +710 -675
- warp/tests/test_tape.py +173 -148
- warp/tests/test_torch.py +743 -743
- warp/tests/test_transient_module.py +87 -87
- warp/tests/test_types.py +556 -659
- warp/tests/test_utils.py +490 -499
- warp/tests/test_vec.py +1264 -1268
- warp/tests/test_vec_lite.py +73 -73
- warp/tests/test_vec_scalar_ops.py +2099 -2099
- warp/tests/test_verify_fp.py +94 -94
- warp/tests/test_volume.py +737 -736
- warp/tests/test_volume_write.py +255 -265
- warp/tests/unittest_serial.py +37 -37
- warp/tests/unittest_suites.py +363 -359
- warp/tests/unittest_utils.py +603 -578
- warp/tests/unused_test_misc.py +71 -71
- warp/tests/walkthrough_debug.py +85 -85
- warp/thirdparty/appdirs.py +598 -598
- warp/thirdparty/dlpack.py +143 -143
- warp/thirdparty/unittest_parallel.py +566 -561
- warp/torch.py +321 -295
- warp/types.py +4504 -4450
- warp/utils.py +1008 -821
- {warp_lang-1.0.2.dist-info → warp_lang-1.1.0.dist-info}/LICENSE.md +126 -126
- {warp_lang-1.0.2.dist-info → warp_lang-1.1.0.dist-info}/METADATA +338 -400
- warp_lang-1.1.0.dist-info/RECORD +352 -0
- warp/examples/assets/cube.usda +0 -42
- warp/examples/assets/sphere.usda +0 -56
- warp/examples/assets/torus.usda +0 -105
- warp_lang-1.0.2.dist-info/RECORD +0 -352
- {warp_lang-1.0.2.dist-info → warp_lang-1.1.0.dist-info}/WHEEL +0 -0
- {warp_lang-1.0.2.dist-info → warp_lang-1.1.0.dist-info}/top_level.txt +0 -0
warp/native/sparse.cu
CHANGED
|
@@ -1,544 +1,544 @@
|
|
|
1
|
-
#include "cuda_util.h"
|
|
2
|
-
#include "warp.h"
|
|
3
|
-
|
|
4
|
-
#define THRUST_IGNORE_CUB_VERSION_CHECK
|
|
5
|
-
|
|
6
|
-
#include <cub/device/device_radix_sort.cuh>
|
|
7
|
-
#include <cub/device/device_run_length_encode.cuh>
|
|
8
|
-
#include <cub/device/device_scan.cuh>
|
|
9
|
-
#include <cub/device/device_select.cuh>
|
|
10
|
-
|
|
11
|
-
namespace {
|
|
12
|
-
|
|
13
|
-
// Combined row+column value that can be radix-sorted with CUB
|
|
14
|
-
using BsrRowCol = uint64_t;
|
|
15
|
-
|
|
16
|
-
CUDA_CALLABLE BsrRowCol bsr_combine_row_col(uint32_t row, uint32_t col) {
|
|
17
|
-
return (static_cast<uint64_t>(row) << 32) | col;
|
|
18
|
-
}
|
|
19
|
-
|
|
20
|
-
CUDA_CALLABLE uint32_t bsr_get_row(const BsrRowCol &row_col) {
|
|
21
|
-
return row_col >> 32;
|
|
22
|
-
}
|
|
23
|
-
|
|
24
|
-
CUDA_CALLABLE uint32_t bsr_get_col(const BsrRowCol &row_col) {
|
|
25
|
-
return row_col & INT_MAX;
|
|
26
|
-
}
|
|
27
|
-
|
|
28
|
-
// Cached temporary storage
|
|
29
|
-
struct BsrFromTripletsTemp {
|
|
30
|
-
|
|
31
|
-
int *count_buffer = NULL;
|
|
32
|
-
cudaEvent_t host_sync_event = NULL;
|
|
33
|
-
|
|
34
|
-
BsrFromTripletsTemp()
|
|
35
|
-
: count_buffer(static_cast<int*>(alloc_pinned(sizeof(int))))
|
|
36
|
-
{
|
|
37
|
-
cudaEventCreateWithFlags(&host_sync_event, cudaEventDisableTiming);
|
|
38
|
-
}
|
|
39
|
-
|
|
40
|
-
~BsrFromTripletsTemp()
|
|
41
|
-
{
|
|
42
|
-
cudaEventDestroy(host_sync_event);
|
|
43
|
-
free_pinned(count_buffer);
|
|
44
|
-
}
|
|
45
|
-
|
|
46
|
-
BsrFromTripletsTemp(const BsrFromTripletsTemp&) = delete;
|
|
47
|
-
BsrFromTripletsTemp& operator=(const BsrFromTripletsTemp&) = delete;
|
|
48
|
-
|
|
49
|
-
};
|
|
50
|
-
|
|
51
|
-
// map temp buffers to CUDA contexts
|
|
52
|
-
static std::unordered_map<void *, BsrFromTripletsTemp> g_bsr_from_triplets_temp_map;
|
|
53
|
-
|
|
54
|
-
template <typename T> struct BsrBlockIsNotZero {
|
|
55
|
-
int block_size;
|
|
56
|
-
const T *values;
|
|
57
|
-
|
|
58
|
-
CUDA_CALLABLE_DEVICE bool operator()(int i) const {
|
|
59
|
-
const T *val = values + i * block_size;
|
|
60
|
-
for (int i = 0; i < block_size; ++i, ++val) {
|
|
61
|
-
if (*val != T(0))
|
|
62
|
-
return true;
|
|
63
|
-
}
|
|
64
|
-
return false;
|
|
65
|
-
}
|
|
66
|
-
};
|
|
67
|
-
|
|
68
|
-
__global__ void bsr_fill_block_indices(int nnz, int *block_indices) {
|
|
69
|
-
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
|
70
|
-
if (i >= nnz)
|
|
71
|
-
return;
|
|
72
|
-
|
|
73
|
-
block_indices[i] = i;
|
|
74
|
-
}
|
|
75
|
-
|
|
76
|
-
__global__ void bsr_fill_row_col(const int *nnz, const int *block_indices,
|
|
77
|
-
const int *tpl_rows, const int *tpl_columns,
|
|
78
|
-
BsrRowCol *tpl_row_col) {
|
|
79
|
-
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
|
80
|
-
if (i >= *nnz)
|
|
81
|
-
return;
|
|
82
|
-
|
|
83
|
-
const int block = block_indices[i];
|
|
84
|
-
|
|
85
|
-
BsrRowCol row_col = bsr_combine_row_col(tpl_rows[block], tpl_columns[block]);
|
|
86
|
-
tpl_row_col[i] = row_col;
|
|
87
|
-
}
|
|
88
|
-
|
|
89
|
-
template <typename T>
|
|
90
|
-
__global__ void
|
|
91
|
-
bsr_merge_blocks(int nnz, int block_size, const int *block_offsets,
|
|
92
|
-
const int *sorted_block_indices,
|
|
93
|
-
const BsrRowCol *unique_row_cols, const T *tpl_values,
|
|
94
|
-
int *bsr_row_counts, int *bsr_cols, T *bsr_values)
|
|
95
|
-
|
|
96
|
-
{
|
|
97
|
-
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
|
98
|
-
if (i >= nnz)
|
|
99
|
-
return;
|
|
100
|
-
|
|
101
|
-
const int beg = i ? block_offsets[i - 1] : 0;
|
|
102
|
-
const int end = block_offsets[i];
|
|
103
|
-
|
|
104
|
-
BsrRowCol row_col = unique_row_cols[i];
|
|
105
|
-
|
|
106
|
-
bsr_cols[i] = bsr_get_col(row_col);
|
|
107
|
-
atomicAdd(bsr_row_counts + bsr_get_row(row_col) + 1, 1);
|
|
108
|
-
|
|
109
|
-
if (bsr_values == nullptr)
|
|
110
|
-
return;
|
|
111
|
-
|
|
112
|
-
T *bsr_val = bsr_values + i * block_size;
|
|
113
|
-
const T *tpl_val = tpl_values + sorted_block_indices[beg] * block_size;
|
|
114
|
-
|
|
115
|
-
for (int k = 0; k < block_size; ++k) {
|
|
116
|
-
bsr_val[k] = tpl_val[k];
|
|
117
|
-
}
|
|
118
|
-
|
|
119
|
-
for (int cur = beg + 1; cur != end; ++cur) {
|
|
120
|
-
const T *tpl_val = tpl_values + sorted_block_indices[cur] * block_size;
|
|
121
|
-
for (int k = 0; k < block_size; ++k) {
|
|
122
|
-
bsr_val[k] += tpl_val[k];
|
|
123
|
-
}
|
|
124
|
-
}
|
|
125
|
-
}
|
|
126
|
-
|
|
127
|
-
template <typename T>
|
|
128
|
-
int bsr_matrix_from_triplets_device(const int rows_per_block,
|
|
129
|
-
const int cols_per_block,
|
|
130
|
-
const int row_count, const int nnz,
|
|
131
|
-
const int *tpl_rows, const int *tpl_columns,
|
|
132
|
-
const T *tpl_values, int *bsr_offsets,
|
|
133
|
-
int *bsr_columns, T *bsr_values) {
|
|
134
|
-
const int block_size = rows_per_block * cols_per_block;
|
|
135
|
-
|
|
136
|
-
void *context = cuda_context_get_current();
|
|
137
|
-
ContextGuard guard(context);
|
|
138
|
-
|
|
139
|
-
// Per-context cached temporary buffers
|
|
140
|
-
BsrFromTripletsTemp &bsr_temp = g_bsr_from_triplets_temp_map[context];
|
|
141
|
-
|
|
142
|
-
cudaStream_t stream = static_cast<cudaStream_t>(cuda_stream_get_current());
|
|
143
|
-
|
|
144
|
-
ScopedTemporary<int> block_indices(context, 2*nnz);
|
|
145
|
-
ScopedTemporary<BsrRowCol> combined_row_col(context, 2*nnz);
|
|
146
|
-
|
|
147
|
-
cub::DoubleBuffer<int> d_keys(block_indices.buffer(),
|
|
148
|
-
block_indices.buffer() + nnz);
|
|
149
|
-
cub::DoubleBuffer<BsrRowCol> d_values(combined_row_col.buffer(),
|
|
150
|
-
combined_row_col.buffer() + nnz);
|
|
151
|
-
|
|
152
|
-
int *p_nz_triplet_count = bsr_temp.count_buffer;
|
|
153
|
-
|
|
154
|
-
wp_launch_device(WP_CURRENT_CONTEXT, bsr_fill_block_indices, nnz,
|
|
155
|
-
(nnz, d_keys.Current()));
|
|
156
|
-
|
|
157
|
-
if (tpl_values) {
|
|
158
|
-
|
|
159
|
-
// Remove zero blocks
|
|
160
|
-
{
|
|
161
|
-
size_t buff_size = 0;
|
|
162
|
-
BsrBlockIsNotZero<T> isNotZero{block_size, tpl_values};
|
|
163
|
-
check_cuda(cub::DeviceSelect::If(nullptr, buff_size, d_keys.Current(),
|
|
164
|
-
d_keys.Alternate(), p_nz_triplet_count,
|
|
165
|
-
nnz, isNotZero, stream));
|
|
166
|
-
ScopedTemporary<> temp(context, buff_size);
|
|
167
|
-
check_cuda(cub::DeviceSelect::If(
|
|
168
|
-
temp.buffer(), buff_size, d_keys.Current(), d_keys.Alternate(),
|
|
169
|
-
p_nz_triplet_count, nnz, isNotZero, stream));
|
|
170
|
-
}
|
|
171
|
-
cudaEventRecord(bsr_temp.host_sync_event, stream);
|
|
172
|
-
|
|
173
|
-
// switch current/alternate in double buffer
|
|
174
|
-
d_keys.selector ^= 1;
|
|
175
|
-
|
|
176
|
-
} else {
|
|
177
|
-
*p_nz_triplet_count = nnz;
|
|
178
|
-
}
|
|
179
|
-
|
|
180
|
-
// Combine rows and columns so we can sort on them both
|
|
181
|
-
wp_launch_device(WP_CURRENT_CONTEXT, bsr_fill_row_col, nnz,
|
|
182
|
-
(p_nz_triplet_count, d_keys.Current(), tpl_rows, tpl_columns,
|
|
183
|
-
d_values.Current()));
|
|
184
|
-
|
|
185
|
-
if (tpl_values) {
|
|
186
|
-
// Make sure count is available on host
|
|
187
|
-
cudaEventSynchronize(bsr_temp.host_sync_event);
|
|
188
|
-
}
|
|
189
|
-
|
|
190
|
-
const int nz_triplet_count = *p_nz_triplet_count;
|
|
191
|
-
|
|
192
|
-
// Sort
|
|
193
|
-
{
|
|
194
|
-
size_t buff_size = 0;
|
|
195
|
-
check_cuda(cub::DeviceRadixSort::SortPairs(
|
|
196
|
-
nullptr, buff_size, d_values, d_keys, nz_triplet_count, 0, 64, stream));
|
|
197
|
-
ScopedTemporary<> temp(context, buff_size);
|
|
198
|
-
check_cuda(cub::DeviceRadixSort::SortPairs(temp.buffer(), buff_size,
|
|
199
|
-
d_values, d_keys, nz_triplet_count,
|
|
200
|
-
0, 64, stream));
|
|
201
|
-
}
|
|
202
|
-
|
|
203
|
-
// Runlength encode row-col sequences
|
|
204
|
-
{
|
|
205
|
-
size_t buff_size = 0;
|
|
206
|
-
check_cuda(cub::DeviceRunLengthEncode::Encode(
|
|
207
|
-
nullptr, buff_size, d_values.Current(), d_values.Alternate(),
|
|
208
|
-
d_keys.Alternate(), p_nz_triplet_count, nz_triplet_count, stream));
|
|
209
|
-
ScopedTemporary<> temp(context, buff_size);
|
|
210
|
-
check_cuda(cub::DeviceRunLengthEncode::Encode(
|
|
211
|
-
temp.buffer(), buff_size, d_values.Current(), d_values.Alternate(),
|
|
212
|
-
d_keys.Alternate(), p_nz_triplet_count, nz_triplet_count, stream));
|
|
213
|
-
}
|
|
214
|
-
|
|
215
|
-
cudaEventRecord(bsr_temp.host_sync_event, stream);
|
|
216
|
-
|
|
217
|
-
// Now we have the following:
|
|
218
|
-
// d_values.Current(): sorted block row-col
|
|
219
|
-
// d_values.Alternate(): sorted unique row-col
|
|
220
|
-
// d_keys.Current(): sorted block indices
|
|
221
|
-
// d_keys.Alternate(): repeated block-row count
|
|
222
|
-
|
|
223
|
-
// Scan repeated block counts
|
|
224
|
-
{
|
|
225
|
-
size_t buff_size = 0;
|
|
226
|
-
check_cuda(cub::DeviceScan::InclusiveSum(
|
|
227
|
-
nullptr, buff_size, d_keys.Alternate(), d_keys.Alternate(),
|
|
228
|
-
nz_triplet_count, stream));
|
|
229
|
-
ScopedTemporary<> temp(context, buff_size);
|
|
230
|
-
check_cuda(cub::DeviceScan::InclusiveSum(
|
|
231
|
-
temp.buffer(), buff_size, d_keys.Alternate(), d_keys.Alternate(),
|
|
232
|
-
nz_triplet_count, stream));
|
|
233
|
-
}
|
|
234
|
-
|
|
235
|
-
// While we're at it, zero the bsr offsets buffer
|
|
236
|
-
memset_device(WP_CURRENT_CONTEXT, bsr_offsets, 0,
|
|
237
|
-
(row_count + 1) * sizeof(int));
|
|
238
|
-
|
|
239
|
-
// Wait for number of compressed blocks
|
|
240
|
-
cudaEventSynchronize(bsr_temp.host_sync_event);
|
|
241
|
-
const int compressed_nnz = *p_nz_triplet_count;
|
|
242
|
-
|
|
243
|
-
// We have all we need to accumulate our repeated blocks
|
|
244
|
-
wp_launch_device(WP_CURRENT_CONTEXT, bsr_merge_blocks, compressed_nnz,
|
|
245
|
-
(compressed_nnz, block_size, d_keys.Alternate(),
|
|
246
|
-
d_keys.Current(), d_values.Alternate(), tpl_values,
|
|
247
|
-
bsr_offsets, bsr_columns, bsr_values));
|
|
248
|
-
|
|
249
|
-
// Last, prefix sum the row block counts
|
|
250
|
-
{
|
|
251
|
-
size_t buff_size = 0;
|
|
252
|
-
check_cuda(cub::DeviceScan::InclusiveSum(nullptr, buff_size, bsr_offsets,
|
|
253
|
-
bsr_offsets, row_count + 1, stream));
|
|
254
|
-
ScopedTemporary<> temp(context, buff_size);
|
|
255
|
-
check_cuda(cub::DeviceScan::InclusiveSum(temp.buffer(), buff_size,
|
|
256
|
-
bsr_offsets, bsr_offsets,
|
|
257
|
-
row_count + 1, stream));
|
|
258
|
-
}
|
|
259
|
-
|
|
260
|
-
return compressed_nnz;
|
|
261
|
-
}
|
|
262
|
-
|
|
263
|
-
__global__ void bsr_transpose_fill_row_col(const int nnz, const int row_count,
|
|
264
|
-
const int *bsr_offsets,
|
|
265
|
-
const int *bsr_columns,
|
|
266
|
-
int *block_indices,
|
|
267
|
-
BsrRowCol *transposed_row_col,
|
|
268
|
-
int *transposed_bsr_offsets) {
|
|
269
|
-
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
|
270
|
-
if (i >= nnz)
|
|
271
|
-
return;
|
|
272
|
-
|
|
273
|
-
block_indices[i] = i;
|
|
274
|
-
|
|
275
|
-
// Binary search for row
|
|
276
|
-
int lower = 0;
|
|
277
|
-
int upper = row_count - 1;
|
|
278
|
-
|
|
279
|
-
while (lower < upper) {
|
|
280
|
-
int mid = lower + (upper - lower) / 2;
|
|
281
|
-
|
|
282
|
-
if (bsr_offsets[mid + 1] <= i) {
|
|
283
|
-
lower = mid + 1;
|
|
284
|
-
} else {
|
|
285
|
-
upper = mid;
|
|
286
|
-
}
|
|
287
|
-
}
|
|
288
|
-
|
|
289
|
-
const int row = lower;
|
|
290
|
-
const int col = bsr_columns[i];
|
|
291
|
-
BsrRowCol row_col = bsr_combine_row_col(col, row);
|
|
292
|
-
transposed_row_col[i] = row_col;
|
|
293
|
-
|
|
294
|
-
atomicAdd(transposed_bsr_offsets + col + 1, 1);
|
|
295
|
-
}
|
|
296
|
-
|
|
297
|
-
template <int Rows, int Cols, typename T> struct BsrBlockTransposer {
|
|
298
|
-
void CUDA_CALLABLE_DEVICE operator()(const T *src, T *dest) const {
|
|
299
|
-
for (int r = 0; r < Rows; ++r) {
|
|
300
|
-
for (int c = 0; c < Cols; ++c) {
|
|
301
|
-
dest[c * Rows + r] = src[r * Cols + c];
|
|
302
|
-
}
|
|
303
|
-
}
|
|
304
|
-
}
|
|
305
|
-
};
|
|
306
|
-
|
|
307
|
-
template <typename T> struct BsrBlockTransposer<-1, -1, T> {
|
|
308
|
-
|
|
309
|
-
int row_count;
|
|
310
|
-
int col_count;
|
|
311
|
-
|
|
312
|
-
void CUDA_CALLABLE_DEVICE operator()(const T *src, T *dest) const {
|
|
313
|
-
for (int r = 0; r < row_count; ++r) {
|
|
314
|
-
for (int c = 0; c < col_count; ++c) {
|
|
315
|
-
dest[c * row_count + r] = src[r * col_count + c];
|
|
316
|
-
}
|
|
317
|
-
}
|
|
318
|
-
}
|
|
319
|
-
};
|
|
320
|
-
|
|
321
|
-
template <int Rows, int Cols, typename T>
|
|
322
|
-
__global__ void
|
|
323
|
-
bsr_transpose_blocks(const int nnz, const int block_size,
|
|
324
|
-
BsrBlockTransposer<Rows, Cols, T> transposer,
|
|
325
|
-
const int *block_indices,
|
|
326
|
-
const BsrRowCol *transposed_indices, const T *bsr_values,
|
|
327
|
-
int *transposed_bsr_columns, T *transposed_bsr_values) {
|
|
328
|
-
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
|
329
|
-
if (i >= nnz)
|
|
330
|
-
return;
|
|
331
|
-
|
|
332
|
-
const int src_idx = block_indices[i];
|
|
333
|
-
|
|
334
|
-
transposer(bsr_values + src_idx * block_size,
|
|
335
|
-
transposed_bsr_values + i * block_size);
|
|
336
|
-
|
|
337
|
-
transposed_bsr_columns[i] = bsr_get_col(transposed_indices[i]);
|
|
338
|
-
}
|
|
339
|
-
|
|
340
|
-
template <typename T>
|
|
341
|
-
void
|
|
342
|
-
launch_bsr_transpose_blocks(const int nnz, const int block_size,
|
|
343
|
-
const int rows_per_block, const int cols_per_block,
|
|
344
|
-
const int *block_indices,
|
|
345
|
-
const BsrRowCol *transposed_indices,
|
|
346
|
-
const T *bsr_values,
|
|
347
|
-
int *transposed_bsr_columns, T *transposed_bsr_values) {
|
|
348
|
-
|
|
349
|
-
switch (rows_per_block) {
|
|
350
|
-
case 1:
|
|
351
|
-
switch (cols_per_block) {
|
|
352
|
-
case 1:
|
|
353
|
-
wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
|
|
354
|
-
(nnz, block_size, BsrBlockTransposer<1, 1, T>{},
|
|
355
|
-
block_indices, transposed_indices, bsr_values,
|
|
356
|
-
transposed_bsr_columns, transposed_bsr_values));
|
|
357
|
-
return;
|
|
358
|
-
case 2:
|
|
359
|
-
wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
|
|
360
|
-
(nnz, block_size, BsrBlockTransposer<1, 2, T>{},
|
|
361
|
-
block_indices, transposed_indices, bsr_values,
|
|
362
|
-
transposed_bsr_columns, transposed_bsr_values));
|
|
363
|
-
return;
|
|
364
|
-
case 3:
|
|
365
|
-
wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
|
|
366
|
-
(nnz, block_size, BsrBlockTransposer<1, 3, T>{},
|
|
367
|
-
block_indices, transposed_indices, bsr_values,
|
|
368
|
-
transposed_bsr_columns, transposed_bsr_values));
|
|
369
|
-
return;
|
|
370
|
-
}
|
|
371
|
-
case 2:
|
|
372
|
-
switch (cols_per_block) {
|
|
373
|
-
case 1:
|
|
374
|
-
wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
|
|
375
|
-
(nnz, block_size, BsrBlockTransposer<2, 1, T>{},
|
|
376
|
-
block_indices, transposed_indices, bsr_values,
|
|
377
|
-
transposed_bsr_columns, transposed_bsr_values));
|
|
378
|
-
return;
|
|
379
|
-
case 2:
|
|
380
|
-
wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
|
|
381
|
-
(nnz, block_size, BsrBlockTransposer<2, 2, T>{},
|
|
382
|
-
block_indices, transposed_indices, bsr_values,
|
|
383
|
-
transposed_bsr_columns, transposed_bsr_values));
|
|
384
|
-
return;
|
|
385
|
-
case 3:
|
|
386
|
-
wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
|
|
387
|
-
(nnz, block_size, BsrBlockTransposer<2, 3, T>{},
|
|
388
|
-
block_indices, transposed_indices, bsr_values,
|
|
389
|
-
transposed_bsr_columns, transposed_bsr_values));
|
|
390
|
-
return;
|
|
391
|
-
}
|
|
392
|
-
case 3:
|
|
393
|
-
switch (cols_per_block) {
|
|
394
|
-
case 1:
|
|
395
|
-
wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
|
|
396
|
-
(nnz, block_size, BsrBlockTransposer<3, 1, T>{},
|
|
397
|
-
block_indices, transposed_indices, bsr_values,
|
|
398
|
-
transposed_bsr_columns, transposed_bsr_values));
|
|
399
|
-
return;
|
|
400
|
-
case 2:
|
|
401
|
-
wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
|
|
402
|
-
(nnz, block_size, BsrBlockTransposer<3, 2, T>{},
|
|
403
|
-
block_indices, transposed_indices, bsr_values,
|
|
404
|
-
transposed_bsr_columns, transposed_bsr_values));
|
|
405
|
-
return;
|
|
406
|
-
case 3:
|
|
407
|
-
wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
|
|
408
|
-
(nnz, block_size, BsrBlockTransposer<3, 3, T>{},
|
|
409
|
-
block_indices, transposed_indices, bsr_values,
|
|
410
|
-
transposed_bsr_columns, transposed_bsr_values));
|
|
411
|
-
return;
|
|
412
|
-
}
|
|
413
|
-
}
|
|
414
|
-
|
|
415
|
-
wp_launch_device(
|
|
416
|
-
WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
|
|
417
|
-
(nnz, block_size,
|
|
418
|
-
BsrBlockTransposer<-1, -1, T>{rows_per_block, cols_per_block},
|
|
419
|
-
block_indices, transposed_indices, bsr_values, transposed_bsr_columns,
|
|
420
|
-
transposed_bsr_values));
|
|
421
|
-
}
|
|
422
|
-
|
|
423
|
-
template <typename T>
|
|
424
|
-
void bsr_transpose_device(int rows_per_block, int cols_per_block, int row_count,
|
|
425
|
-
int col_count, int nnz, const int *bsr_offsets,
|
|
426
|
-
const int *bsr_columns, const T *bsr_values,
|
|
427
|
-
int *transposed_bsr_offsets,
|
|
428
|
-
int *transposed_bsr_columns,
|
|
429
|
-
T *transposed_bsr_values) {
|
|
430
|
-
|
|
431
|
-
const int block_size = rows_per_block * cols_per_block;
|
|
432
|
-
|
|
433
|
-
void *context = cuda_context_get_current();
|
|
434
|
-
ContextGuard guard(context);
|
|
435
|
-
|
|
436
|
-
cudaStream_t stream = static_cast<cudaStream_t>(cuda_stream_get_current());
|
|
437
|
-
|
|
438
|
-
// Zero the transposed offsets
|
|
439
|
-
memset_device(WP_CURRENT_CONTEXT, transposed_bsr_offsets, 0,
|
|
440
|
-
(col_count + 1) * sizeof(int));
|
|
441
|
-
|
|
442
|
-
ScopedTemporary<int> block_indices(context, 2*nnz);
|
|
443
|
-
ScopedTemporary<BsrRowCol> combined_row_col(context, 2*nnz);
|
|
444
|
-
|
|
445
|
-
cub::DoubleBuffer<int> d_keys(block_indices.buffer(),
|
|
446
|
-
block_indices.buffer() + nnz);
|
|
447
|
-
cub::DoubleBuffer<BsrRowCol> d_values(combined_row_col.buffer(),
|
|
448
|
-
combined_row_col.buffer() + nnz);
|
|
449
|
-
|
|
450
|
-
wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_fill_row_col, nnz,
|
|
451
|
-
(nnz, row_count, bsr_offsets, bsr_columns, d_keys.Current(),
|
|
452
|
-
d_values.Current(), transposed_bsr_offsets));
|
|
453
|
-
|
|
454
|
-
// Sort blocks
|
|
455
|
-
{
|
|
456
|
-
size_t buff_size = 0;
|
|
457
|
-
check_cuda(cub::DeviceRadixSort::SortPairs(nullptr, buff_size, d_values,
|
|
458
|
-
d_keys, nnz, 0, 64, stream));
|
|
459
|
-
ScopedTemporary<> temp(context, buff_size);
|
|
460
|
-
check_cuda(cub::DeviceRadixSort::SortPairs(
|
|
461
|
-
temp.buffer(), buff_size, d_values, d_keys, nnz, 0, 64, stream));
|
|
462
|
-
}
|
|
463
|
-
|
|
464
|
-
// Prefix sum the transposed row block counts
|
|
465
|
-
{
|
|
466
|
-
size_t buff_size = 0;
|
|
467
|
-
check_cuda(cub::DeviceScan::InclusiveSum(
|
|
468
|
-
nullptr, buff_size, transposed_bsr_offsets, transposed_bsr_offsets,
|
|
469
|
-
col_count + 1, stream));
|
|
470
|
-
ScopedTemporary<> temp(context, buff_size);
|
|
471
|
-
check_cuda(cub::DeviceScan::InclusiveSum(
|
|
472
|
-
temp.buffer(), buff_size, transposed_bsr_offsets,
|
|
473
|
-
transposed_bsr_offsets, col_count + 1, stream));
|
|
474
|
-
}
|
|
475
|
-
|
|
476
|
-
// Move and transpose individual blocks
|
|
477
|
-
launch_bsr_transpose_blocks(
|
|
478
|
-
nnz, block_size,
|
|
479
|
-
rows_per_block, cols_per_block,
|
|
480
|
-
d_keys.Current(), d_values.Current(), bsr_values, transposed_bsr_columns,
|
|
481
|
-
transposed_bsr_values);
|
|
482
|
-
}
|
|
483
|
-
|
|
484
|
-
} // namespace
|
|
485
|
-
|
|
486
|
-
int bsr_matrix_from_triplets_float_device(
|
|
487
|
-
int rows_per_block, int cols_per_block, int row_count, int nnz,
|
|
488
|
-
uint64_t tpl_rows, uint64_t tpl_columns, uint64_t tpl_values,
|
|
489
|
-
uint64_t bsr_offsets, uint64_t bsr_columns, uint64_t bsr_values) {
|
|
490
|
-
return bsr_matrix_from_triplets_device<float>(
|
|
491
|
-
rows_per_block, cols_per_block, row_count, nnz,
|
|
492
|
-
reinterpret_cast<const int *>(tpl_rows),
|
|
493
|
-
reinterpret_cast<const int *>(tpl_columns),
|
|
494
|
-
reinterpret_cast<const float *>(tpl_values),
|
|
495
|
-
reinterpret_cast<int *>(bsr_offsets),
|
|
496
|
-
reinterpret_cast<int *>(bsr_columns),
|
|
497
|
-
reinterpret_cast<float *>(bsr_values));
|
|
498
|
-
}
|
|
499
|
-
|
|
500
|
-
int bsr_matrix_from_triplets_double_device(
|
|
501
|
-
int rows_per_block, int cols_per_block, int row_count, int nnz,
|
|
502
|
-
uint64_t tpl_rows, uint64_t tpl_columns, uint64_t tpl_values,
|
|
503
|
-
uint64_t bsr_offsets, uint64_t bsr_columns, uint64_t bsr_values) {
|
|
504
|
-
return bsr_matrix_from_triplets_device<double>(
|
|
505
|
-
rows_per_block, cols_per_block, row_count, nnz,
|
|
506
|
-
reinterpret_cast<const int *>(tpl_rows),
|
|
507
|
-
reinterpret_cast<const int *>(tpl_columns),
|
|
508
|
-
reinterpret_cast<const double *>(tpl_values),
|
|
509
|
-
reinterpret_cast<int *>(bsr_offsets),
|
|
510
|
-
reinterpret_cast<int *>(bsr_columns),
|
|
511
|
-
reinterpret_cast<double *>(bsr_values));
|
|
512
|
-
}
|
|
513
|
-
|
|
514
|
-
void bsr_transpose_float_device(int rows_per_block, int cols_per_block,
|
|
515
|
-
int row_count, int col_count, int nnz,
|
|
516
|
-
uint64_t bsr_offsets, uint64_t bsr_columns,
|
|
517
|
-
uint64_t bsr_values,
|
|
518
|
-
uint64_t transposed_bsr_offsets,
|
|
519
|
-
uint64_t transposed_bsr_columns,
|
|
520
|
-
uint64_t transposed_bsr_values) {
|
|
521
|
-
bsr_transpose_device(rows_per_block, cols_per_block, row_count, col_count,
|
|
522
|
-
nnz, reinterpret_cast<const int *>(bsr_offsets),
|
|
523
|
-
reinterpret_cast<const int *>(bsr_columns),
|
|
524
|
-
reinterpret_cast<const float *>(bsr_values),
|
|
525
|
-
reinterpret_cast<int *>(transposed_bsr_offsets),
|
|
526
|
-
reinterpret_cast<int *>(transposed_bsr_columns),
|
|
527
|
-
reinterpret_cast<float *>(transposed_bsr_values));
|
|
528
|
-
}
|
|
529
|
-
|
|
530
|
-
void bsr_transpose_double_device(int rows_per_block, int cols_per_block,
|
|
531
|
-
int row_count, int col_count, int nnz,
|
|
532
|
-
uint64_t bsr_offsets, uint64_t bsr_columns,
|
|
533
|
-
uint64_t bsr_values,
|
|
534
|
-
uint64_t transposed_bsr_offsets,
|
|
535
|
-
uint64_t transposed_bsr_columns,
|
|
536
|
-
uint64_t transposed_bsr_values) {
|
|
537
|
-
bsr_transpose_device(rows_per_block, cols_per_block, row_count, col_count,
|
|
538
|
-
nnz, reinterpret_cast<const int *>(bsr_offsets),
|
|
539
|
-
reinterpret_cast<const int *>(bsr_columns),
|
|
540
|
-
reinterpret_cast<const double *>(bsr_values),
|
|
541
|
-
reinterpret_cast<int *>(transposed_bsr_offsets),
|
|
542
|
-
reinterpret_cast<int *>(transposed_bsr_columns),
|
|
543
|
-
reinterpret_cast<double *>(transposed_bsr_values));
|
|
544
|
-
}
|
|
1
|
+
#include "cuda_util.h"
|
|
2
|
+
#include "warp.h"
|
|
3
|
+
|
|
4
|
+
#define THRUST_IGNORE_CUB_VERSION_CHECK
|
|
5
|
+
|
|
6
|
+
#include <cub/device/device_radix_sort.cuh>
|
|
7
|
+
#include <cub/device/device_run_length_encode.cuh>
|
|
8
|
+
#include <cub/device/device_scan.cuh>
|
|
9
|
+
#include <cub/device/device_select.cuh>
|
|
10
|
+
|
|
11
|
+
namespace {
|
|
12
|
+
|
|
13
|
+
// Combined row+column value that can be radix-sorted with CUB
|
|
14
|
+
using BsrRowCol = uint64_t;
|
|
15
|
+
|
|
16
|
+
CUDA_CALLABLE BsrRowCol bsr_combine_row_col(uint32_t row, uint32_t col) {
|
|
17
|
+
return (static_cast<uint64_t>(row) << 32) | col;
|
|
18
|
+
}
|
|
19
|
+
|
|
20
|
+
CUDA_CALLABLE uint32_t bsr_get_row(const BsrRowCol &row_col) {
|
|
21
|
+
return row_col >> 32;
|
|
22
|
+
}
|
|
23
|
+
|
|
24
|
+
CUDA_CALLABLE uint32_t bsr_get_col(const BsrRowCol &row_col) {
|
|
25
|
+
return row_col & INT_MAX;
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
// Cached temporary storage
|
|
29
|
+
struct BsrFromTripletsTemp {
|
|
30
|
+
|
|
31
|
+
int *count_buffer = NULL;
|
|
32
|
+
cudaEvent_t host_sync_event = NULL;
|
|
33
|
+
|
|
34
|
+
BsrFromTripletsTemp()
|
|
35
|
+
: count_buffer(static_cast<int*>(alloc_pinned(sizeof(int))))
|
|
36
|
+
{
|
|
37
|
+
cudaEventCreateWithFlags(&host_sync_event, cudaEventDisableTiming);
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
~BsrFromTripletsTemp()
|
|
41
|
+
{
|
|
42
|
+
cudaEventDestroy(host_sync_event);
|
|
43
|
+
free_pinned(count_buffer);
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
BsrFromTripletsTemp(const BsrFromTripletsTemp&) = delete;
|
|
47
|
+
BsrFromTripletsTemp& operator=(const BsrFromTripletsTemp&) = delete;
|
|
48
|
+
|
|
49
|
+
};
|
|
50
|
+
|
|
51
|
+
// map temp buffers to CUDA contexts
|
|
52
|
+
static std::unordered_map<void *, BsrFromTripletsTemp> g_bsr_from_triplets_temp_map;
|
|
53
|
+
|
|
54
|
+
template <typename T> struct BsrBlockIsNotZero {
|
|
55
|
+
int block_size;
|
|
56
|
+
const T *values;
|
|
57
|
+
|
|
58
|
+
CUDA_CALLABLE_DEVICE bool operator()(int i) const {
|
|
59
|
+
const T *val = values + i * block_size;
|
|
60
|
+
for (int i = 0; i < block_size; ++i, ++val) {
|
|
61
|
+
if (*val != T(0))
|
|
62
|
+
return true;
|
|
63
|
+
}
|
|
64
|
+
return false;
|
|
65
|
+
}
|
|
66
|
+
};
|
|
67
|
+
|
|
68
|
+
__global__ void bsr_fill_block_indices(int nnz, int *block_indices) {
|
|
69
|
+
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
|
70
|
+
if (i >= nnz)
|
|
71
|
+
return;
|
|
72
|
+
|
|
73
|
+
block_indices[i] = i;
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
__global__ void bsr_fill_row_col(const int *nnz, const int *block_indices,
|
|
77
|
+
const int *tpl_rows, const int *tpl_columns,
|
|
78
|
+
BsrRowCol *tpl_row_col) {
|
|
79
|
+
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
|
80
|
+
if (i >= *nnz)
|
|
81
|
+
return;
|
|
82
|
+
|
|
83
|
+
const int block = block_indices[i];
|
|
84
|
+
|
|
85
|
+
BsrRowCol row_col = bsr_combine_row_col(tpl_rows[block], tpl_columns[block]);
|
|
86
|
+
tpl_row_col[i] = row_col;
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
template <typename T>
|
|
90
|
+
__global__ void
|
|
91
|
+
bsr_merge_blocks(int nnz, int block_size, const int *block_offsets,
|
|
92
|
+
const int *sorted_block_indices,
|
|
93
|
+
const BsrRowCol *unique_row_cols, const T *tpl_values,
|
|
94
|
+
int *bsr_row_counts, int *bsr_cols, T *bsr_values)
|
|
95
|
+
|
|
96
|
+
{
|
|
97
|
+
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
|
98
|
+
if (i >= nnz)
|
|
99
|
+
return;
|
|
100
|
+
|
|
101
|
+
const int beg = i ? block_offsets[i - 1] : 0;
|
|
102
|
+
const int end = block_offsets[i];
|
|
103
|
+
|
|
104
|
+
BsrRowCol row_col = unique_row_cols[i];
|
|
105
|
+
|
|
106
|
+
bsr_cols[i] = bsr_get_col(row_col);
|
|
107
|
+
atomicAdd(bsr_row_counts + bsr_get_row(row_col) + 1, 1);
|
|
108
|
+
|
|
109
|
+
if (bsr_values == nullptr)
|
|
110
|
+
return;
|
|
111
|
+
|
|
112
|
+
T *bsr_val = bsr_values + i * block_size;
|
|
113
|
+
const T *tpl_val = tpl_values + sorted_block_indices[beg] * block_size;
|
|
114
|
+
|
|
115
|
+
for (int k = 0; k < block_size; ++k) {
|
|
116
|
+
bsr_val[k] = tpl_val[k];
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
for (int cur = beg + 1; cur != end; ++cur) {
|
|
120
|
+
const T *tpl_val = tpl_values + sorted_block_indices[cur] * block_size;
|
|
121
|
+
for (int k = 0; k < block_size; ++k) {
|
|
122
|
+
bsr_val[k] += tpl_val[k];
|
|
123
|
+
}
|
|
124
|
+
}
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
template <typename T>
|
|
128
|
+
int bsr_matrix_from_triplets_device(const int rows_per_block,
|
|
129
|
+
const int cols_per_block,
|
|
130
|
+
const int row_count, const int nnz,
|
|
131
|
+
const int *tpl_rows, const int *tpl_columns,
|
|
132
|
+
const T *tpl_values, int *bsr_offsets,
|
|
133
|
+
int *bsr_columns, T *bsr_values) {
|
|
134
|
+
const int block_size = rows_per_block * cols_per_block;
|
|
135
|
+
|
|
136
|
+
void *context = cuda_context_get_current();
|
|
137
|
+
ContextGuard guard(context);
|
|
138
|
+
|
|
139
|
+
// Per-context cached temporary buffers
|
|
140
|
+
BsrFromTripletsTemp &bsr_temp = g_bsr_from_triplets_temp_map[context];
|
|
141
|
+
|
|
142
|
+
cudaStream_t stream = static_cast<cudaStream_t>(cuda_stream_get_current());
|
|
143
|
+
|
|
144
|
+
ScopedTemporary<int> block_indices(context, 2*nnz);
|
|
145
|
+
ScopedTemporary<BsrRowCol> combined_row_col(context, 2*nnz);
|
|
146
|
+
|
|
147
|
+
cub::DoubleBuffer<int> d_keys(block_indices.buffer(),
|
|
148
|
+
block_indices.buffer() + nnz);
|
|
149
|
+
cub::DoubleBuffer<BsrRowCol> d_values(combined_row_col.buffer(),
|
|
150
|
+
combined_row_col.buffer() + nnz);
|
|
151
|
+
|
|
152
|
+
int *p_nz_triplet_count = bsr_temp.count_buffer;
|
|
153
|
+
|
|
154
|
+
wp_launch_device(WP_CURRENT_CONTEXT, bsr_fill_block_indices, nnz,
|
|
155
|
+
(nnz, d_keys.Current()));
|
|
156
|
+
|
|
157
|
+
if (tpl_values) {
|
|
158
|
+
|
|
159
|
+
// Remove zero blocks
|
|
160
|
+
{
|
|
161
|
+
size_t buff_size = 0;
|
|
162
|
+
BsrBlockIsNotZero<T> isNotZero{block_size, tpl_values};
|
|
163
|
+
check_cuda(cub::DeviceSelect::If(nullptr, buff_size, d_keys.Current(),
|
|
164
|
+
d_keys.Alternate(), p_nz_triplet_count,
|
|
165
|
+
nnz, isNotZero, stream));
|
|
166
|
+
ScopedTemporary<> temp(context, buff_size);
|
|
167
|
+
check_cuda(cub::DeviceSelect::If(
|
|
168
|
+
temp.buffer(), buff_size, d_keys.Current(), d_keys.Alternate(),
|
|
169
|
+
p_nz_triplet_count, nnz, isNotZero, stream));
|
|
170
|
+
}
|
|
171
|
+
cudaEventRecord(bsr_temp.host_sync_event, stream);
|
|
172
|
+
|
|
173
|
+
// switch current/alternate in double buffer
|
|
174
|
+
d_keys.selector ^= 1;
|
|
175
|
+
|
|
176
|
+
} else {
|
|
177
|
+
*p_nz_triplet_count = nnz;
|
|
178
|
+
}
|
|
179
|
+
|
|
180
|
+
// Combine rows and columns so we can sort on them both
|
|
181
|
+
wp_launch_device(WP_CURRENT_CONTEXT, bsr_fill_row_col, nnz,
|
|
182
|
+
(p_nz_triplet_count, d_keys.Current(), tpl_rows, tpl_columns,
|
|
183
|
+
d_values.Current()));
|
|
184
|
+
|
|
185
|
+
if (tpl_values) {
|
|
186
|
+
// Make sure count is available on host
|
|
187
|
+
cudaEventSynchronize(bsr_temp.host_sync_event);
|
|
188
|
+
}
|
|
189
|
+
|
|
190
|
+
const int nz_triplet_count = *p_nz_triplet_count;
|
|
191
|
+
|
|
192
|
+
// Sort
|
|
193
|
+
{
|
|
194
|
+
size_t buff_size = 0;
|
|
195
|
+
check_cuda(cub::DeviceRadixSort::SortPairs(
|
|
196
|
+
nullptr, buff_size, d_values, d_keys, nz_triplet_count, 0, 64, stream));
|
|
197
|
+
ScopedTemporary<> temp(context, buff_size);
|
|
198
|
+
check_cuda(cub::DeviceRadixSort::SortPairs(temp.buffer(), buff_size,
|
|
199
|
+
d_values, d_keys, nz_triplet_count,
|
|
200
|
+
0, 64, stream));
|
|
201
|
+
}
|
|
202
|
+
|
|
203
|
+
// Runlength encode row-col sequences
|
|
204
|
+
{
|
|
205
|
+
size_t buff_size = 0;
|
|
206
|
+
check_cuda(cub::DeviceRunLengthEncode::Encode(
|
|
207
|
+
nullptr, buff_size, d_values.Current(), d_values.Alternate(),
|
|
208
|
+
d_keys.Alternate(), p_nz_triplet_count, nz_triplet_count, stream));
|
|
209
|
+
ScopedTemporary<> temp(context, buff_size);
|
|
210
|
+
check_cuda(cub::DeviceRunLengthEncode::Encode(
|
|
211
|
+
temp.buffer(), buff_size, d_values.Current(), d_values.Alternate(),
|
|
212
|
+
d_keys.Alternate(), p_nz_triplet_count, nz_triplet_count, stream));
|
|
213
|
+
}
|
|
214
|
+
|
|
215
|
+
cudaEventRecord(bsr_temp.host_sync_event, stream);
|
|
216
|
+
|
|
217
|
+
// Now we have the following:
|
|
218
|
+
// d_values.Current(): sorted block row-col
|
|
219
|
+
// d_values.Alternate(): sorted unique row-col
|
|
220
|
+
// d_keys.Current(): sorted block indices
|
|
221
|
+
// d_keys.Alternate(): repeated block-row count
|
|
222
|
+
|
|
223
|
+
// Scan repeated block counts
|
|
224
|
+
{
|
|
225
|
+
size_t buff_size = 0;
|
|
226
|
+
check_cuda(cub::DeviceScan::InclusiveSum(
|
|
227
|
+
nullptr, buff_size, d_keys.Alternate(), d_keys.Alternate(),
|
|
228
|
+
nz_triplet_count, stream));
|
|
229
|
+
ScopedTemporary<> temp(context, buff_size);
|
|
230
|
+
check_cuda(cub::DeviceScan::InclusiveSum(
|
|
231
|
+
temp.buffer(), buff_size, d_keys.Alternate(), d_keys.Alternate(),
|
|
232
|
+
nz_triplet_count, stream));
|
|
233
|
+
}
|
|
234
|
+
|
|
235
|
+
// While we're at it, zero the bsr offsets buffer
|
|
236
|
+
memset_device(WP_CURRENT_CONTEXT, bsr_offsets, 0,
|
|
237
|
+
(row_count + 1) * sizeof(int));
|
|
238
|
+
|
|
239
|
+
// Wait for number of compressed blocks
|
|
240
|
+
cudaEventSynchronize(bsr_temp.host_sync_event);
|
|
241
|
+
const int compressed_nnz = *p_nz_triplet_count;
|
|
242
|
+
|
|
243
|
+
// We have all we need to accumulate our repeated blocks
|
|
244
|
+
wp_launch_device(WP_CURRENT_CONTEXT, bsr_merge_blocks, compressed_nnz,
|
|
245
|
+
(compressed_nnz, block_size, d_keys.Alternate(),
|
|
246
|
+
d_keys.Current(), d_values.Alternate(), tpl_values,
|
|
247
|
+
bsr_offsets, bsr_columns, bsr_values));
|
|
248
|
+
|
|
249
|
+
// Last, prefix sum the row block counts
|
|
250
|
+
{
|
|
251
|
+
size_t buff_size = 0;
|
|
252
|
+
check_cuda(cub::DeviceScan::InclusiveSum(nullptr, buff_size, bsr_offsets,
|
|
253
|
+
bsr_offsets, row_count + 1, stream));
|
|
254
|
+
ScopedTemporary<> temp(context, buff_size);
|
|
255
|
+
check_cuda(cub::DeviceScan::InclusiveSum(temp.buffer(), buff_size,
|
|
256
|
+
bsr_offsets, bsr_offsets,
|
|
257
|
+
row_count + 1, stream));
|
|
258
|
+
}
|
|
259
|
+
|
|
260
|
+
return compressed_nnz;
|
|
261
|
+
}
|
|
262
|
+
|
|
263
|
+
__global__ void bsr_transpose_fill_row_col(const int nnz, const int row_count,
|
|
264
|
+
const int *bsr_offsets,
|
|
265
|
+
const int *bsr_columns,
|
|
266
|
+
int *block_indices,
|
|
267
|
+
BsrRowCol *transposed_row_col,
|
|
268
|
+
int *transposed_bsr_offsets) {
|
|
269
|
+
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
|
270
|
+
if (i >= nnz)
|
|
271
|
+
return;
|
|
272
|
+
|
|
273
|
+
block_indices[i] = i;
|
|
274
|
+
|
|
275
|
+
// Binary search for row
|
|
276
|
+
int lower = 0;
|
|
277
|
+
int upper = row_count - 1;
|
|
278
|
+
|
|
279
|
+
while (lower < upper) {
|
|
280
|
+
int mid = lower + (upper - lower) / 2;
|
|
281
|
+
|
|
282
|
+
if (bsr_offsets[mid + 1] <= i) {
|
|
283
|
+
lower = mid + 1;
|
|
284
|
+
} else {
|
|
285
|
+
upper = mid;
|
|
286
|
+
}
|
|
287
|
+
}
|
|
288
|
+
|
|
289
|
+
const int row = lower;
|
|
290
|
+
const int col = bsr_columns[i];
|
|
291
|
+
BsrRowCol row_col = bsr_combine_row_col(col, row);
|
|
292
|
+
transposed_row_col[i] = row_col;
|
|
293
|
+
|
|
294
|
+
atomicAdd(transposed_bsr_offsets + col + 1, 1);
|
|
295
|
+
}
|
|
296
|
+
|
|
297
|
+
template <int Rows, int Cols, typename T> struct BsrBlockTransposer {
|
|
298
|
+
void CUDA_CALLABLE_DEVICE operator()(const T *src, T *dest) const {
|
|
299
|
+
for (int r = 0; r < Rows; ++r) {
|
|
300
|
+
for (int c = 0; c < Cols; ++c) {
|
|
301
|
+
dest[c * Rows + r] = src[r * Cols + c];
|
|
302
|
+
}
|
|
303
|
+
}
|
|
304
|
+
}
|
|
305
|
+
};
|
|
306
|
+
|
|
307
|
+
template <typename T> struct BsrBlockTransposer<-1, -1, T> {
|
|
308
|
+
|
|
309
|
+
int row_count;
|
|
310
|
+
int col_count;
|
|
311
|
+
|
|
312
|
+
void CUDA_CALLABLE_DEVICE operator()(const T *src, T *dest) const {
|
|
313
|
+
for (int r = 0; r < row_count; ++r) {
|
|
314
|
+
for (int c = 0; c < col_count; ++c) {
|
|
315
|
+
dest[c * row_count + r] = src[r * col_count + c];
|
|
316
|
+
}
|
|
317
|
+
}
|
|
318
|
+
}
|
|
319
|
+
};
|
|
320
|
+
|
|
321
|
+
template <int Rows, int Cols, typename T>
|
|
322
|
+
__global__ void
|
|
323
|
+
bsr_transpose_blocks(const int nnz, const int block_size,
|
|
324
|
+
BsrBlockTransposer<Rows, Cols, T> transposer,
|
|
325
|
+
const int *block_indices,
|
|
326
|
+
const BsrRowCol *transposed_indices, const T *bsr_values,
|
|
327
|
+
int *transposed_bsr_columns, T *transposed_bsr_values) {
|
|
328
|
+
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
|
329
|
+
if (i >= nnz)
|
|
330
|
+
return;
|
|
331
|
+
|
|
332
|
+
const int src_idx = block_indices[i];
|
|
333
|
+
|
|
334
|
+
transposer(bsr_values + src_idx * block_size,
|
|
335
|
+
transposed_bsr_values + i * block_size);
|
|
336
|
+
|
|
337
|
+
transposed_bsr_columns[i] = bsr_get_col(transposed_indices[i]);
|
|
338
|
+
}
|
|
339
|
+
|
|
340
|
+
template <typename T>
|
|
341
|
+
void
|
|
342
|
+
launch_bsr_transpose_blocks(const int nnz, const int block_size,
|
|
343
|
+
const int rows_per_block, const int cols_per_block,
|
|
344
|
+
const int *block_indices,
|
|
345
|
+
const BsrRowCol *transposed_indices,
|
|
346
|
+
const T *bsr_values,
|
|
347
|
+
int *transposed_bsr_columns, T *transposed_bsr_values) {
|
|
348
|
+
|
|
349
|
+
switch (rows_per_block) {
|
|
350
|
+
case 1:
|
|
351
|
+
switch (cols_per_block) {
|
|
352
|
+
case 1:
|
|
353
|
+
wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
|
|
354
|
+
(nnz, block_size, BsrBlockTransposer<1, 1, T>{},
|
|
355
|
+
block_indices, transposed_indices, bsr_values,
|
|
356
|
+
transposed_bsr_columns, transposed_bsr_values));
|
|
357
|
+
return;
|
|
358
|
+
case 2:
|
|
359
|
+
wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
|
|
360
|
+
(nnz, block_size, BsrBlockTransposer<1, 2, T>{},
|
|
361
|
+
block_indices, transposed_indices, bsr_values,
|
|
362
|
+
transposed_bsr_columns, transposed_bsr_values));
|
|
363
|
+
return;
|
|
364
|
+
case 3:
|
|
365
|
+
wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
|
|
366
|
+
(nnz, block_size, BsrBlockTransposer<1, 3, T>{},
|
|
367
|
+
block_indices, transposed_indices, bsr_values,
|
|
368
|
+
transposed_bsr_columns, transposed_bsr_values));
|
|
369
|
+
return;
|
|
370
|
+
}
|
|
371
|
+
case 2:
|
|
372
|
+
switch (cols_per_block) {
|
|
373
|
+
case 1:
|
|
374
|
+
wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
|
|
375
|
+
(nnz, block_size, BsrBlockTransposer<2, 1, T>{},
|
|
376
|
+
block_indices, transposed_indices, bsr_values,
|
|
377
|
+
transposed_bsr_columns, transposed_bsr_values));
|
|
378
|
+
return;
|
|
379
|
+
case 2:
|
|
380
|
+
wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
|
|
381
|
+
(nnz, block_size, BsrBlockTransposer<2, 2, T>{},
|
|
382
|
+
block_indices, transposed_indices, bsr_values,
|
|
383
|
+
transposed_bsr_columns, transposed_bsr_values));
|
|
384
|
+
return;
|
|
385
|
+
case 3:
|
|
386
|
+
wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
|
|
387
|
+
(nnz, block_size, BsrBlockTransposer<2, 3, T>{},
|
|
388
|
+
block_indices, transposed_indices, bsr_values,
|
|
389
|
+
transposed_bsr_columns, transposed_bsr_values));
|
|
390
|
+
return;
|
|
391
|
+
}
|
|
392
|
+
case 3:
|
|
393
|
+
switch (cols_per_block) {
|
|
394
|
+
case 1:
|
|
395
|
+
wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
|
|
396
|
+
(nnz, block_size, BsrBlockTransposer<3, 1, T>{},
|
|
397
|
+
block_indices, transposed_indices, bsr_values,
|
|
398
|
+
transposed_bsr_columns, transposed_bsr_values));
|
|
399
|
+
return;
|
|
400
|
+
case 2:
|
|
401
|
+
wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
|
|
402
|
+
(nnz, block_size, BsrBlockTransposer<3, 2, T>{},
|
|
403
|
+
block_indices, transposed_indices, bsr_values,
|
|
404
|
+
transposed_bsr_columns, transposed_bsr_values));
|
|
405
|
+
return;
|
|
406
|
+
case 3:
|
|
407
|
+
wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
|
|
408
|
+
(nnz, block_size, BsrBlockTransposer<3, 3, T>{},
|
|
409
|
+
block_indices, transposed_indices, bsr_values,
|
|
410
|
+
transposed_bsr_columns, transposed_bsr_values));
|
|
411
|
+
return;
|
|
412
|
+
}
|
|
413
|
+
}
|
|
414
|
+
|
|
415
|
+
wp_launch_device(
|
|
416
|
+
WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
|
|
417
|
+
(nnz, block_size,
|
|
418
|
+
BsrBlockTransposer<-1, -1, T>{rows_per_block, cols_per_block},
|
|
419
|
+
block_indices, transposed_indices, bsr_values, transposed_bsr_columns,
|
|
420
|
+
transposed_bsr_values));
|
|
421
|
+
}
|
|
422
|
+
|
|
423
|
+
template <typename T>
|
|
424
|
+
void bsr_transpose_device(int rows_per_block, int cols_per_block, int row_count,
|
|
425
|
+
int col_count, int nnz, const int *bsr_offsets,
|
|
426
|
+
const int *bsr_columns, const T *bsr_values,
|
|
427
|
+
int *transposed_bsr_offsets,
|
|
428
|
+
int *transposed_bsr_columns,
|
|
429
|
+
T *transposed_bsr_values) {
|
|
430
|
+
|
|
431
|
+
const int block_size = rows_per_block * cols_per_block;
|
|
432
|
+
|
|
433
|
+
void *context = cuda_context_get_current();
|
|
434
|
+
ContextGuard guard(context);
|
|
435
|
+
|
|
436
|
+
cudaStream_t stream = static_cast<cudaStream_t>(cuda_stream_get_current());
|
|
437
|
+
|
|
438
|
+
// Zero the transposed offsets
|
|
439
|
+
memset_device(WP_CURRENT_CONTEXT, transposed_bsr_offsets, 0,
|
|
440
|
+
(col_count + 1) * sizeof(int));
|
|
441
|
+
|
|
442
|
+
ScopedTemporary<int> block_indices(context, 2*nnz);
|
|
443
|
+
ScopedTemporary<BsrRowCol> combined_row_col(context, 2*nnz);
|
|
444
|
+
|
|
445
|
+
cub::DoubleBuffer<int> d_keys(block_indices.buffer(),
|
|
446
|
+
block_indices.buffer() + nnz);
|
|
447
|
+
cub::DoubleBuffer<BsrRowCol> d_values(combined_row_col.buffer(),
|
|
448
|
+
combined_row_col.buffer() + nnz);
|
|
449
|
+
|
|
450
|
+
wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_fill_row_col, nnz,
|
|
451
|
+
(nnz, row_count, bsr_offsets, bsr_columns, d_keys.Current(),
|
|
452
|
+
d_values.Current(), transposed_bsr_offsets));
|
|
453
|
+
|
|
454
|
+
// Sort blocks
|
|
455
|
+
{
|
|
456
|
+
size_t buff_size = 0;
|
|
457
|
+
check_cuda(cub::DeviceRadixSort::SortPairs(nullptr, buff_size, d_values,
|
|
458
|
+
d_keys, nnz, 0, 64, stream));
|
|
459
|
+
ScopedTemporary<> temp(context, buff_size);
|
|
460
|
+
check_cuda(cub::DeviceRadixSort::SortPairs(
|
|
461
|
+
temp.buffer(), buff_size, d_values, d_keys, nnz, 0, 64, stream));
|
|
462
|
+
}
|
|
463
|
+
|
|
464
|
+
// Prefix sum the transposed row block counts
|
|
465
|
+
{
|
|
466
|
+
size_t buff_size = 0;
|
|
467
|
+
check_cuda(cub::DeviceScan::InclusiveSum(
|
|
468
|
+
nullptr, buff_size, transposed_bsr_offsets, transposed_bsr_offsets,
|
|
469
|
+
col_count + 1, stream));
|
|
470
|
+
ScopedTemporary<> temp(context, buff_size);
|
|
471
|
+
check_cuda(cub::DeviceScan::InclusiveSum(
|
|
472
|
+
temp.buffer(), buff_size, transposed_bsr_offsets,
|
|
473
|
+
transposed_bsr_offsets, col_count + 1, stream));
|
|
474
|
+
}
|
|
475
|
+
|
|
476
|
+
// Move and transpose individual blocks
|
|
477
|
+
launch_bsr_transpose_blocks(
|
|
478
|
+
nnz, block_size,
|
|
479
|
+
rows_per_block, cols_per_block,
|
|
480
|
+
d_keys.Current(), d_values.Current(), bsr_values, transposed_bsr_columns,
|
|
481
|
+
transposed_bsr_values);
|
|
482
|
+
}
|
|
483
|
+
|
|
484
|
+
} // namespace
|
|
485
|
+
|
|
486
|
+
int bsr_matrix_from_triplets_float_device(
|
|
487
|
+
int rows_per_block, int cols_per_block, int row_count, int nnz,
|
|
488
|
+
uint64_t tpl_rows, uint64_t tpl_columns, uint64_t tpl_values,
|
|
489
|
+
uint64_t bsr_offsets, uint64_t bsr_columns, uint64_t bsr_values) {
|
|
490
|
+
return bsr_matrix_from_triplets_device<float>(
|
|
491
|
+
rows_per_block, cols_per_block, row_count, nnz,
|
|
492
|
+
reinterpret_cast<const int *>(tpl_rows),
|
|
493
|
+
reinterpret_cast<const int *>(tpl_columns),
|
|
494
|
+
reinterpret_cast<const float *>(tpl_values),
|
|
495
|
+
reinterpret_cast<int *>(bsr_offsets),
|
|
496
|
+
reinterpret_cast<int *>(bsr_columns),
|
|
497
|
+
reinterpret_cast<float *>(bsr_values));
|
|
498
|
+
}
|
|
499
|
+
|
|
500
|
+
int bsr_matrix_from_triplets_double_device(
|
|
501
|
+
int rows_per_block, int cols_per_block, int row_count, int nnz,
|
|
502
|
+
uint64_t tpl_rows, uint64_t tpl_columns, uint64_t tpl_values,
|
|
503
|
+
uint64_t bsr_offsets, uint64_t bsr_columns, uint64_t bsr_values) {
|
|
504
|
+
return bsr_matrix_from_triplets_device<double>(
|
|
505
|
+
rows_per_block, cols_per_block, row_count, nnz,
|
|
506
|
+
reinterpret_cast<const int *>(tpl_rows),
|
|
507
|
+
reinterpret_cast<const int *>(tpl_columns),
|
|
508
|
+
reinterpret_cast<const double *>(tpl_values),
|
|
509
|
+
reinterpret_cast<int *>(bsr_offsets),
|
|
510
|
+
reinterpret_cast<int *>(bsr_columns),
|
|
511
|
+
reinterpret_cast<double *>(bsr_values));
|
|
512
|
+
}
|
|
513
|
+
|
|
514
|
+
void bsr_transpose_float_device(int rows_per_block, int cols_per_block,
|
|
515
|
+
int row_count, int col_count, int nnz,
|
|
516
|
+
uint64_t bsr_offsets, uint64_t bsr_columns,
|
|
517
|
+
uint64_t bsr_values,
|
|
518
|
+
uint64_t transposed_bsr_offsets,
|
|
519
|
+
uint64_t transposed_bsr_columns,
|
|
520
|
+
uint64_t transposed_bsr_values) {
|
|
521
|
+
bsr_transpose_device(rows_per_block, cols_per_block, row_count, col_count,
|
|
522
|
+
nnz, reinterpret_cast<const int *>(bsr_offsets),
|
|
523
|
+
reinterpret_cast<const int *>(bsr_columns),
|
|
524
|
+
reinterpret_cast<const float *>(bsr_values),
|
|
525
|
+
reinterpret_cast<int *>(transposed_bsr_offsets),
|
|
526
|
+
reinterpret_cast<int *>(transposed_bsr_columns),
|
|
527
|
+
reinterpret_cast<float *>(transposed_bsr_values));
|
|
528
|
+
}
|
|
529
|
+
|
|
530
|
+
void bsr_transpose_double_device(int rows_per_block, int cols_per_block,
|
|
531
|
+
int row_count, int col_count, int nnz,
|
|
532
|
+
uint64_t bsr_offsets, uint64_t bsr_columns,
|
|
533
|
+
uint64_t bsr_values,
|
|
534
|
+
uint64_t transposed_bsr_offsets,
|
|
535
|
+
uint64_t transposed_bsr_columns,
|
|
536
|
+
uint64_t transposed_bsr_values) {
|
|
537
|
+
bsr_transpose_device(rows_per_block, cols_per_block, row_count, col_count,
|
|
538
|
+
nnz, reinterpret_cast<const int *>(bsr_offsets),
|
|
539
|
+
reinterpret_cast<const int *>(bsr_columns),
|
|
540
|
+
reinterpret_cast<const double *>(bsr_values),
|
|
541
|
+
reinterpret_cast<int *>(transposed_bsr_offsets),
|
|
542
|
+
reinterpret_cast<int *>(transposed_bsr_columns),
|
|
543
|
+
reinterpret_cast<double *>(transposed_bsr_values));
|
|
544
|
+
}
|