warp-lang 1.7.2__py3-none-manylinux_2_34_aarch64.whl → 1.8.0__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (180) hide show
  1. warp/__init__.py +3 -1
  2. warp/__init__.pyi +3489 -1
  3. warp/autograd.py +45 -122
  4. warp/bin/warp.so +0 -0
  5. warp/build.py +241 -252
  6. warp/build_dll.py +125 -26
  7. warp/builtins.py +1907 -384
  8. warp/codegen.py +257 -101
  9. warp/config.py +12 -1
  10. warp/constants.py +1 -1
  11. warp/context.py +657 -223
  12. warp/dlpack.py +1 -1
  13. warp/examples/benchmarks/benchmark_cloth.py +2 -2
  14. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  15. warp/examples/core/example_sample_mesh.py +1 -1
  16. warp/examples/core/example_spin_lock.py +93 -0
  17. warp/examples/core/example_work_queue.py +118 -0
  18. warp/examples/fem/example_adaptive_grid.py +5 -5
  19. warp/examples/fem/example_apic_fluid.py +1 -1
  20. warp/examples/fem/example_burgers.py +1 -1
  21. warp/examples/fem/example_convection_diffusion.py +9 -6
  22. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  23. warp/examples/fem/example_deformed_geometry.py +1 -1
  24. warp/examples/fem/example_diffusion.py +2 -2
  25. warp/examples/fem/example_diffusion_3d.py +1 -1
  26. warp/examples/fem/example_distortion_energy.py +1 -1
  27. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  28. warp/examples/fem/example_magnetostatics.py +5 -3
  29. warp/examples/fem/example_mixed_elasticity.py +5 -3
  30. warp/examples/fem/example_navier_stokes.py +11 -9
  31. warp/examples/fem/example_nonconforming_contact.py +5 -3
  32. warp/examples/fem/example_streamlines.py +8 -3
  33. warp/examples/fem/utils.py +9 -8
  34. warp/examples/interop/example_jax_ffi_callback.py +2 -2
  35. warp/examples/optim/example_drone.py +1 -1
  36. warp/examples/sim/example_cloth.py +1 -1
  37. warp/examples/sim/example_cloth_self_contact.py +48 -54
  38. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  39. warp/examples/tile/example_tile_cholesky.py +2 -1
  40. warp/examples/tile/example_tile_convolution.py +1 -1
  41. warp/examples/tile/example_tile_filtering.py +1 -1
  42. warp/examples/tile/example_tile_matmul.py +1 -1
  43. warp/examples/tile/example_tile_mlp.py +2 -0
  44. warp/fabric.py +7 -7
  45. warp/fem/__init__.py +5 -0
  46. warp/fem/adaptivity.py +1 -1
  47. warp/fem/cache.py +152 -63
  48. warp/fem/dirichlet.py +2 -2
  49. warp/fem/domain.py +136 -6
  50. warp/fem/field/field.py +141 -99
  51. warp/fem/field/nodal_field.py +85 -39
  52. warp/fem/field/virtual.py +97 -52
  53. warp/fem/geometry/adaptive_nanogrid.py +91 -86
  54. warp/fem/geometry/closest_point.py +13 -0
  55. warp/fem/geometry/deformed_geometry.py +102 -40
  56. warp/fem/geometry/element.py +56 -2
  57. warp/fem/geometry/geometry.py +323 -22
  58. warp/fem/geometry/grid_2d.py +157 -62
  59. warp/fem/geometry/grid_3d.py +116 -20
  60. warp/fem/geometry/hexmesh.py +86 -20
  61. warp/fem/geometry/nanogrid.py +166 -86
  62. warp/fem/geometry/partition.py +59 -25
  63. warp/fem/geometry/quadmesh.py +86 -135
  64. warp/fem/geometry/tetmesh.py +47 -119
  65. warp/fem/geometry/trimesh.py +77 -270
  66. warp/fem/integrate.py +107 -52
  67. warp/fem/linalg.py +25 -58
  68. warp/fem/operator.py +124 -27
  69. warp/fem/quadrature/pic_quadrature.py +36 -14
  70. warp/fem/quadrature/quadrature.py +40 -16
  71. warp/fem/space/__init__.py +1 -1
  72. warp/fem/space/basis_function_space.py +66 -46
  73. warp/fem/space/basis_space.py +17 -4
  74. warp/fem/space/dof_mapper.py +1 -1
  75. warp/fem/space/function_space.py +2 -2
  76. warp/fem/space/grid_2d_function_space.py +4 -1
  77. warp/fem/space/hexmesh_function_space.py +4 -2
  78. warp/fem/space/nanogrid_function_space.py +3 -1
  79. warp/fem/space/partition.py +11 -2
  80. warp/fem/space/quadmesh_function_space.py +4 -1
  81. warp/fem/space/restriction.py +5 -2
  82. warp/fem/space/shape/__init__.py +10 -8
  83. warp/fem/space/tetmesh_function_space.py +4 -1
  84. warp/fem/space/topology.py +52 -21
  85. warp/fem/space/trimesh_function_space.py +4 -1
  86. warp/fem/utils.py +53 -8
  87. warp/jax.py +1 -2
  88. warp/jax_experimental/ffi.py +12 -17
  89. warp/jax_experimental/xla_ffi.py +37 -24
  90. warp/math.py +171 -1
  91. warp/native/array.h +99 -0
  92. warp/native/builtin.h +174 -31
  93. warp/native/coloring.cpp +1 -1
  94. warp/native/exports.h +118 -63
  95. warp/native/intersect.h +3 -3
  96. warp/native/mat.h +5 -10
  97. warp/native/mathdx.cpp +11 -5
  98. warp/native/matnn.h +1 -123
  99. warp/native/quat.h +28 -4
  100. warp/native/sparse.cpp +121 -258
  101. warp/native/sparse.cu +181 -274
  102. warp/native/spatial.h +305 -17
  103. warp/native/tile.h +583 -72
  104. warp/native/tile_radix_sort.h +1108 -0
  105. warp/native/tile_reduce.h +237 -2
  106. warp/native/tile_scan.h +240 -0
  107. warp/native/tuple.h +189 -0
  108. warp/native/vec.h +6 -16
  109. warp/native/warp.cpp +36 -4
  110. warp/native/warp.cu +574 -51
  111. warp/native/warp.h +47 -74
  112. warp/optim/linear.py +5 -1
  113. warp/paddle.py +7 -8
  114. warp/py.typed +0 -0
  115. warp/render/render_opengl.py +58 -29
  116. warp/render/render_usd.py +124 -61
  117. warp/sim/__init__.py +9 -0
  118. warp/sim/collide.py +252 -78
  119. warp/sim/graph_coloring.py +8 -1
  120. warp/sim/import_mjcf.py +4 -3
  121. warp/sim/import_usd.py +11 -7
  122. warp/sim/integrator.py +5 -2
  123. warp/sim/integrator_euler.py +1 -1
  124. warp/sim/integrator_featherstone.py +1 -1
  125. warp/sim/integrator_vbd.py +751 -320
  126. warp/sim/integrator_xpbd.py +1 -1
  127. warp/sim/model.py +265 -260
  128. warp/sim/utils.py +10 -7
  129. warp/sparse.py +303 -166
  130. warp/tape.py +52 -51
  131. warp/tests/cuda/test_conditional_captures.py +1046 -0
  132. warp/tests/cuda/test_streams.py +1 -1
  133. warp/tests/geometry/test_volume.py +2 -2
  134. warp/tests/interop/test_dlpack.py +9 -9
  135. warp/tests/interop/test_jax.py +0 -1
  136. warp/tests/run_coverage_serial.py +1 -1
  137. warp/tests/sim/disabled_kinematics.py +2 -2
  138. warp/tests/sim/{test_vbd.py → test_cloth.py} +296 -113
  139. warp/tests/sim/test_collision.py +159 -51
  140. warp/tests/sim/test_coloring.py +15 -1
  141. warp/tests/test_array.py +254 -2
  142. warp/tests/test_array_reduce.py +2 -2
  143. warp/tests/test_atomic_cas.py +299 -0
  144. warp/tests/test_codegen.py +142 -19
  145. warp/tests/test_conditional.py +47 -1
  146. warp/tests/test_ctypes.py +0 -20
  147. warp/tests/test_devices.py +8 -0
  148. warp/tests/test_fabricarray.py +4 -2
  149. warp/tests/test_fem.py +58 -25
  150. warp/tests/test_func.py +42 -1
  151. warp/tests/test_grad.py +1 -1
  152. warp/tests/test_lerp.py +1 -3
  153. warp/tests/test_map.py +481 -0
  154. warp/tests/test_mat.py +1 -24
  155. warp/tests/test_quat.py +6 -15
  156. warp/tests/test_rounding.py +10 -38
  157. warp/tests/test_runlength_encode.py +7 -7
  158. warp/tests/test_smoothstep.py +1 -1
  159. warp/tests/test_sparse.py +51 -2
  160. warp/tests/test_spatial.py +507 -1
  161. warp/tests/test_struct.py +2 -2
  162. warp/tests/test_tuple.py +265 -0
  163. warp/tests/test_types.py +2 -2
  164. warp/tests/test_utils.py +24 -18
  165. warp/tests/tile/test_tile.py +420 -1
  166. warp/tests/tile/test_tile_mathdx.py +518 -14
  167. warp/tests/tile/test_tile_reduce.py +213 -0
  168. warp/tests/tile/test_tile_shared_memory.py +130 -1
  169. warp/tests/tile/test_tile_sort.py +117 -0
  170. warp/tests/unittest_suites.py +4 -6
  171. warp/types.py +462 -308
  172. warp/utils.py +647 -86
  173. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/METADATA +20 -6
  174. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/RECORD +177 -165
  175. warp/stubs.py +0 -3381
  176. warp/tests/sim/test_xpbd.py +0 -399
  177. warp/tests/test_mlp.py +0 -282
  178. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/WHEEL +0 -0
  179. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/licenses/LICENSE.md +0 -0
  180. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/top_level.txt +0 -0
