llama_cpp 0.5.3 → 0.6.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -11,11 +11,14 @@
11
11
  #define MIN(a, b) ((a) < (b) ? (a) : (b))
12
12
  #define MAX(a, b) ((a) > (b) ? (a) : (b))
13
13
 
14
- // TODO: temporary - reuse llama.cpp logging
15
14
  #ifdef GGML_METAL_NDEBUG
16
- #define metal_printf(...)
15
+ #define GGML_METAL_LOG_INFO(...)
16
+ #define GGML_METAL_LOG_WARN(...)
17
+ #define GGML_METAL_LOG_ERROR(...)
17
18
  #else
18
- #define metal_printf(...) fprintf(stderr, __VA_ARGS__)
19
+ #define GGML_METAL_LOG_INFO(...) ggml_metal_log(GGML_LOG_LEVEL_INFO, __VA_ARGS__)
20
+ #define GGML_METAL_LOG_WARN(...) ggml_metal_log(GGML_LOG_LEVEL_WARN, __VA_ARGS__)
21
+ #define GGML_METAL_LOG_ERROR(...) ggml_metal_log(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
19
22
  #endif
20
23
 
21
24
  #define UNUSED(x) (void)(x)
@@ -100,7 +103,8 @@ struct ggml_metal_context {
100
103
  GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
101
104
  GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32);
102
105
  GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32);
103
- GGML_METAL_DECL_KERNEL(rope);
106
+ GGML_METAL_DECL_KERNEL(rope_f32);
107
+ GGML_METAL_DECL_KERNEL(rope_f16);
104
108
  GGML_METAL_DECL_KERNEL(alibi_f32);
105
109
  GGML_METAL_DECL_KERNEL(cpy_f32_f16);
106
110
  GGML_METAL_DECL_KERNEL(cpy_f32_f32);
@@ -120,8 +124,37 @@ static NSString * const msl_library_source = @"see metal.metal";
120
124
  @implementation GGMLMetalClass
121
125
  @end
122
126
 
127
+ ggml_log_callback ggml_metal_log_callback = NULL;
128
+ void * ggml_metal_log_user_data = NULL;
129
+
130
+ void ggml_metal_log_set_callback(ggml_log_callback log_callback, void * user_data) {
131
+ ggml_metal_log_callback = log_callback;
132
+ ggml_metal_log_user_data = user_data;
133
+ }
134
+
135
+ static void ggml_metal_log(enum ggml_log_level level, const char* format, ...){
136
+ if (ggml_metal_log_callback != NULL) {
137
+ va_list args;
138
+ va_start(args, format);
139
+ char buffer[128];
140
+ int len = vsnprintf(buffer, 128, format, args);
141
+ if (len < 128) {
142
+ ggml_metal_log_callback(level, buffer, ggml_metal_log_user_data);
143
+ } else {
144
+ char* buffer2 = malloc(len+1);
145
+ vsnprintf(buffer2, len+1, format, args);
146
+ buffer2[len] = 0;
147
+ ggml_metal_log_callback(level, buffer2, ggml_metal_log_user_data);
148
+ free(buffer2);
149
+ }
150
+ va_end(args);
151
+ }
152
+ }
153
+
154
+
155
+
123
156
  struct ggml_metal_context * ggml_metal_init(int n_cb) {
124
- metal_printf("%s: allocating\n", __func__);
157
+ GGML_METAL_LOG_INFO("%s: allocating\n", __func__);
125
158
 
126
159
  id <MTLDevice> device;
127
160
  NSString * s;
@@ -131,14 +164,14 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
131
164
  NSArray * devices = MTLCopyAllDevices();
132
165
  for (device in devices) {
133
166
  s = [device name];
134
- metal_printf("%s: found device: %s\n", __func__, [s UTF8String]);
167
+ GGML_METAL_LOG_INFO("%s: found device: %s\n", __func__, [s UTF8String]);
135
168
  }
136
169
  #endif
137
170
 
138
171
  // Pick and show default Metal device
139
172
  device = MTLCreateSystemDefaultDevice();
140
173
  s = [device name];
141
- metal_printf("%s: picking default device: %s\n", __func__, [s UTF8String]);
174
+ GGML_METAL_LOG_INFO("%s: picking default device: %s\n", __func__, [s UTF8String]);
142
175
 
143
176
  // Configure context
144
177
  struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context));
