whisper.rn 0.5.0-rc.8 → 0.5.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. package/cpp/ggml-alloc.c +1 -15
  2. package/cpp/ggml-backend-reg.cpp +17 -8
  3. package/cpp/ggml-backend.cpp +15 -22
  4. package/cpp/ggml-common.h +17 -0
  5. package/cpp/ggml-cpu/arch/arm/quants.c +132 -596
  6. package/cpp/ggml-cpu/arch/arm/repack.cpp +14 -286
  7. package/cpp/ggml-cpu/arch/x86/quants.c +184 -675
  8. package/cpp/ggml-cpu/arch/x86/repack.cpp +4679 -1657
  9. package/cpp/ggml-cpu/arch-fallback.h +34 -0
  10. package/cpp/ggml-cpu/ggml-cpu.c +22 -1
  11. package/cpp/ggml-cpu/ggml-cpu.cpp +21 -24
  12. package/cpp/ggml-cpu/ops.cpp +870 -211
  13. package/cpp/ggml-cpu/ops.h +3 -8
  14. package/cpp/ggml-cpu/quants.c +35 -0
  15. package/cpp/ggml-cpu/quants.h +8 -0
  16. package/cpp/ggml-cpu/repack.cpp +458 -47
  17. package/cpp/ggml-cpu/repack.h +22 -0
  18. package/cpp/ggml-cpu/simd-mappings.h +1 -1
  19. package/cpp/ggml-cpu/traits.cpp +2 -2
  20. package/cpp/ggml-cpu/traits.h +1 -1
  21. package/cpp/ggml-cpu/vec.cpp +12 -9
  22. package/cpp/ggml-cpu/vec.h +107 -13
  23. package/cpp/ggml-impl.h +77 -0
  24. package/cpp/ggml-metal-impl.h +51 -12
  25. package/cpp/ggml-metal.m +610 -115
  26. package/cpp/ggml-opt.cpp +97 -41
  27. package/cpp/ggml-opt.h +25 -6
  28. package/cpp/ggml-quants.c +110 -16
  29. package/cpp/ggml-quants.h +6 -0
  30. package/cpp/ggml-whisper-sim.metallib +0 -0
  31. package/cpp/ggml-whisper.metallib +0 -0
  32. package/cpp/ggml.c +314 -88
  33. package/cpp/ggml.h +137 -11
  34. package/cpp/gguf.cpp +8 -1
  35. package/cpp/jsi/RNWhisperJSI.cpp +23 -6
  36. package/cpp/whisper.cpp +15 -6
  37. package/ios/RNWhisper.mm +6 -6
  38. package/ios/RNWhisperContext.mm +2 -0
  39. package/ios/RNWhisperVadContext.mm +2 -0
  40. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-common.h +17 -0
  41. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +77 -0
  42. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +51 -12
  43. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-opt.h +25 -6
  44. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-quants.h +6 -0
  45. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +137 -11
  46. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  47. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
  48. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-common.h +17 -0
  49. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +77 -0
  50. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +51 -12
  51. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-opt.h +25 -6
  52. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-quants.h +6 -0
  53. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +137 -11
  54. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
  55. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  56. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-common.h +17 -0
  57. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +77 -0
  58. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +51 -12
  59. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-opt.h +25 -6
  60. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-quants.h +6 -0
  61. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +137 -11
  62. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  63. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
  64. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-common.h +17 -0
  65. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +77 -0
  66. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +51 -12
  67. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-opt.h +25 -6
  68. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-quants.h +6 -0
  69. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +137 -11
  70. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
  71. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  72. package/lib/commonjs/realtime-transcription/RealtimeTranscriber.js +28 -2
  73. package/lib/commonjs/realtime-transcription/RealtimeTranscriber.js.map +1 -1
  74. package/lib/module/realtime-transcription/RealtimeTranscriber.js +28 -2
  75. package/lib/module/realtime-transcription/RealtimeTranscriber.js.map +1 -1
  76. package/lib/typescript/realtime-transcription/RealtimeTranscriber.d.ts +1 -0
  77. package/lib/typescript/realtime-transcription/RealtimeTranscriber.d.ts.map +1 -1
  78. package/lib/typescript/realtime-transcription/types.d.ts +6 -0
  79. package/lib/typescript/realtime-transcription/types.d.ts.map +1 -1
  80. package/package.json +1 -1
  81. package/src/realtime-transcription/RealtimeTranscriber.ts +32 -0
  82. package/src/realtime-transcription/types.ts +6 -0