warp/native/sparse.cu CHANGED
@@ -17,6 +17,8 @@
17
17
 
18
18
  #include "cuda_util.h"
19
19
  #include "warp.h"
20
+ #include "stdint.h"
21
+ #include <cstdint>
20
22
 
21
23
  #define THRUST_IGNORE_CUB_VERSION_CHECK
22
24
 
@@ -45,29 +47,55 @@ template <typename T> struct BsrBlockIsNotZero
45
47
  {
46
48
  int block_size;
47
49
  const T* values;
50
+ T zero_mask;
48
51
 
49
- CUDA_CALLABLE_DEVICE bool operator()(int i) const
52
+ BsrBlockIsNotZero(int block_size, const void* values, const uint64_t zero_mask)
53
+ : block_size(block_size), values(static_cast<const T*>(values)), zero_mask(static_cast<const T>(zero_mask))
54
+ {}
55
+
56
+ CUDA_CALLABLE_DEVICE bool operator()(int block) const
50
57
  {
51
58
  if (!values)
52
59
  return true;
53
60
 
54
- const T* val = values + i * block_size;
61
+ const T* val = values + block * block_size;
55
62
  for (int i = 0; i < block_size; ++i, ++val)
56
63
  {
57
- if (*val != T(0))
64
+ if ((*val & zero_mask) != 0)
58
65
  return true;
59
66
  }
60
67
  return false;
61
68
  }
62
69
  };
