llama_cpp 0.9.5 → 0.10.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -62,11 +62,15 @@ struct ggml_metal_context {
62
62
  GGML_METAL_DECL_KERNEL(add_row); // TODO: avoid this extra kernel, instead extend the "add" kernel to support broadcast
63
63
  GGML_METAL_DECL_KERNEL(mul);
64
64
  GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast
65
+ GGML_METAL_DECL_KERNEL(div);
66
+ GGML_METAL_DECL_KERNEL(div_row);
65
67
  GGML_METAL_DECL_KERNEL(scale);
66
68
  GGML_METAL_DECL_KERNEL(scale_4);
67
- GGML_METAL_DECL_KERNEL(silu);
69
+ GGML_METAL_DECL_KERNEL(tanh);
68
70
  GGML_METAL_DECL_KERNEL(relu);
69
71
  GGML_METAL_DECL_KERNEL(gelu);
72
+ GGML_METAL_DECL_KERNEL(gelu_quick);
73
+ GGML_METAL_DECL_KERNEL(silu);
70
74
  GGML_METAL_DECL_KERNEL(soft_max);
71
75
  GGML_METAL_DECL_KERNEL(soft_max_4);
72
76
  GGML_METAL_DECL_KERNEL(diag_mask_inf);
@@ -84,6 +88,7 @@ struct ggml_metal_context {
84
88
  GGML_METAL_DECL_KERNEL(get_rows_q5_K);
85
89
  GGML_METAL_DECL_KERNEL(get_rows_q6_K);
86
90
  GGML_METAL_DECL_KERNEL(rms_norm);
91
+ GGML_METAL_DECL_KERNEL(group_norm);
87
92
  GGML_METAL_DECL_KERNEL(norm);
88
93
  GGML_METAL_DECL_KERNEL(mul_mv_f32_f32);
89
94
  GGML_METAL_DECL_KERNEL(mul_mv_f16_f16);
@@ -100,6 +105,21 @@ struct ggml_metal_context {
100
105
  GGML_METAL_DECL_KERNEL(mul_mv_q4_K_f32);
101
106
  GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32);
102
107
  GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32);
108
+ GGML_METAL_DECL_KERNEL(mul_mv_id_f32_f32);
109
+ //GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f16);
110
+ GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32);
111
+ //GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32_1row);
112
+ //GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32_l4);
113
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q4_0_f32);
114
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q4_1_f32);
115
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q5_0_f32);
116
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q5_1_f32);
117
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q8_0_f32);
118
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q2_K_f32);
119
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q3_K_f32);
120
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q4_K_f32);
121
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q5_K_f32);
122
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q6_K_f32);
103
123
  GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
104
124
  GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
105
125
  GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
@@ -112,15 +132,39 @@ struct ggml_metal_context {
112
132
  GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
113
133
  GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32);
114
134
  GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32);
135
+ GGML_METAL_DECL_KERNEL(mul_mm_id_f32_f32);
136
+ GGML_METAL_DECL_KERNEL(mul_mm_id_f16_f32);
137
+ GGML_METAL_DECL_KERNEL(mul_mm_id_q4_0_f32);
138
+ GGML_METAL_DECL_KERNEL(mul_mm_id_q4_1_f32);
139
+ GGML_METAL_DECL_KERNEL(mul_mm_id_q5_0_f32);
140
+ GGML_METAL_DECL_KERNEL(mul_mm_id_q5_1_f32);
141
+ GGML_METAL_DECL_KERNEL(mul_mm_id_q8_0_f32);
142
+ GGML_METAL_DECL_KERNEL(mul_mm_id_q2_K_f32);
143
+ GGML_METAL_DECL_KERNEL(mul_mm_id_q3_K_f32);
144
+ GGML_METAL_DECL_KERNEL(mul_mm_id_q4_K_f32);
145
+ GGML_METAL_DECL_KERNEL(mul_mm_id_q5_K_f32);
146
+ GGML_METAL_DECL_KERNEL(mul_mm_id_q6_K_f32);
115
147
  GGML_METAL_DECL_KERNEL(rope_f32);
116
148
  GGML_METAL_DECL_KERNEL(rope_f16);
117
149
  GGML_METAL_DECL_KERNEL(alibi_f32);
118
150
  GGML_METAL_DECL_KERNEL(im2col_f16);
151
+ GGML_METAL_DECL_KERNEL(upscale_f32);
152
+ GGML_METAL_DECL_KERNEL(pad_f32);
153
+ GGML_METAL_DECL_KERNEL(argsort_f32_i32_asc);
154
+ GGML_METAL_DECL_KERNEL(argsort_f32_i32_desc);
155
+ GGML_METAL_DECL_KERNEL(leaky_relu_f32);
119
156
  GGML_METAL_DECL_KERNEL(cpy_f32_f16);
120
157
  GGML_METAL_DECL_KERNEL(cpy_f32_f32);
158
+ GGML_METAL_DECL_KERNEL(cpy_f32_q8_0);
159
+ GGML_METAL_DECL_KERNEL(cpy_f32_q4_0);
160
+ GGML_METAL_DECL_KERNEL(cpy_f32_q4_1);
161
+ //GGML_METAL_DECL_KERNEL(cpy_f32_q5_0);
162
+ //GGML_METAL_DECL_KERNEL(cpy_f32_q5_1);
121
163
  GGML_METAL_DECL_KERNEL(cpy_f16_f16);
164
+ GGML_METAL_DECL_KERNEL(cpy_f16_f32);
122
165
  GGML_METAL_DECL_KERNEL(concat);
123
166
  GGML_METAL_DECL_KERNEL(sqr);
167
+ GGML_METAL_DECL_KERNEL(sum_rows);
124
168
 
125
169
  #undef GGML_METAL_DECL_KERNEL
126
170
  };
@@ -155,6 +199,8 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
155
199
  ggml_metal_log_callback(level, buffer, ggml_metal_log_user_data);
156
200
  } else {
157
201
  char* buffer2 = malloc(len+1);
202
+ va_end(args);
203
+ va_start(args, format);
158
204
  vsnprintf(buffer2, len+1, format, args);
159
205
  buffer2[len] = 0;
160
206
  ggml_metal_log_callback(level, buffer2, ggml_metal_log_user_data);
@@ -164,12 +210,10 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
164
210
  }
165
211
  }
166
212
 
167
-
168
-
169
213
  struct ggml_metal_context * ggml_metal_init(int n_cb) {
170
214
  GGML_METAL_LOG_INFO("%s: allocating\n", __func__);
171
215
 
172
- id <MTLDevice> device;
216
+ id<MTLDevice> device;
173
217
  NSString * s;
174
218
 
175
219
  #if TARGET_OS_OSX
@@ -215,6 +259,9 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
215
259
 
216
260
  NSString * sourcePath;
217
261
  NSString * ggmlMetalPathResources = [[NSProcessInfo processInfo].environment objectForKey:@"GGML_METAL_PATH_RESOURCES"];
262
+
263
+ GGML_METAL_LOG_INFO("%s: GGML_METAL_PATH_RESOURCES = %s\n", __func__, ggmlMetalPathResources ? [ggmlMetalPathResources UTF8String] : "nil");
264
+
218
265
  if (ggmlMetalPathResources) {
219
266
  sourcePath = [ggmlMetalPathResources stringByAppendingPathComponent:@"ggml-metal.metal"];
220
267
  } else {
@@ -245,6 +292,29 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
245
292
  }
246
293
  }
247
294
 
