llama_cpp 0.7.0 → 0.7.1

Sign up to get free protection for your applications and to get access to all the features.
@@ -81,18 +81,18 @@ struct ggml_metal_context {
81
81
  GGML_METAL_DECL_KERNEL(get_rows_q6_K);
82
82
  GGML_METAL_DECL_KERNEL(rms_norm);
83
83
  GGML_METAL_DECL_KERNEL(norm);
84
- GGML_METAL_DECL_KERNEL(mul_mat_f32_f32);
85
- GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
86
- GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_1row);
87
- GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_l4);
88
- GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
89
- GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
90
- GGML_METAL_DECL_KERNEL(mul_mat_q8_0_f32);
91
- GGML_METAL_DECL_KERNEL(mul_mat_q2_K_f32);
92
- GGML_METAL_DECL_KERNEL(mul_mat_q3_K_f32);
93
- GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32);
94
- GGML_METAL_DECL_KERNEL(mul_mat_q5_K_f32);
95
- GGML_METAL_DECL_KERNEL(mul_mat_q6_K_f32);
84
+ GGML_METAL_DECL_KERNEL(mul_mv_f32_f32);
85
+ GGML_METAL_DECL_KERNEL(mul_mv_f16_f32);
86
+ GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_1row);
87
+ GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_l4);
88
+ GGML_METAL_DECL_KERNEL(mul_mv_q4_0_f32);
89
+ GGML_METAL_DECL_KERNEL(mul_mv_q4_1_f32);
90
+ GGML_METAL_DECL_KERNEL(mul_mv_q8_0_f32);
91
+ GGML_METAL_DECL_KERNEL(mul_mv_q2_K_f32);
92
+ GGML_METAL_DECL_KERNEL(mul_mv_q3_K_f32);
93
+ GGML_METAL_DECL_KERNEL(mul_mv_q4_K_f32);
94
+ GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32);
95
+ GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32);
96
96
  GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
97
97
  GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
98
98
  GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
@@ -109,6 +109,8 @@ struct ggml_metal_context {
109
109
  GGML_METAL_DECL_KERNEL(cpy_f32_f16);
110
110
  GGML_METAL_DECL_KERNEL(cpy_f32_f32);
111
111
  GGML_METAL_DECL_KERNEL(cpy_f16_f16);
112
+ GGML_METAL_DECL_KERNEL(concat);
113
+ GGML_METAL_DECL_KERNEL(sqr);
112
114
 
113
115
  #undef GGML_METAL_DECL_KERNEL
114
116
  };
@@ -183,56 +185,44 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
183
185
 
184
186
  ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
185
187
 
186
- #ifdef GGML_SWIFT
187
- // load the default.metallib file
188
+ // load library
188
189
  {
189
- NSError * error = nil;
190
-
191
- NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
192
- NSString * llamaBundlePath = [bundle pathForResource:@"llama_llama" ofType:@"bundle"];
193
- NSBundle * llamaBundle = [NSBundle bundleWithPath:llamaBundlePath];
194
- NSString * libPath = [llamaBundle pathForResource:@"default" ofType:@"metallib"];
195
- NSURL * libURL = [NSURL fileURLWithPath:libPath];
196
-
197
- // Load the metallib file into a Metal library
198
- ctx->library = [ctx->device newLibraryWithURL:libURL error:&error];
199
-
200
- if (error) {
201
- GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
202
- return NULL;
203
- }
204
- }
190
+ NSBundle * bundle = nil;
191
+ #ifdef SWIFT_PACKAGE
192
+ bundle = SWIFTPM_MODULE_BUNDLE;
205
193
  #else
