llama_cpp 0.5.2 → 0.6.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -11,11 +11,14 @@
11
11
  #define MIN(a, b) ((a) < (b) ? (a) : (b))
12
12
  #define MAX(a, b) ((a) > (b) ? (a) : (b))
13
13
 
14
- // TODO: temporary - reuse llama.cpp logging
15
14
  #ifdef GGML_METAL_NDEBUG
16
- #define metal_printf(...)
15
+ #define GGML_METAL_LOG_INFO(...)
16
+ #define GGML_METAL_LOG_WARN(...)
17
+ #define GGML_METAL_LOG_ERROR(...)
17
18
  #else
18
- #define metal_printf(...) fprintf(stderr, __VA_ARGS__)
19
+ #define GGML_METAL_LOG_INFO(...) ggml_metal_log(GGML_LOG_LEVEL_INFO, __VA_ARGS__)
20
+ #define GGML_METAL_LOG_WARN(...) ggml_metal_log(GGML_LOG_LEVEL_WARN, __VA_ARGS__)
21
+ #define GGML_METAL_LOG_ERROR(...) ggml_metal_log(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
19
22
  #endif
20
23
 
21
24
  #define UNUSED(x) (void)(x)
@@ -66,6 +69,7 @@ struct ggml_metal_context {
66
69
  GGML_METAL_DECL_KERNEL(soft_max_4);
67
70
  GGML_METAL_DECL_KERNEL(diag_mask_inf);
68
71
  GGML_METAL_DECL_KERNEL(diag_mask_inf_8);
72
+ GGML_METAL_DECL_KERNEL(get_rows_f32);
69
73
  GGML_METAL_DECL_KERNEL(get_rows_f16);
70
74
  GGML_METAL_DECL_KERNEL(get_rows_q4_0);
71
75
  GGML_METAL_DECL_KERNEL(get_rows_q4_1);
@@ -77,6 +81,7 @@ struct ggml_metal_context {
77
81
  GGML_METAL_DECL_KERNEL(get_rows_q6_K);
78
82
  GGML_METAL_DECL_KERNEL(rms_norm);
79
83
  GGML_METAL_DECL_KERNEL(norm);
84
+ GGML_METAL_DECL_KERNEL(mul_mat_f32_f32);
80
85
  GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
81
86
  GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_1row);
82
87
  GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_l4);
@@ -88,6 +93,7 @@ struct ggml_metal_context {
88
93
  GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32);
89
94
  GGML_METAL_DECL_KERNEL(mul_mat_q5_K_f32);
90
95
  GGML_METAL_DECL_KERNEL(mul_mat_q6_K_f32);
96
+ GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
91
97
  GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
92
98
  GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
93
99
  GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32);
@@ -97,7 +103,8 @@ struct ggml_metal_context {
97
103
  GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
98
104
  GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32);
99
105
  GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32);
100
- GGML_METAL_DECL_KERNEL(rope);
106
+ GGML_METAL_DECL_KERNEL(rope_f32);
107
+ GGML_METAL_DECL_KERNEL(rope_f16);
101
108
  GGML_METAL_DECL_KERNEL(alibi_f32);
102
109
  GGML_METAL_DECL_KERNEL(cpy_f32_f16);
103
110
  GGML_METAL_DECL_KERNEL(cpy_f32_f32);
@@ -117,8 +124,37 @@ static NSString * const msl_library_source = @"see metal.metal";
117
124
  @implementation GGMLMetalClass
118
125
  @end
119
126
 