63
70
 
71
+ template <> struct BsrBlockIsNotZero<void>
72
+ {
73
+ BsrBlockIsNotZero(int block_size, const void* values, const uint64_t zero_mask)
74
+ {}
75
+
76
+ CUDA_CALLABLE_DEVICE bool operator()(int block) const
77
+ {
78
+ return true;
79
+ }
80
+ };
81
+
64
82
  struct BsrBlockInMask
65
83
  {
84
+ const int nrow;
85
+ const int ncol;
66
86
  const int* bsr_offsets;
67
87
  const int* bsr_columns;
88
+ const int* device_nnz;
68
89
 
69
- CUDA_CALLABLE_DEVICE bool operator()(int row, int col) const
90
+ CUDA_CALLABLE_DEVICE bool operator()(int index, int row, int col) const
70
91
  {
92
+ if (device_nnz != nullptr && index >= *device_nnz)
93
+ return false;
94
+
95
+ if (row < 0 || row >= nrow || col < 0 || col >= ncol){
96
+ return false;
97
+ }
98
+
71
99
  if (bsr_offsets == nullptr)
72
100
  return true;
73
101
 
@@ -93,9 +121,9 @@ struct BsrBlockInMask
93
121
  };
94
122
 
95
123
  template <typename T>
96
- __global__ void bsr_fill_triplet_key_values(const int nnz, const int nrow, const int* tpl_rows, const int* tpl_columns,
124
+ __global__ void bsr_fill_triplet_key_values(const int nnz, const int* tpl_rows, const int* tpl_columns,
97
125
  const BsrBlockIsNotZero<T> nonZero, const BsrBlockInMask mask,
98
- uint32_t* block_indices, BsrRowCol* tpl_row_col)
126
+ int* block_indices, BsrRowCol* tpl_row_col)
99
127
  {
100
128
  int block = blockIdx.x * blockDim.x + threadIdx.x;
101
129
  if (block >= nnz)
@@ -103,10 +131,10 @@ __global__ void bsr_fill_triplet_key_values(const int nnz, const int nrow, const
103
131
 
104
132
  const int row = tpl_rows[block];
105
133
  const int col = tpl_columns[block];
106
- const bool is_valid = row >= 0 && row < nrow;
107
134
 
108
135
  const BsrRowCol row_col =
109
- is_valid && nonZero(block) && mask(row, col) ? bsr_combine_row_col(row, col) : PRUNED_ROWCOL;
136
+ mask(block, row, col) && nonZero(block) ? bsr_combine_row_col(row, col) : PRUNED_ROWCOL;
137
+
110
138
  tpl_row_col[block] = row_col;
111
139
  block_indices[block] = block;
112
140
  }
@@ -153,126 +181,34 @@ __global__ void bsr_find_row_offsets(uint32_t row_count, const T* d_nnz, const B
153
181
  row_offsets[row] = lower;
154
182
  }
155
183
 
156
- template <typename T>
157
- __global__ void bsr_merge_blocks(const int* d_nnz, int block_size, const uint32_t* block_offsets,
158
- const uint32_t* sorted_block_indices, const BsrRowCol* unique_row_cols,
159
- const T* tpl_values, int* bsr_cols, T* bsr_values)
160
-
184
+ __global__ void bsr_set_column(const int* d_nnz, const BsrRowCol* unique_row_cols, int* bsr_cols)
161
185
  {
162
186
  const uint32_t i = blockIdx.x * blockDim.x + threadIdx.x;
163
-
164
187
  if (i >= *d_nnz)
165
188
  return;
166
-
167
189
  const BsrRowCol row_col = unique_row_cols[i];
168
190
  bsr_cols[i] = bsr_get_col(row_col);
169
-
170
- // Accumulate merged block values
171
- if (row_col == PRUNED_ROWCOL || bsr_values == nullptr)
172
- return;
173
-
174
- const uint32_t beg = i ? block_offsets[i - 1] : 0;
175
- const uint32_t end = block_offsets[i];
176
-
177
- T* bsr_val = bsr_values + i * block_size;
178
- const T* tpl_val = tpl_values + sorted_block_indices[beg] * block_size;
179
-
180
- for (int k = 0; k < block_size; ++k)
181
- {
182
- bsr_val[k] = tpl_val[k];
183
- }
184
-
185
- for (uint32_t cur = beg + 1; cur != end; ++cur)
186
- {
187
- const T* tpl_val = tpl_values + sorted_block_indices[cur] * block_size;
188
- for (int k = 0; k < block_size; ++k)
189
- {
190
- bsr_val[k] += tpl_val[k];
191
- }
192
- }
193
191
  }
194
192
 
195
193
  template <typename T>
196
- void bsr_matrix_from_triplets_device(const int rows_per_block, const int cols_per_block, const int row_count,
197
- const int nnz, const int* tpl_rows, const int* tpl_columns, const T* tpl_values,
198
- const bool prune_numerical_zeros, const bool masked, int* bsr_offsets,
199
- int* bsr_columns, T* bsr_values, int* bsr_nnz, void* bsr_nnz_event)
194
+ void launch_bsr_fill_triplet_key_values(
195
+ const int block_size,
196
+ const int nnz,
197
+ const BsrBlockInMask& mask,
198
+ const int* tpl_rows,
199
+ const int* tpl_columns,
200
+ const void* tpl_values,
201
+ const uint64_t scalar_zero_mask,
202
+ int* block_indices,
203
+ BsrRowCol* row_col
204
+ )
200
205
  {
201
- const int block_size = rows_per_block * cols_per_block;
202
-
203
- void* context = cuda_context_get_current();
204
- ContextGuard guard(context);
205
-
206
- // Per-context cached temporary buffers
207
- // BsrFromTripletsTemp& bsr_temp = g_bsr_from_triplets_temp_map[context];
208
-
209
- cudaStream_t stream = static_cast<cudaStream_t>(cuda_stream_get_current());
210
-
211
- ScopedTemporary<uint32_t> block_indices(context, 2 * nnz + 1);
212
- ScopedTemporary<BsrRowCol> combined_row_col(context, 2 * nnz);
213
-
214
- cub::DoubleBuffer<uint32_t> d_keys(block_indices.buffer(), block_indices.buffer() + nnz);
215
- cub::DoubleBuffer<BsrRowCol> d_values(combined_row_col.buffer(), combined_row_col.buffer() + nnz);
216
-
217
- uint32_t* unique_triplet_count = block_indices.buffer() + 2 * nnz;
218
-
219
- // Combine rows and columns so we can sort on them both
220
- BsrBlockIsNotZero<T> isNotZero{block_size, prune_numerical_zeros ? tpl_values : nullptr};
221
- BsrBlockInMask mask{masked ? bsr_offsets : nullptr, bsr_columns};
206
+ BsrBlockIsNotZero<T> isNotZero{block_size, tpl_values, scalar_zero_mask};
222
207
  wp_launch_device(WP_CURRENT_CONTEXT, bsr_fill_triplet_key_values, nnz,
223
- (nnz, row_count, tpl_rows, tpl_columns, isNotZero, mask, d_keys.Current(), d_values.Current()));
224
-
225
- // Sort
226
- {
227
- size_t buff_size = 0;
228
- check_cuda(cub::DeviceRadixSort::SortPairs(nullptr, buff_size, d_values, d_keys, nnz, 0, 64, stream));
229
- ScopedTemporary<> temp(context, buff_size);
230
- check_cuda(cub::DeviceRadixSort::SortPairs(temp.buffer(), buff_size, d_values, d_keys, nnz, 0, 64, stream));
231
- }
232
-
233
- // Runlength encode row-col sequences
234
- {
235
- size_t buff_size = 0;
236
- check_cuda(cub::DeviceRunLengthEncode::Encode(nullptr, buff_size, d_values.Current(), d_values.Alternate(),
237
- d_keys.Alternate(), unique_triplet_count, nnz, stream));
238
- ScopedTemporary<> temp(context, buff_size);
239
- check_cuda(cub::DeviceRunLengthEncode::Encode(temp.buffer(), buff_size, d_values.Current(),
240
- d_values.Alternate(), d_keys.Alternate(), unique_triplet_count,
241
- nnz, stream));
242
- }
243
-
244
- // Compute row offsets from sorted unique blocks
245
- wp_launch_device(WP_CURRENT_CONTEXT, bsr_find_row_offsets, row_count + 1,
246
- (row_count, unique_triplet_count, d_values.Alternate(), bsr_offsets));
247
-
248
- if (bsr_nnz)
249
- {
250
- // Copy nnz to host, and record an event for the completed transfer if desired
251
-
252
- memcpy_d2h(WP_CURRENT_CONTEXT, bsr_nnz, bsr_offsets + row_count, sizeof(int), stream);
253
-
254
- if (bsr_nnz_event)
255
- {
256
- cuda_event_record(bsr_nnz_event, stream);
257
- }
258
- }
259
-
260
- // Scan repeated block counts
261
- {
262
- size_t buff_size = 0;
263
- check_cuda(
264
- cub::DeviceScan::InclusiveSum(nullptr, buff_size, d_keys.Alternate(), d_keys.Alternate(), nnz, stream));
265
- ScopedTemporary<> temp(context, buff_size);
266
- check_cuda(cub::DeviceScan::InclusiveSum(temp.buffer(), buff_size, d_keys.Alternate(), d_keys.Alternate(), nnz,
267
- stream));
268
- }
269
-
270
- // Accumulate repeated blocks and set column indices
271
- wp_launch_device(WP_CURRENT_CONTEXT, bsr_merge_blocks, nnz,
272
- (bsr_offsets + row_count, block_size, d_keys.Alternate(), d_keys.Current(), d_values.Alternate(),
273
- tpl_values, bsr_columns, bsr_values));
208
+ (nnz, tpl_rows, tpl_columns, isNotZero, mask, block_indices, row_col ));
274
209
  }
275
210
 
211
+
276
212
  __global__ void bsr_transpose_fill_row_col(const int nnz_upper_bound, const int row_count, const int* bsr_offsets,
277
213
  const int* bsr_columns, int* block_indices, BsrRowCol* transposed_row_col)
278
214
  {
@@ -283,6 +219,8 @@ __global__ void bsr_transpose_fill_row_col(const int nnz_upper_bound, const int
283
219
  // Outside of allocated bounds, do nothing
284
220
  return;
285
221
  }
222
+
223
+ block_indices[i] = i;
286
224
 
287
225
  if (i >= bsr_offsets[row_count])
288
226
  {
@@ -291,8 +229,6 @@ __global__ void bsr_transpose_fill_row_col(const int nnz_upper_bound, const int
291
229
  return;
292
230
  }
293
231
 
294
- block_indices[i] = i;
295
-
296
232
  // Binary search for row
297
233
  int lower = 0;
298
234
  int upper = row_count - 1;
@@ -317,144 +253,153 @@ __global__ void bsr_transpose_fill_row_col(const int nnz_upper_bound, const int
317
253
  transposed_row_col[i] = row_col;
318
254
  }
319
255
 
320
- template <int Rows, int Cols, typename T> struct BsrBlockTransposer
256
+ } // namespace
257
+
258
+
259
+ WP_API void bsr_matrix_from_triplets_device(
260
+ const int block_size,
261
+ int scalar_size,
262
+ const int row_count,
263
+ const int col_count,
264
+ const int nnz,
265
+ const int* tpl_nnz,
266
+ const int* tpl_rows,
267
+ const int* tpl_columns,
268
+ const void* tpl_values,
269
+ const uint64_t scalar_zero_mask,
270
+ const bool masked_topology,
271
+ int* tpl_block_offsets,
272
+ int* tpl_block_indices,
273
+ int* bsr_offsets,
274
+ int* bsr_columns,
275
+ int* bsr_nnz, void* bsr_nnz_event)
321
276
  {
322
- void CUDA_CALLABLE_DEVICE operator()(const T* src, T* dest) const
277
+ void* context = cuda_context_get_current();
278
+ ContextGuard guard(context);
279
+
280
+ // Per-context cached temporary buffers
281
+ // BsrFromTripletsTemp& bsr_temp = g_bsr_from_triplets_temp_map[context];
282
+
283
+ cudaStream_t stream = static_cast<cudaStream_t>(cuda_stream_get_current());
284
+
285
+ ScopedTemporary<BsrRowCol> combined_row_col(context, 2 * size_t(nnz));
286
+ ScopedTemporary<int> unique_triplet_count(context, 1);
287
+
288
+ bool return_summed_blocks = tpl_block_offsets != nullptr && tpl_block_indices != nullptr;
289
+ if(!return_summed_blocks)
323
290
  {
324
- for (int r = 0; r < Rows; ++r)
325
- {
326
- for (int c = 0; c < Cols; ++c)
327
- {
328
- dest[c * Rows + r] = src[r * Cols + c];
329
- }
330
- }
291
+ // if not provided, allocate temporary offset and indices buffers
292
+ tpl_block_offsets = static_cast<int*>(alloc_device(context, size_t(nnz) * sizeof(int)));
293
+ tpl_block_indices = static_cast<int*>(alloc_device(context, size_t(nnz) * sizeof(int)));
331
294
  }
332
- };
333
295
 
334
- template <typename T> struct BsrBlockTransposer<-1, -1, T>
335
- {
336
296
 
337
- int row_count;
338
- int col_count;
297
+ cub::DoubleBuffer<int> d_keys(tpl_block_indices, tpl_block_offsets);
298
+ cub::DoubleBuffer<BsrRowCol> d_values(combined_row_col.buffer(), combined_row_col.buffer() + nnz);
339
299
 
340
- void CUDA_CALLABLE_DEVICE operator()(const T* src, T* dest) const
300
+ // Combine rows and columns so we can sort on them both,
301
+ // ensuring that blocks that should be pruned are moved to the end
302
+ BsrBlockInMask mask{row_count, col_count, masked_topology ? bsr_offsets : nullptr, bsr_columns, tpl_nnz};
303
+ if (scalar_zero_mask == 0 || tpl_values == nullptr)
304
+ scalar_size = 0;
305
+ switch(scalar_size)
341
306
  {
342
- for (int r = 0; r < row_count; ++r)
343
- {
344
- for (int c = 0; c < col_count; ++c)
345
- {
346
- dest[c * row_count + r] = src[r * col_count + c];
347
- }
348
- }
307
+ case sizeof(uint8_t):
308
+ launch_bsr_fill_triplet_key_values<uint8_t>(block_size, nnz, mask, tpl_rows, tpl_columns, tpl_values, scalar_zero_mask, d_keys.Current(), d_values.Current());
309
+ break;
310
+ case sizeof(uint16_t):
311
+ launch_bsr_fill_triplet_key_values<uint16_t>(block_size, nnz, mask, tpl_rows, tpl_columns, tpl_values, scalar_zero_mask, d_keys.Current(), d_values.Current());
312
+ break;
313
+ case sizeof(uint32_t):
314
+ launch_bsr_fill_triplet_key_values<uint32_t>(block_size, nnz, mask, tpl_rows, tpl_columns, tpl_values, scalar_zero_mask, d_keys.Current(), d_values.Current());
315
+ break;
316
+ case sizeof(uint64_t):
317
+ launch_bsr_fill_triplet_key_values<uint64_t>(block_size, nnz, mask, tpl_rows, tpl_columns, tpl_values, scalar_zero_mask, d_keys.Current(), d_values.Current());
318
+ break;
319
+ default:
320
+ // no scalar-level pruning
321
+ launch_bsr_fill_triplet_key_values<void>(block_size, nnz, mask, tpl_rows, tpl_columns, tpl_values, scalar_zero_mask, d_keys.Current(), d_values.Current());
322
+ break;
349
323
  }
350
- };
351
324
 
352
- template <int Rows, int Cols, typename T>
353
- __global__ void bsr_transpose_blocks(const int* nnz, const int block_size, BsrBlockTransposer<Rows, Cols, T> transposer,
354
- const int* block_indices, const BsrRowCol* transposed_indices, const T* bsr_values,
355
- int* transposed_bsr_columns, T* transposed_bsr_values)
356
- {
357
- int i = blockIdx.x * blockDim.x + threadIdx.x;
358
- if (i >= *nnz)
359
- return;
360
325
 
361
- const int src_idx = block_indices[i];
326
+ // Sort
327
+ {
328
+ size_t buff_size = 0;
329
+ check_cuda(cub::DeviceRadixSort::SortPairs(nullptr, buff_size, d_values, d_keys, nnz, 0, 64, stream));
330
+ ScopedTemporary<> temp(context, buff_size);
331
+ check_cuda(cub::DeviceRadixSort::SortPairs(temp.buffer(), buff_size, d_values, d_keys, nnz, 0, 64, stream));
362
332
 
363
- transposer(bsr_values + src_idx * block_size, transposed_bsr_values + i * block_size);
333
+ // Depending on data size and GPU architecture buffers may have been swapped or not
334
+ // Ensures the sorted keys are available in summed_block_indices if needed
335
+ if(return_summed_blocks && d_keys.Current() != tpl_block_indices)
336
+ {
337
+ check_cuda(cudaMemcpy(tpl_block_indices, d_keys.Current(), nnz * sizeof(int), cudaMemcpyDeviceToDevice));
338
+ }
339
+ }
364
340
 
365
- transposed_bsr_columns[i] = bsr_get_col(transposed_indices[i]);
366
- }
341
+ // Runlength encode row-col sequences
342
+ {
343
+ size_t buff_size = 0;
344
+ check_cuda(cub::DeviceRunLengthEncode::Encode(nullptr, buff_size, d_values.Current(), d_values.Alternate(),
345
+ tpl_block_offsets, unique_triplet_count.buffer(), nnz, stream));
346
+ ScopedTemporary<> temp(context, buff_size);
347
+ check_cuda(cub::DeviceRunLengthEncode::Encode(temp.buffer(), buff_size, d_values.Current(),
348
+ d_values.Alternate(), tpl_block_offsets, unique_triplet_count.buffer(),
349
+ nnz, stream));
350
+ }
367
351
 
368
- template <typename T>
369
- void launch_bsr_transpose_blocks(int nnz, const int* d_nnz, const int block_size, const int rows_per_block,
370
- const int cols_per_block, const int* block_indices,
371
- const BsrRowCol* transposed_indices, const T* bsr_values, int* transposed_bsr_columns,
372
- T* transposed_bsr_values)
373
- {
352
+ // Compute row offsets from sorted unique blocks
353
+ wp_launch_device(WP_CURRENT_CONTEXT, bsr_find_row_offsets, row_count + 1,
354
+ (row_count, unique_triplet_count.buffer(), d_values.Alternate(), bsr_offsets));
374
355
 
375
- switch (rows_per_block)
356
+ if (bsr_nnz)
376
357
  {
377
- case 1:
378
- switch (cols_per_block)
379
- {
380
- case 1:
381
- wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
382
- (d_nnz, block_size, BsrBlockTransposer<1, 1, T>{}, block_indices, transposed_indices,
383
- bsr_values, transposed_bsr_columns, transposed_bsr_values));
384
- return;
385
- case 2:
386
- wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
387
- (d_nnz, block_size, BsrBlockTransposer<1, 2, T>{}, block_indices, transposed_indices,
388
- bsr_values, transposed_bsr_columns, transposed_bsr_values));
389
- return;
390
- case 3:
391
- wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
392
- (d_nnz, block_size, BsrBlockTransposer<1, 3, T>{}, block_indices, transposed_indices,
393
- bsr_values, transposed_bsr_columns, transposed_bsr_values));
394
- return;
395
- }
396
- case 2:
397
- switch (cols_per_block)
398
- {
399
- case 1:
400
- wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
401
- (d_nnz, block_size, BsrBlockTransposer<2, 1, T>{}, block_indices, transposed_indices,
402
- bsr_values, transposed_bsr_columns, transposed_bsr_values));
403
- return;
404
- case 2:
405
- wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
406
- (d_nnz, block_size, BsrBlockTransposer<2, 2, T>{}, block_indices, transposed_indices,
407
- bsr_values, transposed_bsr_columns, transposed_bsr_values));
408
- return;
409
- case 3:
410
- wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
411
- (d_nnz, block_size, BsrBlockTransposer<2, 3, T>{}, block_indices, transposed_indices,
412
- bsr_values, transposed_bsr_columns, transposed_bsr_values));
413
- return;
414
- }
415
- case 3:
416
- switch (cols_per_block)
358
+ // Copy nnz to host, and record an event for the completed transfer if desired
359
+
360
+ memcpy_d2h(WP_CURRENT_CONTEXT, bsr_nnz, bsr_offsets + row_count, sizeof(int), stream);
361
+
362
+ if (bsr_nnz_event)
417
363
  {
418
- case 1:
419
- wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
420
- (d_nnz, block_size, BsrBlockTransposer<3, 1, T>{}, block_indices, transposed_indices,
421
- bsr_values, transposed_bsr_columns, transposed_bsr_values));
422
- return;
423
- case 2:
424
- wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
425
- (d_nnz, block_size, BsrBlockTransposer<3, 2, T>{}, block_indices, transposed_indices,
426
- bsr_values, transposed_bsr_columns, transposed_bsr_values));
427
- return;
428
- case 3:
429
- wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
430
- (d_nnz, block_size, BsrBlockTransposer<3, 3, T>{}, block_indices, transposed_indices,
431
- bsr_values, transposed_bsr_columns, transposed_bsr_values));
432
- return;
364
+ cuda_event_record(bsr_nnz_event, stream);
433
365
  }
