llama_cpp 0.9.5 → 0.10.1

Sign up to get free protection for your applications and to get access to all the features.
@@ -62,11 +62,15 @@ struct ggml_metal_context {
62
62
  GGML_METAL_DECL_KERNEL(add_row); // TODO: avoid this extra kernel, instead extend the "add" kernel to support broadcast
63
63
  GGML_METAL_DECL_KERNEL(mul);
64
64
  GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast
65
+ GGML_METAL_DECL_KERNEL(div);
66
+ GGML_METAL_DECL_KERNEL(div_row);
65
67
  GGML_METAL_DECL_KERNEL(scale);
66
68
  GGML_METAL_DECL_KERNEL(scale_4);
67
- GGML_METAL_DECL_KERNEL(silu);
69
+ GGML_METAL_DECL_KERNEL(tanh);
68
70
  GGML_METAL_DECL_KERNEL(relu);
69
71
  GGML_METAL_DECL_KERNEL(gelu);
72
+ GGML_METAL_DECL_KERNEL(gelu_quick);
73
+ GGML_METAL_DECL_KERNEL(silu);
70
74
  GGML_METAL_DECL_KERNEL(soft_max);
71
75
  GGML_METAL_DECL_KERNEL(soft_max_4);
72
76
  GGML_METAL_DECL_KERNEL(diag_mask_inf);
@@ -84,6 +88,7 @@ struct ggml_metal_context {
84
88
  GGML_METAL_DECL_KERNEL(get_rows_q5_K);
85
89
  GGML_METAL_DECL_KERNEL(get_rows_q6_K);
86
90
  GGML_METAL_DECL_KERNEL(rms_norm);
91
+ GGML_METAL_DECL_KERNEL(group_norm);
87
92
  GGML_METAL_DECL_KERNEL(norm);
88
93
  GGML_METAL_DECL_KERNEL(mul_mv_f32_f32);
89
94
  GGML_METAL_DECL_KERNEL(mul_mv_f16_f16);
@@ -100,6 +105,21 @@ struct ggml_metal_context {
100
105
  GGML_METAL_DECL_KERNEL(mul_mv_q4_K_f32);
101
106
  GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32);
102
107
  GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32);
108
+ GGML_METAL_DECL_KERNEL(mul_mv_id_f32_f32);
109
+ //GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f16);
110
+ GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32);
111
+ //GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32_1row);
112
+ //GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32_l4);
113
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q4_0_f32);
114
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q4_1_f32);
115
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q5_0_f32);
116
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q5_1_f32);
117
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q8_0_f32);
118
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q2_K_f32);
119
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q3_K_f32);
120
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q4_K_f32);
121
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q5_K_f32);
122
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q6_K_f32);
103
123
  GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
104
124
  GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
105
125
  GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
@@ -112,15 +132,39 @@ struct ggml_metal_context {
112
132
  GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
113
133
  GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32);
114
134
  GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32);
135
+ GGML_METAL_DECL_KERNEL(mul_mm_id_f32_f32);
136
+ GGML_METAL_DECL_KERNEL(mul_mm_id_f16_f32);
137
+ GGML_METAL_DECL_KERNEL(mul_mm_id_q4_0_f32);
138
+ GGML_METAL_DECL_KERNEL(mul_mm_id_q4_1_f32);
139
+ GGML_METAL_DECL_KERNEL(mul_mm_id_q5_0_f32);
140
+ GGML_METAL_DECL_KERNEL(mul_mm_id_q5_1_f32);
141
+ GGML_METAL_DECL_KERNEL(mul_mm_id_q8_0_f32);
142
+ GGML_METAL_DECL_KERNEL(mul_mm_id_q2_K_f32);
143
+ GGML_METAL_DECL_KERNEL(mul_mm_id_q3_K_f32);
144
+ GGML_METAL_DECL_KERNEL(mul_mm_id_q4_K_f32);
145
+ GGML_METAL_DECL_KERNEL(mul_mm_id_q5_K_f32);
146
+ GGML_METAL_DECL_KERNEL(mul_mm_id_q6_K_f32);
115
147
  GGML_METAL_DECL_KERNEL(rope_f32);
116
148
  GGML_METAL_DECL_KERNEL(rope_f16);
117
149
  GGML_METAL_DECL_KERNEL(alibi_f32);
118
150
  GGML_METAL_DECL_KERNEL(im2col_f16);
151
+ GGML_METAL_DECL_KERNEL(upscale_f32);
152
+ GGML_METAL_DECL_KERNEL(pad_f32);
153
+ GGML_METAL_DECL_KERNEL(argsort_f32_i32_asc);
154
+ GGML_METAL_DECL_KERNEL(argsort_f32_i32_desc);
155
+ GGML_METAL_DECL_KERNEL(leaky_relu_f32);
119
156
  GGML_METAL_DECL_KERNEL(cpy_f32_f16);
120
157
  GGML_METAL_DECL_KERNEL(cpy_f32_f32);
158
+ GGML_METAL_DECL_KERNEL(cpy_f32_q8_0);
159
+ GGML_METAL_DECL_KERNEL(cpy_f32_q4_0);
160
+ GGML_METAL_DECL_KERNEL(cpy_f32_q4_1);
161
+ //GGML_METAL_DECL_KERNEL(cpy_f32_q5_0);
162
+ //GGML_METAL_DECL_KERNEL(cpy_f32_q5_1);
121
163
  GGML_METAL_DECL_KERNEL(cpy_f16_f16);
164
+ GGML_METAL_DECL_KERNEL(cpy_f16_f32);
122
165
  GGML_METAL_DECL_KERNEL(concat);
123
166
  GGML_METAL_DECL_KERNEL(sqr);
167
+ GGML_METAL_DECL_KERNEL(sum_rows);
124
168
 
125
169
  #undef GGML_METAL_DECL_KERNEL
126
170
  };
@@ -155,6 +199,8 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
155
199
  ggml_metal_log_callback(level, buffer, ggml_metal_log_user_data);
156
200
  } else {
157
201
  char* buffer2 = malloc(len+1);
202
+ va_end(args);
203
+ va_start(args, format);
158
204
  vsnprintf(buffer2, len+1, format, args);
159
205
  buffer2[len] = 0;
160
206
  ggml_metal_log_callback(level, buffer2, ggml_metal_log_user_data);
@@ -164,12 +210,10 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
164
210
  }
165
211
  }
166
212
 
167
-
168
-
169
213
  struct ggml_metal_context * ggml_metal_init(int n_cb) {
170
214
  GGML_METAL_LOG_INFO("%s: allocating\n", __func__);
171
215
 
172
- id <MTLDevice> device;
216
+ id<MTLDevice> device;
173
217
  NSString * s;
174
218
 
175
219
  #if TARGET_OS_OSX
@@ -215,6 +259,9 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
215
259
 
216
260
  NSString * sourcePath;
217
261
  NSString * ggmlMetalPathResources = [[NSProcessInfo processInfo].environment objectForKey:@"GGML_METAL_PATH_RESOURCES"];
262
+
263
+ GGML_METAL_LOG_INFO("%s: GGML_METAL_PATH_RESOURCES = %s\n", __func__, ggmlMetalPathResources ? [ggmlMetalPathResources UTF8String] : "nil");
264
+
218
265
  if (ggmlMetalPathResources) {
219
266
  sourcePath = [ggmlMetalPathResources stringByAppendingPathComponent:@"ggml-metal.metal"];
220
267
  } else {
@@ -245,6 +292,29 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
245
292
  }
246
293
  }
247
294
 
