cui-llama.rn 1.2.6 → 1.3.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (75) hide show
  1. package/README.md +3 -2
  2. package/android/src/main/CMakeLists.txt +26 -6
  3. package/android/src/main/java/com/rnllama/LlamaContext.java +115 -27
  4. package/android/src/main/java/com/rnllama/RNLlama.java +40 -7
  5. package/android/src/main/jni.cpp +228 -40
  6. package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +9 -4
  7. package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +9 -4
  8. package/cpp/amx/amx.cpp +196 -0
  9. package/cpp/amx/amx.h +20 -0
  10. package/cpp/amx/common.h +101 -0
  11. package/cpp/amx/mmq.cpp +2524 -0
  12. package/cpp/amx/mmq.h +16 -0
  13. package/cpp/common.cpp +118 -251
  14. package/cpp/common.h +53 -30
  15. package/cpp/ggml-aarch64.c +46 -3395
  16. package/cpp/ggml-aarch64.h +0 -20
  17. package/cpp/ggml-alloc.c +6 -8
  18. package/cpp/ggml-backend-impl.h +33 -11
  19. package/cpp/ggml-backend-reg.cpp +423 -0
  20. package/cpp/ggml-backend.cpp +14 -676
  21. package/cpp/ggml-backend.h +46 -9
  22. package/cpp/ggml-common.h +6 -0
  23. package/cpp/ggml-cpu-aarch64.c +3823 -0
  24. package/cpp/ggml-cpu-aarch64.h +32 -0
  25. package/cpp/ggml-cpu-impl.h +14 -242
  26. package/cpp/ggml-cpu-quants.c +10835 -0
  27. package/cpp/ggml-cpu-quants.h +63 -0
  28. package/cpp/ggml-cpu.c +13971 -13720
  29. package/cpp/ggml-cpu.cpp +715 -0
  30. package/cpp/ggml-cpu.h +65 -63
  31. package/cpp/ggml-impl.h +285 -25
  32. package/cpp/ggml-metal.h +8 -8
  33. package/cpp/ggml-metal.m +1221 -728
  34. package/cpp/ggml-quants.c +189 -10681
  35. package/cpp/ggml-quants.h +78 -125
  36. package/cpp/ggml-threading.cpp +12 -0
  37. package/cpp/ggml-threading.h +12 -0
  38. package/cpp/ggml.c +688 -1460
  39. package/cpp/ggml.h +58 -244
  40. package/cpp/json-schema-to-grammar.cpp +1045 -1045
  41. package/cpp/json.hpp +24766 -24766
  42. package/cpp/llama-sampling.cpp +5 -2
  43. package/cpp/llama.cpp +409 -123
  44. package/cpp/llama.h +8 -4
  45. package/cpp/rn-llama.hpp +89 -25
  46. package/cpp/sampling.cpp +42 -3
  47. package/cpp/sampling.h +22 -1
  48. package/cpp/sgemm.cpp +608 -0
  49. package/cpp/speculative.cpp +270 -0
  50. package/cpp/speculative.h +28 -0
  51. package/cpp/unicode.cpp +11 -0
  52. package/ios/RNLlama.mm +43 -20
  53. package/ios/RNLlamaContext.h +9 -3
  54. package/ios/RNLlamaContext.mm +146 -33
  55. package/jest/mock.js +0 -1
  56. package/lib/commonjs/NativeRNLlama.js.map +1 -1
  57. package/lib/commonjs/grammar.js +4 -2
  58. package/lib/commonjs/grammar.js.map +1 -1
  59. package/lib/commonjs/index.js +52 -15
  60. package/lib/commonjs/index.js.map +1 -1
  61. package/lib/module/NativeRNLlama.js.map +1 -1
  62. package/lib/module/grammar.js +2 -1
  63. package/lib/module/grammar.js.map +1 -1
  64. package/lib/module/index.js +51 -15
  65. package/lib/module/index.js.map +1 -1
  66. package/lib/typescript/NativeRNLlama.d.ts +122 -8
  67. package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
  68. package/lib/typescript/grammar.d.ts +5 -6
  69. package/lib/typescript/grammar.d.ts.map +1 -1
  70. package/lib/typescript/index.d.ts +15 -6
  71. package/lib/typescript/index.d.ts.map +1 -1
  72. package/package.json +2 -1
  73. package/src/NativeRNLlama.ts +135 -13
  74. package/src/grammar.ts +10 -8
  75. package/src/index.ts +104 -28
package/cpp/ggml-metal.m CHANGED
@@ -2,6 +2,7 @@
2
2
 
3
3
  #import "ggml-impl.h"
4
4
  #import "ggml-backend-impl.h"
5
+ #import "ggml-metal-impl.h"
5
6
 
6
7
  #import <Foundation/Foundation.h>
7
8
 
@@ -36,16 +37,20 @@ static struct lm_ggml_backend_metal_device_context {
36
37
  id<MTLDevice> mtl_device;
37
38
  int mtl_device_ref_count;
38
39
 
39
- bool support_simdgroup_reduction;
40
- bool support_simdgroup_mm;
40
+ bool has_simdgroup_reduction;
41
+ bool has_simdgroup_mm;
42
+ bool has_bfloat;
43
+ bool use_bfloat;
41
44
 
42
45
  char name[128];
43
46
  } g_lm_ggml_ctx_dev_main = {
44
- /*.mtl_device =*/ nil,
45
- /*.mtl_device_ref_count =*/ 0,
46
- /*.support_simdgroup_reduction =*/ false,
47
- /*.support_simdgroup_mm =*/ false,
48
- /*.name =*/ "",
47
+ /*.mtl_device =*/ nil,
48
+ /*.mtl_device_ref_count =*/ 0,
49
+ /*.has_simdgroup_reduction =*/ false,
50
+ /*.has_simdgroup_mm =*/ false,
51
+ /*.has_bfloat =*/ false,
52
+ /*.use_bfloat =*/ false,
53
+ /*.name =*/ "",
49
54
  };
50
55
 
51
56
  // acquire
