warp-lang 1.6.2__py3-none-macosx_10_13_universal2.whl → 1.7.0__py3-none-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (179) hide show
  1. warp/__init__.py +7 -1
  2. warp/bin/libwarp-clang.dylib +0 -0
  3. warp/bin/libwarp.dylib +0 -0
  4. warp/build.py +410 -0
  5. warp/build_dll.py +6 -14
  6. warp/builtins.py +452 -362
  7. warp/codegen.py +179 -119
  8. warp/config.py +42 -6
  9. warp/context.py +490 -271
  10. warp/dlpack.py +8 -6
  11. warp/examples/assets/nonuniform.usd +0 -0
  12. warp/examples/assets/nvidia_logo.png +0 -0
  13. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  14. warp/examples/core/example_sample_mesh.py +300 -0
  15. warp/examples/fem/example_apic_fluid.py +1 -1
  16. warp/examples/fem/example_burgers.py +2 -2
  17. warp/examples/fem/example_deformed_geometry.py +1 -1
  18. warp/examples/fem/example_distortion_energy.py +1 -1
  19. warp/examples/fem/example_magnetostatics.py +6 -6
  20. warp/examples/fem/utils.py +9 -3
  21. warp/examples/interop/example_jax_callable.py +116 -0
  22. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  23. warp/examples/interop/example_jax_kernel.py +205 -0
  24. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  25. warp/examples/tile/example_tile_matmul.py +2 -4
  26. warp/fem/__init__.py +11 -1
  27. warp/fem/adaptivity.py +4 -4
  28. warp/fem/field/nodal_field.py +22 -68
  29. warp/fem/field/virtual.py +62 -23
  30. warp/fem/geometry/adaptive_nanogrid.py +9 -10
  31. warp/fem/geometry/closest_point.py +1 -1
  32. warp/fem/geometry/deformed_geometry.py +5 -2
  33. warp/fem/geometry/geometry.py +5 -0
  34. warp/fem/geometry/grid_2d.py +12 -12
  35. warp/fem/geometry/grid_3d.py +12 -15
  36. warp/fem/geometry/hexmesh.py +5 -7
  37. warp/fem/geometry/nanogrid.py +9 -11
  38. warp/fem/geometry/quadmesh.py +13 -13
  39. warp/fem/geometry/tetmesh.py +3 -4
  40. warp/fem/geometry/trimesh.py +3 -8
  41. warp/fem/integrate.py +262 -93
  42. warp/fem/linalg.py +5 -5
  43. warp/fem/quadrature/pic_quadrature.py +37 -22
  44. warp/fem/quadrature/quadrature.py +194 -25
  45. warp/fem/space/__init__.py +1 -1
  46. warp/fem/space/basis_function_space.py +4 -2
  47. warp/fem/space/basis_space.py +25 -18
  48. warp/fem/space/hexmesh_function_space.py +2 -2
  49. warp/fem/space/partition.py +6 -2
  50. warp/fem/space/quadmesh_function_space.py +8 -8
  51. warp/fem/space/shape/cube_shape_function.py +23 -23
  52. warp/fem/space/shape/square_shape_function.py +12 -12
  53. warp/fem/space/shape/triangle_shape_function.py +1 -1
  54. warp/fem/space/tetmesh_function_space.py +3 -3
  55. warp/fem/space/trimesh_function_space.py +2 -2
  56. warp/fem/utils.py +12 -6
  57. warp/jax.py +14 -1
  58. warp/jax_experimental/__init__.py +16 -0
  59. warp/{jax_experimental.py → jax_experimental/custom_call.py} +14 -27
  60. warp/jax_experimental/ffi.py +698 -0
  61. warp/jax_experimental/xla_ffi.py +602 -0
  62. warp/math.py +89 -0
  63. warp/native/array.h +13 -0
  64. warp/native/builtin.h +29 -3
  65. warp/native/bvh.cpp +3 -1
  66. warp/native/bvh.cu +42 -14
  67. warp/native/bvh.h +2 -1
  68. warp/native/clang/clang.cpp +30 -3
  69. warp/native/cuda_util.cpp +14 -0
  70. warp/native/cuda_util.h +2 -0
  71. warp/native/exports.h +68 -63
  72. warp/native/intersect.h +26 -26
  73. warp/native/intersect_adj.h +33 -33
  74. warp/native/marching.cu +1 -1
  75. warp/native/mat.h +513 -9
  76. warp/native/mesh.h +10 -10
  77. warp/native/quat.h +99 -11
  78. warp/native/rand.h +6 -0
  79. warp/native/sort.cpp +122 -59
  80. warp/native/sort.cu +152 -15
  81. warp/native/sort.h +8 -1
  82. warp/native/sparse.cpp +43 -22
  83. warp/native/sparse.cu +52 -17
  84. warp/native/svd.h +116 -0
  85. warp/native/tile.h +301 -105
  86. warp/native/tile_reduce.h +46 -3
  87. warp/native/vec.h +68 -7
  88. warp/native/volume.cpp +85 -113
  89. warp/native/volume_builder.cu +25 -10
  90. warp/native/volume_builder.h +6 -0
  91. warp/native/warp.cpp +5 -6
  92. warp/native/warp.cu +99 -10
  93. warp/native/warp.h +19 -10
  94. warp/optim/linear.py +10 -10
  95. warp/sim/articulation.py +4 -4
  96. warp/sim/collide.py +21 -10
  97. warp/sim/import_mjcf.py +449 -155
  98. warp/sim/import_urdf.py +32 -12
  99. warp/sim/integrator_euler.py +5 -5
  100. warp/sim/integrator_featherstone.py +3 -10
  101. warp/sim/integrator_vbd.py +207 -2
  102. warp/sim/integrator_xpbd.py +5 -5
  103. warp/sim/model.py +42 -13
  104. warp/sim/utils.py +2 -2
  105. warp/sparse.py +642 -555
  106. warp/stubs.py +216 -19
  107. warp/tests/__main__.py +0 -15
  108. warp/tests/cuda/__init__.py +0 -0
  109. warp/tests/{test_mempool.py → cuda/test_mempool.py} +39 -0
  110. warp/tests/{test_streams.py → cuda/test_streams.py} +71 -0
  111. warp/tests/geometry/__init__.py +0 -0
  112. warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +66 -63
  113. warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +1 -1
  114. warp/tests/{test_volume.py → geometry/test_volume.py} +41 -6
  115. warp/tests/interop/__init__.py +0 -0
  116. warp/tests/{test_dlpack.py → interop/test_dlpack.py} +28 -5
  117. warp/tests/sim/__init__.py +0 -0
  118. warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +9 -10
  119. warp/tests/{test_collision.py → sim/test_collision.py} +2 -2
  120. warp/tests/{test_model.py → sim/test_model.py} +40 -0
  121. warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +2 -1
  122. warp/tests/sim/test_vbd.py +597 -0
  123. warp/tests/test_bool.py +1 -1
  124. warp/tests/test_examples.py +28 -36
  125. warp/tests/test_fem.py +23 -4
  126. warp/tests/test_linear_solvers.py +0 -11
  127. warp/tests/test_mat.py +233 -79
  128. warp/tests/test_mat_scalar_ops.py +4 -4
  129. warp/tests/test_overwrite.py +0 -60
  130. warp/tests/test_quat.py +67 -46
  131. warp/tests/test_rand.py +44 -37
  132. warp/tests/test_sparse.py +47 -6
  133. warp/tests/test_spatial.py +75 -0
  134. warp/tests/test_static.py +1 -1
  135. warp/tests/test_utils.py +84 -4
  136. warp/tests/test_vec.py +46 -34
  137. warp/tests/tile/__init__.py +0 -0
  138. warp/tests/{test_tile.py → tile/test_tile.py} +136 -51
  139. warp/tests/{test_tile_load.py → tile/test_tile_load.py} +1 -1
  140. warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +9 -6
  141. warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +25 -14
  142. warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +60 -1
  143. warp/tests/{test_tile_view.py → tile/test_tile_view.py} +1 -1
  144. warp/tests/unittest_serial.py +1 -0
  145. warp/tests/unittest_suites.py +45 -59
  146. warp/tests/unittest_utils.py +2 -1
  147. warp/thirdparty/unittest_parallel.py +3 -1
  148. warp/types.py +110 -658
  149. warp/utils.py +137 -72
  150. {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/METADATA +29 -7
  151. {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/RECORD +172 -162
  152. {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/WHEEL +1 -1
  153. warp/examples/optim/example_walker.py +0 -317
  154. warp/native/cutlass_gemm.cpp +0 -43
  155. warp/native/cutlass_gemm.cu +0 -382
  156. warp/tests/test_matmul.py +0 -511
  157. warp/tests/test_matmul_lite.py +0 -411
  158. warp/tests/test_vbd.py +0 -386
  159. warp/tests/unused_test_misc.py +0 -77
  160. /warp/tests/{test_async.py → cuda/test_async.py} +0 -0
  161. /warp/tests/{test_ipc.py → cuda/test_ipc.py} +0 -0
  162. /warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +0 -0
  163. /warp/tests/{test_peer.py → cuda/test_peer.py} +0 -0
  164. /warp/tests/{test_pinned.py → cuda/test_pinned.py} +0 -0
  165. /warp/tests/{test_bvh.py → geometry/test_bvh.py} +0 -0
  166. /warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +0 -0
  167. /warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +0 -0
  168. /warp/tests/{test_mesh.py → geometry/test_mesh.py} +0 -0
  169. /warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +0 -0
  170. /warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +0 -0
  171. /warp/tests/{test_jax.py → interop/test_jax.py} +0 -0
  172. /warp/tests/{test_paddle.py → interop/test_paddle.py} +0 -0
  173. /warp/tests/{test_torch.py → interop/test_torch.py} +0 -0
  174. /warp/tests/{flaky_test_sim_grad.py → sim/flaky_test_sim_grad.py} +0 -0
  175. /warp/tests/{test_coloring.py → sim/test_coloring.py} +0 -0
  176. /warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +0 -0
  177. /warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +0 -0
  178. {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info/licenses}/LICENSE.md +0 -0
  179. {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/top_level.txt +0 -0
warp/native/sparse.cpp CHANGED
@@ -81,7 +81,8 @@ template <typename T> void bsr_dyn_block_transpose(const T* src, T* dest, int ro
81
81
  template <typename T>
82
82
  int bsr_matrix_from_triplets_host(const int rows_per_block, const int cols_per_block, const int row_count,
83
83
  const int nnz, const int* tpl_rows, const int* tpl_columns, const T* tpl_values,
84
- const bool prune_numerical_zeros, int* bsr_offsets, int* bsr_columns, T* bsr_values)
84
+ const bool prune_numerical_zeros, const bool masked, int* bsr_offsets,
85
+ int* bsr_columns, T* bsr_values)
85
86
  {
86
87
 
87
88
  // get specialized accumulator for common block sizes (1,1), (1,2), (1,3),
@@ -124,14 +125,33 @@ int bsr_matrix_from_triplets_host(const int rows_per_block, const int cols_per_b
124
125
  std::iota(block_indices.begin(), block_indices.end(), 0);
125
126
 
126
127
  // remove zero blocks and invalid row indices
127
- block_indices.erase(std::remove_if(block_indices.begin(), block_indices.end(),
128
- [&](int i)
129
- {
130
- return tpl_rows[i] < 0 || tpl_rows[i] >= row_count ||
131
- (prune_numerical_zeros && tpl_values &&
132
- block_is_zero_func(tpl_values + i * block_size, block_size));
133
- }),
134
- block_indices.end());
128
+
129
+ auto discard_block = [&](int i)
130
+ {
131
+ const int row = tpl_rows[i];
132
+ if (row < 0 || row >= row_count)
133
+ {
134
+ return true;
135
+ }
136
+
137
+ if (prune_numerical_zeros && tpl_values && block_is_zero_func(tpl_values + i * block_size, block_size))
138
+ {
139
+ return true;
140
+ }
141
+
142
+ if (!masked)
143
+ {
144
+ return false;
145
+ }
146
+
147
+ const int* beg = bsr_columns + bsr_offsets[row];
148
+ const int* end = bsr_columns + bsr_offsets[row + 1];
149
+ const int col = tpl_columns[i];
150
+ const int* block = std::lower_bound(beg, end, col);
151
+ return block == end || *block != col;
152
+ };
153
+
154
+ block_indices.erase(std::remove_if(block_indices.begin(), block_indices.end(), discard_block), block_indices.end());
135
155
 
136
156
  // sort block indices according to lexico order
137
157
  std::sort(block_indices.begin(), block_indices.end(), [tpl_rows, tpl_columns](int i, int j) -> bool
@@ -281,12 +301,12 @@ void bsr_transpose_host(int rows_per_block, int cols_per_block, int row_count, i
281
301
 
282
302
  WP_API void bsr_matrix_from_triplets_float_host(int rows_per_block, int cols_per_block, int row_count, int nnz,
283
303
  int* tpl_rows, int* tpl_columns, void* tpl_values,
284
- bool prune_numerical_zeros, int* bsr_offsets, int* bsr_columns,
285
- void* bsr_values, int* bsr_nnz, void* bsr_nnz_event)
304
+ bool prune_numerical_zeros, bool masked, int* bsr_offsets,
305
+ int* bsr_columns, void* bsr_values, int* bsr_nnz, void* bsr_nnz_event)
286
306
  {
287
307
  bsr_matrix_from_triplets_host<float>(rows_per_block, cols_per_block, row_count, nnz, tpl_rows, tpl_columns,
288
- static_cast<const float*>(tpl_values), prune_numerical_zeros, bsr_offsets,
289
- bsr_columns, static_cast<float*>(bsr_values));
308
+ static_cast<const float*>(tpl_values), prune_numerical_zeros, masked,
309
+ bsr_offsets, bsr_columns, static_cast<float*>(bsr_values));
290
310
  if (bsr_nnz)
291
311
  {
292
312
  *bsr_nnz = bsr_offsets[row_count];
@@ -295,12 +315,12 @@ WP_API void bsr_matrix_from_triplets_float_host(int rows_per_block, int cols_per
295
315
 
296
316
  WP_API void bsr_matrix_from_triplets_double_host(int rows_per_block, int cols_per_block, int row_count, int nnz,
297
317
  int* tpl_rows, int* tpl_columns, void* tpl_values,
298
- bool prune_numerical_zeros, int* bsr_offsets, int* bsr_columns,
299
- void* bsr_values, int* bsr_nnz, void* bsr_nnz_event)
318
+ bool prune_numerical_zeros, bool masked, int* bsr_offsets,
319
+ int* bsr_columns, void* bsr_values, int* bsr_nnz, void* bsr_nnz_event)
300
320
  {
301
321
  bsr_matrix_from_triplets_host<double>(rows_per_block, cols_per_block, row_count, nnz, tpl_rows, tpl_columns,
302
- static_cast<const double*>(tpl_values), prune_numerical_zeros, bsr_offsets,
303
- bsr_columns, static_cast<double*>(bsr_values));
322
+ static_cast<const double*>(tpl_values), prune_numerical_zeros, masked,
323
+ bsr_offsets, bsr_columns, static_cast<double*>(bsr_values));
304
324
  if (bsr_nnz)
305
325
  {
306
326
  *bsr_nnz = bsr_offsets[row_count];
@@ -327,16 +347,17 @@ WP_API void bsr_transpose_double_host(int rows_per_block, int cols_per_block, in
327
347
 
328
348
  #if !WP_ENABLE_CUDA
329
349
  WP_API void bsr_matrix_from_triplets_float_device(int rows_per_block, int cols_per_block, int row_count, int nnz,
330
- int* tpl_rows, int* tpl_columns, void* tpl_values,
331
- bool prune_numerical_zeros, int* bsr_offsets, int* bsr_columns,
332
- void* bsr_values, int* bsr_nnz, void* bsr_nnz_event)
350
+ int* tpl_rows, int* tpl_columns, void* tpl_values,
351
+ bool prune_numerical_zeros, bool masked, int* bsr_offsets,
352
+ int* bsr_columns, void* bsr_values, int* bsr_nnz, void* bsr_nnz_event)
333
353
  {
334
354
  }
335
355
 
336
356
  WP_API void bsr_matrix_from_triplets_double_device(int rows_per_block, int cols_per_block, int row_count, int nnz,
337
357
  int* tpl_rows, int* tpl_columns, void* tpl_values,
338
- bool prune_numerical_zeros, int* bsr_offsets, int* bsr_columns,
339
- void* bsr_values, int* bsr_nnz, void* bsr_nnz_event)
358
+ bool prune_numerical_zeros, bool masked, int* bsr_offsets,
359
+ int* bsr_columns, void* bsr_values, int* bsr_nnz,
360
+ void* bsr_nnz_event)
340
361
  {
341
362
  }
342
363
 
warp/native/sparse.cu CHANGED
@@ -61,10 +61,41 @@ template <typename T> struct BsrBlockIsNotZero
61
61
  }
62
62
  };
63
63
 
64
+ struct BsrBlockInMask
65
+ {
66
+ const int* bsr_offsets;
67
+ const int* bsr_columns;
68
+
69
+ CUDA_CALLABLE_DEVICE bool operator()(int row, int col) const
70
+ {
71
+ if (bsr_offsets == nullptr)
72
+ return true;
73
+
74
+ int lower = bsr_offsets[row];
75
+ int upper = bsr_offsets[row + 1] - 1;
76
+
77
+ while (lower < upper)
78
+ {
79
+ const int mid = lower + (upper - lower) / 2;
80
+
81
+ if (bsr_columns[mid] < col)
82
+ {
83
+ lower = mid + 1;
84
+ }
85
+ else
86
+ {
87
+ upper = mid;
88
+ }
89
+ }
90
+
91
+ return lower == upper && (bsr_columns[lower] == col);
92
+ }
93
+ };
94
+
64
95
  template <typename T>
65
96
  __global__ void bsr_fill_triplet_key_values(const int nnz, const int nrow, const int* tpl_rows, const int* tpl_columns,
66
- const BsrBlockIsNotZero<T> nonZero, uint32_t* block_indices,
67
- BsrRowCol* tpl_row_col)
97
+ const BsrBlockIsNotZero<T> nonZero, const BsrBlockInMask mask,
98
+ uint32_t* block_indices, BsrRowCol* tpl_row_col)
68
99
  {
69
100
  int block = blockIdx.x * blockDim.x + threadIdx.x;
70
101
  if (block >= nnz)
@@ -74,7 +105,8 @@ __global__ void bsr_fill_triplet_key_values(const int nnz, const int nrow, const
74
105
  const int col = tpl_columns[block];
75
106
  const bool is_valid = row >= 0 && row < nrow;
76
107
 
77
- const BsrRowCol row_col = is_valid && nonZero(block) ? bsr_combine_row_col(row, col) : PRUNED_ROWCOL;
108
+ const BsrRowCol row_col =
109
+ is_valid && nonZero(block) && mask(row, col) ? bsr_combine_row_col(row, col) : PRUNED_ROWCOL;
78
110
  tpl_row_col[block] = row_col;
79
111
  block_indices[block] = block;
80
112
  }
@@ -122,7 +154,7 @@ __global__ void bsr_find_row_offsets(uint32_t row_count, const T* d_nnz, const B
122
154
  }
123
155
 
124
156
  template <typename T>
125
- __global__ void bsr_merge_blocks(const uint32_t* d_nnz, int block_size, const uint32_t* block_offsets,
157
+ __global__ void bsr_merge_blocks(const int* d_nnz, int block_size, const uint32_t* block_offsets,
126
158
  const uint32_t* sorted_block_indices, const BsrRowCol* unique_row_cols,
127
159
  const T* tpl_values, int* bsr_cols, T* bsr_values)
128
160
 
@@ -163,8 +195,8 @@ __global__ void bsr_merge_blocks(const uint32_t* d_nnz, int block_size, const ui
163
195
  template <typename T>
164
196
  void bsr_matrix_from_triplets_device(const int rows_per_block, const int cols_per_block, const int row_count,
165
197
  const int nnz, const int* tpl_rows, const int* tpl_columns, const T* tpl_values,
166
- const bool prune_numerical_zeros, int* bsr_offsets, int* bsr_columns,
167
- T* bsr_values, int* bsr_nnz, void* bsr_nnz_event)
198
+ const bool prune_numerical_zeros, const bool masked, int* bsr_offsets,
199
+ int* bsr_columns, T* bsr_values, int* bsr_nnz, void* bsr_nnz_event)
168
200
  {
169
201
  const int block_size = rows_per_block * cols_per_block;
170
202
 
@@ -186,8 +218,9 @@ void bsr_matrix_from_triplets_device(const int rows_per_block, const int cols_pe
186
218
 
187
219
  // Combine rows and columns so we can sort on them both
188
220
  BsrBlockIsNotZero<T> isNotZero{block_size, prune_numerical_zeros ? tpl_values : nullptr};
221
+ BsrBlockInMask mask{masked ? bsr_offsets : nullptr, bsr_columns};
189
222
  wp_launch_device(WP_CURRENT_CONTEXT, bsr_fill_triplet_key_values, nnz,
190
- (nnz, row_count, tpl_rows, tpl_columns, isNotZero, d_keys.Current(), d_values.Current()));
223
+ (nnz, row_count, tpl_rows, tpl_columns, isNotZero, mask, d_keys.Current(), d_values.Current()));
191
224
 
192
225
  // Sort
193
226
  {
@@ -214,7 +247,7 @@ void bsr_matrix_from_triplets_device(const int rows_per_block, const int cols_pe
214
247
 
215
248
  if (bsr_nnz)
216
249
  {
217
- // Copy nnz to host, and record an event for the competed transfer if desired
250
+ // Copy nnz to host, and record an event for the completed transfer if desired
218
251
 
219
252
  memcpy_d2h(WP_CURRENT_CONTEXT, bsr_nnz, bsr_offsets + row_count, sizeof(int), stream);
220
253
 
@@ -236,7 +269,7 @@ void bsr_matrix_from_triplets_device(const int rows_per_block, const int cols_pe
236
269
 
237
270
  // Accumulate repeated blocks and set column indices
238
271
  wp_launch_device(WP_CURRENT_CONTEXT, bsr_merge_blocks, nnz,
239
- (unique_triplet_count, block_size, d_keys.Alternate(), d_keys.Current(), d_values.Alternate(),
272
+ (bsr_offsets + row_count, block_size, d_keys.Alternate(), d_keys.Current(), d_values.Alternate(),
240
273
  tpl_values, bsr_columns, bsr_values));
241
274
  }
242
275
 
@@ -452,22 +485,24 @@ void bsr_transpose_device(int rows_per_block, int cols_per_block, int row_count,
452
485
 
453
486
  void bsr_matrix_from_triplets_float_device(int rows_per_block, int cols_per_block, int row_count, int nnz,
454
487
  int* tpl_rows, int* tpl_columns, void* tpl_values,
455
- bool prune_numerical_zeros, int* bsr_offsets, int* bsr_columns,
488
+ bool prune_numerical_zeros, bool masked, int* bsr_offsets, int* bsr_columns,
456
489
  void* bsr_values, int* bsr_nnz, void* bsr_nnz_event)
457
490
  {
458
- return bsr_matrix_from_triplets_device<float>(
459
- rows_per_block, cols_per_block, row_count, nnz, tpl_rows, tpl_columns, static_cast<const float*>(tpl_values),
460
- prune_numerical_zeros, bsr_offsets, bsr_columns, static_cast<float*>(bsr_values), bsr_nnz, bsr_nnz_event);
491
+ return bsr_matrix_from_triplets_device<float>(rows_per_block, cols_per_block, row_count, nnz, tpl_rows, tpl_columns,
492
+ static_cast<const float*>(tpl_values), prune_numerical_zeros, masked,
493
+ bsr_offsets, bsr_columns, static_cast<float*>(bsr_values), bsr_nnz,
494
+ bsr_nnz_event);
461
495
  }
462
496
 
463
497
  void bsr_matrix_from_triplets_double_device(int rows_per_block, int cols_per_block, int row_count, int nnz,
464
498
  int* tpl_rows, int* tpl_columns, void* tpl_values,
465
- bool prune_numerical_zeros, int* bsr_offsets, int* bsr_columns,
499
+ bool prune_numerical_zeros, bool masked, int* bsr_offsets, int* bsr_columns,
466
500
  void* bsr_values, int* bsr_nnz, void* bsr_nnz_event)
467
501
  {
468
- return bsr_matrix_from_triplets_device<double>(
469
- rows_per_block, cols_per_block, row_count, nnz, tpl_rows, tpl_columns, static_cast<const double*>(tpl_values),
470
- prune_numerical_zeros, bsr_offsets, bsr_columns, static_cast<double*>(bsr_values), bsr_nnz, bsr_nnz_event);
502
+ return bsr_matrix_from_triplets_device<double>(rows_per_block, cols_per_block, row_count, nnz, tpl_rows,
503
+ tpl_columns, static_cast<const double*>(tpl_values),
504
+ prune_numerical_zeros, masked, bsr_offsets, bsr_columns,
505
+ static_cast<double*>(bsr_values), bsr_nnz, bsr_nnz_event);
471
506
  }
472
507
 
473
508
  void bsr_transpose_float_device(int rows_per_block, int cols_per_block, int row_count, int col_count, int nnz,
warp/native/svd.h CHANGED
@@ -432,6 +432,62 @@ void _svd(// input A
432
432
  );
433
433
  }
434
434
 
435
+
436
+ template<typename Type>
437
+ inline CUDA_CALLABLE
438
+ void _svd_2(// input A
439
+ Type a11, Type a12,
440
+ Type a21, Type a22,
441
+ // output U
442
+ Type &u11, Type &u12,
443
+ Type &u21, Type &u22,
444
+ // output S
445
+ Type &s11, Type &s12,
446
+ Type &s21, Type &s22,
447
+ // output V
448
+ Type &v11, Type &v12,
449
+ Type &v21, Type &v22)
450
+ {
451
+ // Step 1: Compute ATA
452
+ Type ATA11 = a11 * a11 + a21 * a21;
453
+ Type ATA12 = a11 * a12 + a21 * a22;
454
+ Type ATA22 = a12 * a12 + a22 * a22;
455
+
456
+ // Step 2: Eigenanalysis
457
+ Type trace = ATA11 + ATA22;
458
+ Type det = ATA11 * ATA22 - ATA12 * ATA12;
459
+ Type sqrt_term = sqrt(trace * trace - Type(4.0) * det);
460
+ Type lambda1 = (trace + sqrt_term) * Type(0.5);
461
+ Type lambda2 = (trace - sqrt_term) * Type(0.5);
462
+
463
+ // Step 3: Singular values
464
+ Type sigma1 = sqrt(lambda1);
465
+ Type sigma2 = sqrt(lambda2);
466
+
467
+ // Step 4: Eigenvectors (find V)
468
+ Type v1x = ATA12, v1y = lambda1 - ATA11; // For first eigenvector
469
+ Type v2x = ATA12, v2y = lambda2 - ATA11; // For second eigenvector
470
+ Type norm1 = sqrt(v1x * v1x + v1y * v1y);
471
+ Type norm2 = sqrt(v2x * v2x + v2y * v2y);
472
+
473
+ v11 = v1x / norm1; v12 = v2x / norm2;
474
+ v21 = v1y / norm1; v22 = v2y / norm2;
475
+
476
+ // Step 5: Compute U
477
+ Type inv_sigma1 = (sigma1 > Type(1e-6)) ? Type(1.0) / sigma1 : Type(0.0);
478
+ Type inv_sigma2 = (sigma2 > Type(1e-6)) ? Type(1.0) / sigma2 : Type(0.0);
479
+
480
+ u11 = (a11 * v11 + a12 * v21) * inv_sigma1;
481
+ u12 = (a11 * v12 + a12 * v22) * inv_sigma2;
482
+ u21 = (a21 * v11 + a22 * v21) * inv_sigma1;
483
+ u22 = (a21 * v12 + a22 * v22) * inv_sigma2;
484
+
485
+ // Step 6: Set S
486
+ s11 = sigma1; s12 = Type(0.0);
487
+ s21 = Type(0.0); s22 = sigma2;
488
+ }
489
+
490
+
435
491
  template<typename Type>
436
492
  inline CUDA_CALLABLE void svd3(const mat_t<3,3,Type>& A, mat_t<3,3,Type>& U, vec_t<3,Type>& sigma, mat_t<3,3,Type>& V) {
437
493
  Type s12, s13, s21, s23, s31, s32;
@@ -492,6 +548,66 @@ inline CUDA_CALLABLE void adj_svd3(const mat_t<3,3,Type>& A,
492
548
  adj_A = adj_A + (u_term + v_term + sigma_term);
493
549
  }
494
550
 
551
+ template<typename Type>
552
+ inline CUDA_CALLABLE void svd2(const mat_t<2,2,Type>& A, mat_t<2,2,Type>& U, vec_t<2,Type>& sigma, mat_t<2,2,Type>& V) {
553
+ Type s12, s21;
554
+ _svd_2(A.data[0][0], A.data[0][1],
555
+ A.data[1][0], A.data[1][1],
556
+
557
+ U.data[0][0], U.data[0][1],
558
+ U.data[1][0], U.data[1][1],
559
+
560
+ sigma[0], s12,
561
+ s21, sigma[1],
562
+
563
+ V.data[0][0], V.data[0][1],
564
+ V.data[1][0], V.data[1][1]);
565
+ }
566
+
567
+ template<typename Type>
568
+ inline CUDA_CALLABLE void adj_svd2(const mat_t<2,2,Type>& A,
569
+ const mat_t<2,2,Type>& U,
570
+ const vec_t<2,Type>& sigma,
571
+ const mat_t<2,2,Type>& V,
572
+ mat_t<2,2,Type>& adj_A,
573
+ const mat_t<2,2,Type>& adj_U,
574
+ const vec_t<2,Type>& adj_sigma,
575
+ const mat_t<2,2,Type>& adj_V) {
576
+ Type s1_squared = sigma[0] * sigma[0];
577
+ Type s2_squared = sigma[1] * sigma[1];
578
+
579
+ // Compute inverse of (s1^2 - s2^2) if possible, use small epsilon to prevent division by zero
580
+ Type F01 = Type(1) / min(s2_squared - s1_squared, Type(-1e-6f));
581
+
582
+ // Construct the matrix F for the adjoint
583
+ mat_t<2,2,Type> F = mat_t<2,2,Type>(0.0, F01,
584
+ -F01, 0.0);
585
+
586
+ // Create a matrix to handle the adjoint of the singular values (diagonal matrix)
587
+ mat_t<2,2,Type> adj_sigma_mat = mat_t<2,2,Type>(adj_sigma[0], 0.0,
588
+ 0.0, adj_sigma[1]);
589
+
590
+ // Matrix for handling singular values (diagonal matrix with sigma values)
591
+ mat_t<2,2,Type> s_mat = mat_t<2,2,Type>(sigma[0], 0.0,
592
+ 0.0, sigma[1]);
593
+
594
+ // Compute the transpose of U and V
595
+ mat_t<2,2,Type> UT = transpose(U);
596
+ mat_t<2,2,Type> VT = transpose(V);
597
+
598
+ // Compute the term for sigma (diagonal matrix of adjoint singular values)
599
+ mat_t<2,2,Type> sigma_term = mul(U, mul(adj_sigma_mat, VT));
600
+
601
+ // Compute the adjoint contributions for U (left singular vectors)
602
+ mat_t<2,2,Type> u_term = mul(mul(U, mul(cw_mul(F, (mul(UT, adj_U) - mul(transpose(adj_U), U))), s_mat)), VT);
603
+
604
+ // Compute the adjoint contributions for V (right singular vectors)
605
+ mat_t<2,2,Type> v_term = mul(U, mul(s_mat, mul(cw_mul(F, (mul(VT, adj_V) - mul(transpose(adj_V), V))), VT)));
606
+
607
+ // Combine the terms to compute the adjoint of A
608
+ adj_A = adj_A + (u_term + v_term + sigma_term);
609
+ }
610
+
495
611
 
496
612
  template<typename Type>
497
613
  inline CUDA_CALLABLE void qr3(const mat_t<3,3,Type>& A, mat_t<3,3,Type>& Q, mat_t<3,3,Type>& R) {