warp-lang 1.7.2__py3-none-win_amd64.whl → 1.8.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (181) hide show
  1. warp/__init__.py +3 -1
  2. warp/__init__.pyi +3489 -1
  3. warp/autograd.py +45 -122
  4. warp/bin/warp-clang.dll +0 -0
  5. warp/bin/warp.dll +0 -0
  6. warp/build.py +241 -252
  7. warp/build_dll.py +125 -26
  8. warp/builtins.py +1907 -384
  9. warp/codegen.py +257 -101
  10. warp/config.py +12 -1
  11. warp/constants.py +1 -1
  12. warp/context.py +657 -223
  13. warp/dlpack.py +1 -1
  14. warp/examples/benchmarks/benchmark_cloth.py +2 -2
  15. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  16. warp/examples/core/example_sample_mesh.py +1 -1
  17. warp/examples/core/example_spin_lock.py +93 -0
  18. warp/examples/core/example_work_queue.py +118 -0
  19. warp/examples/fem/example_adaptive_grid.py +5 -5
  20. warp/examples/fem/example_apic_fluid.py +1 -1
  21. warp/examples/fem/example_burgers.py +1 -1
  22. warp/examples/fem/example_convection_diffusion.py +9 -6
  23. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  24. warp/examples/fem/example_deformed_geometry.py +1 -1
  25. warp/examples/fem/example_diffusion.py +2 -2
  26. warp/examples/fem/example_diffusion_3d.py +1 -1
  27. warp/examples/fem/example_distortion_energy.py +1 -1
  28. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  29. warp/examples/fem/example_magnetostatics.py +5 -3
  30. warp/examples/fem/example_mixed_elasticity.py +5 -3
  31. warp/examples/fem/example_navier_stokes.py +11 -9
  32. warp/examples/fem/example_nonconforming_contact.py +5 -3
  33. warp/examples/fem/example_streamlines.py +8 -3
  34. warp/examples/fem/utils.py +9 -8
  35. warp/examples/interop/example_jax_ffi_callback.py +2 -2
  36. warp/examples/optim/example_drone.py +1 -1
  37. warp/examples/sim/example_cloth.py +1 -1
  38. warp/examples/sim/example_cloth_self_contact.py +48 -54
  39. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  40. warp/examples/tile/example_tile_cholesky.py +2 -1
  41. warp/examples/tile/example_tile_convolution.py +1 -1
  42. warp/examples/tile/example_tile_filtering.py +1 -1
  43. warp/examples/tile/example_tile_matmul.py +1 -1
  44. warp/examples/tile/example_tile_mlp.py +2 -0
  45. warp/fabric.py +7 -7
  46. warp/fem/__init__.py +5 -0
  47. warp/fem/adaptivity.py +1 -1
  48. warp/fem/cache.py +152 -63
  49. warp/fem/dirichlet.py +2 -2
  50. warp/fem/domain.py +136 -6
  51. warp/fem/field/field.py +141 -99
  52. warp/fem/field/nodal_field.py +85 -39
  53. warp/fem/field/virtual.py +97 -52
  54. warp/fem/geometry/adaptive_nanogrid.py +91 -86
  55. warp/fem/geometry/closest_point.py +13 -0
  56. warp/fem/geometry/deformed_geometry.py +102 -40
  57. warp/fem/geometry/element.py +56 -2
  58. warp/fem/geometry/geometry.py +323 -22
  59. warp/fem/geometry/grid_2d.py +157 -62
  60. warp/fem/geometry/grid_3d.py +116 -20
  61. warp/fem/geometry/hexmesh.py +86 -20
  62. warp/fem/geometry/nanogrid.py +166 -86
  63. warp/fem/geometry/partition.py +59 -25
  64. warp/fem/geometry/quadmesh.py +86 -135
  65. warp/fem/geometry/tetmesh.py +47 -119
  66. warp/fem/geometry/trimesh.py +77 -270
  67. warp/fem/integrate.py +107 -52
  68. warp/fem/linalg.py +25 -58
  69. warp/fem/operator.py +124 -27
  70. warp/fem/quadrature/pic_quadrature.py +36 -14
  71. warp/fem/quadrature/quadrature.py +40 -16
  72. warp/fem/space/__init__.py +1 -1
  73. warp/fem/space/basis_function_space.py +66 -46
  74. warp/fem/space/basis_space.py +17 -4
  75. warp/fem/space/dof_mapper.py +1 -1
  76. warp/fem/space/function_space.py +2 -2
  77. warp/fem/space/grid_2d_function_space.py +4 -1
  78. warp/fem/space/hexmesh_function_space.py +4 -2
  79. warp/fem/space/nanogrid_function_space.py +3 -1
  80. warp/fem/space/partition.py +11 -2
  81. warp/fem/space/quadmesh_function_space.py +4 -1
  82. warp/fem/space/restriction.py +5 -2
  83. warp/fem/space/shape/__init__.py +10 -8
  84. warp/fem/space/tetmesh_function_space.py +4 -1
  85. warp/fem/space/topology.py +52 -21
  86. warp/fem/space/trimesh_function_space.py +4 -1
  87. warp/fem/utils.py +53 -8
  88. warp/jax.py +1 -2
  89. warp/jax_experimental/ffi.py +12 -17
  90. warp/jax_experimental/xla_ffi.py +37 -24
  91. warp/math.py +171 -1
  92. warp/native/array.h +99 -0
  93. warp/native/builtin.h +174 -31
  94. warp/native/coloring.cpp +1 -1
  95. warp/native/exports.h +118 -63
  96. warp/native/intersect.h +3 -3
  97. warp/native/mat.h +5 -10
  98. warp/native/mathdx.cpp +11 -5
  99. warp/native/matnn.h +1 -123
  100. warp/native/quat.h +28 -4
  101. warp/native/sparse.cpp +121 -258
  102. warp/native/sparse.cu +181 -274
  103. warp/native/spatial.h +305 -17
  104. warp/native/tile.h +583 -72
  105. warp/native/tile_radix_sort.h +1108 -0
  106. warp/native/tile_reduce.h +237 -2
  107. warp/native/tile_scan.h +240 -0
  108. warp/native/tuple.h +189 -0
  109. warp/native/vec.h +6 -16
  110. warp/native/warp.cpp +36 -4
  111. warp/native/warp.cu +574 -51
  112. warp/native/warp.h +47 -74
  113. warp/optim/linear.py +5 -1
  114. warp/paddle.py +7 -8
  115. warp/py.typed +0 -0
  116. warp/render/render_opengl.py +58 -29
  117. warp/render/render_usd.py +124 -61
  118. warp/sim/__init__.py +9 -0
  119. warp/sim/collide.py +252 -78
  120. warp/sim/graph_coloring.py +8 -1
  121. warp/sim/import_mjcf.py +4 -3
  122. warp/sim/import_usd.py +11 -7
  123. warp/sim/integrator.py +5 -2
  124. warp/sim/integrator_euler.py +1 -1
  125. warp/sim/integrator_featherstone.py +1 -1
  126. warp/sim/integrator_vbd.py +751 -320
  127. warp/sim/integrator_xpbd.py +1 -1
  128. warp/sim/model.py +265 -260
  129. warp/sim/utils.py +10 -7
  130. warp/sparse.py +303 -166
  131. warp/tape.py +52 -51
  132. warp/tests/cuda/test_conditional_captures.py +1046 -0
  133. warp/tests/cuda/test_streams.py +1 -1
  134. warp/tests/geometry/test_volume.py +2 -2
  135. warp/tests/interop/test_dlpack.py +9 -9
  136. warp/tests/interop/test_jax.py +0 -1
  137. warp/tests/run_coverage_serial.py +1 -1
  138. warp/tests/sim/disabled_kinematics.py +2 -2
  139. warp/tests/sim/{test_vbd.py → test_cloth.py} +296 -113
  140. warp/tests/sim/test_collision.py +159 -51
  141. warp/tests/sim/test_coloring.py +15 -1
  142. warp/tests/test_array.py +254 -2
  143. warp/tests/test_array_reduce.py +2 -2
  144. warp/tests/test_atomic_cas.py +299 -0
  145. warp/tests/test_codegen.py +142 -19
  146. warp/tests/test_conditional.py +47 -1
  147. warp/tests/test_ctypes.py +0 -20
  148. warp/tests/test_devices.py +8 -0
  149. warp/tests/test_fabricarray.py +4 -2
  150. warp/tests/test_fem.py +58 -25
  151. warp/tests/test_func.py +42 -1
  152. warp/tests/test_grad.py +1 -1
  153. warp/tests/test_lerp.py +1 -3
  154. warp/tests/test_map.py +481 -0
  155. warp/tests/test_mat.py +1 -24
  156. warp/tests/test_quat.py +6 -15
  157. warp/tests/test_rounding.py +10 -38
  158. warp/tests/test_runlength_encode.py +7 -7
  159. warp/tests/test_smoothstep.py +1 -1
  160. warp/tests/test_sparse.py +51 -2
  161. warp/tests/test_spatial.py +507 -1
  162. warp/tests/test_struct.py +2 -2
  163. warp/tests/test_tuple.py +265 -0
  164. warp/tests/test_types.py +2 -2
  165. warp/tests/test_utils.py +24 -18
  166. warp/tests/tile/test_tile.py +420 -1
  167. warp/tests/tile/test_tile_mathdx.py +518 -14
  168. warp/tests/tile/test_tile_reduce.py +213 -0
  169. warp/tests/tile/test_tile_shared_memory.py +130 -1
  170. warp/tests/tile/test_tile_sort.py +117 -0
  171. warp/tests/unittest_suites.py +4 -6
  172. warp/types.py +462 -308
  173. warp/utils.py +647 -86
  174. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/METADATA +20 -6
  175. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/RECORD +178 -166
  176. warp/stubs.py +0 -3381
  177. warp/tests/sim/test_xpbd.py +0 -399
  178. warp/tests/test_mlp.py +0 -282
  179. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/WHEEL +0 -0
  180. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/licenses/LICENSE.md +0 -0
  181. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/top_level.txt +0 -0