@@ -55,10 +60,19 @@ static id<MTLDevice> lm_ggml_backend_metal_device_acq(struct lm_ggml_backend_met
55
60
  if (ctx->mtl_device == nil) {
56
61
  ctx->mtl_device = MTLCreateSystemDefaultDevice();
57
62
 
58
- ctx->support_simdgroup_reduction = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
59
- ctx->support_simdgroup_reduction |= [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
63
+ ctx->has_simdgroup_reduction = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
64
+ ctx->has_simdgroup_reduction |= [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
60
65
 
61
- ctx->support_simdgroup_mm = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
66
+ ctx->has_simdgroup_mm = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
67
+
68
+ ctx->has_bfloat = [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
69
+ ctx->has_bfloat |= [ctx->mtl_device supportsFamily:MTLGPUFamilyApple6];
70
+
71
+ #if defined(LM_GGML_METAL_USE_BF16)
72
+ ctx->use_bfloat = ctx->has_bfloat;
73
+ #else
74
+ ctx->use_bfloat = false;
75
+ #endif
62
76
 
63
77
  strncpy(ctx->name, [[ctx->mtl_device name] UTF8String], sizeof(ctx->name) - 1);
64
78
  }
@@ -112,6 +126,7 @@ enum lm_ggml_metal_kernel_type {
112
126
  LM_GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,
113
127
  LM_GGML_METAL_KERNEL_TYPE_SILU,
114
128
  LM_GGML_METAL_KERNEL_TYPE_SILU_4,
129
+ LM_GGML_METAL_KERNEL_TYPE_ELU,
115
130
  LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16,
116
131
  LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4,
117
132
  LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32,
@@ -120,6 +135,7 @@ enum lm_ggml_metal_kernel_type {
120
135
  LM_GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8,
121
136
  LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_F32,
122
137
  LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_F16,
138
+ LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16,
123
139
  LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0,
124
140
  LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1,
125
141
  LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0,
@@ -146,10 +162,14 @@ enum lm_ggml_metal_kernel_type {
146
162
  LM_GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
147
163
  LM_GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,
148
164
  LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
149
- LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,
150
165
  LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
151
166
  LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
152
167
  LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4,
168
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,
169
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32,
170
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW,
171
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4,
172
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16,
153
173
  LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32,
154
174
  LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32,
155
175
  LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,
@@ -170,10 +190,11 @@ enum lm_ggml_metal_kernel_type {
170
190
  LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32,
171
191
  LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32,
172
192
  LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
173
- //LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,
174
193
  LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,
175
194
  //LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW,
176
195
  //LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4,
196
+ //LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,
197
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32,
177
198
  LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32,
178
199
  LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32,
179
200
  LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32,
@@ -195,6 +216,7 @@ enum lm_ggml_metal_kernel_type {
195
216
  LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32,
196
217
  LM_GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
197
218
  LM_GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,
219
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32,
198
220
  LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,
199
221
  LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32,
200
222
  LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32,
@@ -216,6 +238,7 @@ enum lm_ggml_metal_kernel_type {
216
238
  LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,
217
239
  LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
218
240
  LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,
241
+ LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32,
219
242
  LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,
220
243
  LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32,
221
244
  LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32,
@@ -255,13 +278,64 @@ enum lm_ggml_metal_kernel_type {
255
278
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,
256
279
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
257
280
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
258
- //LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
281
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
282
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64,
283
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80,
284
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96,
285
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112,
286
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128,
287
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256,
288
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64,
289
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80,
290
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96,
291
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112,
292
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128,
293
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256,
294
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64,
295
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80,
296
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96,
297
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112,
298
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128,
299
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256,
300
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64,
301
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80,
302
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96,
303
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112,
304
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128,
305
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256,
306
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64,
307
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80,
308
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96,
309
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112,
310
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128,
311
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256,
312
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64,
313
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80,
314
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96,
315
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112,
316
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128,
317
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,
259
318
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
260
- //LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
319
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128,
320
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128,
321
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128,
322
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128,
323
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128,
324
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128,
325
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,
326
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256,
327
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256,
328
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256,
329
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256,
330
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256,
331
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256,
261
332
  LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
262
333
  LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
334
+ LM_GGML_METAL_KERNEL_TYPE_CPY_F32_BF16,
263
335
  LM_GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
264
336
  LM_GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
337
+ LM_GGML_METAL_KERNEL_TYPE_CPY_BF16_F32,
338
+ LM_GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16,
265
339
  LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
266
340
  LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,
267
341
  LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,
@@ -440,7 +514,15 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
440
514
  // dictionary of preprocessor macros
441
515
  NSMutableDictionary * prep = [NSMutableDictionary dictionary];
442
516
 
443
- MTLCompileOptions* options = [MTLCompileOptions new];
517
+ if (ctx_dev->use_bfloat) {
518
+ [prep setObject:@"1" forKey:@"LM_GGML_METAL_USE_BF16"];
519
+ }
520
+
521
+ #if LM_GGML_METAL_EMBED_LIBRARY
522
+ [prep setObject:@"1" forKey:@"LM_GGML_METAL_EMBED_LIBRARY"];
523
+ #endif
524
+
525
+ MTLCompileOptions * options = [MTLCompileOptions new];
444
526
  options.preprocessorMacros = prep;
445
527
 
446
528
  //[options setFastMathEnabled:false];
@@ -490,9 +572,11 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
490
572
  }
491
573
  }
492
574
 
493
- LM_GGML_LOG_INFO("%s: simdgroup reduction support = %s\n", __func__, ctx_dev->support_simdgroup_reduction ? "true" : "false");
494
- LM_GGML_LOG_INFO("%s: simdgroup matrix mul. support = %s\n", __func__, ctx_dev->support_simdgroup_mm ? "true" : "false");
495
- LM_GGML_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx_dev->mtl_device.hasUnifiedMemory ? "true" : "false");
575
+ LM_GGML_LOG_INFO("%s: simdgroup reduction = %s\n", __func__, ctx_dev->has_simdgroup_reduction ? "true" : "false");
576
+ LM_GGML_LOG_INFO("%s: simdgroup matrix mul. = %s\n", __func__, ctx_dev->has_simdgroup_mm ? "true" : "false");
577
+ LM_GGML_LOG_INFO("%s: has bfloat = %s\n", __func__, ctx_dev->has_bfloat ? "true" : "false");
578
+ LM_GGML_LOG_INFO("%s: use bfloat = %s\n", __func__, ctx_dev->use_bfloat ? "true" : "false");
579
+ LM_GGML_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx_dev->mtl_device.hasUnifiedMemory ? "true" : "false");
496
580
 
497
581
  ctx->capture_next_compute = false;
498
582
  ctx->capture_started = false;
@@ -518,16 +602,14 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
518
602
  ctx->kernels[i].pipeline = nil;
519
603
  }
520
604
 
521
- /*
522
- LM_GGML_LOG_INFO("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \
523
- (int) kernel->pipeline.maxTotalThreadsPerThreadgroup, \
524
- (int) kernel->pipeline.threadExecutionWidth); \
525
- */
526
605
  #define LM_GGML_METAL_ADD_KERNEL(e, name, supported) \
527
606
  if (supported) { \
528
607
  struct lm_ggml_metal_kernel * kernel = &ctx->kernels[e]; \
529
608
  id<MTLFunction> metal_function = [metal_library newFunctionWithName:@"kernel_"#name]; \
530
609
  kernel->pipeline = [device newComputePipelineStateWithFunction:metal_function error:&error]; \
610
+ LM_GGML_LOG_DEBUG("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \
611
+ (int) kernel->pipeline.maxTotalThreadsPerThreadgroup, \
612
+ (int) kernel->pipeline.threadExecutionWidth); \
531
613
  [metal_function release]; \
532
614
  if (error) { \
533
615
  LM_GGML_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
@@ -538,8 +620,9 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
538
620
  LM_GGML_LOG_WARN("%s: skipping %-40s (not supported)\n", __func__, "kernel_"#name); \
539
621
  }
540
622
 
541
- const bool support_simdgroup_mm = ctx_dev->support_simdgroup_mm;
542
- const bool support_simdgroup_reduction = ctx_dev->support_simdgroup_reduction;
623
+ const bool has_simdgroup_mm = ctx_dev->has_simdgroup_mm;
624
+ const bool has_simdgroup_reduction = ctx_dev->has_simdgroup_reduction;
625
+ const bool use_bfloat = ctx_dev->use_bfloat;
543
626
 
544
627
  // simd_sum and simd_max requires MTLGPUFamilyApple7
545
628
 
@@ -567,14 +650,16 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
567
650
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true);
568
651
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SILU, silu, true);
569
652
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true);
570
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, support_simdgroup_reduction);
571
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, support_simdgroup_reduction);
572
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, support_simdgroup_reduction);
573
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4, soft_max_f32_4, support_simdgroup_reduction);
653
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ELU, elu, true);
654
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, has_simdgroup_reduction);
655
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, has_simdgroup_reduction);
656
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, has_simdgroup_reduction);
657
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4, soft_max_f32_4, has_simdgroup_reduction);
574
658
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true);
575
659
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true);
576
660
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true);
577
661
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true);
662
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16, get_rows_bf16, use_bfloat);
578
663
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true);
579
664
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true);
580
665
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true);
@@ -595,101 +680,108 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
595
680
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
596
681
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
597
682
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
598
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, support_simdgroup_reduction);
599
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, support_simdgroup_reduction);
683
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction);
684
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction);
600
685
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_NORM, norm, true);
601
686
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true);
602
687
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true);
603
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, support_simdgroup_reduction);
604
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, support_simdgroup_reduction);
605
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, support_simdgroup_reduction);
606
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, support_simdgroup_reduction);
607
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, support_simdgroup_reduction);
608
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, support_simdgroup_reduction);
609
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, support_simdgroup_reduction);
610
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, support_simdgroup_reduction);
611
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, support_simdgroup_reduction);
612
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, support_simdgroup_reduction);
613
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, support_simdgroup_reduction);
614
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, support_simdgroup_reduction);
615
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, support_simdgroup_reduction);
616
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, support_simdgroup_reduction);
617
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, support_simdgroup_reduction);
618
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, support_simdgroup_reduction);
619
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, support_simdgroup_reduction);
620
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, support_simdgroup_reduction);
621
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, mul_mv_iq3_s_f32, support_simdgroup_reduction);
622
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, support_simdgroup_reduction);
623
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, support_simdgroup_reduction);
624
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32, mul_mv_iq1_m_f32, support_simdgroup_reduction);
625
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, support_simdgroup_reduction);
626
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, support_simdgroup_reduction);
627
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, support_simdgroup_reduction);
628
- //LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, support_simdgroup_reduction);
629
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, support_simdgroup_reduction);
630
- //LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, support_simdgroup_reduction);
631
- //LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, support_simdgroup_reduction);
632
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, support_simdgroup_reduction);
633
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, support_simdgroup_reduction);
634
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, support_simdgroup_reduction);
635
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, support_simdgroup_reduction);
636
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, support_simdgroup_reduction);
637
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, support_simdgroup_reduction);
638
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, support_simdgroup_reduction);
639
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, support_simdgroup_reduction);
640
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, support_simdgroup_reduction);
641
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, support_simdgroup_reduction);
642
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, support_simdgroup_reduction);
643
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, support_simdgroup_reduction);
644
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, support_simdgroup_reduction);
645
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, mul_mv_id_iq3_s_f32, support_simdgroup_reduction);
646
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, support_simdgroup_reduction);
647
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, support_simdgroup_reduction);
648
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, mul_mv_id_iq1_m_f32, support_simdgroup_reduction);
649
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, support_simdgroup_reduction);
650
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, support_simdgroup_reduction);
651
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, support_simdgroup_mm);
652
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, support_simdgroup_mm);
653
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, support_simdgroup_mm);
654
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, support_simdgroup_mm);
655
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, support_simdgroup_mm);
656
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, support_simdgroup_mm);
657
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, support_simdgroup_mm);
658
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, support_simdgroup_mm);
659
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, support_simdgroup_mm);
660
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, support_simdgroup_mm);
661
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, support_simdgroup_mm);
662
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, support_simdgroup_mm);
663
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, support_simdgroup_mm);
664
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, support_simdgroup_mm);
665
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, support_simdgroup_mm);
666
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, mul_mm_iq3_s_f32, support_simdgroup_mm);
667
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, support_simdgroup_mm);
668
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, support_simdgroup_mm);
669
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, support_simdgroup_mm);
670
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, support_simdgroup_mm);
671
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, support_simdgroup_mm);
672
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, support_simdgroup_mm);
673
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, support_simdgroup_mm);
674
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, support_simdgroup_mm);
675
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, support_simdgroup_mm);
676
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, support_simdgroup_mm);
677
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, support_simdgroup_mm);
678
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, support_simdgroup_mm);
679
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, support_simdgroup_mm);
680
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, support_simdgroup_mm);
681
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, support_simdgroup_mm);
682
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, support_simdgroup_mm);
683
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, support_simdgroup_mm);
684
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, support_simdgroup_mm);
685
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, support_simdgroup_mm);
686
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, support_simdgroup_mm);
687
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, support_simdgroup_mm);
688
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, support_simdgroup_mm);
689
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, support_simdgroup_mm);
690
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, support_simdgroup_mm);
691
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, support_simdgroup_mm);
692
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, support_simdgroup_mm);
688
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction);
689
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, has_simdgroup_reduction && use_bfloat);
690
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, has_simdgroup_reduction && use_bfloat);
691
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4, mul_mv_bf16_f32_l4, has_simdgroup_reduction && use_bfloat);
692
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16, mul_mv_bf16_bf16, has_simdgroup_reduction && use_bfloat);
693
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, has_simdgroup_reduction);
694
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, has_simdgroup_reduction);
695
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, has_simdgroup_reduction);
696
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, has_simdgroup_reduction);
697
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, has_simdgroup_reduction);
698
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, has_simdgroup_reduction);
699
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, has_simdgroup_reduction);
700
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, has_simdgroup_reduction);
701
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, has_simdgroup_reduction);
702
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, has_simdgroup_reduction);
703
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, has_simdgroup_reduction);
704
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, has_simdgroup_reduction);
705
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, has_simdgroup_reduction);
706
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, has_simdgroup_reduction);
707
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, has_simdgroup_reduction);
708
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, has_simdgroup_reduction);
709
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, has_simdgroup_reduction);
710
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, mul_mv_iq3_s_f32, has_simdgroup_reduction);
711
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, has_simdgroup_reduction);
712
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, has_simdgroup_reduction);
713
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32, mul_mv_iq1_m_f32, has_simdgroup_reduction);
714
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, has_simdgroup_reduction);
715
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, has_simdgroup_reduction);
716
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, has_simdgroup_reduction);
717
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, has_simdgroup_reduction);
718
+ //LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, has_simdgroup_reduction);
719
+ //LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, has_simdgroup_reduction);
720
+ //LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, has_simdgroup_reduction);
721
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32, mul_mv_id_bf16_f32, has_simdgroup_reduction && use_bfloat);
722
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, has_simdgroup_reduction);
723
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, has_simdgroup_reduction);
724
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, has_simdgroup_reduction);
725
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, has_simdgroup_reduction);
726
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, has_simdgroup_reduction);
727
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, has_simdgroup_reduction);
728
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, has_simdgroup_reduction);
729
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, has_simdgroup_reduction);
730
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, has_simdgroup_reduction);
731
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, has_simdgroup_reduction);
732
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, has_simdgroup_reduction);
733
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, has_simdgroup_reduction);
734
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, has_simdgroup_reduction);
735
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, mul_mv_id_iq3_s_f32, has_simdgroup_reduction);
736
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, has_simdgroup_reduction);
737
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, has_simdgroup_reduction);
738
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, mul_mv_id_iq1_m_f32, has_simdgroup_reduction);
739
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, has_simdgroup_reduction);
740
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, has_simdgroup_reduction);
741
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, has_simdgroup_mm);
742
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, has_simdgroup_mm);
743
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32, mul_mm_bf16_f32, has_simdgroup_mm && use_bfloat);
744
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, has_simdgroup_mm);
745
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, has_simdgroup_mm);
746
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, has_simdgroup_mm);
747
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, has_simdgroup_mm);
748
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, has_simdgroup_mm);
749
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, has_simdgroup_mm);
750
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, has_simdgroup_mm);
751
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, has_simdgroup_mm);
752
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, has_simdgroup_mm);
753
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, has_simdgroup_mm);
754
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, has_simdgroup_mm);
755
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, has_simdgroup_mm);
756
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, has_simdgroup_mm);
757
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, mul_mm_iq3_s_f32, has_simdgroup_mm);
758
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, has_simdgroup_mm);
759
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, has_simdgroup_mm);
760
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, has_simdgroup_mm);
761
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, has_simdgroup_mm);
762
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, has_simdgroup_mm);
763
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, has_simdgroup_mm);
764
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, has_simdgroup_mm);
765
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32, mul_mm_id_bf16_f32, has_simdgroup_mm && use_bfloat);
766
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, has_simdgroup_mm);
767
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, has_simdgroup_mm);
768
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, has_simdgroup_mm);
769
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, has_simdgroup_mm);
770
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, has_simdgroup_mm);
771
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, has_simdgroup_mm);
772
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, has_simdgroup_mm);
773
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, has_simdgroup_mm);
774
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, has_simdgroup_mm);
775
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, has_simdgroup_mm);
776
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, has_simdgroup_mm);
777
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, has_simdgroup_mm);
778
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, has_simdgroup_mm);
779
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, has_simdgroup_mm);
780
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, has_simdgroup_mm);
781
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, has_simdgroup_mm);
782
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, has_simdgroup_mm);
783
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, has_simdgroup_mm);
784
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, has_simdgroup_mm);
693
785
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true);
694
786
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true);
695
787
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true);
@@ -705,18 +797,69 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
705
797
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
706
798
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
707
799
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
708
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, support_simdgroup_mm);
709
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, support_simdgroup_mm);
710
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, support_simdgroup_mm);
711
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, support_simdgroup_mm);
712
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, support_simdgroup_mm);
713
- //LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, support_simdgroup_mm);
714
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, support_simdgroup_reduction);
715
- //LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, support_simdgroup_reduction);
716
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
800
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, has_simdgroup_mm);
801
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, has_simdgroup_mm);
802
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, has_simdgroup_mm);
803
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, has_simdgroup_mm);
804
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, has_simdgroup_mm);
805
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, has_simdgroup_mm);
806
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64, flash_attn_ext_bf16_h64, has_simdgroup_mm && use_bfloat);
807
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80, flash_attn_ext_bf16_h80, has_simdgroup_mm && use_bfloat);
808
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96, flash_attn_ext_bf16_h96, has_simdgroup_mm && use_bfloat);
809
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112, flash_attn_ext_bf16_h112, has_simdgroup_mm && use_bfloat);
810
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128, flash_attn_ext_bf16_h128, has_simdgroup_mm && use_bfloat);
811
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256, flash_attn_ext_bf16_h256, has_simdgroup_mm && use_bfloat);
812
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, flash_attn_ext_q4_0_h64, has_simdgroup_mm);
813
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, flash_attn_ext_q4_0_h80, has_simdgroup_mm);
814
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, flash_attn_ext_q4_0_h96, has_simdgroup_mm);
815
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112, flash_attn_ext_q4_0_h112, has_simdgroup_mm);
816
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128, flash_attn_ext_q4_0_h128, has_simdgroup_mm);
817
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256, flash_attn_ext_q4_0_h256, has_simdgroup_mm);
818
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64, flash_attn_ext_q4_1_h64, has_simdgroup_mm);
819
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80, flash_attn_ext_q4_1_h80, has_simdgroup_mm);
820
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96, flash_attn_ext_q4_1_h96, has_simdgroup_mm);
821
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112, flash_attn_ext_q4_1_h112, has_simdgroup_mm);
822
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128, flash_attn_ext_q4_1_h128, has_simdgroup_mm);
823
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256, flash_attn_ext_q4_1_h256, has_simdgroup_mm);
824
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64, flash_attn_ext_q5_0_h64, has_simdgroup_mm);
825
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80, flash_attn_ext_q5_0_h80, has_simdgroup_mm);
826
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96, flash_attn_ext_q5_0_h96, has_simdgroup_mm);
827
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112, flash_attn_ext_q5_0_h112, has_simdgroup_mm);
828
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128, flash_attn_ext_q5_0_h128, has_simdgroup_mm);
829
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256, flash_attn_ext_q5_0_h256, has_simdgroup_mm);
830
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64, flash_attn_ext_q5_1_h64, has_simdgroup_mm);
831
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80, flash_attn_ext_q5_1_h80, has_simdgroup_mm);
832
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96, flash_attn_ext_q5_1_h96, has_simdgroup_mm);
833
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112, flash_attn_ext_q5_1_h112, has_simdgroup_mm);
834
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128, flash_attn_ext_q5_1_h128, has_simdgroup_mm);
835
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256, flash_attn_ext_q5_1_h256, has_simdgroup_mm);
836
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64, flash_attn_ext_q8_0_h64, has_simdgroup_mm);
837
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80, flash_attn_ext_q8_0_h80, has_simdgroup_mm);
838
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96, flash_attn_ext_q8_0_h96, has_simdgroup_mm);
839
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112, flash_attn_ext_q8_0_h112, has_simdgroup_mm);
840
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128, flash_attn_ext_q8_0_h128, has_simdgroup_mm);
841
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, has_simdgroup_mm);
842
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, has_simdgroup_reduction);
843
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128, flash_attn_ext_vec_bf16_h128, has_simdgroup_reduction && use_bfloat);
844
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128, flash_attn_ext_vec_q4_0_h128, has_simdgroup_reduction);
845
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128, flash_attn_ext_vec_q4_1_h128, has_simdgroup_reduction);
846
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128, flash_attn_ext_vec_q5_0_h128, has_simdgroup_reduction);
847
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128, flash_attn_ext_vec_q5_1_h128, has_simdgroup_reduction);
848
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128, flash_attn_ext_vec_q8_0_h128, has_simdgroup_reduction);
849
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, has_simdgroup_reduction);
850
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256, flash_attn_ext_vec_bf16_h256, has_simdgroup_reduction && use_bfloat);
851
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256, flash_attn_ext_vec_q4_0_h256, has_simdgroup_reduction);
852
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256, flash_attn_ext_vec_q4_1_h256, has_simdgroup_reduction);
853
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256, has_simdgroup_reduction);
854
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256, flash_attn_ext_vec_q5_1_h256, has_simdgroup_reduction);
855
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, flash_attn_ext_vec_q8_0_h256, has_simdgroup_reduction);
717
856
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
718
- LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
857
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
858
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_BF16, cpy_f32_bf16, use_bfloat);
719
859
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true);
860
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
861
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_BF16_F32, cpy_bf16_f32, use_bfloat);
862
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16, cpy_bf16_bf16, use_bfloat);
720
863
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
721
864
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true);
722
865
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true);
@@ -806,15 +949,18 @@ static id<MTLBuffer> lm_ggml_metal_get_buffer(struct lm_ggml_tensor * t, size_t
806
949
  }