295
+ #if TARGET_OS_OSX
296
+ // print MTL GPU family:
297
+ GGML_METAL_LOG_INFO("%s: GPU name: %s\n", __func__, [[ctx->device name] UTF8String]);
298
+
299
+ // determine max supported GPU family
300
+ // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
301
+ // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
302
+ for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
303
+ if ([ctx->device supportsFamily:i]) {
304
+ GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i);
305
+ break;
306
+ }
307
+ }
308
+
309
+ GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
310
+ GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1e6);
311
+ if (ctx->device.maxTransferRate != 0) {
312
+ GGML_METAL_LOG_INFO("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1e6);
313
+ } else {
314
+ GGML_METAL_LOG_INFO("%s: maxTransferRate = built-in GPU\n", __func__);
315
+ }
316
+ #endif
317
+
248
318
  // load kernels
249
319
  {
250
320
  NSError * error = nil;
@@ -266,11 +336,15 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
266
336
  GGML_METAL_ADD_KERNEL(add_row);
267
337
  GGML_METAL_ADD_KERNEL(mul);
268
338
  GGML_METAL_ADD_KERNEL(mul_row);
339
+ GGML_METAL_ADD_KERNEL(div);
340
+ GGML_METAL_ADD_KERNEL(div_row);
269
341
  GGML_METAL_ADD_KERNEL(scale);
270
342
  GGML_METAL_ADD_KERNEL(scale_4);
271
- GGML_METAL_ADD_KERNEL(silu);
343
+ GGML_METAL_ADD_KERNEL(tanh);
272
344
  GGML_METAL_ADD_KERNEL(relu);
273
345
  GGML_METAL_ADD_KERNEL(gelu);
346
+ GGML_METAL_ADD_KERNEL(gelu_quick);
347
+ GGML_METAL_ADD_KERNEL(silu);
274
348
  GGML_METAL_ADD_KERNEL(soft_max);
275
349
  GGML_METAL_ADD_KERNEL(soft_max_4);
276
350
  GGML_METAL_ADD_KERNEL(diag_mask_inf);
@@ -288,6 +362,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
288
362
  GGML_METAL_ADD_KERNEL(get_rows_q5_K);
289
363
  GGML_METAL_ADD_KERNEL(get_rows_q6_K);
290
364
  GGML_METAL_ADD_KERNEL(rms_norm);
365
+ GGML_METAL_ADD_KERNEL(group_norm);
291
366
  GGML_METAL_ADD_KERNEL(norm);
292
367
  GGML_METAL_ADD_KERNEL(mul_mv_f32_f32);
293
368
  GGML_METAL_ADD_KERNEL(mul_mv_f16_f16);
@@ -304,6 +379,21 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
304
379
  GGML_METAL_ADD_KERNEL(mul_mv_q4_K_f32);
305
380
  GGML_METAL_ADD_KERNEL(mul_mv_q5_K_f32);
306
381
  GGML_METAL_ADD_KERNEL(mul_mv_q6_K_f32);
382
+ GGML_METAL_ADD_KERNEL(mul_mv_id_f32_f32);
383
+ //GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f16);
384
+ GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32);
385
+ //GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32_1row);
386
+ //GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32_l4);
387
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q4_0_f32);
388
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q4_1_f32);
389
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q5_0_f32);
390
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q5_1_f32);
391
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q8_0_f32);
392
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q2_K_f32);
393
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q3_K_f32);
394
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q4_K_f32);
395
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q5_K_f32);
396
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q6_K_f32);
307
397
  if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
308
398
  GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
309
399
  GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
@@ -317,43 +407,44 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
317
407
  GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
318
408
  GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
319
409
  GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
410
+ GGML_METAL_ADD_KERNEL(mul_mm_id_f32_f32);
411
+ GGML_METAL_ADD_KERNEL(mul_mm_id_f16_f32);
412
+ GGML_METAL_ADD_KERNEL(mul_mm_id_q4_0_f32);
413
+ GGML_METAL_ADD_KERNEL(mul_mm_id_q4_1_f32);
414
+ GGML_METAL_ADD_KERNEL(mul_mm_id_q5_0_f32);
415
+ GGML_METAL_ADD_KERNEL(mul_mm_id_q5_1_f32);
416
+ GGML_METAL_ADD_KERNEL(mul_mm_id_q8_0_f32);
417
+ GGML_METAL_ADD_KERNEL(mul_mm_id_q2_K_f32);
418
+ GGML_METAL_ADD_KERNEL(mul_mm_id_q3_K_f32);
419
+ GGML_METAL_ADD_KERNEL(mul_mm_id_q4_K_f32);
420
+ GGML_METAL_ADD_KERNEL(mul_mm_id_q5_K_f32);
421
+ GGML_METAL_ADD_KERNEL(mul_mm_id_q6_K_f32);
320
422
  }
321
423
  GGML_METAL_ADD_KERNEL(rope_f32);
322
424
  GGML_METAL_ADD_KERNEL(rope_f16);
323
425
  GGML_METAL_ADD_KERNEL(alibi_f32);
324
426
  GGML_METAL_ADD_KERNEL(im2col_f16);
427
+ GGML_METAL_ADD_KERNEL(upscale_f32);
428
+ GGML_METAL_ADD_KERNEL(pad_f32);
429
+ GGML_METAL_ADD_KERNEL(argsort_f32_i32_asc);
430
+ GGML_METAL_ADD_KERNEL(argsort_f32_i32_desc);
431
+ GGML_METAL_ADD_KERNEL(leaky_relu_f32);
325
432
  GGML_METAL_ADD_KERNEL(cpy_f32_f16);
326
433
  GGML_METAL_ADD_KERNEL(cpy_f32_f32);
434
+ GGML_METAL_ADD_KERNEL(cpy_f32_q8_0);
435
+ GGML_METAL_ADD_KERNEL(cpy_f32_q4_0);
436
+ GGML_METAL_ADD_KERNEL(cpy_f32_q4_1);
437
+ //GGML_METAL_ADD_KERNEL(cpy_f32_q5_0);
438
+ //GGML_METAL_ADD_KERNEL(cpy_f32_q5_1);
327
439
  GGML_METAL_ADD_KERNEL(cpy_f16_f16);
440
+ GGML_METAL_ADD_KERNEL(cpy_f16_f32);
328
441
  GGML_METAL_ADD_KERNEL(concat);
329
442
  GGML_METAL_ADD_KERNEL(sqr);
443
+ GGML_METAL_ADD_KERNEL(sum_rows);
330
444
 
331
445
  #undef GGML_METAL_ADD_KERNEL
332
446
  }
333
447
 
334
- #if TARGET_OS_OSX
335
- // print MTL GPU family:
336
- GGML_METAL_LOG_INFO("%s: GPU name: %s\n", __func__, [[ctx->device name] UTF8String]);
337
-
338
- // determine max supported GPU family
339
- // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
340
- // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
341
- for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
342
- if ([ctx->device supportsFamily:i]) {
343
- GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i);
344
- break;
345
- }
346
- }
347
-
348
- GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
349
- GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MiB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
350
- if (ctx->device.maxTransferRate != 0) {
351
- GGML_METAL_LOG_INFO("%s: maxTransferRate = %8.2f MiB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
352
- } else {
353
- GGML_METAL_LOG_INFO("%s: maxTransferRate = built-in GPU\n", __func__);
354
- }
355
- #endif
356
-
357
448
  return ctx;
358
449
  }
359
450
 
@@ -367,11 +458,15 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
367
458
  GGML_METAL_DEL_KERNEL(add_row);
368
459
  GGML_METAL_DEL_KERNEL(mul);
369
460
  GGML_METAL_DEL_KERNEL(mul_row);
461
+ GGML_METAL_DEL_KERNEL(div);
462
+ GGML_METAL_DEL_KERNEL(div_row);
370
463
  GGML_METAL_DEL_KERNEL(scale);
371
464
  GGML_METAL_DEL_KERNEL(scale_4);
372
- GGML_METAL_DEL_KERNEL(silu);
465
+ GGML_METAL_DEL_KERNEL(tanh);
373
466
  GGML_METAL_DEL_KERNEL(relu);
374
467
  GGML_METAL_DEL_KERNEL(gelu);
468
+ GGML_METAL_DEL_KERNEL(gelu_quick);
469
+ GGML_METAL_DEL_KERNEL(silu);
375
470
  GGML_METAL_DEL_KERNEL(soft_max);
376
471
  GGML_METAL_DEL_KERNEL(soft_max_4);
377
472
  GGML_METAL_DEL_KERNEL(diag_mask_inf);
@@ -389,6 +484,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
389
484
  GGML_METAL_DEL_KERNEL(get_rows_q5_K);
390
485
  GGML_METAL_DEL_KERNEL(get_rows_q6_K);
391
486
  GGML_METAL_DEL_KERNEL(rms_norm);
487
+ GGML_METAL_DEL_KERNEL(group_norm);
392
488
  GGML_METAL_DEL_KERNEL(norm);
393
489
  GGML_METAL_DEL_KERNEL(mul_mv_f32_f32);
394
490
  GGML_METAL_DEL_KERNEL(mul_mv_f16_f16);
@@ -405,6 +501,21 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
405
501
  GGML_METAL_DEL_KERNEL(mul_mv_q4_K_f32);
406
502
  GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32);
407
503
  GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32);
504
+ GGML_METAL_DEL_KERNEL(mul_mv_id_f32_f32);
505
+ //GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f16);
506
+ GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32);
507
+ //GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32_1row);
508
+ //GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32_l4);
509
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q4_0_f32);
510
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q4_1_f32);
511
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q5_0_f32);
512
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q5_1_f32);
513
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q8_0_f32);
514
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q2_K_f32);
515
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q3_K_f32);
516
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q4_K_f32);
517
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q5_K_f32);
518
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q6_K_f32);
408
519
  if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
409
520
  GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
