@novastera-oss/llamarn 0.2.1 → 0.2.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +80 -14
- package/RNLlamaCpp.podspec +10 -3
- package/android/CMakeLists.txt +8 -0
- package/android/src/main/cpp/include/llama.h +62 -125
- package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
- package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
- package/cpp/build-info.cpp +2 -2
- package/cpp/llama.cpp/README.md +11 -3
- package/cpp/llama.cpp/build-xcframework.sh +1 -0
- package/cpp/llama.cpp/common/CMakeLists.txt +8 -2
- package/cpp/llama.cpp/common/arg.cpp +153 -113
- package/cpp/llama.cpp/common/chat-parser.cpp +379 -0
- package/cpp/llama.cpp/common/chat-parser.h +117 -0
- package/cpp/llama.cpp/common/chat.cpp +847 -699
- package/cpp/llama.cpp/common/chat.h +73 -6
- package/cpp/llama.cpp/common/common.cpp +50 -82
- package/cpp/llama.cpp/common/common.h +21 -17
- package/cpp/llama.cpp/common/json-partial.cpp +255 -0
- package/cpp/llama.cpp/common/json-partial.h +37 -0
- package/cpp/llama.cpp/common/minja/chat-template.hpp +9 -5
- package/cpp/llama.cpp/common/minja/minja.hpp +69 -36
- package/cpp/llama.cpp/common/regex-partial.cpp +204 -0
- package/cpp/llama.cpp/common/regex-partial.h +56 -0
- package/cpp/llama.cpp/common/sampling.cpp +7 -8
- package/cpp/llama.cpp/convert_hf_to_gguf.py +453 -118
- package/cpp/llama.cpp/convert_hf_to_gguf_update.py +120 -68
- package/cpp/llama.cpp/ggml/CMakeLists.txt +2 -1
- package/cpp/llama.cpp/ggml/cmake/common.cmake +25 -0
- package/cpp/llama.cpp/ggml/include/ggml-opt.h +49 -28
- package/cpp/llama.cpp/ggml/include/ggml.h +26 -7
- package/cpp/llama.cpp/ggml/src/CMakeLists.txt +16 -10
- package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +4 -1
- package/cpp/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +1 -0
- package/cpp/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +604 -0
- package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +42 -0
- package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +54 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +50 -51
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +5 -9
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +779 -19
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +22 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +88 -5
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +47 -12
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +264 -69
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +322 -100
- package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +117 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +85 -16
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +220 -49
- package/cpp/llama.cpp/ggml/src/ggml-cuda/acc.cu +40 -26
- package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy.cu +11 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +15 -7
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +266 -64
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +49 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +48 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cu +2 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +5 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cu +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/quantize.cu +7 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/sum.cu +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +10 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-impl.h +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +99 -17
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +200 -2
- package/cpp/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +8 -2
- package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cu +112 -0
- package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cuh +12 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +6 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +972 -178
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/div.cl +72 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/group_norm.cl +72 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/sub.cl +72 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/sum_rows.cl +39 -0
- package/cpp/llama.cpp/ggml/src/ggml-opt.cpp +373 -190
- package/cpp/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +29 -23
- package/cpp/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +5 -10
- package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +101 -5
- package/cpp/llama.cpp/ggml/src/ggml-sycl/concat.cpp +31 -33
- package/cpp/llama.cpp/ggml/src/ggml-sycl/conv.cpp +1 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +29 -2
- package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +4 -5
- package/cpp/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +59 -21
- package/cpp/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +9 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +84 -72
- package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +37 -8
- package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +1 -3
- package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +324 -129
- package/cpp/llama.cpp/ggml/src/ggml-sycl/gla.cpp +1 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +31 -2
- package/cpp/llama.cpp/ggml/src/ggml-sycl/norm.cpp +95 -68
- package/cpp/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +1 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +22 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +1 -2
- package/cpp/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +1 -4
- package/cpp/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +2 -3
- package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +69 -43
- package/cpp/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +2 -14
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +81 -91
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +432 -181
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +17 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +6 -152
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +162 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +360 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +2 -118
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +12 -1
- package/cpp/llama.cpp/ggml/src/ggml.c +107 -36
- package/cpp/llama.cpp/ggml/src/gguf.cpp +33 -33
- package/cpp/llama.cpp/gguf-py/gguf/constants.py +100 -15
- package/cpp/llama.cpp/gguf-py/gguf/gguf_reader.py +1 -1
- package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +44 -12
- package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_editor_gui.py +21 -10
- package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_new_metadata.py +5 -2
- package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +128 -31
- package/cpp/llama.cpp/gguf-py/gguf/utility.py +1 -1
- package/cpp/llama.cpp/gguf-py/pyproject.toml +1 -1
- package/cpp/llama.cpp/include/llama.h +62 -125
- package/cpp/llama.cpp/models/ggml-vocab-bert-bge.gguf.inp +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-bert-bge.gguf.out +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-command-r.gguf.inp +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-command-r.gguf.out +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-deepseek-coder.gguf.inp +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-deepseek-coder.gguf.out +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-deepseek-llm.gguf.inp +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-deepseek-llm.gguf.out +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-falcon.gguf.inp +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-falcon.gguf.out +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-gpt-2.gguf.inp +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-gpt-2.gguf.out +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-llama-bpe.gguf.inp +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-llama-bpe.gguf.out +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-llama-spm.gguf.inp +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-llama-spm.gguf.out +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-mpt.gguf.inp +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-mpt.gguf.out +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-nomic-bert-moe.gguf +0 -0
- package/cpp/llama.cpp/models/ggml-vocab-phi-3.gguf.inp +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-phi-3.gguf.out +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-qwen2.gguf.inp +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-qwen2.gguf.out +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-refact.gguf.inp +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-refact.gguf.out +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-starcoder.gguf.inp +1 -1
- package/cpp/llama.cpp/models/ggml-vocab-starcoder.gguf.out +1 -1
- package/cpp/llama.cpp/models/templates/Qwen-QwQ-32B.jinja +62 -0
- package/cpp/llama.cpp/models/templates/Qwen-Qwen3-0.6B.jinja +85 -0
- package/cpp/llama.cpp/models/templates/README.md +2 -0
- package/cpp/llama.cpp/requirements/requirements-convert_hf_to_gguf.txt +5 -1
- package/cpp/llama.cpp/requirements/requirements-convert_hf_to_gguf_update.txt +5 -1
- package/cpp/llama.cpp/requirements/requirements-convert_lora_to_gguf.txt +2 -0
- package/cpp/llama.cpp/requirements/requirements-gguf_editor_gui.txt +1 -1
- package/cpp/llama.cpp/src/CMakeLists.txt +2 -0
- package/cpp/llama.cpp/src/llama-arch.cpp +6 -0
- package/cpp/llama.cpp/src/llama-arch.h +2 -0
- package/cpp/llama.cpp/src/llama-batch.cpp +3 -1
- package/cpp/llama.cpp/src/llama-context.cpp +340 -123
- package/cpp/llama.cpp/src/llama-context.h +30 -0
- package/cpp/llama.cpp/src/llama-cparams.cpp +4 -0
- package/cpp/llama.cpp/src/llama-cparams.h +2 -0
- package/cpp/llama.cpp/src/llama-grammar.cpp +12 -2
- package/cpp/llama.cpp/src/llama-graph.cpp +157 -247
- package/cpp/llama.cpp/src/llama-graph.h +52 -7
- package/cpp/llama.cpp/src/llama-hparams.cpp +17 -1
- package/cpp/llama.cpp/src/llama-hparams.h +37 -5
- package/cpp/llama.cpp/src/llama-kv-cache.cpp +742 -481
- package/cpp/llama.cpp/src/llama-kv-cache.h +196 -99
- package/cpp/llama.cpp/src/llama-kv-cells.h +379 -0
- package/cpp/llama.cpp/src/llama-memory.h +4 -3
- package/cpp/llama.cpp/src/llama-model-loader.cpp +22 -17
- package/cpp/llama.cpp/src/llama-model-saver.cpp +281 -0
- package/cpp/llama.cpp/src/llama-model-saver.h +37 -0
- package/cpp/llama.cpp/src/llama-model.cpp +529 -172
- package/cpp/llama.cpp/src/llama-model.h +6 -1
- package/cpp/llama.cpp/src/llama-quant.cpp +15 -13
- package/cpp/llama.cpp/src/llama-sampling.cpp +2 -2
- package/cpp/llama.cpp/src/llama-vocab.cpp +35 -8
- package/cpp/llama.cpp/src/llama-vocab.h +6 -0
- package/cpp/llama.cpp/src/llama.cpp +14 -0
- package/cpp/rn-completion.cpp +4 -2
- package/ios/include/chat.h +73 -6
- package/ios/include/common/minja/chat-template.hpp +9 -5
- package/ios/include/common/minja/minja.hpp +69 -36
- package/ios/include/common.h +21 -17
- package/ios/include/llama.h +62 -125
- package/ios/libs/llama.xcframework/Info.plist +19 -19
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4617 -4487
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-opt.h +237 -0
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +26 -7
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +62 -125
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4638 -4508
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3557 -3435
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +237 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +26 -7
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +62 -125
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4638 -4508
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3559 -3437
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-opt.h +237 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +26 -7
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +62 -125
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-opt.h +237 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +26 -7
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +62 -125
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-opt.h +237 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +26 -7
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +62 -125
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4616 -4487
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-opt.h +237 -0
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +26 -7
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +62 -125
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4637 -4508
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3556 -3435
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +237 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +26 -7
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +62 -125
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4653 -4523
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-opt.h +237 -0
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +26 -7
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +62 -125
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4674 -4544
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3587 -3465
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +237 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +26 -7
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +62 -125
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/package.json +1 -1
- package/cpp/llama.cpp/common/stb_image.h +0 -7988
- package/cpp/llama.cpp/models/ggml-vocab-chameleon.gguf.inp +0 -112
- package/cpp/llama.cpp/models/ggml-vocab-chameleon.gguf.out +0 -46
- package/cpp/llama.cpp/models/ggml-vocab-deepseek-r1-qwen.gguf.inp +0 -112
- package/cpp/llama.cpp/models/ggml-vocab-deepseek-r1-qwen.gguf.out +0 -46
- package/cpp/llama.cpp/models/ggml-vocab-gpt-4o.gguf.inp +0 -112
- package/cpp/llama.cpp/models/ggml-vocab-gpt-4o.gguf.out +0 -46
- package/cpp/llama.cpp/models/ggml-vocab-llama4.gguf.inp +0 -112
- package/cpp/llama.cpp/models/ggml-vocab-llama4.gguf.out +0 -46
- package/cpp/llama.cpp/models/ggml-vocab-pixtral.gguf.inp +0 -112
- package/cpp/llama.cpp/models/ggml-vocab-pixtral.gguf.out +0 -46
- package/cpp/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.inp +0 -112
- package/cpp/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.out +0 -46
|
@@ -33,9 +33,30 @@ struct fattn_mma_f16_config< 64, 64> {
|
|
|
33
33
|
static constexpr int nwarps_max = 4;
|
|
34
34
|
static constexpr bool Q_in_reg = true;
|
|
35
35
|
static constexpr int nstages_target = 2;
|
|
36
|
-
|
|
37
|
-
static
|
|
38
|
-
|
|
36
|
+
|
|
37
|
+
static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
|
|
38
|
+
return 32;
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
|
|
42
|
+
return 32;
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
|
|
46
|
+
return 32;
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
|
|
50
|
+
return 32;
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
|
|
54
|
+
return 32;
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
|
|
58
|
+
return 32;
|
|
59
|
+
}
|
|
39
60
|
};
|
|
40
61
|
|
|
41
62
|
template <>
|
|
@@ -44,9 +65,30 @@ struct fattn_mma_f16_config< 80, 80> {
|
|
|
44
65
|
static constexpr int nwarps_max = 4;
|
|
45
66
|
static constexpr bool Q_in_reg = true;
|
|
46
67
|
static constexpr int nstages_target = 2;
|
|
47
|
-
|
|
48
|
-
static
|
|
49
|
-
|
|
68
|
+
|
|
69
|
+
static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
|
|
70
|
+
return 40;
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
|
|
74
|
+
return 40;
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
|
|
78
|
+
return 40;
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
|
|
82
|
+
return 40;
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
|
|
86
|
+
return 40;
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
|
|
90
|
+
return 40;
|
|
91
|
+
}
|
|
50
92
|
};
|
|
51
93
|
|
|
52
94
|
template <>
|
|
@@ -55,9 +97,30 @@ struct fattn_mma_f16_config< 96, 96> {
|
|
|
55
97
|
static constexpr int nwarps_max = 4;
|
|
56
98
|
static constexpr bool Q_in_reg = true;
|
|
57
99
|
static constexpr int nstages_target = 2;
|
|
58
|
-
|
|
59
|
-
static
|
|
60
|
-
|
|
100
|
+
|
|
101
|
+
static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
|
|
102
|
+
return 48;
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
|
|
106
|
+
return 48;
|
|
107
|
+
}
|
|
108
|
+
|
|
109
|
+
static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
|
|
110
|
+
return 48;
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
|
|
114
|
+
return 48;
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
|
|
118
|
+
return 48;
|
|
119
|
+
}
|
|
120
|
+
|
|
121
|
+
static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
|
|
122
|
+
return 48;
|
|
123
|
+
}
|
|
61
124
|
};
|
|
62
125
|
|
|
63
126
|
template <>
|
|
@@ -66,9 +129,30 @@ struct fattn_mma_f16_config<112, 112> {
|
|
|
66
129
|
static constexpr int nwarps_max = 4;
|
|
67
130
|
static constexpr bool Q_in_reg = true;
|
|
68
131
|
static constexpr int nstages_target = 2;
|
|
69
|
-
|
|
70
|
-
static
|
|
71
|
-
|
|
132
|
+
|
|
133
|
+
static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
|
|
134
|
+
return 56;
|
|
135
|
+
}
|
|
136
|
+
|
|
137
|
+
static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
|
|
138
|
+
return 56;
|
|
139
|
+
}
|
|
140
|
+
|
|
141
|
+
static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
|
|
142
|
+
return 56;
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
|
|
146
|
+
return 56;
|
|
147
|
+
}
|
|
148
|
+
|
|
149
|
+
static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
|
|
150
|
+
return 56;
|
|
151
|
+
}
|
|
152
|
+
|
|
153
|
+
static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
|
|
154
|
+
return 56;
|
|
155
|
+
}
|
|
72
156
|
};
|
|
73
157
|
|
|
74
158
|
template <>
|
|
@@ -77,9 +161,30 @@ struct fattn_mma_f16_config<128, 128> {
|
|
|
77
161
|
static constexpr int nwarps_max = 4;
|
|
78
162
|
static constexpr bool Q_in_reg = true;
|
|
79
163
|
static constexpr int nstages_target = 2;
|
|
80
|
-
|
|
81
|
-
static
|
|
82
|
-
|
|
164
|
+
|
|
165
|
+
static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
|
|
166
|
+
return 64;
|
|
167
|
+
}
|
|
168
|
+
|
|
169
|
+
static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
|
|
170
|
+
return 64;
|
|
171
|
+
}
|
|
172
|
+
|
|
173
|
+
static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
|
|
174
|
+
return 64;
|
|
175
|
+
}
|
|
176
|
+
|
|
177
|
+
static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
|
|
178
|
+
return 64;
|
|
179
|
+
}
|
|
180
|
+
|
|
181
|
+
static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
|
|
182
|
+
return 64;
|
|
183
|
+
}
|
|
184
|
+
|
|
185
|
+
static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
|
|
186
|
+
return 64;
|
|
187
|
+
}
|
|
83
188
|
};
|
|
84
189
|
|
|
85
190
|
template <>
|
|
@@ -88,9 +193,38 @@ struct fattn_mma_f16_config<256, 256> {
|
|
|
88
193
|
static constexpr int nwarps_max = 4;
|
|
89
194
|
static constexpr bool Q_in_reg = true;
|
|
90
195
|
static constexpr int nstages_target = 2;
|
|
91
|
-
|
|
92
|
-
static
|
|
93
|
-
|
|
196
|
+
|
|
197
|
+
static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
|
|
198
|
+
return 128;
|
|
199
|
+
}
|
|
200
|
+
|
|
201
|
+
static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
|
|
202
|
+
return 128;
|
|
203
|
+
}
|
|
204
|
+
|
|
205
|
+
static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
|
|
206
|
+
return 128;
|
|
207
|
+
}
|
|
208
|
+
|
|
209
|
+
static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
|
|
210
|
+
return 128;
|
|
211
|
+
}
|
|
212
|
+
|
|
213
|
+
static int get_nbatch_combine_host(const int cc, const int ncols) {
|
|
214
|
+
if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) {
|
|
215
|
+
return ncols <= 16 ? 128 : 64;
|
|
216
|
+
}
|
|
217
|
+
return 64;
|
|
218
|
+
}
|
|
219
|
+
|
|
220
|
+
static constexpr __device__ int get_nbatch_combine_device(int ncols) {
|
|
221
|
+
#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
|
222
|
+
return ncols <= 16 ? 128 : 64;
|
|
223
|
+
#else
|
|
224
|
+
GGML_UNUSED(ncols);
|
|
225
|
+
return 128;
|
|
226
|
+
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
|
227
|
+
}
|
|
94
228
|
};
|
|
95
229
|
|
|
96
230
|
template <>
|
|
@@ -99,9 +233,44 @@ struct fattn_mma_f16_config<576, 512> {
|
|
|
99
233
|
static constexpr int nwarps_max = 8;
|
|
100
234
|
static constexpr bool Q_in_reg = false;
|
|
101
235
|
static constexpr int nstages_target = 1;
|
|
102
|
-
|
|
103
|
-
static
|
|
104
|
-
|
|
236
|
+
|
|
237
|
+
static int get_nbatch_K2_host(const int cc, const int ncols) {
|
|
238
|
+
if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) {
|
|
239
|
+
return ncols <= 16 ? 96 : 160;
|
|
240
|
+
}
|
|
241
|
+
return ncols <= 16 ? 288 : 160;
|
|
242
|
+
}
|
|
243
|
+
|
|
244
|
+
static constexpr __device__ int get_nbatch_K2_device(int ncols) {
|
|
245
|
+
#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
|
246
|
+
return ncols <= 16 ? 96 : 160;
|
|
247
|
+
#else
|
|
248
|
+
return ncols <= 16 ? 288 : 160;
|
|
249
|
+
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
|
250
|
+
}
|
|
251
|
+
|
|
252
|
+
static int get_nbatch_V2_host(const int cc, const int ncols) {
|
|
253
|
+
if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) {
|
|
254
|
+
return ncols <= 16 ? 64 : 128;
|
|
255
|
+
}
|
|
256
|
+
return ncols <= 16 ? 256 : 128;
|
|
257
|
+
}
|
|
258
|
+
|
|
259
|
+
static constexpr __device__ int get_nbatch_V2_device(int ncols) {
|
|
260
|
+
#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
|
261
|
+
return ncols <= 16 ? 64 : 128;
|
|
262
|
+
#else
|
|
263
|
+
return ncols <= 16 ? 256 : 128;
|
|
264
|
+
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
|
265
|
+
}
|
|
266
|
+
|
|
267
|
+
static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
|
|
268
|
+
return 128;
|
|
269
|
+
}
|
|
270
|
+
|
|
271
|
+
static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
|
|
272
|
+
return 128;
|
|
273
|
+
}
|
|
105
274
|
};
|
|
106
275
|
|
|
107
276
|
// ------------------------------------------------------------------------------------------------------------------
|
|
@@ -120,7 +289,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
|
|
|
120
289
|
|
|
121
290
|
const unsigned int tile_KV_32 = ggml_cuda_cvta_generic_to_shared(tile_KV);
|
|
122
291
|
|
|
123
|
-
auto load = [&] __device__ (
|
|
292
|
+
auto load = [&] __device__ (auto n) {
|
|
124
293
|
const int stride_k = WARP_SIZE >> n;
|
|
125
294
|
const int k0_start = stride_k == WARP_SIZE ? 0 : chunks_per_row - chunks_per_row % (2*stride_k);
|
|
126
295
|
const int k0_stop = chunks_per_row - chunks_per_row % (1*stride_k);
|
|
@@ -223,7 +392,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
|
|
|
223
392
|
}
|
|
224
393
|
}
|
|
225
394
|
|
|
226
|
-
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool needs_fixup, bool is_fixup, bool last_iter>
|
|
395
|
+
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup, bool last_iter>
|
|
227
396
|
static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
228
397
|
const float2 * const __restrict__ Q_f2,
|
|
229
398
|
const half2 * const __restrict__ K_h2,
|
|
@@ -261,10 +430,15 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
261
430
|
constexpr int cols_per_warp = ntiles * tile_B::I;
|
|
262
431
|
constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles;
|
|
263
432
|
constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
|
|
433
|
+
constexpr int ncols = ncols1 * ncols2;
|
|
434
|
+
constexpr int nbatch_K2 = c::get_nbatch_K2_device(ncols);
|
|
435
|
+
constexpr int nbatch_V2 = c::get_nbatch_V2_device(ncols);
|
|
436
|
+
|
|
437
|
+
constexpr int stride_tile_Q = DKQ/2 + 4;
|
|
438
|
+
constexpr int stride_tile_K = nbatch_K2 + 4;
|
|
264
439
|
|
|
265
|
-
|
|
266
|
-
constexpr int
|
|
267
|
-
constexpr int stride_tile_V = c::nbatch_V2 + 4;
|
|
440
|
+
static_assert(!mla || nbatch_K2 >= nbatch_V2, "bad nbatch_K2, nbatch_V2 for MLA");
|
|
441
|
+
constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4;
|
|
268
442
|
|
|
269
443
|
const int k_VKQ_0 = kb0 * c::nbatch_fa;
|
|
270
444
|
tile_C_KQ KQ_C[c::nbatch_fa/(np*tile_C_KQ::I) * ntiles];
|
|
@@ -275,12 +449,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
275
449
|
tile_C_KQ_16 * KQ_C_16 = (tile_C_KQ_16 *) KQ_C;
|
|
276
450
|
|
|
277
451
|
if constexpr (nstages > 1) {
|
|
278
|
-
static_assert(
|
|
452
|
+
static_assert(!mla, "multi-stage loading not implemented for MLA");
|
|
453
|
+
static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi stage loading");
|
|
279
454
|
constexpr bool use_cp_async = true;
|
|
280
455
|
cp_async_wait_all();
|
|
281
456
|
__syncthreads();
|
|
282
457
|
flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
|
|
283
|
-
(V_h2 + k_VKQ_0*stride_V, tile_V,
|
|
458
|
+
(V_h2 + k_VKQ_0*stride_V, tile_V, nbatch_V2, stride_V);
|
|
284
459
|
} else {
|
|
285
460
|
constexpr bool use_cp_async = nstages == 1;
|
|
286
461
|
if (ncols2 > 1 || mask_h2) {
|
|
@@ -289,8 +464,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
289
464
|
}
|
|
290
465
|
|
|
291
466
|
#pragma unroll
|
|
292
|
-
for (int k0_start = 0; k0_start < DKQ/2; k0_start +=
|
|
293
|
-
const int k0_stop = k0_start +
|
|
467
|
+
for (int k0_start = 0; k0_start < DKQ/2; k0_start += nbatch_K2) {
|
|
468
|
+
const int k0_stop = k0_start + nbatch_K2 < DKQ/2 ? k0_start + nbatch_K2 : DKQ/2;
|
|
294
469
|
const int k0_diff = k0_stop - k0_start;
|
|
295
470
|
|
|
296
471
|
if (nstages <= 1) {
|
|
@@ -537,16 +712,21 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
537
712
|
(mask_h2 + (k_VKQ_0 + c::nbatch_fa)/2, tile_mask, stride_mask);
|
|
538
713
|
}
|
|
539
714
|
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
|
|
540
|
-
(K_h2 + (k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K,
|
|
715
|
+
(K_h2 + (k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K);
|
|
541
716
|
}
|
|
542
717
|
}
|
|
543
718
|
|
|
719
|
+
|
|
720
|
+
// For MLA K and V have the same data.
|
|
721
|
+
// Therefore, iterate over V in reverse and re-use the data if possible.
|
|
722
|
+
static_assert(!mla || nstages <= 1, "combination of MLA and multi-stage loading not implemented");
|
|
723
|
+
constexpr int reusable_cutoff = mla ? (DKQ - 1) - (DKQ - 1) % (2*nbatch_K2) - (DKQ - DV) : DV;
|
|
544
724
|
#pragma unroll
|
|
545
|
-
for (int
|
|
546
|
-
const int
|
|
547
|
-
const int i0_diff
|
|
725
|
+
for (int i0_stop = DV; i0_stop > 0; i0_stop -= 2*nbatch_V2) {
|
|
726
|
+
const int i0_start = i0_stop - 2*nbatch_V2 > 0 ? i0_stop - 2*nbatch_V2 : 0;
|
|
727
|
+
const int i0_diff = i0_stop - i0_start;
|
|
548
728
|
|
|
549
|
-
if (nstages <= 1) {
|
|
729
|
+
if (nstages <= 1 && i0_start < reusable_cutoff) {
|
|
550
730
|
constexpr bool use_cp_async = nstages == 1;
|
|
551
731
|
flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
|
|
552
732
|
(V_h2 + k_VKQ_0*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V);
|
|
@@ -555,6 +735,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
555
735
|
}
|
|
556
736
|
__syncthreads();
|
|
557
737
|
}
|
|
738
|
+
const half2 * tile_V_i = i0_start < reusable_cutoff ? tile_V : tile_V + (i0_start - reusable_cutoff)/2;
|
|
558
739
|
|
|
559
740
|
// Calculate VKQ tile:
|
|
560
741
|
#pragma unroll
|
|
@@ -565,7 +746,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
565
746
|
const int k0 = k00 + (threadIdx.y % np)*tile_A::J;
|
|
566
747
|
|
|
567
748
|
tile_A A;
|
|
568
|
-
load_ldmatrix_trans(A,
|
|
749
|
+
load_ldmatrix_trans(A, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
|
|
569
750
|
if (ntiles == 1) {
|
|
570
751
|
mma(VKQ_C[i_VKQ_0/tile_C_VKQ::I], A, B[k00/(np*tile_A::J)]);
|
|
571
752
|
} else {
|
|
@@ -591,12 +772,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
591
772
|
GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K);
|
|
592
773
|
GGML_UNUSED(tile_V); GGML_UNUSED(tile_mask); GGML_UNUSED(Q_B);
|
|
593
774
|
GGML_UNUSED(VKQ_C); GGML_UNUSED(KQ_max); GGML_UNUSED(KQ_rowsum);
|
|
594
|
-
GGML_UNUSED(kb0);
|
|
775
|
+
GGML_UNUSED(kb0); GGML_UNUSED(tile_Q);
|
|
595
776
|
NO_DEVICE_CODE;
|
|
596
777
|
#endif // NEW_MMA_AVAILABLE
|
|
597
778
|
}
|
|
598
779
|
|
|
599
|
-
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool needs_fixup, bool is_fixup>
|
|
780
|
+
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup>
|
|
600
781
|
static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
601
782
|
const float2 * const __restrict__ Q_f2,
|
|
602
783
|
const half2 * const __restrict__ K_h2,
|
|
@@ -632,13 +813,16 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
632
813
|
constexpr int cols_per_warp = ntiles * tile_B::I;
|
|
633
814
|
constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles;
|
|
634
815
|
constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
|
|
816
|
+
constexpr int nbatch_K2 = c::get_nbatch_K2_device(ncols);
|
|
817
|
+
constexpr int nbatch_V2 = c::get_nbatch_V2_device(ncols);
|
|
635
818
|
|
|
636
819
|
static_assert(nwarps * (cols_per_warp/ncols2) % ncols1 == 0, "bad nwarps");
|
|
637
820
|
|
|
638
|
-
constexpr int stride_tile_Q = DKQ/2
|
|
639
|
-
constexpr int stride_tile_K =
|
|
640
|
-
constexpr int stride_tile_V = c::nbatch_V2 + 4;
|
|
821
|
+
constexpr int stride_tile_Q = DKQ/2 + 4;
|
|
822
|
+
constexpr int stride_tile_K = nbatch_K2 + 4;
|
|
641
823
|
|
|
824
|
+
static_assert(!mla || nbatch_K2 >= nbatch_V2, "bad nbatch_K2, nbatch_V2 for MLA");
|
|
825
|
+
constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4;
|
|
642
826
|
constexpr int stride_tile_KV_max = stride_tile_K > stride_tile_V ? stride_tile_K : stride_tile_V;
|
|
643
827
|
|
|
644
828
|
extern __shared__ half2 tile_Q[];
|
|
@@ -726,26 +910,26 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
726
910
|
|
|
727
911
|
// Preload mask and K data for first iteration when using cp_async with multiple stages:
|
|
728
912
|
if constexpr (nstages > 1) {
|
|
729
|
-
static_assert(
|
|
913
|
+
static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi-stage pipeline");
|
|
730
914
|
constexpr bool use_cp_async = true;
|
|
731
915
|
if (ncols2 > 1 || mask_h2) {
|
|
732
916
|
flash_attn_ext_f16_load_mask<ncols1, nwarps, c::nbatch_fa, use_cp_async>
|
|
733
917
|
(mask_h2 + kb0_start*c::nbatch_fa/2, tile_mask, stride_mask);
|
|
734
918
|
}
|
|
735
919
|
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
|
|
736
|
-
(K_h2 + kb0_start*c::nbatch_fa*stride_K, tile_K,
|
|
920
|
+
(K_h2 + kb0_start*c::nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K);
|
|
737
921
|
}
|
|
738
922
|
|
|
739
923
|
// Iterate over ne11 == previous tokens:
|
|
740
924
|
for (int kb0 = kb0_start; kb0 < kb0_stop-1; ++kb0) {
|
|
741
925
|
constexpr bool last_iter = false;
|
|
742
|
-
flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup, last_iter>
|
|
926
|
+
flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
|
|
743
927
|
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
|
|
744
928
|
ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
|
|
745
929
|
}
|
|
746
930
|
{ // kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
|
|
747
931
|
constexpr bool last_iter = true;
|
|
748
|
-
flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup, last_iter>
|
|
932
|
+
flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
|
|
749
933
|
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
|
|
750
934
|
ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1);
|
|
751
935
|
}
|
|
@@ -774,7 +958,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
774
958
|
// It's also faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM.
|
|
775
959
|
// So also write VKQ accumulators to shared memory in column-major format if np == 1.
|
|
776
960
|
|
|
777
|
-
constexpr int nbatch_combine = c::
|
|
961
|
+
constexpr int nbatch_combine = c::get_nbatch_combine_device(ncols);
|
|
778
962
|
constexpr int tile_stride = nbatch_combine + 4;
|
|
779
963
|
static_assert((DV/2) % nbatch_combine == 0, "bad nbatch_combine");
|
|
780
964
|
|
|
@@ -895,6 +1079,11 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
895
1079
|
float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
|
|
896
1080
|
dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
|
|
897
1081
|
}
|
|
1082
|
+
} else if (np > 1) {
|
|
1083
|
+
// Warps with threadIdx.y % np == 0 execute a __syncthreads() in the if branch.
|
|
1084
|
+
// Therefore, all other warps also need to execute a __syncthreads().
|
|
1085
|
+
// Otherwise the points at which warps synchronize with each other would become misaligned.
|
|
1086
|
+
__syncthreads();
|
|
898
1087
|
}
|
|
899
1088
|
|
|
900
1089
|
#pragma unroll
|
|
@@ -1007,7 +1196,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
1007
1196
|
#endif // NEW_MMA_AVAILABLE
|
|
1008
1197
|
}
|
|
1009
1198
|
|
|
1010
|
-
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap>
|
|
1199
|
+
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla>
|
|
1011
1200
|
__launch_bounds__(nwarps*WARP_SIZE, 1)
|
|
1012
1201
|
static __global__ void flash_attn_ext_f16(
|
|
1013
1202
|
const char * __restrict__ Q,
|
|
@@ -1052,6 +1241,14 @@ static __global__ void flash_attn_ext_f16(
|
|
|
1052
1241
|
NO_DEVICE_CODE;
|
|
1053
1242
|
return;
|
|
1054
1243
|
}
|
|
1244
|
+
#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
|
1245
|
+
if (ncols1*ncols2 > 32) {
|
|
1246
|
+
NO_DEVICE_CODE;
|
|
1247
|
+
return;
|
|
1248
|
+
}
|
|
1249
|
+
#endif __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
|
1250
|
+
|
|
1251
|
+
static_assert(!mla || DKQ >= DV, "MLA needs DKQ >= DV");
|
|
1055
1252
|
|
|
1056
1253
|
typedef fattn_mma_f16_config<DKQ, DV> c;
|
|
1057
1254
|
|
|
@@ -1062,9 +1259,10 @@ static __global__ void flash_attn_ext_f16(
|
|
|
1062
1259
|
const int stride_Q1 = nb01 / sizeof(float2);
|
|
1063
1260
|
const int stride_Q2 = nb02 / sizeof(float2);
|
|
1064
1261
|
const int stride_K = nb11 / sizeof(half2);
|
|
1065
|
-
const int stride_V = nb21 / sizeof(half2);
|
|
1066
1262
|
const int stride_mask = nb31 / sizeof(half2);
|
|
1067
1263
|
|
|
1264
|
+
const int stride_V = mla ? stride_K : nb21 / sizeof(half2);
|
|
1265
|
+
|
|
1068
1266
|
const int iter_k = ne11 / FATTN_KQ_STRIDE;
|
|
1069
1267
|
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
|
|
1070
1268
|
|
|
@@ -1087,10 +1285,11 @@ static __global__ void flash_attn_ext_f16(
|
|
|
1087
1285
|
|
|
1088
1286
|
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
|
|
1089
1287
|
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
|
|
1090
|
-
const half2 * V_h2 = (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
|
|
1091
1288
|
const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
|
|
1092
1289
|
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
|
|
1093
1290
|
|
|
1291
|
+
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
|
|
1292
|
+
|
|
1094
1293
|
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
|
|
1095
1294
|
|
|
1096
1295
|
const int kb0_start_kernel = kb0_start * kb_niter;
|
|
@@ -1099,12 +1298,12 @@ static __global__ void flash_attn_ext_f16(
|
|
|
1099
1298
|
constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
|
|
1100
1299
|
if (kb0_start == 0) {
|
|
1101
1300
|
constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
|
|
1102
|
-
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup>
|
|
1301
|
+
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
|
|
1103
1302
|
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
|
|
1104
1303
|
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
|
1105
1304
|
} else {
|
|
1106
1305
|
constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile.
|
|
1107
|
-
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup>
|
|
1306
|
+
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
|
|
1108
1307
|
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
|
|
1109
1308
|
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
|
1110
1309
|
}
|
|
@@ -1125,10 +1324,11 @@ static __global__ void flash_attn_ext_f16(
|
|
|
1125
1324
|
|
|
1126
1325
|
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
|
|
1127
1326
|
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
|
|
1128
|
-
const half2 * V_h2 = (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio)); // K and V have same shape
|
|
1129
1327
|
const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
|
|
1130
1328
|
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
|
|
1131
1329
|
|
|
1330
|
+
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
|
|
1331
|
+
|
|
1132
1332
|
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
|
|
1133
1333
|
|
|
1134
1334
|
const int kb0_start_kernel = kb0_start * kb_niter;
|
|
@@ -1136,7 +1336,7 @@ static __global__ void flash_attn_ext_f16(
|
|
|
1136
1336
|
|
|
1137
1337
|
constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
|
|
1138
1338
|
constexpr bool needs_fixup = false;
|
|
1139
|
-
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup>
|
|
1339
|
+
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
|
|
1140
1340
|
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
|
|
1141
1341
|
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
|
1142
1342
|
#else
|
|
@@ -1162,10 +1362,6 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
|
|
|
1162
1362
|
|
|
1163
1363
|
typedef fattn_mma_f16_config<DKQ, DV> c;
|
|
1164
1364
|
|
|
1165
|
-
constexpr int nbatch_K2 = c::nbatch_K2 < 1 ? DKQ/2 : c::nbatch_K2;
|
|
1166
|
-
constexpr int nbatch_V2 = c::nbatch_V2 < 1 ? DV /2 : c::nbatch_V2;
|
|
1167
|
-
constexpr int nbatch_combine = c::nbatch_combine < 1 ? DV /2 : c::nbatch_combine;
|
|
1168
|
-
|
|
1169
1365
|
const int nstages = cp_async_available(cc) ? c::nstages_target : 0;
|
|
1170
1366
|
|
|
1171
1367
|
constexpr int ncols = ncols1 * ncols2;
|
|
@@ -1175,15 +1371,21 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
|
|
|
1175
1371
|
constexpr int nwarps_max_y = c::nbatch_fa / tile_A::I;
|
|
1176
1372
|
constexpr int nwarps = nwarps_max_x*nwarps_max_y <= c::nwarps_max ? nwarps_max_x*nwarps_max_y : c::nwarps_max;
|
|
1177
1373
|
|
|
1374
|
+
constexpr bool mla = DKQ == 576;
|
|
1375
|
+
|
|
1376
|
+
const int nbatch_K2 = c::get_nbatch_K2_host (cc, ncols);
|
|
1377
|
+
const int nbatch_V2 = c::get_nbatch_K2_host (cc, ncols);
|
|
1378
|
+
const int nbatch_combine = c::get_nbatch_combine_host(cc, ncols);
|
|
1379
|
+
|
|
1178
1380
|
static_assert(DKQ % tile_B::J == 0, "bad DKQ");
|
|
1179
1381
|
static_assert(DV % tile_A::J == 0, "bad DV");
|
|
1180
1382
|
static_assert(ncols % cols_per_warp == 0, "bad ncols");
|
|
1181
1383
|
|
|
1182
|
-
const size_t nbytes_shared_KV_1stage = c::nbatch_fa * std::max(
|
|
1183
|
-
const size_t nbytes_shared_KV_2stage = c::nbatch_fa * (
|
|
1184
|
-
const size_t nbytes_shared_Q = ncols * (DKQ/2 + 4)
|
|
1185
|
-
const size_t nbytes_shared_mask = ncols1 * (c::nbatch_fa/2 + 4)
|
|
1186
|
-
const size_t nbytes_shared_combine = nwarps*cols_per_warp * (nbatch_combine + 4)
|
|
1384
|
+
const size_t nbytes_shared_KV_1stage = c::nbatch_fa * std::max(nbatch_K2 + 4, nbatch_V2 + 4) * sizeof(half2);
|
|
1385
|
+
const size_t nbytes_shared_KV_2stage = c::nbatch_fa * (nbatch_K2 + 4 + nbatch_V2 + 4) * sizeof(half2);
|
|
1386
|
+
const size_t nbytes_shared_Q = ncols * (DKQ/2 + 4) * sizeof(half2);
|
|
1387
|
+
const size_t nbytes_shared_mask = ncols1 * (c::nbatch_fa/2 + 4) * sizeof(half2);
|
|
1388
|
+
const size_t nbytes_shared_combine = nwarps*cols_per_warp * (nbatch_combine + 4) * sizeof(half2);
|
|
1187
1389
|
|
|
1188
1390
|
const size_t nbytes_shared_KV = nstages <= 1 ? nbytes_shared_KV_1stage : nbytes_shared_KV_2stage;
|
|
1189
1391
|
|
|
@@ -1197,7 +1399,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
|
|
|
1197
1399
|
fattn_kernel_t fattn_kernel;
|
|
1198
1400
|
if (logit_softcap == 0.0f) {
|
|
1199
1401
|
constexpr bool use_logit_softcap = false;
|
|
1200
|
-
fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap>;
|
|
1402
|
+
fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla>;
|
|
1201
1403
|
|
|
1202
1404
|
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
|
1203
1405
|
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
|
|
@@ -1208,7 +1410,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
|
|
|
1208
1410
|
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
|
1209
1411
|
} else {
|
|
1210
1412
|
constexpr bool use_logit_softcap = true;
|
|
1211
|
-
fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap>;
|
|
1413
|
+
fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla>;
|
|
1212
1414
|
|
|
1213
1415
|
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
|
1214
1416
|
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
|