cui-llama.rn 1.4.6 → 1.5.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/src/main/CMakeLists.txt +9 -2
- package/android/src/main/jni.cpp +52 -34
- package/android/src/main/jniLibs/arm64-v8a/librnllama.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod_i8mm.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_i8mm.so +0 -0
- package/android/src/main/jniLibs/x86_64/librnllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/librnllama_x86_64.so +0 -0
- package/cpp/binary-ops.cpp +158 -0
- package/cpp/binary-ops.h +16 -0
- package/cpp/chat.cpp +1769 -1779
- package/cpp/chat.h +9 -1
- package/cpp/common.cpp +20 -522
- package/cpp/common.h +13 -36
- package/cpp/cpu-common.h +72 -0
- package/cpp/ggml-common.h +12 -6
- package/cpp/ggml-cpu-aarch64.cpp +1557 -80
- package/cpp/ggml-cpu-impl.h +2 -21
- package/cpp/ggml-cpu-quants.c +904 -405
- package/cpp/ggml-cpu.c +909 -13237
- package/cpp/ggml-impl.h +50 -23
- package/cpp/ggml-metal-impl.h +77 -3
- package/cpp/ggml-metal.m +794 -580
- package/cpp/ggml.c +92 -3
- package/cpp/ggml.h +29 -5
- package/cpp/gguf.cpp +1 -0
- package/cpp/llama-adapter.cpp +55 -20
- package/cpp/llama-adapter.h +11 -9
- package/cpp/llama-arch.cpp +217 -16
- package/cpp/llama-arch.h +25 -0
- package/cpp/llama-batch.h +2 -2
- package/cpp/llama-chat.cpp +54 -2
- package/cpp/llama-chat.h +3 -0
- package/cpp/llama-context.cpp +2294 -1238
- package/cpp/llama-context.h +214 -77
- package/cpp/llama-cparams.h +1 -0
- package/cpp/llama-graph.cpp +1695 -0
- package/cpp/llama-graph.h +592 -0
- package/cpp/llama-hparams.cpp +8 -0
- package/cpp/llama-hparams.h +17 -0
- package/cpp/llama-io.cpp +15 -0
- package/cpp/llama-io.h +35 -0
- package/cpp/llama-kv-cache.cpp +965 -303
- package/cpp/llama-kv-cache.h +145 -151
- package/cpp/llama-memory.cpp +1 -0
- package/cpp/llama-memory.h +21 -0
- package/cpp/llama-mmap.cpp +1 -1
- package/cpp/llama-model-loader.cpp +10 -5
- package/cpp/llama-model-loader.h +5 -3
- package/cpp/llama-model.cpp +9194 -201
- package/cpp/llama-model.h +40 -1
- package/cpp/llama-sampling.cpp +5 -0
- package/cpp/llama-vocab.cpp +36 -5
- package/cpp/llama.cpp +51 -9984
- package/cpp/llama.h +102 -22
- package/cpp/log.cpp +34 -0
- package/cpp/minja/chat-template.hpp +15 -7
- package/cpp/minja/minja.hpp +120 -94
- package/cpp/ops.cpp +8723 -0
- package/cpp/ops.h +128 -0
- package/cpp/rn-llama.cpp +44 -53
- package/cpp/rn-llama.h +2 -12
- package/cpp/sampling.cpp +3 -0
- package/cpp/sgemm.cpp +533 -88
- package/cpp/simd-mappings.h +888 -0
- package/cpp/speculative.cpp +4 -4
- package/cpp/unary-ops.cpp +186 -0
- package/cpp/unary-ops.h +28 -0
- package/cpp/vec.cpp +258 -0
- package/cpp/vec.h +802 -0
- package/ios/CMakeLists.txt +5 -2
- package/ios/RNLlama.mm +2 -2
- package/ios/RNLlamaContext.mm +40 -24
- package/package.json +1 -1
- package/src/NativeRNLlama.ts +6 -4
- package/src/index.ts +3 -1
- package/cpp/chat-template.hpp +0 -529
- package/cpp/minja.hpp +0 -2915
package/cpp/ggml-metal.m
CHANGED
@@ -184,10 +184,13 @@ enum lm_ggml_metal_kernel_type {
|
|
184
184
|
LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
|
185
185
|
LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
|
186
186
|
LM_GGML_METAL_KERNEL_TYPE_RMS_NORM,
|
187
|
+
LM_GGML_METAL_KERNEL_TYPE_L2_NORM,
|
187
188
|
LM_GGML_METAL_KERNEL_TYPE_GROUP_NORM,
|
188
189
|
LM_GGML_METAL_KERNEL_TYPE_NORM,
|
189
190
|
LM_GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
|
190
191
|
LM_GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,
|
192
|
+
LM_GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32,
|
193
|
+
LM_GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32,
|
191
194
|
LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
|
192
195
|
LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
|
193
196
|
LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
|
@@ -348,42 +351,56 @@ enum lm_ggml_metal_kernel_type {
|
|
348
351
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,
|
349
352
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
|
350
353
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
|
354
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H192,
|
355
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128,
|
351
356
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
|
352
357
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64,
|
353
358
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80,
|
354
359
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96,
|
355
360
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112,
|
356
361
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128,
|
362
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H192,
|
363
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128,
|
357
364
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256,
|
358
365
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64,
|
359
366
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80,
|
360
367
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96,
|
361
368
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112,
|
362
369
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128,
|
370
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H192,
|
371
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128,
|
363
372
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256,
|
364
373
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64,
|
365
374
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80,
|
366
375
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96,
|
367
376
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112,
|
368
377
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128,
|
378
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H192,
|
379
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128,
|
369
380
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256,
|
370
381
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64,
|
371
382
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80,
|
372
383
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96,
|
373
384
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112,
|
374
385
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128,
|
386
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H192,
|
387
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128,
|
375
388
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256,
|
376
389
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64,
|
377
390
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80,
|
378
391
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96,
|
379
392
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112,
|
380
393
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128,
|
394
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H192,
|
395
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128,
|
381
396
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256,
|
382
397
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64,
|
383
398
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80,
|
384
399
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96,
|
385
400
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112,
|
386
401
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128,
|
402
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192,
|
403
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128,
|
387
404
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,
|
388
405
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
|
389
406
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128,
|
@@ -392,6 +409,20 @@ enum lm_ggml_metal_kernel_type {
|
|
392
409
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128,
|
393
410
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128,
|
394
411
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128,
|
412
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H192,
|
413
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H192,
|
414
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H192,
|
415
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H192,
|
416
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H192,
|
417
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H192,
|
418
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H192,
|
419
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK192_HV128,
|
420
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK192_HV128,
|
421
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK192_HV128,
|
422
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK192_HV128,
|
423
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK192_HV128,
|
424
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK192_HV128,
|
425
|
+
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK192_HV128,
|
395
426
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,
|
396
427
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256,
|
397
428
|
LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256,
|
@@ -755,310 +786,341 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
|
|
755
786
|
|
756
787
|
// simd_sum and simd_max requires MTLGPUFamilyApple7
|
757
788
|
|
758
|
-
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ADD,
|
759
|
-
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ADD_ROW,
|
760
|
-
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SUB,
|
761
|
-
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SUB_ROW,
|
762
|
-
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL,
|
763
|
-
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_ROW,
|
764
|
-
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_DIV,
|
765
|
-
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_DIV_ROW,
|
766
|
-
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_REPEAT_F32,
|
767
|
-
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_REPEAT_F16,
|
768
|
-
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_REPEAT_I32,
|
769
|
-
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_REPEAT_I16,
|
770
|
-
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SCALE,
|
771
|
-
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SCALE_4,
|
772
|
-
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CLAMP,
|
773
|
-
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_TANH,
|
774
|
-
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_RELU,
|
775
|
-
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SIGMOID,
|
776
|
-
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GELU,
|
777
|
-
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GELU_4,
|
778
|
-
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GELU_QUICK,
|
779
|
-
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,
|
780
|
-
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SILU,
|
781
|
-
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SILU_4,
|
782
|
-
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ELU,
|
783
|
-
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16,
|
784
|
-
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4,
|
785
|
-
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32,
|
786
|
-
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4,
|
787
|
-
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF,
|
788
|
-
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8,
|
789
|
-
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_F32,
|
790
|
-
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_F16,
|
791
|
-
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16,
|
792
|
-
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0,
|
793
|
-
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1,
|
794
|
-
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0,
|
795
|
-
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1,
|
796
|
-
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0,
|
797
|
-
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K,
|
798
|
-
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K,
|
799
|
-
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K,
|
800
|
-
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K,
|
801
|
-
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K,
|
802
|
-
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS,
|
803
|
-
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS,
|
804
|
-
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS,
|
805
|
-
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S,
|
806
|
-
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S,
|
807
|
-
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S,
|
808
|
-
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M,
|
809
|
-
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,
|
810
|
-
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
|
811
|
-
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
|
812
|
-
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_RMS_NORM,
|
813
|
-
LM_GGML_METAL_ADD_KERNEL(
|
814
|
-
LM_GGML_METAL_ADD_KERNEL(
|
815
|
-
LM_GGML_METAL_ADD_KERNEL(
|
816
|
-
LM_GGML_METAL_ADD_KERNEL(
|
817
|
-
LM_GGML_METAL_ADD_KERNEL(
|
818
|
-
LM_GGML_METAL_ADD_KERNEL(
|
819
|
-
LM_GGML_METAL_ADD_KERNEL(
|
820
|
-
LM_GGML_METAL_ADD_KERNEL(
|
821
|
-
LM_GGML_METAL_ADD_KERNEL(
|
822
|
-
LM_GGML_METAL_ADD_KERNEL(
|
823
|
-
LM_GGML_METAL_ADD_KERNEL(
|
824
|
-
LM_GGML_METAL_ADD_KERNEL(
|
825
|
-
LM_GGML_METAL_ADD_KERNEL(
|
826
|
-
LM_GGML_METAL_ADD_KERNEL(
|
827
|
-
LM_GGML_METAL_ADD_KERNEL(
|
828
|
-
LM_GGML_METAL_ADD_KERNEL(
|
829
|
-
LM_GGML_METAL_ADD_KERNEL(
|
830
|
-
LM_GGML_METAL_ADD_KERNEL(
|
831
|
-
LM_GGML_METAL_ADD_KERNEL(
|
832
|
-
LM_GGML_METAL_ADD_KERNEL(
|
833
|
-
LM_GGML_METAL_ADD_KERNEL(
|
834
|
-
LM_GGML_METAL_ADD_KERNEL(
|
835
|
-
LM_GGML_METAL_ADD_KERNEL(
|
836
|
-
LM_GGML_METAL_ADD_KERNEL(
|
837
|
-
LM_GGML_METAL_ADD_KERNEL(
|
838
|
-
LM_GGML_METAL_ADD_KERNEL(
|
839
|
-
LM_GGML_METAL_ADD_KERNEL(
|
840
|
-
LM_GGML_METAL_ADD_KERNEL(
|
841
|
-
LM_GGML_METAL_ADD_KERNEL(
|
842
|
-
LM_GGML_METAL_ADD_KERNEL(
|
843
|
-
LM_GGML_METAL_ADD_KERNEL(
|
844
|
-
LM_GGML_METAL_ADD_KERNEL(
|
845
|
-
LM_GGML_METAL_ADD_KERNEL(
|
846
|
-
LM_GGML_METAL_ADD_KERNEL(
|
847
|
-
LM_GGML_METAL_ADD_KERNEL(
|
848
|
-
LM_GGML_METAL_ADD_KERNEL(
|
849
|
-
LM_GGML_METAL_ADD_KERNEL(
|
850
|
-
LM_GGML_METAL_ADD_KERNEL(
|
851
|
-
LM_GGML_METAL_ADD_KERNEL(
|
852
|
-
LM_GGML_METAL_ADD_KERNEL(
|
853
|
-
LM_GGML_METAL_ADD_KERNEL(
|
854
|
-
LM_GGML_METAL_ADD_KERNEL(
|
855
|
-
LM_GGML_METAL_ADD_KERNEL(
|
856
|
-
LM_GGML_METAL_ADD_KERNEL(
|
857
|
-
LM_GGML_METAL_ADD_KERNEL(
|
858
|
-
LM_GGML_METAL_ADD_KERNEL(
|
859
|
-
LM_GGML_METAL_ADD_KERNEL(
|
860
|
-
LM_GGML_METAL_ADD_KERNEL(
|
861
|
-
LM_GGML_METAL_ADD_KERNEL(
|
862
|
-
LM_GGML_METAL_ADD_KERNEL(
|
863
|
-
LM_GGML_METAL_ADD_KERNEL(
|
864
|
-
LM_GGML_METAL_ADD_KERNEL(
|
865
|
-
LM_GGML_METAL_ADD_KERNEL(
|
866
|
-
LM_GGML_METAL_ADD_KERNEL(
|
867
|
-
LM_GGML_METAL_ADD_KERNEL(
|
868
|
-
LM_GGML_METAL_ADD_KERNEL(
|
869
|
-
LM_GGML_METAL_ADD_KERNEL(
|
870
|
-
LM_GGML_METAL_ADD_KERNEL(
|
871
|
-
LM_GGML_METAL_ADD_KERNEL(
|
872
|
-
LM_GGML_METAL_ADD_KERNEL(
|
873
|
-
LM_GGML_METAL_ADD_KERNEL(
|
874
|
-
LM_GGML_METAL_ADD_KERNEL(
|
875
|
-
LM_GGML_METAL_ADD_KERNEL(
|
876
|
-
LM_GGML_METAL_ADD_KERNEL(
|
877
|
-
LM_GGML_METAL_ADD_KERNEL(
|
878
|
-
LM_GGML_METAL_ADD_KERNEL(
|
879
|
-
LM_GGML_METAL_ADD_KERNEL(
|
880
|
-
LM_GGML_METAL_ADD_KERNEL(
|
881
|
-
LM_GGML_METAL_ADD_KERNEL(
|
882
|
-
LM_GGML_METAL_ADD_KERNEL(
|
883
|
-
LM_GGML_METAL_ADD_KERNEL(
|
884
|
-
LM_GGML_METAL_ADD_KERNEL(
|
885
|
-
LM_GGML_METAL_ADD_KERNEL(
|
886
|
-
LM_GGML_METAL_ADD_KERNEL(
|
887
|
-
|
888
|
-
|
889
|
-
|
890
|
-
|
891
|
-
|
892
|
-
|
893
|
-
LM_GGML_METAL_ADD_KERNEL(
|
894
|
-
LM_GGML_METAL_ADD_KERNEL(
|
895
|
-
LM_GGML_METAL_ADD_KERNEL(
|
896
|
-
LM_GGML_METAL_ADD_KERNEL(
|
897
|
-
LM_GGML_METAL_ADD_KERNEL(
|
898
|
-
LM_GGML_METAL_ADD_KERNEL(
|
899
|
-
LM_GGML_METAL_ADD_KERNEL(
|
900
|
-
LM_GGML_METAL_ADD_KERNEL(
|
901
|
-
LM_GGML_METAL_ADD_KERNEL(
|
902
|
-
LM_GGML_METAL_ADD_KERNEL(
|
903
|
-
LM_GGML_METAL_ADD_KERNEL(
|
904
|
-
LM_GGML_METAL_ADD_KERNEL(
|
905
|
-
LM_GGML_METAL_ADD_KERNEL(
|
906
|
-
LM_GGML_METAL_ADD_KERNEL(
|
907
|
-
LM_GGML_METAL_ADD_KERNEL(
|
908
|
-
LM_GGML_METAL_ADD_KERNEL(
|
909
|
-
LM_GGML_METAL_ADD_KERNEL(
|
910
|
-
LM_GGML_METAL_ADD_KERNEL(
|
911
|
-
LM_GGML_METAL_ADD_KERNEL(
|
912
|
-
LM_GGML_METAL_ADD_KERNEL(
|
913
|
-
LM_GGML_METAL_ADD_KERNEL(
|
914
|
-
LM_GGML_METAL_ADD_KERNEL(
|
915
|
-
LM_GGML_METAL_ADD_KERNEL(
|
916
|
-
LM_GGML_METAL_ADD_KERNEL(
|
917
|
-
LM_GGML_METAL_ADD_KERNEL(
|
918
|
-
LM_GGML_METAL_ADD_KERNEL(
|
919
|
-
LM_GGML_METAL_ADD_KERNEL(
|
920
|
-
LM_GGML_METAL_ADD_KERNEL(
|
921
|
-
LM_GGML_METAL_ADD_KERNEL(
|
922
|
-
LM_GGML_METAL_ADD_KERNEL(
|
923
|
-
LM_GGML_METAL_ADD_KERNEL(
|
924
|
-
LM_GGML_METAL_ADD_KERNEL(
|
925
|
-
LM_GGML_METAL_ADD_KERNEL(
|
926
|
-
LM_GGML_METAL_ADD_KERNEL(
|
927
|
-
LM_GGML_METAL_ADD_KERNEL(
|
928
|
-
LM_GGML_METAL_ADD_KERNEL(
|
929
|
-
LM_GGML_METAL_ADD_KERNEL(
|
930
|
-
LM_GGML_METAL_ADD_KERNEL(
|
931
|
-
LM_GGML_METAL_ADD_KERNEL(
|
932
|
-
LM_GGML_METAL_ADD_KERNEL(
|
933
|
-
LM_GGML_METAL_ADD_KERNEL(
|
934
|
-
LM_GGML_METAL_ADD_KERNEL(
|
935
|
-
LM_GGML_METAL_ADD_KERNEL(
|
936
|
-
LM_GGML_METAL_ADD_KERNEL(
|
937
|
-
LM_GGML_METAL_ADD_KERNEL(
|
938
|
-
LM_GGML_METAL_ADD_KERNEL(
|
939
|
-
LM_GGML_METAL_ADD_KERNEL(
|
940
|
-
LM_GGML_METAL_ADD_KERNEL(
|
941
|
-
LM_GGML_METAL_ADD_KERNEL(
|
942
|
-
LM_GGML_METAL_ADD_KERNEL(
|
943
|
-
LM_GGML_METAL_ADD_KERNEL(
|
944
|
-
LM_GGML_METAL_ADD_KERNEL(
|
945
|
-
LM_GGML_METAL_ADD_KERNEL(
|
946
|
-
LM_GGML_METAL_ADD_KERNEL(
|
947
|
-
LM_GGML_METAL_ADD_KERNEL(
|
948
|
-
LM_GGML_METAL_ADD_KERNEL(
|
949
|
-
LM_GGML_METAL_ADD_KERNEL(
|
950
|
-
LM_GGML_METAL_ADD_KERNEL(
|
951
|
-
LM_GGML_METAL_ADD_KERNEL(
|
952
|
-
LM_GGML_METAL_ADD_KERNEL(
|
953
|
-
LM_GGML_METAL_ADD_KERNEL(
|
954
|
-
LM_GGML_METAL_ADD_KERNEL(
|
955
|
-
LM_GGML_METAL_ADD_KERNEL(
|
956
|
-
LM_GGML_METAL_ADD_KERNEL(
|
957
|
-
LM_GGML_METAL_ADD_KERNEL(
|
958
|
-
LM_GGML_METAL_ADD_KERNEL(
|
959
|
-
LM_GGML_METAL_ADD_KERNEL(
|
960
|
-
LM_GGML_METAL_ADD_KERNEL(
|
961
|
-
LM_GGML_METAL_ADD_KERNEL(
|
962
|
-
LM_GGML_METAL_ADD_KERNEL(
|
963
|
-
LM_GGML_METAL_ADD_KERNEL(
|
964
|
-
LM_GGML_METAL_ADD_KERNEL(
|
965
|
-
LM_GGML_METAL_ADD_KERNEL(
|
966
|
-
LM_GGML_METAL_ADD_KERNEL(
|
967
|
-
LM_GGML_METAL_ADD_KERNEL(
|
968
|
-
LM_GGML_METAL_ADD_KERNEL(
|
969
|
-
LM_GGML_METAL_ADD_KERNEL(
|
970
|
-
LM_GGML_METAL_ADD_KERNEL(
|
971
|
-
LM_GGML_METAL_ADD_KERNEL(
|
972
|
-
LM_GGML_METAL_ADD_KERNEL(
|
973
|
-
LM_GGML_METAL_ADD_KERNEL(
|
974
|
-
LM_GGML_METAL_ADD_KERNEL(
|
975
|
-
LM_GGML_METAL_ADD_KERNEL(
|
976
|
-
LM_GGML_METAL_ADD_KERNEL(
|
977
|
-
LM_GGML_METAL_ADD_KERNEL(
|
978
|
-
LM_GGML_METAL_ADD_KERNEL(
|
979
|
-
LM_GGML_METAL_ADD_KERNEL(
|
980
|
-
LM_GGML_METAL_ADD_KERNEL(
|
981
|
-
LM_GGML_METAL_ADD_KERNEL(
|
982
|
-
LM_GGML_METAL_ADD_KERNEL(
|
983
|
-
LM_GGML_METAL_ADD_KERNEL(
|
984
|
-
LM_GGML_METAL_ADD_KERNEL(
|
985
|
-
LM_GGML_METAL_ADD_KERNEL(
|
986
|
-
LM_GGML_METAL_ADD_KERNEL(
|
987
|
-
LM_GGML_METAL_ADD_KERNEL(
|
988
|
-
LM_GGML_METAL_ADD_KERNEL(
|
989
|
-
LM_GGML_METAL_ADD_KERNEL(
|
990
|
-
LM_GGML_METAL_ADD_KERNEL(
|
991
|
-
LM_GGML_METAL_ADD_KERNEL(
|
992
|
-
LM_GGML_METAL_ADD_KERNEL(
|
993
|
-
LM_GGML_METAL_ADD_KERNEL(
|
994
|
-
LM_GGML_METAL_ADD_KERNEL(
|
995
|
-
LM_GGML_METAL_ADD_KERNEL(
|
996
|
-
LM_GGML_METAL_ADD_KERNEL(
|
997
|
-
LM_GGML_METAL_ADD_KERNEL(
|
998
|
-
LM_GGML_METAL_ADD_KERNEL(
|
999
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1000
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1001
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1002
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1003
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1004
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1005
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1006
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1007
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1008
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1009
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1010
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1011
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1012
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1013
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1014
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1015
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1016
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1017
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1018
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1019
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1020
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1021
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1022
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1023
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1024
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1025
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1026
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1027
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1028
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1029
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1030
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1031
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1032
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1033
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1034
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1035
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1036
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1037
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1038
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1039
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1040
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1041
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1042
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1043
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1044
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1045
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1046
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1047
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1048
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1049
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1050
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1051
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1052
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1053
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1054
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1055
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1056
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1057
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1058
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1059
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1060
|
-
LM_GGML_METAL_ADD_KERNEL(
|
1061
|
-
LM_GGML_METAL_ADD_KERNEL(
|
789
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ADD, add, true);
|
790
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true);
|
791
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SUB, sub, true);
|
792
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SUB_ROW, sub_row, true);
|
793
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL, mul, true);
|
794
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true);
|
795
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_DIV, div, true);
|
796
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true);
|
797
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true);
|
798
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true);
|
799
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_REPEAT_I32, repeat_i32, true);
|
800
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_REPEAT_I16, repeat_i16, true);
|
801
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SCALE, scale, true);
|
802
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true);
|
803
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true);
|
804
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_TANH, tanh, true);
|
805
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_RELU, relu, true);
|
806
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SIGMOID, sigmoid, true);
|
807
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GELU, gelu, true);
|
808
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true);
|
809
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true);
|
810
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true);
|
811
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SILU, silu, true);
|
812
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true);
|
813
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ELU, elu, true);
|
814
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, has_simdgroup_reduction);
|
815
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, has_simdgroup_reduction);
|
816
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, has_simdgroup_reduction);
|
817
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4, soft_max_f32_4, has_simdgroup_reduction);
|
818
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true);
|
819
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true);
|
820
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true);
|
821
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true);
|
822
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16, get_rows_bf16, use_bfloat);
|
823
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true);
|
824
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true);
|
825
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true);
|
826
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true);
|
827
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true);
|
828
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true);
|
829
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true);
|
830
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true);
|
831
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K, get_rows_q5_K, true);
|
832
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K, get_rows_q6_K, true);
|
833
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true);
|
834
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true);
|
835
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, get_rows_iq3_xxs, true);
|
836
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S, get_rows_iq3_s, true);
|
837
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S, get_rows_iq2_s, true);
|
838
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true);
|
839
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M, get_rows_iq1_m, true);
|
840
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
|
841
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
|
842
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
|
843
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction);
|
844
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_L2_NORM, l2_norm, has_simdgroup_reduction);
|
845
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction);
|
846
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_NORM, norm, true);
|
847
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true);
|
848
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true);
|
849
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, rwkv_wkv6_f32, true);
|
850
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, rwkv_wkv7_f32, true);
|
851
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction);
|
852
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, has_simdgroup_reduction && use_bfloat);
|
853
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, has_simdgroup_reduction && use_bfloat);
|
854
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4, mul_mv_bf16_f32_l4, has_simdgroup_reduction && use_bfloat);
|
855
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16, mul_mv_bf16_bf16, has_simdgroup_reduction && use_bfloat);
|
856
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, has_simdgroup_reduction);
|
857
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, has_simdgroup_reduction);
|
858
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, has_simdgroup_reduction);
|
859
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, has_simdgroup_reduction);
|
860
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, has_simdgroup_reduction);
|
861
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, has_simdgroup_reduction);
|
862
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, has_simdgroup_reduction);
|
863
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, has_simdgroup_reduction);
|
864
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, has_simdgroup_reduction);
|
865
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2, mul_mv_ext_f16_f32_r1_2, has_simdgroup_reduction);
|
866
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3, mul_mv_ext_f16_f32_r1_3, has_simdgroup_reduction);
|
867
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4, mul_mv_ext_f16_f32_r1_4, has_simdgroup_reduction);
|
868
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_5, mul_mv_ext_f16_f32_r1_5, has_simdgroup_reduction);
|
869
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_2, mul_mv_ext_q4_0_f32_r1_2, has_simdgroup_reduction);
|
870
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_3, mul_mv_ext_q4_0_f32_r1_3, has_simdgroup_reduction);
|
871
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_4, mul_mv_ext_q4_0_f32_r1_4, has_simdgroup_reduction);
|
872
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_5, mul_mv_ext_q4_0_f32_r1_5, has_simdgroup_reduction);
|
873
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_2, mul_mv_ext_q4_1_f32_r1_2, has_simdgroup_reduction);
|
874
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_3, mul_mv_ext_q4_1_f32_r1_3, has_simdgroup_reduction);
|
875
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_4, mul_mv_ext_q4_1_f32_r1_4, has_simdgroup_reduction);
|
876
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_5, mul_mv_ext_q4_1_f32_r1_5, has_simdgroup_reduction);
|
877
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_2, mul_mv_ext_q5_0_f32_r1_2, has_simdgroup_reduction);
|
878
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_3, mul_mv_ext_q5_0_f32_r1_3, has_simdgroup_reduction);
|
879
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_4, mul_mv_ext_q5_0_f32_r1_4, has_simdgroup_reduction);
|
880
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_5, mul_mv_ext_q5_0_f32_r1_5, has_simdgroup_reduction);
|
881
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_2, mul_mv_ext_q5_1_f32_r1_2, has_simdgroup_reduction);
|
882
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_3, mul_mv_ext_q5_1_f32_r1_3, has_simdgroup_reduction);
|
883
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_4, mul_mv_ext_q5_1_f32_r1_4, has_simdgroup_reduction);
|
884
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_5, mul_mv_ext_q5_1_f32_r1_5, has_simdgroup_reduction);
|
885
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_2, mul_mv_ext_q8_0_f32_r1_2, has_simdgroup_reduction);
|
886
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3, mul_mv_ext_q8_0_f32_r1_3, has_simdgroup_reduction);
|
887
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4, mul_mv_ext_q8_0_f32_r1_4, has_simdgroup_reduction);
|
888
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5, mul_mv_ext_q8_0_f32_r1_5, has_simdgroup_reduction);
|
889
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2, mul_mv_ext_q4_K_f32_r1_2, has_simdgroup_reduction);
|
890
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3, mul_mv_ext_q4_K_f32_r1_3, has_simdgroup_reduction);
|
891
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4, mul_mv_ext_q4_K_f32_r1_4, has_simdgroup_reduction);
|
892
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_5, mul_mv_ext_q4_K_f32_r1_5, has_simdgroup_reduction);
|
893
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_2, mul_mv_ext_q5_K_f32_r1_2, has_simdgroup_reduction);
|
894
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_3, mul_mv_ext_q5_K_f32_r1_3, has_simdgroup_reduction);
|
895
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_4, mul_mv_ext_q5_K_f32_r1_4, has_simdgroup_reduction);
|
896
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_5, mul_mv_ext_q5_K_f32_r1_5, has_simdgroup_reduction);
|
897
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_2, mul_mv_ext_q6_K_f32_r1_2, has_simdgroup_reduction);
|
898
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_3, mul_mv_ext_q6_K_f32_r1_3, has_simdgroup_reduction);
|
899
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_4, mul_mv_ext_q6_K_f32_r1_4, has_simdgroup_reduction);
|
900
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_5, mul_mv_ext_q6_K_f32_r1_5, has_simdgroup_reduction);
|
901
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_2, mul_mv_ext_iq4_nl_f32_r1_2, has_simdgroup_reduction);
|
902
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_3, mul_mv_ext_iq4_nl_f32_r1_3, has_simdgroup_reduction);
|
903
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_4, mul_mv_ext_iq4_nl_f32_r1_4, has_simdgroup_reduction);
|
904
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_5, mul_mv_ext_iq4_nl_f32_r1_5, has_simdgroup_reduction);
|
905
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, has_simdgroup_reduction);
|
906
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, has_simdgroup_reduction);
|
907
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, has_simdgroup_reduction);
|
908
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, has_simdgroup_reduction);
|
909
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, has_simdgroup_reduction);
|
910
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, has_simdgroup_reduction);
|
911
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, has_simdgroup_reduction);
|
912
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, has_simdgroup_reduction);
|
913
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, mul_mv_iq3_s_f32, has_simdgroup_reduction);
|
914
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, has_simdgroup_reduction);
|
915
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, has_simdgroup_reduction);
|
916
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32, mul_mv_iq1_m_f32, has_simdgroup_reduction);
|
917
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, has_simdgroup_reduction);
|
918
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, has_simdgroup_reduction);
|
919
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, has_simdgroup_reduction);
|
920
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, has_simdgroup_reduction);
|
921
|
+
//LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, has_simdgroup_reduction);
|
922
|
+
//LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, has_simdgroup_reduction);
|
923
|
+
//LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, has_simdgroup_reduction);
|
924
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32, mul_mv_id_bf16_f32, has_simdgroup_reduction && use_bfloat);
|
925
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, has_simdgroup_reduction);
|
926
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, has_simdgroup_reduction);
|
927
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, has_simdgroup_reduction);
|
928
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, has_simdgroup_reduction);
|
929
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, has_simdgroup_reduction);
|
930
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, has_simdgroup_reduction);
|
931
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, has_simdgroup_reduction);
|
932
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, has_simdgroup_reduction);
|
933
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, has_simdgroup_reduction);
|
934
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, has_simdgroup_reduction);
|
935
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, has_simdgroup_reduction);
|
936
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, has_simdgroup_reduction);
|
937
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, has_simdgroup_reduction);
|
938
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, mul_mv_id_iq3_s_f32, has_simdgroup_reduction);
|
939
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, has_simdgroup_reduction);
|
940
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, has_simdgroup_reduction);
|
941
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, mul_mv_id_iq1_m_f32, has_simdgroup_reduction);
|
942
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, has_simdgroup_reduction);
|
943
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, has_simdgroup_reduction);
|
944
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, has_simdgroup_mm);
|
945
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, has_simdgroup_mm);
|
946
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32, mul_mm_bf16_f32, has_simdgroup_mm && use_bfloat);
|
947
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, has_simdgroup_mm);
|
948
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, has_simdgroup_mm);
|
949
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, has_simdgroup_mm);
|
950
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, has_simdgroup_mm);
|
951
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, has_simdgroup_mm);
|
952
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, has_simdgroup_mm);
|
953
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, has_simdgroup_mm);
|
954
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, has_simdgroup_mm);
|
955
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, has_simdgroup_mm);
|
956
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, has_simdgroup_mm);
|
957
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, has_simdgroup_mm);
|
958
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, has_simdgroup_mm);
|
959
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, has_simdgroup_mm);
|
960
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, mul_mm_iq3_s_f32, has_simdgroup_mm);
|
961
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, has_simdgroup_mm);
|
962
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, has_simdgroup_mm);
|
963
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, has_simdgroup_mm);
|
964
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, has_simdgroup_mm);
|
965
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, has_simdgroup_mm);
|
966
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, has_simdgroup_mm);
|
967
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, has_simdgroup_mm);
|
968
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32, mul_mm_id_bf16_f32, has_simdgroup_mm && use_bfloat);
|
969
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, has_simdgroup_mm);
|
970
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, has_simdgroup_mm);
|
971
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, has_simdgroup_mm);
|
972
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, has_simdgroup_mm);
|
973
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, has_simdgroup_mm);
|
974
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, has_simdgroup_mm);
|
975
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, has_simdgroup_mm);
|
976
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, has_simdgroup_mm);
|
977
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, has_simdgroup_mm);
|
978
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, has_simdgroup_mm);
|
979
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, has_simdgroup_mm);
|
980
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, has_simdgroup_mm);
|
981
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, has_simdgroup_mm);
|
982
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, has_simdgroup_mm);
|
983
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, has_simdgroup_mm);
|
984
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, has_simdgroup_mm);
|
985
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, has_simdgroup_mm);
|
986
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, has_simdgroup_mm);
|
987
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, has_simdgroup_mm);
|
988
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true);
|
989
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true);
|
990
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true);
|
991
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true);
|
992
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
|
993
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
|
994
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16, im2col_ext_f16, true);
|
995
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32, im2col_ext_f32, true);
|
996
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F32_F32, conv_transpose_1d_f32_f32, true);
|
997
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32, conv_transpose_1d_f16_f32, true);
|
998
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
|
999
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
|
1000
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32, pad_reflect_1d_f32, true);
|
1001
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true);
|
1002
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ARANGE_F32, arange_f32, true);
|
1003
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
|
1004
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
|
1005
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
|
1006
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, has_simdgroup_mm);
|
1007
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, has_simdgroup_mm);
|
1008
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, has_simdgroup_mm);
|
1009
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, has_simdgroup_mm);
|
1010
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, has_simdgroup_mm);
|
1011
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H192, flash_attn_ext_f16_h192, has_simdgroup_mm);
|
1012
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128, flash_attn_ext_f16_hk192_hv128, has_simdgroup_mm);
|
1013
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, has_simdgroup_mm);
|
1014
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64, flash_attn_ext_bf16_h64, has_simdgroup_mm && use_bfloat);
|
1015
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80, flash_attn_ext_bf16_h80, has_simdgroup_mm && use_bfloat);
|
1016
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96, flash_attn_ext_bf16_h96, has_simdgroup_mm && use_bfloat);
|
1017
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112, flash_attn_ext_bf16_h112, has_simdgroup_mm && use_bfloat);
|
1018
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128, flash_attn_ext_bf16_h128, has_simdgroup_mm && use_bfloat);
|
1019
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H192, flash_attn_ext_bf16_h192, has_simdgroup_mm && use_bfloat);
|
1020
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128, flash_attn_ext_bf16_hk192_hv128, has_simdgroup_mm && use_bfloat);
|
1021
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256, flash_attn_ext_bf16_h256, has_simdgroup_mm && use_bfloat);
|
1022
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, flash_attn_ext_q4_0_h64, has_simdgroup_mm);
|
1023
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, flash_attn_ext_q4_0_h80, has_simdgroup_mm);
|
1024
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, flash_attn_ext_q4_0_h96, has_simdgroup_mm);
|
1025
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112, flash_attn_ext_q4_0_h112, has_simdgroup_mm);
|
1026
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128, flash_attn_ext_q4_0_h128, has_simdgroup_mm);
|
1027
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H192, flash_attn_ext_q4_0_h192, has_simdgroup_mm);
|
1028
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128, flash_attn_ext_q4_0_hk192_hv128, has_simdgroup_mm);
|
1029
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256, flash_attn_ext_q4_0_h256, has_simdgroup_mm);
|
1030
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64, flash_attn_ext_q4_1_h64, has_simdgroup_mm);
|
1031
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80, flash_attn_ext_q4_1_h80, has_simdgroup_mm);
|
1032
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96, flash_attn_ext_q4_1_h96, has_simdgroup_mm);
|
1033
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112, flash_attn_ext_q4_1_h112, has_simdgroup_mm);
|
1034
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128, flash_attn_ext_q4_1_h128, has_simdgroup_mm);
|
1035
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H192, flash_attn_ext_q4_1_h192, has_simdgroup_mm);
|
1036
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128, flash_attn_ext_q4_1_hk192_hv128, has_simdgroup_mm);
|
1037
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256, flash_attn_ext_q4_1_h256, has_simdgroup_mm);
|
1038
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64, flash_attn_ext_q5_0_h64, has_simdgroup_mm);
|
1039
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80, flash_attn_ext_q5_0_h80, has_simdgroup_mm);
|
1040
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96, flash_attn_ext_q5_0_h96, has_simdgroup_mm);
|
1041
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112, flash_attn_ext_q5_0_h112, has_simdgroup_mm);
|
1042
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128, flash_attn_ext_q5_0_h128, has_simdgroup_mm);
|
1043
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H192, flash_attn_ext_q5_0_h192, has_simdgroup_mm);
|
1044
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128, flash_attn_ext_q5_0_hk192_hv128, has_simdgroup_mm);
|
1045
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256, flash_attn_ext_q5_0_h256, has_simdgroup_mm);
|
1046
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64, flash_attn_ext_q5_1_h64, has_simdgroup_mm);
|
1047
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80, flash_attn_ext_q5_1_h80, has_simdgroup_mm);
|
1048
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96, flash_attn_ext_q5_1_h96, has_simdgroup_mm);
|
1049
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112, flash_attn_ext_q5_1_h112, has_simdgroup_mm);
|
1050
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128, flash_attn_ext_q5_1_h128, has_simdgroup_mm);
|
1051
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H192, flash_attn_ext_q5_1_h192, has_simdgroup_mm);
|
1052
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128, flash_attn_ext_q5_1_hk192_hv128, has_simdgroup_mm);
|
1053
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256, flash_attn_ext_q5_1_h256, has_simdgroup_mm);
|
1054
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64, flash_attn_ext_q8_0_h64, has_simdgroup_mm);
|
1055
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80, flash_attn_ext_q8_0_h80, has_simdgroup_mm);
|
1056
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96, flash_attn_ext_q8_0_h96, has_simdgroup_mm);
|
1057
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112, flash_attn_ext_q8_0_h112, has_simdgroup_mm);
|
1058
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128, flash_attn_ext_q8_0_h128, has_simdgroup_mm);
|
1059
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192, flash_attn_ext_q8_0_h192, has_simdgroup_mm);
|
1060
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128, flash_attn_ext_q8_0_hk192_hv128, has_simdgroup_mm);
|
1061
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, has_simdgroup_mm);
|
1062
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, has_simdgroup_reduction);
|
1063
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128, flash_attn_ext_vec_bf16_h128, has_simdgroup_reduction && use_bfloat);
|
1064
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128, flash_attn_ext_vec_q4_0_h128, has_simdgroup_reduction);
|
1065
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128, flash_attn_ext_vec_q4_1_h128, has_simdgroup_reduction);
|
1066
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128, flash_attn_ext_vec_q5_0_h128, has_simdgroup_reduction);
|
1067
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128, flash_attn_ext_vec_q5_1_h128, has_simdgroup_reduction);
|
1068
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128, flash_attn_ext_vec_q8_0_h128, has_simdgroup_reduction);
|
1069
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H192, flash_attn_ext_vec_f16_h192, has_simdgroup_reduction);
|
1070
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H192, flash_attn_ext_vec_bf16_h192, has_simdgroup_reduction && use_bfloat);
|
1071
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H192, flash_attn_ext_vec_q4_0_h192, has_simdgroup_reduction);
|
1072
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H192, flash_attn_ext_vec_q4_1_h192, has_simdgroup_reduction);
|
1073
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H192, flash_attn_ext_vec_q5_0_h192, has_simdgroup_reduction);
|
1074
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H192, flash_attn_ext_vec_q5_1_h192, has_simdgroup_reduction);
|
1075
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H192, flash_attn_ext_vec_q8_0_h192, has_simdgroup_reduction);
|
1076
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK192_HV128, flash_attn_ext_vec_f16_hk192_hv128, has_simdgroup_reduction);
|
1077
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK192_HV128, flash_attn_ext_vec_bf16_hk192_hv128, has_simdgroup_reduction && use_bfloat);
|
1078
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK192_HV128, flash_attn_ext_vec_q4_0_hk192_hv128, has_simdgroup_reduction);
|
1079
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK192_HV128, flash_attn_ext_vec_q4_1_hk192_hv128, has_simdgroup_reduction);
|
1080
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK192_HV128, flash_attn_ext_vec_q5_0_hk192_hv128, has_simdgroup_reduction);
|
1081
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK192_HV128, flash_attn_ext_vec_q5_1_hk192_hv128, has_simdgroup_reduction);
|
1082
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK192_HV128, flash_attn_ext_vec_q8_0_hk192_hv128, has_simdgroup_reduction);
|
1083
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, has_simdgroup_reduction);
|
1084
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256, flash_attn_ext_vec_bf16_h256, has_simdgroup_reduction && use_bfloat);
|
1085
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256, flash_attn_ext_vec_q4_0_h256, has_simdgroup_reduction);
|
1086
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256, flash_attn_ext_vec_q4_1_h256, has_simdgroup_reduction);
|
1087
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256, has_simdgroup_reduction);
|
1088
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256, flash_attn_ext_vec_q5_1_h256, has_simdgroup_reduction);
|
1089
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, flash_attn_ext_vec_q8_0_h256, has_simdgroup_reduction);
|
1090
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SET_F32, set_f32, true);
|
1091
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SET_I32, set_i32, true);
|
1092
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
|
1093
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
|
1094
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_BF16, cpy_f32_bf16, use_bfloat);
|
1095
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true);
|
1096
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
|
1097
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_BF16_F32, cpy_bf16_f32, use_bfloat);
|
1098
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16, cpy_bf16_bf16, use_bfloat);
|
1099
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
|
1100
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true);
|
1101
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true);
|
1102
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true);
|
1103
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
|
1104
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true);
|
1105
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32, cpy_q4_0_f32, true);
|
1106
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F16, cpy_q4_0_f16, true);
|
1107
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32, cpy_q4_1_f32, true);
|
1108
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F16, cpy_q4_1_f16, true);
|
1109
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32, cpy_q5_0_f32, true);
|
1110
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F16, cpy_q5_0_f16, true);
|
1111
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32, cpy_q5_1_f32, true);
|
1112
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F16, cpy_q5_1_f16, true);
|
1113
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32, cpy_q8_0_f32, true);
|
1114
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16, cpy_q8_0_f16, true);
|
1115
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
|
1116
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SQR, sqr, true);
|
1117
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SQRT, sqrt, true);
|
1118
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SIN, sin, true);
|
1119
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_COS, cos, true);
|
1120
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
1121
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
|
1122
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
|
1123
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true);
|
1062
1124
|
}
|
1063
1125
|
|
1064
1126
|
return ctx;
|
@@ -1251,6 +1313,7 @@ static bool lm_ggml_metal_supports_op(const struct lm_ggml_backend_metal_device_
|
|
1251
1313
|
case LM_GGML_OP_GROUP_NORM:
|
1252
1314
|
return has_simdgroup_reduction && lm_ggml_is_contiguous(op->src[0]);
|
1253
1315
|
case LM_GGML_OP_RMS_NORM:
|
1316
|
+
case LM_GGML_OP_L2_NORM:
|
1254
1317
|
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && lm_ggml_is_contiguous_1(op->src[0]));
|
1255
1318
|
case LM_GGML_OP_ARGMAX:
|
1256
1319
|
return true;
|
@@ -1282,12 +1345,19 @@ static bool lm_ggml_metal_supports_op(const struct lm_ggml_backend_metal_device_
|
|
1282
1345
|
case LM_GGML_OP_ARANGE:
|
1283
1346
|
return true;
|
1284
1347
|
case LM_GGML_OP_FLASH_ATTN_EXT:
|
1348
|
+
if (op->src[0]->ne[0] == 32) {
|
1349
|
+
// head size == 32 (e.g. bert-bge-small)
|
1350
|
+
// TODO: not sure if it is worth adding kernels for this size
|
1351
|
+
return false;
|
1352
|
+
}
|
1285
1353
|
if (op->src[1]->type != op->src[2]->type) {
|
1286
1354
|
return false;
|
1287
1355
|
}
|
1288
1356
|
return has_simdgroup_mm; // TODO: over-restricted for vec-kernels
|
1289
1357
|
case LM_GGML_OP_SSM_CONV:
|
1290
1358
|
case LM_GGML_OP_SSM_SCAN:
|
1359
|
+
case LM_GGML_OP_RWKV_WKV6:
|
1360
|
+
case LM_GGML_OP_RWKV_WKV7:
|
1291
1361
|
return true;
|
1292
1362
|
case LM_GGML_OP_MUL_MAT:
|
1293
1363
|
case LM_GGML_OP_MUL_MAT_ID:
|
@@ -2216,6 +2286,83 @@ static void lm_ggml_metal_encode_node(
|
|
2216
2286
|
|
2217
2287
|
[encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
2218
2288
|
} break;
|
2289
|
+
case LM_GGML_OP_RWKV_WKV6:
|
2290
|
+
{
|
2291
|
+
const int64_t B = dst->src[5]->ne[1];
|
2292
|
+
const int64_t T = dst->src[0]->ne[2];
|
2293
|
+
const int64_t C = dst->ne[0];
|
2294
|
+
const int64_t H = dst->src[0]->ne[1];
|
2295
|
+
|
2296
|
+
LM_GGML_ASSERT(dst->src[5]->type == LM_GGML_TYPE_F32);
|
2297
|
+
LM_GGML_ASSERT(C % H == 0);
|
2298
|
+
LM_GGML_ASSERT(C / H == 64);
|
2299
|
+
|
2300
|
+
size_t offs_src3 = 0;
|
2301
|
+
size_t offs_src4 = 0;
|
2302
|
+
size_t offs_src5 = 0;
|
2303
|
+
|
2304
|
+
id<MTLBuffer> id_src3 = dst->src[3] ? lm_ggml_metal_get_buffer(dst->src[3], &offs_src3) : nil;
|
2305
|
+
id<MTLBuffer> id_src4 = dst->src[4] ? lm_ggml_metal_get_buffer(dst->src[4], &offs_src4) : nil;
|
2306
|
+
id<MTLBuffer> id_src5 = dst->src[5] ? lm_ggml_metal_get_buffer(dst->src[5], &offs_src5) : nil;
|
2307
|
+
|
2308
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32].pipeline;
|
2309
|
+
|
2310
|
+
[encoder setComputePipelineState:pipeline];
|
2311
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
2312
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
2313
|
+
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
|
2314
|
+
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
|
2315
|
+
[encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
|
2316
|
+
[encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
|
2317
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:6];
|
2318
|
+
|
2319
|
+
[encoder setBytes:&B length:sizeof(B) atIndex:7];
|
2320
|
+
[encoder setBytes:&T length:sizeof(T) atIndex:8];
|
2321
|
+
[encoder setBytes:&C length:sizeof(C) atIndex:9];
|
2322
|
+
[encoder setBytes:&H length:sizeof(H) atIndex:10];
|
2323
|
+
|
2324
|
+
[encoder dispatchThreadgroups:MTLSizeMake(B * H, 1, 1) threadsPerThreadgroup:MTLSizeMake(C/ H, 1, 1)];
|
2325
|
+
} break;
|
2326
|
+
case LM_GGML_OP_RWKV_WKV7:
|
2327
|
+
{
|
2328
|
+
const int64_t B = dst->src[6]->ne[1];
|
2329
|
+
const int64_t T = dst->src[0]->ne[2];
|
2330
|
+
const int64_t C = dst->ne[0];
|
2331
|
+
const int64_t H = dst->src[0]->ne[1];
|
2332
|
+
|
2333
|
+
LM_GGML_ASSERT(dst->src[6]->type == LM_GGML_TYPE_F32);
|
2334
|
+
LM_GGML_ASSERT(C % H == 0);
|
2335
|
+
LM_GGML_ASSERT(C / H == 64);
|
2336
|
+
|
2337
|
+
size_t offs_src3 = 0;
|
2338
|
+
size_t offs_src4 = 0;
|
2339
|
+
size_t offs_src5 = 0;
|
2340
|
+
size_t offs_src6 = 0;
|
2341
|
+
|
2342
|
+
id<MTLBuffer> id_src3 = dst->src[3] ? lm_ggml_metal_get_buffer(dst->src[3], &offs_src3) : nil;
|
2343
|
+
id<MTLBuffer> id_src4 = dst->src[4] ? lm_ggml_metal_get_buffer(dst->src[4], &offs_src4) : nil;
|
2344
|
+
id<MTLBuffer> id_src5 = dst->src[5] ? lm_ggml_metal_get_buffer(dst->src[5], &offs_src5) : nil;
|
2345
|
+
id<MTLBuffer> id_src6 = dst->src[6] ? lm_ggml_metal_get_buffer(dst->src[6], &offs_src6) : nil;
|
2346
|
+
|
2347
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32].pipeline;
|
2348
|
+
|
2349
|
+
[encoder setComputePipelineState:pipeline];
|
2350
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
2351
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
2352
|
+
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
|
2353
|
+
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
|
2354
|
+
[encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
|
2355
|
+
[encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
|
2356
|
+
[encoder setBuffer:id_src6 offset:offs_src6 atIndex:6];
|
2357
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:7];
|
2358
|
+
|
2359
|
+
[encoder setBytes:&B length:sizeof(B) atIndex:8];
|
2360
|
+
[encoder setBytes:&T length:sizeof(T) atIndex:9];
|
2361
|
+
[encoder setBytes:&C length:sizeof(C) atIndex:10];
|
2362
|
+
[encoder setBytes:&H length:sizeof(H) atIndex:11];
|
2363
|
+
|
2364
|
+
[encoder dispatchThreadgroups:MTLSizeMake(B * H, 1, 1) threadsPerThreadgroup:MTLSizeMake(C/ H, 1, 1)];
|
2365
|
+
} break;
|
2219
2366
|
case LM_GGML_OP_MUL_MAT:
|
2220
2367
|
{
|
2221
2368
|
LM_GGML_ASSERT(ne00 == ne10);
|
@@ -2475,171 +2622,180 @@ static void lm_ggml_metal_encode_node(
|
|
2475
2622
|
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
2476
2623
|
[encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
2477
2624
|
} else {
|
2478
|
-
int nth0 = 32;
|
2479
|
-
int nth1 = 1;
|
2480
|
-
int nrows = 1;
|
2481
|
-
//printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
|
2482
|
-
|
2483
2625
|
id<MTLComputePipelineState> pipeline = nil;
|
2484
2626
|
|
2627
|
+
int nsg = 0; // number of simdgroups
|
2628
|
+
int nr0 = 0; // number of src0 rows per simdgroup
|
2629
|
+
int nr1 = 1; // number of src1 rows per threadgroup
|
2630
|
+
|
2631
|
+
size_t smem = 0; // shared memory
|
2632
|
+
|
2485
2633
|
// use custom matrix x vector kernel
|
2486
2634
|
switch (src0t) {
|
2487
2635
|
case LM_GGML_TYPE_F32:
|
2488
2636
|
{
|
2489
2637
|
LM_GGML_ASSERT(src1t == LM_GGML_TYPE_F32);
|
2638
|
+
nsg = 1;
|
2639
|
+
nr0 = 1;
|
2640
|
+
nr1 = 4;
|
2490
2641
|
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline;
|
2491
|
-
nrows = 4;
|
2492
2642
|
} break;
|
2493
2643
|
case LM_GGML_TYPE_F16:
|
2494
2644
|
{
|
2495
|
-
|
2496
|
-
|
2645
|
+
nsg = 1;
|
2646
|
+
nr0 = 1;
|
2497
2647
|
if (src1t == LM_GGML_TYPE_F32) {
|
2498
2648
|
if (ne11 * ne12 < 4) {
|
2499
2649
|
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline;
|
2500
2650
|
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
|
2501
2651
|
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline;
|
2502
|
-
|
2652
|
+
nr1 = ne11;
|
2503
2653
|
} else {
|
2504
2654
|
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32].pipeline;
|
2505
|
-
|
2655
|
+
nr1 = 4;
|
2506
2656
|
}
|
2507
2657
|
} else {
|
2508
2658
|
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16].pipeline;
|
2509
|
-
|
2659
|
+
nr1 = 4;
|
2510
2660
|
}
|
2511
2661
|
} break;
|
2512
2662
|
case LM_GGML_TYPE_BF16:
|
2513
2663
|
{
|
2514
|
-
|
2515
|
-
|
2664
|
+
nsg = 1;
|
2665
|
+
nr0 = 1;
|
2516
2666
|
if (src1t == LM_GGML_TYPE_F32) {
|
2517
2667
|
if (ne11 * ne12 < 4) {
|
2518
2668
|
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW].pipeline;
|
2519
2669
|
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
|
2520
2670
|
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4].pipeline;
|
2521
|
-
|
2671
|
+
nr1 = ne11;
|
2522
2672
|
} else {
|
2523
2673
|
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32].pipeline;
|
2524
|
-
|
2674
|
+
nr1 = 4;
|
2525
2675
|
}
|
2526
2676
|
} else {
|
2527
2677
|
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16].pipeline;
|
2528
|
-
|
2678
|
+
nr1 = 4;
|
2529
2679
|
}
|
2530
2680
|
} break;
|
2531
2681
|
case LM_GGML_TYPE_Q4_0:
|
2532
2682
|
{
|
2533
|
-
|
2534
|
-
|
2683
|
+
nsg = N_SG_Q4_0;
|
2684
|
+
nr0 = N_R0_Q4_0;
|
2535
2685
|
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32].pipeline;
|
2536
2686
|
} break;
|
2537
2687
|
case LM_GGML_TYPE_Q4_1:
|
2538
2688
|
{
|
2539
|
-
|
2540
|
-
|
2689
|
+
nsg = N_SG_Q4_1;
|
2690
|
+
nr0 = N_R0_Q4_1;
|
2541
2691
|
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32].pipeline;
|
2542
2692
|
} break;
|
2543
2693
|
case LM_GGML_TYPE_Q5_0:
|
2544
2694
|
{
|
2545
|
-
|
2546
|
-
|
2695
|
+
nsg = N_SG_Q5_0;
|
2696
|
+
nr0 = N_R0_Q5_0;
|
2547
2697
|
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32].pipeline;
|
2548
2698
|
} break;
|
2549
2699
|
case LM_GGML_TYPE_Q5_1:
|
2550
2700
|
{
|
2551
|
-
|
2552
|
-
|
2701
|
+
nsg = N_SG_Q5_1;
|
2702
|
+
nr0 = N_R0_Q5_1;
|
2553
2703
|
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32].pipeline;
|
2554
2704
|
} break;
|
2555
2705
|
case LM_GGML_TYPE_Q8_0:
|
2556
2706
|
{
|
2557
|
-
|
2558
|
-
|
2707
|
+
nsg = N_SG_Q8_0;
|
2708
|
+
nr0 = N_R0_Q8_0;
|
2559
2709
|
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline;
|
2560
2710
|
} break;
|
2561
2711
|
case LM_GGML_TYPE_Q2_K:
|
2562
2712
|
{
|
2563
|
-
|
2564
|
-
|
2713
|
+
nsg = N_SG_Q2_K;
|
2714
|
+
nr0 = N_R0_Q2_K;
|
2565
2715
|
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32].pipeline;
|
2566
2716
|
} break;
|
2567
2717
|
case LM_GGML_TYPE_Q3_K:
|
2568
2718
|
{
|
2569
|
-
|
2570
|
-
|
2719
|
+
nsg = N_SG_Q3_K;
|
2720
|
+
nr0 = N_R0_Q3_K;
|
2571
2721
|
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32].pipeline;
|
2572
2722
|
} break;
|
2573
2723
|
case LM_GGML_TYPE_Q4_K:
|
2574
2724
|
{
|
2575
|
-
|
2576
|
-
|
2725
|
+
nsg = N_SG_Q4_K;
|
2726
|
+
nr0 = N_R0_Q4_K;
|
2577
2727
|
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32].pipeline;
|
2578
2728
|
} break;
|
2579
2729
|
case LM_GGML_TYPE_Q5_K:
|
2580
2730
|
{
|
2581
|
-
|
2582
|
-
|
2731
|
+
nsg = N_SG_Q5_K;
|
2732
|
+
nr0 = N_R0_Q5_K;
|
2583
2733
|
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32].pipeline;
|
2584
2734
|
} break;
|
2585
2735
|
case LM_GGML_TYPE_Q6_K:
|
2586
2736
|
{
|
2587
|
-
|
2588
|
-
|
2737
|
+
nsg = N_SG_Q6_K;
|
2738
|
+
nr0 = N_R0_Q6_K;
|
2589
2739
|
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32].pipeline;
|
2590
2740
|
} break;
|
2591
2741
|
case LM_GGML_TYPE_IQ2_XXS:
|
2592
2742
|
{
|
2593
|
-
|
2594
|
-
|
2743
|
+
nsg = N_SG_IQ2_XXS;
|
2744
|
+
nr0 = N_R0_IQ2_XXS;
|
2745
|
+
smem = 256*8+128;
|
2595
2746
|
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32].pipeline;
|
2596
2747
|
} break;
|
2597
2748
|
case LM_GGML_TYPE_IQ2_XS:
|
2598
2749
|
{
|
2599
|
-
|
2600
|
-
|
2750
|
+
nsg = N_SG_IQ2_XS;
|
2751
|
+
nr0 = N_R0_IQ2_XS;
|
2752
|
+
smem = 512*8+128;
|
2601
2753
|
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32].pipeline;
|
2602
2754
|
} break;
|
2603
2755
|
case LM_GGML_TYPE_IQ3_XXS:
|
2604
2756
|
{
|
2605
|
-
|
2606
|
-
|
2757
|
+
nsg = N_SG_IQ3_XXS;
|
2758
|
+
nr0 = N_R0_IQ3_XXS;
|
2759
|
+
smem = 256*4+128;
|
2607
2760
|
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline;
|
2608
2761
|
} break;
|
2609
2762
|
case LM_GGML_TYPE_IQ3_S:
|
2610
2763
|
{
|
2611
|
-
|
2612
|
-
|
2764
|
+
nsg = N_SG_IQ3_S;
|
2765
|
+
nr0 = N_R0_IQ3_S;
|
2766
|
+
smem = 512*4;
|
2613
2767
|
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32].pipeline;
|
2614
2768
|
} break;
|
2615
2769
|
case LM_GGML_TYPE_IQ2_S:
|
2616
2770
|
{
|
2617
|
-
|
2618
|
-
|
2771
|
+
nsg = N_SG_IQ2_S;
|
2772
|
+
nr0 = N_R0_IQ2_S;
|
2619
2773
|
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32].pipeline;
|
2620
2774
|
} break;
|
2621
2775
|
case LM_GGML_TYPE_IQ1_S:
|
2622
2776
|
{
|
2623
|
-
|
2624
|
-
|
2777
|
+
nsg = N_SG_IQ1_S;
|
2778
|
+
nr0 = N_R0_IQ1_S;
|
2625
2779
|
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32].pipeline;
|
2626
2780
|
} break;
|
2627
2781
|
case LM_GGML_TYPE_IQ1_M:
|
2628
2782
|
{
|
2629
|
-
|
2630
|
-
|
2783
|
+
nsg = N_SG_IQ1_M;
|
2784
|
+
nr0 = N_R0_IQ1_M;
|
2631
2785
|
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32].pipeline;
|
2632
2786
|
} break;
|
2633
2787
|
case LM_GGML_TYPE_IQ4_NL:
|
2634
2788
|
{
|
2635
|
-
|
2636
|
-
|
2789
|
+
nsg = N_SG_IQ4_NL;
|
2790
|
+
nr0 = N_R0_IQ4_NL;
|
2791
|
+
smem = 32*sizeof(float);
|
2637
2792
|
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32].pipeline;
|
2638
2793
|
} break;
|
2639
2794
|
case LM_GGML_TYPE_IQ4_XS:
|
2640
2795
|
{
|
2641
|
-
|
2642
|
-
|
2796
|
+
nsg = N_SG_IQ4_XS;
|
2797
|
+
nr0 = N_R0_IQ4_XS;
|
2798
|
+
smem = 32*sizeof(float);
|
2643
2799
|
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32].pipeline;
|
2644
2800
|
} break;
|
2645
2801
|
default:
|
@@ -2676,41 +2832,10 @@ static void lm_ggml_metal_encode_node(
|
|
2676
2832
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
|
2677
2833
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
2678
2834
|
|
2679
|
-
if (
|
2680
|
-
|
2681
|
-
src0t == LM_GGML_TYPE_IQ1_S || src0t == LM_GGML_TYPE_IQ1_M || src0t == LM_GGML_TYPE_IQ2_S) {
|
2682
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
2683
|
-
}
|
2684
|
-
else if (src0t == LM_GGML_TYPE_IQ2_XXS || src0t == LM_GGML_TYPE_IQ2_XS) {
|
2685
|
-
const int mem_size = src0t == LM_GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
|
2686
|
-
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
2687
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
2688
|
-
}
|
2689
|
-
else if (src0t == LM_GGML_TYPE_IQ3_XXS || src0t == LM_GGML_TYPE_IQ3_S) {
|
2690
|
-
const int mem_size = src0t == LM_GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
|
2691
|
-
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
2692
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
2693
|
-
}
|
2694
|
-
else if (src0t == LM_GGML_TYPE_IQ4_NL || src0t == LM_GGML_TYPE_IQ4_XS) {
|
2695
|
-
const int mem_size = 32*sizeof(float);
|
2696
|
-
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
2697
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
2698
|
-
}
|
2699
|
-
else if (src0t == LM_GGML_TYPE_Q4_K) {
|
2700
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
2701
|
-
}
|
2702
|
-
else if (src0t == LM_GGML_TYPE_Q3_K) {
|
2703
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
2704
|
-
}
|
2705
|
-
else if (src0t == LM_GGML_TYPE_Q5_K) {
|
2706
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
2707
|
-
}
|
2708
|
-
else if (src0t == LM_GGML_TYPE_Q6_K) {
|
2709
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
2710
|
-
} else {
|
2711
|
-
const int64_t ny = (ne11 + nrows - 1)/nrows;
|
2712
|
-
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
2835
|
+
if (smem > 0) {
|
2836
|
+
[encoder setThreadgroupMemoryLength:smem atIndex:0];
|
2713
2837
|
}
|
2838
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nr0*nsg - 1)/(nr0*nsg), (ne11 + nr1 - 1)/nr1, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
|
2714
2839
|
}
|
2715
2840
|
} break;
|
2716
2841
|
case LM_GGML_OP_MUL_MAT_ID:
|
@@ -2736,20 +2861,19 @@ static void lm_ggml_metal_encode_node(
|
|
2736
2861
|
// ne21 = n_rows
|
2737
2862
|
const int dst_rows = ne20*ne21;
|
2738
2863
|
const int dst_rows_min = n_as;
|
2739
|
-
const int dst_rows_max = (device.maxThreadgroupMemoryLength -
|
2864
|
+
const int dst_rows_max = (device.maxThreadgroupMemoryLength/2 - 8192)/4;
|
2740
2865
|
|
2741
2866
|
// max size of the rowids array in the kernel shared buffer
|
2742
|
-
LM_GGML_ASSERT(dst_rows <= dst_rows_max);
|
2867
|
+
//LM_GGML_ASSERT(dst_rows <= dst_rows_max);
|
2743
2868
|
|
2744
2869
|
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
2745
2870
|
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
2746
|
-
// !!!
|
2747
|
-
// TODO: for now, always use mat-vec kernels until we figure out how to improve the
|
2748
|
-
// indirect matrix multiplication
|
2749
|
-
// !!!
|
2750
2871
|
if ([device supportsFamily:MTLGPUFamilyApple7] &&
|
2751
2872
|
ne00 % 32 == 0 && ne00 >= 64 &&
|
2752
|
-
|
2873
|
+
//ne01 / ne02 >= 512 && // NOTE: this is based on Mixtral shapes, might need adjustments
|
2874
|
+
dst_rows > dst_rows_min &&
|
2875
|
+
dst_rows <= dst_rows_max) {
|
2876
|
+
|
2753
2877
|
// some Metal matrix data types require aligned pointers
|
2754
2878
|
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
|
2755
2879
|
switch (src0->type) {
|
@@ -2816,146 +2940,155 @@ static void lm_ggml_metal_encode_node(
|
|
2816
2940
|
|
2817
2941
|
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, n_as) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
2818
2942
|
} else {
|
2819
|
-
int nth0 = 32;
|
2820
|
-
int nth1 = 1;
|
2821
|
-
int nrows = 1;
|
2822
|
-
//printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
|
2823
|
-
|
2824
2943
|
id<MTLComputePipelineState> pipeline = nil;
|
2825
2944
|
|
2945
|
+
int nsg = 0; // number of simdgroups
|
2946
|
+
int nr0 = 0; // number of src0 rows per simdgroup
|
2947
|
+
int nr1 = 1; // number of src1 rows per threadgroup
|
2948
|
+
|
2949
|
+
size_t smem = 0; // shared memory
|
2950
|
+
|
2826
2951
|
// use custom matrix x vector kernel
|
2827
2952
|
switch (src0t) {
|
2828
2953
|
case LM_GGML_TYPE_F32:
|
2829
2954
|
{
|
2830
2955
|
LM_GGML_ASSERT(src1t == LM_GGML_TYPE_F32);
|
2956
|
+
nsg = 1;
|
2957
|
+
nr0 = 1;
|
2831
2958
|
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32].pipeline;
|
2832
2959
|
} break;
|
2833
2960
|
case LM_GGML_TYPE_F16:
|
2834
2961
|
{
|
2835
2962
|
LM_GGML_ASSERT(src1t == LM_GGML_TYPE_F32);
|
2836
|
-
|
2837
|
-
|
2963
|
+
nsg = 1;
|
2964
|
+
nr0 = 1;
|
2838
2965
|
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32].pipeline;
|
2839
2966
|
} break;
|
2840
2967
|
case LM_GGML_TYPE_BF16:
|
2841
2968
|
{
|
2842
2969
|
LM_GGML_ASSERT(src1t == LM_GGML_TYPE_F32);
|
2843
|
-
|
2844
|
-
|
2970
|
+
nsg = 1;
|
2971
|
+
nr0 = 1;
|
2845
2972
|
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32].pipeline;
|
2846
2973
|
} break;
|
2847
2974
|
case LM_GGML_TYPE_Q4_0:
|
2848
2975
|
{
|
2849
|
-
|
2850
|
-
|
2976
|
+
nsg = N_SG_Q4_0;
|
2977
|
+
nr0 = N_R0_Q4_0;
|
2851
2978
|
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32].pipeline;
|
2852
2979
|
} break;
|
2853
2980
|
case LM_GGML_TYPE_Q4_1:
|
2854
2981
|
{
|
2855
|
-
|
2856
|
-
|
2982
|
+
nsg = N_SG_Q4_1;
|
2983
|
+
nr0 = N_R0_Q4_1;
|
2857
2984
|
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32].pipeline;
|
2858
2985
|
} break;
|
2859
2986
|
case LM_GGML_TYPE_Q5_0:
|
2860
2987
|
{
|
2861
|
-
|
2862
|
-
|
2988
|
+
nsg = N_SG_Q5_0;
|
2989
|
+
nr0 = N_R0_Q5_0;
|
2863
2990
|
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32].pipeline;
|
2864
2991
|
} break;
|
2865
2992
|
case LM_GGML_TYPE_Q5_1:
|
2866
2993
|
{
|
2867
|
-
|
2868
|
-
|
2994
|
+
nsg = N_SG_Q5_1;
|
2995
|
+
nr0 = N_R0_Q5_1;
|
2869
2996
|
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32].pipeline;
|
2870
2997
|
} break;
|
2871
2998
|
case LM_GGML_TYPE_Q8_0:
|
2872
2999
|
{
|
2873
|
-
|
2874
|
-
|
3000
|
+
nsg = N_SG_Q8_0;
|
3001
|
+
nr0 = N_R0_Q8_0;
|
2875
3002
|
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline;
|
2876
3003
|
} break;
|
2877
3004
|
case LM_GGML_TYPE_Q2_K:
|
2878
3005
|
{
|
2879
|
-
|
2880
|
-
|
3006
|
+
nsg = N_SG_Q2_K;
|
3007
|
+
nr0 = N_R0_Q2_K;
|
2881
3008
|
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32].pipeline;
|
2882
3009
|
} break;
|
2883
3010
|
case LM_GGML_TYPE_Q3_K:
|
2884
3011
|
{
|
2885
|
-
|
2886
|
-
|
3012
|
+
nsg = N_SG_Q3_K;
|
3013
|
+
nr0 = N_R0_Q3_K;
|
2887
3014
|
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32].pipeline;
|
2888
3015
|
} break;
|
2889
3016
|
case LM_GGML_TYPE_Q4_K:
|
2890
3017
|
{
|
2891
|
-
|
2892
|
-
|
3018
|
+
nsg = N_SG_Q4_K;
|
3019
|
+
nr0 = N_R0_Q4_K;
|
2893
3020
|
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32].pipeline;
|
2894
3021
|
} break;
|
2895
3022
|
case LM_GGML_TYPE_Q5_K:
|
2896
3023
|
{
|
2897
|
-
|
2898
|
-
|
3024
|
+
nsg = N_SG_Q5_K;
|
3025
|
+
nr0 = N_R0_Q5_K;
|
2899
3026
|
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32].pipeline;
|
2900
3027
|
} break;
|
2901
3028
|
case LM_GGML_TYPE_Q6_K:
|
2902
3029
|
{
|
2903
|
-
|
2904
|
-
|
3030
|
+
nsg = N_SG_Q6_K;
|
3031
|
+
nr0 = N_R0_Q6_K;
|
2905
3032
|
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32].pipeline;
|
2906
3033
|
} break;
|
2907
3034
|
case LM_GGML_TYPE_IQ2_XXS:
|
2908
3035
|
{
|
2909
|
-
|
2910
|
-
|
3036
|
+
nsg = N_SG_IQ2_XXS;
|
3037
|
+
nr0 = N_R0_IQ2_XXS;
|
3038
|
+
smem = 256*8+128;
|
2911
3039
|
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32].pipeline;
|
2912
3040
|
} break;
|
2913
3041
|
case LM_GGML_TYPE_IQ2_XS:
|
2914
3042
|
{
|
2915
|
-
|
2916
|
-
|
3043
|
+
nsg = N_SG_IQ2_XS;
|
3044
|
+
nr0 = N_R0_IQ2_XS;
|
3045
|
+
smem = 512*8+128;
|
2917
3046
|
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32].pipeline;
|
2918
3047
|
} break;
|
2919
3048
|
case LM_GGML_TYPE_IQ3_XXS:
|
2920
3049
|
{
|
2921
|
-
|
2922
|
-
|
3050
|
+
nsg = N_SG_IQ3_XXS;
|
3051
|
+
nr0 = N_R0_IQ3_XXS;
|
3052
|
+
smem = 256*4+128;
|
2923
3053
|
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32].pipeline;
|
2924
3054
|
} break;
|
2925
3055
|
case LM_GGML_TYPE_IQ3_S:
|
2926
3056
|
{
|
2927
|
-
|
2928
|
-
|
3057
|
+
nsg = N_SG_IQ3_S;
|
3058
|
+
nr0 = N_R0_IQ3_S;
|
3059
|
+
smem = 512*4;
|
2929
3060
|
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32].pipeline;
|
2930
3061
|
} break;
|
2931
3062
|
case LM_GGML_TYPE_IQ2_S:
|
2932
3063
|
{
|
2933
|
-
|
2934
|
-
|
3064
|
+
nsg = N_SG_IQ2_S;
|
3065
|
+
nr0 = N_R0_IQ2_S;
|
2935
3066
|
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32].pipeline;
|
2936
3067
|
} break;
|
2937
3068
|
case LM_GGML_TYPE_IQ1_S:
|
2938
3069
|
{
|
2939
|
-
|
2940
|
-
|
3070
|
+
nsg = N_SG_IQ1_S;
|
3071
|
+
nr0 = N_R0_IQ1_S;
|
2941
3072
|
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32].pipeline;
|
2942
3073
|
} break;
|
2943
3074
|
case LM_GGML_TYPE_IQ1_M:
|
2944
3075
|
{
|
2945
|
-
|
2946
|
-
|
3076
|
+
nsg = N_SG_IQ1_M;
|
3077
|
+
nr0 = N_R0_IQ1_M;
|
2947
3078
|
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32].pipeline;
|
2948
3079
|
} break;
|
2949
3080
|
case LM_GGML_TYPE_IQ4_NL:
|
2950
3081
|
{
|
2951
|
-
|
2952
|
-
|
3082
|
+
nsg = N_SG_IQ4_NL;
|
3083
|
+
nr0 = N_R0_IQ4_NL;
|
3084
|
+
smem = 32*sizeof(float);
|
2953
3085
|
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline;
|
2954
3086
|
} break;
|
2955
3087
|
case LM_GGML_TYPE_IQ4_XS:
|
2956
3088
|
{
|
2957
|
-
|
2958
|
-
|
3089
|
+
nsg = N_SG_IQ4_XS;
|
3090
|
+
nr0 = N_R0_IQ4_XS;
|
3091
|
+
smem = 32*sizeof(float);
|
2959
3092
|
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline;
|
2960
3093
|
} break;
|
2961
3094
|
default:
|
@@ -2966,7 +3099,7 @@ static void lm_ggml_metal_encode_node(
|
|
2966
3099
|
};
|
2967
3100
|
|
2968
3101
|
if (lm_ggml_is_quantized(src0t)) {
|
2969
|
-
LM_GGML_ASSERT(ne00 >=
|
3102
|
+
LM_GGML_ASSERT(ne00 >= nsg*nr0);
|
2970
3103
|
}
|
2971
3104
|
|
2972
3105
|
lm_ggml_metal_kargs_mul_mv_id args = {
|
@@ -2999,43 +3132,12 @@ static void lm_ggml_metal_encode_node(
|
|
2999
3132
|
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:4];
|
3000
3133
|
|
3001
3134
|
const int64_t _ne1 = 1;
|
3002
|
-
const
|
3135
|
+
const int64_t ne123 = dst_rows;
|
3003
3136
|
|
3004
|
-
if (
|
3005
|
-
|
3006
|
-
src0t == LM_GGML_TYPE_IQ1_S || src0t == LM_GGML_TYPE_IQ1_M || src0t == LM_GGML_TYPE_IQ2_S) {
|
3007
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
3008
|
-
}
|
3009
|
-
else if (src0t == LM_GGML_TYPE_IQ2_XXS || src0t == LM_GGML_TYPE_IQ2_XS) {
|
3010
|
-
const int mem_size = src0t == LM_GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
|
3011
|
-
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
3012
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
3013
|
-
}
|
3014
|
-
else if (src0t == LM_GGML_TYPE_IQ3_XXS || src0t == LM_GGML_TYPE_IQ3_S) {
|
3015
|
-
const int mem_size = src0t == LM_GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
|
3016
|
-
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
3017
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
3018
|
-
}
|
3019
|
-
else if (src0t == LM_GGML_TYPE_IQ4_NL || src0t == LM_GGML_TYPE_IQ4_XS) {
|
3020
|
-
const int mem_size = 32*sizeof(float);
|
3021
|
-
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
3022
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
3023
|
-
}
|
3024
|
-
else if (src0t == LM_GGML_TYPE_Q4_K) {
|
3025
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
3026
|
-
}
|
3027
|
-
else if (src0t == LM_GGML_TYPE_Q3_K) {
|
3028
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
3029
|
-
}
|
3030
|
-
else if (src0t == LM_GGML_TYPE_Q5_K) {
|
3031
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
3032
|
-
}
|
3033
|
-
else if (src0t == LM_GGML_TYPE_Q6_K) {
|
3034
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
3035
|
-
} else {
|
3036
|
-
const int64_t ny = (_ne1 + nrows - 1)/nrows; // = _ne1
|
3037
|
-
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
3137
|
+
if (smem > 0) {
|
3138
|
+
[encoder setThreadgroupMemoryLength:smem atIndex:0];
|
3038
3139
|
}
|
3140
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nr0*nsg - 1)/(nr0*nsg), (_ne1 + nr1 - 1)/nr1, ne123) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
|
3039
3141
|
}
|
3040
3142
|
} break;
|
3041
3143
|
case LM_GGML_OP_GET_ROWS:
|
@@ -3122,6 +3224,42 @@ static void lm_ggml_metal_encode_node(
|
|
3122
3224
|
|
3123
3225
|
const int64_t nrows = lm_ggml_nrows(src0);
|
3124
3226
|
|
3227
|
+
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
3228
|
+
} break;
|
3229
|
+
case LM_GGML_OP_L2_NORM:
|
3230
|
+
{
|
3231
|
+
LM_GGML_ASSERT(ne00 % 4 == 0);
|
3232
|
+
LM_GGML_ASSERT(lm_ggml_is_contiguous_1(src0));
|
3233
|
+
|
3234
|
+
float eps;
|
3235
|
+
memcpy(&eps, dst->op_params, sizeof(float));
|
3236
|
+
|
3237
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_L2_NORM].pipeline;
|
3238
|
+
|
3239
|
+
int nth = 32; // SIMD width
|
3240
|
+
|
3241
|
+
while (nth < ne00/4 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
|
3242
|
+
nth *= 2;
|
3243
|
+
}
|
3244
|
+
|
3245
|
+
nth = MIN(nth, ne00/4);
|
3246
|
+
|
3247
|
+
lm_ggml_metal_kargs_l2_norm args = {
|
3248
|
+
/*.ne00 =*/ ne00,
|
3249
|
+
/*.ne00_4 =*/ ne00/4,
|
3250
|
+
/*.nb01 =*/ nb01,
|
3251
|
+
/*.eps =*/ eps,
|
3252
|
+
};
|
3253
|
+
|
3254
|
+
[encoder setComputePipelineState:pipeline];
|
3255
|
+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
3256
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
3257
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
3258
|
+
|
3259
|
+
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
3260
|
+
|
3261
|
+
const int64_t nrows = lm_ggml_nrows(src0);
|
3262
|
+
|
3125
3263
|
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
3126
3264
|
} break;
|
3127
3265
|
case LM_GGML_OP_GROUP_NORM:
|
@@ -3654,7 +3792,9 @@ static void lm_ggml_metal_encode_node(
|
|
3654
3792
|
LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_F32);
|
3655
3793
|
LM_GGML_ASSERT(src1->type == src2->type);
|
3656
3794
|
|
3657
|
-
LM_GGML_ASSERT(lm_ggml_are_same_shape (src1, src2));
|
3795
|
+
//LM_GGML_ASSERT(lm_ggml_are_same_shape (src1, src2));
|
3796
|
+
LM_GGML_ASSERT(ne11 == ne21);
|
3797
|
+
LM_GGML_ASSERT(ne12 == ne22);
|
3658
3798
|
|
3659
3799
|
struct lm_ggml_tensor * src3 = node->src[3];
|
3660
3800
|
|
@@ -3701,125 +3841,161 @@ static void lm_ggml_metal_encode_node(
|
|
3701
3841
|
|
3702
3842
|
// TODO: add vec kernels for (ne00%64 == 0) and maybe also for (ne00%32 == 0)
|
3703
3843
|
// for now avoiding mainly to keep the number of templates/kernels a bit lower
|
3704
|
-
|
3844
|
+
// these are now trivial to add after: https://github.com/ggml-org/llama.cpp/pull/12612
|
3845
|
+
if (ne01 >= 4 || (ne00%128 != 0 && ne00 != 192)) {
|
3705
3846
|
switch (src1->type) {
|
3706
3847
|
case LM_GGML_TYPE_F16:
|
3707
3848
|
{
|
3708
|
-
|
3709
|
-
|
3710
|
-
|
3711
|
-
|
3712
|
-
|
3713
|
-
|
3714
|
-
|
3715
|
-
|
3716
|
-
|
3717
|
-
|
3718
|
-
|
3719
|
-
|
3720
|
-
|
3849
|
+
if (ne00 == 192 && ne20 == 128) {
|
3850
|
+
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128].pipeline;
|
3851
|
+
} else {
|
3852
|
+
switch (ne00) {
|
3853
|
+
case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
|
3854
|
+
case 80: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;
|
3855
|
+
case 96: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
|
3856
|
+
case 112: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
|
3857
|
+
case 128: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
|
3858
|
+
case 192: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H192].pipeline; break;
|
3859
|
+
case 256: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
|
3860
|
+
default:
|
3861
|
+
{
|
3862
|
+
LM_GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
|
3863
|
+
LM_GGML_LOG_ERROR("add template specialization for this size\n");
|
3864
|
+
LM_GGML_ABORT("add template specialization for this size");
|
3865
|
+
}
|
3866
|
+
}
|
3721
3867
|
}
|
3722
3868
|
} break;
|
3723
3869
|
case LM_GGML_TYPE_BF16:
|
3724
3870
|
{
|
3725
|
-
|
3726
|
-
|
3727
|
-
|
3728
|
-
|
3729
|
-
|
3730
|
-
|
3731
|
-
|
3732
|
-
|
3733
|
-
|
3734
|
-
|
3735
|
-
|
3736
|
-
|
3737
|
-
|
3871
|
+
if (ne00 == 192 && ne20 == 128) {
|
3872
|
+
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128].pipeline;
|
3873
|
+
} else {
|
3874
|
+
switch (ne00) {
|
3875
|
+
case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64 ].pipeline; break;
|
3876
|
+
case 80: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80 ].pipeline; break;
|
3877
|
+
case 96: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96 ].pipeline; break;
|
3878
|
+
case 112: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112].pipeline; break;
|
3879
|
+
case 128: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128].pipeline; break;
|
3880
|
+
case 192: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H192].pipeline; break;
|
3881
|
+
case 256: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256].pipeline; break;
|
3882
|
+
default:
|
3883
|
+
{
|
3884
|
+
LM_GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
|
3885
|
+
LM_GGML_LOG_ERROR("add template specialization for this size\n");
|
3886
|
+
LM_GGML_ABORT("add template specialization for this size");
|
3887
|
+
}
|
3888
|
+
}
|
3738
3889
|
}
|
3739
3890
|
} break;
|
3740
3891
|
case LM_GGML_TYPE_Q4_0:
|
3741
3892
|
{
|
3742
|
-
|
3743
|
-
|
3744
|
-
|
3745
|
-
|
3746
|
-
|
3747
|
-
|
3748
|
-
|
3749
|
-
|
3750
|
-
|
3751
|
-
|
3752
|
-
|
3753
|
-
|
3754
|
-
|
3893
|
+
if (ne00 == 192 && ne20 == 128) {
|
3894
|
+
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128].pipeline;
|
3895
|
+
} else {
|
3896
|
+
switch (ne00) {
|
3897
|
+
case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64 ].pipeline; break;
|
3898
|
+
case 80: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80 ].pipeline; break;
|
3899
|
+
case 96: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96 ].pipeline; break;
|
3900
|
+
case 112: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112].pipeline; break;
|
3901
|
+
case 128: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128].pipeline; break;
|
3902
|
+
case 192: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H192].pipeline; break;
|
3903
|
+
case 256: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256].pipeline; break;
|
3904
|
+
default:
|
3905
|
+
{
|
3906
|
+
LM_GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
|
3907
|
+
LM_GGML_LOG_ERROR("add template specialization for this size\n");
|
3908
|
+
LM_GGML_ABORT("add template specialization for this size");
|
3909
|
+
}
|
3910
|
+
}
|
3755
3911
|
}
|
3756
3912
|
} break;
|
3757
3913
|
case LM_GGML_TYPE_Q4_1:
|
3758
3914
|
{
|
3759
|
-
|
3760
|
-
|
3761
|
-
|
3762
|
-
|
3763
|
-
|
3764
|
-
|
3765
|
-
|
3766
|
-
|
3767
|
-
|
3768
|
-
|
3769
|
-
|
3770
|
-
|
3771
|
-
|
3915
|
+
if (ne00 == 192 && ne20 == 128) {
|
3916
|
+
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128].pipeline;
|
3917
|
+
} else {
|
3918
|
+
switch (ne00) {
|
3919
|
+
case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64 ].pipeline; break;
|
3920
|
+
case 80: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80 ].pipeline; break;
|
3921
|
+
case 96: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96 ].pipeline; break;
|
3922
|
+
case 112: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112].pipeline; break;
|
3923
|
+
case 128: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128].pipeline; break;
|
3924
|
+
case 192: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H192].pipeline; break;
|
3925
|
+
case 256: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256].pipeline; break;
|
3926
|
+
default:
|
3927
|
+
{
|
3928
|
+
LM_GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
|
3929
|
+
LM_GGML_LOG_ERROR("add template specialization for this size\n");
|
3930
|
+
LM_GGML_ABORT("add template specialization for this size");
|
3931
|
+
}
|
3932
|
+
}
|
3772
3933
|
}
|
3773
3934
|
} break;
|
3774
3935
|
case LM_GGML_TYPE_Q5_0:
|
3775
3936
|
{
|
3776
|
-
|
3777
|
-
|
3778
|
-
|
3779
|
-
|
3780
|
-
|
3781
|
-
|
3782
|
-
|
3783
|
-
|
3784
|
-
|
3785
|
-
|
3786
|
-
|
3787
|
-
|
3788
|
-
|
3937
|
+
if (ne00 == 192 && ne20 == 128) {
|
3938
|
+
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128].pipeline;
|
3939
|
+
} else {
|
3940
|
+
switch (ne00) {
|
3941
|
+
case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64 ].pipeline; break;
|
3942
|
+
case 80: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80 ].pipeline; break;
|
3943
|
+
case 96: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96 ].pipeline; break;
|
3944
|
+
case 112: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112].pipeline; break;
|
3945
|
+
case 128: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128].pipeline; break;
|
3946
|
+
case 192: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H192].pipeline; break;
|
3947
|
+
case 256: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256].pipeline; break;
|
3948
|
+
default:
|
3949
|
+
{
|
3950
|
+
LM_GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
|
3951
|
+
LM_GGML_LOG_ERROR("add template specialization for this size\n");
|
3952
|
+
LM_GGML_ABORT("add template specialization for this size");
|
3953
|
+
}
|
3954
|
+
}
|
3789
3955
|
}
|
3790
3956
|
} break;
|
3791
3957
|
case LM_GGML_TYPE_Q5_1:
|
3792
3958
|
{
|
3793
|
-
|
3794
|
-
|
3795
|
-
|
3796
|
-
|
3797
|
-
|
3798
|
-
|
3799
|
-
|
3800
|
-
|
3801
|
-
|
3802
|
-
|
3803
|
-
|
3804
|
-
|
3805
|
-
|
3959
|
+
if (ne00 == 192 && ne20 == 128) {
|
3960
|
+
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128].pipeline;
|
3961
|
+
} else {
|
3962
|
+
switch (ne00) {
|
3963
|
+
case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64 ].pipeline; break;
|
3964
|
+
case 80: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80 ].pipeline; break;
|
3965
|
+
case 96: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96 ].pipeline; break;
|
3966
|
+
case 112: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112].pipeline; break;
|
3967
|
+
case 128: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128].pipeline; break;
|
3968
|
+
case 192: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H192].pipeline; break;
|
3969
|
+
case 256: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256].pipeline; break;
|
3970
|
+
default:
|
3971
|
+
{
|
3972
|
+
LM_GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
|
3973
|
+
LM_GGML_LOG_ERROR("add template specialization for this size\n");
|
3974
|
+
LM_GGML_ABORT("add template specialization for this size");
|
3975
|
+
}
|
3976
|
+
}
|
3806
3977
|
}
|
3807
3978
|
} break;
|
3808
3979
|
case LM_GGML_TYPE_Q8_0:
|
3809
3980
|
{
|
3810
|
-
|
3811
|
-
|
3812
|
-
|
3813
|
-
|
3814
|
-
|
3815
|
-
|
3816
|
-
|
3817
|
-
|
3818
|
-
|
3819
|
-
|
3820
|
-
|
3821
|
-
|
3822
|
-
|
3981
|
+
if (ne00 == 192 && ne20 == 128) {
|
3982
|
+
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128].pipeline;
|
3983
|
+
} else {
|
3984
|
+
switch (ne00) {
|
3985
|
+
case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64 ].pipeline; break;
|
3986
|
+
case 80: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80 ].pipeline; break;
|
3987
|
+
case 96: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96 ].pipeline; break;
|
3988
|
+
case 112: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112].pipeline; break;
|
3989
|
+
case 128: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128].pipeline; break;
|
3990
|
+
case 192: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192].pipeline; break;
|
3991
|
+
case 256: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256].pipeline; break;
|
3992
|
+
default:
|
3993
|
+
{
|
3994
|
+
LM_GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
|
3995
|
+
LM_GGML_LOG_ERROR("add template specialization for this size\n");
|
3996
|
+
LM_GGML_ABORT("add template specialization for this size");
|
3997
|
+
}
|
3998
|
+
}
|
3823
3999
|
}
|
3824
4000
|
} break;
|
3825
4001
|
default:
|
@@ -3851,6 +4027,42 @@ static void lm_ggml_metal_encode_node(
|
|
3851
4027
|
}
|
3852
4028
|
}
|
3853
4029
|
} break;
|
4030
|
+
case 192:
|
4031
|
+
{
|
4032
|
+
if (ne20 == 128) {
|
4033
|
+
switch (src1->type) {
|
4034
|
+
case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK192_HV128].pipeline; break;
|
4035
|
+
case LM_GGML_TYPE_BF16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK192_HV128].pipeline; break;
|
4036
|
+
case LM_GGML_TYPE_Q4_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK192_HV128].pipeline; break;
|
4037
|
+
case LM_GGML_TYPE_Q4_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK192_HV128].pipeline; break;
|
4038
|
+
case LM_GGML_TYPE_Q5_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK192_HV128].pipeline; break;
|
4039
|
+
case LM_GGML_TYPE_Q5_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK192_HV128].pipeline; break;
|
4040
|
+
case LM_GGML_TYPE_Q8_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK192_HV128].pipeline; break;
|
4041
|
+
default:
|
4042
|
+
{
|
4043
|
+
LM_GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
|
4044
|
+
LM_GGML_LOG_ERROR("add template specialization for this type\n");
|
4045
|
+
LM_GGML_ABORT("add template specialization for this type");
|
4046
|
+
}
|
4047
|
+
}
|
4048
|
+
} else {
|
4049
|
+
switch (src1->type) {
|
4050
|
+
case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H192].pipeline; break;
|
4051
|
+
case LM_GGML_TYPE_BF16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H192].pipeline; break;
|
4052
|
+
case LM_GGML_TYPE_Q4_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H192].pipeline; break;
|
4053
|
+
case LM_GGML_TYPE_Q4_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H192].pipeline; break;
|
4054
|
+
case LM_GGML_TYPE_Q5_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H192].pipeline; break;
|
4055
|
+
case LM_GGML_TYPE_Q5_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H192].pipeline; break;
|
4056
|
+
case LM_GGML_TYPE_Q8_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H192].pipeline; break;
|
4057
|
+
default:
|
4058
|
+
{
|
4059
|
+
LM_GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
|
4060
|
+
LM_GGML_LOG_ERROR("add template specialization for this type\n");
|
4061
|
+
LM_GGML_ABORT("add template specialization for this type");
|
4062
|
+
}
|
4063
|
+
}
|
4064
|
+
}
|
4065
|
+
} break;
|
3854
4066
|
case 256:
|
3855
4067
|
{
|
3856
4068
|
switch (src1->type) {
|
@@ -3888,9 +4100,12 @@ static void lm_ggml_metal_encode_node(
|
|
3888
4100
|
/*.ne11 =*/ ne11,
|
3889
4101
|
/*.ne_12_2 =*/ ne12,
|
3890
4102
|
/*.ne_12_3 =*/ ne13,
|
3891
|
-
/*.
|
3892
|
-
/*.
|
3893
|
-
/*.
|
4103
|
+
/*.nb11 =*/ nb11,
|
4104
|
+
/*.nb12 =*/ nb12,
|
4105
|
+
/*.nb13 =*/ nb13,
|
4106
|
+
/*.nb21 =*/ nb21,
|
4107
|
+
/*.nb22 =*/ nb22,
|
4108
|
+
/*.nb23 =*/ nb23,
|
3894
4109
|
/*.nb31 =*/ nb31,
|
3895
4110
|
/*.ne1 =*/ ne1,
|
3896
4111
|
/*.ne2 =*/ ne2,
|
@@ -3969,10 +4184,9 @@ static void lm_ggml_metal_encode_node(
|
|
3969
4184
|
// ne00*(nsg)
|
3970
4185
|
// each simdgroup has a full f16 head vector in shared mem to accumulate results
|
3971
4186
|
//
|
3972
|
-
#define FATTN_SMEM(nsg) (LM_GGML_PAD((nqptg*(ne00 +
|
4187
|
+
#define FATTN_SMEM(nsg) (LM_GGML_PAD((nqptg*(LM_GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + ne20*(nsg))*(sizeof(float)/2), 16))
|
3973
4188
|
|
3974
4189
|
int64_t nsgmax = 2;
|
3975
|
-
|
3976
4190
|
while (true) {
|
3977
4191
|
const size_t smem = FATTN_SMEM(nsgmax);
|
3978
4192
|
if (smem > device.maxThreadgroupMemoryLength) {
|