whisper.rn 0.4.0-rc.10 → 0.4.0-rc.11
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/src/main/CMakeLists.txt +9 -3
- package/cpp/ggml-alloc.c +6 -14
- package/cpp/ggml-backend-impl.h +50 -11
- package/cpp/ggml-backend-reg.cpp +409 -31
- package/cpp/ggml-backend.cpp +9 -3
- package/cpp/ggml-backend.h +18 -0
- package/cpp/ggml-common.h +41 -43
- package/cpp/ggml-cpp.h +1 -0
- package/cpp/{ggml-cpu-aarch64.c → ggml-cpu-aarch64.cpp} +941 -254
- package/cpp/ggml-cpu-aarch64.h +2 -24
- package/cpp/ggml-cpu-impl.h +171 -11
- package/cpp/ggml-cpu-quants.c +1812 -389
- package/cpp/ggml-cpu-traits.cpp +36 -0
- package/cpp/ggml-cpu-traits.h +38 -0
- package/cpp/ggml-cpu.c +1432 -610
- package/cpp/ggml-cpu.cpp +131 -141
- package/cpp/ggml-cpu.h +10 -50
- package/cpp/ggml-impl.h +27 -11
- package/cpp/ggml-metal-impl.h +39 -0
- package/cpp/ggml-metal.h +1 -1
- package/cpp/ggml-metal.m +1031 -359
- package/cpp/ggml-opt.cpp +854 -0
- package/cpp/ggml-opt.h +216 -0
- package/cpp/ggml-quants.c +0 -9
- package/cpp/ggml-threading.h +4 -2
- package/cpp/ggml-whisper.metallib +0 -0
- package/cpp/ggml.c +501 -1537
- package/cpp/ggml.h +144 -171
- package/cpp/gguf.cpp +1329 -0
- package/cpp/gguf.h +202 -0
- package/cpp/whisper.cpp +254 -114
- package/cpp/whisper.h +6 -3
- package/lib/commonjs/version.json +1 -1
- package/lib/module/version.json +1 -1
- package/package.json +1 -1
- package/src/version.json +1 -1
- package/whisper-rn.podspec +2 -2
- package/cpp/README.md +0 -4
- package/cpp/ggml-aarch64.c +0 -129
- package/cpp/ggml-aarch64.h +0 -19
- package/cpp/ggml-backend.cpp.rej +0 -12
package/cpp/ggml-metal.m
CHANGED
|
@@ -19,7 +19,17 @@
|
|
|
19
19
|
// max number of MTLCommandBuffer used to submit a graph for processing
|
|
20
20
|
#define WSP_GGML_METAL_MAX_COMMAND_BUFFERS 8
|
|
21
21
|
|
|
22
|
-
#
|
|
22
|
+
#ifndef TARGET_OS_VISION
|
|
23
|
+
#define TARGET_OS_VISION 0
|
|
24
|
+
#endif
|
|
25
|
+
|
|
26
|
+
// create residency sets only on macOS >= 15.0
|
|
27
|
+
#if !TARGET_CPU_X86_64 && TARGET_OS_OSX && __MAC_OS_X_VERSION_MAX_ALLOWED >= 150000 || \
|
|
28
|
+
TARGET_OS_IOS && __IPHONE_OS_VERSION_MAX_ALLOWED >= 180000 || \
|
|
29
|
+
TARGET_OS_TV && __TV_OS_VERSION_MAX_ALLOWED >= 180000 || \
|
|
30
|
+
TARGET_OS_VISION && __VISION_OS_VERSION_MAX_ALLOWED >= 200000
|
|
31
|
+
#define WSP_GGML_METAL_HAS_RESIDENCY_SETS 1
|
|
32
|
+
#endif
|
|
23
33
|
|
|
24
34
|
// globals
|
|
25
35
|
|
|
@@ -39,6 +49,7 @@ static struct wsp_ggml_backend_metal_device_context {
|
|
|
39
49
|
|
|
40
50
|
bool has_simdgroup_reduction;
|
|
41
51
|
bool has_simdgroup_mm;
|
|
52
|
+
bool has_residency_sets;
|
|
42
53
|
bool has_bfloat;
|
|
43
54
|
bool use_bfloat;
|
|
44
55
|
|
|
@@ -48,6 +59,7 @@ static struct wsp_ggml_backend_metal_device_context {
|
|
|
48
59
|
/*.mtl_device_ref_count =*/ 0,
|
|
49
60
|
/*.has_simdgroup_reduction =*/ false,
|
|
50
61
|
/*.has_simdgroup_mm =*/ false,
|
|
62
|
+
/*.has_residency_sets =*/ false,
|
|
51
63
|
/*.has_bfloat =*/ false,
|
|
52
64
|
/*.use_bfloat =*/ false,
|
|
53
65
|
/*.name =*/ "",
|
|
@@ -59,12 +71,18 @@ static id<MTLDevice> wsp_ggml_backend_metal_device_acq(struct wsp_ggml_backend_m
|
|
|
59
71
|
|
|
60
72
|
if (ctx->mtl_device == nil) {
|
|
61
73
|
ctx->mtl_device = MTLCreateSystemDefaultDevice();
|
|
74
|
+
}
|
|
62
75
|
|
|
76
|
+
if (ctx->mtl_device) {
|
|
63
77
|
ctx->has_simdgroup_reduction = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
|
|
64
78
|
ctx->has_simdgroup_reduction |= [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
|
|
65
79
|
|
|
66
80
|
ctx->has_simdgroup_mm = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
|
|
67
81
|
|
|
82
|
+
#if defined(WSP_GGML_METAL_HAS_RESIDENCY_SETS)
|
|
83
|
+
ctx->has_residency_sets = getenv("WSP_GGML_METAL_NO_RESIDENCY") == NULL;
|
|
84
|
+
#endif
|
|
85
|
+
|
|
68
86
|
ctx->has_bfloat = [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
|
|
69
87
|
ctx->has_bfloat |= [ctx->mtl_device supportsFamily:MTLGPUFamilyApple6];
|
|
70
88
|
|
|
@@ -90,8 +108,10 @@ static void wsp_ggml_backend_metal_device_rel(struct wsp_ggml_backend_metal_devi
|
|
|
90
108
|
ctx->mtl_device_ref_count--;
|
|
91
109
|
|
|
92
110
|
if (ctx->mtl_device_ref_count == 0) {
|
|
93
|
-
|
|
94
|
-
|
|
111
|
+
if (ctx->mtl_device) {
|
|
112
|
+
[ctx->mtl_device release];
|
|
113
|
+
ctx->mtl_device = nil;
|
|
114
|
+
}
|
|
95
115
|
}
|
|
96
116
|
}
|
|
97
117
|
|
|
@@ -175,6 +195,46 @@ enum wsp_ggml_metal_kernel_type {
|
|
|
175
195
|
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,
|
|
176
196
|
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,
|
|
177
197
|
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,
|
|
198
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2,
|
|
199
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3,
|
|
200
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4,
|
|
201
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_5,
|
|
202
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_2,
|
|
203
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_3,
|
|
204
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_4,
|
|
205
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_5,
|
|
206
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_2,
|
|
207
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_3,
|
|
208
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_4,
|
|
209
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_5,
|
|
210
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_2,
|
|
211
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_3,
|
|
212
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_4,
|
|
213
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_5,
|
|
214
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_2,
|
|
215
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_3,
|
|
216
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_4,
|
|
217
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_5,
|
|
218
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_2,
|
|
219
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3,
|
|
220
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4,
|
|
221
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5,
|
|
222
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2,
|
|
223
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3,
|
|
224
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4,
|
|
225
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_5,
|
|
226
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_2,
|
|
227
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_3,
|
|
228
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_4,
|
|
229
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_5,
|
|
230
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_2,
|
|
231
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_3,
|
|
232
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_4,
|
|
233
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_5,
|
|
234
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_2,
|
|
235
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_3,
|
|
236
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_4,
|
|
237
|
+
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_5,
|
|
178
238
|
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32,
|
|
179
239
|
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32,
|
|
180
240
|
WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32,
|
|
@@ -266,8 +326,11 @@ enum wsp_ggml_metal_kernel_type {
|
|
|
266
326
|
WSP_GGML_METAL_KERNEL_TYPE_IM2COL_F32,
|
|
267
327
|
WSP_GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16,
|
|
268
328
|
WSP_GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32,
|
|
329
|
+
WSP_GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F32_F32,
|
|
330
|
+
WSP_GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32,
|
|
269
331
|
WSP_GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
|
|
270
332
|
WSP_GGML_METAL_KERNEL_TYPE_PAD_F32,
|
|
333
|
+
WSP_GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32,
|
|
271
334
|
WSP_GGML_METAL_KERNEL_TYPE_ARANGE_F32,
|
|
272
335
|
WSP_GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32,
|
|
273
336
|
WSP_GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
|
|
@@ -329,6 +392,8 @@ enum wsp_ggml_metal_kernel_type {
|
|
|
329
392
|
WSP_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256,
|
|
330
393
|
WSP_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256,
|
|
331
394
|
WSP_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256,
|
|
395
|
+
WSP_GGML_METAL_KERNEL_TYPE_SET_I32,
|
|
396
|
+
WSP_GGML_METAL_KERNEL_TYPE_SET_F32,
|
|
332
397
|
WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
|
|
333
398
|
WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
|
|
334
399
|
WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_BF16,
|
|
@@ -342,6 +407,16 @@ enum wsp_ggml_metal_kernel_type {
|
|
|
342
407
|
WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,
|
|
343
408
|
WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,
|
|
344
409
|
WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,
|
|
410
|
+
WSP_GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32,
|
|
411
|
+
WSP_GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F16,
|
|
412
|
+
WSP_GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32,
|
|
413
|
+
WSP_GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F16,
|
|
414
|
+
WSP_GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32,
|
|
415
|
+
WSP_GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F16,
|
|
416
|
+
WSP_GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32,
|
|
417
|
+
WSP_GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F16,
|
|
418
|
+
WSP_GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32,
|
|
419
|
+
WSP_GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16,
|
|
345
420
|
WSP_GGML_METAL_KERNEL_TYPE_CONCAT,
|
|
346
421
|
WSP_GGML_METAL_KERNEL_TYPE_SQR,
|
|
347
422
|
WSP_GGML_METAL_KERNEL_TYPE_SQRT,
|
|
@@ -350,6 +425,7 @@ enum wsp_ggml_metal_kernel_type {
|
|
|
350
425
|
WSP_GGML_METAL_KERNEL_TYPE_SUM_ROWS,
|
|
351
426
|
WSP_GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
|
|
352
427
|
WSP_GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
|
|
428
|
+
WSP_GGML_METAL_KERNEL_TYPE_ARGMAX,
|
|
353
429
|
|
|
354
430
|
WSP_GGML_METAL_KERNEL_TYPE_COUNT
|
|
355
431
|
};
|
|
@@ -437,6 +513,11 @@ static struct wsp_ggml_backend_metal_context * wsp_ggml_metal_init(wsp_ggml_back
|
|
|
437
513
|
WSP_GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]);
|
|
438
514
|
|
|
439
515
|
ctx->queue = [device newCommandQueue];
|
|
516
|
+
if (ctx->queue == nil) {
|
|
517
|
+
WSP_GGML_LOG_ERROR("%s: error: failed to create command queue\n", __func__);
|
|
518
|
+
return NULL;
|
|
519
|
+
}
|
|
520
|
+
|
|
440
521
|
ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
|
|
441
522
|
|
|
442
523
|
id<MTLLibrary> metal_library;
|
|
@@ -464,6 +545,35 @@ static struct wsp_ggml_backend_metal_context * wsp_ggml_metal_init(wsp_ggml_back
|
|
|
464
545
|
#endif
|
|
465
546
|
|
|
466
547
|
NSString * path_lib = [bundle pathForResource:@"ggml-whisper" ofType:@"metallib"];
|
|
548
|
+
if (path_lib == nil) {
|
|
549
|
+
// Try to find the resource in the directory where the current binary located.
|
|
550
|
+
NSString * current_binary = [[NSProcessInfo processInfo] arguments][0];
|
|
551
|
+
NSString * bin_dir = [current_binary stringByDeletingLastPathComponent];
|
|
552
|
+
NSString * default_metallib_path = [NSString pathWithComponents:@[bin_dir, @"default.metallib"]];
|
|
553
|
+
if ([[NSFileManager defaultManager] isReadableFileAtPath:default_metallib_path]) {
|
|
554
|
+
WSP_GGML_LOG_INFO("%s: found '%s'\n", __func__, [default_metallib_path UTF8String]);
|
|
555
|
+
NSDictionary * atts = [[NSFileManager defaultManager] attributesOfItemAtPath:default_metallib_path error:&error];
|
|
556
|
+
if (atts && atts[NSFileType] == NSFileTypeSymbolicLink) {
|
|
557
|
+
// Optionally, if this is a symlink, try to resolve it.
|
|
558
|
+
default_metallib_path = [[NSFileManager defaultManager] destinationOfSymbolicLinkAtPath:default_metallib_path error:&error];
|
|
559
|
+
if (default_metallib_path && [default_metallib_path length] > 0 && ![[default_metallib_path substringToIndex:1] isEqualToString:@"/"]) {
|
|
560
|
+
// It is a relative path, adding the binary directory as directory prefix.
|
|
561
|
+
default_metallib_path = [NSString pathWithComponents:@[bin_dir, default_metallib_path]];
|
|
562
|
+
}
|
|
563
|
+
if (!default_metallib_path || ![[NSFileManager defaultManager] isReadableFileAtPath:default_metallib_path]) {
|
|
564
|
+
// Link to the resource could not be resolved.
|
|
565
|
+
default_metallib_path = nil;
|
|
566
|
+
} else {
|
|
567
|
+
WSP_GGML_LOG_INFO("%s: symlink resolved '%s'\n", __func__, [default_metallib_path UTF8String]);
|
|
568
|
+
}
|
|
569
|
+
}
|
|
570
|
+
} else {
|
|
571
|
+
// The resource couldn't be found in the binary's directory.
|
|
572
|
+
default_metallib_path = nil;
|
|
573
|
+
}
|
|
574
|
+
path_lib = default_metallib_path;
|
|
575
|
+
}
|
|
576
|
+
|
|
467
577
|
if (try_metallib && path_lib != nil) {
|
|
468
578
|
// pre-compiled library found
|
|
469
579
|
NSURL * libURL = [NSURL fileURLWithPath:path_lib];
|
|
@@ -574,6 +684,7 @@ static struct wsp_ggml_backend_metal_context * wsp_ggml_metal_init(wsp_ggml_back
|
|
|
574
684
|
|
|
575
685
|
WSP_GGML_LOG_INFO("%s: simdgroup reduction = %s\n", __func__, ctx_dev->has_simdgroup_reduction ? "true" : "false");
|
|
576
686
|
WSP_GGML_LOG_INFO("%s: simdgroup matrix mul. = %s\n", __func__, ctx_dev->has_simdgroup_mm ? "true" : "false");
|
|
687
|
+
WSP_GGML_LOG_INFO("%s: has residency sets = %s\n", __func__, ctx_dev->has_residency_sets ? "true" : "false");
|
|
577
688
|
WSP_GGML_LOG_INFO("%s: has bfloat = %s\n", __func__, ctx_dev->has_bfloat ? "true" : "false");
|
|
578
689
|
WSP_GGML_LOG_INFO("%s: use bfloat = %s\n", __func__, ctx_dev->use_bfloat ? "true" : "false");
|
|
579
690
|
WSP_GGML_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx_dev->mtl_device.hasUnifiedMemory ? "true" : "false");
|
|
@@ -699,6 +810,46 @@ static struct wsp_ggml_backend_metal_context * wsp_ggml_metal_init(wsp_ggml_back
|
|
|
699
810
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, has_simdgroup_reduction);
|
|
700
811
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, has_simdgroup_reduction);
|
|
701
812
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, has_simdgroup_reduction);
|
|
813
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2, mul_mv_ext_f16_f32_r1_2, has_simdgroup_reduction);
|
|
814
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3, mul_mv_ext_f16_f32_r1_3, has_simdgroup_reduction);
|
|
815
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4, mul_mv_ext_f16_f32_r1_4, has_simdgroup_reduction);
|
|
816
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_5, mul_mv_ext_f16_f32_r1_5, has_simdgroup_reduction);
|
|
817
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_2, mul_mv_ext_q4_0_f32_r1_2, has_simdgroup_reduction);
|
|
818
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_3, mul_mv_ext_q4_0_f32_r1_3, has_simdgroup_reduction);
|
|
819
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_4, mul_mv_ext_q4_0_f32_r1_4, has_simdgroup_reduction);
|
|
820
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_5, mul_mv_ext_q4_0_f32_r1_5, has_simdgroup_reduction);
|
|
821
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_2, mul_mv_ext_q4_1_f32_r1_2, has_simdgroup_reduction);
|
|
822
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_3, mul_mv_ext_q4_1_f32_r1_3, has_simdgroup_reduction);
|
|
823
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_4, mul_mv_ext_q4_1_f32_r1_4, has_simdgroup_reduction);
|
|
824
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_5, mul_mv_ext_q4_1_f32_r1_5, has_simdgroup_reduction);
|
|
825
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_2, mul_mv_ext_q5_0_f32_r1_2, has_simdgroup_reduction);
|
|
826
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_3, mul_mv_ext_q5_0_f32_r1_3, has_simdgroup_reduction);
|
|
827
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_4, mul_mv_ext_q5_0_f32_r1_4, has_simdgroup_reduction);
|
|
828
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_5, mul_mv_ext_q5_0_f32_r1_5, has_simdgroup_reduction);
|
|
829
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_2, mul_mv_ext_q5_1_f32_r1_2, has_simdgroup_reduction);
|
|
830
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_3, mul_mv_ext_q5_1_f32_r1_3, has_simdgroup_reduction);
|
|
831
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_4, mul_mv_ext_q5_1_f32_r1_4, has_simdgroup_reduction);
|
|
832
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_5, mul_mv_ext_q5_1_f32_r1_5, has_simdgroup_reduction);
|
|
833
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_2, mul_mv_ext_q8_0_f32_r1_2, has_simdgroup_reduction);
|
|
834
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3, mul_mv_ext_q8_0_f32_r1_3, has_simdgroup_reduction);
|
|
835
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4, mul_mv_ext_q8_0_f32_r1_4, has_simdgroup_reduction);
|
|
836
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5, mul_mv_ext_q8_0_f32_r1_5, has_simdgroup_reduction);
|
|
837
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2, mul_mv_ext_q4_K_f32_r1_2, has_simdgroup_reduction);
|
|
838
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3, mul_mv_ext_q4_K_f32_r1_3, has_simdgroup_reduction);
|
|
839
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4, mul_mv_ext_q4_K_f32_r1_4, has_simdgroup_reduction);
|
|
840
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_5, mul_mv_ext_q4_K_f32_r1_5, has_simdgroup_reduction);
|
|
841
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_2, mul_mv_ext_q5_K_f32_r1_2, has_simdgroup_reduction);
|
|
842
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_3, mul_mv_ext_q5_K_f32_r1_3, has_simdgroup_reduction);
|
|
843
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_4, mul_mv_ext_q5_K_f32_r1_4, has_simdgroup_reduction);
|
|
844
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_5, mul_mv_ext_q5_K_f32_r1_5, has_simdgroup_reduction);
|
|
845
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_2, mul_mv_ext_q6_K_f32_r1_2, has_simdgroup_reduction);
|
|
846
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_3, mul_mv_ext_q6_K_f32_r1_3, has_simdgroup_reduction);
|
|
847
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_4, mul_mv_ext_q6_K_f32_r1_4, has_simdgroup_reduction);
|
|
848
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_5, mul_mv_ext_q6_K_f32_r1_5, has_simdgroup_reduction);
|
|
849
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_2, mul_mv_ext_iq4_nl_f32_r1_2, has_simdgroup_reduction);
|
|
850
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_3, mul_mv_ext_iq4_nl_f32_r1_3, has_simdgroup_reduction);
|
|
851
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_4, mul_mv_ext_iq4_nl_f32_r1_4, has_simdgroup_reduction);
|
|
852
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_5, mul_mv_ext_iq4_nl_f32_r1_5, has_simdgroup_reduction);
|
|
702
853
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, has_simdgroup_reduction);
|
|
703
854
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, has_simdgroup_reduction);
|
|
704
855
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, has_simdgroup_reduction);
|
|
@@ -790,8 +941,11 @@ static struct wsp_ggml_backend_metal_context * wsp_ggml_metal_init(wsp_ggml_back
|
|
|
790
941
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
|
|
791
942
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16, im2col_ext_f16, true);
|
|
792
943
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32, im2col_ext_f32, true);
|
|
944
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F32_F32, conv_transpose_1d_f32_f32, true);
|
|
945
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32, conv_transpose_1d_f16_f32, true);
|
|
793
946
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
|
|
794
947
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
|
|
948
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32, pad_reflect_1d_f32, true);
|
|
795
949
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true);
|
|
796
950
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_ARANGE_F32, arange_f32, true);
|
|
797
951
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
|
|
@@ -853,6 +1007,8 @@ static struct wsp_ggml_backend_metal_context * wsp_ggml_metal_init(wsp_ggml_back
|
|
|
853
1007
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256, has_simdgroup_reduction);
|
|
854
1008
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256, flash_attn_ext_vec_q5_1_h256, has_simdgroup_reduction);
|
|
855
1009
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, flash_attn_ext_vec_q8_0_h256, has_simdgroup_reduction);
|
|
1010
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SET_F32, set_f32, true);
|
|
1011
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SET_I32, set_i32, true);
|
|
856
1012
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
|
|
857
1013
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
|
|
858
1014
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_BF16, cpy_f32_bf16, use_bfloat);
|
|
@@ -866,12 +1022,23 @@ static struct wsp_ggml_backend_metal_context * wsp_ggml_metal_init(wsp_ggml_back
|
|
|
866
1022
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true);
|
|
867
1023
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
|
|
868
1024
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true);
|
|
1025
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32, cpy_q4_0_f32, true);
|
|
1026
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F16, cpy_q4_0_f16, true);
|
|
1027
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32, cpy_q4_1_f32, true);
|
|
1028
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F16, cpy_q4_1_f16, true);
|
|
1029
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32, cpy_q5_0_f32, true);
|
|
1030
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F16, cpy_q5_0_f16, true);
|
|
1031
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32, cpy_q5_1_f32, true);
|
|
1032
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F16, cpy_q5_1_f16, true);
|
|
1033
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32, cpy_q8_0_f32, true);
|
|
1034
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16, cpy_q8_0_f16, true);
|
|
869
1035
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
|
|
870
1036
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SQR, sqr, true);
|
|
871
1037
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SQRT, sqrt, true);
|
|
872
1038
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SIN, sin, true);
|
|
873
1039
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_COS, cos, true);
|
|
874
1040
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
|
1041
|
+
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
|
|
875
1042
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
|
|
876
1043
|
WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true);
|
|
877
1044
|
}
|
|
@@ -914,8 +1081,70 @@ struct wsp_ggml_backend_metal_buffer_context {
|
|
|
914
1081
|
// multiple buffers are used only to avoid the maximum buffer size limitation when using mmap
|
|
915
1082
|
int n_buffers;
|
|
916
1083
|
struct wsp_ggml_backend_metal_buffer buffers[WSP_GGML_METAL_MAX_BUFFERS];
|
|
1084
|
+
|
|
1085
|
+
// optional MTLResidencySet
|
|
1086
|
+
id rset;
|
|
917
1087
|
};
|
|
918
1088
|
|
|
1089
|
+
// rset init
|
|
1090
|
+
static bool wsp_ggml_backend_metal_buffer_rset_init(
|
|
1091
|
+
struct wsp_ggml_backend_metal_buffer_context * ctx,
|
|
1092
|
+
struct wsp_ggml_backend_metal_device_context * ctx_dev,
|
|
1093
|
+
id<MTLDevice> device) {
|
|
1094
|
+
ctx->rset = nil;
|
|
1095
|
+
|
|
1096
|
+
if (!ctx_dev->has_residency_sets) {
|
|
1097
|
+
return true;
|
|
1098
|
+
}
|
|
1099
|
+
|
|
1100
|
+
#if defined(WSP_GGML_METAL_HAS_RESIDENCY_SETS)
|
|
1101
|
+
if (@available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, *)) {
|
|
1102
|
+
MTLResidencySetDescriptor * desc = [[MTLResidencySetDescriptor alloc] init];
|
|
1103
|
+
desc.label = @"wsp_ggml_backend_metal";
|
|
1104
|
+
desc.initialCapacity = ctx->n_buffers;
|
|
1105
|
+
|
|
1106
|
+
NSError * error;
|
|
1107
|
+
ctx->rset = [device newResidencySetWithDescriptor:desc error:&error];
|
|
1108
|
+
if (error) {
|
|
1109
|
+
WSP_GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
|
1110
|
+
[desc release];
|
|
1111
|
+
return false;
|
|
1112
|
+
}
|
|
1113
|
+
|
|
1114
|
+
[desc release];
|
|
1115
|
+
|
|
1116
|
+
for (int i = 0; i < ctx->n_buffers; i++) {
|
|
1117
|
+
[ctx->rset addAllocation:ctx->buffers[i].metal];
|
|
1118
|
+
}
|
|
1119
|
+
|
|
1120
|
+
[ctx->rset commit];
|
|
1121
|
+
[ctx->rset requestResidency];
|
|
1122
|
+
|
|
1123
|
+
return true;
|
|
1124
|
+
}
|
|
1125
|
+
#else
|
|
1126
|
+
WSP_GGML_UNUSED(ctx_dev);
|
|
1127
|
+
WSP_GGML_UNUSED(device);
|
|
1128
|
+
#endif
|
|
1129
|
+
|
|
1130
|
+
return true;
|
|
1131
|
+
}
|
|
1132
|
+
|
|
1133
|
+
// rset free
|
|
1134
|
+
static void wsp_ggml_backend_metal_buffer_rset_free(struct wsp_ggml_backend_metal_buffer_context * ctx) {
|
|
1135
|
+
#if defined(WSP_GGML_METAL_HAS_RESIDENCY_SETS)
|
|
1136
|
+
if (@available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, *)) {
|
|
1137
|
+
if (ctx->rset) {
|
|
1138
|
+
[ctx->rset endResidency];
|
|
1139
|
+
[ctx->rset removeAllAllocations];
|
|
1140
|
+
[ctx->rset release];
|
|
1141
|
+
}
|
|
1142
|
+
}
|
|
1143
|
+
#else
|
|
1144
|
+
WSP_GGML_UNUSED(ctx);
|
|
1145
|
+
#endif
|
|
1146
|
+
}
|
|
1147
|
+
|
|
919
1148
|
// finds the Metal buffer that contains the tensor data on the GPU device
|
|
920
1149
|
// the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
|
|
921
1150
|
// Metal buffer based on the host memory pointer
|
|
@@ -989,6 +1218,7 @@ static bool wsp_ggml_metal_supports_op(const struct wsp_ggml_backend_metal_devic
|
|
|
989
1218
|
case WSP_GGML_OP_REPEAT:
|
|
990
1219
|
case WSP_GGML_OP_SCALE:
|
|
991
1220
|
case WSP_GGML_OP_CLAMP:
|
|
1221
|
+
case WSP_GGML_OP_CONV_TRANSPOSE_1D:
|
|
992
1222
|
return true;
|
|
993
1223
|
case WSP_GGML_OP_SQR:
|
|
994
1224
|
case WSP_GGML_OP_SQRT:
|
|
@@ -997,12 +1227,25 @@ static bool wsp_ggml_metal_supports_op(const struct wsp_ggml_backend_metal_devic
|
|
|
997
1227
|
return wsp_ggml_is_contiguous(op->src[0]);
|
|
998
1228
|
case WSP_GGML_OP_SUM_ROWS:
|
|
999
1229
|
case WSP_GGML_OP_SOFT_MAX:
|
|
1000
|
-
case WSP_GGML_OP_RMS_NORM:
|
|
1001
1230
|
case WSP_GGML_OP_GROUP_NORM:
|
|
1002
|
-
return has_simdgroup_reduction;
|
|
1231
|
+
return has_simdgroup_reduction && wsp_ggml_is_contiguous(op->src[0]);
|
|
1232
|
+
case WSP_GGML_OP_RMS_NORM:
|
|
1233
|
+
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && wsp_ggml_is_contiguous_1(op->src[0]));
|
|
1234
|
+
case WSP_GGML_OP_ARGMAX:
|
|
1235
|
+
return true;
|
|
1003
1236
|
case WSP_GGML_OP_NORM:
|
|
1237
|
+
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && wsp_ggml_is_contiguous_1(op->src[0]));
|
|
1004
1238
|
case WSP_GGML_OP_ROPE:
|
|
1005
|
-
|
|
1239
|
+
{
|
|
1240
|
+
const int mode = ((const int32_t *) op->op_params)[2];
|
|
1241
|
+
if (mode & WSP_GGML_ROPE_TYPE_MROPE) {
|
|
1242
|
+
return false;
|
|
1243
|
+
}
|
|
1244
|
+
if (mode & WSP_GGML_ROPE_TYPE_VISION) {
|
|
1245
|
+
return false;
|
|
1246
|
+
}
|
|
1247
|
+
return true;
|
|
1248
|
+
}
|
|
1006
1249
|
case WSP_GGML_OP_IM2COL:
|
|
1007
1250
|
return op->src[0]->type == WSP_GGML_TYPE_F16;
|
|
1008
1251
|
case WSP_GGML_OP_POOL_1D:
|
|
@@ -1010,6 +1253,7 @@ static bool wsp_ggml_metal_supports_op(const struct wsp_ggml_backend_metal_devic
|
|
|
1010
1253
|
case WSP_GGML_OP_POOL_2D:
|
|
1011
1254
|
case WSP_GGML_OP_UPSCALE:
|
|
1012
1255
|
case WSP_GGML_OP_PAD:
|
|
1256
|
+
case WSP_GGML_OP_PAD_REFLECT_1D:
|
|
1013
1257
|
case WSP_GGML_OP_ARANGE:
|
|
1014
1258
|
case WSP_GGML_OP_TIMESTEP_EMBEDDING:
|
|
1015
1259
|
case WSP_GGML_OP_ARGSORT:
|
|
@@ -1063,6 +1307,28 @@ static bool wsp_ggml_metal_supports_op(const struct wsp_ggml_backend_metal_devic
|
|
|
1063
1307
|
default:
|
|
1064
1308
|
return false;
|
|
1065
1309
|
}
|
|
1310
|
+
case WSP_GGML_TYPE_Q4_0:
|
|
1311
|
+
case WSP_GGML_TYPE_Q4_1:
|
|
1312
|
+
case WSP_GGML_TYPE_Q5_0:
|
|
1313
|
+
case WSP_GGML_TYPE_Q5_1:
|
|
1314
|
+
case WSP_GGML_TYPE_Q8_0:
|
|
1315
|
+
switch (op->type) {
|
|
1316
|
+
case WSP_GGML_TYPE_F32:
|
|
1317
|
+
case WSP_GGML_TYPE_F16:
|
|
1318
|
+
return true;
|
|
1319
|
+
default:
|
|
1320
|
+
return false;
|
|
1321
|
+
}
|
|
1322
|
+
default:
|
|
1323
|
+
return false;
|
|
1324
|
+
};
|
|
1325
|
+
}
|
|
1326
|
+
case WSP_GGML_OP_SET:
|
|
1327
|
+
{
|
|
1328
|
+
switch (op->src[0]->type) {
|
|
1329
|
+
case WSP_GGML_TYPE_F32:
|
|
1330
|
+
case WSP_GGML_TYPE_I32:
|
|
1331
|
+
return true;
|
|
1066
1332
|
default:
|
|
1067
1333
|
return false;
|
|
1068
1334
|
};
|
|
@@ -1749,7 +2015,7 @@ static void wsp_ggml_metal_encode_node(
|
|
|
1749
2015
|
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
|
1750
2016
|
|
|
1751
2017
|
// TODO: add wsp_ggml_metal_kargs struct
|
|
1752
|
-
// TODO: optimize (see https://github.com/
|
|
2018
|
+
// TODO: optimize (see https://github.com/ggml-org/llama.cpp/pull/10238/commits/7941b6b9ec29a2866fec6fa6c51612515ca509f6)
|
|
1753
2019
|
[encoder setComputePipelineState:pipeline];
|
|
1754
2020
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1755
2021
|
if (id_src1) {
|
|
@@ -1922,345 +2188,495 @@ static void wsp_ggml_metal_encode_node(
|
|
|
1922
2188
|
WSP_GGML_ASSERT(ne12 % ne02 == 0);
|
|
1923
2189
|
WSP_GGML_ASSERT(ne13 % ne03 == 0);
|
|
1924
2190
|
|
|
1925
|
-
const
|
|
1926
|
-
const
|
|
2191
|
+
const uint32_t r2 = ne12/ne02;
|
|
2192
|
+
const uint32_t r3 = ne13/ne03;
|
|
1927
2193
|
|
|
1928
2194
|
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
|
|
1929
2195
|
// to the matrix-vector kernel
|
|
1930
|
-
int ne11_mm_min =
|
|
2196
|
+
const int ne11_mm_min = 4;
|
|
2197
|
+
|
|
2198
|
+
// first try to use small-batch mat-mv kernels
|
|
2199
|
+
// these should be efficient for BS [2, ~8]
|
|
2200
|
+
if (src1t == WSP_GGML_TYPE_F32 && (ne00%256 == 0) &&
|
|
2201
|
+
(
|
|
2202
|
+
(
|
|
2203
|
+
(
|
|
2204
|
+
src0t == WSP_GGML_TYPE_F16 || // TODO: helper function
|
|
2205
|
+
src0t == WSP_GGML_TYPE_Q4_0 ||
|
|
2206
|
+
src0t == WSP_GGML_TYPE_Q4_1 ||
|
|
2207
|
+
src0t == WSP_GGML_TYPE_Q5_0 ||
|
|
2208
|
+
src0t == WSP_GGML_TYPE_Q5_1 ||
|
|
2209
|
+
src0t == WSP_GGML_TYPE_Q8_0 ||
|
|
2210
|
+
src0t == WSP_GGML_TYPE_IQ4_NL ||
|
|
2211
|
+
false) && (ne11 >= 2 && ne11 <= 8)
|
|
2212
|
+
) ||
|
|
2213
|
+
(
|
|
2214
|
+
(
|
|
2215
|
+
src0t == WSP_GGML_TYPE_Q4_K ||
|
|
2216
|
+
src0t == WSP_GGML_TYPE_Q5_K ||
|
|
2217
|
+
src0t == WSP_GGML_TYPE_Q6_K ||
|
|
2218
|
+
false) && (ne11 >= 4 && ne11 <= 8)
|
|
2219
|
+
)
|
|
2220
|
+
)
|
|
2221
|
+
) {
|
|
2222
|
+
// TODO: determine the optimal parameters based on grid utilization
|
|
2223
|
+
// I still don't know why we should not always use the maximum available threads:
|
|
2224
|
+
//
|
|
2225
|
+
// nsg = pipeline.maxTotalThreadsPerThreadgroup / 32
|
|
2226
|
+
//
|
|
2227
|
+
// my current hypothesis is that the work grid is not evenly divisible for different nsg
|
|
2228
|
+
// values and there can be some tail effects when nsg is high. need to confirm this
|
|
2229
|
+
//
|
|
2230
|
+
const int nsg = 2; // num simdgroups per threadgroup
|
|
2231
|
+
const int nxpsg = ne11 < 3 ? 16 : 8; // num threads along row per simdgroup
|
|
2232
|
+
const int nypsg = 32/nxpsg; // num threads along col per simdgroup (i.e. a simdgroup processes that many src0 rows at a time)
|
|
2233
|
+
const int r0ptg = nypsg*nsg; // num src0 rows per threadgroup
|
|
2234
|
+
int r1ptg = 4; // num src1 rows per threadgroup
|
|
2235
|
+
|
|
2236
|
+
// note: not sure how optimal are those across all different hardware. there might be someting cleverer
|
|
2237
|
+
switch (ne11) {
|
|
2238
|
+
case 2:
|
|
2239
|
+
r1ptg = 2; break;
|
|
2240
|
+
case 3:
|
|
2241
|
+
case 6:
|
|
2242
|
+
r1ptg = 3; break;
|
|
2243
|
+
case 4:
|
|
2244
|
+
case 7:
|
|
2245
|
+
case 8:
|
|
2246
|
+
r1ptg = 4; break;
|
|
2247
|
+
case 5:
|
|
2248
|
+
r1ptg = 5; break;
|
|
2249
|
+
};
|
|
1931
2250
|
|
|
1932
|
-
|
|
1933
|
-
// the numbers below are measured on M2 Ultra for 7B and 13B models
|
|
1934
|
-
// these numbers do not translate to other devices or model sizes
|
|
1935
|
-
// TODO: need to find a better approach
|
|
1936
|
-
if ([device.name isEqualToString:@"Apple M2 Ultra"]) {
|
|
1937
|
-
switch (src0t) {
|
|
1938
|
-
case WSP_GGML_TYPE_F16: ne11_mm_min = 2; break;
|
|
1939
|
-
case WSP_GGML_TYPE_Q8_0: ne11_mm_min = 7; break;
|
|
1940
|
-
case WSP_GGML_TYPE_Q2_K: ne11_mm_min = 15; break;
|
|
1941
|
-
case WSP_GGML_TYPE_Q3_K: ne11_mm_min = 7; break;
|
|
1942
|
-
case WSP_GGML_TYPE_Q4_0:
|
|
1943
|
-
case WSP_GGML_TYPE_Q4_1: ne11_mm_min = 15; break;
|
|
1944
|
-
case WSP_GGML_TYPE_Q4_K: ne11_mm_min = 11; break;
|
|
1945
|
-
case WSP_GGML_TYPE_Q5_0: // not tested yet
|
|
1946
|
-
case WSP_GGML_TYPE_Q5_1: ne11_mm_min = 13; break; // not tested yet
|
|
1947
|
-
case WSP_GGML_TYPE_Q5_K: ne11_mm_min = 7; break;
|
|
1948
|
-
case WSP_GGML_TYPE_Q6_K: ne11_mm_min = 7; break;
|
|
1949
|
-
default: ne11_mm_min = 1; break;
|
|
1950
|
-
}
|
|
1951
|
-
}
|
|
1952
|
-
#endif
|
|
2251
|
+
id<MTLComputePipelineState> pipeline = nil;
|
|
1953
2252
|
|
|
1954
|
-
|
|
1955
|
-
|
|
1956
|
-
|
|
1957
|
-
|
|
1958
|
-
|
|
1959
|
-
|
|
1960
|
-
|
|
1961
|
-
|
|
1962
|
-
|
|
1963
|
-
|
|
1964
|
-
|
|
1965
|
-
|
|
1966
|
-
|
|
1967
|
-
case
|
|
1968
|
-
case
|
|
1969
|
-
|
|
1970
|
-
|
|
1971
|
-
|
|
2253
|
+
switch (src0->type) {
|
|
2254
|
+
case WSP_GGML_TYPE_F16:
|
|
2255
|
+
switch (r1ptg) {
|
|
2256
|
+
case 2: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2].pipeline; break;
|
|
2257
|
+
case 3: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3].pipeline; break;
|
|
2258
|
+
case 4: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4].pipeline; break;
|
|
2259
|
+
case 5: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_5].pipeline; break;
|
|
2260
|
+
default: WSP_GGML_ABORT("not implemented");
|
|
2261
|
+
} break;
|
|
2262
|
+
case WSP_GGML_TYPE_Q4_0:
|
|
2263
|
+
switch (r1ptg) {
|
|
2264
|
+
case 2: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_2].pipeline; break;
|
|
2265
|
+
case 3: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_3].pipeline; break;
|
|
2266
|
+
case 4: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_4].pipeline; break;
|
|
2267
|
+
case 5: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_5].pipeline; break;
|
|
2268
|
+
default: WSP_GGML_ABORT("not implemented");
|
|
2269
|
+
} break;
|
|
2270
|
+
case WSP_GGML_TYPE_Q4_1:
|
|
2271
|
+
switch (r1ptg) {
|
|
2272
|
+
case 2: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_2].pipeline; break;
|
|
2273
|
+
case 3: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_3].pipeline; break;
|
|
2274
|
+
case 4: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_4].pipeline; break;
|
|
2275
|
+
case 5: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_5].pipeline; break;
|
|
2276
|
+
default: WSP_GGML_ABORT("not implemented");
|
|
2277
|
+
} break;
|
|
2278
|
+
case WSP_GGML_TYPE_Q5_0:
|
|
2279
|
+
switch (r1ptg) {
|
|
2280
|
+
case 2: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_2].pipeline; break;
|
|
2281
|
+
case 3: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_3].pipeline; break;
|
|
2282
|
+
case 4: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_4].pipeline; break;
|
|
2283
|
+
case 5: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_5].pipeline; break;
|
|
2284
|
+
default: WSP_GGML_ABORT("not implemented");
|
|
2285
|
+
} break;
|
|
2286
|
+
case WSP_GGML_TYPE_Q5_1:
|
|
2287
|
+
switch (r1ptg) {
|
|
2288
|
+
case 2: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_2].pipeline; break;
|
|
2289
|
+
case 3: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_3].pipeline; break;
|
|
2290
|
+
case 4: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_4].pipeline; break;
|
|
2291
|
+
case 5: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_5].pipeline; break;
|
|
2292
|
+
default: WSP_GGML_ABORT("not implemented");
|
|
2293
|
+
} break;
|
|
2294
|
+
case WSP_GGML_TYPE_Q8_0:
|
|
2295
|
+
switch (r1ptg) {
|
|
2296
|
+
case 2: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_2].pipeline; break;
|
|
2297
|
+
case 3: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3].pipeline; break;
|
|
2298
|
+
case 4: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4].pipeline; break;
|
|
2299
|
+
case 5: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5].pipeline; break;
|
|
2300
|
+
default: WSP_GGML_ABORT("not implemented");
|
|
2301
|
+
} break;
|
|
2302
|
+
case WSP_GGML_TYPE_Q4_K:
|
|
2303
|
+
switch (r1ptg) {
|
|
2304
|
+
case 2: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2].pipeline; break;
|
|
2305
|
+
case 3: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3].pipeline; break;
|
|
2306
|
+
case 4: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4].pipeline; break;
|
|
2307
|
+
case 5: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_5].pipeline; break;
|
|
2308
|
+
default: WSP_GGML_ABORT("not implemented");
|
|
2309
|
+
} break;
|
|
2310
|
+
case WSP_GGML_TYPE_Q5_K:
|
|
2311
|
+
switch (r1ptg) {
|
|
2312
|
+
case 2: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_2].pipeline; break;
|
|
2313
|
+
case 3: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_3].pipeline; break;
|
|
2314
|
+
case 4: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_4].pipeline; break;
|
|
2315
|
+
case 5: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_5].pipeline; break;
|
|
2316
|
+
default: WSP_GGML_ABORT("not implemented");
|
|
2317
|
+
} break;
|
|
2318
|
+
case WSP_GGML_TYPE_Q6_K:
|
|
2319
|
+
switch (r1ptg) {
|
|
2320
|
+
case 2: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_2].pipeline; break;
|
|
2321
|
+
case 3: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_3].pipeline; break;
|
|
2322
|
+
case 4: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_4].pipeline; break;
|
|
2323
|
+
case 5: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_5].pipeline; break;
|
|
2324
|
+
default: WSP_GGML_ABORT("not implemented");
|
|
2325
|
+
} break;
|
|
2326
|
+
case WSP_GGML_TYPE_IQ4_NL:
|
|
2327
|
+
switch (r1ptg) {
|
|
2328
|
+
case 2: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_2].pipeline; break;
|
|
2329
|
+
case 3: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_3].pipeline; break;
|
|
2330
|
+
case 4: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_4].pipeline; break;
|
|
2331
|
+
case 5: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_5].pipeline; break;
|
|
2332
|
+
default: WSP_GGML_ABORT("not implemented");
|
|
2333
|
+
} break;
|
|
2334
|
+
default: WSP_GGML_ABORT("not implemented");
|
|
2335
|
+
}
|
|
1972
2336
|
|
|
1973
|
-
|
|
1974
|
-
|
|
1975
|
-
|
|
1976
|
-
|
|
1977
|
-
|
|
1978
|
-
|
|
1979
|
-
|
|
1980
|
-
|
|
1981
|
-
|
|
1982
|
-
|
|
1983
|
-
|
|
1984
|
-
|
|
1985
|
-
|
|
1986
|
-
|
|
1987
|
-
|
|
1988
|
-
|
|
1989
|
-
|
|
1990
|
-
|
|
1991
|
-
|
|
1992
|
-
|
|
1993
|
-
|
|
1994
|
-
|
|
1995
|
-
|
|
1996
|
-
case WSP_GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break;
|
|
1997
|
-
case WSP_GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break;
|
|
1998
|
-
default: WSP_GGML_ABORT("MUL MAT-MAT not implemented");
|
|
1999
|
-
}
|
|
2337
|
+
wsp_ggml_metal_kargs_mul_mv_ext args = {
|
|
2338
|
+
/*.ne00 =*/ ne00,
|
|
2339
|
+
/*.ne01 =*/ ne01,
|
|
2340
|
+
/*.ne02 =*/ ne02,
|
|
2341
|
+
/*.nb00 =*/ nb00,
|
|
2342
|
+
/*.nb01 =*/ nb01,
|
|
2343
|
+
/*.nb02 =*/ nb02,
|
|
2344
|
+
/*.nb03 =*/ nb03,
|
|
2345
|
+
/*.ne10 =*/ ne10,
|
|
2346
|
+
/*.ne11 =*/ ne11,
|
|
2347
|
+
/*.ne12 =*/ ne12,
|
|
2348
|
+
/*.nb10 =*/ nb10,
|
|
2349
|
+
/*.nb11 =*/ nb11,
|
|
2350
|
+
/*.nb12 =*/ nb12,
|
|
2351
|
+
/*.nb13 =*/ nb13,
|
|
2352
|
+
/*.ne0 =*/ ne0,
|
|
2353
|
+
/*.ne1 =*/ ne1,
|
|
2354
|
+
/*.r2 =*/ r2,
|
|
2355
|
+
/*.r3 =*/ r3,
|
|
2356
|
+
/*.nsg =*/ nsg,
|
|
2357
|
+
/*.nxpsg =*/ nxpsg,
|
|
2358
|
+
/*.r1ptg =*/ r1ptg,
|
|
2359
|
+
};
|
|
2000
2360
|
|
|
2001
|
-
|
|
2002
|
-
|
|
2003
|
-
|
|
2004
|
-
|
|
2005
|
-
|
|
2006
|
-
|
|
2007
|
-
|
|
2008
|
-
|
|
2009
|
-
|
|
2010
|
-
|
|
2011
|
-
|
|
2012
|
-
|
|
2013
|
-
|
|
2014
|
-
|
|
2015
|
-
|
|
2016
|
-
|
|
2361
|
+
[encoder setComputePipelineState:pipeline];
|
|
2362
|
+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
|
2363
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
|
2364
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
|
|
2365
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
|
2366
|
+
|
|
2367
|
+
//printf("ne01 = %lld nr0ptg = %d\n", ne01, nr0ptg);
|
|
2368
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + r0ptg - 1)/r0ptg, (ne11 + r1ptg - 1)/r1ptg, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
|
|
2369
|
+
} else
|
|
2370
|
+
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
|
2371
|
+
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
|
2372
|
+
if ([device supportsFamily:MTLGPUFamilyApple7] &&
|
|
2373
|
+
!wsp_ggml_is_transposed(src0) &&
|
|
2374
|
+
!wsp_ggml_is_transposed(src1) &&
|
|
2375
|
+
src1t == WSP_GGML_TYPE_F32 &&
|
|
2376
|
+
ne00 % 32 == 0 && ne00 >= 64 &&
|
|
2377
|
+
(ne11 > ne11_mm_min || (wsp_ggml_is_quantized(src0t) && ne12 > 1))) {
|
|
2378
|
+
//printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
|
|
2017
2379
|
|
|
2018
|
-
|
|
2019
|
-
|
|
2020
|
-
|
|
2021
|
-
|
|
2022
|
-
|
|
2023
|
-
|
|
2024
|
-
|
|
2025
|
-
|
|
2026
|
-
|
|
2027
|
-
|
|
2028
|
-
|
|
2029
|
-
|
|
2030
|
-
|
|
2031
|
-
|
|
2032
|
-
|
|
2033
|
-
|
|
2034
|
-
|
|
2035
|
-
|
|
2036
|
-
|
|
2037
|
-
|
|
2038
|
-
|
|
2039
|
-
|
|
2380
|
+
// some Metal matrix data types require aligned pointers
|
|
2381
|
+
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
|
|
2382
|
+
switch (src0->type) {
|
|
2383
|
+
case WSP_GGML_TYPE_F32: WSP_GGML_ASSERT(nb01 % 16 == 0); break;
|
|
2384
|
+
case WSP_GGML_TYPE_F16: WSP_GGML_ASSERT(nb01 % 8 == 0); break;
|
|
2385
|
+
case WSP_GGML_TYPE_BF16: WSP_GGML_ASSERT(nb01 % 8 == 0); break;
|
|
2386
|
+
default: break;
|
|
2387
|
+
}
|
|
2388
|
+
|
|
2389
|
+
id<MTLComputePipelineState> pipeline = nil;
|
|
2390
|
+
|
|
2391
|
+
switch (src0->type) {
|
|
2392
|
+
case WSP_GGML_TYPE_F32: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32 ].pipeline; break;
|
|
2393
|
+
case WSP_GGML_TYPE_F16: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32 ].pipeline; break;
|
|
2394
|
+
case WSP_GGML_TYPE_BF16: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32 ].pipeline; break;
|
|
2395
|
+
case WSP_GGML_TYPE_Q4_0: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32 ].pipeline; break;
|
|
2396
|
+
case WSP_GGML_TYPE_Q4_1: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32 ].pipeline; break;
|
|
2397
|
+
case WSP_GGML_TYPE_Q5_0: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break;
|
|
2398
|
+
case WSP_GGML_TYPE_Q5_1: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32 ].pipeline; break;
|
|
2399
|
+
case WSP_GGML_TYPE_Q8_0: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32 ].pipeline; break;
|
|
2400
|
+
case WSP_GGML_TYPE_Q2_K: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32 ].pipeline; break;
|
|
2401
|
+
case WSP_GGML_TYPE_Q3_K: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32 ].pipeline; break;
|
|
2402
|
+
case WSP_GGML_TYPE_Q4_K: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32 ].pipeline; break;
|
|
2403
|
+
case WSP_GGML_TYPE_Q5_K: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32 ].pipeline; break;
|
|
2404
|
+
case WSP_GGML_TYPE_Q6_K: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32 ].pipeline; break;
|
|
2405
|
+
case WSP_GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break;
|
|
2406
|
+
case WSP_GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break;
|
|
2407
|
+
case WSP_GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32].pipeline; break;
|
|
2408
|
+
case WSP_GGML_TYPE_IQ3_S: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32 ].pipeline; break;
|
|
2409
|
+
case WSP_GGML_TYPE_IQ2_S: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32 ].pipeline; break;
|
|
2410
|
+
case WSP_GGML_TYPE_IQ1_S: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32 ].pipeline; break;
|
|
2411
|
+
case WSP_GGML_TYPE_IQ1_M: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32 ].pipeline; break;
|
|
2412
|
+
case WSP_GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break;
|
|
2413
|
+
case WSP_GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break;
|
|
2414
|
+
default: WSP_GGML_ABORT("MUL MAT-MAT not implemented");
|
|
2415
|
+
}
|
|
2416
|
+
|
|
2417
|
+
wsp_ggml_metal_kargs_mul_mm args = {
|
|
2418
|
+
/*.ne00 =*/ ne00,
|
|
2419
|
+
/*.ne02 =*/ ne02,
|
|
2420
|
+
/*.nb01 =*/ nb01,
|
|
2421
|
+
/*.nb02 =*/ nb02,
|
|
2422
|
+
/*.nb03 =*/ nb03,
|
|
2423
|
+
/*.ne12 =*/ ne12,
|
|
2424
|
+
/*.nb10 =*/ nb10,
|
|
2425
|
+
/*.nb11 =*/ nb11,
|
|
2426
|
+
/*.nb12 =*/ nb12,
|
|
2427
|
+
/*.nb13 =*/ nb13,
|
|
2428
|
+
/*.ne0 =*/ ne0,
|
|
2429
|
+
/*.ne1 =*/ ne1,
|
|
2430
|
+
/*.r2 =*/ r2,
|
|
2431
|
+
/*.r3 =*/ r3,
|
|
2432
|
+
};
|
|
2433
|
+
|
|
2434
|
+
[encoder setComputePipelineState:pipeline];
|
|
2435
|
+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
|
2436
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
|
2437
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
|
|
2438
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
|
2439
|
+
|
|
2440
|
+
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
|
2441
|
+
[encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
|
2442
|
+
} else {
|
|
2443
|
+
int nth0 = 32;
|
|
2444
|
+
int nth1 = 1;
|
|
2445
|
+
int nrows = 1;
|
|
2446
|
+
//printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
|
|
2447
|
+
|
|
2448
|
+
id<MTLComputePipelineState> pipeline = nil;
|
|
2449
|
+
|
|
2450
|
+
// use custom matrix x vector kernel
|
|
2451
|
+
switch (src0t) {
|
|
2452
|
+
case WSP_GGML_TYPE_F32:
|
|
2453
|
+
{
|
|
2454
|
+
WSP_GGML_ASSERT(src1t == WSP_GGML_TYPE_F32);
|
|
2455
|
+
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline;
|
|
2456
|
+
nrows = 4;
|
|
2457
|
+
} break;
|
|
2458
|
+
case WSP_GGML_TYPE_F16:
|
|
2459
|
+
{
|
|
2460
|
+
nth0 = 32;
|
|
2461
|
+
nth1 = 1;
|
|
2462
|
+
if (src1t == WSP_GGML_TYPE_F32) {
|
|
2463
|
+
if (ne11 * ne12 < 4) {
|
|
2464
|
+
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline;
|
|
2465
|
+
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
|
|
2466
|
+
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline;
|
|
2467
|
+
nrows = ne11;
|
|
2468
|
+
} else {
|
|
2469
|
+
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32].pipeline;
|
|
2040
2470
|
nrows = 4;
|
|
2041
|
-
} break;
|
|
2042
|
-
case WSP_GGML_TYPE_F16:
|
|
2043
|
-
{
|
|
2044
|
-
nth0 = 32;
|
|
2045
|
-
nth1 = 1;
|
|
2046
|
-
if (src1t == WSP_GGML_TYPE_F32) {
|
|
2047
|
-
if (ne11 * ne12 < 4) {
|
|
2048
|
-
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline;
|
|
2049
|
-
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
|
|
2050
|
-
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline;
|
|
2051
|
-
nrows = ne11;
|
|
2052
|
-
} else {
|
|
2053
|
-
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32].pipeline;
|
|
2054
|
-
nrows = 4;
|
|
2055
|
-
}
|
|
2056
|
-
} else {
|
|
2057
|
-
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16].pipeline;
|
|
2058
|
-
nrows = 4;
|
|
2059
|
-
}
|
|
2060
|
-
} break;
|
|
2061
|
-
case WSP_GGML_TYPE_BF16:
|
|
2062
|
-
{
|
|
2063
|
-
nth0 = 32;
|
|
2064
|
-
nth1 = 1;
|
|
2065
|
-
if (src1t == WSP_GGML_TYPE_F32) {
|
|
2066
|
-
if (ne11 * ne12 < 4) {
|
|
2067
|
-
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW].pipeline;
|
|
2068
|
-
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
|
|
2069
|
-
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4].pipeline;
|
|
2070
|
-
nrows = ne11;
|
|
2071
|
-
} else {
|
|
2072
|
-
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32].pipeline;
|
|
2073
|
-
nrows = 4;
|
|
2074
|
-
}
|
|
2075
|
-
} else {
|
|
2076
|
-
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16].pipeline;
|
|
2077
|
-
nrows = 4;
|
|
2078
|
-
}
|
|
2079
|
-
} break;
|
|
2080
|
-
case WSP_GGML_TYPE_Q4_0:
|
|
2081
|
-
{
|
|
2082
|
-
nth0 = 8;
|
|
2083
|
-
nth1 = 8;
|
|
2084
|
-
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32].pipeline;
|
|
2085
|
-
} break;
|
|
2086
|
-
case WSP_GGML_TYPE_Q4_1:
|
|
2087
|
-
{
|
|
2088
|
-
nth0 = 8;
|
|
2089
|
-
nth1 = 8;
|
|
2090
|
-
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32].pipeline;
|
|
2091
|
-
} break;
|
|
2092
|
-
case WSP_GGML_TYPE_Q5_0:
|
|
2093
|
-
{
|
|
2094
|
-
nth0 = 8;
|
|
2095
|
-
nth1 = 8;
|
|
2096
|
-
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32].pipeline;
|
|
2097
|
-
} break;
|
|
2098
|
-
case WSP_GGML_TYPE_Q5_1:
|
|
2099
|
-
{
|
|
2100
|
-
nth0 = 8;
|
|
2101
|
-
nth1 = 8;
|
|
2102
|
-
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32].pipeline;
|
|
2103
|
-
} break;
|
|
2104
|
-
case WSP_GGML_TYPE_Q8_0:
|
|
2105
|
-
{
|
|
2106
|
-
nth0 = 8;
|
|
2107
|
-
nth1 = 8;
|
|
2108
|
-
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline;
|
|
2109
|
-
} break;
|
|
2110
|
-
case WSP_GGML_TYPE_Q2_K:
|
|
2111
|
-
{
|
|
2112
|
-
nth0 = 2;
|
|
2113
|
-
nth1 = 32;
|
|
2114
|
-
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32].pipeline;
|
|
2115
|
-
} break;
|
|
2116
|
-
case WSP_GGML_TYPE_Q3_K:
|
|
2117
|
-
{
|
|
2118
|
-
nth0 = 2;
|
|
2119
|
-
nth1 = 32;
|
|
2120
|
-
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32].pipeline;
|
|
2121
|
-
} break;
|
|
2122
|
-
case WSP_GGML_TYPE_Q4_K:
|
|
2123
|
-
{
|
|
2124
|
-
nth0 = 4; //1;
|
|
2125
|
-
nth1 = 8; //32;
|
|
2126
|
-
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32].pipeline;
|
|
2127
|
-
} break;
|
|
2128
|
-
case WSP_GGML_TYPE_Q5_K:
|
|
2129
|
-
{
|
|
2130
|
-
nth0 = 2;
|
|
2131
|
-
nth1 = 32;
|
|
2132
|
-
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32].pipeline;
|
|
2133
|
-
} break;
|
|
2134
|
-
case WSP_GGML_TYPE_Q6_K:
|
|
2135
|
-
{
|
|
2136
|
-
nth0 = 2;
|
|
2137
|
-
nth1 = 32;
|
|
2138
|
-
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32].pipeline;
|
|
2139
|
-
} break;
|
|
2140
|
-
case WSP_GGML_TYPE_IQ2_XXS:
|
|
2141
|
-
{
|
|
2142
|
-
nth0 = 4;
|
|
2143
|
-
nth1 = 16;
|
|
2144
|
-
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32].pipeline;
|
|
2145
|
-
} break;
|
|
2146
|
-
case WSP_GGML_TYPE_IQ2_XS:
|
|
2147
|
-
{
|
|
2148
|
-
nth0 = 4;
|
|
2149
|
-
nth1 = 16;
|
|
2150
|
-
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32].pipeline;
|
|
2151
|
-
} break;
|
|
2152
|
-
case WSP_GGML_TYPE_IQ3_XXS:
|
|
2153
|
-
{
|
|
2154
|
-
nth0 = 4;
|
|
2155
|
-
nth1 = 16;
|
|
2156
|
-
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline;
|
|
2157
|
-
} break;
|
|
2158
|
-
case WSP_GGML_TYPE_IQ3_S:
|
|
2159
|
-
{
|
|
2160
|
-
nth0 = 4;
|
|
2161
|
-
nth1 = 16;
|
|
2162
|
-
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32].pipeline;
|
|
2163
|
-
} break;
|
|
2164
|
-
case WSP_GGML_TYPE_IQ2_S:
|
|
2165
|
-
{
|
|
2166
|
-
nth0 = 4;
|
|
2167
|
-
nth1 = 16;
|
|
2168
|
-
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32].pipeline;
|
|
2169
|
-
} break;
|
|
2170
|
-
case WSP_GGML_TYPE_IQ1_S:
|
|
2171
|
-
{
|
|
2172
|
-
nth0 = 4;
|
|
2173
|
-
nth1 = 16;
|
|
2174
|
-
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32].pipeline;
|
|
2175
|
-
} break;
|
|
2176
|
-
case WSP_GGML_TYPE_IQ1_M:
|
|
2177
|
-
{
|
|
2178
|
-
nth0 = 4;
|
|
2179
|
-
nth1 = 16;
|
|
2180
|
-
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32].pipeline;
|
|
2181
|
-
} break;
|
|
2182
|
-
case WSP_GGML_TYPE_IQ4_NL:
|
|
2183
|
-
{
|
|
2184
|
-
nth0 = 4;
|
|
2185
|
-
nth1 = 16;
|
|
2186
|
-
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32].pipeline;
|
|
2187
|
-
} break;
|
|
2188
|
-
case WSP_GGML_TYPE_IQ4_XS:
|
|
2189
|
-
{
|
|
2190
|
-
nth0 = 4;
|
|
2191
|
-
nth1 = 16;
|
|
2192
|
-
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32].pipeline;
|
|
2193
|
-
} break;
|
|
2194
|
-
default:
|
|
2195
|
-
{
|
|
2196
|
-
WSP_GGML_LOG_ERROR("Asserting on type %d\n", (int)src0t);
|
|
2197
|
-
WSP_GGML_ABORT("not implemented");
|
|
2198
2471
|
}
|
|
2199
|
-
|
|
2472
|
+
} else {
|
|
2473
|
+
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16].pipeline;
|
|
2474
|
+
nrows = 4;
|
|
2475
|
+
}
|
|
2476
|
+
} break;
|
|
2477
|
+
case WSP_GGML_TYPE_BF16:
|
|
2478
|
+
{
|
|
2479
|
+
nth0 = 32;
|
|
2480
|
+
nth1 = 1;
|
|
2481
|
+
if (src1t == WSP_GGML_TYPE_F32) {
|
|
2482
|
+
if (ne11 * ne12 < 4) {
|
|
2483
|
+
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW].pipeline;
|
|
2484
|
+
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
|
|
2485
|
+
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4].pipeline;
|
|
2486
|
+
nrows = ne11;
|
|
2487
|
+
} else {
|
|
2488
|
+
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32].pipeline;
|
|
2489
|
+
nrows = 4;
|
|
2490
|
+
}
|
|
2491
|
+
} else {
|
|
2492
|
+
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16].pipeline;
|
|
2493
|
+
nrows = 4;
|
|
2494
|
+
}
|
|
2495
|
+
} break;
|
|
2496
|
+
case WSP_GGML_TYPE_Q4_0:
|
|
2497
|
+
{
|
|
2498
|
+
nth0 = 8;
|
|
2499
|
+
nth1 = 8;
|
|
2500
|
+
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32].pipeline;
|
|
2501
|
+
} break;
|
|
2502
|
+
case WSP_GGML_TYPE_Q4_1:
|
|
2503
|
+
{
|
|
2504
|
+
nth0 = 8;
|
|
2505
|
+
nth1 = 8;
|
|
2506
|
+
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32].pipeline;
|
|
2507
|
+
} break;
|
|
2508
|
+
case WSP_GGML_TYPE_Q5_0:
|
|
2509
|
+
{
|
|
2510
|
+
nth0 = 8;
|
|
2511
|
+
nth1 = 8;
|
|
2512
|
+
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32].pipeline;
|
|
2513
|
+
} break;
|
|
2514
|
+
case WSP_GGML_TYPE_Q5_1:
|
|
2515
|
+
{
|
|
2516
|
+
nth0 = 8;
|
|
2517
|
+
nth1 = 8;
|
|
2518
|
+
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32].pipeline;
|
|
2519
|
+
} break;
|
|
2520
|
+
case WSP_GGML_TYPE_Q8_0:
|
|
2521
|
+
{
|
|
2522
|
+
nth0 = 8;
|
|
2523
|
+
nth1 = 8;
|
|
2524
|
+
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline;
|
|
2525
|
+
} break;
|
|
2526
|
+
case WSP_GGML_TYPE_Q2_K:
|
|
2527
|
+
{
|
|
2528
|
+
nth0 = 2;
|
|
2529
|
+
nth1 = 32;
|
|
2530
|
+
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32].pipeline;
|
|
2531
|
+
} break;
|
|
2532
|
+
case WSP_GGML_TYPE_Q3_K:
|
|
2533
|
+
{
|
|
2534
|
+
nth0 = 2;
|
|
2535
|
+
nth1 = 32;
|
|
2536
|
+
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32].pipeline;
|
|
2537
|
+
} break;
|
|
2538
|
+
case WSP_GGML_TYPE_Q4_K:
|
|
2539
|
+
{
|
|
2540
|
+
nth0 = 4; //1;
|
|
2541
|
+
nth1 = 8; //32;
|
|
2542
|
+
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32].pipeline;
|
|
2543
|
+
} break;
|
|
2544
|
+
case WSP_GGML_TYPE_Q5_K:
|
|
2545
|
+
{
|
|
2546
|
+
nth0 = 2;
|
|
2547
|
+
nth1 = 32;
|
|
2548
|
+
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32].pipeline;
|
|
2549
|
+
} break;
|
|
2550
|
+
case WSP_GGML_TYPE_Q6_K:
|
|
2551
|
+
{
|
|
2552
|
+
nth0 = 2;
|
|
2553
|
+
nth1 = 32;
|
|
2554
|
+
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32].pipeline;
|
|
2555
|
+
} break;
|
|
2556
|
+
case WSP_GGML_TYPE_IQ2_XXS:
|
|
2557
|
+
{
|
|
2558
|
+
nth0 = 4;
|
|
2559
|
+
nth1 = 16;
|
|
2560
|
+
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32].pipeline;
|
|
2561
|
+
} break;
|
|
2562
|
+
case WSP_GGML_TYPE_IQ2_XS:
|
|
2563
|
+
{
|
|
2564
|
+
nth0 = 4;
|
|
2565
|
+
nth1 = 16;
|
|
2566
|
+
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32].pipeline;
|
|
2567
|
+
} break;
|
|
2568
|
+
case WSP_GGML_TYPE_IQ3_XXS:
|
|
2569
|
+
{
|
|
2570
|
+
nth0 = 4;
|
|
2571
|
+
nth1 = 16;
|
|
2572
|
+
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline;
|
|
2573
|
+
} break;
|
|
2574
|
+
case WSP_GGML_TYPE_IQ3_S:
|
|
2575
|
+
{
|
|
2576
|
+
nth0 = 4;
|
|
2577
|
+
nth1 = 16;
|
|
2578
|
+
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32].pipeline;
|
|
2579
|
+
} break;
|
|
2580
|
+
case WSP_GGML_TYPE_IQ2_S:
|
|
2581
|
+
{
|
|
2582
|
+
nth0 = 4;
|
|
2583
|
+
nth1 = 16;
|
|
2584
|
+
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32].pipeline;
|
|
2585
|
+
} break;
|
|
2586
|
+
case WSP_GGML_TYPE_IQ1_S:
|
|
2587
|
+
{
|
|
2588
|
+
nth0 = 4;
|
|
2589
|
+
nth1 = 16;
|
|
2590
|
+
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32].pipeline;
|
|
2591
|
+
} break;
|
|
2592
|
+
case WSP_GGML_TYPE_IQ1_M:
|
|
2593
|
+
{
|
|
2594
|
+
nth0 = 4;
|
|
2595
|
+
nth1 = 16;
|
|
2596
|
+
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32].pipeline;
|
|
2597
|
+
} break;
|
|
2598
|
+
case WSP_GGML_TYPE_IQ4_NL:
|
|
2599
|
+
{
|
|
2600
|
+
nth0 = 4;
|
|
2601
|
+
nth1 = 16;
|
|
2602
|
+
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32].pipeline;
|
|
2603
|
+
} break;
|
|
2604
|
+
case WSP_GGML_TYPE_IQ4_XS:
|
|
2605
|
+
{
|
|
2606
|
+
nth0 = 4;
|
|
2607
|
+
nth1 = 16;
|
|
2608
|
+
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32].pipeline;
|
|
2609
|
+
} break;
|
|
2610
|
+
default:
|
|
2611
|
+
{
|
|
2612
|
+
WSP_GGML_LOG_ERROR("Asserting on type %d\n", (int)src0t);
|
|
2613
|
+
WSP_GGML_ABORT("not implemented");
|
|
2614
|
+
}
|
|
2615
|
+
};
|
|
2200
2616
|
|
|
2201
|
-
|
|
2202
|
-
|
|
2203
|
-
|
|
2204
|
-
|
|
2205
|
-
|
|
2206
|
-
|
|
2207
|
-
|
|
2208
|
-
|
|
2209
|
-
|
|
2210
|
-
|
|
2211
|
-
|
|
2212
|
-
|
|
2213
|
-
|
|
2214
|
-
|
|
2215
|
-
|
|
2216
|
-
|
|
2217
|
-
|
|
2218
|
-
|
|
2219
|
-
|
|
2220
|
-
|
|
2617
|
+
wsp_ggml_metal_kargs_mul_mv args = {
|
|
2618
|
+
/*.ne00 =*/ ne00,
|
|
2619
|
+
/*.ne01 =*/ ne01,
|
|
2620
|
+
/*.ne02 =*/ ne02,
|
|
2621
|
+
/*.nb00 =*/ nb00,
|
|
2622
|
+
/*.nb01 =*/ nb01,
|
|
2623
|
+
/*.nb02 =*/ nb02,
|
|
2624
|
+
/*.nb03 =*/ nb03,
|
|
2625
|
+
/*.ne10 =*/ ne10,
|
|
2626
|
+
/*.ne11 =*/ ne11,
|
|
2627
|
+
/*.ne12 =*/ ne12,
|
|
2628
|
+
/*.nb10 =*/ nb10,
|
|
2629
|
+
/*.nb11 =*/ nb11,
|
|
2630
|
+
/*.nb12 =*/ nb12,
|
|
2631
|
+
/*.nb13 =*/ nb13,
|
|
2632
|
+
/*.ne0 =*/ ne0,
|
|
2633
|
+
/*.ne1 =*/ ne1,
|
|
2634
|
+
/*.r2 =*/ r2,
|
|
2635
|
+
/*.r3 =*/ r3,
|
|
2636
|
+
};
|
|
2221
2637
|
|
|
2222
|
-
|
|
2223
|
-
|
|
2224
|
-
|
|
2225
|
-
|
|
2226
|
-
|
|
2638
|
+
[encoder setComputePipelineState:pipeline];
|
|
2639
|
+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
|
2640
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
|
2641
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
|
|
2642
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
|
2227
2643
|
|
|
2228
|
-
|
|
2229
|
-
|
|
2230
|
-
|
|
2231
|
-
|
|
2232
|
-
|
|
2233
|
-
|
|
2234
|
-
|
|
2235
|
-
|
|
2236
|
-
|
|
2237
|
-
|
|
2238
|
-
|
|
2239
|
-
|
|
2240
|
-
|
|
2241
|
-
|
|
2242
|
-
|
|
2243
|
-
|
|
2244
|
-
|
|
2245
|
-
|
|
2246
|
-
|
|
2247
|
-
|
|
2248
|
-
|
|
2249
|
-
|
|
2250
|
-
|
|
2251
|
-
|
|
2252
|
-
|
|
2253
|
-
|
|
2254
|
-
|
|
2255
|
-
|
|
2256
|
-
|
|
2257
|
-
|
|
2258
|
-
|
|
2259
|
-
|
|
2260
|
-
|
|
2261
|
-
|
|
2262
|
-
|
|
2263
|
-
|
|
2644
|
+
if (src0t == WSP_GGML_TYPE_Q4_0 || src0t == WSP_GGML_TYPE_Q4_1 || src0t == WSP_GGML_TYPE_Q5_0 ||
|
|
2645
|
+
src0t == WSP_GGML_TYPE_Q5_1 || src0t == WSP_GGML_TYPE_Q8_0 || src0t == WSP_GGML_TYPE_Q2_K ||
|
|
2646
|
+
src0t == WSP_GGML_TYPE_IQ1_S || src0t == WSP_GGML_TYPE_IQ1_M || src0t == WSP_GGML_TYPE_IQ2_S) {
|
|
2647
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
2648
|
+
}
|
|
2649
|
+
else if (src0t == WSP_GGML_TYPE_IQ2_XXS || src0t == WSP_GGML_TYPE_IQ2_XS) {
|
|
2650
|
+
const int mem_size = src0t == WSP_GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
|
|
2651
|
+
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
|
2652
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
2653
|
+
}
|
|
2654
|
+
else if (src0t == WSP_GGML_TYPE_IQ3_XXS || src0t == WSP_GGML_TYPE_IQ3_S) {
|
|
2655
|
+
const int mem_size = src0t == WSP_GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
|
|
2656
|
+
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
|
2657
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
2658
|
+
}
|
|
2659
|
+
else if (src0t == WSP_GGML_TYPE_IQ4_NL || src0t == WSP_GGML_TYPE_IQ4_XS) {
|
|
2660
|
+
const int mem_size = 32*sizeof(float);
|
|
2661
|
+
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
|
2662
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
2663
|
+
}
|
|
2664
|
+
else if (src0t == WSP_GGML_TYPE_Q4_K) {
|
|
2665
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
2666
|
+
}
|
|
2667
|
+
else if (src0t == WSP_GGML_TYPE_Q3_K) {
|
|
2668
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
2669
|
+
}
|
|
2670
|
+
else if (src0t == WSP_GGML_TYPE_Q5_K) {
|
|
2671
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
2672
|
+
}
|
|
2673
|
+
else if (src0t == WSP_GGML_TYPE_Q6_K) {
|
|
2674
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
2675
|
+
} else {
|
|
2676
|
+
const int64_t ny = (ne11 + nrows - 1)/nrows;
|
|
2677
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
2678
|
+
}
|
|
2679
|
+
}
|
|
2264
2680
|
} break;
|
|
2265
2681
|
case WSP_GGML_OP_MUL_MAT_ID:
|
|
2266
2682
|
{
|
|
@@ -2672,7 +3088,6 @@ static void wsp_ggml_metal_encode_node(
|
|
|
2672
3088
|
} break;
|
|
2673
3089
|
case WSP_GGML_OP_GROUP_NORM:
|
|
2674
3090
|
{
|
|
2675
|
-
WSP_GGML_ASSERT(ne00 % 4 == 0);
|
|
2676
3091
|
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0));
|
|
2677
3092
|
|
|
2678
3093
|
float eps;
|
|
@@ -2742,7 +3157,9 @@ static void wsp_ggml_metal_encode_node(
|
|
|
2742
3157
|
} break;
|
|
2743
3158
|
case WSP_GGML_OP_ROPE:
|
|
2744
3159
|
{
|
|
2745
|
-
|
|
3160
|
+
// make sure we have one or more position id(ne10) per token(ne02)
|
|
3161
|
+
WSP_GGML_ASSERT(ne10 % ne02 == 0);
|
|
3162
|
+
WSP_GGML_ASSERT(ne10 >= ne02);
|
|
2746
3163
|
|
|
2747
3164
|
const int nth = MIN(1024, ne00);
|
|
2748
3165
|
|
|
@@ -2908,6 +3325,49 @@ static void wsp_ggml_metal_encode_node(
|
|
|
2908
3325
|
[encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
|
|
2909
3326
|
}
|
|
2910
3327
|
} break;
|
|
3328
|
+
case WSP_GGML_OP_CONV_TRANSPOSE_1D:
|
|
3329
|
+
{
|
|
3330
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0));
|
|
3331
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src1));
|
|
3332
|
+
WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F16 || src0->type == WSP_GGML_TYPE_F32);
|
|
3333
|
+
WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F32);
|
|
3334
|
+
WSP_GGML_ASSERT( dst->type == WSP_GGML_TYPE_F32);
|
|
3335
|
+
|
|
3336
|
+
const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
|
|
3337
|
+
|
|
3338
|
+
const int32_t IC = src1->ne[1];
|
|
3339
|
+
const int32_t IL = src1->ne[0];
|
|
3340
|
+
|
|
3341
|
+
const int32_t K = src0->ne[0];
|
|
3342
|
+
|
|
3343
|
+
const int32_t OL = dst->ne[0];
|
|
3344
|
+
const int32_t OC = dst->ne[1];
|
|
3345
|
+
|
|
3346
|
+
id<MTLComputePipelineState> pipeline;
|
|
3347
|
+
|
|
3348
|
+
switch (src0->type) {
|
|
3349
|
+
case WSP_GGML_TYPE_F32: {
|
|
3350
|
+
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F32_F32].pipeline;
|
|
3351
|
+
} break;
|
|
3352
|
+
case WSP_GGML_TYPE_F16: {
|
|
3353
|
+
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32].pipeline;
|
|
3354
|
+
} break;
|
|
3355
|
+
default: WSP_GGML_ABORT("fatal error");
|
|
3356
|
+
};
|
|
3357
|
+
|
|
3358
|
+
[encoder setComputePipelineState:pipeline];
|
|
3359
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
3360
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
3361
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
3362
|
+
[encoder setBytes:&IC length:sizeof( int32_t) atIndex:3];
|
|
3363
|
+
[encoder setBytes:&IL length:sizeof( int32_t) atIndex:4];
|
|
3364
|
+
[encoder setBytes:&K length:sizeof( int32_t) atIndex:5];
|
|
3365
|
+
[encoder setBytes:&s0 length:sizeof( int32_t) atIndex:6];
|
|
3366
|
+
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:7];
|
|
3367
|
+
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:8];
|
|
3368
|
+
|
|
3369
|
+
[encoder dispatchThreadgroups:MTLSizeMake(OL, OC, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
3370
|
+
} break;
|
|
2911
3371
|
case WSP_GGML_OP_UPSCALE:
|
|
2912
3372
|
{
|
|
2913
3373
|
WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F32);
|
|
@@ -2977,6 +3437,38 @@ static void wsp_ggml_metal_encode_node(
|
|
|
2977
3437
|
|
|
2978
3438
|
const int nth = MIN(1024, ne0);
|
|
2979
3439
|
|
|
3440
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
3441
|
+
} break;
|
|
3442
|
+
case WSP_GGML_OP_PAD_REFLECT_1D:
|
|
3443
|
+
{
|
|
3444
|
+
WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F32);
|
|
3445
|
+
|
|
3446
|
+
const int32_t p0 = ((const int32_t *)(dst->op_params))[0];
|
|
3447
|
+
const int32_t p1 = ((const int32_t *)(dst->op_params))[1];
|
|
3448
|
+
|
|
3449
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32].pipeline;
|
|
3450
|
+
|
|
3451
|
+
[encoder setComputePipelineState:pipeline];
|
|
3452
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
3453
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
3454
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
|
3455
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
|
3456
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
|
3457
|
+
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
|
|
3458
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:6];
|
|
3459
|
+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
|
|
3460
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
|
|
3461
|
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
|
|
3462
|
+
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
|
|
3463
|
+
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:11];
|
|
3464
|
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:12];
|
|
3465
|
+
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:13];
|
|
3466
|
+
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:14];
|
|
3467
|
+
[encoder setBytes:&p0 length:sizeof(p0) atIndex:15];
|
|
3468
|
+
[encoder setBytes:&p1 length:sizeof(p1) atIndex:16];
|
|
3469
|
+
|
|
3470
|
+
const int nth = MIN(1024, ne0);
|
|
3471
|
+
|
|
2980
3472
|
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
2981
3473
|
} break;
|
|
2982
3474
|
case WSP_GGML_OP_ARANGE:
|
|
@@ -3439,10 +3931,6 @@ static void wsp_ggml_metal_encode_node(
|
|
|
3439
3931
|
case WSP_GGML_OP_CPY:
|
|
3440
3932
|
case WSP_GGML_OP_CONT:
|
|
3441
3933
|
{
|
|
3442
|
-
WSP_GGML_ASSERT(ne00 % wsp_ggml_blck_size(src0->type) == 0);
|
|
3443
|
-
|
|
3444
|
-
int nth = MIN(1024, ne00/wsp_ggml_blck_size(src0->type));
|
|
3445
|
-
|
|
3446
3934
|
id<MTLComputePipelineState> pipeline = nil;
|
|
3447
3935
|
|
|
3448
3936
|
switch (src0t) {
|
|
@@ -3476,7 +3964,47 @@ static void wsp_ggml_metal_encode_node(
|
|
|
3476
3964
|
switch (dstt) {
|
|
3477
3965
|
case WSP_GGML_TYPE_F32: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_CPY_BF16_F32].pipeline; break;
|
|
3478
3966
|
case WSP_GGML_TYPE_BF16: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16].pipeline; break;
|
|
3479
|
-
default:
|
|
3967
|
+
default: WSP_GGML_ABORT("not implemented");
|
|
3968
|
+
};
|
|
3969
|
+
} break;
|
|
3970
|
+
case WSP_GGML_TYPE_Q4_0:
|
|
3971
|
+
{
|
|
3972
|
+
switch (dstt) {
|
|
3973
|
+
case WSP_GGML_TYPE_F32: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32].pipeline; break;
|
|
3974
|
+
case WSP_GGML_TYPE_F16: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F16].pipeline; break;
|
|
3975
|
+
default: WSP_GGML_ABORT("not implemented");
|
|
3976
|
+
};
|
|
3977
|
+
} break;
|
|
3978
|
+
case WSP_GGML_TYPE_Q4_1:
|
|
3979
|
+
{
|
|
3980
|
+
switch (dstt) {
|
|
3981
|
+
case WSP_GGML_TYPE_F32: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32].pipeline; break;
|
|
3982
|
+
case WSP_GGML_TYPE_F16: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F16].pipeline; break;
|
|
3983
|
+
default: WSP_GGML_ABORT("not implemented");
|
|
3984
|
+
};
|
|
3985
|
+
} break;
|
|
3986
|
+
case WSP_GGML_TYPE_Q5_0:
|
|
3987
|
+
{
|
|
3988
|
+
switch (dstt) {
|
|
3989
|
+
case WSP_GGML_TYPE_F32: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32].pipeline; break;
|
|
3990
|
+
case WSP_GGML_TYPE_F16: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F16].pipeline; break;
|
|
3991
|
+
default: WSP_GGML_ABORT("not implemented");
|
|
3992
|
+
};
|
|
3993
|
+
} break;
|
|
3994
|
+
case WSP_GGML_TYPE_Q5_1:
|
|
3995
|
+
{
|
|
3996
|
+
switch (dstt) {
|
|
3997
|
+
case WSP_GGML_TYPE_F32: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32].pipeline; break;
|
|
3998
|
+
case WSP_GGML_TYPE_F16: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F16].pipeline; break;
|
|
3999
|
+
default: WSP_GGML_ABORT("not implemented");
|
|
4000
|
+
};
|
|
4001
|
+
} break;
|
|
4002
|
+
case WSP_GGML_TYPE_Q8_0:
|
|
4003
|
+
{
|
|
4004
|
+
switch (dstt) {
|
|
4005
|
+
case WSP_GGML_TYPE_F32: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32].pipeline; break;
|
|
4006
|
+
case WSP_GGML_TYPE_F16: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16].pipeline; break;
|
|
4007
|
+
default: WSP_GGML_ABORT("not implemented");
|
|
3480
4008
|
};
|
|
3481
4009
|
} break;
|
|
3482
4010
|
default: WSP_GGML_ABORT("not implemented");
|
|
@@ -3506,7 +4034,73 @@ static void wsp_ggml_metal_encode_node(
|
|
|
3506
4034
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
|
3507
4035
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
3508
4036
|
|
|
4037
|
+
WSP_GGML_ASSERT(ne00 % wsp_ggml_blck_size(src0->type) == 0);
|
|
4038
|
+
int nth = MIN(1024, ne00/wsp_ggml_blck_size(src0->type));
|
|
4039
|
+
|
|
3509
4040
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
4041
|
+
|
|
4042
|
+
} break;
|
|
4043
|
+
case WSP_GGML_OP_SET:
|
|
4044
|
+
{
|
|
4045
|
+
WSP_GGML_ASSERT(wsp_ggml_are_same_shape(src0, dst));
|
|
4046
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(dst) && wsp_ggml_is_contiguous(src0));
|
|
4047
|
+
|
|
4048
|
+
// src0 and dst as viewed during set
|
|
4049
|
+
const size_t dst_nb0 = wsp_ggml_element_size(src0);
|
|
4050
|
+
|
|
4051
|
+
const size_t dst_nb1 = ((int32_t *) dst->op_params)[0];
|
|
4052
|
+
const size_t dst_nb2 = ((int32_t *) dst->op_params)[1];
|
|
4053
|
+
const size_t dst_nb3 = ((int32_t *) dst->op_params)[2];
|
|
4054
|
+
const size_t offset = ((int32_t *) dst->op_params)[3];
|
|
4055
|
+
const bool inplace = (bool) ((int32_t *) dst->op_params)[4];
|
|
4056
|
+
|
|
4057
|
+
if (!inplace) {
|
|
4058
|
+
memcpy(((char *) dst->data), ((char *) src0->data), wsp_ggml_nbytes(dst));
|
|
4059
|
+
}
|
|
4060
|
+
|
|
4061
|
+
const int im0 = (ne10 == 0 ? 0 : ne10-1);
|
|
4062
|
+
const int im1 = (ne11 == 0 ? 0 : ne11-1);
|
|
4063
|
+
const int im2 = (ne12 == 0 ? 0 : ne12-1);
|
|
4064
|
+
const int im3 = (ne13 == 0 ? 0 : ne13-1);
|
|
4065
|
+
|
|
4066
|
+
WSP_GGML_ASSERT(offset + im0*dst_nb0 + im1*dst_nb1 + im2*dst_nb2 + im3*dst_nb3 <= wsp_ggml_nbytes(dst));
|
|
4067
|
+
|
|
4068
|
+
id<MTLComputePipelineState> pipeline = nil;
|
|
4069
|
+
|
|
4070
|
+
switch (src0t) {
|
|
4071
|
+
case WSP_GGML_TYPE_F32:
|
|
4072
|
+
WSP_GGML_ASSERT(nb10 == sizeof(float));
|
|
4073
|
+
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_SET_F32].pipeline; break;
|
|
4074
|
+
case WSP_GGML_TYPE_I32:
|
|
4075
|
+
WSP_GGML_ASSERT(nb10 == sizeof(int32_t));
|
|
4076
|
+
pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_SET_I32].pipeline; break;
|
|
4077
|
+
default: WSP_GGML_ABORT("fatal error");
|
|
4078
|
+
}
|
|
4079
|
+
|
|
4080
|
+
wsp_ggml_metal_kargs_set args = {
|
|
4081
|
+
/*.ne10 =*/ ne10,
|
|
4082
|
+
/*.ne11 =*/ ne11,
|
|
4083
|
+
/*.ne12 =*/ ne12,
|
|
4084
|
+
/*.nb10 =*/ nb10,
|
|
4085
|
+
/*.nb11 =*/ nb11,
|
|
4086
|
+
/*.nb12 =*/ nb12,
|
|
4087
|
+
/*.nb13 =*/ nb13,
|
|
4088
|
+
/*.nb1 =*/ dst_nb1,
|
|
4089
|
+
/*.nb2 =*/ dst_nb2,
|
|
4090
|
+
/*.nb3 =*/ dst_nb3,
|
|
4091
|
+
/*.offs =*/ offset,
|
|
4092
|
+
/*.inplace =*/ inplace,
|
|
4093
|
+
};
|
|
4094
|
+
|
|
4095
|
+
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne10);
|
|
4096
|
+
|
|
4097
|
+
[encoder setComputePipelineState:pipeline];
|
|
4098
|
+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
|
4099
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
|
4100
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
|
|
4101
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
|
4102
|
+
|
|
4103
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
3510
4104
|
} break;
|
|
3511
4105
|
case WSP_GGML_OP_POOL_2D:
|
|
3512
4106
|
{
|
|
@@ -3567,6 +4161,31 @@ static void wsp_ggml_metal_encode_node(
|
|
|
3567
4161
|
|
|
3568
4162
|
[encoder dispatchThreadgroups:MTLSizeMake(n_tg, 1, 1) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)];
|
|
3569
4163
|
} break;
|
|
4164
|
+
case WSP_GGML_OP_ARGMAX:
|
|
4165
|
+
{
|
|
4166
|
+
WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F32);
|
|
4167
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous_1(src0));
|
|
4168
|
+
WSP_GGML_ASSERT(nb00 == wsp_ggml_type_size(src0->type));
|
|
4169
|
+
|
|
4170
|
+
const int64_t nrows = wsp_ggml_nrows(src0);
|
|
4171
|
+
|
|
4172
|
+
int nth = 32; // SIMD width
|
|
4173
|
+
while (nth < ne00 && nth*ne01*ne02*ne03 < 256) {
|
|
4174
|
+
nth *= 2;
|
|
4175
|
+
}
|
|
4176
|
+
|
|
4177
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_ARGMAX].pipeline;
|
|
4178
|
+
|
|
4179
|
+
[encoder setComputePipelineState:pipeline];
|
|
4180
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
4181
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
4182
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
|
4183
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
|
|
4184
|
+
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
|
4185
|
+
[encoder setThreadgroupMemoryLength:32*sizeof(int32_t) atIndex:1];
|
|
4186
|
+
|
|
4187
|
+
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
4188
|
+
} break;
|
|
3570
4189
|
default:
|
|
3571
4190
|
{
|
|
3572
4191
|
WSP_GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, wsp_ggml_op_name(dst->op));
|
|
@@ -3718,6 +4337,8 @@ static void wsp_ggml_backend_metal_buffer_free_buffer(wsp_ggml_backend_buffer_t
|
|
|
3718
4337
|
for (int i = 0; i < ctx->n_buffers; i++) {
|
|
3719
4338
|
[ctx->buffers[i].metal release];
|
|
3720
4339
|
}
|
|
4340
|
+
|
|
4341
|
+
wsp_ggml_backend_metal_buffer_rset_free(ctx);
|
|
3721
4342
|
wsp_ggml_backend_metal_device_rel(buffer->buft->device->context);
|
|
3722
4343
|
|
|
3723
4344
|
if (ctx->owned) {
|
|
@@ -3740,19 +4361,19 @@ static void * wsp_ggml_backend_metal_buffer_get_base(wsp_ggml_backend_buffer_t b
|
|
|
3740
4361
|
static void wsp_ggml_backend_metal_buffer_memset_tensor(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
|
|
3741
4362
|
memset((char *)tensor->data + offset, value, size);
|
|
3742
4363
|
|
|
3743
|
-
|
|
4364
|
+
WSP_GGML_UNUSED(buffer);
|
|
3744
4365
|
}
|
|
3745
4366
|
|
|
3746
4367
|
static void wsp_ggml_backend_metal_buffer_set_tensor(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
|
|
3747
4368
|
memcpy((char *)tensor->data + offset, data, size);
|
|
3748
4369
|
|
|
3749
|
-
|
|
4370
|
+
WSP_GGML_UNUSED(buffer);
|
|
3750
4371
|
}
|
|
3751
4372
|
|
|
3752
4373
|
static void wsp_ggml_backend_metal_buffer_get_tensor(wsp_ggml_backend_buffer_t buffer, const struct wsp_ggml_tensor * tensor, void * data, size_t offset, size_t size) {
|
|
3753
4374
|
memcpy(data, (const char *)tensor->data + offset, size);
|
|
3754
4375
|
|
|
3755
|
-
|
|
4376
|
+
WSP_GGML_UNUSED(buffer);
|
|
3756
4377
|
}
|
|
3757
4378
|
|
|
3758
4379
|
static bool wsp_ggml_backend_metal_buffer_cpy_tensor(wsp_ggml_backend_buffer_t buffer, const struct wsp_ggml_tensor * src, struct wsp_ggml_tensor * dst) {
|
|
@@ -3762,7 +4383,7 @@ static bool wsp_ggml_backend_metal_buffer_cpy_tensor(wsp_ggml_backend_buffer_t b
|
|
|
3762
4383
|
}
|
|
3763
4384
|
return false;
|
|
3764
4385
|
|
|
3765
|
-
|
|
4386
|
+
WSP_GGML_UNUSED(buffer);
|
|
3766
4387
|
}
|
|
3767
4388
|
|
|
3768
4389
|
static void wsp_ggml_backend_metal_buffer_clear(wsp_ggml_backend_buffer_t buffer, uint8_t value) {
|
|
@@ -3788,7 +4409,7 @@ static struct wsp_ggml_backend_buffer_i wsp_ggml_backend_metal_buffer_i = {
|
|
|
3788
4409
|
static const char * wsp_ggml_backend_metal_buffer_type_get_name(wsp_ggml_backend_buffer_type_t buft) {
|
|
3789
4410
|
return "Metal";
|
|
3790
4411
|
|
|
3791
|
-
|
|
4412
|
+
WSP_GGML_UNUSED(buft);
|
|
3792
4413
|
}
|
|
3793
4414
|
|
|
3794
4415
|
static void wsp_ggml_backend_metal_log_allocated_size(id<MTLDevice> device, size_t size_aligned) {
|
|
@@ -3812,8 +4433,8 @@ static void wsp_ggml_backend_metal_log_allocated_size(id<MTLDevice> device, size
|
|
|
3812
4433
|
}
|
|
3813
4434
|
#endif
|
|
3814
4435
|
#endif
|
|
3815
|
-
|
|
3816
|
-
|
|
4436
|
+
WSP_GGML_UNUSED(device);
|
|
4437
|
+
WSP_GGML_UNUSED(size_aligned);
|
|
3817
4438
|
}
|
|
3818
4439
|
|
|
3819
4440
|
static wsp_ggml_backend_buffer_t wsp_ggml_backend_metal_buffer_type_alloc_buffer(wsp_ggml_backend_buffer_type_t buft, size_t size) {
|
|
@@ -3826,7 +4447,8 @@ static wsp_ggml_backend_buffer_t wsp_ggml_backend_metal_buffer_type_alloc_buffer
|
|
|
3826
4447
|
size_aligned += (size_page - (size_aligned % size_page));
|
|
3827
4448
|
}
|
|
3828
4449
|
|
|
3829
|
-
|
|
4450
|
+
struct wsp_ggml_backend_metal_device_context * ctx_dev = (struct wsp_ggml_backend_metal_device_context *)buft->device->context;
|
|
4451
|
+
id<MTLDevice> device = wsp_ggml_backend_metal_device_acq(ctx_dev);
|
|
3830
4452
|
|
|
3831
4453
|
ctx->all_data = wsp_ggml_metal_host_malloc(size_aligned);
|
|
3832
4454
|
ctx->all_size = size_aligned;
|
|
@@ -3849,7 +4471,14 @@ static wsp_ggml_backend_buffer_t wsp_ggml_backend_metal_buffer_type_alloc_buffer
|
|
|
3849
4471
|
if (size_aligned > 0 && (ctx->all_data == NULL || ctx->buffers[0].metal == nil)) {
|
|
3850
4472
|
WSP_GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0);
|
|
3851
4473
|
free(ctx);
|
|
3852
|
-
wsp_ggml_backend_metal_device_rel(
|
|
4474
|
+
wsp_ggml_backend_metal_device_rel(ctx_dev);
|
|
4475
|
+
return NULL;
|
|
4476
|
+
}
|
|
4477
|
+
|
|
4478
|
+
if (!wsp_ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) {
|
|
4479
|
+
WSP_GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__);
|
|
4480
|
+
free(ctx);
|
|
4481
|
+
wsp_ggml_backend_metal_device_rel(ctx_dev);
|
|
3853
4482
|
return NULL;
|
|
3854
4483
|
}
|
|
3855
4484
|
|
|
@@ -3860,7 +4489,7 @@ static wsp_ggml_backend_buffer_t wsp_ggml_backend_metal_buffer_type_alloc_buffer
|
|
|
3860
4489
|
|
|
3861
4490
|
static size_t wsp_ggml_backend_metal_buffer_type_get_alignment(wsp_ggml_backend_buffer_type_t buft) {
|
|
3862
4491
|
return 32;
|
|
3863
|
-
|
|
4492
|
+
WSP_GGML_UNUSED(buft);
|
|
3864
4493
|
}
|
|
3865
4494
|
|
|
3866
4495
|
static size_t wsp_ggml_backend_metal_buffer_type_get_max_size(wsp_ggml_backend_buffer_type_t buft) {
|
|
@@ -3870,13 +4499,13 @@ static size_t wsp_ggml_backend_metal_buffer_type_get_max_size(wsp_ggml_backend_b
|
|
|
3870
4499
|
|
|
3871
4500
|
return max_size;
|
|
3872
4501
|
|
|
3873
|
-
|
|
4502
|
+
WSP_GGML_UNUSED(buft);
|
|
3874
4503
|
}
|
|
3875
4504
|
|
|
3876
4505
|
static bool wsp_ggml_backend_metal_buffer_type_is_host(wsp_ggml_backend_buffer_type_t buft) {
|
|
3877
4506
|
return true;
|
|
3878
4507
|
|
|
3879
|
-
|
|
4508
|
+
WSP_GGML_UNUSED(buft);
|
|
3880
4509
|
}
|
|
3881
4510
|
|
|
3882
4511
|
wsp_ggml_backend_buffer_type_t wsp_ggml_backend_metal_buffer_type(void) {
|
|
@@ -3899,7 +4528,7 @@ wsp_ggml_backend_buffer_type_t wsp_ggml_backend_metal_buffer_type(void) {
|
|
|
3899
4528
|
static const char * wsp_ggml_backend_metal_buffer_from_ptr_type_get_name(wsp_ggml_backend_buffer_type_t buft) {
|
|
3900
4529
|
return "Metal_Mapped";
|
|
3901
4530
|
|
|
3902
|
-
|
|
4531
|
+
WSP_GGML_UNUSED(buft);
|
|
3903
4532
|
}
|
|
3904
4533
|
|
|
3905
4534
|
static wsp_ggml_backend_buffer_type_t wsp_ggml_backend_metal_buffer_from_ptr_type(void) {
|
|
@@ -3942,7 +4571,8 @@ wsp_ggml_backend_buffer_t wsp_ggml_backend_metal_buffer_from_ptr(void * data, si
|
|
|
3942
4571
|
size_aligned += (size_page - (size_aligned % size_page));
|
|
3943
4572
|
}
|
|
3944
4573
|
|
|
3945
|
-
|
|
4574
|
+
struct wsp_ggml_backend_metal_device_context * ctx_dev = &g_wsp_ggml_ctx_dev_main;
|
|
4575
|
+
id<MTLDevice> device = wsp_ggml_backend_metal_device_acq(ctx_dev);
|
|
3946
4576
|
|
|
3947
4577
|
// the buffer fits into the max buffer size allowed by the device
|
|
3948
4578
|
if (size_aligned <= device.maxBufferLength) {
|
|
@@ -3995,6 +4625,13 @@ wsp_ggml_backend_buffer_t wsp_ggml_backend_metal_buffer_from_ptr(void * data, si
|
|
|
3995
4625
|
}
|
|
3996
4626
|
}
|
|
3997
4627
|
|
|
4628
|
+
if (!wsp_ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) {
|
|
4629
|
+
WSP_GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__);
|
|
4630
|
+
free(ctx);
|
|
4631
|
+
wsp_ggml_backend_metal_device_rel(ctx_dev);
|
|
4632
|
+
return NULL;
|
|
4633
|
+
}
|
|
4634
|
+
|
|
3998
4635
|
return wsp_ggml_backend_buffer_init(wsp_ggml_backend_metal_buffer_from_ptr_type(), wsp_ggml_backend_metal_buffer_i, ctx, size);
|
|
3999
4636
|
}
|
|
4000
4637
|
|
|
@@ -4003,7 +4640,7 @@ wsp_ggml_backend_buffer_t wsp_ggml_backend_metal_buffer_from_ptr(void * data, si
|
|
|
4003
4640
|
static const char * wsp_ggml_backend_metal_name(wsp_ggml_backend_t backend) {
|
|
4004
4641
|
return "Metal";
|
|
4005
4642
|
|
|
4006
|
-
|
|
4643
|
+
WSP_GGML_UNUSED(backend);
|
|
4007
4644
|
}
|
|
4008
4645
|
|
|
4009
4646
|
static void wsp_ggml_backend_metal_free(wsp_ggml_backend_t backend) {
|
|
@@ -4308,6 +4945,13 @@ static wsp_ggml_backend_buffer_t wsp_ggml_backend_metal_device_buffer_from_ptr(w
|
|
|
4308
4945
|
}
|
|
4309
4946
|
}
|
|
4310
4947
|
|
|
4948
|
+
if (!wsp_ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) {
|
|
4949
|
+
WSP_GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__);
|
|
4950
|
+
free(ctx);
|
|
4951
|
+
wsp_ggml_backend_metal_device_rel(ctx_dev);
|
|
4952
|
+
return NULL;
|
|
4953
|
+
}
|
|
4954
|
+
|
|
4311
4955
|
return wsp_ggml_backend_buffer_init(wsp_ggml_backend_metal_buffer_from_ptr_type(), wsp_ggml_backend_metal_buffer_i, ctx, size);
|
|
4312
4956
|
}
|
|
4313
4957
|
|
|
@@ -4321,7 +4965,7 @@ static bool wsp_ggml_backend_metal_device_supports_buft(wsp_ggml_backend_dev_t d
|
|
|
4321
4965
|
return buft->iface.get_name == wsp_ggml_backend_metal_buffer_type_get_name ||
|
|
4322
4966
|
buft->iface.get_name == wsp_ggml_backend_metal_buffer_from_ptr_type_get_name;
|
|
4323
4967
|
|
|
4324
|
-
|
|
4968
|
+
WSP_GGML_UNUSED(dev);
|
|
4325
4969
|
}
|
|
4326
4970
|
|
|
4327
4971
|
static bool wsp_ggml_backend_metal_device_offload_op(wsp_ggml_backend_dev_t dev, const struct wsp_ggml_tensor * op) {
|
|
@@ -4372,19 +5016,45 @@ static wsp_ggml_backend_dev_t wsp_ggml_backend_metal_reg_device_get(wsp_ggml_bac
|
|
|
4372
5016
|
WSP_GGML_UNUSED(index);
|
|
4373
5017
|
}
|
|
4374
5018
|
|
|
5019
|
+
static struct wsp_ggml_backend_feature g_wsp_ggml_backend_metal_features[] = {
|
|
5020
|
+
#if defined(WSP_GGML_METAL_EMBED_LIBRARY)
|
|
5021
|
+
{ "EMBED_LIBRARY", "1" },
|
|
5022
|
+
#endif
|
|
5023
|
+
#if defined(WSP_GGML_METAL_USE_BF16)
|
|
5024
|
+
{ "BF16", "1" },
|
|
5025
|
+
#endif
|
|
5026
|
+
{ nil, nil },
|
|
5027
|
+
};
|
|
5028
|
+
|
|
5029
|
+
static struct wsp_ggml_backend_feature * wsp_ggml_backend_metal_get_features(wsp_ggml_backend_reg_t reg) {
|
|
5030
|
+
return g_wsp_ggml_backend_metal_features;
|
|
5031
|
+
|
|
5032
|
+
WSP_GGML_UNUSED(reg);
|
|
5033
|
+
}
|
|
5034
|
+
|
|
5035
|
+
static void * wsp_ggml_backend_metal_get_proc_address(wsp_ggml_backend_reg_t reg, const char * name) {
|
|
5036
|
+
if (strcmp(name, "wsp_ggml_backend_get_features") == 0) {
|
|
5037
|
+
return (void *)wsp_ggml_backend_metal_get_features;
|
|
5038
|
+
}
|
|
5039
|
+
|
|
5040
|
+
return NULL;
|
|
5041
|
+
|
|
5042
|
+
WSP_GGML_UNUSED(reg);
|
|
5043
|
+
}
|
|
4375
5044
|
static struct wsp_ggml_backend_reg_i wsp_ggml_backend_metal_reg_i = {
|
|
4376
5045
|
/* .get_name = */ wsp_ggml_backend_metal_reg_get_name,
|
|
4377
5046
|
/* .device_count = */ wsp_ggml_backend_metal_reg_device_count,
|
|
4378
5047
|
/* .device_get = */ wsp_ggml_backend_metal_reg_device_get,
|
|
4379
|
-
/* .get_proc_address = */
|
|
5048
|
+
/* .get_proc_address = */ wsp_ggml_backend_metal_get_proc_address,
|
|
4380
5049
|
};
|
|
4381
5050
|
|
|
4382
5051
|
wsp_ggml_backend_reg_t wsp_ggml_backend_metal_reg(void) {
|
|
4383
5052
|
// TODO: make this thread-safe somehow?
|
|
4384
5053
|
{
|
|
4385
5054
|
g_wsp_ggml_backend_metal_reg = (struct wsp_ggml_backend_reg) {
|
|
4386
|
-
/* .
|
|
4387
|
-
/* .
|
|
5055
|
+
/* .api_version = */ WSP_GGML_BACKEND_API_VERSION,
|
|
5056
|
+
/* .iface = */ wsp_ggml_backend_metal_reg_i,
|
|
5057
|
+
/* .context = */ NULL,
|
|
4388
5058
|
};
|
|
4389
5059
|
|
|
4390
5060
|
g_wsp_ggml_backend_metal_device = (struct wsp_ggml_backend_device) {
|
|
@@ -4396,3 +5066,5 @@ wsp_ggml_backend_reg_t wsp_ggml_backend_metal_reg(void) {
|
|
|
4396
5066
|
|
|
4397
5067
|
return &g_wsp_ggml_backend_metal_reg;
|
|
4398
5068
|
}
|
|
5069
|
+
|
|
5070
|
+
WSP_GGML_BACKEND_DL_IMPL(wsp_ggml_backend_metal_reg)
|