295
+ #if TARGET_OS_OSX
296
+ // print MTL GPU family:
297
+ GGML_METAL_LOG_INFO("%s: GPU name: %s\n", __func__, [[ctx->device name] UTF8String]);
298
+
299
+ // determine max supported GPU family
300
+ // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
301
+ // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
302
+ for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
303
+ if ([ctx->device supportsFamily:i]) {
304
+ GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i);
305
+ break;
306
+ }
307
+ }
308
+
309
+ GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
310
+ GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1e6);
311
+ if (ctx->device.maxTransferRate != 0) {
312
+ GGML_METAL_LOG_INFO("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1e6);
313
+ } else {
314
+ GGML_METAL_LOG_INFO("%s: maxTransferRate = built-in GPU\n", __func__);
315
+ }
316
+ #endif
317
+
248
318
  // load kernels
249
319
  {
250
320
  NSError * error = nil;
@@ -266,11 +336,15 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
266
336
  GGML_METAL_ADD_KERNEL(add_row);
267
337
  GGML_METAL_ADD_KERNEL(mul);
268
338
  GGML_METAL_ADD_KERNEL(mul_row);
339
+ GGML_METAL_ADD_KERNEL(div);
340
+ GGML_METAL_ADD_KERNEL(div_row);
269
341
  GGML_METAL_ADD_KERNEL(scale);
270
342
  GGML_METAL_ADD_KERNEL(scale_4);
271
- GGML_METAL_ADD_KERNEL(silu);
343
+ GGML_METAL_ADD_KERNEL(tanh);
272
344
  GGML_METAL_ADD_KERNEL(relu);
273
345
  GGML_METAL_ADD_KERNEL(gelu);
346
+ GGML_METAL_ADD_KERNEL(gelu_quick);
347
+ GGML_METAL_ADD_KERNEL(silu);
274
348
  GGML_METAL_ADD_KERNEL(soft_max);
275
349
  GGML_METAL_ADD_KERNEL(soft_max_4);
276
350
  GGML_METAL_ADD_KERNEL(diag_mask_inf);
@@ -288,6 +362,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
288
362
  GGML_METAL_ADD_KERNEL(get_rows_q5_K);
289
363
  GGML_METAL_ADD_KERNEL(get_rows_q6_K);
290
364
  GGML_METAL_ADD_KERNEL(rms_norm);
365
+ GGML_METAL_ADD_KERNEL(group_norm);
291
366
  GGML_METAL_ADD_KERNEL(norm);
292
367
  GGML_METAL_ADD_KERNEL(mul_mv_f32_f32);
293
368
  GGML_METAL_ADD_KERNEL(mul_mv_f16_f16);
@@ -304,6 +379,21 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
304
379
  GGML_METAL_ADD_KERNEL(mul_mv_q4_K_f32);
305
380
  GGML_METAL_ADD_KERNEL(mul_mv_q5_K_f32);
306
381
  GGML_METAL_ADD_KERNEL(mul_mv_q6_K_f32);
382
+ GGML_METAL_ADD_KERNEL(mul_mv_id_f32_f32);
383
+ //GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f16);
384
+ GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32);
385
+ //GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32_1row);
386
+ //GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32_l4);
387
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q4_0_f32);
388
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q4_1_f32);
389
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q5_0_f32);
390
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q5_1_f32);
391
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q8_0_f32);
392
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q2_K_f32);
393
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q3_K_f32);
394
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q4_K_f32);
395
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q5_K_f32);
396
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q6_K_f32);
307
397
  if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
308
398
  GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
309
399
  GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
@@ -317,43 +407,44 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
317
407
  GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
318
408
  GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
319
409
  GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
410
+ GGML_METAL_ADD_KERNEL(mul_mm_id_f32_f32);
411
+ GGML_METAL_ADD_KERNEL(mul_mm_id_f16_f32);
412
+ GGML_METAL_ADD_KERNEL(mul_mm_id_q4_0_f32);
413
+ GGML_METAL_ADD_KERNEL(mul_mm_id_q4_1_f32);
414
+ GGML_METAL_ADD_KERNEL(mul_mm_id_q5_0_f32);
415
+ GGML_METAL_ADD_KERNEL(mul_mm_id_q5_1_f32);
416
+ GGML_METAL_ADD_KERNEL(mul_mm_id_q8_0_f32);
417
+ GGML_METAL_ADD_KERNEL(mul_mm_id_q2_K_f32);
418
+ GGML_METAL_ADD_KERNEL(mul_mm_id_q3_K_f32);
419
+ GGML_METAL_ADD_KERNEL(mul_mm_id_q4_K_f32);
420
+ GGML_METAL_ADD_KERNEL(mul_mm_id_q5_K_f32);
421
+ GGML_METAL_ADD_KERNEL(mul_mm_id_q6_K_f32);
320
422
  }
321
423
  GGML_METAL_ADD_KERNEL(rope_f32);
322
424
  GGML_METAL_ADD_KERNEL(rope_f16);
323
425
  GGML_METAL_ADD_KERNEL(alibi_f32);
324
426
  GGML_METAL_ADD_KERNEL(im2col_f16);
427
+ GGML_METAL_ADD_KERNEL(upscale_f32);
428
+ GGML_METAL_ADD_KERNEL(pad_f32);
429
+ GGML_METAL_ADD_KERNEL(argsort_f32_i32_asc);
430
+ GGML_METAL_ADD_KERNEL(argsort_f32_i32_desc);
431
+ GGML_METAL_ADD_KERNEL(leaky_relu_f32);
325
432
  GGML_METAL_ADD_KERNEL(cpy_f32_f16);
326
433
  GGML_METAL_ADD_KERNEL(cpy_f32_f32);
434
+ GGML_METAL_ADD_KERNEL(cpy_f32_q8_0);
435
+ GGML_METAL_ADD_KERNEL(cpy_f32_q4_0);
436
+ GGML_METAL_ADD_KERNEL(cpy_f32_q4_1);
437
+ //GGML_METAL_ADD_KERNEL(cpy_f32_q5_0);
438
+ //GGML_METAL_ADD_KERNEL(cpy_f32_q5_1);
327
439
  GGML_METAL_ADD_KERNEL(cpy_f16_f16);
440
+ GGML_METAL_ADD_KERNEL(cpy_f16_f32);
328
441
  GGML_METAL_ADD_KERNEL(concat);
329
442
  GGML_METAL_ADD_KERNEL(sqr);
443
+ GGML_METAL_ADD_KERNEL(sum_rows);
330
444
 
331
445
  #undef GGML_METAL_ADD_KERNEL
332
446
  }
333
447
 
334
- #if TARGET_OS_OSX
335
- // print MTL GPU family:
336
- GGML_METAL_LOG_INFO("%s: GPU name: %s\n", __func__, [[ctx->device name] UTF8String]);
337
-
338
- // determine max supported GPU family
339
- // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
340
- // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
341
- for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
342
- if ([ctx->device supportsFamily:i]) {
343
- GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i);
344
- break;
345
- }
346
- }
347
-
348
- GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
349
- GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MiB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
350
- if (ctx->device.maxTransferRate != 0) {
351
- GGML_METAL_LOG_INFO("%s: maxTransferRate = %8.2f MiB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
352
- } else {
353
- GGML_METAL_LOG_INFO("%s: maxTransferRate = built-in GPU\n", __func__);
354
- }
355
- #endif
356
-
357
448
  return ctx;
358
449
  }
359
450
 
@@ -367,11 +458,15 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
367
458
  GGML_METAL_DEL_KERNEL(add_row);
368
459
  GGML_METAL_DEL_KERNEL(mul);
369
460
  GGML_METAL_DEL_KERNEL(mul_row);
461
+ GGML_METAL_DEL_KERNEL(div);
462
+ GGML_METAL_DEL_KERNEL(div_row);
370
463
  GGML_METAL_DEL_KERNEL(scale);
371
464
  GGML_METAL_DEL_KERNEL(scale_4);
372
- GGML_METAL_DEL_KERNEL(silu);
465
+ GGML_METAL_DEL_KERNEL(tanh);
373
466
  GGML_METAL_DEL_KERNEL(relu);
374
467
  GGML_METAL_DEL_KERNEL(gelu);
468
+ GGML_METAL_DEL_KERNEL(gelu_quick);
469
+ GGML_METAL_DEL_KERNEL(silu);
375
470
  GGML_METAL_DEL_KERNEL(soft_max);
376
471
  GGML_METAL_DEL_KERNEL(soft_max_4);
377
472
  GGML_METAL_DEL_KERNEL(diag_mask_inf);
@@ -389,6 +484,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
389
484
  GGML_METAL_DEL_KERNEL(get_rows_q5_K);
390
485
  GGML_METAL_DEL_KERNEL(get_rows_q6_K);
391
486
  GGML_METAL_DEL_KERNEL(rms_norm);
487
+ GGML_METAL_DEL_KERNEL(group_norm);
392
488
  GGML_METAL_DEL_KERNEL(norm);
393
489
  GGML_METAL_DEL_KERNEL(mul_mv_f32_f32);
394
490
  GGML_METAL_DEL_KERNEL(mul_mv_f16_f16);
@@ -405,6 +501,21 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
405
501
  GGML_METAL_DEL_KERNEL(mul_mv_q4_K_f32);
406
502
  GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32);
407
503
  GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32);
504
+ GGML_METAL_DEL_KERNEL(mul_mv_id_f32_f32);
505
+ //GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f16);
506
+ GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32);
507
+ //GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32_1row);
508
+ //GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32_l4);
509
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q4_0_f32);
510
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q4_1_f32);
511
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q5_0_f32);
512
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q5_1_f32);
513
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q8_0_f32);
514
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q2_K_f32);
515
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q3_K_f32);
516
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q4_K_f32);
517
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q5_K_f32);
518
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q6_K_f32);
408
519
  if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
409
520
  GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
410
521
  GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
