cui-llama.rn 1.6.1 → 1.7.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/src/main/CMakeLists.txt +6 -0
- package/android/src/main/java/com/rnllama/LlamaContext.java +51 -14
- package/android/src/main/java/com/rnllama/RNLlama.java +158 -6
- package/android/src/main/jni.cpp +153 -14
- package/android/src/main/jniLibs/arm64-v8a/librnllama.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod_i8mm.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_i8mm.so +0 -0
- package/android/src/main/jniLibs/x86_64/librnllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/librnllama_x86_64.so +0 -0
- package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +24 -4
- package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +22 -2
- package/cpp/chat.cpp +128 -106
- package/cpp/chat.h +2 -0
- package/cpp/common.cpp +38 -76
- package/cpp/common.h +23 -19
- package/cpp/ggml-backend.cpp +9 -5
- package/cpp/ggml-backend.h +4 -4
- package/cpp/ggml-cpu/ggml-cpu-aarch64.cpp +0 -2
- package/cpp/ggml-cpu/ggml-cpu-quants.c +306 -6
- package/cpp/ggml-cpu/ggml-cpu.c +5 -13
- package/cpp/ggml-cpu/ggml-cpu.cpp +29 -16
- package/cpp/ggml-cpu/ops.cpp +107 -13
- package/cpp/ggml-cpu/vec.cpp +0 -6
- package/cpp/ggml-cpu/vec.h +16 -0
- package/cpp/ggml-llama-sim.metallib +0 -0
- package/cpp/ggml-llama.metallib +0 -0
- package/cpp/ggml-metal-impl.h +36 -11
- package/cpp/ggml-metal.m +321 -132
- package/cpp/ggml-opt.cpp +373 -190
- package/cpp/ggml-opt.h +49 -28
- package/cpp/ggml-quants.c +0 -6
- package/cpp/ggml.c +93 -38
- package/cpp/ggml.h +21 -7
- package/cpp/gguf.cpp +33 -33
- package/cpp/llama-adapter.cpp +6 -0
- package/cpp/llama-arch.cpp +3 -0
- package/cpp/llama-batch.cpp +3 -1
- package/cpp/llama-chat.cpp +8 -6
- package/cpp/llama-chat.h +1 -0
- package/cpp/llama-context.cpp +349 -135
- package/cpp/llama-context.h +30 -3
- package/cpp/llama-cparams.h +1 -0
- package/cpp/llama-graph.cpp +150 -234
- package/cpp/llama-graph.h +52 -7
- package/cpp/llama-hparams.cpp +17 -1
- package/cpp/llama-hparams.h +34 -5
- package/cpp/llama-kv-cache.cpp +662 -321
- package/cpp/llama-kv-cache.h +203 -93
- package/cpp/llama-memory.h +3 -2
- package/cpp/llama-model-loader.cpp +24 -15
- package/cpp/llama-model-saver.cpp +281 -0
- package/cpp/llama-model-saver.h +37 -0
- package/cpp/llama-model.cpp +536 -132
- package/cpp/llama-model.h +7 -1
- package/cpp/llama-sampling.cpp +18 -6
- package/cpp/llama-vocab.cpp +46 -8
- package/cpp/llama-vocab.h +6 -0
- package/cpp/llama.cpp +14 -0
- package/cpp/llama.h +72 -131
- package/cpp/minja/chat-template.hpp +9 -5
- package/cpp/minja/minja.hpp +69 -36
- package/cpp/rn-llama.cpp +611 -47
- package/cpp/rn-llama.h +33 -3
- package/cpp/sampling.cpp +57 -50
- package/cpp/tools/mtmd/clip-impl.h +462 -0
- package/cpp/tools/mtmd/clip.cpp +4024 -0
- package/cpp/tools/mtmd/clip.h +101 -0
- package/cpp/tools/mtmd/miniaudio.h +93468 -0
- package/cpp/tools/mtmd/mtmd-audio.cpp +855 -0
- package/cpp/tools/mtmd/mtmd-audio.h +62 -0
- package/cpp/tools/mtmd/mtmd-helper.cpp +297 -0
- package/cpp/tools/mtmd/mtmd.cpp +942 -0
- package/cpp/tools/mtmd/mtmd.h +362 -0
- package/cpp/tools/mtmd/stb_image.h +7988 -0
- package/ios/CMakeLists.txt +7 -0
- package/ios/RNLlama.mm +77 -3
- package/ios/RNLlamaContext.h +5 -1
- package/ios/RNLlamaContext.mm +105 -10
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/chat.h +2 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/common.h +23 -19
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-backend.h +4 -4
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-opt.h +49 -28
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml.h +21 -7
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-chat.h +1 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-context.h +30 -3
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-cparams.h +1 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-graph.h +52 -7
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-hparams.h +34 -5
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cache.h +203 -93
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-memory.h +3 -2
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model-saver.h +37 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model.h +7 -1
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-vocab.h +6 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama.h +72 -131
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja/minja.hpp +69 -36
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/rn-llama.h +33 -3
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Info.plist +0 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/ggml-llama.metallib +0 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/rnllama +0 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +2 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +23 -19
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-backend.h +4 -4
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-opt.h +49 -28
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +21 -7
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +1 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +30 -3
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +1 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +52 -7
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +34 -5
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +203 -93
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +3 -2
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model-saver.h +37 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +7 -1
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +6 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +72 -131
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +69 -36
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +33 -3
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Info.plist +0 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/_CodeSignature/CodeResources +1 -1
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/chat.h +2 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/common.h +23 -19
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-backend.h +4 -4
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-opt.h +49 -28
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml.h +21 -7
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-chat.h +1 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-context.h +30 -3
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-cparams.h +1 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-graph.h +52 -7
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-hparams.h +34 -5
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cache.h +203 -93
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-memory.h +3 -2
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model-saver.h +37 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model.h +7 -1
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-vocab.h +6 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama.h +72 -131
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/minja.hpp +69 -36
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/rn-llama.h +33 -3
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Info.plist +0 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/ggml-llama.metallib +0 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/rnllama +0 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +2 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +23 -19
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-backend.h +4 -4
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-metal-impl.h +36 -11
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-opt.h +49 -28
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +21 -7
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +1 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +30 -3
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +1 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +52 -7
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +34 -5
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +203 -93
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +3 -2
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model-saver.h +37 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +7 -1
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +6 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +72 -131
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +9 -5
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +69 -36
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +33 -3
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Info.plist +0 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/_CodeSignature/CodeResources +1 -1
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
- package/jest/mock.js +33 -7
- package/lib/commonjs/NativeRNLlama.js.map +1 -1
- package/lib/commonjs/index.js +153 -21
- package/lib/commonjs/index.js.map +1 -1
- package/lib/module/NativeRNLlama.js.map +1 -1
- package/lib/module/index.js +152 -20
- package/lib/module/index.js.map +1 -1
- package/lib/typescript/NativeRNLlama.d.ts +50 -4
- package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
- package/lib/typescript/index.d.ts +72 -6
- package/lib/typescript/index.d.ts.map +1 -1
- package/package.json +1 -1
- package/src/NativeRNLlama.ts +67 -4
- package/src/index.ts +212 -38
- package/lib/commonjs/chat.js +0 -37
- package/lib/commonjs/chat.js.map +0 -1
- package/lib/module/chat.js +0 -33
- package/lib/module/chat.js.map +0 -1
- package/lib/typescript/chat.d.ts +0 -10
- package/lib/typescript/chat.d.ts.map +0 -1
- package/src/chat.ts +0 -44
package/cpp/ggml-metal.m
CHANGED
@@ -149,6 +149,8 @@ enum lm_ggml_metal_kernel_type {
|
|
149
149
|
LM_GGML_METAL_KERNEL_TYPE_SIGMOID,
|
150
150
|
LM_GGML_METAL_KERNEL_TYPE_GELU,
|
151
151
|
LM_GGML_METAL_KERNEL_TYPE_GELU_4,
|
152
|
+
LM_GGML_METAL_KERNEL_TYPE_GELU_ERF,
|
153
|
+
LM_GGML_METAL_KERNEL_TYPE_GELU_ERF_4,
|
152
154
|
LM_GGML_METAL_KERNEL_TYPE_GELU_QUICK,
|
153
155
|
LM_GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,
|
154
156
|
LM_GGML_METAL_KERNEL_TYPE_SILU,
|
@@ -306,30 +308,36 @@ enum lm_ggml_metal_kernel_type {
|
|
306
308
|
LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32,
|
307
309
|
LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
|
308
310
|
LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
|
311
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16,
|
312
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32,
|
313
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16,
|
314
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16,
|
315
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16,
|
316
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F16,
|
317
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F16,
|
318
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16,
|
319
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16,
|
320
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16,
|
321
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16,
|
322
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16,
|
323
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16,
|
324
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F16,
|
325
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F16,
|
326
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F16,
|
327
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F16,
|
328
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F16,
|
329
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F16,
|
330
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F16,
|
331
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F16,
|
332
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F16,
|
333
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F16,
|
334
|
+
LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16,
|
331
335
|
LM_GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32,
|
332
336
|
LM_GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16,
|
337
|
+
LM_GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32,
|
338
|
+
LM_GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16,
|
339
|
+
LM_GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32,
|
340
|
+
LM_GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16,
|
333
341
|
LM_GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32,
|
334
342
|
LM_GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16,
|
335
343
|
LM_GGML_METAL_KERNEL_TYPE_IM2COL_F16,
|
@@ -409,6 +417,13 @@ enum lm_ggml_metal_kernel_type {
|
|
409
417
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128,
|
410
418
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,
|
411
419
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512,
|
420
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64,
|
421
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64,
|
422
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64,
|
423
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H64,
|
424
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H64,
|
425
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H64,
|
426
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H64,
|
412
427
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96,
|
413
428
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96,
|
414
429
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96,
|
@@ -650,7 +665,8 @@ static void lm_ggml_metal_mem_pool_reset(struct lm_ggml_metal_mem_pool * mem_poo
|
|
650
665
|
}
|
651
666
|
|
652
667
|
if (mem_pool->heaps_to_remove.count > 0) {
|
653
|
-
|
668
|
+
// remove in reverse order
|
669
|
+
for (NSUInteger i = [mem_pool->heaps_to_remove count] - 1; ; --i) {
|
654
670
|
NSUInteger index = [[mem_pool->heaps_to_remove objectAtIndex:i] intValue];
|
655
671
|
lm_ggml_metal_heap_ptr * ptr = [mem_pool->heaps objectAtIndex:index];
|
656
672
|
|
@@ -659,6 +675,10 @@ static void lm_ggml_metal_mem_pool_reset(struct lm_ggml_metal_mem_pool * mem_poo
|
|
659
675
|
|
660
676
|
[mem_pool->heaps removeObjectAtIndex:index];
|
661
677
|
[ptr release];
|
678
|
+
|
679
|
+
if (i == 0) {
|
680
|
+
break;
|
681
|
+
}
|
662
682
|
}
|
663
683
|
|
664
684
|
[mem_pool->heaps_to_remove removeAllObjects];
|
@@ -672,7 +692,7 @@ static void lm_ggml_metal_mem_pool_clear(struct lm_ggml_metal_mem_pool * mem_poo
|
|
672
692
|
}
|
673
693
|
|
674
694
|
static id<MTLBuffer> lm_ggml_metal_mem_pool_alloc(struct lm_ggml_metal_mem_pool * mem_pool, size_t size) {
|
675
|
-
const size_t alignment =
|
695
|
+
const size_t alignment = 256;
|
676
696
|
|
677
697
|
const size_t size_aligned = LM_GGML_PAD(size, alignment);
|
678
698
|
|
@@ -834,11 +854,7 @@ static id<MTLLibrary> lm_ggml_metal_load_library(id<MTLDevice> device, bool use_
|
|
834
854
|
NSBundle * bundle = [NSBundle bundleForClass:[LMGGMLMetalClass class]];
|
835
855
|
#endif
|
836
856
|
|
837
|
-
|
838
|
-
NSString * path_lib = [bundle pathForResource:@"ggml-llama-sim" ofType:@"metallib"];
|
839
|
-
#else
|
840
|
-
NSString * path_lib = [bundle pathForResource:@"ggml-llama" ofType:@"metallib"];
|
841
|
-
#endif
|
857
|
+
NSString * path_lib = [bundle pathForResource:@"default" ofType:@"metallib"];
|
842
858
|
if (path_lib == nil) {
|
843
859
|
// Try to find the resource in the directory where the current binary located.
|
844
860
|
NSString * current_binary = [[NSProcessInfo processInfo] arguments][0];
|
@@ -1089,6 +1105,8 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
|
|
1089
1105
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SIGMOID, sigmoid, true);
|
1090
1106
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GELU, gelu, true);
|
1091
1107
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true);
|
1108
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GELU_ERF, gelu_erf, true);
|
1109
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GELU_ERF_4, gelu_erf_4, true);
|
1092
1110
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true);
|
1093
1111
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true);
|
1094
1112
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SILU, silu, true);
|
@@ -1246,30 +1264,36 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
|
|
1246
1264
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, has_simdgroup_mm);
|
1247
1265
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, has_simdgroup_mm);
|
1248
1266
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, has_simdgroup_mm);
|
1249
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1250
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1251
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1252
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1253
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1254
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1255
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1256
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1257
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1258
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1259
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1260
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1261
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1262
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1263
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1264
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1265
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1266
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1267
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1268
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1269
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1270
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1267
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16, mul_mm_id_map0_f16, has_simdgroup_mm);
|
1268
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32, mul_mm_id_map1_f32, has_simdgroup_mm);
|
1269
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16, mul_mm_id_f32_f16, has_simdgroup_mm);
|
1270
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16, mul_mm_id_f16_f16, has_simdgroup_mm);
|
1271
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16, mul_mm_id_bf16_f16, has_simdgroup_mm && use_bfloat);
|
1272
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F16, mul_mm_id_q4_0_f16, has_simdgroup_mm);
|
1273
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F16, mul_mm_id_q4_1_f16, has_simdgroup_mm);
|
1274
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16, mul_mm_id_q5_0_f16, has_simdgroup_mm);
|
1275
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16, mul_mm_id_q5_1_f16, has_simdgroup_mm);
|
1276
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16, mul_mm_id_q8_0_f16, has_simdgroup_mm);
|
1277
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16, mul_mm_id_q2_K_f16, has_simdgroup_mm);
|
1278
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16, mul_mm_id_q3_K_f16, has_simdgroup_mm);
|
1279
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16, mul_mm_id_q4_K_f16, has_simdgroup_mm);
|
1280
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F16, mul_mm_id_q5_K_f16, has_simdgroup_mm);
|
1281
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F16, mul_mm_id_q6_K_f16, has_simdgroup_mm);
|
1282
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F16, mul_mm_id_iq2_xxs_f16, has_simdgroup_mm);
|
1283
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F16, mul_mm_id_iq2_xs_f16, has_simdgroup_mm);
|
1284
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F16, mul_mm_id_iq3_xxs_f16, has_simdgroup_mm);
|
1285
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F16, mul_mm_id_iq3_s_f16, has_simdgroup_mm);
|
1286
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F16, mul_mm_id_iq2_s_f16, has_simdgroup_mm);
|
1287
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F16, mul_mm_id_iq1_s_f16, has_simdgroup_mm);
|
1288
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F16, mul_mm_id_iq1_m_f16, has_simdgroup_mm);
|
1289
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F16, mul_mm_id_iq4_nl_f16, has_simdgroup_mm);
|
1290
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16, mul_mm_id_iq4_xs_f16, has_simdgroup_mm);
|
1271
1291
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true);
|
1272
1292
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true);
|
1293
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32, rope_multi_f32, true);
|
1294
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16, rope_multi_f16, true);
|
1295
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32, rope_vision_f32, true);
|
1296
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16, rope_vision_f16, true);
|
1273
1297
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true);
|
1274
1298
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true);
|
1275
1299
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
|
@@ -1349,6 +1373,13 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
|
|
1349
1373
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128, flash_attn_ext_q8_0_hk192_hv128, has_simdgroup_mm);
|
1350
1374
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, has_simdgroup_mm);
|
1351
1375
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512, flash_attn_ext_q8_0_hk576_hv512, has_simdgroup_mm);
|
1376
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64, flash_attn_ext_vec_f16_h64, has_simdgroup_reduction);
|
1377
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64, flash_attn_ext_vec_bf16_h64, has_simdgroup_reduction && use_bfloat);
|
1378
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64, flash_attn_ext_vec_q4_0_h64, has_simdgroup_reduction);
|
1379
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H64, flash_attn_ext_vec_q4_1_h64, has_simdgroup_reduction);
|
1380
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H64, flash_attn_ext_vec_q5_0_h64, has_simdgroup_reduction);
|
1381
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H64, flash_attn_ext_vec_q5_1_h64, has_simdgroup_reduction);
|
1382
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H64, flash_attn_ext_vec_q8_0_h64, has_simdgroup_reduction);
|
1352
1383
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96, flash_attn_ext_vec_f16_h96, has_simdgroup_reduction);
|
1353
1384
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96, flash_attn_ext_vec_bf16_h96, has_simdgroup_reduction && use_bfloat);
|
1354
1385
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96, flash_attn_ext_vec_q4_0_h96, has_simdgroup_reduction);
|
@@ -1586,6 +1617,7 @@ static bool lm_ggml_metal_supports_op(const struct lm_ggml_backend_metal_device_
|
|
1586
1617
|
case LM_GGML_UNARY_OP_RELU:
|
1587
1618
|
case LM_GGML_UNARY_OP_SIGMOID:
|
1588
1619
|
case LM_GGML_UNARY_OP_GELU:
|
1620
|
+
case LM_GGML_UNARY_OP_GELU_ERF:
|
1589
1621
|
case LM_GGML_UNARY_OP_GELU_QUICK:
|
1590
1622
|
case LM_GGML_UNARY_OP_SILU:
|
1591
1623
|
case LM_GGML_UNARY_OP_ELU:
|
@@ -1632,16 +1664,7 @@ static bool lm_ggml_metal_supports_op(const struct lm_ggml_backend_metal_device_
|
|
1632
1664
|
case LM_GGML_OP_NORM:
|
1633
1665
|
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && lm_ggml_is_contiguous_1(op->src[0]));
|
1634
1666
|
case LM_GGML_OP_ROPE:
|
1635
|
-
|
1636
|
-
const int mode = ((const int32_t *) op->op_params)[2];
|
1637
|
-
if (mode & LM_GGML_ROPE_TYPE_MROPE) {
|
1638
|
-
return false;
|
1639
|
-
}
|
1640
|
-
if (mode & LM_GGML_ROPE_TYPE_VISION) {
|
1641
|
-
return false;
|
1642
|
-
}
|
1643
|
-
return true;
|
1644
|
-
}
|
1667
|
+
return true;
|
1645
1668
|
case LM_GGML_OP_IM2COL:
|
1646
1669
|
return op->src[0]->type == LM_GGML_TYPE_F16;
|
1647
1670
|
case LM_GGML_OP_POOL_1D:
|
@@ -2233,6 +2256,25 @@ static bool lm_ggml_metal_encode_node(
|
|
2233
2256
|
|
2234
2257
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
2235
2258
|
} break;
|
2259
|
+
case LM_GGML_UNARY_OP_GELU_ERF:
|
2260
|
+
{
|
2261
|
+
int64_t n = lm_ggml_nelements(dst);
|
2262
|
+
|
2263
|
+
id<MTLComputePipelineState> pipeline = nil;
|
2264
|
+
|
2265
|
+
if (n % 4 == 0) {
|
2266
|
+
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GELU_ERF_4].pipeline;
|
2267
|
+
n /= 4;
|
2268
|
+
} else {
|
2269
|
+
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GELU_ERF].pipeline;
|
2270
|
+
}
|
2271
|
+
|
2272
|
+
[encoder setComputePipelineState:pipeline];
|
2273
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
2274
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
2275
|
+
|
2276
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
2277
|
+
} break;
|
2236
2278
|
case LM_GGML_UNARY_OP_GELU_QUICK:
|
2237
2279
|
{
|
2238
2280
|
int64_t n = lm_ggml_nelements(dst);
|
@@ -3003,7 +3045,7 @@ static bool lm_ggml_metal_encode_node(
|
|
3003
3045
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
3004
3046
|
|
3005
3047
|
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
3006
|
-
[encoder dispatchThreadgroups:MTLSizeMake(
|
3048
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
3007
3049
|
} else {
|
3008
3050
|
id<MTLComputePipelineState> pipeline = nil;
|
3009
3051
|
|
@@ -3223,8 +3265,6 @@ static bool lm_ggml_metal_encode_node(
|
|
3223
3265
|
} break;
|
3224
3266
|
case LM_GGML_OP_MUL_MAT_ID:
|
3225
3267
|
{
|
3226
|
-
const int n_as = src0->ne[2];
|
3227
|
-
|
3228
3268
|
// src2 = ids
|
3229
3269
|
const enum lm_ggml_type src2t = src2->type; LM_GGML_UNUSED(src2t);
|
3230
3270
|
|
@@ -3238,24 +3278,21 @@ static bool lm_ggml_metal_encode_node(
|
|
3238
3278
|
LM_GGML_ASSERT(ne03 == 1);
|
3239
3279
|
LM_GGML_ASSERT(ne13 == 1);
|
3240
3280
|
|
3281
|
+
const uint32_t r2 = 1;
|
3282
|
+
const uint32_t r3 = 1;
|
3283
|
+
|
3241
3284
|
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
|
3242
3285
|
// to the matrix-vector kernel
|
3243
3286
|
// ne20 = n_used_experts
|
3244
|
-
// ne21 = n_rows
|
3245
|
-
const int
|
3246
|
-
const int dst_rows_min = n_as;
|
3247
|
-
const int dst_rows_max = (device.maxThreadgroupMemoryLength/2 - 8192)/4;
|
3248
|
-
|
3249
|
-
// max size of the rowids array in the kernel shared buffer
|
3250
|
-
//LM_GGML_ASSERT(dst_rows <= dst_rows_max);
|
3287
|
+
// ne21 = n_rows (batch size)
|
3288
|
+
const int ne21_mm_id_min = 32;
|
3251
3289
|
|
3252
3290
|
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
3253
3291
|
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
3254
3292
|
if ([device supportsFamily:MTLGPUFamilyApple7] &&
|
3255
3293
|
ne00 % 32 == 0 && ne00 >= 64 &&
|
3256
|
-
|
3257
|
-
|
3258
|
-
dst_rows <= dst_rows_max) {
|
3294
|
+
(ne21 >= ne21_mm_id_min)) {
|
3295
|
+
LM_GGML_ASSERT(ne00 % 4 == 0);
|
3259
3296
|
|
3260
3297
|
// some Metal matrix data types require aligned pointers
|
3261
3298
|
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
|
@@ -3266,62 +3303,169 @@ static bool lm_ggml_metal_encode_node(
|
|
3266
3303
|
default: break;
|
3267
3304
|
}
|
3268
3305
|
|
3269
|
-
|
3306
|
+
const int64_t neh10 = ne10; // n_embd
|
3307
|
+
const int64_t neh11 = ne21; // n_tokens
|
3308
|
+
const int64_t neh12 = ne02; // n_expert
|
3270
3309
|
|
3271
|
-
|
3272
|
-
|
3273
|
-
|
3274
|
-
|
3275
|
-
|
3276
|
-
|
3277
|
-
|
3278
|
-
|
3279
|
-
|
3280
|
-
|
3281
|
-
case LM_GGML_TYPE_Q3_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32 ].pipeline; break;
|
3282
|
-
case LM_GGML_TYPE_Q4_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32 ].pipeline; break;
|
3283
|
-
case LM_GGML_TYPE_Q5_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32 ].pipeline; break;
|
3284
|
-
case LM_GGML_TYPE_Q6_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32 ].pipeline; break;
|
3285
|
-
case LM_GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break;
|
3286
|
-
case LM_GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break;
|
3287
|
-
case LM_GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32].pipeline; break;
|
3288
|
-
case LM_GGML_TYPE_IQ3_S: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32 ].pipeline; break;
|
3289
|
-
case LM_GGML_TYPE_IQ2_S: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32 ].pipeline; break;
|
3290
|
-
case LM_GGML_TYPE_IQ1_S: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32 ].pipeline; break;
|
3291
|
-
case LM_GGML_TYPE_IQ1_M: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32 ].pipeline; break;
|
3292
|
-
case LM_GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break;
|
3293
|
-
case LM_GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32 ].pipeline; break;
|
3294
|
-
default: LM_GGML_ABORT("MUL_MAT_ID not implemented");
|
3310
|
+
const uint64_t nbh10 = lm_ggml_type_size(LM_GGML_TYPE_F16);
|
3311
|
+
const uint64_t nbh11 = nbh10*neh10;
|
3312
|
+
const uint64_t nbh12 = nbh11*neh11;
|
3313
|
+
const uint64_t nbh13 = nbh12*neh12;
|
3314
|
+
|
3315
|
+
const size_t s_src1 = lm_ggml_type_size(LM_GGML_TYPE_F16)*neh10*neh11*neh12;
|
3316
|
+
id<MTLBuffer> h_src1 = lm_ggml_metal_mem_pool_alloc(mem_pool, s_src1);
|
3317
|
+
if (!h_src1) {
|
3318
|
+
LM_GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_src1);
|
3319
|
+
return false;
|
3295
3320
|
}
|
3296
3321
|
|
3297
|
-
|
3298
|
-
|
3299
|
-
|
3300
|
-
/*.nbi1 =*/ nb21,
|
3301
|
-
/*.ne00 =*/ ne00,
|
3302
|
-
/*.ne02 =*/ ne02,
|
3303
|
-
/*.nb01 =*/ nb01,
|
3304
|
-
/*.nb02 =*/ nb02,
|
3305
|
-
/*.ne11 =*/ ne11,
|
3306
|
-
/*.ne12 =*/ ne12,
|
3307
|
-
/*.ne13 =*/ ne13,
|
3308
|
-
/*.nb10 =*/ nb10,
|
3309
|
-
/*.nb11 =*/ nb11,
|
3310
|
-
/*.nb12 =*/ nb12,
|
3311
|
-
/*.ne0 =*/ ne0,
|
3312
|
-
/*.ne1 =*/ ne1,
|
3313
|
-
};
|
3322
|
+
const int64_t neh0 = ne0;
|
3323
|
+
const int64_t neh1 = ne21;
|
3324
|
+
const int64_t neh2 = ne02;
|
3314
3325
|
|
3315
|
-
|
3316
|
-
|
3317
|
-
|
3318
|
-
|
3319
|
-
|
3320
|
-
|
3326
|
+
const uint64_t nbh0 = lm_ggml_type_size(LM_GGML_TYPE_F32);
|
3327
|
+
const uint64_t nbh1 = nbh0*neh0;
|
3328
|
+
const uint64_t nbh2 = nbh1*neh1;
|
3329
|
+
//const uint64_t nbh3 = nbh2*neh2;
|
3330
|
+
|
3331
|
+
const size_t s_dst = lm_ggml_type_size(LM_GGML_TYPE_F32)*neh0*neh1*neh2;
|
3332
|
+
id<MTLBuffer> h_dst = lm_ggml_metal_mem_pool_alloc(mem_pool, s_dst);
|
3333
|
+
if (!h_dst) {
|
3334
|
+
LM_GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_dst);
|
3335
|
+
return false;
|
3336
|
+
}
|
3337
|
+
|
3338
|
+
// tokens per expert
|
3339
|
+
const size_t s_tpe = lm_ggml_type_size(LM_GGML_TYPE_I32)*ne02;
|
3340
|
+
id<MTLBuffer> h_tpe = lm_ggml_metal_mem_pool_alloc(mem_pool, s_tpe);
|
3341
|
+
if (!h_tpe) {
|
3342
|
+
LM_GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_tpe);
|
3343
|
+
return false;
|
3344
|
+
}
|
3345
|
+
|
3346
|
+
// id map
|
3347
|
+
// [n_expert_used, n_tokens]
|
3348
|
+
const size_t s_ids = lm_ggml_type_size(LM_GGML_TYPE_I32)*ne20*ne21;
|
3349
|
+
id<MTLBuffer> h_ids = lm_ggml_metal_mem_pool_alloc(mem_pool, s_ids);
|
3350
|
+
if (!h_ids) {
|
3351
|
+
LM_GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_ids);
|
3352
|
+
return false;
|
3353
|
+
}
|
3321
3354
|
|
3322
|
-
|
3355
|
+
{
|
3356
|
+
const int nth = MIN(1024, ne10/4);
|
3357
|
+
|
3358
|
+
lm_ggml_metal_kargs_mul_mm_id_map0 args = {
|
3359
|
+
ne10,
|
3360
|
+
ne11, // n_expert_used (bcast)
|
3361
|
+
nb11,
|
3362
|
+
nb12,
|
3363
|
+
neh11, // n_tokens
|
3364
|
+
nbh11,
|
3365
|
+
ne20, // n_expert_used
|
3366
|
+
nb21,
|
3367
|
+
};
|
3368
|
+
|
3369
|
+
id<MTLComputePipelineState> pipeline = nil;
|
3370
|
+
|
3371
|
+
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16].pipeline;
|
3372
|
+
|
3373
|
+
[encoder setComputePipelineState:pipeline];
|
3374
|
+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
3375
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
3376
|
+
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
|
3377
|
+
[encoder setBuffer: h_src1 offset:0 atIndex:3];
|
3378
|
+
[encoder setBuffer: h_tpe offset:0 atIndex:4];
|
3379
|
+
[encoder setBuffer: h_ids offset:0 atIndex:5];
|
3380
|
+
|
3381
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne02, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
3382
|
+
}
|
3383
|
+
|
3384
|
+
{
|
3385
|
+
id<MTLComputePipelineState> pipeline = nil;
|
3386
|
+
|
3387
|
+
switch (src0->type) {
|
3388
|
+
case LM_GGML_TYPE_F32: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16 ].pipeline; break;
|
3389
|
+
case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16 ].pipeline; break;
|
3390
|
+
case LM_GGML_TYPE_BF16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16 ].pipeline; break;
|
3391
|
+
case LM_GGML_TYPE_Q4_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F16 ].pipeline; break;
|
3392
|
+
case LM_GGML_TYPE_Q4_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F16 ].pipeline; break;
|
3393
|
+
case LM_GGML_TYPE_Q5_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16 ].pipeline; break;
|
3394
|
+
case LM_GGML_TYPE_Q5_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16 ].pipeline; break;
|
3395
|
+
case LM_GGML_TYPE_Q8_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16 ].pipeline; break;
|
3396
|
+
case LM_GGML_TYPE_Q2_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16 ].pipeline; break;
|
3397
|
+
case LM_GGML_TYPE_Q3_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16 ].pipeline; break;
|
3398
|
+
case LM_GGML_TYPE_Q4_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16 ].pipeline; break;
|
3399
|
+
case LM_GGML_TYPE_Q5_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F16 ].pipeline; break;
|
3400
|
+
case LM_GGML_TYPE_Q6_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F16 ].pipeline; break;
|
3401
|
+
case LM_GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F16].pipeline; break;
|
3402
|
+
case LM_GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F16 ].pipeline; break;
|
3403
|
+
case LM_GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F16].pipeline; break;
|
3404
|
+
case LM_GGML_TYPE_IQ3_S: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F16 ].pipeline; break;
|
3405
|
+
case LM_GGML_TYPE_IQ2_S: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F16 ].pipeline; break;
|
3406
|
+
case LM_GGML_TYPE_IQ1_S: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F16 ].pipeline; break;
|
3407
|
+
case LM_GGML_TYPE_IQ1_M: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F16 ].pipeline; break;
|
3408
|
+
case LM_GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F16 ].pipeline; break;
|
3409
|
+
case LM_GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16 ].pipeline; break;
|
3410
|
+
default: LM_GGML_ABORT("MUL_MAT_ID not implemented");
|
3411
|
+
}
|
3412
|
+
|
3413
|
+
lm_ggml_metal_kargs_mul_mm_id args = {
|
3414
|
+
/*.ne00 =*/ ne00,
|
3415
|
+
/*.ne02 =*/ ne02,
|
3416
|
+
/*.nb01 =*/ nb01,
|
3417
|
+
/*.nb02 =*/ nb02,
|
3418
|
+
/*.nb03 =*/ nb03,
|
3419
|
+
/*.neh12 =*/ neh12,
|
3420
|
+
/*.nbh10 =*/ nbh10,
|
3421
|
+
/*.nbh11 =*/ nbh11,
|
3422
|
+
/*.nbh12 =*/ nbh12,
|
3423
|
+
/*.nbh13 =*/ nbh13,
|
3424
|
+
/*.neh0 =*/ neh0,
|
3425
|
+
/*.neh1 =*/ neh1,
|
3426
|
+
/*.r2 =*/ r2,
|
3427
|
+
/*.r3 =*/ r3,
|
3428
|
+
};
|
3429
|
+
|
3430
|
+
[encoder setComputePipelineState:pipeline];
|
3431
|
+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
3432
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
3433
|
+
[encoder setBuffer: h_src1 offset:0 atIndex:2];
|
3434
|
+
[encoder setBuffer: h_tpe offset:0 atIndex:3];
|
3435
|
+
[encoder setBuffer: h_dst offset:0 atIndex:4];
|
3436
|
+
|
3437
|
+
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
3438
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, ne02) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
3439
|
+
}
|
3440
|
+
|
3441
|
+
{
|
3442
|
+
LM_GGML_ASSERT(ne0 % 4 == 0);
|
3443
|
+
|
3444
|
+
const int nth = MIN(1024, ne0/4);
|
3323
3445
|
|
3324
|
-
|
3446
|
+
lm_ggml_metal_kargs_mul_mm_id_map1 args = {
|
3447
|
+
ne20, // n_expert_used
|
3448
|
+
neh0,
|
3449
|
+
neh1,
|
3450
|
+
nbh1,
|
3451
|
+
nbh2,
|
3452
|
+
ne0,
|
3453
|
+
nb1,
|
3454
|
+
nb2,
|
3455
|
+
};
|
3456
|
+
|
3457
|
+
id<MTLComputePipelineState> pipeline = nil;
|
3458
|
+
|
3459
|
+
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32].pipeline;
|
3460
|
+
|
3461
|
+
[encoder setComputePipelineState:pipeline];
|
3462
|
+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
3463
|
+
[encoder setBuffer: h_dst offset:0 atIndex:1];
|
3464
|
+
[encoder setBuffer: h_ids offset:0 atIndex:2];
|
3465
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
3466
|
+
|
3467
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne20, ne21, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
3468
|
+
}
|
3325
3469
|
} else {
|
3326
3470
|
id<MTLComputePipelineState> pipeline = nil;
|
3327
3471
|
|
@@ -3515,7 +3659,7 @@ static bool lm_ggml_metal_encode_node(
|
|
3515
3659
|
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:4];
|
3516
3660
|
|
3517
3661
|
const int64_t _ne1 = 1;
|
3518
|
-
const int64_t ne123 =
|
3662
|
+
const int64_t ne123 = ne20*ne21;
|
3519
3663
|
|
3520
3664
|
if (smem > 0) {
|
3521
3665
|
[encoder setThreadgroupMemoryLength:smem atIndex:0];
|
@@ -3719,6 +3863,7 @@ static bool lm_ggml_metal_encode_node(
|
|
3719
3863
|
} break;
|
3720
3864
|
case LM_GGML_OP_ROPE:
|
3721
3865
|
{
|
3866
|
+
|
3722
3867
|
// make sure we have one or more position id(ne10) per token(ne02)
|
3723
3868
|
LM_GGML_ASSERT(ne10 % ne02 == 0);
|
3724
3869
|
LM_GGML_ASSERT(ne10 >= ne02);
|
@@ -3745,20 +3890,42 @@ static bool lm_ggml_metal_encode_node(
|
|
3745
3890
|
memcpy(&beta_fast, (const int32_t *) dst->op_params + 9, sizeof(float));
|
3746
3891
|
memcpy(&beta_slow, (const int32_t *) dst->op_params + 10, sizeof(float));
|
3747
3892
|
|
3748
|
-
const bool is_neox
|
3893
|
+
const bool is_neox = mode & LM_GGML_ROPE_TYPE_NEOX;
|
3894
|
+
const bool is_mrope = mode & LM_GGML_ROPE_TYPE_MROPE;
|
3895
|
+
const bool is_vision = mode == LM_GGML_ROPE_TYPE_VISION;
|
3896
|
+
|
3897
|
+
// mrope
|
3898
|
+
const int sect_0 = ((const int32_t *) dst->op_params)[11];
|
3899
|
+
const int sect_1 = ((const int32_t *) dst->op_params)[12];
|
3900
|
+
const int sect_2 = ((const int32_t *) dst->op_params)[13];
|
3901
|
+
const int sect_3 = ((const int32_t *) dst->op_params)[14];
|
3749
3902
|
|
3750
3903
|
id<MTLComputePipelineState> pipeline = nil;
|
3751
3904
|
|
3752
|
-
if (
|
3905
|
+
if (is_neox) {
|
3753
3906
|
switch (src0->type) {
|
3754
|
-
case LM_GGML_TYPE_F32: pipeline = ctx->kernels[
|
3755
|
-
case LM_GGML_TYPE_F16: pipeline = ctx->kernels[
|
3907
|
+
case LM_GGML_TYPE_F32: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break;
|
3908
|
+
case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break;
|
3909
|
+
default: LM_GGML_ABORT("fatal error");
|
3910
|
+
};
|
3911
|
+
} else if (is_mrope && !is_vision) {
|
3912
|
+
LM_GGML_ASSERT(ne10*4 >= ne02); // need at least 4 pos per token
|
3913
|
+
switch (src0->type) {
|
3914
|
+
case LM_GGML_TYPE_F32: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32].pipeline; break;
|
3915
|
+
case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16].pipeline; break;
|
3916
|
+
default: LM_GGML_ABORT("fatal error");
|
3917
|
+
};
|
3918
|
+
} else if (is_vision) {
|
3919
|
+
LM_GGML_ASSERT(ne10*4 >= ne02); // need at least 4 pos per token
|
3920
|
+
switch (src0->type) {
|
3921
|
+
case LM_GGML_TYPE_F32: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32].pipeline; break;
|
3922
|
+
case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16].pipeline; break;
|
3756
3923
|
default: LM_GGML_ABORT("fatal error");
|
3757
3924
|
};
|
3758
3925
|
} else {
|
3759
3926
|
switch (src0->type) {
|
3760
|
-
case LM_GGML_TYPE_F32: pipeline = ctx->kernels[
|
3761
|
-
case LM_GGML_TYPE_F16: pipeline = ctx->kernels[
|
3927
|
+
case LM_GGML_TYPE_F32: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break;
|
3928
|
+
case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break;
|
3762
3929
|
default: LM_GGML_ABORT("fatal error");
|
3763
3930
|
};
|
3764
3931
|
}
|
@@ -3789,6 +3956,10 @@ static bool lm_ggml_metal_encode_node(
|
|
3789
3956
|
/*.attn_factor =*/ attn_factor,
|
3790
3957
|
/*.beta_fast =*/ beta_fast,
|
3791
3958
|
/*.beta_slow =*/ beta_slow,
|
3959
|
+
/* sect_0 =*/ sect_0,
|
3960
|
+
/* sect_1 =*/ sect_1,
|
3961
|
+
/* sect_2 =*/ sect_2,
|
3962
|
+
/* sect_3 =*/ sect_3,
|
3792
3963
|
};
|
3793
3964
|
|
3794
3965
|
[encoder setComputePipelineState:pipeline];
|
@@ -4225,7 +4396,7 @@ static bool lm_ggml_metal_encode_node(
|
|
4225
4396
|
// TODO: add vec kernels for (ne00%64 == 0) and maybe also for (ne00%32 == 0)
|
4226
4397
|
// for now avoiding mainly to keep the number of templates/kernels a bit lower
|
4227
4398
|
// these are now trivial to add after: https://github.com/ggml-org/llama.cpp/pull/12612
|
4228
|
-
if (ne01 >=
|
4399
|
+
if (ne01 >= 20 || (ne00%128 != 0 && ne00 != 64 && ne00 != 96 && ne00 != 192 && ne00 != 576)) {
|
4229
4400
|
switch (src1->type) {
|
4230
4401
|
case LM_GGML_TYPE_F16:
|
4231
4402
|
{
|
@@ -4406,6 +4577,24 @@ static bool lm_ggml_metal_encode_node(
|
|
4406
4577
|
use_vec_kernel = true;
|
4407
4578
|
|
4408
4579
|
switch (ne00) {
|
4580
|
+
case 64:
|
4581
|
+
{
|
4582
|
+
switch (src1->type) {
|
4583
|
+
case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64].pipeline; break;
|
4584
|
+
case LM_GGML_TYPE_BF16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64].pipeline; break;
|
4585
|
+
case LM_GGML_TYPE_Q4_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64].pipeline; break;
|
4586
|
+
case LM_GGML_TYPE_Q4_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H64].pipeline; break;
|
4587
|
+
case LM_GGML_TYPE_Q5_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H64].pipeline; break;
|
4588
|
+
case LM_GGML_TYPE_Q5_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H64].pipeline; break;
|
4589
|
+
case LM_GGML_TYPE_Q8_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H64].pipeline; break;
|
4590
|
+
default:
|
4591
|
+
{
|
4592
|
+
LM_GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
|
4593
|
+
LM_GGML_LOG_ERROR("add template specialization for this type\n");
|
4594
|
+
LM_GGML_ABORT("add template specialization for this type");
|
4595
|
+
}
|
4596
|
+
}
|
4597
|
+
} break;
|
4409
4598
|
case 96:
|
4410
4599
|
{
|
4411
4600
|
switch (src1->type) {
|