410
521
  GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
@@ -418,16 +529,40 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
418
529
  GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
419
530
  GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
420
531
  GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
532
+ GGML_METAL_DEL_KERNEL(mul_mm_id_f32_f32);
533
+ GGML_METAL_DEL_KERNEL(mul_mm_id_f16_f32);
534
+ GGML_METAL_DEL_KERNEL(mul_mm_id_q4_0_f32);
535
+ GGML_METAL_DEL_KERNEL(mul_mm_id_q4_1_f32);
536
+ GGML_METAL_DEL_KERNEL(mul_mm_id_q5_0_f32);
537
+ GGML_METAL_DEL_KERNEL(mul_mm_id_q5_1_f32);
538
+ GGML_METAL_DEL_KERNEL(mul_mm_id_q8_0_f32);
539
+ GGML_METAL_DEL_KERNEL(mul_mm_id_q2_K_f32);
540
+ GGML_METAL_DEL_KERNEL(mul_mm_id_q3_K_f32);
541
+ GGML_METAL_DEL_KERNEL(mul_mm_id_q4_K_f32);
542
+ GGML_METAL_DEL_KERNEL(mul_mm_id_q5_K_f32);
543
+ GGML_METAL_DEL_KERNEL(mul_mm_id_q6_K_f32);
421
544
  }
422
545
  GGML_METAL_DEL_KERNEL(rope_f32);
423
546
  GGML_METAL_DEL_KERNEL(rope_f16);
424
547
  GGML_METAL_DEL_KERNEL(alibi_f32);
425
548
  GGML_METAL_DEL_KERNEL(im2col_f16);
549
+ GGML_METAL_DEL_KERNEL(upscale_f32);
550
+ GGML_METAL_DEL_KERNEL(pad_f32);
551
+ GGML_METAL_DEL_KERNEL(argsort_f32_i32_asc);
552
+ GGML_METAL_DEL_KERNEL(argsort_f32_i32_desc);
553
+ GGML_METAL_DEL_KERNEL(leaky_relu_f32);
426
554
  GGML_METAL_DEL_KERNEL(cpy_f32_f16);
427
555
  GGML_METAL_DEL_KERNEL(cpy_f32_f32);
556
+ GGML_METAL_DEL_KERNEL(cpy_f32_q8_0);
557
+ GGML_METAL_DEL_KERNEL(cpy_f32_q4_0);
558
+ GGML_METAL_DEL_KERNEL(cpy_f32_q4_1);
559
+ //GGML_METAL_DEL_KERNEL(cpy_f32_q5_0);
560
+ //GGML_METAL_DEL_KERNEL(cpy_f32_q5_1);
428
561
  GGML_METAL_DEL_KERNEL(cpy_f16_f16);
562
+ GGML_METAL_DEL_KERNEL(cpy_f16_f32);
429
563
  GGML_METAL_DEL_KERNEL(concat);
430
564
  GGML_METAL_DEL_KERNEL(sqr);
565
+ GGML_METAL_DEL_KERNEL(sum_rows);
431
566
 
432
567
  #undef GGML_METAL_DEL_KERNEL
433
568
 
@@ -471,6 +606,13 @@ int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx) {
471
606
  return ctx->concur_list;
472
607
  }
473
608
 
609
+ // temporarily defined here for compatibility between ggml-backend and the old API
610
+ struct ggml_backend_metal_buffer_context {
611
+ void * data;
612
+
613
+ id<MTLBuffer> metal;
614
+ };
615
+
474
616
  // finds the Metal buffer that contains the tensor data on the GPU device
475
617
  // the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
476
618
  // Metal buffer based on the host memory pointer
@@ -480,8 +622,17 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru
480
622
 
481
623
  const int64_t tsize = ggml_nbytes(t);
482
624
 
483
- if (t->buffer && t->buffer->backend && t->buffer->backend->context) {
484
- ctx = t->buffer->backend->context;
625
+ // compatibility with ggml-backend
626
+ if (t->buffer && t->buffer->buft == ggml_backend_metal_buffer_type()) {
627
+ struct ggml_backend_metal_buffer_context * buf_ctx = (struct ggml_backend_metal_buffer_context *) t->buffer->context;
628
+
629
+ const int64_t ioffs = (int64_t) t->data - (int64_t) buf_ctx->data;
630
+
631
+ GGML_ASSERT(ioffs >= 0 && ioffs + tsize <= (int64_t) t->buffer->size);
632
+
633
+ *offs = (size_t) ioffs;
634
+
635
+ return buf_ctx->metal;
485
636
  }
486
637
 
487
638
  // find the view that contains the tensor fully
@@ -706,6 +857,83 @@ void ggml_metal_graph_find_concurrency(
706
857
  }
707
858
  }
708
859
 
860
+ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
861
+ switch (op->op) {
862
+ case GGML_OP_UNARY:
863
+ switch (ggml_get_unary_op(op)) {
864
+ case GGML_UNARY_OP_TANH:
865
+ case GGML_UNARY_OP_RELU:
866
+ case GGML_UNARY_OP_GELU:
867
+ case GGML_UNARY_OP_GELU_QUICK:
868
+ case GGML_UNARY_OP_SILU:
869
+ return true;
870
+ default:
871
+ return false;
872
+ }
873
+ case GGML_OP_NONE:
874
+ case GGML_OP_RESHAPE:
875
+ case GGML_OP_VIEW:
876
+ case GGML_OP_TRANSPOSE:
877
+ case GGML_OP_PERMUTE:
878
+ case GGML_OP_CONCAT:
879
+ case GGML_OP_ADD:
880
+ case GGML_OP_ACC:
881
+ case GGML_OP_MUL:
882
+ case GGML_OP_DIV:
883
+ case GGML_OP_SCALE:
884
+ case GGML_OP_SQR:
885
+ case GGML_OP_SUM_ROWS:
886
+ case GGML_OP_SOFT_MAX:
887
+ case GGML_OP_RMS_NORM:
888
+ case GGML_OP_GROUP_NORM:
889
+ case GGML_OP_NORM:
890
+ case GGML_OP_ALIBI:
891
+ case GGML_OP_ROPE:
892
+ case GGML_OP_IM2COL:
893
+ case GGML_OP_UPSCALE:
894
+ case GGML_OP_PAD:
895
+ case GGML_OP_ARGSORT:
896
+ case GGML_OP_LEAKY_RELU:
897
+ case GGML_OP_MUL_MAT:
898
+ case GGML_OP_MUL_MAT_ID:
899
+ return true;
900
+ case GGML_OP_CPY:
901
+ case GGML_OP_DUP:
902
+ case GGML_OP_CONT:
903
+ {
904
+ switch (op->src[0]->type) {
905
+ case GGML_TYPE_F32:
906
+ switch (op->type) {
907
+ case GGML_TYPE_F16:
908
+ case GGML_TYPE_F32:
909
+ case GGML_TYPE_Q8_0:
910
+ case GGML_TYPE_Q4_0:
911
+ case GGML_TYPE_Q4_1:
912
+ return true;
913
+ default:
914
+ return false;
915
+ }
916
+ case GGML_TYPE_F16:
917
+ switch (op->type) {
918
+ case GGML_TYPE_F16:
919
+ case GGML_TYPE_F32:
920
+ return true;
921
+ default:
922
+ return false;
923
+ }
924
+ default:
925
+ return false;
926
+ };
927
+ }
928
+ case GGML_OP_DIAG_MASK_INF:
929
+ case GGML_OP_GET_ROWS:
930
+ {
931
+ return op->ne[3] == 1;
932
+ }
933
+ default:
934
+ return false;
935
+ }
936
+ }
709
937
  void ggml_metal_graph_compute(
710
938
  struct ggml_metal_context * ctx,
711
939
  struct ggml_cgraph * gf) {
@@ -776,6 +1004,11 @@ void ggml_metal_graph_compute(
776
1004
  } break;
777
1005
  }
778
1006
 
1007
+ if (!ggml_metal_supports_op(dst)) {
1008
+ GGML_METAL_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst));
1009
+ GGML_ASSERT(!"unsupported op");
1010
+ }
1011
+
779
1012
  const int64_t ne00 = src0 ? src0->ne[0] : 0;
780
1013
  const int64_t ne01 = src0 ? src0->ne[1] : 0;
781
1014
  const int64_t ne02 = src0 ? src0->ne[2] : 0;
