warp-lang 1.7.0__py3-none-manylinux_2_34_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +139 -0
- warp/__init__.pyi +1 -0
- warp/autograd.py +1142 -0
- warp/bin/warp-clang.so +0 -0
- warp/bin/warp.so +0 -0
- warp/build.py +557 -0
- warp/build_dll.py +405 -0
- warp/builtins.py +6855 -0
- warp/codegen.py +3969 -0
- warp/config.py +158 -0
- warp/constants.py +57 -0
- warp/context.py +6812 -0
- warp/dlpack.py +462 -0
- warp/examples/__init__.py +24 -0
- warp/examples/assets/bear.usd +0 -0
- warp/examples/assets/bunny.usd +0 -0
- warp/examples/assets/cartpole.urdf +110 -0
- warp/examples/assets/crazyflie.usd +0 -0
- warp/examples/assets/cube.usd +0 -0
- warp/examples/assets/nonuniform.usd +0 -0
- warp/examples/assets/nv_ant.xml +92 -0
- warp/examples/assets/nv_humanoid.xml +183 -0
- warp/examples/assets/nvidia_logo.png +0 -0
- warp/examples/assets/pixel.jpg +0 -0
- warp/examples/assets/quadruped.urdf +268 -0
- warp/examples/assets/rocks.nvdb +0 -0
- warp/examples/assets/rocks.usd +0 -0
- warp/examples/assets/sphere.usd +0 -0
- warp/examples/assets/square_cloth.usd +0 -0
- warp/examples/benchmarks/benchmark_api.py +389 -0
- warp/examples/benchmarks/benchmark_cloth.py +296 -0
- warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
- warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
- warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
- warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
- warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
- warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
- warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
- warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
- warp/examples/benchmarks/benchmark_gemm.py +164 -0
- warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
- warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
- warp/examples/benchmarks/benchmark_launches.py +301 -0
- warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
- warp/examples/browse.py +37 -0
- warp/examples/core/example_cupy.py +86 -0
- warp/examples/core/example_dem.py +241 -0
- warp/examples/core/example_fluid.py +299 -0
- warp/examples/core/example_graph_capture.py +150 -0
- warp/examples/core/example_marching_cubes.py +194 -0
- warp/examples/core/example_mesh.py +180 -0
- warp/examples/core/example_mesh_intersect.py +211 -0
- warp/examples/core/example_nvdb.py +182 -0
- warp/examples/core/example_raycast.py +111 -0
- warp/examples/core/example_raymarch.py +205 -0
- warp/examples/core/example_render_opengl.py +193 -0
- warp/examples/core/example_sample_mesh.py +300 -0
- warp/examples/core/example_sph.py +411 -0
- warp/examples/core/example_torch.py +211 -0
- warp/examples/core/example_wave.py +269 -0
- warp/examples/fem/example_adaptive_grid.py +286 -0
- warp/examples/fem/example_apic_fluid.py +423 -0
- warp/examples/fem/example_burgers.py +261 -0
- warp/examples/fem/example_convection_diffusion.py +178 -0
- warp/examples/fem/example_convection_diffusion_dg.py +204 -0
- warp/examples/fem/example_deformed_geometry.py +172 -0
- warp/examples/fem/example_diffusion.py +196 -0
- warp/examples/fem/example_diffusion_3d.py +225 -0
- warp/examples/fem/example_diffusion_mgpu.py +220 -0
- warp/examples/fem/example_distortion_energy.py +228 -0
- warp/examples/fem/example_magnetostatics.py +240 -0
- warp/examples/fem/example_mixed_elasticity.py +291 -0
- warp/examples/fem/example_navier_stokes.py +261 -0
- warp/examples/fem/example_nonconforming_contact.py +298 -0
- warp/examples/fem/example_stokes.py +213 -0
- warp/examples/fem/example_stokes_transfer.py +262 -0
- warp/examples/fem/example_streamlines.py +352 -0
- warp/examples/fem/utils.py +1000 -0
- warp/examples/interop/example_jax_callable.py +116 -0
- warp/examples/interop/example_jax_ffi_callback.py +132 -0
- warp/examples/interop/example_jax_kernel.py +205 -0
- warp/examples/optim/example_bounce.py +266 -0
- warp/examples/optim/example_cloth_throw.py +228 -0
- warp/examples/optim/example_diffray.py +561 -0
- warp/examples/optim/example_drone.py +870 -0
- warp/examples/optim/example_fluid_checkpoint.py +497 -0
- warp/examples/optim/example_inverse_kinematics.py +182 -0
- warp/examples/optim/example_inverse_kinematics_torch.py +191 -0
- warp/examples/optim/example_softbody_properties.py +400 -0
- warp/examples/optim/example_spring_cage.py +245 -0
- warp/examples/optim/example_trajectory.py +227 -0
- warp/examples/sim/example_cartpole.py +143 -0
- warp/examples/sim/example_cloth.py +225 -0
- warp/examples/sim/example_cloth_self_contact.py +322 -0
- warp/examples/sim/example_granular.py +130 -0
- warp/examples/sim/example_granular_collision_sdf.py +202 -0
- warp/examples/sim/example_jacobian_ik.py +244 -0
- warp/examples/sim/example_particle_chain.py +124 -0
- warp/examples/sim/example_quadruped.py +203 -0
- warp/examples/sim/example_rigid_chain.py +203 -0
- warp/examples/sim/example_rigid_contact.py +195 -0
- warp/examples/sim/example_rigid_force.py +133 -0
- warp/examples/sim/example_rigid_gyroscopic.py +115 -0
- warp/examples/sim/example_rigid_soft_contact.py +140 -0
- warp/examples/sim/example_soft_body.py +196 -0
- warp/examples/tile/example_tile_cholesky.py +87 -0
- warp/examples/tile/example_tile_convolution.py +66 -0
- warp/examples/tile/example_tile_fft.py +55 -0
- warp/examples/tile/example_tile_filtering.py +113 -0
- warp/examples/tile/example_tile_matmul.py +85 -0
- warp/examples/tile/example_tile_mlp.py +383 -0
- warp/examples/tile/example_tile_nbody.py +199 -0
- warp/examples/tile/example_tile_walker.py +327 -0
- warp/fabric.py +355 -0
- warp/fem/__init__.py +106 -0
- warp/fem/adaptivity.py +508 -0
- warp/fem/cache.py +572 -0
- warp/fem/dirichlet.py +202 -0
- warp/fem/domain.py +411 -0
- warp/fem/field/__init__.py +125 -0
- warp/fem/field/field.py +619 -0
- warp/fem/field/nodal_field.py +326 -0
- warp/fem/field/restriction.py +37 -0
- warp/fem/field/virtual.py +848 -0
- warp/fem/geometry/__init__.py +32 -0
- warp/fem/geometry/adaptive_nanogrid.py +857 -0
- warp/fem/geometry/closest_point.py +84 -0
- warp/fem/geometry/deformed_geometry.py +221 -0
- warp/fem/geometry/element.py +776 -0
- warp/fem/geometry/geometry.py +362 -0
- warp/fem/geometry/grid_2d.py +392 -0
- warp/fem/geometry/grid_3d.py +452 -0
- warp/fem/geometry/hexmesh.py +911 -0
- warp/fem/geometry/nanogrid.py +571 -0
- warp/fem/geometry/partition.py +389 -0
- warp/fem/geometry/quadmesh.py +663 -0
- warp/fem/geometry/tetmesh.py +855 -0
- warp/fem/geometry/trimesh.py +806 -0
- warp/fem/integrate.py +2335 -0
- warp/fem/linalg.py +419 -0
- warp/fem/operator.py +293 -0
- warp/fem/polynomial.py +229 -0
- warp/fem/quadrature/__init__.py +17 -0
- warp/fem/quadrature/pic_quadrature.py +299 -0
- warp/fem/quadrature/quadrature.py +591 -0
- warp/fem/space/__init__.py +228 -0
- warp/fem/space/basis_function_space.py +468 -0
- warp/fem/space/basis_space.py +667 -0
- warp/fem/space/dof_mapper.py +251 -0
- warp/fem/space/function_space.py +309 -0
- warp/fem/space/grid_2d_function_space.py +177 -0
- warp/fem/space/grid_3d_function_space.py +227 -0
- warp/fem/space/hexmesh_function_space.py +257 -0
- warp/fem/space/nanogrid_function_space.py +201 -0
- warp/fem/space/partition.py +367 -0
- warp/fem/space/quadmesh_function_space.py +223 -0
- warp/fem/space/restriction.py +179 -0
- warp/fem/space/shape/__init__.py +143 -0
- warp/fem/space/shape/cube_shape_function.py +1105 -0
- warp/fem/space/shape/shape_function.py +133 -0
- warp/fem/space/shape/square_shape_function.py +926 -0
- warp/fem/space/shape/tet_shape_function.py +834 -0
- warp/fem/space/shape/triangle_shape_function.py +672 -0
- warp/fem/space/tetmesh_function_space.py +271 -0
- warp/fem/space/topology.py +424 -0
- warp/fem/space/trimesh_function_space.py +194 -0
- warp/fem/types.py +99 -0
- warp/fem/utils.py +420 -0
- warp/jax.py +187 -0
- warp/jax_experimental/__init__.py +16 -0
- warp/jax_experimental/custom_call.py +351 -0
- warp/jax_experimental/ffi.py +698 -0
- warp/jax_experimental/xla_ffi.py +602 -0
- warp/math.py +244 -0
- warp/native/array.h +1145 -0
- warp/native/builtin.h +1800 -0
- warp/native/bvh.cpp +492 -0
- warp/native/bvh.cu +791 -0
- warp/native/bvh.h +554 -0
- warp/native/clang/clang.cpp +536 -0
- warp/native/coloring.cpp +613 -0
- warp/native/crt.cpp +51 -0
- warp/native/crt.h +362 -0
- warp/native/cuda_crt.h +1058 -0
- warp/native/cuda_util.cpp +646 -0
- warp/native/cuda_util.h +307 -0
- warp/native/error.cpp +77 -0
- warp/native/error.h +36 -0
- warp/native/exports.h +1878 -0
- warp/native/fabric.h +245 -0
- warp/native/hashgrid.cpp +311 -0
- warp/native/hashgrid.cu +87 -0
- warp/native/hashgrid.h +240 -0
- warp/native/initializer_array.h +41 -0
- warp/native/intersect.h +1230 -0
- warp/native/intersect_adj.h +375 -0
- warp/native/intersect_tri.h +339 -0
- warp/native/marching.cpp +19 -0
- warp/native/marching.cu +514 -0
- warp/native/marching.h +19 -0
- warp/native/mat.h +2220 -0
- warp/native/mathdx.cpp +87 -0
- warp/native/matnn.h +343 -0
- warp/native/mesh.cpp +266 -0
- warp/native/mesh.cu +404 -0
- warp/native/mesh.h +1980 -0
- warp/native/nanovdb/GridHandle.h +366 -0
- warp/native/nanovdb/HostBuffer.h +590 -0
- warp/native/nanovdb/NanoVDB.h +6624 -0
- warp/native/nanovdb/PNanoVDB.h +3390 -0
- warp/native/noise.h +859 -0
- warp/native/quat.h +1371 -0
- warp/native/rand.h +342 -0
- warp/native/range.h +139 -0
- warp/native/reduce.cpp +174 -0
- warp/native/reduce.cu +364 -0
- warp/native/runlength_encode.cpp +79 -0
- warp/native/runlength_encode.cu +61 -0
- warp/native/scan.cpp +47 -0
- warp/native/scan.cu +53 -0
- warp/native/scan.h +23 -0
- warp/native/solid_angle.h +466 -0
- warp/native/sort.cpp +251 -0
- warp/native/sort.cu +277 -0
- warp/native/sort.h +33 -0
- warp/native/sparse.cpp +378 -0
- warp/native/sparse.cu +524 -0
- warp/native/spatial.h +657 -0
- warp/native/svd.h +702 -0
- warp/native/temp_buffer.h +46 -0
- warp/native/tile.h +2584 -0
- warp/native/tile_reduce.h +264 -0
- warp/native/vec.h +1426 -0
- warp/native/volume.cpp +501 -0
- warp/native/volume.cu +67 -0
- warp/native/volume.h +969 -0
- warp/native/volume_builder.cu +477 -0
- warp/native/volume_builder.h +52 -0
- warp/native/volume_impl.h +70 -0
- warp/native/warp.cpp +1082 -0
- warp/native/warp.cu +3636 -0
- warp/native/warp.h +381 -0
- warp/optim/__init__.py +17 -0
- warp/optim/adam.py +163 -0
- warp/optim/linear.py +1137 -0
- warp/optim/sgd.py +112 -0
- warp/paddle.py +407 -0
- warp/render/__init__.py +18 -0
- warp/render/render_opengl.py +3518 -0
- warp/render/render_usd.py +784 -0
- warp/render/utils.py +160 -0
- warp/sim/__init__.py +65 -0
- warp/sim/articulation.py +793 -0
- warp/sim/collide.py +2395 -0
- warp/sim/graph_coloring.py +300 -0
- warp/sim/import_mjcf.py +790 -0
- warp/sim/import_snu.py +227 -0
- warp/sim/import_urdf.py +579 -0
- warp/sim/import_usd.py +894 -0
- warp/sim/inertia.py +324 -0
- warp/sim/integrator.py +242 -0
- warp/sim/integrator_euler.py +1997 -0
- warp/sim/integrator_featherstone.py +2101 -0
- warp/sim/integrator_vbd.py +2048 -0
- warp/sim/integrator_xpbd.py +3292 -0
- warp/sim/model.py +4791 -0
- warp/sim/particles.py +121 -0
- warp/sim/render.py +427 -0
- warp/sim/utils.py +428 -0
- warp/sparse.py +2057 -0
- warp/stubs.py +3333 -0
- warp/tape.py +1203 -0
- warp/tests/__init__.py +1 -0
- warp/tests/__main__.py +4 -0
- warp/tests/assets/curlnoise_golden.npy +0 -0
- warp/tests/assets/mlp_golden.npy +0 -0
- warp/tests/assets/pixel.npy +0 -0
- warp/tests/assets/pnoise_golden.npy +0 -0
- warp/tests/assets/spiky.usd +0 -0
- warp/tests/assets/test_grid.nvdb +0 -0
- warp/tests/assets/test_index_grid.nvdb +0 -0
- warp/tests/assets/test_int32_grid.nvdb +0 -0
- warp/tests/assets/test_vec_grid.nvdb +0 -0
- warp/tests/assets/torus.nvdb +0 -0
- warp/tests/assets/torus.usda +105 -0
- warp/tests/aux_test_class_kernel.py +34 -0
- warp/tests/aux_test_compile_consts_dummy.py +18 -0
- warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
- warp/tests/aux_test_dependent.py +29 -0
- warp/tests/aux_test_grad_customs.py +29 -0
- warp/tests/aux_test_instancing_gc.py +26 -0
- warp/tests/aux_test_module_unload.py +23 -0
- warp/tests/aux_test_name_clash1.py +40 -0
- warp/tests/aux_test_name_clash2.py +40 -0
- warp/tests/aux_test_reference.py +9 -0
- warp/tests/aux_test_reference_reference.py +8 -0
- warp/tests/aux_test_square.py +16 -0
- warp/tests/aux_test_unresolved_func.py +22 -0
- warp/tests/aux_test_unresolved_symbol.py +22 -0
- warp/tests/cuda/__init__.py +0 -0
- warp/tests/cuda/test_async.py +676 -0
- warp/tests/cuda/test_ipc.py +124 -0
- warp/tests/cuda/test_mempool.py +233 -0
- warp/tests/cuda/test_multigpu.py +169 -0
- warp/tests/cuda/test_peer.py +139 -0
- warp/tests/cuda/test_pinned.py +84 -0
- warp/tests/cuda/test_streams.py +634 -0
- warp/tests/geometry/__init__.py +0 -0
- warp/tests/geometry/test_bvh.py +200 -0
- warp/tests/geometry/test_hash_grid.py +221 -0
- warp/tests/geometry/test_marching_cubes.py +74 -0
- warp/tests/geometry/test_mesh.py +316 -0
- warp/tests/geometry/test_mesh_query_aabb.py +399 -0
- warp/tests/geometry/test_mesh_query_point.py +932 -0
- warp/tests/geometry/test_mesh_query_ray.py +311 -0
- warp/tests/geometry/test_volume.py +1103 -0
- warp/tests/geometry/test_volume_write.py +346 -0
- warp/tests/interop/__init__.py +0 -0
- warp/tests/interop/test_dlpack.py +729 -0
- warp/tests/interop/test_jax.py +371 -0
- warp/tests/interop/test_paddle.py +800 -0
- warp/tests/interop/test_torch.py +1001 -0
- warp/tests/run_coverage_serial.py +39 -0
- warp/tests/sim/__init__.py +0 -0
- warp/tests/sim/disabled_kinematics.py +244 -0
- warp/tests/sim/flaky_test_sim_grad.py +290 -0
- warp/tests/sim/test_collision.py +604 -0
- warp/tests/sim/test_coloring.py +258 -0
- warp/tests/sim/test_model.py +224 -0
- warp/tests/sim/test_sim_grad_bounce_linear.py +212 -0
- warp/tests/sim/test_sim_kinematics.py +98 -0
- warp/tests/sim/test_vbd.py +597 -0
- warp/tests/test_adam.py +163 -0
- warp/tests/test_arithmetic.py +1096 -0
- warp/tests/test_array.py +2972 -0
- warp/tests/test_array_reduce.py +156 -0
- warp/tests/test_assert.py +250 -0
- warp/tests/test_atomic.py +153 -0
- warp/tests/test_bool.py +220 -0
- warp/tests/test_builtins_resolution.py +1298 -0
- warp/tests/test_closest_point_edge_edge.py +327 -0
- warp/tests/test_codegen.py +810 -0
- warp/tests/test_codegen_instancing.py +1495 -0
- warp/tests/test_compile_consts.py +215 -0
- warp/tests/test_conditional.py +252 -0
- warp/tests/test_context.py +42 -0
- warp/tests/test_copy.py +238 -0
- warp/tests/test_ctypes.py +638 -0
- warp/tests/test_dense.py +73 -0
- warp/tests/test_devices.py +97 -0
- warp/tests/test_examples.py +482 -0
- warp/tests/test_fabricarray.py +996 -0
- warp/tests/test_fast_math.py +74 -0
- warp/tests/test_fem.py +2003 -0
- warp/tests/test_fp16.py +136 -0
- warp/tests/test_func.py +454 -0
- warp/tests/test_future_annotations.py +98 -0
- warp/tests/test_generics.py +656 -0
- warp/tests/test_grad.py +893 -0
- warp/tests/test_grad_customs.py +339 -0
- warp/tests/test_grad_debug.py +341 -0
- warp/tests/test_implicit_init.py +411 -0
- warp/tests/test_import.py +45 -0
- warp/tests/test_indexedarray.py +1140 -0
- warp/tests/test_intersect.py +73 -0
- warp/tests/test_iter.py +76 -0
- warp/tests/test_large.py +177 -0
- warp/tests/test_launch.py +411 -0
- warp/tests/test_lerp.py +151 -0
- warp/tests/test_linear_solvers.py +193 -0
- warp/tests/test_lvalue.py +427 -0
- warp/tests/test_mat.py +2089 -0
- warp/tests/test_mat_lite.py +122 -0
- warp/tests/test_mat_scalar_ops.py +2913 -0
- warp/tests/test_math.py +178 -0
- warp/tests/test_mlp.py +282 -0
- warp/tests/test_module_hashing.py +258 -0
- warp/tests/test_modules_lite.py +44 -0
- warp/tests/test_noise.py +252 -0
- warp/tests/test_operators.py +299 -0
- warp/tests/test_options.py +129 -0
- warp/tests/test_overwrite.py +551 -0
- warp/tests/test_print.py +339 -0
- warp/tests/test_quat.py +2315 -0
- warp/tests/test_rand.py +339 -0
- warp/tests/test_reload.py +302 -0
- warp/tests/test_rounding.py +185 -0
- warp/tests/test_runlength_encode.py +196 -0
- warp/tests/test_scalar_ops.py +105 -0
- warp/tests/test_smoothstep.py +108 -0
- warp/tests/test_snippet.py +318 -0
- warp/tests/test_sparse.py +582 -0
- warp/tests/test_spatial.py +2229 -0
- warp/tests/test_special_values.py +361 -0
- warp/tests/test_static.py +592 -0
- warp/tests/test_struct.py +734 -0
- warp/tests/test_tape.py +204 -0
- warp/tests/test_transient_module.py +93 -0
- warp/tests/test_triangle_closest_point.py +145 -0
- warp/tests/test_types.py +562 -0
- warp/tests/test_utils.py +588 -0
- warp/tests/test_vec.py +1487 -0
- warp/tests/test_vec_lite.py +80 -0
- warp/tests/test_vec_scalar_ops.py +2327 -0
- warp/tests/test_verify_fp.py +100 -0
- warp/tests/tile/__init__.py +0 -0
- warp/tests/tile/test_tile.py +780 -0
- warp/tests/tile/test_tile_load.py +407 -0
- warp/tests/tile/test_tile_mathdx.py +208 -0
- warp/tests/tile/test_tile_mlp.py +402 -0
- warp/tests/tile/test_tile_reduce.py +447 -0
- warp/tests/tile/test_tile_shared_memory.py +247 -0
- warp/tests/tile/test_tile_view.py +173 -0
- warp/tests/unittest_serial.py +47 -0
- warp/tests/unittest_suites.py +427 -0
- warp/tests/unittest_utils.py +468 -0
- warp/tests/walkthrough_debug.py +93 -0
- warp/thirdparty/__init__.py +0 -0
- warp/thirdparty/appdirs.py +598 -0
- warp/thirdparty/dlpack.py +145 -0
- warp/thirdparty/unittest_parallel.py +570 -0
- warp/torch.py +391 -0
- warp/types.py +5230 -0
- warp/utils.py +1137 -0
- warp_lang-1.7.0.dist-info/METADATA +516 -0
- warp_lang-1.7.0.dist-info/RECORD +429 -0
- warp_lang-1.7.0.dist-info/WHEEL +5 -0
- warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
- warp_lang-1.7.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,264 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
3
|
+
* SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
*
|
|
5
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
* you may not use this file except in compliance with the License.
|
|
7
|
+
* You may obtain a copy of the License at
|
|
8
|
+
*
|
|
9
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
*
|
|
11
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
* See the License for the specific language governing permissions and
|
|
15
|
+
* limitations under the License.
|
|
16
|
+
*/
|
|
17
|
+
|
|
18
|
+
#pragma once
|
|
19
|
+
|
|
20
|
+
#include "tile.h"
|
|
21
|
+
|
|
22
|
+
#define WP_TILE_WARP_SIZE 32
|
|
23
|
+
|
|
24
|
+
namespace wp
|
|
25
|
+
{
|
|
26
|
+
|
|
27
|
+
#if defined(__CUDA_ARCH__)
|
|
28
|
+
|
|
29
|
+
template <typename T>
|
|
30
|
+
inline CUDA_CALLABLE T warp_shuffle_down(T val, int offset, int mask)
|
|
31
|
+
{
|
|
32
|
+
typedef unsigned int Word;
|
|
33
|
+
|
|
34
|
+
union
|
|
35
|
+
{
|
|
36
|
+
T output;
|
|
37
|
+
Word output_storage;
|
|
38
|
+
};
|
|
39
|
+
|
|
40
|
+
union
|
|
41
|
+
{
|
|
42
|
+
T input;
|
|
43
|
+
Word input_storage;
|
|
44
|
+
};
|
|
45
|
+
|
|
46
|
+
input = val;
|
|
47
|
+
|
|
48
|
+
Word* dest = reinterpret_cast<Word*>(&output);
|
|
49
|
+
Word* src = reinterpret_cast<Word*>(&input);
|
|
50
|
+
|
|
51
|
+
unsigned int shuffle_word;
|
|
52
|
+
|
|
53
|
+
constexpr int word_count = (sizeof(T) + sizeof(Word) - 1) / sizeof(Word);
|
|
54
|
+
|
|
55
|
+
WP_PRAGMA_UNROLL
|
|
56
|
+
for (int i=0; i < word_count; ++i)
|
|
57
|
+
{
|
|
58
|
+
shuffle_word = __shfl_down_sync(mask, src[i], offset, WP_TILE_WARP_SIZE);
|
|
59
|
+
dest[i] = shuffle_word;
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
return output;
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
template <typename T, typename Op>
|
|
66
|
+
inline CUDA_CALLABLE T warp_reduce(T val, Op f, unsigned int mask)
|
|
67
|
+
{
|
|
68
|
+
T sum = val;
|
|
69
|
+
|
|
70
|
+
if (mask == 0xFFFFFFFF)
|
|
71
|
+
{
|
|
72
|
+
// handle case where entire warp is active
|
|
73
|
+
for (int offset=WP_TILE_WARP_SIZE/2; offset > 0; offset /= 2)
|
|
74
|
+
{
|
|
75
|
+
sum = f(sum, warp_shuffle_down(sum, offset, mask));
|
|
76
|
+
}
|
|
77
|
+
}
|
|
78
|
+
else
|
|
79
|
+
{
|
|
80
|
+
// handle partial warp case
|
|
81
|
+
for (int offset=WP_TILE_WARP_SIZE/2; offset > 0; offset /= 2)
|
|
82
|
+
{
|
|
83
|
+
T shfl_val = warp_shuffle_down(sum, offset, mask);
|
|
84
|
+
if ((mask & (1 << ((threadIdx.x + offset)%WP_TILE_WARP_SIZE))) != 0)
|
|
85
|
+
sum = f(sum, shfl_val);
|
|
86
|
+
}
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
return sum;
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
// non-axis version which computes sum
|
|
93
|
+
// across the entire tile using the whole block
|
|
94
|
+
template <typename Tile, typename Op>
|
|
95
|
+
auto tile_reduce_impl(Op f, Tile& t)
|
|
96
|
+
{
|
|
97
|
+
using T = typename Tile::Type;
|
|
98
|
+
|
|
99
|
+
auto input = t.copy_to_register();
|
|
100
|
+
auto output = tile_register_t<T, tile_layout_register_t<tile_shape_t<1>>>();
|
|
101
|
+
|
|
102
|
+
const int warp_count = (WP_TILE_BLOCK_DIM + WP_TILE_WARP_SIZE - 1)/WP_TILE_WARP_SIZE;
|
|
103
|
+
const int warp_index = threadIdx.x/WP_TILE_WARP_SIZE;
|
|
104
|
+
const int lane_index = threadIdx.x%WP_TILE_WARP_SIZE;
|
|
105
|
+
|
|
106
|
+
T thread_sum = input.data[0];
|
|
107
|
+
|
|
108
|
+
using Layout = typename decltype(input)::Layout;
|
|
109
|
+
|
|
110
|
+
// thread reduction
|
|
111
|
+
WP_PRAGMA_UNROLL
|
|
112
|
+
for (int i=1; i < Layout::NumRegs; ++i)
|
|
113
|
+
{
|
|
114
|
+
int linear = Layout::linear_from_register(i);
|
|
115
|
+
if (!Layout::valid(linear))
|
|
116
|
+
break;
|
|
117
|
+
|
|
118
|
+
thread_sum = f(thread_sum, input.data[i]);
|
|
119
|
+
}
|
|
120
|
+
|
|
121
|
+
// ensure that only threads with at least one valid item participate in the reduction
|
|
122
|
+
unsigned int mask = __ballot_sync(__activemask(), Layout::valid(Layout::linear_from_register(0)));
|
|
123
|
+
|
|
124
|
+
// warp reduction
|
|
125
|
+
T warp_sum = warp_reduce(thread_sum, f, mask);
|
|
126
|
+
|
|
127
|
+
// fixed size scratch pad for partial results in shared memory
|
|
128
|
+
WP_TILE_SHARED T partials[warp_count];
|
|
129
|
+
|
|
130
|
+
// count of active warps
|
|
131
|
+
WP_TILE_SHARED int active_warps;
|
|
132
|
+
if (threadIdx.x == 0)
|
|
133
|
+
active_warps = 0;
|
|
134
|
+
|
|
135
|
+
// ensure active_warps is initialized
|
|
136
|
+
WP_TILE_SYNC();
|
|
137
|
+
|
|
138
|
+
if (lane_index == 0)
|
|
139
|
+
{
|
|
140
|
+
partials[warp_index] = warp_sum;
|
|
141
|
+
atomicAdd(&active_warps, 1);
|
|
142
|
+
}
|
|
143
|
+
|
|
144
|
+
// ensure partials are ready
|
|
145
|
+
WP_TILE_SYNC();
|
|
146
|
+
|
|
147
|
+
// reduce across block, todo: use warp_reduce() here
|
|
148
|
+
if (threadIdx.x == 0)
|
|
149
|
+
{
|
|
150
|
+
T block_sum = partials[0];
|
|
151
|
+
|
|
152
|
+
WP_PRAGMA_UNROLL
|
|
153
|
+
for (int i=1; i < active_warps; ++i)
|
|
154
|
+
block_sum = f(block_sum, partials[i]);
|
|
155
|
+
|
|
156
|
+
output.data[0] = block_sum;
|
|
157
|
+
}
|
|
158
|
+
|
|
159
|
+
return output;
|
|
160
|
+
}
|
|
161
|
+
|
|
162
|
+
#else
|
|
163
|
+
|
|
164
|
+
// CPU implementation
|
|
165
|
+
|
|
166
|
+
template <typename Tile, typename Op>
|
|
167
|
+
auto tile_reduce_impl(Op f, Tile& t)
|
|
168
|
+
{
|
|
169
|
+
using T = typename Tile::Type;
|
|
170
|
+
|
|
171
|
+
auto input = t.copy_to_register();
|
|
172
|
+
auto output = tile_register_t<T, tile_layout_register_t<tile_shape_t<1>>>();
|
|
173
|
+
|
|
174
|
+
using Layout = typename decltype(input)::Layout;
|
|
175
|
+
|
|
176
|
+
T sum = input.data[0];
|
|
177
|
+
|
|
178
|
+
WP_PRAGMA_UNROLL
|
|
179
|
+
for (int i=1; i < Layout::NumRegs; ++i)
|
|
180
|
+
{
|
|
181
|
+
int linear = Layout::linear_from_register(i);
|
|
182
|
+
if (!Layout::valid(linear))
|
|
183
|
+
break;
|
|
184
|
+
|
|
185
|
+
sum = f(sum, input.data[i]);
|
|
186
|
+
}
|
|
187
|
+
|
|
188
|
+
output.data[0] = sum;
|
|
189
|
+
return output;
|
|
190
|
+
}
|
|
191
|
+
|
|
192
|
+
#endif // !defined(__CUDA_ARCH__)
|
|
193
|
+
|
|
194
|
+
inline void adj_tile_reduce_impl()
|
|
195
|
+
{
|
|
196
|
+
// todo: general purpose reduction gradients not implemented
|
|
197
|
+
}
|
|
198
|
+
|
|
199
|
+
// entry point for Python code-gen, wraps op in a lambda to perform overload resolution
|
|
200
|
+
#define tile_reduce(op, t) tile_reduce_impl([](auto x, auto y) { return op(x, y);}, t)
|
|
201
|
+
#define adj_tile_reduce(op, a, adj_op, adj_a, adj_ret) adj_tile_reduce_impl()
|
|
202
|
+
|
|
203
|
+
// convenience methods for specific reductions
|
|
204
|
+
|
|
205
|
+
template <typename Tile>
|
|
206
|
+
auto tile_sum(Tile& t)
|
|
207
|
+
{
|
|
208
|
+
return tile_reduce(add, t);
|
|
209
|
+
}
|
|
210
|
+
|
|
211
|
+
// special case adjoint for summation
|
|
212
|
+
template <typename Tile, typename AdjTile>
|
|
213
|
+
void adj_tile_sum(Tile& t, Tile& adj_t, AdjTile& adj_ret)
|
|
214
|
+
{
|
|
215
|
+
using T = typename Tile::Type;
|
|
216
|
+
|
|
217
|
+
#if !defined(__CUDA_ARCH__)
|
|
218
|
+
|
|
219
|
+
for (int i=0; i < Tile::Layout::Size; ++i)
|
|
220
|
+
{
|
|
221
|
+
adj_t(i) += adj_ret.data[0];
|
|
222
|
+
|
|
223
|
+
}
|
|
224
|
+
#else
|
|
225
|
+
// broadcast incoming adjoint to block
|
|
226
|
+
WP_TILE_SHARED T scratch;
|
|
227
|
+
if (WP_TILE_THREAD_IDX == 0)
|
|
228
|
+
scratch = adj_ret.data[0];
|
|
229
|
+
|
|
230
|
+
WP_TILE_SYNC();
|
|
231
|
+
|
|
232
|
+
// broadcast scalar across input dimensions (note zero strides)
|
|
233
|
+
auto adj_ret_reg = tile_shared_t<T, tile_layout_strided_t<typename Tile::Layout::Shape, tile_stride_t<0, 0>>, false>(&scratch, nullptr).copy_to_register();
|
|
234
|
+
adj_t.grad_add(adj_ret_reg);
|
|
235
|
+
#endif
|
|
236
|
+
}
|
|
237
|
+
|
|
238
|
+
template <typename Tile>
|
|
239
|
+
auto tile_max(Tile& t)
|
|
240
|
+
{
|
|
241
|
+
return tile_reduce(max, t);
|
|
242
|
+
}
|
|
243
|
+
|
|
244
|
+
template <typename Tile, typename AdjTile>
|
|
245
|
+
void adj_tile_max(Tile& t, Tile& adj_t, AdjTile& adj_ret)
|
|
246
|
+
{
|
|
247
|
+
// todo: not implemented
|
|
248
|
+
}
|
|
249
|
+
|
|
250
|
+
template <typename Tile>
|
|
251
|
+
auto tile_min(Tile& t)
|
|
252
|
+
{
|
|
253
|
+
return tile_reduce(min, t);
|
|
254
|
+
}
|
|
255
|
+
|
|
256
|
+
template <typename Tile, typename AdjTile>
|
|
257
|
+
void adj_tile_min(Tile& t, Tile& adj_t, AdjTile& adj_ret)
|
|
258
|
+
{
|
|
259
|
+
// todo: not implemented
|
|
260
|
+
}
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
} // namespace wp
|