whisper.rn 0.4.0-rc.10 → 0.4.0-rc.11

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/cpp/ggml-metal.m CHANGED
@@ -19,7 +19,17 @@
19
19
  // max number of MTLCommandBuffer used to submit a graph for processing
20
20
  #define WSP_GGML_METAL_MAX_COMMAND_BUFFERS 8
21
21
 
22
- #define UNUSED(x) (void)(x)
22
+ #ifndef TARGET_OS_VISION
23
+ #define TARGET_OS_VISION 0
24
+ #endif
25
+
26
+ // create residency sets only on macOS >= 15.0
27
+ #if !TARGET_CPU_X86_64 && TARGET_OS_OSX && __MAC_OS_X_VERSION_MAX_ALLOWED >= 150000 || \
28
+ TARGET_OS_IOS && __IPHONE_OS_VERSION_MAX_ALLOWED >= 180000 || \
29
+ TARGET_OS_TV && __TV_OS_VERSION_MAX_ALLOWED >= 180000 || \
30
+ TARGET_OS_VISION && __VISION_OS_VERSION_MAX_ALLOWED >= 200000
31
+ #define WSP_GGML_METAL_HAS_RESIDENCY_SETS 1
32
+ #endif
23
33
 
24
34
  // globals
25
35
 
@@ -39,6 +49,7 @@ static struct wsp_ggml_backend_metal_device_context {
39
49
 
40
50
  bool has_simdgroup_reduction;
41
51
  bool has_simdgroup_mm;
52
+ bool has_residency_sets;
42
53
  bool has_bfloat;
43
54
  bool use_bfloat;
44
55
 
@@ -48,6 +59,7 @@ static struct wsp_ggml_backend_metal_device_context {
48
59
  /*.mtl_device_ref_count =*/ 0,
49
60
  /*.has_simdgroup_reduction =*/ false,
50
61
  /*.has_simdgroup_mm =*/ false,
62
+ /*.has_residency_sets =*/ false,
51
63
  /*.has_bfloat =*/ false,
52
64
  /*.use_bfloat =*/ false,
53
65
  /*.name =*/ "",
@@ -59,12 +71,18 @@ static id<MTLDevice> wsp_ggml_backend_metal_device_acq(struct wsp_ggml_backend_m
59
71
 
60
72
  if (ctx->mtl_device == nil) {
61
73
  ctx->mtl_device = MTLCreateSystemDefaultDevice();
74
+ }
62
75
 
76
+ if (ctx->mtl_device) {
63
77
  ctx->has_simdgroup_reduction = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
64
78
  ctx->has_simdgroup_reduction |= [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
65
79
 
66
80
  ctx->has_simdgroup_mm = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
67
81
 
82
+ #if defined(WSP_GGML_METAL_HAS_RESIDENCY_SETS)
83
+ ctx->has_residency_sets = getenv("WSP_GGML_METAL_NO_RESIDENCY") == NULL;
84
+ #endif
85
+
68
86
  ctx->has_bfloat = [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
69
87
  ctx->has_bfloat |= [ctx->mtl_device supportsFamily:MTLGPUFamilyApple6];
70
88
 
@@ -90,8 +108,10 @@ static void wsp_ggml_backend_metal_device_rel(struct wsp_ggml_backend_metal_devi
90
108
  ctx->mtl_device_ref_count--;
91
109
 
92
110
  if (ctx->mtl_device_ref_count == 0) {
93
- [ctx->mtl_device release];
94
- ctx->mtl_device = nil;
111
+ if (ctx->mtl_device) {
112
+ [ctx->mtl_device release];
113
+ ctx->mtl_device = nil;
114
+ }
95
115
  }
96
116
  }
97
117
 
@@ -175,6 +195,46 @@ enum wsp_ggml_metal_kernel_type {
175
195
  WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,
176
196
  WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,
177
197
  WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,
198
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2,
199
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3,
200
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4,
201
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_5,
202
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_2,
203
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_3,
204
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_4,
205
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_5,
206
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_2,
207
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_3,
208
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_4,
209
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_5,
210
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_2,
211
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_3,
212
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_4,
213
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_5,
214
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_2,
215
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_3,
216
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_4,
217
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_5,
218
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_2,
219
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3,
220
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4,
221
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5,
222
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2,
223
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3,
224
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4,
225
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_5,
226
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_2,
227
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_3,
228
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_4,
229
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_5,
230
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_2,
231
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_3,
232
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_4,
233
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_5,
234
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_2,
235
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_3,
236
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_4,
237
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_5,
178
238
  WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32,
179
239
  WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32,
180
240
  WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32,
@@ -266,8 +326,11 @@ enum wsp_ggml_metal_kernel_type {
266
326
  WSP_GGML_METAL_KERNEL_TYPE_IM2COL_F32,
267
327
  WSP_GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16,
268
328
  WSP_GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32,
329
+ WSP_GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F32_F32,
330
+ WSP_GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32,
269
331
  WSP_GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
270
332
  WSP_GGML_METAL_KERNEL_TYPE_PAD_F32,
333
+ WSP_GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32,
271
334
  WSP_GGML_METAL_KERNEL_TYPE_ARANGE_F32,
272
335
  WSP_GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32,
273
336
  WSP_GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
@@ -329,6 +392,8 @@ enum wsp_ggml_metal_kernel_type {
329
392
  WSP_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256,
330
393
  WSP_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256,
331
394
  WSP_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256,
395
+ WSP_GGML_METAL_KERNEL_TYPE_SET_I32,
396
+ WSP_GGML_METAL_KERNEL_TYPE_SET_F32,
332
397
  WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
333
398
  WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
334
399
  WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_BF16,
@@ -342,6 +407,16 @@ enum wsp_ggml_metal_kernel_type {
342
407
  WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,
343
408
  WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,
344
409
  WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,
410
+ WSP_GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32,
411
+ WSP_GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F16,
412
+ WSP_GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32,
413
+ WSP_GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F16,
414
+ WSP_GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32,
415
+ WSP_GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F16,
416
+ WSP_GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32,
417
+ WSP_GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F16,
418
+ WSP_GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32,
419
+ WSP_GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16,
345
420
  WSP_GGML_METAL_KERNEL_TYPE_CONCAT,
346
421
  WSP_GGML_METAL_KERNEL_TYPE_SQR,
347
422
  WSP_GGML_METAL_KERNEL_TYPE_SQRT,
@@ -350,6 +425,7 @@ enum wsp_ggml_metal_kernel_type {
350
425
  WSP_GGML_METAL_KERNEL_TYPE_SUM_ROWS,
351
426
  WSP_GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
352
427
  WSP_GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
428
+ WSP_GGML_METAL_KERNEL_TYPE_ARGMAX,
353
429
 
354
430
  WSP_GGML_METAL_KERNEL_TYPE_COUNT
355
431
  };
@@ -437,6 +513,11 @@ static struct wsp_ggml_backend_metal_context * wsp_ggml_metal_init(wsp_ggml_back
437
513
  WSP_GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]);
438
514
 
439
515
  ctx->queue = [device newCommandQueue];
516
+ if (ctx->queue == nil) {
517
+ WSP_GGML_LOG_ERROR("%s: error: failed to create command queue\n", __func__);
518
+ return NULL;
519
+ }
520
+
440
521
  ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
441
522
 
442
523
  id<MTLLibrary> metal_library;
@@ -464,6 +545,35 @@ static struct wsp_ggml_backend_metal_context * wsp_ggml_metal_init(wsp_ggml_back
464
545
  #endif
465
546
 
466
547
  NSString * path_lib = [bundle pathForResource:@"ggml-whisper" ofType:@"metallib"];
548
+ if (path_lib == nil) {
549
+ // Try to find the resource in the directory where the current binary located.
550
+ NSString * current_binary = [[NSProcessInfo processInfo] arguments][0];
551
+ NSString * bin_dir = [current_binary stringByDeletingLastPathComponent];
552
+ NSString * default_metallib_path = [NSString pathWithComponents:@[bin_dir, @"default.metallib"]];
553
+ if ([[NSFileManager defaultManager] isReadableFileAtPath:default_metallib_path]) {
554
+ WSP_GGML_LOG_INFO("%s: found '%s'\n", __func__, [default_metallib_path UTF8String]);
555
+ NSDictionary * atts = [[NSFileManager defaultManager] attributesOfItemAtPath:default_metallib_path error:&error];
556
+ if (atts && atts[NSFileType] == NSFileTypeSymbolicLink) {
557
+ // Optionally, if this is a symlink, try to resolve it.
558
+ default_metallib_path = [[NSFileManager defaultManager] destinationOfSymbolicLinkAtPath:default_metallib_path error:&error];
559
+ if (default_metallib_path && [default_metallib_path length] > 0 && ![[default_metallib_path substringToIndex:1] isEqualToString:@"/"]) {
560
+ // It is a relative path, adding the binary directory as directory prefix.
561
+ default_metallib_path = [NSString pathWithComponents:@[bin_dir, default_metallib_path]];
562
+ }
563
+ if (!default_metallib_path || ![[NSFileManager defaultManager] isReadableFileAtPath:default_metallib_path]) {
564
+ // Link to the resource could not be resolved.
565
+ default_metallib_path = nil;
566
+ } else {
567
+ WSP_GGML_LOG_INFO("%s: symlink resolved '%s'\n", __func__, [default_metallib_path UTF8String]);
568
+ }
569
+ }
570
+ } else {
571
+ // The resource couldn't be found in the binary's directory.
572
+ default_metallib_path = nil;
573
+ }
574
+ path_lib = default_metallib_path;
575
+ }
576
+
467
577
  if (try_metallib && path_lib != nil) {
468
578
  // pre-compiled library found
469
579
  NSURL * libURL = [NSURL fileURLWithPath:path_lib];
@@ -574,6 +684,7 @@ static struct wsp_ggml_backend_metal_context * wsp_ggml_metal_init(wsp_ggml_back
574
684
 
575
685
  WSP_GGML_LOG_INFO("%s: simdgroup reduction = %s\n", __func__, ctx_dev->has_simdgroup_reduction ? "true" : "false");
576
686
  WSP_GGML_LOG_INFO("%s: simdgroup matrix mul. = %s\n", __func__, ctx_dev->has_simdgroup_mm ? "true" : "false");
687
+ WSP_GGML_LOG_INFO("%s: has residency sets = %s\n", __func__, ctx_dev->has_residency_sets ? "true" : "false");
577
688
  WSP_GGML_LOG_INFO("%s: has bfloat = %s\n", __func__, ctx_dev->has_bfloat ? "true" : "false");
578
689
  WSP_GGML_LOG_INFO("%s: use bfloat = %s\n", __func__, ctx_dev->use_bfloat ? "true" : "false");
579
690
  WSP_GGML_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx_dev->mtl_device.hasUnifiedMemory ? "true" : "false");
@@ -699,6 +810,46 @@ static struct wsp_ggml_backend_metal_context * wsp_ggml_metal_init(wsp_ggml_back
699
810
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, has_simdgroup_reduction);
700
811
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, has_simdgroup_reduction);
701
812
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, has_simdgroup_reduction);
813
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2, mul_mv_ext_f16_f32_r1_2, has_simdgroup_reduction);
814
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3, mul_mv_ext_f16_f32_r1_3, has_simdgroup_reduction);
815
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4, mul_mv_ext_f16_f32_r1_4, has_simdgroup_reduction);
816
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_5, mul_mv_ext_f16_f32_r1_5, has_simdgroup_reduction);
817
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_2, mul_mv_ext_q4_0_f32_r1_2, has_simdgroup_reduction);
818
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_3, mul_mv_ext_q4_0_f32_r1_3, has_simdgroup_reduction);
819
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_4, mul_mv_ext_q4_0_f32_r1_4, has_simdgroup_reduction);
820
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_5, mul_mv_ext_q4_0_f32_r1_5, has_simdgroup_reduction);
821
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_2, mul_mv_ext_q4_1_f32_r1_2, has_simdgroup_reduction);
822
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_3, mul_mv_ext_q4_1_f32_r1_3, has_simdgroup_reduction);
823
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_4, mul_mv_ext_q4_1_f32_r1_4, has_simdgroup_reduction);
824
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_5, mul_mv_ext_q4_1_f32_r1_5, has_simdgroup_reduction);
825
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_2, mul_mv_ext_q5_0_f32_r1_2, has_simdgroup_reduction);
826
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_3, mul_mv_ext_q5_0_f32_r1_3, has_simdgroup_reduction);
827
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_4, mul_mv_ext_q5_0_f32_r1_4, has_simdgroup_reduction);
828
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_5, mul_mv_ext_q5_0_f32_r1_5, has_simdgroup_reduction);
829
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_2, mul_mv_ext_q5_1_f32_r1_2, has_simdgroup_reduction);
830
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_3, mul_mv_ext_q5_1_f32_r1_3, has_simdgroup_reduction);
831
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_4, mul_mv_ext_q5_1_f32_r1_4, has_simdgroup_reduction);
832
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_5, mul_mv_ext_q5_1_f32_r1_5, has_simdgroup_reduction);
833
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_2, mul_mv_ext_q8_0_f32_r1_2, has_simdgroup_reduction);
834
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3, mul_mv_ext_q8_0_f32_r1_3, has_simdgroup_reduction);
835
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4, mul_mv_ext_q8_0_f32_r1_4, has_simdgroup_reduction);
836
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5, mul_mv_ext_q8_0_f32_r1_5, has_simdgroup_reduction);
837
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2, mul_mv_ext_q4_K_f32_r1_2, has_simdgroup_reduction);
838
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3, mul_mv_ext_q4_K_f32_r1_3, has_simdgroup_reduction);
839
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4, mul_mv_ext_q4_K_f32_r1_4, has_simdgroup_reduction);
840
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_5, mul_mv_ext_q4_K_f32_r1_5, has_simdgroup_reduction);
841
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_2, mul_mv_ext_q5_K_f32_r1_2, has_simdgroup_reduction);
842
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_3, mul_mv_ext_q5_K_f32_r1_3, has_simdgroup_reduction);
843
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_4, mul_mv_ext_q5_K_f32_r1_4, has_simdgroup_reduction);
844
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_5, mul_mv_ext_q5_K_f32_r1_5, has_simdgroup_reduction);
845
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_2, mul_mv_ext_q6_K_f32_r1_2, has_simdgroup_reduction);
846
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_3, mul_mv_ext_q6_K_f32_r1_3, has_simdgroup_reduction);
847
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_4, mul_mv_ext_q6_K_f32_r1_4, has_simdgroup_reduction);
848
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_5, mul_mv_ext_q6_K_f32_r1_5, has_simdgroup_reduction);
849
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_2, mul_mv_ext_iq4_nl_f32_r1_2, has_simdgroup_reduction);
850
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_3, mul_mv_ext_iq4_nl_f32_r1_3, has_simdgroup_reduction);
851
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_4, mul_mv_ext_iq4_nl_f32_r1_4, has_simdgroup_reduction);
852
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_5, mul_mv_ext_iq4_nl_f32_r1_5, has_simdgroup_reduction);
702
853
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, has_simdgroup_reduction);
703
854
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, has_simdgroup_reduction);
704
855
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, has_simdgroup_reduction);
@@ -790,8 +941,11 @@ static struct wsp_ggml_backend_metal_context * wsp_ggml_metal_init(wsp_ggml_back
790
941
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
791
942
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16, im2col_ext_f16, true);
792
943
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32, im2col_ext_f32, true);
944
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F32_F32, conv_transpose_1d_f32_f32, true);
945
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32, conv_transpose_1d_f16_f32, true);
793
946
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
794
947
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
948
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32, pad_reflect_1d_f32, true);
795
949
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true);
796
950
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_ARANGE_F32, arange_f32, true);
797
951
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
@@ -853,6 +1007,8 @@ static struct wsp_ggml_backend_metal_context * wsp_ggml_metal_init(wsp_ggml_back
853
1007
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256, has_simdgroup_reduction);
854
1008
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256, flash_attn_ext_vec_q5_1_h256, has_simdgroup_reduction);
855
1009
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, flash_attn_ext_vec_q8_0_h256, has_simdgroup_reduction);
1010
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SET_F32, set_f32, true);
1011
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SET_I32, set_i32, true);
856
1012
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
857
1013
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
858
1014
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_BF16, cpy_f32_bf16, use_bfloat);
@@ -866,12 +1022,23 @@ static struct wsp_ggml_backend_metal_context * wsp_ggml_metal_init(wsp_ggml_back
866
1022
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true);
867
1023
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
868
1024
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true);
1025
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32, cpy_q4_0_f32, true);
1026
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F16, cpy_q4_0_f16, true);
1027
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32, cpy_q4_1_f32, true);
1028
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F16, cpy_q4_1_f16, true);
1029
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32, cpy_q5_0_f32, true);
1030
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F16, cpy_q5_0_f16, true);
1031
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32, cpy_q5_1_f32, true);
1032
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F16, cpy_q5_1_f16, true);
1033
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32, cpy_q8_0_f32, true);
1034
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16, cpy_q8_0_f16, true);
869
1035
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
870
1036
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SQR, sqr, true);
871
1037
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SQRT, sqrt, true);
872
1038
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SIN, sin, true);
873
1039
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_COS, cos, true);
874
1040
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
1041
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
875
1042
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
876
1043
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true);
877
1044
  }