@@ -868,25 +1101,42 @@ void ggml_metal_graph_compute(
868
1101
  [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
869
1102
  } break;
870
1103
  case GGML_OP_ADD:
1104
+ case GGML_OP_MUL:
1105
+ case GGML_OP_DIV:
871
1106
  {
872
- GGML_ASSERT(ggml_is_contiguous(src0));
873
- GGML_ASSERT(ggml_is_contiguous(src1));
1107
+ const size_t offs = 0;
874
1108
 
875
1109
  bool bcast_row = false;
876
1110
 
877
1111
  int64_t nb = ne00;
878
1112
 
879
- if (ggml_nelements(src1) == ne10 && ne00 % 4 == 0) {
1113
+ id<MTLComputePipelineState> pipeline = nil;
1114
+
1115
+ if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
1116
+ GGML_ASSERT(ggml_is_contiguous(src0));
1117
+
880
1118
  // src1 is a row
881
1119
  GGML_ASSERT(ne11 == 1);
882
1120
 
883
1121
  nb = ne00 / 4;
884
- [encoder setComputePipelineState:ctx->pipeline_add_row];
1122
+ switch (dst->op) {
1123
+ case GGML_OP_ADD: pipeline = ctx->pipeline_add_row; break;
1124
+ case GGML_OP_MUL: pipeline = ctx->pipeline_mul_row; break;
1125
+ case GGML_OP_DIV: pipeline = ctx->pipeline_div_row; break;
1126
+ default: GGML_ASSERT(false);
1127
+ }
885
1128
 
886
1129
  bcast_row = true;
887
1130
  } else {
888
- [encoder setComputePipelineState:ctx->pipeline_add];
1131
+ switch (dst->op) {
1132
+ case GGML_OP_ADD: pipeline = ctx->pipeline_add; break;
1133
+ case GGML_OP_MUL: pipeline = ctx->pipeline_mul; break;
1134
+ case GGML_OP_DIV: pipeline = ctx->pipeline_div; break;
1135
+ default: GGML_ASSERT(false);
1136
+ }
889
1137
  }
1138
+
1139
+ [encoder setComputePipelineState:pipeline];
890
1140
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
891
1141
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
892
1142
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
@@ -914,42 +1164,98 @@ void ggml_metal_graph_compute(
914
1164
  [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
915
1165
  [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
916
1166
  [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
917
- [encoder setBytes:&nb length:sizeof(nb) atIndex:27];
1167
+ [encoder setBytes:&offs length:sizeof(offs) atIndex:27];
1168
+ [encoder setBytes:&nb length:sizeof(nb) atIndex:28];
918
1169
 
919
1170
  if (bcast_row) {
920
1171
  const int64_t n = ggml_nelements(dst)/4;
921
1172
 
922
1173
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
923
1174
  } else {
924
- const int nth = MIN(1024, ne0);
1175
+ const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
925
1176
 
926
1177
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
927
1178
  }
928
1179
  } break;
929
- case GGML_OP_MUL:
1180
+ case GGML_OP_ACC:
930
1181
  {
1182
+ GGML_ASSERT(src0t == GGML_TYPE_F32);
1183
+ GGML_ASSERT(src1t == GGML_TYPE_F32);
1184
+ GGML_ASSERT(dstt == GGML_TYPE_F32);
1185
+
931
1186
  GGML_ASSERT(ggml_is_contiguous(src0));
932
1187
  GGML_ASSERT(ggml_is_contiguous(src1));
933
1188
 
934
- // utilize float4
935
- GGML_ASSERT(ne00 % 4 == 0);
936
- const int64_t nb = ne00/4;
1189
+ const size_t pnb1 = ((int32_t *) dst->op_params)[0];
1190
+ const size_t pnb2 = ((int32_t *) dst->op_params)[1];
1191
+ const size_t pnb3 = ((int32_t *) dst->op_params)[2];
1192
+ const size_t offs = ((int32_t *) dst->op_params)[3];
1193
+
1194
+ const bool inplace = (bool) ((int32_t *) dst->op_params)[4];
1195
+
1196
+ if (!inplace) {
1197
+ // run a separete kernel to cpy src->dst
1198
+ // not sure how to avoid this
1199
+ // TODO: make a simpler cpy_bytes kernel
1200
+
1201
+ const int nth = MIN(1024, ne00);
1202
+
1203
+ [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32];
1204
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1205
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1206
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1207
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
1208
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
1209
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
1210
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
1211
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
1212
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
1213
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
1214
+ [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
1215
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
1216
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
1217
+ [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
1218
+ [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
1219
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
1220
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
1221
+ [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
937
1222
 
938
- if (ggml_nelements(src1) == ne10) {
939
- // src1 is a row
940
- GGML_ASSERT(ne11 == 1);
941
- [encoder setComputePipelineState:ctx->pipeline_mul_row];
942
- } else {
943
- [encoder setComputePipelineState:ctx->pipeline_mul];
1223
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
944
1224
  }
1225
+
1226
+ [encoder setComputePipelineState:ctx->pipeline_add];
945
1227
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
946
1228
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
947
1229
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
948
- [encoder setBytes:&nb length:sizeof(nb) atIndex:3];
1230
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1231
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1232
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1233
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
1234
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
1235
+ [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:8];
1236
+ [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:9];
1237
+ [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:10];
1238
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
1239
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
1240
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
1241
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
1242
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
1243
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
1244
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
1245
+ [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
1246
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
1247
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
1248
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
1249
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
1250
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
1251
+ [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:24];
1252
+ [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:25];
1253
+ [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26];
1254
+ [encoder setBytes:&offs length:sizeof(offs) atIndex:27];
949
1255
 
950
- const int64_t n = ggml_nelements(dst)/4;
1256
+ const int nth = MIN(1024, ne0);
951
1257
 
952
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1258
+ [encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
953
1259
  } break;
954
1260
  case GGML_OP_SCALE:
955
1261
  {
@@ -974,16 +1280,15 @@ void ggml_metal_graph_compute(
974
1280
  } break;
975
1281
  case GGML_OP_UNARY:
976
1282
  switch (ggml_get_unary_op(gf->nodes[i])) {
977
- case GGML_UNARY_OP_SILU:
1283
+ case GGML_UNARY_OP_TANH:
978
1284
  {
979
- [encoder setComputePipelineState:ctx->pipeline_silu];
1285
+ [encoder setComputePipelineState:ctx->pipeline_tanh];
980
1286
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
981
1287
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
982
1288
 
983
1289
  const int64_t n = ggml_nelements(dst);
984
- GGML_ASSERT(n % 4 == 0);
985
1290
 
986
- [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1291
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
987
1292
  } break;
988
1293
  case GGML_UNARY_OP_RELU:
989
1294
  {
@@ -1004,6 +1309,28 @@ void ggml_metal_graph_compute(
1004
1309
  const int64_t n = ggml_nelements(dst);
1005
1310
  GGML_ASSERT(n % 4 == 0);
1006
1311
 
1312
+ [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1313
+ } break;
1314
+ case GGML_UNARY_OP_GELU_QUICK:
1315
+ {
1316
+ [encoder setComputePipelineState:ctx->pipeline_gelu_quick];
1317
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1318
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1319
+
1320
+ const int64_t n = ggml_nelements(dst);
1321
+ GGML_ASSERT(n % 4 == 0);
1322
+
1323
+ [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1324
+ } break;
1325
+ case GGML_UNARY_OP_SILU:
1326
+ {
1327
+ [encoder setComputePipelineState:ctx->pipeline_silu];
1328
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1329
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1330
+
1331
+ const int64_t n = ggml_nelements(dst);
1332
+ GGML_ASSERT(n % 4 == 0);
1333
+
1007
1334
  [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1008
1335
  } break;
1009
1336
  default:
@@ -1023,6 +1350,40 @@ void ggml_metal_graph_compute(
1023
1350
  const int64_t n = ggml_nelements(dst);
1024
1351
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1025
1352
  } break;
1353
+ case GGML_OP_SUM_ROWS:
1354
+ {
1355
+ GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
1356
+
1357
+ [encoder setComputePipelineState:ctx->pipeline_sum_rows];
1358
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1359
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1360
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
1361
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
1362
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
1363
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
1364
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
1365
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
1366
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
1367
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
1368
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
1369
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
1370
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
1371
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
1372
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
1373
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
1374
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
1375
+ [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
1376
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18];
1377
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:19];
1378
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:20];
1379
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:21];
1380
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:22];
1381
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:23];
1382
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:24];
1383
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:25];
1384
+
1385
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1386
+ } break;
1026
1387
  case GGML_OP_SOFT_MAX:
1027
1388
  {
1028
1389
  int nth = 32; // SIMD width
@@ -1042,7 +1403,11 @@ void ggml_metal_graph_compute(
1042
1403
  const float scale = ((float *) dst->op_params)[0];
1043
1404
 
1044
1405
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1045
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1406
+ if (id_src1) {
1407
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1408
+ } else {
1409
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
1410
+ }
1046
1411
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1047
1412
  [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1048
1413
  [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
@@ -1077,9 +1442,13 @@ void ggml_metal_graph_compute(
1077
1442
  case GGML_OP_MUL_MAT:
1078
1443
  {
1079
1444
  GGML_ASSERT(ne00 == ne10);
1080
- GGML_ASSERT(ne03 == ne13);
1081
1445
 
1082
- const uint gqa = ne12/ne02;
1446
+ // TODO: assert that dim2 and dim3 are contiguous
1447
+ GGML_ASSERT(ne12 % ne02 == 0);
1448
+ GGML_ASSERT(ne13 % ne03 == 0);
1449
+
1450
+ const uint r2 = ne12/ne02;
1451
+ const uint r3 = ne13/ne03;
1083
1452
 
1084
1453
  // find the break-even point where the matrix-matrix kernel becomes more efficient compared
1085
1454
  // to the matrix-vector kernel
@@ -1114,7 +1483,7 @@ void ggml_metal_graph_compute(
1114
1483
  !ggml_is_transposed(src1) &&
1115
1484
  src1t == GGML_TYPE_F32 &&
1116
1485
  ne00 % 32 == 0 && ne00 >= 64 &&
1117
- ne11 > ne11_mm_min) {
1486
+ (ne11 > ne11_mm_min || (ggml_is_quantized(src0t) && ne12 > 1))) {
1118
1487
  //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
1119
1488
  switch (src0->type) {
1120
1489
  case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32]; break;
@@ -1144,9 +1513,10 @@ void ggml_metal_graph_compute(
1144
1513
  [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10];
1145
1514
  [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11];
1146
1515
  [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
1147
- [encoder setBytes:&gqa length:sizeof(gqa) atIndex:13];
1516
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:13];
1517
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:14];
1148
1518
  [encoder setThreadgroupMemoryLength:8192 atIndex:0];
1149
- [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1519
+ [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1150
1520
  } else {
1151
1521
  int nth0 = 32;
1152
1522
  int nth1 = 1;
@@ -1182,90 +1552,60 @@ void ggml_metal_graph_compute(
1182
1552
  } break;
1183
1553
  case GGML_TYPE_Q4_0:
1184
1554
  {
1185
- GGML_ASSERT(ne02 == 1);
1186
- GGML_ASSERT(ne12 == 1);
1187
-
1188
1555
  nth0 = 8;
1189
1556
  nth1 = 8;
1190
1557
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_0_f32];
1191
1558
  } break;
1192
1559
  case GGML_TYPE_Q4_1:
1193
1560
  {
1194
- GGML_ASSERT(ne02 == 1);
1195
- GGML_ASSERT(ne12 == 1);
1196
-
1197
1561
  nth0 = 8;
1198
1562
  nth1 = 8;
1199
1563
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_1_f32];
1200
1564
  } break;
1201
1565
  case GGML_TYPE_Q5_0:
1202
1566
  {
1203
- GGML_ASSERT(ne02 == 1);
1204
- GGML_ASSERT(ne12 == 1);
1205
-
1206
1567
  nth0 = 8;
1207
1568
  nth1 = 8;
1208
1569
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_0_f32];
1209
1570
  } break;
1210
1571
  case GGML_TYPE_Q5_1:
1211
1572
  {
1212
- GGML_ASSERT(ne02 == 1);
1213
- GGML_ASSERT(ne12 == 1);
1214
-
1215
1573
  nth0 = 8;
1216
1574
  nth1 = 8;
1217
1575
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_1_f32];
1218
1576
  } break;
1219
1577
  case GGML_TYPE_Q8_0:
1220
1578
  {
1221
- GGML_ASSERT(ne02 == 1);
1222
- GGML_ASSERT(ne12 == 1);
1223
-
1224
1579
  nth0 = 8;
1225
1580
  nth1 = 8;
1226
1581
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q8_0_f32];
1227
1582
  } break;
1228
1583
  case GGML_TYPE_Q2_K:
1229
1584
  {
1230
- GGML_ASSERT(ne02 == 1);
1231
- GGML_ASSERT(ne12 == 1);
1232
-
1233
1585
  nth0 = 2;
1234
1586
  nth1 = 32;
1235
1587
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q2_K_f32];
1236
1588
  } break;
1237
1589
  case GGML_TYPE_Q3_K:
1238
1590
  {
1239
- GGML_ASSERT(ne02 == 1);
1240
- GGML_ASSERT(ne12 == 1);
1241
-
1242
1591
  nth0 = 2;
1243
1592
  nth1 = 32;
1244
1593
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q3_K_f32];
1245
1594
  } break;
1246
1595
  case GGML_TYPE_Q4_K:
1247
1596
  {
1248
- GGML_ASSERT(ne02 == 1);
1249
- GGML_ASSERT(ne12 == 1);
1250
-
1251
1597
  nth0 = 4; //1;
1252
1598
  nth1 = 8; //32;
1253
1599
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_K_f32];
1254
1600
  } break;
1255
1601
  case GGML_TYPE_Q5_K:
1256
1602
  {
1257
- GGML_ASSERT(ne02 == 1);
1258
- GGML_ASSERT(ne12 == 1);
1259
-
1260
1603
  nth0 = 2;
1261
1604
  nth1 = 32;
1262
1605
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_K_f32];
1263
1606
  } break;
1264
1607
  case GGML_TYPE_Q6_K:
1265
1608
  {
1266
- GGML_ASSERT(ne02 == 1);
1267
- GGML_ASSERT(ne12 == 1);
1268
-
1269
1609
  nth0 = 2;
1270
1610
  nth1 = 32;
1271
1611
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q6_K_f32];
@@ -1294,31 +1634,281 @@ void ggml_metal_graph_compute(
1294
1634
  [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
1295
1635
  [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
1296
1636
  [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
1297
- [encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
1637
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:17];
1638
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:18];
1298
1639
 
1299
1640
  if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
1300
1641
  src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
1301
1642
  src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
1302
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1643
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1303
1644
  }
1304
1645
  else if (src0t == GGML_TYPE_Q4_K) {
1305
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1646
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1306
1647
  }
1307
1648
  else if (src0t == GGML_TYPE_Q3_K) {
1308
1649
  #ifdef GGML_QKK_64
1309
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1650
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1310
1651
  #else
1311
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1652
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1312
1653
  #endif
1313
1654
  }
1314
1655
  else if (src0t == GGML_TYPE_Q5_K) {
1315
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1656
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1316
1657
  }
1317
1658
  else if (src0t == GGML_TYPE_Q6_K) {
1318
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1659
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1660
+ } else {
1661
+ const int64_t ny = (ne11 + nrows - 1)/nrows;
1662
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1663
+ }
1664
+ }
1665
+ } break;
1666
+ case GGML_OP_MUL_MAT_ID:
1667
+ {
1668
+ //GGML_ASSERT(ne00 == ne10);
1669
+ //GGML_ASSERT(ne03 == ne13);
1670
+
1671
+ GGML_ASSERT(src0t == GGML_TYPE_I32);
1672
+
1673
+ const int n_as = ((int32_t *) dst->op_params)[1];
1674
+
1675
+ // TODO: make this more general
1676
+ GGML_ASSERT(n_as <= 8);
1677
+
1678
+ struct ggml_tensor * src2 = gf->nodes[i]->src[2];
1679
+
1680
+ const int64_t ne20 = src2 ? src2->ne[0] : 0;
1681
+ const int64_t ne21 = src2 ? src2->ne[1] : 0;
1682
+ const int64_t ne22 = src2 ? src2->ne[2] : 0;
1683
+ const int64_t ne23 = src2 ? src2->ne[3] : 0; GGML_UNUSED(ne23);
1684
+
1685
+ const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
1686
+ const uint64_t nb21 = src2 ? src2->nb[1] : 0;
1687
+ const uint64_t nb22 = src2 ? src2->nb[2] : 0;
1688
+ const uint64_t nb23 = src2 ? src2->nb[3] : 0; GGML_UNUSED(nb23);
1689
+
1690
+ const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
1691
+
1692
+ GGML_ASSERT(!ggml_is_transposed(src2));
1693
+ GGML_ASSERT(!ggml_is_transposed(src1));
1694
+
1695
+ GGML_ASSERT(ne20 % 32 == 0);
1696
+ // !!!!!!!!! TODO: this assert is probably required but not sure!
1697
+ //GGML_ASSERT(ne20 >= 64);
1698
+ GGML_ASSERT(src1t == GGML_TYPE_F32);
1699
+
1700
+ const uint r2 = ne12/ne22;
1701
+ const uint r3 = ne13/ne23;
1702
+
1703
+ // find the break-even point where the matrix-matrix kernel becomes more efficient compared
1704
+ // to the matrix-vector kernel
1705
+ int ne11_mm_min = 1;
1706
+
1707
+ const int idx = ((int32_t *) dst->op_params)[0];
1708
+
1709
+ // batch size
1710
+ GGML_ASSERT(ne01 == ne11);
1711
+
1712
+ const int64_t _ne1 = 1; // kernel_mul_mm_impl needs a reference in constant memory
1713
+
1714
+ // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
1715
+ // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
1716
+ // !!!
1717
+ // TODO: for now, always use mat-vec kernels until we figure out how to improve the
1718
+ // indirect matrix multiplication
1719
+ // !!!
1720
+ if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && _ne1 > ne11_mm_min) {
1721
+ switch (src2->type) {
1722
+ case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f32_f32]; break;
1723
+ case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f16_f32]; break;
1724
+ case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_0_f32]; break;
1725
+ case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_1_f32]; break;
1726
+ case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_0_f32]; break;
1727
+ case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_1_f32]; break;
1728
+ case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q8_0_f32]; break;
1729
+ case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q2_K_f32]; break;
1730
+ case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q3_K_f32]; break;
1731
+ case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_K_f32]; break;
1732
+ case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_K_f32]; break;
1733
+ case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q6_K_f32]; break;
1734
+ default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
1735
+ }
1736
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1737
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1738
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1739
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
1740
+ [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
1741
+ [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:5];
1742
+ [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
1743
+ [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:7];
1744
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:8];
1745
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:9];
1746
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10];
1747
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11];
1748
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12];
1749
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
1750
+ [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:14];
1751
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
1752
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:16];
1753
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:17];
1754
+ [encoder setBytes:&idx length:sizeof(idx) atIndex:18];
1755
+ // TODO: how to make this an array? read Metal docs
1756
+ for (int j = 0; j < n_as; ++j) {
1757
+ struct ggml_tensor * src_cur = dst->src[2 + j];
1758
+
1759
+ size_t offs_src_cur = 0;
1760
+ id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
1761
+
1762
+ [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:19 + j];
1763
+ }
1764
+
1765
+ [encoder setThreadgroupMemoryLength:8192 atIndex:0];
1766
+
1767
+ // TODO: processing one row at a time (ne11 -> 1) is not efficient
1768
+ [encoder dispatchThreadgroups:MTLSizeMake( (_ne1 + 31)/32, (ne21 + 63)/64, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1769
+ } else {
1770
+ int nth0 = 32;
1771
+ int nth1 = 1;
1772
+ int nrows = 1;
1773
+ //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
1774
+
1775
+ // use custom matrix x vector kernel
1776
+ switch (src2t) {
1777
+ case GGML_TYPE_F32:
1778
+ {
1779
+ GGML_ASSERT(src1t == GGML_TYPE_F32);
1780
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_f32_f32];
1781
+ } break;
1782
+ case GGML_TYPE_F16:
1783
+ {
1784
+ GGML_ASSERT(src1t == GGML_TYPE_F32);
1785
+ nth0 = 32;
1786
+ nth1 = 1;
1787
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_f16_f32];
1788
+ } break;
1789
+ case GGML_TYPE_Q4_0:
1790
+ {
1791
+ nth0 = 8;
1792
+ nth1 = 8;
1793
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_0_f32];
1794
+ } break;
1795
+ case GGML_TYPE_Q4_1:
1796
+ {
1797
+ nth0 = 8;
1798
+ nth1 = 8;
1799
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_1_f32];
1800
+ } break;
1801
+ case GGML_TYPE_Q5_0:
1802
+ {
1803
+ nth0 = 8;
1804
+ nth1 = 8;
1805
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_0_f32];
1806
+ } break;
1807
+ case GGML_TYPE_Q5_1:
1808
+ {
1809
+ nth0 = 8;
1810
+ nth1 = 8;
1811
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_1_f32];
1812
+ } break;
1813
+ case GGML_TYPE_Q8_0:
1814
+ {
1815
+ nth0 = 8;
1816
+ nth1 = 8;
1817
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q8_0_f32];
1818
+ } break;
1819
+ case GGML_TYPE_Q2_K:
1820
+ {
1821
+ nth0 = 2;
1822
+ nth1 = 32;
1823
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q2_K_f32];
1824
+ } break;
1825
+ case GGML_TYPE_Q3_K:
1826
+ {
1827
+ nth0 = 2;
1828
+ nth1 = 32;
1829
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q3_K_f32];
1830
+ } break;
1831
+ case GGML_TYPE_Q4_K:
1832
+ {
1833
+ nth0 = 4; //1;
1834
+ nth1 = 8; //32;
1835
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_K_f32];
1836
+ } break;
1837
+ case GGML_TYPE_Q5_K:
1838
+ {
1839
+ nth0 = 2;
1840
+ nth1 = 32;
1841
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_K_f32];
1842
+ } break;
1843
+ case GGML_TYPE_Q6_K:
1844
+ {
1845
+ nth0 = 2;
1846
+ nth1 = 32;
1847
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q6_K_f32];
1848
+ } break;
1849
+ default:
1850
+ {
1851
+ GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
1852
+ GGML_ASSERT(false && "not implemented");
1853
+ }
1854
+ };
1855
+
1856
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1857
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1858
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1859
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
1860
+ [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
1861
+ [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
1862
+ [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:6];
1863
+ [encoder setBytes:&nb20 length:sizeof(nb20) atIndex:7];
1864
+ [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:8];
1865
+ [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:9];
1866
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
1867
+ [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:11];
1868
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
1869
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
1870
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
1871
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
1872
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
1873
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
1874
+ [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:18];
1875
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19];
1876
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:20];
1877
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:21];
1878
+ [encoder setBytes:&idx length:sizeof(idx) atIndex:22];
1879
+ // TODO: how to make this an array? read Metal docs
1880
+ for (int j = 0; j < n_as; ++j) {
1881
+ struct ggml_tensor * src_cur = dst->src[2 + j];
1882
+
1883
+ size_t offs_src_cur = 0;
1884
+ id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
1885
+
1886
+ [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:23 + j];
1887
+ }
1888
+
1889
+ if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
1890
+ src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
1891
+ src2t == GGML_TYPE_Q2_K) { // || src2t == GGML_TYPE_Q4_K) {
1892
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1893
+ }
1894
+ else if (src2t == GGML_TYPE_Q4_K) {
1895
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1896
+ }
1897
+ else if (src2t == GGML_TYPE_Q3_K) {
1898
+ #ifdef GGML_QKK_64
1899
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1900
+ #else
1901
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1902
+ #endif
1903
+ }
1904
+ else if (src2t == GGML_TYPE_Q5_K) {
1905
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1906
+ }
1907
+ else if (src2t == GGML_TYPE_Q6_K) {
1908
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1319
1909
  } else {
1320
- int64_t ny = (ne11 + nrows - 1)/nrows;
1321
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1910
+ const int64_t ny = (_ne1 + nrows - 1)/nrows;
1911
+ [encoder dispatchThreadgroups:MTLSizeMake(ne21, ny, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1322
1912
  }
1323
1913
  }
1324
1914
  } break;