warp/native/sparse.cpp CHANGED
@@ -18,143 +18,103 @@
18
18
  #include "warp.h"
19
19
 
20
20
  #include <algorithm>
21
+ #include <cstddef>
21
22
  #include <numeric>
22
23
  #include <vector>
23
24
 
24
25
  namespace
25
26
  {
26
27
 
27
- // Specialized is_zero and accumulation function for common block sizes
28
- // Rely on compiler to unroll loops when block size is known
29
-
30
- template <int N, typename T> bool bsr_fixed_block_is_zero(const T* val, int value_size)
31
- {
32
- return std::all_of(val, val + N, [](float v) { return v == T(0); });
33
- }
34
-
35
- template <typename T> bool bsr_dyn_block_is_zero(const T* val, int value_size)
28
+ template <typename T> bool bsr_block_is_zero(int block_idx, int block_size, const void* values, const uint64_t scalar_zero_mask)
36
29
  {
37
- return std::all_of(val, val + value_size, [](float v) { return v == T(0); });
38
- }
30
+ const T* block_values = static_cast<const T*>(values) + block_idx * block_size;
31
+ const T zero_mask = static_cast<T>(scalar_zero_mask);
39
32
 
40
- template <int N, typename T> void bsr_fixed_block_accumulate(const T* val, T* sum, int value_size)
41
- {
42
- for (int i = 0; i < N; ++i, ++val, ++sum)
43
- {
44
- *sum += *val;
45
- }
33
+ return std::all_of(block_values, block_values + block_size, [zero_mask](T v) { return (v & zero_mask) == T(0); });
46
34
  }
47
35
 
48
- template <typename T> void bsr_dyn_block_accumulate(const T* val, T* sum, int value_size)
49
- {
50
- for (int i = 0; i < value_size; ++i, ++val, ++sum)
51
- {
52
- *sum += *val;
53
- }
54
- }
36
+ } // namespace
55
37
 
