llama_cpp 0.5.3 → 0.6.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -11,11 +11,14 @@
11
11
  #define MIN(a, b) ((a) < (b) ? (a) : (b))
12
12
  #define MAX(a, b) ((a) > (b) ? (a) : (b))
13
13
 
14
- // TODO: temporary - reuse llama.cpp logging
15
14
  #ifdef GGML_METAL_NDEBUG
16
- #define metal_printf(...)
15
+ #define GGML_METAL_LOG_INFO(...)
16
+ #define GGML_METAL_LOG_WARN(...)
17
+ #define GGML_METAL_LOG_ERROR(...)
17
18
  #else
18
- #define metal_printf(...) fprintf(stderr, __VA_ARGS__)
19
+ #define GGML_METAL_LOG_INFO(...) ggml_metal_log(GGML_LOG_LEVEL_INFO, __VA_ARGS__)
20
+ #define GGML_METAL_LOG_WARN(...) ggml_metal_log(GGML_LOG_LEVEL_WARN, __VA_ARGS__)
21
+ #define GGML_METAL_LOG_ERROR(...) ggml_metal_log(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
19
22
  #endif
20
23
 
21
24
  #define UNUSED(x) (void)(x)
@@ -100,7 +103,8 @@ struct ggml_metal_context {
100
103
  GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
101
104
  GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32);
102
105
  GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32);
103
- GGML_METAL_DECL_KERNEL(rope);
106
+ GGML_METAL_DECL_KERNEL(rope_f32);
107
+ GGML_METAL_DECL_KERNEL(rope_f16);
104
108
  GGML_METAL_DECL_KERNEL(alibi_f32);
105
109
  GGML_METAL_DECL_KERNEL(cpy_f32_f16);
106
110
  GGML_METAL_DECL_KERNEL(cpy_f32_f32);
@@ -120,8 +124,37 @@ static NSString * const msl_library_source = @"see metal.metal";
120
124
  @implementation GGMLMetalClass
121
125
  @end
122
126
 
127
+ ggml_log_callback ggml_metal_log_callback = NULL;
128
+ void * ggml_metal_log_user_data = NULL;
129
+
130
+ void ggml_metal_log_set_callback(ggml_log_callback log_callback, void * user_data) {
131
+ ggml_metal_log_callback = log_callback;
132
+ ggml_metal_log_user_data = user_data;
133
+ }
134
+
135
+ static void ggml_metal_log(enum ggml_log_level level, const char* format, ...){
136
+ if (ggml_metal_log_callback != NULL) {
137
+ va_list args;
138
+ va_start(args, format);
139
+ char buffer[128];
140
+ int len = vsnprintf(buffer, 128, format, args);
141
+ if (len < 128) {
142
+ ggml_metal_log_callback(level, buffer, ggml_metal_log_user_data);
143
+ } else {
144
+ char* buffer2 = malloc(len+1);
145
+ vsnprintf(buffer2, len+1, format, args);
146
+ buffer2[len] = 0;
147
+ ggml_metal_log_callback(level, buffer2, ggml_metal_log_user_data);
148
+ free(buffer2);
149
+ }
150
+ va_end(args);
151
+ }
152
+ }
153
+
154
+
155
+
123
156
  struct ggml_metal_context * ggml_metal_init(int n_cb) {
124
- metal_printf("%s: allocating\n", __func__);
157
+ GGML_METAL_LOG_INFO("%s: allocating\n", __func__);
125
158
 
126
159
  id <MTLDevice> device;
127
160
  NSString * s;
@@ -131,14 +164,14 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
131
164
  NSArray * devices = MTLCopyAllDevices();
132
165
  for (device in devices) {
133
166
  s = [device name];
134
- metal_printf("%s: found device: %s\n", __func__, [s UTF8String]);
167
+ GGML_METAL_LOG_INFO("%s: found device: %s\n", __func__, [s UTF8String]);
135
168
  }
136
169
  #endif
137
170
 
138
171
  // Pick and show default Metal device
139
172
  device = MTLCreateSystemDefaultDevice();
140
173
  s = [device name];
141
- metal_printf("%s: picking default device: %s\n", __func__, [s UTF8String]);
174
+ GGML_METAL_LOG_INFO("%s: picking default device: %s\n", __func__, [s UTF8String]);
142
175
 
143
176
  // Configure context
144
177
  struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context));