@@ -165,7 +198,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
165
198
  ctx->library = [ctx->device newLibraryWithURL:libURL error:&error];
166
199
 
167
200
  if (error) {
168
- metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]);
201
+ GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
169
202
  return NULL;
170
203
  }
171
204
  }
@@ -179,11 +212,11 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
179
212
  //NSString * path = [[NSBundle mainBundle] pathForResource:@"../../examples/metal/metal" ofType:@"metal"];
180
213
  NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
181
214
  NSString * path = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
182
- metal_printf("%s: loading '%s'\n", __func__, [path UTF8String]);
215
+ GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [path UTF8String]);
183
216
 
184
217
  NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error];
185
218
  if (error) {
186
- metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]);
219
+ GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
187
220
  return NULL;
188
221
  }
189
222
 
@@ -195,7 +228,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
195
228
  ctx->library = [ctx->device newLibraryWithSource:src options:nil error:&error];
196
229
  #endif
197
230
  if (error) {
198
- metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]);
231
+ GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
199
232
  return NULL;
200
233
  }
201
234
  }
@@ -207,11 +240,11 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
207
240
  #define GGML_METAL_ADD_KERNEL(name) \
208
241
  ctx->function_##name = [ctx->library newFunctionWithName:@"kernel_"#name]; \
209
242
  ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:&error]; \