807
950
 
808
951
  static bool lm_ggml_metal_supports_op(const struct lm_ggml_backend_metal_device_context * ctx_dev, const struct lm_ggml_tensor * op) {
809
- for (size_t i = 0, n = 3; i < n; ++i) {
810
- if (op->src[i] != NULL && op->src[i]->type == LM_GGML_TYPE_BF16) {
811
- return false;
952
+ const bool has_simdgroup_mm = ctx_dev->has_simdgroup_mm;
953
+ const bool has_simdgroup_reduction = ctx_dev->has_simdgroup_reduction;
954
+ const bool use_bfloat = ctx_dev->use_bfloat;
955
+
956
+ if (!use_bfloat) {
957
+ for (size_t i = 0, n = 3; i < n; ++i) {
958
+ if (op->src[i] != NULL && op->src[i]->type == LM_GGML_TYPE_BF16) {
959
+ return false;
960
+ }
812
961
  }
813
962
  }
814
963
 
815
- const bool support_simdgroup_mm = ctx_dev->support_simdgroup_mm;
816
- const bool support_simdgroup_reduction = ctx_dev->support_simdgroup_reduction;
817
-
818
964
  switch (op->op) {
819
965
  case LM_GGML_OP_UNARY:
820
966
  switch (lm_ggml_get_unary_op(op)) {
@@ -824,6 +970,7 @@ static bool lm_ggml_metal_supports_op(const struct lm_ggml_backend_metal_device_
824
970
  case LM_GGML_UNARY_OP_GELU:
825
971
  case LM_GGML_UNARY_OP_GELU_QUICK:
826
972
  case LM_GGML_UNARY_OP_SILU:
973
+ case LM_GGML_UNARY_OP_ELU:
827
974
  return lm_ggml_is_contiguous(op->src[0]);
828
975
  default:
829
976
  return false;
@@ -850,9 +997,10 @@ static bool lm_ggml_metal_supports_op(const struct lm_ggml_backend_metal_device_
850
997
  return lm_ggml_is_contiguous(op->src[0]);
851
998
  case LM_GGML_OP_SUM_ROWS:
852
999
  case LM_GGML_OP_SOFT_MAX:
853
- case LM_GGML_OP_RMS_NORM:
854
1000
  case LM_GGML_OP_GROUP_NORM:
855
- return support_simdgroup_reduction;
1001
+ return has_simdgroup_reduction;
1002
+ case LM_GGML_OP_RMS_NORM:
1003
+ return has_simdgroup_reduction && (op->ne[0] % 4 == 0);
856
1004
  case LM_GGML_OP_NORM:
857
1005
  case LM_GGML_OP_ROPE:
858
1006
  return true;
@@ -869,22 +1017,16 @@ static bool lm_ggml_metal_supports_op(const struct lm_ggml_backend_metal_device_
869
1017
  case LM_GGML_OP_LEAKY_RELU:
870
1018
  return true;
871
1019
  case LM_GGML_OP_FLASH_ATTN_EXT:
872
- if (op->src[1]->type != LM_GGML_TYPE_F16) {
873
- return false;
874
- }
875
- if (op->src[2]->type != LM_GGML_TYPE_F16) {
876
- return false;
877
- }
878
- if (op->src[0]->ne[0] == 256) {
1020
+ if (op->src[1]->type != op->src[2]->type) {
879
1021
  return false;
880
1022
  }
881
- return support_simdgroup_mm; // TODO: over-restricted for vec-kernels
1023
+ return has_simdgroup_mm; // TODO: over-restricted for vec-kernels
882
1024
  case LM_GGML_OP_SSM_CONV:
883
1025
  case LM_GGML_OP_SSM_SCAN:
884
1026
  return true;
885
1027
  case LM_GGML_OP_MUL_MAT:
886
1028
  case LM_GGML_OP_MUL_MAT_ID:
887
- return support_simdgroup_reduction &&
1029
+ return has_simdgroup_reduction &&
888
1030
  (op->src[0]->type != LM_GGML_TYPE_F32 || op->src[1]->type == LM_GGML_TYPE_F32);
889
1031
  case LM_GGML_OP_CPY:
890
1032
  case LM_GGML_OP_DUP:
@@ -895,6 +1037,7 @@ static bool lm_ggml_metal_supports_op(const struct lm_ggml_backend_metal_device_
895
1037
  switch (op->type) {
896
1038
  case LM_GGML_TYPE_F32:
897
1039
  case LM_GGML_TYPE_F16:
1040
+ case LM_GGML_TYPE_BF16:
898
1041
  case LM_GGML_TYPE_Q8_0:
899
1042
  case LM_GGML_TYPE_Q4_0:
900
1043
  case LM_GGML_TYPE_Q4_1:
@@ -907,10 +1050,18 @@ static bool lm_ggml_metal_supports_op(const struct lm_ggml_backend_metal_device_
907
1050
  }
908
1051
  case LM_GGML_TYPE_F16:
909
1052
  switch (op->type) {
910
- case LM_GGML_TYPE_F32:
911
- case LM_GGML_TYPE_F16:
1053
+ case LM_GGML_TYPE_F32:
1054
+ case LM_GGML_TYPE_F16:
912
1055
  return true;
913
- default:
1056
+ default:
1057
+ return false;
1058
+ }
1059
+ case LM_GGML_TYPE_BF16:
1060
+ switch (op->type) {
1061
+ case LM_GGML_TYPE_F32:
1062
+ case LM_GGML_TYPE_BF16:
1063
+ return true;
1064
+ default:
914
1065
  return false;
915
1066
  }
916
1067
  default:
@@ -996,7 +1147,7 @@ static void lm_ggml_metal_encode_node(
996
1147
  const uint64_t nb20 = src2 ? src2->nb[0] : 0; LM_GGML_UNUSED(nb20);
997
1148
  const uint64_t nb21 = src2 ? src2->nb[1] : 0;
998
1149
  const uint64_t nb22 = src2 ? src2->nb[2] : 0;
999
- const uint64_t nb23 = src2 ? src2->nb[3] : 0;
1150
+ const uint64_t nb23 = src2 ? src2->nb[3] : 0; LM_GGML_UNUSED(nb23);
1000
1151
 
1001
1152
  const int64_t ne0 = dst ? dst->ne[0] : 0;
1002
1153
  const int64_t ne1 = dst ? dst->ne[1] : 0;
@@ -1047,35 +1198,39 @@ static void lm_ggml_metal_encode_node(
1047
1198
 
1048
1199
  const int32_t dim = ((const int32_t *) dst->op_params)[0];
1049
1200
 
1201
+ lm_ggml_metal_kargs_concat args = {
1202
+ /*.ne00 =*/ ne00,
1203
+ /*.ne01 =*/ ne01,
1204
+ /*.ne02 =*/ ne02,
1205
+ /*.ne03 =*/ ne03,
1206
+ /*.nb00 =*/ nb00,
1207
+ /*.nb01 =*/ nb01,
1208
+ /*.nb02 =*/ nb02,
1209
+ /*.nb03 =*/ nb03,
1210
+ /*.ne10 =*/ ne10,
1211
+ /*.ne11 =*/ ne11,
1212
+ /*.ne12 =*/ ne12,
1213
+ /*.ne13 =*/ ne13,
1214
+ /*.nb10 =*/ nb10,
1215
+ /*.nb11 =*/ nb11,
1216
+ /*.nb12 =*/ nb12,
1217
+ /*.nb13 =*/ nb13,
1218
+ /*.ne0 =*/ ne0,
1219
+ /*.ne1 =*/ ne1,
1220
+ /*.ne2 =*/ ne2,
1221
+ /*.ne3 =*/ ne3,
1222
+ /*.nb0 =*/ nb0,
1223
+ /*.nb1 =*/ nb1,
1224
+ /*.nb2 =*/ nb2,
1225
+ /*.nb3 =*/ nb3,
1226
+ /*.dim =*/ dim,
1227
+ };
1228
+
1050
1229
  [encoder setComputePipelineState:pipeline];
1051
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1052
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1053
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1054
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1055
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1056
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1057
- [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
1058
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
1059
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
1060
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
1061
- [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
1062
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
1063
- [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
1064
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
1065
- [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
1066
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
1067
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
1068
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
1069
- [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
1070
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
1071
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
1072
- [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
1073
- [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
1074
- [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
1075
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
1076
- [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
1077
- [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
1078
- [encoder setBytes:&dim length:sizeof(dim) atIndex:27];
1230
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
1231
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
1232
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
1233
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
1079
1234
 
1080
1235
  const int nth = MIN(1024, ne0);
1081
1236
 
@@ -1093,8 +1248,6 @@ static void lm_ggml_metal_encode_node(
1093
1248
 
1094
1249
  bool bcast_row = false;
1095
1250
 
1096
- int64_t nb = ne00; // used by the "row" kernels
1097
-
1098
1251
  id<MTLComputePipelineState> pipeline = nil;
1099
1252
 
1100
1253
  if (lm_ggml_nelements(src1) == ne10 && lm_ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
@@ -1103,7 +1256,6 @@ static void lm_ggml_metal_encode_node(
1103
1256
  // src1 is a row
1104
1257
  LM_GGML_ASSERT(ne11 == 1);
1105
1258
 
1106
- nb = ne00 / 4;
1107
1259
  switch (dst->op) {
1108
1260
  case LM_GGML_OP_ADD: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break;
1109
1261
  case LM_GGML_OP_SUB: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_SUB_ROW].pipeline; break;
@@ -1123,36 +1275,39 @@ static void lm_ggml_metal_encode_node(
1123
1275
  }
1124
1276
  }
1125
1277
 
1278
+ lm_ggml_metal_kargs_bin args = {
1279
+ /*.ne00 =*/ ne00,
1280
+ /*.ne01 =*/ ne01,
1281
+ /*.ne02 =*/ ne02,
1282
+ /*.ne03 =*/ ne03,
1283
+ /*.nb00 =*/ nb00,
1284
+ /*.nb01 =*/ nb01,
1285
+ /*.nb02 =*/ nb02,
1286
+ /*.nb03 =*/ nb03,
1287
+ /*.ne10 =*/ ne10,
1288
+ /*.ne11 =*/ ne11,
1289
+ /*.ne12 =*/ ne12,
1290
+ /*.ne13 =*/ ne13,
1291
+ /*.nb10 =*/ nb10,
1292
+ /*.nb11 =*/ nb11,
1293
+ /*.nb12 =*/ nb12,
1294
+ /*.nb13 =*/ nb13,
1295
+ /*.ne0 =*/ ne0,
1296
+ /*.ne1 =*/ ne1,
1297
+ /*.ne2 =*/ ne2,
1298
+ /*.ne3 =*/ ne3,
1299
+ /*.nb0 =*/ nb0,
1300
+ /*.nb1 =*/ nb1,
1301
+ /*.nb2 =*/ nb2,
1302
+ /*.nb3 =*/ nb3,
1303
+ /*.offs =*/ offs,
1304
+ };
1305
+
1126
1306
  [encoder setComputePipelineState:pipeline];
1127
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1128
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1129
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1130
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1131
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1132
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1133
- [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
1134
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
1135
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
1136
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
1137
- [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
1138
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
1139
- [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
1140
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
1141
- [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
1142
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
1143
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
1144
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
1145
- [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
1146
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
1147
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
1148
- [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
1149
- [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
1150
- [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
1151
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
1152
- [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
1153
- [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
1154
- [encoder setBytes:&offs length:sizeof(offs) atIndex:27];
1155
- [encoder setBytes:&nb length:sizeof(nb) atIndex:28];
1307
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
1308
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
1309
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
1310
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
1156
1311
 
1157
1312
  if (bcast_row) {
1158
1313
  const int64_t n = lm_ggml_nelements(dst)/4;
@@ -1176,25 +1331,29 @@ static void lm_ggml_metal_encode_node(
1176
1331
  default: LM_GGML_ABORT("fatal error");
1177
1332
  }
1178
1333
 
1334
+ lm_ggml_metal_kargs_repeat args = {
1335
+ /*.ne00 =*/ ne00,
1336
+ /*.ne01 =*/ ne01,
1337
+ /*.ne02 =*/ ne02,
1338
+ /*.ne03 =*/ ne03,
1339
+ /*.nb00 =*/ nb00,
1340
+ /*.nb01 =*/ nb01,
1341
+ /*.nb02 =*/ nb02,
1342
+ /*.nb03 =*/ nb03,
1343
+ /*.ne0 =*/ ne0,
1344
+ /*.ne1 =*/ ne1,
1345
+ /*.ne2 =*/ ne2,
1346
+ /*.ne3 =*/ ne3,
1347
+ /*.nb0 =*/ nb0,
1348
+ /*.nb1 =*/ nb1,
1349
+ /*.nb2 =*/ nb2,
1350
+ /*.nb3 =*/ nb3,
1351
+ };
1352
+
1179
1353
  [encoder setComputePipelineState:pipeline];
1180
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1181
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1182
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
1183
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
1184
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
1185
- [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
1186
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
1187
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
1188
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
1189
- [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
1190
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
1191
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
1192
- [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
1193
- [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
1194
- [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
1195
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
1196
- [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
1197
- [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
1354
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
1355
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
1356
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1198
1357
 
1199
1358
  const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
1200
1359
 
@@ -1223,25 +1382,29 @@ static void lm_ggml_metal_encode_node(
1223
1382
 
1224
1383
  const id<MTLComputePipelineState> pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline;
1225
1384
 
1385
+ lm_ggml_metal_kargs_cpy args = {
1386
+ /*.ne00 =*/ ne00,
1387
+ /*.ne01 =*/ ne01,
1388
+ /*.ne02 =*/ ne02,
1389
+ /*.ne03 =*/ ne03,
1390
+ /*.nb00 =*/ nb00,
1391
+ /*.nb01 =*/ nb01,
1392
+ /*.nb02 =*/ nb02,
1393
+ /*.nb03 =*/ nb03,
1394
+ /*.ne0 =*/ ne0,
1395
+ /*.ne1 =*/ ne1,
1396
+ /*.ne2 =*/ ne2,
1397
+ /*.ne3 =*/ ne3,
1398
+ /*.nb0 =*/ nb0,
1399
+ /*.nb1 =*/ nb1,
1400
+ /*.nb2 =*/ nb2,
1401
+ /*.nb3 =*/ nb3,
1402
+ };
1403
+
1226
1404
  [encoder setComputePipelineState:pipeline];
1227
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1228
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1229
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1230
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
1231
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
1232
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
1233
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
1234
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
1235
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
1236
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
1237
- [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
1238
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
1239
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
1240
- [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
1241
- [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
1242
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
1243
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
1244
- [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
1405
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
1406
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
1407
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1245
1408
 
1246
1409
  const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
1247
1410
 
@@ -1250,35 +1413,39 @@ static void lm_ggml_metal_encode_node(
1250
1413
 
1251
1414
  const id<MTLComputePipelineState> pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ADD].pipeline;
1252
1415
 
1416
+ lm_ggml_metal_kargs_bin args = {
1417
+ /*.ne00 =*/ ne00,
1418
+ /*.ne01 =*/ ne01,
1419
+ /*.ne02 =*/ ne02,
1420
+ /*.ne03 =*/ ne03,
1421
+ /*.nb00 =*/ nb00,
1422
+ /*.nb01 =*/ pnb1,
1423
+ /*.nb02 =*/ pnb2,
1424
+ /*.nb03 =*/ pnb3,
1425
+ /*.ne10 =*/ ne10,
1426
+ /*.ne11 =*/ ne11,
1427
+ /*.ne12 =*/ ne12,
1428
+ /*.ne13 =*/ ne13,
1429
+ /*.nb10 =*/ nb10,
1430
+ /*.nb11 =*/ nb11,
1431
+ /*.nb12 =*/ nb12,
1432
+ /*.nb13 =*/ nb13,
1433
+ /*.ne0 =*/ ne0,
1434
+ /*.ne1 =*/ ne1,
1435
+ /*.ne2 =*/ ne2,
1436
+ /*.ne3 =*/ ne3,
1437
+ /*.nb0 =*/ nb0,
1438
+ /*.nb1 =*/ pnb1,
1439
+ /*.nb2 =*/ pnb2,
1440
+ /*.nb3 =*/ pnb3,
1441
+ /*.offs =*/ offs,
1442
+ };
1443
+
1253
1444
  [encoder setComputePipelineState:pipeline];
1254
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1255
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1256
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1257
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1258
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1259
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1260
- [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
1261
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
1262
- [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:8];
1263
- [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:9];
1264
- [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:10];
1265
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
1266
- [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
1267
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
1268
- [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
1269
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
1270
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
1271
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
1272
- [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
1273
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
1274
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
1275
- [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
1276
- [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
1277
- [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
1278
- [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:24];
1279
- [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:25];
1280
- [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26];
1281
- [encoder setBytes:&offs length:sizeof(offs) atIndex:27];
1445
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
1446
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
1447
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
1448
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
1282
1449
 
1283
1450
  const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
1284
1451
 
@@ -1319,10 +1486,10 @@ static void lm_ggml_metal_encode_node(
1319
1486
  memcpy(&max, ((const int32_t *) dst->op_params) + 1, sizeof(float));
1320
1487
 
1321
1488
  [encoder setComputePipelineState:pipeline];
1322
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1323
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1324
- [encoder setBytes:&min length:sizeof(min) atIndex:2];
1325
- [encoder setBytes:&max length:sizeof(max) atIndex:3];
1489
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1490
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1491
+ [encoder setBytes:&min length:sizeof(min) atIndex:2];
1492
+ [encoder setBytes:&max length:sizeof(max) atIndex:3];
1326
1493
 
1327
1494
  const int64_t n = lm_ggml_nelements(dst);
1328
1495
 
@@ -1426,6 +1593,18 @@ static void lm_ggml_metal_encode_node(
1426
1593
 
1427
1594
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1428
1595
  } break;
1596
+ case LM_GGML_UNARY_OP_ELU:
1597
+ {
1598
+ id<MTLComputePipelineState> pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ELU].pipeline;
1599
+
1600
+ [encoder setComputePipelineState:pipeline];
1601
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1602
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1603
+
1604
+ const int64_t n = lm_ggml_nelements(dst);
1605
+
1606
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1607
+ } break;
1429
1608
  default:
1430
1609
  {
1431
1610
  LM_GGML_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, idx, lm_ggml_op_name(dst->op));
@@ -1494,6 +1673,7 @@ static void lm_ggml_metal_encode_node(
1494
1673
 
1495
1674
  id<MTLComputePipelineState> pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
1496
1675
 
1676
+ // TODO: add lm_ggml_metal_kargs struct
1497
1677
  [encoder setComputePipelineState:pipeline];
1498
1678
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1499
1679
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
@@ -1569,6 +1749,8 @@ static void lm_ggml_metal_encode_node(
1569
1749
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
1570
1750
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
1571
1751
 
1752
+ // TODO: add lm_ggml_metal_kargs struct
1753
+ // TODO: optimize (see https://github.com/ggerganov/llama.cpp/pull/10238/commits/7941b6b9ec29a2866fec6fa6c51612515ca509f6)
1572
1754
  [encoder setComputePipelineState:pipeline];
1573
1755
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1574
1756
  if (id_src1) {
@@ -1585,6 +1767,7 @@ static void lm_ggml_metal_encode_node(
1585
1767
  [encoder setBytes:&m0 length:sizeof(m0) atIndex:8];
1586
1768
  [encoder setBytes:&m1 length:sizeof(m1) atIndex:9];
1587
1769
  [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:10];
1770
+
1588
1771
  [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
1589
1772
 
1590
1773
  [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
@@ -1601,6 +1784,7 @@ static void lm_ggml_metal_encode_node(
1601
1784
  pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF].pipeline;
1602
1785
  }
1603
1786
 
1787
+ // TODO: add lm_ggml_metal_kargs struct
1604
1788
  [encoder setComputePipelineState:pipeline];
1605
1789
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1606
1790
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
@@ -1625,6 +1809,7 @@ static void lm_ggml_metal_encode_node(
1625
1809
 
1626
1810
  id<MTLComputePipelineState> pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_SSM_CONV_F32].pipeline;
1627
1811
 
1812
+ // TODO: add lm_ggml_metal_kargs struct
1628
1813
  [encoder setComputePipelineState:pipeline];
1629
1814
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1630
1815
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
@@ -1695,6 +1880,7 @@ static void lm_ggml_metal_encode_node(
1695
1880
 
1696
1881
  id<MTLComputePipelineState> pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline;
1697
1882
 
1883
+ // TODO: add lm_ggml_metal_kargs struct
1698
1884
  [encoder setComputePipelineState:pipeline];
1699
1885
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1700
1886
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
@@ -1742,7 +1928,7 @@ static void lm_ggml_metal_encode_node(
1742
1928
 
1743
1929
  // find the break-even point where the matrix-matrix kernel becomes more efficient compared
1744
1930
  // to the matrix-vector kernel
1745
- int ne11_mm_min = 1;
1931
+ int ne11_mm_min = 4;
1746
1932
 
1747
1933
  #if 0
1748
1934
  // the numbers below are measured on M2 Ultra for 7B and 13B models
@@ -1766,286 +1952,316 @@ static void lm_ggml_metal_encode_node(
1766
1952
  }
1767
1953
  #endif
1768
1954
 
1769
- // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
1770
- // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
1771
- if ([device supportsFamily:MTLGPUFamilyApple7] &&
1772
- !lm_ggml_is_transposed(src0) &&
1773
- !lm_ggml_is_transposed(src1) &&
1774
- src1t == LM_GGML_TYPE_F32 &&
1775
- ne00 % 32 == 0 && ne00 >= 64 &&
1776
- (ne11 > ne11_mm_min || (lm_ggml_is_quantized(src0t) && ne12 > 1))) {
1777
- //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
1778
-
1779
- // some Metal matrix data types require aligned pointers
1780
- // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
1781
- switch (src0->type) {
1782
- case LM_GGML_TYPE_F32: LM_GGML_ASSERT(nb01 % 16 == 0); break;
1783
- case LM_GGML_TYPE_F16: LM_GGML_ASSERT(nb01 % 8 == 0); break;
1784
- default: break;
1785
- }
1955
+ // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
1956
+ // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
1957
+ if ([device supportsFamily:MTLGPUFamilyApple7] &&
1958
+ !lm_ggml_is_transposed(src0) &&
1959
+ !lm_ggml_is_transposed(src1) &&
1960
+ src1t == LM_GGML_TYPE_F32 &&
1961
+ ne00 % 32 == 0 && ne00 >= 64 &&
1962
+ (ne11 > ne11_mm_min || (lm_ggml_is_quantized(src0t) && ne12 > 1))) {
1963
+ //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
1786
1964
 
1787
- id<MTLComputePipelineState> pipeline = nil;
1788
-
1789
- switch (src0->type) {
1790
- case LM_GGML_TYPE_F32: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32 ].pipeline; break;
1791
- case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32 ].pipeline; break;
1792
- case LM_GGML_TYPE_Q4_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32 ].pipeline; break;
1793
- case LM_GGML_TYPE_Q4_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32 ].pipeline; break;
1794
- case LM_GGML_TYPE_Q5_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break;
1795
- case LM_GGML_TYPE_Q5_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32 ].pipeline; break;
1796
- case LM_GGML_TYPE_Q8_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32 ].pipeline; break;
1797
- case LM_GGML_TYPE_Q2_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32 ].pipeline; break;
1798
- case LM_GGML_TYPE_Q3_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32 ].pipeline; break;
1799
- case LM_GGML_TYPE_Q4_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32 ].pipeline; break;
1800
- case LM_GGML_TYPE_Q5_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32 ].pipeline; break;
1801
- case LM_GGML_TYPE_Q6_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32 ].pipeline; break;
1802
- case LM_GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break;
1803
- case LM_GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break;
1804
- case LM_GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32].pipeline; break;
1805
- case LM_GGML_TYPE_IQ3_S: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32 ].pipeline; break;
1806
- case LM_GGML_TYPE_IQ2_S: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32 ].pipeline; break;
1807
- case LM_GGML_TYPE_IQ1_S: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32 ].pipeline; break;
1808
- case LM_GGML_TYPE_IQ1_M: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32 ].pipeline; break;
1809
- case LM_GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break;
1810
- case LM_GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break;
1811
- default: LM_GGML_ABORT("MUL MAT-MAT not implemented");
1812
- }
1965
+ // some Metal matrix data types require aligned pointers
1966
+ // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
1967
+ switch (src0->type) {
1968
+ case LM_GGML_TYPE_F32: LM_GGML_ASSERT(nb01 % 16 == 0); break;
1969
+ case LM_GGML_TYPE_F16: LM_GGML_ASSERT(nb01 % 8 == 0); break;
1970
+ case LM_GGML_TYPE_BF16: LM_GGML_ASSERT(nb01 % 8 == 0); break;
1971
+ default: break;
1972
+ }
1813
1973
 
1814
- [encoder setComputePipelineState:pipeline];
1815
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1816
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1817
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1818
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1819
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
1820
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
1821
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
1822
- [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:7];
1823
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:8];
1824
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:9];
1825
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:10];
1826
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:11];
1827
- [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:12];
1828
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
1829
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
1830
- [encoder setBytes:&r2 length:sizeof(r2) atIndex:15];
1831
- [encoder setBytes:&r3 length:sizeof(r3) atIndex:16];
1832
- [encoder setThreadgroupMemoryLength:8192 atIndex:0];
1833
- [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1834
- } else {
1835
- int nth0 = 32;
1836
- int nth1 = 1;
1837
- int nrows = 1;
1838
- //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
1839
-
1840
- id<MTLComputePipelineState> pipeline = nil;
1841
-
1842
- // use custom matrix x vector kernel
1843
- switch (src0t) {
1844
- case LM_GGML_TYPE_F32:
1845
- {
1846
- LM_GGML_ASSERT(src1t == LM_GGML_TYPE_F32);
1847
- pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline;
1974
+ id<MTLComputePipelineState> pipeline = nil;
1975
+
1976
+ switch (src0->type) {
1977
+ case LM_GGML_TYPE_F32: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32 ].pipeline; break;
1978
+ case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32 ].pipeline; break;
1979
+ case LM_GGML_TYPE_BF16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32 ].pipeline; break;
1980
+ case LM_GGML_TYPE_Q4_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32 ].pipeline; break;
1981
+ case LM_GGML_TYPE_Q4_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32 ].pipeline; break;
1982
+ case LM_GGML_TYPE_Q5_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break;
1983
+ case LM_GGML_TYPE_Q5_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32 ].pipeline; break;
1984
+ case LM_GGML_TYPE_Q8_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32 ].pipeline; break;
1985
+ case LM_GGML_TYPE_Q2_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32 ].pipeline; break;
1986
+ case LM_GGML_TYPE_Q3_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32 ].pipeline; break;
1987
+ case LM_GGML_TYPE_Q4_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32 ].pipeline; break;
1988
+ case LM_GGML_TYPE_Q5_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32 ].pipeline; break;
1989
+ case LM_GGML_TYPE_Q6_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32 ].pipeline; break;
1990
+ case LM_GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break;
1991
+ case LM_GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break;
1992
+ case LM_GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32].pipeline; break;
1993
+ case LM_GGML_TYPE_IQ3_S: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32 ].pipeline; break;
1994
+ case LM_GGML_TYPE_IQ2_S: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32 ].pipeline; break;
1995
+ case LM_GGML_TYPE_IQ1_S: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32 ].pipeline; break;
1996
+ case LM_GGML_TYPE_IQ1_M: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32 ].pipeline; break;
1997
+ case LM_GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break;
1998
+ case LM_GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break;
1999
+ default: LM_GGML_ABORT("MUL MAT-MAT not implemented");
2000
+ }
2001
+
2002
+ lm_ggml_metal_kargs_mul_mm args = {
2003
+ /*.ne00 =*/ ne00,
2004
+ /*.ne02 =*/ ne02,
2005
+ /*.nb01 =*/ nb01,
2006
+ /*.nb02 =*/ nb02,
2007
+ /*.nb03 =*/ nb03,
2008
+ /*.ne12 =*/ ne12,
2009
+ /*.nb10 =*/ nb10,
2010
+ /*.nb11 =*/ nb11,
2011
+ /*.nb12 =*/ nb12,
2012
+ /*.nb13 =*/ nb13,
2013
+ /*.ne0 =*/ ne0,
2014
+ /*.ne1 =*/ ne1,
2015
+ /*.r2 =*/ r2,
2016
+ /*.r3 =*/ r3,
2017
+ };
2018
+
2019
+ [encoder setComputePipelineState:pipeline];
2020
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
2021
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
2022
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
2023
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
2024
+
2025
+ [encoder setThreadgroupMemoryLength:8192 atIndex:0];
2026
+ [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
2027
+ } else {
2028
+ int nth0 = 32;
2029
+ int nth1 = 1;
2030
+ int nrows = 1;
2031
+ //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
2032
+
2033
+ id<MTLComputePipelineState> pipeline = nil;
2034
+
2035
+ // use custom matrix x vector kernel
2036
+ switch (src0t) {
2037
+ case LM_GGML_TYPE_F32:
2038
+ {
2039
+ LM_GGML_ASSERT(src1t == LM_GGML_TYPE_F32);
2040
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline;
2041
+ nrows = 4;
2042
+ } break;
2043
+ case LM_GGML_TYPE_F16:
2044
+ {
2045
+ nth0 = 32;
2046
+ nth1 = 1;
2047
+ if (src1t == LM_GGML_TYPE_F32) {
2048
+ if (ne11 * ne12 < 4) {
2049
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline;
2050
+ } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
2051
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline;
2052
+ nrows = ne11;
2053
+ } else {
2054
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32].pipeline;
1848
2055
  nrows = 4;
1849
- } break;
1850
- case LM_GGML_TYPE_F16:
1851
- {
1852
- nth0 = 32;
1853
- nth1 = 1;
1854
- if (src1t == LM_GGML_TYPE_F32) {
1855
- if (ne11 * ne12 < 4) {
1856
- pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline;
1857
- } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
1858
- pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline;
1859
- nrows = ne11;
1860
- } else {
1861
- pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32].pipeline;
1862
- nrows = 4;
1863
- }
1864
- } else {
1865
- pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16].pipeline;
1866
- nrows = 4;
1867
- }
1868
- } break;
1869
- case LM_GGML_TYPE_Q4_0:
1870
- {
1871
- nth0 = 8;
1872
- nth1 = 8;
1873
- pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32].pipeline;
1874
- } break;
1875
- case LM_GGML_TYPE_Q4_1:
1876
- {
1877
- nth0 = 8;
1878
- nth1 = 8;
1879
- pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32].pipeline;
1880
- } break;
1881
- case LM_GGML_TYPE_Q5_0:
1882
- {
1883
- nth0 = 8;
1884
- nth1 = 8;
1885
- pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32].pipeline;
1886
- } break;
1887
- case LM_GGML_TYPE_Q5_1:
1888
- {
1889
- nth0 = 8;
1890
- nth1 = 8;
1891
- pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32].pipeline;
1892
- } break;
1893
- case LM_GGML_TYPE_Q8_0:
1894
- {
1895
- nth0 = 8;
1896
- nth1 = 8;
1897
- pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline;
1898
- } break;
1899
- case LM_GGML_TYPE_Q2_K:
1900
- {
1901
- nth0 = 2;
1902
- nth1 = 32;
1903
- pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32].pipeline;
1904
- } break;
1905
- case LM_GGML_TYPE_Q3_K:
1906
- {
1907
- nth0 = 2;
1908
- nth1 = 32;
1909
- pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32].pipeline;
1910
- } break;
1911
- case LM_GGML_TYPE_Q4_K:
1912
- {
1913
- nth0 = 4; //1;
1914
- nth1 = 8; //32;
1915
- pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32].pipeline;
1916
- } break;
1917
- case LM_GGML_TYPE_Q5_K:
1918
- {
1919
- nth0 = 2;
1920
- nth1 = 32;
1921
- pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32].pipeline;
1922
- } break;
1923
- case LM_GGML_TYPE_Q6_K:
1924
- {
1925
- nth0 = 2;
1926
- nth1 = 32;
1927
- pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32].pipeline;
1928
- } break;
1929
- case LM_GGML_TYPE_IQ2_XXS:
1930
- {
1931
- nth0 = 4;
1932
- nth1 = 16;
1933
- pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32].pipeline;
1934
- } break;
1935
- case LM_GGML_TYPE_IQ2_XS:
1936
- {
1937
- nth0 = 4;
1938
- nth1 = 16;
1939
- pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32].pipeline;
1940
- } break;
1941
- case LM_GGML_TYPE_IQ3_XXS:
1942
- {
1943
- nth0 = 4;
1944
- nth1 = 16;
1945
- pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline;
1946
- } break;
1947
- case LM_GGML_TYPE_IQ3_S:
1948
- {
1949
- nth0 = 4;
1950
- nth1 = 16;
1951
- pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32].pipeline;
1952
- } break;
1953
- case LM_GGML_TYPE_IQ2_S:
1954
- {
1955
- nth0 = 4;
1956
- nth1 = 16;
1957
- pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32].pipeline;
1958
- } break;
1959
- case LM_GGML_TYPE_IQ1_S:
1960
- {
1961
- nth0 = 4;
1962
- nth1 = 16;
1963
- pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32].pipeline;
1964
- } break;
1965
- case LM_GGML_TYPE_IQ1_M:
1966
- {
1967
- nth0 = 4;
1968
- nth1 = 16;
1969
- pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32].pipeline;
1970
- } break;
1971
- case LM_GGML_TYPE_IQ4_NL:
1972
- {
1973
- nth0 = 4;
1974
- nth1 = 16;
1975
- pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32].pipeline;
1976
- } break;
1977
- case LM_GGML_TYPE_IQ4_XS:
1978
- {
1979
- nth0 = 4;
1980
- nth1 = 16;
1981
- pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32].pipeline;
1982
- } break;
1983
- default:
1984
- {
1985
- LM_GGML_LOG_ERROR("Asserting on type %d\n", (int)src0t);
1986
- LM_GGML_ABORT("not implemented");
1987
2056
  }
1988
- };
1989
-
1990
- [encoder setComputePipelineState:pipeline];
1991
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1992
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1993
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1994
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1995
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1996
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1997
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
1998
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
1999
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
2000
- [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
2001
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
2002
- [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
2003
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
2004
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:13];
2005
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:14];
2006
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:15];
2007
- [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:16];
2008
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
2009
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:18];
2010
- [encoder setBytes:&r2 length:sizeof(r2) atIndex:19];
2011
- [encoder setBytes:&r3 length:sizeof(r3) atIndex:20];
2012
-
2013
- if (src0t == LM_GGML_TYPE_Q4_0 || src0t == LM_GGML_TYPE_Q4_1 || src0t == LM_GGML_TYPE_Q5_0 ||
2014
- src0t == LM_GGML_TYPE_Q5_1 || src0t == LM_GGML_TYPE_Q8_0 || src0t == LM_GGML_TYPE_Q2_K ||
2015
- src0t == LM_GGML_TYPE_IQ1_S || src0t == LM_GGML_TYPE_IQ1_M || src0t == LM_GGML_TYPE_IQ2_S) {
2016
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2017
- }
2018
- else if (src0t == LM_GGML_TYPE_IQ2_XXS || src0t == LM_GGML_TYPE_IQ2_XS) {
2019
- const int mem_size = src0t == LM_GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
2020
- [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
2021
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2022
- }
2023
- else if (src0t == LM_GGML_TYPE_IQ3_XXS || src0t == LM_GGML_TYPE_IQ3_S) {
2024
- const int mem_size = src0t == LM_GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
2025
- [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
2026
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2027
- }
2028
- else if (src0t == LM_GGML_TYPE_IQ4_NL || src0t == LM_GGML_TYPE_IQ4_XS) {
2029
- const int mem_size = 32*sizeof(float);
2030
- [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
2031
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2032
- }
2033
- else if (src0t == LM_GGML_TYPE_Q4_K) {
2034
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2035
- }
2036
- else if (src0t == LM_GGML_TYPE_Q3_K) {
2037
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2038
- }
2039
- else if (src0t == LM_GGML_TYPE_Q5_K) {
2040
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2041
- }
2042
- else if (src0t == LM_GGML_TYPE_Q6_K) {
2043
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2044
- } else {
2045
- const int64_t ny = (ne11 + nrows - 1)/nrows;
2046
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2057
+ } else {
2058
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16].pipeline;
2059
+ nrows = 4;
2060
+ }
2061
+ } break;
2062
+ case LM_GGML_TYPE_BF16:
2063
+ {
2064
+ nth0 = 32;
2065
+ nth1 = 1;
2066
+ if (src1t == LM_GGML_TYPE_F32) {
2067
+ if (ne11 * ne12 < 4) {
2068
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW].pipeline;
2069
+ } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
2070
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4].pipeline;
2071
+ nrows = ne11;
2072
+ } else {
2073
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32].pipeline;
2074
+ nrows = 4;
2075
+ }
2076
+ } else {
2077
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16].pipeline;
2078
+ nrows = 4;
2079
+ }
2080
+ } break;
2081
+ case LM_GGML_TYPE_Q4_0:
2082
+ {
2083
+ nth0 = 8;
2084
+ nth1 = 8;
2085
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32].pipeline;
2086
+ } break;
2087
+ case LM_GGML_TYPE_Q4_1:
2088
+ {
2089
+ nth0 = 8;
2090
+ nth1 = 8;
2091
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32].pipeline;
2092
+ } break;
2093
+ case LM_GGML_TYPE_Q5_0:
2094
+ {
2095
+ nth0 = 8;
2096
+ nth1 = 8;
2097
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32].pipeline;
2098
+ } break;
2099
+ case LM_GGML_TYPE_Q5_1:
2100
+ {
2101
+ nth0 = 8;
2102
+ nth1 = 8;
2103
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32].pipeline;
2104
+ } break;
2105
+ case LM_GGML_TYPE_Q8_0:
2106
+ {
2107
+ nth0 = 8;
2108
+ nth1 = 8;
2109
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline;
2110
+ } break;
2111
+ case LM_GGML_TYPE_Q2_K:
2112
+ {
2113
+ nth0 = 2;
2114
+ nth1 = 32;
2115
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32].pipeline;
2116
+ } break;
2117
+ case LM_GGML_TYPE_Q3_K:
2118
+ {
2119
+ nth0 = 2;
2120
+ nth1 = 32;
2121
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32].pipeline;
2122
+ } break;
2123
+ case LM_GGML_TYPE_Q4_K:
2124
+ {
2125
+ nth0 = 4; //1;
2126
+ nth1 = 8; //32;
2127
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32].pipeline;
2128
+ } break;
2129
+ case LM_GGML_TYPE_Q5_K:
2130
+ {
2131
+ nth0 = 2;
2132
+ nth1 = 32;
2133
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32].pipeline;
2134
+ } break;
2135
+ case LM_GGML_TYPE_Q6_K:
2136
+ {
2137
+ nth0 = 2;
2138
+ nth1 = 32;
2139
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32].pipeline;
2140
+ } break;
2141
+ case LM_GGML_TYPE_IQ2_XXS:
2142
+ {
2143
+ nth0 = 4;
2144
+ nth1 = 16;
2145
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32].pipeline;
2146
+ } break;
2147
+ case LM_GGML_TYPE_IQ2_XS:
2148
+ {
2149
+ nth0 = 4;
2150
+ nth1 = 16;
2151
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32].pipeline;
2152
+ } break;
2153
+ case LM_GGML_TYPE_IQ3_XXS:
2154
+ {
2155
+ nth0 = 4;
2156
+ nth1 = 16;
2157
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline;
2158
+ } break;
2159
+ case LM_GGML_TYPE_IQ3_S:
2160
+ {
2161
+ nth0 = 4;
2162
+ nth1 = 16;
2163
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32].pipeline;
2164
+ } break;
2165
+ case LM_GGML_TYPE_IQ2_S:
2166
+ {
2167
+ nth0 = 4;
2168
+ nth1 = 16;
2169
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32].pipeline;
2170
+ } break;
2171
+ case LM_GGML_TYPE_IQ1_S:
2172
+ {
2173
+ nth0 = 4;
2174
+ nth1 = 16;
2175
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32].pipeline;
2176
+ } break;
2177
+ case LM_GGML_TYPE_IQ1_M:
2178
+ {
2179
+ nth0 = 4;
2180
+ nth1 = 16;
2181
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32].pipeline;
2182
+ } break;
2183
+ case LM_GGML_TYPE_IQ4_NL:
2184
+ {
2185
+ nth0 = 4;
2186
+ nth1 = 16;
2187
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32].pipeline;
2188
+ } break;
2189
+ case LM_GGML_TYPE_IQ4_XS:
2190
+ {
2191
+ nth0 = 4;
2192
+ nth1 = 16;
2193
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32].pipeline;
2194
+ } break;
2195
+ default:
2196
+ {
2197
+ LM_GGML_LOG_ERROR("Asserting on type %d\n", (int)src0t);
2198
+ LM_GGML_ABORT("not implemented");
2047
2199
  }
