@novastera-oss/llamarn 0.4.0 → 0.4.3-beta4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/RNLlamaCpp.podspec +4 -1
- package/android/CMakeLists.txt +13 -3
- package/android/src/main/cpp/include/llama.h +44 -21
- package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86/libggml.so +0 -0
- package/android/src/main/jniLibs/x86/libllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
- package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
- package/cpp/LlamaCppModel.cpp +2 -10
- package/cpp/SystemUtils.cpp +3 -7
- package/cpp/build-info.cpp +2 -2
- package/cpp/llama.cpp/CMakeLists.txt +12 -0
- package/cpp/llama.cpp/CODEOWNERS +116 -10
- package/cpp/llama.cpp/CONTRIBUTING.md +30 -3
- package/cpp/llama.cpp/README.md +13 -5
- package/cpp/llama.cpp/build-xcframework.sh +5 -0
- package/cpp/llama.cpp/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
- package/cpp/llama.cpp/common/CMakeLists.txt +12 -2
- package/cpp/llama.cpp/common/arg.cpp +303 -795
- package/cpp/llama.cpp/common/arg.h +2 -3
- package/cpp/llama.cpp/common/chat-parser-xml-toolcall.cpp +861 -0
- package/cpp/llama.cpp/common/chat-parser-xml-toolcall.h +45 -0
- package/cpp/llama.cpp/common/chat-parser.cpp +156 -15
- package/cpp/llama.cpp/common/chat-parser.h +13 -0
- package/cpp/llama.cpp/common/chat.cpp +1147 -88
- package/cpp/llama.cpp/common/chat.h +16 -3
- package/cpp/llama.cpp/common/common.cpp +70 -15
- package/cpp/llama.cpp/common/common.h +57 -19
- package/cpp/llama.cpp/common/download.cpp +1072 -0
- package/cpp/llama.cpp/common/download.h +55 -0
- package/cpp/llama.cpp/common/http.h +73 -0
- package/cpp/llama.cpp/common/json-partial.cpp +70 -2
- package/cpp/llama.cpp/common/json-schema-to-grammar.cpp +61 -22
- package/cpp/llama.cpp/common/json-schema-to-grammar.h +2 -0
- package/cpp/llama.cpp/common/log.cpp +59 -2
- package/cpp/llama.cpp/common/log.h +12 -4
- package/cpp/llama.cpp/common/sampling.cpp +84 -8
- package/cpp/llama.cpp/common/sampling.h +3 -1
- package/cpp/llama.cpp/common/speculative.cpp +1 -1
- package/cpp/llama.cpp/convert_hf_to_gguf.py +1608 -233
- package/cpp/llama.cpp/convert_hf_to_gguf_update.py +6 -1
- package/cpp/llama.cpp/convert_lora_to_gguf.py +37 -5
- package/cpp/llama.cpp/ggml/CMakeLists.txt +47 -28
- package/cpp/llama.cpp/ggml/include/ggml-backend.h +19 -1
- package/cpp/llama.cpp/ggml/include/ggml-cpu.h +1 -1
- package/cpp/llama.cpp/ggml/include/ggml-hexagon.h +19 -0
- package/cpp/llama.cpp/ggml/include/ggml-metal.h +1 -6
- package/cpp/llama.cpp/ggml/include/ggml-rpc.h +7 -9
- package/cpp/llama.cpp/ggml/include/ggml-zdnn.h +2 -1
- package/cpp/llama.cpp/ggml/include/ggml.h +199 -6
- package/cpp/llama.cpp/ggml/src/CMakeLists.txt +38 -0
- package/cpp/llama.cpp/ggml/src/ggml-alloc.c +299 -130
- package/cpp/llama.cpp/ggml/src/ggml-backend-impl.h +4 -4
- package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +21 -5
- package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +99 -2
- package/cpp/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-blas/ggml-blas.cpp +1 -0
- package/cpp/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +57 -45
- package/cpp/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +138 -47
- package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +1584 -1773
- package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +201 -317
- package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +146 -187
- package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +771 -713
- package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +135 -77
- package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +5 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +428 -26
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +16 -17
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +318 -145
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +155 -60
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/repack.cpp +8 -8
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +0 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +14 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +10 -9
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +108 -64
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +14 -4
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +530 -87
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +37 -45
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +349 -127
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +947 -1218
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +5 -4
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +143 -29
- package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +82 -76
- package/cpp/llama.cpp/ggml/src/ggml-cpu/spacemit/ime.cpp +1025 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/unary-ops.cpp +151 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/unary-ops.h +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +233 -28
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +326 -66
- package/cpp/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +12 -3
- package/cpp/llama.cpp/ggml/src/ggml-cuda/argsort.cu +102 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/binbcast.cu +110 -76
- package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +167 -38
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d.cu +6 -11
- package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +12 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy-utils.cuh +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy.cu +245 -151
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy.cuh +1 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +341 -289
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile.cu +49 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile.cuh +1233 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec.cuh +586 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +6 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cu +123 -220
- package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +41 -39
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +715 -45
- package/cpp/llama.cpp/ggml/src/ggml-cuda/im2col.cu +150 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/im2col.cuh +1 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mma.cuh +321 -24
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmf.cu +93 -351
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmf.cuh +828 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmid.cu +164 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmid.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cu +3 -166
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmvf.cu +371 -78
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmvf.cuh +3 -2
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmvq.cu +279 -147
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmvq.cuh +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cu +97 -85
- package/cpp/llama.cpp/ggml/src/ggml-cuda/pad.cu +46 -23
- package/cpp/llama.cpp/ggml/src/ggml-cuda/pad_reflect_1d.cu +63 -54
- package/cpp/llama.cpp/ggml/src/ggml-cuda/quantize.cu +12 -10
- package/cpp/llama.cpp/ggml/src/ggml-cuda/rope.cu +192 -77
- package/cpp/llama.cpp/ggml/src/ggml-cuda/rope.cuh +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/scale.cu +10 -9
- package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cu +137 -75
- package/cpp/llama.cpp/ggml/src/ggml-cuda/set.cu +39 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/set.cuh +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +40 -19
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/topk-moe.cu +336 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/topk-moe.cuh +16 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/tsembd.cu +3 -3
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +105 -11
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +36 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/upscale.cu +87 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +28 -12
- package/cpp/llama.cpp/ggml/src/ggml-hexagon/CMakeLists.txt +68 -0
- package/cpp/llama.cpp/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3807 -0
- package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/CMakeLists.txt +40 -0
- package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/act-ops.c +442 -0
- package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/binary-ops.c +360 -0
- package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
- package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/htp-ctx.h +40 -0
- package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/htp-dma.c +69 -0
- package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/htp-dma.h +119 -0
- package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/htp-msg.h +156 -0
- package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/htp-ops.h +64 -0
- package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
- package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-exp.c +93 -0
- package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-inverse.c +60 -0
- package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +49 -0
- package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-utils.c +960 -0
- package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-utils.h +1032 -0
- package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/main.c +829 -0
- package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/matmul-ops.c +2223 -0
- package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/ops-utils.h +149 -0
- package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/rope-ops.c +418 -0
- package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/softmax-ops.c +402 -0
- package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/unary-ops.c +255 -0
- package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/worker-pool.c +297 -0
- package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
- package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp-utils.c +448 -0
- package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp-utils.h +220 -0
- package/cpp/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +8 -13
- package/cpp/llama.cpp/ggml/src/ggml-impl.h +110 -12
- package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +6 -5
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-common.cpp +446 -0
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-common.h +52 -0
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-context.h +33 -0
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-context.m +599 -0
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-device.cpp +1662 -0
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-device.h +251 -0
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-device.m +1527 -0
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +244 -39
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-ops.cpp +3844 -0
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-ops.h +90 -0
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.cpp +723 -0
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +3453 -1907
- package/cpp/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +3 -1
- package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +10 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +1331 -109
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/cvt.cl +126 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +31 -4
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +35 -7
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +31 -4
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +24 -10
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +24 -10
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl +176 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl +140 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl +222 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl +167 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl +125 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl +202 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/rms_norm.cl +25 -10
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/set_rows.cl +123 -10
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tsembd.cl +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-quants.c +1 -0
- package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +341 -161
- package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +6 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +6 -5
- package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +74 -15
- package/cpp/llama.cpp/ggml/src/ggml-sycl/concat.cpp +50 -30
- package/cpp/llama.cpp/ggml/src/ggml-sycl/conv.cpp +10 -4
- package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +166 -99
- package/cpp/llama.cpp/ggml/src/ggml-sycl/count-equal.cpp +79 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/count-equal.hpp +9 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +72 -94
- package/cpp/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +67 -49
- package/cpp/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +21 -31
- package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +252 -316
- package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +6 -2
- package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +9 -6
- package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +359 -142
- package/cpp/llama.cpp/ggml/src/ggml-sycl/gla.cpp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +80 -60
- package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +201 -132
- package/cpp/llama.cpp/ggml/src/ggml-sycl/norm.cpp +230 -55
- package/cpp/llama.cpp/ggml/src/ggml-sycl/norm.hpp +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/pad.cpp +97 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/pad.hpp +24 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/pad_reflect_1d.cpp +72 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/pad_reflect_1d.hpp +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/presets.hpp +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/roll.cpp +122 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/roll.hpp +20 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +50 -41
- package/cpp/llama.cpp/ggml/src/ggml-sycl/set.cpp +73 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/set.hpp +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +45 -36
- package/cpp/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +330 -165
- package/cpp/llama.cpp/ggml/src/ggml-sycl/softmax.hpp +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +12 -6
- package/cpp/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +16 -12
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +38 -18
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +4184 -2159
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +33 -26
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +53 -30
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +3 -3
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +13 -6
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +138 -2
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +18 -4
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +3 -2
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +7 -6
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +5 -3
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +3 -3
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +3 -3
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +3 -3
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +3 -3
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +3 -2
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +52 -14
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/{flash_attn_base.comp → flash_attn_base.glsl} +50 -12
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +61 -12
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +54 -12
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +5 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +21 -12
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +28 -18
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +22 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +22 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +15 -7
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +125 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +229 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +33 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +9 -7
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +3 -5
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +3 -5
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +3 -5
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +140 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +106 -634
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +118 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +556 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +70 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +77 -214
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +589 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +97 -13
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +25 -4
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +55 -5
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +45 -3
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +227 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +5 -52
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +5 -35
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +5 -35
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +27 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +5 -41
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +6 -2
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +44 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +140 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +5 -4
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +171 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +79 -29
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +36 -12
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +471 -196
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/CMakeLists.txt +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1690 -383
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +188 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +45 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +930 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +101 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +57 -10
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +874 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +25 -912
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +97 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +247 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +302 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +267 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +123 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl +90 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/{set_rows.wgsl → set_rows.tmpl.wgsl} +38 -8
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
- package/cpp/llama.cpp/ggml/src/ggml-zdnn/common.hpp +59 -0
- package/cpp/llama.cpp/ggml/src/ggml-zdnn/ggml-zdnn.cpp +96 -314
- package/cpp/llama.cpp/ggml/src/ggml-zdnn/mmf.cpp +80 -0
- package/cpp/llama.cpp/ggml/src/ggml-zdnn/mmf.hpp +12 -0
- package/cpp/llama.cpp/ggml/src/ggml-zdnn/utils.cpp +79 -0
- package/cpp/llama.cpp/ggml/src/ggml-zdnn/utils.hpp +19 -0
- package/cpp/llama.cpp/ggml/src/ggml.c +440 -17
- package/cpp/llama.cpp/ggml/src/gguf.cpp +104 -29
- package/cpp/llama.cpp/gguf-py/gguf/constants.py +363 -13
- package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +64 -0
- package/cpp/llama.cpp/gguf-py/gguf/lazy.py +8 -3
- package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_convert_endian.py +6 -0
- package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +156 -18
- package/cpp/llama.cpp/gguf-py/gguf/utility.py +80 -0
- package/cpp/llama.cpp/gguf-py/gguf/vocab.py +4 -4
- package/cpp/llama.cpp/include/llama.h +44 -21
- package/cpp/llama.cpp/media/llama1-icon-transparent.png +0 -0
- package/cpp/llama.cpp/media/llama1-icon-transparent.svg +77 -0
- package/cpp/llama.cpp/media/llama1-icon.png +0 -0
- package/cpp/llama.cpp/media/llama1-icon.svg +87 -0
- package/cpp/llama.cpp/requirements/requirements-all.txt +2 -0
- package/cpp/llama.cpp/requirements/requirements-convert_hf_to_gguf.txt +3 -3
- package/cpp/llama.cpp/requirements/requirements-convert_legacy_llama.txt +3 -1
- package/cpp/llama.cpp/requirements/requirements-tool_bench.txt +1 -1
- package/cpp/llama.cpp/src/CMakeLists.txt +101 -0
- package/cpp/llama.cpp/src/llama-adapter.cpp +33 -0
- package/cpp/llama.cpp/src/llama-adapter.h +3 -0
- package/cpp/llama.cpp/src/llama-arch.cpp +344 -14
- package/cpp/llama.cpp/src/llama-arch.h +50 -0
- package/cpp/llama.cpp/src/llama-batch.cpp +63 -31
- package/cpp/llama.cpp/src/llama-batch.h +13 -2
- package/cpp/llama.cpp/src/llama-chat.cpp +85 -3
- package/cpp/llama.cpp/src/llama-chat.h +4 -0
- package/cpp/llama.cpp/src/llama-context.cpp +300 -45
- package/cpp/llama.cpp/src/llama-context.h +16 -6
- package/cpp/llama.cpp/src/llama-cparams.h +2 -1
- package/cpp/llama.cpp/src/llama-grammar.cpp +17 -9
- package/cpp/llama.cpp/src/llama-graph.cpp +226 -64
- package/cpp/llama.cpp/src/llama-graph.h +27 -5
- package/cpp/llama.cpp/src/llama-hparams.cpp +53 -2
- package/cpp/llama.cpp/src/llama-hparams.h +48 -8
- package/cpp/llama.cpp/src/llama-impl.cpp +3 -3
- package/cpp/llama.cpp/src/llama-impl.h +2 -0
- package/cpp/llama.cpp/src/llama-kv-cache-iswa.cpp +13 -3
- package/cpp/llama.cpp/src/llama-kv-cache-iswa.h +2 -0
- package/cpp/llama.cpp/src/llama-kv-cache.cpp +120 -62
- package/cpp/llama.cpp/src/llama-kv-cache.h +13 -4
- package/cpp/llama.cpp/src/llama-kv-cells.h +44 -2
- package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +19 -9
- package/cpp/llama.cpp/src/llama-memory-hybrid.h +2 -0
- package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +38 -17
- package/cpp/llama.cpp/src/llama-memory-recurrent.h +5 -2
- package/cpp/llama.cpp/src/llama-memory.h +3 -0
- package/cpp/llama.cpp/src/llama-model-loader.cpp +2 -0
- package/cpp/llama.cpp/src/llama-model.cpp +1070 -12614
- package/cpp/llama.cpp/src/llama-model.h +40 -4
- package/cpp/llama.cpp/src/llama-quant.cpp +14 -6
- package/cpp/llama.cpp/src/llama-sampling.cpp +243 -136
- package/cpp/llama.cpp/src/llama-vocab.cpp +43 -3
- package/cpp/llama.cpp/src/llama-vocab.h +43 -39
- package/cpp/llama.cpp/src/llama.cpp +69 -10
- package/cpp/llama.cpp/src/models/afmoe.cpp +187 -0
- package/cpp/llama.cpp/src/models/apertus.cpp +125 -0
- package/cpp/llama.cpp/src/models/arcee.cpp +135 -0
- package/cpp/llama.cpp/src/models/arctic.cpp +138 -0
- package/cpp/llama.cpp/src/models/arwkv7.cpp +86 -0
- package/cpp/llama.cpp/src/models/baichuan.cpp +122 -0
- package/cpp/llama.cpp/src/models/bailingmoe.cpp +144 -0
- package/cpp/llama.cpp/src/models/bailingmoe2.cpp +135 -0
- package/cpp/llama.cpp/src/models/bert.cpp +176 -0
- package/cpp/llama.cpp/src/models/bitnet.cpp +160 -0
- package/cpp/llama.cpp/src/models/bloom.cpp +101 -0
- package/cpp/llama.cpp/src/models/chameleon.cpp +178 -0
- package/cpp/llama.cpp/src/models/chatglm.cpp +132 -0
- package/cpp/llama.cpp/src/models/codeshell.cpp +111 -0
- package/cpp/llama.cpp/src/models/cogvlm.cpp +100 -0
- package/cpp/llama.cpp/src/models/cohere2-iswa.cpp +131 -0
- package/cpp/llama.cpp/src/models/command-r.cpp +122 -0
- package/cpp/llama.cpp/src/models/dbrx.cpp +123 -0
- package/cpp/llama.cpp/src/models/deci.cpp +135 -0
- package/cpp/llama.cpp/src/models/deepseek.cpp +144 -0
- package/cpp/llama.cpp/src/models/deepseek2.cpp +237 -0
- package/cpp/llama.cpp/src/models/dots1.cpp +134 -0
- package/cpp/llama.cpp/src/models/dream.cpp +105 -0
- package/cpp/llama.cpp/src/models/ernie4-5-moe.cpp +150 -0
- package/cpp/llama.cpp/src/models/ernie4-5.cpp +110 -0
- package/cpp/llama.cpp/src/models/exaone.cpp +114 -0
- package/cpp/llama.cpp/src/models/exaone4.cpp +123 -0
- package/cpp/llama.cpp/src/models/falcon-h1.cpp +113 -0
- package/cpp/llama.cpp/src/models/falcon.cpp +120 -0
- package/cpp/llama.cpp/src/models/gemma-embedding.cpp +120 -0
- package/cpp/llama.cpp/src/models/gemma.cpp +112 -0
- package/cpp/llama.cpp/src/models/gemma2-iswa.cpp +125 -0
- package/cpp/llama.cpp/src/models/gemma3-iswa.cpp +131 -0
- package/cpp/llama.cpp/src/models/gemma3n-iswa.cpp +377 -0
- package/cpp/llama.cpp/src/models/glm4-moe.cpp +153 -0
- package/cpp/llama.cpp/src/models/glm4.cpp +127 -0
- package/cpp/llama.cpp/src/models/gpt2.cpp +105 -0
- package/cpp/llama.cpp/src/models/gptneox.cpp +144 -0
- package/cpp/llama.cpp/src/models/granite-hybrid.cpp +196 -0
- package/cpp/llama.cpp/src/models/granite.cpp +211 -0
- package/cpp/llama.cpp/src/models/graph-context-mamba.cpp +283 -0
- package/cpp/llama.cpp/src/models/grok.cpp +159 -0
- package/cpp/llama.cpp/src/models/grovemoe.cpp +141 -0
- package/cpp/llama.cpp/src/models/hunyuan-dense.cpp +132 -0
- package/cpp/llama.cpp/src/models/hunyuan-moe.cpp +154 -0
- package/cpp/llama.cpp/src/models/internlm2.cpp +120 -0
- package/cpp/llama.cpp/src/models/jais.cpp +86 -0
- package/cpp/llama.cpp/src/models/jamba.cpp +106 -0
- package/cpp/llama.cpp/src/models/lfm2.cpp +173 -0
- package/cpp/llama.cpp/src/models/llada-moe.cpp +122 -0
- package/cpp/llama.cpp/src/models/llada.cpp +99 -0
- package/cpp/llama.cpp/src/models/llama-iswa.cpp +174 -0
- package/cpp/llama.cpp/src/models/llama.cpp +155 -0
- package/cpp/llama.cpp/src/models/mamba.cpp +55 -0
- package/cpp/llama.cpp/src/models/minicpm3.cpp +199 -0
- package/cpp/llama.cpp/src/models/minimax-m2.cpp +124 -0
- package/cpp/llama.cpp/src/models/models.h +485 -0
- package/cpp/llama.cpp/src/models/mpt.cpp +126 -0
- package/cpp/llama.cpp/src/models/nemotron-h.cpp +121 -0
- package/cpp/llama.cpp/src/models/nemotron.cpp +122 -0
- package/cpp/llama.cpp/src/models/neo-bert.cpp +104 -0
- package/cpp/llama.cpp/src/models/olmo.cpp +121 -0
- package/cpp/llama.cpp/src/models/olmo2.cpp +150 -0
- package/cpp/llama.cpp/src/models/olmoe.cpp +124 -0
- package/cpp/llama.cpp/src/models/openai-moe-iswa.cpp +124 -0
- package/cpp/llama.cpp/src/models/openelm.cpp +124 -0
- package/cpp/llama.cpp/src/models/orion.cpp +123 -0
- package/cpp/llama.cpp/src/models/pangu-embedded.cpp +121 -0
- package/cpp/llama.cpp/src/models/phi2.cpp +121 -0
- package/cpp/llama.cpp/src/models/phi3.cpp +152 -0
- package/cpp/llama.cpp/src/models/plamo.cpp +110 -0
- package/cpp/llama.cpp/src/models/plamo2.cpp +316 -0
- package/cpp/llama.cpp/src/models/plm.cpp +168 -0
- package/cpp/llama.cpp/src/models/qwen.cpp +108 -0
- package/cpp/llama.cpp/src/models/qwen2.cpp +117 -0
- package/cpp/llama.cpp/src/models/qwen2moe.cpp +151 -0
- package/cpp/llama.cpp/src/models/qwen2vl.cpp +117 -0
- package/cpp/llama.cpp/src/models/qwen3.cpp +117 -0
- package/cpp/llama.cpp/src/models/qwen3moe.cpp +124 -0
- package/cpp/llama.cpp/src/models/qwen3vl-moe.cpp +149 -0
- package/cpp/llama.cpp/src/models/qwen3vl.cpp +141 -0
- package/cpp/llama.cpp/src/models/refact.cpp +94 -0
- package/cpp/llama.cpp/src/models/rwkv6-base.cpp +162 -0
- package/cpp/llama.cpp/src/models/rwkv6.cpp +94 -0
- package/cpp/llama.cpp/src/models/rwkv6qwen2.cpp +86 -0
- package/cpp/llama.cpp/src/models/rwkv7-base.cpp +135 -0
- package/cpp/llama.cpp/src/models/rwkv7.cpp +90 -0
- package/cpp/llama.cpp/src/models/seed-oss.cpp +124 -0
- package/cpp/llama.cpp/src/models/smallthinker.cpp +120 -0
- package/cpp/llama.cpp/src/models/smollm3.cpp +128 -0
- package/cpp/llama.cpp/src/models/stablelm.cpp +146 -0
- package/cpp/llama.cpp/src/models/starcoder.cpp +100 -0
- package/cpp/llama.cpp/src/models/starcoder2.cpp +121 -0
- package/cpp/llama.cpp/src/models/t5-dec.cpp +166 -0
- package/cpp/llama.cpp/src/models/t5-enc.cpp +96 -0
- package/cpp/llama.cpp/src/models/wavtokenizer-dec.cpp +149 -0
- package/cpp/llama.cpp/src/models/xverse.cpp +108 -0
- package/cpp/llama.cpp/src/unicode.cpp +77 -0
- package/cpp/llama.cpp/src/unicode.h +43 -0
- package/cpp/llama.cpp/vendor/cpp-httplib/CMakeLists.txt +94 -0
- package/cpp/llama.cpp/vendor/cpp-httplib/httplib.cpp +9339 -0
- package/cpp/llama.cpp/vendor/cpp-httplib/httplib.h +433 -8222
- package/cpp/llama.cpp/vendor/cpp-httplib/patch-boringssl.cmake +6 -0
- package/cpp/llama.cpp/vendor/miniaudio/miniaudio.h +4179 -1900
- package/cpp/llama.cpp/vendor/minja/chat-template.hpp +9 -2
- package/cpp/llama.cpp/vendor/minja/minja.hpp +101 -22
- package/cpp/rn-completion.cpp +3 -27
- package/ios/include/chat.h +16 -3
- package/ios/include/common/minja/chat-template.hpp +9 -2
- package/ios/include/common/minja/minja.hpp +101 -22
- package/ios/include/common.h +57 -19
- package/ios/include/json-schema-to-grammar.h +2 -0
- package/ios/include/llama.h +44 -21
- package/ios/include/log.h +12 -4
- package/ios/include/sampling.h +3 -1
- package/ios/libs/llama.xcframework/Info.plist +20 -20
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +6399 -5557
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-backend.h +19 -1
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-cpu.h +1 -1
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-metal.h +1 -6
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +199 -6
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +44 -21
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +6362 -5520
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4813 -4241
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +19 -1
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +1 -1
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-metal.h +1 -6
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +199 -6
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +44 -21
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/package.json +10 -4
- package/cpp/llama.cpp/ggml/src/ggml-cann/Doxyfile +0 -2579
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +0 -371
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cuh +0 -3
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +0 -379
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cuh +0 -3
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +0 -495
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +0 -486
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +0 -5
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +0 -6886
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +0 -154
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -105
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -55
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +0 -60
- package/cpp/llama.cpp/ggml/src/ggml-zdnn/ggml-zdnn-impl.h +0 -97
- package/cpp/llama.cpp/models/ggml-vocab-aquila.gguf +0 -0
- package/cpp/llama.cpp/models/ggml-vocab-baichuan.gguf +0 -0
- package/cpp/llama.cpp/models/ggml-vocab-bert-bge.gguf +0 -0
- package/cpp/llama.cpp/models/ggml-vocab-bert-bge.gguf.inp +0 -112
- package/cpp/llama.cpp/models/ggml-vocab-bert-bge.gguf.out +0 -46
- package/cpp/llama.cpp/models/ggml-vocab-command-r.gguf +0 -0
- package/cpp/llama.cpp/models/ggml-vocab-command-r.gguf.inp +0 -112
- package/cpp/llama.cpp/models/ggml-vocab-command-r.gguf.out +0 -46
- package/cpp/llama.cpp/models/ggml-vocab-deepseek-coder.gguf +0 -0
- package/cpp/llama.cpp/models/ggml-vocab-deepseek-coder.gguf.inp +0 -112
- package/cpp/llama.cpp/models/ggml-vocab-deepseek-coder.gguf.out +0 -46
- package/cpp/llama.cpp/models/ggml-vocab-deepseek-llm.gguf +0 -0
- package/cpp/llama.cpp/models/ggml-vocab-deepseek-llm.gguf.inp +0 -112
- package/cpp/llama.cpp/models/ggml-vocab-deepseek-llm.gguf.out +0 -46
- package/cpp/llama.cpp/models/ggml-vocab-falcon.gguf +0 -0
- package/cpp/llama.cpp/models/ggml-vocab-falcon.gguf.inp +0 -112
- package/cpp/llama.cpp/models/ggml-vocab-falcon.gguf.out +0 -46
- package/cpp/llama.cpp/models/ggml-vocab-gpt-2.gguf +0 -0
- package/cpp/llama.cpp/models/ggml-vocab-gpt-2.gguf.inp +0 -112
- package/cpp/llama.cpp/models/ggml-vocab-gpt-2.gguf.out +0 -46
- package/cpp/llama.cpp/models/ggml-vocab-gpt-neox.gguf +0 -0
- package/cpp/llama.cpp/models/ggml-vocab-llama-bpe.gguf +0 -0
- package/cpp/llama.cpp/models/ggml-vocab-llama-bpe.gguf.inp +0 -112
- package/cpp/llama.cpp/models/ggml-vocab-llama-bpe.gguf.out +0 -46
- package/cpp/llama.cpp/models/ggml-vocab-llama-spm.gguf +0 -0
- package/cpp/llama.cpp/models/ggml-vocab-llama-spm.gguf.inp +0 -112
- package/cpp/llama.cpp/models/ggml-vocab-llama-spm.gguf.out +0 -46
- package/cpp/llama.cpp/models/ggml-vocab-mpt.gguf +0 -0
- package/cpp/llama.cpp/models/ggml-vocab-mpt.gguf.inp +0 -112
- package/cpp/llama.cpp/models/ggml-vocab-mpt.gguf.out +0 -46
- package/cpp/llama.cpp/models/ggml-vocab-nomic-bert-moe.gguf +0 -0
- package/cpp/llama.cpp/models/ggml-vocab-phi-3.gguf +0 -0
- package/cpp/llama.cpp/models/ggml-vocab-phi-3.gguf.inp +0 -112
- package/cpp/llama.cpp/models/ggml-vocab-phi-3.gguf.out +0 -46
- package/cpp/llama.cpp/models/ggml-vocab-qwen2.gguf +0 -0
- package/cpp/llama.cpp/models/ggml-vocab-qwen2.gguf.inp +0 -112
- package/cpp/llama.cpp/models/ggml-vocab-qwen2.gguf.out +0 -46
- package/cpp/llama.cpp/models/ggml-vocab-refact.gguf +0 -0
- package/cpp/llama.cpp/models/ggml-vocab-refact.gguf.inp +0 -112
- package/cpp/llama.cpp/models/ggml-vocab-refact.gguf.out +0 -46
- package/cpp/llama.cpp/models/ggml-vocab-starcoder.gguf +0 -0
- package/cpp/llama.cpp/models/ggml-vocab-starcoder.gguf.inp +0 -112
- package/cpp/llama.cpp/models/ggml-vocab-starcoder.gguf.out +0 -46
- package/cpp/llama.cpp/models/templates/ByteDance-Seed-OSS.jinja +0 -171
- package/cpp/llama.cpp/models/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja +0 -202
- package/cpp/llama.cpp/models/templates/CohereForAI-c4ai-command-r7b-12-2024-tool_use.jinja +0 -156
- package/cpp/llama.cpp/models/templates/Mistral-Small-3.2-24B-Instruct-2506.jinja +0 -124
- package/cpp/llama.cpp/models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja +0 -152
- package/cpp/llama.cpp/models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja +0 -152
- package/cpp/llama.cpp/models/templates/Qwen-QwQ-32B.jinja +0 -62
- package/cpp/llama.cpp/models/templates/Qwen-Qwen2.5-7B-Instruct.jinja +0 -54
- package/cpp/llama.cpp/models/templates/Qwen-Qwen3-0.6B.jinja +0 -85
- package/cpp/llama.cpp/models/templates/README.md +0 -25
- package/cpp/llama.cpp/models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja +0 -1
- package/cpp/llama.cpp/models/templates/deepseek-ai-DeepSeek-R1-Distill-Qwen-32B.jinja +0 -1
- package/cpp/llama.cpp/models/templates/fireworks-ai-llama-3-firefunction-v2.jinja +0 -57
- package/cpp/llama.cpp/models/templates/google-gemma-2-2b-it.jinja +0 -4
- package/cpp/llama.cpp/models/templates/ibm-granite-granite-3.3-2B-Instruct.jinja +0 -59
- package/cpp/llama.cpp/models/templates/llama-cpp-deepseek-r1.jinja +0 -76
- package/cpp/llama.cpp/models/templates/llama-cpp-rwkv-world.jinja +0 -34
- package/cpp/llama.cpp/models/templates/meetkai-functionary-medium-v3.1.jinja +0 -58
- package/cpp/llama.cpp/models/templates/meetkai-functionary-medium-v3.2.jinja +0 -287
- package/cpp/llama.cpp/models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja +0 -109
- package/cpp/llama.cpp/models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja +0 -93
- package/cpp/llama.cpp/models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja +0 -109
- package/cpp/llama.cpp/models/templates/microsoft-Phi-3.5-mini-instruct.jinja +0 -8
- package/cpp/llama.cpp/models/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja +0 -87
- package/cpp/llama.cpp/models/templates/moonshotai-Kimi-K2.jinja +0 -43
- package/cpp/llama.cpp/models/templates/openai-gpt-oss-120b.jinja +0 -331
- package/cpp/llama.cpp/models/templates/unsloth-mistral-Devstral-Small-2507.jinja +0 -105
- package/cpp/llama.cpp/prompts/LLM-questions.txt +0 -49
- package/cpp/llama.cpp/prompts/alpaca.txt +0 -1
- package/cpp/llama.cpp/prompts/assistant.txt +0 -31
- package/cpp/llama.cpp/prompts/chat-with-baichuan.txt +0 -4
- package/cpp/llama.cpp/prompts/chat-with-bob.txt +0 -7
- package/cpp/llama.cpp/prompts/chat-with-qwen.txt +0 -1
- package/cpp/llama.cpp/prompts/chat-with-vicuna-v0.txt +0 -7
- package/cpp/llama.cpp/prompts/chat-with-vicuna-v1.txt +0 -7
- package/cpp/llama.cpp/prompts/chat.txt +0 -28
- package/cpp/llama.cpp/prompts/dan-modified.txt +0 -1
- package/cpp/llama.cpp/prompts/dan.txt +0 -1
- package/cpp/llama.cpp/prompts/mnemonics.txt +0 -93
- package/cpp/llama.cpp/prompts/parallel-questions.txt +0 -43
- package/cpp/llama.cpp/prompts/reason-act.txt +0 -18
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Info.plist +0 -20
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +0 -5524
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +0 -4247
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-alloc.h +0 -76
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-backend.h +0 -354
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-blas.h +0 -25
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-cpu.h +0 -145
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-metal.h +0 -66
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-opt.h +0 -256
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +0 -2492
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/gguf.h +0 -202
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +0 -1391
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Modules/module.modulemap +0 -17
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Resources/Info.plist +0 -32
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-alloc.h +0 -76
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-backend.h +0 -354
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-blas.h +0 -25
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-cpu.h +0 -145
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-metal.h +0 -66
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-opt.h +0 -256
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +0 -2492
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/gguf.h +0 -202
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +0 -1391
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Modules/module.modulemap +0 -17
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Resources/Info.plist +0 -32
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-alloc.h +0 -76
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-backend.h +0 -354
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-blas.h +0 -25
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-cpu.h +0 -145
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-metal.h +0 -66
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-opt.h +0 -256
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +0 -2492
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/gguf.h +0 -202
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +0 -1391
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Modules/module.modulemap +0 -17
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Resources/Info.plist +0 -32
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Info.plist +0 -20
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +0 -5561
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-alloc.h +0 -76
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-backend.h +0 -354
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-blas.h +0 -25
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-cpu.h +0 -145
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-metal.h +0 -66
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-opt.h +0 -256
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +0 -2492
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/gguf.h +0 -202
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +0 -1391
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Info.plist +0 -35
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Modules/module.modulemap +0 -17
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Info.plist +0 -20
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +0 -5524
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +0 -4246
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-alloc.h +0 -76
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +0 -354
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-blas.h +0 -25
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +0 -145
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-metal.h +0 -66
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +0 -256
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +0 -2492
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/gguf.h +0 -202
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +0 -1391
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Info.plist +0 -35
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Modules/module.modulemap +0 -17
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Info.plist +0 -20
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +0 -5558
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-alloc.h +0 -76
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-backend.h +0 -354
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-blas.h +0 -25
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-cpu.h +0 -145
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-metal.h +0 -66
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-opt.h +0 -256
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +0 -2492
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/gguf.h +0 -202
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +0 -1391
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Info.plist +0 -32
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Modules/module.modulemap +0 -17
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Info.plist +0 -20
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +0 -5520
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +0 -4243
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-alloc.h +0 -76
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +0 -354
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-blas.h +0 -25
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +0 -145
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-metal.h +0 -66
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +0 -256
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +0 -2492
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/gguf.h +0 -202
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +0 -1391
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Info.plist +0 -32
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Modules/module.modulemap +0 -17
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
- /package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
- /package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
- /package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
- /package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
- /package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +0 -0
- /package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +0 -0
- /package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
- /package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/{rte.comp → rte.glsl} +0 -0
- /package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/{utils.comp → utils.glsl} +0 -0
|
@@ -7,8 +7,10 @@
|
|
|
7
7
|
#include "unary-ops.h"
|
|
8
8
|
#include "vec.h"
|
|
9
9
|
|
|
10
|
-
#include <
|
|
10
|
+
#include <cfloat>
|
|
11
11
|
#include <algorithm>
|
|
12
|
+
#include <cmath>
|
|
13
|
+
#include <functional>
|
|
12
14
|
|
|
13
15
|
// ggml_compute_forward_dup
|
|
14
16
|
|
|
@@ -41,628 +43,15 @@ static void ggml_compute_forward_dup_same_cont(
|
|
|
41
43
|
}
|
|
42
44
|
}
|
|
43
45
|
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
ggml_tensor * dst) {
|
|
47
|
-
|
|
48
|
-
const ggml_tensor * src0 = dst->src[0];
|
|
49
|
-
|
|
50
|
-
GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
|
|
51
|
-
|
|
52
|
-
GGML_TENSOR_UNARY_OP_LOCALS
|
|
53
|
-
|
|
54
|
-
const int ith = params->ith; // thread index
|
|
55
|
-
const int nth = params->nth; // number of threads
|
|
56
|
-
|
|
57
|
-
// parallelize by rows
|
|
58
|
-
const int nr = ne01;
|
|
59
|
-
// number of rows per thread
|
|
60
|
-
const int dr = (nr + nth - 1) / nth;
|
|
61
|
-
// row range for this thread
|
|
62
|
-
const int ir0 = dr * ith;
|
|
63
|
-
const int ir1 = MIN(ir0 + dr, nr);
|
|
64
|
-
|
|
65
|
-
if (src0->type == dst->type &&
|
|
66
|
-
ne00 == ne0 &&
|
|
67
|
-
nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) {
|
|
68
|
-
// copy by rows
|
|
69
|
-
const size_t rs = ne00*nb00;
|
|
70
|
-
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
|
71
|
-
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
|
72
|
-
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
|
73
|
-
memcpy(
|
|
74
|
-
((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
|
|
75
|
-
((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
|
|
76
|
-
rs);
|
|
77
|
-
}
|
|
78
|
-
}
|
|
79
|
-
}
|
|
80
|
-
return;
|
|
81
|
-
}
|
|
82
|
-
|
|
83
|
-
// TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy
|
|
84
|
-
|
|
85
|
-
if (ggml_is_contiguous(dst)) {
|
|
86
|
-
if (nb00 == sizeof(ggml_fp16_t)) {
|
|
87
|
-
if (dst->type == GGML_TYPE_F16) {
|
|
88
|
-
size_t id = 0;
|
|
89
|
-
const size_t rs = ne00 * nb00;
|
|
90
|
-
char * dst_ptr = (char *) dst->data;
|
|
91
|
-
|
|
92
|
-
for (int i03 = 0; i03 < ne03; i03++) {
|
|
93
|
-
for (int i02 = 0; i02 < ne02; i02++) {
|
|
94
|
-
id += rs * ir0;
|
|
95
|
-
for (int i01 = ir0; i01 < ir1; i01++) {
|
|
96
|
-
const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
|
|
97
|
-
memcpy(dst_ptr + id, src0_ptr, rs);
|
|
98
|
-
id += rs;
|
|
99
|
-
}
|
|
100
|
-
id += rs * (ne01 - ir1);
|
|
101
|
-
}
|
|
102
|
-
}
|
|
103
|
-
} else if (dst->type == GGML_TYPE_F32) {
|
|
104
|
-
size_t id = 0;
|
|
105
|
-
float * dst_ptr = (float *) dst->data;
|
|
106
|
-
|
|
107
|
-
for (int i03 = 0; i03 < ne03; i03++) {
|
|
108
|
-
for (int i02 = 0; i02 < ne02; i02++) {
|
|
109
|
-
id += ne00 * ir0;
|
|
110
|
-
for (int i01 = ir0; i01 < ir1; i01++) {
|
|
111
|
-
const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
|
112
|
-
for (int i00 = 0; i00 < ne00; i00++) {
|
|
113
|
-
dst_ptr[id] = GGML_CPU_FP16_TO_FP32(src0_ptr[i00]);
|
|
114
|
-
id++;
|
|
115
|
-
}
|
|
116
|
-
}
|
|
117
|
-
id += ne00 * (ne01 - ir1);
|
|
118
|
-
}
|
|
119
|
-
}
|
|
120
|
-
} else if (ggml_get_type_traits_cpu(dst->type)->from_float) {
|
|
121
|
-
ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dst->type)->from_float;
|
|
122
|
-
float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
|
|
123
|
-
|
|
124
|
-
size_t id = 0;
|
|
125
|
-
size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
|
|
126
|
-
char * dst_ptr = (char *) dst->data;
|
|
127
|
-
|
|
128
|
-
for (int i03 = 0; i03 < ne03; i03++) {
|
|
129
|
-
for (int i02 = 0; i02 < ne02; i02++) {
|
|
130
|
-
id += rs * ir0;
|
|
131
|
-
for (int i01 = ir0; i01 < ir1; i01++) {
|
|
132
|
-
const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
|
133
|
-
|
|
134
|
-
for (int i00 = 0; i00 < ne00; i00++) {
|
|
135
|
-
src0_f32[i00] = GGML_CPU_FP16_TO_FP32(src0_ptr[i00]);
|
|
136
|
-
}
|
|
137
|
-
|
|
138
|
-
quantize_row_q(src0_f32, dst_ptr + id, ne00);
|
|
139
|
-
id += rs;
|
|
140
|
-
}
|
|
141
|
-
id += rs * (ne01 - ir1);
|
|
142
|
-
}
|
|
143
|
-
}
|
|
144
|
-
} else {
|
|
145
|
-
GGML_ABORT("fatal error"); // TODO: implement
|
|
146
|
-
}
|
|
147
|
-
} else {
|
|
148
|
-
//printf("%s: this is not optimal - fix me\n", __func__);
|
|
149
|
-
|
|
150
|
-
if (dst->type == GGML_TYPE_F32) {
|
|
151
|
-
size_t id = 0;
|
|
152
|
-
float * dst_ptr = (float *) dst->data;
|
|
153
|
-
|
|
154
|
-
for (int i03 = 0; i03 < ne03; i03++) {
|
|
155
|
-
for (int i02 = 0; i02 < ne02; i02++) {
|
|
156
|
-
id += ne00 * ir0;
|
|
157
|
-
for (int i01 = ir0; i01 < ir1; i01++) {
|
|
158
|
-
for (int i00 = 0; i00 < ne00; i00++) {
|
|
159
|
-
const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
|
160
|
-
|
|
161
|
-
dst_ptr[id] = GGML_CPU_FP16_TO_FP32(*src0_ptr);
|
|
162
|
-
id++;
|
|
163
|
-
}
|
|
164
|
-
}
|
|
165
|
-
id += ne00 * (ne01 - ir1);
|
|
166
|
-
}
|
|
167
|
-
}
|
|
168
|
-
} else if (dst->type == GGML_TYPE_F16) {
|
|
169
|
-
size_t id = 0;
|
|
170
|
-
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
|
|
171
|
-
|
|
172
|
-
for (int i03 = 0; i03 < ne03; i03++) {
|
|
173
|
-
for (int i02 = 0; i02 < ne02; i02++) {
|
|
174
|
-
id += ne00 * ir0;
|
|
175
|
-
for (int i01 = ir0; i01 < ir1; i01++) {
|
|
176
|
-
for (int i00 = 0; i00 < ne00; i00++) {
|
|
177
|
-
const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
|
178
|
-
|
|
179
|
-
dst_ptr[id] = *src0_ptr;
|
|
180
|
-
id++;
|
|
181
|
-
}
|
|
182
|
-
}
|
|
183
|
-
id += ne00 * (ne01 - ir1);
|
|
184
|
-
}
|
|
185
|
-
}
|
|
186
|
-
} else {
|
|
187
|
-
GGML_ABORT("fatal error"); // TODO: implement
|
|
188
|
-
}
|
|
189
|
-
}
|
|
190
|
-
return;
|
|
191
|
-
}
|
|
192
|
-
|
|
193
|
-
// dst counters
|
|
194
|
-
int64_t i10 = 0;
|
|
195
|
-
int64_t i11 = 0;
|
|
196
|
-
int64_t i12 = 0;
|
|
197
|
-
int64_t i13 = 0;
|
|
198
|
-
|
|
199
|
-
if (dst->type == GGML_TYPE_F16) {
|
|
200
|
-
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
|
201
|
-
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
|
202
|
-
i10 += ne00 * ir0;
|
|
203
|
-
while (i10 >= ne0) {
|
|
204
|
-
i10 -= ne0;
|
|
205
|
-
if (++i11 == ne1) {
|
|
206
|
-
i11 = 0;
|
|
207
|
-
if (++i12 == ne2) {
|
|
208
|
-
i12 = 0;
|
|
209
|
-
if (++i13 == ne3) {
|
|
210
|
-
i13 = 0;
|
|
211
|
-
}
|
|
212
|
-
}
|
|
213
|
-
}
|
|
214
|
-
}
|
|
215
|
-
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
|
216
|
-
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
|
217
|
-
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
|
218
|
-
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
|
219
|
-
|
|
220
|
-
memcpy(dst_ptr, src0_ptr, sizeof(ggml_fp16_t));
|
|
221
|
-
|
|
222
|
-
if (++i10 == ne00) {
|
|
223
|
-
i10 = 0;
|
|
224
|
-
if (++i11 == ne01) {
|
|
225
|
-
i11 = 0;
|
|
226
|
-
if (++i12 == ne02) {
|
|
227
|
-
i12 = 0;
|
|
228
|
-
if (++i13 == ne03) {
|
|
229
|
-
i13 = 0;
|
|
230
|
-
}
|
|
231
|
-
}
|
|
232
|
-
}
|
|
233
|
-
}
|
|
234
|
-
}
|
|
235
|
-
}
|
|
236
|
-
i10 += ne00 * (ne01 - ir1);
|
|
237
|
-
while (i10 >= ne0) {
|
|
238
|
-
i10 -= ne0;
|
|
239
|
-
if (++i11 == ne1) {
|
|
240
|
-
i11 = 0;
|
|
241
|
-
if (++i12 == ne2) {
|
|
242
|
-
i12 = 0;
|
|
243
|
-
if (++i13 == ne3) {
|
|
244
|
-
i13 = 0;
|
|
245
|
-
}
|
|
246
|
-
}
|
|
247
|
-
}
|
|
248
|
-
}
|
|
249
|
-
}
|
|
250
|
-
}
|
|
251
|
-
} else if (dst->type == GGML_TYPE_F32) {
|
|
252
|
-
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
|
253
|
-
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
|
254
|
-
i10 += ne00 * ir0;
|
|
255
|
-
while (i10 >= ne0) {
|
|
256
|
-
i10 -= ne0;
|
|
257
|
-
if (++i11 == ne1) {
|
|
258
|
-
i11 = 0;
|
|
259
|
-
if (++i12 == ne2) {
|
|
260
|
-
i12 = 0;
|
|
261
|
-
if (++i13 == ne3) {
|
|
262
|
-
i13 = 0;
|
|
263
|
-
}
|
|
264
|
-
}
|
|
265
|
-
}
|
|
266
|
-
}
|
|
267
|
-
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
|
268
|
-
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
|
269
|
-
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
|
270
|
-
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
|
271
|
-
|
|
272
|
-
*(float *) dst_ptr = GGML_CPU_FP16_TO_FP32(*(const ggml_fp16_t *) src0_ptr);
|
|
273
|
-
|
|
274
|
-
if (++i10 == ne0) {
|
|
275
|
-
i10 = 0;
|
|
276
|
-
if (++i11 == ne1) {
|
|
277
|
-
i11 = 0;
|
|
278
|
-
if (++i12 == ne2) {
|
|
279
|
-
i12 = 0;
|
|
280
|
-
if (++i13 == ne3) {
|
|
281
|
-
i13 = 0;
|
|
282
|
-
}
|
|
283
|
-
}
|
|
284
|
-
}
|
|
285
|
-
}
|
|
286
|
-
}
|
|
287
|
-
}
|
|
288
|
-
i10 += ne00 * (ne01 - ir1);
|
|
289
|
-
while (i10 >= ne0) {
|
|
290
|
-
i10 -= ne0;
|
|
291
|
-
if (++i11 == ne1) {
|
|
292
|
-
i11 = 0;
|
|
293
|
-
if (++i12 == ne2) {
|
|
294
|
-
i12 = 0;
|
|
295
|
-
if (++i13 == ne3) {
|
|
296
|
-
i13 = 0;
|
|
297
|
-
}
|
|
298
|
-
}
|
|
299
|
-
}
|
|
300
|
-
}
|
|
301
|
-
}
|
|
302
|
-
}
|
|
303
|
-
} else {
|
|
304
|
-
GGML_ABORT("fatal error"); // TODO: implement
|
|
305
|
-
}
|
|
306
|
-
}
|
|
307
|
-
|
|
308
|
-
static void ggml_compute_forward_dup_bf16(
|
|
309
|
-
const ggml_compute_params * params,
|
|
310
|
-
ggml_tensor * dst) {
|
|
311
|
-
|
|
312
|
-
const ggml_tensor * src0 = dst->src[0];
|
|
313
|
-
|
|
314
|
-
GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
|
|
315
|
-
|
|
316
|
-
GGML_TENSOR_UNARY_OP_LOCALS
|
|
317
|
-
|
|
318
|
-
const int ith = params->ith; // thread index
|
|
319
|
-
const int nth = params->nth; // number of threads
|
|
320
|
-
|
|
321
|
-
// parallelize by rows
|
|
322
|
-
const int nr = ne01;
|
|
323
|
-
// number of rows per thread
|
|
324
|
-
const int dr = (nr + nth - 1) / nth;
|
|
325
|
-
// row range for this thread
|
|
326
|
-
const int ir0 = dr * ith;
|
|
327
|
-
const int ir1 = MIN(ir0 + dr, nr);
|
|
328
|
-
|
|
329
|
-
if (src0->type == dst->type &&
|
|
330
|
-
ne00 == ne0 &&
|
|
331
|
-
nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) {
|
|
332
|
-
// copy by rows
|
|
333
|
-
const size_t rs = ne00*nb00;
|
|
334
|
-
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
|
335
|
-
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
|
336
|
-
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
|
337
|
-
memcpy(
|
|
338
|
-
((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
|
|
339
|
-
((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
|
|
340
|
-
rs);
|
|
341
|
-
}
|
|
342
|
-
}
|
|
343
|
-
}
|
|
344
|
-
return;
|
|
345
|
-
}
|
|
346
|
-
|
|
347
|
-
// TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy
|
|
348
|
-
|
|
349
|
-
if (ggml_is_contiguous(dst)) {
|
|
350
|
-
if (nb00 == sizeof(ggml_bf16_t)) {
|
|
351
|
-
if (dst->type == GGML_TYPE_BF16) {
|
|
352
|
-
size_t id = 0;
|
|
353
|
-
const size_t rs = ne00 * nb00;
|
|
354
|
-
char * dst_ptr = (char *) dst->data;
|
|
355
|
-
|
|
356
|
-
for (int i03 = 0; i03 < ne03; i03++) {
|
|
357
|
-
for (int i02 = 0; i02 < ne02; i02++) {
|
|
358
|
-
id += rs * ir0;
|
|
359
|
-
for (int i01 = ir0; i01 < ir1; i01++) {
|
|
360
|
-
const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
|
|
361
|
-
memcpy(dst_ptr + id, src0_ptr, rs);
|
|
362
|
-
id += rs;
|
|
363
|
-
}
|
|
364
|
-
id += rs * (ne01 - ir1);
|
|
365
|
-
}
|
|
366
|
-
}
|
|
367
|
-
} else if (dst->type == GGML_TYPE_F16) {
|
|
368
|
-
size_t id = 0;
|
|
369
|
-
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
|
|
370
|
-
|
|
371
|
-
for (int i03 = 0; i03 < ne03; i03++) {
|
|
372
|
-
for (int i02 = 0; i02 < ne02; i02++) {
|
|
373
|
-
id += ne00 * ir0;
|
|
374
|
-
for (int i01 = ir0; i01 < ir1; i01++) {
|
|
375
|
-
const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
|
376
|
-
for (int i00 = 0; i00 < ne00; i00++) {
|
|
377
|
-
dst_ptr[id] = GGML_CPU_FP32_TO_FP16(GGML_BF16_TO_FP32(src0_ptr[i00]));
|
|
378
|
-
id++;
|
|
379
|
-
}
|
|
380
|
-
}
|
|
381
|
-
id += ne00 * (ne01 - ir1);
|
|
382
|
-
}
|
|
383
|
-
}
|
|
384
|
-
} else if (dst->type == GGML_TYPE_F32) {
|
|
385
|
-
size_t id = 0;
|
|
386
|
-
float * dst_ptr = (float *) dst->data;
|
|
387
|
-
|
|
388
|
-
for (int i03 = 0; i03 < ne03; i03++) {
|
|
389
|
-
for (int i02 = 0; i02 < ne02; i02++) {
|
|
390
|
-
id += ne00 * ir0;
|
|
391
|
-
for (int i01 = ir0; i01 < ir1; i01++) {
|
|
392
|
-
const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
|
393
|
-
for (int i00 = 0; i00 < ne00; i00++) {
|
|
394
|
-
dst_ptr[id] = GGML_BF16_TO_FP32(src0_ptr[i00]);
|
|
395
|
-
id++;
|
|
396
|
-
}
|
|
397
|
-
}
|
|
398
|
-
id += ne00 * (ne01 - ir1);
|
|
399
|
-
}
|
|
400
|
-
}
|
|
401
|
-
} else if (ggml_get_type_traits_cpu(dst->type)->from_float) {
|
|
402
|
-
ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dst->type)->from_float;
|
|
403
|
-
float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
|
|
404
|
-
|
|
405
|
-
size_t id = 0;
|
|
406
|
-
size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
|
|
407
|
-
char * dst_ptr = (char *) dst->data;
|
|
408
|
-
|
|
409
|
-
for (int i03 = 0; i03 < ne03; i03++) {
|
|
410
|
-
for (int i02 = 0; i02 < ne02; i02++) {
|
|
411
|
-
id += rs * ir0;
|
|
412
|
-
for (int i01 = ir0; i01 < ir1; i01++) {
|
|
413
|
-
const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
|
414
|
-
|
|
415
|
-
for (int i00 = 0; i00 < ne00; i00++) {
|
|
416
|
-
src0_f32[i00] = GGML_BF16_TO_FP32(src0_ptr[i00]);
|
|
417
|
-
}
|
|
418
|
-
|
|
419
|
-
quantize_row_q(src0_f32, dst_ptr + id, ne00);
|
|
420
|
-
id += rs;
|
|
421
|
-
}
|
|
422
|
-
id += rs * (ne01 - ir1);
|
|
423
|
-
}
|
|
424
|
-
}
|
|
425
|
-
} else {
|
|
426
|
-
GGML_ABORT("fatal error"); // TODO: implement
|
|
427
|
-
}
|
|
428
|
-
} else {
|
|
429
|
-
//printf("%s: this is not optimal - fix me\n", __func__);
|
|
430
|
-
|
|
431
|
-
if (dst->type == GGML_TYPE_F32) {
|
|
432
|
-
size_t id = 0;
|
|
433
|
-
float * dst_ptr = (float *) dst->data;
|
|
434
|
-
|
|
435
|
-
for (int i03 = 0; i03 < ne03; i03++) {
|
|
436
|
-
for (int i02 = 0; i02 < ne02; i02++) {
|
|
437
|
-
id += ne00 * ir0;
|
|
438
|
-
for (int i01 = ir0; i01 < ir1; i01++) {
|
|
439
|
-
for (int i00 = 0; i00 < ne00; i00++) {
|
|
440
|
-
const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
|
441
|
-
|
|
442
|
-
dst_ptr[id] = GGML_BF16_TO_FP32(*src0_ptr);
|
|
443
|
-
id++;
|
|
444
|
-
}
|
|
445
|
-
}
|
|
446
|
-
id += ne00 * (ne01 - ir1);
|
|
447
|
-
}
|
|
448
|
-
}
|
|
449
|
-
} else if (dst->type == GGML_TYPE_BF16) {
|
|
450
|
-
size_t id = 0;
|
|
451
|
-
ggml_bf16_t * dst_ptr = (ggml_bf16_t *) dst->data;
|
|
452
|
-
|
|
453
|
-
for (int i03 = 0; i03 < ne03; i03++) {
|
|
454
|
-
for (int i02 = 0; i02 < ne02; i02++) {
|
|
455
|
-
id += ne00 * ir0;
|
|
456
|
-
for (int i01 = ir0; i01 < ir1; i01++) {
|
|
457
|
-
for (int i00 = 0; i00 < ne00; i00++) {
|
|
458
|
-
const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
|
459
|
-
|
|
460
|
-
dst_ptr[id] = *src0_ptr;
|
|
461
|
-
id++;
|
|
462
|
-
}
|
|
463
|
-
}
|
|
464
|
-
id += ne00 * (ne01 - ir1);
|
|
465
|
-
}
|
|
466
|
-
}
|
|
467
|
-
} else if (dst->type == GGML_TYPE_F16) {
|
|
468
|
-
size_t id = 0;
|
|
469
|
-
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
|
|
470
|
-
|
|
471
|
-
for (int i03 = 0; i03 < ne03; i03++) {
|
|
472
|
-
for (int i02 = 0; i02 < ne02; i02++) {
|
|
473
|
-
id += ne00 * ir0;
|
|
474
|
-
for (int i01 = ir0; i01 < ir1; i01++) {
|
|
475
|
-
for (int i00 = 0; i00 < ne00; i00++) {
|
|
476
|
-
const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
|
477
|
-
|
|
478
|
-
dst_ptr[id] = GGML_CPU_FP32_TO_FP16(GGML_BF16_TO_FP32(*src0_ptr));
|
|
479
|
-
id++;
|
|
480
|
-
}
|
|
481
|
-
}
|
|
482
|
-
id += ne00 * (ne01 - ir1);
|
|
483
|
-
}
|
|
484
|
-
}
|
|
485
|
-
} else {
|
|
486
|
-
GGML_ABORT("fatal error"); // TODO: implement
|
|
487
|
-
}
|
|
488
|
-
}
|
|
489
|
-
return;
|
|
490
|
-
}
|
|
491
|
-
|
|
492
|
-
// dst counters
|
|
493
|
-
int64_t i10 = 0;
|
|
494
|
-
int64_t i11 = 0;
|
|
495
|
-
int64_t i12 = 0;
|
|
496
|
-
int64_t i13 = 0;
|
|
497
|
-
|
|
498
|
-
if (dst->type == GGML_TYPE_BF16) {
|
|
499
|
-
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
|
500
|
-
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
|
501
|
-
i10 += ne00 * ir0;
|
|
502
|
-
while (i10 >= ne0) {
|
|
503
|
-
i10 -= ne0;
|
|
504
|
-
if (++i11 == ne1) {
|
|
505
|
-
i11 = 0;
|
|
506
|
-
if (++i12 == ne2) {
|
|
507
|
-
i12 = 0;
|
|
508
|
-
if (++i13 == ne3) {
|
|
509
|
-
i13 = 0;
|
|
510
|
-
}
|
|
511
|
-
}
|
|
512
|
-
}
|
|
513
|
-
}
|
|
514
|
-
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
|
515
|
-
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
|
516
|
-
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
|
517
|
-
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
|
518
|
-
|
|
519
|
-
memcpy(dst_ptr, src0_ptr, sizeof(ggml_bf16_t));
|
|
520
|
-
|
|
521
|
-
if (++i10 == ne00) {
|
|
522
|
-
i10 = 0;
|
|
523
|
-
if (++i11 == ne01) {
|
|
524
|
-
i11 = 0;
|
|
525
|
-
if (++i12 == ne02) {
|
|
526
|
-
i12 = 0;
|
|
527
|
-
if (++i13 == ne03) {
|
|
528
|
-
i13 = 0;
|
|
529
|
-
}
|
|
530
|
-
}
|
|
531
|
-
}
|
|
532
|
-
}
|
|
533
|
-
}
|
|
534
|
-
}
|
|
535
|
-
i10 += ne00 * (ne01 - ir1);
|
|
536
|
-
while (i10 >= ne0) {
|
|
537
|
-
i10 -= ne0;
|
|
538
|
-
if (++i11 == ne1) {
|
|
539
|
-
i11 = 0;
|
|
540
|
-
if (++i12 == ne2) {
|
|
541
|
-
i12 = 0;
|
|
542
|
-
if (++i13 == ne3) {
|
|
543
|
-
i13 = 0;
|
|
544
|
-
}
|
|
545
|
-
}
|
|
546
|
-
}
|
|
547
|
-
}
|
|
548
|
-
}
|
|
549
|
-
}
|
|
550
|
-
} else if (dst->type == GGML_TYPE_F16) {
|
|
551
|
-
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
|
552
|
-
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
|
553
|
-
i10 += ne00 * ir0;
|
|
554
|
-
while (i10 >= ne0) {
|
|
555
|
-
i10 -= ne0;
|
|
556
|
-
if (++i11 == ne1) {
|
|
557
|
-
i11 = 0;
|
|
558
|
-
if (++i12 == ne2) {
|
|
559
|
-
i12 = 0;
|
|
560
|
-
if (++i13 == ne3) {
|
|
561
|
-
i13 = 0;
|
|
562
|
-
}
|
|
563
|
-
}
|
|
564
|
-
}
|
|
565
|
-
}
|
|
566
|
-
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
|
567
|
-
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
|
568
|
-
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
|
569
|
-
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
|
570
|
-
|
|
571
|
-
*(ggml_fp16_t *) dst_ptr = GGML_CPU_FP32_TO_FP16(GGML_BF16_TO_FP32(*(const ggml_bf16_t *) src0_ptr));
|
|
572
|
-
|
|
573
|
-
if (++i10 == ne0) {
|
|
574
|
-
i10 = 0;
|
|
575
|
-
if (++i11 == ne1) {
|
|
576
|
-
i11 = 0;
|
|
577
|
-
if (++i12 == ne2) {
|
|
578
|
-
i12 = 0;
|
|
579
|
-
if (++i13 == ne3) {
|
|
580
|
-
i13 = 0;
|
|
581
|
-
}
|
|
582
|
-
}
|
|
583
|
-
}
|
|
584
|
-
}
|
|
585
|
-
}
|
|
586
|
-
}
|
|
587
|
-
i10 += ne00 * (ne01 - ir1);
|
|
588
|
-
while (i10 >= ne0) {
|
|
589
|
-
i10 -= ne0;
|
|
590
|
-
if (++i11 == ne1) {
|
|
591
|
-
i11 = 0;
|
|
592
|
-
if (++i12 == ne2) {
|
|
593
|
-
i12 = 0;
|
|
594
|
-
if (++i13 == ne3) {
|
|
595
|
-
i13 = 0;
|
|
596
|
-
}
|
|
597
|
-
}
|
|
598
|
-
}
|
|
599
|
-
}
|
|
600
|
-
}
|
|
601
|
-
}
|
|
602
|
-
} else if (dst->type == GGML_TYPE_F32) {
|
|
603
|
-
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
|
604
|
-
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
|
605
|
-
i10 += ne00 * ir0;
|
|
606
|
-
while (i10 >= ne0) {
|
|
607
|
-
i10 -= ne0;
|
|
608
|
-
if (++i11 == ne1) {
|
|
609
|
-
i11 = 0;
|
|
610
|
-
if (++i12 == ne2) {
|
|
611
|
-
i12 = 0;
|
|
612
|
-
if (++i13 == ne3) {
|
|
613
|
-
i13 = 0;
|
|
614
|
-
}
|
|
615
|
-
}
|
|
616
|
-
}
|
|
617
|
-
}
|
|
618
|
-
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
|
619
|
-
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
|
620
|
-
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
|
621
|
-
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
|
622
|
-
|
|
623
|
-
*(float *) dst_ptr = GGML_BF16_TO_FP32(*(const ggml_bf16_t *) src0_ptr);
|
|
624
|
-
|
|
625
|
-
if (++i10 == ne0) {
|
|
626
|
-
i10 = 0;
|
|
627
|
-
if (++i11 == ne1) {
|
|
628
|
-
i11 = 0;
|
|
629
|
-
if (++i12 == ne2) {
|
|
630
|
-
i12 = 0;
|
|
631
|
-
if (++i13 == ne3) {
|
|
632
|
-
i13 = 0;
|
|
633
|
-
}
|
|
634
|
-
}
|
|
635
|
-
}
|
|
636
|
-
}
|
|
637
|
-
}
|
|
638
|
-
}
|
|
639
|
-
i10 += ne00 * (ne01 - ir1);
|
|
640
|
-
while (i10 >= ne0) {
|
|
641
|
-
i10 -= ne0;
|
|
642
|
-
if (++i11 == ne1) {
|
|
643
|
-
i11 = 0;
|
|
644
|
-
if (++i12 == ne2) {
|
|
645
|
-
i12 = 0;
|
|
646
|
-
if (++i13 == ne3) {
|
|
647
|
-
i13 = 0;
|
|
648
|
-
}
|
|
649
|
-
}
|
|
650
|
-
}
|
|
651
|
-
}
|
|
652
|
-
}
|
|
653
|
-
}
|
|
654
|
-
} else {
|
|
655
|
-
GGML_ABORT("fatal error"); // TODO: implement
|
|
656
|
-
}
|
|
657
|
-
}
|
|
658
|
-
|
|
659
|
-
static void ggml_compute_forward_dup_f32(
|
|
46
|
+
template<typename src_t, typename dst_t>
|
|
47
|
+
static void ggml_compute_forward_dup_flt(
|
|
660
48
|
const ggml_compute_params * params,
|
|
661
49
|
ggml_tensor * dst) {
|
|
662
50
|
|
|
663
51
|
const ggml_tensor * src0 = dst->src[0];
|
|
664
52
|
|
|
665
53
|
GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
|
|
54
|
+
GGML_ASSERT(!ggml_is_quantized(src0->type) && !ggml_is_quantized(dst->type));
|
|
666
55
|
|
|
667
56
|
GGML_TENSOR_UNARY_OP_LOCALS
|
|
668
57
|
|
|
@@ -677,6 +66,7 @@ static void ggml_compute_forward_dup_f32(
|
|
|
677
66
|
const int ir0 = dr * ith;
|
|
678
67
|
const int ir1 = MIN(ir0 + dr, nr);
|
|
679
68
|
|
|
69
|
+
// case: type & row size equal
|
|
680
70
|
if (src0->type == dst->type &&
|
|
681
71
|
ne00 == ne0 &&
|
|
682
72
|
nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) {
|
|
@@ -695,103 +85,78 @@ static void ggml_compute_forward_dup_f32(
|
|
|
695
85
|
return;
|
|
696
86
|
}
|
|
697
87
|
|
|
88
|
+
// case: dst tensor is contiguous
|
|
698
89
|
if (ggml_is_contiguous(dst)) {
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
ggml_from_float_t const from_float = ggml_get_type_traits_cpu(dst->type)->from_float;
|
|
703
|
-
|
|
90
|
+
if (nb00 == sizeof(src_t)) {
|
|
91
|
+
if constexpr (std::is_same_v<dst_t, src_t>) {
|
|
92
|
+
// same type
|
|
704
93
|
size_t id = 0;
|
|
705
|
-
size_t rs =
|
|
94
|
+
const size_t rs = ne00 * nb00;
|
|
706
95
|
char * dst_ptr = (char *) dst->data;
|
|
707
96
|
|
|
708
97
|
for (int i03 = 0; i03 < ne03; i03++) {
|
|
709
98
|
for (int i02 = 0; i02 < ne02; i02++) {
|
|
710
99
|
id += rs * ir0;
|
|
711
100
|
for (int i01 = ir0; i01 < ir1; i01++) {
|
|
712
|
-
const
|
|
713
|
-
|
|
101
|
+
const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
|
|
102
|
+
memcpy(dst_ptr + id, src0_ptr, rs);
|
|
714
103
|
id += rs;
|
|
715
104
|
}
|
|
716
105
|
id += rs * (ne01 - ir1);
|
|
717
106
|
}
|
|
718
107
|
}
|
|
719
108
|
} else {
|
|
720
|
-
|
|
721
|
-
}
|
|
722
|
-
} else {
|
|
723
|
-
//printf("%s: this is not optimal - fix me\n", __func__);
|
|
724
|
-
|
|
725
|
-
if (dst->type == GGML_TYPE_F32) {
|
|
109
|
+
// casting between non-quantized types
|
|
726
110
|
size_t id = 0;
|
|
727
|
-
|
|
111
|
+
dst_t * dst_ptr = (dst_t *) dst->data;
|
|
728
112
|
|
|
729
113
|
for (int i03 = 0; i03 < ne03; i03++) {
|
|
730
114
|
for (int i02 = 0; i02 < ne02; i02++) {
|
|
731
115
|
id += ne00 * ir0;
|
|
732
116
|
for (int i01 = ir0; i01 < ir1; i01++) {
|
|
117
|
+
const src_t * src0_ptr = (src_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
|
733
118
|
for (int i00 = 0; i00 < ne00; i00++) {
|
|
734
|
-
|
|
735
|
-
|
|
736
|
-
dst_ptr[id] = *src0_ptr;
|
|
119
|
+
float tmp = type_conversion_table<src_t>::to_f32(src0_ptr[i00]);
|
|
120
|
+
dst_ptr[id] = type_conversion_table<dst_t>::from_f32(tmp);
|
|
737
121
|
id++;
|
|
738
122
|
}
|
|
739
123
|
}
|
|
740
124
|
id += ne00 * (ne01 - ir1);
|
|
741
125
|
}
|
|
742
126
|
}
|
|
743
|
-
}
|
|
744
|
-
|
|
745
|
-
|
|
746
|
-
|
|
747
|
-
for (int i03 = 0; i03 < ne03; i03++) {
|
|
748
|
-
for (int i02 = 0; i02 < ne02; i02++) {
|
|
749
|
-
id += ne00 * ir0;
|
|
750
|
-
for (int i01 = ir0; i01 < ir1; i01++) {
|
|
751
|
-
for (int i00 = 0; i00 < ne00; i00++) {
|
|
752
|
-
const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
|
127
|
+
}
|
|
128
|
+
} else {
|
|
129
|
+
//printf("%s: this is not optimal - fix me\n", __func__);
|
|
753
130
|
|
|
754
|
-
|
|
755
|
-
|
|
756
|
-
}
|
|
757
|
-
}
|
|
758
|
-
id += ne00 * (ne01 - ir1);
|
|
759
|
-
}
|
|
760
|
-
}
|
|
761
|
-
} else if (dst->type == GGML_TYPE_BF16) {
|
|
762
|
-
size_t id = 0;
|
|
763
|
-
ggml_bf16_t * dst_ptr = (ggml_bf16_t *) dst->data;
|
|
131
|
+
size_t id = 0;
|
|
132
|
+
dst_t * dst_ptr = (dst_t *) dst->data;
|
|
764
133
|
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
134
|
+
for (int i03 = 0; i03 < ne03; i03++) {
|
|
135
|
+
for (int i02 = 0; i02 < ne02; i02++) {
|
|
136
|
+
id += ne00 * ir0;
|
|
137
|
+
for (int i01 = ir0; i01 < ir1; i01++) {
|
|
138
|
+
for (int i00 = 0; i00 < ne00; i00++) {
|
|
139
|
+
const src_t * src0_ptr = (src_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
|
771
140
|
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
141
|
+
float tmp = type_conversion_table<src_t>::to_f32(*src0_ptr);
|
|
142
|
+
dst_ptr[id] = type_conversion_table<dst_t>::from_f32(tmp);
|
|
143
|
+
id++;
|
|
775
144
|
}
|
|
776
|
-
id += ne00 * (ne01 - ir1);
|
|
777
145
|
}
|
|
146
|
+
id += ne00 * (ne01 - ir1);
|
|
778
147
|
}
|
|
779
|
-
} else {
|
|
780
|
-
GGML_ABORT("fatal error"); // TODO: implement
|
|
781
148
|
}
|
|
782
149
|
}
|
|
783
|
-
|
|
784
150
|
return;
|
|
785
151
|
}
|
|
786
152
|
|
|
787
153
|
// dst counters
|
|
788
|
-
|
|
789
154
|
int64_t i10 = 0;
|
|
790
155
|
int64_t i11 = 0;
|
|
791
156
|
int64_t i12 = 0;
|
|
792
157
|
int64_t i13 = 0;
|
|
793
158
|
|
|
794
|
-
if (
|
|
159
|
+
if constexpr (std::is_same_v<dst_t, src_t>) {
|
|
795
160
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
|
796
161
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
|
797
162
|
i10 += ne00 * ir0;
|
|
@@ -812,15 +177,15 @@ static void ggml_compute_forward_dup_f32(
|
|
|
812
177
|
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
|
813
178
|
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
|
814
179
|
|
|
815
|
-
memcpy(dst_ptr, src0_ptr, sizeof(
|
|
180
|
+
memcpy(dst_ptr, src0_ptr, sizeof(dst_t));
|
|
816
181
|
|
|
817
|
-
if (++i10 ==
|
|
182
|
+
if (++i10 == ne00) {
|
|
818
183
|
i10 = 0;
|
|
819
|
-
if (++i11 ==
|
|
184
|
+
if (++i11 == ne01) {
|
|
820
185
|
i11 = 0;
|
|
821
|
-
if (++i12 ==
|
|
186
|
+
if (++i12 == ne02) {
|
|
822
187
|
i12 = 0;
|
|
823
|
-
if (++i13 ==
|
|
188
|
+
if (++i13 == ne03) {
|
|
824
189
|
i13 = 0;
|
|
825
190
|
}
|
|
826
191
|
}
|
|
@@ -843,7 +208,8 @@ static void ggml_compute_forward_dup_f32(
|
|
|
843
208
|
}
|
|
844
209
|
}
|
|
845
210
|
}
|
|
846
|
-
|
|
211
|
+
|
|
212
|
+
} else {
|
|
847
213
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
|
848
214
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
|
849
215
|
i10 += ne00 * ir0;
|
|
@@ -864,7 +230,8 @@ static void ggml_compute_forward_dup_f32(
|
|
|
864
230
|
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
|
865
231
|
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
|
866
232
|
|
|
867
|
-
|
|
233
|
+
float tmp = type_conversion_table<src_t>::to_f32(*(const src_t *) src0_ptr);
|
|
234
|
+
*(dst_t *) dst_ptr = type_conversion_table<dst_t>::from_f32(tmp);
|
|
868
235
|
|
|
869
236
|
if (++i10 == ne0) {
|
|
870
237
|
i10 = 0;
|
|
@@ -895,60 +262,63 @@ static void ggml_compute_forward_dup_f32(
|
|
|
895
262
|
}
|
|
896
263
|
}
|
|
897
264
|
}
|
|
898
|
-
}
|
|
899
|
-
|
|
900
|
-
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
|
901
|
-
i10 += ne00 * ir0;
|
|
902
|
-
while (i10 >= ne0) {
|
|
903
|
-
i10 -= ne0;
|
|
904
|
-
if (++i11 == ne1) {
|
|
905
|
-
i11 = 0;
|
|
906
|
-
if (++i12 == ne2) {
|
|
907
|
-
i12 = 0;
|
|
908
|
-
if (++i13 == ne3) {
|
|
909
|
-
i13 = 0;
|
|
910
|
-
}
|
|
911
|
-
}
|
|
912
|
-
}
|
|
913
|
-
}
|
|
914
|
-
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
|
915
|
-
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
|
916
|
-
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
|
917
|
-
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
|
265
|
+
}
|
|
266
|
+
}
|
|
918
267
|
|
|
919
|
-
*(ggml_bf16_t *) dst_ptr = GGML_FP32_TO_BF16(*(const float *) src0_ptr);
|
|
920
268
|
|
|
921
|
-
|
|
922
|
-
|
|
923
|
-
|
|
924
|
-
|
|
925
|
-
|
|
926
|
-
|
|
927
|
-
|
|
928
|
-
|
|
929
|
-
|
|
930
|
-
|
|
931
|
-
|
|
932
|
-
|
|
933
|
-
|
|
934
|
-
|
|
935
|
-
|
|
936
|
-
|
|
937
|
-
|
|
938
|
-
|
|
939
|
-
|
|
940
|
-
|
|
941
|
-
|
|
942
|
-
|
|
943
|
-
|
|
944
|
-
|
|
945
|
-
|
|
269
|
+
template<typename src_t>
|
|
270
|
+
static void ggml_compute_forward_dup_to_q(
|
|
271
|
+
const ggml_compute_params * params,
|
|
272
|
+
ggml_tensor * dst) {
|
|
273
|
+
|
|
274
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
275
|
+
|
|
276
|
+
GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
|
|
277
|
+
GGML_ASSERT(!ggml_is_quantized(src0->type));
|
|
278
|
+
|
|
279
|
+
GGML_TENSOR_UNARY_OP_LOCALS
|
|
280
|
+
|
|
281
|
+
const int ith = params->ith; // thread index
|
|
282
|
+
const int nth = params->nth; // number of threads
|
|
283
|
+
|
|
284
|
+
// parallelize by rows
|
|
285
|
+
const int nr = ne01;
|
|
286
|
+
// number of rows per thread
|
|
287
|
+
const int dr = (nr + nth - 1) / nth;
|
|
288
|
+
// row range for this thread
|
|
289
|
+
const int ir0 = dr * ith;
|
|
290
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
|
291
|
+
|
|
292
|
+
if (ggml_is_contiguous(dst) &&
|
|
293
|
+
nb00 == sizeof(src_t) &&
|
|
294
|
+
ggml_get_type_traits_cpu(dst->type)->from_float) {
|
|
295
|
+
// casting non-quantized types --> intermediate f32 --> quantized
|
|
296
|
+
ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dst->type)->from_float;
|
|
297
|
+
float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
|
|
298
|
+
|
|
299
|
+
size_t id = 0;
|
|
300
|
+
size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
|
|
301
|
+
char * dst_ptr = (char *) dst->data;
|
|
302
|
+
|
|
303
|
+
for (int i03 = 0; i03 < ne03; i03++) {
|
|
304
|
+
for (int i02 = 0; i02 < ne02; i02++) {
|
|
305
|
+
id += rs * ir0;
|
|
306
|
+
for (int i01 = ir0; i01 < ir1; i01++) {
|
|
307
|
+
const src_t * src0_ptr = (src_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
|
308
|
+
|
|
309
|
+
for (int i00 = 0; i00 < ne00; i00++) {
|
|
310
|
+
src0_f32[i00] = type_conversion_table<src_t>::to_f32(src0_ptr[i00]);
|
|
946
311
|
}
|
|
312
|
+
|
|
313
|
+
quantize_row_q(src0_f32, dst_ptr + id, ne00);
|
|
314
|
+
id += rs;
|
|
947
315
|
}
|
|
316
|
+
id += rs * (ne01 - ir1);
|
|
948
317
|
}
|
|
949
318
|
}
|
|
950
319
|
} else {
|
|
951
|
-
|
|
320
|
+
// printf("%s %s\n", ggml_type_name(src0->type), ggml_type_name(dst->type));
|
|
321
|
+
GGML_ABORT("not implemented");
|
|
952
322
|
}
|
|
953
323
|
}
|
|
954
324
|
|
|
@@ -1102,7 +472,7 @@ static void ggml_compute_forward_dup_bytes(
|
|
|
1102
472
|
}
|
|
1103
473
|
}
|
|
1104
474
|
|
|
1105
|
-
static void
|
|
475
|
+
static void ggml_compute_forward_dup_from_q(
|
|
1106
476
|
const ggml_compute_params * params,
|
|
1107
477
|
ggml_tensor * dst) {
|
|
1108
478
|
|
|
@@ -1167,20 +537,35 @@ void ggml_compute_forward_dup(
|
|
|
1167
537
|
switch (src0->type) {
|
|
1168
538
|
case GGML_TYPE_F16:
|
|
1169
539
|
{
|
|
1170
|
-
|
|
540
|
+
/**/ if (dst->type == GGML_TYPE_F16) ggml_compute_forward_dup_flt<ggml_fp16_t, ggml_fp16_t>(params, dst);
|
|
541
|
+
else if (dst->type == GGML_TYPE_BF16) ggml_compute_forward_dup_flt<ggml_fp16_t, ggml_bf16_t>(params, dst);
|
|
542
|
+
else if (dst->type == GGML_TYPE_F32) ggml_compute_forward_dup_flt<ggml_fp16_t, float >(params, dst);
|
|
543
|
+
else ggml_compute_forward_dup_to_q<ggml_fp16_t>(params, dst);
|
|
1171
544
|
} break;
|
|
1172
545
|
case GGML_TYPE_BF16:
|
|
1173
546
|
{
|
|
1174
|
-
|
|
547
|
+
/**/ if (dst->type == GGML_TYPE_F16) ggml_compute_forward_dup_flt<ggml_bf16_t, ggml_fp16_t>(params, dst);
|
|
548
|
+
else if (dst->type == GGML_TYPE_BF16) ggml_compute_forward_dup_flt<ggml_bf16_t, ggml_bf16_t>(params, dst);
|
|
549
|
+
else if (dst->type == GGML_TYPE_F32) ggml_compute_forward_dup_flt<ggml_bf16_t, float >(params, dst);
|
|
550
|
+
else ggml_compute_forward_dup_to_q<ggml_bf16_t>(params, dst);
|
|
1175
551
|
} break;
|
|
1176
552
|
case GGML_TYPE_F32:
|
|
1177
553
|
{
|
|
1178
|
-
|
|
554
|
+
/**/ if (dst->type == GGML_TYPE_F16) ggml_compute_forward_dup_flt<float, ggml_fp16_t>(params, dst);
|
|
555
|
+
else if (dst->type == GGML_TYPE_BF16) ggml_compute_forward_dup_flt<float, ggml_bf16_t>(params, dst);
|
|
556
|
+
else if (dst->type == GGML_TYPE_F32) ggml_compute_forward_dup_flt<float, float >(params, dst);
|
|
557
|
+
else if (dst->type == GGML_TYPE_I32) ggml_compute_forward_dup_flt<float, int32_t >(params, dst);
|
|
558
|
+
else ggml_compute_forward_dup_to_q<float>(params, dst);
|
|
559
|
+
} break;
|
|
560
|
+
case GGML_TYPE_I32:
|
|
561
|
+
{
|
|
562
|
+
if (dst->type == GGML_TYPE_F32) ggml_compute_forward_dup_flt<int32_t, float>(params, dst);
|
|
563
|
+
else GGML_ABORT("not implemented");
|
|
1179
564
|
} break;
|
|
1180
565
|
default:
|
|
1181
566
|
{
|
|
1182
567
|
if (ggml_is_quantized(src0->type) && dst->type == GGML_TYPE_F32) {
|
|
1183
|
-
|
|
568
|
+
ggml_compute_forward_dup_from_q(params, dst);
|
|
1184
569
|
break;
|
|
1185
570
|
}
|
|
1186
571
|
GGML_ABORT("fatal error");
|
|
@@ -2002,7 +1387,57 @@ void ggml_compute_forward_sum(
|
|
|
2002
1387
|
} break;
|
|
2003
1388
|
case GGML_TYPE_BF16:
|
|
2004
1389
|
{
|
|
2005
|
-
ggml_compute_forward_sum_bf16(params, dst);
|
|
1390
|
+
ggml_compute_forward_sum_bf16(params, dst);
|
|
1391
|
+
} break;
|
|
1392
|
+
default:
|
|
1393
|
+
{
|
|
1394
|
+
GGML_ABORT("fatal error");
|
|
1395
|
+
}
|
|
1396
|
+
}
|
|
1397
|
+
}
|
|
1398
|
+
|
|
1399
|
+
// ggml_compute_forward_cumsum
|
|
1400
|
+
|
|
1401
|
+
static void ggml_compute_forward_cumsum_f32(
|
|
1402
|
+
const ggml_compute_params * params,
|
|
1403
|
+
ggml_tensor * dst) {
|
|
1404
|
+
|
|
1405
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
1406
|
+
|
|
1407
|
+
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
|
1408
|
+
GGML_ASSERT(dst->nb[0] == sizeof(float));
|
|
1409
|
+
|
|
1410
|
+
GGML_TENSOR_UNARY_OP_LOCALS
|
|
1411
|
+
|
|
1412
|
+
GGML_ASSERT(ne0 == ne00);
|
|
1413
|
+
GGML_ASSERT(ne1 == ne01);
|
|
1414
|
+
GGML_ASSERT(ne2 == ne02);
|
|
1415
|
+
GGML_ASSERT(ne3 == ne03);
|
|
1416
|
+
|
|
1417
|
+
const auto [ir0, ir1] = get_thread_range(params, src0);
|
|
1418
|
+
|
|
1419
|
+
for (int64_t ir = ir0; ir < ir1; ++ir) {
|
|
1420
|
+
const int64_t i03 = ir/(ne02*ne01);
|
|
1421
|
+
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
|
|
1422
|
+
const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
|
|
1423
|
+
|
|
1424
|
+
float * src_row = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
|
1425
|
+
float * dst_row = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
|
|
1426
|
+
|
|
1427
|
+
ggml_vec_cumsum_f32(ne00, dst_row, src_row);
|
|
1428
|
+
}
|
|
1429
|
+
}
|
|
1430
|
+
|
|
1431
|
+
void ggml_compute_forward_cumsum(
|
|
1432
|
+
const ggml_compute_params * params,
|
|
1433
|
+
ggml_tensor * dst) {
|
|
1434
|
+
|
|
1435
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
1436
|
+
|
|
1437
|
+
switch (src0->type) {
|
|
1438
|
+
case GGML_TYPE_F32:
|
|
1439
|
+
{
|
|
1440
|
+
ggml_compute_forward_cumsum_f32(params, dst);
|
|
2006
1441
|
} break;
|
|
2007
1442
|
default:
|
|
2008
1443
|
{
|
|
@@ -2757,6 +2192,83 @@ static void ggml_compute_forward_gelu(
|
|
|
2757
2192
|
}
|
|
2758
2193
|
}
|
|
2759
2194
|
|
|
2195
|
+
// ggml_compute_fill
|
|
2196
|
+
|
|
2197
|
+
static void ggml_compute_forward_fill_f32(const ggml_compute_params * params, ggml_tensor * dst) {
|
|
2198
|
+
const float c = ggml_get_op_params_f32(dst, 0);
|
|
2199
|
+
|
|
2200
|
+
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne);
|
|
2201
|
+
GGML_TENSOR_LOCALS(size_t, nb, dst, nb);
|
|
2202
|
+
|
|
2203
|
+
const auto [ir0, ir1] = get_thread_range(params, dst);
|
|
2204
|
+
|
|
2205
|
+
for (int64_t ir = ir0; ir < ir1; ++ir) {
|
|
2206
|
+
const int64_t i03 = ir/(ne2*ne1);
|
|
2207
|
+
const int64_t i02 = (ir - i03*ne2*ne1)/ne1;
|
|
2208
|
+
const int64_t i01 = (ir - i03*ne2*ne1 - i02*ne1);
|
|
2209
|
+
|
|
2210
|
+
float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1);
|
|
2211
|
+
|
|
2212
|
+
ggml_vec_set_f32(ne0, dst_ptr, c);
|
|
2213
|
+
}
|
|
2214
|
+
}
|
|
2215
|
+
|
|
2216
|
+
void ggml_compute_forward_fill(const ggml_compute_params * params, ggml_tensor * dst) {
|
|
2217
|
+
ggml_compute_forward_fill_f32(params, dst);
|
|
2218
|
+
}
|
|
2219
|
+
|
|
2220
|
+
// ggml_compute_tri
|
|
2221
|
+
|
|
2222
|
+
static void ggml_compute_forward_tri_f32(const ggml_compute_params * params, ggml_tensor * dst) {
|
|
2223
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
2224
|
+
|
|
2225
|
+
const ggml_tri_type ttype = (ggml_tri_type) ggml_get_op_params_i32(dst, 0);
|
|
2226
|
+
|
|
2227
|
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
2228
|
+
|
|
2229
|
+
GGML_TENSOR_UNARY_OP_LOCALS
|
|
2230
|
+
|
|
2231
|
+
const auto [ir0, ir1] = get_thread_range(params, src0);
|
|
2232
|
+
|
|
2233
|
+
bool (*bipred)(int, int);
|
|
2234
|
+
|
|
2235
|
+
switch (ttype) {
|
|
2236
|
+
case GGML_TRI_TYPE_LOWER: bipred = [](int i, int r) { return i < r; }; break;
|
|
2237
|
+
case GGML_TRI_TYPE_LOWER_DIAG: bipred = [](int i, int r) { return i <= r; }; break;
|
|
2238
|
+
case GGML_TRI_TYPE_UPPER: bipred = [](int i, int r) { return i > r; }; break;
|
|
2239
|
+
case GGML_TRI_TYPE_UPPER_DIAG: bipred = [](int i, int r) { return i >= r; }; break;
|
|
2240
|
+
default: GGML_ABORT("invalid tri type");
|
|
2241
|
+
}
|
|
2242
|
+
|
|
2243
|
+
for (int64_t ir = ir0; ir < ir1; ++ir) {
|
|
2244
|
+
const int64_t i03 = ir/(ne02*ne01);
|
|
2245
|
+
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
|
|
2246
|
+
const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
|
|
2247
|
+
|
|
2248
|
+
const float * src_ptr = (const float *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
|
|
2249
|
+
float * dst_ptr = ( float *) (( char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1);
|
|
2250
|
+
|
|
2251
|
+
for (int i0 = 0; i0 < ne0; ++i0) {
|
|
2252
|
+
dst_ptr[i0] = bipred(i0, i01) ? src_ptr[i0] : 0.0f;
|
|
2253
|
+
}
|
|
2254
|
+
}
|
|
2255
|
+
}
|
|
2256
|
+
|
|
2257
|
+
void ggml_compute_forward_tri(const ggml_compute_params * params, ggml_tensor * dst) {
|
|
2258
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
2259
|
+
|
|
2260
|
+
switch (src0->type) {
|
|
2261
|
+
case GGML_TYPE_F32:
|
|
2262
|
+
{
|
|
2263
|
+
ggml_compute_forward_tri_f32(params, dst);
|
|
2264
|
+
} break;
|
|
2265
|
+
default:
|
|
2266
|
+
{
|
|
2267
|
+
GGML_ABORT("fatal error");
|
|
2268
|
+
}
|
|
2269
|
+
}
|
|
2270
|
+
}
|
|
2271
|
+
|
|
2760
2272
|
// ggml_compute_forward_gelu_erf
|
|
2761
2273
|
|
|
2762
2274
|
static void ggml_compute_forward_gelu_erf_f32(
|
|
@@ -4084,31 +3596,27 @@ static void ggml_compute_forward_norm_f32(
|
|
|
4084
3596
|
|
|
4085
3597
|
GGML_ASSERT(eps >= 0.0f);
|
|
4086
3598
|
|
|
4087
|
-
// TODO: optimize
|
|
4088
3599
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
|
4089
3600
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
|
4090
3601
|
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
|
|
4091
3602
|
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
|
4092
3603
|
|
|
4093
|
-
|
|
4094
|
-
|
|
4095
|
-
sum += (ggml_float)x[i00];
|
|
4096
|
-
}
|
|
4097
|
-
|
|
3604
|
+
float sum = 0.0;
|
|
3605
|
+
ggml_vec_sum_f32(ne00, &sum, x);
|
|
4098
3606
|
float mean = sum/ne00;
|
|
4099
3607
|
|
|
4100
3608
|
float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
|
|
3609
|
+
float variance = 0;
|
|
4101
3610
|
|
|
4102
|
-
|
|
4103
|
-
|
|
4104
|
-
|
|
4105
|
-
|
|
4106
|
-
|
|
4107
|
-
|
|
3611
|
+
#ifdef GGML_USE_ACCELERATE
|
|
3612
|
+
mean = -mean;
|
|
3613
|
+
vDSP_vsadd(x, 1, &mean, y, 1, ne00);
|
|
3614
|
+
vDSP_measqv(y, 1, &variance, ne00);
|
|
3615
|
+
#else
|
|
3616
|
+
variance = ggml_vec_cvar_f32(ne00, y, x, mean);
|
|
3617
|
+
#endif //GGML_USE_ACCELERATE
|
|
4108
3618
|
|
|
4109
|
-
float variance = sum2/ne00;
|
|
4110
3619
|
const float scale = 1.0f/sqrtf(variance + eps);
|
|
4111
|
-
|
|
4112
3620
|
ggml_vec_scale_f32(ne00, y, scale);
|
|
4113
3621
|
}
|
|
4114
3622
|
}
|
|
@@ -5076,46 +4584,6 @@ void ggml_compute_forward_cont(
|
|
|
5076
4584
|
ggml_compute_forward_dup(params, dst);
|
|
5077
4585
|
}
|
|
5078
4586
|
|
|
5079
|
-
// ggml_compute_forward_reshape
|
|
5080
|
-
|
|
5081
|
-
void ggml_compute_forward_reshape(
|
|
5082
|
-
const ggml_compute_params * params,
|
|
5083
|
-
ggml_tensor * dst) {
|
|
5084
|
-
// NOP
|
|
5085
|
-
GGML_UNUSED(params);
|
|
5086
|
-
GGML_UNUSED(dst);
|
|
5087
|
-
}
|
|
5088
|
-
|
|
5089
|
-
// ggml_compute_forward_view
|
|
5090
|
-
|
|
5091
|
-
void ggml_compute_forward_view(
|
|
5092
|
-
const ggml_compute_params * params,
|
|
5093
|
-
ggml_tensor * dst) {
|
|
5094
|
-
// NOP
|
|
5095
|
-
GGML_UNUSED(params);
|
|
5096
|
-
GGML_UNUSED(dst);
|
|
5097
|
-
}
|
|
5098
|
-
|
|
5099
|
-
// ggml_compute_forward_permute
|
|
5100
|
-
|
|
5101
|
-
void ggml_compute_forward_permute(
|
|
5102
|
-
const ggml_compute_params * params,
|
|
5103
|
-
ggml_tensor * dst) {
|
|
5104
|
-
// NOP
|
|
5105
|
-
GGML_UNUSED(params);
|
|
5106
|
-
GGML_UNUSED(dst);
|
|
5107
|
-
}
|
|
5108
|
-
|
|
5109
|
-
// ggml_compute_forward_transpose
|
|
5110
|
-
|
|
5111
|
-
void ggml_compute_forward_transpose(
|
|
5112
|
-
const ggml_compute_params * params,
|
|
5113
|
-
ggml_tensor * dst) {
|
|
5114
|
-
// NOP
|
|
5115
|
-
GGML_UNUSED(params);
|
|
5116
|
-
GGML_UNUSED(dst);
|
|
5117
|
-
}
|
|
5118
|
-
|
|
5119
4587
|
// ggml_compute_forward_get_rows
|
|
5120
4588
|
|
|
5121
4589
|
static void ggml_compute_forward_get_rows_q(
|
|
@@ -5356,6 +4824,7 @@ void ggml_compute_forward_get_rows(
|
|
|
5356
4824
|
//}
|
|
5357
4825
|
}
|
|
5358
4826
|
|
|
4827
|
+
template<typename idx_t>
|
|
5359
4828
|
static void ggml_compute_forward_set_rows_f32(
|
|
5360
4829
|
const ggml_compute_params * params,
|
|
5361
4830
|
ggml_tensor * dst) {
|
|
@@ -5394,7 +4863,7 @@ static void ggml_compute_forward_set_rows_f32(
|
|
|
5394
4863
|
const int64_t i11 = i02%ne11;
|
|
5395
4864
|
const int64_t i10 = i;
|
|
5396
4865
|
|
|
5397
|
-
const int64_t i1 = *(
|
|
4866
|
+
const int64_t i1 = *(idx_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
|
|
5398
4867
|
|
|
5399
4868
|
GGML_ASSERT(i1 >= 0 && i1 < ne1);
|
|
5400
4869
|
|
|
@@ -5411,11 +4880,18 @@ void ggml_compute_forward_set_rows(
|
|
|
5411
4880
|
ggml_tensor * dst) {
|
|
5412
4881
|
|
|
5413
4882
|
const ggml_tensor * src0 = dst->src[0];
|
|
4883
|
+
const ggml_tensor * src1 = dst->src[1];
|
|
5414
4884
|
|
|
5415
4885
|
switch (src0->type) {
|
|
5416
4886
|
case GGML_TYPE_F32:
|
|
5417
4887
|
{
|
|
5418
|
-
|
|
4888
|
+
if (src1->type == GGML_TYPE_I64) {
|
|
4889
|
+
ggml_compute_forward_set_rows_f32<int64_t>(params, dst);
|
|
4890
|
+
} else if (src1->type == GGML_TYPE_I32) {
|
|
4891
|
+
ggml_compute_forward_set_rows_f32<int32_t>(params, dst);
|
|
4892
|
+
} else {
|
|
4893
|
+
GGML_ABORT("src1->type = %d (%s) not supported", src1->type, ggml_type_name(src1->type));
|
|
4894
|
+
}
|
|
5419
4895
|
} break;
|
|
5420
4896
|
default:
|
|
5421
4897
|
{
|
|
@@ -6067,270 +5543,117 @@ static void rope_yarn(
|
|
|
6067
5543
|
mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
|
|
6068
5544
|
}
|
|
6069
5545
|
*cos_theta = cosf(theta) * mscale;
|
|
6070
|
-
*sin_theta = sinf(theta) * mscale;
|
|
6071
|
-
}
|
|
6072
|
-
|
|
6073
|
-
static void ggml_rope_cache_init(
|
|
6074
|
-
float theta_base, float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
|
|
6075
|
-
float * cache, float sin_sign, float theta_scale) {
|
|
6076
|
-
// ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
|
|
6077
|
-
float theta = theta_base;
|
|
6078
|
-
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
|
|
6079
|
-
const float ff = freq_factors ? freq_factors[i0/2] : 1.0f;
|
|
6080
|
-
rope_yarn(
|
|
6081
|
-
theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
|
|
6082
|
-
);
|
|
6083
|
-
cache[i0 + 1] *= sin_sign;
|
|
6084
|
-
|
|
6085
|
-
theta *= theta_scale;
|
|
6086
|
-
}
|
|
6087
|
-
}
|
|
6088
|
-
|
|
6089
|
-
static void ggml_mrope_cache_init(
|
|
6090
|
-
float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[4], bool indep_sects,
|
|
6091
|
-
float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
|
|
6092
|
-
float * cache, float sin_sign, float theta_scale) {
|
|
6093
|
-
// ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
|
|
6094
|
-
float theta_t = theta_base_t;
|
|
6095
|
-
float theta_h = theta_base_h;
|
|
6096
|
-
float theta_w = theta_base_w;
|
|
6097
|
-
float theta_e = theta_base_e; // extra position id for vision encoder
|
|
6098
|
-
int sect_dims = sections[0] + sections[1] + sections[2] + sections[3];
|
|
6099
|
-
int sec_w = sections[1] + sections[0];
|
|
6100
|
-
int sec_e = sections[2] + sec_w;
|
|
6101
|
-
GGML_ASSERT(sect_dims <= ne0);
|
|
6102
|
-
|
|
6103
|
-
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
|
|
6104
|
-
const float ff = freq_factors ? freq_factors[i0/2] : 1.0f;
|
|
6105
|
-
|
|
6106
|
-
int sector = (i0 / 2) % sect_dims;
|
|
6107
|
-
if (indep_sects) {
|
|
6108
|
-
// compute theta independently for each dim sections
|
|
6109
|
-
// (i.e. reset corresponding theta when `i0` go from one section to another)
|
|
6110
|
-
if (sector == 0) {
|
|
6111
|
-
theta_t = theta_base_t;
|
|
6112
|
-
}
|
|
6113
|
-
else if (sector == sections[0]) {
|
|
6114
|
-
theta_h = theta_base_h;;
|
|
6115
|
-
}
|
|
6116
|
-
else if (sector == sec_w) {
|
|
6117
|
-
theta_w = theta_base_w;
|
|
6118
|
-
}
|
|
6119
|
-
else if (sector == sec_e) {
|
|
6120
|
-
theta_e = theta_base_e;
|
|
6121
|
-
}
|
|
6122
|
-
}
|
|
6123
|
-
|
|
6124
|
-
float theta = theta_t;
|
|
6125
|
-
if (sector >= sections[0] && sector < sec_w) {
|
|
6126
|
-
theta = theta_h;
|
|
6127
|
-
}
|
|
6128
|
-
else if (sector >= sec_w && sector < sec_w + sections[2]) {
|
|
6129
|
-
theta = theta_w;
|
|
6130
|
-
}
|
|
6131
|
-
else if (sector >= sec_w + sections[2]) {
|
|
6132
|
-
theta = theta_e;
|
|
6133
|
-
}
|
|
6134
|
-
|
|
6135
|
-
rope_yarn(
|
|
6136
|
-
theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
|
|
6137
|
-
);
|
|
6138
|
-
cache[i0 + 1] *= sin_sign;
|
|
6139
|
-
|
|
6140
|
-
theta_t *= theta_scale;
|
|
6141
|
-
theta_w *= theta_scale;
|
|
6142
|
-
theta_h *= theta_scale;
|
|
6143
|
-
theta_e *= theta_scale;
|
|
6144
|
-
}
|
|
6145
|
-
}
|
|
6146
|
-
|
|
6147
|
-
static void ggml_compute_forward_rope_f32(
|
|
6148
|
-
const ggml_compute_params * params,
|
|
6149
|
-
ggml_tensor * dst,
|
|
6150
|
-
const bool forward) {
|
|
6151
|
-
|
|
6152
|
-
const ggml_tensor * src0 = dst->src[0];
|
|
6153
|
-
const ggml_tensor * src1 = dst->src[1];
|
|
6154
|
-
const ggml_tensor * src2 = dst->src[2];
|
|
6155
|
-
|
|
6156
|
-
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
|
6157
|
-
int sections[4];
|
|
6158
|
-
|
|
6159
|
-
//const int n_past = ((int32_t *) dst->op_params)[0];
|
|
6160
|
-
const int n_dims = ((int32_t *) dst->op_params)[1];
|
|
6161
|
-
const int mode = ((int32_t *) dst->op_params)[2];
|
|
6162
|
-
//const int n_ctx = ((int32_t *) dst->op_params)[3];
|
|
6163
|
-
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
|
|
6164
|
-
|
|
6165
|
-
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
|
6166
|
-
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
|
|
6167
|
-
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
|
|
6168
|
-
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
|
|
6169
|
-
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
|
6170
|
-
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
|
6171
|
-
memcpy(§ions, (int32_t *) dst->op_params + 11, sizeof(int)*4);
|
|
6172
|
-
|
|
6173
|
-
GGML_TENSOR_UNARY_OP_LOCALS
|
|
6174
|
-
|
|
6175
|
-
//printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
|
|
6176
|
-
//printf("n_past = %d, ne2 = %d\n", n_past, ne2);
|
|
6177
|
-
|
|
6178
|
-
GGML_ASSERT(nb00 == sizeof(float));
|
|
6179
|
-
|
|
6180
|
-
const int ith = params->ith;
|
|
6181
|
-
const int nth = params->nth;
|
|
6182
|
-
|
|
6183
|
-
const int nr = ggml_nrows(dst);
|
|
6184
|
-
|
|
6185
|
-
GGML_ASSERT(n_dims <= ne0);
|
|
6186
|
-
GGML_ASSERT(n_dims % 2 == 0);
|
|
6187
|
-
|
|
6188
|
-
// rows per thread
|
|
6189
|
-
const int dr = (nr + nth - 1)/nth;
|
|
6190
|
-
|
|
6191
|
-
// row range for this thread
|
|
6192
|
-
const int ir0 = dr*ith;
|
|
6193
|
-
const int ir1 = MIN(ir0 + dr, nr);
|
|
6194
|
-
|
|
6195
|
-
// row index used to determine which thread to use
|
|
6196
|
-
int ir = 0;
|
|
6197
|
-
|
|
6198
|
-
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
|
6199
|
-
|
|
6200
|
-
float corr_dims[2];
|
|
6201
|
-
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
|
|
6202
|
-
|
|
6203
|
-
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
|
6204
|
-
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; // ggml_rope_multi, multimodal rotary position embedding
|
|
6205
|
-
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
|
|
6206
|
-
|
|
6207
|
-
if (is_mrope) {
|
|
6208
|
-
GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
|
|
6209
|
-
}
|
|
6210
|
-
|
|
6211
|
-
if (is_vision) {
|
|
6212
|
-
GGML_ASSERT(n_dims == ne0/2);
|
|
6213
|
-
}
|
|
6214
|
-
|
|
6215
|
-
const float * freq_factors = NULL;
|
|
6216
|
-
if (src2 != NULL) {
|
|
6217
|
-
GGML_ASSERT(src2->type == GGML_TYPE_F32);
|
|
6218
|
-
GGML_ASSERT(src2->ne[0] >= n_dims / 2);
|
|
6219
|
-
freq_factors = (const float *) src2->data;
|
|
6220
|
-
}
|
|
6221
|
-
|
|
6222
|
-
// backward process uses inverse rotation by cos and sin.
|
|
6223
|
-
// cos and sin build a rotation matrix, where the inverse is the transpose.
|
|
6224
|
-
// this essentially just switches the sign of sin.
|
|
6225
|
-
const float sin_sign = forward ? 1.0f : -1.0f;
|
|
6226
|
-
|
|
6227
|
-
const int32_t * pos = (const int32_t *) src1->data;
|
|
6228
|
-
|
|
6229
|
-
for (int64_t i3 = 0; i3 < ne3; i3++) { // batch
|
|
6230
|
-
for (int64_t i2 = 0; i2 < ne2; i2++) { // seq-len
|
|
6231
|
-
|
|
6232
|
-
float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
|
|
6233
|
-
if (!is_mrope) {
|
|
6234
|
-
const int64_t p = pos[i2];
|
|
6235
|
-
ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
|
6236
|
-
}
|
|
6237
|
-
else {
|
|
6238
|
-
const int64_t p_t = pos[i2];
|
|
6239
|
-
const int64_t p_h = pos[i2 + ne2];
|
|
6240
|
-
const int64_t p_w = pos[i2 + ne2 * 2];
|
|
6241
|
-
const int64_t p_e = pos[i2 + ne2 * 3];
|
|
6242
|
-
ggml_mrope_cache_init(
|
|
6243
|
-
p_t, p_h, p_w, p_e, sections, is_vision,
|
|
6244
|
-
freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
|
6245
|
-
}
|
|
6246
|
-
|
|
6247
|
-
for (int64_t i1 = 0; i1 < ne1; i1++) { // attn-heads
|
|
6248
|
-
if (ir++ < ir0) continue;
|
|
6249
|
-
if (ir > ir1) break;
|
|
6250
|
-
|
|
6251
|
-
if (is_neox || is_mrope) {
|
|
6252
|
-
if (is_vision){
|
|
6253
|
-
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
|
|
6254
|
-
const int64_t ic = i0/2;
|
|
6255
|
-
|
|
6256
|
-
const float cos_theta = cache[i0 + 0];
|
|
6257
|
-
const float sin_theta = cache[i0 + 1];
|
|
6258
|
-
|
|
6259
|
-
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
|
6260
|
-
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
|
6261
|
-
|
|
6262
|
-
const float x0 = src[0];
|
|
6263
|
-
const float x1 = src[n_dims];
|
|
5546
|
+
*sin_theta = sinf(theta) * mscale;
|
|
5547
|
+
}
|
|
6264
5548
|
|
|
6265
|
-
|
|
6266
|
-
|
|
6267
|
-
|
|
6268
|
-
|
|
6269
|
-
|
|
6270
|
-
|
|
5549
|
+
static void ggml_rope_cache_init(
|
|
5550
|
+
float theta_base, float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
|
|
5551
|
+
float * cache, float sin_sign, float theta_scale) {
|
|
5552
|
+
// ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
|
|
5553
|
+
float theta = theta_base;
|
|
5554
|
+
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
|
|
5555
|
+
const float ff = freq_factors ? freq_factors[i0/2] : 1.0f;
|
|
5556
|
+
rope_yarn(
|
|
5557
|
+
theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
|
|
5558
|
+
);
|
|
5559
|
+
cache[i0 + 1] *= sin_sign;
|
|
6271
5560
|
|
|
6272
|
-
|
|
6273
|
-
|
|
5561
|
+
theta *= theta_scale;
|
|
5562
|
+
}
|
|
5563
|
+
}
|
|
6274
5564
|
|
|
6275
|
-
|
|
6276
|
-
|
|
5565
|
+
static void ggml_mrope_cache_init(
|
|
5566
|
+
float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[4], bool is_imrope, bool indep_sects,
|
|
5567
|
+
float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
|
|
5568
|
+
float * cache, float sin_sign, float theta_scale) {
|
|
5569
|
+
// ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
|
|
5570
|
+
float theta_t = theta_base_t;
|
|
5571
|
+
float theta_h = theta_base_h;
|
|
5572
|
+
float theta_w = theta_base_w;
|
|
5573
|
+
float theta_e = theta_base_e; // extra position id for vision encoder
|
|
5574
|
+
int sect_dims = sections[0] + sections[1] + sections[2] + sections[3];
|
|
5575
|
+
int sec_w = sections[1] + sections[0];
|
|
5576
|
+
int sec_e = sections[2] + sec_w;
|
|
5577
|
+
GGML_ASSERT(sect_dims <= ne0);
|
|
6277
5578
|
|
|
6278
|
-
|
|
6279
|
-
|
|
5579
|
+
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
|
|
5580
|
+
const float ff = freq_factors ? freq_factors[i0/2] : 1.0f;
|
|
6280
5581
|
|
|
6281
|
-
|
|
6282
|
-
|
|
6283
|
-
|
|
6284
|
-
|
|
6285
|
-
|
|
6286
|
-
|
|
6287
|
-
|
|
6288
|
-
|
|
5582
|
+
int sector = (i0 / 2) % sect_dims;
|
|
5583
|
+
if (indep_sects) {
|
|
5584
|
+
// compute theta independently for each dim sections
|
|
5585
|
+
// (i.e. reset corresponding theta when `i0` go from one section to another)
|
|
5586
|
+
if (sector == 0) {
|
|
5587
|
+
theta_t = theta_base_t;
|
|
5588
|
+
}
|
|
5589
|
+
else if (sector == sections[0]) {
|
|
5590
|
+
theta_h = theta_base_h;;
|
|
5591
|
+
}
|
|
5592
|
+
else if (sector == sec_w) {
|
|
5593
|
+
theta_w = theta_base_w;
|
|
5594
|
+
}
|
|
5595
|
+
else if (sector == sec_e) {
|
|
5596
|
+
theta_e = theta_base_e;
|
|
5597
|
+
}
|
|
5598
|
+
}
|
|
6289
5599
|
|
|
6290
|
-
|
|
6291
|
-
|
|
5600
|
+
float theta = theta_t;
|
|
5601
|
+
if (is_imrope) { // qwen3vl apply interleaved mrope
|
|
5602
|
+
if (sector % 3 == 1 && sector < 3 * sections[1]) {
|
|
5603
|
+
theta = theta_h;
|
|
5604
|
+
} else if (sector % 3 == 2 && sector < 3 * sections[2]) {
|
|
5605
|
+
theta = theta_w;
|
|
5606
|
+
} else if (sector % 3 == 0 && sector < 3 * sections[0]) {
|
|
5607
|
+
theta = theta_t;
|
|
5608
|
+
} else {
|
|
5609
|
+
theta = theta_e;
|
|
5610
|
+
}
|
|
5611
|
+
} else {
|
|
5612
|
+
if (sector >= sections[0] && sector < sec_w) {
|
|
5613
|
+
theta = theta_h;
|
|
5614
|
+
}
|
|
5615
|
+
else if (sector >= sec_w && sector < sec_w + sections[2]) {
|
|
5616
|
+
theta = theta_w;
|
|
5617
|
+
}
|
|
5618
|
+
else if (sector >= sec_w + sections[2]) {
|
|
5619
|
+
theta = theta_e;
|
|
5620
|
+
}
|
|
5621
|
+
}
|
|
6292
5622
|
|
|
6293
|
-
|
|
6294
|
-
|
|
5623
|
+
rope_yarn(
|
|
5624
|
+
theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
|
|
5625
|
+
);
|
|
5626
|
+
cache[i0 + 1] *= sin_sign;
|
|
6295
5627
|
|
|
6296
|
-
|
|
6297
|
-
|
|
6298
|
-
|
|
6299
|
-
|
|
5628
|
+
theta_t *= theta_scale;
|
|
5629
|
+
theta_w *= theta_scale;
|
|
5630
|
+
theta_h *= theta_scale;
|
|
5631
|
+
theta_e *= theta_scale;
|
|
5632
|
+
}
|
|
5633
|
+
}
|
|
6300
5634
|
|
|
6301
|
-
if (is_vision) {
|
|
6302
|
-
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
|
|
6303
|
-
const int64_t ic = i0/2;
|
|
6304
5635
|
|
|
6305
|
-
|
|
6306
|
-
|
|
5636
|
+
template<typename T>
|
|
5637
|
+
static void rotate_pairs(const int64_t n, const int64_t n_offset, const float * cache, const T * src_data, T * dst_data, const int scale = 2) {
|
|
5638
|
+
for (int64_t i0 = 0; i0 < n; i0 += 2) {
|
|
5639
|
+
const int64_t ic = i0/scale; // hack for GGML_ROPE_TYPE_NORMAL, where we need ic = i0; for all other cases, ic = i0/2
|
|
6307
5640
|
|
|
6308
|
-
|
|
6309
|
-
|
|
5641
|
+
const float cos_theta = cache[i0 + 0];
|
|
5642
|
+
const float sin_theta = cache[i0 + 1];
|
|
6310
5643
|
|
|
6311
|
-
|
|
6312
|
-
|
|
5644
|
+
const T * const src = src_data + ic;
|
|
5645
|
+
T * dst = dst_data + ic;
|
|
6313
5646
|
|
|
6314
|
-
|
|
6315
|
-
|
|
6316
|
-
}
|
|
6317
|
-
} else {
|
|
6318
|
-
// fill the remain channels with data from src tensor
|
|
6319
|
-
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
|
|
6320
|
-
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
|
6321
|
-
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
5647
|
+
const float x0 = type_conversion_table<T>::to_f32(src[0]);
|
|
5648
|
+
const float x1 = type_conversion_table<T>::to_f32(src[n_offset]);
|
|
6322
5649
|
|
|
6323
|
-
|
|
6324
|
-
|
|
6325
|
-
|
|
6326
|
-
}
|
|
6327
|
-
}
|
|
6328
|
-
}
|
|
6329
|
-
}
|
|
5650
|
+
dst[0] = type_conversion_table<T>::from_f32(x0*cos_theta - x1*sin_theta);
|
|
5651
|
+
dst[n_offset] = type_conversion_table<T>::from_f32(x0*sin_theta + x1*cos_theta);
|
|
5652
|
+
}
|
|
6330
5653
|
}
|
|
6331
5654
|
|
|
6332
|
-
|
|
6333
|
-
static void
|
|
5655
|
+
template<typename T> //float or ggml_fp16_t
|
|
5656
|
+
static void ggml_compute_forward_rope_flt(
|
|
6334
5657
|
const ggml_compute_params * params,
|
|
6335
5658
|
ggml_tensor * dst,
|
|
6336
5659
|
const bool forward) {
|
|
@@ -6339,6 +5662,9 @@ static void ggml_compute_forward_rope_f16(
|
|
|
6339
5662
|
const ggml_tensor * src1 = dst->src[1];
|
|
6340
5663
|
const ggml_tensor * src2 = dst->src[2];
|
|
6341
5664
|
|
|
5665
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
|
|
5666
|
+
GGML_ASSERT(src1->type == GGML_TYPE_I32);
|
|
5667
|
+
|
|
6342
5668
|
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
|
6343
5669
|
int sections[4];
|
|
6344
5670
|
|
|
@@ -6347,6 +5673,7 @@ static void ggml_compute_forward_rope_f16(
|
|
|
6347
5673
|
const int mode = ((int32_t *) dst->op_params)[2];
|
|
6348
5674
|
//const int n_ctx = ((int32_t *) dst->op_params)[3];
|
|
6349
5675
|
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
|
|
5676
|
+
|
|
6350
5677
|
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
|
6351
5678
|
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
|
|
6352
5679
|
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
|
|
@@ -6355,13 +5682,13 @@ static void ggml_compute_forward_rope_f16(
|
|
|
6355
5682
|
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
|
6356
5683
|
memcpy(§ions, (int32_t *) dst->op_params + 11, sizeof(int)*4);
|
|
6357
5684
|
|
|
6358
|
-
|
|
6359
5685
|
GGML_TENSOR_UNARY_OP_LOCALS
|
|
6360
5686
|
|
|
6361
5687
|
//printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
|
|
6362
5688
|
//printf("n_past = %d, ne2 = %d\n", n_past, ne2);
|
|
6363
5689
|
|
|
6364
|
-
GGML_ASSERT(nb0 ==
|
|
5690
|
+
GGML_ASSERT(nb0 == nb00);
|
|
5691
|
+
GGML_ASSERT(nb0 == sizeof(T));
|
|
6365
5692
|
|
|
6366
5693
|
const int ith = params->ith;
|
|
6367
5694
|
const int nth = params->nth;
|
|
@@ -6386,11 +5713,11 @@ static void ggml_compute_forward_rope_f16(
|
|
|
6386
5713
|
float corr_dims[2];
|
|
6387
5714
|
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
|
|
6388
5715
|
|
|
6389
|
-
const bool
|
|
6390
|
-
const bool
|
|
5716
|
+
const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE; // qwen3vl apply interleaved mrope
|
|
5717
|
+
const bool mrope_used = mode & GGML_ROPE_TYPE_MROPE; // ggml_rope_multi, note: also true for vision (24 & 8 == true) and for imrope
|
|
6391
5718
|
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
|
|
6392
5719
|
|
|
6393
|
-
if (
|
|
5720
|
+
if (mrope_used) {
|
|
6394
5721
|
GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
|
|
6395
5722
|
}
|
|
6396
5723
|
|
|
@@ -6412,11 +5739,11 @@ static void ggml_compute_forward_rope_f16(
|
|
|
6412
5739
|
|
|
6413
5740
|
const int32_t * pos = (const int32_t *) src1->data;
|
|
6414
5741
|
|
|
6415
|
-
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
|
6416
|
-
for (int64_t i2 = 0; i2 < ne2; i2++) {
|
|
5742
|
+
for (int64_t i3 = 0; i3 < ne3; i3++) { // batch
|
|
5743
|
+
for (int64_t i2 = 0; i2 < ne2; i2++) { // seq-len
|
|
6417
5744
|
|
|
6418
5745
|
float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
|
|
6419
|
-
if (!
|
|
5746
|
+
if (!mrope_used) {
|
|
6420
5747
|
const int64_t p = pos[i2];
|
|
6421
5748
|
ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
|
6422
5749
|
}
|
|
@@ -6426,90 +5753,44 @@ static void ggml_compute_forward_rope_f16(
|
|
|
6426
5753
|
const int64_t p_w = pos[i2 + ne2 * 2];
|
|
6427
5754
|
const int64_t p_e = pos[i2 + ne2 * 3];
|
|
6428
5755
|
ggml_mrope_cache_init(
|
|
6429
|
-
p_t, p_h, p_w, p_e, sections, is_vision,
|
|
5756
|
+
p_t, p_h, p_w, p_e, sections, is_imrope, is_vision,
|
|
6430
5757
|
freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
|
6431
5758
|
}
|
|
6432
5759
|
|
|
6433
|
-
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
|
5760
|
+
for (int64_t i1 = 0; i1 < ne1; i1++) { // attn-heads
|
|
6434
5761
|
if (ir++ < ir0) continue;
|
|
6435
5762
|
if (ir > ir1) break;
|
|
6436
5763
|
|
|
6437
|
-
|
|
6438
|
-
|
|
6439
|
-
|
|
6440
|
-
|
|
6441
|
-
|
|
6442
|
-
|
|
6443
|
-
|
|
6444
|
-
|
|
6445
|
-
|
|
6446
|
-
|
|
6447
|
-
|
|
6448
|
-
|
|
6449
|
-
|
|
6450
|
-
|
|
6451
|
-
|
|
6452
|
-
|
|
6453
|
-
|
|
6454
|
-
|
|
6455
|
-
|
|
6456
|
-
|
|
6457
|
-
|
|
6458
|
-
const float cos_theta = cache[i0 + 0];
|
|
6459
|
-
const float sin_theta = cache[i0 + 1];
|
|
6460
|
-
|
|
6461
|
-
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
|
6462
|
-
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
|
6463
|
-
|
|
6464
|
-
const float x0 = GGML_CPU_FP16_TO_FP32(src[0]);
|
|
6465
|
-
const float x1 = GGML_CPU_FP16_TO_FP32(src[n_dims/2]);
|
|
6466
|
-
|
|
6467
|
-
dst_data[0] = GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
|
|
6468
|
-
dst_data[n_dims/2] = GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
|
6469
|
-
}
|
|
6470
|
-
}
|
|
6471
|
-
} else {
|
|
6472
|
-
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
|
|
6473
|
-
const float cos_theta = cache[i0 + 0];
|
|
6474
|
-
const float sin_theta = cache[i0 + 1];
|
|
6475
|
-
|
|
6476
|
-
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
|
6477
|
-
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
6478
|
-
|
|
6479
|
-
const float x0 = GGML_CPU_FP16_TO_FP32(src[0]);
|
|
6480
|
-
const float x1 = GGML_CPU_FP16_TO_FP32(src[1]);
|
|
6481
|
-
|
|
6482
|
-
dst_data[0] = GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
|
|
6483
|
-
dst_data[1] = GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
|
6484
|
-
}
|
|
6485
|
-
}
|
|
6486
|
-
|
|
6487
|
-
if (is_vision) {
|
|
6488
|
-
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
|
|
6489
|
-
const int64_t ic = i0/2;
|
|
6490
|
-
|
|
6491
|
-
const float cos_theta = cache[i0 + 0];
|
|
6492
|
-
const float sin_theta = cache[i0 + 1];
|
|
6493
|
-
|
|
6494
|
-
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
|
6495
|
-
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
|
6496
|
-
|
|
6497
|
-
const float x0 = GGML_CPU_FP16_TO_FP32(src[0]);
|
|
6498
|
-
const float x1 = GGML_CPU_FP16_TO_FP32(src[n_dims]);
|
|
6499
|
-
|
|
6500
|
-
dst_data[0] = GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
|
|
6501
|
-
dst_data[n_dims] = GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
|
6502
|
-
}
|
|
6503
|
-
} else {
|
|
5764
|
+
T * src = (T *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
|
|
5765
|
+
T * dst_data = (T *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
|
|
5766
|
+
|
|
5767
|
+
switch (mode) {
|
|
5768
|
+
case GGML_ROPE_TYPE_NORMAL:
|
|
5769
|
+
rotate_pairs<T>(n_dims, 1, cache, src, dst_data, 1);
|
|
5770
|
+
break;
|
|
5771
|
+
case GGML_ROPE_TYPE_NEOX:
|
|
5772
|
+
case GGML_ROPE_TYPE_MROPE:
|
|
5773
|
+
case GGML_ROPE_TYPE_IMROPE:
|
|
5774
|
+
rotate_pairs<T>(n_dims, n_dims/2, cache, src, dst_data);
|
|
5775
|
+
break;
|
|
5776
|
+
case GGML_ROPE_TYPE_VISION:
|
|
5777
|
+
rotate_pairs<T>(ne0, n_dims, cache, src, dst_data);
|
|
5778
|
+
break;
|
|
5779
|
+
default:
|
|
5780
|
+
GGML_ABORT("rope type not supported");
|
|
5781
|
+
}
|
|
5782
|
+
|
|
5783
|
+
if (!is_vision) {
|
|
5784
|
+
// fill the remain channels with data from src tensor
|
|
6504
5785
|
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
|
|
6505
|
-
const
|
|
6506
|
-
|
|
5786
|
+
const T * const src = (T *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
|
5787
|
+
T * dst_data = (T *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
6507
5788
|
|
|
6508
5789
|
dst_data[0] = src[0];
|
|
6509
5790
|
dst_data[1] = src[1];
|
|
6510
5791
|
}
|
|
6511
5792
|
}
|
|
6512
|
-
}
|
|
5793
|
+
} //attn-heads
|
|
6513
5794
|
}
|
|
6514
5795
|
}
|
|
6515
5796
|
}
|
|
@@ -6523,11 +5804,11 @@ void ggml_compute_forward_rope(
|
|
|
6523
5804
|
switch (src0->type) {
|
|
6524
5805
|
case GGML_TYPE_F16:
|
|
6525
5806
|
{
|
|
6526
|
-
|
|
5807
|
+
ggml_compute_forward_rope_flt<ggml_fp16_t>(params, dst, true);
|
|
6527
5808
|
} break;
|
|
6528
5809
|
case GGML_TYPE_F32:
|
|
6529
5810
|
{
|
|
6530
|
-
|
|
5811
|
+
ggml_compute_forward_rope_flt<float>(params, dst, true);
|
|
6531
5812
|
} break;
|
|
6532
5813
|
default:
|
|
6533
5814
|
{
|
|
@@ -6547,11 +5828,11 @@ void ggml_compute_forward_rope_back(
|
|
|
6547
5828
|
switch (src0->type) {
|
|
6548
5829
|
case GGML_TYPE_F16:
|
|
6549
5830
|
{
|
|
6550
|
-
|
|
5831
|
+
ggml_compute_forward_rope_flt<ggml_fp16_t>(params, dst, false);
|
|
6551
5832
|
} break;
|
|
6552
5833
|
case GGML_TYPE_F32:
|
|
6553
5834
|
{
|
|
6554
|
-
|
|
5835
|
+
ggml_compute_forward_rope_flt<float>(params, dst, false);
|
|
6555
5836
|
} break;
|
|
6556
5837
|
default:
|
|
6557
5838
|
{
|
|
@@ -6938,10 +6219,198 @@ void ggml_compute_forward_im2col_back_f32(
|
|
|
6938
6219
|
const ggml_compute_params * params,
|
|
6939
6220
|
ggml_tensor * dst) {
|
|
6940
6221
|
|
|
6941
|
-
const ggml_tensor * src0 = dst->src[0]; // gradients of forward pass output
|
|
6942
|
-
const ggml_tensor * src1 = dst->src[1]; // convolution kernel
|
|
6222
|
+
const ggml_tensor * src0 = dst->src[0]; // gradients of forward pass output
|
|
6223
|
+
const ggml_tensor * src1 = dst->src[1]; // convolution kernel
|
|
6224
|
+
|
|
6225
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
6226
|
+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
|
6227
|
+
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
|
6228
|
+
|
|
6229
|
+
GGML_TENSOR_BINARY_OP_LOCALS;
|
|
6230
|
+
|
|
6231
|
+
const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
|
|
6232
|
+
const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
|
|
6233
|
+
const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
|
|
6234
|
+
const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
|
|
6235
|
+
const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
|
|
6236
|
+
const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
|
|
6237
|
+
const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
|
|
6238
|
+
|
|
6239
|
+
const int ith = params->ith;
|
|
6240
|
+
const int nth = params->nth;
|
|
6241
|
+
|
|
6242
|
+
const int64_t N = is_2D ? ne3 : ne2;
|
|
6243
|
+
const int64_t IC = is_2D ? ne2 : ne1;
|
|
6244
|
+
const int64_t IH = is_2D ? ne1 : 1;
|
|
6245
|
+
const int64_t IW = ne0;
|
|
6246
|
+
|
|
6247
|
+
const int64_t KH = is_2D ? ne11 : 1;
|
|
6248
|
+
const int64_t KW = ne10;
|
|
6249
|
+
|
|
6250
|
+
const int64_t OH = is_2D ? ne02 : 1;
|
|
6251
|
+
const int64_t OW = ne01;
|
|
6252
|
+
|
|
6253
|
+
int ofs0 = is_2D ? nb3 : nb2;
|
|
6254
|
+
int ofs1 = is_2D ? nb2 : nb1;
|
|
6255
|
+
|
|
6256
|
+
GGML_ASSERT(nb0 == sizeof(float));
|
|
6257
|
+
|
|
6258
|
+
// im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
|
|
6259
|
+
{
|
|
6260
|
+
float * const wdata = (float *) dst->data;
|
|
6261
|
+
|
|
6262
|
+
for (int64_t in = 0; in < N; in++) {
|
|
6263
|
+
for (int64_t iic = ith; iic < IC; iic += nth) {
|
|
6264
|
+
for (int64_t iih = 0; iih < IH; iih++) {
|
|
6265
|
+
for (int64_t iiw = 0; iiw < IW; iiw++) {
|
|
6266
|
+
|
|
6267
|
+
// micro kernel
|
|
6268
|
+
float grad = 0.0f;
|
|
6269
|
+
for (int64_t ikh = 0; ikh < KH; ikh++) {
|
|
6270
|
+
for (int64_t ikw = 0; ikw < KW; ikw++) {
|
|
6271
|
+
// For s0 > 1 some values were skipped over in the forward pass.
|
|
6272
|
+
// These values have tmpw % s0 != 0 and need to be skipped in the backwards pass as well.
|
|
6273
|
+
const int64_t tmpw = (iiw + p0 - ikw*d0);
|
|
6274
|
+
if (tmpw % s0 != 0) {
|
|
6275
|
+
continue;
|
|
6276
|
+
}
|
|
6277
|
+
const int64_t iow = tmpw / s0;
|
|
6278
|
+
|
|
6279
|
+
// Equivalent logic as above except for s1.
|
|
6280
|
+
int64_t ioh;
|
|
6281
|
+
if (is_2D) {
|
|
6282
|
+
const int64_t tmph = iih + p1 - ikh*d1;
|
|
6283
|
+
|
|
6284
|
+
if (tmph % s1 != 0) {
|
|
6285
|
+
continue;
|
|
6286
|
+
}
|
|
6287
|
+
|
|
6288
|
+
ioh = tmph / s1;
|
|
6289
|
+
} else {
|
|
6290
|
+
ioh = 0;
|
|
6291
|
+
}
|
|
6292
|
+
|
|
6293
|
+
if (iow < 0 || iow >= OW || ioh < 0 || ioh >= OH) {
|
|
6294
|
+
continue;
|
|
6295
|
+
}
|
|
6296
|
+
|
|
6297
|
+
const float * const grad_in = (const float *) src0->data
|
|
6298
|
+
+ (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
|
|
6299
|
+
grad += grad_in[iic*(KH*KW) + ikh*KW + ikw];
|
|
6300
|
+
}
|
|
6301
|
+
}
|
|
6302
|
+
float * dst_data = (float *)((char *) wdata + (in*ofs0 + iic*ofs1)); // [IH, IW]
|
|
6303
|
+
dst_data[iih*IW + iiw] = grad;
|
|
6304
|
+
}
|
|
6305
|
+
}
|
|
6306
|
+
}
|
|
6307
|
+
}
|
|
6308
|
+
}
|
|
6309
|
+
}
|
|
6310
|
+
|
|
6311
|
+
|
|
6312
|
+
// ggml_compute_forward_im2col_3d_f16
|
|
6313
|
+
// src0: kernel [OC*IC, KD, KH, KW]
|
|
6314
|
+
// src1: image [N*IC, ID, IH, IW]
|
|
6315
|
+
// dst: result [N*OD, OH, OW, IC * KD * KH * KW]
|
|
6316
|
+
static void ggml_compute_forward_im2col_3d_f16(
|
|
6317
|
+
const ggml_compute_params * params,
|
|
6318
|
+
ggml_tensor * dst) {
|
|
6319
|
+
|
|
6320
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
6321
|
+
const ggml_tensor * src1 = dst->src[1];
|
|
6322
|
+
|
|
6323
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
|
6324
|
+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
|
6325
|
+
GGML_ASSERT( dst->type == GGML_TYPE_F16);
|
|
6326
|
+
|
|
6327
|
+
GGML_TENSOR_BINARY_OP_LOCALS;
|
|
6328
|
+
|
|
6329
|
+
const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
|
|
6330
|
+
const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
|
|
6331
|
+
const int32_t s2 = ((const int32_t *)(dst->op_params))[2];
|
|
6332
|
+
const int32_t p0 = ((const int32_t *)(dst->op_params))[3];
|
|
6333
|
+
const int32_t p1 = ((const int32_t *)(dst->op_params))[4];
|
|
6334
|
+
const int32_t p2 = ((const int32_t *)(dst->op_params))[5];
|
|
6335
|
+
const int32_t d0 = ((const int32_t *)(dst->op_params))[6];
|
|
6336
|
+
const int32_t d1 = ((const int32_t *)(dst->op_params))[7];
|
|
6337
|
+
const int32_t d2 = ((const int32_t *)(dst->op_params))[8];
|
|
6338
|
+
const int32_t IC = ((const int32_t *)(dst->op_params))[9];
|
|
6339
|
+
|
|
6340
|
+
|
|
6341
|
+
const int ith = params->ith;
|
|
6342
|
+
const int nth = params->nth;
|
|
6343
|
+
|
|
6344
|
+
const int64_t N = ne13 / IC;
|
|
6345
|
+
const int64_t ID = ne12;
|
|
6346
|
+
const int64_t IH = ne11;
|
|
6347
|
+
const int64_t IW = ne10;
|
|
6348
|
+
|
|
6349
|
+
const int64_t OC = ne03 / IC;
|
|
6350
|
+
GGML_UNUSED(OC);
|
|
6351
|
+
const int64_t KD = ne02;
|
|
6352
|
+
const int64_t KH = ne01;
|
|
6353
|
+
const int64_t KW = ne00;
|
|
6354
|
+
|
|
6355
|
+
const int64_t OD = ne3 / N;
|
|
6356
|
+
const int64_t OH = ne2;
|
|
6357
|
+
const int64_t OW = ne1;
|
|
6358
|
+
const int64_t OH_OW = OH*OW;
|
|
6359
|
+
const int64_t KD_KH_KW = KD*KH*KW;
|
|
6360
|
+
const int64_t KH_KW = KH*KW;
|
|
6361
|
+
const int64_t IC_KD_KH_KW = IC*KD*KH*KW;
|
|
6362
|
+
|
|
6363
|
+
GGML_ASSERT(nb10 == sizeof(float));
|
|
6364
|
+
|
|
6365
|
+
// im2col: [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW]
|
|
6366
|
+
{
|
|
6367
|
+
ggml_fp16_t * const wdata = (ggml_fp16_t *) dst->data;
|
|
6368
|
+
|
|
6369
|
+
for (int64_t in = 0; in < N; in++) {
|
|
6370
|
+
for (int64_t iod = 0; iod < OD; iod++) {
|
|
6371
|
+
for (int64_t ioh = 0; ioh < OH; ioh++) {
|
|
6372
|
+
for (int64_t iow = 0; iow < OW; iow++) {
|
|
6373
|
+
for (int64_t iic = ith; iic < IC; iic += nth) {
|
|
6374
|
+
|
|
6375
|
+
// micro kernel
|
|
6376
|
+
ggml_fp16_t * dst_data = wdata + (in*OD*OH_OW + iod*OH_OW + ioh*OW + iow)*IC_KD_KH_KW; // [IC, KD, KH, KW]
|
|
6377
|
+
const float * const src_data = (const float *) ((const char *)src1->data + (in*IC + iic)*nb13); // [ID, IH, IW]
|
|
6378
|
+
|
|
6379
|
+
for (int64_t ikd = 0; ikd < KD; ikd++) {
|
|
6380
|
+
for (int64_t ikh = 0; ikh < KH; ikh++) {
|
|
6381
|
+
for (int64_t ikw = 0; ikw < KW; ikw++) {
|
|
6382
|
+
const int64_t iiw = iow*s0 + ikw*d0 - p0;
|
|
6383
|
+
const int64_t iih = ioh*s1 + ikh*d1 - p1;
|
|
6384
|
+
const int64_t iid = iod*s2 + ikd*d2 - p2;
|
|
6385
|
+
|
|
6386
|
+
if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) {
|
|
6387
|
+
dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = 0;
|
|
6388
|
+
} else {
|
|
6389
|
+
const float * const s = (const float *) ((const char *)src_data + iid*nb12 + iih*nb11 + iiw*nb10); // [ID, IH, IW]
|
|
6390
|
+
dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = GGML_CPU_FP32_TO_FP16(*s);
|
|
6391
|
+
}
|
|
6392
|
+
}
|
|
6393
|
+
}
|
|
6394
|
+
}
|
|
6395
|
+
}
|
|
6396
|
+
}
|
|
6397
|
+
}
|
|
6398
|
+
}
|
|
6399
|
+
}
|
|
6400
|
+
}
|
|
6401
|
+
}
|
|
6402
|
+
|
|
6403
|
+
// ggml_compute_forward_im2col_3d_f32
|
|
6404
|
+
// src0: kernel [OC*IC, KD, KH, KW]
|
|
6405
|
+
// src1: image [N*IC, ID, IH, IW]
|
|
6406
|
+
// dst: result [N*OD, OH, OW, IC * KD * KH * KW]
|
|
6407
|
+
static void ggml_compute_forward_im2col_3d_f32(
|
|
6408
|
+
const ggml_compute_params * params,
|
|
6409
|
+
ggml_tensor * dst) {
|
|
6410
|
+
|
|
6411
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
6412
|
+
const ggml_tensor * src1 = dst->src[1];
|
|
6943
6413
|
|
|
6944
|
-
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
6945
6414
|
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
|
6946
6415
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
|
6947
6416
|
|
|
@@ -6949,77 +6418,72 @@ void ggml_compute_forward_im2col_back_f32(
|
|
|
6949
6418
|
|
|
6950
6419
|
const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
|
|
6951
6420
|
const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
|
|
6952
|
-
const int32_t
|
|
6953
|
-
const int32_t
|
|
6954
|
-
const int32_t
|
|
6955
|
-
const int32_t
|
|
6956
|
-
const
|
|
6421
|
+
const int32_t s2 = ((const int32_t *)(dst->op_params))[2];
|
|
6422
|
+
const int32_t p0 = ((const int32_t *)(dst->op_params))[3];
|
|
6423
|
+
const int32_t p1 = ((const int32_t *)(dst->op_params))[4];
|
|
6424
|
+
const int32_t p2 = ((const int32_t *)(dst->op_params))[5];
|
|
6425
|
+
const int32_t d0 = ((const int32_t *)(dst->op_params))[6];
|
|
6426
|
+
const int32_t d1 = ((const int32_t *)(dst->op_params))[7];
|
|
6427
|
+
const int32_t d2 = ((const int32_t *)(dst->op_params))[8];
|
|
6428
|
+
const int32_t IC = ((const int32_t *)(dst->op_params))[9];
|
|
6429
|
+
|
|
6957
6430
|
|
|
6958
6431
|
const int ith = params->ith;
|
|
6959
6432
|
const int nth = params->nth;
|
|
6960
6433
|
|
|
6961
|
-
const int64_t N =
|
|
6962
|
-
const int64_t
|
|
6963
|
-
const int64_t IH =
|
|
6964
|
-
const int64_t IW =
|
|
6434
|
+
const int64_t N = ne13 / IC;
|
|
6435
|
+
const int64_t ID = ne12;
|
|
6436
|
+
const int64_t IH = ne11;
|
|
6437
|
+
const int64_t IW = ne10;
|
|
6965
6438
|
|
|
6966
|
-
const int64_t
|
|
6967
|
-
|
|
6439
|
+
const int64_t OC = ne03 / IC;
|
|
6440
|
+
GGML_UNUSED(OC);
|
|
6441
|
+
const int64_t KD = ne02;
|
|
6442
|
+
const int64_t KH = ne01;
|
|
6443
|
+
const int64_t KW = ne00;
|
|
6968
6444
|
|
|
6969
|
-
const int64_t
|
|
6970
|
-
const int64_t
|
|
6445
|
+
const int64_t OD = ne3 / N;
|
|
6446
|
+
const int64_t OH = ne2;
|
|
6447
|
+
const int64_t OW = ne1;
|
|
6971
6448
|
|
|
6972
|
-
|
|
6973
|
-
|
|
6449
|
+
const int64_t OH_OW = OH*OW;
|
|
6450
|
+
const int64_t KD_KH_KW = KD*KH*KW;
|
|
6451
|
+
const int64_t KH_KW = KH*KW;
|
|
6452
|
+
const int64_t IC_KD_KH_KW = IC*KD*KH*KW;
|
|
6974
6453
|
|
|
6975
|
-
GGML_ASSERT(
|
|
6454
|
+
GGML_ASSERT(nb10 == sizeof(float));
|
|
6976
6455
|
|
|
6977
|
-
// im2col: [N,
|
|
6456
|
+
// im2col: [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW]
|
|
6978
6457
|
{
|
|
6979
6458
|
float * const wdata = (float *) dst->data;
|
|
6980
6459
|
|
|
6981
6460
|
for (int64_t in = 0; in < N; in++) {
|
|
6982
|
-
for (int64_t
|
|
6983
|
-
for (int64_t
|
|
6984
|
-
for (int64_t
|
|
6985
|
-
|
|
6986
|
-
|
|
6987
|
-
|
|
6988
|
-
|
|
6989
|
-
|
|
6990
|
-
|
|
6991
|
-
|
|
6992
|
-
|
|
6993
|
-
|
|
6994
|
-
|
|
6995
|
-
|
|
6996
|
-
|
|
6997
|
-
|
|
6998
|
-
|
|
6999
|
-
|
|
7000
|
-
|
|
7001
|
-
|
|
7002
|
-
|
|
7003
|
-
|
|
7004
|
-
continue;
|
|
6461
|
+
for (int64_t iod = 0; iod < OD; iod++) {
|
|
6462
|
+
for (int64_t ioh = 0; ioh < OH; ioh++) {
|
|
6463
|
+
for (int64_t iow = 0; iow < OW; iow++) {
|
|
6464
|
+
for (int64_t iic = ith; iic < IC; iic += nth) {
|
|
6465
|
+
|
|
6466
|
+
// micro kernel
|
|
6467
|
+
float * dst_data = wdata + (in*OD*OH_OW + iod*OH_OW + ioh*OW + iow)*IC_KD_KH_KW; // [IC, KD, KH, KW]
|
|
6468
|
+
const float * const src_data = (const float *) ((const char *)src1->data + (in*IC + iic)*nb13); // [ID, IH, IW]
|
|
6469
|
+
|
|
6470
|
+
for (int64_t ikd = 0; ikd < KD; ikd++) {
|
|
6471
|
+
for (int64_t ikh = 0; ikh < KH; ikh++) {
|
|
6472
|
+
for (int64_t ikw = 0; ikw < KW; ikw++) {
|
|
6473
|
+
const int64_t iiw = iow*s0 + ikw*d0 - p0;
|
|
6474
|
+
const int64_t iih = ioh*s1 + ikh*d1 - p1;
|
|
6475
|
+
const int64_t iid = iod*s2 + ikd*d2 - p2;
|
|
6476
|
+
|
|
6477
|
+
if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) {
|
|
6478
|
+
dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = 0;
|
|
6479
|
+
} else {
|
|
6480
|
+
const float * const s = (const float *) ((const char *)src_data + iid*nb12 + iih*nb11 + iiw*nb10); // [ID, IH, IW]
|
|
6481
|
+
dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = *s;
|
|
6482
|
+
}
|
|
7005
6483
|
}
|
|
7006
|
-
|
|
7007
|
-
ioh = tmph / s1;
|
|
7008
|
-
} else {
|
|
7009
|
-
ioh = 0;
|
|
7010
|
-
}
|
|
7011
|
-
|
|
7012
|
-
if (iow < 0 || iow >= OW || ioh < 0 || ioh >= OH) {
|
|
7013
|
-
continue;
|
|
7014
6484
|
}
|
|
7015
|
-
|
|
7016
|
-
const float * const grad_in = (const float *) src0->data
|
|
7017
|
-
+ (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
|
|
7018
|
-
grad += grad_in[iic*(KH*KW) + ikh*KW + ikw];
|
|
7019
6485
|
}
|
|
7020
6486
|
}
|
|
7021
|
-
float * dst_data = (float *)((char *) wdata + (in*ofs0 + iic*ofs1)); // [IH, IW]
|
|
7022
|
-
dst_data[iih*IW + iiw] = grad;
|
|
7023
6487
|
}
|
|
7024
6488
|
}
|
|
7025
6489
|
}
|
|
@@ -7027,6 +6491,26 @@ void ggml_compute_forward_im2col_back_f32(
|
|
|
7027
6491
|
}
|
|
7028
6492
|
}
|
|
7029
6493
|
|
|
6494
|
+
|
|
6495
|
+
void ggml_compute_forward_im2col_3d(
|
|
6496
|
+
const ggml_compute_params * params,
|
|
6497
|
+
ggml_tensor * dst) {
|
|
6498
|
+
switch (dst->type) {
|
|
6499
|
+
case GGML_TYPE_F16:
|
|
6500
|
+
{
|
|
6501
|
+
ggml_compute_forward_im2col_3d_f16(params, dst);
|
|
6502
|
+
} break;
|
|
6503
|
+
case GGML_TYPE_F32:
|
|
6504
|
+
{
|
|
6505
|
+
ggml_compute_forward_im2col_3d_f32(params, dst);
|
|
6506
|
+
} break;
|
|
6507
|
+
default:
|
|
6508
|
+
{
|
|
6509
|
+
GGML_ABORT("fatal error");
|
|
6510
|
+
}
|
|
6511
|
+
}
|
|
6512
|
+
}
|
|
6513
|
+
|
|
7030
6514
|
static void ggml_call_mul_mat(ggml_type type, const ggml_compute_params * params, int64_t m, int64_t n, int64_t k,
|
|
7031
6515
|
void * a, void * b, float * c) {
|
|
7032
6516
|
const ggml_type_traits * traits = ggml_get_type_traits(type);
|
|
@@ -7480,7 +6964,11 @@ static void ggml_compute_forward_conv_2d_dw_cwhn(
|
|
|
7480
6964
|
const int64_t row_end = MIN(row_start + rows_per_thread, rows_total);
|
|
7481
6965
|
|
|
7482
6966
|
#ifdef GGML_SIMD
|
|
7483
|
-
|
|
6967
|
+
#if defined(__ARM_FEATURE_SVE)
|
|
6968
|
+
const int64_t pkg_size = svcntw();
|
|
6969
|
+
#else
|
|
6970
|
+
const int64_t pkg_size = GGML_F32_EPR;
|
|
6971
|
+
#endif
|
|
7484
6972
|
const int64_t pkg_count = c / pkg_size;
|
|
7485
6973
|
const int64_t c_pkg_end = pkg_count * pkg_size;
|
|
7486
6974
|
#else
|
|
@@ -7903,10 +7391,17 @@ static void ggml_compute_forward_upscale_f32(
|
|
|
7903
7391
|
float sf1 = (float)ne1/src0->ne[1];
|
|
7904
7392
|
float sf2 = (float)ne2/src0->ne[2];
|
|
7905
7393
|
float sf3 = (float)ne3/src0->ne[3];
|
|
7394
|
+
float pixel_offset = 0.5f;
|
|
7906
7395
|
|
|
7907
7396
|
const int32_t mode_flags = ggml_get_op_params_i32(dst, 0);
|
|
7908
7397
|
const ggml_scale_mode mode = (ggml_scale_mode) (mode_flags & 0xFF);
|
|
7909
7398
|
|
|
7399
|
+
if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
|
|
7400
|
+
pixel_offset = 0.0f;
|
|
7401
|
+
sf0 = ne0 > 1 && ne00 > 1 ? (float)(ne0 - 1) / (ne00 - 1) : sf0;
|
|
7402
|
+
sf1 = ne1 > 1 && ne01 > 1 ? (float)(ne1 - 1) / (ne01 - 1) : sf1;
|
|
7403
|
+
}
|
|
7404
|
+
|
|
7910
7405
|
if (mode == GGML_SCALE_MODE_NEAREST) {
|
|
7911
7406
|
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
|
7912
7407
|
const int64_t i03 = i3 / sf3;
|
|
@@ -7926,13 +7421,6 @@ static void ggml_compute_forward_upscale_f32(
|
|
|
7926
7421
|
}
|
|
7927
7422
|
}
|
|
7928
7423
|
} else if (mode == GGML_SCALE_MODE_BILINEAR) {
|
|
7929
|
-
float pixel_offset = 0.5f;
|
|
7930
|
-
if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
|
|
7931
|
-
pixel_offset = 0.0f;
|
|
7932
|
-
sf0 = (float)(ne0 - 1) / (src0->ne[0] - 1);
|
|
7933
|
-
sf1 = (float)(ne1 - 1) / (src0->ne[1] - 1);
|
|
7934
|
-
}
|
|
7935
|
-
|
|
7936
7424
|
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
|
7937
7425
|
const int64_t i03 = i3 / sf3;
|
|
7938
7426
|
for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
|
|
@@ -7967,6 +7455,51 @@ static void ggml_compute_forward_upscale_f32(
|
|
|
7967
7455
|
|
|
7968
7456
|
const float val = a*(1 - dx)*(1 - dy) + b*dx*(1 - dy) + c*(1 - dx)*dy + d*dx*dy;
|
|
7969
7457
|
|
|
7458
|
+
float * y_dst = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
|
|
7459
|
+
*y_dst = val;
|
|
7460
|
+
}
|
|
7461
|
+
}
|
|
7462
|
+
}
|
|
7463
|
+
}
|
|
7464
|
+
} else if (mode == GGML_SCALE_MODE_BICUBIC) {
|
|
7465
|
+
// https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
|
|
7466
|
+
const float a = -0.75f; // use alpha = -0.75 (same as PyTorch)
|
|
7467
|
+
auto weight1 = [a](float x) { return ((a + 2) * x - (a + 3)) * x * x + 1; };
|
|
7468
|
+
auto weight2 = [a](float x) { return ((a * x - 5 * a) * x + 8 * a) * x - 4 * a; };
|
|
7469
|
+
auto bicubic = [=](float p0, float p1, float p2, float p3, float x) {
|
|
7470
|
+
const float w0 = weight2(x + 1);
|
|
7471
|
+
const float w1 = weight1(x + 0);
|
|
7472
|
+
const float w2 = weight1(1 - x);
|
|
7473
|
+
const float w3 = weight2(2 - x);
|
|
7474
|
+
return p0*w0 + p1*w1 + p2*w2 + p3*w3;
|
|
7475
|
+
};
|
|
7476
|
+
|
|
7477
|
+
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
|
7478
|
+
const int64_t i03 = i3 / sf3;
|
|
7479
|
+
for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
|
|
7480
|
+
const int64_t i02 = i2 / sf2;
|
|
7481
|
+
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
|
7482
|
+
const float y = ((float)i1 + pixel_offset) / sf1 - pixel_offset;
|
|
7483
|
+
const int64_t y0 = (int64_t)floorf(y);
|
|
7484
|
+
const float dy = y - (float)y0;
|
|
7485
|
+
|
|
7486
|
+
for (int64_t i0 = 0; i0 < ne0; i0++) {
|
|
7487
|
+
const float x = ((float)i0 + pixel_offset) / sf0 - pixel_offset;
|
|
7488
|
+
const int64_t x0 = (int64_t)floorf(x);
|
|
7489
|
+
const float dx = x - (float)x0;
|
|
7490
|
+
|
|
7491
|
+
auto p = [=](int64_t x_off, int64_t y_off) -> float {
|
|
7492
|
+
int64_t i00 = std::max(int64_t(0), std::min(x0 + x_off, ne00 - 1));
|
|
7493
|
+
int64_t i01 = std::max(int64_t(0), std::min(y0 + y_off, ne01 - 1));
|
|
7494
|
+
return *(const float *)((const char *)src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
|
7495
|
+
};
|
|
7496
|
+
|
|
7497
|
+
const float val = bicubic(
|
|
7498
|
+
bicubic(p(-1,-1), p(0,-1), p(1,-1), p(2,-1), dx),
|
|
7499
|
+
bicubic(p(-1, 0), p(0, 0), p(1, 0), p(2, 0), dx),
|
|
7500
|
+
bicubic(p(-1, 1), p(0, 1), p(1, 1), p(2, 1), dx),
|
|
7501
|
+
bicubic(p(-1, 2), p(0, 2), p(1, 2), p(2, 2), dx), dy);
|
|
7502
|
+
|
|
7970
7503
|
float * y_dst = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
|
|
7971
7504
|
*y_dst = val;
|
|
7972
7505
|
}
|
|
@@ -8014,6 +7547,15 @@ static void ggml_compute_forward_pad_f32(
|
|
|
8014
7547
|
GGML_TENSOR_UNARY_OP_LOCALS
|
|
8015
7548
|
|
|
8016
7549
|
float * dst_ptr = (float *) dst->data;
|
|
7550
|
+
const int32_t lp0 = ggml_get_op_params_i32(dst, 0);
|
|
7551
|
+
const int32_t rp0 = ggml_get_op_params_i32(dst, 1);
|
|
7552
|
+
const int32_t lp1 = ggml_get_op_params_i32(dst, 2);
|
|
7553
|
+
const int32_t rp1 = ggml_get_op_params_i32(dst, 3);
|
|
7554
|
+
const int32_t lp2 = ggml_get_op_params_i32(dst, 4);
|
|
7555
|
+
const int32_t rp2 = ggml_get_op_params_i32(dst, 5);
|
|
7556
|
+
const int32_t lp3 = ggml_get_op_params_i32(dst, 6);
|
|
7557
|
+
const int32_t rp3 = ggml_get_op_params_i32(dst, 7);
|
|
7558
|
+
|
|
8017
7559
|
|
|
8018
7560
|
// TODO: optimize
|
|
8019
7561
|
|
|
@@ -8022,10 +7564,12 @@ static void ggml_compute_forward_pad_f32(
|
|
|
8022
7564
|
for (int64_t i0 = 0; i0 < ne0; ++i0) {
|
|
8023
7565
|
for (int64_t i3 = 0; i3 < ne3; ++i3) {
|
|
8024
7566
|
const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
|
|
8025
|
-
|
|
8026
|
-
|
|
8027
|
-
|
|
8028
|
-
|
|
7567
|
+
if ((i0 >= lp0 && i0 < ne0 - rp0) \
|
|
7568
|
+
&& (i1 >= lp1 && i1 < ne1 - rp1) \
|
|
7569
|
+
&& (i2 >= lp2 && i2 < ne2 - rp2) \
|
|
7570
|
+
&& (i3 >= lp3 && i3 < ne3 - rp3)) {
|
|
7571
|
+
const int64_t src_idx = (i3 - lp3)*nb03 + (i2 - lp2)*nb02 + (i1 - lp1)*nb01 + (i0 - lp0)*nb00;
|
|
7572
|
+
const float * src_ptr = (const float *)((char *) src0->data + src_idx);
|
|
8029
7573
|
dst_ptr[dst_idx] = *src_ptr;
|
|
8030
7574
|
} else {
|
|
8031
7575
|
dst_ptr[dst_idx] = 0;
|
|
@@ -8224,7 +7768,7 @@ static void ggml_compute_forward_timestep_embedding_f32(
|
|
|
8224
7768
|
embed_data[j + half] = sinf(arg);
|
|
8225
7769
|
}
|
|
8226
7770
|
if (dim % 2 != 0 && ith == 0) {
|
|
8227
|
-
embed_data[
|
|
7771
|
+
embed_data[2 * half] = 0.f;
|
|
8228
7772
|
}
|
|
8229
7773
|
}
|
|
8230
7774
|
}
|
|
@@ -8249,6 +7793,18 @@ void ggml_compute_forward_timestep_embedding(
|
|
|
8249
7793
|
|
|
8250
7794
|
// ggml_compute_forward_argsort
|
|
8251
7795
|
|
|
7796
|
+
template<enum ggml_sort_order order>
|
|
7797
|
+
struct argsort_cmp {
|
|
7798
|
+
const float * data;
|
|
7799
|
+
bool operator()(int32_t a, int32_t b) const {
|
|
7800
|
+
if constexpr (order == GGML_SORT_ORDER_ASC) {
|
|
7801
|
+
return data[a] < data[b];
|
|
7802
|
+
} else {
|
|
7803
|
+
return data[a] > data[b];
|
|
7804
|
+
}
|
|
7805
|
+
}
|
|
7806
|
+
};
|
|
7807
|
+
|
|
8252
7808
|
static void ggml_compute_forward_argsort_f32(
|
|
8253
7809
|
const ggml_compute_params * params,
|
|
8254
7810
|
ggml_tensor * dst) {
|
|
@@ -8267,23 +7823,25 @@ static void ggml_compute_forward_argsort_f32(
|
|
|
8267
7823
|
ggml_sort_order order = (ggml_sort_order) ggml_get_op_params_i32(dst, 0);
|
|
8268
7824
|
|
|
8269
7825
|
for (int64_t i = ith; i < nr; i += nth) {
|
|
8270
|
-
int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
|
|
8271
7826
|
const float * src_data = (float *)((char *) src0->data + i*nb01);
|
|
8272
7827
|
|
|
7828
|
+
int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
|
|
7829
|
+
|
|
8273
7830
|
for (int64_t j = 0; j < ne0; j++) {
|
|
8274
7831
|
dst_data[j] = j;
|
|
8275
7832
|
}
|
|
8276
7833
|
|
|
8277
|
-
|
|
8278
|
-
|
|
8279
|
-
|
|
8280
|
-
|
|
8281
|
-
|
|
8282
|
-
|
|
8283
|
-
|
|
8284
|
-
|
|
8285
|
-
|
|
8286
|
-
|
|
7834
|
+
switch (order) {
|
|
7835
|
+
case GGML_SORT_ORDER_ASC:
|
|
7836
|
+
std::sort(dst_data, dst_data + ne0, argsort_cmp<GGML_SORT_ORDER_ASC>{src_data});
|
|
7837
|
+
break;
|
|
7838
|
+
|
|
7839
|
+
case GGML_SORT_ORDER_DESC:
|
|
7840
|
+
std::sort(dst_data, dst_data + ne0, argsort_cmp<GGML_SORT_ORDER_DESC>{src_data});
|
|
7841
|
+
break;
|
|
7842
|
+
|
|
7843
|
+
default:
|
|
7844
|
+
GGML_ABORT("invalid sort order");
|
|
8287
7845
|
}
|
|
8288
7846
|
}
|
|
8289
7847
|
}
|
|
@@ -8308,10 +7866,10 @@ void ggml_compute_forward_argsort(
|
|
|
8308
7866
|
|
|
8309
7867
|
// ggml_compute_forward_flash_attn_ext
|
|
8310
7868
|
|
|
8311
|
-
static void
|
|
7869
|
+
static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
|
|
8312
7870
|
const ggml_compute_params * params,
|
|
8313
|
-
ggml_tensor * dst
|
|
8314
|
-
|
|
7871
|
+
ggml_tensor * dst,
|
|
7872
|
+
int ir0, int ir1) {
|
|
8315
7873
|
const ggml_tensor * q = dst->src[0];
|
|
8316
7874
|
const ggml_tensor * k = dst->src[1];
|
|
8317
7875
|
const ggml_tensor * v = dst->src[2];
|
|
@@ -8327,9 +7885,6 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
|
8327
7885
|
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
|
8328
7886
|
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
|
8329
7887
|
|
|
8330
|
-
const int ith = params->ith;
|
|
8331
|
-
const int nth = params->nth;
|
|
8332
|
-
|
|
8333
7888
|
const int64_t DK = nek0;
|
|
8334
7889
|
const int64_t DV = nev0;
|
|
8335
7890
|
const int64_t N = neq1;
|
|
@@ -8363,16 +7918,6 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
|
8363
7918
|
|
|
8364
7919
|
// parallelize by q rows using ggml_vec_dot_f32
|
|
8365
7920
|
|
|
8366
|
-
// total rows in q
|
|
8367
|
-
const int nr = neq1*neq2*neq3;
|
|
8368
|
-
|
|
8369
|
-
// rows per thread
|
|
8370
|
-
const int dr = (nr + nth - 1)/nth;
|
|
8371
|
-
|
|
8372
|
-
// row range for this thread
|
|
8373
|
-
const int ir0 = dr*ith;
|
|
8374
|
-
const int ir1 = MIN(ir0 + dr, nr);
|
|
8375
|
-
|
|
8376
7921
|
float scale = 1.0f;
|
|
8377
7922
|
float max_bias = 0.0f;
|
|
8378
7923
|
float logit_softcap = 0.0f;
|
|
@@ -8399,6 +7944,8 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
|
8399
7944
|
GGML_ASSERT(( q_to_vec_dot) && "fattn: unsupported K-type");
|
|
8400
7945
|
GGML_ASSERT((v->type == GGML_TYPE_F32 || v_to_float ) && "fattn: unsupported V-type");
|
|
8401
7946
|
|
|
7947
|
+
int ith = params->ith;
|
|
7948
|
+
|
|
8402
7949
|
// loop over n_batch and n_head
|
|
8403
7950
|
for (int ir = ir0; ir < ir1; ++ir) {
|
|
8404
7951
|
// q indices
|
|
@@ -8530,7 +8077,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
|
8530
8077
|
}
|
|
8531
8078
|
|
|
8532
8079
|
// V /= S
|
|
8533
|
-
const float S_inv = 1.0f/S;
|
|
8080
|
+
const float S_inv = S == 0.0f ? 0.0f : 1.0f/S;
|
|
8534
8081
|
ggml_vec_scale_f32(DV, VKQ32, S_inv);
|
|
8535
8082
|
|
|
8536
8083
|
// dst indices
|
|
@@ -8546,6 +8093,91 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
|
8546
8093
|
}
|
|
8547
8094
|
}
|
|
8548
8095
|
|
|
8096
|
+
static void ggml_compute_forward_flash_attn_ext_f16(
|
|
8097
|
+
const ggml_compute_params * params,
|
|
8098
|
+
ggml_tensor * dst) {
|
|
8099
|
+
|
|
8100
|
+
const ggml_tensor * q = dst->src[0];
|
|
8101
|
+
const ggml_tensor * k = dst->src[1];
|
|
8102
|
+
const ggml_tensor * v = dst->src[2];
|
|
8103
|
+
|
|
8104
|
+
GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
|
|
8105
|
+
GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
|
|
8106
|
+
GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
|
|
8107
|
+
GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
|
|
8108
|
+
GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
|
|
8109
|
+
GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
|
|
8110
|
+
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
|
8111
|
+
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
|
8112
|
+
|
|
8113
|
+
const int64_t DK = nek0;
|
|
8114
|
+
const int64_t DV = nev0;
|
|
8115
|
+
const int64_t N = neq1;
|
|
8116
|
+
|
|
8117
|
+
GGML_ASSERT(ne0 == DV);
|
|
8118
|
+
GGML_ASSERT(ne2 == N);
|
|
8119
|
+
|
|
8120
|
+
// input tensor rows must be contiguous
|
|
8121
|
+
GGML_ASSERT(nbq0 == ggml_type_size(q->type));
|
|
8122
|
+
GGML_ASSERT(nbk0 == ggml_type_size(k->type));
|
|
8123
|
+
GGML_ASSERT(nbv0 == ggml_type_size(v->type));
|
|
8124
|
+
|
|
8125
|
+
GGML_ASSERT(neq0 == DK);
|
|
8126
|
+
GGML_ASSERT(nek0 == DK);
|
|
8127
|
+
GGML_ASSERT(nev0 == DV);
|
|
8128
|
+
|
|
8129
|
+
GGML_ASSERT(neq1 == N);
|
|
8130
|
+
|
|
8131
|
+
// dst cannot be transposed or permuted
|
|
8132
|
+
GGML_ASSERT(nb0 == sizeof(float));
|
|
8133
|
+
GGML_ASSERT(nb0 <= nb1);
|
|
8134
|
+
GGML_ASSERT(nb1 <= nb2);
|
|
8135
|
+
GGML_ASSERT(nb2 <= nb3);
|
|
8136
|
+
|
|
8137
|
+
// parallelize by q rows using ggml_vec_dot_f32
|
|
8138
|
+
|
|
8139
|
+
// total rows in q
|
|
8140
|
+
const int64_t nr = neq1*neq2*neq3;
|
|
8141
|
+
|
|
8142
|
+
// rows per thread
|
|
8143
|
+
const int ith = params->ith;
|
|
8144
|
+
const int nth = params->nth;
|
|
8145
|
+
|
|
8146
|
+
// disable for NUMA
|
|
8147
|
+
const bool disable_chunking = ggml_is_numa();
|
|
8148
|
+
|
|
8149
|
+
// 4x chunks per thread
|
|
8150
|
+
int nth_scaled = nth * 4;
|
|
8151
|
+
int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
|
|
8152
|
+
int64_t nchunk = (nr + chunk_size - 1) / chunk_size;
|
|
8153
|
+
|
|
8154
|
+
if (nth == 1 || nchunk < nth || disable_chunking) {
|
|
8155
|
+
nchunk = nth;
|
|
8156
|
+
}
|
|
8157
|
+
|
|
8158
|
+
if (ith == 0) {
|
|
8159
|
+
// Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
|
|
8160
|
+
ggml_threadpool_chunk_set(params->threadpool, nth);
|
|
8161
|
+
}
|
|
8162
|
+
|
|
8163
|
+
ggml_barrier(params->threadpool);
|
|
8164
|
+
|
|
8165
|
+
// The number of elements in each chunk
|
|
8166
|
+
const int64_t dr = (nr + nchunk - 1) / nchunk;
|
|
8167
|
+
|
|
8168
|
+
// The first chunk comes from our thread_id, the rest will get auto-assigned.
|
|
8169
|
+
int current_chunk = ith;
|
|
8170
|
+
|
|
8171
|
+
while (current_chunk < nchunk) {
|
|
8172
|
+
const int64_t ir0 = dr * current_chunk;
|
|
8173
|
+
const int64_t ir1 = MIN(ir0 + dr, nr);
|
|
8174
|
+
|
|
8175
|
+
ggml_compute_forward_flash_attn_ext_f16_one_chunk(params, dst, ir0, ir1);
|
|
8176
|
+
|
|
8177
|
+
current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
|
|
8178
|
+
}
|
|
8179
|
+
}
|
|
8180
|
+
|
|
8549
8181
|
void ggml_compute_forward_flash_attn_ext(
|
|
8550
8182
|
const ggml_compute_params * params,
|
|
8551
8183
|
ggml_tensor * dst) {
|
|
@@ -9032,7 +8664,7 @@ static void ggml_compute_forward_ssm_scan_f32(
|
|
|
9032
8664
|
// n_head
|
|
9033
8665
|
for (int h = ih0; h < ih1; ++h) {
|
|
9034
8666
|
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
|
|
9035
|
-
const float dt_soft_plus =
|
|
8667
|
+
const float dt_soft_plus = ggml_compute_softplus_f32(dt[h]);
|
|
9036
8668
|
const float dA = expf(dt_soft_plus * A[h]);
|
|
9037
8669
|
const int g = h / (nh / ng); // repeat_interleave
|
|
9038
8670
|
|
|
@@ -9129,7 +8761,7 @@ static void ggml_compute_forward_ssm_scan_f32(
|
|
|
9129
8761
|
// n_head
|
|
9130
8762
|
for (int h = ih0; h < ih1; ++h) {
|
|
9131
8763
|
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
|
|
9132
|
-
const float dt_soft_plus =
|
|
8764
|
+
const float dt_soft_plus = ggml_compute_softplus_f32(dt[h]);
|
|
9133
8765
|
const int g = h / (nh / ng); // repeat_interleave
|
|
9134
8766
|
|
|
9135
8767
|
// dim
|
|
@@ -9392,6 +9024,34 @@ void ggml_compute_forward_unary(
|
|
|
9392
9024
|
{
|
|
9393
9025
|
ggml_compute_forward_exp(params, dst);
|
|
9394
9026
|
} break;
|
|
9027
|
+
case GGML_UNARY_OP_FLOOR:
|
|
9028
|
+
{
|
|
9029
|
+
ggml_compute_forward_floor(params, dst);
|
|
9030
|
+
} break;
|
|
9031
|
+
case GGML_UNARY_OP_CEIL:
|
|
9032
|
+
{
|
|
9033
|
+
ggml_compute_forward_ceil(params, dst);
|
|
9034
|
+
} break;
|
|
9035
|
+
case GGML_UNARY_OP_ROUND:
|
|
9036
|
+
{
|
|
9037
|
+
ggml_compute_forward_round(params, dst);
|
|
9038
|
+
} break;
|
|
9039
|
+
case GGML_UNARY_OP_TRUNC:
|
|
9040
|
+
{
|
|
9041
|
+
ggml_compute_forward_trunc(params, dst);
|
|
9042
|
+
} break;
|
|
9043
|
+
case GGML_UNARY_OP_XIELU:
|
|
9044
|
+
{
|
|
9045
|
+
ggml_compute_forward_xielu(params, dst);
|
|
9046
|
+
} break;
|
|
9047
|
+
case GGML_UNARY_OP_EXPM1:
|
|
9048
|
+
{
|
|
9049
|
+
ggml_compute_forward_expm1(params, dst);
|
|
9050
|
+
} break;
|
|
9051
|
+
case GGML_UNARY_OP_SOFTPLUS:
|
|
9052
|
+
{
|
|
9053
|
+
ggml_compute_forward_softplus(params, dst);
|
|
9054
|
+
} break;
|
|
9395
9055
|
default:
|
|
9396
9056
|
{
|
|
9397
9057
|
GGML_ABORT("fatal error");
|
|
@@ -9988,6 +9648,75 @@ void ggml_compute_forward_gla(
|
|
|
9988
9648
|
}
|
|
9989
9649
|
}
|
|
9990
9650
|
|
|
9651
|
+
static void ggml_compute_forward_solve_tri_f32(const struct ggml_compute_params * params, struct ggml_tensor * dst) {
|
|
9652
|
+
const struct ggml_tensor * src0 = dst->src[0]; // A (lower triangular)
|
|
9653
|
+
const struct ggml_tensor * src1 = dst->src[1]; // B (RHS)
|
|
9654
|
+
|
|
9655
|
+
GGML_TENSOR_BINARY_OP_LOCALS;
|
|
9656
|
+
|
|
9657
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
9658
|
+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
|
9659
|
+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
|
9660
|
+
|
|
9661
|
+
GGML_ASSERT(ne00 == ne01); // A must be square
|
|
9662
|
+
GGML_ASSERT(ne0 == ne10); // solution cols == B cols
|
|
9663
|
+
GGML_ASSERT(ne1 == ne11); // solution rows == B rows
|
|
9664
|
+
|
|
9665
|
+
GGML_ASSERT(ne02 == ne12 && ne12 == ne2);
|
|
9666
|
+
GGML_ASSERT(ne03 == ne13 && ne13 == ne3);
|
|
9667
|
+
|
|
9668
|
+
const int ith = params->ith;
|
|
9669
|
+
const int nth = params->nth;
|
|
9670
|
+
|
|
9671
|
+
const int64_t k = ne10; // number of RHS columns
|
|
9672
|
+
const int64_t n = ne11; // A is n×n
|
|
9673
|
+
const int64_t nr = ne02 * ne03 * k; // we're parallelizing on columns here, so seq x token x column will be the unit
|
|
9674
|
+
|
|
9675
|
+
// chunks per thread
|
|
9676
|
+
const int64_t dr = (nr + nth - 1)/nth;
|
|
9677
|
+
|
|
9678
|
+
// chunk range for this thread
|
|
9679
|
+
const int64_t ir0 = dr*ith;
|
|
9680
|
+
const int64_t ir1 = MIN(ir0 + dr, nr);
|
|
9681
|
+
|
|
9682
|
+
const float * A = (const float *) src0->data; // [n, n, B1, B2]
|
|
9683
|
+
const float * B = (const float *) src1->data; // [n, k, B1, B2]
|
|
9684
|
+
float * X = ( float *) dst->data; // [n, k, B1, B2]
|
|
9685
|
+
|
|
9686
|
+
for (int64_t ir = ir0; ir < ir1; ++ir) {
|
|
9687
|
+
const int64_t i03 = ir/(ne02*k);
|
|
9688
|
+
const int64_t i02 = (ir - i03*ne02*k)/k;
|
|
9689
|
+
const int64_t i01 = (ir - i03*ne02*k - i02*k);
|
|
9690
|
+
|
|
9691
|
+
const float * A_batch = A + i02 * nb02 / sizeof(float) + i03 * nb03 / sizeof(float);
|
|
9692
|
+
const float * B_batch = B + i02 * nb12 / sizeof(float) + i03 * nb13 / sizeof(float);
|
|
9693
|
+
|
|
9694
|
+
float * X_batch = X + i02 * nb2 / sizeof(float) + i03 * nb3 / sizeof(float);
|
|
9695
|
+
|
|
9696
|
+
for (int64_t i00 = 0; i00 < n; ++i00) {
|
|
9697
|
+
float sum = 0.0f;
|
|
9698
|
+
for (int64_t t = 0; t < i00; ++t) {
|
|
9699
|
+
sum += A_batch[i00 * n + t] * X_batch[t * k + i01];
|
|
9700
|
+
}
|
|
9701
|
+
|
|
9702
|
+
const float diag = A_batch[i00 * n + i00];
|
|
9703
|
+
GGML_ASSERT(diag != 0.0f && "Zero diagonal in triangular matrix");
|
|
9704
|
+
X_batch[i00 * k + i01] = (B_batch[i00 * k + i01] - sum) / diag;
|
|
9705
|
+
}
|
|
9706
|
+
}
|
|
9707
|
+
}
|
|
9708
|
+
|
|
9709
|
+
void ggml_compute_forward_solve_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst) {
|
|
9710
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
9711
|
+
const ggml_tensor * src1 = dst->src[1];
|
|
9712
|
+
|
|
9713
|
+
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
|
|
9714
|
+
ggml_compute_forward_solve_tri_f32(params, dst);
|
|
9715
|
+
} else {
|
|
9716
|
+
GGML_ABORT("fatal error");
|
|
9717
|
+
}
|
|
9718
|
+
}
|
|
9719
|
+
|
|
9991
9720
|
// ggml_compute_forward_rwkv_wkv7
|
|
9992
9721
|
|
|
9993
9722
|
static void ggml_compute_forward_rwkv_wkv7_f32(
|