package/cpp/ggml-metal.m CHANGED
@@ -55,6 +55,12 @@ static struct wsp_ggml_backend_metal_device_context {
55
55
  bool has_residency_sets;
56
56
  bool has_bfloat;
57
57
  bool use_bfloat;
58
+ bool use_fusion;
59
+
60
+ int debug_fusion;
61
+
62
+ // how many times a given op was fused
63
+ uint64_t fuse_cnt[WSP_GGML_OP_COUNT];
58
64
 
59
65
  size_t max_size;
60
66
 
@@ -69,6 +75,9 @@ static struct wsp_ggml_backend_metal_device_context {
69
75
  /*.has_residency_sets =*/ false,
70
76
  /*.has_bfloat =*/ false,
71
77
  /*.use_bfloat =*/ false,
78
+ /*.use_fusion =*/ true,
79
+ /*.debug_fusion =*/ 0,
80
+ /*.fuse_cnt =*/ { 0 },
72
81
  /*.max_size =*/ 0,
73
82
  /*.name =*/ "",
74
83
  };
@@ -83,16 +92,14 @@ static id<MTLDevice> wsp_ggml_backend_metal_device_acq(struct wsp_ggml_backend_m
83
92
 
84
93
  if (ctx->mtl_device == nil) {
85
94
  ctx->mtl_device = MTLCreateSystemDefaultDevice();
86
- }
87
95
 
88
- if (ctx->mtl_device) {
89
96
  ctx->has_simdgroup_reduction = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
90
97
  ctx->has_simdgroup_reduction |= [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
91
98
 
92
99
  ctx->has_simdgroup_mm = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
93
100
 
94
101
  #if defined(WSP_GGML_METAL_HAS_RESIDENCY_SETS)
95
- ctx->has_residency_sets = getenv("WSP_GGML_METAL_NO_RESIDENCY") == NULL;
102
+ ctx->has_residency_sets = getenv("WSP_GGML_METAL_NO_RESIDENCY") == nil;
96
103
  #endif
97
104
 
98
105
  ctx->has_bfloat = [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
@@ -103,6 +110,14 @@ static id<MTLDevice> wsp_ggml_backend_metal_device_acq(struct wsp_ggml_backend_m
103
110
  #else
104
111
  ctx->use_bfloat = false;
105
112
  #endif
113
+ ctx->use_fusion = getenv("WSP_GGML_METAL_FUSION_DISABLE") == nil;
114
+
115
+ {
116
+ const char * val = getenv("WSP_GGML_METAL_FUSION_DEBUG");
117
+ ctx->debug_fusion = val ? atoi(val) : 0;
118
+ }
119
+
120
+ memset(ctx->fuse_cnt, 0, sizeof(ctx->fuse_cnt));
106
121
 
107
122
  ctx->max_size = ctx->mtl_device.maxBufferLength;
108
123
 
@@ -122,6 +137,18 @@ static void wsp_ggml_backend_metal_device_rel(struct wsp_ggml_backend_metal_devi
122
137
  ctx->mtl_device_ref_count--;
123
138
 
124
139
  if (ctx->mtl_device_ref_count == 0) {
140
+ if (ctx->debug_fusion > 0) {
141
+ fprintf(stderr, "%s: fusion stats:\n", __func__);
142
+ for (int i = 0; i < WSP_GGML_OP_COUNT; i++) {
143
+ if (ctx->fuse_cnt[i] == 0) {
144
+ continue;
145
+ }
146
+
147
+ // note: cannot use wsp_ggml_log here
148
+ fprintf(stderr, "%s: - %s: %" PRIu64 "\n", __func__, wsp_ggml_op_name((enum wsp_ggml_op) i), ctx->fuse_cnt[i]);
149
+ }
150
+ }
151
+
125
152
  if (ctx->mtl_lock) {
126
153
  [ctx->mtl_lock release];
127
154
  ctx->mtl_lock = nil;
@@ -147,13 +174,28 @@ struct wsp_ggml_metal_kernel {
147
174
 
148
175
  enum wsp_ggml_metal_kernel_type {
149
176
  WSP_GGML_METAL_KERNEL_TYPE_ADD,
150
- WSP_GGML_METAL_KERNEL_TYPE_ADD_ROW,
177
+ WSP_GGML_METAL_KERNEL_TYPE_ADD_FUSE_2,
178
+ WSP_GGML_METAL_KERNEL_TYPE_ADD_FUSE_3,
179
+ WSP_GGML_METAL_KERNEL_TYPE_ADD_FUSE_4,
180
+ WSP_GGML_METAL_KERNEL_TYPE_ADD_FUSE_5,
181
+ WSP_GGML_METAL_KERNEL_TYPE_ADD_FUSE_6,
182
+ WSP_GGML_METAL_KERNEL_TYPE_ADD_FUSE_7,
183
+ WSP_GGML_METAL_KERNEL_TYPE_ADD_FUSE_8,
184
+ WSP_GGML_METAL_KERNEL_TYPE_ADD_ROW_C4,
185
+ WSP_GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2,
186
+ WSP_GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3,
187
+ WSP_GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4,
188
+ WSP_GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5,
189
+ WSP_GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6,
190
+ WSP_GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7,
191
+ WSP_GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8,
151
192
  WSP_GGML_METAL_KERNEL_TYPE_SUB,
152
- WSP_GGML_METAL_KERNEL_TYPE_SUB_ROW,
193
+ WSP_GGML_METAL_KERNEL_TYPE_SUB_ROW_C4,
153
194
  WSP_GGML_METAL_KERNEL_TYPE_MUL,
154
- WSP_GGML_METAL_KERNEL_TYPE_MUL_ROW,
195
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_ROW_C4,
155
196
  WSP_GGML_METAL_KERNEL_TYPE_DIV,
156
- WSP_GGML_METAL_KERNEL_TYPE_DIV_ROW,
197
+ WSP_GGML_METAL_KERNEL_TYPE_DIV_ROW_C4,
198
+ WSP_GGML_METAL_KERNEL_TYPE_ADD_ID,
157
199
  WSP_GGML_METAL_KERNEL_TYPE_REPEAT_F32,
158
200
  WSP_GGML_METAL_KERNEL_TYPE_REPEAT_F16,
159
201
  WSP_GGML_METAL_KERNEL_TYPE_REPEAT_I32,
@@ -173,6 +215,12 @@ enum wsp_ggml_metal_kernel_type {
173
215
  WSP_GGML_METAL_KERNEL_TYPE_SILU,
174
216
  WSP_GGML_METAL_KERNEL_TYPE_SILU_4,
175
217
  WSP_GGML_METAL_KERNEL_TYPE_ELU,
218
+ WSP_GGML_METAL_KERNEL_TYPE_ABS,
219
+ WSP_GGML_METAL_KERNEL_TYPE_SGN,
220
+ WSP_GGML_METAL_KERNEL_TYPE_STEP,
221
+ WSP_GGML_METAL_KERNEL_TYPE_HARDSWISH,
222
+ WSP_GGML_METAL_KERNEL_TYPE_HARDSIGMOID,
223
+ WSP_GGML_METAL_KERNEL_TYPE_EXP,
176
224
  WSP_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16,
177
225
  WSP_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4,
178
226
  WSP_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32,
@@ -187,6 +235,7 @@ enum wsp_ggml_metal_kernel_type {
187
235
  WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0,
188
236
  WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1,
189
237
  WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0,
238
+ WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_MXFP4,
190
239
  WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K,
191
240
  WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K,
192
241
  WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K,
@@ -212,11 +261,14 @@ enum wsp_ggml_metal_kernel_type {
212
261
  WSP_GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1,
213
262
  WSP_GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL,
214
263
  WSP_GGML_METAL_KERNEL_TYPE_RMS_NORM,
264
+ WSP_GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL,
265
+ WSP_GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD,
215
266
  WSP_GGML_METAL_KERNEL_TYPE_L2_NORM,
216
267
  WSP_GGML_METAL_KERNEL_TYPE_GROUP_NORM,
217
268
  WSP_GGML_METAL_KERNEL_TYPE_NORM,
218
269
  WSP_GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
219
270
  WSP_GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,
271
+ WSP_GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP,
220
272
  WSP_GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32,
221
273
  WSP_GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32,
222
274
  WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
@@ -236,6 +288,7 @@ enum wsp_ggml_metal_kernel_type {
236
288
  WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,
237
289
  WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,
238
290
  WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,
291
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32,
239
292
  WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2,
240
293
  WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3,
241
294
  WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4,
@@ -260,6 +313,10 @@ enum wsp_ggml_metal_kernel_type {
260
313
  WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3,
261
314
  WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4,
262
315
  WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5,
316
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_2,
317
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_3,
318
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_4,
319
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_5,
263
320
  WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2,
264
321
  WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3,
265
322
  WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4,
@@ -301,6 +358,7 @@ enum wsp_ggml_metal_kernel_type {
301
358
  WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32,
302
359
  WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32,
303
360
  WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32,
361
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_MXFP4_F32,
304
362
  WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32,
305
363
  WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32,
306
364
  WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32,
@@ -323,6 +381,7 @@ enum wsp_ggml_metal_kernel_type {
323
381
  WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32,
324
382
  WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32,
325
383
  WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32,
384
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32,
326
385
  WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32,
327
386
  WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32,
328
387
  WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32,
@@ -347,6 +406,7 @@ enum wsp_ggml_metal_kernel_type {
347
406
  WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16,
348
407
  WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16,
349
408
  WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16,
409
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MXFP4_F16,
350
410
  WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16,
351
411
  WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16,
352
412
  WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16,
@@ -529,6 +589,9 @@ enum wsp_ggml_metal_kernel_type {
529
589
  WSP_GGML_METAL_KERNEL_TYPE_REGLU,
530
590
  WSP_GGML_METAL_KERNEL_TYPE_GEGLU,
531
591
  WSP_GGML_METAL_KERNEL_TYPE_SWIGLU,
592
+ WSP_GGML_METAL_KERNEL_TYPE_SWIGLU_OAI,
593
+ WSP_GGML_METAL_KERNEL_TYPE_GEGLU_ERF,
594
+ WSP_GGML_METAL_KERNEL_TYPE_GEGLU_QUICK,
532
595
  WSP_GGML_METAL_KERNEL_TYPE_SUM_ROWS,
533
596
  WSP_GGML_METAL_KERNEL_TYPE_MEAN,
534
597
  WSP_GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
@@ -1130,13 +1193,28 @@ static struct wsp_ggml_backend_metal_context * wsp_ggml_metal_init(wsp_ggml_back
1130
1193
  // simd_sum and simd_max requires MTLGPUFamilyApple7
1131
1194
 
1132
1195
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_ADD, add, true);
1133
- WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true);
1196
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_ADD_FUSE_2, add_fuse_2, true);
1197
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_ADD_FUSE_3, add_fuse_3, true);
1198
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_ADD_FUSE_4, add_fuse_4, true);
1199
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_ADD_FUSE_5, add_fuse_5, true);
1200
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_ADD_FUSE_6, add_fuse_6, true);
1201
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_ADD_FUSE_7, add_fuse_7, true);
1202
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_ADD_FUSE_8, add_fuse_8, true);
1203
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_ADD_ROW_C4, add_row_c4, true);
1204
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2, add_row_c4_fuse_2, true);
1205
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3, add_row_c4_fuse_3, true);
1206
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4, add_row_c4_fuse_4, true);
1207
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5, add_row_c4_fuse_5, true);
1208
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6, add_row_c4_fuse_6, true);
1209
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7, add_row_c4_fuse_7, true);
1210
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8, add_row_c4_fuse_8, true);
1134
1211
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SUB, sub, true);
1135
- WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SUB_ROW, sub_row, true);
1212
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SUB_ROW_C4, sub_row_c4, true);
1136
1213
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL, mul, true);
1137
- WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true);
1214
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_ROW_C4, mul_row_c4, true);
1138
1215
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_DIV, div, true);
1139
- WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true);
1216
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_DIV_ROW_C4, div_row_c4, true);
1217
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_ADD_ID, add_id, true);
1140
1218
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true);
1141
1219
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true);
1142
1220
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_REPEAT_I32, repeat_i32, true);
@@ -1156,6 +1234,12 @@ static struct wsp_ggml_backend_metal_context * wsp_ggml_metal_init(wsp_ggml_back
1156
1234
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SILU, silu, true);
1157
1235
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true);
1158
1236
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_ELU, elu, true);
1237
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_ABS, abs, true);
1238
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SGN, sgn, true);
1239
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_STEP, step, true);
1240
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_HARDSWISH, hardswish, true);
1241
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_HARDSIGMOID, hardsigmoid, true);
1242
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_EXP, exp, true);
1159
1243
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, has_simdgroup_reduction);
1160
1244
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, has_simdgroup_reduction);
1161
1245
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, has_simdgroup_reduction);
@@ -1170,6 +1254,7 @@ static struct wsp_ggml_backend_metal_context * wsp_ggml_metal_init(wsp_ggml_back
1170
1254
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true);
1171
1255
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true);
1172
1256
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true);
1257
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_MXFP4, get_rows_mxfp4, true);
1173
1258
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true);
1174
1259
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true);
1175
1260
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true);
@@ -1195,11 +1280,14 @@ static struct wsp_ggml_backend_metal_context * wsp_ggml_metal_init(wsp_ggml_back
1195
1280
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1, set_rows_q5_1, true);
1196
1281
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL, set_rows_iq4_nl, true);
1197
1282
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction);
1283
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL, rms_norm_mul, has_simdgroup_reduction);
1284
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD, rms_norm_mul_add, has_simdgroup_reduction);
1198
1285
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_L2_NORM, l2_norm, has_simdgroup_reduction);
1199
1286
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction);
1200
1287
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_NORM, norm, true);
1201
1288
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true);
1202
1289
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true);
1290
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP, ssm_scan_f32_group, true);
1203
1291
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, rwkv_wkv6_f32, true);
1204
1292
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, rwkv_wkv7_f32, true);
1205
1293
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction);
@@ -1219,6 +1307,7 @@ static struct wsp_ggml_backend_metal_context * wsp_ggml_metal_init(wsp_ggml_back
1219
1307
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, has_simdgroup_reduction);
1220
1308
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, has_simdgroup_reduction);
1221
1309
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, has_simdgroup_reduction);
1310
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32, mul_mv_mxfp4_f32, has_simdgroup_reduction);
1222
1311
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2, mul_mv_ext_f16_f32_r1_2, has_simdgroup_reduction);
1223
1312
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3, mul_mv_ext_f16_f32_r1_3, has_simdgroup_reduction);
1224
1313
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4, mul_mv_ext_f16_f32_r1_4, has_simdgroup_reduction);
@@ -1243,6 +1332,10 @@ static struct wsp_ggml_backend_metal_context * wsp_ggml_metal_init(wsp_ggml_back
1243
1332
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3, mul_mv_ext_q8_0_f32_r1_3, has_simdgroup_reduction);
1244
1333
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4, mul_mv_ext_q8_0_f32_r1_4, has_simdgroup_reduction);
1245
1334
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5, mul_mv_ext_q8_0_f32_r1_5, has_simdgroup_reduction);
1335
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_2, mul_mv_ext_mxfp4_f32_r1_2, has_simdgroup_reduction);
1336
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_3, mul_mv_ext_mxfp4_f32_r1_3, has_simdgroup_reduction);
1337
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_4, mul_mv_ext_mxfp4_f32_r1_4, has_simdgroup_reduction);
1338
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_5, mul_mv_ext_mxfp4_f32_r1_5, has_simdgroup_reduction);
1246
1339
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2, mul_mv_ext_q4_K_f32_r1_2, has_simdgroup_reduction);
1247
1340
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3, mul_mv_ext_q4_K_f32_r1_3, has_simdgroup_reduction);
1248
1341
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4, mul_mv_ext_q4_K_f32_r1_4, has_simdgroup_reduction);
@@ -1284,6 +1377,7 @@ static struct wsp_ggml_backend_metal_context * wsp_ggml_metal_init(wsp_ggml_back
1284
1377
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, has_simdgroup_reduction);
1285
1378
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, has_simdgroup_reduction);
1286
1379
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, has_simdgroup_reduction);
1380
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_MXFP4_F32, mul_mv_id_mxfp4_f32, has_simdgroup_reduction);
1287
1381
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, has_simdgroup_reduction);
1288
1382
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, has_simdgroup_reduction);
1289
1383
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, has_simdgroup_reduction);
@@ -1306,6 +1400,8 @@ static struct wsp_ggml_backend_metal_context * wsp_ggml_metal_init(wsp_ggml_back
1306
1400
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, has_simdgroup_mm);
1307
1401
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, has_simdgroup_mm);
1308
1402
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, has_simdgroup_mm);
1403
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32, mul_mm_mxfp4_f32, has_simdgroup_mm);
1404
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32, mul_mm_mxfp4_f32, has_simdgroup_mm);
1309
1405
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, has_simdgroup_mm);
1310
1406
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, has_simdgroup_mm);
1311
1407
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, has_simdgroup_mm);
@@ -1330,6 +1426,7 @@ static struct wsp_ggml_backend_metal_context * wsp_ggml_metal_init(wsp_ggml_back
1330
1426
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16, mul_mm_id_q5_0_f16, has_simdgroup_mm);
1331
1427
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16, mul_mm_id_q5_1_f16, has_simdgroup_mm);
1332
1428
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16, mul_mm_id_q8_0_f16, has_simdgroup_mm);
1429
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MXFP4_F16, mul_mm_id_mxfp4_f16, has_simdgroup_mm);
1333
1430
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16, mul_mm_id_q2_K_f16, has_simdgroup_mm);
1334
1431
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16, mul_mm_id_q3_K_f16, has_simdgroup_mm);
1335
1432
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16, mul_mm_id_q4_K_f16, has_simdgroup_mm);
@@ -1512,6 +1609,9 @@ static struct wsp_ggml_backend_metal_context * wsp_ggml_metal_init(wsp_ggml_back
1512
1609
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_REGLU, reglu, true);
1513
1610
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GEGLU, geglu, true);
1514
1611
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SWIGLU, swiglu, true);
1612
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SWIGLU_OAI, swiglu_oai, true);
1613
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GEGLU_ERF, geglu_erf, true);
1614
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GEGLU_QUICK, geglu_quick, true);
1515
1615
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
1516
1616
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MEAN, mean, true);
1517
1617
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
@@ -1686,6 +1786,12 @@ static bool wsp_ggml_metal_supports_op(const struct wsp_ggml_backend_metal_devic
1686
1786
  case WSP_GGML_UNARY_OP_SILU:
1687
1787
  case WSP_GGML_UNARY_OP_ELU:
1688
1788
  case WSP_GGML_UNARY_OP_NEG:
1789
+ case WSP_GGML_UNARY_OP_ABS:
1790
+ case WSP_GGML_UNARY_OP_SGN:
1791
+ case WSP_GGML_UNARY_OP_STEP:
1792
+ case WSP_GGML_UNARY_OP_HARDSWISH:
1793
+ case WSP_GGML_UNARY_OP_HARDSIGMOID:
1794
+ case WSP_GGML_UNARY_OP_EXP:
1689
1795
  return wsp_ggml_is_contiguous(op->src[0]) && op->src[0]->type == WSP_GGML_TYPE_F32;
1690
1796
  default:
1691
1797
  return false;
@@ -1695,6 +1801,9 @@ static bool wsp_ggml_metal_supports_op(const struct wsp_ggml_backend_metal_devic
1695
1801
  case WSP_GGML_GLU_OP_REGLU:
1696
1802
  case WSP_GGML_GLU_OP_GEGLU:
1697
1803
  case WSP_GGML_GLU_OP_SWIGLU:
1804
+ case WSP_GGML_GLU_OP_SWIGLU_OAI:
1805
+ case WSP_GGML_GLU_OP_GEGLU_ERF:
1806
+ case WSP_GGML_GLU_OP_GEGLU_QUICK:
1698
1807
  return wsp_ggml_is_contiguous_1(op->src[0]) && op->src[0]->type == WSP_GGML_TYPE_F32;
1699
1808
  default:
1700
1809
  return false;
@@ -1710,6 +1819,7 @@ static bool wsp_ggml_metal_supports_op(const struct wsp_ggml_backend_metal_devic
1710
1819
  case WSP_GGML_OP_SUB:
1711
1820
  case WSP_GGML_OP_MUL:
1712
1821
  case WSP_GGML_OP_DIV:
1822
+ case WSP_GGML_OP_ADD_ID:
1713
1823
  return op->src[0]->type == WSP_GGML_TYPE_F32;
1714
1824
  case WSP_GGML_OP_ACC:
1715
1825
  case WSP_GGML_OP_REPEAT:
@@ -1729,7 +1839,7 @@ static bool wsp_ggml_metal_supports_op(const struct wsp_ggml_backend_metal_devic
1729
1839
  case WSP_GGML_OP_MEAN:
1730
1840
  case WSP_GGML_OP_SOFT_MAX:
1731
1841
  case WSP_GGML_OP_GROUP_NORM:
1732
- return has_simdgroup_reduction && wsp_ggml_is_contiguous(op->src[0]);
1842
+ return has_simdgroup_reduction && wsp_ggml_is_contiguous_rows(op->src[0]);
1733
1843
  case WSP_GGML_OP_RMS_NORM:
1734
1844
  case WSP_GGML_OP_L2_NORM:
1735
1845
  return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && wsp_ggml_is_contiguous_1(op->src[0]));
@@ -1871,9 +1981,10 @@ static bool wsp_ggml_metal_supports_op(const struct wsp_ggml_backend_metal_devic
1871
1981
  }
1872
1982
  }
1873
1983
 
1874
- static bool wsp_ggml_metal_encode_node(
1984
+ static int wsp_ggml_metal_encode_node(
1875
1985
  wsp_ggml_backend_t backend,
1876
1986
  int idx,
1987
+ int idx_end,
1877
1988
  id<MTLComputeCommandEncoder> encoder,
1878
1989
  struct wsp_ggml_metal_mem_pool * mem_pool) {
1879
1990
  struct wsp_ggml_backend_metal_context * ctx = backend->context;
@@ -1881,7 +1992,10 @@ static bool wsp_ggml_metal_encode_node(
1881
1992
 
1882
1993
  struct wsp_ggml_cgraph * gf = ctx->gf;
1883
1994
 
1884
- struct wsp_ggml_tensor * node = wsp_ggml_graph_node(gf, idx);
1995
+ enum wsp_ggml_op ops[8];
1996
+
1997
+ struct wsp_ggml_tensor ** nodes = wsp_ggml_graph_nodes(gf) + idx;
1998
+ struct wsp_ggml_tensor * node = nodes[0];
1885
1999
 
1886
2000
  //WSP_GGML_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, idx, wsp_ggml_op_name(node->op));
1887
2001
 
@@ -1891,7 +2005,7 @@ static bool wsp_ggml_metal_encode_node(
1891
2005
  struct wsp_ggml_tensor * dst = node;
1892
2006
 
1893
2007
  if (wsp_ggml_is_empty(dst)) {
1894
- return true;
2008
+ return 1;
1895
2009
  }
1896
2010
 
1897
2011
  switch (dst->op) {
@@ -1902,7 +2016,7 @@ static bool wsp_ggml_metal_encode_node(
1902
2016
  case WSP_GGML_OP_PERMUTE:
1903
2017
  {
1904
2018
  // noop -> next node
1905
- } return true;
2019
+ } return 1;
1906
2020
  default:
1907
2021
  {
1908
2022
  } break;
@@ -1957,6 +2071,7 @@ static bool wsp_ggml_metal_encode_node(
1957
2071
 
1958
2072
  const enum wsp_ggml_type src0t = src0 ? src0->type : WSP_GGML_TYPE_COUNT;
1959
2073
  const enum wsp_ggml_type src1t = src1 ? src1->type : WSP_GGML_TYPE_COUNT;
2074
+ const enum wsp_ggml_type src2t = src2 ? src2->type : WSP_GGML_TYPE_COUNT;
1960
2075
  const enum wsp_ggml_type dstt = dst ? dst->type : WSP_GGML_TYPE_COUNT;
1961
2076
 
1962
2077
  size_t offs_src0 = 0;
@@ -1969,6 +2084,8 @@ static bool wsp_ggml_metal_encode_node(
1969
2084
  id<MTLBuffer> id_src2 = src2 ? wsp_ggml_metal_get_buffer(src2, &offs_src2) : nil;
1970
2085
  id<MTLBuffer> id_dst = dst ? wsp_ggml_metal_get_buffer(dst, &offs_dst) : nil;
1971
2086
 
2087
+ int n_fuse = 1;
2088
+
1972
2089
  #if 0
1973
2090
  WSP_GGML_LOG_INFO("%s: op - %s\n", __func__, wsp_ggml_op_name(dst->op));
1974
2091
  if (src0) {
@@ -2040,37 +2157,15 @@ static bool wsp_ggml_metal_encode_node(
2040
2157
  WSP_GGML_ASSERT(src0t == WSP_GGML_TYPE_F32);
2041
2158
  WSP_GGML_ASSERT(src1t == WSP_GGML_TYPE_F32);
2042
2159
 
2160
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous_rows(src0));
2161
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous_rows(src1));
2162
+
2043
2163
  const size_t offs = 0;
2044
2164
 
2045
2165
  bool bcast_row = false;
2046
2166
 
2047
2167
  id<MTLComputePipelineState> pipeline = nil;
2048
2168
 
2049
- if (wsp_ggml_nelements(src1) == ne10 && wsp_ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
2050
- WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0));
2051
-
2052
- // src1 is a row
2053
- WSP_GGML_ASSERT(ne11 == 1);
2054
-
2055
- switch (dst->op) {
2056
- case WSP_GGML_OP_ADD: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break;
2057
- case WSP_GGML_OP_SUB: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_SUB_ROW].pipeline; break;
2058
- case WSP_GGML_OP_MUL: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline; break;
2059
- case WSP_GGML_OP_DIV: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline; break;
2060
- default: WSP_GGML_ABORT("fatal error");
2061
- }
2062
-
2063
- bcast_row = true;
2064
- } else {
2065
- switch (dst->op) {
2066
- case WSP_GGML_OP_ADD: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_ADD].pipeline; break;
2067
- case WSP_GGML_OP_SUB: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_SUB].pipeline; break;
2068
- case WSP_GGML_OP_MUL: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL].pipeline; break;
2069
- case WSP_GGML_OP_DIV: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_DIV].pipeline; break;
2070
- default: WSP_GGML_ABORT("fatal error");
2071
- }
2072
- }
2073
-
2074
2169
  wsp_ggml_metal_kargs_bin args = {
2075
2170
  /*.ne00 =*/ ne00,
2076
2171
  /*.ne01 =*/ ne01,
@@ -2097,12 +2192,119 @@ static bool wsp_ggml_metal_encode_node(
2097
2192
  /*.nb2 =*/ nb2,
2098
2193
  /*.nb3 =*/ nb3,
2099
2194
  /*.offs =*/ offs,
2195
+ /*.o1 =*/ { offs_src1 },
2100
2196
  };
2101
2197
 
2198
+ // c[0] = add(a, b[0])
2199
+ // c[1] = add(c[0], b[1])
2200
+ // c[2] = add(c[1], b[2])
2201
+ // ...
2202
+ if (ctx_dev->use_fusion) {
2203
+ ops[0] = WSP_GGML_OP_ADD;
2204
+ ops[1] = WSP_GGML_OP_ADD;
2205
+ ops[2] = WSP_GGML_OP_ADD;
2206
+ ops[3] = WSP_GGML_OP_ADD;
2207
+ ops[4] = WSP_GGML_OP_ADD;
2208
+ ops[5] = WSP_GGML_OP_ADD;
2209
+ ops[6] = WSP_GGML_OP_ADD;
2210
+ ops[7] = WSP_GGML_OP_ADD;
2211
+
2212
+ size_t offs_fuse;
2213
+ id<MTLBuffer> id_fuse;
2214
+
2215
+ // note: in metal, we sometimes encode the graph in parallel so we have to avoid fusing nodes
2216
+ // across splits. idx_end indicates the last node in the current split
2217
+ for (n_fuse = 0; n_fuse <= 6 && idx + n_fuse + 1 < idx_end; ++n_fuse) {
2218
+ if (!wsp_ggml_can_fuse(gf, idx + n_fuse, ops + n_fuse, 2)) {
2219
+ break;
2220
+ }
2221
+
2222
+ if (nodes[n_fuse] != nodes[n_fuse + 1]->src[0]) {
2223
+ break;
2224
+ }
2225
+
2226
+ // b[0] === b[1] === ...
2227
+ if (!wsp_ggml_are_same_layout(nodes[n_fuse]->src[1], nodes[n_fuse + 1]->src[1])) {
2228
+ break;
2229
+ }
2230
+
2231
+ // only fuse nodes if src1 is in the same Metal buffer
2232
+ id_fuse = wsp_ggml_metal_get_buffer(nodes[n_fuse + 1]->src[1], &offs_fuse);
2233
+ if (id_fuse != id_src1) {
2234
+ break;
2235
+ }
2236
+
2237
+ ctx_dev->fuse_cnt[nodes[n_fuse + 1]->op]++;
2238
+
2239
+ args.o1[n_fuse + 1] = offs_fuse;
2240
+ }
2241
+
2242
+ ++n_fuse;
2243
+
2244
+ if (ctx_dev->debug_fusion > 1 && n_fuse > 1) {
2245
+ WSP_GGML_LOG_DEBUG("%s: fuse: ADD x %d\n", __func__, n_fuse);
2246
+ }
2247
+ }
2248
+
2249
+ if (wsp_ggml_nelements(src1) == ne10 && wsp_ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
2250
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0));
2251
+
2252
+ // src1 is a row
2253
+ WSP_GGML_ASSERT(ne11 == 1);
2254
+
2255
+ switch (dst->op) {
2256
+ case WSP_GGML_OP_ADD:
2257
+ {
2258
+ switch (n_fuse) {
2259
+ case 1: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_ADD_ROW_C4 ].pipeline; break;
2260
+ case 2: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2].pipeline; break;
2261
+ case 3: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3].pipeline; break;
2262
+ case 4: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4].pipeline; break;
2263
+ case 5: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5].pipeline; break;
2264
+ case 6: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6].pipeline; break;
2265
+ case 7: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7].pipeline; break;
2266
+ case 8: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8].pipeline; break;
2267
+ default: WSP_GGML_ABORT("fatal error");
2268
+ }
2269
+ } break;
2270
+ case WSP_GGML_OP_SUB: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_SUB_ROW_C4].pipeline; break;
2271
+ case WSP_GGML_OP_MUL: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_ROW_C4].pipeline; break;
2272
+ case WSP_GGML_OP_DIV: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_DIV_ROW_C4].pipeline; break;
2273
+ default: WSP_GGML_ABORT("fatal error");
2274
+ }
2275
+
2276
+ bcast_row = true;
2277
+ } else {
2278
+ switch (dst->op) {
2279
+ case WSP_GGML_OP_ADD:
2280
+ {
2281
+ switch (n_fuse) {
2282
+ case 1: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_ADD ].pipeline; break;
2283
+ case 2: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_ADD_FUSE_2].pipeline; break;
2284
+ case 3: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_ADD_FUSE_3].pipeline; break;
2285
+ case 4: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_ADD_FUSE_4].pipeline; break;
2286
+ case 5: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_ADD_FUSE_5].pipeline; break;
2287
+ case 6: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_ADD_FUSE_6].pipeline; break;
2288
+ case 7: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_ADD_FUSE_7].pipeline; break;
2289
+ case 8: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_ADD_FUSE_8].pipeline; break;
2290
+ default: WSP_GGML_ABORT("fatal error");
2291
+ }
2292
+ } break;
2293
+ case WSP_GGML_OP_SUB: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_SUB].pipeline; break;
2294
+ case WSP_GGML_OP_MUL: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL].pipeline; break;
2295
+ case WSP_GGML_OP_DIV: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_DIV].pipeline; break;
2296
+ default: WSP_GGML_ABORT("fatal error");
2297
+ }
2298
+ }
2299
+
2300
+ if (n_fuse > 1) {
2301
+ id_dst = wsp_ggml_metal_get_buffer(nodes[n_fuse - 1], &offs_dst);
2302
+ }
2303
+
2102
2304
  [encoder setComputePipelineState:pipeline];