@@ -165,7 +198,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
165
198
  ctx->library = [ctx->device newLibraryWithURL:libURL error:&error];
166
199
 
167
200
  if (error) {
168
- metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]);
201
+ GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
169
202
  return NULL;
170
203
  }
171
204
  }
@@ -179,11 +212,11 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
179
212
  //NSString * path = [[NSBundle mainBundle] pathForResource:@"../../examples/metal/metal" ofType:@"metal"];
180
213
  NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
181
214
  NSString * path = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
182
- metal_printf("%s: loading '%s'\n", __func__, [path UTF8String]);
215
+ GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [path UTF8String]);
183
216
 
184
217
  NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error];
185
218
  if (error) {
186
- metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]);
219
+ GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
187
220
  return NULL;
188
221
  }
189
222
 
@@ -195,7 +228,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
195
228
  ctx->library = [ctx->device newLibraryWithSource:src options:nil error:&error];
196
229
  #endif
197
230
  if (error) {
198
- metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]);
231
+ GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
199
232
  return NULL;
200
233
  }
201
234
  }
@@ -207,11 +240,11 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
207
240
  #define GGML_METAL_ADD_KERNEL(name) \
208
241
  ctx->function_##name = [ctx->library newFunctionWithName:@"kernel_"#name]; \
209
242
  ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:&error]; \