@@ -914,8 +1081,70 @@ struct wsp_ggml_backend_metal_buffer_context {
914
1081
  // multiple buffers are used only to avoid the maximum buffer size limitation when using mmap
915
1082
  int n_buffers;
916
1083
  struct wsp_ggml_backend_metal_buffer buffers[WSP_GGML_METAL_MAX_BUFFERS];
1084
+
1085
+ // optional MTLResidencySet
1086
+ id rset;
917
1087
  };
918
1088
 
1089
+ // rset init
1090
+ static bool wsp_ggml_backend_metal_buffer_rset_init(
1091
+ struct wsp_ggml_backend_metal_buffer_context * ctx,
1092
+ struct wsp_ggml_backend_metal_device_context * ctx_dev,
1093
+ id<MTLDevice> device) {
1094
+ ctx->rset = nil;
1095
+
1096
+ if (!ctx_dev->has_residency_sets) {
1097
+ return true;
1098
+ }
1099
+
1100
+ #if defined(WSP_GGML_METAL_HAS_RESIDENCY_SETS)
1101
+ if (@available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, *)) {
1102
+ MTLResidencySetDescriptor * desc = [[MTLResidencySetDescriptor alloc] init];
1103
+ desc.label = @"wsp_ggml_backend_metal";
1104
+ desc.initialCapacity = ctx->n_buffers;
1105
+
1106
+ NSError * error;
1107
+ ctx->rset = [device newResidencySetWithDescriptor:desc error:&error];
1108
+ if (error) {
1109
+ WSP_GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
1110
+ [desc release];
1111
+ return false;
1112
+ }
1113
+
1114
+ [desc release];
1115
+
1116
+ for (int i = 0; i < ctx->n_buffers; i++) {
1117
+ [ctx->rset addAllocation:ctx->buffers[i].metal];
1118
+ }
1119
+
1120
+ [ctx->rset commit];
1121
+ [ctx->rset requestResidency];
1122
+
1123
+ return true;
1124
+ }
1125
+ #else
1126
+ WSP_GGML_UNUSED(ctx_dev);
1127
+ WSP_GGML_UNUSED(device);
1128
+ #endif
1129
+
1130
+ return true;
1131
+ }
1132
+
1133
+ // rset free
1134
+ static void wsp_ggml_backend_metal_buffer_rset_free(struct wsp_ggml_backend_metal_buffer_context * ctx) {
1135
+ #if defined(WSP_GGML_METAL_HAS_RESIDENCY_SETS)
1136
+ if (@available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, *)) {
1137
+ if (ctx->rset) {
1138
+ [ctx->rset endResidency];
1139
+ [ctx->rset removeAllAllocations];
1140
+ [ctx->rset release];
1141
+ }
1142
+ }
1143
+ #else
1144
+ WSP_GGML_UNUSED(ctx);
1145
+ #endif
1146
+ }
1147
+
919
1148
  // finds the Metal buffer that contains the tensor data on the GPU device
920
1149
  // the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
921
1150
  // Metal buffer based on the host memory pointer
@@ -989,6 +1218,7 @@ static bool wsp_ggml_metal_supports_op(const struct wsp_ggml_backend_metal_devic
989
1218
  case WSP_GGML_OP_REPEAT:
990
1219
  case WSP_GGML_OP_SCALE:
991
1220
  case WSP_GGML_OP_CLAMP:
1221
+ case WSP_GGML_OP_CONV_TRANSPOSE_1D:
992
1222
  return true;
993
1223
  case WSP_GGML_OP_SQR:
994
1224
  case WSP_GGML_OP_SQRT:
@@ -997,12 +1227,25 @@ static bool wsp_ggml_metal_supports_op(const struct wsp_ggml_backend_metal_devic
997
1227
  return wsp_ggml_is_contiguous(op->src[0]);
998
1228
  case WSP_GGML_OP_SUM_ROWS:
999
1229
  case WSP_GGML_OP_SOFT_MAX:
1000
- case WSP_GGML_OP_RMS_NORM:
1001
1230
  case WSP_GGML_OP_GROUP_NORM:
1002
- return has_simdgroup_reduction;
1231
+ return has_simdgroup_reduction && wsp_ggml_is_contiguous(op->src[0]);
1232
+ case WSP_GGML_OP_RMS_NORM:
1233
+ return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && wsp_ggml_is_contiguous_1(op->src[0]));
1234
+ case WSP_GGML_OP_ARGMAX:
1235
+ return true;
1003
1236
  case WSP_GGML_OP_NORM:
1237
+ return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && wsp_ggml_is_contiguous_1(op->src[0]));
1004
1238
  case WSP_GGML_OP_ROPE:
1005
- return true;
1239
+ {
1240
+ const int mode = ((const int32_t *) op->op_params)[2];
1241
+ if (mode & WSP_GGML_ROPE_TYPE_MROPE) {
1242
+ return false;
1243
+ }
1244
+ if (mode & WSP_GGML_ROPE_TYPE_VISION) {
1245
+ return false;
1246
+ }
1247
+ return true;
1248
+ }
1006
1249
  case WSP_GGML_OP_IM2COL:
1007
1250
  return op->src[0]->type == WSP_GGML_TYPE_F16;
1008
1251
  case WSP_GGML_OP_POOL_1D:
@@ -1010,6 +1253,7 @@ static bool wsp_ggml_metal_supports_op(const struct wsp_ggml_backend_metal_devic
1010
1253
  case WSP_GGML_OP_POOL_2D:
1011
1254
  case WSP_GGML_OP_UPSCALE:
1012
1255
  case WSP_GGML_OP_PAD:
1256
+ case WSP_GGML_OP_PAD_REFLECT_1D:
1013
1257
  case WSP_GGML_OP_ARANGE:
1014
1258
  case WSP_GGML_OP_TIMESTEP_EMBEDDING:
1015
1259
  case WSP_GGML_OP_ARGSORT:
@@ -1063,6 +1307,28 @@ static bool wsp_ggml_metal_supports_op(const struct wsp_ggml_backend_metal_devic
1063
1307
  default:
1064
1308
  return false;
1065
1309
  }
1310
+ case WSP_GGML_TYPE_Q4_0:
1311
+ case WSP_GGML_TYPE_Q4_1:
1312
+ case WSP_GGML_TYPE_Q5_0:
1313
+ case WSP_GGML_TYPE_Q5_1:
1314
+ case WSP_GGML_TYPE_Q8_0:
1315
+ switch (op->type) {
1316
+ case WSP_GGML_TYPE_F32:
1317
+ case WSP_GGML_TYPE_F16:
1318
+ return true;
1319
+ default:
1320
+ return false;
1321
+ }
1322
+ default:
1323
+ return false;
1324
+ };
1325
+ }
1326
+ case WSP_GGML_OP_SET:
1327
+ {
1328
+ switch (op->src[0]->type) {
1329
+ case WSP_GGML_TYPE_F32:
1330
+ case WSP_GGML_TYPE_I32:
1331
+ return true;
1066
1332
  default:
1067
1333
  return false;
1068
1334
  };
@@ -1749,7 +2015,7 @@ static void wsp_ggml_metal_encode_node(
1749
2015
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
1750
2016
 
1751
2017
  // TODO: add wsp_ggml_metal_kargs struct
1752
- // TODO: optimize (see https://github.com/ggerganov/llama.cpp/pull/10238/commits/7941b6b9ec29a2866fec6fa6c51612515ca509f6)
2018
+ // TODO: optimize (see https://github.com/ggml-org/llama.cpp/pull/10238/commits/7941b6b9ec29a2866fec6fa6c51612515ca509f6)
1753
2019
  [encoder setComputePipelineState:pipeline];
1754
2020
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1755
2021
  if (id_src1) {
@@ -1922,345 +2188,495 @@ static void wsp_ggml_metal_encode_node(
1922
2188
  WSP_GGML_ASSERT(ne12 % ne02 == 0);
1923
2189
  WSP_GGML_ASSERT(ne13 % ne03 == 0);
1924
2190
 
1925
- const uint r2 = ne12/ne02;
1926
- const uint r3 = ne13/ne03;
2191
+ const uint32_t r2 = ne12/ne02;
2192
+ const uint32_t r3 = ne13/ne03;
1927
2193
 
1928
2194
  // find the break-even point where the matrix-matrix kernel becomes more efficient compared
1929
2195
  // to the matrix-vector kernel
1930
- int ne11_mm_min = 1;
2196
+ const int ne11_mm_min = 4;
2197
+
2198
+ // first try to use small-batch mat-mv kernels
2199
+ // these should be efficient for BS [2, ~8]
2200
+ if (src1t == WSP_GGML_TYPE_F32 && (ne00%256 == 0) &&
2201
+ (
2202
+ (
2203
+ (
2204
+ src0t == WSP_GGML_TYPE_F16 || // TODO: helper function
2205
+ src0t == WSP_GGML_TYPE_Q4_0 ||
2206
+ src0t == WSP_GGML_TYPE_Q4_1 ||
2207
+ src0t == WSP_GGML_TYPE_Q5_0 ||
2208
+ src0t == WSP_GGML_TYPE_Q5_1 ||
2209
+ src0t == WSP_GGML_TYPE_Q8_0 ||
2210
+ src0t == WSP_GGML_TYPE_IQ4_NL ||
2211
+ false) && (ne11 >= 2 && ne11 <= 8)
2212
+ ) ||
2213
+ (
2214
+ (
2215
+ src0t == WSP_GGML_TYPE_Q4_K ||
2216
+ src0t == WSP_GGML_TYPE_Q5_K ||
2217
+ src0t == WSP_GGML_TYPE_Q6_K ||
2218
+ false) && (ne11 >= 4 && ne11 <= 8)
2219
+ )
2220
+ )
2221
+ ) {
2222
+ // TODO: determine the optimal parameters based on grid utilization
2223
+ // I still don't know why we should not always use the maximum available threads:
2224
+ //
2225
+ // nsg = pipeline.maxTotalThreadsPerThreadgroup / 32
2226
+ //
2227
+ // my current hypothesis is that the work grid is not evenly divisible for different nsg
2228
+ // values and there can be some tail effects when nsg is high. need to confirm this
2229
+ //
2230
+ const int nsg = 2; // num simdgroups per threadgroup
2231
+ const int nxpsg = ne11 < 3 ? 16 : 8; // num threads along row per simdgroup
2232
+ const int nypsg = 32/nxpsg; // num threads along col per simdgroup (i.e. a simdgroup processes that many src0 rows at a time)
2233
+ const int r0ptg = nypsg*nsg; // num src0 rows per threadgroup
2234
+ int r1ptg = 4; // num src1 rows per threadgroup
2235
+
2236
+ // note: not sure how optimal are those across all different hardware. there might be someting cleverer
2237
+ switch (ne11) {
2238
+ case 2:
2239
+ r1ptg = 2; break;
2240
+ case 3:
2241
+ case 6:
2242
+ r1ptg = 3; break;
2243
+ case 4:
2244
+ case 7:
2245
+ case 8:
2246
+ r1ptg = 4; break;
2247
+ case 5:
2248
+ r1ptg = 5; break;
2249
+ };
1931
2250
 
1932
- #if 0
1933
- // the numbers below are measured on M2 Ultra for 7B and 13B models
1934
- // these numbers do not translate to other devices or model sizes
1935
- // TODO: need to find a better approach
1936
- if ([device.name isEqualToString:@"Apple M2 Ultra"]) {
1937
- switch (src0t) {
1938
- case WSP_GGML_TYPE_F16: ne11_mm_min = 2; break;
1939
- case WSP_GGML_TYPE_Q8_0: ne11_mm_min = 7; break;
1940
- case WSP_GGML_TYPE_Q2_K: ne11_mm_min = 15; break;
1941
- case WSP_GGML_TYPE_Q3_K: ne11_mm_min = 7; break;
1942
- case WSP_GGML_TYPE_Q4_0:
1943
- case WSP_GGML_TYPE_Q4_1: ne11_mm_min = 15; break;
1944
- case WSP_GGML_TYPE_Q4_K: ne11_mm_min = 11; break;
1945
- case WSP_GGML_TYPE_Q5_0: // not tested yet
1946
- case WSP_GGML_TYPE_Q5_1: ne11_mm_min = 13; break; // not tested yet
1947
- case WSP_GGML_TYPE_Q5_K: ne11_mm_min = 7; break;
1948
- case WSP_GGML_TYPE_Q6_K: ne11_mm_min = 7; break;
1949
- default: ne11_mm_min = 1; break;
1950
- }
1951
- }
1952
- #endif
2251
+ id<MTLComputePipelineState> pipeline = nil;
1953
2252
 
1954
- // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
1955
- // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
1956
- if ([device supportsFamily:MTLGPUFamilyApple7] &&
1957
- !wsp_ggml_is_transposed(src0) &&
1958
- !wsp_ggml_is_transposed(src1) &&
1959
- src1t == WSP_GGML_TYPE_F32 &&
1960
- ne00 % 32 == 0 && ne00 >= 64 &&
1961
- (ne11 > ne11_mm_min || (wsp_ggml_is_quantized(src0t) && ne12 > 1))) {
1962
- //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
1963
-
1964
- // some Metal matrix data types require aligned pointers
1965
- // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
1966
- switch (src0->type) {
1967
- case WSP_GGML_TYPE_F32: WSP_GGML_ASSERT(nb01 % 16 == 0); break;
1968
- case WSP_GGML_TYPE_F16: WSP_GGML_ASSERT(nb01 % 8 == 0); break;
1969
- case WSP_GGML_TYPE_BF16: WSP_GGML_ASSERT(nb01 % 8 == 0); break;
1970
- default: break;
1971
- }
2253
+ switch (src0->type) {
2254
+ case WSP_GGML_TYPE_F16:
2255
+ switch (r1ptg) {
2256
+ case 2: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2].pipeline; break;
2257
+ case 3: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3].pipeline; break;
2258
+ case 4: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4].pipeline; break;
2259
+ case 5: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_5].pipeline; break;
2260
+ default: WSP_GGML_ABORT("not implemented");
2261
+ } break;
2262
+ case WSP_GGML_TYPE_Q4_0:
2263
+ switch (r1ptg) {
2264
+ case 2: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_2].pipeline; break;
2265
+ case 3: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_3].pipeline; break;
2266
+ case 4: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_4].pipeline; break;
2267
+ case 5: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_5].pipeline; break;
2268
+ default: WSP_GGML_ABORT("not implemented");
2269
+ } break;
2270
+ case WSP_GGML_TYPE_Q4_1:
2271
+ switch (r1ptg) {
2272
+ case 2: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_2].pipeline; break;
2273
+ case 3: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_3].pipeline; break;
2274
+ case 4: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_4].pipeline; break;
2275
+ case 5: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_5].pipeline; break;
2276
+ default: WSP_GGML_ABORT("not implemented");
2277
+ } break;
2278
+ case WSP_GGML_TYPE_Q5_0:
2279
+ switch (r1ptg) {
2280
+ case 2: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_2].pipeline; break;
2281
+ case 3: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_3].pipeline; break;
2282
+ case 4: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_4].pipeline; break;
2283
+ case 5: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_5].pipeline; break;
2284
+ default: WSP_GGML_ABORT("not implemented");
2285
+ } break;
2286
+ case WSP_GGML_TYPE_Q5_1:
2287
+ switch (r1ptg) {
2288
+ case 2: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_2].pipeline; break;
2289
+ case 3: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_3].pipeline; break;
2290
+ case 4: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_4].pipeline; break;
2291
+ case 5: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_5].pipeline; break;
2292
+ default: WSP_GGML_ABORT("not implemented");
2293
+ } break;
2294
+ case WSP_GGML_TYPE_Q8_0:
2295
+ switch (r1ptg) {
2296
+ case 2: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_2].pipeline; break;
2297
+ case 3: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3].pipeline; break;
2298
+ case 4: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4].pipeline; break;
2299
+ case 5: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5].pipeline; break;
2300
+ default: WSP_GGML_ABORT("not implemented");
2301
+ } break;
2302
+ case WSP_GGML_TYPE_Q4_K:
2303
+ switch (r1ptg) {
2304
+ case 2: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2].pipeline; break;
2305
+ case 3: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3].pipeline; break;
2306
+ case 4: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4].pipeline; break;
2307
+ case 5: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_5].pipeline; break;
2308
+ default: WSP_GGML_ABORT("not implemented");
2309
+ } break;
2310
+ case WSP_GGML_TYPE_Q5_K:
2311
+ switch (r1ptg) {
2312
+ case 2: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_2].pipeline; break;
2313
+ case 3: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_3].pipeline; break;
2314
+ case 4: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_4].pipeline; break;
2315
+ case 5: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_5].pipeline; break;
2316
+ default: WSP_GGML_ABORT("not implemented");
2317
+ } break;
2318
+ case WSP_GGML_TYPE_Q6_K:
2319
+ switch (r1ptg) {
2320
+ case 2: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_2].pipeline; break;
2321
+ case 3: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_3].pipeline; break;
2322
+ case 4: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_4].pipeline; break;
2323
+ case 5: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_5].pipeline; break;
2324
+ default: WSP_GGML_ABORT("not implemented");
2325
+ } break;
2326
+ case WSP_GGML_TYPE_IQ4_NL:
2327
+ switch (r1ptg) {
2328
+ case 2: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_2].pipeline; break;
2329
+ case 3: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_3].pipeline; break;
2330
+ case 4: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_4].pipeline; break;
2331
+ case 5: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_5].pipeline; break;
2332
+ default: WSP_GGML_ABORT("not implemented");
2333
+ } break;
2334
+ default: WSP_GGML_ABORT("not implemented");
2335
+ }
1972
2336
 