@@ -418,16 +529,40 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
418
529
  GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
419
530
  GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
420
531
  GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
532
+ GGML_METAL_DEL_KERNEL(mul_mm_id_f32_f32);
533
+ GGML_METAL_DEL_KERNEL(mul_mm_id_f16_f32);
534
+ GGML_METAL_DEL_KERNEL(mul_mm_id_q4_0_f32);
535
+ GGML_METAL_DEL_KERNEL(mul_mm_id_q4_1_f32);
536
+ GGML_METAL_DEL_KERNEL(mul_mm_id_q5_0_f32);
537
+ GGML_METAL_DEL_KERNEL(mul_mm_id_q5_1_f32);
538
+ GGML_METAL_DEL_KERNEL(mul_mm_id_q8_0_f32);
539
+ GGML_METAL_DEL_KERNEL(mul_mm_id_q2_K_f32);
540
+ GGML_METAL_DEL_KERNEL(mul_mm_id_q3_K_f32);
541
+ GGML_METAL_DEL_KERNEL(mul_mm_id_q4_K_f32);
542
+ GGML_METAL_DEL_KERNEL(mul_mm_id_q5_K_f32);
543
+ GGML_METAL_DEL_KERNEL(mul_mm_id_q6_K_f32);
421
544
  }
422
545
  GGML_METAL_DEL_KERNEL(rope_f32);
423
546
  GGML_METAL_DEL_KERNEL(rope_f16);
424
547
  GGML_METAL_DEL_KERNEL(alibi_f32);
425
548
  GGML_METAL_DEL_KERNEL(im2col_f16);
549
+ GGML_METAL_DEL_KERNEL(upscale_f32);
550
+ GGML_METAL_DEL_KERNEL(pad_f32);
551
+ GGML_METAL_DEL_KERNEL(argsort_f32_i32_asc);
552
+ GGML_METAL_DEL_KERNEL(argsort_f32_i32_desc);
553
+ GGML_METAL_DEL_KERNEL(leaky_relu_f32);
426
554
  GGML_METAL_DEL_KERNEL(cpy_f32_f16);
427
555
  GGML_METAL_DEL_KERNEL(cpy_f32_f32);
556
+ GGML_METAL_DEL_KERNEL(cpy_f32_q8_0);
557
+ GGML_METAL_DEL_KERNEL(cpy_f32_q4_0);
558
+ GGML_METAL_DEL_KERNEL(cpy_f32_q4_1);
559
+ //GGML_METAL_DEL_KERNEL(cpy_f32_q5_0);
560
+ //GGML_METAL_DEL_KERNEL(cpy_f32_q5_1);
428
561
  GGML_METAL_DEL_KERNEL(cpy_f16_f16);
562
+ GGML_METAL_DEL_KERNEL(cpy_f16_f32);
429
563
  GGML_METAL_DEL_KERNEL(concat);
430
564
  GGML_METAL_DEL_KERNEL(sqr);
565
+ GGML_METAL_DEL_KERNEL(sum_rows);
431
566
 
432
567
  #undef GGML_METAL_DEL_KERNEL
433
568
 
@@ -471,6 +606,13 @@ int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx) {
471
606
  return ctx->concur_list;
472
607
  }
473
608
 
609
+ // temporarily defined here for compatibility between ggml-backend and the old API
610
+ struct ggml_backend_metal_buffer_context {
611
+ void * data;
612
+
613
+ id<MTLBuffer> metal;
614
+ };
615
+
474
616
  // finds the Metal buffer that contains the tensor data on the GPU device
475
617
  // the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
476
618
  // Metal buffer based on the host memory pointer
@@ -480,8 +622,17 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru
480
622
 
481
623
  const int64_t tsize = ggml_nbytes(t);
482
624
 
483
- if (t->buffer && t->buffer->backend && t->buffer->backend->context) {
484
- ctx = t->buffer->backend->context;
625
+ // compatibility with ggml-backend
626
+ if (t->buffer && t->buffer->buft == ggml_backend_metal_buffer_type()) {
627
+ struct ggml_backend_metal_buffer_context * buf_ctx = (struct ggml_backend_metal_buffer_context *) t->buffer->context;
628
+
629
+ const int64_t ioffs = (int64_t) t->data - (int64_t) buf_ctx->data;
630
+
631
+ GGML_ASSERT(ioffs >= 0 && ioffs + tsize <= (int64_t) t->buffer->size);
632
+
633
+ *offs = (size_t) ioffs;
634
+
635
+ return buf_ctx->metal;
485
636
  }
486
637
 
487
638
  // find the view that contains the tensor fully
@@ -706,6 +857,83 @@ void ggml_metal_graph_find_concurrency(
706
857
  }
707
858
  }
708
859
 
860
+ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
861
+ switch (op->op) {
862
+ case GGML_OP_UNARY:
863
+ switch (ggml_get_unary_op(op)) {
864
+ case GGML_UNARY_OP_TANH:
865
+ case GGML_UNARY_OP_RELU:
866
+ case GGML_UNARY_OP_GELU:
867
+ case GGML_UNARY_OP_GELU_QUICK:
868
+ case GGML_UNARY_OP_SILU:
869
+ return true;
870
+ default:
871
+ return false;
872
+ }
873
+ case GGML_OP_NONE:
874
+ case GGML_OP_RESHAPE:
875
+ case GGML_OP_VIEW:
876
+ case GGML_OP_TRANSPOSE:
877
+ case GGML_OP_PERMUTE:
878
+ case GGML_OP_CONCAT:
879
+ case GGML_OP_ADD:
880
+ case GGML_OP_ACC:
881
+ case GGML_OP_MUL:
882
+ case GGML_OP_DIV:
883
+ case GGML_OP_SCALE:
884
+ case GGML_OP_SQR:
885
+ case GGML_OP_SUM_ROWS:
886
+ case GGML_OP_SOFT_MAX:
887
+ case GGML_OP_RMS_NORM:
888
+ case GGML_OP_GROUP_NORM:
889
+ case GGML_OP_NORM:
890
+ case GGML_OP_ALIBI:
891
+ case GGML_OP_ROPE:
892
+ case GGML_OP_IM2COL:
893
+ case GGML_OP_UPSCALE:
894
+ case GGML_OP_PAD:
895
+ case GGML_OP_ARGSORT:
896
+ case GGML_OP_LEAKY_RELU:
897
+ case GGML_OP_MUL_MAT:
898
+ case GGML_OP_MUL_MAT_ID:
899
+ return true;
900
+ case GGML_OP_CPY:
901
+ case GGML_OP_DUP:
902
+ case GGML_OP_CONT:
903
+ {
904
+ switch (op->src[0]->type) {
905
+ case GGML_TYPE_F32:
906
+ switch (op->type) {
907
+ case GGML_TYPE_F16:
908
+ case GGML_TYPE_F32:
909
+ case GGML_TYPE_Q8_0:
910
+ case GGML_TYPE_Q4_0:
911
+ case GGML_TYPE_Q4_1:
912
+ return true;
913
+ default:
914
+ return false;
915
+ }
916
+ case GGML_TYPE_F16:
917
+ switch (op->type) {
918
+ case GGML_TYPE_F16:
919
+ case GGML_TYPE_F32:
920
+ return true;
921
+ default:
922
+ return false;
923
+ }
924
+ default:
925
+ return false;
926
+ };
927
+ }
928
+ case GGML_OP_DIAG_MASK_INF:
929
+ case GGML_OP_GET_ROWS:
930
+ {
931
+ return op->ne[3] == 1;
932
+ }
933
+ default:
934
+ return false;
935
+ }
936
+ }
709
937
  void ggml_metal_graph_compute(
710
938
  struct ggml_metal_context * ctx,
711
939
  struct ggml_cgraph * gf) {
@@ -776,6 +1004,11 @@ void ggml_metal_graph_compute(
776
1004
  } break;
777
1005
  }
778
1006
 
1007
+ if (!ggml_metal_supports_op(dst)) {
1008
+ GGML_METAL_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst));
1009
+ GGML_ASSERT(!"unsupported op");
1010
+ }
1011
+
779
1012
  const int64_t ne00 = src0 ? src0->ne[0] : 0;
780
1013
  const int64_t ne01 = src0 ? src0->ne[1] : 0;
781
1014
  const int64_t ne02 = src0 ? src0->ne[2] : 0;
