@novastera-oss/llamarn 0.2.9 → 0.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86/libggml.so +0 -0
- package/android/src/main/jniLibs/x86/libllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
- package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
- package/cpp/build-info.cpp +2 -2
- package/cpp/llama.cpp/CMakeLists.txt +0 -1
- package/cpp/llama.cpp/README.md +4 -5
- package/cpp/llama.cpp/build-xcframework.sh +1 -1
- package/cpp/llama.cpp/common/CMakeLists.txt +4 -5
- package/cpp/llama.cpp/common/arg.cpp +17 -0
- package/cpp/llama.cpp/common/chat.cpp +37 -20
- package/cpp/llama.cpp/common/chat.h +2 -0
- package/cpp/llama.cpp/common/common.h +4 -0
- package/cpp/llama.cpp/convert_hf_to_gguf.py +745 -6
- package/cpp/llama.cpp/convert_hf_to_gguf_update.py +9 -0
- package/cpp/llama.cpp/ggml/CMakeLists.txt +7 -2
- package/cpp/llama.cpp/ggml/include/ggml-backend.h +1 -1
- package/cpp/llama.cpp/ggml/include/ggml.h +173 -10
- package/cpp/llama.cpp/ggml/src/CMakeLists.txt +0 -1
- package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +0 -8
- package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +36 -18
- package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +68 -5
- package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +16 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +6 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +28 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +1203 -163
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +6 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +33 -9
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +142 -9
- package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +17 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +22 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +4 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +8 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +6 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +14 -12
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +5 -3
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +15 -10
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +8 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +185 -79
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +2 -8
- package/cpp/llama.cpp/ggml/src/ggml-cuda/rope.cu +21 -27
- package/cpp/llama.cpp/ggml/src/ggml-cuda/scale.cu +8 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/softmax.cu +119 -58
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +192 -52
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +97 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +11 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/upscale.cu +92 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +14 -5
- package/cpp/llama.cpp/ggml/src/ggml-impl.h +64 -0
- package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -2
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +35 -9
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +167 -39
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +254 -57
- package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +505 -40
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl +337 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/set_rows.cl +95 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
- package/cpp/llama.cpp/ggml/src/ggml-quants.c +6 -6
- package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +693 -1034
- package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +18 -9
- package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +60 -9
- package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +15 -18
- package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +131 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.hpp +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +711 -292
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +58 -7
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +28 -23
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +14 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +38 -32
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +32 -27
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +44 -12
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +15 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +128 -72
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +38 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +12 -3
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +20 -4
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +23 -3
- package/cpp/llama.cpp/ggml/src/ggml.c +382 -61
- package/cpp/llama.cpp/ggml/src/gguf.cpp +8 -1
- package/cpp/llama.cpp/gguf-py/gguf/constants.py +209 -0
- package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +8 -2
- package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +73 -21
- package/cpp/llama.cpp/gguf-py/gguf/vocab.py +12 -3
- package/cpp/llama.cpp/include/llama.h +0 -40
- package/cpp/llama.cpp/src/llama-arch.cpp +210 -3
- package/cpp/llama.cpp/src/llama-arch.h +18 -1
- package/cpp/llama.cpp/src/llama-batch.cpp +27 -1
- package/cpp/llama.cpp/src/llama-batch.h +8 -1
- package/cpp/llama.cpp/src/llama-chat.cpp +15 -0
- package/cpp/llama.cpp/src/llama-chat.h +1 -0
- package/cpp/llama.cpp/src/llama-graph.cpp +119 -184
- package/cpp/llama.cpp/src/llama-graph.h +47 -60
- package/cpp/llama.cpp/src/llama-hparams.cpp +7 -1
- package/cpp/llama.cpp/src/llama-hparams.h +3 -0
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +28 -18
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +4 -2
- package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +214 -65
- package/cpp/llama.cpp/src/llama-kv-cache-unified.h +62 -24
- package/cpp/llama.cpp/src/llama-kv-cells.h +62 -10
- package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +9 -4
- package/cpp/llama.cpp/src/llama-memory-hybrid.h +3 -1
- package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +20 -10
- package/cpp/llama.cpp/src/llama-memory.cpp +17 -0
- package/cpp/llama.cpp/src/llama-memory.h +3 -0
- package/cpp/llama.cpp/src/llama-model.cpp +2530 -685
- package/cpp/llama.cpp/src/llama-model.h +18 -0
- package/cpp/llama.cpp/src/llama-quant.cpp +1 -0
- package/cpp/llama.cpp/src/llama-vocab.cpp +13 -2
- package/cpp/llama.cpp/src/llama-vocab.h +41 -0
- package/ios/include/chat.h +2 -0
- package/ios/include/common.h +4 -0
- package/ios/include/llama.h +0 -40
- package/ios/libs/llama.xcframework/Info.plist +19 -19
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5055 -4886
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +0 -40
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4861
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3889 -3764
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +0 -40
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4861
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3891 -3766
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +0 -40
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +0 -40
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +0 -40
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5059 -4890
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +0 -40
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4861
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3889 -3764
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +0 -40
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5091 -4922
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +0 -40
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5066 -4897
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3919 -3794
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +0 -40
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/package.json +1 -1
- package/cpp/llama.cpp/ggml/include/ggml-kompute.h +0 -50
- package/cpp/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
- package/cpp/llama.cpp/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
|
@@ -20,6 +20,9 @@
|
|
|
20
20
|
|
|
21
21
|
static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
|
|
22
22
|
|
|
23
|
+
// Work buffer size for im2col operations in CONV2D
|
|
24
|
+
#define GGML_IM2COL_WORK_SIZE (16 * 1024 * 1024)
|
|
25
|
+
|
|
23
26
|
#ifdef __cplusplus
|
|
24
27
|
extern "C" {
|
|
25
28
|
#endif
|
|
@@ -65,6 +68,7 @@ void ggml_compute_forward_clamp(const struct ggml_compute_params * params, struc
|
|
|
65
68
|
void ggml_compute_forward_conv_transpose_1d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
|
66
69
|
void ggml_compute_forward_im2col(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
|
67
70
|
void ggml_compute_forward_im2col_back_f32(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
|
71
|
+
void ggml_compute_forward_conv_2d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
|
68
72
|
void ggml_compute_forward_conv_transpose_2d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
|
69
73
|
void ggml_compute_forward_conv_2d_dw(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
|
70
74
|
void ggml_compute_forward_pool_1d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
|
@@ -94,6 +98,7 @@ void ggml_compute_forward_ssm_scan(const struct ggml_compute_params * params, st
|
|
|
94
98
|
void ggml_compute_forward_win_part(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
|
95
99
|
void ggml_compute_forward_win_unpart(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
|
96
100
|
void ggml_compute_forward_unary(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
|
101
|
+
void ggml_compute_forward_glu(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
|
97
102
|
void ggml_compute_forward_get_rel_pos(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
|
98
103
|
void ggml_compute_forward_add_rel_pos(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
|
99
104
|
void ggml_compute_forward_rwkv_wkv6(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
|
@@ -106,6 +111,7 @@ void ggml_compute_forward_custom(const struct ggml_compute_params * params, stru
|
|
|
106
111
|
void ggml_compute_forward_cross_entropy_loss(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
|
107
112
|
void ggml_compute_forward_cross_entropy_loss_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
|
108
113
|
void ggml_compute_forward_opt_step_adamw(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
|
114
|
+
void ggml_compute_forward_mul_mat(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
|
109
115
|
|
|
110
116
|
#ifdef __cplusplus
|
|
111
117
|
}
|
|
@@ -189,7 +189,7 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
|
|
|
189
189
|
#define GGML_F32xt_LOAD(...) GGML_F32xt_LOAD_IMPL(DEFAULT_PG, __VA_ARGS__)
|
|
190
190
|
#define GGML_F32xt_STORE_IMPL(pg,a,b) svst1_f32(pg, a, b)
|
|
191
191
|
#define GGML_F32xt_STORE(...) GGML_F32xt_STORE_IMPL(DEFAULT_PG, __VA_ARGS__)
|
|
192
|
-
#define GGML_F32xt_FMA_IMPL(pg, a, b, c) svmad_f32_m(pg,
|
|
192
|
+
#define GGML_F32xt_FMA_IMPL(pg, a, b, c) svmad_f32_m(pg, b, c, a)
|
|
193
193
|
#define GGML_F32xt_FMA(...) GGML_F32xt_FMA_IMPL(DEFAULT_PG, __VA_ARGS__)
|
|
194
194
|
#define GGML_F32xt_ADD_IMPL(pg, a, b) svadd_f32_m(pg, a, b)
|
|
195
195
|
#define GGML_F32xt_ADD(...) GGML_F32xt_ADD_IMPL(DEFAULT_PG, __VA_ARGS__)
|
|
@@ -37,35 +37,35 @@ void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * G
|
|
|
37
37
|
for (int i = 0; i < np; i += ggml_f32_step) {
|
|
38
38
|
ax1 = GGML_F32_VEC_LOAD(x + i);
|
|
39
39
|
ay1 = GGML_F32_VEC_LOAD(y + i);
|
|
40
|
-
sum1 = GGML_F32_VEC_FMA(ax1, ay1
|
|
40
|
+
sum1 = GGML_F32_VEC_FMA(sum1, ax1, ay1);
|
|
41
41
|
|
|
42
42
|
ax2 = GGML_F32_VEC_LOAD(x + i + 1*ggml_f32_epr);
|
|
43
43
|
ay2 = GGML_F32_VEC_LOAD(y + i + 1*ggml_f32_epr);
|
|
44
|
-
sum2 = GGML_F32_VEC_FMA(ax2, ay2
|
|
44
|
+
sum2 = GGML_F32_VEC_FMA(sum2, ax2, ay2);
|
|
45
45
|
|
|
46
46
|
ax3 = GGML_F32_VEC_LOAD(x + i + 2*ggml_f32_epr);
|
|
47
47
|
ay3 = GGML_F32_VEC_LOAD(y + i + 2*ggml_f32_epr);
|
|
48
|
-
sum3 = GGML_F32_VEC_FMA(ax3, ay3
|
|
48
|
+
sum3 = GGML_F32_VEC_FMA(sum3, ax3, ay3);
|
|
49
49
|
|
|
50
50
|
ax4 = GGML_F32_VEC_LOAD(x + i + 3*ggml_f32_epr);
|
|
51
51
|
ay4 = GGML_F32_VEC_LOAD(y + i + 3*ggml_f32_epr);
|
|
52
|
-
sum4 = GGML_F32_VEC_FMA(ax4, ay4
|
|
52
|
+
sum4 = GGML_F32_VEC_FMA(sum4, ax4, ay4);
|
|
53
53
|
|
|
54
54
|
ax5 = GGML_F32_VEC_LOAD(x + i + 4*ggml_f32_epr);
|
|
55
55
|
ay5 = GGML_F32_VEC_LOAD(y + i + 4*ggml_f32_epr);
|
|
56
|
-
sum5 = GGML_F32_VEC_FMA(ax5, ay5
|
|
56
|
+
sum5 = GGML_F32_VEC_FMA(sum5, ax5, ay5);
|
|
57
57
|
|
|
58
58
|
ax6 = GGML_F32_VEC_LOAD(x + i + 5*ggml_f32_epr);
|
|
59
59
|
ay6 = GGML_F32_VEC_LOAD(y + i + 5*ggml_f32_epr);
|
|
60
|
-
sum6 = GGML_F32_VEC_FMA(ax6, ay6
|
|
60
|
+
sum6 = GGML_F32_VEC_FMA(sum6, ax6, ay6);
|
|
61
61
|
|
|
62
62
|
ax7 = GGML_F32_VEC_LOAD(x + i + 6*ggml_f32_epr);
|
|
63
63
|
ay7 = GGML_F32_VEC_LOAD(y + i + 6*ggml_f32_epr);
|
|
64
|
-
sum7 = GGML_F32_VEC_FMA(ax7, ay7
|
|
64
|
+
sum7 = GGML_F32_VEC_FMA(sum7, ax7, ay7);
|
|
65
65
|
|
|
66
66
|
ax8 = GGML_F32_VEC_LOAD(x + i + 7*ggml_f32_epr);
|
|
67
67
|
ay8 = GGML_F32_VEC_LOAD(y + i + 7*ggml_f32_epr);
|
|
68
|
-
sum8 = GGML_F32_VEC_FMA(ax8, ay8
|
|
68
|
+
sum8 = GGML_F32_VEC_FMA(sum8, ax8, ay8);
|
|
69
69
|
}
|
|
70
70
|
// leftovers
|
|
71
71
|
// Since 8 unrolls are done in above loop, leftovers lie in range [0, ggml_f32_step] which is handled in below loop
|
|
@@ -73,7 +73,7 @@ void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * G
|
|
|
73
73
|
for (int i = np; i < np2; i += ggml_f32_epr) {
|
|
74
74
|
ax1 = GGML_F32_VEC_LOAD(x + i);
|
|
75
75
|
ay1 = GGML_F32_VEC_LOAD(y + i);
|
|
76
|
-
sum1 = GGML_F32_VEC_FMA(ax1, ay1
|
|
76
|
+
sum1 = GGML_F32_VEC_FMA(sum1, ax1, ay1);
|
|
77
77
|
}
|
|
78
78
|
// maximum number of leftover elements will be less that ggml_f32_epr. Apply predicated svmad on available elements only
|
|
79
79
|
if (np2 < n) {
|
|
@@ -254,6 +254,30 @@ void ggml_vec_silu_f32(const int n, float * y, const float * x) {
|
|
|
254
254
|
}
|
|
255
255
|
}
|
|
256
256
|
|
|
257
|
+
void ggml_vec_swiglu_f32(const int n, float * y, const float * x, const float * g) {
|
|
258
|
+
int i = 0;
|
|
259
|
+
#if defined(__AVX512F__) && defined(__AVX512DQ__)
|
|
260
|
+
for (; i + 15 < n; i += 16) {
|
|
261
|
+
_mm512_storeu_ps(y + i, _mm512_mul_ps(ggml_v_silu(_mm512_loadu_ps(x + i)), _mm512_loadu_ps(g + i)));
|
|
262
|
+
}
|
|
263
|
+
#elif defined(__AVX2__) && defined(__FMA__)
|
|
264
|
+
for (; i + 7 < n; i += 8) {
|
|
265
|
+
_mm256_storeu_ps(y + i, _mm256_mul_ps(ggml_v_silu(_mm256_loadu_ps(x + i)), _mm256_loadu_ps(g + i)));
|
|
266
|
+
}
|
|
267
|
+
#elif defined(__SSE2__)
|
|
268
|
+
for (; i + 3 < n; i += 4) {
|
|
269
|
+
_mm_storeu_ps(y + i, _mm_mul_ps(ggml_v_silu(_mm_loadu_ps(x + i)), _mm_loadu_ps(g + i)));
|
|
270
|
+
}
|
|
271
|
+
#elif defined(__ARM_NEON) && defined(__aarch64__)
|
|
272
|
+
for (; i + 3 < n; i += 4) {
|
|
273
|
+
vst1q_f32(y + i, vmulq_f32(ggml_v_silu(vld1q_f32(x + i)), vld1q_f32(g + i)));
|
|
274
|
+
}
|
|
275
|
+
#endif
|
|
276
|
+
for (; i < n; ++i) {
|
|
277
|
+
y[i] = ggml_silu_f32(x[i]) * g[i];
|
|
278
|
+
}
|
|
279
|
+
}
|
|
280
|
+
|
|
257
281
|
ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max) {
|
|
258
282
|
int i = 0;
|
|
259
283
|
ggml_float sum = 0;
|
|
@@ -163,49 +163,49 @@ inline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const
|
|
|
163
163
|
|
|
164
164
|
ax1 = GGML_F32_VEC_LOAD(x + i);
|
|
165
165
|
ay1 = GGML_F32_VEC_LOAD(y + i);
|
|
166
|
-
ay1 = GGML_F32_VEC_FMA(ax1, vx
|
|
166
|
+
ay1 = GGML_F32_VEC_FMA(ay1, ax1, vx);
|
|
167
167
|
|
|
168
168
|
GGML_F32_VEC_STORE(y + i, ay1);
|
|
169
169
|
|
|
170
170
|
ax2 = GGML_F32_VEC_LOAD(x + i + 1*ggml_f32_epr);
|
|
171
171
|
ay2 = GGML_F32_VEC_LOAD(y + i + 1*ggml_f32_epr);
|
|
172
|
-
ay2 = GGML_F32_VEC_FMA(ax2, vx
|
|
172
|
+
ay2 = GGML_F32_VEC_FMA(ay2, ax2, vx);
|
|
173
173
|
|
|
174
174
|
GGML_F32_VEC_STORE(y + i + 1*ggml_f32_epr, ay2);
|
|
175
175
|
|
|
176
176
|
ax3 = GGML_F32_VEC_LOAD(x + i + 2*ggml_f32_epr);
|
|
177
177
|
ay3 = GGML_F32_VEC_LOAD(y + i + 2*ggml_f32_epr);
|
|
178
|
-
ay3 = GGML_F32_VEC_FMA(ax3, vx
|
|
178
|
+
ay3 = GGML_F32_VEC_FMA(ay3, ax3, vx);
|
|
179
179
|
|
|
180
180
|
GGML_F32_VEC_STORE(y + i + 2*ggml_f32_epr, ay3);
|
|
181
181
|
|
|
182
182
|
ax4 = GGML_F32_VEC_LOAD(x + i + 3*ggml_f32_epr);
|
|
183
183
|
ay4 = GGML_F32_VEC_LOAD(y + i + 3*ggml_f32_epr);
|
|
184
|
-
ay4 = GGML_F32_VEC_FMA(ax4, vx
|
|
184
|
+
ay4 = GGML_F32_VEC_FMA(ay4, ax4, vx);
|
|
185
185
|
|
|
186
186
|
GGML_F32_VEC_STORE(y + i + 3*ggml_f32_epr, ay4);
|
|
187
187
|
|
|
188
188
|
ax5 = GGML_F32_VEC_LOAD(x + i + 4*ggml_f32_epr);
|
|
189
189
|
ay5 = GGML_F32_VEC_LOAD(y + i + 4*ggml_f32_epr);
|
|
190
|
-
ay5 = GGML_F32_VEC_FMA(ax5, vx
|
|
190
|
+
ay5 = GGML_F32_VEC_FMA(ay5, ax5, vx);
|
|
191
191
|
|
|
192
192
|
GGML_F32_VEC_STORE(y + i + 4*ggml_f32_epr, ay5);
|
|
193
193
|
|
|
194
194
|
ax6 = GGML_F32_VEC_LOAD(x + i + 5*ggml_f32_epr);
|
|
195
195
|
ay6 = GGML_F32_VEC_LOAD(y + i + 5*ggml_f32_epr);
|
|
196
|
-
ay6 = GGML_F32_VEC_FMA(ax6, vx
|
|
196
|
+
ay6 = GGML_F32_VEC_FMA(ay6, ax6, vx);
|
|
197
197
|
|
|
198
198
|
GGML_F32_VEC_STORE(y + i + 5*ggml_f32_epr, ay6);
|
|
199
199
|
|
|
200
200
|
ax7 = GGML_F32_VEC_LOAD(x + i + 6*ggml_f32_epr);
|
|
201
201
|
ay7 = GGML_F32_VEC_LOAD(y + i + 6*ggml_f32_epr);
|
|
202
|
-
ay7 = GGML_F32_VEC_FMA(ax7, vx
|
|
202
|
+
ay7 = GGML_F32_VEC_FMA(ay7, ax7, vx);
|
|
203
203
|
|
|
204
204
|
GGML_F32_VEC_STORE(y + i + 6*ggml_f32_epr, ay7);
|
|
205
205
|
|
|
206
206
|
ax8 = GGML_F32_VEC_LOAD(x + i + 7*ggml_f32_epr);
|
|
207
207
|
ay8 = GGML_F32_VEC_LOAD(y + i + 7*ggml_f32_epr);
|
|
208
|
-
ay8 = GGML_F32_VEC_FMA(ax8, vx
|
|
208
|
+
ay8 = GGML_F32_VEC_FMA(ay8, ax8, vx);
|
|
209
209
|
|
|
210
210
|
GGML_F32_VEC_STORE(y + i + 7*ggml_f32_epr, ay8);
|
|
211
211
|
}
|
|
@@ -215,7 +215,7 @@ inline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const
|
|
|
215
215
|
for (int i = np; i < np2; i += ggml_f32_epr) {
|
|
216
216
|
ax1 = GGML_F32_VEC_LOAD(x + i);
|
|
217
217
|
ay1 = GGML_F32_VEC_LOAD(y + i);
|
|
218
|
-
ay1 = GGML_F32_VEC_FMA(ax1, vx
|
|
218
|
+
ay1 = GGML_F32_VEC_FMA(ay1, ax1, vx);
|
|
219
219
|
|
|
220
220
|
GGML_F32_VEC_STORE(y + i, ay1);
|
|
221
221
|
}
|
|
@@ -351,6 +351,45 @@ inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int
|
|
|
351
351
|
#endif
|
|
352
352
|
}
|
|
353
353
|
|
|
354
|
+
inline static void ggml_vec_mad1_f32(const int n, float * y, const float * x, const float s, const float b) {
|
|
355
|
+
#if defined(GGML_USE_ACCELERATE)
|
|
356
|
+
vDSP_vsmsa(x, 1, &s, &b, y, 1, n);
|
|
357
|
+
#elif defined(GGML_SIMD)
|
|
358
|
+
#if defined(__ARM_FEATURE_SVE)
|
|
359
|
+
// scalar ; TODO: Write SVE code
|
|
360
|
+
for (int i = 0; i < n; ++i) {
|
|
361
|
+
y[i] = x[i]*s + b;
|
|
362
|
+
}
|
|
363
|
+
#else
|
|
364
|
+
const int np = (n & ~(GGML_F32_STEP - 1));
|
|
365
|
+
|
|
366
|
+
GGML_F32_VEC vs = GGML_F32_VEC_SET1(s);
|
|
367
|
+
GGML_F32_VEC vb = GGML_F32_VEC_SET1(b);
|
|
368
|
+
|
|
369
|
+
GGML_F32_VEC ay[GGML_F32_ARR];
|
|
370
|
+
|
|
371
|
+
for (int i = 0; i < np; i += GGML_F32_STEP) {
|
|
372
|
+
for (int j = 0; j < GGML_F32_ARR; j++) {
|
|
373
|
+
ay[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR);
|
|
374
|
+
ay[j] = GGML_F32_VEC_FMA(ay[j], vs, vb);
|
|
375
|
+
|
|
376
|
+
GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
|
|
377
|
+
}
|
|
378
|
+
}
|
|
379
|
+
|
|
380
|
+
// leftovers
|
|
381
|
+
for (int i = np; i < n; ++i) {
|
|
382
|
+
y[i] = x[i]*s + b;
|
|
383
|
+
}
|
|
384
|
+
#endif
|
|
385
|
+
#else
|
|
386
|
+
// scalar
|
|
387
|
+
for (int i = 0; i < n; ++i) {
|
|
388
|
+
y[i] = x[i]*s + b;
|
|
389
|
+
}
|
|
390
|
+
#endif
|
|
391
|
+
}
|
|
392
|
+
|
|
354
393
|
//inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] *= v; }
|
|
355
394
|
inline static void ggml_vec_scale_f32(const int n, float * y, const float v) {
|
|
356
395
|
#if defined(GGML_USE_ACCELERATE)
|
|
@@ -905,6 +944,100 @@ inline static void ggml_vec_silu_backward_f16(const int n, ggml_fp16_t * dx, con
|
|
|
905
944
|
}
|
|
906
945
|
}
|
|
907
946
|
|
|
947
|
+
inline static void ggml_vec_reglu_f32 (const int n, float * y, const float * x, const float * g) {
|
|
948
|
+
for (int i = 0; i < n; ++i) {
|
|
949
|
+
y[i] = (x[i] > 0.f) ? x[i] * g[i] : 0.f;
|
|
950
|
+
}
|
|
951
|
+
}
|
|
952
|
+
|
|
953
|
+
inline static void ggml_vec_reglu_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
|
|
954
|
+
for (int i = 0; i < n; ++i) {
|
|
955
|
+
float v = GGML_CPU_FP16_TO_FP32(x[i]);
|
|
956
|
+
y[i] = GGML_CPU_FP32_TO_FP16((v > 0.f) ? v * GGML_CPU_FP16_TO_FP32(g[i]) : 0.f);
|
|
957
|
+
}
|
|
958
|
+
}
|
|
959
|
+
|
|
960
|
+
#ifdef GGML_GELU_FP16
|
|
961
|
+
inline static void ggml_vec_geglu_f32(const int n, float * y, const float * x, const float * g) {
|
|
962
|
+
uint16_t t;
|
|
963
|
+
for (int i = 0; i < n; ++i) {
|
|
964
|
+
if (x[i] <= -10.0f) {
|
|
965
|
+
y[i] = 0.0f;
|
|
966
|
+
} else if (x[i] >= 10.0f) {
|
|
967
|
+
y[i] = x[i] * g[i];
|
|
968
|
+
} else {
|
|
969
|
+
ggml_fp16_t fp16 = GGML_CPU_FP32_TO_FP16(x[i]);
|
|
970
|
+
memcpy(&t, &fp16, sizeof(uint16_t));
|
|
971
|
+
y[i] = GGML_CPU_FP16_TO_FP32(ggml_table_gelu_f16[t]) * g[i];
|
|
972
|
+
}
|
|
973
|
+
}
|
|
974
|
+
}
|
|
975
|
+
#else
|
|
976
|
+
inline static void ggml_vec_geglu_f32(const int n, float * y, const float * x, const float * g) {
|
|
977
|
+
for (int i = 0; i < n; ++i) {
|
|
978
|
+
y[i] = ggml_gelu_f32(x[i]) * g[i];
|
|
979
|
+
}
|
|
980
|
+
}
|
|
981
|
+
#endif
|
|
982
|
+
|
|
983
|
+
inline static void ggml_vec_geglu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
|
|
984
|
+
const uint16_t * i16 = (const uint16_t *) x;
|
|
985
|
+
for (int i = 0; i < n; ++i) {
|
|
986
|
+
float v = GGML_CPU_FP16_TO_FP32(g[i]);
|
|
987
|
+
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(ggml_table_gelu_f16[i16[i]]) * v);
|
|
988
|
+
}
|
|
989
|
+
}
|
|
990
|
+
|
|
991
|
+
void ggml_vec_swiglu_f32(const int n, float * y, const float * x, const float * g);
|
|
992
|
+
|
|
993
|
+
inline static void ggml_vec_swiglu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
|
|
994
|
+
for (int i = 0; i < n; ++i) {
|
|
995
|
+
float v = GGML_CPU_FP16_TO_FP32(x[i]);
|
|
996
|
+
float w = GGML_CPU_FP16_TO_FP32(g[i]);
|
|
997
|
+
y[i] = GGML_CPU_FP32_TO_FP16((v/(1.0f + expf(-v))) * w);
|
|
998
|
+
}
|
|
999
|
+
}
|
|
1000
|
+
|
|
1001
|
+
inline static void ggml_vec_geglu_erf_f32(const int n, float * y, const float * x, const float * g) {
|
|
1002
|
+
for (int i = 0; i < n; ++i) {
|
|
1003
|
+
float xi = x[i];
|
|
1004
|
+
y[i] = 0.5f * xi * (1.0f + erff(xi*SQRT_2_INV)) * g[i];
|
|
1005
|
+
}
|
|
1006
|
+
}
|
|
1007
|
+
|
|
1008
|
+
inline static void ggml_vec_geglu_erf_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
|
|
1009
|
+
for (int i = 0; i < n; ++i) {
|
|
1010
|
+
float xi = GGML_CPU_FP16_TO_FP32(x[i]);
|
|
1011
|
+
float gi = GGML_CPU_FP16_TO_FP32(g[i]);
|
|
1012
|
+
y[i] = GGML_CPU_FP32_TO_FP16(0.5f * xi * (1.0f + erff(xi*SQRT_2_INV)) * gi);
|
|
1013
|
+
}
|
|
1014
|
+
}
|
|
1015
|
+
|
|
1016
|
+
#ifdef GGML_GELU_QUICK_FP16
|
|
1017
|
+
inline static void ggml_vec_geglu_quick_f32(const int n, float * y, const float * x, const float * g) {
|
|
1018
|
+
uint16_t t;
|
|
1019
|
+
for (int i = 0; i < n; ++i) {
|
|
1020
|
+
ggml_fp16_t fp16 = GGML_CPU_FP32_TO_FP16(x[i]);
|
|
1021
|
+
memcpy(&t, &fp16, sizeof(uint16_t));
|
|
1022
|
+
y[i] = GGML_CPU_FP16_TO_FP32(ggml_table_gelu_quick_f16[t]) * g[i];
|
|
1023
|
+
}
|
|
1024
|
+
}
|
|
1025
|
+
#else
|
|
1026
|
+
inline static void ggml_vec_geglu_quick_f32(const int n, float * y, const float * x, const float * g) {
|
|
1027
|
+
for (int i = 0; i < n; ++i) {
|
|
1028
|
+
y[i] = ggml_gelu_quick_f32(x[i]) * g[i];
|
|
1029
|
+
}
|
|
1030
|
+
}
|
|
1031
|
+
#endif
|
|
1032
|
+
|
|
1033
|
+
inline static void ggml_vec_geglu_quick_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
|
|
1034
|
+
const uint16_t * i16 = (const uint16_t *) x;
|
|
1035
|
+
for (int i = 0; i < n; ++i) {
|
|
1036
|
+
float v = GGML_CPU_FP16_TO_FP32(g[i]);
|
|
1037
|
+
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(ggml_table_gelu_quick_f16[i16[i]]) * v);
|
|
1038
|
+
}
|
|
1039
|
+
}
|
|
1040
|
+
|
|
908
1041
|
inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) {
|
|
909
1042
|
#ifndef GGML_USE_ACCELERATE
|
|
910
1043
|
ggml_float sum = 0.0;
|
|
@@ -175,6 +175,23 @@ static const char * cu_get_error_str(CUresult err) {
|
|
|
175
175
|
#define CU_CHECK(err) CUDA_CHECK_GEN(err, CUDA_SUCCESS, cu_get_error_str)
|
|
176
176
|
#endif
|
|
177
177
|
|
|
178
|
+
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
|
179
|
+
# define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes) \
|
|
180
|
+
do { \
|
|
181
|
+
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = { false }; \
|
|
182
|
+
const int id = ggml_cuda_get_device(); \
|
|
183
|
+
if (!shared_memory_limit_raised[id]) { \
|
|
184
|
+
CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes)); \
|
|
185
|
+
shared_memory_limit_raised[id] = true; \
|
|
186
|
+
} \
|
|
187
|
+
} while (0)
|
|
188
|
+
#else
|
|
189
|
+
# define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes) \
|
|
190
|
+
do { \
|
|
191
|
+
GGML_UNUSED(nbytes); \
|
|
192
|
+
} while (0)
|
|
193
|
+
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
|
194
|
+
|
|
178
195
|
#if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA)
|
|
179
196
|
#define GGML_CUDA_ASSUME(x) __builtin_assume(x)
|
|
180
197
|
#else
|
|
@@ -728,3 +728,25 @@ to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type) {
|
|
|
728
728
|
return nullptr;
|
|
729
729
|
}
|
|
730
730
|
}
|
|
731
|
+
|
|
732
|
+
to_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda(ggml_type type) {
|
|
733
|
+
switch (type) {
|
|
734
|
+
case GGML_TYPE_F32:
|
|
735
|
+
return convert_unary_cuda<float, nv_bfloat16>;
|
|
736
|
+
case GGML_TYPE_F16:
|
|
737
|
+
return convert_unary_cuda<half, nv_bfloat16>;
|
|
738
|
+
default:
|
|
739
|
+
return nullptr;
|
|
740
|
+
}
|
|
741
|
+
}
|
|
742
|
+
|
|
743
|
+
to_fp32_nc_cuda_t ggml_get_to_fp32_nc_cuda(ggml_type type) {
|
|
744
|
+
switch (type) {
|
|
745
|
+
case GGML_TYPE_F16:
|
|
746
|
+
return convert_unary_cuda<half, float>;
|
|
747
|
+
case GGML_TYPE_BF16:
|
|
748
|
+
return convert_unary_cuda<nv_bfloat16, float>;
|
|
749
|
+
default:
|
|
750
|
+
return nullptr;
|
|
751
|
+
}
|
|
752
|
+
}
|
|
@@ -22,5 +22,10 @@ using to_t_nc_cuda_t = void (*)(const void * x, T * y,
|
|
|
22
22
|
int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne03,
|
|
23
23
|
int64_t s01, int64_t s02, int64_t s03, cudaStream_t stream);
|
|
24
24
|
|
|
25
|
+
typedef to_t_nc_cuda_t<float> to_fp32_nc_cuda_t;
|
|
25
26
|
typedef to_t_nc_cuda_t<half> to_fp16_nc_cuda_t;
|
|
27
|
+
typedef to_t_nc_cuda_t<nv_bfloat16> to_bf16_nc_cuda_t;
|
|
28
|
+
|
|
29
|
+
to_fp32_nc_cuda_t ggml_get_to_fp32_nc_cuda(ggml_type type);
|
|
26
30
|
to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type);
|
|
31
|
+
to_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda(ggml_type type);
|
|
@@ -123,13 +123,7 @@ void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor *
|
|
|
123
123
|
ggml_cuda_pool_alloc<float> dst_tmp(pool, blocks_num.x);
|
|
124
124
|
|
|
125
125
|
if (nbytes_shared <= smpbo) {
|
|
126
|
-
|
|
127
|
-
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
|
|
128
|
-
if (!shared_memory_limit_raised[id]) {
|
|
129
|
-
CUDA_CHECK(cudaFuncSetAttribute(cross_entropy_loss_f32<true>, cudaFuncAttributeMaxDynamicSharedMemorySize, smpbo));
|
|
130
|
-
shared_memory_limit_raised[id] = true;
|
|
131
|
-
}
|
|
132
|
-
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
|
126
|
+
CUDA_SET_SHARED_MEMORY_LIMIT((cross_entropy_loss_f32<true>), smpbo);
|
|
133
127
|
cross_entropy_loss_f32<true><<<blocks_num, blocks_dim, nbytes_shared, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
|
|
134
128
|
} else {
|
|
135
129
|
cross_entropy_loss_f32<false><<<blocks_num, blocks_dim, 0, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
|
|
@@ -175,13 +169,7 @@ void ggml_cuda_cross_entropy_loss_back(ggml_backend_cuda_context & ctx, ggml_ten
|
|
|
175
169
|
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
|
|
176
170
|
|
|
177
171
|
if (nbytes_shared <= smpbo) {
|
|
178
|
-
|
|
179
|
-
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
|
|
180
|
-
if (!shared_memory_limit_raised[id]) {
|
|
181
|
-
CUDA_CHECK(cudaFuncSetAttribute(cross_entropy_loss_back_f32<true>, cudaFuncAttributeMaxDynamicSharedMemorySize, smpbo));
|
|
182
|
-
shared_memory_limit_raised[id] = true;
|
|
183
|
-
}
|
|
184
|
-
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
|
172
|
+
CUDA_SET_SHARED_MEMORY_LIMIT((cross_entropy_loss_back_f32<true>), smpbo);
|
|
185
173
|
cross_entropy_loss_back_f32<true><<<blocks_num, blocks_dim, nbytes_shared, stream>>>(grad_d, src0f_d, src1f_d, dst_d, ne00);
|
|
186
174
|
} else {
|
|
187
175
|
cross_entropy_loss_back_f32<false><<<blocks_num, blocks_dim, 0, stream>>>(grad_d, src0f_d, src1f_d, dst_d, ne00);
|
|
@@ -32,7 +32,9 @@ typedef void (* fattn_kernel_t)(
|
|
|
32
32
|
const int ne12,
|
|
33
33
|
const int ne13,
|
|
34
34
|
const int ne31,
|
|
35
|
+
const int ne32,
|
|
35
36
|
const int nb31,
|
|
37
|
+
const int nb32,
|
|
36
38
|
const int nb01,
|
|
37
39
|
const int nb02,
|
|
38
40
|
const int nb03,
|
|
@@ -851,7 +853,8 @@ void launch_fattn(
|
|
|
851
853
|
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
|
|
852
854
|
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
|
853
855
|
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
|
854
|
-
mask ? mask->ne[1] : 0, mask ?
|
|
856
|
+
mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0,
|
|
857
|
+
mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0,
|
|
855
858
|
Q->nb[1], Q->nb[2], Q->nb[3],
|
|
856
859
|
nb11, nb12, nb13,
|
|
857
860
|
nb21, nb22, nb23,
|
|
@@ -1223,7 +1223,9 @@ static __global__ void flash_attn_ext_f16(
|
|
|
1223
1223
|
const int ne12,
|
|
1224
1224
|
const int ne13,
|
|
1225
1225
|
const int ne31,
|
|
1226
|
+
const int ne32,
|
|
1226
1227
|
const int nb31,
|
|
1228
|
+
const int nb32,
|
|
1227
1229
|
const int nb01,
|
|
1228
1230
|
const int nb02,
|
|
1229
1231
|
const int nb03,
|
|
@@ -1288,7 +1290,8 @@ static __global__ void flash_attn_ext_f16(
|
|
|
1288
1290
|
|
|
1289
1291
|
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
|
|
1290
1292
|
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
|
|
1291
|
-
const half2 * mask_h2 = ncols2
|
|
1293
|
+
const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
|
|
1294
|
+
(const half2 *) (mask + nb32*(channel % ne32) + nb31*jt*ncols1);
|
|
1292
1295
|
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
|
|
1293
1296
|
|
|
1294
1297
|
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
|
|
@@ -1327,7 +1330,8 @@ static __global__ void flash_attn_ext_f16(
|
|
|
1327
1330
|
|
|
1328
1331
|
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
|
|
1329
1332
|
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
|
|
1330
|
-
const half2 * mask_h2 = ncols2
|
|
1333
|
+
const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
|
|
1334
|
+
(const half2 *) (mask + nb32*(channel % ne32) + nb31*jt*ncols1);
|
|
1331
1335
|
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
|
|
1332
1336
|
|
|
1333
1337
|
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
|
|
@@ -1348,8 +1352,8 @@ static __global__ void flash_attn_ext_f16(
|
|
|
1348
1352
|
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
|
1349
1353
|
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00);
|
|
1350
1354
|
GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10);
|
|
1351
|
-
GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
|
|
1352
|
-
GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
|
|
1355
|
+
GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
|
|
1356
|
+
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
|
|
1353
1357
|
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21);
|
|
1354
1358
|
GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
|
|
1355
1359
|
GGML_UNUSED(ne2); GGML_UNUSED(ne3);
|
|
@@ -6,7 +6,7 @@
|
|
|
6
6
|
|
|
7
7
|
template<int D, int ncols, int nwarps, bool use_logit_softcap> // D == head size
|
|
8
8
|
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
|
9
|
-
__launch_bounds__(nwarps*WARP_SIZE,
|
|
9
|
+
__launch_bounds__(nwarps*WARP_SIZE, 2)
|
|
10
10
|
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
|
11
11
|
static __global__ void flash_attn_tile_ext_f16(
|
|
12
12
|
const char * __restrict__ Q,
|
|
@@ -30,7 +30,9 @@ static __global__ void flash_attn_tile_ext_f16(
|
|
|
30
30
|
const int ne12,
|
|
31
31
|
const int ne13,
|
|
32
32
|
const int ne31,
|
|
33
|
+
const int ne32,
|
|
33
34
|
const int nb31,
|
|
35
|
+
const int nb32,
|
|
34
36
|
const int nb01,
|
|
35
37
|
const int nb02,
|
|
36
38
|
const int nb03,
|
|
@@ -64,7 +66,7 @@ static __global__ void flash_attn_tile_ext_f16(
|
|
|
64
66
|
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0);
|
|
65
67
|
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio));
|
|
66
68
|
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
|
|
67
|
-
const half * maskh = (const half *)
|
|
69
|
+
const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
|
|
68
70
|
|
|
69
71
|
const int stride_KV2 = nb11 / sizeof(half2);
|
|
70
72
|
|
|
@@ -288,8 +290,8 @@ static __global__ void flash_attn_tile_ext_f16(
|
|
|
288
290
|
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
|
289
291
|
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
|
|
290
292
|
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
|
|
291
|
-
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
|
|
292
|
-
GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
|
293
|
+
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
|
|
294
|
+
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
|
293
295
|
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
|
|
294
296
|
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
|
|
295
297
|
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
|
|
@@ -6,7 +6,7 @@
|
|
|
6
6
|
|
|
7
7
|
template<int D, int ncols, int nwarps, bool use_logit_softcap> // D == head size
|
|
8
8
|
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
|
9
|
-
__launch_bounds__(nwarps*WARP_SIZE,
|
|
9
|
+
__launch_bounds__(nwarps*WARP_SIZE, 2)
|
|
10
10
|
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
|
11
11
|
static __global__ void flash_attn_tile_ext_f32(
|
|
12
12
|
const char * __restrict__ Q,
|
|
@@ -30,7 +30,9 @@ static __global__ void flash_attn_tile_ext_f32(
|
|
|
30
30
|
const int ne12,
|
|
31
31
|
const int ne13,
|
|
32
32
|
const int ne31,
|
|
33
|
+
const int ne32,
|
|
33
34
|
const int nb31,
|
|
35
|
+
const int nb32,
|
|
34
36
|
const int nb01,
|
|
35
37
|
const int nb02,
|
|
36
38
|
const int nb03,
|
|
@@ -58,8 +60,8 @@ static __global__ void flash_attn_tile_ext_f32(
|
|
|
58
60
|
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
|
59
61
|
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
|
|
60
62
|
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
|
|
61
|
-
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
|
|
62
|
-
GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
|
63
|
+
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
|
|
64
|
+
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
|
63
65
|
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
|
|
64
66
|
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
|
|
65
67
|
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
|
|
@@ -76,7 +78,7 @@ static __global__ void flash_attn_tile_ext_f32(
|
|
|
76
78
|
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0);
|
|
77
79
|
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio));
|
|
78
80
|
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
|
|
79
|
-
const half * maskh = (const half *)
|
|
81
|
+
const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
|
|
80
82
|
|
|
81
83
|
const int stride_KV2 = nb11 / sizeof(half2);
|
|
82
84
|
|
|
@@ -297,14 +299,14 @@ static __global__ void flash_attn_tile_ext_f32(
|
|
|
297
299
|
GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
|
|
298
300
|
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
|
299
301
|
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
|
300
|
-
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
|
|
301
|
-
GGML_UNUSED(
|
|
302
|
-
GGML_UNUSED(
|
|
303
|
-
GGML_UNUSED(nb31); GGML_UNUSED(
|
|
304
|
-
GGML_UNUSED(
|
|
305
|
-
GGML_UNUSED(
|
|
306
|
-
GGML_UNUSED(
|
|
307
|
-
GGML_UNUSED(ne2); GGML_UNUSED(ne3);
|
|
302
|
+
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
|
|
303
|
+
GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
|
|
304
|
+
GGML_UNUSED(ne31); GGML_UNUSED(ne32);
|
|
305
|
+
GGML_UNUSED(nb31); GGML_UNUSED(nb32);
|
|
306
|
+
GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
|
|
307
|
+
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
|
|
308
|
+
GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
|
|
309
|
+
GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3);
|
|
308
310
|
NO_DEVICE_CODE;
|
|
309
311
|
#endif // FLASH_ATTN_AVAILABLE
|
|
310
312
|
}
|
|
@@ -27,7 +27,9 @@ static __global__ void flash_attn_vec_ext_f16(
|
|
|
27
27
|
const int ne12,
|
|
28
28
|
const int ne13,
|
|
29
29
|
const int ne31,
|
|
30
|
+
const int ne32,
|
|
30
31
|
const int nb31,
|
|
32
|
+
const int nb32,
|
|
31
33
|
const int nb01,
|
|
32
34
|
const int nb02,
|
|
33
35
|
const int nb03,
|
|
@@ -68,7 +70,7 @@ static __global__ void flash_attn_vec_ext_f16(
|
|
|
68
70
|
K += nb12*(blockIdx.z / gqa_ratio);
|
|
69
71
|
V += nb22*(blockIdx.z / gqa_ratio);
|
|
70
72
|
|
|
71
|
-
const half * maskh = (const half
|
|
73
|
+
const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
|
|
72
74
|
|
|
73
75
|
const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
|
|
74
76
|
const half slopeh = __float2half(slopef);
|
|
@@ -342,8 +344,8 @@ static __global__ void flash_attn_vec_ext_f16(
|
|
|
342
344
|
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
|
343
345
|
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
|
|
344
346
|
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
|
|
345
|
-
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
|
|
346
|
-
GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
|
347
|
+
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
|
|
348
|
+
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
|
347
349
|
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
|
|
348
350
|
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
|
|
349
351
|
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
|