@novastera-oss/llamarn 0.2.1 → 0.2.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +80 -14
- package/RNLlamaCpp.podspec +10 -3
- package/android/CMakeLists.txt +8 -0
- package/android/src/main/cpp/include/llama.h +62 -125
- package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
- package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
- package/cpp/build-info.cpp +2 -2
- package/cpp/llama.cpp/README.md +11 -3
- package/cpp/llama.cpp/build-xcframework.sh +1 -0
- package/cpp/llama.cpp/common/CMakeLists.txt +8 -2
- package/cpp/llama.cpp/common/arg.cpp +153 -113
- package/cpp/llama.cpp/common/chat-parser.cpp +379 -0
- package/cpp/llama.cpp/common/chat-parser.h +117 -0
- package/cpp/llama.cpp/common/chat.cpp +847 -699
- package/cpp/llama.cpp/common/chat.h +73 -6
- package/cpp/llama.cpp/common/common.cpp +50 -82
- package/cpp/llama.cpp/common/common.h +21 -17
- package/cpp/llama.cpp/common/json-partial.cpp +255 -0
- package/cpp/llama.cpp/common/json-partial.h +37 -0
- package/cpp/llama.cpp/common/minja/chat-template.hpp +9 -5
- package/cpp/llama.cpp/common/minja/minja.hpp +69 -36
- package/cpp/llama.cpp/common/regex-partial.cpp +204 -0
- package/cpp/llama.cpp/common/regex-partial.h +56 -0
- package/cpp/llama.cpp/common/sampling.cpp +7 -8
- package/cpp/llama.cpp/convert_hf_to_gguf.py +453 -118
- package/cpp/llama.cpp/convert_hf_to_gguf_update.py +120 -68
- package/cpp/llama.cpp/ggml/CMakeLists.txt +2 -1
- package/cpp/llama.cpp/ggml/cmake/common.cmake +25 -0
- package/cpp/llama.cpp/ggml/include/ggml-opt.h +49 -28
- package/cpp/llama.cpp/ggml/include/ggml.h +26 -7
- package/cpp/llama.cpp/ggml/src/CMakeLists.txt +16 -10
- package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +4 -1
- package/cpp/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +1 -0
- package/cpp/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +604 -0
- package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +42 -0
- package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +54 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +50 -51
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +5 -9
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +779 -19
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +22 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +88 -5
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +47 -12
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +264 -69
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +322 -100
- package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +117 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +85 -16
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +220 -49
- package/cpp/llama.cpp/ggml/src/ggml-cuda/acc.cu +40 -26
- package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy.cu +11 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +15 -7
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +266 -64
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +49 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +48 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cu +2 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +5 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cu +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/quantize.cu +7 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/sum.cu +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +10 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-impl.h +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +99 -17
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +200 -2
- package/cpp/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +8 -2
- package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cu +112 -0
- package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cuh +12 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +6 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +972 -178
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/div.cl +72 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/group_norm.cl +72 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/sub.cl +72 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/sum_rows.cl +39 -0
- package/cpp/llama.cpp/ggml/src/ggml-opt.cpp +373 -190
- package/cpp/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +29 -23
- package/cpp/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +5 -10
- package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +101 -5
- package/cpp/llama.cpp/ggml/src/ggml-sycl/concat.cpp +31 -33
- package/cpp/llama.cpp/ggml/src/ggml-sycl/conv.cpp +1 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +29 -2
- package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +4 -5
- package/cpp/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +59 -21
- package/cpp/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +9 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +84 -72
- package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +37 -8
- package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +1 -3
- package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +324 -129
- package/cpp/llama.cpp/ggml/src/ggml-sycl/gla.cpp +1 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +31 -2
- package/cpp/llama.cpp/ggml/src/ggml-sycl/norm.cpp +95 -68
- package/cpp/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +1 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +22 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +1 -2
- package/cpp/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +1 -4
- package/cpp/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +2 -3
- package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +69 -43
- package/cpp/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +2 -14
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +81 -91
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +432 -181
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +17 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +6 -152
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +162 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +360 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +2 -118
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +12 -1
- package/cpp/llama.cpp/ggml/src/ggml.c +107 -36
- package/cpp/llama.cpp/ggml/src/gguf.cpp +33 -33
- package/cpp/llama.cpp/gguf-py/gguf/constants.py +100 -15
- package/cpp/llama.cpp/gguf-py/gguf/gguf_reader.py +1 -1
- package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +44 -12
- package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_editor_gui.py +21 -10
- package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_new_metadata.py +5 -2
- package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +128 -31
- package/cpp/llama.cpp/gguf-py/gguf/utility.py +1 -1
- package/cpp/llama.cpp/gguf-py/pyproject.toml +1 -1
- package/cpp/llama.cpp/include/llama.h +62 -125
- package/cpp/llama.cpp/models/ggml-vocab-bert-bge.gguf.inp +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-bert-bge.gguf.out +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-command-r.gguf.inp +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-command-r.gguf.out +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-deepseek-coder.gguf.inp +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-deepseek-coder.gguf.out +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-deepseek-llm.gguf.inp +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-deepseek-llm.gguf.out +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-falcon.gguf.inp +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-falcon.gguf.out +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-gpt-2.gguf.inp +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-gpt-2.gguf.out +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-llama-bpe.gguf.inp +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-llama-bpe.gguf.out +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-llama-spm.gguf.inp +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-llama-spm.gguf.out +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-mpt.gguf.inp +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-mpt.gguf.out +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-nomic-bert-moe.gguf +0 -0
- package/cpp/llama.cpp/models/ggml-vocab-phi-3.gguf.inp +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-phi-3.gguf.out +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-qwen2.gguf.inp +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-qwen2.gguf.out +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-refact.gguf.inp +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-refact.gguf.out +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-starcoder.gguf.inp +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-starcoder.gguf.out +1 -1
- package/cpp/llama.cpp/models/templates/Qwen-QwQ-32B.jinja +62 -0
- package/cpp/llama.cpp/models/templates/Qwen-Qwen3-0.6B.jinja +85 -0
- package/cpp/llama.cpp/models/templates/README.md +2 -0
- package/cpp/llama.cpp/requirements/requirements-convert_hf_to_gguf.txt +5 -1
- package/cpp/llama.cpp/requirements/requirements-convert_hf_to_gguf_update.txt +5 -1
- package/cpp/llama.cpp/requirements/requirements-convert_lora_to_gguf.txt +2 -0
- package/cpp/llama.cpp/requirements/requirements-gguf_editor_gui.txt +1 -1
- package/cpp/llama.cpp/src/CMakeLists.txt +2 -0
- package/cpp/llama.cpp/src/llama-arch.cpp +6 -0
- package/cpp/llama.cpp/src/llama-arch.h +2 -0
- package/cpp/llama.cpp/src/llama-batch.cpp +3 -1
- package/cpp/llama.cpp/src/llama-context.cpp +340 -123
- package/cpp/llama.cpp/src/llama-context.h +30 -0
- package/cpp/llama.cpp/src/llama-cparams.cpp +4 -0
- package/cpp/llama.cpp/src/llama-cparams.h +2 -0
- package/cpp/llama.cpp/src/llama-grammar.cpp +12 -2
- package/cpp/llama.cpp/src/llama-graph.cpp +157 -247
- package/cpp/llama.cpp/src/llama-graph.h +52 -7
- package/cpp/llama.cpp/src/llama-hparams.cpp +17 -1
- package/cpp/llama.cpp/src/llama-hparams.h +37 -5
- package/cpp/llama.cpp/src/llama-kv-cache.cpp +742 -481
- package/cpp/llama.cpp/src/llama-kv-cache.h +196 -99
- package/cpp/llama.cpp/src/llama-kv-cells.h +379 -0
- package/cpp/llama.cpp/src/llama-memory.h +4 -3
- package/cpp/llama.cpp/src/llama-model-loader.cpp +22 -17
- package/cpp/llama.cpp/src/llama-model-saver.cpp +281 -0
- package/cpp/llama.cpp/src/llama-model-saver.h +37 -0
- package/cpp/llama.cpp/src/llama-model.cpp +529 -172
- package/cpp/llama.cpp/src/llama-model.h +6 -1
- package/cpp/llama.cpp/src/llama-quant.cpp +15 -13
- package/cpp/llama.cpp/src/llama-sampling.cpp +2 -2
- package/cpp/llama.cpp/src/llama-vocab.cpp +35 -8
- package/cpp/llama.cpp/src/llama-vocab.h +6 -0
- package/cpp/llama.cpp/src/llama.cpp +14 -0
- package/cpp/rn-completion.cpp +4 -2
- package/ios/include/chat.h +73 -6
- package/ios/include/common/minja/chat-template.hpp +9 -5
- package/ios/include/common/minja/minja.hpp +69 -36
- package/ios/include/common.h +21 -17
- package/ios/include/llama.h +62 -125
- package/ios/libs/llama.xcframework/Info.plist +19 -19
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4617 -4487
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-opt.h +237 -0
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +26 -7
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +62 -125
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4638 -4508
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3557 -3435
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +237 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +26 -7
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +62 -125
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4638 -4508
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3559 -3437
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-opt.h +237 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +26 -7
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +62 -125
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-opt.h +237 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +26 -7
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +62 -125
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-opt.h +237 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +26 -7
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +62 -125
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4616 -4487
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-opt.h +237 -0
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +26 -7
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +62 -125
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4637 -4508
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3556 -3435
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +237 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +26 -7
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +62 -125
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4653 -4523
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-opt.h +237 -0
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +26 -7
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +62 -125
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4674 -4544
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3587 -3465
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +237 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +26 -7
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +62 -125
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/package.json +1 -1
- package/cpp/llama.cpp/common/stb_image.h +0 -7988
- package/cpp/llama.cpp/models/ggml-vocab-chameleon.gguf.inp +0 -112
- package/cpp/llama.cpp/models/ggml-vocab-chameleon.gguf.out +0 -46
- package/cpp/llama.cpp/models/ggml-vocab-deepseek-r1-qwen.gguf.inp +0 -112
- package/cpp/llama.cpp/models/ggml-vocab-deepseek-r1-qwen.gguf.out +0 -46
- package/cpp/llama.cpp/models/ggml-vocab-gpt-4o.gguf.inp +0 -112
- package/cpp/llama.cpp/models/ggml-vocab-gpt-4o.gguf.out +0 -46
- package/cpp/llama.cpp/models/ggml-vocab-llama4.gguf.inp +0 -112
- package/cpp/llama.cpp/models/ggml-vocab-llama4.gguf.out +0 -46
- package/cpp/llama.cpp/models/ggml-vocab-pixtral.gguf.inp +0 -112
- package/cpp/llama.cpp/models/ggml-vocab-pixtral.gguf.out +0 -46
- package/cpp/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.inp +0 -112
- package/cpp/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.out +0 -46
|
@@ -5,18 +5,35 @@ find_package (Threads REQUIRED)
|
|
|
5
5
|
|
|
6
6
|
if (GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
|
7
7
|
add_compile_definitions(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
|
8
|
+
message(STATUS "Enabling coopmat glslc support")
|
|
8
9
|
endif()
|
|
9
10
|
if (GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
|
10
11
|
add_compile_definitions(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
|
12
|
+
message(STATUS "Enabling coopmat2 glslc support")
|
|
11
13
|
endif()
|
|
12
14
|
if (GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
|
13
15
|
add_compile_definitions(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
|
16
|
+
message(STATUS "Enabling dot glslc support")
|
|
14
17
|
endif()
|
|
15
18
|
if (GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
|
|
16
19
|
add_compile_definitions(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
|
|
20
|
+
message(STATUS "Enabling bfloat16 glslc support")
|
|
17
21
|
endif()
|
|
22
|
+
|
|
18
23
|
set(TARGET vulkan-shaders-gen)
|
|
19
24
|
add_executable(${TARGET} vulkan-shaders-gen.cpp)
|
|
20
25
|
install(TARGETS ${TARGET} RUNTIME)
|
|
21
26
|
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
|
22
27
|
target_link_libraries(vulkan-shaders-gen PUBLIC Threads::Threads)
|
|
28
|
+
|
|
29
|
+
# Configure output directories for MSVC builds
|
|
30
|
+
if(MSVC)
|
|
31
|
+
# Get the main project's runtime output directory if possible
|
|
32
|
+
if(DEFINED CMAKE_RUNTIME_OUTPUT_DIRECTORY)
|
|
33
|
+
foreach(CONFIG ${CMAKE_CONFIGURATION_TYPES})
|
|
34
|
+
string(TOUPPER ${CONFIG} CONFIG)
|
|
35
|
+
set_target_properties(${TARGET} PROPERTIES
|
|
36
|
+
RUNTIME_OUTPUT_DIRECTORY_${CONFIG} ${CMAKE_RUNTIME_OUTPUT_DIRECTORY})
|
|
37
|
+
endforeach()
|
|
38
|
+
endif()
|
|
39
|
+
endif()
|
|
@@ -9,59 +9,13 @@
|
|
|
9
9
|
#extension GL_KHR_shader_subgroup_shuffle : enable
|
|
10
10
|
|
|
11
11
|
#include "types.comp"
|
|
12
|
+
#include "flash_attn_base.comp"
|
|
12
13
|
|
|
13
|
-
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
|
14
|
-
|
|
15
|
-
layout (constant_id = 1) const uint32_t Br = 1;
|
|
16
|
-
layout (constant_id = 2) const uint32_t Bc = 32;
|
|
17
|
-
layout (constant_id = 3) const uint32_t D = 32;
|
|
18
|
-
|
|
19
|
-
layout (constant_id = 5) const uint32_t D_split = 16;
|
|
20
14
|
const uint32_t D_per_thread = D / D_split;
|
|
21
15
|
|
|
22
|
-
const uint32_t cols_per_iter =
|
|
16
|
+
const uint32_t cols_per_iter = WorkGroupSize / D_split;
|
|
23
17
|
const uint32_t cols_per_thread = Bc / cols_per_iter;
|
|
24
18
|
|
|
25
|
-
layout (push_constant) uniform parameter {
|
|
26
|
-
uint32_t N;
|
|
27
|
-
uint32_t KV;
|
|
28
|
-
|
|
29
|
-
uint32_t ne1;
|
|
30
|
-
uint32_t ne2;
|
|
31
|
-
uint32_t ne3;
|
|
32
|
-
|
|
33
|
-
uint32_t neq2;
|
|
34
|
-
uint32_t neq3;
|
|
35
|
-
uint32_t nek2;
|
|
36
|
-
uint32_t nek3;
|
|
37
|
-
uint32_t nev2;
|
|
38
|
-
uint32_t nev3;
|
|
39
|
-
uint32_t nem1;
|
|
40
|
-
|
|
41
|
-
uint32_t nb01;
|
|
42
|
-
uint32_t nb02;
|
|
43
|
-
uint32_t nb03;
|
|
44
|
-
uint32_t nb11;
|
|
45
|
-
uint32_t nb12;
|
|
46
|
-
uint32_t nb13;
|
|
47
|
-
uint32_t nb21;
|
|
48
|
-
uint32_t nb22;
|
|
49
|
-
uint32_t nb23;
|
|
50
|
-
uint32_t nb31;
|
|
51
|
-
|
|
52
|
-
float scale;
|
|
53
|
-
float max_bias;
|
|
54
|
-
float logit_softcap;
|
|
55
|
-
|
|
56
|
-
uint32_t mask;
|
|
57
|
-
uint32_t n_head_log2;
|
|
58
|
-
float m0;
|
|
59
|
-
float m1;
|
|
60
|
-
|
|
61
|
-
uint32_t gqa_ratio;
|
|
62
|
-
uint32_t split_kv;
|
|
63
|
-
uint32_t k_num;
|
|
64
|
-
} p;
|
|
65
19
|
|
|
66
20
|
layout (binding = 0) readonly buffer Q {float data_q[];};
|
|
67
21
|
layout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];};
|
|
@@ -70,39 +24,6 @@ layout (binding = 1) readonly buffer KV4 {f16vec4 data_kv4[];};
|
|
|
70
24
|
layout (binding = 2) readonly buffer V {float16_t data_v[];};
|
|
71
25
|
layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];};
|
|
72
26
|
layout (binding = 3) readonly buffer M {float16_t data_m[];};
|
|
73
|
-
layout (binding = 4) writeonly buffer O {D_TYPE data_o[];};
|
|
74
|
-
|
|
75
|
-
#if defined(A_TYPE_PACKED16)
|
|
76
|
-
#define BINDING_IDX_K 0
|
|
77
|
-
#define BINDING_IDX_V 1
|
|
78
|
-
layout (binding = 1) readonly buffer KV_PACKED16 {A_TYPE_PACKED16 data_packed16[];} kv_packed[2];
|
|
79
|
-
#endif
|
|
80
|
-
|
|
81
|
-
#if defined(DATA_A_Q4_0)
|
|
82
|
-
#define BLOCK_BYTE_SIZE 18
|
|
83
|
-
|
|
84
|
-
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
|
|
85
|
-
uint vui_lo = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
|
|
86
|
-
uint vui_hi = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
|
|
87
|
-
uint shift = (iqs & 0x10) >> 2;
|
|
88
|
-
vui_lo >>= shift;
|
|
89
|
-
vui_hi >>= shift;
|
|
90
|
-
|
|
91
|
-
return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
|
|
92
|
-
}
|
|
93
|
-
#endif
|
|
94
|
-
|
|
95
|
-
#if defined(DATA_A_Q8_0)
|
|
96
|
-
#define BLOCK_BYTE_SIZE 34
|
|
97
|
-
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
|
|
98
|
-
const i8vec2 v0 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
|
|
99
|
-
const i8vec2 v1 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
|
|
100
|
-
|
|
101
|
-
return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
|
|
102
|
-
}
|
|
103
|
-
#endif
|
|
104
|
-
|
|
105
|
-
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
|
|
106
27
|
|
|
107
28
|
// Store the output when doing grouped query attention.
|
|
108
29
|
// Rows index by Q's dimension 2, and the first N rows are valid.
|
|
@@ -113,29 +34,8 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TY
|
|
|
113
34
|
return elem;
|
|
114
35
|
}
|
|
115
36
|
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
{
|
|
119
|
-
if (r < N && c == 0) {
|
|
120
|
-
uint32_t offset = iq2 + r;
|
|
121
|
-
data_o[o_offset + offset] = D_TYPE(elem);
|
|
122
|
-
}
|
|
123
|
-
return elem;
|
|
124
|
-
}
|
|
125
|
-
|
|
126
|
-
// Load the slope matrix, indexed by Q's dimension 2.
|
|
127
|
-
ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
|
|
128
|
-
{
|
|
129
|
-
const uint32_t h = iq2 + (r % p.gqa_ratio);
|
|
130
|
-
|
|
131
|
-
const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
|
|
132
|
-
const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
|
|
133
|
-
|
|
134
|
-
return ACC_TYPE(pow(base, ACC_TYPE(exph)));
|
|
135
|
-
}
|
|
136
|
-
|
|
137
|
-
shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x];
|
|
138
|
-
shared vec4 tmpshv4[gl_WorkGroupSize.x];
|
|
37
|
+
shared FLOAT_TYPE tmpsh[WorkGroupSize];
|
|
38
|
+
shared vec4 tmpshv4[WorkGroupSize];
|
|
139
39
|
|
|
140
40
|
shared float masksh[Bc][Br];
|
|
141
41
|
shared vec4 Qf[Br][D / 4];
|
|
@@ -145,58 +45,12 @@ void main() {
|
|
|
145
45
|
init_iq_shmem(gl_WorkGroupSize);
|
|
146
46
|
#endif
|
|
147
47
|
|
|
148
|
-
|
|
149
|
-
const uint32_t N = p.N;
|
|
150
|
-
const uint32_t KV = p.KV;
|
|
48
|
+
init_indices();
|
|
151
49
|
|
|
50
|
+
const uint32_t tid = gl_LocalInvocationIndex;
|
|
152
51
|
const uint32_t d_tid = gl_LocalInvocationIndex % D_split;
|
|
153
52
|
const uint32_t col_tid = gl_LocalInvocationIndex / D_split;
|
|
154
53
|
|
|
155
|
-
uint32_t i = gl_WorkGroupID.x;
|
|
156
|
-
uint32_t split_k_index = 0;
|
|
157
|
-
|
|
158
|
-
if (p.k_num > 1) {
|
|
159
|
-
i = 0;
|
|
160
|
-
split_k_index = gl_WorkGroupID.x;
|
|
161
|
-
}
|
|
162
|
-
|
|
163
|
-
const uint32_t Tr = CEIL_DIV(N, Br);
|
|
164
|
-
|
|
165
|
-
const uint32_t start_j = split_k_index * p.split_kv / Bc;
|
|
166
|
-
const uint32_t end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc);
|
|
167
|
-
|
|
168
|
-
// When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y.
|
|
169
|
-
// When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2.
|
|
170
|
-
const uint32_t iq2 = gl_WorkGroupID.y * p.gqa_ratio;
|
|
171
|
-
const uint32_t iq3 = gl_WorkGroupID.z;
|
|
172
|
-
|
|
173
|
-
// broadcast factors
|
|
174
|
-
const uint32_t rk2 = p.neq2/p.nek2;
|
|
175
|
-
const uint32_t rk3 = p.neq3/p.nek3;
|
|
176
|
-
|
|
177
|
-
const uint32_t rv2 = p.neq2/p.nev2;
|
|
178
|
-
const uint32_t rv3 = p.neq3/p.nev3;
|
|
179
|
-
|
|
180
|
-
// k indices
|
|
181
|
-
const uint32_t ik3 = iq3 / rk3;
|
|
182
|
-
const uint32_t ik2 = iq2 / rk2;
|
|
183
|
-
|
|
184
|
-
// v indices
|
|
185
|
-
const uint32_t iv3 = iq3 / rv3;
|
|
186
|
-
const uint32_t iv2 = iq2 / rv2;
|
|
187
|
-
|
|
188
|
-
// nb?1 are already divided by the type size and are in units of elements.
|
|
189
|
-
// When using grouped query attention, Q is indexed by iq2, so the stride
|
|
190
|
-
// should be nb02 (which is in bytes).
|
|
191
|
-
uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01;
|
|
192
|
-
uint32_t k_stride = p.nb11;
|
|
193
|
-
uint32_t v_stride = p.nb21;
|
|
194
|
-
// When using grouped query attention, all rows use the same mask (stride 0).
|
|
195
|
-
// "p.gqa_ratio >> 16" is just a roundabout way of writing zero
|
|
196
|
-
// that prevents the compiler from folding the "&" through the select
|
|
197
|
-
// and breaking the alignment detection.
|
|
198
|
-
uint32_t m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV;
|
|
199
|
-
|
|
200
54
|
uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
|
|
201
55
|
|
|
202
56
|
[[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) {
|
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
|
|
2
|
+
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
|
3
|
+
|
|
4
|
+
layout (constant_id = 0) const uint32_t WorkGroupSize = 128;
|
|
5
|
+
layout (constant_id = 1) const uint32_t Br = 1;
|
|
6
|
+
layout (constant_id = 2) const uint32_t Bc = 32;
|
|
7
|
+
layout (constant_id = 3) const uint32_t D = 32;
|
|
8
|
+
layout (constant_id = 4) const uint32_t Clamp = 0;
|
|
9
|
+
layout (constant_id = 5) const uint32_t D_split = 16;
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
layout (push_constant) uniform parameter {
|
|
13
|
+
uint32_t N;
|
|
14
|
+
uint32_t KV;
|
|
15
|
+
|
|
16
|
+
uint32_t ne1;
|
|
17
|
+
uint32_t ne2;
|
|
18
|
+
uint32_t ne3;
|
|
19
|
+
|
|
20
|
+
uint32_t neq2;
|
|
21
|
+
uint32_t neq3;
|
|
22
|
+
uint32_t nek2;
|
|
23
|
+
uint32_t nek3;
|
|
24
|
+
uint32_t nev2;
|
|
25
|
+
uint32_t nev3;
|
|
26
|
+
uint32_t nem1;
|
|
27
|
+
|
|
28
|
+
uint32_t nb01;
|
|
29
|
+
uint32_t nb02;
|
|
30
|
+
uint32_t nb03;
|
|
31
|
+
uint32_t nb11;
|
|
32
|
+
uint32_t nb12;
|
|
33
|
+
uint32_t nb13;
|
|
34
|
+
uint32_t nb21;
|
|
35
|
+
uint32_t nb22;
|
|
36
|
+
uint32_t nb23;
|
|
37
|
+
uint32_t nb31;
|
|
38
|
+
|
|
39
|
+
float scale;
|
|
40
|
+
float max_bias;
|
|
41
|
+
float logit_softcap;
|
|
42
|
+
|
|
43
|
+
uint32_t mask;
|
|
44
|
+
uint32_t n_head_log2;
|
|
45
|
+
float m0;
|
|
46
|
+
float m1;
|
|
47
|
+
|
|
48
|
+
uint32_t gqa_ratio;
|
|
49
|
+
uint32_t split_kv;
|
|
50
|
+
uint32_t k_num;
|
|
51
|
+
} p;
|
|
52
|
+
|
|
53
|
+
layout (binding = 4) writeonly buffer O {D_TYPE data_o[];};
|
|
54
|
+
|
|
55
|
+
#if defined(A_TYPE_PACKED16)
|
|
56
|
+
#define BINDING_IDX_K 0
|
|
57
|
+
#define BINDING_IDX_V 1
|
|
58
|
+
layout (binding = 1) readonly buffer KV_PACKED16 {A_TYPE_PACKED16 data_packed16[];} kv_packed[2];
|
|
59
|
+
#endif
|
|
60
|
+
|
|
61
|
+
#if defined(DATA_A_Q4_0)
|
|
62
|
+
#define BLOCK_BYTE_SIZE 18
|
|
63
|
+
|
|
64
|
+
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
|
|
65
|
+
uint vui_lo = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
|
|
66
|
+
uint vui_hi = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
|
|
67
|
+
uint shift = (iqs & 0x10) >> 2;
|
|
68
|
+
vui_lo >>= shift;
|
|
69
|
+
vui_hi >>= shift;
|
|
70
|
+
|
|
71
|
+
return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
|
|
72
|
+
}
|
|
73
|
+
#endif
|
|
74
|
+
|
|
75
|
+
#if defined(DATA_A_Q8_0)
|
|
76
|
+
#define BLOCK_BYTE_SIZE 34
|
|
77
|
+
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
|
|
78
|
+
const i8vec2 v0 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
|
|
79
|
+
const i8vec2 v1 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
|
|
80
|
+
|
|
81
|
+
return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
|
|
82
|
+
}
|
|
83
|
+
#endif
|
|
84
|
+
|
|
85
|
+
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
// Store column zero. This is used to save per-row m and L values for split_k.
|
|
89
|
+
ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
|
|
90
|
+
{
|
|
91
|
+
if (r < N && c == 0) {
|
|
92
|
+
uint32_t offset = iq2 + r;
|
|
93
|
+
data_o[o_offset + offset] = D_TYPE(elem);
|
|
94
|
+
}
|
|
95
|
+
return elem;
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
// Load the slope matrix, indexed by Q's dimension 2.
|
|
99
|
+
ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
|
|
100
|
+
{
|
|
101
|
+
const uint32_t h = iq2 + (r % p.gqa_ratio);
|
|
102
|
+
|
|
103
|
+
const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
|
|
104
|
+
const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
|
|
105
|
+
|
|
106
|
+
return ACC_TYPE(pow(base, ACC_TYPE(exph)));
|
|
107
|
+
}
|
|
108
|
+
|
|
109
|
+
uint32_t i, N, KV, split_k_index, Tr, start_j, end_j,
|
|
110
|
+
iq2, iq3, rk2, rk3, rv2, rv3, ik2, ik3, iv2, iv3,
|
|
111
|
+
q_stride, k_stride, v_stride, m_stride;
|
|
112
|
+
|
|
113
|
+
void init_indices()
|
|
114
|
+
{
|
|
115
|
+
N = p.N;
|
|
116
|
+
KV = p.KV;
|
|
117
|
+
|
|
118
|
+
i = gl_WorkGroupID.x;
|
|
119
|
+
split_k_index = 0;
|
|
120
|
+
|
|
121
|
+
if (p.k_num > 1) {
|
|
122
|
+
i = 0;
|
|
123
|
+
split_k_index = gl_WorkGroupID.x;
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
Tr = CEIL_DIV(N, Br);
|
|
127
|
+
|
|
128
|
+
start_j = split_k_index * p.split_kv / Bc;
|
|
129
|
+
end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc);
|
|
130
|
+
|
|
131
|
+
// When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y.
|
|
132
|
+
// When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2.
|
|
133
|
+
iq2 = gl_WorkGroupID.y * p.gqa_ratio;
|
|
134
|
+
iq3 = gl_WorkGroupID.z;
|
|
135
|
+
|
|
136
|
+
// broadcast factors
|
|
137
|
+
rk2 = p.neq2/p.nek2;
|
|
138
|
+
rk3 = p.neq3/p.nek3;
|
|
139
|
+
|
|
140
|
+
rv2 = p.neq2/p.nev2;
|
|
141
|
+
rv3 = p.neq3/p.nev3;
|
|
142
|
+
|
|
143
|
+
// k indices
|
|
144
|
+
ik3 = iq3 / rk3;
|
|
145
|
+
ik2 = iq2 / rk2;
|
|
146
|
+
|
|
147
|
+
// v indices
|
|
148
|
+
iv3 = iq3 / rv3;
|
|
149
|
+
iv2 = iq2 / rv2;
|
|
150
|
+
|
|
151
|
+
// nb?1 are already divided by the type size and are in units of elements.
|
|
152
|
+
// When using grouped query attention, Q is indexed by iq2, so the stride
|
|
153
|
+
// should be nb02 (which is in bytes).
|
|
154
|
+
q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01;
|
|
155
|
+
k_stride = p.nb11;
|
|
156
|
+
v_stride = p.nb21;
|
|
157
|
+
// When using grouped query attention, all rows use the same mask (stride 0).
|
|
158
|
+
// "p.gqa_ratio >> 16" is just a roundabout way of writing zero
|
|
159
|
+
// that prevents the compiler from folding the "&" through the select
|
|
160
|
+
// and breaking the alignment detection.
|
|
161
|
+
m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV;
|
|
162
|
+
}
|