warp-lang 1.6.2__py3-none-win_amd64.whl → 1.7.0__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +7 -1
- warp/bin/warp-clang.dll +0 -0
- warp/bin/warp.dll +0 -0
- warp/build.py +410 -0
- warp/build_dll.py +6 -14
- warp/builtins.py +452 -362
- warp/codegen.py +179 -119
- warp/config.py +42 -6
- warp/context.py +490 -271
- warp/dlpack.py +8 -6
- warp/examples/assets/nonuniform.usd +0 -0
- warp/examples/assets/nvidia_logo.png +0 -0
- warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
- warp/examples/core/example_sample_mesh.py +300 -0
- warp/examples/fem/example_apic_fluid.py +1 -1
- warp/examples/fem/example_burgers.py +2 -2
- warp/examples/fem/example_deformed_geometry.py +1 -1
- warp/examples/fem/example_distortion_energy.py +1 -1
- warp/examples/fem/example_magnetostatics.py +6 -6
- warp/examples/fem/utils.py +9 -3
- warp/examples/interop/example_jax_callable.py +116 -0
- warp/examples/interop/example_jax_ffi_callback.py +132 -0
- warp/examples/interop/example_jax_kernel.py +205 -0
- warp/examples/optim/example_fluid_checkpoint.py +497 -0
- warp/examples/tile/example_tile_matmul.py +2 -4
- warp/fem/__init__.py +11 -1
- warp/fem/adaptivity.py +4 -4
- warp/fem/field/nodal_field.py +22 -68
- warp/fem/field/virtual.py +62 -23
- warp/fem/geometry/adaptive_nanogrid.py +9 -10
- warp/fem/geometry/closest_point.py +1 -1
- warp/fem/geometry/deformed_geometry.py +5 -2
- warp/fem/geometry/geometry.py +5 -0
- warp/fem/geometry/grid_2d.py +12 -12
- warp/fem/geometry/grid_3d.py +12 -15
- warp/fem/geometry/hexmesh.py +5 -7
- warp/fem/geometry/nanogrid.py +9 -11
- warp/fem/geometry/quadmesh.py +13 -13
- warp/fem/geometry/tetmesh.py +3 -4
- warp/fem/geometry/trimesh.py +3 -8
- warp/fem/integrate.py +262 -93
- warp/fem/linalg.py +5 -5
- warp/fem/quadrature/pic_quadrature.py +37 -22
- warp/fem/quadrature/quadrature.py +194 -25
- warp/fem/space/__init__.py +1 -1
- warp/fem/space/basis_function_space.py +4 -2
- warp/fem/space/basis_space.py +25 -18
- warp/fem/space/hexmesh_function_space.py +2 -2
- warp/fem/space/partition.py +6 -2
- warp/fem/space/quadmesh_function_space.py +8 -8
- warp/fem/space/shape/cube_shape_function.py +23 -23
- warp/fem/space/shape/square_shape_function.py +12 -12
- warp/fem/space/shape/triangle_shape_function.py +1 -1
- warp/fem/space/tetmesh_function_space.py +3 -3
- warp/fem/space/trimesh_function_space.py +2 -2
- warp/fem/utils.py +12 -6
- warp/jax.py +14 -1
- warp/jax_experimental/__init__.py +16 -0
- warp/{jax_experimental.py → jax_experimental/custom_call.py} +14 -27
- warp/jax_experimental/ffi.py +698 -0
- warp/jax_experimental/xla_ffi.py +602 -0
- warp/math.py +89 -0
- warp/native/array.h +13 -0
- warp/native/builtin.h +29 -3
- warp/native/bvh.cpp +3 -1
- warp/native/bvh.cu +42 -14
- warp/native/bvh.h +2 -1
- warp/native/clang/clang.cpp +30 -3
- warp/native/cuda_util.cpp +14 -0
- warp/native/cuda_util.h +2 -0
- warp/native/exports.h +68 -63
- warp/native/intersect.h +26 -26
- warp/native/intersect_adj.h +33 -33
- warp/native/marching.cu +1 -1
- warp/native/mat.h +513 -9
- warp/native/mesh.h +10 -10
- warp/native/quat.h +99 -11
- warp/native/rand.h +6 -0
- warp/native/sort.cpp +122 -59
- warp/native/sort.cu +152 -15
- warp/native/sort.h +8 -1
- warp/native/sparse.cpp +43 -22
- warp/native/sparse.cu +52 -17
- warp/native/svd.h +116 -0
- warp/native/tile.h +301 -105
- warp/native/tile_reduce.h +46 -3
- warp/native/vec.h +68 -7
- warp/native/volume.cpp +85 -113
- warp/native/volume_builder.cu +25 -10
- warp/native/volume_builder.h +6 -0
- warp/native/warp.cpp +5 -6
- warp/native/warp.cu +99 -10
- warp/native/warp.h +19 -10
- warp/optim/linear.py +10 -10
- warp/sim/articulation.py +4 -4
- warp/sim/collide.py +21 -10
- warp/sim/import_mjcf.py +449 -155
- warp/sim/import_urdf.py +32 -12
- warp/sim/integrator_euler.py +5 -5
- warp/sim/integrator_featherstone.py +3 -10
- warp/sim/integrator_vbd.py +207 -2
- warp/sim/integrator_xpbd.py +5 -5
- warp/sim/model.py +42 -13
- warp/sim/utils.py +2 -2
- warp/sparse.py +642 -555
- warp/stubs.py +216 -19
- warp/tests/__main__.py +0 -15
- warp/tests/cuda/__init__.py +0 -0
- warp/tests/{test_mempool.py → cuda/test_mempool.py} +39 -0
- warp/tests/{test_streams.py → cuda/test_streams.py} +71 -0
- warp/tests/geometry/__init__.py +0 -0
- warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +66 -63
- warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +1 -1
- warp/tests/{test_volume.py → geometry/test_volume.py} +41 -6
- warp/tests/interop/__init__.py +0 -0
- warp/tests/{test_dlpack.py → interop/test_dlpack.py} +28 -5
- warp/tests/sim/__init__.py +0 -0
- warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +9 -10
- warp/tests/{test_collision.py → sim/test_collision.py} +2 -2
- warp/tests/{test_model.py → sim/test_model.py} +40 -0
- warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +2 -1
- warp/tests/sim/test_vbd.py +597 -0
- warp/tests/test_bool.py +1 -1
- warp/tests/test_examples.py +28 -36
- warp/tests/test_fem.py +23 -4
- warp/tests/test_linear_solvers.py +0 -11
- warp/tests/test_mat.py +233 -79
- warp/tests/test_mat_scalar_ops.py +4 -4
- warp/tests/test_overwrite.py +0 -60
- warp/tests/test_quat.py +67 -46
- warp/tests/test_rand.py +44 -37
- warp/tests/test_sparse.py +47 -6
- warp/tests/test_spatial.py +75 -0
- warp/tests/test_static.py +1 -1
- warp/tests/test_utils.py +84 -4
- warp/tests/test_vec.py +46 -34
- warp/tests/tile/__init__.py +0 -0
- warp/tests/{test_tile.py → tile/test_tile.py} +136 -51
- warp/tests/{test_tile_load.py → tile/test_tile_load.py} +1 -1
- warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +9 -6
- warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +25 -14
- warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +60 -1
- warp/tests/{test_tile_view.py → tile/test_tile_view.py} +1 -1
- warp/tests/unittest_serial.py +1 -0
- warp/tests/unittest_suites.py +45 -59
- warp/tests/unittest_utils.py +2 -1
- warp/thirdparty/unittest_parallel.py +3 -1
- warp/types.py +110 -658
- warp/utils.py +137 -72
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/METADATA +29 -7
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/RECORD +172 -162
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/WHEEL +1 -1
- warp/examples/optim/example_walker.py +0 -317
- warp/native/cutlass_gemm.cpp +0 -43
- warp/native/cutlass_gemm.cu +0 -382
- warp/tests/test_matmul.py +0 -511
- warp/tests/test_matmul_lite.py +0 -411
- warp/tests/test_vbd.py +0 -386
- warp/tests/unused_test_misc.py +0 -77
- /warp/tests/{test_async.py → cuda/test_async.py} +0 -0
- /warp/tests/{test_ipc.py → cuda/test_ipc.py} +0 -0
- /warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +0 -0
- /warp/tests/{test_peer.py → cuda/test_peer.py} +0 -0
- /warp/tests/{test_pinned.py → cuda/test_pinned.py} +0 -0
- /warp/tests/{test_bvh.py → geometry/test_bvh.py} +0 -0
- /warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +0 -0
- /warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +0 -0
- /warp/tests/{test_mesh.py → geometry/test_mesh.py} +0 -0
- /warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +0 -0
- /warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +0 -0
- /warp/tests/{test_jax.py → interop/test_jax.py} +0 -0
- /warp/tests/{test_paddle.py → interop/test_paddle.py} +0 -0
- /warp/tests/{test_torch.py → interop/test_torch.py} +0 -0
- /warp/tests/{flaky_test_sim_grad.py → sim/flaky_test_sim_grad.py} +0 -0
- /warp/tests/{test_coloring.py → sim/test_coloring.py} +0 -0
- /warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +0 -0
- /warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +0 -0
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info/licenses}/LICENSE.md +0 -0
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/top_level.txt +0 -0
warp/native/sparse.cpp
CHANGED
|
@@ -81,7 +81,8 @@ template <typename T> void bsr_dyn_block_transpose(const T* src, T* dest, int ro
|
|
|
81
81
|
template <typename T>
|
|
82
82
|
int bsr_matrix_from_triplets_host(const int rows_per_block, const int cols_per_block, const int row_count,
|
|
83
83
|
const int nnz, const int* tpl_rows, const int* tpl_columns, const T* tpl_values,
|
|
84
|
-
const bool prune_numerical_zeros,
|
|
84
|
+
const bool prune_numerical_zeros, const bool masked, int* bsr_offsets,
|
|
85
|
+
int* bsr_columns, T* bsr_values)
|
|
85
86
|
{
|
|
86
87
|
|
|
87
88
|
// get specialized accumulator for common block sizes (1,1), (1,2), (1,3),
|
|
@@ -124,14 +125,33 @@ int bsr_matrix_from_triplets_host(const int rows_per_block, const int cols_per_b
|
|
|
124
125
|
std::iota(block_indices.begin(), block_indices.end(), 0);
|
|
125
126
|
|
|
126
127
|
// remove zero blocks and invalid row indices
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
128
|
+
|
|
129
|
+
auto discard_block = [&](int i)
|
|
130
|
+
{
|
|
131
|
+
const int row = tpl_rows[i];
|
|
132
|
+
if (row < 0 || row >= row_count)
|
|
133
|
+
{
|
|
134
|
+
return true;
|
|
135
|
+
}
|
|
136
|
+
|
|
137
|
+
if (prune_numerical_zeros && tpl_values && block_is_zero_func(tpl_values + i * block_size, block_size))
|
|
138
|
+
{
|
|
139
|
+
return true;
|
|
140
|
+
}
|
|
141
|
+
|
|
142
|
+
if (!masked)
|
|
143
|
+
{
|
|
144
|
+
return false;
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
const int* beg = bsr_columns + bsr_offsets[row];
|
|
148
|
+
const int* end = bsr_columns + bsr_offsets[row + 1];
|
|
149
|
+
const int col = tpl_columns[i];
|
|
150
|
+
const int* block = std::lower_bound(beg, end, col);
|
|
151
|
+
return block == end || *block != col;
|
|
152
|
+
};
|
|
153
|
+
|
|
154
|
+
block_indices.erase(std::remove_if(block_indices.begin(), block_indices.end(), discard_block), block_indices.end());
|
|
135
155
|
|
|
136
156
|
// sort block indices according to lexico order
|
|
137
157
|
std::sort(block_indices.begin(), block_indices.end(), [tpl_rows, tpl_columns](int i, int j) -> bool
|
|
@@ -281,12 +301,12 @@ void bsr_transpose_host(int rows_per_block, int cols_per_block, int row_count, i
|
|
|
281
301
|
|
|
282
302
|
WP_API void bsr_matrix_from_triplets_float_host(int rows_per_block, int cols_per_block, int row_count, int nnz,
|
|
283
303
|
int* tpl_rows, int* tpl_columns, void* tpl_values,
|
|
284
|
-
bool prune_numerical_zeros,
|
|
285
|
-
void* bsr_values, int* bsr_nnz, void* bsr_nnz_event)
|
|
304
|
+
bool prune_numerical_zeros, bool masked, int* bsr_offsets,
|
|
305
|
+
int* bsr_columns, void* bsr_values, int* bsr_nnz, void* bsr_nnz_event)
|
|
286
306
|
{
|
|
287
307
|
bsr_matrix_from_triplets_host<float>(rows_per_block, cols_per_block, row_count, nnz, tpl_rows, tpl_columns,
|
|
288
|
-
static_cast<const float*>(tpl_values), prune_numerical_zeros,
|
|
289
|
-
bsr_columns, static_cast<float*>(bsr_values));
|
|
308
|
+
static_cast<const float*>(tpl_values), prune_numerical_zeros, masked,
|
|
309
|
+
bsr_offsets, bsr_columns, static_cast<float*>(bsr_values));
|
|
290
310
|
if (bsr_nnz)
|
|
291
311
|
{
|
|
292
312
|
*bsr_nnz = bsr_offsets[row_count];
|
|
@@ -295,12 +315,12 @@ WP_API void bsr_matrix_from_triplets_float_host(int rows_per_block, int cols_per
|
|
|
295
315
|
|
|
296
316
|
WP_API void bsr_matrix_from_triplets_double_host(int rows_per_block, int cols_per_block, int row_count, int nnz,
|
|
297
317
|
int* tpl_rows, int* tpl_columns, void* tpl_values,
|
|
298
|
-
bool prune_numerical_zeros,
|
|
299
|
-
void* bsr_values, int* bsr_nnz, void* bsr_nnz_event)
|
|
318
|
+
bool prune_numerical_zeros, bool masked, int* bsr_offsets,
|
|
319
|
+
int* bsr_columns, void* bsr_values, int* bsr_nnz, void* bsr_nnz_event)
|
|
300
320
|
{
|
|
301
321
|
bsr_matrix_from_triplets_host<double>(rows_per_block, cols_per_block, row_count, nnz, tpl_rows, tpl_columns,
|
|
302
|
-
static_cast<const double*>(tpl_values), prune_numerical_zeros,
|
|
303
|
-
bsr_columns, static_cast<double*>(bsr_values));
|
|
322
|
+
static_cast<const double*>(tpl_values), prune_numerical_zeros, masked,
|
|
323
|
+
bsr_offsets, bsr_columns, static_cast<double*>(bsr_values));
|
|
304
324
|
if (bsr_nnz)
|
|
305
325
|
{
|
|
306
326
|
*bsr_nnz = bsr_offsets[row_count];
|
|
@@ -327,16 +347,17 @@ WP_API void bsr_transpose_double_host(int rows_per_block, int cols_per_block, in
|
|
|
327
347
|
|
|
328
348
|
#if !WP_ENABLE_CUDA
|
|
329
349
|
WP_API void bsr_matrix_from_triplets_float_device(int rows_per_block, int cols_per_block, int row_count, int nnz,
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
350
|
+
int* tpl_rows, int* tpl_columns, void* tpl_values,
|
|
351
|
+
bool prune_numerical_zeros, bool masked, int* bsr_offsets,
|
|
352
|
+
int* bsr_columns, void* bsr_values, int* bsr_nnz, void* bsr_nnz_event)
|
|
333
353
|
{
|
|
334
354
|
}
|
|
335
355
|
|
|
336
356
|
WP_API void bsr_matrix_from_triplets_double_device(int rows_per_block, int cols_per_block, int row_count, int nnz,
|
|
337
357
|
int* tpl_rows, int* tpl_columns, void* tpl_values,
|
|
338
|
-
bool prune_numerical_zeros,
|
|
339
|
-
void* bsr_values, int* bsr_nnz,
|
|
358
|
+
bool prune_numerical_zeros, bool masked, int* bsr_offsets,
|
|
359
|
+
int* bsr_columns, void* bsr_values, int* bsr_nnz,
|
|
360
|
+
void* bsr_nnz_event)
|
|
340
361
|
{
|
|
341
362
|
}
|
|
342
363
|
|
warp/native/sparse.cu
CHANGED
|
@@ -61,10 +61,41 @@ template <typename T> struct BsrBlockIsNotZero
|
|
|
61
61
|
}
|
|
62
62
|
};
|
|
63
63
|
|
|
64
|
+
struct BsrBlockInMask
|
|
65
|
+
{
|
|
66
|
+
const int* bsr_offsets;
|
|
67
|
+
const int* bsr_columns;
|
|
68
|
+
|
|
69
|
+
CUDA_CALLABLE_DEVICE bool operator()(int row, int col) const
|
|
70
|
+
{
|
|
71
|
+
if (bsr_offsets == nullptr)
|
|
72
|
+
return true;
|
|
73
|
+
|
|
74
|
+
int lower = bsr_offsets[row];
|
|
75
|
+
int upper = bsr_offsets[row + 1] - 1;
|
|
76
|
+
|
|
77
|
+
while (lower < upper)
|
|
78
|
+
{
|
|
79
|
+
const int mid = lower + (upper - lower) / 2;
|
|
80
|
+
|
|
81
|
+
if (bsr_columns[mid] < col)
|
|
82
|
+
{
|
|
83
|
+
lower = mid + 1;
|
|
84
|
+
}
|
|
85
|
+
else
|
|
86
|
+
{
|
|
87
|
+
upper = mid;
|
|
88
|
+
}
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
return lower == upper && (bsr_columns[lower] == col);
|
|
92
|
+
}
|
|
93
|
+
};
|
|
94
|
+
|
|
64
95
|
template <typename T>
|
|
65
96
|
__global__ void bsr_fill_triplet_key_values(const int nnz, const int nrow, const int* tpl_rows, const int* tpl_columns,
|
|
66
|
-
const BsrBlockIsNotZero<T> nonZero,
|
|
67
|
-
BsrRowCol* tpl_row_col)
|
|
97
|
+
const BsrBlockIsNotZero<T> nonZero, const BsrBlockInMask mask,
|
|
98
|
+
uint32_t* block_indices, BsrRowCol* tpl_row_col)
|
|
68
99
|
{
|
|
69
100
|
int block = blockIdx.x * blockDim.x + threadIdx.x;
|
|
70
101
|
if (block >= nnz)
|
|
@@ -74,7 +105,8 @@ __global__ void bsr_fill_triplet_key_values(const int nnz, const int nrow, const
|
|
|
74
105
|
const int col = tpl_columns[block];
|
|
75
106
|
const bool is_valid = row >= 0 && row < nrow;
|
|
76
107
|
|
|
77
|
-
const BsrRowCol row_col =
|
|
108
|
+
const BsrRowCol row_col =
|
|
109
|
+
is_valid && nonZero(block) && mask(row, col) ? bsr_combine_row_col(row, col) : PRUNED_ROWCOL;
|
|
78
110
|
tpl_row_col[block] = row_col;
|
|
79
111
|
block_indices[block] = block;
|
|
80
112
|
}
|
|
@@ -122,7 +154,7 @@ __global__ void bsr_find_row_offsets(uint32_t row_count, const T* d_nnz, const B
|
|
|
122
154
|
}
|
|
123
155
|
|
|
124
156
|
template <typename T>
|
|
125
|
-
__global__ void bsr_merge_blocks(const
|
|
157
|
+
__global__ void bsr_merge_blocks(const int* d_nnz, int block_size, const uint32_t* block_offsets,
|
|
126
158
|
const uint32_t* sorted_block_indices, const BsrRowCol* unique_row_cols,
|
|
127
159
|
const T* tpl_values, int* bsr_cols, T* bsr_values)
|
|
128
160
|
|
|
@@ -163,8 +195,8 @@ __global__ void bsr_merge_blocks(const uint32_t* d_nnz, int block_size, const ui
|
|
|
163
195
|
template <typename T>
|
|
164
196
|
void bsr_matrix_from_triplets_device(const int rows_per_block, const int cols_per_block, const int row_count,
|
|
165
197
|
const int nnz, const int* tpl_rows, const int* tpl_columns, const T* tpl_values,
|
|
166
|
-
const bool prune_numerical_zeros,
|
|
167
|
-
T* bsr_values, int* bsr_nnz, void* bsr_nnz_event)
|
|
198
|
+
const bool prune_numerical_zeros, const bool masked, int* bsr_offsets,
|
|
199
|
+
int* bsr_columns, T* bsr_values, int* bsr_nnz, void* bsr_nnz_event)
|
|
168
200
|
{
|
|
169
201
|
const int block_size = rows_per_block * cols_per_block;
|
|
170
202
|
|
|
@@ -186,8 +218,9 @@ void bsr_matrix_from_triplets_device(const int rows_per_block, const int cols_pe
|
|
|
186
218
|
|
|
187
219
|
// Combine rows and columns so we can sort on them both
|
|
188
220
|
BsrBlockIsNotZero<T> isNotZero{block_size, prune_numerical_zeros ? tpl_values : nullptr};
|
|
221
|
+
BsrBlockInMask mask{masked ? bsr_offsets : nullptr, bsr_columns};
|
|
189
222
|
wp_launch_device(WP_CURRENT_CONTEXT, bsr_fill_triplet_key_values, nnz,
|
|
190
|
-
(nnz, row_count, tpl_rows, tpl_columns, isNotZero, d_keys.Current(), d_values.Current()));
|
|
223
|
+
(nnz, row_count, tpl_rows, tpl_columns, isNotZero, mask, d_keys.Current(), d_values.Current()));
|
|
191
224
|
|
|
192
225
|
// Sort
|
|
193
226
|
{
|
|
@@ -214,7 +247,7 @@ void bsr_matrix_from_triplets_device(const int rows_per_block, const int cols_pe
|
|
|
214
247
|
|
|
215
248
|
if (bsr_nnz)
|
|
216
249
|
{
|
|
217
|
-
// Copy nnz to host, and record an event for the
|
|
250
|
+
// Copy nnz to host, and record an event for the completed transfer if desired
|
|
218
251
|
|
|
219
252
|
memcpy_d2h(WP_CURRENT_CONTEXT, bsr_nnz, bsr_offsets + row_count, sizeof(int), stream);
|
|
220
253
|
|
|
@@ -236,7 +269,7 @@ void bsr_matrix_from_triplets_device(const int rows_per_block, const int cols_pe
|
|
|
236
269
|
|
|
237
270
|
// Accumulate repeated blocks and set column indices
|
|
238
271
|
wp_launch_device(WP_CURRENT_CONTEXT, bsr_merge_blocks, nnz,
|
|
239
|
-
(
|
|
272
|
+
(bsr_offsets + row_count, block_size, d_keys.Alternate(), d_keys.Current(), d_values.Alternate(),
|
|
240
273
|
tpl_values, bsr_columns, bsr_values));
|
|
241
274
|
}
|
|
242
275
|
|
|
@@ -452,22 +485,24 @@ void bsr_transpose_device(int rows_per_block, int cols_per_block, int row_count,
|
|
|
452
485
|
|
|
453
486
|
void bsr_matrix_from_triplets_float_device(int rows_per_block, int cols_per_block, int row_count, int nnz,
|
|
454
487
|
int* tpl_rows, int* tpl_columns, void* tpl_values,
|
|
455
|
-
bool prune_numerical_zeros, int* bsr_offsets, int* bsr_columns,
|
|
488
|
+
bool prune_numerical_zeros, bool masked, int* bsr_offsets, int* bsr_columns,
|
|
456
489
|
void* bsr_values, int* bsr_nnz, void* bsr_nnz_event)
|
|
457
490
|
{
|
|
458
|
-
return bsr_matrix_from_triplets_device<float>(
|
|
459
|
-
|
|
460
|
-
|
|
491
|
+
return bsr_matrix_from_triplets_device<float>(rows_per_block, cols_per_block, row_count, nnz, tpl_rows, tpl_columns,
|
|
492
|
+
static_cast<const float*>(tpl_values), prune_numerical_zeros, masked,
|
|
493
|
+
bsr_offsets, bsr_columns, static_cast<float*>(bsr_values), bsr_nnz,
|
|
494
|
+
bsr_nnz_event);
|
|
461
495
|
}
|
|
462
496
|
|
|
463
497
|
void bsr_matrix_from_triplets_double_device(int rows_per_block, int cols_per_block, int row_count, int nnz,
|
|
464
498
|
int* tpl_rows, int* tpl_columns, void* tpl_values,
|
|
465
|
-
bool prune_numerical_zeros, int* bsr_offsets, int* bsr_columns,
|
|
499
|
+
bool prune_numerical_zeros, bool masked, int* bsr_offsets, int* bsr_columns,
|
|
466
500
|
void* bsr_values, int* bsr_nnz, void* bsr_nnz_event)
|
|
467
501
|
{
|
|
468
|
-
return bsr_matrix_from_triplets_device<double>(
|
|
469
|
-
|
|
470
|
-
|
|
502
|
+
return bsr_matrix_from_triplets_device<double>(rows_per_block, cols_per_block, row_count, nnz, tpl_rows,
|
|
503
|
+
tpl_columns, static_cast<const double*>(tpl_values),
|
|
504
|
+
prune_numerical_zeros, masked, bsr_offsets, bsr_columns,
|
|
505
|
+
static_cast<double*>(bsr_values), bsr_nnz, bsr_nnz_event);
|
|
471
506
|
}
|
|
472
507
|
|
|
473
508
|
void bsr_transpose_float_device(int rows_per_block, int cols_per_block, int row_count, int col_count, int nnz,
|
warp/native/svd.h
CHANGED
|
@@ -432,6 +432,62 @@ void _svd(// input A
|
|
|
432
432
|
);
|
|
433
433
|
}
|
|
434
434
|
|
|
435
|
+
|
|
436
|
+
template<typename Type>
|
|
437
|
+
inline CUDA_CALLABLE
|
|
438
|
+
void _svd_2(// input A
|
|
439
|
+
Type a11, Type a12,
|
|
440
|
+
Type a21, Type a22,
|
|
441
|
+
// output U
|
|
442
|
+
Type &u11, Type &u12,
|
|
443
|
+
Type &u21, Type &u22,
|
|
444
|
+
// output S
|
|
445
|
+
Type &s11, Type &s12,
|
|
446
|
+
Type &s21, Type &s22,
|
|
447
|
+
// output V
|
|
448
|
+
Type &v11, Type &v12,
|
|
449
|
+
Type &v21, Type &v22)
|
|
450
|
+
{
|
|
451
|
+
// Step 1: Compute ATA
|
|
452
|
+
Type ATA11 = a11 * a11 + a21 * a21;
|
|
453
|
+
Type ATA12 = a11 * a12 + a21 * a22;
|
|
454
|
+
Type ATA22 = a12 * a12 + a22 * a22;
|
|
455
|
+
|
|
456
|
+
// Step 2: Eigenanalysis
|
|
457
|
+
Type trace = ATA11 + ATA22;
|
|
458
|
+
Type det = ATA11 * ATA22 - ATA12 * ATA12;
|
|
459
|
+
Type sqrt_term = sqrt(trace * trace - Type(4.0) * det);
|
|
460
|
+
Type lambda1 = (trace + sqrt_term) * Type(0.5);
|
|
461
|
+
Type lambda2 = (trace - sqrt_term) * Type(0.5);
|
|
462
|
+
|
|
463
|
+
// Step 3: Singular values
|
|
464
|
+
Type sigma1 = sqrt(lambda1);
|
|
465
|
+
Type sigma2 = sqrt(lambda2);
|
|
466
|
+
|
|
467
|
+
// Step 4: Eigenvectors (find V)
|
|
468
|
+
Type v1x = ATA12, v1y = lambda1 - ATA11; // For first eigenvector
|
|
469
|
+
Type v2x = ATA12, v2y = lambda2 - ATA11; // For second eigenvector
|
|
470
|
+
Type norm1 = sqrt(v1x * v1x + v1y * v1y);
|
|
471
|
+
Type norm2 = sqrt(v2x * v2x + v2y * v2y);
|
|
472
|
+
|
|
473
|
+
v11 = v1x / norm1; v12 = v2x / norm2;
|
|
474
|
+
v21 = v1y / norm1; v22 = v2y / norm2;
|
|
475
|
+
|
|
476
|
+
// Step 5: Compute U
|
|
477
|
+
Type inv_sigma1 = (sigma1 > Type(1e-6)) ? Type(1.0) / sigma1 : Type(0.0);
|
|
478
|
+
Type inv_sigma2 = (sigma2 > Type(1e-6)) ? Type(1.0) / sigma2 : Type(0.0);
|
|
479
|
+
|
|
480
|
+
u11 = (a11 * v11 + a12 * v21) * inv_sigma1;
|
|
481
|
+
u12 = (a11 * v12 + a12 * v22) * inv_sigma2;
|
|
482
|
+
u21 = (a21 * v11 + a22 * v21) * inv_sigma1;
|
|
483
|
+
u22 = (a21 * v12 + a22 * v22) * inv_sigma2;
|
|
484
|
+
|
|
485
|
+
// Step 6: Set S
|
|
486
|
+
s11 = sigma1; s12 = Type(0.0);
|
|
487
|
+
s21 = Type(0.0); s22 = sigma2;
|
|
488
|
+
}
|
|
489
|
+
|
|
490
|
+
|
|
435
491
|
template<typename Type>
|
|
436
492
|
inline CUDA_CALLABLE void svd3(const mat_t<3,3,Type>& A, mat_t<3,3,Type>& U, vec_t<3,Type>& sigma, mat_t<3,3,Type>& V) {
|
|
437
493
|
Type s12, s13, s21, s23, s31, s32;
|
|
@@ -492,6 +548,66 @@ inline CUDA_CALLABLE void adj_svd3(const mat_t<3,3,Type>& A,
|
|
|
492
548
|
adj_A = adj_A + (u_term + v_term + sigma_term);
|
|
493
549
|
}
|
|
494
550
|
|
|
551
|
+
template<typename Type>
|
|
552
|
+
inline CUDA_CALLABLE void svd2(const mat_t<2,2,Type>& A, mat_t<2,2,Type>& U, vec_t<2,Type>& sigma, mat_t<2,2,Type>& V) {
|
|
553
|
+
Type s12, s21;
|
|
554
|
+
_svd_2(A.data[0][0], A.data[0][1],
|
|
555
|
+
A.data[1][0], A.data[1][1],
|
|
556
|
+
|
|
557
|
+
U.data[0][0], U.data[0][1],
|
|
558
|
+
U.data[1][0], U.data[1][1],
|
|
559
|
+
|
|
560
|
+
sigma[0], s12,
|
|
561
|
+
s21, sigma[1],
|
|
562
|
+
|
|
563
|
+
V.data[0][0], V.data[0][1],
|
|
564
|
+
V.data[1][0], V.data[1][1]);
|
|
565
|
+
}
|
|
566
|
+
|
|
567
|
+
template<typename Type>
|
|
568
|
+
inline CUDA_CALLABLE void adj_svd2(const mat_t<2,2,Type>& A,
|
|
569
|
+
const mat_t<2,2,Type>& U,
|
|
570
|
+
const vec_t<2,Type>& sigma,
|
|
571
|
+
const mat_t<2,2,Type>& V,
|
|
572
|
+
mat_t<2,2,Type>& adj_A,
|
|
573
|
+
const mat_t<2,2,Type>& adj_U,
|
|
574
|
+
const vec_t<2,Type>& adj_sigma,
|
|
575
|
+
const mat_t<2,2,Type>& adj_V) {
|
|
576
|
+
Type s1_squared = sigma[0] * sigma[0];
|
|
577
|
+
Type s2_squared = sigma[1] * sigma[1];
|
|
578
|
+
|
|
579
|
+
// Compute inverse of (s1^2 - s2^2) if possible, use small epsilon to prevent division by zero
|
|
580
|
+
Type F01 = Type(1) / min(s2_squared - s1_squared, Type(-1e-6f));
|
|
581
|
+
|
|
582
|
+
// Construct the matrix F for the adjoint
|
|
583
|
+
mat_t<2,2,Type> F = mat_t<2,2,Type>(0.0, F01,
|
|
584
|
+
-F01, 0.0);
|
|
585
|
+
|
|
586
|
+
// Create a matrix to handle the adjoint of the singular values (diagonal matrix)
|
|
587
|
+
mat_t<2,2,Type> adj_sigma_mat = mat_t<2,2,Type>(adj_sigma[0], 0.0,
|
|
588
|
+
0.0, adj_sigma[1]);
|
|
589
|
+
|
|
590
|
+
// Matrix for handling singular values (diagonal matrix with sigma values)
|
|
591
|
+
mat_t<2,2,Type> s_mat = mat_t<2,2,Type>(sigma[0], 0.0,
|
|
592
|
+
0.0, sigma[1]);
|
|
593
|
+
|
|
594
|
+
// Compute the transpose of U and V
|
|
595
|
+
mat_t<2,2,Type> UT = transpose(U);
|
|
596
|
+
mat_t<2,2,Type> VT = transpose(V);
|
|
597
|
+
|
|
598
|
+
// Compute the term for sigma (diagonal matrix of adjoint singular values)
|
|
599
|
+
mat_t<2,2,Type> sigma_term = mul(U, mul(adj_sigma_mat, VT));
|
|
600
|
+
|
|
601
|
+
// Compute the adjoint contributions for U (left singular vectors)
|
|
602
|
+
mat_t<2,2,Type> u_term = mul(mul(U, mul(cw_mul(F, (mul(UT, adj_U) - mul(transpose(adj_U), U))), s_mat)), VT);
|
|
603
|
+
|
|
604
|
+
// Compute the adjoint contributions for V (right singular vectors)
|
|
605
|
+
mat_t<2,2,Type> v_term = mul(U, mul(s_mat, mul(cw_mul(F, (mul(VT, adj_V) - mul(transpose(adj_V), V))), VT)));
|
|
606
|
+
|
|
607
|
+
// Combine the terms to compute the adjoint of A
|
|
608
|
+
adj_A = adj_A + (u_term + v_term + sigma_term);
|
|
609
|
+
}
|
|
610
|
+
|
|
495
611
|
|
|
496
612
|
template<typename Type>
|
|
497
613
|
inline CUDA_CALLABLE void qr3(const mat_t<3,3,Type>& A, mat_t<3,3,Type>& Q, mat_t<3,3,Type>& R) {
|