@@ -868,25 +1101,42 @@ void ggml_metal_graph_compute(
868
1101
  [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
869
1102
  } break;
870
1103
  case GGML_OP_ADD:
1104
+ case GGML_OP_MUL:
1105
+ case GGML_OP_DIV:
871
1106
  {
872
- GGML_ASSERT(ggml_is_contiguous(src0));
873
- GGML_ASSERT(ggml_is_contiguous(src1));
1107
+ const size_t offs = 0;
874
1108
 
875
1109
  bool bcast_row = false;
876
1110
 
877
1111
  int64_t nb = ne00;
878
1112
 
879
- if (ggml_nelements(src1) == ne10 && ne00 % 4 == 0) {
1113
+ id<MTLComputePipelineState> pipeline = nil;
1114
+
1115
+ if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
1116
+ GGML_ASSERT(ggml_is_contiguous(src0));
1117
+
880
1118
  // src1 is a row
881
1119
  GGML_ASSERT(ne11 == 1);
882
1120
 
883
1121
  nb = ne00 / 4;
884
- [encoder setComputePipelineState:ctx->pipeline_add_row];
1122
+ switch (dst->op) {
1123
+ case GGML_OP_ADD: pipeline = ctx->pipeline_add_row; break;
1124
+ case GGML_OP_MUL: pipeline = ctx->pipeline_mul_row; break;
1125
+ case GGML_OP_DIV: pipeline = ctx->pipeline_div_row; break;
1126
+ default: GGML_ASSERT(false);
1127
+ }
885
1128
 
886
1129
  bcast_row = true;
887
1130
  } else {
888
- [encoder setComputePipelineState:ctx->pipeline_add];
1131
+ switch (dst->op) {
1132
+ case GGML_OP_ADD: pipeline = ctx->pipeline_add; break;
1133
+ case GGML_OP_MUL: pipeline = ctx->pipeline_mul; break;
1134
+ case GGML_OP_DIV: pipeline = ctx->pipeline_div; break;
1135
+ default: GGML_ASSERT(false);
1136
+ }
889
1137
  }
1138
+
1139
+ [encoder setComputePipelineState:pipeline];
890
1140
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
891
1141
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
892
1142
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
@@ -914,42 +1164,98 @@ void ggml_metal_graph_compute(
914
1164
  [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
915
1165
  [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
916
1166
  [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
917
- [encoder setBytes:&nb length:sizeof(nb) atIndex:27];
1167
+ [encoder setBytes:&offs length:sizeof(offs) atIndex:27];
1168
+ [encoder setBytes:&nb length:sizeof(nb) atIndex:28];
918
1169
 
919
1170
  if (bcast_row) {
920
1171
  const int64_t n = ggml_nelements(dst)/4;
921
1172
 
922
1173
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
923
1174
  } else {
924
- const int nth = MIN(1024, ne0);
1175
+ const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
925
1176
 
926
1177
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
927
1178
  }
928
1179
  } break;
929
- case GGML_OP_MUL:
1180
+ case GGML_OP_ACC:
930
1181
  {
1182
+ GGML_ASSERT(src0t == GGML_TYPE_F32);
1183
+ GGML_ASSERT(src1t == GGML_TYPE_F32);
1184
+ GGML_ASSERT(dstt == GGML_TYPE_F32);
1185
+
931
1186
  GGML_ASSERT(ggml_is_contiguous(src0));
932
1187
  GGML_ASSERT(ggml_is_contiguous(src1));
933
1188
 
934
- // utilize float4
935
- GGML_ASSERT(ne00 % 4 == 0);
936
- const int64_t nb = ne00/4;
1189
+ const size_t pnb1 = ((int32_t *) dst->op_params)[0];
1190
+ const size_t pnb2 = ((int32_t *) dst->op_params)[1];
1191
+ const size_t pnb3 = ((int32_t *) dst->op_params)[2];
1192
+ const size_t offs = ((int32_t *) dst->op_params)[3];
1193
+
1194
+ const bool inplace = (bool) ((int32_t *) dst->op_params)[4];
1195
+
1196
+ if (!inplace) {
1197
+ // run a separete kernel to cpy src->dst
1198
+ // not sure how to avoid this
1199
+ // TODO: make a simpler cpy_bytes kernel
1200
+
1201
+ const int nth = MIN(1024, ne00);
1202
+
1203
+ [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32];
1204
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1205
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1206
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1207
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
1208
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
1209
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
1210
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
1211
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
1212
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
1213
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
1214
+ [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
1215
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
1216
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
1217
+ [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
1218
+ [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
1219
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
1220
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
1221
+ [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
937
1222
 
938
- if (ggml_nelements(src1) == ne10) {
939
- // src1 is a row
940
- GGML_ASSERT(ne11 == 1);
941
- [encoder setComputePipelineState:ctx->pipeline_mul_row];
942
- } else {
943
- [encoder setComputePipelineState:ctx->pipeline_mul];
1223
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
944
1224
  }
1225
+
1226
+ [encoder setComputePipelineState:ctx->pipeline_add];
945
1227
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
946
1228
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
947
1229
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
948
- [encoder setBytes:&nb length:sizeof(nb) atIndex:3];
1230
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1231
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1232
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1233
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
1234
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
1235
+ [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:8];
1236
+ [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:9];
1237
+ [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:10];
1238
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
1239
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
1240
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
1241
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
1242
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
1243
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
1244
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
1245
+ [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
1246
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
1247
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
1248
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
1249
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
1250
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
1251
+ [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:24];
1252
+ [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:25];
1253
+ [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26];
1254
+ [encoder setBytes:&offs length:sizeof(offs) atIndex:27];
949
1255
 
950
- const int64_t n = ggml_nelements(dst)/4;
1256
+ const int nth = MIN(1024, ne0);
951
1257
 
952
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1258
+ [encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
953
1259
  } break;
954
1260
  case GGML_OP_SCALE:
955
1261
  {
@@ -974,16 +1280,15 @@ void ggml_metal_graph_compute(
974
1280
  } break;
975
1281
  case GGML_OP_UNARY:
976
1282
  switch (ggml_get_unary_op(gf->nodes[i])) {
977
- case GGML_UNARY_OP_SILU:
1283
+ case GGML_UNARY_OP_TANH:
978
1284
  {
979
- [encoder setComputePipelineState:ctx->pipeline_silu];
1285
+ [encoder setComputePipelineState:ctx->pipeline_tanh];
980
1286
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
981
1287
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
982
1288
 
983
1289
  const int64_t n = ggml_nelements(dst);
984
- GGML_ASSERT(n % 4 == 0);
985
1290
 
986
- [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1291
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
987
1292
  } break;
988
1293
  case GGML_UNARY_OP_RELU:
989
1294
  {
@@ -1004,6 +1309,28 @@ void ggml_metal_graph_compute(
1004
1309
  const int64_t n = ggml_nelements(dst);
1005
1310
  GGML_ASSERT(n % 4 == 0);
1006
1311
 
1312
+ [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1313
+ } break;
1314
+ case GGML_UNARY_OP_GELU_QUICK:
1315
+ {
1316
+ [encoder setComputePipelineState:ctx->pipeline_gelu_quick];
1317
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1318
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1319
+
1320
+ const int64_t n = ggml_nelements(dst);
1321
+ GGML_ASSERT(n % 4 == 0);
1322
+
1323
+ [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1324
+ } break;
1325
+ case GGML_UNARY_OP_SILU:
1326
+ {
1327
+ [encoder setComputePipelineState:ctx->pipeline_silu];
1328
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1329
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1330
+
1331
+ const int64_t n = ggml_nelements(dst);
1332
+ GGML_ASSERT(n % 4 == 0);
1333
+
1007
1334
  [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1008
1335
  } break;
1009
1336
  default:
@@ -1023,6 +1350,40 @@ void ggml_metal_graph_compute(
1023
1350
  const int64_t n = ggml_nelements(dst);
1024
1351
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1025
1352
  } break;
1353
+ case GGML_OP_SUM_ROWS:
1354
+ {
1355
+ GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
1356
+
1357
+ [encoder setComputePipelineState:ctx->pipeline_sum_rows];
1358
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1359
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1360
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
1361
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
1362
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
1363
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
1364
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
1365
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
1366
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
1367
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
1368
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
1369
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
1370
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
1371
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
1372
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
1373
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
1374
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
1375
+ [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
1376
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18];
1377
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:19];
1378
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:20];
1379
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:21];
1380
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:22];
1381
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:23];
1382
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:24];
1383
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:25];
1384
+
1385
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1386
+ } break;
1026
1387
  case GGML_OP_SOFT_MAX:
1027
1388
  {
1028
1389
  int nth = 32; // SIMD width
@@ -1042,7 +1403,11 @@ void ggml_metal_graph_compute(
1042
1403
  const float scale = ((float *) dst->op_params)[0];
1043
1404
 
1044
1405
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1045
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1406
+ if (id_src1) {
1407
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1408
+ } else {
1409
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
1410
+ }
1046
1411
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1047
1412
  [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1048
1413
  [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
@@ -1077,9 +1442,13 @@ void ggml_metal_graph_compute(
1077
1442
  case GGML_OP_MUL_MAT:
1078
1443
  {
1079
1444
  GGML_ASSERT(ne00 == ne10);
1080
- GGML_ASSERT(ne03 == ne13);
1081
1445
 
1082
- const uint gqa = ne12/ne02;
1446
+ // TODO: assert that dim2 and dim3 are contiguous
1447
+ GGML_ASSERT(ne12 % ne02 == 0);
1448
+ GGML_ASSERT(ne13 % ne03 == 0);
1449
+
1450
+ const uint r2 = ne12/ne02;
1451
+ const uint r3 = ne13/ne03;
1083
1452
 
1084
1453
  // find the break-even point where the matrix-matrix kernel becomes more efficient compared
1085
1454
  // to the matrix-vector kernel
@@ -1114,7 +1483,7 @@ void ggml_metal_graph_compute(
1114
1483
  !ggml_is_transposed(src1) &&
1115
1484
  src1t == GGML_TYPE_F32 &&
1116
1485
  ne00 % 32 == 0 && ne00 >= 64 &&
1117
- ne11 > ne11_mm_min) {
1486
+ (ne11 > ne11_mm_min || (ggml_is_quantized(src0t) && ne12 > 1))) {
1118
1487
  //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
1119
1488
  switch (src0->type) {
1120
1489
  case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32]; break;
@@ -1144,9 +1513,10 @@ void ggml_metal_graph_compute(
1144
1513
  [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10];
1145
1514
  [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11];
1146
1515
  [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
1147
- [encoder setBytes:&gqa length:sizeof(gqa) atIndex:13];
1516
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:13];
1517
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:14];
1148
1518
  [encoder setThreadgroupMemoryLength:8192 atIndex:0];
1149
- [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1519
+ [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1150
1520
  } else {
1151
1521
  int nth0 = 32;
1152
1522
  int nth1 = 1;
@@ -1182,90 +1552,60 @@ void ggml_metal_graph_compute(
1182
1552
  } break;
1183
1553
  case GGML_TYPE_Q4_0:
1184
1554
  {
1185
- GGML_ASSERT(ne02 == 1);
1186
- GGML_ASSERT(ne12 == 1);
1187
-
1188
1555
  nth0 = 8;
1189
1556
  nth1 = 8;
1190
1557
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_0_f32];
1191
1558
  } break;
1192
1559
  case GGML_TYPE_Q4_1:
1193
1560
  {
1194
- GGML_ASSERT(ne02 == 1);
1195
- GGML_ASSERT(ne12 == 1);
1196
-
1197
1561
  nth0 = 8;
1198
1562
  nth1 = 8;
1199
1563
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_1_f32];
1200
1564
  } break;
1201
1565
  case GGML_TYPE_Q5_0:
1202
1566
  {
1203
- GGML_ASSERT(ne02 == 1);
1204
- GGML_ASSERT(ne12 == 1);
1205
-
1206
1567
  nth0 = 8;
1207
1568
  nth1 = 8;
1208
1569
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_0_f32];
1209
1570
  } break;
1210
1571
  case GGML_TYPE_Q5_1:
1211
1572
  {
1212
- GGML_ASSERT(ne02 == 1);
1213
- GGML_ASSERT(ne12 == 1);
1214
-
1215
1573
  nth0 = 8;
1216
1574
  nth1 = 8;
1217
1575
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_1_f32];
1218
1576
  } break;
1219
1577
  case GGML_TYPE_Q8_0:
1220
1578
  {
1221
- GGML_ASSERT(ne02 == 1);
1222
- GGML_ASSERT(ne12 == 1);
1223
-
1224
1579
  nth0 = 8;
1225
1580
  nth1 = 8;
1226
1581
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q8_0_f32];
1227
1582
  } break;
1228
1583
  case GGML_TYPE_Q2_K:
1229
1584
  {
1230
- GGML_ASSERT(ne02 == 1);
1231
- GGML_ASSERT(ne12 == 1);
1232
-
1233
1585
  nth0 = 2;
1234
1586
  nth1 = 32;
1235
1587
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q2_K_f32];
1236
1588
  } break;
1237
1589
  case GGML_TYPE_Q3_K:
1238
1590
  {
1239
- GGML_ASSERT(ne02 == 1);
1240
- GGML_ASSERT(ne12 == 1);
1241
-
1242
1591
  nth0 = 2;
1243
1592
  nth1 = 32;
1244
1593
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q3_K_f32];
1245
1594
  } break;
1246
1595
  case GGML_TYPE_Q4_K:
1247
1596
  {
1248
- GGML_ASSERT(ne02 == 1);
1249
- GGML_ASSERT(ne12 == 1);
1250
-
1251
1597
  nth0 = 4; //1;
1252
1598
  nth1 = 8; //32;
1253
1599
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_K_f32];
1254
1600
  } break;
1255
1601
  case GGML_TYPE_Q5_K:
1256
1602
  {
1257
- GGML_ASSERT(ne02 == 1);
1258
- GGML_ASSERT(ne12 == 1);
1259
-
1260
1603
  nth0 = 2;
1261
1604
  nth1 = 32;
1262
1605
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_K_f32];
1263
1606
  } break;
1264
1607
  case GGML_TYPE_Q6_K:
1265
1608
  {
1266
- GGML_ASSERT(ne02 == 1);
1267
- GGML_ASSERT(ne12 == 1);
1268
-
1269
1609
  nth0 = 2;
1270
1610
  nth1 = 32;
1271
1611
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q6_K_f32];
@@ -1294,31 +1634,281 @@ void ggml_metal_graph_compute(
1294
1634
  [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
1295
1635
  [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
1296
1636
  [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
1297
- [encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
1637
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:17];
1638
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:18];
1298
1639
 
1299
1640
  if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
1300
1641
  src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
1301
1642
  src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
1302
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1643
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1303
1644
  }
1304
1645
  else if (src0t == GGML_TYPE_Q4_K) {
1305
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1646
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1306
1647
  }
1307
1648
  else if (src0t == GGML_TYPE_Q3_K) {
1308
1649
  #ifdef GGML_QKK_64
1309
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1650
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1310
1651
  #else
1311
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1652
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1312
1653
  #endif
1313
1654
  }
1314
1655
  else if (src0t == GGML_TYPE_Q5_K) {
1315
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1656
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1316
1657
  }
1317
1658
  else if (src0t == GGML_TYPE_Q6_K) {
1318
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1659
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1660
+ } else {
1661
+ const int64_t ny = (ne11 + nrows - 1)/nrows;
1662
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1663
+ }
1664
+ }
1665
+ } break;
1666
+ case GGML_OP_MUL_MAT_ID:
1667
+ {
1668
+ //GGML_ASSERT(ne00 == ne10);
1669
+ //GGML_ASSERT(ne03 == ne13);
1670
+
1671
+ GGML_ASSERT(src0t == GGML_TYPE_I32);
1672
+
1673
+ const int n_as = ((int32_t *) dst->op_params)[1];
1674
+
1675
+ // TODO: make this more general
1676
+ GGML_ASSERT(n_as <= 8);
1677
+
1678
+ struct ggml_tensor * src2 = gf->nodes[i]->src[2];
1679
+
1680
+ const int64_t ne20 = src2 ? src2->ne[0] : 0;
1681
+ const int64_t ne21 = src2 ? src2->ne[1] : 0;
1682
+ const int64_t ne22 = src2 ? src2->ne[2] : 0;
1683
+ const int64_t ne23 = src2 ? src2->ne[3] : 0; GGML_UNUSED(ne23);
1684
+
1685
+ const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
1686
+ const uint64_t nb21 = src2 ? src2->nb[1] : 0;
1687
+ const uint64_t nb22 = src2 ? src2->nb[2] : 0;
1688
+ const uint64_t nb23 = src2 ? src2->nb[3] : 0; GGML_UNUSED(nb23);
1689
+
1690
+ const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
1691
+
1692
+ GGML_ASSERT(!ggml_is_transposed(src2));
1693
+ GGML_ASSERT(!ggml_is_transposed(src1));
1694
+
1695
+ GGML_ASSERT(ne20 % 32 == 0);
1696
+ // !!!!!!!!! TODO: this assert is probably required but not sure!
1697
+ //GGML_ASSERT(ne20 >= 64);
1698
+ GGML_ASSERT(src1t == GGML_TYPE_F32);
1699
+
1700
+ const uint r2 = ne12/ne22;
1701
+ const uint r3 = ne13/ne23;
1702
+
1703
+ // find the break-even point where the matrix-matrix kernel becomes more efficient compared
1704
+ // to the matrix-vector kernel
1705
+ int ne11_mm_min = 1;
1706
+
1707
+ const int idx = ((int32_t *) dst->op_params)[0];
1708
+
1709
+ // batch size
1710
+ GGML_ASSERT(ne01 == ne11);
1711
+
1712
+ const int64_t _ne1 = 1; // kernel_mul_mm_impl needs a reference in constant memory
1713
+
1714
+ // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
1715
+ // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
1716
+ // !!!
1717
+ // TODO: for now, always use mat-vec kernels until we figure out how to improve the
1718
+ // indirect matrix multiplication
1719
+ // !!!
1720
+ if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && _ne1 > ne11_mm_min) {
1721
+ switch (src2->type) {
1722
+ case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f32_f32]; break;
1723
+ case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f16_f32]; break;
1724
+ case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_0_f32]; break;
1725
+ case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_1_f32]; break;
1726
+ case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_0_f32]; break;
1727
+ case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_1_f32]; break;
1728
+ case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q8_0_f32]; break;
1729
+ case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q2_K_f32]; break;
1730
+ case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q3_K_f32]; break;
1731
+ case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_K_f32]; break;
1732
+ case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_K_f32]; break;
1733
+ case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q6_K_f32]; break;
1734
+ default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
1735
+ }
1736
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1737
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1738
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1739
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
1740
+ [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
1741
+ [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:5];
1742
+ [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
1743
+ [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:7];
1744
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:8];
1745
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:9];
1746
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10];
1747
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11];
1748
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12];
1749
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
1750
+ [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:14];
1751
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
1752
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:16];
1753
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:17];
1754
+ [encoder setBytes:&idx length:sizeof(idx) atIndex:18];
1755
+ // TODO: how to make this an array? read Metal docs
1756
+ for (int j = 0; j < n_as; ++j) {
1757
+ struct ggml_tensor * src_cur = dst->src[2 + j];
1758
+
1759
+ size_t offs_src_cur = 0;
1760
+ id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
1761
+
1762
+ [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:19 + j];
1763
+ }
1764
+
1765
+ [encoder setThreadgroupMemoryLength:8192 atIndex:0];
1766
+
1767
+ // TODO: processing one row at a time (ne11 -> 1) is not efficient
1768
+ [encoder dispatchThreadgroups:MTLSizeMake( (_ne1 + 31)/32, (ne21 + 63)/64, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1769
+ } else {
1770
+ int nth0 = 32;
1771
+ int nth1 = 1;
1772
+ int nrows = 1;
1773
+ //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
1774
+
1775
+ // use custom matrix x vector kernel
1776
+ switch (src2t) {
1777
+ case GGML_TYPE_F32:
1778
+ {
1779
+ GGML_ASSERT(src1t == GGML_TYPE_F32);
1780
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_f32_f32];
1781
+ } break;
1782
+ case GGML_TYPE_F16:
1783
+ {
1784
+ GGML_ASSERT(src1t == GGML_TYPE_F32);
1785
+ nth0 = 32;
1786
+ nth1 = 1;
1787
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_f16_f32];
1788
+ } break;
1789
+ case GGML_TYPE_Q4_0:
1790
+ {
1791
+ nth0 = 8;
1792
+ nth1 = 8;
1793
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_0_f32];
1794
+ } break;
1795
+ case GGML_TYPE_Q4_1:
1796
+ {
1797
+ nth0 = 8;
1798
+ nth1 = 8;
1799
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_1_f32];
1800
+ } break;
1801
+ case GGML_TYPE_Q5_0:
1802
+ {
1803
+ nth0 = 8;
1804
+ nth1 = 8;
1805
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_0_f32];
1806
+ } break;
1807
+ case GGML_TYPE_Q5_1:
1808
+ {
1809
+ nth0 = 8;
1810
+ nth1 = 8;
1811
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_1_f32];
1812
+ } break;
1813
+ case GGML_TYPE_Q8_0:
1814
+ {
1815
+ nth0 = 8;
1816
+ nth1 = 8;
1817
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q8_0_f32];
1818
+ } break;
1819
+ case GGML_TYPE_Q2_K:
1820
+ {
1821
+ nth0 = 2;
1822
+ nth1 = 32;
1823
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q2_K_f32];
1824
+ } break;
1825
+ case GGML_TYPE_Q3_K:
1826
+ {
1827
+ nth0 = 2;
1828
+ nth1 = 32;
1829
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q3_K_f32];
1830
+ } break;
1831
+ case GGML_TYPE_Q4_K:
1832
+ {
1833
+ nth0 = 4; //1;
1834
+ nth1 = 8; //32;
1835
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_K_f32];
1836
+ } break;
1837
+ case GGML_TYPE_Q5_K:
1838
+ {
1839
+ nth0 = 2;
1840
+ nth1 = 32;
1841
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_K_f32];
1842
+ } break;
1843
+ case GGML_TYPE_Q6_K:
1844
+ {
1845
+ nth0 = 2;
1846
+ nth1 = 32;
1847
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q6_K_f32];
1848
+ } break;
1849
+ default:
1850
+ {
1851
+ GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
1852
+ GGML_ASSERT(false && "not implemented");
1853
+ }
1854
+ };
1855
+
1856
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1857
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1858
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1859
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
1860
+ [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
1861
+ [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
1862
+ [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:6];
1863
+ [encoder setBytes:&nb20 length:sizeof(nb20) atIndex:7];
1864
+ [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:8];
1865
+ [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:9];
1866
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
1867
+ [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:11];
1868
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
1869
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
1870
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
1871
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
1872
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
1873
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
1874
+ [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:18];
1875
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19];
1876
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:20];
1877
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:21];
1878
+ [encoder setBytes:&idx length:sizeof(idx) atIndex:22];
1879
+ // TODO: how to make this an array? read Metal docs
1880
+ for (int j = 0; j < n_as; ++j) {
1881
+ struct ggml_tensor * src_cur = dst->src[2 + j];
1882
+
1883
+ size_t offs_src_cur = 0;
1884
+ id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
1885
+
1886
+ [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:23 + j];
1887
+ }
1888
+
1889
+ if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
1890
+ src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
1891
+ src2t == GGML_TYPE_Q2_K) { // || src2t == GGML_TYPE_Q4_K) {
1892
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1893
+ }
1894
+ else if (src2t == GGML_TYPE_Q4_K) {
1895
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1896
+ }
1897
+ else if (src2t == GGML_TYPE_Q3_K) {
1898
+ #ifdef GGML_QKK_64
1899
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1900
+ #else
1901
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1902
+ #endif
1903
+ }
1904
+ else if (src2t == GGML_TYPE_Q5_K) {
1905
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1906
+ }
1907
+ else if (src2t == GGML_TYPE_Q6_K) {
1908
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1319
1909
  } else {
1320
- int64_t ny = (ne11 + nrows - 1)/nrows;
1321
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1910
+ const int64_t ny = (_ne1 + nrows - 1)/nrows;
1911
+ [encoder dispatchThreadgroups:MTLSizeMake(ne21, ny, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1322
1912
  }
1323
1913
  }
1324
1914
  } break;