@@ -1340,16 +1930,19 @@ void ggml_metal_graph_compute(
1340
1930
  default: GGML_ASSERT(false && "not implemented");
1341
1931
  }
1342
1932
 
1343
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1344
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1345
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1933
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1934
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1935
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1346
1936
  [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
1347
1937
  [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
1348
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:5];
1349
-
1350
- const int64_t n = ggml_nelements(src1);
1351
-
1352
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1938
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5];
1939
+ [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6];
1940
+ [encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7];
1941
+ [encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8];
1942
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:9];
1943
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:10];
1944
+
1945
+ [encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
1353
1946
  } break;
1354
1947
  case GGML_OP_RMS_NORM:
1355
1948
  {
@@ -1376,6 +1969,38 @@ void ggml_metal_graph_compute(
1376
1969
 
1377
1970
  [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1378
1971
  } break;
1972
+ case GGML_OP_GROUP_NORM:
1973
+ {
1974
+ GGML_ASSERT(ne00 % 4 == 0);
1975
+
1976
+ //float eps;
1977
+ //memcpy(&eps, dst->op_params, sizeof(float));
1978
+
1979
+ const float eps = 1e-6f; // TODO: temporarily hardcoded
1980
+
1981
+ const int32_t n_groups = ((int32_t *) dst->op_params)[0];
1982
+
1983
+ int nth = 32; // SIMD width
1984
+
1985
+ //while (nth < ne00/4 && nth < 1024) {
1986
+ // nth *= 2;
1987
+ //}
1988
+
1989
+ [encoder setComputePipelineState:ctx->pipeline_group_norm];
1990
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1991
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1992
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1993
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
1994
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
1995
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:5];
1996
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:6];
1997
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:7];
1998
+ [encoder setBytes:&n_groups length:sizeof( int32_t) atIndex:8];
1999
+ [encoder setBytes:&eps length:sizeof( float) atIndex:9];
2000
+ [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
2001
+
2002
+ [encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2003
+ } break;
1379
2004
  case GGML_OP_NORM:
1380
2005
  {
1381
2006
  float eps;
@@ -1545,18 +2170,123 @@ void ggml_metal_graph_compute(
1545
2170
 
1546
2171
  [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
1547
2172
  } break;
2173
+ case GGML_OP_UPSCALE:
2174
+ {
2175
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
2176
+
2177
+ const int sf = dst->op_params[0];
2178
+
2179
+ [encoder setComputePipelineState:ctx->pipeline_upscale_f32];
2180
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2181
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2182
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
2183
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
2184
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
2185
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
2186
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
2187
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
2188
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
2189
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
2190
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
2191
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
2192
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
2193
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
2194
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
2195
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
2196
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
2197
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
2198
+ [encoder setBytes:&sf length:sizeof(sf) atIndex:18];
2199
+
2200
+ const int nth = MIN(1024, ne0);
2201
+
2202
+ [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2203
+ } break;
2204
+ case GGML_OP_PAD:
2205
+ {
2206
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
2207
+
2208
+ [encoder setComputePipelineState:ctx->pipeline_pad_f32];
2209
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2210
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2211
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
2212
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
2213
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
2214
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
2215
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
2216
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
2217
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
2218
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
2219
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
2220
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
2221
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
2222
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
2223
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
2224
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
2225
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
2226
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
2227
+
2228
+ const int nth = MIN(1024, ne0);
2229
+
2230
+ [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2231
+ } break;
2232
+ case GGML_OP_ARGSORT:
2233
+ {
2234
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
2235
+ GGML_ASSERT( dst->type == GGML_TYPE_I32);
2236
+
2237
+ const int nrows = ggml_nrows(src0);
2238
+
2239
+ enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
2240
+
2241
+ switch (order) {
2242
+ case GGML_SORT_ASC: [encoder setComputePipelineState:ctx->pipeline_argsort_f32_i32_asc]; break;
2243
+ case GGML_SORT_DESC: [encoder setComputePipelineState:ctx->pipeline_argsort_f32_i32_desc]; break;
2244
+ default: GGML_ASSERT(false);
2245
+ };
2246
+
2247
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2248
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2249
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
2250
+
2251
+ [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00, 1, 1)];
2252
+ } break;
2253
+ case GGML_OP_LEAKY_RELU:
2254
+ {
2255
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
2256
+
2257
+ float slope;
2258
+ memcpy(&slope, dst->op_params, sizeof(float));
2259
+
2260
+ [encoder setComputePipelineState:ctx->pipeline_leaky_relu_f32];
2261
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2262
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2263
+ [encoder setBytes:&slope length:sizeof(slope) atIndex:2];
2264
+
2265
+ const int64_t n = ggml_nelements(dst);
2266
+
2267
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2268
+ } break;
1548
2269
  case GGML_OP_DUP:
1549
2270
  case GGML_OP_CPY:
1550
2271
  case GGML_OP_CONT:
1551
2272
  {
1552
- const int nth = MIN(1024, ne00);
2273
+ GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
2274
+
2275
+ int nth = MIN(1024, ne00/ggml_blck_size(src0->type));
1553
2276
 
1554
2277
  switch (src0t) {
1555
2278
  case GGML_TYPE_F32:
1556
2279
  {
2280
+ GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0);
2281
+
1557
2282
  switch (dstt) {
1558
- case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f16]; break;
1559
- case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32]; break;
2283
+ case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f16]; break;
2284
+ case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32]; break;
2285
+ case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q8_0]; break;
2286
+ case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q4_0]; break;
2287
+ case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q4_1]; break;
2288
+ //case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q5_0]; break;
2289
+ //case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q5_1]; break;
1560
2290
  default: GGML_ASSERT(false && "not implemented");