2103
2305
  [encoder setBytes:&args length:sizeof(args) atIndex:0];
2104
2306
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
2105
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
2307
+ [encoder setBuffer:id_src1 offset:0 atIndex:2];
2106
2308
  [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
2107
2309
 
2108
2310
  if (bcast_row) {
@@ -2110,11 +2312,47 @@ static bool wsp_ggml_metal_encode_node(
2110
2312
 
2111
2313
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2112
2314
  } else {
2113
- const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
2315
+ int nth = 32;
2316
+
2317
+ while (16*nth < ne0 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
2318
+ nth *= 2;
2319
+ }
2114
2320
 
2115
2321
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2116
2322
  }
2117
2323
  } break;
2324
+ case WSP_GGML_OP_ADD_ID:
2325
+ {
2326
+ WSP_GGML_ASSERT(src0t == WSP_GGML_TYPE_F32);
2327
+ WSP_GGML_ASSERT(src1t == WSP_GGML_TYPE_F32);
2328
+ WSP_GGML_ASSERT(src2t == WSP_GGML_TYPE_I32);
2329
+ WSP_GGML_ASSERT(dstt == WSP_GGML_TYPE_F32);
2330
+
2331
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous_rows(src0));
2332
+
2333
+ id<MTLComputePipelineState> pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_ADD_ID].pipeline;
2334
+
2335
+ wsp_ggml_metal_kargs_add_id args = {
2336
+ /*.ne0 =*/ ne0,
2337
+ /*.ne1 =*/ ne1,
2338
+ /*.nb01 =*/ nb01,
2339
+ /*.nb02 =*/ nb02,
2340
+ /*.nb11 =*/ nb11,
2341
+ /*.nb21 =*/ nb21,
2342
+
2343
+ };
2344
+
2345
+ [encoder setComputePipelineState:pipeline];
2346
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
2347
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
2348
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
2349
+ [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
2350
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:4];
2351
+
2352
+ const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
2353
+
2354
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2355
+ } break;
2118
2356
  case WSP_GGML_OP_REPEAT:
2119
2357
  {
2120
2358
  id<MTLComputePipelineState> pipeline;
@@ -2235,12 +2473,13 @@ static bool wsp_ggml_metal_encode_node(
2235
2473
  /*.nb2 =*/ pnb2,
2236
2474
  /*.nb3 =*/ pnb3,
2237
2475
  /*.offs =*/ offs,
2476
+ /*.o1 =*/ { offs_src1},
2238
2477
  };
2239
2478
 
2240
2479
  [encoder setComputePipelineState:pipeline];
2241
2480
  [encoder setBytes:&args length:sizeof(args) atIndex:0];
2242
2481
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
2243
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
2482
+ [encoder setBuffer:id_src1 offset:0 atIndex:2];
2244
2483
  [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
2245
2484
 
2246
2485
  const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
@@ -2252,7 +2491,9 @@ static bool wsp_ggml_metal_encode_node(
2252
2491
  WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0));
2253
2492
 
2254
2493
  float scale;
2255
- memcpy(&scale, dst->op_params, sizeof(scale));
2494
+ float bias;
2495
+ memcpy(&scale, ((const int32_t *) dst->op_params) + 0, sizeof(float));
2496
+ memcpy(&bias, ((const int32_t *) dst->op_params) + 1, sizeof(float));
2256
2497
 
2257
2498
  int64_t n = wsp_ggml_nelements(dst);
2258
2499
 
@@ -2269,6 +2510,7 @@ static bool wsp_ggml_metal_encode_node(
2269
2510
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2270
2511
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2271
2512
  [encoder setBytes:&scale length:sizeof(scale) atIndex:2];
2513
+ [encoder setBytes:&bias length:sizeof(bias) atIndex:3];
2272
2514
 
2273
2515
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2274
2516
  } break;
@@ -2432,6 +2674,78 @@ static bool wsp_ggml_metal_encode_node(
2432
2674
 
2433
2675
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2434
2676
  } break;
2677
+ case WSP_GGML_UNARY_OP_ABS:
2678
+ {
2679
+ id<MTLComputePipelineState> pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_ABS].pipeline;
2680
+
2681
+ [encoder setComputePipelineState:pipeline];
2682
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2683
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2684
+
2685
+ const int64_t n = wsp_ggml_nelements(dst);
2686
+
2687
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2688
+ } break;
2689
+ case WSP_GGML_UNARY_OP_SGN:
2690
+ {
2691
+ id<MTLComputePipelineState> pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_SGN].pipeline;
2692
+
2693
+ [encoder setComputePipelineState:pipeline];
2694
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2695
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2696
+
2697
+ const int64_t n = wsp_ggml_nelements(dst);
2698
+
2699
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2700
+ } break;
2701
+ case WSP_GGML_UNARY_OP_STEP:
2702
+ {
2703
+ id<MTLComputePipelineState> pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_STEP].pipeline;
2704
+
2705
+ [encoder setComputePipelineState:pipeline];
2706
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2707
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2708
+
2709
+ const int64_t n = wsp_ggml_nelements(dst);
2710
+
2711
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2712
+ } break;
2713
+ case WSP_GGML_UNARY_OP_HARDSWISH:
2714
+ {
2715
+ id<MTLComputePipelineState> pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_HARDSWISH].pipeline;
2716
+
2717
+ [encoder setComputePipelineState:pipeline];
2718
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2719
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2720
+
2721
+ const int64_t n = wsp_ggml_nelements(dst);
2722
+
2723
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2724
+ } break;
2725
+ case WSP_GGML_UNARY_OP_HARDSIGMOID:
2726
+ {
2727
+ id<MTLComputePipelineState> pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_HARDSIGMOID].pipeline;
2728
+
2729
+ [encoder setComputePipelineState:pipeline];
2730
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2731
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2732
+
2733
+ const int64_t n = wsp_ggml_nelements(dst);
2734
+
2735
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2736
+ } break;
2737
+ case WSP_GGML_UNARY_OP_EXP:
2738
+ {
2739
+ id<MTLComputePipelineState> pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_EXP].pipeline;
2740
+
2741
+ [encoder setComputePipelineState:pipeline];
2742
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2743
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2744
+
2745
+ const int64_t n = wsp_ggml_nelements(dst);
2746
+
2747
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2748
+ } break;
2435
2749
  default:
2436
2750
  {
2437
2751
  WSP_GGML_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, idx, wsp_ggml_op_name(dst->op));
@@ -2458,11 +2772,22 @@ static bool wsp_ggml_metal_encode_node(
2458
2772
  case WSP_GGML_GLU_OP_SWIGLU:
2459
2773
  pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_SWIGLU].pipeline;
2460
2774
  break;
2775
+ case WSP_GGML_GLU_OP_SWIGLU_OAI:
2776
+ pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_SWIGLU_OAI].pipeline;
2777
+ break;
2778
+ case WSP_GGML_GLU_OP_GEGLU_ERF:
2779
+ pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_GEGLU_ERF].pipeline;
2780
+ break;
2781
+ case WSP_GGML_GLU_OP_GEGLU_QUICK:
2782
+ pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_GEGLU_QUICK].pipeline;
2783
+ break;
2461
2784
  default:
2462
2785
  WSP_GGML_ABORT("fatal error");
2463
2786
  }
2464
2787
 
2465
- const int32_t swp = ((const int32_t *) dst->op_params)[1];
2788
+ const int32_t swp = wsp_ggml_get_op_params_i32(dst, 1);
2789
+ const float alpha = wsp_ggml_get_op_params_f32(dst, 2);
2790
+ const float limit = wsp_ggml_get_op_params_f32(dst, 3);
2466
2791
 
2467
2792
  const int32_t i00 = swp ? ne0 : 0;
2468
2793
  const int32_t i10 = swp ? 0 : ne0;
@@ -2476,6 +2801,8 @@ static bool wsp_ggml_metal_encode_node(
2476
2801
  /*.nb1 =*/ nb1,
2477
2802
  /*.i00 =*/ src1 ? 0 : i00,
2478
2803
  /*.i10 =*/ src1 ? 0 : i10,
2804
+ /*.alpha=*/ alpha,
2805
+ /*.limit=*/ limit
2479
2806
  };
2480
2807
 
2481
2808
  [encoder setComputePipelineState:pipeline];