@@ -1340,16 +1930,19 @@ void ggml_metal_graph_compute(
1340
1930
  default: GGML_ASSERT(false && "not implemented");
1341
1931
  }
1342
1932
 
1343
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1344
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1345
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1933
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1934
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1935
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1346
1936
  [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
1347
1937
  [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
1348
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:5];
1349
-
1350
- const int64_t n = ggml_nelements(src1);
1351
-
1352
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1938
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5];
1939
+ [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6];
1940
+ [encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7];
1941
+ [encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8];
1942
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:9];
1943
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:10];
1944
+
1945
+ [encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
1353
1946
  } break;
1354
1947
  case GGML_OP_RMS_NORM:
1355
1948
  {
@@ -1376,6 +1969,38 @@ void ggml_metal_graph_compute(
1376
1969
 
1377
1970
  [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1378
1971
  } break;
1972
+ case GGML_OP_GROUP_NORM:
1973
+ {
1974
+ GGML_ASSERT(ne00 % 4 == 0);
1975
+
1976
+ //float eps;
1977
+ //memcpy(&eps, dst->op_params, sizeof(float));
1978
+
1979
+ const float eps = 1e-6f; // TODO: temporarily hardcoded
1980
+
1981
+ const int32_t n_groups = ((int32_t *) dst->op_params)[0];
1982
+
1983
+ int nth = 32; // SIMD width
1984
+
1985
+ //while (nth < ne00/4 && nth < 1024) {
1986
+ // nth *= 2;
1987
+ //}
1988
+
1989
+ [encoder setComputePipelineState:ctx->pipeline_group_norm];
1990
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1991
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1992
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1993
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
1994
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
1995
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:5];
1996
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:6];
1997
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:7];
1998
+ [encoder setBytes:&n_groups length:sizeof( int32_t) atIndex:8];
1999
+ [encoder setBytes:&eps length:sizeof( float) atIndex:9];
2000
+ [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
2001
+
2002
+ [encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2003
+ } break;
1379
2004
  case GGML_OP_NORM:
1380
2005
  {
1381
2006
  float eps;
@@ -1545,18 +2170,123 @@ void ggml_metal_graph_compute(
1545
2170
 
1546
2171
  [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
1547
2172
  } break;
2173
+ case GGML_OP_UPSCALE:
2174
+ {
2175
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
2176
+
2177
+ const int sf = dst->op_params[0];
2178
+
2179
+ [encoder setComputePipelineState:ctx->pipeline_upscale_f32];
2180
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2181
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2182
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
2183
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
2184
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
2185
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
2186
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
2187
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
2188
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
2189
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
2190
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
2191
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
2192
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
2193
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
2194
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
2195
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
2196
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
2197
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
2198
+ [encoder setBytes:&sf length:sizeof(sf) atIndex:18];
2199
+
2200
+ const int nth = MIN(1024, ne0);
2201
+
2202
+ [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2203
+ } break;
2204
+ case GGML_OP_PAD:
2205
+ {
2206
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
2207
+
2208
+ [encoder setComputePipelineState:ctx->pipeline_pad_f32];
2209
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2210
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2211
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
2212
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
2213
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
2214
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
2215
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
2216
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
2217
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
2218
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
2219
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
2220
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
2221
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
2222
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
2223
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
2224
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
2225
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
2226
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
2227
+
2228
+ const int nth = MIN(1024, ne0);
2229
+
2230
+ [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2231
+ } break;
2232
+ case GGML_OP_ARGSORT:
2233
+ {
2234
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
2235
+ GGML_ASSERT( dst->type == GGML_TYPE_I32);
2236
+
2237
+ const int nrows = ggml_nrows(src0);
2238
+
2239
+ enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
2240
+
2241
+ switch (order) {
2242
+ case GGML_SORT_ASC: [encoder setComputePipelineState:ctx->pipeline_argsort_f32_i32_asc]; break;
2243
+ case GGML_SORT_DESC: [encoder setComputePipelineState:ctx->pipeline_argsort_f32_i32_desc]; break;
2244
+ default: GGML_ASSERT(false);
2245
+ };
2246
+
2247
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2248
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2249
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
2250
+
2251
+ [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00, 1, 1)];
2252
+ } break;
2253
+ case GGML_OP_LEAKY_RELU:
2254
+ {
2255
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
2256
+
2257
+ float slope;
2258
+ memcpy(&slope, dst->op_params, sizeof(float));
2259
+
2260
+ [encoder setComputePipelineState:ctx->pipeline_leaky_relu_f32];
2261
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2262
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2263
+ [encoder setBytes:&slope length:sizeof(slope) atIndex:2];
2264
+
2265
+ const int64_t n = ggml_nelements(dst);
2266
+
2267
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2268
+ } break;
1548
2269
  case GGML_OP_DUP:
1549
2270
  case GGML_OP_CPY:
1550
2271
  case GGML_OP_CONT:
1551
2272
  {
1552
- const int nth = MIN(1024, ne00);
2273
+ GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
2274
+
2275
+ int nth = MIN(1024, ne00/ggml_blck_size(src0->type));
1553
2276
 
1554
2277
  switch (src0t) {
1555
2278
  case GGML_TYPE_F32:
1556
2279
  {
2280
+ GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0);
2281
+
1557
2282
  switch (dstt) {
1558
- case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f16]; break;
1559
- case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32]; break;
2283
+ case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f16]; break;
2284
+ case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32]; break;
2285
+ case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q8_0]; break;
2286
+ case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q4_0]; break;
2287
+ case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q4_1]; break;
2288
+ //case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q5_0]; break;
2289
+ //case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q5_1]; break;
1560
2290
  default: GGML_ASSERT(false && "not implemented");
