llama_cpp 0.3.8 → 0.5.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -11,6 +11,7 @@
11
11
  #define MIN(a, b) ((a) < (b) ? (a) : (b))
12
12
  #define MAX(a, b) ((a) > (b) ? (a) : (b))
13
13
 
14
+ // TODO: temporary - reuse llama.cpp logging
14
15
  #ifdef GGML_METAL_NDEBUG
15
16
  #define metal_printf(...)
16
17
  #else
@@ -33,12 +34,15 @@ struct ggml_metal_buffer {
33
34
  struct ggml_metal_context {
34
35
  int n_cb;
35
36
 
36
- float * logits;
37
-
38
37
  id<MTLDevice> device;
39
38
  id<MTLCommandQueue> queue;
40
39
  id<MTLLibrary> library;
41
40
 
41
+ id<MTLCommandBuffer> command_buffers [GGML_METAL_MAX_COMMAND_BUFFERS];
42
+ id<MTLComputeCommandEncoder> command_encoders[GGML_METAL_MAX_COMMAND_BUFFERS];
43
+
44
+ dispatch_queue_t d_queue;
45
+
42
46
  int n_buffers;
43
47
  struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
44
48
 
@@ -63,6 +67,7 @@ struct ggml_metal_context {
63
67
  GGML_METAL_DECL_KERNEL(get_rows_f16);
64
68
  GGML_METAL_DECL_KERNEL(get_rows_q4_0);
65
69
  GGML_METAL_DECL_KERNEL(get_rows_q4_1);
70
+ GGML_METAL_DECL_KERNEL(get_rows_q8_0);
66
71
  GGML_METAL_DECL_KERNEL(get_rows_q2_K);
67
72
  GGML_METAL_DECL_KERNEL(get_rows_q3_K);
68
73
  GGML_METAL_DECL_KERNEL(get_rows_q4_K);
@@ -73,6 +78,7 @@ struct ggml_metal_context {
73
78
  GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
74
79
  GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
75
80
  GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
81
+ GGML_METAL_DECL_KERNEL(mul_mat_q8_0_f32);
76
82
  GGML_METAL_DECL_KERNEL(mul_mat_q2_K_f32);
77
83
  GGML_METAL_DECL_KERNEL(mul_mat_q3_K_f32);
78
84
  GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32);
@@ -81,6 +87,7 @@ struct ggml_metal_context {
81
87
  GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
82
88
  GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
83
89
  GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32);
90
+ GGML_METAL_DECL_KERNEL(mul_mm_q8_0_f32);
84
91
  GGML_METAL_DECL_KERNEL(mul_mm_q2_K_f32);
85
92
  GGML_METAL_DECL_KERNEL(mul_mm_q3_K_f32);
86
93
  GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
@@ -107,16 +114,17 @@ static NSString * const msl_library_source = @"see metal.metal";
107
114
  @end
108
115
 
109
116
  struct ggml_metal_context * ggml_metal_init(int n_cb) {
110
- fprintf(stderr, "%s: allocating\n", __func__);
117
+ metal_printf("%s: allocating\n", __func__);
111
118
 
112
119
  struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context));
113
120
 
114
- ctx->n_cb = n_cb;
121
+ ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
115
122
  ctx->device = MTLCreateSystemDefaultDevice();
116
123
  ctx->queue = [ctx->device newCommandQueue];
117
124
  ctx->n_buffers = 0;
118
125
  ctx->concur_list_len = 0;
119
126
 
127
+ ctx->d_queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT);
120
128
 
121
129
  #if 0
122
130
  // compile from source string and show compile log
@@ -125,7 +133,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
125
133
 
126
134
  ctx->library = [ctx->device newLibraryWithSource:msl_library_source options:nil error:&error];
127
135
  if (error) {
128
- fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
136
+ metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]);
129
137
  return NULL;
130
138
  }
131
139
  }
