llama_cpp 0.3.8 → 0.5.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -11,6 +11,7 @@
11
11
  #define MIN(a, b) ((a) < (b) ? (a) : (b))
12
12
  #define MAX(a, b) ((a) > (b) ? (a) : (b))
13
13
 
14
+ // TODO: temporary - reuse llama.cpp logging
14
15
  #ifdef GGML_METAL_NDEBUG
15
16
  #define metal_printf(...)
16
17
  #else
@@ -33,12 +34,15 @@ struct ggml_metal_buffer {
33
34
  struct ggml_metal_context {
34
35
  int n_cb;
35
36
 
36
- float * logits;
37
-
38
37
  id<MTLDevice> device;
39
38
  id<MTLCommandQueue> queue;
40
39
  id<MTLLibrary> library;
41
40
 
41
+ id<MTLCommandBuffer> command_buffers [GGML_METAL_MAX_COMMAND_BUFFERS];
42
+ id<MTLComputeCommandEncoder> command_encoders[GGML_METAL_MAX_COMMAND_BUFFERS];
43
+
44
+ dispatch_queue_t d_queue;
45
+
42
46
  int n_buffers;
43
47
  struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
44
48
 
@@ -63,6 +67,7 @@ struct ggml_metal_context {
63
67
  GGML_METAL_DECL_KERNEL(get_rows_f16);
64
68
  GGML_METAL_DECL_KERNEL(get_rows_q4_0);
65
69
  GGML_METAL_DECL_KERNEL(get_rows_q4_1);
70
+ GGML_METAL_DECL_KERNEL(get_rows_q8_0);
66
71
  GGML_METAL_DECL_KERNEL(get_rows_q2_K);
67
72
  GGML_METAL_DECL_KERNEL(get_rows_q3_K);
68
73
  GGML_METAL_DECL_KERNEL(get_rows_q4_K);
@@ -73,6 +78,7 @@ struct ggml_metal_context {
73
78
  GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
74
79
  GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
75
80
  GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
81
+ GGML_METAL_DECL_KERNEL(mul_mat_q8_0_f32);
76
82
  GGML_METAL_DECL_KERNEL(mul_mat_q2_K_f32);
77
83
  GGML_METAL_DECL_KERNEL(mul_mat_q3_K_f32);
78
84
  GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32);
@@ -81,6 +87,7 @@ struct ggml_metal_context {
81
87
  GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
82
88
  GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
83
89
  GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32);
90
+ GGML_METAL_DECL_KERNEL(mul_mm_q8_0_f32);
84
91
  GGML_METAL_DECL_KERNEL(mul_mm_q2_K_f32);
85
92
  GGML_METAL_DECL_KERNEL(mul_mm_q3_K_f32);
86
93
  GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
@@ -107,16 +114,17 @@ static NSString * const msl_library_source = @"see metal.metal";
107
114
  @end
108
115
 
109
116
  struct ggml_metal_context * ggml_metal_init(int n_cb) {
110
- fprintf(stderr, "%s: allocating\n", __func__);
117
+ metal_printf("%s: allocating\n", __func__);
111
118
 
112
119
  struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context));
113
120
 
114
- ctx->n_cb = n_cb;
121
+ ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
115
122
  ctx->device = MTLCreateSystemDefaultDevice();
116
123
  ctx->queue = [ctx->device newCommandQueue];
117
124
  ctx->n_buffers = 0;
118
125
  ctx->concur_list_len = 0;
119
126
 
127
+ ctx->d_queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT);
120
128
 
121
129
  #if 0
122
130
  // compile from source string and show compile log
@@ -125,7 +133,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
125
133
 
126
134
  ctx->library = [ctx->device newLibraryWithSource:msl_library_source options:nil error:&error];
127
135
  if (error) {
128
- fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
136
+ metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]);
129
137
  return NULL;
130
138
  }
131
139
  }