1561
2291
  };
1562
2292
  } break;
@@ -1564,7 +2294,7 @@ void ggml_metal_graph_compute(
1564
2294
  {
1565
2295
  switch (dstt) {
1566
2296
  case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f16]; break;
1567
- case GGML_TYPE_F32: GGML_ASSERT(false && "cpy_f16_f32 not implemented"); break;
2297
+ case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f32]; break;
1568
2298
  default: GGML_ASSERT(false && "not implemented");
1569
2299
  };
1570
2300
  } break;
@@ -1631,81 +2361,150 @@ void ggml_metal_graph_compute(
1631
2361
 
1632
2362
  // backend interface
1633
2363
 
1634
- static const char * ggml_backend_metal_name(ggml_backend_t backend) {
1635
- return "Metal";
2364
+ static id<MTLDevice> g_backend_device = nil;
2365
+ static int g_backend_device_ref_count = 0;
1636
2366
 
1637
- UNUSED(backend);
2367
+ static id<MTLDevice> ggml_backend_metal_get_device(void) {
2368
+ if (g_backend_device == nil) {
2369
+ g_backend_device = MTLCreateSystemDefaultDevice();
2370
+ }
2371
+
2372
+ g_backend_device_ref_count++;
2373
+
2374
+ return g_backend_device;
1638
2375
  }
1639
2376
 
1640
- static void ggml_backend_metal_free(ggml_backend_t backend) {
1641
- struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
1642
- ggml_metal_free(ctx);
1643
- free(backend);
2377
+ static void ggml_backend_metal_free_device(void) {
2378
+ assert(g_backend_device_ref_count > 0);
2379
+
2380
+ g_backend_device_ref_count--;
2381
+
2382
+ if (g_backend_device_ref_count == 0) {
2383
+ [g_backend_device release];
2384
+ g_backend_device = nil;
2385
+ }
1644
2386
  }
1645
2387
 
1646
2388
  static void * ggml_backend_metal_buffer_get_base(ggml_backend_buffer_t buffer) {
1647
- return (void *)buffer->context;
2389
+ struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
2390
+
2391
+ return ctx->data;
1648
2392
  }
1649
2393
 
1650
2394
  static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer) {
1651
- free(buffer->context);
2395
+ struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
2396
+
2397
+ [ctx->metal release];
2398
+ ggml_backend_metal_free_device();
2399
+
2400
+ free(ctx->data);
2401
+ free(ctx);
2402
+
2403
+ UNUSED(buffer);
2404
+ }
2405
+
2406
+ static void ggml_backend_metal_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
2407
+ GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
2408
+ GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
2409
+
2410
+ memcpy((char *)tensor->data + offset, data, size);
2411
+
2412
+ UNUSED(buffer);
2413
+ }
2414
+
2415
+ static void ggml_backend_metal_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
2416
+ GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
2417
+ GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
2418
+
2419
+ memcpy(data, (const char *)tensor->data + offset, size);
2420
+
2421
+ UNUSED(buffer);
2422
+ }
2423
+
2424
+ static void ggml_backend_metal_buffer_cpy_tensor_from(ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst) {
2425
+ ggml_backend_tensor_get(src, dst->data, 0, ggml_nbytes(src));
2426
+
2427
+ UNUSED(buffer);
2428
+ }
2429
+
2430
+ static void ggml_backend_metal_buffer_cpy_tensor_to(ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst) {
2431
+ ggml_backend_tensor_set(dst, src->data, 0, ggml_nbytes(src));
2432
+
1652
2433
  UNUSED(buffer);
1653
2434
  }