@@ -139,11 +147,11 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
139
147
  //NSString * path = [[NSBundle mainBundle] pathForResource:@"../../examples/metal/metal" ofType:@"metal"];
140
148
  NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
141
149
  NSString * path = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
142
- fprintf(stderr, "%s: loading '%s'\n", __func__, [path UTF8String]);
150
+ metal_printf("%s: loading '%s'\n", __func__, [path UTF8String]);
143
151
 
144
152
  NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error];
145
153
  if (error) {
146
- fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
154
+ metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]);
147
155
  return NULL;
148
156
  }
149
157
 
@@ -155,7 +163,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
155
163
  ctx->library = [ctx->device newLibraryWithSource:src options:nil error:&error];
156
164
  #endif
157
165
  if (error) {
158
- fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
166
+ metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]);
159
167
  return NULL;
160
168
  }
161
169
  }
@@ -167,9 +175,11 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
167
175
  #define GGML_METAL_ADD_KERNEL(name) \
168
176
  ctx->function_##name = [ctx->library newFunctionWithName:@"kernel_"#name]; \
169
177
  ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:&error]; \
170
- fprintf(stderr, "%s: loaded %-32s %16p\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name); \
178
+ metal_printf("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name, \
179
+ (int) ctx->pipeline_##name.maxTotalThreadsPerThreadgroup, \
180
+ (int) ctx->pipeline_##name.threadExecutionWidth); \
171
181
  if (error) { \
172
- fprintf(stderr, "%s: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
182
+ metal_printf("%s: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
173
183
  return NULL; \
174
184
  }
175
185
 
@@ -186,6 +196,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
186
196
  GGML_METAL_ADD_KERNEL(get_rows_f16);
187
197
  GGML_METAL_ADD_KERNEL(get_rows_q4_0);
188
198
  GGML_METAL_ADD_KERNEL(get_rows_q4_1);
199
+ GGML_METAL_ADD_KERNEL(get_rows_q8_0);
189
200
  GGML_METAL_ADD_KERNEL(get_rows_q2_K);
190
201
  GGML_METAL_ADD_KERNEL(get_rows_q3_K);
191
202
  GGML_METAL_ADD_KERNEL(get_rows_q4_K);
@@ -196,6 +207,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
196
207
  GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
197
208
  GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
198
209
  GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
210
+ GGML_METAL_ADD_KERNEL(mul_mat_q8_0_f32);
199
211
  GGML_METAL_ADD_KERNEL(mul_mat_q2_K_f32);
200
212
  GGML_METAL_ADD_KERNEL(mul_mat_q3_K_f32);
201
213
  GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32);
@@ -203,6 +215,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
203
215
  GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32);
204
216
  GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
205
217
  GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
218
+ GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
206
219
  GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32);
207
220
  GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32);
208
221
  GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32);
@@ -218,27 +231,100 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
218
231
  #undef GGML_METAL_ADD_KERNEL
219
232
  }
220
233
 
