warp-lang 1.6.1__py3-none-win_amd64.whl → 1.7.0__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +21 -7
- warp/autograd.py +14 -6
- warp/bin/warp-clang.dll +0 -0
- warp/bin/warp.dll +0 -0
- warp/build.py +424 -6
- warp/build_dll.py +20 -20
- warp/builtins.py +467 -368
- warp/codegen.py +193 -125
- warp/config.py +56 -12
- warp/constants.py +14 -6
- warp/context.py +524 -277
- warp/dlpack.py +22 -12
- warp/examples/__init__.py +14 -6
- warp/examples/assets/nonuniform.usd +0 -0
- warp/examples/assets/nvidia_logo.png +0 -0
- warp/examples/benchmarks/benchmark_api.py +14 -6
- warp/examples/benchmarks/benchmark_cloth.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_cupy.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_jax.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_numba.py +15 -0
- warp/examples/benchmarks/benchmark_cloth_numpy.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_paddle.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_pytorch.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_taichi.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_warp.py +14 -6
- warp/examples/benchmarks/benchmark_gemm.py +82 -48
- warp/examples/benchmarks/benchmark_interop_paddle.py +14 -6
- warp/examples/benchmarks/benchmark_interop_torch.py +14 -6
- warp/examples/benchmarks/benchmark_launches.py +14 -6
- warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
- warp/examples/browse.py +14 -6
- warp/examples/core/example_cupy.py +14 -6
- warp/examples/core/example_dem.py +14 -6
- warp/examples/core/example_fluid.py +14 -6
- warp/examples/core/example_graph_capture.py +14 -6
- warp/examples/core/example_marching_cubes.py +14 -6
- warp/examples/core/example_mesh.py +14 -6
- warp/examples/core/example_mesh_intersect.py +14 -6
- warp/examples/core/example_nvdb.py +14 -6
- warp/examples/core/example_raycast.py +14 -6
- warp/examples/core/example_raymarch.py +14 -6
- warp/examples/core/example_render_opengl.py +14 -6
- warp/examples/core/example_sample_mesh.py +300 -0
- warp/examples/core/example_sph.py +14 -6
- warp/examples/core/example_torch.py +14 -6
- warp/examples/core/example_wave.py +14 -6
- warp/examples/fem/example_adaptive_grid.py +14 -6
- warp/examples/fem/example_apic_fluid.py +15 -7
- warp/examples/fem/example_burgers.py +16 -8
- warp/examples/fem/example_convection_diffusion.py +14 -6
- warp/examples/fem/example_convection_diffusion_dg.py +14 -6
- warp/examples/fem/example_deformed_geometry.py +15 -7
- warp/examples/fem/example_diffusion.py +14 -6
- warp/examples/fem/example_diffusion_3d.py +14 -6
- warp/examples/fem/example_diffusion_mgpu.py +14 -6
- warp/examples/fem/example_distortion_energy.py +15 -7
- warp/examples/fem/example_magnetostatics.py +20 -12
- warp/examples/fem/example_mixed_elasticity.py +14 -6
- warp/examples/fem/example_navier_stokes.py +14 -6
- warp/examples/fem/example_nonconforming_contact.py +14 -6
- warp/examples/fem/example_stokes.py +14 -6
- warp/examples/fem/example_stokes_transfer.py +14 -6
- warp/examples/fem/example_streamlines.py +14 -6
- warp/examples/fem/utils.py +24 -3
- warp/examples/interop/example_jax_callable.py +116 -0
- warp/examples/interop/example_jax_ffi_callback.py +132 -0
- warp/examples/interop/example_jax_kernel.py +205 -0
- warp/examples/optim/example_bounce.py +14 -6
- warp/examples/optim/example_cloth_throw.py +14 -6
- warp/examples/optim/example_diffray.py +14 -6
- warp/examples/optim/example_drone.py +14 -6
- warp/examples/optim/example_fluid_checkpoint.py +497 -0
- warp/examples/optim/example_inverse_kinematics.py +14 -6
- warp/examples/optim/example_inverse_kinematics_torch.py +14 -6
- warp/examples/optim/example_softbody_properties.py +14 -6
- warp/examples/optim/example_spring_cage.py +14 -6
- warp/examples/optim/example_trajectory.py +14 -6
- warp/examples/sim/example_cartpole.py +14 -6
- warp/examples/sim/example_cloth.py +14 -6
- warp/examples/sim/example_cloth_self_contact.py +14 -6
- warp/examples/sim/example_granular.py +14 -6
- warp/examples/sim/example_granular_collision_sdf.py +14 -6
- warp/examples/sim/example_jacobian_ik.py +14 -6
- warp/examples/sim/example_particle_chain.py +14 -6
- warp/examples/sim/example_quadruped.py +14 -6
- warp/examples/sim/example_rigid_chain.py +14 -6
- warp/examples/sim/example_rigid_contact.py +14 -6
- warp/examples/sim/example_rigid_force.py +14 -6
- warp/examples/sim/example_rigid_gyroscopic.py +14 -6
- warp/examples/sim/example_rigid_soft_contact.py +14 -6
- warp/examples/sim/example_soft_body.py +14 -6
- warp/examples/tile/example_tile_cholesky.py +14 -6
- warp/examples/tile/example_tile_convolution.py +14 -6
- warp/examples/tile/example_tile_fft.py +14 -6
- warp/examples/tile/example_tile_filtering.py +14 -6
- warp/examples/tile/example_tile_matmul.py +16 -10
- warp/examples/tile/example_tile_mlp.py +14 -6
- warp/examples/tile/example_tile_nbody.py +14 -6
- warp/examples/tile/example_tile_walker.py +14 -6
- warp/fabric.py +15 -0
- warp/fem/__init__.py +26 -1
- warp/fem/adaptivity.py +19 -4
- warp/fem/cache.py +15 -0
- warp/fem/dirichlet.py +15 -0
- warp/fem/domain.py +15 -0
- warp/fem/field/__init__.py +15 -0
- warp/fem/field/field.py +15 -0
- warp/fem/field/nodal_field.py +37 -68
- warp/fem/field/restriction.py +15 -0
- warp/fem/field/virtual.py +77 -23
- warp/fem/geometry/__init__.py +15 -0
- warp/fem/geometry/adaptive_nanogrid.py +24 -10
- warp/fem/geometry/closest_point.py +16 -1
- warp/fem/geometry/deformed_geometry.py +20 -2
- warp/fem/geometry/element.py +15 -0
- warp/fem/geometry/geometry.py +20 -0
- warp/fem/geometry/grid_2d.py +27 -12
- warp/fem/geometry/grid_3d.py +27 -15
- warp/fem/geometry/hexmesh.py +20 -7
- warp/fem/geometry/nanogrid.py +24 -11
- warp/fem/geometry/partition.py +15 -0
- warp/fem/geometry/quadmesh.py +28 -13
- warp/fem/geometry/tetmesh.py +18 -4
- warp/fem/geometry/trimesh.py +18 -8
- warp/fem/integrate.py +277 -93
- warp/fem/linalg.py +20 -5
- warp/fem/operator.py +15 -0
- warp/fem/polynomial.py +15 -0
- warp/fem/quadrature/__init__.py +15 -0
- warp/fem/quadrature/pic_quadrature.py +52 -22
- warp/fem/quadrature/quadrature.py +209 -25
- warp/fem/space/__init__.py +16 -1
- warp/fem/space/basis_function_space.py +19 -2
- warp/fem/space/basis_space.py +40 -18
- warp/fem/space/dof_mapper.py +15 -0
- warp/fem/space/function_space.py +15 -0
- warp/fem/space/grid_2d_function_space.py +15 -0
- warp/fem/space/grid_3d_function_space.py +15 -0
- warp/fem/space/hexmesh_function_space.py +17 -2
- warp/fem/space/nanogrid_function_space.py +15 -0
- warp/fem/space/partition.py +21 -2
- warp/fem/space/quadmesh_function_space.py +23 -8
- warp/fem/space/restriction.py +15 -0
- warp/fem/space/shape/__init__.py +15 -0
- warp/fem/space/shape/cube_shape_function.py +38 -23
- warp/fem/space/shape/shape_function.py +15 -0
- warp/fem/space/shape/square_shape_function.py +27 -12
- warp/fem/space/shape/tet_shape_function.py +15 -0
- warp/fem/space/shape/triangle_shape_function.py +16 -1
- warp/fem/space/tetmesh_function_space.py +18 -3
- warp/fem/space/topology.py +15 -0
- warp/fem/space/trimesh_function_space.py +17 -2
- warp/fem/types.py +15 -0
- warp/fem/utils.py +27 -6
- warp/jax.py +28 -7
- warp/jax_experimental/__init__.py +16 -0
- warp/{jax_experimental.py → jax_experimental/custom_call.py} +28 -33
- warp/jax_experimental/ffi.py +698 -0
- warp/jax_experimental/xla_ffi.py +602 -0
- warp/math.py +103 -6
- warp/native/array.h +28 -6
- warp/native/builtin.h +44 -9
- warp/native/bvh.cpp +18 -7
- warp/native/bvh.cu +57 -20
- warp/native/bvh.h +17 -7
- warp/native/clang/clang.cpp +45 -9
- warp/native/coloring.cpp +15 -6
- warp/native/crt.cpp +15 -6
- warp/native/crt.h +15 -6
- warp/native/cuda_crt.h +15 -6
- warp/native/cuda_util.cpp +29 -6
- warp/native/cuda_util.h +17 -6
- warp/native/error.cpp +15 -6
- warp/native/error.h +15 -6
- warp/native/exports.h +85 -63
- warp/native/fabric.h +15 -6
- warp/native/hashgrid.cpp +15 -6
- warp/native/hashgrid.cu +15 -6
- warp/native/hashgrid.h +15 -6
- warp/native/initializer_array.h +15 -6
- warp/native/intersect.h +41 -32
- warp/native/intersect_adj.h +48 -39
- warp/native/intersect_tri.h +17 -0
- warp/native/marching.cpp +16 -0
- warp/native/marching.cu +16 -7
- warp/native/marching.h +17 -0
- warp/native/mat.h +528 -15
- warp/native/mathdx.cpp +15 -6
- warp/native/matnn.h +15 -6
- warp/native/mesh.cpp +15 -6
- warp/native/mesh.cu +15 -6
- warp/native/mesh.h +25 -16
- warp/native/noise.h +15 -6
- warp/native/quat.h +114 -17
- warp/native/rand.h +21 -6
- warp/native/range.h +15 -6
- warp/native/reduce.cpp +15 -6
- warp/native/reduce.cu +15 -6
- warp/native/runlength_encode.cpp +15 -6
- warp/native/runlength_encode.cu +15 -6
- warp/native/scan.cpp +15 -6
- warp/native/scan.cu +15 -6
- warp/native/scan.h +15 -6
- warp/native/solid_angle.h +17 -0
- warp/native/sort.cpp +137 -65
- warp/native/sort.cu +167 -21
- warp/native/sort.h +23 -7
- warp/native/sparse.cpp +58 -28
- warp/native/sparse.cu +67 -23
- warp/native/spatial.h +15 -6
- warp/native/svd.h +131 -6
- warp/native/temp_buffer.h +15 -6
- warp/native/tile.h +316 -111
- warp/native/tile_reduce.h +61 -9
- warp/native/vec.h +83 -13
- warp/native/volume.cpp +100 -119
- warp/native/volume.cu +15 -6
- warp/native/volume.h +15 -6
- warp/native/volume_builder.cu +40 -16
- warp/native/volume_builder.h +21 -6
- warp/native/volume_impl.h +15 -6
- warp/native/warp.cpp +20 -12
- warp/native/warp.cu +114 -16
- warp/native/warp.h +34 -16
- warp/optim/__init__.py +14 -6
- warp/optim/adam.py +14 -6
- warp/optim/linear.py +25 -10
- warp/optim/sgd.py +14 -6
- warp/paddle.py +14 -6
- warp/render/__init__.py +14 -6
- warp/render/render_opengl.py +14 -6
- warp/render/render_usd.py +14 -6
- warp/render/utils.py +14 -6
- warp/sim/__init__.py +14 -7
- warp/sim/articulation.py +18 -10
- warp/sim/collide.py +35 -16
- warp/sim/graph_coloring.py +14 -6
- warp/sim/import_mjcf.py +463 -162
- warp/sim/import_snu.py +14 -7
- warp/sim/import_urdf.py +46 -18
- warp/sim/import_usd.py +14 -7
- warp/sim/inertia.py +14 -6
- warp/sim/integrator.py +14 -6
- warp/sim/integrator_euler.py +19 -11
- warp/sim/integrator_featherstone.py +17 -16
- warp/sim/integrator_vbd.py +222 -8
- warp/sim/integrator_xpbd.py +19 -11
- warp/sim/model.py +56 -19
- warp/sim/particles.py +14 -6
- warp/sim/render.py +14 -6
- warp/sim/utils.py +17 -2
- warp/sparse.py +657 -555
- warp/stubs.py +231 -19
- warp/tape.py +14 -6
- warp/tests/aux_test_class_kernel.py +14 -6
- warp/tests/aux_test_compile_consts_dummy.py +14 -6
- warp/tests/aux_test_conditional_unequal_types_kernels.py +14 -6
- warp/tests/aux_test_dependent.py +14 -6
- warp/tests/aux_test_grad_customs.py +14 -6
- warp/tests/aux_test_instancing_gc.py +14 -6
- warp/tests/aux_test_module_unload.py +14 -6
- warp/tests/aux_test_name_clash1.py +14 -6
- warp/tests/aux_test_name_clash2.py +14 -6
- warp/tests/aux_test_unresolved_func.py +14 -6
- warp/tests/aux_test_unresolved_symbol.py +14 -6
- warp/tests/cuda/__init__.py +0 -0
- warp/tests/{test_async.py → cuda/test_async.py} +14 -6
- warp/tests/{test_ipc.py → cuda/test_ipc.py} +14 -6
- warp/tests/{test_mempool.py → cuda/test_mempool.py} +53 -6
- warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +14 -6
- warp/tests/{test_peer.py → cuda/test_peer.py} +14 -6
- warp/tests/{test_pinned.py → cuda/test_pinned.py} +14 -6
- warp/tests/{test_streams.py → cuda/test_streams.py} +85 -6
- warp/tests/geometry/__init__.py +0 -0
- warp/tests/{test_bvh.py → geometry/test_bvh.py} +14 -6
- warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +14 -6
- warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +14 -6
- warp/tests/{test_mesh.py → geometry/test_mesh.py} +14 -6
- warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +14 -6
- warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +80 -69
- warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +15 -7
- warp/tests/{test_volume.py → geometry/test_volume.py} +55 -12
- warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +14 -6
- warp/tests/interop/__init__.py +0 -0
- warp/tests/{test_dlpack.py → interop/test_dlpack.py} +42 -11
- warp/tests/{test_jax.py → interop/test_jax.py} +14 -6
- warp/tests/{test_paddle.py → interop/test_paddle.py} +14 -6
- warp/tests/{test_torch.py → interop/test_torch.py} +14 -6
- warp/tests/run_coverage_serial.py +14 -6
- warp/tests/sim/__init__.py +0 -0
- warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +23 -16
- warp/tests/{flaky_test_sim_grad.py → sim/flaky_test_sim_grad.py} +14 -6
- warp/tests/{test_collision.py → sim/test_collision.py} +16 -8
- warp/tests/{test_coloring.py → sim/test_coloring.py} +14 -7
- warp/tests/{test_model.py → sim/test_model.py} +55 -7
- warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +14 -6
- warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +16 -7
- warp/tests/sim/test_vbd.py +597 -0
- warp/tests/test_adam.py +14 -6
- warp/tests/test_arithmetic.py +14 -6
- warp/tests/test_array.py +14 -6
- warp/tests/test_array_reduce.py +14 -6
- warp/tests/test_assert.py +14 -6
- warp/tests/test_atomic.py +14 -6
- warp/tests/test_bool.py +15 -7
- warp/tests/test_builtins_resolution.py +14 -6
- warp/tests/test_closest_point_edge_edge.py +14 -6
- warp/tests/test_codegen.py +14 -6
- warp/tests/test_codegen_instancing.py +14 -6
- warp/tests/test_compile_consts.py +14 -6
- warp/tests/test_conditional.py +14 -6
- warp/tests/test_context.py +14 -6
- warp/tests/test_copy.py +14 -6
- warp/tests/test_ctypes.py +14 -6
- warp/tests/test_dense.py +14 -6
- warp/tests/test_devices.py +14 -6
- warp/tests/test_examples.py +42 -42
- warp/tests/test_fabricarray.py +14 -6
- warp/tests/test_fast_math.py +14 -6
- warp/tests/test_fem.py +37 -10
- warp/tests/test_fp16.py +14 -6
- warp/tests/test_func.py +14 -6
- warp/tests/test_future_annotations.py +14 -6
- warp/tests/test_generics.py +14 -6
- warp/tests/test_grad.py +14 -6
- warp/tests/test_grad_customs.py +14 -6
- warp/tests/test_grad_debug.py +14 -6
- warp/tests/test_implicit_init.py +14 -6
- warp/tests/test_import.py +14 -6
- warp/tests/test_indexedarray.py +14 -6
- warp/tests/test_intersect.py +14 -6
- warp/tests/test_iter.py +14 -6
- warp/tests/test_large.py +14 -6
- warp/tests/test_launch.py +14 -6
- warp/tests/test_lerp.py +14 -6
- warp/tests/test_linear_solvers.py +15 -11
- warp/tests/test_lvalue.py +14 -6
- warp/tests/test_mat.py +247 -85
- warp/tests/test_mat_lite.py +14 -6
- warp/tests/test_mat_scalar_ops.py +18 -10
- warp/tests/test_math.py +14 -6
- warp/tests/test_mlp.py +14 -6
- warp/tests/test_module_hashing.py +14 -6
- warp/tests/test_modules_lite.py +14 -6
- warp/tests/test_noise.py +14 -6
- warp/tests/test_operators.py +14 -6
- warp/tests/test_options.py +14 -6
- warp/tests/test_overwrite.py +15 -60
- warp/tests/test_print.py +14 -6
- warp/tests/test_quat.py +81 -52
- warp/tests/test_rand.py +58 -43
- warp/tests/test_reload.py +14 -6
- warp/tests/test_rounding.py +14 -6
- warp/tests/test_runlength_encode.py +14 -6
- warp/tests/test_scalar_ops.py +14 -6
- warp/tests/test_smoothstep.py +14 -6
- warp/tests/test_snippet.py +15 -0
- warp/tests/test_sparse.py +61 -12
- warp/tests/test_spatial.py +89 -6
- warp/tests/test_special_values.py +14 -6
- warp/tests/test_static.py +15 -7
- warp/tests/test_struct.py +14 -6
- warp/tests/test_tape.py +14 -6
- warp/tests/test_transient_module.py +14 -6
- warp/tests/test_triangle_closest_point.py +14 -6
- warp/tests/test_types.py +14 -6
- warp/tests/test_utils.py +98 -10
- warp/tests/test_vec.py +60 -40
- warp/tests/test_vec_lite.py +14 -6
- warp/tests/test_vec_scalar_ops.py +14 -6
- warp/tests/test_verify_fp.py +14 -6
- warp/tests/tile/__init__.py +0 -0
- warp/tests/{test_tile.py → tile/test_tile.py} +150 -57
- warp/tests/{test_tile_load.py → tile/test_tile_load.py} +15 -7
- warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +23 -12
- warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +39 -20
- warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +74 -7
- warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +14 -6
- warp/tests/{test_tile_view.py → tile/test_tile_view.py} +15 -7
- warp/tests/unittest_serial.py +15 -6
- warp/tests/unittest_suites.py +59 -65
- warp/tests/unittest_utils.py +16 -7
- warp/tests/walkthrough_debug.py +14 -6
- warp/thirdparty/unittest_parallel.py +15 -8
- warp/torch.py +14 -6
- warp/types.py +124 -664
- warp/utils.py +151 -78
- {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/METADATA +39 -12
- warp_lang-1.7.0.dist-info/RECORD +429 -0
- {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/WHEEL +1 -1
- warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
- warp/examples/optim/example_walker.py +0 -309
- warp/native/cutlass_gemm.cpp +0 -34
- warp/native/cutlass_gemm.cu +0 -373
- warp/tests/test_matmul.py +0 -503
- warp/tests/test_matmul_lite.py +0 -403
- warp/tests/test_vbd.py +0 -378
- warp/tests/unused_test_misc.py +0 -69
- warp_lang-1.6.1.dist-info/LICENSE.md +0 -126
- warp_lang-1.6.1.dist-info/RECORD +0 -419
- {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/top_level.txt +0 -0
warp/native/sparse.cpp
CHANGED
|
@@ -1,9 +1,18 @@
|
|
|
1
|
-
|
|
2
|
-
* NVIDIA CORPORATION
|
|
3
|
-
*
|
|
4
|
-
*
|
|
5
|
-
*
|
|
6
|
-
*
|
|
1
|
+
/*
|
|
2
|
+
* SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
3
|
+
* SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
*
|
|
5
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
* you may not use this file except in compliance with the License.
|
|
7
|
+
* You may obtain a copy of the License at
|
|
8
|
+
*
|
|
9
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
*
|
|
11
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
* See the License for the specific language governing permissions and
|
|
15
|
+
* limitations under the License.
|
|
7
16
|
*/
|
|
8
17
|
|
|
9
18
|
#include "warp.h"
|
|
@@ -72,7 +81,8 @@ template <typename T> void bsr_dyn_block_transpose(const T* src, T* dest, int ro
|
|
|
72
81
|
template <typename T>
|
|
73
82
|
int bsr_matrix_from_triplets_host(const int rows_per_block, const int cols_per_block, const int row_count,
|
|
74
83
|
const int nnz, const int* tpl_rows, const int* tpl_columns, const T* tpl_values,
|
|
75
|
-
const bool prune_numerical_zeros,
|
|
84
|
+
const bool prune_numerical_zeros, const bool masked, int* bsr_offsets,
|
|
85
|
+
int* bsr_columns, T* bsr_values)
|
|
76
86
|
{
|
|
77
87
|
|
|
78
88
|
// get specialized accumulator for common block sizes (1,1), (1,2), (1,3),
|
|
@@ -115,14 +125,33 @@ int bsr_matrix_from_triplets_host(const int rows_per_block, const int cols_per_b
|
|
|
115
125
|
std::iota(block_indices.begin(), block_indices.end(), 0);
|
|
116
126
|
|
|
117
127
|
// remove zero blocks and invalid row indices
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
128
|
+
|
|
129
|
+
auto discard_block = [&](int i)
|
|
130
|
+
{
|
|
131
|
+
const int row = tpl_rows[i];
|
|
132
|
+
if (row < 0 || row >= row_count)
|
|
133
|
+
{
|
|
134
|
+
return true;
|
|
135
|
+
}
|
|
136
|
+
|
|
137
|
+
if (prune_numerical_zeros && tpl_values && block_is_zero_func(tpl_values + i * block_size, block_size))
|
|
138
|
+
{
|
|
139
|
+
return true;
|
|
140
|
+
}
|
|
141
|
+
|
|
142
|
+
if (!masked)
|
|
143
|
+
{
|
|
144
|
+
return false;
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
const int* beg = bsr_columns + bsr_offsets[row];
|
|
148
|
+
const int* end = bsr_columns + bsr_offsets[row + 1];
|
|
149
|
+
const int col = tpl_columns[i];
|
|
150
|
+
const int* block = std::lower_bound(beg, end, col);
|
|
151
|
+
return block == end || *block != col;
|
|
152
|
+
};
|
|
153
|
+
|
|
154
|
+
block_indices.erase(std::remove_if(block_indices.begin(), block_indices.end(), discard_block), block_indices.end());
|
|
126
155
|
|
|
127
156
|
// sort block indices according to lexico order
|
|
128
157
|
std::sort(block_indices.begin(), block_indices.end(), [tpl_rows, tpl_columns](int i, int j) -> bool
|
|
@@ -272,12 +301,12 @@ void bsr_transpose_host(int rows_per_block, int cols_per_block, int row_count, i
|
|
|
272
301
|
|
|
273
302
|
WP_API void bsr_matrix_from_triplets_float_host(int rows_per_block, int cols_per_block, int row_count, int nnz,
|
|
274
303
|
int* tpl_rows, int* tpl_columns, void* tpl_values,
|
|
275
|
-
bool prune_numerical_zeros,
|
|
276
|
-
void* bsr_values, int* bsr_nnz, void* bsr_nnz_event)
|
|
304
|
+
bool prune_numerical_zeros, bool masked, int* bsr_offsets,
|
|
305
|
+
int* bsr_columns, void* bsr_values, int* bsr_nnz, void* bsr_nnz_event)
|
|
277
306
|
{
|
|
278
307
|
bsr_matrix_from_triplets_host<float>(rows_per_block, cols_per_block, row_count, nnz, tpl_rows, tpl_columns,
|
|
279
|
-
static_cast<const float*>(tpl_values), prune_numerical_zeros,
|
|
280
|
-
bsr_columns, static_cast<float*>(bsr_values));
|
|
308
|
+
static_cast<const float*>(tpl_values), prune_numerical_zeros, masked,
|
|
309
|
+
bsr_offsets, bsr_columns, static_cast<float*>(bsr_values));
|
|
281
310
|
if (bsr_nnz)
|
|
282
311
|
{
|
|
283
312
|
*bsr_nnz = bsr_offsets[row_count];
|
|
@@ -286,12 +315,12 @@ WP_API void bsr_matrix_from_triplets_float_host(int rows_per_block, int cols_per
|
|
|
286
315
|
|
|
287
316
|
WP_API void bsr_matrix_from_triplets_double_host(int rows_per_block, int cols_per_block, int row_count, int nnz,
|
|
288
317
|
int* tpl_rows, int* tpl_columns, void* tpl_values,
|
|
289
|
-
bool prune_numerical_zeros,
|
|
290
|
-
void* bsr_values, int* bsr_nnz, void* bsr_nnz_event)
|
|
318
|
+
bool prune_numerical_zeros, bool masked, int* bsr_offsets,
|
|
319
|
+
int* bsr_columns, void* bsr_values, int* bsr_nnz, void* bsr_nnz_event)
|
|
291
320
|
{
|
|
292
321
|
bsr_matrix_from_triplets_host<double>(rows_per_block, cols_per_block, row_count, nnz, tpl_rows, tpl_columns,
|
|
293
|
-
static_cast<const double*>(tpl_values), prune_numerical_zeros,
|
|
294
|
-
bsr_columns, static_cast<double*>(bsr_values));
|
|
322
|
+
static_cast<const double*>(tpl_values), prune_numerical_zeros, masked,
|
|
323
|
+
bsr_offsets, bsr_columns, static_cast<double*>(bsr_values));
|
|
295
324
|
if (bsr_nnz)
|
|
296
325
|
{
|
|
297
326
|
*bsr_nnz = bsr_offsets[row_count];
|
|
@@ -318,16 +347,17 @@ WP_API void bsr_transpose_double_host(int rows_per_block, int cols_per_block, in
|
|
|
318
347
|
|
|
319
348
|
#if !WP_ENABLE_CUDA
|
|
320
349
|
WP_API void bsr_matrix_from_triplets_float_device(int rows_per_block, int cols_per_block, int row_count, int nnz,
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
350
|
+
int* tpl_rows, int* tpl_columns, void* tpl_values,
|
|
351
|
+
bool prune_numerical_zeros, bool masked, int* bsr_offsets,
|
|
352
|
+
int* bsr_columns, void* bsr_values, int* bsr_nnz, void* bsr_nnz_event)
|
|
324
353
|
{
|
|
325
354
|
}
|
|
326
355
|
|
|
327
356
|
WP_API void bsr_matrix_from_triplets_double_device(int rows_per_block, int cols_per_block, int row_count, int nnz,
|
|
328
357
|
int* tpl_rows, int* tpl_columns, void* tpl_values,
|
|
329
|
-
bool prune_numerical_zeros,
|
|
330
|
-
void* bsr_values, int* bsr_nnz,
|
|
358
|
+
bool prune_numerical_zeros, bool masked, int* bsr_offsets,
|
|
359
|
+
int* bsr_columns, void* bsr_values, int* bsr_nnz,
|
|
360
|
+
void* bsr_nnz_event)
|
|
331
361
|
{
|
|
332
362
|
}
|
|
333
363
|
|
warp/native/sparse.cu
CHANGED
|
@@ -1,9 +1,18 @@
|
|
|
1
|
-
|
|
2
|
-
* NVIDIA CORPORATION
|
|
3
|
-
*
|
|
4
|
-
*
|
|
5
|
-
*
|
|
6
|
-
*
|
|
1
|
+
/*
|
|
2
|
+
* SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
3
|
+
* SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
*
|
|
5
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
* you may not use this file except in compliance with the License.
|
|
7
|
+
* You may obtain a copy of the License at
|
|
8
|
+
*
|
|
9
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
*
|
|
11
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
* See the License for the specific language governing permissions and
|
|
15
|
+
* limitations under the License.
|
|
7
16
|
*/
|
|
8
17
|
|
|
9
18
|
#include "cuda_util.h"
|
|
@@ -52,10 +61,41 @@ template <typename T> struct BsrBlockIsNotZero
|
|
|
52
61
|
}
|
|
53
62
|
};
|
|
54
63
|
|
|
64
|
+
struct BsrBlockInMask
|
|
65
|
+
{
|
|
66
|
+
const int* bsr_offsets;
|
|
67
|
+
const int* bsr_columns;
|
|
68
|
+
|
|
69
|
+
CUDA_CALLABLE_DEVICE bool operator()(int row, int col) const
|
|
70
|
+
{
|
|
71
|
+
if (bsr_offsets == nullptr)
|
|
72
|
+
return true;
|
|
73
|
+
|
|
74
|
+
int lower = bsr_offsets[row];
|
|
75
|
+
int upper = bsr_offsets[row + 1] - 1;
|
|
76
|
+
|
|
77
|
+
while (lower < upper)
|
|
78
|
+
{
|
|
79
|
+
const int mid = lower + (upper - lower) / 2;
|
|
80
|
+
|
|
81
|
+
if (bsr_columns[mid] < col)
|
|
82
|
+
{
|
|
83
|
+
lower = mid + 1;
|
|
84
|
+
}
|
|
85
|
+
else
|
|
86
|
+
{
|
|
87
|
+
upper = mid;
|
|
88
|
+
}
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
return lower == upper && (bsr_columns[lower] == col);
|
|
92
|
+
}
|
|
93
|
+
};
|
|
94
|
+
|
|
55
95
|
template <typename T>
|
|
56
96
|
__global__ void bsr_fill_triplet_key_values(const int nnz, const int nrow, const int* tpl_rows, const int* tpl_columns,
|
|
57
|
-
const BsrBlockIsNotZero<T> nonZero,
|
|
58
|
-
BsrRowCol* tpl_row_col)
|
|
97
|
+
const BsrBlockIsNotZero<T> nonZero, const BsrBlockInMask mask,
|
|
98
|
+
uint32_t* block_indices, BsrRowCol* tpl_row_col)
|
|
59
99
|
{
|
|
60
100
|
int block = blockIdx.x * blockDim.x + threadIdx.x;
|
|
61
101
|
if (block >= nnz)
|
|
@@ -65,7 +105,8 @@ __global__ void bsr_fill_triplet_key_values(const int nnz, const int nrow, const
|
|
|
65
105
|
const int col = tpl_columns[block];
|
|
66
106
|
const bool is_valid = row >= 0 && row < nrow;
|
|
67
107
|
|
|
68
|
-
const BsrRowCol row_col =
|
|
108
|
+
const BsrRowCol row_col =
|
|
109
|
+
is_valid && nonZero(block) && mask(row, col) ? bsr_combine_row_col(row, col) : PRUNED_ROWCOL;
|
|
69
110
|
tpl_row_col[block] = row_col;
|
|
70
111
|
block_indices[block] = block;
|
|
71
112
|
}
|
|
@@ -113,7 +154,7 @@ __global__ void bsr_find_row_offsets(uint32_t row_count, const T* d_nnz, const B
|
|
|
113
154
|
}
|
|
114
155
|
|
|
115
156
|
template <typename T>
|
|
116
|
-
__global__ void bsr_merge_blocks(const
|
|
157
|
+
__global__ void bsr_merge_blocks(const int* d_nnz, int block_size, const uint32_t* block_offsets,
|
|
117
158
|
const uint32_t* sorted_block_indices, const BsrRowCol* unique_row_cols,
|
|
118
159
|
const T* tpl_values, int* bsr_cols, T* bsr_values)
|
|
119
160
|
|
|
@@ -154,8 +195,8 @@ __global__ void bsr_merge_blocks(const uint32_t* d_nnz, int block_size, const ui
|
|
|
154
195
|
template <typename T>
|
|
155
196
|
void bsr_matrix_from_triplets_device(const int rows_per_block, const int cols_per_block, const int row_count,
|
|
156
197
|
const int nnz, const int* tpl_rows, const int* tpl_columns, const T* tpl_values,
|
|
157
|
-
const bool prune_numerical_zeros,
|
|
158
|
-
T* bsr_values, int* bsr_nnz, void* bsr_nnz_event)
|
|
198
|
+
const bool prune_numerical_zeros, const bool masked, int* bsr_offsets,
|
|
199
|
+
int* bsr_columns, T* bsr_values, int* bsr_nnz, void* bsr_nnz_event)
|
|
159
200
|
{
|
|
160
201
|
const int block_size = rows_per_block * cols_per_block;
|
|
161
202
|
|
|
@@ -177,8 +218,9 @@ void bsr_matrix_from_triplets_device(const int rows_per_block, const int cols_pe
|
|
|
177
218
|
|
|
178
219
|
// Combine rows and columns so we can sort on them both
|
|
179
220
|
BsrBlockIsNotZero<T> isNotZero{block_size, prune_numerical_zeros ? tpl_values : nullptr};
|
|
221
|
+
BsrBlockInMask mask{masked ? bsr_offsets : nullptr, bsr_columns};
|
|
180
222
|
wp_launch_device(WP_CURRENT_CONTEXT, bsr_fill_triplet_key_values, nnz,
|
|
181
|
-
(nnz, row_count, tpl_rows, tpl_columns, isNotZero, d_keys.Current(), d_values.Current()));
|
|
223
|
+
(nnz, row_count, tpl_rows, tpl_columns, isNotZero, mask, d_keys.Current(), d_values.Current()));
|
|
182
224
|
|
|
183
225
|
// Sort
|
|
184
226
|
{
|
|
@@ -205,7 +247,7 @@ void bsr_matrix_from_triplets_device(const int rows_per_block, const int cols_pe
|
|
|
205
247
|
|
|
206
248
|
if (bsr_nnz)
|
|
207
249
|
{
|
|
208
|
-
// Copy nnz to host, and record an event for the
|
|
250
|
+
// Copy nnz to host, and record an event for the completed transfer if desired
|
|
209
251
|
|
|
210
252
|
memcpy_d2h(WP_CURRENT_CONTEXT, bsr_nnz, bsr_offsets + row_count, sizeof(int), stream);
|
|
211
253
|
|
|
@@ -227,7 +269,7 @@ void bsr_matrix_from_triplets_device(const int rows_per_block, const int cols_pe
|
|
|
227
269
|
|
|
228
270
|
// Accumulate repeated blocks and set column indices
|
|
229
271
|
wp_launch_device(WP_CURRENT_CONTEXT, bsr_merge_blocks, nnz,
|
|
230
|
-
(
|
|
272
|
+
(bsr_offsets + row_count, block_size, d_keys.Alternate(), d_keys.Current(), d_values.Alternate(),
|
|
231
273
|
tpl_values, bsr_columns, bsr_values));
|
|
232
274
|
}
|
|
233
275
|
|
|
@@ -443,22 +485,24 @@ void bsr_transpose_device(int rows_per_block, int cols_per_block, int row_count,
|
|
|
443
485
|
|
|
444
486
|
void bsr_matrix_from_triplets_float_device(int rows_per_block, int cols_per_block, int row_count, int nnz,
|
|
445
487
|
int* tpl_rows, int* tpl_columns, void* tpl_values,
|
|
446
|
-
bool prune_numerical_zeros, int* bsr_offsets, int* bsr_columns,
|
|
488
|
+
bool prune_numerical_zeros, bool masked, int* bsr_offsets, int* bsr_columns,
|
|
447
489
|
void* bsr_values, int* bsr_nnz, void* bsr_nnz_event)
|
|
448
490
|
{
|
|
449
|
-
return bsr_matrix_from_triplets_device<float>(
|
|
450
|
-
|
|
451
|
-
|
|
491
|
+
return bsr_matrix_from_triplets_device<float>(rows_per_block, cols_per_block, row_count, nnz, tpl_rows, tpl_columns,
|
|
492
|
+
static_cast<const float*>(tpl_values), prune_numerical_zeros, masked,
|
|
493
|
+
bsr_offsets, bsr_columns, static_cast<float*>(bsr_values), bsr_nnz,
|
|
494
|
+
bsr_nnz_event);
|
|
452
495
|
}
|
|
453
496
|
|
|
454
497
|
void bsr_matrix_from_triplets_double_device(int rows_per_block, int cols_per_block, int row_count, int nnz,
|
|
455
498
|
int* tpl_rows, int* tpl_columns, void* tpl_values,
|
|
456
|
-
bool prune_numerical_zeros, int* bsr_offsets, int* bsr_columns,
|
|
499
|
+
bool prune_numerical_zeros, bool masked, int* bsr_offsets, int* bsr_columns,
|
|
457
500
|
void* bsr_values, int* bsr_nnz, void* bsr_nnz_event)
|
|
458
501
|
{
|
|
459
|
-
return bsr_matrix_from_triplets_device<double>(
|
|
460
|
-
|
|
461
|
-
|
|
502
|
+
return bsr_matrix_from_triplets_device<double>(rows_per_block, cols_per_block, row_count, nnz, tpl_rows,
|
|
503
|
+
tpl_columns, static_cast<const double*>(tpl_values),
|
|
504
|
+
prune_numerical_zeros, masked, bsr_offsets, bsr_columns,
|
|
505
|
+
static_cast<double*>(bsr_values), bsr_nnz, bsr_nnz_event);
|
|
462
506
|
}
|
|
463
507
|
|
|
464
508
|
void bsr_transpose_float_device(int rows_per_block, int cols_per_block, int row_count, int col_count, int nnz,
|
warp/native/spatial.h
CHANGED
|
@@ -1,9 +1,18 @@
|
|
|
1
|
-
|
|
2
|
-
* NVIDIA CORPORATION
|
|
3
|
-
*
|
|
4
|
-
*
|
|
5
|
-
*
|
|
6
|
-
*
|
|
1
|
+
/*
|
|
2
|
+
* SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
3
|
+
* SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
*
|
|
5
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
* you may not use this file except in compliance with the License.
|
|
7
|
+
* You may obtain a copy of the License at
|
|
8
|
+
*
|
|
9
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
*
|
|
11
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
* See the License for the specific language governing permissions and
|
|
15
|
+
* limitations under the License.
|
|
7
16
|
*/
|
|
8
17
|
|
|
9
18
|
#pragma once
|
warp/native/svd.h
CHANGED
|
@@ -1,9 +1,18 @@
|
|
|
1
|
-
|
|
2
|
-
* NVIDIA CORPORATION
|
|
3
|
-
*
|
|
4
|
-
*
|
|
5
|
-
*
|
|
6
|
-
*
|
|
1
|
+
/*
|
|
2
|
+
* SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
3
|
+
* SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
*
|
|
5
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
* you may not use this file except in compliance with the License.
|
|
7
|
+
* You may obtain a copy of the License at
|
|
8
|
+
*
|
|
9
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
*
|
|
11
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
* See the License for the specific language governing permissions and
|
|
15
|
+
* limitations under the License.
|
|
7
16
|
*/
|
|
8
17
|
|
|
9
18
|
// The MIT License (MIT)
|
|
@@ -423,6 +432,62 @@ void _svd(// input A
|
|
|
423
432
|
);
|
|
424
433
|
}
|
|
425
434
|
|
|
435
|
+
|
|
436
|
+
template<typename Type>
|
|
437
|
+
inline CUDA_CALLABLE
|
|
438
|
+
void _svd_2(// input A
|
|
439
|
+
Type a11, Type a12,
|
|
440
|
+
Type a21, Type a22,
|
|
441
|
+
// output U
|
|
442
|
+
Type &u11, Type &u12,
|
|
443
|
+
Type &u21, Type &u22,
|
|
444
|
+
// output S
|
|
445
|
+
Type &s11, Type &s12,
|
|
446
|
+
Type &s21, Type &s22,
|
|
447
|
+
// output V
|
|
448
|
+
Type &v11, Type &v12,
|
|
449
|
+
Type &v21, Type &v22)
|
|
450
|
+
{
|
|
451
|
+
// Step 1: Compute ATA
|
|
452
|
+
Type ATA11 = a11 * a11 + a21 * a21;
|
|
453
|
+
Type ATA12 = a11 * a12 + a21 * a22;
|
|
454
|
+
Type ATA22 = a12 * a12 + a22 * a22;
|
|
455
|
+
|
|
456
|
+
// Step 2: Eigenanalysis
|
|
457
|
+
Type trace = ATA11 + ATA22;
|
|
458
|
+
Type det = ATA11 * ATA22 - ATA12 * ATA12;
|
|
459
|
+
Type sqrt_term = sqrt(trace * trace - Type(4.0) * det);
|
|
460
|
+
Type lambda1 = (trace + sqrt_term) * Type(0.5);
|
|
461
|
+
Type lambda2 = (trace - sqrt_term) * Type(0.5);
|
|
462
|
+
|
|
463
|
+
// Step 3: Singular values
|
|
464
|
+
Type sigma1 = sqrt(lambda1);
|
|
465
|
+
Type sigma2 = sqrt(lambda2);
|
|
466
|
+
|
|
467
|
+
// Step 4: Eigenvectors (find V)
|
|
468
|
+
Type v1x = ATA12, v1y = lambda1 - ATA11; // For first eigenvector
|
|
469
|
+
Type v2x = ATA12, v2y = lambda2 - ATA11; // For second eigenvector
|
|
470
|
+
Type norm1 = sqrt(v1x * v1x + v1y * v1y);
|
|
471
|
+
Type norm2 = sqrt(v2x * v2x + v2y * v2y);
|
|
472
|
+
|
|
473
|
+
v11 = v1x / norm1; v12 = v2x / norm2;
|
|
474
|
+
v21 = v1y / norm1; v22 = v2y / norm2;
|
|
475
|
+
|
|
476
|
+
// Step 5: Compute U
|
|
477
|
+
Type inv_sigma1 = (sigma1 > Type(1e-6)) ? Type(1.0) / sigma1 : Type(0.0);
|
|
478
|
+
Type inv_sigma2 = (sigma2 > Type(1e-6)) ? Type(1.0) / sigma2 : Type(0.0);
|
|
479
|
+
|
|
480
|
+
u11 = (a11 * v11 + a12 * v21) * inv_sigma1;
|
|
481
|
+
u12 = (a11 * v12 + a12 * v22) * inv_sigma2;
|
|
482
|
+
u21 = (a21 * v11 + a22 * v21) * inv_sigma1;
|
|
483
|
+
u22 = (a21 * v12 + a22 * v22) * inv_sigma2;
|
|
484
|
+
|
|
485
|
+
// Step 6: Set S
|
|
486
|
+
s11 = sigma1; s12 = Type(0.0);
|
|
487
|
+
s21 = Type(0.0); s22 = sigma2;
|
|
488
|
+
}
|
|
489
|
+
|
|
490
|
+
|
|
426
491
|
template<typename Type>
|
|
427
492
|
inline CUDA_CALLABLE void svd3(const mat_t<3,3,Type>& A, mat_t<3,3,Type>& U, vec_t<3,Type>& sigma, mat_t<3,3,Type>& V) {
|
|
428
493
|
Type s12, s13, s21, s23, s31, s32;
|
|
@@ -483,6 +548,66 @@ inline CUDA_CALLABLE void adj_svd3(const mat_t<3,3,Type>& A,
|
|
|
483
548
|
adj_A = adj_A + (u_term + v_term + sigma_term);
|
|
484
549
|
}
|
|
485
550
|
|
|
551
|
+
template<typename Type>
|
|
552
|
+
inline CUDA_CALLABLE void svd2(const mat_t<2,2,Type>& A, mat_t<2,2,Type>& U, vec_t<2,Type>& sigma, mat_t<2,2,Type>& V) {
|
|
553
|
+
Type s12, s21;
|
|
554
|
+
_svd_2(A.data[0][0], A.data[0][1],
|
|
555
|
+
A.data[1][0], A.data[1][1],
|
|
556
|
+
|
|
557
|
+
U.data[0][0], U.data[0][1],
|
|
558
|
+
U.data[1][0], U.data[1][1],
|
|
559
|
+
|
|
560
|
+
sigma[0], s12,
|
|
561
|
+
s21, sigma[1],
|
|
562
|
+
|
|
563
|
+
V.data[0][0], V.data[0][1],
|
|
564
|
+
V.data[1][0], V.data[1][1]);
|
|
565
|
+
}
|
|
566
|
+
|
|
567
|
+
template<typename Type>
|
|
568
|
+
inline CUDA_CALLABLE void adj_svd2(const mat_t<2,2,Type>& A,
|
|
569
|
+
const mat_t<2,2,Type>& U,
|
|
570
|
+
const vec_t<2,Type>& sigma,
|
|
571
|
+
const mat_t<2,2,Type>& V,
|
|
572
|
+
mat_t<2,2,Type>& adj_A,
|
|
573
|
+
const mat_t<2,2,Type>& adj_U,
|
|
574
|
+
const vec_t<2,Type>& adj_sigma,
|
|
575
|
+
const mat_t<2,2,Type>& adj_V) {
|
|
576
|
+
Type s1_squared = sigma[0] * sigma[0];
|
|
577
|
+
Type s2_squared = sigma[1] * sigma[1];
|
|
578
|
+
|
|
579
|
+
// Compute inverse of (s1^2 - s2^2) if possible, use small epsilon to prevent division by zero
|
|
580
|
+
Type F01 = Type(1) / min(s2_squared - s1_squared, Type(-1e-6f));
|
|
581
|
+
|
|
582
|
+
// Construct the matrix F for the adjoint
|
|
583
|
+
mat_t<2,2,Type> F = mat_t<2,2,Type>(0.0, F01,
|
|
584
|
+
-F01, 0.0);
|
|
585
|
+
|
|
586
|
+
// Create a matrix to handle the adjoint of the singular values (diagonal matrix)
|
|
587
|
+
mat_t<2,2,Type> adj_sigma_mat = mat_t<2,2,Type>(adj_sigma[0], 0.0,
|
|
588
|
+
0.0, adj_sigma[1]);
|
|
589
|
+
|
|
590
|
+
// Matrix for handling singular values (diagonal matrix with sigma values)
|
|
591
|
+
mat_t<2,2,Type> s_mat = mat_t<2,2,Type>(sigma[0], 0.0,
|
|
592
|
+
0.0, sigma[1]);
|
|
593
|
+
|
|
594
|
+
// Compute the transpose of U and V
|
|
595
|
+
mat_t<2,2,Type> UT = transpose(U);
|
|
596
|
+
mat_t<2,2,Type> VT = transpose(V);
|
|
597
|
+
|
|
598
|
+
// Compute the term for sigma (diagonal matrix of adjoint singular values)
|
|
599
|
+
mat_t<2,2,Type> sigma_term = mul(U, mul(adj_sigma_mat, VT));
|
|
600
|
+
|
|
601
|
+
// Compute the adjoint contributions for U (left singular vectors)
|
|
602
|
+
mat_t<2,2,Type> u_term = mul(mul(U, mul(cw_mul(F, (mul(UT, adj_U) - mul(transpose(adj_U), U))), s_mat)), VT);
|
|
603
|
+
|
|
604
|
+
// Compute the adjoint contributions for V (right singular vectors)
|
|
605
|
+
mat_t<2,2,Type> v_term = mul(U, mul(s_mat, mul(cw_mul(F, (mul(VT, adj_V) - mul(transpose(adj_V), V))), VT)));
|
|
606
|
+
|
|
607
|
+
// Combine the terms to compute the adjoint of A
|
|
608
|
+
adj_A = adj_A + (u_term + v_term + sigma_term);
|
|
609
|
+
}
|
|
610
|
+
|
|
486
611
|
|
|
487
612
|
template<typename Type>
|
|
488
613
|
inline CUDA_CALLABLE void qr3(const mat_t<3,3,Type>& A, mat_t<3,3,Type>& Q, mat_t<3,3,Type>& R) {
|
warp/native/temp_buffer.h
CHANGED
|
@@ -1,9 +1,18 @@
|
|
|
1
|
-
|
|
2
|
-
* NVIDIA CORPORATION
|
|
3
|
-
*
|
|
4
|
-
*
|
|
5
|
-
*
|
|
6
|
-
*
|
|
1
|
+
/*
|
|
2
|
+
* SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
3
|
+
* SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
*
|
|
5
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
* you may not use this file except in compliance with the License.
|
|
7
|
+
* You may obtain a copy of the License at
|
|
8
|
+
*
|
|
9
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
*
|
|
11
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
* See the License for the specific language governing permissions and
|
|
15
|
+
* limitations under the License.
|
|
7
16
|
*/
|
|
8
17
|
|
|
9
18
|
#pragma once
|