127
+ ggml_log_callback ggml_metal_log_callback = NULL;
128
+ void * ggml_metal_log_user_data = NULL;
129
+
130
+ void ggml_metal_log_set_callback(ggml_log_callback log_callback, void * user_data) {
131
+ ggml_metal_log_callback = log_callback;
132
+ ggml_metal_log_user_data = user_data;
133
+ }
134
+
135
+ static void ggml_metal_log(enum ggml_log_level level, const char* format, ...){
136
+ if (ggml_metal_log_callback != NULL) {
137
+ va_list args;
138
+ va_start(args, format);
139
+ char buffer[128];
140
+ int len = vsnprintf(buffer, 128, format, args);
141
+ if (len < 128) {
142
+ ggml_metal_log_callback(level, buffer, ggml_metal_log_user_data);
143
+ } else {
144
+ char* buffer2 = malloc(len+1);
145
+ vsnprintf(buffer2, len+1, format, args);
146
+ buffer2[len] = 0;
147
+ ggml_metal_log_callback(level, buffer2, ggml_metal_log_user_data);
148
+ free(buffer2);
149
+ }
150
+ va_end(args);
151
+ }
152
+ }
153
+
154
+
155
+
120
156
  struct ggml_metal_context * ggml_metal_init(int n_cb) {
121
- metal_printf("%s: allocating\n", __func__);
157
+ GGML_METAL_LOG_INFO("%s: allocating\n", __func__);
122
158
 
123
159
  id <MTLDevice> device;
124
160
  NSString * s;
@@ -128,14 +164,14 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
128
164
  NSArray * devices = MTLCopyAllDevices();
129
165
  for (device in devices) {
130
166
  s = [device name];
131
- metal_printf("%s: found device: %s\n", __func__, [s UTF8String]);
167
+ GGML_METAL_LOG_INFO("%s: found device: %s\n", __func__, [s UTF8String]);
132
168
  }
133
169
  #endif
134
170
 
135
171
  // Pick and show default Metal device
136
172
  device = MTLCreateSystemDefaultDevice();
137
173
  s = [device name];
138
- metal_printf("%s: picking default device: %s\n", __func__, [s UTF8String]);
174
+ GGML_METAL_LOG_INFO("%s: picking default device: %s\n", __func__, [s UTF8String]);
139
175
 
140
176
  // Configure context
141
177
  struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context));
@@ -145,7 +181,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
145
181
  ctx->n_buffers = 0;
146
182
  ctx->concur_list_len = 0;
147
183
 
148
- ctx->d_queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT);
184
+ ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
149
185
 
150
186
  #ifdef GGML_SWIFT
151
187
  // load the default.metallib file
@@ -162,7 +198,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
162
198
  ctx->library = [ctx->device newLibraryWithURL:libURL error:&error];
163
199
 
164
200
  if (error) {
165
- metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]);
201
+ GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
166
202
  return NULL;
167
203
  }
168
204
  }
@@ -175,12 +211,12 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
175
211
 
176
212
  //NSString * path = [[NSBundle mainBundle] pathForResource:@"../../examples/metal/metal" ofType:@"metal"];
177
213
  NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
178
- NSString * path = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
179
- metal_printf("%s: loading '%s'\n", __func__, [path UTF8String]);
214
+ NSString * path = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
215
+ GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [path UTF8String]);
180
216
 
181
217
  NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error];
182
218
  if (error) {
183
- metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]);
219
+ GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
184
220
  return NULL;
185
221
  }
186
222
 
@@ -192,7 +228,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
192
228
  ctx->library = [ctx->device newLibraryWithSource:src options:nil error:&error];
193
229
  #endif
194
230
  if (error) {
195
- metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]);
231
+ GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
196
232
  return NULL;
197
233
  }
198
234
  }
@@ -204,11 +240,11 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
204
240
  #define GGML_METAL_ADD_KERNEL(name) \
205
241
  ctx->function_##name = [ctx->library newFunctionWithName:@"kernel_"#name]; \
206
242
  ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:&error]; \
