warp-lang 1.2.2__py3-none-manylinux2014_aarch64.whl → 1.3.1__py3-none-manylinux2014_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (193) hide show
  1. warp/__init__.py +8 -6
  2. warp/autograd.py +823 -0
  3. warp/bin/warp.so +0 -0
  4. warp/build.py +6 -2
  5. warp/builtins.py +1412 -888
  6. warp/codegen.py +503 -166
  7. warp/config.py +48 -18
  8. warp/context.py +400 -198
  9. warp/dlpack.py +8 -0
  10. warp/examples/assets/bunny.usd +0 -0
  11. warp/examples/benchmarks/benchmark_cloth_warp.py +1 -1
  12. warp/examples/benchmarks/benchmark_interop_torch.py +158 -0
  13. warp/examples/benchmarks/benchmark_launches.py +1 -1
  14. warp/examples/core/example_cupy.py +78 -0
  15. warp/examples/fem/example_apic_fluid.py +17 -36
  16. warp/examples/fem/example_burgers.py +9 -18
  17. warp/examples/fem/example_convection_diffusion.py +7 -17
  18. warp/examples/fem/example_convection_diffusion_dg.py +27 -47
  19. warp/examples/fem/example_deformed_geometry.py +11 -22
  20. warp/examples/fem/example_diffusion.py +7 -18
  21. warp/examples/fem/example_diffusion_3d.py +24 -28
  22. warp/examples/fem/example_diffusion_mgpu.py +7 -14
  23. warp/examples/fem/example_magnetostatics.py +190 -0
  24. warp/examples/fem/example_mixed_elasticity.py +111 -80
  25. warp/examples/fem/example_navier_stokes.py +30 -34
  26. warp/examples/fem/example_nonconforming_contact.py +290 -0
  27. warp/examples/fem/example_stokes.py +17 -32
  28. warp/examples/fem/example_stokes_transfer.py +12 -21
  29. warp/examples/fem/example_streamlines.py +350 -0
  30. warp/examples/fem/utils.py +936 -0
  31. warp/fabric.py +5 -2
  32. warp/fem/__init__.py +13 -3
  33. warp/fem/cache.py +161 -11
  34. warp/fem/dirichlet.py +37 -28
  35. warp/fem/domain.py +105 -14
  36. warp/fem/field/__init__.py +14 -3
  37. warp/fem/field/field.py +454 -11
  38. warp/fem/field/nodal_field.py +33 -18
  39. warp/fem/geometry/deformed_geometry.py +50 -15
  40. warp/fem/geometry/hexmesh.py +12 -24
  41. warp/fem/geometry/nanogrid.py +106 -31
  42. warp/fem/geometry/quadmesh_2d.py +6 -11
  43. warp/fem/geometry/tetmesh.py +103 -61
  44. warp/fem/geometry/trimesh_2d.py +98 -47
  45. warp/fem/integrate.py +231 -186
  46. warp/fem/operator.py +14 -9
  47. warp/fem/quadrature/pic_quadrature.py +35 -9
  48. warp/fem/quadrature/quadrature.py +119 -32
  49. warp/fem/space/basis_space.py +98 -22
  50. warp/fem/space/collocated_function_space.py +3 -1
  51. warp/fem/space/function_space.py +7 -2
  52. warp/fem/space/grid_2d_function_space.py +3 -3
  53. warp/fem/space/grid_3d_function_space.py +4 -4
  54. warp/fem/space/hexmesh_function_space.py +3 -2
  55. warp/fem/space/nanogrid_function_space.py +12 -14
  56. warp/fem/space/partition.py +45 -47
  57. warp/fem/space/restriction.py +19 -16
  58. warp/fem/space/shape/cube_shape_function.py +91 -3
  59. warp/fem/space/shape/shape_function.py +7 -0
  60. warp/fem/space/shape/square_shape_function.py +32 -0
  61. warp/fem/space/shape/tet_shape_function.py +11 -7
  62. warp/fem/space/shape/triangle_shape_function.py +10 -1
  63. warp/fem/space/topology.py +116 -42
  64. warp/fem/types.py +8 -1
  65. warp/fem/utils.py +301 -83
  66. warp/native/array.h +16 -0
  67. warp/native/builtin.h +0 -15
  68. warp/native/cuda_util.cpp +14 -6
  69. warp/native/exports.h +1348 -1308
  70. warp/native/quat.h +79 -0
  71. warp/native/rand.h +27 -4
  72. warp/native/sparse.cpp +83 -81
  73. warp/native/sparse.cu +381 -453
  74. warp/native/vec.h +64 -0
  75. warp/native/volume.cpp +40 -49
  76. warp/native/volume_builder.cu +2 -3
  77. warp/native/volume_builder.h +12 -17
  78. warp/native/warp.cu +3 -3
  79. warp/native/warp.h +69 -59
  80. warp/render/render_opengl.py +17 -9
  81. warp/sim/articulation.py +117 -17
  82. warp/sim/collide.py +35 -29
  83. warp/sim/model.py +123 -18
  84. warp/sim/render.py +3 -1
  85. warp/sparse.py +867 -203
  86. warp/stubs.py +312 -541
  87. warp/tape.py +29 -1
  88. warp/tests/disabled_kinematics.py +1 -1
  89. warp/tests/test_adam.py +1 -1
  90. warp/tests/test_arithmetic.py +1 -1
  91. warp/tests/test_array.py +58 -1
  92. warp/tests/test_array_reduce.py +1 -1
  93. warp/tests/test_async.py +1 -1
  94. warp/tests/test_atomic.py +1 -1
  95. warp/tests/test_bool.py +1 -1
  96. warp/tests/test_builtins_resolution.py +1 -1
  97. warp/tests/test_bvh.py +6 -1
  98. warp/tests/test_closest_point_edge_edge.py +1 -1
  99. warp/tests/test_codegen.py +91 -1
  100. warp/tests/test_compile_consts.py +1 -1
  101. warp/tests/test_conditional.py +1 -1
  102. warp/tests/test_copy.py +1 -1
  103. warp/tests/test_ctypes.py +1 -1
  104. warp/tests/test_dense.py +1 -1
  105. warp/tests/test_devices.py +1 -1
  106. warp/tests/test_dlpack.py +1 -1
  107. warp/tests/test_examples.py +33 -4
  108. warp/tests/test_fabricarray.py +5 -2
  109. warp/tests/test_fast_math.py +1 -1
  110. warp/tests/test_fem.py +213 -6
  111. warp/tests/test_fp16.py +1 -1
  112. warp/tests/test_func.py +1 -1
  113. warp/tests/test_future_annotations.py +90 -0
  114. warp/tests/test_generics.py +1 -1
  115. warp/tests/test_grad.py +1 -1
  116. warp/tests/test_grad_customs.py +1 -1
  117. warp/tests/test_grad_debug.py +247 -0
  118. warp/tests/test_hash_grid.py +6 -1
  119. warp/tests/test_implicit_init.py +354 -0
  120. warp/tests/test_import.py +1 -1
  121. warp/tests/test_indexedarray.py +1 -1
  122. warp/tests/test_intersect.py +1 -1
  123. warp/tests/test_jax.py +1 -1
  124. warp/tests/test_large.py +1 -1
  125. warp/tests/test_launch.py +1 -1
  126. warp/tests/test_lerp.py +1 -1
  127. warp/tests/test_linear_solvers.py +1 -1
  128. warp/tests/test_lvalue.py +1 -1
  129. warp/tests/test_marching_cubes.py +5 -2
  130. warp/tests/test_mat.py +34 -35
  131. warp/tests/test_mat_lite.py +2 -1
  132. warp/tests/test_mat_scalar_ops.py +1 -1
  133. warp/tests/test_math.py +1 -1
  134. warp/tests/test_matmul.py +20 -16
  135. warp/tests/test_matmul_lite.py +1 -1
  136. warp/tests/test_mempool.py +1 -1
  137. warp/tests/test_mesh.py +5 -2
  138. warp/tests/test_mesh_query_aabb.py +1 -1
  139. warp/tests/test_mesh_query_point.py +1 -1
  140. warp/tests/test_mesh_query_ray.py +1 -1
  141. warp/tests/test_mlp.py +1 -1
  142. warp/tests/test_model.py +1 -1
  143. warp/tests/test_module_hashing.py +77 -1
  144. warp/tests/test_modules_lite.py +1 -1
  145. warp/tests/test_multigpu.py +1 -1
  146. warp/tests/test_noise.py +1 -1
  147. warp/tests/test_operators.py +1 -1
  148. warp/tests/test_options.py +1 -1
  149. warp/tests/test_overwrite.py +542 -0
  150. warp/tests/test_peer.py +1 -1
  151. warp/tests/test_pinned.py +1 -1
  152. warp/tests/test_print.py +1 -1
  153. warp/tests/test_quat.py +15 -1
  154. warp/tests/test_rand.py +1 -1
  155. warp/tests/test_reload.py +1 -1
  156. warp/tests/test_rounding.py +1 -1
  157. warp/tests/test_runlength_encode.py +1 -1
  158. warp/tests/test_scalar_ops.py +95 -0
  159. warp/tests/test_sim_grad.py +1 -1
  160. warp/tests/test_sim_kinematics.py +1 -1
  161. warp/tests/test_smoothstep.py +1 -1
  162. warp/tests/test_sparse.py +82 -15
  163. warp/tests/test_spatial.py +1 -1
  164. warp/tests/test_special_values.py +2 -11
  165. warp/tests/test_streams.py +11 -1
  166. warp/tests/test_struct.py +1 -1
  167. warp/tests/test_tape.py +1 -1
  168. warp/tests/test_torch.py +194 -1
  169. warp/tests/test_transient_module.py +1 -1
  170. warp/tests/test_types.py +1 -1
  171. warp/tests/test_utils.py +1 -1
  172. warp/tests/test_vec.py +15 -63
  173. warp/tests/test_vec_lite.py +2 -1
  174. warp/tests/test_vec_scalar_ops.py +65 -1
  175. warp/tests/test_verify_fp.py +1 -1
  176. warp/tests/test_volume.py +28 -2
  177. warp/tests/test_volume_write.py +1 -1
  178. warp/tests/unittest_serial.py +1 -1
  179. warp/tests/unittest_suites.py +9 -1
  180. warp/tests/walkthrough_debug.py +1 -1
  181. warp/thirdparty/unittest_parallel.py +2 -5
  182. warp/torch.py +103 -41
  183. warp/types.py +341 -224
  184. warp/utils.py +11 -2
  185. {warp_lang-1.2.2.dist-info → warp_lang-1.3.1.dist-info}/METADATA +99 -46
  186. warp_lang-1.3.1.dist-info/RECORD +368 -0
  187. warp/examples/fem/bsr_utils.py +0 -378
  188. warp/examples/fem/mesh_utils.py +0 -133
  189. warp/examples/fem/plot_utils.py +0 -292
  190. warp_lang-1.2.2.dist-info/RECORD +0 -359
  191. {warp_lang-1.2.2.dist-info → warp_lang-1.3.1.dist-info}/LICENSE.md +0 -0
  192. {warp_lang-1.2.2.dist-info → warp_lang-1.3.1.dist-info}/WHEEL +0 -0
  193. {warp_lang-1.2.2.dist-info → warp_lang-1.3.1.dist-info}/top_level.txt +0 -0