@@ -2648,10 +2975,7 @@ static bool wsp_ggml_metal_encode_node(
2648
2975
  memcpy(&scale, ((const int32_t *) dst->op_params) + 0, sizeof(scale));
2649
2976
  memcpy(&max_bias, ((const int32_t *) dst->op_params) + 1, sizeof(max_bias));
2650
2977
 
2651
- const int64_t nrows_x = wsp_ggml_nrows(src0);
2652
- const int64_t nrows_y = src0->ne[1];
2653
-
2654
- const uint32_t n_head = nrows_x/nrows_y;
2978
+ const uint32_t n_head = src0->ne[2];
2655
2979
  const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
2656
2980
 
2657
2981
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
@@ -2664,7 +2988,7 @@ static bool wsp_ggml_metal_encode_node(
2664
2988
  id<MTLBuffer> h_src0 = h_src0 = wsp_ggml_metal_mem_pool_alloc(mem_pool, wsp_ggml_nbytes(src0));
2665
2989
  if (!h_src0) {
2666
2990
  WSP_GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, wsp_ggml_nbytes(src0));
2667
- return false;
2991
+ return 0;
2668
2992
  }
2669
2993
 
2670
2994
  offs_src0 = 0;
@@ -2711,6 +3035,18 @@ static bool wsp_ggml_metal_encode_node(
2711
3035
  /*.ne00 =*/ ne00,
2712
3036
  /*.ne01 =*/ ne01,
2713
3037
  /*.ne02 =*/ ne02,
3038
+ /*.nb01 =*/ nb01,
3039
+ /*.nb02 =*/ nb02,
3040
+ /*.nb03 =*/ nb03,
3041
+ /*.ne11 =*/ ne11,
3042
+ /*.ne12 =*/ ne12,
3043
+ /*.ne13 =*/ ne13,
3044
+ /*.nb11 =*/ nb11,
3045
+ /*.nb12 =*/ nb12,
3046
+ /*.nb13 =*/ nb13,
3047
+ /*.nb1 =*/ nb1,
3048
+ /*.nb2 =*/ nb2,
3049
+ /*.nb3 =*/ nb3,
2714
3050
  /*.scale =*/ scale,
2715
3051
  /*.max_bias =*/ max_bias,
2716
3052
  /*.m0 =*/ m0,
@@ -2725,12 +3061,17 @@ static bool wsp_ggml_metal_encode_node(
2725
3061
  } else {
2726
3062
  [encoder setBuffer:h_src0 offset:offs_src0 atIndex:1];
2727
3063
  }
2728
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
2729
- [encoder setBytes:&args length:sizeof(args) atIndex:3];
3064
+ if (id_src2) {
3065
+ [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
3066
+ } else {
3067
+ [encoder setBuffer:h_src0 offset:offs_src0 atIndex:2];
3068
+ }
3069
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
3070
+ [encoder setBytes:&args length:sizeof(args) atIndex:4];
2730
3071
 
2731
3072
  [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
2732
3073
 
2733
- [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
3074
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2734
3075
  } break;
2735
3076
  case WSP_GGML_OP_DIAG_MASK_INF:
2736
3077
  {
@@ -2804,71 +3145,92 @@ static bool wsp_ggml_metal_encode_node(
2804
3145
  struct wsp_ggml_tensor * src3 = node->src[3];
2805
3146
  struct wsp_ggml_tensor * src4 = node->src[4];
2806
3147
  struct wsp_ggml_tensor * src5 = node->src[5];
3148
+ struct wsp_ggml_tensor * src6 = node->src[6];
2807
3149
 
2808
3150
  WSP_GGML_ASSERT(src3);
2809
3151
  WSP_GGML_ASSERT(src4);
2810
3152
  WSP_GGML_ASSERT(src5);
3153
+ WSP_GGML_ASSERT(src6);
2811
3154
 
2812
3155
  size_t offs_src3 = 0;
2813
3156
  size_t offs_src4 = 0;
2814
3157
  size_t offs_src5 = 0;
3158
+ size_t offs_src6 = 0;
2815
3159
 
2816
3160
  id<MTLBuffer> id_src3 = src3 ? wsp_ggml_metal_get_buffer(src3, &offs_src3) : nil;
2817
3161
  id<MTLBuffer> id_src4 = src4 ? wsp_ggml_metal_get_buffer(src4, &offs_src4) : nil;
2818
3162
  id<MTLBuffer> id_src5 = src5 ? wsp_ggml_metal_get_buffer(src5, &offs_src5) : nil;
3163
+ id<MTLBuffer> id_src6 = src6 ? wsp_ggml_metal_get_buffer(src6, &offs_src6) : nil;
2819
3164
 
2820
- const int64_t ne30 = src3->ne[0]; WSP_GGML_UNUSED(ne30);
3165
+ const int64_t ne30 = src3->ne[0];
2821
3166
  const int64_t ne31 = src3->ne[1]; WSP_GGML_UNUSED(ne31);
2822
3167
 
2823
- const uint64_t nb30 = src3->nb[0];
3168
+ const uint64_t nb30 = src3->nb[0]; WSP_GGML_UNUSED(nb30);
2824
3169
  const uint64_t nb31 = src3->nb[1];
2825
3170
 
2826
3171
  const int64_t ne40 = src4->ne[0]; WSP_GGML_UNUSED(ne40);
2827
- const int64_t ne41 = src4->ne[1]; WSP_GGML_UNUSED(ne41);
3172
+ const int64_t ne41 = src4->ne[1];
2828
3173
  const int64_t ne42 = src4->ne[2]; WSP_GGML_UNUSED(ne42);
3174
+ const int64_t ne43 = src4->ne[3]; WSP_GGML_UNUSED(ne43);
2829
3175
 
2830
- const uint64_t nb40 = src4->nb[0];
3176
+ const uint64_t nb40 = src4->nb[0]; WSP_GGML_UNUSED(nb40);
2831
3177
  const uint64_t nb41 = src4->nb[1];
2832
3178
  const uint64_t nb42 = src4->nb[2];
3179
+ const uint64_t nb43 = src4->nb[3];
2833
3180
 
2834
3181
  const int64_t ne50 = src5->ne[0]; WSP_GGML_UNUSED(ne50);
2835
3182
  const int64_t ne51 = src5->ne[1]; WSP_GGML_UNUSED(ne51);
2836
3183
  const int64_t ne52 = src5->ne[2]; WSP_GGML_UNUSED(ne52);
3184
+ const int64_t ne53 = src5->ne[3]; WSP_GGML_UNUSED(ne53);
2837
3185
 
2838
- const uint64_t nb50 = src5->nb[0];
3186
+ const uint64_t nb50 = src5->nb[0]; WSP_GGML_UNUSED(nb50);
2839
3187
  const uint64_t nb51 = src5->nb[1];
2840
3188
  const uint64_t nb52 = src5->nb[2];
3189
+ const uint64_t nb53 = src5->nb[3];
3190
+
3191
+ const int64_t ne60 = src6->ne[0]; WSP_GGML_UNUSED(ne60);
3192
+
3193
+ const uint64_t nb60 = src6->nb[0]; WSP_GGML_UNUSED(nb60);
2841
3194
 
2842
3195
  const int64_t d_state = ne00;
2843
3196
  const int64_t d_inner = ne01;
2844
- const int64_t n_seq_tokens = ne11;
2845
- const int64_t n_seqs = ne02;
3197
+ const int64_t n_head = ne02;
3198
+ const int64_t n_group = ne41;
3199
+ const int64_t n_seq_tokens = ne12;
3200
+ const int64_t n_seqs = ne13;
2846
3201
 
2847
- id<MTLComputePipelineState> pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline;
3202
+ id<MTLComputePipelineState> pipeline = nil;
3203
+
3204
+ if (ne30 == 1) {
3205
+ // Mamba-2
3206
+ pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP].pipeline;
3207
+ } else {
3208
+ pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline;
3209
+ }
2848
3210
 
2849
3211
  wsp_ggml_metal_kargs_ssm_scan args = {
2850
- /*.d_state =*/ d_state,
2851
- /*.d_inner =*/ d_inner,
3212
+ /*.d_state =*/ d_state,
3213
+ /*.d_inner =*/ d_inner,
3214
+ /*.n_head =*/ n_head,
3215
+ /*.n_group =*/ n_group,
2852
3216
  /*.n_seq_tokens =*/ n_seq_tokens,
2853
- /*.n_seqs =*/ n_seqs,
2854
- /*.nb00 =*/ nb00,
2855
- /*.nb01 =*/ nb01,
2856
- /*.nb02 =*/ nb02,
2857
- /*.nb10 =*/ nb10,
2858
- /*.nb11 =*/ nb11,
2859
- /*.nb12 =*/ nb12,
2860
- /*.nb13 =*/ nb13,
2861
- /*.nb20 =*/ nb20,
2862
- /*.nb21 =*/ nb21,
2863
- /*.nb22 =*/ nb22,
2864
- /*.nb30 =*/ nb30,
2865
- /*.nb31 =*/ nb31,
2866
- /*.nb40 =*/ nb40,
2867
- /*.nb41 =*/ nb41,
2868
- /*.nb42 =*/ nb42,
2869
- /*.nb50 =*/ nb50,
2870
- /*.nb51 =*/ nb51,
2871
- /*.nb52 =*/ nb52,
3217
+ /*.n_seqs =*/ n_seqs,
3218
+ /*.s_off =*/ wsp_ggml_nelements(src1) * sizeof(float),
3219
+ /*.nb01 =*/ nb01,
3220
+ /*.nb02 =*/ nb02,
3221
+ /*.nb03 =*/ nb03,
3222
+ /*.nb11 =*/ nb11,
3223
+ /*.nb12 =*/ nb12,
3224
+ /*.nb13 =*/ nb13,
3225
+ /*.nb21 =*/ nb21,
3226
+ /*.nb22 =*/ nb22,
3227
+ /*.nb31 =*/ nb31,
3228
+ /*.nb41 =*/ nb41,
3229
+ /*.nb42 =*/ nb42,
3230
+ /*.nb43 =*/ nb43,
3231
+ /*.nb51 =*/ nb51,
3232
+ /*.nb52 =*/ nb52,
3233
+ /*.nb53 =*/ nb53,
2872
3234
  };
2873
3235
 
2874
3236
  [encoder setComputePipelineState:pipeline];
@@ -2878,10 +3240,27 @@ static bool wsp_ggml_metal_encode_node(
2878
3240
  [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
2879
3241
  [encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
2880
3242
  [encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
2881
- [encoder setBuffer:id_dst offset:offs_dst atIndex:6];
2882
- [encoder setBytes:&args length:sizeof(args) atIndex:7];
3243
+ [encoder setBuffer:id_src6 offset:offs_src6 atIndex:6];
3244
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:7];
3245
+ [encoder setBytes:&args length:sizeof(args) atIndex:8];
3246
+
3247
+ // One shared memory bucket for each simd group in the threadgroup
3248
+ // NOTE: Metal kernels require the buffer size to be multiple of 16 bytes
3249
+ // https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength
3250
+ if (d_state >= 32) {
3251
+ WSP_GGML_ASSERT((int64_t)(d_state / 32) <= 32);
3252
+ const int64_t shmem_size = 32;
3253
+ WSP_GGML_ASSERT(d_state <= (int64_t)pipeline.maxTotalThreadsPerThreadgroup);
3254
+ [encoder setThreadgroupMemoryLength:(shmem_size)*sizeof(float) atIndex:0];
3255
+ }
2883
3256
 
2884
- [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
3257
+ if (ne30 == 1) {
3258
+ // Mamba-2
3259
+ [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(d_state, 1, 1)];
3260
+ } else {
3261
+ WSP_GGML_ASSERT(d_inner == 1);
3262
+ [encoder dispatchThreadgroups:MTLSizeMake(n_head, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(d_state, 1, 1)];
3263
+ }
2885
3264
  } break;
2886
3265
  case WSP_GGML_OP_RWKV_WKV6:
2887
3266
  {
@@ -2986,6 +3365,7 @@ static bool wsp_ggml_metal_encode_node(
2986
3365
  src0t == WSP_GGML_TYPE_Q5_0 ||
2987
3366
  src0t == WSP_GGML_TYPE_Q5_1 ||
2988
3367
  src0t == WSP_GGML_TYPE_Q8_0 ||
3368
+ src0t == WSP_GGML_TYPE_MXFP4 ||
2989
3369
  src0t == WSP_GGML_TYPE_IQ4_NL ||
2990
3370
  false) && (ne11 >= 2 && ne11 <= 8)
2991
3371
  ) ||
@@ -3078,6 +3458,14 @@ static bool wsp_ggml_metal_encode_node(
3078
3458
  case 5: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5].pipeline; break;
3079
3459
  default: WSP_GGML_ABORT("not implemented");
3080
3460
  } break;
3461
+ case WSP_GGML_TYPE_MXFP4:
3462
+ switch (r1ptg) {
3463
+ case 2: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_2].pipeline; break;
3464
+ case 3: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_3].pipeline; break;
3465
+ case 4: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_4].pipeline; break;
3466
+ case 5: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_5].pipeline; break;
3467
+ default: WSP_GGML_ABORT("not implemented");
3468
+ } break;
3081
3469
  case WSP_GGML_TYPE_Q4_K:
3082
3470
  switch (r1ptg) {
3083
3471
  case 2: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2].pipeline; break;
@@ -3176,6 +3564,7 @@ static bool wsp_ggml_metal_encode_node(
3176
3564
  case WSP_GGML_TYPE_Q5_0: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break;
3177
3565
  case WSP_GGML_TYPE_Q5_1: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32 ].pipeline; break;
3178
3566
  case WSP_GGML_TYPE_Q8_0: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32 ].pipeline; break;
3567
+ case WSP_GGML_TYPE_MXFP4: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32 ].pipeline; break;
3179
3568
  case WSP_GGML_TYPE_Q2_K: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32 ].pipeline; break;
3180
3569
  case WSP_GGML_TYPE_Q3_K: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32 ].pipeline; break;