207
- metal_printf("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name, \
243
+ GGML_METAL_LOG_INFO("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name, \
208
244
  (int) ctx->pipeline_##name.maxTotalThreadsPerThreadgroup, \
209
245
  (int) ctx->pipeline_##name.threadExecutionWidth); \
210
246
  if (error) { \
211
- metal_printf("%s: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
247
+ GGML_METAL_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
212
248
  return NULL; \
213
249
  }
214
250
 
@@ -224,6 +260,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
224
260
  GGML_METAL_ADD_KERNEL(soft_max_4);
225
261
  GGML_METAL_ADD_KERNEL(diag_mask_inf);
226
262
  GGML_METAL_ADD_KERNEL(diag_mask_inf_8);
263
+ GGML_METAL_ADD_KERNEL(get_rows_f32);
227
264
  GGML_METAL_ADD_KERNEL(get_rows_f16);
228
265
  GGML_METAL_ADD_KERNEL(get_rows_q4_0);
229
266
  GGML_METAL_ADD_KERNEL(get_rows_q4_1);
@@ -235,6 +272,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
235
272
  GGML_METAL_ADD_KERNEL(get_rows_q6_K);
236
273
  GGML_METAL_ADD_KERNEL(rms_norm);
237
274
  GGML_METAL_ADD_KERNEL(norm);
275
+ GGML_METAL_ADD_KERNEL(mul_mat_f32_f32);
238
276
  GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
239
277
  GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_1row);
240
278
  GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_l4);
@@ -246,6 +284,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
246
284
  GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32);
247
285
  GGML_METAL_ADD_KERNEL(mul_mat_q5_K_f32);
248
286
  GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32);
287
+ GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
249
288
  GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
250
289
  GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
251
290
  GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
@@ -255,7 +294,8 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
255
294
  GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
256
295
  GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
257
296
  GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
258
- GGML_METAL_ADD_KERNEL(rope);
297
+ GGML_METAL_ADD_KERNEL(rope_f32);
298
+ GGML_METAL_ADD_KERNEL(rope_f16);
259
299
  GGML_METAL_ADD_KERNEL(alibi_f32);
260
300
  GGML_METAL_ADD_KERNEL(cpy_f32_f16);
261
301
  GGML_METAL_ADD_KERNEL(cpy_f32_f32);
@@ -264,13 +304,13 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
264
304
  #undef GGML_METAL_ADD_KERNEL
265
305
  }
266
306
 