206
- UNUSED(msl_library_source);
207
-
208
- // read the source from "ggml-metal.metal" into a string and use newLibraryWithSource
209
- {
194
+ bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
195
+ #endif
210
196
  NSError * error = nil;
197
+ NSString * libPath = [bundle pathForResource:@"default" ofType:@"metallib"];
198
+ if (libPath != nil) {
199
+ NSURL * libURL = [NSURL fileURLWithPath:libPath];
200
+ GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [libPath UTF8String]);
201
+ ctx->library = [ctx->device newLibraryWithURL:libURL error:&error];
202
+ } else {
203
+ GGML_METAL_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__);
204
+
205
+ NSString * sourcePath = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
206
+ GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [sourcePath UTF8String]);
207
+ NSString * src = [NSString stringWithContentsOfFile:sourcePath encoding:NSUTF8StringEncoding error:&error];
208
+ if (error) {
209
+ GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
210
+ return NULL;
211
+ }
211
212
 
212
- //NSString * path = [[NSBundle mainBundle] pathForResource:@"../../examples/metal/metal" ofType:@"metal"];
213
- NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
214
- NSString * path = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
215
- GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [path UTF8String]);
216
-
217
- NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error];
218
- if (error) {
219
- GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
220
- return NULL;
221
- }
222
-
213
+ MTLCompileOptions* options = nil;
223
214
  #ifdef GGML_QKK_64
224
- MTLCompileOptions* options = [MTLCompileOptions new];
225
- options.preprocessorMacros = @{ @"QK_K" : @(64) };
226
- ctx->library = [ctx->device newLibraryWithSource:src options:options error:&error];
227
- #else
228
- ctx->library = [ctx->device newLibraryWithSource:src options:nil error:&error];
215
+ options = [MTLCompileOptions new];
216
+ options.preprocessorMacros = @{ @"QK_K" : @(64) };
229
217
  #endif
218
+ ctx->library = [ctx->device newLibraryWithSource:src options:options error:&error];
219
+ }
220
+
230
221
  if (error) {
231
222
  GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
232
223
  return NULL;
233
224
  }
234
225
  }
235
- #endif
236
226
 
237
227
  // load kernels
238
228
  {
@@ -272,40 +262,57 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
272
262
  GGML_METAL_ADD_KERNEL(get_rows_q6_K);
273
263
  GGML_METAL_ADD_KERNEL(rms_norm);
274
264
  GGML_METAL_ADD_KERNEL(norm);
275
- GGML_METAL_ADD_KERNEL(mul_mat_f32_f32);
276
- GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
277
- GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_1row);
278
- GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_l4);
279
- GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
280
- GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
281
- GGML_METAL_ADD_KERNEL(mul_mat_q8_0_f32);
282
- GGML_METAL_ADD_KERNEL(mul_mat_q2_K_f32);
283
- GGML_METAL_ADD_KERNEL(mul_mat_q3_K_f32);
284
- GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32);
285
- GGML_METAL_ADD_KERNEL(mul_mat_q5_K_f32);
286
- GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32);
287
- GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
288
- GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
289
- GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
290
- GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
291
- GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32);
292
- GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32);
293
- GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32);
294
- GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
295
- GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
296
- GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
265
+ GGML_METAL_ADD_KERNEL(mul_mv_f32_f32);
266
+ GGML_METAL_ADD_KERNEL(mul_mv_f16_f32);
267
+ GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_1row);
268
+ GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_l4);
269
+ GGML_METAL_ADD_KERNEL(mul_mv_q4_0_f32);
270
+ GGML_METAL_ADD_KERNEL(mul_mv_q4_1_f32);
271
+ GGML_METAL_ADD_KERNEL(mul_mv_q8_0_f32);
272
+ GGML_METAL_ADD_KERNEL(mul_mv_q2_K_f32);
273
+ GGML_METAL_ADD_KERNEL(mul_mv_q3_K_f32);
274
+ GGML_METAL_ADD_KERNEL(mul_mv_q4_K_f32);
275
+ GGML_METAL_ADD_KERNEL(mul_mv_q5_K_f32);
276
+ GGML_METAL_ADD_KERNEL(mul_mv_q6_K_f32);
277
+ if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
278
+ GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
279
+ GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
280
+ GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
281
+ GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
282
+ GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32);
283
+ GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32);
284
+ GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32);
285
+ GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
286
+ GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
287
+ GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
288
+ }
297
289
  GGML_METAL_ADD_KERNEL(rope_f32);
