warp-lang 1.2.2__py3-none-manylinux2014_x86_64.whl → 1.3.1__py3-none-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (193) hide show
  1. warp/__init__.py +8 -6
  2. warp/autograd.py +823 -0
  3. warp/bin/warp.so +0 -0
  4. warp/build.py +6 -2
  5. warp/builtins.py +1412 -888
  6. warp/codegen.py +503 -166
  7. warp/config.py +48 -18
  8. warp/context.py +400 -198
  9. warp/dlpack.py +8 -0
  10. warp/examples/assets/bunny.usd +0 -0
  11. warp/examples/benchmarks/benchmark_cloth_warp.py +1 -1
  12. warp/examples/benchmarks/benchmark_interop_torch.py +158 -0
  13. warp/examples/benchmarks/benchmark_launches.py +1 -1
  14. warp/examples/core/example_cupy.py +78 -0
  15. warp/examples/fem/example_apic_fluid.py +17 -36
  16. warp/examples/fem/example_burgers.py +9 -18
  17. warp/examples/fem/example_convection_diffusion.py +7 -17
  18. warp/examples/fem/example_convection_diffusion_dg.py +27 -47
  19. warp/examples/fem/example_deformed_geometry.py +11 -22
  20. warp/examples/fem/example_diffusion.py +7 -18
  21. warp/examples/fem/example_diffusion_3d.py +24 -28
  22. warp/examples/fem/example_diffusion_mgpu.py +7 -14
  23. warp/examples/fem/example_magnetostatics.py +190 -0
  24. warp/examples/fem/example_mixed_elasticity.py +111 -80
  25. warp/examples/fem/example_navier_stokes.py +30 -34
  26. warp/examples/fem/example_nonconforming_contact.py +290 -0
  27. warp/examples/fem/example_stokes.py +17 -32
  28. warp/examples/fem/example_stokes_transfer.py +12 -21
  29. warp/examples/fem/example_streamlines.py +350 -0
  30. warp/examples/fem/utils.py +936 -0
  31. warp/fabric.py +5 -2
  32. warp/fem/__init__.py +13 -3
  33. warp/fem/cache.py +161 -11
  34. warp/fem/dirichlet.py +37 -28
  35. warp/fem/domain.py +105 -14
  36. warp/fem/field/__init__.py +14 -3
  37. warp/fem/field/field.py +454 -11
  38. warp/fem/field/nodal_field.py +33 -18
  39. warp/fem/geometry/deformed_geometry.py +50 -15
  40. warp/fem/geometry/hexmesh.py +12 -24
  41. warp/fem/geometry/nanogrid.py +106 -31
  42. warp/fem/geometry/quadmesh_2d.py +6 -11
  43. warp/fem/geometry/tetmesh.py +103 -61
  44. warp/fem/geometry/trimesh_2d.py +98 -47
  45. warp/fem/integrate.py +231 -186
  46. warp/fem/operator.py +14 -9
  47. warp/fem/quadrature/pic_quadrature.py +35 -9
  48. warp/fem/quadrature/quadrature.py +119 -32
  49. warp/fem/space/basis_space.py +98 -22
  50. warp/fem/space/collocated_function_space.py +3 -1
  51. warp/fem/space/function_space.py +7 -2
  52. warp/fem/space/grid_2d_function_space.py +3 -3
  53. warp/fem/space/grid_3d_function_space.py +4 -4
  54. warp/fem/space/hexmesh_function_space.py +3 -2
  55. warp/fem/space/nanogrid_function_space.py +12 -14
  56. warp/fem/space/partition.py +45 -47
  57. warp/fem/space/restriction.py +19 -16
  58. warp/fem/space/shape/cube_shape_function.py +91 -3
  59. warp/fem/space/shape/shape_function.py +7 -0
  60. warp/fem/space/shape/square_shape_function.py +32 -0
  61. warp/fem/space/shape/tet_shape_function.py +11 -7
  62. warp/fem/space/shape/triangle_shape_function.py +10 -1
  63. warp/fem/space/topology.py +116 -42
  64. warp/fem/types.py +8 -1
  65. warp/fem/utils.py +301 -83
  66. warp/native/array.h +16 -0
  67. warp/native/builtin.h +0 -15
  68. warp/native/cuda_util.cpp +14 -6
  69. warp/native/exports.h +1348 -1308
  70. warp/native/quat.h +79 -0
  71. warp/native/rand.h +27 -4
  72. warp/native/sparse.cpp +83 -81
  73. warp/native/sparse.cu +381 -453
  74. warp/native/vec.h +64 -0
  75. warp/native/volume.cpp +40 -49
  76. warp/native/volume_builder.cu +2 -3
  77. warp/native/volume_builder.h +12 -17
  78. warp/native/warp.cu +3 -3
  79. warp/native/warp.h +69 -59
  80. warp/render/render_opengl.py +17 -9
  81. warp/sim/articulation.py +117 -17
  82. warp/sim/collide.py +35 -29
  83. warp/sim/model.py +123 -18
  84. warp/sim/render.py +3 -1
  85. warp/sparse.py +867 -203
  86. warp/stubs.py +312 -541
  87. warp/tape.py +29 -1
  88. warp/tests/disabled_kinematics.py +1 -1
  89. warp/tests/test_adam.py +1 -1
  90. warp/tests/test_arithmetic.py +1 -1
  91. warp/tests/test_array.py +58 -1
  92. warp/tests/test_array_reduce.py +1 -1
  93. warp/tests/test_async.py +1 -1
  94. warp/tests/test_atomic.py +1 -1
  95. warp/tests/test_bool.py +1 -1
  96. warp/tests/test_builtins_resolution.py +1 -1
  97. warp/tests/test_bvh.py +6 -1
  98. warp/tests/test_closest_point_edge_edge.py +1 -1
  99. warp/tests/test_codegen.py +91 -1
  100. warp/tests/test_compile_consts.py +1 -1
  101. warp/tests/test_conditional.py +1 -1
  102. warp/tests/test_copy.py +1 -1
  103. warp/tests/test_ctypes.py +1 -1
  104. warp/tests/test_dense.py +1 -1
  105. warp/tests/test_devices.py +1 -1
  106. warp/tests/test_dlpack.py +1 -1
  107. warp/tests/test_examples.py +33 -4
  108. warp/tests/test_fabricarray.py +5 -2
  109. warp/tests/test_fast_math.py +1 -1
  110. warp/tests/test_fem.py +213 -6
  111. warp/tests/test_fp16.py +1 -1
  112. warp/tests/test_func.py +1 -1
  113. warp/tests/test_future_annotations.py +90 -0
  114. warp/tests/test_generics.py +1 -1
  115. warp/tests/test_grad.py +1 -1
  116. warp/tests/test_grad_customs.py +1 -1
  117. warp/tests/test_grad_debug.py +247 -0
  118. warp/tests/test_hash_grid.py +6 -1
  119. warp/tests/test_implicit_init.py +354 -0
  120. warp/tests/test_import.py +1 -1
  121. warp/tests/test_indexedarray.py +1 -1
  122. warp/tests/test_intersect.py +1 -1
  123. warp/tests/test_jax.py +1 -1
  124. warp/tests/test_large.py +1 -1
  125. warp/tests/test_launch.py +1 -1
  126. warp/tests/test_lerp.py +1 -1
  127. warp/tests/test_linear_solvers.py +1 -1
  128. warp/tests/test_lvalue.py +1 -1
  129. warp/tests/test_marching_cubes.py +5 -2
  130. warp/tests/test_mat.py +34 -35
  131. warp/tests/test_mat_lite.py +2 -1
  132. warp/tests/test_mat_scalar_ops.py +1 -1
  133. warp/tests/test_math.py +1 -1
  134. warp/tests/test_matmul.py +20 -16
  135. warp/tests/test_matmul_lite.py +1 -1
  136. warp/tests/test_mempool.py +1 -1
  137. warp/tests/test_mesh.py +5 -2
  138. warp/tests/test_mesh_query_aabb.py +1 -1
  139. warp/tests/test_mesh_query_point.py +1 -1
  140. warp/tests/test_mesh_query_ray.py +1 -1
  141. warp/tests/test_mlp.py +1 -1
  142. warp/tests/test_model.py +1 -1
  143. warp/tests/test_module_hashing.py +77 -1
  144. warp/tests/test_modules_lite.py +1 -1
  145. warp/tests/test_multigpu.py +1 -1
  146. warp/tests/test_noise.py +1 -1
  147. warp/tests/test_operators.py +1 -1
  148. warp/tests/test_options.py +1 -1
  149. warp/tests/test_overwrite.py +542 -0
  150. warp/tests/test_peer.py +1 -1
  151. warp/tests/test_pinned.py +1 -1
  152. warp/tests/test_print.py +1 -1
  153. warp/tests/test_quat.py +15 -1
  154. warp/tests/test_rand.py +1 -1
  155. warp/tests/test_reload.py +1 -1
  156. warp/tests/test_rounding.py +1 -1
  157. warp/tests/test_runlength_encode.py +1 -1
  158. warp/tests/test_scalar_ops.py +95 -0
  159. warp/tests/test_sim_grad.py +1 -1
  160. warp/tests/test_sim_kinematics.py +1 -1
  161. warp/tests/test_smoothstep.py +1 -1
  162. warp/tests/test_sparse.py +82 -15
  163. warp/tests/test_spatial.py +1 -1
  164. warp/tests/test_special_values.py +2 -11
  165. warp/tests/test_streams.py +11 -1
  166. warp/tests/test_struct.py +1 -1
  167. warp/tests/test_tape.py +1 -1
  168. warp/tests/test_torch.py +194 -1
  169. warp/tests/test_transient_module.py +1 -1
  170. warp/tests/test_types.py +1 -1
  171. warp/tests/test_utils.py +1 -1
  172. warp/tests/test_vec.py +15 -63
  173. warp/tests/test_vec_lite.py +2 -1
  174. warp/tests/test_vec_scalar_ops.py +65 -1
  175. warp/tests/test_verify_fp.py +1 -1
  176. warp/tests/test_volume.py +28 -2
  177. warp/tests/test_volume_write.py +1 -1
  178. warp/tests/unittest_serial.py +1 -1
  179. warp/tests/unittest_suites.py +9 -1
  180. warp/tests/walkthrough_debug.py +1 -1
  181. warp/thirdparty/unittest_parallel.py +2 -5
  182. warp/torch.py +103 -41
  183. warp/types.py +341 -224
  184. warp/utils.py +11 -2
  185. {warp_lang-1.2.2.dist-info → warp_lang-1.3.1.dist-info}/METADATA +99 -46
  186. warp_lang-1.3.1.dist-info/RECORD +368 -0
  187. warp/examples/fem/bsr_utils.py +0 -378
  188. warp/examples/fem/mesh_utils.py +0 -133
  189. warp/examples/fem/plot_utils.py +0 -292
  190. warp_lang-1.2.2.dist-info/RECORD +0 -359
  191. {warp_lang-1.2.2.dist-info → warp_lang-1.3.1.dist-info}/LICENSE.md +0 -0
  192. {warp_lang-1.2.2.dist-info → warp_lang-1.3.1.dist-info}/WHEEL +0 -0
  193. {warp_lang-1.2.2.dist-info → warp_lang-1.3.1.dist-info}/top_level.txt +0 -0