1973
- id<MTLComputePipelineState> pipeline = nil;
1974
-
1975
- switch (src0->type) {
1976
- case WSP_GGML_TYPE_F32: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32 ].pipeline; break;
1977
- case WSP_GGML_TYPE_F16: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32 ].pipeline; break;
1978
- case WSP_GGML_TYPE_BF16: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32 ].pipeline; break;
1979
- case WSP_GGML_TYPE_Q4_0: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32 ].pipeline; break;
1980
- case WSP_GGML_TYPE_Q4_1: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32 ].pipeline; break;
1981
- case WSP_GGML_TYPE_Q5_0: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break;
1982
- case WSP_GGML_TYPE_Q5_1: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32 ].pipeline; break;
1983
- case WSP_GGML_TYPE_Q8_0: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32 ].pipeline; break;
1984
- case WSP_GGML_TYPE_Q2_K: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32 ].pipeline; break;
1985
- case WSP_GGML_TYPE_Q3_K: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32 ].pipeline; break;
1986
- case WSP_GGML_TYPE_Q4_K: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32 ].pipeline; break;
1987
- case WSP_GGML_TYPE_Q5_K: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32 ].pipeline; break;
1988
- case WSP_GGML_TYPE_Q6_K: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32 ].pipeline; break;
1989
- case WSP_GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break;
1990
- case WSP_GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break;
1991
- case WSP_GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32].pipeline; break;
1992
- case WSP_GGML_TYPE_IQ3_S: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32 ].pipeline; break;
1993
- case WSP_GGML_TYPE_IQ2_S: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32 ].pipeline; break;
1994
- case WSP_GGML_TYPE_IQ1_S: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32 ].pipeline; break;
1995
- case WSP_GGML_TYPE_IQ1_M: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32 ].pipeline; break;
1996
- case WSP_GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break;
1997
- case WSP_GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break;
1998
- default: WSP_GGML_ABORT("MUL MAT-MAT not implemented");
1999
- }
2337
+ wsp_ggml_metal_kargs_mul_mv_ext args = {
2338
+ /*.ne00 =*/ ne00,
2339
+ /*.ne01 =*/ ne01,
2340
+ /*.ne02 =*/ ne02,
2341
+ /*.nb00 =*/ nb00,
2342
+ /*.nb01 =*/ nb01,
2343
+ /*.nb02 =*/ nb02,
2344
+ /*.nb03 =*/ nb03,
2345
+ /*.ne10 =*/ ne10,
2346
+ /*.ne11 =*/ ne11,
2347
+ /*.ne12 =*/ ne12,
2348
+ /*.nb10 =*/ nb10,
2349
+ /*.nb11 =*/ nb11,
2350
+ /*.nb12 =*/ nb12,
2351
+ /*.nb13 =*/ nb13,
2352
+ /*.ne0 =*/ ne0,
2353
+ /*.ne1 =*/ ne1,
2354
+ /*.r2 =*/ r2,
2355
+ /*.r3 =*/ r3,
2356
+ /*.nsg =*/ nsg,
2357
+ /*.nxpsg =*/ nxpsg,
2358
+ /*.r1ptg =*/ r1ptg,
2359
+ };
2000
2360
 
