llama_cpp 0.4.0 → 0.5.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -11,6 +11,7 @@
11
11
  #define MIN(a, b) ((a) < (b) ? (a) : (b))
12
12
  #define MAX(a, b) ((a) > (b) ? (a) : (b))
13
13
 
14
+ // TODO: temporary - reuse llama.cpp logging
14
15
  #ifdef GGML_METAL_NDEBUG
15
16
  #define metal_printf(...)
16
17
  #else
@@ -33,12 +34,15 @@ struct ggml_metal_buffer {
33
34
  struct ggml_metal_context {
34
35
  int n_cb;
35
36
 
36
- float * logits;
37
-
38
37
  id<MTLDevice> device;
39
38
  id<MTLCommandQueue> queue;
40
39
  id<MTLLibrary> library;
41
40
 
41
+ id<MTLCommandBuffer> command_buffers [GGML_METAL_MAX_COMMAND_BUFFERS];
42
+ id<MTLComputeCommandEncoder> command_encoders[GGML_METAL_MAX_COMMAND_BUFFERS];
43
+
44
+ dispatch_queue_t d_queue;
45
+
42
46
  int n_buffers;
43
47
  struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
44
48
 
@@ -110,16 +114,17 @@ static NSString * const msl_library_source = @"see metal.metal";
110
114
  @end
111
115
 
112
116
  struct ggml_metal_context * ggml_metal_init(int n_cb) {
113
- fprintf(stderr, "%s: allocating\n", __func__);
117
+ metal_printf("%s: allocating\n", __func__);
114
118
 
115
119
  struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context));
116
120
 
117
- ctx->n_cb = n_cb;
121
+ ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
118
122
  ctx->device = MTLCreateSystemDefaultDevice();
119
123
  ctx->queue = [ctx->device newCommandQueue];
120
124
  ctx->n_buffers = 0;
121
125
  ctx->concur_list_len = 0;
122
126
 
127
+ ctx->d_queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT);
123
128
 
124
129
  #if 0
125
130
  // compile from source string and show compile log
@@ -128,7 +133,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
128
133
 
129
134
  ctx->library = [ctx->device newLibraryWithSource:msl_library_source options:nil error:&error];
130
135
  if (error) {
131
- fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
136
+ metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]);
132
137
  return NULL;
133
138
  }
134
139
  }
@@ -142,11 +147,11 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
142
147
  //NSString * path = [[NSBundle mainBundle] pathForResource:@"../../examples/metal/metal" ofType:@"metal"];
143
148
  NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
144
149
  NSString * path = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
145
- fprintf(stderr, "%s: loading '%s'\n", __func__, [path UTF8String]);
150
+ metal_printf("%s: loading '%s'\n", __func__, [path UTF8String]);
146
151
 
147
152
  NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error];
148
153
  if (error) {
149
- fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
154
+ metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]);
150
155
  return NULL;
151
156
  }
152
157
 
@@ -158,7 +163,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
158
163
  ctx->library = [ctx->device newLibraryWithSource:src options:nil error:&error];
159
164
  #endif
160
165
  if (error) {
161
- fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
166
+ metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]);
162
167
  return NULL;
163
168
  }
164
169
  }
@@ -170,11 +175,11 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
170
175
  #define GGML_METAL_ADD_KERNEL(name) \
171
176
  ctx->function_##name = [ctx->library newFunctionWithName:@"kernel_"#name]; \
172
177
  ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:&error]; \
