warp-lang 1.2.2__py3-none-manylinux2014_x86_64.whl → 1.3.1__py3-none-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +8 -6
- warp/autograd.py +823 -0
- warp/bin/warp.so +0 -0
- warp/build.py +6 -2
- warp/builtins.py +1412 -888
- warp/codegen.py +503 -166
- warp/config.py +48 -18
- warp/context.py +400 -198
- warp/dlpack.py +8 -0
- warp/examples/assets/bunny.usd +0 -0
- warp/examples/benchmarks/benchmark_cloth_warp.py +1 -1
- warp/examples/benchmarks/benchmark_interop_torch.py +158 -0
- warp/examples/benchmarks/benchmark_launches.py +1 -1
- warp/examples/core/example_cupy.py +78 -0
- warp/examples/fem/example_apic_fluid.py +17 -36
- warp/examples/fem/example_burgers.py +9 -18
- warp/examples/fem/example_convection_diffusion.py +7 -17
- warp/examples/fem/example_convection_diffusion_dg.py +27 -47
- warp/examples/fem/example_deformed_geometry.py +11 -22
- warp/examples/fem/example_diffusion.py +7 -18
- warp/examples/fem/example_diffusion_3d.py +24 -28
- warp/examples/fem/example_diffusion_mgpu.py +7 -14
- warp/examples/fem/example_magnetostatics.py +190 -0
- warp/examples/fem/example_mixed_elasticity.py +111 -80
- warp/examples/fem/example_navier_stokes.py +30 -34
- warp/examples/fem/example_nonconforming_contact.py +290 -0
- warp/examples/fem/example_stokes.py +17 -32
- warp/examples/fem/example_stokes_transfer.py +12 -21
- warp/examples/fem/example_streamlines.py +350 -0
- warp/examples/fem/utils.py +936 -0
- warp/fabric.py +5 -2
- warp/fem/__init__.py +13 -3
- warp/fem/cache.py +161 -11
- warp/fem/dirichlet.py +37 -28
- warp/fem/domain.py +105 -14
- warp/fem/field/__init__.py +14 -3
- warp/fem/field/field.py +454 -11
- warp/fem/field/nodal_field.py +33 -18
- warp/fem/geometry/deformed_geometry.py +50 -15
- warp/fem/geometry/hexmesh.py +12 -24
- warp/fem/geometry/nanogrid.py +106 -31
- warp/fem/geometry/quadmesh_2d.py +6 -11
- warp/fem/geometry/tetmesh.py +103 -61
- warp/fem/geometry/trimesh_2d.py +98 -47
- warp/fem/integrate.py +231 -186
- warp/fem/operator.py +14 -9
- warp/fem/quadrature/pic_quadrature.py +35 -9
- warp/fem/quadrature/quadrature.py +119 -32
- warp/fem/space/basis_space.py +98 -22
- warp/fem/space/collocated_function_space.py +3 -1
- warp/fem/space/function_space.py +7 -2
- warp/fem/space/grid_2d_function_space.py +3 -3
- warp/fem/space/grid_3d_function_space.py +4 -4
- warp/fem/space/hexmesh_function_space.py +3 -2
- warp/fem/space/nanogrid_function_space.py +12 -14
- warp/fem/space/partition.py +45 -47
- warp/fem/space/restriction.py +19 -16
- warp/fem/space/shape/cube_shape_function.py +91 -3
- warp/fem/space/shape/shape_function.py +7 -0
- warp/fem/space/shape/square_shape_function.py +32 -0
- warp/fem/space/shape/tet_shape_function.py +11 -7
- warp/fem/space/shape/triangle_shape_function.py +10 -1
- warp/fem/space/topology.py +116 -42
- warp/fem/types.py +8 -1
- warp/fem/utils.py +301 -83
- warp/native/array.h +16 -0
- warp/native/builtin.h +0 -15
- warp/native/cuda_util.cpp +14 -6
- warp/native/exports.h +1348 -1308
- warp/native/quat.h +79 -0
- warp/native/rand.h +27 -4
- warp/native/sparse.cpp +83 -81
- warp/native/sparse.cu +381 -453
- warp/native/vec.h +64 -0
- warp/native/volume.cpp +40 -49
- warp/native/volume_builder.cu +2 -3
- warp/native/volume_builder.h +12 -17
- warp/native/warp.cu +3 -3
- warp/native/warp.h +69 -59
- warp/render/render_opengl.py +17 -9
- warp/sim/articulation.py +117 -17
- warp/sim/collide.py +35 -29
- warp/sim/model.py +123 -18
- warp/sim/render.py +3 -1
- warp/sparse.py +867 -203
- warp/stubs.py +312 -541
- warp/tape.py +29 -1
- warp/tests/disabled_kinematics.py +1 -1
- warp/tests/test_adam.py +1 -1
- warp/tests/test_arithmetic.py +1 -1
- warp/tests/test_array.py +58 -1
- warp/tests/test_array_reduce.py +1 -1
- warp/tests/test_async.py +1 -1
- warp/tests/test_atomic.py +1 -1
- warp/tests/test_bool.py +1 -1
- warp/tests/test_builtins_resolution.py +1 -1
- warp/tests/test_bvh.py +6 -1
- warp/tests/test_closest_point_edge_edge.py +1 -1
- warp/tests/test_codegen.py +91 -1
- warp/tests/test_compile_consts.py +1 -1
- warp/tests/test_conditional.py +1 -1
- warp/tests/test_copy.py +1 -1
- warp/tests/test_ctypes.py +1 -1
- warp/tests/test_dense.py +1 -1
- warp/tests/test_devices.py +1 -1
- warp/tests/test_dlpack.py +1 -1
- warp/tests/test_examples.py +33 -4
- warp/tests/test_fabricarray.py +5 -2
- warp/tests/test_fast_math.py +1 -1
- warp/tests/test_fem.py +213 -6
- warp/tests/test_fp16.py +1 -1
- warp/tests/test_func.py +1 -1
- warp/tests/test_future_annotations.py +90 -0
- warp/tests/test_generics.py +1 -1
- warp/tests/test_grad.py +1 -1
- warp/tests/test_grad_customs.py +1 -1
- warp/tests/test_grad_debug.py +247 -0
- warp/tests/test_hash_grid.py +6 -1
- warp/tests/test_implicit_init.py +354 -0
- warp/tests/test_import.py +1 -1
- warp/tests/test_indexedarray.py +1 -1
- warp/tests/test_intersect.py +1 -1
- warp/tests/test_jax.py +1 -1
- warp/tests/test_large.py +1 -1
- warp/tests/test_launch.py +1 -1
- warp/tests/test_lerp.py +1 -1
- warp/tests/test_linear_solvers.py +1 -1
- warp/tests/test_lvalue.py +1 -1
- warp/tests/test_marching_cubes.py +5 -2
- warp/tests/test_mat.py +34 -35
- warp/tests/test_mat_lite.py +2 -1
- warp/tests/test_mat_scalar_ops.py +1 -1
- warp/tests/test_math.py +1 -1
- warp/tests/test_matmul.py +20 -16
- warp/tests/test_matmul_lite.py +1 -1
- warp/tests/test_mempool.py +1 -1
- warp/tests/test_mesh.py +5 -2
- warp/tests/test_mesh_query_aabb.py +1 -1
- warp/tests/test_mesh_query_point.py +1 -1
- warp/tests/test_mesh_query_ray.py +1 -1
- warp/tests/test_mlp.py +1 -1
- warp/tests/test_model.py +1 -1
- warp/tests/test_module_hashing.py +77 -1
- warp/tests/test_modules_lite.py +1 -1
- warp/tests/test_multigpu.py +1 -1
- warp/tests/test_noise.py +1 -1
- warp/tests/test_operators.py +1 -1
- warp/tests/test_options.py +1 -1
- warp/tests/test_overwrite.py +542 -0
- warp/tests/test_peer.py +1 -1
- warp/tests/test_pinned.py +1 -1
- warp/tests/test_print.py +1 -1
- warp/tests/test_quat.py +15 -1
- warp/tests/test_rand.py +1 -1
- warp/tests/test_reload.py +1 -1
- warp/tests/test_rounding.py +1 -1
- warp/tests/test_runlength_encode.py +1 -1
- warp/tests/test_scalar_ops.py +95 -0
- warp/tests/test_sim_grad.py +1 -1
- warp/tests/test_sim_kinematics.py +1 -1
- warp/tests/test_smoothstep.py +1 -1
- warp/tests/test_sparse.py +82 -15
- warp/tests/test_spatial.py +1 -1
- warp/tests/test_special_values.py +2 -11
- warp/tests/test_streams.py +11 -1
- warp/tests/test_struct.py +1 -1
- warp/tests/test_tape.py +1 -1
- warp/tests/test_torch.py +194 -1
- warp/tests/test_transient_module.py +1 -1
- warp/tests/test_types.py +1 -1
- warp/tests/test_utils.py +1 -1
- warp/tests/test_vec.py +15 -63
- warp/tests/test_vec_lite.py +2 -1
- warp/tests/test_vec_scalar_ops.py +65 -1
- warp/tests/test_verify_fp.py +1 -1
- warp/tests/test_volume.py +28 -2
- warp/tests/test_volume_write.py +1 -1
- warp/tests/unittest_serial.py +1 -1
- warp/tests/unittest_suites.py +9 -1
- warp/tests/walkthrough_debug.py +1 -1
- warp/thirdparty/unittest_parallel.py +2 -5
- warp/torch.py +103 -41
- warp/types.py +341 -224
- warp/utils.py +11 -2
- {warp_lang-1.2.2.dist-info → warp_lang-1.3.1.dist-info}/METADATA +99 -46
- warp_lang-1.3.1.dist-info/RECORD +368 -0
- warp/examples/fem/bsr_utils.py +0 -378
- warp/examples/fem/mesh_utils.py +0 -133
- warp/examples/fem/plot_utils.py +0 -292
- warp_lang-1.2.2.dist-info/RECORD +0 -359
- {warp_lang-1.2.2.dist-info → warp_lang-1.3.1.dist-info}/LICENSE.md +0 -0
- {warp_lang-1.2.2.dist-info → warp_lang-1.3.1.dist-info}/WHEEL +0 -0
- {warp_lang-1.2.2.dist-info → warp_lang-1.3.1.dist-info}/top_level.txt +0 -0
warp/native/sparse.cu
CHANGED
|
@@ -6,539 +6,467 @@
|
|
|
6
6
|
#include <cub/device/device_radix_sort.cuh>
|
|
7
7
|
#include <cub/device/device_run_length_encode.cuh>
|
|
8
8
|
#include <cub/device/device_scan.cuh>
|
|
9
|
-
#include <cub/device/device_select.cuh>
|
|
10
9
|
|
|
11
|
-
namespace
|
|
10
|
+
namespace
|
|
11
|
+
{
|
|
12
12
|
|
|
13
13
|
// Combined row+column value that can be radix-sorted with CUB
|
|
14
14
|
using BsrRowCol = uint64_t;
|
|
15
15
|
|
|
16
|
-
|
|
17
|
-
return (static_cast<uint64_t>(row) << 32) | col;
|
|
18
|
-
}
|
|
19
|
-
|
|
20
|
-
CUDA_CALLABLE uint32_t bsr_get_row(const BsrRowCol &row_col) {
|
|
21
|
-
return row_col >> 32;
|
|
22
|
-
}
|
|
16
|
+
static constexpr BsrRowCol PRUNED_ROWCOL = ~BsrRowCol(0);
|
|
23
17
|
|
|
24
|
-
CUDA_CALLABLE
|
|
25
|
-
|
|
18
|
+
CUDA_CALLABLE BsrRowCol bsr_combine_row_col(uint32_t row, uint32_t col)
|
|
19
|
+
{
|
|
20
|
+
return (static_cast<uint64_t>(row) << 32) | col;
|
|
26
21
|
}
|
|
27
22
|
|
|
28
|
-
|
|
29
|
-
struct BsrFromTripletsTemp {
|
|
30
|
-
|
|
31
|
-
int *count_buffer = NULL;
|
|
32
|
-
cudaEvent_t host_sync_event = NULL;
|
|
33
|
-
|
|
34
|
-
BsrFromTripletsTemp()
|
|
35
|
-
: count_buffer(static_cast<int*>(alloc_pinned(sizeof(int))))
|
|
36
|
-
{
|
|
37
|
-
cudaEventCreateWithFlags(&host_sync_event, cudaEventDisableTiming);
|
|
38
|
-
}
|
|
39
|
-
|
|
40
|
-
~BsrFromTripletsTemp()
|
|
41
|
-
{
|
|
42
|
-
cudaEventDestroy(host_sync_event);
|
|
43
|
-
free_pinned(count_buffer);
|
|
44
|
-
}
|
|
45
|
-
|
|
46
|
-
BsrFromTripletsTemp(const BsrFromTripletsTemp&) = delete;
|
|
47
|
-
BsrFromTripletsTemp& operator=(const BsrFromTripletsTemp&) = delete;
|
|
23
|
+
CUDA_CALLABLE uint32_t bsr_get_row(const BsrRowCol& row_col) { return row_col >> 32; }
|
|
48
24
|
|
|
49
|
-
}
|
|
25
|
+
CUDA_CALLABLE uint32_t bsr_get_col(const BsrRowCol& row_col) { return row_col & INT_MAX; }
|
|
50
26
|
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
int block_size;
|
|
56
|
-
const T *values;
|
|
27
|
+
template <typename T> struct BsrBlockIsNotZero
|
|
28
|
+
{
|
|
29
|
+
int block_size;
|
|
30
|
+
const T* values;
|
|
57
31
|
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
32
|
+
CUDA_CALLABLE_DEVICE bool operator()(int i) const
|
|
33
|
+
{
|
|
34
|
+
if (!values)
|
|
35
|
+
return true;
|
|
36
|
+
|
|
37
|
+
const T* val = values + i * block_size;
|
|
38
|
+
for (int i = 0; i < block_size; ++i, ++val)
|
|
39
|
+
{
|
|
40
|
+
if (*val != T(0))
|
|
41
|
+
return true;
|
|
42
|
+
}
|
|
43
|
+
return false;
|
|
63
44
|
}
|
|
64
|
-
return false;
|
|
65
|
-
}
|
|
66
45
|
};
|
|
67
46
|
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
47
|
+
template <typename T>
|
|
48
|
+
__global__ void bsr_fill_triplet_key_values(const int nnz, const int nrow, const int* tpl_rows, const int* tpl_columns,
|
|
49
|
+
const BsrBlockIsNotZero<T> nonZero, uint32_t* block_indices,
|
|
50
|
+
BsrRowCol* tpl_row_col)
|
|
51
|
+
{
|
|
52
|
+
int block = blockIdx.x * blockDim.x + threadIdx.x;
|
|
53
|
+
if (block >= nnz)
|
|
54
|
+
return;
|
|
55
|
+
|
|
56
|
+
const int row = tpl_rows[block];
|
|
57
|
+
const int col = tpl_columns[block];
|
|
58
|
+
const bool is_valid = row >= 0 && row < nrow;
|
|
72
59
|
|
|
73
|
-
|
|
60
|
+
const BsrRowCol row_col = is_valid && nonZero(block) ? bsr_combine_row_col(row, col) : PRUNED_ROWCOL;
|
|
61
|
+
tpl_row_col[block] = row_col;
|
|
62
|
+
block_indices[block] = block;
|
|
74
63
|
}
|
|
75
64
|
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
65
|
+
template <typename T>
|
|
66
|
+
__global__ void bsr_find_row_offsets(uint32_t row_count, const T* d_nnz, const BsrRowCol* unique_row_col,
|
|
67
|
+
int* row_offsets)
|
|
68
|
+
{
|
|
69
|
+
const uint32_t row = blockIdx.x * blockDim.x + threadIdx.x;
|
|
70
|
+
|
|
71
|
+
if (row > row_count)
|
|
72
|
+
return;
|
|
82
73
|
|
|
83
|
-
|
|
74
|
+
const uint32_t nnz = *d_nnz;
|
|
75
|
+
if (row == 0 || nnz == 0)
|
|
76
|
+
{
|
|
77
|
+
row_offsets[row] = 0;
|
|
78
|
+
return;
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
if (bsr_get_row(unique_row_col[nnz - 1]) < row)
|
|
82
|
+
{
|
|
83
|
+
row_offsets[row] = nnz;
|
|
84
|
+
return;
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
// binary search for row start
|
|
88
|
+
uint32_t lower = 0;
|
|
89
|
+
uint32_t upper = nnz - 1;
|
|
90
|
+
while (lower < upper)
|
|
91
|
+
{
|
|
92
|
+
uint32_t mid = lower + (upper - lower) / 2;
|
|
93
|
+
|
|
94
|
+
if (bsr_get_row(unique_row_col[mid]) < row)
|
|
95
|
+
{
|
|
96
|
+
lower = mid + 1;
|
|
97
|
+
}
|
|
98
|
+
else
|
|
99
|
+
{
|
|
100
|
+
upper = mid;
|
|
101
|
+
}
|
|
102
|
+
}
|
|
84
103
|
|
|
85
|
-
|
|
86
|
-
tpl_row_col[i] = row_col;
|
|
104
|
+
row_offsets[row] = lower;
|
|
87
105
|
}
|
|
88
106
|
|
|
89
107
|
template <typename T>
|
|
90
|
-
__global__ void
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
const BsrRowCol *unique_row_cols, const T *tpl_values,
|
|
94
|
-
int *bsr_row_counts, int *bsr_cols, T *bsr_values)
|
|
108
|
+
__global__ void bsr_merge_blocks(const uint32_t* d_nnz, int block_size, const uint32_t* block_offsets,
|
|
109
|
+
const uint32_t* sorted_block_indices, const BsrRowCol* unique_row_cols,
|
|
110
|
+
const T* tpl_values, int* bsr_cols, T* bsr_values)
|
|
95
111
|
|
|
96
112
|
{
|
|
97
|
-
|
|
98
|
-
if (i >= nnz)
|
|
99
|
-
return;
|
|
113
|
+
const uint32_t i = blockIdx.x * blockDim.x + threadIdx.x;
|
|
100
114
|
|
|
101
|
-
|
|
102
|
-
|
|
115
|
+
if (i >= *d_nnz)
|
|
116
|
+
return;
|
|
103
117
|
|
|
104
|
-
|
|
118
|
+
const BsrRowCol row_col = unique_row_cols[i];
|
|
119
|
+
bsr_cols[i] = bsr_get_col(row_col);
|
|
105
120
|
|
|
106
|
-
|
|
107
|
-
|
|
121
|
+
// Accumulate merged block values
|
|
122
|
+
if (row_col == PRUNED_ROWCOL || bsr_values == nullptr)
|
|
123
|
+
return;
|
|
108
124
|
|
|
109
|
-
|
|
110
|
-
|
|
125
|
+
const uint32_t beg = i ? block_offsets[i - 1] : 0;
|
|
126
|
+
const uint32_t end = block_offsets[i];
|
|
111
127
|
|
|
112
|
-
|
|
113
|
-
|
|
128
|
+
T* bsr_val = bsr_values + i * block_size;
|
|
129
|
+
const T* tpl_val = tpl_values + sorted_block_indices[beg] * block_size;
|
|
114
130
|
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
131
|
+
for (int k = 0; k < block_size; ++k)
|
|
132
|
+
{
|
|
133
|
+
bsr_val[k] = tpl_val[k];
|
|
134
|
+
}
|
|
118
135
|
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
136
|
+
for (uint32_t cur = beg + 1; cur != end; ++cur)
|
|
137
|
+
{
|
|
138
|
+
const T* tpl_val = tpl_values + sorted_block_indices[cur] * block_size;
|
|
139
|
+
for (int k = 0; k < block_size; ++k)
|
|
140
|
+
{
|
|
141
|
+
bsr_val[k] += tpl_val[k];
|
|
142
|
+
}
|
|
123
143
|
}
|
|
124
|
-
}
|
|
125
144
|
}
|
|
126
145
|
|
|
127
146
|
template <typename T>
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
147
|
+
void bsr_matrix_from_triplets_device(const int rows_per_block, const int cols_per_block, const int row_count,
|
|
148
|
+
const int nnz, const int* tpl_rows, const int* tpl_columns, const T* tpl_values,
|
|
149
|
+
const bool prune_numerical_zeros, int* bsr_offsets, int* bsr_columns,
|
|
150
|
+
T* bsr_values, int* bsr_nnz, void* bsr_nnz_event)
|
|
151
|
+
{
|
|
152
|
+
const int block_size = rows_per_block * cols_per_block;
|
|
153
|
+
|
|
154
|
+
void* context = cuda_context_get_current();
|
|
155
|
+
ContextGuard guard(context);
|
|
135
156
|
|
|
136
|
-
|
|
137
|
-
|
|
157
|
+
// Per-context cached temporary buffers
|
|
158
|
+
// BsrFromTripletsTemp& bsr_temp = g_bsr_from_triplets_temp_map[context];
|
|
138
159
|
|
|
139
|
-
|
|
140
|
-
BsrFromTripletsTemp &bsr_temp = g_bsr_from_triplets_temp_map[context];
|
|
160
|
+
cudaStream_t stream = static_cast<cudaStream_t>(cuda_stream_get_current());
|
|
141
161
|
|
|
142
|
-
|
|
162
|
+
ScopedTemporary<uint32_t> block_indices(context, 2 * nnz + 1);
|
|
163
|
+
ScopedTemporary<BsrRowCol> combined_row_col(context, 2 * nnz);
|
|
164
|
+
|
|
165
|
+
cub::DoubleBuffer<uint32_t> d_keys(block_indices.buffer(), block_indices.buffer() + nnz);
|
|
166
|
+
cub::DoubleBuffer<BsrRowCol> d_values(combined_row_col.buffer(), combined_row_col.buffer() + nnz);
|
|
167
|
+
|
|
168
|
+
uint32_t* unique_triplet_count = block_indices.buffer() + 2 * nnz;
|
|
169
|
+
|
|
170
|
+
// Combine rows and columns so we can sort on them both
|
|
171
|
+
BsrBlockIsNotZero<T> isNotZero{block_size, prune_numerical_zeros ? tpl_values : nullptr};
|
|
172
|
+
wp_launch_device(WP_CURRENT_CONTEXT, bsr_fill_triplet_key_values, nnz,
|
|
173
|
+
(nnz, row_count, tpl_rows, tpl_columns, isNotZero, d_keys.Current(), d_values.Current()));
|
|
174
|
+
|
|
175
|
+
// Sort
|
|
176
|
+
{
|
|
177
|
+
size_t buff_size = 0;
|
|
178
|
+
check_cuda(cub::DeviceRadixSort::SortPairs(nullptr, buff_size, d_values, d_keys, nnz, 0, 64, stream));
|
|
179
|
+
ScopedTemporary<> temp(context, buff_size);
|
|
180
|
+
check_cuda(cub::DeviceRadixSort::SortPairs(temp.buffer(), buff_size, d_values, d_keys, nnz, 0, 64, stream));
|
|
181
|
+
}
|
|
143
182
|
|
|
144
|
-
|
|
145
|
-
|
|
183
|
+
// Runlength encode row-col sequences
|
|
184
|
+
{
|
|
185
|
+
size_t buff_size = 0;
|
|
186
|
+
check_cuda(cub::DeviceRunLengthEncode::Encode(nullptr, buff_size, d_values.Current(), d_values.Alternate(),
|
|
187
|
+
d_keys.Alternate(), unique_triplet_count, nnz, stream));
|
|
188
|
+
ScopedTemporary<> temp(context, buff_size);
|
|
189
|
+
check_cuda(cub::DeviceRunLengthEncode::Encode(temp.buffer(), buff_size, d_values.Current(),
|
|
190
|
+
d_values.Alternate(), d_keys.Alternate(), unique_triplet_count,
|
|
191
|
+
nnz, stream));
|
|
192
|
+
}
|
|
146
193
|
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
combined_row_col.buffer() + nnz);
|
|
194
|
+
// Compute row offsets from sorted unique blocks
|
|
195
|
+
wp_launch_device(WP_CURRENT_CONTEXT, bsr_find_row_offsets, row_count + 1,
|
|
196
|
+
(row_count, unique_triplet_count, d_values.Alternate(), bsr_offsets));
|
|
151
197
|
|
|
152
|
-
|
|
198
|
+
if (bsr_nnz)
|
|
199
|
+
{
|
|
200
|
+
// Copy nnz to host, and record an event for the competed transfer if desired
|
|
153
201
|
|
|
154
|
-
|
|
155
|
-
(nnz, d_keys.Current()));
|
|
202
|
+
memcpy_d2h(WP_CURRENT_CONTEXT, bsr_nnz, bsr_offsets + row_count, sizeof(int), stream);
|
|
156
203
|
|
|
157
|
-
|
|
204
|
+
if (bsr_nnz_event)
|
|
205
|
+
{
|
|
206
|
+
cuda_event_record(bsr_nnz_event, stream);
|
|
207
|
+
}
|
|
208
|
+
}
|
|
158
209
|
|
|
159
|
-
//
|
|
210
|
+
// Scan repeated block counts
|
|
160
211
|
{
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
check_cuda(cub::DeviceSelect::If(
|
|
168
|
-
temp.buffer(), buff_size, d_keys.Current(), d_keys.Alternate(),
|
|
169
|
-
p_nz_triplet_count, nnz, isNotZero, stream));
|
|
212
|
+
size_t buff_size = 0;
|
|
213
|
+
check_cuda(
|
|
214
|
+
cub::DeviceScan::InclusiveSum(nullptr, buff_size, d_keys.Alternate(), d_keys.Alternate(), nnz, stream));
|
|
215
|
+
ScopedTemporary<> temp(context, buff_size);
|
|
216
|
+
check_cuda(cub::DeviceScan::InclusiveSum(temp.buffer(), buff_size, d_keys.Alternate(), d_keys.Alternate(), nnz,
|
|
217
|
+
stream));
|
|
170
218
|
}
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
} else {
|
|
177
|
-
*p_nz_triplet_count = nnz;
|
|
178
|
-
}
|
|
179
|
-
|
|
180
|
-
// Combine rows and columns so we can sort on them both
|
|
181
|
-
wp_launch_device(WP_CURRENT_CONTEXT, bsr_fill_row_col, nnz,
|
|
182
|
-
(p_nz_triplet_count, d_keys.Current(), tpl_rows, tpl_columns,
|
|
183
|
-
d_values.Current()));
|
|
184
|
-
|
|
185
|
-
if (tpl_values) {
|
|
186
|
-
// Make sure count is available on host
|
|
187
|
-
cudaEventSynchronize(bsr_temp.host_sync_event);
|
|
188
|
-
}
|
|
189
|
-
|
|
190
|
-
const int nz_triplet_count = *p_nz_triplet_count;
|
|
191
|
-
|
|
192
|
-
// Sort
|
|
193
|
-
{
|
|
194
|
-
size_t buff_size = 0;
|
|
195
|
-
check_cuda(cub::DeviceRadixSort::SortPairs(
|
|
196
|
-
nullptr, buff_size, d_values, d_keys, nz_triplet_count, 0, 64, stream));
|
|
197
|
-
ScopedTemporary<> temp(context, buff_size);
|
|
198
|
-
check_cuda(cub::DeviceRadixSort::SortPairs(temp.buffer(), buff_size,
|
|
199
|
-
d_values, d_keys, nz_triplet_count,
|
|
200
|
-
0, 64, stream));
|
|
201
|
-
}
|
|
202
|
-
|
|
203
|
-
// Runlength encode row-col sequences
|
|
204
|
-
{
|
|
205
|
-
size_t buff_size = 0;
|
|
206
|
-
check_cuda(cub::DeviceRunLengthEncode::Encode(
|
|
207
|
-
nullptr, buff_size, d_values.Current(), d_values.Alternate(),
|
|
208
|
-
d_keys.Alternate(), p_nz_triplet_count, nz_triplet_count, stream));
|
|
209
|
-
ScopedTemporary<> temp(context, buff_size);
|
|
210
|
-
check_cuda(cub::DeviceRunLengthEncode::Encode(
|
|
211
|
-
temp.buffer(), buff_size, d_values.Current(), d_values.Alternate(),
|
|
212
|
-
d_keys.Alternate(), p_nz_triplet_count, nz_triplet_count, stream));
|
|
213
|
-
}
|
|
214
|
-
|
|
215
|
-
cudaEventRecord(bsr_temp.host_sync_event, stream);
|
|
216
|
-
|
|
217
|
-
// Now we have the following:
|
|
218
|
-
// d_values.Current(): sorted block row-col
|
|
219
|
-
// d_values.Alternate(): sorted unique row-col
|
|
220
|
-
// d_keys.Current(): sorted block indices
|
|
221
|
-
// d_keys.Alternate(): repeated block-row count
|
|
222
|
-
|
|
223
|
-
// Scan repeated block counts
|
|
224
|
-
{
|
|
225
|
-
size_t buff_size = 0;
|
|
226
|
-
check_cuda(cub::DeviceScan::InclusiveSum(
|
|
227
|
-
nullptr, buff_size, d_keys.Alternate(), d_keys.Alternate(),
|
|
228
|
-
nz_triplet_count, stream));
|
|
229
|
-
ScopedTemporary<> temp(context, buff_size);
|
|
230
|
-
check_cuda(cub::DeviceScan::InclusiveSum(
|
|
231
|
-
temp.buffer(), buff_size, d_keys.Alternate(), d_keys.Alternate(),
|
|
232
|
-
nz_triplet_count, stream));
|
|
233
|
-
}
|
|
234
|
-
|
|
235
|
-
// While we're at it, zero the bsr offsets buffer
|
|
236
|
-
memset_device(WP_CURRENT_CONTEXT, bsr_offsets, 0,
|
|
237
|
-
(row_count + 1) * sizeof(int));
|
|
238
|
-
|
|
239
|
-
// Wait for number of compressed blocks
|
|
240
|
-
cudaEventSynchronize(bsr_temp.host_sync_event);
|
|
241
|
-
const int compressed_nnz = *p_nz_triplet_count;
|
|
242
|
-
|
|
243
|
-
// We have all we need to accumulate our repeated blocks
|
|
244
|
-
wp_launch_device(WP_CURRENT_CONTEXT, bsr_merge_blocks, compressed_nnz,
|
|
245
|
-
(compressed_nnz, block_size, d_keys.Alternate(),
|
|
246
|
-
d_keys.Current(), d_values.Alternate(), tpl_values,
|
|
247
|
-
bsr_offsets, bsr_columns, bsr_values));
|
|
248
|
-
|
|
249
|
-
// Last, prefix sum the row block counts
|
|
250
|
-
{
|
|
251
|
-
size_t buff_size = 0;
|
|
252
|
-
check_cuda(cub::DeviceScan::InclusiveSum(nullptr, buff_size, bsr_offsets,
|
|
253
|
-
bsr_offsets, row_count + 1, stream));
|
|
254
|
-
ScopedTemporary<> temp(context, buff_size);
|
|
255
|
-
check_cuda(cub::DeviceScan::InclusiveSum(temp.buffer(), buff_size,
|
|
256
|
-
bsr_offsets, bsr_offsets,
|
|
257
|
-
row_count + 1, stream));
|
|
258
|
-
}
|
|
259
|
-
|
|
260
|
-
return compressed_nnz;
|
|
219
|
+
|
|
220
|
+
// Accumulate repeated blocks and set column indices
|
|
221
|
+
wp_launch_device(WP_CURRENT_CONTEXT, bsr_merge_blocks, nnz,
|
|
222
|
+
(unique_triplet_count, block_size, d_keys.Alternate(), d_keys.Current(), d_values.Alternate(),
|
|
223
|
+
tpl_values, bsr_columns, bsr_values));
|
|
261
224
|
}
|
|
262
225
|
|
|
263
|
-
__global__ void bsr_transpose_fill_row_col(const int
|
|
264
|
-
const int *
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
block_indices[i] = i;
|
|
274
|
-
|
|
275
|
-
// Binary search for row
|
|
276
|
-
int lower = 0;
|
|
277
|
-
int upper = row_count - 1;
|
|
278
|
-
|
|
279
|
-
while (lower < upper) {
|
|
280
|
-
int mid = lower + (upper - lower) / 2;
|
|
281
|
-
|
|
282
|
-
if (bsr_offsets[mid + 1] <= i) {
|
|
283
|
-
lower = mid + 1;
|
|
284
|
-
} else {
|
|
285
|
-
upper = mid;
|
|
226
|
+
__global__ void bsr_transpose_fill_row_col(const int nnz_upper_bound, const int row_count, const int* bsr_offsets,
|
|
227
|
+
const int* bsr_columns, int* block_indices, BsrRowCol* transposed_row_col)
|
|
228
|
+
{
|
|
229
|
+
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
|
230
|
+
|
|
231
|
+
if (i >= nnz_upper_bound)
|
|
232
|
+
{
|
|
233
|
+
// Outside of allocated bounds, do nothing
|
|
234
|
+
return;
|
|
286
235
|
}
|
|
287
|
-
}
|
|
288
236
|
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
237
|
+
if (i >= bsr_offsets[row_count])
|
|
238
|
+
{
|
|
239
|
+
// Below upper bound but above actual nnz count, mark as invalid
|
|
240
|
+
transposed_row_col[i] = PRUNED_ROWCOL;
|
|
241
|
+
return;
|
|
242
|
+
}
|
|
293
243
|
|
|
294
|
-
|
|
244
|
+
block_indices[i] = i;
|
|
245
|
+
|
|
246
|
+
// Binary search for row
|
|
247
|
+
int lower = 0;
|
|
248
|
+
int upper = row_count - 1;
|
|
249
|
+
|
|
250
|
+
while (lower < upper)
|
|
251
|
+
{
|
|
252
|
+
int mid = lower + (upper - lower) / 2;
|
|
253
|
+
|
|
254
|
+
if (bsr_offsets[mid + 1] <= i)
|
|
255
|
+
{
|
|
256
|
+
lower = mid + 1;
|
|
257
|
+
}
|
|
258
|
+
else
|
|
259
|
+
{
|
|
260
|
+
upper = mid;
|
|
261
|
+
}
|
|
262
|
+
}
|
|
263
|
+
|
|
264
|
+
const int row = lower;
|
|
265
|
+
const int col = bsr_columns[i];
|
|
266
|
+
BsrRowCol row_col = bsr_combine_row_col(col, row);
|
|
267
|
+
transposed_row_col[i] = row_col;
|
|
295
268
|
}
|
|
296
269
|
|
|
297
|
-
template <int Rows, int Cols, typename T> struct BsrBlockTransposer
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
270
|
+
template <int Rows, int Cols, typename T> struct BsrBlockTransposer
|
|
271
|
+
{
|
|
272
|
+
void CUDA_CALLABLE_DEVICE operator()(const T* src, T* dest) const
|
|
273
|
+
{
|
|
274
|
+
for (int r = 0; r < Rows; ++r)
|
|
275
|
+
{
|
|
276
|
+
for (int c = 0; c < Cols; ++c)
|
|
277
|
+
{
|
|
278
|
+
dest[c * Rows + r] = src[r * Cols + c];
|
|
279
|
+
}
|
|
280
|
+
}
|
|
303
281
|
}
|
|
304
|
-
}
|
|
305
282
|
};
|
|
306
283
|
|
|
307
|
-
template <typename T> struct BsrBlockTransposer<-1, -1, T>
|
|
284
|
+
template <typename T> struct BsrBlockTransposer<-1, -1, T>
|
|
285
|
+
{
|
|
308
286
|
|
|
309
|
-
|
|
310
|
-
|
|
287
|
+
int row_count;
|
|
288
|
+
int col_count;
|
|
311
289
|
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
290
|
+
void CUDA_CALLABLE_DEVICE operator()(const T* src, T* dest) const
|
|
291
|
+
{
|
|
292
|
+
for (int r = 0; r < row_count; ++r)
|
|
293
|
+
{
|
|
294
|
+
for (int c = 0; c < col_count; ++c)
|
|
295
|
+
{
|
|
296
|
+
dest[c * row_count + r] = src[r * col_count + c];
|
|
297
|
+
}
|
|
298
|
+
}
|
|
317
299
|
}
|
|
318
|
-
}
|
|
319
300
|
};
|
|
320
301
|
|
|
321
302
|
template <int Rows, int Cols, typename T>
|
|
322
|
-
__global__ void
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
transposed_bsr_values + i * block_size);
|
|
336
|
-
|
|
337
|
-
transposed_bsr_columns[i] = bsr_get_col(transposed_indices[i]);
|
|
303
|
+
__global__ void bsr_transpose_blocks(const int* nnz, const int block_size, BsrBlockTransposer<Rows, Cols, T> transposer,
|
|
304
|
+
const int* block_indices, const BsrRowCol* transposed_indices, const T* bsr_values,
|
|
305
|
+
int* transposed_bsr_columns, T* transposed_bsr_values)
|
|
306
|
+
{
|
|
307
|
+
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
|
308
|
+
if (i >= *nnz)
|
|
309
|
+
return;
|
|
310
|
+
|
|
311
|
+
const int src_idx = block_indices[i];
|
|
312
|
+
|
|
313
|
+
transposer(bsr_values + src_idx * block_size, transposed_bsr_values + i * block_size);
|
|
314
|
+
|
|
315
|
+
transposed_bsr_columns[i] = bsr_get_col(transposed_indices[i]);
|
|
338
316
|
}
|
|
339
317
|
|
|
340
318
|
template <typename T>
|
|
341
|
-
void
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
switch (rows_per_block) {
|
|
350
|
-
case 1:
|
|
351
|
-
switch (cols_per_block) {
|
|
352
|
-
case 1:
|
|
353
|
-
wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
|
|
354
|
-
(nnz, block_size, BsrBlockTransposer<1, 1, T>{},
|
|
355
|
-
block_indices, transposed_indices, bsr_values,
|
|
356
|
-
transposed_bsr_columns, transposed_bsr_values));
|
|
357
|
-
return;
|
|
358
|
-
case 2:
|
|
359
|
-
wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
|
|
360
|
-
(nnz, block_size, BsrBlockTransposer<1, 2, T>{},
|
|
361
|
-
block_indices, transposed_indices, bsr_values,
|
|
362
|
-
transposed_bsr_columns, transposed_bsr_values));
|
|
363
|
-
return;
|
|
364
|
-
case 3:
|
|
365
|
-
wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
|
|
366
|
-
(nnz, block_size, BsrBlockTransposer<1, 3, T>{},
|
|
367
|
-
block_indices, transposed_indices, bsr_values,
|
|
368
|
-
transposed_bsr_columns, transposed_bsr_values));
|
|
369
|
-
return;
|
|
370
|
-
}
|
|
371
|
-
case 2:
|
|
372
|
-
switch (cols_per_block) {
|
|
373
|
-
case 1:
|
|
374
|
-
wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
|
|
375
|
-
(nnz, block_size, BsrBlockTransposer<2, 1, T>{},
|
|
376
|
-
block_indices, transposed_indices, bsr_values,
|
|
377
|
-
transposed_bsr_columns, transposed_bsr_values));
|
|
378
|
-
return;
|
|
379
|
-
case 2:
|
|
380
|
-
wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
|
|
381
|
-
(nnz, block_size, BsrBlockTransposer<2, 2, T>{},
|
|
382
|
-
block_indices, transposed_indices, bsr_values,
|
|
383
|
-
transposed_bsr_columns, transposed_bsr_values));
|
|
384
|
-
return;
|
|
385
|
-
case 3:
|
|
386
|
-
wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
|
|
387
|
-
(nnz, block_size, BsrBlockTransposer<2, 3, T>{},
|
|
388
|
-
block_indices, transposed_indices, bsr_values,
|
|
389
|
-
transposed_bsr_columns, transposed_bsr_values));
|
|
390
|
-
return;
|
|
391
|
-
}
|
|
392
|
-
case 3:
|
|
393
|
-
switch (cols_per_block) {
|
|
319
|
+
void launch_bsr_transpose_blocks(int nnz, const int* d_nnz, const int block_size, const int rows_per_block,
|
|
320
|
+
const int cols_per_block, const int* block_indices,
|
|
321
|
+
const BsrRowCol* transposed_indices, const T* bsr_values, int* transposed_bsr_columns,
|
|
322
|
+
T* transposed_bsr_values)
|
|
323
|
+
{
|
|
324
|
+
|
|
325
|
+
switch (rows_per_block)
|
|
326
|
+
{
|
|
394
327
|
case 1:
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
328
|
+
switch (cols_per_block)
|
|
329
|
+
{
|
|
330
|
+
case 1:
|
|
331
|
+
wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
|
|
332
|
+
(d_nnz, block_size, BsrBlockTransposer<1, 1, T>{}, block_indices, transposed_indices,
|
|
333
|
+
bsr_values, transposed_bsr_columns, transposed_bsr_values));
|
|
334
|
+
return;
|
|
335
|
+
case 2:
|
|
336
|
+
wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
|
|
337
|
+
(d_nnz, block_size, BsrBlockTransposer<1, 2, T>{}, block_indices, transposed_indices,
|
|
338
|
+
bsr_values, transposed_bsr_columns, transposed_bsr_values));
|
|
339
|
+
return;
|
|
340
|
+
case 3:
|
|
341
|
+
wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
|
|
342
|
+
(d_nnz, block_size, BsrBlockTransposer<1, 3, T>{}, block_indices, transposed_indices,
|
|
343
|
+
bsr_values, transposed_bsr_columns, transposed_bsr_values));
|
|
344
|
+
return;
|
|
345
|
+
}
|
|
400
346
|
case 2:
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
347
|
+
switch (cols_per_block)
|
|
348
|
+
{
|
|
349
|
+
case 1:
|
|
350
|
+
wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
|
|
351
|
+
(d_nnz, block_size, BsrBlockTransposer<2, 1, T>{}, block_indices, transposed_indices,
|
|
352
|
+
bsr_values, transposed_bsr_columns, transposed_bsr_values));
|
|
353
|
+
return;
|
|
354
|
+
case 2:
|
|
355
|
+
wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
|
|
356
|
+
(d_nnz, block_size, BsrBlockTransposer<2, 2, T>{}, block_indices, transposed_indices,
|
|
357
|
+
bsr_values, transposed_bsr_columns, transposed_bsr_values));
|
|
358
|
+
return;
|
|
359
|
+
case 3:
|
|
360
|
+
wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
|
|
361
|
+
(d_nnz, block_size, BsrBlockTransposer<2, 3, T>{}, block_indices, transposed_indices,
|
|
362
|
+
bsr_values, transposed_bsr_columns, transposed_bsr_values));
|
|
363
|
+
return;
|
|
364
|
+
}
|
|
406
365
|
case 3:
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
366
|
+
switch (cols_per_block)
|
|
367
|
+
{
|
|
368
|
+
case 1:
|
|
369
|
+
wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
|
|
370
|
+
(d_nnz, block_size, BsrBlockTransposer<3, 1, T>{}, block_indices, transposed_indices,
|
|
371
|
+
bsr_values, transposed_bsr_columns, transposed_bsr_values));
|
|
372
|
+
return;
|
|
373
|
+
case 2:
|
|
374
|
+
wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
|
|
375
|
+
(d_nnz, block_size, BsrBlockTransposer<3, 2, T>{}, block_indices, transposed_indices,
|
|
376
|
+
bsr_values, transposed_bsr_columns, transposed_bsr_values));
|
|
377
|
+
return;
|
|
378
|
+
case 3:
|
|
379
|
+
wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
|
|
380
|
+
(d_nnz, block_size, BsrBlockTransposer<3, 3, T>{}, block_indices, transposed_indices,
|
|
381
|
+
bsr_values, transposed_bsr_columns, transposed_bsr_values));
|
|
382
|
+
return;
|
|
383
|
+
}
|
|
412
384
|
}
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
(nnz, block_size,
|
|
418
|
-
BsrBlockTransposer<-1, -1, T>{rows_per_block, cols_per_block},
|
|
419
|
-
block_indices, transposed_indices, bsr_values, transposed_bsr_columns,
|
|
420
|
-
transposed_bsr_values));
|
|
385
|
+
|
|
386
|
+
wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
|
|
387
|
+
(d_nnz, block_size, BsrBlockTransposer<-1, -1, T>{rows_per_block, cols_per_block}, block_indices,
|
|
388
|
+
transposed_indices, bsr_values, transposed_bsr_columns, transposed_bsr_values));
|
|
421
389
|
}
|
|
422
390
|
|
|
423
391
|
template <typename T>
|
|
424
|
-
void bsr_transpose_device(int rows_per_block, int cols_per_block, int row_count,
|
|
425
|
-
int
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
// Prefix sum the transposed row block counts
|
|
465
|
-
{
|
|
466
|
-
size_t buff_size = 0;
|
|
467
|
-
check_cuda(cub::DeviceScan::InclusiveSum(
|
|
468
|
-
nullptr, buff_size, transposed_bsr_offsets, transposed_bsr_offsets,
|
|
469
|
-
col_count + 1, stream));
|
|
470
|
-
ScopedTemporary<> temp(context, buff_size);
|
|
471
|
-
check_cuda(cub::DeviceScan::InclusiveSum(
|
|
472
|
-
temp.buffer(), buff_size, transposed_bsr_offsets,
|
|
473
|
-
transposed_bsr_offsets, col_count + 1, stream));
|
|
474
|
-
}
|
|
475
|
-
|
|
476
|
-
// Move and transpose individual blocks
|
|
477
|
-
launch_bsr_transpose_blocks(
|
|
478
|
-
nnz, block_size,
|
|
479
|
-
rows_per_block, cols_per_block,
|
|
480
|
-
d_keys.Current(), d_values.Current(), bsr_values, transposed_bsr_columns,
|
|
481
|
-
transposed_bsr_values);
|
|
392
|
+
void bsr_transpose_device(int rows_per_block, int cols_per_block, int row_count, int col_count, int nnz,
|
|
393
|
+
const int* bsr_offsets, const int* bsr_columns, const T* bsr_values,
|
|
394
|
+
int* transposed_bsr_offsets, int* transposed_bsr_columns, T* transposed_bsr_values)
|
|
395
|
+
{
|
|
396
|
+
|
|
397
|
+
const int block_size = rows_per_block * cols_per_block;
|
|
398
|
+
|
|
399
|
+
void* context = cuda_context_get_current();
|
|
400
|
+
ContextGuard guard(context);
|
|
401
|
+
|
|
402
|
+
cudaStream_t stream = static_cast<cudaStream_t>(cuda_stream_get_current());
|
|
403
|
+
|
|
404
|
+
ScopedTemporary<int> block_indices(context, 2 * nnz);
|
|
405
|
+
ScopedTemporary<BsrRowCol> combined_row_col(context, 2 * nnz);
|
|
406
|
+
|
|
407
|
+
cub::DoubleBuffer<int> d_keys(block_indices.buffer(), block_indices.buffer() + nnz);
|
|
408
|
+
cub::DoubleBuffer<BsrRowCol> d_values(combined_row_col.buffer(), combined_row_col.buffer() + nnz);
|
|
409
|
+
|
|
410
|
+
wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_fill_row_col, nnz,
|
|
411
|
+
(nnz, row_count, bsr_offsets, bsr_columns, d_keys.Current(), d_values.Current()));
|
|
412
|
+
|
|
413
|
+
// Sort blocks
|
|
414
|
+
{
|
|
415
|
+
size_t buff_size = 0;
|
|
416
|
+
check_cuda(cub::DeviceRadixSort::SortPairs(nullptr, buff_size, d_values, d_keys, nnz, 0, 64, stream));
|
|
417
|
+
ScopedTemporary<> temp(context, buff_size);
|
|
418
|
+
check_cuda(cub::DeviceRadixSort::SortPairs(temp.buffer(), buff_size, d_values, d_keys, nnz, 0, 64, stream));
|
|
419
|
+
}
|
|
420
|
+
|
|
421
|
+
// Compute row offsets from sorted unique blocks
|
|
422
|
+
wp_launch_device(WP_CURRENT_CONTEXT, bsr_find_row_offsets, col_count + 1,
|
|
423
|
+
(col_count, bsr_offsets + row_count, d_values.Current(), transposed_bsr_offsets));
|
|
424
|
+
|
|
425
|
+
// Move and transpose individual blocks
|
|
426
|
+
if (transposed_bsr_values != nullptr)
|
|
427
|
+
{
|
|
428
|
+
launch_bsr_transpose_blocks(nnz, bsr_offsets + row_count, block_size, rows_per_block, cols_per_block,
|
|
429
|
+
d_keys.Current(), d_values.Current(), bsr_values, transposed_bsr_columns,
|
|
430
|
+
transposed_bsr_values);
|
|
431
|
+
}
|
|
482
432
|
}
|
|
483
433
|
|
|
484
434
|
} // namespace
|
|
485
435
|
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
reinterpret_cast<const float *>(tpl_values),
|
|
495
|
-
reinterpret_cast<int *>(bsr_offsets),
|
|
496
|
-
reinterpret_cast<int *>(bsr_columns),
|
|
497
|
-
reinterpret_cast<float *>(bsr_values));
|
|
436
|
+
void bsr_matrix_from_triplets_float_device(int rows_per_block, int cols_per_block, int row_count, int nnz,
|
|
437
|
+
int* tpl_rows, int* tpl_columns, void* tpl_values,
|
|
438
|
+
bool prune_numerical_zeros, int* bsr_offsets, int* bsr_columns,
|
|
439
|
+
void* bsr_values, int* bsr_nnz, void* bsr_nnz_event)
|
|
440
|
+
{
|
|
441
|
+
return bsr_matrix_from_triplets_device<float>(
|
|
442
|
+
rows_per_block, cols_per_block, row_count, nnz, tpl_rows, tpl_columns, static_cast<const float*>(tpl_values),
|
|
443
|
+
prune_numerical_zeros, bsr_offsets, bsr_columns, static_cast<float*>(bsr_values), bsr_nnz, bsr_nnz_event);
|
|
498
444
|
}
|
|
499
445
|
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
reinterpret_cast<const double *>(tpl_values),
|
|
509
|
-
reinterpret_cast<int *>(bsr_offsets),
|
|
510
|
-
reinterpret_cast<int *>(bsr_columns),
|
|
511
|
-
reinterpret_cast<double *>(bsr_values));
|
|
446
|
+
void bsr_matrix_from_triplets_double_device(int rows_per_block, int cols_per_block, int row_count, int nnz,
|
|
447
|
+
int* tpl_rows, int* tpl_columns, void* tpl_values,
|
|
448
|
+
bool prune_numerical_zeros, int* bsr_offsets, int* bsr_columns,
|
|
449
|
+
void* bsr_values, int* bsr_nnz, void* bsr_nnz_event)
|
|
450
|
+
{
|
|
451
|
+
return bsr_matrix_from_triplets_device<double>(
|
|
452
|
+
rows_per_block, cols_per_block, row_count, nnz, tpl_rows, tpl_columns, static_cast<const double*>(tpl_values),
|
|
453
|
+
prune_numerical_zeros, bsr_offsets, bsr_columns, static_cast<double*>(bsr_values), bsr_nnz, bsr_nnz_event);
|
|
512
454
|
}
|
|
513
455
|
|
|
514
|
-
void bsr_transpose_float_device(int rows_per_block, int cols_per_block,
|
|
515
|
-
int
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
bsr_transpose_device(rows_per_block, cols_per_block, row_count, col_count,
|
|
522
|
-
nnz, reinterpret_cast<const int *>(bsr_offsets),
|
|
523
|
-
reinterpret_cast<const int *>(bsr_columns),
|
|
524
|
-
reinterpret_cast<const float *>(bsr_values),
|
|
525
|
-
reinterpret_cast<int *>(transposed_bsr_offsets),
|
|
526
|
-
reinterpret_cast<int *>(transposed_bsr_columns),
|
|
527
|
-
reinterpret_cast<float *>(transposed_bsr_values));
|
|
456
|
+
void bsr_transpose_float_device(int rows_per_block, int cols_per_block, int row_count, int col_count, int nnz,
|
|
457
|
+
int* bsr_offsets, int* bsr_columns, void* bsr_values, int* transposed_bsr_offsets,
|
|
458
|
+
int* transposed_bsr_columns, void* transposed_bsr_values)
|
|
459
|
+
{
|
|
460
|
+
bsr_transpose_device(rows_per_block, cols_per_block, row_count, col_count, nnz, bsr_offsets, bsr_columns,
|
|
461
|
+
static_cast<const float*>(bsr_values), transposed_bsr_offsets, transposed_bsr_columns,
|
|
462
|
+
static_cast<float*>(transposed_bsr_values));
|
|
528
463
|
}
|
|
529
464
|
|
|
530
|
-
void bsr_transpose_double_device(int rows_per_block, int cols_per_block,
|
|
531
|
-
int
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
bsr_transpose_device(rows_per_block, cols_per_block, row_count, col_count,
|
|
538
|
-
nnz, reinterpret_cast<const int *>(bsr_offsets),
|
|
539
|
-
reinterpret_cast<const int *>(bsr_columns),
|
|
540
|
-
reinterpret_cast<const double *>(bsr_values),
|
|
541
|
-
reinterpret_cast<int *>(transposed_bsr_offsets),
|
|
542
|
-
reinterpret_cast<int *>(transposed_bsr_columns),
|
|
543
|
-
reinterpret_cast<double *>(transposed_bsr_values));
|
|
465
|
+
void bsr_transpose_double_device(int rows_per_block, int cols_per_block, int row_count, int col_count, int nnz,
|
|
466
|
+
int* bsr_offsets, int* bsr_columns, void* bsr_values, int* transposed_bsr_offsets,
|
|
467
|
+
int* transposed_bsr_columns, void* transposed_bsr_values)
|
|
468
|
+
{
|
|
469
|
+
bsr_transpose_device(rows_per_block, cols_per_block, row_count, col_count, nnz, bsr_offsets, bsr_columns,
|
|
470
|
+
static_cast<const double*>(bsr_values), transposed_bsr_offsets, transposed_bsr_columns,
|
|
471
|
+
static_cast<double*>(transposed_bsr_values));
|
|
544
472
|
}
|