2048
- }
2200
+ };
2201
+
2202
+ lm_ggml_metal_kargs_mul_mv args = {
2203
+ /*.ne00 =*/ ne00,
2204
+ /*.ne01 =*/ ne01,
2205
+ /*.ne02 =*/ ne02,
2206
+ /*.nb00 =*/ nb00,
2207
+ /*.nb01 =*/ nb01,
2208
+ /*.nb02 =*/ nb02,
2209
+ /*.nb03 =*/ nb03,
2210
+ /*.ne10 =*/ ne10,
2211
+ /*.ne11 =*/ ne11,
2212
+ /*.ne12 =*/ ne12,
2213
+ /*.nb10 =*/ nb10,
2214
+ /*.nb11 =*/ nb11,
2215
+ /*.nb12 =*/ nb12,
2216
+ /*.nb13 =*/ nb13,
2217
+ /*.ne0 =*/ ne0,
2218
+ /*.ne1 =*/ ne1,
2219
+ /*.r2 =*/ r2,
2220
+ /*.r3 =*/ r3,
2221
+ };
2222
+
2223
+ [encoder setComputePipelineState:pipeline];
2224
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
2225
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
2226
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
2227
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
2228
+
2229
+ if (src0t == LM_GGML_TYPE_Q4_0 || src0t == LM_GGML_TYPE_Q4_1 || src0t == LM_GGML_TYPE_Q5_0 ||
2230
+ src0t == LM_GGML_TYPE_Q5_1 || src0t == LM_GGML_TYPE_Q8_0 || src0t == LM_GGML_TYPE_Q2_K ||
2231
+ src0t == LM_GGML_TYPE_IQ1_S || src0t == LM_GGML_TYPE_IQ1_M || src0t == LM_GGML_TYPE_IQ2_S) {
2232
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2233
+ }
2234
+ else if (src0t == LM_GGML_TYPE_IQ2_XXS || src0t == LM_GGML_TYPE_IQ2_XS) {
2235
+ const int mem_size = src0t == LM_GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
2236
+ [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
2237
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2238
+ }
2239
+ else if (src0t == LM_GGML_TYPE_IQ3_XXS || src0t == LM_GGML_TYPE_IQ3_S) {
2240
+ const int mem_size = src0t == LM_GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
2241
+ [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
2242
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2243
+ }
2244
+ else if (src0t == LM_GGML_TYPE_IQ4_NL || src0t == LM_GGML_TYPE_IQ4_XS) {
2245
+ const int mem_size = 32*sizeof(float);
2246
+ [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
2247
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2248
+ }
2249
+ else if (src0t == LM_GGML_TYPE_Q4_K) {
2250
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2251
+ }
2252
+ else if (src0t == LM_GGML_TYPE_Q3_K) {
2253
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2254
+ }
2255
+ else if (src0t == LM_GGML_TYPE_Q5_K) {
2256
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2257
+ }
2258
+ else if (src0t == LM_GGML_TYPE_Q6_K) {
2259
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2260
+ } else {
2261
+ const int64_t ny = (ne11 + nrows - 1)/nrows;
2262
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2263
+ }
2264
+ }
2049
2265
  } break;
2050
2266
  case LM_GGML_OP_MUL_MAT_ID:
2051
2267
  {
@@ -2084,12 +2300,12 @@ static void lm_ggml_metal_encode_node(
2084
2300
  if ([device supportsFamily:MTLGPUFamilyApple7] &&
2085
2301
  ne00 % 32 == 0 && ne00 >= 64 &&
2086
2302
  dst_rows > dst_rows_min) {
2087
-
2088
2303
  // some Metal matrix data types require aligned pointers
2089
2304
  // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
2090
2305
  switch (src0->type) {
2091
- case LM_GGML_TYPE_F32: LM_GGML_ASSERT(nb01 % 16 == 0); break;
2092
- case LM_GGML_TYPE_F16: LM_GGML_ASSERT(nb01 % 8 == 0); break;
2306
+ case LM_GGML_TYPE_F32: LM_GGML_ASSERT(nb01 % 16 == 0); break;
2307
+ case LM_GGML_TYPE_F16: LM_GGML_ASSERT(nb01 % 8 == 0); break;
2308
+ case LM_GGML_TYPE_BF16: LM_GGML_ASSERT(nb01 % 8 == 0); break;
2093
2309
  default: break;
2094
2310
  }
2095
2311
 
@@ -2098,6 +2314,7 @@ static void lm_ggml_metal_encode_node(
2098
2314
  switch (src0->type) {
2099
2315
  case LM_GGML_TYPE_F32: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32 ].pipeline; break;
2100
2316
  case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32 ].pipeline; break;
2317
+ case LM_GGML_TYPE_BF16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32 ].pipeline; break;
2101
2318
  case LM_GGML_TYPE_Q4_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32 ].pipeline; break;
2102
2319
  case LM_GGML_TYPE_Q4_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32 ].pipeline; break;
2103
2320
  case LM_GGML_TYPE_Q5_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32 ].pipeline; break;