267
- metal_printf("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
307
+ GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
268
308
  #if TARGET_OS_OSX
269
- metal_printf("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
309
+ GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
270
310
  if (ctx->device.maxTransferRate != 0) {
271
- metal_printf("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
311
+ GGML_METAL_LOG_INFO("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
272
312
  } else {
273
- metal_printf("%s: maxTransferRate = built-in GPU\n", __func__);
313
+ GGML_METAL_LOG_INFO("%s: maxTransferRate = built-in GPU\n", __func__);
274
314
  }
275
315
  #endif
276
316
 
@@ -278,7 +318,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
278
318
  }
279
319
 
280
320
  void ggml_metal_free(struct ggml_metal_context * ctx) {
281
- metal_printf("%s: deallocating\n", __func__);
321
+ GGML_METAL_LOG_INFO("%s: deallocating\n", __func__);
282
322
  #define GGML_METAL_DEL_KERNEL(name) \
283
323
  [ctx->function_##name release]; \
284
324
  [ctx->pipeline_##name release];
@@ -293,7 +333,9 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
293
333
  GGML_METAL_DEL_KERNEL(gelu);
294
334
  GGML_METAL_DEL_KERNEL(soft_max);
295
335
  GGML_METAL_DEL_KERNEL(soft_max_4);
336
+ GGML_METAL_DEL_KERNEL(diag_mask_inf);
296
337
  GGML_METAL_DEL_KERNEL(diag_mask_inf_8);
338
+ GGML_METAL_DEL_KERNEL(get_rows_f32);
297
339
  GGML_METAL_DEL_KERNEL(get_rows_f16);
298
340
  GGML_METAL_DEL_KERNEL(get_rows_q4_0);
299
341
  GGML_METAL_DEL_KERNEL(get_rows_q4_1);
@@ -305,6 +347,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
305
347
  GGML_METAL_DEL_KERNEL(get_rows_q6_K);
306
348
  GGML_METAL_DEL_KERNEL(rms_norm);
307
349
  GGML_METAL_DEL_KERNEL(norm);
350
+ GGML_METAL_DEL_KERNEL(mul_mat_f32_f32);
308
351
  GGML_METAL_DEL_KERNEL(mul_mat_f16_f32);
309
352
  GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_1row);
310
353
  GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_l4);
@@ -316,6 +359,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
316
359
  GGML_METAL_DEL_KERNEL(mul_mat_q4_K_f32);
317
360
  GGML_METAL_DEL_KERNEL(mul_mat_q5_K_f32);
318
361
  GGML_METAL_DEL_KERNEL(mul_mat_q6_K_f32);
362
+ GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
319
363
  GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
320
364
  GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
321
365
  GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
@@ -325,7 +369,8 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
325
369
  GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
326
370
  GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
327
371
  GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
328
- GGML_METAL_DEL_KERNEL(rope);
372
+ GGML_METAL_DEL_KERNEL(rope_f32);
373
+ GGML_METAL_DEL_KERNEL(rope_f16);
329
374
  GGML_METAL_DEL_KERNEL(alibi_f32);
330
375
  GGML_METAL_DEL_KERNEL(cpy_f32_f16);
331
376
  GGML_METAL_DEL_KERNEL(cpy_f32_f32);
@@ -350,7 +395,7 @@ void * ggml_metal_host_malloc(size_t n) {
350
395
  void * data = NULL;
351
396
  const int result = posix_memalign((void **) &data, sysconf(_SC_PAGESIZE), n);
352
397
  if (result != 0) {
353
- metal_printf("%s: error: posix_memalign failed\n", __func__);
398
+ GGML_METAL_LOG_ERROR("%s: error: posix_memalign failed\n", __func__);
354
399
  return NULL;
355
400
  }
356
401
 
@@ -378,7 +423,7 @@ int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx) {
378
423
  // Metal buffer based on the host memory pointer
379
424
  //
380
425
  static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, struct ggml_tensor * t, size_t * offs) {
381
- //metal_printf("%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach);
426
+ //GGML_METAL_LOG_INFO("%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach);
382
427
 
383
428
  const int64_t tsize = ggml_nbytes(t);
384
429
 
@@ -386,16 +431,17 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru
386
431
  for (int i = 0; i < ctx->n_buffers; ++i) {
387
432
  const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data;
388
433
 
434
+ //metal_printf("ioffs = %10ld, tsize = %10ld, sum = %10ld, ctx->buffers[%d].size = %10ld, name = %s\n", ioffs, tsize, ioffs + tsize, i, ctx->buffers[i].size, ctx->buffers[i].name);
389
435
  if (ioffs >= 0 && ioffs + tsize <= (int64_t) ctx->buffers[i].size) {
390
436
  *offs = (size_t) ioffs;
391
437
 
392
- //metal_printf("%s: '%s' tensor '%16s', offs = %8ld\n", __func__, ctx->buffers[i].name, t->name, *offs);
438
+ //GGML_METAL_LOG_INFO("%s: '%s' tensor '%16s', offs = %8ld\n", __func__, ctx->buffers[i].name, t->name, *offs);
393
439
 
394
440
  return ctx->buffers[i].metal;
395
441
  }
396
442
  }
397
443
 
398
- metal_printf("%s: error: buffer is nil\n", __func__);
444
+ GGML_METAL_LOG_ERROR("%s: error: buffer is nil\n", __func__);
399
445
 
400
446
  return nil;
401
447
  }
@@ -407,7 +453,7 @@ bool ggml_metal_add_buffer(
407
453
  size_t size,
408
454
  size_t max_size) {
409
455
  if (ctx->n_buffers >= GGML_METAL_MAX_BUFFERS) {
410
- metal_printf("%s: too many buffers\n", __func__);
456
+ GGML_METAL_LOG_ERROR("%s: error: too many buffers\n", __func__);
411
457
  return false;
412
458
  }
413
459
 
@@ -417,7 +463,7 @@ bool ggml_metal_add_buffer(
417
463
  const int64_t ioffs = (int64_t) data - (int64_t) ctx->buffers[i].data;
418
464
 
419
465
  if (ioffs >= 0 && ioffs < (int64_t) ctx->buffers[i].size) {
420
- metal_printf("%s: error: buffer '%s' overlaps with '%s'\n", __func__, name, ctx->buffers[i].name);
466
+ GGML_METAL_LOG_ERROR("%s: error: buffer '%s' overlaps with '%s'\n", __func__, name, ctx->buffers[i].name);
421
467
  return false;
422
468
  }
423
469
  }
@@ -438,11 +484,11 @@ bool ggml_metal_add_buffer(
438
484
  ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
439
485
 
440
486
  if (ctx->buffers[ctx->n_buffers].metal == nil) {
441
- metal_printf("%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_aligned / 1024.0 / 1024.0);
487
+ GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_aligned / 1024.0 / 1024.0);
442
488
  return false;
443
489
  }
444
490
 
445
- metal_printf("%s: allocated '%-16s' buffer, size = %8.2f MB", __func__, name, size_aligned / 1024.0 / 1024.0);
491
+ GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MB", __func__, name, size_aligned / 1024.0 / 1024.0);
446
492
 
447
493
  ++ctx->n_buffers;
448
494
  } else {
@@ -462,13 +508,13 @@ bool ggml_metal_add_buffer(
462
508
  ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
463
509
 
464
510
  if (ctx->buffers[ctx->n_buffers].metal == nil) {
465
- metal_printf("%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0);
511
+ GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0);
466
512
  return false;
467
513
  }
468
514
 
469
- metal_printf("%s: allocated '%-16s' buffer, size = %8.2f MB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i);
515
+ GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i);
470
516
  if (i + size_step < size) {
471
- metal_printf("\n");
517
+ GGML_METAL_LOG_INFO("\n");
472
518
  }
473
519
 
474
520
  ++ctx->n_buffers;
@@ -476,17 +522,17 @@ bool ggml_metal_add_buffer(
476
522
  }
477
523
 
478
524
  #if TARGET_OS_OSX
479
- metal_printf(", (%8.2f / %8.2f)",
525
+ GGML_METAL_LOG_INFO(", (%8.2f / %8.2f)",
480
526
  ctx->device.currentAllocatedSize / 1024.0 / 1024.0,
481
527
  ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
482
528
 
483
529
  if (ctx->device.currentAllocatedSize > ctx->device.recommendedMaxWorkingSetSize) {
484
- metal_printf(", warning: current allocated size is greater than the recommended max working set size\n");
530
+ GGML_METAL_LOG_WARN(", warning: current allocated size is greater than the recommended max working set size\n", __func__);
485
531
  } else {
486
- metal_printf("\n");
532
+ GGML_METAL_LOG_INFO("\n");
487
533
  }
488
534
  #else
489
- metal_printf(", (%8.2f)\n", ctx->device.currentAllocatedSize / 1024.0 / 1024.0);
535
+ GGML_METAL_LOG_INFO(", (%8.2f)\n", ctx->device.currentAllocatedSize / 1024.0 / 1024.0);
490
536
  #endif
491
537
  }
492
538
 
@@ -599,7 +645,7 @@ void ggml_metal_graph_find_concurrency(
599
645
  }
600
646
 
601
647
  if (ctx->concur_list_len > GGML_MAX_CONCUR) {
602
- metal_printf("%s: too many elements for metal ctx->concur_list!\n", __func__);
648
+ GGML_METAL_LOG_WARN("%s: too many elements for metal ctx->concur_list!\n", __func__);
603
649
  }
604
650
  }
605
651
 
@@ -653,7 +699,7 @@ void ggml_metal_graph_compute(
653
699
  continue;
654
700
  }
655
701
 
656
- //metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
702
+ //GGML_METAL_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
657
703
 
658
704
  struct ggml_tensor * src0 = gf->nodes[i]->src[0];
659
705
  struct ggml_tensor * src1 = gf->nodes[i]->src[1];
@@ -697,17 +743,17 @@ void ggml_metal_graph_compute(
697
743
  id<MTLBuffer> id_src1 = src1 ? ggml_metal_get_buffer(ctx, src1, &offs_src1) : nil;
698
744
  id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(ctx, dst, &offs_dst) : nil;
699
745
 
700
- //metal_printf("%s: op - %s\n", __func__, ggml_op_name(dst->op));
746
+ //GGML_METAL_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
701
747
  //if (src0) {
702
- // metal_printf("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02,
748
+ // GGML_METAL_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02,
703
749
  // ggml_is_contiguous(src0), src0->name);
704
750
  //}
705
751
  //if (src1) {
706
- // metal_printf("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12,
752
+ // GGML_METAL_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12,
707
753
  // ggml_is_contiguous(src1), src1->name);
708
754
  //}
709
755
  //if (dst) {
710
- // metal_printf("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2,
756
+ // GGML_METAL_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2,
711
757
  // dst->name);
712
758
  //}
713
759
 
@@ -723,29 +769,66 @@ void ggml_metal_graph_compute(
723
769
  case GGML_OP_ADD:
724
770
  {
725
771
  GGML_ASSERT(ggml_is_contiguous(src0));
772
+ GGML_ASSERT(ggml_is_contiguous(src1));
726
773
 
727
- // utilize float4
728
- GGML_ASSERT(ne00 % 4 == 0);
729
- const int64_t nb = ne00/4;
774
+ bool bcast_row = false;
730
775
 
731
- if (ggml_nelements(src1) == ne10) {
776
+ int64_t nb = ne00;
777
+
778
+ if (ggml_nelements(src1) == ne10 && ne00 % 4 == 0) {
732
779
  // src1 is a row
780
+ GGML_ASSERT(ne11 == 1);
781
+
782
+ nb = ne00 / 4;
733
783
  [encoder setComputePipelineState:ctx->pipeline_add_row];
784
+
785
+ bcast_row = true;
734
786
  } else {
735
787
  [encoder setComputePipelineState:ctx->pipeline_add];
736
788
  }
737
789
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
738
790
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
739
791
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
740
- [encoder setBytes:&nb length:sizeof(nb) atIndex:3];
741
-
742
- const int64_t n = ggml_nelements(dst)/4;
792
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
793
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
794
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
795
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
796
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
797
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
798
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
799
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
800
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
801
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
802
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
803
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
804
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
805
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
806
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
807
+ [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
808
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
809
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
810
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
811
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
812
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
813
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
814
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
815
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
816
+ [encoder setBytes:&nb length:sizeof(nb) atIndex:27];
817
+
818
+ if (bcast_row) {
819
+ const int64_t n = ggml_nelements(dst)/4;
820
+
821
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
822
+ } else {
823
+ const int nth = MIN(1024, ne0);
743
824
 
744
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
825
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
826
+ }
745
827
  } break;
746
828
  case GGML_OP_MUL:
747
829
  {
748
830
  GGML_ASSERT(ggml_is_contiguous(src0));
831
+ GGML_ASSERT(ggml_is_contiguous(src1));
749
832
 
750
833
  // utilize float4
751
834
  GGML_ASSERT(ne00 % 4 == 0);
@@ -753,6 +836,7 @@ void ggml_metal_graph_compute(
753
836
 
754
837
  if (ggml_nelements(src1) == ne10) {
755
838
  // src1 is a row
839
+ GGML_ASSERT(ne11 == 1);
756
840
  [encoder setComputePipelineState:ctx->pipeline_mul_row];
757
841
  } else {
758
842
  [encoder setComputePipelineState:ctx->pipeline_mul];
@@ -768,6 +852,8 @@ void ggml_metal_graph_compute(
768
852
  } break;
769
853
  case GGML_OP_SCALE:
770
854
  {
855
+ GGML_ASSERT(ggml_is_contiguous(src0));
856
+
771
857
  const float scale = *(const float *) src1->data;
772
858
 
773
859
  [encoder setComputePipelineState:ctx->pipeline_scale];
@@ -813,13 +899,13 @@ void ggml_metal_graph_compute(
813
899
  } break;
814
900
  default:
815
901
  {
816
- metal_printf("%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
902
+ GGML_METAL_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
817
903
  GGML_ASSERT(false);
818
904
  }
819
905
  } break;
820
906
  case GGML_OP_SOFT_MAX:
821
907
  {
822
- const int nth = 32;
908
+ const int nth = MIN(32, ne00);
823
909
 
824
910
  if (ne00%4 == 0) {
825
911
  [encoder setComputePipelineState:ctx->pipeline_soft_max_4];
@@ -867,13 +953,14 @@ void ggml_metal_graph_compute(
867
953
 
868
954
  // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
869
955
  // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
870
- if (ggml_is_contiguous(src0) &&
871
- ggml_is_contiguous(src1) &&
956
+ if (!ggml_is_transposed(src0) &&
957
+ !ggml_is_transposed(src1) &&
872
958
  src1t == GGML_TYPE_F32 &&
873
959
  [ctx->device supportsFamily:MTLGPUFamilyApple7] &&
874
960
  ne00%32 == 0 &&
875
- ne11 > 1) {
961
+ ne11 > 2) {
876
962
  switch (src0->type) {
963
+ case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32]; break;
877
964
  case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
878
965
  case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
879
966
  case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
@@ -893,9 +980,12 @@ void ggml_metal_graph_compute(
893
980
  [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
894
981
  [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
895
982
  [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
896
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:8];
897
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:9];
898
- [encoder setBytes:&gqa length:sizeof(gqa) atIndex:10];
983
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8];
984
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9];
985
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10];
986
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11];
987
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
988
+ [encoder setBytes:&gqa length:sizeof(gqa) atIndex:13];
899
989
  [encoder setThreadgroupMemoryLength:8192 atIndex:0];
900
990
  [encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
901
991
  } else {
@@ -905,6 +995,11 @@ void ggml_metal_graph_compute(
905
995
 
906
996
  // use custom matrix x vector kernel
907
997
  switch (src0t) {
998
+ case GGML_TYPE_F32:
999
+ {
1000
+ [encoder setComputePipelineState:ctx->pipeline_mul_mat_f32_f32];
1001
+ nrows = 4;
1002
+ } break;
908
1003
  case GGML_TYPE_F16:
909
1004
  {
910
1005
  nth0 = 32;
@@ -993,7 +1088,7 @@ void ggml_metal_graph_compute(
993
1088
  } break;
994
1089
  default:
995
1090
  {
996
- metal_printf("Asserting on type %d\n",(int)src0t);
1091
+ GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
997
1092
  GGML_ASSERT(false && "not implemented");
998
1093
  }
999
1094
  };
@@ -1045,6 +1140,7 @@ void ggml_metal_graph_compute(
1045
1140
  case GGML_OP_GET_ROWS:
1046
1141
  {
1047
1142
  switch (src0->type) {
1143
+ case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_get_rows_f32]; break;
1048
1144
  case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
1049
1145
  case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
1050
1146
  case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
@@ -1060,9 +1156,9 @@ void ggml_metal_graph_compute(
1060
1156
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1061
1157
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1062
1158
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1063
- [encoder setBytes:&(src0->ne[0]) length:sizeof( int64_t) atIndex:3];
1064
- [encoder setBytes:&(src0->nb[1]) length:sizeof(uint64_t) atIndex:4];
1065
- [encoder setBytes:&(dst->nb[1]) length:sizeof(uint64_t) atIndex:5];
1159
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
1160
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
1161
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:5];
1066
1162
 
1067
1163
  const int64_t n = ggml_nelements(src1);
1068
1164
 
@@ -1073,7 +1169,7 @@ void ggml_metal_graph_compute(
1073
1169
  float eps;
1074
1170
  memcpy(&eps, dst->op_params, sizeof(float));
1075
1171
 
1076
- const int nth = 512;
1172
+ const int nth = MIN(512, ne00);
1077
1173
 
1078
1174
  [encoder setComputePipelineState:ctx->pipeline_rms_norm];
1079
1175
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@@ -1092,7 +1188,7 @@ void ggml_metal_graph_compute(
1092
1188
  float eps;
1093
1189
  memcpy(&eps, dst->op_params, sizeof(float));
1094
1190
 
1095
- const int nth = 256;
1191
+ const int nth = MIN(256, ne00);
1096
1192
 
1097
1193
  [encoder setComputePipelineState:ctx->pipeline_norm];
1098
1194
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@@ -1110,6 +1206,8 @@ void ggml_metal_graph_compute(
1110
1206
  {
1111
1207
  GGML_ASSERT((src0t == GGML_TYPE_F32));
1112
1208
 
1209
+ const int nth = MIN(1024, ne00);
1210
+
1113
1211
  const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past);
1114
1212
  const int n_head = ((int32_t *) dst->op_params)[1];
1115
1213
  float max_bias;
@@ -1143,12 +1241,14 @@ void ggml_metal_graph_compute(
1143
1241
  [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
1144
1242
  [encoder setBytes:&m0 length:sizeof( float) atIndex:18];
1145
1243
 
1146
- const int nth = 32;
1147
-
1148
1244
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1149
1245
  } break;
1150
1246
  case GGML_OP_ROPE:
1151
1247
  {
1248
+ GGML_ASSERT(ne10 == ne02);
1249
+
1250
+ const int nth = MIN(1024, ne00);
1251
+
1152
1252
  const int n_past = ((int32_t *) dst->op_params)[0];
1153
1253
  const int n_dims = ((int32_t *) dst->op_params)[1];
1154
1254
  const int mode = ((int32_t *) dst->op_params)[2];
@@ -1158,38 +1258,44 @@ void ggml_metal_graph_compute(
1158
1258
  memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float));
1159
1259
  memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
1160
1260
 
1161
- [encoder setComputePipelineState:ctx->pipeline_rope];
1261
+ switch (src0->type) {
1262
+ case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_rope_f32]; break;
1263
+ case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_rope_f16]; break;
1264
+ default: GGML_ASSERT(false);
1265
+ };
1266
+
1162
1267
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1163
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1164
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1165
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
1166
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
1167
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
1168
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
1169
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
1170
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
1171
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
1172
- [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
1173
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
1174
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
1175
- [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
1176
- [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
1177
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
1178
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
1179
- [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
1180
- [encoder setBytes:&n_past length:sizeof( int) atIndex:18];
1181
- [encoder setBytes:&n_dims length:sizeof( int) atIndex:19];
1182
- [encoder setBytes:&mode length:sizeof( int) atIndex:20];
1183
- [encoder setBytes:&freq_base length:sizeof(float) atIndex:21];
1184
- [encoder setBytes:&freq_scale length:sizeof(float) atIndex:22];
1268
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1269
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1270
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
1271
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:4];
1272
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:5];
1273
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:6];
1274
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:7];
1275
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
1276
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
1277
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
1278
+ [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:11];
1279
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:12];
1280
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:13];
1281
+ [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:14];
1282
+ [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:15];
1283
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:16];
1284
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:17];
1285
+ [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:18];
1286
+ [encoder setBytes:&n_past length:sizeof( int) atIndex:19];
1287
+ [encoder setBytes:&n_dims length:sizeof( int) atIndex:20];
1288
+ [encoder setBytes:&mode length:sizeof( int) atIndex:21];
1289
+ [encoder setBytes:&freq_base length:sizeof(float) atIndex:22];
1290
+ [encoder setBytes:&freq_scale length:sizeof(float) atIndex:23];
1185
1291
 
1186
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
1292
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1187
1293
  } break;
1188
1294
  case GGML_OP_DUP:
1189
1295
  case GGML_OP_CPY:
1190
1296
  case GGML_OP_CONT:
1191
1297
  {
1192
- const int nth = 32;
1298
+ const int nth = MIN(1024, ne00);
1193
1299
 
1194
1300
  switch (src0t) {
1195
1301
  case GGML_TYPE_F32:
@@ -1234,7 +1340,7 @@ void ggml_metal_graph_compute(
1234
1340
  } break;
1235
1341
  default:
1236
1342
  {
1237
- metal_printf("%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
1343
+ GGML_METAL_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
1238
1344
  GGML_ASSERT(false);
1239
1345
  }
1240
1346
  }
@@ -1259,7 +1365,7 @@ void ggml_metal_graph_compute(
1259
1365
 
1260
1366
  MTLCommandBufferStatus status = (MTLCommandBufferStatus) [ctx->command_buffers[i] status];
1261
1367
  if (status != MTLCommandBufferStatusCompleted) {
1262
- metal_printf("%s: command buffer %d failed with status %lu\n", __func__, i, status);
1368
+ GGML_METAL_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
1263
1369
  GGML_ASSERT(false);
1264
1370
  }
1265
1371
  }