298
290
  GGML_METAL_ADD_KERNEL(rope_f16);
299
291
  GGML_METAL_ADD_KERNEL(alibi_f32);
300
292
  GGML_METAL_ADD_KERNEL(cpy_f32_f16);
301
293
  GGML_METAL_ADD_KERNEL(cpy_f32_f32);
302
294
  GGML_METAL_ADD_KERNEL(cpy_f16_f16);
295
+ GGML_METAL_ADD_KERNEL(concat);
296
+ GGML_METAL_ADD_KERNEL(sqr);
303
297
 
304
298
  #undef GGML_METAL_ADD_KERNEL
305
299
  }
306
300
 
307
- GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
308
301
  #if TARGET_OS_OSX
302
+ // print MTL GPU family:
303
+ GGML_METAL_LOG_INFO("%s: GPU name: %s\n", __func__, [[ctx->device name] UTF8String]);
304
+
305
+ // determine max supported GPU family
306
+ // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
307
+ // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
308
+ for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
309
+ if ([ctx->device supportsFamily:i]) {
310
+ GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - MTLGPUFamilyApple1 + 1, i);
311
+ break;
312
+ }
313
+ }
314
+
315
+ GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
309
316
  GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
310
317
  if (ctx->device.maxTransferRate != 0) {
311
318
  GGML_METAL_LOG_INFO("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
@@ -347,34 +354,38 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
347
354
  GGML_METAL_DEL_KERNEL(get_rows_q6_K);
348
355
  GGML_METAL_DEL_KERNEL(rms_norm);
349
356
  GGML_METAL_DEL_KERNEL(norm);
350
- GGML_METAL_DEL_KERNEL(mul_mat_f32_f32);
351
- GGML_METAL_DEL_KERNEL(mul_mat_f16_f32);
352
- GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_1row);
353
- GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_l4);
354
- GGML_METAL_DEL_KERNEL(mul_mat_q4_0_f32);
355
- GGML_METAL_DEL_KERNEL(mul_mat_q4_1_f32);
356
- GGML_METAL_DEL_KERNEL(mul_mat_q8_0_f32);
357
- GGML_METAL_DEL_KERNEL(mul_mat_q2_K_f32);
358
- GGML_METAL_DEL_KERNEL(mul_mat_q3_K_f32);
359
- GGML_METAL_DEL_KERNEL(mul_mat_q4_K_f32);
360
- GGML_METAL_DEL_KERNEL(mul_mat_q5_K_f32);
361
- GGML_METAL_DEL_KERNEL(mul_mat_q6_K_f32);
362
- GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
363
- GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
364
- GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
365
- GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
366
- GGML_METAL_DEL_KERNEL(mul_mm_q4_1_f32);
367
- GGML_METAL_DEL_KERNEL(mul_mm_q2_K_f32);
368
- GGML_METAL_DEL_KERNEL(mul_mm_q3_K_f32);
369
- GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
370
- GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
371
- GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
357
+ GGML_METAL_DEL_KERNEL(mul_mv_f32_f32);
358
+ GGML_METAL_DEL_KERNEL(mul_mv_f16_f32);
359
+ GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_1row);
360
+ GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_l4);
361
+ GGML_METAL_DEL_KERNEL(mul_mv_q4_0_f32);
362
+ GGML_METAL_DEL_KERNEL(mul_mv_q4_1_f32);
363
+ GGML_METAL_DEL_KERNEL(mul_mv_q8_0_f32);
364
+ GGML_METAL_DEL_KERNEL(mul_mv_q2_K_f32);
365
+ GGML_METAL_DEL_KERNEL(mul_mv_q3_K_f32);
366
+ GGML_METAL_DEL_KERNEL(mul_mv_q4_K_f32);
367
+ GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32);
368
+ GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32);
369
+ if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
370
+ GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
371
+ GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
372
+ GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
373
+ GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
374
+ GGML_METAL_DEL_KERNEL(mul_mm_q4_1_f32);
375
+ GGML_METAL_DEL_KERNEL(mul_mm_q2_K_f32);
376
+ GGML_METAL_DEL_KERNEL(mul_mm_q3_K_f32);
377
+ GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
378
+ GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
379
+ GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
380
+ }
372
381
  GGML_METAL_DEL_KERNEL(rope_f32);