@@ -2120,27 +2337,30 @@ static void lm_ggml_metal_encode_node(
2120
2337
  default: LM_GGML_ABORT("MUL_MAT_ID not implemented");
2121
2338
  }
2122
2339
 
2340
+ lm_ggml_metal_kargs_mul_mm_id args = {
2341
+ /*.nei0 =*/ ne20,
2342
+ /*.nei1 =*/ ne21,
2343
+ /*.nbi1 =*/ nb21,
2344
+ /*.ne00 =*/ ne00,
2345
+ /*.ne02 =*/ ne02,
2346
+ /*.nb01 =*/ nb01,
2347
+ /*.nb02 =*/ nb02,
2348
+ /*.ne11 =*/ ne11,
2349
+ /*.ne12 =*/ ne12,
2350
+ /*.ne13 =*/ ne13,
2351
+ /*.nb10 =*/ nb10,
2352
+ /*.nb11 =*/ nb11,
2353
+ /*.nb12 =*/ nb12,
2354
+ /*.ne0 =*/ ne0,
2355
+ /*.ne1 =*/ ne1,
2356
+ };
2357
+
2123
2358
  [encoder setComputePipelineState:pipeline];
2124
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2125
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
2126
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
2127
- [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
2128
- [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
2129
- [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
2130
- [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
2131
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7];
2132
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:8];
2133
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:9];
2134
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:10];
2135
- [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
2136
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
2137
- [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
2138
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
2139
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
2140
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
2141
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
2142
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:18];
2143
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19];
2359
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
2360
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
2361
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
2362
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
2363
+ [encoder setBuffer:id_src2 offset:offs_src2 atIndex:4];
2144
2364
 
