@novastera-oss/llamarn 0.3.0 → 0.3.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/build.gradle +2 -1
- package/android/proguard-rules.pro +12 -0
- package/android/src/main/cpp/include/llama.h +15 -47
- package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86/libggml.so +0 -0
- package/android/src/main/jniLibs/x86/libllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
- package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
- package/cpp/build-info.cpp +2 -2
- package/cpp/llama.cpp/CMakePresets.json +11 -0
- package/cpp/llama.cpp/CODEOWNERS +1 -0
- package/cpp/llama.cpp/README.md +4 -3
- package/cpp/llama.cpp/common/arg.cpp +45 -1
- package/cpp/llama.cpp/common/common.cpp +22 -6
- package/cpp/llama.cpp/common/common.h +18 -4
- package/cpp/llama.cpp/convert_hf_to_gguf.py +500 -32
- package/cpp/llama.cpp/convert_hf_to_gguf_update.py +12 -13
- package/cpp/llama.cpp/ggml/CMakeLists.txt +6 -1
- package/cpp/llama.cpp/ggml/cmake/ggml-config.cmake.in +85 -47
- package/cpp/llama.cpp/ggml/include/ggml-webgpu.h +19 -0
- package/cpp/llama.cpp/ggml/src/CMakeLists.txt +1 -0
- package/cpp/llama.cpp/ggml/src/ggml-alloc.c +0 -15
- package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +8 -20
- package/cpp/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +3 -1
- package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +58 -3
- package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +130 -22
- package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +122 -16
- package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +5 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +109 -12
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +88 -10
- package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +343 -1094
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +0 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +3 -3
- package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +14 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +64 -17
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy-utils.cuh +225 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy.cu +41 -301
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +85 -67
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +45 -62
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +28 -43
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +41 -56
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +36 -47
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +31 -43
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +22 -37
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cu +3 -13
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +73 -23
- package/cpp/llama.cpp/ggml/src/ggml-cuda/im2col.cu +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mma.cuh +111 -3
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cu +6 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +1152 -689
- package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cu +92 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cuh +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cu +275 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cuh +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +13 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-impl.h +16 -0
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +13 -3
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +407 -69
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +380 -83
- package/cpp/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +18 -4
- package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +295 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/rms_norm.cl +79 -0
- package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +4 -4
- package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +14 -26
- package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +131 -46
- package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +8 -9
- package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +43 -43
- package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +2 -6
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +287 -22
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +265 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +1 -5
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +3 -8
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +8 -2
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +1 -4
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +71 -16
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/CMakeLists.txt +54 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/ggml-webgpu.cpp +907 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +60 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +35 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +56 -0
- package/cpp/llama.cpp/ggml/src/ggml.c +4 -6
- package/cpp/llama.cpp/gguf-py/gguf/constants.py +98 -0
- package/cpp/llama.cpp/gguf-py/gguf/metadata.py +4 -0
- package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_dump.py +24 -1
- package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +75 -52
- package/cpp/llama.cpp/include/llama.h +15 -7
- package/cpp/llama.cpp/models/templates/llama-cpp-rwkv-world.jinja +34 -0
- package/cpp/llama.cpp/models/templates/moonshotai-Kimi-K2.jinja +43 -0
- package/cpp/llama.cpp/requirements/requirements-all.txt +1 -0
- package/cpp/llama.cpp/requirements/requirements-server-bench.txt +5 -0
- package/cpp/llama.cpp/src/llama-arch.cpp +106 -0
- package/cpp/llama.cpp/src/llama-arch.h +5 -0
- package/cpp/llama.cpp/src/llama-batch.cpp +76 -70
- package/cpp/llama.cpp/src/llama-batch.h +24 -18
- package/cpp/llama.cpp/src/llama-chat.cpp +43 -1
- package/cpp/llama.cpp/src/llama-chat.h +2 -0
- package/cpp/llama.cpp/src/llama-context.cpp +180 -106
- package/cpp/llama.cpp/src/llama-context.h +26 -16
- package/cpp/llama.cpp/src/llama-cparams.h +3 -2
- package/cpp/llama.cpp/src/llama-graph.cpp +203 -39
- package/cpp/llama.cpp/src/llama-graph.h +147 -72
- package/cpp/llama.cpp/src/llama-hparams.cpp +40 -0
- package/cpp/llama.cpp/src/llama-hparams.h +10 -2
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +11 -5
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +3 -0
- package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +698 -302
- package/cpp/llama.cpp/src/llama-kv-cache-unified.h +89 -31
- package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +1 -0
- package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +16 -1
- package/cpp/llama.cpp/src/llama-model.cpp +1293 -312
- package/cpp/llama.cpp/src/llama-model.h +3 -4
- package/cpp/llama.cpp/src/llama-quant.cpp +1 -2
- package/cpp/llama.cpp/src/llama-vocab.cpp +363 -8
- package/cpp/llama.cpp/src/llama-vocab.h +2 -0
- package/cpp/llama.cpp/src/unicode.cpp +207 -0
- package/cpp/llama.cpp/src/unicode.h +2 -0
- package/ios/include/common.h +18 -4
- package/ios/include/llama.h +15 -7
- package/ios/libs/llama.xcframework/Info.plist +15 -15
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5267 -5059
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +15 -7
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -5030
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4014 -3889
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -7
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -5030
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4016 -3891
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +15 -7
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +15 -7
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +15 -7
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5267 -5059
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +15 -7
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -5030
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4014 -3889
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -7
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5303 -5095
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +15 -7
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5274 -5066
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4044 -3919
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -7
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/package.json +4 -4
|
@@ -23,31 +23,13 @@ typedef void (* fattn_kernel_t)(
|
|
|
23
23
|
const float m1,
|
|
24
24
|
const uint32_t n_head_log2,
|
|
25
25
|
const float logit_softcap,
|
|
26
|
-
const
|
|
27
|
-
|
|
28
|
-
const
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
const int ne13,
|
|
34
|
-
const int ne31,
|
|
35
|
-
const int ne32,
|
|
36
|
-
const int nb31,
|
|
37
|
-
const int nb32,
|
|
38
|
-
const int nb01,
|
|
39
|
-
const int nb02,
|
|
40
|
-
const int nb03,
|
|
41
|
-
const int nb11,
|
|
42
|
-
const int nb12,
|
|
43
|
-
const int nb13,
|
|
44
|
-
const int nb21,
|
|
45
|
-
const int nb22,
|
|
46
|
-
const int nb23,
|
|
47
|
-
const int ne0,
|
|
48
|
-
const int ne1,
|
|
49
|
-
const int ne2,
|
|
50
|
-
const int ne3);
|
|
26
|
+
const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
|
|
27
|
+
const int32_t nb01, const int32_t nb02, const int32_t nb03,
|
|
28
|
+
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
|
|
29
|
+
const int32_t nb11, const int32_t nb12, const int64_t nb13,
|
|
30
|
+
const int32_t nb21, const int32_t nb22, const int64_t nb23,
|
|
31
|
+
const int32_t ne31, const int32_t ne32, const int32_t ne33,
|
|
32
|
+
const int32_t nb31, const int32_t nb32, const int64_t nb33);
|
|
51
33
|
|
|
52
34
|
typedef half (*vec_dot_KQ_f16_t)(
|
|
53
35
|
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds);
|
|
@@ -521,7 +503,7 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
|
|
|
521
503
|
template<int D, int ncols1, int ncols2> // D == head size
|
|
522
504
|
__launch_bounds__(D, 1)
|
|
523
505
|
static __global__ void flash_attn_stream_k_fixup(
|
|
524
|
-
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) {
|
|
506
|
+
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, const int ne11) {
|
|
525
507
|
constexpr int ncols = ncols1*ncols2;
|
|
526
508
|
|
|
527
509
|
const int bidx0 = blockIdx.x;
|
|
@@ -535,8 +517,8 @@ static __global__ void flash_attn_stream_k_fixup(
|
|
|
535
517
|
const int iter_k = ne11 / FATTN_KQ_STRIDE;
|
|
536
518
|
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
|
|
537
519
|
|
|
538
|
-
const int kbc0 = (bidx0 + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
|
|
539
|
-
const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
|
|
520
|
+
const int kbc0 = (bidx0 + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
|
|
521
|
+
const int kbc0_stop = (bidx0 + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
|
|
540
522
|
|
|
541
523
|
const bool did_not_have_any_data = kbc0 == kbc0_stop;
|
|
542
524
|
const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
|
|
@@ -545,14 +527,15 @@ static __global__ void flash_attn_stream_k_fixup(
|
|
|
545
527
|
return;
|
|
546
528
|
}
|
|
547
529
|
|
|
548
|
-
const int
|
|
549
|
-
const int
|
|
530
|
+
const int sequence = kbc0 / (iter_k*iter_j*(ne02/ncols2));
|
|
531
|
+
const int head = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
|
|
532
|
+
const int jt = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
|
|
550
533
|
|
|
551
534
|
if (jt*ncols1 + j >= ne01) {
|
|
552
535
|
return;
|
|
553
536
|
}
|
|
554
537
|
|
|
555
|
-
dst += jt*ne02*(ncols1*D) +
|
|
538
|
+
dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + head*(ncols2*D) + (j*ne02 + c)*D + tid;
|
|
556
539
|
|
|
557
540
|
// Load the partial result that needs a fixup:
|
|
558
541
|
float dst_val = 0.0f;
|
|
@@ -571,7 +554,7 @@ static __global__ void flash_attn_stream_k_fixup(
|
|
|
571
554
|
int bidx = bidx0 - 1;
|
|
572
555
|
int kbc_stop = kbc0;
|
|
573
556
|
while(true) {
|
|
574
|
-
const int kbc = bidx*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
|
|
557
|
+
const int kbc = bidx*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
|
|
575
558
|
if (kbc == kbc_stop) { // Did not have any data.
|
|
576
559
|
bidx--;
|
|
577
560
|
kbc_stop = kbc;
|
|
@@ -617,16 +600,31 @@ static __global__ void flash_attn_combine_results(
|
|
|
617
600
|
const float2 * __restrict__ VKQ_meta,
|
|
618
601
|
float * __restrict__ dst,
|
|
619
602
|
const int parallel_blocks) {
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
603
|
+
// Dimension 0: threadIdx.x
|
|
604
|
+
// Dimension 1: blockIdx.x
|
|
605
|
+
// Dimension 2: blockIdx.y
|
|
606
|
+
// Dimension 3: blockIdx.z
|
|
607
|
+
// Memory layout is permuted with [0, 2, 1, 3]
|
|
608
|
+
|
|
609
|
+
const int ne01 = gridDim.x;
|
|
610
|
+
const int ne02 = gridDim.y;
|
|
611
|
+
|
|
612
|
+
const int col = blockIdx.x;
|
|
613
|
+
const int head = blockIdx.y;
|
|
614
|
+
const int sequence = blockIdx.z;
|
|
615
|
+
|
|
616
|
+
const int j_dst_unrolled = (sequence*ne01 + col)*ne02 + head;
|
|
617
|
+
|
|
618
|
+
VKQ_parts += j_dst_unrolled * parallel_blocks*D;
|
|
619
|
+
VKQ_meta += j_dst_unrolled * parallel_blocks;
|
|
620
|
+
dst += j_dst_unrolled * D;
|
|
623
621
|
|
|
624
622
|
const int tid = threadIdx.x;
|
|
625
623
|
__builtin_assume(tid < D);
|
|
626
624
|
|
|
627
625
|
extern __shared__ float2 meta[];
|
|
628
626
|
for (int i = tid; i < 2*parallel_blocks; i += D) {
|
|
629
|
-
((float *) meta)[i] = ((const float *)VKQ_meta) [
|
|
627
|
+
((float *) meta)[i] = ((const float *)VKQ_meta) [i];
|
|
630
628
|
}
|
|
631
629
|
|
|
632
630
|
__syncthreads();
|
|
@@ -644,11 +642,11 @@ static __global__ void flash_attn_combine_results(
|
|
|
644
642
|
const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
|
|
645
643
|
*((uint32_t *) &KQ_max_scale) &= ftz_mask;
|
|
646
644
|
|
|
647
|
-
VKQ_numerator += KQ_max_scale * VKQ_parts[l*
|
|
645
|
+
VKQ_numerator += KQ_max_scale * VKQ_parts[l*D + tid];
|
|
648
646
|
VKQ_denominator += KQ_max_scale * meta[l].y;
|
|
649
647
|
}
|
|
650
648
|
|
|
651
|
-
dst[
|
|
649
|
+
dst[tid] = VKQ_numerator / VKQ_denominator;
|
|
652
650
|
}
|
|
653
651
|
|
|
654
652
|
[[noreturn]]
|
|
@@ -705,8 +703,6 @@ void launch_fattn(
|
|
|
705
703
|
|
|
706
704
|
GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
|
|
707
705
|
|
|
708
|
-
GGML_ASSERT(Q->ne[3] == 1);
|
|
709
|
-
|
|
710
706
|
ggml_cuda_pool & pool = ctx.pool();
|
|
711
707
|
cudaStream_t main_stream = ctx.stream();
|
|
712
708
|
const int id = ggml_cuda_get_device();
|
|
@@ -729,33 +725,58 @@ void launch_fattn(
|
|
|
729
725
|
size_t nb23 = V ? V->nb[3] : nb13;
|
|
730
726
|
|
|
731
727
|
if (need_f16_K && K->type != GGML_TYPE_F16) {
|
|
732
|
-
GGML_ASSERT(ggml_is_contiguously_allocated(K));
|
|
733
|
-
K_f16.alloc(ggml_nelements(K));
|
|
734
|
-
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(K->type);
|
|
735
|
-
to_fp16(K_data, K_f16.ptr, ggml_nelements(K), main_stream);
|
|
736
|
-
K_data = (char *) K_f16.ptr;
|
|
737
|
-
|
|
738
728
|
const size_t bs = ggml_blck_size(K->type);
|
|
739
729
|
const size_t ts = ggml_type_size(K->type);
|
|
740
730
|
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
|
|
731
|
+
K_f16.alloc(ggml_nelements(K));
|
|
732
|
+
if (ggml_is_contiguously_allocated(K)) {
|
|
733
|
+
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(K->type);
|
|
734
|
+
to_fp16(K_data, K_f16.ptr, ggml_nelements(K), main_stream);
|
|
735
|
+
|
|
736
|
+
nb11 = nb11*bs*sizeof(half)/ts;
|
|
737
|
+
nb12 = nb12*bs*sizeof(half)/ts;
|
|
738
|
+
nb13 = nb13*bs*sizeof(half)/ts;
|
|
739
|
+
} else {
|
|
740
|
+
GGML_ASSERT(K->nb[0] == ts);
|
|
741
|
+
to_fp16_nc_cuda_t to_fp16 = ggml_get_to_fp16_nc_cuda(K->type);
|
|
742
|
+
const int64_t s01 = nb11 / ts;
|
|
743
|
+
const int64_t s02 = nb12 / ts;
|
|
744
|
+
const int64_t s03 = nb13 / ts;
|
|
745
|
+
to_fp16(K_data, K_f16.ptr, K->ne[0], K->ne[1], K->ne[2], K->ne[3], s01, s02, s03, main_stream);
|
|
746
|
+
|
|
747
|
+
nb11 = K->ne[0] * sizeof(half);
|
|
748
|
+
nb12 = K->ne[1] * nb11;
|
|
749
|
+
nb13 = K->ne[2] * nb12;
|
|
750
|
+
}
|
|
751
|
+
K_data = (char *) K_f16.ptr;
|
|
744
752
|
}
|
|
745
753
|
|
|
746
754
|
if (V && need_f16_V && V->type != GGML_TYPE_F16) {
|
|
747
|
-
GGML_ASSERT(ggml_is_contiguously_allocated(V));
|
|
748
|
-
V_f16.alloc(ggml_nelements(V));
|
|
749
|
-
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
|
|
750
|
-
to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream);
|
|
751
|
-
V_data = (char *) V_f16.ptr;
|
|
752
|
-
|
|
753
755
|
const size_t bs = ggml_blck_size(V->type);
|
|
754
756
|
const size_t ts = ggml_type_size(V->type);
|
|
755
757
|
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
|
|
758
|
+
V_f16.alloc(ggml_nelements(V));
|
|
759
|
+
if (ggml_is_contiguously_allocated(V)) {
|
|
760
|
+
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
|
|
761
|
+
to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream);
|
|
762
|
+
V_data = (char *) V_f16.ptr;
|
|
763
|
+
|
|
764
|
+
nb21 = nb21*bs*sizeof(half)/ts;
|
|
765
|
+
nb22 = nb22*bs*sizeof(half)/ts;
|
|
766
|
+
nb23 = nb23*bs*sizeof(half)/ts;
|
|
767
|
+
} else {
|
|
768
|
+
GGML_ASSERT(V->nb[0] == ts);
|
|
769
|
+
to_fp16_nc_cuda_t to_fp16 = ggml_get_to_fp16_nc_cuda(V->type);
|
|
770
|
+
const int64_t s01 = nb21 / ts;
|
|
771
|
+
const int64_t s02 = nb22 / ts;
|
|
772
|
+
const int64_t s03 = nb23 / ts;
|
|
773
|
+
to_fp16(V_data, V_f16.ptr, V->ne[0], V->ne[1], V->ne[2], V->ne[3], s01, s02, s03, main_stream);
|
|
774
|
+
|
|
775
|
+
nb21 = V->ne[0] * sizeof(half);
|
|
776
|
+
nb22 = V->ne[1] * nb21;
|
|
777
|
+
nb23 = V->ne[2] * nb22;
|
|
778
|
+
}
|
|
779
|
+
V_data = (char *) V_f16.ptr;
|
|
759
780
|
}
|
|
760
781
|
|
|
761
782
|
int parallel_blocks = 1;
|
|
@@ -851,14 +872,11 @@ void launch_fattn(
|
|
|
851
872
|
mask ? ((const char *) mask->data) : nullptr,
|
|
852
873
|
!stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
|
|
853
874
|
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
|
|
854
|
-
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
|
855
|
-
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
|
856
|
-
mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0,
|
|
857
|
-
mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0,
|
|
858
|
-
Q->nb[1], Q->nb[2], Q->nb[3],
|
|
859
|
-
nb11, nb12, nb13,
|
|
875
|
+
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], Q->nb[1], Q->nb[2], Q->nb[3],
|
|
876
|
+
K->ne[0], K->ne[1], K->ne[2], K->ne[3], nb11, nb12, nb13,
|
|
860
877
|
nb21, nb22, nb23,
|
|
861
|
-
|
|
878
|
+
mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, mask ? mask->ne[3] : 0,
|
|
879
|
+
mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0, mask ? mask->nb[3] : 0
|
|
862
880
|
);
|
|
863
881
|
CUDA_CHECK(cudaGetLastError());
|
|
864
882
|
|
|
@@ -869,11 +887,11 @@ void launch_fattn(
|
|
|
869
887
|
|
|
870
888
|
flash_attn_stream_k_fixup<DV, ncols1, ncols2>
|
|
871
889
|
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
|
|
872
|
-
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]);
|
|
890
|
+
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1]);
|
|
873
891
|
}
|
|
874
892
|
} else if (parallel_blocks > 1) {
|
|
875
893
|
const dim3 block_dim_combine(DV, 1, 1);
|
|
876
|
-
const dim3 blocks_num_combine(Q->ne[1],
|
|
894
|
+
const dim3 blocks_num_combine(Q->ne[1], Q->ne[2], Q->ne[3]);
|
|
877
895
|
const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2);
|
|
878
896
|
|
|
879
897
|
flash_attn_combine_results<DV>
|
|
@@ -408,7 +408,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
408
408
|
const int stride_K,
|
|
409
409
|
const int stride_V,
|
|
410
410
|
const int stride_mask,
|
|
411
|
-
const int jt,
|
|
412
411
|
half2 * const __restrict__ tile_Q,
|
|
413
412
|
half2 * const __restrict__ tile_K,
|
|
414
413
|
half2 * const __restrict__ tile_V,
|
|
@@ -455,7 +454,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
455
454
|
cp_async_wait_all();
|
|
456
455
|
__syncthreads();
|
|
457
456
|
flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
|
|
458
|
-
(V_h2 + k_VKQ_0*stride_V, tile_V, nbatch_V2, stride_V);
|
|
457
|
+
(V_h2 + int64_t(k_VKQ_0)*stride_V, tile_V, nbatch_V2, stride_V);
|
|
459
458
|
} else {
|
|
460
459
|
constexpr bool use_cp_async = nstages == 1;
|
|
461
460
|
if (ncols2 > 1 || mask_h2) {
|
|
@@ -471,7 +470,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
471
470
|
if (nstages <= 1) {
|
|
472
471
|
constexpr bool use_cp_async = nstages == 1;
|
|
473
472
|
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
|
|
474
|
-
(K_h2 + k_VKQ_0*stride_K + k0_start, tile_K, k0_diff, stride_K);
|
|
473
|
+
(K_h2 + int64_t(k_VKQ_0)*stride_K + k0_start, tile_K, k0_diff, stride_K);
|
|
475
474
|
if (use_cp_async) {
|
|
476
475
|
cp_async_wait_all();
|
|
477
476
|
}
|
|
@@ -715,7 +714,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
715
714
|
(mask_h2 + (k_VKQ_0 + c::nbatch_fa)/2, tile_mask, stride_mask);
|
|
716
715
|
}
|
|
717
716
|
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
|
|
718
|
-
(K_h2 + (k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K);
|
|
717
|
+
(K_h2 + int64_t(k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K);
|
|
719
718
|
}
|
|
720
719
|
}
|
|
721
720
|
|
|
@@ -732,7 +731,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
732
731
|
if (nstages <= 1 && i0_start < reusable_cutoff) {
|
|
733
732
|
constexpr bool use_cp_async = nstages == 1;
|
|
734
733
|
flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
|
|
735
|
-
(V_h2 + k_VKQ_0*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V);
|
|
734
|
+
(V_h2 + int64_t(k_VKQ_0)*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V);
|
|
736
735
|
if (use_cp_async) {
|
|
737
736
|
cp_async_wait_all();
|
|
738
737
|
}
|
|
@@ -771,8 +770,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
771
770
|
GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup);
|
|
772
771
|
GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap);
|
|
773
772
|
GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_K); GGML_UNUSED(stride_V);
|
|
774
|
-
GGML_UNUSED(stride_mask); GGML_UNUSED(
|
|
775
|
-
GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K);
|
|
773
|
+
GGML_UNUSED(stride_mask); GGML_UNUSED(tile_K);
|
|
776
774
|
GGML_UNUSED(tile_V); GGML_UNUSED(tile_mask); GGML_UNUSED(Q_B);
|
|
777
775
|
GGML_UNUSED(VKQ_C); GGML_UNUSED(KQ_max); GGML_UNUSED(KQ_rowsum);
|
|
778
776
|
GGML_UNUSED(kb0); GGML_UNUSED(tile_Q);
|
|
@@ -920,7 +918,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
920
918
|
(mask_h2 + kb0_start*c::nbatch_fa/2, tile_mask, stride_mask);
|
|
921
919
|
}
|
|
922
920
|
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
|
|
923
|
-
(K_h2 + kb0_start*c::nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K);
|
|
921
|
+
(K_h2 + int64_t(kb0_start)*c::nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K);
|
|
924
922
|
}
|
|
925
923
|
|
|
926
924
|
// Iterate over ne11 == previous tokens:
|
|
@@ -928,13 +926,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
928
926
|
constexpr bool last_iter = false;
|
|
929
927
|
flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
|
|
930
928
|
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
|
|
931
|
-
ne01, ne02, stride_K, stride_V, stride_mask,
|
|
929
|
+
ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
|
|
932
930
|
}
|
|
933
931
|
{ // kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
|
|
934
932
|
constexpr bool last_iter = true;
|
|
935
933
|
flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
|
|
936
934
|
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
|
|
937
|
-
ne01, ne02, stride_K, stride_V, stride_mask,
|
|
935
|
+
ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1);
|
|
938
936
|
}
|
|
939
937
|
|
|
940
938
|
// With multi-stage loading there is no __syncthreads at the end of the iter,
|
|
@@ -1214,31 +1212,13 @@ static __global__ void flash_attn_ext_f16(
|
|
|
1214
1212
|
const float m1,
|
|
1215
1213
|
const uint32_t n_head_log2,
|
|
1216
1214
|
const float logit_softcap,
|
|
1217
|
-
const
|
|
1218
|
-
|
|
1219
|
-
const
|
|
1220
|
-
|
|
1221
|
-
|
|
1222
|
-
|
|
1223
|
-
|
|
1224
|
-
const int ne13,
|
|
1225
|
-
const int ne31,
|
|
1226
|
-
const int ne32,
|
|
1227
|
-
const int nb31,
|
|
1228
|
-
const int nb32,
|
|
1229
|
-
const int nb01,
|
|
1230
|
-
const int nb02,
|
|
1231
|
-
const int nb03,
|
|
1232
|
-
const int nb11,
|
|
1233
|
-
const int nb12,
|
|
1234
|
-
const int nb13,
|
|
1235
|
-
const int nb21,
|
|
1236
|
-
const int nb22,
|
|
1237
|
-
const int nb23,
|
|
1238
|
-
const int ne0,
|
|
1239
|
-
const int ne1,
|
|
1240
|
-
const int ne2,
|
|
1241
|
-
const int ne3) {
|
|
1215
|
+
const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
|
|
1216
|
+
const int32_t nb01, const int32_t nb02, const int32_t nb03,
|
|
1217
|
+
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
|
|
1218
|
+
const int32_t nb11, const int32_t nb12, const int64_t nb13,
|
|
1219
|
+
const int32_t nb21, const int32_t nb22, const int64_t nb23,
|
|
1220
|
+
const int32_t ne31, const int32_t ne32, const int32_t ne33,
|
|
1221
|
+
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
|
|
1242
1222
|
#if defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
|
|
1243
1223
|
|
|
1244
1224
|
// Skip unused kernel variants for faster compilation:
|
|
@@ -1274,8 +1254,8 @@ static __global__ void flash_attn_ext_f16(
|
|
|
1274
1254
|
constexpr int kb_niter = FATTN_KQ_STRIDE / c::nbatch_fa; // Number of kernel iterations per assigned KQ slice.
|
|
1275
1255
|
|
|
1276
1256
|
// kbc == k block continuous, current index in continuous ijk space.
|
|
1277
|
-
int kbc = (blockIdx.x + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
|
|
1278
|
-
const int kbc_stop = (blockIdx.x + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
|
|
1257
|
+
int kbc = (blockIdx.x + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
|
|
1258
|
+
const int kbc_stop = (blockIdx.x + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
|
|
1279
1259
|
|
|
1280
1260
|
// If the seams of 2 CUDA blocks fall within an output tile their results need to be combined.
|
|
1281
1261
|
// For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup).
|
|
@@ -1285,18 +1265,19 @@ static __global__ void flash_attn_ext_f16(
|
|
|
1285
1265
|
int kb0_start = kbc % iter_k;
|
|
1286
1266
|
int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc);
|
|
1287
1267
|
while (kbc < kbc_stop && kb0_stop == iter_k) {
|
|
1288
|
-
const int
|
|
1289
|
-
const int
|
|
1268
|
+
const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
|
|
1269
|
+
const int head = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
|
|
1270
|
+
const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
|
|
1290
1271
|
|
|
1291
|
-
const float2 * Q_f2 = (const float2 *) (Q +
|
|
1292
|
-
const half2 * K_h2 = (const half2 *) (K + nb12*(
|
|
1272
|
+
const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*(head*ncols2));
|
|
1273
|
+
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head*ncols2 / gqa_ratio));
|
|
1293
1274
|
const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
|
|
1294
|
-
(const half2 *) (mask +
|
|
1295
|
-
float2 * dstk = ((float2 *) dst) +
|
|
1275
|
+
(const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1);
|
|
1276
|
+
float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head*ncols2) * (DV/2);
|
|
1296
1277
|
|
|
1297
|
-
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(
|
|
1278
|
+
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head*ncols2 / gqa_ratio));
|
|
1298
1279
|
|
|
1299
|
-
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias,
|
|
1280
|
+
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head, n_head_log2, m0, m1) : 1.0f;
|
|
1300
1281
|
|
|
1301
1282
|
const int kb0_start_kernel = kb0_start * kb_niter;
|
|
1302
1283
|
const int kb0_stop_kernel = kb0_stop * kb_niter;
|
|
@@ -1325,18 +1306,19 @@ static __global__ void flash_attn_ext_f16(
|
|
|
1325
1306
|
return;
|
|
1326
1307
|
}
|
|
1327
1308
|
|
|
1328
|
-
const int
|
|
1329
|
-
const int
|
|
1309
|
+
const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
|
|
1310
|
+
const int head = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
|
|
1311
|
+
const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
|
|
1330
1312
|
|
|
1331
|
-
const float2 * Q_f2 = (const float2 *) (Q +
|
|
1332
|
-
const half2 * K_h2 = (const half2 *) (K + nb12*(
|
|
1313
|
+
const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*(head*ncols2));
|
|
1314
|
+
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head*ncols2 / gqa_ratio));
|
|
1333
1315
|
const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
|
|
1334
|
-
(const half2 *) (mask +
|
|
1335
|
-
float2 * dstk = ((float2 *) dst) +
|
|
1316
|
+
(const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1);
|
|
1317
|
+
float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head*ncols2) * (DV/2);
|
|
1336
1318
|
|
|
1337
|
-
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(
|
|
1319
|
+
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head*ncols2 / gqa_ratio));
|
|
1338
1320
|
|
|
1339
|
-
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias,
|
|
1321
|
+
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head, n_head_log2, m0, m1) : 1.0f;
|
|
1340
1322
|
|
|
1341
1323
|
const int kb0_start_kernel = kb0_start * kb_niter;
|
|
1342
1324
|
const int kb0_stop_kernel = kb0_stop * kb_niter;
|
|
@@ -1348,15 +1330,16 @@ static __global__ void flash_attn_ext_f16(
|
|
|
1348
1330
|
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
|
1349
1331
|
#else
|
|
1350
1332
|
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
|
|
1351
|
-
GGML_UNUSED(dst); GGML_UNUSED(dst_meta);
|
|
1352
|
-
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
|
1353
|
-
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
|
1354
|
-
GGML_UNUSED(
|
|
1355
|
-
GGML_UNUSED(
|
|
1356
|
-
GGML_UNUSED(
|
|
1357
|
-
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
|
|
1358
|
-
GGML_UNUSED(
|
|
1359
|
-
GGML_UNUSED(
|
|
1333
|
+
GGML_UNUSED(dst); GGML_UNUSED(dst_meta);
|
|
1334
|
+
GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
|
1335
|
+
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
|
1336
|
+
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
|
|
1337
|
+
GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
|
|
1338
|
+
GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
|
|
1339
|
+
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
|
|
1340
|
+
GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
|
|
1341
|
+
GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
|
|
1342
|
+
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33);
|
|
1360
1343
|
NO_DEVICE_CODE;
|
|
1361
1344
|
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
|
|
1362
1345
|
}
|
|
@@ -21,31 +21,13 @@ static __global__ void flash_attn_tile_ext_f16(
|
|
|
21
21
|
const float m1,
|
|
22
22
|
const uint32_t n_head_log2,
|
|
23
23
|
const float logit_softcap,
|
|
24
|
-
const
|
|
25
|
-
|
|
26
|
-
const
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
const int ne13,
|
|
32
|
-
const int ne31,
|
|
33
|
-
const int ne32,
|
|
34
|
-
const int nb31,
|
|
35
|
-
const int nb32,
|
|
36
|
-
const int nb01,
|
|
37
|
-
const int nb02,
|
|
38
|
-
const int nb03,
|
|
39
|
-
const int nb11,
|
|
40
|
-
const int nb12,
|
|
41
|
-
const int nb13,
|
|
42
|
-
const int nb21,
|
|
43
|
-
const int nb22,
|
|
44
|
-
const int nb23,
|
|
45
|
-
const int ne0,
|
|
46
|
-
const int ne1,
|
|
47
|
-
const int ne2,
|
|
48
|
-
const int ne3) {
|
|
24
|
+
const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
|
|
25
|
+
const int32_t nb01, const int32_t nb02, const int32_t nb03,
|
|
26
|
+
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
|
|
27
|
+
const int32_t nb11, const int32_t nb12, const int64_t nb13,
|
|
28
|
+
const int32_t nb21, const int32_t nb22, const int64_t nb23,
|
|
29
|
+
const int32_t ne31, const int32_t ne32, const int32_t ne33,
|
|
30
|
+
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
|
|
49
31
|
#if defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
|
|
50
32
|
|
|
51
33
|
// Skip unused kernel variants for faster compilation:
|
|
@@ -62,15 +44,17 @@ static __global__ void flash_attn_tile_ext_f16(
|
|
|
62
44
|
|
|
63
45
|
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
|
|
64
46
|
|
|
47
|
+
const int sequence = blockIdx.z / ne02;
|
|
48
|
+
const int head = blockIdx.z - sequence*ne02;
|
|
65
49
|
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
|
66
|
-
const float2 * Q_f2 = (const float2 *) (Q + nb02*
|
|
67
|
-
const half2 * K_h2 = (const half2 *) (K + nb12*(
|
|
68
|
-
const half2 * V_h2 = (const half2 *) (V + nb12*(
|
|
69
|
-
const half * maskh = (const half *) (mask +
|
|
50
|
+
const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0);
|
|
51
|
+
const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
|
|
52
|
+
const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
|
|
53
|
+
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
|
|
70
54
|
|
|
71
55
|
const int stride_KV2 = nb11 / sizeof(half2);
|
|
72
56
|
|
|
73
|
-
const float slopef = get_alibi_slope(max_bias,
|
|
57
|
+
const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
|
|
74
58
|
const half slopeh = __float2half(slopef);
|
|
75
59
|
|
|
76
60
|
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
|
|
@@ -123,7 +107,7 @@ static __global__ void flash_attn_tile_ext_f16(
|
|
|
123
107
|
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
|
|
124
108
|
const int k_KQ = k_KQ_0 + threadIdx.x;
|
|
125
109
|
|
|
126
|
-
KV_tmp[i_KQ][k_KQ] = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ];
|
|
110
|
+
KV_tmp[i_KQ][k_KQ] = K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ];
|
|
127
111
|
}
|
|
128
112
|
}
|
|
129
113
|
|
|
@@ -217,7 +201,7 @@ static __global__ void flash_attn_tile_ext_f16(
|
|
|
217
201
|
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
|
218
202
|
const int i = i0 + threadIdx.x;
|
|
219
203
|
|
|
220
|
-
KV_tmp[k][i] = V_h2[(k_VKQ_0 + k)*stride_KV2 + i];
|
|
204
|
+
KV_tmp[k][i] = V_h2[int64_t(k_VKQ_0 + k)*stride_KV2 + i];
|
|
221
205
|
}
|
|
222
206
|
}
|
|
223
207
|
|
|
@@ -255,6 +239,8 @@ static __global__ void flash_attn_tile_ext_f16(
|
|
|
255
239
|
__syncthreads();
|
|
256
240
|
}
|
|
257
241
|
|
|
242
|
+
float2 * dst2 = (float2 *) dst;
|
|
243
|
+
|
|
258
244
|
#pragma unroll
|
|
259
245
|
for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
|
|
260
246
|
const int j_VKQ = j_VKQ_0 + threadIdx.y;
|
|
@@ -266,21 +252,21 @@ static __global__ void flash_attn_tile_ext_f16(
|
|
|
266
252
|
half kqsum_j = __low2half(kqsum[j_VKQ_0/nwarps]) + __high2half(kqsum[j_VKQ_0/nwarps]);
|
|
267
253
|
kqsum_j = warp_reduce_sum((float)kqsum_j);
|
|
268
254
|
|
|
255
|
+
const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
|
|
256
|
+
|
|
269
257
|
#pragma unroll
|
|
270
|
-
for (int i00 = 0; i00 < D; i00 +=
|
|
271
|
-
const int i0 = i00 +
|
|
258
|
+
for (int i00 = 0; i00 < D/2; i00 += WARP_SIZE) {
|
|
259
|
+
const int i0 = i00 + threadIdx.x;
|
|
272
260
|
|
|
273
|
-
half2 dst_val = VKQ[j_VKQ_0/nwarps][i0/
|
|
261
|
+
half2 dst_val = VKQ[j_VKQ_0/nwarps][i0/WARP_SIZE];
|
|
274
262
|
if (gridDim.y == 1) {
|
|
275
263
|
dst_val /= __half2half2(kqsum_j);
|
|
276
264
|
}
|
|
277
|
-
|
|
278
|
-
dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 0] = __low2float(dst_val);
|
|
279
|
-
dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 1] = __high2float(dst_val);
|
|
265
|
+
dst2[j_dst_unrolled*(D/2) + i0] = __half22float2(dst_val);
|
|
280
266
|
}
|
|
281
267
|
|
|
282
268
|
if (gridDim.y != 1 && threadIdx.x == 0) {
|
|
283
|
-
dst_meta[
|
|
269
|
+
dst_meta[j_dst_unrolled] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
|
|
284
270
|
}
|
|
285
271
|
}
|
|
286
272
|
#else
|
|
@@ -290,12 +276,11 @@ static __global__ void flash_attn_tile_ext_f16(
|
|
|
290
276
|
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
|
291
277
|
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
|
|
292
278
|
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
|
|
293
|
-
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
|
|
294
|
-
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
|
279
|
+
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
|
|
280
|
+
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
|
295
281
|
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
|
|
296
282
|
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
|
|
297
|
-
GGML_UNUSED(nb23);
|
|
298
|
-
GGML_UNUSED(ne2); GGML_UNUSED(ne3);
|
|
283
|
+
GGML_UNUSED(nb23);
|
|
299
284
|
NO_DEVICE_CODE;
|
|
300
285
|
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
|
|
301
286
|
}
|