warp/native/sparse.cu CHANGED
@@ -6,539 +6,467 @@
6
6
  #include <cub/device/device_radix_sort.cuh>
7
7
  #include <cub/device/device_run_length_encode.cuh>
8
8
  #include <cub/device/device_scan.cuh>
9
- #include <cub/device/device_select.cuh>
10
9
 
11
- namespace {
10
+ namespace
11
+ {
12
12
 
13
13
  // Combined row+column value that can be radix-sorted with CUB
14
14
  using BsrRowCol = uint64_t;
15
15
 
16
- CUDA_CALLABLE BsrRowCol bsr_combine_row_col(uint32_t row, uint32_t col) {
17
- return (static_cast<uint64_t>(row) << 32) | col;
18
- }
19
-
20
- CUDA_CALLABLE uint32_t bsr_get_row(const BsrRowCol &row_col) {
21
- return row_col >> 32;
22
- }
16
+ static constexpr BsrRowCol PRUNED_ROWCOL = ~BsrRowCol(0);
23
17
 
24
- CUDA_CALLABLE uint32_t bsr_get_col(const BsrRowCol &row_col) {
25
- return row_col & INT_MAX;
18
+ CUDA_CALLABLE BsrRowCol bsr_combine_row_col(uint32_t row, uint32_t col)
19
+ {
20
+ return (static_cast<uint64_t>(row) << 32) | col;
26
21
  }
27
22
 
28
- // Cached temporary storage
29
- struct BsrFromTripletsTemp {
30
-
31
- int *count_buffer = NULL;
32
- cudaEvent_t host_sync_event = NULL;
33
-
34
- BsrFromTripletsTemp()
35
- : count_buffer(static_cast<int*>(alloc_pinned(sizeof(int))))
36
- {
37
- cudaEventCreateWithFlags(&host_sync_event, cudaEventDisableTiming);
38
- }
39
-
40
- ~BsrFromTripletsTemp()
41
- {
42
- cudaEventDestroy(host_sync_event);
43
- free_pinned(count_buffer);
44
- }
45
-
46
- BsrFromTripletsTemp(const BsrFromTripletsTemp&) = delete;
47
- BsrFromTripletsTemp& operator=(const BsrFromTripletsTemp&) = delete;
23
+ CUDA_CALLABLE uint32_t bsr_get_row(const BsrRowCol& row_col) { return row_col >> 32; }
48
24
 
49
- };
25
+ CUDA_CALLABLE uint32_t bsr_get_col(const BsrRowCol& row_col) { return row_col & INT_MAX; }
50
26
 
51
- // map temp buffers to CUDA contexts
52
- static std::unordered_map<void *, BsrFromTripletsTemp> g_bsr_from_triplets_temp_map;
53
-
54
- template <typename T> struct BsrBlockIsNotZero {
55
- int block_size;
56
- const T *values;
27
+ template <typename T> struct BsrBlockIsNotZero
28
+ {
29
+ int block_size;
30
+ const T* values;
57
31
 
58
- CUDA_CALLABLE_DEVICE bool operator()(int i) const {
59
- const T *val = values + i * block_size;
60
- for (int i = 0; i < block_size; ++i, ++val) {
61
- if (*val != T(0))
62
- return true;
32
+ CUDA_CALLABLE_DEVICE bool operator()(int i) const
33
+ {
34
+ if (!values)
35
+ return true;
36
+
37
+ const T* val = values + i * block_size;
38
+ for (int i = 0; i < block_size; ++i, ++val)
39
+ {
40
+ if (*val != T(0))
41
+ return true;
42
+ }
43
+ return false;
63
44
  }
64
- return false;
65
- }
66
45
  };
67
46
 
68
- __global__ void bsr_fill_block_indices(int nnz, int *block_indices) {
69
- int i = blockIdx.x * blockDim.x + threadIdx.x;
70
- if (i >= nnz)
71
- return;
47
+ template <typename T>
48
+ __global__ void bsr_fill_triplet_key_values(const int nnz, const int nrow, const int* tpl_rows, const int* tpl_columns,
49
+ const BsrBlockIsNotZero<T> nonZero, uint32_t* block_indices,
50
+ BsrRowCol* tpl_row_col)
51
+ {
52
+ int block = blockIdx.x * blockDim.x + threadIdx.x;
53
+ if (block >= nnz)
54
+ return;
55
+
56
+ const int row = tpl_rows[block];
57
+ const int col = tpl_columns[block];
58
+ const bool is_valid = row >= 0 && row < nrow;
72
59
 
73
- block_indices[i] = i;
60
+ const BsrRowCol row_col = is_valid && nonZero(block) ? bsr_combine_row_col(row, col) : PRUNED_ROWCOL;
61
+ tpl_row_col[block] = row_col;
62
+ block_indices[block] = block;
74
63
  }
75
64
 
76
- __global__ void bsr_fill_row_col(const int *nnz, const int *block_indices,
77
- const int *tpl_rows, const int *tpl_columns,
78
- BsrRowCol *tpl_row_col) {
79
- int i = blockIdx.x * blockDim.x + threadIdx.x;
80
- if (i >= *nnz)
81
- return;
65
+ template <typename T>
66
+ __global__ void bsr_find_row_offsets(uint32_t row_count, const T* d_nnz, const BsrRowCol* unique_row_col,
67
+ int* row_offsets)
68
+ {
69
+ const uint32_t row = blockIdx.x * blockDim.x + threadIdx.x;
70
+
71
+ if (row > row_count)
72
+ return;
82
73
 
83
- const int block = block_indices[i];
74
+ const uint32_t nnz = *d_nnz;
75
+ if (row == 0 || nnz == 0)
76
+ {
77
+ row_offsets[row] = 0;
78
+ return;
79
+ }
80
+
81
+ if (bsr_get_row(unique_row_col[nnz - 1]) < row)
82
+ {
83
+ row_offsets[row] = nnz;
84
+ return;
85
+ }
86
+
87
+ // binary search for row start
88
+ uint32_t lower = 0;
89
+ uint32_t upper = nnz - 1;
90
+ while (lower < upper)
91
+ {
92
+ uint32_t mid = lower + (upper - lower) / 2;
93
+
94
+ if (bsr_get_row(unique_row_col[mid]) < row)
95
+ {
96
+ lower = mid + 1;
97
+ }
98
+ else
99
+ {
100
+ upper = mid;
101
+ }
102
+ }
84
103
 
85
- BsrRowCol row_col = bsr_combine_row_col(tpl_rows[block], tpl_columns[block]);
86
- tpl_row_col[i] = row_col;
104
+ row_offsets[row] = lower;
87
105
  }
88
106
 
89
107
  template <typename T>
90
- __global__ void
91
- bsr_merge_blocks(int nnz, int block_size, const int *block_offsets,
92
- const int *sorted_block_indices,
93
- const BsrRowCol *unique_row_cols, const T *tpl_values,
94
- int *bsr_row_counts, int *bsr_cols, T *bsr_values)
108
+ __global__ void bsr_merge_blocks(const uint32_t* d_nnz, int block_size, const uint32_t* block_offsets,
109
+ const uint32_t* sorted_block_indices, const BsrRowCol* unique_row_cols,
110
+ const T* tpl_values, int* bsr_cols, T* bsr_values)
95
111
 
96
112
  {
97
- int i = blockIdx.x * blockDim.x + threadIdx.x;
98
- if (i >= nnz)
99
- return;
113
+ const uint32_t i = blockIdx.x * blockDim.x + threadIdx.x;
100
114
 
101
- const int beg = i ? block_offsets[i - 1] : 0;
102
- const int end = block_offsets[i];
115
+ if (i >= *d_nnz)
116
+ return;
103
117
 
104
- BsrRowCol row_col = unique_row_cols[i];
118
+ const BsrRowCol row_col = unique_row_cols[i];
119
+ bsr_cols[i] = bsr_get_col(row_col);
105
120
 
106
- bsr_cols[i] = bsr_get_col(row_col);
107
- atomicAdd(bsr_row_counts + bsr_get_row(row_col) + 1, 1);
121
+ // Accumulate merged block values
122
+ if (row_col == PRUNED_ROWCOL || bsr_values == nullptr)
123
+ return;
108
124
 
109
- if (bsr_values == nullptr)
110
- return;
125
+ const uint32_t beg = i ? block_offsets[i - 1] : 0;
126
+ const uint32_t end = block_offsets[i];
111
127
 
112
- T *bsr_val = bsr_values + i * block_size;
113
- const T *tpl_val = tpl_values + sorted_block_indices[beg] * block_size;
128
+ T* bsr_val = bsr_values + i * block_size;
129
+ const T* tpl_val = tpl_values + sorted_block_indices[beg] * block_size;
114
130
 
115
- for (int k = 0; k < block_size; ++k) {
116
- bsr_val[k] = tpl_val[k];
117
- }
131
+ for (int k = 0; k < block_size; ++k)
132
+ {
133
+ bsr_val[k] = tpl_val[k];
134
+ }
118
135
 
119
- for (int cur = beg + 1; cur != end; ++cur) {
120
- const T *tpl_val = tpl_values + sorted_block_indices[cur] * block_size;
121
- for (int k = 0; k < block_size; ++k) {
122
- bsr_val[k] += tpl_val[k];
136
+ for (uint32_t cur = beg + 1; cur != end; ++cur)
137
+ {
138
+ const T* tpl_val = tpl_values + sorted_block_indices[cur] * block_size;
139
+ for (int k = 0; k < block_size; ++k)
140
+ {
141
+ bsr_val[k] += tpl_val[k];
142
+ }
123
143
  }
124
- }
125
144
  }
126
145
 
127
146
  template <typename T>
128
- int bsr_matrix_from_triplets_device(const int rows_per_block,
129
- const int cols_per_block,
130
- const int row_count, const int nnz,
131
- const int *tpl_rows, const int *tpl_columns,
132
- const T *tpl_values, int *bsr_offsets,
133
- int *bsr_columns, T *bsr_values) {
134
- const int block_size = rows_per_block * cols_per_block;
147
+ void bsr_matrix_from_triplets_device(const int rows_per_block, const int cols_per_block, const int row_count,
148
+ const int nnz, const int* tpl_rows, const int* tpl_columns, const T* tpl_values,
149
+ const bool prune_numerical_zeros, int* bsr_offsets, int* bsr_columns,
150
+ T* bsr_values, int* bsr_nnz, void* bsr_nnz_event)
151
+ {
152
+ const int block_size = rows_per_block * cols_per_block;
153
+
154
+ void* context = cuda_context_get_current();
155
+ ContextGuard guard(context);
135
156
 
136
- void *context = cuda_context_get_current();
137
- ContextGuard guard(context);
157
+ // Per-context cached temporary buffers
158
+ // BsrFromTripletsTemp& bsr_temp = g_bsr_from_triplets_temp_map[context];
138
159
 
139
- // Per-context cached temporary buffers
140
- BsrFromTripletsTemp &bsr_temp = g_bsr_from_triplets_temp_map[context];
160
+ cudaStream_t stream = static_cast<cudaStream_t>(cuda_stream_get_current());
141
161
 
142
- cudaStream_t stream = static_cast<cudaStream_t>(cuda_stream_get_current());
162
+ ScopedTemporary<uint32_t> block_indices(context, 2 * nnz + 1);
163
+ ScopedTemporary<BsrRowCol> combined_row_col(context, 2 * nnz);
164
+
165
+ cub::DoubleBuffer<uint32_t> d_keys(block_indices.buffer(), block_indices.buffer() + nnz);
166
+ cub::DoubleBuffer<BsrRowCol> d_values(combined_row_col.buffer(), combined_row_col.buffer() + nnz);
167
+
168
+ uint32_t* unique_triplet_count = block_indices.buffer() + 2 * nnz;
169
+
170
+ // Combine rows and columns so we can sort on them both
171
+ BsrBlockIsNotZero<T> isNotZero{block_size, prune_numerical_zeros ? tpl_values : nullptr};
172
+ wp_launch_device(WP_CURRENT_CONTEXT, bsr_fill_triplet_key_values, nnz,
173
+ (nnz, row_count, tpl_rows, tpl_columns, isNotZero, d_keys.Current(), d_values.Current()));
174
+
175
+ // Sort
176
+ {
177
+ size_t buff_size = 0;
178
+ check_cuda(cub::DeviceRadixSort::SortPairs(nullptr, buff_size, d_values, d_keys, nnz, 0, 64, stream));
179
+ ScopedTemporary<> temp(context, buff_size);
180
+ check_cuda(cub::DeviceRadixSort::SortPairs(temp.buffer(), buff_size, d_values, d_keys, nnz, 0, 64, stream));
181
+ }
143
182
 
144
- ScopedTemporary<int> block_indices(context, 2*nnz);
145
- ScopedTemporary<BsrRowCol> combined_row_col(context, 2*nnz);
183
+ // Runlength encode row-col sequences
184
+ {
185
+ size_t buff_size = 0;
186
+ check_cuda(cub::DeviceRunLengthEncode::Encode(nullptr, buff_size, d_values.Current(), d_values.Alternate(),
187
+ d_keys.Alternate(), unique_triplet_count, nnz, stream));
188
+ ScopedTemporary<> temp(context, buff_size);
189
+ check_cuda(cub::DeviceRunLengthEncode::Encode(temp.buffer(), buff_size, d_values.Current(),
190
+ d_values.Alternate(), d_keys.Alternate(), unique_triplet_count,
191
+ nnz, stream));
192
+ }
146
193
 
147
- cub::DoubleBuffer<int> d_keys(block_indices.buffer(),
148
- block_indices.buffer() + nnz);
149
- cub::DoubleBuffer<BsrRowCol> d_values(combined_row_col.buffer(),
150
- combined_row_col.buffer() + nnz);
194
+ // Compute row offsets from sorted unique blocks
195
+ wp_launch_device(WP_CURRENT_CONTEXT, bsr_find_row_offsets, row_count + 1,
196
+ (row_count, unique_triplet_count, d_values.Alternate(), bsr_offsets));
151
197
 
152
- int *p_nz_triplet_count = bsr_temp.count_buffer;
198
+ if (bsr_nnz)
199
+ {
200
+ // Copy nnz to host, and record an event for the competed transfer if desired
153
201
 
154
- wp_launch_device(WP_CURRENT_CONTEXT, bsr_fill_block_indices, nnz,
155
- (nnz, d_keys.Current()));
202
+ memcpy_d2h(WP_CURRENT_CONTEXT, bsr_nnz, bsr_offsets + row_count, sizeof(int), stream);
156
203
 
157
- if (tpl_values) {
204
+ if (bsr_nnz_event)
205
+ {
206
+ cuda_event_record(bsr_nnz_event, stream);
207
+ }
208
+ }
158
209
 
159
- // Remove zero blocks
210
+ // Scan repeated block counts
160
211
  {
161
- size_t buff_size = 0;
162
- BsrBlockIsNotZero<T> isNotZero{block_size, tpl_values};
163
- check_cuda(cub::DeviceSelect::If(nullptr, buff_size, d_keys.Current(),
164
- d_keys.Alternate(), p_nz_triplet_count,
165
- nnz, isNotZero, stream));
166
- ScopedTemporary<> temp(context, buff_size);
167
- check_cuda(cub::DeviceSelect::If(
168
- temp.buffer(), buff_size, d_keys.Current(), d_keys.Alternate(),
169
- p_nz_triplet_count, nnz, isNotZero, stream));
212
+ size_t buff_size = 0;
213
+ check_cuda(
214
+ cub::DeviceScan::InclusiveSum(nullptr, buff_size, d_keys.Alternate(), d_keys.Alternate(), nnz, stream));
215
+ ScopedTemporary<> temp(context, buff_size);
216
+ check_cuda(cub::DeviceScan::InclusiveSum(temp.buffer(), buff_size, d_keys.Alternate(), d_keys.Alternate(), nnz,
217
+ stream));
170
218
  }
171
- cudaEventRecord(bsr_temp.host_sync_event, stream);
172
-
173
- // switch current/alternate in double buffer
174
- d_keys.selector ^= 1;
175
-
176
- } else {
177
- *p_nz_triplet_count = nnz;
178
- }
179
-
180
- // Combine rows and columns so we can sort on them both
181
- wp_launch_device(WP_CURRENT_CONTEXT, bsr_fill_row_col, nnz,
182
- (p_nz_triplet_count, d_keys.Current(), tpl_rows, tpl_columns,
183
- d_values.Current()));
184
-
185
- if (tpl_values) {
186
- // Make sure count is available on host
187
- cudaEventSynchronize(bsr_temp.host_sync_event);
188
- }
189
-
190
- const int nz_triplet_count = *p_nz_triplet_count;
191
-
192
- // Sort
193
- {
194
- size_t buff_size = 0;
195
- check_cuda(cub::DeviceRadixSort::SortPairs(
196
- nullptr, buff_size, d_values, d_keys, nz_triplet_count, 0, 64, stream));
197
- ScopedTemporary<> temp(context, buff_size);
198
- check_cuda(cub::DeviceRadixSort::SortPairs(temp.buffer(), buff_size,
199
- d_values, d_keys, nz_triplet_count,
200
- 0, 64, stream));
201
- }
202
-
203
- // Runlength encode row-col sequences
204
- {
205
- size_t buff_size = 0;
206
- check_cuda(cub::DeviceRunLengthEncode::Encode(
207
- nullptr, buff_size, d_values.Current(), d_values.Alternate(),
208
- d_keys.Alternate(), p_nz_triplet_count, nz_triplet_count, stream));
209
- ScopedTemporary<> temp(context, buff_size);
210
- check_cuda(cub::DeviceRunLengthEncode::Encode(
211
- temp.buffer(), buff_size, d_values.Current(), d_values.Alternate(),
212
- d_keys.Alternate(), p_nz_triplet_count, nz_triplet_count, stream));
213
- }
214
-
215
- cudaEventRecord(bsr_temp.host_sync_event, stream);
216
-
217
- // Now we have the following:
218
- // d_values.Current(): sorted block row-col
219
- // d_values.Alternate(): sorted unique row-col
220
- // d_keys.Current(): sorted block indices
221
- // d_keys.Alternate(): repeated block-row count
222
-
223
- // Scan repeated block counts
224
- {
225
- size_t buff_size = 0;
226
- check_cuda(cub::DeviceScan::InclusiveSum(
227
- nullptr, buff_size, d_keys.Alternate(), d_keys.Alternate(),
228
- nz_triplet_count, stream));
229
- ScopedTemporary<> temp(context, buff_size);
230
- check_cuda(cub::DeviceScan::InclusiveSum(
231
- temp.buffer(), buff_size, d_keys.Alternate(), d_keys.Alternate(),
232
- nz_triplet_count, stream));
233
- }
234
-
235
- // While we're at it, zero the bsr offsets buffer
236
- memset_device(WP_CURRENT_CONTEXT, bsr_offsets, 0,
237
- (row_count + 1) * sizeof(int));
238
-
239
- // Wait for number of compressed blocks
240
- cudaEventSynchronize(bsr_temp.host_sync_event);
241
- const int compressed_nnz = *p_nz_triplet_count;
242
-
243
- // We have all we need to accumulate our repeated blocks
244
- wp_launch_device(WP_CURRENT_CONTEXT, bsr_merge_blocks, compressed_nnz,
245
- (compressed_nnz, block_size, d_keys.Alternate(),
246
- d_keys.Current(), d_values.Alternate(), tpl_values,
247
- bsr_offsets, bsr_columns, bsr_values));
248
-
249
- // Last, prefix sum the row block counts
250
- {
251
- size_t buff_size = 0;
252
- check_cuda(cub::DeviceScan::InclusiveSum(nullptr, buff_size, bsr_offsets,
253
- bsr_offsets, row_count + 1, stream));
254
- ScopedTemporary<> temp(context, buff_size);
255
- check_cuda(cub::DeviceScan::InclusiveSum(temp.buffer(), buff_size,
256
- bsr_offsets, bsr_offsets,
257
- row_count + 1, stream));
258
- }
259
-
260
- return compressed_nnz;
219
+
220
+ // Accumulate repeated blocks and set column indices
221
+ wp_launch_device(WP_CURRENT_CONTEXT, bsr_merge_blocks, nnz,
222
+ (unique_triplet_count, block_size, d_keys.Alternate(), d_keys.Current(), d_values.Alternate(),
223
+ tpl_values, bsr_columns, bsr_values));
261
224
  }
262
225
 
263
- __global__ void bsr_transpose_fill_row_col(const int nnz, const int row_count,
264
- const int *bsr_offsets,
265
- const int *bsr_columns,
266
- int *block_indices,
267
- BsrRowCol *transposed_row_col,
268
- int *transposed_bsr_offsets) {
269
- int i = blockIdx.x * blockDim.x + threadIdx.x;
270
- if (i >= nnz)
271
- return;
272
-
273
- block_indices[i] = i;
274
-
275
- // Binary search for row
276
- int lower = 0;
277
- int upper = row_count - 1;
278
-
279
- while (lower < upper) {
280
- int mid = lower + (upper - lower) / 2;
281
-
282
- if (bsr_offsets[mid + 1] <= i) {
283
- lower = mid + 1;
284
- } else {
285
- upper = mid;
226
+ __global__ void bsr_transpose_fill_row_col(const int nnz_upper_bound, const int row_count, const int* bsr_offsets,
227
+ const int* bsr_columns, int* block_indices, BsrRowCol* transposed_row_col)
228
+ {
229
+ int i = blockIdx.x * blockDim.x + threadIdx.x;
230
+
231
+ if (i >= nnz_upper_bound)
232
+ {
233
+ // Outside of allocated bounds, do nothing
234
+ return;
286
235
  }
287
- }
288
236
 
289
- const int row = lower;
290
- const int col = bsr_columns[i];
291
- BsrRowCol row_col = bsr_combine_row_col(col, row);
292
- transposed_row_col[i] = row_col;
237
+ if (i >= bsr_offsets[row_count])
238
+ {
239
+ // Below upper bound but above actual nnz count, mark as invalid
240
+ transposed_row_col[i] = PRUNED_ROWCOL;
241
+ return;
242
+ }
293
243
 
294
- atomicAdd(transposed_bsr_offsets + col + 1, 1);
244
+ block_indices[i] = i;
245
+
246
+ // Binary search for row
247
+ int lower = 0;
248
+ int upper = row_count - 1;
249
+
250
+ while (lower < upper)
251
+ {
252
+ int mid = lower + (upper - lower) / 2;
253
+
254
+ if (bsr_offsets[mid + 1] <= i)
255
+ {
256
+ lower = mid + 1;
257
+ }
258
+ else
259
+ {
260
+ upper = mid;
261
+ }
262
+ }
263
+
264
+ const int row = lower;
265
+ const int col = bsr_columns[i];
266
+ BsrRowCol row_col = bsr_combine_row_col(col, row);
267
+ transposed_row_col[i] = row_col;
295
268
  }
296
269
 
297
- template <int Rows, int Cols, typename T> struct BsrBlockTransposer {
298
- void CUDA_CALLABLE_DEVICE operator()(const T *src, T *dest) const {
299
- for (int r = 0; r < Rows; ++r) {
300
- for (int c = 0; c < Cols; ++c) {
301
- dest[c * Rows + r] = src[r * Cols + c];
302
- }
270
+ template <int Rows, int Cols, typename T> struct BsrBlockTransposer
271
+ {
272
+ void CUDA_CALLABLE_DEVICE operator()(const T* src, T* dest) const
273
+ {
274
+ for (int r = 0; r < Rows; ++r)
275
+ {
276
+ for (int c = 0; c < Cols; ++c)
277
+ {
278
+ dest[c * Rows + r] = src[r * Cols + c];
279
+ }
280
+ }
303
281
  }
304
- }
305
282
  };
306
283
 
307
- template <typename T> struct BsrBlockTransposer<-1, -1, T> {
284
+ template <typename T> struct BsrBlockTransposer<-1, -1, T>
285
+ {
308
286
 
309
- int row_count;
310
- int col_count;
287
+ int row_count;
288
+ int col_count;
311
289
 
312
- void CUDA_CALLABLE_DEVICE operator()(const T *src, T *dest) const {
313
- for (int r = 0; r < row_count; ++r) {
314
- for (int c = 0; c < col_count; ++c) {
315
- dest[c * row_count + r] = src[r * col_count + c];
316
- }
290
+ void CUDA_CALLABLE_DEVICE operator()(const T* src, T* dest) const
291
+ {
292
+ for (int r = 0; r < row_count; ++r)
293
+ {
294
+ for (int c = 0; c < col_count; ++c)
295
+ {
296
+ dest[c * row_count + r] = src[r * col_count + c];
297
+ }
298
+ }
317
299
  }
318
- }
319
300
  };
320
301
 
321
302
  template <int Rows, int Cols, typename T>
322
- __global__ void
323
- bsr_transpose_blocks(const int nnz, const int block_size,
324
- BsrBlockTransposer<Rows, Cols, T> transposer,
325
- const int *block_indices,
326
- const BsrRowCol *transposed_indices, const T *bsr_values,
327
- int *transposed_bsr_columns, T *transposed_bsr_values) {
328
- int i = blockIdx.x * blockDim.x + threadIdx.x;
329
- if (i >= nnz)
330
- return;
331
-
332
- const int src_idx = block_indices[i];
333
-
334
- transposer(bsr_values + src_idx * block_size,
335
- transposed_bsr_values + i * block_size);
336
-
337
- transposed_bsr_columns[i] = bsr_get_col(transposed_indices[i]);
303
+ __global__ void bsr_transpose_blocks(const int* nnz, const int block_size, BsrBlockTransposer<Rows, Cols, T> transposer,
304
+ const int* block_indices, const BsrRowCol* transposed_indices, const T* bsr_values,
305
+ int* transposed_bsr_columns, T* transposed_bsr_values)
306
+ {
307
+ int i = blockIdx.x * blockDim.x + threadIdx.x;
308
+ if (i >= *nnz)
309
+ return;
310
+
311
+ const int src_idx = block_indices[i];
312
+
313
+ transposer(bsr_values + src_idx * block_size, transposed_bsr_values + i * block_size);
314
+
315
+ transposed_bsr_columns[i] = bsr_get_col(transposed_indices[i]);
338
316
  }
339
317
 
340
318
  template <typename T>
341
- void
342
- launch_bsr_transpose_blocks(const int nnz, const int block_size,
343
- const int rows_per_block, const int cols_per_block,
344
- const int *block_indices,
345
- const BsrRowCol *transposed_indices,
346
- const T *bsr_values,
347
- int *transposed_bsr_columns, T *transposed_bsr_values) {
348
-
349
- switch (rows_per_block) {
350
- case 1:
351
- switch (cols_per_block) {
352
- case 1:
353
- wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
354
- (nnz, block_size, BsrBlockTransposer<1, 1, T>{},
355
- block_indices, transposed_indices, bsr_values,
356
- transposed_bsr_columns, transposed_bsr_values));
357
- return;
358
- case 2:
359
- wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
360
- (nnz, block_size, BsrBlockTransposer<1, 2, T>{},
361
- block_indices, transposed_indices, bsr_values,
362
- transposed_bsr_columns, transposed_bsr_values));
363
- return;
364
- case 3:
365
- wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
366
- (nnz, block_size, BsrBlockTransposer<1, 3, T>{},
367
- block_indices, transposed_indices, bsr_values,
368
- transposed_bsr_columns, transposed_bsr_values));
369
- return;
370
- }
371
- case 2:
372
- switch (cols_per_block) {
373
- case 1:
374
- wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
375
- (nnz, block_size, BsrBlockTransposer<2, 1, T>{},
376
- block_indices, transposed_indices, bsr_values,
377
- transposed_bsr_columns, transposed_bsr_values));
378
- return;
379
- case 2:
380
- wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
381
- (nnz, block_size, BsrBlockTransposer<2, 2, T>{},
382
- block_indices, transposed_indices, bsr_values,
383
- transposed_bsr_columns, transposed_bsr_values));
384
- return;
385
- case 3:
386
- wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
387
- (nnz, block_size, BsrBlockTransposer<2, 3, T>{},
388
- block_indices, transposed_indices, bsr_values,
389
- transposed_bsr_columns, transposed_bsr_values));
390
- return;
391
- }
392
- case 3:
393
- switch (cols_per_block) {
319
+ void launch_bsr_transpose_blocks(int nnz, const int* d_nnz, const int block_size, const int rows_per_block,
320
+ const int cols_per_block, const int* block_indices,
321
+ const BsrRowCol* transposed_indices, const T* bsr_values, int* transposed_bsr_columns,
322
+ T* transposed_bsr_values)
323
+ {
324
+
325
+ switch (rows_per_block)
326
+ {
394
327
  case 1:
395
- wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
396
- (nnz, block_size, BsrBlockTransposer<3, 1, T>{},
397
- block_indices, transposed_indices, bsr_values,
398
- transposed_bsr_columns, transposed_bsr_values));
399
- return;
328
+ switch (cols_per_block)
329
+ {
330
+ case 1:
331
+ wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
332
+ (d_nnz, block_size, BsrBlockTransposer<1, 1, T>{}, block_indices, transposed_indices,
333
+ bsr_values, transposed_bsr_columns, transposed_bsr_values));
334
+ return;
335
+ case 2:
336
+ wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
337
+ (d_nnz, block_size, BsrBlockTransposer<1, 2, T>{}, block_indices, transposed_indices,
338
+ bsr_values, transposed_bsr_columns, transposed_bsr_values));
339
+ return;
340
+ case 3:
341
+ wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
342
+ (d_nnz, block_size, BsrBlockTransposer<1, 3, T>{}, block_indices, transposed_indices,
343
+ bsr_values, transposed_bsr_columns, transposed_bsr_values));
344
+ return;
345
+ }
400
346
  case 2:
401
- wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
402
- (nnz, block_size, BsrBlockTransposer<3, 2, T>{},
403
- block_indices, transposed_indices, bsr_values,
404
- transposed_bsr_columns, transposed_bsr_values));
405
- return;
347
+ switch (cols_per_block)
348
+ {
349
+ case 1:
350
+ wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
351
+ (d_nnz, block_size, BsrBlockTransposer<2, 1, T>{}, block_indices, transposed_indices,
352
+ bsr_values, transposed_bsr_columns, transposed_bsr_values));
353
+ return;
354
+ case 2:
355
+ wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
356
+ (d_nnz, block_size, BsrBlockTransposer<2, 2, T>{}, block_indices, transposed_indices,
357
+ bsr_values, transposed_bsr_columns, transposed_bsr_values));
358
+ return;
359
+ case 3:
360
+ wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
361
+ (d_nnz, block_size, BsrBlockTransposer<2, 3, T>{}, block_indices, transposed_indices,
362
+ bsr_values, transposed_bsr_columns, transposed_bsr_values));
363
+ return;
364
+ }
406
365
  case 3:
407
- wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
408
- (nnz, block_size, BsrBlockTransposer<3, 3, T>{},
409
- block_indices, transposed_indices, bsr_values,
410
- transposed_bsr_columns, transposed_bsr_values));
411
- return;
366
+ switch (cols_per_block)
367
+ {
368
+ case 1:
369
+ wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
370
+ (d_nnz, block_size, BsrBlockTransposer<3, 1, T>{}, block_indices, transposed_indices,
371
+ bsr_values, transposed_bsr_columns, transposed_bsr_values));
372
+ return;
373
+ case 2:
374
+ wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
375
+ (d_nnz, block_size, BsrBlockTransposer<3, 2, T>{}, block_indices, transposed_indices,
376
+ bsr_values, transposed_bsr_columns, transposed_bsr_values));
377
+ return;
378
+ case 3:
379
+ wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
380
+ (d_nnz, block_size, BsrBlockTransposer<3, 3, T>{}, block_indices, transposed_indices,
381
+ bsr_values, transposed_bsr_columns, transposed_bsr_values));
382
+ return;
383
+ }
412
384
  }
413
- }
414
-
415
- wp_launch_device(
416
- WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
417
- (nnz, block_size,
418
- BsrBlockTransposer<-1, -1, T>{rows_per_block, cols_per_block},
419
- block_indices, transposed_indices, bsr_values, transposed_bsr_columns,
420
- transposed_bsr_values));
385
+
386
+ wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
387
+ (d_nnz, block_size, BsrBlockTransposer<-1, -1, T>{rows_per_block, cols_per_block}, block_indices,
388
+ transposed_indices, bsr_values, transposed_bsr_columns, transposed_bsr_values));
421
389
  }