173
- fprintf(stderr, "%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name, \
178
+ metal_printf("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name, \
174
179
  (int) ctx->pipeline_##name.maxTotalThreadsPerThreadgroup, \
175
180
  (int) ctx->pipeline_##name.threadExecutionWidth); \
176
181
  if (error) { \
177
- fprintf(stderr, "%s: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
182
+ metal_printf("%s: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
178
183
  return NULL; \
179
184
  }
180
185
 
@@ -226,22 +231,80 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
226
231
  #undef GGML_METAL_ADD_KERNEL
227
232
  }
228
233
 
229
- fprintf(stderr, "%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
230
- fprintf(stderr, "%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
234
+ metal_printf("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
235
+ metal_printf("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
231
236
  if (ctx->device.maxTransferRate != 0) {
232
- fprintf(stderr, "%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
237
+ metal_printf("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
233
238
  } else {
234
- fprintf(stderr, "%s: maxTransferRate = built-in GPU\n", __func__);
239
+ metal_printf("%s: maxTransferRate = built-in GPU\n", __func__);
235
240
  }
236
241
 
237
242
  return ctx;
238
243
  }
239
244
 
240
245
  void ggml_metal_free(struct ggml_metal_context * ctx) {
241
- fprintf(stderr, "%s: deallocating\n", __func__);
246
+ metal_printf("%s: deallocating\n", __func__);
247
+ #define GGML_METAL_DEL_KERNEL(name) \
248
+ [ctx->function_##name release]; \
249
+ [ctx->pipeline_##name release];
250
+
251
+ GGML_METAL_DEL_KERNEL(add);
252
+ GGML_METAL_DEL_KERNEL(add_row);
253
+ GGML_METAL_DEL_KERNEL(mul);
254
+ GGML_METAL_DEL_KERNEL(mul_row);
255
+ GGML_METAL_DEL_KERNEL(scale);
256
+ GGML_METAL_DEL_KERNEL(silu);
257
+ GGML_METAL_DEL_KERNEL(relu);
258
+ GGML_METAL_DEL_KERNEL(gelu);
259
+ GGML_METAL_DEL_KERNEL(soft_max);
260
+ GGML_METAL_DEL_KERNEL(diag_mask_inf);
261
+ GGML_METAL_DEL_KERNEL(get_rows_f16);
262
+ GGML_METAL_DEL_KERNEL(get_rows_q4_0);
263
+ GGML_METAL_DEL_KERNEL(get_rows_q4_1);
264
+ GGML_METAL_DEL_KERNEL(get_rows_q8_0);
265
+ GGML_METAL_DEL_KERNEL(get_rows_q2_K);
266
+ GGML_METAL_DEL_KERNEL(get_rows_q3_K);
267
+ GGML_METAL_DEL_KERNEL(get_rows_q4_K);
268
+ GGML_METAL_DEL_KERNEL(get_rows_q5_K);
269
+ GGML_METAL_DEL_KERNEL(get_rows_q6_K);
270
+ GGML_METAL_DEL_KERNEL(rms_norm);
271
+ GGML_METAL_DEL_KERNEL(norm);
272
+ GGML_METAL_DEL_KERNEL(mul_mat_f16_f32);
273
+ GGML_METAL_DEL_KERNEL(mul_mat_q4_0_f32);
274
+ GGML_METAL_DEL_KERNEL(mul_mat_q4_1_f32);
275
+ GGML_METAL_DEL_KERNEL(mul_mat_q8_0_f32);
276
+ GGML_METAL_DEL_KERNEL(mul_mat_q2_K_f32);
277
+ GGML_METAL_DEL_KERNEL(mul_mat_q3_K_f32);
278
+ GGML_METAL_DEL_KERNEL(mul_mat_q4_K_f32);
279
+ GGML_METAL_DEL_KERNEL(mul_mat_q5_K_f32);
280
+ GGML_METAL_DEL_KERNEL(mul_mat_q6_K_f32);
281
+ GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
282
+ GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
283
+ GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
284
+ GGML_METAL_DEL_KERNEL(mul_mm_q4_1_f32);
285
+ GGML_METAL_DEL_KERNEL(mul_mm_q2_K_f32);
286
+ GGML_METAL_DEL_KERNEL(mul_mm_q3_K_f32);
287
+ GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
288
+ GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
289
+ GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
290
+ GGML_METAL_DEL_KERNEL(rope);
291
+ GGML_METAL_DEL_KERNEL(alibi_f32);
292
+ GGML_METAL_DEL_KERNEL(cpy_f32_f16);
293
+ GGML_METAL_DEL_KERNEL(cpy_f32_f32);
294
+ GGML_METAL_DEL_KERNEL(cpy_f16_f16);
295
+
296
+ #undef GGML_METAL_DEL_KERNEL
297
+
242
298
  for (int i = 0; i < ctx->n_buffers; ++i) {
243
299
  [ctx->buffers[i].metal release];
244
300
  }
301
+
302
+ [ctx->library release];
303
+ [ctx->queue release];
304
+ [ctx->device release];
305
+
306
+ dispatch_release(ctx->d_queue);
307
+
245
308
  free(ctx);
246
309
  }
247
310
 
@@ -249,7 +312,7 @@ void * ggml_metal_host_malloc(size_t n) {
249
312
  void * data = NULL;
250
313
  const int result = posix_memalign((void **) &data, getpagesize(), n);
251
314
  if (result != 0) {
252
- fprintf(stderr, "%s: error: posix_memalign failed\n", __func__);
315
+ metal_printf("%s: error: posix_memalign failed\n", __func__);
253
316
  return NULL;
254
317
  }
255
318
 
@@ -261,7 +324,7 @@ void ggml_metal_host_free(void * data) {
261
324
  }
262
325
 
263
326
  void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb) {
264
- ctx->n_cb = n_cb;
327
+ ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
265
328
  }
266
329
 
267
330
  int ggml_metal_if_optimized(struct ggml_metal_context * ctx) {
@@ -277,7 +340,7 @@ int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx) {
277
340
  // Metal buffer based on the host memory pointer
278
341
  //
279
342
  static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, struct ggml_tensor * t, size_t * offs) {
280
- //fprintf(stderr, "%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach);
343
+ //metal_printf("%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach);
281
344
 
282
345
  const int64_t tsize = ggml_nbytes(t);
283
346
 
@@ -288,13 +351,13 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru
288
351
  if (ioffs >= 0 && ioffs + tsize <= (int64_t) ctx->buffers[i].size) {
289
352
  *offs = (size_t) ioffs;
290
353
 
291
- //fprintf(stderr, "%s: '%s' tensor '%16s', offs = %8ld\n", __func__, ctx->buffers[i].name, t->name, *offs);
354
+ //metal_printf("%s: '%s' tensor '%16s', offs = %8ld\n", __func__, ctx->buffers[i].name, t->name, *offs);
292
355
 
293
356
  return ctx->buffers[i].metal;
294
357
  }
295
358
  }
296
359
 
297
- fprintf(stderr, "%s: error: buffer is nil\n", __func__);
360
+ metal_printf("%s: error: buffer is nil\n", __func__);
298
361
 
299
362
  return nil;
300
363
  }
@@ -306,7 +369,7 @@ bool ggml_metal_add_buffer(
306
369
  size_t size,
307
370
  size_t max_size) {
308
371
  if (ctx->n_buffers >= GGML_METAL_MAX_BUFFERS) {
309
- fprintf(stderr, "%s: too many buffers\n", __func__);
372
+ metal_printf("%s: too many buffers\n", __func__);
310
373
  return false;
311
374
  }
312
375
 
@@ -316,7 +379,7 @@ bool ggml_metal_add_buffer(
316
379
  const int64_t ioffs = (int64_t) data - (int64_t) ctx->buffers[i].data;
317
380
 
318
381
  if (ioffs >= 0 && ioffs < (int64_t) ctx->buffers[i].size) {
319
- fprintf(stderr, "%s: error: buffer '%s' overlaps with '%s'\n", __func__, name, ctx->buffers[i].name);
382
+ metal_printf("%s: error: buffer '%s' overlaps with '%s'\n", __func__, name, ctx->buffers[i].name);
320
383
  return false;
321
384
  }
322
385
  }
@@ -337,11 +400,11 @@ bool ggml_metal_add_buffer(
337
400
  ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
338
401
 
339
402
  if (ctx->buffers[ctx->n_buffers].metal == nil) {
340
- fprintf(stderr, "%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_aligned / 1024.0 / 1024.0);
403
+ metal_printf("%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_aligned / 1024.0 / 1024.0);
341
404
  return false;
342
405
  }
343
406
 
344
- fprintf(stderr, "%s: allocated '%-16s' buffer, size = %8.2f MB", __func__, name, size_aligned / 1024.0 / 1024.0);
407
+ metal_printf("%s: allocated '%-16s' buffer, size = %8.2f MB", __func__, name, size_aligned / 1024.0 / 1024.0);
345
408
 
346
409
  ++ctx->n_buffers;
347
410
  } else {
@@ -361,27 +424,27 @@ bool ggml_metal_add_buffer(
361
424
  ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
362
425
 
363
426
  if (ctx->buffers[ctx->n_buffers].metal == nil) {
364
- fprintf(stderr, "%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0);
427
+ metal_printf("%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0);
365
428
  return false;
366
429
  }
367
430
 
368
- fprintf(stderr, "%s: allocated '%-16s' buffer, size = %8.2f MB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i);
431
+ metal_printf("%s: allocated '%-16s' buffer, size = %8.2f MB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i);
369
432
  if (i + size_step < size) {
370
- fprintf(stderr, "\n");
433
+ metal_printf("\n");
371
434
  }
372
435
 
373
436
  ++ctx->n_buffers;
374
437
  }
375
438
  }
376
439
 
377
- fprintf(stderr, ", (%8.2f / %8.2f)",
440
+ metal_printf(", (%8.2f / %8.2f)",
378
441
  ctx->device.currentAllocatedSize / 1024.0 / 1024.0,
379
442
  ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
380
443
 
381
444
  if (ctx->device.currentAllocatedSize > ctx->device.recommendedMaxWorkingSetSize) {
382
- fprintf(stderr, ", warning: current allocated size is greater than the recommended max working set size\n");
445
+ metal_printf(", warning: current allocated size is greater than the recommended max working set size\n");
383
446
  } else {
384
- fprintf(stderr, "\n");
447
+ metal_printf("\n");
385
448
  }
386
449
  }
387
450
 
@@ -391,8 +454,6 @@ bool ggml_metal_add_buffer(
391
454
  void ggml_metal_set_tensor(
392
455
  struct ggml_metal_context * ctx,
393
456
  struct ggml_tensor * t) {
394
- metal_printf("%s: set input for tensor '%s'\n", __func__, t->name);
395
-
396
457
  size_t offs;
397
458
  id<MTLBuffer> id_dst = ggml_metal_get_buffer(ctx, t, &offs);
398
459
 
@@ -402,8 +463,6 @@ void ggml_metal_set_tensor(
402
463
  void ggml_metal_get_tensor(
403
464
  struct ggml_metal_context * ctx,
404
465
  struct ggml_tensor * t) {
405
- metal_printf("%s: extract results for tensor '%s'\n", __func__, t->name);
406
-
407
466
  size_t offs;
408
467
  id<MTLBuffer> id_src = ggml_metal_get_buffer(ctx, t, &offs);
409
468
 
@@ -498,14 +557,14 @@ void ggml_metal_graph_find_concurrency(
498
557
  }
499
558
 
500
559
  if (ctx->concur_list_len > GGML_MAX_CONCUR) {
501
- fprintf(stderr, "%s: too many elements for metal ctx->concur_list!\n", __func__);
560
+ metal_printf("%s: too many elements for metal ctx->concur_list!\n", __func__);
502
561
  }
503
562
  }
504
563
 
505
564
  void ggml_metal_graph_compute(
506
565
  struct ggml_metal_context * ctx,
507
566
  struct ggml_cgraph * gf) {
508
- metal_printf("%s: evaluating graph\n", __func__);
567
+ @autoreleasepool {
509
568
 
510
569
  // if there is ctx->concur_list, dispatch concurrently
511
570
  // else fallback to serial dispatch
@@ -521,29 +580,25 @@ void ggml_metal_graph_compute(
521
580
 
522
581
  const int n_cb = ctx->n_cb;
523
582
 
524
- NSMutableArray * command_buffers = [NSMutableArray arrayWithCapacity:n_cb];
525
-
526
583
  for (int i = 0; i < n_cb; ++i) {
527
- command_buffers[i] = [ctx->queue commandBuffer];
584
+ ctx->command_buffers[i] = [ctx->queue commandBuffer];
528
585
 
529
586
  // enqueue the command buffers in order to specify their execution order
530
- [command_buffers[i] enqueue];
531
- }
587
+ [ctx->command_buffers[i] enqueue];
532
588
 
533
- // TODO: is this the best way to start threads?
534
- dispatch_queue_t queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT);
589
+ ctx->command_encoders[i] = [ctx->command_buffers[i] computeCommandEncoderWithDescriptor: edesc];
590
+ }
535
591
 
536
592
  for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
537
593
  const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb;
538
594
 
539
- dispatch_async(queue, ^{
595
+ dispatch_async(ctx->d_queue, ^{
540
596
  size_t offs_src0 = 0;
541
597
  size_t offs_src1 = 0;
542
598
  size_t offs_dst = 0;
543
599
 
544
- id<MTLCommandBuffer> command_buffer = command_buffers[cb_idx];
545
-
546
- id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
600
+ id<MTLCommandBuffer> command_buffer = ctx->command_buffers[cb_idx];
601
+ id<MTLComputeCommandEncoder> encoder = ctx->command_encoders[cb_idx];
547
602
 
548
603
  const int node_start = (cb_idx + 0) * n_nodes_per_cb;
549
604
  const int node_end = MIN((cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb, n_nodes);
@@ -556,7 +611,7 @@ void ggml_metal_graph_compute(
556
611
  continue;
557
612
  }
558
613
 
559
- metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
614
+ //metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
560
615
 
561
616
  struct ggml_tensor * src0 = gf->nodes[i]->src[0];
562
617
  struct ggml_tensor * src1 = gf->nodes[i]->src[1];
@@ -625,6 +680,12 @@ void ggml_metal_graph_compute(
625
680
  } break;
626
681
  case GGML_OP_ADD:
627
682
  {
683
+ GGML_ASSERT(ggml_is_contiguous(src0));
684
+
685
+ // utilize float4
686
+ GGML_ASSERT(ne00 % 4 == 0);
687
+ const int64_t nb = ne00/4;
688
+
628
689
  if (ggml_nelements(src1) == ne10) {
629
690
  // src1 is a row
630
691
  [encoder setComputePipelineState:ctx->pipeline_add_row];
@@ -634,14 +695,20 @@ void ggml_metal_graph_compute(
634
695
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
635
696
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
636
697
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
637
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
698
+ [encoder setBytes:&nb length:sizeof(nb) atIndex:3];
638
699
 
639
- const int64_t n = ggml_nelements(dst);
700
+ const int64_t n = ggml_nelements(dst)/4;
640
701
 
641
702
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
642
703
  } break;
643
704
  case GGML_OP_MUL:
644
705
  {
706
+ GGML_ASSERT(ggml_is_contiguous(src0));
707
+
708
+ // utilize float4
709
+ GGML_ASSERT(ne00 % 4 == 0);
710
+ const int64_t nb = ne00/4;
711
+
645
712
  if (ggml_nelements(src1) == ne10) {
646
713
  // src1 is a row
647
714
  [encoder setComputePipelineState:ctx->pipeline_mul_row];
@@ -651,9 +718,9 @@ void ggml_metal_graph_compute(
651
718
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
652
719
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
653
720
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
654
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
721
+ [encoder setBytes:&nb length:sizeof(nb) atIndex:3];
655
722
 
656
- const int64_t n = ggml_nelements(dst);
723
+ const int64_t n = ggml_nelements(dst)/4;
657
724
 
658
725
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
659
726
  } break;
@@ -704,7 +771,7 @@ void ggml_metal_graph_compute(
704
771
  } break;
705
772
  default:
706
773
  {
707
- fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
774
+ metal_printf("%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
708
775
  GGML_ASSERT(false);
709
776
  }
710
777
  } break;
@@ -785,7 +852,7 @@ void ggml_metal_graph_compute(
785
852
  switch (src0t) {
786
853
  case GGML_TYPE_F16:
787
854
  {
788
- nth0 = 64;
855
+ nth0 = 32;
789
856
  nth1 = 1;
790
857
  [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
791
858
  } break;
@@ -863,7 +930,7 @@ void ggml_metal_graph_compute(
863
930
  } break;
864
931
  default:
865
932
  {
866
- fprintf(stderr, "Asserting on type %d\n",(int)src0t);
933
+ metal_printf("Asserting on type %d\n",(int)src0t);
867
934
  GGML_ASSERT(false && "not implemented");
868
935
  }
869
936
  };
@@ -1101,7 +1168,7 @@ void ggml_metal_graph_compute(
1101
1168
  } break;
1102
1169
  default:
1103
1170
  {
1104
- fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
1171
+ metal_printf("%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
1105
1172
  GGML_ASSERT(false);
1106
1173
  }
1107
1174
  }
@@ -1117,17 +1184,19 @@ void ggml_metal_graph_compute(
1117
1184
  }
1118
1185
 
1119
1186
  // wait for all threads to finish
1120
- dispatch_barrier_sync(queue, ^{});
1121
-
1122
- [command_buffers[n_cb - 1] waitUntilCompleted];
1187
+ dispatch_barrier_sync(ctx->d_queue, ^{});
1123
1188
 
1124
1189
  // check status of command buffers
1125
1190
  // needed to detect if the device ran out-of-memory for example (#1881)
1126
1191
  for (int i = 0; i < n_cb; i++) {
1127
- MTLCommandBufferStatus status = (MTLCommandBufferStatus) [command_buffers[i] status];
1192
+ [ctx->command_buffers[i] waitUntilCompleted];
1193
+
1194
+ MTLCommandBufferStatus status = (MTLCommandBufferStatus) [ctx->command_buffers[i] status];
1128
1195
  if (status != MTLCommandBufferStatusCompleted) {
1129
- fprintf(stderr, "%s: command buffer %d failed with status %lu\n", __func__, i, status);
1196
+ metal_printf("%s: command buffer %d failed with status %lu\n", __func__, i, status);
1130
1197
  GGML_ASSERT(false);
1131
1198
  }
1132
1199
  }
1200
+
1201
+ }
1133
1202
  }
@@ -25,9 +25,9 @@ typedef struct {
25
25
  } block_q8_0;
26
26
 
27
27
  kernel void kernel_add(
28
- device const float * src0,
29
- device const float * src1,
30
- device float * dst,
28
+ device const float4 * src0,
29
+ device const float4 * src1,
30
+ device float4 * dst,
31
31
  uint tpig[[thread_position_in_grid]]) {
32
32
  dst[tpig] = src0[tpig] + src1[tpig];
33
33
  }
@@ -35,18 +35,18 @@ kernel void kernel_add(
35
35
  // assumption: src1 is a row
36
36
  // broadcast src1 into src0
37
37
  kernel void kernel_add_row(
38
- device const float * src0,
39
- device const float * src1,
40
- device float * dst,
41
- constant int64_t & ne00,
38
+ device const float4 * src0,
39
+ device const float4 * src1,
40
+ device float4 * dst,
41
+ constant int64_t & nb,
42
42
  uint tpig[[thread_position_in_grid]]) {
43
- dst[tpig] = src0[tpig] + src1[tpig % ne00];
43
+ dst[tpig] = src0[tpig] + src1[tpig % nb];
44
44
  }
45
45
 
46
46
  kernel void kernel_mul(
47
- device const float * src0,
48
- device const float * src1,
49
- device float * dst,
47
+ device const float4 * src0,
48
+ device const float4 * src1,
49
+ device float4 * dst,
50
50
  uint tpig[[thread_position_in_grid]]) {
51
51
  dst[tpig] = src0[tpig] * src1[tpig];
52
52
  }
@@ -54,12 +54,12 @@ kernel void kernel_mul(
54
54
  // assumption: src1 is a row
55
55
  // broadcast src1 into src0
56
56
  kernel void kernel_mul_row(
57
- device const float * src0,
58
- device const float * src1,
59
- device float * dst,
60
- constant int64_t & ne00,
57
+ device const float4 * src0,
58
+ device const float4 * src1,
59
+ device float4 * dst,
60
+ constant int64_t & nb,
61
61
  uint tpig[[thread_position_in_grid]]) {
62
- dst[tpig] = src0[tpig] * src1[tpig % ne00];
62
+ dst[tpig] = src0[tpig] * src1[tpig % nb];
63
63
  }
64
64
 
65
65
  kernel void kernel_scale(
@@ -528,24 +528,42 @@ kernel void kernel_mul_mat_f16_f32(
528
528
  device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
529
529
  device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
530
530
 
531
- sum[tpitg.x] = 0.0f;
531
+ uint ith = tpitg.x;
532
+ uint nth = tptg.x;
533
+
534
+ sum[ith] = 0.0f;
532
535
 
533
- for (int i = tpitg.x; i < ne00; i += tptg.x) {
534
- sum[tpitg.x] += (float) x[i] * (float) y[i];
536
+ for (int i = ith; i < ne00; i += nth) {
537
+ sum[ith] += (float) x[i] * (float) y[i];
535
538
  }
536
539
 
537
540
  // accumulate the sum from all threads in the threadgroup
538
541
  threadgroup_barrier(mem_flags::mem_threadgroup);
539
- for (uint i = tptg.x/2; i > 0; i /= 2) {
540
- if (tpitg.x < i) {
541
- sum[tpitg.x] += sum[tpitg.x + i];
542
- }
543
- threadgroup_barrier(mem_flags::mem_threadgroup);
542
+ if (ith%4 == 0) {
543
+ for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
544
544
  }
545
-
546
- if (tpitg.x == 0) {
545
+ threadgroup_barrier(mem_flags::mem_threadgroup);
546
+ if (ith%16 == 0) {
547
+ for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
548
+ }
549
+ threadgroup_barrier(mem_flags::mem_threadgroup);
550
+ if (ith == 0) {
551
+ for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
547
552
  dst[im*ne1*ne0 + r1*ne0 + r0] = sum[0];
548
553
  }
554
+
555
+ // Original implementation. Left behind commented out for now
556
+ //threadgroup_barrier(mem_flags::mem_threadgroup);
557
+ //for (uint i = tptg.x/2; i > 0; i /= 2) {
558
+ // if (tpitg.x < i) {
559
+ // sum[tpitg.x] += sum[tpitg.x + i];
560
+ // }
561
+ // threadgroup_barrier(mem_flags::mem_threadgroup);
562
+ //}
563
+ //
564
+ //if (tpitg.x == 0) {
565
+ // dst[im*ne1*ne0 + r1*ne0 + r0] = sum[0];
566
+ //}
549
567
  }
550
568
 
551
569
  kernel void kernel_alibi_f32(