warp/native/quat.h CHANGED
@@ -36,6 +36,42 @@ struct quat_t
36
36
 
37
37
  // real part
38
38
  Type w;
39
+
40
+ inline CUDA_CALLABLE Type operator[](int index) const
41
+ {
42
+ switch (index)
43
+ {
44
+ case 0:
45
+ return x;
46
+ case 1:
47
+ return y;
48
+ case 2:
49
+ return z;
50
+ case 3:
51
+ return w;
52
+ default:
53
+ assert(0);
54
+ return x;
55
+ }
56
+ }
57
+
58
+ inline CUDA_CALLABLE Type& operator[](int index)
59
+ {
60
+ switch (index)
61
+ {
62
+ case 0:
63
+ return x;
64
+ case 1:
65
+ return y;
66
+ case 2:
67
+ return z;
68
+ case 3:
69
+ return w;
70
+ default:
71
+ assert(0);
72
+ return x;
73
+ }
74
+ }
39
75
  };
40
76
 
41
77
  using quat = quat_t<float>;
@@ -400,6 +436,49 @@ inline CUDA_CALLABLE Type extract(const quat_t<Type>& a, int idx)
400
436
  else {return a.w;}
401
437
  }
402
438
 
439
+ template<typename Type>
440
+ inline CUDA_CALLABLE Type* index(quat_t<Type>& q, int idx)
441
+ {
442
+ #ifndef NDEBUG
443
+ if (idx < 0 || idx > 3)
444
+ {
445
+ printf("quat index %d out of bounds at %s %d\n", idx, __FILE__, __LINE__);
446
+ assert(0);
447
+ }
448
+ #endif
449
+
450
+ return &q[idx];
451
+ }
452
+
453
+ template<typename Type>
454
+ inline CUDA_CALLABLE Type* indexref(quat_t<Type>* q, int idx)
455
+ {
456
+ #ifndef NDEBUG
457
+ if (idx < 0 || idx > 3)
458
+ {
459
+ printf("quat store %d out of bounds at %s %d\n", idx, __FILE__, __LINE__);
460
+ assert(0);
461
+ }
462
+ #endif
463
+
464
+ return &((*q)[idx]);
465
+ }
466
+
467
+ template<typename Type>
468
+ inline CUDA_CALLABLE void adj_index(quat_t<Type>& q, int idx,
469
+ quat_t<Type>& adj_q, int adj_idx, const Type& adj_value)
470
+ {
471
+ // nop
472
+ }
473
+
474
+
475
+ template<typename Type>
476
+ inline CUDA_CALLABLE void adj_indexref(quat_t<Type>* q, int idx,
477
+ quat_t<Type>& adj_q, int adj_idx, const Type& adj_value)
478
+ {
479
+ // nop
480
+ }
481
+
403
482
  template<typename Type>
