cui-llama.rn 1.5.0 → 1.6.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +20 -20
- package/README.md +345 -319
- package/android/build.gradle +116 -116
- package/android/gradle.properties +5 -5
- package/android/src/main/AndroidManifest.xml +4 -4
- package/android/src/main/CMakeLists.txt +129 -124
- package/android/src/main/java/com/rnllama/LlamaContext.java +648 -645
- package/android/src/main/java/com/rnllama/RNLlama.java +695 -695
- package/android/src/main/java/com/rnllama/RNLlamaPackage.java +48 -48
- package/android/src/main/jni-utils.h +100 -100
- package/android/src/main/jni.cpp +1279 -1263
- package/android/src/main/jniLibs/arm64-v8a/librnllama.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod_i8mm.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_i8mm.so +0 -0
- package/android/src/main/jniLibs/x86_64/librnllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/librnllama_x86_64.so +0 -0
- package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +135 -135
- package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +136 -136
- package/cpp/LICENSE +21 -0
- package/cpp/README.md +4 -4
- package/cpp/chat.cpp +1 -1
- package/cpp/common.cpp +17 -2
- package/cpp/common.h +7 -3
- package/cpp/ggml-alloc.c +4 -1
- package/cpp/ggml-cpp.h +1 -1
- package/cpp/ggml-cpu/amx/amx.cpp +221 -0
- package/cpp/ggml-cpu/amx/amx.h +8 -0
- package/cpp/ggml-cpu/amx/common.h +91 -0
- package/cpp/ggml-cpu/amx/mmq.cpp +2511 -0
- package/cpp/ggml-cpu/amx/mmq.h +10 -0
- package/cpp/{binary-ops.h → ggml-cpu/binary-ops.h} +1 -1
- package/cpp/ggml-cpu/common.h +72 -0
- package/cpp/{ggml-cpu-aarch64.cpp → ggml-cpu/ggml-cpu-aarch64.cpp} +809 -101
- package/cpp/{ggml-cpu.c → ggml-cpu/ggml-cpu.c} +109 -42
- package/cpp/{ggml-cpu.cpp → ggml-cpu/ggml-cpu.cpp} +3 -0
- package/cpp/{ops.cpp → ggml-cpu/ops.cpp} +246 -160
- package/cpp/{ops.h → ggml-cpu/ops.h} +2 -20
- package/cpp/{sgemm.cpp → ggml-cpu/sgemm.cpp} +501 -0
- package/cpp/{simd-mappings.h → ggml-cpu/simd-mappings.h} +7 -3
- package/cpp/{unary-ops.h → ggml-cpu/unary-ops.h} +1 -1
- package/cpp/ggml-cpu.h +5 -0
- package/cpp/ggml-impl.h +16 -9
- package/cpp/ggml-llama-sim.metallib +0 -0
- package/cpp/ggml-llama.metallib +0 -0
- package/cpp/ggml-metal-impl.h +597 -597
- package/cpp/ggml-metal.m +496 -47
- package/cpp/ggml.c +134 -244
- package/cpp/ggml.h +62 -95
- package/cpp/json-schema-to-grammar.cpp +3 -0
- package/cpp/llama-arch.cpp +46 -17
- package/cpp/llama-arch.h +9 -0
- package/cpp/llama-batch.cpp +5 -1
- package/cpp/llama-batch.h +2 -1
- package/cpp/llama-chat.cpp +31 -10
- package/cpp/llama-chat.h +3 -2
- package/cpp/llama-context.cpp +104 -489
- package/cpp/llama-context.h +14 -30
- package/cpp/llama-graph.cpp +69 -62
- package/cpp/llama-graph.h +21 -18
- package/cpp/llama-hparams.h +5 -0
- package/cpp/llama-kv-cache.cpp +1497 -391
- package/cpp/llama-kv-cache.h +272 -80
- package/cpp/llama-memory.h +11 -1
- package/cpp/llama-model.cpp +502 -176
- package/cpp/llama-model.h +13 -3
- package/cpp/llama-sampling.cpp +2 -1
- package/cpp/llama-vocab.cpp +8 -1
- package/cpp/llama.h +14 -11
- package/cpp/rn-llama.cpp +721 -873
- package/cpp/rn-llama.h +134 -138
- package/cpp/sampling.h +107 -107
- package/cpp/unicode-data.cpp +7034 -7034
- package/cpp/unicode-data.h +20 -20
- package/cpp/unicode.cpp +849 -849
- package/cpp/unicode.h +66 -66
- package/ios/CMakeLists.txt +119 -108
- package/ios/RNLlama.h +13 -7
- package/ios/RNLlama.mm +423 -405
- package/ios/RNLlamaContext.h +57 -57
- package/ios/RNLlamaContext.mm +833 -835
- package/ios/rnllama.xcframework/Info.plist +74 -74
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/chat.h +143 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/common.h +681 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/cpu-common.h +72 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-alloc.h +76 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-backend-impl.h +255 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-backend.h +354 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-common.h +1857 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpp.h +39 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu.h +143 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-impl.h +601 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-metal-impl.h +597 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-metal.h +66 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-opt.h +216 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-quants.h +100 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-threading.h +14 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml.h +2189 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/gguf.h +202 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/json-schema-to-grammar.h +21 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/json.hpp +24766 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-adapter.h +76 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-arch.h +437 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-batch.h +89 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-chat.h +57 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-context.h +249 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-cparams.h +38 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-cpp.h +30 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-grammar.h +173 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-graph.h +595 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-hparams.h +161 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-impl.h +61 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-io.h +35 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cache.h +405 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-memory.h +31 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-mmap.h +68 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model-loader.h +169 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model.h +419 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-sampling.h +32 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-vocab.h +125 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama.h +1437 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/log.h +132 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja/chat-template.hpp +537 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja/minja.hpp +2941 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/rn-llama.h +134 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/sampling.h +107 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/speculative.h +28 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/unicode-data.h +20 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/unicode.h +66 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Info.plist +0 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/ggml-llama.metallib +0 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/rnllama +0 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +143 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +681 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/cpu-common.h +72 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-alloc.h +76 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-backend-impl.h +255 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-backend.h +354 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-common.h +1857 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpp.h +39 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu.h +143 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-impl.h +601 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-metal-impl.h +597 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-metal.h +66 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-opt.h +216 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-quants.h +100 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-threading.h +14 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +2189 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/gguf.h +202 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/json-schema-to-grammar.h +21 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/json.hpp +24766 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-adapter.h +76 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-arch.h +437 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-batch.h +89 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +57 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +249 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +38 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cpp.h +30 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-grammar.h +173 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +595 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +161 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-impl.h +61 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-io.h +35 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +405 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +31 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-mmap.h +68 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model-loader.h +169 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +419 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-sampling.h +32 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +125 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +1437 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/log.h +132 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +537 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +2941 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +134 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/sampling.h +107 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/speculative.h +28 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/unicode-data.h +20 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/unicode.h +66 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Info.plist +0 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/_CodeSignature/CodeResources +101 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/chat.h +143 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/common.h +681 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/cpu-common.h +72 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-alloc.h +76 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-backend-impl.h +255 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-backend.h +354 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-common.h +1857 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpp.h +39 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu.h +143 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-impl.h +601 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-metal-impl.h +597 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-metal.h +66 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-opt.h +216 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-quants.h +100 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-threading.h +14 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml.h +2189 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/gguf.h +202 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/json-schema-to-grammar.h +21 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/json.hpp +24766 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-adapter.h +76 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-arch.h +437 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-batch.h +89 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-chat.h +57 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-context.h +249 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-cparams.h +38 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-cpp.h +30 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-grammar.h +173 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-graph.h +595 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-hparams.h +161 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-impl.h +61 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-io.h +35 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cache.h +405 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-memory.h +31 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-mmap.h +68 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model-loader.h +169 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model.h +419 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-sampling.h +32 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-vocab.h +125 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama.h +1437 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/log.h +132 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/chat-template.hpp +537 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/minja.hpp +2941 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/rn-llama.h +134 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/sampling.h +107 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/speculative.h +28 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/unicode-data.h +20 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/unicode.h +66 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Info.plist +0 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/ggml-llama.metallib +0 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/rnllama +0 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +143 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +681 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/cpu-common.h +72 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-alloc.h +76 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-backend-impl.h +255 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-backend.h +354 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-common.h +1857 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpp.h +39 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu.h +143 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-impl.h +601 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-metal-impl.h +597 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-metal.h +66 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-opt.h +216 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-quants.h +100 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-threading.h +14 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +2189 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/gguf.h +202 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/json-schema-to-grammar.h +21 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/json.hpp +24766 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-adapter.h +76 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-arch.h +437 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-batch.h +89 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +57 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +249 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +38 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cpp.h +30 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-grammar.h +173 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +595 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +161 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-impl.h +61 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-io.h +35 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +405 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +31 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-mmap.h +68 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model-loader.h +169 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +419 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-sampling.h +32 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +125 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +1437 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/log.h +132 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +537 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +2941 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +134 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/sampling.h +107 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/speculative.h +28 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/unicode-data.h +20 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/unicode.h +66 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Info.plist +0 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/_CodeSignature/CodeResources +101 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
- package/jest/mock.js +203 -203
- package/lib/commonjs/NativeRNLlama.js +1 -2
- package/lib/commonjs/NativeRNLlama.js.map +1 -1
- package/lib/commonjs/chat.js.map +1 -1
- package/lib/commonjs/grammar.js +12 -31
- package/lib/commonjs/grammar.js.map +1 -1
- package/lib/commonjs/index.js +47 -47
- package/lib/commonjs/index.js.map +1 -1
- package/lib/commonjs/package.json +1 -0
- package/lib/module/NativeRNLlama.js +2 -0
- package/lib/module/NativeRNLlama.js.map +1 -1
- package/lib/module/chat.js +2 -0
- package/lib/module/chat.js.map +1 -1
- package/lib/module/grammar.js +14 -31
- package/lib/module/grammar.js.map +1 -1
- package/lib/module/index.js +47 -45
- package/lib/module/index.js.map +1 -1
- package/lib/module/package.json +1 -0
- package/lib/typescript/NativeRNLlama.d.ts +10 -4
- package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
- package/lib/typescript/index.d.ts.map +1 -1
- package/llama-rn.podspec +48 -48
- package/package.json +233 -233
- package/src/NativeRNLlama.ts +431 -426
- package/src/chat.ts +44 -44
- package/src/grammar.ts +854 -854
- package/src/index.ts +495 -487
- /package/cpp/{binary-ops.cpp → ggml-cpu/binary-ops.cpp} +0 -0
- /package/cpp/{ggml-cpu-aarch64.h → ggml-cpu/ggml-cpu-aarch64.h} +0 -0
- /package/cpp/{ggml-cpu-impl.h → ggml-cpu/ggml-cpu-impl.h} +0 -0
- /package/cpp/{ggml-cpu-quants.c → ggml-cpu/ggml-cpu-quants.c} +0 -0
- /package/cpp/{ggml-cpu-quants.h → ggml-cpu/ggml-cpu-quants.h} +0 -0
- /package/cpp/{ggml-cpu-traits.cpp → ggml-cpu/ggml-cpu-traits.cpp} +0 -0
- /package/cpp/{ggml-cpu-traits.h → ggml-cpu/ggml-cpu-traits.h} +0 -0
- /package/cpp/{sgemm.h → ggml-cpu/sgemm.h} +0 -0
- /package/cpp/{unary-ops.cpp → ggml-cpu/unary-ops.cpp} +0 -0
- /package/cpp/{vec.cpp → ggml-cpu/vec.cpp} +0 -0
- /package/cpp/{vec.h → ggml-cpu/vec.h} +0 -0
package/cpp/ggml-metal.m
CHANGED
@@ -44,8 +44,8 @@ static struct lm_ggml_backend_device g_lm_ggml_backend_metal_device;
|
|
44
44
|
// note: assumes single GPU device - the default one
|
45
45
|
// TODO: support multiple GPU devices
|
46
46
|
static struct lm_ggml_backend_metal_device_context {
|
47
|
-
id<MTLDevice>
|
48
|
-
int
|
47
|
+
id<MTLDevice> mtl_device;
|
48
|
+
int mtl_device_ref_count;
|
49
49
|
id<MTLLibrary> mtl_library;
|
50
50
|
|
51
51
|
bool has_simdgroup_reduction;
|
@@ -354,6 +354,7 @@ enum lm_ggml_metal_kernel_type {
|
|
354
354
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H192,
|
355
355
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128,
|
356
356
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
|
357
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512,
|
357
358
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64,
|
358
359
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80,
|
359
360
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96,
|
@@ -362,6 +363,7 @@ enum lm_ggml_metal_kernel_type {
|
|
362
363
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H192,
|
363
364
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128,
|
364
365
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256,
|
366
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512,
|
365
367
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64,
|
366
368
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80,
|
367
369
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96,
|
@@ -370,6 +372,7 @@ enum lm_ggml_metal_kernel_type {
|
|
370
372
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H192,
|
371
373
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128,
|
372
374
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256,
|
375
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512,
|
373
376
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64,
|
374
377
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80,
|
375
378
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96,
|
@@ -378,6 +381,7 @@ enum lm_ggml_metal_kernel_type {
|
|
378
381
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H192,
|
379
382
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128,
|
380
383
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256,
|
384
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512,
|
381
385
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64,
|
382
386
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80,
|
383
387
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96,
|
@@ -386,6 +390,7 @@ enum lm_ggml_metal_kernel_type {
|
|
386
390
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H192,
|
387
391
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128,
|
388
392
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256,
|
393
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512,
|
389
394
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64,
|
390
395
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80,
|
391
396
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96,
|
@@ -394,6 +399,7 @@ enum lm_ggml_metal_kernel_type {
|
|
394
399
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H192,
|
395
400
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128,
|
396
401
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256,
|
402
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512,
|
397
403
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64,
|
398
404
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80,
|
399
405
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96,
|
@@ -402,6 +408,14 @@ enum lm_ggml_metal_kernel_type {
|
|
402
408
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192,
|
403
409
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128,
|
404
410
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,
|
411
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512,
|
412
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96,
|
413
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96,
|
414
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96,
|
415
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H96,
|
416
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H96,
|
417
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H96,
|
418
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H96,
|
405
419
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
|
406
420
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128,
|
407
421
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128,
|
@@ -430,6 +444,13 @@ enum lm_ggml_metal_kernel_type {
|
|
430
444
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256,
|
431
445
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256,
|
432
446
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256,
|
447
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512,
|
448
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK576_HV512,
|
449
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK576_HV512,
|
450
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK576_HV512,
|
451
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512,
|
452
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512,
|
453
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512,
|
433
454
|
LM_GGML_METAL_KERNEL_TYPE_SET_I32,
|
434
455
|
LM_GGML_METAL_KERNEL_TYPE_SET_F32,
|
435
456
|
LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
|
@@ -460,6 +481,7 @@ enum lm_ggml_metal_kernel_type {
|
|
460
481
|
LM_GGML_METAL_KERNEL_TYPE_SQRT,
|
461
482
|
LM_GGML_METAL_KERNEL_TYPE_SIN,
|
462
483
|
LM_GGML_METAL_KERNEL_TYPE_COS,
|
484
|
+
LM_GGML_METAL_KERNEL_TYPE_NEG,
|
463
485
|
LM_GGML_METAL_KERNEL_TYPE_SUM_ROWS,
|
464
486
|
LM_GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
|
465
487
|
LM_GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
|
@@ -468,7 +490,259 @@ enum lm_ggml_metal_kernel_type {
|
|
468
490
|
LM_GGML_METAL_KERNEL_TYPE_COUNT
|
469
491
|
};
|
470
492
|
|
493
|
+
//
|
494
|
+
// lm_ggml_metal_heap
|
495
|
+
//
|
496
|
+
|
497
|
+
struct lm_ggml_metal_heap {
|
498
|
+
// number of times the heap was unused
|
499
|
+
int n_unused;
|
500
|
+
|
501
|
+
// total number of buffer allocations in this heap across all computes
|
502
|
+
int64_t n_alloc;
|
503
|
+
|
504
|
+
// current offset in the heap - we reset this after each node in order to reuse the memory
|
505
|
+
size_t offs;
|
506
|
+
|
507
|
+
// the currently allocated MTLBuffer objects in this heap
|
508
|
+
id<MTLHeap> obj;
|
509
|
+
|
510
|
+
NSMutableArray * bufs;
|
511
|
+
};
|
512
|
+
|
513
|
+
static struct lm_ggml_metal_heap * lm_ggml_metal_heap_init(id<MTLDevice> device, size_t size) {
|
514
|
+
struct lm_ggml_metal_heap * heap = calloc(1, sizeof(struct lm_ggml_metal_heap));
|
515
|
+
|
516
|
+
MTLHeapDescriptor * desc = [[MTLHeapDescriptor alloc] init];
|
517
|
+
desc.storageMode = MTLStorageModePrivate;
|
518
|
+
desc.cpuCacheMode = MTLCPUCacheModeDefaultCache;
|
519
|
+
desc.type = MTLHeapTypePlacement;
|
520
|
+
desc.size = size;
|
521
|
+
|
522
|
+
heap->n_unused = 0;
|
523
|
+
heap->n_alloc = 0;
|
524
|
+
|
525
|
+
heap->obj = [device newHeapWithDescriptor:desc];
|
526
|
+
if (!heap->obj) {
|
527
|
+
LM_GGML_LOG_ERROR("%s: error: failed to create MTLHeap with size %zu\n", __func__, size);
|
528
|
+
|
529
|
+
free(heap);
|
530
|
+
|
531
|
+
return false;
|
532
|
+
}
|
533
|
+
|
534
|
+
[desc release];
|
535
|
+
|
536
|
+
heap->bufs = [[NSMutableArray alloc] init];
|
537
|
+
|
538
|
+
return heap;
|
539
|
+
}
|
540
|
+
|
541
|
+
static void lm_ggml_metal_heap_reset(struct lm_ggml_metal_heap * heap) {
|
542
|
+
heap->offs = 0;
|
543
|
+
|
544
|
+
// count how many graph computes the heap ended up being unused
|
545
|
+
if ([heap->bufs count] > 0) {
|
546
|
+
heap->n_unused = 0;
|
547
|
+
} else {
|
548
|
+
heap->n_unused++;
|
549
|
+
}
|
550
|
+
|
551
|
+
for (id<MTLBuffer> buf in heap->bufs) {
|
552
|
+
[buf release];
|
553
|
+
}
|
554
|
+
[heap->bufs removeAllObjects];
|
555
|
+
|
556
|
+
// tell the OS that it can reuse this memory if needed
|
557
|
+
// ref: https://developer.apple.com/documentation/metal/mtlpurgeablestate?language=objc
|
558
|
+
[heap->obj setPurgeableState:MTLPurgeableStateVolatile];
|
559
|
+
}
|
560
|
+
|
561
|
+
static void lm_ggml_metal_heap_free(struct lm_ggml_metal_heap * heap) {
|
562
|
+
if (heap == nil) {
|
563
|
+
return;
|
564
|
+
}
|
565
|
+
|
566
|
+
lm_ggml_metal_heap_reset(heap);
|
567
|
+
|
568
|
+
[heap->obj release];
|
569
|
+
[heap->bufs release];
|
570
|
+
|
571
|
+
free(heap);
|
572
|
+
}
|
573
|
+
|
574
|
+
@interface lm_ggml_metal_heap_ptr : NSObject
|
575
|
+
|
576
|
+
@property (nonatomic, assign) struct lm_ggml_metal_heap * data;
|
577
|
+
|
578
|
+
@end
|
579
|
+
|
580
|
+
@implementation lm_ggml_metal_heap_ptr
|
581
|
+
@end
|
582
|
+
|
583
|
+
//
|
584
|
+
// lm_ggml_metal_mem_pool
|
585
|
+
//
|
586
|
+
|
587
|
+
struct lm_ggml_metal_mem_pool {
|
588
|
+
id<MTLDevice> device;
|
589
|
+
|
590
|
+
int n_heaps; // total number of heaps ever created (including those that were removed)
|
591
|
+
|
592
|
+
NSMutableArray * heaps;
|
593
|
+
NSMutableArray * heaps_to_remove;
|
594
|
+
};
|
595
|
+
|
596
|
+
static struct lm_ggml_metal_mem_pool * lm_ggml_metal_mem_pool_init(void) {
|
597
|
+
struct lm_ggml_metal_mem_pool * mem_pool = calloc(1, sizeof(struct lm_ggml_metal_mem_pool));
|
598
|
+
|
599
|
+
mem_pool->n_heaps = 0;
|
600
|
+
|
601
|
+
mem_pool->heaps = [[NSMutableArray alloc] init];
|
602
|
+
mem_pool->heaps_to_remove = [[NSMutableArray alloc] init];
|
603
|
+
|
604
|
+
return mem_pool;
|
605
|
+
}
|
606
|
+
|
607
|
+
static void lm_ggml_metal_mem_pool_free(struct lm_ggml_metal_mem_pool * mem_pool) {
|
608
|
+
LM_GGML_LOG_DEBUG("%s: freeing memory pool, num heaps = %zu (total = %d)\n", __func__, [mem_pool->heaps count], mem_pool->n_heaps);
|
609
|
+
|
610
|
+
size_t size_all = 0;
|
611
|
+
size_t size_cur = 0;
|
612
|
+
|
613
|
+
for (lm_ggml_metal_heap_ptr * ptr in mem_pool->heaps) {
|
614
|
+
LM_GGML_LOG_DEBUG("%s: heap: %p\n", __func__, (void *) ptr.data);
|
615
|
+
LM_GGML_LOG_DEBUG("%s: n_alloc: %" PRId64 "\n", __func__, ptr.data->n_alloc);
|
616
|
+
LM_GGML_LOG_DEBUG("%s: n_unused: %d\n", __func__, ptr.data->n_unused);
|
617
|
+
LM_GGML_LOG_DEBUG("%s: size: %.2f MiB\n", __func__, [ptr.data->obj size] / 1024.0 / 1024.0);
|
618
|
+
LM_GGML_LOG_DEBUG("%s: bufs: %zu\n", __func__, [ptr.data->bufs count]);
|
619
|
+
|
620
|
+
if ([ptr.data->bufs count] > 0) {
|
621
|
+
size_cur += [ptr.data->obj size];
|
622
|
+
}
|
623
|
+
size_all += [ptr.data->obj size];
|
624
|
+
|
625
|
+
lm_ggml_metal_heap_free(ptr.data);
|
626
|
+
[ptr release];
|
627
|
+
}
|
628
|
+
[mem_pool->heaps release];
|
629
|
+
[mem_pool->heaps_to_remove release];
|
630
|
+
|
631
|
+
if (size_all > 0) {
|
632
|
+
LM_GGML_LOG_DEBUG("%s: size_all: %.2f MiB\n", __func__, size_all / 1024.0 / 1024.0);
|
633
|
+
LM_GGML_LOG_DEBUG("%s: size_cur: %.2f MiB\n", __func__, size_cur / 1024.0 / 1024.0);
|
634
|
+
}
|
635
|
+
|
636
|
+
free(mem_pool);
|
637
|
+
}
|
638
|
+
|
639
|
+
static void lm_ggml_metal_mem_pool_reset(struct lm_ggml_metal_mem_pool * mem_pool) {
|
640
|
+
for (NSUInteger i = 0; i < [mem_pool->heaps count]; i++) {
|
641
|
+
lm_ggml_metal_heap_ptr * ptr = [mem_pool->heaps objectAtIndex:i];
|
642
|
+
|
643
|
+
struct lm_ggml_metal_heap * heap = ptr.data;
|
644
|
+
lm_ggml_metal_heap_reset(heap);
|
645
|
+
|
646
|
+
// if the heap hasn't been used for a while, remove it
|
647
|
+
if (heap->n_unused >= 128) {
|
648
|
+
[mem_pool->heaps_to_remove addObject:@(i)];
|
649
|
+
}
|
650
|
+
}
|
651
|
+
|
652
|
+
if (mem_pool->heaps_to_remove.count > 0) {
|
653
|
+
for (NSUInteger i = 0; i < [mem_pool->heaps_to_remove count]; i++) {
|
654
|
+
NSUInteger index = [[mem_pool->heaps_to_remove objectAtIndex:i] intValue];
|
655
|
+
lm_ggml_metal_heap_ptr * ptr = [mem_pool->heaps objectAtIndex:index];
|
656
|
+
|
657
|
+
struct lm_ggml_metal_heap * heap = ptr.data;
|
658
|
+
lm_ggml_metal_heap_free(heap);
|
659
|
+
|
660
|
+
[mem_pool->heaps removeObjectAtIndex:index];
|
661
|
+
[ptr release];
|
662
|
+
}
|
663
|
+
|
664
|
+
[mem_pool->heaps_to_remove removeAllObjects];
|
665
|
+
}
|
666
|
+
}
|
667
|
+
|
668
|
+
static void lm_ggml_metal_mem_pool_clear(struct lm_ggml_metal_mem_pool * mem_pool) {
|
669
|
+
for (lm_ggml_metal_heap_ptr * ptr in mem_pool->heaps) {
|
670
|
+
ptr.data->offs = 0;
|
671
|
+
}
|
672
|
+
}
|
673
|
+
|
674
|
+
static id<MTLBuffer> lm_ggml_metal_mem_pool_alloc(struct lm_ggml_metal_mem_pool * mem_pool, size_t size) {
|
675
|
+
const size_t alignment = 32;
|
676
|
+
|
677
|
+
const size_t size_aligned = LM_GGML_PAD(size, alignment);
|
678
|
+
|
679
|
+
// try one of the existing heaps
|
680
|
+
for (lm_ggml_metal_heap_ptr * ptr in mem_pool->heaps) {
|
681
|
+
struct lm_ggml_metal_heap * heap = ptr.data;
|
682
|
+
if (heap->offs + size_aligned <= [heap->obj size]) {
|
683
|
+
// if this is the first buffer in the heap for the current command buffer, tell the OS that
|
684
|
+
// it cannot free the memory used by the heap
|
685
|
+
// ref: https://developer.apple.com/documentation/metal/mtlpurgeablestate?language=objc
|
686
|
+
if ([heap->bufs count] == 0) {
|
687
|
+
[heap->obj setPurgeableState:MTLPurgeableStateNonVolatile];
|
688
|
+
}
|
689
|
+
|
690
|
+
id<MTLBuffer> buf = [heap->obj newBufferWithLength:size_aligned options:MTLResourceStorageModePrivate offset:heap->offs];
|
691
|
+
if (buf == nil) {
|
692
|
+
LM_GGML_LOG_ERROR("%s: error: failed to create MTLBuffer with size %zu\n", __func__, size_aligned);
|
693
|
+
return nil;
|
694
|
+
}
|
695
|
+
|
696
|
+
heap->n_alloc++;
|
697
|
+
heap->offs += size_aligned;
|
698
|
+
|
699
|
+
[heap->bufs addObject:buf];
|
700
|
+
|
701
|
+
return buf;
|
702
|
+
}
|
703
|
+
}
|
704
|
+
|
705
|
+
// create a new heap that can fit this buffer
|
706
|
+
lm_ggml_metal_heap_ptr * heap_ptr = [lm_ggml_metal_heap_ptr new];
|
707
|
+
|
708
|
+
struct lm_ggml_metal_heap * heap = lm_ggml_metal_heap_init(mem_pool->device, size_aligned);
|
709
|
+
if (heap == NULL) {
|
710
|
+
LM_GGML_LOG_ERROR("%s: error: failed to create heap of size %zu\n", __func__, size_aligned);
|
711
|
+
return NULL;
|
712
|
+
}
|
713
|
+
|
714
|
+
//LM_GGML_LOG_DEBUG("%s: creating new heap of size %zu, got %zu\n", __func__, size_aligned, [heap->obj size]);
|
715
|
+
|
716
|
+
heap_ptr.data = heap;
|
717
|
+
lm_ggml_metal_heap_reset(heap);
|
718
|
+
|
719
|
+
[heap->obj setPurgeableState:MTLPurgeableStateNonVolatile];
|
720
|
+
id<MTLBuffer> buf = [heap->obj newBufferWithLength:size_aligned options:MTLResourceStorageModePrivate offset:heap->offs];
|
721
|
+
if (buf == nil) {
|
722
|
+
LM_GGML_LOG_ERROR("%s: error: failed to create MTLBuffer with size %zu\n", __func__, size_aligned);
|
723
|
+
return NULL;
|
724
|
+
}
|
725
|
+
|
726
|
+
heap->n_alloc++;
|
727
|
+
heap->offs += size_aligned;
|
728
|
+
|
729
|
+
[heap->bufs addObject:buf];
|
730
|
+
|
731
|
+
[mem_pool->heaps addObject:heap_ptr];
|
732
|
+
mem_pool->n_heaps++;
|
733
|
+
|
734
|
+
return buf;
|
735
|
+
}
|
736
|
+
|
737
|
+
struct lm_ggml_metal_command_buffer {
|
738
|
+
id<MTLCommandBuffer> obj;
|
739
|
+
|
740
|
+
// each command buffer has a memory pool from which it can allocate temporary buffers during the compute
|
741
|
+
struct lm_ggml_metal_mem_pool * mem_pool;
|
742
|
+
};
|
743
|
+
|
471
744
|
struct lm_ggml_backend_metal_context {
|
745
|
+
id<MTLDevice> device;
|
472
746
|
id<MTLCommandQueue> queue;
|
473
747
|
|
474
748
|
dispatch_queue_t d_queue;
|
@@ -493,7 +767,7 @@ struct lm_ggml_backend_metal_context {
|
|
493
767
|
void (^encode_async)(size_t ith);
|
494
768
|
|
495
769
|
// n_cb command buffers + 1 used by the main thread
|
496
|
-
|
770
|
+
struct lm_ggml_metal_command_buffer cmd_bufs[LM_GGML_METAL_MAX_COMMAND_BUFFERS + 1];
|
497
771
|
|
498
772
|
// abort lm_ggml_metal_graph_compute if callback returns true
|
499
773
|
lm_ggml_abort_callback abort_callback;
|
@@ -560,7 +834,11 @@ static id<MTLLibrary> lm_ggml_metal_load_library(id<MTLDevice> device, bool use_
|
|
560
834
|
NSBundle * bundle = [NSBundle bundleForClass:[LMGGMLMetalClass class]];
|
561
835
|
#endif
|
562
836
|
|
837
|
+
#if TARGET_OS_SIMULATOR
|
838
|
+
NSString * path_lib = [bundle pathForResource:@"ggml-llama-sim" ofType:@"metallib"];
|
839
|
+
#else
|
563
840
|
NSString * path_lib = [bundle pathForResource:@"ggml-llama" ofType:@"metallib"];
|
841
|
+
#endif
|
564
842
|
if (path_lib == nil) {
|
565
843
|
// Try to find the resource in the directory where the current binary located.
|
566
844
|
NSString * current_binary = [[NSProcessInfo processInfo] arguments][0];
|
@@ -683,9 +961,11 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
|
|
683
961
|
struct lm_ggml_backend_metal_device_context * ctx_dev = dev->context;
|
684
962
|
|
685
963
|
id<MTLDevice> device = lm_ggml_backend_metal_device_acq(ctx_dev);
|
964
|
+
|
686
965
|
LM_GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]);
|
687
966
|
|
688
|
-
ctx->
|
967
|
+
ctx->device = device;
|
968
|
+
ctx->queue = [device newCommandQueue];
|
689
969
|
if (ctx->queue == nil) {
|
690
970
|
LM_GGML_LOG_ERROR("%s: error: failed to create command queue\n", __func__);
|
691
971
|
return NULL;
|
@@ -746,7 +1026,10 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
|
|
746
1026
|
ctx->gf = nil;
|
747
1027
|
ctx->encode_async = nil;
|
748
1028
|
for (int i = 0; i < LM_GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
|
749
|
-
ctx->
|
1029
|
+
ctx->cmd_bufs[i].obj = nil;
|
1030
|
+
|
1031
|
+
ctx->cmd_bufs[i].mem_pool = lm_ggml_metal_mem_pool_init();
|
1032
|
+
ctx->cmd_bufs[i].mem_pool->device = device;
|
750
1033
|
}
|
751
1034
|
|
752
1035
|
#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
|
@@ -1011,6 +1294,7 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
|
|
1011
1294
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H192, flash_attn_ext_f16_h192, has_simdgroup_mm);
|
1012
1295
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128, flash_attn_ext_f16_hk192_hv128, has_simdgroup_mm);
|
1013
1296
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, has_simdgroup_mm);
|
1297
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512, flash_attn_ext_f16_hk576_hv512, has_simdgroup_mm);
|
1014
1298
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64, flash_attn_ext_bf16_h64, has_simdgroup_mm && use_bfloat);
|
1015
1299
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80, flash_attn_ext_bf16_h80, has_simdgroup_mm && use_bfloat);
|
1016
1300
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96, flash_attn_ext_bf16_h96, has_simdgroup_mm && use_bfloat);
|
@@ -1019,6 +1303,7 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
|
|
1019
1303
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H192, flash_attn_ext_bf16_h192, has_simdgroup_mm && use_bfloat);
|
1020
1304
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128, flash_attn_ext_bf16_hk192_hv128, has_simdgroup_mm && use_bfloat);
|
1021
1305
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256, flash_attn_ext_bf16_h256, has_simdgroup_mm && use_bfloat);
|
1306
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512, flash_attn_ext_bf16_hk576_hv512, has_simdgroup_mm && use_bfloat);
|
1022
1307
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, flash_attn_ext_q4_0_h64, has_simdgroup_mm);
|
1023
1308
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, flash_attn_ext_q4_0_h80, has_simdgroup_mm);
|
1024
1309
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, flash_attn_ext_q4_0_h96, has_simdgroup_mm);
|
@@ -1027,6 +1312,7 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
|
|
1027
1312
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H192, flash_attn_ext_q4_0_h192, has_simdgroup_mm);
|
1028
1313
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128, flash_attn_ext_q4_0_hk192_hv128, has_simdgroup_mm);
|
1029
1314
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256, flash_attn_ext_q4_0_h256, has_simdgroup_mm);
|
1315
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512, flash_attn_ext_q4_0_hk576_hv512, has_simdgroup_mm);
|
1030
1316
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64, flash_attn_ext_q4_1_h64, has_simdgroup_mm);
|
1031
1317
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80, flash_attn_ext_q4_1_h80, has_simdgroup_mm);
|
1032
1318
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96, flash_attn_ext_q4_1_h96, has_simdgroup_mm);
|
@@ -1035,6 +1321,7 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
|
|
1035
1321
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H192, flash_attn_ext_q4_1_h192, has_simdgroup_mm);
|
1036
1322
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128, flash_attn_ext_q4_1_hk192_hv128, has_simdgroup_mm);
|
1037
1323
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256, flash_attn_ext_q4_1_h256, has_simdgroup_mm);
|
1324
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512, flash_attn_ext_q4_1_hk576_hv512, has_simdgroup_mm);
|
1038
1325
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64, flash_attn_ext_q5_0_h64, has_simdgroup_mm);
|
1039
1326
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80, flash_attn_ext_q5_0_h80, has_simdgroup_mm);
|
1040
1327
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96, flash_attn_ext_q5_0_h96, has_simdgroup_mm);
|
@@ -1043,6 +1330,7 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
|
|
1043
1330
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H192, flash_attn_ext_q5_0_h192, has_simdgroup_mm);
|
1044
1331
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128, flash_attn_ext_q5_0_hk192_hv128, has_simdgroup_mm);
|
1045
1332
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256, flash_attn_ext_q5_0_h256, has_simdgroup_mm);
|
1333
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512, flash_attn_ext_q5_0_hk576_hv512, has_simdgroup_mm);
|
1046
1334
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64, flash_attn_ext_q5_1_h64, has_simdgroup_mm);
|
1047
1335
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80, flash_attn_ext_q5_1_h80, has_simdgroup_mm);
|
1048
1336
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96, flash_attn_ext_q5_1_h96, has_simdgroup_mm);
|
@@ -1051,6 +1339,7 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
|
|
1051
1339
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H192, flash_attn_ext_q5_1_h192, has_simdgroup_mm);
|
1052
1340
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128, flash_attn_ext_q5_1_hk192_hv128, has_simdgroup_mm);
|
1053
1341
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256, flash_attn_ext_q5_1_h256, has_simdgroup_mm);
|
1342
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512, flash_attn_ext_q5_1_hk576_hv512, has_simdgroup_mm);
|
1054
1343
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64, flash_attn_ext_q8_0_h64, has_simdgroup_mm);
|
1055
1344
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80, flash_attn_ext_q8_0_h80, has_simdgroup_mm);
|
1056
1345
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96, flash_attn_ext_q8_0_h96, has_simdgroup_mm);
|
@@ -1059,6 +1348,14 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
|
|
1059
1348
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192, flash_attn_ext_q8_0_h192, has_simdgroup_mm);
|
1060
1349
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128, flash_attn_ext_q8_0_hk192_hv128, has_simdgroup_mm);
|
1061
1350
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, has_simdgroup_mm);
|
1351
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512, flash_attn_ext_q8_0_hk576_hv512, has_simdgroup_mm);
|
1352
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96, flash_attn_ext_vec_f16_h96, has_simdgroup_reduction);
|
1353
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96, flash_attn_ext_vec_bf16_h96, has_simdgroup_reduction && use_bfloat);
|
1354
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96, flash_attn_ext_vec_q4_0_h96, has_simdgroup_reduction);
|
1355
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H96, flash_attn_ext_vec_q4_1_h96, has_simdgroup_reduction);
|
1356
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H96, flash_attn_ext_vec_q5_0_h96, has_simdgroup_reduction);
|
1357
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H96, flash_attn_ext_vec_q5_1_h96, has_simdgroup_reduction);
|
1358
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H96, flash_attn_ext_vec_q8_0_h96, has_simdgroup_reduction);
|
1062
1359
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, has_simdgroup_reduction);
|
1063
1360
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128, flash_attn_ext_vec_bf16_h128, has_simdgroup_reduction && use_bfloat);
|
1064
1361
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128, flash_attn_ext_vec_q4_0_h128, has_simdgroup_reduction);
|
@@ -1087,6 +1384,13 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
|
|
1087
1384
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256, has_simdgroup_reduction);
|
1088
1385
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256, flash_attn_ext_vec_q5_1_h256, has_simdgroup_reduction);
|
1089
1386
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, flash_attn_ext_vec_q8_0_h256, has_simdgroup_reduction);
|
1387
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512, flash_attn_ext_vec_f16_hk576_hv512, has_simdgroup_reduction);
|
1388
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK576_HV512, flash_attn_ext_vec_bf16_hk576_hv512, has_simdgroup_reduction && use_bfloat);
|
1389
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK576_HV512, flash_attn_ext_vec_q4_0_hk576_hv512, has_simdgroup_reduction);
|
1390
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK576_HV512, flash_attn_ext_vec_q4_1_hk576_hv512, has_simdgroup_reduction);
|
1391
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512, flash_attn_ext_vec_q5_0_hk576_hv512, has_simdgroup_reduction);
|
1392
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512, flash_attn_ext_vec_q5_1_hk576_hv512, has_simdgroup_reduction);
|
1393
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512, flash_attn_ext_vec_q8_0_hk576_hv512, has_simdgroup_reduction);
|
1090
1394
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SET_F32, set_f32, true);
|
1091
1395
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SET_I32, set_i32, true);
|
1092
1396
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
|
@@ -1117,6 +1421,7 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
|
|
1117
1421
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SQRT, sqrt, true);
|
1118
1422
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SIN, sin, true);
|
1119
1423
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_COS, cos, true);
|
1424
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_NEG, neg, true);
|
1120
1425
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
1121
1426
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
|
1122
1427
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
|
@@ -1137,6 +1442,12 @@ static void lm_ggml_metal_free(struct lm_ggml_backend_metal_context * ctx) {
|
|
1137
1442
|
|
1138
1443
|
[ctx->queue release];
|
1139
1444
|
|
1445
|
+
for (int i = 0; i < LM_GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
|
1446
|
+
// ctx->cmd_bufs[i].obj is auto released
|
1447
|
+
|
1448
|
+
lm_ggml_metal_mem_pool_free(ctx->cmd_bufs[i].mem_pool);
|
1449
|
+
}
|
1450
|
+
|
1140
1451
|
dispatch_release(ctx->d_queue);
|
1141
1452
|
|
1142
1453
|
free(ctx);
|
@@ -1278,6 +1589,7 @@ static bool lm_ggml_metal_supports_op(const struct lm_ggml_backend_metal_device_
|
|
1278
1589
|
case LM_GGML_UNARY_OP_GELU_QUICK:
|
1279
1590
|
case LM_GGML_UNARY_OP_SILU:
|
1280
1591
|
case LM_GGML_UNARY_OP_ELU:
|
1592
|
+
case LM_GGML_UNARY_OP_NEG:
|
1281
1593
|
return lm_ggml_is_contiguous(op->src[0]) && op->src[0]->type == LM_GGML_TYPE_F32;
|
1282
1594
|
default:
|
1283
1595
|
return false;
|
@@ -1334,8 +1646,9 @@ static bool lm_ggml_metal_supports_op(const struct lm_ggml_backend_metal_device_
|
|
1334
1646
|
return op->src[0]->type == LM_GGML_TYPE_F16;
|
1335
1647
|
case LM_GGML_OP_POOL_1D:
|
1336
1648
|
return false;
|
1337
|
-
case LM_GGML_OP_POOL_2D:
|
1338
1649
|
case LM_GGML_OP_UPSCALE:
|
1650
|
+
return op->src[0]->type == LM_GGML_TYPE_F32 && op->op_params[0] == LM_GGML_SCALE_MODE_NEAREST;
|
1651
|
+
case LM_GGML_OP_POOL_2D:
|
1339
1652
|
case LM_GGML_OP_PAD:
|
1340
1653
|
case LM_GGML_OP_PAD_REFLECT_1D:
|
1341
1654
|
case LM_GGML_OP_TIMESTEP_EMBEDDING:
|
@@ -1350,6 +1663,11 @@ static bool lm_ggml_metal_supports_op(const struct lm_ggml_backend_metal_device_
|
|
1350
1663
|
// TODO: not sure if it is worth adding kernels for this size
|
1351
1664
|
return false;
|
1352
1665
|
}
|
1666
|
+
if (op->src[0]->ne[0] == 576) {
|
1667
|
+
// DeepSeek sizes
|
1668
|
+
// TODO: disabled for now, until optmized
|
1669
|
+
return false;
|
1670
|
+
}
|
1353
1671
|
if (op->src[1]->type != op->src[2]->type) {
|
1354
1672
|
return false;
|
1355
1673
|
}
|
@@ -1435,10 +1753,11 @@ static bool lm_ggml_metal_supports_op(const struct lm_ggml_backend_metal_device_
|
|
1435
1753
|
}
|
1436
1754
|
}
|
1437
1755
|
|
1438
|
-
static
|
1756
|
+
static bool lm_ggml_metal_encode_node(
|
1439
1757
|
lm_ggml_backend_t backend,
|
1440
1758
|
int idx,
|
1441
|
-
id<MTLComputeCommandEncoder> encoder
|
1759
|
+
id<MTLComputeCommandEncoder> encoder,
|
1760
|
+
struct lm_ggml_metal_mem_pool * mem_pool) {
|
1442
1761
|
struct lm_ggml_backend_metal_context * ctx = backend->context;
|
1443
1762
|
struct lm_ggml_backend_metal_device_context * ctx_dev = backend->device->context;
|
1444
1763
|
|
@@ -1454,7 +1773,7 @@ static void lm_ggml_metal_encode_node(
|
|
1454
1773
|
struct lm_ggml_tensor * dst = node;
|
1455
1774
|
|
1456
1775
|
if (lm_ggml_is_empty(dst)) {
|
1457
|
-
return;
|
1776
|
+
return true;
|
1458
1777
|
}
|
1459
1778
|
|
1460
1779
|
switch (dst->op) {
|
@@ -1465,7 +1784,7 @@ static void lm_ggml_metal_encode_node(
|
|
1465
1784
|
case LM_GGML_OP_PERMUTE:
|
1466
1785
|
{
|
1467
1786
|
// noop -> next node
|
1468
|
-
} return;
|
1787
|
+
} return true;
|
1469
1788
|
default:
|
1470
1789
|
{
|
1471
1790
|
} break;
|
@@ -1476,6 +1795,8 @@ static void lm_ggml_metal_encode_node(
|
|
1476
1795
|
LM_GGML_ABORT("unsupported op");
|
1477
1796
|
}
|
1478
1797
|
|
1798
|
+
lm_ggml_metal_mem_pool_clear(mem_pool);
|
1799
|
+
|
1479
1800
|
const int64_t ne00 = src0 ? src0->ne[0] : 0;
|
1480
1801
|
const int64_t ne01 = src0 ? src0->ne[1] : 0;
|
1481
1802
|
const int64_t ne02 = src0 ? src0->ne[2] : 0;
|
@@ -1962,6 +2283,18 @@ static void lm_ggml_metal_encode_node(
|
|
1962
2283
|
|
1963
2284
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1964
2285
|
} break;
|
2286
|
+
case LM_GGML_UNARY_OP_NEG:
|
2287
|
+
{
|
2288
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_NEG].pipeline;
|
2289
|
+
|
2290
|
+
[encoder setComputePipelineState:pipeline];
|
2291
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
2292
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
2293
|
+
|
2294
|
+
const int64_t n = lm_ggml_nelements(dst);
|
2295
|
+
|
2296
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
2297
|
+
} break;
|
1965
2298
|
default:
|
1966
2299
|
{
|
1967
2300
|
LM_GGML_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, idx, lm_ggml_op_name(dst->op));
|
@@ -2110,26 +2443,76 @@ static void lm_ggml_metal_encode_node(
|
|
2110
2443
|
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
2111
2444
|
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
2112
2445
|
|
2113
|
-
|
2446
|
+
// use this branch to test the lm_ggml_metal_mem_pool functionality
|
2447
|
+
#if 0
|
2448
|
+
// cpy to tmp buffer in MTLHeap
|
2449
|
+
|
2450
|
+
id<MTLBuffer> h_src0 = h_src0 = lm_ggml_metal_mem_pool_alloc(mem_pool, lm_ggml_nbytes(src0));
|
2451
|
+
if (!h_src0) {
|
2452
|
+
LM_GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, lm_ggml_nbytes(src0));
|
2453
|
+
return false;
|
2454
|
+
}
|
2455
|
+
|
2456
|
+
offs_src0 = 0;
|
2457
|
+
|
2458
|
+
lm_ggml_metal_kargs_cpy args_cpy = {
|
2114
2459
|
/*.ne00 =*/ ne00,
|
2115
2460
|
/*.ne01 =*/ ne01,
|
2116
2461
|
/*.ne02 =*/ ne02,
|
2117
|
-
/*.
|
2118
|
-
/*.
|
2119
|
-
/*.
|
2120
|
-
/*.
|
2462
|
+
/*.ne03 =*/ ne03,
|
2463
|
+
/*.nb00 =*/ nb00,
|
2464
|
+
/*.nb01 =*/ nb01,
|
2465
|
+
/*.nb02 =*/ nb02,
|
2466
|
+
/*.nb03 =*/ nb03,
|
2467
|
+
/*.ne0 =*/ ne00,
|
2468
|
+
/*.ne1 =*/ ne01,
|
2469
|
+
/*.ne2 =*/ ne02,
|
2470
|
+
/*.ne3 =*/ ne03,
|
2471
|
+
/*.nb0 =*/ nb00,
|
2472
|
+
/*.nb1 =*/ nb01,
|
2473
|
+
/*.nb2 =*/ nb02,
|
2474
|
+
/*.nb3 =*/ nb03,
|
2475
|
+
};
|
2476
|
+
|
2477
|
+
if (src0->type == LM_GGML_TYPE_F16) {
|
2478
|
+
[encoder setComputePipelineState:ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline];
|
2479
|
+
} else {
|
2480
|
+
[encoder setComputePipelineState:ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline];
|
2481
|
+
}
|
2482
|
+
[encoder setBytes:&args_cpy length:sizeof(args_cpy) atIndex:0];
|
2483
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
2484
|
+
[encoder setBuffer:h_src0 offset:0 atIndex:2];
|
2485
|
+
|
2486
|
+
LM_GGML_ASSERT(ne00 % lm_ggml_blck_size(src0->type) == 0);
|
2487
|
+
int nth_cpy = MIN(1024, ne00 / lm_ggml_blck_size(src0->type));
|
2488
|
+
|
2489
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth_cpy, 1, 1)];
|
2490
|
+
|
2491
|
+
#else
|
2492
|
+
id<MTLBuffer> h_src0 = id_src0;
|
2493
|
+
#endif
|
2494
|
+
// softmax
|
2495
|
+
|
2496
|
+
lm_ggml_metal_kargs_soft_max args = {
|
2497
|
+
/*.ne00 =*/ ne00,
|
2498
|
+
/*.ne01 =*/ ne01,
|
2499
|
+
/*.ne02 =*/ ne02,
|
2500
|
+
/*.scale =*/ scale,
|
2501
|
+
/*.max_bias =*/ max_bias,
|
2502
|
+
/*.m0 =*/ m0,
|
2503
|
+
/*.m1 =*/ m1,
|
2121
2504
|
/*.n_head_log2 =*/ n_head_log2,
|
2122
2505
|
};
|
2123
2506
|
|
2124
2507
|
[encoder setComputePipelineState:pipeline];
|
2125
|
-
[encoder setBuffer:
|
2508
|
+
[encoder setBuffer:h_src0 offset:offs_src0 atIndex:0];
|
2126
2509
|
if (id_src1) {
|
2127
|
-
[encoder setBuffer:id_src1 offset:offs_src1
|
2510
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
2128
2511
|
} else {
|
2129
|
-
[encoder setBuffer:
|
2512
|
+
[encoder setBuffer:h_src0 offset:offs_src0 atIndex:1];
|
2130
2513
|
}
|
2131
|
-
[encoder setBuffer:id_dst
|
2132
|
-
[encoder setBytes:&args
|
2514
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
2515
|
+
[encoder setBytes:&args length:sizeof(args) atIndex:3];
|
2133
2516
|
|
2134
2517
|
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
2135
2518
|
|
@@ -3842,12 +4225,14 @@ static void lm_ggml_metal_encode_node(
|
|
3842
4225
|
// TODO: add vec kernels for (ne00%64 == 0) and maybe also for (ne00%32 == 0)
|
3843
4226
|
// for now avoiding mainly to keep the number of templates/kernels a bit lower
|
3844
4227
|
// these are now trivial to add after: https://github.com/ggml-org/llama.cpp/pull/12612
|
3845
|
-
if (ne01 >= 4 || (ne00%128 != 0 && ne00 != 192)) {
|
4228
|
+
if (ne01 >= 4 || (ne00%128 != 0 && ne00 != 96 && ne00 != 192 && ne00 != 576)) {
|
3846
4229
|
switch (src1->type) {
|
3847
4230
|
case LM_GGML_TYPE_F16:
|
3848
4231
|
{
|
3849
4232
|
if (ne00 == 192 && ne20 == 128) {
|
3850
4233
|
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128].pipeline;
|
4234
|
+
} else if (ne00 == 576 && ne20 == 512) {
|
4235
|
+
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512].pipeline;
|
3851
4236
|
} else {
|
3852
4237
|
switch (ne00) {
|
3853
4238
|
case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
|
@@ -3870,6 +4255,8 @@ static void lm_ggml_metal_encode_node(
|
|
3870
4255
|
{
|
3871
4256
|
if (ne00 == 192 && ne20 == 128) {
|
3872
4257
|
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128].pipeline;
|
4258
|
+
} else if (ne00 == 576 && ne20 == 512) {
|
4259
|
+
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512].pipeline;
|
3873
4260
|
} else {
|
3874
4261
|
switch (ne00) {
|
3875
4262
|
case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64 ].pipeline; break;
|
@@ -3892,6 +4279,8 @@ static void lm_ggml_metal_encode_node(
|
|
3892
4279
|
{
|
3893
4280
|
if (ne00 == 192 && ne20 == 128) {
|
3894
4281
|
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128].pipeline;
|
4282
|
+
} else if (ne00 == 576 && ne20 == 512) {
|
4283
|
+
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512].pipeline;
|
3895
4284
|
} else {
|
3896
4285
|
switch (ne00) {
|
3897
4286
|
case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64 ].pipeline; break;
|
@@ -3914,6 +4303,8 @@ static void lm_ggml_metal_encode_node(
|
|
3914
4303
|
{
|
3915
4304
|
if (ne00 == 192 && ne20 == 128) {
|
3916
4305
|
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128].pipeline;
|
4306
|
+
} else if (ne00 == 576 && ne20 == 512) {
|
4307
|
+
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512].pipeline;
|
3917
4308
|
} else {
|
3918
4309
|
switch (ne00) {
|
3919
4310
|
case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64 ].pipeline; break;
|
@@ -3936,6 +4327,8 @@ static void lm_ggml_metal_encode_node(
|
|
3936
4327
|
{
|
3937
4328
|
if (ne00 == 192 && ne20 == 128) {
|
3938
4329
|
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128].pipeline;
|
4330
|
+
} else if (ne00 == 576 && ne20 == 512) {
|
4331
|
+
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512].pipeline;
|
3939
4332
|
} else {
|
3940
4333
|
switch (ne00) {
|
3941
4334
|
case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64 ].pipeline; break;
|
@@ -3958,6 +4351,8 @@ static void lm_ggml_metal_encode_node(
|
|
3958
4351
|
{
|
3959
4352
|
if (ne00 == 192 && ne20 == 128) {
|
3960
4353
|
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128].pipeline;
|
4354
|
+
} else if (ne00 == 576 && ne20 == 512) {
|
4355
|
+
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512].pipeline;
|
3961
4356
|
} else {
|
3962
4357
|
switch (ne00) {
|
3963
4358
|
case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64 ].pipeline; break;
|
@@ -3980,6 +4375,8 @@ static void lm_ggml_metal_encode_node(
|
|
3980
4375
|
{
|
3981
4376
|
if (ne00 == 192 && ne20 == 128) {
|
3982
4377
|
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128].pipeline;
|
4378
|
+
} else if (ne00 == 576 && ne20 == 512) {
|
4379
|
+
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512].pipeline;
|
3983
4380
|
} else {
|
3984
4381
|
switch (ne00) {
|
3985
4382
|
case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64 ].pipeline; break;
|
@@ -4009,6 +4406,24 @@ static void lm_ggml_metal_encode_node(
|
|
4009
4406
|
use_vec_kernel = true;
|
4010
4407
|
|
4011
4408
|
switch (ne00) {
|
4409
|
+
case 96:
|
4410
|
+
{
|
4411
|
+
switch (src1->type) {
|
4412
|
+
case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96].pipeline; break;
|
4413
|
+
case LM_GGML_TYPE_BF16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96].pipeline; break;
|
4414
|
+
case LM_GGML_TYPE_Q4_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96].pipeline; break;
|
4415
|
+
case LM_GGML_TYPE_Q4_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H96].pipeline; break;
|
4416
|
+
case LM_GGML_TYPE_Q5_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H96].pipeline; break;
|
4417
|
+
case LM_GGML_TYPE_Q5_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H96].pipeline; break;
|
4418
|
+
case LM_GGML_TYPE_Q8_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H96].pipeline; break;
|
4419
|
+
default:
|
4420
|
+
{
|
4421
|
+
LM_GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
|
4422
|
+
LM_GGML_LOG_ERROR("add template specialization for this type\n");
|
4423
|
+
LM_GGML_ABORT("add template specialization for this type");
|
4424
|
+
}
|
4425
|
+
}
|
4426
|
+
} break;
|
4012
4427
|
case 128:
|
4013
4428
|
{
|
4014
4429
|
switch (src1->type) {
|
@@ -4081,12 +4496,36 @@ static void lm_ggml_metal_encode_node(
|
|
4081
4496
|
}
|
4082
4497
|
}
|
4083
4498
|
} break;
|
4499
|
+
case 576:
|
4500
|
+
{
|
4501
|
+
if (ne20 == 512) {
|
4502
|
+
switch (src1->type) {
|
4503
|
+
case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512].pipeline; break;
|
4504
|
+
case LM_GGML_TYPE_BF16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK576_HV512].pipeline; break;
|
4505
|
+
case LM_GGML_TYPE_Q4_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK576_HV512].pipeline; break;
|
4506
|
+
case LM_GGML_TYPE_Q4_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK576_HV512].pipeline; break;
|
4507
|
+
case LM_GGML_TYPE_Q5_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512].pipeline; break;
|
4508
|
+
case LM_GGML_TYPE_Q5_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512].pipeline; break;
|
4509
|
+
case LM_GGML_TYPE_Q8_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512].pipeline; break;
|
4510
|
+
default:
|
4511
|
+
{
|
4512
|
+
LM_GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
|
4513
|
+
LM_GGML_LOG_ERROR("add template specialization for this type\n");
|
4514
|
+
LM_GGML_ABORT("add template specialization for this type");
|
4515
|
+
}
|
4516
|
+
}
|
4517
|
+
} else {
|
4518
|
+
LM_GGML_LOG_ERROR("unsupported size: %lld\n", ne20);
|
4519
|
+
LM_GGML_LOG_ERROR("add template specialization for this size\n");
|
4520
|
+
LM_GGML_ABORT("add template specialization for this size");
|
4521
|
+
}
|
4522
|
+
} break;
|
4084
4523
|
default:
|
4085
|
-
|
4086
|
-
|
4087
|
-
|
4088
|
-
|
4089
|
-
|
4524
|
+
{
|
4525
|
+
LM_GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
|
4526
|
+
LM_GGML_LOG_ERROR("add template specialization for this size\n");
|
4527
|
+
LM_GGML_ABORT("add template specialization for this size");
|
4528
|
+
}
|
4090
4529
|
}
|
4091
4530
|
}
|
4092
4531
|
|
@@ -4482,6 +4921,8 @@ static void lm_ggml_metal_encode_node(
|
|
4482
4921
|
LM_GGML_ABORT("fatal error");
|
4483
4922
|
}
|
4484
4923
|
}
|
4924
|
+
|
4925
|
+
return true;
|
4485
4926
|
}
|
4486
4927
|
|
4487
4928
|
static enum lm_ggml_status lm_ggml_metal_graph_compute(
|
@@ -4535,25 +4976,25 @@ static enum lm_ggml_status lm_ggml_metal_graph_compute(
|
|
4535
4976
|
}
|
4536
4977
|
|
4537
4978
|
// the main thread commits the first few commands immediately
|
4538
|
-
//
|
4979
|
+
// cmd_buf[n_cb]
|
4539
4980
|
{
|
4540
|
-
id<MTLCommandBuffer>
|
4541
|
-
ctx->
|
4981
|
+
id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
|
4982
|
+
ctx->cmd_bufs[n_cb].obj = cmd_buf;
|
4542
4983
|
|
4543
|
-
[
|
4984
|
+
[cmd_buf enqueue];
|
4544
4985
|
ctx->encode_async(n_cb);
|
4545
4986
|
}
|
4546
4987
|
|
4547
4988
|
// prepare the rest of the command buffers asynchronously
|
4548
|
-
//
|
4989
|
+
// cmd_buf[0.. n_cb)
|
4549
4990
|
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
|
4550
|
-
id<MTLCommandBuffer>
|
4551
|
-
ctx->
|
4991
|
+
id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
|
4992
|
+
ctx->cmd_bufs[cb_idx].obj = cmd_buf;
|
4552
4993
|
|
4553
4994
|
// always enqueue the first two command buffers
|
4554
4995
|
// enqueue all of the command buffers if we don't need to abort
|
4555
4996
|
if (cb_idx < 2 || ctx->abort_callback == NULL) {
|
4556
|
-
[
|
4997
|
+
[cmd_buf enqueue];
|
4557
4998
|
}
|
4558
4999
|
}
|
4559
5000
|
|
@@ -4562,14 +5003,14 @@ static enum lm_ggml_status lm_ggml_metal_graph_compute(
|
|
4562
5003
|
// wait for completion and check status of each command buffer
|
4563
5004
|
// needed to detect if the device ran out-of-memory for example (#1881)
|
4564
5005
|
{
|
4565
|
-
id<MTLCommandBuffer>
|
4566
|
-
[
|
5006
|
+
id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[n_cb].obj;
|
5007
|
+
[cmd_buf waitUntilCompleted];
|
4567
5008
|
|
4568
|
-
MTLCommandBufferStatus status = [
|
5009
|
+
MTLCommandBufferStatus status = [cmd_buf status];
|
4569
5010
|
if (status != MTLCommandBufferStatusCompleted) {
|
4570
5011
|
LM_GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, n_cb, status);
|
4571
5012
|
if (status == MTLCommandBufferStatusError) {
|
4572
|
-
LM_GGML_LOG_INFO("error: %s\n", [[
|
5013
|
+
LM_GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
|
4573
5014
|
}
|
4574
5015
|
|
4575
5016
|
return LM_GGML_STATUS_FAILED;
|
@@ -4577,20 +5018,20 @@ static enum lm_ggml_status lm_ggml_metal_graph_compute(
|
|
4577
5018
|
}
|
4578
5019
|
|
4579
5020
|
for (int i = 0; i < n_cb; ++i) {
|
4580
|
-
id<MTLCommandBuffer>
|
4581
|
-
[
|
5021
|
+
id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[i].obj;
|
5022
|
+
[cmd_buf waitUntilCompleted];
|
4582
5023
|
|
4583
|
-
MTLCommandBufferStatus status = [
|
5024
|
+
MTLCommandBufferStatus status = [cmd_buf status];
|
4584
5025
|
if (status != MTLCommandBufferStatusCompleted) {
|
4585
5026
|
LM_GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
|
4586
5027
|
if (status == MTLCommandBufferStatusError) {
|
4587
|
-
LM_GGML_LOG_INFO("error: %s\n", [[
|
5028
|
+
LM_GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
|
4588
5029
|
}
|
4589
5030
|
|
4590
5031
|
return LM_GGML_STATUS_FAILED;
|
4591
5032
|
}
|
4592
5033
|
|
4593
|
-
id<MTLCommandBuffer> next_buffer = (i + 1 < n_cb ? ctx->
|
5034
|
+
id<MTLCommandBuffer> next_buffer = (i + 1 < n_cb ? ctx->cmd_bufs[i + 1].obj : nil);
|
4594
5035
|
if (!next_buffer) {
|
4595
5036
|
continue;
|
4596
5037
|
}
|
@@ -4973,8 +5414,9 @@ static void lm_ggml_backend_metal_set_n_cb(lm_ggml_backend_t backend, int n_cb)
|
|
4973
5414
|
|
4974
5415
|
const int n_nodes_per_cb = ctx->n_nodes_per_cb;
|
4975
5416
|
|
4976
|
-
id<MTLCommandBuffer>
|
4977
|
-
|
5417
|
+
id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[cb_idx].obj;
|
5418
|
+
|
5419
|
+
id<MTLComputeCommandEncoder> encoder = [cmd_buf computeCommandEncoder];
|
4978
5420
|
|
4979
5421
|
int node_start = 0;
|
4980
5422
|
int node_end = n_nodes_0;
|
@@ -4986,22 +5428,29 @@ static void lm_ggml_backend_metal_set_n_cb(lm_ggml_backend_t backend, int n_cb)
|
|
4986
5428
|
|
4987
5429
|
const bool should_capture = ctx->capture_next_compute;
|
4988
5430
|
|
5431
|
+
struct lm_ggml_metal_mem_pool * mem_pool = ctx->cmd_bufs[cb_idx].mem_pool;
|
5432
|
+
lm_ggml_metal_mem_pool_reset(mem_pool);
|
5433
|
+
|
4989
5434
|
for (int idx = node_start; idx < node_end; ++idx) {
|
4990
5435
|
if (should_capture) {
|
4991
5436
|
[encoder pushDebugGroup:[NSString stringWithCString:lm_ggml_op_desc(lm_ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
|
4992
5437
|
}
|
4993
5438
|
|
4994
|
-
lm_ggml_metal_encode_node(backend, idx, encoder);
|
5439
|
+
const bool res = lm_ggml_metal_encode_node(backend, idx, encoder, mem_pool);
|
4995
5440
|
|
4996
5441
|
if (should_capture) {
|
4997
5442
|
[encoder popDebugGroup];
|
4998
5443
|
}
|
5444
|
+
|
5445
|
+
if (!res) {
|
5446
|
+
break;
|
5447
|
+
}
|
4999
5448
|
}
|
5000
5449
|
|
5001
5450
|
[encoder endEncoding];
|
5002
5451
|
|
5003
5452
|
if (cb_idx < 2 || ctx->abort_callback == NULL) {
|
5004
|
-
[
|
5453
|
+
[cmd_buf commit];
|
5005
5454
|
}
|
5006
5455
|
});
|
5007
5456
|
}
|