3181
3570
  case WSP_GGML_TYPE_Q4_K: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32 ].pipeline; break;
@@ -3318,6 +3707,13 @@ static bool wsp_ggml_metal_encode_node(
3318
3707
  nr0 = N_R0_Q8_0;
3319
3708
  pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline;
3320
3709
  } break;
3710
+ case WSP_GGML_TYPE_MXFP4:
3711
+ {
3712
+ nsg = N_SG_MXFP4;
3713
+ nr0 = N_R0_MXFP4;
3714
+ smem = 32*sizeof(float);
3715
+ pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32].pipeline;
3716
+ } break;
3321
3717
  case WSP_GGML_TYPE_Q2_K:
3322
3718
  {
3323
3719
  nsg = N_SG_Q2_K;
@@ -3451,8 +3847,6 @@ static bool wsp_ggml_metal_encode_node(
3451
3847
  case WSP_GGML_OP_MUL_MAT_ID:
3452
3848
  {
3453
3849
  // src2 = ids
3454
- const enum wsp_ggml_type src2t = src2->type; WSP_GGML_UNUSED(src2t);
3455
-
3456
3850
  WSP_GGML_ASSERT(src2t == WSP_GGML_TYPE_I32);
3457
3851
 
3458
3852
  WSP_GGML_ASSERT(!wsp_ggml_is_transposed(src0));
@@ -3501,7 +3895,7 @@ static bool wsp_ggml_metal_encode_node(
3501
3895
  id<MTLBuffer> h_src1 = wsp_ggml_metal_mem_pool_alloc(mem_pool, s_src1);
3502
3896
  if (!h_src1) {
3503
3897
  WSP_GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_src1);
3504
- return false;
3898
+ return 0;
3505
3899
  }
3506
3900
 
3507
3901
  const int64_t neh0 = ne0;
@@ -3517,7 +3911,7 @@ static bool wsp_ggml_metal_encode_node(
3517
3911
  id<MTLBuffer> h_dst = wsp_ggml_metal_mem_pool_alloc(mem_pool, s_dst);
3518
3912
  if (!h_dst) {
3519
3913
  WSP_GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_dst);
3520
- return false;
3914
+ return 0;
3521
3915
  }
3522
3916
 
3523
3917
  // tokens per expert
@@ -3525,7 +3919,7 @@ static bool wsp_ggml_metal_encode_node(
3525
3919
  id<MTLBuffer> h_tpe = wsp_ggml_metal_mem_pool_alloc(mem_pool, s_tpe);
3526
3920
  if (!h_tpe) {
3527
3921
  WSP_GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_tpe);
3528
- return false;
3922
+ return 0;
3529
3923
  }
3530
3924
 
3531
3925
  // id map
@@ -3534,7 +3928,7 @@ static bool wsp_ggml_metal_encode_node(
3534
3928
  id<MTLBuffer> h_ids = wsp_ggml_metal_mem_pool_alloc(mem_pool, s_ids);
3535
3929
  if (!h_ids) {
3536
3930
  WSP_GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_ids);
3537
- return false;
3931
+ return 0;
3538
3932
  }
3539
3933
 
3540
3934
  {
@@ -3578,6 +3972,7 @@ static bool wsp_ggml_metal_encode_node(
3578
3972
  case WSP_GGML_TYPE_Q5_0: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16 ].pipeline; break;
3579
3973
  case WSP_GGML_TYPE_Q5_1: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16 ].pipeline; break;
3580
3974
  case WSP_GGML_TYPE_Q8_0: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16 ].pipeline; break;
3975
+ case WSP_GGML_TYPE_MXFP4: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MXFP4_F16 ].pipeline; break;
3581
3976
  case WSP_GGML_TYPE_Q2_K: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16 ].pipeline; break;
3582
3977
  case WSP_GGML_TYPE_Q3_K: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16 ].pipeline; break;
3583
3978
  case WSP_GGML_TYPE_Q4_K: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16 ].pipeline; break;
@@ -3713,6 +4108,13 @@ static bool wsp_ggml_metal_encode_node(
3713
4108
  nr0 = N_R0_Q8_0;
3714
4109
  pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline;
3715
4110
  } break;
4111
+ case WSP_GGML_TYPE_MXFP4:
4112
+ {
4113
+ nsg = N_SG_MXFP4;
4114
+ nr0 = N_R0_MXFP4;
4115
+ smem = 32*sizeof(float);
4116
+ pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_MXFP4_F32].pipeline;
4117
+ } break;
3716
4118
  case WSP_GGML_TYPE_Q2_K:
3717
4119
  {
3718
4120
  nsg = N_SG_Q2_K;
@@ -3865,6 +4267,7 @@ static bool wsp_ggml_metal_encode_node(
3865
4267
  case WSP_GGML_TYPE_Q5_0: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0 ].pipeline; break;
3866
4268
  case WSP_GGML_TYPE_Q5_1: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1 ].pipeline; break;
3867
4269
  case WSP_GGML_TYPE_Q8_0: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0 ].pipeline; break;
4270
+ case WSP_GGML_TYPE_MXFP4: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_MXFP4 ].pipeline; break;
3868
4271
  case WSP_GGML_TYPE_Q2_K: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K ].pipeline; break;
3869
4272
  case WSP_GGML_TYPE_Q3_K: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K ].pipeline; break;
3870
4273
  case WSP_GGML_TYPE_Q4_K: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K ].pipeline; break;