1654
2435
 
1655
2436
  static struct ggml_backend_buffer_i metal_backend_buffer_i = {
1656
- /* .free_buffer = */ ggml_backend_metal_buffer_free_buffer,
1657
- /* .get_base = */ ggml_backend_metal_buffer_get_base,
1658
- /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
1659
- /* .init_tensor = */ NULL, // no initialization required
1660
- /* .free_tensor = */ NULL, // no cleanup required
2437
+ /* .free_buffer = */ ggml_backend_metal_buffer_free_buffer,
2438
+ /* .get_base = */ ggml_backend_metal_buffer_get_base,
2439
+ /* .init_tensor = */ NULL,
2440
+ /* .set_tensor = */ ggml_backend_metal_buffer_set_tensor,
2441
+ /* .get_tensor = */ ggml_backend_metal_buffer_get_tensor,
2442
+ /* .cpy_tensor_from = */ ggml_backend_metal_buffer_cpy_tensor_from,
2443
+ /* .cpy_tensor_to = */ ggml_backend_metal_buffer_cpy_tensor_to,
1661
2444
  };
1662
2445
 
1663
- static ggml_backend_buffer_t ggml_backend_metal_alloc_buffer(ggml_backend_t backend, size_t size) {
1664
- struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
2446
+ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
2447
+ struct ggml_backend_metal_buffer_context * ctx = malloc(sizeof(struct ggml_backend_metal_buffer_context));
2448
+
2449
+ const size_t size_page = sysconf(_SC_PAGESIZE);
1665
2450
 
1666
- void * data = ggml_metal_host_malloc(size);
2451
+ size_t size_aligned = size;
2452
+ if ((size_aligned % size_page) != 0) {
2453
+ size_aligned += (size_page - (size_aligned % size_page));
2454
+ }
1667
2455
 
1668
- // TODO: set proper name of the buffers
1669
- ggml_metal_add_buffer(ctx, "backend", data, size, 0);
2456
+ ctx->data = ggml_metal_host_malloc(size);
2457
+ ctx->metal = [ggml_backend_metal_get_device() newBufferWithBytesNoCopy:ctx->data
2458
+ length:size_aligned
2459
+ options:MTLResourceStorageModeShared
2460
+ deallocator:nil];
1670
2461
 
1671
- return ggml_backend_buffer_init(backend, metal_backend_buffer_i, data, size);
2462
+ return ggml_backend_buffer_init(buft, metal_backend_buffer_i, ctx, size);
1672
2463
  }