422
390
 
423
391
  template <typename T>
424
- void bsr_transpose_device(int rows_per_block, int cols_per_block, int row_count,
425
- int col_count, int nnz, const int *bsr_offsets,
426
- const int *bsr_columns, const T *bsr_values,
427
- int *transposed_bsr_offsets,
428
- int *transposed_bsr_columns,
429
- T *transposed_bsr_values) {
430
-
431
- const int block_size = rows_per_block * cols_per_block;
432
-
433
- void *context = cuda_context_get_current();
434
- ContextGuard guard(context);
435
-
436
- cudaStream_t stream = static_cast<cudaStream_t>(cuda_stream_get_current());
437
-
438
- // Zero the transposed offsets
439
- memset_device(WP_CURRENT_CONTEXT, transposed_bsr_offsets, 0,
440
- (col_count + 1) * sizeof(int));
441
-
442
- ScopedTemporary<int> block_indices(context, 2*nnz);
443
- ScopedTemporary<BsrRowCol> combined_row_col(context, 2*nnz);
444
-
445
- cub::DoubleBuffer<int> d_keys(block_indices.buffer(),
446
- block_indices.buffer() + nnz);
447
- cub::DoubleBuffer<BsrRowCol> d_values(combined_row_col.buffer(),
448
- combined_row_col.buffer() + nnz);
449
-
450
- wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_fill_row_col, nnz,
451
- (nnz, row_count, bsr_offsets, bsr_columns, d_keys.Current(),
452
- d_values.Current(), transposed_bsr_offsets));
453
-
454
- // Sort blocks
455
- {
456
- size_t buff_size = 0;
457
- check_cuda(cub::DeviceRadixSort::SortPairs(nullptr, buff_size, d_values,
458
- d_keys, nnz, 0, 64, stream));
459
- ScopedTemporary<> temp(context, buff_size);
460
- check_cuda(cub::DeviceRadixSort::SortPairs(
461
- temp.buffer(), buff_size, d_values, d_keys, nnz, 0, 64, stream));
462
- }
463
-
464
- // Prefix sum the transposed row block counts
465
- {
466
- size_t buff_size = 0;
467
- check_cuda(cub::DeviceScan::InclusiveSum(
468
- nullptr, buff_size, transposed_bsr_offsets, transposed_bsr_offsets,
469
- col_count + 1, stream));
470
- ScopedTemporary<> temp(context, buff_size);
471
- check_cuda(cub::DeviceScan::InclusiveSum(
472
- temp.buffer(), buff_size, transposed_bsr_offsets,
473
- transposed_bsr_offsets, col_count + 1, stream));
474
- }
475
-
476
- // Move and transpose individual blocks
477
- launch_bsr_transpose_blocks(
478
- nnz, block_size,
479
- rows_per_block, cols_per_block,
480
- d_keys.Current(), d_values.Current(), bsr_values, transposed_bsr_columns,
481
- transposed_bsr_values);
392
+ void bsr_transpose_device(int rows_per_block, int cols_per_block, int row_count, int col_count, int nnz,
393
+ const int* bsr_offsets, const int* bsr_columns, const T* bsr_values,
394
+ int* transposed_bsr_offsets, int* transposed_bsr_columns, T* transposed_bsr_values)
395
+ {
396
+
397
+ const int block_size = rows_per_block * cols_per_block;
398
+
399
+ void* context = cuda_context_get_current();
400
+ ContextGuard guard(context);
401
+
402
+ cudaStream_t stream = static_cast<cudaStream_t>(cuda_stream_get_current());
403
+
404
+ ScopedTemporary<int> block_indices(context, 2 * nnz);
405
+ ScopedTemporary<BsrRowCol> combined_row_col(context, 2 * nnz);
406
+
407
+ cub::DoubleBuffer<int> d_keys(block_indices.buffer(), block_indices.buffer() + nnz);
408
+ cub::DoubleBuffer<BsrRowCol> d_values(combined_row_col.buffer(), combined_row_col.buffer() + nnz);
409
+
410
+ wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_fill_row_col, nnz,
411
+ (nnz, row_count, bsr_offsets, bsr_columns, d_keys.Current(), d_values.Current()));
412
+
413
+ // Sort blocks
414
+ {
415
+ size_t buff_size = 0;
416
+ check_cuda(cub::DeviceRadixSort::SortPairs(nullptr, buff_size, d_values, d_keys, nnz, 0, 64, stream));
417
+ ScopedTemporary<> temp(context, buff_size);
418
+ check_cuda(cub::DeviceRadixSort::SortPairs(temp.buffer(), buff_size, d_values, d_keys, nnz, 0, 64, stream));
419
+ }
420
+
421
+ // Compute row offsets from sorted unique blocks
422
+ wp_launch_device(WP_CURRENT_CONTEXT, bsr_find_row_offsets, col_count + 1,
423
+ (col_count, bsr_offsets + row_count, d_values.Current(), transposed_bsr_offsets));
424
+
425
+ // Move and transpose individual blocks
426
+ if (transposed_bsr_values != nullptr)
427
+ {
428
+ launch_bsr_transpose_blocks(nnz, bsr_offsets + row_count, block_size, rows_per_block, cols_per_block,
429
+ d_keys.Current(), d_values.Current(), bsr_values, transposed_bsr_columns,
430
+ transposed_bsr_values);
431
+ }
482
432
  }