221
- fprintf(stderr, "%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
222
- fprintf(stderr, "%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
234
+ metal_printf("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
235
+ metal_printf("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
223
236
  if (ctx->device.maxTransferRate != 0) {
224
- fprintf(stderr, "%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
237
+ metal_printf("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
225
238
  } else {
226
- fprintf(stderr, "%s: maxTransferRate = built-in GPU\n", __func__);
239
+ metal_printf("%s: maxTransferRate = built-in GPU\n", __func__);
227
240
  }
228
241
 
229
242
  return ctx;
230
243
  }
231
244
 
232
245
  void ggml_metal_free(struct ggml_metal_context * ctx) {
233
- fprintf(stderr, "%s: deallocating\n", __func__);
246
+ metal_printf("%s: deallocating\n", __func__);
247
+ #define GGML_METAL_DEL_KERNEL(name) \
248
+ [ctx->function_##name release]; \
249
+ [ctx->pipeline_##name release];
250
+
251
+ GGML_METAL_DEL_KERNEL(add);
252
+ GGML_METAL_DEL_KERNEL(add_row);
253
+ GGML_METAL_DEL_KERNEL(mul);
254
+ GGML_METAL_DEL_KERNEL(mul_row);
255
+ GGML_METAL_DEL_KERNEL(scale);
256
+ GGML_METAL_DEL_KERNEL(silu);
257
+ GGML_METAL_DEL_KERNEL(relu);
258
+ GGML_METAL_DEL_KERNEL(gelu);
259
+ GGML_METAL_DEL_KERNEL(soft_max);
260
+ GGML_METAL_DEL_KERNEL(diag_mask_inf);
261
+ GGML_METAL_DEL_KERNEL(get_rows_f16);
262
+ GGML_METAL_DEL_KERNEL(get_rows_q4_0);
263
+ GGML_METAL_DEL_KERNEL(get_rows_q4_1);
264
+ GGML_METAL_DEL_KERNEL(get_rows_q8_0);
265
+ GGML_METAL_DEL_KERNEL(get_rows_q2_K);
266
+ GGML_METAL_DEL_KERNEL(get_rows_q3_K);
267
+ GGML_METAL_DEL_KERNEL(get_rows_q4_K);
268
+ GGML_METAL_DEL_KERNEL(get_rows_q5_K);
269
+ GGML_METAL_DEL_KERNEL(get_rows_q6_K);
270
+ GGML_METAL_DEL_KERNEL(rms_norm);
271
+ GGML_METAL_DEL_KERNEL(norm);
272
+ GGML_METAL_DEL_KERNEL(mul_mat_f16_f32);
273
+ GGML_METAL_DEL_KERNEL(mul_mat_q4_0_f32);
274
+ GGML_METAL_DEL_KERNEL(mul_mat_q4_1_f32);
275
+ GGML_METAL_DEL_KERNEL(mul_mat_q8_0_f32);
276
+ GGML_METAL_DEL_KERNEL(mul_mat_q2_K_f32);
277
+ GGML_METAL_DEL_KERNEL(mul_mat_q3_K_f32);
278
+ GGML_METAL_DEL_KERNEL(mul_mat_q4_K_f32);
279
+ GGML_METAL_DEL_KERNEL(mul_mat_q5_K_f32);
280
+ GGML_METAL_DEL_KERNEL(mul_mat_q6_K_f32);
281
+ GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
282
+ GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
283
+ GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
284
+ GGML_METAL_DEL_KERNEL(mul_mm_q4_1_f32);
285
+ GGML_METAL_DEL_KERNEL(mul_mm_q2_K_f32);
286
+ GGML_METAL_DEL_KERNEL(mul_mm_q3_K_f32);
287
+ GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
288
+ GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
289
+ GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
290
+ GGML_METAL_DEL_KERNEL(rope);
291
+ GGML_METAL_DEL_KERNEL(alibi_f32);
292
+ GGML_METAL_DEL_KERNEL(cpy_f32_f16);
293
+ GGML_METAL_DEL_KERNEL(cpy_f32_f32);
294
+ GGML_METAL_DEL_KERNEL(cpy_f16_f16);
295
+
296
+ #undef GGML_METAL_DEL_KERNEL
297
+
234
298
  for (int i = 0; i < ctx->n_buffers; ++i) {
235
299
  [ctx->buffers[i].metal release];
236
300
  }
301
+
302
+ [ctx->library release];
303
+ [ctx->queue release];
304
+ [ctx->device release];
305
+
306
+ dispatch_release(ctx->d_queue);
307
+
237
308
  free(ctx);
238
309
  }
239
310
 
311
+ void * ggml_metal_host_malloc(size_t n) {
312
+ void * data = NULL;
313
+ const int result = posix_memalign((void **) &data, getpagesize(), n);
314
+ if (result != 0) {
315
+ metal_printf("%s: error: posix_memalign failed\n", __func__);
316
+ return NULL;
317
+ }
318
+
319
+ return data;
320
+ }
321
+
322
+ void ggml_metal_host_free(void * data) {
323
+ free(data);
324
+ }
325
+
240
326
  void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb) {
241
- ctx->n_cb = n_cb;
327
+ ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
242
328
  }
243
329
 
244
330
  int ggml_metal_if_optimized(struct ggml_metal_context * ctx) {
@@ -254,7 +340,7 @@ int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx) {
254
340
  // Metal buffer based on the host memory pointer
255
341
  //
256
342
  static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, struct ggml_tensor * t, size_t * offs) {
257
- //fprintf(stderr, "%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach);
343
+ //metal_printf("%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach);
258
344
 
259
345
  const int64_t tsize = ggml_nbytes(t);
260
346
 
@@ -265,13 +351,13 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru
265
351
  if (ioffs >= 0 && ioffs + tsize <= (int64_t) ctx->buffers[i].size) {
266
352
  *offs = (size_t) ioffs;
267
353
 
268
- //fprintf(stderr, "%s: '%s' tensor '%16s', offs = %8ld\n", __func__, ctx->buffers[i].name, t->name, *offs);
354
+ //metal_printf("%s: '%s' tensor '%16s', offs = %8ld\n", __func__, ctx->buffers[i].name, t->name, *offs);
269
355
 
270
356
  return ctx->buffers[i].metal;
271
357
  }
272
358
  }
273
359
 
274
- fprintf(stderr, "%s: error: buffer is nil\n", __func__);
360
+ metal_printf("%s: error: buffer is nil\n", __func__);
275
361
 
276
362
  return nil;
277
363
  }
@@ -283,7 +369,7 @@ bool ggml_metal_add_buffer(
283
369
  size_t size,
284
370
  size_t max_size) {
285
371
  if (ctx->n_buffers >= GGML_METAL_MAX_BUFFERS) {
286
- fprintf(stderr, "%s: too many buffers\n", __func__);
372
+ metal_printf("%s: too many buffers\n", __func__);
287
373
  return false;
288
374
  }
289
375
 
@@ -293,7 +379,7 @@ bool ggml_metal_add_buffer(
293
379
  const int64_t ioffs = (int64_t) data - (int64_t) ctx->buffers[i].data;
294
380
 
295
381
  if (ioffs >= 0 && ioffs < (int64_t) ctx->buffers[i].size) {
296
- fprintf(stderr, "%s: error: buffer '%s' overlaps with '%s'\n", __func__, name, ctx->buffers[i].name);
382
+ metal_printf("%s: error: buffer '%s' overlaps with '%s'\n", __func__, name, ctx->buffers[i].name);
297
383
  return false;
298
384
  }
299
385
  }
@@ -314,11 +400,11 @@ bool ggml_metal_add_buffer(
314
400
  ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
315
401
 
316
402
  if (ctx->buffers[ctx->n_buffers].metal == nil) {
317
- fprintf(stderr, "%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_aligned / 1024.0 / 1024.0);
403
+ metal_printf("%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_aligned / 1024.0 / 1024.0);
318
404
  return false;
319
405
  }
320
406
 
321
- fprintf(stderr, "%s: allocated '%-16s' buffer, size = %8.2f MB", __func__, name, size_aligned / 1024.0 / 1024.0);
407
+ metal_printf("%s: allocated '%-16s' buffer, size = %8.2f MB", __func__, name, size_aligned / 1024.0 / 1024.0);
322
408
 
323
409
  ++ctx->n_buffers;
324
410
  } else {
@@ -338,27 +424,27 @@ bool ggml_metal_add_buffer(
338
424
  ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
339
425
 
340
426
  if (ctx->buffers[ctx->n_buffers].metal == nil) {
341
- fprintf(stderr, "%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0);
427
+ metal_printf("%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0);
342
428
  return false;
343
429
  }
344
430
 
345
- fprintf(stderr, "%s: allocated '%-16s' buffer, size = %8.2f MB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i);
431
+ metal_printf("%s: allocated '%-16s' buffer, size = %8.2f MB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i);
346
432
  if (i + size_step < size) {
347
- fprintf(stderr, "\n");
433
+ metal_printf("\n");
348
434
  }
349
435
 
350
436
  ++ctx->n_buffers;
351
437
  }
352
438
  }
353
439
 
354
- fprintf(stderr, ", (%8.2f / %8.2f)",
440
+ metal_printf(", (%8.2f / %8.2f)",
355
441
  ctx->device.currentAllocatedSize / 1024.0 / 1024.0,
356
442
  ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
357
443
 
358
444
  if (ctx->device.currentAllocatedSize > ctx->device.recommendedMaxWorkingSetSize) {
359
- fprintf(stderr, ", warning: current allocated size is greater than the recommended max working set size\n");
445
+ metal_printf(", warning: current allocated size is greater than the recommended max working set size\n");
360
446
  } else {
361
- fprintf(stderr, "\n");
447
+ metal_printf("\n");
362
448
  }
363
449
  }
364
450
 
@@ -368,8 +454,6 @@ bool ggml_metal_add_buffer(
368
454
  void ggml_metal_set_tensor(
369
455
  struct ggml_metal_context * ctx,
370
456
  struct ggml_tensor * t) {
371
- metal_printf("%s: set input for tensor '%s'\n", __func__, t->name);
372
-
373
457
  size_t offs;
374
458
  id<MTLBuffer> id_dst = ggml_metal_get_buffer(ctx, t, &offs);
375
459
 
@@ -379,8 +463,6 @@ void ggml_metal_set_tensor(
379
463
  void ggml_metal_get_tensor(
380
464
  struct ggml_metal_context * ctx,
381
465
  struct ggml_tensor * t) {
382
- metal_printf("%s: extract results for tensor '%s'\n", __func__, t->name);
383
-
384
466
  size_t offs;
385
467
  id<MTLBuffer> id_src = ggml_metal_get_buffer(ctx, t, &offs);
386
468
 
@@ -475,14 +557,14 @@ void ggml_metal_graph_find_concurrency(
475
557
  }
476
558
 
477
559
  if (ctx->concur_list_len > GGML_MAX_CONCUR) {
478
- fprintf(stderr, "%s: too many elements for metal ctx->concur_list!\n", __func__);
560
+ metal_printf("%s: too many elements for metal ctx->concur_list!\n", __func__);
479
561
  }
480
562
  }
481
563
 
482
564
  void ggml_metal_graph_compute(
483
565
  struct ggml_metal_context * ctx,
484
566
  struct ggml_cgraph * gf) {
485
- metal_printf("%s: evaluating graph\n", __func__);
567
+ @autoreleasepool {
486
568
 
487
569
  // if there is ctx->concur_list, dispatch concurrently
488
570
  // else fallback to serial dispatch
@@ -498,32 +580,28 @@ void ggml_metal_graph_compute(
498
580
 
499
581
  const int n_cb = ctx->n_cb;
500
582
 
501
- NSMutableArray * command_buffers = [NSMutableArray arrayWithCapacity:n_cb];
502
-
503
583
  for (int i = 0; i < n_cb; ++i) {
504
- command_buffers[i] = [ctx->queue commandBuffer];
584
+ ctx->command_buffers[i] = [ctx->queue commandBuffer];
505
585
 
506
586
  // enqueue the command buffers in order to specify their execution order
507
- [command_buffers[i] enqueue];
508
- }
587
+ [ctx->command_buffers[i] enqueue];
509
588
 
510
- // TODO: is this the best way to start threads?
511
- dispatch_queue_t queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT);
589
+ ctx->command_encoders[i] = [ctx->command_buffers[i] computeCommandEncoderWithDescriptor: edesc];
590
+ }
512
591
 
513
592
  for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
514
593
  const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb;
515
594
 
516
- dispatch_async(queue, ^{
595
+ dispatch_async(ctx->d_queue, ^{
517
596
  size_t offs_src0 = 0;
518
597
  size_t offs_src1 = 0;
519
598
  size_t offs_dst = 0;
520
599
 
521
- id<MTLCommandBuffer> command_buffer = command_buffers[cb_idx];
522
-
523
- id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
600
+ id<MTLCommandBuffer> command_buffer = ctx->command_buffers[cb_idx];
601
+ id<MTLComputeCommandEncoder> encoder = ctx->command_encoders[cb_idx];
524
602
 
525
- const int node_start = (cb_idx + 0) * n_nodes_per_cb;
526
- const int node_end = (cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb;
603
+ const int node_start = (cb_idx + 0) * n_nodes_per_cb;
604
+ const int node_end = MIN((cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb, n_nodes);
527
605
 
528
606
  for (int ind = node_start; ind < node_end; ++ind) {
529
607
  const int i = has_concur ? ctx->concur_list[ind] : ind;
@@ -533,7 +611,7 @@ void ggml_metal_graph_compute(
533
611
  continue;
534
612
  }
535
613
 
536
- metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
614
+ //metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
537
615
 
538
616
  struct ggml_tensor * src0 = gf->nodes[i]->src[0];
539
617
  struct ggml_tensor * src1 = gf->nodes[i]->src[1];
@@ -602,6 +680,12 @@ void ggml_metal_graph_compute(
602
680
  } break;
603
681
  case GGML_OP_ADD:
604
682
  {
683
+ GGML_ASSERT(ggml_is_contiguous(src0));
684
+
685
+ // utilize float4
686
+ GGML_ASSERT(ne00 % 4 == 0);
687
+ const int64_t nb = ne00/4;
688
+
605
689
  if (ggml_nelements(src1) == ne10) {
606
690
  // src1 is a row
607
691
  [encoder setComputePipelineState:ctx->pipeline_add_row];
@@ -611,14 +695,20 @@ void ggml_metal_graph_compute(
611
695
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
612
696
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
613
697
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
614
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
698
+ [encoder setBytes:&nb length:sizeof(nb) atIndex:3];
615
699
 
616
- const int64_t n = ggml_nelements(dst);
700
+ const int64_t n = ggml_nelements(dst)/4;
617
701
 
618
702
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
619
703
  } break;
620
704
  case GGML_OP_MUL:
621
705
  {
706
+ GGML_ASSERT(ggml_is_contiguous(src0));
707
+
708
+ // utilize float4
709
+ GGML_ASSERT(ne00 % 4 == 0);
710
+ const int64_t nb = ne00/4;
711
+
622
712
  if (ggml_nelements(src1) == ne10) {
623
713
  // src1 is a row
624
714
  [encoder setComputePipelineState:ctx->pipeline_mul_row];
@@ -628,9 +718,9 @@ void ggml_metal_graph_compute(
628
718
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
629
719
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
630
720
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
631
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
721
+ [encoder setBytes:&nb length:sizeof(nb) atIndex:3];
632
722
 
633
- const int64_t n = ggml_nelements(dst);
723
+ const int64_t n = ggml_nelements(dst)/4;
634
724
 
635
725
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
636
726
  } break;
@@ -681,7 +771,7 @@ void ggml_metal_graph_compute(
681
771
  } break;
682
772
  default:
683
773
  {
684
- fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
774
+ metal_printf("%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
685
775
  GGML_ASSERT(false);
686
776
  }
687
777
  } break;
@@ -729,32 +819,32 @@ void ggml_metal_graph_compute(
729
819
  [ctx->device supportsFamily:MTLGPUFamilyApple7] &&
730
820
  ne00%32 == 0 &&
731
821
  ne11 > 1) {
732
- switch (src0->type) {
733
- case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
734
- case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
735
- case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
736
- case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break;
737
- case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break;
738
- case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_K_f32]; break;
739
- case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_K_f32]; break;
740
- case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q6_K_f32]; break;
741
- default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
742
- }
743
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
744
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
745
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
746
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
747
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
748
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
749
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
750
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
751
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:8];
752
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:9];
753
- [encoder setBytes:&gqa length:sizeof(gqa) atIndex:10];
754
- [encoder setThreadgroupMemoryLength:8192 atIndex:0];
755
- [encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
822
+ switch (src0->type) {
823
+ case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
824
+ case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
825
+ case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
826
+ case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q8_0_f32]; break;
827
+ case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break;
828
+ case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break;
829
+ case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_K_f32]; break;
830
+ case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_K_f32]; break;
831
+ case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q6_K_f32]; break;
832
+ default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
756
833
  }
757
- else {
834
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
835
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
836
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
837
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
838
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
839
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
840
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
841
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
842
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:8];
843
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:9];
844
+ [encoder setBytes:&gqa length:sizeof(gqa) atIndex:10];
845
+ [encoder setThreadgroupMemoryLength:8192 atIndex:0];
846
+ [encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
847
+ } else {
758
848
  int nth0 = 32;
759
849
  int nth1 = 1;
760
850
 
@@ -762,7 +852,7 @@ void ggml_metal_graph_compute(
762
852
  switch (src0t) {
763
853
  case GGML_TYPE_F16:
764
854
  {
765
- nth0 = 64;
855
+ nth0 = 32;
766
856
  nth1 = 1;
767
857
  [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
768
858
  } break;
@@ -784,6 +874,15 @@ void ggml_metal_graph_compute(
784
874
  nth1 = 8;
785
875
  [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_1_f32];
786
876
  } break;
877
+ case GGML_TYPE_Q8_0:
878
+ {
879
+ GGML_ASSERT(ne02 == 1);
880
+ GGML_ASSERT(ne12 == 1);
881
+
882
+ nth0 = 8;
883
+ nth1 = 8;
884
+ [encoder setComputePipelineState:ctx->pipeline_mul_mat_q8_0_f32];
885
+ } break;
787
886
  case GGML_TYPE_Q2_K:
788
887
  {
789
888
  GGML_ASSERT(ne02 == 1);
@@ -831,7 +930,7 @@ void ggml_metal_graph_compute(
831
930
  } break;
832
931
  default:
833
932
  {
834
- fprintf(stderr, "Asserting on type %d\n",(int)src0t);
933
+ metal_printf("Asserting on type %d\n",(int)src0t);
835
934
  GGML_ASSERT(false && "not implemented");
836
935
  }
837
936
  };
@@ -853,24 +952,24 @@ void ggml_metal_graph_compute(
853
952
  [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
854
953
  [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
855
954
  [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
856
- [encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
955
+ [encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
857
956
 
858
- if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
957
+ if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q8_0 ||
859
958
  src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
860
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
959
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
861
960
  }
862
961
  else if (src0t == GGML_TYPE_Q3_K) {
863
962
  #ifdef GGML_QKK_64
864
- [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
963
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
865
964
  #else
866
- [encoder dispatchThreadgroups:MTLSizeMake((ne01+3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
965
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
867
966
  #endif
868
967
  }
869
968
  else if (src0t == GGML_TYPE_Q5_K) {
870
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3) / 4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
969
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
871
970
  }
872
971
  else if (src0t == GGML_TYPE_Q6_K) {
873
- [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
972
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
874
973
  } else {
875
974
  [encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
876
975
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
@@ -880,9 +979,10 @@ void ggml_metal_graph_compute(
880
979
  case GGML_OP_GET_ROWS:
881
980
  {
882
981
  switch (src0->type) {
883
- case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
982
+ case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
884
983
  case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
885
984
  case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
985
+ case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q8_0]; break;
886
986
  case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_K]; break;
887
987
  case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_K]; break;
888
988
  case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_K]; break;
@@ -923,16 +1023,17 @@ void ggml_metal_graph_compute(
923
1023
  } break;
924
1024
  case GGML_OP_NORM:
925
1025
  {
926
- const float eps = 1e-5f;
1026
+ float eps;
1027
+ memcpy(&eps, dst->op_params, sizeof(float));
927
1028
 
928
1029
  const int nth = 256;
929
1030
 
930
1031
  [encoder setComputePipelineState:ctx->pipeline_norm];
931
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
932
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
933
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
934
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
935
- [encoder setBytes:&eps length:sizeof( float) atIndex:4];
1032
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1033
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1034
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1035
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
1036
+ [encoder setBytes:&eps length:sizeof( float) atIndex:4];
936
1037
  [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
937
1038
 
938
1039
  const int64_t nrows = ggml_nrows(src0);
@@ -975,7 +1076,9 @@ void ggml_metal_graph_compute(
975
1076
  [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
976
1077
  [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
977
1078
  [encoder setBytes:&m0 length:sizeof( float) atIndex:18];
1079
+
978
1080
  const int nth = 32;
1081
+
979
1082
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
980
1083
  } break;
981
1084
  case GGML_OP_ROPE:
@@ -990,8 +1093,8 @@ void ggml_metal_graph_compute(
990
1093
  memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
991
1094
 
992
1095
  [encoder setComputePipelineState:ctx->pipeline_rope];
993
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
994
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1096
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1097
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
995
1098
  [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
996
1099
  [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
997
1100
  [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
@@ -1042,30 +1145,30 @@ void ggml_metal_graph_compute(
1042
1145
  default: GGML_ASSERT(false && "not implemented");
1043
1146
  }
1044
1147
 
1045
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1046
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1047
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1048
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
1049
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
1050
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
1051
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
1052
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
1053
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
1054
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
1055
- [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
1056
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
1057
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
1058
- [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
1059
- [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
1060
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
1061
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
1062
- [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
1148
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1149
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1150
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1151
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
1152
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
1153
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
1154
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
1155
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
1156
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
1157
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
1158
+ [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
1159
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
1160
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
1161
+ [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
1162
+ [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
1163
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
1164
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
1165
+ [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
1063
1166
 
1064
1167
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1065
1168
  } break;
1066
1169
  default:
1067
1170
  {
1068
- fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
1171
+ metal_printf("%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
1069
1172
  GGML_ASSERT(false);
1070
1173
  }
1071
1174
  }
@@ -1081,17 +1184,19 @@ void ggml_metal_graph_compute(
1081
1184
  }
1082
1185
 
1083
1186
  // wait for all threads to finish
1084
- dispatch_barrier_sync(queue, ^{});
1085
-
1086
- [command_buffers[n_cb - 1] waitUntilCompleted];
1187
+ dispatch_barrier_sync(ctx->d_queue, ^{});
1087
1188
 
1088
1189
  // check status of command buffers
1089
1190
  // needed to detect if the device ran out-of-memory for example (#1881)
1090
1191
  for (int i = 0; i < n_cb; i++) {
1091
- MTLCommandBufferStatus status = (MTLCommandBufferStatus) [command_buffers[i] status];
1192
+ [ctx->command_buffers[i] waitUntilCompleted];
1193
+
1194
+ MTLCommandBufferStatus status = (MTLCommandBufferStatus) [ctx->command_buffers[i] status];
1092
1195
  if (status != MTLCommandBufferStatusCompleted) {
1093
- fprintf(stderr, "%s: command buffer %d failed with status %lu\n", __func__, i, status);
1196
+ metal_printf("%s: command buffer %d failed with status %lu\n", __func__, i, status);
1094
1197
  GGML_ASSERT(false);
1095
1198
  }
1096
1199
  }
1200
+
1201
+ }
1097
1202
  }