2145
2365
  [encoder setThreadgroupMemoryLength:LM_GGML_PAD(8192 + dst_rows*4/*sizeof(ushort2)*/, 16) atIndex:0];
2146
2366
 
@@ -2167,6 +2387,13 @@ static void lm_ggml_metal_encode_node(
2167
2387
  nth1 = 1;
2168
2388
  pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32].pipeline;
2169
2389
  } break;
2390
+ case LM_GGML_TYPE_BF16:
2391
+ {
2392
+ LM_GGML_ASSERT(src1t == LM_GGML_TYPE_F32);
2393
+ nth0 = 32;
2394
+ nth1 = 1;
2395
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32].pipeline;
2396
+ } break;
2170
2397
  case LM_GGML_TYPE_Q4_0:
2171
2398
  {
2172
2399
  nth0 = 8;
@@ -2292,30 +2519,34 @@ static void lm_ggml_metal_encode_node(
2292
2519
  LM_GGML_ASSERT(ne00 >= nth0*nth1);
2293
2520
  }
2294
2521
 
2522
+ lm_ggml_metal_kargs_mul_mv_id args = {
2523
+ /*.nei0 =*/ ne20,
2524
+ /*.nei1 =*/ ne21,
2525
+ /*.nbi1 =*/ nb21,
2526
+ /*.ne00 =*/ ne00,
2527
+ /*.ne01 =*/ ne01,
2528
+ /*.ne02 =*/ ne02,
2529
+ /*.nb00 =*/ nb00,
2530
+ /*.nb01 =*/ nb01,
2531
+ /*.nb02 =*/ nb02,
2532
+ /*.ne10 =*/ ne10,
2533
+ /*.ne11 =*/ ne11,
2534
+ /*.ne12 =*/ ne12,
2535
+ /*.ne13 =*/ ne13,
2536
+ /*.nb10 =*/ nb10,
2537
+ /*.nb11 =*/ nb11,
2538
+ /*.nb12 =*/ nb12,
2539
+ /*.ne0 =*/ ne0,
2540
+ /*.ne1 =*/ ne1,
2541
+ /*.nb1 =*/ nb1,
2542
+ };
2543
+
2295
2544
  [encoder setComputePipelineState:pipeline];
2296
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2297
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
2298
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
2299
- [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
2300
- [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
2301
- [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
2302
- [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
2303
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7];
2304
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:8];
2305
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:9];
2306
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:10];
2307
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:11];
2308
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:12];
2309
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:13];
2310
- [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:14];
2311
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:15];
2312
- [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:16];
2313
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:17];
2314
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:18];
2315
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:19];
2316
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:20];
2317
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:21];
2318
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:22];
2545
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
2546
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
2547
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
2548
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
2549
+ [encoder setBuffer:id_src2 offset:offs_src2 atIndex:4];
2319
2550
 
2320
2551
  const int64_t _ne1 = 1;
2321
2552
  const int tgz = dst_rows;
@@ -2364,6 +2595,7 @@ static void lm_ggml_metal_encode_node(
2364
2595
  switch (src0->type) {
2365
2596
  case LM_GGML_TYPE_F32: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_F32 ].pipeline; break;
2366
2597
  case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_F16 ].pipeline; break;
2598
+ case LM_GGML_TYPE_BF16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16 ].pipeline; break;
2367
2599
  case LM_GGML_TYPE_Q4_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0 ].pipeline; break;
2368
2600
  case LM_GGML_TYPE_Q4_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1 ].pipeline; break;
2369
2601
  case LM_GGML_TYPE_Q5_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0 ].pipeline; break;
@@ -2387,6 +2619,7 @@ static void lm_ggml_metal_encode_node(
2387
2619
  default: LM_GGML_ABORT("not implemented");
2388
2620
  }
2389
2621
 
2622
+ // TODO: add lm_ggml_metal_kargs struct
2390
2623
  [encoder setComputePipelineState:pipeline];
2391
2624
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2392
2625
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
@@ -2410,20 +2643,28 @@ static void lm_ggml_metal_encode_node(
2410
2643
  float eps;
2411
2644
  memcpy(&eps, dst->op_params, sizeof(float));
2412
2645
 
2646
+ id<MTLComputePipelineState> pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline;
2647
+
2413
2648
  int nth = 32; // SIMD width
2414
2649
 
2415
- while (nth < ne00/4 && nth < 1024) {
2650
+ while (nth < ne00/4 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
2416
2651
  nth *= 2;
2417
2652
  }
2418
2653
 
2419
- id<MTLComputePipelineState> pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline;
2654
+ nth = MIN(nth, ne00/4);
2655
+
2656
+ lm_ggml_metal_kargs_rms_norm args = {
2657
+ /*.ne00 =*/ ne00,
2658
+ /*.ne00_4 =*/ ne00/4,
2659
+ /*.nb01 =*/ nb01,
2660
+ /*.eps =*/ eps,
2661
+ };
2420
2662
 
2421
2663
  [encoder setComputePipelineState:pipeline];
2422
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2423
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2424
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
2425
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
2426
- [encoder setBytes:&eps length:sizeof( float) atIndex:4];
2664
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
2665
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
2666
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
2667
+
2427
2668
  [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
2428
2669
 
2429
2670
  const int64_t nrows = lm_ggml_nrows(src0);
@@ -2432,7 +2673,6 @@ static void lm_ggml_metal_encode_node(
2432
2673
  } break;
2433
2674
  case LM_GGML_OP_GROUP_NORM:
2434
2675
  {
2435
- LM_GGML_ASSERT(ne00 % 4 == 0);
2436
2676
  LM_GGML_ASSERT(lm_ggml_is_contiguous(src0));
2437
2677
 
2438
2678
  float eps;
@@ -2448,6 +2688,7 @@ static void lm_ggml_metal_encode_node(
2448
2688
 
2449
2689
  id<MTLComputePipelineState> pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_GROUP_NORM].pipeline;
2450
2690
 
2691
+ // TODO: add lm_ggml_metal_kargs struct
2451
2692
  [encoder setComputePipelineState:pipeline];
2452
2693
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2453
2694
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
@@ -2465,22 +2706,35 @@ static void lm_ggml_metal_encode_node(
2465
2706
  } break;
2466
2707
  case LM_GGML_OP_NORM:
2467
2708
  {
2709
+ LM_GGML_ASSERT(ne00 % 4 == 0);
2468
2710
  LM_GGML_ASSERT(lm_ggml_is_contiguous_1(src0));
2469
2711
 
2470
2712
  float eps;
2471
2713
  memcpy(&eps, dst->op_params, sizeof(float));
2472
2714
 
2473
- const int nth = MIN(256, ne00);
2474
-
2475
2715
  id<MTLComputePipelineState> pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_NORM].pipeline;
2476
2716
 
2717
+ int nth = 32; // SIMD width
2718
+
2719
+ while (nth < ne00/4 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
2720
+ nth *= 2;
2721
+ }
2722
+
2723
+ nth = MIN(nth, ne00/4);
2724
+
2725
+ lm_ggml_metal_kargs_norm args = {
2726
+ /*.ne00 =*/ ne00,
2727
+ /*.ne00_4 =*/ ne00/4,
2728
+ /*.nb01 =*/ nb01,
2729
+ /*.eps =*/ eps,
2730
+ };
2731
+
2477
2732
  [encoder setComputePipelineState:pipeline];
2478
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2479
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2480
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
2481
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
2482
- [encoder setBytes:&eps length:sizeof( float) atIndex:4];
2483
- [encoder setThreadgroupMemoryLength:LM_GGML_PAD(nth*sizeof(float), 16) atIndex:0];
2733
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
2734
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
2735
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
2736
+
2737
+ [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
2484
2738
 
2485
2739
  const int64_t nrows = lm_ggml_nrows(src0);
2486
2740
 
@@ -2530,40 +2784,44 @@ static void lm_ggml_metal_encode_node(
2530
2784
  };
2531
2785
  }
2532
2786
 
2787
+ lm_ggml_metal_kargs_rope args = {
2788
+ /*.ne00 =*/ ne00,
2789
+ /*.ne01 =*/ ne01,
2790
+ /*.ne02 =*/ ne02,
2791
+ /*.ne03 =*/ ne03,
2792
+ /*.nb00 =*/ nb00,
2793
+ /*.nb01 =*/ nb01,
2794
+ /*.nb02 =*/ nb02,
2795
+ /*.nb03 =*/ nb03,
2796
+ /*.ne0 =*/ ne0,
2797
+ /*.ne1 =*/ ne1,
2798
+ /*.ne2 =*/ ne2,
2799
+ /*.ne3 =*/ ne3,
2800
+ /*.nb0 =*/ nb0,
2801
+ /*.nb1 =*/ nb1,
2802
+ /*.nb2 =*/ nb2,
2803
+ /*.nb3 =*/ nb3,
2804
+ /*.n_past =*/ n_past,
2805
+ /*.n_dims =*/ n_dims,
2806
+ /*.n_ctx_orig =*/ n_ctx_orig,
2807
+ /*.freq_base =*/ freq_base,
2808
+ /*.freq_scale =*/ freq_scale,
2809
+ /*.ext_factor =*/ ext_factor,
2810
+ /*.attn_factor =*/ attn_factor,
2811
+ /*.beta_fast =*/ beta_fast,
2812
+ /*.beta_slow =*/ beta_slow,
2813
+ };
2814
+
2533
2815
  [encoder setComputePipelineState:pipeline];
2534
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2535
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
2816
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
2817
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
2818
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
2536
2819
  if (id_src2 != nil) {
2537
- [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
2820
+ [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
2538
2821
  } else {
2539
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:2];
2822
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:3];
2540
2823
  }
2541
- [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
2542
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:4];
2543
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5];
2544
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6];
2545
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7];
2546
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:8];
2547
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:9];
2548
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:10];
2549
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:11];
2550
- [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:12];
2551
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:13];
2552
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:14];
2553
- [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:15];
2554
- [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:16];
2555
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:17];
2556
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:18];
2557
- [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:19];
2558
- [encoder setBytes:&n_past length:sizeof( int) atIndex:20];
2559
- [encoder setBytes:&n_dims length:sizeof( int) atIndex:21];
2560
- [encoder setBytes:&n_ctx_orig length:sizeof( int) atIndex:22];
2561
- [encoder setBytes:&freq_base length:sizeof( float) atIndex:23];
2562
- [encoder setBytes:&freq_scale length:sizeof( float) atIndex:24];
2563
- [encoder setBytes:&ext_factor length:sizeof( float) atIndex:25];
2564
- [encoder setBytes:&attn_factor length:sizeof( float) atIndex:26];
2565
- [encoder setBytes:&beta_fast length:sizeof( float) atIndex:27];
2566
- [encoder setBytes:&beta_slow length:sizeof( float) atIndex:28];
2824
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:4];
2567
2825
 