56
- template <int Rows, int Cols, typename T>
57
- void bsr_fixed_block_transpose(const T* src, T* dest, int row_count, int col_count)
58
- {
59
- for (int r = 0; r < Rows; ++r)
60
- {
61
- for (int c = 0; c < Cols; ++c)
62
- {
63
- dest[c * Rows + r] = src[r * Cols + c];
64
- }
65
- }
66
- }
67
38
 
68
- template <typename T> void bsr_dyn_block_transpose(const T* src, T* dest, int row_count, int col_count)
69
- {
70
- for (int r = 0; r < row_count; ++r)
39
+ WP_API void bsr_matrix_from_triplets_host(
40
+ int block_size,
41
+ int scalar_size_in_bytes,
42
+ int row_count,
43
+ int col_count,
44
+ int nnz,
45
+ const int* tpl_nnz,
46
+ const int* tpl_rows,
47
+ const int* tpl_columns,
48
+ const void* tpl_values,
49
+ const uint64_t scalar_zero_mask,
50
+ bool masked_topology,
51
+ int* tpl_block_offsets,
52
+ int* tpl_block_indices,
53
+ int* bsr_offsets,
54
+ int* bsr_columns,
55
+ int* bsr_nnz,
56
+ void* bsr_nnz_event)
57
+ {
58
+ if (tpl_nnz != nullptr)
71
59
  {
72
- for (int c = 0; c < col_count; ++c)
73
- {
74
- dest[c * row_count + r] = src[r * col_count + c];
75
- }
60
+ nnz = *tpl_nnz;
76
61
  }
77
- }
78
62
 
79
- } // namespace
80
-
81
- template <typename T>
82
- int bsr_matrix_from_triplets_host(const int rows_per_block, const int cols_per_block, const int row_count,
83
- const int nnz, const int* tpl_rows, const int* tpl_columns, const T* tpl_values,
84
- const bool prune_numerical_zeros, const bool masked, int* bsr_offsets,
85
- int* bsr_columns, T* bsr_values)
86
- {
87
-
88
- // get specialized accumulator for common block sizes (1,1), (1,2), (1,3),
89
- // (2,2), (2,3), (3,3)
90
- const int block_size = rows_per_block * cols_per_block;
91
- void (*block_accumulate_func)(const T*, T*, int);
92
- bool (*block_is_zero_func)(const T*, int);
93
- switch (block_size)
63
+ // allocate temporary buffers if not provided
64
+ bool return_summed_blocks = tpl_block_offsets != nullptr && tpl_block_indices != nullptr;
65
+ if (!return_summed_blocks)
94
66
  {
95
- case 1:
96
- block_accumulate_func = bsr_fixed_block_accumulate<1, T>;
97
- block_is_zero_func = bsr_fixed_block_is_zero<1, T>;
98
- break;
99
- case 2:
100
- block_accumulate_func = bsr_fixed_block_accumulate<2, T>;
101
- block_is_zero_func = bsr_fixed_block_is_zero<2, T>;
102
- break;
103
- case 3:
104
- block_accumulate_func = bsr_fixed_block_accumulate<3, T>;
105
- block_is_zero_func = bsr_fixed_block_is_zero<3, T>;
106
- break;
107
- case 4:
108
- block_accumulate_func = bsr_fixed_block_accumulate<4, T>;
109
- block_is_zero_func = bsr_fixed_block_is_zero<4, T>;
110
- break;
111
- case 6:
112
- block_accumulate_func = bsr_fixed_block_accumulate<6, T>;
113
- block_is_zero_func = bsr_fixed_block_is_zero<6, T>;
114
- break;
115
- case 9:
116
- block_accumulate_func = bsr_fixed_block_accumulate<9, T>;
117
- block_is_zero_func = bsr_fixed_block_is_zero<9, T>;
118
- break;
119
- default:
120
- block_accumulate_func = bsr_dyn_block_accumulate<T>;
121
- block_is_zero_func = bsr_dyn_block_is_zero<T>;
67
+ tpl_block_offsets = static_cast<int*>(alloc_host(size_t(nnz) * sizeof(int)));
68
+ tpl_block_indices = static_cast<int*>(alloc_host(size_t(nnz) * sizeof(int)));
122
69
  }
123
70
 
124
- std::vector<int> block_indices(nnz);
125
- std::iota(block_indices.begin(), block_indices.end(), 0);
126
-
127
- // remove zero blocks and invalid row indices
71
+ std::iota(tpl_block_indices, tpl_block_indices + nnz, 0);
128
72
 
129
- auto discard_block = [&](int i)
73
+ // remove invalid indices / indices not in mask
74
+ auto discard_invalid_block = [&](int i) -> bool
130
75
  {
131
76
  const int row = tpl_rows[i];
132
- if (row < 0 || row >= row_count)
133
- {
134
- return true;
135
- }
136
-
137
- if (prune_numerical_zeros && tpl_values && block_is_zero_func(tpl_values + i * block_size, block_size))
77
+ const int col = tpl_columns[i];
78
+ if (row < 0 || row >= row_count || col < 0 || col >= col_count)
138
79
  {
139
80
  return true;
140
81
  }
141
82
 
142
- if (!masked)
83
+ if (!masked_topology)
143
84
  {
144
85
  return false;
145
86
  }
146
87
 
147
88
  const int* beg = bsr_columns + bsr_offsets[row];
148
89
  const int* end = bsr_columns + bsr_offsets[row + 1];
149
- const int col = tpl_columns[i];
150
90
  const int* block = std::lower_bound(beg, end, col);
151
91
  return block == end || *block != col;
152
92
  };
153
93
 
154
- block_indices.erase(std::remove_if(block_indices.begin(), block_indices.end(), discard_block), block_indices.end());
94
+ int* valid_indices_end = std::remove_if(tpl_block_indices, tpl_block_indices + nnz, discard_invalid_block);
155
95
 
96
+ // remove zero blocks
97
+ if (tpl_values != nullptr && scalar_zero_mask != 0)
98
+ {
99
+ switch (scalar_size_in_bytes)
100
+ {
101
+ case sizeof(uint8_t):
102
+ valid_indices_end = std::remove_if(tpl_block_indices, valid_indices_end, [block_size, tpl_values, scalar_zero_mask](uint32_t i) { return bsr_block_is_zero<uint8_t>(i, block_size, tpl_values, scalar_zero_mask); });
103
+ break;
104
+ case sizeof(uint16_t):
105
+ valid_indices_end = std::remove_if(tpl_block_indices, valid_indices_end, [block_size, tpl_values, scalar_zero_mask](uint32_t i) { return bsr_block_is_zero<uint16_t>(i, block_size, tpl_values, scalar_zero_mask); });
106
+ break;
107
+ case sizeof(uint32_t):
108
+ valid_indices_end = std::remove_if(tpl_block_indices, valid_indices_end, [block_size, tpl_values, scalar_zero_mask](uint32_t i) { return bsr_block_is_zero<uint32_t>(i, block_size, tpl_values, scalar_zero_mask); });
109
+ break;
110
+ case sizeof(uint64_t):
111
+ valid_indices_end = std::remove_if(tpl_block_indices, valid_indices_end, [block_size, tpl_values, scalar_zero_mask](uint32_t i) { return bsr_block_is_zero<uint64_t>(i, block_size, tpl_values, scalar_zero_mask); });
112
+ break;
113
+ }
114
+ }
115
+
156
116
  // sort block indices according to lexico order
157
- std::sort(block_indices.begin(), block_indices.end(), [tpl_rows, tpl_columns](int i, int j) -> bool
117
+ std::sort(tpl_block_indices, valid_indices_end, [tpl_rows, tpl_columns](int i, int j) -> bool
158
118
  { return tpl_rows[i] < tpl_rows[j] || (tpl_rows[i] == tpl_rows[j] && tpl_columns[i] < tpl_columns[j]); });
159
119
 
160
120
  // accumulate blocks at same locations, count blocks per row
@@ -162,107 +122,62 @@ int bsr_matrix_from_triplets_host(const int rows_per_block, const int cols_per_b
162
122
 
163
123
  int current_row = -1;
164
124
  int current_col = -1;
125
+ int current_block_idx = -1;
165
126
 
166
- // so that we get back to the start for the first block
167
- if (bsr_values)
168
- {
169
- bsr_values -= block_size;
170
- }
171
-
172
- for (int i = 0; i < block_indices.size(); ++i)
127
+ for (int *block = tpl_block_indices, *block_offset = tpl_block_offsets ; block != valid_indices_end ; ++ block)
173
128
  {
174
- int idx = block_indices[i];
129
+ int32_t idx = *block;
175
130
  int row = tpl_rows[idx];
176
131
  int col = tpl_columns[idx];
177
- const T* val = tpl_values + idx * block_size;
178
132
 
179
- if (row == current_row && col == current_col)
180
- {
181
- if (bsr_values)
182
- {
183
- block_accumulate_func(val, bsr_values, block_size);
184
- }
185
- }
186
- else
133
+ if (row != current_row || col != current_col)
187
134
  {
188
135
  *(bsr_columns++) = col;
189
136
 
190
- if (bsr_values)
191
- {
192
- bsr_values += block_size;
193
- std::copy_n(val, block_size, bsr_values);
194
- }
137
+ ++bsr_offsets[row + 1];
195
138
 
196
- bsr_offsets[row + 1]++;
139
+ if(current_row == -1) {
140
+ *block_offset = 0;
141
+ } else {
142
+ *(block_offset+1) = *block_offset;
143
+ ++block_offset;
144
+ }
197
145
 
198
146
  current_row = row;
199
147
  current_col = col;
200
148
  }
149
+
150
+ ++(*block_offset);
201
151
  }
202
152
 
203
153
  // build postfix sum of row counts
204
154
  std::partial_sum(bsr_offsets, bsr_offsets + row_count + 1, bsr_offsets);
205
155
 
206
- return bsr_offsets[row_count];
207
- }
208
-
209
- template <typename T>
210
- void bsr_transpose_host(int rows_per_block, int cols_per_block, int row_count, int col_count, int nnz_up,
211
- const int* bsr_offsets, const int* bsr_columns, const T* bsr_values,
212
- int* transposed_bsr_offsets, int* transposed_bsr_columns, T* transposed_bsr_values)
213
- {
214
- const int nnz = bsr_offsets[row_count];
215
- const int block_size = rows_per_block * cols_per_block;
156
+ if(!return_summed_blocks)
157
+ {
158
+ // free our temporary buffers
159
+ free_host(tpl_block_offsets);
160
+ free_host(tpl_block_indices);
161
+ }
216
162
 
217
- void (*block_transpose_func)(const T*, T*, int, int) = bsr_dyn_block_transpose<T>;
218
- switch (rows_per_block)
163
+ if (bsr_nnz != nullptr)
219
164
  {
220
- case 1:
221
- switch (cols_per_block)
222
- {
223
- case 1:
224
- block_transpose_func = bsr_fixed_block_transpose<1, 1, T>;
225
- break;
226
- case 2:
227
- block_transpose_func = bsr_fixed_block_transpose<1, 2, T>;
228
- break;
229
- case 3:
230
- block_transpose_func = bsr_fixed_block_transpose<1, 3, T>;
231
- break;
232
- }
233
- break;
234
- case 2:
235
- switch (cols_per_block)
236
- {
237
- case 1:
238
- block_transpose_func = bsr_fixed_block_transpose<2, 1, T>;
239
- break;
240
- case 2:
241
- block_transpose_func = bsr_fixed_block_transpose<2, 2, T>;
242
- break;
243
- case 3:
244
- block_transpose_func = bsr_fixed_block_transpose<2, 3, T>;
245
- break;
246
- }
247
- break;
248
- case 3:
249
- switch (cols_per_block)
250
- {
251
- case 1:
252
- block_transpose_func = bsr_fixed_block_transpose<3, 1, T>;
253
- break;
254
- case 2:
255
- block_transpose_func = bsr_fixed_block_transpose<3, 2, T>;
256
- break;
257
- case 3:
258
- block_transpose_func = bsr_fixed_block_transpose<3, 3, T>;
259
- break;
260
- }
261
- break;
165
+ *bsr_nnz = bsr_offsets[row_count];
262
166
  }
167
+ }
263
168
 
264
- std::vector<int> block_indices(nnz), bsr_rows(nnz);
265
- std::iota(block_indices.begin(), block_indices.end(), 0);
169
+ WP_API void bsr_transpose_host(
170
+ int row_count, int col_count, int nnz,
171
+ const int* bsr_offsets, const int* bsr_columns,
172
+ int* transposed_bsr_offsets,
173
+ int* transposed_bsr_columns,
174
+ int* block_indices
175
+ )
176
+ {
177
+ nnz = bsr_offsets[row_count];
178
+
179
+ std::vector<int> bsr_rows(nnz);
180
+ std::iota(block_indices, block_indices + nnz, 0);
266
181
 
267
182
  // Fill row indices from offsets
268
183
  for (int row = 0; row < row_count; ++row)
@@ -272,7 +187,7 @@ void bsr_transpose_host(int rows_per_block, int cols_per_block, int row_count, i
272
187
 
273
188
  // sort block indices according to (transposed) lexico order
274
189
  std::sort(
275
- block_indices.begin(), block_indices.end(), [&bsr_rows, bsr_columns](int i, int j) -> bool
190
+ block_indices, block_indices + nnz, [&bsr_rows, bsr_columns](int i, int j) -> bool
276
191
  { return bsr_columns[i] < bsr_columns[j] || (bsr_columns[i] == bsr_columns[j] && bsr_rows[i] < bsr_rows[j]); });
277
192
 
278
193
  // Count blocks per column and transpose blocks
@@ -286,93 +201,41 @@ void bsr_transpose_host(int rows_per_block, int cols_per_block, int row_count, i
286
201
 
287
202
  ++transposed_bsr_offsets[col + 1];
288
203
  transposed_bsr_columns[i] = row;
289
-
290
- if (transposed_bsr_values != nullptr)
291
- {
292
- const T* src_block = bsr_values + idx * block_size;
293
- T* dst_block = transposed_bsr_values + i * block_size;
294
- block_transpose_func(src_block, dst_block, rows_per_block, cols_per_block);
295
- }
296
204
  }
297
205
 
298
206
  // build postfix sum of column counts
299
207
  std::partial_sum(transposed_bsr_offsets, transposed_bsr_offsets + col_count + 1, transposed_bsr_offsets);
300
- }
301
-
302
- WP_API void bsr_matrix_from_triplets_float_host(int rows_per_block, int cols_per_block, int row_count, int nnz,
303
- int* tpl_rows, int* tpl_columns, void* tpl_values,
304
- bool prune_numerical_zeros, bool masked, int* bsr_offsets,
305
- int* bsr_columns, void* bsr_values, int* bsr_nnz, void* bsr_nnz_event)
306
- {
307
- bsr_matrix_from_triplets_host<float>(rows_per_block, cols_per_block, row_count, nnz, tpl_rows, tpl_columns,
308
- static_cast<const float*>(tpl_values), prune_numerical_zeros, masked,
309
- bsr_offsets, bsr_columns, static_cast<float*>(bsr_values));
310
- if (bsr_nnz)
311
- {
312
- *bsr_nnz = bsr_offsets[row_count];
313
- }
314
- }
315
-
316
- WP_API void bsr_matrix_from_triplets_double_host(int rows_per_block, int cols_per_block, int row_count, int nnz,
317
- int* tpl_rows, int* tpl_columns, void* tpl_values,
318
- bool prune_numerical_zeros, bool masked, int* bsr_offsets,
319
- int* bsr_columns, void* bsr_values, int* bsr_nnz, void* bsr_nnz_event)
320
- {
321
- bsr_matrix_from_triplets_host<double>(rows_per_block, cols_per_block, row_count, nnz, tpl_rows, tpl_columns,
322
- static_cast<const double*>(tpl_values), prune_numerical_zeros, masked,
323
- bsr_offsets, bsr_columns, static_cast<double*>(bsr_values));
324
- if (bsr_nnz)
325
- {
326
- *bsr_nnz = bsr_offsets[row_count];
327
- }
328
- }
329
208
 
330
- WP_API void bsr_transpose_float_host(int rows_per_block, int cols_per_block, int row_count, int col_count, int nnz,
331
- int* bsr_offsets, int* bsr_columns, void* bsr_values, int* transposed_bsr_offsets,
332
- int* transposed_bsr_columns, void* transposed_bsr_values)
333
- {
334
- bsr_transpose_host(rows_per_block, cols_per_block, row_count, col_count, nnz, bsr_offsets, bsr_columns,
335
- static_cast<const float*>(bsr_values), transposed_bsr_offsets, transposed_bsr_columns,
336
- static_cast<float*>(transposed_bsr_values));
337
- }
338
-
339
- WP_API void bsr_transpose_double_host(int rows_per_block, int cols_per_block, int row_count, int col_count, int nnz,
340
- int* bsr_offsets, int* bsr_columns, void* bsr_values, int* transposed_bsr_offsets,
341
- int* transposed_bsr_columns, void* transposed_bsr_values)
342
- {
343
- bsr_transpose_host(rows_per_block, cols_per_block, row_count, col_count, nnz, bsr_offsets, bsr_columns,
344
- static_cast<const double*>(bsr_values), transposed_bsr_offsets, transposed_bsr_columns,
345
- static_cast<double*>(transposed_bsr_values));
346
209
  }
347
210
 
348
211
  #if !WP_ENABLE_CUDA
349
- WP_API void bsr_matrix_from_triplets_float_device(int rows_per_block, int cols_per_block, int row_count, int nnz,
350
- int* tpl_rows, int* tpl_columns, void* tpl_values,
351
- bool prune_numerical_zeros, bool masked, int* bsr_offsets,
352
- int* bsr_columns, void* bsr_values, int* bsr_nnz, void* bsr_nnz_event)
353
- {
354
- }
212
+ WP_API void bsr_matrix_from_triplets_device(
213
+ int block_size,
214
+ int scalar_size_in_bytes,
215
+ int row_count,
216
+ int col_count,
217
+ int tpl_nnz_upper_bound,
218
+ const int* tpl_nnz,
219
+ const int* tpl_rows,
220
+ const int* tpl_columns,
221
+ const void* tpl_values,
222
+ const uint64_t scalar_zero_mask,
223
+ bool masked_topology,
224
+ int* summed_block_offsets,
225
+ int* summed_block_indices,
226
+ int* bsr_offsets,
227
+ int* bsr_columns,
228
+ int* bsr_nnz,
229
+ void* bsr_nnz_event) {}
230
+
231
+
232
+ WP_API void bsr_transpose_device(
233
+ int row_count, int col_count, int nnz,
234
+ const int* bsr_offsets, const int* bsr_columns,
235
+ int* transposed_bsr_offsets,
236
+ int* transposed_bsr_columns,
237
+ int* src_block_indices) {}
355
238
 
356
- WP_API void bsr_matrix_from_triplets_double_device(int rows_per_block, int cols_per_block, int row_count, int nnz,
357
- int* tpl_rows, int* tpl_columns, void* tpl_values,
358
- bool prune_numerical_zeros, bool masked, int* bsr_offsets,
359
- int* bsr_columns, void* bsr_values, int* bsr_nnz,
360
- void* bsr_nnz_event)
361
- {
362
- }
363
239
 
364
- WP_API void bsr_transpose_float_device(int rows_per_block, int cols_per_block, int row_count, int col_count, int nnz,
365
- int* bsr_offsets, int* bsr_columns, void* bsr_values,
366
- int* transposed_bsr_offsets, int* transposed_bsr_columns,
367
- void* transposed_bsr_values)
368
- {
369
- }
370
-
371
- WP_API void bsr_transpose_double_device(int rows_per_block, int cols_per_block, int row_count, int col_count, int nnz,
372
- int* bsr_offsets, int* bsr_columns, void* bsr_values,
373
- int* transposed_bsr_offsets, int* transposed_bsr_columns,
374
- void* transposed_bsr_values)
375
- {
376
- }
377
240
 
378
241
  #endif