warp-lang 1.10.0__py3-none-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +334 -0
- warp/__init__.pyi +5856 -0
- warp/_src/__init__.py +14 -0
- warp/_src/autograd.py +1077 -0
- warp/_src/build.py +620 -0
- warp/_src/build_dll.py +642 -0
- warp/_src/builtins.py +10555 -0
- warp/_src/codegen.py +4361 -0
- warp/_src/config.py +178 -0
- warp/_src/constants.py +59 -0
- warp/_src/context.py +8352 -0
- warp/_src/dlpack.py +464 -0
- warp/_src/fabric.py +362 -0
- warp/_src/fem/__init__.py +14 -0
- warp/_src/fem/adaptivity.py +510 -0
- warp/_src/fem/cache.py +689 -0
- warp/_src/fem/dirichlet.py +190 -0
- warp/_src/fem/domain.py +553 -0
- warp/_src/fem/field/__init__.py +131 -0
- warp/_src/fem/field/field.py +703 -0
- warp/_src/fem/field/nodal_field.py +403 -0
- warp/_src/fem/field/restriction.py +39 -0
- warp/_src/fem/field/virtual.py +1021 -0
- warp/_src/fem/geometry/__init__.py +32 -0
- warp/_src/fem/geometry/adaptive_nanogrid.py +782 -0
- warp/_src/fem/geometry/closest_point.py +99 -0
- warp/_src/fem/geometry/deformed_geometry.py +277 -0
- warp/_src/fem/geometry/element.py +854 -0
- warp/_src/fem/geometry/geometry.py +693 -0
- warp/_src/fem/geometry/grid_2d.py +478 -0
- warp/_src/fem/geometry/grid_3d.py +539 -0
- warp/_src/fem/geometry/hexmesh.py +956 -0
- warp/_src/fem/geometry/nanogrid.py +660 -0
- warp/_src/fem/geometry/partition.py +483 -0
- warp/_src/fem/geometry/quadmesh.py +597 -0
- warp/_src/fem/geometry/tetmesh.py +762 -0
- warp/_src/fem/geometry/trimesh.py +588 -0
- warp/_src/fem/integrate.py +2507 -0
- warp/_src/fem/linalg.py +385 -0
- warp/_src/fem/operator.py +398 -0
- warp/_src/fem/polynomial.py +231 -0
- warp/_src/fem/quadrature/__init__.py +17 -0
- warp/_src/fem/quadrature/pic_quadrature.py +318 -0
- warp/_src/fem/quadrature/quadrature.py +665 -0
- warp/_src/fem/space/__init__.py +248 -0
- warp/_src/fem/space/basis_function_space.py +499 -0
- warp/_src/fem/space/basis_space.py +681 -0
- warp/_src/fem/space/dof_mapper.py +253 -0
- warp/_src/fem/space/function_space.py +312 -0
- warp/_src/fem/space/grid_2d_function_space.py +179 -0
- warp/_src/fem/space/grid_3d_function_space.py +229 -0
- warp/_src/fem/space/hexmesh_function_space.py +255 -0
- warp/_src/fem/space/nanogrid_function_space.py +199 -0
- warp/_src/fem/space/partition.py +435 -0
- warp/_src/fem/space/quadmesh_function_space.py +222 -0
- warp/_src/fem/space/restriction.py +221 -0
- warp/_src/fem/space/shape/__init__.py +152 -0
- warp/_src/fem/space/shape/cube_shape_function.py +1107 -0
- warp/_src/fem/space/shape/shape_function.py +134 -0
- warp/_src/fem/space/shape/square_shape_function.py +928 -0
- warp/_src/fem/space/shape/tet_shape_function.py +829 -0
- warp/_src/fem/space/shape/triangle_shape_function.py +674 -0
- warp/_src/fem/space/tetmesh_function_space.py +270 -0
- warp/_src/fem/space/topology.py +461 -0
- warp/_src/fem/space/trimesh_function_space.py +193 -0
- warp/_src/fem/types.py +114 -0
- warp/_src/fem/utils.py +488 -0
- warp/_src/jax.py +188 -0
- warp/_src/jax_experimental/__init__.py +14 -0
- warp/_src/jax_experimental/custom_call.py +389 -0
- warp/_src/jax_experimental/ffi.py +1286 -0
- warp/_src/jax_experimental/xla_ffi.py +658 -0
- warp/_src/marching_cubes.py +710 -0
- warp/_src/math.py +416 -0
- warp/_src/optim/__init__.py +14 -0
- warp/_src/optim/adam.py +165 -0
- warp/_src/optim/linear.py +1608 -0
- warp/_src/optim/sgd.py +114 -0
- warp/_src/paddle.py +408 -0
- warp/_src/render/__init__.py +14 -0
- warp/_src/render/imgui_manager.py +291 -0
- warp/_src/render/render_opengl.py +3638 -0
- warp/_src/render/render_usd.py +939 -0
- warp/_src/render/utils.py +162 -0
- warp/_src/sparse.py +2718 -0
- warp/_src/tape.py +1208 -0
- warp/_src/thirdparty/__init__.py +0 -0
- warp/_src/thirdparty/appdirs.py +598 -0
- warp/_src/thirdparty/dlpack.py +145 -0
- warp/_src/thirdparty/unittest_parallel.py +676 -0
- warp/_src/torch.py +393 -0
- warp/_src/types.py +5888 -0
- warp/_src/utils.py +1695 -0
- warp/autograd.py +33 -0
- warp/bin/libwarp-clang.dylib +0 -0
- warp/bin/libwarp.dylib +0 -0
- warp/build.py +29 -0
- warp/build_dll.py +24 -0
- warp/codegen.py +24 -0
- warp/constants.py +24 -0
- warp/context.py +33 -0
- warp/dlpack.py +24 -0
- warp/examples/__init__.py +24 -0
- warp/examples/assets/bear.usd +0 -0
- warp/examples/assets/bunny.usd +0 -0
- warp/examples/assets/cube.usd +0 -0
- warp/examples/assets/nonuniform.usd +0 -0
- warp/examples/assets/nvidia_logo.png +0 -0
- warp/examples/assets/pixel.jpg +0 -0
- warp/examples/assets/rocks.nvdb +0 -0
- warp/examples/assets/rocks.usd +0 -0
- warp/examples/assets/sphere.usd +0 -0
- warp/examples/assets/square_cloth.usd +0 -0
- warp/examples/benchmarks/benchmark_api.py +389 -0
- warp/examples/benchmarks/benchmark_cloth.py +296 -0
- warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
- warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
- warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
- warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
- warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
- warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
- warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
- warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
- warp/examples/benchmarks/benchmark_gemm.py +164 -0
- warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
- warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
- warp/examples/benchmarks/benchmark_launches.py +301 -0
- warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
- warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
- warp/examples/browse.py +37 -0
- warp/examples/core/example_cupy.py +86 -0
- warp/examples/core/example_dem.py +241 -0
- warp/examples/core/example_fluid.py +299 -0
- warp/examples/core/example_graph_capture.py +150 -0
- warp/examples/core/example_marching_cubes.py +195 -0
- warp/examples/core/example_mesh.py +180 -0
- warp/examples/core/example_mesh_intersect.py +211 -0
- warp/examples/core/example_nvdb.py +182 -0
- warp/examples/core/example_raycast.py +111 -0
- warp/examples/core/example_raymarch.py +205 -0
- warp/examples/core/example_render_opengl.py +290 -0
- warp/examples/core/example_sample_mesh.py +300 -0
- warp/examples/core/example_sph.py +411 -0
- warp/examples/core/example_spin_lock.py +93 -0
- warp/examples/core/example_torch.py +211 -0
- warp/examples/core/example_wave.py +269 -0
- warp/examples/core/example_work_queue.py +118 -0
- warp/examples/distributed/example_jacobi_mpi.py +506 -0
- warp/examples/fem/example_adaptive_grid.py +286 -0
- warp/examples/fem/example_apic_fluid.py +469 -0
- warp/examples/fem/example_burgers.py +261 -0
- warp/examples/fem/example_convection_diffusion.py +181 -0
- warp/examples/fem/example_convection_diffusion_dg.py +225 -0
- warp/examples/fem/example_darcy_ls_optimization.py +489 -0
- warp/examples/fem/example_deformed_geometry.py +172 -0
- warp/examples/fem/example_diffusion.py +196 -0
- warp/examples/fem/example_diffusion_3d.py +225 -0
- warp/examples/fem/example_diffusion_mgpu.py +225 -0
- warp/examples/fem/example_distortion_energy.py +228 -0
- warp/examples/fem/example_elastic_shape_optimization.py +387 -0
- warp/examples/fem/example_magnetostatics.py +242 -0
- warp/examples/fem/example_mixed_elasticity.py +293 -0
- warp/examples/fem/example_navier_stokes.py +263 -0
- warp/examples/fem/example_nonconforming_contact.py +300 -0
- warp/examples/fem/example_stokes.py +213 -0
- warp/examples/fem/example_stokes_transfer.py +262 -0
- warp/examples/fem/example_streamlines.py +357 -0
- warp/examples/fem/utils.py +1047 -0
- warp/examples/interop/example_jax_callable.py +146 -0
- warp/examples/interop/example_jax_ffi_callback.py +132 -0
- warp/examples/interop/example_jax_kernel.py +232 -0
- warp/examples/optim/example_diffray.py +561 -0
- warp/examples/optim/example_fluid_checkpoint.py +497 -0
- warp/examples/tile/example_tile_block_cholesky.py +502 -0
- warp/examples/tile/example_tile_cholesky.py +88 -0
- warp/examples/tile/example_tile_convolution.py +66 -0
- warp/examples/tile/example_tile_fft.py +55 -0
- warp/examples/tile/example_tile_filtering.py +113 -0
- warp/examples/tile/example_tile_matmul.py +85 -0
- warp/examples/tile/example_tile_mcgp.py +191 -0
- warp/examples/tile/example_tile_mlp.py +385 -0
- warp/examples/tile/example_tile_nbody.py +199 -0
- warp/fabric.py +24 -0
- warp/fem/__init__.py +173 -0
- warp/fem/adaptivity.py +26 -0
- warp/fem/cache.py +30 -0
- warp/fem/dirichlet.py +24 -0
- warp/fem/field/__init__.py +24 -0
- warp/fem/field/field.py +26 -0
- warp/fem/geometry/__init__.py +21 -0
- warp/fem/geometry/closest_point.py +31 -0
- warp/fem/linalg.py +38 -0
- warp/fem/operator.py +32 -0
- warp/fem/polynomial.py +29 -0
- warp/fem/space/__init__.py +22 -0
- warp/fem/space/basis_space.py +24 -0
- warp/fem/space/shape/__init__.py +68 -0
- warp/fem/space/topology.py +24 -0
- warp/fem/types.py +24 -0
- warp/fem/utils.py +32 -0
- warp/jax.py +29 -0
- warp/jax_experimental/__init__.py +29 -0
- warp/jax_experimental/custom_call.py +29 -0
- warp/jax_experimental/ffi.py +39 -0
- warp/jax_experimental/xla_ffi.py +24 -0
- warp/marching_cubes.py +24 -0
- warp/math.py +37 -0
- warp/native/array.h +1687 -0
- warp/native/builtin.h +2327 -0
- warp/native/bvh.cpp +562 -0
- warp/native/bvh.cu +826 -0
- warp/native/bvh.h +555 -0
- warp/native/clang/clang.cpp +541 -0
- warp/native/coloring.cpp +622 -0
- warp/native/crt.cpp +51 -0
- warp/native/crt.h +568 -0
- warp/native/cuda_crt.h +1058 -0
- warp/native/cuda_util.cpp +677 -0
- warp/native/cuda_util.h +313 -0
- warp/native/error.cpp +77 -0
- warp/native/error.h +36 -0
- warp/native/exports.h +2023 -0
- warp/native/fabric.h +246 -0
- warp/native/hashgrid.cpp +311 -0
- warp/native/hashgrid.cu +89 -0
- warp/native/hashgrid.h +240 -0
- warp/native/initializer_array.h +41 -0
- warp/native/intersect.h +1253 -0
- warp/native/intersect_adj.h +375 -0
- warp/native/intersect_tri.h +348 -0
- warp/native/mat.h +5189 -0
- warp/native/mathdx.cpp +93 -0
- warp/native/matnn.h +221 -0
- warp/native/mesh.cpp +266 -0
- warp/native/mesh.cu +406 -0
- warp/native/mesh.h +2097 -0
- warp/native/nanovdb/GridHandle.h +533 -0
- warp/native/nanovdb/HostBuffer.h +591 -0
- warp/native/nanovdb/NanoVDB.h +6246 -0
- warp/native/nanovdb/NodeManager.h +323 -0
- warp/native/nanovdb/PNanoVDB.h +3390 -0
- warp/native/noise.h +859 -0
- warp/native/quat.h +1664 -0
- warp/native/rand.h +342 -0
- warp/native/range.h +145 -0
- warp/native/reduce.cpp +174 -0
- warp/native/reduce.cu +363 -0
- warp/native/runlength_encode.cpp +79 -0
- warp/native/runlength_encode.cu +61 -0
- warp/native/scan.cpp +47 -0
- warp/native/scan.cu +55 -0
- warp/native/scan.h +23 -0
- warp/native/solid_angle.h +466 -0
- warp/native/sort.cpp +251 -0
- warp/native/sort.cu +286 -0
- warp/native/sort.h +35 -0
- warp/native/sparse.cpp +241 -0
- warp/native/sparse.cu +435 -0
- warp/native/spatial.h +1306 -0
- warp/native/svd.h +727 -0
- warp/native/temp_buffer.h +46 -0
- warp/native/tile.h +4124 -0
- warp/native/tile_radix_sort.h +1112 -0
- warp/native/tile_reduce.h +838 -0
- warp/native/tile_scan.h +240 -0
- warp/native/tuple.h +189 -0
- warp/native/vec.h +2199 -0
- warp/native/version.h +23 -0
- warp/native/volume.cpp +501 -0
- warp/native/volume.cu +68 -0
- warp/native/volume.h +970 -0
- warp/native/volume_builder.cu +483 -0
- warp/native/volume_builder.h +52 -0
- warp/native/volume_impl.h +70 -0
- warp/native/warp.cpp +1143 -0
- warp/native/warp.cu +4604 -0
- warp/native/warp.h +358 -0
- warp/optim/__init__.py +20 -0
- warp/optim/adam.py +24 -0
- warp/optim/linear.py +35 -0
- warp/optim/sgd.py +24 -0
- warp/paddle.py +24 -0
- warp/py.typed +0 -0
- warp/render/__init__.py +22 -0
- warp/render/imgui_manager.py +29 -0
- warp/render/render_opengl.py +24 -0
- warp/render/render_usd.py +24 -0
- warp/render/utils.py +24 -0
- warp/sparse.py +51 -0
- warp/tape.py +24 -0
- warp/tests/__init__.py +1 -0
- warp/tests/__main__.py +4 -0
- warp/tests/assets/curlnoise_golden.npy +0 -0
- warp/tests/assets/mlp_golden.npy +0 -0
- warp/tests/assets/pixel.npy +0 -0
- warp/tests/assets/pnoise_golden.npy +0 -0
- warp/tests/assets/spiky.usd +0 -0
- warp/tests/assets/test_grid.nvdb +0 -0
- warp/tests/assets/test_index_grid.nvdb +0 -0
- warp/tests/assets/test_int32_grid.nvdb +0 -0
- warp/tests/assets/test_vec_grid.nvdb +0 -0
- warp/tests/assets/torus.nvdb +0 -0
- warp/tests/assets/torus.usda +105 -0
- warp/tests/aux_test_class_kernel.py +34 -0
- warp/tests/aux_test_compile_consts_dummy.py +18 -0
- warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
- warp/tests/aux_test_dependent.py +29 -0
- warp/tests/aux_test_grad_customs.py +29 -0
- warp/tests/aux_test_instancing_gc.py +26 -0
- warp/tests/aux_test_module_aot.py +7 -0
- warp/tests/aux_test_module_unload.py +23 -0
- warp/tests/aux_test_name_clash1.py +40 -0
- warp/tests/aux_test_name_clash2.py +40 -0
- warp/tests/aux_test_reference.py +9 -0
- warp/tests/aux_test_reference_reference.py +8 -0
- warp/tests/aux_test_square.py +16 -0
- warp/tests/aux_test_unresolved_func.py +22 -0
- warp/tests/aux_test_unresolved_symbol.py +22 -0
- warp/tests/cuda/__init__.py +0 -0
- warp/tests/cuda/test_async.py +676 -0
- warp/tests/cuda/test_conditional_captures.py +1147 -0
- warp/tests/cuda/test_ipc.py +124 -0
- warp/tests/cuda/test_mempool.py +233 -0
- warp/tests/cuda/test_multigpu.py +169 -0
- warp/tests/cuda/test_peer.py +139 -0
- warp/tests/cuda/test_pinned.py +84 -0
- warp/tests/cuda/test_streams.py +691 -0
- warp/tests/geometry/__init__.py +0 -0
- warp/tests/geometry/test_bvh.py +335 -0
- warp/tests/geometry/test_hash_grid.py +259 -0
- warp/tests/geometry/test_marching_cubes.py +294 -0
- warp/tests/geometry/test_mesh.py +318 -0
- warp/tests/geometry/test_mesh_query_aabb.py +392 -0
- warp/tests/geometry/test_mesh_query_point.py +935 -0
- warp/tests/geometry/test_mesh_query_ray.py +323 -0
- warp/tests/geometry/test_volume.py +1103 -0
- warp/tests/geometry/test_volume_write.py +346 -0
- warp/tests/interop/__init__.py +0 -0
- warp/tests/interop/test_dlpack.py +730 -0
- warp/tests/interop/test_jax.py +1673 -0
- warp/tests/interop/test_paddle.py +800 -0
- warp/tests/interop/test_torch.py +1001 -0
- warp/tests/run_coverage_serial.py +39 -0
- warp/tests/test_adam.py +162 -0
- warp/tests/test_arithmetic.py +1096 -0
- warp/tests/test_array.py +3756 -0
- warp/tests/test_array_reduce.py +156 -0
- warp/tests/test_assert.py +303 -0
- warp/tests/test_atomic.py +336 -0
- warp/tests/test_atomic_bitwise.py +209 -0
- warp/tests/test_atomic_cas.py +312 -0
- warp/tests/test_bool.py +220 -0
- warp/tests/test_builtins_resolution.py +732 -0
- warp/tests/test_closest_point_edge_edge.py +327 -0
- warp/tests/test_codegen.py +974 -0
- warp/tests/test_codegen_instancing.py +1495 -0
- warp/tests/test_compile_consts.py +215 -0
- warp/tests/test_conditional.py +298 -0
- warp/tests/test_context.py +35 -0
- warp/tests/test_copy.py +319 -0
- warp/tests/test_ctypes.py +618 -0
- warp/tests/test_dense.py +73 -0
- warp/tests/test_devices.py +127 -0
- warp/tests/test_enum.py +136 -0
- warp/tests/test_examples.py +424 -0
- warp/tests/test_fabricarray.py +998 -0
- warp/tests/test_fast_math.py +72 -0
- warp/tests/test_fem.py +2204 -0
- warp/tests/test_fixedarray.py +229 -0
- warp/tests/test_fp16.py +136 -0
- warp/tests/test_func.py +501 -0
- warp/tests/test_future_annotations.py +100 -0
- warp/tests/test_generics.py +656 -0
- warp/tests/test_grad.py +893 -0
- warp/tests/test_grad_customs.py +339 -0
- warp/tests/test_grad_debug.py +341 -0
- warp/tests/test_implicit_init.py +411 -0
- warp/tests/test_import.py +45 -0
- warp/tests/test_indexedarray.py +1140 -0
- warp/tests/test_intersect.py +103 -0
- warp/tests/test_iter.py +76 -0
- warp/tests/test_large.py +177 -0
- warp/tests/test_launch.py +411 -0
- warp/tests/test_lerp.py +151 -0
- warp/tests/test_linear_solvers.py +223 -0
- warp/tests/test_lvalue.py +427 -0
- warp/tests/test_map.py +526 -0
- warp/tests/test_mat.py +3515 -0
- warp/tests/test_mat_assign_copy.py +178 -0
- warp/tests/test_mat_constructors.py +573 -0
- warp/tests/test_mat_lite.py +122 -0
- warp/tests/test_mat_scalar_ops.py +2913 -0
- warp/tests/test_math.py +212 -0
- warp/tests/test_module_aot.py +287 -0
- warp/tests/test_module_hashing.py +258 -0
- warp/tests/test_modules_lite.py +70 -0
- warp/tests/test_noise.py +252 -0
- warp/tests/test_operators.py +299 -0
- warp/tests/test_options.py +129 -0
- warp/tests/test_overwrite.py +551 -0
- warp/tests/test_print.py +408 -0
- warp/tests/test_quat.py +2653 -0
- warp/tests/test_quat_assign_copy.py +145 -0
- warp/tests/test_rand.py +339 -0
- warp/tests/test_reload.py +303 -0
- warp/tests/test_rounding.py +157 -0
- warp/tests/test_runlength_encode.py +196 -0
- warp/tests/test_scalar_ops.py +133 -0
- warp/tests/test_smoothstep.py +108 -0
- warp/tests/test_snippet.py +318 -0
- warp/tests/test_sparse.py +845 -0
- warp/tests/test_spatial.py +2859 -0
- warp/tests/test_spatial_assign_copy.py +160 -0
- warp/tests/test_special_values.py +361 -0
- warp/tests/test_static.py +640 -0
- warp/tests/test_struct.py +901 -0
- warp/tests/test_tape.py +242 -0
- warp/tests/test_transient_module.py +93 -0
- warp/tests/test_triangle_closest_point.py +192 -0
- warp/tests/test_tuple.py +361 -0
- warp/tests/test_types.py +615 -0
- warp/tests/test_utils.py +594 -0
- warp/tests/test_vec.py +1408 -0
- warp/tests/test_vec_assign_copy.py +143 -0
- warp/tests/test_vec_constructors.py +325 -0
- warp/tests/test_vec_lite.py +80 -0
- warp/tests/test_vec_scalar_ops.py +2327 -0
- warp/tests/test_verify_fp.py +100 -0
- warp/tests/test_version.py +75 -0
- warp/tests/tile/__init__.py +0 -0
- warp/tests/tile/test_tile.py +1519 -0
- warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
- warp/tests/tile/test_tile_cholesky.py +608 -0
- warp/tests/tile/test_tile_load.py +724 -0
- warp/tests/tile/test_tile_mathdx.py +156 -0
- warp/tests/tile/test_tile_matmul.py +179 -0
- warp/tests/tile/test_tile_mlp.py +400 -0
- warp/tests/tile/test_tile_reduce.py +950 -0
- warp/tests/tile/test_tile_shared_memory.py +376 -0
- warp/tests/tile/test_tile_sort.py +121 -0
- warp/tests/tile/test_tile_view.py +173 -0
- warp/tests/unittest_serial.py +47 -0
- warp/tests/unittest_suites.py +430 -0
- warp/tests/unittest_utils.py +469 -0
- warp/tests/walkthrough_debug.py +95 -0
- warp/torch.py +24 -0
- warp/types.py +51 -0
- warp/utils.py +31 -0
- warp_lang-1.10.0.dist-info/METADATA +459 -0
- warp_lang-1.10.0.dist-info/RECORD +468 -0
- warp_lang-1.10.0.dist-info/WHEEL +5 -0
- warp_lang-1.10.0.dist-info/licenses/LICENSE.md +176 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
- warp_lang-1.10.0.dist-info/top_level.txt +1 -0
warp/native/sort.cu
ADDED
|
@@ -0,0 +1,286 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
3
|
+
* SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
*
|
|
5
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
* you may not use this file except in compliance with the License.
|
|
7
|
+
* You may obtain a copy of the License at
|
|
8
|
+
*
|
|
9
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
*
|
|
11
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
* See the License for the specific language governing permissions and
|
|
15
|
+
* limitations under the License.
|
|
16
|
+
*/
|
|
17
|
+
|
|
18
|
+
#include "warp.h"
|
|
19
|
+
#include "cuda_util.h"
|
|
20
|
+
#include "sort.h"
|
|
21
|
+
|
|
22
|
+
#define THRUST_IGNORE_CUB_VERSION_CHECK
|
|
23
|
+
|
|
24
|
+
#include <cub/cub.cuh>
|
|
25
|
+
|
|
26
|
+
#include <unordered_map>
|
|
27
|
+
|
|
28
|
+
// temporary buffer for radix sort
|
|
29
|
+
struct RadixSortTemp
|
|
30
|
+
{
|
|
31
|
+
void* mem = NULL;
|
|
32
|
+
size_t size = 0;
|
|
33
|
+
};
|
|
34
|
+
|
|
35
|
+
// use unique temp buffers per CUDA stream to avoid race conditions
|
|
36
|
+
static std::unordered_map<void*, RadixSortTemp> g_radix_sort_temp_map;
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
template <typename KeyType>
|
|
40
|
+
void radix_sort_reserve_internal(void* context, int n, void** mem_out, size_t* size_out)
|
|
41
|
+
{
|
|
42
|
+
ContextGuard guard(context);
|
|
43
|
+
|
|
44
|
+
cub::DoubleBuffer<KeyType> d_keys;
|
|
45
|
+
cub::DoubleBuffer<int> d_values;
|
|
46
|
+
|
|
47
|
+
CUstream stream = static_cast<CUstream>(wp_cuda_stream_get_current());
|
|
48
|
+
|
|
49
|
+
// compute temporary memory required
|
|
50
|
+
size_t sort_temp_size;
|
|
51
|
+
check_cuda(cub::DeviceRadixSort::SortPairs(
|
|
52
|
+
NULL,
|
|
53
|
+
sort_temp_size,
|
|
54
|
+
d_keys,
|
|
55
|
+
d_values,
|
|
56
|
+
n, 0, sizeof(KeyType)*8,
|
|
57
|
+
stream));
|
|
58
|
+
|
|
59
|
+
RadixSortTemp& temp = g_radix_sort_temp_map[stream];
|
|
60
|
+
|
|
61
|
+
if (sort_temp_size > temp.size)
|
|
62
|
+
{
|
|
63
|
+
wp_free_device(WP_CURRENT_CONTEXT, temp.mem);
|
|
64
|
+
temp.mem = wp_alloc_device(WP_CURRENT_CONTEXT, sort_temp_size);
|
|
65
|
+
temp.size = sort_temp_size;
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
if (mem_out)
|
|
69
|
+
*mem_out = temp.mem;
|
|
70
|
+
if (size_out)
|
|
71
|
+
*size_out = temp.size;
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
void radix_sort_reserve(void* context, int n, void** mem_out, size_t* size_out)
|
|
75
|
+
{
|
|
76
|
+
radix_sort_reserve_internal<int>(context, n, mem_out, size_out);
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
void radix_sort_release(void* context, void* stream)
|
|
80
|
+
{
|
|
81
|
+
// release temporary buffer for the given stream, if it exists
|
|
82
|
+
auto it = g_radix_sort_temp_map.find(stream);
|
|
83
|
+
if (it != g_radix_sort_temp_map.end())
|
|
84
|
+
{
|
|
85
|
+
wp_free_device(context, it->second.mem);
|
|
86
|
+
g_radix_sort_temp_map.erase(it);
|
|
87
|
+
}
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
template <typename KeyType>
|
|
91
|
+
void radix_sort_pairs_device(void* context, KeyType* keys, int* values, int n)
|
|
92
|
+
{
|
|
93
|
+
ContextGuard guard(context);
|
|
94
|
+
|
|
95
|
+
cub::DoubleBuffer<KeyType> d_keys(keys, keys + n);
|
|
96
|
+
cub::DoubleBuffer<int> d_values(values, values + n);
|
|
97
|
+
|
|
98
|
+
RadixSortTemp temp;
|
|
99
|
+
radix_sort_reserve_internal<KeyType>(WP_CURRENT_CONTEXT, n, &temp.mem, &temp.size);
|
|
100
|
+
|
|
101
|
+
// sort
|
|
102
|
+
check_cuda(cub::DeviceRadixSort::SortPairs(
|
|
103
|
+
temp.mem,
|
|
104
|
+
temp.size,
|
|
105
|
+
d_keys,
|
|
106
|
+
d_values,
|
|
107
|
+
n, 0, sizeof(KeyType)*8,
|
|
108
|
+
(cudaStream_t)wp_cuda_stream_get_current()));
|
|
109
|
+
|
|
110
|
+
if (d_keys.Current() != keys)
|
|
111
|
+
wp_memcpy_d2d(WP_CURRENT_CONTEXT, keys, d_keys.Current(), sizeof(KeyType)*n);
|
|
112
|
+
|
|
113
|
+
if (d_values.Current() != values)
|
|
114
|
+
wp_memcpy_d2d(WP_CURRENT_CONTEXT, values, d_values.Current(), sizeof(int)*n);
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
void radix_sort_pairs_device(void* context, int* keys, int* values, int n)
|
|
118
|
+
{
|
|
119
|
+
radix_sort_pairs_device<int>(context, keys, values, n);
|
|
120
|
+
}
|
|
121
|
+
|
|
122
|
+
void radix_sort_pairs_device(void* context, float* keys, int* values, int n)
|
|
123
|
+
{
|
|
124
|
+
radix_sort_pairs_device<float>(context, keys, values, n);
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
void radix_sort_pairs_device(void* context, int64_t* keys, int* values, int n)
|
|
128
|
+
{
|
|
129
|
+
radix_sort_pairs_device<int64_t>(context, keys, values, n);
|
|
130
|
+
}
|
|
131
|
+
|
|
132
|
+
void wp_radix_sort_pairs_int_device(uint64_t keys, uint64_t values, int n)
|
|
133
|
+
{
|
|
134
|
+
radix_sort_pairs_device(
|
|
135
|
+
WP_CURRENT_CONTEXT,
|
|
136
|
+
reinterpret_cast<int *>(keys),
|
|
137
|
+
reinterpret_cast<int *>(values), n);
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
void wp_radix_sort_pairs_float_device(uint64_t keys, uint64_t values, int n)
|
|
141
|
+
{
|
|
142
|
+
radix_sort_pairs_device(
|
|
143
|
+
WP_CURRENT_CONTEXT,
|
|
144
|
+
reinterpret_cast<float *>(keys),
|
|
145
|
+
reinterpret_cast<int *>(values), n);
|
|
146
|
+
}
|
|
147
|
+
|
|
148
|
+
void wp_radix_sort_pairs_int64_device(uint64_t keys, uint64_t values, int n)
|
|
149
|
+
{
|
|
150
|
+
radix_sort_pairs_device(
|
|
151
|
+
WP_CURRENT_CONTEXT,
|
|
152
|
+
reinterpret_cast<int64_t *>(keys),
|
|
153
|
+
reinterpret_cast<int *>(values), n);
|
|
154
|
+
}
|
|
155
|
+
|
|
156
|
+
void segmented_sort_reserve(void* context, int n, int num_segments, void** mem_out, size_t* size_out)
|
|
157
|
+
{
|
|
158
|
+
ContextGuard guard(context);
|
|
159
|
+
|
|
160
|
+
cub::DoubleBuffer<int> d_keys;
|
|
161
|
+
cub::DoubleBuffer<int> d_values;
|
|
162
|
+
|
|
163
|
+
int* start_indices = NULL;
|
|
164
|
+
int* end_indices = NULL;
|
|
165
|
+
|
|
166
|
+
CUstream stream = static_cast<CUstream>(wp_cuda_stream_get_current());
|
|
167
|
+
|
|
168
|
+
// compute temporary memory required
|
|
169
|
+
size_t sort_temp_size;
|
|
170
|
+
check_cuda(cub::DeviceSegmentedRadixSort::SortPairs(
|
|
171
|
+
NULL,
|
|
172
|
+
sort_temp_size,
|
|
173
|
+
d_keys,
|
|
174
|
+
d_values,
|
|
175
|
+
n,
|
|
176
|
+
num_segments,
|
|
177
|
+
start_indices,
|
|
178
|
+
end_indices,
|
|
179
|
+
0,
|
|
180
|
+
32,
|
|
181
|
+
stream));
|
|
182
|
+
|
|
183
|
+
RadixSortTemp& temp = g_radix_sort_temp_map[stream];
|
|
184
|
+
|
|
185
|
+
if (sort_temp_size > temp.size)
|
|
186
|
+
{
|
|
187
|
+
wp_free_device(WP_CURRENT_CONTEXT, temp.mem);
|
|
188
|
+
temp.mem = wp_alloc_device(WP_CURRENT_CONTEXT, sort_temp_size);
|
|
189
|
+
temp.size = sort_temp_size;
|
|
190
|
+
}
|
|
191
|
+
|
|
192
|
+
if (mem_out)
|
|
193
|
+
*mem_out = temp.mem;
|
|
194
|
+
if (size_out)
|
|
195
|
+
*size_out = temp.size;
|
|
196
|
+
}
|
|
197
|
+
|
|
198
|
+
// segment_start_indices and segment_end_indices are arrays of length num_segments, where segment_start_indices[i] is the index of the first element
|
|
199
|
+
// in the i-th segment and segment_end_indices[i] is the index after the last element in the i-th segment
|
|
200
|
+
// https://nvidia.github.io/cccl/cub/api/structcub_1_1DeviceSegmentedRadixSort.html
|
|
201
|
+
void segmented_sort_pairs_device(void* context, float* keys, int* values, int n, int* segment_start_indices, int* segment_end_indices, int num_segments)
|
|
202
|
+
{
|
|
203
|
+
ContextGuard guard(context);
|
|
204
|
+
|
|
205
|
+
cub::DoubleBuffer<float> d_keys(keys, keys + n);
|
|
206
|
+
cub::DoubleBuffer<int> d_values(values, values + n);
|
|
207
|
+
|
|
208
|
+
RadixSortTemp temp;
|
|
209
|
+
segmented_sort_reserve(WP_CURRENT_CONTEXT, n, num_segments, &temp.mem, &temp.size);
|
|
210
|
+
|
|
211
|
+
// sort
|
|
212
|
+
check_cuda(cub::DeviceSegmentedRadixSort::SortPairs(
|
|
213
|
+
temp.mem,
|
|
214
|
+
temp.size,
|
|
215
|
+
d_keys,
|
|
216
|
+
d_values,
|
|
217
|
+
n,
|
|
218
|
+
num_segments,
|
|
219
|
+
segment_start_indices,
|
|
220
|
+
segment_end_indices,
|
|
221
|
+
0,
|
|
222
|
+
32,
|
|
223
|
+
(cudaStream_t)wp_cuda_stream_get_current()));
|
|
224
|
+
|
|
225
|
+
if (d_keys.Current() != keys)
|
|
226
|
+
wp_memcpy_d2d(WP_CURRENT_CONTEXT, keys, d_keys.Current(), sizeof(float)*n);
|
|
227
|
+
|
|
228
|
+
if (d_values.Current() != values)
|
|
229
|
+
wp_memcpy_d2d(WP_CURRENT_CONTEXT, values, d_values.Current(), sizeof(int)*n);
|
|
230
|
+
}
|
|
231
|
+
|
|
232
|
+
void wp_segmented_sort_pairs_float_device(uint64_t keys, uint64_t values, int n, uint64_t segment_start_indices, uint64_t segment_end_indices, int num_segments)
|
|
233
|
+
{
|
|
234
|
+
segmented_sort_pairs_device(
|
|
235
|
+
WP_CURRENT_CONTEXT,
|
|
236
|
+
reinterpret_cast<float *>(keys),
|
|
237
|
+
reinterpret_cast<int *>(values), n,
|
|
238
|
+
reinterpret_cast<int *>(segment_start_indices),
|
|
239
|
+
reinterpret_cast<int *>(segment_end_indices),
|
|
240
|
+
num_segments);
|
|
241
|
+
}
|
|
242
|
+
|
|
243
|
+
// segment_indices is an array of length num_segments + 1, where segment_indices[i] is the index of the first element in the i-th segment
|
|
244
|
+
// The end of a segment is given by segment_indices[i+1]
|
|
245
|
+
// https://nvidia.github.io/cccl/cub/api/structcub_1_1DeviceSegmentedSort.html#a-simple-example
|
|
246
|
+
void segmented_sort_pairs_device(void* context, int* keys, int* values, int n, int* segment_start_indices, int* segment_end_indices, int num_segments)
|
|
247
|
+
{
|
|
248
|
+
ContextGuard guard(context);
|
|
249
|
+
|
|
250
|
+
cub::DoubleBuffer<int> d_keys(keys, keys + n);
|
|
251
|
+
cub::DoubleBuffer<int> d_values(values, values + n);
|
|
252
|
+
|
|
253
|
+
RadixSortTemp temp;
|
|
254
|
+
segmented_sort_reserve(WP_CURRENT_CONTEXT, n, num_segments, &temp.mem, &temp.size);
|
|
255
|
+
|
|
256
|
+
// sort
|
|
257
|
+
check_cuda(cub::DeviceSegmentedRadixSort::SortPairs(
|
|
258
|
+
temp.mem,
|
|
259
|
+
temp.size,
|
|
260
|
+
d_keys,
|
|
261
|
+
d_values,
|
|
262
|
+
n,
|
|
263
|
+
num_segments,
|
|
264
|
+
segment_start_indices,
|
|
265
|
+
segment_end_indices,
|
|
266
|
+
0,
|
|
267
|
+
32,
|
|
268
|
+
(cudaStream_t)wp_cuda_stream_get_current()));
|
|
269
|
+
|
|
270
|
+
if (d_keys.Current() != keys)
|
|
271
|
+
wp_memcpy_d2d(WP_CURRENT_CONTEXT, keys, d_keys.Current(), sizeof(float)*n);
|
|
272
|
+
|
|
273
|
+
if (d_values.Current() != values)
|
|
274
|
+
wp_memcpy_d2d(WP_CURRENT_CONTEXT, values, d_values.Current(), sizeof(int)*n);
|
|
275
|
+
}
|
|
276
|
+
|
|
277
|
+
void wp_segmented_sort_pairs_int_device(uint64_t keys, uint64_t values, int n, uint64_t segment_start_indices, uint64_t segment_end_indices, int num_segments)
|
|
278
|
+
{
|
|
279
|
+
segmented_sort_pairs_device(
|
|
280
|
+
WP_CURRENT_CONTEXT,
|
|
281
|
+
reinterpret_cast<int *>(keys),
|
|
282
|
+
reinterpret_cast<int *>(values), n,
|
|
283
|
+
reinterpret_cast<int *>(segment_start_indices),
|
|
284
|
+
reinterpret_cast<int *>(segment_end_indices),
|
|
285
|
+
num_segments);
|
|
286
|
+
}
|
warp/native/sort.h
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
3
|
+
* SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
*
|
|
5
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
* you may not use this file except in compliance with the License.
|
|
7
|
+
* You may obtain a copy of the License at
|
|
8
|
+
*
|
|
9
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
*
|
|
11
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
* See the License for the specific language governing permissions and
|
|
15
|
+
* limitations under the License.
|
|
16
|
+
*/
|
|
17
|
+
|
|
18
|
+
#pragma once
|
|
19
|
+
|
|
20
|
+
#include <stddef.h>
|
|
21
|
+
|
|
22
|
+
void radix_sort_reserve(void* context, int n, void** mem_out=NULL, size_t* size_out=NULL);
|
|
23
|
+
void radix_sort_release(void* context, void* stream);
|
|
24
|
+
|
|
25
|
+
void radix_sort_pairs_host(int* keys, int* values, int n);
|
|
26
|
+
void radix_sort_pairs_host(float* keys, int* values, int n);
|
|
27
|
+
void radix_sort_pairs_host(int64_t* keys, int* values, int n);
|
|
28
|
+
void radix_sort_pairs_device(void* context, int* keys, int* values, int n);
|
|
29
|
+
void radix_sort_pairs_device(void* context, float* keys, int* values, int n);
|
|
30
|
+
void radix_sort_pairs_device(void* context, int64_t* keys, int* values, int n);
|
|
31
|
+
|
|
32
|
+
void segmented_sort_pairs_host(float* keys, int* values, int n, int* segment_start_indices, int* segment_end_indices, int num_segments);
|
|
33
|
+
void segmented_sort_pairs_device(void* context, float* keys, int* values, int n, int* segment_start_indices, int* segment_end_indices, int num_segments);
|
|
34
|
+
void segmented_sort_pairs_host(void* context, int* keys, int* values, int n, int* segment_start_indices, int* segment_end_indices, int num_segments);
|
|
35
|
+
void segmented_sort_pairs_device(void* context, int* keys, int* values, int n, int* segment_start_indices, int* segment_end_indices, int num_segments);
|
warp/native/sparse.cpp
ADDED
|
@@ -0,0 +1,241 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
3
|
+
* SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
*
|
|
5
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
* you may not use this file except in compliance with the License.
|
|
7
|
+
* You may obtain a copy of the License at
|
|
8
|
+
*
|
|
9
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
*
|
|
11
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
* See the License for the specific language governing permissions and
|
|
15
|
+
* limitations under the License.
|
|
16
|
+
*/
|
|
17
|
+
|
|
18
|
+
#include "warp.h"
|
|
19
|
+
|
|
20
|
+
#include <algorithm>
|
|
21
|
+
#include <cstddef>
|
|
22
|
+
#include <numeric>
|
|
23
|
+
#include <vector>
|
|
24
|
+
|
|
25
|
+
namespace
|
|
26
|
+
{
|
|
27
|
+
|
|
28
|
+
template <typename T> bool bsr_block_is_zero(int block_idx, int block_size, const void* values, const uint64_t scalar_zero_mask)
|
|
29
|
+
{
|
|
30
|
+
const T* block_values = static_cast<const T*>(values) + block_idx * block_size;
|
|
31
|
+
const T zero_mask = static_cast<T>(scalar_zero_mask);
|
|
32
|
+
|
|
33
|
+
return std::all_of(block_values, block_values + block_size, [zero_mask](T v) { return (v & zero_mask) == T(0); });
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
} // namespace
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
WP_API void wp_bsr_matrix_from_triplets_host(
|
|
40
|
+
int block_size,
|
|
41
|
+
int scalar_size_in_bytes,
|
|
42
|
+
int row_count,
|
|
43
|
+
int col_count,
|
|
44
|
+
int nnz,
|
|
45
|
+
const int* tpl_nnz,
|
|
46
|
+
const int* tpl_rows,
|
|
47
|
+
const int* tpl_columns,
|
|
48
|
+
const void* tpl_values,
|
|
49
|
+
const uint64_t scalar_zero_mask,
|
|
50
|
+
bool masked_topology,
|
|
51
|
+
int* tpl_block_offsets,
|
|
52
|
+
int* tpl_block_indices,
|
|
53
|
+
int* bsr_offsets,
|
|
54
|
+
int* bsr_columns,
|
|
55
|
+
int* bsr_nnz,
|
|
56
|
+
void* bsr_nnz_event)
|
|
57
|
+
{
|
|
58
|
+
if (tpl_nnz != nullptr)
|
|
59
|
+
{
|
|
60
|
+
nnz = *tpl_nnz;
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
// allocate temporary buffers if not provided
|
|
64
|
+
bool return_summed_blocks = tpl_block_offsets != nullptr && tpl_block_indices != nullptr;
|
|
65
|
+
if (!return_summed_blocks)
|
|
66
|
+
{
|
|
67
|
+
tpl_block_offsets = static_cast<int*>(wp_alloc_host(size_t(nnz) * sizeof(int)));
|
|
68
|
+
tpl_block_indices = static_cast<int*>(wp_alloc_host(size_t(nnz) * sizeof(int)));
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
std::iota(tpl_block_indices, tpl_block_indices + nnz, 0);
|
|
72
|
+
|
|
73
|
+
// remove invalid indices / indices not in mask
|
|
74
|
+
auto discard_invalid_block = [&](int i) -> bool
|
|
75
|
+
{
|
|
76
|
+
const int row = tpl_rows[i];
|
|
77
|
+
const int col = tpl_columns[i];
|
|
78
|
+
if (row < 0 || row >= row_count || col < 0 || col >= col_count)
|
|
79
|
+
{
|
|
80
|
+
return true;
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
if (!masked_topology)
|
|
84
|
+
{
|
|
85
|
+
return false;
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
const int* beg = bsr_columns + bsr_offsets[row];
|
|
89
|
+
const int* end = bsr_columns + bsr_offsets[row + 1];
|
|
90
|
+
const int* block = std::lower_bound(beg, end, col);
|
|
91
|
+
return block == end || *block != col;
|
|
92
|
+
};
|
|
93
|
+
|
|
94
|
+
int* valid_indices_end = std::remove_if(tpl_block_indices, tpl_block_indices + nnz, discard_invalid_block);
|
|
95
|
+
|
|
96
|
+
// remove zero blocks
|
|
97
|
+
if (tpl_values != nullptr && scalar_zero_mask != 0)
|
|
98
|
+
{
|
|
99
|
+
switch (scalar_size_in_bytes)
|
|
100
|
+
{
|
|
101
|
+
case sizeof(uint8_t):
|
|
102
|
+
valid_indices_end = std::remove_if(tpl_block_indices, valid_indices_end, [block_size, tpl_values, scalar_zero_mask](uint32_t i) { return bsr_block_is_zero<uint8_t>(i, block_size, tpl_values, scalar_zero_mask); });
|
|
103
|
+
break;
|
|
104
|
+
case sizeof(uint16_t):
|
|
105
|
+
valid_indices_end = std::remove_if(tpl_block_indices, valid_indices_end, [block_size, tpl_values, scalar_zero_mask](uint32_t i) { return bsr_block_is_zero<uint16_t>(i, block_size, tpl_values, scalar_zero_mask); });
|
|
106
|
+
break;
|
|
107
|
+
case sizeof(uint32_t):
|
|
108
|
+
valid_indices_end = std::remove_if(tpl_block_indices, valid_indices_end, [block_size, tpl_values, scalar_zero_mask](uint32_t i) { return bsr_block_is_zero<uint32_t>(i, block_size, tpl_values, scalar_zero_mask); });
|
|
109
|
+
break;
|
|
110
|
+
case sizeof(uint64_t):
|
|
111
|
+
valid_indices_end = std::remove_if(tpl_block_indices, valid_indices_end, [block_size, tpl_values, scalar_zero_mask](uint32_t i) { return bsr_block_is_zero<uint64_t>(i, block_size, tpl_values, scalar_zero_mask); });
|
|
112
|
+
break;
|
|
113
|
+
}
|
|
114
|
+
}
|
|
115
|
+
|
|
116
|
+
// sort block indices according to lexico order
|
|
117
|
+
std::sort(tpl_block_indices, valid_indices_end, [tpl_rows, tpl_columns](int i, int j) -> bool
|
|
118
|
+
{ return tpl_rows[i] < tpl_rows[j] || (tpl_rows[i] == tpl_rows[j] && tpl_columns[i] < tpl_columns[j]); });
|
|
119
|
+
|
|
120
|
+
// accumulate blocks at same locations, count blocks per row
|
|
121
|
+
std::fill_n(bsr_offsets, row_count + 1, 0);
|
|
122
|
+
|
|
123
|
+
int current_row = -1;
|
|
124
|
+
int current_col = -1;
|
|
125
|
+
int current_block_idx = -1;
|
|
126
|
+
|
|
127
|
+
for (int *block = tpl_block_indices, *block_offset = tpl_block_offsets ; block != valid_indices_end ; ++ block)
|
|
128
|
+
{
|
|
129
|
+
int32_t idx = *block;
|
|
130
|
+
int row = tpl_rows[idx];
|
|
131
|
+
int col = tpl_columns[idx];
|
|
132
|
+
|
|
133
|
+
if (row != current_row || col != current_col)
|
|
134
|
+
{
|
|
135
|
+
*(bsr_columns++) = col;
|
|
136
|
+
|
|
137
|
+
++bsr_offsets[row + 1];
|
|
138
|
+
|
|
139
|
+
if(current_row == -1) {
|
|
140
|
+
*block_offset = 0;
|
|
141
|
+
} else {
|
|
142
|
+
*(block_offset+1) = *block_offset;
|
|
143
|
+
++block_offset;
|
|
144
|
+
}
|
|
145
|
+
|
|
146
|
+
current_row = row;
|
|
147
|
+
current_col = col;
|
|
148
|
+
}
|
|
149
|
+
|
|
150
|
+
++(*block_offset);
|
|
151
|
+
}
|
|
152
|
+
|
|
153
|
+
// build postfix sum of row counts
|
|
154
|
+
std::partial_sum(bsr_offsets, bsr_offsets + row_count + 1, bsr_offsets);
|
|
155
|
+
|
|
156
|
+
if(!return_summed_blocks)
|
|
157
|
+
{
|
|
158
|
+
// free our temporary buffers
|
|
159
|
+
wp_free_host(tpl_block_offsets);
|
|
160
|
+
wp_free_host(tpl_block_indices);
|
|
161
|
+
}
|
|
162
|
+
|
|
163
|
+
if (bsr_nnz != nullptr)
|
|
164
|
+
{
|
|
165
|
+
*bsr_nnz = bsr_offsets[row_count];
|
|
166
|
+
}
|
|
167
|
+
}
|
|
168
|
+
|
|
169
|
+
WP_API void wp_bsr_transpose_host(
|
|
170
|
+
int row_count, int col_count, int nnz,
|
|
171
|
+
const int* bsr_offsets, const int* bsr_columns,
|
|
172
|
+
int* transposed_bsr_offsets,
|
|
173
|
+
int* transposed_bsr_columns,
|
|
174
|
+
int* block_indices
|
|
175
|
+
)
|
|
176
|
+
{
|
|
177
|
+
nnz = bsr_offsets[row_count];
|
|
178
|
+
|
|
179
|
+
std::vector<int> bsr_rows(nnz);
|
|
180
|
+
std::iota(block_indices, block_indices + nnz, 0);
|
|
181
|
+
|
|
182
|
+
// Fill row indices from offsets
|
|
183
|
+
for (int row = 0; row < row_count; ++row)
|
|
184
|
+
{
|
|
185
|
+
std::fill(bsr_rows.begin() + bsr_offsets[row], bsr_rows.begin() + bsr_offsets[row + 1], row);
|
|
186
|
+
}
|
|
187
|
+
|
|
188
|
+
// sort block indices according to (transposed) lexico order
|
|
189
|
+
std::sort(
|
|
190
|
+
block_indices, block_indices + nnz, [&bsr_rows, bsr_columns](int i, int j) -> bool
|
|
191
|
+
{ return bsr_columns[i] < bsr_columns[j] || (bsr_columns[i] == bsr_columns[j] && bsr_rows[i] < bsr_rows[j]); });
|
|
192
|
+
|
|
193
|
+
// Count blocks per column and transpose blocks
|
|
194
|
+
std::fill_n(transposed_bsr_offsets, col_count + 1, 0);
|
|
195
|
+
|
|
196
|
+
for (int i = 0; i < nnz; ++i)
|
|
197
|
+
{
|
|
198
|
+
int idx = block_indices[i];
|
|
199
|
+
int row = bsr_rows[idx];
|
|
200
|
+
int col = bsr_columns[idx];
|
|
201
|
+
|
|
202
|
+
++transposed_bsr_offsets[col + 1];
|
|
203
|
+
transposed_bsr_columns[i] = row;
|
|
204
|
+
}
|
|
205
|
+
|
|
206
|
+
// build postfix sum of column counts
|
|
207
|
+
std::partial_sum(transposed_bsr_offsets, transposed_bsr_offsets + col_count + 1, transposed_bsr_offsets);
|
|
208
|
+
|
|
209
|
+
}
|
|
210
|
+
|
|
211
|
+
#if !WP_ENABLE_CUDA
|
|
212
|
+
WP_API void wp_bsr_matrix_from_triplets_device(
|
|
213
|
+
int block_size,
|
|
214
|
+
int scalar_size_in_bytes,
|
|
215
|
+
int row_count,
|
|
216
|
+
int col_count,
|
|
217
|
+
int tpl_nnz_upper_bound,
|
|
218
|
+
const int* tpl_nnz,
|
|
219
|
+
const int* tpl_rows,
|
|
220
|
+
const int* tpl_columns,
|
|
221
|
+
const void* tpl_values,
|
|
222
|
+
const uint64_t scalar_zero_mask,
|
|
223
|
+
bool masked_topology,
|
|
224
|
+
int* summed_block_offsets,
|
|
225
|
+
int* summed_block_indices,
|
|
226
|
+
int* bsr_offsets,
|
|
227
|
+
int* bsr_columns,
|
|
228
|
+
int* bsr_nnz,
|
|
229
|
+
void* bsr_nnz_event) {}
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
WP_API void wp_bsr_transpose_device(
|
|
233
|
+
int row_count, int col_count, int nnz,
|
|
234
|
+
const int* bsr_offsets, const int* bsr_columns,
|
|
235
|
+
int* transposed_bsr_offsets,
|
|
236
|
+
int* transposed_bsr_columns,
|
|
237
|
+
int* src_block_indices) {}
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
#endif
|