@novastera-oss/llamarn 0.2.1 → 0.2.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +80 -14
- package/RNLlamaCpp.podspec +10 -3
- package/android/CMakeLists.txt +8 -0
- package/android/src/main/cpp/include/llama.h +62 -125
- package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
- package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
- package/cpp/PureCppImpl.cpp +9 -27
- package/cpp/SystemUtils.h +2 -2
- package/cpp/build-info.cpp +2 -2
- package/cpp/llama.cpp/README.md +11 -3
- package/cpp/llama.cpp/build-xcframework.sh +1 -0
- package/cpp/llama.cpp/common/CMakeLists.txt +8 -2
- package/cpp/llama.cpp/common/arg.cpp +153 -113
- package/cpp/llama.cpp/common/chat-parser.cpp +379 -0
- package/cpp/llama.cpp/common/chat-parser.h +117 -0
- package/cpp/llama.cpp/common/chat.cpp +847 -699
- package/cpp/llama.cpp/common/chat.h +73 -6
- package/cpp/llama.cpp/common/common.cpp +50 -82
- package/cpp/llama.cpp/common/common.h +21 -17
- package/cpp/llama.cpp/common/json-partial.cpp +255 -0
- package/cpp/llama.cpp/common/json-partial.h +37 -0
- package/cpp/llama.cpp/common/minja/chat-template.hpp +9 -5
- package/cpp/llama.cpp/common/minja/minja.hpp +69 -36
- package/cpp/llama.cpp/common/regex-partial.cpp +204 -0
- package/cpp/llama.cpp/common/regex-partial.h +56 -0
- package/cpp/llama.cpp/common/sampling.cpp +7 -8
- package/cpp/llama.cpp/convert_hf_to_gguf.py +453 -118
- package/cpp/llama.cpp/convert_hf_to_gguf_update.py +120 -68
- package/cpp/llama.cpp/ggml/CMakeLists.txt +2 -1
- package/cpp/llama.cpp/ggml/cmake/common.cmake +25 -0
- package/cpp/llama.cpp/ggml/include/ggml-opt.h +49 -28
- package/cpp/llama.cpp/ggml/include/ggml.h +26 -7
- package/cpp/llama.cpp/ggml/src/CMakeLists.txt +16 -10
- package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +4 -1
- package/cpp/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +1 -0
- package/cpp/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +604 -0
- package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +42 -0
- package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +54 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +50 -51
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +5 -9
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +779 -19
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +22 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +88 -5
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +47 -12
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +264 -69
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +322 -100
- package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +117 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +85 -16
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +220 -49
- package/cpp/llama.cpp/ggml/src/ggml-cuda/acc.cu +40 -26
- package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy.cu +11 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +15 -7
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +266 -64
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +49 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +48 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cu +2 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +5 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cu +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/quantize.cu +7 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/sum.cu +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +10 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-impl.h +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +99 -17
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +200 -2
- package/cpp/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +8 -2
- package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cu +112 -0
- package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cuh +12 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +6 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +972 -178
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/div.cl +72 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/group_norm.cl +72 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/sub.cl +72 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/sum_rows.cl +39 -0
- package/cpp/llama.cpp/ggml/src/ggml-opt.cpp +373 -190
- package/cpp/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +29 -23
- package/cpp/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +5 -10
- package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +101 -5
- package/cpp/llama.cpp/ggml/src/ggml-sycl/concat.cpp +31 -33
- package/cpp/llama.cpp/ggml/src/ggml-sycl/conv.cpp +1 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +29 -2
- package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +4 -5
- package/cpp/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +59 -21
- package/cpp/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +9 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +84 -72
- package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +37 -8
- package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +1 -3
- package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +324 -129
- package/cpp/llama.cpp/ggml/src/ggml-sycl/gla.cpp +1 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +31 -2
- package/cpp/llama.cpp/ggml/src/ggml-sycl/norm.cpp +95 -68
- package/cpp/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +1 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +22 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +1 -2
- package/cpp/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +1 -4
- package/cpp/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +2 -3
- package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +69 -43
- package/cpp/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +2 -14
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +81 -91
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +432 -181
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +17 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +6 -152
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +162 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +360 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +2 -118
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +12 -1
- package/cpp/llama.cpp/ggml/src/ggml.c +107 -36
- package/cpp/llama.cpp/ggml/src/gguf.cpp +33 -33
- package/cpp/llama.cpp/gguf-py/gguf/constants.py +100 -15
- package/cpp/llama.cpp/gguf-py/gguf/gguf_reader.py +1 -1
- package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +44 -12
- package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_editor_gui.py +21 -10
- package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_new_metadata.py +5 -2
- package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +128 -31
- package/cpp/llama.cpp/gguf-py/gguf/utility.py +1 -1
- package/cpp/llama.cpp/gguf-py/pyproject.toml +1 -1
- package/cpp/llama.cpp/include/llama.h +62 -125
- package/cpp/llama.cpp/models/ggml-vocab-bert-bge.gguf.inp +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-bert-bge.gguf.out +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-command-r.gguf.inp +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-command-r.gguf.out +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-deepseek-coder.gguf.inp +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-deepseek-coder.gguf.out +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-deepseek-llm.gguf.inp +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-deepseek-llm.gguf.out +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-falcon.gguf.inp +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-falcon.gguf.out +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-gpt-2.gguf.inp +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-gpt-2.gguf.out +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-llama-bpe.gguf.inp +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-llama-bpe.gguf.out +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-llama-spm.gguf.inp +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-llama-spm.gguf.out +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-mpt.gguf.inp +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-mpt.gguf.out +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-nomic-bert-moe.gguf +0 -0
- package/cpp/llama.cpp/models/ggml-vocab-phi-3.gguf.inp +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-phi-3.gguf.out +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-qwen2.gguf.inp +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-qwen2.gguf.out +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-refact.gguf.inp +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-refact.gguf.out +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-starcoder.gguf.inp +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-starcoder.gguf.out +1 -1
- package/cpp/llama.cpp/models/templates/Qwen-QwQ-32B.jinja +62 -0
- package/cpp/llama.cpp/models/templates/Qwen-Qwen3-0.6B.jinja +85 -0
- package/cpp/llama.cpp/models/templates/README.md +2 -0
- package/cpp/llama.cpp/requirements/requirements-convert_hf_to_gguf.txt +5 -1
- package/cpp/llama.cpp/requirements/requirements-convert_hf_to_gguf_update.txt +5 -1
- package/cpp/llama.cpp/requirements/requirements-convert_lora_to_gguf.txt +2 -0
- package/cpp/llama.cpp/requirements/requirements-gguf_editor_gui.txt +1 -1
- package/cpp/llama.cpp/src/CMakeLists.txt +2 -0
- package/cpp/llama.cpp/src/llama-arch.cpp +6 -0
- package/cpp/llama.cpp/src/llama-arch.h +2 -0
- package/cpp/llama.cpp/src/llama-batch.cpp +3 -1
- package/cpp/llama.cpp/src/llama-context.cpp +340 -123
- package/cpp/llama.cpp/src/llama-context.h +30 -0
- package/cpp/llama.cpp/src/llama-cparams.cpp +4 -0
- package/cpp/llama.cpp/src/llama-cparams.h +2 -0
- package/cpp/llama.cpp/src/llama-grammar.cpp +12 -2
- package/cpp/llama.cpp/src/llama-graph.cpp +157 -247
- package/cpp/llama.cpp/src/llama-graph.h +52 -7
- package/cpp/llama.cpp/src/llama-hparams.cpp +17 -1
- package/cpp/llama.cpp/src/llama-hparams.h +37 -5
- package/cpp/llama.cpp/src/llama-kv-cache.cpp +742 -481
- package/cpp/llama.cpp/src/llama-kv-cache.h +196 -99
- package/cpp/llama.cpp/src/llama-kv-cells.h +379 -0
- package/cpp/llama.cpp/src/llama-memory.h +4 -3
- package/cpp/llama.cpp/src/llama-model-loader.cpp +22 -17
- package/cpp/llama.cpp/src/llama-model-saver.cpp +281 -0
- package/cpp/llama.cpp/src/llama-model-saver.h +37 -0
- package/cpp/llama.cpp/src/llama-model.cpp +529 -172
- package/cpp/llama.cpp/src/llama-model.h +6 -1
- package/cpp/llama.cpp/src/llama-quant.cpp +15 -13
- package/cpp/llama.cpp/src/llama-sampling.cpp +2 -2
- package/cpp/llama.cpp/src/llama-vocab.cpp +35 -8
- package/cpp/llama.cpp/src/llama-vocab.h +6 -0
- package/cpp/llama.cpp/src/llama.cpp +14 -0
- package/cpp/rn-completion.cpp +60 -5
- package/ios/include/chat.h +73 -6
- package/ios/include/common/minja/chat-template.hpp +9 -5
- package/ios/include/common/minja/minja.hpp +69 -36
- package/ios/include/common.h +21 -17
- package/ios/include/llama.h +62 -125
- package/ios/libs/llama.xcframework/Info.plist +19 -19
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4617 -4487
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-opt.h +237 -0
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +26 -7
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +62 -125
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4638 -4508
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3557 -3435
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +237 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +26 -7
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +62 -125
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4638 -4508
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3559 -3437
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-opt.h +237 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +26 -7
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +62 -125
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-opt.h +237 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +26 -7
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +62 -125
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-opt.h +237 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +26 -7
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +62 -125
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4616 -4487
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-opt.h +237 -0
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +26 -7
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +62 -125
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4637 -4508
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3556 -3435
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +237 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +26 -7
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +62 -125
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4653 -4523
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-opt.h +237 -0
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +26 -7
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +62 -125
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4674 -4544
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3587 -3465
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +237 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +26 -7
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +62 -125
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/package.json +1 -1
- package/cpp/llama.cpp/common/stb_image.h +0 -7988
- package/cpp/llama.cpp/models/ggml-vocab-chameleon.gguf.inp +0 -112
- package/cpp/llama.cpp/models/ggml-vocab-chameleon.gguf.out +0 -46
- package/cpp/llama.cpp/models/ggml-vocab-deepseek-r1-qwen.gguf.inp +0 -112
- package/cpp/llama.cpp/models/ggml-vocab-deepseek-r1-qwen.gguf.out +0 -46
- package/cpp/llama.cpp/models/ggml-vocab-gpt-4o.gguf.inp +0 -112
- package/cpp/llama.cpp/models/ggml-vocab-gpt-4o.gguf.out +0 -46
- package/cpp/llama.cpp/models/ggml-vocab-llama4.gguf.inp +0 -112
- package/cpp/llama.cpp/models/ggml-vocab-llama4.gguf.out +0 -46
- package/cpp/llama.cpp/models/ggml-vocab-pixtral.gguf.inp +0 -112
- package/cpp/llama.cpp/models/ggml-vocab-pixtral.gguf.out +0 -46
- package/cpp/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.inp +0 -112
- package/cpp/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.out +0 -46
|
@@ -0,0 +1,360 @@
|
|
|
1
|
+
#version 450
|
|
2
|
+
|
|
3
|
+
#extension GL_EXT_control_flow_attributes : enable
|
|
4
|
+
#extension GL_EXT_shader_16bit_storage : require
|
|
5
|
+
|
|
6
|
+
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
|
7
|
+
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
|
|
8
|
+
|
|
9
|
+
#extension GL_KHR_shader_subgroup_basic : enable
|
|
10
|
+
#extension GL_KHR_memory_scope_semantics : enable
|
|
11
|
+
#extension GL_KHR_cooperative_matrix : enable
|
|
12
|
+
|
|
13
|
+
#include "types.comp"
|
|
14
|
+
#include "flash_attn_base.comp"
|
|
15
|
+
|
|
16
|
+
const uint32_t D_per_thread = D / D_split;
|
|
17
|
+
const uint32_t row_split = 4;
|
|
18
|
+
const uint32_t rows_per_thread = Br / row_split;
|
|
19
|
+
const uint32_t cols_per_iter = gl_WorkGroupSize.x / D_split / row_split;
|
|
20
|
+
const uint32_t cols_per_thread = Bc / cols_per_iter;
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
layout (binding = 0) readonly buffer Q {float data_q[];};
|
|
24
|
+
layout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];};
|
|
25
|
+
layout (binding = 1) readonly buffer K {float16_t data_k[];};
|
|
26
|
+
layout (binding = 1) readonly buffer KV4 {f16vec4 data_kv4[];};
|
|
27
|
+
layout (binding = 2) readonly buffer V {float16_t data_v[];};
|
|
28
|
+
layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];};
|
|
29
|
+
layout (binding = 3) readonly buffer M {float16_t data_m[];};
|
|
30
|
+
|
|
31
|
+
// Store the output when doing grouped query attention.
|
|
32
|
+
// Rows index by Q's dimension 2, and the first N rows are valid.
|
|
33
|
+
D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
|
|
34
|
+
{
|
|
35
|
+
uint32_t offset = (iq2 + r) * D + c;
|
|
36
|
+
data_o[o_offset + offset] = D_TYPE(elem);
|
|
37
|
+
return elem;
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
// These need to be supported N,M values for a MatBc x MatBr x 16 coopmatmuladd
|
|
41
|
+
const uint32_t MatBr = 16;
|
|
42
|
+
const uint32_t MatBc = 16;
|
|
43
|
+
|
|
44
|
+
shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x];
|
|
45
|
+
shared ACC_TYPEV4 tmpshv4[gl_WorkGroupSize.x];
|
|
46
|
+
|
|
47
|
+
const uint32_t qstride = D / 4 + 2; // in units of f16vec4
|
|
48
|
+
shared f16vec4 Qf[Br * qstride];
|
|
49
|
+
|
|
50
|
+
// Avoid padding for D==256 to make it fit in 48KB shmem.
|
|
51
|
+
const uint32_t sfshstride = (D <= 128) ? (Br + 8) : Br;
|
|
52
|
+
shared ACC_TYPE sfsh[Bc * sfshstride];
|
|
53
|
+
|
|
54
|
+
const uint32_t kshstride = D / 4 + 2; // in units of f16vec4
|
|
55
|
+
shared f16vec4 ksh[Bc * kshstride];
|
|
56
|
+
|
|
57
|
+
shared float slope[Br];
|
|
58
|
+
|
|
59
|
+
void main() {
|
|
60
|
+
#ifdef NEEDS_INIT_IQ_SHMEM
|
|
61
|
+
init_iq_shmem(gl_WorkGroupSize);
|
|
62
|
+
#endif
|
|
63
|
+
|
|
64
|
+
init_indices();
|
|
65
|
+
|
|
66
|
+
const uint32_t tid = gl_LocalInvocationIndex;
|
|
67
|
+
|
|
68
|
+
const uint32_t threads_per_rowgroup = gl_WorkGroupSize.x / row_split;
|
|
69
|
+
const uint32_t row_tid = gl_LocalInvocationIndex / threads_per_rowgroup;
|
|
70
|
+
const uint32_t d_tid = gl_LocalInvocationIndex % D_split;
|
|
71
|
+
const uint32_t col_tid = (gl_LocalInvocationIndex % threads_per_rowgroup) / D_split;
|
|
72
|
+
|
|
73
|
+
#define tile_row(r) (row_tid * rows_per_thread + (r))
|
|
74
|
+
|
|
75
|
+
uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
|
|
76
|
+
|
|
77
|
+
[[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) {
|
|
78
|
+
uint32_t d = (idx + tid) % (D / 4);
|
|
79
|
+
uint32_t r = (idx + tid) / (D / 4);
|
|
80
|
+
if (r < Br && d < D / 4 &&
|
|
81
|
+
i * Br + r < N) {
|
|
82
|
+
Qf[r * qstride + d] = f16vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale);
|
|
83
|
+
}
|
|
84
|
+
}
|
|
85
|
+
barrier();
|
|
86
|
+
|
|
87
|
+
ACC_TYPEV4 Of[rows_per_thread][D_per_thread / 4];
|
|
88
|
+
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
|
89
|
+
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
|
90
|
+
Of[r][d] = ACC_TYPEV4(0.0);
|
|
91
|
+
}
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
float Lf[rows_per_thread], Mf[rows_per_thread];
|
|
95
|
+
|
|
96
|
+
// Use -FLT_MAX/2 rather than -inf to reduce the possibility of NaNs, e.g. when computing Mold-M.
|
|
97
|
+
const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF);
|
|
98
|
+
|
|
99
|
+
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
|
100
|
+
Lf[r] = 0;
|
|
101
|
+
Mf[r] = NEG_FLT_MAX_OVER_2;
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
// ALiBi
|
|
105
|
+
if (p.max_bias > 0.0f) {
|
|
106
|
+
if (tid < Br) {
|
|
107
|
+
uint r = tid;
|
|
108
|
+
slope[r] = perElemOpComputeSlope(r, col_tid, ACC_TYPE(0), iq2);
|
|
109
|
+
}
|
|
110
|
+
barrier();
|
|
111
|
+
} else {
|
|
112
|
+
if (tid < Br) {
|
|
113
|
+
uint r = tid;
|
|
114
|
+
slope[r] = 1.0;
|
|
115
|
+
}
|
|
116
|
+
barrier();
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
#if BLOCK_SIZE > 1
|
|
120
|
+
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / BLOCK_BYTE_SIZE;
|
|
121
|
+
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / BLOCK_BYTE_SIZE;
|
|
122
|
+
#else
|
|
123
|
+
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
|
|
124
|
+
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
|
|
125
|
+
#endif
|
|
126
|
+
|
|
127
|
+
[[dont_unroll]]
|
|
128
|
+
for (uint32_t j = start_j; j < end_j; ++j) {
|
|
129
|
+
|
|
130
|
+
[[unroll]] for (uint32_t idx = 0; idx < Bc * D / 4; idx += gl_WorkGroupSize.x) {
|
|
131
|
+
uint32_t d = (idx + tid) % (D / 4);
|
|
132
|
+
uint32_t c = (idx + tid) / (D / 4);
|
|
133
|
+
if (c < Bc && d < D / 4) {
|
|
134
|
+
#if BLOCK_SIZE > 1
|
|
135
|
+
uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d;
|
|
136
|
+
uint ib = coord / BLOCK_SIZE;
|
|
137
|
+
uint iqs = (coord % BLOCK_SIZE);
|
|
138
|
+
f16vec4 K_Tf = f16vec4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K));
|
|
139
|
+
#else
|
|
140
|
+
f16vec4 K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]);
|
|
141
|
+
#endif
|
|
142
|
+
|
|
143
|
+
ksh[c * kshstride + d] = K_Tf;
|
|
144
|
+
}
|
|
145
|
+
}
|
|
146
|
+
barrier();
|
|
147
|
+
|
|
148
|
+
// K * Q^T -> S^T: Bc x D * D x Br -> Bc x Br
|
|
149
|
+
// Bc split across workgroup (four subgroups), loop over D in chunks of 16: 16 x 16 * 16 x 16 -> 16 x 16
|
|
150
|
+
// This is written transposed in order to allow for N being 8 if implementations need it
|
|
151
|
+
coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator> SfMat = coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator>(0);
|
|
152
|
+
coopmat<float16_t, gl_ScopeSubgroup, MatBc, 16, gl_MatrixUseA> KMat;
|
|
153
|
+
coopmat<float16_t, gl_ScopeSubgroup, 16, MatBr, gl_MatrixUseB> QMat;
|
|
154
|
+
|
|
155
|
+
for (uint32_t d = 0; d < D / 16; ++d) {
|
|
156
|
+
coopMatLoad(QMat, Qf, d * 16 / 4, qstride, gl_CooperativeMatrixLayoutColumnMajor);
|
|
157
|
+
|
|
158
|
+
uint coord = (gl_SubgroupID * MatBc) * kshstride + d * 16 / 4;
|
|
159
|
+
coopMatLoad(KMat, ksh, coord, kshstride, gl_CooperativeMatrixLayoutRowMajor);
|
|
160
|
+
|
|
161
|
+
SfMat = coopMatMulAdd(KMat, QMat, SfMat);
|
|
162
|
+
}
|
|
163
|
+
|
|
164
|
+
uint coord = gl_SubgroupID * MatBc * sfshstride;
|
|
165
|
+
coopMatStore(SfMat, sfsh, coord, sfshstride, gl_CooperativeMatrixLayoutRowMajor);
|
|
166
|
+
barrier();
|
|
167
|
+
|
|
168
|
+
if (p.logit_softcap != 0.0f) {
|
|
169
|
+
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
|
|
170
|
+
uint32_t c = (idx + tid) / Br;
|
|
171
|
+
uint32_t r = (idx + tid) % Br;
|
|
172
|
+
if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
|
|
173
|
+
sfsh[c * sfshstride + r] = ACC_TYPE(p.logit_softcap * tanh(sfsh[c * sfshstride + r]));
|
|
174
|
+
}
|
|
175
|
+
}
|
|
176
|
+
barrier();
|
|
177
|
+
}
|
|
178
|
+
|
|
179
|
+
if (p.mask != 0) {
|
|
180
|
+
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
|
|
181
|
+
uint32_t c = (idx + tid) % Bc;
|
|
182
|
+
uint32_t r = (idx + tid) / Bc;
|
|
183
|
+
if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
|
|
184
|
+
sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[(i * Br + r) * m_stride + (j * Bc + c)]));
|
|
185
|
+
}
|
|
186
|
+
}
|
|
187
|
+
barrier();
|
|
188
|
+
}
|
|
189
|
+
|
|
190
|
+
float eMf[rows_per_thread];
|
|
191
|
+
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
|
192
|
+
float rowmaxf = sfsh[tile_row(r) + (0 * cols_per_iter + col_tid) * sfshstride];
|
|
193
|
+
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
|
194
|
+
rowmaxf = max(rowmaxf, float(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride]));
|
|
195
|
+
}
|
|
196
|
+
float Moldf = Mf[r];
|
|
197
|
+
|
|
198
|
+
// M = max(rowmax, Mold)
|
|
199
|
+
// P = e^(S - M)
|
|
200
|
+
// eM = e^(Mold - M)
|
|
201
|
+
Mf[r] = max(rowmaxf, Moldf);
|
|
202
|
+
eMf[r] = exp(Moldf - Mf[r]);
|
|
203
|
+
}
|
|
204
|
+
|
|
205
|
+
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
|
206
|
+
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
|
207
|
+
Of[r][d] = float16_t(eMf[r]) * Of[r][d];
|
|
208
|
+
}
|
|
209
|
+
}
|
|
210
|
+
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
|
211
|
+
Lf[r] = eMf[r]*Lf[r];
|
|
212
|
+
}
|
|
213
|
+
|
|
214
|
+
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
|
215
|
+
float Pf[rows_per_thread];
|
|
216
|
+
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
|
217
|
+
Pf[r] = exp(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride] - Mf[r]);
|
|
218
|
+
Lf[r] += Pf[r];
|
|
219
|
+
}
|
|
220
|
+
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
|
221
|
+
#if BLOCK_SIZE > 1
|
|
222
|
+
uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
|
|
223
|
+
uint ib = coord / BLOCK_SIZE;
|
|
224
|
+
uint iqs = (coord % BLOCK_SIZE);
|
|
225
|
+
vec4 Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
|
|
226
|
+
#else
|
|
227
|
+
vec4 Vf = vec4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]);
|
|
228
|
+
#endif
|
|
229
|
+
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
|
230
|
+
Of[r][d] += float16_t(Pf[r]) * ACC_TYPEV4(Vf);
|
|
231
|
+
}
|
|
232
|
+
}
|
|
233
|
+
}
|
|
234
|
+
|
|
235
|
+
barrier();
|
|
236
|
+
}
|
|
237
|
+
|
|
238
|
+
// reduce across threads
|
|
239
|
+
|
|
240
|
+
float rowmaxf[rows_per_thread], eMf[rows_per_thread], Moldf[rows_per_thread];
|
|
241
|
+
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
|
242
|
+
FLOAT_TYPE M = Mf[r];
|
|
243
|
+
tmpsh[tid] = M;
|
|
244
|
+
// Compute max across the row
|
|
245
|
+
barrier();
|
|
246
|
+
[[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) {
|
|
247
|
+
M = max(M, tmpsh[tid ^ s]);
|
|
248
|
+
barrier();
|
|
249
|
+
tmpsh[tid] = M;
|
|
250
|
+
barrier();
|
|
251
|
+
}
|
|
252
|
+
rowmaxf[r] = tmpsh[d_tid + row_tid * threads_per_rowgroup];
|
|
253
|
+
barrier();
|
|
254
|
+
}
|
|
255
|
+
|
|
256
|
+
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
|
257
|
+
Moldf[r] = Mf[r];
|
|
258
|
+
|
|
259
|
+
// M = max(rowmax, Mold)
|
|
260
|
+
// eM = e^(Mold - M)
|
|
261
|
+
Mf[r] = max(rowmaxf[r], Moldf[r]);
|
|
262
|
+
eMf[r] = exp(Moldf[r] - Mf[r]);
|
|
263
|
+
|
|
264
|
+
Lf[r] = eMf[r]*Lf[r];
|
|
265
|
+
}
|
|
266
|
+
|
|
267
|
+
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
|
268
|
+
FLOAT_TYPE L = Lf[r];
|
|
269
|
+
tmpsh[tid] = L;
|
|
270
|
+
// Compute sum across the row
|
|
271
|
+
barrier();
|
|
272
|
+
[[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) {
|
|
273
|
+
L += tmpsh[tid ^ s];
|
|
274
|
+
barrier();
|
|
275
|
+
tmpsh[tid] = L;
|
|
276
|
+
barrier();
|
|
277
|
+
}
|
|
278
|
+
Lf[r] = tmpsh[d_tid + row_tid * threads_per_rowgroup];
|
|
279
|
+
barrier();
|
|
280
|
+
}
|
|
281
|
+
|
|
282
|
+
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
|
283
|
+
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
|
284
|
+
|
|
285
|
+
Of[r][d] = float16_t(eMf[r]) * Of[r][d];
|
|
286
|
+
tmpshv4[tid] = Of[r][d];
|
|
287
|
+
|
|
288
|
+
barrier();
|
|
289
|
+
[[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) {
|
|
290
|
+
Of[r][d] += tmpshv4[tid ^ s];
|
|
291
|
+
barrier();
|
|
292
|
+
tmpshv4[tid] = Of[r][d];
|
|
293
|
+
barrier();
|
|
294
|
+
}
|
|
295
|
+
Of[r][d] = tmpshv4[d_tid + row_tid * threads_per_rowgroup];
|
|
296
|
+
barrier();
|
|
297
|
+
}
|
|
298
|
+
}
|
|
299
|
+
|
|
300
|
+
// If there is split_k, then the split_k resolve shader does the final
|
|
301
|
+
// division by L. Store the intermediate O value and per-row m and L values.
|
|
302
|
+
if (p.k_num > 1) {
|
|
303
|
+
uint32_t o_offset = D * p.ne1 * split_k_index;
|
|
304
|
+
|
|
305
|
+
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
|
306
|
+
if (tile_row(r) < N) {
|
|
307
|
+
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
|
308
|
+
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
|
309
|
+
perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N);
|
|
310
|
+
}
|
|
311
|
+
}
|
|
312
|
+
}
|
|
313
|
+
}
|
|
314
|
+
|
|
315
|
+
o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2;
|
|
316
|
+
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
|
317
|
+
if (tile_row(r) < N) {
|
|
318
|
+
perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
|
|
319
|
+
perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N);
|
|
320
|
+
}
|
|
321
|
+
}
|
|
322
|
+
|
|
323
|
+
return;
|
|
324
|
+
}
|
|
325
|
+
|
|
326
|
+
float Lfrcp[rows_per_thread];
|
|
327
|
+
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
|
328
|
+
Lfrcp[r] = 1.0 / Lf[r];
|
|
329
|
+
}
|
|
330
|
+
|
|
331
|
+
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
|
332
|
+
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
|
333
|
+
Of[r][d] *= float16_t(Lfrcp[r]);
|
|
334
|
+
}
|
|
335
|
+
}
|
|
336
|
+
|
|
337
|
+
uint32_t o_offset = iq3*p.ne2*p.ne1;
|
|
338
|
+
|
|
339
|
+
if (p.gqa_ratio > 1) {
|
|
340
|
+
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
|
341
|
+
if (tile_row(r) < N) {
|
|
342
|
+
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
|
343
|
+
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
|
344
|
+
perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N);
|
|
345
|
+
}
|
|
346
|
+
}
|
|
347
|
+
}
|
|
348
|
+
}
|
|
349
|
+
} else {
|
|
350
|
+
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
|
351
|
+
if (i * Br + tile_row(r) < N) {
|
|
352
|
+
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
|
353
|
+
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
|
354
|
+
data_o[o_offset + iq2 * D + (i * Br + tile_row(r)) * p.ne1 * D + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
|
|
355
|
+
}
|
|
356
|
+
}
|
|
357
|
+
}
|
|
358
|
+
}
|
|
359
|
+
}
|
|
360
|
+
}
|
|
@@ -18,62 +18,12 @@
|
|
|
18
18
|
|
|
19
19
|
#include "types.comp"
|
|
20
20
|
#include "dequant_funcs_cm2.comp"
|
|
21
|
-
|
|
22
|
-
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
|
23
|
-
|
|
24
|
-
layout (constant_id = 1) const uint32_t Br = 32;
|
|
25
|
-
layout (constant_id = 2) const uint32_t Bc = 32;
|
|
26
|
-
layout (constant_id = 3) const uint32_t D = 32;
|
|
27
|
-
layout (constant_id = 4) const uint32_t Clamp = gl_CooperativeMatrixClampModeConstantNV;
|
|
28
|
-
|
|
29
|
-
layout (push_constant) uniform parameter {
|
|
30
|
-
uint32_t N;
|
|
31
|
-
uint32_t KV;
|
|
32
|
-
|
|
33
|
-
uint32_t ne1;
|
|
34
|
-
uint32_t ne2;
|
|
35
|
-
uint32_t ne3;
|
|
36
|
-
|
|
37
|
-
uint32_t neq2;
|
|
38
|
-
uint32_t neq3;
|
|
39
|
-
uint32_t nek2;
|
|
40
|
-
uint32_t nek3;
|
|
41
|
-
uint32_t nev2;
|
|
42
|
-
uint32_t nev3;
|
|
43
|
-
uint32_t nem1;
|
|
44
|
-
|
|
45
|
-
uint32_t nb01;
|
|
46
|
-
uint32_t nb02;
|
|
47
|
-
uint32_t nb03;
|
|
48
|
-
uint32_t nb11;
|
|
49
|
-
uint32_t nb12;
|
|
50
|
-
uint32_t nb13;
|
|
51
|
-
uint32_t nb21;
|
|
52
|
-
uint32_t nb22;
|
|
53
|
-
uint32_t nb23;
|
|
54
|
-
uint32_t nb31;
|
|
55
|
-
|
|
56
|
-
float scale;
|
|
57
|
-
float max_bias;
|
|
58
|
-
float logit_softcap;
|
|
59
|
-
|
|
60
|
-
uint32_t mask;
|
|
61
|
-
uint32_t n_head_log2;
|
|
62
|
-
float m0;
|
|
63
|
-
float m1;
|
|
64
|
-
|
|
65
|
-
uint32_t gqa_ratio;
|
|
66
|
-
uint32_t split_kv;
|
|
67
|
-
uint32_t k_num;
|
|
68
|
-
} p;
|
|
21
|
+
#include "flash_attn_base.comp"
|
|
69
22
|
|
|
70
23
|
layout (binding = 0) readonly buffer Q {uint8_t data_q[];};
|
|
71
24
|
layout (binding = 1) readonly buffer K {uint8_t data_k[];};
|
|
72
25
|
layout (binding = 2) readonly buffer V {uint8_t data_v[];};
|
|
73
26
|
layout (binding = 3) readonly buffer M {uint8_t data_m[];};
|
|
74
|
-
layout (binding = 4) writeonly buffer O {D_TYPE data_o[];};
|
|
75
|
-
|
|
76
|
-
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
|
|
77
27
|
|
|
78
28
|
ACC_TYPE maxReduce(const in ACC_TYPE x, const in ACC_TYPE y) {
|
|
79
29
|
return max(x, y);
|
|
@@ -118,67 +68,12 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TY
|
|
|
118
68
|
return elem;
|
|
119
69
|
}
|
|
120
70
|
|
|
121
|
-
// Store column zero. This is used to save per-row m and L values for split_k.
|
|
122
|
-
ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
|
|
123
|
-
{
|
|
124
|
-
if (r < N && c == 0) {
|
|
125
|
-
uint32_t offset = iq2 + r;
|
|
126
|
-
data_o[o_offset + offset] = D_TYPE(elem);
|
|
127
|
-
}
|
|
128
|
-
return elem;
|
|
129
|
-
}
|
|
130
|
-
|
|
131
|
-
// Load the slope matrix, indexed by Q's dimension 2.
|
|
132
|
-
ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
|
|
133
|
-
{
|
|
134
|
-
const uint32_t h = iq2 + (r % p.gqa_ratio);
|
|
135
|
-
|
|
136
|
-
const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
|
|
137
|
-
const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
|
|
138
|
-
|
|
139
|
-
return ACC_TYPE(pow(base, ACC_TYPE(exph)));
|
|
140
|
-
}
|
|
141
|
-
|
|
142
71
|
void main() {
|
|
143
72
|
#ifdef NEEDS_INIT_IQ_SHMEM
|
|
144
73
|
init_iq_shmem(gl_WorkGroupSize);
|
|
145
74
|
#endif
|
|
146
75
|
|
|
147
|
-
|
|
148
|
-
const uint32_t KV = p.KV;
|
|
149
|
-
|
|
150
|
-
uint32_t i = gl_WorkGroupID.x;
|
|
151
|
-
uint32_t split_k_index = 0;
|
|
152
|
-
|
|
153
|
-
if (p.k_num > 1) {
|
|
154
|
-
i = 0;
|
|
155
|
-
split_k_index = gl_WorkGroupID.x;
|
|
156
|
-
}
|
|
157
|
-
|
|
158
|
-
const uint32_t Tr = CEIL_DIV(N, Br);
|
|
159
|
-
|
|
160
|
-
const uint32_t start_j = split_k_index * p.split_kv / Bc;
|
|
161
|
-
const uint32_t end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc);
|
|
162
|
-
|
|
163
|
-
// When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y.
|
|
164
|
-
// When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2.
|
|
165
|
-
const uint32_t iq2 = gl_WorkGroupID.y * p.gqa_ratio;
|
|
166
|
-
const uint32_t iq3 = gl_WorkGroupID.z;
|
|
167
|
-
|
|
168
|
-
// broadcast factors
|
|
169
|
-
const uint32_t rk2 = p.neq2/p.nek2;
|
|
170
|
-
const uint32_t rk3 = p.neq3/p.nek3;
|
|
171
|
-
|
|
172
|
-
const uint32_t rv2 = p.neq2/p.nev2;
|
|
173
|
-
const uint32_t rv3 = p.neq3/p.nev3;
|
|
174
|
-
|
|
175
|
-
// k indices
|
|
176
|
-
const uint32_t ik3 = iq3 / rk3;
|
|
177
|
-
const uint32_t ik2 = iq2 / rk2;
|
|
178
|
-
|
|
179
|
-
// v indices
|
|
180
|
-
const uint32_t iv3 = iq3 / rv3;
|
|
181
|
-
const uint32_t iv2 = iq2 / rv2;
|
|
76
|
+
init_indices();
|
|
182
77
|
|
|
183
78
|
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutQ = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
|
|
184
79
|
tensorLayoutNV<2, Clamp> tensorLayoutK = createTensorLayoutNV(2, Clamp);
|
|
@@ -195,17 +90,6 @@ void main() {
|
|
|
195
90
|
tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, D);
|
|
196
91
|
tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, D);
|
|
197
92
|
|
|
198
|
-
// nb?1 are already divided by the type size and are in units of elements.
|
|
199
|
-
// When using grouped query attention, Q is indexed by iq2, so the stride
|
|
200
|
-
// should be nb02 (which is in bytes).
|
|
201
|
-
uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01;
|
|
202
|
-
uint32_t k_stride = p.nb11;
|
|
203
|
-
uint32_t v_stride = p.nb21;
|
|
204
|
-
// When using grouped query attention, all rows use the same mask (stride 0).
|
|
205
|
-
// "p.gqa_ratio >> 16" is just a roundabout way of writing zero
|
|
206
|
-
// that prevents the compiler from folding the "&" through the select
|
|
207
|
-
// and breaking the alignment detection.
|
|
208
|
-
uint32_t m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV;
|
|
209
93
|
// hint to the compiler that strides are aligned for the aligned variant of the shader
|
|
210
94
|
if (Clamp != gl_CooperativeMatrixClampModeConstantNV)
|
|
211
95
|
{
|
|
@@ -7,7 +7,7 @@
|
|
|
7
7
|
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
|
8
8
|
#endif
|
|
9
9
|
#if defined(DATA_A_IQ1_M)
|
|
10
|
-
#extension
|
|
10
|
+
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
|
|
11
11
|
#endif
|
|
12
12
|
|
|
13
13
|
#if defined(DATA_A_BF16) && defined(COOPMAT)
|
|
@@ -215,7 +215,7 @@ static std::mutex compile_count_mutex;
|
|
|
215
215
|
static std::condition_variable compile_count_cond;
|
|
216
216
|
|
|
217
217
|
void string_to_spv_func(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) {
|
|
218
|
-
std::string name = _name + (f16acc ? "_f16acc" : "") + (coopmat ? "
|
|
218
|
+
std::string name = _name + (f16acc ? "_f16acc" : "") + (coopmat ? "_cm1" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32"));
|
|
219
219
|
std::string out_fname = join_paths(output_dir, name + ".spv");
|
|
220
220
|
std::string in_path = join_paths(input_dir, in_fname);
|
|
221
221
|
|
|
@@ -424,6 +424,7 @@ void process_shaders() {
|
|
|
424
424
|
// flash attention
|
|
425
425
|
for (const auto& f16acc : {false, true}) {
|
|
426
426
|
std::string acctype = f16acc ? "float16_t" : "float";
|
|
427
|
+
std::string acctypev4 = f16acc ? "f16vec4" : "vec4";
|
|
427
428
|
|
|
428
429
|
for (const auto& tname : type_names) {
|
|
429
430
|
if (tname == "f32") {
|
|
@@ -440,6 +441,16 @@ void process_shaders() {
|
|
|
440
441
|
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
|
|
441
442
|
merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc);
|
|
442
443
|
}
|
|
444
|
+
#endif
|
|
445
|
+
#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
|
446
|
+
if (tname == "f16") {
|
|
447
|
+
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
|
|
448
|
+
merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"ACC_TYPEV4", acctypev4}, {"COOPMAT", "1"}}), true, true, false, f16acc);
|
|
449
|
+
} else if (tname == "q4_0" || tname == "q8_0") {
|
|
450
|
+
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
|
451
|
+
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
|
|
452
|
+
merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"ACC_TYPEV4", acctypev4}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), true, true, false, f16acc);
|
|
453
|
+
}
|
|
443
454
|
#endif
|
|
444
455
|
if (tname == "f16") {
|
|
445
456
|
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
|