llama_cpp 0.7.0 → 0.7.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -81,18 +81,18 @@ struct ggml_metal_context {
81
81
  GGML_METAL_DECL_KERNEL(get_rows_q6_K);
82
82
  GGML_METAL_DECL_KERNEL(rms_norm);
83
83
  GGML_METAL_DECL_KERNEL(norm);
84
- GGML_METAL_DECL_KERNEL(mul_mat_f32_f32);
85
- GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
86
- GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_1row);
87
- GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_l4);
88
- GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
89
- GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
90
- GGML_METAL_DECL_KERNEL(mul_mat_q8_0_f32);
91
- GGML_METAL_DECL_KERNEL(mul_mat_q2_K_f32);
92
- GGML_METAL_DECL_KERNEL(mul_mat_q3_K_f32);
93
- GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32);
94
- GGML_METAL_DECL_KERNEL(mul_mat_q5_K_f32);
95
- GGML_METAL_DECL_KERNEL(mul_mat_q6_K_f32);
84
+ GGML_METAL_DECL_KERNEL(mul_mv_f32_f32);
85
+ GGML_METAL_DECL_KERNEL(mul_mv_f16_f32);
86
+ GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_1row);
87
+ GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_l4);
88
+ GGML_METAL_DECL_KERNEL(mul_mv_q4_0_f32);
89
+ GGML_METAL_DECL_KERNEL(mul_mv_q4_1_f32);
90
+ GGML_METAL_DECL_KERNEL(mul_mv_q8_0_f32);
91
+ GGML_METAL_DECL_KERNEL(mul_mv_q2_K_f32);
92
+ GGML_METAL_DECL_KERNEL(mul_mv_q3_K_f32);
93
+ GGML_METAL_DECL_KERNEL(mul_mv_q4_K_f32);
94
+ GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32);
95
+ GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32);
96
96
  GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
97
97
  GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
98
98
  GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
@@ -109,6 +109,8 @@ struct ggml_metal_context {
109
109
  GGML_METAL_DECL_KERNEL(cpy_f32_f16);
110
110
  GGML_METAL_DECL_KERNEL(cpy_f32_f32);
111
111
  GGML_METAL_DECL_KERNEL(cpy_f16_f16);
112
+ GGML_METAL_DECL_KERNEL(concat);
113
+ GGML_METAL_DECL_KERNEL(sqr);
112
114
 
113
115
  #undef GGML_METAL_DECL_KERNEL
114
116
  };
@@ -183,56 +185,44 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
183
185
 
184
186
  ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
185
187
 
186
- #ifdef GGML_SWIFT
187
- // load the default.metallib file
188
+ // load library
188
189
  {
189
- NSError * error = nil;
190
-
191
- NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
192
- NSString * llamaBundlePath = [bundle pathForResource:@"llama_llama" ofType:@"bundle"];
193
- NSBundle * llamaBundle = [NSBundle bundleWithPath:llamaBundlePath];
194
- NSString * libPath = [llamaBundle pathForResource:@"default" ofType:@"metallib"];
195
- NSURL * libURL = [NSURL fileURLWithPath:libPath];
196
-
197
- // Load the metallib file into a Metal library
198
- ctx->library = [ctx->device newLibraryWithURL:libURL error:&error];
199
-
200
- if (error) {
201
- GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
202
- return NULL;
203
- }
204
- }
190
+ NSBundle * bundle = nil;
191
+ #ifdef SWIFT_PACKAGE
192
+ bundle = SWIFTPM_MODULE_BUNDLE;
205
193
  #else
