cui-llama.rn 1.6.0 → 1.6.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +35 -7
- package/android/src/main/CMakeLists.txt +16 -11
- package/android/src/main/java/com/rnllama/LlamaContext.java +4 -1
- package/android/src/main/jni.cpp +20 -4
- package/android/src/main/jniLibs/arm64-v8a/librnllama.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod_i8mm.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_i8mm.so +0 -0
- package/android/src/main/jniLibs/x86_64/librnllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/librnllama_x86_64.so +0 -0
- package/cpp/LICENSE +21 -0
- package/cpp/chat.cpp +1 -1
- package/cpp/common.cpp +17 -2
- package/cpp/common.h +7 -3
- package/cpp/ggml-alloc.c +4 -1
- package/cpp/ggml-cpp.h +1 -1
- package/cpp/ggml-cpu/amx/amx.cpp +221 -0
- package/cpp/ggml-cpu/amx/amx.h +8 -0
- package/cpp/ggml-cpu/amx/common.h +91 -0
- package/cpp/ggml-cpu/amx/mmq.cpp +2511 -0
- package/cpp/ggml-cpu/amx/mmq.h +10 -0
- package/{ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers → cpp/ggml-cpu}/binary-ops.h +1 -1
- package/cpp/ggml-cpu/common.h +72 -0
- package/cpp/{ggml-cpu-aarch64.cpp → ggml-cpu/ggml-cpu-aarch64.cpp} +809 -101
- package/cpp/{ggml-cpu.c → ggml-cpu/ggml-cpu.c} +109 -42
- package/cpp/{ggml-cpu.cpp → ggml-cpu/ggml-cpu.cpp} +3 -0
- package/cpp/{ops.cpp → ggml-cpu/ops.cpp} +246 -160
- package/{ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers → cpp/ggml-cpu}/ops.h +2 -20
- package/cpp/{sgemm.cpp → ggml-cpu/sgemm.cpp} +501 -0
- package/{ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers → cpp/ggml-cpu}/simd-mappings.h +7 -3
- package/{ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers → cpp/ggml-cpu}/unary-ops.h +1 -1
- package/cpp/ggml-cpu.h +5 -0
- package/cpp/ggml-impl.h +16 -9
- package/cpp/ggml-llama-sim.metallib +0 -0
- package/cpp/ggml-llama.metallib +0 -0
- package/cpp/ggml-metal.m +492 -47
- package/cpp/ggml.c +134 -244
- package/cpp/ggml.h +61 -94
- package/cpp/json-schema-to-grammar.cpp +3 -0
- package/cpp/llama-arch.cpp +46 -17
- package/cpp/llama-arch.h +9 -0
- package/cpp/llama-batch.cpp +5 -1
- package/cpp/llama-batch.h +2 -1
- package/cpp/llama-chat.cpp +31 -10
- package/cpp/llama-chat.h +3 -2
- package/cpp/llama-context.cpp +104 -489
- package/cpp/llama-context.h +14 -30
- package/cpp/llama-graph.cpp +69 -62
- package/cpp/llama-graph.h +21 -18
- package/cpp/llama-hparams.h +5 -0
- package/cpp/llama-kv-cache.cpp +1497 -391
- package/cpp/llama-kv-cache.h +272 -80
- package/cpp/llama-memory.h +11 -1
- package/cpp/llama-model.cpp +502 -176
- package/cpp/llama-model.h +13 -3
- package/cpp/llama-sampling.cpp +2 -1
- package/cpp/llama-vocab.cpp +8 -1
- package/cpp/llama.h +14 -11
- package/cpp/rn-llama.cpp +20 -172
- package/cpp/rn-llama.h +1 -5
- package/ios/CMakeLists.txt +13 -10
- package/ios/RNLlama.h +6 -0
- package/ios/RNLlama.mm +5 -0
- package/ios/RNLlamaContext.mm +26 -28
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/common.h +7 -3
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpp.h +1 -1
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu.h +5 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-impl.h +16 -9
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml.h +61 -94
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-arch.h +9 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-batch.h +2 -1
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-chat.h +3 -2
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-context.h +14 -30
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-graph.h +21 -18
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-hparams.h +5 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cache.h +272 -80
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-memory.h +11 -1
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model.h +13 -3
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama.h +14 -11
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/rn-llama.h +1 -5
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/ggml-llama.metallib +0 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/rnllama +0 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +7 -3
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpp.h +1 -1
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu.h +5 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-impl.h +16 -9
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +61 -94
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-arch.h +9 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-batch.h +2 -1
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +3 -2
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +14 -30
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +21 -18
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +5 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +272 -80
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +11 -1
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +13 -3
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +14 -11
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +1 -5
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/common.h +7 -3
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpp.h +1 -1
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu.h +5 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-impl.h +16 -9
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml.h +61 -94
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-arch.h +9 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-batch.h +2 -1
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-chat.h +3 -2
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-context.h +14 -30
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-graph.h +21 -18
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-hparams.h +5 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cache.h +272 -80
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-memory.h +11 -1
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model.h +13 -3
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama.h +14 -11
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/rn-llama.h +1 -5
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/ggml-llama.metallib +0 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/rnllama +0 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +7 -3
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpp.h +1 -1
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu.h +5 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-impl.h +16 -9
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +61 -94
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-arch.h +9 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-batch.h +2 -1
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +3 -2
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +14 -30
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +21 -18
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +5 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +272 -80
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +11 -1
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +13 -3
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +14 -11
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +1 -5
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
- package/lib/commonjs/NativeRNLlama.js.map +1 -1
- package/lib/module/NativeRNLlama.js.map +1 -1
- package/lib/typescript/NativeRNLlama.d.ts +4 -0
- package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
- package/package.json +1 -1
- package/src/NativeRNLlama.ts +5 -0
- package/cpp/binary-ops.h +0 -16
- package/cpp/ops.h +0 -128
- package/cpp/simd-mappings.h +0 -888
- package/cpp/unary-ops.h +0 -28
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/binary-ops.h +0 -16
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-aarch64.h +0 -8
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-impl.h +0 -512
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-quants.h +0 -63
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-traits.h +0 -38
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ops.h +0 -128
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/sgemm.h +0 -14
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/simd-mappings.h +0 -888
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/vec.h +0 -802
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-aarch64.h +0 -8
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-impl.h +0 -512
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-quants.h +0 -63
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-traits.h +0 -38
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/sgemm.h +0 -14
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/unary-ops.h +0 -28
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/vec.h +0 -802
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/binary-ops.h +0 -16
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-aarch64.h +0 -8
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-impl.h +0 -512
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-quants.h +0 -63
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-traits.h +0 -38
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ops.h +0 -128
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/sgemm.h +0 -14
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/simd-mappings.h +0 -888
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/unary-ops.h +0 -28
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/vec.h +0 -802
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/binary-ops.h +0 -16
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-aarch64.h +0 -8
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-impl.h +0 -512
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-quants.h +0 -63
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-traits.h +0 -38
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ops.h +0 -128
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/sgemm.h +0 -14
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/simd-mappings.h +0 -888
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/unary-ops.h +0 -28
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/vec.h +0 -802
- /package/cpp/{binary-ops.cpp → ggml-cpu/binary-ops.cpp} +0 -0
- /package/cpp/{ggml-cpu-aarch64.h → ggml-cpu/ggml-cpu-aarch64.h} +0 -0
- /package/cpp/{ggml-cpu-impl.h → ggml-cpu/ggml-cpu-impl.h} +0 -0
- /package/cpp/{ggml-cpu-quants.c → ggml-cpu/ggml-cpu-quants.c} +0 -0
- /package/cpp/{ggml-cpu-quants.h → ggml-cpu/ggml-cpu-quants.h} +0 -0
- /package/cpp/{ggml-cpu-traits.cpp → ggml-cpu/ggml-cpu-traits.cpp} +0 -0
- /package/cpp/{ggml-cpu-traits.h → ggml-cpu/ggml-cpu-traits.h} +0 -0
- /package/cpp/{sgemm.h → ggml-cpu/sgemm.h} +0 -0
- /package/cpp/{unary-ops.cpp → ggml-cpu/unary-ops.cpp} +0 -0
- /package/cpp/{vec.cpp → ggml-cpu/vec.cpp} +0 -0
- /package/cpp/{vec.h → ggml-cpu/vec.h} +0 -0
package/cpp/ggml-metal.m
CHANGED
@@ -44,8 +44,8 @@ static struct lm_ggml_backend_device g_lm_ggml_backend_metal_device;
|
|
44
44
|
// note: assumes single GPU device - the default one
|
45
45
|
// TODO: support multiple GPU devices
|
46
46
|
static struct lm_ggml_backend_metal_device_context {
|
47
|
-
id<MTLDevice>
|
48
|
-
int
|
47
|
+
id<MTLDevice> mtl_device;
|
48
|
+
int mtl_device_ref_count;
|
49
49
|
id<MTLLibrary> mtl_library;
|
50
50
|
|
51
51
|
bool has_simdgroup_reduction;
|
@@ -354,6 +354,7 @@ enum lm_ggml_metal_kernel_type {
|
|
354
354
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H192,
|
355
355
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128,
|
356
356
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
|
357
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512,
|
357
358
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64,
|
358
359
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80,
|
359
360
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96,
|
@@ -362,6 +363,7 @@ enum lm_ggml_metal_kernel_type {
|
|
362
363
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H192,
|
363
364
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128,
|
364
365
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256,
|
366
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512,
|
365
367
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64,
|
366
368
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80,
|
367
369
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96,
|
@@ -370,6 +372,7 @@ enum lm_ggml_metal_kernel_type {
|
|
370
372
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H192,
|
371
373
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128,
|
372
374
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256,
|
375
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512,
|
373
376
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64,
|
374
377
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80,
|
375
378
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96,
|
@@ -378,6 +381,7 @@ enum lm_ggml_metal_kernel_type {
|
|
378
381
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H192,
|
379
382
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128,
|
380
383
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256,
|
384
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512,
|
381
385
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64,
|
382
386
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80,
|
383
387
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96,
|
@@ -386,6 +390,7 @@ enum lm_ggml_metal_kernel_type {
|
|
386
390
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H192,
|
387
391
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128,
|
388
392
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256,
|
393
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512,
|
389
394
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64,
|
390
395
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80,
|
391
396
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96,
|
@@ -394,6 +399,7 @@ enum lm_ggml_metal_kernel_type {
|
|
394
399
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H192,
|
395
400
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128,
|
396
401
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256,
|
402
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512,
|
397
403
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64,
|
398
404
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80,
|
399
405
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96,
|
@@ -402,6 +408,14 @@ enum lm_ggml_metal_kernel_type {
|
|
402
408
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192,
|
403
409
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128,
|
404
410
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,
|
411
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512,
|
412
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96,
|
413
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96,
|
414
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96,
|
415
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H96,
|
416
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H96,
|
417
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H96,
|
418
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H96,
|
405
419
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
|
406
420
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128,
|
407
421
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128,
|
@@ -430,6 +444,13 @@ enum lm_ggml_metal_kernel_type {
|
|
430
444
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256,
|
431
445
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256,
|
432
446
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256,
|
447
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512,
|
448
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK576_HV512,
|
449
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK576_HV512,
|
450
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK576_HV512,
|
451
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512,
|
452
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512,
|
453
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512,
|
433
454
|
LM_GGML_METAL_KERNEL_TYPE_SET_I32,
|
434
455
|
LM_GGML_METAL_KERNEL_TYPE_SET_F32,
|
435
456
|
LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
|
@@ -460,6 +481,7 @@ enum lm_ggml_metal_kernel_type {
|
|
460
481
|
LM_GGML_METAL_KERNEL_TYPE_SQRT,
|
461
482
|
LM_GGML_METAL_KERNEL_TYPE_SIN,
|
462
483
|
LM_GGML_METAL_KERNEL_TYPE_COS,
|
484
|
+
LM_GGML_METAL_KERNEL_TYPE_NEG,
|
463
485
|
LM_GGML_METAL_KERNEL_TYPE_SUM_ROWS,
|
464
486
|
LM_GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
|
465
487
|
LM_GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
|
@@ -468,7 +490,259 @@ enum lm_ggml_metal_kernel_type {
|
|
468
490
|
LM_GGML_METAL_KERNEL_TYPE_COUNT
|
469
491
|
};
|
470
492
|
|
493
|
+
//
|
494
|
+
// lm_ggml_metal_heap
|
495
|
+
//
|
496
|
+
|
497
|
+
struct lm_ggml_metal_heap {
|
498
|
+
// number of times the heap was unused
|
499
|
+
int n_unused;
|
500
|
+
|
501
|
+
// total number of buffer allocations in this heap across all computes
|
502
|
+
int64_t n_alloc;
|
503
|
+
|
504
|
+
// current offset in the heap - we reset this after each node in order to reuse the memory
|
505
|
+
size_t offs;
|
506
|
+
|
507
|
+
// the currently allocated MTLBuffer objects in this heap
|
508
|
+
id<MTLHeap> obj;
|
509
|
+
|
510
|
+
NSMutableArray * bufs;
|
511
|
+
};
|
512
|
+
|
513
|
+
static struct lm_ggml_metal_heap * lm_ggml_metal_heap_init(id<MTLDevice> device, size_t size) {
|
514
|
+
struct lm_ggml_metal_heap * heap = calloc(1, sizeof(struct lm_ggml_metal_heap));
|
515
|
+
|
516
|
+
MTLHeapDescriptor * desc = [[MTLHeapDescriptor alloc] init];
|
517
|
+
desc.storageMode = MTLStorageModePrivate;
|
518
|
+
desc.cpuCacheMode = MTLCPUCacheModeDefaultCache;
|
519
|
+
desc.type = MTLHeapTypePlacement;
|
520
|
+
desc.size = size;
|
521
|
+
|
522
|
+
heap->n_unused = 0;
|
523
|
+
heap->n_alloc = 0;
|
524
|
+
|
525
|
+
heap->obj = [device newHeapWithDescriptor:desc];
|
526
|
+
if (!heap->obj) {
|
527
|
+
LM_GGML_LOG_ERROR("%s: error: failed to create MTLHeap with size %zu\n", __func__, size);
|
528
|
+
|
529
|
+
free(heap);
|
530
|
+
|
531
|
+
return false;
|
532
|
+
}
|
533
|
+
|
534
|
+
[desc release];
|
535
|
+
|
536
|
+
heap->bufs = [[NSMutableArray alloc] init];
|
537
|
+
|
538
|
+
return heap;
|
539
|
+
}
|
540
|
+
|
541
|
+
static void lm_ggml_metal_heap_reset(struct lm_ggml_metal_heap * heap) {
|
542
|
+
heap->offs = 0;
|
543
|
+
|
544
|
+
// count how many graph computes the heap ended up being unused
|
545
|
+
if ([heap->bufs count] > 0) {
|
546
|
+
heap->n_unused = 0;
|
547
|
+
} else {
|
548
|
+
heap->n_unused++;
|
549
|
+
}
|
550
|
+
|
551
|
+
for (id<MTLBuffer> buf in heap->bufs) {
|
552
|
+
[buf release];
|
553
|
+
}
|
554
|
+
[heap->bufs removeAllObjects];
|
555
|
+
|
556
|
+
// tell the OS that it can reuse this memory if needed
|
557
|
+
// ref: https://developer.apple.com/documentation/metal/mtlpurgeablestate?language=objc
|
558
|
+
[heap->obj setPurgeableState:MTLPurgeableStateVolatile];
|
559
|
+
}
|
560
|
+
|
561
|
+
static void lm_ggml_metal_heap_free(struct lm_ggml_metal_heap * heap) {
|
562
|
+
if (heap == nil) {
|
563
|
+
return;
|
564
|
+
}
|
565
|
+
|
566
|
+
lm_ggml_metal_heap_reset(heap);
|
567
|
+
|
568
|
+
[heap->obj release];
|
569
|
+
[heap->bufs release];
|
570
|
+
|
571
|
+
free(heap);
|
572
|
+
}
|
573
|
+
|
574
|
+
@interface lm_ggml_metal_heap_ptr : NSObject
|
575
|
+
|
576
|
+
@property (nonatomic, assign) struct lm_ggml_metal_heap * data;
|
577
|
+
|
578
|
+
@end
|
579
|
+
|
580
|
+
@implementation lm_ggml_metal_heap_ptr
|
581
|
+
@end
|
582
|
+
|
583
|
+
//
|
584
|
+
// lm_ggml_metal_mem_pool
|
585
|
+
//
|
586
|
+
|
587
|
+
struct lm_ggml_metal_mem_pool {
|
588
|
+
id<MTLDevice> device;
|
589
|
+
|
590
|
+
int n_heaps; // total number of heaps ever created (including those that were removed)
|
591
|
+
|
592
|
+
NSMutableArray * heaps;
|
593
|
+
NSMutableArray * heaps_to_remove;
|
594
|
+
};
|
595
|
+
|
596
|
+
static struct lm_ggml_metal_mem_pool * lm_ggml_metal_mem_pool_init(void) {
|
597
|
+
struct lm_ggml_metal_mem_pool * mem_pool = calloc(1, sizeof(struct lm_ggml_metal_mem_pool));
|
598
|
+
|
599
|
+
mem_pool->n_heaps = 0;
|
600
|
+
|
601
|
+
mem_pool->heaps = [[NSMutableArray alloc] init];
|
602
|
+
mem_pool->heaps_to_remove = [[NSMutableArray alloc] init];
|
603
|
+
|
604
|
+
return mem_pool;
|
605
|
+
}
|
606
|
+
|
607
|
+
static void lm_ggml_metal_mem_pool_free(struct lm_ggml_metal_mem_pool * mem_pool) {
|
608
|
+
LM_GGML_LOG_DEBUG("%s: freeing memory pool, num heaps = %zu (total = %d)\n", __func__, [mem_pool->heaps count], mem_pool->n_heaps);
|
609
|
+
|
610
|
+
size_t size_all = 0;
|
611
|
+
size_t size_cur = 0;
|
612
|
+
|
613
|
+
for (lm_ggml_metal_heap_ptr * ptr in mem_pool->heaps) {
|
614
|
+
LM_GGML_LOG_DEBUG("%s: heap: %p\n", __func__, (void *) ptr.data);
|
615
|
+
LM_GGML_LOG_DEBUG("%s: n_alloc: %" PRId64 "\n", __func__, ptr.data->n_alloc);
|
616
|
+
LM_GGML_LOG_DEBUG("%s: n_unused: %d\n", __func__, ptr.data->n_unused);
|
617
|
+
LM_GGML_LOG_DEBUG("%s: size: %.2f MiB\n", __func__, [ptr.data->obj size] / 1024.0 / 1024.0);
|
618
|
+
LM_GGML_LOG_DEBUG("%s: bufs: %zu\n", __func__, [ptr.data->bufs count]);
|
619
|
+
|
620
|
+
if ([ptr.data->bufs count] > 0) {
|
621
|
+
size_cur += [ptr.data->obj size];
|
622
|
+
}
|
623
|
+
size_all += [ptr.data->obj size];
|
624
|
+
|
625
|
+
lm_ggml_metal_heap_free(ptr.data);
|
626
|
+
[ptr release];
|
627
|
+
}
|
628
|
+
[mem_pool->heaps release];
|
629
|
+
[mem_pool->heaps_to_remove release];
|
630
|
+
|
631
|
+
if (size_all > 0) {
|
632
|
+
LM_GGML_LOG_DEBUG("%s: size_all: %.2f MiB\n", __func__, size_all / 1024.0 / 1024.0);
|
633
|
+
LM_GGML_LOG_DEBUG("%s: size_cur: %.2f MiB\n", __func__, size_cur / 1024.0 / 1024.0);
|
634
|
+
}
|
635
|
+
|
636
|
+
free(mem_pool);
|
637
|
+
}
|
638
|
+
|
639
|
+
static void lm_ggml_metal_mem_pool_reset(struct lm_ggml_metal_mem_pool * mem_pool) {
|
640
|
+
for (NSUInteger i = 0; i < [mem_pool->heaps count]; i++) {
|
641
|
+
lm_ggml_metal_heap_ptr * ptr = [mem_pool->heaps objectAtIndex:i];
|
642
|
+
|
643
|
+
struct lm_ggml_metal_heap * heap = ptr.data;
|
644
|
+
lm_ggml_metal_heap_reset(heap);
|
645
|
+
|
646
|
+
// if the heap hasn't been used for a while, remove it
|
647
|
+
if (heap->n_unused >= 128) {
|
648
|
+
[mem_pool->heaps_to_remove addObject:@(i)];
|
649
|
+
}
|
650
|
+
}
|
651
|
+
|
652
|
+
if (mem_pool->heaps_to_remove.count > 0) {
|
653
|
+
for (NSUInteger i = 0; i < [mem_pool->heaps_to_remove count]; i++) {
|
654
|
+
NSUInteger index = [[mem_pool->heaps_to_remove objectAtIndex:i] intValue];
|
655
|
+
lm_ggml_metal_heap_ptr * ptr = [mem_pool->heaps objectAtIndex:index];
|
656
|
+
|
657
|
+
struct lm_ggml_metal_heap * heap = ptr.data;
|
658
|
+
lm_ggml_metal_heap_free(heap);
|
659
|
+
|
660
|
+
[mem_pool->heaps removeObjectAtIndex:index];
|
661
|
+
[ptr release];
|
662
|
+
}
|
663
|
+
|
664
|
+
[mem_pool->heaps_to_remove removeAllObjects];
|
665
|
+
}
|
666
|
+
}
|
667
|
+
|
668
|
+
static void lm_ggml_metal_mem_pool_clear(struct lm_ggml_metal_mem_pool * mem_pool) {
|
669
|
+
for (lm_ggml_metal_heap_ptr * ptr in mem_pool->heaps) {
|
670
|
+
ptr.data->offs = 0;
|
671
|
+
}
|
672
|
+
}
|
673
|
+
|
674
|
+
static id<MTLBuffer> lm_ggml_metal_mem_pool_alloc(struct lm_ggml_metal_mem_pool * mem_pool, size_t size) {
|
675
|
+
const size_t alignment = 32;
|
676
|
+
|
677
|
+
const size_t size_aligned = LM_GGML_PAD(size, alignment);
|
678
|
+
|
679
|
+
// try one of the existing heaps
|
680
|
+
for (lm_ggml_metal_heap_ptr * ptr in mem_pool->heaps) {
|
681
|
+
struct lm_ggml_metal_heap * heap = ptr.data;
|
682
|
+
if (heap->offs + size_aligned <= [heap->obj size]) {
|
683
|
+
// if this is the first buffer in the heap for the current command buffer, tell the OS that
|
684
|
+
// it cannot free the memory used by the heap
|
685
|
+
// ref: https://developer.apple.com/documentation/metal/mtlpurgeablestate?language=objc
|
686
|
+
if ([heap->bufs count] == 0) {
|
687
|
+
[heap->obj setPurgeableState:MTLPurgeableStateNonVolatile];
|
688
|
+
}
|
689
|
+
|
690
|
+
id<MTLBuffer> buf = [heap->obj newBufferWithLength:size_aligned options:MTLResourceStorageModePrivate offset:heap->offs];
|
691
|
+
if (buf == nil) {
|
692
|
+
LM_GGML_LOG_ERROR("%s: error: failed to create MTLBuffer with size %zu\n", __func__, size_aligned);
|
693
|
+
return nil;
|
694
|
+
}
|
695
|
+
|
696
|
+
heap->n_alloc++;
|
697
|
+
heap->offs += size_aligned;
|
698
|
+
|
699
|
+
[heap->bufs addObject:buf];
|
700
|
+
|
701
|
+
return buf;
|
702
|
+
}
|
703
|
+
}
|
704
|
+
|
705
|
+
// create a new heap that can fit this buffer
|
706
|
+
lm_ggml_metal_heap_ptr * heap_ptr = [lm_ggml_metal_heap_ptr new];
|
707
|
+
|
708
|
+
struct lm_ggml_metal_heap * heap = lm_ggml_metal_heap_init(mem_pool->device, size_aligned);
|
709
|
+
if (heap == NULL) {
|
710
|
+
LM_GGML_LOG_ERROR("%s: error: failed to create heap of size %zu\n", __func__, size_aligned);
|
711
|
+
return NULL;
|
712
|
+
}
|
713
|
+
|
714
|
+
//LM_GGML_LOG_DEBUG("%s: creating new heap of size %zu, got %zu\n", __func__, size_aligned, [heap->obj size]);
|
715
|
+
|
716
|
+
heap_ptr.data = heap;
|
717
|
+
lm_ggml_metal_heap_reset(heap);
|
718
|
+
|
719
|
+
[heap->obj setPurgeableState:MTLPurgeableStateNonVolatile];
|
720
|
+
id<MTLBuffer> buf = [heap->obj newBufferWithLength:size_aligned options:MTLResourceStorageModePrivate offset:heap->offs];
|
721
|
+
if (buf == nil) {
|
722
|
+
LM_GGML_LOG_ERROR("%s: error: failed to create MTLBuffer with size %zu\n", __func__, size_aligned);
|
723
|
+
return NULL;
|
724
|
+
}
|
725
|
+
|
726
|
+
heap->n_alloc++;
|
727
|
+
heap->offs += size_aligned;
|
728
|
+
|
729
|
+
[heap->bufs addObject:buf];
|
730
|
+
|
731
|
+
[mem_pool->heaps addObject:heap_ptr];
|
732
|
+
mem_pool->n_heaps++;
|
733
|
+
|
734
|
+
return buf;
|
735
|
+
}
|
736
|
+
|
737
|
+
struct lm_ggml_metal_command_buffer {
|
738
|
+
id<MTLCommandBuffer> obj;
|
739
|
+
|
740
|
+
// each command buffer has a memory pool from which it can allocate temporary buffers during the compute
|
741
|
+
struct lm_ggml_metal_mem_pool * mem_pool;
|
742
|
+
};
|
743
|
+
|
471
744
|
struct lm_ggml_backend_metal_context {
|
745
|
+
id<MTLDevice> device;
|
472
746
|
id<MTLCommandQueue> queue;
|
473
747
|
|
474
748
|
dispatch_queue_t d_queue;
|
@@ -493,7 +767,7 @@ struct lm_ggml_backend_metal_context {
|
|
493
767
|
void (^encode_async)(size_t ith);
|
494
768
|
|
495
769
|
// n_cb command buffers + 1 used by the main thread
|
496
|
-
|
770
|
+
struct lm_ggml_metal_command_buffer cmd_bufs[LM_GGML_METAL_MAX_COMMAND_BUFFERS + 1];
|
497
771
|
|
498
772
|
// abort lm_ggml_metal_graph_compute if callback returns true
|
499
773
|
lm_ggml_abort_callback abort_callback;
|
@@ -687,9 +961,11 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
|
|
687
961
|
struct lm_ggml_backend_metal_device_context * ctx_dev = dev->context;
|
688
962
|
|
689
963
|
id<MTLDevice> device = lm_ggml_backend_metal_device_acq(ctx_dev);
|
964
|
+
|
690
965
|
LM_GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]);
|
691
966
|
|
692
|
-
ctx->
|
967
|
+
ctx->device = device;
|
968
|
+
ctx->queue = [device newCommandQueue];
|
693
969
|
if (ctx->queue == nil) {
|
694
970
|
LM_GGML_LOG_ERROR("%s: error: failed to create command queue\n", __func__);
|
695
971
|
return NULL;
|
@@ -750,7 +1026,10 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
|
|
750
1026
|
ctx->gf = nil;
|
751
1027
|
ctx->encode_async = nil;
|
752
1028
|
for (int i = 0; i < LM_GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
|
753
|
-
ctx->
|
1029
|
+
ctx->cmd_bufs[i].obj = nil;
|
1030
|
+
|
1031
|
+
ctx->cmd_bufs[i].mem_pool = lm_ggml_metal_mem_pool_init();
|
1032
|
+
ctx->cmd_bufs[i].mem_pool->device = device;
|
754
1033
|
}
|
755
1034
|
|
756
1035
|
#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
|
@@ -1015,6 +1294,7 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
|
|
1015
1294
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H192, flash_attn_ext_f16_h192, has_simdgroup_mm);
|
1016
1295
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128, flash_attn_ext_f16_hk192_hv128, has_simdgroup_mm);
|
1017
1296
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, has_simdgroup_mm);
|
1297
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512, flash_attn_ext_f16_hk576_hv512, has_simdgroup_mm);
|
1018
1298
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64, flash_attn_ext_bf16_h64, has_simdgroup_mm && use_bfloat);
|
1019
1299
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80, flash_attn_ext_bf16_h80, has_simdgroup_mm && use_bfloat);
|
1020
1300
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96, flash_attn_ext_bf16_h96, has_simdgroup_mm && use_bfloat);
|
@@ -1023,6 +1303,7 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
|
|
1023
1303
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H192, flash_attn_ext_bf16_h192, has_simdgroup_mm && use_bfloat);
|
1024
1304
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128, flash_attn_ext_bf16_hk192_hv128, has_simdgroup_mm && use_bfloat);
|
1025
1305
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256, flash_attn_ext_bf16_h256, has_simdgroup_mm && use_bfloat);
|
1306
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512, flash_attn_ext_bf16_hk576_hv512, has_simdgroup_mm && use_bfloat);
|
1026
1307
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, flash_attn_ext_q4_0_h64, has_simdgroup_mm);
|
1027
1308
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, flash_attn_ext_q4_0_h80, has_simdgroup_mm);
|
1028
1309
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, flash_attn_ext_q4_0_h96, has_simdgroup_mm);
|
@@ -1031,6 +1312,7 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
|
|
1031
1312
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H192, flash_attn_ext_q4_0_h192, has_simdgroup_mm);
|
1032
1313
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128, flash_attn_ext_q4_0_hk192_hv128, has_simdgroup_mm);
|
1033
1314
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256, flash_attn_ext_q4_0_h256, has_simdgroup_mm);
|
1315
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512, flash_attn_ext_q4_0_hk576_hv512, has_simdgroup_mm);
|
1034
1316
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64, flash_attn_ext_q4_1_h64, has_simdgroup_mm);
|
1035
1317
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80, flash_attn_ext_q4_1_h80, has_simdgroup_mm);
|
1036
1318
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96, flash_attn_ext_q4_1_h96, has_simdgroup_mm);
|
@@ -1039,6 +1321,7 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
|
|
1039
1321
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H192, flash_attn_ext_q4_1_h192, has_simdgroup_mm);
|
1040
1322
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128, flash_attn_ext_q4_1_hk192_hv128, has_simdgroup_mm);
|
1041
1323
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256, flash_attn_ext_q4_1_h256, has_simdgroup_mm);
|
1324
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512, flash_attn_ext_q4_1_hk576_hv512, has_simdgroup_mm);
|
1042
1325
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64, flash_attn_ext_q5_0_h64, has_simdgroup_mm);
|
1043
1326
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80, flash_attn_ext_q5_0_h80, has_simdgroup_mm);
|
1044
1327
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96, flash_attn_ext_q5_0_h96, has_simdgroup_mm);
|
@@ -1047,6 +1330,7 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
|
|
1047
1330
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H192, flash_attn_ext_q5_0_h192, has_simdgroup_mm);
|
1048
1331
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128, flash_attn_ext_q5_0_hk192_hv128, has_simdgroup_mm);
|
1049
1332
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256, flash_attn_ext_q5_0_h256, has_simdgroup_mm);
|
1333
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512, flash_attn_ext_q5_0_hk576_hv512, has_simdgroup_mm);
|
1050
1334
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64, flash_attn_ext_q5_1_h64, has_simdgroup_mm);
|
1051
1335
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80, flash_attn_ext_q5_1_h80, has_simdgroup_mm);
|
1052
1336
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96, flash_attn_ext_q5_1_h96, has_simdgroup_mm);
|
@@ -1055,6 +1339,7 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
|
|
1055
1339
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H192, flash_attn_ext_q5_1_h192, has_simdgroup_mm);
|
1056
1340
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128, flash_attn_ext_q5_1_hk192_hv128, has_simdgroup_mm);
|
1057
1341
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256, flash_attn_ext_q5_1_h256, has_simdgroup_mm);
|
1342
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512, flash_attn_ext_q5_1_hk576_hv512, has_simdgroup_mm);
|
1058
1343
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64, flash_attn_ext_q8_0_h64, has_simdgroup_mm);
|
1059
1344
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80, flash_attn_ext_q8_0_h80, has_simdgroup_mm);
|
1060
1345
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96, flash_attn_ext_q8_0_h96, has_simdgroup_mm);
|
@@ -1063,6 +1348,14 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
|
|
1063
1348
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192, flash_attn_ext_q8_0_h192, has_simdgroup_mm);
|
1064
1349
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128, flash_attn_ext_q8_0_hk192_hv128, has_simdgroup_mm);
|
1065
1350
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, has_simdgroup_mm);
|
1351
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512, flash_attn_ext_q8_0_hk576_hv512, has_simdgroup_mm);
|
1352
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96, flash_attn_ext_vec_f16_h96, has_simdgroup_reduction);
|
1353
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96, flash_attn_ext_vec_bf16_h96, has_simdgroup_reduction && use_bfloat);
|
1354
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96, flash_attn_ext_vec_q4_0_h96, has_simdgroup_reduction);
|
1355
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H96, flash_attn_ext_vec_q4_1_h96, has_simdgroup_reduction);
|
1356
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H96, flash_attn_ext_vec_q5_0_h96, has_simdgroup_reduction);
|
1357
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H96, flash_attn_ext_vec_q5_1_h96, has_simdgroup_reduction);
|
1358
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H96, flash_attn_ext_vec_q8_0_h96, has_simdgroup_reduction);
|
1066
1359
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, has_simdgroup_reduction);
|
1067
1360
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128, flash_attn_ext_vec_bf16_h128, has_simdgroup_reduction && use_bfloat);
|
1068
1361
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128, flash_attn_ext_vec_q4_0_h128, has_simdgroup_reduction);
|
@@ -1091,6 +1384,13 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
|
|
1091
1384
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256, has_simdgroup_reduction);
|
1092
1385
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256, flash_attn_ext_vec_q5_1_h256, has_simdgroup_reduction);
|
1093
1386
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, flash_attn_ext_vec_q8_0_h256, has_simdgroup_reduction);
|
1387
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512, flash_attn_ext_vec_f16_hk576_hv512, has_simdgroup_reduction);
|
1388
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK576_HV512, flash_attn_ext_vec_bf16_hk576_hv512, has_simdgroup_reduction && use_bfloat);
|
1389
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK576_HV512, flash_attn_ext_vec_q4_0_hk576_hv512, has_simdgroup_reduction);
|
1390
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK576_HV512, flash_attn_ext_vec_q4_1_hk576_hv512, has_simdgroup_reduction);
|
1391
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512, flash_attn_ext_vec_q5_0_hk576_hv512, has_simdgroup_reduction);
|
1392
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512, flash_attn_ext_vec_q5_1_hk576_hv512, has_simdgroup_reduction);
|
1393
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512, flash_attn_ext_vec_q8_0_hk576_hv512, has_simdgroup_reduction);
|
1094
1394
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SET_F32, set_f32, true);
|
1095
1395
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SET_I32, set_i32, true);
|
1096
1396
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
|
@@ -1121,6 +1421,7 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
|
|
1121
1421
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SQRT, sqrt, true);
|
1122
1422
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SIN, sin, true);
|
1123
1423
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_COS, cos, true);
|
1424
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_NEG, neg, true);
|
1124
1425
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
1125
1426
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
|
1126
1427
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
|
@@ -1141,6 +1442,12 @@ static void lm_ggml_metal_free(struct lm_ggml_backend_metal_context * ctx) {
|
|
1141
1442
|
|
1142
1443
|
[ctx->queue release];
|
1143
1444
|
|
1445
|
+
for (int i = 0; i < LM_GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
|
1446
|
+
// ctx->cmd_bufs[i].obj is auto released
|
1447
|
+
|
1448
|
+
lm_ggml_metal_mem_pool_free(ctx->cmd_bufs[i].mem_pool);
|
1449
|
+
}
|
1450
|
+
|
1144
1451
|
dispatch_release(ctx->d_queue);
|
1145
1452
|
|
1146
1453
|
free(ctx);
|
@@ -1282,6 +1589,7 @@ static bool lm_ggml_metal_supports_op(const struct lm_ggml_backend_metal_device_
|
|
1282
1589
|
case LM_GGML_UNARY_OP_GELU_QUICK:
|
1283
1590
|
case LM_GGML_UNARY_OP_SILU:
|
1284
1591
|
case LM_GGML_UNARY_OP_ELU:
|
1592
|
+
case LM_GGML_UNARY_OP_NEG:
|
1285
1593
|
return lm_ggml_is_contiguous(op->src[0]) && op->src[0]->type == LM_GGML_TYPE_F32;
|
1286
1594
|
default:
|
1287
1595
|
return false;
|
@@ -1338,8 +1646,9 @@ static bool lm_ggml_metal_supports_op(const struct lm_ggml_backend_metal_device_
|
|
1338
1646
|
return op->src[0]->type == LM_GGML_TYPE_F16;
|
1339
1647
|
case LM_GGML_OP_POOL_1D:
|
1340
1648
|
return false;
|
1341
|
-
case LM_GGML_OP_POOL_2D:
|
1342
1649
|
case LM_GGML_OP_UPSCALE:
|
1650
|
+
return op->src[0]->type == LM_GGML_TYPE_F32 && op->op_params[0] == LM_GGML_SCALE_MODE_NEAREST;
|
1651
|
+
case LM_GGML_OP_POOL_2D:
|
1343
1652
|
case LM_GGML_OP_PAD:
|
1344
1653
|
case LM_GGML_OP_PAD_REFLECT_1D:
|
1345
1654
|
case LM_GGML_OP_TIMESTEP_EMBEDDING:
|
@@ -1354,6 +1663,11 @@ static bool lm_ggml_metal_supports_op(const struct lm_ggml_backend_metal_device_
|
|
1354
1663
|
// TODO: not sure if it is worth adding kernels for this size
|
1355
1664
|
return false;
|
1356
1665
|
}
|
1666
|
+
if (op->src[0]->ne[0] == 576) {
|
1667
|
+
// DeepSeek sizes
|
1668
|
+
// TODO: disabled for now, until optmized
|
1669
|
+
return false;
|
1670
|
+
}
|
1357
1671
|
if (op->src[1]->type != op->src[2]->type) {
|
1358
1672
|
return false;
|
1359
1673
|
}
|
@@ -1439,10 +1753,11 @@ static bool lm_ggml_metal_supports_op(const struct lm_ggml_backend_metal_device_
|
|
1439
1753
|
}
|
1440
1754
|
}
|
1441
1755
|
|
1442
|
-
static
|
1756
|
+
static bool lm_ggml_metal_encode_node(
|
1443
1757
|
lm_ggml_backend_t backend,
|
1444
1758
|
int idx,
|
1445
|
-
id<MTLComputeCommandEncoder> encoder
|
1759
|
+
id<MTLComputeCommandEncoder> encoder,
|
1760
|
+
struct lm_ggml_metal_mem_pool * mem_pool) {
|
1446
1761
|
struct lm_ggml_backend_metal_context * ctx = backend->context;
|
1447
1762
|
struct lm_ggml_backend_metal_device_context * ctx_dev = backend->device->context;
|
1448
1763
|
|
@@ -1458,7 +1773,7 @@ static void lm_ggml_metal_encode_node(
|
|
1458
1773
|
struct lm_ggml_tensor * dst = node;
|
1459
1774
|
|
1460
1775
|
if (lm_ggml_is_empty(dst)) {
|
1461
|
-
return;
|
1776
|
+
return true;
|
1462
1777
|
}
|
1463
1778
|
|
1464
1779
|
switch (dst->op) {
|
@@ -1469,7 +1784,7 @@ static void lm_ggml_metal_encode_node(
|
|
1469
1784
|
case LM_GGML_OP_PERMUTE:
|
1470
1785
|
{
|
1471
1786
|
// noop -> next node
|
1472
|
-
} return;
|
1787
|
+
} return true;
|
1473
1788
|
default:
|
1474
1789
|
{
|
1475
1790
|
} break;
|
@@ -1480,6 +1795,8 @@ static void lm_ggml_metal_encode_node(
|
|
1480
1795
|
LM_GGML_ABORT("unsupported op");
|
1481
1796
|
}
|
1482
1797
|
|
1798
|
+
lm_ggml_metal_mem_pool_clear(mem_pool);
|
1799
|
+
|
1483
1800
|
const int64_t ne00 = src0 ? src0->ne[0] : 0;
|
1484
1801
|
const int64_t ne01 = src0 ? src0->ne[1] : 0;
|
1485
1802
|
const int64_t ne02 = src0 ? src0->ne[2] : 0;
|
@@ -1966,6 +2283,18 @@ static void lm_ggml_metal_encode_node(
|
|
1966
2283
|
|
1967
2284
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1968
2285
|
} break;
|
2286
|
+
case LM_GGML_UNARY_OP_NEG:
|
2287
|
+
{
|
2288
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_NEG].pipeline;
|
2289
|
+
|
2290
|
+
[encoder setComputePipelineState:pipeline];
|
2291
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
2292
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
2293
|
+
|
2294
|
+
const int64_t n = lm_ggml_nelements(dst);
|
2295
|
+
|
2296
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
2297
|
+
} break;
|
1969
2298
|
default:
|
1970
2299
|
{
|
1971
2300
|
LM_GGML_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, idx, lm_ggml_op_name(dst->op));
|
@@ -2114,26 +2443,76 @@ static void lm_ggml_metal_encode_node(
|
|
2114
2443
|
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
2115
2444
|
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
2116
2445
|
|
2117
|
-
|
2446
|
+
// use this branch to test the lm_ggml_metal_mem_pool functionality
|
2447
|
+
#if 0
|
2448
|
+
// cpy to tmp buffer in MTLHeap
|
2449
|
+
|
2450
|
+
id<MTLBuffer> h_src0 = h_src0 = lm_ggml_metal_mem_pool_alloc(mem_pool, lm_ggml_nbytes(src0));
|
2451
|
+
if (!h_src0) {
|
2452
|
+
LM_GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, lm_ggml_nbytes(src0));
|
2453
|
+
return false;
|
2454
|
+
}
|
2455
|
+
|
2456
|
+
offs_src0 = 0;
|
2457
|
+
|
2458
|
+
lm_ggml_metal_kargs_cpy args_cpy = {
|
2118
2459
|
/*.ne00 =*/ ne00,
|
2119
2460
|
/*.ne01 =*/ ne01,
|
2120
2461
|
/*.ne02 =*/ ne02,
|
2121
|
-
/*.
|
2122
|
-
/*.
|
2123
|
-
/*.
|
2124
|
-
/*.
|
2462
|
+
/*.ne03 =*/ ne03,
|
2463
|
+
/*.nb00 =*/ nb00,
|
2464
|
+
/*.nb01 =*/ nb01,
|
2465
|
+
/*.nb02 =*/ nb02,
|
2466
|
+
/*.nb03 =*/ nb03,
|
2467
|
+
/*.ne0 =*/ ne00,
|
2468
|
+
/*.ne1 =*/ ne01,
|
2469
|
+
/*.ne2 =*/ ne02,
|
2470
|
+
/*.ne3 =*/ ne03,
|
2471
|
+
/*.nb0 =*/ nb00,
|
2472
|
+
/*.nb1 =*/ nb01,
|
2473
|
+
/*.nb2 =*/ nb02,
|
2474
|
+
/*.nb3 =*/ nb03,
|
2475
|
+
};
|
2476
|
+
|
2477
|
+
if (src0->type == LM_GGML_TYPE_F16) {
|
2478
|
+
[encoder setComputePipelineState:ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline];
|
2479
|
+
} else {
|
2480
|
+
[encoder setComputePipelineState:ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline];
|
2481
|
+
}
|
2482
|
+
[encoder setBytes:&args_cpy length:sizeof(args_cpy) atIndex:0];
|
2483
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
2484
|
+
[encoder setBuffer:h_src0 offset:0 atIndex:2];
|
2485
|
+
|
2486
|
+
LM_GGML_ASSERT(ne00 % lm_ggml_blck_size(src0->type) == 0);
|
2487
|
+
int nth_cpy = MIN(1024, ne00 / lm_ggml_blck_size(src0->type));
|
2488
|
+
|
2489
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth_cpy, 1, 1)];
|
2490
|
+
|
2491
|
+
#else
|
2492
|
+
id<MTLBuffer> h_src0 = id_src0;
|
2493
|
+
#endif
|
2494
|
+
// softmax
|
2495
|
+
|
2496
|
+
lm_ggml_metal_kargs_soft_max args = {
|
2497
|
+
/*.ne00 =*/ ne00,
|
2498
|
+
/*.ne01 =*/ ne01,
|
2499
|
+
/*.ne02 =*/ ne02,
|
2500
|
+
/*.scale =*/ scale,
|
2501
|
+
/*.max_bias =*/ max_bias,
|
2502
|
+
/*.m0 =*/ m0,
|
2503
|
+
/*.m1 =*/ m1,
|
2125
2504
|
/*.n_head_log2 =*/ n_head_log2,
|
2126
2505
|
};
|
2127
2506
|
|
2128
2507
|
[encoder setComputePipelineState:pipeline];
|
2129
|
-
[encoder setBuffer:
|
2508
|
+
[encoder setBuffer:h_src0 offset:offs_src0 atIndex:0];
|
2130
2509
|
if (id_src1) {
|
2131
|
-
[encoder setBuffer:id_src1 offset:offs_src1
|
2510
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
2132
2511
|
} else {
|
2133
|
-
[encoder setBuffer:
|
2512
|
+
[encoder setBuffer:h_src0 offset:offs_src0 atIndex:1];
|
2134
2513
|
}
|
2135
|
-
[encoder setBuffer:id_dst
|
2136
|
-
[encoder setBytes:&args
|
2514
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
2515
|
+
[encoder setBytes:&args length:sizeof(args) atIndex:3];
|
2137
2516
|
|
2138
2517
|
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
2139
2518
|
|
@@ -3846,12 +4225,14 @@ static void lm_ggml_metal_encode_node(
|
|
3846
4225
|
// TODO: add vec kernels for (ne00%64 == 0) and maybe also for (ne00%32 == 0)
|
3847
4226
|
// for now avoiding mainly to keep the number of templates/kernels a bit lower
|
3848
4227
|
// these are now trivial to add after: https://github.com/ggml-org/llama.cpp/pull/12612
|
3849
|
-
if (ne01 >= 4 || (ne00%128 != 0 && ne00 != 192)) {
|
4228
|
+
if (ne01 >= 4 || (ne00%128 != 0 && ne00 != 96 && ne00 != 192 && ne00 != 576)) {
|
3850
4229
|
switch (src1->type) {
|
3851
4230
|
case LM_GGML_TYPE_F16:
|
3852
4231
|
{
|
3853
4232
|
if (ne00 == 192 && ne20 == 128) {
|
3854
4233
|
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128].pipeline;
|
4234
|
+
} else if (ne00 == 576 && ne20 == 512) {
|
4235
|
+
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512].pipeline;
|
3855
4236
|
} else {
|
3856
4237
|
switch (ne00) {
|
3857
4238
|
case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
|
@@ -3874,6 +4255,8 @@ static void lm_ggml_metal_encode_node(
|
|
3874
4255
|
{
|
3875
4256
|
if (ne00 == 192 && ne20 == 128) {
|
3876
4257
|
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128].pipeline;
|
4258
|
+
} else if (ne00 == 576 && ne20 == 512) {
|
4259
|
+
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512].pipeline;
|
3877
4260
|
} else {
|
3878
4261
|
switch (ne00) {
|
3879
4262
|
case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64 ].pipeline; break;
|
@@ -3896,6 +4279,8 @@ static void lm_ggml_metal_encode_node(
|
|
3896
4279
|
{
|
3897
4280
|
if (ne00 == 192 && ne20 == 128) {
|
3898
4281
|
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128].pipeline;
|
4282
|
+
} else if (ne00 == 576 && ne20 == 512) {
|
4283
|
+
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512].pipeline;
|
3899
4284
|
} else {
|
3900
4285
|
switch (ne00) {
|
3901
4286
|
case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64 ].pipeline; break;
|
@@ -3918,6 +4303,8 @@ static void lm_ggml_metal_encode_node(
|
|
3918
4303
|
{
|
3919
4304
|
if (ne00 == 192 && ne20 == 128) {
|
3920
4305
|
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128].pipeline;
|
4306
|
+
} else if (ne00 == 576 && ne20 == 512) {
|
4307
|
+
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512].pipeline;
|
3921
4308
|
} else {
|
3922
4309
|
switch (ne00) {
|
3923
4310
|
case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64 ].pipeline; break;
|
@@ -3940,6 +4327,8 @@ static void lm_ggml_metal_encode_node(
|
|
3940
4327
|
{
|
3941
4328
|
if (ne00 == 192 && ne20 == 128) {
|
3942
4329
|
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128].pipeline;
|
4330
|
+
} else if (ne00 == 576 && ne20 == 512) {
|
4331
|
+
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512].pipeline;
|
3943
4332
|
} else {
|
3944
4333
|
switch (ne00) {
|
3945
4334
|
case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64 ].pipeline; break;
|
@@ -3962,6 +4351,8 @@ static void lm_ggml_metal_encode_node(
|
|
3962
4351
|
{
|
3963
4352
|
if (ne00 == 192 && ne20 == 128) {
|
3964
4353
|
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128].pipeline;
|
4354
|
+
} else if (ne00 == 576 && ne20 == 512) {
|
4355
|
+
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512].pipeline;
|
3965
4356
|
} else {
|
3966
4357
|
switch (ne00) {
|
3967
4358
|
case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64 ].pipeline; break;
|
@@ -3984,6 +4375,8 @@ static void lm_ggml_metal_encode_node(
|
|
3984
4375
|
{
|
3985
4376
|
if (ne00 == 192 && ne20 == 128) {
|
3986
4377
|
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128].pipeline;
|
4378
|
+
} else if (ne00 == 576 && ne20 == 512) {
|
4379
|
+
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512].pipeline;
|
3987
4380
|
} else {
|
3988
4381
|
switch (ne00) {
|
3989
4382
|
case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64 ].pipeline; break;
|
@@ -4013,6 +4406,24 @@ static void lm_ggml_metal_encode_node(
|
|
4013
4406
|
use_vec_kernel = true;
|
4014
4407
|
|
4015
4408
|
switch (ne00) {
|
4409
|
+
case 96:
|
4410
|
+
{
|
4411
|
+
switch (src1->type) {
|
4412
|
+
case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96].pipeline; break;
|
4413
|
+
case LM_GGML_TYPE_BF16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96].pipeline; break;
|
4414
|
+
case LM_GGML_TYPE_Q4_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96].pipeline; break;
|
4415
|
+
case LM_GGML_TYPE_Q4_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H96].pipeline; break;
|
4416
|
+
case LM_GGML_TYPE_Q5_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H96].pipeline; break;
|
4417
|
+
case LM_GGML_TYPE_Q5_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H96].pipeline; break;
|
4418
|
+
case LM_GGML_TYPE_Q8_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H96].pipeline; break;
|
4419
|
+
default:
|
4420
|
+
{
|
4421
|
+
LM_GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
|
4422
|
+
LM_GGML_LOG_ERROR("add template specialization for this type\n");
|
4423
|
+
LM_GGML_ABORT("add template specialization for this type");
|
4424
|
+
}
|
4425
|
+
}
|
4426
|
+
} break;
|
4016
4427
|
case 128:
|
4017
4428
|
{
|
4018
4429
|
switch (src1->type) {
|
@@ -4085,12 +4496,36 @@ static void lm_ggml_metal_encode_node(
|
|
4085
4496
|
}
|
4086
4497
|
}
|
4087
4498
|
} break;
|
4499
|
+
case 576:
|
4500
|
+
{
|
4501
|
+
if (ne20 == 512) {
|
4502
|
+
switch (src1->type) {
|
4503
|
+
case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512].pipeline; break;
|
4504
|
+
case LM_GGML_TYPE_BF16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK576_HV512].pipeline; break;
|
4505
|
+
case LM_GGML_TYPE_Q4_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK576_HV512].pipeline; break;
|
4506
|
+
case LM_GGML_TYPE_Q4_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK576_HV512].pipeline; break;
|
4507
|
+
case LM_GGML_TYPE_Q5_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512].pipeline; break;
|
4508
|
+
case LM_GGML_TYPE_Q5_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512].pipeline; break;
|
4509
|
+
case LM_GGML_TYPE_Q8_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512].pipeline; break;
|
4510
|
+
default:
|
4511
|
+
{
|
4512
|
+
LM_GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
|
4513
|
+
LM_GGML_LOG_ERROR("add template specialization for this type\n");
|
4514
|
+
LM_GGML_ABORT("add template specialization for this type");
|
4515
|
+
}
|
4516
|
+
}
|
4517
|
+
} else {
|
4518
|
+
LM_GGML_LOG_ERROR("unsupported size: %lld\n", ne20);
|
4519
|
+
LM_GGML_LOG_ERROR("add template specialization for this size\n");
|
4520
|
+
LM_GGML_ABORT("add template specialization for this size");
|
4521
|
+
}
|
4522
|
+
} break;
|
4088
4523
|
default:
|
4089
|
-
|
4090
|
-
|
4091
|
-
|
4092
|
-
|
4093
|
-
|
4524
|
+
{
|
4525
|
+
LM_GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
|
4526
|
+
LM_GGML_LOG_ERROR("add template specialization for this size\n");
|
4527
|
+
LM_GGML_ABORT("add template specialization for this size");
|
4528
|
+
}
|
4094
4529
|
}
|
4095
4530
|
}
|
4096
4531
|
|
@@ -4486,6 +4921,8 @@ static void lm_ggml_metal_encode_node(
|
|
4486
4921
|
LM_GGML_ABORT("fatal error");
|
4487
4922
|
}
|
4488
4923
|
}
|
4924
|
+
|
4925
|
+
return true;
|
4489
4926
|
}
|
4490
4927
|
|
4491
4928
|
static enum lm_ggml_status lm_ggml_metal_graph_compute(
|
@@ -4539,25 +4976,25 @@ static enum lm_ggml_status lm_ggml_metal_graph_compute(
|
|
4539
4976
|
}
|
4540
4977
|
|
4541
4978
|
// the main thread commits the first few commands immediately
|
4542
|
-
//
|
4979
|
+
// cmd_buf[n_cb]
|
4543
4980
|
{
|
4544
|
-
id<MTLCommandBuffer>
|
4545
|
-
ctx->
|
4981
|
+
id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
|
4982
|
+
ctx->cmd_bufs[n_cb].obj = cmd_buf;
|
4546
4983
|
|
4547
|
-
[
|
4984
|
+
[cmd_buf enqueue];
|
4548
4985
|
ctx->encode_async(n_cb);
|
4549
4986
|
}
|
4550
4987
|
|
4551
4988
|
// prepare the rest of the command buffers asynchronously
|
4552
|
-
//
|
4989
|
+
// cmd_buf[0.. n_cb)
|
4553
4990
|
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
|
4554
|
-
id<MTLCommandBuffer>
|
4555
|
-
ctx->
|
4991
|
+
id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
|
4992
|
+
ctx->cmd_bufs[cb_idx].obj = cmd_buf;
|
4556
4993
|
|
4557
4994
|
// always enqueue the first two command buffers
|
4558
4995
|
// enqueue all of the command buffers if we don't need to abort
|
4559
4996
|
if (cb_idx < 2 || ctx->abort_callback == NULL) {
|
4560
|
-
[
|
4997
|
+
[cmd_buf enqueue];
|
4561
4998
|
}
|
4562
4999
|
}
|
4563
5000
|
|
@@ -4566,14 +5003,14 @@ static enum lm_ggml_status lm_ggml_metal_graph_compute(
|
|
4566
5003
|
// wait for completion and check status of each command buffer
|
4567
5004
|
// needed to detect if the device ran out-of-memory for example (#1881)
|
4568
5005
|
{
|
4569
|
-
id<MTLCommandBuffer>
|
4570
|
-
[
|
5006
|
+
id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[n_cb].obj;
|
5007
|
+
[cmd_buf waitUntilCompleted];
|
4571
5008
|
|
4572
|
-
MTLCommandBufferStatus status = [
|
5009
|
+
MTLCommandBufferStatus status = [cmd_buf status];
|
4573
5010
|
if (status != MTLCommandBufferStatusCompleted) {
|
4574
5011
|
LM_GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, n_cb, status);
|
4575
5012
|
if (status == MTLCommandBufferStatusError) {
|
4576
|
-
LM_GGML_LOG_INFO("error: %s\n", [[
|
5013
|
+
LM_GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
|
4577
5014
|
}
|
4578
5015
|
|
4579
5016
|
return LM_GGML_STATUS_FAILED;
|
@@ -4581,20 +5018,20 @@ static enum lm_ggml_status lm_ggml_metal_graph_compute(
|
|
4581
5018
|
}
|
4582
5019
|
|
4583
5020
|
for (int i = 0; i < n_cb; ++i) {
|
4584
|
-
id<MTLCommandBuffer>
|
4585
|
-
[
|
5021
|
+
id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[i].obj;
|
5022
|
+
[cmd_buf waitUntilCompleted];
|
4586
5023
|
|
4587
|
-
MTLCommandBufferStatus status = [
|
5024
|
+
MTLCommandBufferStatus status = [cmd_buf status];
|
4588
5025
|
if (status != MTLCommandBufferStatusCompleted) {
|
4589
5026
|
LM_GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
|
4590
5027
|
if (status == MTLCommandBufferStatusError) {
|
4591
|
-
LM_GGML_LOG_INFO("error: %s\n", [[
|
5028
|
+
LM_GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
|
4592
5029
|
}
|
4593
5030
|
|
4594
5031
|
return LM_GGML_STATUS_FAILED;
|
4595
5032
|
}
|
4596
5033
|
|
4597
|
-
id<MTLCommandBuffer> next_buffer = (i + 1 < n_cb ? ctx->
|
5034
|
+
id<MTLCommandBuffer> next_buffer = (i + 1 < n_cb ? ctx->cmd_bufs[i + 1].obj : nil);
|
4598
5035
|
if (!next_buffer) {
|
4599
5036
|
continue;
|
4600
5037
|
}
|
@@ -4977,8 +5414,9 @@ static void lm_ggml_backend_metal_set_n_cb(lm_ggml_backend_t backend, int n_cb)
|
|
4977
5414
|
|
4978
5415
|
const int n_nodes_per_cb = ctx->n_nodes_per_cb;
|
4979
5416
|
|
4980
|
-
id<MTLCommandBuffer>
|
4981
|
-
|
5417
|
+
id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[cb_idx].obj;
|
5418
|
+
|
5419
|
+
id<MTLComputeCommandEncoder> encoder = [cmd_buf computeCommandEncoder];
|
4982
5420
|
|
4983
5421
|
int node_start = 0;
|
4984
5422
|
int node_end = n_nodes_0;
|
@@ -4990,22 +5428,29 @@ static void lm_ggml_backend_metal_set_n_cb(lm_ggml_backend_t backend, int n_cb)
|
|
4990
5428
|
|
4991
5429
|
const bool should_capture = ctx->capture_next_compute;
|
4992
5430
|
|
5431
|
+
struct lm_ggml_metal_mem_pool * mem_pool = ctx->cmd_bufs[cb_idx].mem_pool;
|
5432
|
+
lm_ggml_metal_mem_pool_reset(mem_pool);
|
5433
|
+
|
4993
5434
|
for (int idx = node_start; idx < node_end; ++idx) {
|
4994
5435
|
if (should_capture) {
|
4995
5436
|
[encoder pushDebugGroup:[NSString stringWithCString:lm_ggml_op_desc(lm_ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
|
4996
5437
|
}
|
4997
5438
|
|
4998
|
-
lm_ggml_metal_encode_node(backend, idx, encoder);
|
5439
|
+
const bool res = lm_ggml_metal_encode_node(backend, idx, encoder, mem_pool);
|
4999
5440
|
|
5000
5441
|
if (should_capture) {
|
5001
5442
|
[encoder popDebugGroup];
|
5002
5443
|
}
|
5444
|
+
|
5445
|
+
if (!res) {
|
5446
|
+
break;
|
5447
|
+
}
|
5003
5448
|
}
|
5004
5449
|
|
5005
5450
|
[encoder endEncoding];
|
5006
5451
|
|
5007
5452
|
if (cb_idx < 2 || ctx->abort_callback == NULL) {
|
5008
|
-
[
|
5453
|
+
[cmd_buf commit];
|
5009
5454
|
}
|
5010
5455
|
});
|
5011
5456
|
}
|