2568
2826
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2569
2827
  } break;
@@ -2620,6 +2878,7 @@ static void lm_ggml_metal_encode_node(
2620
2878
  default: LM_GGML_ABORT("fatal error");
2621
2879
  };
2622
2880
 
2881
+ // TODO: add lm_ggml_metal_kargs struct
2623
2882
  [encoder setComputePipelineState:pipeline];
2624
2883
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
2625
2884
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
@@ -2660,6 +2919,7 @@ static void lm_ggml_metal_encode_node(
2660
2919
 
2661
2920
  const id<MTLComputePipelineState> pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline;
2662
2921
 
2922
+ // TODO: add lm_ggml_metal_kargs struct
2663
2923
  [encoder setComputePipelineState:pipeline];
2664
2924
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2665
2925
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
@@ -2694,6 +2954,7 @@ static void lm_ggml_metal_encode_node(
2694
2954
 
2695
2955
  id<MTLComputePipelineState> pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_PAD_F32].pipeline;
2696
2956
 
2957
+ // TODO: add lm_ggml_metal_kargs struct
2697
2958
  [encoder setComputePipelineState:pipeline];
2698
2959
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2699
2960
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
@@ -2730,6 +2991,7 @@ static void lm_ggml_metal_encode_node(
2730
2991
 
2731
2992
  id<MTLComputePipelineState> pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ARANGE_F32].pipeline;
2732
2993
 
2994
+ // TODO: add lm_ggml_metal_kargs struct
2733
2995
  [encoder setComputePipelineState:pipeline];
2734
2996
  [encoder setBuffer:id_dst offset:offs_dst atIndex:0];
2735
2997
  [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:1];
@@ -2751,6 +3013,7 @@ static void lm_ggml_metal_encode_node(
2751
3013
 
2752
3014
  id<MTLComputePipelineState> pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32].pipeline;
2753
3015
 
3016
+ // TODO: add lm_ggml_metal_kargs struct
2754
3017
  [encoder setComputePipelineState:pipeline];
2755
3018
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2756
3019
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
@@ -2789,6 +3052,7 @@ static void lm_ggml_metal_encode_node(
2789
3052
  default: LM_GGML_ABORT("fatal error");
2790
3053
  };
2791
3054
 
3055
+ // TODO: add lm_ggml_metal_kargs struct
2792
3056
  [encoder setComputePipelineState:pipeline];
2793
3057
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2794
3058
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
@@ -2807,6 +3071,7 @@ static void lm_ggml_metal_encode_node(
2807
3071
 
2808
3072
  id<MTLComputePipelineState> pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline;
2809
3073
 
3074
+ // TODO: add lm_ggml_metal_kargs struct
2810
3075
  [encoder setComputePipelineState:pipeline];
2811
3076
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2812
3077
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
@@ -2822,6 +3087,7 @@ static void lm_ggml_metal_encode_node(
2822
3087
  LM_GGML_ASSERT(ne11 % 32 == 0);
2823
3088
 
2824
3089
  LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_F32);
3090
+ LM_GGML_ASSERT(src1->type == src2->type);
2825
3091
 
2826
3092
  LM_GGML_ASSERT(lm_ggml_are_same_shape (src1, src2));
2827
3093
 
@@ -2868,27 +3134,176 @@ static void lm_ggml_metal_encode_node(
2868
3134
 
2869
3135
  bool use_vec_kernel = false;
2870
3136
 
3137
+ // TODO: add vec kernels for (ne00%64 == 0) and maybe also for (ne00%32 == 0)
3138
+ // for now avoiding mainly to keep the number of templates/kernels a bit lower
2871
3139
  if (ne01 >= 4 || (ne00%128 != 0)) {
2872
- switch (ne00) {
2873
- case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
2874
- case 80: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;
2875
- case 96: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
2876
- case 112: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
2877
- case 128: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
2878
- //case 256: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
3140
+ switch (src1->type) {
3141
+ case LM_GGML_TYPE_F16:
3142
+ {
3143
+ switch (ne00) {
3144
+ case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
3145
+ case 80: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;
3146
+ case 96: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
3147
+ case 112: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
3148
+ case 128: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
3149
+ case 256: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
3150
+ default:
3151
+ {
3152
+ LM_GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
3153
+ LM_GGML_LOG_ERROR("add template specialization for this size\n");
3154
+ LM_GGML_ABORT("add template specialization for this size");
3155
+ }
3156
+ }
3157
+ } break;
3158
+ case LM_GGML_TYPE_BF16:
3159
+ {
3160
+ switch (ne00) {
3161
+ case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64 ].pipeline; break;
3162
+ case 80: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80 ].pipeline; break;
3163
+ case 96: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96 ].pipeline; break;
3164
+ case 112: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112].pipeline; break;
3165
+ case 128: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128].pipeline; break;
3166
+ case 256: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256].pipeline; break;
3167
+ default:
3168
+ {
3169
+ LM_GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
3170
+ LM_GGML_LOG_ERROR("add template specialization for this size\n");
3171
+ LM_GGML_ABORT("add template specialization for this size");
3172
+ }
3173
+ }
3174
+ } break;
3175
+ case LM_GGML_TYPE_Q4_0:
3176
+ {
3177
+ switch (ne00) {
3178
+ case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64 ].pipeline; break;
3179
+ case 80: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80 ].pipeline; break;
3180
+ case 96: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96 ].pipeline; break;
3181
+ case 112: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112].pipeline; break;
3182
+ case 128: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128].pipeline; break;
3183
+ case 256: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256].pipeline; break;
3184
+ default:
3185
+ {
3186
+ LM_GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
3187
+ LM_GGML_LOG_ERROR("add template specialization for this size\n");
3188
+ LM_GGML_ABORT("add template specialization for this size");
3189
+ }
3190
+ }
3191
+ } break;
3192
+ case LM_GGML_TYPE_Q4_1:
3193
+ {
3194
+ switch (ne00) {
3195
+ case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64 ].pipeline; break;
3196
+ case 80: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80 ].pipeline; break;
3197
+ case 96: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96 ].pipeline; break;
3198
+ case 112: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112].pipeline; break;
3199
+ case 128: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128].pipeline; break;
3200
+ case 256: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256].pipeline; break;
3201
+ default:
3202
+ {
3203
+ LM_GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
3204
+ LM_GGML_LOG_ERROR("add template specialization for this size\n");
3205
+ LM_GGML_ABORT("add template specialization for this size");
3206
+ }
3207
+ }
3208
+ } break;
3209
+ case LM_GGML_TYPE_Q5_0:
3210
+ {
3211
+ switch (ne00) {
3212
+ case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64 ].pipeline; break;
3213
+ case 80: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80 ].pipeline; break;
3214
+ case 96: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96 ].pipeline; break;
3215
+ case 112: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112].pipeline; break;
3216
+ case 128: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128].pipeline; break;
3217
+ case 256: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256].pipeline; break;
3218
+ default:
3219
+ {
3220
+ LM_GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
3221
+ LM_GGML_LOG_ERROR("add template specialization for this size\n");
3222
+ LM_GGML_ABORT("add template specialization for this size");
3223
+ }
3224
+ }
3225
+ } break;
3226
+ case LM_GGML_TYPE_Q5_1:
3227
+ {
3228
+ switch (ne00) {
3229
+ case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64 ].pipeline; break;
3230
+ case 80: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80 ].pipeline; break;
3231
+ case 96: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96 ].pipeline; break;
3232
+ case 112: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112].pipeline; break;
3233
+ case 128: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128].pipeline; break;
3234
+ case 256: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256].pipeline; break;
3235
+ default:
3236
+ {
3237
+ LM_GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
3238
+ LM_GGML_LOG_ERROR("add template specialization for this size\n");
3239
+ LM_GGML_ABORT("add template specialization for this size");
3240
+ }
3241
+ }
3242
+ } break;
3243
+ case LM_GGML_TYPE_Q8_0:
3244
+ {
3245
+ switch (ne00) {
3246
+ case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64 ].pipeline; break;
3247
+ case 80: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80 ].pipeline; break;
3248
+ case 96: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96 ].pipeline; break;
3249
+ case 112: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112].pipeline; break;
3250
+ case 128: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128].pipeline; break;
3251
+ case 256: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256].pipeline; break;
3252
+ default:
3253
+ {
3254
+ LM_GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
3255
+ LM_GGML_LOG_ERROR("add template specialization for this size\n");
3256
+ LM_GGML_ABORT("add template specialization for this size");
3257
+ }
3258
+ }
3259
+ } break;
2879
3260
  default:
2880
- {
2881
- LM_GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
2882
- LM_GGML_LOG_ERROR("add template specialization for this size\n");
2883
- LM_GGML_ABORT("add template specialization for this size");
2884
- }
3261
+ {
3262
+ LM_GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
3263
+ LM_GGML_LOG_ERROR("add template specialization for this type\n");
3264
+ LM_GGML_ABORT("add template specialization for this type");
3265
+ }
2885
3266
  }
2886
3267
  } else {
2887
3268
  use_vec_kernel = true;
2888
3269
 
2889
3270
  switch (ne00) {
2890
- case 128: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
2891
- //case 256: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
3271
+ case 128:
3272
+ {
3273
+ switch (src1->type) {
3274
+ case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
3275
+ case LM_GGML_TYPE_BF16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128].pipeline; break;
3276
+ case LM_GGML_TYPE_Q4_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128].pipeline; break;
3277
+ case LM_GGML_TYPE_Q4_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128].pipeline; break;
3278
+ case LM_GGML_TYPE_Q5_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128].pipeline; break;
3279
+ case LM_GGML_TYPE_Q5_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128].pipeline; break;
3280
+ case LM_GGML_TYPE_Q8_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128].pipeline; break;
3281
+ default:
3282
+ {
3283
+ LM_GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
3284
+ LM_GGML_LOG_ERROR("add template specialization for this type\n");
3285
+ LM_GGML_ABORT("add template specialization for this type");
3286
+ }
3287
+ }
3288
+ } break;
3289
+ case 256:
3290
+ {
3291
+ switch (src1->type) {
3292
+ case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
3293
+ case LM_GGML_TYPE_BF16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256].pipeline; break;
3294
+ case LM_GGML_TYPE_Q4_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256].pipeline; break;
3295
+ case LM_GGML_TYPE_Q4_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256].pipeline; break;
3296
+ case LM_GGML_TYPE_Q5_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256].pipeline; break;
3297
+ case LM_GGML_TYPE_Q5_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256].pipeline; break;
3298
+ case LM_GGML_TYPE_Q8_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256].pipeline; break;
3299
+ default:
3300
+ {
3301
+ LM_GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
3302
+ LM_GGML_LOG_ERROR("add template specialization for this type\n");
3303
+ LM_GGML_ABORT("add template specialization for this type");
3304
+ }
3305
+ }
3306
+ } break;
2892
3307
  default:
2893
3308
  {
2894
3309
  LM_GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
@@ -2898,40 +3313,41 @@ static void lm_ggml_metal_encode_node(
2898
3313
  }
2899
3314
  }
2900
3315
 
3316
+ lm_ggml_metal_kargs_flash_attn_ext args = {
3317
+ /*.ne01 =*/ ne01,
3318
+ /*.ne02 =*/ ne02,
3319
+ /*.ne03 =*/ ne03,
3320
+ /*.nb01 =*/ nb01,
3321
+ /*.nb02 =*/ nb02,
3322
+ /*.nb03 =*/ nb03,
3323
+ /*.ne11 =*/ ne11,
3324
+ /*.ne_12_2 =*/ ne12,
3325
+ /*.ne_12_3 =*/ ne13,
3326
+ /*.nb_12_1 =*/ nb11,
3327
+ /*.nb_12_2 =*/ nb12,
3328
+ /*.nb_12_3 =*/ nb13,
3329
+ /*.nb31 =*/ nb31,
3330
+ /*.ne1 =*/ ne1,
3331
+ /*.ne2 =*/ ne2,
3332
+ /*.scale =*/ scale,
3333
+ /*.max_bias =*/ max_bias,
3334
+ /*.m0 =*/ m0,
3335
+ /*.m1 =*/ m1,
3336
+ /*.n_head_log2 =*/ n_head_log2,
3337
+ /*.logit_softcap =*/ logit_softcap,
3338
+ };
3339
+
2901
3340
  [encoder setComputePipelineState:pipeline];
2902
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2903
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
2904
- [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
3341
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
3342
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
3343
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
3344
+ [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
2905
3345
  if (id_src3) {
2906
- [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
3346
+ [encoder setBuffer:id_src3 offset:offs_src3 atIndex:4];
2907
3347
  } else {
2908
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:3];
3348
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:4];
2909
3349
  }
2910
- [encoder setBuffer:id_dst offset:offs_dst atIndex:4];
2911
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5];
2912
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6];
2913
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7];
2914
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
2915
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
2916
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
2917
- [encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:11];
2918
- [encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:12];
2919
- [encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:13];
2920
- [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14];
2921
- [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15];
2922
- [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16];
2923
- [encoder setBytes:&nb21 length:sizeof(uint64_t) atIndex:17];
2924
- [encoder setBytes:&nb22 length:sizeof(uint64_t) atIndex:18];
2925
- [encoder setBytes:&nb23 length:sizeof(uint64_t) atIndex:19];
2926
- [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:20];
2927
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:21];
2928
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:22];
2929
- [encoder setBytes:&scale length:sizeof( float) atIndex:23];
2930
- [encoder setBytes:&max_bias length:sizeof( float) atIndex:24];
2931
- [encoder setBytes:&m0 length:sizeof(m0) atIndex:25];
2932
- [encoder setBytes:&m1 length:sizeof(m1) atIndex:26];
2933
- [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:27];
2934
- [encoder setBytes:&logit_softcap length:sizeof(logit_softcap) atIndex:28];
3350
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:5];
2935
3351
 