2001
- wsp_ggml_metal_kargs_mul_mm args = {
2002
- /*.ne00 =*/ ne00,
2003
- /*.ne02 =*/ ne02,
2004
- /*.nb01 =*/ nb01,
2005
- /*.nb02 =*/ nb02,
2006
- /*.nb03 =*/ nb03,
2007
- /*.ne12 =*/ ne12,
2008
- /*.nb10 =*/ nb10,
2009
- /*.nb11 =*/ nb11,
2010
- /*.nb12 =*/ nb12,
2011
- /*.nb13 =*/ nb13,
2012
- /*.ne0 =*/ ne0,
2013
- /*.ne1 =*/ ne1,
2014
- /*.r2 =*/ r2,
2015
- /*.r3 =*/ r3,
2016
- };
2361
+ [encoder setComputePipelineState:pipeline];
2362
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
2363
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
2364
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
2365
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
2366
+
2367
+ //printf("ne01 = %lld nr0ptg = %d\n", ne01, nr0ptg);
2368
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + r0ptg - 1)/r0ptg, (ne11 + r1ptg - 1)/r1ptg, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
2369
+ } else
2370
+ // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
2371
+ // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
2372
+ if ([device supportsFamily:MTLGPUFamilyApple7] &&
2373
+ !wsp_ggml_is_transposed(src0) &&
2374
+ !wsp_ggml_is_transposed(src1) &&
2375
+ src1t == WSP_GGML_TYPE_F32 &&
2376
+ ne00 % 32 == 0 && ne00 >= 64 &&
2377
+ (ne11 > ne11_mm_min || (wsp_ggml_is_quantized(src0t) && ne12 > 1))) {
2378
+ //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
2017
2379
 
2018
- [encoder setComputePipelineState:pipeline];
2019
- [encoder setBytes:&args length:sizeof(args) atIndex:0];
2020
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
2021
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
2022
- [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
2023
-
2024
- [encoder setThreadgroupMemoryLength:8192 atIndex:0];
2025
- [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
2026
- } else {
2027
- int nth0 = 32;
2028
- int nth1 = 1;
2029
- int nrows = 1;
2030
- //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
2031
-
2032
- id<MTLComputePipelineState> pipeline = nil;
2033
-
2034
- // use custom matrix x vector kernel
2035
- switch (src0t) {
2036
- case WSP_GGML_TYPE_F32:
2037
- {
2038
- WSP_GGML_ASSERT(src1t == WSP_GGML_TYPE_F32);
2039
- pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline;
2380
+ // some Metal matrix data types require aligned pointers
2381
+ // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
2382
+ switch (src0->type) {
2383
+ case WSP_GGML_TYPE_F32: WSP_GGML_ASSERT(nb01 % 16 == 0); break;
2384
+ case WSP_GGML_TYPE_F16: WSP_GGML_ASSERT(nb01 % 8 == 0); break;
2385
+ case WSP_GGML_TYPE_BF16: WSP_GGML_ASSERT(nb01 % 8 == 0); break;
2386
+ default: break;
2387
+ }
2388
+
2389
+ id<MTLComputePipelineState> pipeline = nil;
2390
+
2391
+ switch (src0->type) {
2392
+ case WSP_GGML_TYPE_F32: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32 ].pipeline; break;
2393
+ case WSP_GGML_TYPE_F16: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32 ].pipeline; break;
2394
+ case WSP_GGML_TYPE_BF16: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32 ].pipeline; break;
2395
+ case WSP_GGML_TYPE_Q4_0: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32 ].pipeline; break;
2396
+ case WSP_GGML_TYPE_Q4_1: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32 ].pipeline; break;
2397
+ case WSP_GGML_TYPE_Q5_0: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break;
2398
+ case WSP_GGML_TYPE_Q5_1: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32 ].pipeline; break;
2399
+ case WSP_GGML_TYPE_Q8_0: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32 ].pipeline; break;
2400
+ case WSP_GGML_TYPE_Q2_K: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32 ].pipeline; break;
2401
+ case WSP_GGML_TYPE_Q3_K: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32 ].pipeline; break;
2402
+ case WSP_GGML_TYPE_Q4_K: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32 ].pipeline; break;
2403
+ case WSP_GGML_TYPE_Q5_K: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32 ].pipeline; break;
2404
+ case WSP_GGML_TYPE_Q6_K: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32 ].pipeline; break;
2405
+ case WSP_GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break;
2406
+ case WSP_GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break;
2407
+ case WSP_GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32].pipeline; break;
2408
+ case WSP_GGML_TYPE_IQ3_S: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32 ].pipeline; break;
2409
+ case WSP_GGML_TYPE_IQ2_S: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32 ].pipeline; break;
2410
+ case WSP_GGML_TYPE_IQ1_S: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32 ].pipeline; break;
2411
+ case WSP_GGML_TYPE_IQ1_M: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32 ].pipeline; break;
2412
+ case WSP_GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break;
2413
+ case WSP_GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break;
2414
+ default: WSP_GGML_ABORT("MUL MAT-MAT not implemented");
2415
+ }
2416
+
2417
+ wsp_ggml_metal_kargs_mul_mm args = {
2418
+ /*.ne00 =*/ ne00,
2419
+ /*.ne02 =*/ ne02,
2420
+ /*.nb01 =*/ nb01,
2421
+ /*.nb02 =*/ nb02,
2422
+ /*.nb03 =*/ nb03,
2423
+ /*.ne12 =*/ ne12,
2424
+ /*.nb10 =*/ nb10,
2425
+ /*.nb11 =*/ nb11,
2426
+ /*.nb12 =*/ nb12,
2427
+ /*.nb13 =*/ nb13,
2428
+ /*.ne0 =*/ ne0,
2429
+ /*.ne1 =*/ ne1,
2430
+ /*.r2 =*/ r2,
2431
+ /*.r3 =*/ r3,
2432
+ };
2433
+
2434
+ [encoder setComputePipelineState:pipeline];
2435
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
2436
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
2437
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
2438
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
2439
+
2440
+ [encoder setThreadgroupMemoryLength:8192 atIndex:0];
2441
+ [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
2442
+ } else {
2443
+ int nth0 = 32;
2444
+ int nth1 = 1;
2445
+ int nrows = 1;
2446
+ //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
2447
+
2448
+ id<MTLComputePipelineState> pipeline = nil;
2449
+
2450
+ // use custom matrix x vector kernel
2451
+ switch (src0t) {
2452
+ case WSP_GGML_TYPE_F32:
2453
+ {
2454
+ WSP_GGML_ASSERT(src1t == WSP_GGML_TYPE_F32);
2455
+ pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline;
2456
+ nrows = 4;
2457
+ } break;
2458
+ case WSP_GGML_TYPE_F16:
2459
+ {
2460
+ nth0 = 32;
2461
+ nth1 = 1;
2462
+ if (src1t == WSP_GGML_TYPE_F32) {
2463
+ if (ne11 * ne12 < 4) {
2464
+ pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline;
2465
+ } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
2466
+ pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline;
2467
+ nrows = ne11;
2468
+ } else {
2469
+ pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32].pipeline;
2040
2470
  nrows = 4;
2041
- } break;
2042
- case WSP_GGML_TYPE_F16:
2043
- {
2044
- nth0 = 32;
2045
- nth1 = 1;
2046
- if (src1t == WSP_GGML_TYPE_F32) {
2047
- if (ne11 * ne12 < 4) {
2048
- pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline;
2049
- } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
2050
- pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline;
2051
- nrows = ne11;
2052
- } else {
2053
- pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32].pipeline;
2054
- nrows = 4;
2055
- }
2056
- } else {
2057
- pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16].pipeline;
2058
- nrows = 4;
2059
- }
2060
- } break;
2061
- case WSP_GGML_TYPE_BF16:
2062
- {
2063
- nth0 = 32;
2064
- nth1 = 1;
2065
- if (src1t == WSP_GGML_TYPE_F32) {
2066
- if (ne11 * ne12 < 4) {
2067
- pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW].pipeline;
2068
- } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
2069
- pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4].pipeline;
2070
- nrows = ne11;
2071
- } else {
2072
- pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32].pipeline;
2073
- nrows = 4;
2074
- }
2075
- } else {
2076
- pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16].pipeline;
2077
- nrows = 4;
2078
- }
2079
- } break;
2080
- case WSP_GGML_TYPE_Q4_0:
2081
- {
2082
- nth0 = 8;
2083
- nth1 = 8;
2084
- pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32].pipeline;
2085
- } break;
2086
- case WSP_GGML_TYPE_Q4_1:
2087
- {
2088
- nth0 = 8;
2089
- nth1 = 8;
2090
- pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32].pipeline;
2091
- } break;
2092
- case WSP_GGML_TYPE_Q5_0:
2093
- {
2094
- nth0 = 8;
2095
- nth1 = 8;
2096
- pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32].pipeline;
2097
- } break;
2098
- case WSP_GGML_TYPE_Q5_1:
2099
- {
2100
- nth0 = 8;
2101
- nth1 = 8;
2102
- pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32].pipeline;
2103
- } break;
2104
- case WSP_GGML_TYPE_Q8_0:
2105
- {
2106
- nth0 = 8;
2107
- nth1 = 8;
2108
- pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline;
2109
- } break;
2110
- case WSP_GGML_TYPE_Q2_K:
2111
- {
2112
- nth0 = 2;
2113
- nth1 = 32;
2114
- pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32].pipeline;
2115
- } break;
2116
- case WSP_GGML_TYPE_Q3_K:
2117
- {
2118
- nth0 = 2;
2119
- nth1 = 32;
2120
- pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32].pipeline;
2121
- } break;
2122
- case WSP_GGML_TYPE_Q4_K:
2123
- {
2124
- nth0 = 4; //1;
2125
- nth1 = 8; //32;
2126
- pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32].pipeline;
2127
- } break;
2128
- case WSP_GGML_TYPE_Q5_K:
2129
- {
2130
- nth0 = 2;
2131
- nth1 = 32;
2132
- pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32].pipeline;
2133
- } break;
2134
- case WSP_GGML_TYPE_Q6_K:
2135
- {
2136
- nth0 = 2;
2137
- nth1 = 32;
2138
- pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32].pipeline;
2139
- } break;
2140
- case WSP_GGML_TYPE_IQ2_XXS:
2141
- {
2142
- nth0 = 4;
2143
- nth1 = 16;
2144
- pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32].pipeline;
2145
- } break;
2146
- case WSP_GGML_TYPE_IQ2_XS:
2147
- {
2148
- nth0 = 4;
2149
- nth1 = 16;
2150
- pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32].pipeline;
2151
- } break;
2152
- case WSP_GGML_TYPE_IQ3_XXS:
2153
- {
2154
- nth0 = 4;
2155
- nth1 = 16;
2156
- pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline;
2157
- } break;
2158
- case WSP_GGML_TYPE_IQ3_S:
2159
- {
2160
- nth0 = 4;
2161
- nth1 = 16;
2162
- pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32].pipeline;
2163
- } break;
2164
- case WSP_GGML_TYPE_IQ2_S:
2165
- {
2166
- nth0 = 4;
2167
- nth1 = 16;
2168
- pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32].pipeline;
2169
- } break;
2170
- case WSP_GGML_TYPE_IQ1_S:
2171
- {
2172
- nth0 = 4;
2173
- nth1 = 16;
2174
- pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32].pipeline;
2175
- } break;
2176
- case WSP_GGML_TYPE_IQ1_M:
2177
- {
2178
- nth0 = 4;
2179
- nth1 = 16;
2180
- pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32].pipeline;
2181
- } break;
2182
- case WSP_GGML_TYPE_IQ4_NL:
2183
- {
2184
- nth0 = 4;
2185
- nth1 = 16;
2186
- pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32].pipeline;
2187
- } break;
2188
- case WSP_GGML_TYPE_IQ4_XS:
2189
- {
2190
- nth0 = 4;
2191
- nth1 = 16;
2192
- pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32].pipeline;
2193
- } break;
2194
- default:
2195
- {
2196
- WSP_GGML_LOG_ERROR("Asserting on type %d\n", (int)src0t);
2197
- WSP_GGML_ABORT("not implemented");
2198
2471
  }
2199
- };
2472
+ } else {
2473
+ pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16].pipeline;
2474
+ nrows = 4;
2475
+ }
2476
+ } break;
2477
+ case WSP_GGML_TYPE_BF16:
2478
+ {
2479
+ nth0 = 32;
2480
+ nth1 = 1;
2481
+ if (src1t == WSP_GGML_TYPE_F32) {
2482
+ if (ne11 * ne12 < 4) {
2483
+ pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW].pipeline;
2484
+ } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
2485
+ pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4].pipeline;
2486
+ nrows = ne11;
2487
+ } else {
2488
+ pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32].pipeline;
2489
+ nrows = 4;
2490
+ }
2491
+ } else {
2492
+ pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16].pipeline;
2493
+ nrows = 4;
2494
+ }
2495
+ } break;
2496
+ case WSP_GGML_TYPE_Q4_0:
2497
+ {
2498
+ nth0 = 8;
2499
+ nth1 = 8;
2500
+ pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32].pipeline;
2501
+ } break;
2502
+ case WSP_GGML_TYPE_Q4_1:
2503
+ {
2504
+ nth0 = 8;
2505
+ nth1 = 8;
2506
+ pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32].pipeline;
2507
+ } break;
2508
+ case WSP_GGML_TYPE_Q5_0:
2509
+ {
2510
+ nth0 = 8;
2511
+ nth1 = 8;
2512
+ pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32].pipeline;
2513
+ } break;
2514
+ case WSP_GGML_TYPE_Q5_1:
2515
+ {
2516
+ nth0 = 8;
2517
+ nth1 = 8;
2518
+ pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32].pipeline;
2519
+ } break;
2520
+ case WSP_GGML_TYPE_Q8_0:
2521
+ {
2522
+ nth0 = 8;
2523
+ nth1 = 8;
2524
+ pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline;
2525
+ } break;
2526
+ case WSP_GGML_TYPE_Q2_K:
2527
+ {
2528
+ nth0 = 2;
2529
+ nth1 = 32;
2530
+ pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32].pipeline;
2531
+ } break;
2532
+ case WSP_GGML_TYPE_Q3_K:
2533
+ {
2534
+ nth0 = 2;
2535
+ nth1 = 32;
2536
+ pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32].pipeline;
2537
+ } break;
2538
+ case WSP_GGML_TYPE_Q4_K:
2539
+ {
2540
+ nth0 = 4; //1;
2541
+ nth1 = 8; //32;
2542
+ pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32].pipeline;
2543
+ } break;
2544
+ case WSP_GGML_TYPE_Q5_K:
2545
+ {
2546
+ nth0 = 2;
2547
+ nth1 = 32;
2548
+ pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32].pipeline;
2549
+ } break;
2550
+ case WSP_GGML_TYPE_Q6_K:
2551
+ {
2552
+ nth0 = 2;
2553
+ nth1 = 32;
2554
+ pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32].pipeline;
2555
+ } break;
2556
+ case WSP_GGML_TYPE_IQ2_XXS:
2557
+ {
2558
+ nth0 = 4;
2559
+ nth1 = 16;
2560
+ pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32].pipeline;
2561
+ } break;
2562
+ case WSP_GGML_TYPE_IQ2_XS:
2563
+ {
2564
+ nth0 = 4;
2565
+ nth1 = 16;
2566
+ pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32].pipeline;
2567
+ } break;
2568
+ case WSP_GGML_TYPE_IQ3_XXS:
2569
+ {
2570
+ nth0 = 4;
2571
+ nth1 = 16;
2572
+ pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline;
2573
+ } break;
2574
+ case WSP_GGML_TYPE_IQ3_S:
2575
+ {
2576
+ nth0 = 4;
2577
+ nth1 = 16;
2578
+ pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32].pipeline;
2579
+ } break;
2580
+ case WSP_GGML_TYPE_IQ2_S:
2581
+ {
2582
+ nth0 = 4;
2583
+ nth1 = 16;
2584
+ pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32].pipeline;
2585
+ } break;
2586
+ case WSP_GGML_TYPE_IQ1_S:
2587
+ {
2588
+ nth0 = 4;
2589
+ nth1 = 16;
2590
+ pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32].pipeline;
2591
+ } break;
2592
+ case WSP_GGML_TYPE_IQ1_M:
2593
+ {
2594
+ nth0 = 4;
2595
+ nth1 = 16;
2596
+ pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32].pipeline;
2597
+ } break;
2598
+ case WSP_GGML_TYPE_IQ4_NL:
2599
+ {
2600
+ nth0 = 4;
2601
+ nth1 = 16;
2602
+ pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32].pipeline;
2603
+ } break;
2604
+ case WSP_GGML_TYPE_IQ4_XS:
2605
+ {
2606
+ nth0 = 4;
2607
+ nth1 = 16;
2608
+ pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32].pipeline;
2609
+ } break;
2610
+ default:
2611
+ {
2612
+ WSP_GGML_LOG_ERROR("Asserting on type %d\n", (int)src0t);
2613
+ WSP_GGML_ABORT("not implemented");
2614
+ }
2615
+ };
2200
2616
 
