@novastera-oss/llamarn 0.2.9 → 0.3.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/build.gradle +2 -1
- package/android/proguard-rules.pro +12 -0
- package/android/src/main/cpp/include/llama.h +15 -47
- package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86/libggml.so +0 -0
- package/android/src/main/jniLibs/x86/libllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
- package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
- package/cpp/build-info.cpp +2 -2
- package/cpp/llama.cpp/CMakeLists.txt +0 -1
- package/cpp/llama.cpp/CMakePresets.json +11 -0
- package/cpp/llama.cpp/CODEOWNERS +1 -0
- package/cpp/llama.cpp/README.md +8 -8
- package/cpp/llama.cpp/build-xcframework.sh +1 -1
- package/cpp/llama.cpp/common/CMakeLists.txt +4 -5
- package/cpp/llama.cpp/common/arg.cpp +62 -1
- package/cpp/llama.cpp/common/chat.cpp +37 -20
- package/cpp/llama.cpp/common/chat.h +2 -0
- package/cpp/llama.cpp/common/common.cpp +22 -6
- package/cpp/llama.cpp/common/common.h +22 -4
- package/cpp/llama.cpp/convert_hf_to_gguf.py +1250 -43
- package/cpp/llama.cpp/convert_hf_to_gguf_update.py +21 -13
- package/cpp/llama.cpp/ggml/CMakeLists.txt +13 -3
- package/cpp/llama.cpp/ggml/cmake/ggml-config.cmake.in +85 -47
- package/cpp/llama.cpp/ggml/include/ggml-backend.h +1 -1
- package/cpp/llama.cpp/ggml/include/ggml-webgpu.h +19 -0
- package/cpp/llama.cpp/ggml/include/ggml.h +173 -10
- package/cpp/llama.cpp/ggml/src/CMakeLists.txt +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-alloc.c +0 -15
- package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +7 -8
- package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +44 -38
- package/cpp/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +3 -1
- package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +126 -8
- package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +130 -22
- package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +138 -18
- package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +11 -3
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +28 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +109 -12
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +88 -10
- package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +343 -1094
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +1206 -163
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +6 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +0 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +36 -9
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +142 -9
- package/cpp/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +3 -3
- package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +31 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +86 -17
- package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy-utils.cuh +225 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy.cu +41 -301
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +85 -64
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +47 -60
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +29 -42
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +46 -59
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +36 -45
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +38 -45
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +23 -36
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cu +3 -13
- package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +255 -99
- package/cpp/llama.cpp/ggml/src/ggml-cuda/im2col.cu +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mma.cuh +111 -3
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cu +6 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +1152 -695
- package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cu +92 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cuh +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/rope.cu +21 -27
- package/cpp/llama.cpp/ggml/src/ggml-cuda/scale.cu +8 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cu +275 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cuh +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/softmax.cu +119 -58
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +192 -52
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +104 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +13 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/upscale.cu +92 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +27 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-impl.h +80 -0
- package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -2
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +48 -12
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +572 -106
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +599 -105
- package/cpp/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +18 -4
- package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +800 -42
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl +337 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/rms_norm.cl +79 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/set_rows.cl +95 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
- package/cpp/llama.cpp/ggml/src/ggml-quants.c +6 -6
- package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +4 -4
- package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +693 -1034
- package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +18 -9
- package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +14 -26
- package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +191 -55
- package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +8 -9
- package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +15 -18
- package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +131 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.hpp +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +2 -6
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +991 -307
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +265 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +59 -12
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +28 -23
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +14 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +38 -32
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +32 -27
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +44 -12
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +17 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +3 -8
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +128 -72
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +38 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +18 -3
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +1 -4
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +20 -4
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +84 -9
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/CMakeLists.txt +54 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/ggml-webgpu.cpp +907 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +60 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +35 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +56 -0
- package/cpp/llama.cpp/ggml/src/ggml.c +386 -67
- package/cpp/llama.cpp/ggml/src/gguf.cpp +8 -1
- package/cpp/llama.cpp/gguf-py/gguf/constants.py +307 -0
- package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +8 -2
- package/cpp/llama.cpp/gguf-py/gguf/metadata.py +4 -0
- package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_dump.py +24 -1
- package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +122 -47
- package/cpp/llama.cpp/gguf-py/gguf/vocab.py +12 -3
- package/cpp/llama.cpp/include/llama.h +15 -47
- package/cpp/llama.cpp/models/templates/llama-cpp-rwkv-world.jinja +34 -0
- package/cpp/llama.cpp/models/templates/moonshotai-Kimi-K2.jinja +43 -0
- package/cpp/llama.cpp/requirements/requirements-all.txt +1 -0
- package/cpp/llama.cpp/requirements/requirements-server-bench.txt +5 -0
- package/cpp/llama.cpp/src/llama-arch.cpp +316 -3
- package/cpp/llama.cpp/src/llama-arch.h +23 -1
- package/cpp/llama.cpp/src/llama-batch.cpp +103 -71
- package/cpp/llama.cpp/src/llama-batch.h +31 -18
- package/cpp/llama.cpp/src/llama-chat.cpp +58 -1
- package/cpp/llama.cpp/src/llama-chat.h +3 -0
- package/cpp/llama.cpp/src/llama-context.cpp +180 -106
- package/cpp/llama.cpp/src/llama-context.h +26 -16
- package/cpp/llama.cpp/src/llama-cparams.h +3 -2
- package/cpp/llama.cpp/src/llama-graph.cpp +310 -211
- package/cpp/llama.cpp/src/llama-graph.h +184 -122
- package/cpp/llama.cpp/src/llama-hparams.cpp +47 -1
- package/cpp/llama.cpp/src/llama-hparams.h +13 -2
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +38 -22
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +7 -2
- package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +849 -304
- package/cpp/llama.cpp/src/llama-kv-cache-unified.h +143 -47
- package/cpp/llama.cpp/src/llama-kv-cells.h +62 -10
- package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +10 -4
- package/cpp/llama.cpp/src/llama-memory-hybrid.h +3 -1
- package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +36 -11
- package/cpp/llama.cpp/src/llama-memory.cpp +17 -0
- package/cpp/llama.cpp/src/llama-memory.h +3 -0
- package/cpp/llama.cpp/src/llama-model.cpp +3545 -719
- package/cpp/llama.cpp/src/llama-model.h +21 -4
- package/cpp/llama.cpp/src/llama-quant.cpp +2 -2
- package/cpp/llama.cpp/src/llama-vocab.cpp +376 -10
- package/cpp/llama.cpp/src/llama-vocab.h +43 -0
- package/cpp/llama.cpp/src/unicode.cpp +207 -0
- package/cpp/llama.cpp/src/unicode.h +2 -0
- package/ios/include/chat.h +2 -0
- package/ios/include/common.h +22 -4
- package/ios/include/llama.h +15 -47
- package/ios/libs/llama.xcframework/Info.plist +13 -13
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5267 -4890
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -4861
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4014 -3764
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -4861
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4016 -3766
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5267 -4890
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -4861
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4014 -3764
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5303 -4926
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5274 -4897
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4044 -3794
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/package.json +4 -4
- package/cpp/llama.cpp/ggml/include/ggml-kompute.h +0 -50
- package/cpp/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
- package/cpp/llama.cpp/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
|
@@ -55,6 +55,12 @@ static struct ggml_backend_metal_device_context {
|
|
|
55
55
|
bool has_residency_sets;
|
|
56
56
|
bool has_bfloat;
|
|
57
57
|
bool use_bfloat;
|
|
58
|
+
bool use_fusion;
|
|
59
|
+
|
|
60
|
+
int debug_fusion;
|
|
61
|
+
|
|
62
|
+
// how many times a given op was fused
|
|
63
|
+
uint64_t fuse_cnt[GGML_OP_COUNT];
|
|
58
64
|
|
|
59
65
|
size_t max_size;
|
|
60
66
|
|
|
@@ -69,6 +75,9 @@ static struct ggml_backend_metal_device_context {
|
|
|
69
75
|
/*.has_residency_sets =*/ false,
|
|
70
76
|
/*.has_bfloat =*/ false,
|
|
71
77
|
/*.use_bfloat =*/ false,
|
|
78
|
+
/*.use_fusion =*/ true,
|
|
79
|
+
/*.debug_fusion =*/ 0,
|
|
80
|
+
/*.fuse_cnt =*/ { 0 },
|
|
72
81
|
/*.max_size =*/ 0,
|
|
73
82
|
/*.name =*/ "",
|
|
74
83
|
};
|
|
@@ -83,16 +92,14 @@ static id<MTLDevice> ggml_backend_metal_device_acq(struct ggml_backend_metal_dev
|
|
|
83
92
|
|
|
84
93
|
if (ctx->mtl_device == nil) {
|
|
85
94
|
ctx->mtl_device = MTLCreateSystemDefaultDevice();
|
|
86
|
-
}
|
|
87
95
|
|
|
88
|
-
if (ctx->mtl_device) {
|
|
89
96
|
ctx->has_simdgroup_reduction = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
|
|
90
97
|
ctx->has_simdgroup_reduction |= [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
|
|
91
98
|
|
|
92
99
|
ctx->has_simdgroup_mm = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
|
|
93
100
|
|
|
94
101
|
#if defined(GGML_METAL_HAS_RESIDENCY_SETS)
|
|
95
|
-
ctx->has_residency_sets = getenv("GGML_METAL_NO_RESIDENCY") ==
|
|
102
|
+
ctx->has_residency_sets = getenv("GGML_METAL_NO_RESIDENCY") == nil;
|
|
96
103
|
#endif
|
|
97
104
|
|
|
98
105
|
ctx->has_bfloat = [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
|
|
@@ -103,6 +110,14 @@ static id<MTLDevice> ggml_backend_metal_device_acq(struct ggml_backend_metal_dev
|
|
|
103
110
|
#else
|
|
104
111
|
ctx->use_bfloat = false;
|
|
105
112
|
#endif
|
|
113
|
+
ctx->use_fusion = getenv("GGML_METAL_FUSION_DISABLE") == nil;
|
|
114
|
+
|
|
115
|
+
{
|
|
116
|
+
const char * val = getenv("GGML_METAL_FUSION_DEBUG");
|
|
117
|
+
ctx->debug_fusion = val ? atoi(val) : 0;
|
|
118
|
+
}
|
|
119
|
+
|
|
120
|
+
memset(ctx->fuse_cnt, 0, sizeof(ctx->fuse_cnt));
|
|
106
121
|
|
|
107
122
|
ctx->max_size = ctx->mtl_device.maxBufferLength;
|
|
108
123
|
|
|
@@ -122,6 +137,18 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
|
|
|
122
137
|
ctx->mtl_device_ref_count--;
|
|
123
138
|
|
|
124
139
|
if (ctx->mtl_device_ref_count == 0) {
|
|
140
|
+
if (ctx->debug_fusion > 0) {
|
|
141
|
+
fprintf(stderr, "%s: fusion stats:\n", __func__);
|
|
142
|
+
for (int i = 0; i < GGML_OP_COUNT; i++) {
|
|
143
|
+
if (ctx->fuse_cnt[i] == 0) {
|
|
144
|
+
continue;
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
// note: cannot use ggml_log here
|
|
148
|
+
fprintf(stderr, "%s: - %s: %" PRIu64 "\n", __func__, ggml_op_name((enum ggml_op) i), ctx->fuse_cnt[i]);
|
|
149
|
+
}
|
|
150
|
+
}
|
|
151
|
+
|
|
125
152
|
if (ctx->mtl_lock) {
|
|
126
153
|
[ctx->mtl_lock release];
|
|
127
154
|
ctx->mtl_lock = nil;
|
|
@@ -147,13 +174,27 @@ struct ggml_metal_kernel {
|
|
|
147
174
|
|
|
148
175
|
enum ggml_metal_kernel_type {
|
|
149
176
|
GGML_METAL_KERNEL_TYPE_ADD,
|
|
150
|
-
|
|
177
|
+
GGML_METAL_KERNEL_TYPE_ADD_FUSE_2,
|
|
178
|
+
GGML_METAL_KERNEL_TYPE_ADD_FUSE_3,
|
|
179
|
+
GGML_METAL_KERNEL_TYPE_ADD_FUSE_4,
|
|
180
|
+
GGML_METAL_KERNEL_TYPE_ADD_FUSE_5,
|
|
181
|
+
GGML_METAL_KERNEL_TYPE_ADD_FUSE_6,
|
|
182
|
+
GGML_METAL_KERNEL_TYPE_ADD_FUSE_7,
|
|
183
|
+
GGML_METAL_KERNEL_TYPE_ADD_FUSE_8,
|
|
184
|
+
GGML_METAL_KERNEL_TYPE_ADD_ROW_C4,
|
|
185
|
+
GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2,
|
|
186
|
+
GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3,
|
|
187
|
+
GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4,
|
|
188
|
+
GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5,
|
|
189
|
+
GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6,
|
|
190
|
+
GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7,
|
|
191
|
+
GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8,
|
|
151
192
|
GGML_METAL_KERNEL_TYPE_SUB,
|
|
152
|
-
|
|
193
|
+
GGML_METAL_KERNEL_TYPE_SUB_ROW_C4,
|
|
153
194
|
GGML_METAL_KERNEL_TYPE_MUL,
|
|
154
|
-
|
|
195
|
+
GGML_METAL_KERNEL_TYPE_MUL_ROW_C4,
|
|
155
196
|
GGML_METAL_KERNEL_TYPE_DIV,
|
|
156
|
-
|
|
197
|
+
GGML_METAL_KERNEL_TYPE_DIV_ROW_C4,
|
|
157
198
|
GGML_METAL_KERNEL_TYPE_REPEAT_F32,
|
|
158
199
|
GGML_METAL_KERNEL_TYPE_REPEAT_F16,
|
|
159
200
|
GGML_METAL_KERNEL_TYPE_REPEAT_I32,
|
|
@@ -173,6 +214,12 @@ enum ggml_metal_kernel_type {
|
|
|
173
214
|
GGML_METAL_KERNEL_TYPE_SILU,
|
|
174
215
|
GGML_METAL_KERNEL_TYPE_SILU_4,
|
|
175
216
|
GGML_METAL_KERNEL_TYPE_ELU,
|
|
217
|
+
GGML_METAL_KERNEL_TYPE_ABS,
|
|
218
|
+
GGML_METAL_KERNEL_TYPE_SGN,
|
|
219
|
+
GGML_METAL_KERNEL_TYPE_STEP,
|
|
220
|
+
GGML_METAL_KERNEL_TYPE_HARDSWISH,
|
|
221
|
+
GGML_METAL_KERNEL_TYPE_HARDSIGMOID,
|
|
222
|
+
GGML_METAL_KERNEL_TYPE_EXP,
|
|
176
223
|
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16,
|
|
177
224
|
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4,
|
|
178
225
|
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32,
|
|
@@ -212,11 +259,14 @@ enum ggml_metal_kernel_type {
|
|
|
212
259
|
GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1,
|
|
213
260
|
GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL,
|
|
214
261
|
GGML_METAL_KERNEL_TYPE_RMS_NORM,
|
|
262
|
+
GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL,
|
|
263
|
+
GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD,
|
|
215
264
|
GGML_METAL_KERNEL_TYPE_L2_NORM,
|
|
216
265
|
GGML_METAL_KERNEL_TYPE_GROUP_NORM,
|
|
217
266
|
GGML_METAL_KERNEL_TYPE_NORM,
|
|
218
267
|
GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
|
|
219
268
|
GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,
|
|
269
|
+
GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP,
|
|
220
270
|
GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32,
|
|
221
271
|
GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32,
|
|
222
272
|
GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
|
|
@@ -526,6 +576,11 @@ enum ggml_metal_kernel_type {
|
|
|
526
576
|
GGML_METAL_KERNEL_TYPE_SIN,
|
|
527
577
|
GGML_METAL_KERNEL_TYPE_COS,
|
|
528
578
|
GGML_METAL_KERNEL_TYPE_NEG,
|
|
579
|
+
GGML_METAL_KERNEL_TYPE_REGLU,
|
|
580
|
+
GGML_METAL_KERNEL_TYPE_GEGLU,
|
|
581
|
+
GGML_METAL_KERNEL_TYPE_SWIGLU,
|
|
582
|
+
GGML_METAL_KERNEL_TYPE_GEGLU_ERF,
|
|
583
|
+
GGML_METAL_KERNEL_TYPE_GEGLU_QUICK,
|
|
529
584
|
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
|
|
530
585
|
GGML_METAL_KERNEL_TYPE_MEAN,
|
|
531
586
|
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
|
|
@@ -1123,13 +1178,27 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
|
1123
1178
|
// simd_sum and simd_max requires MTLGPUFamilyApple7
|
|
1124
1179
|
|
|
1125
1180
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true);
|
|
1126
|
-
GGML_METAL_ADD_KERNEL(
|
|
1181
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_2, add_fuse_2, true);
|
|
1182
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_3, add_fuse_3, true);
|
|
1183
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_4, add_fuse_4, true);
|
|
1184
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_5, add_fuse_5, true);
|
|
1185
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_6, add_fuse_6, true);
|
|
1186
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_7, add_fuse_7, true);
|
|
1187
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_8, add_fuse_8, true);
|
|
1188
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4, add_row_c4, true);
|
|
1189
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2, add_row_c4_fuse_2, true);
|
|
1190
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3, add_row_c4_fuse_3, true);
|
|
1191
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4, add_row_c4_fuse_4, true);
|
|
1192
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5, add_row_c4_fuse_5, true);
|
|
1193
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6, add_row_c4_fuse_6, true);
|
|
1194
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7, add_row_c4_fuse_7, true);
|
|
1195
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8, add_row_c4_fuse_8, true);
|
|
1127
1196
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB, sub, true);
|
|
1128
|
-
GGML_METAL_ADD_KERNEL(
|
|
1197
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB_ROW_C4, sub_row_c4, true);
|
|
1129
1198
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true);
|
|
1130
|
-
GGML_METAL_ADD_KERNEL(
|
|
1199
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW_C4, mul_row_c4, true);
|
|
1131
1200
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true);
|
|
1132
|
-
GGML_METAL_ADD_KERNEL(
|
|
1201
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW_C4, div_row_c4, true);
|
|
1133
1202
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true);
|
|
1134
1203
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true);
|
|
1135
1204
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I32, repeat_i32, true);
|
|
@@ -1149,6 +1218,12 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
|
1149
1218
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
|
|
1150
1219
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true);
|
|
1151
1220
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ELU, elu, true);
|
|
1221
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ABS, abs, true);
|
|
1222
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SGN, sgn, true);
|
|
1223
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_STEP, step, true);
|
|
1224
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_HARDSWISH, hardswish, true);
|
|
1225
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_HARDSIGMOID, hardsigmoid, true);
|
|
1226
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_EXP, exp, true);
|
|
1152
1227
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, has_simdgroup_reduction);
|
|
1153
1228
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, has_simdgroup_reduction);
|
|
1154
1229
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, has_simdgroup_reduction);
|
|
@@ -1188,11 +1263,14 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
|
1188
1263
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1, set_rows_q5_1, true);
|
|
1189
1264
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL, set_rows_iq4_nl, true);
|
|
1190
1265
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction);
|
|
1266
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL, rms_norm_mul, has_simdgroup_reduction);
|
|
1267
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD, rms_norm_mul_add, has_simdgroup_reduction);
|
|
1191
1268
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_L2_NORM, l2_norm, has_simdgroup_reduction);
|
|
1192
1269
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction);
|
|
1193
1270
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
|
|
1194
1271
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true);
|
|
1195
1272
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true);
|
|
1273
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP, ssm_scan_f32_group, true);
|
|
1196
1274
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, rwkv_wkv6_f32, true);
|
|
1197
1275
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, rwkv_wkv7_f32, true);
|
|
1198
1276
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction);
|
|
@@ -1502,6 +1580,11 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
|
1502
1580
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true);
|
|
1503
1581
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
|
|
1504
1582
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true);
|
|
1583
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REGLU, reglu, true);
|
|
1584
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU, geglu, true);
|
|
1585
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU, swiglu, true);
|
|
1586
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU_ERF, geglu_erf, true);
|
|
1587
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU_QUICK, geglu_quick, true);
|
|
1505
1588
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
|
1506
1589
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MEAN, mean, true);
|
|
1507
1590
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
|
|
@@ -1676,10 +1759,27 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
|
|
1676
1759
|
case GGML_UNARY_OP_SILU:
|
|
1677
1760
|
case GGML_UNARY_OP_ELU:
|
|
1678
1761
|
case GGML_UNARY_OP_NEG:
|
|
1762
|
+
case GGML_UNARY_OP_ABS:
|
|
1763
|
+
case GGML_UNARY_OP_SGN:
|
|
1764
|
+
case GGML_UNARY_OP_STEP:
|
|
1765
|
+
case GGML_UNARY_OP_HARDSWISH:
|
|
1766
|
+
case GGML_UNARY_OP_HARDSIGMOID:
|
|
1767
|
+
case GGML_UNARY_OP_EXP:
|
|
1679
1768
|
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
|
|
1680
1769
|
default:
|
|
1681
1770
|
return false;
|
|
1682
1771
|
}
|
|
1772
|
+
case GGML_OP_GLU:
|
|
1773
|
+
switch (ggml_get_glu_op(op)) {
|
|
1774
|
+
case GGML_GLU_OP_REGLU:
|
|
1775
|
+
case GGML_GLU_OP_GEGLU:
|
|
1776
|
+
case GGML_GLU_OP_SWIGLU:
|
|
1777
|
+
case GGML_GLU_OP_GEGLU_ERF:
|
|
1778
|
+
case GGML_GLU_OP_GEGLU_QUICK:
|
|
1779
|
+
return ggml_is_contiguous_1(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
|
|
1780
|
+
default:
|
|
1781
|
+
return false;
|
|
1782
|
+
}
|
|
1683
1783
|
case GGML_OP_NONE:
|
|
1684
1784
|
case GGML_OP_RESHAPE:
|
|
1685
1785
|
case GGML_OP_VIEW:
|
|
@@ -1710,7 +1810,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
|
|
1710
1810
|
case GGML_OP_MEAN:
|
|
1711
1811
|
case GGML_OP_SOFT_MAX:
|
|
1712
1812
|
case GGML_OP_GROUP_NORM:
|
|
1713
|
-
return has_simdgroup_reduction &&
|
|
1813
|
+
return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]);
|
|
1714
1814
|
case GGML_OP_RMS_NORM:
|
|
1715
1815
|
case GGML_OP_L2_NORM:
|
|
1716
1816
|
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
|
|
@@ -1852,9 +1952,10 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
|
|
1852
1952
|
}
|
|
1853
1953
|
}
|
|
1854
1954
|
|
|
1855
|
-
static
|
|
1955
|
+
static int ggml_metal_encode_node(
|
|
1856
1956
|
ggml_backend_t backend,
|
|
1857
1957
|
int idx,
|
|
1958
|
+
int idx_end,
|
|
1858
1959
|
id<MTLComputeCommandEncoder> encoder,
|
|
1859
1960
|
struct ggml_metal_mem_pool * mem_pool) {
|
|
1860
1961
|
struct ggml_backend_metal_context * ctx = backend->context;
|
|
@@ -1862,7 +1963,10 @@ static bool ggml_metal_encode_node(
|
|
|
1862
1963
|
|
|
1863
1964
|
struct ggml_cgraph * gf = ctx->gf;
|
|
1864
1965
|
|
|
1865
|
-
|
|
1966
|
+
enum ggml_op ops[8];
|
|
1967
|
+
|
|
1968
|
+
struct ggml_tensor ** nodes = ggml_graph_nodes(gf) + idx;
|
|
1969
|
+
struct ggml_tensor * node = nodes[0];
|
|
1866
1970
|
|
|
1867
1971
|
//GGML_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, idx, ggml_op_name(node->op));
|
|
1868
1972
|
|
|
@@ -1872,7 +1976,7 @@ static bool ggml_metal_encode_node(
|
|
|
1872
1976
|
struct ggml_tensor * dst = node;
|
|
1873
1977
|
|
|
1874
1978
|
if (ggml_is_empty(dst)) {
|
|
1875
|
-
return
|
|
1979
|
+
return 1;
|
|
1876
1980
|
}
|
|
1877
1981
|
|
|
1878
1982
|
switch (dst->op) {
|
|
@@ -1883,7 +1987,7 @@ static bool ggml_metal_encode_node(
|
|
|
1883
1987
|
case GGML_OP_PERMUTE:
|
|
1884
1988
|
{
|
|
1885
1989
|
// noop -> next node
|
|
1886
|
-
} return
|
|
1990
|
+
} return 1;
|
|
1887
1991
|
default:
|
|
1888
1992
|
{
|
|
1889
1993
|
} break;
|
|
@@ -1950,6 +2054,8 @@ static bool ggml_metal_encode_node(
|
|
|
1950
2054
|
id<MTLBuffer> id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil;
|
|
1951
2055
|
id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil;
|
|
1952
2056
|
|
|
2057
|
+
int n_fuse = 1;
|
|
2058
|
+
|
|
1953
2059
|
#if 0
|
|
1954
2060
|
GGML_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
|
|
1955
2061
|
if (src0) {
|
|
@@ -2021,37 +2127,15 @@ static bool ggml_metal_encode_node(
|
|
|
2021
2127
|
GGML_ASSERT(src0t == GGML_TYPE_F32);
|
|
2022
2128
|
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
|
2023
2129
|
|
|
2130
|
+
GGML_ASSERT(ggml_is_contiguous_rows(src0));
|
|
2131
|
+
GGML_ASSERT(ggml_is_contiguous_rows(src1));
|
|
2132
|
+
|
|
2024
2133
|
const size_t offs = 0;
|
|
2025
2134
|
|
|
2026
2135
|
bool bcast_row = false;
|
|
2027
2136
|
|
|
2028
2137
|
id<MTLComputePipelineState> pipeline = nil;
|
|
2029
2138
|
|
|
2030
|
-
if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
|
|
2031
|
-
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
2032
|
-
|
|
2033
|
-
// src1 is a row
|
|
2034
|
-
GGML_ASSERT(ne11 == 1);
|
|
2035
|
-
|
|
2036
|
-
switch (dst->op) {
|
|
2037
|
-
case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break;
|
|
2038
|
-
case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB_ROW].pipeline; break;
|
|
2039
|
-
case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline; break;
|
|
2040
|
-
case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline; break;
|
|
2041
|
-
default: GGML_ABORT("fatal error");
|
|
2042
|
-
}
|
|
2043
|
-
|
|
2044
|
-
bcast_row = true;
|
|
2045
|
-
} else {
|
|
2046
|
-
switch (dst->op) {
|
|
2047
|
-
case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; break;
|
|
2048
|
-
case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB].pipeline; break;
|
|
2049
|
-
case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break;
|
|
2050
|
-
case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break;
|
|
2051
|
-
default: GGML_ABORT("fatal error");
|
|
2052
|
-
}
|
|
2053
|
-
}
|
|
2054
|
-
|
|
2055
2139
|
ggml_metal_kargs_bin args = {
|
|
2056
2140
|
/*.ne00 =*/ ne00,
|
|
2057
2141
|
/*.ne01 =*/ ne01,
|
|
@@ -2078,12 +2162,119 @@ static bool ggml_metal_encode_node(
|
|
|
2078
2162
|
/*.nb2 =*/ nb2,
|
|
2079
2163
|
/*.nb3 =*/ nb3,
|
|
2080
2164
|
/*.offs =*/ offs,
|
|
2165
|
+
/*.o1 =*/ { offs_src1 },
|
|
2081
2166
|
};
|
|
2082
2167
|
|
|
2168
|
+
// c[0] = add(a, b[0])
|
|
2169
|
+
// c[1] = add(c[0], b[1])
|
|
2170
|
+
// c[2] = add(c[1], b[2])
|
|
2171
|
+
// ...
|
|
2172
|
+
if (ctx_dev->use_fusion) {
|
|
2173
|
+
ops[0] = GGML_OP_ADD;
|
|
2174
|
+
ops[1] = GGML_OP_ADD;
|
|
2175
|
+
ops[2] = GGML_OP_ADD;
|
|
2176
|
+
ops[3] = GGML_OP_ADD;
|
|
2177
|
+
ops[4] = GGML_OP_ADD;
|
|
2178
|
+
ops[5] = GGML_OP_ADD;
|
|
2179
|
+
ops[6] = GGML_OP_ADD;
|
|
2180
|
+
ops[7] = GGML_OP_ADD;
|
|
2181
|
+
|
|
2182
|
+
size_t offs_fuse;
|
|
2183
|
+
id<MTLBuffer> id_fuse;
|
|
2184
|
+
|
|
2185
|
+
// note: in metal, we sometimes encode the graph in parallel so we have to avoid fusing nodes
|
|
2186
|
+
// across splits. idx_end indicates the last node in the current split
|
|
2187
|
+
for (n_fuse = 0; n_fuse <= 6 && idx + n_fuse + 1 < idx_end; ++n_fuse) {
|
|
2188
|
+
if (!ggml_can_fuse(gf, idx + n_fuse, ops + n_fuse, 2)) {
|
|
2189
|
+
break;
|
|
2190
|
+
}
|
|
2191
|
+
|
|
2192
|
+
if (nodes[n_fuse] != nodes[n_fuse + 1]->src[0]) {
|
|
2193
|
+
break;
|
|
2194
|
+
}
|
|
2195
|
+
|
|
2196
|
+
// b[0] === b[1] === ...
|
|
2197
|
+
if (!ggml_are_same_layout(nodes[n_fuse]->src[1], nodes[n_fuse + 1]->src[1])) {
|
|
2198
|
+
break;
|
|
2199
|
+
}
|
|
2200
|
+
|
|
2201
|
+
// only fuse nodes if src1 is in the same Metal buffer
|
|
2202
|
+
id_fuse = ggml_metal_get_buffer(nodes[n_fuse + 1]->src[1], &offs_fuse);
|
|
2203
|
+
if (id_fuse != id_src1) {
|
|
2204
|
+
break;
|
|
2205
|
+
}
|
|
2206
|
+
|
|
2207
|
+
ctx_dev->fuse_cnt[nodes[n_fuse + 1]->op]++;
|
|
2208
|
+
|
|
2209
|
+
args.o1[n_fuse + 1] = offs_fuse;
|
|
2210
|
+
}
|
|
2211
|
+
|
|
2212
|
+
++n_fuse;
|
|
2213
|
+
|
|
2214
|
+
if (ctx_dev->debug_fusion > 1 && n_fuse > 1) {
|
|
2215
|
+
GGML_LOG_DEBUG("%s: fuse: ADD x %d\n", __func__, n_fuse);
|
|
2216
|
+
}
|
|
2217
|
+
}
|
|
2218
|
+
|
|
2219
|
+
if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
|
|
2220
|
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
2221
|
+
|
|
2222
|
+
// src1 is a row
|
|
2223
|
+
GGML_ASSERT(ne11 == 1);
|
|
2224
|
+
|
|
2225
|
+
switch (dst->op) {
|
|
2226
|
+
case GGML_OP_ADD:
|
|
2227
|
+
{
|
|
2228
|
+
switch (n_fuse) {
|
|
2229
|
+
case 1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4 ].pipeline; break;
|
|
2230
|
+
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2].pipeline; break;
|
|
2231
|
+
case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3].pipeline; break;
|
|
2232
|
+
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4].pipeline; break;
|
|
2233
|
+
case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5].pipeline; break;
|
|
2234
|
+
case 6: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6].pipeline; break;
|
|
2235
|
+
case 7: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7].pipeline; break;
|
|
2236
|
+
case 8: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8].pipeline; break;
|
|
2237
|
+
default: GGML_ABORT("fatal error");
|
|
2238
|
+
}
|
|
2239
|
+
} break;
|
|
2240
|
+
case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB_ROW_C4].pipeline; break;
|
|
2241
|
+
case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW_C4].pipeline; break;
|
|
2242
|
+
case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW_C4].pipeline; break;
|
|
2243
|
+
default: GGML_ABORT("fatal error");
|
|
2244
|
+
}
|
|
2245
|
+
|
|
2246
|
+
bcast_row = true;
|
|
2247
|
+
} else {
|
|
2248
|
+
switch (dst->op) {
|
|
2249
|
+
case GGML_OP_ADD:
|
|
2250
|
+
{
|
|
2251
|
+
switch (n_fuse) {
|
|
2252
|
+
case 1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD ].pipeline; break;
|
|
2253
|
+
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_2].pipeline; break;
|
|
2254
|
+
case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_3].pipeline; break;
|
|
2255
|
+
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_4].pipeline; break;
|
|
2256
|
+
case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_5].pipeline; break;
|
|
2257
|
+
case 6: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_6].pipeline; break;
|
|
2258
|
+
case 7: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_7].pipeline; break;
|
|
2259
|
+
case 8: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_8].pipeline; break;
|
|
2260
|
+
default: GGML_ABORT("fatal error");
|
|
2261
|
+
}
|
|
2262
|
+
} break;
|
|
2263
|
+
case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB].pipeline; break;
|
|
2264
|
+
case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break;
|
|
2265
|
+
case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break;
|
|
2266
|
+
default: GGML_ABORT("fatal error");
|
|
2267
|
+
}
|
|
2268
|
+
}
|
|
2269
|
+
|
|
2270
|
+
if (n_fuse > 1) {
|
|
2271
|
+
id_dst = ggml_metal_get_buffer(nodes[n_fuse - 1], &offs_dst);
|
|
2272
|
+
}
|
|
2273
|
+
|
|
2083
2274
|
[encoder setComputePipelineState:pipeline];
|
|
2084
2275
|
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
|
2085
2276
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
|
2086
|
-
[encoder setBuffer:id_src1 offset:
|
|
2277
|
+
[encoder setBuffer:id_src1 offset:0 atIndex:2];
|
|
2087
2278
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
|
2088
2279
|
|
|
2089
2280
|
if (bcast_row) {
|
|
@@ -2091,7 +2282,11 @@ static bool ggml_metal_encode_node(
|
|
|
2091
2282
|
|
|
2092
2283
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
2093
2284
|
} else {
|
|
2094
|
-
|
|
2285
|
+
int nth = 32;
|
|
2286
|
+
|
|
2287
|
+
while (16*nth < ne0 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
|
|
2288
|
+
nth *= 2;
|
|
2289
|
+
}
|
|
2095
2290
|
|
|
2096
2291
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
2097
2292
|
}
|
|
@@ -2216,12 +2411,13 @@ static bool ggml_metal_encode_node(
|
|
|
2216
2411
|
/*.nb2 =*/ pnb2,
|
|
2217
2412
|
/*.nb3 =*/ pnb3,
|
|
2218
2413
|
/*.offs =*/ offs,
|
|
2414
|
+
/*.o1 =*/ { offs_src1},
|
|
2219
2415
|
};
|
|
2220
2416
|
|
|
2221
2417
|
[encoder setComputePipelineState:pipeline];
|
|
2222
2418
|
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
|
2223
2419
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
|
2224
|
-
[encoder setBuffer:id_src1 offset:
|
|
2420
|
+
[encoder setBuffer:id_src1 offset:0 atIndex:2];
|
|
2225
2421
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
|
2226
2422
|
|
|
2227
2423
|
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
|
|
@@ -2233,7 +2429,9 @@ static bool ggml_metal_encode_node(
|
|
|
2233
2429
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
2234
2430
|
|
|
2235
2431
|
float scale;
|
|
2236
|
-
|
|
2432
|
+
float bias;
|
|
2433
|
+
memcpy(&scale, ((const int32_t *) dst->op_params) + 0, sizeof(float));
|
|
2434
|
+
memcpy(&bias, ((const int32_t *) dst->op_params) + 1, sizeof(float));
|
|
2237
2435
|
|
|
2238
2436
|
int64_t n = ggml_nelements(dst);
|
|
2239
2437
|
|
|
@@ -2250,6 +2448,7 @@ static bool ggml_metal_encode_node(
|
|
|
2250
2448
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
2251
2449
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
2252
2450
|
[encoder setBytes:&scale length:sizeof(scale) atIndex:2];
|
|
2451
|
+
[encoder setBytes:&bias length:sizeof(bias) atIndex:3];
|
|
2253
2452
|
|
|
2254
2453
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
2255
2454
|
} break;
|
|
@@ -2413,12 +2612,146 @@ static bool ggml_metal_encode_node(
|
|
|
2413
2612
|
|
|
2414
2613
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
2415
2614
|
} break;
|
|
2615
|
+
case GGML_UNARY_OP_ABS:
|
|
2616
|
+
{
|
|
2617
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ABS].pipeline;
|
|
2618
|
+
|
|
2619
|
+
[encoder setComputePipelineState:pipeline];
|
|
2620
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
2621
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
2622
|
+
|
|
2623
|
+
const int64_t n = ggml_nelements(dst);
|
|
2624
|
+
|
|
2625
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
2626
|
+
} break;
|
|
2627
|
+
case GGML_UNARY_OP_SGN:
|
|
2628
|
+
{
|
|
2629
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SGN].pipeline;
|
|
2630
|
+
|
|
2631
|
+
[encoder setComputePipelineState:pipeline];
|
|
2632
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
2633
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
2634
|
+
|
|
2635
|
+
const int64_t n = ggml_nelements(dst);
|
|
2636
|
+
|
|
2637
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
2638
|
+
} break;
|
|
2639
|
+
case GGML_UNARY_OP_STEP:
|
|
2640
|
+
{
|
|
2641
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_STEP].pipeline;
|
|
2642
|
+
|
|
2643
|
+
[encoder setComputePipelineState:pipeline];
|
|
2644
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
2645
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
2646
|
+
|
|
2647
|
+
const int64_t n = ggml_nelements(dst);
|
|
2648
|
+
|
|
2649
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
2650
|
+
} break;
|
|
2651
|
+
case GGML_UNARY_OP_HARDSWISH:
|
|
2652
|
+
{
|
|
2653
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_HARDSWISH].pipeline;
|
|
2654
|
+
|
|
2655
|
+
[encoder setComputePipelineState:pipeline];
|
|
2656
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
2657
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
2658
|
+
|
|
2659
|
+
const int64_t n = ggml_nelements(dst);
|
|
2660
|
+
|
|
2661
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
2662
|
+
} break;
|
|
2663
|
+
case GGML_UNARY_OP_HARDSIGMOID:
|
|
2664
|
+
{
|
|
2665
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_HARDSIGMOID].pipeline;
|
|
2666
|
+
|
|
2667
|
+
[encoder setComputePipelineState:pipeline];
|
|
2668
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
2669
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
2670
|
+
|
|
2671
|
+
const int64_t n = ggml_nelements(dst);
|
|
2672
|
+
|
|
2673
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
2674
|
+
} break;
|
|
2675
|
+
case GGML_UNARY_OP_EXP:
|
|
2676
|
+
{
|
|
2677
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_EXP].pipeline;
|
|
2678
|
+
|
|
2679
|
+
[encoder setComputePipelineState:pipeline];
|
|
2680
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
2681
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
2682
|
+
|
|
2683
|
+
const int64_t n = ggml_nelements(dst);
|
|
2684
|
+
|
|
2685
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
2686
|
+
} break;
|
|
2416
2687
|
default:
|
|
2417
2688
|
{
|
|
2418
2689
|
GGML_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));
|
|
2419
2690
|
GGML_ABORT("fatal error");
|
|
2420
2691
|
}
|
|
2421
2692
|
} break;
|
|
2693
|
+
case GGML_OP_GLU:
|
|
2694
|
+
{
|
|
2695
|
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
|
2696
|
+
|
|
2697
|
+
if (src1) {
|
|
2698
|
+
GGML_ASSERT(ggml_are_same_shape(src0, src1));
|
|
2699
|
+
}
|
|
2700
|
+
|
|
2701
|
+
id<MTLComputePipelineState> pipeline = nil;
|
|
2702
|
+
|
|
2703
|
+
switch (ggml_get_glu_op(node)) {
|
|
2704
|
+
case GGML_GLU_OP_REGLU:
|
|
2705
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REGLU].pipeline;
|
|
2706
|
+
break;
|
|
2707
|
+
case GGML_GLU_OP_GEGLU:
|
|
2708
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GEGLU].pipeline;
|
|
2709
|
+
break;
|
|
2710
|
+
case GGML_GLU_OP_SWIGLU:
|
|
2711
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU].pipeline;
|
|
2712
|
+
break;
|
|
2713
|
+
case GGML_GLU_OP_GEGLU_ERF:
|
|
2714
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GEGLU_ERF].pipeline;
|
|
2715
|
+
break;
|
|
2716
|
+
case GGML_GLU_OP_GEGLU_QUICK:
|
|
2717
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GEGLU_QUICK].pipeline;
|
|
2718
|
+
break;
|
|
2719
|
+
default:
|
|
2720
|
+
GGML_ABORT("fatal error");
|
|
2721
|
+
}
|
|
2722
|
+
|
|
2723
|
+
const int32_t swp = ((const int32_t *) dst->op_params)[1];
|
|
2724
|
+
|
|
2725
|
+
const int32_t i00 = swp ? ne0 : 0;
|
|
2726
|
+
const int32_t i10 = swp ? 0 : ne0;
|
|
2727
|
+
|
|
2728
|
+
ggml_metal_kargs_glu args = {
|
|
2729
|
+
/*.ne00 =*/ ne00,
|
|
2730
|
+
/*.nb01 =*/ nb01,
|
|
2731
|
+
/*.ne10 =*/ src1 ? ne10 : ne00,
|
|
2732
|
+
/*.nb11 =*/ src1 ? nb11 : nb01,
|
|
2733
|
+
/*.ne0 =*/ ne0,
|
|
2734
|
+
/*.nb1 =*/ nb1,
|
|
2735
|
+
/*.i00 =*/ src1 ? 0 : i00,
|
|
2736
|
+
/*.i10 =*/ src1 ? 0 : i10,
|
|
2737
|
+
};
|
|
2738
|
+
|
|
2739
|
+
[encoder setComputePipelineState:pipeline];
|
|
2740
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
2741
|
+
if (src1) {
|
|
2742
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
2743
|
+
} else {
|
|
2744
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
|
2745
|
+
}
|
|
2746
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
2747
|
+
[encoder setBytes:&args length:sizeof(args) atIndex:3];
|
|
2748
|
+
|
|
2749
|
+
const int64_t nrows = ggml_nrows(src0);
|
|
2750
|
+
|
|
2751
|
+
const int32_t nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00/2);
|
|
2752
|
+
|
|
2753
|
+
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
2754
|
+
} break;
|
|
2422
2755
|
case GGML_OP_SQR:
|
|
2423
2756
|
{
|
|
2424
2757
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
@@ -2573,10 +2906,7 @@ static bool ggml_metal_encode_node(
|
|
|
2573
2906
|
memcpy(&scale, ((const int32_t *) dst->op_params) + 0, sizeof(scale));
|
|
2574
2907
|
memcpy(&max_bias, ((const int32_t *) dst->op_params) + 1, sizeof(max_bias));
|
|
2575
2908
|
|
|
2576
|
-
const
|
|
2577
|
-
const int64_t nrows_y = src0->ne[1];
|
|
2578
|
-
|
|
2579
|
-
const uint32_t n_head = nrows_x/nrows_y;
|
|
2909
|
+
const uint32_t n_head = src0->ne[2];
|
|
2580
2910
|
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
|
|
2581
2911
|
|
|
2582
2912
|
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
|
@@ -2589,7 +2919,7 @@ static bool ggml_metal_encode_node(
|
|
|
2589
2919
|
id<MTLBuffer> h_src0 = h_src0 = ggml_metal_mem_pool_alloc(mem_pool, ggml_nbytes(src0));
|
|
2590
2920
|
if (!h_src0) {
|
|
2591
2921
|
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, ggml_nbytes(src0));
|
|
2592
|
-
return
|
|
2922
|
+
return 0;
|
|
2593
2923
|
}
|
|
2594
2924
|
|
|
2595
2925
|
offs_src0 = 0;
|
|
@@ -2636,6 +2966,18 @@ static bool ggml_metal_encode_node(
|
|
|
2636
2966
|
/*.ne00 =*/ ne00,
|
|
2637
2967
|
/*.ne01 =*/ ne01,
|
|
2638
2968
|
/*.ne02 =*/ ne02,
|
|
2969
|
+
/*.nb01 =*/ nb01,
|
|
2970
|
+
/*.nb02 =*/ nb02,
|
|
2971
|
+
/*.nb03 =*/ nb03,
|
|
2972
|
+
/*.ne11 =*/ ne11,
|
|
2973
|
+
/*.ne12 =*/ ne12,
|
|
2974
|
+
/*.ne13 =*/ ne13,
|
|
2975
|
+
/*.nb11 =*/ nb11,
|
|
2976
|
+
/*.nb12 =*/ nb12,
|
|
2977
|
+
/*.nb13 =*/ nb13,
|
|
2978
|
+
/*.nb1 =*/ nb1,
|
|
2979
|
+
/*.nb2 =*/ nb2,
|
|
2980
|
+
/*.nb3 =*/ nb3,
|
|
2639
2981
|
/*.scale =*/ scale,
|
|
2640
2982
|
/*.max_bias =*/ max_bias,
|
|
2641
2983
|
/*.m0 =*/ m0,
|
|
@@ -2655,7 +2997,7 @@ static bool ggml_metal_encode_node(
|
|
|
2655
2997
|
|
|
2656
2998
|
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
|
2657
2999
|
|
|
2658
|
-
[encoder dispatchThreadgroups:MTLSizeMake(ne01
|
|
3000
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
2659
3001
|
} break;
|
|
2660
3002
|
case GGML_OP_DIAG_MASK_INF:
|
|
2661
3003
|
{
|
|
@@ -2729,71 +3071,92 @@ static bool ggml_metal_encode_node(
|
|
|
2729
3071
|
struct ggml_tensor * src3 = node->src[3];
|
|
2730
3072
|
struct ggml_tensor * src4 = node->src[4];
|
|
2731
3073
|
struct ggml_tensor * src5 = node->src[5];
|
|
3074
|
+
struct ggml_tensor * src6 = node->src[6];
|
|
2732
3075
|
|
|
2733
3076
|
GGML_ASSERT(src3);
|
|
2734
3077
|
GGML_ASSERT(src4);
|
|
2735
3078
|
GGML_ASSERT(src5);
|
|
3079
|
+
GGML_ASSERT(src6);
|
|
2736
3080
|
|
|
2737
3081
|
size_t offs_src3 = 0;
|
|
2738
3082
|
size_t offs_src4 = 0;
|
|
2739
3083
|
size_t offs_src5 = 0;
|
|
3084
|
+
size_t offs_src6 = 0;
|
|
2740
3085
|
|
|
2741
3086
|
id<MTLBuffer> id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil;
|
|
2742
3087
|
id<MTLBuffer> id_src4 = src4 ? ggml_metal_get_buffer(src4, &offs_src4) : nil;
|
|
2743
3088
|
id<MTLBuffer> id_src5 = src5 ? ggml_metal_get_buffer(src5, &offs_src5) : nil;
|
|
3089
|
+
id<MTLBuffer> id_src6 = src6 ? ggml_metal_get_buffer(src6, &offs_src6) : nil;
|
|
2744
3090
|
|
|
2745
|
-
const int64_t ne30 = src3->ne[0];
|
|
3091
|
+
const int64_t ne30 = src3->ne[0];
|
|
2746
3092
|
const int64_t ne31 = src3->ne[1]; GGML_UNUSED(ne31);
|
|
2747
3093
|
|
|
2748
|
-
const uint64_t nb30 = src3->nb[0];
|
|
3094
|
+
const uint64_t nb30 = src3->nb[0]; GGML_UNUSED(nb30);
|
|
2749
3095
|
const uint64_t nb31 = src3->nb[1];
|
|
2750
3096
|
|
|
2751
3097
|
const int64_t ne40 = src4->ne[0]; GGML_UNUSED(ne40);
|
|
2752
|
-
const int64_t ne41 = src4->ne[1];
|
|
3098
|
+
const int64_t ne41 = src4->ne[1];
|
|
2753
3099
|
const int64_t ne42 = src4->ne[2]; GGML_UNUSED(ne42);
|
|
3100
|
+
const int64_t ne43 = src4->ne[3]; GGML_UNUSED(ne43);
|
|
2754
3101
|
|
|
2755
|
-
const uint64_t nb40 = src4->nb[0];
|
|
3102
|
+
const uint64_t nb40 = src4->nb[0]; GGML_UNUSED(nb40);
|
|
2756
3103
|
const uint64_t nb41 = src4->nb[1];
|
|
2757
3104
|
const uint64_t nb42 = src4->nb[2];
|
|
3105
|
+
const uint64_t nb43 = src4->nb[3];
|
|
2758
3106
|
|
|
2759
3107
|
const int64_t ne50 = src5->ne[0]; GGML_UNUSED(ne50);
|
|
2760
3108
|
const int64_t ne51 = src5->ne[1]; GGML_UNUSED(ne51);
|
|
2761
3109
|
const int64_t ne52 = src5->ne[2]; GGML_UNUSED(ne52);
|
|
3110
|
+
const int64_t ne53 = src5->ne[3]; GGML_UNUSED(ne53);
|
|
2762
3111
|
|
|
2763
|
-
const uint64_t nb50 = src5->nb[0];
|
|
3112
|
+
const uint64_t nb50 = src5->nb[0]; GGML_UNUSED(nb50);
|
|
2764
3113
|
const uint64_t nb51 = src5->nb[1];
|
|
2765
3114
|
const uint64_t nb52 = src5->nb[2];
|
|
3115
|
+
const uint64_t nb53 = src5->nb[3];
|
|
3116
|
+
|
|
3117
|
+
const int64_t ne60 = src6->ne[0]; GGML_UNUSED(ne60);
|
|
3118
|
+
|
|
3119
|
+
const uint64_t nb60 = src6->nb[0]; GGML_UNUSED(nb60);
|
|
2766
3120
|
|
|
2767
3121
|
const int64_t d_state = ne00;
|
|
2768
3122
|
const int64_t d_inner = ne01;
|
|
2769
|
-
const int64_t
|
|
2770
|
-
const int64_t
|
|
3123
|
+
const int64_t n_head = ne02;
|
|
3124
|
+
const int64_t n_group = ne41;
|
|
3125
|
+
const int64_t n_seq_tokens = ne12;
|
|
3126
|
+
const int64_t n_seqs = ne13;
|
|
2771
3127
|
|
|
2772
|
-
id<MTLComputePipelineState> pipeline =
|
|
3128
|
+
id<MTLComputePipelineState> pipeline = nil;
|
|
3129
|
+
|
|
3130
|
+
if (ne30 == 1) {
|
|
3131
|
+
// Mamba-2
|
|
3132
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP].pipeline;
|
|
3133
|
+
} else {
|
|
3134
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline;
|
|
3135
|
+
}
|
|
2773
3136
|
|
|
2774
3137
|
ggml_metal_kargs_ssm_scan args = {
|
|
2775
|
-
/*.d_state
|
|
2776
|
-
/*.d_inner
|
|
3138
|
+
/*.d_state =*/ d_state,
|
|
3139
|
+
/*.d_inner =*/ d_inner,
|
|
3140
|
+
/*.n_head =*/ n_head,
|
|
3141
|
+
/*.n_group =*/ n_group,
|
|
2777
3142
|
/*.n_seq_tokens =*/ n_seq_tokens,
|
|
2778
|
-
/*.n_seqs
|
|
2779
|
-
/*.
|
|
2780
|
-
/*.nb01
|
|
2781
|
-
/*.nb02
|
|
2782
|
-
/*.
|
|
2783
|
-
/*.nb11
|
|
2784
|
-
/*.nb12
|
|
2785
|
-
/*.nb13
|
|
2786
|
-
/*.
|
|
2787
|
-
/*.
|
|
2788
|
-
/*.
|
|
2789
|
-
/*.
|
|
2790
|
-
/*.
|
|
2791
|
-
/*.
|
|
2792
|
-
/*.
|
|
2793
|
-
/*.
|
|
2794
|
-
/*.
|
|
2795
|
-
/*.nb51 =*/ nb51,
|
|
2796
|
-
/*.nb52 =*/ nb52,
|
|
3143
|
+
/*.n_seqs =*/ n_seqs,
|
|
3144
|
+
/*.s_off =*/ ggml_nelements(src1) * sizeof(float),
|
|
3145
|
+
/*.nb01 =*/ nb01,
|
|
3146
|
+
/*.nb02 =*/ nb02,
|
|
3147
|
+
/*.nb03 =*/ nb03,
|
|
3148
|
+
/*.nb11 =*/ nb11,
|
|
3149
|
+
/*.nb12 =*/ nb12,
|
|
3150
|
+
/*.nb13 =*/ nb13,
|
|
3151
|
+
/*.nb21 =*/ nb21,
|
|
3152
|
+
/*.nb22 =*/ nb22,
|
|
3153
|
+
/*.nb31 =*/ nb31,
|
|
3154
|
+
/*.nb41 =*/ nb41,
|
|
3155
|
+
/*.nb42 =*/ nb42,
|
|
3156
|
+
/*.nb43 =*/ nb43,
|
|
3157
|
+
/*.nb51 =*/ nb51,
|
|
3158
|
+
/*.nb52 =*/ nb52,
|
|
3159
|
+
/*.nb53 =*/ nb53,
|
|
2797
3160
|
};
|
|
2798
3161
|
|
|
2799
3162
|
[encoder setComputePipelineState:pipeline];
|
|
@@ -2803,10 +3166,27 @@ static bool ggml_metal_encode_node(
|
|
|
2803
3166
|
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
|
|
2804
3167
|
[encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
|
|
2805
3168
|
[encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
|
|
2806
|
-
[encoder setBuffer:
|
|
2807
|
-
[encoder
|
|
3169
|
+
[encoder setBuffer:id_src6 offset:offs_src6 atIndex:6];
|
|
3170
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:7];
|
|
3171
|
+
[encoder setBytes:&args length:sizeof(args) atIndex:8];
|
|
3172
|
+
|
|
3173
|
+
// One shared memory bucket for each simd group in the threadgroup
|
|
3174
|
+
// NOTE: Metal kernels require the buffer size to be multiple of 16 bytes
|
|
3175
|
+
// https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength
|
|
3176
|
+
if (d_state >= 32) {
|
|
3177
|
+
GGML_ASSERT((int64_t)(d_state / 32) <= 32);
|
|
3178
|
+
const int64_t shmem_size = 32;
|
|
3179
|
+
GGML_ASSERT(d_state <= (int64_t)pipeline.maxTotalThreadsPerThreadgroup);
|
|
3180
|
+
[encoder setThreadgroupMemoryLength:(shmem_size)*sizeof(float) atIndex:0];
|
|
3181
|
+
}
|
|
2808
3182
|
|
|
2809
|
-
|
|
3183
|
+
if (ne30 == 1) {
|
|
3184
|
+
// Mamba-2
|
|
3185
|
+
[encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(d_state, 1, 1)];
|
|
3186
|
+
} else {
|
|
3187
|
+
GGML_ASSERT(d_inner == 1);
|
|
3188
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n_head, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(d_state, 1, 1)];
|
|
3189
|
+
}
|
|
2810
3190
|
} break;
|
|
2811
3191
|
case GGML_OP_RWKV_WKV6:
|
|
2812
3192
|
{
|
|
@@ -3426,7 +3806,7 @@ static bool ggml_metal_encode_node(
|
|
|
3426
3806
|
id<MTLBuffer> h_src1 = ggml_metal_mem_pool_alloc(mem_pool, s_src1);
|
|
3427
3807
|
if (!h_src1) {
|
|
3428
3808
|
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_src1);
|
|
3429
|
-
return
|
|
3809
|
+
return 0;
|
|
3430
3810
|
}
|
|
3431
3811
|
|
|
3432
3812
|
const int64_t neh0 = ne0;
|
|
@@ -3442,7 +3822,7 @@ static bool ggml_metal_encode_node(
|
|
|
3442
3822
|
id<MTLBuffer> h_dst = ggml_metal_mem_pool_alloc(mem_pool, s_dst);
|
|
3443
3823
|
if (!h_dst) {
|
|
3444
3824
|
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_dst);
|
|
3445
|
-
return
|
|
3825
|
+
return 0;
|
|
3446
3826
|
}
|
|
3447
3827
|
|
|
3448
3828
|
// tokens per expert
|
|
@@ -3450,7 +3830,7 @@ static bool ggml_metal_encode_node(
|
|
|
3450
3830
|
id<MTLBuffer> h_tpe = ggml_metal_mem_pool_alloc(mem_pool, s_tpe);
|
|
3451
3831
|
if (!h_tpe) {
|
|
3452
3832
|
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_tpe);
|
|
3453
|
-
return
|
|
3833
|
+
return 0;
|
|
3454
3834
|
}
|
|
3455
3835
|
|
|
3456
3836
|
// id map
|
|
@@ -3459,7 +3839,7 @@ static bool ggml_metal_encode_node(
|
|
|
3459
3839
|
id<MTLBuffer> h_ids = ggml_metal_mem_pool_alloc(mem_pool, s_ids);
|
|
3460
3840
|
if (!h_ids) {
|
|
3461
3841
|
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_ids);
|
|
3462
|
-
return
|
|
3842
|
+
return 0;
|
|
3463
3843
|
}
|
|
3464
3844
|
|
|
3465
3845
|
{
|
|
@@ -3891,12 +4271,95 @@ static bool ggml_metal_encode_node(
|
|
|
3891
4271
|
case GGML_OP_RMS_NORM:
|
|
3892
4272
|
{
|
|
3893
4273
|
GGML_ASSERT(ne00 % 4 == 0);
|
|
3894
|
-
GGML_ASSERT(
|
|
4274
|
+
GGML_ASSERT(ggml_is_contiguous_rows(src0));
|
|
3895
4275
|
|
|
3896
4276
|
float eps;
|
|
3897
4277
|
memcpy(&eps, dst->op_params, sizeof(float));
|
|
3898
4278
|
|
|
3899
|
-
|
|
4279
|
+
ggml_metal_kargs_rms_norm args = {
|
|
4280
|
+
/*.ne00 =*/ ne00,
|
|
4281
|
+
/*.ne00_4 =*/ ne00/4,
|
|
4282
|
+
/*.nb1 =*/ nb1,
|
|
4283
|
+
/*.nb2 =*/ nb2,
|
|
4284
|
+
/*.nb3 =*/ nb3,
|
|
4285
|
+
/*.eps =*/ eps,
|
|
4286
|
+
/*.nef1 =*/ { ne01 },
|
|
4287
|
+
/*.nef2 =*/ { ne02 },
|
|
4288
|
+
/*.nef3 =*/ { ne03 },
|
|
4289
|
+
/*.nbf1 =*/ { nb01 },
|
|
4290
|
+
/*.nbf2 =*/ { nb02 },
|
|
4291
|
+
/*.nbf3 =*/ { nb03 },
|
|
4292
|
+
};
|
|
4293
|
+
|
|
4294
|
+
size_t offs_fuse[2] = { 0, 0 };
|
|
4295
|
+
id<MTLBuffer> id_fuse[2] = { id_src0, id_src0 };
|
|
4296
|
+
|
|
4297
|
+
// d[0] = rms_norm(a)
|
|
4298
|
+
// d[1] = mul(d[0], b)
|
|
4299
|
+
// d[2] = add(d[1], c)
|
|
4300
|
+
if (ctx_dev->use_fusion) {
|
|
4301
|
+
ops[0] = GGML_OP_RMS_NORM;
|
|
4302
|
+
ops[1] = GGML_OP_MUL;
|
|
4303
|
+
ops[2] = GGML_OP_ADD;
|
|
4304
|
+
|
|
4305
|
+
for (n_fuse = 0; n_fuse <= 1 && idx + n_fuse + 1 < idx_end; ++n_fuse) {
|
|
4306
|
+
if (!ggml_can_fuse(gf, idx + n_fuse, ops + n_fuse, 2)) {
|
|
4307
|
+
break;
|
|
4308
|
+
}
|
|
4309
|
+
|
|
4310
|
+
if (nodes[n_fuse] != nodes[n_fuse + 1]->src[0]) {
|
|
4311
|
+
break;
|
|
4312
|
+
}
|
|
4313
|
+
|
|
4314
|
+
if (nodes[n_fuse + 1]->src[1]->ne[0] != node->ne[0]) {
|
|
4315
|
+
break;
|
|
4316
|
+
}
|
|
4317
|
+
|
|
4318
|
+
if (!ggml_is_contiguous_rows(nodes[n_fuse + 1]->src[1])) {
|
|
4319
|
+
break;
|
|
4320
|
+
}
|
|
4321
|
+
|
|
4322
|
+
if (nodes[n_fuse + 1]->type != GGML_TYPE_F32) {
|
|
4323
|
+
break;
|
|
4324
|
+
}
|
|
4325
|
+
|
|
4326
|
+
ctx_dev->fuse_cnt[nodes[n_fuse + 1]->op]++;
|
|
4327
|
+
|
|
4328
|
+
id_fuse[n_fuse] = ggml_metal_get_buffer(nodes[n_fuse + 1]->src[1], &offs_fuse[n_fuse]);
|
|
4329
|
+
|
|
4330
|
+
args.nef1[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->ne[1];
|
|
4331
|
+
args.nef2[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->ne[2];
|
|
4332
|
+
args.nef3[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->ne[3];
|
|
4333
|
+
|
|
4334
|
+
args.nbf1[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->nb[1];
|
|
4335
|
+
args.nbf2[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->nb[2];
|
|
4336
|
+
args.nbf3[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->nb[3];
|
|
4337
|
+
}
|
|
4338
|
+
|
|
4339
|
+
++n_fuse;
|
|
4340
|
+
|
|
4341
|
+
if (ctx_dev->debug_fusion > 1 && n_fuse > 1) {
|
|
4342
|
+
if (n_fuse == 2) {
|
|
4343
|
+
GGML_LOG_DEBUG("%s: fuse: RMS_NORM + MUL\n", __func__);
|
|
4344
|
+
}
|
|
4345
|
+
if (n_fuse == 3) {
|
|
4346
|
+
GGML_LOG_DEBUG("%s: fuse: RMS_NORM + MUL + ADD\n", __func__);
|
|
4347
|
+
}
|
|
4348
|
+
}
|
|
4349
|
+
}
|
|
4350
|
+
|
|
4351
|
+
if (n_fuse > 1) {
|
|
4352
|
+
id_dst = ggml_metal_get_buffer(nodes[n_fuse - 1], &offs_dst);
|
|
4353
|
+
}
|
|
4354
|
+
|
|
4355
|
+
id<MTLComputePipelineState> pipeline;
|
|
4356
|
+
|
|
4357
|
+
switch (n_fuse) {
|
|
4358
|
+
case 1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM ].pipeline; break;
|
|
4359
|
+
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL ].pipeline; break;
|
|
4360
|
+
case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD].pipeline; break;
|
|
4361
|
+
default: GGML_ABORT("unsupported n_fuse = %d\n", n_fuse);
|
|
4362
|
+
}
|
|
3900
4363
|
|
|
3901
4364
|
int nth = 32; // SIMD width
|
|
3902
4365
|
|
|
@@ -3907,23 +4370,16 @@ static bool ggml_metal_encode_node(
|
|
|
3907
4370
|
nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
|
|
3908
4371
|
nth = MIN(nth, ne00/4);
|
|
3909
4372
|
|
|
3910
|
-
ggml_metal_kargs_rms_norm args = {
|
|
3911
|
-
/*.ne00 =*/ ne00,
|
|
3912
|
-
/*.ne00_4 =*/ ne00/4,
|
|
3913
|
-
/*.nb01 =*/ nb01,
|
|
3914
|
-
/*.eps =*/ eps,
|
|
3915
|
-
};
|
|
3916
|
-
|
|
3917
4373
|
[encoder setComputePipelineState:pipeline];
|
|
3918
|
-
[encoder setBytes:&args length:sizeof(args)
|
|
3919
|
-
[encoder setBuffer:id_src0
|
|
3920
|
-
[encoder setBuffer:
|
|
4374
|
+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
|
4375
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
|
4376
|
+
[encoder setBuffer:id_fuse[0] offset:offs_fuse[0] atIndex:2];
|
|
4377
|
+
[encoder setBuffer:id_fuse[1] offset:offs_fuse[1] atIndex:3];
|
|
4378
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:4];
|
|
3921
4379
|
|
|
3922
4380
|
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
|
3923
4381
|
|
|
3924
|
-
|
|
3925
|
-
|
|
3926
|
-
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
4382
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
3927
4383
|
} break;
|
|
3928
4384
|
case GGML_OP_L2_NORM:
|
|
3929
4385
|
{
|
|
@@ -4908,7 +5364,11 @@ static bool ggml_metal_encode_node(
|
|
|
4908
5364
|
/*.nb21 =*/ nb21,
|
|
4909
5365
|
/*.nb22 =*/ nb22,
|
|
4910
5366
|
/*.nb23 =*/ nb23,
|
|
5367
|
+
/*.ne32 =*/ ne32,
|
|
5368
|
+
/*.ne33 =*/ ne33,
|
|
4911
5369
|
/*.nb31 =*/ nb31,
|
|
5370
|
+
/*.nb32 =*/ nb32,
|
|
5371
|
+
/*.nb33 =*/ nb33,
|
|
4912
5372
|
/*.ne1 =*/ ne1,
|
|
4913
5373
|
/*.ne2 =*/ ne2,
|
|
4914
5374
|
/*.scale =*/ scale,
|
|
@@ -5314,7 +5774,7 @@ static bool ggml_metal_encode_node(
|
|
|
5314
5774
|
}
|
|
5315
5775
|
}
|
|
5316
5776
|
|
|
5317
|
-
return
|
|
5777
|
+
return n_fuse;
|
|
5318
5778
|
}
|
|
5319
5779
|
|
|
5320
5780
|
static enum ggml_status ggml_metal_graph_compute(
|
|
@@ -5820,20 +6280,26 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
|
|
|
5820
6280
|
struct ggml_metal_mem_pool * mem_pool = ctx->cmd_bufs[cb_idx].mem_pool;
|
|
5821
6281
|
ggml_metal_mem_pool_reset(mem_pool);
|
|
5822
6282
|
|
|
5823
|
-
for (int idx = node_start; idx < node_end;
|
|
6283
|
+
for (int idx = node_start; idx < node_end;) {
|
|
5824
6284
|
if (should_capture) {
|
|
5825
6285
|
[encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
|
|
5826
6286
|
}
|
|
5827
6287
|
|
|
5828
|
-
const
|
|
6288
|
+
const int res = ggml_metal_encode_node(backend, idx, node_end, encoder, mem_pool);
|
|
6289
|
+
if (idx + res > node_end) {
|
|
6290
|
+
GGML_ABORT("fusion error: nodes spanning multiple encoders have been fused. this indicates a bug in the fusion logic %s",
|
|
6291
|
+
"https://github.com/ggml-org/llama.cpp/pull/14849");
|
|
6292
|
+
}
|
|
5829
6293
|
|
|
5830
6294
|
if (should_capture) {
|
|
5831
6295
|
[encoder popDebugGroup];
|
|
5832
6296
|
}
|
|
5833
6297
|
|
|
5834
|
-
if (
|
|
6298
|
+
if (res == 0) {
|
|
5835
6299
|
break;
|
|
5836
6300
|
}
|
|
6301
|
+
|
|
6302
|
+
idx += res;
|
|
5837
6303
|
}
|
|
5838
6304
|
|
|
5839
6305
|
[encoder endEncoding];
|