206
- UNUSED(msl_library_source);
207
-
208
- // read the source from "ggml-metal.metal" into a string and use newLibraryWithSource
209
- {
194
+ bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
195
+ #endif
210
196
  NSError * error = nil;
197
+ NSString * libPath = [bundle pathForResource:@"default" ofType:@"metallib"];
198
+ if (libPath != nil) {
199
+ NSURL * libURL = [NSURL fileURLWithPath:libPath];
200
+ GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [libPath UTF8String]);
201
+ ctx->library = [ctx->device newLibraryWithURL:libURL error:&error];
202
+ } else {
203
+ GGML_METAL_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__);
204
+
205
+ NSString * sourcePath = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
206
+ GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [sourcePath UTF8String]);
207
+ NSString * src = [NSString stringWithContentsOfFile:sourcePath encoding:NSUTF8StringEncoding error:&error];
208
+ if (error) {
209
+ GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
210
+ return NULL;
211
+ }
211
212
 
212
- //NSString * path = [[NSBundle mainBundle] pathForResource:@"../../examples/metal/metal" ofType:@"metal"];
213
- NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
214
- NSString * path = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
215
- GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [path UTF8String]);
216
-
217
- NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error];
218
- if (error) {
219
- GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
220
- return NULL;
221
- }
222
-
213
+ MTLCompileOptions* options = nil;
223
214
  #ifdef GGML_QKK_64
224
- MTLCompileOptions* options = [MTLCompileOptions new];
225
- options.preprocessorMacros = @{ @"QK_K" : @(64) };
226
- ctx->library = [ctx->device newLibraryWithSource:src options:options error:&error];
227
- #else
228
- ctx->library = [ctx->device newLibraryWithSource:src options:nil error:&error];
215
+ options = [MTLCompileOptions new];
216
+ options.preprocessorMacros = @{ @"QK_K" : @(64) };
229
217
  #endif
218
+ ctx->library = [ctx->device newLibraryWithSource:src options:options error:&error];
219
+ }
220
+
230
221
  if (error) {
231
222
  GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
232
223
  return NULL;
233
224
  }
234
225
  }
235
- #endif
236
226
 
237
227
  // load kernels
238
228
  {
@@ -272,40 +262,57 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
272
262
  GGML_METAL_ADD_KERNEL(get_rows_q6_K);
273
263
  GGML_METAL_ADD_KERNEL(rms_norm);
274
264
  GGML_METAL_ADD_KERNEL(norm);
275
- GGML_METAL_ADD_KERNEL(mul_mat_f32_f32);
276
- GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
277
- GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_1row);
278
- GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_l4);
279
- GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
280
- GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
281
- GGML_METAL_ADD_KERNEL(mul_mat_q8_0_f32);
282
- GGML_METAL_ADD_KERNEL(mul_mat_q2_K_f32);
283
- GGML_METAL_ADD_KERNEL(mul_mat_q3_K_f32);
284
- GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32);
285
- GGML_METAL_ADD_KERNEL(mul_mat_q5_K_f32);
286
- GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32);
287
- GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
288
- GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
289
- GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
290
- GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
291
- GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32);
292
- GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32);
293
- GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32);
294
- GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
295
- GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
296
- GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
265
+ GGML_METAL_ADD_KERNEL(mul_mv_f32_f32);
266
+ GGML_METAL_ADD_KERNEL(mul_mv_f16_f32);
267
+ GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_1row);
268
+ GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_l4);
269
+ GGML_METAL_ADD_KERNEL(mul_mv_q4_0_f32);
270
+ GGML_METAL_ADD_KERNEL(mul_mv_q4_1_f32);
271
+ GGML_METAL_ADD_KERNEL(mul_mv_q8_0_f32);
272
+ GGML_METAL_ADD_KERNEL(mul_mv_q2_K_f32);
273
+ GGML_METAL_ADD_KERNEL(mul_mv_q3_K_f32);
274
+ GGML_METAL_ADD_KERNEL(mul_mv_q4_K_f32);
275
+ GGML_METAL_ADD_KERNEL(mul_mv_q5_K_f32);
276
+ GGML_METAL_ADD_KERNEL(mul_mv_q6_K_f32);
277
+ if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
278
+ GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
279
+ GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
280
+ GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
281
+ GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
282
+ GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32);
283
+ GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32);
284
+ GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32);
285
+ GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
286
+ GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
287
+ GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
288
+ }
297
289
  GGML_METAL_ADD_KERNEL(rope_f32);