2201
- wsp_ggml_metal_kargs_mul_mv args = {
2202
- /*.ne00 =*/ ne00,
2203
- /*.ne01 =*/ ne01,
2204
- /*.ne02 =*/ ne02,
2205
- /*.nb00 =*/ nb00,
2206
- /*.nb01 =*/ nb01,
2207
- /*.nb02 =*/ nb02,
2208
- /*.nb03 =*/ nb03,
2209
- /*.ne10 =*/ ne10,
2210
- /*.ne11 =*/ ne11,
2211
- /*.ne12 =*/ ne12,
2212
- /*.nb10 =*/ nb10,
2213
- /*.nb11 =*/ nb11,
2214
- /*.nb12 =*/ nb12,
2215
- /*.nb13 =*/ nb13,
2216
- /*.ne0 =*/ ne0,
2217
- /*.ne1 =*/ ne1,
2218
- /*.r2 =*/ r2,
2219
- /*.r3 =*/ r3,
2220
- };
2617
+ wsp_ggml_metal_kargs_mul_mv args = {
2618
+ /*.ne00 =*/ ne00,
2619
+ /*.ne01 =*/ ne01,
2620
+ /*.ne02 =*/ ne02,
2621
+ /*.nb00 =*/ nb00,
2622
+ /*.nb01 =*/ nb01,
2623
+ /*.nb02 =*/ nb02,
2624
+ /*.nb03 =*/ nb03,
2625
+ /*.ne10 =*/ ne10,
2626
+ /*.ne11 =*/ ne11,
2627
+ /*.ne12 =*/ ne12,
2628
+ /*.nb10 =*/ nb10,
2629
+ /*.nb11 =*/ nb11,
2630
+ /*.nb12 =*/ nb12,
2631
+ /*.nb13 =*/ nb13,
2632
+ /*.ne0 =*/ ne0,
2633
+ /*.ne1 =*/ ne1,
2634
+ /*.r2 =*/ r2,
2635
+ /*.r3 =*/ r3,
2636
+ };
2221
2637
 
2222
- [encoder setComputePipelineState:pipeline];
2223
- [encoder setBytes:&args length:sizeof(args) atIndex:0];
2224
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
2225
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
2226
- [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
2638
+ [encoder setComputePipelineState:pipeline];
2639
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
2640
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
2641
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
2642
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
2227
2643
 
2228
- if (src0t == WSP_GGML_TYPE_Q4_0 || src0t == WSP_GGML_TYPE_Q4_1 || src0t == WSP_GGML_TYPE_Q5_0 ||
2229
- src0t == WSP_GGML_TYPE_Q5_1 || src0t == WSP_GGML_TYPE_Q8_0 || src0t == WSP_GGML_TYPE_Q2_K ||
2230
- src0t == WSP_GGML_TYPE_IQ1_S || src0t == WSP_GGML_TYPE_IQ1_M || src0t == WSP_GGML_TYPE_IQ2_S) {
2231
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2232
- }
2233
- else if (src0t == WSP_GGML_TYPE_IQ2_XXS || src0t == WSP_GGML_TYPE_IQ2_XS) {
2234
- const int mem_size = src0t == WSP_GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
2235
- [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
2236
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2237
- }
2238
- else if (src0t == WSP_GGML_TYPE_IQ3_XXS || src0t == WSP_GGML_TYPE_IQ3_S) {
2239
- const int mem_size = src0t == WSP_GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
2240
- [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
2241
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2242
- }
2243
- else if (src0t == WSP_GGML_TYPE_IQ4_NL || src0t == WSP_GGML_TYPE_IQ4_XS) {
2244
- const int mem_size = 32*sizeof(float);
2245
- [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
2246
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2247
- }
2248
- else if (src0t == WSP_GGML_TYPE_Q4_K) {
2249
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2250
- }
2251
- else if (src0t == WSP_GGML_TYPE_Q3_K) {
2252
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2253
- }
2254
- else if (src0t == WSP_GGML_TYPE_Q5_K) {
2255
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2256
- }
2257
- else if (src0t == WSP_GGML_TYPE_Q6_K) {
2258
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2259
- } else {
2260
- const int64_t ny = (ne11 + nrows - 1)/nrows;
2261
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2262
- }
2263
- }
2644
+ if (src0t == WSP_GGML_TYPE_Q4_0 || src0t == WSP_GGML_TYPE_Q4_1 || src0t == WSP_GGML_TYPE_Q5_0 ||
2645
+ src0t == WSP_GGML_TYPE_Q5_1 || src0t == WSP_GGML_TYPE_Q8_0 || src0t == WSP_GGML_TYPE_Q2_K ||
2646
+ src0t == WSP_GGML_TYPE_IQ1_S || src0t == WSP_GGML_TYPE_IQ1_M || src0t == WSP_GGML_TYPE_IQ2_S) {
2647
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2648
+ }
2649
+ else if (src0t == WSP_GGML_TYPE_IQ2_XXS || src0t == WSP_GGML_TYPE_IQ2_XS) {
2650
+ const int mem_size = src0t == WSP_GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
2651
+ [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
2652
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2653
+ }
2654
+ else if (src0t == WSP_GGML_TYPE_IQ3_XXS || src0t == WSP_GGML_TYPE_IQ3_S) {
2655
+ const int mem_size = src0t == WSP_GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
2656
+ [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
2657
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2658
+ }
2659
+ else if (src0t == WSP_GGML_TYPE_IQ4_NL || src0t == WSP_GGML_TYPE_IQ4_XS) {
2660
+ const int mem_size = 32*sizeof(float);
2661
+ [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
2662
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2663
+ }
2664
+ else if (src0t == WSP_GGML_TYPE_Q4_K) {
2665
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2666
+ }
2667
+ else if (src0t == WSP_GGML_TYPE_Q3_K) {
2668
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2669
+ }
2670
+ else if (src0t == WSP_GGML_TYPE_Q5_K) {
2671
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2672
+ }
2673
+ else if (src0t == WSP_GGML_TYPE_Q6_K) {
2674
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2675
+ } else {
2676
+ const int64_t ny = (ne11 + nrows - 1)/nrows;
2677
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2678
+ }
2679
+ }
2264
2680
  } break;
2265
2681
  case WSP_GGML_OP_MUL_MAT_ID:
2266
2682
  {
@@ -2672,7 +3088,6 @@ static void wsp_ggml_metal_encode_node(
2672
3088
  } break;
2673
3089
  case WSP_GGML_OP_GROUP_NORM:
2674
3090
  {
2675
- WSP_GGML_ASSERT(ne00 % 4 == 0);
2676
3091
  WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0));
2677
3092
 
2678
3093
  float eps;
@@ -2742,7 +3157,9 @@ static void wsp_ggml_metal_encode_node(
2742
3157
  } break;
2743
3158
  case WSP_GGML_OP_ROPE:
2744
3159
  {
2745
- WSP_GGML_ASSERT(ne10 == ne02);
3160
+ // make sure we have one or more position id(ne10) per token(ne02)
3161
+ WSP_GGML_ASSERT(ne10 % ne02 == 0);
3162
+ WSP_GGML_ASSERT(ne10 >= ne02);
2746
3163
 
2747
3164
  const int nth = MIN(1024, ne00);
2748
3165
 
@@ -2908,6 +3325,49 @@ static void wsp_ggml_metal_encode_node(
2908
3325
  [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
2909
3326
  }
2910
3327
  } break;
3328
+ case WSP_GGML_OP_CONV_TRANSPOSE_1D:
3329
+ {
3330
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0));
3331
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src1));
3332
+ WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F16 || src0->type == WSP_GGML_TYPE_F32);
3333
+ WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F32);
3334
+ WSP_GGML_ASSERT( dst->type == WSP_GGML_TYPE_F32);
3335
+
3336
+ const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
3337
+
3338
+ const int32_t IC = src1->ne[1];
3339
+ const int32_t IL = src1->ne[0];
3340
+
3341
+ const int32_t K = src0->ne[0];
3342
+
3343
+ const int32_t OL = dst->ne[0];
3344
+ const int32_t OC = dst->ne[1];
3345
+
3346
+ id<MTLComputePipelineState> pipeline;
3347
+
3348
+ switch (src0->type) {
3349
+ case WSP_GGML_TYPE_F32: {
3350
+ pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F32_F32].pipeline;
3351
+ } break;
3352
+ case WSP_GGML_TYPE_F16: {
3353
+ pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32].pipeline;
3354
+ } break;
3355
+ default: WSP_GGML_ABORT("fatal error");
3356
+ };
3357
+
3358
+ [encoder setComputePipelineState:pipeline];
3359
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
3360
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
3361
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
3362
+ [encoder setBytes:&IC length:sizeof( int32_t) atIndex:3];
3363
+ [encoder setBytes:&IL length:sizeof( int32_t) atIndex:4];
3364
+ [encoder setBytes:&K length:sizeof( int32_t) atIndex:5];
3365
+ [encoder setBytes:&s0 length:sizeof( int32_t) atIndex:6];
3366
+ [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:7];
3367
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:8];
3368
+
3369
+ [encoder dispatchThreadgroups:MTLSizeMake(OL, OC, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
3370
+ } break;
2911
3371
  case WSP_GGML_OP_UPSCALE:
2912
3372
  {
2913
3373
  WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F32);
@@ -2977,6 +3437,38 @@ static void wsp_ggml_metal_encode_node(
2977
3437
 
2978
3438
  const int nth = MIN(1024, ne0);
2979
3439
 
3440
+ [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
3441
+ } break;
3442
+ case WSP_GGML_OP_PAD_REFLECT_1D:
3443
+ {
3444
+ WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F32);
3445
+
3446
+ const int32_t p0 = ((const int32_t *)(dst->op_params))[0];
3447
+ const int32_t p1 = ((const int32_t *)(dst->op_params))[1];
3448
+
3449
+ id<MTLComputePipelineState> pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32].pipeline;
3450
+
3451
+ [encoder setComputePipelineState:pipeline];
3452
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
3453
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
3454
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
3455
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
3456
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
3457
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
3458
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:6];
3459
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
3460
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
3461
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
3462
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
3463
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:11];
3464
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:12];
3465
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:13];
3466
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:14];
3467
+ [encoder setBytes:&p0 length:sizeof(p0) atIndex:15];
3468
+ [encoder setBytes:&p1 length:sizeof(p1) atIndex:16];
3469
+
3470
+ const int nth = MIN(1024, ne0);
3471
+
2980
3472
  [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2981
3473
  } break;
2982
3474
  case WSP_GGML_OP_ARANGE:
