@novastera-oss/llamarn 0.2.9 → 0.3.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/build.gradle +2 -1
- package/android/proguard-rules.pro +12 -0
- package/android/src/main/cpp/include/llama.h +15 -47
- package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86/libggml.so +0 -0
- package/android/src/main/jniLibs/x86/libllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
- package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
- package/cpp/build-info.cpp +2 -2
- package/cpp/llama.cpp/CMakeLists.txt +0 -1
- package/cpp/llama.cpp/CMakePresets.json +11 -0
- package/cpp/llama.cpp/CODEOWNERS +1 -0
- package/cpp/llama.cpp/README.md +8 -8
- package/cpp/llama.cpp/build-xcframework.sh +1 -1
- package/cpp/llama.cpp/common/CMakeLists.txt +4 -5
- package/cpp/llama.cpp/common/arg.cpp +62 -1
- package/cpp/llama.cpp/common/chat.cpp +37 -20
- package/cpp/llama.cpp/common/chat.h +2 -0
- package/cpp/llama.cpp/common/common.cpp +22 -6
- package/cpp/llama.cpp/common/common.h +22 -4
- package/cpp/llama.cpp/convert_hf_to_gguf.py +1250 -43
- package/cpp/llama.cpp/convert_hf_to_gguf_update.py +21 -13
- package/cpp/llama.cpp/ggml/CMakeLists.txt +13 -3
- package/cpp/llama.cpp/ggml/cmake/ggml-config.cmake.in +85 -47
- package/cpp/llama.cpp/ggml/include/ggml-backend.h +1 -1
- package/cpp/llama.cpp/ggml/include/ggml-webgpu.h +19 -0
- package/cpp/llama.cpp/ggml/include/ggml.h +173 -10
- package/cpp/llama.cpp/ggml/src/CMakeLists.txt +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-alloc.c +0 -15
- package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +7 -8
- package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +44 -38
- package/cpp/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +3 -1
- package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +126 -8
- package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +130 -22
- package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +138 -18
- package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +11 -3
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +28 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +109 -12
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +88 -10
- package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +343 -1094
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +1206 -163
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +6 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +0 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +36 -9
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +142 -9
- package/cpp/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +3 -3
- package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +31 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +86 -17
- package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy-utils.cuh +225 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy.cu +41 -301
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +85 -64
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +47 -60
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +29 -42
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +46 -59
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +36 -45
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +38 -45
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +23 -36
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cu +3 -13
- package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +255 -99
- package/cpp/llama.cpp/ggml/src/ggml-cuda/im2col.cu +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mma.cuh +111 -3
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cu +6 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +1152 -695
- package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cu +92 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cuh +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/rope.cu +21 -27
- package/cpp/llama.cpp/ggml/src/ggml-cuda/scale.cu +8 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cu +275 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cuh +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/softmax.cu +119 -58
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +192 -52
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +104 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +13 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/upscale.cu +92 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +27 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-impl.h +80 -0
- package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -2
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +48 -12
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +572 -106
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +599 -105
- package/cpp/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +18 -4
- package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +800 -42
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl +337 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/rms_norm.cl +79 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/set_rows.cl +95 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
- package/cpp/llama.cpp/ggml/src/ggml-quants.c +6 -6
- package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +4 -4
- package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +693 -1034
- package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +18 -9
- package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +14 -26
- package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +191 -55
- package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +8 -9
- package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +15 -18
- package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +131 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.hpp +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +2 -6
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +991 -307
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +265 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +59 -12
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +28 -23
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +14 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +38 -32
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +32 -27
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +44 -12
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +17 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +3 -8
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +128 -72
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +38 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +18 -3
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +1 -4
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +20 -4
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +84 -9
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/CMakeLists.txt +54 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/ggml-webgpu.cpp +907 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +60 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +35 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +56 -0
- package/cpp/llama.cpp/ggml/src/ggml.c +386 -67
- package/cpp/llama.cpp/ggml/src/gguf.cpp +8 -1
- package/cpp/llama.cpp/gguf-py/gguf/constants.py +307 -0
- package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +8 -2
- package/cpp/llama.cpp/gguf-py/gguf/metadata.py +4 -0
- package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_dump.py +24 -1
- package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +122 -47
- package/cpp/llama.cpp/gguf-py/gguf/vocab.py +12 -3
- package/cpp/llama.cpp/include/llama.h +15 -47
- package/cpp/llama.cpp/models/templates/llama-cpp-rwkv-world.jinja +34 -0
- package/cpp/llama.cpp/models/templates/moonshotai-Kimi-K2.jinja +43 -0
- package/cpp/llama.cpp/requirements/requirements-all.txt +1 -0
- package/cpp/llama.cpp/requirements/requirements-server-bench.txt +5 -0
- package/cpp/llama.cpp/src/llama-arch.cpp +316 -3
- package/cpp/llama.cpp/src/llama-arch.h +23 -1
- package/cpp/llama.cpp/src/llama-batch.cpp +103 -71
- package/cpp/llama.cpp/src/llama-batch.h +31 -18
- package/cpp/llama.cpp/src/llama-chat.cpp +58 -1
- package/cpp/llama.cpp/src/llama-chat.h +3 -0
- package/cpp/llama.cpp/src/llama-context.cpp +180 -106
- package/cpp/llama.cpp/src/llama-context.h +26 -16
- package/cpp/llama.cpp/src/llama-cparams.h +3 -2
- package/cpp/llama.cpp/src/llama-graph.cpp +310 -211
- package/cpp/llama.cpp/src/llama-graph.h +184 -122
- package/cpp/llama.cpp/src/llama-hparams.cpp +47 -1
- package/cpp/llama.cpp/src/llama-hparams.h +13 -2
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +38 -22
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +7 -2
- package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +849 -304
- package/cpp/llama.cpp/src/llama-kv-cache-unified.h +143 -47
- package/cpp/llama.cpp/src/llama-kv-cells.h +62 -10
- package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +10 -4
- package/cpp/llama.cpp/src/llama-memory-hybrid.h +3 -1
- package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +36 -11
- package/cpp/llama.cpp/src/llama-memory.cpp +17 -0
- package/cpp/llama.cpp/src/llama-memory.h +3 -0
- package/cpp/llama.cpp/src/llama-model.cpp +3545 -719
- package/cpp/llama.cpp/src/llama-model.h +21 -4
- package/cpp/llama.cpp/src/llama-quant.cpp +2 -2
- package/cpp/llama.cpp/src/llama-vocab.cpp +376 -10
- package/cpp/llama.cpp/src/llama-vocab.h +43 -0
- package/cpp/llama.cpp/src/unicode.cpp +207 -0
- package/cpp/llama.cpp/src/unicode.h +2 -0
- package/ios/include/chat.h +2 -0
- package/ios/include/common.h +22 -4
- package/ios/include/llama.h +15 -47
- package/ios/libs/llama.xcframework/Info.plist +13 -13
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5267 -4890
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -4861
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4014 -3764
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -4861
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4016 -3766
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5267 -4890
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -4861
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4014 -3764
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5303 -4926
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5274 -4897
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4044 -3794
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/package.json +4 -4
- package/cpp/llama.cpp/ggml/include/ggml-kompute.h +0 -50
- package/cpp/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
- package/cpp/llama.cpp/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
|
@@ -6,7 +6,7 @@
|
|
|
6
6
|
|
|
7
7
|
template<int D, int ncols, int nwarps, bool use_logit_softcap> // D == head size
|
|
8
8
|
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
|
9
|
-
__launch_bounds__(nwarps*WARP_SIZE,
|
|
9
|
+
__launch_bounds__(nwarps*WARP_SIZE, 2)
|
|
10
10
|
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
|
11
11
|
static __global__ void flash_attn_tile_ext_f32(
|
|
12
12
|
const char * __restrict__ Q,
|
|
@@ -21,29 +21,13 @@ static __global__ void flash_attn_tile_ext_f32(
|
|
|
21
21
|
const float m1,
|
|
22
22
|
const uint32_t n_head_log2,
|
|
23
23
|
const float logit_softcap,
|
|
24
|
-
const
|
|
25
|
-
|
|
26
|
-
const
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
const int ne13,
|
|
32
|
-
const int ne31,
|
|
33
|
-
const int nb31,
|
|
34
|
-
const int nb01,
|
|
35
|
-
const int nb02,
|
|
36
|
-
const int nb03,
|
|
37
|
-
const int nb11,
|
|
38
|
-
const int nb12,
|
|
39
|
-
const int nb13,
|
|
40
|
-
const int nb21,
|
|
41
|
-
const int nb22,
|
|
42
|
-
const int nb23,
|
|
43
|
-
const int ne0,
|
|
44
|
-
const int ne1,
|
|
45
|
-
const int ne2,
|
|
46
|
-
const int ne3) {
|
|
24
|
+
const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
|
|
25
|
+
const int32_t nb01, const int32_t nb02, const int32_t nb03,
|
|
26
|
+
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
|
|
27
|
+
const int32_t nb11, const int32_t nb12, const int64_t nb13,
|
|
28
|
+
const int32_t nb21, const int32_t nb22, const int64_t nb23,
|
|
29
|
+
const int32_t ne31, const int32_t ne32, const int32_t ne33,
|
|
30
|
+
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
|
|
47
31
|
#ifdef FLASH_ATTN_AVAILABLE
|
|
48
32
|
|
|
49
33
|
// Skip unused kernel variants for faster compilation:
|
|
@@ -53,17 +37,16 @@ static __global__ void flash_attn_tile_ext_f32(
|
|
|
53
37
|
#endif // FP16_MMA_AVAILABLE
|
|
54
38
|
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
|
55
39
|
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
|
|
56
|
-
GGML_UNUSED(dst); GGML_UNUSED(dst_meta);
|
|
57
|
-
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
|
40
|
+
GGML_UNUSED(dst); GGML_UNUSED(dst_meta);
|
|
41
|
+
GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
|
58
42
|
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
|
59
|
-
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
|
|
60
|
-
GGML_UNUSED(
|
|
61
|
-
GGML_UNUSED(
|
|
62
|
-
GGML_UNUSED(
|
|
63
|
-
GGML_UNUSED(
|
|
64
|
-
GGML_UNUSED(
|
|
65
|
-
GGML_UNUSED(
|
|
66
|
-
GGML_UNUSED(ne2); GGML_UNUSED(ne3);
|
|
43
|
+
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
|
|
44
|
+
GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
|
|
45
|
+
GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
|
|
46
|
+
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
|
|
47
|
+
GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
|
|
48
|
+
GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
|
|
49
|
+
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33);
|
|
67
50
|
NO_DEVICE_CODE;
|
|
68
51
|
return;
|
|
69
52
|
}
|
|
@@ -72,15 +55,17 @@ static __global__ void flash_attn_tile_ext_f32(
|
|
|
72
55
|
|
|
73
56
|
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
|
|
74
57
|
|
|
58
|
+
const int sequence = blockIdx.z / ne02;
|
|
59
|
+
const int head = blockIdx.z - sequence*ne02;
|
|
75
60
|
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
|
76
|
-
const float2 * Q_f2 = (const float2 *) (Q + nb02*
|
|
77
|
-
const half2 * K_h2 = (const half2 *) (K + nb12*(
|
|
78
|
-
const half2 * V_h2 = (const half2 *) (V + nb12*(
|
|
79
|
-
const half * maskh = (const half *)
|
|
61
|
+
const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0);
|
|
62
|
+
const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
|
|
63
|
+
const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
|
|
64
|
+
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
|
|
80
65
|
|
|
81
66
|
const int stride_KV2 = nb11 / sizeof(half2);
|
|
82
67
|
|
|
83
|
-
const float slope = get_alibi_slope(max_bias,
|
|
68
|
+
const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
|
|
84
69
|
|
|
85
70
|
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
|
|
86
71
|
|
|
@@ -129,7 +114,7 @@ static __global__ void flash_attn_tile_ext_f32(
|
|
|
129
114
|
|
|
130
115
|
#pragma unroll
|
|
131
116
|
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 2*WARP_SIZE) {
|
|
132
|
-
const half2 tmp = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + threadIdx.x];
|
|
117
|
+
const half2 tmp = K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + threadIdx.x];
|
|
133
118
|
KV_tmp[i_KQ][k_KQ_0 + 0*WARP_SIZE + threadIdx.x] = __low2float(tmp);
|
|
134
119
|
KV_tmp[i_KQ][k_KQ_0 + 1*WARP_SIZE + threadIdx.x] = __high2float(tmp);
|
|
135
120
|
}
|
|
@@ -225,8 +210,9 @@ static __global__ void flash_attn_tile_ext_f32(
|
|
|
225
210
|
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
|
226
211
|
const int i = i0 + threadIdx.x;
|
|
227
212
|
|
|
228
|
-
|
|
229
|
-
KV_tmp2[k*(D/2) + i].
|
|
213
|
+
const half2 tmp = V_h2[int64_t(k_VKQ_0 + k)*stride_KV2 + i];
|
|
214
|
+
KV_tmp2[k*(D/2) + i].x = __low2float(tmp);
|
|
215
|
+
KV_tmp2[k*(D/2) + i].y = __high2float(tmp);
|
|
230
216
|
}
|
|
231
217
|
}
|
|
232
218
|
|
|
@@ -263,6 +249,8 @@ static __global__ void flash_attn_tile_ext_f32(
|
|
|
263
249
|
__syncthreads();
|
|
264
250
|
}
|
|
265
251
|
|
|
252
|
+
float2 * dst2 = (float2 *) dst;
|
|
253
|
+
|
|
266
254
|
#pragma unroll
|
|
267
255
|
for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
|
|
268
256
|
const int j_VKQ = j_VKQ_0 + threadIdx.y;
|
|
@@ -274,37 +262,36 @@ static __global__ void flash_attn_tile_ext_f32(
|
|
|
274
262
|
float kqsum_j = kqsum[j_VKQ_0/nwarps];
|
|
275
263
|
kqsum_j = warp_reduce_sum(kqsum_j);
|
|
276
264
|
|
|
265
|
+
const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
|
|
266
|
+
|
|
277
267
|
#pragma unroll
|
|
278
|
-
for (int i00 = 0; i00 < D; i00 +=
|
|
279
|
-
const int i0 = i00 +
|
|
268
|
+
for (int i00 = 0; i00 < D/2; i00 += WARP_SIZE) {
|
|
269
|
+
const int i0 = i00 + threadIdx.x;
|
|
280
270
|
|
|
281
|
-
float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/
|
|
271
|
+
float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/WARP_SIZE];
|
|
282
272
|
if (gridDim.y == 1) {
|
|
283
273
|
dst_val.x /= kqsum_j;
|
|
284
274
|
dst_val.y /= kqsum_j;
|
|
285
275
|
}
|
|
286
|
-
|
|
287
|
-
dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 0] = dst_val.x;
|
|
288
|
-
dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 1] = dst_val.y;
|
|
276
|
+
dst2[j_dst_unrolled*(D/2) + i0] = dst_val;
|
|
289
277
|
}
|
|
290
278
|
|
|
291
279
|
if (gridDim.y != 1 && threadIdx.x == 0) {
|
|
292
|
-
dst_meta[
|
|
280
|
+
dst_meta[j_dst_unrolled] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
|
|
293
281
|
}
|
|
294
282
|
}
|
|
295
283
|
#else
|
|
296
284
|
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
|
|
297
|
-
GGML_UNUSED(dst); GGML_UNUSED(dst_meta);
|
|
298
|
-
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
|
285
|
+
GGML_UNUSED(dst); GGML_UNUSED(dst_meta);
|
|
286
|
+
GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
|
299
287
|
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
|
300
|
-
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
|
|
301
|
-
GGML_UNUSED(
|
|
302
|
-
GGML_UNUSED(
|
|
303
|
-
GGML_UNUSED(
|
|
304
|
-
GGML_UNUSED(
|
|
305
|
-
GGML_UNUSED(
|
|
306
|
-
GGML_UNUSED(
|
|
307
|
-
GGML_UNUSED(ne2); GGML_UNUSED(ne3);
|
|
288
|
+
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
|
|
289
|
+
GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
|
|
290
|
+
GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
|
|
291
|
+
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
|
|
292
|
+
GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
|
|
293
|
+
GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
|
|
294
|
+
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33);
|
|
308
295
|
NO_DEVICE_CODE;
|
|
309
296
|
#endif // FLASH_ATTN_AVAILABLE
|
|
310
297
|
}
|
|
@@ -18,29 +18,13 @@ static __global__ void flash_attn_vec_ext_f16(
|
|
|
18
18
|
const float m1,
|
|
19
19
|
const uint32_t n_head_log2,
|
|
20
20
|
const float logit_softcap,
|
|
21
|
-
const
|
|
22
|
-
|
|
23
|
-
const
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
const int ne13,
|
|
29
|
-
const int ne31,
|
|
30
|
-
const int nb31,
|
|
31
|
-
const int nb01,
|
|
32
|
-
const int nb02,
|
|
33
|
-
const int nb03,
|
|
34
|
-
const int nb11,
|
|
35
|
-
const int nb12,
|
|
36
|
-
const int nb13,
|
|
37
|
-
const int nb21,
|
|
38
|
-
const int nb22,
|
|
39
|
-
const int nb23,
|
|
40
|
-
const int ne0,
|
|
41
|
-
const int ne1,
|
|
42
|
-
const int ne2,
|
|
43
|
-
const int ne3) {
|
|
21
|
+
const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
|
|
22
|
+
const int32_t nb01, const int32_t nb02, const int32_t nb03,
|
|
23
|
+
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
|
|
24
|
+
const int32_t nb11, const int32_t nb12, const int64_t nb13,
|
|
25
|
+
const int32_t nb21, const int32_t nb22, const int64_t nb23,
|
|
26
|
+
const int32_t ne31, const int32_t ne32, const int32_t ne33,
|
|
27
|
+
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
|
|
44
28
|
#if defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
|
|
45
29
|
|
|
46
30
|
// Skip unused kernel variants for faster compilation:
|
|
@@ -63,14 +47,16 @@ static __global__ void flash_attn_vec_ext_f16(
|
|
|
63
47
|
|
|
64
48
|
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
|
|
65
49
|
|
|
50
|
+
const int sequence = blockIdx.z / ne02;
|
|
51
|
+
const int head = blockIdx.z - sequence*ne02;
|
|
66
52
|
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
|
67
|
-
Q += nb02*
|
|
68
|
-
K += nb12*(
|
|
69
|
-
V += nb22*(
|
|
53
|
+
Q += nb03*sequence + nb02* head + nb01*ic0;
|
|
54
|
+
K += nb13*sequence + nb12*(head / gqa_ratio);
|
|
55
|
+
V += nb23*sequence + nb22*(head / gqa_ratio);
|
|
70
56
|
|
|
71
|
-
const half * maskh = (const half
|
|
57
|
+
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
|
|
72
58
|
|
|
73
|
-
const float slopef = get_alibi_slope(max_bias,
|
|
59
|
+
const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
|
|
74
60
|
const half slopeh = __float2half(slopef);
|
|
75
61
|
|
|
76
62
|
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
|
|
@@ -185,13 +171,16 @@ static __global__ void flash_attn_vec_ext_f16(
|
|
|
185
171
|
|
|
186
172
|
half2 VKQ[ncols] = {{0.0f, 0.0f}};
|
|
187
173
|
|
|
174
|
+
K += blockIdx.y*D * nb11;
|
|
175
|
+
V += blockIdx.y*D * nb21;
|
|
176
|
+
maskh += blockIdx.y*D;
|
|
188
177
|
for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) {
|
|
189
178
|
// Calculate KQ tile and keep track of new maximum KQ values:
|
|
190
179
|
|
|
191
180
|
if (mask) {
|
|
192
181
|
#pragma unroll
|
|
193
182
|
for (int j = 0; j < ncols; ++j) {
|
|
194
|
-
maskh_shared[j*D + tid] = slopeh*maskh[j*ne11 +
|
|
183
|
+
maskh_shared[j*D + tid] = slopeh*maskh[j*ne11 + tid];
|
|
195
184
|
}
|
|
196
185
|
|
|
197
186
|
__syncthreads();
|
|
@@ -238,7 +227,7 @@ static __global__ void flash_attn_vec_ext_f16(
|
|
|
238
227
|
|
|
239
228
|
#pragma unroll
|
|
240
229
|
for (int j = 0; j < ncols; ++j) {
|
|
241
|
-
half sum = vec_dot_KQ(K +
|
|
230
|
+
half sum = vec_dot_KQ(K + i_KQ*nb11, Q_h2[j], Q_i32[j], Q_ds[j]);
|
|
242
231
|
sum = warp_reduce_sum((float)sum);
|
|
243
232
|
|
|
244
233
|
if (use_logit_softcap) {
|
|
@@ -294,14 +283,18 @@ static __global__ void flash_attn_vec_ext_f16(
|
|
|
294
283
|
}
|
|
295
284
|
|
|
296
285
|
half2 V_k;
|
|
297
|
-
reinterpret_cast<half&>(V_k.x) = dequantize_1_v(V + (
|
|
298
|
-
reinterpret_cast<half&>(V_k.y) = dequantize_1_v(V + (
|
|
286
|
+
reinterpret_cast<half&>(V_k.x) = dequantize_1_v(V + (k0 + 0)*nb21, tid);
|
|
287
|
+
reinterpret_cast<half&>(V_k.y) = dequantize_1_v(V + (k0 + 1)*nb21, tid);
|
|
299
288
|
#pragma unroll
|
|
300
289
|
for (int j = 0; j < ncols; ++j) {
|
|
301
290
|
VKQ[j] += V_k*KQ2[j*(D/2) + k0/2];
|
|
302
291
|
}
|
|
303
292
|
}
|
|
304
293
|
|
|
294
|
+
K += gridDim.y*D * nb11;
|
|
295
|
+
V += gridDim.y*D * nb21;
|
|
296
|
+
maskh += gridDim.y*D;
|
|
297
|
+
|
|
305
298
|
__syncthreads();
|
|
306
299
|
}
|
|
307
300
|
|
|
@@ -328,26 +321,24 @@ static __global__ void flash_attn_vec_ext_f16(
|
|
|
328
321
|
if (gridDim.y == 1) {
|
|
329
322
|
dst_val /= kqsum[j_VKQ];
|
|
330
323
|
}
|
|
331
|
-
|
|
332
|
-
dst[j_dst*D*gridDim.z + D*blockIdx.z + tid] = dst_val;
|
|
324
|
+
dst[(((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*D + tid] = dst_val;
|
|
333
325
|
}
|
|
334
326
|
|
|
335
327
|
if (gridDim.y != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
|
|
336
|
-
dst_meta[((ic0 + tid)*
|
|
328
|
+
dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
|
|
337
329
|
}
|
|
338
330
|
#else
|
|
339
331
|
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
|
|
340
|
-
GGML_UNUSED(dst); GGML_UNUSED(dst_meta);
|
|
341
|
-
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
|
332
|
+
GGML_UNUSED(dst); GGML_UNUSED(dst_meta);
|
|
333
|
+
GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
|
342
334
|
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
|
343
|
-
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
|
|
344
|
-
GGML_UNUSED(
|
|
345
|
-
GGML_UNUSED(
|
|
346
|
-
GGML_UNUSED(
|
|
347
|
-
GGML_UNUSED(
|
|
348
|
-
GGML_UNUSED(
|
|
349
|
-
GGML_UNUSED(
|
|
350
|
-
GGML_UNUSED(ne2); GGML_UNUSED(ne3);
|
|
335
|
+
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
|
|
336
|
+
GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
|
|
337
|
+
GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
|
|
338
|
+
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
|
|
339
|
+
GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
|
|
340
|
+
GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
|
|
341
|
+
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33);
|
|
351
342
|
NO_DEVICE_CODE;
|
|
352
343
|
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
|
|
353
344
|
}
|
|
@@ -18,29 +18,13 @@ static __global__ void flash_attn_vec_ext_f32(
|
|
|
18
18
|
const float m1,
|
|
19
19
|
const uint32_t n_head_log2,
|
|
20
20
|
const float logit_softcap,
|
|
21
|
-
const
|
|
22
|
-
|
|
23
|
-
const
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
const int ne13,
|
|
29
|
-
const int ne31,
|
|
30
|
-
const int nb31,
|
|
31
|
-
const int nb01,
|
|
32
|
-
const int nb02,
|
|
33
|
-
const int nb03,
|
|
34
|
-
const int nb11,
|
|
35
|
-
const int nb12,
|
|
36
|
-
const int nb13,
|
|
37
|
-
const int nb21,
|
|
38
|
-
const int nb22,
|
|
39
|
-
const int nb23,
|
|
40
|
-
const int ne0,
|
|
41
|
-
const int ne1,
|
|
42
|
-
const int ne2,
|
|
43
|
-
const int ne3) {
|
|
21
|
+
const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
|
|
22
|
+
const int32_t nb01, const int32_t nb02, const int32_t nb03,
|
|
23
|
+
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
|
|
24
|
+
const int32_t nb11, const int32_t nb12, const int64_t nb13,
|
|
25
|
+
const int32_t nb21, const int32_t nb22, const int64_t nb23,
|
|
26
|
+
const int32_t ne31, const int32_t ne32, const int32_t ne33,
|
|
27
|
+
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
|
|
44
28
|
#ifdef FLASH_ATTN_AVAILABLE
|
|
45
29
|
|
|
46
30
|
// Skip unused kernel variants for faster compilation:
|
|
@@ -51,12 +35,11 @@ static __global__ void flash_attn_vec_ext_f32(
|
|
|
51
35
|
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
|
52
36
|
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
|
|
53
37
|
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
|
|
54
|
-
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
|
|
55
|
-
GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
|
38
|
+
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
|
|
39
|
+
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
|
56
40
|
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
|
|
57
41
|
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
|
|
58
|
-
GGML_UNUSED(nb23);
|
|
59
|
-
GGML_UNUSED(ne2); GGML_UNUSED(ne3);
|
|
42
|
+
GGML_UNUSED(nb23);
|
|
60
43
|
NO_DEVICE_CODE;
|
|
61
44
|
return;
|
|
62
45
|
}
|
|
@@ -75,13 +58,16 @@ static __global__ void flash_attn_vec_ext_f32(
|
|
|
75
58
|
|
|
76
59
|
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
|
|
77
60
|
|
|
61
|
+
const int sequence = blockIdx.z / ne02;
|
|
62
|
+
const int head = blockIdx.z - sequence*ne02;
|
|
78
63
|
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
|
79
|
-
Q += nb02*
|
|
80
|
-
K += nb12*(
|
|
81
|
-
V += nb22*(
|
|
82
|
-
const half * maskh = (const half *) mask + ne11*ic0;
|
|
64
|
+
Q += nb03*sequence + nb02* head + nb01*ic0;
|
|
65
|
+
K += nb13*sequence + nb12*(head / gqa_ratio);
|
|
66
|
+
V += nb23*sequence + nb22*(head / gqa_ratio);
|
|
83
67
|
|
|
84
|
-
const
|
|
68
|
+
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
|
|
69
|
+
|
|
70
|
+
const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
|
|
85
71
|
|
|
86
72
|
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
|
|
87
73
|
constexpr int nwarps = D / WARP_SIZE;
|
|
@@ -191,13 +177,16 @@ static __global__ void flash_attn_vec_ext_f32(
|
|
|
191
177
|
|
|
192
178
|
float VKQ[ncols] = {0.0f};
|
|
193
179
|
|
|
180
|
+
K += blockIdx.y*D * nb11;
|
|
181
|
+
V += blockIdx.y*D * nb21;
|
|
182
|
+
maskh += blockIdx.y*D;
|
|
194
183
|
for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) {
|
|
195
184
|
// Calculate KQ tile and keep track of new maximum KQ values:
|
|
196
185
|
|
|
197
186
|
if (mask) {
|
|
198
187
|
#pragma unroll
|
|
199
188
|
for (int j = 0; j < ncols; ++j) {
|
|
200
|
-
maskf_shared[j*D + tid] = slope*__half2float(maskh[j*ne11 +
|
|
189
|
+
maskf_shared[j*D + tid] = slope*__half2float(maskh[j*ne11 + tid]);
|
|
201
190
|
}
|
|
202
191
|
|
|
203
192
|
__syncthreads();
|
|
@@ -239,7 +228,7 @@ static __global__ void flash_attn_vec_ext_f32(
|
|
|
239
228
|
|
|
240
229
|
#pragma unroll
|
|
241
230
|
for (int j = 0; j < ncols; ++j) {
|
|
242
|
-
float sum = vec_dot_KQ(K +
|
|
231
|
+
float sum = vec_dot_KQ(K + i_KQ*nb11, Q_f2[j], Q_i32[j], Q_ds[j]);
|
|
243
232
|
sum = warp_reduce_sum(sum);
|
|
244
233
|
|
|
245
234
|
if (use_logit_softcap) {
|
|
@@ -290,13 +279,17 @@ static __global__ void flash_attn_vec_ext_f32(
|
|
|
290
279
|
break;
|
|
291
280
|
}
|
|
292
281
|
|
|
293
|
-
const float V_ki = dequantize_1_v(V +
|
|
282
|
+
const float V_ki = dequantize_1_v(V + k*nb21, tid);
|
|
294
283
|
#pragma unroll
|
|
295
284
|
for (int j = 0; j < ncols; ++j) {
|
|
296
285
|
VKQ[j] += V_ki*KQ[j*D + k];
|
|
297
286
|
}
|
|
298
287
|
}
|
|
299
288
|
|
|
289
|
+
K += gridDim.y*D * nb11;
|
|
290
|
+
V += gridDim.y*D * nb21;
|
|
291
|
+
maskh += gridDim.y*D;
|
|
292
|
+
|
|
300
293
|
__syncthreads();
|
|
301
294
|
}
|
|
302
295
|
|
|
@@ -323,24 +316,24 @@ static __global__ void flash_attn_vec_ext_f32(
|
|
|
323
316
|
if (gridDim.y == 1) {
|
|
324
317
|
dst_val /= kqsum[j_VKQ];
|
|
325
318
|
}
|
|
326
|
-
|
|
327
|
-
dst[j_dst*D*gridDim.z + D*blockIdx.z + tid] = dst_val;
|
|
319
|
+
dst[(((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*D + tid] = dst_val;
|
|
328
320
|
}
|
|
329
321
|
|
|
330
322
|
if (gridDim.y != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
|
|
331
|
-
dst_meta[((ic0 + tid)*
|
|
323
|
+
dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
|
|
332
324
|
}
|
|
333
325
|
#else
|
|
334
326
|
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
|
|
335
327
|
GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
|
|
336
328
|
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
|
337
|
-
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
|
338
|
-
GGML_UNUSED(
|
|
339
|
-
GGML_UNUSED(
|
|
340
|
-
GGML_UNUSED(
|
|
341
|
-
GGML_UNUSED(
|
|
342
|
-
GGML_UNUSED(
|
|
343
|
-
GGML_UNUSED(
|
|
329
|
+
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
|
330
|
+
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
|
|
331
|
+
GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
|
|
332
|
+
GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
|
|
333
|
+
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33);
|
|
334
|
+
GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
|
|
335
|
+
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
|
|
336
|
+
GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
|
|
344
337
|
NO_DEVICE_CODE;
|
|
345
338
|
#endif // FLASH_ATTN_AVAILABLE
|
|
346
339
|
}
|
|
@@ -37,29 +37,13 @@ static __global__ void flash_attn_ext_f16(
|
|
|
37
37
|
const float m1,
|
|
38
38
|
const uint32_t n_head_log2,
|
|
39
39
|
const float logit_softcap,
|
|
40
|
-
const
|
|
41
|
-
|
|
42
|
-
const
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
const int ne13,
|
|
48
|
-
const int ne31,
|
|
49
|
-
const int nb31,
|
|
50
|
-
const int nb01,
|
|
51
|
-
const int nb02,
|
|
52
|
-
const int nb03,
|
|
53
|
-
const int nb11,
|
|
54
|
-
const int nb12,
|
|
55
|
-
const int nb13,
|
|
56
|
-
const int nb21,
|
|
57
|
-
const int nb22,
|
|
58
|
-
const int nb23,
|
|
59
|
-
const int ne0,
|
|
60
|
-
const int ne1,
|
|
61
|
-
const int ne2,
|
|
62
|
-
const int ne3) {
|
|
40
|
+
const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
|
|
41
|
+
const int32_t nb01, const int32_t nb02, const int32_t nb03,
|
|
42
|
+
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
|
|
43
|
+
const int32_t nb11, const int32_t nb12, const int64_t nb13,
|
|
44
|
+
const int32_t nb21, const int32_t nb22, const int64_t nb23,
|
|
45
|
+
const int32_t ne31, const int32_t ne32, const int32_t ne33,
|
|
46
|
+
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
|
|
63
47
|
#if defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)))
|
|
64
48
|
// Skip unused kernel variants for faster compilation:
|
|
65
49
|
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
|
@@ -93,17 +77,19 @@ static __global__ void flash_attn_ext_f16(
|
|
|
93
77
|
constexpr int kqs_padded = FATTN_KQ_STRIDE + 8;
|
|
94
78
|
constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);
|
|
95
79
|
|
|
80
|
+
const int sequence = blockIdx.z / ne02;
|
|
81
|
+
const int head = blockIdx.z - sequence*ne02;
|
|
96
82
|
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
|
97
|
-
const float * Q_f = (const float *) (Q + nb02*
|
|
98
|
-
const half * K_h = (const half *) (K + nb12*(
|
|
99
|
-
const half * V_h = (const half *) (V + nb12*(
|
|
100
|
-
const half * maskh = (const half *)
|
|
101
|
-
const half2 * mask2 = (const half2 *)
|
|
83
|
+
const float * Q_f = (const float *) (Q + nb03* sequence + nb02* head + nb01*ic0);
|
|
84
|
+
const half * K_h = (const half *) (K + nb13* sequence + nb12*(head / gqa_ratio));
|
|
85
|
+
const half * V_h = (const half *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
|
|
86
|
+
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
|
|
87
|
+
const half2 * mask2 = (const half2 *) maskh;
|
|
102
88
|
|
|
103
89
|
const int stride_Q = nb01 / sizeof(float);
|
|
104
90
|
const int stride_KV = nb11 / sizeof(half);
|
|
105
91
|
|
|
106
|
-
const float slopef = get_alibi_slope(max_bias,
|
|
92
|
+
const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
|
|
107
93
|
const half slopeh = __float2half(slopef);
|
|
108
94
|
const half2 slope2 = make_half2(slopef, slopef);
|
|
109
95
|
|
|
@@ -191,7 +177,7 @@ static __global__ void flash_attn_ext_f16(
|
|
|
191
177
|
#pragma unroll
|
|
192
178
|
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
|
|
193
179
|
frag_a_K K_a;
|
|
194
|
-
wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
|
|
180
|
+
wmma::load_matrix_sync(K_a, K_h + int64_t(k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
|
|
195
181
|
#pragma unroll
|
|
196
182
|
for (int j = 0; j < ncols/frag_n; ++j) {
|
|
197
183
|
wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
|
|
@@ -338,7 +324,7 @@ static __global__ void flash_attn_ext_f16(
|
|
|
338
324
|
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
|
|
339
325
|
|
|
340
326
|
frag_a_V v_a;
|
|
341
|
-
wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
|
|
327
|
+
wmma::load_matrix_sync(v_a, V_h + int64_t(k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
|
|
342
328
|
#pragma unroll
|
|
343
329
|
for (int j = 0; j < ncols/frag_n; ++j) {
|
|
344
330
|
wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
|
|
@@ -398,7 +384,6 @@ static __global__ void flash_attn_ext_f16(
|
|
|
398
384
|
if (ic0 + j_VKQ >= ne01) {
|
|
399
385
|
return;
|
|
400
386
|
}
|
|
401
|
-
const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
|
|
402
387
|
|
|
403
388
|
float KQ_rowsum_j;
|
|
404
389
|
if (std::is_same<KQ_acc_t, float>::value) {
|
|
@@ -407,6 +392,8 @@ static __global__ void flash_attn_ext_f16(
|
|
|
407
392
|
KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]);
|
|
408
393
|
}
|
|
409
394
|
|
|
395
|
+
const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
|
|
396
|
+
|
|
410
397
|
#pragma unroll
|
|
411
398
|
for (int i0 = 0; i0 < D; i0 += warp_size) {
|
|
412
399
|
const int i = i0 + threadIdx.x;
|
|
@@ -417,7 +404,7 @@ static __global__ void flash_attn_ext_f16(
|
|
|
417
404
|
if (gridDim.y == 1) {
|
|
418
405
|
dst_val /= KQ_rowsum_j;
|
|
419
406
|
}
|
|
420
|
-
dst[
|
|
407
|
+
dst[j_dst_unrolled*D + i] = dst_val;
|
|
421
408
|
}
|
|
422
409
|
|
|
423
410
|
if (gridDim.y == 1 || threadIdx.x != 0) {
|
|
@@ -431,7 +418,7 @@ static __global__ void flash_attn_ext_f16(
|
|
|
431
418
|
dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]);
|
|
432
419
|
}
|
|
433
420
|
dst_meta_val.y = KQ_rowsum_j;
|
|
434
|
-
dst_meta[
|
|
421
|
+
dst_meta[j_dst_unrolled] = dst_meta_val;
|
|
435
422
|
}
|
|
436
423
|
#else
|
|
437
424
|
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
|
|
@@ -440,10 +427,10 @@ static __global__ void flash_attn_ext_f16(
|
|
|
440
427
|
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
|
441
428
|
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
|
|
442
429
|
GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
|
|
443
|
-
GGML_UNUSED(ne31); GGML_UNUSED(
|
|
430
|
+
GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33); GGML_UNUSED(nb31);
|
|
431
|
+
GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
|
444
432
|
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
|
|
445
433
|
GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
|
|
446
|
-
GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3);
|
|
447
434
|
NO_DEVICE_CODE;
|
|
448
435
|
#endif // defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)))
|
|
449
436
|
}
|
|
@@ -280,22 +280,12 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
|
|
280
280
|
const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
|
|
281
281
|
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
|
|
282
282
|
|
|
283
|
-
if (GGML_CUDA_CC_IS_AMD(cc)) {
|
|
284
283
|
#if defined(GGML_HIP_ROCWMMA_FATTN)
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
return;
|
|
288
|
-
}
|
|
289
|
-
#endif // defined(GGML_HIP_ROCWMMA_FATTN)
|
|
290
|
-
|
|
291
|
-
// On AMD the tile kernels perform poorly, use the vec kernel instead:
|
|
292
|
-
if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
|
|
293
|
-
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
|
294
|
-
} else {
|
|
295
|
-
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
|
|
296
|
-
}
|
|
284
|
+
if (GGML_CUDA_CC_IS_AMD(cc) && fp16_mma_available(cc)) {
|
|
285
|
+
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
|
|
297
286
|
return;
|
|
298
287
|
}
|
|
288
|
+
#endif // defined(GGML_HIP_ROCWMMA_FATTN)
|
|
299
289
|
|
|
300
290
|
if (!fast_fp16_available(cc)) {
|
|
301
291
|
if (Q->ne[1] <= 8 || Q->ne[0] == 256) {
|
|
@@ -168,6 +168,10 @@ static void ggml_cuda_get_rows_switch_src0_type(
|
|
|
168
168
|
get_rows_cuda_float((const float *) src0_d, src1_d, dst_d,
|
|
169
169
|
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
|
|
170
170
|
break;
|
|
171
|
+
case GGML_TYPE_I32:
|
|
172
|
+
get_rows_cuda_float((const int32_t *) src0_d, src1_d, dst_d,
|
|
173
|
+
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
|
|
174
|
+
break;
|
|
171
175
|
case GGML_TYPE_BF16:
|
|
172
176
|
get_rows_cuda_float((const nv_bfloat16 *) src0_d, src1_d, dst_d,
|
|
173
177
|
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
|
|
@@ -210,6 +214,10 @@ void get_rows_cuda(
|
|
|
210
214
|
ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (float *) dst_d,
|
|
211
215
|
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
|
|
212
216
|
break;
|
|
217
|
+
case GGML_TYPE_I32:
|
|
218
|
+
ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (int32_t *) dst_d,
|
|
219
|
+
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
|
|
220
|
+
break;
|
|
213
221
|
case GGML_TYPE_F16:
|
|
214
222
|
ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (half *) dst_d,
|
|
215
223
|
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
|