@novastera-oss/llamarn 0.2.7 → 0.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/src/main/cpp/include/llama.h +8 -3
- package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86/libggml.so +0 -0
- package/android/src/main/jniLibs/x86/libllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
- package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
- package/cpp/LlamaCppModel.cpp +56 -22
- package/cpp/build-info.cpp +2 -2
- package/cpp/llama.cpp/CMakeLists.txt +1 -2
- package/cpp/llama.cpp/README.md +4 -5
- package/cpp/llama.cpp/build-xcframework.sh +1 -1
- package/cpp/llama.cpp/common/CMakeLists.txt +4 -5
- package/cpp/llama.cpp/common/arg.cpp +24 -0
- package/cpp/llama.cpp/common/chat.cpp +37 -20
- package/cpp/llama.cpp/common/chat.h +2 -0
- package/cpp/llama.cpp/common/common.cpp +3 -0
- package/cpp/llama.cpp/common/common.h +5 -0
- package/cpp/llama.cpp/common/json-schema-to-grammar.cpp +3 -46
- package/cpp/llama.cpp/convert_hf_to_gguf.py +860 -23
- package/cpp/llama.cpp/convert_hf_to_gguf_update.py +9 -0
- package/cpp/llama.cpp/ggml/CMakeLists.txt +8 -2
- package/cpp/llama.cpp/ggml/include/ggml-backend.h +1 -1
- package/cpp/llama.cpp/ggml/include/ggml-cpu.h +2 -0
- package/cpp/llama.cpp/ggml/include/ggml.h +206 -10
- package/cpp/llama.cpp/ggml/src/CMakeLists.txt +17 -1
- package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +0 -8
- package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +36 -18
- package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +68 -5
- package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +16 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +37 -3
- package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +10 -9
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +109 -108
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +1027 -1038
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +53 -52
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +56 -55
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +42 -41
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +24 -23
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +29 -28
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +30 -29
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +83 -82
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/repack.cpp +20 -19
- package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +3 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +9 -3
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +111 -103
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +3 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +1405 -240
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +25 -24
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +56 -40
- package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +212 -34
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +35 -11
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +187 -54
- package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +71 -29
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +22 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +4 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +8 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +6 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +14 -12
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +5 -3
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +15 -10
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +12 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +269 -110
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cu +19 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cuh +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +2 -8
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cu +257 -87
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cuh +2 -3
- package/cpp/llama.cpp/ggml/src/ggml-cuda/rope.cu +21 -27
- package/cpp/llama.cpp/ggml/src/ggml-cuda/scale.cu +8 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/softmax.cu +119 -58
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +192 -52
- package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cu +5 -18
- package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cuh +0 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +97 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +11 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/upscale.cu +92 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +14 -5
- package/cpp/llama.cpp/ggml/src/ggml-impl.h +125 -183
- package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -2
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +51 -9
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +394 -80
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +616 -239
- package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cuh +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +741 -571
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl +337 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/set_rows.cl +95 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
- package/cpp/llama.cpp/ggml/src/ggml-quants.c +6 -6
- package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +5 -6
- package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +1 -24
- package/cpp/llama.cpp/ggml/src/ggml-sycl/concat.cpp +28 -41
- package/cpp/llama.cpp/ggml/src/ggml-sycl/conv.cpp +4 -10
- package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +99 -166
- package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +94 -72
- package/cpp/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +49 -67
- package/cpp/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +697 -1098
- package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +18 -9
- package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +6 -9
- package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +104 -62
- package/cpp/llama.cpp/ggml/src/ggml-sycl/gla.cpp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +60 -80
- package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +132 -201
- package/cpp/llama.cpp/ggml/src/ggml-sycl/norm.cpp +55 -74
- package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +39 -38
- package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +131 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.hpp +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +3 -3
- package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +3 -8
- package/cpp/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +12 -16
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +12 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +767 -292
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +58 -7
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +28 -23
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +14 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +38 -32
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +32 -27
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +44 -12
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +15 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +128 -72
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +38 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +12 -3
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +20 -4
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +23 -3
- package/cpp/llama.cpp/ggml/src/ggml.c +449 -72
- package/cpp/llama.cpp/ggml/src/gguf.cpp +13 -2
- package/cpp/llama.cpp/gguf-py/gguf/constants.py +285 -0
- package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +27 -0
- package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +137 -21
- package/cpp/llama.cpp/gguf-py/gguf/vocab.py +109 -7
- package/cpp/llama.cpp/gguf-py/pyproject.toml +2 -2
- package/cpp/llama.cpp/include/llama.h +8 -43
- package/cpp/llama.cpp/models/templates/Mistral-Small-3.2-24B-Instruct-2506.jinja +124 -0
- package/cpp/llama.cpp/src/llama-arch.cpp +265 -3
- package/cpp/llama.cpp/src/llama-arch.h +36 -1
- package/cpp/llama.cpp/src/llama-batch.cpp +596 -359
- package/cpp/llama.cpp/src/llama-batch.h +105 -70
- package/cpp/llama.cpp/src/llama-chat.cpp +26 -6
- package/cpp/llama.cpp/src/llama-chat.h +1 -0
- package/cpp/llama.cpp/src/llama-context.cpp +101 -107
- package/cpp/llama.cpp/src/llama-context.h +13 -13
- package/cpp/llama.cpp/src/llama-graph.cpp +286 -404
- package/cpp/llama.cpp/src/llama-graph.h +78 -79
- package/cpp/llama.cpp/src/llama-hparams.cpp +11 -1
- package/cpp/llama.cpp/src/llama-hparams.h +11 -0
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +74 -66
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +23 -26
- package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +312 -157
- package/cpp/llama.cpp/src/llama-kv-cache-unified.h +79 -46
- package/cpp/llama.cpp/src/llama-kv-cells.h +97 -21
- package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +73 -69
- package/cpp/llama.cpp/src/llama-memory-hybrid.h +19 -22
- package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +88 -77
- package/cpp/llama.cpp/src/llama-memory-recurrent.h +15 -20
- package/cpp/llama.cpp/src/llama-memory.cpp +17 -0
- package/cpp/llama.cpp/src/llama-memory.h +21 -22
- package/cpp/llama.cpp/src/llama-model-saver.cpp +1 -0
- package/cpp/llama.cpp/src/llama-model.cpp +5301 -2922
- package/cpp/llama.cpp/src/llama-model.h +40 -0
- package/cpp/llama.cpp/src/llama-quant.cpp +88 -5
- package/cpp/llama.cpp/src/llama-vocab.cpp +37 -3
- package/cpp/llama.cpp/src/llama-vocab.h +42 -0
- package/cpp/rn-utils.h +3 -0
- package/ios/include/chat.h +2 -0
- package/ios/include/common.h +5 -0
- package/ios/include/llama.h +8 -43
- package/ios/libs/llama.xcframework/Info.plist +19 -19
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5059 -4863
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4834
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3889 -3742
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4834
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3891 -3744
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5059 -4863
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4834
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3889 -3742
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5095 -4900
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5066 -4871
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3919 -3773
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/package.json +1 -1
- package/cpp/llama.cpp/ggml/include/ggml-kompute.h +0 -50
- package/cpp/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
- package/cpp/llama.cpp/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
#include "llama-batch.h"
|
|
2
2
|
|
|
3
3
|
#include "llama-impl.h"
|
|
4
|
-
#include "llama-cparams.h"
|
|
5
4
|
#include "llama-vocab.h"
|
|
6
5
|
#include "llama-memory.h"
|
|
7
6
|
|
|
@@ -10,282 +9,7 @@
|
|
|
10
9
|
#include <algorithm>
|
|
11
10
|
#include <sstream>
|
|
12
11
|
|
|
13
|
-
|
|
14
|
-
// clear empty sequences
|
|
15
|
-
// the previous ubatch is assumed to be gone,
|
|
16
|
-
// so nothing should refer to values in these sequences anymore.
|
|
17
|
-
for (size_t i = seq.size(); i-- > 0;) {
|
|
18
|
-
if (seq[i].length == 0) {
|
|
19
|
-
seq.pop_back();
|
|
20
|
-
} else {
|
|
21
|
-
break;
|
|
22
|
-
}
|
|
23
|
-
}
|
|
24
|
-
|
|
25
|
-
udatas.push_back({});
|
|
26
|
-
|
|
27
|
-
auto & udata = udatas.back();
|
|
28
|
-
|
|
29
|
-
udata.token.resize(!has_embd ? n_ubatch : 0);
|
|
30
|
-
udata.embd.resize(has_embd ? n_embd * n_ubatch : 0);
|
|
31
|
-
udata.pos.resize(n_ubatch);
|
|
32
|
-
udata.n_seq_id.resize(n_ubatch);
|
|
33
|
-
udata.seq_id.resize(n_ubatch);
|
|
34
|
-
udata.output.resize(n_ubatch);
|
|
35
|
-
|
|
36
|
-
llama_ubatch ubatch = {
|
|
37
|
-
/*equal_seqs =*/ true,
|
|
38
|
-
/*n_tokens =*/ 0,
|
|
39
|
-
/*n_seq_tokens =*/ 0,
|
|
40
|
-
/*n_seqs =*/ 0,
|
|
41
|
-
/*token =*/ !has_embd ? udata.token.data() : nullptr,
|
|
42
|
-
/*embd =*/ has_embd ? udata.embd.data() : nullptr,
|
|
43
|
-
/*pos =*/ udata.pos.data(),
|
|
44
|
-
/*n_seq_id =*/ udata.n_seq_id.data(),
|
|
45
|
-
/*seq_id =*/ udata.seq_id.data(),
|
|
46
|
-
/*output =*/ udata.output.data(),
|
|
47
|
-
};
|
|
48
|
-
|
|
49
|
-
return ubatch;
|
|
50
|
-
}
|
|
51
|
-
|
|
52
|
-
void llama_sbatch::add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length) {
|
|
53
|
-
GGML_ASSERT(batch != nullptr);
|
|
54
|
-
GGML_ASSERT(length <= seq.length);
|
|
55
|
-
// Can only add sequences of equal lengths to a batch,
|
|
56
|
-
// otherwise it isn't clear to which sequence a token belongs
|
|
57
|
-
GGML_ASSERT(seq.n_seq_id == 0 || ubatch.n_seqs == 0 || length == (size_t) ubatch.n_tokens / ubatch.n_seqs);
|
|
58
|
-
GGML_ASSERT((seq.n_seq_id != 0) == ubatch.equal_seqs);
|
|
59
|
-
// NOTE: loops are separated for cache-friendliness
|
|
60
|
-
if (batch->token) {
|
|
61
|
-
if (ubatch.equal_seqs) {
|
|
62
|
-
for (size_t i = 0; i < length; ++i) {
|
|
63
|
-
ubatch.token[ubatch.n_tokens + i] = batch->token[ids[seq.offset + i]];
|
|
64
|
-
}
|
|
65
|
-
} else {
|
|
66
|
-
// simple split
|
|
67
|
-
ubatch.token = batch->token + seq.offset;
|
|
68
|
-
}
|
|
69
|
-
} else {
|
|
70
|
-
ubatch.token = nullptr;
|
|
71
|
-
}
|
|
72
|
-
if (batch->embd) {
|
|
73
|
-
if (ubatch.equal_seqs) {
|
|
74
|
-
for (size_t i = 0; i < length; ++i) {
|
|
75
|
-
memcpy(
|
|
76
|
-
ubatch.embd + (n_embd * (ubatch.n_tokens + i)),
|
|
77
|
-
batch->embd + (n_embd * ids[seq.offset + i]),
|
|
78
|
-
n_embd * sizeof(float)
|
|
79
|
-
);
|
|
80
|
-
}
|
|
81
|
-
} else {
|
|
82
|
-
// simple split
|
|
83
|
-
ubatch.embd = batch->embd + (n_embd * seq.offset);
|
|
84
|
-
}
|
|
85
|
-
} else {
|
|
86
|
-
ubatch.embd = nullptr;
|
|
87
|
-
}
|
|
88
|
-
if (ubatch.equal_seqs) {
|
|
89
|
-
for (size_t i = 0; i < length; ++i) {
|
|
90
|
-
ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]];
|
|
91
|
-
}
|
|
92
|
-
} else {
|
|
93
|
-
// simple split
|
|
94
|
-
ubatch.pos = batch->pos + seq.offset;
|
|
95
|
-
}
|
|
96
|
-
if (ubatch.equal_seqs) {
|
|
97
|
-
ubatch.n_seq_id[ubatch.n_seqs] = seq.n_seq_id;
|
|
98
|
-
if (seq.seq_id) {
|
|
99
|
-
ubatch.seq_id[ubatch.n_seqs] = seq.seq_id;
|
|
100
|
-
}
|
|
101
|
-
} else {
|
|
102
|
-
// simple split
|
|
103
|
-
if (batch->n_seq_id) {
|
|
104
|
-
ubatch.n_seq_id = batch->n_seq_id + seq.offset;
|
|
105
|
-
} else {
|
|
106
|
-
for (size_t i = 0; i < length; ++i) {
|
|
107
|
-
ubatch.n_seq_id[ubatch.n_seqs + i] = 1;
|
|
108
|
-
}
|
|
109
|
-
}
|
|
110
|
-
if (batch->seq_id) {
|
|
111
|
-
ubatch.seq_id = batch->seq_id + seq.offset;
|
|
112
|
-
}
|
|
113
|
-
}
|
|
114
|
-
if (batch->logits) {
|
|
115
|
-
if (ubatch.equal_seqs) {
|
|
116
|
-
for (size_t i = 0; i < length; ++i) {
|
|
117
|
-
size_t id = ids[seq.offset + i];
|
|
118
|
-
int8_t is_output = batch->logits[id];
|
|
119
|
-
ubatch.output[ubatch.n_tokens + i] = is_output;
|
|
120
|
-
if (is_output) { out_ids.push_back(id); }
|
|
121
|
-
}
|
|
122
|
-
} else {
|
|
123
|
-
// simple split
|
|
124
|
-
ubatch.output = batch->logits + seq.offset;
|
|
125
|
-
for (size_t i = 0; i < length; ++i) {
|
|
126
|
-
if (ubatch.output[i] != 0) { out_ids.push_back(seq.offset + i); }
|
|
127
|
-
}
|
|
128
|
-
}
|
|
129
|
-
} else {
|
|
130
|
-
// only get last output
|
|
131
|
-
for (size_t i = 0; i < length; ++i) {
|
|
132
|
-
size_t id = ids[seq.offset + i];
|
|
133
|
-
int8_t is_last = id == ids.size() - 1;
|
|
134
|
-
ubatch.output[ubatch.n_tokens + i] = is_last;
|
|
135
|
-
if (is_last) { out_ids.push_back(id); }
|
|
136
|
-
}
|
|
137
|
-
}
|
|
138
|
-
if (ubatch.n_tokens == 0 && ubatch.n_seqs == 0) {
|
|
139
|
-
ubatch.n_seq_tokens = ubatch.equal_seqs ? length : 1;
|
|
140
|
-
}
|
|
141
|
-
ubatch.n_tokens += length;
|
|
142
|
-
ubatch.n_seqs += ubatch.equal_seqs ? 1 : length; // virtual sequences for simple splits
|
|
143
|
-
seq.offset += length;
|
|
144
|
-
seq.length -= length;
|
|
145
|
-
n_tokens -= length;
|
|
146
|
-
GGML_ASSERT(ubatch.n_tokens == ubatch.n_seq_tokens * ubatch.n_seqs);
|
|
147
|
-
}
|
|
148
|
-
|
|
149
|
-
llama_ubatch llama_sbatch::split_simple(size_t n_ubatch) {
|
|
150
|
-
n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
|
|
151
|
-
llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
|
|
152
|
-
ubatch.equal_seqs = false;
|
|
153
|
-
if (!seq.empty()) {
|
|
154
|
-
llama_sbatch_seq & s = seq[0];
|
|
155
|
-
size_t length = s.length < n_ubatch ? s.length : n_ubatch;
|
|
156
|
-
GGML_ASSERT(seq.size() == 1 && s.n_seq_id == 0); // don't mix with other splits
|
|
157
|
-
add_seq_to_ubatch(ubatch, s, length);
|
|
158
|
-
}
|
|
159
|
-
return ubatch;
|
|
160
|
-
}
|
|
161
|
-
|
|
162
|
-
llama_ubatch llama_sbatch::split_equal(size_t n_ubatch) {
|
|
163
|
-
n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
|
|
164
|
-
llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
|
|
165
|
-
if (!seq.empty()) {
|
|
166
|
-
size_t length = 0;
|
|
167
|
-
size_t n_tokens_in_ubatch = 0;
|
|
168
|
-
GGML_ASSERT(seq[0].n_seq_id > 0); // should not be mixed with simple splits
|
|
169
|
-
// smallest first, because it's easier to split this way;
|
|
170
|
-
// starting from the end to pop in constant time.
|
|
171
|
-
for (size_t i = seq.size(); i-- > 0;) {
|
|
172
|
-
llama_sbatch_seq & s = seq[i];
|
|
173
|
-
GGML_ASSERT(s.length > 0);
|
|
174
|
-
if (length == 0) {
|
|
175
|
-
length = s.length < n_ubatch ? s.length : n_ubatch;
|
|
176
|
-
}
|
|
177
|
-
add_seq_to_ubatch(ubatch, s, length);
|
|
178
|
-
n_tokens_in_ubatch += length;
|
|
179
|
-
// shared prompts can't be mixed with any of their sequences,
|
|
180
|
-
// so it's safer to compute them in their own ubatch
|
|
181
|
-
if (s.n_seq_id > 1) { break; }
|
|
182
|
-
// stop when there isn't enough space for another sequence
|
|
183
|
-
if (length + n_tokens_in_ubatch > n_ubatch) { break; }
|
|
184
|
-
}
|
|
185
|
-
}
|
|
186
|
-
return ubatch;
|
|
187
|
-
}
|
|
188
|
-
|
|
189
|
-
llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) {
|
|
190
|
-
n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
|
|
191
|
-
llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
|
|
192
|
-
if (!seq.empty()) {
|
|
193
|
-
llama_sbatch_seq & s = seq[seq.size() - 1];
|
|
194
|
-
size_t length = s.length < n_ubatch ? s.length : n_ubatch;
|
|
195
|
-
GGML_ASSERT(s.n_seq_id > 0); // should not be mixed with simple splits
|
|
196
|
-
add_seq_to_ubatch(ubatch, s, length);
|
|
197
|
-
}
|
|
198
|
-
return ubatch;
|
|
199
|
-
}
|
|
200
|
-
|
|
201
|
-
llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split) {
|
|
202
|
-
GGML_ASSERT(batch.n_tokens >= 0);
|
|
203
|
-
this->batch = &batch;
|
|
204
|
-
this->n_embd = n_embd;
|
|
205
|
-
|
|
206
|
-
n_tokens = batch.n_tokens;
|
|
207
|
-
ids.resize(n_tokens);
|
|
208
|
-
out_ids.clear();
|
|
209
|
-
// TODO: reserve out_ids and seq
|
|
210
|
-
|
|
211
|
-
for (size_t i = 0; i < n_tokens; ++i) {
|
|
212
|
-
ids[i] = i;
|
|
213
|
-
}
|
|
214
|
-
|
|
215
|
-
if (simple_split) {
|
|
216
|
-
seq.resize(1);
|
|
217
|
-
llama_sbatch_seq & s = seq[0];
|
|
218
|
-
s.n_seq_id = 0;
|
|
219
|
-
s.seq_id = nullptr;
|
|
220
|
-
s.offset = 0;
|
|
221
|
-
s.length = n_tokens;
|
|
222
|
-
return;
|
|
223
|
-
}
|
|
224
|
-
|
|
225
|
-
std::sort(ids.begin(), ids.end(),
|
|
226
|
-
[&batch](size_t a, size_t b) {
|
|
227
|
-
int32_t n_seq_a = batch.n_seq_id ? batch.n_seq_id[a] : 1;
|
|
228
|
-
int32_t n_seq_b = batch.n_seq_id ? batch.n_seq_id[b] : 1;
|
|
229
|
-
// sort by seq_id, then by pos
|
|
230
|
-
if (n_seq_a == n_seq_b) {
|
|
231
|
-
if (batch.seq_id) {
|
|
232
|
-
for (int32_t i = 0; i < n_seq_a; ++i) {
|
|
233
|
-
llama_seq_id seq_id_a = batch.seq_id[a][i];
|
|
234
|
-
llama_seq_id seq_id_b = batch.seq_id[b][i];
|
|
235
|
-
// smaller seq_ids go first
|
|
236
|
-
if (seq_id_a != seq_id_b) {
|
|
237
|
-
return seq_id_a < seq_id_b;
|
|
238
|
-
}
|
|
239
|
-
}
|
|
240
|
-
}
|
|
241
|
-
// when all else is equal, sort by pos
|
|
242
|
-
if (batch.pos) {
|
|
243
|
-
return batch.pos[a] < batch.pos[b];
|
|
244
|
-
}
|
|
245
|
-
// no pos, sort by id
|
|
246
|
-
return a < b;
|
|
247
|
-
}
|
|
248
|
-
// shared prompts go first
|
|
249
|
-
return n_seq_a > n_seq_b;
|
|
250
|
-
}
|
|
251
|
-
);
|
|
252
|
-
|
|
253
|
-
// init seq
|
|
254
|
-
llama_sbatch_seq * last_seq = nullptr;
|
|
255
|
-
|
|
256
|
-
for (size_t i = 0; i < n_tokens; ++i) {
|
|
257
|
-
const size_t bi = ids[i];
|
|
258
|
-
const int32_t n_seqs = batch.n_seq_id[bi];
|
|
259
|
-
llama_seq_id * seq_ids = batch.seq_id[bi];
|
|
260
|
-
if (last_seq != nullptr) {
|
|
261
|
-
bool same = n_seqs == last_seq->n_seq_id;
|
|
262
|
-
for (int32_t j = 0; same && j < n_seqs; ++j) {
|
|
263
|
-
if (seq_ids[j] != last_seq->seq_id[j]) {
|
|
264
|
-
same = false;
|
|
265
|
-
}
|
|
266
|
-
}
|
|
267
|
-
if (same) {
|
|
268
|
-
last_seq->length += 1;
|
|
269
|
-
continue;
|
|
270
|
-
}
|
|
271
|
-
}
|
|
272
|
-
llama_sbatch_seq new_seq = {n_seqs, seq_ids, i, 1};
|
|
273
|
-
seq.push_back(new_seq);
|
|
274
|
-
last_seq = &seq.back();
|
|
275
|
-
}
|
|
276
|
-
|
|
277
|
-
// keep shared prompts first at the end, then sort by length descending.
|
|
278
|
-
std::sort(seq.begin(), seq.end(),
|
|
279
|
-
[](llama_sbatch_seq & a, llama_sbatch_seq & b) {
|
|
280
|
-
if (a.n_seq_id == b.n_seq_id) {
|
|
281
|
-
return a.length > b.length;
|
|
282
|
-
}
|
|
283
|
-
return a.n_seq_id < b.n_seq_id;
|
|
284
|
-
}
|
|
285
|
-
);
|
|
286
|
-
}
|
|
287
|
-
|
|
288
|
-
llama_batch_allocr::llama_batch_allocr() {
|
|
12
|
+
llama_batch_allocr::llama_batch_allocr(uint32_t n_pos_per_embd) : n_pos_per_embd(n_pos_per_embd) {
|
|
289
13
|
const char * LLAMA_BATCH_DEBUG = getenv("LLAMA_BATCH_DEBUG");
|
|
290
14
|
debug = LLAMA_BATCH_DEBUG ? atoi(LLAMA_BATCH_DEBUG) : 0;
|
|
291
15
|
|
|
@@ -294,17 +18,22 @@ llama_batch_allocr::llama_batch_allocr() {
|
|
|
294
18
|
for (auto & cur : seq_cpl) {
|
|
295
19
|
cur.resize(LLAMA_MAX_SEQ);
|
|
296
20
|
}
|
|
21
|
+
|
|
22
|
+
seq_idx.resize(LLAMA_MAX_SEQ, -1);
|
|
297
23
|
}
|
|
298
24
|
|
|
299
25
|
bool llama_batch_allocr::init(
|
|
300
26
|
const llama_batch & batch_inp,
|
|
301
27
|
const llama_vocab & vocab,
|
|
302
28
|
const llama_memory_i * memory,
|
|
303
|
-
|
|
29
|
+
uint32_t n_embd,
|
|
30
|
+
bool output_all) {
|
|
304
31
|
clear();
|
|
305
32
|
|
|
306
33
|
batch = batch_inp;
|
|
307
34
|
|
|
35
|
+
this->vocab = &vocab;
|
|
36
|
+
|
|
308
37
|
GGML_ASSERT(batch.n_tokens > 0);
|
|
309
38
|
|
|
310
39
|
//
|
|
@@ -359,6 +88,7 @@ bool llama_batch_allocr::init(
|
|
|
359
88
|
llama_pos p0[LLAMA_MAX_SEQ];
|
|
360
89
|
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
|
361
90
|
if (!memory) {
|
|
91
|
+
// if no memory -> start from 0
|
|
362
92
|
p0[s] = 0;
|
|
363
93
|
} else {
|
|
364
94
|
p0[s] = memory->seq_pos_max(s) + 1;
|
|
@@ -370,8 +100,11 @@ bool llama_batch_allocr::init(
|
|
|
370
100
|
|
|
371
101
|
pos[i] = p0[seq_id];
|
|
372
102
|
|
|
103
|
+
// update the starting position for all sequences that are assigned to the this token
|
|
373
104
|
for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
|
|
374
|
-
|
|
105
|
+
const llama_seq_id seq_id = batch.seq_id[i][s];
|
|
106
|
+
|
|
107
|
+
p0[seq_id] = pos[i] + 1;
|
|
375
108
|
}
|
|
376
109
|
}
|
|
377
110
|
|
|
@@ -379,7 +112,7 @@ bool llama_batch_allocr::init(
|
|
|
379
112
|
}
|
|
380
113
|
|
|
381
114
|
if (!batch.logits) {
|
|
382
|
-
if (
|
|
115
|
+
if (output_all) {
|
|
383
116
|
// return the output for all tokens
|
|
384
117
|
output.resize(batch.n_tokens, true);
|
|
385
118
|
} else {
|
|
@@ -389,7 +122,7 @@ bool llama_batch_allocr::init(
|
|
|
389
122
|
}
|
|
390
123
|
|
|
391
124
|
batch.logits = output.data();
|
|
392
|
-
} else if (
|
|
125
|
+
} else if (output_all) {
|
|
393
126
|
bool warn = false;
|
|
394
127
|
|
|
395
128
|
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
|
@@ -410,6 +143,9 @@ bool llama_batch_allocr::init(
|
|
|
410
143
|
// compute stats
|
|
411
144
|
//
|
|
412
145
|
|
|
146
|
+
this->n_embd = n_embd;
|
|
147
|
+
|
|
148
|
+
// count the outputs in this batch
|
|
413
149
|
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
|
414
150
|
n_outputs += batch.logits[i] != 0;
|
|
415
151
|
}
|
|
@@ -417,85 +153,88 @@ bool llama_batch_allocr::init(
|
|
|
417
153
|
// determine coupled sequences
|
|
418
154
|
// these are pairs of sequences that have at least one token in the input batch that is assigned to both of them
|
|
419
155
|
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
|
156
|
+
const llama_seq_id s0 = batch.seq_id[i][0];
|
|
157
|
+
|
|
420
158
|
for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
|
|
421
|
-
|
|
159
|
+
const llama_seq_id s1 = batch.seq_id[i][s];
|
|
422
160
|
|
|
423
|
-
|
|
424
|
-
const llama_seq_id s0 = batch.seq_id[i][0];
|
|
425
|
-
const llama_seq_id s1 = batch.seq_id[i][s];
|
|
161
|
+
seq_pos[s1].insert(batch.pos[i]);
|
|
426
162
|
|
|
163
|
+
if (s > 0) {
|
|
427
164
|
// mark that sequence s1 is coupled to s0
|
|
428
165
|
seq_cpl[s1][s0] = true;
|
|
429
166
|
|
|
430
|
-
// note: the other way around is not necessary for now
|
|
167
|
+
// note: tracking the other way around is not necessary for now
|
|
431
168
|
//seq_cpl[s0][s1] = true;
|
|
169
|
+
|
|
170
|
+
has_cpl = true;
|
|
432
171
|
}
|
|
433
172
|
}
|
|
434
173
|
}
|
|
435
174
|
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
LLAMA_LOG_DEBUG("%s: token = %p\n", __func__, (void *) batch.token);
|
|
440
|
-
LLAMA_LOG_DEBUG("%s: embd = %p\n", __func__, (void *) batch.embd);
|
|
441
|
-
LLAMA_LOG_DEBUG("%s: pos = %p\n", __func__, (void *) batch.pos);
|
|
442
|
-
LLAMA_LOG_DEBUG("%s: n_seq_id = %p\n", __func__, (void *) batch.n_seq_id);
|
|
443
|
-
LLAMA_LOG_DEBUG("%s: seq_id = %p\n", __func__, (void *) batch.seq_id);
|
|
444
|
-
LLAMA_LOG_DEBUG("%s: logits = %p\n", __func__, (void *) batch.logits);
|
|
445
|
-
LLAMA_LOG_DEBUG("%s: n_outputs = %d\n", __func__, n_outputs);
|
|
175
|
+
// precompute the sequence sets for each token and determine the unique sequence ids that participate in the batch
|
|
176
|
+
{
|
|
177
|
+
seq_set_t seq_set_unq;
|
|
446
178
|
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
for (int32_t
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
}
|
|
179
|
+
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
|
180
|
+
seq_set_t cur;
|
|
181
|
+
for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
|
|
182
|
+
const llama_seq_id seq_id = batch.seq_id[i][s];
|
|
183
|
+
|
|
184
|
+
cur .set(seq_id);
|
|
185
|
+
seq_set_unq.set(seq_id);
|
|
455
186
|
}
|
|
456
|
-
++seq_id_max;
|
|
457
187
|
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
188
|
+
seq_set.push_back(cur);
|
|
189
|
+
seq_set_map[cur].push_back(i);
|
|
190
|
+
}
|
|
461
191
|
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
192
|
+
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
|
193
|
+
if (seq_set_unq.test(s)) {
|
|
194
|
+
seq_idx[s] = seq_id_unq.size();
|
|
195
|
+
seq_id_unq.push_back(s);
|
|
196
|
+
}
|
|
197
|
+
}
|
|
198
|
+
}
|
|
465
199
|
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
if (seq_id[s]) {
|
|
469
|
-
ss << s%10;
|
|
470
|
-
} else {
|
|
471
|
-
ss << ".";
|
|
472
|
-
}
|
|
473
|
-
}
|
|
200
|
+
if (debug > 0) {
|
|
201
|
+
LLAMA_LOG_DEBUG("%s: input batch info:\n", __func__);
|
|
474
202
|
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
203
|
+
llama_ubatch ubatch {
|
|
204
|
+
/*.equal_seqs =*/ false,
|
|
205
|
+
/*.n_tokens =*/ (uint32_t) batch.n_tokens,
|
|
206
|
+
/*.n_seq_tokens =*/ (uint32_t) 1,
|
|
207
|
+
/*.n_seqs =*/ (uint32_t) batch.n_tokens,
|
|
208
|
+
/*.n_seqs_unq =*/ (uint32_t) this->seq_id_unq.size(),
|
|
209
|
+
/*.token =*/ batch.token,
|
|
210
|
+
/*.embd =*/ batch.embd,
|
|
211
|
+
/*.pos =*/ batch.pos,
|
|
212
|
+
/*.n_seq_id =*/ batch.n_seq_id,
|
|
213
|
+
/*.seq_id =*/ batch.seq_id,
|
|
214
|
+
/*.seq_id_unq =*/ this->seq_id_unq.data(),
|
|
215
|
+
/*.seq_idx =*/ this->seq_idx.data(),
|
|
216
|
+
/*.output =*/ batch.logits,
|
|
217
|
+
};
|
|
218
|
+
|
|
219
|
+
ubatch_print(ubatch, debug);
|
|
220
|
+
|
|
221
|
+
LLAMA_LOG_DEBUG("%s: seq = [\n", __func__);
|
|
222
|
+
for (int s0 = 0; s0 < (int) seq_pos.size(); ++s0) {
|
|
223
|
+
if (seq_pos[s0].empty()) {
|
|
224
|
+
continue;
|
|
478
225
|
}
|
|
479
|
-
LLAMA_LOG_DEBUG("%s: ]\n", __func__);
|
|
480
226
|
|
|
481
|
-
|
|
482
|
-
for (int
|
|
483
|
-
if (
|
|
484
|
-
|
|
485
|
-
}
|
|
486
|
-
|
|
487
|
-
std::stringstream ss;
|
|
488
|
-
for (int s1 = 0; s1 < (int) seq_cpl[s0].size(); ++s1) {
|
|
489
|
-
if (seq_cpl[s0][s1]) {
|
|
490
|
-
ss << s1 << " ";
|
|
491
|
-
}
|
|
227
|
+
std::stringstream ss;
|
|
228
|
+
for (int s1 = 0; s1 < (int) seq_cpl[s0].size(); ++s1) {
|
|
229
|
+
if (seq_cpl[s0][s1]) {
|
|
230
|
+
ss << s1 << " ";
|
|
492
231
|
}
|
|
493
|
-
|
|
494
|
-
LLAMA_LOG_DEBUG("%s: %4d: pos = [%4d, %4d], cpl = %s\n",
|
|
495
|
-
__func__, s0, seq_pos_min(s0), seq_pos_max(s0), ss.str().empty() ? "-" : ss.str().c_str());
|
|
496
232
|
}
|
|
497
|
-
|
|
233
|
+
|
|
234
|
+
LLAMA_LOG_DEBUG("%s: %4d: pos = [%4d, %4d], cpl = %s\n",
|
|
235
|
+
__func__, s0, seq_pos_min(s0), seq_pos_max(s0), ss.str().empty() ? "-" : ss.str().c_str());
|
|
498
236
|
}
|
|
237
|
+
LLAMA_LOG_DEBUG("%s: ]\n", __func__);
|
|
499
238
|
}
|
|
500
239
|
|
|
501
240
|
//
|
|
@@ -507,9 +246,35 @@ bool llama_batch_allocr::init(
|
|
|
507
246
|
continue;
|
|
508
247
|
}
|
|
509
248
|
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
249
|
+
const llama_pos p0 = memory ? memory->seq_pos_max(s) : -1;
|
|
250
|
+
|
|
251
|
+
if (p0 >= 0) {
|
|
252
|
+
bool ok = true;
|
|
253
|
+
|
|
254
|
+
if (batch.token) {
|
|
255
|
+
if (seq_pos_min(s) != p0 + 1) {
|
|
256
|
+
ok = false;
|
|
257
|
+
}
|
|
258
|
+
} else {
|
|
259
|
+
assert(batch.embd);
|
|
260
|
+
|
|
261
|
+
// for embeddings (typically used as vision input), we allow them to have repeating positions
|
|
262
|
+
// ref: https://github.com/ggml-org/llama.cpp/issues/13694#issuecomment-2983871762
|
|
263
|
+
if (seq_pos_min(s) != p0 && seq_pos_min(s) != p0 + 1) {
|
|
264
|
+
ok = false;
|
|
265
|
+
}
|
|
266
|
+
}
|
|
267
|
+
|
|
268
|
+
if (!ok) {
|
|
269
|
+
LLAMA_LOG_ERROR(
|
|
270
|
+
"%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
|
|
271
|
+
" - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
|
|
272
|
+
" - the tokens for sequence %d in the input batch have a starting position of Y = %d\n"
|
|
273
|
+
" it is required that the sequence positions remain consecutive: Y = X + 1\n",
|
|
274
|
+
__func__, s, s, p0, s, seq_pos_min(s));
|
|
275
|
+
|
|
276
|
+
return false;
|
|
277
|
+
}
|
|
513
278
|
}
|
|
514
279
|
|
|
515
280
|
if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) {
|
|
@@ -532,17 +297,124 @@ bool llama_batch_allocr::init(
|
|
|
532
297
|
}
|
|
533
298
|
}
|
|
534
299
|
|
|
300
|
+
// disallow partial sequence sub-sets:
|
|
301
|
+
//
|
|
302
|
+
// invalid: x
|
|
303
|
+
// i: 0 1 2 ...
|
|
304
|
+
// ---------------------------------------
|
|
305
|
+
// seq_id[i][0]: 0 0 1
|
|
306
|
+
// seq_id[i][1]: 1 1 2
|
|
307
|
+
// seq_id[i][2]: 2
|
|
308
|
+
//
|
|
309
|
+
// disallow decreasing sequence positions:
|
|
310
|
+
//
|
|
311
|
+
// invalid: x
|
|
312
|
+
// i: 0 1 2 3 4 5 6 ...
|
|
313
|
+
// ---------------------------------------
|
|
314
|
+
// pos[i]: 4 5 0 1 6 2 3
|
|
315
|
+
// seq_id[i][0]: 0 0 1 1 0 1 0
|
|
316
|
+
//
|
|
317
|
+
{
|
|
318
|
+
seq_set_t cur_seq_set[LLAMA_MAX_SEQ];
|
|
319
|
+
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
|
320
|
+
cur_seq_set[s].set();
|
|
321
|
+
}
|
|
322
|
+
|
|
323
|
+
llama_pos cur_seq_pos[LLAMA_MAX_SEQ];
|
|
324
|
+
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
|
325
|
+
cur_seq_pos[s] = -1;
|
|
326
|
+
}
|
|
327
|
+
|
|
328
|
+
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
|
329
|
+
const llama_pos pos = batch.pos[i];
|
|
330
|
+
|
|
331
|
+
for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
|
|
332
|
+
const llama_seq_id seq_id = batch.seq_id[i][s];
|
|
333
|
+
|
|
334
|
+
cur_seq_set[seq_id] &= seq_set[i];
|
|
335
|
+
|
|
336
|
+
if (cur_seq_set[seq_id].none()) {
|
|
337
|
+
LLAMA_LOG_ERROR("%s: sequence %d belongs to incompatible sequence sets (not allowed)\n", __func__, seq_id);
|
|
338
|
+
return false;
|
|
339
|
+
}
|
|
340
|
+
|
|
341
|
+
if (pos < cur_seq_pos[seq_id]) {
|
|
342
|
+
LLAMA_LOG_ERROR("%s: sequence %d positions are decreasing (not allowed)\n", __func__, seq_id);
|
|
343
|
+
return false;
|
|
344
|
+
}
|
|
345
|
+
}
|
|
346
|
+
}
|
|
347
|
+
}
|
|
348
|
+
|
|
349
|
+
split_reset();
|
|
350
|
+
|
|
535
351
|
return true;
|
|
536
352
|
}
|
|
537
353
|
|
|
354
|
+
llama_ubatch llama_batch_allocr::ubatch_reserve(uint32_t n_seq_tokens, uint32_t n_seqs) {
|
|
355
|
+
const uint32_t n_tokens = n_seq_tokens*n_seqs;
|
|
356
|
+
|
|
357
|
+
clear();
|
|
358
|
+
split_reset();
|
|
359
|
+
|
|
360
|
+
ubatches.emplace_back();
|
|
361
|
+
|
|
362
|
+
auto & ubatch = ubatches.back();
|
|
363
|
+
|
|
364
|
+
ubatch.token .resize(n_tokens);
|
|
365
|
+
ubatch.embd .clear();
|
|
366
|
+
ubatch.pos .resize(n_tokens);
|
|
367
|
+
ubatch.n_seq_id .resize(n_tokens);
|
|
368
|
+
ubatch.seq_id .resize(n_tokens);
|
|
369
|
+
ubatch.seq_id_unq.resize(0);
|
|
370
|
+
ubatch.seq_idx .resize(LLAMA_MAX_SEQ, -1);
|
|
371
|
+
ubatch.output .resize(n_tokens);
|
|
372
|
+
|
|
373
|
+
for (uint32_t s = 0; s < n_seqs; ++s) {
|
|
374
|
+
ubatch.seq_idx[s] = s;
|
|
375
|
+
ubatch.seq_id_unq.push_back(s);
|
|
376
|
+
}
|
|
377
|
+
|
|
378
|
+
llama_ubatch res {
|
|
379
|
+
/*.equal_seqs =*/ true,
|
|
380
|
+
/*.n_tokens =*/ n_tokens,
|
|
381
|
+
/*.n_seq_tokens =*/ n_seq_tokens,
|
|
382
|
+
/*.n_seqs =*/ n_seqs,
|
|
383
|
+
/*.n_seqs_unq =*/ n_seqs,
|
|
384
|
+
|
|
385
|
+
/*.token =*/ ubatch.token.data(),
|
|
386
|
+
/*.embd =*/ nullptr,
|
|
387
|
+
/*.pos =*/ ubatch.pos.data(),
|
|
388
|
+
/*.n_seq_id =*/ ubatch.n_seq_id.data(),
|
|
389
|
+
/*.seq_id =*/ ubatch.seq_id.data(),
|
|
390
|
+
/*.seq_id_unq =*/ ubatch.seq_id_unq.data(),
|
|
391
|
+
/*.seq_idx =*/ ubatch.seq_idx.data(),
|
|
392
|
+
/*.output =*/ ubatch.output.data(),
|
|
393
|
+
};
|
|
394
|
+
|
|
395
|
+
return res;
|
|
396
|
+
}
|
|
397
|
+
|
|
538
398
|
const llama_batch & llama_batch_allocr::get_batch() const {
|
|
539
399
|
return batch;
|
|
540
400
|
}
|
|
541
401
|
|
|
402
|
+
uint32_t llama_batch_allocr::get_n_tokens() const {
|
|
403
|
+
return batch.n_tokens;
|
|
404
|
+
}
|
|
405
|
+
|
|
542
406
|
uint32_t llama_batch_allocr::get_n_outputs() const {
|
|
543
407
|
return n_outputs;
|
|
544
408
|
}
|
|
545
409
|
|
|
410
|
+
uint32_t llama_batch_allocr::get_n_used() const {
|
|
411
|
+
return n_used;
|
|
412
|
+
}
|
|
413
|
+
|
|
414
|
+
std::vector<int32_t> & llama_batch_allocr::get_out_ids() {
|
|
415
|
+
return out_ids;
|
|
416
|
+
}
|
|
417
|
+
|
|
546
418
|
llama_pos llama_batch_allocr::seq_pos_min(llama_seq_id seq_id) const {
|
|
547
419
|
return seq_pos[seq_id].empty() ? -1 : *seq_pos[seq_id].begin();
|
|
548
420
|
}
|
|
@@ -551,14 +423,208 @@ llama_pos llama_batch_allocr::seq_pos_max(llama_seq_id seq_id) const {
|
|
|
551
423
|
return seq_pos[seq_id].empty() ? -1 : *seq_pos[seq_id].rbegin();
|
|
552
424
|
}
|
|
553
425
|
|
|
426
|
+
void llama_batch_allocr::split_reset() {
|
|
427
|
+
out_ids.clear();
|
|
428
|
+
|
|
429
|
+
n_used = 0;
|
|
430
|
+
|
|
431
|
+
used.clear();
|
|
432
|
+
used.resize(get_n_tokens(), false);
|
|
433
|
+
|
|
434
|
+
ubatches.clear();
|
|
435
|
+
}
|
|
436
|
+
|
|
437
|
+
llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
|
|
438
|
+
// find the first unused token
|
|
439
|
+
uint32_t cur_idx = 0;
|
|
440
|
+
while (cur_idx < used.size() && used[cur_idx]) {
|
|
441
|
+
++cur_idx;
|
|
442
|
+
}
|
|
443
|
+
|
|
444
|
+
// we are done
|
|
445
|
+
if (cur_idx >= used.size()) {
|
|
446
|
+
return {};
|
|
447
|
+
}
|
|
448
|
+
|
|
449
|
+
std::vector<int32_t> idxs;
|
|
450
|
+
|
|
451
|
+
while (true) {
|
|
452
|
+
idxs.push_back(cur_idx);
|
|
453
|
+
|
|
454
|
+
used[cur_idx] = true;
|
|
455
|
+
++n_used;
|
|
456
|
+
|
|
457
|
+
++cur_idx;
|
|
458
|
+
|
|
459
|
+
if (cur_idx >= used.size()) {
|
|
460
|
+
break;
|
|
461
|
+
}
|
|
462
|
+
|
|
463
|
+
if (idxs.size() >= n_ubatch) {
|
|
464
|
+
break;
|
|
465
|
+
}
|
|
466
|
+
}
|
|
467
|
+
|
|
468
|
+
return ubatch_add(idxs, idxs.size(), false);
|
|
469
|
+
}
|
|
470
|
+
|
|
471
|
+
llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch, bool sequential) {
|
|
472
|
+
if (sequential && has_cpl) {
|
|
473
|
+
LLAMA_LOG_ERROR("%s: sequential split is not supported when there are coupled sequences in the input batch\n", __func__);
|
|
474
|
+
|
|
475
|
+
return {};
|
|
476
|
+
}
|
|
477
|
+
|
|
478
|
+
std::vector<seq_set_t> cur_seq_set;
|
|
479
|
+
|
|
480
|
+
llama_seq_id last_seq_id = -1;
|
|
481
|
+
|
|
482
|
+
// determine the non-overlapping sequence sets participating in this ubatch
|
|
483
|
+
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
|
484
|
+
if (used[i]) {
|
|
485
|
+
continue;
|
|
486
|
+
}
|
|
487
|
+
|
|
488
|
+
bool add = true;
|
|
489
|
+
|
|
490
|
+
for (uint32_t s = 0; s < cur_seq_set.size(); ++s) {
|
|
491
|
+
// no overlap with existing sequence sets:
|
|
492
|
+
if (!(cur_seq_set[s] & seq_set[i]).none()) {
|
|
493
|
+
add = false;
|
|
494
|
+
break;
|
|
495
|
+
}
|
|
496
|
+
}
|
|
497
|
+
|
|
498
|
+
// accept only increasing sequence ids
|
|
499
|
+
if (sequential) {
|
|
500
|
+
add = add && (cur_seq_set.empty() || batch.seq_id[i][0] == last_seq_id + 1);
|
|
501
|
+
}
|
|
502
|
+
|
|
503
|
+
if (add) {
|
|
504
|
+
cur_seq_set.push_back(seq_set[i]);
|
|
505
|
+
|
|
506
|
+
last_seq_id = batch.seq_id[i][0];
|
|
507
|
+
|
|
508
|
+
if (cur_seq_set.size() > n_ubatch) {
|
|
509
|
+
break;
|
|
510
|
+
}
|
|
511
|
+
}
|
|
512
|
+
}
|
|
513
|
+
|
|
514
|
+
const uint32_t n_seqs = cur_seq_set.size();
|
|
515
|
+
|
|
516
|
+
// we are done
|
|
517
|
+
if (n_seqs == 0) {
|
|
518
|
+
return {};
|
|
519
|
+
}
|
|
520
|
+
|
|
521
|
+
// the current batch index of each sequence set
|
|
522
|
+
std::vector<int32_t> cur_idx(n_seqs, 0);
|
|
523
|
+
|
|
524
|
+
for (uint32_t s = 0; s < n_seqs; ++s) {
|
|
525
|
+
while (used[seq_set_map[cur_seq_set[s]][cur_idx[s]]]) {
|
|
526
|
+
++cur_idx[s];
|
|
527
|
+
}
|
|
528
|
+
}
|
|
529
|
+
|
|
530
|
+
// the list of batch indices for each sequence set
|
|
531
|
+
// at the end we will concat these to get the final ubatch
|
|
532
|
+
std::vector<idx_vec_t> idxs_per_seq(n_seqs);
|
|
533
|
+
|
|
534
|
+
while (true) {
|
|
535
|
+
// we can only add new n_seq_tokens tokens if all the sequence sets have at least one more unused token and
|
|
536
|
+
// if we haven't reached n_ubatch
|
|
537
|
+
bool can_expand = true;
|
|
538
|
+
|
|
539
|
+
for (uint32_t s = 0; s < n_seqs; ++s) {
|
|
540
|
+
if (cur_idx[s] >= (int32_t) seq_set_map[cur_seq_set[s]].size()) {
|
|
541
|
+
can_expand = false;
|
|
542
|
+
break;
|
|
543
|
+
}
|
|
544
|
+
}
|
|
545
|
+
|
|
546
|
+
if (!can_expand) {
|
|
547
|
+
break;
|
|
548
|
+
}
|
|
549
|
+
|
|
550
|
+
for (uint32_t s = 0; s < n_seqs; ++s) {
|
|
551
|
+
const int32_t idx = seq_set_map[cur_seq_set[s]][cur_idx[s]];
|
|
552
|
+
|
|
553
|
+
idxs_per_seq[s].push_back(idx);
|
|
554
|
+
|
|
555
|
+
used[idx] = true;
|
|
556
|
+
++n_used;
|
|
557
|
+
|
|
558
|
+
++cur_idx[s];
|
|
559
|
+
}
|
|
560
|
+
|
|
561
|
+
if ((idxs_per_seq[0].size() + 1)*n_seqs > n_ubatch) {
|
|
562
|
+
break;
|
|
563
|
+
}
|
|
564
|
+
}
|
|
565
|
+
|
|
566
|
+
// concat the per-sequence-set lists
|
|
567
|
+
std::vector<int32_t> idxs;
|
|
568
|
+
|
|
569
|
+
for (uint32_t s = 0; s < n_seqs; ++s) {
|
|
570
|
+
idxs.insert(idxs.end(), idxs_per_seq[s].begin(), idxs_per_seq[s].end());
|
|
571
|
+
}
|
|
572
|
+
|
|
573
|
+
return ubatch_add(idxs, n_seqs, true);
|
|
574
|
+
}
|
|
575
|
+
|
|
576
|
+
llama_ubatch llama_batch_allocr::split_seq(uint32_t n_ubatch) {
|
|
577
|
+
// find the first unused token
|
|
578
|
+
uint32_t cur_idx = 0;
|
|
579
|
+
while (cur_idx < used.size() && used[cur_idx]) {
|
|
580
|
+
++cur_idx;
|
|
581
|
+
}
|
|
582
|
+
|
|
583
|
+
// we are done
|
|
584
|
+
if (cur_idx >= used.size()) {
|
|
585
|
+
return {};
|
|
586
|
+
}
|
|
587
|
+
|
|
588
|
+
// this is the starting sequence set
|
|
589
|
+
// we allow adding tokens only if their sequence set is a subset of the current sequence set
|
|
590
|
+
auto cur_seq_set = seq_set[cur_idx];
|
|
591
|
+
|
|
592
|
+
std::vector<int32_t> idxs;
|
|
593
|
+
|
|
594
|
+
while (true) {
|
|
595
|
+
idxs.push_back(cur_idx);
|
|
596
|
+
|
|
597
|
+
used[cur_idx] = true;
|
|
598
|
+
++n_used;
|
|
599
|
+
|
|
600
|
+
if (idxs.size() >= n_ubatch) {
|
|
601
|
+
break;
|
|
602
|
+
}
|
|
603
|
+
|
|
604
|
+
do {
|
|
605
|
+
++cur_idx;
|
|
606
|
+
} while (cur_idx < get_n_tokens() && (used[cur_idx] || ((cur_seq_set & seq_set[cur_idx]) != seq_set[cur_idx])));
|
|
607
|
+
|
|
608
|
+
if (cur_idx == get_n_tokens()) {
|
|
609
|
+
break;
|
|
610
|
+
}
|
|
611
|
+
|
|
612
|
+
cur_seq_set = seq_set[cur_idx];
|
|
613
|
+
}
|
|
614
|
+
|
|
615
|
+
return ubatch_add(idxs, 1, true);
|
|
616
|
+
}
|
|
617
|
+
|
|
554
618
|
void llama_batch_allocr::clear() {
|
|
555
619
|
n_outputs = 0;
|
|
556
620
|
|
|
557
621
|
batch = {};
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
622
|
+
|
|
623
|
+
pos .clear();
|
|
624
|
+
n_seq_id .clear();
|
|
625
|
+
seq_id .clear();
|
|
626
|
+
seq_id_unq.clear();
|
|
627
|
+
output .clear();
|
|
562
628
|
|
|
563
629
|
for (auto & cur : seq_pos) {
|
|
564
630
|
cur.clear();
|
|
@@ -567,6 +633,177 @@ void llama_batch_allocr::clear() {
|
|
|
567
633
|
for (auto & cur : seq_cpl) {
|
|
568
634
|
std::fill(cur.begin(), cur.end(), false);
|
|
569
635
|
}
|
|
636
|
+
|
|
637
|
+
seq_set.clear();
|
|
638
|
+
|
|
639
|
+
seq_set_map.clear();
|
|
640
|
+
|
|
641
|
+
std::fill(seq_idx.begin(), seq_idx.end(), -1);
|
|
642
|
+
}
|
|
643
|
+
|
|
644
|
+
llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, uint32_t n_seqs, bool equal_seqs) {
|
|
645
|
+
const uint32_t n_tokens = idxs.size();
|
|
646
|
+
|
|
647
|
+
assert(n_tokens%n_seqs == 0);
|
|
648
|
+
|
|
649
|
+
ubatches.emplace_back();
|
|
650
|
+
|
|
651
|
+
auto & ubatch = ubatches.back();
|
|
652
|
+
|
|
653
|
+
const int32_t n_pos_cur = batch.embd ? n_pos_per_embd : 1;
|
|
654
|
+
|
|
655
|
+
const int64_t n_embd_all = batch.embd ? (int64_t) n_tokens*n_embd : 0;
|
|
656
|
+
const int64_t n_pos_all = (int64_t) n_tokens*n_pos_cur;
|
|
657
|
+
|
|
658
|
+
ubatch.token .resize(n_tokens);
|
|
659
|
+
ubatch.embd .resize(n_embd_all);
|
|
660
|
+
ubatch.pos .resize(n_pos_all);
|
|
661
|
+
ubatch.n_seq_id .resize(n_tokens);
|
|
662
|
+
ubatch.seq_id .resize(n_tokens);
|
|
663
|
+
ubatch.seq_id_unq.resize(0);
|
|
664
|
+
ubatch.seq_idx .resize(LLAMA_MAX_SEQ, -1);
|
|
665
|
+
ubatch.output .resize(n_tokens);
|
|
666
|
+
|
|
667
|
+
seq_set_t seq_set_unq;
|
|
668
|
+
|
|
669
|
+
for (size_t i = 0; i < idxs.size(); ++i) {
|
|
670
|
+
if (batch.token) {
|
|
671
|
+
ubatch.token[i] = batch.token[idxs[i]];
|
|
672
|
+
}
|
|
673
|
+
|
|
674
|
+
if (batch.embd) {
|
|
675
|
+
memcpy(ubatch.embd.data() + i*n_embd, batch.embd + (int64_t) idxs[i]*n_embd, n_embd*sizeof(float));
|
|
676
|
+
}
|
|
677
|
+
|
|
678
|
+
for (int j = 0; j < n_pos_cur; ++j) {
|
|
679
|
+
ubatch.pos[j*n_tokens + i] = batch.pos[j*batch.n_tokens + idxs[i]];
|
|
680
|
+
}
|
|
681
|
+
|
|
682
|
+
ubatch.n_seq_id[i] = batch.n_seq_id[idxs[i]];
|
|
683
|
+
ubatch.seq_id[i] = batch.seq_id[idxs[i]];
|
|
684
|
+
ubatch.output[i] = batch.logits[idxs[i]];
|
|
685
|
+
|
|
686
|
+
for (int s = 0; s < ubatch.n_seq_id[i]; ++s) {
|
|
687
|
+
seq_set_unq.set(ubatch.seq_id[i][s]);
|
|
688
|
+
}
|
|
689
|
+
|
|
690
|
+
if (ubatch.output[i]) {
|
|
691
|
+
out_ids.push_back(idxs[i]);
|
|
692
|
+
}
|
|
693
|
+
}
|
|
694
|
+
|
|
695
|
+
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
|
696
|
+
if (seq_set_unq.test(s)) {
|
|
697
|
+
ubatch.seq_idx[s] = ubatch.seq_id_unq.size();
|
|
698
|
+
ubatch.seq_id_unq.push_back(s);
|
|
699
|
+
}
|
|
700
|
+
}
|
|
701
|
+
|
|
702
|
+
llama_ubatch res {
|
|
703
|
+
/*.equal_seqs =*/ equal_seqs,
|
|
704
|
+
/*.n_tokens =*/ n_tokens,
|
|
705
|
+
/*.n_seq_tokens =*/ n_tokens/n_seqs,
|
|
706
|
+
/*.n_seqs =*/ n_seqs,
|
|
707
|
+
/*.n_seqs_unq =*/ (uint32_t) ubatch.seq_id_unq.size(),
|
|
708
|
+
|
|
709
|
+
/*.token =*/ batch.token ? ubatch.token.data() : nullptr,
|
|
710
|
+
/*.embd =*/ batch.embd ? ubatch.embd.data() : nullptr,
|
|
711
|
+
/*.pos =*/ ubatch.pos.data(),
|
|
712
|
+
/*.n_seq_id =*/ ubatch.n_seq_id.data(),
|
|
713
|
+
/*.seq_id =*/ ubatch.seq_id.data(),
|
|
714
|
+
/*.seq_id_unq =*/ ubatch.seq_id_unq.data(),
|
|
715
|
+
/*.seq_idx =*/ ubatch.seq_idx.data(),
|
|
716
|
+
/*.output =*/ ubatch.output.data(),
|
|
717
|
+
};
|
|
718
|
+
|
|
719
|
+
if (debug > 0) {
|
|
720
|
+
LLAMA_LOG_DEBUG("%s: added ubatch %d to split:\n", __func__, (int) ubatches.size() - 1);
|
|
721
|
+
|
|
722
|
+
ubatch_print(res, debug);
|
|
723
|
+
}
|
|
724
|
+
|
|
725
|
+
return res;
|
|
726
|
+
}
|
|
727
|
+
|
|
728
|
+
void llama_batch_allocr::ubatch_print(const llama_ubatch & ubatch, int debug) {
|
|
729
|
+
if (debug > 0) {
|
|
730
|
+
LLAMA_LOG_DEBUG("%s: equal_seqs = %d\n", __func__, ubatch.equal_seqs);
|
|
731
|
+
LLAMA_LOG_DEBUG("%s: n_tokens = %d\n", __func__, ubatch.n_tokens);
|
|
732
|
+
LLAMA_LOG_DEBUG("%s: n_seq_tokens = %d\n", __func__, ubatch.n_seq_tokens);
|
|
733
|
+
LLAMA_LOG_DEBUG("%s: n_seqs = %d\n", __func__, ubatch.n_seqs);
|
|
734
|
+
LLAMA_LOG_DEBUG("%s: n_seqs_unq = %d\n", __func__, ubatch.n_seqs_unq);
|
|
735
|
+
|
|
736
|
+
std::stringstream ss_seq_id_unq;
|
|
737
|
+
std::stringstream ss_seq_idx;
|
|
738
|
+
|
|
739
|
+
ss_seq_id_unq << "[ ";
|
|
740
|
+
ss_seq_idx << "[";
|
|
741
|
+
|
|
742
|
+
for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
|
|
743
|
+
ss_seq_id_unq << ubatch.seq_id_unq[s] << " ";
|
|
744
|
+
}
|
|
745
|
+
|
|
746
|
+
for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
|
747
|
+
if (ubatch.seq_idx[s] >= 0) {
|
|
748
|
+
ss_seq_idx << ubatch.seq_idx[s]%10;
|
|
749
|
+
} else {
|
|
750
|
+
ss_seq_idx << ".";
|
|
751
|
+
}
|
|
752
|
+
}
|
|
753
|
+
|
|
754
|
+
ss_seq_id_unq << "]";
|
|
755
|
+
ss_seq_idx << "]";
|
|
756
|
+
|
|
757
|
+
LLAMA_LOG_DEBUG("%s: token = %p\n", __func__, (void *) ubatch.token);
|
|
758
|
+
LLAMA_LOG_DEBUG("%s: embd = %p\n", __func__, (void *) ubatch.embd);
|
|
759
|
+
LLAMA_LOG_DEBUG("%s: pos = %p\n", __func__, (void *) ubatch.pos);
|
|
760
|
+
LLAMA_LOG_DEBUG("%s: n_seq_id = %p\n", __func__, (void *) ubatch.n_seq_id);
|
|
761
|
+
LLAMA_LOG_DEBUG("%s: seq_id = %p\n", __func__, (void *) ubatch.seq_id);
|
|
762
|
+
LLAMA_LOG_DEBUG("%s: seq_id_unq = %s\n", __func__, ss_seq_id_unq.str().c_str());
|
|
763
|
+
LLAMA_LOG_DEBUG("%s: seq_idx = %s\n", __func__, ss_seq_idx.str().c_str());
|
|
764
|
+
LLAMA_LOG_DEBUG("%s: output = %p\n", __func__, (void *) ubatch.output);
|
|
765
|
+
LLAMA_LOG_DEBUG("%s: n_outputs = %d\n", __func__, n_outputs);
|
|
766
|
+
|
|
767
|
+
if (debug > 1) {
|
|
768
|
+
int seq_id_max = 0;
|
|
769
|
+
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
|
|
770
|
+
for (int s = 0; s < ubatch.n_seq_id[i]; ++s) {
|
|
771
|
+
for (int s = 0; s < ubatch.n_seq_id[i]; ++s) {
|
|
772
|
+
seq_id_max = std::max(seq_id_max, ubatch.seq_id[i][s]);
|
|
773
|
+
}
|
|
774
|
+
}
|
|
775
|
+
}
|
|
776
|
+
++seq_id_max;
|
|
777
|
+
|
|
778
|
+
LLAMA_LOG_DEBUG("%s: token = [\n", __func__);
|
|
779
|
+
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
|
|
780
|
+
std::vector<int8_t> seq_id(seq_id_max);
|
|
781
|
+
|
|
782
|
+
for (int s = 0; s < ubatch.n_seq_id[i]; ++s) {
|
|
783
|
+
seq_id[ubatch.seq_id[i][s]] = 1;
|
|
784
|
+
}
|
|
785
|
+
|
|
786
|
+
std::stringstream ss;
|
|
787
|
+
for (int s = 0; s < seq_id_max; ++s) {
|
|
788
|
+
if (seq_id[s]) {
|
|
789
|
+
ss << s%10;
|
|
790
|
+
} else {
|
|
791
|
+
ss << ".";
|
|
792
|
+
}
|
|
793
|
+
}
|
|
794
|
+
|
|
795
|
+
if (ubatch.token) {
|
|
796
|
+
LLAMA_LOG_DEBUG("%s: %4d: id = %6d (%16s), pos = %4d, n_seq_id = %2d, seq_id = [%s], output = %d\n",
|
|
797
|
+
__func__, i, ubatch.token[i], vocab->token_to_piece(ubatch.token[i]).c_str(),
|
|
798
|
+
ubatch.pos[i], ubatch.n_seq_id[i], ss.str().c_str(), ubatch.output[i]);
|
|
799
|
+
} else {
|
|
800
|
+
LLAMA_LOG_DEBUG("%s: %4d: [embd], pos = %4d, n_seq_id = %2d, seq_id = [%s], output = %d\n",
|
|
801
|
+
__func__, i, ubatch.pos[i], ubatch.n_seq_id[i], ss.str().c_str(), ubatch.output[i]);
|
|
802
|
+
}
|
|
803
|
+
}
|
|
804
|
+
LLAMA_LOG_DEBUG("%s: ]\n", __func__);
|
|
805
|
+
}
|
|
806
|
+
}
|
|
570
807
|
}
|
|
571
808
|
|
|
572
809
|
//
|
|
@@ -577,25 +814,25 @@ struct llama_batch llama_batch_get_one(
|
|
|
577
814
|
llama_token * tokens,
|
|
578
815
|
int32_t n_tokens) {
|
|
579
816
|
return {
|
|
580
|
-
/*n_tokens
|
|
581
|
-
/*tokens
|
|
582
|
-
/*embd
|
|
583
|
-
/*pos
|
|
584
|
-
/*n_seq_id
|
|
585
|
-
/*seq_id
|
|
586
|
-
/*logits
|
|
817
|
+
/*n_tokens =*/ n_tokens,
|
|
818
|
+
/*tokens =*/ tokens,
|
|
819
|
+
/*embd =*/ nullptr,
|
|
820
|
+
/*pos =*/ nullptr,
|
|
821
|
+
/*n_seq_id =*/ nullptr,
|
|
822
|
+
/*seq_id =*/ nullptr,
|
|
823
|
+
/*logits =*/ nullptr,
|
|
587
824
|
};
|
|
588
825
|
}
|
|
589
826
|
|
|
590
827
|
struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) {
|
|
591
828
|
llama_batch batch = {
|
|
592
|
-
/*n_tokens
|
|
593
|
-
/*tokens
|
|
594
|
-
/*embd
|
|
595
|
-
/*pos
|
|
596
|
-
/*n_seq_id
|
|
597
|
-
/*seq_id
|
|
598
|
-
/*logits
|
|
829
|
+
/*n_tokens =*/ 0,
|
|
830
|
+
/*tokens =*/ nullptr,
|
|
831
|
+
/*embd =*/ nullptr,
|
|
832
|
+
/*pos =*/ nullptr,
|
|
833
|
+
/*n_seq_id =*/ nullptr,
|
|
834
|
+
/*seq_id =*/ nullptr,
|
|
835
|
+
/*logits =*/ nullptr,
|
|
599
836
|
};
|
|
600
837
|
|
|
601
838
|
if (embd) {
|