@@ -3439,10 +3931,6 @@ static void wsp_ggml_metal_encode_node(
3439
3931
  case WSP_GGML_OP_CPY:
3440
3932
  case WSP_GGML_OP_CONT:
3441
3933
  {
3442
- WSP_GGML_ASSERT(ne00 % wsp_ggml_blck_size(src0->type) == 0);
3443
-
3444
- int nth = MIN(1024, ne00/wsp_ggml_blck_size(src0->type));
3445
-
3446
3934
  id<MTLComputePipelineState> pipeline = nil;
3447
3935
 
3448
3936
  switch (src0t) {
@@ -3476,7 +3964,47 @@ static void wsp_ggml_metal_encode_node(
3476
3964
  switch (dstt) {
3477
3965
  case WSP_GGML_TYPE_F32: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_CPY_BF16_F32].pipeline; break;
3478
3966
  case WSP_GGML_TYPE_BF16: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16].pipeline; break;
3479
- default: WSP_GGML_ASSERT(false && "not implemented");
3967
+ default: WSP_GGML_ABORT("not implemented");
3968
+ };
3969
+ } break;
3970
+ case WSP_GGML_TYPE_Q4_0:
3971
+ {
3972
+ switch (dstt) {
3973
+ case WSP_GGML_TYPE_F32: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32].pipeline; break;
3974
+ case WSP_GGML_TYPE_F16: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F16].pipeline; break;
3975
+ default: WSP_GGML_ABORT("not implemented");
3976
+ };
3977
+ } break;
3978
+ case WSP_GGML_TYPE_Q4_1:
3979
+ {
3980
+ switch (dstt) {
3981
+ case WSP_GGML_TYPE_F32: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32].pipeline; break;
3982
+ case WSP_GGML_TYPE_F16: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F16].pipeline; break;
3983
+ default: WSP_GGML_ABORT("not implemented");
3984
+ };
3985
+ } break;
3986
+ case WSP_GGML_TYPE_Q5_0:
3987
+ {
3988
+ switch (dstt) {
3989
+ case WSP_GGML_TYPE_F32: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32].pipeline; break;
3990
+ case WSP_GGML_TYPE_F16: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F16].pipeline; break;
3991
+ default: WSP_GGML_ABORT("not implemented");
3992
+ };
3993
+ } break;
3994
+ case WSP_GGML_TYPE_Q5_1:
3995
+ {
3996
+ switch (dstt) {
3997
+ case WSP_GGML_TYPE_F32: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32].pipeline; break;
3998
+ case WSP_GGML_TYPE_F16: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F16].pipeline; break;
3999
+ default: WSP_GGML_ABORT("not implemented");
4000
+ };
4001
+ } break;
4002
+ case WSP_GGML_TYPE_Q8_0:
4003
+ {
4004
+ switch (dstt) {
4005
+ case WSP_GGML_TYPE_F32: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32].pipeline; break;
4006
+ case WSP_GGML_TYPE_F16: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16].pipeline; break;
4007
+ default: WSP_GGML_ABORT("not implemented");
3480
4008
  };
3481
4009
  } break;
3482
4010
  default: WSP_GGML_ABORT("not implemented");
@@ -3506,7 +4034,73 @@ static void wsp_ggml_metal_encode_node(
3506
4034
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
3507
4035
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
3508
4036
 
4037
+ WSP_GGML_ASSERT(ne00 % wsp_ggml_blck_size(src0->type) == 0);
4038
+ int nth = MIN(1024, ne00/wsp_ggml_blck_size(src0->type));
4039
+
3509
4040
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
4041
+
4042
+ } break;
4043
+ case WSP_GGML_OP_SET:
4044
+ {
4045
+ WSP_GGML_ASSERT(wsp_ggml_are_same_shape(src0, dst));
4046
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous(dst) && wsp_ggml_is_contiguous(src0));
4047
+
4048
+ // src0 and dst as viewed during set
4049
+ const size_t dst_nb0 = wsp_ggml_element_size(src0);
4050
+
4051
+ const size_t dst_nb1 = ((int32_t *) dst->op_params)[0];
4052
+ const size_t dst_nb2 = ((int32_t *) dst->op_params)[1];
4053
+ const size_t dst_nb3 = ((int32_t *) dst->op_params)[2];
4054
+ const size_t offset = ((int32_t *) dst->op_params)[3];
4055
+ const bool inplace = (bool) ((int32_t *) dst->op_params)[4];
4056
+
4057
+ if (!inplace) {
4058
+ memcpy(((char *) dst->data), ((char *) src0->data), wsp_ggml_nbytes(dst));
4059
+ }
4060
+
4061
+ const int im0 = (ne10 == 0 ? 0 : ne10-1);
4062
+ const int im1 = (ne11 == 0 ? 0 : ne11-1);
4063
+ const int im2 = (ne12 == 0 ? 0 : ne12-1);
4064
+ const int im3 = (ne13 == 0 ? 0 : ne13-1);
4065
+
4066
+ WSP_GGML_ASSERT(offset + im0*dst_nb0 + im1*dst_nb1 + im2*dst_nb2 + im3*dst_nb3 <= wsp_ggml_nbytes(dst));
4067
+
4068
+ id<MTLComputePipelineState> pipeline = nil;
4069
+
4070
+ switch (src0t) {
4071
+ case WSP_GGML_TYPE_F32:
4072
+ WSP_GGML_ASSERT(nb10 == sizeof(float));
4073
+ pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_SET_F32].pipeline; break;
4074
+ case WSP_GGML_TYPE_I32:
4075
+ WSP_GGML_ASSERT(nb10 == sizeof(int32_t));
4076
+ pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_SET_I32].pipeline; break;
4077
+ default: WSP_GGML_ABORT("fatal error");
4078
+ }
4079
+
4080
+ wsp_ggml_metal_kargs_set args = {
4081
+ /*.ne10 =*/ ne10,
4082
+ /*.ne11 =*/ ne11,
4083
+ /*.ne12 =*/ ne12,
4084
+ /*.nb10 =*/ nb10,
4085
+ /*.nb11 =*/ nb11,
4086
+ /*.nb12 =*/ nb12,
4087
+ /*.nb13 =*/ nb13,
4088
+ /*.nb1 =*/ dst_nb1,
4089
+ /*.nb2 =*/ dst_nb2,
4090
+ /*.nb3 =*/ dst_nb3,
4091
+ /*.offs =*/ offset,
4092
+ /*.inplace =*/ inplace,
4093
+ };
4094
+
4095
+ const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne10);
4096
+
4097
+ [encoder setComputePipelineState:pipeline];
4098
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
4099
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
4100
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
4101
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
4102
+
4103
+ [encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
3510
4104
  } break;
3511
4105
  case WSP_GGML_OP_POOL_2D:
3512
4106
  {
@@ -3567,6 +4161,31 @@ static void wsp_ggml_metal_encode_node(
3567
4161
 
3568
4162
  [encoder dispatchThreadgroups:MTLSizeMake(n_tg, 1, 1) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)];
3569
4163
  } break;
