@novastera-oss/llamarn 0.2.9 → 0.3.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/build.gradle +2 -1
- package/android/proguard-rules.pro +12 -0
- package/android/src/main/cpp/include/llama.h +15 -47
- package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86/libggml.so +0 -0
- package/android/src/main/jniLibs/x86/libllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
- package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
- package/cpp/build-info.cpp +2 -2
- package/cpp/llama.cpp/CMakeLists.txt +0 -1
- package/cpp/llama.cpp/CMakePresets.json +11 -0
- package/cpp/llama.cpp/CODEOWNERS +1 -0
- package/cpp/llama.cpp/README.md +8 -8
- package/cpp/llama.cpp/build-xcframework.sh +1 -1
- package/cpp/llama.cpp/common/CMakeLists.txt +4 -5
- package/cpp/llama.cpp/common/arg.cpp +62 -1
- package/cpp/llama.cpp/common/chat.cpp +37 -20
- package/cpp/llama.cpp/common/chat.h +2 -0
- package/cpp/llama.cpp/common/common.cpp +22 -6
- package/cpp/llama.cpp/common/common.h +22 -4
- package/cpp/llama.cpp/convert_hf_to_gguf.py +1250 -43
- package/cpp/llama.cpp/convert_hf_to_gguf_update.py +21 -13
- package/cpp/llama.cpp/ggml/CMakeLists.txt +13 -3
- package/cpp/llama.cpp/ggml/cmake/ggml-config.cmake.in +85 -47
- package/cpp/llama.cpp/ggml/include/ggml-backend.h +1 -1
- package/cpp/llama.cpp/ggml/include/ggml-webgpu.h +19 -0
- package/cpp/llama.cpp/ggml/include/ggml.h +173 -10
- package/cpp/llama.cpp/ggml/src/CMakeLists.txt +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-alloc.c +0 -15
- package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +7 -8
- package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +44 -38
- package/cpp/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +3 -1
- package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +126 -8
- package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +130 -22
- package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +138 -18
- package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +11 -3
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +28 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +109 -12
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +88 -10
- package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +343 -1094
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +1206 -163
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +6 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +0 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +36 -9
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +142 -9
- package/cpp/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +3 -3
- package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +31 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +86 -17
- package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy-utils.cuh +225 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy.cu +41 -301
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +85 -64
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +47 -60
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +29 -42
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +46 -59
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +36 -45
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +38 -45
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +23 -36
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cu +3 -13
- package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +255 -99
- package/cpp/llama.cpp/ggml/src/ggml-cuda/im2col.cu +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mma.cuh +111 -3
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cu +6 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +1152 -695
- package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cu +92 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cuh +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/rope.cu +21 -27
- package/cpp/llama.cpp/ggml/src/ggml-cuda/scale.cu +8 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cu +275 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cuh +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/softmax.cu +119 -58
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +192 -52
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +104 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +13 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/upscale.cu +92 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +27 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-impl.h +80 -0
- package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -2
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +48 -12
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +572 -106
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +599 -105
- package/cpp/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +18 -4
- package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +800 -42
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl +337 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/rms_norm.cl +79 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/set_rows.cl +95 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
- package/cpp/llama.cpp/ggml/src/ggml-quants.c +6 -6
- package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +4 -4
- package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +693 -1034
- package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +18 -9
- package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +14 -26
- package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +191 -55
- package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +8 -9
- package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +15 -18
- package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +131 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.hpp +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +2 -6
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +991 -307
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +265 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +59 -12
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +28 -23
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +14 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +38 -32
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +32 -27
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +44 -12
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +17 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +3 -8
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +128 -72
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +38 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +18 -3
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +1 -4
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +20 -4
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +84 -9
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/CMakeLists.txt +54 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/ggml-webgpu.cpp +907 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +60 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +35 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +56 -0
- package/cpp/llama.cpp/ggml/src/ggml.c +386 -67
- package/cpp/llama.cpp/ggml/src/gguf.cpp +8 -1
- package/cpp/llama.cpp/gguf-py/gguf/constants.py +307 -0
- package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +8 -2
- package/cpp/llama.cpp/gguf-py/gguf/metadata.py +4 -0
- package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_dump.py +24 -1
- package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +122 -47
- package/cpp/llama.cpp/gguf-py/gguf/vocab.py +12 -3
- package/cpp/llama.cpp/include/llama.h +15 -47
- package/cpp/llama.cpp/models/templates/llama-cpp-rwkv-world.jinja +34 -0
- package/cpp/llama.cpp/models/templates/moonshotai-Kimi-K2.jinja +43 -0
- package/cpp/llama.cpp/requirements/requirements-all.txt +1 -0
- package/cpp/llama.cpp/requirements/requirements-server-bench.txt +5 -0
- package/cpp/llama.cpp/src/llama-arch.cpp +316 -3
- package/cpp/llama.cpp/src/llama-arch.h +23 -1
- package/cpp/llama.cpp/src/llama-batch.cpp +103 -71
- package/cpp/llama.cpp/src/llama-batch.h +31 -18
- package/cpp/llama.cpp/src/llama-chat.cpp +58 -1
- package/cpp/llama.cpp/src/llama-chat.h +3 -0
- package/cpp/llama.cpp/src/llama-context.cpp +180 -106
- package/cpp/llama.cpp/src/llama-context.h +26 -16
- package/cpp/llama.cpp/src/llama-cparams.h +3 -2
- package/cpp/llama.cpp/src/llama-graph.cpp +310 -211
- package/cpp/llama.cpp/src/llama-graph.h +184 -122
- package/cpp/llama.cpp/src/llama-hparams.cpp +47 -1
- package/cpp/llama.cpp/src/llama-hparams.h +13 -2
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +38 -22
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +7 -2
- package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +849 -304
- package/cpp/llama.cpp/src/llama-kv-cache-unified.h +143 -47
- package/cpp/llama.cpp/src/llama-kv-cells.h +62 -10
- package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +10 -4
- package/cpp/llama.cpp/src/llama-memory-hybrid.h +3 -1
- package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +36 -11
- package/cpp/llama.cpp/src/llama-memory.cpp +17 -0
- package/cpp/llama.cpp/src/llama-memory.h +3 -0
- package/cpp/llama.cpp/src/llama-model.cpp +3545 -719
- package/cpp/llama.cpp/src/llama-model.h +21 -4
- package/cpp/llama.cpp/src/llama-quant.cpp +2 -2
- package/cpp/llama.cpp/src/llama-vocab.cpp +376 -10
- package/cpp/llama.cpp/src/llama-vocab.h +43 -0
- package/cpp/llama.cpp/src/unicode.cpp +207 -0
- package/cpp/llama.cpp/src/unicode.h +2 -0
- package/ios/include/chat.h +2 -0
- package/ios/include/common.h +22 -4
- package/ios/include/llama.h +15 -47
- package/ios/libs/llama.xcframework/Info.plist +13 -13
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5267 -4890
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -4861
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4014 -3764
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -4861
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4016 -3766
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5267 -4890
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -4861
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4014 -3764
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5303 -4926
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5274 -4897
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4044 -3794
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/package.json +4 -4
- package/cpp/llama.cpp/ggml/include/ggml-kompute.h +0 -50
- package/cpp/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
- package/cpp/llama.cpp/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
|
@@ -123,13 +123,7 @@ void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor *
|
|
|
123
123
|
ggml_cuda_pool_alloc<float> dst_tmp(pool, blocks_num.x);
|
|
124
124
|
|
|
125
125
|
if (nbytes_shared <= smpbo) {
|
|
126
|
-
|
|
127
|
-
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
|
|
128
|
-
if (!shared_memory_limit_raised[id]) {
|
|
129
|
-
CUDA_CHECK(cudaFuncSetAttribute(cross_entropy_loss_f32<true>, cudaFuncAttributeMaxDynamicSharedMemorySize, smpbo));
|
|
130
|
-
shared_memory_limit_raised[id] = true;
|
|
131
|
-
}
|
|
132
|
-
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
|
126
|
+
CUDA_SET_SHARED_MEMORY_LIMIT((cross_entropy_loss_f32<true>), smpbo);
|
|
133
127
|
cross_entropy_loss_f32<true><<<blocks_num, blocks_dim, nbytes_shared, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
|
|
134
128
|
} else {
|
|
135
129
|
cross_entropy_loss_f32<false><<<blocks_num, blocks_dim, 0, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
|
|
@@ -175,13 +169,7 @@ void ggml_cuda_cross_entropy_loss_back(ggml_backend_cuda_context & ctx, ggml_ten
|
|
|
175
169
|
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
|
|
176
170
|
|
|
177
171
|
if (nbytes_shared <= smpbo) {
|
|
178
|
-
|
|
179
|
-
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
|
|
180
|
-
if (!shared_memory_limit_raised[id]) {
|
|
181
|
-
CUDA_CHECK(cudaFuncSetAttribute(cross_entropy_loss_back_f32<true>, cudaFuncAttributeMaxDynamicSharedMemorySize, smpbo));
|
|
182
|
-
shared_memory_limit_raised[id] = true;
|
|
183
|
-
}
|
|
184
|
-
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
|
172
|
+
CUDA_SET_SHARED_MEMORY_LIMIT((cross_entropy_loss_back_f32<true>), smpbo);
|
|
185
173
|
cross_entropy_loss_back_f32<true><<<blocks_num, blocks_dim, nbytes_shared, stream>>>(grad_d, src0f_d, src1f_d, dst_d, ne00);
|
|
186
174
|
} else {
|
|
187
175
|
cross_entropy_loss_back_f32<false><<<blocks_num, blocks_dim, 0, stream>>>(grad_d, src0f_d, src1f_d, dst_d, ne00);
|
|
@@ -23,29 +23,13 @@ typedef void (* fattn_kernel_t)(
|
|
|
23
23
|
const float m1,
|
|
24
24
|
const uint32_t n_head_log2,
|
|
25
25
|
const float logit_softcap,
|
|
26
|
-
const
|
|
27
|
-
|
|
28
|
-
const
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
const int ne13,
|
|
34
|
-
const int ne31,
|
|
35
|
-
const int nb31,
|
|
36
|
-
const int nb01,
|
|
37
|
-
const int nb02,
|
|
38
|
-
const int nb03,
|
|
39
|
-
const int nb11,
|
|
40
|
-
const int nb12,
|
|
41
|
-
const int nb13,
|
|
42
|
-
const int nb21,
|
|
43
|
-
const int nb22,
|
|
44
|
-
const int nb23,
|
|
45
|
-
const int ne0,
|
|
46
|
-
const int ne1,
|
|
47
|
-
const int ne2,
|
|
48
|
-
const int ne3);
|
|
26
|
+
const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
|
|
27
|
+
const int32_t nb01, const int32_t nb02, const int32_t nb03,
|
|
28
|
+
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
|
|
29
|
+
const int32_t nb11, const int32_t nb12, const int64_t nb13,
|
|
30
|
+
const int32_t nb21, const int32_t nb22, const int64_t nb23,
|
|
31
|
+
const int32_t ne31, const int32_t ne32, const int32_t ne33,
|
|
32
|
+
const int32_t nb31, const int32_t nb32, const int64_t nb33);
|
|
49
33
|
|
|
50
34
|
typedef half (*vec_dot_KQ_f16_t)(
|
|
51
35
|
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds);
|
|
@@ -519,7 +503,7 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
|
|
|
519
503
|
template<int D, int ncols1, int ncols2> // D == head size
|
|
520
504
|
__launch_bounds__(D, 1)
|
|
521
505
|
static __global__ void flash_attn_stream_k_fixup(
|
|
522
|
-
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) {
|
|
506
|
+
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, const int ne11) {
|
|
523
507
|
constexpr int ncols = ncols1*ncols2;
|
|
524
508
|
|
|
525
509
|
const int bidx0 = blockIdx.x;
|
|
@@ -533,8 +517,8 @@ static __global__ void flash_attn_stream_k_fixup(
|
|
|
533
517
|
const int iter_k = ne11 / FATTN_KQ_STRIDE;
|
|
534
518
|
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
|
|
535
519
|
|
|
536
|
-
const int kbc0 = (bidx0 + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
|
|
537
|
-
const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
|
|
520
|
+
const int kbc0 = (bidx0 + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
|
|
521
|
+
const int kbc0_stop = (bidx0 + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
|
|
538
522
|
|
|
539
523
|
const bool did_not_have_any_data = kbc0 == kbc0_stop;
|
|
540
524
|
const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
|
|
@@ -543,14 +527,15 @@ static __global__ void flash_attn_stream_k_fixup(
|
|
|
543
527
|
return;
|
|
544
528
|
}
|
|
545
529
|
|
|
546
|
-
const int
|
|
547
|
-
const int
|
|
530
|
+
const int sequence = kbc0 / (iter_k*iter_j*(ne02/ncols2));
|
|
531
|
+
const int head = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
|
|
532
|
+
const int jt = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
|
|
548
533
|
|
|
549
534
|
if (jt*ncols1 + j >= ne01) {
|
|
550
535
|
return;
|
|
551
536
|
}
|
|
552
537
|
|
|
553
|
-
dst += jt*ne02*(ncols1*D) +
|
|
538
|
+
dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + head*(ncols2*D) + (j*ne02 + c)*D + tid;
|
|
554
539
|
|
|
555
540
|
// Load the partial result that needs a fixup:
|
|
556
541
|
float dst_val = 0.0f;
|
|
@@ -569,7 +554,7 @@ static __global__ void flash_attn_stream_k_fixup(
|
|
|
569
554
|
int bidx = bidx0 - 1;
|
|
570
555
|
int kbc_stop = kbc0;
|
|
571
556
|
while(true) {
|
|
572
|
-
const int kbc = bidx*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
|
|
557
|
+
const int kbc = bidx*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
|
|
573
558
|
if (kbc == kbc_stop) { // Did not have any data.
|
|
574
559
|
bidx--;
|
|
575
560
|
kbc_stop = kbc;
|
|
@@ -615,16 +600,31 @@ static __global__ void flash_attn_combine_results(
|
|
|
615
600
|
const float2 * __restrict__ VKQ_meta,
|
|
616
601
|
float * __restrict__ dst,
|
|
617
602
|
const int parallel_blocks) {
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
603
|
+
// Dimension 0: threadIdx.x
|
|
604
|
+
// Dimension 1: blockIdx.x
|
|
605
|
+
// Dimension 2: blockIdx.y
|
|
606
|
+
// Dimension 3: blockIdx.z
|
|
607
|
+
// Memory layout is permuted with [0, 2, 1, 3]
|
|
608
|
+
|
|
609
|
+
const int ne01 = gridDim.x;
|
|
610
|
+
const int ne02 = gridDim.y;
|
|
611
|
+
|
|
612
|
+
const int col = blockIdx.x;
|
|
613
|
+
const int head = blockIdx.y;
|
|
614
|
+
const int sequence = blockIdx.z;
|
|
615
|
+
|
|
616
|
+
const int j_dst_unrolled = (sequence*ne01 + col)*ne02 + head;
|
|
617
|
+
|
|
618
|
+
VKQ_parts += j_dst_unrolled * parallel_blocks*D;
|
|
619
|
+
VKQ_meta += j_dst_unrolled * parallel_blocks;
|
|
620
|
+
dst += j_dst_unrolled * D;
|
|
621
621
|
|
|
622
622
|
const int tid = threadIdx.x;
|
|
623
623
|
__builtin_assume(tid < D);
|
|
624
624
|
|
|
625
625
|
extern __shared__ float2 meta[];
|
|
626
626
|
for (int i = tid; i < 2*parallel_blocks; i += D) {
|
|
627
|
-
((float *) meta)[i] = ((const float *)VKQ_meta) [
|
|
627
|
+
((float *) meta)[i] = ((const float *)VKQ_meta) [i];
|
|
628
628
|
}
|
|
629
629
|
|
|
630
630
|
__syncthreads();
|
|
@@ -642,11 +642,11 @@ static __global__ void flash_attn_combine_results(
|
|
|
642
642
|
const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
|
|
643
643
|
*((uint32_t *) &KQ_max_scale) &= ftz_mask;
|
|
644
644
|
|
|
645
|
-
VKQ_numerator += KQ_max_scale * VKQ_parts[l*
|
|
645
|
+
VKQ_numerator += KQ_max_scale * VKQ_parts[l*D + tid];
|
|
646
646
|
VKQ_denominator += KQ_max_scale * meta[l].y;
|
|
647
647
|
}
|
|
648
648
|
|
|
649
|
-
dst[
|
|
649
|
+
dst[tid] = VKQ_numerator / VKQ_denominator;
|
|
650
650
|
}
|
|
651
651
|
|
|
652
652
|
[[noreturn]]
|
|
@@ -703,8 +703,6 @@ void launch_fattn(
|
|
|
703
703
|
|
|
704
704
|
GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
|
|
705
705
|
|
|
706
|
-
GGML_ASSERT(Q->ne[3] == 1);
|
|
707
|
-
|
|
708
706
|
ggml_cuda_pool & pool = ctx.pool();
|
|
709
707
|
cudaStream_t main_stream = ctx.stream();
|
|
710
708
|
const int id = ggml_cuda_get_device();
|
|
@@ -727,33 +725,58 @@ void launch_fattn(
|
|
|
727
725
|
size_t nb23 = V ? V->nb[3] : nb13;
|
|
728
726
|
|
|
729
727
|
if (need_f16_K && K->type != GGML_TYPE_F16) {
|
|
730
|
-
GGML_ASSERT(ggml_is_contiguously_allocated(K));
|
|
731
|
-
K_f16.alloc(ggml_nelements(K));
|
|
732
|
-
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(K->type);
|
|
733
|
-
to_fp16(K_data, K_f16.ptr, ggml_nelements(K), main_stream);
|
|
734
|
-
K_data = (char *) K_f16.ptr;
|
|
735
|
-
|
|
736
728
|
const size_t bs = ggml_blck_size(K->type);
|
|
737
729
|
const size_t ts = ggml_type_size(K->type);
|
|
738
730
|
|
|
739
|
-
|
|
740
|
-
|
|
741
|
-
|
|
731
|
+
K_f16.alloc(ggml_nelements(K));
|
|
732
|
+
if (ggml_is_contiguously_allocated(K)) {
|
|
733
|
+
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(K->type);
|
|
734
|
+
to_fp16(K_data, K_f16.ptr, ggml_nelements(K), main_stream);
|
|
735
|
+
|
|
736
|
+
nb11 = nb11*bs*sizeof(half)/ts;
|
|
737
|
+
nb12 = nb12*bs*sizeof(half)/ts;
|
|
738
|
+
nb13 = nb13*bs*sizeof(half)/ts;
|
|
739
|
+
} else {
|
|
740
|
+
GGML_ASSERT(K->nb[0] == ts);
|
|
741
|
+
to_fp16_nc_cuda_t to_fp16 = ggml_get_to_fp16_nc_cuda(K->type);
|
|
742
|
+
const int64_t s01 = nb11 / ts;
|
|
743
|
+
const int64_t s02 = nb12 / ts;
|
|
744
|
+
const int64_t s03 = nb13 / ts;
|
|
745
|
+
to_fp16(K_data, K_f16.ptr, K->ne[0], K->ne[1], K->ne[2], K->ne[3], s01, s02, s03, main_stream);
|
|
746
|
+
|
|
747
|
+
nb11 = K->ne[0] * sizeof(half);
|
|
748
|
+
nb12 = K->ne[1] * nb11;
|
|
749
|
+
nb13 = K->ne[2] * nb12;
|
|
750
|
+
}
|
|
751
|
+
K_data = (char *) K_f16.ptr;
|
|
742
752
|
}
|
|
743
753
|
|
|
744
754
|
if (V && need_f16_V && V->type != GGML_TYPE_F16) {
|
|
745
|
-
GGML_ASSERT(ggml_is_contiguously_allocated(V));
|
|
746
|
-
V_f16.alloc(ggml_nelements(V));
|
|
747
|
-
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
|
|
748
|
-
to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream);
|
|
749
|
-
V_data = (char *) V_f16.ptr;
|
|
750
|
-
|
|
751
755
|
const size_t bs = ggml_blck_size(V->type);
|
|
752
756
|
const size_t ts = ggml_type_size(V->type);
|
|
753
757
|
|
|
754
|
-
|
|
755
|
-
|
|
756
|
-
|
|
758
|
+
V_f16.alloc(ggml_nelements(V));
|
|
759
|
+
if (ggml_is_contiguously_allocated(V)) {
|
|
760
|
+
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
|
|
761
|
+
to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream);
|
|
762
|
+
V_data = (char *) V_f16.ptr;
|
|
763
|
+
|
|
764
|
+
nb21 = nb21*bs*sizeof(half)/ts;
|
|
765
|
+
nb22 = nb22*bs*sizeof(half)/ts;
|
|
766
|
+
nb23 = nb23*bs*sizeof(half)/ts;
|
|
767
|
+
} else {
|
|
768
|
+
GGML_ASSERT(V->nb[0] == ts);
|
|
769
|
+
to_fp16_nc_cuda_t to_fp16 = ggml_get_to_fp16_nc_cuda(V->type);
|
|
770
|
+
const int64_t s01 = nb21 / ts;
|
|
771
|
+
const int64_t s02 = nb22 / ts;
|
|
772
|
+
const int64_t s03 = nb23 / ts;
|
|
773
|
+
to_fp16(V_data, V_f16.ptr, V->ne[0], V->ne[1], V->ne[2], V->ne[3], s01, s02, s03, main_stream);
|
|
774
|
+
|
|
775
|
+
nb21 = V->ne[0] * sizeof(half);
|
|
776
|
+
nb22 = V->ne[1] * nb21;
|
|
777
|
+
nb23 = V->ne[2] * nb22;
|
|
778
|
+
}
|
|
779
|
+
V_data = (char *) V_f16.ptr;
|
|
757
780
|
}
|
|
758
781
|
|
|
759
782
|
int parallel_blocks = 1;
|
|
@@ -849,13 +872,11 @@ void launch_fattn(
|
|
|
849
872
|
mask ? ((const char *) mask->data) : nullptr,
|
|
850
873
|
!stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
|
|
851
874
|
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
|
|
852
|
-
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
|
853
|
-
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
|
854
|
-
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
|
855
|
-
Q->nb[1], Q->nb[2], Q->nb[3],
|
|
856
|
-
nb11, nb12, nb13,
|
|
875
|
+
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], Q->nb[1], Q->nb[2], Q->nb[3],
|
|
876
|
+
K->ne[0], K->ne[1], K->ne[2], K->ne[3], nb11, nb12, nb13,
|
|
857
877
|
nb21, nb22, nb23,
|
|
858
|
-
|
|
878
|
+
mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, mask ? mask->ne[3] : 0,
|
|
879
|
+
mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0, mask ? mask->nb[3] : 0
|
|
859
880
|
);
|
|
860
881
|
CUDA_CHECK(cudaGetLastError());
|
|
861
882
|
|
|
@@ -866,11 +887,11 @@ void launch_fattn(
|
|
|
866
887
|
|
|
867
888
|
flash_attn_stream_k_fixup<DV, ncols1, ncols2>
|
|
868
889
|
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
|
|
869
|
-
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]);
|
|
890
|
+
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1]);
|
|
870
891
|
}
|
|
871
892
|
} else if (parallel_blocks > 1) {
|
|
872
893
|
const dim3 block_dim_combine(DV, 1, 1);
|
|
873
|
-
const dim3 blocks_num_combine(Q->ne[1],
|
|
894
|
+
const dim3 blocks_num_combine(Q->ne[1], Q->ne[2], Q->ne[3]);
|
|
874
895
|
const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2);
|
|
875
896
|
|
|
876
897
|
flash_attn_combine_results<DV>
|
|
@@ -408,7 +408,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
408
408
|
const int stride_K,
|
|
409
409
|
const int stride_V,
|
|
410
410
|
const int stride_mask,
|
|
411
|
-
const int jt,
|
|
412
411
|
half2 * const __restrict__ tile_Q,
|
|
413
412
|
half2 * const __restrict__ tile_K,
|
|
414
413
|
half2 * const __restrict__ tile_V,
|
|
@@ -455,7 +454,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
455
454
|
cp_async_wait_all();
|
|
456
455
|
__syncthreads();
|
|
457
456
|
flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
|
|
458
|
-
(V_h2 + k_VKQ_0*stride_V, tile_V, nbatch_V2, stride_V);
|
|
457
|
+
(V_h2 + int64_t(k_VKQ_0)*stride_V, tile_V, nbatch_V2, stride_V);
|
|
459
458
|
} else {
|
|
460
459
|
constexpr bool use_cp_async = nstages == 1;
|
|
461
460
|
if (ncols2 > 1 || mask_h2) {
|
|
@@ -471,7 +470,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
471
470
|
if (nstages <= 1) {
|
|
472
471
|
constexpr bool use_cp_async = nstages == 1;
|
|
473
472
|
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
|
|
474
|
-
(K_h2 + k_VKQ_0*stride_K + k0_start, tile_K, k0_diff, stride_K);
|
|
473
|
+
(K_h2 + int64_t(k_VKQ_0)*stride_K + k0_start, tile_K, k0_diff, stride_K);
|
|
475
474
|
if (use_cp_async) {
|
|
476
475
|
cp_async_wait_all();
|
|
477
476
|
}
|
|
@@ -715,7 +714,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
715
714
|
(mask_h2 + (k_VKQ_0 + c::nbatch_fa)/2, tile_mask, stride_mask);
|
|
716
715
|
}
|
|
717
716
|
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
|
|
718
|
-
(K_h2 + (k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K);
|
|
717
|
+
(K_h2 + int64_t(k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K);
|
|
719
718
|
}
|
|
720
719
|
}
|
|
721
720
|
|
|
@@ -732,7 +731,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
732
731
|
if (nstages <= 1 && i0_start < reusable_cutoff) {
|
|
733
732
|
constexpr bool use_cp_async = nstages == 1;
|
|
734
733
|
flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
|
|
735
|
-
(V_h2 + k_VKQ_0*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V);
|
|
734
|
+
(V_h2 + int64_t(k_VKQ_0)*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V);
|
|
736
735
|
if (use_cp_async) {
|
|
737
736
|
cp_async_wait_all();
|
|
738
737
|
}
|
|
@@ -771,8 +770,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
771
770
|
GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup);
|
|
772
771
|
GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap);
|
|
773
772
|
GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_K); GGML_UNUSED(stride_V);
|
|
774
|
-
GGML_UNUSED(stride_mask); GGML_UNUSED(
|
|
775
|
-
GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K);
|
|
773
|
+
GGML_UNUSED(stride_mask); GGML_UNUSED(tile_K);
|
|
776
774
|
GGML_UNUSED(tile_V); GGML_UNUSED(tile_mask); GGML_UNUSED(Q_B);
|
|
777
775
|
GGML_UNUSED(VKQ_C); GGML_UNUSED(KQ_max); GGML_UNUSED(KQ_rowsum);
|
|
778
776
|
GGML_UNUSED(kb0); GGML_UNUSED(tile_Q);
|
|
@@ -920,7 +918,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
920
918
|
(mask_h2 + kb0_start*c::nbatch_fa/2, tile_mask, stride_mask);
|
|
921
919
|
}
|
|
922
920
|
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
|
|
923
|
-
(K_h2 + kb0_start*c::nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K);
|
|
921
|
+
(K_h2 + int64_t(kb0_start)*c::nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K);
|
|
924
922
|
}
|
|
925
923
|
|
|
926
924
|
// Iterate over ne11 == previous tokens:
|
|
@@ -928,13 +926,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
928
926
|
constexpr bool last_iter = false;
|
|
929
927
|
flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
|
|
930
928
|
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
|
|
931
|
-
ne01, ne02, stride_K, stride_V, stride_mask,
|
|
929
|
+
ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
|
|
932
930
|
}
|
|
933
931
|
{ // kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
|
|
934
932
|
constexpr bool last_iter = true;
|
|
935
933
|
flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
|
|
936
934
|
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
|
|
937
|
-
ne01, ne02, stride_K, stride_V, stride_mask,
|
|
935
|
+
ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1);
|
|
938
936
|
}
|
|
939
937
|
|
|
940
938
|
// With multi-stage loading there is no __syncthreads at the end of the iter,
|
|
@@ -1214,29 +1212,13 @@ static __global__ void flash_attn_ext_f16(
|
|
|
1214
1212
|
const float m1,
|
|
1215
1213
|
const uint32_t n_head_log2,
|
|
1216
1214
|
const float logit_softcap,
|
|
1217
|
-
const
|
|
1218
|
-
|
|
1219
|
-
const
|
|
1220
|
-
|
|
1221
|
-
|
|
1222
|
-
|
|
1223
|
-
|
|
1224
|
-
const int ne13,
|
|
1225
|
-
const int ne31,
|
|
1226
|
-
const int nb31,
|
|
1227
|
-
const int nb01,
|
|
1228
|
-
const int nb02,
|
|
1229
|
-
const int nb03,
|
|
1230
|
-
const int nb11,
|
|
1231
|
-
const int nb12,
|
|
1232
|
-
const int nb13,
|
|
1233
|
-
const int nb21,
|
|
1234
|
-
const int nb22,
|
|
1235
|
-
const int nb23,
|
|
1236
|
-
const int ne0,
|
|
1237
|
-
const int ne1,
|
|
1238
|
-
const int ne2,
|
|
1239
|
-
const int ne3) {
|
|
1215
|
+
const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
|
|
1216
|
+
const int32_t nb01, const int32_t nb02, const int32_t nb03,
|
|
1217
|
+
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
|
|
1218
|
+
const int32_t nb11, const int32_t nb12, const int64_t nb13,
|
|
1219
|
+
const int32_t nb21, const int32_t nb22, const int64_t nb23,
|
|
1220
|
+
const int32_t ne31, const int32_t ne32, const int32_t ne33,
|
|
1221
|
+
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
|
|
1240
1222
|
#if defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
|
|
1241
1223
|
|
|
1242
1224
|
// Skip unused kernel variants for faster compilation:
|
|
@@ -1272,8 +1254,8 @@ static __global__ void flash_attn_ext_f16(
|
|
|
1272
1254
|
constexpr int kb_niter = FATTN_KQ_STRIDE / c::nbatch_fa; // Number of kernel iterations per assigned KQ slice.
|
|
1273
1255
|
|
|
1274
1256
|
// kbc == k block continuous, current index in continuous ijk space.
|
|
1275
|
-
int kbc = (blockIdx.x + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
|
|
1276
|
-
const int kbc_stop = (blockIdx.x + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
|
|
1257
|
+
int kbc = (blockIdx.x + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
|
|
1258
|
+
const int kbc_stop = (blockIdx.x + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
|
|
1277
1259
|
|
|
1278
1260
|
// If the seams of 2 CUDA blocks fall within an output tile their results need to be combined.
|
|
1279
1261
|
// For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup).
|
|
@@ -1283,17 +1265,19 @@ static __global__ void flash_attn_ext_f16(
|
|
|
1283
1265
|
int kb0_start = kbc % iter_k;
|
|
1284
1266
|
int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc);
|
|
1285
1267
|
while (kbc < kbc_stop && kb0_stop == iter_k) {
|
|
1286
|
-
const int
|
|
1287
|
-
const int
|
|
1268
|
+
const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
|
|
1269
|
+
const int head = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
|
|
1270
|
+
const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
|
|
1288
1271
|
|
|
1289
|
-
const float2 * Q_f2 = (const float2 *) (Q +
|
|
1290
|
-
const half2 * K_h2 = (const half2 *) (K + nb12*(
|
|
1291
|
-
const half2 * mask_h2 = ncols2
|
|
1292
|
-
|
|
1272
|
+
const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*(head*ncols2));
|
|
1273
|
+
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head*ncols2 / gqa_ratio));
|
|
1274
|
+
const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
|
|
1275
|
+
(const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1);
|
|
1276
|
+
float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head*ncols2) * (DV/2);
|
|
1293
1277
|
|
|
1294
|
-
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(
|
|
1278
|
+
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head*ncols2 / gqa_ratio));
|
|
1295
1279
|
|
|
1296
|
-
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias,
|
|
1280
|
+
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head, n_head_log2, m0, m1) : 1.0f;
|
|
1297
1281
|
|
|
1298
1282
|
const int kb0_start_kernel = kb0_start * kb_niter;
|
|
1299
1283
|
const int kb0_stop_kernel = kb0_stop * kb_niter;
|
|
@@ -1322,17 +1306,19 @@ static __global__ void flash_attn_ext_f16(
|
|
|
1322
1306
|
return;
|
|
1323
1307
|
}
|
|
1324
1308
|
|
|
1325
|
-
const int
|
|
1326
|
-
const int
|
|
1309
|
+
const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
|
|
1310
|
+
const int head = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
|
|
1311
|
+
const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
|
|
1327
1312
|
|
|
1328
|
-
const float2 * Q_f2 = (const float2 *) (Q +
|
|
1329
|
-
const half2 * K_h2 = (const half2 *) (K + nb12*(
|
|
1330
|
-
const half2 * mask_h2 = ncols2
|
|
1331
|
-
|
|
1313
|
+
const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*(head*ncols2));
|
|
1314
|
+
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head*ncols2 / gqa_ratio));
|
|
1315
|
+
const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
|
|
1316
|
+
(const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1);
|
|
1317
|
+
float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head*ncols2) * (DV/2);
|
|
1332
1318
|
|
|
1333
|
-
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(
|
|
1319
|
+
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head*ncols2 / gqa_ratio));
|
|
1334
1320
|
|
|
1335
|
-
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias,
|
|
1321
|
+
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head, n_head_log2, m0, m1) : 1.0f;
|
|
1336
1322
|
|
|
1337
1323
|
const int kb0_start_kernel = kb0_start * kb_niter;
|
|
1338
1324
|
const int kb0_stop_kernel = kb0_stop * kb_niter;
|
|
@@ -1344,15 +1330,16 @@ static __global__ void flash_attn_ext_f16(
|
|
|
1344
1330
|
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
|
1345
1331
|
#else
|
|
1346
1332
|
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
|
|
1347
|
-
GGML_UNUSED(dst); GGML_UNUSED(dst_meta);
|
|
1348
|
-
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
|
1349
|
-
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
|
1350
|
-
GGML_UNUSED(
|
|
1351
|
-
GGML_UNUSED(
|
|
1352
|
-
GGML_UNUSED(
|
|
1353
|
-
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
|
|
1354
|
-
GGML_UNUSED(
|
|
1355
|
-
GGML_UNUSED(
|
|
1333
|
+
GGML_UNUSED(dst); GGML_UNUSED(dst_meta);
|
|
1334
|
+
GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
|
1335
|
+
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
|
1336
|
+
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
|
|
1337
|
+
GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
|
|
1338
|
+
GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
|
|
1339
|
+
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
|
|
1340
|
+
GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
|
|
1341
|
+
GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
|
|
1342
|
+
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33);
|
|
1356
1343
|
NO_DEVICE_CODE;
|
|
1357
1344
|
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
|
|
1358
1345
|
}
|
|
@@ -6,7 +6,7 @@
|
|
|
6
6
|
|
|
7
7
|
template<int D, int ncols, int nwarps, bool use_logit_softcap> // D == head size
|
|
8
8
|
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
|
9
|
-
__launch_bounds__(nwarps*WARP_SIZE,
|
|
9
|
+
__launch_bounds__(nwarps*WARP_SIZE, 2)
|
|
10
10
|
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
|
11
11
|
static __global__ void flash_attn_tile_ext_f16(
|
|
12
12
|
const char * __restrict__ Q,
|
|
@@ -21,29 +21,13 @@ static __global__ void flash_attn_tile_ext_f16(
|
|
|
21
21
|
const float m1,
|
|
22
22
|
const uint32_t n_head_log2,
|
|
23
23
|
const float logit_softcap,
|
|
24
|
-
const
|
|
25
|
-
|
|
26
|
-
const
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
const int ne13,
|
|
32
|
-
const int ne31,
|
|
33
|
-
const int nb31,
|
|
34
|
-
const int nb01,
|
|
35
|
-
const int nb02,
|
|
36
|
-
const int nb03,
|
|
37
|
-
const int nb11,
|
|
38
|
-
const int nb12,
|
|
39
|
-
const int nb13,
|
|
40
|
-
const int nb21,
|
|
41
|
-
const int nb22,
|
|
42
|
-
const int nb23,
|
|
43
|
-
const int ne0,
|
|
44
|
-
const int ne1,
|
|
45
|
-
const int ne2,
|
|
46
|
-
const int ne3) {
|
|
24
|
+
const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
|
|
25
|
+
const int32_t nb01, const int32_t nb02, const int32_t nb03,
|
|
26
|
+
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
|
|
27
|
+
const int32_t nb11, const int32_t nb12, const int64_t nb13,
|
|
28
|
+
const int32_t nb21, const int32_t nb22, const int64_t nb23,
|
|
29
|
+
const int32_t ne31, const int32_t ne32, const int32_t ne33,
|
|
30
|
+
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
|
|
47
31
|
#if defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
|
|
48
32
|
|
|
49
33
|
// Skip unused kernel variants for faster compilation:
|
|
@@ -60,15 +44,17 @@ static __global__ void flash_attn_tile_ext_f16(
|
|
|
60
44
|
|
|
61
45
|
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
|
|
62
46
|
|
|
47
|
+
const int sequence = blockIdx.z / ne02;
|
|
48
|
+
const int head = blockIdx.z - sequence*ne02;
|
|
63
49
|
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
|
64
|
-
const float2 * Q_f2 = (const float2 *) (Q + nb02*
|
|
65
|
-
const half2 * K_h2 = (const half2 *) (K + nb12*(
|
|
66
|
-
const half2 * V_h2 = (const half2 *) (V + nb12*(
|
|
67
|
-
const half * maskh = (const half *)
|
|
50
|
+
const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0);
|
|
51
|
+
const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
|
|
52
|
+
const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
|
|
53
|
+
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
|
|
68
54
|
|
|
69
55
|
const int stride_KV2 = nb11 / sizeof(half2);
|
|
70
56
|
|
|
71
|
-
const float slopef = get_alibi_slope(max_bias,
|
|
57
|
+
const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
|
|
72
58
|
const half slopeh = __float2half(slopef);
|
|
73
59
|
|
|
74
60
|
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
|
|
@@ -121,7 +107,7 @@ static __global__ void flash_attn_tile_ext_f16(
|
|
|
121
107
|
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
|
|
122
108
|
const int k_KQ = k_KQ_0 + threadIdx.x;
|
|
123
109
|
|
|
124
|
-
KV_tmp[i_KQ][k_KQ] = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ];
|
|
110
|
+
KV_tmp[i_KQ][k_KQ] = K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ];
|
|
125
111
|
}
|
|
126
112
|
}
|
|
127
113
|
|
|
@@ -215,7 +201,7 @@ static __global__ void flash_attn_tile_ext_f16(
|
|
|
215
201
|
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
|
216
202
|
const int i = i0 + threadIdx.x;
|
|
217
203
|
|
|
218
|
-
KV_tmp[k][i] = V_h2[(k_VKQ_0 + k)*stride_KV2 + i];
|
|
204
|
+
KV_tmp[k][i] = V_h2[int64_t(k_VKQ_0 + k)*stride_KV2 + i];
|
|
219
205
|
}
|
|
220
206
|
}
|
|
221
207
|
|
|
@@ -253,6 +239,8 @@ static __global__ void flash_attn_tile_ext_f16(
|
|
|
253
239
|
__syncthreads();
|
|
254
240
|
}
|
|
255
241
|
|
|
242
|
+
float2 * dst2 = (float2 *) dst;
|
|
243
|
+
|
|
256
244
|
#pragma unroll
|
|
257
245
|
for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
|
|
258
246
|
const int j_VKQ = j_VKQ_0 + threadIdx.y;
|
|
@@ -264,21 +252,21 @@ static __global__ void flash_attn_tile_ext_f16(
|
|
|
264
252
|
half kqsum_j = __low2half(kqsum[j_VKQ_0/nwarps]) + __high2half(kqsum[j_VKQ_0/nwarps]);
|
|
265
253
|
kqsum_j = warp_reduce_sum((float)kqsum_j);
|
|
266
254
|
|
|
255
|
+
const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
|
|
256
|
+
|
|
267
257
|
#pragma unroll
|
|
268
|
-
for (int i00 = 0; i00 < D; i00 +=
|
|
269
|
-
const int i0 = i00 +
|
|
258
|
+
for (int i00 = 0; i00 < D/2; i00 += WARP_SIZE) {
|
|
259
|
+
const int i0 = i00 + threadIdx.x;
|
|
270
260
|
|
|
271
|
-
half2 dst_val = VKQ[j_VKQ_0/nwarps][i0/
|
|
261
|
+
half2 dst_val = VKQ[j_VKQ_0/nwarps][i0/WARP_SIZE];
|
|
272
262
|
if (gridDim.y == 1) {
|
|
273
263
|
dst_val /= __half2half2(kqsum_j);
|
|
274
264
|
}
|
|
275
|
-
|
|
276
|
-
dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 0] = __low2float(dst_val);
|
|
277
|
-
dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 1] = __high2float(dst_val);
|
|
265
|
+
dst2[j_dst_unrolled*(D/2) + i0] = __half22float2(dst_val);
|
|
278
266
|
}
|
|
279
267
|
|
|
280
268
|
if (gridDim.y != 1 && threadIdx.x == 0) {
|
|
281
|
-
dst_meta[
|
|
269
|
+
dst_meta[j_dst_unrolled] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
|
|
282
270
|
}
|
|
283
271
|
}
|
|
284
272
|
#else
|
|
@@ -288,12 +276,11 @@ static __global__ void flash_attn_tile_ext_f16(
|
|
|
288
276
|
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
|
289
277
|
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
|
|
290
278
|
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
|
|
291
|
-
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
|
|
292
|
-
GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
|
279
|
+
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
|
|
280
|
+
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
|
293
281
|
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
|
|
294
282
|
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
|
|
295
|
-
GGML_UNUSED(nb23);
|
|
296
|
-
GGML_UNUSED(ne2); GGML_UNUSED(ne3);
|
|
283
|
+
GGML_UNUSED(nb23);
|
|
297
284
|
NO_DEVICE_CODE;
|
|
298
285
|
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
|
|
299
286
|
}
|