cui-llama.rn 1.6.0 → 1.7.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +35 -7
- package/android/src/main/CMakeLists.txt +22 -11
- package/android/src/main/java/com/rnllama/LlamaContext.java +42 -6
- package/android/src/main/java/com/rnllama/RNLlama.java +139 -4
- package/android/src/main/jni.cpp +173 -18
- package/android/src/main/jniLibs/arm64-v8a/librnllama.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod_i8mm.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_i8mm.so +0 -0
- package/android/src/main/jniLibs/x86_64/librnllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/librnllama_x86_64.so +0 -0
- package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +24 -4
- package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +22 -2
- package/cpp/LICENSE +21 -0
- package/cpp/chat.cpp +129 -107
- package/cpp/chat.h +2 -0
- package/cpp/common.cpp +58 -78
- package/cpp/common.h +29 -21
- package/cpp/ggml-alloc.c +4 -1
- package/cpp/ggml-backend.cpp +9 -5
- package/cpp/ggml-backend.h +4 -4
- package/cpp/ggml-cpp.h +1 -1
- package/cpp/ggml-cpu/amx/amx.cpp +221 -0
- package/cpp/ggml-cpu/amx/amx.h +8 -0
- package/cpp/ggml-cpu/amx/common.h +91 -0
- package/cpp/ggml-cpu/amx/mmq.cpp +2511 -0
- package/cpp/ggml-cpu/amx/mmq.h +10 -0
- package/{ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers → cpp/ggml-cpu}/binary-ops.h +1 -1
- package/cpp/ggml-cpu/common.h +72 -0
- package/cpp/{ggml-cpu-aarch64.cpp → ggml-cpu/ggml-cpu-aarch64.cpp} +809 -103
- package/cpp/{ggml-cpu-quants.c → ggml-cpu/ggml-cpu-quants.c} +306 -6
- package/cpp/{ggml-cpu.c → ggml-cpu/ggml-cpu.c} +114 -55
- package/cpp/{ggml-cpu.cpp → ggml-cpu/ggml-cpu.cpp} +32 -16
- package/cpp/{ops.cpp → ggml-cpu/ops.cpp} +353 -173
- package/{ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers → cpp/ggml-cpu}/ops.h +2 -20
- package/cpp/{sgemm.cpp → ggml-cpu/sgemm.cpp} +501 -0
- package/{ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers → cpp/ggml-cpu}/simd-mappings.h +7 -3
- package/{ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers → cpp/ggml-cpu}/unary-ops.h +1 -1
- package/cpp/{vec.cpp → ggml-cpu/vec.cpp} +0 -6
- package/{ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers → cpp/ggml-cpu}/vec.h +16 -0
- package/cpp/ggml-cpu.h +5 -0
- package/cpp/ggml-impl.h +16 -9
- package/cpp/ggml-llama-sim.metallib +0 -0
- package/cpp/ggml-llama.metallib +0 -0
- package/cpp/ggml-metal-impl.h +36 -11
- package/cpp/ggml-metal.m +810 -176
- package/cpp/ggml-opt.cpp +373 -190
- package/cpp/ggml-opt.h +49 -28
- package/cpp/ggml-quants.c +0 -6
- package/cpp/ggml.c +227 -282
- package/cpp/ggml.h +82 -101
- package/cpp/gguf.cpp +33 -33
- package/cpp/json-schema-to-grammar.cpp +3 -0
- package/cpp/llama-adapter.cpp +6 -0
- package/cpp/llama-arch.cpp +49 -17
- package/cpp/llama-arch.h +9 -0
- package/cpp/llama-batch.cpp +8 -2
- package/cpp/llama-batch.h +2 -1
- package/cpp/llama-chat.cpp +39 -16
- package/cpp/llama-chat.h +4 -2
- package/cpp/llama-context.cpp +440 -611
- package/cpp/llama-context.h +44 -33
- package/cpp/llama-cparams.h +1 -0
- package/cpp/llama-graph.cpp +214 -291
- package/cpp/llama-graph.h +69 -21
- package/cpp/llama-hparams.cpp +17 -1
- package/cpp/llama-hparams.h +39 -5
- package/cpp/llama-kv-cache.cpp +2067 -620
- package/cpp/llama-kv-cache.h +410 -108
- package/cpp/llama-memory.h +12 -1
- package/cpp/llama-model-loader.cpp +24 -15
- package/cpp/llama-model-saver.cpp +281 -0
- package/cpp/llama-model-saver.h +37 -0
- package/cpp/llama-model.cpp +1089 -359
- package/cpp/llama-model.h +19 -3
- package/cpp/llama-sampling.cpp +20 -7
- package/cpp/llama-vocab.cpp +54 -9
- package/cpp/llama-vocab.h +6 -0
- package/cpp/llama.cpp +14 -0
- package/cpp/llama.h +86 -142
- package/cpp/minja/chat-template.hpp +9 -5
- package/cpp/minja/minja.hpp +69 -36
- package/cpp/rn-llama.cpp +602 -190
- package/cpp/rn-llama.h +34 -8
- package/cpp/sampling.cpp +57 -50
- package/cpp/tools/mtmd/clip-impl.h +462 -0
- package/cpp/tools/mtmd/clip.cpp +4024 -0
- package/cpp/tools/mtmd/clip.h +101 -0
- package/cpp/tools/mtmd/miniaudio.h +93468 -0
- package/cpp/tools/mtmd/mtmd-audio.cpp +855 -0
- package/cpp/tools/mtmd/mtmd-audio.h +62 -0
- package/cpp/tools/mtmd/mtmd-helper.cpp +297 -0
- package/cpp/tools/mtmd/mtmd.cpp +942 -0
- package/cpp/tools/mtmd/mtmd.h +362 -0
- package/cpp/tools/mtmd/stb_image.h +7988 -0
- package/ios/CMakeLists.txt +20 -10
- package/ios/RNLlama.h +6 -0
- package/ios/RNLlama.mm +82 -3
- package/ios/RNLlamaContext.h +5 -1
- package/ios/RNLlamaContext.mm +131 -38
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/chat.h +2 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/common.h +29 -21
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-backend.h +4 -4
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpp.h +1 -1
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu.h +5 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-impl.h +16 -9
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-opt.h +49 -28
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml.h +82 -101
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-arch.h +9 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-batch.h +2 -1
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-chat.h +4 -2
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-context.h +44 -33
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-cparams.h +1 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-graph.h +69 -21
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-hparams.h +39 -5
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cache.h +410 -108
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-memory.h +12 -1
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model-saver.h +37 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model.h +19 -3
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-vocab.h +6 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama.h +86 -142
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja/minja.hpp +69 -36
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/rn-llama.h +34 -8
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Info.plist +0 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/ggml-llama.metallib +0 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/rnllama +0 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +2 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +29 -21
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-backend.h +4 -4
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpp.h +1 -1
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu.h +5 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-impl.h +16 -9
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-opt.h +49 -28
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +82 -101
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-arch.h +9 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-batch.h +2 -1
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +4 -2
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +44 -33
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +1 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +69 -21
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +39 -5
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +410 -108
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +12 -1
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model-saver.h +37 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +19 -3
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +6 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +86 -142
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +69 -36
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +34 -8
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Info.plist +0 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/_CodeSignature/CodeResources +1 -1
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/chat.h +2 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/common.h +29 -21
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-backend.h +4 -4
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpp.h +1 -1
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu.h +5 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-impl.h +16 -9
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-opt.h +49 -28
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml.h +82 -101
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-arch.h +9 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-batch.h +2 -1
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-chat.h +4 -2
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-context.h +44 -33
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-cparams.h +1 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-graph.h +69 -21
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-hparams.h +39 -5
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cache.h +410 -108
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-memory.h +12 -1
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model-saver.h +37 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model.h +19 -3
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-vocab.h +6 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama.h +86 -142
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/minja.hpp +69 -36
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/rn-llama.h +34 -8
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Info.plist +0 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/ggml-llama.metallib +0 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/rnllama +0 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +2 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +29 -21
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-backend.h +4 -4
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpp.h +1 -1
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu.h +5 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-impl.h +16 -9
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-opt.h +49 -28
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +82 -101
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-arch.h +9 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-batch.h +2 -1
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +4 -2
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +44 -33
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +1 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +69 -21
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +39 -5
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +410 -108
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +12 -1
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model-saver.h +37 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +19 -3
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +6 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +86 -142
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +69 -36
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +34 -8
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Info.plist +0 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/_CodeSignature/CodeResources +1 -1
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
- package/jest/mock.js +33 -7
- package/lib/commonjs/NativeRNLlama.js.map +1 -1
- package/lib/commonjs/index.js +153 -21
- package/lib/commonjs/index.js.map +1 -1
- package/lib/module/NativeRNLlama.js.map +1 -1
- package/lib/module/index.js +152 -20
- package/lib/module/index.js.map +1 -1
- package/lib/typescript/NativeRNLlama.d.ts +54 -4
- package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
- package/lib/typescript/index.d.ts +72 -6
- package/lib/typescript/index.d.ts.map +1 -1
- package/package.json +1 -1
- package/src/NativeRNLlama.ts +72 -4
- package/src/index.ts +212 -38
- package/cpp/binary-ops.h +0 -16
- package/cpp/ops.h +0 -128
- package/cpp/simd-mappings.h +0 -888
- package/cpp/unary-ops.h +0 -28
- package/cpp/vec.h +0 -802
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/binary-ops.h +0 -16
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-aarch64.h +0 -8
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-impl.h +0 -512
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-quants.h +0 -63
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-traits.h +0 -38
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ops.h +0 -128
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/sgemm.h +0 -14
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/simd-mappings.h +0 -888
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/vec.h +0 -802
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-aarch64.h +0 -8
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-impl.h +0 -512
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-quants.h +0 -63
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-traits.h +0 -38
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/sgemm.h +0 -14
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/unary-ops.h +0 -28
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/vec.h +0 -802
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/binary-ops.h +0 -16
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-aarch64.h +0 -8
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-impl.h +0 -512
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-quants.h +0 -63
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-traits.h +0 -38
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ops.h +0 -128
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/sgemm.h +0 -14
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/simd-mappings.h +0 -888
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/unary-ops.h +0 -28
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/binary-ops.h +0 -16
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-aarch64.h +0 -8
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-impl.h +0 -512
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-quants.h +0 -63
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-traits.h +0 -38
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ops.h +0 -128
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/sgemm.h +0 -14
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/simd-mappings.h +0 -888
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/unary-ops.h +0 -28
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/vec.h +0 -802
- package/lib/commonjs/chat.js +0 -37
- package/lib/commonjs/chat.js.map +0 -1
- package/lib/module/chat.js +0 -33
- package/lib/module/chat.js.map +0 -1
- package/lib/typescript/chat.d.ts +0 -10
- package/lib/typescript/chat.d.ts.map +0 -1
- package/src/chat.ts +0 -44
- /package/cpp/{binary-ops.cpp → ggml-cpu/binary-ops.cpp} +0 -0
- /package/cpp/{ggml-cpu-aarch64.h → ggml-cpu/ggml-cpu-aarch64.h} +0 -0
- /package/cpp/{ggml-cpu-impl.h → ggml-cpu/ggml-cpu-impl.h} +0 -0
- /package/cpp/{ggml-cpu-quants.h → ggml-cpu/ggml-cpu-quants.h} +0 -0
- /package/cpp/{ggml-cpu-traits.cpp → ggml-cpu/ggml-cpu-traits.cpp} +0 -0
- /package/cpp/{ggml-cpu-traits.h → ggml-cpu/ggml-cpu-traits.h} +0 -0
- /package/cpp/{sgemm.h → ggml-cpu/sgemm.h} +0 -0
- /package/cpp/{unary-ops.cpp → ggml-cpu/unary-ops.cpp} +0 -0
package/cpp/ggml-metal.m
CHANGED
@@ -44,8 +44,8 @@ static struct lm_ggml_backend_device g_lm_ggml_backend_metal_device;
|
|
44
44
|
// note: assumes single GPU device - the default one
|
45
45
|
// TODO: support multiple GPU devices
|
46
46
|
static struct lm_ggml_backend_metal_device_context {
|
47
|
-
id<MTLDevice>
|
48
|
-
int
|
47
|
+
id<MTLDevice> mtl_device;
|
48
|
+
int mtl_device_ref_count;
|
49
49
|
id<MTLLibrary> mtl_library;
|
50
50
|
|
51
51
|
bool has_simdgroup_reduction;
|
@@ -149,6 +149,8 @@ enum lm_ggml_metal_kernel_type {
|
|
149
149
|
LM_GGML_METAL_KERNEL_TYPE_SIGMOID,
|
150
150
|
LM_GGML_METAL_KERNEL_TYPE_GELU,
|
151
151
|
LM_GGML_METAL_KERNEL_TYPE_GELU_4,
|
152
|
+
LM_GGML_METAL_KERNEL_TYPE_GELU_ERF,
|
153
|
+
LM_GGML_METAL_KERNEL_TYPE_GELU_ERF_4,
|
152
154
|
LM_GGML_METAL_KERNEL_TYPE_GELU_QUICK,
|
153
155
|
LM_GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,
|
154
156
|
LM_GGML_METAL_KERNEL_TYPE_SILU,
|
@@ -306,30 +308,36 @@ enum lm_ggml_metal_kernel_type {
|
|
306
308
|
LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32,
|
307
309
|
LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
|
308
310
|
LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
|
311
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16,
|
312
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32,
|
313
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16,
|
314
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16,
|
315
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16,
|
316
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F16,
|
317
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F16,
|
318
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16,
|
319
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16,
|
320
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16,
|
321
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16,
|
322
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16,
|
323
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16,
|
324
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F16,
|
325
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F16,
|
326
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F16,
|
327
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F16,
|
328
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F16,
|
329
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F16,
|
330
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F16,
|
331
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F16,
|
332
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F16,
|
333
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F16,
|
334
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16,
|
331
335
|
LM_GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32,
|
332
336
|
LM_GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16,
|
337
|
+
LM_GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32,
|
338
|
+
LM_GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16,
|
339
|
+
LM_GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32,
|
340
|
+
LM_GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16,
|
333
341
|
LM_GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32,
|
334
342
|
LM_GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16,
|
335
343
|
LM_GGML_METAL_KERNEL_TYPE_IM2COL_F16,
|
@@ -354,6 +362,7 @@ enum lm_ggml_metal_kernel_type {
|
|
354
362
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H192,
|
355
363
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128,
|
356
364
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
|
365
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512,
|
357
366
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64,
|
358
367
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80,
|
359
368
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96,
|
@@ -362,6 +371,7 @@ enum lm_ggml_metal_kernel_type {
|
|
362
371
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H192,
|
363
372
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128,
|
364
373
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256,
|
374
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512,
|
365
375
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64,
|
366
376
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80,
|
367
377
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96,
|
@@ -370,6 +380,7 @@ enum lm_ggml_metal_kernel_type {
|
|
370
380
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H192,
|
371
381
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128,
|
372
382
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256,
|
383
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512,
|
373
384
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64,
|
374
385
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80,
|
375
386
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96,
|
@@ -378,6 +389,7 @@ enum lm_ggml_metal_kernel_type {
|
|
378
389
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H192,
|
379
390
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128,
|
380
391
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256,
|
392
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512,
|
381
393
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64,
|
382
394
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80,
|
383
395
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96,
|
@@ -386,6 +398,7 @@ enum lm_ggml_metal_kernel_type {
|
|
386
398
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H192,
|
387
399
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128,
|
388
400
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256,
|
401
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512,
|
389
402
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64,
|
390
403
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80,
|
391
404
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96,
|
@@ -394,6 +407,7 @@ enum lm_ggml_metal_kernel_type {
|
|
394
407
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H192,
|
395
408
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128,
|
396
409
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256,
|
410
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512,
|
397
411
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64,
|
398
412
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80,
|
399
413
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96,
|
@@ -402,6 +416,21 @@ enum lm_ggml_metal_kernel_type {
|
|
402
416
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192,
|
403
417
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128,
|
404
418
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,
|
419
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512,
|
420
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64,
|
421
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64,
|
422
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64,
|
423
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H64,
|
424
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H64,
|
425
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H64,
|
426
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H64,
|
427
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96,
|
428
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96,
|
429
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96,
|
430
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H96,
|
431
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H96,
|
432
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H96,
|
433
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H96,
|
405
434
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
|
406
435
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128,
|
407
436
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128,
|
@@ -430,6 +459,13 @@ enum lm_ggml_metal_kernel_type {
|
|
430
459
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256,
|
431
460
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256,
|
432
461
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256,
|
462
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512,
|
463
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK576_HV512,
|
464
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK576_HV512,
|
465
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK576_HV512,
|
466
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512,
|
467
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512,
|
468
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512,
|
433
469
|
LM_GGML_METAL_KERNEL_TYPE_SET_I32,
|
434
470
|
LM_GGML_METAL_KERNEL_TYPE_SET_F32,
|
435
471
|
LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
|
@@ -460,6 +496,7 @@ enum lm_ggml_metal_kernel_type {
|
|
460
496
|
LM_GGML_METAL_KERNEL_TYPE_SQRT,
|
461
497
|
LM_GGML_METAL_KERNEL_TYPE_SIN,
|
462
498
|
LM_GGML_METAL_KERNEL_TYPE_COS,
|
499
|
+
LM_GGML_METAL_KERNEL_TYPE_NEG,
|
463
500
|
LM_GGML_METAL_KERNEL_TYPE_SUM_ROWS,
|
464
501
|
LM_GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
|
465
502
|
LM_GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
|
@@ -468,7 +505,264 @@ enum lm_ggml_metal_kernel_type {
|
|
468
505
|
LM_GGML_METAL_KERNEL_TYPE_COUNT
|
469
506
|
};
|
470
507
|
|
508
|
+
//
|
509
|
+
// lm_ggml_metal_heap
|
510
|
+
//
|
511
|
+
|
512
|
+
struct lm_ggml_metal_heap {
|
513
|
+
// number of times the heap was unused
|
514
|
+
int n_unused;
|
515
|
+
|
516
|
+
// total number of buffer allocations in this heap across all computes
|
517
|
+
int64_t n_alloc;
|
518
|
+
|
519
|
+
// current offset in the heap - we reset this after each node in order to reuse the memory
|
520
|
+
size_t offs;
|
521
|
+
|
522
|
+
// the currently allocated MTLBuffer objects in this heap
|
523
|
+
id<MTLHeap> obj;
|
524
|
+
|
525
|
+
NSMutableArray * bufs;
|
526
|
+
};
|
527
|
+
|
528
|
+
static struct lm_ggml_metal_heap * lm_ggml_metal_heap_init(id<MTLDevice> device, size_t size) {
|
529
|
+
struct lm_ggml_metal_heap * heap = calloc(1, sizeof(struct lm_ggml_metal_heap));
|
530
|
+
|
531
|
+
MTLHeapDescriptor * desc = [[MTLHeapDescriptor alloc] init];
|
532
|
+
desc.storageMode = MTLStorageModePrivate;
|
533
|
+
desc.cpuCacheMode = MTLCPUCacheModeDefaultCache;
|
534
|
+
desc.type = MTLHeapTypePlacement;
|
535
|
+
desc.size = size;
|
536
|
+
|
537
|
+
heap->n_unused = 0;
|
538
|
+
heap->n_alloc = 0;
|
539
|
+
|
540
|
+
heap->obj = [device newHeapWithDescriptor:desc];
|
541
|
+
if (!heap->obj) {
|
542
|
+
LM_GGML_LOG_ERROR("%s: error: failed to create MTLHeap with size %zu\n", __func__, size);
|
543
|
+
|
544
|
+
free(heap);
|
545
|
+
|
546
|
+
return false;
|
547
|
+
}
|
548
|
+
|
549
|
+
[desc release];
|
550
|
+
|
551
|
+
heap->bufs = [[NSMutableArray alloc] init];
|
552
|
+
|
553
|
+
return heap;
|
554
|
+
}
|
555
|
+
|
556
|
+
static void lm_ggml_metal_heap_reset(struct lm_ggml_metal_heap * heap) {
|
557
|
+
heap->offs = 0;
|
558
|
+
|
559
|
+
// count how many graph computes the heap ended up being unused
|
560
|
+
if ([heap->bufs count] > 0) {
|
561
|
+
heap->n_unused = 0;
|
562
|
+
} else {
|
563
|
+
heap->n_unused++;
|
564
|
+
}
|
565
|
+
|
566
|
+
for (id<MTLBuffer> buf in heap->bufs) {
|
567
|
+
[buf release];
|
568
|
+
}
|
569
|
+
[heap->bufs removeAllObjects];
|
570
|
+
|
571
|
+
// tell the OS that it can reuse this memory if needed
|
572
|
+
// ref: https://developer.apple.com/documentation/metal/mtlpurgeablestate?language=objc
|
573
|
+
[heap->obj setPurgeableState:MTLPurgeableStateVolatile];
|
574
|
+
}
|
575
|
+
|
576
|
+
static void lm_ggml_metal_heap_free(struct lm_ggml_metal_heap * heap) {
|
577
|
+
if (heap == nil) {
|
578
|
+
return;
|
579
|
+
}
|
580
|
+
|
581
|
+
lm_ggml_metal_heap_reset(heap);
|
582
|
+
|
583
|
+
[heap->obj release];
|
584
|
+
[heap->bufs release];
|
585
|
+
|
586
|
+
free(heap);
|
587
|
+
}
|
588
|
+
|
589
|
+
@interface lm_ggml_metal_heap_ptr : NSObject
|
590
|
+
|
591
|
+
@property (nonatomic, assign) struct lm_ggml_metal_heap * data;
|
592
|
+
|
593
|
+
@end
|
594
|
+
|
595
|
+
@implementation lm_ggml_metal_heap_ptr
|
596
|
+
@end
|
597
|
+
|
598
|
+
//
|
599
|
+
// lm_ggml_metal_mem_pool
|
600
|
+
//
|
601
|
+
|
602
|
+
struct lm_ggml_metal_mem_pool {
|
603
|
+
id<MTLDevice> device;
|
604
|
+
|
605
|
+
int n_heaps; // total number of heaps ever created (including those that were removed)
|
606
|
+
|
607
|
+
NSMutableArray * heaps;
|
608
|
+
NSMutableArray * heaps_to_remove;
|
609
|
+
};
|
610
|
+
|
611
|
+
static struct lm_ggml_metal_mem_pool * lm_ggml_metal_mem_pool_init(void) {
|
612
|
+
struct lm_ggml_metal_mem_pool * mem_pool = calloc(1, sizeof(struct lm_ggml_metal_mem_pool));
|
613
|
+
|
614
|
+
mem_pool->n_heaps = 0;
|
615
|
+
|
616
|
+
mem_pool->heaps = [[NSMutableArray alloc] init];
|
617
|
+
mem_pool->heaps_to_remove = [[NSMutableArray alloc] init];
|
618
|
+
|
619
|
+
return mem_pool;
|
620
|
+
}
|
621
|
+
|
622
|
+
static void lm_ggml_metal_mem_pool_free(struct lm_ggml_metal_mem_pool * mem_pool) {
|
623
|
+
LM_GGML_LOG_DEBUG("%s: freeing memory pool, num heaps = %zu (total = %d)\n", __func__, [mem_pool->heaps count], mem_pool->n_heaps);
|
624
|
+
|
625
|
+
size_t size_all = 0;
|
626
|
+
size_t size_cur = 0;
|
627
|
+
|
628
|
+
for (lm_ggml_metal_heap_ptr * ptr in mem_pool->heaps) {
|
629
|
+
LM_GGML_LOG_DEBUG("%s: heap: %p\n", __func__, (void *) ptr.data);
|
630
|
+
LM_GGML_LOG_DEBUG("%s: n_alloc: %" PRId64 "\n", __func__, ptr.data->n_alloc);
|
631
|
+
LM_GGML_LOG_DEBUG("%s: n_unused: %d\n", __func__, ptr.data->n_unused);
|
632
|
+
LM_GGML_LOG_DEBUG("%s: size: %.2f MiB\n", __func__, [ptr.data->obj size] / 1024.0 / 1024.0);
|
633
|
+
LM_GGML_LOG_DEBUG("%s: bufs: %zu\n", __func__, [ptr.data->bufs count]);
|
634
|
+
|
635
|
+
if ([ptr.data->bufs count] > 0) {
|
636
|
+
size_cur += [ptr.data->obj size];
|
637
|
+
}
|
638
|
+
size_all += [ptr.data->obj size];
|
639
|
+
|
640
|
+
lm_ggml_metal_heap_free(ptr.data);
|
641
|
+
[ptr release];
|
642
|
+
}
|
643
|
+
[mem_pool->heaps release];
|
644
|
+
[mem_pool->heaps_to_remove release];
|
645
|
+
|
646
|
+
if (size_all > 0) {
|
647
|
+
LM_GGML_LOG_DEBUG("%s: size_all: %.2f MiB\n", __func__, size_all / 1024.0 / 1024.0);
|
648
|
+
LM_GGML_LOG_DEBUG("%s: size_cur: %.2f MiB\n", __func__, size_cur / 1024.0 / 1024.0);
|
649
|
+
}
|
650
|
+
|
651
|
+
free(mem_pool);
|
652
|
+
}
|
653
|
+
|
654
|
+
static void lm_ggml_metal_mem_pool_reset(struct lm_ggml_metal_mem_pool * mem_pool) {
|
655
|
+
for (NSUInteger i = 0; i < [mem_pool->heaps count]; i++) {
|
656
|
+
lm_ggml_metal_heap_ptr * ptr = [mem_pool->heaps objectAtIndex:i];
|
657
|
+
|
658
|
+
struct lm_ggml_metal_heap * heap = ptr.data;
|
659
|
+
lm_ggml_metal_heap_reset(heap);
|
660
|
+
|
661
|
+
// if the heap hasn't been used for a while, remove it
|
662
|
+
if (heap->n_unused >= 128) {
|
663
|
+
[mem_pool->heaps_to_remove addObject:@(i)];
|
664
|
+
}
|
665
|
+
}
|
666
|
+
|
667
|
+
if (mem_pool->heaps_to_remove.count > 0) {
|
668
|
+
// remove in reverse order
|
669
|
+
for (NSUInteger i = [mem_pool->heaps_to_remove count] - 1; ; --i) {
|
670
|
+
NSUInteger index = [[mem_pool->heaps_to_remove objectAtIndex:i] intValue];
|
671
|
+
lm_ggml_metal_heap_ptr * ptr = [mem_pool->heaps objectAtIndex:index];
|
672
|
+
|
673
|
+
struct lm_ggml_metal_heap * heap = ptr.data;
|
674
|
+
lm_ggml_metal_heap_free(heap);
|
675
|
+
|
676
|
+
[mem_pool->heaps removeObjectAtIndex:index];
|
677
|
+
[ptr release];
|
678
|
+
|
679
|
+
if (i == 0) {
|
680
|
+
break;
|
681
|
+
}
|
682
|
+
}
|
683
|
+
|
684
|
+
[mem_pool->heaps_to_remove removeAllObjects];
|
685
|
+
}
|
686
|
+
}
|
687
|
+
|
688
|
+
static void lm_ggml_metal_mem_pool_clear(struct lm_ggml_metal_mem_pool * mem_pool) {
|
689
|
+
for (lm_ggml_metal_heap_ptr * ptr in mem_pool->heaps) {
|
690
|
+
ptr.data->offs = 0;
|
691
|
+
}
|
692
|
+
}
|
693
|
+
|
694
|
+
static id<MTLBuffer> lm_ggml_metal_mem_pool_alloc(struct lm_ggml_metal_mem_pool * mem_pool, size_t size) {
|
695
|
+
const size_t alignment = 256;
|
696
|
+
|
697
|
+
const size_t size_aligned = LM_GGML_PAD(size, alignment);
|
698
|
+
|
699
|
+
// try one of the existing heaps
|
700
|
+
for (lm_ggml_metal_heap_ptr * ptr in mem_pool->heaps) {
|
701
|
+
struct lm_ggml_metal_heap * heap = ptr.data;
|
702
|
+
if (heap->offs + size_aligned <= [heap->obj size]) {
|
703
|
+
// if this is the first buffer in the heap for the current command buffer, tell the OS that
|
704
|
+
// it cannot free the memory used by the heap
|
705
|
+
// ref: https://developer.apple.com/documentation/metal/mtlpurgeablestate?language=objc
|
706
|
+
if ([heap->bufs count] == 0) {
|
707
|
+
[heap->obj setPurgeableState:MTLPurgeableStateNonVolatile];
|
708
|
+
}
|
709
|
+
|
710
|
+
id<MTLBuffer> buf = [heap->obj newBufferWithLength:size_aligned options:MTLResourceStorageModePrivate offset:heap->offs];
|
711
|
+
if (buf == nil) {
|
712
|
+
LM_GGML_LOG_ERROR("%s: error: failed to create MTLBuffer with size %zu\n", __func__, size_aligned);
|
713
|
+
return nil;
|
714
|
+
}
|
715
|
+
|
716
|
+
heap->n_alloc++;
|
717
|
+
heap->offs += size_aligned;
|
718
|
+
|
719
|
+
[heap->bufs addObject:buf];
|
720
|
+
|
721
|
+
return buf;
|
722
|
+
}
|
723
|
+
}
|
724
|
+
|
725
|
+
// create a new heap that can fit this buffer
|
726
|
+
lm_ggml_metal_heap_ptr * heap_ptr = [lm_ggml_metal_heap_ptr new];
|
727
|
+
|
728
|
+
struct lm_ggml_metal_heap * heap = lm_ggml_metal_heap_init(mem_pool->device, size_aligned);
|
729
|
+
if (heap == NULL) {
|
730
|
+
LM_GGML_LOG_ERROR("%s: error: failed to create heap of size %zu\n", __func__, size_aligned);
|
731
|
+
return NULL;
|
732
|
+
}
|
733
|
+
|
734
|
+
//LM_GGML_LOG_DEBUG("%s: creating new heap of size %zu, got %zu\n", __func__, size_aligned, [heap->obj size]);
|
735
|
+
|
736
|
+
heap_ptr.data = heap;
|
737
|
+
lm_ggml_metal_heap_reset(heap);
|
738
|
+
|
739
|
+
[heap->obj setPurgeableState:MTLPurgeableStateNonVolatile];
|
740
|
+
id<MTLBuffer> buf = [heap->obj newBufferWithLength:size_aligned options:MTLResourceStorageModePrivate offset:heap->offs];
|
741
|
+
if (buf == nil) {
|
742
|
+
LM_GGML_LOG_ERROR("%s: error: failed to create MTLBuffer with size %zu\n", __func__, size_aligned);
|
743
|
+
return NULL;
|
744
|
+
}
|
745
|
+
|
746
|
+
heap->n_alloc++;
|
747
|
+
heap->offs += size_aligned;
|
748
|
+
|
749
|
+
[heap->bufs addObject:buf];
|
750
|
+
|
751
|
+
[mem_pool->heaps addObject:heap_ptr];
|
752
|
+
mem_pool->n_heaps++;
|
753
|
+
|
754
|
+
return buf;
|
755
|
+
}
|
756
|
+
|
757
|
+
struct lm_ggml_metal_command_buffer {
|
758
|
+
id<MTLCommandBuffer> obj;
|
759
|
+
|
760
|
+
// each command buffer has a memory pool from which it can allocate temporary buffers during the compute
|
761
|
+
struct lm_ggml_metal_mem_pool * mem_pool;
|
762
|
+
};
|
763
|
+
|
471
764
|
struct lm_ggml_backend_metal_context {
|
765
|
+
id<MTLDevice> device;
|
472
766
|
id<MTLCommandQueue> queue;
|
473
767
|
|
474
768
|
dispatch_queue_t d_queue;
|
@@ -493,7 +787,7 @@ struct lm_ggml_backend_metal_context {
|
|
493
787
|
void (^encode_async)(size_t ith);
|
494
788
|
|
495
789
|
// n_cb command buffers + 1 used by the main thread
|
496
|
-
|
790
|
+
struct lm_ggml_metal_command_buffer cmd_bufs[LM_GGML_METAL_MAX_COMMAND_BUFFERS + 1];
|
497
791
|
|
498
792
|
// abort lm_ggml_metal_graph_compute if callback returns true
|
499
793
|
lm_ggml_abort_callback abort_callback;
|
@@ -560,11 +854,7 @@ static id<MTLLibrary> lm_ggml_metal_load_library(id<MTLDevice> device, bool use_
|
|
560
854
|
NSBundle * bundle = [NSBundle bundleForClass:[LMGGMLMetalClass class]];
|
561
855
|
#endif
|
562
856
|
|
563
|
-
|
564
|
-
NSString * path_lib = [bundle pathForResource:@"ggml-llama-sim" ofType:@"metallib"];
|
565
|
-
#else
|
566
|
-
NSString * path_lib = [bundle pathForResource:@"ggml-llama" ofType:@"metallib"];
|
567
|
-
#endif
|
857
|
+
NSString * path_lib = [bundle pathForResource:@"default" ofType:@"metallib"];
|
568
858
|
if (path_lib == nil) {
|
569
859
|
// Try to find the resource in the directory where the current binary located.
|
570
860
|
NSString * current_binary = [[NSProcessInfo processInfo] arguments][0];
|
@@ -687,9 +977,11 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
|
|
687
977
|
struct lm_ggml_backend_metal_device_context * ctx_dev = dev->context;
|
688
978
|
|
689
979
|
id<MTLDevice> device = lm_ggml_backend_metal_device_acq(ctx_dev);
|
980
|
+
|
690
981
|
LM_GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]);
|
691
982
|
|
692
|
-
ctx->
|
983
|
+
ctx->device = device;
|
984
|
+
ctx->queue = [device newCommandQueue];
|
693
985
|
if (ctx->queue == nil) {
|
694
986
|
LM_GGML_LOG_ERROR("%s: error: failed to create command queue\n", __func__);
|
695
987
|
return NULL;
|
@@ -750,7 +1042,10 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
|
|
750
1042
|
ctx->gf = nil;
|
751
1043
|
ctx->encode_async = nil;
|
752
1044
|
for (int i = 0; i < LM_GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
|
753
|
-
ctx->
|
1045
|
+
ctx->cmd_bufs[i].obj = nil;
|
1046
|
+
|
1047
|
+
ctx->cmd_bufs[i].mem_pool = lm_ggml_metal_mem_pool_init();
|
1048
|
+
ctx->cmd_bufs[i].mem_pool->device = device;
|
754
1049
|
}
|
755
1050
|
|
756
1051
|
#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
|
@@ -810,6 +1105,8 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
|
|
810
1105
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SIGMOID, sigmoid, true);
|
811
1106
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GELU, gelu, true);
|
812
1107
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true);
|
1108
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GELU_ERF, gelu_erf, true);
|
1109
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GELU_ERF_4, gelu_erf_4, true);
|
813
1110
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true);
|
814
1111
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true);
|
815
1112
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SILU, silu, true);
|
@@ -967,30 +1264,36 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
|
|
967
1264
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, has_simdgroup_mm);
|
968
1265
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, has_simdgroup_mm);
|
969
1266
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, has_simdgroup_mm);
|
970
|
-
LM_GGML_METAL_ADD_KERNEL(
|
971
|
-
LM_GGML_METAL_ADD_KERNEL(
|
972
|
-
LM_GGML_METAL_ADD_KERNEL(
|
973
|
-
LM_GGML_METAL_ADD_KERNEL(
|
974
|
-
LM_GGML_METAL_ADD_KERNEL(
|
975
|
-
LM_GGML_METAL_ADD_KERNEL(
|
976
|
-
LM_GGML_METAL_ADD_KERNEL(
|
977
|
-
LM_GGML_METAL_ADD_KERNEL(
|
978
|
-
LM_GGML_METAL_ADD_KERNEL(
|
979
|
-
LM_GGML_METAL_ADD_KERNEL(
|
980
|
-
LM_GGML_METAL_ADD_KERNEL(
|
981
|
-
LM_GGML_METAL_ADD_KERNEL(
|
982
|
-
LM_GGML_METAL_ADD_KERNEL(
|
983
|
-
LM_GGML_METAL_ADD_KERNEL(
|
984
|
-
LM_GGML_METAL_ADD_KERNEL(
|
985
|
-
LM_GGML_METAL_ADD_KERNEL(
|
986
|
-
LM_GGML_METAL_ADD_KERNEL(
|
987
|
-
LM_GGML_METAL_ADD_KERNEL(
|
988
|
-
LM_GGML_METAL_ADD_KERNEL(
|
989
|
-
LM_GGML_METAL_ADD_KERNEL(
|
990
|
-
LM_GGML_METAL_ADD_KERNEL(
|
991
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1267
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16, mul_mm_id_map0_f16, has_simdgroup_mm);
|
1268
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32, mul_mm_id_map1_f32, has_simdgroup_mm);
|
1269
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16, mul_mm_id_f32_f16, has_simdgroup_mm);
|
1270
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16, mul_mm_id_f16_f16, has_simdgroup_mm);
|
1271
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16, mul_mm_id_bf16_f16, has_simdgroup_mm && use_bfloat);
|
1272
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F16, mul_mm_id_q4_0_f16, has_simdgroup_mm);
|
1273
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F16, mul_mm_id_q4_1_f16, has_simdgroup_mm);
|
1274
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16, mul_mm_id_q5_0_f16, has_simdgroup_mm);
|
1275
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16, mul_mm_id_q5_1_f16, has_simdgroup_mm);
|
1276
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16, mul_mm_id_q8_0_f16, has_simdgroup_mm);
|
1277
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16, mul_mm_id_q2_K_f16, has_simdgroup_mm);
|
1278
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16, mul_mm_id_q3_K_f16, has_simdgroup_mm);
|
1279
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16, mul_mm_id_q4_K_f16, has_simdgroup_mm);
|
1280
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F16, mul_mm_id_q5_K_f16, has_simdgroup_mm);
|
1281
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F16, mul_mm_id_q6_K_f16, has_simdgroup_mm);
|
1282
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F16, mul_mm_id_iq2_xxs_f16, has_simdgroup_mm);
|
1283
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F16, mul_mm_id_iq2_xs_f16, has_simdgroup_mm);
|
1284
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F16, mul_mm_id_iq3_xxs_f16, has_simdgroup_mm);
|
1285
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F16, mul_mm_id_iq3_s_f16, has_simdgroup_mm);
|
1286
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F16, mul_mm_id_iq2_s_f16, has_simdgroup_mm);
|
1287
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F16, mul_mm_id_iq1_s_f16, has_simdgroup_mm);
|
1288
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F16, mul_mm_id_iq1_m_f16, has_simdgroup_mm);
|
1289
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F16, mul_mm_id_iq4_nl_f16, has_simdgroup_mm);
|
1290
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16, mul_mm_id_iq4_xs_f16, has_simdgroup_mm);
|
992
1291
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true);
|
993
1292
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true);
|
1293
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32, rope_multi_f32, true);
|
1294
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16, rope_multi_f16, true);
|
1295
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32, rope_vision_f32, true);
|
1296
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16, rope_vision_f16, true);
|
994
1297
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true);
|
995
1298
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true);
|
996
1299
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
|
@@ -1015,6 +1318,7 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
|
|
1015
1318
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H192, flash_attn_ext_f16_h192, has_simdgroup_mm);
|
1016
1319
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128, flash_attn_ext_f16_hk192_hv128, has_simdgroup_mm);
|
1017
1320
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, has_simdgroup_mm);
|
1321
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512, flash_attn_ext_f16_hk576_hv512, has_simdgroup_mm);
|
1018
1322
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64, flash_attn_ext_bf16_h64, has_simdgroup_mm && use_bfloat);
|
1019
1323
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80, flash_attn_ext_bf16_h80, has_simdgroup_mm && use_bfloat);
|
1020
1324
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96, flash_attn_ext_bf16_h96, has_simdgroup_mm && use_bfloat);
|
@@ -1023,6 +1327,7 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
|
|
1023
1327
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H192, flash_attn_ext_bf16_h192, has_simdgroup_mm && use_bfloat);
|
1024
1328
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128, flash_attn_ext_bf16_hk192_hv128, has_simdgroup_mm && use_bfloat);
|
1025
1329
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256, flash_attn_ext_bf16_h256, has_simdgroup_mm && use_bfloat);
|
1330
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512, flash_attn_ext_bf16_hk576_hv512, has_simdgroup_mm && use_bfloat);
|
1026
1331
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, flash_attn_ext_q4_0_h64, has_simdgroup_mm);
|
1027
1332
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, flash_attn_ext_q4_0_h80, has_simdgroup_mm);
|
1028
1333
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, flash_attn_ext_q4_0_h96, has_simdgroup_mm);
|
@@ -1031,6 +1336,7 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
|
|
1031
1336
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H192, flash_attn_ext_q4_0_h192, has_simdgroup_mm);
|
1032
1337
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128, flash_attn_ext_q4_0_hk192_hv128, has_simdgroup_mm);
|
1033
1338
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256, flash_attn_ext_q4_0_h256, has_simdgroup_mm);
|
1339
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512, flash_attn_ext_q4_0_hk576_hv512, has_simdgroup_mm);
|
1034
1340
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64, flash_attn_ext_q4_1_h64, has_simdgroup_mm);
|
1035
1341
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80, flash_attn_ext_q4_1_h80, has_simdgroup_mm);
|
1036
1342
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96, flash_attn_ext_q4_1_h96, has_simdgroup_mm);
|
@@ -1039,6 +1345,7 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
|
|
1039
1345
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H192, flash_attn_ext_q4_1_h192, has_simdgroup_mm);
|
1040
1346
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128, flash_attn_ext_q4_1_hk192_hv128, has_simdgroup_mm);
|
1041
1347
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256, flash_attn_ext_q4_1_h256, has_simdgroup_mm);
|
1348
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512, flash_attn_ext_q4_1_hk576_hv512, has_simdgroup_mm);
|
1042
1349
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64, flash_attn_ext_q5_0_h64, has_simdgroup_mm);
|
1043
1350
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80, flash_attn_ext_q5_0_h80, has_simdgroup_mm);
|
1044
1351
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96, flash_attn_ext_q5_0_h96, has_simdgroup_mm);
|
@@ -1047,6 +1354,7 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
|
|
1047
1354
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H192, flash_attn_ext_q5_0_h192, has_simdgroup_mm);
|
1048
1355
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128, flash_attn_ext_q5_0_hk192_hv128, has_simdgroup_mm);
|
1049
1356
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256, flash_attn_ext_q5_0_h256, has_simdgroup_mm);
|
1357
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512, flash_attn_ext_q5_0_hk576_hv512, has_simdgroup_mm);
|
1050
1358
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64, flash_attn_ext_q5_1_h64, has_simdgroup_mm);
|
1051
1359
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80, flash_attn_ext_q5_1_h80, has_simdgroup_mm);
|
1052
1360
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96, flash_attn_ext_q5_1_h96, has_simdgroup_mm);
|
@@ -1055,6 +1363,7 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
|
|
1055
1363
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H192, flash_attn_ext_q5_1_h192, has_simdgroup_mm);
|
1056
1364
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128, flash_attn_ext_q5_1_hk192_hv128, has_simdgroup_mm);
|
1057
1365
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256, flash_attn_ext_q5_1_h256, has_simdgroup_mm);
|
1366
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512, flash_attn_ext_q5_1_hk576_hv512, has_simdgroup_mm);
|
1058
1367
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64, flash_attn_ext_q8_0_h64, has_simdgroup_mm);
|
1059
1368
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80, flash_attn_ext_q8_0_h80, has_simdgroup_mm);
|
1060
1369
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96, flash_attn_ext_q8_0_h96, has_simdgroup_mm);
|
@@ -1063,6 +1372,21 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
|
|
1063
1372
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192, flash_attn_ext_q8_0_h192, has_simdgroup_mm);
|
1064
1373
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128, flash_attn_ext_q8_0_hk192_hv128, has_simdgroup_mm);
|
1065
1374
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, has_simdgroup_mm);
|
1375
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512, flash_attn_ext_q8_0_hk576_hv512, has_simdgroup_mm);
|
1376
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64, flash_attn_ext_vec_f16_h64, has_simdgroup_reduction);
|
1377
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64, flash_attn_ext_vec_bf16_h64, has_simdgroup_reduction && use_bfloat);
|
1378
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64, flash_attn_ext_vec_q4_0_h64, has_simdgroup_reduction);
|
1379
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H64, flash_attn_ext_vec_q4_1_h64, has_simdgroup_reduction);
|
1380
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H64, flash_attn_ext_vec_q5_0_h64, has_simdgroup_reduction);
|
1381
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H64, flash_attn_ext_vec_q5_1_h64, has_simdgroup_reduction);
|
1382
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H64, flash_attn_ext_vec_q8_0_h64, has_simdgroup_reduction);
|
1383
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96, flash_attn_ext_vec_f16_h96, has_simdgroup_reduction);
|
1384
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96, flash_attn_ext_vec_bf16_h96, has_simdgroup_reduction && use_bfloat);
|
1385
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96, flash_attn_ext_vec_q4_0_h96, has_simdgroup_reduction);
|
1386
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H96, flash_attn_ext_vec_q4_1_h96, has_simdgroup_reduction);
|
1387
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H96, flash_attn_ext_vec_q5_0_h96, has_simdgroup_reduction);
|
1388
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H96, flash_attn_ext_vec_q5_1_h96, has_simdgroup_reduction);
|
1389
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H96, flash_attn_ext_vec_q8_0_h96, has_simdgroup_reduction);
|
1066
1390
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, has_simdgroup_reduction);
|
1067
1391
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128, flash_attn_ext_vec_bf16_h128, has_simdgroup_reduction && use_bfloat);
|
1068
1392
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128, flash_attn_ext_vec_q4_0_h128, has_simdgroup_reduction);
|
@@ -1091,6 +1415,13 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
|
|
1091
1415
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256, has_simdgroup_reduction);
|
1092
1416
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256, flash_attn_ext_vec_q5_1_h256, has_simdgroup_reduction);
|
1093
1417
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, flash_attn_ext_vec_q8_0_h256, has_simdgroup_reduction);
|
1418
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512, flash_attn_ext_vec_f16_hk576_hv512, has_simdgroup_reduction);
|
1419
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK576_HV512, flash_attn_ext_vec_bf16_hk576_hv512, has_simdgroup_reduction && use_bfloat);
|
1420
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK576_HV512, flash_attn_ext_vec_q4_0_hk576_hv512, has_simdgroup_reduction);
|
1421
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK576_HV512, flash_attn_ext_vec_q4_1_hk576_hv512, has_simdgroup_reduction);
|
1422
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512, flash_attn_ext_vec_q5_0_hk576_hv512, has_simdgroup_reduction);
|
1423
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512, flash_attn_ext_vec_q5_1_hk576_hv512, has_simdgroup_reduction);
|
1424
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512, flash_attn_ext_vec_q8_0_hk576_hv512, has_simdgroup_reduction);
|
1094
1425
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SET_F32, set_f32, true);
|
1095
1426
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SET_I32, set_i32, true);
|
1096
1427
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
|
@@ -1121,6 +1452,7 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
|
|
1121
1452
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SQRT, sqrt, true);
|
1122
1453
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SIN, sin, true);
|
1123
1454
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_COS, cos, true);
|
1455
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_NEG, neg, true);
|
1124
1456
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
1125
1457
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
|
1126
1458
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
|
@@ -1141,6 +1473,12 @@ static void lm_ggml_metal_free(struct lm_ggml_backend_metal_context * ctx) {
|
|
1141
1473
|
|
1142
1474
|
[ctx->queue release];
|
1143
1475
|
|
1476
|
+
for (int i = 0; i < LM_GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
|
1477
|
+
// ctx->cmd_bufs[i].obj is auto released
|
1478
|
+
|
1479
|
+
lm_ggml_metal_mem_pool_free(ctx->cmd_bufs[i].mem_pool);
|
1480
|
+
}
|
1481
|
+
|
1144
1482
|
dispatch_release(ctx->d_queue);
|
1145
1483
|
|
1146
1484
|
free(ctx);
|
@@ -1279,9 +1617,11 @@ static bool lm_ggml_metal_supports_op(const struct lm_ggml_backend_metal_device_
|
|
1279
1617
|
case LM_GGML_UNARY_OP_RELU:
|
1280
1618
|
case LM_GGML_UNARY_OP_SIGMOID:
|
1281
1619
|
case LM_GGML_UNARY_OP_GELU:
|
1620
|
+
case LM_GGML_UNARY_OP_GELU_ERF:
|
1282
1621
|
case LM_GGML_UNARY_OP_GELU_QUICK:
|
1283
1622
|
case LM_GGML_UNARY_OP_SILU:
|
1284
1623
|
case LM_GGML_UNARY_OP_ELU:
|
1624
|
+
case LM_GGML_UNARY_OP_NEG:
|
1285
1625
|
return lm_ggml_is_contiguous(op->src[0]) && op->src[0]->type == LM_GGML_TYPE_F32;
|
1286
1626
|
default:
|
1287
1627
|
return false;
|
@@ -1324,22 +1664,14 @@ static bool lm_ggml_metal_supports_op(const struct lm_ggml_backend_metal_device_
|
|
1324
1664
|
case LM_GGML_OP_NORM:
|
1325
1665
|
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && lm_ggml_is_contiguous_1(op->src[0]));
|
1326
1666
|
case LM_GGML_OP_ROPE:
|
1327
|
-
|
1328
|
-
const int mode = ((const int32_t *) op->op_params)[2];
|
1329
|
-
if (mode & LM_GGML_ROPE_TYPE_MROPE) {
|
1330
|
-
return false;
|
1331
|
-
}
|
1332
|
-
if (mode & LM_GGML_ROPE_TYPE_VISION) {
|
1333
|
-
return false;
|
1334
|
-
}
|
1335
|
-
return true;
|
1336
|
-
}
|
1667
|
+
return true;
|
1337
1668
|
case LM_GGML_OP_IM2COL:
|
1338
1669
|
return op->src[0]->type == LM_GGML_TYPE_F16;
|
1339
1670
|
case LM_GGML_OP_POOL_1D:
|
1340
1671
|
return false;
|
1341
|
-
case LM_GGML_OP_POOL_2D:
|
1342
1672
|
case LM_GGML_OP_UPSCALE:
|
1673
|
+
return op->src[0]->type == LM_GGML_TYPE_F32 && op->op_params[0] == LM_GGML_SCALE_MODE_NEAREST;
|
1674
|
+
case LM_GGML_OP_POOL_2D:
|
1343
1675
|
case LM_GGML_OP_PAD:
|
1344
1676
|
case LM_GGML_OP_PAD_REFLECT_1D:
|
1345
1677
|
case LM_GGML_OP_TIMESTEP_EMBEDDING:
|
@@ -1354,6 +1686,11 @@ static bool lm_ggml_metal_supports_op(const struct lm_ggml_backend_metal_device_
|
|
1354
1686
|
// TODO: not sure if it is worth adding kernels for this size
|
1355
1687
|
return false;
|
1356
1688
|
}
|
1689
|
+
if (op->src[0]->ne[0] == 576) {
|
1690
|
+
// DeepSeek sizes
|
1691
|
+
// TODO: disabled for now, until optmized
|
1692
|
+
return false;
|
1693
|
+
}
|
1357
1694
|
if (op->src[1]->type != op->src[2]->type) {
|
1358
1695
|
return false;
|
1359
1696
|
}
|
@@ -1439,10 +1776,11 @@ static bool lm_ggml_metal_supports_op(const struct lm_ggml_backend_metal_device_
|
|
1439
1776
|
}
|
1440
1777
|
}
|
1441
1778
|
|
1442
|
-
static
|
1779
|
+
static bool lm_ggml_metal_encode_node(
|
1443
1780
|
lm_ggml_backend_t backend,
|
1444
1781
|
int idx,
|
1445
|
-
id<MTLComputeCommandEncoder> encoder
|
1782
|
+
id<MTLComputeCommandEncoder> encoder,
|
1783
|
+
struct lm_ggml_metal_mem_pool * mem_pool) {
|
1446
1784
|
struct lm_ggml_backend_metal_context * ctx = backend->context;
|
1447
1785
|
struct lm_ggml_backend_metal_device_context * ctx_dev = backend->device->context;
|
1448
1786
|
|
@@ -1458,7 +1796,7 @@ static void lm_ggml_metal_encode_node(
|
|
1458
1796
|
struct lm_ggml_tensor * dst = node;
|
1459
1797
|
|
1460
1798
|
if (lm_ggml_is_empty(dst)) {
|
1461
|
-
return;
|
1799
|
+
return true;
|
1462
1800
|
}
|
1463
1801
|
|
1464
1802
|
switch (dst->op) {
|
@@ -1469,7 +1807,7 @@ static void lm_ggml_metal_encode_node(
|
|
1469
1807
|
case LM_GGML_OP_PERMUTE:
|
1470
1808
|
{
|
1471
1809
|
// noop -> next node
|
1472
|
-
} return;
|
1810
|
+
} return true;
|
1473
1811
|
default:
|
1474
1812
|
{
|
1475
1813
|
} break;
|
@@ -1480,6 +1818,8 @@ static void lm_ggml_metal_encode_node(
|
|
1480
1818
|
LM_GGML_ABORT("unsupported op");
|
1481
1819
|
}
|
1482
1820
|
|
1821
|
+
lm_ggml_metal_mem_pool_clear(mem_pool);
|
1822
|
+
|
1483
1823
|
const int64_t ne00 = src0 ? src0->ne[0] : 0;
|
1484
1824
|
const int64_t ne01 = src0 ? src0->ne[1] : 0;
|
1485
1825
|
const int64_t ne02 = src0 ? src0->ne[2] : 0;
|
@@ -1916,6 +2256,25 @@ static void lm_ggml_metal_encode_node(
|
|
1916
2256
|
|
1917
2257
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1918
2258
|
} break;
|
2259
|
+
case LM_GGML_UNARY_OP_GELU_ERF:
|
2260
|
+
{
|
2261
|
+
int64_t n = lm_ggml_nelements(dst);
|
2262
|
+
|
2263
|
+
id<MTLComputePipelineState> pipeline = nil;
|
2264
|
+
|
2265
|
+
if (n % 4 == 0) {
|
2266
|
+
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GELU_ERF_4].pipeline;
|
2267
|
+
n /= 4;
|
2268
|
+
} else {
|
2269
|
+
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GELU_ERF].pipeline;
|
2270
|
+
}
|
2271
|
+
|
2272
|
+
[encoder setComputePipelineState:pipeline];
|
2273
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
2274
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
2275
|
+
|
2276
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
2277
|
+
} break;
|
1919
2278
|
case LM_GGML_UNARY_OP_GELU_QUICK:
|
1920
2279
|
{
|
1921
2280
|
int64_t n = lm_ggml_nelements(dst);
|
@@ -1966,6 +2325,18 @@ static void lm_ggml_metal_encode_node(
|
|
1966
2325
|
|
1967
2326
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1968
2327
|
} break;
|
2328
|
+
case LM_GGML_UNARY_OP_NEG:
|
2329
|
+
{
|
2330
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_NEG].pipeline;
|
2331
|
+
|
2332
|
+
[encoder setComputePipelineState:pipeline];
|
2333
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
2334
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
2335
|
+
|
2336
|
+
const int64_t n = lm_ggml_nelements(dst);
|
2337
|
+
|
2338
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
2339
|
+
} break;
|
1969
2340
|
default:
|
1970
2341
|
{
|
1971
2342
|
LM_GGML_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, idx, lm_ggml_op_name(dst->op));
|
@@ -2114,26 +2485,76 @@ static void lm_ggml_metal_encode_node(
|
|
2114
2485
|
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
2115
2486
|
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
2116
2487
|
|
2117
|
-
|
2488
|
+
// use this branch to test the lm_ggml_metal_mem_pool functionality
|
2489
|
+
#if 0
|
2490
|
+
// cpy to tmp buffer in MTLHeap
|
2491
|
+
|
2492
|
+
id<MTLBuffer> h_src0 = h_src0 = lm_ggml_metal_mem_pool_alloc(mem_pool, lm_ggml_nbytes(src0));
|
2493
|
+
if (!h_src0) {
|
2494
|
+
LM_GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, lm_ggml_nbytes(src0));
|
2495
|
+
return false;
|
2496
|
+
}
|
2497
|
+
|
2498
|
+
offs_src0 = 0;
|
2499
|
+
|
2500
|
+
lm_ggml_metal_kargs_cpy args_cpy = {
|
2118
2501
|
/*.ne00 =*/ ne00,
|
2119
2502
|
/*.ne01 =*/ ne01,
|
2120
2503
|
/*.ne02 =*/ ne02,
|
2121
|
-
/*.
|
2122
|
-
/*.
|
2123
|
-
/*.
|
2124
|
-
/*.
|
2504
|
+
/*.ne03 =*/ ne03,
|
2505
|
+
/*.nb00 =*/ nb00,
|
2506
|
+
/*.nb01 =*/ nb01,
|
2507
|
+
/*.nb02 =*/ nb02,
|
2508
|
+
/*.nb03 =*/ nb03,
|
2509
|
+
/*.ne0 =*/ ne00,
|
2510
|
+
/*.ne1 =*/ ne01,
|
2511
|
+
/*.ne2 =*/ ne02,
|
2512
|
+
/*.ne3 =*/ ne03,
|
2513
|
+
/*.nb0 =*/ nb00,
|
2514
|
+
/*.nb1 =*/ nb01,
|
2515
|
+
/*.nb2 =*/ nb02,
|
2516
|
+
/*.nb3 =*/ nb03,
|
2517
|
+
};
|
2518
|
+
|
2519
|
+
if (src0->type == LM_GGML_TYPE_F16) {
|
2520
|
+
[encoder setComputePipelineState:ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline];
|
2521
|
+
} else {
|
2522
|
+
[encoder setComputePipelineState:ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline];
|
2523
|
+
}
|
2524
|
+
[encoder setBytes:&args_cpy length:sizeof(args_cpy) atIndex:0];
|
2525
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
2526
|
+
[encoder setBuffer:h_src0 offset:0 atIndex:2];
|
2527
|
+
|
2528
|
+
LM_GGML_ASSERT(ne00 % lm_ggml_blck_size(src0->type) == 0);
|
2529
|
+
int nth_cpy = MIN(1024, ne00 / lm_ggml_blck_size(src0->type));
|
2530
|
+
|
2531
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth_cpy, 1, 1)];
|
2532
|
+
|
2533
|
+
#else
|
2534
|
+
id<MTLBuffer> h_src0 = id_src0;
|
2535
|
+
#endif
|
2536
|
+
// softmax
|
2537
|
+
|
2538
|
+
lm_ggml_metal_kargs_soft_max args = {
|
2539
|
+
/*.ne00 =*/ ne00,
|
2540
|
+
/*.ne01 =*/ ne01,
|
2541
|
+
/*.ne02 =*/ ne02,
|
2542
|
+
/*.scale =*/ scale,
|
2543
|
+
/*.max_bias =*/ max_bias,
|
2544
|
+
/*.m0 =*/ m0,
|
2545
|
+
/*.m1 =*/ m1,
|
2125
2546
|
/*.n_head_log2 =*/ n_head_log2,
|
2126
2547
|
};
|
2127
2548
|
|
2128
2549
|
[encoder setComputePipelineState:pipeline];
|
2129
|
-
[encoder setBuffer:
|
2550
|
+
[encoder setBuffer:h_src0 offset:offs_src0 atIndex:0];
|
2130
2551
|
if (id_src1) {
|
2131
|
-
[encoder setBuffer:id_src1 offset:offs_src1
|
2552
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
2132
2553
|
} else {
|
2133
|
-
[encoder setBuffer:
|
2554
|
+
[encoder setBuffer:h_src0 offset:offs_src0 atIndex:1];
|
2134
2555
|
}
|
2135
|
-
[encoder setBuffer:id_dst
|
2136
|
-
[encoder setBytes:&args
|
2556
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
2557
|
+
[encoder setBytes:&args length:sizeof(args) atIndex:3];
|
2137
2558
|
|
2138
2559
|
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
2139
2560
|
|
@@ -2624,7 +3045,7 @@ static void lm_ggml_metal_encode_node(
|
|
2624
3045
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
2625
3046
|
|
2626
3047
|
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
2627
|
-
[encoder dispatchThreadgroups:MTLSizeMake(
|
3048
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
2628
3049
|
} else {
|
2629
3050
|
id<MTLComputePipelineState> pipeline = nil;
|
2630
3051
|
|
@@ -2844,8 +3265,6 @@ static void lm_ggml_metal_encode_node(
|
|
2844
3265
|
} break;
|
2845
3266
|
case LM_GGML_OP_MUL_MAT_ID:
|
2846
3267
|
{
|
2847
|
-
const int n_as = src0->ne[2];
|
2848
|
-
|
2849
3268
|
// src2 = ids
|
2850
3269
|
const enum lm_ggml_type src2t = src2->type; LM_GGML_UNUSED(src2t);
|
2851
3270
|
|
@@ -2859,24 +3278,21 @@ static void lm_ggml_metal_encode_node(
|
|
2859
3278
|
LM_GGML_ASSERT(ne03 == 1);
|
2860
3279
|
LM_GGML_ASSERT(ne13 == 1);
|
2861
3280
|
|
3281
|
+
const uint32_t r2 = 1;
|
3282
|
+
const uint32_t r3 = 1;
|
3283
|
+
|
2862
3284
|
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
|
2863
3285
|
// to the matrix-vector kernel
|
2864
3286
|
// ne20 = n_used_experts
|
2865
|
-
// ne21 = n_rows
|
2866
|
-
const int
|
2867
|
-
const int dst_rows_min = n_as;
|
2868
|
-
const int dst_rows_max = (device.maxThreadgroupMemoryLength/2 - 8192)/4;
|
2869
|
-
|
2870
|
-
// max size of the rowids array in the kernel shared buffer
|
2871
|
-
//LM_GGML_ASSERT(dst_rows <= dst_rows_max);
|
3287
|
+
// ne21 = n_rows (batch size)
|
3288
|
+
const int ne21_mm_id_min = 32;
|
2872
3289
|
|
2873
3290
|
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
2874
3291
|
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
2875
3292
|
if ([device supportsFamily:MTLGPUFamilyApple7] &&
|
2876
3293
|
ne00 % 32 == 0 && ne00 >= 64 &&
|
2877
|
-
|
2878
|
-
|
2879
|
-
dst_rows <= dst_rows_max) {
|
3294
|
+
(ne21 >= ne21_mm_id_min)) {
|
3295
|
+
LM_GGML_ASSERT(ne00 % 4 == 0);
|
2880
3296
|
|
2881
3297
|
// some Metal matrix data types require aligned pointers
|
2882
3298
|
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
|
@@ -2887,62 +3303,169 @@ static void lm_ggml_metal_encode_node(
|
|
2887
3303
|
default: break;
|
2888
3304
|
}
|
2889
3305
|
|
2890
|
-
|
3306
|
+
const int64_t neh10 = ne10; // n_embd
|
3307
|
+
const int64_t neh11 = ne21; // n_tokens
|
3308
|
+
const int64_t neh12 = ne02; // n_expert
|
2891
3309
|
|
2892
|
-
|
2893
|
-
|
2894
|
-
|
2895
|
-
|
2896
|
-
|
2897
|
-
|
2898
|
-
|
2899
|
-
|
2900
|
-
|
2901
|
-
|
2902
|
-
case LM_GGML_TYPE_Q3_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32 ].pipeline; break;
|
2903
|
-
case LM_GGML_TYPE_Q4_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32 ].pipeline; break;
|
2904
|
-
case LM_GGML_TYPE_Q5_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32 ].pipeline; break;
|
2905
|
-
case LM_GGML_TYPE_Q6_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32 ].pipeline; break;
|
2906
|
-
case LM_GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break;
|
2907
|
-
case LM_GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break;
|
2908
|
-
case LM_GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32].pipeline; break;
|
2909
|
-
case LM_GGML_TYPE_IQ3_S: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32 ].pipeline; break;
|
2910
|
-
case LM_GGML_TYPE_IQ2_S: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32 ].pipeline; break;
|
2911
|
-
case LM_GGML_TYPE_IQ1_S: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32 ].pipeline; break;
|
2912
|
-
case LM_GGML_TYPE_IQ1_M: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32 ].pipeline; break;
|
2913
|
-
case LM_GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break;
|
2914
|
-
case LM_GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32 ].pipeline; break;
|
2915
|
-
default: LM_GGML_ABORT("MUL_MAT_ID not implemented");
|
3310
|
+
const uint64_t nbh10 = lm_ggml_type_size(LM_GGML_TYPE_F16);
|
3311
|
+
const uint64_t nbh11 = nbh10*neh10;
|
3312
|
+
const uint64_t nbh12 = nbh11*neh11;
|
3313
|
+
const uint64_t nbh13 = nbh12*neh12;
|
3314
|
+
|
3315
|
+
const size_t s_src1 = lm_ggml_type_size(LM_GGML_TYPE_F16)*neh10*neh11*neh12;
|
3316
|
+
id<MTLBuffer> h_src1 = lm_ggml_metal_mem_pool_alloc(mem_pool, s_src1);
|
3317
|
+
if (!h_src1) {
|
3318
|
+
LM_GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_src1);
|
3319
|
+
return false;
|
2916
3320
|
}
|
2917
3321
|
|
2918
|
-
|
2919
|
-
|
2920
|
-
|
2921
|
-
/*.nbi1 =*/ nb21,
|
2922
|
-
/*.ne00 =*/ ne00,
|
2923
|
-
/*.ne02 =*/ ne02,
|
2924
|
-
/*.nb01 =*/ nb01,
|
2925
|
-
/*.nb02 =*/ nb02,
|
2926
|
-
/*.ne11 =*/ ne11,
|
2927
|
-
/*.ne12 =*/ ne12,
|
2928
|
-
/*.ne13 =*/ ne13,
|
2929
|
-
/*.nb10 =*/ nb10,
|
2930
|
-
/*.nb11 =*/ nb11,
|
2931
|
-
/*.nb12 =*/ nb12,
|
2932
|
-
/*.ne0 =*/ ne0,
|
2933
|
-
/*.ne1 =*/ ne1,
|
2934
|
-
};
|
3322
|
+
const int64_t neh0 = ne0;
|
3323
|
+
const int64_t neh1 = ne21;
|
3324
|
+
const int64_t neh2 = ne02;
|
2935
3325
|
|
2936
|
-
|
2937
|
-
|
2938
|
-
|
2939
|
-
|
2940
|
-
|
2941
|
-
|
3326
|
+
const uint64_t nbh0 = lm_ggml_type_size(LM_GGML_TYPE_F32);
|
3327
|
+
const uint64_t nbh1 = nbh0*neh0;
|
3328
|
+
const uint64_t nbh2 = nbh1*neh1;
|
3329
|
+
//const uint64_t nbh3 = nbh2*neh2;
|
3330
|
+
|
3331
|
+
const size_t s_dst = lm_ggml_type_size(LM_GGML_TYPE_F32)*neh0*neh1*neh2;
|
3332
|
+
id<MTLBuffer> h_dst = lm_ggml_metal_mem_pool_alloc(mem_pool, s_dst);
|
3333
|
+
if (!h_dst) {
|
3334
|
+
LM_GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_dst);
|
3335
|
+
return false;
|
3336
|
+
}
|
3337
|
+
|
3338
|
+
// tokens per expert
|
3339
|
+
const size_t s_tpe = lm_ggml_type_size(LM_GGML_TYPE_I32)*ne02;
|
3340
|
+
id<MTLBuffer> h_tpe = lm_ggml_metal_mem_pool_alloc(mem_pool, s_tpe);
|
3341
|
+
if (!h_tpe) {
|
3342
|
+
LM_GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_tpe);
|
3343
|
+
return false;
|
3344
|
+
}
|
3345
|
+
|
3346
|
+
// id map
|
3347
|
+
// [n_expert_used, n_tokens]
|
3348
|
+
const size_t s_ids = lm_ggml_type_size(LM_GGML_TYPE_I32)*ne20*ne21;
|
3349
|
+
id<MTLBuffer> h_ids = lm_ggml_metal_mem_pool_alloc(mem_pool, s_ids);
|
3350
|
+
if (!h_ids) {
|
3351
|
+
LM_GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_ids);
|
3352
|
+
return false;
|
3353
|
+
}
|
3354
|
+
|
3355
|
+
{
|
3356
|
+
const int nth = MIN(1024, ne10/4);
|
3357
|
+
|
3358
|
+
lm_ggml_metal_kargs_mul_mm_id_map0 args = {
|
3359
|
+
ne10,
|
3360
|
+
ne11, // n_expert_used (bcast)
|
3361
|
+
nb11,
|
3362
|
+
nb12,
|
3363
|
+
neh11, // n_tokens
|
3364
|
+
nbh11,
|
3365
|
+
ne20, // n_expert_used
|
3366
|
+
nb21,
|
3367
|
+
};
|
3368
|
+
|
3369
|
+
id<MTLComputePipelineState> pipeline = nil;
|
3370
|
+
|
3371
|
+
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16].pipeline;
|
3372
|
+
|
3373
|
+
[encoder setComputePipelineState:pipeline];
|
3374
|
+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
3375
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
3376
|
+
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
|
3377
|
+
[encoder setBuffer: h_src1 offset:0 atIndex:3];
|
3378
|
+
[encoder setBuffer: h_tpe offset:0 atIndex:4];
|
3379
|
+
[encoder setBuffer: h_ids offset:0 atIndex:5];
|
3380
|
+
|
3381
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne02, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
3382
|
+
}
|
3383
|
+
|
3384
|
+
{
|
3385
|
+
id<MTLComputePipelineState> pipeline = nil;
|
3386
|
+
|
3387
|
+
switch (src0->type) {
|
3388
|
+
case LM_GGML_TYPE_F32: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16 ].pipeline; break;
|
3389
|
+
case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16 ].pipeline; break;
|
3390
|
+
case LM_GGML_TYPE_BF16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16 ].pipeline; break;
|
3391
|
+
case LM_GGML_TYPE_Q4_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F16 ].pipeline; break;
|
3392
|
+
case LM_GGML_TYPE_Q4_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F16 ].pipeline; break;
|
3393
|
+
case LM_GGML_TYPE_Q5_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16 ].pipeline; break;
|
3394
|
+
case LM_GGML_TYPE_Q5_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16 ].pipeline; break;
|
3395
|
+
case LM_GGML_TYPE_Q8_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16 ].pipeline; break;
|
3396
|
+
case LM_GGML_TYPE_Q2_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16 ].pipeline; break;
|
3397
|
+
case LM_GGML_TYPE_Q3_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16 ].pipeline; break;
|
3398
|
+
case LM_GGML_TYPE_Q4_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16 ].pipeline; break;
|
3399
|
+
case LM_GGML_TYPE_Q5_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F16 ].pipeline; break;
|
3400
|
+
case LM_GGML_TYPE_Q6_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F16 ].pipeline; break;
|
3401
|
+
case LM_GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F16].pipeline; break;
|
3402
|
+
case LM_GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F16 ].pipeline; break;
|
3403
|
+
case LM_GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F16].pipeline; break;
|
3404
|
+
case LM_GGML_TYPE_IQ3_S: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F16 ].pipeline; break;
|
3405
|
+
case LM_GGML_TYPE_IQ2_S: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F16 ].pipeline; break;
|
3406
|
+
case LM_GGML_TYPE_IQ1_S: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F16 ].pipeline; break;
|
3407
|
+
case LM_GGML_TYPE_IQ1_M: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F16 ].pipeline; break;
|
3408
|
+
case LM_GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F16 ].pipeline; break;
|
3409
|
+
case LM_GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16 ].pipeline; break;
|
3410
|
+
default: LM_GGML_ABORT("MUL_MAT_ID not implemented");
|
3411
|
+
}
|
3412
|
+
|
3413
|
+
lm_ggml_metal_kargs_mul_mm_id args = {
|
3414
|
+
/*.ne00 =*/ ne00,
|
3415
|
+
/*.ne02 =*/ ne02,
|
3416
|
+
/*.nb01 =*/ nb01,
|
3417
|
+
/*.nb02 =*/ nb02,
|
3418
|
+
/*.nb03 =*/ nb03,
|
3419
|
+
/*.neh12 =*/ neh12,
|
3420
|
+
/*.nbh10 =*/ nbh10,
|
3421
|
+
/*.nbh11 =*/ nbh11,
|
3422
|
+
/*.nbh12 =*/ nbh12,
|
3423
|
+
/*.nbh13 =*/ nbh13,
|
3424
|
+
/*.neh0 =*/ neh0,
|
3425
|
+
/*.neh1 =*/ neh1,
|
3426
|
+
/*.r2 =*/ r2,
|
3427
|
+
/*.r3 =*/ r3,
|
3428
|
+
};
|
3429
|
+
|
3430
|
+
[encoder setComputePipelineState:pipeline];
|
3431
|
+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
3432
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
3433
|
+
[encoder setBuffer: h_src1 offset:0 atIndex:2];
|
3434
|
+
[encoder setBuffer: h_tpe offset:0 atIndex:3];
|
3435
|
+
[encoder setBuffer: h_dst offset:0 atIndex:4];
|
3436
|
+
|
3437
|
+
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
3438
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, ne02) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
3439
|
+
}
|
2942
3440
|
|
2943
|
-
|
3441
|
+
{
|
3442
|
+
LM_GGML_ASSERT(ne0 % 4 == 0);
|
2944
3443
|
|
2945
|
-
|
3444
|
+
const int nth = MIN(1024, ne0/4);
|
3445
|
+
|
3446
|
+
lm_ggml_metal_kargs_mul_mm_id_map1 args = {
|
3447
|
+
ne20, // n_expert_used
|
3448
|
+
neh0,
|
3449
|
+
neh1,
|
3450
|
+
nbh1,
|
3451
|
+
nbh2,
|
3452
|
+
ne0,
|
3453
|
+
nb1,
|
3454
|
+
nb2,
|
3455
|
+
};
|
3456
|
+
|
3457
|
+
id<MTLComputePipelineState> pipeline = nil;
|
3458
|
+
|
3459
|
+
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32].pipeline;
|
3460
|
+
|
3461
|
+
[encoder setComputePipelineState:pipeline];
|
3462
|
+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
3463
|
+
[encoder setBuffer: h_dst offset:0 atIndex:1];
|
3464
|
+
[encoder setBuffer: h_ids offset:0 atIndex:2];
|
3465
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
3466
|
+
|
3467
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne20, ne21, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
3468
|
+
}
|
2946
3469
|
} else {
|
2947
3470
|
id<MTLComputePipelineState> pipeline = nil;
|
2948
3471
|
|
@@ -3136,7 +3659,7 @@ static void lm_ggml_metal_encode_node(
|
|
3136
3659
|
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:4];
|
3137
3660
|
|
3138
3661
|
const int64_t _ne1 = 1;
|
3139
|
-
const int64_t ne123 =
|
3662
|
+
const int64_t ne123 = ne20*ne21;
|
3140
3663
|
|
3141
3664
|
if (smem > 0) {
|
3142
3665
|
[encoder setThreadgroupMemoryLength:smem atIndex:0];
|
@@ -3340,6 +3863,7 @@ static void lm_ggml_metal_encode_node(
|
|
3340
3863
|
} break;
|
3341
3864
|
case LM_GGML_OP_ROPE:
|
3342
3865
|
{
|
3866
|
+
|
3343
3867
|
// make sure we have one or more position id(ne10) per token(ne02)
|
3344
3868
|
LM_GGML_ASSERT(ne10 % ne02 == 0);
|
3345
3869
|
LM_GGML_ASSERT(ne10 >= ne02);
|
@@ -3366,20 +3890,42 @@ static void lm_ggml_metal_encode_node(
|
|
3366
3890
|
memcpy(&beta_fast, (const int32_t *) dst->op_params + 9, sizeof(float));
|
3367
3891
|
memcpy(&beta_slow, (const int32_t *) dst->op_params + 10, sizeof(float));
|
3368
3892
|
|
3369
|
-
const bool is_neox
|
3893
|
+
const bool is_neox = mode & LM_GGML_ROPE_TYPE_NEOX;
|
3894
|
+
const bool is_mrope = mode & LM_GGML_ROPE_TYPE_MROPE;
|
3895
|
+
const bool is_vision = mode == LM_GGML_ROPE_TYPE_VISION;
|
3896
|
+
|
3897
|
+
// mrope
|
3898
|
+
const int sect_0 = ((const int32_t *) dst->op_params)[11];
|
3899
|
+
const int sect_1 = ((const int32_t *) dst->op_params)[12];
|
3900
|
+
const int sect_2 = ((const int32_t *) dst->op_params)[13];
|
3901
|
+
const int sect_3 = ((const int32_t *) dst->op_params)[14];
|
3370
3902
|
|
3371
3903
|
id<MTLComputePipelineState> pipeline = nil;
|
3372
3904
|
|
3373
|
-
if (
|
3905
|
+
if (is_neox) {
|
3374
3906
|
switch (src0->type) {
|
3375
|
-
case LM_GGML_TYPE_F32: pipeline = ctx->kernels[
|
3376
|
-
case LM_GGML_TYPE_F16: pipeline = ctx->kernels[
|
3907
|
+
case LM_GGML_TYPE_F32: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break;
|
3908
|
+
case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break;
|
3909
|
+
default: LM_GGML_ABORT("fatal error");
|
3910
|
+
};
|
3911
|
+
} else if (is_mrope && !is_vision) {
|
3912
|
+
LM_GGML_ASSERT(ne10*4 >= ne02); // need at least 4 pos per token
|
3913
|
+
switch (src0->type) {
|
3914
|
+
case LM_GGML_TYPE_F32: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32].pipeline; break;
|
3915
|
+
case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16].pipeline; break;
|
3916
|
+
default: LM_GGML_ABORT("fatal error");
|
3917
|
+
};
|
3918
|
+
} else if (is_vision) {
|
3919
|
+
LM_GGML_ASSERT(ne10*4 >= ne02); // need at least 4 pos per token
|
3920
|
+
switch (src0->type) {
|
3921
|
+
case LM_GGML_TYPE_F32: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32].pipeline; break;
|
3922
|
+
case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16].pipeline; break;
|
3377
3923
|
default: LM_GGML_ABORT("fatal error");
|
3378
3924
|
};
|
3379
3925
|
} else {
|
3380
3926
|
switch (src0->type) {
|
3381
|
-
case LM_GGML_TYPE_F32: pipeline = ctx->kernels[
|
3382
|
-
case LM_GGML_TYPE_F16: pipeline = ctx->kernels[
|
3927
|
+
case LM_GGML_TYPE_F32: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break;
|
3928
|
+
case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break;
|
3383
3929
|
default: LM_GGML_ABORT("fatal error");
|
3384
3930
|
};
|
3385
3931
|
}
|
@@ -3410,6 +3956,10 @@ static void lm_ggml_metal_encode_node(
|
|
3410
3956
|
/*.attn_factor =*/ attn_factor,
|
3411
3957
|
/*.beta_fast =*/ beta_fast,
|
3412
3958
|
/*.beta_slow =*/ beta_slow,
|
3959
|
+
/* sect_0 =*/ sect_0,
|
3960
|
+
/* sect_1 =*/ sect_1,
|
3961
|
+
/* sect_2 =*/ sect_2,
|
3962
|
+
/* sect_3 =*/ sect_3,
|
3413
3963
|
};
|
3414
3964
|
|
3415
3965
|
[encoder setComputePipelineState:pipeline];
|
@@ -3846,12 +4396,14 @@ static void lm_ggml_metal_encode_node(
|
|
3846
4396
|
// TODO: add vec kernels for (ne00%64 == 0) and maybe also for (ne00%32 == 0)
|
3847
4397
|
// for now avoiding mainly to keep the number of templates/kernels a bit lower
|
3848
4398
|
// these are now trivial to add after: https://github.com/ggml-org/llama.cpp/pull/12612
|
3849
|
-
if (ne01 >=
|
4399
|
+
if (ne01 >= 20 || (ne00%128 != 0 && ne00 != 64 && ne00 != 96 && ne00 != 192 && ne00 != 576)) {
|
3850
4400
|
switch (src1->type) {
|
3851
4401
|
case LM_GGML_TYPE_F16:
|
3852
4402
|
{
|
3853
4403
|
if (ne00 == 192 && ne20 == 128) {
|
3854
4404
|
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128].pipeline;
|
4405
|
+
} else if (ne00 == 576 && ne20 == 512) {
|
4406
|
+
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512].pipeline;
|
3855
4407
|
} else {
|
3856
4408
|
switch (ne00) {
|
3857
4409
|
case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
|
@@ -3874,6 +4426,8 @@ static void lm_ggml_metal_encode_node(
|
|
3874
4426
|
{
|
3875
4427
|
if (ne00 == 192 && ne20 == 128) {
|
3876
4428
|
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128].pipeline;
|
4429
|
+
} else if (ne00 == 576 && ne20 == 512) {
|
4430
|
+
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512].pipeline;
|
3877
4431
|
} else {
|
3878
4432
|
switch (ne00) {
|
3879
4433
|
case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64 ].pipeline; break;
|
@@ -3896,6 +4450,8 @@ static void lm_ggml_metal_encode_node(
|
|
3896
4450
|
{
|
3897
4451
|
if (ne00 == 192 && ne20 == 128) {
|
3898
4452
|
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128].pipeline;
|
4453
|
+
} else if (ne00 == 576 && ne20 == 512) {
|
4454
|
+
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512].pipeline;
|
3899
4455
|
} else {
|
3900
4456
|
switch (ne00) {
|
3901
4457
|
case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64 ].pipeline; break;
|
@@ -3918,6 +4474,8 @@ static void lm_ggml_metal_encode_node(
|
|
3918
4474
|
{
|
3919
4475
|
if (ne00 == 192 && ne20 == 128) {
|
3920
4476
|
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128].pipeline;
|
4477
|
+
} else if (ne00 == 576 && ne20 == 512) {
|
4478
|
+
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512].pipeline;
|
3921
4479
|
} else {
|
3922
4480
|
switch (ne00) {
|
3923
4481
|
case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64 ].pipeline; break;
|
@@ -3940,6 +4498,8 @@ static void lm_ggml_metal_encode_node(
|
|
3940
4498
|
{
|
3941
4499
|
if (ne00 == 192 && ne20 == 128) {
|
3942
4500
|
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128].pipeline;
|
4501
|
+
} else if (ne00 == 576 && ne20 == 512) {
|
4502
|
+
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512].pipeline;
|
3943
4503
|
} else {
|
3944
4504
|
switch (ne00) {
|
3945
4505
|
case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64 ].pipeline; break;
|
@@ -3962,6 +4522,8 @@ static void lm_ggml_metal_encode_node(
|
|
3962
4522
|
{
|
3963
4523
|
if (ne00 == 192 && ne20 == 128) {
|
3964
4524
|
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128].pipeline;
|
4525
|
+
} else if (ne00 == 576 && ne20 == 512) {
|
4526
|
+
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512].pipeline;
|
3965
4527
|
} else {
|
3966
4528
|
switch (ne00) {
|
3967
4529
|
case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64 ].pipeline; break;
|
@@ -3984,6 +4546,8 @@ static void lm_ggml_metal_encode_node(
|
|
3984
4546
|
{
|
3985
4547
|
if (ne00 == 192 && ne20 == 128) {
|
3986
4548
|
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128].pipeline;
|
4549
|
+
} else if (ne00 == 576 && ne20 == 512) {
|
4550
|
+
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512].pipeline;
|
3987
4551
|
} else {
|
3988
4552
|
switch (ne00) {
|
3989
4553
|
case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64 ].pipeline; break;
|
@@ -4013,6 +4577,42 @@ static void lm_ggml_metal_encode_node(
|
|
4013
4577
|
use_vec_kernel = true;
|
4014
4578
|
|
4015
4579
|
switch (ne00) {
|
4580
|
+
case 64:
|
4581
|
+
{
|
4582
|
+
switch (src1->type) {
|
4583
|
+
case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64].pipeline; break;
|
4584
|
+
case LM_GGML_TYPE_BF16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64].pipeline; break;
|
4585
|
+
case LM_GGML_TYPE_Q4_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64].pipeline; break;
|
4586
|
+
case LM_GGML_TYPE_Q4_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H64].pipeline; break;
|
4587
|
+
case LM_GGML_TYPE_Q5_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H64].pipeline; break;
|
4588
|
+
case LM_GGML_TYPE_Q5_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H64].pipeline; break;
|
4589
|
+
case LM_GGML_TYPE_Q8_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H64].pipeline; break;
|
4590
|
+
default:
|
4591
|
+
{
|
4592
|
+
LM_GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
|
4593
|
+
LM_GGML_LOG_ERROR("add template specialization for this type\n");
|
4594
|
+
LM_GGML_ABORT("add template specialization for this type");
|
4595
|
+
}
|
4596
|
+
}
|
4597
|
+
} break;
|
4598
|
+
case 96:
|
4599
|
+
{
|
4600
|
+
switch (src1->type) {
|
4601
|
+
case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96].pipeline; break;
|
4602
|
+
case LM_GGML_TYPE_BF16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96].pipeline; break;
|
4603
|
+
case LM_GGML_TYPE_Q4_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96].pipeline; break;
|
4604
|
+
case LM_GGML_TYPE_Q4_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H96].pipeline; break;
|
4605
|
+
case LM_GGML_TYPE_Q5_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H96].pipeline; break;
|
4606
|
+
case LM_GGML_TYPE_Q5_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H96].pipeline; break;
|
4607
|
+
case LM_GGML_TYPE_Q8_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H96].pipeline; break;
|
4608
|
+
default:
|
4609
|
+
{
|
4610
|
+
LM_GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
|
4611
|
+
LM_GGML_LOG_ERROR("add template specialization for this type\n");
|
4612
|
+
LM_GGML_ABORT("add template specialization for this type");
|
4613
|
+
}
|
4614
|
+
}
|
4615
|
+
} break;
|
4016
4616
|
case 128:
|
4017
4617
|
{
|
4018
4618
|
switch (src1->type) {
|
@@ -4085,12 +4685,36 @@ static void lm_ggml_metal_encode_node(
|
|
4085
4685
|
}
|
4086
4686
|
}
|
4087
4687
|
} break;
|
4688
|
+
case 576:
|
4689
|
+
{
|
4690
|
+
if (ne20 == 512) {
|
4691
|
+
switch (src1->type) {
|
4692
|
+
case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512].pipeline; break;
|
4693
|
+
case LM_GGML_TYPE_BF16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK576_HV512].pipeline; break;
|
4694
|
+
case LM_GGML_TYPE_Q4_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK576_HV512].pipeline; break;
|
4695
|
+
case LM_GGML_TYPE_Q4_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK576_HV512].pipeline; break;
|
4696
|
+
case LM_GGML_TYPE_Q5_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512].pipeline; break;
|
4697
|
+
case LM_GGML_TYPE_Q5_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512].pipeline; break;
|
4698
|
+
case LM_GGML_TYPE_Q8_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512].pipeline; break;
|
4699
|
+
default:
|
4700
|
+
{
|
4701
|
+
LM_GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
|
4702
|
+
LM_GGML_LOG_ERROR("add template specialization for this type\n");
|
4703
|
+
LM_GGML_ABORT("add template specialization for this type");
|
4704
|
+
}
|
4705
|
+
}
|
4706
|
+
} else {
|
4707
|
+
LM_GGML_LOG_ERROR("unsupported size: %lld\n", ne20);
|
4708
|
+
LM_GGML_LOG_ERROR("add template specialization for this size\n");
|
4709
|
+
LM_GGML_ABORT("add template specialization for this size");
|
4710
|
+
}
|
4711
|
+
} break;
|
4088
4712
|
default:
|
4089
|
-
|
4090
|
-
|
4091
|
-
|
4092
|
-
|
4093
|
-
|
4713
|
+
{
|
4714
|
+
LM_GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
|
4715
|
+
LM_GGML_LOG_ERROR("add template specialization for this size\n");
|
4716
|
+
LM_GGML_ABORT("add template specialization for this size");
|
4717
|
+
}
|
4094
4718
|
}
|
4095
4719
|
}
|
4096
4720
|
|
@@ -4486,6 +5110,8 @@ static void lm_ggml_metal_encode_node(
|
|
4486
5110
|
LM_GGML_ABORT("fatal error");
|
4487
5111
|
}
|
4488
5112
|
}
|
5113
|
+
|
5114
|
+
return true;
|
4489
5115
|
}
|
4490
5116
|
|
4491
5117
|
static enum lm_ggml_status lm_ggml_metal_graph_compute(
|
@@ -4539,25 +5165,25 @@ static enum lm_ggml_status lm_ggml_metal_graph_compute(
|
|
4539
5165
|
}
|
4540
5166
|
|
4541
5167
|
// the main thread commits the first few commands immediately
|
4542
|
-
//
|
5168
|
+
// cmd_buf[n_cb]
|
4543
5169
|
{
|
4544
|
-
id<MTLCommandBuffer>
|
4545
|
-
ctx->
|
5170
|
+
id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
|
5171
|
+
ctx->cmd_bufs[n_cb].obj = cmd_buf;
|
4546
5172
|
|
4547
|
-
[
|
5173
|
+
[cmd_buf enqueue];
|
4548
5174
|
ctx->encode_async(n_cb);
|
4549
5175
|
}
|
4550
5176
|
|
4551
5177
|
// prepare the rest of the command buffers asynchronously
|
4552
|
-
//
|
5178
|
+
// cmd_buf[0.. n_cb)
|
4553
5179
|
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
|
4554
|
-
id<MTLCommandBuffer>
|
4555
|
-
ctx->
|
5180
|
+
id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
|
5181
|
+
ctx->cmd_bufs[cb_idx].obj = cmd_buf;
|
4556
5182
|
|
4557
5183
|
// always enqueue the first two command buffers
|
4558
5184
|
// enqueue all of the command buffers if we don't need to abort
|
4559
5185
|
if (cb_idx < 2 || ctx->abort_callback == NULL) {
|
4560
|
-
[
|
5186
|
+
[cmd_buf enqueue];
|
4561
5187
|
}
|
4562
5188
|
}
|
4563
5189
|
|
@@ -4566,14 +5192,14 @@ static enum lm_ggml_status lm_ggml_metal_graph_compute(
|
|
4566
5192
|
// wait for completion and check status of each command buffer
|
4567
5193
|
// needed to detect if the device ran out-of-memory for example (#1881)
|
4568
5194
|
{
|
4569
|
-
id<MTLCommandBuffer>
|
4570
|
-
[
|
5195
|
+
id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[n_cb].obj;
|
5196
|
+
[cmd_buf waitUntilCompleted];
|
4571
5197
|
|
4572
|
-
MTLCommandBufferStatus status = [
|
5198
|
+
MTLCommandBufferStatus status = [cmd_buf status];
|
4573
5199
|
if (status != MTLCommandBufferStatusCompleted) {
|
4574
5200
|
LM_GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, n_cb, status);
|
4575
5201
|
if (status == MTLCommandBufferStatusError) {
|
4576
|
-
LM_GGML_LOG_INFO("error: %s\n", [[
|
5202
|
+
LM_GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
|
4577
5203
|
}
|
4578
5204
|
|
4579
5205
|
return LM_GGML_STATUS_FAILED;
|
@@ -4581,20 +5207,20 @@ static enum lm_ggml_status lm_ggml_metal_graph_compute(
|
|
4581
5207
|
}
|
4582
5208
|
|
4583
5209
|
for (int i = 0; i < n_cb; ++i) {
|
4584
|
-
id<MTLCommandBuffer>
|
4585
|
-
[
|
5210
|
+
id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[i].obj;
|
5211
|
+
[cmd_buf waitUntilCompleted];
|
4586
5212
|
|
4587
|
-
MTLCommandBufferStatus status = [
|
5213
|
+
MTLCommandBufferStatus status = [cmd_buf status];
|
4588
5214
|
if (status != MTLCommandBufferStatusCompleted) {
|
4589
5215
|
LM_GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
|
4590
5216
|
if (status == MTLCommandBufferStatusError) {
|
4591
|
-
LM_GGML_LOG_INFO("error: %s\n", [[
|
5217
|
+
LM_GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
|
4592
5218
|
}
|
4593
5219
|
|
4594
5220
|
return LM_GGML_STATUS_FAILED;
|
4595
5221
|
}
|
4596
5222
|
|
4597
|
-
id<MTLCommandBuffer> next_buffer = (i + 1 < n_cb ? ctx->
|
5223
|
+
id<MTLCommandBuffer> next_buffer = (i + 1 < n_cb ? ctx->cmd_bufs[i + 1].obj : nil);
|
4598
5224
|
if (!next_buffer) {
|
4599
5225
|
continue;
|
4600
5226
|
}
|
@@ -4977,8 +5603,9 @@ static void lm_ggml_backend_metal_set_n_cb(lm_ggml_backend_t backend, int n_cb)
|
|
4977
5603
|
|
4978
5604
|
const int n_nodes_per_cb = ctx->n_nodes_per_cb;
|
4979
5605
|
|
4980
|
-
id<MTLCommandBuffer>
|
4981
|
-
|
5606
|
+
id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[cb_idx].obj;
|
5607
|
+
|
5608
|
+
id<MTLComputeCommandEncoder> encoder = [cmd_buf computeCommandEncoder];
|
4982
5609
|
|
4983
5610
|
int node_start = 0;
|
4984
5611
|
int node_end = n_nodes_0;
|
@@ -4990,22 +5617,29 @@ static void lm_ggml_backend_metal_set_n_cb(lm_ggml_backend_t backend, int n_cb)
|
|
4990
5617
|
|
4991
5618
|
const bool should_capture = ctx->capture_next_compute;
|
4992
5619
|
|
5620
|
+
struct lm_ggml_metal_mem_pool * mem_pool = ctx->cmd_bufs[cb_idx].mem_pool;
|
5621
|
+
lm_ggml_metal_mem_pool_reset(mem_pool);
|
5622
|
+
|
4993
5623
|
for (int idx = node_start; idx < node_end; ++idx) {
|
4994
5624
|
if (should_capture) {
|
4995
5625
|
[encoder pushDebugGroup:[NSString stringWithCString:lm_ggml_op_desc(lm_ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
|
4996
5626
|
}
|
4997
5627
|
|
4998
|
-
lm_ggml_metal_encode_node(backend, idx, encoder);
|
5628
|
+
const bool res = lm_ggml_metal_encode_node(backend, idx, encoder, mem_pool);
|
4999
5629
|
|
5000
5630
|
if (should_capture) {
|
5001
5631
|
[encoder popDebugGroup];
|
5002
5632
|
}
|
5633
|
+
|
5634
|
+
if (!res) {
|
5635
|
+
break;
|
5636
|
+
}
|
5003
5637
|
}
|
5004
5638
|
|
5005
5639
|
[encoder endEncoding];
|
5006
5640
|
|
5007
5641
|
if (cb_idx < 2 || ctx->abort_callback == NULL) {
|
5008
|
-
[
|
5642
|
+
[cmd_buf commit];
|
5009
5643
|
}
|
5010
5644
|
});
|
5011
5645
|
}
|