298
290
  GGML_METAL_ADD_KERNEL(rope_f16);
299
291
  GGML_METAL_ADD_KERNEL(alibi_f32);
300
292
  GGML_METAL_ADD_KERNEL(cpy_f32_f16);
301
293
  GGML_METAL_ADD_KERNEL(cpy_f32_f32);
302
294
  GGML_METAL_ADD_KERNEL(cpy_f16_f16);
295
+ GGML_METAL_ADD_KERNEL(concat);
296
+ GGML_METAL_ADD_KERNEL(sqr);
303
297
 
304
298
  #undef GGML_METAL_ADD_KERNEL
305
299
  }
306
300
 
307
- GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
308
301
  #if TARGET_OS_OSX
302
+ // print MTL GPU family:
303
+ GGML_METAL_LOG_INFO("%s: GPU name: %s\n", __func__, [[ctx->device name] UTF8String]);
304
+
305
+ // determine max supported GPU family
306
+ // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
307
+ // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
308
+ for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
309
+ if ([ctx->device supportsFamily:i]) {
310
+ GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - MTLGPUFamilyApple1 + 1, i);
311
+ break;
312
+ }
313
+ }
314
+
315
+ GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
309
316
  GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
310
317
  if (ctx->device.maxTransferRate != 0) {
311
318
  GGML_METAL_LOG_INFO("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
@@ -347,34 +354,38 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
347
354
  GGML_METAL_DEL_KERNEL(get_rows_q6_K);
348
355
  GGML_METAL_DEL_KERNEL(rms_norm);
349
356
  GGML_METAL_DEL_KERNEL(norm);
350
- GGML_METAL_DEL_KERNEL(mul_mat_f32_f32);
351
- GGML_METAL_DEL_KERNEL(mul_mat_f16_f32);
352
- GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_1row);
353
- GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_l4);
354
- GGML_METAL_DEL_KERNEL(mul_mat_q4_0_f32);
355
- GGML_METAL_DEL_KERNEL(mul_mat_q4_1_f32);
356
- GGML_METAL_DEL_KERNEL(mul_mat_q8_0_f32);
357
- GGML_METAL_DEL_KERNEL(mul_mat_q2_K_f32);
358
- GGML_METAL_DEL_KERNEL(mul_mat_q3_K_f32);
359
- GGML_METAL_DEL_KERNEL(mul_mat_q4_K_f32);
360
- GGML_METAL_DEL_KERNEL(mul_mat_q5_K_f32);
361
- GGML_METAL_DEL_KERNEL(mul_mat_q6_K_f32);
362
- GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
363
- GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
364
- GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
365
- GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
366
- GGML_METAL_DEL_KERNEL(mul_mm_q4_1_f32);
367
- GGML_METAL_DEL_KERNEL(mul_mm_q2_K_f32);
368
- GGML_METAL_DEL_KERNEL(mul_mm_q3_K_f32);
369
- GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
370
- GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
371
- GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
357
+ GGML_METAL_DEL_KERNEL(mul_mv_f32_f32);
358
+ GGML_METAL_DEL_KERNEL(mul_mv_f16_f32);
359
+ GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_1row);
360
+ GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_l4);
361
+ GGML_METAL_DEL_KERNEL(mul_mv_q4_0_f32);
362
+ GGML_METAL_DEL_KERNEL(mul_mv_q4_1_f32);
363
+ GGML_METAL_DEL_KERNEL(mul_mv_q8_0_f32);
364
+ GGML_METAL_DEL_KERNEL(mul_mv_q2_K_f32);
365
+ GGML_METAL_DEL_KERNEL(mul_mv_q3_K_f32);
366
+ GGML_METAL_DEL_KERNEL(mul_mv_q4_K_f32);
367
+ GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32);
368
+ GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32);
369
+ if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
370
+ GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
371
+ GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
372
+ GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
373
+ GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
374
+ GGML_METAL_DEL_KERNEL(mul_mm_q4_1_f32);
375
+ GGML_METAL_DEL_KERNEL(mul_mm_q2_K_f32);
376
+ GGML_METAL_DEL_KERNEL(mul_mm_q3_K_f32);
377
+ GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
378
+ GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
379
+ GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
380
+ }
372
381
  GGML_METAL_DEL_KERNEL(rope_f32);