434
366
  }
435
367
 
436
- wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
437
- (d_nnz, block_size, BsrBlockTransposer<-1, -1, T>{rows_per_block, cols_per_block}, block_indices,
438
- transposed_indices, bsr_values, transposed_bsr_columns, transposed_bsr_values));
439
- }
368
+ // Set column indices
369
+ wp_launch_device(WP_CURRENT_CONTEXT, bsr_set_column, nnz,
370
+ (bsr_offsets + row_count, d_values.Alternate(),
371
+ bsr_columns));
440
372
 
441
- template <typename T>
442
- void bsr_transpose_device(int rows_per_block, int cols_per_block, int row_count, int col_count, int nnz,
443
- const int* bsr_offsets, const int* bsr_columns, const T* bsr_values,
444
- int* transposed_bsr_offsets, int* transposed_bsr_columns, T* transposed_bsr_values)
445
- {
373
+ // Scan repeated block counts
374
+ if(return_summed_blocks)
375
+ {
376
+ size_t buff_size = 0;
377
+ check_cuda(
378
+ cub::DeviceScan::InclusiveSum(nullptr, buff_size, tpl_block_offsets, tpl_block_offsets, nnz, stream));
379
+ ScopedTemporary<> temp(context, buff_size);
380
+ check_cuda(cub::DeviceScan::InclusiveSum(temp.buffer(), buff_size, tpl_block_offsets, tpl_block_offsets, nnz,
381
+ stream));
382
+ } else {
383
+ // free our temporary buffers
384
+ free_device(context, tpl_block_offsets);
385
+ free_device(context, tpl_block_indices);
386
+ }
387
+ }
446
388
 
