warp-lang 1.10.0__py3-none-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +334 -0
- warp/__init__.pyi +5856 -0
- warp/_src/__init__.py +14 -0
- warp/_src/autograd.py +1077 -0
- warp/_src/build.py +620 -0
- warp/_src/build_dll.py +642 -0
- warp/_src/builtins.py +10555 -0
- warp/_src/codegen.py +4361 -0
- warp/_src/config.py +178 -0
- warp/_src/constants.py +59 -0
- warp/_src/context.py +8352 -0
- warp/_src/dlpack.py +464 -0
- warp/_src/fabric.py +362 -0
- warp/_src/fem/__init__.py +14 -0
- warp/_src/fem/adaptivity.py +510 -0
- warp/_src/fem/cache.py +689 -0
- warp/_src/fem/dirichlet.py +190 -0
- warp/_src/fem/domain.py +553 -0
- warp/_src/fem/field/__init__.py +131 -0
- warp/_src/fem/field/field.py +703 -0
- warp/_src/fem/field/nodal_field.py +403 -0
- warp/_src/fem/field/restriction.py +39 -0
- warp/_src/fem/field/virtual.py +1021 -0
- warp/_src/fem/geometry/__init__.py +32 -0
- warp/_src/fem/geometry/adaptive_nanogrid.py +782 -0
- warp/_src/fem/geometry/closest_point.py +99 -0
- warp/_src/fem/geometry/deformed_geometry.py +277 -0
- warp/_src/fem/geometry/element.py +854 -0
- warp/_src/fem/geometry/geometry.py +693 -0
- warp/_src/fem/geometry/grid_2d.py +478 -0
- warp/_src/fem/geometry/grid_3d.py +539 -0
- warp/_src/fem/geometry/hexmesh.py +956 -0
- warp/_src/fem/geometry/nanogrid.py +660 -0
- warp/_src/fem/geometry/partition.py +483 -0
- warp/_src/fem/geometry/quadmesh.py +597 -0
- warp/_src/fem/geometry/tetmesh.py +762 -0
- warp/_src/fem/geometry/trimesh.py +588 -0
- warp/_src/fem/integrate.py +2507 -0
- warp/_src/fem/linalg.py +385 -0
- warp/_src/fem/operator.py +398 -0
- warp/_src/fem/polynomial.py +231 -0
- warp/_src/fem/quadrature/__init__.py +17 -0
- warp/_src/fem/quadrature/pic_quadrature.py +318 -0
- warp/_src/fem/quadrature/quadrature.py +665 -0
- warp/_src/fem/space/__init__.py +248 -0
- warp/_src/fem/space/basis_function_space.py +499 -0
- warp/_src/fem/space/basis_space.py +681 -0
- warp/_src/fem/space/dof_mapper.py +253 -0
- warp/_src/fem/space/function_space.py +312 -0
- warp/_src/fem/space/grid_2d_function_space.py +179 -0
- warp/_src/fem/space/grid_3d_function_space.py +229 -0
- warp/_src/fem/space/hexmesh_function_space.py +255 -0
- warp/_src/fem/space/nanogrid_function_space.py +199 -0
- warp/_src/fem/space/partition.py +435 -0
- warp/_src/fem/space/quadmesh_function_space.py +222 -0
- warp/_src/fem/space/restriction.py +221 -0
- warp/_src/fem/space/shape/__init__.py +152 -0
- warp/_src/fem/space/shape/cube_shape_function.py +1107 -0
- warp/_src/fem/space/shape/shape_function.py +134 -0
- warp/_src/fem/space/shape/square_shape_function.py +928 -0
- warp/_src/fem/space/shape/tet_shape_function.py +829 -0
- warp/_src/fem/space/shape/triangle_shape_function.py +674 -0
- warp/_src/fem/space/tetmesh_function_space.py +270 -0
- warp/_src/fem/space/topology.py +461 -0
- warp/_src/fem/space/trimesh_function_space.py +193 -0
- warp/_src/fem/types.py +114 -0
- warp/_src/fem/utils.py +488 -0
- warp/_src/jax.py +188 -0
- warp/_src/jax_experimental/__init__.py +14 -0
- warp/_src/jax_experimental/custom_call.py +389 -0
- warp/_src/jax_experimental/ffi.py +1286 -0
- warp/_src/jax_experimental/xla_ffi.py +658 -0
- warp/_src/marching_cubes.py +710 -0
- warp/_src/math.py +416 -0
- warp/_src/optim/__init__.py +14 -0
- warp/_src/optim/adam.py +165 -0
- warp/_src/optim/linear.py +1608 -0
- warp/_src/optim/sgd.py +114 -0
- warp/_src/paddle.py +408 -0
- warp/_src/render/__init__.py +14 -0
- warp/_src/render/imgui_manager.py +291 -0
- warp/_src/render/render_opengl.py +3638 -0
- warp/_src/render/render_usd.py +939 -0
- warp/_src/render/utils.py +162 -0
- warp/_src/sparse.py +2718 -0
- warp/_src/tape.py +1208 -0
- warp/_src/thirdparty/__init__.py +0 -0
- warp/_src/thirdparty/appdirs.py +598 -0
- warp/_src/thirdparty/dlpack.py +145 -0
- warp/_src/thirdparty/unittest_parallel.py +676 -0
- warp/_src/torch.py +393 -0
- warp/_src/types.py +5888 -0
- warp/_src/utils.py +1695 -0
- warp/autograd.py +33 -0
- warp/bin/libwarp-clang.dylib +0 -0
- warp/bin/libwarp.dylib +0 -0
- warp/build.py +29 -0
- warp/build_dll.py +24 -0
- warp/codegen.py +24 -0
- warp/constants.py +24 -0
- warp/context.py +33 -0
- warp/dlpack.py +24 -0
- warp/examples/__init__.py +24 -0
- warp/examples/assets/bear.usd +0 -0
- warp/examples/assets/bunny.usd +0 -0
- warp/examples/assets/cube.usd +0 -0
- warp/examples/assets/nonuniform.usd +0 -0
- warp/examples/assets/nvidia_logo.png +0 -0
- warp/examples/assets/pixel.jpg +0 -0
- warp/examples/assets/rocks.nvdb +0 -0
- warp/examples/assets/rocks.usd +0 -0
- warp/examples/assets/sphere.usd +0 -0
- warp/examples/assets/square_cloth.usd +0 -0
- warp/examples/benchmarks/benchmark_api.py +389 -0
- warp/examples/benchmarks/benchmark_cloth.py +296 -0
- warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
- warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
- warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
- warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
- warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
- warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
- warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
- warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
- warp/examples/benchmarks/benchmark_gemm.py +164 -0
- warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
- warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
- warp/examples/benchmarks/benchmark_launches.py +301 -0
- warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
- warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
- warp/examples/browse.py +37 -0
- warp/examples/core/example_cupy.py +86 -0
- warp/examples/core/example_dem.py +241 -0
- warp/examples/core/example_fluid.py +299 -0
- warp/examples/core/example_graph_capture.py +150 -0
- warp/examples/core/example_marching_cubes.py +195 -0
- warp/examples/core/example_mesh.py +180 -0
- warp/examples/core/example_mesh_intersect.py +211 -0
- warp/examples/core/example_nvdb.py +182 -0
- warp/examples/core/example_raycast.py +111 -0
- warp/examples/core/example_raymarch.py +205 -0
- warp/examples/core/example_render_opengl.py +290 -0
- warp/examples/core/example_sample_mesh.py +300 -0
- warp/examples/core/example_sph.py +411 -0
- warp/examples/core/example_spin_lock.py +93 -0
- warp/examples/core/example_torch.py +211 -0
- warp/examples/core/example_wave.py +269 -0
- warp/examples/core/example_work_queue.py +118 -0
- warp/examples/distributed/example_jacobi_mpi.py +506 -0
- warp/examples/fem/example_adaptive_grid.py +286 -0
- warp/examples/fem/example_apic_fluid.py +469 -0
- warp/examples/fem/example_burgers.py +261 -0
- warp/examples/fem/example_convection_diffusion.py +181 -0
- warp/examples/fem/example_convection_diffusion_dg.py +225 -0
- warp/examples/fem/example_darcy_ls_optimization.py +489 -0
- warp/examples/fem/example_deformed_geometry.py +172 -0
- warp/examples/fem/example_diffusion.py +196 -0
- warp/examples/fem/example_diffusion_3d.py +225 -0
- warp/examples/fem/example_diffusion_mgpu.py +225 -0
- warp/examples/fem/example_distortion_energy.py +228 -0
- warp/examples/fem/example_elastic_shape_optimization.py +387 -0
- warp/examples/fem/example_magnetostatics.py +242 -0
- warp/examples/fem/example_mixed_elasticity.py +293 -0
- warp/examples/fem/example_navier_stokes.py +263 -0
- warp/examples/fem/example_nonconforming_contact.py +300 -0
- warp/examples/fem/example_stokes.py +213 -0
- warp/examples/fem/example_stokes_transfer.py +262 -0
- warp/examples/fem/example_streamlines.py +357 -0
- warp/examples/fem/utils.py +1047 -0
- warp/examples/interop/example_jax_callable.py +146 -0
- warp/examples/interop/example_jax_ffi_callback.py +132 -0
- warp/examples/interop/example_jax_kernel.py +232 -0
- warp/examples/optim/example_diffray.py +561 -0
- warp/examples/optim/example_fluid_checkpoint.py +497 -0
- warp/examples/tile/example_tile_block_cholesky.py +502 -0
- warp/examples/tile/example_tile_cholesky.py +88 -0
- warp/examples/tile/example_tile_convolution.py +66 -0
- warp/examples/tile/example_tile_fft.py +55 -0
- warp/examples/tile/example_tile_filtering.py +113 -0
- warp/examples/tile/example_tile_matmul.py +85 -0
- warp/examples/tile/example_tile_mcgp.py +191 -0
- warp/examples/tile/example_tile_mlp.py +385 -0
- warp/examples/tile/example_tile_nbody.py +199 -0
- warp/fabric.py +24 -0
- warp/fem/__init__.py +173 -0
- warp/fem/adaptivity.py +26 -0
- warp/fem/cache.py +30 -0
- warp/fem/dirichlet.py +24 -0
- warp/fem/field/__init__.py +24 -0
- warp/fem/field/field.py +26 -0
- warp/fem/geometry/__init__.py +21 -0
- warp/fem/geometry/closest_point.py +31 -0
- warp/fem/linalg.py +38 -0
- warp/fem/operator.py +32 -0
- warp/fem/polynomial.py +29 -0
- warp/fem/space/__init__.py +22 -0
- warp/fem/space/basis_space.py +24 -0
- warp/fem/space/shape/__init__.py +68 -0
- warp/fem/space/topology.py +24 -0
- warp/fem/types.py +24 -0
- warp/fem/utils.py +32 -0
- warp/jax.py +29 -0
- warp/jax_experimental/__init__.py +29 -0
- warp/jax_experimental/custom_call.py +29 -0
- warp/jax_experimental/ffi.py +39 -0
- warp/jax_experimental/xla_ffi.py +24 -0
- warp/marching_cubes.py +24 -0
- warp/math.py +37 -0
- warp/native/array.h +1687 -0
- warp/native/builtin.h +2327 -0
- warp/native/bvh.cpp +562 -0
- warp/native/bvh.cu +826 -0
- warp/native/bvh.h +555 -0
- warp/native/clang/clang.cpp +541 -0
- warp/native/coloring.cpp +622 -0
- warp/native/crt.cpp +51 -0
- warp/native/crt.h +568 -0
- warp/native/cuda_crt.h +1058 -0
- warp/native/cuda_util.cpp +677 -0
- warp/native/cuda_util.h +313 -0
- warp/native/error.cpp +77 -0
- warp/native/error.h +36 -0
- warp/native/exports.h +2023 -0
- warp/native/fabric.h +246 -0
- warp/native/hashgrid.cpp +311 -0
- warp/native/hashgrid.cu +89 -0
- warp/native/hashgrid.h +240 -0
- warp/native/initializer_array.h +41 -0
- warp/native/intersect.h +1253 -0
- warp/native/intersect_adj.h +375 -0
- warp/native/intersect_tri.h +348 -0
- warp/native/mat.h +5189 -0
- warp/native/mathdx.cpp +93 -0
- warp/native/matnn.h +221 -0
- warp/native/mesh.cpp +266 -0
- warp/native/mesh.cu +406 -0
- warp/native/mesh.h +2097 -0
- warp/native/nanovdb/GridHandle.h +533 -0
- warp/native/nanovdb/HostBuffer.h +591 -0
- warp/native/nanovdb/NanoVDB.h +6246 -0
- warp/native/nanovdb/NodeManager.h +323 -0
- warp/native/nanovdb/PNanoVDB.h +3390 -0
- warp/native/noise.h +859 -0
- warp/native/quat.h +1664 -0
- warp/native/rand.h +342 -0
- warp/native/range.h +145 -0
- warp/native/reduce.cpp +174 -0
- warp/native/reduce.cu +363 -0
- warp/native/runlength_encode.cpp +79 -0
- warp/native/runlength_encode.cu +61 -0
- warp/native/scan.cpp +47 -0
- warp/native/scan.cu +55 -0
- warp/native/scan.h +23 -0
- warp/native/solid_angle.h +466 -0
- warp/native/sort.cpp +251 -0
- warp/native/sort.cu +286 -0
- warp/native/sort.h +35 -0
- warp/native/sparse.cpp +241 -0
- warp/native/sparse.cu +435 -0
- warp/native/spatial.h +1306 -0
- warp/native/svd.h +727 -0
- warp/native/temp_buffer.h +46 -0
- warp/native/tile.h +4124 -0
- warp/native/tile_radix_sort.h +1112 -0
- warp/native/tile_reduce.h +838 -0
- warp/native/tile_scan.h +240 -0
- warp/native/tuple.h +189 -0
- warp/native/vec.h +2199 -0
- warp/native/version.h +23 -0
- warp/native/volume.cpp +501 -0
- warp/native/volume.cu +68 -0
- warp/native/volume.h +970 -0
- warp/native/volume_builder.cu +483 -0
- warp/native/volume_builder.h +52 -0
- warp/native/volume_impl.h +70 -0
- warp/native/warp.cpp +1143 -0
- warp/native/warp.cu +4604 -0
- warp/native/warp.h +358 -0
- warp/optim/__init__.py +20 -0
- warp/optim/adam.py +24 -0
- warp/optim/linear.py +35 -0
- warp/optim/sgd.py +24 -0
- warp/paddle.py +24 -0
- warp/py.typed +0 -0
- warp/render/__init__.py +22 -0
- warp/render/imgui_manager.py +29 -0
- warp/render/render_opengl.py +24 -0
- warp/render/render_usd.py +24 -0
- warp/render/utils.py +24 -0
- warp/sparse.py +51 -0
- warp/tape.py +24 -0
- warp/tests/__init__.py +1 -0
- warp/tests/__main__.py +4 -0
- warp/tests/assets/curlnoise_golden.npy +0 -0
- warp/tests/assets/mlp_golden.npy +0 -0
- warp/tests/assets/pixel.npy +0 -0
- warp/tests/assets/pnoise_golden.npy +0 -0
- warp/tests/assets/spiky.usd +0 -0
- warp/tests/assets/test_grid.nvdb +0 -0
- warp/tests/assets/test_index_grid.nvdb +0 -0
- warp/tests/assets/test_int32_grid.nvdb +0 -0
- warp/tests/assets/test_vec_grid.nvdb +0 -0
- warp/tests/assets/torus.nvdb +0 -0
- warp/tests/assets/torus.usda +105 -0
- warp/tests/aux_test_class_kernel.py +34 -0
- warp/tests/aux_test_compile_consts_dummy.py +18 -0
- warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
- warp/tests/aux_test_dependent.py +29 -0
- warp/tests/aux_test_grad_customs.py +29 -0
- warp/tests/aux_test_instancing_gc.py +26 -0
- warp/tests/aux_test_module_aot.py +7 -0
- warp/tests/aux_test_module_unload.py +23 -0
- warp/tests/aux_test_name_clash1.py +40 -0
- warp/tests/aux_test_name_clash2.py +40 -0
- warp/tests/aux_test_reference.py +9 -0
- warp/tests/aux_test_reference_reference.py +8 -0
- warp/tests/aux_test_square.py +16 -0
- warp/tests/aux_test_unresolved_func.py +22 -0
- warp/tests/aux_test_unresolved_symbol.py +22 -0
- warp/tests/cuda/__init__.py +0 -0
- warp/tests/cuda/test_async.py +676 -0
- warp/tests/cuda/test_conditional_captures.py +1147 -0
- warp/tests/cuda/test_ipc.py +124 -0
- warp/tests/cuda/test_mempool.py +233 -0
- warp/tests/cuda/test_multigpu.py +169 -0
- warp/tests/cuda/test_peer.py +139 -0
- warp/tests/cuda/test_pinned.py +84 -0
- warp/tests/cuda/test_streams.py +691 -0
- warp/tests/geometry/__init__.py +0 -0
- warp/tests/geometry/test_bvh.py +335 -0
- warp/tests/geometry/test_hash_grid.py +259 -0
- warp/tests/geometry/test_marching_cubes.py +294 -0
- warp/tests/geometry/test_mesh.py +318 -0
- warp/tests/geometry/test_mesh_query_aabb.py +392 -0
- warp/tests/geometry/test_mesh_query_point.py +935 -0
- warp/tests/geometry/test_mesh_query_ray.py +323 -0
- warp/tests/geometry/test_volume.py +1103 -0
- warp/tests/geometry/test_volume_write.py +346 -0
- warp/tests/interop/__init__.py +0 -0
- warp/tests/interop/test_dlpack.py +730 -0
- warp/tests/interop/test_jax.py +1673 -0
- warp/tests/interop/test_paddle.py +800 -0
- warp/tests/interop/test_torch.py +1001 -0
- warp/tests/run_coverage_serial.py +39 -0
- warp/tests/test_adam.py +162 -0
- warp/tests/test_arithmetic.py +1096 -0
- warp/tests/test_array.py +3756 -0
- warp/tests/test_array_reduce.py +156 -0
- warp/tests/test_assert.py +303 -0
- warp/tests/test_atomic.py +336 -0
- warp/tests/test_atomic_bitwise.py +209 -0
- warp/tests/test_atomic_cas.py +312 -0
- warp/tests/test_bool.py +220 -0
- warp/tests/test_builtins_resolution.py +732 -0
- warp/tests/test_closest_point_edge_edge.py +327 -0
- warp/tests/test_codegen.py +974 -0
- warp/tests/test_codegen_instancing.py +1495 -0
- warp/tests/test_compile_consts.py +215 -0
- warp/tests/test_conditional.py +298 -0
- warp/tests/test_context.py +35 -0
- warp/tests/test_copy.py +319 -0
- warp/tests/test_ctypes.py +618 -0
- warp/tests/test_dense.py +73 -0
- warp/tests/test_devices.py +127 -0
- warp/tests/test_enum.py +136 -0
- warp/tests/test_examples.py +424 -0
- warp/tests/test_fabricarray.py +998 -0
- warp/tests/test_fast_math.py +72 -0
- warp/tests/test_fem.py +2204 -0
- warp/tests/test_fixedarray.py +229 -0
- warp/tests/test_fp16.py +136 -0
- warp/tests/test_func.py +501 -0
- warp/tests/test_future_annotations.py +100 -0
- warp/tests/test_generics.py +656 -0
- warp/tests/test_grad.py +893 -0
- warp/tests/test_grad_customs.py +339 -0
- warp/tests/test_grad_debug.py +341 -0
- warp/tests/test_implicit_init.py +411 -0
- warp/tests/test_import.py +45 -0
- warp/tests/test_indexedarray.py +1140 -0
- warp/tests/test_intersect.py +103 -0
- warp/tests/test_iter.py +76 -0
- warp/tests/test_large.py +177 -0
- warp/tests/test_launch.py +411 -0
- warp/tests/test_lerp.py +151 -0
- warp/tests/test_linear_solvers.py +223 -0
- warp/tests/test_lvalue.py +427 -0
- warp/tests/test_map.py +526 -0
- warp/tests/test_mat.py +3515 -0
- warp/tests/test_mat_assign_copy.py +178 -0
- warp/tests/test_mat_constructors.py +573 -0
- warp/tests/test_mat_lite.py +122 -0
- warp/tests/test_mat_scalar_ops.py +2913 -0
- warp/tests/test_math.py +212 -0
- warp/tests/test_module_aot.py +287 -0
- warp/tests/test_module_hashing.py +258 -0
- warp/tests/test_modules_lite.py +70 -0
- warp/tests/test_noise.py +252 -0
- warp/tests/test_operators.py +299 -0
- warp/tests/test_options.py +129 -0
- warp/tests/test_overwrite.py +551 -0
- warp/tests/test_print.py +408 -0
- warp/tests/test_quat.py +2653 -0
- warp/tests/test_quat_assign_copy.py +145 -0
- warp/tests/test_rand.py +339 -0
- warp/tests/test_reload.py +303 -0
- warp/tests/test_rounding.py +157 -0
- warp/tests/test_runlength_encode.py +196 -0
- warp/tests/test_scalar_ops.py +133 -0
- warp/tests/test_smoothstep.py +108 -0
- warp/tests/test_snippet.py +318 -0
- warp/tests/test_sparse.py +845 -0
- warp/tests/test_spatial.py +2859 -0
- warp/tests/test_spatial_assign_copy.py +160 -0
- warp/tests/test_special_values.py +361 -0
- warp/tests/test_static.py +640 -0
- warp/tests/test_struct.py +901 -0
- warp/tests/test_tape.py +242 -0
- warp/tests/test_transient_module.py +93 -0
- warp/tests/test_triangle_closest_point.py +192 -0
- warp/tests/test_tuple.py +361 -0
- warp/tests/test_types.py +615 -0
- warp/tests/test_utils.py +594 -0
- warp/tests/test_vec.py +1408 -0
- warp/tests/test_vec_assign_copy.py +143 -0
- warp/tests/test_vec_constructors.py +325 -0
- warp/tests/test_vec_lite.py +80 -0
- warp/tests/test_vec_scalar_ops.py +2327 -0
- warp/tests/test_verify_fp.py +100 -0
- warp/tests/test_version.py +75 -0
- warp/tests/tile/__init__.py +0 -0
- warp/tests/tile/test_tile.py +1519 -0
- warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
- warp/tests/tile/test_tile_cholesky.py +608 -0
- warp/tests/tile/test_tile_load.py +724 -0
- warp/tests/tile/test_tile_mathdx.py +156 -0
- warp/tests/tile/test_tile_matmul.py +179 -0
- warp/tests/tile/test_tile_mlp.py +400 -0
- warp/tests/tile/test_tile_reduce.py +950 -0
- warp/tests/tile/test_tile_shared_memory.py +376 -0
- warp/tests/tile/test_tile_sort.py +121 -0
- warp/tests/tile/test_tile_view.py +173 -0
- warp/tests/unittest_serial.py +47 -0
- warp/tests/unittest_suites.py +430 -0
- warp/tests/unittest_utils.py +469 -0
- warp/tests/walkthrough_debug.py +95 -0
- warp/torch.py +24 -0
- warp/types.py +51 -0
- warp/utils.py +31 -0
- warp_lang-1.10.0.dist-info/METADATA +459 -0
- warp_lang-1.10.0.dist-info/RECORD +468 -0
- warp_lang-1.10.0.dist-info/WHEEL +5 -0
- warp_lang-1.10.0.dist-info/licenses/LICENSE.md +176 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
- warp_lang-1.10.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,838 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
3
|
+
* SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
*
|
|
5
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
* you may not use this file except in compliance with the License.
|
|
7
|
+
* You may obtain a copy of the License at
|
|
8
|
+
*
|
|
9
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
*
|
|
11
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
* See the License for the specific language governing permissions and
|
|
15
|
+
* limitations under the License.
|
|
16
|
+
*/
|
|
17
|
+
|
|
18
|
+
#pragma once
|
|
19
|
+
|
|
20
|
+
#include "tile.h"
|
|
21
|
+
|
|
22
|
+
#ifdef __clang__
|
|
23
|
+
// disable warnings related to C++17 extensions on CPU JIT builds
|
|
24
|
+
#pragma clang diagnostic push
|
|
25
|
+
#pragma clang diagnostic ignored "-Wc++17-extensions"
|
|
26
|
+
#endif // __clang__
|
|
27
|
+
|
|
28
|
+
#define WP_TILE_WARP_SIZE 32
|
|
29
|
+
|
|
30
|
+
namespace wp
|
|
31
|
+
{
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
template <typename T>
|
|
35
|
+
int argmax_tracker(T champion_value, T current_value, int champion_index, int current_index)
|
|
36
|
+
{
|
|
37
|
+
return current_value > champion_value ? current_index : champion_index;
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
template <typename T>
|
|
41
|
+
int argmin_tracker(T champion_value, T current_value, int champion_index, int current_index)
|
|
42
|
+
{
|
|
43
|
+
return current_value < champion_value ? current_index : champion_index;
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
#if defined(__CUDA_ARCH__)
|
|
48
|
+
|
|
49
|
+
template <typename T>
|
|
50
|
+
inline CUDA_CALLABLE T warp_shuffle_down(T val, int offset, int mask)
|
|
51
|
+
{
|
|
52
|
+
typedef unsigned int Word;
|
|
53
|
+
|
|
54
|
+
union
|
|
55
|
+
{
|
|
56
|
+
T output;
|
|
57
|
+
Word output_storage;
|
|
58
|
+
};
|
|
59
|
+
|
|
60
|
+
union
|
|
61
|
+
{
|
|
62
|
+
T input;
|
|
63
|
+
Word input_storage;
|
|
64
|
+
};
|
|
65
|
+
|
|
66
|
+
input = val;
|
|
67
|
+
|
|
68
|
+
Word* dest = reinterpret_cast<Word*>(&output);
|
|
69
|
+
Word* src = reinterpret_cast<Word*>(&input);
|
|
70
|
+
|
|
71
|
+
unsigned int shuffle_word;
|
|
72
|
+
|
|
73
|
+
constexpr int word_count = (sizeof(T) + sizeof(Word) - 1) / sizeof(Word);
|
|
74
|
+
|
|
75
|
+
WP_PRAGMA_UNROLL
|
|
76
|
+
for (int i=0; i < word_count; ++i)
|
|
77
|
+
{
|
|
78
|
+
shuffle_word = __shfl_down_sync(mask, src[i], offset, WP_TILE_WARP_SIZE);
|
|
79
|
+
dest[i] = shuffle_word;
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
return output;
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
// vector overload
|
|
86
|
+
template <unsigned Length, typename T>
|
|
87
|
+
inline CUDA_CALLABLE wp::vec_t<Length, T> warp_shuffle_down(wp::vec_t<Length, T> val, int offset, int mask)
|
|
88
|
+
{
|
|
89
|
+
wp::vec_t<Length, T> result;
|
|
90
|
+
|
|
91
|
+
for (unsigned i=0; i < Length; ++i)
|
|
92
|
+
result[i] = __shfl_down_sync(mask, val[i], offset, WP_TILE_WARP_SIZE);
|
|
93
|
+
|
|
94
|
+
return result;
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
// matrix overload
|
|
98
|
+
template <unsigned Rows, unsigned Cols, typename T>
|
|
99
|
+
inline CUDA_CALLABLE wp::mat_t<Rows, Cols, T> warp_shuffle_down(wp::mat_t<Rows, Cols, T> val, int offset, int mask)
|
|
100
|
+
{
|
|
101
|
+
wp::mat_t<Rows, Cols, T> result;
|
|
102
|
+
|
|
103
|
+
for (unsigned i=0; i < Rows; ++i)
|
|
104
|
+
for (unsigned j=0; j < Cols; ++j)
|
|
105
|
+
result.data[i][j] = __shfl_down_sync(mask, val.data[i][j], offset, WP_TILE_WARP_SIZE);
|
|
106
|
+
|
|
107
|
+
return result;
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
template <typename T, typename Op>
|
|
112
|
+
inline CUDA_CALLABLE T warp_reduce(T val, Op f, unsigned int mask)
|
|
113
|
+
{
|
|
114
|
+
T sum = val;
|
|
115
|
+
|
|
116
|
+
if (mask == 0xFFFFFFFF)
|
|
117
|
+
{
|
|
118
|
+
// handle case where entire warp is active
|
|
119
|
+
for (int offset=WP_TILE_WARP_SIZE/2; offset > 0; offset /= 2)
|
|
120
|
+
{
|
|
121
|
+
sum = f(sum, warp_shuffle_down(sum, offset, mask));
|
|
122
|
+
}
|
|
123
|
+
}
|
|
124
|
+
else
|
|
125
|
+
{
|
|
126
|
+
// handle partial warp case - works for contiguous masks
|
|
127
|
+
for (int offset=WP_TILE_WARP_SIZE/2; offset > 0; offset /= 2)
|
|
128
|
+
{
|
|
129
|
+
T shfl_val = warp_shuffle_down(sum, offset, mask);
|
|
130
|
+
if ((mask & (1 << ((threadIdx.x + offset)%WP_TILE_WARP_SIZE))) != 0)
|
|
131
|
+
sum = f(sum, shfl_val);
|
|
132
|
+
}
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
return sum;
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
template <typename T>
|
|
139
|
+
struct ValueAndIndex
|
|
140
|
+
{
|
|
141
|
+
T value;
|
|
142
|
+
int index;
|
|
143
|
+
};
|
|
144
|
+
|
|
145
|
+
template <typename T, typename Op, typename OpTrack>
|
|
146
|
+
inline CUDA_CALLABLE ValueAndIndex<T> warp_reduce_tracked(T val, int idx, Op f, OpTrack track, unsigned int mask)
|
|
147
|
+
{
|
|
148
|
+
T sum = val;
|
|
149
|
+
int index = idx;
|
|
150
|
+
|
|
151
|
+
if (mask == 0xFFFFFFFF)
|
|
152
|
+
{
|
|
153
|
+
// handle case where entire warp is active
|
|
154
|
+
for (int offset=WP_TILE_WARP_SIZE/2; offset > 0; offset /= 2)
|
|
155
|
+
{
|
|
156
|
+
auto shfl_val = warp_shuffle_down(sum, offset, mask);
|
|
157
|
+
int shfl_idx = warp_shuffle_down(index, offset, mask);
|
|
158
|
+
index = track(sum, shfl_val, index, shfl_idx);
|
|
159
|
+
sum = f(sum, shfl_val);
|
|
160
|
+
}
|
|
161
|
+
}
|
|
162
|
+
else
|
|
163
|
+
{
|
|
164
|
+
// handle partial warp case
|
|
165
|
+
for (int offset=WP_TILE_WARP_SIZE/2; offset > 0; offset /= 2)
|
|
166
|
+
{
|
|
167
|
+
T shfl_val = warp_shuffle_down(sum, offset, mask);
|
|
168
|
+
int shfl_index = warp_shuffle_down(index, offset, mask);
|
|
169
|
+
if ((mask & (1 << ((threadIdx.x + offset)%WP_TILE_WARP_SIZE))) != 0)
|
|
170
|
+
{
|
|
171
|
+
index = track(sum, shfl_val, index, shfl_index);
|
|
172
|
+
sum = f(sum, shfl_val);
|
|
173
|
+
}
|
|
174
|
+
}
|
|
175
|
+
}
|
|
176
|
+
|
|
177
|
+
ValueAndIndex<T> result;
|
|
178
|
+
result.value = sum;
|
|
179
|
+
result.index = index;
|
|
180
|
+
|
|
181
|
+
return result;
|
|
182
|
+
}
|
|
183
|
+
|
|
184
|
+
// combines per-thread reduction results across warps and the entire block
|
|
185
|
+
// assumes each thread has already reduced its local data to thread_sum
|
|
186
|
+
// returns the block-wide reduced value (only valid in thread 0)
|
|
187
|
+
template <typename T, typename Op>
|
|
188
|
+
inline CUDA_CALLABLE T block_combine_thread_results(T thread_sum, bool thread_has_data, Op f,
|
|
189
|
+
T* partials, int& active_warps)
|
|
190
|
+
{
|
|
191
|
+
constexpr int warp_count = (WP_TILE_BLOCK_DIM + WP_TILE_WARP_SIZE - 1) / WP_TILE_WARP_SIZE;
|
|
192
|
+
const int warp_index = threadIdx.x / WP_TILE_WARP_SIZE;
|
|
193
|
+
const int lane_index = threadIdx.x % WP_TILE_WARP_SIZE;
|
|
194
|
+
|
|
195
|
+
// determine which threads have data
|
|
196
|
+
unsigned int mask = __ballot_sync(0xFFFFFFFF, thread_has_data);
|
|
197
|
+
bool warp_is_active = mask != 0;
|
|
198
|
+
|
|
199
|
+
// warp reduction
|
|
200
|
+
T warp_sum;
|
|
201
|
+
if (thread_has_data)
|
|
202
|
+
warp_sum = warp_reduce(thread_sum, f, mask);
|
|
203
|
+
|
|
204
|
+
// lane 0 of each active warp writes to shared memory and increments counter
|
|
205
|
+
if (lane_index == 0 && warp_is_active)
|
|
206
|
+
{
|
|
207
|
+
partials[warp_index] = warp_sum;
|
|
208
|
+
atomicAdd(&active_warps, 1);
|
|
209
|
+
}
|
|
210
|
+
|
|
211
|
+
// sync to ensure all warps have written their partials
|
|
212
|
+
WP_TILE_SYNC();
|
|
213
|
+
|
|
214
|
+
// thread 0 performs final reduction across active warps
|
|
215
|
+
T block_sum;
|
|
216
|
+
if (threadIdx.x == 0)
|
|
217
|
+
{
|
|
218
|
+
block_sum = partials[0];
|
|
219
|
+
|
|
220
|
+
for (int w = 1; w < active_warps; ++w)
|
|
221
|
+
{
|
|
222
|
+
block_sum = f(block_sum, partials[w]);
|
|
223
|
+
}
|
|
224
|
+
}
|
|
225
|
+
|
|
226
|
+
return block_sum;
|
|
227
|
+
}
|
|
228
|
+
|
|
229
|
+
// non-axis version which computes sum
|
|
230
|
+
// across the entire tile using the whole block
|
|
231
|
+
template <typename Tile, typename Op>
|
|
232
|
+
auto tile_reduce_impl(Op f, Tile& t)
|
|
233
|
+
{
|
|
234
|
+
using T = typename Tile::Type;
|
|
235
|
+
|
|
236
|
+
auto input = t.copy_to_register();
|
|
237
|
+
auto output = tile_register_t<T, tile_layout_register_t<tile_shape_t<1>>>();
|
|
238
|
+
|
|
239
|
+
constexpr int warp_count = (WP_TILE_BLOCK_DIM + WP_TILE_WARP_SIZE - 1) / WP_TILE_WARP_SIZE;
|
|
240
|
+
|
|
241
|
+
using Layout = typename decltype(input)::Layout;
|
|
242
|
+
|
|
243
|
+
// step 1: each thread reduces its own registers locally
|
|
244
|
+
T thread_sum = input.data[0];
|
|
245
|
+
bool thread_has_data = Layout::valid(Layout::linear_from_register(0));
|
|
246
|
+
|
|
247
|
+
WP_PRAGMA_UNROLL
|
|
248
|
+
for (int i=1; i < Layout::NumRegs; ++i)
|
|
249
|
+
{
|
|
250
|
+
int linear = Layout::linear_from_register(i);
|
|
251
|
+
if (!Layout::valid(linear))
|
|
252
|
+
break;
|
|
253
|
+
|
|
254
|
+
thread_sum = f(thread_sum, input.data[i]);
|
|
255
|
+
}
|
|
256
|
+
|
|
257
|
+
// shared memory for cross-warp reduction
|
|
258
|
+
__shared__ T partials[warp_count];
|
|
259
|
+
__shared__ int active_warps;
|
|
260
|
+
|
|
261
|
+
if (threadIdx.x == 0)
|
|
262
|
+
active_warps = 0;
|
|
263
|
+
|
|
264
|
+
WP_TILE_SYNC();
|
|
265
|
+
|
|
266
|
+
// step 2-3: combine thread results across warps and block
|
|
267
|
+
T block_sum = block_combine_thread_results(thread_sum, thread_has_data, f, partials, active_warps);
|
|
268
|
+
|
|
269
|
+
if (threadIdx.x == 0)
|
|
270
|
+
output.data[0] = block_sum;
|
|
271
|
+
|
|
272
|
+
return output;
|
|
273
|
+
}
|
|
274
|
+
|
|
275
|
+
template <int Axis, typename Op, typename Tile>
|
|
276
|
+
auto tile_reduce_axis_impl(Op f, Tile& t)
|
|
277
|
+
{
|
|
278
|
+
using T = typename Tile::Type;
|
|
279
|
+
using InputShape = typename Tile::Layout::Shape;
|
|
280
|
+
using OutputShape = typename tile_shape_remove_dim<Axis, InputShape>::type;
|
|
281
|
+
|
|
282
|
+
constexpr int reduce_dim_size = InputShape::dim(Axis);
|
|
283
|
+
constexpr int output_size = OutputShape::size();
|
|
284
|
+
|
|
285
|
+
// special case: 1D input delegates to block-wide tile_reduce_impl for optimal performance
|
|
286
|
+
if constexpr (InputShape::N == 1)
|
|
287
|
+
{
|
|
288
|
+
return tile_reduce_impl(f, t);
|
|
289
|
+
}
|
|
290
|
+
|
|
291
|
+
// shared memory buffer for the output (used by all tiers)
|
|
292
|
+
__shared__ T output_buffer[output_size];
|
|
293
|
+
|
|
294
|
+
// create output layout for coordinate conversion (used by all tiers)
|
|
295
|
+
using OutputLayout = tile_layout_strided_t<OutputShape>;
|
|
296
|
+
|
|
297
|
+
if constexpr (reduce_dim_size <= 32)
|
|
298
|
+
{
|
|
299
|
+
// Tier 1: Single thread per output element (optimal for small reductions)
|
|
300
|
+
|
|
301
|
+
// each thread processes output elements, performing reduction along the axis
|
|
302
|
+
for (int out_idx = WP_TILE_THREAD_IDX; out_idx < output_size; out_idx += WP_TILE_BLOCK_DIM)
|
|
303
|
+
{
|
|
304
|
+
// convert output linear index to output coordinates
|
|
305
|
+
auto out_coord = OutputLayout::coord_from_linear(out_idx);
|
|
306
|
+
|
|
307
|
+
// initialize accumulator with first element along the reduction axis
|
|
308
|
+
T accumulator = t.data(tile_coord_insert_axis<Axis>(out_coord, 0));
|
|
309
|
+
|
|
310
|
+
// reduce across the axis
|
|
311
|
+
for (int i = 1; i < reduce_dim_size; ++i)
|
|
312
|
+
{
|
|
313
|
+
accumulator = f(accumulator, t.data(tile_coord_insert_axis<Axis>(out_coord, i)));
|
|
314
|
+
}
|
|
315
|
+
|
|
316
|
+
// store to output buffer
|
|
317
|
+
output_buffer[out_idx] = accumulator;
|
|
318
|
+
}
|
|
319
|
+
|
|
320
|
+
// sync before reading output
|
|
321
|
+
WP_TILE_SYNC();
|
|
322
|
+
}
|
|
323
|
+
else if constexpr (reduce_dim_size <= 256)
|
|
324
|
+
{
|
|
325
|
+
// Tier 2: Warp-based reduction (one warp per output element)
|
|
326
|
+
constexpr int warp_count = (WP_TILE_BLOCK_DIM + WP_TILE_WARP_SIZE - 1) / WP_TILE_WARP_SIZE;
|
|
327
|
+
const int warp_index = threadIdx.x / WP_TILE_WARP_SIZE;
|
|
328
|
+
const int lane_index = threadIdx.x % WP_TILE_WARP_SIZE;
|
|
329
|
+
|
|
330
|
+
constexpr int chunks_per_slice = (reduce_dim_size + WP_TILE_WARP_SIZE - 1) / WP_TILE_WARP_SIZE;
|
|
331
|
+
|
|
332
|
+
// shared memory: one accumulator per warp
|
|
333
|
+
__shared__ T warp_partials[warp_count];
|
|
334
|
+
|
|
335
|
+
// each warp processes output slices
|
|
336
|
+
for (int out_idx = warp_index; out_idx < output_size; out_idx += warp_count)
|
|
337
|
+
{
|
|
338
|
+
auto out_coord = OutputLayout::coord_from_linear(out_idx);
|
|
339
|
+
|
|
340
|
+
// process the reduction axis in chunks of 32
|
|
341
|
+
for (int chunk = 0; chunk < chunks_per_slice; ++chunk)
|
|
342
|
+
{
|
|
343
|
+
int axis_idx = chunk * WP_TILE_WARP_SIZE + lane_index;
|
|
344
|
+
bool valid = axis_idx < reduce_dim_size;
|
|
345
|
+
|
|
346
|
+
T val;
|
|
347
|
+
if (valid)
|
|
348
|
+
{
|
|
349
|
+
auto in_coord = tile_coord_insert_axis<Axis>(out_coord, axis_idx);
|
|
350
|
+
val = t.data(in_coord);
|
|
351
|
+
}
|
|
352
|
+
|
|
353
|
+
// warp reduce this chunk (only valid lanes participate)
|
|
354
|
+
unsigned int mask = __ballot_sync(0xFFFFFFFF, valid);
|
|
355
|
+
T chunk_result = warp_reduce(val, f, mask);
|
|
356
|
+
|
|
357
|
+
// lane 0 accumulates the chunk result
|
|
358
|
+
if (lane_index == 0)
|
|
359
|
+
{
|
|
360
|
+
if (chunk == 0)
|
|
361
|
+
warp_partials[warp_index] = chunk_result;
|
|
362
|
+
else
|
|
363
|
+
warp_partials[warp_index] = f(warp_partials[warp_index], chunk_result);
|
|
364
|
+
}
|
|
365
|
+
}
|
|
366
|
+
|
|
367
|
+
// lane 0 writes final result for this output element
|
|
368
|
+
if (lane_index == 0)
|
|
369
|
+
output_buffer[out_idx] = warp_partials[warp_index];
|
|
370
|
+
}
|
|
371
|
+
|
|
372
|
+
// sync before reading output
|
|
373
|
+
WP_TILE_SYNC();
|
|
374
|
+
}
|
|
375
|
+
else
|
|
376
|
+
{
|
|
377
|
+
// Tier 3: Block-level reduction (entire block collaborates on each output element)
|
|
378
|
+
constexpr int warp_count = (WP_TILE_BLOCK_DIM + WP_TILE_WARP_SIZE - 1) / WP_TILE_WARP_SIZE;
|
|
379
|
+
|
|
380
|
+
// shared memory for cross-warp reduction
|
|
381
|
+
__shared__ T partials[warp_count];
|
|
382
|
+
__shared__ int active_warps;
|
|
383
|
+
|
|
384
|
+
// process each output element sequentially with full block cooperation
|
|
385
|
+
for (int out_idx = 0; out_idx < output_size; ++out_idx)
|
|
386
|
+
{
|
|
387
|
+
auto out_coord = OutputLayout::coord_from_linear(out_idx);
|
|
388
|
+
|
|
389
|
+
// step 1: each thread reduces its strided subset of the slice locally
|
|
390
|
+
bool thread_has_data = threadIdx.x < reduce_dim_size;
|
|
391
|
+
T thread_sum;
|
|
392
|
+
|
|
393
|
+
if (thread_has_data)
|
|
394
|
+
{
|
|
395
|
+
// initialize with first element
|
|
396
|
+
auto in_coord = tile_coord_insert_axis<Axis>(out_coord, threadIdx.x);
|
|
397
|
+
thread_sum = t.data(in_coord);
|
|
398
|
+
|
|
399
|
+
// reduce remaining elements with stride
|
|
400
|
+
for (int i = threadIdx.x + WP_TILE_BLOCK_DIM; i < reduce_dim_size; i += WP_TILE_BLOCK_DIM)
|
|
401
|
+
{
|
|
402
|
+
auto in_coord = tile_coord_insert_axis<Axis>(out_coord, i);
|
|
403
|
+
T val = t.data(in_coord);
|
|
404
|
+
thread_sum = f(thread_sum, val);
|
|
405
|
+
}
|
|
406
|
+
}
|
|
407
|
+
|
|
408
|
+
// initialize active warp counter
|
|
409
|
+
if (threadIdx.x == 0)
|
|
410
|
+
active_warps = 0;
|
|
411
|
+
|
|
412
|
+
WP_TILE_SYNC();
|
|
413
|
+
|
|
414
|
+
// step 2-3: combine thread results across warps and block
|
|
415
|
+
T block_sum = block_combine_thread_results(thread_sum, thread_has_data, f, partials, active_warps);
|
|
416
|
+
|
|
417
|
+
if (threadIdx.x == 0)
|
|
418
|
+
output_buffer[out_idx] = block_sum;
|
|
419
|
+
|
|
420
|
+
// sync before next output element
|
|
421
|
+
WP_TILE_SYNC();
|
|
422
|
+
}
|
|
423
|
+
}
|
|
424
|
+
|
|
425
|
+
// copy from shared memory buffer to register tile (common to all tiers)
|
|
426
|
+
auto output = tile_register_t<T, tile_layout_register_t<OutputShape>>();
|
|
427
|
+
using OutputRegLayout = typename decltype(output)::Layout;
|
|
428
|
+
|
|
429
|
+
WP_PRAGMA_UNROLL
|
|
430
|
+
for (int i = 0; i < OutputRegLayout::NumRegs; ++i)
|
|
431
|
+
{
|
|
432
|
+
int linear = OutputRegLayout::linear_from_register(i);
|
|
433
|
+
if (!OutputRegLayout::valid(linear))
|
|
434
|
+
break;
|
|
435
|
+
|
|
436
|
+
output.data[i] = output_buffer[linear];
|
|
437
|
+
}
|
|
438
|
+
|
|
439
|
+
return output;
|
|
440
|
+
}
|
|
441
|
+
|
|
442
|
+
// non-axis version which computes sum
|
|
443
|
+
// across the entire tile using the whole block
|
|
444
|
+
template <typename Tile, typename Op, typename OpTrack>
|
|
445
|
+
auto tile_arg_reduce_impl(Op f, OpTrack track, Tile& t)
|
|
446
|
+
{
|
|
447
|
+
using T = typename Tile::Type;
|
|
448
|
+
|
|
449
|
+
auto input = t.copy_to_register();
|
|
450
|
+
auto output = tile_register_t<int, tile_layout_register_t<tile_shape_t<1>>>();
|
|
451
|
+
|
|
452
|
+
const int warp_count = (WP_TILE_BLOCK_DIM + WP_TILE_WARP_SIZE - 1)/WP_TILE_WARP_SIZE;
|
|
453
|
+
const int warp_index = threadIdx.x/WP_TILE_WARP_SIZE;
|
|
454
|
+
const int lane_index = threadIdx.x%WP_TILE_WARP_SIZE;
|
|
455
|
+
|
|
456
|
+
using Layout = typename decltype(input)::Layout;
|
|
457
|
+
|
|
458
|
+
int champion_index = Layout::NumRegs > 0 ? Layout::linear_from_register(0) : -1;
|
|
459
|
+
T thread_sum = input.data[0];
|
|
460
|
+
|
|
461
|
+
// thread reduction
|
|
462
|
+
WP_PRAGMA_UNROLL
|
|
463
|
+
for (int i=1; i < Layout::NumRegs; ++i)
|
|
464
|
+
{
|
|
465
|
+
int linear = Layout::linear_from_register(i);
|
|
466
|
+
if (!Layout::valid(linear))
|
|
467
|
+
break;
|
|
468
|
+
|
|
469
|
+
champion_index = track(thread_sum, input.data[i], champion_index, linear);
|
|
470
|
+
thread_sum = f(thread_sum, input.data[i]);
|
|
471
|
+
}
|
|
472
|
+
|
|
473
|
+
// ensure that only threads with at least one valid item participate in the reduction
|
|
474
|
+
unsigned int mask = __ballot_sync(__activemask(), Layout::valid(Layout::linear_from_register(0)));
|
|
475
|
+
bool warp_is_active = mask != 0;
|
|
476
|
+
|
|
477
|
+
// warp reduction
|
|
478
|
+
ValueAndIndex<T> warp_sum = warp_reduce_tracked(thread_sum, champion_index, f, track, mask);
|
|
479
|
+
|
|
480
|
+
// fixed size scratch pad for partial results in shared memory
|
|
481
|
+
__shared__ T partials[warp_count];
|
|
482
|
+
__shared__ int partials_idx[warp_count];
|
|
483
|
+
|
|
484
|
+
// count of active warps
|
|
485
|
+
__shared__ int active_warps;
|
|
486
|
+
if (threadIdx.x == 0)
|
|
487
|
+
active_warps = 0;
|
|
488
|
+
|
|
489
|
+
// ensure active_warps is initialized
|
|
490
|
+
WP_TILE_SYNC();
|
|
491
|
+
|
|
492
|
+
if (lane_index == 0 && warp_is_active)
|
|
493
|
+
{
|
|
494
|
+
partials[warp_index] = warp_sum.value;
|
|
495
|
+
partials_idx[warp_index] = warp_sum.index;
|
|
496
|
+
atomicAdd(&active_warps, 1);
|
|
497
|
+
}
|
|
498
|
+
|
|
499
|
+
// ensure partials are ready
|
|
500
|
+
WP_TILE_SYNC();
|
|
501
|
+
|
|
502
|
+
// reduce across block, todo: use warp_reduce() here
|
|
503
|
+
if (threadIdx.x == 0)
|
|
504
|
+
{
|
|
505
|
+
T block_sum = partials[0];
|
|
506
|
+
int block_champion_index = partials_idx[0];
|
|
507
|
+
|
|
508
|
+
WP_PRAGMA_UNROLL
|
|
509
|
+
for (int i=1; i < active_warps; ++i)
|
|
510
|
+
{
|
|
511
|
+
block_champion_index = track(block_sum, partials[i], block_champion_index, partials_idx[i]);
|
|
512
|
+
block_sum = f(block_sum, partials[i]);
|
|
513
|
+
}
|
|
514
|
+
|
|
515
|
+
output.data[0] = block_champion_index;
|
|
516
|
+
}
|
|
517
|
+
|
|
518
|
+
return output;
|
|
519
|
+
}
|
|
520
|
+
|
|
521
|
+
#else
|
|
522
|
+
|
|
523
|
+
// CPU implementation
|
|
524
|
+
|
|
525
|
+
template <typename Tile, typename Op>
|
|
526
|
+
auto tile_reduce_impl(Op f, Tile& t)
|
|
527
|
+
{
|
|
528
|
+
using T = typename Tile::Type;
|
|
529
|
+
|
|
530
|
+
auto input = t.copy_to_register();
|
|
531
|
+
auto output = tile_register_t<T, tile_layout_register_t<tile_shape_t<1>>>();
|
|
532
|
+
|
|
533
|
+
using Layout = typename decltype(input)::Layout;
|
|
534
|
+
|
|
535
|
+
T sum = input.data[0];
|
|
536
|
+
|
|
537
|
+
WP_PRAGMA_UNROLL
|
|
538
|
+
for (int i=1; i < Layout::NumRegs; ++i)
|
|
539
|
+
{
|
|
540
|
+
int linear = Layout::linear_from_register(i);
|
|
541
|
+
if (!Layout::valid(linear))
|
|
542
|
+
break;
|
|
543
|
+
|
|
544
|
+
sum = f(sum, input.data[i]);
|
|
545
|
+
}
|
|
546
|
+
|
|
547
|
+
output.data[0] = sum;
|
|
548
|
+
return output;
|
|
549
|
+
}
|
|
550
|
+
|
|
551
|
+
template <int Axis, typename Op, typename Tile>
|
|
552
|
+
auto tile_reduce_axis_impl(Op f, Tile& t)
|
|
553
|
+
{
|
|
554
|
+
using T = typename Tile::Type;
|
|
555
|
+
using InputShape = typename Tile::Layout::Shape;
|
|
556
|
+
using OutputShape = typename tile_shape_remove_dim<Axis, InputShape>::type;
|
|
557
|
+
|
|
558
|
+
constexpr int reduce_dim_size = InputShape::dim(Axis);
|
|
559
|
+
|
|
560
|
+
// CPU version - work directly with register tiles, no thread coordination needed
|
|
561
|
+
auto input = t.copy_to_register();
|
|
562
|
+
auto output = tile_register_t<T, tile_layout_register_t<OutputShape>>();
|
|
563
|
+
using OutputLayout = typename decltype(output)::Layout;
|
|
564
|
+
|
|
565
|
+
// iterate through each output element and reduce along the axis
|
|
566
|
+
constexpr int output_size = OutputShape::size();
|
|
567
|
+
for (int out_idx = 0; out_idx < output_size; ++out_idx)
|
|
568
|
+
{
|
|
569
|
+
T accumulator;
|
|
570
|
+
|
|
571
|
+
// special case for 1D input (reduces to single value)
|
|
572
|
+
if constexpr (InputShape::N == 1)
|
|
573
|
+
{
|
|
574
|
+
accumulator = input.data[0];
|
|
575
|
+
for (int i = 1; i < reduce_dim_size; ++i)
|
|
576
|
+
{
|
|
577
|
+
// input is in registers, linear access
|
|
578
|
+
accumulator = f(accumulator, input.data[i]);
|
|
579
|
+
}
|
|
580
|
+
}
|
|
581
|
+
else
|
|
582
|
+
{
|
|
583
|
+
// multi-dimensional case
|
|
584
|
+
auto out_coord = OutputLayout::coord_from_linear(out_idx);
|
|
585
|
+
|
|
586
|
+
// get input coordinates by inserting axis values
|
|
587
|
+
auto coord_0 = tile_coord_insert_axis<Axis>(out_coord, 0);
|
|
588
|
+
int input_linear_0 = tile_layout_register_t<InputShape>::linear_from_coord(coord_0);
|
|
589
|
+
int input_reg_0 = tile_layout_register_t<InputShape>::register_from_linear(input_linear_0);
|
|
590
|
+
accumulator = input.data[input_reg_0];
|
|
591
|
+
|
|
592
|
+
// reduce across the axis
|
|
593
|
+
for (int i = 1; i < reduce_dim_size; ++i)
|
|
594
|
+
{
|
|
595
|
+
auto coord_i = tile_coord_insert_axis<Axis>(out_coord, i);
|
|
596
|
+
int input_linear_i = tile_layout_register_t<InputShape>::linear_from_coord(coord_i);
|
|
597
|
+
int input_reg_i = tile_layout_register_t<InputShape>::register_from_linear(input_linear_i);
|
|
598
|
+
accumulator = f(accumulator, input.data[input_reg_i]);
|
|
599
|
+
}
|
|
600
|
+
}
|
|
601
|
+
|
|
602
|
+
// store to output register
|
|
603
|
+
int output_reg = OutputLayout::register_from_linear(out_idx);
|
|
604
|
+
output.data[output_reg] = accumulator;
|
|
605
|
+
}
|
|
606
|
+
|
|
607
|
+
return output;
|
|
608
|
+
}
|
|
609
|
+
|
|
610
|
+
template <typename Tile, typename Op, typename OpTrack>
|
|
611
|
+
auto tile_arg_reduce_impl(Op f, OpTrack track, Tile& t)
|
|
612
|
+
{
|
|
613
|
+
using T = typename Tile::Type;
|
|
614
|
+
|
|
615
|
+
auto input = t.copy_to_register();
|
|
616
|
+
auto output = tile_register_t<int, tile_layout_register_t<tile_shape_t<1>>>();
|
|
617
|
+
|
|
618
|
+
using Layout = typename decltype(input)::Layout;
|
|
619
|
+
|
|
620
|
+
int champion_index = Layout::NumRegs > 0 ? Layout::linear_from_register(0) : -1;
|
|
621
|
+
T sum = input.data[0];
|
|
622
|
+
|
|
623
|
+
WP_PRAGMA_UNROLL
|
|
624
|
+
for (int i=1; i < Layout::NumRegs; ++i)
|
|
625
|
+
{
|
|
626
|
+
int linear = Layout::linear_from_register(i);
|
|
627
|
+
if (!Layout::valid(linear))
|
|
628
|
+
break;
|
|
629
|
+
|
|
630
|
+
champion_index = track(sum, input.data[i], champion_index, linear);
|
|
631
|
+
sum = f(sum, input.data[i]);
|
|
632
|
+
}
|
|
633
|
+
|
|
634
|
+
output.data[0] = champion_index;
|
|
635
|
+
return output;
|
|
636
|
+
}
|
|
637
|
+
|
|
638
|
+
#endif // !defined(__CUDA_ARCH__)
|
|
639
|
+
|
|
640
|
+
inline void adj_tile_reduce_impl()
|
|
641
|
+
{
|
|
642
|
+
// todo: general purpose reduction gradients not implemented
|
|
643
|
+
}
|
|
644
|
+
|
|
645
|
+
inline void adj_tile_reduce_axis_impl()
|
|
646
|
+
{
|
|
647
|
+
// todo: axis-specific reduction gradients not implemented
|
|
648
|
+
}
|
|
649
|
+
|
|
650
|
+
// entry point for Python code-gen, wraps op in a lambda to perform overload resolution
|
|
651
|
+
#define tile_reduce(op, t) tile_reduce_impl([](auto x, auto y) { return op(x, y);}, t)
|
|
652
|
+
#define adj_tile_reduce(op, t, adj_op, adj_t, adj_ret) adj_tile_reduce_impl()
|
|
653
|
+
|
|
654
|
+
#define tile_arg_reduce(op, opTrack, t) tile_arg_reduce_impl([](auto x, auto y) { return op(x, y);}, [](auto a, auto b, auto c, auto d) { return opTrack(a, b, c, d); }, t)
|
|
655
|
+
#define adj_tile_arg_reduce(op, t, adj_op, adj_t, adj_ret) adj_tile_arg_reduce_impl()
|
|
656
|
+
|
|
657
|
+
// axis-specific reduction entry points
|
|
658
|
+
#define tile_reduce_axis(op, t, axis) tile_reduce_axis_impl<axis>([](auto x, auto y) { return op(x, y);}, t)
|
|
659
|
+
#define adj_tile_reduce_axis(op, t, axis, adj_op, adj_t, adj_axis, adj_ret) adj_tile_reduce_axis_impl()
|
|
660
|
+
|
|
661
|
+
// convenience methods for specific reductions
|
|
662
|
+
|
|
663
|
+
// whole-tile sum
|
|
664
|
+
template <typename Tile>
|
|
665
|
+
auto tile_sum(Tile& t)
|
|
666
|
+
{
|
|
667
|
+
return tile_reduce(add, t);
|
|
668
|
+
}
|
|
669
|
+
|
|
670
|
+
// special case adjoint for summation
|
|
671
|
+
template <typename Tile, typename AdjTile>
|
|
672
|
+
void adj_tile_sum(Tile& t, Tile& adj_t, AdjTile& adj_ret)
|
|
673
|
+
{
|
|
674
|
+
using T = typename Tile::Type;
|
|
675
|
+
|
|
676
|
+
auto adj_reg = adj_ret.grad_to_register();
|
|
677
|
+
|
|
678
|
+
#if !defined(__CUDA_ARCH__)
|
|
679
|
+
T scratch = adj_reg.data[0];
|
|
680
|
+
#else
|
|
681
|
+
// broadcast incoming adjoint to block
|
|
682
|
+
__shared__ T scratch;
|
|
683
|
+
if (WP_TILE_THREAD_IDX == 0)
|
|
684
|
+
scratch = adj_reg.data[0];
|
|
685
|
+
|
|
686
|
+
WP_TILE_SYNC();
|
|
687
|
+
#endif
|
|
688
|
+
|
|
689
|
+
auto adj_ret_reg = tile_register_like<Tile>();
|
|
690
|
+
using Layout = typename decltype(adj_ret_reg)::Layout;
|
|
691
|
+
for (int i=0; i < Layout::NumRegs; ++i)
|
|
692
|
+
{
|
|
693
|
+
adj_ret_reg.data[i] += scratch;
|
|
694
|
+
}
|
|
695
|
+
adj_t.grad_add(adj_ret_reg);
|
|
696
|
+
}
|
|
697
|
+
|
|
698
|
+
// axis-specific sum
|
|
699
|
+
template <int Axis, typename Tile>
|
|
700
|
+
auto tile_sum(Tile& t)
|
|
701
|
+
{
|
|
702
|
+
return tile_reduce_axis_impl<Axis>([](auto x, auto y) { return add(x, y); }, t);
|
|
703
|
+
}
|
|
704
|
+
|
|
705
|
+
// special case adjoint for axis-specific summation
|
|
706
|
+
template<int Axis, typename Tile, typename AdjTile>
|
|
707
|
+
void adj_tile_sum(Tile& t, Tile& adj_t, AdjTile& adj_ret)
|
|
708
|
+
{
|
|
709
|
+
using InputShape = typename Tile::Layout::Shape;
|
|
710
|
+
|
|
711
|
+
if constexpr (InputShape::N == 1)
|
|
712
|
+
{
|
|
713
|
+
// 1D -> scalar case: broadcast scalar to 1D
|
|
714
|
+
auto broadcasted = tile_broadcast<InputShape::dim(0), 0>(adj_ret);
|
|
715
|
+
tile_add_inplace(adj_t, broadcasted);
|
|
716
|
+
}
|
|
717
|
+
else if constexpr (InputShape::N == 2)
|
|
718
|
+
{
|
|
719
|
+
if constexpr (Axis == 0)
|
|
720
|
+
{
|
|
721
|
+
// broadcast from (D1,) to (D0, D1) with strides (0, 1)
|
|
722
|
+
auto broadcasted = tile_broadcast<InputShape::dim(0), InputShape::dim(1), 0, 1>(adj_ret);
|
|
723
|
+
tile_add_inplace(adj_t, broadcasted);
|
|
724
|
+
}
|
|
725
|
+
else // Axis == 1
|
|
726
|
+
{
|
|
727
|
+
// broadcast from (D0,) to (D0, D1) with strides (1, 0)
|
|
728
|
+
auto broadcasted = tile_broadcast<InputShape::dim(0), InputShape::dim(1), 1, 0>(adj_ret);
|
|
729
|
+
tile_add_inplace(adj_t, broadcasted);
|
|
730
|
+
}
|
|
731
|
+
}
|
|
732
|
+
else if constexpr (InputShape::N == 3)
|
|
733
|
+
{
|
|
734
|
+
if constexpr (Axis == 0)
|
|
735
|
+
{
|
|
736
|
+
// broadcast from (D1, D2) to (D0, D1, D2) with strides (0, D2, 1)
|
|
737
|
+
auto broadcasted = tile_broadcast<InputShape::dim(0), InputShape::dim(1), InputShape::dim(2), 0, InputShape::dim(2), 1>(adj_ret);
|
|
738
|
+
tile_add_inplace(adj_t, broadcasted);
|
|
739
|
+
}
|
|
740
|
+
else if constexpr (Axis == 1)
|
|
741
|
+
{
|
|
742
|
+
// broadcast from (D0, D2) to (D0, D1, D2) with strides (D2, 0, 1)
|
|
743
|
+
auto broadcasted = tile_broadcast<InputShape::dim(0), InputShape::dim(1), InputShape::dim(2), InputShape::dim(2), 0, 1>(adj_ret);
|
|
744
|
+
tile_add_inplace(adj_t, broadcasted);
|
|
745
|
+
}
|
|
746
|
+
else // Axis == 2
|
|
747
|
+
{
|
|
748
|
+
// broadcast from (D0, D1) to (D0, D1, D2) with strides (D1, 1, 0)
|
|
749
|
+
auto broadcasted = tile_broadcast<InputShape::dim(0), InputShape::dim(1), InputShape::dim(2), InputShape::dim(1), 1, 0>(adj_ret);
|
|
750
|
+
tile_add_inplace(adj_t, broadcasted);
|
|
751
|
+
}
|
|
752
|
+
}
|
|
753
|
+
else if constexpr (InputShape::N == 4)
|
|
754
|
+
{
|
|
755
|
+
if constexpr (Axis == 0)
|
|
756
|
+
{
|
|
757
|
+
// broadcast from (D1, D2, D3) to (D0, D1, D2, D3) with strides (0, D2*D3, D3, 1)
|
|
758
|
+
auto broadcasted = tile_broadcast<InputShape::dim(0), InputShape::dim(1), InputShape::dim(2), InputShape::dim(3), 0, InputShape::dim(2)*InputShape::dim(3), InputShape::dim(3), 1>(adj_ret);
|
|
759
|
+
tile_add_inplace(adj_t, broadcasted);
|
|
760
|
+
}
|
|
761
|
+
else if constexpr (Axis == 1)
|
|
762
|
+
{
|
|
763
|
+
// broadcast from (D0, D2, D3) to (D0, D1, D2, D3) with strides (D2*D3, 0, D3, 1)
|
|
764
|
+
auto broadcasted = tile_broadcast<InputShape::dim(0), InputShape::dim(1), InputShape::dim(2), InputShape::dim(3), InputShape::dim(2)*InputShape::dim(3), 0, InputShape::dim(3), 1>(adj_ret);
|
|
765
|
+
tile_add_inplace(adj_t, broadcasted);
|
|
766
|
+
}
|
|
767
|
+
else if constexpr (Axis == 2)
|
|
768
|
+
{
|
|
769
|
+
// broadcast from (D0, D1, D3) to (D0, D1, D2, D3) with strides (D1*D3, D3, 0, 1)
|
|
770
|
+
auto broadcasted = tile_broadcast<InputShape::dim(0), InputShape::dim(1), InputShape::dim(2), InputShape::dim(3), InputShape::dim(1)*InputShape::dim(3), InputShape::dim(3), 0, 1>(adj_ret);
|
|
771
|
+
tile_add_inplace(adj_t, broadcasted);
|
|
772
|
+
}
|
|
773
|
+
else // Axis == 3
|
|
774
|
+
{
|
|
775
|
+
// broadcast from (D0, D1, D2) to (D0, D1, D2, D3) with strides (D1*D2, D2, 1, 0)
|
|
776
|
+
auto broadcasted = tile_broadcast<InputShape::dim(0), InputShape::dim(1), InputShape::dim(2), InputShape::dim(3), InputShape::dim(1)*InputShape::dim(2), InputShape::dim(2), 1, 0>(adj_ret);
|
|
777
|
+
tile_add_inplace(adj_t, broadcasted);
|
|
778
|
+
}
|
|
779
|
+
}
|
|
780
|
+
}
|
|
781
|
+
|
|
782
|
+
template <typename Tile>
|
|
783
|
+
auto tile_max(Tile& t)
|
|
784
|
+
{
|
|
785
|
+
return tile_reduce(max, t);
|
|
786
|
+
}
|
|
787
|
+
|
|
788
|
+
template <typename Tile, typename AdjTile>
|
|
789
|
+
void adj_tile_max(Tile& t, Tile& adj_t, AdjTile& adj_ret)
|
|
790
|
+
{
|
|
791
|
+
// todo: not implemented
|
|
792
|
+
}
|
|
793
|
+
|
|
794
|
+
template <typename Tile>
|
|
795
|
+
auto tile_min(Tile& t)
|
|
796
|
+
{
|
|
797
|
+
return tile_reduce(min, t);
|
|
798
|
+
}
|
|
799
|
+
|
|
800
|
+
template <typename Tile, typename AdjTile>
|
|
801
|
+
void adj_tile_min(Tile& t, Tile& adj_t, AdjTile& adj_ret)
|
|
802
|
+
{
|
|
803
|
+
// todo: not implemented
|
|
804
|
+
}
|
|
805
|
+
|
|
806
|
+
|
|
807
|
+
|
|
808
|
+
template <typename Tile>
|
|
809
|
+
auto tile_argmax(Tile& t)
|
|
810
|
+
{
|
|
811
|
+
return tile_arg_reduce(max, argmax_tracker, t);
|
|
812
|
+
}
|
|
813
|
+
|
|
814
|
+
template <typename Tile, typename AdjTile>
|
|
815
|
+
void adj_tile_argmax(Tile& t, Tile& adj_t, AdjTile& adj_ret)
|
|
816
|
+
{
|
|
817
|
+
// todo: not implemented
|
|
818
|
+
}
|
|
819
|
+
|
|
820
|
+
template <typename Tile>
|
|
821
|
+
auto tile_argmin(Tile& t)
|
|
822
|
+
{
|
|
823
|
+
return tile_arg_reduce(min, argmin_tracker, t);
|
|
824
|
+
}
|
|
825
|
+
|
|
826
|
+
template <typename Tile, typename AdjTile>
|
|
827
|
+
void adj_tile_argmin(Tile& t, Tile& adj_t, AdjTile& adj_ret)
|
|
828
|
+
{
|
|
829
|
+
// todo: not implemented
|
|
830
|
+
}
|
|
831
|
+
|
|
832
|
+
|
|
833
|
+
} // namespace wp
|
|
834
|
+
|
|
835
|
+
|
|
836
|
+
#ifdef __clang__
|
|
837
|
+
#pragma clang diagnostic pop
|
|
838
|
+
#endif
|