483
433
 
484
434
  } // namespace
485
435
 
486
- int bsr_matrix_from_triplets_float_device(
487
- int rows_per_block, int cols_per_block, int row_count, int nnz,
488
- uint64_t tpl_rows, uint64_t tpl_columns, uint64_t tpl_values,
489
- uint64_t bsr_offsets, uint64_t bsr_columns, uint64_t bsr_values) {
490
- return bsr_matrix_from_triplets_device<float>(
491
- rows_per_block, cols_per_block, row_count, nnz,
492
- reinterpret_cast<const int *>(tpl_rows),
493
- reinterpret_cast<const int *>(tpl_columns),
494
- reinterpret_cast<const float *>(tpl_values),
495
- reinterpret_cast<int *>(bsr_offsets),
496
- reinterpret_cast<int *>(bsr_columns),
497
- reinterpret_cast<float *>(bsr_values));
436
+ void bsr_matrix_from_triplets_float_device(int rows_per_block, int cols_per_block, int row_count, int nnz,
437
+ int* tpl_rows, int* tpl_columns, void* tpl_values,
438
+ bool prune_numerical_zeros, int* bsr_offsets, int* bsr_columns,
439
+ void* bsr_values, int* bsr_nnz, void* bsr_nnz_event)
440
+ {
441
+ return bsr_matrix_from_triplets_device<float>(
442
+ rows_per_block, cols_per_block, row_count, nnz, tpl_rows, tpl_columns, static_cast<const float*>(tpl_values),
443
+ prune_numerical_zeros, bsr_offsets, bsr_columns, static_cast<float*>(bsr_values), bsr_nnz, bsr_nnz_event);
498
444
  }
499
445
 
500
- int bsr_matrix_from_triplets_double_device(
501
- int rows_per_block, int cols_per_block, int row_count, int nnz,
502
- uint64_t tpl_rows, uint64_t tpl_columns, uint64_t tpl_values,
503
- uint64_t bsr_offsets, uint64_t bsr_columns, uint64_t bsr_values) {
504
- return bsr_matrix_from_triplets_device<double>(
505
- rows_per_block, cols_per_block, row_count, nnz,
506
- reinterpret_cast<const int *>(tpl_rows),
507
- reinterpret_cast<const int *>(tpl_columns),
508
- reinterpret_cast<const double *>(tpl_values),
509
- reinterpret_cast<int *>(bsr_offsets),
510
- reinterpret_cast<int *>(bsr_columns),
511
- reinterpret_cast<double *>(bsr_values));
446
+ void bsr_matrix_from_triplets_double_device(int rows_per_block, int cols_per_block, int row_count, int nnz,
447
+ int* tpl_rows, int* tpl_columns, void* tpl_values,
448
+ bool prune_numerical_zeros, int* bsr_offsets, int* bsr_columns,
449
+ void* bsr_values, int* bsr_nnz, void* bsr_nnz_event)
450
+ {
451
+ return bsr_matrix_from_triplets_device<double>(
452
+ rows_per_block, cols_per_block, row_count, nnz, tpl_rows, tpl_columns, static_cast<const double*>(tpl_values),
453
+ prune_numerical_zeros, bsr_offsets, bsr_columns, static_cast<double*>(bsr_values), bsr_nnz, bsr_nnz_event);
512
454
  }
513
455
 
514
- void bsr_transpose_float_device(int rows_per_block, int cols_per_block,
515
- int row_count, int col_count, int nnz,
516
- uint64_t bsr_offsets, uint64_t bsr_columns,
517
- uint64_t bsr_values,
518
- uint64_t transposed_bsr_offsets,
519
- uint64_t transposed_bsr_columns,
520
- uint64_t transposed_bsr_values) {
521
- bsr_transpose_device(rows_per_block, cols_per_block, row_count, col_count,
522
- nnz, reinterpret_cast<const int *>(bsr_offsets),
523
- reinterpret_cast<const int *>(bsr_columns),
524
- reinterpret_cast<const float *>(bsr_values),
525
- reinterpret_cast<int *>(transposed_bsr_offsets),
526
- reinterpret_cast<int *>(transposed_bsr_columns),
527
- reinterpret_cast<float *>(transposed_bsr_values));
456
+ void bsr_transpose_float_device(int rows_per_block, int cols_per_block, int row_count, int col_count, int nnz,
457
+ int* bsr_offsets, int* bsr_columns, void* bsr_values, int* transposed_bsr_offsets,
458
+ int* transposed_bsr_columns, void* transposed_bsr_values)
459
+ {
460
+ bsr_transpose_device(rows_per_block, cols_per_block, row_count, col_count, nnz, bsr_offsets, bsr_columns,
461
+ static_cast<const float*>(bsr_values), transposed_bsr_offsets, transposed_bsr_columns,
462
+ static_cast<float*>(transposed_bsr_values));
528
463
  }
529
464
 
530
- void bsr_transpose_double_device(int rows_per_block, int cols_per_block,
531
- int row_count, int col_count, int nnz,
532
- uint64_t bsr_offsets, uint64_t bsr_columns,
533
- uint64_t bsr_values,
534
- uint64_t transposed_bsr_offsets,
535
- uint64_t transposed_bsr_columns,
536
- uint64_t transposed_bsr_values) {
537
- bsr_transpose_device(rows_per_block, cols_per_block, row_count, col_count,
538
- nnz, reinterpret_cast<const int *>(bsr_offsets),
539
- reinterpret_cast<const int *>(bsr_columns),
540
- reinterpret_cast<const double *>(bsr_values),
541
- reinterpret_cast<int *>(transposed_bsr_offsets),
542
- reinterpret_cast<int *>(transposed_bsr_columns),
543
- reinterpret_cast<double *>(transposed_bsr_values));
465
+ void bsr_transpose_double_device(int rows_per_block, int cols_per_block, int row_count, int col_count, int nnz,
466
+ int* bsr_offsets, int* bsr_columns, void* bsr_values, int* transposed_bsr_offsets,
467
+ int* transposed_bsr_columns, void* transposed_bsr_values)
468
+ {
469
+ bsr_transpose_device(rows_per_block, cols_per_block, row_count, col_count, nnz, bsr_offsets, bsr_columns,
470
+ static_cast<const double*>(bsr_values), transposed_bsr_offsets, transposed_bsr_columns,
471
+ static_cast<double*>(transposed_bsr_values));
544
472
  }