@@ -3966,12 +4369,95 @@ static bool wsp_ggml_metal_encode_node(
3966
4369
  case WSP_GGML_OP_RMS_NORM:
3967
4370
  {
3968
4371
  WSP_GGML_ASSERT(ne00 % 4 == 0);
3969
- WSP_GGML_ASSERT(wsp_ggml_is_contiguous_1(src0));
4372
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous_rows(src0));
3970
4373
 
3971
4374
  float eps;
3972
4375
  memcpy(&eps, dst->op_params, sizeof(float));
3973
4376
 
3974
- id<MTLComputePipelineState> pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline;
4377
+ wsp_ggml_metal_kargs_rms_norm args = {
4378
+ /*.ne00 =*/ ne00,
4379
+ /*.ne00_4 =*/ ne00/4,
4380
+ /*.nb1 =*/ nb1,
4381
+ /*.nb2 =*/ nb2,
4382
+ /*.nb3 =*/ nb3,
4383
+ /*.eps =*/ eps,
4384
+ /*.nef1 =*/ { ne01 },
4385
+ /*.nef2 =*/ { ne02 },
4386
+ /*.nef3 =*/ { ne03 },
4387
+ /*.nbf1 =*/ { nb01 },
4388
+ /*.nbf2 =*/ { nb02 },
4389
+ /*.nbf3 =*/ { nb03 },
4390
+ };
4391
+
4392
+ size_t offs_fuse[2] = { 0, 0 };
4393
+ id<MTLBuffer> id_fuse[2] = { id_src0, id_src0 };
4394
+
4395
+ // d[0] = rms_norm(a)
4396
+ // d[1] = mul(d[0], b)
4397
+ // d[2] = add(d[1], c)
4398
+ if (ctx_dev->use_fusion) {
4399
+ ops[0] = WSP_GGML_OP_RMS_NORM;
4400
+ ops[1] = WSP_GGML_OP_MUL;
4401
+ ops[2] = WSP_GGML_OP_ADD;
4402
+
4403
+ for (n_fuse = 0; n_fuse <= 1 && idx + n_fuse + 1 < idx_end; ++n_fuse) {
4404
+ if (!wsp_ggml_can_fuse(gf, idx + n_fuse, ops + n_fuse, 2)) {
4405
+ break;
4406
+ }
4407
+
4408
+ if (nodes[n_fuse] != nodes[n_fuse + 1]->src[0]) {
4409
+ break;
4410
+ }
4411
+
4412
+ if (nodes[n_fuse + 1]->src[1]->ne[0] != node->ne[0]) {
4413
+ break;
4414
+ }
4415
+
4416
+ if (!wsp_ggml_is_contiguous_rows(nodes[n_fuse + 1]->src[1])) {
4417
+ break;
4418
+ }
4419
+
4420
+ if (nodes[n_fuse + 1]->type != WSP_GGML_TYPE_F32) {
4421
+ break;
4422
+ }
4423
+
4424
+ ctx_dev->fuse_cnt[nodes[n_fuse + 1]->op]++;
4425
+
4426
+ id_fuse[n_fuse] = wsp_ggml_metal_get_buffer(nodes[n_fuse + 1]->src[1], &offs_fuse[n_fuse]);
4427
+
4428
+ args.nef1[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->ne[1];
4429
+ args.nef2[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->ne[2];
4430
+ args.nef3[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->ne[3];
4431
+
4432
+ args.nbf1[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->nb[1];
4433
+ args.nbf2[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->nb[2];
4434
+ args.nbf3[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->nb[3];
4435
+ }
4436
+
4437
+ ++n_fuse;
4438
+
4439
+ if (ctx_dev->debug_fusion > 1 && n_fuse > 1) {
4440
+ if (n_fuse == 2) {
4441
+ WSP_GGML_LOG_DEBUG("%s: fuse: RMS_NORM + MUL\n", __func__);
4442
+ }
4443
+ if (n_fuse == 3) {
4444
+ WSP_GGML_LOG_DEBUG("%s: fuse: RMS_NORM + MUL + ADD\n", __func__);
4445
+ }
4446
+ }
4447
+ }
4448
+
4449
+ if (n_fuse > 1) {
4450
+ id_dst = wsp_ggml_metal_get_buffer(nodes[n_fuse - 1], &offs_dst);
4451
+ }
4452
+
4453
+ id<MTLComputePipelineState> pipeline;
4454
+
4455
+ switch (n_fuse) {
4456
+ case 1: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_RMS_NORM ].pipeline; break;
4457
+ case 2: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL ].pipeline; break;
4458
+ case 3: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD].pipeline; break;
4459
+ default: WSP_GGML_ABORT("unsupported n_fuse = %d\n", n_fuse);
4460
+ }
3975
4461
 
3976
4462
  int nth = 32; // SIMD width
3977
4463
 
@@ -3982,23 +4468,16 @@ static bool wsp_ggml_metal_encode_node(
3982
4468
  nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
3983
4469
  nth = MIN(nth, ne00/4);
3984
4470
 
3985
- wsp_ggml_metal_kargs_rms_norm args = {
3986
- /*.ne00 =*/ ne00,
3987
- /*.ne00_4 =*/ ne00/4,
3988
- /*.nb01 =*/ nb01,
3989
- /*.eps =*/ eps,
3990
- };
3991
-
3992
4471
  [encoder setComputePipelineState:pipeline];
3993
- [encoder setBytes:&args length:sizeof(args) atIndex:0];
3994
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
3995
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
4472
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
4473
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
4474
+ [encoder setBuffer:id_fuse[0] offset:offs_fuse[0] atIndex:2];
4475
+ [encoder setBuffer:id_fuse[1] offset:offs_fuse[1] atIndex:3];
4476
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:4];
3996
4477
 
3997
4478
  [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
3998
4479
 
3999
- const int64_t nrows = wsp_ggml_nrows(src0);
4000
-
4001
- [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
4480
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
4002
4481
  } break;
4003
4482
  case WSP_GGML_OP_L2_NORM:
4004
4483
  {
@@ -4599,11 +5078,14 @@ static bool wsp_ggml_metal_encode_node(
4599
5078
  WSP_GGML_ASSERT(ne11 == ne21);
4600
5079
  WSP_GGML_ASSERT(ne12 == ne22);
4601
5080
 
4602
- struct wsp_ggml_tensor * src3 = node->src[3];
5081
+ struct wsp_ggml_tensor * src3 = node->src[3]; // mask
5082
+ struct wsp_ggml_tensor * src4 = node->src[4]; // sinks
4603
5083
 
4604
5084
  size_t offs_src3 = 0;
5085
+ size_t offs_src4 = 0;
4605
5086
 
4606
5087
  id<MTLBuffer> id_src3 = src3 ? wsp_ggml_metal_get_buffer(src3, &offs_src3) : nil;
5088
+ id<MTLBuffer> id_src4 = src4 ? wsp_ggml_metal_get_buffer(src4, &offs_src4) : nil;
4607
5089
 
4608
5090
  WSP_GGML_ASSERT(!src3 || src3->type == WSP_GGML_TYPE_F16);
4609
5091
  WSP_GGML_ASSERT(!src3 || src3->ne[1] >= WSP_GGML_PAD(src0->ne[1], 8) &&
@@ -4619,8 +5101,6 @@ static bool wsp_ggml_metal_encode_node(
4619
5101
  const uint64_t nb32 = src3 ? src3->nb[2] : 0; WSP_GGML_UNUSED(nb32);
4620
5102
  const uint64_t nb33 = src3 ? src3->nb[3] : 0; WSP_GGML_UNUSED(nb33);
4621
5103
 
4622
- const enum wsp_ggml_type src2t = src2 ? src2->type : WSP_GGML_TYPE_COUNT; WSP_GGML_UNUSED(src2t);
4623
-
4624
5104
  float scale;
4625
5105
  float max_bias;
4626
5106
  float logit_softcap;
@@ -4983,7 +5463,11 @@ static bool wsp_ggml_metal_encode_node(
4983
5463
  /*.nb21 =*/ nb21,
4984
5464
  /*.nb22 =*/ nb22,
4985
5465
  /*.nb23 =*/ nb23,
5466
+ /*.ne32 =*/ ne32,
5467
+ /*.ne33 =*/ ne33,
4986
5468
  /*.nb31 =*/ nb31,
5469
+ /*.nb32 =*/ nb32,
5470
+ /*.nb33 =*/ nb33,
4987
5471
  /*.ne1 =*/ ne1,
4988
5472
  /*.ne2 =*/ ne2,
4989
5473
  /*.scale =*/ scale,
@@ -5004,7 +5488,12 @@ static bool wsp_ggml_metal_encode_node(
5004
5488
  } else {
5005
5489
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:4];
5006
5490
  }
5007
- [encoder setBuffer:id_dst offset:offs_dst atIndex:5];
5491
+ if (id_src4) {
5492
+ [encoder setBuffer:id_src4 offset:offs_src4 atIndex:5];
5493
+ } else {
5494
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:5];
5495
+ }
5496
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:6];
5008
5497
 
5009
5498
  if (!use_vec_kernel) {
5010
5499
  // half8x8 kernel
@@ -5389,7 +5878,7 @@ static bool wsp_ggml_metal_encode_node(
5389
5878
  }
5390
5879
  }
5391
5880
 
5392
- return true;
5881
+ return n_fuse;
5393
5882
  }
5394
5883
 
5395
5884
  static enum wsp_ggml_status wsp_ggml_metal_graph_compute(
@@ -5895,20 +6384,26 @@ static void wsp_ggml_backend_metal_set_n_cb(wsp_ggml_backend_t backend, int n_cb
5895
6384
  struct wsp_ggml_metal_mem_pool * mem_pool = ctx->cmd_bufs[cb_idx].mem_pool;
5896
6385
  wsp_ggml_metal_mem_pool_reset(mem_pool);
5897
6386
 
5898
- for (int idx = node_start; idx < node_end; ++idx) {
6387
+ for (int idx = node_start; idx < node_end;) {
5899
6388
  if (should_capture) {
5900
6389
  [encoder pushDebugGroup:[NSString stringWithCString:wsp_ggml_op_desc(wsp_ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
5901
6390
  }
5902
6391
 
5903
- const bool res = wsp_ggml_metal_encode_node(backend, idx, encoder, mem_pool);
6392
+ const int res = wsp_ggml_metal_encode_node(backend, idx, node_end, encoder, mem_pool);
6393
+ if (idx + res > node_end) {
6394
+ WSP_GGML_ABORT("fusion error: nodes spanning multiple encoders have been fused. this indicates a bug in the fusion logic %s",
6395
+ "https://github.com/ggml-org/llama.cpp/pull/14849");
6396
+ }
5904
6397
 
5905
6398
  if (should_capture) {
5906
6399
  [encoder popDebugGroup];
5907
6400
  }
5908
6401
 
5909
- if (!res) {
6402
+ if (res == 0) {
5910
6403
  break;
5911
6404
  }
6405
+
6406
+ idx += res;
5912
6407
  }
5913
6408
 
5914
6409
  [encoder endEncoding];