warp-lang 1.7.0__py3-none-manylinux_2_34_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +139 -0
- warp/__init__.pyi +1 -0
- warp/autograd.py +1142 -0
- warp/bin/warp-clang.so +0 -0
- warp/bin/warp.so +0 -0
- warp/build.py +557 -0
- warp/build_dll.py +405 -0
- warp/builtins.py +6855 -0
- warp/codegen.py +3969 -0
- warp/config.py +158 -0
- warp/constants.py +57 -0
- warp/context.py +6812 -0
- warp/dlpack.py +462 -0
- warp/examples/__init__.py +24 -0
- warp/examples/assets/bear.usd +0 -0
- warp/examples/assets/bunny.usd +0 -0
- warp/examples/assets/cartpole.urdf +110 -0
- warp/examples/assets/crazyflie.usd +0 -0
- warp/examples/assets/cube.usd +0 -0
- warp/examples/assets/nonuniform.usd +0 -0
- warp/examples/assets/nv_ant.xml +92 -0
- warp/examples/assets/nv_humanoid.xml +183 -0
- warp/examples/assets/nvidia_logo.png +0 -0
- warp/examples/assets/pixel.jpg +0 -0
- warp/examples/assets/quadruped.urdf +268 -0
- warp/examples/assets/rocks.nvdb +0 -0
- warp/examples/assets/rocks.usd +0 -0
- warp/examples/assets/sphere.usd +0 -0
- warp/examples/assets/square_cloth.usd +0 -0
- warp/examples/benchmarks/benchmark_api.py +389 -0
- warp/examples/benchmarks/benchmark_cloth.py +296 -0
- warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
- warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
- warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
- warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
- warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
- warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
- warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
- warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
- warp/examples/benchmarks/benchmark_gemm.py +164 -0
- warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
- warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
- warp/examples/benchmarks/benchmark_launches.py +301 -0
- warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
- warp/examples/browse.py +37 -0
- warp/examples/core/example_cupy.py +86 -0
- warp/examples/core/example_dem.py +241 -0
- warp/examples/core/example_fluid.py +299 -0
- warp/examples/core/example_graph_capture.py +150 -0
- warp/examples/core/example_marching_cubes.py +194 -0
- warp/examples/core/example_mesh.py +180 -0
- warp/examples/core/example_mesh_intersect.py +211 -0
- warp/examples/core/example_nvdb.py +182 -0
- warp/examples/core/example_raycast.py +111 -0
- warp/examples/core/example_raymarch.py +205 -0
- warp/examples/core/example_render_opengl.py +193 -0
- warp/examples/core/example_sample_mesh.py +300 -0
- warp/examples/core/example_sph.py +411 -0
- warp/examples/core/example_torch.py +211 -0
- warp/examples/core/example_wave.py +269 -0
- warp/examples/fem/example_adaptive_grid.py +286 -0
- warp/examples/fem/example_apic_fluid.py +423 -0
- warp/examples/fem/example_burgers.py +261 -0
- warp/examples/fem/example_convection_diffusion.py +178 -0
- warp/examples/fem/example_convection_diffusion_dg.py +204 -0
- warp/examples/fem/example_deformed_geometry.py +172 -0
- warp/examples/fem/example_diffusion.py +196 -0
- warp/examples/fem/example_diffusion_3d.py +225 -0
- warp/examples/fem/example_diffusion_mgpu.py +220 -0
- warp/examples/fem/example_distortion_energy.py +228 -0
- warp/examples/fem/example_magnetostatics.py +240 -0
- warp/examples/fem/example_mixed_elasticity.py +291 -0
- warp/examples/fem/example_navier_stokes.py +261 -0
- warp/examples/fem/example_nonconforming_contact.py +298 -0
- warp/examples/fem/example_stokes.py +213 -0
- warp/examples/fem/example_stokes_transfer.py +262 -0
- warp/examples/fem/example_streamlines.py +352 -0
- warp/examples/fem/utils.py +1000 -0
- warp/examples/interop/example_jax_callable.py +116 -0
- warp/examples/interop/example_jax_ffi_callback.py +132 -0
- warp/examples/interop/example_jax_kernel.py +205 -0
- warp/examples/optim/example_bounce.py +266 -0
- warp/examples/optim/example_cloth_throw.py +228 -0
- warp/examples/optim/example_diffray.py +561 -0
- warp/examples/optim/example_drone.py +870 -0
- warp/examples/optim/example_fluid_checkpoint.py +497 -0
- warp/examples/optim/example_inverse_kinematics.py +182 -0
- warp/examples/optim/example_inverse_kinematics_torch.py +191 -0
- warp/examples/optim/example_softbody_properties.py +400 -0
- warp/examples/optim/example_spring_cage.py +245 -0
- warp/examples/optim/example_trajectory.py +227 -0
- warp/examples/sim/example_cartpole.py +143 -0
- warp/examples/sim/example_cloth.py +225 -0
- warp/examples/sim/example_cloth_self_contact.py +322 -0
- warp/examples/sim/example_granular.py +130 -0
- warp/examples/sim/example_granular_collision_sdf.py +202 -0
- warp/examples/sim/example_jacobian_ik.py +244 -0
- warp/examples/sim/example_particle_chain.py +124 -0
- warp/examples/sim/example_quadruped.py +203 -0
- warp/examples/sim/example_rigid_chain.py +203 -0
- warp/examples/sim/example_rigid_contact.py +195 -0
- warp/examples/sim/example_rigid_force.py +133 -0
- warp/examples/sim/example_rigid_gyroscopic.py +115 -0
- warp/examples/sim/example_rigid_soft_contact.py +140 -0
- warp/examples/sim/example_soft_body.py +196 -0
- warp/examples/tile/example_tile_cholesky.py +87 -0
- warp/examples/tile/example_tile_convolution.py +66 -0
- warp/examples/tile/example_tile_fft.py +55 -0
- warp/examples/tile/example_tile_filtering.py +113 -0
- warp/examples/tile/example_tile_matmul.py +85 -0
- warp/examples/tile/example_tile_mlp.py +383 -0
- warp/examples/tile/example_tile_nbody.py +199 -0
- warp/examples/tile/example_tile_walker.py +327 -0
- warp/fabric.py +355 -0
- warp/fem/__init__.py +106 -0
- warp/fem/adaptivity.py +508 -0
- warp/fem/cache.py +572 -0
- warp/fem/dirichlet.py +202 -0
- warp/fem/domain.py +411 -0
- warp/fem/field/__init__.py +125 -0
- warp/fem/field/field.py +619 -0
- warp/fem/field/nodal_field.py +326 -0
- warp/fem/field/restriction.py +37 -0
- warp/fem/field/virtual.py +848 -0
- warp/fem/geometry/__init__.py +32 -0
- warp/fem/geometry/adaptive_nanogrid.py +857 -0
- warp/fem/geometry/closest_point.py +84 -0
- warp/fem/geometry/deformed_geometry.py +221 -0
- warp/fem/geometry/element.py +776 -0
- warp/fem/geometry/geometry.py +362 -0
- warp/fem/geometry/grid_2d.py +392 -0
- warp/fem/geometry/grid_3d.py +452 -0
- warp/fem/geometry/hexmesh.py +911 -0
- warp/fem/geometry/nanogrid.py +571 -0
- warp/fem/geometry/partition.py +389 -0
- warp/fem/geometry/quadmesh.py +663 -0
- warp/fem/geometry/tetmesh.py +855 -0
- warp/fem/geometry/trimesh.py +806 -0
- warp/fem/integrate.py +2335 -0
- warp/fem/linalg.py +419 -0
- warp/fem/operator.py +293 -0
- warp/fem/polynomial.py +229 -0
- warp/fem/quadrature/__init__.py +17 -0
- warp/fem/quadrature/pic_quadrature.py +299 -0
- warp/fem/quadrature/quadrature.py +591 -0
- warp/fem/space/__init__.py +228 -0
- warp/fem/space/basis_function_space.py +468 -0
- warp/fem/space/basis_space.py +667 -0
- warp/fem/space/dof_mapper.py +251 -0
- warp/fem/space/function_space.py +309 -0
- warp/fem/space/grid_2d_function_space.py +177 -0
- warp/fem/space/grid_3d_function_space.py +227 -0
- warp/fem/space/hexmesh_function_space.py +257 -0
- warp/fem/space/nanogrid_function_space.py +201 -0
- warp/fem/space/partition.py +367 -0
- warp/fem/space/quadmesh_function_space.py +223 -0
- warp/fem/space/restriction.py +179 -0
- warp/fem/space/shape/__init__.py +143 -0
- warp/fem/space/shape/cube_shape_function.py +1105 -0
- warp/fem/space/shape/shape_function.py +133 -0
- warp/fem/space/shape/square_shape_function.py +926 -0
- warp/fem/space/shape/tet_shape_function.py +834 -0
- warp/fem/space/shape/triangle_shape_function.py +672 -0
- warp/fem/space/tetmesh_function_space.py +271 -0
- warp/fem/space/topology.py +424 -0
- warp/fem/space/trimesh_function_space.py +194 -0
- warp/fem/types.py +99 -0
- warp/fem/utils.py +420 -0
- warp/jax.py +187 -0
- warp/jax_experimental/__init__.py +16 -0
- warp/jax_experimental/custom_call.py +351 -0
- warp/jax_experimental/ffi.py +698 -0
- warp/jax_experimental/xla_ffi.py +602 -0
- warp/math.py +244 -0
- warp/native/array.h +1145 -0
- warp/native/builtin.h +1800 -0
- warp/native/bvh.cpp +492 -0
- warp/native/bvh.cu +791 -0
- warp/native/bvh.h +554 -0
- warp/native/clang/clang.cpp +536 -0
- warp/native/coloring.cpp +613 -0
- warp/native/crt.cpp +51 -0
- warp/native/crt.h +362 -0
- warp/native/cuda_crt.h +1058 -0
- warp/native/cuda_util.cpp +646 -0
- warp/native/cuda_util.h +307 -0
- warp/native/error.cpp +77 -0
- warp/native/error.h +36 -0
- warp/native/exports.h +1878 -0
- warp/native/fabric.h +245 -0
- warp/native/hashgrid.cpp +311 -0
- warp/native/hashgrid.cu +87 -0
- warp/native/hashgrid.h +240 -0
- warp/native/initializer_array.h +41 -0
- warp/native/intersect.h +1230 -0
- warp/native/intersect_adj.h +375 -0
- warp/native/intersect_tri.h +339 -0
- warp/native/marching.cpp +19 -0
- warp/native/marching.cu +514 -0
- warp/native/marching.h +19 -0
- warp/native/mat.h +2220 -0
- warp/native/mathdx.cpp +87 -0
- warp/native/matnn.h +343 -0
- warp/native/mesh.cpp +266 -0
- warp/native/mesh.cu +404 -0
- warp/native/mesh.h +1980 -0
- warp/native/nanovdb/GridHandle.h +366 -0
- warp/native/nanovdb/HostBuffer.h +590 -0
- warp/native/nanovdb/NanoVDB.h +6624 -0
- warp/native/nanovdb/PNanoVDB.h +3390 -0
- warp/native/noise.h +859 -0
- warp/native/quat.h +1371 -0
- warp/native/rand.h +342 -0
- warp/native/range.h +139 -0
- warp/native/reduce.cpp +174 -0
- warp/native/reduce.cu +364 -0
- warp/native/runlength_encode.cpp +79 -0
- warp/native/runlength_encode.cu +61 -0
- warp/native/scan.cpp +47 -0
- warp/native/scan.cu +53 -0
- warp/native/scan.h +23 -0
- warp/native/solid_angle.h +466 -0
- warp/native/sort.cpp +251 -0
- warp/native/sort.cu +277 -0
- warp/native/sort.h +33 -0
- warp/native/sparse.cpp +378 -0
- warp/native/sparse.cu +524 -0
- warp/native/spatial.h +657 -0
- warp/native/svd.h +702 -0
- warp/native/temp_buffer.h +46 -0
- warp/native/tile.h +2584 -0
- warp/native/tile_reduce.h +264 -0
- warp/native/vec.h +1426 -0
- warp/native/volume.cpp +501 -0
- warp/native/volume.cu +67 -0
- warp/native/volume.h +969 -0
- warp/native/volume_builder.cu +477 -0
- warp/native/volume_builder.h +52 -0
- warp/native/volume_impl.h +70 -0
- warp/native/warp.cpp +1082 -0
- warp/native/warp.cu +3636 -0
- warp/native/warp.h +381 -0
- warp/optim/__init__.py +17 -0
- warp/optim/adam.py +163 -0
- warp/optim/linear.py +1137 -0
- warp/optim/sgd.py +112 -0
- warp/paddle.py +407 -0
- warp/render/__init__.py +18 -0
- warp/render/render_opengl.py +3518 -0
- warp/render/render_usd.py +784 -0
- warp/render/utils.py +160 -0
- warp/sim/__init__.py +65 -0
- warp/sim/articulation.py +793 -0
- warp/sim/collide.py +2395 -0
- warp/sim/graph_coloring.py +300 -0
- warp/sim/import_mjcf.py +790 -0
- warp/sim/import_snu.py +227 -0
- warp/sim/import_urdf.py +579 -0
- warp/sim/import_usd.py +894 -0
- warp/sim/inertia.py +324 -0
- warp/sim/integrator.py +242 -0
- warp/sim/integrator_euler.py +1997 -0
- warp/sim/integrator_featherstone.py +2101 -0
- warp/sim/integrator_vbd.py +2048 -0
- warp/sim/integrator_xpbd.py +3292 -0
- warp/sim/model.py +4791 -0
- warp/sim/particles.py +121 -0
- warp/sim/render.py +427 -0
- warp/sim/utils.py +428 -0
- warp/sparse.py +2057 -0
- warp/stubs.py +3333 -0
- warp/tape.py +1203 -0
- warp/tests/__init__.py +1 -0
- warp/tests/__main__.py +4 -0
- warp/tests/assets/curlnoise_golden.npy +0 -0
- warp/tests/assets/mlp_golden.npy +0 -0
- warp/tests/assets/pixel.npy +0 -0
- warp/tests/assets/pnoise_golden.npy +0 -0
- warp/tests/assets/spiky.usd +0 -0
- warp/tests/assets/test_grid.nvdb +0 -0
- warp/tests/assets/test_index_grid.nvdb +0 -0
- warp/tests/assets/test_int32_grid.nvdb +0 -0
- warp/tests/assets/test_vec_grid.nvdb +0 -0
- warp/tests/assets/torus.nvdb +0 -0
- warp/tests/assets/torus.usda +105 -0
- warp/tests/aux_test_class_kernel.py +34 -0
- warp/tests/aux_test_compile_consts_dummy.py +18 -0
- warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
- warp/tests/aux_test_dependent.py +29 -0
- warp/tests/aux_test_grad_customs.py +29 -0
- warp/tests/aux_test_instancing_gc.py +26 -0
- warp/tests/aux_test_module_unload.py +23 -0
- warp/tests/aux_test_name_clash1.py +40 -0
- warp/tests/aux_test_name_clash2.py +40 -0
- warp/tests/aux_test_reference.py +9 -0
- warp/tests/aux_test_reference_reference.py +8 -0
- warp/tests/aux_test_square.py +16 -0
- warp/tests/aux_test_unresolved_func.py +22 -0
- warp/tests/aux_test_unresolved_symbol.py +22 -0
- warp/tests/cuda/__init__.py +0 -0
- warp/tests/cuda/test_async.py +676 -0
- warp/tests/cuda/test_ipc.py +124 -0
- warp/tests/cuda/test_mempool.py +233 -0
- warp/tests/cuda/test_multigpu.py +169 -0
- warp/tests/cuda/test_peer.py +139 -0
- warp/tests/cuda/test_pinned.py +84 -0
- warp/tests/cuda/test_streams.py +634 -0
- warp/tests/geometry/__init__.py +0 -0
- warp/tests/geometry/test_bvh.py +200 -0
- warp/tests/geometry/test_hash_grid.py +221 -0
- warp/tests/geometry/test_marching_cubes.py +74 -0
- warp/tests/geometry/test_mesh.py +316 -0
- warp/tests/geometry/test_mesh_query_aabb.py +399 -0
- warp/tests/geometry/test_mesh_query_point.py +932 -0
- warp/tests/geometry/test_mesh_query_ray.py +311 -0
- warp/tests/geometry/test_volume.py +1103 -0
- warp/tests/geometry/test_volume_write.py +346 -0
- warp/tests/interop/__init__.py +0 -0
- warp/tests/interop/test_dlpack.py +729 -0
- warp/tests/interop/test_jax.py +371 -0
- warp/tests/interop/test_paddle.py +800 -0
- warp/tests/interop/test_torch.py +1001 -0
- warp/tests/run_coverage_serial.py +39 -0
- warp/tests/sim/__init__.py +0 -0
- warp/tests/sim/disabled_kinematics.py +244 -0
- warp/tests/sim/flaky_test_sim_grad.py +290 -0
- warp/tests/sim/test_collision.py +604 -0
- warp/tests/sim/test_coloring.py +258 -0
- warp/tests/sim/test_model.py +224 -0
- warp/tests/sim/test_sim_grad_bounce_linear.py +212 -0
- warp/tests/sim/test_sim_kinematics.py +98 -0
- warp/tests/sim/test_vbd.py +597 -0
- warp/tests/test_adam.py +163 -0
- warp/tests/test_arithmetic.py +1096 -0
- warp/tests/test_array.py +2972 -0
- warp/tests/test_array_reduce.py +156 -0
- warp/tests/test_assert.py +250 -0
- warp/tests/test_atomic.py +153 -0
- warp/tests/test_bool.py +220 -0
- warp/tests/test_builtins_resolution.py +1298 -0
- warp/tests/test_closest_point_edge_edge.py +327 -0
- warp/tests/test_codegen.py +810 -0
- warp/tests/test_codegen_instancing.py +1495 -0
- warp/tests/test_compile_consts.py +215 -0
- warp/tests/test_conditional.py +252 -0
- warp/tests/test_context.py +42 -0
- warp/tests/test_copy.py +238 -0
- warp/tests/test_ctypes.py +638 -0
- warp/tests/test_dense.py +73 -0
- warp/tests/test_devices.py +97 -0
- warp/tests/test_examples.py +482 -0
- warp/tests/test_fabricarray.py +996 -0
- warp/tests/test_fast_math.py +74 -0
- warp/tests/test_fem.py +2003 -0
- warp/tests/test_fp16.py +136 -0
- warp/tests/test_func.py +454 -0
- warp/tests/test_future_annotations.py +98 -0
- warp/tests/test_generics.py +656 -0
- warp/tests/test_grad.py +893 -0
- warp/tests/test_grad_customs.py +339 -0
- warp/tests/test_grad_debug.py +341 -0
- warp/tests/test_implicit_init.py +411 -0
- warp/tests/test_import.py +45 -0
- warp/tests/test_indexedarray.py +1140 -0
- warp/tests/test_intersect.py +73 -0
- warp/tests/test_iter.py +76 -0
- warp/tests/test_large.py +177 -0
- warp/tests/test_launch.py +411 -0
- warp/tests/test_lerp.py +151 -0
- warp/tests/test_linear_solvers.py +193 -0
- warp/tests/test_lvalue.py +427 -0
- warp/tests/test_mat.py +2089 -0
- warp/tests/test_mat_lite.py +122 -0
- warp/tests/test_mat_scalar_ops.py +2913 -0
- warp/tests/test_math.py +178 -0
- warp/tests/test_mlp.py +282 -0
- warp/tests/test_module_hashing.py +258 -0
- warp/tests/test_modules_lite.py +44 -0
- warp/tests/test_noise.py +252 -0
- warp/tests/test_operators.py +299 -0
- warp/tests/test_options.py +129 -0
- warp/tests/test_overwrite.py +551 -0
- warp/tests/test_print.py +339 -0
- warp/tests/test_quat.py +2315 -0
- warp/tests/test_rand.py +339 -0
- warp/tests/test_reload.py +302 -0
- warp/tests/test_rounding.py +185 -0
- warp/tests/test_runlength_encode.py +196 -0
- warp/tests/test_scalar_ops.py +105 -0
- warp/tests/test_smoothstep.py +108 -0
- warp/tests/test_snippet.py +318 -0
- warp/tests/test_sparse.py +582 -0
- warp/tests/test_spatial.py +2229 -0
- warp/tests/test_special_values.py +361 -0
- warp/tests/test_static.py +592 -0
- warp/tests/test_struct.py +734 -0
- warp/tests/test_tape.py +204 -0
- warp/tests/test_transient_module.py +93 -0
- warp/tests/test_triangle_closest_point.py +145 -0
- warp/tests/test_types.py +562 -0
- warp/tests/test_utils.py +588 -0
- warp/tests/test_vec.py +1487 -0
- warp/tests/test_vec_lite.py +80 -0
- warp/tests/test_vec_scalar_ops.py +2327 -0
- warp/tests/test_verify_fp.py +100 -0
- warp/tests/tile/__init__.py +0 -0
- warp/tests/tile/test_tile.py +780 -0
- warp/tests/tile/test_tile_load.py +407 -0
- warp/tests/tile/test_tile_mathdx.py +208 -0
- warp/tests/tile/test_tile_mlp.py +402 -0
- warp/tests/tile/test_tile_reduce.py +447 -0
- warp/tests/tile/test_tile_shared_memory.py +247 -0
- warp/tests/tile/test_tile_view.py +173 -0
- warp/tests/unittest_serial.py +47 -0
- warp/tests/unittest_suites.py +427 -0
- warp/tests/unittest_utils.py +468 -0
- warp/tests/walkthrough_debug.py +93 -0
- warp/thirdparty/__init__.py +0 -0
- warp/thirdparty/appdirs.py +598 -0
- warp/thirdparty/dlpack.py +145 -0
- warp/thirdparty/unittest_parallel.py +570 -0
- warp/torch.py +391 -0
- warp/types.py +5230 -0
- warp/utils.py +1137 -0
- warp_lang-1.7.0.dist-info/METADATA +516 -0
- warp_lang-1.7.0.dist-info/RECORD +429 -0
- warp_lang-1.7.0.dist-info/WHEEL +5 -0
- warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
- warp_lang-1.7.0.dist-info/top_level.txt +1 -0
warp/native/tile.h
ADDED
|
@@ -0,0 +1,2584 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
3
|
+
* SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
*
|
|
5
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
* you may not use this file except in compliance with the License.
|
|
7
|
+
* You may obtain a copy of the License at
|
|
8
|
+
*
|
|
9
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
*
|
|
11
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
* See the License for the specific language governing permissions and
|
|
15
|
+
* limitations under the License.
|
|
16
|
+
*/
|
|
17
|
+
|
|
18
|
+
#pragma once
|
|
19
|
+
|
|
20
|
+
#include "builtin.h"
|
|
21
|
+
|
|
22
|
+
#ifdef __clang__
|
|
23
|
+
// disable warnings related to C++17 extensions on CPU JIT builds
|
|
24
|
+
#pragma clang diagnostic push
|
|
25
|
+
#pragma clang diagnostic ignored "-Wc++17-extensions"
|
|
26
|
+
#endif // __clang__
|
|
27
|
+
|
|
28
|
+
// Check if the CUDA toolkit is available
|
|
29
|
+
#if WP_ENABLE_CUDA || defined(__CUDACC_RTC__)
|
|
30
|
+
|
|
31
|
+
// If NVRTC is being used, do not include extra headers (NVRTC has built-in float4)
|
|
32
|
+
#ifdef __CUDACC_RTC__
|
|
33
|
+
// NVRTC: Use built-in float4 (no need for extra definitions)
|
|
34
|
+
#else
|
|
35
|
+
// NVCC: Include vector_types.h to get float4
|
|
36
|
+
#include <cuda_runtime.h>
|
|
37
|
+
#endif
|
|
38
|
+
|
|
39
|
+
#else
|
|
40
|
+
// If CUDA is not available (e.g., macOS build), manually define float4
|
|
41
|
+
struct alignas(16) float4 {
|
|
42
|
+
float x, y, z, w;
|
|
43
|
+
};
|
|
44
|
+
#endif
|
|
45
|
+
|
|
46
|
+
// only used while building the warp core library
|
|
47
|
+
#ifndef WP_TILE_BLOCK_DIM
|
|
48
|
+
#define WP_TILE_BLOCK_DIM 256
|
|
49
|
+
#endif
|
|
50
|
+
|
|
51
|
+
#if !defined(__CUDA_ARCH__)
|
|
52
|
+
#define WP_TILE_SHARED static
|
|
53
|
+
#define WP_TILE_SYNC void
|
|
54
|
+
|
|
55
|
+
#else
|
|
56
|
+
#define WP_TILE_SHARED __shared__
|
|
57
|
+
#define WP_TILE_SYNC __syncthreads
|
|
58
|
+
#endif
|
|
59
|
+
|
|
60
|
+
#if defined(__CUDA_ARCH__) && !defined(__INTELLISENSE__)
|
|
61
|
+
#if defined(__CUDACC_RTC__) || (defined(__clang__) && defined(__CUDA__))
|
|
62
|
+
#define WP_PRAGMA_UNROLL _Pragma("unroll")
|
|
63
|
+
#define WP_PRAGMA_NO_UNROLL _Pragma("unroll 1")
|
|
64
|
+
#else
|
|
65
|
+
#define WP_PRAGMA_UNROLL #pragma unroll
|
|
66
|
+
#define WP_PRAGMA_NO_UNROLL #pragma unroll 1
|
|
67
|
+
#endif
|
|
68
|
+
|
|
69
|
+
#else
|
|
70
|
+
|
|
71
|
+
#define WP_PRAGMA_UNROLL
|
|
72
|
+
#define WP_PRAGMA_NO_UNROLL
|
|
73
|
+
|
|
74
|
+
#endif
|
|
75
|
+
|
|
76
|
+
#define WP_USE_ASYNC_PIPELINE 0
|
|
77
|
+
#define WP_USE_REGISTER_GEMM 0
|
|
78
|
+
|
|
79
|
+
#if defined(__CUDACC_RTC__)
|
|
80
|
+
#define WP_TILE_THREAD_IDX threadIdx.x
|
|
81
|
+
#else
|
|
82
|
+
#define WP_TILE_THREAD_IDX 0
|
|
83
|
+
#endif //
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
/* Tile Expressions
|
|
88
|
+
|
|
89
|
+
[ ] Tiles
|
|
90
|
+
[x] Register, Shared, Global
|
|
91
|
+
[ ] Layouts
|
|
92
|
+
[x] Simple
|
|
93
|
+
[ ] Cute
|
|
94
|
+
[x] Remove Alloc type from tile_shared_t
|
|
95
|
+
[x] wp.launch_tiled() helper
|
|
96
|
+
[ ] Creation
|
|
97
|
+
[x] zeros
|
|
98
|
+
[x] ones
|
|
99
|
+
[x] arange
|
|
100
|
+
[x] tile()
|
|
101
|
+
[x] untile()
|
|
102
|
+
[ ] fromfunction()
|
|
103
|
+
[ ] explicit storage
|
|
104
|
+
[ ] Load/Store
|
|
105
|
+
[ ] 1D load/store variants
|
|
106
|
+
[ ] max_coord option for non-aligned loads
|
|
107
|
+
[ ] Indexed load
|
|
108
|
+
[x] wp.tile_atomic_add()
|
|
109
|
+
[ ] Maps
|
|
110
|
+
[x] Support user functions
|
|
111
|
+
[x] Support built-in functions
|
|
112
|
+
[ ] Support for lambda functions
|
|
113
|
+
[ ] Infer tile_map() output from operator type (e.g.: dot for each element)
|
|
114
|
+
[ ] Reductions
|
|
115
|
+
[x] Sum
|
|
116
|
+
[x] Forward
|
|
117
|
+
[x] Reverse
|
|
118
|
+
[x] Min
|
|
119
|
+
[x] Max
|
|
120
|
+
[x] Custom
|
|
121
|
+
[x] MatMul
|
|
122
|
+
[x] Forward
|
|
123
|
+
[x] Reverse
|
|
124
|
+
[ ] Operators
|
|
125
|
+
[ ] +, -, *, /, @?
|
|
126
|
+
[ ] += for matmul, e.g.: c += a@b, or c = a@b
|
|
127
|
+
[ ] Reshape
|
|
128
|
+
[ ] Broadcasting
|
|
129
|
+
[ ] Transpose
|
|
130
|
+
[x] Shared
|
|
131
|
+
[ ] Register
|
|
132
|
+
[ ] Slice
|
|
133
|
+
[ ] Runtime
|
|
134
|
+
[x] Compile-time block dimensions
|
|
135
|
+
[x] Switch between SIMT / Tile based execution if `block_dim` not provided to wp.launch()
|
|
136
|
+
[ ] Examples
|
|
137
|
+
[ ] Point registration
|
|
138
|
+
[ ] GEMM
|
|
139
|
+
[ ] MLP
|
|
140
|
+
[ ] LayerNorm
|
|
141
|
+
[ ] SoftMax
|
|
142
|
+
[ ] GEMM
|
|
143
|
+
[ ] warp.sim (CRBA)
|
|
144
|
+
[ ] Batched MLP
|
|
145
|
+
[ ] Layer norm
|
|
146
|
+
[ ] FNO + Burgers equation
|
|
147
|
+
[ ] Stochastic financial modeling
|
|
148
|
+
[ ] Convolution: https://github.com/NVIDIA/MinkowskiEngine/blob/master/src/convolution_kernel.cu#L123
|
|
149
|
+
[ ] MeshCNN (Modulus, Oliver)
|
|
150
|
+
[ ] BioNemo (Ali)
|
|
151
|
+
[ ] Skinning (David/Or/Vismay)
|
|
152
|
+
[ ] warp.sim (VBD)
|
|
153
|
+
[ ] Error checking
|
|
154
|
+
[ ] Ensure functions passed to tile_map() are compatible with tile type
|
|
155
|
+
[ ] Ensure that args passed to tile ops are compatible
|
|
156
|
+
[ ] Ensure tile load/store operations don't go out of bounds of arrays in debug mode
|
|
157
|
+
|
|
158
|
+
*/
|
|
159
|
+
|
|
160
|
+
/*
|
|
161
|
+
Notes on shared memory synchronization
|
|
162
|
+
======================================
|
|
163
|
+
|
|
164
|
+
Currently operations that write to shared memory tiles (e.g.: tile_load())
|
|
165
|
+
must synchronize before they return through WP_TILE_SYNC(), this
|
|
166
|
+
ensures subsequent read operations from the tile do not cause a race condition.
|
|
167
|
+
|
|
168
|
+
For tile_shared_t adjoints, the gradient accumulation is done through shared
|
|
169
|
+
memory atomics, i.e.: atomic_add(), since for broadcast tiles multiple threads
|
|
170
|
+
may map to the same location. Synchronization is still required after these
|
|
171
|
+
updates, since subsequent operations e.g.: adj_tile_load() will store the
|
|
172
|
+
gradients to memory, and all updates must be visible at that point, e.g.:
|
|
173
|
+
|
|
174
|
+
a = wp.tile_load(...)
|
|
175
|
+
b = wp.tile_load(...)
|
|
176
|
+
c = wp.tile_matmul(a, b)
|
|
177
|
+
wp.tile_store(c)
|
|
178
|
+
|
|
179
|
+
// loads incoming adjoints from global -> shared
|
|
180
|
+
wp.adj_tile_store(c, adj_c)
|
|
181
|
+
// consumes adj_c, requires synchronization
|
|
182
|
+
wp.adj_tile_matmul(a, b, adj_a, adj_b, adj_c)
|
|
183
|
+
// consumes adj_b, requires synchronization
|
|
184
|
+
wp.adj_tile_load(..., adj_b)
|
|
185
|
+
// consumes adj_b, requires synchronization
|
|
186
|
+
wp.adj_tile_load(..., adj_a)
|
|
187
|
+
|
|
188
|
+
Generally synchronization to adjoint tiles will happen through the
|
|
189
|
+
tile_shared_t::add() and tile_shared_t::assign() function automatically,
|
|
190
|
+
but in some cases e.g.: tile_matmul() it is done manually.
|
|
191
|
+
|
|
192
|
+
The current synchronization strategy is conservative, and can lead to more
|
|
193
|
+
synchronization than necessary. A more sophisticated strategy would be
|
|
194
|
+
to track the 'dirty' state of shared tiles, and synchronize only when
|
|
195
|
+
necessary. In addition, custom synchronization for e.g.: tile_load()
|
|
196
|
+
operations could be added through a SyncProvider template parameter on
|
|
197
|
+
the tile_shared_t type, for example to support barrier synchronization
|
|
198
|
+
for asynchronous global to shared loads.
|
|
199
|
+
*/
|
|
200
|
+
|
|
201
|
+
namespace wp
|
|
202
|
+
{
|
|
203
|
+
|
|
204
|
+
// Primary template
|
|
205
|
+
template <typename T, typename U>
|
|
206
|
+
struct is_same {
|
|
207
|
+
static constexpr bool value = false;
|
|
208
|
+
};
|
|
209
|
+
|
|
210
|
+
// Specialization for the case when T and U are the same type
|
|
211
|
+
template <typename T>
|
|
212
|
+
struct is_same<T, T> {
|
|
213
|
+
static constexpr bool value = true;
|
|
214
|
+
};
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
template <int N>
|
|
218
|
+
struct tile_coord_t
|
|
219
|
+
{
|
|
220
|
+
int indices[N];
|
|
221
|
+
|
|
222
|
+
CUDA_CALLABLE inline int operator[](int i) const { assert(0 <= 1 && i < N); return indices[i]; }
|
|
223
|
+
CUDA_CALLABLE inline int& operator[](int i) { assert(0 <= 1 && i < N); return indices[i]; }
|
|
224
|
+
|
|
225
|
+
CUDA_CALLABLE inline tile_coord_t<N> operator + (const tile_coord_t<N>& c) const
|
|
226
|
+
{
|
|
227
|
+
tile_coord_t<N> out;
|
|
228
|
+
for (int i=0; i < N; ++i)
|
|
229
|
+
{
|
|
230
|
+
out.indices[i] = indices[i] + c.indices[i];
|
|
231
|
+
}
|
|
232
|
+
return out;
|
|
233
|
+
}
|
|
234
|
+
};
|
|
235
|
+
|
|
236
|
+
// This function deduces N = sizeof...(Ints)
|
|
237
|
+
template <typename... Ints>
|
|
238
|
+
constexpr tile_coord_t<sizeof...(Ints)> tile_coord(Ints... idxs)
|
|
239
|
+
{
|
|
240
|
+
constexpr int N = sizeof...(Ints);
|
|
241
|
+
|
|
242
|
+
// Create the result
|
|
243
|
+
tile_coord_t<N> result{};
|
|
244
|
+
|
|
245
|
+
// Capture all arguments in a local array
|
|
246
|
+
int arr[] = { static_cast<int>(idxs)... };
|
|
247
|
+
|
|
248
|
+
// C++14 or later: 'for' is allowed in a constexpr context
|
|
249
|
+
for (int i = 0; i < N; ++i)
|
|
250
|
+
{
|
|
251
|
+
result.indices[i] = arr[i];
|
|
252
|
+
}
|
|
253
|
+
|
|
254
|
+
return result;
|
|
255
|
+
}
|
|
256
|
+
|
|
257
|
+
// helpers to construct a coord from a set of indices
|
|
258
|
+
inline auto tile_coord(int i)
|
|
259
|
+
{
|
|
260
|
+
auto c = tile_coord_t<1>();
|
|
261
|
+
c.indices[0] = i;
|
|
262
|
+
return c;
|
|
263
|
+
}
|
|
264
|
+
|
|
265
|
+
inline auto tile_coord(int i, int j)
|
|
266
|
+
{
|
|
267
|
+
auto c = tile_coord_t<2>();
|
|
268
|
+
c.indices[0] = i;
|
|
269
|
+
c.indices[1] = j;
|
|
270
|
+
return c;
|
|
271
|
+
}
|
|
272
|
+
|
|
273
|
+
inline auto tile_coord(int i, int j, int k)
|
|
274
|
+
{
|
|
275
|
+
auto c = tile_coord_t<3>();
|
|
276
|
+
c.indices[0] = i;
|
|
277
|
+
c.indices[1] = j;
|
|
278
|
+
c.indices[2] = k;
|
|
279
|
+
return c;
|
|
280
|
+
}
|
|
281
|
+
|
|
282
|
+
inline auto tile_coord(int i, int j, int k, int l)
|
|
283
|
+
{
|
|
284
|
+
auto c = tile_coord_t<4>();
|
|
285
|
+
c.indices[0] = i;
|
|
286
|
+
c.indices[1] = j;
|
|
287
|
+
c.indices[2] = k;
|
|
288
|
+
c.indices[3] = l;
|
|
289
|
+
return c;
|
|
290
|
+
}
|
|
291
|
+
|
|
292
|
+
// represents a compile time int tuple for strides/shapes/coords
|
|
293
|
+
template <int... V>
|
|
294
|
+
struct tile_tuple_t
|
|
295
|
+
{
|
|
296
|
+
static constexpr int N = sizeof...(V);
|
|
297
|
+
static_assert(N > 0, "Expected N > 0");
|
|
298
|
+
|
|
299
|
+
static constexpr int data[N] = { V... };
|
|
300
|
+
|
|
301
|
+
static constexpr int dim(int i) { assert(i < N); return data[i]; }
|
|
302
|
+
static constexpr int size()
|
|
303
|
+
{
|
|
304
|
+
int res = data[0];
|
|
305
|
+
for (int i=1; i < N; ++i)
|
|
306
|
+
res *= data[i];
|
|
307
|
+
|
|
308
|
+
return res;
|
|
309
|
+
}
|
|
310
|
+
};
|
|
311
|
+
|
|
312
|
+
// simple helper to compute strides from a shape up to 4d
|
|
313
|
+
template <typename Shape>
|
|
314
|
+
struct compute_strides;
|
|
315
|
+
|
|
316
|
+
// 1D
|
|
317
|
+
template <int D0>
|
|
318
|
+
struct compute_strides< tile_tuple_t<D0> > { using Stride = tile_tuple_t<1>; };
|
|
319
|
+
// 2D
|
|
320
|
+
template <int D0, int D1>
|
|
321
|
+
struct compute_strides< tile_tuple_t<D0, D1> > { using Stride = tile_tuple_t<D1, 1>; };
|
|
322
|
+
// 3D
|
|
323
|
+
template <int D0, int D1, int D2>
|
|
324
|
+
struct compute_strides< tile_tuple_t<D0, D1, D2> > { using Stride = tile_tuple_t<(D1 * D2), D2, 1>; };
|
|
325
|
+
// 4D
|
|
326
|
+
template <int D0, int D1, int D2, int D3>
|
|
327
|
+
struct compute_strides< tile_tuple_t<D0, D1, D2, D3> > { using Stride = tile_tuple_t<(D1 * D2 * D3), (D2 * D3), D3, 1>; };
|
|
328
|
+
|
|
329
|
+
|
|
330
|
+
// alias of tuple to represent shapes
|
|
331
|
+
template <int... V>
|
|
332
|
+
using tile_shape_t = tile_tuple_t<V...>;
|
|
333
|
+
|
|
334
|
+
// alias of tuple to represent stride
|
|
335
|
+
template <int... V>
|
|
336
|
+
using tile_stride_t = tile_tuple_t<V...>;
|
|
337
|
+
|
|
338
|
+
|
|
339
|
+
// represents a tile stored in global memory with dynamic strides
|
|
340
|
+
// used to represent the source and offset for tile loads to register/shared
|
|
341
|
+
template <typename T, typename Shape_>
|
|
342
|
+
struct tile_global_t
|
|
343
|
+
{
|
|
344
|
+
using Type = T;
|
|
345
|
+
using Shape = Shape_;
|
|
346
|
+
using Coord = tile_coord_t<Shape::N>;
|
|
347
|
+
|
|
348
|
+
array_t<T> data;
|
|
349
|
+
Coord offset;
|
|
350
|
+
|
|
351
|
+
tile_global_t(array_t<T>& a, const Coord& c) : data(a), offset(c)
|
|
352
|
+
{
|
|
353
|
+
}
|
|
354
|
+
|
|
355
|
+
inline CUDA_CALLABLE int index_from_coord(const Coord& coord) const
|
|
356
|
+
{
|
|
357
|
+
// element index
|
|
358
|
+
int index = 0;
|
|
359
|
+
|
|
360
|
+
WP_PRAGMA_UNROLL
|
|
361
|
+
for (int i=0; i < Shape::N; ++i)
|
|
362
|
+
{
|
|
363
|
+
// global = offset + coord
|
|
364
|
+
int c = offset[i] + coord[i];
|
|
365
|
+
index += data.strides[i]*c;
|
|
366
|
+
}
|
|
367
|
+
|
|
368
|
+
return index/sizeof(T);
|
|
369
|
+
}
|
|
370
|
+
|
|
371
|
+
inline CUDA_CALLABLE bool index(const Coord& coord, int& out) const
|
|
372
|
+
{
|
|
373
|
+
// element index
|
|
374
|
+
int index = 0;
|
|
375
|
+
|
|
376
|
+
WP_PRAGMA_UNROLL
|
|
377
|
+
for (int i=0; i < Shape::N; ++i)
|
|
378
|
+
{
|
|
379
|
+
// global = offset + coord
|
|
380
|
+
int c = offset[i] + coord[i];
|
|
381
|
+
|
|
382
|
+
// handle out of bounds case
|
|
383
|
+
if (c >= data.shape[i])
|
|
384
|
+
return false;
|
|
385
|
+
else
|
|
386
|
+
index += data.strides[i]*c;
|
|
387
|
+
}
|
|
388
|
+
|
|
389
|
+
// array strides are in bytes so we convert to elements
|
|
390
|
+
out = index / sizeof(T);
|
|
391
|
+
return true;
|
|
392
|
+
}
|
|
393
|
+
|
|
394
|
+
inline CUDA_CALLABLE T load(const Coord& coord) const
|
|
395
|
+
{
|
|
396
|
+
int i;
|
|
397
|
+
if (index(coord, i))
|
|
398
|
+
return data.data[i];
|
|
399
|
+
else
|
|
400
|
+
return T(0);
|
|
401
|
+
}
|
|
402
|
+
|
|
403
|
+
inline CUDA_CALLABLE T load_grad(const Coord& coord) const
|
|
404
|
+
{
|
|
405
|
+
int i;
|
|
406
|
+
if (index(coord, i))
|
|
407
|
+
return data.grad[i];
|
|
408
|
+
else
|
|
409
|
+
return T(0);
|
|
410
|
+
}
|
|
411
|
+
|
|
412
|
+
inline CUDA_CALLABLE void store(const Coord& coord, const T& x) const
|
|
413
|
+
{
|
|
414
|
+
int i;
|
|
415
|
+
if (index(coord, i))
|
|
416
|
+
data.data[i] = x;
|
|
417
|
+
}
|
|
418
|
+
|
|
419
|
+
inline CUDA_CALLABLE T atomic_add(const Coord& coord, const T& value) const
|
|
420
|
+
{
|
|
421
|
+
int i;
|
|
422
|
+
if (index(coord, i))
|
|
423
|
+
return wp::atomic_add(&data.data[i], value);
|
|
424
|
+
else
|
|
425
|
+
return T(0);
|
|
426
|
+
}
|
|
427
|
+
|
|
428
|
+
inline CUDA_CALLABLE T atomic_add_grad(const Coord& coord, const T& grad) const
|
|
429
|
+
{
|
|
430
|
+
int i;
|
|
431
|
+
if (index(coord, i))
|
|
432
|
+
return wp::atomic_add(&data.grad[i], grad);
|
|
433
|
+
else
|
|
434
|
+
return T(0);
|
|
435
|
+
}
|
|
436
|
+
};
|
|
437
|
+
|
|
438
|
+
template <typename Shape_>
|
|
439
|
+
struct tile_layout_register_t
|
|
440
|
+
{
|
|
441
|
+
using Shape = Shape_;
|
|
442
|
+
using Coord = tile_coord_t<Shape::N>;
|
|
443
|
+
|
|
444
|
+
static constexpr int Size = Shape::size();
|
|
445
|
+
static constexpr int NumRegs = (Size + WP_TILE_BLOCK_DIM - 1) / WP_TILE_BLOCK_DIM;
|
|
446
|
+
static constexpr bool Aligned = Size%WP_TILE_BLOCK_DIM == 0;
|
|
447
|
+
|
|
448
|
+
static inline CUDA_CALLABLE int linear_from_register(int reg)
|
|
449
|
+
{
|
|
450
|
+
return WP_TILE_THREAD_IDX + reg*WP_TILE_BLOCK_DIM;
|
|
451
|
+
}
|
|
452
|
+
|
|
453
|
+
static inline CUDA_CALLABLE int linear_from_coord(Coord c)
|
|
454
|
+
{
|
|
455
|
+
int linear = 0;
|
|
456
|
+
int stride = 1;
|
|
457
|
+
|
|
458
|
+
WP_PRAGMA_UNROLL
|
|
459
|
+
for (int i=Shape::N-1; i >= 0; --i)
|
|
460
|
+
{
|
|
461
|
+
linear += c[i] * stride;
|
|
462
|
+
stride *= Shape::dim(i);
|
|
463
|
+
}
|
|
464
|
+
return linear;
|
|
465
|
+
}
|
|
466
|
+
|
|
467
|
+
static inline CUDA_CALLABLE auto coord_from_linear(int linear)
|
|
468
|
+
{
|
|
469
|
+
Coord c;
|
|
470
|
+
|
|
471
|
+
WP_PRAGMA_UNROLL
|
|
472
|
+
for (int i=Shape::N-1; i >= 0; --i)
|
|
473
|
+
{
|
|
474
|
+
c[i] = linear%Shape::dim(i);
|
|
475
|
+
linear /= Shape::dim(i);
|
|
476
|
+
}
|
|
477
|
+
|
|
478
|
+
return c;
|
|
479
|
+
}
|
|
480
|
+
|
|
481
|
+
static inline CUDA_CALLABLE int thread_from_linear(int linear)
|
|
482
|
+
{
|
|
483
|
+
const int thread = linear%WP_TILE_BLOCK_DIM;
|
|
484
|
+
return thread;
|
|
485
|
+
}
|
|
486
|
+
|
|
487
|
+
static inline CUDA_CALLABLE int register_from_linear(int linear)
|
|
488
|
+
{
|
|
489
|
+
const int reg = linear/WP_TILE_BLOCK_DIM;
|
|
490
|
+
return reg;
|
|
491
|
+
}
|
|
492
|
+
|
|
493
|
+
static inline CUDA_CALLABLE bool valid(int linear)
|
|
494
|
+
{
|
|
495
|
+
if (Aligned || linear < Size)
|
|
496
|
+
return true;
|
|
497
|
+
else
|
|
498
|
+
return false;
|
|
499
|
+
}
|
|
500
|
+
|
|
501
|
+
};
|
|
502
|
+
|
|
503
|
+
// represents a tile stored in registers across a block
|
|
504
|
+
template <typename T, typename L>
|
|
505
|
+
struct tile_register_t
|
|
506
|
+
{
|
|
507
|
+
using Type = T;
|
|
508
|
+
using Layout = L;
|
|
509
|
+
|
|
510
|
+
T data[Layout::NumRegs];
|
|
511
|
+
|
|
512
|
+
inline CUDA_CALLABLE tile_register_t(T value=T(0.0))
|
|
513
|
+
{
|
|
514
|
+
// zero-initialize by default necessary for tile adjoints
|
|
515
|
+
// need to check if this results in worse codegen
|
|
516
|
+
// than doing adj_var = tile_zeros() explicitly
|
|
517
|
+
// in backwards pass and letting default constructor
|
|
518
|
+
// avoid initialization
|
|
519
|
+
|
|
520
|
+
for (int i=0; i < Layout::NumRegs; ++i)
|
|
521
|
+
data[i] = value;
|
|
522
|
+
}
|
|
523
|
+
|
|
524
|
+
inline CUDA_CALLABLE auto& operator=(const tile_global_t<T, typename Layout::Shape>& t)
|
|
525
|
+
{
|
|
526
|
+
copy_from_global(t);
|
|
527
|
+
return *this;
|
|
528
|
+
}
|
|
529
|
+
|
|
530
|
+
// define the += operator which is used during backward pass codegen
|
|
531
|
+
// when returning a register tile from a user defined function
|
|
532
|
+
inline CUDA_CALLABLE auto& operator += (tile_register_t<T, Layout>& rhs)
|
|
533
|
+
{
|
|
534
|
+
grad_add(rhs);
|
|
535
|
+
return *this;
|
|
536
|
+
}
|
|
537
|
+
|
|
538
|
+
inline CUDA_CALLABLE T& operator()(int reg)
|
|
539
|
+
{
|
|
540
|
+
assert(reg < Layout::NumRegs);
|
|
541
|
+
return data[reg];
|
|
542
|
+
}
|
|
543
|
+
|
|
544
|
+
inline CUDA_CALLABLE const T& operator()(int reg) const
|
|
545
|
+
{
|
|
546
|
+
assert(reg < Layout::NumRegs);
|
|
547
|
+
return data[reg];
|
|
548
|
+
}
|
|
549
|
+
|
|
550
|
+
inline CUDA_CALLABLE void assign(const tile_register_t<T, Layout>& tile)
|
|
551
|
+
{
|
|
552
|
+
for (int i=0; i < Layout::NumRegs; ++i)
|
|
553
|
+
data[i] = tile.data[i];
|
|
554
|
+
}
|
|
555
|
+
|
|
556
|
+
inline CUDA_CALLABLE void zero()
|
|
557
|
+
{
|
|
558
|
+
for (int i=0; i < Layout::NumRegs; ++i)
|
|
559
|
+
data[i] = T(0);
|
|
560
|
+
}
|
|
561
|
+
|
|
562
|
+
// extract a single tile element to a native type
|
|
563
|
+
template <typename Coord>
|
|
564
|
+
inline CUDA_CALLABLE Type extract(const Coord& c)
|
|
565
|
+
{
|
|
566
|
+
// map from logical coords (i, j) -> (thread, reg)
|
|
567
|
+
const int linear = Layout::linear_from_coord(c);
|
|
568
|
+
const int thread = Layout::thread_from_linear(linear);
|
|
569
|
+
const int reg = Layout::register_from_linear(linear);
|
|
570
|
+
|
|
571
|
+
WP_TILE_SHARED Type scratch;
|
|
572
|
+
|
|
573
|
+
// ensure any previously scheduled threads have finished reading from scratch
|
|
574
|
+
WP_TILE_SYNC();
|
|
575
|
+
|
|
576
|
+
if (WP_TILE_THREAD_IDX == thread)
|
|
577
|
+
{
|
|
578
|
+
scratch = data[reg];
|
|
579
|
+
}
|
|
580
|
+
|
|
581
|
+
// ensure extraction thread has updated smem
|
|
582
|
+
WP_TILE_SYNC();
|
|
583
|
+
|
|
584
|
+
return scratch;
|
|
585
|
+
}
|
|
586
|
+
|
|
587
|
+
|
|
588
|
+
// backward version of scalar extract
|
|
589
|
+
template <typename Coord>
|
|
590
|
+
inline CUDA_CALLABLE void adj_extract(const Coord& c, Type adj_ret)
|
|
591
|
+
{
|
|
592
|
+
// map from logical coords (i, j) -> (thread, reg)
|
|
593
|
+
const int linear = Layout::linear_from_coord(c);
|
|
594
|
+
const int thread = Layout::thread_from_linear(linear);
|
|
595
|
+
const int reg = Layout::register_from_linear(linear);
|
|
596
|
+
|
|
597
|
+
if (WP_TILE_THREAD_IDX == thread)
|
|
598
|
+
{
|
|
599
|
+
data[reg] += adj_ret;
|
|
600
|
+
}
|
|
601
|
+
}
|
|
602
|
+
|
|
603
|
+
inline CUDA_CALLABLE void print() const;
|
|
604
|
+
|
|
605
|
+
|
|
606
|
+
// return the in-register version of this tile (nop)
|
|
607
|
+
inline CUDA_CALLABLE auto& copy_to_register()
|
|
608
|
+
{
|
|
609
|
+
return *this;
|
|
610
|
+
}
|
|
611
|
+
|
|
612
|
+
inline CUDA_CALLABLE const auto& copy_to_register() const
|
|
613
|
+
{
|
|
614
|
+
return *this;
|
|
615
|
+
}
|
|
616
|
+
|
|
617
|
+
// apply a lambda to all valid entries in the tile
|
|
618
|
+
// Op should be a functor that takes a register index and tile_coord_t as input
|
|
619
|
+
template <typename Op>
|
|
620
|
+
void apply(Op op)
|
|
621
|
+
{
|
|
622
|
+
WP_PRAGMA_UNROLL
|
|
623
|
+
for (int i=0; i < Layout::NumRegs; ++i)
|
|
624
|
+
{
|
|
625
|
+
int linear = Layout::linear_from_register(i);
|
|
626
|
+
if (!Layout::valid(linear))
|
|
627
|
+
break;
|
|
628
|
+
|
|
629
|
+
auto c = Layout::coord_from_linear(linear);
|
|
630
|
+
op(i, c);
|
|
631
|
+
}
|
|
632
|
+
}
|
|
633
|
+
|
|
634
|
+
|
|
635
|
+
// in-place gradient zero
|
|
636
|
+
inline CUDA_CALLABLE void grad_zero()
|
|
637
|
+
{
|
|
638
|
+
zero();
|
|
639
|
+
}
|
|
640
|
+
|
|
641
|
+
// accumulate gradients onto this tile
|
|
642
|
+
inline CUDA_CALLABLE void grad_add(const tile_register_t<T, Layout>& tile)
|
|
643
|
+
{
|
|
644
|
+
for (int i=0; i < Layout::NumRegs; ++i)
|
|
645
|
+
data[i] += tile.data[i];
|
|
646
|
+
}
|
|
647
|
+
|
|
648
|
+
CUDA_CALLABLE void grad_add(const tile_global_t<T, typename Layout::Shape>& global)
|
|
649
|
+
{
|
|
650
|
+
apply([&](int reg, auto c) {data[reg] = global.load_grad(c);});
|
|
651
|
+
|
|
652
|
+
}
|
|
653
|
+
|
|
654
|
+
inline CUDA_CALLABLE auto& grad_to_register()
|
|
655
|
+
{
|
|
656
|
+
// nop for register tiles
|
|
657
|
+
return *this;
|
|
658
|
+
}
|
|
659
|
+
|
|
660
|
+
template <typename Global>
|
|
661
|
+
inline CUDA_CALLABLE void copy_to_global(const Global& dest)
|
|
662
|
+
{
|
|
663
|
+
apply([&](int reg, auto c) { dest.store(c, data[reg]); });
|
|
664
|
+
}
|
|
665
|
+
|
|
666
|
+
template <typename Global>
|
|
667
|
+
inline CUDA_CALLABLE void copy_from_global(const Global& src)
|
|
668
|
+
{
|
|
669
|
+
apply([&](int reg, auto c) { data[reg] = src.load(c); });
|
|
670
|
+
}
|
|
671
|
+
|
|
672
|
+
// add a register tile to a global array
|
|
673
|
+
template <typename Global>
|
|
674
|
+
inline CUDA_CALLABLE auto atomic_add(const Global& dest)
|
|
675
|
+
{
|
|
676
|
+
// allocate a tile to hold previous dest value
|
|
677
|
+
auto previous = *this;
|
|
678
|
+
|
|
679
|
+
apply([&](int reg, auto c) { previous.data[reg] = dest.atomic_add(c, data[reg]); });
|
|
680
|
+
return previous;
|
|
681
|
+
}
|
|
682
|
+
|
|
683
|
+
// add a register tile to the gradient of a global array
|
|
684
|
+
template <typename Global>
|
|
685
|
+
inline CUDA_CALLABLE auto atomic_add_grad(const Global& dest)
|
|
686
|
+
{
|
|
687
|
+
// allocate a tile to hold previous dest value
|
|
688
|
+
auto previous = *this;
|
|
689
|
+
|
|
690
|
+
apply([&](int reg, auto c) { previous.data[reg] = dest.atomic_add_grad(c, data[reg]); });
|
|
691
|
+
return previous;
|
|
692
|
+
}
|
|
693
|
+
};
|
|
694
|
+
|
|
695
|
+
|
|
696
|
+
// helper to allocate a register tile like another tile
|
|
697
|
+
// users can either specify a template explicitly or
|
|
698
|
+
// pass in another concrete instance
|
|
699
|
+
template<typename Tile>
|
|
700
|
+
auto tile_register_like(Tile* t=nullptr)
|
|
701
|
+
{
|
|
702
|
+
using T = typename Tile::Type;
|
|
703
|
+
using L = typename Tile::Layout;
|
|
704
|
+
|
|
705
|
+
return tile_register_t<T, tile_layout_register_t<typename L::Shape>>(T(0.0));
|
|
706
|
+
}
|
|
707
|
+
|
|
708
|
+
// helper to construct a register tile from a type and a list of dims
|
|
709
|
+
template <typename T, int... Dims>
|
|
710
|
+
auto tile_register()
|
|
711
|
+
{
|
|
712
|
+
return tile_register_t<T, tile_layout_register_t<tile_shape_t<Dims...>>>();
|
|
713
|
+
}
|
|
714
|
+
|
|
715
|
+
inline CUDA_CALLABLE int tile_align(int num_bytes)
|
|
716
|
+
{
|
|
717
|
+
// note this much match value in Python types.py
|
|
718
|
+
const int alignment = 16;
|
|
719
|
+
|
|
720
|
+
const int num_bytes_abs = num_bytes < 0 ? - num_bytes : num_bytes;
|
|
721
|
+
const int sign = num_bytes < 0 ? - 1 : 1;
|
|
722
|
+
|
|
723
|
+
return sign * ((num_bytes_abs + alignment - 1) / alignment) * alignment;
|
|
724
|
+
}
|
|
725
|
+
|
|
726
|
+
inline CUDA_CALLABLE void* tile_alloc_shared(int num_bytes, bool init=false, bool check=false)
|
|
727
|
+
{
|
|
728
|
+
// we maintain a per-thread offset into dynamic
|
|
729
|
+
// shared memory that allows us to keep track of
|
|
730
|
+
// current use across dynamic function calls
|
|
731
|
+
WP_TILE_SHARED int smem_base[WP_TILE_BLOCK_DIM];
|
|
732
|
+
|
|
733
|
+
if (init)
|
|
734
|
+
{
|
|
735
|
+
smem_base[WP_TILE_THREAD_IDX] = 0;
|
|
736
|
+
return nullptr;
|
|
737
|
+
}
|
|
738
|
+
else if (check)
|
|
739
|
+
{
|
|
740
|
+
assert(smem_base[WP_TILE_THREAD_IDX] == 0);
|
|
741
|
+
return nullptr;
|
|
742
|
+
}
|
|
743
|
+
else
|
|
744
|
+
{
|
|
745
|
+
const int offset = smem_base[WP_TILE_THREAD_IDX];
|
|
746
|
+
|
|
747
|
+
// one entry per-thread so no need for synchronization
|
|
748
|
+
smem_base[WP_TILE_THREAD_IDX] += tile_align(num_bytes);
|
|
749
|
+
|
|
750
|
+
#ifdef __CUDA_ARCH__
|
|
751
|
+
extern __shared__ char dynamic_smem_base[];
|
|
752
|
+
#else
|
|
753
|
+
// on CPU allocate a fixed 256k block to use for shared allocs
|
|
754
|
+
static const int max_cpu_shared = 256*1024;
|
|
755
|
+
static char dynamic_smem_base[max_cpu_shared];
|
|
756
|
+
|
|
757
|
+
assert(smem_base[WP_TILE_THREAD_IDX] <= max_cpu_shared);
|
|
758
|
+
#endif
|
|
759
|
+
return &(dynamic_smem_base[offset]);
|
|
760
|
+
}
|
|
761
|
+
}
|
|
762
|
+
|
|
763
|
+
|
|
764
|
+
template <typename Shape_, typename Stride_= typename compute_strides<Shape_>::Stride>
|
|
765
|
+
struct tile_layout_strided_t
|
|
766
|
+
{
|
|
767
|
+
using Shape = Shape_;
|
|
768
|
+
using Stride = Stride_;
|
|
769
|
+
using Coord = tile_coord_t<Shape::N>;
|
|
770
|
+
|
|
771
|
+
static constexpr int Size = Shape::size();
|
|
772
|
+
static constexpr bool Aligned = Size%WP_TILE_BLOCK_DIM == 0;
|
|
773
|
+
|
|
774
|
+
static inline CUDA_CALLABLE auto coord_from_linear(int linear)
|
|
775
|
+
{
|
|
776
|
+
assert(linear < Size);
|
|
777
|
+
|
|
778
|
+
Coord c;
|
|
779
|
+
|
|
780
|
+
WP_PRAGMA_UNROLL
|
|
781
|
+
for (int d=Shape::N-1; d >= 0; --d)
|
|
782
|
+
{
|
|
783
|
+
c[d] = linear%Shape::dim(d);
|
|
784
|
+
linear /= Shape::dim(d);
|
|
785
|
+
}
|
|
786
|
+
|
|
787
|
+
return c;
|
|
788
|
+
}
|
|
789
|
+
|
|
790
|
+
static inline CUDA_CALLABLE int index_from_coord(Coord c)
|
|
791
|
+
{
|
|
792
|
+
int index = 0;
|
|
793
|
+
|
|
794
|
+
WP_PRAGMA_UNROLL
|
|
795
|
+
for (int d=0; d < Shape::N; ++d)
|
|
796
|
+
{
|
|
797
|
+
assert(c[d] < Shape::dim(d));
|
|
798
|
+
|
|
799
|
+
index += c[d]*Stride::dim(d);
|
|
800
|
+
}
|
|
801
|
+
|
|
802
|
+
return index;
|
|
803
|
+
}
|
|
804
|
+
|
|
805
|
+
// checks whether a strided layout is unique, i.e.: if memory locations are only
|
|
806
|
+
// every referred to by one element in the tile, this is a basic test that only
|
|
807
|
+
// checks for broadcast dimensions, it would be possible to do the full check
|
|
808
|
+
// using sorted shape/strides in Python and add it as a template parameter to the type
|
|
809
|
+
static constexpr bool is_unique()
|
|
810
|
+
{
|
|
811
|
+
constexpr int N = Shape::N;
|
|
812
|
+
|
|
813
|
+
// check for any broadcast dimensions
|
|
814
|
+
for (int i=0; i < N; ++i)
|
|
815
|
+
if (Stride::dim(i) == 0)
|
|
816
|
+
return false;
|
|
817
|
+
|
|
818
|
+
return true;
|
|
819
|
+
}
|
|
820
|
+
|
|
821
|
+
static constexpr bool Unique = is_unique();
|
|
822
|
+
|
|
823
|
+
static inline CUDA_CALLABLE bool valid(int linear)
|
|
824
|
+
{
|
|
825
|
+
return linear < Size;
|
|
826
|
+
}
|
|
827
|
+
|
|
828
|
+
};
|
|
829
|
+
|
|
830
|
+
|
|
831
|
+
template <typename T, typename L, bool Owner_=true>
|
|
832
|
+
struct tile_shared_t
|
|
833
|
+
{
|
|
834
|
+
using Type = T;
|
|
835
|
+
using Layout = L;
|
|
836
|
+
static constexpr bool Owner = Owner_;
|
|
837
|
+
|
|
838
|
+
struct Storage
|
|
839
|
+
{
|
|
840
|
+
T* ptr;
|
|
841
|
+
|
|
842
|
+
Storage(T* p) : ptr(p) {}
|
|
843
|
+
|
|
844
|
+
inline CUDA_CALLABLE T& operator()(typename Layout::Coord c)
|
|
845
|
+
{
|
|
846
|
+
assert(ptr);
|
|
847
|
+
|
|
848
|
+
int index = Layout::index_from_coord(c);
|
|
849
|
+
return ptr[index];
|
|
850
|
+
}
|
|
851
|
+
|
|
852
|
+
inline CUDA_CALLABLE const T& operator()(typename Layout::Coord c) const
|
|
853
|
+
{
|
|
854
|
+
assert(ptr);
|
|
855
|
+
|
|
856
|
+
int index = Layout::index_from_coord(c);
|
|
857
|
+
return ptr[index];
|
|
858
|
+
}
|
|
859
|
+
|
|
860
|
+
inline CUDA_CALLABLE T& operator()(int linear)
|
|
861
|
+
{
|
|
862
|
+
assert(ptr);
|
|
863
|
+
assert(Layout::valid(linear));
|
|
864
|
+
|
|
865
|
+
auto c = Layout::coord_from_linear(linear);
|
|
866
|
+
return (*this)(c);
|
|
867
|
+
}
|
|
868
|
+
|
|
869
|
+
inline CUDA_CALLABLE const T& operator()(int linear) const
|
|
870
|
+
{
|
|
871
|
+
assert(ptr);
|
|
872
|
+
assert(Layout::valid(linear));
|
|
873
|
+
|
|
874
|
+
auto c = Layout::coord_from_linear(linear);
|
|
875
|
+
return (*this)(c);
|
|
876
|
+
}
|
|
877
|
+
};
|
|
878
|
+
|
|
879
|
+
Storage data;
|
|
880
|
+
Storage grad;
|
|
881
|
+
|
|
882
|
+
// we need to track whether or not this tile's data has been initialized.
|
|
883
|
+
// once true, any re-initialization of data that follows needs a WP_TILE_SYNC()
|
|
884
|
+
// call to precede it, to allow threads that are still reading from this tile
|
|
885
|
+
// to complete their work. e.g, in a dynamic loop:
|
|
886
|
+
// for i in range(x):
|
|
887
|
+
// tile = wp.tile_load(arr, i, TILE_SIZE, storage="shared")
|
|
888
|
+
// # read from tile...
|
|
889
|
+
bool initialized;
|
|
890
|
+
|
|
891
|
+
// default initialization (non-initialized)
|
|
892
|
+
inline CUDA_CALLABLE tile_shared_t() : data(nullptr), grad(nullptr), initialized(false)
|
|
893
|
+
{
|
|
894
|
+
}
|
|
895
|
+
|
|
896
|
+
// initialize from an existing tile's memory
|
|
897
|
+
inline CUDA_CALLABLE tile_shared_t(T* data, T* grad=nullptr, bool initialized=true) : data(data), grad(grad), initialized(initialized)
|
|
898
|
+
{
|
|
899
|
+
}
|
|
900
|
+
|
|
901
|
+
inline CUDA_CALLABLE ~tile_shared_t()
|
|
902
|
+
{
|
|
903
|
+
if (Owner)
|
|
904
|
+
{
|
|
905
|
+
// update our per-thread shared memory allocator
|
|
906
|
+
if (data.ptr)
|
|
907
|
+
tile_alloc_shared(-Layout::Size*int(sizeof(T)));
|
|
908
|
+
|
|
909
|
+
if (grad.ptr)
|
|
910
|
+
tile_alloc_shared(-Layout::Size*int(sizeof(T)));
|
|
911
|
+
}
|
|
912
|
+
}
|
|
913
|
+
|
|
914
|
+
// assign from a register tile
|
|
915
|
+
template <typename Tile>
|
|
916
|
+
inline CUDA_CALLABLE auto& operator=(const Tile& t)
|
|
917
|
+
{
|
|
918
|
+
assign(t);
|
|
919
|
+
return *this;
|
|
920
|
+
}
|
|
921
|
+
|
|
922
|
+
|
|
923
|
+
/*
|
|
924
|
+
// construct from another shared tile, this constructor
|
|
925
|
+
// is invoked for reshape operations like `wp.tile_transpose()`
|
|
926
|
+
template <typename OtherT, typename OtherLayout>
|
|
927
|
+
inline CUDA_CALLABLE auto& operator=(const tile_shared_t<OtherT, OtherLayout>& rhs)
|
|
928
|
+
{
|
|
929
|
+
using OtherTile = tile_shared_t<OtherT, OtherLayout>;
|
|
930
|
+
|
|
931
|
+
// check dimensions are compatible
|
|
932
|
+
static_assert(Size == OtherTile::Size, "Expected Size == OtherTile::Size");
|
|
933
|
+
|
|
934
|
+
// alias tile directly
|
|
935
|
+
data = rhs.data;
|
|
936
|
+
grad = rhs.grad;
|
|
937
|
+
initialized = rhs.initialized;
|
|
938
|
+
|
|
939
|
+
return *this;
|
|
940
|
+
}
|
|
941
|
+
*/
|
|
942
|
+
|
|
943
|
+
// assign from a global tile (load)
|
|
944
|
+
inline CUDA_CALLABLE auto& operator=(const tile_global_t<T, typename Layout::Shape>& t)
|
|
945
|
+
{
|
|
946
|
+
copy_from_global(t);
|
|
947
|
+
return *this;
|
|
948
|
+
}
|
|
949
|
+
|
|
950
|
+
// assign from a constant value
|
|
951
|
+
inline CUDA_CALLABLE auto& operator=(const T& x)
|
|
952
|
+
{
|
|
953
|
+
// sync if we are re-initializing data so that any threads that are still
|
|
954
|
+
// reading from this tile can complete their work, e.g.: if re-assigning
|
|
955
|
+
// to a tile during a dynamic loop
|
|
956
|
+
if (initialized)
|
|
957
|
+
WP_TILE_SYNC();
|
|
958
|
+
|
|
959
|
+
for (int i=WP_TILE_THREAD_IDX; i < Layout::Size; i+= WP_TILE_BLOCK_DIM)
|
|
960
|
+
data(i) = x;
|
|
961
|
+
|
|
962
|
+
initialized = true;
|
|
963
|
+
WP_TILE_SYNC();
|
|
964
|
+
return *this;
|
|
965
|
+
}
|
|
966
|
+
|
|
967
|
+
// in-place zero
|
|
968
|
+
inline CUDA_CALLABLE void zero()
|
|
969
|
+
{
|
|
970
|
+
for (int i=WP_TILE_THREAD_IDX; i < Layout::Size; i+= WP_TILE_BLOCK_DIM)
|
|
971
|
+
data(i) = T(0);
|
|
972
|
+
|
|
973
|
+
WP_TILE_SYNC();
|
|
974
|
+
}
|
|
975
|
+
|
|
976
|
+
// extract a single tile element to a native type
|
|
977
|
+
inline CUDA_CALLABLE Type extract(const typename Layout::Coord& c)
|
|
978
|
+
{
|
|
979
|
+
return data(c);
|
|
980
|
+
}
|
|
981
|
+
|
|
982
|
+
// backward of scalar extraction
|
|
983
|
+
inline CUDA_CALLABLE void adj_extract(const typename Layout::Coord& c, Type adj_ret)
|
|
984
|
+
{
|
|
985
|
+
// since multiple threads may extract the same element
|
|
986
|
+
// we need to accumulate using atomic operations
|
|
987
|
+
wp::atomic_add(&grad(c), adj_ret);
|
|
988
|
+
|
|
989
|
+
WP_TILE_SYNC();
|
|
990
|
+
}
|
|
991
|
+
|
|
992
|
+
|
|
993
|
+
// copy register tile to shared
|
|
994
|
+
template <typename Tile>
|
|
995
|
+
inline CUDA_CALLABLE void assign(const Tile& tile)
|
|
996
|
+
{
|
|
997
|
+
if (initialized)
|
|
998
|
+
WP_TILE_SYNC();
|
|
999
|
+
|
|
1000
|
+
WP_PRAGMA_UNROLL
|
|
1001
|
+
for (int i=0; i < Tile::Layout::NumRegs; ++i)
|
|
1002
|
+
{
|
|
1003
|
+
const int linear = Tile::Layout::linear_from_register(i);
|
|
1004
|
+
|
|
1005
|
+
// handle case where tile size is not
|
|
1006
|
+
// aligned to block dimensions
|
|
1007
|
+
if (!Tile::Layout::valid(linear))
|
|
1008
|
+
break;
|
|
1009
|
+
|
|
1010
|
+
data(linear) = tile.data[i];
|
|
1011
|
+
}
|
|
1012
|
+
|
|
1013
|
+
initialized = true;
|
|
1014
|
+
WP_TILE_SYNC();
|
|
1015
|
+
}
|
|
1016
|
+
|
|
1017
|
+
// in-place gradient zero
|
|
1018
|
+
inline CUDA_CALLABLE void grad_zero()
|
|
1019
|
+
{
|
|
1020
|
+
for (int i=WP_TILE_THREAD_IDX; i < Layout::Size; i+= WP_TILE_BLOCK_DIM)
|
|
1021
|
+
grad(i) = T(0);
|
|
1022
|
+
|
|
1023
|
+
WP_TILE_SYNC();
|
|
1024
|
+
}
|
|
1025
|
+
|
|
1026
|
+
|
|
1027
|
+
// accumulate gradients onto this tile
|
|
1028
|
+
template <typename Tile>
|
|
1029
|
+
inline CUDA_CALLABLE void grad_add(const Tile& tile)
|
|
1030
|
+
{
|
|
1031
|
+
WP_PRAGMA_UNROLL
|
|
1032
|
+
for (int i=0; i < Tile::Layout::NumRegs; ++i)
|
|
1033
|
+
{
|
|
1034
|
+
const int linear = Tile::Layout::linear_from_register(i);
|
|
1035
|
+
|
|
1036
|
+
// handle case where tile size is not
|
|
1037
|
+
// aligned to block dimensions
|
|
1038
|
+
if (!Tile::Layout::valid(linear))
|
|
1039
|
+
break;
|
|
1040
|
+
|
|
1041
|
+
// if the destination layout is unique (no broadcast dimensions)
|
|
1042
|
+
// then we can use regular non-atomic accmulation
|
|
1043
|
+
if (Layout::Unique)
|
|
1044
|
+
grad(linear) += tile.data[i];
|
|
1045
|
+
else
|
|
1046
|
+
// use shared memory atomics to accumulate gradients
|
|
1047
|
+
// since for broadcast tiles (e.g.: a bias vector) multiple incoming threads
|
|
1048
|
+
// may map to a single location in shared memory
|
|
1049
|
+
wp::atomic_add(&grad(linear), tile.data[i]);
|
|
1050
|
+
|
|
1051
|
+
}
|
|
1052
|
+
|
|
1053
|
+
WP_TILE_SYNC();
|
|
1054
|
+
}
|
|
1055
|
+
|
|
1056
|
+
// accumulate gradient onto this tile from a global array
|
|
1057
|
+
CUDA_CALLABLE void grad_add(const tile_global_t<T, typename Layout::Shape>& global)
|
|
1058
|
+
{
|
|
1059
|
+
WP_PRAGMA_UNROLL
|
|
1060
|
+
for (int i=WP_TILE_THREAD_IDX; i < Layout::Size; i += WP_TILE_BLOCK_DIM)
|
|
1061
|
+
{
|
|
1062
|
+
auto c = Layout::coord_from_linear(i);
|
|
1063
|
+
T g = global.load_grad(c);
|
|
1064
|
+
|
|
1065
|
+
if (Layout::Unique)
|
|
1066
|
+
{
|
|
1067
|
+
// if the destination layout is unique (no broadcast dimensions)
|
|
1068
|
+
// then we can use regular non-atomic accumulation
|
|
1069
|
+
grad(c) += g;
|
|
1070
|
+
}
|
|
1071
|
+
else
|
|
1072
|
+
{
|
|
1073
|
+
// use shared memory atomics to accumulate gradients
|
|
1074
|
+
// since for broadcast tiles (e.g.: a bias vector) multiple incoming threads
|
|
1075
|
+
// may map to a single location in shared memory
|
|
1076
|
+
wp::atomic_add(&grad(c), g);
|
|
1077
|
+
}
|
|
1078
|
+
}
|
|
1079
|
+
|
|
1080
|
+
WP_TILE_SYNC();
|
|
1081
|
+
}
|
|
1082
|
+
|
|
1083
|
+
// copy shared tile to register
|
|
1084
|
+
inline CUDA_CALLABLE auto grad_to_register()
|
|
1085
|
+
{
|
|
1086
|
+
using Tile = tile_register_t<T, tile_layout_register_t<typename Layout::Shape>>;
|
|
1087
|
+
Tile out;
|
|
1088
|
+
|
|
1089
|
+
WP_PRAGMA_UNROLL
|
|
1090
|
+
for (int i=0; i < Tile::Layout::NumRegs; ++i)
|
|
1091
|
+
{
|
|
1092
|
+
const int linear = Tile::Layout::linear_from_register(i);
|
|
1093
|
+
|
|
1094
|
+
if (!Tile::Layout::valid(linear))
|
|
1095
|
+
break;
|
|
1096
|
+
|
|
1097
|
+
out(i) = grad(linear);
|
|
1098
|
+
}
|
|
1099
|
+
|
|
1100
|
+
return out;
|
|
1101
|
+
}
|
|
1102
|
+
|
|
1103
|
+
// copy shared tile to register
|
|
1104
|
+
inline CUDA_CALLABLE auto copy_to_register() const
|
|
1105
|
+
{
|
|
1106
|
+
|
|
1107
|
+
auto out = tile_register_like(this);
|
|
1108
|
+
|
|
1109
|
+
using Layout = typename decltype(out)::Layout;
|
|
1110
|
+
|
|
1111
|
+
WP_PRAGMA_UNROLL
|
|
1112
|
+
for (int i=0; i < Layout::NumRegs; ++i)
|
|
1113
|
+
{
|
|
1114
|
+
const int linear = Layout::linear_from_register(i);
|
|
1115
|
+
|
|
1116
|
+
if (!Layout::valid(linear))
|
|
1117
|
+
break;
|
|
1118
|
+
|
|
1119
|
+
out(i) = data(linear);
|
|
1120
|
+
}
|
|
1121
|
+
|
|
1122
|
+
return out;
|
|
1123
|
+
}
|
|
1124
|
+
|
|
1125
|
+
template <typename Global>
|
|
1126
|
+
inline CUDA_CALLABLE void copy_to_global(const Global& dest)
|
|
1127
|
+
{
|
|
1128
|
+
|
|
1129
|
+
#if defined(__CUDA_ARCH__)
|
|
1130
|
+
// vectorized loads for specific input/output shapes
|
|
1131
|
+
if constexpr (Layout::Shape::N == 2)
|
|
1132
|
+
{
|
|
1133
|
+
constexpr int lastdim = Layout::Shape::N-1;
|
|
1134
|
+
constexpr bool contiguous_src = Layout::Stride::dim(lastdim) == 1;
|
|
1135
|
+
const bool contiguous_dest = dest.data.strides[lastdim] == sizeof(T);
|
|
1136
|
+
const int elements = (dest.data.shape[lastdim] - dest.offset[lastdim]);
|
|
1137
|
+
const bool aligned_size = (elements*sizeof(T))%sizeof(float4) == 0;
|
|
1138
|
+
|
|
1139
|
+
float4* dest128 = (float4*)&dest.data.data[dest.index_from_coord(tile_coord(0,0))];
|
|
1140
|
+
const bool aligned_dst = (uint64_t)(dest128)%sizeof(float4) == 0;
|
|
1141
|
+
|
|
1142
|
+
if (contiguous_dest && contiguous_src && aligned_size && aligned_dst)
|
|
1143
|
+
{
|
|
1144
|
+
constexpr int M = Layout::Shape::dim(0);
|
|
1145
|
+
constexpr int N = (Layout::Shape::dim(1)*sizeof(T))/sizeof(float4);
|
|
1146
|
+
|
|
1147
|
+
// alias of shared tile with 128bit type
|
|
1148
|
+
using SrcLayout = tile_layout_strided_t<tile_shape_t<M, N>>;
|
|
1149
|
+
tile_shared_t<float4, SrcLayout> src128((float4*)data.ptr);
|
|
1150
|
+
|
|
1151
|
+
assert(((uint64_t)(data.ptr))%sizeof(float4) == 0);
|
|
1152
|
+
assert(((uint64_t)(dest128))%sizeof(float4) == 0);
|
|
1153
|
+
|
|
1154
|
+
const int stride_i = dest.data.strides[0]/sizeof(float4);
|
|
1155
|
+
const int stride_j = 1;
|
|
1156
|
+
|
|
1157
|
+
WP_PRAGMA_UNROLL
|
|
1158
|
+
for (int i=WP_TILE_THREAD_IDX; i < SrcLayout::Size; i += WP_TILE_BLOCK_DIM)
|
|
1159
|
+
{
|
|
1160
|
+
auto c = SrcLayout::coord_from_linear(i);
|
|
1161
|
+
|
|
1162
|
+
dest128[stride_i*c[0] + stride_j*c[1]] = src128.data(i);
|
|
1163
|
+
}
|
|
1164
|
+
|
|
1165
|
+
return;
|
|
1166
|
+
}
|
|
1167
|
+
}
|
|
1168
|
+
|
|
1169
|
+
#endif //defined(__CUDA_ARCH__)
|
|
1170
|
+
|
|
1171
|
+
// scalar bounds checked path
|
|
1172
|
+
WP_PRAGMA_UNROLL
|
|
1173
|
+
for (int i=WP_TILE_THREAD_IDX; i < Layout::Size; i += WP_TILE_BLOCK_DIM)
|
|
1174
|
+
{
|
|
1175
|
+
auto c = Layout::coord_from_linear(i);
|
|
1176
|
+
dest.store(c, data(i));
|
|
1177
|
+
}
|
|
1178
|
+
}
|
|
1179
|
+
|
|
1180
|
+
inline CUDA_CALLABLE void cp_async_global_to_shared_128(float4* shared_dest, const float4* global_src)
|
|
1181
|
+
{
|
|
1182
|
+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
|
|
1183
|
+
|
|
1184
|
+
unsigned long long saddr = 0ULL;
|
|
1185
|
+
unsigned long long gaddr = 0ULL;
|
|
1186
|
+
|
|
1187
|
+
asm volatile("cvta.to.shared.u64 %0, %1;" : "=l"(saddr) : "l"(shared_dest));
|
|
1188
|
+
asm volatile("cvta.to.global.u64 %0, %1;" : "=l"(gaddr) : "l"(global_src));
|
|
1189
|
+
|
|
1190
|
+
// Use cp.async on newer architectures
|
|
1191
|
+
asm volatile(
|
|
1192
|
+
"cp.async.ca.shared.global [%0], [%1], 16;\n"
|
|
1193
|
+
:
|
|
1194
|
+
: "l"(saddr), "l"(gaddr)
|
|
1195
|
+
);
|
|
1196
|
+
#else
|
|
1197
|
+
// use regular load/store through register on older arches
|
|
1198
|
+
*shared_dest = *global_src;
|
|
1199
|
+
#endif
|
|
1200
|
+
}
|
|
1201
|
+
|
|
1202
|
+
inline CUDA_CALLABLE void cp_async_commit_and_wait_all_128()
|
|
1203
|
+
{
|
|
1204
|
+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
|
|
1205
|
+
asm volatile(
|
|
1206
|
+
"cp.async.commit_group;\n"
|
|
1207
|
+
"cp.async.wait_group 0;\n" ::);
|
|
1208
|
+
#endif
|
|
1209
|
+
}
|
|
1210
|
+
|
|
1211
|
+
template <typename Global>
|
|
1212
|
+
inline CUDA_CALLABLE void copy_from_global(const Global& src)
|
|
1213
|
+
{
|
|
1214
|
+
if (initialized)
|
|
1215
|
+
WP_TILE_SYNC();
|
|
1216
|
+
|
|
1217
|
+
#if defined(__CUDA_ARCH__)
|
|
1218
|
+
|
|
1219
|
+
// vectorized loads for specific input/output shapes
|
|
1220
|
+
if constexpr (Layout::Shape::N == 2)
|
|
1221
|
+
{
|
|
1222
|
+
constexpr int lastdim = Layout::Shape::N-1;
|
|
1223
|
+
constexpr bool contiguous_dest = Layout::Stride::dim(lastdim) == 1;
|
|
1224
|
+
const bool contiguous_src = src.data.strides[lastdim] == sizeof(T);
|
|
1225
|
+
const int elements = (src.data.shape[lastdim] - src.offset[lastdim]);
|
|
1226
|
+
const bool aligned_size = (elements*sizeof(T))%sizeof(float4) == 0;
|
|
1227
|
+
|
|
1228
|
+
float4* src128 = (float4*)&src.data.data[src.index_from_coord(tile_coord(0,0))];
|
|
1229
|
+
const bool aligned_src = (uint64_t)(src128)%sizeof(float4) == 0;
|
|
1230
|
+
|
|
1231
|
+
if (contiguous_dest && contiguous_src && aligned_size && aligned_src)
|
|
1232
|
+
{
|
|
1233
|
+
constexpr int M = Layout::Shape::dim(0);
|
|
1234
|
+
constexpr int N = (Layout::Shape::dim(1)*sizeof(T))/sizeof(float4);
|
|
1235
|
+
|
|
1236
|
+
// alias of shared tile with 128bit type
|
|
1237
|
+
using DestLayout = tile_layout_strided_t<tile_shape_t<M, N>>;
|
|
1238
|
+
tile_shared_t<float4, DestLayout> dest128((float4*)data.ptr);
|
|
1239
|
+
|
|
1240
|
+
assert(((uint64_t)(dest128.data.ptr))%sizeof(float4) == 0);
|
|
1241
|
+
assert(((uint64_t)(src128))%sizeof(float4) == 0);
|
|
1242
|
+
|
|
1243
|
+
const int stride_i = src.data.strides[0]/sizeof(float4);
|
|
1244
|
+
const int stride_j = 1;
|
|
1245
|
+
|
|
1246
|
+
WP_PRAGMA_UNROLL
|
|
1247
|
+
for (int i=WP_TILE_THREAD_IDX; i < DestLayout::Size; i += WP_TILE_BLOCK_DIM)
|
|
1248
|
+
{
|
|
1249
|
+
auto c = DestLayout::coord_from_linear(i);
|
|
1250
|
+
|
|
1251
|
+
#if WP_USE_ASYNC_PIPELINE
|
|
1252
|
+
cp_async_global_to_shared_128(&dest128.data(i), &src128[stride_i*c[0] + stride_j*c[1]]);
|
|
1253
|
+
#else
|
|
1254
|
+
dest128.data(i) = src128[stride_i*c[0] + stride_j*c[1]];
|
|
1255
|
+
#endif // WP_USE_ASYNC_PIPELINE
|
|
1256
|
+
}
|
|
1257
|
+
|
|
1258
|
+
#if WP_USE_ASYNC_PIPELINE
|
|
1259
|
+
cp_async_commit_and_wait_all_128();
|
|
1260
|
+
#endif // WP_USE_ASYNC_PIPELINE
|
|
1261
|
+
|
|
1262
|
+
initialized = true;
|
|
1263
|
+
WP_TILE_SYNC();
|
|
1264
|
+
return;
|
|
1265
|
+
}
|
|
1266
|
+
}
|
|
1267
|
+
|
|
1268
|
+
#endif //defined(__CUDA_ARCH__)
|
|
1269
|
+
|
|
1270
|
+
// scalar bounds checked path
|
|
1271
|
+
WP_PRAGMA_UNROLL
|
|
1272
|
+
for (int i=WP_TILE_THREAD_IDX; i < Layout::Size; i += WP_TILE_BLOCK_DIM)
|
|
1273
|
+
{
|
|
1274
|
+
auto c = Layout::coord_from_linear(i);
|
|
1275
|
+
data(i) = src.load(c);
|
|
1276
|
+
}
|
|
1277
|
+
|
|
1278
|
+
initialized = true;
|
|
1279
|
+
WP_TILE_SYNC();
|
|
1280
|
+
}
|
|
1281
|
+
|
|
1282
|
+
template <typename Global>
|
|
1283
|
+
inline CUDA_CALLABLE auto atomic_add(Global& dest)
|
|
1284
|
+
{
|
|
1285
|
+
copy_to_register().atomic_add(dest);
|
|
1286
|
+
}
|
|
1287
|
+
|
|
1288
|
+
template <typename Global>
|
|
1289
|
+
inline CUDA_CALLABLE auto atomic_add_grad(Global& dest)
|
|
1290
|
+
{
|
|
1291
|
+
grad_to_register().atomic_add_grad(dest);
|
|
1292
|
+
}
|
|
1293
|
+
|
|
1294
|
+
// overload for integral types
|
|
1295
|
+
inline CUDA_CALLABLE void print_value(int x) const
|
|
1296
|
+
{
|
|
1297
|
+
printf("%d", x);
|
|
1298
|
+
}
|
|
1299
|
+
|
|
1300
|
+
// overload for floating point types
|
|
1301
|
+
template <typename ValueType>
|
|
1302
|
+
inline CUDA_CALLABLE void print_value(ValueType x) const
|
|
1303
|
+
{
|
|
1304
|
+
printf("%g", x);
|
|
1305
|
+
}
|
|
1306
|
+
|
|
1307
|
+
template <int Level = 0>
|
|
1308
|
+
inline CUDA_CALLABLE void print_values(const Storage& storage, int index=0) const
|
|
1309
|
+
{
|
|
1310
|
+
using Shape = typename Layout::Shape;
|
|
1311
|
+
|
|
1312
|
+
if constexpr (Level < Shape::N)
|
|
1313
|
+
{
|
|
1314
|
+
if constexpr (Level == Shape::N - 1)
|
|
1315
|
+
{
|
|
1316
|
+
// Special handling for 1D case
|
|
1317
|
+
printf("[");
|
|
1318
|
+
for (int i = 0; i < Shape::dim(Level); ++i)
|
|
1319
|
+
{
|
|
1320
|
+
print_value(storage(index + i));
|
|
1321
|
+
|
|
1322
|
+
if (i < Shape::dim(Level) - 1)
|
|
1323
|
+
{
|
|
1324
|
+
printf(" ");
|
|
1325
|
+
}
|
|
1326
|
+
}
|
|
1327
|
+
printf("]");
|
|
1328
|
+
}
|
|
1329
|
+
else if constexpr (Level == Shape::N - 2)
|
|
1330
|
+
{
|
|
1331
|
+
// Special handling for 2D case
|
|
1332
|
+
printf("[");
|
|
1333
|
+
for (int i = 0; i < Shape::dim(Level); ++i)
|
|
1334
|
+
{
|
|
1335
|
+
printf("[");
|
|
1336
|
+
for (int j=0; j < Shape::dim(Level+1); ++j)
|
|
1337
|
+
{
|
|
1338
|
+
print_value(storage(index));
|
|
1339
|
+
|
|
1340
|
+
if (j < Shape::dim(Level+1) - 1)
|
|
1341
|
+
{
|
|
1342
|
+
printf(" ");
|
|
1343
|
+
}
|
|
1344
|
+
|
|
1345
|
+
++index;
|
|
1346
|
+
}
|
|
1347
|
+
|
|
1348
|
+
printf("]");
|
|
1349
|
+
|
|
1350
|
+
// next row
|
|
1351
|
+
if (i < Shape::dim(Level)-1)
|
|
1352
|
+
{
|
|
1353
|
+
printf("\n");
|
|
1354
|
+
|
|
1355
|
+
// indent next row
|
|
1356
|
+
for (int i=0; i <= Shape::N-2; ++i)
|
|
1357
|
+
printf(" ");
|
|
1358
|
+
|
|
1359
|
+
}
|
|
1360
|
+
}
|
|
1361
|
+
printf("]");
|
|
1362
|
+
}
|
|
1363
|
+
else
|
|
1364
|
+
{
|
|
1365
|
+
printf("[");
|
|
1366
|
+
for (int i = 0; i < Shape::dim(Level); ++i)
|
|
1367
|
+
{
|
|
1368
|
+
print_values<Level + 1>(storage, index + i * Shape::dim(Level));
|
|
1369
|
+
if (i < Shape::dim(Level) - 1)
|
|
1370
|
+
{
|
|
1371
|
+
printf("\n\n");
|
|
1372
|
+
|
|
1373
|
+
// indent next row
|
|
1374
|
+
for (int i=0; i <= Level; ++i)
|
|
1375
|
+
printf(" ");
|
|
1376
|
+
}
|
|
1377
|
+
}
|
|
1378
|
+
printf("]");
|
|
1379
|
+
}
|
|
1380
|
+
}
|
|
1381
|
+
}
|
|
1382
|
+
|
|
1383
|
+
inline CUDA_CALLABLE void print(bool reverse=false) const
|
|
1384
|
+
{
|
|
1385
|
+
if (WP_TILE_THREAD_IDX != 0)
|
|
1386
|
+
return;
|
|
1387
|
+
|
|
1388
|
+
if (reverse)
|
|
1389
|
+
print_values(grad);
|
|
1390
|
+
else
|
|
1391
|
+
print_values(data);
|
|
1392
|
+
|
|
1393
|
+
printf(" = tile(shape=(");
|
|
1394
|
+
for (int i=0; i < Layout::Shape::N; ++i)
|
|
1395
|
+
{
|
|
1396
|
+
printf("%d", Layout::Shape::dim(i));
|
|
1397
|
+
if (i != Layout::Shape::N-1)
|
|
1398
|
+
printf(",");
|
|
1399
|
+
}
|
|
1400
|
+
|
|
1401
|
+
printf("), storage=shared)\n");
|
|
1402
|
+
}
|
|
1403
|
+
};
|
|
1404
|
+
|
|
1405
|
+
|
|
1406
|
+
template <typename T, typename L>
|
|
1407
|
+
void tile_register_t<T, L>::print() const
|
|
1408
|
+
{
|
|
1409
|
+
// create a temporary shared tile so that
|
|
1410
|
+
// we can print it deterministically
|
|
1411
|
+
WP_TILE_SHARED T smem[L::Size];
|
|
1412
|
+
tile_shared_t<T, tile_layout_strided_t<typename L::Shape>, false> scratch(smem, nullptr);
|
|
1413
|
+
|
|
1414
|
+
scratch.assign(*this);
|
|
1415
|
+
|
|
1416
|
+
WP_TILE_SYNC();
|
|
1417
|
+
|
|
1418
|
+
if (WP_TILE_THREAD_IDX == 0)
|
|
1419
|
+
{
|
|
1420
|
+
scratch.print_values(scratch.data, 0);
|
|
1421
|
+
|
|
1422
|
+
printf(" = tile(shape=(");
|
|
1423
|
+
for (int i=0; i < L::Shape::N; ++i)
|
|
1424
|
+
{
|
|
1425
|
+
printf("%d", L::Shape::dim(i));
|
|
1426
|
+
if (i != L::Shape::N-1)
|
|
1427
|
+
printf(",");
|
|
1428
|
+
}
|
|
1429
|
+
|
|
1430
|
+
printf("), storage=register)\n");
|
|
1431
|
+
}
|
|
1432
|
+
|
|
1433
|
+
WP_TILE_SYNC();
|
|
1434
|
+
}
|
|
1435
|
+
|
|
1436
|
+
// print entry points
|
|
1437
|
+
template <typename T, typename L>
|
|
1438
|
+
inline CUDA_CALLABLE void print(const tile_register_t<T, L>& t) { t.print(); }
|
|
1439
|
+
template <typename T, typename L, bool Owner>
|
|
1440
|
+
inline CUDA_CALLABLE void print(const tile_shared_t<T, L, Owner>& t) { t.print(); }
|
|
1441
|
+
|
|
1442
|
+
template <typename T, typename L, bool O>
|
|
1443
|
+
inline CUDA_CALLABLE int len(const tile_shared_t<T, L, O>& t)
|
|
1444
|
+
{
|
|
1445
|
+
return L::Shape::dim(0);
|
|
1446
|
+
}
|
|
1447
|
+
|
|
1448
|
+
template <typename T, typename L, bool O, typename AdjTile>
|
|
1449
|
+
inline CUDA_CALLABLE void adj_len(const tile_shared_t<T,L,O>& t, const AdjTile& a, int& adj_ret)
|
|
1450
|
+
{
|
|
1451
|
+
}
|
|
1452
|
+
|
|
1453
|
+
template <typename T, typename L>
|
|
1454
|
+
inline CUDA_CALLABLE int len(const tile_register_t<T, L>& t)
|
|
1455
|
+
{
|
|
1456
|
+
return L::Shape::dim(0);
|
|
1457
|
+
}
|
|
1458
|
+
|
|
1459
|
+
template <typename T, typename L, typename AdjTile>
|
|
1460
|
+
inline CUDA_CALLABLE void adj_len(const tile_register_t<T,L>& t, const AdjTile& a, int& adj_ret)
|
|
1461
|
+
{
|
|
1462
|
+
}
|
|
1463
|
+
|
|
1464
|
+
|
|
1465
|
+
template <typename T, typename L>
|
|
1466
|
+
inline CUDA_CALLABLE void adj_print(const tile_register_t<T, L>& t, const tile_register_t<T, L>& a) { a.print(); }
|
|
1467
|
+
template <typename T, typename L, bool Owner>
|
|
1468
|
+
inline CUDA_CALLABLE void adj_print(const tile_shared_t<T, L, Owner>& t, const tile_shared_t<T, L, Owner>& a) { a.print(true); }
|
|
1469
|
+
|
|
1470
|
+
|
|
1471
|
+
|
|
1472
|
+
// helpers to allocate shared tiles
|
|
1473
|
+
template <typename T, typename Shape, bool RequiresGrad>
|
|
1474
|
+
inline CUDA_CALLABLE auto tile_alloc_empty()
|
|
1475
|
+
|
|
1476
|
+
{ constexpr int size = Shape::size();
|
|
1477
|
+
T* data = (T*)tile_alloc_shared(size*sizeof(T));
|
|
1478
|
+
T* grad = nullptr;
|
|
1479
|
+
|
|
1480
|
+
#if FP_CHECK
|
|
1481
|
+
|
|
1482
|
+
// initialize tile to quiet nan
|
|
1483
|
+
uint32_t qnanbits = 0x7FC00000;
|
|
1484
|
+
float qnan = *(float*)(&qnanbits);
|
|
1485
|
+
|
|
1486
|
+
for (int i=WP_TILE_THREAD_IDX; i < size; i+= WP_TILE_BLOCK_DIM)
|
|
1487
|
+
data[i] = T(qnan);
|
|
1488
|
+
|
|
1489
|
+
WP_TILE_SYNC();
|
|
1490
|
+
|
|
1491
|
+
#endif // FP_CHECK
|
|
1492
|
+
|
|
1493
|
+
|
|
1494
|
+
if (RequiresGrad)
|
|
1495
|
+
{
|
|
1496
|
+
grad = (T*)tile_alloc_shared(size*sizeof(T));
|
|
1497
|
+
|
|
1498
|
+
for (int i=WP_TILE_THREAD_IDX; i < size; i+= WP_TILE_BLOCK_DIM)
|
|
1499
|
+
grad[i] = T(0);
|
|
1500
|
+
|
|
1501
|
+
WP_TILE_SYNC();
|
|
1502
|
+
}
|
|
1503
|
+
|
|
1504
|
+
return tile_shared_t<T, tile_layout_strided_t<Shape>>(data, grad);
|
|
1505
|
+
}
|
|
1506
|
+
|
|
1507
|
+
|
|
1508
|
+
//-----------------------------------------------------------------------------------------------------
|
|
1509
|
+
// High level entry points for each op (correspond to one Warp builtin)
|
|
1510
|
+
|
|
1511
|
+
// construct a tile from a local SIMT value (one per-thread)
|
|
1512
|
+
template <typename T>
|
|
1513
|
+
inline CUDA_CALLABLE auto tile(const T& x)
|
|
1514
|
+
{
|
|
1515
|
+
tile_register_t<T, tile_layout_register_t<tile_shape_t<WP_TILE_BLOCK_DIM>>> result;
|
|
1516
|
+
|
|
1517
|
+
using Layout = typename decltype(result)::Layout;
|
|
1518
|
+
static_assert(Layout::NumRegs == 1, "Expected Layout::NumRegs == 1");
|
|
1519
|
+
|
|
1520
|
+
result.data[0] = x;
|
|
1521
|
+
return result;
|
|
1522
|
+
}
|
|
1523
|
+
|
|
1524
|
+
// overload for constructing a tile from a per-thread vector
|
|
1525
|
+
template <typename T, unsigned Length>
|
|
1526
|
+
inline CUDA_CALLABLE auto tile(const wp::vec_t<Length, T>& x)
|
|
1527
|
+
{
|
|
1528
|
+
tile_register_t<T, tile_layout_register_t<tile_shape_t<Length, WP_TILE_BLOCK_DIM>>> result;
|
|
1529
|
+
|
|
1530
|
+
using Layout = typename decltype(result)::Layout;
|
|
1531
|
+
static_assert(Layout::NumRegs == Length, "Expected Layout::NumRegs == Length");
|
|
1532
|
+
|
|
1533
|
+
for (int i=0; i < Length; ++i)
|
|
1534
|
+
result.data[i] = x[i];
|
|
1535
|
+
|
|
1536
|
+
return result;
|
|
1537
|
+
}
|
|
1538
|
+
|
|
1539
|
+
// construct a tile from a local SIMT value (one per-thread)
|
|
1540
|
+
template <typename T, typename AdjTile>
|
|
1541
|
+
inline CUDA_CALLABLE void adj_tile(const T& x, T& adj_x, AdjTile& adj_ret)
|
|
1542
|
+
{
|
|
1543
|
+
static_assert(AdjTile::Layout::Shape::N == 1, "Expected AdjTile::Layout::Shape::N == 1");
|
|
1544
|
+
static_assert(AdjTile::Layout::Shape::dim(0) == WP_TILE_BLOCK_DIM, "Expected AdjTile::Layout::Shape::dim(0) == WP_TILE_BLOCK_DIM");
|
|
1545
|
+
|
|
1546
|
+
auto adj_reg = adj_ret.copy_to_register();
|
|
1547
|
+
|
|
1548
|
+
adj_x += adj_reg.data[0];
|
|
1549
|
+
}
|
|
1550
|
+
|
|
1551
|
+
template <typename T, unsigned Length, typename AdjTile>
|
|
1552
|
+
inline CUDA_CALLABLE void adj_tile(const wp::vec_t<Length, T>& x, wp::vec_t<Length, T>& adj_x, AdjTile& adj_ret)
|
|
1553
|
+
{
|
|
1554
|
+
static_assert(AdjTile::Layout::Shape::N == 2, "Expected AdjTile::Layout::Shape::N == 2");
|
|
1555
|
+
static_assert(AdjTile::Layout::Shape::dim(0) == Length, "Expected AdjTile::Layout::Shape::dim(0) == Length");
|
|
1556
|
+
static_assert(AdjTile::Layout::Shape::dim(1) == WP_TILE_BLOCK_DIM, "Expected AdjTile::Layout::Shape::dim(1) == WP_TILE_BLOCK_DIM");
|
|
1557
|
+
|
|
1558
|
+
auto adj_reg = adj_ret.copy_to_register();
|
|
1559
|
+
|
|
1560
|
+
for (int i=0; i < Length; ++i)
|
|
1561
|
+
adj_x[i] += adj_reg.data[i];
|
|
1562
|
+
}
|
|
1563
|
+
|
|
1564
|
+
template <typename Tile>
|
|
1565
|
+
inline CUDA_CALLABLE auto untile(Tile& tile)
|
|
1566
|
+
{
|
|
1567
|
+
// code-gen should have set the tile to
|
|
1568
|
+
// have exactly the block dimension so
|
|
1569
|
+
// there is exactly one value per-thread
|
|
1570
|
+
auto reg = tile.copy_to_register();
|
|
1571
|
+
|
|
1572
|
+
constexpr int N = Tile::Layout::Shape::N;
|
|
1573
|
+
|
|
1574
|
+
// scalar case
|
|
1575
|
+
if constexpr(N == 1)
|
|
1576
|
+
{
|
|
1577
|
+
return reg.data[0];
|
|
1578
|
+
}
|
|
1579
|
+
|
|
1580
|
+
// vector case
|
|
1581
|
+
if constexpr(N == 2)
|
|
1582
|
+
{
|
|
1583
|
+
constexpr int Length = Tile::Layout::Shape::dim(0);
|
|
1584
|
+
wp::vec_t<Length, typename Tile::Type> v;
|
|
1585
|
+
for (int i=0; i < Length; ++i)
|
|
1586
|
+
v[i] = reg.data[i];
|
|
1587
|
+
|
|
1588
|
+
return v;
|
|
1589
|
+
}
|
|
1590
|
+
}
|
|
1591
|
+
|
|
1592
|
+
template <typename Tile, typename Value>
|
|
1593
|
+
inline CUDA_CALLABLE void adj_untile(Tile& tile, Tile& adj_tile, Value& adj_ret)
|
|
1594
|
+
{
|
|
1595
|
+
auto adj = adj_tile.copy_to_register();
|
|
1596
|
+
|
|
1597
|
+
constexpr int N = Tile::Layout::Shape::N;
|
|
1598
|
+
|
|
1599
|
+
// scalar case
|
|
1600
|
+
if constexpr(N == 1)
|
|
1601
|
+
{
|
|
1602
|
+
adj.data[0] += adj_ret;
|
|
1603
|
+
}
|
|
1604
|
+
|
|
1605
|
+
// vector case
|
|
1606
|
+
if constexpr(N == 2)
|
|
1607
|
+
{
|
|
1608
|
+
constexpr int Length = Tile::Layout::Shape::dim(0);
|
|
1609
|
+
for (int i=0; i < Length; ++i)
|
|
1610
|
+
adj.data[i] += adj_ret[i];
|
|
1611
|
+
}
|
|
1612
|
+
|
|
1613
|
+
adj_tile.assign(adj);
|
|
1614
|
+
}
|
|
1615
|
+
|
|
1616
|
+
// zero initialized tile
|
|
1617
|
+
template <typename T, unsigned... Shape>
|
|
1618
|
+
inline CUDA_CALLABLE auto tile_zeros()
|
|
1619
|
+
{
|
|
1620
|
+
// tile variable assignment operator will handle initialization (since lhs could be shared/register tile)
|
|
1621
|
+
return T(0);
|
|
1622
|
+
}
|
|
1623
|
+
|
|
1624
|
+
// one-initialized tile
|
|
1625
|
+
template <typename T, unsigned... Shape>
|
|
1626
|
+
inline CUDA_CALLABLE auto tile_ones()
|
|
1627
|
+
{
|
|
1628
|
+
// tile variable assignment operator will handle initialization (since lhs could be shared/register tile)
|
|
1629
|
+
return T(1);
|
|
1630
|
+
}
|
|
1631
|
+
|
|
1632
|
+
// tile with evenly spaced values
|
|
1633
|
+
template <typename T, int Len>
|
|
1634
|
+
inline CUDA_CALLABLE auto tile_arange(T start, T stop, T step)
|
|
1635
|
+
{
|
|
1636
|
+
auto out = tile_register<T, Len>();
|
|
1637
|
+
|
|
1638
|
+
using Layout = typename decltype(out)::Layout;
|
|
1639
|
+
|
|
1640
|
+
WP_PRAGMA_UNROLL
|
|
1641
|
+
for (int i=0; i < Layout::NumRegs; ++i)
|
|
1642
|
+
{
|
|
1643
|
+
const int linear = Layout::linear_from_register(i);
|
|
1644
|
+
|
|
1645
|
+
// handle case where tile size is not
|
|
1646
|
+
// aligned to block dimensions
|
|
1647
|
+
if (!Layout::valid(linear))
|
|
1648
|
+
break;
|
|
1649
|
+
|
|
1650
|
+
out.data[i] = start + linear*step;
|
|
1651
|
+
}
|
|
1652
|
+
|
|
1653
|
+
return out;
|
|
1654
|
+
}
|
|
1655
|
+
|
|
1656
|
+
template <typename T, typename AdjTile>
|
|
1657
|
+
inline CUDA_CALLABLE void adj_tile_arange(T start, T stop, T step,
|
|
1658
|
+
T& adj_start, T& adj_stop, T& adj_step, AdjTile& adj_ret) {}
|
|
1659
|
+
|
|
1660
|
+
// entry point for load operations, these just return a reference to a global memory array + coordinate
|
|
1661
|
+
template <unsigned... Shape, typename... Indices, typename T>
|
|
1662
|
+
inline CUDA_CALLABLE auto tile_load(array_t<T>& src, Indices... offset)
|
|
1663
|
+
{
|
|
1664
|
+
return tile_global_t<T, tile_shape_t<Shape...>>(src, tile_coord(offset...));
|
|
1665
|
+
}
|
|
1666
|
+
|
|
1667
|
+
// // entry point for tile store operations
|
|
1668
|
+
// template <typename... Indices, typename T, typename Tile>
|
|
1669
|
+
// inline CUDA_CALLABLE void tile_store(array_t<T>& dest, Tile& src, Indices... x)
|
|
1670
|
+
// {
|
|
1671
|
+
// src.copy_to_global(tile_global_t<T, typename Tile::Layout::Shape>(dest, tile_coord(x)));
|
|
1672
|
+
// }
|
|
1673
|
+
|
|
1674
|
+
// entry point for tile store operations
|
|
1675
|
+
template <typename T, typename Tile>
|
|
1676
|
+
inline CUDA_CALLABLE void tile_store(array_t<T>& dest, int x, Tile& src) { src.copy_to_global(tile_global_t<T, typename Tile::Layout::Shape>(dest, tile_coord(x))); }
|
|
1677
|
+
template <typename T, typename Tile>
|
|
1678
|
+
inline CUDA_CALLABLE void tile_store(array_t<T>& dest, int x, int y, Tile& src) { src.copy_to_global(tile_global_t<T, typename Tile::Layout::Shape>(dest, tile_coord(x, y))); }
|
|
1679
|
+
template <typename T, typename Tile>
|
|
1680
|
+
inline CUDA_CALLABLE void tile_store(array_t<T>& dest, int x, int y, int z, Tile& src) { src.copy_to_global(tile_global_t<T, typename Tile::Layout::Shape>(dest, tile_coord(x, y, z))); }
|
|
1681
|
+
template <typename T, typename Tile>
|
|
1682
|
+
inline CUDA_CALLABLE void tile_store(array_t<T>& dest, int x, int y, int z, int w, Tile& src) { src.copy_to_global(tile_global_t<T, typename Tile::Layout::Shape>(dest, tile_coord(x, y, z, w))); }
|
|
1683
|
+
|
|
1684
|
+
|
|
1685
|
+
|
|
1686
|
+
template <typename T, typename Tile>
|
|
1687
|
+
inline CUDA_CALLABLE auto tile_atomic_add(array_t<T>& dest, int x, Tile& src) { return src.atomic_add(tile_global_t<T, typename Tile::Layout::Shape>(dest, tile_coord(x))); }
|
|
1688
|
+
template <typename T, typename Tile>
|
|
1689
|
+
inline CUDA_CALLABLE auto tile_atomic_add(array_t<T>& dest, int x, int y, Tile& src) { return src.atomic_add(tile_global_t<T, typename Tile::Layout::Shape>(dest, tile_coord(x, y)));}
|
|
1690
|
+
template <typename T, typename Tile>
|
|
1691
|
+
inline CUDA_CALLABLE auto tile_atomic_add(array_t<T>& dest, int x, int y, int z, Tile& src) { return src.atomic_add(tile_global_t<T, typename Tile::Layout::Shape>(dest, tile_coord(x, y, z)));}
|
|
1692
|
+
template <typename T, typename Tile>
|
|
1693
|
+
inline CUDA_CALLABLE auto tile_atomic_add(array_t<T>& dest, int x, int y, int z, int w, Tile& src) { return src.atomic_add(tile_global_t<T, typename Tile::Layout::Shape>(dest, tile_coord(x, y, z, w)));}
|
|
1694
|
+
|
|
1695
|
+
|
|
1696
|
+
//-------------------------------------
|
|
1697
|
+
// Adjoints
|
|
1698
|
+
|
|
1699
|
+
template <typename T, typename AdjTile, typename Coord>
|
|
1700
|
+
inline CUDA_CALLABLE void adj_tile_load(array_t<T>& src, Coord c,
|
|
1701
|
+
array_t<T>& adj_src, Coord adj_c,
|
|
1702
|
+
AdjTile& adj_ret)
|
|
1703
|
+
{
|
|
1704
|
+
tile_global_t<T, typename AdjTile::Layout::Shape> dest(src, c);
|
|
1705
|
+
|
|
1706
|
+
// we allow users to override grad of src
|
|
1707
|
+
if (adj_src.data)
|
|
1708
|
+
dest.data.grad = adj_src.data;
|
|
1709
|
+
|
|
1710
|
+
adj_ret.atomic_add_grad(dest);
|
|
1711
|
+
}
|
|
1712
|
+
|
|
1713
|
+
|
|
1714
|
+
template <typename T, typename AdjTile>
|
|
1715
|
+
inline CUDA_CALLABLE void adj_tile_load(array_t<T>& src, int x, array_t<T>& adj_src, int adj_x, AdjTile& adj_ret) { adj_tile_load( src, tile_coord(x), adj_src, tile_coord(0), adj_ret); }
|
|
1716
|
+
template <typename T, typename AdjTile>
|
|
1717
|
+
inline CUDA_CALLABLE void adj_tile_load(array_t<T>& src, int x, int y, array_t<T>& adj_src, int adj_x, int adj_y, AdjTile& adj_ret) { adj_tile_load( src, tile_coord(x, y), adj_src, tile_coord(0,0), adj_ret); }
|
|
1718
|
+
template <typename T, typename AdjTile>
|
|
1719
|
+
inline CUDA_CALLABLE void adj_tile_load(array_t<T>& src, int x, int y, int z, array_t<T>& adj_src, int adj_x, int adj_y, int adj_z, AdjTile& adj_ret) { adj_tile_load( src, tile_coord(x, y, z), adj_src, tile_coord(0,0,0), adj_ret); }
|
|
1720
|
+
template <typename T, typename AdjTile>
|
|
1721
|
+
inline CUDA_CALLABLE void adj_tile_load(array_t<T>& src, int x, int y, int z, int w, array_t<T>& adj_src, int adj_x, int adj_y, int adj_z, int adj_w, AdjTile& adj_ret) { adj_tile_load( src, tile_coord(x, y, z, w), adj_src, tile_coord(0,0,0,0), adj_ret); }
|
|
1722
|
+
|
|
1723
|
+
|
|
1724
|
+
|
|
1725
|
+
template <typename T, typename Tile, typename AdjTile, typename Coord>
|
|
1726
|
+
inline CUDA_CALLABLE void adj_tile_store(array_t<T>& dest, Coord c, Tile& t, array_t<T>& adj_dest, Coord adj_c, AdjTile& adj_t)
|
|
1727
|
+
{
|
|
1728
|
+
tile_global_t<T, typename AdjTile::Layout::Shape> src(dest, c);
|
|
1729
|
+
|
|
1730
|
+
// we allow users to override grad of src
|
|
1731
|
+
if (adj_dest.data)
|
|
1732
|
+
src.data.grad = adj_dest.data;
|
|
1733
|
+
|
|
1734
|
+
if (src.data.grad == nullptr)
|
|
1735
|
+
return;
|
|
1736
|
+
|
|
1737
|
+
adj_t.grad_add(src);
|
|
1738
|
+
}
|
|
1739
|
+
|
|
1740
|
+
template <typename T, typename Tile, typename AdjTile>
|
|
1741
|
+
inline CUDA_CALLABLE void adj_tile_store(array_t<T>& dest, int x, Tile& t, array_t<T>& adj_dest, int adj_x, AdjTile& adj_t) { adj_tile_store(dest, tile_coord(x), t, adj_dest, tile_coord(0), adj_t); }
|
|
1742
|
+
template <typename T, typename Tile, typename AdjTile>
|
|
1743
|
+
inline CUDA_CALLABLE void adj_tile_store(array_t<T>& dest, int x, int y, Tile& t, array_t<T>& adj_dest, int adj_x, int adj_y, AdjTile& adj_t) { adj_tile_store(dest, tile_coord(x, y), t, adj_dest, tile_coord(0,0), adj_t); }
|
|
1744
|
+
template <typename T, typename Tile, typename AdjTile>
|
|
1745
|
+
inline CUDA_CALLABLE void adj_tile_store(array_t<T>& dest, int x, int y, int z, Tile& t, array_t<T>& adj_dest, int adj_x, int adj_y, int adj_z, AdjTile& adj_t) { adj_tile_store(dest, tile_coord(x, y, z), t, adj_dest, tile_coord(0,0,0), adj_t); }
|
|
1746
|
+
template <typename T, typename Tile, typename AdjTile>
|
|
1747
|
+
inline CUDA_CALLABLE void adj_tile_store(array_t<T>& dest, int x, int y, int z, int w, Tile& t, array_t<T>& adj_dest, int adj_x, int adj_y, int adj_z, int adj_w, AdjTile& adj_t) { adj_tile_store(dest, tile_coord(x, y, z, w), t, adj_dest, tile_coord(0,0,0,0), adj_t); }
|
|
1748
|
+
|
|
1749
|
+
|
|
1750
|
+
|
|
1751
|
+
// adj_tile_atomic_add is an alias for adj_tile_store
|
|
1752
|
+
template <typename T, typename Tile, typename AdjTile, typename AdjRet>
|
|
1753
|
+
inline CUDA_CALLABLE void adj_tile_atomic_add(array_t<T>& dest, int x, Tile& t, array_t<T>& adj_dest, int adj_x, AdjTile& adj_t, AdjRet& adj_ret) { adj_tile_store(dest, tile_coord(x), t, adj_dest, tile_coord(adj_x), adj_t); }
|
|
1754
|
+
template <typename T, typename Tile, typename AdjTile, typename AdjRet>
|
|
1755
|
+
inline CUDA_CALLABLE void adj_tile_atomic_add(array_t<T>& dest, int x, int y, Tile& t, array_t<T>& adj_dest, int adj_x, int adj_y, AdjTile& adj_t, AdjRet& adj_ret) { adj_tile_store(dest, tile_coord(x, y), t, adj_dest, tile_coord(adj_x, adj_y), adj_t); }
|
|
1756
|
+
template <typename T, typename Tile, typename AdjTile, typename AdjRet>
|
|
1757
|
+
inline CUDA_CALLABLE void adj_tile_atomic_add(array_t<T>& dest, int x, int y, int z, Tile& t, array_t<T>& adj_dest, int adj_x, int adj_y, int adj_z, AdjTile& adj_t, AdjRet& adj_ret) { adj_tile_store(dest, tile_coord(x, y, z), t, adj_dest, tile_coord(adj_x, adj_y, adj_z), adj_t); }
|
|
1758
|
+
template <typename T, typename Tile, typename AdjTile, typename AdjRet>
|
|
1759
|
+
inline CUDA_CALLABLE void adj_tile_atomic_add(array_t<T>& dest, int x, int y, int z, int w, Tile& t, array_t<T>& adj_dest, int adj_x, int adj_y, int adj_z, int adj_w, AdjTile& adj_t, AdjRet& adj_ret) { adj_tile_store(dest, tile_coord(x, y, z, w), t, adj_dest, tile_coord(adj_x, adj_y, adj_z, adj_w), adj_t); }
|
|
1760
|
+
|
|
1761
|
+
|
|
1762
|
+
// unary map
|
|
1763
|
+
template <typename Tile, typename Fwd>
|
|
1764
|
+
inline CUDA_CALLABLE auto tile_map(Fwd op,
|
|
1765
|
+
Tile &a)
|
|
1766
|
+
{
|
|
1767
|
+
auto out = tile_register_like<Tile>();
|
|
1768
|
+
auto a_reg = a.copy_to_register();
|
|
1769
|
+
|
|
1770
|
+
using Layout = typename decltype(out)::Layout;
|
|
1771
|
+
|
|
1772
|
+
WP_PRAGMA_UNROLL
|
|
1773
|
+
for (int i=0; i < Layout::NumRegs; ++i)
|
|
1774
|
+
{
|
|
1775
|
+
out.data[i] = op(a_reg.data[i]);
|
|
1776
|
+
}
|
|
1777
|
+
|
|
1778
|
+
return out;
|
|
1779
|
+
}
|
|
1780
|
+
|
|
1781
|
+
|
|
1782
|
+
template <typename Tile, typename AdjTile, typename Fwd, typename Adj>
|
|
1783
|
+
inline CUDA_CALLABLE void adj_tile_map(Fwd op,
|
|
1784
|
+
Tile& a,
|
|
1785
|
+
Adj adj_op,
|
|
1786
|
+
Tile& adj_a,
|
|
1787
|
+
AdjTile& adj_ret)
|
|
1788
|
+
{
|
|
1789
|
+
auto a_reg = a.copy_to_register();
|
|
1790
|
+
auto adj_a_reg = tile_register_like<Tile>();
|
|
1791
|
+
auto adj_ret_reg = adj_ret.grad_to_register();
|
|
1792
|
+
|
|
1793
|
+
using Layout = typename decltype(a_reg)::Layout;
|
|
1794
|
+
|
|
1795
|
+
WP_PRAGMA_UNROLL
|
|
1796
|
+
for (int i=0; i < Layout::NumRegs; ++i)
|
|
1797
|
+
{
|
|
1798
|
+
adj_op(a_reg.data[i], adj_a_reg.data[i], adj_ret_reg.data[i]);
|
|
1799
|
+
}
|
|
1800
|
+
|
|
1801
|
+
// write adjoints back
|
|
1802
|
+
adj_a.grad_add(adj_a_reg);
|
|
1803
|
+
}
|
|
1804
|
+
|
|
1805
|
+
// binary map
|
|
1806
|
+
template <typename TileA, typename TileB, typename Fwd>
|
|
1807
|
+
inline CUDA_CALLABLE auto tile_map(Fwd op,
|
|
1808
|
+
TileA& a,
|
|
1809
|
+
TileB& b)
|
|
1810
|
+
{
|
|
1811
|
+
auto out = tile_register_like<TileA>();
|
|
1812
|
+
|
|
1813
|
+
auto a_reg = a.copy_to_register();
|
|
1814
|
+
auto b_reg = b.copy_to_register();
|
|
1815
|
+
|
|
1816
|
+
using Layout = typename decltype(out)::Layout;
|
|
1817
|
+
|
|
1818
|
+
WP_PRAGMA_UNROLL
|
|
1819
|
+
for (int i=0; i < Layout::NumRegs; ++i)
|
|
1820
|
+
{
|
|
1821
|
+
out.data[i] = op(a_reg.data[i], b_reg.data[i]);
|
|
1822
|
+
}
|
|
1823
|
+
|
|
1824
|
+
return out;
|
|
1825
|
+
}
|
|
1826
|
+
|
|
1827
|
+
|
|
1828
|
+
template <typename TileA, typename TileB, typename Fwd, typename Adj, typename AdjTile>
|
|
1829
|
+
inline CUDA_CALLABLE void adj_tile_map(Fwd op,
|
|
1830
|
+
TileA &a,
|
|
1831
|
+
TileB &b,
|
|
1832
|
+
Adj adj_op,
|
|
1833
|
+
TileA &adj_a,
|
|
1834
|
+
TileB &adj_b,
|
|
1835
|
+
AdjTile &adj_ret)
|
|
1836
|
+
{
|
|
1837
|
+
auto a_reg = a.copy_to_register();
|
|
1838
|
+
auto b_reg = b.copy_to_register();
|
|
1839
|
+
|
|
1840
|
+
// allocate storage for adjoints
|
|
1841
|
+
auto adj_a_reg = tile_register_like<TileA>();
|
|
1842
|
+
auto adj_b_reg = tile_register_like<TileB>();
|
|
1843
|
+
|
|
1844
|
+
auto adj_ret_reg = adj_ret.grad_to_register();
|
|
1845
|
+
|
|
1846
|
+
using Layout = typename decltype(a_reg)::Layout;
|
|
1847
|
+
|
|
1848
|
+
WP_PRAGMA_UNROLL
|
|
1849
|
+
for (int i=0; i < Layout::NumRegs; ++i)
|
|
1850
|
+
{
|
|
1851
|
+
adj_op(a_reg.data[i], b_reg.data[i], adj_a_reg.data[i], adj_b_reg.data[i], adj_ret_reg.data[i]);
|
|
1852
|
+
}
|
|
1853
|
+
|
|
1854
|
+
adj_a.grad_add(adj_a_reg);
|
|
1855
|
+
adj_b.grad_add(adj_b_reg);
|
|
1856
|
+
}
|
|
1857
|
+
|
|
1858
|
+
// wrap the operator in a lambda so that we don't have to do overload resolution for things like e.g.: wp.sin()
|
|
1859
|
+
// this is important because many of the builtin operators don't follow particular conventions on references for
|
|
1860
|
+
// the `adj_ret` parameter, which means it's not possible to figure out the overload we need using simple casting
|
|
1861
|
+
#define tile_unary_map(op, a) tile_map([](auto x) { return op(x);}, a)
|
|
1862
|
+
#define adj_tile_unary_map(op, a, adj_op, adj_a, adj_ret) adj_tile_map([](auto x) { return op(x);}, a, [](auto x, auto& adj_x, auto adj_ret) { adj_op(x, adj_x, adj_ret);}, adj_a, adj_ret)
|
|
1863
|
+
|
|
1864
|
+
#define tile_binary_map(op, a, b) tile_map([](auto x, auto y) { return op(x, y);}, a, b)
|
|
1865
|
+
#define adj_tile_binary_map(op, a, b, adj_op, adj_a, adj_b, adj_ret) adj_tile_map([](auto x, auto y) { return op(x, y);}, a, b, [](auto x, auto y, auto& adj_x, auto& adj_y, auto adj_ret) { adj_op(x, y, adj_x, adj_y, adj_ret);}, adj_a, adj_b, adj_ret)
|
|
1866
|
+
|
|
1867
|
+
// -tile (unary neg)
|
|
1868
|
+
template <typename Tile>
|
|
1869
|
+
inline CUDA_CALLABLE auto tile_neg(Tile& a) { return tile_unary_map(wp::neg, a); }
|
|
1870
|
+
|
|
1871
|
+
template <typename Tile, typename AdjTile>
|
|
1872
|
+
inline CUDA_CALLABLE void adj_tile_neg(Tile& a, Tile& adj_a, AdjTile& adj_ret) { adj_tile_unary_map(wp::neg, a, wp::adj_neg, adj_a, adj_ret); }
|
|
1873
|
+
|
|
1874
|
+
|
|
1875
|
+
// tile + tile
|
|
1876
|
+
template <typename TileA, typename TileB>
|
|
1877
|
+
inline CUDA_CALLABLE auto tile_add(TileA& a, TileB& b)
|
|
1878
|
+
{
|
|
1879
|
+
return tile_binary_map(add, a, b);
|
|
1880
|
+
}
|
|
1881
|
+
|
|
1882
|
+
template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB, typename AdjTile>
|
|
1883
|
+
inline CUDA_CALLABLE void adj_tile_add(TileA& a, TileB& b, AdjTileA& adj_a, AdjTileB& adj_b, AdjTile& adj_c)
|
|
1884
|
+
{
|
|
1885
|
+
adj_tile_binary_map(add, a, b, adj_add, adj_a, adj_b, adj_c);
|
|
1886
|
+
}
|
|
1887
|
+
|
|
1888
|
+
// tile - tile
|
|
1889
|
+
template <typename TileA, typename TileB>
|
|
1890
|
+
inline CUDA_CALLABLE auto tile_sub(TileA& a, TileB& b)
|
|
1891
|
+
{
|
|
1892
|
+
return tile_binary_map(sub, a, b);
|
|
1893
|
+
}
|
|
1894
|
+
|
|
1895
|
+
template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB, typename AdjTile>
|
|
1896
|
+
inline CUDA_CALLABLE void adj_tile_sub(TileA& a, TileB& b, AdjTileA& adj_a, AdjTileB& adj_b, AdjTile& adj_c)
|
|
1897
|
+
{
|
|
1898
|
+
adj_tile_binary_map(sub, a, b, adj_sub, adj_a, adj_b, adj_c);
|
|
1899
|
+
}
|
|
1900
|
+
|
|
1901
|
+
|
|
1902
|
+
// tile*scalar
|
|
1903
|
+
template <typename Tile>
|
|
1904
|
+
inline CUDA_CALLABLE auto tile_mul(Tile& a, const typename Tile::Type& s)
|
|
1905
|
+
{
|
|
1906
|
+
// promote scalar to a constant tile
|
|
1907
|
+
auto s_tile = tile_register_t<typename Tile::Type, tile_layout_register_t<typename Tile::Layout::Shape>>(s);
|
|
1908
|
+
|
|
1909
|
+
return tile_binary_map(mul, a, s_tile);
|
|
1910
|
+
}
|
|
1911
|
+
|
|
1912
|
+
template <typename Tile, typename AdjTile>
|
|
1913
|
+
inline CUDA_CALLABLE void adj_tile_mul(Tile& a, const typename Tile::Type& s,
|
|
1914
|
+
Tile& adj_a, typename Tile::Type& adj_s,
|
|
1915
|
+
AdjTile& adj_c)
|
|
1916
|
+
{
|
|
1917
|
+
auto s_tile = tile_register_like<Tile>();
|
|
1918
|
+
auto adj_s_tile = tile_register_like<Tile>();
|
|
1919
|
+
|
|
1920
|
+
using Layout = typename decltype(adj_s_tile)::Layout;
|
|
1921
|
+
|
|
1922
|
+
// initialize to constant
|
|
1923
|
+
s_tile = s;
|
|
1924
|
+
|
|
1925
|
+
adj_tile_binary_map(mul, a, s_tile, adj_mul, adj_a, adj_s_tile, adj_c);
|
|
1926
|
+
|
|
1927
|
+
for (int i=0; i < Layout::NumRegs; ++i)
|
|
1928
|
+
{
|
|
1929
|
+
adj_s += adj_s_tile.data[i];
|
|
1930
|
+
}
|
|
1931
|
+
}
|
|
1932
|
+
|
|
1933
|
+
|
|
1934
|
+
// scalar*tile
|
|
1935
|
+
template <typename Tile>
|
|
1936
|
+
inline CUDA_CALLABLE auto tile_mul(const typename Tile::Type& s, Tile& a)
|
|
1937
|
+
{
|
|
1938
|
+
return tile_mul(a, s);
|
|
1939
|
+
}
|
|
1940
|
+
|
|
1941
|
+
template <typename Tile, typename AdjTile>
|
|
1942
|
+
inline CUDA_CALLABLE void adj_tile_mul(const typename Tile::Type& s, Tile& a,
|
|
1943
|
+
typename Tile::Type& adj_s, Tile& adj_a,
|
|
1944
|
+
AdjTile& adj_c)
|
|
1945
|
+
{
|
|
1946
|
+
adj_tile_mul(a, s, adj_a, adj_s, adj_c);
|
|
1947
|
+
}
|
|
1948
|
+
|
|
1949
|
+
|
|
1950
|
+
template<typename Tile>
|
|
1951
|
+
typename Tile::Type tile_extract(Tile& t, int i) { return t.extract(tile_coord(i)); }
|
|
1952
|
+
template<typename Tile>
|
|
1953
|
+
typename Tile::Type tile_extract(Tile& t, int i, int j) { return t.extract(tile_coord(i,j)); }
|
|
1954
|
+
template<typename Tile>
|
|
1955
|
+
typename Tile::Type tile_extract(Tile& t, int i, int j, int k) { return t.extract(tile_coord(i,j,k)); }
|
|
1956
|
+
template<typename Tile>
|
|
1957
|
+
typename Tile::Type tile_extract(Tile& t, int i, int j, int k, int l) { return t.extract(tile_coord(i,j,k,l)); }
|
|
1958
|
+
|
|
1959
|
+
|
|
1960
|
+
template<typename Tile, typename AdjTile>
|
|
1961
|
+
void adj_tile_extract(Tile& t, int i, AdjTile& adj_t, int adj_i, typename Tile::Type adj_ret) { adj_t.adj_extract(tile_coord(i), adj_ret); }
|
|
1962
|
+
template<typename Tile, typename AdjTile>
|
|
1963
|
+
void adj_tile_extract(Tile& t, int i, int j, AdjTile& adj_t, int adj_i, int adj_j, typename Tile::Type adj_ret) { adj_t.adj_extract(tile_coord(i, j), adj_ret); }
|
|
1964
|
+
template<typename Tile, typename AdjTile>
|
|
1965
|
+
void adj_tile_extract(Tile& t, int i, int j, int k, AdjTile& adj_t, int adj_i, int adj_j, int adj_k, typename Tile::Type adj_ret) { adj_t.adj_extract(tile_coord(i, j, k), adj_ret); }
|
|
1966
|
+
template<typename Tile, typename AdjTile>
|
|
1967
|
+
void adj_tile_extract(Tile& t, int i, int j, int k, int l, AdjTile& adj_t, int adj_i, int adj_j, int adj_k, int adj_l, typename Tile::Type adj_ret) { adj_t.adj_extract(tile_coord(i, j, k, l), adj_ret); }
|
|
1968
|
+
|
|
1969
|
+
|
|
1970
|
+
namespace partitioned_gemm
|
|
1971
|
+
{
|
|
1972
|
+
|
|
1973
|
+
template <typename T>
|
|
1974
|
+
inline CUDA_CALLABLE const T& index(const T* __restrict__ p, int i, int j, int stride)
|
|
1975
|
+
{
|
|
1976
|
+
return p[i*stride + j];
|
|
1977
|
+
}
|
|
1978
|
+
|
|
1979
|
+
template <typename T>
|
|
1980
|
+
inline CUDA_CALLABLE T& index(T* __restrict__ p, int i, int j, int stride)
|
|
1981
|
+
{
|
|
1982
|
+
return p[i*stride + j];
|
|
1983
|
+
}
|
|
1984
|
+
|
|
1985
|
+
template <int PartitionM, int PartitionN, typename Tile>
|
|
1986
|
+
struct partition_t
|
|
1987
|
+
{
|
|
1988
|
+
static constexpr int M = PartitionM;
|
|
1989
|
+
static constexpr int N = PartitionN;
|
|
1990
|
+
static constexpr int Stride = Tile::Layout::Shape::dim(1);
|
|
1991
|
+
|
|
1992
|
+
using T = typename Tile::Type;
|
|
1993
|
+
|
|
1994
|
+
inline partition_t(Tile& A)
|
|
1995
|
+
{
|
|
1996
|
+
data = A.data.ptr;
|
|
1997
|
+
|
|
1998
|
+
// todo: do ceil div for non-multiples of M,N
|
|
1999
|
+
shape[0] = Tile::Layout::Shape::dim(0)/PartitionM;
|
|
2000
|
+
shape[1] = Tile::Layout::Shape::dim(1)/PartitionN;
|
|
2001
|
+
}
|
|
2002
|
+
|
|
2003
|
+
// underlying data
|
|
2004
|
+
T* data;
|
|
2005
|
+
|
|
2006
|
+
// partition dimensions
|
|
2007
|
+
int shape[2];
|
|
2008
|
+
};
|
|
2009
|
+
|
|
2010
|
+
template <typename Partition>
|
|
2011
|
+
inline int partition_size(const Partition& part)
|
|
2012
|
+
{
|
|
2013
|
+
return part.shape[0]*part.shape[1];
|
|
2014
|
+
}
|
|
2015
|
+
|
|
2016
|
+
// returns the x, y coordinates of a tile given a linear index
|
|
2017
|
+
template <typename Partition>
|
|
2018
|
+
inline void partition_coord(const Partition& part, const int t, int& i, int& j)
|
|
2019
|
+
{
|
|
2020
|
+
i = t/part.shape[1];
|
|
2021
|
+
j = t%part.shape[1];
|
|
2022
|
+
}
|
|
2023
|
+
|
|
2024
|
+
template <typename Partition>
|
|
2025
|
+
inline auto partition_load(const Partition& tile, int i, int j)
|
|
2026
|
+
{
|
|
2027
|
+
mat_t<Partition::M, Partition::N, typename Partition::T> out;
|
|
2028
|
+
|
|
2029
|
+
const int tile_i = i*Partition::M;
|
|
2030
|
+
const int tile_j = j*Partition::N;
|
|
2031
|
+
|
|
2032
|
+
WP_PRAGMA_UNROLL
|
|
2033
|
+
for (int i=0; i < Partition::M; ++i)
|
|
2034
|
+
{
|
|
2035
|
+
WP_PRAGMA_UNROLL
|
|
2036
|
+
for (int j=0; j < Partition::N; ++j)
|
|
2037
|
+
{
|
|
2038
|
+
out.data[i][j] = partitioned_gemm::index(tile.data, tile_i + i, tile_j + j, Partition::Stride);
|
|
2039
|
+
}
|
|
2040
|
+
}
|
|
2041
|
+
|
|
2042
|
+
return out;
|
|
2043
|
+
}
|
|
2044
|
+
|
|
2045
|
+
template <typename Partition, typename Value>
|
|
2046
|
+
inline void partition_store(const Partition& tile, int i, int j, const Value& value)
|
|
2047
|
+
{
|
|
2048
|
+
const int tile_i = Partition::M*i;
|
|
2049
|
+
const int tile_j = Partition::N*j;
|
|
2050
|
+
|
|
2051
|
+
WP_PRAGMA_UNROLL
|
|
2052
|
+
for (int i=0; i < Partition::M; ++i)
|
|
2053
|
+
{
|
|
2054
|
+
WP_PRAGMA_UNROLL
|
|
2055
|
+
for (int j=0; j < Partition::N; ++j)
|
|
2056
|
+
{
|
|
2057
|
+
index(tile.data, tile_i + i, tile_j + j, Partition::Stride) = value.data[i][j];
|
|
2058
|
+
}
|
|
2059
|
+
}
|
|
2060
|
+
}
|
|
2061
|
+
|
|
2062
|
+
|
|
2063
|
+
template <typename TileA, typename TileB, typename TileC>
|
|
2064
|
+
inline CUDA_CALLABLE void matmul(TileA& A, TileB& B, TileC& out)
|
|
2065
|
+
{
|
|
2066
|
+
const int TILE_M = 4;
|
|
2067
|
+
const int TILE_N = 4;
|
|
2068
|
+
const int TILE_K = 4;
|
|
2069
|
+
|
|
2070
|
+
auto A_tile = partition_t<TILE_M, TILE_K, TileA>(A);
|
|
2071
|
+
auto B_tile = partition_t<TILE_K, TILE_N, TileB>(B);
|
|
2072
|
+
auto C_tile = partition_t<TILE_M, TILE_N, TileC>(out);
|
|
2073
|
+
|
|
2074
|
+
//static_assert(is_same<typename TileA::Type, typename TileB::Type>::value);
|
|
2075
|
+
|
|
2076
|
+
const int length = partition_size(C_tile);
|
|
2077
|
+
|
|
2078
|
+
for (int t=WP_TILE_THREAD_IDX; t < length; t += WP_TILE_BLOCK_DIM)
|
|
2079
|
+
{
|
|
2080
|
+
int i, j;
|
|
2081
|
+
partition_coord(C_tile, t, i, j);
|
|
2082
|
+
|
|
2083
|
+
// accumulator
|
|
2084
|
+
auto sum = partition_load(C_tile, i, j);
|
|
2085
|
+
|
|
2086
|
+
WP_PRAGMA_UNROLL
|
|
2087
|
+
for (int k=0; k < A_tile.shape[1]; k++)
|
|
2088
|
+
{
|
|
2089
|
+
const auto a = partition_load(A_tile, i, k);
|
|
2090
|
+
const auto b = partition_load(B_tile, k, j);
|
|
2091
|
+
|
|
2092
|
+
sum += mul(a, b);
|
|
2093
|
+
}
|
|
2094
|
+
|
|
2095
|
+
partition_store(C_tile, i, j, sum);
|
|
2096
|
+
}
|
|
2097
|
+
}
|
|
2098
|
+
|
|
2099
|
+
template <typename LayoutA, typename LayoutB, typename LayoutC, typename StorageA, typename StorageB, typename StorageC, typename T>
|
|
2100
|
+
inline CUDA_CALLABLE void scalar_matmul(const StorageA& A, const StorageB& B, StorageC& C, T scale)
|
|
2101
|
+
{
|
|
2102
|
+
for (int t=WP_TILE_THREAD_IDX; t < LayoutC::Size; t += WP_TILE_BLOCK_DIM)
|
|
2103
|
+
{
|
|
2104
|
+
auto coord = LayoutC::coord_from_linear(t);
|
|
2105
|
+
|
|
2106
|
+
int i = coord[0];
|
|
2107
|
+
int j = coord[1];
|
|
2108
|
+
|
|
2109
|
+
// accumulator
|
|
2110
|
+
auto sum = C(coord)*scale;
|
|
2111
|
+
|
|
2112
|
+
WP_PRAGMA_UNROLL
|
|
2113
|
+
for (int k=0; k < LayoutA::Shape::dim(1); k++)
|
|
2114
|
+
{
|
|
2115
|
+
const auto a = A(tile_coord(i, k));
|
|
2116
|
+
const auto b = B(tile_coord(k, j));
|
|
2117
|
+
|
|
2118
|
+
sum = muladd<decltype(sum)>(a, b, sum);
|
|
2119
|
+
}
|
|
2120
|
+
|
|
2121
|
+
C(coord) = sum;
|
|
2122
|
+
}
|
|
2123
|
+
}
|
|
2124
|
+
|
|
2125
|
+
template <typename TileA, typename TileL>
|
|
2126
|
+
inline CUDA_CALLABLE void scalar_cholesky(TileA& A, TileL& L)
|
|
2127
|
+
{
|
|
2128
|
+
using T = typename TileA::Type;
|
|
2129
|
+
constexpr int n = TileA::Layout::Shape::dim(1);
|
|
2130
|
+
|
|
2131
|
+
for (int j=0; j < n; ++j)
|
|
2132
|
+
{
|
|
2133
|
+
T s = A.data(tile_coord(j, j));
|
|
2134
|
+
|
|
2135
|
+
for (int k=0; k < j; ++k)
|
|
2136
|
+
{
|
|
2137
|
+
T r = L.data(tile_coord(j, k));
|
|
2138
|
+
s -= r * r;
|
|
2139
|
+
}
|
|
2140
|
+
|
|
2141
|
+
s = wp::sqrt(s);
|
|
2142
|
+
T invS = 1.0 / s;
|
|
2143
|
+
|
|
2144
|
+
L.data(tile_coord(j, j)) = s;
|
|
2145
|
+
|
|
2146
|
+
for (int i=j+1; i < n; ++i)
|
|
2147
|
+
{
|
|
2148
|
+
s = A.data(tile_coord(i, j));
|
|
2149
|
+
|
|
2150
|
+
for (int k=0; k < j; ++k)
|
|
2151
|
+
{
|
|
2152
|
+
s -= L.data(tile_coord(i, k)) * L.data(tile_coord(j, k));
|
|
2153
|
+
}
|
|
2154
|
+
|
|
2155
|
+
L.data(tile_coord(i, j)) = s * invS;
|
|
2156
|
+
}
|
|
2157
|
+
|
|
2158
|
+
// zero out upper triangular portion
|
|
2159
|
+
for (int k=j+1; k < n; ++k)
|
|
2160
|
+
{
|
|
2161
|
+
L.data(tile_coord(j,k)) = T(0.0);
|
|
2162
|
+
}
|
|
2163
|
+
}
|
|
2164
|
+
}
|
|
2165
|
+
|
|
2166
|
+
template <typename TileL, typename TileX, typename TileY>
|
|
2167
|
+
inline CUDA_CALLABLE void scalar_cholesky_solve(TileL& L, TileX& X, TileY& Y)
|
|
2168
|
+
{
|
|
2169
|
+
using T = typename TileL::Type;
|
|
2170
|
+
constexpr int n = TileL::Layout::Shape::dim(1);
|
|
2171
|
+
|
|
2172
|
+
for (int i=0; i < n; ++i)
|
|
2173
|
+
{
|
|
2174
|
+
T s = Y.data(tile_coord(i));
|
|
2175
|
+
|
|
2176
|
+
for (int j=0; j < i; ++j)
|
|
2177
|
+
s -= L.data(tile_coord(i,j)) * X.data(tile_coord(j));
|
|
2178
|
+
|
|
2179
|
+
X.data(tile_coord(i)) = s / L.data(tile_coord(i, i));
|
|
2180
|
+
}
|
|
2181
|
+
|
|
2182
|
+
for (int i=n-1; i >= 0; --i)
|
|
2183
|
+
{
|
|
2184
|
+
T s = X.data(tile_coord(i));
|
|
2185
|
+
|
|
2186
|
+
for (int j=i+1; j < n; ++j)
|
|
2187
|
+
s -= L.data(tile_coord(j, i)) * X.data(tile_coord(j));
|
|
2188
|
+
|
|
2189
|
+
X.data(tile_coord(i)) = s / L.data(tile_coord(i, i));
|
|
2190
|
+
}
|
|
2191
|
+
}
|
|
2192
|
+
|
|
2193
|
+
} // namespace partition_gemm
|
|
2194
|
+
|
|
2195
|
+
|
|
2196
|
+
template <int Add, typename Fwd, typename AdjA, typename AdjB, typename TileA, typename TileB, typename TileC>
|
|
2197
|
+
TileC& tile_matmul(Fwd fun_forward, AdjA fun_backward_A, AdjB fun_backward_B, TileA& A, TileB& B, TileC& C)
|
|
2198
|
+
{
|
|
2199
|
+
using ShapeA = typename TileA::Layout::Shape;
|
|
2200
|
+
using ShapeB = typename TileB::Layout::Shape;
|
|
2201
|
+
using ShapeC = typename TileC::Layout::Shape;
|
|
2202
|
+
|
|
2203
|
+
static_assert(ShapeA::N == 2, "Expected ShapeA::N == 2");
|
|
2204
|
+
static_assert(ShapeB::N == 2, "Expected ShapeB::N == 2");
|
|
2205
|
+
static_assert(ShapeC::N == 2, "Expected ShapeC::N == 2");
|
|
2206
|
+
|
|
2207
|
+
static_assert(ShapeA::dim(1) == ShapeB::dim(0), "Expected ShapeA::dim(1) == ShapeB::dim(0)");
|
|
2208
|
+
static_assert(ShapeC::dim(0) == ShapeA::dim(0), "Expected ShapeC::dim(0) == ShapeA::dim(0)");
|
|
2209
|
+
static_assert(ShapeC::dim(1) == ShapeB::dim(1), "Expected ShapeC::dim(1) == ShapeB::dim(1)");
|
|
2210
|
+
|
|
2211
|
+
|
|
2212
|
+
using T = typename TileA::Type;
|
|
2213
|
+
|
|
2214
|
+
#if !defined(__CUDA_ARCH__) || WP_ENABLE_MATHDX == 0
|
|
2215
|
+
partitioned_gemm::scalar_matmul<typename TileA::Layout, typename TileB::Layout, typename TileC::Layout>(A.data, B.data, C.data, T(Add));
|
|
2216
|
+
#else
|
|
2217
|
+
fun_forward(T(1.0), A.data.ptr, B.data.ptr, T(Add), C.data.ptr);
|
|
2218
|
+
#endif
|
|
2219
|
+
|
|
2220
|
+
WP_TILE_SYNC();
|
|
2221
|
+
|
|
2222
|
+
return C;
|
|
2223
|
+
}
|
|
2224
|
+
|
|
2225
|
+
|
|
2226
|
+
// backward for the wp.tile_matmul(a, b, out) syntax
|
|
2227
|
+
template <typename Fwd, typename AdjA, typename AdjB, typename TileA, typename TileB, typename TileC>
|
|
2228
|
+
void adj_tile_matmul(Fwd fun_forward, AdjA fun_backward_A, AdjB fun_backward_B, TileA& A, TileB& B, TileC& C,
|
|
2229
|
+
Fwd adj_fun_forward, AdjA adj_fun_backward_A, AdjB adj_fun_backward_B, TileA& adj_A, TileB& adj_B, TileC& adj_C)
|
|
2230
|
+
{
|
|
2231
|
+
using T = typename TileA::Type;
|
|
2232
|
+
|
|
2233
|
+
#if !defined(__CUDA_ARCH__) || WP_ENABLE_MATHDX == 0
|
|
2234
|
+
auto At = tile_transpose(A);
|
|
2235
|
+
auto Bt = tile_transpose(B);
|
|
2236
|
+
|
|
2237
|
+
partitioned_gemm::scalar_matmul<typename TileC::Layout, typename decltype(Bt)::Layout, typename TileA::Layout>(adj_C.grad, Bt.data, adj_A.grad, T(1.0));
|
|
2238
|
+
partitioned_gemm::scalar_matmul<typename decltype(At)::Layout, typename TileC::Layout, typename TileB::Layout>(At.data, adj_C.grad, adj_B.grad, T(1.0));
|
|
2239
|
+
#else
|
|
2240
|
+
fun_backward_A(T(1.0), adj_C.grad.ptr, B.data.ptr, T(1.0), adj_A.grad.ptr);
|
|
2241
|
+
fun_backward_B(T(1.0), A.data.ptr, adj_C.grad.ptr, T(1.0), adj_B.grad.ptr);
|
|
2242
|
+
#endif
|
|
2243
|
+
|
|
2244
|
+
WP_TILE_SYNC();
|
|
2245
|
+
}
|
|
2246
|
+
|
|
2247
|
+
// backward for the out = wp.tile_matmul(a, b) syntax
|
|
2248
|
+
template <typename Fwd, typename AdjA, typename AdjB, typename TileA, typename TileB, typename TileC>
|
|
2249
|
+
void adj_tile_matmul(Fwd fun_forward, AdjA fun_backward_A, AdjB fun_backward_B, TileA& A, TileB& B, TileC& C,
|
|
2250
|
+
Fwd adj_fun_forward, AdjA adj_fun_backward_A, AdjB adj_fun_backward_B, TileA& adj_A, TileB& adj_B, TileC& adj_C, TileC& adj_ret)
|
|
2251
|
+
{
|
|
2252
|
+
using T = typename TileA::Type;
|
|
2253
|
+
|
|
2254
|
+
#if !defined(__CUDA_ARCH__) || WP_ENABLE_MATHDX == 0
|
|
2255
|
+
auto At = tile_transpose(A);
|
|
2256
|
+
auto Bt = tile_transpose(B);
|
|
2257
|
+
|
|
2258
|
+
partitioned_gemm::scalar_matmul<typename TileC::Layout, typename decltype(Bt)::Layout, typename TileA::Layout>(adj_C.grad, Bt.data, adj_A.grad, T(1.0));
|
|
2259
|
+
partitioned_gemm::scalar_matmul<typename decltype(At)::Layout, typename TileC::Layout, typename TileB::Layout>(At.data, adj_C.grad, adj_B.grad, T(1.0));
|
|
2260
|
+
#else
|
|
2261
|
+
fun_backward_A(T(1.0), adj_C.grad.ptr, B.data.ptr, T(1.0), adj_A.grad.ptr);
|
|
2262
|
+
fun_backward_B(T(1.0), A.data.ptr, adj_C.grad.ptr, T(1.0), adj_B.grad.ptr);
|
|
2263
|
+
#endif
|
|
2264
|
+
|
|
2265
|
+
WP_TILE_SYNC();
|
|
2266
|
+
}
|
|
2267
|
+
|
|
2268
|
+
#if !defined(__CUDA_ARCH__) || WP_ENABLE_MATHDX == 0
|
|
2269
|
+
|
|
2270
|
+
#define tile_fft()
|
|
2271
|
+
#define tile_ifft()
|
|
2272
|
+
|
|
2273
|
+
#define adj_tile_fft()
|
|
2274
|
+
#define adj_tile_ifft()
|
|
2275
|
+
|
|
2276
|
+
#else
|
|
2277
|
+
|
|
2278
|
+
// TODO(lcambier): use a properly overaligned complex type that matches cuFFTDx's expectation
|
|
2279
|
+
// and remove the need for __align__(16) dtypes data[...]
|
|
2280
|
+
#define tile_fft(function_name, dtype, shared_memory_size, batch_size, ept, Xinout) \
|
|
2281
|
+
do { \
|
|
2282
|
+
void function_name(dtype*, dtype*); \
|
|
2283
|
+
char* buffer = (char*)wp::tile_alloc_shared(shared_memory_size); \
|
|
2284
|
+
__align__(16) dtype data[ept]; \
|
|
2285
|
+
for(int b = 0; b < (int)batch_size; b++) { \
|
|
2286
|
+
dtype* inout = Xinout.data + (int)b * (int)ept; \
|
|
2287
|
+
memcpy(data, inout, sizeof(dtype) * ept); \
|
|
2288
|
+
function_name(data, (dtype*)buffer); \
|
|
2289
|
+
memcpy(inout, data, sizeof(dtype) * ept); \
|
|
2290
|
+
WP_TILE_SYNC(); \
|
|
2291
|
+
} \
|
|
2292
|
+
wp::tile_alloc_shared(-shared_memory_size); \
|
|
2293
|
+
} while (0)
|
|
2294
|
+
|
|
2295
|
+
#define tile_ifft tile_fft
|
|
2296
|
+
|
|
2297
|
+
// adj_function_name, adj_dtype, adj_shared_memory_size, adj_batch_size, adj_ept are all ignored
|
|
2298
|
+
|
|
2299
|
+
#define adj_tile_fft(function_name, dtype, shared_memory_size, batch_size, ept, Xinout, \
|
|
2300
|
+
adj_function_name, adj_dtype, adj_shared_memory_size, adj_batch_size, adj_ept, \
|
|
2301
|
+
adj_Xinout) \
|
|
2302
|
+
do { \
|
|
2303
|
+
tile_ifft(function_name, dtype, shared_memory_size, batch_size, ept, adj_Xinout); \
|
|
2304
|
+
} while (0)
|
|
2305
|
+
|
|
2306
|
+
#define adj_tile_ifft(function_name, dtype, shared_memory_size, batch_size, ept, Xinout, \
|
|
2307
|
+
adj_function_name, adj_dtype, adj_shared_memory_size, adj_batch_size, adj_ept, \
|
|
2308
|
+
adj_Xinout) \
|
|
2309
|
+
do { \
|
|
2310
|
+
tile_fft(function_name, dtype, shared_memory_size, batch_size, ept, adj_Xinout); \
|
|
2311
|
+
} while (0)
|
|
2312
|
+
|
|
2313
|
+
#endif // !defined(__CUDA_ARCH__)
|
|
2314
|
+
|
|
2315
|
+
template <typename Fwd, typename TileA, typename TileL>
|
|
2316
|
+
TileL& tile_cholesky(Fwd fun_forward, TileA& A, TileL& L)
|
|
2317
|
+
{
|
|
2318
|
+
// Copy to L
|
|
2319
|
+
L = A;
|
|
2320
|
+
|
|
2321
|
+
#if !defined(__CUDA_ARCH__) || WP_ENABLE_MATHDX == 0
|
|
2322
|
+
|
|
2323
|
+
partitioned_gemm::scalar_cholesky(A, L);
|
|
2324
|
+
|
|
2325
|
+
#else
|
|
2326
|
+
|
|
2327
|
+
|
|
2328
|
+
// Call cholesky on L
|
|
2329
|
+
WP_TILE_SYNC();
|
|
2330
|
+
|
|
2331
|
+
fun_forward(L.data.ptr, TileL::Layout::Shape::dim(0));
|
|
2332
|
+
|
|
2333
|
+
WP_TILE_SYNC();
|
|
2334
|
+
|
|
2335
|
+
// Zero-out the upper triangular part of L
|
|
2336
|
+
|
|
2337
|
+
WP_PRAGMA_UNROLL
|
|
2338
|
+
for (int i=WP_TILE_THREAD_IDX; i < TileL::Layout::Size; i += WP_TILE_BLOCK_DIM)
|
|
2339
|
+
{
|
|
2340
|
+
auto c = TileL::Layout::coord_from_linear(i);
|
|
2341
|
+
|
|
2342
|
+
if(c[0] < c[1])
|
|
2343
|
+
L.data(c) = 0.0;
|
|
2344
|
+
}
|
|
2345
|
+
|
|
2346
|
+
WP_TILE_SYNC();
|
|
2347
|
+
|
|
2348
|
+
#endif
|
|
2349
|
+
|
|
2350
|
+
return L;
|
|
2351
|
+
}
|
|
2352
|
+
|
|
2353
|
+
#define adj_tile_cholesky(function_name, A, L, \
|
|
2354
|
+
adj_function_name, adj_A, adj_L, adj_ret) \
|
|
2355
|
+
do { \
|
|
2356
|
+
assert(false); \
|
|
2357
|
+
} while (0)
|
|
2358
|
+
|
|
2359
|
+
template <typename Fwd, typename TileL, typename TileX, typename TileY>
|
|
2360
|
+
TileY& tile_cholesky_solve(Fwd fun_forward, TileL& L, TileX& X, TileY& Y)
|
|
2361
|
+
{
|
|
2362
|
+
// Copy x to y
|
|
2363
|
+
|
|
2364
|
+
Y = X;
|
|
2365
|
+
|
|
2366
|
+
#if !defined(__CUDA_ARCH__) || WP_ENABLE_MATHDX == 0
|
|
2367
|
+
|
|
2368
|
+
partitioned_gemm::scalar_cholesky_solve(L, X, Y);
|
|
2369
|
+
|
|
2370
|
+
#else
|
|
2371
|
+
|
|
2372
|
+
// Call cholesky solve on L & y
|
|
2373
|
+
|
|
2374
|
+
WP_TILE_SYNC();
|
|
2375
|
+
|
|
2376
|
+
fun_forward(L.data.ptr, Y.data.ptr); \
|
|
2377
|
+
|
|
2378
|
+
WP_TILE_SYNC();
|
|
2379
|
+
|
|
2380
|
+
#endif
|
|
2381
|
+
|
|
2382
|
+
return Y;
|
|
2383
|
+
}
|
|
2384
|
+
|
|
2385
|
+
#define adj_tile_cholesky_solve(function_name, L, X, Y, \
|
|
2386
|
+
adj_function_name, adj_L, adj_X, adj_Y, adj_ret) \
|
|
2387
|
+
do { \
|
|
2388
|
+
assert(false); \
|
|
2389
|
+
} while (0)
|
|
2390
|
+
|
|
2391
|
+
template <typename Tile>
|
|
2392
|
+
inline CUDA_CALLABLE auto tile_transpose(Tile& t)
|
|
2393
|
+
{
|
|
2394
|
+
static_assert(Tile::Layout::Shape::N == 2, "Expected Tile::Layout::Shape::N == 2");
|
|
2395
|
+
|
|
2396
|
+
// alias incoming tile
|
|
2397
|
+
constexpr int M = Tile::Layout::Shape::dim(0);
|
|
2398
|
+
constexpr int N = Tile::Layout::Shape::dim(1);
|
|
2399
|
+
|
|
2400
|
+
constexpr int StrideM = Tile::Layout::Stride::dim(0);
|
|
2401
|
+
constexpr int StrideN = Tile::Layout::Stride::dim(1);
|
|
2402
|
+
|
|
2403
|
+
return tile_shared_t<typename Tile::Type, tile_layout_strided_t<tile_shape_t<N,M>, tile_stride_t<StrideN, StrideM>>, false>(t.data.ptr, t.grad.ptr);
|
|
2404
|
+
}
|
|
2405
|
+
|
|
2406
|
+
template <typename Tile, typename AdjTile>
|
|
2407
|
+
inline CUDA_CALLABLE void adj_tile_transpose(Tile& t, Tile& adj_t, AdjTile& adj_ret)
|
|
2408
|
+
{
|
|
2409
|
+
auto a = tile_transpose(adj_ret);
|
|
2410
|
+
auto b = adj_t;
|
|
2411
|
+
|
|
2412
|
+
adj_t.assign(tile_add(a,b));
|
|
2413
|
+
}
|
|
2414
|
+
|
|
2415
|
+
template <int N, int StrideN, typename Tile>
|
|
2416
|
+
inline CUDA_CALLABLE auto tile_broadcast(Tile& t)
|
|
2417
|
+
{
|
|
2418
|
+
// alias incoming tile with new strides
|
|
2419
|
+
return tile_shared_t<typename Tile::Type, tile_layout_strided_t<tile_shape_t<N>, tile_stride_t<StrideN>>, false>(t.data.ptr, t.grad.ptr);
|
|
2420
|
+
}
|
|
2421
|
+
|
|
2422
|
+
template <int M, int N, int StrideM, int StrideN, typename Tile>
|
|
2423
|
+
inline CUDA_CALLABLE auto tile_broadcast(Tile& t)
|
|
2424
|
+
{
|
|
2425
|
+
// alias incoming tile with new strides
|
|
2426
|
+
return tile_shared_t<typename Tile::Type, tile_layout_strided_t<tile_shape_t<M, N>, tile_stride_t<StrideM, StrideN>>, false>(t.data.ptr, t.grad.ptr);
|
|
2427
|
+
}
|
|
2428
|
+
|
|
2429
|
+
template <int M, int N, int O, int StrideM, int StrideN, int StrideO, typename Tile>
|
|
2430
|
+
inline CUDA_CALLABLE auto tile_broadcast(Tile& t)
|
|
2431
|
+
{
|
|
2432
|
+
// alias incoming tile with new strides
|
|
2433
|
+
return tile_shared_t<typename Tile::Type, tile_layout_strided_t<tile_shape_t<M, N, O>, tile_stride_t<StrideM, StrideN, StrideO>>, false>(t.data.ptr, t.grad.ptr);
|
|
2434
|
+
}
|
|
2435
|
+
|
|
2436
|
+
template <int M, int N, int O, int P, int StrideM, int StrideN, int StrideO, int StrideP, typename Tile>
|
|
2437
|
+
inline CUDA_CALLABLE auto tile_broadcast(Tile& t)
|
|
2438
|
+
{
|
|
2439
|
+
// alias incoming tile with new strides
|
|
2440
|
+
return tile_shared_t<typename Tile::Type, tile_layout_strided_t<tile_shape_t<M, N, O, P>, tile_stride_t<StrideM, StrideN, StrideO, StrideP>>, false>(t.data.ptr, t.grad.ptr);
|
|
2441
|
+
}
|
|
2442
|
+
|
|
2443
|
+
template <typename Tile, typename AdjTile>
|
|
2444
|
+
inline CUDA_CALLABLE void adj_tile_broadcast(Tile& t, Tile& adj_t, AdjTile& adj_ret)
|
|
2445
|
+
{
|
|
2446
|
+
// nop, since memory is aliased grads already accumulated
|
|
2447
|
+
}
|
|
2448
|
+
|
|
2449
|
+
template <typename ReturnType, typename Tile, typename... Indices>
|
|
2450
|
+
inline CUDA_CALLABLE auto tile_view(Tile& t, Indices... indices)
|
|
2451
|
+
{
|
|
2452
|
+
auto c = tile_coord(indices...);
|
|
2453
|
+
|
|
2454
|
+
// return new tile with same strides
|
|
2455
|
+
typename Tile::Type* data_ptr = &t.data(c);
|
|
2456
|
+
typename Tile::Type* grad_ptr = nullptr;
|
|
2457
|
+
|
|
2458
|
+
if (t.grad.ptr)
|
|
2459
|
+
grad_ptr = &t.grad(c);
|
|
2460
|
+
|
|
2461
|
+
return ReturnType(data_ptr, grad_ptr);
|
|
2462
|
+
}
|
|
2463
|
+
|
|
2464
|
+
|
|
2465
|
+
template <typename TileA, typename Scalar>
|
|
2466
|
+
inline CUDA_CALLABLE void assign(TileA& dest, int i, const Scalar& src)
|
|
2467
|
+
{
|
|
2468
|
+
dest.data(tile_coord(i)) = src;
|
|
2469
|
+
WP_TILE_SYNC();
|
|
2470
|
+
}
|
|
2471
|
+
|
|
2472
|
+
template <typename TileA, typename Scalar>
|
|
2473
|
+
inline CUDA_CALLABLE void assign(TileA& dest, int i, int j, const Scalar& src)
|
|
2474
|
+
{
|
|
2475
|
+
dest.data(tile_coord(i, j)) = src;
|
|
2476
|
+
WP_TILE_SYNC();
|
|
2477
|
+
}
|
|
2478
|
+
|
|
2479
|
+
template <typename TileA, typename Scalar>
|
|
2480
|
+
inline CUDA_CALLABLE void assign(TileA& dest, int i, int j, int k, const Scalar& src)
|
|
2481
|
+
{
|
|
2482
|
+
dest.data(tile_coord(i, j, k)) = src;
|
|
2483
|
+
WP_TILE_SYNC();
|
|
2484
|
+
}
|
|
2485
|
+
|
|
2486
|
+
template <typename TileA, typename Scalar>
|
|
2487
|
+
inline CUDA_CALLABLE void assign(TileA& dest, int i, int j, int k, int l, const Scalar& src)
|
|
2488
|
+
{
|
|
2489
|
+
dest.data(tile_coord(i, j, k, l)) = src;
|
|
2490
|
+
WP_TILE_SYNC();
|
|
2491
|
+
}
|
|
2492
|
+
|
|
2493
|
+
|
|
2494
|
+
|
|
2495
|
+
|
|
2496
|
+
template <typename TileA, typename TileB, typename Coord>
|
|
2497
|
+
inline CUDA_CALLABLE void tile_assign(TileA& dest, TileB& src, const Coord& offset)
|
|
2498
|
+
{
|
|
2499
|
+
using Layout = typename TileB::Layout;
|
|
2500
|
+
|
|
2501
|
+
for (int t=WP_TILE_THREAD_IDX; t < Layout::Size; t += WP_TILE_BLOCK_DIM)
|
|
2502
|
+
{
|
|
2503
|
+
auto c = Layout::coord_from_linear(t);
|
|
2504
|
+
dest.data(c + offset) = src.data(c);
|
|
2505
|
+
}
|
|
2506
|
+
|
|
2507
|
+
WP_TILE_SYNC();
|
|
2508
|
+
}
|
|
2509
|
+
|
|
2510
|
+
template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB, typename Coord, typename AdjCoord>
|
|
2511
|
+
inline CUDA_CALLABLE void adj_tile_assign(TileA& dest, TileB& src, Coord offset,
|
|
2512
|
+
AdjTileA& adj_dest, AdjTileB& adj_src, AdjCoord adj_offset)
|
|
2513
|
+
{
|
|
2514
|
+
using Layout = typename TileB::Layout;
|
|
2515
|
+
|
|
2516
|
+
for (int t=WP_TILE_THREAD_IDX; t < Layout::Size; t += WP_TILE_BLOCK_DIM)
|
|
2517
|
+
{
|
|
2518
|
+
auto c = Layout::coord_from_linear(t);
|
|
2519
|
+
src.grad(c) += dest.grad(c + offset);
|
|
2520
|
+
}
|
|
2521
|
+
|
|
2522
|
+
WP_TILE_SYNC();
|
|
2523
|
+
}
|
|
2524
|
+
|
|
2525
|
+
|
|
2526
|
+
// codegen entry points, which emit calls like `tile_assign(dest, src, i, j, k)`
|
|
2527
|
+
// a better approach here would be for codegen to just directly generate `tile_assign(dest, src, tile_coord(i, j, k))`
|
|
2528
|
+
// i.e.: call the above implementation methods directly, then we could remove these overloads
|
|
2529
|
+
template <typename TileA, typename TileB>
|
|
2530
|
+
inline CUDA_CALLABLE void tile_assign(TileA& dest, TileB& src, int i) { tile_assign(dest, src, tile_coord(i)); }
|
|
2531
|
+
template <typename TileA, typename TileB>
|
|
2532
|
+
inline CUDA_CALLABLE void tile_assign(TileA& dest, TileB& src, int i, int j) { tile_assign(dest, src, tile_coord(i, j)); }
|
|
2533
|
+
template <typename TileA, typename TileB>
|
|
2534
|
+
inline CUDA_CALLABLE void tile_assign(TileA& dest, TileB& src, int i, int j, int k) { tile_assign(dest, src, tile_coord(i, j, k)); }
|
|
2535
|
+
template <typename TileA, typename TileB>
|
|
2536
|
+
inline CUDA_CALLABLE void tile_assign(TileA& dest, TileB& src, int i, int j, int k, int l) { tile_assign(dest, src, tile_coord(i, j, k, l)); }
|
|
2537
|
+
|
|
2538
|
+
template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB>
|
|
2539
|
+
inline CUDA_CALLABLE void adj_tile_assign(TileA& dest, TileB& src, int i, AdjTileA& adj_dest, AdjTileB& adj_src, int) { adj_tile_assign(dest, src, tile_coord(i), adj_dest, adj_src, tile_coord(0)); }
|
|
2540
|
+
template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB>
|
|
2541
|
+
inline CUDA_CALLABLE void adj_tile_assign(TileA& dest, TileB& src, int i, int j, AdjTileA& adj_dest, AdjTileB& adj_src, int, int) { adj_tile_assign(dest, src, tile_coord(i,j), adj_dest, adj_src, tile_coord(0)); }
|
|
2542
|
+
template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB>
|
|
2543
|
+
inline CUDA_CALLABLE void adj_tile_assign(TileA& dest, TileB& src, int i, int j, int k, AdjTileA& adj_dest, AdjTileB& adj_src, int, int, int) { adj_tile_assign(dest, src, tile_coord(i,j,k), adj_dest, adj_src, tile_coord(0)); }
|
|
2544
|
+
template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB>
|
|
2545
|
+
inline CUDA_CALLABLE void adj_tile_assign(TileA& dest, TileB& src, int i, int j, int k, int l, AdjTileA& adj_dest, AdjTileB& adj_src, int, int, int, int) { adj_tile_assign(dest, src, tile_coord(i,j,k,l), adj_dest, adj_src, tile_coord(0)); }
|
|
2546
|
+
|
|
2547
|
+
|
|
2548
|
+
template <typename TileA, typename TileB, typename TileC>
|
|
2549
|
+
inline CUDA_CALLABLE TileC& tile_diag_add(TileA& a, TileB& b, TileC& c)
|
|
2550
|
+
{
|
|
2551
|
+
using ShapeA = typename TileA::Layout::Shape;
|
|
2552
|
+
using ShapeB = typename TileB::Layout::Shape;
|
|
2553
|
+
using ShapeC = typename TileC::Layout::Shape;
|
|
2554
|
+
|
|
2555
|
+
static_assert(ShapeA::dim(0) == ShapeA::dim(1), "Expected ShapeA::dim(0) == ShapeA::dim(1)");
|
|
2556
|
+
static_assert(ShapeB::dim(0) == ShapeA::dim(0), "Expected ShapeB::dim(0) == ShapeA::dim(0)");
|
|
2557
|
+
static_assert(ShapeC::dim(0) == ShapeA::dim(0), "Expected ShapeC::dim(0) == ShapeA::dim(0)");
|
|
2558
|
+
static_assert(ShapeC::dim(0) == ShapeC::dim(1), "Expected ShapeC::dim(0) == ShapeC::dim(1)");
|
|
2559
|
+
|
|
2560
|
+
c = a;
|
|
2561
|
+
|
|
2562
|
+
for (int t=WP_TILE_THREAD_IDX; t < ShapeA::dim(0); t += WP_TILE_BLOCK_DIM)
|
|
2563
|
+
{
|
|
2564
|
+
c.data(tile_coord(t, t)) += b.data(tile_coord(t));
|
|
2565
|
+
}
|
|
2566
|
+
|
|
2567
|
+
WP_TILE_SYNC();
|
|
2568
|
+
|
|
2569
|
+
return c;
|
|
2570
|
+
}
|
|
2571
|
+
|
|
2572
|
+
template <typename TileA, typename TileB, typename TileC, typename AdjTileA, typename AdjTileB, typename AdjTileC>
|
|
2573
|
+
inline CUDA_CALLABLE void adj_tile_diag_add(TileA& a, TileB& b, TileC& c, AdjTileA& adj_a, AdjTileB& adj_b, AdjTileC& adj_c, AdjTileC& adj_ret)
|
|
2574
|
+
{
|
|
2575
|
+
assert(false);
|
|
2576
|
+
}
|
|
2577
|
+
|
|
2578
|
+
|
|
2579
|
+
} // namespace wp
|
|
2580
|
+
|
|
2581
|
+
|
|
2582
|
+
#ifdef __clang__
|
|
2583
|
+
#pragma clang diagnostic pop
|
|
2584
|
+
#endif
|