447
- const int block_size = rows_per_block * cols_per_block;
448
389
 
390
+ WP_API void bsr_transpose_device(int row_count, int col_count, int nnz,
391
+ const int* bsr_offsets, const int* bsr_columns,
392
+ int* transposed_bsr_offsets, int* transposed_bsr_columns,
393
+ int* src_block_indices)
394
+ {
449
395
  void* context = cuda_context_get_current();
450
396
  ContextGuard guard(context);
451
397
 
452
398
  cudaStream_t stream = static_cast<cudaStream_t>(cuda_stream_get_current());
453
399
 
454
- ScopedTemporary<int> block_indices(context, 2 * nnz);
455
400
  ScopedTemporary<BsrRowCol> combined_row_col(context, 2 * nnz);
456
401
 
457
- cub::DoubleBuffer<int> d_keys(block_indices.buffer(), block_indices.buffer() + nnz);
402
+ cub::DoubleBuffer<int> d_keys(src_block_indices + nnz, src_block_indices);
458
403
  cub::DoubleBuffer<BsrRowCol> d_values(combined_row_col.buffer(), combined_row_col.buffer() + nnz);
459
404
 
460
405
  wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_fill_row_col, nnz,
@@ -466,59 +411,21 @@ void bsr_transpose_device(int rows_per_block, int cols_per_block, int row_count,
466
411
  check_cuda(cub::DeviceRadixSort::SortPairs(nullptr, buff_size, d_values, d_keys, nnz, 0, 64, stream));
467
412
  ScopedTemporary<> temp(context, buff_size);
468
413
  check_cuda(cub::DeviceRadixSort::SortPairs(temp.buffer(), buff_size, d_values, d_keys, nnz, 0, 64, stream));
414
+
415
+ // Depending on data size and GPU architecture buffers may have been swapped or not
416
+ // Ensures the sorted keys are available in summed_block_indices if needed
417
+ if(d_keys.Current() != src_block_indices)
418
+ {
419
+ check_cuda(cudaMemcpy(src_block_indices, src_block_indices+nnz, size_t(nnz) * sizeof(int), cudaMemcpyDeviceToDevice));
420
+ }
469
421
  }
470
422
 
471
423
  // Compute row offsets from sorted unique blocks
472
424
  wp_launch_device(WP_CURRENT_CONTEXT, bsr_find_row_offsets, col_count + 1,
473
425
  (col_count, bsr_offsets + row_count, d_values.Current(), transposed_bsr_offsets));
474
426
 
475
- // Move and transpose individual blocks
476
- if (transposed_bsr_values != nullptr)
477
- {
478
- launch_bsr_transpose_blocks(nnz, bsr_offsets + row_count, block_size, rows_per_block, cols_per_block,
479
- d_keys.Current(), d_values.Current(), bsr_values, transposed_bsr_columns,
480
- transposed_bsr_values);
481
- }
482
- }
483
-
484
- } // namespace
485
427
 