2936
3352
  if (!use_vec_kernel) {
2937
3353
  // half8x8 kernel
@@ -2942,10 +3358,19 @@ static void lm_ggml_metal_encode_node(
2942
3358
  LM_GGML_ASSERT(nqptg % 8 == 0);
2943
3359
  LM_GGML_ASSERT(ncpsg % 32 == 0);
2944
3360
 
3361
+ // 2*(2*ncpsg + nqptg)*(nsg)
3362
+ // ncpsg soft_max values + ncpsg mask values + a diagonal scaling matrix (in float)
3363
+ //
3364
+ // 16*32*(nsg)
3365
+ // the shared memory needed for the simdgroups to load the KV cache
3366
+ // each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
3367
+ //
3368
+ #define FATTN_SMEM(nsg) (LM_GGML_PAD((nqptg*(ne00 + 2*(2*ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16))
3369
+
2945
3370
  int64_t nsgmax = 2;
2946
3371
 
2947
3372
  while (true) {
2948
- const size_t smem = nqptg*(ne00 + 2*nsgmax*(ncpsg + nqptg))*(sizeof(float)/2);
3373
+ const size_t smem = FATTN_SMEM(nsgmax);
2949
3374
  if (smem > device.maxThreadgroupMemoryLength) {
2950
3375
  break;
2951
3376
  }
@@ -2956,16 +3381,15 @@ static void lm_ggml_metal_encode_node(
2956
3381
  // simdgroups per threadgroup (a.k.a. warps)
2957
3382
  const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
2958
3383
 
2959
- const size_t smem = nqptg*(ne00 + 2*nsg*(ncpsg + nqptg))*(sizeof(float)/2);
3384
+ const size_t smem = FATTN_SMEM(nsg);
2960
3385
 
2961
- //printf("smem: %zu, max: %zu\n", smem, device.maxThreadgroupMemoryLength);
3386
+ //printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg);
2962
3387
  LM_GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
2963
-
2964
- [encoder setThreadgroupMemoryLength:LM_GGML_PAD(smem, 16) atIndex:0];
2965
-
3388
+ [encoder setThreadgroupMemoryLength:smem atIndex:0];
3389
+ #undef FATTN_SMEM
2966
3390
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
2967
3391
  } else {
2968
- // half1x4 kernel
3392
+ // half4x4 kernel
2969
3393
  const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !!
2970
3394
  const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
2971
3395
 
@@ -2973,8 +3397,28 @@ static void lm_ggml_metal_encode_node(
2973
3397
  LM_GGML_ASSERT(nqptg % 1 == 0);
2974
3398
  LM_GGML_ASSERT(ncpsg % 32 == 0);
2975
3399
 
3400
+ // ne00 + 2*ncpsg*(nsg)
3401
+ // for each query, we load it as f16 in shared memory (ne00)
3402
+ // and store the soft_max values and the mask
3403
+ //
3404
+ // ne00*(nsg)
3405
+ // each simdgroup has a full f16 head vector in shared mem to accumulate results
3406
+ //
3407
+ #define FATTN_SMEM(nsg) (LM_GGML_PAD((nqptg*(ne00 + 2*ncpsg*(nsg)) + ne00*(nsg))*(sizeof(float)/2), 16))
3408
+
3409
+ int64_t nsgmax = 2;
3410
+
3411
+ while (true) {
3412
+ const size_t smem = FATTN_SMEM(nsgmax);
3413
+ if (smem > device.maxThreadgroupMemoryLength) {
3414
+ break;
3415
+ }
3416
+ nsgmax *= 2;
3417
+ }
3418
+ nsgmax /= 2;
3419
+
2976
3420
  // simdgroups per threadgroup (a.k.a. warps)
2977
- const int64_t nsgt = MAX(2, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32));
3421
+ const int64_t nsgt = MAX(2, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)));
2978
3422
 
2979
3423
  int64_t nsg = 1;
2980
3424
  while (nsg <= nsgt) {
@@ -2982,12 +3426,12 @@ static void lm_ggml_metal_encode_node(
2982
3426
  }
2983
3427
  nsg /= 2;
2984
3428
 
2985
- const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof(float)/2);
3429
+ const size_t smem = FATTN_SMEM(nsg);
2986
3430
 
2987
- //printf("smem: %zu, max: %zu\n", smem, device.maxThreadgroupMemoryLength);
3431
+ //printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg);
2988
3432
  LM_GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
2989
- [encoder setThreadgroupMemoryLength:LM_GGML_PAD(smem, 16) atIndex:0];
2990
-
3433
+ [encoder setThreadgroupMemoryLength:smem atIndex:0];
3434
+ #undef FATTN_SMEM
2991
3435
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
2992
3436
  }
2993
3437
  } break;
@@ -3009,6 +3453,7 @@ static void lm_ggml_metal_encode_node(
3009
3453
  switch (dstt) {
3010
3454
  case LM_GGML_TYPE_F32: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break;
3011
3455
  case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break;
3456
+ case LM_GGML_TYPE_BF16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CPY_F32_BF16].pipeline; break;
3012
3457
  case LM_GGML_TYPE_Q8_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break;
3013
3458
  case LM_GGML_TYPE_Q4_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break;
3014
3459
  case LM_GGML_TYPE_Q4_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break;
@@ -3026,28 +3471,40 @@ static void lm_ggml_metal_encode_node(
3026
3471
  default: LM_GGML_ABORT("not implemented");
3027
3472
  };
3028
3473
  } break;
3474
+ case LM_GGML_TYPE_BF16:
3475
+ {
3476
+ switch (dstt) {
3477
+ case LM_GGML_TYPE_F32: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CPY_BF16_F32].pipeline; break;
3478
+ case LM_GGML_TYPE_BF16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16].pipeline; break;
3479
+ default: LM_GGML_ASSERT(false && "not implemented");
3480
+ };
3481
+ } break;
3029
3482
  default: LM_GGML_ABORT("not implemented");
3030
3483
  }
3031
3484
 
3485
+ lm_ggml_metal_kargs_cpy args = {
3486
+ /*.ne00 =*/ ne00,
3487
+ /*.ne01 =*/ ne01,
3488
+ /*.ne02 =*/ ne02,
3489
+ /*.ne03 =*/ ne03,
3490
+ /*.nb00 =*/ nb00,
3491
+ /*.nb01 =*/ nb01,
3492
+ /*.nb02 =*/ nb02,
3493
+ /*.nb03 =*/ nb03,
3494
+ /*.ne0 =*/ ne0,
3495
+ /*.ne1 =*/ ne1,
3496
+ /*.ne2 =*/ ne2,
3497
+ /*.ne3 =*/ ne3,
3498
+ /*.nb0 =*/ nb0,
3499
+ /*.nb1 =*/ nb1,
3500
+ /*.nb2 =*/ nb2,
3501
+ /*.nb3 =*/ nb3,
3502
+ };
3503
+
3032
3504
  [encoder setComputePipelineState:pipeline];
3033
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
3034
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
3035
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
3036
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
3037
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
3038
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
3039
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
3040
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
3041
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
3042
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
3043
- [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
3044
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
3045
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
3046
- [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
3047
- [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
3048
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
3049
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
3050
- [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
3505
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
3506
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
3507
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
3051
3508
 
3052
3509
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
3053
3510
  } break;
@@ -3092,6 +3549,7 @@ static void lm_ggml_metal_encode_node(
3092
3549
  const int64_t n_threads = MIN((int64_t)[pipeline maxTotalThreadsPerThreadgroup], parallel_elements);
3093
3550
  const int64_t n_tg = (parallel_elements + n_threads - 1) / n_threads;
3094
3551
 
3552
+ // TODO: add lm_ggml_metal_kargs struct
3095
3553
  [encoder setComputePipelineState:pipeline];
3096
3554
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
3097
3555
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
@@ -3279,6 +3737,12 @@ static void * lm_ggml_backend_metal_buffer_get_base(lm_ggml_backend_buffer_t buf
3279
3737
  return ctx->all_data;
3280
3738
  }
3281
3739
 
3740
+ static void lm_ggml_backend_metal_buffer_memset_tensor(lm_ggml_backend_buffer_t buffer, struct lm_ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
3741
+ memset((char *)tensor->data + offset, value, size);
3742
+
3743
+ UNUSED(buffer);
3744
+ }
3745
+
3282
3746
  static void lm_ggml_backend_metal_buffer_set_tensor(lm_ggml_backend_buffer_t buffer, struct lm_ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
3283
3747
  memcpy((char *)tensor->data + offset, data, size);
3284
3748
 
@@ -3311,7 +3775,7 @@ static struct lm_ggml_backend_buffer_i lm_ggml_backend_metal_buffer_i = {
3311
3775
  /* .free_buffer = */ lm_ggml_backend_metal_buffer_free_buffer,
3312
3776
  /* .get_base = */ lm_ggml_backend_metal_buffer_get_base,
3313
3777
  /* .init_tensor = */ NULL,
3314
- /* .memset_tensor = */ NULL,
3778
+ /* .memset_tensor = */ lm_ggml_backend_metal_buffer_memset_tensor,
3315
3779
  /* .set_tensor = */ lm_ggml_backend_metal_buffer_set_tensor,
3316
3780
  /* .get_tensor = */ lm_ggml_backend_metal_buffer_get_tensor,
3317
3781
  /* .cpy_tensor = */ lm_ggml_backend_metal_buffer_cpy_tensor,
@@ -3844,7 +4308,7 @@ static lm_ggml_backend_buffer_t lm_ggml_backend_metal_device_buffer_from_ptr(lm_
3844
4308
  }
3845
4309
  }
3846
4310
 
3847
- return lm_ggml_backend_buffer_init(lm_ggml_backend_metal_buffer_type(), lm_ggml_backend_metal_buffer_i, ctx, size);
4311
+ return lm_ggml_backend_buffer_init(lm_ggml_backend_metal_buffer_from_ptr_type(), lm_ggml_backend_metal_buffer_i, ctx, size);
3848
4312
  }
3849
4313
 
3850
4314
  static bool lm_ggml_backend_metal_device_supports_op(lm_ggml_backend_dev_t dev, const struct lm_ggml_tensor * op) {
@@ -3854,7 +4318,8 @@ static bool lm_ggml_backend_metal_device_supports_op(lm_ggml_backend_dev_t dev,
3854
4318
  }
3855
4319
 
3856
4320
  static bool lm_ggml_backend_metal_device_supports_buft(lm_ggml_backend_dev_t dev, lm_ggml_backend_buffer_type_t buft) {
3857
- return buft->iface.get_name == lm_ggml_backend_metal_buffer_type_get_name;
4321
+ return buft->iface.get_name == lm_ggml_backend_metal_buffer_type_get_name ||
4322
+ buft->iface.get_name == lm_ggml_backend_metal_buffer_from_ptr_type_get_name;
3858
4323
 
3859
4324
  UNUSED(dev);
3860
4325
  }
@@ -3907,19 +4372,45 @@ static lm_ggml_backend_dev_t lm_ggml_backend_metal_reg_device_get(lm_ggml_backen
3907
4372
  LM_GGML_UNUSED(index);
3908
4373
  }
3909
4374
 
4375
+ static struct lm_ggml_backend_feature g_lm_ggml_backend_metal_features[] = {
4376
+ #if defined(LM_GGML_METAL_EMBED_LIBRARY)
4377
+ { "EMBED_LIBRARY", "1" },
4378
+ #endif
4379
+ #if defined(LM_GGML_METAL_USE_BF16)
4380
+ { "BF16", "1" },
4381
+ #endif
4382
+ { nil, nil },
4383
+ };
4384
+
4385
+ static struct lm_ggml_backend_feature * lm_ggml_backend_metal_get_features(lm_ggml_backend_reg_t reg) {
4386
+ return g_lm_ggml_backend_metal_features;
4387
+
4388
+ LM_GGML_UNUSED(reg);
4389
+ }
4390
+
4391
+ static void * lm_ggml_backend_metal_get_proc_address(lm_ggml_backend_reg_t reg, const char * name) {
4392
+ if (strcmp(name, "lm_ggml_backend_get_features") == 0) {
4393
+ return (void *)lm_ggml_backend_metal_get_features;
4394
+ }
4395
+
4396
+ return NULL;
4397
+
4398
+ LM_GGML_UNUSED(reg);
4399
+ }
3910
4400
  static struct lm_ggml_backend_reg_i lm_ggml_backend_metal_reg_i = {
3911
4401
  /* .get_name = */ lm_ggml_backend_metal_reg_get_name,
3912
4402
  /* .device_count = */ lm_ggml_backend_metal_reg_device_count,
3913
4403
  /* .device_get = */ lm_ggml_backend_metal_reg_device_get,
3914
- /* .get_proc_address = */ NULL,
4404
+ /* .get_proc_address = */ lm_ggml_backend_metal_get_proc_address,
3915
4405
  };
3916
4406
 
3917
4407
  lm_ggml_backend_reg_t lm_ggml_backend_metal_reg(void) {
3918
4408
  // TODO: make this thread-safe somehow?
3919
4409
  {
3920
4410
  g_lm_ggml_backend_metal_reg = (struct lm_ggml_backend_reg) {
3921
- /* .iface = */ lm_ggml_backend_metal_reg_i,
3922
- /* .context = */ NULL,
4411
+ /* .api_version = */ LM_GGML_BACKEND_API_VERSION,
4412
+ /* .iface = */ lm_ggml_backend_metal_reg_i,
4413
+ /* .context = */ NULL,
3923
4414
  };
3924
4415
 
3925
4416
  g_lm_ggml_backend_metal_device = (struct lm_ggml_backend_device) {
@@ -3931,3 +4422,5 @@ lm_ggml_backend_reg_t lm_ggml_backend_metal_reg(void) {
3931
4422
 
3932
4423
  return &g_lm_ggml_backend_metal_reg;
3933
4424
  }
4425
+
4426
+ LM_GGML_BACKEND_DL_IMPL(lm_ggml_backend_metal_reg)