@novastera-oss/llamarn 0.3.1 → 0.4.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +86 -3
- package/RNLlamaCpp.podspec +1 -1
- package/android/CMakeLists.txt +11 -3
- package/android/generated/jni/react/renderer/components/RNLlamaCppSpec/RNLlamaCppSpecJSI.h +49 -4
- package/android/src/main/cpp/include/llama.h +53 -114
- package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86/libggml.so +0 -0
- package/android/src/main/jniLibs/x86/libllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
- package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
- package/cpp/LlamaCppModel.cpp +2 -10
- package/cpp/PureCppImpl.cpp +71 -4
- package/cpp/SystemUtils.cpp +3 -7
- package/cpp/build-info.cpp +2 -2
- package/cpp/llama.cpp/CMakeLists.txt +2 -0
- package/cpp/llama.cpp/CODEOWNERS +1 -1
- package/cpp/llama.cpp/Makefile +6 -1605
- package/cpp/llama.cpp/README.md +5 -1
- package/cpp/llama.cpp/common/arg.cpp +230 -51
- package/cpp/llama.cpp/common/chat-parser.cpp +9 -1
- package/cpp/llama.cpp/common/chat.cpp +539 -8
- package/cpp/llama.cpp/common/chat.h +8 -1
- package/cpp/llama.cpp/common/common.cpp +60 -15
- package/cpp/llama.cpp/common/common.h +64 -15
- package/cpp/llama.cpp/common/speculative.cpp +135 -54
- package/cpp/llama.cpp/common/speculative.h +8 -1
- package/cpp/llama.cpp/convert_hf_to_gguf.py +1216 -109
- package/cpp/llama.cpp/convert_hf_to_gguf_update.py +19 -6
- package/cpp/llama.cpp/convert_lora_to_gguf.py +1 -1
- package/cpp/llama.cpp/flake.nix +0 -5
- package/cpp/llama.cpp/ggml/CMakeLists.txt +6 -3
- package/cpp/llama.cpp/ggml/cmake/ggml-config.cmake.in +71 -70
- package/cpp/llama.cpp/ggml/include/ggml-opt.h +25 -6
- package/cpp/llama.cpp/ggml/include/ggml-zdnn.h +16 -0
- package/cpp/llama.cpp/ggml/include/ggml.h +90 -3
- package/cpp/llama.cpp/ggml/src/CMakeLists.txt +13 -1
- package/cpp/llama.cpp/ggml/src/ggml-alloc.c +1 -0
- package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +10 -0
- package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +113 -17
- package/cpp/llama.cpp/ggml/src/ggml-blas/ggml-blas.cpp +4 -4
- package/cpp/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +14 -0
- package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +701 -585
- package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +13 -3
- package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +52 -0
- package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +274 -91
- package/cpp/llama.cpp/ggml/src/ggml-common.h +17 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +132 -596
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +14 -286
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +90 -569
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +162 -589
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +55 -341
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +3 -58
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +371 -298
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +54 -314
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +184 -675
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/repack.cpp +4679 -1657
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +33 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +26 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +21 -24
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +16 -7
- package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +232 -123
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +428 -23
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +4 -8
- package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +35 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.h +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +458 -46
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.h +22 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +39 -14
- package/cpp/llama.cpp/ggml/src/ggml-cpu/traits.cpp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/traits.h +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +20 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +122 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +9 -11
- package/cpp/llama.cpp/ggml/src/ggml-cuda/add-id.cu +58 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/add-id.cuh +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/binbcast.cu +275 -170
- package/cpp/llama.cpp/ggml/src/ggml-cuda/binbcast.cuh +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +103 -65
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d.cu +171 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +33 -7
- package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +13 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy-utils.cuh +2 -10
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy.cu +3 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/dequantize.cuh +14 -40
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +83 -27
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +116 -57
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +45 -18
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +56 -29
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +61 -39
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +70 -49
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +70 -21
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cu +162 -50
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cuh +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +5 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +208 -97
- package/cpp/llama.cpp/ggml/src/ggml-cuda/im2col.cu +46 -35
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cu +56 -2
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mma.cuh +95 -51
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmf.cu +427 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmf.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cu +204 -57
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +252 -168
- package/cpp/llama.cpp/ggml/src/ggml-cuda/{mmv.cu → mmvf.cu} +53 -53
- package/cpp/llama.cpp/ggml/src/ggml-cuda/{mmv.cuh → mmvf.cuh} +3 -3
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmvq.cu +10 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cu +192 -19
- package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/pad_reflect_1d.cu +82 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/roll.cu +67 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/roll.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cu +1 -8
- package/cpp/llama.cpp/ggml/src/ggml-cuda/softcap.cu +34 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/softcap.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/softmax.cu +16 -10
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +153 -71
- package/cpp/llama.cpp/ggml/src/ggml-cuda/sum.cu +6 -10
- package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cu +21 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +75 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +14 -25
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +2 -1
- package/cpp/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-impl.h +61 -0
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +31 -20
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +342 -131
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +464 -134
- package/cpp/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +0 -4
- package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +1108 -176
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/add.cl +107 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/div.cl +66 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +343 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +343 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +346 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl +41 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/group_norm.cl +49 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +132 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +133 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/sub.cl +66 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/transpose.cl +20 -0
- package/cpp/llama.cpp/ggml/src/ggml-opt.cpp +97 -41
- package/cpp/llama.cpp/ggml/src/ggml-quants.c +110 -16
- package/cpp/llama.cpp/ggml/src/ggml-quants.h +6 -0
- package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +22 -9
- package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +0 -212
- package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.hpp +213 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +117 -238
- package/cpp/llama.cpp/ggml/src/ggml-sycl/quantize.hpp +133 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +94 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1666 -633
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +41 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +13 -4
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +39 -29
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +107 -43
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +18 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +21 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +20 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +21 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +16 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +44 -8
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +44 -16
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +26 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +2 -17
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +37 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +11 -7
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +109 -55
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +71 -41
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +6 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +111 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +49 -11
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +9 -3
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +38 -5
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +55 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/utils.comp +25 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +75 -20
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/CMakeLists.txt +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/ggml-webgpu.cpp +807 -412
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +72 -22
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +8 -8
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +1794 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +82 -0
- package/cpp/llama.cpp/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
- package/cpp/llama.cpp/ggml/src/ggml-zdnn/ggml-zdnn-impl.h +97 -0
- package/cpp/llama.cpp/ggml/src/ggml-zdnn/ggml-zdnn.cpp +846 -0
- package/cpp/llama.cpp/ggml/src/ggml.c +204 -50
- package/cpp/llama.cpp/gguf-py/gguf/constants.py +187 -2
- package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +11 -2
- package/cpp/llama.cpp/gguf-py/gguf/quants.py +53 -4
- package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_convert_endian.py +67 -63
- package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_new_metadata.py +7 -1
- package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +120 -16
- package/cpp/llama.cpp/gguf-py/gguf/utility.py +5 -1
- package/cpp/llama.cpp/gguf-py/gguf/vocab.py +284 -1
- package/cpp/llama.cpp/gguf-py/tests/test_quants.py +14 -5
- package/cpp/llama.cpp/include/llama.h +53 -114
- package/cpp/llama.cpp/models/templates/ByteDance-Seed-OSS.jinja +171 -0
- package/cpp/llama.cpp/models/templates/README.md +2 -1
- package/cpp/llama.cpp/models/templates/ibm-granite-granite-3.3-2B-Instruct.jinja +59 -0
- package/cpp/llama.cpp/models/templates/openai-gpt-oss-120b.jinja +331 -0
- package/cpp/llama.cpp/models/templates/unsloth-mistral-Devstral-Small-2507.jinja +105 -0
- package/cpp/llama.cpp/requirements/requirements-convert_hf_to_gguf.txt +3 -1
- package/cpp/llama.cpp/requirements/requirements-convert_hf_to_gguf_update.txt +0 -6
- package/cpp/llama.cpp/requirements/requirements-pydantic.txt +1 -1
- package/cpp/llama.cpp/src/CMakeLists.txt +2 -2
- package/cpp/llama.cpp/src/llama-adapter.cpp +68 -4
- package/cpp/llama.cpp/src/llama-adapter.h +3 -0
- package/cpp/llama.cpp/src/llama-arch.cpp +192 -2
- package/cpp/llama.cpp/src/llama-arch.h +18 -0
- package/cpp/llama.cpp/src/llama-batch.cpp +2 -2
- package/cpp/llama.cpp/src/llama-chat.cpp +47 -6
- package/cpp/llama.cpp/src/llama-chat.h +3 -0
- package/cpp/llama.cpp/src/llama-context.cpp +61 -252
- package/cpp/llama.cpp/src/llama-context.h +10 -15
- package/cpp/llama.cpp/src/llama-cparams.h +0 -1
- package/cpp/llama.cpp/src/llama-graph.cpp +180 -85
- package/cpp/llama.cpp/src/llama-graph.h +90 -51
- package/cpp/llama.cpp/src/llama-hparams.cpp +34 -3
- package/cpp/llama.cpp/src/llama-hparams.h +21 -6
- package/cpp/llama.cpp/src/{llama-kv-cache-unified-iswa.cpp → llama-kv-cache-iswa.cpp} +79 -56
- package/cpp/llama.cpp/src/{llama-kv-cache-unified-iswa.h → llama-kv-cache-iswa.h} +30 -28
- package/cpp/llama.cpp/src/{llama-kv-cache-unified.cpp → llama-kv-cache.cpp} +240 -632
- package/cpp/llama.cpp/src/{llama-kv-cache-unified.h → llama-kv-cache.h} +39 -74
- package/cpp/llama.cpp/src/llama-kv-cells.h +21 -21
- package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +41 -35
- package/cpp/llama.cpp/src/llama-memory-hybrid.h +26 -29
- package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +13 -9
- package/cpp/llama.cpp/src/llama-memory-recurrent.h +10 -14
- package/cpp/llama.cpp/src/llama-memory.h +13 -10
- package/cpp/llama.cpp/src/llama-model-loader.cpp +2 -0
- package/cpp/llama.cpp/src/llama-model-loader.h +3 -2
- package/cpp/llama.cpp/src/llama-model.cpp +1959 -419
- package/cpp/llama.cpp/src/llama-model.h +28 -4
- package/cpp/llama.cpp/src/llama-quant.cpp +40 -4
- package/cpp/llama.cpp/src/llama-vocab.cpp +51 -2
- package/cpp/llama.cpp/src/llama-vocab.h +1 -0
- package/cpp/llama.cpp/vendor/minja/chat-template.hpp +16 -7
- package/cpp/llama.cpp/vendor/minja/minja.hpp +47 -12
- package/cpp/rn-completion.cpp +3 -27
- package/ios/generated/RNLlamaCppSpec/RNLlamaCppSpec.h +30 -0
- package/ios/generated/RNLlamaCppSpecJSI.h +49 -4
- package/ios/include/chat.h +8 -1
- package/ios/include/common/minja/chat-template.hpp +16 -7
- package/ios/include/common/minja/minja.hpp +47 -12
- package/ios/include/common.h +64 -15
- package/ios/include/llama.h +53 -114
- package/ios/include/speculative.h +8 -1
- package/ios/libs/llama.xcframework/Info.plist +18 -18
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5557 -5267
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5520 -5238
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4241 -4014
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5519 -5238
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4242 -4016
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5556 -5267
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5519 -5238
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4241 -4014
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5553 -5303
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5515 -5274
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4238 -4044
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +25 -6
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +90 -3
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +53 -114
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/lib/module/NativeRNLlamaCpp.js.map +1 -1
- package/lib/typescript/src/NativeRNLlamaCpp.d.ts +5 -0
- package/lib/typescript/src/NativeRNLlamaCpp.d.ts.map +1 -1
- package/package.json +1 -2
- package/src/NativeRNLlamaCpp.ts +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +0 -56
|
@@ -109,6 +109,9 @@ enum common_chat_format {
|
|
|
109
109
|
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
|
|
110
110
|
COMMON_CHAT_FORMAT_HERMES_2_PRO,
|
|
111
111
|
COMMON_CHAT_FORMAT_COMMAND_R7B,
|
|
112
|
+
COMMON_CHAT_FORMAT_GRANITE,
|
|
113
|
+
COMMON_CHAT_FORMAT_GPT_OSS,
|
|
114
|
+
COMMON_CHAT_FORMAT_SEED_OSS,
|
|
112
115
|
|
|
113
116
|
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
|
|
114
117
|
};
|
|
@@ -127,6 +130,8 @@ struct common_chat_templates_inputs {
|
|
|
127
130
|
bool enable_thinking = true;
|
|
128
131
|
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
|
|
129
132
|
std::map<std::string, std::string> chat_template_kwargs;
|
|
133
|
+
bool add_bos = false;
|
|
134
|
+
bool add_eos = false;
|
|
130
135
|
};
|
|
131
136
|
|
|
132
137
|
struct common_chat_params {
|
|
@@ -183,10 +188,12 @@ std::string common_chat_format_single(
|
|
|
183
188
|
// Returns an example of formatted chat
|
|
184
189
|
std::string common_chat_format_example(
|
|
185
190
|
const struct common_chat_templates * tmpls,
|
|
186
|
-
bool use_jinja
|
|
191
|
+
bool use_jinja,
|
|
192
|
+
const std::map<std::string, std::string> & chat_template_kwargs);
|
|
187
193
|
|
|
188
194
|
const char* common_chat_format_name(common_chat_format format);
|
|
189
195
|
const char* common_reasoning_format_name(common_reasoning_format format);
|
|
196
|
+
common_reasoning_format common_reasoning_format_from_name(const std::string & format);
|
|
190
197
|
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax);
|
|
191
198
|
|
|
192
199
|
common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice);
|
|
@@ -41,6 +41,7 @@
|
|
|
41
41
|
#endif
|
|
42
42
|
#include <locale>
|
|
43
43
|
#include <windows.h>
|
|
44
|
+
#include <string.h>
|
|
44
45
|
#include <fcntl.h>
|
|
45
46
|
#include <io.h>
|
|
46
47
|
#else
|
|
@@ -557,13 +558,6 @@ std::string string_from(const struct llama_context * ctx, const std::vector<llam
|
|
|
557
558
|
|
|
558
559
|
auto detokenized = common_token_to_piece(ctx, token);
|
|
559
560
|
|
|
560
|
-
detokenized.erase(
|
|
561
|
-
std::remove_if(
|
|
562
|
-
detokenized.begin(),
|
|
563
|
-
detokenized.end(),
|
|
564
|
-
[](const unsigned char c) { return !std::isprint(c); }),
|
|
565
|
-
detokenized.end());
|
|
566
|
-
|
|
567
561
|
buf << "'" << detokenized << "'"
|
|
568
562
|
<< ":" << std::to_string(token);
|
|
569
563
|
}
|
|
@@ -588,13 +582,6 @@ std::string string_from(const struct llama_context * ctx, const struct llama_bat
|
|
|
588
582
|
|
|
589
583
|
auto detokenized = common_token_to_piece(ctx, batch.token[i]);
|
|
590
584
|
|
|
591
|
-
detokenized.erase(
|
|
592
|
-
std::remove_if(
|
|
593
|
-
detokenized.begin(),
|
|
594
|
-
detokenized.end(),
|
|
595
|
-
[](const unsigned char c) { return !std::isprint(c); }),
|
|
596
|
-
detokenized.end());
|
|
597
|
-
|
|
598
585
|
buf << "\n" << std::to_string(i)
|
|
599
586
|
<< ", token '" << detokenized << "'"
|
|
600
587
|
<< ", pos " << std::to_string(batch.pos[i])
|
|
@@ -1001,7 +988,12 @@ struct common_init_result common_init_from_params(common_params & params) {
|
|
|
1001
988
|
return iparams;
|
|
1002
989
|
}
|
|
1003
990
|
|
|
991
|
+
char buf[1024];
|
|
1004
992
|
la.ptr = lora.get();
|
|
993
|
+
llama_adapter_meta_val_str(la.ptr, "adapter.lora.task_name", buf, sizeof(buf));
|
|
994
|
+
la.task_name = buf;
|
|
995
|
+
llama_adapter_meta_val_str(la.ptr, "adapter.lora.prompt_prefix", buf, sizeof(buf));
|
|
996
|
+
la.prompt_prefix = buf;
|
|
1005
997
|
iparams.lora.emplace_back(std::move(lora)); // copy to list of loaded adapters
|
|
1006
998
|
}
|
|
1007
999
|
|
|
@@ -1122,6 +1114,7 @@ struct llama_model_params common_model_params_to_llama(common_params & params) {
|
|
|
1122
1114
|
mparams.use_mmap = params.use_mmap;
|
|
1123
1115
|
mparams.use_mlock = params.use_mlock;
|
|
1124
1116
|
mparams.check_tensors = params.check_tensors;
|
|
1117
|
+
mparams.use_extra_bufts = !params.no_extra_bufts;
|
|
1125
1118
|
|
|
1126
1119
|
if (params.kv_overrides.empty()) {
|
|
1127
1120
|
mparams.kv_overrides = NULL;
|
|
@@ -1164,7 +1157,6 @@ struct llama_context_params common_context_params_to_llama(const common_params &
|
|
|
1164
1157
|
cparams.yarn_orig_ctx = params.yarn_orig_ctx;
|
|
1165
1158
|
cparams.pooling_type = params.pooling_type;
|
|
1166
1159
|
cparams.attention_type = params.attention_type;
|
|
1167
|
-
cparams.defrag_thold = params.defrag_thold;
|
|
1168
1160
|
cparams.cb_eval = params.cb_eval;
|
|
1169
1161
|
cparams.cb_eval_user_data = params.cb_eval_user_data;
|
|
1170
1162
|
cparams.offload_kqv = !params.no_kv_offload;
|
|
@@ -1564,3 +1556,56 @@ ggml_opt_dataset_t common_opt_dataset_init(struct llama_context * ctx, const std
|
|
|
1564
1556
|
|
|
1565
1557
|
return result;
|
|
1566
1558
|
}
|
|
1559
|
+
|
|
1560
|
+
ggml_opt_optimizer_params common_opt_lr_pars(void * userdata) {
|
|
1561
|
+
ggml_opt_optimizer_params result = ggml_opt_get_default_optimizer_params(nullptr);
|
|
1562
|
+
const lr_opt & d = *(lr_opt *) userdata;
|
|
1563
|
+
result.adamw.alpha = result.sgd.alpha = d.get_lr(d.epoch);
|
|
1564
|
+
result.sgd.wd = result.adamw.wd = d.wd;
|
|
1565
|
+
return result;
|
|
1566
|
+
}
|
|
1567
|
+
|
|
1568
|
+
// TODO make all command line args case-insensitive
|
|
1569
|
+
static inline bool eq_case_insensitive(char const* a, char const* b) {
|
|
1570
|
+
return !
|
|
1571
|
+
#if defined(_MSC_VER)
|
|
1572
|
+
_stricmp
|
|
1573
|
+
#else
|
|
1574
|
+
strcasecmp
|
|
1575
|
+
#endif // defined(_MSC_VER)
|
|
1576
|
+
(a, b);
|
|
1577
|
+
}
|
|
1578
|
+
|
|
1579
|
+
enum ggml_opt_optimizer_type common_opt_get_optimizer(const char * n) {
|
|
1580
|
+
if (eq_case_insensitive("adamw", n)) {
|
|
1581
|
+
return GGML_OPT_OPTIMIZER_TYPE_ADAMW;
|
|
1582
|
+
}
|
|
1583
|
+
if (eq_case_insensitive("sgd", n)) {
|
|
1584
|
+
return GGML_OPT_OPTIMIZER_TYPE_SGD;
|
|
1585
|
+
}
|
|
1586
|
+
return GGML_OPT_OPTIMIZER_TYPE_COUNT;
|
|
1587
|
+
}
|
|
1588
|
+
|
|
1589
|
+
// TODO simplify to use just log and exp
|
|
1590
|
+
static float const k_log_2 = std::log(2.f);
|
|
1591
|
+
|
|
1592
|
+
void lr_opt::init() {
|
|
1593
|
+
if (lr_min > 0 && lr_min < lr0) {
|
|
1594
|
+
float nhalf = std::log(lr0 / lr_min) / k_log_2;
|
|
1595
|
+
float e = epochs;
|
|
1596
|
+
if (decay_epochs > 0 && decay_epochs < e) {
|
|
1597
|
+
e = decay_epochs;
|
|
1598
|
+
} else {
|
|
1599
|
+
decay_epochs = e;
|
|
1600
|
+
}
|
|
1601
|
+
scale_epoch = nhalf / e;
|
|
1602
|
+
}
|
|
1603
|
+
}
|
|
1604
|
+
|
|
1605
|
+
float lr_opt::get_lr(float epoch) const {
|
|
1606
|
+
float r = lr_min <= 0 ? lr0 :
|
|
1607
|
+
epoch >= decay_epochs ? lr_min :
|
|
1608
|
+
lr0 * std::pow(0.5f, epoch * scale_epoch);
|
|
1609
|
+
LOG_INF("epoch %.2g lr=%.2g\n", epoch, r);
|
|
1610
|
+
return r;
|
|
1611
|
+
}
|
|
@@ -2,14 +2,17 @@
|
|
|
2
2
|
|
|
3
3
|
#pragma once
|
|
4
4
|
|
|
5
|
-
#include "llama-cpp.h"
|
|
6
|
-
|
|
7
5
|
#include <set>
|
|
6
|
+
#include <sstream>
|
|
8
7
|
#include <string>
|
|
9
8
|
#include <string_view>
|
|
10
9
|
#include <vector>
|
|
11
10
|
#include <map>
|
|
12
11
|
#include <sstream>
|
|
12
|
+
#include <cmath>
|
|
13
|
+
|
|
14
|
+
#include "ggml-opt.h"
|
|
15
|
+
#include "llama-cpp.h"
|
|
13
16
|
|
|
14
17
|
#ifdef _WIN32
|
|
15
18
|
#define DIRECTORY_SEPARATOR '\\'
|
|
@@ -31,6 +34,9 @@ struct common_adapter_lora_info {
|
|
|
31
34
|
std::string path;
|
|
32
35
|
float scale;
|
|
33
36
|
|
|
37
|
+
std::string task_name;
|
|
38
|
+
std::string prompt_prefix;
|
|
39
|
+
|
|
34
40
|
struct llama_adapter_lora * ptr;
|
|
35
41
|
};
|
|
36
42
|
|
|
@@ -82,6 +88,7 @@ enum llama_example {
|
|
|
82
88
|
LLAMA_EXAMPLE_PARALLEL,
|
|
83
89
|
LLAMA_EXAMPLE_TTS,
|
|
84
90
|
LLAMA_EXAMPLE_DIFFUSION,
|
|
91
|
+
LLAMA_EXAMPLE_FINETUNE,
|
|
85
92
|
|
|
86
93
|
LLAMA_EXAMPLE_COUNT,
|
|
87
94
|
};
|
|
@@ -201,6 +208,8 @@ struct common_params_speculative {
|
|
|
201
208
|
int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
|
|
202
209
|
float p_split = 0.1f; // speculative decoding split probability
|
|
203
210
|
float p_min = 0.75f; // minimum speculative decoding probability (greedy)
|
|
211
|
+
std::vector<std::pair<std::string, std::string>> replacements; // main to speculative model replacements
|
|
212
|
+
std::vector<llama_model_tensor_buft_override> tensor_buft_overrides;
|
|
204
213
|
|
|
205
214
|
ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K
|
|
206
215
|
ggml_type cache_type_v = GGML_TYPE_F16; // KV cache data type for the V
|
|
@@ -220,19 +229,49 @@ struct common_params_vocoder {
|
|
|
220
229
|
};
|
|
221
230
|
|
|
222
231
|
struct common_params_diffusion {
|
|
223
|
-
int32_t steps
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
float
|
|
227
|
-
|
|
232
|
+
int32_t steps = 128;
|
|
233
|
+
bool visual_mode = false;
|
|
234
|
+
|
|
235
|
+
float eps = 0; // epsilon for timesteps
|
|
236
|
+
int32_t block_length = 0; // block length for generation
|
|
237
|
+
|
|
238
|
+
int32_t algorithm = 4; // default algorithm: low-confidence
|
|
239
|
+
float alg_temp = 0.0f; // algorithm temperature
|
|
240
|
+
|
|
241
|
+
float cfg_scale = 0; // classifier-free guidance scale
|
|
242
|
+
bool add_gumbel_noise = false; // add gumbel noise to the logits if temp > 0.0
|
|
228
243
|
};
|
|
229
244
|
|
|
245
|
+
// reasoning API response format (not to be confused as chat template's reasoning format)
|
|
230
246
|
enum common_reasoning_format {
|
|
231
247
|
COMMON_REASONING_FORMAT_NONE,
|
|
248
|
+
COMMON_REASONING_FORMAT_AUTO, // Same as deepseek, using `message.reasoning_content`
|
|
232
249
|
COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY, // Extract thinking tag contents and return as `message.reasoning_content`, or leave inline in <think> tags in stream mode
|
|
233
250
|
COMMON_REASONING_FORMAT_DEEPSEEK, // Extract thinking tag contents and return as `message.reasoning_content`, including in streaming deltas.
|
|
251
|
+
// do not extend this enum unless you absolutely have to
|
|
252
|
+
// in most cases, use COMMON_REASONING_FORMAT_AUTO
|
|
253
|
+
// see: https://github.com/ggml-org/llama.cpp/pull/15408
|
|
254
|
+
};
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
struct lr_opt {
|
|
258
|
+
float lr0 = 1e-5; // learning rate at first epoch
|
|
259
|
+
float lr_min = -1;
|
|
260
|
+
float decay_epochs = -1; // if >0, the learning rate starts at lr0 and decays to lr_min after this many epochs
|
|
261
|
+
float scale_epoch = 0;
|
|
262
|
+
float wd = 0;
|
|
263
|
+
unsigned epochs = 2;
|
|
264
|
+
|
|
265
|
+
unsigned epoch; // set by optimizer outer (epochs) loop
|
|
266
|
+
// learning rate decay - constant LR per epoch only for now
|
|
267
|
+
float get_lr(float e) const;
|
|
268
|
+
float get_lr() const { return get_lr(epoch); }
|
|
269
|
+
// must call after arg parse, before get_lr
|
|
270
|
+
void init();
|
|
234
271
|
};
|
|
235
272
|
|
|
273
|
+
struct ggml_opt_optimizer_params common_opt_lr_pars(void * userdata);
|
|
274
|
+
|
|
236
275
|
struct common_params {
|
|
237
276
|
int32_t n_predict = -1; // new tokens to predict
|
|
238
277
|
int32_t n_ctx = 4096; // context size
|
|
@@ -252,7 +291,6 @@ struct common_params {
|
|
|
252
291
|
float yarn_beta_fast = 32.0f; // YaRN low correction dim
|
|
253
292
|
float yarn_beta_slow = 1.0f; // YaRN high correction dim
|
|
254
293
|
int32_t yarn_orig_ctx = 0; // YaRN original context length
|
|
255
|
-
float defrag_thold = 0.1f; // KV cache defragmentation threshold
|
|
256
294
|
|
|
257
295
|
// offload params
|
|
258
296
|
std::vector<ggml_backend_dev_t> devices; // devices to use for offloading
|
|
@@ -339,7 +377,7 @@ struct common_params {
|
|
|
339
377
|
bool cont_batching = true; // insert new sequences for decoding on-the-fly
|
|
340
378
|
bool flash_attn = false; // flash attention
|
|
341
379
|
bool no_perf = false; // disable performance metrics
|
|
342
|
-
bool ctx_shift =
|
|
380
|
+
bool ctx_shift = false; // context shift on infinite text generation
|
|
343
381
|
bool swa_full = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
|
|
344
382
|
bool kv_unified = false; // enable unified KV cache
|
|
345
383
|
|
|
@@ -352,6 +390,7 @@ struct common_params {
|
|
|
352
390
|
bool warmup = true; // warmup run
|
|
353
391
|
bool check_tensors = false; // validate tensor data
|
|
354
392
|
bool no_op_offload = false; // globally disable offload host tensor operations to device
|
|
393
|
+
bool no_extra_bufts = false; // disable extra buffer types (used for weight repacking)
|
|
355
394
|
|
|
356
395
|
bool single_turn = false; // single turn chat conversation
|
|
357
396
|
|
|
@@ -366,6 +405,11 @@ struct common_params {
|
|
|
366
405
|
bool no_mmproj = false; // explicitly disable multimodal model
|
|
367
406
|
std::vector<std::string> image; // path to image file(s)
|
|
368
407
|
|
|
408
|
+
// finetune
|
|
409
|
+
struct lr_opt lr;
|
|
410
|
+
enum ggml_opt_optimizer_type optimizer = GGML_OPT_OPTIMIZER_TYPE_ADAMW;
|
|
411
|
+
float val_split = 0.05f; // fraction of the data used for the validation set
|
|
412
|
+
|
|
369
413
|
// embedding
|
|
370
414
|
bool embedding = false; // get only sentence embedding
|
|
371
415
|
int32_t embd_normalize = 2; // normalisation for embeddings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)
|
|
@@ -374,11 +418,12 @@ struct common_params {
|
|
|
374
418
|
std::string cls_sep = "\t"; // separator of classification sequences
|
|
375
419
|
|
|
376
420
|
// server params
|
|
377
|
-
int32_t port
|
|
378
|
-
int32_t timeout_read
|
|
379
|
-
int32_t timeout_write
|
|
380
|
-
int32_t n_threads_http
|
|
381
|
-
int32_t n_cache_reuse
|
|
421
|
+
int32_t port = 8080; // server listens on this network port
|
|
422
|
+
int32_t timeout_read = 600; // http read timeout in seconds
|
|
423
|
+
int32_t timeout_write = timeout_read; // http write timeout in seconds
|
|
424
|
+
int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool)
|
|
425
|
+
int32_t n_cache_reuse = 0; // min chunk size to reuse from the cache via KV shifting
|
|
426
|
+
int32_t n_swa_checkpoints = 3; // max number of SWA checkpoints per slot
|
|
382
427
|
|
|
383
428
|
std::string hostname = "127.0.0.1";
|
|
384
429
|
std::string public_path = ""; // NOLINT
|
|
@@ -386,7 +431,7 @@ struct common_params {
|
|
|
386
431
|
std::string chat_template = ""; // NOLINT
|
|
387
432
|
bool use_jinja = false; // NOLINT
|
|
388
433
|
bool enable_chat_template = true;
|
|
389
|
-
common_reasoning_format reasoning_format =
|
|
434
|
+
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_AUTO;
|
|
390
435
|
int reasoning_budget = -1;
|
|
391
436
|
bool prefill_assistant = true; // if true, any trailing assistant message will be prefilled into the response
|
|
392
437
|
|
|
@@ -431,6 +476,7 @@ struct common_params {
|
|
|
431
476
|
int32_t n_out_freq = 10; // output the imatrix every n_out_freq iterations
|
|
432
477
|
int32_t n_save_freq = 0; // save the imatrix every n_save_freq iterations
|
|
433
478
|
int32_t i_chunk = 0; // start processing from this chunk
|
|
479
|
+
int8_t imat_dat = 0; // whether the legacy imatrix.dat format should be output (gguf <= 0 < dat)
|
|
434
480
|
|
|
435
481
|
bool process_output = false; // collect data for the output tensor
|
|
436
482
|
bool compute_ppl = true; // whether to compute perplexity
|
|
@@ -692,3 +738,6 @@ const char * const LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count";
|
|
|
692
738
|
//
|
|
693
739
|
|
|
694
740
|
ggml_opt_dataset_t common_opt_dataset_init(struct llama_context * ctx, const std::vector<llama_token> & tokens, int64_t stride);
|
|
741
|
+
|
|
742
|
+
// "adamw" or "sgd" (case insensitive)
|
|
743
|
+
enum ggml_opt_optimizer_type common_opt_get_optimizer(const char *);
|
|
@@ -1,30 +1,39 @@
|
|
|
1
1
|
#include "speculative.h"
|
|
2
2
|
|
|
3
|
+
#include "ggml.h"
|
|
4
|
+
#include "llama.h"
|
|
3
5
|
#include "log.h"
|
|
4
6
|
#include "common.h"
|
|
5
7
|
#include "sampling.h"
|
|
6
8
|
|
|
7
9
|
#include <cstring>
|
|
8
10
|
#include <algorithm>
|
|
11
|
+
#include <map>
|
|
9
12
|
|
|
10
13
|
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128
|
|
11
14
|
#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
|
|
12
15
|
|
|
13
16
|
struct common_speculative {
|
|
14
|
-
struct llama_context *
|
|
17
|
+
struct llama_context * ctx_tgt; // only used for retokenizing from ctx_dft
|
|
18
|
+
struct llama_context * ctx_dft;
|
|
15
19
|
struct common_sampler * smpl;
|
|
16
20
|
|
|
17
21
|
llama_batch batch;
|
|
18
|
-
llama_tokens
|
|
22
|
+
llama_tokens prompt_dft;
|
|
23
|
+
bool vocab_dft_compatible = true; // whether retokenization is needed
|
|
24
|
+
std::map<std::string, std::string> tgt_dft_replacements = {};
|
|
19
25
|
};
|
|
20
26
|
|
|
21
27
|
struct common_speculative * common_speculative_init(
|
|
28
|
+
struct llama_context * ctx_tgt,
|
|
22
29
|
struct llama_context * ctx_dft) {
|
|
23
30
|
auto * result = new common_speculative {
|
|
24
|
-
/* .
|
|
25
|
-
/* .
|
|
26
|
-
/* .
|
|
27
|
-
/* .
|
|
31
|
+
/* .ctx_tgt = */ ctx_tgt,
|
|
32
|
+
/* .ctx_dft = */ ctx_dft,
|
|
33
|
+
/* .smpl = */ nullptr,
|
|
34
|
+
/* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1),
|
|
35
|
+
/* .prompt_dft = */ {},
|
|
36
|
+
/* .vocab_dft_compatible = */ false,
|
|
28
37
|
};
|
|
29
38
|
|
|
30
39
|
// TODO: optimize or pass from outside?
|
|
@@ -59,6 +68,9 @@ struct common_speculative * common_speculative_init(
|
|
|
59
68
|
}
|
|
60
69
|
#endif
|
|
61
70
|
|
|
71
|
+
result->vocab_dft_compatible = common_speculative_are_compatible(ctx_tgt, ctx_dft);
|
|
72
|
+
LOG_DBG("vocab_dft_compatible = %d\n", result->vocab_dft_compatible);
|
|
73
|
+
|
|
62
74
|
return result;
|
|
63
75
|
}
|
|
64
76
|
|
|
@@ -75,8 +87,8 @@ void common_speculative_free(struct common_speculative * spec) {
|
|
|
75
87
|
}
|
|
76
88
|
|
|
77
89
|
bool common_speculative_are_compatible(
|
|
78
|
-
|
|
79
|
-
|
|
90
|
+
const struct llama_context * ctx_tgt,
|
|
91
|
+
const struct llama_context * ctx_dft) {
|
|
80
92
|
const struct llama_model * model_tgt = llama_get_model(ctx_tgt);
|
|
81
93
|
const struct llama_model * model_dft = llama_get_model(ctx_dft);
|
|
82
94
|
|
|
@@ -90,31 +102,32 @@ bool common_speculative_are_compatible(
|
|
|
90
102
|
LOG_DBG("%s: vocab_type dft: %d\n", __func__, vocab_type_dft);
|
|
91
103
|
|
|
92
104
|
if (vocab_type_tgt != vocab_type_dft) {
|
|
93
|
-
|
|
94
|
-
|
|
105
|
+
LOG_DBG("%s: draft model vocab type must match target model to use speculation but ", __func__);
|
|
106
|
+
LOG_DBG("vocab_type_dft = %d while vocab_type_tgt = %d\n", vocab_type_dft, vocab_type_tgt);
|
|
95
107
|
return false;
|
|
96
108
|
}
|
|
97
109
|
|
|
98
|
-
if (
|
|
110
|
+
if (
|
|
111
|
+
llama_vocab_get_add_bos(vocab_tgt) != llama_vocab_get_add_bos(vocab_dft) ||
|
|
99
112
|
llama_vocab_get_add_eos(vocab_tgt) != llama_vocab_get_add_eos(vocab_dft) ||
|
|
100
113
|
llama_vocab_bos(vocab_tgt) != llama_vocab_bos(vocab_dft) ||
|
|
101
|
-
llama_vocab_eos(vocab_tgt) != llama_vocab_eos(vocab_dft)
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
LOG_ERR("%s: dft: bos = %d (%d), eos = %d (%d)\n", __func__, llama_vocab_bos(vocab_dft), llama_vocab_get_add_bos(vocab_dft), llama_vocab_eos(vocab_dft), llama_vocab_get_add_eos(vocab_dft));
|
|
114
|
+
llama_vocab_eos(vocab_tgt) != llama_vocab_eos(vocab_dft)
|
|
115
|
+
) {
|
|
116
|
+
LOG_DBG("%s: draft model special tokens must match target model to use speculation\n", __func__);
|
|
105
117
|
return false;
|
|
106
118
|
}
|
|
107
119
|
|
|
108
120
|
{
|
|
109
121
|
const int n_vocab_tgt = llama_vocab_n_tokens(vocab_tgt);
|
|
110
122
|
const int n_vocab_dft = llama_vocab_n_tokens(vocab_dft);
|
|
111
|
-
|
|
112
|
-
|
|
123
|
+
const int vocab_diff = n_vocab_tgt > n_vocab_dft
|
|
124
|
+
? n_vocab_tgt - n_vocab_dft
|
|
125
|
+
: n_vocab_dft - n_vocab_tgt;
|
|
113
126
|
|
|
114
127
|
if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) {
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
128
|
+
LOG_DBG("%s: draft model vocab must closely match target model to use speculation but ", __func__);
|
|
129
|
+
LOG_DBG("target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n",
|
|
130
|
+
n_vocab_tgt, llama_vocab_n_tokens(vocab_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE);
|
|
118
131
|
return false;
|
|
119
132
|
}
|
|
120
133
|
|
|
@@ -122,8 +135,8 @@ bool common_speculative_are_compatible(
|
|
|
122
135
|
const char * token_text_tgt = llama_vocab_get_text(vocab_tgt, i);
|
|
123
136
|
const char * token_text_dft = llama_vocab_get_text(vocab_dft, i);
|
|
124
137
|
if (std::strcmp(token_text_tgt, token_text_dft) != 0) {
|
|
125
|
-
|
|
126
|
-
|
|
138
|
+
LOG_DBG("%s: draft model vocab must match target model to use speculation but ", __func__);
|
|
139
|
+
LOG_DBG("token %d content differs - target '%s', draft '%s'\n", i,
|
|
127
140
|
common_token_to_piece(ctx_tgt, i).c_str(),
|
|
128
141
|
common_token_to_piece(ctx_dft, i).c_str());
|
|
129
142
|
return false;
|
|
@@ -134,32 +147,93 @@ bool common_speculative_are_compatible(
|
|
|
134
147
|
return true;
|
|
135
148
|
}
|
|
136
149
|
|
|
150
|
+
void common_speculative_add_replacement_tgt_dft(
|
|
151
|
+
struct common_speculative * spec,
|
|
152
|
+
const char *source, const char *dest) {
|
|
153
|
+
spec->tgt_dft_replacements[source] = dest;
|
|
154
|
+
}
|
|
155
|
+
|
|
156
|
+
static std::string replace_to_dft(
|
|
157
|
+
struct common_speculative * spec,
|
|
158
|
+
const std::string& input) {
|
|
159
|
+
std::string result = input;
|
|
160
|
+
for (const auto & pair : spec->tgt_dft_replacements) {
|
|
161
|
+
size_t pos = result.find(pair.first);
|
|
162
|
+
while (pos != std::string::npos) {
|
|
163
|
+
result.replace(pos, pair.first.length(), pair.second);
|
|
164
|
+
pos = result.find(pair.first, pos + pair.second.length());
|
|
165
|
+
}
|
|
166
|
+
}
|
|
167
|
+
return result;
|
|
168
|
+
}
|
|
169
|
+
|
|
170
|
+
static std::string replace_to_tgt(
|
|
171
|
+
struct common_speculative * spec,
|
|
172
|
+
const std::string& input) {
|
|
173
|
+
std::string result = input;
|
|
174
|
+
for (const auto& pair : spec->tgt_dft_replacements) {
|
|
175
|
+
size_t pos = result.find(pair.second);
|
|
176
|
+
while (pos != std::string::npos) {
|
|
177
|
+
result.replace(pos, pair.second.length(), pair.first);
|
|
178
|
+
pos = result.find(pair.second, pos + pair.first.length());
|
|
179
|
+
}
|
|
180
|
+
}
|
|
181
|
+
return result;
|
|
182
|
+
}
|
|
183
|
+
|
|
184
|
+
|
|
137
185
|
llama_tokens common_speculative_gen_draft(
|
|
138
186
|
struct common_speculative * spec,
|
|
139
187
|
struct common_speculative_params params,
|
|
140
|
-
const llama_tokens &
|
|
188
|
+
const llama_tokens & prompt_tgt_main_model, // specified in target model vocab
|
|
141
189
|
llama_token id_last) {
|
|
142
190
|
auto & batch = spec->batch;
|
|
143
|
-
auto &
|
|
191
|
+
auto & ctx_tgt = spec->ctx_tgt;
|
|
192
|
+
auto & ctx_dft = spec->ctx_dft;
|
|
144
193
|
auto & smpl = spec->smpl;
|
|
145
|
-
auto &
|
|
194
|
+
auto & prompt_dft = spec->prompt_dft;
|
|
146
195
|
|
|
147
|
-
auto *
|
|
196
|
+
auto * mem_dft = llama_get_memory(ctx_dft);
|
|
148
197
|
|
|
149
198
|
int reuse_i = 0;
|
|
150
199
|
int reuse_n = 0;
|
|
151
200
|
|
|
152
|
-
const int n_ctx = llama_n_ctx(
|
|
201
|
+
const int n_ctx = llama_n_ctx(ctx_dft) - params.n_draft;
|
|
202
|
+
|
|
203
|
+
llama_tokens prompt_tgt_draft_model;
|
|
204
|
+
if (!spec->vocab_dft_compatible) {
|
|
205
|
+
std::string text;
|
|
206
|
+
text = common_detokenize(ctx_tgt, prompt_tgt_main_model, true);
|
|
207
|
+
text = replace_to_dft(spec, text);
|
|
208
|
+
LOG_DBG("%s: main->draft detokenized string: '%s'\n", __func__, text.c_str());
|
|
209
|
+
prompt_tgt_draft_model = common_tokenize(ctx_dft, text, false, true);
|
|
210
|
+
|
|
211
|
+
// convert id_last to draft vocab. llama_detokenize is called directly to avoid an allocation
|
|
212
|
+
const auto * model_tgt = llama_get_model(ctx_tgt);
|
|
213
|
+
const auto * vocab_tgt = llama_model_get_vocab(model_tgt);
|
|
214
|
+
|
|
215
|
+
int32_t n_chars = llama_detokenize(vocab_tgt, &id_last, 1, nullptr, 0, false, false);
|
|
216
|
+
GGML_ASSERT(n_chars < 0 && "failed to detokenize id_last");
|
|
217
|
+
text.resize(-n_chars);
|
|
218
|
+
llama_detokenize(vocab_tgt, &id_last, 1, text.data(), text.size(), false, false);
|
|
219
|
+
text = replace_to_dft(spec, text);
|
|
220
|
+
|
|
221
|
+
LOG_DBG("main->draft detokenized id_last(%d): '%s'\n", id_last, text.c_str());
|
|
222
|
+
id_last = common_tokenize(ctx_dft, text, false, true)[0];
|
|
223
|
+
}
|
|
224
|
+
// prompt_tgt's tokens will always be compatible with ctx_dft
|
|
225
|
+
const llama_tokens &prompt_tgt =
|
|
226
|
+
spec->vocab_dft_compatible ? prompt_tgt_main_model : prompt_tgt_draft_model;
|
|
153
227
|
|
|
154
228
|
const int i_start = std::max<int>(0, (int) prompt_tgt.size() - n_ctx);
|
|
155
229
|
|
|
156
230
|
// reuse as much as possible from the old draft context
|
|
157
231
|
// ideally, the draft context should be as big as the target context and we will always reuse the entire prompt
|
|
158
|
-
for (int i = 0; i < (int)
|
|
232
|
+
for (int i = 0; i < (int) prompt_dft.size(); ++i) {
|
|
159
233
|
int cur = 0;
|
|
160
234
|
while (i_start + cur < (int) prompt_tgt.size() &&
|
|
161
|
-
i + cur < (int)
|
|
162
|
-
prompt_tgt[i_start + cur] ==
|
|
235
|
+
i + cur < (int) prompt_dft.size() &&
|
|
236
|
+
prompt_tgt[i_start + cur] == prompt_dft[i + cur]) {
|
|
163
237
|
cur++;
|
|
164
238
|
}
|
|
165
239
|
|
|
@@ -169,21 +243,20 @@ llama_tokens common_speculative_gen_draft(
|
|
|
169
243
|
}
|
|
170
244
|
}
|
|
171
245
|
|
|
172
|
-
LOG_DBG("%s: reuse_i = %d, reuse_n = %d, prompt = %d\n", __func__, reuse_i, reuse_n, (int)
|
|
246
|
+
LOG_DBG("%s: reuse_i = %d, reuse_n = %d, prompt = %d\n", __func__, reuse_i, reuse_n, (int) prompt_dft.size());
|
|
173
247
|
|
|
174
248
|
llama_tokens result;
|
|
175
249
|
result.reserve(params.n_draft);
|
|
176
250
|
|
|
177
251
|
if (reuse_n == 0) {
|
|
178
|
-
llama_memory_clear(
|
|
179
|
-
|
|
180
|
-
prompt.clear();
|
|
252
|
+
llama_memory_clear(mem_dft, false);
|
|
253
|
+
prompt_dft.clear();
|
|
181
254
|
} else {
|
|
182
255
|
// this happens when a previous draft has been discarded (for example, due to being too small), but the
|
|
183
256
|
// target model agreed with it. in this case, we simply pass back the previous results to save compute
|
|
184
|
-
if (reuse_i + reuse_n < (int)
|
|
185
|
-
for (int i = reuse_i + reuse_n + 1; i < (int)
|
|
186
|
-
result.push_back(
|
|
257
|
+
if (reuse_i + reuse_n < (int) prompt_dft.size() && prompt_dft[reuse_i + reuse_n] == id_last) {
|
|
258
|
+
for (int i = reuse_i + reuse_n + 1; i < (int) prompt_dft.size(); ++i) {
|
|
259
|
+
result.push_back(prompt_dft[i]);
|
|
187
260
|
|
|
188
261
|
if (params.n_draft <= (int) result.size()) {
|
|
189
262
|
break;
|
|
@@ -194,16 +267,15 @@ llama_tokens common_speculative_gen_draft(
|
|
|
194
267
|
}
|
|
195
268
|
|
|
196
269
|
if (reuse_i > 0) {
|
|
197
|
-
llama_memory_seq_rm (
|
|
198
|
-
llama_memory_seq_add(
|
|
270
|
+
llama_memory_seq_rm (mem_dft, 0, 0, reuse_i);
|
|
271
|
+
llama_memory_seq_add(mem_dft, 0, reuse_i, -1, -reuse_i);
|
|
199
272
|
|
|
200
|
-
|
|
273
|
+
prompt_dft.erase(prompt_dft.begin(), prompt_dft.begin() + reuse_i);
|
|
201
274
|
}
|
|
202
275
|
|
|
203
|
-
if (reuse_n < (int)
|
|
204
|
-
llama_memory_seq_rm (
|
|
205
|
-
|
|
206
|
-
prompt.erase(prompt.begin() + reuse_n, prompt.end());
|
|
276
|
+
if (reuse_n < (int) prompt_dft.size()) {
|
|
277
|
+
llama_memory_seq_rm (mem_dft, 0, reuse_n, -1);
|
|
278
|
+
prompt_dft.erase(prompt_dft.begin() + reuse_n, prompt_dft.end());
|
|
207
279
|
}
|
|
208
280
|
}
|
|
209
281
|
|
|
@@ -214,28 +286,28 @@ llama_tokens common_speculative_gen_draft(
|
|
|
214
286
|
//LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_tgt[i]);
|
|
215
287
|
common_batch_add(batch, prompt_tgt[i], i - i_start, { 0 }, false);
|
|
216
288
|
|
|
217
|
-
|
|
289
|
+
prompt_dft.push_back(prompt_tgt[i]);
|
|
218
290
|
}
|
|
219
291
|
|
|
220
292
|
// we should rarely end-up here during normal decoding
|
|
221
293
|
if (batch.n_tokens > 0) {
|
|
222
294
|
//LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str());
|
|
223
295
|
|
|
224
|
-
llama_decode(
|
|
296
|
+
llama_decode(ctx_dft, batch);
|
|
225
297
|
}
|
|
226
298
|
|
|
227
|
-
const llama_pos n_past =
|
|
299
|
+
const llama_pos n_past = prompt_dft.size();
|
|
228
300
|
|
|
229
301
|
LOG_DBG("%s: n_past = %d\n", __func__, n_past);
|
|
230
302
|
|
|
231
303
|
common_batch_clear(batch);
|
|
232
304
|
common_batch_add (batch, id_last, n_past, { 0 }, true);
|
|
233
305
|
|
|
234
|
-
|
|
306
|
+
prompt_dft.push_back(id_last);
|
|
235
307
|
|
|
236
|
-
|
|
308
|
+
LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx_dft, prompt_dft).c_str());
|
|
237
309
|
|
|
238
|
-
llama_decode(
|
|
310
|
+
llama_decode(ctx_dft, batch);
|
|
239
311
|
|
|
240
312
|
common_sampler_reset(smpl);
|
|
241
313
|
|
|
@@ -243,13 +315,13 @@ llama_tokens common_speculative_gen_draft(
|
|
|
243
315
|
for (int i = 0; i < params.n_draft; ++i) {
|
|
244
316
|
common_batch_clear(batch);
|
|
245
317
|
|
|
246
|
-
common_sampler_sample(smpl,
|
|
318
|
+
common_sampler_sample(smpl, ctx_dft, 0, true);
|
|
247
319
|
|
|
248
320
|
const auto * cur_p = common_sampler_get_candidates(smpl);
|
|
249
321
|
|
|
250
322
|
for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) {
|
|
251
323
|
LOG_DBG(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
|
|
252
|
-
k, i, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(
|
|
324
|
+
k, i, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx_dft, cur_p->data[k].id).c_str());
|
|
253
325
|
}
|
|
254
326
|
|
|
255
327
|
// add drafted token for each sequence
|
|
@@ -271,10 +343,19 @@ llama_tokens common_speculative_gen_draft(
|
|
|
271
343
|
common_batch_add(batch, id, n_past + i + 1, { 0 }, true);
|
|
272
344
|
|
|
273
345
|
// evaluate the drafted tokens on the draft model
|
|
274
|
-
llama_decode(
|
|
346
|
+
llama_decode(ctx_dft, batch);
|
|
275
347
|
|
|
276
|
-
|
|
348
|
+
prompt_dft.push_back(id);
|
|
277
349
|
}
|
|
278
350
|
|
|
351
|
+
if (!spec->vocab_dft_compatible) {
|
|
352
|
+
std::string detokenized = common_detokenize(ctx_dft, result, true);
|
|
353
|
+
detokenized = replace_to_tgt(spec, detokenized);
|
|
354
|
+
LOG_DBG("draft->main detokenized string: '%s'\n", detokenized.c_str());
|
|
355
|
+
result = common_tokenize(ctx_tgt, detokenized, false, true);
|
|
356
|
+
if (result.size() > (size_t)params.n_draft) {
|
|
357
|
+
result.resize(params.n_draft);
|
|
358
|
+
}
|
|
359
|
+
}
|
|
279
360
|
return result;
|
|
280
361
|
}
|