1673
2464
 
1674
- static size_t ggml_backend_metal_get_alignment(ggml_backend_t backend) {
2465
+ static size_t ggml_backend_metal_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
1675
2466
  return 32;
1676
- UNUSED(backend);
2467
+ UNUSED(buft);
1677
2468
  }
1678
2469
 
1679
- static void ggml_backend_metal_set_tensor_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
1680
- GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
1681
- GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
1682
-
1683
- memcpy((char *)tensor->data + offset, data, size);
2470
+ static bool ggml_backend_metal_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) {
2471
+ return ggml_backend_is_metal(backend) || ggml_backend_is_cpu(backend);
1684
2472
 
1685
- UNUSED(backend);
2473
+ GGML_UNUSED(buft);
1686
2474
  }
1687
2475
 
1688
- static void ggml_backend_metal_get_tensor_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
1689
- GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
1690
- GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
1691
-
1692
- memcpy(data, (const char *)tensor->data + offset, size);
2476
+ ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) {
2477
+ static struct ggml_backend_buffer_type ggml_backend_buffer_type_metal = {
2478
+ /* .iface = */ {
2479
+ /* .alloc_buffer = */ ggml_backend_metal_buffer_type_alloc_buffer,
2480
+ /* .get_alignment = */ ggml_backend_metal_buffer_type_get_alignment,
2481
+ /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
2482
+ /* .supports_backend = */ ggml_backend_metal_buffer_type_supports_backend,
2483
+ },
2484
+ /* .context = */ NULL,
2485
+ };
1693
2486
 
1694
- UNUSED(backend);
2487
+ return &ggml_backend_buffer_type_metal;
1695
2488
  }
1696
2489
 
1697
- static void ggml_backend_metal_synchronize(ggml_backend_t backend) {
2490
+ static const char * ggml_backend_metal_name(ggml_backend_t backend) {
2491
+ return "Metal";
2492
+
1698
2493
  UNUSED(backend);
1699
2494
  }
1700
2495
 
1701
- static void ggml_backend_metal_cpy_tensor_from(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst) {
1702
- ggml_backend_tensor_get(src, dst->data, 0, ggml_nbytes(src));
2496
+ static void ggml_backend_metal_free(ggml_backend_t backend) {
2497
+ struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
2498
+ ggml_metal_free(ctx);
2499
+ free(backend);
2500
+ }
1703
2501
 
2502
+ static void ggml_backend_metal_synchronize(ggml_backend_t backend) {
1704
2503
  UNUSED(backend);
1705
2504
  }
1706
2505
 
1707
- static void ggml_backend_metal_cpy_tensor_to(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst) {
1708
- ggml_backend_tensor_set_async(dst, src->data, 0, ggml_nbytes(src));
2506
+ static ggml_backend_buffer_type_t ggml_backend_metal_get_default_buffer_type(ggml_backend_t backend) {
2507
+ return ggml_backend_metal_buffer_type();
1709
2508
 
1710
2509
  UNUSED(backend);
1711
2510
  }
@@ -1717,32 +2516,43 @@ static void ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml
1717
2516
  }
1718
2517
 
1719
2518
  static bool ggml_backend_metal_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
1720
- return true;
2519
+ return ggml_metal_supports_op(op);
2520
+
1721
2521
  UNUSED(backend);
1722
- UNUSED(op);
1723
2522
  }
1724
2523
 
1725
2524
  static struct ggml_backend_i metal_backend_i = {
1726
- /* .get_name = */ ggml_backend_metal_name,
1727
- /* .free = */ ggml_backend_metal_free,
1728
- /* .alloc_buffer = */ ggml_backend_metal_alloc_buffer,
1729
- /* .get_alignment = */ ggml_backend_metal_get_alignment,
1730
- /* .set_tensor_async = */ ggml_backend_metal_set_tensor_async,
1731
- /* .get_tensor_async = */ ggml_backend_metal_get_tensor_async,
1732
- /* .synchronize = */ ggml_backend_metal_synchronize,
1733
- /* .cpy_tensor_from = */ ggml_backend_metal_cpy_tensor_from,
1734
- /* .cpy_tensor_to = */ ggml_backend_metal_cpy_tensor_to,
1735
- /* .graph_plan_create = */ NULL, // the metal implementation does not require creating graph plans atm
1736
- /* .graph_plan_free = */ NULL,
1737
- /* .graph_plan_compute = */ NULL,
1738
- /* .graph_compute = */ ggml_backend_metal_graph_compute,
1739
- /* .supports_op = */ ggml_backend_metal_supports_op,
2525
+ /* .get_name = */ ggml_backend_metal_name,
2526
+ /* .free = */ ggml_backend_metal_free,
2527
+ /* .get_default_buffer_type = */ ggml_backend_metal_get_default_buffer_type,
2528
+ /* .set_tensor_async = */ NULL,
2529
+ /* .get_tensor_async = */ NULL,
2530
+ /* .cpy_tensor_from_async = */ NULL,
2531
+ /* .cpy_tensor_to_async = */ NULL,
2532
+ /* .synchronize = */ ggml_backend_metal_synchronize,
2533
+ /* .graph_plan_create = */ NULL, // the metal implementation does not require creating graph plans atm
2534
+ /* .graph_plan_free = */ NULL,
2535
+ /* .graph_plan_compute = */ NULL,
2536
+ /* .graph_compute = */ ggml_backend_metal_graph_compute,
2537
+ /* .supports_op = */ ggml_backend_metal_supports_op,
1740
2538
  };
1741
2539
 
2540
+ // TODO: make a common log callback for all backends in ggml-backend
2541
+ static void ggml_backend_log_callback(enum ggml_log_level level, const char * msg, void * user_data) {
2542
+ fprintf(stderr, "%s", msg);
2543
+
2544
+ UNUSED(level);
2545
+ UNUSED(user_data);
2546
+ }
2547
+
1742
2548
  ggml_backend_t ggml_backend_metal_init(void) {
1743
- struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context));
2549
+ ggml_metal_log_set_callback(ggml_backend_log_callback, NULL);
1744
2550
 
1745
- ctx = ggml_metal_init(GGML_DEFAULT_N_THREADS);
2551
+ struct ggml_metal_context * ctx = ggml_metal_init(GGML_DEFAULT_N_THREADS);
2552
+
2553
+ if (ctx == NULL) {
2554
+ return NULL;
2555
+ }
1746
2556
 
1747
2557
  ggml_backend_t metal_backend = malloc(sizeof(struct ggml_backend));
1748
2558
 
@@ -1759,7 +2569,26 @@ bool ggml_backend_is_metal(ggml_backend_t backend) {
1759
2569
  }
1760
2570
 
1761
2571
  void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
2572
+ GGML_ASSERT(ggml_backend_is_metal(backend));
2573
+
1762
2574
  struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
1763
2575
 
1764
2576
  ggml_metal_set_n_cb(ctx, n_cb);
1765
2577
  }
2578
+
2579
+ bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {
2580
+ GGML_ASSERT(ggml_backend_is_metal(backend));
2581
+
2582
+ struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
2583
+
2584
+ return [ctx->device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
2585
+ }
2586
+
2587
+ ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data); // silence warning
2588
+
2589
+ ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data) {
2590
+ return ggml_backend_metal_init();
2591
+
2592
+ GGML_UNUSED(params);
2593
+ GGML_UNUSED(user_data);
2594
+ }