@@ -139,11 +147,11 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
139
147
  //NSString * path = [[NSBundle mainBundle] pathForResource:@"../../examples/metal/metal" ofType:@"metal"];
140
148
  NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
141
149
  NSString * path = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
142
- fprintf(stderr, "%s: loading '%s'\n", __func__, [path UTF8String]);
150
+ metal_printf("%s: loading '%s'\n", __func__, [path UTF8String]);
143
151
 
144
152
  NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error];
145
153
  if (error) {
146
- fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
154
+ metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]);
147
155
  return NULL;
148
156
  }
149
157
 
@@ -155,7 +163,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
155
163
  ctx->library = [ctx->device newLibraryWithSource:src options:nil error:&error];
156
164
  #endif
157
165
  if (error) {
158
- fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
166
+ metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]);
159
167
  return NULL;
160
168
  }
161
169
  }
@@ -167,9 +175,11 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
167
175
  #define GGML_METAL_ADD_KERNEL(name) \
168
176
  ctx->function_##name = [ctx->library newFunctionWithName:@"kernel_"#name]; \
169
177
  ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:&error]; \
170
- fprintf(stderr, "%s: loaded %-32s %16p\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name); \
178
+ metal_printf("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name, \
179
+ (int) ctx->pipeline_##name.maxTotalThreadsPerThreadgroup, \
180
+ (int) ctx->pipeline_##name.threadExecutionWidth); \
171
181
  if (error) { \
172
- fprintf(stderr, "%s: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
182
+ metal_printf("%s: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
173
183
  return NULL; \
174
184
  }
175
185
 
@@ -186,6 +196,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
186
196
  GGML_METAL_ADD_KERNEL(get_rows_f16);
187
197
  GGML_METAL_ADD_KERNEL(get_rows_q4_0);
188
198
  GGML_METAL_ADD_KERNEL(get_rows_q4_1);
199
+ GGML_METAL_ADD_KERNEL(get_rows_q8_0);
189
200
  GGML_METAL_ADD_KERNEL(get_rows_q2_K);
190
201
  GGML_METAL_ADD_KERNEL(get_rows_q3_K);
191
202
  GGML_METAL_ADD_KERNEL(get_rows_q4_K);
@@ -196,6 +207,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
196
207
  GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
197
208
  GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
198
209
  GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
210
+ GGML_METAL_ADD_KERNEL(mul_mat_q8_0_f32);
199
211
  GGML_METAL_ADD_KERNEL(mul_mat_q2_K_f32);
200
212
  GGML_METAL_ADD_KERNEL(mul_mat_q3_K_f32);
201
213
  GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32);
@@ -203,6 +215,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
203
215
  GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32);
204
216
  GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
205
217
  GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
218
+ GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
206
219
  GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32);
207
220
  GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32);
208
221
  GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32);
@@ -218,27 +231,100 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
218
231
  #undef GGML_METAL_ADD_KERNEL
219
232
  }
220
233
 
