@novastera-oss/llamarn 0.3.1 → 0.4.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +86 -3
- package/RNLlamaCpp.podspec +1 -1
- package/android/CMakeLists.txt +11 -3
- package/android/generated/jni/react/renderer/components/RNLlamaCppSpec/RNLlamaCppSpecJSI.h +49 -4
- package/android/src/main/cpp/include/llama.h +53 -114
- package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86/libggml.so +0 -0
- package/android/src/main/jniLibs/x86/libllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
- package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
- package/cpp/LlamaCppModel.cpp +2 -10
- package/cpp/PureCppImpl.cpp +71 -4
- package/cpp/SystemUtils.cpp +3 -7
- package/cpp/build-info.cpp +2 -2
- package/cpp/llama.cpp/CMakeLists.txt +2 -0
- package/cpp/llama.cpp/CODEOWNERS +1 -1
- package/cpp/llama.cpp/Makefile +6 -1605
- package/cpp/llama.cpp/README.md +5 -1
- package/cpp/llama.cpp/common/arg.cpp +230 -51
- package/cpp/llama.cpp/common/chat-parser.cpp +9 -1
- package/cpp/llama.cpp/common/chat.cpp +539 -8
- package/cpp/llama.cpp/common/chat.h +8 -1
- package/cpp/llama.cpp/common/common.cpp +60 -15
- package/cpp/llama.cpp/common/common.h +64 -15
- package/cpp/llama.cpp/common/speculative.cpp +135 -54
- package/cpp/llama.cpp/common/speculative.h +8 -1
- package/cpp/llama.cpp/convert_hf_to_gguf.py +1216 -109
- package/cpp/llama.cpp/convert_hf_to_gguf_update.py +19 -6
- package/cpp/llama.cpp/convert_lora_to_gguf.py +1 -1
- package/cpp/llama.cpp/flake.nix +0 -5
- package/cpp/llama.cpp/ggml/CMakeLists.txt +6 -3
- package/cpp/llama.cpp/ggml/cmake/ggml-config.cmake.in +71 -70
- package/cpp/llama.cpp/ggml/include/ggml-opt.h +25 -6
- package/cpp/llama.cpp/ggml/include/ggml-zdnn.h +16 -0
- package/cpp/llama.cpp/ggml/include/ggml.h +90 -3
- package/cpp/llama.cpp/ggml/src/CMakeLists.txt +13 -1
- package/cpp/llama.cpp/ggml/src/ggml-alloc.c +1 -0
- package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +10 -0
- package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +113 -17
- package/cpp/llama.cpp/ggml/src/ggml-blas/ggml-blas.cpp +4 -4
- package/cpp/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +14 -0
- package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +701 -585
- package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +13 -3
- package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +52 -0
- package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +274 -91
- package/cpp/llama.cpp/ggml/src/ggml-common.h +17 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +132 -596
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +14 -286
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +90 -569
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +162 -589
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +55 -341
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +3 -58
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +371 -298
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +54 -314
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +184 -675
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/repack.cpp +4679 -1657
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +33 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +26 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +21 -24
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +16 -7
- package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +232 -123
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +428 -23
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +4 -8
- package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +35 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.h +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +458 -46
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.h +22 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +39 -14
- package/cpp/llama.cpp/ggml/src/ggml-cpu/traits.cpp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/traits.h +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +20 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +122 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +9 -11
- package/cpp/llama.cpp/ggml/src/ggml-cuda/add-id.cu +58 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/add-id.cuh +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/binbcast.cu +275 -170
- package/cpp/llama.cpp/ggml/src/ggml-cuda/binbcast.cuh +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +103 -65
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d.cu +171 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +33 -7
- package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +13 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy-utils.cuh +2 -10
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy.cu +3 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/dequantize.cuh +14 -40
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +83 -27
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +116 -57
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +45 -18
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +56 -29
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +61 -39
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +70 -49
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +70 -21
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cu +162 -50
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cuh +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +5 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +208 -97
- package/cpp/llama.cpp/ggml/src/ggml-cuda/im2col.cu +46 -35
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cu +56 -2
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mma.cuh +95 -51
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmf.cu +427 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmf.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cu +204 -57
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +252 -168
- package/cpp/llama.cpp/ggml/src/ggml-cuda/{mmv.cu → mmvf.cu} +53 -53
- package/cpp/llama.cpp/ggml/src/ggml-cuda/{mmv.cuh → mmvf.cuh} +3 -3
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmvq.cu +10 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cu +192 -19
- package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/pad_reflect_1d.cu +82 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/roll.cu +67 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/roll.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cu +1 -8
- package/cpp/llama.cpp/ggml/src/ggml-cuda/softcap.cu +34 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/softcap.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/softmax.cu +16 -10
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +153 -71
- package/cpp/llama.cpp/ggml/src/ggml-cuda/sum.cu +6 -10
- package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cu +21 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +75 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +14 -25
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +2 -1
- package/cpp/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-impl.h +61 -0
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +31 -20
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +342 -131
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +464 -134
- package/cpp/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +0 -4
- package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +1108 -176
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/add.cl +107 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/div.cl +66 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +343 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +343 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +346 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl +41 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/group_norm.cl +49 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +132 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +133 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/sub.cl +66 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/transpose.cl +20 -0
- package/cpp/llama.cpp/ggml/src/ggml-opt.cpp +97 -41
- package/cpp/llama.cpp/ggml/src/ggml-quants.c +110 -16
- package/cpp/llama.cpp/ggml/src/ggml-quants.h +6 -0
- package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +22 -9
- package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +0 -212
- package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.hpp +213 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +117 -238
- package/cpp/llama.cpp/ggml/src/ggml-sycl/quantize.hpp +133 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +94 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1666 -633
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +41 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +13 -4
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +39 -29
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +107 -43
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +18 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +21 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +20 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +21 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +16 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +44 -8
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +44 -16
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +26 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +2 -17
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +37 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +11 -7
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +109 -55
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +71 -41
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +6 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +111 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +49 -11
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +9 -3
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +38 -5
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +55 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/utils.comp +25 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +75 -20
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/CMakeLists.txt +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/ggml-webgpu.cpp +807 -412
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +72 -22
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +8 -8
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +1794 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +82 -0
- package/cpp/llama.cpp/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
- package/cpp/llama.cpp/ggml/src/ggml-zdnn/ggml-zdnn-impl.h +97 -0
- package/cpp/llama.cpp/ggml/src/ggml-zdnn/ggml-zdnn.cpp +846 -0
- package/cpp/llama.cpp/ggml/src/ggml.c +204 -50
- package/cpp/llama.cpp/gguf-py/gguf/constants.py +187 -2
- package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +11 -2
- package/cpp/llama.cpp/gguf-py/gguf/quants.py +53 -4
- package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_convert_endian.py +67 -63
- package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_new_metadata.py +7 -1
- package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +120 -16
- package/cpp/llama.cpp/gguf-py/gguf/utility.py +5 -1
- package/cpp/llama.cpp/gguf-py/gguf/vocab.py +284 -1
- package/cpp/llama.cpp/gguf-py/tests/test_quants.py +14 -5
- package/cpp/llama.cpp/include/llama.h +53 -114
- package/cpp/llama.cpp/models/templates/ByteDance-Seed-OSS.jinja +171 -0
- package/cpp/llama.cpp/models/templates/README.md +2 -1
- package/cpp/llama.cpp/models/templates/ibm-granite-granite-3.3-2B-Instruct.jinja +59 -0
- package/cpp/llama.cpp/models/templates/openai-gpt-oss-120b.jinja +331 -0
- package/cpp/llama.cpp/models/templates/unsloth-mistral-Devstral-Small-2507.jinja +105 -0
- package/cpp/llama.cpp/requirements/requirements-convert_hf_to_gguf.txt +3 -1
- package/cpp/llama.cpp/requirements/requirements-convert_hf_to_gguf_update.txt +0 -6
- package/cpp/llama.cpp/requirements/requirements-pydantic.txt +1 -1
- package/cpp/llama.cpp/src/CMakeLists.txt +2 -2
- package/cpp/llama.cpp/src/llama-adapter.cpp +68 -4
- package/cpp/llama.cpp/src/llama-adapter.h +3 -0
- package/cpp/llama.cpp/src/llama-arch.cpp +192 -2
- package/cpp/llama.cpp/src/llama-arch.h +18 -0
- package/cpp/llama.cpp/src/llama-batch.cpp +2 -2
- package/cpp/llama.cpp/src/llama-chat.cpp +47 -6
- package/cpp/llama.cpp/src/llama-chat.h +3 -0
- package/cpp/llama.cpp/src/llama-context.cpp +61 -252
- package/cpp/llama.cpp/src/llama-context.h +10 -15
- package/cpp/llama.cpp/src/llama-cparams.h +0 -1
- package/cpp/llama.cpp/src/llama-graph.cpp +180 -85
- package/cpp/llama.cpp/src/llama-graph.h +90 -51
- package/cpp/llama.cpp/src/llama-hparams.cpp +34 -3
- package/cpp/llama.cpp/src/llama-hparams.h +21 -6
- package/cpp/llama.cpp/src/{llama-kv-cache-unified-iswa.cpp → llama-kv-cache-iswa.cpp} +79 -56
- package/cpp/llama.cpp/src/{llama-kv-cache-unified-iswa.h → llama-kv-cache-iswa.h} +30 -28
- package/cpp/llama.cpp/src/{llama-kv-cache-unified.cpp → llama-kv-cache.cpp} +240 -632
- package/cpp/llama.cpp/src/{llama-kv-cache-unified.h → llama-kv-cache.h} +39 -74
- package/cpp/llama.cpp/src/llama-kv-cells.h +21 -21
- package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +41 -35
- package/cpp/llama.cpp/src/llama-memory-hybrid.h +26 -29
- package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +13 -9
- package/cpp/llama.cpp/src/llama-memory-recurrent.h +10 -14
- package/cpp/llama.cpp/src/llama-memory.h +13 -10
- package/cpp/llama.cpp/src/llama-model-loader.cpp +2 -0
- package/cpp/llama.cpp/src/llama-model-loader.h +3 -2
- package/cpp/llama.cpp/src/llama-model.cpp +1959 -419
- package/cpp/llama.cpp/src/llama-model.h +28 -4
- package/cpp/llama.cpp/src/llama-quant.cpp +40 -4
- package/cpp/llama.cpp/src/llama-vocab.cpp +51 -2
- package/cpp/llama.cpp/src/llama-vocab.h +1 -0
- package/cpp/llama.cpp/vendor/minja/chat-template.hpp +16 -7
- package/cpp/llama.cpp/vendor/minja/minja.hpp +47 -12
- package/cpp/rn-completion.cpp +3 -27
- package/ios/generated/RNLlamaCppSpec/RNLlamaCppSpec.h +30 -0
- package/ios/generated/RNLlamaCppSpecJSI.h +49 -4
- package/ios/include/chat.h +8 -1
- package/ios/include/common/minja/chat-template.hpp +16 -7
- package/ios/include/common/minja/minja.hpp +47 -12
- package/ios/include/common.h +64 -15
- package/ios/include/llama.h +53 -114
- package/ios/include/speculative.h +8 -1
- package/ios/libs/llama.xcframework/Info.plist +18 -18
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5557 -5267
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5520 -5238
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4241 -4014
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5519 -5238
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4242 -4016
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5556 -5267
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5519 -5238
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4241 -4014
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5553 -5303
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5515 -5274
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4238 -4044
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/lib/module/NativeRNLlamaCpp.js.map +1 -1
- package/lib/typescript/src/NativeRNLlamaCpp.d.ts +5 -0
- package/lib/typescript/src/NativeRNLlamaCpp.d.ts.map +1 -1
- package/package.json +1 -2
- package/src/NativeRNLlamaCpp.ts +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +0 -56
|
@@ -5,14 +5,16 @@
|
|
|
5
5
|
#define FATTN_KQ_STRIDE_TILE_F16 64
|
|
6
6
|
|
|
7
7
|
template<int D, int ncols, int nwarps, bool use_logit_softcap> // D == head size
|
|
8
|
-
#if !
|
|
8
|
+
#if !defined(GGML_USE_HIP)
|
|
9
9
|
__launch_bounds__(nwarps*WARP_SIZE, 2)
|
|
10
|
-
#endif // !
|
|
10
|
+
#endif // !defined(GGML_USE_HIP)
|
|
11
11
|
static __global__ void flash_attn_tile_ext_f16(
|
|
12
12
|
const char * __restrict__ Q,
|
|
13
13
|
const char * __restrict__ K,
|
|
14
14
|
const char * __restrict__ V,
|
|
15
15
|
const char * __restrict__ mask,
|
|
16
|
+
const char * __restrict__ sinks,
|
|
17
|
+
const int * __restrict__ KV_max,
|
|
16
18
|
float * __restrict__ dst,
|
|
17
19
|
float2 * __restrict__ dst_meta,
|
|
18
20
|
const float scale,
|
|
@@ -47,10 +49,11 @@ static __global__ void flash_attn_tile_ext_f16(
|
|
|
47
49
|
const int sequence = blockIdx.z / ne02;
|
|
48
50
|
const int head = blockIdx.z - sequence*ne02;
|
|
49
51
|
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
|
50
|
-
const float2 * Q_f2
|
|
51
|
-
const half2 * K_h2
|
|
52
|
-
const half2 * V_h2
|
|
53
|
-
const half * maskh
|
|
52
|
+
const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0);
|
|
53
|
+
const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
|
|
54
|
+
const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
|
|
55
|
+
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
|
|
56
|
+
const float * sinksf = (const float *) (sinks);
|
|
54
57
|
|
|
55
58
|
const int stride_KV2 = nb11 / sizeof(half2);
|
|
56
59
|
|
|
@@ -90,7 +93,8 @@ static __global__ void flash_attn_tile_ext_f16(
|
|
|
90
93
|
|
|
91
94
|
__syncthreads();
|
|
92
95
|
|
|
93
|
-
|
|
96
|
+
const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
|
|
97
|
+
for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE_TILE_F16; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE_TILE_F16) {
|
|
94
98
|
// Calculate KQ tile and keep track of new maximum KQ values:
|
|
95
99
|
|
|
96
100
|
half kqmax_new[ncols/nwarps];
|
|
@@ -239,6 +243,31 @@ static __global__ void flash_attn_tile_ext_f16(
|
|
|
239
243
|
__syncthreads();
|
|
240
244
|
}
|
|
241
245
|
|
|
246
|
+
//Attention sink: adjust running max and sum once per head
|
|
247
|
+
if (sinksf && blockIdx.y == 0) {
|
|
248
|
+
const half sink = __float2half(sinksf[head]);
|
|
249
|
+
|
|
250
|
+
#pragma unroll
|
|
251
|
+
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
|
252
|
+
half kqmax_new_j = fmaxf(kqmax[j0/nwarps], sink);
|
|
253
|
+
kqmax_new_j = warp_reduce_max(kqmax_new_j);
|
|
254
|
+
|
|
255
|
+
const half2 KQ_max_scale = __half2half2(hexp(kqmax[j0/nwarps] - kqmax_new_j));
|
|
256
|
+
kqmax[j0/nwarps] = kqmax_new_j;
|
|
257
|
+
|
|
258
|
+
const half val = hexp(sink - kqmax[j0/nwarps]);
|
|
259
|
+
kqsum[j0/nwarps] = kqsum[j0/nwarps] * KQ_max_scale;
|
|
260
|
+
if (threadIdx.x == 0) {
|
|
261
|
+
kqsum[j0/nwarps].x = __hadd(__low2half(kqsum[j0/nwarps]), val);
|
|
262
|
+
}
|
|
263
|
+
|
|
264
|
+
#pragma unroll
|
|
265
|
+
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
|
266
|
+
VKQ[j0/nwarps][i0/WARP_SIZE] *= KQ_max_scale;
|
|
267
|
+
}
|
|
268
|
+
}
|
|
269
|
+
}
|
|
270
|
+
|
|
242
271
|
float2 * dst2 = (float2 *) dst;
|
|
243
272
|
|
|
244
273
|
#pragma unroll
|
|
@@ -270,17 +299,15 @@ static __global__ void flash_attn_tile_ext_f16(
|
|
|
270
299
|
}
|
|
271
300
|
}
|
|
272
301
|
#else
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
|
|
283
|
-
GGML_UNUSED(nb23);
|
|
302
|
+
GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
|
|
303
|
+
max_bias, m0, m1, n_head_log2, logit_softcap,
|
|
304
|
+
ne00, ne01, ne02, ne03,
|
|
305
|
+
nb01, nb02, nb03,
|
|
306
|
+
ne10, ne11, ne12, ne13,
|
|
307
|
+
nb11, nb12, nb13,
|
|
308
|
+
nb21, nb22, nb23,
|
|
309
|
+
ne31, ne32, ne33,
|
|
310
|
+
nb31, nb32, nb33);
|
|
284
311
|
NO_DEVICE_CODE;
|
|
285
312
|
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
|
|
286
313
|
}
|
|
@@ -5,14 +5,16 @@
|
|
|
5
5
|
#define FATTN_KQ_STRIDE_TILE_F32 32
|
|
6
6
|
|
|
7
7
|
template<int D, int ncols, int nwarps, bool use_logit_softcap> // D == head size
|
|
8
|
-
#if !
|
|
8
|
+
#if !defined(GGML_USE_HIP)
|
|
9
9
|
__launch_bounds__(nwarps*WARP_SIZE, 2)
|
|
10
|
-
#endif // !
|
|
10
|
+
#endif // !defined(GGML_USE_HIP)
|
|
11
11
|
static __global__ void flash_attn_tile_ext_f32(
|
|
12
12
|
const char * __restrict__ Q,
|
|
13
13
|
const char * __restrict__ K,
|
|
14
14
|
const char * __restrict__ V,
|
|
15
15
|
const char * __restrict__ mask,
|
|
16
|
+
const char * __restrict__ sinks,
|
|
17
|
+
const int * __restrict__ KV_max,
|
|
16
18
|
float * __restrict__ dst,
|
|
17
19
|
float2 * __restrict__ dst_meta,
|
|
18
20
|
const float scale,
|
|
@@ -36,17 +38,15 @@ static __global__ void flash_attn_tile_ext_f32(
|
|
|
36
38
|
return;
|
|
37
39
|
#endif // FP16_MMA_AVAILABLE
|
|
38
40
|
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
|
|
49
|
-
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33);
|
|
41
|
+
GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
|
|
42
|
+
max_bias, m0, m1, n_head_log2, logit_softcap,
|
|
43
|
+
ne00, ne01, ne02, ne03,
|
|
44
|
+
nb01, nb02, nb03,
|
|
45
|
+
ne10, ne11, ne12, ne13,
|
|
46
|
+
nb11, nb12, nb13,
|
|
47
|
+
nb21, nb22, nb23,
|
|
48
|
+
ne31, ne32, ne33,
|
|
49
|
+
nb31, nb32, nb33);
|
|
50
50
|
NO_DEVICE_CODE;
|
|
51
51
|
return;
|
|
52
52
|
}
|
|
@@ -58,10 +58,11 @@ static __global__ void flash_attn_tile_ext_f32(
|
|
|
58
58
|
const int sequence = blockIdx.z / ne02;
|
|
59
59
|
const int head = blockIdx.z - sequence*ne02;
|
|
60
60
|
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
|
61
|
-
const float2 * Q_f2
|
|
62
|
-
const half2 * K_h2
|
|
63
|
-
const half2 * V_h2
|
|
64
|
-
const half * maskh
|
|
61
|
+
const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0);
|
|
62
|
+
const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
|
|
63
|
+
const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
|
|
64
|
+
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
|
|
65
|
+
const float * sinksf = (const float *) (sinks);
|
|
65
66
|
|
|
66
67
|
const int stride_KV2 = nb11 / sizeof(half2);
|
|
67
68
|
|
|
@@ -99,7 +100,8 @@ static __global__ void flash_attn_tile_ext_f32(
|
|
|
99
100
|
|
|
100
101
|
__syncthreads();
|
|
101
102
|
|
|
102
|
-
|
|
103
|
+
const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
|
|
104
|
+
for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE_TILE_F32; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE_TILE_F32) {
|
|
103
105
|
// Calculate KQ tile and keep track of new maximum KQ values:
|
|
104
106
|
|
|
105
107
|
float kqmax_new[ncols/nwarps];
|
|
@@ -249,6 +251,33 @@ static __global__ void flash_attn_tile_ext_f32(
|
|
|
249
251
|
__syncthreads();
|
|
250
252
|
}
|
|
251
253
|
|
|
254
|
+
|
|
255
|
+
//Attention sink: adjust running max and sum once per head
|
|
256
|
+
if (sinksf && blockIdx.y == 0) {
|
|
257
|
+
const float sink = sinksf[head];
|
|
258
|
+
|
|
259
|
+
#pragma unroll
|
|
260
|
+
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
|
261
|
+
float kqmax_new_j = fmaxf(kqmax[j0/nwarps], sink);
|
|
262
|
+
kqmax_new_j = warp_reduce_max(kqmax_new_j);
|
|
263
|
+
|
|
264
|
+
const float KQ_max_scale = expf(kqmax[j0/nwarps] - kqmax_new_j);
|
|
265
|
+
kqmax[j0/nwarps] = kqmax_new_j;
|
|
266
|
+
|
|
267
|
+
const float val = expf(sink - kqmax[j0/nwarps]);
|
|
268
|
+
kqsum[j0/nwarps] = kqsum[j0/nwarps] * KQ_max_scale;
|
|
269
|
+
if (threadIdx.x == 0) {
|
|
270
|
+
kqsum[j0/nwarps] += val;
|
|
271
|
+
}
|
|
272
|
+
|
|
273
|
+
#pragma unroll
|
|
274
|
+
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
|
275
|
+
VKQ[j0/nwarps][i0/WARP_SIZE].x *= KQ_max_scale;
|
|
276
|
+
VKQ[j0/nwarps][i0/WARP_SIZE].y *= KQ_max_scale;
|
|
277
|
+
}
|
|
278
|
+
}
|
|
279
|
+
}
|
|
280
|
+
|
|
252
281
|
float2 * dst2 = (float2 *) dst;
|
|
253
282
|
|
|
254
283
|
#pragma unroll
|
|
@@ -281,17 +310,15 @@ static __global__ void flash_attn_tile_ext_f32(
|
|
|
281
310
|
}
|
|
282
311
|
}
|
|
283
312
|
#else
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
|
|
294
|
-
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33);
|
|
313
|
+
GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
|
|
314
|
+
max_bias, m0, m1, n_head_log2, logit_softcap,
|
|
315
|
+
ne00, ne01, ne02, ne03,
|
|
316
|
+
nb01, nb02, nb03,
|
|
317
|
+
ne10, ne11, ne12, ne13,
|
|
318
|
+
nb11, nb12, nb13,
|
|
319
|
+
nb21, nb22, nb23,
|
|
320
|
+
ne31, ne32, ne33,
|
|
321
|
+
nb31, nb32, nb33);
|
|
295
322
|
NO_DEVICE_CODE;
|
|
296
323
|
#endif // FLASH_ATTN_AVAILABLE
|
|
297
324
|
}
|
|
@@ -1,6 +1,12 @@
|
|
|
1
1
|
#include "common.cuh"
|
|
2
2
|
#include "fattn-common.cuh"
|
|
3
3
|
|
|
4
|
+
// Currenlty llvm with the amdgcn target dose not support unrolling loops
|
|
5
|
+
// that contain a break that can not be resolved at compile time.
|
|
6
|
+
#ifdef __clang__
|
|
7
|
+
#pragma clang diagnostic push
|
|
8
|
+
#pragma clang diagnostic ignored "-Wpass-failed"
|
|
9
|
+
#endif // __clang__
|
|
4
10
|
template<int D, int ncols, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size
|
|
5
11
|
#ifndef GGML_USE_HIP
|
|
6
12
|
__launch_bounds__(D, 1)
|
|
@@ -10,6 +16,8 @@ static __global__ void flash_attn_vec_ext_f16(
|
|
|
10
16
|
const char * __restrict__ K,
|
|
11
17
|
const char * __restrict__ V,
|
|
12
18
|
const char * __restrict__ mask,
|
|
19
|
+
const char * __restrict__ sinks,
|
|
20
|
+
const int * __restrict__ KV_max,
|
|
13
21
|
float * __restrict__ dst,
|
|
14
22
|
float2 * __restrict__ dst_meta,
|
|
15
23
|
const float scale,
|
|
@@ -54,7 +62,8 @@ static __global__ void flash_attn_vec_ext_f16(
|
|
|
54
62
|
K += nb13*sequence + nb12*(head / gqa_ratio);
|
|
55
63
|
V += nb23*sequence + nb22*(head / gqa_ratio);
|
|
56
64
|
|
|
57
|
-
const half
|
|
65
|
+
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
|
|
66
|
+
const float * sinksf = (const float *) (sinks);
|
|
58
67
|
|
|
59
68
|
const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
|
|
60
69
|
const half slopeh = __float2half(slopef);
|
|
@@ -68,11 +77,12 @@ static __global__ void flash_attn_vec_ext_f16(
|
|
|
68
77
|
half2 * KQ2 = (half2 *) KQ;
|
|
69
78
|
|
|
70
79
|
half kqmax[ncols];
|
|
80
|
+
half kqsum[ncols];
|
|
71
81
|
#pragma unroll
|
|
72
82
|
for (int j = 0; j < ncols; ++j) {
|
|
73
83
|
kqmax[j] = -HALF_MAX_HALF;
|
|
84
|
+
kqsum[j] = 0.0f;
|
|
74
85
|
}
|
|
75
|
-
half kqsum[ncols] = {0.0f};
|
|
76
86
|
|
|
77
87
|
__shared__ half kqmax_shared[ncols][WARP_SIZE];
|
|
78
88
|
__shared__ half kqsum_shared[ncols][WARP_SIZE];
|
|
@@ -171,10 +181,14 @@ static __global__ void flash_attn_vec_ext_f16(
|
|
|
171
181
|
|
|
172
182
|
half2 VKQ[ncols] = {{0.0f, 0.0f}};
|
|
173
183
|
|
|
184
|
+
const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
|
|
174
185
|
K += blockIdx.y*D * nb11;
|
|
175
186
|
V += blockIdx.y*D * nb21;
|
|
176
187
|
maskh += blockIdx.y*D;
|
|
177
|
-
for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 <
|
|
188
|
+
for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*D,
|
|
189
|
+
// Increment pointers after each loop:
|
|
190
|
+
K += gridDim.y*D*nb11, V += gridDim.y*D*nb21, maskh += gridDim.y*D) {
|
|
191
|
+
|
|
178
192
|
// Calculate KQ tile and keep track of new maximum KQ values:
|
|
179
193
|
|
|
180
194
|
if (mask) {
|
|
@@ -182,29 +196,7 @@ static __global__ void flash_attn_vec_ext_f16(
|
|
|
182
196
|
for (int j = 0; j < ncols; ++j) {
|
|
183
197
|
maskh_shared[j*D + tid] = slopeh*maskh[j*ne11 + tid];
|
|
184
198
|
}
|
|
185
|
-
|
|
186
199
|
__syncthreads();
|
|
187
|
-
|
|
188
|
-
// When using multiple parallel sequences in llama.cpp, some KV slices can be fully masked out.
|
|
189
|
-
// In such cases, skip the KV slice.
|
|
190
|
-
// On AMD __all_sync would not work correctly because it assumes a warp size of 64.
|
|
191
|
-
#ifndef GGML_USE_HIP
|
|
192
|
-
bool skip = true;
|
|
193
|
-
#pragma unroll
|
|
194
|
-
for (int j = 0; j < ncols; ++j) {
|
|
195
|
-
#pragma unroll
|
|
196
|
-
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
|
197
|
-
const int i = i0 + threadIdx.x;
|
|
198
|
-
|
|
199
|
-
const float2 tmp = __half22float2(((const half2 *) maskh_shared)[j*(D/2) + i]);
|
|
200
|
-
skip = skip && isinf(tmp.x) && isinf(tmp.y);
|
|
201
|
-
}
|
|
202
|
-
}
|
|
203
|
-
if (__all_sync(0xFFFFFFFF, skip)) {
|
|
204
|
-
__syncthreads();
|
|
205
|
-
continue;
|
|
206
|
-
}
|
|
207
|
-
#endif // GGML_USE_HIP
|
|
208
200
|
}
|
|
209
201
|
|
|
210
202
|
// For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression,
|
|
@@ -291,9 +283,38 @@ static __global__ void flash_attn_vec_ext_f16(
|
|
|
291
283
|
}
|
|
292
284
|
}
|
|
293
285
|
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
286
|
+
__syncthreads();
|
|
287
|
+
}
|
|
288
|
+
|
|
289
|
+
if (sinksf && blockIdx.y == 0) {
|
|
290
|
+
const half sink = __float2half(sinksf[head]);
|
|
291
|
+
|
|
292
|
+
#pragma unroll
|
|
293
|
+
for (int j = 0; j < ncols; ++j) {
|
|
294
|
+
if (threadIdx.x == 0) {
|
|
295
|
+
kqmax_shared[j][threadIdx.y] = fmaxf(kqmax[j], sink);
|
|
296
|
+
}
|
|
297
|
+
}
|
|
298
|
+
|
|
299
|
+
__syncthreads();
|
|
300
|
+
|
|
301
|
+
#pragma unroll
|
|
302
|
+
for (int j = 0; j < ncols; ++j) {
|
|
303
|
+
half kqmax_new_j = kqmax_shared[j][threadIdx.x];
|
|
304
|
+
kqmax_new_j = warp_reduce_max(kqmax_new_j);
|
|
305
|
+
|
|
306
|
+
const half KQ_max_scale = hexp(kqmax[j] - kqmax_new_j);
|
|
307
|
+
kqmax[j] = kqmax_new_j;
|
|
308
|
+
|
|
309
|
+
const half val = hexp(sink - kqmax[j]);
|
|
310
|
+
kqsum[j] = kqsum[j]*KQ_max_scale;
|
|
311
|
+
|
|
312
|
+
if (tid == 0) {
|
|
313
|
+
kqsum[j] += val;
|
|
314
|
+
}
|
|
315
|
+
|
|
316
|
+
VKQ[j] *= __half2half2(KQ_max_scale);
|
|
317
|
+
}
|
|
297
318
|
|
|
298
319
|
__syncthreads();
|
|
299
320
|
}
|
|
@@ -328,20 +349,21 @@ static __global__ void flash_attn_vec_ext_f16(
|
|
|
328
349
|
dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
|
|
329
350
|
}
|
|
330
351
|
#else
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
|
|
341
|
-
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33);
|
|
352
|
+
GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
|
|
353
|
+
max_bias, m0, m1, n_head_log2, logit_softcap,
|
|
354
|
+
ne00, ne01, ne02, ne03,
|
|
355
|
+
nb01, nb02, nb03,
|
|
356
|
+
ne10, ne11, ne12, ne13,
|
|
357
|
+
nb11, nb12, nb13,
|
|
358
|
+
nb21, nb22, nb23,
|
|
359
|
+
ne31, ne32, ne33,
|
|
360
|
+
nb31, nb32, nb33);
|
|
342
361
|
NO_DEVICE_CODE;
|
|
343
362
|
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
|
|
344
363
|
}
|
|
364
|
+
#ifdef __clang__
|
|
365
|
+
#pragma clang diagnostic pop
|
|
366
|
+
#endif // __clang__
|
|
345
367
|
|
|
346
368
|
template <int D, int cols_per_block, ggml_type type_K, ggml_type type_V, bool use_logit_softcap>
|
|
347
369
|
void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
@@ -1,6 +1,12 @@
|
|
|
1
1
|
#include "common.cuh"
|
|
2
2
|
#include "fattn-common.cuh"
|
|
3
3
|
|
|
4
|
+
// Currenlty llvm with the amdgcn target dose not support unrolling loops
|
|
5
|
+
// that contain a break that can not be resolved at compile time.
|
|
6
|
+
#ifdef __clang__
|
|
7
|
+
#pragma clang diagnostic push
|
|
8
|
+
#pragma clang diagnostic ignored "-Wpass-failed"
|
|
9
|
+
#endif // __clang__
|
|
4
10
|
template<int D, int ncols, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size
|
|
5
11
|
#ifndef GGML_USE_HIP
|
|
6
12
|
__launch_bounds__(D, 1)
|
|
@@ -10,6 +16,8 @@ static __global__ void flash_attn_vec_ext_f32(
|
|
|
10
16
|
const char * __restrict__ K,
|
|
11
17
|
const char * __restrict__ V,
|
|
12
18
|
const char * __restrict__ mask,
|
|
19
|
+
const char * __restrict__ sinks,
|
|
20
|
+
const int * __restrict__ KV_max,
|
|
13
21
|
float * __restrict__ dst,
|
|
14
22
|
float2 * __restrict__ dst_meta,
|
|
15
23
|
const float scale,
|
|
@@ -29,17 +37,15 @@ static __global__ void flash_attn_vec_ext_f32(
|
|
|
29
37
|
|
|
30
38
|
// Skip unused kernel variants for faster compilation:
|
|
31
39
|
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
|
|
42
|
-
GGML_UNUSED(nb23);
|
|
40
|
+
GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
|
|
41
|
+
max_bias, m0, m1, n_head_log2, logit_softcap,
|
|
42
|
+
ne00, ne01, ne02, ne03,
|
|
43
|
+
nb01, nb02, nb03,
|
|
44
|
+
ne10, ne11, ne12, ne13,
|
|
45
|
+
nb11, nb12, nb13,
|
|
46
|
+
nb21, nb22, nb23,
|
|
47
|
+
ne31, ne32, ne33,
|
|
48
|
+
nb31, nb32, nb33);
|
|
43
49
|
NO_DEVICE_CODE;
|
|
44
50
|
return;
|
|
45
51
|
}
|
|
@@ -65,7 +71,8 @@ static __global__ void flash_attn_vec_ext_f32(
|
|
|
65
71
|
K += nb13*sequence + nb12*(head / gqa_ratio);
|
|
66
72
|
V += nb23*sequence + nb22*(head / gqa_ratio);
|
|
67
73
|
|
|
68
|
-
const half
|
|
74
|
+
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
|
|
75
|
+
const float * sinksf = (const float *) (sinks);
|
|
69
76
|
|
|
70
77
|
const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
|
|
71
78
|
|
|
@@ -81,11 +88,12 @@ static __global__ void flash_attn_vec_ext_f32(
|
|
|
81
88
|
}
|
|
82
89
|
|
|
83
90
|
float kqmax[ncols];
|
|
91
|
+
float kqsum[ncols];
|
|
84
92
|
#pragma unroll
|
|
85
93
|
for (int j = 0; j < ncols; ++j) {
|
|
86
94
|
kqmax[j] = -FLT_MAX/2.0f;
|
|
95
|
+
kqsum[j] = 0.0f;
|
|
87
96
|
}
|
|
88
|
-
float kqsum[ncols] = {0.0f};
|
|
89
97
|
|
|
90
98
|
__shared__ float kqmax_shared[ncols][WARP_SIZE];
|
|
91
99
|
__shared__ float kqsum_shared[ncols][WARP_SIZE];
|
|
@@ -177,10 +185,14 @@ static __global__ void flash_attn_vec_ext_f32(
|
|
|
177
185
|
|
|
178
186
|
float VKQ[ncols] = {0.0f};
|
|
179
187
|
|
|
188
|
+
const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
|
|
180
189
|
K += blockIdx.y*D * nb11;
|
|
181
190
|
V += blockIdx.y*D * nb21;
|
|
182
191
|
maskh += blockIdx.y*D;
|
|
183
|
-
for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 <
|
|
192
|
+
for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*D,
|
|
193
|
+
// Increment pointers after each loop:
|
|
194
|
+
K += gridDim.y*D*nb11, V += gridDim.y*D*nb21, maskh += gridDim.y*D) {
|
|
195
|
+
|
|
184
196
|
// Calculate KQ tile and keep track of new maximum KQ values:
|
|
185
197
|
|
|
186
198
|
if (mask) {
|
|
@@ -188,28 +200,7 @@ static __global__ void flash_attn_vec_ext_f32(
|
|
|
188
200
|
for (int j = 0; j < ncols; ++j) {
|
|
189
201
|
maskf_shared[j*D + tid] = slope*__half2float(maskh[j*ne11 + tid]);
|
|
190
202
|
}
|
|
191
|
-
|
|
192
203
|
__syncthreads();
|
|
193
|
-
|
|
194
|
-
// When using multiple parallel sequences in llama.cpp, some KV slices can be fully masked out.
|
|
195
|
-
// In such cases, skip the KV slice.
|
|
196
|
-
// On AMD __all_sync would not work correctly because it assumes a warp size of 64.
|
|
197
|
-
#ifndef GGML_USE_HIP
|
|
198
|
-
bool skip = true;
|
|
199
|
-
#pragma unroll
|
|
200
|
-
for (int j = 0; j < ncols; ++j) {
|
|
201
|
-
#pragma unroll
|
|
202
|
-
for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
|
|
203
|
-
const int i = i0 + threadIdx.x;
|
|
204
|
-
|
|
205
|
-
skip = skip && isinf(maskf_shared[j*D + i]);
|
|
206
|
-
}
|
|
207
|
-
}
|
|
208
|
-
if (__all_sync(0xFFFFFFFF, skip)) {
|
|
209
|
-
__syncthreads();
|
|
210
|
-
continue;
|
|
211
|
-
}
|
|
212
|
-
#endif // GGML_USE_HIP
|
|
213
204
|
}
|
|
214
205
|
|
|
215
206
|
float kqmax_new_arr[ncols];
|
|
@@ -286,9 +277,38 @@ static __global__ void flash_attn_vec_ext_f32(
|
|
|
286
277
|
}
|
|
287
278
|
}
|
|
288
279
|
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
280
|
+
__syncthreads();
|
|
281
|
+
}
|
|
282
|
+
|
|
283
|
+
if (sinksf && blockIdx.y == 0) {
|
|
284
|
+
const float sink = sinksf[head];
|
|
285
|
+
|
|
286
|
+
#pragma unroll
|
|
287
|
+
for (int j = 0; j < ncols; ++j) {
|
|
288
|
+
if (threadIdx.x == 0) {
|
|
289
|
+
kqmax_shared[j][threadIdx.y] = fmaxf(kqmax[j], sink);
|
|
290
|
+
}
|
|
291
|
+
}
|
|
292
|
+
|
|
293
|
+
__syncthreads();
|
|
294
|
+
|
|
295
|
+
#pragma unroll
|
|
296
|
+
for (int j = 0; j < ncols; ++j) {
|
|
297
|
+
float kqmax_new_j = kqmax_shared[j][threadIdx.x];
|
|
298
|
+
kqmax_new_j = warp_reduce_max(kqmax_new_j);
|
|
299
|
+
|
|
300
|
+
const float KQ_max_scale = expf(kqmax[j] - kqmax_new_j);
|
|
301
|
+
kqmax[j] = kqmax_new_j;
|
|
302
|
+
|
|
303
|
+
const float val = expf(sink - kqmax[j]);
|
|
304
|
+
kqsum[j] = kqsum[j]*KQ_max_scale;
|
|
305
|
+
|
|
306
|
+
if (tid == 0) {
|
|
307
|
+
kqsum[j] += val;
|
|
308
|
+
}
|
|
309
|
+
|
|
310
|
+
VKQ[j] *= KQ_max_scale;
|
|
311
|
+
}
|
|
292
312
|
|
|
293
313
|
__syncthreads();
|
|
294
314
|
}
|
|
@@ -323,20 +343,21 @@ static __global__ void flash_attn_vec_ext_f32(
|
|
|
323
343
|
dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
|
|
324
344
|
}
|
|
325
345
|
#else
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
|
|
336
|
-
GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
|
|
346
|
+
GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
|
|
347
|
+
max_bias, m0, m1, n_head_log2, logit_softcap,
|
|
348
|
+
ne00, ne01, ne02, ne03,
|
|
349
|
+
nb01, nb02, nb03,
|
|
350
|
+
ne10, ne11, ne12, ne13,
|
|
351
|
+
nb11, nb12, nb13,
|
|
352
|
+
nb21, nb22, nb23,
|
|
353
|
+
ne31, ne32, ne33,
|
|
354
|
+
nb31, nb32, nb33);
|
|
337
355
|
NO_DEVICE_CODE;
|
|
338
356
|
#endif // FLASH_ATTN_AVAILABLE
|
|
339
357
|
}
|
|
358
|
+
#ifdef __clang__
|
|
359
|
+
#pragma clang diagnostic pop
|
|
360
|
+
#endif // __clang__
|
|
340
361
|
|
|
341
362
|
template <int D, int cols_per_block, ggml_type type_K, ggml_type type_V, bool use_logit_softcap>
|
|
342
363
|
void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|