warp-lang 1.6.1__py3-none-win_amd64.whl → 1.7.0__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +21 -7
- warp/autograd.py +14 -6
- warp/bin/warp-clang.dll +0 -0
- warp/bin/warp.dll +0 -0
- warp/build.py +424 -6
- warp/build_dll.py +20 -20
- warp/builtins.py +467 -368
- warp/codegen.py +193 -125
- warp/config.py +56 -12
- warp/constants.py +14 -6
- warp/context.py +524 -277
- warp/dlpack.py +22 -12
- warp/examples/__init__.py +14 -6
- warp/examples/assets/nonuniform.usd +0 -0
- warp/examples/assets/nvidia_logo.png +0 -0
- warp/examples/benchmarks/benchmark_api.py +14 -6
- warp/examples/benchmarks/benchmark_cloth.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_cupy.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_jax.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_numba.py +15 -0
- warp/examples/benchmarks/benchmark_cloth_numpy.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_paddle.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_pytorch.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_taichi.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_warp.py +14 -6
- warp/examples/benchmarks/benchmark_gemm.py +82 -48
- warp/examples/benchmarks/benchmark_interop_paddle.py +14 -6
- warp/examples/benchmarks/benchmark_interop_torch.py +14 -6
- warp/examples/benchmarks/benchmark_launches.py +14 -6
- warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
- warp/examples/browse.py +14 -6
- warp/examples/core/example_cupy.py +14 -6
- warp/examples/core/example_dem.py +14 -6
- warp/examples/core/example_fluid.py +14 -6
- warp/examples/core/example_graph_capture.py +14 -6
- warp/examples/core/example_marching_cubes.py +14 -6
- warp/examples/core/example_mesh.py +14 -6
- warp/examples/core/example_mesh_intersect.py +14 -6
- warp/examples/core/example_nvdb.py +14 -6
- warp/examples/core/example_raycast.py +14 -6
- warp/examples/core/example_raymarch.py +14 -6
- warp/examples/core/example_render_opengl.py +14 -6
- warp/examples/core/example_sample_mesh.py +300 -0
- warp/examples/core/example_sph.py +14 -6
- warp/examples/core/example_torch.py +14 -6
- warp/examples/core/example_wave.py +14 -6
- warp/examples/fem/example_adaptive_grid.py +14 -6
- warp/examples/fem/example_apic_fluid.py +15 -7
- warp/examples/fem/example_burgers.py +16 -8
- warp/examples/fem/example_convection_diffusion.py +14 -6
- warp/examples/fem/example_convection_diffusion_dg.py +14 -6
- warp/examples/fem/example_deformed_geometry.py +15 -7
- warp/examples/fem/example_diffusion.py +14 -6
- warp/examples/fem/example_diffusion_3d.py +14 -6
- warp/examples/fem/example_diffusion_mgpu.py +14 -6
- warp/examples/fem/example_distortion_energy.py +15 -7
- warp/examples/fem/example_magnetostatics.py +20 -12
- warp/examples/fem/example_mixed_elasticity.py +14 -6
- warp/examples/fem/example_navier_stokes.py +14 -6
- warp/examples/fem/example_nonconforming_contact.py +14 -6
- warp/examples/fem/example_stokes.py +14 -6
- warp/examples/fem/example_stokes_transfer.py +14 -6
- warp/examples/fem/example_streamlines.py +14 -6
- warp/examples/fem/utils.py +24 -3
- warp/examples/interop/example_jax_callable.py +116 -0
- warp/examples/interop/example_jax_ffi_callback.py +132 -0
- warp/examples/interop/example_jax_kernel.py +205 -0
- warp/examples/optim/example_bounce.py +14 -6
- warp/examples/optim/example_cloth_throw.py +14 -6
- warp/examples/optim/example_diffray.py +14 -6
- warp/examples/optim/example_drone.py +14 -6
- warp/examples/optim/example_fluid_checkpoint.py +497 -0
- warp/examples/optim/example_inverse_kinematics.py +14 -6
- warp/examples/optim/example_inverse_kinematics_torch.py +14 -6
- warp/examples/optim/example_softbody_properties.py +14 -6
- warp/examples/optim/example_spring_cage.py +14 -6
- warp/examples/optim/example_trajectory.py +14 -6
- warp/examples/sim/example_cartpole.py +14 -6
- warp/examples/sim/example_cloth.py +14 -6
- warp/examples/sim/example_cloth_self_contact.py +14 -6
- warp/examples/sim/example_granular.py +14 -6
- warp/examples/sim/example_granular_collision_sdf.py +14 -6
- warp/examples/sim/example_jacobian_ik.py +14 -6
- warp/examples/sim/example_particle_chain.py +14 -6
- warp/examples/sim/example_quadruped.py +14 -6
- warp/examples/sim/example_rigid_chain.py +14 -6
- warp/examples/sim/example_rigid_contact.py +14 -6
- warp/examples/sim/example_rigid_force.py +14 -6
- warp/examples/sim/example_rigid_gyroscopic.py +14 -6
- warp/examples/sim/example_rigid_soft_contact.py +14 -6
- warp/examples/sim/example_soft_body.py +14 -6
- warp/examples/tile/example_tile_cholesky.py +14 -6
- warp/examples/tile/example_tile_convolution.py +14 -6
- warp/examples/tile/example_tile_fft.py +14 -6
- warp/examples/tile/example_tile_filtering.py +14 -6
- warp/examples/tile/example_tile_matmul.py +16 -10
- warp/examples/tile/example_tile_mlp.py +14 -6
- warp/examples/tile/example_tile_nbody.py +14 -6
- warp/examples/tile/example_tile_walker.py +14 -6
- warp/fabric.py +15 -0
- warp/fem/__init__.py +26 -1
- warp/fem/adaptivity.py +19 -4
- warp/fem/cache.py +15 -0
- warp/fem/dirichlet.py +15 -0
- warp/fem/domain.py +15 -0
- warp/fem/field/__init__.py +15 -0
- warp/fem/field/field.py +15 -0
- warp/fem/field/nodal_field.py +37 -68
- warp/fem/field/restriction.py +15 -0
- warp/fem/field/virtual.py +77 -23
- warp/fem/geometry/__init__.py +15 -0
- warp/fem/geometry/adaptive_nanogrid.py +24 -10
- warp/fem/geometry/closest_point.py +16 -1
- warp/fem/geometry/deformed_geometry.py +20 -2
- warp/fem/geometry/element.py +15 -0
- warp/fem/geometry/geometry.py +20 -0
- warp/fem/geometry/grid_2d.py +27 -12
- warp/fem/geometry/grid_3d.py +27 -15
- warp/fem/geometry/hexmesh.py +20 -7
- warp/fem/geometry/nanogrid.py +24 -11
- warp/fem/geometry/partition.py +15 -0
- warp/fem/geometry/quadmesh.py +28 -13
- warp/fem/geometry/tetmesh.py +18 -4
- warp/fem/geometry/trimesh.py +18 -8
- warp/fem/integrate.py +277 -93
- warp/fem/linalg.py +20 -5
- warp/fem/operator.py +15 -0
- warp/fem/polynomial.py +15 -0
- warp/fem/quadrature/__init__.py +15 -0
- warp/fem/quadrature/pic_quadrature.py +52 -22
- warp/fem/quadrature/quadrature.py +209 -25
- warp/fem/space/__init__.py +16 -1
- warp/fem/space/basis_function_space.py +19 -2
- warp/fem/space/basis_space.py +40 -18
- warp/fem/space/dof_mapper.py +15 -0
- warp/fem/space/function_space.py +15 -0
- warp/fem/space/grid_2d_function_space.py +15 -0
- warp/fem/space/grid_3d_function_space.py +15 -0
- warp/fem/space/hexmesh_function_space.py +17 -2
- warp/fem/space/nanogrid_function_space.py +15 -0
- warp/fem/space/partition.py +21 -2
- warp/fem/space/quadmesh_function_space.py +23 -8
- warp/fem/space/restriction.py +15 -0
- warp/fem/space/shape/__init__.py +15 -0
- warp/fem/space/shape/cube_shape_function.py +38 -23
- warp/fem/space/shape/shape_function.py +15 -0
- warp/fem/space/shape/square_shape_function.py +27 -12
- warp/fem/space/shape/tet_shape_function.py +15 -0
- warp/fem/space/shape/triangle_shape_function.py +16 -1
- warp/fem/space/tetmesh_function_space.py +18 -3
- warp/fem/space/topology.py +15 -0
- warp/fem/space/trimesh_function_space.py +17 -2
- warp/fem/types.py +15 -0
- warp/fem/utils.py +27 -6
- warp/jax.py +28 -7
- warp/jax_experimental/__init__.py +16 -0
- warp/{jax_experimental.py → jax_experimental/custom_call.py} +28 -33
- warp/jax_experimental/ffi.py +698 -0
- warp/jax_experimental/xla_ffi.py +602 -0
- warp/math.py +103 -6
- warp/native/array.h +28 -6
- warp/native/builtin.h +44 -9
- warp/native/bvh.cpp +18 -7
- warp/native/bvh.cu +57 -20
- warp/native/bvh.h +17 -7
- warp/native/clang/clang.cpp +45 -9
- warp/native/coloring.cpp +15 -6
- warp/native/crt.cpp +15 -6
- warp/native/crt.h +15 -6
- warp/native/cuda_crt.h +15 -6
- warp/native/cuda_util.cpp +29 -6
- warp/native/cuda_util.h +17 -6
- warp/native/error.cpp +15 -6
- warp/native/error.h +15 -6
- warp/native/exports.h +85 -63
- warp/native/fabric.h +15 -6
- warp/native/hashgrid.cpp +15 -6
- warp/native/hashgrid.cu +15 -6
- warp/native/hashgrid.h +15 -6
- warp/native/initializer_array.h +15 -6
- warp/native/intersect.h +41 -32
- warp/native/intersect_adj.h +48 -39
- warp/native/intersect_tri.h +17 -0
- warp/native/marching.cpp +16 -0
- warp/native/marching.cu +16 -7
- warp/native/marching.h +17 -0
- warp/native/mat.h +528 -15
- warp/native/mathdx.cpp +15 -6
- warp/native/matnn.h +15 -6
- warp/native/mesh.cpp +15 -6
- warp/native/mesh.cu +15 -6
- warp/native/mesh.h +25 -16
- warp/native/noise.h +15 -6
- warp/native/quat.h +114 -17
- warp/native/rand.h +21 -6
- warp/native/range.h +15 -6
- warp/native/reduce.cpp +15 -6
- warp/native/reduce.cu +15 -6
- warp/native/runlength_encode.cpp +15 -6
- warp/native/runlength_encode.cu +15 -6
- warp/native/scan.cpp +15 -6
- warp/native/scan.cu +15 -6
- warp/native/scan.h +15 -6
- warp/native/solid_angle.h +17 -0
- warp/native/sort.cpp +137 -65
- warp/native/sort.cu +167 -21
- warp/native/sort.h +23 -7
- warp/native/sparse.cpp +58 -28
- warp/native/sparse.cu +67 -23
- warp/native/spatial.h +15 -6
- warp/native/svd.h +131 -6
- warp/native/temp_buffer.h +15 -6
- warp/native/tile.h +316 -111
- warp/native/tile_reduce.h +61 -9
- warp/native/vec.h +83 -13
- warp/native/volume.cpp +100 -119
- warp/native/volume.cu +15 -6
- warp/native/volume.h +15 -6
- warp/native/volume_builder.cu +40 -16
- warp/native/volume_builder.h +21 -6
- warp/native/volume_impl.h +15 -6
- warp/native/warp.cpp +20 -12
- warp/native/warp.cu +114 -16
- warp/native/warp.h +34 -16
- warp/optim/__init__.py +14 -6
- warp/optim/adam.py +14 -6
- warp/optim/linear.py +25 -10
- warp/optim/sgd.py +14 -6
- warp/paddle.py +14 -6
- warp/render/__init__.py +14 -6
- warp/render/render_opengl.py +14 -6
- warp/render/render_usd.py +14 -6
- warp/render/utils.py +14 -6
- warp/sim/__init__.py +14 -7
- warp/sim/articulation.py +18 -10
- warp/sim/collide.py +35 -16
- warp/sim/graph_coloring.py +14 -6
- warp/sim/import_mjcf.py +463 -162
- warp/sim/import_snu.py +14 -7
- warp/sim/import_urdf.py +46 -18
- warp/sim/import_usd.py +14 -7
- warp/sim/inertia.py +14 -6
- warp/sim/integrator.py +14 -6
- warp/sim/integrator_euler.py +19 -11
- warp/sim/integrator_featherstone.py +17 -16
- warp/sim/integrator_vbd.py +222 -8
- warp/sim/integrator_xpbd.py +19 -11
- warp/sim/model.py +56 -19
- warp/sim/particles.py +14 -6
- warp/sim/render.py +14 -6
- warp/sim/utils.py +17 -2
- warp/sparse.py +657 -555
- warp/stubs.py +231 -19
- warp/tape.py +14 -6
- warp/tests/aux_test_class_kernel.py +14 -6
- warp/tests/aux_test_compile_consts_dummy.py +14 -6
- warp/tests/aux_test_conditional_unequal_types_kernels.py +14 -6
- warp/tests/aux_test_dependent.py +14 -6
- warp/tests/aux_test_grad_customs.py +14 -6
- warp/tests/aux_test_instancing_gc.py +14 -6
- warp/tests/aux_test_module_unload.py +14 -6
- warp/tests/aux_test_name_clash1.py +14 -6
- warp/tests/aux_test_name_clash2.py +14 -6
- warp/tests/aux_test_unresolved_func.py +14 -6
- warp/tests/aux_test_unresolved_symbol.py +14 -6
- warp/tests/cuda/__init__.py +0 -0
- warp/tests/{test_async.py → cuda/test_async.py} +14 -6
- warp/tests/{test_ipc.py → cuda/test_ipc.py} +14 -6
- warp/tests/{test_mempool.py → cuda/test_mempool.py} +53 -6
- warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +14 -6
- warp/tests/{test_peer.py → cuda/test_peer.py} +14 -6
- warp/tests/{test_pinned.py → cuda/test_pinned.py} +14 -6
- warp/tests/{test_streams.py → cuda/test_streams.py} +85 -6
- warp/tests/geometry/__init__.py +0 -0
- warp/tests/{test_bvh.py → geometry/test_bvh.py} +14 -6
- warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +14 -6
- warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +14 -6
- warp/tests/{test_mesh.py → geometry/test_mesh.py} +14 -6
- warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +14 -6
- warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +80 -69
- warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +15 -7
- warp/tests/{test_volume.py → geometry/test_volume.py} +55 -12
- warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +14 -6
- warp/tests/interop/__init__.py +0 -0
- warp/tests/{test_dlpack.py → interop/test_dlpack.py} +42 -11
- warp/tests/{test_jax.py → interop/test_jax.py} +14 -6
- warp/tests/{test_paddle.py → interop/test_paddle.py} +14 -6
- warp/tests/{test_torch.py → interop/test_torch.py} +14 -6
- warp/tests/run_coverage_serial.py +14 -6
- warp/tests/sim/__init__.py +0 -0
- warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +23 -16
- warp/tests/{flaky_test_sim_grad.py → sim/flaky_test_sim_grad.py} +14 -6
- warp/tests/{test_collision.py → sim/test_collision.py} +16 -8
- warp/tests/{test_coloring.py → sim/test_coloring.py} +14 -7
- warp/tests/{test_model.py → sim/test_model.py} +55 -7
- warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +14 -6
- warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +16 -7
- warp/tests/sim/test_vbd.py +597 -0
- warp/tests/test_adam.py +14 -6
- warp/tests/test_arithmetic.py +14 -6
- warp/tests/test_array.py +14 -6
- warp/tests/test_array_reduce.py +14 -6
- warp/tests/test_assert.py +14 -6
- warp/tests/test_atomic.py +14 -6
- warp/tests/test_bool.py +15 -7
- warp/tests/test_builtins_resolution.py +14 -6
- warp/tests/test_closest_point_edge_edge.py +14 -6
- warp/tests/test_codegen.py +14 -6
- warp/tests/test_codegen_instancing.py +14 -6
- warp/tests/test_compile_consts.py +14 -6
- warp/tests/test_conditional.py +14 -6
- warp/tests/test_context.py +14 -6
- warp/tests/test_copy.py +14 -6
- warp/tests/test_ctypes.py +14 -6
- warp/tests/test_dense.py +14 -6
- warp/tests/test_devices.py +14 -6
- warp/tests/test_examples.py +42 -42
- warp/tests/test_fabricarray.py +14 -6
- warp/tests/test_fast_math.py +14 -6
- warp/tests/test_fem.py +37 -10
- warp/tests/test_fp16.py +14 -6
- warp/tests/test_func.py +14 -6
- warp/tests/test_future_annotations.py +14 -6
- warp/tests/test_generics.py +14 -6
- warp/tests/test_grad.py +14 -6
- warp/tests/test_grad_customs.py +14 -6
- warp/tests/test_grad_debug.py +14 -6
- warp/tests/test_implicit_init.py +14 -6
- warp/tests/test_import.py +14 -6
- warp/tests/test_indexedarray.py +14 -6
- warp/tests/test_intersect.py +14 -6
- warp/tests/test_iter.py +14 -6
- warp/tests/test_large.py +14 -6
- warp/tests/test_launch.py +14 -6
- warp/tests/test_lerp.py +14 -6
- warp/tests/test_linear_solvers.py +15 -11
- warp/tests/test_lvalue.py +14 -6
- warp/tests/test_mat.py +247 -85
- warp/tests/test_mat_lite.py +14 -6
- warp/tests/test_mat_scalar_ops.py +18 -10
- warp/tests/test_math.py +14 -6
- warp/tests/test_mlp.py +14 -6
- warp/tests/test_module_hashing.py +14 -6
- warp/tests/test_modules_lite.py +14 -6
- warp/tests/test_noise.py +14 -6
- warp/tests/test_operators.py +14 -6
- warp/tests/test_options.py +14 -6
- warp/tests/test_overwrite.py +15 -60
- warp/tests/test_print.py +14 -6
- warp/tests/test_quat.py +81 -52
- warp/tests/test_rand.py +58 -43
- warp/tests/test_reload.py +14 -6
- warp/tests/test_rounding.py +14 -6
- warp/tests/test_runlength_encode.py +14 -6
- warp/tests/test_scalar_ops.py +14 -6
- warp/tests/test_smoothstep.py +14 -6
- warp/tests/test_snippet.py +15 -0
- warp/tests/test_sparse.py +61 -12
- warp/tests/test_spatial.py +89 -6
- warp/tests/test_special_values.py +14 -6
- warp/tests/test_static.py +15 -7
- warp/tests/test_struct.py +14 -6
- warp/tests/test_tape.py +14 -6
- warp/tests/test_transient_module.py +14 -6
- warp/tests/test_triangle_closest_point.py +14 -6
- warp/tests/test_types.py +14 -6
- warp/tests/test_utils.py +98 -10
- warp/tests/test_vec.py +60 -40
- warp/tests/test_vec_lite.py +14 -6
- warp/tests/test_vec_scalar_ops.py +14 -6
- warp/tests/test_verify_fp.py +14 -6
- warp/tests/tile/__init__.py +0 -0
- warp/tests/{test_tile.py → tile/test_tile.py} +150 -57
- warp/tests/{test_tile_load.py → tile/test_tile_load.py} +15 -7
- warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +23 -12
- warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +39 -20
- warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +74 -7
- warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +14 -6
- warp/tests/{test_tile_view.py → tile/test_tile_view.py} +15 -7
- warp/tests/unittest_serial.py +15 -6
- warp/tests/unittest_suites.py +59 -65
- warp/tests/unittest_utils.py +16 -7
- warp/tests/walkthrough_debug.py +14 -6
- warp/thirdparty/unittest_parallel.py +15 -8
- warp/torch.py +14 -6
- warp/types.py +124 -664
- warp/utils.py +151 -78
- {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/METADATA +39 -12
- warp_lang-1.7.0.dist-info/RECORD +429 -0
- {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/WHEEL +1 -1
- warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
- warp/examples/optim/example_walker.py +0 -309
- warp/native/cutlass_gemm.cpp +0 -34
- warp/native/cutlass_gemm.cu +0 -373
- warp/tests/test_matmul.py +0 -503
- warp/tests/test_matmul_lite.py +0 -403
- warp/tests/test_vbd.py +0 -378
- warp/tests/unused_test_misc.py +0 -69
- warp_lang-1.6.1.dist-info/LICENSE.md +0 -126
- warp_lang-1.6.1.dist-info/RECORD +0 -419
- {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/top_level.txt +0 -0
warp/native/tile.h
CHANGED
|
@@ -1,18 +1,57 @@
|
|
|
1
|
-
|
|
2
|
-
* NVIDIA CORPORATION
|
|
3
|
-
*
|
|
4
|
-
*
|
|
5
|
-
*
|
|
6
|
-
*
|
|
1
|
+
/*
|
|
2
|
+
* SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
3
|
+
* SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
*
|
|
5
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
* you may not use this file except in compliance with the License.
|
|
7
|
+
* You may obtain a copy of the License at
|
|
8
|
+
*
|
|
9
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
*
|
|
11
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
* See the License for the specific language governing permissions and
|
|
15
|
+
* limitations under the License.
|
|
7
16
|
*/
|
|
8
17
|
|
|
9
18
|
#pragma once
|
|
10
19
|
|
|
11
20
|
#include "builtin.h"
|
|
12
21
|
|
|
22
|
+
#ifdef __clang__
|
|
23
|
+
// disable warnings related to C++17 extensions on CPU JIT builds
|
|
24
|
+
#pragma clang diagnostic push
|
|
25
|
+
#pragma clang diagnostic ignored "-Wc++17-extensions"
|
|
26
|
+
#endif // __clang__
|
|
27
|
+
|
|
28
|
+
// Check if the CUDA toolkit is available
|
|
29
|
+
#if WP_ENABLE_CUDA || defined(__CUDACC_RTC__)
|
|
30
|
+
|
|
31
|
+
// If NVRTC is being used, do not include extra headers (NVRTC has built-in float4)
|
|
32
|
+
#ifdef __CUDACC_RTC__
|
|
33
|
+
// NVRTC: Use built-in float4 (no need for extra definitions)
|
|
34
|
+
#else
|
|
35
|
+
// NVCC: Include vector_types.h to get float4
|
|
36
|
+
#include <cuda_runtime.h>
|
|
37
|
+
#endif
|
|
38
|
+
|
|
39
|
+
#else
|
|
40
|
+
// If CUDA is not available (e.g., macOS build), manually define float4
|
|
41
|
+
struct alignas(16) float4 {
|
|
42
|
+
float x, y, z, w;
|
|
43
|
+
};
|
|
44
|
+
#endif
|
|
45
|
+
|
|
46
|
+
// only used while building the warp core library
|
|
47
|
+
#ifndef WP_TILE_BLOCK_DIM
|
|
48
|
+
#define WP_TILE_BLOCK_DIM 256
|
|
49
|
+
#endif
|
|
50
|
+
|
|
13
51
|
#if !defined(__CUDA_ARCH__)
|
|
14
52
|
#define WP_TILE_SHARED static
|
|
15
53
|
#define WP_TILE_SYNC void
|
|
54
|
+
|
|
16
55
|
#else
|
|
17
56
|
#define WP_TILE_SHARED __shared__
|
|
18
57
|
#define WP_TILE_SYNC __syncthreads
|
|
@@ -37,6 +76,14 @@
|
|
|
37
76
|
#define WP_USE_ASYNC_PIPELINE 0
|
|
38
77
|
#define WP_USE_REGISTER_GEMM 0
|
|
39
78
|
|
|
79
|
+
#if defined(__CUDACC_RTC__)
|
|
80
|
+
#define WP_TILE_THREAD_IDX threadIdx.x
|
|
81
|
+
#else
|
|
82
|
+
#define WP_TILE_THREAD_IDX 0
|
|
83
|
+
#endif //
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
|
|
40
87
|
/* Tile Expressions
|
|
41
88
|
|
|
42
89
|
[ ] Tiles
|
|
@@ -208,14 +255,14 @@ constexpr tile_coord_t<sizeof...(Ints)> tile_coord(Ints... idxs)
|
|
|
208
255
|
}
|
|
209
256
|
|
|
210
257
|
// helpers to construct a coord from a set of indices
|
|
211
|
-
auto tile_coord(int i)
|
|
258
|
+
inline auto tile_coord(int i)
|
|
212
259
|
{
|
|
213
260
|
auto c = tile_coord_t<1>();
|
|
214
261
|
c.indices[0] = i;
|
|
215
262
|
return c;
|
|
216
263
|
}
|
|
217
264
|
|
|
218
|
-
auto tile_coord(int i, int j)
|
|
265
|
+
inline auto tile_coord(int i, int j)
|
|
219
266
|
{
|
|
220
267
|
auto c = tile_coord_t<2>();
|
|
221
268
|
c.indices[0] = i;
|
|
@@ -223,7 +270,7 @@ auto tile_coord(int i, int j)
|
|
|
223
270
|
return c;
|
|
224
271
|
}
|
|
225
272
|
|
|
226
|
-
auto tile_coord(int i, int j, int k)
|
|
273
|
+
inline auto tile_coord(int i, int j, int k)
|
|
227
274
|
{
|
|
228
275
|
auto c = tile_coord_t<3>();
|
|
229
276
|
c.indices[0] = i;
|
|
@@ -232,7 +279,7 @@ auto tile_coord(int i, int j, int k)
|
|
|
232
279
|
return c;
|
|
233
280
|
}
|
|
234
281
|
|
|
235
|
-
auto tile_coord(int i, int j, int k, int l)
|
|
282
|
+
inline auto tile_coord(int i, int j, int k, int l)
|
|
236
283
|
{
|
|
237
284
|
auto c = tile_coord_t<4>();
|
|
238
285
|
c.indices[0] = i;
|
|
@@ -247,7 +294,7 @@ template <int... V>
|
|
|
247
294
|
struct tile_tuple_t
|
|
248
295
|
{
|
|
249
296
|
static constexpr int N = sizeof...(V);
|
|
250
|
-
static_assert(N > 0);
|
|
297
|
+
static_assert(N > 0, "Expected N > 0");
|
|
251
298
|
|
|
252
299
|
static constexpr int data[N] = { V... };
|
|
253
300
|
|
|
@@ -400,7 +447,7 @@ struct tile_layout_register_t
|
|
|
400
447
|
|
|
401
448
|
static inline CUDA_CALLABLE int linear_from_register(int reg)
|
|
402
449
|
{
|
|
403
|
-
return
|
|
450
|
+
return WP_TILE_THREAD_IDX + reg*WP_TILE_BLOCK_DIM;
|
|
404
451
|
}
|
|
405
452
|
|
|
406
453
|
static inline CUDA_CALLABLE int linear_from_coord(Coord c)
|
|
@@ -500,15 +547,6 @@ struct tile_register_t
|
|
|
500
547
|
return data[reg];
|
|
501
548
|
}
|
|
502
549
|
|
|
503
|
-
// Returns the number of valid registers for this tile
|
|
504
|
-
// i.e.: how many registers map to a valid coordinate.
|
|
505
|
-
// When a tile's size is not aligned to the block dimension
|
|
506
|
-
// some of the trailing registers may lie outside the valid range
|
|
507
|
-
inline CUDA_CALLABLE int valid() const
|
|
508
|
-
{
|
|
509
|
-
return (int)floor(float(Size - threadIdx.x - 1)/WP_TILE_BLOCK_DIM) + 1;
|
|
510
|
-
}
|
|
511
|
-
|
|
512
550
|
inline CUDA_CALLABLE void assign(const tile_register_t<T, Layout>& tile)
|
|
513
551
|
{
|
|
514
552
|
for (int i=0; i < Layout::NumRegs; ++i)
|
|
@@ -535,7 +573,7 @@ struct tile_register_t
|
|
|
535
573
|
// ensure any previously scheduled threads have finished reading from scratch
|
|
536
574
|
WP_TILE_SYNC();
|
|
537
575
|
|
|
538
|
-
if (
|
|
576
|
+
if (WP_TILE_THREAD_IDX == thread)
|
|
539
577
|
{
|
|
540
578
|
scratch = data[reg];
|
|
541
579
|
}
|
|
@@ -556,7 +594,7 @@ struct tile_register_t
|
|
|
556
594
|
const int thread = Layout::thread_from_linear(linear);
|
|
557
595
|
const int reg = Layout::register_from_linear(linear);
|
|
558
596
|
|
|
559
|
-
if (
|
|
597
|
+
if (WP_TILE_THREAD_IDX == thread)
|
|
560
598
|
{
|
|
561
599
|
data[reg] += adj_ret;
|
|
562
600
|
}
|
|
@@ -659,7 +697,7 @@ struct tile_register_t
|
|
|
659
697
|
// users can either specify a template explicitly or
|
|
660
698
|
// pass in another concrete instance
|
|
661
699
|
template<typename Tile>
|
|
662
|
-
auto tile_register_like(Tile* t=
|
|
700
|
+
auto tile_register_like(Tile* t=nullptr)
|
|
663
701
|
{
|
|
664
702
|
using T = typename Tile::Type;
|
|
665
703
|
using L = typename Tile::Layout;
|
|
@@ -685,26 +723,39 @@ inline CUDA_CALLABLE int tile_align(int num_bytes)
|
|
|
685
723
|
return sign * ((num_bytes_abs + alignment - 1) / alignment) * alignment;
|
|
686
724
|
}
|
|
687
725
|
|
|
688
|
-
inline CUDA_CALLABLE void* tile_alloc_shared(int num_bytes, bool init=false)
|
|
726
|
+
inline CUDA_CALLABLE void* tile_alloc_shared(int num_bytes, bool init=false, bool check=false)
|
|
689
727
|
{
|
|
690
728
|
// we maintain a per-thread offset into dynamic
|
|
691
729
|
// shared memory that allows us to keep track of
|
|
692
730
|
// current use across dynamic function calls
|
|
693
|
-
|
|
731
|
+
WP_TILE_SHARED int smem_base[WP_TILE_BLOCK_DIM];
|
|
694
732
|
|
|
695
733
|
if (init)
|
|
696
734
|
{
|
|
697
|
-
smem_base[
|
|
698
|
-
return
|
|
735
|
+
smem_base[WP_TILE_THREAD_IDX] = 0;
|
|
736
|
+
return nullptr;
|
|
737
|
+
}
|
|
738
|
+
else if (check)
|
|
739
|
+
{
|
|
740
|
+
assert(smem_base[WP_TILE_THREAD_IDX] == 0);
|
|
741
|
+
return nullptr;
|
|
699
742
|
}
|
|
700
743
|
else
|
|
701
744
|
{
|
|
702
|
-
const int offset = smem_base[
|
|
745
|
+
const int offset = smem_base[WP_TILE_THREAD_IDX];
|
|
703
746
|
|
|
704
747
|
// one entry per-thread so no need for synchronization
|
|
705
|
-
smem_base[
|
|
748
|
+
smem_base[WP_TILE_THREAD_IDX] += tile_align(num_bytes);
|
|
706
749
|
|
|
750
|
+
#ifdef __CUDA_ARCH__
|
|
707
751
|
extern __shared__ char dynamic_smem_base[];
|
|
752
|
+
#else
|
|
753
|
+
// on CPU allocate a fixed 256k block to use for shared allocs
|
|
754
|
+
static const int max_cpu_shared = 256*1024;
|
|
755
|
+
static char dynamic_smem_base[max_cpu_shared];
|
|
756
|
+
|
|
757
|
+
assert(smem_base[WP_TILE_THREAD_IDX] <= max_cpu_shared);
|
|
758
|
+
#endif
|
|
708
759
|
return &(dynamic_smem_base[offset]);
|
|
709
760
|
}
|
|
710
761
|
}
|
|
@@ -838,12 +889,12 @@ struct tile_shared_t
|
|
|
838
889
|
bool initialized;
|
|
839
890
|
|
|
840
891
|
// default initialization (non-initialized)
|
|
841
|
-
inline CUDA_CALLABLE tile_shared_t() : data(
|
|
892
|
+
inline CUDA_CALLABLE tile_shared_t() : data(nullptr), grad(nullptr), initialized(false)
|
|
842
893
|
{
|
|
843
894
|
}
|
|
844
895
|
|
|
845
896
|
// initialize from an existing tile's memory
|
|
846
|
-
inline CUDA_CALLABLE tile_shared_t(T* data, T* grad=
|
|
897
|
+
inline CUDA_CALLABLE tile_shared_t(T* data, T* grad=nullptr, bool initialized=true) : data(data), grad(grad), initialized(initialized)
|
|
847
898
|
{
|
|
848
899
|
}
|
|
849
900
|
|
|
@@ -869,6 +920,7 @@ struct tile_shared_t
|
|
|
869
920
|
}
|
|
870
921
|
|
|
871
922
|
|
|
923
|
+
/*
|
|
872
924
|
// construct from another shared tile, this constructor
|
|
873
925
|
// is invoked for reshape operations like `wp.tile_transpose()`
|
|
874
926
|
template <typename OtherT, typename OtherLayout>
|
|
@@ -877,7 +929,7 @@ struct tile_shared_t
|
|
|
877
929
|
using OtherTile = tile_shared_t<OtherT, OtherLayout>;
|
|
878
930
|
|
|
879
931
|
// check dimensions are compatible
|
|
880
|
-
static_assert(Size == OtherTile::Size);
|
|
932
|
+
static_assert(Size == OtherTile::Size, "Expected Size == OtherTile::Size");
|
|
881
933
|
|
|
882
934
|
// alias tile directly
|
|
883
935
|
data = rhs.data;
|
|
@@ -886,6 +938,7 @@ struct tile_shared_t
|
|
|
886
938
|
|
|
887
939
|
return *this;
|
|
888
940
|
}
|
|
941
|
+
*/
|
|
889
942
|
|
|
890
943
|
// assign from a global tile (load)
|
|
891
944
|
inline CUDA_CALLABLE auto& operator=(const tile_global_t<T, typename Layout::Shape>& t)
|
|
@@ -903,7 +956,7 @@ struct tile_shared_t
|
|
|
903
956
|
if (initialized)
|
|
904
957
|
WP_TILE_SYNC();
|
|
905
958
|
|
|
906
|
-
for (int i=
|
|
959
|
+
for (int i=WP_TILE_THREAD_IDX; i < Layout::Size; i+= WP_TILE_BLOCK_DIM)
|
|
907
960
|
data(i) = x;
|
|
908
961
|
|
|
909
962
|
initialized = true;
|
|
@@ -914,7 +967,7 @@ struct tile_shared_t
|
|
|
914
967
|
// in-place zero
|
|
915
968
|
inline CUDA_CALLABLE void zero()
|
|
916
969
|
{
|
|
917
|
-
for (int i=
|
|
970
|
+
for (int i=WP_TILE_THREAD_IDX; i < Layout::Size; i+= WP_TILE_BLOCK_DIM)
|
|
918
971
|
data(i) = T(0);
|
|
919
972
|
|
|
920
973
|
WP_TILE_SYNC();
|
|
@@ -964,7 +1017,7 @@ struct tile_shared_t
|
|
|
964
1017
|
// in-place gradient zero
|
|
965
1018
|
inline CUDA_CALLABLE void grad_zero()
|
|
966
1019
|
{
|
|
967
|
-
for (int i=
|
|
1020
|
+
for (int i=WP_TILE_THREAD_IDX; i < Layout::Size; i+= WP_TILE_BLOCK_DIM)
|
|
968
1021
|
grad(i) = T(0);
|
|
969
1022
|
|
|
970
1023
|
WP_TILE_SYNC();
|
|
@@ -1004,7 +1057,7 @@ struct tile_shared_t
|
|
|
1004
1057
|
CUDA_CALLABLE void grad_add(const tile_global_t<T, typename Layout::Shape>& global)
|
|
1005
1058
|
{
|
|
1006
1059
|
WP_PRAGMA_UNROLL
|
|
1007
|
-
for (int i=
|
|
1060
|
+
for (int i=WP_TILE_THREAD_IDX; i < Layout::Size; i += WP_TILE_BLOCK_DIM)
|
|
1008
1061
|
{
|
|
1009
1062
|
auto c = Layout::coord_from_linear(i);
|
|
1010
1063
|
T g = global.load_grad(c);
|
|
@@ -1072,6 +1125,8 @@ struct tile_shared_t
|
|
|
1072
1125
|
template <typename Global>
|
|
1073
1126
|
inline CUDA_CALLABLE void copy_to_global(const Global& dest)
|
|
1074
1127
|
{
|
|
1128
|
+
|
|
1129
|
+
#if defined(__CUDA_ARCH__)
|
|
1075
1130
|
// vectorized loads for specific input/output shapes
|
|
1076
1131
|
if constexpr (Layout::Shape::N == 2)
|
|
1077
1132
|
{
|
|
@@ -1100,7 +1155,7 @@ struct tile_shared_t
|
|
|
1100
1155
|
const int stride_j = 1;
|
|
1101
1156
|
|
|
1102
1157
|
WP_PRAGMA_UNROLL
|
|
1103
|
-
for (int i=
|
|
1158
|
+
for (int i=WP_TILE_THREAD_IDX; i < SrcLayout::Size; i += WP_TILE_BLOCK_DIM)
|
|
1104
1159
|
{
|
|
1105
1160
|
auto c = SrcLayout::coord_from_linear(i);
|
|
1106
1161
|
|
|
@@ -1111,17 +1166,18 @@ struct tile_shared_t
|
|
|
1111
1166
|
}
|
|
1112
1167
|
}
|
|
1113
1168
|
|
|
1169
|
+
#endif //defined(__CUDA_ARCH__)
|
|
1170
|
+
|
|
1114
1171
|
// scalar bounds checked path
|
|
1115
1172
|
WP_PRAGMA_UNROLL
|
|
1116
|
-
for (int i=
|
|
1173
|
+
for (int i=WP_TILE_THREAD_IDX; i < Layout::Size; i += WP_TILE_BLOCK_DIM)
|
|
1117
1174
|
{
|
|
1118
1175
|
auto c = Layout::coord_from_linear(i);
|
|
1119
1176
|
dest.store(c, data(i));
|
|
1120
1177
|
}
|
|
1121
1178
|
}
|
|
1122
1179
|
|
|
1123
|
-
|
|
1124
|
-
void cp_async_global_to_shared_128(float4* shared_dest, const float4* global_src)
|
|
1180
|
+
inline CUDA_CALLABLE void cp_async_global_to_shared_128(float4* shared_dest, const float4* global_src)
|
|
1125
1181
|
{
|
|
1126
1182
|
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
|
|
1127
1183
|
|
|
@@ -1143,8 +1199,7 @@ struct tile_shared_t
|
|
|
1143
1199
|
#endif
|
|
1144
1200
|
}
|
|
1145
1201
|
|
|
1146
|
-
|
|
1147
|
-
void cp_async_commit_and_wait_all_128()
|
|
1202
|
+
inline CUDA_CALLABLE void cp_async_commit_and_wait_all_128()
|
|
1148
1203
|
{
|
|
1149
1204
|
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
|
|
1150
1205
|
asm volatile(
|
|
@@ -1159,6 +1214,8 @@ struct tile_shared_t
|
|
|
1159
1214
|
if (initialized)
|
|
1160
1215
|
WP_TILE_SYNC();
|
|
1161
1216
|
|
|
1217
|
+
#if defined(__CUDA_ARCH__)
|
|
1218
|
+
|
|
1162
1219
|
// vectorized loads for specific input/output shapes
|
|
1163
1220
|
if constexpr (Layout::Shape::N == 2)
|
|
1164
1221
|
{
|
|
@@ -1187,7 +1244,7 @@ struct tile_shared_t
|
|
|
1187
1244
|
const int stride_j = 1;
|
|
1188
1245
|
|
|
1189
1246
|
WP_PRAGMA_UNROLL
|
|
1190
|
-
for (int i=
|
|
1247
|
+
for (int i=WP_TILE_THREAD_IDX; i < DestLayout::Size; i += WP_TILE_BLOCK_DIM)
|
|
1191
1248
|
{
|
|
1192
1249
|
auto c = DestLayout::coord_from_linear(i);
|
|
1193
1250
|
|
|
@@ -1208,9 +1265,11 @@ struct tile_shared_t
|
|
|
1208
1265
|
}
|
|
1209
1266
|
}
|
|
1210
1267
|
|
|
1268
|
+
#endif //defined(__CUDA_ARCH__)
|
|
1269
|
+
|
|
1211
1270
|
// scalar bounds checked path
|
|
1212
1271
|
WP_PRAGMA_UNROLL
|
|
1213
|
-
for (int i=
|
|
1272
|
+
for (int i=WP_TILE_THREAD_IDX; i < Layout::Size; i += WP_TILE_BLOCK_DIM)
|
|
1214
1273
|
{
|
|
1215
1274
|
auto c = Layout::coord_from_linear(i);
|
|
1216
1275
|
data(i) = src.load(c);
|
|
@@ -1323,7 +1382,7 @@ struct tile_shared_t
|
|
|
1323
1382
|
|
|
1324
1383
|
inline CUDA_CALLABLE void print(bool reverse=false) const
|
|
1325
1384
|
{
|
|
1326
|
-
if (
|
|
1385
|
+
if (WP_TILE_THREAD_IDX != 0)
|
|
1327
1386
|
return;
|
|
1328
1387
|
|
|
1329
1388
|
if (reverse)
|
|
@@ -1350,13 +1409,13 @@ void tile_register_t<T, L>::print() const
|
|
|
1350
1409
|
// create a temporary shared tile so that
|
|
1351
1410
|
// we can print it deterministically
|
|
1352
1411
|
WP_TILE_SHARED T smem[L::Size];
|
|
1353
|
-
tile_shared_t<T, tile_layout_strided_t<typename L::Shape
|
|
1412
|
+
tile_shared_t<T, tile_layout_strided_t<typename L::Shape>, false> scratch(smem, nullptr);
|
|
1354
1413
|
|
|
1355
1414
|
scratch.assign(*this);
|
|
1356
1415
|
|
|
1357
1416
|
WP_TILE_SYNC();
|
|
1358
1417
|
|
|
1359
|
-
if (
|
|
1418
|
+
if (WP_TILE_THREAD_IDX == 0)
|
|
1360
1419
|
{
|
|
1361
1420
|
scratch.print_values(scratch.data, 0);
|
|
1362
1421
|
|
|
@@ -1383,7 +1442,7 @@ inline CUDA_CALLABLE void print(const tile_shared_t<T, L, Owner>& t) { t.print()
|
|
|
1383
1442
|
template <typename T, typename L, bool O>
|
|
1384
1443
|
inline CUDA_CALLABLE int len(const tile_shared_t<T, L, O>& t)
|
|
1385
1444
|
{
|
|
1386
|
-
return
|
|
1445
|
+
return L::Shape::dim(0);
|
|
1387
1446
|
}
|
|
1388
1447
|
|
|
1389
1448
|
template <typename T, typename L, bool O, typename AdjTile>
|
|
@@ -1394,7 +1453,7 @@ inline CUDA_CALLABLE void adj_len(const tile_shared_t<T,L,O>& t, const AdjTile&
|
|
|
1394
1453
|
template <typename T, typename L>
|
|
1395
1454
|
inline CUDA_CALLABLE int len(const tile_register_t<T, L>& t)
|
|
1396
1455
|
{
|
|
1397
|
-
return
|
|
1456
|
+
return L::Shape::dim(0);
|
|
1398
1457
|
}
|
|
1399
1458
|
|
|
1400
1459
|
template <typename T, typename L, typename AdjTile>
|
|
@@ -1416,12 +1475,16 @@ inline CUDA_CALLABLE auto tile_alloc_empty()
|
|
|
1416
1475
|
|
|
1417
1476
|
{ constexpr int size = Shape::size();
|
|
1418
1477
|
T* data = (T*)tile_alloc_shared(size*sizeof(T));
|
|
1419
|
-
T* grad =
|
|
1478
|
+
T* grad = nullptr;
|
|
1420
1479
|
|
|
1421
1480
|
#if FP_CHECK
|
|
1422
1481
|
|
|
1423
|
-
|
|
1424
|
-
|
|
1482
|
+
// initialize tile to quiet nan
|
|
1483
|
+
uint32_t qnanbits = 0x7FC00000;
|
|
1484
|
+
float qnan = *(float*)(&qnanbits);
|
|
1485
|
+
|
|
1486
|
+
for (int i=WP_TILE_THREAD_IDX; i < size; i+= WP_TILE_BLOCK_DIM)
|
|
1487
|
+
data[i] = T(qnan);
|
|
1425
1488
|
|
|
1426
1489
|
WP_TILE_SYNC();
|
|
1427
1490
|
|
|
@@ -1432,7 +1495,7 @@ inline CUDA_CALLABLE auto tile_alloc_empty()
|
|
|
1432
1495
|
{
|
|
1433
1496
|
grad = (T*)tile_alloc_shared(size*sizeof(T));
|
|
1434
1497
|
|
|
1435
|
-
for (int i=
|
|
1498
|
+
for (int i=WP_TILE_THREAD_IDX; i < size; i+= WP_TILE_BLOCK_DIM)
|
|
1436
1499
|
grad[i] = T(0);
|
|
1437
1500
|
|
|
1438
1501
|
WP_TILE_SYNC();
|
|
@@ -1441,30 +1504,6 @@ inline CUDA_CALLABLE auto tile_alloc_empty()
|
|
|
1441
1504
|
return tile_shared_t<T, tile_layout_strided_t<Shape>>(data, grad);
|
|
1442
1505
|
}
|
|
1443
1506
|
|
|
1444
|
-
template <typename T, int M, int N, bool RequiresGrad>
|
|
1445
|
-
inline CUDA_CALLABLE auto tile_alloc_zeros()
|
|
1446
|
-
{
|
|
1447
|
-
// compute the total storage required for the tile (may be different from M*N) for broadcast tiles
|
|
1448
|
-
constexpr int Len = M*N;
|
|
1449
|
-
T* data = (T*)tile_alloc_shared(Len*sizeof(T));
|
|
1450
|
-
T* grad = NULL;
|
|
1451
|
-
|
|
1452
|
-
for (int i=threadIdx.x; i < Len; i+= WP_TILE_BLOCK_DIM)
|
|
1453
|
-
data[i] = T(0);
|
|
1454
|
-
|
|
1455
|
-
if (RequiresGrad)
|
|
1456
|
-
{
|
|
1457
|
-
grad = (T*)tile_alloc_shared(Len*sizeof(T));
|
|
1458
|
-
|
|
1459
|
-
for (int i=threadIdx.x; i < Len; i+= WP_TILE_BLOCK_DIM)
|
|
1460
|
-
grad[i] = T(0);
|
|
1461
|
-
}
|
|
1462
|
-
|
|
1463
|
-
WP_TILE_SYNC();
|
|
1464
|
-
|
|
1465
|
-
return tile_shared_t<T, tile_layout_strided_t<tile_shape_t<M, N>>(data, grad);
|
|
1466
|
-
}
|
|
1467
|
-
|
|
1468
1507
|
|
|
1469
1508
|
//-----------------------------------------------------------------------------------------------------
|
|
1470
1509
|
// High level entry points for each op (correspond to one Warp builtin)
|
|
@@ -1476,7 +1515,7 @@ inline CUDA_CALLABLE auto tile(const T& x)
|
|
|
1476
1515
|
tile_register_t<T, tile_layout_register_t<tile_shape_t<WP_TILE_BLOCK_DIM>>> result;
|
|
1477
1516
|
|
|
1478
1517
|
using Layout = typename decltype(result)::Layout;
|
|
1479
|
-
static_assert(Layout::NumRegs == 1);
|
|
1518
|
+
static_assert(Layout::NumRegs == 1, "Expected Layout::NumRegs == 1");
|
|
1480
1519
|
|
|
1481
1520
|
result.data[0] = x;
|
|
1482
1521
|
return result;
|
|
@@ -1489,7 +1528,7 @@ inline CUDA_CALLABLE auto tile(const wp::vec_t<Length, T>& x)
|
|
|
1489
1528
|
tile_register_t<T, tile_layout_register_t<tile_shape_t<Length, WP_TILE_BLOCK_DIM>>> result;
|
|
1490
1529
|
|
|
1491
1530
|
using Layout = typename decltype(result)::Layout;
|
|
1492
|
-
static_assert(Layout::NumRegs == Length);
|
|
1531
|
+
static_assert(Layout::NumRegs == Length, "Expected Layout::NumRegs == Length");
|
|
1493
1532
|
|
|
1494
1533
|
for (int i=0; i < Length; ++i)
|
|
1495
1534
|
result.data[i] = x[i];
|
|
@@ -1501,8 +1540,8 @@ inline CUDA_CALLABLE auto tile(const wp::vec_t<Length, T>& x)
|
|
|
1501
1540
|
template <typename T, typename AdjTile>
|
|
1502
1541
|
inline CUDA_CALLABLE void adj_tile(const T& x, T& adj_x, AdjTile& adj_ret)
|
|
1503
1542
|
{
|
|
1504
|
-
static_assert(AdjTile::Layout::Shape::N == 1);
|
|
1505
|
-
static_assert(AdjTile::Layout::Shape::dim(0) == WP_TILE_BLOCK_DIM);
|
|
1543
|
+
static_assert(AdjTile::Layout::Shape::N == 1, "Expected AdjTile::Layout::Shape::N == 1");
|
|
1544
|
+
static_assert(AdjTile::Layout::Shape::dim(0) == WP_TILE_BLOCK_DIM, "Expected AdjTile::Layout::Shape::dim(0) == WP_TILE_BLOCK_DIM");
|
|
1506
1545
|
|
|
1507
1546
|
auto adj_reg = adj_ret.copy_to_register();
|
|
1508
1547
|
|
|
@@ -1512,9 +1551,9 @@ inline CUDA_CALLABLE void adj_tile(const T& x, T& adj_x, AdjTile& adj_ret)
|
|
|
1512
1551
|
template <typename T, unsigned Length, typename AdjTile>
|
|
1513
1552
|
inline CUDA_CALLABLE void adj_tile(const wp::vec_t<Length, T>& x, wp::vec_t<Length, T>& adj_x, AdjTile& adj_ret)
|
|
1514
1553
|
{
|
|
1515
|
-
static_assert(AdjTile::Layout::Shape::N == 2);
|
|
1516
|
-
static_assert(AdjTile::Layout::Shape::dim(0) == Length);
|
|
1517
|
-
static_assert(AdjTile::Layout::Shape::dim(1) == WP_TILE_BLOCK_DIM);
|
|
1554
|
+
static_assert(AdjTile::Layout::Shape::N == 2, "Expected AdjTile::Layout::Shape::N == 2");
|
|
1555
|
+
static_assert(AdjTile::Layout::Shape::dim(0) == Length, "Expected AdjTile::Layout::Shape::dim(0) == Length");
|
|
1556
|
+
static_assert(AdjTile::Layout::Shape::dim(1) == WP_TILE_BLOCK_DIM, "Expected AdjTile::Layout::Shape::dim(1) == WP_TILE_BLOCK_DIM");
|
|
1518
1557
|
|
|
1519
1558
|
auto adj_reg = adj_ret.copy_to_register();
|
|
1520
1559
|
|
|
@@ -1692,7 +1731,7 @@ inline CUDA_CALLABLE void adj_tile_store(array_t<T>& dest, Coord c, Tile& t, arr
|
|
|
1692
1731
|
if (adj_dest.data)
|
|
1693
1732
|
src.data.grad = adj_dest.data;
|
|
1694
1733
|
|
|
1695
|
-
if (src.data.grad ==
|
|
1734
|
+
if (src.data.grad == nullptr)
|
|
1696
1735
|
return;
|
|
1697
1736
|
|
|
1698
1737
|
adj_t.grad_add(src);
|
|
@@ -1927,7 +1966,6 @@ void adj_tile_extract(Tile& t, int i, int j, int k, AdjTile& adj_t, int adj_i, i
|
|
|
1927
1966
|
template<typename Tile, typename AdjTile>
|
|
1928
1967
|
void adj_tile_extract(Tile& t, int i, int j, int k, int l, AdjTile& adj_t, int adj_i, int adj_j, int adj_k, int adj_l, typename Tile::Type adj_ret) { adj_t.adj_extract(tile_coord(i, j, k, l), adj_ret); }
|
|
1929
1968
|
|
|
1930
|
-
#if WP_USE_REGISTER_GEMM
|
|
1931
1969
|
|
|
1932
1970
|
namespace partitioned_gemm
|
|
1933
1971
|
{
|
|
@@ -2033,9 +2071,11 @@ inline CUDA_CALLABLE void matmul(TileA& A, TileB& B, TileC& out)
|
|
|
2033
2071
|
auto B_tile = partition_t<TILE_K, TILE_N, TileB>(B);
|
|
2034
2072
|
auto C_tile = partition_t<TILE_M, TILE_N, TileC>(out);
|
|
2035
2073
|
|
|
2074
|
+
//static_assert(is_same<typename TileA::Type, typename TileB::Type>::value);
|
|
2075
|
+
|
|
2036
2076
|
const int length = partition_size(C_tile);
|
|
2037
2077
|
|
|
2038
|
-
for (int t=
|
|
2078
|
+
for (int t=WP_TILE_THREAD_IDX; t < length; t += WP_TILE_BLOCK_DIM)
|
|
2039
2079
|
{
|
|
2040
2080
|
int i, j;
|
|
2041
2081
|
partition_coord(C_tile, t, i, j);
|
|
@@ -2055,10 +2095,102 @@ inline CUDA_CALLABLE void matmul(TileA& A, TileB& B, TileC& out)
|
|
|
2055
2095
|
partition_store(C_tile, i, j, sum);
|
|
2056
2096
|
}
|
|
2057
2097
|
}
|
|
2058
|
-
|
|
2059
|
-
} // namespace partition_gemm
|
|
2060
2098
|
|
|
2061
|
-
|
|
2099
|
+
template <typename LayoutA, typename LayoutB, typename LayoutC, typename StorageA, typename StorageB, typename StorageC, typename T>
|
|
2100
|
+
inline CUDA_CALLABLE void scalar_matmul(const StorageA& A, const StorageB& B, StorageC& C, T scale)
|
|
2101
|
+
{
|
|
2102
|
+
for (int t=WP_TILE_THREAD_IDX; t < LayoutC::Size; t += WP_TILE_BLOCK_DIM)
|
|
2103
|
+
{
|
|
2104
|
+
auto coord = LayoutC::coord_from_linear(t);
|
|
2105
|
+
|
|
2106
|
+
int i = coord[0];
|
|
2107
|
+
int j = coord[1];
|
|
2108
|
+
|
|
2109
|
+
// accumulator
|
|
2110
|
+
auto sum = C(coord)*scale;
|
|
2111
|
+
|
|
2112
|
+
WP_PRAGMA_UNROLL
|
|
2113
|
+
for (int k=0; k < LayoutA::Shape::dim(1); k++)
|
|
2114
|
+
{
|
|
2115
|
+
const auto a = A(tile_coord(i, k));
|
|
2116
|
+
const auto b = B(tile_coord(k, j));
|
|
2117
|
+
|
|
2118
|
+
sum = muladd<decltype(sum)>(a, b, sum);
|
|
2119
|
+
}
|
|
2120
|
+
|
|
2121
|
+
C(coord) = sum;
|
|
2122
|
+
}
|
|
2123
|
+
}
|
|
2124
|
+
|
|
2125
|
+
template <typename TileA, typename TileL>
|
|
2126
|
+
inline CUDA_CALLABLE void scalar_cholesky(TileA& A, TileL& L)
|
|
2127
|
+
{
|
|
2128
|
+
using T = typename TileA::Type;
|
|
2129
|
+
constexpr int n = TileA::Layout::Shape::dim(1);
|
|
2130
|
+
|
|
2131
|
+
for (int j=0; j < n; ++j)
|
|
2132
|
+
{
|
|
2133
|
+
T s = A.data(tile_coord(j, j));
|
|
2134
|
+
|
|
2135
|
+
for (int k=0; k < j; ++k)
|
|
2136
|
+
{
|
|
2137
|
+
T r = L.data(tile_coord(j, k));
|
|
2138
|
+
s -= r * r;
|
|
2139
|
+
}
|
|
2140
|
+
|
|
2141
|
+
s = wp::sqrt(s);
|
|
2142
|
+
T invS = 1.0 / s;
|
|
2143
|
+
|
|
2144
|
+
L.data(tile_coord(j, j)) = s;
|
|
2145
|
+
|
|
2146
|
+
for (int i=j+1; i < n; ++i)
|
|
2147
|
+
{
|
|
2148
|
+
s = A.data(tile_coord(i, j));
|
|
2149
|
+
|
|
2150
|
+
for (int k=0; k < j; ++k)
|
|
2151
|
+
{
|
|
2152
|
+
s -= L.data(tile_coord(i, k)) * L.data(tile_coord(j, k));
|
|
2153
|
+
}
|
|
2154
|
+
|
|
2155
|
+
L.data(tile_coord(i, j)) = s * invS;
|
|
2156
|
+
}
|
|
2157
|
+
|
|
2158
|
+
// zero out upper triangular portion
|
|
2159
|
+
for (int k=j+1; k < n; ++k)
|
|
2160
|
+
{
|
|
2161
|
+
L.data(tile_coord(j,k)) = T(0.0);
|
|
2162
|
+
}
|
|
2163
|
+
}
|
|
2164
|
+
}
|
|
2165
|
+
|
|
2166
|
+
template <typename TileL, typename TileX, typename TileY>
|
|
2167
|
+
inline CUDA_CALLABLE void scalar_cholesky_solve(TileL& L, TileX& X, TileY& Y)
|
|
2168
|
+
{
|
|
2169
|
+
using T = typename TileL::Type;
|
|
2170
|
+
constexpr int n = TileL::Layout::Shape::dim(1);
|
|
2171
|
+
|
|
2172
|
+
for (int i=0; i < n; ++i)
|
|
2173
|
+
{
|
|
2174
|
+
T s = Y.data(tile_coord(i));
|
|
2175
|
+
|
|
2176
|
+
for (int j=0; j < i; ++j)
|
|
2177
|
+
s -= L.data(tile_coord(i,j)) * X.data(tile_coord(j));
|
|
2178
|
+
|
|
2179
|
+
X.data(tile_coord(i)) = s / L.data(tile_coord(i, i));
|
|
2180
|
+
}
|
|
2181
|
+
|
|
2182
|
+
for (int i=n-1; i >= 0; --i)
|
|
2183
|
+
{
|
|
2184
|
+
T s = X.data(tile_coord(i));
|
|
2185
|
+
|
|
2186
|
+
for (int j=i+1; j < n; ++j)
|
|
2187
|
+
s -= L.data(tile_coord(j, i)) * X.data(tile_coord(j));
|
|
2188
|
+
|
|
2189
|
+
X.data(tile_coord(i)) = s / L.data(tile_coord(i, i));
|
|
2190
|
+
}
|
|
2191
|
+
}
|
|
2192
|
+
|
|
2193
|
+
} // namespace partition_gemm
|
|
2062
2194
|
|
|
2063
2195
|
|
|
2064
2196
|
template <int Add, typename Fwd, typename AdjA, typename AdjB, typename TileA, typename TileB, typename TileC>
|
|
@@ -2068,19 +2200,19 @@ TileC& tile_matmul(Fwd fun_forward, AdjA fun_backward_A, AdjB fun_backward_B, Ti
|
|
|
2068
2200
|
using ShapeB = typename TileB::Layout::Shape;
|
|
2069
2201
|
using ShapeC = typename TileC::Layout::Shape;
|
|
2070
2202
|
|
|
2071
|
-
static_assert(ShapeA::N == 2);
|
|
2072
|
-
static_assert(ShapeB::N == 2);
|
|
2073
|
-
static_assert(ShapeC::N == 2);
|
|
2203
|
+
static_assert(ShapeA::N == 2, "Expected ShapeA::N == 2");
|
|
2204
|
+
static_assert(ShapeB::N == 2, "Expected ShapeB::N == 2");
|
|
2205
|
+
static_assert(ShapeC::N == 2, "Expected ShapeC::N == 2");
|
|
2074
2206
|
|
|
2075
|
-
static_assert(ShapeA::dim(1) == ShapeB::dim(0));
|
|
2076
|
-
static_assert(ShapeC::dim(0) == ShapeA::dim(0));
|
|
2077
|
-
static_assert(ShapeC::dim(1) == ShapeB::dim(1));
|
|
2207
|
+
static_assert(ShapeA::dim(1) == ShapeB::dim(0), "Expected ShapeA::dim(1) == ShapeB::dim(0)");
|
|
2208
|
+
static_assert(ShapeC::dim(0) == ShapeA::dim(0), "Expected ShapeC::dim(0) == ShapeA::dim(0)");
|
|
2209
|
+
static_assert(ShapeC::dim(1) == ShapeB::dim(1), "Expected ShapeC::dim(1) == ShapeB::dim(1)");
|
|
2078
2210
|
|
|
2079
2211
|
|
|
2080
2212
|
using T = typename TileA::Type;
|
|
2081
2213
|
|
|
2082
|
-
#if
|
|
2083
|
-
partitioned_gemm::
|
|
2214
|
+
#if !defined(__CUDA_ARCH__) || WP_ENABLE_MATHDX == 0
|
|
2215
|
+
partitioned_gemm::scalar_matmul<typename TileA::Layout, typename TileB::Layout, typename TileC::Layout>(A.data, B.data, C.data, T(Add));
|
|
2084
2216
|
#else
|
|
2085
2217
|
fun_forward(T(1.0), A.data.ptr, B.data.ptr, T(Add), C.data.ptr);
|
|
2086
2218
|
#endif
|
|
@@ -2090,6 +2222,7 @@ TileC& tile_matmul(Fwd fun_forward, AdjA fun_backward_A, AdjB fun_backward_B, Ti
|
|
|
2090
2222
|
return C;
|
|
2091
2223
|
}
|
|
2092
2224
|
|
|
2225
|
+
|
|
2093
2226
|
// backward for the wp.tile_matmul(a, b, out) syntax
|
|
2094
2227
|
template <typename Fwd, typename AdjA, typename AdjB, typename TileA, typename TileB, typename TileC>
|
|
2095
2228
|
void adj_tile_matmul(Fwd fun_forward, AdjA fun_backward_A, AdjB fun_backward_B, TileA& A, TileB& B, TileC& C,
|
|
@@ -2097,8 +2230,17 @@ void adj_tile_matmul(Fwd fun_forward, AdjA fun_backward_A, AdjB fun_backward_B,
|
|
|
2097
2230
|
{
|
|
2098
2231
|
using T = typename TileA::Type;
|
|
2099
2232
|
|
|
2233
|
+
#if !defined(__CUDA_ARCH__) || WP_ENABLE_MATHDX == 0
|
|
2234
|
+
auto At = tile_transpose(A);
|
|
2235
|
+
auto Bt = tile_transpose(B);
|
|
2236
|
+
|
|
2237
|
+
partitioned_gemm::scalar_matmul<typename TileC::Layout, typename decltype(Bt)::Layout, typename TileA::Layout>(adj_C.grad, Bt.data, adj_A.grad, T(1.0));
|
|
2238
|
+
partitioned_gemm::scalar_matmul<typename decltype(At)::Layout, typename TileC::Layout, typename TileB::Layout>(At.data, adj_C.grad, adj_B.grad, T(1.0));
|
|
2239
|
+
#else
|
|
2100
2240
|
fun_backward_A(T(1.0), adj_C.grad.ptr, B.data.ptr, T(1.0), adj_A.grad.ptr);
|
|
2101
2241
|
fun_backward_B(T(1.0), A.data.ptr, adj_C.grad.ptr, T(1.0), adj_B.grad.ptr);
|
|
2242
|
+
#endif
|
|
2243
|
+
|
|
2102
2244
|
WP_TILE_SYNC();
|
|
2103
2245
|
}
|
|
2104
2246
|
|
|
@@ -2109,11 +2251,30 @@ void adj_tile_matmul(Fwd fun_forward, AdjA fun_backward_A, AdjB fun_backward_B,
|
|
|
2109
2251
|
{
|
|
2110
2252
|
using T = typename TileA::Type;
|
|
2111
2253
|
|
|
2254
|
+
#if !defined(__CUDA_ARCH__) || WP_ENABLE_MATHDX == 0
|
|
2255
|
+
auto At = tile_transpose(A);
|
|
2256
|
+
auto Bt = tile_transpose(B);
|
|
2257
|
+
|
|
2258
|
+
partitioned_gemm::scalar_matmul<typename TileC::Layout, typename decltype(Bt)::Layout, typename TileA::Layout>(adj_C.grad, Bt.data, adj_A.grad, T(1.0));
|
|
2259
|
+
partitioned_gemm::scalar_matmul<typename decltype(At)::Layout, typename TileC::Layout, typename TileB::Layout>(At.data, adj_C.grad, adj_B.grad, T(1.0));
|
|
2260
|
+
#else
|
|
2112
2261
|
fun_backward_A(T(1.0), adj_C.grad.ptr, B.data.ptr, T(1.0), adj_A.grad.ptr);
|
|
2113
2262
|
fun_backward_B(T(1.0), A.data.ptr, adj_C.grad.ptr, T(1.0), adj_B.grad.ptr);
|
|
2263
|
+
#endif
|
|
2264
|
+
|
|
2114
2265
|
WP_TILE_SYNC();
|
|
2115
2266
|
}
|
|
2116
2267
|
|
|
2268
|
+
#if !defined(__CUDA_ARCH__) || WP_ENABLE_MATHDX == 0
|
|
2269
|
+
|
|
2270
|
+
#define tile_fft()
|
|
2271
|
+
#define tile_ifft()
|
|
2272
|
+
|
|
2273
|
+
#define adj_tile_fft()
|
|
2274
|
+
#define adj_tile_ifft()
|
|
2275
|
+
|
|
2276
|
+
#else
|
|
2277
|
+
|
|
2117
2278
|
// TODO(lcambier): use a properly overaligned complex type that matches cuFFTDx's expectation
|
|
2118
2279
|
// and remove the need for __align__(16) dtypes data[...]
|
|
2119
2280
|
#define tile_fft(function_name, dtype, shared_memory_size, batch_size, ept, Xinout) \
|
|
@@ -2149,12 +2310,21 @@ void adj_tile_matmul(Fwd fun_forward, AdjA fun_backward_A, AdjB fun_backward_B,
|
|
|
2149
2310
|
tile_fft(function_name, dtype, shared_memory_size, batch_size, ept, adj_Xinout); \
|
|
2150
2311
|
} while (0)
|
|
2151
2312
|
|
|
2313
|
+
#endif // !defined(__CUDA_ARCH__)
|
|
2314
|
+
|
|
2152
2315
|
template <typename Fwd, typename TileA, typename TileL>
|
|
2153
2316
|
TileL& tile_cholesky(Fwd fun_forward, TileA& A, TileL& L)
|
|
2154
2317
|
{
|
|
2155
2318
|
// Copy to L
|
|
2156
2319
|
L = A;
|
|
2157
2320
|
|
|
2321
|
+
#if !defined(__CUDA_ARCH__) || WP_ENABLE_MATHDX == 0
|
|
2322
|
+
|
|
2323
|
+
partitioned_gemm::scalar_cholesky(A, L);
|
|
2324
|
+
|
|
2325
|
+
#else
|
|
2326
|
+
|
|
2327
|
+
|
|
2158
2328
|
// Call cholesky on L
|
|
2159
2329
|
WP_TILE_SYNC();
|
|
2160
2330
|
|
|
@@ -2165,7 +2335,7 @@ TileL& tile_cholesky(Fwd fun_forward, TileA& A, TileL& L)
|
|
|
2165
2335
|
// Zero-out the upper triangular part of L
|
|
2166
2336
|
|
|
2167
2337
|
WP_PRAGMA_UNROLL
|
|
2168
|
-
for (int i=
|
|
2338
|
+
for (int i=WP_TILE_THREAD_IDX; i < TileL::Layout::Size; i += WP_TILE_BLOCK_DIM)
|
|
2169
2339
|
{
|
|
2170
2340
|
auto c = TileL::Layout::coord_from_linear(i);
|
|
2171
2341
|
|
|
@@ -2174,7 +2344,9 @@ TileL& tile_cholesky(Fwd fun_forward, TileA& A, TileL& L)
|
|
|
2174
2344
|
}
|
|
2175
2345
|
|
|
2176
2346
|
WP_TILE_SYNC();
|
|
2177
|
-
|
|
2347
|
+
|
|
2348
|
+
#endif
|
|
2349
|
+
|
|
2178
2350
|
return L;
|
|
2179
2351
|
}
|
|
2180
2352
|
|
|
@@ -2191,6 +2363,12 @@ TileY& tile_cholesky_solve(Fwd fun_forward, TileL& L, TileX& X, TileY& Y)
|
|
|
2191
2363
|
|
|
2192
2364
|
Y = X;
|
|
2193
2365
|
|
|
2366
|
+
#if !defined(__CUDA_ARCH__) || WP_ENABLE_MATHDX == 0
|
|
2367
|
+
|
|
2368
|
+
partitioned_gemm::scalar_cholesky_solve(L, X, Y);
|
|
2369
|
+
|
|
2370
|
+
#else
|
|
2371
|
+
|
|
2194
2372
|
// Call cholesky solve on L & y
|
|
2195
2373
|
|
|
2196
2374
|
WP_TILE_SYNC();
|
|
@@ -2199,6 +2377,8 @@ TileY& tile_cholesky_solve(Fwd fun_forward, TileL& L, TileX& X, TileY& Y)
|
|
|
2199
2377
|
|
|
2200
2378
|
WP_TILE_SYNC();
|
|
2201
2379
|
|
|
2380
|
+
#endif
|
|
2381
|
+
|
|
2202
2382
|
return Y;
|
|
2203
2383
|
}
|
|
2204
2384
|
|
|
@@ -2211,7 +2391,7 @@ TileY& tile_cholesky_solve(Fwd fun_forward, TileL& L, TileX& X, TileY& Y)
|
|
|
2211
2391
|
template <typename Tile>
|
|
2212
2392
|
inline CUDA_CALLABLE auto tile_transpose(Tile& t)
|
|
2213
2393
|
{
|
|
2214
|
-
static_assert(Tile::Layout::Shape::N == 2);
|
|
2394
|
+
static_assert(Tile::Layout::Shape::N == 2, "Expected Tile::Layout::Shape::N == 2");
|
|
2215
2395
|
|
|
2216
2396
|
// alias incoming tile
|
|
2217
2397
|
constexpr int M = Tile::Layout::Shape::dim(0);
|
|
@@ -2232,13 +2412,34 @@ inline CUDA_CALLABLE void adj_tile_transpose(Tile& t, Tile& adj_t, AdjTile& adj_
|
|
|
2232
2412
|
adj_t.assign(tile_add(a,b));
|
|
2233
2413
|
}
|
|
2234
2414
|
|
|
2415
|
+
template <int N, int StrideN, typename Tile>
|
|
2416
|
+
inline CUDA_CALLABLE auto tile_broadcast(Tile& t)
|
|
2417
|
+
{
|
|
2418
|
+
// alias incoming tile with new strides
|
|
2419
|
+
return tile_shared_t<typename Tile::Type, tile_layout_strided_t<tile_shape_t<N>, tile_stride_t<StrideN>>, false>(t.data.ptr, t.grad.ptr);
|
|
2420
|
+
}
|
|
2421
|
+
|
|
2235
2422
|
template <int M, int N, int StrideM, int StrideN, typename Tile>
|
|
2236
2423
|
inline CUDA_CALLABLE auto tile_broadcast(Tile& t)
|
|
2237
|
-
{
|
|
2424
|
+
{
|
|
2238
2425
|
// alias incoming tile with new strides
|
|
2239
2426
|
return tile_shared_t<typename Tile::Type, tile_layout_strided_t<tile_shape_t<M, N>, tile_stride_t<StrideM, StrideN>>, false>(t.data.ptr, t.grad.ptr);
|
|
2240
2427
|
}
|
|
2241
2428
|
|
|
2429
|
+
template <int M, int N, int O, int StrideM, int StrideN, int StrideO, typename Tile>
|
|
2430
|
+
inline CUDA_CALLABLE auto tile_broadcast(Tile& t)
|
|
2431
|
+
{
|
|
2432
|
+
// alias incoming tile with new strides
|
|
2433
|
+
return tile_shared_t<typename Tile::Type, tile_layout_strided_t<tile_shape_t<M, N, O>, tile_stride_t<StrideM, StrideN, StrideO>>, false>(t.data.ptr, t.grad.ptr);
|
|
2434
|
+
}
|
|
2435
|
+
|
|
2436
|
+
template <int M, int N, int O, int P, int StrideM, int StrideN, int StrideO, int StrideP, typename Tile>
|
|
2437
|
+
inline CUDA_CALLABLE auto tile_broadcast(Tile& t)
|
|
2438
|
+
{
|
|
2439
|
+
// alias incoming tile with new strides
|
|
2440
|
+
return tile_shared_t<typename Tile::Type, tile_layout_strided_t<tile_shape_t<M, N, O, P>, tile_stride_t<StrideM, StrideN, StrideO, StrideP>>, false>(t.data.ptr, t.grad.ptr);
|
|
2441
|
+
}
|
|
2442
|
+
|
|
2242
2443
|
template <typename Tile, typename AdjTile>
|
|
2243
2444
|
inline CUDA_CALLABLE void adj_tile_broadcast(Tile& t, Tile& adj_t, AdjTile& adj_ret)
|
|
2244
2445
|
{
|
|
@@ -2252,7 +2453,7 @@ inline CUDA_CALLABLE auto tile_view(Tile& t, Indices... indices)
|
|
|
2252
2453
|
|
|
2253
2454
|
// return new tile with same strides
|
|
2254
2455
|
typename Tile::Type* data_ptr = &t.data(c);
|
|
2255
|
-
typename Tile::Type* grad_ptr =
|
|
2456
|
+
typename Tile::Type* grad_ptr = nullptr;
|
|
2256
2457
|
|
|
2257
2458
|
if (t.grad.ptr)
|
|
2258
2459
|
grad_ptr = &t.grad(c);
|
|
@@ -2297,7 +2498,7 @@ inline CUDA_CALLABLE void tile_assign(TileA& dest, TileB& src, const Coord& offs
|
|
|
2297
2498
|
{
|
|
2298
2499
|
using Layout = typename TileB::Layout;
|
|
2299
2500
|
|
|
2300
|
-
for (int t=
|
|
2501
|
+
for (int t=WP_TILE_THREAD_IDX; t < Layout::Size; t += WP_TILE_BLOCK_DIM)
|
|
2301
2502
|
{
|
|
2302
2503
|
auto c = Layout::coord_from_linear(t);
|
|
2303
2504
|
dest.data(c + offset) = src.data(c);
|
|
@@ -2312,7 +2513,7 @@ inline CUDA_CALLABLE void adj_tile_assign(TileA& dest, TileB& src, Coord offset,
|
|
|
2312
2513
|
{
|
|
2313
2514
|
using Layout = typename TileB::Layout;
|
|
2314
2515
|
|
|
2315
|
-
for (int t=
|
|
2516
|
+
for (int t=WP_TILE_THREAD_IDX; t < Layout::Size; t += WP_TILE_BLOCK_DIM)
|
|
2316
2517
|
{
|
|
2317
2518
|
auto c = Layout::coord_from_linear(t);
|
|
2318
2519
|
src.grad(c) += dest.grad(c + offset);
|
|
@@ -2351,14 +2552,14 @@ inline CUDA_CALLABLE TileC& tile_diag_add(TileA& a, TileB& b, TileC& c)
|
|
|
2351
2552
|
using ShapeB = typename TileB::Layout::Shape;
|
|
2352
2553
|
using ShapeC = typename TileC::Layout::Shape;
|
|
2353
2554
|
|
|
2354
|
-
static_assert(ShapeA::dim(0) == ShapeA::dim(1));
|
|
2355
|
-
static_assert(ShapeB::dim(0) == ShapeA::dim(0));
|
|
2356
|
-
static_assert(ShapeC::dim(0) == ShapeA::dim(0));
|
|
2357
|
-
static_assert(ShapeC::dim(0) == ShapeC::dim(1));
|
|
2555
|
+
static_assert(ShapeA::dim(0) == ShapeA::dim(1), "Expected ShapeA::dim(0) == ShapeA::dim(1)");
|
|
2556
|
+
static_assert(ShapeB::dim(0) == ShapeA::dim(0), "Expected ShapeB::dim(0) == ShapeA::dim(0)");
|
|
2557
|
+
static_assert(ShapeC::dim(0) == ShapeA::dim(0), "Expected ShapeC::dim(0) == ShapeA::dim(0)");
|
|
2558
|
+
static_assert(ShapeC::dim(0) == ShapeC::dim(1), "Expected ShapeC::dim(0) == ShapeC::dim(1)");
|
|
2358
2559
|
|
|
2359
2560
|
c = a;
|
|
2360
2561
|
|
|
2361
|
-
for (int t=
|
|
2562
|
+
for (int t=WP_TILE_THREAD_IDX; t < ShapeA::dim(0); t += WP_TILE_BLOCK_DIM)
|
|
2362
2563
|
{
|
|
2363
2564
|
c.data(tile_coord(t, t)) += b.data(tile_coord(t));
|
|
2364
2565
|
}
|
|
@@ -2377,3 +2578,7 @@ inline CUDA_CALLABLE void adj_tile_diag_add(TileA& a, TileB& b, TileC& c, AdjTil
|
|
|
2377
2578
|
|
|
2378
2579
|
} // namespace wp
|
|
2379
2580
|
|
|
2581
|
+
|
|
2582
|
+
#ifdef __clang__
|
|
2583
|
+
#pragma clang diagnostic pop
|
|
2584
|
+
#endif
|