@novastera-oss/llamarn 0.2.6 → 0.2.9
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/src/main/cpp/include/llama.h +141 -38
- package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
- package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
- package/cpp/LlamaCppModel.cpp +58 -24
- package/cpp/LlamaCppModel.h +3 -3
- package/cpp/PureCppImpl.cpp +1 -1
- package/cpp/PureCppImpl.h +2 -2
- package/cpp/build-info.cpp +2 -2
- package/cpp/llama.cpp/CMakeLists.txt +15 -4
- package/cpp/llama.cpp/Makefile +2 -2
- package/cpp/llama.cpp/README.md +32 -13
- package/cpp/llama.cpp/common/CMakeLists.txt +10 -20
- package/cpp/llama.cpp/common/arg.cpp +37 -6
- package/cpp/llama.cpp/common/build-info.cpp.in +2 -2
- package/cpp/llama.cpp/common/chat-parser.cpp +5 -0
- package/cpp/llama.cpp/common/chat-parser.h +2 -0
- package/cpp/llama.cpp/common/chat.cpp +12 -9
- package/cpp/llama.cpp/common/chat.h +1 -1
- package/cpp/llama.cpp/common/common.cpp +53 -40
- package/cpp/llama.cpp/common/common.h +6 -2
- package/cpp/llama.cpp/common/json-schema-to-grammar.cpp +3 -46
- package/cpp/llama.cpp/common/speculative.cpp +6 -4
- package/cpp/llama.cpp/convert_hf_to_gguf.py +215 -76
- package/cpp/llama.cpp/ggml/CMakeLists.txt +48 -2
- package/cpp/llama.cpp/ggml/cmake/common.cmake +1 -2
- package/cpp/llama.cpp/ggml/include/ggml-cpu.h +2 -0
- package/cpp/llama.cpp/ggml/include/ggml.h +33 -0
- package/cpp/llama.cpp/ggml/src/CMakeLists.txt +64 -13
- package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +6 -1
- package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +33 -9
- package/cpp/llama.cpp/ggml/src/ggml-common.h +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +124 -26
- package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +11 -10
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +4114 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +2163 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +2639 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +2732 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +2069 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +397 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +1300 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +1481 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +4311 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +79 -3225
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +184 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +4 -3
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +16 -7
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +93 -104
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +12 -8
- package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +58 -8
- package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +194 -69
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +1158 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +1571 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.h +98 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +213 -37
- package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +45 -45
- package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +59 -37
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +4 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +90 -39
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cu +19 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cuh +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cu +257 -87
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cuh +2 -3
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +6 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cu +5 -18
- package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cuh +0 -1
- package/cpp/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-impl.h +61 -183
- package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +11 -10
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +16 -0
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +260 -49
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +497 -282
- package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cuh +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +1078 -468
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +121 -0
- package/cpp/llama.cpp/ggml/src/ggml-quants.c +0 -2
- package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +18 -15
- package/cpp/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +5 -6
- package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +20 -48
- package/cpp/llama.cpp/ggml/src/ggml-sycl/concat.cpp +28 -41
- package/cpp/llama.cpp/ggml/src/ggml-sycl/conv.cpp +4 -10
- package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +117 -165
- package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +192 -53
- package/cpp/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +32 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +49 -67
- package/cpp/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +99 -159
- package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +8 -105
- package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +209 -92
- package/cpp/llama.cpp/ggml/src/ggml-sycl/gla.cpp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +60 -80
- package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +158 -203
- package/cpp/llama.cpp/ggml/src/ggml-sycl/norm.cpp +55 -74
- package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +38 -10
- package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +24 -20
- package/cpp/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +3 -3
- package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +3 -8
- package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +108 -16
- package/cpp/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +12 -16
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +36 -28
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +487 -247
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -12
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +2 -0
- package/cpp/llama.cpp/ggml/src/ggml.c +69 -19
- package/cpp/llama.cpp/ggml/src/gguf.cpp +5 -1
- package/cpp/llama.cpp/gguf-py/gguf/constants.py +133 -0
- package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +25 -1
- package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +78 -3
- package/cpp/llama.cpp/gguf-py/gguf/vocab.py +97 -4
- package/cpp/llama.cpp/gguf-py/pyproject.toml +2 -2
- package/cpp/llama.cpp/include/llama.h +141 -38
- package/cpp/llama.cpp/models/templates/Mistral-Small-3.2-24B-Instruct-2506.jinja +124 -0
- package/cpp/llama.cpp/requirements/requirements-compare-llama-bench.txt +1 -0
- package/cpp/llama.cpp/src/CMakeLists.txt +2 -2
- package/cpp/llama.cpp/src/llama-arch.cpp +150 -3
- package/cpp/llama.cpp/src/llama-arch.h +25 -1
- package/cpp/llama.cpp/src/llama-batch.cpp +736 -274
- package/cpp/llama.cpp/src/llama-batch.h +110 -57
- package/cpp/llama.cpp/src/llama-chat.cpp +30 -8
- package/cpp/llama.cpp/src/llama-chat.h +1 -0
- package/cpp/llama.cpp/src/llama-context.cpp +360 -266
- package/cpp/llama.cpp/src/llama-context.h +27 -23
- package/cpp/llama.cpp/src/llama-cparams.cpp +1 -1
- package/cpp/llama.cpp/src/llama-cparams.h +1 -1
- package/cpp/llama.cpp/src/llama-graph.cpp +411 -344
- package/cpp/llama.cpp/src/llama-graph.h +126 -58
- package/cpp/llama.cpp/src/llama-hparams.cpp +10 -2
- package/cpp/llama.cpp/src/llama-hparams.h +16 -2
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +103 -73
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +34 -42
- package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +345 -221
- package/cpp/llama.cpp/src/llama-kv-cache-unified.h +75 -50
- package/cpp/llama.cpp/src/llama-kv-cells.h +51 -22
- package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +246 -0
- package/cpp/llama.cpp/src/llama-memory-hybrid.h +138 -0
- package/cpp/llama.cpp/src/{llama-kv-cache-recurrent.cpp → llama-memory-recurrent.cpp} +302 -317
- package/cpp/llama.cpp/src/{llama-kv-cache-recurrent.h → llama-memory-recurrent.h} +60 -68
- package/cpp/llama.cpp/src/llama-memory.cpp +41 -0
- package/cpp/llama.cpp/src/llama-memory.h +73 -36
- package/cpp/llama.cpp/src/llama-mmap.cpp +1 -1
- package/cpp/llama.cpp/src/llama-model-loader.cpp +42 -17
- package/cpp/llama.cpp/src/llama-model-saver.cpp +1 -0
- package/cpp/llama.cpp/src/llama-model.cpp +1630 -511
- package/cpp/llama.cpp/src/llama-model.h +26 -0
- package/cpp/llama.cpp/src/llama-quant.cpp +89 -6
- package/cpp/llama.cpp/src/llama-vocab.cpp +58 -26
- package/cpp/llama.cpp/src/llama-vocab.h +1 -0
- package/cpp/llama.cpp/src/llama.cpp +11 -7
- package/cpp/llama.cpp/src/unicode.cpp +5 -0
- package/cpp/rn-completion.cpp +2 -2
- package/cpp/{rn-llama.hpp → rn-llama.h} +1 -1
- package/cpp/{rn-utils.hpp → rn-utils.h} +3 -0
- package/ios/include/chat.h +1 -1
- package/ios/include/common.h +6 -2
- package/ios/include/llama.h +141 -38
- package/ios/libs/llama.xcframework/Info.plist +15 -15
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4890 -4689
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +33 -0
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +141 -38
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4861 -4710
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3764 -3622
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +33 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +141 -38
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4861 -4710
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3766 -3624
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +33 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +141 -38
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +33 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +141 -38
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +33 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +141 -38
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4890 -4689
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +33 -0
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +141 -38
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4861 -4710
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3764 -3622
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +33 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +141 -38
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4926 -4725
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +33 -0
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +141 -38
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4897 -4746
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3794 -3652
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +33 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +141 -38
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/package.json +1 -2
- package/cpp/llama.cpp/common/cmake/build-info-gen-cpp.cmake +0 -24
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13891
- package/cpp/llama.cpp/src/llama-kv-cache.cpp +0 -1
- package/cpp/llama.cpp/src/llama-kv-cache.h +0 -44
- /package/cpp/llama.cpp/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
- /package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
- /package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
|
@@ -48,22 +48,28 @@ static struct ggml_backend_metal_device_context {
|
|
|
48
48
|
int mtl_device_ref_count;
|
|
49
49
|
id<MTLLibrary> mtl_library;
|
|
50
50
|
|
|
51
|
+
NSLock * mtl_lock;
|
|
52
|
+
|
|
51
53
|
bool has_simdgroup_reduction;
|
|
52
54
|
bool has_simdgroup_mm;
|
|
53
55
|
bool has_residency_sets;
|
|
54
56
|
bool has_bfloat;
|
|
55
57
|
bool use_bfloat;
|
|
56
58
|
|
|
59
|
+
size_t max_size;
|
|
60
|
+
|
|
57
61
|
char name[128];
|
|
58
62
|
} g_ggml_ctx_dev_main = {
|
|
59
63
|
/*.mtl_device =*/ nil,
|
|
60
64
|
/*.mtl_device_ref_count =*/ 0,
|
|
61
65
|
/*.mtl_library =*/ nil,
|
|
66
|
+
/*.mtl_lock =*/ nil,
|
|
62
67
|
/*.has_simdgroup_reduction =*/ false,
|
|
63
68
|
/*.has_simdgroup_mm =*/ false,
|
|
64
69
|
/*.has_residency_sets =*/ false,
|
|
65
70
|
/*.has_bfloat =*/ false,
|
|
66
71
|
/*.use_bfloat =*/ false,
|
|
72
|
+
/*.max_size =*/ 0,
|
|
67
73
|
/*.name =*/ "",
|
|
68
74
|
};
|
|
69
75
|
|
|
@@ -71,6 +77,10 @@ static struct ggml_backend_metal_device_context {
|
|
|
71
77
|
static id<MTLDevice> ggml_backend_metal_device_acq(struct ggml_backend_metal_device_context * ctx) {
|
|
72
78
|
assert(ctx != NULL);
|
|
73
79
|
|
|
80
|
+
if (ctx->mtl_lock == nil) {
|
|
81
|
+
ctx->mtl_lock = [[NSLock alloc] init];
|
|
82
|
+
}
|
|
83
|
+
|
|
74
84
|
if (ctx->mtl_device == nil) {
|
|
75
85
|
ctx->mtl_device = MTLCreateSystemDefaultDevice();
|
|
76
86
|
}
|
|
@@ -94,6 +104,8 @@ static id<MTLDevice> ggml_backend_metal_device_acq(struct ggml_backend_metal_dev
|
|
|
94
104
|
ctx->use_bfloat = false;
|
|
95
105
|
#endif
|
|
96
106
|
|
|
107
|
+
ctx->max_size = ctx->mtl_device.maxBufferLength;
|
|
108
|
+
|
|
97
109
|
strncpy(ctx->name, [[ctx->mtl_device name] UTF8String], sizeof(ctx->name) - 1);
|
|
98
110
|
}
|
|
99
111
|
|
|
@@ -110,6 +122,11 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
|
|
|
110
122
|
ctx->mtl_device_ref_count--;
|
|
111
123
|
|
|
112
124
|
if (ctx->mtl_device_ref_count == 0) {
|
|
125
|
+
if (ctx->mtl_lock) {
|
|
126
|
+
[ctx->mtl_lock release];
|
|
127
|
+
ctx->mtl_lock = nil;
|
|
128
|
+
}
|
|
129
|
+
|
|
113
130
|
if (ctx->mtl_library) {
|
|
114
131
|
[ctx->mtl_library release];
|
|
115
132
|
ctx->mtl_library = nil;
|
|
@@ -185,6 +202,15 @@ enum ggml_metal_kernel_type {
|
|
|
185
202
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,
|
|
186
203
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
|
|
187
204
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
|
|
205
|
+
GGML_METAL_KERNEL_TYPE_SET_ROWS_F32,
|
|
206
|
+
GGML_METAL_KERNEL_TYPE_SET_ROWS_F16,
|
|
207
|
+
GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16,
|
|
208
|
+
GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0,
|
|
209
|
+
GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0,
|
|
210
|
+
GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1,
|
|
211
|
+
GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0,
|
|
212
|
+
GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1,
|
|
213
|
+
GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL,
|
|
188
214
|
GGML_METAL_KERNEL_TYPE_RMS_NORM,
|
|
189
215
|
GGML_METAL_KERNEL_TYPE_L2_NORM,
|
|
190
216
|
GGML_METAL_KERNEL_TYPE_GROUP_NORM,
|
|
@@ -194,11 +220,14 @@ enum ggml_metal_kernel_type {
|
|
|
194
220
|
GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32,
|
|
195
221
|
GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32,
|
|
196
222
|
GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
|
|
223
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32_C4,
|
|
197
224
|
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
|
|
225
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_C4,
|
|
198
226
|
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
|
|
199
227
|
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4,
|
|
200
228
|
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,
|
|
201
229
|
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32,
|
|
230
|
+
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_C4,
|
|
202
231
|
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW,
|
|
203
232
|
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4,
|
|
204
233
|
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16,
|
|
@@ -498,6 +527,7 @@ enum ggml_metal_kernel_type {
|
|
|
498
527
|
GGML_METAL_KERNEL_TYPE_COS,
|
|
499
528
|
GGML_METAL_KERNEL_TYPE_NEG,
|
|
500
529
|
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
|
|
530
|
+
GGML_METAL_KERNEL_TYPE_MEAN,
|
|
501
531
|
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
|
|
502
532
|
GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
|
|
503
533
|
GGML_METAL_KERNEL_TYPE_ARGMAX,
|
|
@@ -976,7 +1006,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
|
976
1006
|
struct ggml_backend_metal_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_context));
|
|
977
1007
|
struct ggml_backend_metal_device_context * ctx_dev = dev->context;
|
|
978
1008
|
|
|
979
|
-
id<MTLDevice> device =
|
|
1009
|
+
id<MTLDevice> device = ctx_dev->mtl_device;
|
|
980
1010
|
|
|
981
1011
|
GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]);
|
|
982
1012
|
|
|
@@ -990,9 +1020,16 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
|
990
1020
|
ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
|
|
991
1021
|
|
|
992
1022
|
// load library
|
|
993
|
-
|
|
994
|
-
ctx_dev->
|
|
1023
|
+
{
|
|
1024
|
+
[ctx_dev->mtl_lock lock];
|
|
1025
|
+
|
|
1026
|
+
if (ctx_dev->mtl_library == nil) {
|
|
1027
|
+
ctx_dev->mtl_library = ggml_metal_load_library(device, ctx_dev->use_bfloat);
|
|
1028
|
+
}
|
|
1029
|
+
|
|
1030
|
+
[ctx_dev->mtl_lock unlock];
|
|
995
1031
|
}
|
|
1032
|
+
|
|
996
1033
|
id<MTLLibrary> metal_library = ctx_dev->mtl_library;
|
|
997
1034
|
if (metal_library == nil) {
|
|
998
1035
|
GGML_LOG_ERROR("%s: error: metal library is nil\n", __func__);
|
|
@@ -1141,6 +1178,15 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
|
1141
1178
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
|
|
1142
1179
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
|
|
1143
1180
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
|
|
1181
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_F32, set_rows_f32, true);
|
|
1182
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_F16, set_rows_f16, true);
|
|
1183
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16, set_rows_bf16, use_bfloat);
|
|
1184
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0, set_rows_q8_0, true);
|
|
1185
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0, set_rows_q4_0, true);
|
|
1186
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1, set_rows_q4_1, true);
|
|
1187
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0, set_rows_q5_0, true);
|
|
1188
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1, set_rows_q5_1, true);
|
|
1189
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL, set_rows_iq4_nl, true);
|
|
1144
1190
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction);
|
|
1145
1191
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_L2_NORM, l2_norm, has_simdgroup_reduction);
|
|
1146
1192
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction);
|
|
@@ -1150,11 +1196,14 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
|
1150
1196
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, rwkv_wkv6_f32, true);
|
|
1151
1197
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, rwkv_wkv7_f32, true);
|
|
1152
1198
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction);
|
|
1199
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32_C4, mul_mv_f32_f32_c4, true);
|
|
1153
1200
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, has_simdgroup_reduction && use_bfloat);
|
|
1201
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_C4, mul_mv_bf16_f32_c4, use_bfloat);
|
|
1154
1202
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, has_simdgroup_reduction && use_bfloat);
|
|
1155
1203
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4, mul_mv_bf16_f32_l4, has_simdgroup_reduction && use_bfloat);
|
|
1156
1204
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16, mul_mv_bf16_bf16, has_simdgroup_reduction && use_bfloat);
|
|
1157
1205
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, has_simdgroup_reduction);
|
|
1206
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_C4, mul_mv_f16_f32_c4, true);
|
|
1158
1207
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, has_simdgroup_reduction);
|
|
1159
1208
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, has_simdgroup_reduction);
|
|
1160
1209
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, has_simdgroup_reduction);
|
|
@@ -1454,6 +1503,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
|
1454
1503
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
|
|
1455
1504
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true);
|
|
1456
1505
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
|
1506
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MEAN, mean, true);
|
|
1457
1507
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
|
|
1458
1508
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
|
|
1459
1509
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true);
|
|
@@ -1603,6 +1653,10 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
|
|
1603
1653
|
const bool use_bfloat = ctx_dev->use_bfloat;
|
|
1604
1654
|
|
|
1605
1655
|
if (!use_bfloat) {
|
|
1656
|
+
if (op->type == GGML_TYPE_BF16) {
|
|
1657
|
+
return false;
|
|
1658
|
+
}
|
|
1659
|
+
|
|
1606
1660
|
for (size_t i = 0, n = 3; i < n; ++i) {
|
|
1607
1661
|
if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16) {
|
|
1608
1662
|
return false;
|
|
@@ -1653,6 +1707,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
|
|
1653
1707
|
case GGML_OP_LOG:
|
|
1654
1708
|
return false; // TODO: implement
|
|
1655
1709
|
case GGML_OP_SUM_ROWS:
|
|
1710
|
+
case GGML_OP_MEAN:
|
|
1656
1711
|
case GGML_OP_SOFT_MAX:
|
|
1657
1712
|
case GGML_OP_GROUP_NORM:
|
|
1658
1713
|
return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
|
|
@@ -1771,6 +1826,27 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
|
|
1771
1826
|
{
|
|
1772
1827
|
return op->ne[3] == 1;
|
|
1773
1828
|
}
|
|
1829
|
+
case GGML_OP_SET_ROWS:
|
|
1830
|
+
{
|
|
1831
|
+
if (op->src[0]->type != GGML_TYPE_F32) {
|
|
1832
|
+
return false;
|
|
1833
|
+
}
|
|
1834
|
+
|
|
1835
|
+
switch (op->type) {
|
|
1836
|
+
case GGML_TYPE_F32:
|
|
1837
|
+
case GGML_TYPE_F16:
|
|
1838
|
+
case GGML_TYPE_BF16:
|
|
1839
|
+
case GGML_TYPE_Q8_0:
|
|
1840
|
+
case GGML_TYPE_Q4_0:
|
|
1841
|
+
case GGML_TYPE_Q4_1:
|
|
1842
|
+
case GGML_TYPE_Q5_0:
|
|
1843
|
+
case GGML_TYPE_Q5_1:
|
|
1844
|
+
case GGML_TYPE_IQ4_NL:
|
|
1845
|
+
return true;
|
|
1846
|
+
default:
|
|
1847
|
+
return false;
|
|
1848
|
+
};
|
|
1849
|
+
}
|
|
1774
1850
|
default:
|
|
1775
1851
|
return false;
|
|
1776
1852
|
}
|
|
@@ -2400,11 +2476,31 @@ static bool ggml_metal_encode_node(
|
|
|
2400
2476
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
2401
2477
|
} break;
|
|
2402
2478
|
case GGML_OP_SUM_ROWS:
|
|
2479
|
+
case GGML_OP_MEAN:
|
|
2403
2480
|
{
|
|
2404
2481
|
GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
|
|
2405
2482
|
|
|
2406
|
-
id<MTLComputePipelineState> pipeline =
|
|
2483
|
+
id<MTLComputePipelineState> pipeline = nil;
|
|
2484
|
+
|
|
2485
|
+
switch (dst->op) {
|
|
2486
|
+
case GGML_OP_SUM_ROWS:
|
|
2487
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
|
|
2488
|
+
break;
|
|
2489
|
+
case GGML_OP_MEAN:
|
|
2490
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MEAN].pipeline;
|
|
2491
|
+
break;
|
|
2492
|
+
default:
|
|
2493
|
+
GGML_ABORT("fatal error");
|
|
2494
|
+
}
|
|
2407
2495
|
|
|
2496
|
+
int nth = 32; // SIMD width
|
|
2497
|
+
|
|
2498
|
+
while (nth < ne00 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
|
|
2499
|
+
nth *= 2;
|
|
2500
|
+
}
|
|
2501
|
+
|
|
2502
|
+
nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
|
|
2503
|
+
nth = MIN(nth, ne00);
|
|
2408
2504
|
|
|
2409
2505
|
ggml_metal_kargs_sum_rows args = {
|
|
2410
2506
|
/*.ne00 =*/ ne00,
|
|
@@ -2434,11 +2530,12 @@ static bool ggml_metal_encode_node(
|
|
|
2434
2530
|
};
|
|
2435
2531
|
|
|
2436
2532
|
[encoder setComputePipelineState:pipeline];
|
|
2437
|
-
[encoder
|
|
2438
|
-
[encoder setBuffer:
|
|
2439
|
-
[encoder
|
|
2533
|
+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
|
2534
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
|
2535
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
2536
|
+
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
|
2440
2537
|
|
|
2441
|
-
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(
|
|
2538
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
2442
2539
|
} break;
|
|
2443
2540
|
case GGML_OP_SOFT_MAX:
|
|
2444
2541
|
{
|
|
@@ -3063,14 +3160,23 @@ static bool ggml_metal_encode_node(
|
|
|
3063
3160
|
nsg = 1;
|
|
3064
3161
|
nr0 = 1;
|
|
3065
3162
|
nr1 = 4;
|
|
3066
|
-
|
|
3163
|
+
if (ne00 == 4) {
|
|
3164
|
+
nr0 = 32;
|
|
3165
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32_C4].pipeline;
|
|
3166
|
+
} else {
|
|
3167
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline;
|
|
3168
|
+
}
|
|
3067
3169
|
} break;
|
|
3068
3170
|
case GGML_TYPE_F16:
|
|
3069
3171
|
{
|
|
3070
3172
|
nsg = 1;
|
|
3071
3173
|
nr0 = 1;
|
|
3072
3174
|
if (src1t == GGML_TYPE_F32) {
|
|
3073
|
-
if (
|
|
3175
|
+
if (ne00 == 4) {
|
|
3176
|
+
nr0 = 32;
|
|
3177
|
+
nr1 = 4;
|
|
3178
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_C4].pipeline;
|
|
3179
|
+
} else if (ne11 * ne12 < 4) {
|
|
3074
3180
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline;
|
|
3075
3181
|
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
|
|
3076
3182
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline;
|
|
@@ -3089,7 +3195,11 @@ static bool ggml_metal_encode_node(
|
|
|
3089
3195
|
nsg = 1;
|
|
3090
3196
|
nr0 = 1;
|
|
3091
3197
|
if (src1t == GGML_TYPE_F32) {
|
|
3092
|
-
if (
|
|
3198
|
+
if (ne00 == 4) {
|
|
3199
|
+
nr0 = 32;
|
|
3200
|
+
nr1 = 4;
|
|
3201
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_C4].pipeline;
|
|
3202
|
+
} else if (ne11 * ne12 < 4) {
|
|
3093
3203
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW].pipeline;
|
|
3094
3204
|
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
|
|
3095
3205
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4].pipeline;
|
|
@@ -3710,13 +3820,74 @@ static bool ggml_metal_encode_node(
|
|
|
3710
3820
|
};
|
|
3711
3821
|
|
|
3712
3822
|
[encoder setComputePipelineState:pipeline];
|
|
3713
|
-
[encoder
|
|
3714
|
-
[encoder setBuffer:
|
|
3715
|
-
[encoder setBuffer:
|
|
3716
|
-
[encoder
|
|
3823
|
+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
|
3824
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
|
3825
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
|
|
3826
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
|
3717
3827
|
|
|
3718
3828
|
[encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
|
|
3719
3829
|
} break;
|
|
3830
|
+
case GGML_OP_SET_ROWS:
|
|
3831
|
+
{
|
|
3832
|
+
id<MTLComputePipelineState> pipeline = nil;
|
|
3833
|
+
|
|
3834
|
+
switch (dst->type) {
|
|
3835
|
+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_F32 ].pipeline; break;
|
|
3836
|
+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_F16 ].pipeline; break;
|
|
3837
|
+
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16 ].pipeline; break;
|
|
3838
|
+
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0 ].pipeline; break;
|
|
3839
|
+
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0 ].pipeline; break;
|
|
3840
|
+
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1 ].pipeline; break;
|
|
3841
|
+
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0 ].pipeline; break;
|
|
3842
|
+
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1 ].pipeline; break;
|
|
3843
|
+
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL].pipeline; break;
|
|
3844
|
+
default: GGML_ABORT("not implemented");
|
|
3845
|
+
}
|
|
3846
|
+
|
|
3847
|
+
const int32_t nk0 = ne0/ggml_blck_size(dst->type);
|
|
3848
|
+
|
|
3849
|
+
int nth = 32; // SIMD width
|
|
3850
|
+
|
|
3851
|
+
while (nth < nk0 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
|
|
3852
|
+
nth *= 2;
|
|
3853
|
+
}
|
|
3854
|
+
|
|
3855
|
+
int nrptg = 1;
|
|
3856
|
+
if (nth > nk0) {
|
|
3857
|
+
nrptg = (nth + nk0 - 1)/nk0;
|
|
3858
|
+
nth = nk0;
|
|
3859
|
+
|
|
3860
|
+
if (nrptg*nth > (int) pipeline.maxTotalThreadsPerThreadgroup) {
|
|
3861
|
+
nrptg--;
|
|
3862
|
+
}
|
|
3863
|
+
}
|
|
3864
|
+
|
|
3865
|
+
nth = MIN(nth, nk0);
|
|
3866
|
+
|
|
3867
|
+
ggml_metal_kargs_set_rows args = {
|
|
3868
|
+
/*.nk0 =*/ nk0,
|
|
3869
|
+
/*.ne01 =*/ ne01,
|
|
3870
|
+
/*.nb01 =*/ nb01,
|
|
3871
|
+
/*.nb02 =*/ nb02,
|
|
3872
|
+
/*.nb03 =*/ nb03,
|
|
3873
|
+
/*.ne11 =*/ ne11,
|
|
3874
|
+
/*.ne12 =*/ ne12,
|
|
3875
|
+
/*.nb10 =*/ nb10,
|
|
3876
|
+
/*.nb11 =*/ nb11,
|
|
3877
|
+
/*.nb12 =*/ nb12,
|
|
3878
|
+
/*.nb1 =*/ nb1,
|
|
3879
|
+
/*.nb2 =*/ nb2,
|
|
3880
|
+
/*.nb3 =*/ nb3,
|
|
3881
|
+
};
|
|
3882
|
+
|
|
3883
|
+
[encoder setComputePipelineState:pipeline];
|
|
3884
|
+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
|
3885
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
|
3886
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
|
|
3887
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
|
3888
|
+
|
|
3889
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nrptg - 1)/nrptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, nrptg, 1)];
|
|
3890
|
+
} break;
|
|
3720
3891
|
case GGML_OP_RMS_NORM:
|
|
3721
3892
|
{
|
|
3722
3893
|
GGML_ASSERT(ne00 % 4 == 0);
|
|
@@ -3733,6 +3904,7 @@ static bool ggml_metal_encode_node(
|
|
|
3733
3904
|
nth *= 2;
|
|
3734
3905
|
}
|
|
3735
3906
|
|
|
3907
|
+
nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
|
|
3736
3908
|
nth = MIN(nth, ne00/4);
|
|
3737
3909
|
|
|
3738
3910
|
ggml_metal_kargs_rms_norm args = {
|
|
@@ -3769,6 +3941,7 @@ static bool ggml_metal_encode_node(
|
|
|
3769
3941
|
nth *= 2;
|
|
3770
3942
|
}
|
|
3771
3943
|
|
|
3944
|
+
nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
|
|
3772
3945
|
nth = MIN(nth, ne00/4);
|
|
3773
3946
|
|
|
3774
3947
|
ggml_metal_kargs_l2_norm args = {
|
|
@@ -3841,6 +4014,7 @@ static bool ggml_metal_encode_node(
|
|
|
3841
4014
|
nth *= 2;
|
|
3842
4015
|
}
|
|
3843
4016
|
|
|
4017
|
+
nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
|
|
3844
4018
|
nth = MIN(nth, ne00/4);
|
|
3845
4019
|
|
|
3846
4020
|
ggml_metal_kargs_norm args = {
|
|
@@ -4766,6 +4940,8 @@ static bool ggml_metal_encode_node(
|
|
|
4766
4940
|
GGML_ASSERT(nqptg % 8 == 0);
|
|
4767
4941
|
GGML_ASSERT(ncpsg % 32 == 0);
|
|
4768
4942
|
|
|
4943
|
+
const int is_q = ggml_is_quantized(src1->type) ? 1 : 0;
|
|
4944
|
+
|
|
4769
4945
|
// 2*(2*ncpsg + nqptg)*(nsg)
|
|
4770
4946
|
// ncpsg soft_max values + ncpsg mask values + a diagonal scaling matrix (in float)
|
|
4771
4947
|
//
|
|
@@ -4773,7 +4949,7 @@ static bool ggml_metal_encode_node(
|
|
|
4773
4949
|
// the shared memory needed for the simdgroups to load the KV cache
|
|
4774
4950
|
// each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
|
|
4775
4951
|
//
|
|
4776
|
-
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*(2*ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16))
|
|
4952
|
+
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(2*ne00 + 2*(2*ncpsg + nqptg)*(nsg)) + is_q*(16*32*(nsg)))*(sizeof(float)/2), 16))
|
|
4777
4953
|
|
|
4778
4954
|
int64_t nsgmax = 2;
|
|
4779
4955
|
|
|
@@ -4810,9 +4986,9 @@ static bool ggml_metal_encode_node(
|
|
|
4810
4986
|
// and store the soft_max values and the mask
|
|
4811
4987
|
//
|
|
4812
4988
|
// ne00*(nsg)
|
|
4813
|
-
// each simdgroup has a full
|
|
4989
|
+
// each simdgroup has a full f32 head vector in shared mem to accumulate results
|
|
4814
4990
|
//
|
|
4815
|
-
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + ne20*(nsg))*(sizeof(float)/2), 16))
|
|
4991
|
+
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + 2*ne20*(nsg))*(sizeof(float)/2), 16))
|
|
4816
4992
|
|
|
4817
4993
|
int64_t nsgmax = 2;
|
|
4818
4994
|
while (true) {
|
|
@@ -4925,8 +5101,39 @@ static bool ggml_metal_encode_node(
|
|
|
4925
5101
|
default: GGML_ABORT("not implemented");
|
|
4926
5102
|
}
|
|
4927
5103
|
|
|
5104
|
+
GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
|
|
5105
|
+
|
|
5106
|
+
// TODO: support
|
|
5107
|
+
//const int32_t nk00 = ne00/ggml_blck_size(dst->type);
|
|
5108
|
+
const int32_t nk00 = ne00;
|
|
5109
|
+
|
|
5110
|
+
int nth = 32; // SIMD width
|
|
5111
|
+
|
|
5112
|
+
while (nth < nk00 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
|
|
5113
|
+
nth *= 2;
|
|
5114
|
+
}
|
|
5115
|
+
|
|
5116
|
+
nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
|
|
5117
|
+
|
|
5118
|
+
// when rows are small, we can batch them together in a single threadgroup
|
|
5119
|
+
int nrptg = 1;
|
|
5120
|
+
|
|
5121
|
+
// TODO: relax this constraint in the future
|
|
5122
|
+
if (ggml_blck_size(src0->type) == 1 && ggml_blck_size(dst->type) == 1) {
|
|
5123
|
+
if (nth > nk00) {
|
|
5124
|
+
nrptg = (nth + nk00 - 1)/nk00;
|
|
5125
|
+
nth = nk00;
|
|
5126
|
+
|
|
5127
|
+
if (nrptg*nth > (int) pipeline.maxTotalThreadsPerThreadgroup) {
|
|
5128
|
+
nrptg--;
|
|
5129
|
+
}
|
|
5130
|
+
}
|
|
5131
|
+
}
|
|
5132
|
+
|
|
5133
|
+
nth = MIN(nth, nk00);
|
|
5134
|
+
|
|
4928
5135
|
ggml_metal_kargs_cpy args = {
|
|
4929
|
-
/*.ne00 =*/
|
|
5136
|
+
/*.ne00 =*/ nk00,
|
|
4930
5137
|
/*.ne01 =*/ ne01,
|
|
4931
5138
|
/*.ne02 =*/ ne02,
|
|
4932
5139
|
/*.ne03 =*/ ne03,
|
|
@@ -4949,11 +5156,7 @@ static bool ggml_metal_encode_node(
|
|
|
4949
5156
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
|
4950
5157
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
4951
5158
|
|
|
4952
|
-
|
|
4953
|
-
int nth = MIN(1024, ne00/ggml_blck_size(src0->type));
|
|
4954
|
-
|
|
4955
|
-
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
4956
|
-
|
|
5159
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nrptg - 1)/nrptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, nrptg, 1)];
|
|
4957
5160
|
} break;
|
|
4958
5161
|
case GGML_OP_SET:
|
|
4959
5162
|
{
|
|
@@ -5259,7 +5462,6 @@ static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer)
|
|
|
5259
5462
|
}
|
|
5260
5463
|
|
|
5261
5464
|
ggml_backend_metal_buffer_rset_free(ctx);
|
|
5262
|
-
ggml_backend_metal_device_rel(buffer->buft->device->context);
|
|
5263
5465
|
|
|
5264
5466
|
if (ctx->owned) {
|
|
5265
5467
|
#if TARGET_OS_OSX
|
|
@@ -5368,7 +5570,10 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
|
|
|
5368
5570
|
}
|
|
5369
5571
|
|
|
5370
5572
|
struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)buft->device->context;
|
|
5371
|
-
|
|
5573
|
+
|
|
5574
|
+
GGML_ASSERT(ctx_dev->mtl_device != nil);
|
|
5575
|
+
|
|
5576
|
+
id<MTLDevice> device = ctx_dev->mtl_device;
|
|
5372
5577
|
|
|
5373
5578
|
ctx->all_data = ggml_metal_host_malloc(size_aligned);
|
|
5374
5579
|
ctx->all_size = size_aligned;
|
|
@@ -5391,14 +5596,12 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
|
|
|
5391
5596
|
if (size_aligned > 0 && (ctx->all_data == NULL || ctx->buffers[0].metal == nil)) {
|
|
5392
5597
|
GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0);
|
|
5393
5598
|
free(ctx);
|
|
5394
|
-
ggml_backend_metal_device_rel(ctx_dev);
|
|
5395
5599
|
return NULL;
|
|
5396
5600
|
}
|
|
5397
5601
|
|
|
5398
5602
|
if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) {
|
|
5399
5603
|
GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__);
|
|
5400
5604
|
free(ctx);
|
|
5401
|
-
ggml_backend_metal_device_rel(ctx_dev);
|
|
5402
5605
|
return NULL;
|
|
5403
5606
|
}
|
|
5404
5607
|
|
|
@@ -5409,17 +5612,14 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
|
|
|
5409
5612
|
|
|
5410
5613
|
static size_t ggml_backend_metal_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
|
|
5411
5614
|
return 32;
|
|
5615
|
+
|
|
5412
5616
|
GGML_UNUSED(buft);
|
|
5413
5617
|
}
|
|
5414
5618
|
|
|
5415
5619
|
static size_t ggml_backend_metal_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
|
|
5416
|
-
|
|
5417
|
-
const size_t max_size = device.maxBufferLength;
|
|
5418
|
-
ggml_backend_metal_device_rel(buft->device->context);
|
|
5620
|
+
const size_t max_size = ((struct ggml_backend_metal_device_context *)buft->device->context)->max_size;
|
|
5419
5621
|
|
|
5420
5622
|
return max_size;
|
|
5421
|
-
|
|
5422
|
-
GGML_UNUSED(buft);
|
|
5423
5623
|
}
|
|
5424
5624
|
|
|
5425
5625
|
static bool ggml_backend_metal_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
|
|
@@ -5492,7 +5692,10 @@ ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t siz
|
|
|
5492
5692
|
}
|
|
5493
5693
|
|
|
5494
5694
|
struct ggml_backend_metal_device_context * ctx_dev = &g_ggml_ctx_dev_main;
|
|
5495
|
-
|
|
5695
|
+
|
|
5696
|
+
GGML_ASSERT(ctx_dev->mtl_device != nil);
|
|
5697
|
+
|
|
5698
|
+
id<MTLDevice> device = ctx_dev->mtl_device;
|
|
5496
5699
|
|
|
5497
5700
|
// the buffer fits into the max buffer size allowed by the device
|
|
5498
5701
|
if (size_aligned <= device.maxBufferLength) {
|
|
@@ -5548,7 +5751,6 @@ ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t siz
|
|
|
5548
5751
|
if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) {
|
|
5549
5752
|
GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__);
|
|
5550
5753
|
free(ctx);
|
|
5551
|
-
ggml_backend_metal_device_rel(ctx_dev);
|
|
5552
5754
|
return NULL;
|
|
5553
5755
|
}
|
|
5554
5756
|
|
|
@@ -5564,10 +5766,8 @@ static const char * ggml_backend_metal_name(ggml_backend_t backend) {
|
|
|
5564
5766
|
}
|
|
5565
5767
|
|
|
5566
5768
|
static void ggml_backend_metal_free(ggml_backend_t backend) {
|
|
5567
|
-
struct ggml_backend_metal_context
|
|
5568
|
-
struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
|
|
5769
|
+
struct ggml_backend_metal_context * ctx = backend->context;
|
|
5569
5770
|
|
|
5570
|
-
ggml_backend_metal_device_rel(ctx_dev);
|
|
5571
5771
|
ggml_metal_free(ctx);
|
|
5572
5772
|
|
|
5573
5773
|
free(backend);
|
|
@@ -5707,6 +5907,8 @@ bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {
|
|
|
5707
5907
|
|
|
5708
5908
|
struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
|
|
5709
5909
|
|
|
5910
|
+
GGML_ASSERT(ctx_dev->mtl_device != nil);
|
|
5911
|
+
|
|
5710
5912
|
return [ctx_dev->mtl_device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
|
|
5711
5913
|
}
|
|
5712
5914
|
|
|
@@ -5726,10 +5928,7 @@ static const char * ggml_backend_metal_device_get_name(ggml_backend_dev_t dev) {
|
|
|
5726
5928
|
}
|
|
5727
5929
|
|
|
5728
5930
|
static const char * ggml_backend_metal_device_get_description(ggml_backend_dev_t dev) {
|
|
5729
|
-
// acq/rel just to populate ctx->name in case it hasn't been done yet
|
|
5730
5931
|
struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context;
|
|
5731
|
-
ggml_backend_metal_device_acq(ctx_dev);
|
|
5732
|
-
ggml_backend_metal_device_rel(ctx_dev);
|
|
5733
5932
|
|
|
5734
5933
|
return ctx_dev->name;
|
|
5735
5934
|
}
|
|
@@ -5737,12 +5936,10 @@ static const char * ggml_backend_metal_device_get_description(ggml_backend_dev_t
|
|
|
5737
5936
|
static void ggml_backend_metal_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
|
|
5738
5937
|
if (@available(macOS 10.12, iOS 16.0, *)) {
|
|
5739
5938
|
struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context;
|
|
5740
|
-
id<MTLDevice> device =
|
|
5939
|
+
id<MTLDevice> device = ctx_dev->mtl_device;
|
|
5741
5940
|
|
|
5742
5941
|
*total = device.recommendedMaxWorkingSetSize;
|
|
5743
5942
|
*free = *total - device.currentAllocatedSize;
|
|
5744
|
-
|
|
5745
|
-
ggml_backend_metal_device_rel(ctx_dev);
|
|
5746
5943
|
} else {
|
|
5747
5944
|
*free = 1;
|
|
5748
5945
|
*total = 1;
|
|
@@ -5820,7 +6017,10 @@ static ggml_backend_buffer_t ggml_backend_metal_device_buffer_from_ptr(ggml_back
|
|
|
5820
6017
|
}
|
|
5821
6018
|
|
|
5822
6019
|
struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context;
|
|
5823
|
-
|
|
6020
|
+
|
|
6021
|
+
GGML_ASSERT(ctx_dev->mtl_device != nil);
|
|
6022
|
+
|
|
6023
|
+
id<MTLDevice> device = ctx_dev->mtl_device;
|
|
5824
6024
|
|
|
5825
6025
|
// the buffer fits into the max buffer size allowed by the device
|
|
5826
6026
|
if (size_aligned <= device.maxBufferLength) {
|
|
@@ -5876,7 +6076,6 @@ static ggml_backend_buffer_t ggml_backend_metal_device_buffer_from_ptr(ggml_back
|
|
|
5876
6076
|
if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) {
|
|
5877
6077
|
GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__);
|
|
5878
6078
|
free(ctx);
|
|
5879
|
-
ggml_backend_metal_device_rel(ctx_dev);
|
|
5880
6079
|
return NULL;
|
|
5881
6080
|
}
|
|
5882
6081
|
|
|
@@ -5890,8 +6089,9 @@ static bool ggml_backend_metal_device_supports_op(ggml_backend_dev_t dev, const
|
|
|
5890
6089
|
}
|
|
5891
6090
|
|
|
5892
6091
|
static bool ggml_backend_metal_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
|
|
5893
|
-
return
|
|
5894
|
-
|
|
6092
|
+
return
|
|
6093
|
+
buft->iface.get_name == ggml_backend_metal_buffer_type_get_name ||
|
|
6094
|
+
buft->iface.get_name == ggml_backend_metal_buffer_from_ptr_type_get_name;
|
|
5895
6095
|
|
|
5896
6096
|
GGML_UNUSED(dev);
|
|
5897
6097
|
}
|
|
@@ -5976,8 +6176,19 @@ static struct ggml_backend_reg_i ggml_backend_metal_reg_i = {
|
|
|
5976
6176
|
/* .get_proc_address = */ ggml_backend_metal_get_proc_address,
|
|
5977
6177
|
};
|
|
5978
6178
|
|
|
6179
|
+
// called upon program exit
|
|
6180
|
+
static void ggml_metal_cleanup(void) {
|
|
6181
|
+
ggml_backend_metal_device_rel(&g_ggml_ctx_dev_main);
|
|
6182
|
+
}
|
|
6183
|
+
|
|
6184
|
+
// TODO: make thread-safe
|
|
5979
6185
|
ggml_backend_reg_t ggml_backend_metal_reg(void) {
|
|
5980
|
-
|
|
6186
|
+
ggml_backend_metal_device_acq(&g_ggml_ctx_dev_main);
|
|
6187
|
+
|
|
6188
|
+
// register cleanup callback
|
|
6189
|
+
// TODO: not ideal, but not sure if there is a better way to do this in Objective-C
|
|
6190
|
+
atexit(ggml_metal_cleanup);
|
|
6191
|
+
|
|
5981
6192
|
{
|
|
5982
6193
|
g_ggml_backend_metal_reg = (struct ggml_backend_reg) {
|
|
5983
6194
|
/* .api_version = */ GGML_BACKEND_API_VERSION,
|