@novastera-oss/llamarn 0.2.9 → 0.3.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/build.gradle +2 -1
- package/android/proguard-rules.pro +12 -0
- package/android/src/main/cpp/include/llama.h +15 -47
- package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86/libggml.so +0 -0
- package/android/src/main/jniLibs/x86/libllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
- package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
- package/cpp/build-info.cpp +2 -2
- package/cpp/llama.cpp/CMakeLists.txt +0 -1
- package/cpp/llama.cpp/CMakePresets.json +11 -0
- package/cpp/llama.cpp/CODEOWNERS +1 -0
- package/cpp/llama.cpp/README.md +8 -8
- package/cpp/llama.cpp/build-xcframework.sh +1 -1
- package/cpp/llama.cpp/common/CMakeLists.txt +4 -5
- package/cpp/llama.cpp/common/arg.cpp +62 -1
- package/cpp/llama.cpp/common/chat.cpp +37 -20
- package/cpp/llama.cpp/common/chat.h +2 -0
- package/cpp/llama.cpp/common/common.cpp +22 -6
- package/cpp/llama.cpp/common/common.h +22 -4
- package/cpp/llama.cpp/convert_hf_to_gguf.py +1250 -43
- package/cpp/llama.cpp/convert_hf_to_gguf_update.py +21 -13
- package/cpp/llama.cpp/ggml/CMakeLists.txt +13 -3
- package/cpp/llama.cpp/ggml/cmake/ggml-config.cmake.in +85 -47
- package/cpp/llama.cpp/ggml/include/ggml-backend.h +1 -1
- package/cpp/llama.cpp/ggml/include/ggml-webgpu.h +19 -0
- package/cpp/llama.cpp/ggml/include/ggml.h +173 -10
- package/cpp/llama.cpp/ggml/src/CMakeLists.txt +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-alloc.c +0 -15
- package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +7 -8
- package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +44 -38
- package/cpp/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +3 -1
- package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +126 -8
- package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +130 -22
- package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +138 -18
- package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +11 -3
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +28 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +109 -12
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +88 -10
- package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +343 -1094
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +1206 -163
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +6 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +0 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +36 -9
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +142 -9
- package/cpp/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +3 -3
- package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +31 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +86 -17
- package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy-utils.cuh +225 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy.cu +41 -301
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +85 -64
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +47 -60
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +29 -42
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +46 -59
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +36 -45
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +38 -45
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +23 -36
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cu +3 -13
- package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +255 -99
- package/cpp/llama.cpp/ggml/src/ggml-cuda/im2col.cu +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mma.cuh +111 -3
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cu +6 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +1152 -695
- package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cu +92 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cuh +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/rope.cu +21 -27
- package/cpp/llama.cpp/ggml/src/ggml-cuda/scale.cu +8 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cu +275 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cuh +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/softmax.cu +119 -58
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +192 -52
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +104 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +13 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/upscale.cu +92 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +27 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-impl.h +80 -0
- package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -2
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +48 -12
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +572 -106
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +599 -105
- package/cpp/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +18 -4
- package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +800 -42
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl +337 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/rms_norm.cl +79 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/set_rows.cl +95 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
- package/cpp/llama.cpp/ggml/src/ggml-quants.c +6 -6
- package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +4 -4
- package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +693 -1034
- package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +18 -9
- package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +14 -26
- package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +191 -55
- package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +8 -9
- package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +15 -18
- package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +131 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.hpp +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +2 -6
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +991 -307
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +265 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +59 -12
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +28 -23
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +14 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +38 -32
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +32 -27
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +44 -12
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +17 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +3 -8
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +128 -72
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +38 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +18 -3
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +1 -4
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +20 -4
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +84 -9
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/CMakeLists.txt +54 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/ggml-webgpu.cpp +907 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +60 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +35 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +56 -0
- package/cpp/llama.cpp/ggml/src/ggml.c +386 -67
- package/cpp/llama.cpp/ggml/src/gguf.cpp +8 -1
- package/cpp/llama.cpp/gguf-py/gguf/constants.py +307 -0
- package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +8 -2
- package/cpp/llama.cpp/gguf-py/gguf/metadata.py +4 -0
- package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_dump.py +24 -1
- package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +122 -47
- package/cpp/llama.cpp/gguf-py/gguf/vocab.py +12 -3
- package/cpp/llama.cpp/include/llama.h +15 -47
- package/cpp/llama.cpp/models/templates/llama-cpp-rwkv-world.jinja +34 -0
- package/cpp/llama.cpp/models/templates/moonshotai-Kimi-K2.jinja +43 -0
- package/cpp/llama.cpp/requirements/requirements-all.txt +1 -0
- package/cpp/llama.cpp/requirements/requirements-server-bench.txt +5 -0
- package/cpp/llama.cpp/src/llama-arch.cpp +316 -3
- package/cpp/llama.cpp/src/llama-arch.h +23 -1
- package/cpp/llama.cpp/src/llama-batch.cpp +103 -71
- package/cpp/llama.cpp/src/llama-batch.h +31 -18
- package/cpp/llama.cpp/src/llama-chat.cpp +58 -1
- package/cpp/llama.cpp/src/llama-chat.h +3 -0
- package/cpp/llama.cpp/src/llama-context.cpp +180 -106
- package/cpp/llama.cpp/src/llama-context.h +26 -16
- package/cpp/llama.cpp/src/llama-cparams.h +3 -2
- package/cpp/llama.cpp/src/llama-graph.cpp +310 -211
- package/cpp/llama.cpp/src/llama-graph.h +184 -122
- package/cpp/llama.cpp/src/llama-hparams.cpp +47 -1
- package/cpp/llama.cpp/src/llama-hparams.h +13 -2
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +38 -22
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +7 -2
- package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +849 -304
- package/cpp/llama.cpp/src/llama-kv-cache-unified.h +143 -47
- package/cpp/llama.cpp/src/llama-kv-cells.h +62 -10
- package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +10 -4
- package/cpp/llama.cpp/src/llama-memory-hybrid.h +3 -1
- package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +36 -11
- package/cpp/llama.cpp/src/llama-memory.cpp +17 -0
- package/cpp/llama.cpp/src/llama-memory.h +3 -0
- package/cpp/llama.cpp/src/llama-model.cpp +3545 -719
- package/cpp/llama.cpp/src/llama-model.h +21 -4
- package/cpp/llama.cpp/src/llama-quant.cpp +2 -2
- package/cpp/llama.cpp/src/llama-vocab.cpp +376 -10
- package/cpp/llama.cpp/src/llama-vocab.h +43 -0
- package/cpp/llama.cpp/src/unicode.cpp +207 -0
- package/cpp/llama.cpp/src/unicode.h +2 -0
- package/ios/include/chat.h +2 -0
- package/ios/include/common.h +22 -4
- package/ios/include/llama.h +15 -47
- package/ios/libs/llama.xcframework/Info.plist +13 -13
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5267 -4890
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -4861
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4014 -3764
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -4861
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4016 -3766
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5267 -4890
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -4861
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4014 -3764
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5303 -4926
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5274 -4897
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4044 -3794
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/package.json +4 -4
- package/cpp/llama.cpp/ggml/include/ggml-kompute.h +0 -50
- package/cpp/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
- package/cpp/llama.cpp/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
|
@@ -24,6 +24,7 @@
|
|
|
24
24
|
|
|
25
25
|
#include <acl/acl.h>
|
|
26
26
|
#include <stdarg.h>
|
|
27
|
+
#include <aclnnop/aclnn_trans_matmul_weight.h>
|
|
27
28
|
|
|
28
29
|
#include <cmath>
|
|
29
30
|
#include <cstdio>
|
|
@@ -1115,6 +1116,63 @@ static enum ggml_status ggml_backend_cann_buffer_init_tensor(
|
|
|
1115
1116
|
return GGML_STATUS_SUCCESS;
|
|
1116
1117
|
}
|
|
1117
1118
|
|
|
1119
|
+
static int CreateAclTensorWeight(const void *hostData, const std::vector<int64_t> &shape, void **deviceAddr,
|
|
1120
|
+
aclDataType dataType, aclTensor **tensor)
|
|
1121
|
+
{
|
|
1122
|
+
uint64_t size = 1;
|
|
1123
|
+
for (auto i : shape) {
|
|
1124
|
+
size *= i;
|
|
1125
|
+
}
|
|
1126
|
+
|
|
1127
|
+
const aclIntArray *mat2Size = aclCreateIntArray(shape.data(), shape.size());
|
|
1128
|
+
ACL_CHECK(aclnnCalculateMatmulWeightSizeV2(mat2Size, dataType, &size));
|
|
1129
|
+
|
|
1130
|
+
size *= sizeof(int16_t);
|
|
1131
|
+
|
|
1132
|
+
ACL_CHECK(aclrtMalloc(deviceAddr, size, ACL_MEM_MALLOC_HUGE_FIRST));
|
|
1133
|
+
aclrtMemcpy(*deviceAddr, size, hostData, size, ACL_MEMCPY_HOST_TO_DEVICE);
|
|
1134
|
+
|
|
1135
|
+
std::vector<int64_t> strides(shape.size(), 1);
|
|
1136
|
+
for (int64_t i = shape.size() - 2; i >= 0; i--) {
|
|
1137
|
+
strides[i] = shape[i + 1] * strides[i + 1];
|
|
1138
|
+
}
|
|
1139
|
+
|
|
1140
|
+
*tensor = aclCreateTensor(shape.data(), shape.size(), dataType, strides.data(), 0, aclFormat::ACL_FORMAT_ND,
|
|
1141
|
+
shape.data(), shape.size(), *deviceAddr);
|
|
1142
|
+
return 0;
|
|
1143
|
+
}
|
|
1144
|
+
|
|
1145
|
+
static void weight_format_to_nz(ggml_tensor *tensor, const void *data, size_t offset) {
|
|
1146
|
+
aclrtStream stream;
|
|
1147
|
+
ACL_CHECK(aclrtCreateStream(&stream));
|
|
1148
|
+
|
|
1149
|
+
std::vector<int64_t> weightTransposedShape = {tensor->ne[1], tensor->ne[0]};
|
|
1150
|
+
void *weightTransposedDeviceAddr = nullptr;
|
|
1151
|
+
aclTensor *weightTransposed = nullptr;
|
|
1152
|
+
CreateAclTensorWeight(data, weightTransposedShape, &weightTransposedDeviceAddr,
|
|
1153
|
+
ggml_cann_type_mapping(tensor->type), &weightTransposed);
|
|
1154
|
+
|
|
1155
|
+
uint64_t workspaceSize = 0;
|
|
1156
|
+
aclOpExecutor *executor;
|
|
1157
|
+
void *workspaceAddr = nullptr;
|
|
1158
|
+
|
|
1159
|
+
// TransMatmulWeight
|
|
1160
|
+
ACL_CHECK(aclnnTransMatmulWeightGetWorkspaceSize(weightTransposed, &workspaceSize, &executor));
|
|
1161
|
+
std::unique_ptr<void, aclError (*)(void *)> workspaceAddrPtrTrans(nullptr, aclrtFree);
|
|
1162
|
+
if (workspaceSize > 0) {
|
|
1163
|
+
ACL_CHECK(aclrtMalloc(&workspaceAddr, workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST));
|
|
1164
|
+
workspaceAddrPtrTrans.reset(workspaceAddr);
|
|
1165
|
+
}
|
|
1166
|
+
ACL_CHECK(aclnnTransMatmulWeight(workspaceAddr, workspaceSize, executor, stream));
|
|
1167
|
+
|
|
1168
|
+
size_t size = ggml_nelements(tensor) * ggml_element_size(tensor);
|
|
1169
|
+
|
|
1170
|
+
aclrtMemcpy((char *)tensor->data + offset, size,
|
|
1171
|
+
weightTransposedDeviceAddr, size, ACL_MEMCPY_HOST_TO_DEVICE);
|
|
1172
|
+
ACL_CHECK(aclDestroyTensor(weightTransposed));
|
|
1173
|
+
aclrtFree(weightTransposedDeviceAddr);
|
|
1174
|
+
}
|
|
1175
|
+
|
|
1118
1176
|
// TODO: need handle tensor which has paddings.
|
|
1119
1177
|
/**
|
|
1120
1178
|
* @brief Set tensor data in a CANN buffer.
|
|
@@ -1139,9 +1197,16 @@ static void ggml_backend_cann_buffer_set_tensor(
|
|
|
1139
1197
|
// For acl, synchronous functions use this default stream.
|
|
1140
1198
|
// Why aclrtSynchronizeDevice?
|
|
1141
1199
|
|
|
1200
|
+
bool weightToNZ = false;
|
|
1201
|
+
#ifdef ASCEND_310P
|
|
1202
|
+
weightToNZ = (getenv("GGML_CANN_WEIGHT_NZ") != nullptr);
|
|
1203
|
+
#endif
|
|
1142
1204
|
if (!need_transform(tensor->type)) {
|
|
1143
1205
|
ACL_CHECK(aclrtMemcpy((char *)tensor->data + offset, size, data, size,
|
|
1144
1206
|
ACL_MEMCPY_HOST_TO_DEVICE));
|
|
1207
|
+
if (weightToNZ && is_matmul_weight((const ggml_tensor*)tensor)) {
|
|
1208
|
+
weight_format_to_nz(tensor, data, offset);
|
|
1209
|
+
}
|
|
1145
1210
|
} else {
|
|
1146
1211
|
void *transform_buffer = malloc(size);
|
|
1147
1212
|
ggml_backend_cann_transform(tensor, data, transform_buffer);
|
|
@@ -1616,16 +1681,18 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
|
|
|
1616
1681
|
case GGML_OP_UNARY:
|
|
1617
1682
|
switch (ggml_get_unary_op(dst)) {
|
|
1618
1683
|
case GGML_UNARY_OP_ABS:
|
|
1619
|
-
|
|
1684
|
+
GGML_CANN_CALL_OP_UNARY(Abs);
|
|
1620
1685
|
break;
|
|
1621
1686
|
case GGML_UNARY_OP_NEG:
|
|
1622
|
-
|
|
1687
|
+
GGML_CANN_CALL_OP_UNARY(Neg);
|
|
1623
1688
|
break;
|
|
1624
1689
|
case GGML_UNARY_OP_GELU:
|
|
1625
|
-
|
|
1690
|
+
case GGML_UNARY_OP_GELU_ERF:
|
|
1691
|
+
// aclnnGelu internally uses the erf-based approximation.
|
|
1692
|
+
GGML_CANN_CALL_OP_UNARY(Gelu);
|
|
1626
1693
|
break;
|
|
1627
1694
|
case GGML_UNARY_OP_SILU:
|
|
1628
|
-
|
|
1695
|
+
GGML_CANN_CALL_OP_UNARY(Silu);
|
|
1629
1696
|
break;
|
|
1630
1697
|
case GGML_UNARY_OP_GELU_QUICK: {
|
|
1631
1698
|
auto lambda = [](ggml_backend_cann_context& ctx,
|
|
@@ -1633,31 +1700,31 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
|
|
|
1633
1700
|
aclTensor* acl_dst) {
|
|
1634
1701
|
GGML_CANN_CALL_ACLNN_OP(ctx, GeluV2, acl_src, 0, acl_dst);
|
|
1635
1702
|
};
|
|
1636
|
-
|
|
1703
|
+
ggml_cann_op_unary(lambda, ctx, dst);
|
|
1637
1704
|
} break;
|
|
1638
1705
|
case GGML_UNARY_OP_TANH:
|
|
1639
|
-
|
|
1706
|
+
GGML_CANN_CALL_OP_UNARY(Tanh);
|
|
1640
1707
|
break;
|
|
1641
1708
|
case GGML_UNARY_OP_RELU:
|
|
1642
|
-
|
|
1709
|
+
GGML_CANN_CALL_OP_UNARY(Relu);
|
|
1643
1710
|
break;
|
|
1644
1711
|
case GGML_UNARY_OP_SIGMOID:
|
|
1645
|
-
|
|
1712
|
+
GGML_CANN_CALL_OP_UNARY(Sigmoid);
|
|
1646
1713
|
break;
|
|
1647
1714
|
case GGML_UNARY_OP_HARDSIGMOID:
|
|
1648
|
-
|
|
1715
|
+
GGML_CANN_CALL_OP_UNARY(Hardsigmoid);
|
|
1649
1716
|
break;
|
|
1650
1717
|
case GGML_UNARY_OP_HARDSWISH:
|
|
1651
|
-
|
|
1718
|
+
GGML_CANN_CALL_OP_UNARY(Hardswish);
|
|
1652
1719
|
break;
|
|
1653
1720
|
case GGML_UNARY_OP_EXP:
|
|
1654
|
-
|
|
1721
|
+
GGML_CANN_CALL_OP_UNARY(Exp);
|
|
1655
1722
|
break;
|
|
1656
1723
|
case GGML_UNARY_OP_ELU:
|
|
1657
1724
|
ggml_cann_elu(ctx, dst);
|
|
1658
1725
|
break;
|
|
1659
1726
|
case GGML_UNARY_OP_SGN:
|
|
1660
|
-
|
|
1727
|
+
GGML_CANN_CALL_OP_UNARY(Sign);
|
|
1661
1728
|
break;
|
|
1662
1729
|
case GGML_UNARY_OP_STEP:
|
|
1663
1730
|
ggml_cann_step(ctx, dst);
|
|
@@ -1666,6 +1733,31 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
|
|
|
1666
1733
|
return false;
|
|
1667
1734
|
}
|
|
1668
1735
|
break;
|
|
1736
|
+
case GGML_OP_GLU:
|
|
1737
|
+
switch (ggml_get_glu_op(dst)) {
|
|
1738
|
+
case GGML_GLU_OP_REGLU:
|
|
1739
|
+
GGML_CANN_CALL_OP_UNARY_GATED(Relu);
|
|
1740
|
+
break;
|
|
1741
|
+
case GGML_GLU_OP_GEGLU:
|
|
1742
|
+
case GGML_GLU_OP_GEGLU_ERF:
|
|
1743
|
+
// aclnnGelu internally uses the erf-based approximation.
|
|
1744
|
+
GGML_CANN_CALL_OP_UNARY_GATED(Gelu);
|
|
1745
|
+
break;
|
|
1746
|
+
case GGML_GLU_OP_SWIGLU:
|
|
1747
|
+
GGML_CANN_CALL_OP_UNARY_GATED(Silu);
|
|
1748
|
+
break;
|
|
1749
|
+
case GGML_GLU_OP_GEGLU_QUICK: {
|
|
1750
|
+
auto lambda = [](ggml_backend_cann_context& ctx,
|
|
1751
|
+
aclTensor* acl_src,
|
|
1752
|
+
aclTensor* acl_dst) {
|
|
1753
|
+
GGML_CANN_CALL_ACLNN_OP(ctx, GeluV2, acl_src, 0, acl_dst);
|
|
1754
|
+
};
|
|
1755
|
+
ggml_cann_op_unary_gated(lambda, ctx, dst);
|
|
1756
|
+
} break;
|
|
1757
|
+
default:
|
|
1758
|
+
return false;
|
|
1759
|
+
}
|
|
1760
|
+
break;
|
|
1669
1761
|
case GGML_OP_NORM:
|
|
1670
1762
|
ggml_cann_norm(ctx, dst);
|
|
1671
1763
|
break;
|
|
@@ -1708,7 +1800,7 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
|
|
|
1708
1800
|
ggml_cann_binary_op<aclnn_mul>(ctx, dst);
|
|
1709
1801
|
break;
|
|
1710
1802
|
case GGML_OP_SQRT:
|
|
1711
|
-
|
|
1803
|
+
GGML_CANN_CALL_OP_UNARY(Sqrt);
|
|
1712
1804
|
break;
|
|
1713
1805
|
case GGML_OP_CLAMP:
|
|
1714
1806
|
ggml_cann_clamp(ctx, dst);
|
|
@@ -1753,16 +1845,16 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
|
|
|
1753
1845
|
ggml_cann_argmax(ctx, dst);
|
|
1754
1846
|
break;
|
|
1755
1847
|
case GGML_OP_COS:
|
|
1756
|
-
|
|
1848
|
+
ggml_cann_op_unary<aclnn_cos>(ctx, dst);
|
|
1757
1849
|
break;
|
|
1758
1850
|
case GGML_OP_SIN:
|
|
1759
|
-
|
|
1851
|
+
ggml_cann_op_unary<aclnn_sin>(ctx, dst);
|
|
1760
1852
|
break;
|
|
1761
1853
|
case GGML_OP_CONV_TRANSPOSE_1D:
|
|
1762
1854
|
ggml_cann_conv_transpose_1d(ctx, dst);
|
|
1763
1855
|
break;
|
|
1764
1856
|
case GGML_OP_LOG:
|
|
1765
|
-
|
|
1857
|
+
GGML_CANN_CALL_OP_UNARY(Log);
|
|
1766
1858
|
break;
|
|
1767
1859
|
case GGML_OP_MEAN:
|
|
1768
1860
|
ggml_cann_mean(ctx, dst);
|
|
@@ -2036,10 +2128,23 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
|
|
|
2036
2128
|
case GGML_UNARY_OP_ELU:
|
|
2037
2129
|
case GGML_UNARY_OP_SGN:
|
|
2038
2130
|
case GGML_UNARY_OP_STEP:
|
|
2131
|
+
case GGML_UNARY_OP_GELU_ERF:
|
|
2039
2132
|
return true;
|
|
2040
2133
|
default:
|
|
2041
2134
|
return false;
|
|
2042
2135
|
}
|
|
2136
|
+
case GGML_OP_GLU:
|
|
2137
|
+
switch (ggml_get_glu_op(op)) {
|
|
2138
|
+
case GGML_GLU_OP_REGLU:
|
|
2139
|
+
case GGML_GLU_OP_GEGLU:
|
|
2140
|
+
case GGML_GLU_OP_SWIGLU:
|
|
2141
|
+
case GGML_GLU_OP_GEGLU_ERF:
|
|
2142
|
+
case GGML_GLU_OP_GEGLU_QUICK:
|
|
2143
|
+
return true;
|
|
2144
|
+
default:
|
|
2145
|
+
return false;
|
|
2146
|
+
}
|
|
2147
|
+
break;
|
|
2043
2148
|
case GGML_OP_MUL_MAT: {
|
|
2044
2149
|
switch (op->src[0]->type) {
|
|
2045
2150
|
case GGML_TYPE_F16:
|
|
@@ -2086,6 +2191,13 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
|
|
|
2086
2191
|
return false;
|
|
2087
2192
|
}
|
|
2088
2193
|
} break;
|
|
2194
|
+
case GGML_OP_SET_ROWS:
|
|
2195
|
+
{
|
|
2196
|
+
// TODO: add support
|
|
2197
|
+
// ref: https://github.com/ggml-org/llama.cpp/pull/14274
|
|
2198
|
+
#pragma message("TODO: implement F32, F16, BF16, Q4_0, Q4_1, Q5_0, Q5_1, Q8_0, IQ4_NL support (https://github.com/ggml-org/llama.cpp/pull/14661)")
|
|
2199
|
+
return false;
|
|
2200
|
+
} break;
|
|
2089
2201
|
case GGML_OP_CPY: {
|
|
2090
2202
|
ggml_tensor *src = op->src[0];
|
|
2091
2203
|
if ((op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_F16) ||
|
|
@@ -2182,12 +2294,10 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
|
|
|
2182
2294
|
case GGML_OP_MUL:
|
|
2183
2295
|
case GGML_OP_DIV:
|
|
2184
2296
|
case GGML_OP_RMS_NORM:
|
|
2185
|
-
case GGML_OP_SCALE:
|
|
2186
2297
|
case GGML_OP_SQR:
|
|
2187
2298
|
case GGML_OP_SQRT:
|
|
2188
2299
|
case GGML_OP_CLAMP:
|
|
2189
2300
|
case GGML_OP_DIAG_MASK_INF:
|
|
2190
|
-
case GGML_OP_SOFT_MAX:
|
|
2191
2301
|
case GGML_OP_SUM_ROWS:
|
|
2192
2302
|
case GGML_OP_ARGSORT:
|
|
2193
2303
|
case GGML_OP_ACC:
|
|
@@ -2205,6 +2315,14 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
|
|
|
2205
2315
|
case GGML_OP_PAD_REFLECT_1D:
|
|
2206
2316
|
case GGML_OP_COUNT_EQUAL:
|
|
2207
2317
|
return true;
|
|
2318
|
+
case GGML_OP_SCALE:
|
|
2319
|
+
float bias;
|
|
2320
|
+
memcpy(&bias, (float*)op->op_params + 1, sizeof(float));
|
|
2321
|
+
return bias == 0.0f; // TODO: support bias != 0.0f
|
|
2322
|
+
case GGML_OP_SOFT_MAX:
|
|
2323
|
+
// TODO: support broadcast
|
|
2324
|
+
// ref: https://github.com/ggml-org/llama.cpp/pull/14435
|
|
2325
|
+
return !op->src[1] || (op->src[1]->ne[2] == 1 && op->src[1]->ne[3] == 1);
|
|
2208
2326
|
case GGML_OP_FLASH_ATTN_EXT:{
|
|
2209
2327
|
// derived from [ggml-cuda.cu]
|
|
2210
2328
|
if(op->src[1]->type != GGML_TYPE_F16 || op->src[2]->type != GGML_TYPE_F16){
|
|
@@ -2227,6 +2345,8 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
|
|
|
2227
2345
|
// DeepSeek MLA
|
|
2228
2346
|
return false;
|
|
2229
2347
|
}
|
|
2348
|
+
// TODO: support broadcast
|
|
2349
|
+
// ref: https://github.com/ggml-org/llama.cpp/pull/14435
|
|
2230
2350
|
if (op->src[0]->ne[3] != 1) {
|
|
2231
2351
|
return false;
|
|
2232
2352
|
}
|
|
@@ -5,7 +5,7 @@ function(ggml_add_cpu_backend_features cpu_name arch)
|
|
|
5
5
|
# build, using set_source_files_properties() to set the arch flags is not possible
|
|
6
6
|
set(GGML_CPU_FEATS_NAME ${cpu_name}-feats)
|
|
7
7
|
add_library(${GGML_CPU_FEATS_NAME} OBJECT ggml-cpu/arch/${arch}/cpu-feats.cpp)
|
|
8
|
-
target_include_directories(${GGML_CPU_FEATS_NAME} PRIVATE .
|
|
8
|
+
target_include_directories(${GGML_CPU_FEATS_NAME} PRIVATE . ../include)
|
|
9
9
|
target_compile_definitions(${GGML_CPU_FEATS_NAME} PRIVATE ${ARGN})
|
|
10
10
|
target_compile_definitions(${GGML_CPU_FEATS_NAME} PRIVATE GGML_BACKEND_DL GGML_BACKEND_BUILD GGML_BACKEND_SHARED)
|
|
11
11
|
set_target_properties(${GGML_CPU_FEATS_NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
|
@@ -70,10 +70,12 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
|
|
70
70
|
if (GGML_OPENMP)
|
|
71
71
|
find_package(OpenMP)
|
|
72
72
|
if (OpenMP_FOUND)
|
|
73
|
+
set(GGML_OPENMP_ENABLED "ON" CACHE INTERNAL "")
|
|
73
74
|
target_compile_definitions(${GGML_CPU_NAME} PRIVATE GGML_USE_OPENMP)
|
|
74
75
|
|
|
75
76
|
target_link_libraries(${GGML_CPU_NAME} PRIVATE OpenMP::OpenMP_C OpenMP::OpenMP_CXX)
|
|
76
77
|
else()
|
|
78
|
+
set(GGML_OPENMP_ENABLED "OFF" CACHE INTERNAL "")
|
|
77
79
|
message(WARNING "OpenMP not found")
|
|
78
80
|
endif()
|
|
79
81
|
endif()
|
|
@@ -456,6 +458,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
|
|
456
458
|
list(APPEND ARCH_FLAGS -march=z16)
|
|
457
459
|
elseif (${S390X_M} MATCHES "9175|9176")
|
|
458
460
|
# NOTE: Only available from GCC 15.1.0 onwards. Any z17 machine with compile issues must first verify their GCC version.
|
|
461
|
+
# binutils must also be updated to the latest for the -march=z17 flag to work. Otherwise, use -march=arch15.
|
|
459
462
|
message(STATUS "z17 target")
|
|
460
463
|
list(APPEND ARCH_FLAGS -march=z17)
|
|
461
464
|
else()
|
|
@@ -494,9 +497,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
|
|
494
497
|
|
|
495
498
|
# Fetch KleidiAI sources:
|
|
496
499
|
include(FetchContent)
|
|
497
|
-
set(KLEIDIAI_COMMIT_TAG "v1.
|
|
500
|
+
set(KLEIDIAI_COMMIT_TAG "v1.11.0")
|
|
498
501
|
set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz")
|
|
499
|
-
set(KLEIDIAI_ARCHIVE_MD5 "
|
|
502
|
+
set(KLEIDIAI_ARCHIVE_MD5 "3fe9e5ab964c375c53839296eb71eaa2")
|
|
500
503
|
|
|
501
504
|
if (POLICY CMP0135)
|
|
502
505
|
cmake_policy(SET CMP0135 NEW)
|
|
@@ -589,4 +592,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
|
|
589
592
|
if (EMSCRIPTEN)
|
|
590
593
|
set_target_properties(${GGML_CPU_NAME} PROPERTIES COMPILE_FLAGS "-msimd128")
|
|
591
594
|
endif()
|
|
595
|
+
|
|
596
|
+
if (CMAKE_CXX_COMPILER_ID STREQUAL "IntelLLVM")
|
|
597
|
+
# The compiler automatically enables "-ffast-math" which can cause NaNs in tests due to "-fassociative-math"
|
|
598
|
+
target_compile_options(${GGML_CPU_NAME} PRIVATE "-fno-associative-math")
|
|
599
|
+
endif()
|
|
592
600
|
endfunction()
|
|
@@ -544,7 +544,7 @@ void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, i
|
|
|
544
544
|
__m128 max4 = __lsx_vfmax_s( lasx_extractf128( max_abs, 1 ), lasx_extractf128( max_abs, 0) );
|
|
545
545
|
max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vpickod_d((__m128i) max4, (__m128i)max4 ) );
|
|
546
546
|
__m128 tmp = max4;
|
|
547
|
-
max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vextrins_w((__m128i)tmp, (__m128i)max4,
|
|
547
|
+
max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vextrins_w((__m128i)tmp, (__m128i)max4, 0x1 ));
|
|
548
548
|
const float max_scalar = ((v4f32)max4)[0];
|
|
549
549
|
|
|
550
550
|
// Quantize these floats
|
|
@@ -1193,7 +1193,7 @@ static void ggml_compute_forward_mul_mat_one_chunk(
|
|
|
1193
1193
|
}
|
|
1194
1194
|
}
|
|
1195
1195
|
|
|
1196
|
-
|
|
1196
|
+
void ggml_compute_forward_mul_mat(
|
|
1197
1197
|
const struct ggml_compute_params * params,
|
|
1198
1198
|
struct ggml_tensor * dst) {
|
|
1199
1199
|
|
|
@@ -1866,6 +1866,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
|
1866
1866
|
{
|
|
1867
1867
|
ggml_compute_forward_im2col_back_f32(params, tensor);
|
|
1868
1868
|
} break;
|
|
1869
|
+
case GGML_OP_CONV_2D:
|
|
1870
|
+
{
|
|
1871
|
+
ggml_compute_forward_conv_2d(params, tensor);
|
|
1872
|
+
} break;
|
|
1869
1873
|
case GGML_OP_CONV_2D_DW:
|
|
1870
1874
|
{
|
|
1871
1875
|
ggml_compute_forward_conv_2d_dw(params, tensor);
|
|
@@ -1949,6 +1953,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
|
1949
1953
|
{
|
|
1950
1954
|
ggml_compute_forward_unary(params, tensor);
|
|
1951
1955
|
} break;
|
|
1956
|
+
case GGML_OP_GLU:
|
|
1957
|
+
{
|
|
1958
|
+
ggml_compute_forward_glu(params, tensor);
|
|
1959
|
+
} break;
|
|
1952
1960
|
case GGML_OP_GET_REL_POS:
|
|
1953
1961
|
{
|
|
1954
1962
|
ggml_compute_forward_get_rel_pos(params, tensor);
|
|
@@ -2159,6 +2167,20 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
|
|
2159
2167
|
GGML_ABORT("fatal error");
|
|
2160
2168
|
}
|
|
2161
2169
|
break;
|
|
2170
|
+
case GGML_OP_GLU:
|
|
2171
|
+
switch (ggml_get_glu_op(node)) {
|
|
2172
|
+
case GGML_GLU_OP_REGLU:
|
|
2173
|
+
case GGML_GLU_OP_GEGLU:
|
|
2174
|
+
case GGML_GLU_OP_SWIGLU:
|
|
2175
|
+
case GGML_GLU_OP_GEGLU_ERF:
|
|
2176
|
+
case GGML_GLU_OP_GEGLU_QUICK:
|
|
2177
|
+
{
|
|
2178
|
+
n_tasks = n_threads;
|
|
2179
|
+
} break;
|
|
2180
|
+
default:
|
|
2181
|
+
GGML_ABORT("fatal error");
|
|
2182
|
+
}
|
|
2183
|
+
break;
|
|
2162
2184
|
case GGML_OP_SILU_BACK:
|
|
2163
2185
|
case GGML_OP_MUL:
|
|
2164
2186
|
case GGML_OP_DIV:
|
|
@@ -2212,6 +2234,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
|
|
2212
2234
|
} break;
|
|
2213
2235
|
case GGML_OP_IM2COL:
|
|
2214
2236
|
case GGML_OP_IM2COL_BACK:
|
|
2237
|
+
case GGML_OP_CONV_2D:
|
|
2215
2238
|
case GGML_OP_CONV_2D_DW:
|
|
2216
2239
|
case GGML_OP_CONV_TRANSPOSE_1D:
|
|
2217
2240
|
case GGML_OP_CONV_TRANSPOSE_2D:
|
|
@@ -2730,6 +2753,10 @@ struct ggml_cplan ggml_graph_plan(
|
|
|
2730
2753
|
GGML_ABORT("fatal error");
|
|
2731
2754
|
}
|
|
2732
2755
|
} break;
|
|
2756
|
+
case GGML_OP_CONV_2D:
|
|
2757
|
+
{
|
|
2758
|
+
cur = GGML_IM2COL_WORK_SIZE;
|
|
2759
|
+
} break;
|
|
2733
2760
|
case GGML_OP_CONV_TRANSPOSE_2D:
|
|
2734
2761
|
{
|
|
2735
2762
|
const int64_t ne00 = node->src[0]->ne[0]; // W
|
|
@@ -22,9 +22,94 @@
|
|
|
22
22
|
|
|
23
23
|
#include "kai_common.h"
|
|
24
24
|
|
|
25
|
+
#include "simd-mappings.h"
|
|
26
|
+
|
|
25
27
|
#include "kernels.h"
|
|
26
28
|
|
|
27
29
|
#define NELEMS(x) sizeof(x) / sizeof(*x)
|
|
30
|
+
|
|
31
|
+
static const size_t INT4_PER_BYTE = 2;
|
|
32
|
+
static const size_t INT4_BITS = 4;
|
|
33
|
+
static const int Q4_0_ZERO_POINT = 8;
|
|
34
|
+
const size_t INT4_PER_UINT16 = 4;
|
|
35
|
+
|
|
36
|
+
static void dequantize_row_qsi4c32pscalef16(
|
|
37
|
+
const void *packed_data,
|
|
38
|
+
int32_t row_idx,
|
|
39
|
+
int64_t nc,
|
|
40
|
+
float *out,
|
|
41
|
+
size_t nr_pack,
|
|
42
|
+
size_t packed_row_stride,
|
|
43
|
+
size_t kr,
|
|
44
|
+
size_t bl,
|
|
45
|
+
size_t num_bytes_multiplier
|
|
46
|
+
) {
|
|
47
|
+
size_t group_idx = row_idx / nr_pack;
|
|
48
|
+
size_t row_in_group = row_idx % nr_pack;
|
|
49
|
+
const uint8_t *packed_group = (const uint8_t *)packed_data + group_idx * packed_row_stride;
|
|
50
|
+
size_t num_blocks = nc / bl;
|
|
51
|
+
const uint8_t *block_ptr = packed_group;
|
|
52
|
+
|
|
53
|
+
for (size_t b = 0; b < num_blocks; ++b) {
|
|
54
|
+
uint16_t scale_f16 = *((const uint16_t *)(block_ptr + row_in_group * num_bytes_multiplier));
|
|
55
|
+
float scale = GGML_CPU_FP16_TO_FP32(scale_f16);
|
|
56
|
+
|
|
57
|
+
const uint8_t *segment_ptr = block_ptr + nr_pack * num_bytes_multiplier;
|
|
58
|
+
size_t num_segments = bl / kr;
|
|
59
|
+
size_t num_bytes_per_segment = kr / INT4_PER_BYTE;
|
|
60
|
+
|
|
61
|
+
for (size_t s = 0; s < num_segments; ++s) {
|
|
62
|
+
const uint8_t *seg_base = segment_ptr + s * nr_pack * num_bytes_per_segment;
|
|
63
|
+
const uint8_t *qbytes = seg_base + row_in_group * num_bytes_per_segment;
|
|
64
|
+
for (size_t k = 0; k < num_bytes_per_segment; ++k) {
|
|
65
|
+
uint8_t byte = qbytes[k] ^ 0x88;
|
|
66
|
+
int x0 = (byte & 0x0F) - Q4_0_ZERO_POINT;
|
|
67
|
+
int x1 = (byte >> INT4_BITS) - Q4_0_ZERO_POINT;
|
|
68
|
+
out[b * bl + s * num_bytes_per_segment + k] = x0 * scale;
|
|
69
|
+
out[b * bl + s * num_bytes_per_segment + k + bl/2] = x1 * scale;
|
|
70
|
+
}
|
|
71
|
+
}
|
|
72
|
+
block_ptr += nr_pack * num_bytes_multiplier + num_segments * nr_pack * num_bytes_per_segment;
|
|
73
|
+
}
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
static void dequantize_row_qsi4c32ps1s0scalef16(
|
|
77
|
+
const void *packed_data,
|
|
78
|
+
int32_t row_idx,
|
|
79
|
+
int64_t k,
|
|
80
|
+
float *out,
|
|
81
|
+
size_t nr,
|
|
82
|
+
size_t packed_row_stride,
|
|
83
|
+
size_t kr,
|
|
84
|
+
size_t bl,
|
|
85
|
+
size_t num_bytes_multiplier
|
|
86
|
+
) {
|
|
87
|
+
const size_t num_blocks = k / bl;
|
|
88
|
+
const size_t bl4 = bl / INT4_PER_UINT16;
|
|
89
|
+
|
|
90
|
+
size_t group_idx = row_idx / nr;
|
|
91
|
+
size_t row_in_group = row_idx % nr;
|
|
92
|
+
|
|
93
|
+
const uint8_t *packed_group = (const uint8_t *)packed_data + group_idx * packed_row_stride;
|
|
94
|
+
const uint16_t *qdata = (const uint16_t *)packed_group;
|
|
95
|
+
const uint16_t *scales = (const uint16_t *)(packed_group + packed_row_stride - (nr * num_blocks * num_bytes_multiplier));
|
|
96
|
+
|
|
97
|
+
for (size_t block_idx = 0; block_idx < num_blocks; ++block_idx) {
|
|
98
|
+
uint16_t scale_f16 = scales[row_in_group + block_idx * nr];
|
|
99
|
+
float scale = GGML_CPU_FP16_TO_FP32(scale_f16);
|
|
100
|
+
|
|
101
|
+
for (size_t bl4_idx = 0; bl4_idx < bl4; ++bl4_idx) {
|
|
102
|
+
uint16_t q = qdata[(block_idx * bl4 + bl4_idx) * nr + row_in_group];
|
|
103
|
+
|
|
104
|
+
for (size_t qidx = 0; qidx < INT4_PER_UINT16; ++qidx) {
|
|
105
|
+
int v = ((q >> (qidx * 4)) & 0xF) - Q4_0_ZERO_POINT;
|
|
106
|
+
out[block_idx * bl + bl4_idx * INT4_BITS + qidx] = v * scale;
|
|
107
|
+
}
|
|
108
|
+
}
|
|
109
|
+
}
|
|
110
|
+
GGML_UNUSED(kr);
|
|
111
|
+
}
|
|
112
|
+
|
|
28
113
|
static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
|
29
114
|
#if defined(__ARM_FEATURE_SME)
|
|
30
115
|
{
|
|
@@ -63,8 +148,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
|
|
63
148
|
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32_neon,
|
|
64
149
|
},
|
|
65
150
|
/* .rhs_info = */ {
|
|
66
|
-
/* .packed_size
|
|
67
|
-
/* .
|
|
151
|
+
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
|
|
152
|
+
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
|
|
153
|
+
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
|
|
154
|
+
/* .to_float = */ dequantize_row_qsi4c32ps1s0scalef16,
|
|
68
155
|
},
|
|
69
156
|
/* .required_cpu = */ CPU_FEATURE_SME,
|
|
70
157
|
/* .lhs_type = */ GGML_TYPE_F32,
|
|
@@ -107,8 +194,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
|
|
107
194
|
/* .pack_func = */ kai_run_lhs_pack_bf16p2vlx2_f32_sme,
|
|
108
195
|
},
|
|
109
196
|
/* .rhs_info = */ {
|
|
110
|
-
/* .packed_size
|
|
111
|
-
/* .
|
|
197
|
+
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme,
|
|
198
|
+
/* .packed_stride = */ NULL,
|
|
199
|
+
/* .pack_func = */ kai_run_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme,
|
|
200
|
+
/* .to_float = */ NULL,
|
|
112
201
|
},
|
|
113
202
|
/* .required_cpu = */ CPU_FEATURE_SME,
|
|
114
203
|
/* .lhs_type = */ GGML_TYPE_F32,
|
|
@@ -154,8 +243,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
|
|
154
243
|
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
|
|
155
244
|
},
|
|
156
245
|
/* .rhs_info = */ {
|
|
157
|
-
/* .packed_size
|
|
158
|
-
/* .
|
|
246
|
+
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
|
247
|
+
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
|
248
|
+
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
|
249
|
+
/* .to_float = */ dequantize_row_qsi4c32pscalef16,
|
|
159
250
|
},
|
|
160
251
|
/* .required_cpu = */ CPU_FEATURE_DOTPROD,
|
|
161
252
|
/* .lhs_type = */ GGML_TYPE_F32,
|
|
@@ -200,8 +291,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
|
|
200
291
|
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
|
|
201
292
|
},
|
|
202
293
|
/* .rhs_info = */ {
|
|
203
|
-
/* .packed_size
|
|
204
|
-
/* .
|
|
294
|
+
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
|
295
|
+
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
|
296
|
+
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
|
297
|
+
/* .to_float = */ dequantize_row_qsi4c32pscalef16,
|
|
205
298
|
},
|
|
206
299
|
/* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
|
|
207
300
|
/* .lhs_type = */ GGML_TYPE_F32,
|
|
@@ -247,8 +340,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
|
|
247
340
|
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
|
|
248
341
|
},
|
|
249
342
|
/* .rhs_info = */ {
|
|
250
|
-
/* .packed_size
|
|
251
|
-
/* .
|
|
343
|
+
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
|
344
|
+
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
|
345
|
+
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
|
346
|
+
/* .to_float = */ dequantize_row_qsi4c32pscalef16,
|
|
252
347
|
},
|
|
253
348
|
/* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
|
|
254
349
|
/* .lhs_type = */ GGML_TYPE_F32,
|
|
@@ -293,8 +388,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
|
|
293
388
|
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
|
|
294
389
|
},
|
|
295
390
|
/* .rhs_info = */ {
|
|
296
|
-
/* .packed_size
|
|
297
|
-
/* .
|
|
391
|
+
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
|
392
|
+
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
|
393
|
+
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
|
394
|
+
/* .to_float = */ dequantize_row_qsi4c32pscalef16,
|
|
298
395
|
},
|
|
299
396
|
/* .required_cpu = */ CPU_FEATURE_DOTPROD,
|
|
300
397
|
/* .lhs_type = */ GGML_TYPE_F32,
|
|
@@ -71,12 +71,15 @@ struct rhs_packing_info {
|
|
|
71
71
|
std::function<size_t(size_t n, size_t k, size_t nr, size_t kr, size_t bl)>,
|
|
72
72
|
std::function<size_t(size_t n, size_t k)>
|
|
73
73
|
> packed_size;
|
|
74
|
+
size_t (*packed_stride)(size_t k, size_t nr, size_t kr, size_t bl);
|
|
74
75
|
std::variant<
|
|
75
76
|
std::function<void(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl, const uint8_t* rhs,
|
|
76
77
|
const float* bias, void* rhs_packed, size_t extra_bytes, const struct kai_rhs_pack_qs4cxs1s0_param* params)>,
|
|
77
78
|
std::function<void(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs,
|
|
78
79
|
const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const void* params)>
|
|
79
80
|
> pack_func;
|
|
81
|
+
void (*to_float)(const void *packed_data, int32_t row_idx, int64_t nc, float *out, size_t nr_pack, size_t packed_row_stride,
|
|
82
|
+
size_t kr, size_t bl, size_t num_bytes_multiplier);
|
|
80
83
|
};
|
|
81
84
|
|
|
82
85
|
struct ggml_kleidiai_kernels {
|