1561
2291
  };
1562
2292
  } break;
@@ -1564,7 +2294,7 @@ void ggml_metal_graph_compute(
1564
2294
  {
1565
2295
  switch (dstt) {
1566
2296
  case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f16]; break;
1567
- case GGML_TYPE_F32: GGML_ASSERT(false && "cpy_f16_f32 not implemented"); break;
2297
+ case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f32]; break;
1568
2298
  default: GGML_ASSERT(false && "not implemented");
1569
2299
  };
1570
2300
  } break;
@@ -1631,81 +2361,150 @@ void ggml_metal_graph_compute(
1631
2361
 
1632
2362
  // backend interface
1633
2363
 
1634
- static const char * ggml_backend_metal_name(ggml_backend_t backend) {
1635
- return "Metal";
2364
+ static id<MTLDevice> g_backend_device = nil;
2365
+ static int g_backend_device_ref_count = 0;
1636
2366
 
1637
- UNUSED(backend);
2367
+ static id<MTLDevice> ggml_backend_metal_get_device(void) {
2368
+ if (g_backend_device == nil) {
2369
+ g_backend_device = MTLCreateSystemDefaultDevice();
2370
+ }
2371
+
2372
+ g_backend_device_ref_count++;
2373
+
2374
+ return g_backend_device;
1638
2375
  }
1639
2376
 
1640
- static void ggml_backend_metal_free(ggml_backend_t backend) {
1641
- struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
1642
- ggml_metal_free(ctx);
1643
- free(backend);
2377
+ static void ggml_backend_metal_free_device(void) {
2378
+ assert(g_backend_device_ref_count > 0);
2379
+
2380
+ g_backend_device_ref_count--;
2381
+
2382
+ if (g_backend_device_ref_count == 0) {
2383
+ [g_backend_device release];
2384
+ g_backend_device = nil;
2385
+ }
1644
2386
  }
1645
2387
 
1646
2388
  static void * ggml_backend_metal_buffer_get_base(ggml_backend_buffer_t buffer) {
1647
- return (void *)buffer->context;
2389
+ struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
2390
+
2391
+ return ctx->data;
1648
2392
  }
1649
2393
 
1650
2394
  static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer) {
1651
- free(buffer->context);
2395
+ struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
2396
+
2397
+ [ctx->metal release];
2398
+ ggml_backend_metal_free_device();
2399
+
2400
+ free(ctx->data);
2401
+ free(ctx);
2402
+
2403
+ UNUSED(buffer);
2404
+ }
2405
+
2406
+ static void ggml_backend_metal_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
2407
+ GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
2408
+ GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
2409
+
2410
+ memcpy((char *)tensor->data + offset, data, size);
2411
+
2412
+ UNUSED(buffer);
2413
+ }
2414
+
2415
+ static void ggml_backend_metal_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
2416
+ GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
2417
+ GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
2418
+
2419
+ memcpy(data, (const char *)tensor->data + offset, size);
2420
+
2421
+ UNUSED(buffer);
2422
+ }
2423
+
2424
+ static void ggml_backend_metal_buffer_cpy_tensor_from(ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst) {
2425
+ ggml_backend_tensor_get(src, dst->data, 0, ggml_nbytes(src));
2426
+
2427
+ UNUSED(buffer);
2428
+ }
2429
+
2430
+ static void ggml_backend_metal_buffer_cpy_tensor_to(ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst) {
2431
+ ggml_backend_tensor_set(dst, src->data, 0, ggml_nbytes(src));
2432
+
1652
2433
  UNUSED(buffer);
1653
2434
  }
