whisper.rn 0.5.0-rc.8 → 0.5.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/cpp/ggml-alloc.c +1 -15
- package/cpp/ggml-backend-reg.cpp +17 -8
- package/cpp/ggml-backend.cpp +15 -22
- package/cpp/ggml-common.h +17 -0
- package/cpp/ggml-cpu/arch/arm/quants.c +132 -596
- package/cpp/ggml-cpu/arch/arm/repack.cpp +14 -286
- package/cpp/ggml-cpu/arch/x86/quants.c +184 -675
- package/cpp/ggml-cpu/arch/x86/repack.cpp +4679 -1657
- package/cpp/ggml-cpu/arch-fallback.h +34 -0
- package/cpp/ggml-cpu/ggml-cpu.c +22 -1
- package/cpp/ggml-cpu/ggml-cpu.cpp +21 -24
- package/cpp/ggml-cpu/ops.cpp +870 -211
- package/cpp/ggml-cpu/ops.h +3 -8
- package/cpp/ggml-cpu/quants.c +35 -0
- package/cpp/ggml-cpu/quants.h +8 -0
- package/cpp/ggml-cpu/repack.cpp +458 -47
- package/cpp/ggml-cpu/repack.h +22 -0
- package/cpp/ggml-cpu/simd-mappings.h +1 -1
- package/cpp/ggml-cpu/traits.cpp +2 -2
- package/cpp/ggml-cpu/traits.h +1 -1
- package/cpp/ggml-cpu/vec.cpp +12 -9
- package/cpp/ggml-cpu/vec.h +107 -13
- package/cpp/ggml-impl.h +77 -0
- package/cpp/ggml-metal-impl.h +51 -12
- package/cpp/ggml-metal.m +610 -115
- package/cpp/ggml-opt.cpp +97 -41
- package/cpp/ggml-opt.h +25 -6
- package/cpp/ggml-quants.c +110 -16
- package/cpp/ggml-quants.h +6 -0
- package/cpp/ggml-whisper-sim.metallib +0 -0
- package/cpp/ggml-whisper.metallib +0 -0
- package/cpp/ggml.c +314 -88
- package/cpp/ggml.h +137 -11
- package/cpp/gguf.cpp +8 -1
- package/cpp/jsi/RNWhisperJSI.cpp +23 -6
- package/cpp/whisper.cpp +15 -6
- package/ios/RNWhisper.mm +6 -6
- package/ios/RNWhisperContext.mm +2 -0
- package/ios/RNWhisperVadContext.mm +2 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-common.h +17 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +77 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +51 -12
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-opt.h +25 -6
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-quants.h +6 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +137 -11
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-common.h +17 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +77 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +51 -12
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-opt.h +25 -6
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-quants.h +6 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +137 -11
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-common.h +17 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +77 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +51 -12
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-opt.h +25 -6
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-quants.h +6 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +137 -11
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-common.h +17 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +77 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +51 -12
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-opt.h +25 -6
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-quants.h +6 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +137 -11
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
- package/lib/commonjs/realtime-transcription/RealtimeTranscriber.js +28 -2
- package/lib/commonjs/realtime-transcription/RealtimeTranscriber.js.map +1 -1
- package/lib/module/realtime-transcription/RealtimeTranscriber.js +28 -2
- package/lib/module/realtime-transcription/RealtimeTranscriber.js.map +1 -1
- package/lib/typescript/realtime-transcription/RealtimeTranscriber.d.ts +1 -0
- package/lib/typescript/realtime-transcription/RealtimeTranscriber.d.ts.map +1 -1
- package/lib/typescript/realtime-transcription/types.d.ts +6 -0
- package/lib/typescript/realtime-transcription/types.d.ts.map +1 -1
- package/package.json +1 -1
- package/src/realtime-transcription/RealtimeTranscriber.ts +32 -0
- package/src/realtime-transcription/types.ts +6 -0
package/cpp/ggml-metal.m
CHANGED
|
@@ -55,6 +55,12 @@ static struct wsp_ggml_backend_metal_device_context {
|
|
|
55
55
|
bool has_residency_sets;
|
|
56
56
|
bool has_bfloat;
|
|
57
57
|
bool use_bfloat;
|
|
58
|
+
bool use_fusion;
|
|
59
|
+
|
|
60
|
+
int debug_fusion;
|
|
61
|
+
|
|
62
|
+
// how many times a given op was fused
|
|
63
|
+
uint64_t fuse_cnt[WSP_GGML_OP_COUNT];
|
|
58
64
|
|
|
59
65
|
size_t max_size;
|
|
60
66
|
|
|
@@ -69,6 +75,9 @@ static struct wsp_ggml_backend_metal_device_context {
|
|
|
69
75
|
/*.has_residency_sets =*/ false,
|
|
70
76
|
/*.has_bfloat =*/ false,
|
|
71
77
|
/*.use_bfloat =*/ false,
|
|
78
|
+
/*.use_fusion =*/ true,
|
|
79
|
+
/*.debug_fusion =*/ 0,
|
|
80
|
+
/*.fuse_cnt =*/ { 0 },
|
|
72
81
|
/*.max_size =*/ 0,
|
|
73
82
|
/*.name =*/ "",
|
|
74
83
|
};
|
|
@@ -83,16 +92,14 @@ static id<MTLDevice> wsp_ggml_backend_metal_device_acq(struct wsp_ggml_backend_m
|
|
|
83
92
|
|
|
84
93
|
if (ctx->mtl_device == nil) {
|
|
85
94
|
ctx->mtl_device = MTLCreateSystemDefaultDevice();
|
|
86
|
-
}
|
|
87
95
|
|
|
88
|
-
if (ctx->mtl_device) {
|
|
89
96
|
ctx->has_simdgroup_reduction = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
|
|
90
97
|
ctx->has_simdgroup_reduction |= [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
|
|
91
98
|
|
|
92
99
|
ctx->has_simdgroup_mm = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
|
|
93
100
|
|
|
94
101
|
#if defined(WSP_GGML_METAL_HAS_RESIDENCY_SETS)
|
|
95
|
-
ctx->has_residency_sets = getenv("WSP_GGML_METAL_NO_RESIDENCY") ==
|
|
102
|
+
ctx->has_residency_sets = getenv("WSP_GGML_METAL_NO_RESIDENCY") == nil;
|
|
96
103
|
#endif
|
|
97
104
|
|
|
98
105
|
ctx->has_bfloat = [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
|
|
@@ -103,6 +110,14 @@ static id<MTLDevice> wsp_ggml_backend_metal_device_acq(struct wsp_ggml_backend_m
|
|
|
103
110
|
#else
|
|
104
111
|
ctx->use_bfloat = false;
|
|
105
112
|
#endif
|
|
113
|
+
ctx->use_fusion = getenv("WSP_GGML_METAL_FUSION_DISABLE") == nil;
|
|
114
|
+
|
|
115
|
+
{
|
|
116
|
+
const char * val = getenv("WSP_GGML_METAL_FUSION_DEBUG");
|
|
117
|
+
ctx->debug_fusion = val ? atoi(val) : 0;
|
|
118
|
+
}
|
|
119
|
+
|
|
120
|
+
memset(ctx->fuse_cnt, 0, sizeof(ctx->fuse_cnt));
|
|
106
121
|
|
|
107
122
|
ctx->max_size = ctx->mtl_device.maxBufferLength;
|
|
108
123
|
|
|
@@ -122,6 +137,18 @@ static void wsp_ggml_backend_metal_device_rel(struct wsp_ggml_backend_metal_devi
|
|
|
122
137
|
ctx->mtl_device_ref_count--;
|
|
123
138
|
|
|
124
139
|
if (ctx->mtl_device_ref_count == 0) {
|
|
140
|
+
if (ctx->debug_fusion > 0) {
|
|
141
|
+
fprintf(stderr, "%s: fusion stats:\n", __func__);
|
|
142
|
+
for (int i = 0; i < WSP_GGML_OP_COUNT; i++) {
|
|
143
|
+
if (ctx->fuse_cnt[i] == 0) {
|
|
144
|
+
continue;
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
// note: cannot use wsp_ggml_log here
|
|
148
|
+
fprintf(stderr, "%s: - %s: %" PRIu64 "\n", __func__, wsp_ggml_op_name((enum wsp_ggml_op) i), ctx->fuse_cnt[i]);
|
|
149
|
+
}
|
|
150
|
+
}
|
|
151
|
+
|
|
125
152
|
if (ctx->mtl_lock) {
|
|
126
153
|
[ctx->mtl_lock release];
|
|
127
154
|
ctx->mtl_lock = nil;
|
|
@@ -147,13 +174,28 @@ struct wsp_ggml_metal_kernel {
|
|
|
147
174
|
|
|
148
175
|
enum wsp_ggml_metal_kernel_type {
|
|
149
176
|
WSP_GGML_METAL_KERNEL_TYPE_ADD,
|
|
150
|
-
|
|
177
|
+
WSP_GGML_METAL_KERNEL_TYPE_ADD_FUSE_2,
|
|
178
|
+
WSP_GGML_METAL_KERNEL_TYPE_ADD_FUSE_3,
|
|
179
|
+
WSP_GGML_METAL_KERNEL_TYPE_ADD_FUSE_4,
|
|
180
|
+
WSP_GGML_METAL_KERNEL_TYPE_ADD_FUSE_5,
|
|
181
|
+
WSP_GGML_METAL_KERNEL_TYPE_ADD_FUSE_6,
|
|
182
|
+
WSP_GGML_METAL_KERNEL_TYPE_ADD_FUSE_7,
|
|
183
|
+
WSP_GGML_METAL_KERNEL_TYPE_ADD_FUSE_8,
|
|
184
|
+
WSP_GGML_METAL_KERNEL_TYPE_ADD_ROW_C4,
|
|
185
|
+
WSP_GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2,
|
|
186
|
+
WSP_GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3,
|
|
187
|
+
WSP_GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4,
|
|
188
|
+
WSP_GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5,
|
|
189
|
+
WSP_GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6,
|
|
190
|
+
WSP_GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7,
|
|
191
|
+
WSP_GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8,
|
|
151
192
|
WSP_GGML_METAL_KERNEL_TYPE_SUB,
|
|
152
|
-
|
|
193
|
+
WSP_GGML_METAL_KERNEL_TYPE_SUB_ROW_C4,
|
|
153
194
|
WSP_GGML_METAL_KERNEL_TYPE_MUL,
|
|
154
|
-
|
|
195
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_ROW_C4,
|
|
155
196
|
WSP_GGML_METAL_KERNEL_TYPE_DIV,
|
|
156
|
-
|
|
197
|
+
WSP_GGML_METAL_KERNEL_TYPE_DIV_ROW_C4,
|
|
198
|
+
WSP_GGML_METAL_KERNEL_TYPE_ADD_ID,
|
|
157
199
|
WSP_GGML_METAL_KERNEL_TYPE_REPEAT_F32,
|
|
158
200
|
WSP_GGML_METAL_KERNEL_TYPE_REPEAT_F16,
|
|
159
201
|
WSP_GGML_METAL_KERNEL_TYPE_REPEAT_I32,
|
|
@@ -173,6 +215,12 @@ enum wsp_ggml_metal_kernel_type {
|
|
|
173
215
|
WSP_GGML_METAL_KERNEL_TYPE_SILU,
|
|
174
216
|
WSP_GGML_METAL_KERNEL_TYPE_SILU_4,
|
|
175
217
|
WSP_GGML_METAL_KERNEL_TYPE_ELU,
|
|
218
|
+
WSP_GGML_METAL_KERNEL_TYPE_ABS,
|
|
219
|
+
WSP_GGML_METAL_KERNEL_TYPE_SGN,
|
|
220
|
+
WSP_GGML_METAL_KERNEL_TYPE_STEP,
|
|
221
|
+
WSP_GGML_METAL_KERNEL_TYPE_HARDSWISH,
|
|
222
|
+
WSP_GGML_METAL_KERNEL_TYPE_HARDSIGMOID,
|
|
223
|
+
WSP_GGML_METAL_KERNEL_TYPE_EXP,
|
|
176
224
|
WSP_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16,
|
|
177
225
|
WSP_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4,
|
|
178
226
|
WSP_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32,
|
|
@@ -187,6 +235,7 @@ enum wsp_ggml_metal_kernel_type {
|
|
|
187
235
|
WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0,
|
|
188
236
|
WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1,
|
|
189
237
|
WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0,
|
|
238
|
+
WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_MXFP4,
|
|
190
239
|
WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K,
|
|
191
240
|
WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K,
|
|
192
241
|
WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K,
|
|
@@ -212,11 +261,14 @@ enum wsp_ggml_metal_kernel_type {
|
|
|
212
261
|
WSP_GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1,
|
|
213
262
|
WSP_GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL,
|
|
214
263
|
WSP_GGML_METAL_KERNEL_TYPE_RMS_NORM,
|
|
264
|
+
WSP_GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL,
|
|
265
|
+
WSP_GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD,
|
|
215
266
|
WSP_GGML_METAL_KERNEL_TYPE_L2_NORM,
|
|
216
267
|
WSP_GGML_METAL_KERNEL_TYPE_GROUP_NORM,
|
|
217
268
|
WSP_GGML_METAL_KERNEL_TYPE_NORM,
|
|
218
269
|
WSP_GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
|
|
219
270
|
WSP_GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,
|
|
271
|
+
WSP_GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP,
|
|
220
272
|
WSP_GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32,
|
|
221
273
|
WSP_GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32,
|
|
222
274
|
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
|
|
@@ -236,6 +288,7 @@ enum wsp_ggml_metal_kernel_type {
|
|
|
236
288
|
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,
|
|
237
289
|
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,
|
|
238
290
|
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,
|
|
291
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32,
|
|
239
292
|
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2,
|
|
240
293
|
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3,
|
|
241
294
|
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4,
|
|
@@ -260,6 +313,10 @@ enum wsp_ggml_metal_kernel_type {
|
|
|
260
313
|
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3,
|
|
261
314
|
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4,
|
|
262
315
|
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5,
|
|
316
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_2,
|
|
317
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_3,
|
|
318
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_4,
|
|
319
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_5,
|
|
263
320
|
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2,
|
|
264
321
|
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3,
|
|
265
322
|
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4,
|
|
@@ -301,6 +358,7 @@ enum wsp_ggml_metal_kernel_type {
|
|
|
301
358
|
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32,
|
|
302
359
|
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32,
|
|
303
360
|
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32,
|
|
361
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_MXFP4_F32,
|
|
304
362
|
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32,
|
|
305
363
|
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32,
|
|
306
364
|
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32,
|
|
@@ -323,6 +381,7 @@ enum wsp_ggml_metal_kernel_type {
|
|
|
323
381
|
WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32,
|
|
324
382
|
WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32,
|
|
325
383
|
WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32,
|
|
384
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32,
|
|
326
385
|
WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32,
|
|
327
386
|
WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32,
|
|
328
387
|
WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32,
|
|
@@ -347,6 +406,7 @@ enum wsp_ggml_metal_kernel_type {
|
|
|
347
406
|
WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16,
|
|
348
407
|
WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16,
|
|
349
408
|
WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16,
|
|
409
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MXFP4_F16,
|
|
350
410
|
WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16,
|
|
351
411
|
WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16,
|
|
352
412
|
WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16,
|
|
@@ -529,6 +589,9 @@ enum wsp_ggml_metal_kernel_type {
|
|
|
529
589
|
WSP_GGML_METAL_KERNEL_TYPE_REGLU,
|
|
530
590
|
WSP_GGML_METAL_KERNEL_TYPE_GEGLU,
|
|
531
591
|
WSP_GGML_METAL_KERNEL_TYPE_SWIGLU,
|
|
592
|
+
WSP_GGML_METAL_KERNEL_TYPE_SWIGLU_OAI,
|
|
593
|
+
WSP_GGML_METAL_KERNEL_TYPE_GEGLU_ERF,
|
|
594
|
+
WSP_GGML_METAL_KERNEL_TYPE_GEGLU_QUICK,
|
|
532
595
|
WSP_GGML_METAL_KERNEL_TYPE_SUM_ROWS,
|
|
533
596
|
WSP_GGML_METAL_KERNEL_TYPE_MEAN,
|
|
534
597
|
WSP_GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
|
|
@@ -1130,13 +1193,28 @@ static struct wsp_ggml_backend_metal_context * wsp_ggml_metal_init(wsp_ggml_back
|
|
|
1130
1193
|
// simd_sum and simd_max requires MTLGPUFamilyApple7
|
|
1131
1194
|
|
|
1132
1195
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_ADD, add, true);
|
|
1133
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
1196
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_ADD_FUSE_2, add_fuse_2, true);
|
|
1197
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_ADD_FUSE_3, add_fuse_3, true);
|
|
1198
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_ADD_FUSE_4, add_fuse_4, true);
|
|
1199
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_ADD_FUSE_5, add_fuse_5, true);
|
|
1200
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_ADD_FUSE_6, add_fuse_6, true);
|
|
1201
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_ADD_FUSE_7, add_fuse_7, true);
|
|
1202
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_ADD_FUSE_8, add_fuse_8, true);
|
|
1203
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_ADD_ROW_C4, add_row_c4, true);
|
|
1204
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2, add_row_c4_fuse_2, true);
|
|
1205
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3, add_row_c4_fuse_3, true);
|
|
1206
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4, add_row_c4_fuse_4, true);
|
|
1207
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5, add_row_c4_fuse_5, true);
|
|
1208
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6, add_row_c4_fuse_6, true);
|
|
1209
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7, add_row_c4_fuse_7, true);
|
|
1210
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8, add_row_c4_fuse_8, true);
|
|
1134
1211
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SUB, sub, true);
|
|
1135
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
1212
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SUB_ROW_C4, sub_row_c4, true);
|
|
1136
1213
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL, mul, true);
|
|
1137
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
1214
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_ROW_C4, mul_row_c4, true);
|
|
1138
1215
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_DIV, div, true);
|
|
1139
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
1216
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_DIV_ROW_C4, div_row_c4, true);
|
|
1217
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_ADD_ID, add_id, true);
|
|
1140
1218
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true);
|
|
1141
1219
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true);
|
|
1142
1220
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_REPEAT_I32, repeat_i32, true);
|
|
@@ -1156,6 +1234,12 @@ static struct wsp_ggml_backend_metal_context * wsp_ggml_metal_init(wsp_ggml_back
|
|
|
1156
1234
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SILU, silu, true);
|
|
1157
1235
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true);
|
|
1158
1236
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_ELU, elu, true);
|
|
1237
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_ABS, abs, true);
|
|
1238
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SGN, sgn, true);
|
|
1239
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_STEP, step, true);
|
|
1240
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_HARDSWISH, hardswish, true);
|
|
1241
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_HARDSIGMOID, hardsigmoid, true);
|
|
1242
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_EXP, exp, true);
|
|
1159
1243
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, has_simdgroup_reduction);
|
|
1160
1244
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, has_simdgroup_reduction);
|
|
1161
1245
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, has_simdgroup_reduction);
|
|
@@ -1170,6 +1254,7 @@ static struct wsp_ggml_backend_metal_context * wsp_ggml_metal_init(wsp_ggml_back
|
|
|
1170
1254
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true);
|
|
1171
1255
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true);
|
|
1172
1256
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true);
|
|
1257
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_MXFP4, get_rows_mxfp4, true);
|
|
1173
1258
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true);
|
|
1174
1259
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true);
|
|
1175
1260
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true);
|
|
@@ -1195,11 +1280,14 @@ static struct wsp_ggml_backend_metal_context * wsp_ggml_metal_init(wsp_ggml_back
|
|
|
1195
1280
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1, set_rows_q5_1, true);
|
|
1196
1281
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL, set_rows_iq4_nl, true);
|
|
1197
1282
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction);
|
|
1283
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL, rms_norm_mul, has_simdgroup_reduction);
|
|
1284
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD, rms_norm_mul_add, has_simdgroup_reduction);
|
|
1198
1285
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_L2_NORM, l2_norm, has_simdgroup_reduction);
|
|
1199
1286
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction);
|
|
1200
1287
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_NORM, norm, true);
|
|
1201
1288
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true);
|
|
1202
1289
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true);
|
|
1290
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP, ssm_scan_f32_group, true);
|
|
1203
1291
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, rwkv_wkv6_f32, true);
|
|
1204
1292
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, rwkv_wkv7_f32, true);
|
|
1205
1293
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction);
|
|
@@ -1219,6 +1307,7 @@ static struct wsp_ggml_backend_metal_context * wsp_ggml_metal_init(wsp_ggml_back
|
|
|
1219
1307
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, has_simdgroup_reduction);
|
|
1220
1308
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, has_simdgroup_reduction);
|
|
1221
1309
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, has_simdgroup_reduction);
|
|
1310
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32, mul_mv_mxfp4_f32, has_simdgroup_reduction);
|
|
1222
1311
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2, mul_mv_ext_f16_f32_r1_2, has_simdgroup_reduction);
|
|
1223
1312
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3, mul_mv_ext_f16_f32_r1_3, has_simdgroup_reduction);
|
|
1224
1313
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4, mul_mv_ext_f16_f32_r1_4, has_simdgroup_reduction);
|
|
@@ -1243,6 +1332,10 @@ static struct wsp_ggml_backend_metal_context * wsp_ggml_metal_init(wsp_ggml_back
|
|
|
1243
1332
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3, mul_mv_ext_q8_0_f32_r1_3, has_simdgroup_reduction);
|
|
1244
1333
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4, mul_mv_ext_q8_0_f32_r1_4, has_simdgroup_reduction);
|
|
1245
1334
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5, mul_mv_ext_q8_0_f32_r1_5, has_simdgroup_reduction);
|
|
1335
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_2, mul_mv_ext_mxfp4_f32_r1_2, has_simdgroup_reduction);
|
|
1336
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_3, mul_mv_ext_mxfp4_f32_r1_3, has_simdgroup_reduction);
|
|
1337
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_4, mul_mv_ext_mxfp4_f32_r1_4, has_simdgroup_reduction);
|
|
1338
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_5, mul_mv_ext_mxfp4_f32_r1_5, has_simdgroup_reduction);
|
|
1246
1339
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2, mul_mv_ext_q4_K_f32_r1_2, has_simdgroup_reduction);
|
|
1247
1340
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3, mul_mv_ext_q4_K_f32_r1_3, has_simdgroup_reduction);
|
|
1248
1341
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4, mul_mv_ext_q4_K_f32_r1_4, has_simdgroup_reduction);
|
|
@@ -1284,6 +1377,7 @@ static struct wsp_ggml_backend_metal_context * wsp_ggml_metal_init(wsp_ggml_back
|
|
|
1284
1377
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, has_simdgroup_reduction);
|
|
1285
1378
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, has_simdgroup_reduction);
|
|
1286
1379
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, has_simdgroup_reduction);
|
|
1380
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_MXFP4_F32, mul_mv_id_mxfp4_f32, has_simdgroup_reduction);
|
|
1287
1381
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, has_simdgroup_reduction);
|
|
1288
1382
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, has_simdgroup_reduction);
|
|
1289
1383
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, has_simdgroup_reduction);
|
|
@@ -1306,6 +1400,8 @@ static struct wsp_ggml_backend_metal_context * wsp_ggml_metal_init(wsp_ggml_back
|
|
|
1306
1400
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, has_simdgroup_mm);
|
|
1307
1401
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, has_simdgroup_mm);
|
|
1308
1402
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, has_simdgroup_mm);
|
|
1403
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32, mul_mm_mxfp4_f32, has_simdgroup_mm);
|
|
1404
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32, mul_mm_mxfp4_f32, has_simdgroup_mm);
|
|
1309
1405
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, has_simdgroup_mm);
|
|
1310
1406
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, has_simdgroup_mm);
|
|
1311
1407
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, has_simdgroup_mm);
|
|
@@ -1330,6 +1426,7 @@ static struct wsp_ggml_backend_metal_context * wsp_ggml_metal_init(wsp_ggml_back
|
|
|
1330
1426
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16, mul_mm_id_q5_0_f16, has_simdgroup_mm);
|
|
1331
1427
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16, mul_mm_id_q5_1_f16, has_simdgroup_mm);
|
|
1332
1428
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16, mul_mm_id_q8_0_f16, has_simdgroup_mm);
|
|
1429
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MXFP4_F16, mul_mm_id_mxfp4_f16, has_simdgroup_mm);
|
|
1333
1430
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16, mul_mm_id_q2_K_f16, has_simdgroup_mm);
|
|
1334
1431
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16, mul_mm_id_q3_K_f16, has_simdgroup_mm);
|
|
1335
1432
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16, mul_mm_id_q4_K_f16, has_simdgroup_mm);
|
|
@@ -1512,6 +1609,9 @@ static struct wsp_ggml_backend_metal_context * wsp_ggml_metal_init(wsp_ggml_back
|
|
|
1512
1609
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_REGLU, reglu, true);
|
|
1513
1610
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GEGLU, geglu, true);
|
|
1514
1611
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SWIGLU, swiglu, true);
|
|
1612
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SWIGLU_OAI, swiglu_oai, true);
|
|
1613
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GEGLU_ERF, geglu_erf, true);
|
|
1614
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GEGLU_QUICK, geglu_quick, true);
|
|
1515
1615
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
|
1516
1616
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MEAN, mean, true);
|
|
1517
1617
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
|
|
@@ -1686,6 +1786,12 @@ static bool wsp_ggml_metal_supports_op(const struct wsp_ggml_backend_metal_devic
|
|
|
1686
1786
|
case WSP_GGML_UNARY_OP_SILU:
|
|
1687
1787
|
case WSP_GGML_UNARY_OP_ELU:
|
|
1688
1788
|
case WSP_GGML_UNARY_OP_NEG:
|
|
1789
|
+
case WSP_GGML_UNARY_OP_ABS:
|
|
1790
|
+
case WSP_GGML_UNARY_OP_SGN:
|
|
1791
|
+
case WSP_GGML_UNARY_OP_STEP:
|
|
1792
|
+
case WSP_GGML_UNARY_OP_HARDSWISH:
|
|
1793
|
+
case WSP_GGML_UNARY_OP_HARDSIGMOID:
|
|
1794
|
+
case WSP_GGML_UNARY_OP_EXP:
|
|
1689
1795
|
return wsp_ggml_is_contiguous(op->src[0]) && op->src[0]->type == WSP_GGML_TYPE_F32;
|
|
1690
1796
|
default:
|
|
1691
1797
|
return false;
|
|
@@ -1695,6 +1801,9 @@ static bool wsp_ggml_metal_supports_op(const struct wsp_ggml_backend_metal_devic
|
|
|
1695
1801
|
case WSP_GGML_GLU_OP_REGLU:
|
|
1696
1802
|
case WSP_GGML_GLU_OP_GEGLU:
|
|
1697
1803
|
case WSP_GGML_GLU_OP_SWIGLU:
|
|
1804
|
+
case WSP_GGML_GLU_OP_SWIGLU_OAI:
|
|
1805
|
+
case WSP_GGML_GLU_OP_GEGLU_ERF:
|
|
1806
|
+
case WSP_GGML_GLU_OP_GEGLU_QUICK:
|
|
1698
1807
|
return wsp_ggml_is_contiguous_1(op->src[0]) && op->src[0]->type == WSP_GGML_TYPE_F32;
|
|
1699
1808
|
default:
|
|
1700
1809
|
return false;
|
|
@@ -1710,6 +1819,7 @@ static bool wsp_ggml_metal_supports_op(const struct wsp_ggml_backend_metal_devic
|
|
|
1710
1819
|
case WSP_GGML_OP_SUB:
|
|
1711
1820
|
case WSP_GGML_OP_MUL:
|
|
1712
1821
|
case WSP_GGML_OP_DIV:
|
|
1822
|
+
case WSP_GGML_OP_ADD_ID:
|
|
1713
1823
|
return op->src[0]->type == WSP_GGML_TYPE_F32;
|
|
1714
1824
|
case WSP_GGML_OP_ACC:
|
|
1715
1825
|
case WSP_GGML_OP_REPEAT:
|
|
@@ -1729,7 +1839,7 @@ static bool wsp_ggml_metal_supports_op(const struct wsp_ggml_backend_metal_devic
|
|
|
1729
1839
|
case WSP_GGML_OP_MEAN:
|
|
1730
1840
|
case WSP_GGML_OP_SOFT_MAX:
|
|
1731
1841
|
case WSP_GGML_OP_GROUP_NORM:
|
|
1732
|
-
return has_simdgroup_reduction &&
|
|
1842
|
+
return has_simdgroup_reduction && wsp_ggml_is_contiguous_rows(op->src[0]);
|
|
1733
1843
|
case WSP_GGML_OP_RMS_NORM:
|
|
1734
1844
|
case WSP_GGML_OP_L2_NORM:
|
|
1735
1845
|
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && wsp_ggml_is_contiguous_1(op->src[0]));
|
|
@@ -1871,9 +1981,10 @@ static bool wsp_ggml_metal_supports_op(const struct wsp_ggml_backend_metal_devic
|
|
|
1871
1981
|
}
|
|
1872
1982
|
}
|
|
1873
1983
|
|
|
1874
|
-
static
|
|
1984
|
+
static int wsp_ggml_metal_encode_node(
|
|
1875
1985
|
wsp_ggml_backend_t backend,
|
|
1876
1986
|
int idx,
|
|
1987
|
+
int idx_end,
|
|
1877
1988
|
id<MTLComputeCommandEncoder> encoder,
|
|
1878
1989
|
struct wsp_ggml_metal_mem_pool * mem_pool) {
|
|
1879
1990
|
struct wsp_ggml_backend_metal_context * ctx = backend->context;
|
|
@@ -1881,7 +1992,10 @@ static bool wsp_ggml_metal_encode_node(
|
|
|
1881
1992
|
|
|
1882
1993
|
struct wsp_ggml_cgraph * gf = ctx->gf;
|
|
1883
1994
|
|
|
1884
|
-
|
|
1995
|
+
enum wsp_ggml_op ops[8];
|
|
1996
|
+
|
|
1997
|
+
struct wsp_ggml_tensor ** nodes = wsp_ggml_graph_nodes(gf) + idx;
|
|
1998
|
+
struct wsp_ggml_tensor * node = nodes[0];
|
|
1885
1999
|
|
|
1886
2000
|
//WSP_GGML_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, idx, wsp_ggml_op_name(node->op));
|
|
1887
2001
|
|
|
@@ -1891,7 +2005,7 @@ static bool wsp_ggml_metal_encode_node(
|
|
|
1891
2005
|
struct wsp_ggml_tensor * dst = node;
|
|
1892
2006
|
|
|
1893
2007
|
if (wsp_ggml_is_empty(dst)) {
|
|
1894
|
-
return
|
|
2008
|
+
return 1;
|
|
1895
2009
|
}
|
|
1896
2010
|
|
|
1897
2011
|
switch (dst->op) {
|
|
@@ -1902,7 +2016,7 @@ static bool wsp_ggml_metal_encode_node(
|
|
|
1902
2016
|
case WSP_GGML_OP_PERMUTE:
|
|
1903
2017
|
{
|
|
1904
2018
|
// noop -> next node
|
|
1905
|
-
} return
|
|
2019
|
+
} return 1;
|
|
1906
2020
|
default:
|
|
1907
2021
|
{
|
|
1908
2022
|
} break;
|
|
@@ -1957,6 +2071,7 @@ static bool wsp_ggml_metal_encode_node(
|
|
|
1957
2071
|
|
|
1958
2072
|
const enum wsp_ggml_type src0t = src0 ? src0->type : WSP_GGML_TYPE_COUNT;
|
|
1959
2073
|
const enum wsp_ggml_type src1t = src1 ? src1->type : WSP_GGML_TYPE_COUNT;
|
|
2074
|
+
const enum wsp_ggml_type src2t = src2 ? src2->type : WSP_GGML_TYPE_COUNT;
|
|
1960
2075
|
const enum wsp_ggml_type dstt = dst ? dst->type : WSP_GGML_TYPE_COUNT;
|
|
1961
2076
|
|
|
1962
2077
|
size_t offs_src0 = 0;
|
|
@@ -1969,6 +2084,8 @@ static bool wsp_ggml_metal_encode_node(
|
|
|
1969
2084
|
id<MTLBuffer> id_src2 = src2 ? wsp_ggml_metal_get_buffer(src2, &offs_src2) : nil;
|
|
1970
2085
|
id<MTLBuffer> id_dst = dst ? wsp_ggml_metal_get_buffer(dst, &offs_dst) : nil;
|
|
1971
2086
|
|
|
2087
|
+
int n_fuse = 1;
|
|
2088
|
+
|
|
1972
2089
|
#if 0
|
|
1973
2090
|
WSP_GGML_LOG_INFO("%s: op - %s\n", __func__, wsp_ggml_op_name(dst->op));
|
|
1974
2091
|
if (src0) {
|
|
@@ -2040,37 +2157,15 @@ static bool wsp_ggml_metal_encode_node(
|
|
|
2040
2157
|
WSP_GGML_ASSERT(src0t == WSP_GGML_TYPE_F32);
|
|
2041
2158
|
WSP_GGML_ASSERT(src1t == WSP_GGML_TYPE_F32);
|
|
2042
2159
|
|
|
2160
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous_rows(src0));
|
|
2161
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous_rows(src1));
|
|
2162
|
+
|
|
2043
2163
|
const size_t offs = 0;
|
|
2044
2164
|
|
|
2045
2165
|
bool bcast_row = false;
|
|
2046
2166
|
|
|
2047
2167
|
id<MTLComputePipelineState> pipeline = nil;
|
|
2048
2168
|
|
|
2049
|
-
if (wsp_ggml_nelements(src1) == ne10 && wsp_ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
|
|
2050
|
-
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0));
|
|
2051
|
-
|
|
2052
|
-
// src1 is a row
|
|
2053
|
-
WSP_GGML_ASSERT(ne11 == 1);
|
|
2054
|
-
|
|
2055
|
-
switch (dst->op) {
|
|
2056
|
-
case WSP_GGML_OP_ADD: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break;
|
|
2057
|
-
case WSP_GGML_OP_SUB: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_SUB_ROW].pipeline; break;
|
|
2058
|
-
case WSP_GGML_OP_MUL: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline; break;
|
|
2059
|
-
case WSP_GGML_OP_DIV: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline; break;
|
|
2060
|
-
default: WSP_GGML_ABORT("fatal error");
|
|
2061
|
-
}
|
|
2062
|
-
|
|
2063
|
-
bcast_row = true;
|
|
2064
|
-
} else {
|
|
2065
|
-
switch (dst->op) {
|
|
2066
|
-
case WSP_GGML_OP_ADD: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_ADD].pipeline; break;
|
|
2067
|
-
case WSP_GGML_OP_SUB: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_SUB].pipeline; break;
|
|
2068
|
-
case WSP_GGML_OP_MUL: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL].pipeline; break;
|
|
2069
|
-
case WSP_GGML_OP_DIV: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_DIV].pipeline; break;
|
|
2070
|
-
default: WSP_GGML_ABORT("fatal error");
|
|
2071
|
-
}
|
|
2072
|
-
}
|
|
2073
|
-
|
|
2074
2169
|
wsp_ggml_metal_kargs_bin args = {
|
|
2075
2170
|
/*.ne00 =*/ ne00,
|
|
2076
2171
|
/*.ne01 =*/ ne01,
|
|
@@ -2097,12 +2192,119 @@ static bool wsp_ggml_metal_encode_node(
|
|
|
2097
2192
|
/*.nb2 =*/ nb2,
|
|
2098
2193
|
/*.nb3 =*/ nb3,
|
|
2099
2194
|
/*.offs =*/ offs,
|
|
2195
|
+
/*.o1 =*/ { offs_src1 },
|
|
2100
2196
|
};
|
|
2101
2197
|
|
|
2198
|
+
// c[0] = add(a, b[0])
|
|
2199
|
+
// c[1] = add(c[0], b[1])
|
|
2200
|
+
// c[2] = add(c[1], b[2])
|
|
2201
|
+
// ...
|
|
2202
|
+
if (ctx_dev->use_fusion) {
|
|
2203
|
+
ops[0] = WSP_GGML_OP_ADD;
|
|
2204
|
+
ops[1] = WSP_GGML_OP_ADD;
|
|
2205
|
+
ops[2] = WSP_GGML_OP_ADD;
|
|
2206
|
+
ops[3] = WSP_GGML_OP_ADD;
|
|
2207
|
+
ops[4] = WSP_GGML_OP_ADD;
|
|
2208
|
+
ops[5] = WSP_GGML_OP_ADD;
|
|
2209
|
+
ops[6] = WSP_GGML_OP_ADD;
|
|
2210
|
+
ops[7] = WSP_GGML_OP_ADD;
|
|
2211
|
+
|
|
2212
|
+
size_t offs_fuse;
|
|
2213
|
+
id<MTLBuffer> id_fuse;
|
|
2214
|
+
|
|
2215
|
+
// note: in metal, we sometimes encode the graph in parallel so we have to avoid fusing nodes
|
|
2216
|
+
// across splits. idx_end indicates the last node in the current split
|
|
2217
|
+
for (n_fuse = 0; n_fuse <= 6 && idx + n_fuse + 1 < idx_end; ++n_fuse) {
|
|
2218
|
+
if (!wsp_ggml_can_fuse(gf, idx + n_fuse, ops + n_fuse, 2)) {
|
|
2219
|
+
break;
|
|
2220
|
+
}
|
|
2221
|
+
|
|
2222
|
+
if (nodes[n_fuse] != nodes[n_fuse + 1]->src[0]) {
|
|
2223
|
+
break;
|
|
2224
|
+
}
|
|
2225
|
+
|
|
2226
|
+
// b[0] === b[1] === ...
|
|
2227
|
+
if (!wsp_ggml_are_same_layout(nodes[n_fuse]->src[1], nodes[n_fuse + 1]->src[1])) {
|
|
2228
|
+
break;
|
|
2229
|
+
}
|
|
2230
|
+
|
|
2231
|
+
// only fuse nodes if src1 is in the same Metal buffer
|
|
2232
|
+
id_fuse = wsp_ggml_metal_get_buffer(nodes[n_fuse + 1]->src[1], &offs_fuse);
|
|
2233
|
+
if (id_fuse != id_src1) {
|
|
2234
|
+
break;
|
|
2235
|
+
}
|
|
2236
|
+
|
|
2237
|
+
ctx_dev->fuse_cnt[nodes[n_fuse + 1]->op]++;
|
|
2238
|
+
|
|
2239
|
+
args.o1[n_fuse + 1] = offs_fuse;
|
|
2240
|
+
}
|
|
2241
|
+
|
|
2242
|
+
++n_fuse;
|
|
2243
|
+
|
|
2244
|
+
if (ctx_dev->debug_fusion > 1 && n_fuse > 1) {
|
|
2245
|
+
WSP_GGML_LOG_DEBUG("%s: fuse: ADD x %d\n", __func__, n_fuse);
|
|
2246
|
+
}
|
|
2247
|
+
}
|
|
2248
|
+
|
|
2249
|
+
if (wsp_ggml_nelements(src1) == ne10 && wsp_ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
|
|
2250
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0));
|
|
2251
|
+
|
|
2252
|
+
// src1 is a row
|
|
2253
|
+
WSP_GGML_ASSERT(ne11 == 1);
|
|
2254
|
+
|
|
2255
|
+
switch (dst->op) {
|
|
2256
|
+
case WSP_GGML_OP_ADD:
|
|
2257
|
+
{
|
|
2258
|
+
switch (n_fuse) {
|
|
2259
|
+
case 1: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_ADD_ROW_C4 ].pipeline; break;
|
|
2260
|
+
case 2: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2].pipeline; break;
|
|
2261
|
+
case 3: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3].pipeline; break;
|
|
2262
|
+
case 4: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4].pipeline; break;
|
|
2263
|
+
case 5: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5].pipeline; break;
|
|
2264
|
+
case 6: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6].pipeline; break;
|
|
2265
|
+
case 7: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7].pipeline; break;
|
|
2266
|
+
case 8: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8].pipeline; break;
|
|
2267
|
+
default: WSP_GGML_ABORT("fatal error");
|
|
2268
|
+
}
|
|
2269
|
+
} break;
|
|
2270
|
+
case WSP_GGML_OP_SUB: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_SUB_ROW_C4].pipeline; break;
|
|
2271
|
+
case WSP_GGML_OP_MUL: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_ROW_C4].pipeline; break;
|
|
2272
|
+
case WSP_GGML_OP_DIV: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_DIV_ROW_C4].pipeline; break;
|
|
2273
|
+
default: WSP_GGML_ABORT("fatal error");
|
|
2274
|
+
}
|
|
2275
|
+
|
|
2276
|
+
bcast_row = true;
|
|
2277
|
+
} else {
|
|
2278
|
+
switch (dst->op) {
|
|
2279
|
+
case WSP_GGML_OP_ADD:
|
|
2280
|
+
{
|
|
2281
|
+
switch (n_fuse) {
|
|
2282
|
+
case 1: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_ADD ].pipeline; break;
|
|
2283
|
+
case 2: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_ADD_FUSE_2].pipeline; break;
|
|
2284
|
+
case 3: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_ADD_FUSE_3].pipeline; break;
|
|
2285
|
+
case 4: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_ADD_FUSE_4].pipeline; break;
|
|
2286
|
+
case 5: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_ADD_FUSE_5].pipeline; break;
|
|
2287
|
+
case 6: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_ADD_FUSE_6].pipeline; break;
|
|
2288
|
+
case 7: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_ADD_FUSE_7].pipeline; break;
|
|
2289
|
+
case 8: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_ADD_FUSE_8].pipeline; break;
|
|
2290
|
+
default: WSP_GGML_ABORT("fatal error");
|
|
2291
|
+
}
|
|
2292
|
+
} break;
|
|
2293
|
+
case WSP_GGML_OP_SUB: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_SUB].pipeline; break;
|
|
2294
|
+
case WSP_GGML_OP_MUL: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL].pipeline; break;
|
|
2295
|
+
case WSP_GGML_OP_DIV: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_DIV].pipeline; break;
|
|
2296
|
+
default: WSP_GGML_ABORT("fatal error");
|
|
2297
|
+
}
|
|
2298
|
+
}
|
|
2299
|
+
|
|
2300
|
+
if (n_fuse > 1) {
|
|
2301
|
+
id_dst = wsp_ggml_metal_get_buffer(nodes[n_fuse - 1], &offs_dst);
|
|
2302
|
+
}
|
|
2303
|
+
|
|
2102
2304
|
[encoder setComputePipelineState:pipeline];
|
|
2103
2305
|
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
|
2104
2306
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
|
2105
|
-
[encoder setBuffer:id_src1 offset:
|
|
2307
|
+
[encoder setBuffer:id_src1 offset:0 atIndex:2];
|
|
2106
2308
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
|
2107
2309
|
|
|
2108
2310
|
if (bcast_row) {
|
|
@@ -2110,11 +2312,47 @@ static bool wsp_ggml_metal_encode_node(
|
|
|
2110
2312
|
|
|
2111
2313
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
2112
2314
|
} else {
|
|
2113
|
-
|
|
2315
|
+
int nth = 32;
|
|
2316
|
+
|
|
2317
|
+
while (16*nth < ne0 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
|
|
2318
|
+
nth *= 2;
|
|
2319
|
+
}
|
|
2114
2320
|
|
|
2115
2321
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
2116
2322
|
}
|
|
2117
2323
|
} break;
|
|
2324
|
+
case WSP_GGML_OP_ADD_ID:
|
|
2325
|
+
{
|
|
2326
|
+
WSP_GGML_ASSERT(src0t == WSP_GGML_TYPE_F32);
|
|
2327
|
+
WSP_GGML_ASSERT(src1t == WSP_GGML_TYPE_F32);
|
|
2328
|
+
WSP_GGML_ASSERT(src2t == WSP_GGML_TYPE_I32);
|
|
2329
|
+
WSP_GGML_ASSERT(dstt == WSP_GGML_TYPE_F32);
|
|
2330
|
+
|
|
2331
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous_rows(src0));
|
|
2332
|
+
|
|
2333
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_ADD_ID].pipeline;
|
|
2334
|
+
|
|
2335
|
+
wsp_ggml_metal_kargs_add_id args = {
|
|
2336
|
+
/*.ne0 =*/ ne0,
|
|
2337
|
+
/*.ne1 =*/ ne1,
|
|
2338
|
+
/*.nb01 =*/ nb01,
|
|
2339
|
+
/*.nb02 =*/ nb02,
|
|
2340
|
+
/*.nb11 =*/ nb11,
|
|
2341
|
+
/*.nb21 =*/ nb21,
|
|
2342
|
+
|
|
2343
|
+
};
|
|
2344
|
+
|
|
2345
|
+
[encoder setComputePipelineState:pipeline];
|
|
2346
|
+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
|
2347
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
|
2348
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
|
|
2349
|
+
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
|
|
2350
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:4];
|
|
2351
|
+
|
|
2352
|
+
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
|
|
2353
|
+
|
|
2354
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
2355
|
+
} break;
|
|
2118
2356
|
case WSP_GGML_OP_REPEAT:
|
|
2119
2357
|
{
|
|
2120
2358
|
id<MTLComputePipelineState> pipeline;
|
|
@@ -2235,12 +2473,13 @@ static bool wsp_ggml_metal_encode_node(
|
|
|
2235
2473
|
/*.nb2 =*/ pnb2,
|
|
2236
2474
|
/*.nb3 =*/ pnb3,
|
|
2237
2475
|
/*.offs =*/ offs,
|
|
2476
|
+
/*.o1 =*/ { offs_src1},
|
|
2238
2477
|
};
|
|
2239
2478
|
|
|
2240
2479
|
[encoder setComputePipelineState:pipeline];
|
|
2241
2480
|
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
|
2242
2481
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
|
2243
|
-
[encoder setBuffer:id_src1 offset:
|
|
2482
|
+
[encoder setBuffer:id_src1 offset:0 atIndex:2];
|
|
2244
2483
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
|
2245
2484
|
|
|
2246
2485
|
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
|
|
@@ -2252,7 +2491,9 @@ static bool wsp_ggml_metal_encode_node(
|
|
|
2252
2491
|
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0));
|
|
2253
2492
|
|
|
2254
2493
|
float scale;
|
|
2255
|
-
|
|
2494
|
+
float bias;
|
|
2495
|
+
memcpy(&scale, ((const int32_t *) dst->op_params) + 0, sizeof(float));
|
|
2496
|
+
memcpy(&bias, ((const int32_t *) dst->op_params) + 1, sizeof(float));
|
|
2256
2497
|
|
|
2257
2498
|
int64_t n = wsp_ggml_nelements(dst);
|
|
2258
2499
|
|
|
@@ -2269,6 +2510,7 @@ static bool wsp_ggml_metal_encode_node(
|
|
|
2269
2510
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
2270
2511
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
2271
2512
|
[encoder setBytes:&scale length:sizeof(scale) atIndex:2];
|
|
2513
|
+
[encoder setBytes:&bias length:sizeof(bias) atIndex:3];
|
|
2272
2514
|
|
|
2273
2515
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
2274
2516
|
} break;
|
|
@@ -2432,6 +2674,78 @@ static bool wsp_ggml_metal_encode_node(
|
|
|
2432
2674
|
|
|
2433
2675
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
2434
2676
|
} break;
|
|
2677
|
+
case WSP_GGML_UNARY_OP_ABS:
|
|
2678
|
+
{
|
|
2679
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_ABS].pipeline;
|
|
2680
|
+
|
|
2681
|
+
[encoder setComputePipelineState:pipeline];
|
|
2682
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
2683
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
2684
|
+
|
|
2685
|
+
const int64_t n = wsp_ggml_nelements(dst);
|
|
2686
|
+
|
|
2687
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
2688
|
+
} break;
|
|
2689
|
+
case WSP_GGML_UNARY_OP_SGN:
|
|
2690
|
+
{
|
|
2691
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_SGN].pipeline;
|
|
2692
|
+
|
|
2693
|
+
[encoder setComputePipelineState:pipeline];
|
|
2694
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
2695
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
2696
|
+
|
|
2697
|
+
const int64_t n = wsp_ggml_nelements(dst);
|
|
2698
|
+
|
|
2699
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
2700
|
+
} break;
|
|
2701
|
+
case WSP_GGML_UNARY_OP_STEP:
|
|
2702
|
+
{
|
|
2703
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_STEP].pipeline;
|
|
2704
|
+
|
|
2705
|
+
[encoder setComputePipelineState:pipeline];
|
|
2706
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
2707
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
2708
|
+
|
|
2709
|
+
const int64_t n = wsp_ggml_nelements(dst);
|
|
2710
|
+
|
|
2711
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
2712
|
+
} break;
|
|
2713
|
+
case WSP_GGML_UNARY_OP_HARDSWISH:
|
|
2714
|
+
{
|
|
2715
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_HARDSWISH].pipeline;
|
|
2716
|
+
|
|
2717
|
+
[encoder setComputePipelineState:pipeline];
|
|
2718
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
2719
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
2720
|
+
|
|
2721
|
+
const int64_t n = wsp_ggml_nelements(dst);
|
|
2722
|
+
|
|
2723
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
2724
|
+
} break;
|
|
2725
|
+
case WSP_GGML_UNARY_OP_HARDSIGMOID:
|
|
2726
|
+
{
|
|
2727
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_HARDSIGMOID].pipeline;
|
|
2728
|
+
|
|
2729
|
+
[encoder setComputePipelineState:pipeline];
|
|
2730
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
2731
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
2732
|
+
|
|
2733
|
+
const int64_t n = wsp_ggml_nelements(dst);
|
|
2734
|
+
|
|
2735
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
2736
|
+
} break;
|
|
2737
|
+
case WSP_GGML_UNARY_OP_EXP:
|
|
2738
|
+
{
|
|
2739
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_EXP].pipeline;
|
|
2740
|
+
|
|
2741
|
+
[encoder setComputePipelineState:pipeline];
|
|
2742
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
2743
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
2744
|
+
|
|
2745
|
+
const int64_t n = wsp_ggml_nelements(dst);
|
|
2746
|
+
|
|
2747
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
2748
|
+
} break;
|
|
2435
2749
|
default:
|
|
2436
2750
|
{
|
|
2437
2751
|
WSP_GGML_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, idx, wsp_ggml_op_name(dst->op));
|
|
@@ -2458,11 +2772,22 @@ static bool wsp_ggml_metal_encode_node(
|
|
|
2458
2772
|
case WSP_GGML_GLU_OP_SWIGLU:
|
|
2459
2773
|
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_SWIGLU].pipeline;
|
|
2460
2774
|
break;
|
|
2775
|
+
case WSP_GGML_GLU_OP_SWIGLU_OAI:
|
|
2776
|
+
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_SWIGLU_OAI].pipeline;
|
|
2777
|
+
break;
|
|
2778
|
+
case WSP_GGML_GLU_OP_GEGLU_ERF:
|
|
2779
|
+
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_GEGLU_ERF].pipeline;
|
|
2780
|
+
break;
|
|
2781
|
+
case WSP_GGML_GLU_OP_GEGLU_QUICK:
|
|
2782
|
+
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_GEGLU_QUICK].pipeline;
|
|
2783
|
+
break;
|
|
2461
2784
|
default:
|
|
2462
2785
|
WSP_GGML_ABORT("fatal error");
|
|
2463
2786
|
}
|
|
2464
2787
|
|
|
2465
|
-
const int32_t swp = (
|
|
2788
|
+
const int32_t swp = wsp_ggml_get_op_params_i32(dst, 1);
|
|
2789
|
+
const float alpha = wsp_ggml_get_op_params_f32(dst, 2);
|
|
2790
|
+
const float limit = wsp_ggml_get_op_params_f32(dst, 3);
|
|
2466
2791
|
|
|
2467
2792
|
const int32_t i00 = swp ? ne0 : 0;
|
|
2468
2793
|
const int32_t i10 = swp ? 0 : ne0;
|
|
@@ -2476,6 +2801,8 @@ static bool wsp_ggml_metal_encode_node(
|
|
|
2476
2801
|
/*.nb1 =*/ nb1,
|
|
2477
2802
|
/*.i00 =*/ src1 ? 0 : i00,
|
|
2478
2803
|
/*.i10 =*/ src1 ? 0 : i10,
|
|
2804
|
+
/*.alpha=*/ alpha,
|
|
2805
|
+
/*.limit=*/ limit
|
|
2479
2806
|
};
|
|
2480
2807
|
|
|
2481
2808
|
[encoder setComputePipelineState:pipeline];
|
|
@@ -2648,10 +2975,7 @@ static bool wsp_ggml_metal_encode_node(
|
|
|
2648
2975
|
memcpy(&scale, ((const int32_t *) dst->op_params) + 0, sizeof(scale));
|
|
2649
2976
|
memcpy(&max_bias, ((const int32_t *) dst->op_params) + 1, sizeof(max_bias));
|
|
2650
2977
|
|
|
2651
|
-
const
|
|
2652
|
-
const int64_t nrows_y = src0->ne[1];
|
|
2653
|
-
|
|
2654
|
-
const uint32_t n_head = nrows_x/nrows_y;
|
|
2978
|
+
const uint32_t n_head = src0->ne[2];
|
|
2655
2979
|
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
|
|
2656
2980
|
|
|
2657
2981
|
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
|
@@ -2664,7 +2988,7 @@ static bool wsp_ggml_metal_encode_node(
|
|
|
2664
2988
|
id<MTLBuffer> h_src0 = h_src0 = wsp_ggml_metal_mem_pool_alloc(mem_pool, wsp_ggml_nbytes(src0));
|
|
2665
2989
|
if (!h_src0) {
|
|
2666
2990
|
WSP_GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, wsp_ggml_nbytes(src0));
|
|
2667
|
-
return
|
|
2991
|
+
return 0;
|
|
2668
2992
|
}
|
|
2669
2993
|
|
|
2670
2994
|
offs_src0 = 0;
|
|
@@ -2711,6 +3035,18 @@ static bool wsp_ggml_metal_encode_node(
|
|
|
2711
3035
|
/*.ne00 =*/ ne00,
|
|
2712
3036
|
/*.ne01 =*/ ne01,
|
|
2713
3037
|
/*.ne02 =*/ ne02,
|
|
3038
|
+
/*.nb01 =*/ nb01,
|
|
3039
|
+
/*.nb02 =*/ nb02,
|
|
3040
|
+
/*.nb03 =*/ nb03,
|
|
3041
|
+
/*.ne11 =*/ ne11,
|
|
3042
|
+
/*.ne12 =*/ ne12,
|
|
3043
|
+
/*.ne13 =*/ ne13,
|
|
3044
|
+
/*.nb11 =*/ nb11,
|
|
3045
|
+
/*.nb12 =*/ nb12,
|
|
3046
|
+
/*.nb13 =*/ nb13,
|
|
3047
|
+
/*.nb1 =*/ nb1,
|
|
3048
|
+
/*.nb2 =*/ nb2,
|
|
3049
|
+
/*.nb3 =*/ nb3,
|
|
2714
3050
|
/*.scale =*/ scale,
|
|
2715
3051
|
/*.max_bias =*/ max_bias,
|
|
2716
3052
|
/*.m0 =*/ m0,
|
|
@@ -2725,12 +3061,17 @@ static bool wsp_ggml_metal_encode_node(
|
|
|
2725
3061
|
} else {
|
|
2726
3062
|
[encoder setBuffer:h_src0 offset:offs_src0 atIndex:1];
|
|
2727
3063
|
}
|
|
2728
|
-
|
|
2729
|
-
|
|
3064
|
+
if (id_src2) {
|
|
3065
|
+
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
|
|
3066
|
+
} else {
|
|
3067
|
+
[encoder setBuffer:h_src0 offset:offs_src0 atIndex:2];
|
|
3068
|
+
}
|
|
3069
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
|
3070
|
+
[encoder setBytes:&args length:sizeof(args) atIndex:4];
|
|
2730
3071
|
|
|
2731
3072
|
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
|
2732
3073
|
|
|
2733
|
-
[encoder dispatchThreadgroups:MTLSizeMake(ne01
|
|
3074
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
2734
3075
|
} break;
|
|
2735
3076
|
case WSP_GGML_OP_DIAG_MASK_INF:
|
|
2736
3077
|
{
|
|
@@ -2804,71 +3145,92 @@ static bool wsp_ggml_metal_encode_node(
|
|
|
2804
3145
|
struct wsp_ggml_tensor * src3 = node->src[3];
|
|
2805
3146
|
struct wsp_ggml_tensor * src4 = node->src[4];
|
|
2806
3147
|
struct wsp_ggml_tensor * src5 = node->src[5];
|
|
3148
|
+
struct wsp_ggml_tensor * src6 = node->src[6];
|
|
2807
3149
|
|
|
2808
3150
|
WSP_GGML_ASSERT(src3);
|
|
2809
3151
|
WSP_GGML_ASSERT(src4);
|
|
2810
3152
|
WSP_GGML_ASSERT(src5);
|
|
3153
|
+
WSP_GGML_ASSERT(src6);
|
|
2811
3154
|
|
|
2812
3155
|
size_t offs_src3 = 0;
|
|
2813
3156
|
size_t offs_src4 = 0;
|
|
2814
3157
|
size_t offs_src5 = 0;
|
|
3158
|
+
size_t offs_src6 = 0;
|
|
2815
3159
|
|
|
2816
3160
|
id<MTLBuffer> id_src3 = src3 ? wsp_ggml_metal_get_buffer(src3, &offs_src3) : nil;
|
|
2817
3161
|
id<MTLBuffer> id_src4 = src4 ? wsp_ggml_metal_get_buffer(src4, &offs_src4) : nil;
|
|
2818
3162
|
id<MTLBuffer> id_src5 = src5 ? wsp_ggml_metal_get_buffer(src5, &offs_src5) : nil;
|
|
3163
|
+
id<MTLBuffer> id_src6 = src6 ? wsp_ggml_metal_get_buffer(src6, &offs_src6) : nil;
|
|
2819
3164
|
|
|
2820
|
-
const int64_t ne30 = src3->ne[0];
|
|
3165
|
+
const int64_t ne30 = src3->ne[0];
|
|
2821
3166
|
const int64_t ne31 = src3->ne[1]; WSP_GGML_UNUSED(ne31);
|
|
2822
3167
|
|
|
2823
|
-
const uint64_t nb30 = src3->nb[0];
|
|
3168
|
+
const uint64_t nb30 = src3->nb[0]; WSP_GGML_UNUSED(nb30);
|
|
2824
3169
|
const uint64_t nb31 = src3->nb[1];
|
|
2825
3170
|
|
|
2826
3171
|
const int64_t ne40 = src4->ne[0]; WSP_GGML_UNUSED(ne40);
|
|
2827
|
-
const int64_t ne41 = src4->ne[1];
|
|
3172
|
+
const int64_t ne41 = src4->ne[1];
|
|
2828
3173
|
const int64_t ne42 = src4->ne[2]; WSP_GGML_UNUSED(ne42);
|
|
3174
|
+
const int64_t ne43 = src4->ne[3]; WSP_GGML_UNUSED(ne43);
|
|
2829
3175
|
|
|
2830
|
-
const uint64_t nb40 = src4->nb[0];
|
|
3176
|
+
const uint64_t nb40 = src4->nb[0]; WSP_GGML_UNUSED(nb40);
|
|
2831
3177
|
const uint64_t nb41 = src4->nb[1];
|
|
2832
3178
|
const uint64_t nb42 = src4->nb[2];
|
|
3179
|
+
const uint64_t nb43 = src4->nb[3];
|
|
2833
3180
|
|
|
2834
3181
|
const int64_t ne50 = src5->ne[0]; WSP_GGML_UNUSED(ne50);
|
|
2835
3182
|
const int64_t ne51 = src5->ne[1]; WSP_GGML_UNUSED(ne51);
|
|
2836
3183
|
const int64_t ne52 = src5->ne[2]; WSP_GGML_UNUSED(ne52);
|
|
3184
|
+
const int64_t ne53 = src5->ne[3]; WSP_GGML_UNUSED(ne53);
|
|
2837
3185
|
|
|
2838
|
-
const uint64_t nb50 = src5->nb[0];
|
|
3186
|
+
const uint64_t nb50 = src5->nb[0]; WSP_GGML_UNUSED(nb50);
|
|
2839
3187
|
const uint64_t nb51 = src5->nb[1];
|
|
2840
3188
|
const uint64_t nb52 = src5->nb[2];
|
|
3189
|
+
const uint64_t nb53 = src5->nb[3];
|
|
3190
|
+
|
|
3191
|
+
const int64_t ne60 = src6->ne[0]; WSP_GGML_UNUSED(ne60);
|
|
3192
|
+
|
|
3193
|
+
const uint64_t nb60 = src6->nb[0]; WSP_GGML_UNUSED(nb60);
|
|
2841
3194
|
|
|
2842
3195
|
const int64_t d_state = ne00;
|
|
2843
3196
|
const int64_t d_inner = ne01;
|
|
2844
|
-
const int64_t
|
|
2845
|
-
const int64_t
|
|
3197
|
+
const int64_t n_head = ne02;
|
|
3198
|
+
const int64_t n_group = ne41;
|
|
3199
|
+
const int64_t n_seq_tokens = ne12;
|
|
3200
|
+
const int64_t n_seqs = ne13;
|
|
2846
3201
|
|
|
2847
|
-
id<MTLComputePipelineState> pipeline =
|
|
3202
|
+
id<MTLComputePipelineState> pipeline = nil;
|
|
3203
|
+
|
|
3204
|
+
if (ne30 == 1) {
|
|
3205
|
+
// Mamba-2
|
|
3206
|
+
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP].pipeline;
|
|
3207
|
+
} else {
|
|
3208
|
+
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline;
|
|
3209
|
+
}
|
|
2848
3210
|
|
|
2849
3211
|
wsp_ggml_metal_kargs_ssm_scan args = {
|
|
2850
|
-
/*.d_state
|
|
2851
|
-
/*.d_inner
|
|
3212
|
+
/*.d_state =*/ d_state,
|
|
3213
|
+
/*.d_inner =*/ d_inner,
|
|
3214
|
+
/*.n_head =*/ n_head,
|
|
3215
|
+
/*.n_group =*/ n_group,
|
|
2852
3216
|
/*.n_seq_tokens =*/ n_seq_tokens,
|
|
2853
|
-
/*.n_seqs
|
|
2854
|
-
/*.
|
|
2855
|
-
/*.nb01
|
|
2856
|
-
/*.nb02
|
|
2857
|
-
/*.
|
|
2858
|
-
/*.nb11
|
|
2859
|
-
/*.nb12
|
|
2860
|
-
/*.nb13
|
|
2861
|
-
/*.
|
|
2862
|
-
/*.
|
|
2863
|
-
/*.
|
|
2864
|
-
/*.
|
|
2865
|
-
/*.
|
|
2866
|
-
/*.
|
|
2867
|
-
/*.
|
|
2868
|
-
/*.
|
|
2869
|
-
/*.
|
|
2870
|
-
/*.nb51 =*/ nb51,
|
|
2871
|
-
/*.nb52 =*/ nb52,
|
|
3217
|
+
/*.n_seqs =*/ n_seqs,
|
|
3218
|
+
/*.s_off =*/ wsp_ggml_nelements(src1) * sizeof(float),
|
|
3219
|
+
/*.nb01 =*/ nb01,
|
|
3220
|
+
/*.nb02 =*/ nb02,
|
|
3221
|
+
/*.nb03 =*/ nb03,
|
|
3222
|
+
/*.nb11 =*/ nb11,
|
|
3223
|
+
/*.nb12 =*/ nb12,
|
|
3224
|
+
/*.nb13 =*/ nb13,
|
|
3225
|
+
/*.nb21 =*/ nb21,
|
|
3226
|
+
/*.nb22 =*/ nb22,
|
|
3227
|
+
/*.nb31 =*/ nb31,
|
|
3228
|
+
/*.nb41 =*/ nb41,
|
|
3229
|
+
/*.nb42 =*/ nb42,
|
|
3230
|
+
/*.nb43 =*/ nb43,
|
|
3231
|
+
/*.nb51 =*/ nb51,
|
|
3232
|
+
/*.nb52 =*/ nb52,
|
|
3233
|
+
/*.nb53 =*/ nb53,
|
|
2872
3234
|
};
|
|
2873
3235
|
|
|
2874
3236
|
[encoder setComputePipelineState:pipeline];
|
|
@@ -2878,10 +3240,27 @@ static bool wsp_ggml_metal_encode_node(
|
|
|
2878
3240
|
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
|
|
2879
3241
|
[encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
|
|
2880
3242
|
[encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
|
|
2881
|
-
[encoder setBuffer:
|
|
2882
|
-
[encoder
|
|
3243
|
+
[encoder setBuffer:id_src6 offset:offs_src6 atIndex:6];
|
|
3244
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:7];
|
|
3245
|
+
[encoder setBytes:&args length:sizeof(args) atIndex:8];
|
|
3246
|
+
|
|
3247
|
+
// One shared memory bucket for each simd group in the threadgroup
|
|
3248
|
+
// NOTE: Metal kernels require the buffer size to be multiple of 16 bytes
|
|
3249
|
+
// https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength
|
|
3250
|
+
if (d_state >= 32) {
|
|
3251
|
+
WSP_GGML_ASSERT((int64_t)(d_state / 32) <= 32);
|
|
3252
|
+
const int64_t shmem_size = 32;
|
|
3253
|
+
WSP_GGML_ASSERT(d_state <= (int64_t)pipeline.maxTotalThreadsPerThreadgroup);
|
|
3254
|
+
[encoder setThreadgroupMemoryLength:(shmem_size)*sizeof(float) atIndex:0];
|
|
3255
|
+
}
|
|
2883
3256
|
|
|
2884
|
-
|
|
3257
|
+
if (ne30 == 1) {
|
|
3258
|
+
// Mamba-2
|
|
3259
|
+
[encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(d_state, 1, 1)];
|
|
3260
|
+
} else {
|
|
3261
|
+
WSP_GGML_ASSERT(d_inner == 1);
|
|
3262
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n_head, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(d_state, 1, 1)];
|
|
3263
|
+
}
|
|
2885
3264
|
} break;
|
|
2886
3265
|
case WSP_GGML_OP_RWKV_WKV6:
|
|
2887
3266
|
{
|
|
@@ -2986,6 +3365,7 @@ static bool wsp_ggml_metal_encode_node(
|
|
|
2986
3365
|
src0t == WSP_GGML_TYPE_Q5_0 ||
|
|
2987
3366
|
src0t == WSP_GGML_TYPE_Q5_1 ||
|
|
2988
3367
|
src0t == WSP_GGML_TYPE_Q8_0 ||
|
|
3368
|
+
src0t == WSP_GGML_TYPE_MXFP4 ||
|
|
2989
3369
|
src0t == WSP_GGML_TYPE_IQ4_NL ||
|
|
2990
3370
|
false) && (ne11 >= 2 && ne11 <= 8)
|
|
2991
3371
|
) ||
|
|
@@ -3078,6 +3458,14 @@ static bool wsp_ggml_metal_encode_node(
|
|
|
3078
3458
|
case 5: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5].pipeline; break;
|
|
3079
3459
|
default: WSP_GGML_ABORT("not implemented");
|
|
3080
3460
|
} break;
|
|
3461
|
+
case WSP_GGML_TYPE_MXFP4:
|
|
3462
|
+
switch (r1ptg) {
|
|
3463
|
+
case 2: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_2].pipeline; break;
|
|
3464
|
+
case 3: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_3].pipeline; break;
|
|
3465
|
+
case 4: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_4].pipeline; break;
|
|
3466
|
+
case 5: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_5].pipeline; break;
|
|
3467
|
+
default: WSP_GGML_ABORT("not implemented");
|
|
3468
|
+
} break;
|
|
3081
3469
|
case WSP_GGML_TYPE_Q4_K:
|
|
3082
3470
|
switch (r1ptg) {
|
|
3083
3471
|
case 2: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2].pipeline; break;
|
|
@@ -3176,6 +3564,7 @@ static bool wsp_ggml_metal_encode_node(
|
|
|
3176
3564
|
case WSP_GGML_TYPE_Q5_0: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break;
|
|
3177
3565
|
case WSP_GGML_TYPE_Q5_1: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32 ].pipeline; break;
|
|
3178
3566
|
case WSP_GGML_TYPE_Q8_0: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32 ].pipeline; break;
|
|
3567
|
+
case WSP_GGML_TYPE_MXFP4: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32 ].pipeline; break;
|
|
3179
3568
|
case WSP_GGML_TYPE_Q2_K: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32 ].pipeline; break;
|
|
3180
3569
|
case WSP_GGML_TYPE_Q3_K: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32 ].pipeline; break;
|
|
3181
3570
|
case WSP_GGML_TYPE_Q4_K: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32 ].pipeline; break;
|
|
@@ -3318,6 +3707,13 @@ static bool wsp_ggml_metal_encode_node(
|
|
|
3318
3707
|
nr0 = N_R0_Q8_0;
|
|
3319
3708
|
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline;
|
|
3320
3709
|
} break;
|
|
3710
|
+
case WSP_GGML_TYPE_MXFP4:
|
|
3711
|
+
{
|
|
3712
|
+
nsg = N_SG_MXFP4;
|
|
3713
|
+
nr0 = N_R0_MXFP4;
|
|
3714
|
+
smem = 32*sizeof(float);
|
|
3715
|
+
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32].pipeline;
|
|
3716
|
+
} break;
|
|
3321
3717
|
case WSP_GGML_TYPE_Q2_K:
|
|
3322
3718
|
{
|
|
3323
3719
|
nsg = N_SG_Q2_K;
|
|
@@ -3451,8 +3847,6 @@ static bool wsp_ggml_metal_encode_node(
|
|
|
3451
3847
|
case WSP_GGML_OP_MUL_MAT_ID:
|
|
3452
3848
|
{
|
|
3453
3849
|
// src2 = ids
|
|
3454
|
-
const enum wsp_ggml_type src2t = src2->type; WSP_GGML_UNUSED(src2t);
|
|
3455
|
-
|
|
3456
3850
|
WSP_GGML_ASSERT(src2t == WSP_GGML_TYPE_I32);
|
|
3457
3851
|
|
|
3458
3852
|
WSP_GGML_ASSERT(!wsp_ggml_is_transposed(src0));
|
|
@@ -3501,7 +3895,7 @@ static bool wsp_ggml_metal_encode_node(
|
|
|
3501
3895
|
id<MTLBuffer> h_src1 = wsp_ggml_metal_mem_pool_alloc(mem_pool, s_src1);
|
|
3502
3896
|
if (!h_src1) {
|
|
3503
3897
|
WSP_GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_src1);
|
|
3504
|
-
return
|
|
3898
|
+
return 0;
|
|
3505
3899
|
}
|
|
3506
3900
|
|
|
3507
3901
|
const int64_t neh0 = ne0;
|
|
@@ -3517,7 +3911,7 @@ static bool wsp_ggml_metal_encode_node(
|
|
|
3517
3911
|
id<MTLBuffer> h_dst = wsp_ggml_metal_mem_pool_alloc(mem_pool, s_dst);
|
|
3518
3912
|
if (!h_dst) {
|
|
3519
3913
|
WSP_GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_dst);
|
|
3520
|
-
return
|
|
3914
|
+
return 0;
|
|
3521
3915
|
}
|
|
3522
3916
|
|
|
3523
3917
|
// tokens per expert
|
|
@@ -3525,7 +3919,7 @@ static bool wsp_ggml_metal_encode_node(
|
|
|
3525
3919
|
id<MTLBuffer> h_tpe = wsp_ggml_metal_mem_pool_alloc(mem_pool, s_tpe);
|
|
3526
3920
|
if (!h_tpe) {
|
|
3527
3921
|
WSP_GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_tpe);
|
|
3528
|
-
return
|
|
3922
|
+
return 0;
|
|
3529
3923
|
}
|
|
3530
3924
|
|
|
3531
3925
|
// id map
|
|
@@ -3534,7 +3928,7 @@ static bool wsp_ggml_metal_encode_node(
|
|
|
3534
3928
|
id<MTLBuffer> h_ids = wsp_ggml_metal_mem_pool_alloc(mem_pool, s_ids);
|
|
3535
3929
|
if (!h_ids) {
|
|
3536
3930
|
WSP_GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_ids);
|
|
3537
|
-
return
|
|
3931
|
+
return 0;
|
|
3538
3932
|
}
|
|
3539
3933
|
|
|
3540
3934
|
{
|
|
@@ -3578,6 +3972,7 @@ static bool wsp_ggml_metal_encode_node(
|
|
|
3578
3972
|
case WSP_GGML_TYPE_Q5_0: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16 ].pipeline; break;
|
|
3579
3973
|
case WSP_GGML_TYPE_Q5_1: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16 ].pipeline; break;
|
|
3580
3974
|
case WSP_GGML_TYPE_Q8_0: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16 ].pipeline; break;
|
|
3975
|
+
case WSP_GGML_TYPE_MXFP4: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MXFP4_F16 ].pipeline; break;
|
|
3581
3976
|
case WSP_GGML_TYPE_Q2_K: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16 ].pipeline; break;
|
|
3582
3977
|
case WSP_GGML_TYPE_Q3_K: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16 ].pipeline; break;
|
|
3583
3978
|
case WSP_GGML_TYPE_Q4_K: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16 ].pipeline; break;
|
|
@@ -3713,6 +4108,13 @@ static bool wsp_ggml_metal_encode_node(
|
|
|
3713
4108
|
nr0 = N_R0_Q8_0;
|
|
3714
4109
|
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline;
|
|
3715
4110
|
} break;
|
|
4111
|
+
case WSP_GGML_TYPE_MXFP4:
|
|
4112
|
+
{
|
|
4113
|
+
nsg = N_SG_MXFP4;
|
|
4114
|
+
nr0 = N_R0_MXFP4;
|
|
4115
|
+
smem = 32*sizeof(float);
|
|
4116
|
+
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_ID_MXFP4_F32].pipeline;
|
|
4117
|
+
} break;
|
|
3716
4118
|
case WSP_GGML_TYPE_Q2_K:
|
|
3717
4119
|
{
|
|
3718
4120
|
nsg = N_SG_Q2_K;
|
|
@@ -3865,6 +4267,7 @@ static bool wsp_ggml_metal_encode_node(
|
|
|
3865
4267
|
case WSP_GGML_TYPE_Q5_0: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0 ].pipeline; break;
|
|
3866
4268
|
case WSP_GGML_TYPE_Q5_1: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1 ].pipeline; break;
|
|
3867
4269
|
case WSP_GGML_TYPE_Q8_0: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0 ].pipeline; break;
|
|
4270
|
+
case WSP_GGML_TYPE_MXFP4: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_MXFP4 ].pipeline; break;
|
|
3868
4271
|
case WSP_GGML_TYPE_Q2_K: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K ].pipeline; break;
|
|
3869
4272
|
case WSP_GGML_TYPE_Q3_K: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K ].pipeline; break;
|
|
3870
4273
|
case WSP_GGML_TYPE_Q4_K: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K ].pipeline; break;
|
|
@@ -3966,12 +4369,95 @@ static bool wsp_ggml_metal_encode_node(
|
|
|
3966
4369
|
case WSP_GGML_OP_RMS_NORM:
|
|
3967
4370
|
{
|
|
3968
4371
|
WSP_GGML_ASSERT(ne00 % 4 == 0);
|
|
3969
|
-
WSP_GGML_ASSERT(
|
|
4372
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous_rows(src0));
|
|
3970
4373
|
|
|
3971
4374
|
float eps;
|
|
3972
4375
|
memcpy(&eps, dst->op_params, sizeof(float));
|
|
3973
4376
|
|
|
3974
|
-
|
|
4377
|
+
wsp_ggml_metal_kargs_rms_norm args = {
|
|
4378
|
+
/*.ne00 =*/ ne00,
|
|
4379
|
+
/*.ne00_4 =*/ ne00/4,
|
|
4380
|
+
/*.nb1 =*/ nb1,
|
|
4381
|
+
/*.nb2 =*/ nb2,
|
|
4382
|
+
/*.nb3 =*/ nb3,
|
|
4383
|
+
/*.eps =*/ eps,
|
|
4384
|
+
/*.nef1 =*/ { ne01 },
|
|
4385
|
+
/*.nef2 =*/ { ne02 },
|
|
4386
|
+
/*.nef3 =*/ { ne03 },
|
|
4387
|
+
/*.nbf1 =*/ { nb01 },
|
|
4388
|
+
/*.nbf2 =*/ { nb02 },
|
|
4389
|
+
/*.nbf3 =*/ { nb03 },
|
|
4390
|
+
};
|
|
4391
|
+
|
|
4392
|
+
size_t offs_fuse[2] = { 0, 0 };
|
|
4393
|
+
id<MTLBuffer> id_fuse[2] = { id_src0, id_src0 };
|
|
4394
|
+
|
|
4395
|
+
// d[0] = rms_norm(a)
|
|
4396
|
+
// d[1] = mul(d[0], b)
|
|
4397
|
+
// d[2] = add(d[1], c)
|
|
4398
|
+
if (ctx_dev->use_fusion) {
|
|
4399
|
+
ops[0] = WSP_GGML_OP_RMS_NORM;
|
|
4400
|
+
ops[1] = WSP_GGML_OP_MUL;
|
|
4401
|
+
ops[2] = WSP_GGML_OP_ADD;
|
|
4402
|
+
|
|
4403
|
+
for (n_fuse = 0; n_fuse <= 1 && idx + n_fuse + 1 < idx_end; ++n_fuse) {
|
|
4404
|
+
if (!wsp_ggml_can_fuse(gf, idx + n_fuse, ops + n_fuse, 2)) {
|
|
4405
|
+
break;
|
|
4406
|
+
}
|
|
4407
|
+
|
|
4408
|
+
if (nodes[n_fuse] != nodes[n_fuse + 1]->src[0]) {
|
|
4409
|
+
break;
|
|
4410
|
+
}
|
|
4411
|
+
|
|
4412
|
+
if (nodes[n_fuse + 1]->src[1]->ne[0] != node->ne[0]) {
|
|
4413
|
+
break;
|
|
4414
|
+
}
|
|
4415
|
+
|
|
4416
|
+
if (!wsp_ggml_is_contiguous_rows(nodes[n_fuse + 1]->src[1])) {
|
|
4417
|
+
break;
|
|
4418
|
+
}
|
|
4419
|
+
|
|
4420
|
+
if (nodes[n_fuse + 1]->type != WSP_GGML_TYPE_F32) {
|
|
4421
|
+
break;
|
|
4422
|
+
}
|
|
4423
|
+
|
|
4424
|
+
ctx_dev->fuse_cnt[nodes[n_fuse + 1]->op]++;
|
|
4425
|
+
|
|
4426
|
+
id_fuse[n_fuse] = wsp_ggml_metal_get_buffer(nodes[n_fuse + 1]->src[1], &offs_fuse[n_fuse]);
|
|
4427
|
+
|
|
4428
|
+
args.nef1[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->ne[1];
|
|
4429
|
+
args.nef2[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->ne[2];
|
|
4430
|
+
args.nef3[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->ne[3];
|
|
4431
|
+
|
|
4432
|
+
args.nbf1[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->nb[1];
|
|
4433
|
+
args.nbf2[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->nb[2];
|
|
4434
|
+
args.nbf3[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->nb[3];
|
|
4435
|
+
}
|
|
4436
|
+
|
|
4437
|
+
++n_fuse;
|
|
4438
|
+
|
|
4439
|
+
if (ctx_dev->debug_fusion > 1 && n_fuse > 1) {
|
|
4440
|
+
if (n_fuse == 2) {
|
|
4441
|
+
WSP_GGML_LOG_DEBUG("%s: fuse: RMS_NORM + MUL\n", __func__);
|
|
4442
|
+
}
|
|
4443
|
+
if (n_fuse == 3) {
|
|
4444
|
+
WSP_GGML_LOG_DEBUG("%s: fuse: RMS_NORM + MUL + ADD\n", __func__);
|
|
4445
|
+
}
|
|
4446
|
+
}
|
|
4447
|
+
}
|
|
4448
|
+
|
|
4449
|
+
if (n_fuse > 1) {
|
|
4450
|
+
id_dst = wsp_ggml_metal_get_buffer(nodes[n_fuse - 1], &offs_dst);
|
|
4451
|
+
}
|
|
4452
|
+
|
|
4453
|
+
id<MTLComputePipelineState> pipeline;
|
|
4454
|
+
|
|
4455
|
+
switch (n_fuse) {
|
|
4456
|
+
case 1: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_RMS_NORM ].pipeline; break;
|
|
4457
|
+
case 2: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL ].pipeline; break;
|
|
4458
|
+
case 3: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD].pipeline; break;
|
|
4459
|
+
default: WSP_GGML_ABORT("unsupported n_fuse = %d\n", n_fuse);
|
|
4460
|
+
}
|
|
3975
4461
|
|
|
3976
4462
|
int nth = 32; // SIMD width
|
|
3977
4463
|
|
|
@@ -3982,23 +4468,16 @@ static bool wsp_ggml_metal_encode_node(
|
|
|
3982
4468
|
nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
|
|
3983
4469
|
nth = MIN(nth, ne00/4);
|
|
3984
4470
|
|
|
3985
|
-
wsp_ggml_metal_kargs_rms_norm args = {
|
|
3986
|
-
/*.ne00 =*/ ne00,
|
|
3987
|
-
/*.ne00_4 =*/ ne00/4,
|
|
3988
|
-
/*.nb01 =*/ nb01,
|
|
3989
|
-
/*.eps =*/ eps,
|
|
3990
|
-
};
|
|
3991
|
-
|
|
3992
4471
|
[encoder setComputePipelineState:pipeline];
|
|
3993
|
-
[encoder setBytes:&args length:sizeof(args)
|
|
3994
|
-
[encoder setBuffer:id_src0
|
|
3995
|
-
[encoder setBuffer:
|
|
4472
|
+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
|
4473
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
|
4474
|
+
[encoder setBuffer:id_fuse[0] offset:offs_fuse[0] atIndex:2];
|
|
4475
|
+
[encoder setBuffer:id_fuse[1] offset:offs_fuse[1] atIndex:3];
|
|
4476
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:4];
|
|
3996
4477
|
|
|
3997
4478
|
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
|
3998
4479
|
|
|
3999
|
-
|
|
4000
|
-
|
|
4001
|
-
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
4480
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
4002
4481
|
} break;
|
|
4003
4482
|
case WSP_GGML_OP_L2_NORM:
|
|
4004
4483
|
{
|
|
@@ -4599,11 +5078,14 @@ static bool wsp_ggml_metal_encode_node(
|
|
|
4599
5078
|
WSP_GGML_ASSERT(ne11 == ne21);
|
|
4600
5079
|
WSP_GGML_ASSERT(ne12 == ne22);
|
|
4601
5080
|
|
|
4602
|
-
struct wsp_ggml_tensor * src3 = node->src[3];
|
|
5081
|
+
struct wsp_ggml_tensor * src3 = node->src[3]; // mask
|
|
5082
|
+
struct wsp_ggml_tensor * src4 = node->src[4]; // sinks
|
|
4603
5083
|
|
|
4604
5084
|
size_t offs_src3 = 0;
|
|
5085
|
+
size_t offs_src4 = 0;
|
|
4605
5086
|
|
|
4606
5087
|
id<MTLBuffer> id_src3 = src3 ? wsp_ggml_metal_get_buffer(src3, &offs_src3) : nil;
|
|
5088
|
+
id<MTLBuffer> id_src4 = src4 ? wsp_ggml_metal_get_buffer(src4, &offs_src4) : nil;
|
|
4607
5089
|
|
|
4608
5090
|
WSP_GGML_ASSERT(!src3 || src3->type == WSP_GGML_TYPE_F16);
|
|
4609
5091
|
WSP_GGML_ASSERT(!src3 || src3->ne[1] >= WSP_GGML_PAD(src0->ne[1], 8) &&
|
|
@@ -4619,8 +5101,6 @@ static bool wsp_ggml_metal_encode_node(
|
|
|
4619
5101
|
const uint64_t nb32 = src3 ? src3->nb[2] : 0; WSP_GGML_UNUSED(nb32);
|
|
4620
5102
|
const uint64_t nb33 = src3 ? src3->nb[3] : 0; WSP_GGML_UNUSED(nb33);
|
|
4621
5103
|
|
|
4622
|
-
const enum wsp_ggml_type src2t = src2 ? src2->type : WSP_GGML_TYPE_COUNT; WSP_GGML_UNUSED(src2t);
|
|
4623
|
-
|
|
4624
5104
|
float scale;
|
|
4625
5105
|
float max_bias;
|
|
4626
5106
|
float logit_softcap;
|
|
@@ -4983,7 +5463,11 @@ static bool wsp_ggml_metal_encode_node(
|
|
|
4983
5463
|
/*.nb21 =*/ nb21,
|
|
4984
5464
|
/*.nb22 =*/ nb22,
|
|
4985
5465
|
/*.nb23 =*/ nb23,
|
|
5466
|
+
/*.ne32 =*/ ne32,
|
|
5467
|
+
/*.ne33 =*/ ne33,
|
|
4986
5468
|
/*.nb31 =*/ nb31,
|
|
5469
|
+
/*.nb32 =*/ nb32,
|
|
5470
|
+
/*.nb33 =*/ nb33,
|
|
4987
5471
|
/*.ne1 =*/ ne1,
|
|
4988
5472
|
/*.ne2 =*/ ne2,
|
|
4989
5473
|
/*.scale =*/ scale,
|
|
@@ -5004,7 +5488,12 @@ static bool wsp_ggml_metal_encode_node(
|
|
|
5004
5488
|
} else {
|
|
5005
5489
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:4];
|
|
5006
5490
|
}
|
|
5007
|
-
|
|
5491
|
+
if (id_src4) {
|
|
5492
|
+
[encoder setBuffer:id_src4 offset:offs_src4 atIndex:5];
|
|
5493
|
+
} else {
|
|
5494
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:5];
|
|
5495
|
+
}
|
|
5496
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:6];
|
|
5008
5497
|
|
|
5009
5498
|
if (!use_vec_kernel) {
|
|
5010
5499
|
// half8x8 kernel
|
|
@@ -5389,7 +5878,7 @@ static bool wsp_ggml_metal_encode_node(
|
|
|
5389
5878
|
}
|
|
5390
5879
|
}
|
|
5391
5880
|
|
|
5392
|
-
return
|
|
5881
|
+
return n_fuse;
|
|
5393
5882
|
}
|
|
5394
5883
|
|
|
5395
5884
|
static enum wsp_ggml_status wsp_ggml_metal_graph_compute(
|
|
@@ -5895,20 +6384,26 @@ static void wsp_ggml_backend_metal_set_n_cb(wsp_ggml_backend_t backend, int n_cb
|
|
|
5895
6384
|
struct wsp_ggml_metal_mem_pool * mem_pool = ctx->cmd_bufs[cb_idx].mem_pool;
|
|
5896
6385
|
wsp_ggml_metal_mem_pool_reset(mem_pool);
|
|
5897
6386
|
|
|
5898
|
-
for (int idx = node_start; idx < node_end;
|
|
6387
|
+
for (int idx = node_start; idx < node_end;) {
|
|
5899
6388
|
if (should_capture) {
|
|
5900
6389
|
[encoder pushDebugGroup:[NSString stringWithCString:wsp_ggml_op_desc(wsp_ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
|
|
5901
6390
|
}
|
|
5902
6391
|
|
|
5903
|
-
const
|
|
6392
|
+
const int res = wsp_ggml_metal_encode_node(backend, idx, node_end, encoder, mem_pool);
|
|
6393
|
+
if (idx + res > node_end) {
|
|
6394
|
+
WSP_GGML_ABORT("fusion error: nodes spanning multiple encoders have been fused. this indicates a bug in the fusion logic %s",
|
|
6395
|
+
"https://github.com/ggml-org/llama.cpp/pull/14849");
|
|
6396
|
+
}
|
|
5904
6397
|
|
|
5905
6398
|
if (should_capture) {
|
|
5906
6399
|
[encoder popDebugGroup];
|
|
5907
6400
|
}
|
|
5908
6401
|
|
|
5909
|
-
if (
|
|
6402
|
+
if (res == 0) {
|
|
5910
6403
|
break;
|
|
5911
6404
|
}
|
|
6405
|
+
|
|
6406
|
+
idx += res;
|
|
5912
6407
|
}
|
|
5913
6408
|
|
|
5914
6409
|
[encoder endEncoding];
|