@novastera-oss/llamarn 0.2.7 → 0.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/src/main/cpp/include/llama.h +8 -3
- package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86/libggml.so +0 -0
- package/android/src/main/jniLibs/x86/libllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
- package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
- package/cpp/LlamaCppModel.cpp +56 -22
- package/cpp/build-info.cpp +2 -2
- package/cpp/llama.cpp/CMakeLists.txt +1 -2
- package/cpp/llama.cpp/README.md +4 -5
- package/cpp/llama.cpp/build-xcframework.sh +1 -1
- package/cpp/llama.cpp/common/CMakeLists.txt +4 -5
- package/cpp/llama.cpp/common/arg.cpp +24 -0
- package/cpp/llama.cpp/common/chat.cpp +37 -20
- package/cpp/llama.cpp/common/chat.h +2 -0
- package/cpp/llama.cpp/common/common.cpp +3 -0
- package/cpp/llama.cpp/common/common.h +5 -0
- package/cpp/llama.cpp/common/json-schema-to-grammar.cpp +3 -46
- package/cpp/llama.cpp/convert_hf_to_gguf.py +860 -23
- package/cpp/llama.cpp/convert_hf_to_gguf_update.py +9 -0
- package/cpp/llama.cpp/ggml/CMakeLists.txt +8 -2
- package/cpp/llama.cpp/ggml/include/ggml-backend.h +1 -1
- package/cpp/llama.cpp/ggml/include/ggml-cpu.h +2 -0
- package/cpp/llama.cpp/ggml/include/ggml.h +206 -10
- package/cpp/llama.cpp/ggml/src/CMakeLists.txt +17 -1
- package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +0 -8
- package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +36 -18
- package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +68 -5
- package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +16 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +37 -3
- package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +10 -9
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +109 -108
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +1027 -1038
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +53 -52
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +56 -55
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +42 -41
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +24 -23
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +29 -28
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +30 -29
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +83 -82
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/repack.cpp +20 -19
- package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +3 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +9 -3
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +111 -103
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +3 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +1405 -240
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +25 -24
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +56 -40
- package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +212 -34
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +35 -11
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +187 -54
- package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +71 -29
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +22 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +4 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +8 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +6 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +14 -12
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +5 -3
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +15 -10
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +12 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +269 -110
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cu +19 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cuh +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +2 -8
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cu +257 -87
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cuh +2 -3
- package/cpp/llama.cpp/ggml/src/ggml-cuda/rope.cu +21 -27
- package/cpp/llama.cpp/ggml/src/ggml-cuda/scale.cu +8 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/softmax.cu +119 -58
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +192 -52
- package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cu +5 -18
- package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cuh +0 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +97 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +11 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/upscale.cu +92 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +14 -5
- package/cpp/llama.cpp/ggml/src/ggml-impl.h +125 -183
- package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -2
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +51 -9
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +394 -80
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +616 -239
- package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cuh +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +741 -571
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl +337 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/set_rows.cl +95 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
- package/cpp/llama.cpp/ggml/src/ggml-quants.c +6 -6
- package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +5 -6
- package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +1 -24
- package/cpp/llama.cpp/ggml/src/ggml-sycl/concat.cpp +28 -41
- package/cpp/llama.cpp/ggml/src/ggml-sycl/conv.cpp +4 -10
- package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +99 -166
- package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +94 -72
- package/cpp/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +49 -67
- package/cpp/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +697 -1098
- package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +18 -9
- package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +6 -9
- package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +104 -62
- package/cpp/llama.cpp/ggml/src/ggml-sycl/gla.cpp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +60 -80
- package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +132 -201
- package/cpp/llama.cpp/ggml/src/ggml-sycl/norm.cpp +55 -74
- package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +39 -38
- package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +131 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.hpp +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +3 -3
- package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +3 -8
- package/cpp/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +12 -16
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +12 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +767 -292
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +58 -7
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +28 -23
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +14 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +38 -32
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +32 -27
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +44 -12
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +15 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +128 -72
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +38 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +12 -3
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +20 -4
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +23 -3
- package/cpp/llama.cpp/ggml/src/ggml.c +449 -72
- package/cpp/llama.cpp/ggml/src/gguf.cpp +13 -2
- package/cpp/llama.cpp/gguf-py/gguf/constants.py +285 -0
- package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +27 -0
- package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +137 -21
- package/cpp/llama.cpp/gguf-py/gguf/vocab.py +109 -7
- package/cpp/llama.cpp/gguf-py/pyproject.toml +2 -2
- package/cpp/llama.cpp/include/llama.h +8 -43
- package/cpp/llama.cpp/models/templates/Mistral-Small-3.2-24B-Instruct-2506.jinja +124 -0
- package/cpp/llama.cpp/src/llama-arch.cpp +265 -3
- package/cpp/llama.cpp/src/llama-arch.h +36 -1
- package/cpp/llama.cpp/src/llama-batch.cpp +596 -359
- package/cpp/llama.cpp/src/llama-batch.h +105 -70
- package/cpp/llama.cpp/src/llama-chat.cpp +26 -6
- package/cpp/llama.cpp/src/llama-chat.h +1 -0
- package/cpp/llama.cpp/src/llama-context.cpp +101 -107
- package/cpp/llama.cpp/src/llama-context.h +13 -13
- package/cpp/llama.cpp/src/llama-graph.cpp +286 -404
- package/cpp/llama.cpp/src/llama-graph.h +78 -79
- package/cpp/llama.cpp/src/llama-hparams.cpp +11 -1
- package/cpp/llama.cpp/src/llama-hparams.h +11 -0
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +74 -66
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +23 -26
- package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +312 -157
- package/cpp/llama.cpp/src/llama-kv-cache-unified.h +79 -46
- package/cpp/llama.cpp/src/llama-kv-cells.h +97 -21
- package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +73 -69
- package/cpp/llama.cpp/src/llama-memory-hybrid.h +19 -22
- package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +88 -77
- package/cpp/llama.cpp/src/llama-memory-recurrent.h +15 -20
- package/cpp/llama.cpp/src/llama-memory.cpp +17 -0
- package/cpp/llama.cpp/src/llama-memory.h +21 -22
- package/cpp/llama.cpp/src/llama-model-saver.cpp +1 -0
- package/cpp/llama.cpp/src/llama-model.cpp +5301 -2922
- package/cpp/llama.cpp/src/llama-model.h +40 -0
- package/cpp/llama.cpp/src/llama-quant.cpp +88 -5
- package/cpp/llama.cpp/src/llama-vocab.cpp +37 -3
- package/cpp/llama.cpp/src/llama-vocab.h +42 -0
- package/cpp/rn-utils.h +3 -0
- package/ios/include/chat.h +2 -0
- package/ios/include/common.h +5 -0
- package/ios/include/llama.h +8 -43
- package/ios/libs/llama.xcframework/Info.plist +19 -19
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5059 -4863
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4834
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3889 -3742
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4834
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3891 -3744
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5059 -4863
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4834
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3889 -3742
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5095 -4900
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5066 -4871
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3919 -3773
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/package.json +1 -1
- package/cpp/llama.cpp/ggml/include/ggml-kompute.h +0 -50
- package/cpp/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
- package/cpp/llama.cpp/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
|
@@ -48,22 +48,28 @@ static struct ggml_backend_metal_device_context {
|
|
|
48
48
|
int mtl_device_ref_count;
|
|
49
49
|
id<MTLLibrary> mtl_library;
|
|
50
50
|
|
|
51
|
+
NSLock * mtl_lock;
|
|
52
|
+
|
|
51
53
|
bool has_simdgroup_reduction;
|
|
52
54
|
bool has_simdgroup_mm;
|
|
53
55
|
bool has_residency_sets;
|
|
54
56
|
bool has_bfloat;
|
|
55
57
|
bool use_bfloat;
|
|
56
58
|
|
|
59
|
+
size_t max_size;
|
|
60
|
+
|
|
57
61
|
char name[128];
|
|
58
62
|
} g_ggml_ctx_dev_main = {
|
|
59
63
|
/*.mtl_device =*/ nil,
|
|
60
64
|
/*.mtl_device_ref_count =*/ 0,
|
|
61
65
|
/*.mtl_library =*/ nil,
|
|
66
|
+
/*.mtl_lock =*/ nil,
|
|
62
67
|
/*.has_simdgroup_reduction =*/ false,
|
|
63
68
|
/*.has_simdgroup_mm =*/ false,
|
|
64
69
|
/*.has_residency_sets =*/ false,
|
|
65
70
|
/*.has_bfloat =*/ false,
|
|
66
71
|
/*.use_bfloat =*/ false,
|
|
72
|
+
/*.max_size =*/ 0,
|
|
67
73
|
/*.name =*/ "",
|
|
68
74
|
};
|
|
69
75
|
|
|
@@ -71,6 +77,10 @@ static struct ggml_backend_metal_device_context {
|
|
|
71
77
|
static id<MTLDevice> ggml_backend_metal_device_acq(struct ggml_backend_metal_device_context * ctx) {
|
|
72
78
|
assert(ctx != NULL);
|
|
73
79
|
|
|
80
|
+
if (ctx->mtl_lock == nil) {
|
|
81
|
+
ctx->mtl_lock = [[NSLock alloc] init];
|
|
82
|
+
}
|
|
83
|
+
|
|
74
84
|
if (ctx->mtl_device == nil) {
|
|
75
85
|
ctx->mtl_device = MTLCreateSystemDefaultDevice();
|
|
76
86
|
}
|
|
@@ -94,6 +104,8 @@ static id<MTLDevice> ggml_backend_metal_device_acq(struct ggml_backend_metal_dev
|
|
|
94
104
|
ctx->use_bfloat = false;
|
|
95
105
|
#endif
|
|
96
106
|
|
|
107
|
+
ctx->max_size = ctx->mtl_device.maxBufferLength;
|
|
108
|
+
|
|
97
109
|
strncpy(ctx->name, [[ctx->mtl_device name] UTF8String], sizeof(ctx->name) - 1);
|
|
98
110
|
}
|
|
99
111
|
|
|
@@ -110,6 +122,11 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
|
|
|
110
122
|
ctx->mtl_device_ref_count--;
|
|
111
123
|
|
|
112
124
|
if (ctx->mtl_device_ref_count == 0) {
|
|
125
|
+
if (ctx->mtl_lock) {
|
|
126
|
+
[ctx->mtl_lock release];
|
|
127
|
+
ctx->mtl_lock = nil;
|
|
128
|
+
}
|
|
129
|
+
|
|
113
130
|
if (ctx->mtl_library) {
|
|
114
131
|
[ctx->mtl_library release];
|
|
115
132
|
ctx->mtl_library = nil;
|
|
@@ -185,20 +202,33 @@ enum ggml_metal_kernel_type {
|
|
|
185
202
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,
|
|
186
203
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
|
|
187
204
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
|
|
205
|
+
GGML_METAL_KERNEL_TYPE_SET_ROWS_F32,
|
|
206
|
+
GGML_METAL_KERNEL_TYPE_SET_ROWS_F16,
|
|
207
|
+
GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16,
|
|
208
|
+
GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0,
|
|
209
|
+
GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0,
|
|
210
|
+
GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1,
|
|
211
|
+
GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0,
|
|
212
|
+
GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1,
|
|
213
|
+
GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL,
|
|
188
214
|
GGML_METAL_KERNEL_TYPE_RMS_NORM,
|
|
189
215
|
GGML_METAL_KERNEL_TYPE_L2_NORM,
|
|
190
216
|
GGML_METAL_KERNEL_TYPE_GROUP_NORM,
|
|
191
217
|
GGML_METAL_KERNEL_TYPE_NORM,
|
|
192
218
|
GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
|
|
193
219
|
GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,
|
|
220
|
+
GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP,
|
|
194
221
|
GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32,
|
|
195
222
|
GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32,
|
|
196
223
|
GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
|
|
224
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32_C4,
|
|
197
225
|
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
|
|
226
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_C4,
|
|
198
227
|
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
|
|
199
228
|
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4,
|
|
200
229
|
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,
|
|
201
230
|
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32,
|
|
231
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_C4,
|
|
202
232
|
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW,
|
|
203
233
|
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4,
|
|
204
234
|
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16,
|
|
@@ -497,6 +527,11 @@ enum ggml_metal_kernel_type {
|
|
|
497
527
|
GGML_METAL_KERNEL_TYPE_SIN,
|
|
498
528
|
GGML_METAL_KERNEL_TYPE_COS,
|
|
499
529
|
GGML_METAL_KERNEL_TYPE_NEG,
|
|
530
|
+
GGML_METAL_KERNEL_TYPE_REGLU,
|
|
531
|
+
GGML_METAL_KERNEL_TYPE_GEGLU,
|
|
532
|
+
GGML_METAL_KERNEL_TYPE_SWIGLU,
|
|
533
|
+
GGML_METAL_KERNEL_TYPE_GEGLU_ERF,
|
|
534
|
+
GGML_METAL_KERNEL_TYPE_GEGLU_QUICK,
|
|
500
535
|
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
|
|
501
536
|
GGML_METAL_KERNEL_TYPE_MEAN,
|
|
502
537
|
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
|
|
@@ -977,7 +1012,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
|
977
1012
|
struct ggml_backend_metal_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_context));
|
|
978
1013
|
struct ggml_backend_metal_device_context * ctx_dev = dev->context;
|
|
979
1014
|
|
|
980
|
-
id<MTLDevice> device =
|
|
1015
|
+
id<MTLDevice> device = ctx_dev->mtl_device;
|
|
981
1016
|
|
|
982
1017
|
GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]);
|
|
983
1018
|
|
|
@@ -991,9 +1026,16 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
|
991
1026
|
ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
|
|
992
1027
|
|
|
993
1028
|
// load library
|
|
994
|
-
|
|
995
|
-
ctx_dev->
|
|
1029
|
+
{
|
|
1030
|
+
[ctx_dev->mtl_lock lock];
|
|
1031
|
+
|
|
1032
|
+
if (ctx_dev->mtl_library == nil) {
|
|
1033
|
+
ctx_dev->mtl_library = ggml_metal_load_library(device, ctx_dev->use_bfloat);
|
|
1034
|
+
}
|
|
1035
|
+
|
|
1036
|
+
[ctx_dev->mtl_lock unlock];
|
|
996
1037
|
}
|
|
1038
|
+
|
|
997
1039
|
id<MTLLibrary> metal_library = ctx_dev->mtl_library;
|
|
998
1040
|
if (metal_library == nil) {
|
|
999
1041
|
GGML_LOG_ERROR("%s: error: metal library is nil\n", __func__);
|
|
@@ -1142,20 +1184,33 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
|
1142
1184
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
|
|
1143
1185
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
|
|
1144
1186
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
|
|
1187
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_F32, set_rows_f32, true);
|
|
1188
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_F16, set_rows_f16, true);
|
|
1189
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16, set_rows_bf16, use_bfloat);
|
|
1190
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0, set_rows_q8_0, true);
|
|
1191
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0, set_rows_q4_0, true);
|
|
1192
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1, set_rows_q4_1, true);
|
|
1193
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0, set_rows_q5_0, true);
|
|
1194
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1, set_rows_q5_1, true);
|
|
1195
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL, set_rows_iq4_nl, true);
|
|
1145
1196
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction);
|
|
1146
1197
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_L2_NORM, l2_norm, has_simdgroup_reduction);
|
|
1147
1198
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction);
|
|
1148
1199
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
|
|
1149
1200
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true);
|
|
1150
1201
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true);
|
|
1202
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP, ssm_scan_f32_group, true);
|
|
1151
1203
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, rwkv_wkv6_f32, true);
|
|
1152
1204
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, rwkv_wkv7_f32, true);
|
|
1153
1205
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction);
|
|
1206
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32_C4, mul_mv_f32_f32_c4, true);
|
|
1154
1207
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, has_simdgroup_reduction && use_bfloat);
|
|
1208
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_C4, mul_mv_bf16_f32_c4, use_bfloat);
|
|
1155
1209
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, has_simdgroup_reduction && use_bfloat);
|
|
1156
1210
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4, mul_mv_bf16_f32_l4, has_simdgroup_reduction && use_bfloat);
|
|
1157
1211
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16, mul_mv_bf16_bf16, has_simdgroup_reduction && use_bfloat);
|
|
1158
1212
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, has_simdgroup_reduction);
|
|
1213
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_C4, mul_mv_f16_f32_c4, true);
|
|
1159
1214
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, has_simdgroup_reduction);
|
|
1160
1215
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, has_simdgroup_reduction);
|
|
1161
1216
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, has_simdgroup_reduction);
|
|
@@ -1454,6 +1509,11 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
|
1454
1509
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true);
|
|
1455
1510
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
|
|
1456
1511
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true);
|
|
1512
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REGLU, reglu, true);
|
|
1513
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU, geglu, true);
|
|
1514
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU, swiglu, true);
|
|
1515
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU_ERF, geglu_erf, true);
|
|
1516
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU_QUICK, geglu_quick, true);
|
|
1457
1517
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
|
1458
1518
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MEAN, mean, true);
|
|
1459
1519
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
|
|
@@ -1605,6 +1665,10 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
|
|
1605
1665
|
const bool use_bfloat = ctx_dev->use_bfloat;
|
|
1606
1666
|
|
|
1607
1667
|
if (!use_bfloat) {
|
|
1668
|
+
if (op->type == GGML_TYPE_BF16) {
|
|
1669
|
+
return false;
|
|
1670
|
+
}
|
|
1671
|
+
|
|
1608
1672
|
for (size_t i = 0, n = 3; i < n; ++i) {
|
|
1609
1673
|
if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16) {
|
|
1610
1674
|
return false;
|
|
@@ -1628,6 +1692,17 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
|
|
1628
1692
|
default:
|
|
1629
1693
|
return false;
|
|
1630
1694
|
}
|
|
1695
|
+
case GGML_OP_GLU:
|
|
1696
|
+
switch (ggml_get_glu_op(op)) {
|
|
1697
|
+
case GGML_GLU_OP_REGLU:
|
|
1698
|
+
case GGML_GLU_OP_GEGLU:
|
|
1699
|
+
case GGML_GLU_OP_SWIGLU:
|
|
1700
|
+
case GGML_GLU_OP_GEGLU_ERF:
|
|
1701
|
+
case GGML_GLU_OP_GEGLU_QUICK:
|
|
1702
|
+
return ggml_is_contiguous_1(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
|
|
1703
|
+
default:
|
|
1704
|
+
return false;
|
|
1705
|
+
}
|
|
1631
1706
|
case GGML_OP_NONE:
|
|
1632
1707
|
case GGML_OP_RESHAPE:
|
|
1633
1708
|
case GGML_OP_VIEW:
|
|
@@ -1658,7 +1733,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
|
|
1658
1733
|
case GGML_OP_MEAN:
|
|
1659
1734
|
case GGML_OP_SOFT_MAX:
|
|
1660
1735
|
case GGML_OP_GROUP_NORM:
|
|
1661
|
-
return has_simdgroup_reduction &&
|
|
1736
|
+
return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]);
|
|
1662
1737
|
case GGML_OP_RMS_NORM:
|
|
1663
1738
|
case GGML_OP_L2_NORM:
|
|
1664
1739
|
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
|
|
@@ -1774,6 +1849,27 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
|
|
1774
1849
|
{
|
|
1775
1850
|
return op->ne[3] == 1;
|
|
1776
1851
|
}
|
|
1852
|
+
case GGML_OP_SET_ROWS:
|
|
1853
|
+
{
|
|
1854
|
+
if (op->src[0]->type != GGML_TYPE_F32) {
|
|
1855
|
+
return false;
|
|
1856
|
+
}
|
|
1857
|
+
|
|
1858
|
+
switch (op->type) {
|
|
1859
|
+
case GGML_TYPE_F32:
|
|
1860
|
+
case GGML_TYPE_F16:
|
|
1861
|
+
case GGML_TYPE_BF16:
|
|
1862
|
+
case GGML_TYPE_Q8_0:
|
|
1863
|
+
case GGML_TYPE_Q4_0:
|
|
1864
|
+
case GGML_TYPE_Q4_1:
|
|
1865
|
+
case GGML_TYPE_Q5_0:
|
|
1866
|
+
case GGML_TYPE_Q5_1:
|
|
1867
|
+
case GGML_TYPE_IQ4_NL:
|
|
1868
|
+
return true;
|
|
1869
|
+
default:
|
|
1870
|
+
return false;
|
|
1871
|
+
};
|
|
1872
|
+
}
|
|
1777
1873
|
default:
|
|
1778
1874
|
return false;
|
|
1779
1875
|
}
|
|
@@ -2160,7 +2256,9 @@ static bool ggml_metal_encode_node(
|
|
|
2160
2256
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
2161
2257
|
|
|
2162
2258
|
float scale;
|
|
2163
|
-
|
|
2259
|
+
float bias;
|
|
2260
|
+
memcpy(&scale, ((const int32_t *) dst->op_params) + 0, sizeof(float));
|
|
2261
|
+
memcpy(&bias, ((const int32_t *) dst->op_params) + 1, sizeof(float));
|
|
2164
2262
|
|
|
2165
2263
|
int64_t n = ggml_nelements(dst);
|
|
2166
2264
|
|
|
@@ -2177,6 +2275,7 @@ static bool ggml_metal_encode_node(
|
|
|
2177
2275
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
2178
2276
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
2179
2277
|
[encoder setBytes:&scale length:sizeof(scale) atIndex:2];
|
|
2278
|
+
[encoder setBytes:&bias length:sizeof(bias) atIndex:3];
|
|
2180
2279
|
|
|
2181
2280
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
2182
2281
|
} break;
|
|
@@ -2346,6 +2445,68 @@ static bool ggml_metal_encode_node(
|
|
|
2346
2445
|
GGML_ABORT("fatal error");
|
|
2347
2446
|
}
|
|
2348
2447
|
} break;
|
|
2448
|
+
case GGML_OP_GLU:
|
|
2449
|
+
{
|
|
2450
|
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
|
2451
|
+
|
|
2452
|
+
if (src1) {
|
|
2453
|
+
GGML_ASSERT(ggml_are_same_shape(src0, src1));
|
|
2454
|
+
}
|
|
2455
|
+
|
|
2456
|
+
id<MTLComputePipelineState> pipeline = nil;
|
|
2457
|
+
|
|
2458
|
+
switch (ggml_get_glu_op(node)) {
|
|
2459
|
+
case GGML_GLU_OP_REGLU:
|
|
2460
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REGLU].pipeline;
|
|
2461
|
+
break;
|
|
2462
|
+
case GGML_GLU_OP_GEGLU:
|
|
2463
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GEGLU].pipeline;
|
|
2464
|
+
break;
|
|
2465
|
+
case GGML_GLU_OP_SWIGLU:
|
|
2466
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU].pipeline;
|
|
2467
|
+
break;
|
|
2468
|
+
case GGML_GLU_OP_GEGLU_ERF:
|
|
2469
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GEGLU_ERF].pipeline;
|
|
2470
|
+
break;
|
|
2471
|
+
case GGML_GLU_OP_GEGLU_QUICK:
|
|
2472
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GEGLU_QUICK].pipeline;
|
|
2473
|
+
break;
|
|
2474
|
+
default:
|
|
2475
|
+
GGML_ABORT("fatal error");
|
|
2476
|
+
}
|
|
2477
|
+
|
|
2478
|
+
const int32_t swp = ((const int32_t *) dst->op_params)[1];
|
|
2479
|
+
|
|
2480
|
+
const int32_t i00 = swp ? ne0 : 0;
|
|
2481
|
+
const int32_t i10 = swp ? 0 : ne0;
|
|
2482
|
+
|
|
2483
|
+
ggml_metal_kargs_glu args = {
|
|
2484
|
+
/*.ne00 =*/ ne00,
|
|
2485
|
+
/*.nb01 =*/ nb01,
|
|
2486
|
+
/*.ne10 =*/ src1 ? ne10 : ne00,
|
|
2487
|
+
/*.nb11 =*/ src1 ? nb11 : nb01,
|
|
2488
|
+
/*.ne0 =*/ ne0,
|
|
2489
|
+
/*.nb1 =*/ nb1,
|
|
2490
|
+
/*.i00 =*/ src1 ? 0 : i00,
|
|
2491
|
+
/*.i10 =*/ src1 ? 0 : i10,
|
|
2492
|
+
};
|
|
2493
|
+
|
|
2494
|
+
[encoder setComputePipelineState:pipeline];
|
|
2495
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
2496
|
+
if (src1) {
|
|
2497
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
2498
|
+
} else {
|
|
2499
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
|
2500
|
+
}
|
|
2501
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
2502
|
+
[encoder setBytes:&args length:sizeof(args) atIndex:3];
|
|
2503
|
+
|
|
2504
|
+
const int64_t nrows = ggml_nrows(src0);
|
|
2505
|
+
|
|
2506
|
+
const int32_t nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00/2);
|
|
2507
|
+
|
|
2508
|
+
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
2509
|
+
} break;
|
|
2349
2510
|
case GGML_OP_SQR:
|
|
2350
2511
|
{
|
|
2351
2512
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
@@ -2426,6 +2587,7 @@ static bool ggml_metal_encode_node(
|
|
|
2426
2587
|
nth *= 2;
|
|
2427
2588
|
}
|
|
2428
2589
|
|
|
2590
|
+
nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
|
|
2429
2591
|
nth = MIN(nth, ne00);
|
|
2430
2592
|
|
|
2431
2593
|
ggml_metal_kargs_sum_rows args = {
|
|
@@ -2499,10 +2661,7 @@ static bool ggml_metal_encode_node(
|
|
|
2499
2661
|
memcpy(&scale, ((const int32_t *) dst->op_params) + 0, sizeof(scale));
|
|
2500
2662
|
memcpy(&max_bias, ((const int32_t *) dst->op_params) + 1, sizeof(max_bias));
|
|
2501
2663
|
|
|
2502
|
-
const
|
|
2503
|
-
const int64_t nrows_y = src0->ne[1];
|
|
2504
|
-
|
|
2505
|
-
const uint32_t n_head = nrows_x/nrows_y;
|
|
2664
|
+
const uint32_t n_head = src0->ne[2];
|
|
2506
2665
|
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
|
|
2507
2666
|
|
|
2508
2667
|
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
|
@@ -2562,6 +2721,18 @@ static bool ggml_metal_encode_node(
|
|
|
2562
2721
|
/*.ne00 =*/ ne00,
|
|
2563
2722
|
/*.ne01 =*/ ne01,
|
|
2564
2723
|
/*.ne02 =*/ ne02,
|
|
2724
|
+
/*.nb01 =*/ nb01,
|
|
2725
|
+
/*.nb02 =*/ nb02,
|
|
2726
|
+
/*.nb03 =*/ nb03,
|
|
2727
|
+
/*.ne11 =*/ ne11,
|
|
2728
|
+
/*.ne12 =*/ ne12,
|
|
2729
|
+
/*.ne13 =*/ ne13,
|
|
2730
|
+
/*.nb11 =*/ nb11,
|
|
2731
|
+
/*.nb12 =*/ nb12,
|
|
2732
|
+
/*.nb13 =*/ nb13,
|
|
2733
|
+
/*.nb1 =*/ nb1,
|
|
2734
|
+
/*.nb2 =*/ nb2,
|
|
2735
|
+
/*.nb3 =*/ nb3,
|
|
2565
2736
|
/*.scale =*/ scale,
|
|
2566
2737
|
/*.max_bias =*/ max_bias,
|
|
2567
2738
|
/*.m0 =*/ m0,
|
|
@@ -2581,7 +2752,7 @@ static bool ggml_metal_encode_node(
|
|
|
2581
2752
|
|
|
2582
2753
|
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
|
2583
2754
|
|
|
2584
|
-
[encoder dispatchThreadgroups:MTLSizeMake(ne01
|
|
2755
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
2585
2756
|
} break;
|
|
2586
2757
|
case GGML_OP_DIAG_MASK_INF:
|
|
2587
2758
|
{
|
|
@@ -2655,71 +2826,91 @@ static bool ggml_metal_encode_node(
|
|
|
2655
2826
|
struct ggml_tensor * src3 = node->src[3];
|
|
2656
2827
|
struct ggml_tensor * src4 = node->src[4];
|
|
2657
2828
|
struct ggml_tensor * src5 = node->src[5];
|
|
2829
|
+
struct ggml_tensor * src6 = node->src[6];
|
|
2658
2830
|
|
|
2659
2831
|
GGML_ASSERT(src3);
|
|
2660
2832
|
GGML_ASSERT(src4);
|
|
2661
2833
|
GGML_ASSERT(src5);
|
|
2834
|
+
GGML_ASSERT(src6);
|
|
2662
2835
|
|
|
2663
2836
|
size_t offs_src3 = 0;
|
|
2664
2837
|
size_t offs_src4 = 0;
|
|
2665
2838
|
size_t offs_src5 = 0;
|
|
2839
|
+
size_t offs_src6 = 0;
|
|
2666
2840
|
|
|
2667
2841
|
id<MTLBuffer> id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil;
|
|
2668
2842
|
id<MTLBuffer> id_src4 = src4 ? ggml_metal_get_buffer(src4, &offs_src4) : nil;
|
|
2669
2843
|
id<MTLBuffer> id_src5 = src5 ? ggml_metal_get_buffer(src5, &offs_src5) : nil;
|
|
2844
|
+
id<MTLBuffer> id_src6 = src6 ? ggml_metal_get_buffer(src6, &offs_src6) : nil;
|
|
2670
2845
|
|
|
2671
|
-
const int64_t ne30 = src3->ne[0];
|
|
2846
|
+
const int64_t ne30 = src3->ne[0];
|
|
2672
2847
|
const int64_t ne31 = src3->ne[1]; GGML_UNUSED(ne31);
|
|
2673
2848
|
|
|
2674
|
-
const uint64_t nb30 = src3->nb[0];
|
|
2849
|
+
const uint64_t nb30 = src3->nb[0]; GGML_UNUSED(nb30);
|
|
2675
2850
|
const uint64_t nb31 = src3->nb[1];
|
|
2676
2851
|
|
|
2677
2852
|
const int64_t ne40 = src4->ne[0]; GGML_UNUSED(ne40);
|
|
2678
|
-
const int64_t ne41 = src4->ne[1];
|
|
2853
|
+
const int64_t ne41 = src4->ne[1];
|
|
2679
2854
|
const int64_t ne42 = src4->ne[2]; GGML_UNUSED(ne42);
|
|
2855
|
+
const int64_t ne43 = src4->ne[3]; GGML_UNUSED(ne43);
|
|
2680
2856
|
|
|
2681
|
-
const uint64_t nb40 = src4->nb[0];
|
|
2857
|
+
const uint64_t nb40 = src4->nb[0]; GGML_UNUSED(nb40);
|
|
2682
2858
|
const uint64_t nb41 = src4->nb[1];
|
|
2683
2859
|
const uint64_t nb42 = src4->nb[2];
|
|
2860
|
+
const uint64_t nb43 = src4->nb[3];
|
|
2684
2861
|
|
|
2685
2862
|
const int64_t ne50 = src5->ne[0]; GGML_UNUSED(ne50);
|
|
2686
2863
|
const int64_t ne51 = src5->ne[1]; GGML_UNUSED(ne51);
|
|
2687
2864
|
const int64_t ne52 = src5->ne[2]; GGML_UNUSED(ne52);
|
|
2865
|
+
const int64_t ne53 = src5->ne[3]; GGML_UNUSED(ne53);
|
|
2688
2866
|
|
|
2689
|
-
const uint64_t nb50 = src5->nb[0];
|
|
2867
|
+
const uint64_t nb50 = src5->nb[0]; GGML_UNUSED(nb50);
|
|
2690
2868
|
const uint64_t nb51 = src5->nb[1];
|
|
2691
2869
|
const uint64_t nb52 = src5->nb[2];
|
|
2870
|
+
const uint64_t nb53 = src5->nb[3];
|
|
2871
|
+
|
|
2872
|
+
const int64_t ne60 = src6->ne[0]; GGML_UNUSED(ne60);
|
|
2873
|
+
|
|
2874
|
+
const uint64_t nb60 = src6->nb[0]; GGML_UNUSED(nb60);
|
|
2692
2875
|
|
|
2693
2876
|
const int64_t d_state = ne00;
|
|
2694
2877
|
const int64_t d_inner = ne01;
|
|
2695
|
-
const int64_t
|
|
2696
|
-
const int64_t
|
|
2878
|
+
const int64_t n_head = ne02;
|
|
2879
|
+
const int64_t n_group = ne41;
|
|
2880
|
+
const int64_t n_seq_tokens = ne12;
|
|
2881
|
+
const int64_t n_seqs = ne13;
|
|
2882
|
+
|
|
2883
|
+
id<MTLComputePipelineState> pipeline = nil;
|
|
2697
2884
|
|
|
2698
|
-
|
|
2885
|
+
if (ne30 == 1) {
|
|
2886
|
+
// Mamba-2
|
|
2887
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP].pipeline;
|
|
2888
|
+
} else {
|
|
2889
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline;
|
|
2890
|
+
}
|
|
2699
2891
|
|
|
2700
2892
|
ggml_metal_kargs_ssm_scan args = {
|
|
2701
|
-
/*.d_state
|
|
2702
|
-
/*.d_inner
|
|
2893
|
+
/*.d_state =*/ d_state,
|
|
2894
|
+
/*.d_inner =*/ d_inner,
|
|
2895
|
+
/*.n_head =*/ n_head,
|
|
2896
|
+
/*.n_group =*/ n_group,
|
|
2703
2897
|
/*.n_seq_tokens =*/ n_seq_tokens,
|
|
2704
|
-
/*.n_seqs
|
|
2705
|
-
/*.
|
|
2706
|
-
/*.
|
|
2707
|
-
/*.
|
|
2708
|
-
/*.
|
|
2709
|
-
/*.
|
|
2710
|
-
/*.
|
|
2711
|
-
/*.
|
|
2712
|
-
/*.
|
|
2713
|
-
/*.
|
|
2714
|
-
/*.
|
|
2715
|
-
/*.
|
|
2716
|
-
/*.
|
|
2717
|
-
/*.
|
|
2718
|
-
/*.
|
|
2719
|
-
/*.
|
|
2720
|
-
/*.nb50 =*/ nb50,
|
|
2721
|
-
/*.nb51 =*/ nb51,
|
|
2722
|
-
/*.nb52 =*/ nb52,
|
|
2898
|
+
/*.n_seqs =*/ n_seqs,
|
|
2899
|
+
/*.nb01 =*/ nb01,
|
|
2900
|
+
/*.nb02 =*/ nb02,
|
|
2901
|
+
/*.nb03 =*/ nb03,
|
|
2902
|
+
/*.nb11 =*/ nb11,
|
|
2903
|
+
/*.nb12 =*/ nb12,
|
|
2904
|
+
/*.nb13 =*/ nb13,
|
|
2905
|
+
/*.nb21 =*/ nb21,
|
|
2906
|
+
/*.nb22 =*/ nb22,
|
|
2907
|
+
/*.nb31 =*/ nb31,
|
|
2908
|
+
/*.nb41 =*/ nb41,
|
|
2909
|
+
/*.nb42 =*/ nb42,
|
|
2910
|
+
/*.nb43 =*/ nb43,
|
|
2911
|
+
/*.nb51 =*/ nb51,
|
|
2912
|
+
/*.nb52 =*/ nb52,
|
|
2913
|
+
/*.nb53 =*/ nb53,
|
|
2723
2914
|
};
|
|
2724
2915
|
|
|
2725
2916
|
[encoder setComputePipelineState:pipeline];
|
|
@@ -2729,10 +2920,17 @@ static bool ggml_metal_encode_node(
|
|
|
2729
2920
|
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
|
|
2730
2921
|
[encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
|
|
2731
2922
|
[encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
|
|
2732
|
-
[encoder setBuffer:
|
|
2733
|
-
[encoder
|
|
2923
|
+
[encoder setBuffer:id_src6 offset:offs_src6 atIndex:6];
|
|
2924
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:7];
|
|
2925
|
+
[encoder setBytes:&args length:sizeof(args) atIndex:8];
|
|
2734
2926
|
|
|
2735
|
-
|
|
2927
|
+
if (ne30 == 1) {
|
|
2928
|
+
// Mamba-2
|
|
2929
|
+
[encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
2930
|
+
} else {
|
|
2931
|
+
GGML_ASSERT(d_inner == 1);
|
|
2932
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n_head, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
2933
|
+
}
|
|
2736
2934
|
} break;
|
|
2737
2935
|
case GGML_OP_RWKV_WKV6:
|
|
2738
2936
|
{
|
|
@@ -3086,14 +3284,23 @@ static bool ggml_metal_encode_node(
|
|
|
3086
3284
|
nsg = 1;
|
|
3087
3285
|
nr0 = 1;
|
|
3088
3286
|
nr1 = 4;
|
|
3089
|
-
|
|
3287
|
+
if (ne00 == 4) {
|
|
3288
|
+
nr0 = 32;
|
|
3289
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32_C4].pipeline;
|
|
3290
|
+
} else {
|
|
3291
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline;
|
|
3292
|
+
}
|
|
3090
3293
|
} break;
|
|
3091
3294
|
case GGML_TYPE_F16:
|
|
3092
3295
|
{
|
|
3093
3296
|
nsg = 1;
|
|
3094
3297
|
nr0 = 1;
|
|
3095
3298
|
if (src1t == GGML_TYPE_F32) {
|
|
3096
|
-
if (
|
|
3299
|
+
if (ne00 == 4) {
|
|
3300
|
+
nr0 = 32;
|
|
3301
|
+
nr1 = 4;
|
|
3302
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_C4].pipeline;
|
|
3303
|
+
} else if (ne11 * ne12 < 4) {
|
|
3097
3304
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline;
|
|
3098
3305
|
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
|
|
3099
3306
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline;
|
|
@@ -3112,7 +3319,11 @@ static bool ggml_metal_encode_node(
|
|
|
3112
3319
|
nsg = 1;
|
|
3113
3320
|
nr0 = 1;
|
|
3114
3321
|
if (src1t == GGML_TYPE_F32) {
|
|
3115
|
-
if (
|
|
3322
|
+
if (ne00 == 4) {
|
|
3323
|
+
nr0 = 32;
|
|
3324
|
+
nr1 = 4;
|
|
3325
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_C4].pipeline;
|
|
3326
|
+
} else if (ne11 * ne12 < 4) {
|
|
3116
3327
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW].pipeline;
|
|
3117
3328
|
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
|
|
3118
3329
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4].pipeline;
|
|
@@ -3733,13 +3944,74 @@ static bool ggml_metal_encode_node(
|
|
|
3733
3944
|
};
|
|
3734
3945
|
|
|
3735
3946
|
[encoder setComputePipelineState:pipeline];
|
|
3736
|
-
[encoder
|
|
3737
|
-
[encoder setBuffer:
|
|
3738
|
-
[encoder setBuffer:
|
|
3739
|
-
[encoder
|
|
3947
|
+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
|
3948
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
|
3949
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
|
|
3950
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
|
3740
3951
|
|
|
3741
3952
|
[encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
|
|
3742
3953
|
} break;
|
|
3954
|
+
case GGML_OP_SET_ROWS:
|
|
3955
|
+
{
|
|
3956
|
+
id<MTLComputePipelineState> pipeline = nil;
|
|
3957
|
+
|
|
3958
|
+
switch (dst->type) {
|
|
3959
|
+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_F32 ].pipeline; break;
|
|
3960
|
+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_F16 ].pipeline; break;
|
|
3961
|
+
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16 ].pipeline; break;
|
|
3962
|
+
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0 ].pipeline; break;
|
|
3963
|
+
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0 ].pipeline; break;
|
|
3964
|
+
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1 ].pipeline; break;
|
|
3965
|
+
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0 ].pipeline; break;
|
|
3966
|
+
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1 ].pipeline; break;
|
|
3967
|
+
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL].pipeline; break;
|
|
3968
|
+
default: GGML_ABORT("not implemented");
|
|
3969
|
+
}
|
|
3970
|
+
|
|
3971
|
+
const int32_t nk0 = ne0/ggml_blck_size(dst->type);
|
|
3972
|
+
|
|
3973
|
+
int nth = 32; // SIMD width
|
|
3974
|
+
|
|
3975
|
+
while (nth < nk0 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
|
|
3976
|
+
nth *= 2;
|
|
3977
|
+
}
|
|
3978
|
+
|
|
3979
|
+
int nrptg = 1;
|
|
3980
|
+
if (nth > nk0) {
|
|
3981
|
+
nrptg = (nth + nk0 - 1)/nk0;
|
|
3982
|
+
nth = nk0;
|
|
3983
|
+
|
|
3984
|
+
if (nrptg*nth > (int) pipeline.maxTotalThreadsPerThreadgroup) {
|
|
3985
|
+
nrptg--;
|
|
3986
|
+
}
|
|
3987
|
+
}
|
|
3988
|
+
|
|
3989
|
+
nth = MIN(nth, nk0);
|
|
3990
|
+
|
|
3991
|
+
ggml_metal_kargs_set_rows args = {
|
|
3992
|
+
/*.nk0 =*/ nk0,
|
|
3993
|
+
/*.ne01 =*/ ne01,
|
|
3994
|
+
/*.nb01 =*/ nb01,
|
|
3995
|
+
/*.nb02 =*/ nb02,
|
|
3996
|
+
/*.nb03 =*/ nb03,
|
|
3997
|
+
/*.ne11 =*/ ne11,
|
|
3998
|
+
/*.ne12 =*/ ne12,
|
|
3999
|
+
/*.nb10 =*/ nb10,
|
|
4000
|
+
/*.nb11 =*/ nb11,
|
|
4001
|
+
/*.nb12 =*/ nb12,
|
|
4002
|
+
/*.nb1 =*/ nb1,
|
|
4003
|
+
/*.nb2 =*/ nb2,
|
|
4004
|
+
/*.nb3 =*/ nb3,
|
|
4005
|
+
};
|
|
4006
|
+
|
|
4007
|
+
[encoder setComputePipelineState:pipeline];
|
|
4008
|
+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
|
4009
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
|
4010
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
|
|
4011
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
|
4012
|
+
|
|
4013
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nrptg - 1)/nrptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, nrptg, 1)];
|
|
4014
|
+
} break;
|
|
3743
4015
|
case GGML_OP_RMS_NORM:
|
|
3744
4016
|
{
|
|
3745
4017
|
GGML_ASSERT(ne00 % 4 == 0);
|
|
@@ -3756,6 +4028,7 @@ static bool ggml_metal_encode_node(
|
|
|
3756
4028
|
nth *= 2;
|
|
3757
4029
|
}
|
|
3758
4030
|
|
|
4031
|
+
nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
|
|
3759
4032
|
nth = MIN(nth, ne00/4);
|
|
3760
4033
|
|
|
3761
4034
|
ggml_metal_kargs_rms_norm args = {
|
|
@@ -3792,6 +4065,7 @@ static bool ggml_metal_encode_node(
|
|
|
3792
4065
|
nth *= 2;
|
|
3793
4066
|
}
|
|
3794
4067
|
|
|
4068
|
+
nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
|
|
3795
4069
|
nth = MIN(nth, ne00/4);
|
|
3796
4070
|
|
|
3797
4071
|
ggml_metal_kargs_l2_norm args = {
|
|
@@ -3864,6 +4138,7 @@ static bool ggml_metal_encode_node(
|
|
|
3864
4138
|
nth *= 2;
|
|
3865
4139
|
}
|
|
3866
4140
|
|
|
4141
|
+
nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
|
|
3867
4142
|
nth = MIN(nth, ne00/4);
|
|
3868
4143
|
|
|
3869
4144
|
ggml_metal_kargs_norm args = {
|
|
@@ -4757,7 +5032,11 @@ static bool ggml_metal_encode_node(
|
|
|
4757
5032
|
/*.nb21 =*/ nb21,
|
|
4758
5033
|
/*.nb22 =*/ nb22,
|
|
4759
5034
|
/*.nb23 =*/ nb23,
|
|
5035
|
+
/*.ne32 =*/ ne32,
|
|
5036
|
+
/*.ne33 =*/ ne33,
|
|
4760
5037
|
/*.nb31 =*/ nb31,
|
|
5038
|
+
/*.nb32 =*/ nb32,
|
|
5039
|
+
/*.nb33 =*/ nb33,
|
|
4761
5040
|
/*.ne1 =*/ ne1,
|
|
4762
5041
|
/*.ne2 =*/ ne2,
|
|
4763
5042
|
/*.scale =*/ scale,
|
|
@@ -4950,8 +5229,39 @@ static bool ggml_metal_encode_node(
|
|
|
4950
5229
|
default: GGML_ABORT("not implemented");
|
|
4951
5230
|
}
|
|
4952
5231
|
|
|
5232
|
+
GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
|
|
5233
|
+
|
|
5234
|
+
// TODO: support
|
|
5235
|
+
//const int32_t nk00 = ne00/ggml_blck_size(dst->type);
|
|
5236
|
+
const int32_t nk00 = ne00;
|
|
5237
|
+
|
|
5238
|
+
int nth = 32; // SIMD width
|
|
5239
|
+
|
|
5240
|
+
while (nth < nk00 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
|
|
5241
|
+
nth *= 2;
|
|
5242
|
+
}
|
|
5243
|
+
|
|
5244
|
+
nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
|
|
5245
|
+
|
|
5246
|
+
// when rows are small, we can batch them together in a single threadgroup
|
|
5247
|
+
int nrptg = 1;
|
|
5248
|
+
|
|
5249
|
+
// TODO: relax this constraint in the future
|
|
5250
|
+
if (ggml_blck_size(src0->type) == 1 && ggml_blck_size(dst->type) == 1) {
|
|
5251
|
+
if (nth > nk00) {
|
|
5252
|
+
nrptg = (nth + nk00 - 1)/nk00;
|
|
5253
|
+
nth = nk00;
|
|
5254
|
+
|
|
5255
|
+
if (nrptg*nth > (int) pipeline.maxTotalThreadsPerThreadgroup) {
|
|
5256
|
+
nrptg--;
|
|
5257
|
+
}
|
|
5258
|
+
}
|
|
5259
|
+
}
|
|
5260
|
+
|
|
5261
|
+
nth = MIN(nth, nk00);
|
|
5262
|
+
|
|
4953
5263
|
ggml_metal_kargs_cpy args = {
|
|
4954
|
-
/*.ne00 =*/
|
|
5264
|
+
/*.ne00 =*/ nk00,
|
|
4955
5265
|
/*.ne01 =*/ ne01,
|
|
4956
5266
|
/*.ne02 =*/ ne02,
|
|
4957
5267
|
/*.ne03 =*/ ne03,
|
|
@@ -4974,11 +5284,7 @@ static bool ggml_metal_encode_node(
|
|
|
4974
5284
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
|
4975
5285
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
4976
5286
|
|
|
4977
|
-
|
|
4978
|
-
int nth = MIN(1024, ne00/ggml_blck_size(src0->type));
|
|
4979
|
-
|
|
4980
|
-
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
4981
|
-
|
|
5287
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nrptg - 1)/nrptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, nrptg, 1)];
|
|
4982
5288
|
} break;
|
|
4983
5289
|
case GGML_OP_SET:
|
|
4984
5290
|
{
|
|
@@ -5284,7 +5590,6 @@ static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer)
|
|
|
5284
5590
|
}
|
|
5285
5591
|
|
|
5286
5592
|
ggml_backend_metal_buffer_rset_free(ctx);
|
|
5287
|
-
ggml_backend_metal_device_rel(buffer->buft->device->context);
|
|
5288
5593
|
|
|
5289
5594
|
if (ctx->owned) {
|
|
5290
5595
|
#if TARGET_OS_OSX
|
|
@@ -5393,7 +5698,10 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
|
|
|
5393
5698
|
}
|
|
5394
5699
|
|
|
5395
5700
|
struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)buft->device->context;
|
|
5396
|
-
|
|
5701
|
+
|
|
5702
|
+
GGML_ASSERT(ctx_dev->mtl_device != nil);
|
|
5703
|
+
|
|
5704
|
+
id<MTLDevice> device = ctx_dev->mtl_device;
|
|
5397
5705
|
|
|
5398
5706
|
ctx->all_data = ggml_metal_host_malloc(size_aligned);
|
|
5399
5707
|
ctx->all_size = size_aligned;
|
|
@@ -5416,14 +5724,12 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
|
|
|
5416
5724
|
if (size_aligned > 0 && (ctx->all_data == NULL || ctx->buffers[0].metal == nil)) {
|
|
5417
5725
|
GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0);
|
|
5418
5726
|
free(ctx);
|
|
5419
|
-
ggml_backend_metal_device_rel(ctx_dev);
|
|
5420
5727
|
return NULL;
|
|
5421
5728
|
}
|
|
5422
5729
|
|
|
5423
5730
|
if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) {
|
|
5424
5731
|
GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__);
|
|
5425
5732
|
free(ctx);
|
|
5426
|
-
ggml_backend_metal_device_rel(ctx_dev);
|
|
5427
5733
|
return NULL;
|
|
5428
5734
|
}
|
|
5429
5735
|
|
|
@@ -5434,17 +5740,14 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
|
|
|
5434
5740
|
|
|
5435
5741
|
static size_t ggml_backend_metal_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
|
|
5436
5742
|
return 32;
|
|
5743
|
+
|
|
5437
5744
|
GGML_UNUSED(buft);
|
|
5438
5745
|
}
|
|
5439
5746
|
|
|
5440
5747
|
static size_t ggml_backend_metal_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
|
|
5441
|
-
|
|
5442
|
-
const size_t max_size = device.maxBufferLength;
|
|
5443
|
-
ggml_backend_metal_device_rel(buft->device->context);
|
|
5748
|
+
const size_t max_size = ((struct ggml_backend_metal_device_context *)buft->device->context)->max_size;
|
|
5444
5749
|
|
|
5445
5750
|
return max_size;
|
|
5446
|
-
|
|
5447
|
-
GGML_UNUSED(buft);
|
|
5448
5751
|
}
|
|
5449
5752
|
|
|
5450
5753
|
static bool ggml_backend_metal_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
|
|
@@ -5517,7 +5820,10 @@ ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t siz
|
|
|
5517
5820
|
}
|
|
5518
5821
|
|
|
5519
5822
|
struct ggml_backend_metal_device_context * ctx_dev = &g_ggml_ctx_dev_main;
|
|
5520
|
-
|
|
5823
|
+
|
|
5824
|
+
GGML_ASSERT(ctx_dev->mtl_device != nil);
|
|
5825
|
+
|
|
5826
|
+
id<MTLDevice> device = ctx_dev->mtl_device;
|
|
5521
5827
|
|
|
5522
5828
|
// the buffer fits into the max buffer size allowed by the device
|
|
5523
5829
|
if (size_aligned <= device.maxBufferLength) {
|
|
@@ -5573,7 +5879,6 @@ ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t siz
|
|
|
5573
5879
|
if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) {
|
|
5574
5880
|
GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__);
|
|
5575
5881
|
free(ctx);
|
|
5576
|
-
ggml_backend_metal_device_rel(ctx_dev);
|
|
5577
5882
|
return NULL;
|
|
5578
5883
|
}
|
|
5579
5884
|
|
|
@@ -5589,10 +5894,8 @@ static const char * ggml_backend_metal_name(ggml_backend_t backend) {
|
|
|
5589
5894
|
}
|
|
5590
5895
|
|
|
5591
5896
|
static void ggml_backend_metal_free(ggml_backend_t backend) {
|
|
5592
|
-
struct ggml_backend_metal_context
|
|
5593
|
-
struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
|
|
5897
|
+
struct ggml_backend_metal_context * ctx = backend->context;
|
|
5594
5898
|
|
|
5595
|
-
ggml_backend_metal_device_rel(ctx_dev);
|
|
5596
5899
|
ggml_metal_free(ctx);
|
|
5597
5900
|
|
|
5598
5901
|
free(backend);
|
|
@@ -5732,6 +6035,8 @@ bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {
|
|
|
5732
6035
|
|
|
5733
6036
|
struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
|
|
5734
6037
|
|
|
6038
|
+
GGML_ASSERT(ctx_dev->mtl_device != nil);
|
|
6039
|
+
|
|
5735
6040
|
return [ctx_dev->mtl_device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
|
|
5736
6041
|
}
|
|
5737
6042
|
|
|
@@ -5751,10 +6056,7 @@ static const char * ggml_backend_metal_device_get_name(ggml_backend_dev_t dev) {
|
|
|
5751
6056
|
}
|
|
5752
6057
|
|
|
5753
6058
|
static const char * ggml_backend_metal_device_get_description(ggml_backend_dev_t dev) {
|
|
5754
|
-
// acq/rel just to populate ctx->name in case it hasn't been done yet
|
|
5755
6059
|
struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context;
|
|
5756
|
-
ggml_backend_metal_device_acq(ctx_dev);
|
|
5757
|
-
ggml_backend_metal_device_rel(ctx_dev);
|
|
5758
6060
|
|
|
5759
6061
|
return ctx_dev->name;
|
|
5760
6062
|
}
|
|
@@ -5762,12 +6064,10 @@ static const char * ggml_backend_metal_device_get_description(ggml_backend_dev_t
|
|
|
5762
6064
|
static void ggml_backend_metal_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
|
|
5763
6065
|
if (@available(macOS 10.12, iOS 16.0, *)) {
|
|
5764
6066
|
struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context;
|
|
5765
|
-
id<MTLDevice> device =
|
|
6067
|
+
id<MTLDevice> device = ctx_dev->mtl_device;
|
|
5766
6068
|
|
|
5767
6069
|
*total = device.recommendedMaxWorkingSetSize;
|
|
5768
6070
|
*free = *total - device.currentAllocatedSize;
|
|
5769
|
-
|
|
5770
|
-
ggml_backend_metal_device_rel(ctx_dev);
|
|
5771
6071
|
} else {
|
|
5772
6072
|
*free = 1;
|
|
5773
6073
|
*total = 1;
|
|
@@ -5845,7 +6145,10 @@ static ggml_backend_buffer_t ggml_backend_metal_device_buffer_from_ptr(ggml_back
|
|
|
5845
6145
|
}
|
|
5846
6146
|
|
|
5847
6147
|
struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context;
|
|
5848
|
-
|
|
6148
|
+
|
|
6149
|
+
GGML_ASSERT(ctx_dev->mtl_device != nil);
|
|
6150
|
+
|
|
6151
|
+
id<MTLDevice> device = ctx_dev->mtl_device;
|
|
5849
6152
|
|
|
5850
6153
|
// the buffer fits into the max buffer size allowed by the device
|
|
5851
6154
|
if (size_aligned <= device.maxBufferLength) {
|
|
@@ -5901,7 +6204,6 @@ static ggml_backend_buffer_t ggml_backend_metal_device_buffer_from_ptr(ggml_back
|
|
|
5901
6204
|
if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) {
|
|
5902
6205
|
GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__);
|
|
5903
6206
|
free(ctx);
|
|
5904
|
-
ggml_backend_metal_device_rel(ctx_dev);
|
|
5905
6207
|
return NULL;
|
|
5906
6208
|
}
|
|
5907
6209
|
|
|
@@ -5915,8 +6217,9 @@ static bool ggml_backend_metal_device_supports_op(ggml_backend_dev_t dev, const
|
|
|
5915
6217
|
}
|
|
5916
6218
|
|
|
5917
6219
|
static bool ggml_backend_metal_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
|
|
5918
|
-
return
|
|
5919
|
-
|
|
6220
|
+
return
|
|
6221
|
+
buft->iface.get_name == ggml_backend_metal_buffer_type_get_name ||
|
|
6222
|
+
buft->iface.get_name == ggml_backend_metal_buffer_from_ptr_type_get_name;
|
|
5920
6223
|
|
|
5921
6224
|
GGML_UNUSED(dev);
|
|
5922
6225
|
}
|
|
@@ -6001,8 +6304,19 @@ static struct ggml_backend_reg_i ggml_backend_metal_reg_i = {
|
|
|
6001
6304
|
/* .get_proc_address = */ ggml_backend_metal_get_proc_address,
|
|
6002
6305
|
};
|
|
6003
6306
|
|
|
6307
|
+
// called upon program exit
|
|
6308
|
+
static void ggml_metal_cleanup(void) {
|
|
6309
|
+
ggml_backend_metal_device_rel(&g_ggml_ctx_dev_main);
|
|
6310
|
+
}
|
|
6311
|
+
|
|
6312
|
+
// TODO: make thread-safe
|
|
6004
6313
|
ggml_backend_reg_t ggml_backend_metal_reg(void) {
|
|
6005
|
-
|
|
6314
|
+
ggml_backend_metal_device_acq(&g_ggml_ctx_dev_main);
|
|
6315
|
+
|
|
6316
|
+
// register cleanup callback
|
|
6317
|
+
// TODO: not ideal, but not sure if there is a better way to do this in Objective-C
|
|
6318
|
+
atexit(ggml_metal_cleanup);
|
|
6319
|
+
|
|
6006
6320
|
{
|
|
6007
6321
|
g_ggml_backend_metal_reg = (struct ggml_backend_reg) {
|
|
6008
6322
|
/* .api_version = */ GGML_BACKEND_API_VERSION,
|