1654
2435
 
1655
2436
  static struct ggml_backend_buffer_i metal_backend_buffer_i = {
1656
- /* .free_buffer = */ ggml_backend_metal_buffer_free_buffer,
1657
- /* .get_base = */ ggml_backend_metal_buffer_get_base,
1658
- /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
1659
- /* .init_tensor = */ NULL, // no initialization required
1660
- /* .free_tensor = */ NULL, // no cleanup required
2437
+ /* .free_buffer = */ ggml_backend_metal_buffer_free_buffer,
2438
+ /* .get_base = */ ggml_backend_metal_buffer_get_base,
2439
+ /* .init_tensor = */ NULL,
2440
+ /* .set_tensor = */ ggml_backend_metal_buffer_set_tensor,
2441
+ /* .get_tensor = */ ggml_backend_metal_buffer_get_tensor,
2442
+ /* .cpy_tensor_from = */ ggml_backend_metal_buffer_cpy_tensor_from,
2443
+ /* .cpy_tensor_to = */ ggml_backend_metal_buffer_cpy_tensor_to,
1661
2444
  };
1662
2445
 
1663
- static ggml_backend_buffer_t ggml_backend_metal_alloc_buffer(ggml_backend_t backend, size_t size) {
1664
- struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
2446
+ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
2447
+ struct ggml_backend_metal_buffer_context * ctx = malloc(sizeof(struct ggml_backend_metal_buffer_context));
2448
+
2449
+ const size_t size_page = sysconf(_SC_PAGESIZE);
1665
2450
 
1666
- void * data = ggml_metal_host_malloc(size);
2451
+ size_t size_aligned = size;
2452
+ if ((size_aligned % size_page) != 0) {
2453
+ size_aligned += (size_page - (size_aligned % size_page));
2454
+ }
1667
2455
 
1668
- // TODO: set proper name of the buffers
1669
- ggml_metal_add_buffer(ctx, "backend", data, size, 0);
2456
+ ctx->data = ggml_metal_host_malloc(size);
2457
+ ctx->metal = [ggml_backend_metal_get_device() newBufferWithBytesNoCopy:ctx->data
2458
+ length:size_aligned
2459
+ options:MTLResourceStorageModeShared
2460
+ deallocator:nil];
1670
2461
 
1671
- return ggml_backend_buffer_init(backend, metal_backend_buffer_i, data, size);
2462
+ return ggml_backend_buffer_init(buft, metal_backend_buffer_i, ctx, size);
1672
2463
  }