373
382
  GGML_METAL_DEL_KERNEL(rope_f16);
374
383
  GGML_METAL_DEL_KERNEL(alibi_f32);
375
384
  GGML_METAL_DEL_KERNEL(cpy_f32_f16);
376
385
  GGML_METAL_DEL_KERNEL(cpy_f32_f32);
377
386
  GGML_METAL_DEL_KERNEL(cpy_f16_f16);
387
+ GGML_METAL_DEL_KERNEL(concat);
388
+ GGML_METAL_DEL_KERNEL(sqr);
378
389
 
379
390
  #undef GGML_METAL_DEL_KERNEL
380
391
 
@@ -431,7 +442,7 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru
431
442
  for (int i = 0; i < ctx->n_buffers; ++i) {
432
443
  const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data;
433
444
 
434
- //metal_printf("ioffs = %10ld, tsize = %10ld, sum = %10ld, ctx->buffers[%d].size = %10ld, name = %s\n", ioffs, tsize, ioffs + tsize, i, ctx->buffers[i].size, ctx->buffers[i].name);
445
+ //GGML_METAL_LOG_INFO("ioffs = %10ld, tsize = %10ld, sum = %10ld, ctx->buffers[%d].size = %10ld, name = %s\n", ioffs, tsize, ioffs + tsize, i, ctx->buffers[i].size, ctx->buffers[i].name);
435
446
  if (ioffs >= 0 && ioffs + tsize <= (int64_t) ctx->buffers[i].size) {
436
447
  *offs = (size_t) ioffs;
437
448
 
@@ -766,6 +777,44 @@ void ggml_metal_graph_compute(
766
777
  {
767
778
  // noop
768
779
  } break;
780
+ case GGML_OP_CONCAT:
781
+ {
782
+ const int64_t nb = ne00;
783
+
784
+ [encoder setComputePipelineState:ctx->pipeline_concat];
785
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
786
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
787
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
788
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
789
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
790
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
791
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
792
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
793
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
794
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
795
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
796
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
797
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
798
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
799
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
800
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
801
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
802
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
803
+ [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
804
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
805
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
806
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
807
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
808
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
809
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
810
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
811
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
812
+ [encoder setBytes:&nb length:sizeof(nb) atIndex:27];
813
+
814
+ const int nth = MIN(1024, ne0);
815
+
816
+ [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
817
+ } break;
769
818
  case GGML_OP_ADD:
770
819
  {
771
820
  GGML_ASSERT(ggml_is_contiguous(src0));
@@ -861,9 +910,10 @@ void ggml_metal_graph_compute(
861
910
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
862
911
  [encoder setBytes:&scale length:sizeof(scale) atIndex:2];
863
912
 
864
- const int64_t n = ggml_nelements(dst)/4;
913
+ const int64_t n = ggml_nelements(dst);
914
+ GGML_ASSERT(n % 4 == 0);
865
915
 
866
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
916
+ [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
867
917
  } break;
868
918
  case GGML_OP_UNARY:
869
919
  switch (ggml_get_unary_op(gf->nodes[i])) {
@@ -873,9 +923,10 @@ void ggml_metal_graph_compute(
873
923
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
874
924
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
875
925
 
876
- const int64_t n = ggml_nelements(dst)/4;
926
+ const int64_t n = ggml_nelements(dst);
927
+ GGML_ASSERT(n % 4 == 0);
877
928
 
878
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
929
+ [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
879
930
  } break;
880
931
  case GGML_UNARY_OP_RELU:
881
932
  {
@@ -893,9 +944,10 @@ void ggml_metal_graph_compute(
893
944
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
894
945
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
895
946
 
896
- const int64_t n = ggml_nelements(dst)/4;
947
+ const int64_t n = ggml_nelements(dst);
948
+ GGML_ASSERT(n % 4 == 0);
897
949
 
898
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
950
+ [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
899
951
  } break;
900
952
  default:
901
953
  {
@@ -903,6 +955,17 @@ void ggml_metal_graph_compute(
903
955
  GGML_ASSERT(false);
904
956
  }
905
957
  } break;
958
+ case GGML_OP_SQR:
959
+ {
960
+ GGML_ASSERT(ggml_is_contiguous(src0));
961
+
962
+ [encoder setComputePipelineState:ctx->pipeline_sqr];
963
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
964
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
965
+
966
+ const int64_t n = ggml_nelements(dst);
967
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
968
+ } break;
906
969
  case GGML_OP_SOFT_MAX:
907
970
  {
908
971
  const int nth = MIN(32, ne00);
@@ -944,21 +1007,46 @@ void ggml_metal_graph_compute(
944
1007
  } break;
945
1008
  case GGML_OP_MUL_MAT:
946
1009
  {
947
- // TODO: needs to be updated after PR: https://github.com/ggerganov/ggml/pull/224
948
-
949
1010
  GGML_ASSERT(ne00 == ne10);
950
- // GGML_ASSERT(ne02 == ne12); // Should be checked on individual data types until broadcast is implemented everywhere
951
- uint gqa = ne12/ne02;
952
1011
  GGML_ASSERT(ne03 == ne13);
953
1012
 
1013
+ const uint gqa = ne12/ne02;
1014
+
1015
+ // find the break-even point where the matrix-matrix kernel becomes more efficient compared
1016
+ // to the matrix-vector kernel
1017
+ int ne11_mm_min = 1;
1018
+
1019
+ #if 0
1020
+ // the numbers below are measured on M2 Ultra for 7B and 13B models
1021
+ // these numbers do not translate to other devices or model sizes
1022
+ // TODO: need to find a better approach
1023
+ if ([ctx->device.name isEqualToString:@"Apple M2 Ultra"]) {
1024
+ switch (src0t) {
1025
+ case GGML_TYPE_F16: ne11_mm_min = 2; break;
1026
+ case GGML_TYPE_Q8_0: ne11_mm_min = 7; break;
1027
+ case GGML_TYPE_Q2_K: ne11_mm_min = 15; break;
1028
+ case GGML_TYPE_Q3_K: ne11_mm_min = 7; break;
1029
+ case GGML_TYPE_Q4_0:
1030
+ case GGML_TYPE_Q4_1: ne11_mm_min = 15; break;
1031
+ case GGML_TYPE_Q4_K: ne11_mm_min = 11; break;
1032
+ case GGML_TYPE_Q5_0: // not tested yet
1033
+ case GGML_TYPE_Q5_1: ne11_mm_min = 13; break; // not tested yet
1034
+ case GGML_TYPE_Q5_K: ne11_mm_min = 7; break;
1035
+ case GGML_TYPE_Q6_K: ne11_mm_min = 7; break;
1036
+ default: ne11_mm_min = 1; break;
1037
+ }
1038
+ }
1039
+ #endif
1040
+
954
1041
  // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
955
1042
  // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
956
- if (!ggml_is_transposed(src0) &&
1043
+ if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
1044
+ !ggml_is_transposed(src0) &&
957
1045
  !ggml_is_transposed(src1) &&
958
1046
  src1t == GGML_TYPE_F32 &&
959
- [ctx->device supportsFamily:MTLGPUFamilyApple7] &&
960
- ne00%32 == 0 &&
961
- ne11 > 2) {
1047
+ ne00 % 32 == 0 && ne00 >= 64 &&
1048
+ ne11 > ne11_mm_min) {
1049
+ //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
962
1050
  switch (src0->type) {
963
1051
  case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32]; break;
964
1052
  case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
@@ -987,17 +1075,18 @@ void ggml_metal_graph_compute(
987
1075
  [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
988
1076
  [encoder setBytes:&gqa length:sizeof(gqa) atIndex:13];
989
1077
  [encoder setThreadgroupMemoryLength:8192 atIndex:0];
990
- [encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1078
+ [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
991
1079
  } else {
992
1080
  int nth0 = 32;
993
1081
  int nth1 = 1;
994
1082
  int nrows = 1;
1083
+ //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
995
1084
 
996
1085
  // use custom matrix x vector kernel
997
1086
  switch (src0t) {
998
1087
  case GGML_TYPE_F32:
999
1088
  {
1000
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_f32_f32];
1089
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_f32_f32];
1001
1090
  nrows = 4;
1002
1091
  } break;
1003
1092
  case GGML_TYPE_F16:
@@ -1005,12 +1094,12 @@ void ggml_metal_graph_compute(
1005
1094
  nth0 = 32;
1006
1095
  nth1 = 1;
1007
1096
  if (ne11 * ne12 < 4) {
1008
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_1row];
1097
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_1row];
1009
1098
  } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
1010
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_l4];
1099
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_l4];
1011
1100
  nrows = ne11;
1012
1101
  } else {
1013
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
1102
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32];
1014
1103
  nrows = 4;
1015
1104
  }
1016
1105
  } break;
@@ -1021,7 +1110,7 @@ void ggml_metal_graph_compute(
1021
1110
 
1022
1111
  nth0 = 8;
1023
1112
  nth1 = 8;
1024
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32];
1113
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_0_f32];
1025
1114
  } break;
1026
1115
  case GGML_TYPE_Q4_1:
1027
1116
  {
@@ -1030,7 +1119,7 @@ void ggml_metal_graph_compute(
1030
1119
 
1031
1120
  nth0 = 8;
1032
1121
  nth1 = 8;
1033
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_1_f32];
1122
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_1_f32];
1034
1123
  } break;
1035
1124
  case GGML_TYPE_Q8_0:
1036
1125
  {
@@ -1039,7 +1128,7 @@ void ggml_metal_graph_compute(
1039
1128
 
1040
1129
  nth0 = 8;
1041
1130
  nth1 = 8;
1042
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q8_0_f32];
1131
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_q8_0_f32];
1043
1132
  } break;
1044
1133
  case GGML_TYPE_Q2_K:
1045
1134
  {
@@ -1048,7 +1137,7 @@ void ggml_metal_graph_compute(
1048
1137
 
1049
1138
  nth0 = 2;
1050
1139
  nth1 = 32;
1051
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_K_f32];
1140
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_q2_K_f32];
1052
1141
  } break;
1053
1142
  case GGML_TYPE_Q3_K:
1054
1143
  {
@@ -1057,7 +1146,7 @@ void ggml_metal_graph_compute(
1057
1146
 
1058
1147
  nth0 = 2;
1059
1148
  nth1 = 32;
1060
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_K_f32];
1149
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_q3_K_f32];
1061
1150
  } break;
1062
1151
  case GGML_TYPE_Q4_K:
1063
1152
  {
@@ -1066,7 +1155,7 @@ void ggml_metal_graph_compute(
1066
1155
 
1067
1156
  nth0 = 4; //1;
1068
1157
  nth1 = 8; //32;
1069
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_K_f32];
1158
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_K_f32];
1070
1159
  } break;
1071
1160
  case GGML_TYPE_Q5_K:
1072
1161
  {
@@ -1075,7 +1164,7 @@ void ggml_metal_graph_compute(
1075
1164
 
1076
1165
  nth0 = 2;
1077
1166
  nth1 = 32;
1078
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_K_f32];
1167
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_K_f32];
1079
1168
  } break;
1080
1169
  case GGML_TYPE_Q6_K:
1081
1170
  {
@@ -1084,7 +1173,7 @@ void ggml_metal_graph_compute(
1084
1173
 
1085
1174
  nth0 = 2;
1086
1175
  nth1 = 32;
1087
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_K_f32];
1176
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_q6_K_f32];
1088
1177
  } break;
1089
1178
  default:
1090
1179
  {
@@ -1113,7 +1202,7 @@ void ggml_metal_graph_compute(
1113
1202
  [encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
1114
1203
 
1115
1204
  if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q8_0 ||
1116
- src0t == GGML_TYPE_Q2_K) {// || src0t == GGML_TYPE_Q4_K) {
1205
+ src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
1117
1206
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1118
1207
  }
1119
1208
  else if (src0t == GGML_TYPE_Q4_K) {
@@ -1166,6 +1255,8 @@ void ggml_metal_graph_compute(
1166
1255
  } break;
1167
1256
  case GGML_OP_RMS_NORM:
1168
1257
  {
1258
+ GGML_ASSERT(ne00 % 4 == 0);
1259
+
1169
1260
  float eps;
1170
1261
  memcpy(&eps, dst->op_params, sizeof(float));
1171
1262
 
@@ -1208,7 +1299,7 @@ void ggml_metal_graph_compute(
1208
1299
 
1209
1300
  const int nth = MIN(1024, ne00);
1210
1301
 
1211
- const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past);
1302
+ //const int n_past = ((int32_t *) dst->op_params)[0];
1212
1303
  const int n_head = ((int32_t *) dst->op_params)[1];
1213
1304
  float max_bias;
1214
1305
  memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
@@ -1371,3 +1462,140 @@ void ggml_metal_graph_compute(
1371
1462
 
1372
1463
  }
1373
1464
  }
1465
+
1466
+ ////////////////////////////////////////////////////////////////////////////////
1467
+
1468
+ // backend interface
1469
+
1470
+ static const char * ggml_backend_metal_name(ggml_backend_t backend) {
1471
+ return "Metal";
1472
+
1473
+ UNUSED(backend);
1474
+ }
1475
+
1476
+ static void ggml_backend_metal_free(ggml_backend_t backend) {
1477
+ struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
1478
+ ggml_metal_free(ctx);
1479
+ free(backend);
1480
+ }
1481
+
1482
+ static void * ggml_backend_metal_buffer_get_base(ggml_backend_buffer_t buffer) {
1483
+ return (void *)buffer->context;
1484
+ }
1485
+
1486
+ static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer) {
1487
+ free(buffer->context);
1488
+ UNUSED(buffer);
1489
+ }
1490
+
1491
+ static struct ggml_backend_buffer_i metal_backend_buffer_i = {
1492
+ /* .free_buffer = */ ggml_backend_metal_buffer_free_buffer,
1493
+ /* .get_base = */ ggml_backend_metal_buffer_get_base,
1494
+ /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
1495
+ /* .init_tensor = */ NULL, // no initialization required
1496
+ /* .free_tensor = */ NULL, // no cleanup required
1497
+ };
1498
+
1499
+ static ggml_backend_buffer_t ggml_backend_metal_alloc_buffer(ggml_backend_t backend, size_t size) {
1500
+ struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
1501
+
1502
+ void * data = ggml_metal_host_malloc(size);
1503
+
1504
+ // TODO: set proper name of the buffers
1505
+ ggml_metal_add_buffer(ctx, "backend", data, size, 0);
1506
+
1507
+ return ggml_backend_buffer_init(backend, metal_backend_buffer_i, data, size);
1508
+ }
1509
+
1510
+ static size_t ggml_backend_metal_get_alignment(ggml_backend_t backend) {
1511
+ return 32;
1512
+ UNUSED(backend);
1513
+ }
1514
+
1515
+ static void ggml_backend_metal_set_tensor_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
1516
+ GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
1517
+ GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
1518
+
1519
+ memcpy((char *)tensor->data + offset, data, size);
1520
+
1521
+ UNUSED(backend);
1522
+ }
1523
+
1524
+ static void ggml_backend_metal_get_tensor_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
1525
+ GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
1526
+ GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
1527
+
1528
+ memcpy(data, (const char *)tensor->data + offset, size);
1529
+
1530
+ UNUSED(backend);
1531
+ }
1532
+
1533
+ static void ggml_backend_metal_synchronize(ggml_backend_t backend) {
1534
+ UNUSED(backend);
1535
+ }
1536
+
1537
+ static void ggml_backend_metal_cpy_tensor_from(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst) {
1538
+ ggml_backend_tensor_get(src, dst->data, 0, ggml_nbytes(src));
1539
+
1540
+ UNUSED(backend);
1541
+ }
1542
+
1543
+ static void ggml_backend_metal_cpy_tensor_to(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst) {
1544
+ ggml_backend_tensor_set_async(dst, src->data, 0, ggml_nbytes(src));
1545
+
1546
+ UNUSED(backend);
1547
+ }
1548
+
1549
+ static void ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
1550
+ struct ggml_metal_context * metal_ctx = (struct ggml_metal_context *)backend->context;
1551
+
1552
+ ggml_metal_graph_compute(metal_ctx, cgraph);
1553
+ }
1554
+
1555
+ static bool ggml_backend_metal_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
1556
+ return true;
1557
+ UNUSED(backend);
1558
+ UNUSED(op);
1559
+ }
1560
+
1561
+ static struct ggml_backend_i metal_backend_i = {
1562
+ /* .get_name = */ ggml_backend_metal_name,
1563
+ /* .free = */ ggml_backend_metal_free,
1564
+ /* .alloc_buffer = */ ggml_backend_metal_alloc_buffer,
1565
+ /* .get_alignment = */ ggml_backend_metal_get_alignment,
1566
+ /* .set_tensor_async = */ ggml_backend_metal_set_tensor_async,
1567
+ /* .get_tensor_async = */ ggml_backend_metal_get_tensor_async,
1568
+ /* .synchronize = */ ggml_backend_metal_synchronize,
1569
+ /* .cpy_tensor_from = */ ggml_backend_metal_cpy_tensor_from,
1570
+ /* .cpy_tensor_to = */ ggml_backend_metal_cpy_tensor_to,
1571
+ /* .graph_plan_create = */ NULL, // the metal implementation does not require creating graph plans atm
1572
+ /* .graph_plan_free = */ NULL,
1573
+ /* .graph_plan_compute = */ NULL,
1574
+ /* .graph_compute = */ ggml_backend_metal_graph_compute,
1575
+ /* .supports_op = */ ggml_backend_metal_supports_op,
1576
+ };
1577
+
1578
+ ggml_backend_t ggml_backend_metal_init(void) {
1579
+ struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context));
1580
+
1581
+ ctx = ggml_metal_init(GGML_DEFAULT_N_THREADS);
1582
+
1583
+ ggml_backend_t metal_backend = malloc(sizeof(struct ggml_backend));
1584
+
1585
+ *metal_backend = (struct ggml_backend) {
1586
+ /* .interface = */ metal_backend_i,
1587
+ /* .context = */ ctx,
1588
+ };
1589
+
1590
+ return metal_backend;
1591
+ }
1592
+
1593
+ bool ggml_backend_is_metal(ggml_backend_t backend) {
1594
+ return backend->iface.get_name == ggml_backend_metal_name;
1595
+ }
1596
+
1597
+ void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
1598
+ struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
1599
+
1600
+ ggml_metal_set_n_cb(ctx, n_cb);
1601
+ }