210
- metal_printf("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name, \
243
+ GGML_METAL_LOG_INFO("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name, \
211
244
  (int) ctx->pipeline_##name.maxTotalThreadsPerThreadgroup, \
212
245
  (int) ctx->pipeline_##name.threadExecutionWidth); \
213
246
  if (error) { \
214
- metal_printf("%s: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
247
+ GGML_METAL_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
215
248
  return NULL; \
216
249
  }
217
250
 
@@ -261,7 +294,8 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
261
294
  GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
262
295
  GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
263
296
  GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
264
- GGML_METAL_ADD_KERNEL(rope);
297
+ GGML_METAL_ADD_KERNEL(rope_f32);
298
+ GGML_METAL_ADD_KERNEL(rope_f16);
265
299
  GGML_METAL_ADD_KERNEL(alibi_f32);
266
300
  GGML_METAL_ADD_KERNEL(cpy_f32_f16);
267
301
  GGML_METAL_ADD_KERNEL(cpy_f32_f32);
@@ -270,13 +304,13 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
270
304
  #undef GGML_METAL_ADD_KERNEL
271
305
  }
272
306
 
273
- metal_printf("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
307
+ GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
274
308
  #if TARGET_OS_OSX
275
- metal_printf("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
309
+ GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
276
310
  if (ctx->device.maxTransferRate != 0) {
277
- metal_printf("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
311
+ GGML_METAL_LOG_INFO("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
278
312
  } else {
279
- metal_printf("%s: maxTransferRate = built-in GPU\n", __func__);
313
+ GGML_METAL_LOG_INFO("%s: maxTransferRate = built-in GPU\n", __func__);
280
314
  }
281
315
  #endif
282
316
 
@@ -284,7 +318,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
284
318
  }
285
319
 
286
320
  void ggml_metal_free(struct ggml_metal_context * ctx) {
287
- metal_printf("%s: deallocating\n", __func__);
321
+ GGML_METAL_LOG_INFO("%s: deallocating\n", __func__);
288
322
  #define GGML_METAL_DEL_KERNEL(name) \
289
323
  [ctx->function_##name release]; \
290
324
  [ctx->pipeline_##name release];
@@ -335,7 +369,8 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
335
369
  GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
336
370
  GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
337
371
  GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
338
- GGML_METAL_DEL_KERNEL(rope);
372
+ GGML_METAL_DEL_KERNEL(rope_f32);
373
+ GGML_METAL_DEL_KERNEL(rope_f16);
339
374
  GGML_METAL_DEL_KERNEL(alibi_f32);
340
375
  GGML_METAL_DEL_KERNEL(cpy_f32_f16);
341
376
  GGML_METAL_DEL_KERNEL(cpy_f32_f32);
@@ -360,7 +395,7 @@ void * ggml_metal_host_malloc(size_t n) {
360
395
  void * data = NULL;
361
396
  const int result = posix_memalign((void **) &data, sysconf(_SC_PAGESIZE), n);
362
397
  if (result != 0) {
363
- metal_printf("%s: error: posix_memalign failed\n", __func__);
398
+ GGML_METAL_LOG_ERROR("%s: error: posix_memalign failed\n", __func__);
364
399
  return NULL;
365
400
  }
366
401
 
@@ -388,7 +423,7 @@ int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx) {
388
423
  // Metal buffer based on the host memory pointer
389
424
  //
390
425
  static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, struct ggml_tensor * t, size_t * offs) {
391
- //metal_printf("%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach);
426
+ //GGML_METAL_LOG_INFO("%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach);
392
427
 
393
428
  const int64_t tsize = ggml_nbytes(t);
394
429
 
@@ -400,13 +435,13 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru
400
435
  if (ioffs >= 0 && ioffs + tsize <= (int64_t) ctx->buffers[i].size) {
401
436
  *offs = (size_t) ioffs;
402
437
 
403
- //metal_printf("%s: '%s' tensor '%16s', offs = %8ld\n", __func__, ctx->buffers[i].name, t->name, *offs);
438
+ //GGML_METAL_LOG_INFO("%s: '%s' tensor '%16s', offs = %8ld\n", __func__, ctx->buffers[i].name, t->name, *offs);
404
439
 
405
440
  return ctx->buffers[i].metal;
406
441
  }
407
442
  }
408
443
 
409
- metal_printf("%s: error: buffer is nil\n", __func__);
444
+ GGML_METAL_LOG_ERROR("%s: error: buffer is nil\n", __func__);
410
445
 
411
446
  return nil;
412
447
  }
@@ -418,7 +453,7 @@ bool ggml_metal_add_buffer(
418
453
  size_t size,
419
454
  size_t max_size) {
420
455
  if (ctx->n_buffers >= GGML_METAL_MAX_BUFFERS) {
421
- metal_printf("%s: too many buffers\n", __func__);
456
+ GGML_METAL_LOG_ERROR("%s: error: too many buffers\n", __func__);
422
457
  return false;
423
458
  }
424
459
 
@@ -428,7 +463,7 @@ bool ggml_metal_add_buffer(
428
463
  const int64_t ioffs = (int64_t) data - (int64_t) ctx->buffers[i].data;
429
464
 
430
465
  if (ioffs >= 0 && ioffs < (int64_t) ctx->buffers[i].size) {
431
- metal_printf("%s: error: buffer '%s' overlaps with '%s'\n", __func__, name, ctx->buffers[i].name);
466
+ GGML_METAL_LOG_ERROR("%s: error: buffer '%s' overlaps with '%s'\n", __func__, name, ctx->buffers[i].name);
432
467
  return false;
433
468
  }
434
469
  }
@@ -449,11 +484,11 @@ bool ggml_metal_add_buffer(
449
484
  ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
450
485
 
451
486
  if (ctx->buffers[ctx->n_buffers].metal == nil) {
452
- metal_printf("%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_aligned / 1024.0 / 1024.0);
487
+ GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_aligned / 1024.0 / 1024.0);
453
488
  return false;
454
489
  }
455
490
 
456
- metal_printf("%s: allocated '%-16s' buffer, size = %8.2f MB", __func__, name, size_aligned / 1024.0 / 1024.0);
491
+ GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MB", __func__, name, size_aligned / 1024.0 / 1024.0);
457
492
 
458
493
  ++ctx->n_buffers;
459
494
  } else {
@@ -473,13 +508,13 @@ bool ggml_metal_add_buffer(
473
508
  ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
474
509
 
475
510
  if (ctx->buffers[ctx->n_buffers].metal == nil) {
476
- metal_printf("%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0);
511
+ GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0);
477
512
  return false;
478
513
  }
479
514
 
480
- metal_printf("%s: allocated '%-16s' buffer, size = %8.2f MB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i);
515
+ GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i);
481
516
  if (i + size_step < size) {
482
- metal_printf("\n");
517
+ GGML_METAL_LOG_INFO("\n");
483
518
  }
484
519
 
485
520
  ++ctx->n_buffers;
@@ -487,17 +522,17 @@ bool ggml_metal_add_buffer(
487
522
  }
488
523
 
489
524
  #if TARGET_OS_OSX
490
- metal_printf(", (%8.2f / %8.2f)",
525
+ GGML_METAL_LOG_INFO(", (%8.2f / %8.2f)",
491
526
  ctx->device.currentAllocatedSize / 1024.0 / 1024.0,
492
527
  ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
493
528
 
494
529
  if (ctx->device.currentAllocatedSize > ctx->device.recommendedMaxWorkingSetSize) {
495
- metal_printf(", warning: current allocated size is greater than the recommended max working set size\n");
530
+ GGML_METAL_LOG_WARN(", warning: current allocated size is greater than the recommended max working set size\n", __func__);
496
531
  } else {
497
- metal_printf("\n");
532
+ GGML_METAL_LOG_INFO("\n");
498
533
  }
499
534
  #else
500
- metal_printf(", (%8.2f)\n", ctx->device.currentAllocatedSize / 1024.0 / 1024.0);
535
+ GGML_METAL_LOG_INFO(", (%8.2f)\n", ctx->device.currentAllocatedSize / 1024.0 / 1024.0);
501
536
  #endif
502
537
  }
503
538
 
@@ -610,7 +645,7 @@ void ggml_metal_graph_find_concurrency(
610
645
  }
611
646
 
612
647
  if (ctx->concur_list_len > GGML_MAX_CONCUR) {
613
- metal_printf("%s: too many elements for metal ctx->concur_list!\n", __func__);
648
+ GGML_METAL_LOG_WARN("%s: too many elements for metal ctx->concur_list!\n", __func__);
614
649
  }
615
650
  }
616
651
 
@@ -664,7 +699,7 @@ void ggml_metal_graph_compute(
664
699
  continue;
665
700
  }
666
701
 
667
- //metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
702
+ //GGML_METAL_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
668
703
 
669
704
  struct ggml_tensor * src0 = gf->nodes[i]->src[0];
670
705
  struct ggml_tensor * src1 = gf->nodes[i]->src[1];
@@ -708,17 +743,17 @@ void ggml_metal_graph_compute(
708
743
  id<MTLBuffer> id_src1 = src1 ? ggml_metal_get_buffer(ctx, src1, &offs_src1) : nil;
709
744
  id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(ctx, dst, &offs_dst) : nil;
710
745
 
711
- //metal_printf("%s: op - %s\n", __func__, ggml_op_name(dst->op));
746
+ //GGML_METAL_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
712
747
  //if (src0) {
713
- // metal_printf("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02,
748
+ // GGML_METAL_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02,
714
749
  // ggml_is_contiguous(src0), src0->name);
715
750
  //}
716
751
  //if (src1) {
717
- // metal_printf("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12,
752
+ // GGML_METAL_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12,
718
753
  // ggml_is_contiguous(src1), src1->name);
719
754
  //}
720
755
  //if (dst) {
721
- // metal_printf("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2,
756
+ // GGML_METAL_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2,
722
757
  // dst->name);
723
758
  //}
724
759
 
@@ -736,25 +771,59 @@ void ggml_metal_graph_compute(
736
771
  GGML_ASSERT(ggml_is_contiguous(src0));
737
772
  GGML_ASSERT(ggml_is_contiguous(src1));
738
773
 
739
- // utilize float4
740
- GGML_ASSERT(ne00 % 4 == 0);
741
- const int64_t nb = ne00/4;
774
+ bool bcast_row = false;
742
775
 
743
- if (ggml_nelements(src1) == ne10) {
776
+ int64_t nb = ne00;
777
+
778
+ if (ggml_nelements(src1) == ne10 && ne00 % 4 == 0) {
744
779
  // src1 is a row
745
780
  GGML_ASSERT(ne11 == 1);
781
+
782
+ nb = ne00 / 4;
746
783
  [encoder setComputePipelineState:ctx->pipeline_add_row];
784
+
785
+ bcast_row = true;
747
786
  } else {
748
787
  [encoder setComputePipelineState:ctx->pipeline_add];
749
788
  }
750
789
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
751
790
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
752
791
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
753
- [encoder setBytes:&nb length:sizeof(nb) atIndex:3];
754
-
755
- const int64_t n = ggml_nelements(dst)/4;
792
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
793
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
794
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
795
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
796
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
797
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
798
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
799
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
800
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
801
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
802
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
803
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
804
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
805
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
806
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
807
+ [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
808
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
809
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
810
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
811
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
812
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
813
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
814
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
815
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
816
+ [encoder setBytes:&nb length:sizeof(nb) atIndex:27];
817
+
818
+ if (bcast_row) {
819
+ const int64_t n = ggml_nelements(dst)/4;
820
+
821
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
822
+ } else {
823
+ const int nth = MIN(1024, ne0);
756
824
 
757
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
825
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
826
+ }
758
827
  } break;
759
828
  case GGML_OP_MUL:
760
829
  {
@@ -830,13 +899,13 @@ void ggml_metal_graph_compute(
830
899
  } break;
831
900
  default:
832
901
  {
833
- metal_printf("%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
902
+ GGML_METAL_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
834
903
  GGML_ASSERT(false);
835
904
  }
836
905
  } break;
837
906
  case GGML_OP_SOFT_MAX:
838
907
  {
839
- const int nth = 32;
908
+ const int nth = MIN(32, ne00);
840
909
 
841
910
  if (ne00%4 == 0) {
842
911
  [encoder setComputePipelineState:ctx->pipeline_soft_max_4];
@@ -889,7 +958,7 @@ void ggml_metal_graph_compute(
889
958
  src1t == GGML_TYPE_F32 &&
890
959
  [ctx->device supportsFamily:MTLGPUFamilyApple7] &&
891
960
  ne00%32 == 0 &&
892
- ne11 > 1) {
961
+ ne11 > 2) {
893
962
  switch (src0->type) {
894
963
  case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32]; break;
895
964
  case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
@@ -1019,7 +1088,7 @@ void ggml_metal_graph_compute(
1019
1088
  } break;
1020
1089
  default:
1021
1090
  {
1022
- metal_printf("Asserting on type %d\n",(int)src0t);
1091
+ GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
1023
1092
  GGML_ASSERT(false && "not implemented");
1024
1093
  }
1025
1094
  };
@@ -1100,7 +1169,7 @@ void ggml_metal_graph_compute(
1100
1169
  float eps;
1101
1170
  memcpy(&eps, dst->op_params, sizeof(float));
1102
1171
 
1103
- const int nth = 512;
1172
+ const int nth = MIN(512, ne00);
1104
1173
 
1105
1174
  [encoder setComputePipelineState:ctx->pipeline_rms_norm];
1106
1175
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@@ -1119,7 +1188,7 @@ void ggml_metal_graph_compute(
1119
1188
  float eps;
1120
1189
  memcpy(&eps, dst->op_params, sizeof(float));
1121
1190
 
1122
- const int nth = 256;
1191
+ const int nth = MIN(256, ne00);
1123
1192
 
1124
1193
  [encoder setComputePipelineState:ctx->pipeline_norm];
1125
1194
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@@ -1137,6 +1206,8 @@ void ggml_metal_graph_compute(
1137
1206
  {
1138
1207
  GGML_ASSERT((src0t == GGML_TYPE_F32));
1139
1208
 
1209
+ const int nth = MIN(1024, ne00);
1210
+
1140
1211
  const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past);
1141
1212
  const int n_head = ((int32_t *) dst->op_params)[1];
1142
1213
  float max_bias;
@@ -1170,12 +1241,14 @@ void ggml_metal_graph_compute(
1170
1241
  [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
1171
1242
  [encoder setBytes:&m0 length:sizeof( float) atIndex:18];
1172
1243
 
1173
- const int nth = 32;
1174
-
1175
1244
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1176
1245
  } break;
1177
1246
  case GGML_OP_ROPE:
1178
1247
  {
1248
+ GGML_ASSERT(ne10 == ne02);
1249
+
1250
+ const int nth = MIN(1024, ne00);
1251
+
1179
1252
  const int n_past = ((int32_t *) dst->op_params)[0];
1180
1253
  const int n_dims = ((int32_t *) dst->op_params)[1];
1181
1254
  const int mode = ((int32_t *) dst->op_params)[2];
@@ -1185,38 +1258,44 @@ void ggml_metal_graph_compute(
1185
1258
  memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float));
1186
1259
  memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
1187
1260
 
1188
- [encoder setComputePipelineState:ctx->pipeline_rope];
1261
+ switch (src0->type) {
1262
+ case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_rope_f32]; break;
1263
+ case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_rope_f16]; break;
1264
+ default: GGML_ASSERT(false);
1265
+ };
1266
+
1189
1267
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1190
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1191
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1192
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
1193
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
1194
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
1195
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
1196
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
1197
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
1198
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
1199
- [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
1200
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
1201
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
1202
- [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
1203
- [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
1204
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
1205
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
1206
- [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
1207
- [encoder setBytes:&n_past length:sizeof( int) atIndex:18];
1208
- [encoder setBytes:&n_dims length:sizeof( int) atIndex:19];
1209
- [encoder setBytes:&mode length:sizeof( int) atIndex:20];
1210
- [encoder setBytes:&freq_base length:sizeof(float) atIndex:21];
1211
- [encoder setBytes:&freq_scale length:sizeof(float) atIndex:22];
1268
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1269
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1270
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
1271
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:4];
1272
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:5];
1273
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:6];
1274
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:7];
1275
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
1276
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
1277
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
1278
+ [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:11];
1279
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:12];
1280
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:13];
1281
+ [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:14];
1282
+ [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:15];
1283
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:16];
1284
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:17];
1285
+ [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:18];
1286
+ [encoder setBytes:&n_past length:sizeof( int) atIndex:19];
1287
+ [encoder setBytes:&n_dims length:sizeof( int) atIndex:20];
1288
+ [encoder setBytes:&mode length:sizeof( int) atIndex:21];
1289
+ [encoder setBytes:&freq_base length:sizeof(float) atIndex:22];
1290
+ [encoder setBytes:&freq_scale length:sizeof(float) atIndex:23];
1212
1291
 
1213
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
1292
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1214
1293
  } break;
1215
1294
  case GGML_OP_DUP:
1216
1295
  case GGML_OP_CPY:
1217
1296
  case GGML_OP_CONT:
1218
1297
  {
1219
- const int nth = 32;
1298
+ const int nth = MIN(1024, ne00);
1220
1299
 
1221
1300
  switch (src0t) {
1222
1301
  case GGML_TYPE_F32:
@@ -1261,7 +1340,7 @@ void ggml_metal_graph_compute(
1261
1340
  } break;
1262
1341
  default:
1263
1342
  {
1264
- metal_printf("%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
1343
+ GGML_METAL_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
1265
1344
  GGML_ASSERT(false);
1266
1345
  }
1267
1346
  }
@@ -1286,7 +1365,7 @@ void ggml_metal_graph_compute(
1286
1365
 
1287
1366
  MTLCommandBufferStatus status = (MTLCommandBufferStatus) [ctx->command_buffers[i] status];
1288
1367
  if (status != MTLCommandBufferStatusCompleted) {
1289
- metal_printf("%s: command buffer %d failed with status %lu\n", __func__, i, status);
1368
+ GGML_METAL_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
1290
1369
  GGML_ASSERT(false);
1291
1370
  }
1292
1371
  }