1673
2464
 
1674
- static size_t ggml_backend_metal_get_alignment(ggml_backend_t backend) {
2465
+ static size_t ggml_backend_metal_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
1675
2466
  return 32;
1676
- UNUSED(backend);
2467
+ UNUSED(buft);
1677
2468
  }
1678
2469
 
1679
- static void ggml_backend_metal_set_tensor_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
1680
- GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
1681
- GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
1682
-
1683
- memcpy((char *)tensor->data + offset, data, size);
2470
+ static bool ggml_backend_metal_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) {
2471
+ return ggml_backend_is_metal(backend) || ggml_backend_is_cpu(backend);
1684
2472
 
1685
- UNUSED(backend);
2473
+ GGML_UNUSED(buft);
1686
2474
  }
1687
2475
 
1688
- static void ggml_backend_metal_get_tensor_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
1689
- GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
1690
- GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
1691
-
1692
- memcpy(data, (const char *)tensor->data + offset, size);
2476
+ ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) {
2477
+ static struct ggml_backend_buffer_type ggml_backend_buffer_type_metal = {
2478
+ /* .iface = */ {
2479
+ /* .alloc_buffer = */ ggml_backend_metal_buffer_type_alloc_buffer,
2480
+ /* .get_alignment = */ ggml_backend_metal_buffer_type_get_alignment,
2481
+ /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
2482
+ /* .supports_backend = */ ggml_backend_metal_buffer_type_supports_backend,
2483
+ },
2484
+ /* .context = */ NULL,
2485
+ };
1693
2486
 
1694
- UNUSED(backend);
2487
+ return &ggml_backend_buffer_type_metal;
1695
2488
  }
1696
2489
 
1697
- static void ggml_backend_metal_synchronize(ggml_backend_t backend) {
2490
+ static const char * ggml_backend_metal_name(ggml_backend_t backend) {
2491
+ return "Metal";
2492
+
1698
2493
  UNUSED(backend);
1699
2494
  }
1700
2495
 
1701
- static void ggml_backend_metal_cpy_tensor_from(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst) {
1702
- ggml_backend_tensor_get(src, dst->data, 0, ggml_nbytes(src));
2496
+ static void ggml_backend_metal_free(ggml_backend_t backend) {
2497
+ struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
2498
+ ggml_metal_free(ctx);
2499
+ free(backend);
2500
+ }
1703
2501
 
2502
+ static void ggml_backend_metal_synchronize(ggml_backend_t backend) {
1704
2503
  UNUSED(backend);
1705
2504
  }
1706
2505
 
1707
- static void ggml_backend_metal_cpy_tensor_to(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst) {
1708
- ggml_backend_tensor_set_async(dst, src->data, 0, ggml_nbytes(src));
2506
+ static ggml_backend_buffer_type_t ggml_backend_metal_get_default_buffer_type(ggml_backend_t backend) {
2507
+ return ggml_backend_metal_buffer_type();
1709
2508
 
1710
2509
  UNUSED(backend);
1711
2510
  }
@@ -1717,32 +2516,43 @@ static void ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml
1717
2516
  }
1718
2517
 
1719
2518
  static bool ggml_backend_metal_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
1720
- return true;
2519
+ return ggml_metal_supports_op(op);
2520
+
1721
2521
  UNUSED(backend);
1722
- UNUSED(op);
1723
2522
  }
1724
2523
 
1725
2524
  static struct ggml_backend_i metal_backend_i = {
1726
- /* .get_name = */ ggml_backend_metal_name,
1727
- /* .free = */ ggml_backend_metal_free,
1728
- /* .alloc_buffer = */ ggml_backend_metal_alloc_buffer,
1729
- /* .get_alignment = */ ggml_backend_metal_get_alignment,
1730
- /* .set_tensor_async = */ ggml_backend_metal_set_tensor_async,
1731
- /* .get_tensor_async = */ ggml_backend_metal_get_tensor_async,
1732
- /* .synchronize = */ ggml_backend_metal_synchronize,
1733
- /* .cpy_tensor_from = */ ggml_backend_metal_cpy_tensor_from,
1734
- /* .cpy_tensor_to = */ ggml_backend_metal_cpy_tensor_to,
1735
- /* .graph_plan_create = */ NULL, // the metal implementation does not require creating graph plans atm
1736
- /* .graph_plan_free = */ NULL,
1737
- /* .graph_plan_compute = */ NULL,
1738
- /* .graph_compute = */ ggml_backend_metal_graph_compute,
1739
- /* .supports_op = */ ggml_backend_metal_supports_op,
2525
+ /* .get_name = */ ggml_backend_metal_name,
2526
+ /* .free = */ ggml_backend_metal_free,
2527
+ /* .get_default_buffer_type = */ ggml_backend_metal_get_default_buffer_type,
2528
+ /* .set_tensor_async = */ NULL,
2529
+ /* .get_tensor_async = */ NULL,
2530
+ /* .cpy_tensor_from_async = */ NULL,
2531
+ /* .cpy_tensor_to_async = */ NULL,
2532
+ /* .synchronize = */ ggml_backend_metal_synchronize,
2533
+ /* .graph_plan_create = */ NULL, // the metal implementation does not require creating graph plans atm
2534
+ /* .graph_plan_free = */ NULL,
2535
+ /* .graph_plan_compute = */ NULL,
2536
+ /* .graph_compute = */ ggml_backend_metal_graph_compute,
2537
+ /* .supports_op = */ ggml_backend_metal_supports_op,
1740
2538
  };
1741
2539
 
2540
+ // TODO: make a common log callback for all backends in ggml-backend
2541
+ static void ggml_backend_log_callback(enum ggml_log_level level, const char * msg, void * user_data) {
2542
+ fprintf(stderr, "%s", msg);
2543
+
2544
+ UNUSED(level);
2545
+ UNUSED(user_data);
2546
+ }
2547
+
1742
2548
  ggml_backend_t ggml_backend_metal_init(void) {
1743
- struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context));
2549
+ ggml_metal_log_set_callback(ggml_backend_log_callback, NULL);
1744
2550
 
1745
- ctx = ggml_metal_init(GGML_DEFAULT_N_THREADS);
2551
+ struct ggml_metal_context * ctx = ggml_metal_init(GGML_DEFAULT_N_THREADS);
2552
+
2553
+ if (ctx == NULL) {
2554
+ return NULL;
2555
+ }
1746
2556
 
1747
2557
  ggml_backend_t metal_backend = malloc(sizeof(struct ggml_backend));
1748
2558
 
@@ -1759,7 +2569,26 @@ bool ggml_backend_is_metal(ggml_backend_t backend) {
1759
2569
  }
1760
2570
 
1761
2571
  void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
2572
+ GGML_ASSERT(ggml_backend_is_metal(backend));
2573
+
1762
2574
  struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
1763
2575
 
1764
2576
  ggml_metal_set_n_cb(ctx, n_cb);
1765
2577
  }
2578
+
2579
+ bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {
2580
+ GGML_ASSERT(ggml_backend_is_metal(backend));
2581
+
2582
+ struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
2583
+
2584
+ return [ctx->device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
2585
+ }
2586
+
2587
+ ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data); // silence warning
2588
+
2589
+ ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data) {
2590
+ return ggml_backend_metal_init();
2591
+
2592
+ GGML_UNUSED(params);
2593
+ GGML_UNUSED(user_data);
2594
+ }