210
- metal_printf("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name, \
243
+ GGML_METAL_LOG_INFO("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name, \
211
244
  (int) ctx->pipeline_##name.maxTotalThreadsPerThreadgroup, \
212
245
  (int) ctx->pipeline_##name.threadExecutionWidth); \
213
246
  if (error) { \
214
- metal_printf("%s: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
247
+ GGML_METAL_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
215
248
  return NULL; \
216
249
  }
217
250
 
@@ -261,7 +294,8 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
261
294
  GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
262
295
  GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
263
296
  GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
264
- GGML_METAL_ADD_KERNEL(rope);
297
+ GGML_METAL_ADD_KERNEL(rope_f32);
298
+ GGML_METAL_ADD_KERNEL(rope_f16);
265
299
  GGML_METAL_ADD_KERNEL(alibi_f32);
266
300
  GGML_METAL_ADD_KERNEL(cpy_f32_f16);
267
301
  GGML_METAL_ADD_KERNEL(cpy_f32_f32);
@@ -270,13 +304,13 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
270
304
  #undef GGML_METAL_ADD_KERNEL
271
305
  }
272
306
 
273
- metal_printf("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
307
+ GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
274
308
  #if TARGET_OS_OSX
275
- metal_printf("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
309
+ GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
276
310
  if (ctx->device.maxTransferRate != 0) {
277
- metal_printf("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
311
+ GGML_METAL_LOG_INFO("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
278
312
  } else {
279
- metal_printf("%s: maxTransferRate = built-in GPU\n", __func__);
313
+ GGML_METAL_LOG_INFO("%s: maxTransferRate = built-in GPU\n", __func__);
280
314
  }
281
315
  #endif
282
316
 
@@ -284,7 +318,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
284
318
  }
285
319
 
286
320
  void ggml_metal_free(struct ggml_metal_context * ctx) {
287
- metal_printf("%s: deallocating\n", __func__);
321
+ GGML_METAL_LOG_INFO("%s: deallocating\n", __func__);
288
322
  #define GGML_METAL_DEL_KERNEL(name) \
289
323
  [ctx->function_##name release]; \
290
324
  [ctx->pipeline_##name release];
@@ -335,7 +369,8 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
335
369
  GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
336
370
  GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
337
371
  GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
338
- GGML_METAL_DEL_KERNEL(rope);
372
+ GGML_METAL_DEL_KERNEL(rope_f32);
373
+ GGML_METAL_DEL_KERNEL(rope_f16);
339
374
  GGML_METAL_DEL_KERNEL(alibi_f32);
340
375
  GGML_METAL_DEL_KERNEL(cpy_f32_f16);
341
376
  GGML_METAL_DEL_KERNEL(cpy_f32_f32);
@@ -360,7 +395,7 @@ void * ggml_metal_host_malloc(size_t n) {
360
395
  void * data = NULL;
361
396
  const int result = posix_memalign((void **) &data, sysconf(_SC_PAGESIZE), n);
362
397
  if (result != 0) {
363
- metal_printf("%s: error: posix_memalign failed\n", __func__);
398
+ GGML_METAL_LOG_ERROR("%s: error: posix_memalign failed\n", __func__);
364
399
  return NULL;
365
400
  }
366
401
 
@@ -388,7 +423,7 @@ int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx) {
388
423
  // Metal buffer based on the host memory pointer
389
424
  //
390
425
  static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, struct ggml_tensor * t, size_t * offs) {
391
- //metal_printf("%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach);
426
+ //GGML_METAL_LOG_INFO("%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach);
392
427
 
393
428
  const int64_t tsize = ggml_nbytes(t);
394
429
 
@@ -400,13 +435,13 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru
400
435
  if (ioffs >= 0 && ioffs + tsize <= (int64_t) ctx->buffers[i].size) {
401
436
  *offs = (size_t) ioffs;
402
437
 
403
- //metal_printf("%s: '%s' tensor '%16s', offs = %8ld\n", __func__, ctx->buffers[i].name, t->name, *offs);
438
+ //GGML_METAL_LOG_INFO("%s: '%s' tensor '%16s', offs = %8ld\n", __func__, ctx->buffers[i].name, t->name, *offs);
404
439
 
405
440
  return ctx->buffers[i].metal;
406
441
  }
407
442
  }
408
443
 
409
- metal_printf("%s: error: buffer is nil\n", __func__);
444
+ GGML_METAL_LOG_ERROR("%s: error: buffer is nil\n", __func__);
410
445
 
411
446
  return nil;
412
447
  }
@@ -418,7 +453,7 @@ bool ggml_metal_add_buffer(
418
453
  size_t size,
419
454
  size_t max_size) {
420
455
  if (ctx->n_buffers >= GGML_METAL_MAX_BUFFERS) {
421
- metal_printf("%s: too many buffers\n", __func__);
456
+ GGML_METAL_LOG_ERROR("%s: error: too many buffers\n", __func__);
422
457
  return false;
423
458
  }
424
459
 
@@ -428,7 +463,7 @@ bool ggml_metal_add_buffer(
428
463
  const int64_t ioffs = (int64_t) data - (int64_t) ctx->buffers[i].data;
429
464
 
430
465
  if (ioffs >= 0 && ioffs < (int64_t) ctx->buffers[i].size) {
431
- metal_printf("%s: error: buffer '%s' overlaps with '%s'\n", __func__, name, ctx->buffers[i].name);
466
+ GGML_METAL_LOG_ERROR("%s: error: buffer '%s' overlaps with '%s'\n", __func__, name, ctx->buffers[i].name);
432
467
  return false;
433
468
  }
434
469
  }
@@ -449,11 +484,11 @@ bool ggml_metal_add_buffer(
449
484
  ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
450
485
 
451
486
  if (ctx->buffers[ctx->n_buffers].metal == nil) {
452
- metal_printf("%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_aligned / 1024.0 / 1024.0);
487
+ GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_aligned / 1024.0 / 1024.0);
453
488
  return false;
454
489
  }
455
490
 
456
- metal_printf("%s: allocated '%-16s' buffer, size = %8.2f MB", __func__, name, size_aligned / 1024.0 / 1024.0);
491
+ GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MB", __func__, name, size_aligned / 1024.0 / 1024.0);
457
492
 
458
493
  ++ctx->n_buffers;
459
494
  } else {
@@ -473,13 +508,13 @@ bool ggml_metal_add_buffer(
473
508
  ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
474
509
 
475
510
  if (ctx->buffers[ctx->n_buffers].metal == nil) {
476
- metal_printf("%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0);
511
+ GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0);
477
512
  return false;
478
513
  }
479
514
 
480
- metal_printf("%s: allocated '%-16s' buffer, size = %8.2f MB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i);
515
+ GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i);
481
516
  if (i + size_step < size) {
482
- metal_printf("\n");
517
+ GGML_METAL_LOG_INFO("\n");
483
518
  }
484
519
 
485
520
  ++ctx->n_buffers;
@@ -487,17 +522,17 @@ bool ggml_metal_add_buffer(
487
522
  }
488
523
 
489
524
  #if TARGET_OS_OSX
490
- metal_printf(", (%8.2f / %8.2f)",
525
+ GGML_METAL_LOG_INFO(", (%8.2f / %8.2f)",
491
526
  ctx->device.currentAllocatedSize / 1024.0 / 1024.0,
492
527
  ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
493
528
 
494
529
  if (ctx->device.currentAllocatedSize > ctx->device.recommendedMaxWorkingSetSize) {
495
- metal_printf(", warning: current allocated size is greater than the recommended max working set size\n");
530
+ GGML_METAL_LOG_WARN(", warning: current allocated size is greater than the recommended max working set size\n", __func__);
496
531
  } else {
497
- metal_printf("\n");
532
+ GGML_METAL_LOG_INFO("\n");
498
533
  }
499
534
  #else
500
- metal_printf(", (%8.2f)\n", ctx->device.currentAllocatedSize / 1024.0 / 1024.0);
535
+ GGML_METAL_LOG_INFO(", (%8.2f)\n", ctx->device.currentAllocatedSize / 1024.0 / 1024.0);
501
536
  #endif
502
537
  }
503
538
 
@@ -610,7 +645,7 @@ void ggml_metal_graph_find_concurrency(
610
645
  }
611
646
 
612
647
  if (ctx->concur_list_len > GGML_MAX_CONCUR) {
613
- metal_printf("%s: too many elements for metal ctx->concur_list!\n", __func__);
648
+ GGML_METAL_LOG_WARN("%s: too many elements for metal ctx->concur_list!\n", __func__);
614
649
  }
615
650
  }
616
651
 
@@ -664,7 +699,7 @@ void ggml_metal_graph_compute(
664
699
  continue;
665
700
  }
666
701
 
667
- //metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
702
+ //GGML_METAL_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
668
703
 
669
704
  struct ggml_tensor * src0 = gf->nodes[i]->src[0];
670
705
  struct ggml_tensor * src1 = gf->nodes[i]->src[1];
@@ -708,17 +743,17 @@ void ggml_metal_graph_compute(
708
743
  id<MTLBuffer> id_src1 = src1 ? ggml_metal_get_buffer(ctx, src1, &offs_src1) : nil;
709
744
  id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(ctx, dst, &offs_dst) : nil;
710
745
 
711
- //metal_printf("%s: op - %s\n", __func__, ggml_op_name(dst->op));
746
+ //GGML_METAL_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
712
747
  //if (src0) {
713
- // metal_printf("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02,
748
+ // GGML_METAL_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02,
714
749
  // ggml_is_contiguous(src0), src0->name);
715
750
  //}
716
751
  //if (src1) {
717
- // metal_printf("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12,
752
+ // GGML_METAL_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12,
718
753
  // ggml_is_contiguous(src1), src1->name);
719
754
  //}
720
755
  //if (dst) {
721
- // metal_printf("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2,
756
+ // GGML_METAL_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2,
722
757
  // dst->name);
723
758
  //}
724
759
 
@@ -736,25 +771,59 @@ void ggml_metal_graph_compute(
736
771
  GGML_ASSERT(ggml_is_contiguous(src0));
737
772
  GGML_ASSERT(ggml_is_contiguous(src1));
738
773
 
739
- // utilize float4
740
- GGML_ASSERT(ne00 % 4 == 0);
741
- const int64_t nb = ne00/4;
774
+ bool bcast_row = false;
742
775
 
743
- if (ggml_nelements(src1) == ne10) {
776
+ int64_t nb = ne00;
777
+
778
+ if (ggml_nelements(src1) == ne10 && ne00 % 4 == 0) {
744
779
  // src1 is a row
745
780
  GGML_ASSERT(ne11 == 1);
781
+
782
+ nb = ne00 / 4;
746
783
  [encoder setComputePipelineState:ctx->pipeline_add_row];
784
+
785
+ bcast_row = true;
747
786
  } else {
748
787
  [encoder setComputePipelineState:ctx->pipeline_add];
749
788
  }
750
789
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
751
790
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
752
791
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
753
- [encoder setBytes:&nb length:sizeof(nb) atIndex:3];
754
-
755
- const int64_t n = ggml_nelements(dst)/4;
792
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
793
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
794
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
795
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
796
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
797
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
798
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
799
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
800
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
801
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
802
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
803
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
804
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
805
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
806
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
807
+ [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
808
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
809
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
810
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
811
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
812
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
813
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
814
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
815
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
816
+ [encoder setBytes:&nb length:sizeof(nb) atIndex:27];
817
+
818
+ if (bcast_row) {
819
+ const int64_t n = ggml_nelements(dst)/4;
820
+
821
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
822
+ } else {
823
+ const int nth = MIN(1024, ne0);
756
824
 
757
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
825
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
826
+ }
758
827
  } break;
759
828
  case GGML_OP_MUL:
760
829
  {
@@ -830,13 +899,13 @@ void ggml_metal_graph_compute(
830
899
  } break;
831
900
  default:
832
901
  {
833
- metal_printf("%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
902
+ GGML_METAL_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
834
903
  GGML_ASSERT(false);
835
904
  }
836
905
  } break;
837
906
  case GGML_OP_SOFT_MAX:
838
907
  {
839
- const int nth = 32;
908
+ const int nth = MIN(32, ne00);
840
909
 
841
910
  if (ne00%4 == 0) {
842
911
  [encoder setComputePipelineState:ctx->pipeline_soft_max_4];
@@ -889,7 +958,7 @@ void ggml_metal_graph_compute(
889
958
  src1t == GGML_TYPE_F32 &&
890
959
  [ctx->device supportsFamily:MTLGPUFamilyApple7] &&
891
960
  ne00%32 == 0 &&
892
- ne11 > 1) {
961
+ ne11 > 2) {
893
962
  switch (src0->type) {
894
963
  case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32]; break;
895
964
  case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
@@ -1019,7 +1088,7 @@ void ggml_metal_graph_compute(
1019
1088
  } break;
1020
1089
  default:
1021
1090
  {
1022
- metal_printf("Asserting on type %d\n",(int)src0t);
1091
+ GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
1023
1092
  GGML_ASSERT(false && "not implemented");
1024
1093
  }
1025
1094
  };
@@ -1100,7 +1169,7 @@ void ggml_metal_graph_compute(
1100
1169
  float eps;
1101
1170
  memcpy(&eps, dst->op_params, sizeof(float));
1102
1171
 
1103
- const int nth = 512;
1172
+ const int nth = MIN(512, ne00);
1104
1173
 
1105
1174
  [encoder setComputePipelineState:ctx->pipeline_rms_norm];
1106
1175
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@@ -1119,7 +1188,7 @@ void ggml_metal_graph_compute(
1119
1188
  float eps;
1120
1189
  memcpy(&eps, dst->op_params, sizeof(float));
1121
1190
 
1122
- const int nth = 256;
1191
+ const int nth = MIN(256, ne00);
1123
1192
 
1124
1193
  [encoder setComputePipelineState:ctx->pipeline_norm];
1125
1194
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@@ -1137,6 +1206,8 @@ void ggml_metal_graph_compute(
1137
1206
  {
1138
1207
  GGML_ASSERT((src0t == GGML_TYPE_F32));
1139
1208
 
1209
+ const int nth = MIN(1024, ne00);
1210
+
1140
1211
  const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past);
1141
1212
  const int n_head = ((int32_t *) dst->op_params)[1];
1142
1213
  float max_bias;
@@ -1170,12 +1241,14 @@ void ggml_metal_graph_compute(
1170
1241
  [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
1171
1242
  [encoder setBytes:&m0 length:sizeof( float) atIndex:18];
1172
1243
 
1173
- const int nth = 32;
1174
-
1175
1244
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1176
1245
  } break;
1177
1246
  case GGML_OP_ROPE:
1178
1247
  {
1248
+ GGML_ASSERT(ne10 == ne02);
1249
+
1250
+ const int nth = MIN(1024, ne00);
1251
+
1179
1252
  const int n_past = ((int32_t *) dst->op_params)[0];
1180
1253
  const int n_dims = ((int32_t *) dst->op_params)[1];
1181
1254
  const int mode = ((int32_t *) dst->op_params)[2];
@@ -1185,38 +1258,44 @@ void ggml_metal_graph_compute(
1185
1258
  memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float));
1186
1259
  memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
1187
1260
 
1188
- [encoder setComputePipelineState:ctx->pipeline_rope];
1261
+ switch (src0->type) {
1262
+ case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_rope_f32]; break;
1263
+ case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_rope_f16]; break;
1264
+ default: GGML_ASSERT(false);
1265
+ };
1266
+
1189
1267
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1190
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1191
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1192
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
1193
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
1194
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
1195
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
1196
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
1197
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
1198
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
1199
- [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
1200
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
1201
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
1202
- [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
1203
- [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
1204
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
1205
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
1206
- [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
1207
- [encoder setBytes:&n_past length:sizeof( int) atIndex:18];
1208
- [encoder setBytes:&n_dims length:sizeof( int) atIndex:19];
1209
- [encoder setBytes:&mode length:sizeof( int) atIndex:20];
1210
- [encoder setBytes:&freq_base length:sizeof(float) atIndex:21];
1211
- [encoder setBytes:&freq_scale length:sizeof(float) atIndex:22];
1268
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1269
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1270
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
1271
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:4];
1272
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:5];
1273
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:6];
1274
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:7];
1275
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
1276
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
1277
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
1278
+ [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:11];
1279
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:12];
1280
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:13];
1281
+ [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:14];
1282
+ [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:15];
1283
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:16];
1284
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:17];
1285
+ [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:18];
1286
+ [encoder setBytes:&n_past length:sizeof( int) atIndex:19];
1287
+ [encoder setBytes:&n_dims length:sizeof( int) atIndex:20];
1288
+ [encoder setBytes:&mode length:sizeof( int) atIndex:21];
1289
+ [encoder setBytes:&freq_base length:sizeof(float) atIndex:22];
1290
+ [encoder setBytes:&freq_scale length:sizeof(float) atIndex:23];
1212
1291
 
1213
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
1292
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1214
1293
  } break;
1215
1294
  case GGML_OP_DUP:
1216
1295
  case GGML_OP_CPY:
1217
1296
  case GGML_OP_CONT:
1218
1297
  {
1219
- const int nth = 32;
1298
+ const int nth = MIN(1024, ne00);
1220
1299
 
1221
1300
  switch (src0t) {
1222
1301
  case GGML_TYPE_F32:
@@ -1261,7 +1340,7 @@ void ggml_metal_graph_compute(
1261
1340
  } break;
1262
1341
  default:
1263
1342
  {
1264
- metal_printf("%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
1343
+ GGML_METAL_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
1265
1344
  GGML_ASSERT(false);
1266
1345
  }
1267
1346
  }
@@ -1286,7 +1365,7 @@ void ggml_metal_graph_compute(
1286
1365
 
1287
1366
  MTLCommandBufferStatus status = (MTLCommandBufferStatus) [ctx->command_buffers[i] status];
1288
1367
  if (status != MTLCommandBufferStatusCompleted) {
1289
- metal_printf("%s: command buffer %d failed with status %lu\n", __func__, i, status);
1368
+ GGML_METAL_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
1290
1369
  GGML_ASSERT(false);
1291
1370
  }
1292
1371
  }