486
- void bsr_matrix_from_triplets_float_device(int rows_per_block, int cols_per_block, int row_count, int nnz,
487
- int* tpl_rows, int* tpl_columns, void* tpl_values,
488
- bool prune_numerical_zeros, bool masked, int* bsr_offsets, int* bsr_columns,
489
- void* bsr_values, int* bsr_nnz, void* bsr_nnz_event)
490
- {
491
- return bsr_matrix_from_triplets_device<float>(rows_per_block, cols_per_block, row_count, nnz, tpl_rows, tpl_columns,
492
- static_cast<const float*>(tpl_values), prune_numerical_zeros, masked,
493
- bsr_offsets, bsr_columns, static_cast<float*>(bsr_values), bsr_nnz,
494
- bsr_nnz_event);
495
- }
496
-
497
- void bsr_matrix_from_triplets_double_device(int rows_per_block, int cols_per_block, int row_count, int nnz,
498
- int* tpl_rows, int* tpl_columns, void* tpl_values,
499
- bool prune_numerical_zeros, bool masked, int* bsr_offsets, int* bsr_columns,
500
- void* bsr_values, int* bsr_nnz, void* bsr_nnz_event)
501
- {
502
- return bsr_matrix_from_triplets_device<double>(rows_per_block, cols_per_block, row_count, nnz, tpl_rows,
503
- tpl_columns, static_cast<const double*>(tpl_values),
504
- prune_numerical_zeros, masked, bsr_offsets, bsr_columns,
505
- static_cast<double*>(bsr_values), bsr_nnz, bsr_nnz_event);
506
- }
507
-
508
- void bsr_transpose_float_device(int rows_per_block, int cols_per_block, int row_count, int col_count, int nnz,
509
- int* bsr_offsets, int* bsr_columns, void* bsr_values, int* transposed_bsr_offsets,
510
- int* transposed_bsr_columns, void* transposed_bsr_values)
511
- {
512
- bsr_transpose_device(rows_per_block, cols_per_block, row_count, col_count, nnz, bsr_offsets, bsr_columns,
513
- static_cast<const float*>(bsr_values), transposed_bsr_offsets, transposed_bsr_columns,
514
- static_cast<float*>(transposed_bsr_values));
515
- }
516
-
517
- void bsr_transpose_double_device(int rows_per_block, int cols_per_block, int row_count, int col_count, int nnz,
518
- int* bsr_offsets, int* bsr_columns, void* bsr_values, int* transposed_bsr_offsets,
519
- int* transposed_bsr_columns, void* transposed_bsr_values)
520
- {
521
- bsr_transpose_device(rows_per_block, cols_per_block, row_count, col_count, nnz, bsr_offsets, bsr_columns,
522
- static_cast<const double*>(bsr_values), transposed_bsr_offsets, transposed_bsr_columns,
523
- static_cast<double*>(transposed_bsr_values));
428
+ wp_launch_device(WP_CURRENT_CONTEXT, bsr_set_column, nnz,
429
+ (bsr_offsets + row_count, d_values.Current(),
430
+ transposed_bsr_columns));
524
431
  }