373
382
  GGML_METAL_DEL_KERNEL(rope_f16);
374
383
  GGML_METAL_DEL_KERNEL(alibi_f32);
375
384
  GGML_METAL_DEL_KERNEL(cpy_f32_f16);
376
385
  GGML_METAL_DEL_KERNEL(cpy_f32_f32);
377
386
  GGML_METAL_DEL_KERNEL(cpy_f16_f16);
387
+ GGML_METAL_DEL_KERNEL(concat);
388
+ GGML_METAL_DEL_KERNEL(sqr);
378
389
 
379
390
  #undef GGML_METAL_DEL_KERNEL
380
391
 
@@ -431,7 +442,7 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru
431
442
  for (int i = 0; i < ctx->n_buffers; ++i) {
432
443
  const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data;
433
444
 
434
- //metal_printf("ioffs = %10ld, tsize = %10ld, sum = %10ld, ctx->buffers[%d].size = %10ld, name = %s\n", ioffs, tsize, ioffs + tsize, i, ctx->buffers[i].size, ctx->buffers[i].name);
445
+ //GGML_METAL_LOG_INFO("ioffs = %10ld, tsize = %10ld, sum = %10ld, ctx->buffers[%d].size = %10ld, name = %s\n", ioffs, tsize, ioffs + tsize, i, ctx->buffers[i].size, ctx->buffers[i].name);
435
446
  if (ioffs >= 0 && ioffs + tsize <= (int64_t) ctx->buffers[i].size) {
436
447
  *offs = (size_t) ioffs;
437
448
 
@@ -766,6 +777,44 @@ void ggml_metal_graph_compute(
766
777
  {
767
778
  // noop
768
779
  } break;
780
+ case GGML_OP_CONCAT:
781
+ {
782
+ const int64_t nb = ne00;
783
+
784
+ [encoder setComputePipelineState:ctx->pipeline_concat];
785
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
786
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
787
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
788
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
789
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
790
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
791
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
792
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
793
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
794
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
795
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
796
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
797
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
798
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
799
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
800
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
801
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
802
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
803
+ [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
804
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
805
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
806
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
807
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
808
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
809
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
810
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
811
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
812
+ [encoder setBytes:&nb length:sizeof(nb) atIndex:27];
813
+
814
+ const int nth = MIN(1024, ne0);
815
+
816
+ [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
817
+ } break;
769
818
  case GGML_OP_ADD:
770
819
  {
771
820
  GGML_ASSERT(ggml_is_contiguous(src0));
@@ -861,9 +910,10 @@ void ggml_metal_graph_compute(
861
910
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
862
911
  [encoder setBytes:&scale length:sizeof(scale) atIndex:2];
863
912
 
864
- const int64_t n = ggml_nelements(dst)/4;
913
+ const int64_t n = ggml_nelements(dst);
914
+ GGML_ASSERT(n % 4 == 0);
865
915
 
866
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
916
+ [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
867
917
  } break;
868
918
  case GGML_OP_UNARY:
869
919
  switch (ggml_get_unary_op(gf->nodes[i])) {
@@ -873,9 +923,10 @@ void ggml_metal_graph_compute(
873
923
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
874
924
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
875
925
 
876
- const int64_t n = ggml_nelements(dst)/4;
926
+ const int64_t n = ggml_nelements(dst);
927
+ GGML_ASSERT(n % 4 == 0);
877
928
 
878
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
929
+ [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
879
930
  } break;
880
931
  case GGML_UNARY_OP_RELU:
881
932
  {
@@ -893,9 +944,10 @@ void ggml_metal_graph_compute(
893
944
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
894
945
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
895
946
 
896
- const int64_t n = ggml_nelements(dst)/4;
947
+ const int64_t n = ggml_nelements(dst);
948
+ GGML_ASSERT(n % 4 == 0);
897
949
 
898
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
950
+ [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
899
951
  } break;
900
952
  default:
901
953
  {
@@ -903,6 +955,17 @@ void ggml_metal_graph_compute(
903
955
  GGML_ASSERT(false);
904
956
  }
905
957
  } break;
958
+ case GGML_OP_SQR:
959
+ {
960
+ GGML_ASSERT(ggml_is_contiguous(src0));
961
+
962
+ [encoder setComputePipelineState:ctx->pipeline_sqr];
963
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
964
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
965
+
966
+ const int64_t n = ggml_nelements(dst);
967
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
968
+ } break;
906
969
  case GGML_OP_SOFT_MAX:
907
970
  {
908
971
  const int nth = MIN(32, ne00);
@@ -944,21 +1007,46 @@ void ggml_metal_graph_compute(
944
1007
  } break;
945
1008
  case GGML_OP_MUL_MAT:
946
1009
  {
947
- // TODO: needs to be updated after PR: https://github.com/ggerganov/ggml/pull/224
948
-
949
1010
  GGML_ASSERT(ne00 == ne10);
950
- // GGML_ASSERT(ne02 == ne12); // Should be checked on individual data types until broadcast is implemented everywhere
951
- uint gqa = ne12/ne02;
952
1011
  GGML_ASSERT(ne03 == ne13);
953
1012
 
1013
+ const uint gqa = ne12/ne02;
1014
+
1015
+ // find the break-even point where the matrix-matrix kernel becomes more efficient compared
1016
+ // to the matrix-vector kernel
1017
+ int ne11_mm_min = 1;
1018
+
1019
+ #if 0
1020
+ // the numbers below are measured on M2 Ultra for 7B and 13B models
1021
+ // these numbers do not translate to other devices or model sizes
1022
+ // TODO: need to find a better approach
1023
+ if ([ctx->device.name isEqualToString:@"Apple M2 Ultra"]) {
1024
+ switch (src0t) {
1025
+ case GGML_TYPE_F16: ne11_mm_min = 2; break;
1026
+ case GGML_TYPE_Q8_0: ne11_mm_min = 7; break;
1027
+ case GGML_TYPE_Q2_K: ne11_mm_min = 15; break;
1028
+ case GGML_TYPE_Q3_K: ne11_mm_min = 7; break;
1029
+ case GGML_TYPE_Q4_0:
1030
+ case GGML_TYPE_Q4_1: ne11_mm_min = 15; break;
1031
+ case GGML_TYPE_Q4_K: ne11_mm_min = 11; break;
1032
+ case GGML_TYPE_Q5_0: // not tested yet
1033
+ case GGML_TYPE_Q5_1: ne11_mm_min = 13; break; // not tested yet
1034
+ case GGML_TYPE_Q5_K: ne11_mm_min = 7; break;
1035
+ case GGML_TYPE_Q6_K: ne11_mm_min = 7; break;
1036
+ default: ne11_mm_min = 1; break;
1037
+ }
1038
+ }
1039
+ #endif
1040
+
954
1041
  // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
955
1042
  // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
956
- if (!ggml_is_transposed(src0) &&
1043
+ if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
1044
+ !ggml_is_transposed(src0) &&
957
1045
  !ggml_is_transposed(src1) &&
958
1046
  src1t == GGML_TYPE_F32 &&
959
- [ctx->device supportsFamily:MTLGPUFamilyApple7] &&
960
- ne00%32 == 0 &&
961
- ne11 > 2) {
1047
+ ne00 % 32 == 0 && ne00 >= 64 &&
1048
+ ne11 > ne11_mm_min) {
1049
+ //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
962
1050
  switch (src0->type) {
963
1051
  case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32]; break;
964
1052
  case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
@@ -987,17 +1075,18 @@ void ggml_metal_graph_compute(
987
1075
  [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
988
1076
  [encoder setBytes:&gqa length:sizeof(gqa) atIndex:13];
989
1077
  [encoder setThreadgroupMemoryLength:8192 atIndex:0];
990
- [encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1078
+ [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
991
1079
  } else {
992
1080
  int nth0 = 32;
993
1081
  int nth1 = 1;
994
1082
  int nrows = 1;
1083
+ //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
995
1084
 
996
1085
  // use custom matrix x vector kernel
997
1086
  switch (src0t) {
998
1087
  case GGML_TYPE_F32:
999
1088
  {
1000
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_f32_f32];
1089
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_f32_f32];
1001
1090
  nrows = 4;
1002
1091
  } break;
1003
1092
  case GGML_TYPE_F16:
@@ -1005,12 +1094,12 @@ void ggml_metal_graph_compute(
1005
1094
  nth0 = 32;
1006
1095
  nth1 = 1;
1007
1096
  if (ne11 * ne12 < 4) {
1008
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_1row];
1097
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_1row];
1009
1098
  } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
1010
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_l4];
1099
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_l4];
1011
1100
  nrows = ne11;
1012
1101
  } else {
1013
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
1102
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32];
1014
1103
  nrows = 4;
1015
1104
  }
1016
1105
  } break;
@@ -1021,7 +1110,7 @@ void ggml_metal_graph_compute(
1021
1110
 
1022
1111
  nth0 = 8;
1023
1112
  nth1 = 8;
1024
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32];
1113
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_0_f32];
1025
1114
  } break;
1026
1115
  case GGML_TYPE_Q4_1:
1027
1116
  {
@@ -1030,7 +1119,7 @@ void ggml_metal_graph_compute(
1030
1119
 
1031
1120
  nth0 = 8;
1032
1121
  nth1 = 8;
1033
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_1_f32];
1122
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_1_f32];
1034
1123
  } break;
1035
1124
  case GGML_TYPE_Q8_0:
1036
1125
  {
@@ -1039,7 +1128,7 @@ void ggml_metal_graph_compute(
1039
1128
 
1040
1129
  nth0 = 8;
1041
1130
  nth1 = 8;
1042
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q8_0_f32];
1131
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_q8_0_f32];
1043
1132
  } break;
1044
1133
  case GGML_TYPE_Q2_K:
1045
1134
  {
@@ -1048,7 +1137,7 @@ void ggml_metal_graph_compute(
1048
1137
 
1049
1138
  nth0 = 2;
1050
1139
  nth1 = 32;
1051
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_K_f32];
1140
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_q2_K_f32];
1052
1141
  } break;
1053
1142
  case GGML_TYPE_Q3_K:
1054
1143
  {
@@ -1057,7 +1146,7 @@ void ggml_metal_graph_compute(
1057
1146
 
1058
1147
  nth0 = 2;
1059
1148
  nth1 = 32;
1060
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_K_f32];
1149
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_q3_K_f32];
1061
1150
  } break;
1062
1151
  case GGML_TYPE_Q4_K:
1063
1152
  {
@@ -1066,7 +1155,7 @@ void ggml_metal_graph_compute(
1066
1155
 
1067
1156
  nth0 = 4; //1;
1068
1157
  nth1 = 8; //32;
1069
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_K_f32];
1158
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_K_f32];
1070
1159
  } break;
1071
1160
  case GGML_TYPE_Q5_K:
1072
1161
  {
@@ -1075,7 +1164,7 @@ void ggml_metal_graph_compute(
1075
1164
 
1076
1165
  nth0 = 2;
1077
1166
  nth1 = 32;
1078
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_K_f32];
1167
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_K_f32];
1079
1168
  } break;
1080
1169
  case GGML_TYPE_Q6_K:
1081
1170
  {
@@ -1084,7 +1173,7 @@ void ggml_metal_graph_compute(
1084
1173
 
1085
1174
  nth0 = 2;
1086
1175
  nth1 = 32;
1087
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_K_f32];
1176
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_q6_K_f32];
1088
1177
  } break;
1089
1178
  default:
1090
1179
  {
@@ -1113,7 +1202,7 @@ void ggml_metal_graph_compute(
1113
1202
  [encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
1114
1203
 
1115
1204
  if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q8_0 ||
1116
- src0t == GGML_TYPE_Q2_K) {// || src0t == GGML_TYPE_Q4_K) {
1205
+ src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
1117
1206
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1118
1207
  }
1119
1208
  else if (src0t == GGML_TYPE_Q4_K) {
@@ -1166,6 +1255,8 @@ void ggml_metal_graph_compute(
1166
1255
  } break;
1167
1256
  case GGML_OP_RMS_NORM:
1168
1257
  {
1258
+ GGML_ASSERT(ne00 % 4 == 0);
1259
+
1169
1260
  float eps;
1170
1261
  memcpy(&eps, dst->op_params, sizeof(float));
1171
1262
 
@@ -1208,7 +1299,7 @@ void ggml_metal_graph_compute(
1208
1299
 
1209
1300
  const int nth = MIN(1024, ne00);
1210
1301
 
1211
- const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past);
1302
+ //const int n_past = ((int32_t *) dst->op_params)[0];
1212
1303
  const int n_head = ((int32_t *) dst->op_params)[1];
1213
1304
  float max_bias;
1214
1305
  memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
@@ -1371,3 +1462,140 @@ void ggml_metal_graph_compute(
1371
1462
 
1372
1463
  }
1373
1464
  }
1465
+
1466
+ ////////////////////////////////////////////////////////////////////////////////
1467
+
1468
+ // backend interface
1469
+
1470
+ static const char * ggml_backend_metal_name(ggml_backend_t backend) {
1471
+ return "Metal";
1472
+
1473
+ UNUSED(backend);
1474
+ }
1475
+
1476
+ static void ggml_backend_metal_free(ggml_backend_t backend) {
1477
+ struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
1478
+ ggml_metal_free(ctx);
1479
+ free(backend);
1480
+ }
1481
+
1482
+ static void * ggml_backend_metal_buffer_get_base(ggml_backend_buffer_t buffer) {
1483
+ return (void *)buffer->context;
1484
+ }
1485
+
1486
+ static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer) {
1487
+ free(buffer->context);
1488
+ UNUSED(buffer);
1489
+ }
1490
+
1491
+ static struct ggml_backend_buffer_i metal_backend_buffer_i = {
1492
+ /* .free_buffer = */ ggml_backend_metal_buffer_free_buffer,
1493
+ /* .get_base = */ ggml_backend_metal_buffer_get_base,
1494
+ /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
1495
+ /* .init_tensor = */ NULL, // no initialization required
1496
+ /* .free_tensor = */ NULL, // no cleanup required
1497
+ };
1498
+
1499
+ static ggml_backend_buffer_t ggml_backend_metal_alloc_buffer(ggml_backend_t backend, size_t size) {
1500
+ struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
1501
+
1502
+ void * data = ggml_metal_host_malloc(size);
1503
+
1504
+ // TODO: set proper name of the buffers
1505
+ ggml_metal_add_buffer(ctx, "backend", data, size, 0);
1506
+
1507
+ return ggml_backend_buffer_init(backend, metal_backend_buffer_i, data, size);
1508
+ }
1509
+
1510
+ static size_t ggml_backend_metal_get_alignment(ggml_backend_t backend) {
1511
+ return 32;
1512
+ UNUSED(backend);
1513
+ }
1514
+
1515
+ static void ggml_backend_metal_set_tensor_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
1516
+ GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
1517
+ GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
1518
+
1519
+ memcpy((char *)tensor->data + offset, data, size);
1520
+
1521
+ UNUSED(backend);
1522
+ }
1523
+
1524
+ static void ggml_backend_metal_get_tensor_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
1525
+ GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
1526
+ GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
1527
+
1528
+ memcpy(data, (const char *)tensor->data + offset, size);
1529
+
1530
+ UNUSED(backend);
1531
+ }
1532
+
1533
+ static void ggml_backend_metal_synchronize(ggml_backend_t backend) {
1534
+ UNUSED(backend);
1535
+ }
1536
+
1537
+ static void ggml_backend_metal_cpy_tensor_from(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst) {
1538
+ ggml_backend_tensor_get(src, dst->data, 0, ggml_nbytes(src));
1539
+
1540
+ UNUSED(backend);
1541
+ }
1542
+
1543
+ static void ggml_backend_metal_cpy_tensor_to(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst) {
1544
+ ggml_backend_tensor_set_async(dst, src->data, 0, ggml_nbytes(src));
1545
+
1546
+ UNUSED(backend);
1547
+ }
1548
+
1549
+ static void ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
1550
+ struct ggml_metal_context * metal_ctx = (struct ggml_metal_context *)backend->context;
1551
+
1552
+ ggml_metal_graph_compute(metal_ctx, cgraph);
1553
+ }
1554
+
1555
+ static bool ggml_backend_metal_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
1556
+ return true;
1557
+ UNUSED(backend);
1558
+ UNUSED(op);
1559
+ }
1560
+
1561
+ static struct ggml_backend_i metal_backend_i = {
1562
+ /* .get_name = */ ggml_backend_metal_name,
1563
+ /* .free = */ ggml_backend_metal_free,
1564
+ /* .alloc_buffer = */ ggml_backend_metal_alloc_buffer,
1565
+ /* .get_alignment = */ ggml_backend_metal_get_alignment,
1566
+ /* .set_tensor_async = */ ggml_backend_metal_set_tensor_async,
1567
+ /* .get_tensor_async = */ ggml_backend_metal_get_tensor_async,
1568
+ /* .synchronize = */ ggml_backend_metal_synchronize,
1569
+ /* .cpy_tensor_from = */ ggml_backend_metal_cpy_tensor_from,
1570
+ /* .cpy_tensor_to = */ ggml_backend_metal_cpy_tensor_to,
1571
+ /* .graph_plan_create = */ NULL, // the metal implementation does not require creating graph plans atm
1572
+ /* .graph_plan_free = */ NULL,
1573
+ /* .graph_plan_compute = */ NULL,
1574
+ /* .graph_compute = */ ggml_backend_metal_graph_compute,
1575
+ /* .supports_op = */ ggml_backend_metal_supports_op,
1576
+ };
1577
+
1578
+ ggml_backend_t ggml_backend_metal_init(void) {
1579
+ struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context));
1580
+
1581
+ ctx = ggml_metal_init(GGML_DEFAULT_N_THREADS);
1582
+
1583
+ ggml_backend_t metal_backend = malloc(sizeof(struct ggml_backend));
1584
+
1585
+ *metal_backend = (struct ggml_backend) {
1586
+ /* .interface = */ metal_backend_i,
1587
+ /* .context = */ ctx,
1588
+ };
1589
+
1590
+ return metal_backend;
1591
+ }
1592
+
1593
+ bool ggml_backend_is_metal(ggml_backend_t backend) {
1594
+ return backend->iface.get_name == ggml_backend_metal_name;
1595
+ }
1596
+
1597
+ void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
1598
+ struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
1599
+
1600
+ ggml_metal_set_n_cb(ctx, n_cb);
1601
+ }