@novastera-oss/llamarn 0.3.1 → 0.4.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +86 -3
- package/RNLlamaCpp.podspec +1 -1
- package/android/CMakeLists.txt +11 -3
- package/android/generated/jni/react/renderer/components/RNLlamaCppSpec/RNLlamaCppSpecJSI.h +49 -4
- package/android/src/main/cpp/include/llama.h +53 -114
- package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86/libggml.so +0 -0
- package/android/src/main/jniLibs/x86/libllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
- package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
- package/cpp/LlamaCppModel.cpp +2 -10
- package/cpp/PureCppImpl.cpp +71 -4
- package/cpp/SystemUtils.cpp +3 -7
- package/cpp/build-info.cpp +2 -2
- package/cpp/llama.cpp/CMakeLists.txt +2 -0
- package/cpp/llama.cpp/CODEOWNERS +1 -1
- package/cpp/llama.cpp/Makefile +6 -1605
- package/cpp/llama.cpp/README.md +5 -1
- package/cpp/llama.cpp/common/arg.cpp +230 -51
- package/cpp/llama.cpp/common/chat-parser.cpp +9 -1
- package/cpp/llama.cpp/common/chat.cpp +539 -8
- package/cpp/llama.cpp/common/chat.h +8 -1
- package/cpp/llama.cpp/common/common.cpp +60 -15
- package/cpp/llama.cpp/common/common.h +64 -15
- package/cpp/llama.cpp/common/speculative.cpp +135 -54
- package/cpp/llama.cpp/common/speculative.h +8 -1
- package/cpp/llama.cpp/convert_hf_to_gguf.py +1216 -109
- package/cpp/llama.cpp/convert_hf_to_gguf_update.py +19 -6
- package/cpp/llama.cpp/convert_lora_to_gguf.py +1 -1
- package/cpp/llama.cpp/flake.nix +0 -5
- package/cpp/llama.cpp/ggml/CMakeLists.txt +6 -3
- package/cpp/llama.cpp/ggml/cmake/ggml-config.cmake.in +71 -70
- package/cpp/llama.cpp/ggml/include/ggml-opt.h +25 -6
- package/cpp/llama.cpp/ggml/include/ggml-zdnn.h +16 -0
- package/cpp/llama.cpp/ggml/include/ggml.h +90 -3
- package/cpp/llama.cpp/ggml/src/CMakeLists.txt +13 -1
- package/cpp/llama.cpp/ggml/src/ggml-alloc.c +1 -0
- package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +10 -0
- package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +113 -17
- package/cpp/llama.cpp/ggml/src/ggml-blas/ggml-blas.cpp +4 -4
- package/cpp/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +14 -0
- package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +701 -585
- package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +13 -3
- package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +52 -0
- package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +274 -91
- package/cpp/llama.cpp/ggml/src/ggml-common.h +17 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +132 -596
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +14 -286
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +90 -569
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +162 -589
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +55 -341
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +3 -58
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +371 -298
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +54 -314
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +184 -675
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/repack.cpp +4679 -1657
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +33 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +26 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +21 -24
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +16 -7
- package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +232 -123
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +428 -23
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +4 -8
- package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +35 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.h +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +458 -46
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.h +22 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +39 -14
- package/cpp/llama.cpp/ggml/src/ggml-cpu/traits.cpp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/traits.h +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +20 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +122 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +9 -11
- package/cpp/llama.cpp/ggml/src/ggml-cuda/add-id.cu +58 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/add-id.cuh +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/binbcast.cu +275 -170
- package/cpp/llama.cpp/ggml/src/ggml-cuda/binbcast.cuh +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +103 -65
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d.cu +171 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +33 -7
- package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +13 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy-utils.cuh +2 -10
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy.cu +3 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/dequantize.cuh +14 -40
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +83 -27
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +116 -57
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +45 -18
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +56 -29
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +61 -39
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +70 -49
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +70 -21
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cu +162 -50
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cuh +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +5 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +208 -97
- package/cpp/llama.cpp/ggml/src/ggml-cuda/im2col.cu +46 -35
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cu +56 -2
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mma.cuh +95 -51
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmf.cu +427 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmf.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cu +204 -57
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +252 -168
- package/cpp/llama.cpp/ggml/src/ggml-cuda/{mmv.cu → mmvf.cu} +53 -53
- package/cpp/llama.cpp/ggml/src/ggml-cuda/{mmv.cuh → mmvf.cuh} +3 -3
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmvq.cu +10 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cu +192 -19
- package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/pad_reflect_1d.cu +82 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/roll.cu +67 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/roll.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cu +1 -8
- package/cpp/llama.cpp/ggml/src/ggml-cuda/softcap.cu +34 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/softcap.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/softmax.cu +16 -10
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +153 -71
- package/cpp/llama.cpp/ggml/src/ggml-cuda/sum.cu +6 -10
- package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cu +21 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +75 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +14 -25
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +2 -1
- package/cpp/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-impl.h +61 -0
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +31 -20
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +342 -131
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +464 -134
- package/cpp/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +0 -4
- package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +1108 -176
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/add.cl +107 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/div.cl +66 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +343 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +343 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +346 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl +41 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/group_norm.cl +49 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +132 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +133 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/sub.cl +66 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/transpose.cl +20 -0
- package/cpp/llama.cpp/ggml/src/ggml-opt.cpp +97 -41
- package/cpp/llama.cpp/ggml/src/ggml-quants.c +110 -16
- package/cpp/llama.cpp/ggml/src/ggml-quants.h +6 -0
- package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +22 -9
- package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +0 -212
- package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.hpp +213 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +117 -238
- package/cpp/llama.cpp/ggml/src/ggml-sycl/quantize.hpp +133 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +94 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1666 -633
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +41 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +13 -4
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +39 -29
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +107 -43
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +18 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +21 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +20 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +21 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +16 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +44 -8
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +44 -16
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +26 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +2 -17
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +37 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +11 -7
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +109 -55
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +71 -41
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +6 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +111 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +49 -11
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +9 -3
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +38 -5
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +55 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/utils.comp +25 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +75 -20
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/CMakeLists.txt +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/ggml-webgpu.cpp +807 -412
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +72 -22
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +8 -8
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +1794 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +82 -0
- package/cpp/llama.cpp/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
- package/cpp/llama.cpp/ggml/src/ggml-zdnn/ggml-zdnn-impl.h +97 -0
- package/cpp/llama.cpp/ggml/src/ggml-zdnn/ggml-zdnn.cpp +846 -0
- package/cpp/llama.cpp/ggml/src/ggml.c +204 -50
- package/cpp/llama.cpp/gguf-py/gguf/constants.py +187 -2
- package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +11 -2
- package/cpp/llama.cpp/gguf-py/gguf/quants.py +53 -4
- package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_convert_endian.py +67 -63
- package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_new_metadata.py +7 -1
- package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +120 -16
- package/cpp/llama.cpp/gguf-py/gguf/utility.py +5 -1
- package/cpp/llama.cpp/gguf-py/gguf/vocab.py +284 -1
- package/cpp/llama.cpp/gguf-py/tests/test_quants.py +14 -5
- package/cpp/llama.cpp/include/llama.h +53 -114
- package/cpp/llama.cpp/models/templates/ByteDance-Seed-OSS.jinja +171 -0
- package/cpp/llama.cpp/models/templates/README.md +2 -1
- package/cpp/llama.cpp/models/templates/ibm-granite-granite-3.3-2B-Instruct.jinja +59 -0
- package/cpp/llama.cpp/models/templates/openai-gpt-oss-120b.jinja +331 -0
- package/cpp/llama.cpp/models/templates/unsloth-mistral-Devstral-Small-2507.jinja +105 -0
- package/cpp/llama.cpp/requirements/requirements-convert_hf_to_gguf.txt +3 -1
- package/cpp/llama.cpp/requirements/requirements-convert_hf_to_gguf_update.txt +0 -6
- package/cpp/llama.cpp/requirements/requirements-pydantic.txt +1 -1
- package/cpp/llama.cpp/src/CMakeLists.txt +2 -2
- package/cpp/llama.cpp/src/llama-adapter.cpp +68 -4
- package/cpp/llama.cpp/src/llama-adapter.h +3 -0
- package/cpp/llama.cpp/src/llama-arch.cpp +192 -2
- package/cpp/llama.cpp/src/llama-arch.h +18 -0
- package/cpp/llama.cpp/src/llama-batch.cpp +2 -2
- package/cpp/llama.cpp/src/llama-chat.cpp +47 -6
- package/cpp/llama.cpp/src/llama-chat.h +3 -0
- package/cpp/llama.cpp/src/llama-context.cpp +61 -252
- package/cpp/llama.cpp/src/llama-context.h +10 -15
- package/cpp/llama.cpp/src/llama-cparams.h +0 -1
- package/cpp/llama.cpp/src/llama-graph.cpp +180 -85
- package/cpp/llama.cpp/src/llama-graph.h +90 -51
- package/cpp/llama.cpp/src/llama-hparams.cpp +34 -3
- package/cpp/llama.cpp/src/llama-hparams.h +21 -6
- package/cpp/llama.cpp/src/{llama-kv-cache-unified-iswa.cpp → llama-kv-cache-iswa.cpp} +79 -56
- package/cpp/llama.cpp/src/{llama-kv-cache-unified-iswa.h → llama-kv-cache-iswa.h} +30 -28
- package/cpp/llama.cpp/src/{llama-kv-cache-unified.cpp → llama-kv-cache.cpp} +240 -632
- package/cpp/llama.cpp/src/{llama-kv-cache-unified.h → llama-kv-cache.h} +39 -74
- package/cpp/llama.cpp/src/llama-kv-cells.h +21 -21
- package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +41 -35
- package/cpp/llama.cpp/src/llama-memory-hybrid.h +26 -29
- package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +13 -9
- package/cpp/llama.cpp/src/llama-memory-recurrent.h +10 -14
- package/cpp/llama.cpp/src/llama-memory.h +13 -10
- package/cpp/llama.cpp/src/llama-model-loader.cpp +2 -0
- package/cpp/llama.cpp/src/llama-model-loader.h +3 -2
- package/cpp/llama.cpp/src/llama-model.cpp +1959 -419
- package/cpp/llama.cpp/src/llama-model.h +28 -4
- package/cpp/llama.cpp/src/llama-quant.cpp +40 -4
- package/cpp/llama.cpp/src/llama-vocab.cpp +51 -2
- package/cpp/llama.cpp/src/llama-vocab.h +1 -0
- package/cpp/llama.cpp/vendor/minja/chat-template.hpp +16 -7
- package/cpp/llama.cpp/vendor/minja/minja.hpp +47 -12
- package/cpp/rn-completion.cpp +3 -27
- package/ios/generated/RNLlamaCppSpec/RNLlamaCppSpec.h +30 -0
- package/ios/generated/RNLlamaCppSpecJSI.h +49 -4
- package/ios/include/chat.h +8 -1
- package/ios/include/common/minja/chat-template.hpp +16 -7
- package/ios/include/common/minja/minja.hpp +47 -12
- package/ios/include/common.h +64 -15
- package/ios/include/llama.h +53 -114
- package/ios/include/speculative.h +8 -1
- package/ios/libs/llama.xcframework/Info.plist +18 -18
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5557 -5267
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5520 -5238
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4241 -4014
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5519 -5238
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4242 -4016
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5556 -5267
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5519 -5238
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4241 -4014
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5553 -5303
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5515 -5274
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4238 -4044
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/lib/module/NativeRNLlamaCpp.js.map +1 -1
- package/lib/typescript/src/NativeRNLlamaCpp.d.ts +5 -0
- package/lib/typescript/src/NativeRNLlamaCpp.d.ts.map +1 -1
- package/package.json +1 -2
- package/src/NativeRNLlamaCpp.ts +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +0 -56
|
@@ -46,10 +46,8 @@ struct llama_context {
|
|
|
46
46
|
|
|
47
47
|
llama_memory_t get_memory() const;
|
|
48
48
|
|
|
49
|
-
// return true
|
|
50
|
-
|
|
51
|
-
bool kv_self_update(bool optimize);
|
|
52
|
-
void kv_self_defrag_sched();
|
|
49
|
+
// return true if the memory was updated
|
|
50
|
+
bool memory_update(bool optimize);
|
|
53
51
|
|
|
54
52
|
enum llama_pooling_type pooling_type() const;
|
|
55
53
|
|
|
@@ -111,9 +109,9 @@ struct llama_context {
|
|
|
111
109
|
size_t state_get_data( uint8_t * dst, size_t size);
|
|
112
110
|
size_t state_set_data(const uint8_t * src, size_t size);
|
|
113
111
|
|
|
114
|
-
size_t state_seq_get_size(llama_seq_id seq_id);
|
|
115
|
-
size_t state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size);
|
|
116
|
-
size_t state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size);
|
|
112
|
+
size_t state_seq_get_size(llama_seq_id seq_id, llama_state_seq_flags flags);
|
|
113
|
+
size_t state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size, llama_state_seq_flags flags);
|
|
114
|
+
size_t state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size, llama_state_seq_flags flags);
|
|
117
115
|
|
|
118
116
|
bool state_load_file(
|
|
119
117
|
const char * filepath,
|
|
@@ -152,6 +150,7 @@ struct llama_context {
|
|
|
152
150
|
|
|
153
151
|
void opt_init(struct llama_model * model, struct llama_opt_params lopt_params);
|
|
154
152
|
|
|
153
|
+
// TODO: more flexible combinations of logical/physical batch size and context size
|
|
155
154
|
void opt_epoch(
|
|
156
155
|
ggml_opt_dataset_t dataset,
|
|
157
156
|
ggml_opt_result_t result_train,
|
|
@@ -212,8 +211,8 @@ private:
|
|
|
212
211
|
size_t state_write_data(llama_io_write_i & io);
|
|
213
212
|
size_t state_read_data (llama_io_read_i & io);
|
|
214
213
|
|
|
215
|
-
size_t state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id);
|
|
216
|
-
size_t state_seq_read_data (llama_io_read_i & io, llama_seq_id seq_id);
|
|
214
|
+
size_t state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags);
|
|
215
|
+
size_t state_seq_read_data (llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags);
|
|
217
216
|
|
|
218
217
|
//
|
|
219
218
|
// members
|
|
@@ -229,9 +228,6 @@ private:
|
|
|
229
228
|
|
|
230
229
|
std::unique_ptr<llama_memory_i> memory;
|
|
231
230
|
|
|
232
|
-
// TODO: temporary, until the llama_kv_self_defrag() API is removed
|
|
233
|
-
bool memory_force_optimize = false;
|
|
234
|
-
|
|
235
231
|
// decode output (2-dimensional array: [n_outputs][n_vocab])
|
|
236
232
|
size_t logits_size = 0; // capacity (of floats) for logits
|
|
237
233
|
float * logits = nullptr;
|
|
@@ -287,9 +283,8 @@ private:
|
|
|
287
283
|
|
|
288
284
|
bool has_evaluated_once = false;
|
|
289
285
|
|
|
290
|
-
// env:
|
|
291
|
-
|
|
292
|
-
bool supports_set_rows = false;
|
|
286
|
+
// env: LLAMA_GRAPH_REUSE_DISABLE
|
|
287
|
+
bool graph_reuse_disable = false;
|
|
293
288
|
|
|
294
289
|
// perf
|
|
295
290
|
mutable int64_t t_start_us = 0;
|
|
@@ -4,8 +4,8 @@
|
|
|
4
4
|
#include "llama-batch.h"
|
|
5
5
|
#include "llama-cparams.h"
|
|
6
6
|
|
|
7
|
-
#include "llama-kv-cache
|
|
8
|
-
#include "llama-kv-cache-
|
|
7
|
+
#include "llama-kv-cache.h"
|
|
8
|
+
#include "llama-kv-cache-iswa.h"
|
|
9
9
|
#include "llama-memory-hybrid.h"
|
|
10
10
|
#include "llama-memory-recurrent.h"
|
|
11
11
|
|
|
@@ -188,38 +188,23 @@ void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
|
|
|
188
188
|
|
|
189
189
|
void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
|
|
190
190
|
const int64_t n_tokens = ubatch->n_tokens;
|
|
191
|
-
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
|
|
192
191
|
const int64_t n_seqs_unq = ubatch->n_seqs_unq;
|
|
193
192
|
|
|
194
193
|
if (cparams.embeddings && (
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
194
|
+
cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
|
|
195
|
+
cparams.pooling_type == LLAMA_POOLING_TYPE_RANK ||
|
|
196
|
+
cparams.pooling_type == LLAMA_POOLING_TYPE_LAST
|
|
197
|
+
)) {
|
|
198
198
|
GGML_ASSERT(cls);
|
|
199
199
|
GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer));
|
|
200
200
|
|
|
201
201
|
uint32_t * data = (uint32_t *) cls->data;
|
|
202
202
|
memset(cls->data, 0, n_seqs_unq*ggml_element_size(cls));
|
|
203
203
|
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
const llama_seq_id seq_id = ubatch->seq_id[i][s];
|
|
207
|
-
const int32_t seq_idx = ubatch->seq_idx[seq_id];
|
|
208
|
-
|
|
209
|
-
data[seq_idx] = i;
|
|
210
|
-
}
|
|
211
|
-
}
|
|
212
|
-
}
|
|
204
|
+
std::vector<int> target_pos(n_seqs_unq, -1);
|
|
205
|
+
std::vector<int> target_row(n_seqs_unq, -1);
|
|
213
206
|
|
|
214
|
-
|
|
215
|
-
GGML_ASSERT(cls);
|
|
216
|
-
GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer));
|
|
217
|
-
|
|
218
|
-
uint32_t * data = (uint32_t *) cls->data;
|
|
219
|
-
memset(cls->data, 0, n_seqs_unq*ggml_element_size(cls));
|
|
220
|
-
|
|
221
|
-
std::vector<int> last_pos(n_seqs_unq, -1);
|
|
222
|
-
std::vector<int> last_row(n_seqs_unq, -1);
|
|
207
|
+
bool last = cparams.pooling_type == LLAMA_POOLING_TYPE_LAST;
|
|
223
208
|
|
|
224
209
|
for (int i = 0; i < n_tokens; ++i) {
|
|
225
210
|
const llama_pos pos = ubatch->pos[i];
|
|
@@ -228,16 +213,20 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
|
|
|
228
213
|
const llama_seq_id seq_id = ubatch->seq_id[i][s];
|
|
229
214
|
const int32_t seq_idx = ubatch->seq_idx[seq_id];
|
|
230
215
|
|
|
231
|
-
if (
|
|
232
|
-
|
|
233
|
-
|
|
216
|
+
if (
|
|
217
|
+
(target_pos[seq_idx] == -1) ||
|
|
218
|
+
( last && pos >= target_pos[seq_idx]) ||
|
|
219
|
+
(!last && pos < target_pos[seq_idx])
|
|
220
|
+
) {
|
|
221
|
+
target_pos[seq_idx] = pos;
|
|
222
|
+
target_row[seq_idx] = i;
|
|
234
223
|
}
|
|
235
224
|
}
|
|
236
225
|
}
|
|
237
226
|
|
|
238
227
|
for (int s = 0; s < n_seqs_unq; ++s) {
|
|
239
|
-
if (
|
|
240
|
-
data[s] =
|
|
228
|
+
if (target_row[s] >= 0) {
|
|
229
|
+
data[s] = target_row[s];
|
|
241
230
|
}
|
|
242
231
|
}
|
|
243
232
|
}
|
|
@@ -288,7 +277,7 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
|
|
|
288
277
|
for (int s = 0; s < ubatch->n_seq_id[i0]; ++s) {
|
|
289
278
|
const llama_seq_id s0 = ubatch->seq_id[i0][0];
|
|
290
279
|
|
|
291
|
-
// TODO: reimplement this like in
|
|
280
|
+
// TODO: reimplement this like in llama_kv_cache
|
|
292
281
|
if (s0 == s1 && (!cparams.causal_attn || ubatch->pos[i0] <= ubatch->pos[i1])) {
|
|
293
282
|
if (hparams.use_alibi) {
|
|
294
283
|
f = -std::abs(ubatch->pos[i0] - ubatch->pos[i1]);
|
|
@@ -305,15 +294,15 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
|
|
|
305
294
|
}
|
|
306
295
|
}
|
|
307
296
|
|
|
308
|
-
void
|
|
297
|
+
void llm_graph_input_attn_kv::set_input(const llama_ubatch * ubatch) {
|
|
309
298
|
mctx->set_input_k_idxs(self_k_idxs, ubatch);
|
|
310
299
|
mctx->set_input_v_idxs(self_v_idxs, ubatch);
|
|
311
300
|
|
|
312
301
|
mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
|
313
302
|
}
|
|
314
303
|
|
|
315
|
-
bool
|
|
316
|
-
const auto * mctx = static_cast<const
|
|
304
|
+
bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) {
|
|
305
|
+
const auto * mctx = static_cast<const llama_kv_cache_context *>(params.mctx);
|
|
317
306
|
|
|
318
307
|
this->mctx = mctx;
|
|
319
308
|
|
|
@@ -325,12 +314,10 @@ bool llm_graph_input_attn_kv_unified::can_reuse(const llm_graph_params & params)
|
|
|
325
314
|
res &= self_kq_mask->ne[0] == mctx->get_n_kv();
|
|
326
315
|
res &= self_kq_mask->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
|
|
327
316
|
|
|
328
|
-
res &= mctx->get_supports_set_rows(); // TODO: tmp
|
|
329
|
-
|
|
330
317
|
return res;
|
|
331
318
|
}
|
|
332
319
|
|
|
333
|
-
void
|
|
320
|
+
void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) {
|
|
334
321
|
mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch);
|
|
335
322
|
mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
|
|
336
323
|
|
|
@@ -342,8 +329,8 @@ void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch
|
|
|
342
329
|
mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
|
|
343
330
|
}
|
|
344
331
|
|
|
345
|
-
bool
|
|
346
|
-
const auto * mctx = static_cast<const
|
|
332
|
+
bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) {
|
|
333
|
+
const auto * mctx = static_cast<const llama_kv_cache_iswa_context *>(params.mctx);
|
|
347
334
|
|
|
348
335
|
this->mctx = mctx;
|
|
349
336
|
|
|
@@ -361,8 +348,6 @@ bool llm_graph_input_attn_kv_unified_iswa::can_reuse(const llm_graph_params & pa
|
|
|
361
348
|
res &= self_kq_mask_swa->ne[0] == mctx->get_swa()->get_n_kv();
|
|
362
349
|
res &= self_kq_mask_swa->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
|
|
363
350
|
|
|
364
|
-
res &= mctx->get_base()->get_supports_set_rows(); // TODO: tmp
|
|
365
|
-
|
|
366
351
|
return res;
|
|
367
352
|
}
|
|
368
353
|
|
|
@@ -751,6 +736,8 @@ ggml_tensor * llm_graph_context::build_ffn(
|
|
|
751
736
|
cur = ggml_reglu(ctx0, cur);
|
|
752
737
|
cb(cur, "ffn_reglu", il);
|
|
753
738
|
} break;
|
|
739
|
+
default:
|
|
740
|
+
GGML_ABORT("fatal error");
|
|
754
741
|
}
|
|
755
742
|
|
|
756
743
|
if (gate && type_gate == LLM_FFN_PAR) {
|
|
@@ -760,8 +747,8 @@ ggml_tensor * llm_graph_context::build_ffn(
|
|
|
760
747
|
|
|
761
748
|
if (down) {
|
|
762
749
|
cur = build_lora_mm(down, cur);
|
|
763
|
-
if (arch == LLM_ARCH_GLM4) {
|
|
764
|
-
// GLM4
|
|
750
|
+
if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) {
|
|
751
|
+
// GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators
|
|
765
752
|
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
|
|
766
753
|
}
|
|
767
754
|
}
|
|
@@ -796,13 +783,64 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|
|
796
783
|
bool scale_w,
|
|
797
784
|
float w_scale,
|
|
798
785
|
llama_expert_gating_func_type gating_op,
|
|
799
|
-
int il
|
|
786
|
+
int il,
|
|
787
|
+
ggml_tensor * probs_in) const {
|
|
788
|
+
return build_moe_ffn(
|
|
789
|
+
cur,
|
|
790
|
+
gate_inp, /* gate_inp_b */ nullptr,
|
|
791
|
+
up_exps, /* up_exps_b */ nullptr,
|
|
792
|
+
gate_exps, /* gate_exps_b */ nullptr,
|
|
793
|
+
down_exps, /* down_exps_b */ nullptr,
|
|
794
|
+
exp_probs_b,
|
|
795
|
+
n_expert,
|
|
796
|
+
n_expert_used,
|
|
797
|
+
type_op,
|
|
798
|
+
norm_w,
|
|
799
|
+
scale_w,
|
|
800
|
+
w_scale,
|
|
801
|
+
gating_op,
|
|
802
|
+
il,
|
|
803
|
+
probs_in
|
|
804
|
+
);
|
|
805
|
+
}
|
|
806
|
+
|
|
807
|
+
ggml_tensor * llm_graph_context::build_moe_ffn(
|
|
808
|
+
ggml_tensor * cur,
|
|
809
|
+
ggml_tensor * gate_inp,
|
|
810
|
+
ggml_tensor * gate_inp_b,
|
|
811
|
+
ggml_tensor * up_exps,
|
|
812
|
+
ggml_tensor * up_exps_b,
|
|
813
|
+
ggml_tensor * gate_exps,
|
|
814
|
+
ggml_tensor * gate_exps_b,
|
|
815
|
+
ggml_tensor * down_exps,
|
|
816
|
+
ggml_tensor * down_exps_b,
|
|
817
|
+
ggml_tensor * exp_probs_b,
|
|
818
|
+
int64_t n_expert,
|
|
819
|
+
int64_t n_expert_used,
|
|
820
|
+
llm_ffn_op_type type_op,
|
|
821
|
+
bool norm_w,
|
|
822
|
+
bool scale_w,
|
|
823
|
+
float w_scale,
|
|
824
|
+
llama_expert_gating_func_type gating_op,
|
|
825
|
+
int il,
|
|
826
|
+
ggml_tensor * probs_in) const {
|
|
800
827
|
const int64_t n_embd = cur->ne[0];
|
|
801
828
|
const int64_t n_tokens = cur->ne[1];
|
|
802
829
|
const bool weight_before_ffn = arch == LLM_ARCH_LLAMA4; // for llama4, we apply the sigmoid-ed weights before the FFN
|
|
803
830
|
|
|
804
|
-
ggml_tensor * logits =
|
|
805
|
-
|
|
831
|
+
ggml_tensor * logits = nullptr;
|
|
832
|
+
|
|
833
|
+
if (probs_in == nullptr) {
|
|
834
|
+
logits = build_lora_mm(gate_inp, cur); // [n_expert, n_tokens]
|
|
835
|
+
cb(logits, "ffn_moe_logits", il);
|
|
836
|
+
} else {
|
|
837
|
+
logits = probs_in;
|
|
838
|
+
}
|
|
839
|
+
|
|
840
|
+
if (gate_inp_b) {
|
|
841
|
+
logits = ggml_add(ctx0, logits, gate_inp_b);
|
|
842
|
+
cb(logits, "ffn_moe_logits_biased", il);
|
|
843
|
+
}
|
|
806
844
|
|
|
807
845
|
ggml_tensor * probs = nullptr;
|
|
808
846
|
switch (gating_op) {
|
|
@@ -814,6 +852,10 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|
|
814
852
|
{
|
|
815
853
|
probs = ggml_sigmoid(ctx0, logits); // [n_expert, n_tokens]
|
|
816
854
|
} break;
|
|
855
|
+
case LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT:
|
|
856
|
+
{
|
|
857
|
+
probs = logits; // [n_expert, n_tokens]
|
|
858
|
+
} break;
|
|
817
859
|
default:
|
|
818
860
|
GGML_ABORT("fatal error");
|
|
819
861
|
}
|
|
@@ -842,6 +884,13 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|
|
842
884
|
ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
|
|
843
885
|
cb(weights, "ffn_moe_weights", il);
|
|
844
886
|
|
|
887
|
+
if (gating_op == LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT) {
|
|
888
|
+
weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens);
|
|
889
|
+
weights = ggml_soft_max(ctx0, weights); // [n_expert_used, n_tokens]
|
|
890
|
+
weights = ggml_reshape_3d(ctx0, weights, 1, n_expert_used, n_tokens);
|
|
891
|
+
cb(weights, "ffn_moe_weights_softmax", il);
|
|
892
|
+
}
|
|
893
|
+
|
|
845
894
|
if (norm_w) {
|
|
846
895
|
weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens);
|
|
847
896
|
|
|
@@ -870,6 +919,11 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|
|
870
919
|
ggml_tensor * up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
|
|
871
920
|
cb(up, "ffn_moe_up", il);
|
|
872
921
|
|
|
922
|
+
if (up_exps_b) {
|
|
923
|
+
up = ggml_add_id(ctx0, up, up_exps_b, selected_experts);
|
|
924
|
+
cb(up, "ffn_moe_up_biased", il);
|
|
925
|
+
}
|
|
926
|
+
|
|
873
927
|
ggml_tensor * experts = nullptr;
|
|
874
928
|
if (gate_exps) {
|
|
875
929
|
cur = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
|
|
@@ -878,6 +932,11 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|
|
878
932
|
cur = up;
|
|
879
933
|
}
|
|
880
934
|
|
|
935
|
+
if (gate_exps_b) {
|
|
936
|
+
cur = ggml_add_id(ctx0, cur, gate_exps_b, selected_experts);
|
|
937
|
+
cb(cur, "ffn_moe_gate_biased", il);
|
|
938
|
+
}
|
|
939
|
+
|
|
881
940
|
switch (type_op) {
|
|
882
941
|
case LLM_FFN_SILU:
|
|
883
942
|
if (gate_exps) {
|
|
@@ -895,6 +954,22 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|
|
895
954
|
cur = ggml_gelu(ctx0, cur);
|
|
896
955
|
cb(cur, "ffn_moe_gelu", il);
|
|
897
956
|
} break;
|
|
957
|
+
case LLM_FFN_SWIGLU_OAI_MOE:
|
|
958
|
+
{
|
|
959
|
+
// TODO: move to hparams?
|
|
960
|
+
constexpr float alpha = 1.702f;
|
|
961
|
+
constexpr float limit = 7.0f;
|
|
962
|
+
cur = ggml_swiglu_oai(ctx0, cur, up, alpha, limit);
|
|
963
|
+
cb(cur, "ffn_moe_swiglu_oai", il);
|
|
964
|
+
} break;
|
|
965
|
+
case LLM_FFN_RELU:
|
|
966
|
+
if (gate_exps) {
|
|
967
|
+
cur = ggml_reglu_split(ctx0, cur, up);
|
|
968
|
+
cb(cur, "ffn_moe_reglu", il);
|
|
969
|
+
} else {
|
|
970
|
+
cur = ggml_relu(ctx0, cur);
|
|
971
|
+
cb(cur, "ffn_moe_relu", il);
|
|
972
|
+
} break;
|
|
898
973
|
default:
|
|
899
974
|
GGML_ABORT("fatal error");
|
|
900
975
|
}
|
|
@@ -902,6 +977,11 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|
|
902
977
|
experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens]
|
|
903
978
|
cb(experts, "ffn_moe_down", il);
|
|
904
979
|
|
|
980
|
+
if (down_exps_b) {
|
|
981
|
+
experts = ggml_add_id(ctx0, experts, down_exps_b, selected_experts);
|
|
982
|
+
cb(experts, "ffn_moe_down_biased", il);
|
|
983
|
+
}
|
|
984
|
+
|
|
905
985
|
if (!weight_before_ffn) {
|
|
906
986
|
experts = ggml_mul(ctx0, experts, weights);
|
|
907
987
|
cb(cur, "ffn_moe_weighted", il);
|
|
@@ -1102,7 +1182,7 @@ ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const {
|
|
|
1102
1182
|
}
|
|
1103
1183
|
|
|
1104
1184
|
ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
|
|
1105
|
-
const auto * mctx_cur = static_cast<const
|
|
1185
|
+
const auto * mctx_cur = static_cast<const llama_kv_cache_context *>(mctx);
|
|
1106
1186
|
|
|
1107
1187
|
auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, mctx_cur);
|
|
1108
1188
|
|
|
@@ -1139,6 +1219,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
|
|
1139
1219
|
ggml_tensor * v,
|
|
1140
1220
|
ggml_tensor * kq_b,
|
|
1141
1221
|
ggml_tensor * kq_mask,
|
|
1222
|
+
ggml_tensor * sinks,
|
|
1142
1223
|
ggml_tensor * v_mla,
|
|
1143
1224
|
float kq_scale) const {
|
|
1144
1225
|
const bool v_trans = v->nb[1] > v->nb[2];
|
|
@@ -1176,7 +1257,8 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
|
|
1176
1257
|
cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
|
|
1177
1258
|
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
|
|
1178
1259
|
|
|
1179
|
-
|
|
1260
|
+
ggml_flash_attn_ext_add_sinks(cur, sinks);
|
|
1261
|
+
ggml_flash_attn_ext_set_prec (cur, GGML_PREC_F32);
|
|
1180
1262
|
|
|
1181
1263
|
if (v_mla) {
|
|
1182
1264
|
#if 0
|
|
@@ -1224,6 +1306,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
|
|
1224
1306
|
}
|
|
1225
1307
|
|
|
1226
1308
|
kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
|
|
1309
|
+
ggml_soft_max_add_sinks(kq, sinks);
|
|
1227
1310
|
|
|
1228
1311
|
if (!v_trans) {
|
|
1229
1312
|
// note: avoid this branch
|
|
@@ -1273,6 +1356,7 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
|
1273
1356
|
ggml_tensor * k_cur,
|
|
1274
1357
|
ggml_tensor * v_cur,
|
|
1275
1358
|
ggml_tensor * kq_b,
|
|
1359
|
+
ggml_tensor * sinks,
|
|
1276
1360
|
ggml_tensor * v_mla,
|
|
1277
1361
|
float kq_scale,
|
|
1278
1362
|
int il) const {
|
|
@@ -1288,13 +1372,13 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
|
1288
1372
|
|
|
1289
1373
|
// [TAG_NO_CACHE_PAD]
|
|
1290
1374
|
// TODO: if ubatch.equal_seqs() == true, we can split the three tensors below into ubatch.n_seqs_unq streams
|
|
1291
|
-
assert(!ubatch.equal_seqs());
|
|
1375
|
+
assert(!ubatch.equal_seqs() || (k_cur->ne[3] == 1 && k_cur->ne[3] == ubatch.n_seqs_unq));
|
|
1292
1376
|
|
|
1293
1377
|
ggml_tensor * q = q_cur;
|
|
1294
1378
|
ggml_tensor * k = k_cur;
|
|
1295
1379
|
ggml_tensor * v = v_cur;
|
|
1296
1380
|
|
|
1297
|
-
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
|
1381
|
+
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale);
|
|
1298
1382
|
cb(cur, "kqv_out", il);
|
|
1299
1383
|
|
|
1300
1384
|
if (wo) {
|
|
@@ -1312,17 +1396,17 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
|
1312
1396
|
return cur;
|
|
1313
1397
|
}
|
|
1314
1398
|
|
|
1315
|
-
static std::unique_ptr<
|
|
1399
|
+
static std::unique_ptr<llm_graph_input_attn_kv> build_attn_inp_kv_impl(
|
|
1316
1400
|
ggml_context * ctx0,
|
|
1317
1401
|
const llama_ubatch & ubatch,
|
|
1318
1402
|
const llama_hparams & hparams,
|
|
1319
1403
|
const llama_cparams & cparams,
|
|
1320
|
-
const
|
|
1404
|
+
const llama_kv_cache_context * mctx_cur) {
|
|
1321
1405
|
|
|
1322
|
-
auto inp = std::make_unique<
|
|
1406
|
+
auto inp = std::make_unique<llm_graph_input_attn_kv>(hparams, cparams, mctx_cur);
|
|
1323
1407
|
|
|
1324
1408
|
{
|
|
1325
|
-
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use
|
|
1409
|
+
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_iswa for SWA");
|
|
1326
1410
|
|
|
1327
1411
|
const auto n_kv = mctx_cur->get_n_kv();
|
|
1328
1412
|
const auto n_tokens = ubatch.n_tokens;
|
|
@@ -1340,22 +1424,23 @@ static std::unique_ptr<llm_graph_input_attn_kv_unified> build_attn_inp_kv_unifie
|
|
|
1340
1424
|
return inp;
|
|
1341
1425
|
}
|
|
1342
1426
|
|
|
1343
|
-
|
|
1344
|
-
const auto * mctx_cur = static_cast<const
|
|
1427
|
+
llm_graph_input_attn_kv * llm_graph_context::build_attn_inp_kv() const {
|
|
1428
|
+
const auto * mctx_cur = static_cast<const llama_kv_cache_context *>(mctx);
|
|
1345
1429
|
|
|
1346
|
-
auto inp =
|
|
1430
|
+
auto inp = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur);
|
|
1347
1431
|
|
|
1348
|
-
return (
|
|
1432
|
+
return (llm_graph_input_attn_kv *) res->add_input(std::move(inp));
|
|
1349
1433
|
}
|
|
1350
1434
|
|
|
1351
1435
|
ggml_tensor * llm_graph_context::build_attn(
|
|
1352
|
-
|
|
1436
|
+
llm_graph_input_attn_kv * inp,
|
|
1353
1437
|
ggml_tensor * wo,
|
|
1354
1438
|
ggml_tensor * wo_b,
|
|
1355
1439
|
ggml_tensor * q_cur,
|
|
1356
1440
|
ggml_tensor * k_cur,
|
|
1357
1441
|
ggml_tensor * v_cur,
|
|
1358
1442
|
ggml_tensor * kq_b,
|
|
1443
|
+
ggml_tensor * sinks,
|
|
1359
1444
|
ggml_tensor * v_mla,
|
|
1360
1445
|
float kq_scale,
|
|
1361
1446
|
int il) const {
|
|
@@ -1382,13 +1467,13 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
|
1382
1467
|
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
|
|
1383
1468
|
ggml_tensor * v = mctx_cur->get_v(ctx0, il);
|
|
1384
1469
|
|
|
1385
|
-
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
|
1470
|
+
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale);
|
|
1386
1471
|
cb(cur, "kqv_out", il);
|
|
1387
1472
|
|
|
1388
1473
|
if (wo) {
|
|
1389
1474
|
cur = build_lora_mm(wo, cur);
|
|
1390
|
-
if (arch == LLM_ARCH_GLM4) {
|
|
1391
|
-
// GLM4
|
|
1475
|
+
if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) {
|
|
1476
|
+
// GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators
|
|
1392
1477
|
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
|
|
1393
1478
|
}
|
|
1394
1479
|
}
|
|
@@ -1401,13 +1486,14 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
|
1401
1486
|
}
|
|
1402
1487
|
|
|
1403
1488
|
ggml_tensor * llm_graph_context::build_attn(
|
|
1404
|
-
|
|
1489
|
+
llm_graph_input_attn_kv_iswa * inp,
|
|
1405
1490
|
ggml_tensor * wo,
|
|
1406
1491
|
ggml_tensor * wo_b,
|
|
1407
1492
|
ggml_tensor * q_cur,
|
|
1408
1493
|
ggml_tensor * k_cur,
|
|
1409
1494
|
ggml_tensor * v_cur,
|
|
1410
1495
|
ggml_tensor * kq_b,
|
|
1496
|
+
ggml_tensor * sinks,
|
|
1411
1497
|
ggml_tensor * v_mla,
|
|
1412
1498
|
float kq_scale,
|
|
1413
1499
|
int il) const {
|
|
@@ -1448,7 +1534,7 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
|
1448
1534
|
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
|
|
1449
1535
|
ggml_tensor * v = mctx_cur->get_v(ctx0, il);
|
|
1450
1536
|
|
|
1451
|
-
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
|
1537
|
+
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale);
|
|
1452
1538
|
cb(cur, "kqv_out", il);
|
|
1453
1539
|
|
|
1454
1540
|
if (wo) {
|
|
@@ -1487,6 +1573,7 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
|
1487
1573
|
ggml_tensor * k_cur,
|
|
1488
1574
|
ggml_tensor * v_cur,
|
|
1489
1575
|
ggml_tensor * kq_b,
|
|
1576
|
+
ggml_tensor * sinks,
|
|
1490
1577
|
ggml_tensor * v_mla,
|
|
1491
1578
|
float kq_scale,
|
|
1492
1579
|
int il) const {
|
|
@@ -1502,7 +1589,7 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
|
1502
1589
|
ggml_tensor * k = k_cur;
|
|
1503
1590
|
ggml_tensor * v = v_cur;
|
|
1504
1591
|
|
|
1505
|
-
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
|
1592
|
+
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale);
|
|
1506
1593
|
cb(cur, "kqv_out", il);
|
|
1507
1594
|
|
|
1508
1595
|
if (wo) {
|
|
@@ -1523,10 +1610,10 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
|
1523
1610
|
// TODO: maybe separate the inner implementation into a separate function
|
|
1524
1611
|
// like with the non-sliding window equivalent
|
|
1525
1612
|
// once sliding-window hybrid caches are a thing.
|
|
1526
|
-
|
|
1527
|
-
const auto * mctx_cur = static_cast<const
|
|
1613
|
+
llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const {
|
|
1614
|
+
const auto * mctx_cur = static_cast<const llama_kv_cache_iswa_context *>(mctx);
|
|
1528
1615
|
|
|
1529
|
-
auto inp = std::make_unique<
|
|
1616
|
+
auto inp = std::make_unique<llm_graph_input_attn_kv_iswa>(hparams, cparams, mctx_cur);
|
|
1530
1617
|
|
|
1531
1618
|
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
|
|
1532
1619
|
|
|
@@ -1543,7 +1630,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
|
|
|
1543
1630
|
}
|
|
1544
1631
|
|
|
1545
1632
|
{
|
|
1546
|
-
GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use
|
|
1633
|
+
GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache for non-SWA");
|
|
1547
1634
|
|
|
1548
1635
|
const auto n_kv = mctx_cur->get_swa()->get_n_kv();
|
|
1549
1636
|
|
|
@@ -1556,21 +1643,22 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
|
|
|
1556
1643
|
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
|
|
1557
1644
|
}
|
|
1558
1645
|
|
|
1559
|
-
return (
|
|
1646
|
+
return (llm_graph_input_attn_kv_iswa *) res->add_input(std::move(inp));
|
|
1560
1647
|
}
|
|
1561
1648
|
|
|
1562
1649
|
ggml_tensor * llm_graph_context::build_rs(
|
|
1563
1650
|
ggml_tensor * s,
|
|
1564
|
-
ggml_tensor *
|
|
1651
|
+
ggml_tensor * state_copy_main,
|
|
1652
|
+
ggml_tensor * state_copy_extra,
|
|
1565
1653
|
int32_t state_size,
|
|
1566
1654
|
int32_t n_seqs,
|
|
1567
|
-
uint32_t
|
|
1568
|
-
uint32_t
|
|
1569
|
-
uint32_t
|
|
1655
|
+
uint32_t n_rs,
|
|
1656
|
+
uint32_t rs_head,
|
|
1657
|
+
uint32_t rs_size,
|
|
1570
1658
|
int32_t rs_zero,
|
|
1571
1659
|
const llm_graph_get_rows_fn & get_state_rows) const {
|
|
1572
1660
|
|
|
1573
|
-
ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size,
|
|
1661
|
+
ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, rs_size);
|
|
1574
1662
|
|
|
1575
1663
|
// Clear a single state which will then be copied to the other cleared states.
|
|
1576
1664
|
// Note that this is a no-op when the view is zero-sized.
|
|
@@ -1578,39 +1666,44 @@ ggml_tensor * llm_graph_context::build_rs(
|
|
|
1578
1666
|
ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0));
|
|
1579
1667
|
|
|
1580
1668
|
// copy states
|
|
1581
|
-
// NOTE: assuming the copy destinations are ALL contained between
|
|
1582
|
-
// {state_size,
|
|
1583
|
-
ggml_tensor * output_states = get_state_rows(ctx0, states,
|
|
1669
|
+
// NOTE: assuming the copy destinations are ALL contained between rs_head and rs_head + n_rs
|
|
1670
|
+
// {state_size, rs_size} -> {state_size, n_seqs}
|
|
1671
|
+
ggml_tensor * output_states = get_state_rows(ctx0, states, state_copy_main);
|
|
1584
1672
|
ggml_build_forward_expand(gf, output_states);
|
|
1585
1673
|
|
|
1586
|
-
// copy extra states which won't be changed further (between n_seqs and
|
|
1587
|
-
ggml_tensor * states_extra = ggml_get_rows(ctx0, states,
|
|
1674
|
+
// copy extra states which won't be changed further (between n_seqs and n_rs)
|
|
1675
|
+
ggml_tensor * states_extra = ggml_get_rows(ctx0, states, state_copy_extra);
|
|
1588
1676
|
ggml_build_forward_expand(gf,
|
|
1589
1677
|
ggml_cpy(ctx0,
|
|
1590
1678
|
states_extra,
|
|
1591
|
-
ggml_view_1d(ctx0, s, state_size*(
|
|
1679
|
+
ggml_view_1d(ctx0, s, state_size*(n_rs - n_seqs), (rs_head + n_seqs)*state_size*ggml_element_size(s))));
|
|
1592
1680
|
|
|
1593
1681
|
return output_states;
|
|
1594
1682
|
}
|
|
1595
1683
|
|
|
1596
1684
|
static std::unique_ptr<llm_graph_input_rs> build_rs_inp_impl(
|
|
1597
1685
|
ggml_context * ctx0,
|
|
1686
|
+
const llama_ubatch & ubatch,
|
|
1598
1687
|
const llama_memory_recurrent_context * mctx_cur) {
|
|
1599
1688
|
|
|
1600
1689
|
auto inp = std::make_unique<llm_graph_input_rs>(mctx_cur);
|
|
1601
1690
|
|
|
1602
|
-
const
|
|
1691
|
+
const int64_t n_rs = mctx_cur->get_n_rs();
|
|
1692
|
+
const int64_t n_seqs = ubatch.n_seqs;
|
|
1603
1693
|
|
|
1604
1694
|
inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
|
|
1605
1695
|
ggml_set_input(inp->s_copy);
|
|
1606
1696
|
|
|
1697
|
+
inp->s_copy_main = ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0);
|
|
1698
|
+
inp->s_copy_extra = ggml_view_1d(ctx0, inp->s_copy, n_rs - n_seqs, n_seqs * inp->s_copy->nb[0]);
|
|
1699
|
+
|
|
1607
1700
|
return inp;
|
|
1608
1701
|
}
|
|
1609
1702
|
|
|
1610
1703
|
llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
|
|
1611
1704
|
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
|
|
1612
1705
|
|
|
1613
|
-
auto inp = build_rs_inp_impl(ctx0, mctx_cur);
|
|
1706
|
+
auto inp = build_rs_inp_impl(ctx0, ubatch, mctx_cur);
|
|
1614
1707
|
|
|
1615
1708
|
return (llm_graph_input_rs *) res->add_input(std::move(inp));
|
|
1616
1709
|
}
|
|
@@ -1623,7 +1716,9 @@ ggml_tensor * llm_graph_context::build_rs(
|
|
|
1623
1716
|
const llm_graph_get_rows_fn & get_state_rows) const {
|
|
1624
1717
|
const auto * kv_state = inp->mctx;
|
|
1625
1718
|
|
|
1626
|
-
return build_rs(s, inp->
|
|
1719
|
+
return build_rs(s, inp->s_copy_main, inp->s_copy_extra, state_size, n_seqs,
|
|
1720
|
+
kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(),
|
|
1721
|
+
get_state_rows);
|
|
1627
1722
|
}
|
|
1628
1723
|
|
|
1629
1724
|
ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
|
|
@@ -1670,8 +1765,8 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
|
|
|
1670
1765
|
llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
|
|
1671
1766
|
const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
|
|
1672
1767
|
|
|
1673
|
-
auto inp_rs = build_rs_inp_impl(ctx0, mctx_cur->get_recr());
|
|
1674
|
-
auto inp_attn =
|
|
1768
|
+
auto inp_rs = build_rs_inp_impl(ctx0, ubatch, mctx_cur->get_recr());
|
|
1769
|
+
auto inp_attn = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn());
|
|
1675
1770
|
|
|
1676
1771
|
auto inp = std::make_unique<llm_graph_input_mem_hybrid>(std::move(inp_attn), std::move(inp_rs), mctx_cur);
|
|
1677
1772
|
|