4164
+ case WSP_GGML_OP_ARGMAX:
4165
+ {
4166
+ WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F32);
4167
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous_1(src0));
4168
+ WSP_GGML_ASSERT(nb00 == wsp_ggml_type_size(src0->type));
4169
+
4170
+ const int64_t nrows = wsp_ggml_nrows(src0);
4171
+
4172
+ int nth = 32; // SIMD width
4173
+ while (nth < ne00 && nth*ne01*ne02*ne03 < 256) {
4174
+ nth *= 2;
4175
+ }
4176
+
4177
+ id<MTLComputePipelineState> pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_ARGMAX].pipeline;
4178
+
4179
+ [encoder setComputePipelineState:pipeline];
4180
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
4181
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
4182
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
4183
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
4184
+ [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
4185
+ [encoder setThreadgroupMemoryLength:32*sizeof(int32_t) atIndex:1];
4186
+
4187
+ [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
4188
+ } break;
3570
4189
  default:
3571
4190
  {
3572
4191
  WSP_GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, wsp_ggml_op_name(dst->op));
@@ -3718,6 +4337,8 @@ static void wsp_ggml_backend_metal_buffer_free_buffer(wsp_ggml_backend_buffer_t
3718
4337
  for (int i = 0; i < ctx->n_buffers; i++) {
3719
4338
  [ctx->buffers[i].metal release];
3720
4339
  }
4340
+
4341
+ wsp_ggml_backend_metal_buffer_rset_free(ctx);
3721
4342
  wsp_ggml_backend_metal_device_rel(buffer->buft->device->context);
3722
4343
 
3723
4344
  if (ctx->owned) {
@@ -3740,19 +4361,19 @@ static void * wsp_ggml_backend_metal_buffer_get_base(wsp_ggml_backend_buffer_t b
3740
4361
  static void wsp_ggml_backend_metal_buffer_memset_tensor(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
3741
4362
  memset((char *)tensor->data + offset, value, size);
3742
4363
 
3743
- UNUSED(buffer);
4364
+ WSP_GGML_UNUSED(buffer);
3744
4365
  }
3745
4366
 
3746
4367
  static void wsp_ggml_backend_metal_buffer_set_tensor(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
3747
4368
  memcpy((char *)tensor->data + offset, data, size);
3748
4369
 
3749
- UNUSED(buffer);
4370
+ WSP_GGML_UNUSED(buffer);
3750
4371
  }
3751
4372
 
3752
4373
  static void wsp_ggml_backend_metal_buffer_get_tensor(wsp_ggml_backend_buffer_t buffer, const struct wsp_ggml_tensor * tensor, void * data, size_t offset, size_t size) {
3753
4374
  memcpy(data, (const char *)tensor->data + offset, size);
3754
4375
 
3755
- UNUSED(buffer);
4376
+ WSP_GGML_UNUSED(buffer);
3756
4377
  }
3757
4378
 
3758
4379
  static bool wsp_ggml_backend_metal_buffer_cpy_tensor(wsp_ggml_backend_buffer_t buffer, const struct wsp_ggml_tensor * src, struct wsp_ggml_tensor * dst) {
@@ -3762,7 +4383,7 @@ static bool wsp_ggml_backend_metal_buffer_cpy_tensor(wsp_ggml_backend_buffer_t b
3762
4383
  }
3763
4384
  return false;
3764
4385
 
3765
- UNUSED(buffer);
4386
+ WSP_GGML_UNUSED(buffer);
3766
4387
  }
3767
4388
 
3768
4389
  static void wsp_ggml_backend_metal_buffer_clear(wsp_ggml_backend_buffer_t buffer, uint8_t value) {
@@ -3788,7 +4409,7 @@ static struct wsp_ggml_backend_buffer_i wsp_ggml_backend_metal_buffer_i = {
3788
4409
  static const char * wsp_ggml_backend_metal_buffer_type_get_name(wsp_ggml_backend_buffer_type_t buft) {
3789
4410
  return "Metal";
3790
4411
 
3791
- UNUSED(buft);
4412
+ WSP_GGML_UNUSED(buft);
3792
4413
  }
3793
4414
 
3794
4415
  static void wsp_ggml_backend_metal_log_allocated_size(id<MTLDevice> device, size_t size_aligned) {
@@ -3812,8 +4433,8 @@ static void wsp_ggml_backend_metal_log_allocated_size(id<MTLDevice> device, size
3812
4433
  }
3813
4434
  #endif
3814
4435
  #endif
3815
- UNUSED(device);
3816
- UNUSED(size_aligned);
4436
+ WSP_GGML_UNUSED(device);
4437
+ WSP_GGML_UNUSED(size_aligned);
3817
4438
  }
3818
4439
 
3819
4440
  static wsp_ggml_backend_buffer_t wsp_ggml_backend_metal_buffer_type_alloc_buffer(wsp_ggml_backend_buffer_type_t buft, size_t size) {
@@ -3826,7 +4447,8 @@ static wsp_ggml_backend_buffer_t wsp_ggml_backend_metal_buffer_type_alloc_buffer
3826
4447
  size_aligned += (size_page - (size_aligned % size_page));
3827
4448
  }
3828
4449
 
3829
- id<MTLDevice> device = wsp_ggml_backend_metal_device_acq(buft->device->context);
4450
+ struct wsp_ggml_backend_metal_device_context * ctx_dev = (struct wsp_ggml_backend_metal_device_context *)buft->device->context;
4451
+ id<MTLDevice> device = wsp_ggml_backend_metal_device_acq(ctx_dev);
3830
4452
 
3831
4453
  ctx->all_data = wsp_ggml_metal_host_malloc(size_aligned);
3832
4454
  ctx->all_size = size_aligned;
@@ -3849,7 +4471,14 @@ static wsp_ggml_backend_buffer_t wsp_ggml_backend_metal_buffer_type_alloc_buffer
3849
4471
  if (size_aligned > 0 && (ctx->all_data == NULL || ctx->buffers[0].metal == nil)) {
3850
4472
  WSP_GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0);
3851
4473
  free(ctx);
3852
- wsp_ggml_backend_metal_device_rel(buft->device->context);
4474
+ wsp_ggml_backend_metal_device_rel(ctx_dev);
4475
+ return NULL;
4476
+ }
4477
+
4478
+ if (!wsp_ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) {
4479
+ WSP_GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__);
4480
+ free(ctx);
4481
+ wsp_ggml_backend_metal_device_rel(ctx_dev);
3853
4482
  return NULL;
3854
4483
  }
3855
4484
 
@@ -3860,7 +4489,7 @@ static wsp_ggml_backend_buffer_t wsp_ggml_backend_metal_buffer_type_alloc_buffer
3860
4489
 
3861
4490
  static size_t wsp_ggml_backend_metal_buffer_type_get_alignment(wsp_ggml_backend_buffer_type_t buft) {
3862
4491
  return 32;
3863
- UNUSED(buft);
4492
+ WSP_GGML_UNUSED(buft);
3864
4493
  }
3865
4494
 
3866
4495
  static size_t wsp_ggml_backend_metal_buffer_type_get_max_size(wsp_ggml_backend_buffer_type_t buft) {
@@ -3870,13 +4499,13 @@ static size_t wsp_ggml_backend_metal_buffer_type_get_max_size(wsp_ggml_backend_b
3870
4499
 
3871
4500
  return max_size;
3872
4501
 
3873
- UNUSED(buft);
4502
+ WSP_GGML_UNUSED(buft);
3874
4503
  }
3875
4504
 
3876
4505
  static bool wsp_ggml_backend_metal_buffer_type_is_host(wsp_ggml_backend_buffer_type_t buft) {
3877
4506
  return true;
3878
4507
 
3879
- UNUSED(buft);
4508
+ WSP_GGML_UNUSED(buft);
3880
4509
  }
3881
4510
 
3882
4511
  wsp_ggml_backend_buffer_type_t wsp_ggml_backend_metal_buffer_type(void) {
@@ -3899,7 +4528,7 @@ wsp_ggml_backend_buffer_type_t wsp_ggml_backend_metal_buffer_type(void) {
3899
4528
  static const char * wsp_ggml_backend_metal_buffer_from_ptr_type_get_name(wsp_ggml_backend_buffer_type_t buft) {
3900
4529
  return "Metal_Mapped";
3901
4530
 
3902
- UNUSED(buft);
4531
+ WSP_GGML_UNUSED(buft);
3903
4532
  }
3904
4533
 
3905
4534
  static wsp_ggml_backend_buffer_type_t wsp_ggml_backend_metal_buffer_from_ptr_type(void) {
@@ -3942,7 +4571,8 @@ wsp_ggml_backend_buffer_t wsp_ggml_backend_metal_buffer_from_ptr(void * data, si
3942
4571
  size_aligned += (size_page - (size_aligned % size_page));
3943
4572
  }
3944
4573
 
3945
- id<MTLDevice> device = wsp_ggml_backend_metal_device_acq(&g_wsp_ggml_ctx_dev_main);
4574
+ struct wsp_ggml_backend_metal_device_context * ctx_dev = &g_wsp_ggml_ctx_dev_main;
4575
+ id<MTLDevice> device = wsp_ggml_backend_metal_device_acq(ctx_dev);
3946
4576
 
3947
4577
  // the buffer fits into the max buffer size allowed by the device
3948
4578
  if (size_aligned <= device.maxBufferLength) {
@@ -3995,6 +4625,13 @@ wsp_ggml_backend_buffer_t wsp_ggml_backend_metal_buffer_from_ptr(void * data, si
3995
4625
  }
3996
4626
  }
3997
4627
 
4628
+ if (!wsp_ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) {
4629
+ WSP_GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__);
4630
+ free(ctx);
4631
+ wsp_ggml_backend_metal_device_rel(ctx_dev);
4632
+ return NULL;
4633
+ }
4634
+
3998
4635
  return wsp_ggml_backend_buffer_init(wsp_ggml_backend_metal_buffer_from_ptr_type(), wsp_ggml_backend_metal_buffer_i, ctx, size);
3999
4636
  }
4000
4637
 
@@ -4003,7 +4640,7 @@ wsp_ggml_backend_buffer_t wsp_ggml_backend_metal_buffer_from_ptr(void * data, si
4003
4640
  static const char * wsp_ggml_backend_metal_name(wsp_ggml_backend_t backend) {
4004
4641
  return "Metal";
4005
4642
 
4006
- UNUSED(backend);
4643
+ WSP_GGML_UNUSED(backend);
4007
4644
  }
4008
4645
 
4009
4646
  static void wsp_ggml_backend_metal_free(wsp_ggml_backend_t backend) {
@@ -4308,6 +4945,13 @@ static wsp_ggml_backend_buffer_t wsp_ggml_backend_metal_device_buffer_from_ptr(w
4308
4945
  }
4309
4946
  }
4310
4947
 
4948
+ if (!wsp_ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) {
4949
+ WSP_GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__);
4950
+ free(ctx);
4951
+ wsp_ggml_backend_metal_device_rel(ctx_dev);
4952
+ return NULL;
4953
+ }
4954
+
4311
4955
  return wsp_ggml_backend_buffer_init(wsp_ggml_backend_metal_buffer_from_ptr_type(), wsp_ggml_backend_metal_buffer_i, ctx, size);
4312
4956
  }
4313
4957
 
@@ -4321,7 +4965,7 @@ static bool wsp_ggml_backend_metal_device_supports_buft(wsp_ggml_backend_dev_t d
4321
4965
  return buft->iface.get_name == wsp_ggml_backend_metal_buffer_type_get_name ||
4322
4966
  buft->iface.get_name == wsp_ggml_backend_metal_buffer_from_ptr_type_get_name;
4323
4967
 
4324
- UNUSED(dev);
4968
+ WSP_GGML_UNUSED(dev);
4325
4969
  }
4326
4970
 
4327
4971
  static bool wsp_ggml_backend_metal_device_offload_op(wsp_ggml_backend_dev_t dev, const struct wsp_ggml_tensor * op) {
@@ -4372,19 +5016,45 @@ static wsp_ggml_backend_dev_t wsp_ggml_backend_metal_reg_device_get(wsp_ggml_bac
4372
5016
  WSP_GGML_UNUSED(index);
4373
5017
  }
4374
5018
 
5019
+ static struct wsp_ggml_backend_feature g_wsp_ggml_backend_metal_features[] = {
5020
+ #if defined(WSP_GGML_METAL_EMBED_LIBRARY)
5021
+ { "EMBED_LIBRARY", "1" },
5022
+ #endif
5023
+ #if defined(WSP_GGML_METAL_USE_BF16)
5024
+ { "BF16", "1" },
5025
+ #endif
5026
+ { nil, nil },
5027
+ };
5028
+
5029
+ static struct wsp_ggml_backend_feature * wsp_ggml_backend_metal_get_features(wsp_ggml_backend_reg_t reg) {
5030
+ return g_wsp_ggml_backend_metal_features;
5031
+
5032
+ WSP_GGML_UNUSED(reg);
5033
+ }
5034
+
5035
+ static void * wsp_ggml_backend_metal_get_proc_address(wsp_ggml_backend_reg_t reg, const char * name) {
5036
+ if (strcmp(name, "wsp_ggml_backend_get_features") == 0) {
5037
+ return (void *)wsp_ggml_backend_metal_get_features;
5038
+ }
5039
+
5040
+ return NULL;
5041
+
5042
+ WSP_GGML_UNUSED(reg);
5043
+ }
4375
5044
  static struct wsp_ggml_backend_reg_i wsp_ggml_backend_metal_reg_i = {
4376
5045
  /* .get_name = */ wsp_ggml_backend_metal_reg_get_name,
4377
5046
  /* .device_count = */ wsp_ggml_backend_metal_reg_device_count,
4378
5047
  /* .device_get = */ wsp_ggml_backend_metal_reg_device_get,
4379
- /* .get_proc_address = */ NULL,
5048
+ /* .get_proc_address = */ wsp_ggml_backend_metal_get_proc_address,
4380
5049
  };
4381
5050
 
4382
5051
  wsp_ggml_backend_reg_t wsp_ggml_backend_metal_reg(void) {
4383
5052
  // TODO: make this thread-safe somehow?
4384
5053
  {
4385
5054
  g_wsp_ggml_backend_metal_reg = (struct wsp_ggml_backend_reg) {
4386
- /* .iface = */ wsp_ggml_backend_metal_reg_i,
4387
- /* .context = */ NULL,
5055
+ /* .api_version = */ WSP_GGML_BACKEND_API_VERSION,
5056
+ /* .iface = */ wsp_ggml_backend_metal_reg_i,
5057
+ /* .context = */ NULL,
4388
5058
  };
4389
5059
 
4390
5060
  g_wsp_ggml_backend_metal_device = (struct wsp_ggml_backend_device) {
@@ -4396,3 +5066,5 @@ wsp_ggml_backend_reg_t wsp_ggml_backend_metal_reg(void) {
4396
5066
 
4397
5067
  return &g_wsp_ggml_backend_metal_reg;
4398
5068
  }
5069
+
5070
+ WSP_GGML_BACKEND_DL_IMPL(wsp_ggml_backend_metal_reg)