221
- fprintf(stderr, "%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
222
- fprintf(stderr, "%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
234
+ metal_printf("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
235
+ metal_printf("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
223
236
  if (ctx->device.maxTransferRate != 0) {
224
- fprintf(stderr, "%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
237
+ metal_printf("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
225
238
  } else {
226
- fprintf(stderr, "%s: maxTransferRate = built-in GPU\n", __func__);
239
+ metal_printf("%s: maxTransferRate = built-in GPU\n", __func__);
227
240
  }
228
241
 
229
242
  return ctx;
230
243
  }
231
244
 
232
245
  void ggml_metal_free(struct ggml_metal_context * ctx) {
233
- fprintf(stderr, "%s: deallocating\n", __func__);
246
+ metal_printf("%s: deallocating\n", __func__);
247
+ #define GGML_METAL_DEL_KERNEL(name) \
248
+ [ctx->function_##name release]; \
249
+ [ctx->pipeline_##name release];
250
+
251
+ GGML_METAL_DEL_KERNEL(add);
252
+ GGML_METAL_DEL_KERNEL(add_row);
253
+ GGML_METAL_DEL_KERNEL(mul);
254
+ GGML_METAL_DEL_KERNEL(mul_row);
255
+ GGML_METAL_DEL_KERNEL(scale);
256
+ GGML_METAL_DEL_KERNEL(silu);
257
+ GGML_METAL_DEL_KERNEL(relu);
258
+ GGML_METAL_DEL_KERNEL(gelu);
259
+ GGML_METAL_DEL_KERNEL(soft_max);
260
+ GGML_METAL_DEL_KERNEL(diag_mask_inf);
261
+ GGML_METAL_DEL_KERNEL(get_rows_f16);
262
+ GGML_METAL_DEL_KERNEL(get_rows_q4_0);
263
+ GGML_METAL_DEL_KERNEL(get_rows_q4_1);
264
+ GGML_METAL_DEL_KERNEL(get_rows_q8_0);
265
+ GGML_METAL_DEL_KERNEL(get_rows_q2_K);
266
+ GGML_METAL_DEL_KERNEL(get_rows_q3_K);
267
+ GGML_METAL_DEL_KERNEL(get_rows_q4_K);
268
+ GGML_METAL_DEL_KERNEL(get_rows_q5_K);
269
+ GGML_METAL_DEL_KERNEL(get_rows_q6_K);
270
+ GGML_METAL_DEL_KERNEL(rms_norm);
271
+ GGML_METAL_DEL_KERNEL(norm);
272
+ GGML_METAL_DEL_KERNEL(mul_mat_f16_f32);
273
+ GGML_METAL_DEL_KERNEL(mul_mat_q4_0_f32);
274
+ GGML_METAL_DEL_KERNEL(mul_mat_q4_1_f32);
275
+ GGML_METAL_DEL_KERNEL(mul_mat_q8_0_f32);
276
+ GGML_METAL_DEL_KERNEL(mul_mat_q2_K_f32);
277
+ GGML_METAL_DEL_KERNEL(mul_mat_q3_K_f32);
278
+ GGML_METAL_DEL_KERNEL(mul_mat_q4_K_f32);
279
+ GGML_METAL_DEL_KERNEL(mul_mat_q5_K_f32);
280
+ GGML_METAL_DEL_KERNEL(mul_mat_q6_K_f32);
281
+ GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
282
+ GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
283
+ GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
284
+ GGML_METAL_DEL_KERNEL(mul_mm_q4_1_f32);
285
+ GGML_METAL_DEL_KERNEL(mul_mm_q2_K_f32);
286
+ GGML_METAL_DEL_KERNEL(mul_mm_q3_K_f32);
287
+ GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
288
+ GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
289
+ GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
290
+ GGML_METAL_DEL_KERNEL(rope);
291
+ GGML_METAL_DEL_KERNEL(alibi_f32);
292
+ GGML_METAL_DEL_KERNEL(cpy_f32_f16);
293
+ GGML_METAL_DEL_KERNEL(cpy_f32_f32);
294
+ GGML_METAL_DEL_KERNEL(cpy_f16_f16);
295
+
296
+ #undef GGML_METAL_DEL_KERNEL
297
+
234
298
  for (int i = 0; i < ctx->n_buffers; ++i) {
235
299
  [ctx->buffers[i].metal release];
236
300
  }
301
+
302
+ [ctx->library release];
303
+ [ctx->queue release];
304
+ [ctx->device release];
305
+
306
+ dispatch_release(ctx->d_queue);
307
+
237
308
  free(ctx);
238
309
  }
239
310
 
311
+ void * ggml_metal_host_malloc(size_t n) {
312
+ void * data = NULL;
313
+ const int result = posix_memalign((void **) &data, getpagesize(), n);
314
+ if (result != 0) {
315
+ metal_printf("%s: error: posix_memalign failed\n", __func__);
316
+ return NULL;
317
+ }
318
+
319
+ return data;
320
+ }
321
+
322
+ void ggml_metal_host_free(void * data) {
323
+ free(data);
324
+ }
325
+
240
326
  void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb) {
241
- ctx->n_cb = n_cb;
327
+ ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
242
328
  }
243
329
 
244
330
  int ggml_metal_if_optimized(struct ggml_metal_context * ctx) {
@@ -254,7 +340,7 @@ int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx) {
254
340
  // Metal buffer based on the host memory pointer
255
341
  //
256
342
  static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, struct ggml_tensor * t, size_t * offs) {
257
- //fprintf(stderr, "%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach);
343
+ //metal_printf("%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach);
258
344
 
259
345
  const int64_t tsize = ggml_nbytes(t);
260
346
 
@@ -265,13 +351,13 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru
265
351
  if (ioffs >= 0 && ioffs + tsize <= (int64_t) ctx->buffers[i].size) {
266
352
  *offs = (size_t) ioffs;
267
353
 
268
- //fprintf(stderr, "%s: '%s' tensor '%16s', offs = %8ld\n", __func__, ctx->buffers[i].name, t->name, *offs);
354
+ //metal_printf("%s: '%s' tensor '%16s', offs = %8ld\n", __func__, ctx->buffers[i].name, t->name, *offs);
269
355
 
270
356
  return ctx->buffers[i].metal;
271
357
  }
272
358
  }
273
359
 
274
- fprintf(stderr, "%s: error: buffer is nil\n", __func__);
360
+ metal_printf("%s: error: buffer is nil\n", __func__);
275
361
 
276
362
  return nil;
277
363
  }
@@ -283,7 +369,7 @@ bool ggml_metal_add_buffer(
283
369
  size_t size,
284
370
  size_t max_size) {
285
371
  if (ctx->n_buffers >= GGML_METAL_MAX_BUFFERS) {
286
- fprintf(stderr, "%s: too many buffers\n", __func__);
372
+ metal_printf("%s: too many buffers\n", __func__);
287
373
  return false;
288
374
  }
289
375
 
@@ -293,7 +379,7 @@ bool ggml_metal_add_buffer(
293
379
  const int64_t ioffs = (int64_t) data - (int64_t) ctx->buffers[i].data;
294
380
 
295
381
  if (ioffs >= 0 && ioffs < (int64_t) ctx->buffers[i].size) {
296
- fprintf(stderr, "%s: error: buffer '%s' overlaps with '%s'\n", __func__, name, ctx->buffers[i].name);
382
+ metal_printf("%s: error: buffer '%s' overlaps with '%s'\n", __func__, name, ctx->buffers[i].name);
297
383
  return false;
298
384
  }
299
385
  }
@@ -314,11 +400,11 @@ bool ggml_metal_add_buffer(
314
400
  ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
315
401
 
316
402
  if (ctx->buffers[ctx->n_buffers].metal == nil) {
317
- fprintf(stderr, "%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_aligned / 1024.0 / 1024.0);
403
+ metal_printf("%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_aligned / 1024.0 / 1024.0);
318
404
  return false;
319
405
  }
320
406
 
321
- fprintf(stderr, "%s: allocated '%-16s' buffer, size = %8.2f MB", __func__, name, size_aligned / 1024.0 / 1024.0);
407
+ metal_printf("%s: allocated '%-16s' buffer, size = %8.2f MB", __func__, name, size_aligned / 1024.0 / 1024.0);
322
408
 
323
409
  ++ctx->n_buffers;
324
410
  } else {
@@ -338,27 +424,27 @@ bool ggml_metal_add_buffer(
338
424
  ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
339
425
 
340
426
  if (ctx->buffers[ctx->n_buffers].metal == nil) {
341
- fprintf(stderr, "%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0);
427
+ metal_printf("%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0);
342
428
  return false;
343
429
  }
344
430
 
345
- fprintf(stderr, "%s: allocated '%-16s' buffer, size = %8.2f MB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i);
431
+ metal_printf("%s: allocated '%-16s' buffer, size = %8.2f MB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i);
346
432
  if (i + size_step < size) {
347
- fprintf(stderr, "\n");
433
+ metal_printf("\n");
348
434
  }
349
435
 
350
436
  ++ctx->n_buffers;
351
437
  }
352
438
  }
353
439
 
354
- fprintf(stderr, ", (%8.2f / %8.2f)",
440
+ metal_printf(", (%8.2f / %8.2f)",
355
441
  ctx->device.currentAllocatedSize / 1024.0 / 1024.0,
356
442
  ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
357
443
 
358
444
  if (ctx->device.currentAllocatedSize > ctx->device.recommendedMaxWorkingSetSize) {
359
- fprintf(stderr, ", warning: current allocated size is greater than the recommended max working set size\n");
445
+ metal_printf(", warning: current allocated size is greater than the recommended max working set size\n");
360
446
  } else {
361
- fprintf(stderr, "\n");
447
+ metal_printf("\n");
362
448
  }
363
449
  }
364
450
 
@@ -368,8 +454,6 @@ bool ggml_metal_add_buffer(
368
454
  void ggml_metal_set_tensor(
369
455
  struct ggml_metal_context * ctx,
370
456
  struct ggml_tensor * t) {
371
- metal_printf("%s: set input for tensor '%s'\n", __func__, t->name);
372
-
373
457
  size_t offs;
374
458
  id<MTLBuffer> id_dst = ggml_metal_get_buffer(ctx, t, &offs);
375
459
 
@@ -379,8 +463,6 @@ void ggml_metal_set_tensor(
379
463
  void ggml_metal_get_tensor(
380
464
  struct ggml_metal_context * ctx,
381
465
  struct ggml_tensor * t) {
382
- metal_printf("%s: extract results for tensor '%s'\n", __func__, t->name);
383
-
384
466
  size_t offs;
385
467
  id<MTLBuffer> id_src = ggml_metal_get_buffer(ctx, t, &offs);
386
468
 
@@ -475,14 +557,14 @@ void ggml_metal_graph_find_concurrency(
475
557
  }
476
558
 
477
559
  if (ctx->concur_list_len > GGML_MAX_CONCUR) {
478
- fprintf(stderr, "%s: too many elements for metal ctx->concur_list!\n", __func__);
560
+ metal_printf("%s: too many elements for metal ctx->concur_list!\n", __func__);
479
561
  }
480
562
  }
481
563
 
482
564
  void ggml_metal_graph_compute(
483
565
  struct ggml_metal_context * ctx,
484
566
  struct ggml_cgraph * gf) {
485
- metal_printf("%s: evaluating graph\n", __func__);
567
+ @autoreleasepool {
486
568
 
487
569
  // if there is ctx->concur_list, dispatch concurrently
488
570
  // else fallback to serial dispatch
@@ -498,32 +580,28 @@ void ggml_metal_graph_compute(
498
580
 
499
581
  const int n_cb = ctx->n_cb;
500
582
 
501
- NSMutableArray * command_buffers = [NSMutableArray arrayWithCapacity:n_cb];
502
-
503
583
  for (int i = 0; i < n_cb; ++i) {
504
- command_buffers[i] = [ctx->queue commandBuffer];
584
+ ctx->command_buffers[i] = [ctx->queue commandBuffer];
505
585
 
506
586
  // enqueue the command buffers in order to specify their execution order
507
- [command_buffers[i] enqueue];
508
- }
587
+ [ctx->command_buffers[i] enqueue];
509
588
 
510
- // TODO: is this the best way to start threads?
511
- dispatch_queue_t queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT);
589
+ ctx->command_encoders[i] = [ctx->command_buffers[i] computeCommandEncoderWithDescriptor: edesc];
590
+ }
512
591
 
513
592
  for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
514
593
  const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb;
515
594
 
516
- dispatch_async(queue, ^{
595
+ dispatch_async(ctx->d_queue, ^{
517
596
  size_t offs_src0 = 0;
518
597
  size_t offs_src1 = 0;
519
598
  size_t offs_dst = 0;
520
599
 
521
- id<MTLCommandBuffer> command_buffer = command_buffers[cb_idx];
522
-
523
- id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
600
+ id<MTLCommandBuffer> command_buffer = ctx->command_buffers[cb_idx];
601
+ id<MTLComputeCommandEncoder> encoder = ctx->command_encoders[cb_idx];
524
602
 
525
- const int node_start = (cb_idx + 0) * n_nodes_per_cb;
526
- const int node_end = (cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb;
603
+ const int node_start = (cb_idx + 0) * n_nodes_per_cb;
604
+ const int node_end = MIN((cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb, n_nodes);
527
605
 
528
606
  for (int ind = node_start; ind < node_end; ++ind) {
529
607
  const int i = has_concur ? ctx->concur_list[ind] : ind;
@@ -533,7 +611,7 @@ void ggml_metal_graph_compute(
533
611
  continue;
534
612
  }
535
613
 
536
- metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
614
+ //metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
537
615
 
538
616
  struct ggml_tensor * src0 = gf->nodes[i]->src[0];
539
617
  struct ggml_tensor * src1 = gf->nodes[i]->src[1];
@@ -602,6 +680,12 @@ void ggml_metal_graph_compute(
602
680
  } break;
603
681
  case GGML_OP_ADD:
604
682
  {
683
+ GGML_ASSERT(ggml_is_contiguous(src0));
684
+
685
+ // utilize float4
686
+ GGML_ASSERT(ne00 % 4 == 0);
687
+ const int64_t nb = ne00/4;
688
+
605
689
  if (ggml_nelements(src1) == ne10) {
606
690
  // src1 is a row
607
691
  [encoder setComputePipelineState:ctx->pipeline_add_row];
@@ -611,14 +695,20 @@ void ggml_metal_graph_compute(
611
695
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
612
696
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
613
697
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
614
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
698
+ [encoder setBytes:&nb length:sizeof(nb) atIndex:3];
615
699
 
616
- const int64_t n = ggml_nelements(dst);
700
+ const int64_t n = ggml_nelements(dst)/4;
617
701
 
618
702
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
619
703
  } break;
620
704
  case GGML_OP_MUL:
621
705
  {
706
+ GGML_ASSERT(ggml_is_contiguous(src0));
707
+
708
+ // utilize float4
709
+ GGML_ASSERT(ne00 % 4 == 0);
710
+ const int64_t nb = ne00/4;
711
+
622
712
  if (ggml_nelements(src1) == ne10) {
623
713
  // src1 is a row
624
714
  [encoder setComputePipelineState:ctx->pipeline_mul_row];
@@ -628,9 +718,9 @@ void ggml_metal_graph_compute(
628
718
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
629
719
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
630
720
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
631
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
721
+ [encoder setBytes:&nb length:sizeof(nb) atIndex:3];
632
722
 
633
- const int64_t n = ggml_nelements(dst);
723
+ const int64_t n = ggml_nelements(dst)/4;
634
724
 
635
725
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
636
726
  } break;
@@ -681,7 +771,7 @@ void ggml_metal_graph_compute(
681
771
  } break;
682
772
  default:
683
773
  {
684
- fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
774
+ metal_printf("%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
685
775
  GGML_ASSERT(false);
686
776
  }
687
777
  } break;
@@ -729,32 +819,32 @@ void ggml_metal_graph_compute(
729
819
  [ctx->device supportsFamily:MTLGPUFamilyApple7] &&
730
820
  ne00%32 == 0 &&
731
821
  ne11 > 1) {
732
- switch (src0->type) {
733
- case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
734
- case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
735
- case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
736
- case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break;
737
- case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break;
738
- case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_K_f32]; break;
739
- case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_K_f32]; break;
740
- case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q6_K_f32]; break;
741
- default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
742
- }
743
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
744
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
745
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
746
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
747
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
748
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
749
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
750
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
751
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:8];
752
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:9];
753
- [encoder setBytes:&gqa length:sizeof(gqa) atIndex:10];
754
- [encoder setThreadgroupMemoryLength:8192 atIndex:0];
755
- [encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
822
+ switch (src0->type) {
823
+ case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
824
+ case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
825
+ case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
826
+ case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q8_0_f32]; break;
827
+ case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break;
828
+ case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break;
829
+ case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_K_f32]; break;
830
+ case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_K_f32]; break;
831
+ case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q6_K_f32]; break;
832
+ default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
756
833
  }
757
- else {
834
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
835
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
836
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
837
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
838
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
839
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
840
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
841
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
842
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:8];
843
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:9];
844
+ [encoder setBytes:&gqa length:sizeof(gqa) atIndex:10];
845
+ [encoder setThreadgroupMemoryLength:8192 atIndex:0];
846
+ [encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
847
+ } else {
758
848
  int nth0 = 32;
759
849
  int nth1 = 1;
760
850
 
@@ -762,7 +852,7 @@ void ggml_metal_graph_compute(
762
852
  switch (src0t) {
763
853
  case GGML_TYPE_F16:
764
854
  {
765
- nth0 = 64;
855
+ nth0 = 32;
766
856
  nth1 = 1;
767
857
  [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
768
858
  } break;
@@ -784,6 +874,15 @@ void ggml_metal_graph_compute(
784
874
  nth1 = 8;
785
875
  [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_1_f32];
786
876
  } break;
877
+ case GGML_TYPE_Q8_0:
878
+ {
879
+ GGML_ASSERT(ne02 == 1);
880
+ GGML_ASSERT(ne12 == 1);
881
+
882
+ nth0 = 8;
883
+ nth1 = 8;
884
+ [encoder setComputePipelineState:ctx->pipeline_mul_mat_q8_0_f32];
885
+ } break;
787
886
  case GGML_TYPE_Q2_K:
788
887
  {
789
888
  GGML_ASSERT(ne02 == 1);
@@ -831,7 +930,7 @@ void ggml_metal_graph_compute(
831
930
  } break;
832
931
  default:
833
932
  {
834
- fprintf(stderr, "Asserting on type %d\n",(int)src0t);
933
+ metal_printf("Asserting on type %d\n",(int)src0t);
835
934
  GGML_ASSERT(false && "not implemented");
836
935
  }
837
936
  };
@@ -853,24 +952,24 @@ void ggml_metal_graph_compute(
853
952
  [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
854
953
  [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
855
954
  [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
856
- [encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
955
+ [encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
857
956
 
858
- if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
957
+ if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q8_0 ||
859
958
  src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
860
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
959
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
861
960
  }
862
961
  else if (src0t == GGML_TYPE_Q3_K) {
863
962
  #ifdef GGML_QKK_64
864
- [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
963
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
865
964
  #else
866
- [encoder dispatchThreadgroups:MTLSizeMake((ne01+3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
965
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
867
966
  #endif
868
967
  }
869
968
  else if (src0t == GGML_TYPE_Q5_K) {
870
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3) / 4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
969
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
871
970
  }
872
971
  else if (src0t == GGML_TYPE_Q6_K) {
873
- [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
972
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
874
973
  } else {
875
974
  [encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
876
975
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
@@ -880,9 +979,10 @@ void ggml_metal_graph_compute(
880
979
  case GGML_OP_GET_ROWS:
881
980
  {
882
981
  switch (src0->type) {
883
- case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
982
+ case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
884
983
  case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
885
984
  case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
985
+ case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q8_0]; break;
886
986
  case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_K]; break;
887
987
  case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_K]; break;
888
988
  case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_K]; break;
@@ -923,16 +1023,17 @@ void ggml_metal_graph_compute(
923
1023
  } break;
924
1024
  case GGML_OP_NORM:
925
1025
  {
926
- const float eps = 1e-5f;
1026
+ float eps;
1027
+ memcpy(&eps, dst->op_params, sizeof(float));
927
1028
 
928
1029
  const int nth = 256;
929
1030
 
930
1031
  [encoder setComputePipelineState:ctx->pipeline_norm];
931
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
932
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
933
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
934
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
935
- [encoder setBytes:&eps length:sizeof( float) atIndex:4];
1032
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1033
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1034
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1035
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
1036
+ [encoder setBytes:&eps length:sizeof( float) atIndex:4];
936
1037
  [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
937
1038
 
938
1039
  const int64_t nrows = ggml_nrows(src0);
@@ -975,7 +1076,9 @@ void ggml_metal_graph_compute(
975
1076
  [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
976
1077
  [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
977
1078
  [encoder setBytes:&m0 length:sizeof( float) atIndex:18];
1079
+
978
1080
  const int nth = 32;
1081
+
979
1082
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
980
1083
  } break;
981
1084
  case GGML_OP_ROPE:
@@ -990,8 +1093,8 @@ void ggml_metal_graph_compute(
990
1093
  memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
991
1094
 
992
1095
  [encoder setComputePipelineState:ctx->pipeline_rope];
993
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
994
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1096
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1097
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
995
1098
  [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
996
1099
  [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
997
1100
  [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
@@ -1042,30 +1145,30 @@ void ggml_metal_graph_compute(
1042
1145
  default: GGML_ASSERT(false && "not implemented");
1043
1146
  }
1044
1147
 
1045
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1046
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1047
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1048
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
1049
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
1050
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
1051
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
1052
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
1053
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
1054
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
1055
- [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
1056
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
1057
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
1058
- [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
1059
- [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
1060
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
1061
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
1062
- [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
1148
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1149
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1150
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1151
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
1152
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
1153
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
1154
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
1155
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
1156
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
1157
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
1158
+ [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
1159
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
1160
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
1161
+ [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
1162
+ [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
1163
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
1164
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
1165
+ [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
1063
1166
 
1064
1167
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1065
1168
  } break;
1066
1169
  default:
1067
1170
  {
1068
- fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
1171
+ metal_printf("%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
1069
1172
  GGML_ASSERT(false);
1070
1173
  }
1071
1174
  }
@@ -1081,17 +1184,19 @@ void ggml_metal_graph_compute(
1081
1184
  }
1082
1185
 
1083
1186
  // wait for all threads to finish
1084
- dispatch_barrier_sync(queue, ^{});
1085
-
1086
- [command_buffers[n_cb - 1] waitUntilCompleted];
1187
+ dispatch_barrier_sync(ctx->d_queue, ^{});
1087
1188
 
1088
1189
  // check status of command buffers
1089
1190
  // needed to detect if the device ran out-of-memory for example (#1881)
1090
1191
  for (int i = 0; i < n_cb; i++) {
1091
- MTLCommandBufferStatus status = (MTLCommandBufferStatus) [command_buffers[i] status];
1192
+ [ctx->command_buffers[i] waitUntilCompleted];
1193
+
1194
+ MTLCommandBufferStatus status = (MTLCommandBufferStatus) [ctx->command_buffers[i] status];
1092
1195
  if (status != MTLCommandBufferStatusCompleted) {
1093
- fprintf(stderr, "%s: command buffer %d failed with status %lu\n", __func__, i, status);
1196
+ metal_printf("%s: command buffer %d failed with status %lu\n", __func__, i, status);
1094
1197
  GGML_ASSERT(false);
1095
1198
  }
1096
1199
  }
1200
+
1201
+ }
1097
1202
  }