404
483
  CUDA_CALLABLE inline quat_t<Type> lerp(const quat_t<Type>& a, const quat_t<Type>& b, Type t)
405
484
  {
warp/native/rand.h CHANGED
@@ -13,13 +13,24 @@
13
13
  #define M_PI_F 3.14159265358979323846f
14
14
  #endif
15
15
 
16
- #ifndef LOG_EPSILON
17
- #define LOG_EPSILON 5.96e-8f
16
+ /*
17
+ * Please first read the randf comment. randf returns values uniformly distributed in the range [0.f, 1.f - 2.^-24] in equal intervals of size 2.^-24.
18
+ * randn computes sqrt(-2.f * log(x)). For this to return a real value, log(x) < 0.f (we exclude 0.f as a precaution) and therefore x < 1.f.
19
+ * For it to be finite, x > 0.f. So x must be in (0.f, 1.f). We define RANDN_EPSILON to be 2^-24 truncated to 5.96e-8f and add it to the range of randf,
20
+ * giving the domain [RANDN_EPSILON, 1.f - 2.^-24 + RAND_EPSILON] which satisfies the requirement that x is in (0.f, 1.f).
21
+ */
22
+
23
+ #ifndef RANDN_EPSILON
24
+ #define RANDN_EPSILON 5.96e-8f
18
25
  #endif
19
26
 
20
27
  namespace wp
21
28
  {
22
29
 
30
+ /*
31
+ * Mark Jarzynski and Marc Olano, Hash Functions for GPU Rendering, Journal of Computer
32
+ * Graphics Techniques (JCGT), vol. 9, no. 3, 20–38, 2020
33
+ */
23
34
  inline CUDA_CALLABLE uint32 rand_pcg(uint32 state)
24
35
  {
25
36
  uint32 b = state * 747796405u + 2891336453u;
@@ -33,11 +44,20 @@ inline CUDA_CALLABLE uint32 rand_init(int seed, int offset) { return rand_pcg(ui
33
44
  inline CUDA_CALLABLE int randi(uint32& state) { state = rand_pcg(state); return int(state); }
34
45
  inline CUDA_CALLABLE int randi(uint32& state, int min, int max) { state = rand_pcg(state); return state % (max - min) + min; }
35
46
 
47
+ /*
48
+ * We want to ensure randf adheres to a uniform distribution over [0,1). The set of all possible float32 (IEEE 754 standard) values is not uniformly distributed however.
49
+ * On the other hand, for a given sign and exponent, the mantissa of the float32 representation is uniformly distributed.
50
+ * Fixing an exponent of -1, we can craft a uniform distribution using the sign bit and 23-bit mantissa that spans the domain [0, 1) in 2^24 equal intervals.
51
+ * We can map 2^24 unique unsigned integers to these 2^24 intervals, so if our random number generator returns values in the range [0, 2^24) without bias,
52
+ * we can ensure that our float distribution in the range [0, 1) is also without bias.
53
+ * Our random number generator returns values in the range [0, 2^32), so we bit shift a random unsigned int 8 places, and then make the assumption that the remaining bit strings
54
+ * are uniformly distributed. After dividing by 2.^24, randf returns values uniformly distributed in the range [0.f, 1.f - 2.^-24].
55
+ */
36
56
  inline CUDA_CALLABLE float randf(uint32& state) { state = rand_pcg(state); return (state >> 8) * (1.0f / 16777216.0f); }
37
57
  inline CUDA_CALLABLE float randf(uint32& state, float min, float max) { return (max - min) * randf(state) + min; }
38
58
 
39
59
  // Box-Muller method
40
- inline CUDA_CALLABLE float randn(uint32& state) { return sqrt(-2.f * log(randf(state) + LOG_EPSILON)) * cos(2.f * M_PI_F * randf(state)); }
60
+ inline CUDA_CALLABLE float randn(uint32& state) { return sqrt(-2.f * log(randf(state) + RANDN_EPSILON)) * cos(2.f * M_PI_F * randf(state)); }
41
61
 
42
62
  inline CUDA_CALLABLE void adj_rand_init(int seed, int& adj_seed, float adj_ret) {}
43
63
  inline CUDA_CALLABLE void adj_rand_init(int seed, int offset, int& adj_seed, int& adj_offset, float adj_ret) {}
@@ -56,6 +76,9 @@ inline CUDA_CALLABLE int sample_cdf(uint32& state, const array_t<float>& cdf)
56
76
  return lower_bound<float>(cdf, u);
57
77
  }
58
78
 
79
+ /*
80
+ * uniform sampling methods for various geometries
81
+ */
59
82
  inline CUDA_CALLABLE vec2 sample_triangle(uint32& state)
60
83
  {
61
84
  float r = sqrt(randf(state));
@@ -301,4 +324,4 @@ inline CUDA_CALLABLE void random_poisson_mult(uint32& state, float lam, uint32&
301
324
  inline CUDA_CALLABLE void adj_random_poisson(uint32& state, float lam, uint32& adj_state, float& adj_lam, const uint32& adj_ret) {}
302
325
  inline CUDA_CALLABLE void adj_poisson(uint32& state, float lam, uint32& adj_state, float& adj_lam, const uint32& adj_ret) {}
303
326
 
304
- } // namespace wp
327
+ } // namespace wp
warp/native/sparse.cpp CHANGED
@@ -10,17 +10,17 @@ namespace
10
10
  // Specialized is_zero and accumulation function for common block sizes
11
11
  // Rely on compiler to unroll loops when block size is known
12
12
 
13
- template <int N, typename T> bool bsr_fixed_block_is_zero(const T *val, int value_size)
13
+ template <int N, typename T> bool bsr_fixed_block_is_zero(const T* val, int value_size)
14
14
  {
15
15
  return std::all_of(val, val + N, [](float v) { return v == T(0); });
16
16
  }
17
17
 
18
- template <typename T> bool bsr_dyn_block_is_zero(const T *val, int value_size)
18
+ template <typename T> bool bsr_dyn_block_is_zero(const T* val, int value_size)
19
19
  {
20
20
  return std::all_of(val, val + value_size, [](float v) { return v == T(0); });
21
21
  }
22
22
 
23
- template <int N, typename T> void bsr_fixed_block_accumulate(const T *val, T *sum, int value_size)
23
+ template <int N, typename T> void bsr_fixed_block_accumulate(const T* val, T* sum, int value_size)
24
24
  {
25
25
  for (int i = 0; i < N; ++i, ++val, ++sum)
26
26
  {
@@ -28,7 +28,7 @@ template <int N, typename T> void bsr_fixed_block_accumulate(const T *val, T *su
28
28
  }
29
29
  }
30
30
 
31
- template <typename T> void bsr_dyn_block_accumulate(const T *val, T *sum, int value_size)
31
+ template <typename T> void bsr_dyn_block_accumulate(const T* val, T* sum, int value_size)
32
32
  {
33
33
  for (int i = 0; i < value_size; ++i, ++val, ++sum)
34
34
  {
@@ -37,7 +37,7 @@ template <typename T> void bsr_dyn_block_accumulate(const T *val, T *sum, int va
37
37
  }
38
38
 
39
39
  template <int Rows, int Cols, typename T>
40
- void bsr_fixed_block_transpose(const T *src, T *dest, int row_count, int col_count)
40
+ void bsr_fixed_block_transpose(const T* src, T* dest, int row_count, int col_count)
41
41
  {
42
42
  for (int r = 0; r < Rows; ++r)
43
43
  {
@@ -48,7 +48,7 @@ void bsr_fixed_block_transpose(const T *src, T *dest, int row_count, int col_cou
48
48
  }
49
49
  }
50
50
 
51
- template <typename T> void bsr_dyn_block_transpose(const T *src, T *dest, int row_count, int col_count)
51
+ template <typename T> void bsr_dyn_block_transpose(const T* src, T* dest, int row_count, int col_count)
52
52
  {
53
53
  for (int r = 0; r < row_count; ++r)
54
54
  {
@@ -63,15 +63,15 @@ template <typename T> void bsr_dyn_block_transpose(const T *src, T *dest, int ro
63
63
 
64
64
  template <typename T>
65
65
  int bsr_matrix_from_triplets_host(const int rows_per_block, const int cols_per_block, const int row_count,
66
- const int nnz, const int *tpl_rows, const int *tpl_columns, const T *tpl_values,
67
- int *bsr_offsets, int *bsr_columns, T *bsr_values)
66
+ const int nnz, const int* tpl_rows, const int* tpl_columns, const T* tpl_values,
67
+ const bool prune_numerical_zeros, int* bsr_offsets, int* bsr_columns, T* bsr_values)
68
68
  {
69
69
 
70
70
  // get specialized accumulator for common block sizes (1,1), (1,2), (1,3),
71
71
  // (2,2), (2,3), (3,3)
72
72
  const int block_size = rows_per_block * cols_per_block;
73
- void (*block_accumulate_func)(const T *, T *, int);
74
- bool (*block_is_zero_func)(const T *, int);
73
+ void (*block_accumulate_func)(const T*, T*, int);
74
+ bool (*block_is_zero_func)(const T*, int);
75
75
  switch (block_size)
76
76
  {
77
77
  case 1:
@@ -106,20 +106,19 @@ int bsr_matrix_from_triplets_host(const int rows_per_block, const int cols_per_b
106
106
  std::vector<int> block_indices(nnz);
107
107
  std::iota(block_indices.begin(), block_indices.end(), 0);
108
108
 
109
- // remove zero block indices
110
- if (tpl_values)
111
- {
112
- block_indices.erase(std::remove_if(block_indices.begin(), block_indices.end(),
113
- [block_is_zero_func, tpl_values, block_size](int i) {
114
- return block_is_zero_func(tpl_values + i * block_size, block_size);
115
- }),
116
- block_indices.end());
117
- }
109
+ // remove zero blocks and invalid row indices
110
+ block_indices.erase(std::remove_if(block_indices.begin(), block_indices.end(),
111
+ [&](int i)
112
+ {
113
+ return tpl_rows[i] < 0 || tpl_rows[i] >= row_count ||
114
+ (prune_numerical_zeros && tpl_values &&
115
+ block_is_zero_func(tpl_values + i * block_size, block_size));
116
+ }),
117
+ block_indices.end());
118
118
 
119
119
  // sort block indices according to lexico order
120
- std::sort(block_indices.begin(), block_indices.end(), [tpl_rows, tpl_columns](int i, int j) -> bool {
121
- return tpl_rows[i] < tpl_rows[j] || (tpl_rows[i] == tpl_rows[j] && tpl_columns[i] < tpl_columns[j]);
122
- });
120
+ std::sort(block_indices.begin(), block_indices.end(), [tpl_rows, tpl_columns](int i, int j) -> bool
121
+ { return tpl_rows[i] < tpl_rows[j] || (tpl_rows[i] == tpl_rows[j] && tpl_columns[i] < tpl_columns[j]); });
123
122
 
124
123
  // accumulate blocks at same locations, count blocks per row
125
124
  std::fill_n(bsr_offsets, row_count + 1, 0);
@@ -138,7 +137,7 @@ int bsr_matrix_from_triplets_host(const int rows_per_block, const int cols_per_b
138
137
  int idx = block_indices[i];
139
138
  int row = tpl_rows[idx];
140
139
  int col = tpl_columns[idx];
141
- const T *val = tpl_values + idx * block_size;
140
+ const T* val = tpl_values + idx * block_size;
142
141
 
143
142
  if (row == current_row && col == current_col)
144
143
  {
@@ -171,14 +170,14 @@ int bsr_matrix_from_triplets_host(const int rows_per_block, const int cols_per_b
171
170
  }
172
171
 
173
172
  template <typename T>
174
- void bsr_transpose_host(int rows_per_block, int cols_per_block, int row_count, int col_count, int nnz,
175
- const int *bsr_offsets, const int *bsr_columns, const T *bsr_values,
176
- int *transposed_bsr_offsets, int *transposed_bsr_columns, T *transposed_bsr_values)
173
+ void bsr_transpose_host(int rows_per_block, int cols_per_block, int row_count, int col_count, int nnz_up,
174
+ const int* bsr_offsets, const int* bsr_columns, const T* bsr_values,
175
+ int* transposed_bsr_offsets, int* transposed_bsr_columns, T* transposed_bsr_values)
177
176
  {
178
-
177
+ const int nnz = bsr_offsets[row_count];
179
178
  const int block_size = rows_per_block * cols_per_block;
180
179
 
181
- void (*block_transpose_func)(const T *, T *, int, int) = bsr_dyn_block_transpose<T>;
180
+ void (*block_transpose_func)(const T*, T*, int, int) = bsr_dyn_block_transpose<T>;
182
181
  switch (rows_per_block)
183
182
  {
184
183
  case 1:
@@ -235,9 +234,9 @@ void bsr_transpose_host(int rows_per_block, int cols_per_block, int row_count, i
235
234
  }
236
235
 
237
236
  // sort block indices according to (transposed) lexico order
238
- std::sort(block_indices.begin(), block_indices.end(), [&bsr_rows, bsr_columns](int i, int j) -> bool {
239
- return bsr_columns[i] < bsr_columns[j] || (bsr_columns[i] == bsr_columns[j] && bsr_rows[i] < bsr_rows[j]);
240
- });
237
+ std::sort(
238
+ block_indices.begin(), block_indices.end(), [&bsr_rows, bsr_columns](int i, int j) -> bool
239
+ { return bsr_columns[i] < bsr_columns[j] || (bsr_columns[i] == bsr_columns[j] && bsr_rows[i] < bsr_rows[j]); });
241
240
 
242
241
  // Count blocks per column and transpose blocks
243
242
  std::fill_n(transposed_bsr_offsets, col_count + 1, 0);
@@ -251,88 +250,91 @@ void bsr_transpose_host(int rows_per_block, int cols_per_block, int row_count, i
251
250
  ++transposed_bsr_offsets[col + 1];
252
251
  transposed_bsr_columns[i] = row;
253
252
 
254
- const T *src_block = bsr_values + idx * block_size;
255
- T *dst_block = transposed_bsr_values + i * block_size;
256
- block_transpose_func(src_block, dst_block, rows_per_block, cols_per_block);
253
+ if (transposed_bsr_values != nullptr)
254
+ {
255
+ const T* src_block = bsr_values + idx * block_size;
256
+ T* dst_block = transposed_bsr_values + i * block_size;
257
+ block_transpose_func(src_block, dst_block, rows_per_block, cols_per_block);
258
+ }
257
259
  }
258
260
 
259
261
  // build postfix sum of column counts
260
262
  std::partial_sum(transposed_bsr_offsets, transposed_bsr_offsets + col_count + 1, transposed_bsr_offsets);
261
263
  }
262
264
 
263
- WP_API int bsr_matrix_from_triplets_float_host(int rows_per_block, int cols_per_block, int row_count, int nnz,
264
- uint64_t tpl_rows, uint64_t tpl_columns, uint64_t tpl_values,
265
- uint64_t bsr_offsets, uint64_t bsr_columns, uint64_t bsr_values)
265
+ WP_API void bsr_matrix_from_triplets_float_host(int rows_per_block, int cols_per_block, int row_count, int nnz,
266
+ int* tpl_rows, int* tpl_columns, void* tpl_values,
267
+ bool prune_numerical_zeros, int* bsr_offsets, int* bsr_columns,
268
+ void* bsr_values, int* bsr_nnz, void* bsr_nnz_event)
266
269
  {
267
- return bsr_matrix_from_triplets_host(
268
- rows_per_block, cols_per_block, row_count, nnz, reinterpret_cast<const int *>(tpl_rows),
269
- reinterpret_cast<const int *>(tpl_columns), reinterpret_cast<const float *>(tpl_values),
270
- reinterpret_cast<int *>(bsr_offsets), reinterpret_cast<int *>(bsr_columns),
271
- reinterpret_cast<float *>(bsr_values));
270
+ bsr_matrix_from_triplets_host<float>(rows_per_block, cols_per_block, row_count, nnz, tpl_rows, tpl_columns,
271
+ static_cast<const float*>(tpl_values), prune_numerical_zeros, bsr_offsets,
272
+ bsr_columns, static_cast<float*>(bsr_values));
273
+ if (bsr_nnz)
274
+ {
275
+ *bsr_nnz = bsr_offsets[row_count];
276
+ }
272
277
  }
273
278
 
274
- WP_API int bsr_matrix_from_triplets_double_host(int rows_per_block, int cols_per_block, int row_count, int nnz,
275
- uint64_t tpl_rows, uint64_t tpl_columns, uint64_t tpl_values,
276
- uint64_t bsr_offsets, uint64_t bsr_columns, uint64_t bsr_values)
279
+ WP_API void bsr_matrix_from_triplets_double_host(int rows_per_block, int cols_per_block, int row_count, int nnz,
280
+ int* tpl_rows, int* tpl_columns, void* tpl_values,
281
+ bool prune_numerical_zeros, int* bsr_offsets, int* bsr_columns,
282
+ void* bsr_values, int* bsr_nnz, void* bsr_nnz_event)
277
283
  {
278
- return bsr_matrix_from_triplets_host(
279
- rows_per_block, cols_per_block, row_count, nnz, reinterpret_cast<const int *>(tpl_rows),
280
- reinterpret_cast<const int *>(tpl_columns), reinterpret_cast<const double *>(tpl_values),
281
- reinterpret_cast<int *>(bsr_offsets), reinterpret_cast<int *>(bsr_columns),
282
- reinterpret_cast<double *>(bsr_values));
284
+ bsr_matrix_from_triplets_host<double>(rows_per_block, cols_per_block, row_count, nnz, tpl_rows, tpl_columns,
285
+ static_cast<const double*>(tpl_values), prune_numerical_zeros, bsr_offsets,
286
+ bsr_columns, static_cast<double*>(bsr_values));
287
+ if (bsr_nnz)
288
+ {
289
+ *bsr_nnz = bsr_offsets[row_count];
290
+ }
283
291
  }
284
292
 
285
293
  WP_API void bsr_transpose_float_host(int rows_per_block, int cols_per_block, int row_count, int col_count, int nnz,
286
- uint64_t bsr_offsets, uint64_t bsr_columns, uint64_t bsr_values,
287
- uint64_t transposed_bsr_offsets, uint64_t transposed_bsr_columns,
288
- uint64_t transposed_bsr_values)
294
+ int* bsr_offsets, int* bsr_columns, void* bsr_values, int* transposed_bsr_offsets,
295
+ int* transposed_bsr_columns, void* transposed_bsr_values)
289
296
  {
290
- bsr_transpose_host(rows_per_block, cols_per_block, row_count, col_count, nnz,
291
- reinterpret_cast<const int *>(bsr_offsets), reinterpret_cast<const int *>(bsr_columns),
292
- reinterpret_cast<const float *>(bsr_values), reinterpret_cast<int *>(transposed_bsr_offsets),
293
- reinterpret_cast<int *>(transposed_bsr_columns),
294
- reinterpret_cast<float *>(transposed_bsr_values));
297
+ bsr_transpose_host(rows_per_block, cols_per_block, row_count, col_count, nnz, bsr_offsets, bsr_columns,
298
+ static_cast<const float*>(bsr_values), transposed_bsr_offsets, transposed_bsr_columns,
299
+ static_cast<float*>(transposed_bsr_values));
295
300
  }
296
301
 
297
302
  WP_API void bsr_transpose_double_host(int rows_per_block, int cols_per_block, int row_count, int col_count, int nnz,
298
- uint64_t bsr_offsets, uint64_t bsr_columns, uint64_t bsr_values,
299
- uint64_t transposed_bsr_offsets, uint64_t transposed_bsr_columns,
300
- uint64_t transposed_bsr_values)
303
+ int* bsr_offsets, int* bsr_columns, void* bsr_values, int* transposed_bsr_offsets,
304
+ int* transposed_bsr_columns, void* transposed_bsr_values)
301
305
  {
302
- bsr_transpose_host(rows_per_block, cols_per_block, row_count, col_count, nnz,
303
- reinterpret_cast<const int *>(bsr_offsets), reinterpret_cast<const int *>(bsr_columns),
304
- reinterpret_cast<const double *>(bsr_values), reinterpret_cast<int *>(transposed_bsr_offsets),
305
- reinterpret_cast<int *>(transposed_bsr_columns),
306
- reinterpret_cast<double *>(transposed_bsr_values));
306
+ bsr_transpose_host(rows_per_block, cols_per_block, row_count, col_count, nnz, bsr_offsets, bsr_columns,
307
+ static_cast<const double*>(bsr_values), transposed_bsr_offsets, transposed_bsr_columns,
308
+ static_cast<double*>(transposed_bsr_values));
307
309
  }
308
310
 
309
311
  #if !WP_ENABLE_CUDA
310
- WP_API int bsr_matrix_from_triplets_float_device(int rows_per_block, int cols_per_block, int row_count, int nnz,
311
- uint64_t tpl_rows, uint64_t tpl_columns, uint64_t tpl_values,
312
- uint64_t bsr_offsets, uint64_t bsr_columns, uint64_t bsr_values)
312
+ WP_API void bsr_matrix_from_triplets_float_device(int rows_per_block, int cols_per_block, int row_count, int nnz,
313
+ int* tpl_rows, int* tpl_columns, void* tpl_values,
314
+ bool prune_numerical_zeros, int* bsr_offsets, int* bsr_columns,
315
+ void* bsr_values, int* bsr_nnz, void* bsr_nnz_event)
313
316
  {
314
- return 0;
315
317
  }
316
318
 
317
- WP_API int bsr_matrix_from_triplets_double_device(int rows_per_block, int cols_per_block, int row_count, int nnz,
318
- uint64_t tpl_rows, uint64_t tpl_columns, uint64_t tpl_values,
319
- uint64_t bsr_offsets, uint64_t bsr_columns, uint64_t bsr_values)
319
+ WP_API void bsr_matrix_from_triplets_double_device(int rows_per_block, int cols_per_block, int row_count, int nnz,
320
+ int* tpl_rows, int* tpl_columns, void* tpl_values,
321
+ bool prune_numerical_zeros, int* bsr_offsets, int* bsr_columns,
322
+ void* bsr_values, int* bsr_nnz, void* bsr_nnz_event)
320
323
  {
321
- return 0;
322
324
  }
323
325
 
324
326
  WP_API void bsr_transpose_float_device(int rows_per_block, int cols_per_block, int row_count, int col_count, int nnz,
325
- uint64_t bsr_offsets, uint64_t bsr_columns, uint64_t bsr_values,
326
- uint64_t transposed_bsr_offsets, uint64_t transposed_bsr_columns,
327
- uint64_t transposed_bsr_values)
327
+ int* bsr_offsets, int* bsr_columns, void* bsr_values,
328
+ int* transposed_bsr_offsets, int* transposed_bsr_columns,
329
+ void* transposed_bsr_values)
328
330
  {
329
331
  }
330
332
 
331
333
  WP_API void bsr_transpose_double_device(int rows_per_block, int cols_per_block, int row_count, int col_count, int nnz,
332
- uint64_t bsr_offsets, uint64_t bsr_columns, uint64_t bsr_values,
333
- uint64_t transposed_bsr_offsets, uint64_t transposed_bsr_columns,
334
- uint64_t transposed_bsr_values)
334
+ int* bsr_offsets, int* bsr_columns, void* bsr_values,
335
+ int* transposed_bsr_offsets, int* transposed_bsr_columns,
336
+ void* transposed_bsr_values)
335
337
  {
336
338
  }
337
339
 
338
- #endif
340
+ #endif