llama_cpp 0.4.0 → 0.5.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -11,6 +11,7 @@
11
11
  #define MIN(a, b) ((a) < (b) ? (a) : (b))
12
12
  #define MAX(a, b) ((a) > (b) ? (a) : (b))
13
13
 
14
+ // TODO: temporary - reuse llama.cpp logging
14
15
  #ifdef GGML_METAL_NDEBUG
15
16
  #define metal_printf(...)
16
17
  #else
@@ -33,12 +34,15 @@ struct ggml_metal_buffer {
33
34
  struct ggml_metal_context {
34
35
  int n_cb;
35
36
 
36
- float * logits;
37
-
38
37
  id<MTLDevice> device;
39
38
  id<MTLCommandQueue> queue;
40
39
  id<MTLLibrary> library;
41
40
 
41
+ id<MTLCommandBuffer> command_buffers [GGML_METAL_MAX_COMMAND_BUFFERS];
42
+ id<MTLComputeCommandEncoder> command_encoders[GGML_METAL_MAX_COMMAND_BUFFERS];
43
+
44
+ dispatch_queue_t d_queue;
45
+
42
46
  int n_buffers;
43
47
  struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
44
48
 
@@ -110,16 +114,17 @@ static NSString * const msl_library_source = @"see metal.metal";
110
114
  @end
111
115
 
112
116
  struct ggml_metal_context * ggml_metal_init(int n_cb) {
113
- fprintf(stderr, "%s: allocating\n", __func__);
117
+ metal_printf("%s: allocating\n", __func__);
114
118
 
115
119
  struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context));
116
120
 
117
- ctx->n_cb = n_cb;
121
+ ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
118
122
  ctx->device = MTLCreateSystemDefaultDevice();
119
123
  ctx->queue = [ctx->device newCommandQueue];
120
124
  ctx->n_buffers = 0;
121
125
  ctx->concur_list_len = 0;
122
126
 
127
+ ctx->d_queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT);
123
128
 
124
129
  #if 0
125
130
  // compile from source string and show compile log
@@ -128,7 +133,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
128
133
 
129
134
  ctx->library = [ctx->device newLibraryWithSource:msl_library_source options:nil error:&error];
130
135
  if (error) {
131
- fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
136
+ metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]);
132
137
  return NULL;
133
138
  }
134
139
  }
@@ -142,11 +147,11 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
142
147
  //NSString * path = [[NSBundle mainBundle] pathForResource:@"../../examples/metal/metal" ofType:@"metal"];
143
148
  NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
144
149
  NSString * path = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
145
- fprintf(stderr, "%s: loading '%s'\n", __func__, [path UTF8String]);
150
+ metal_printf("%s: loading '%s'\n", __func__, [path UTF8String]);
146
151
 
147
152
  NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error];
148
153
  if (error) {
149
- fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
154
+ metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]);
150
155
  return NULL;
151
156
  }
152
157
 
@@ -158,7 +163,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
158
163
  ctx->library = [ctx->device newLibraryWithSource:src options:nil error:&error];
159
164
  #endif
160
165
  if (error) {
161
- fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
166
+ metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]);
162
167
  return NULL;
163
168
  }
164
169
  }
@@ -170,11 +175,11 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
170
175
  #define GGML_METAL_ADD_KERNEL(name) \
171
176
  ctx->function_##name = [ctx->library newFunctionWithName:@"kernel_"#name]; \
172
177
  ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:&error]; \
173
- fprintf(stderr, "%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name, \
178
+ metal_printf("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name, \
174
179
  (int) ctx->pipeline_##name.maxTotalThreadsPerThreadgroup, \
175
180
  (int) ctx->pipeline_##name.threadExecutionWidth); \
176
181
  if (error) { \
177
- fprintf(stderr, "%s: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
182
+ metal_printf("%s: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
178
183
  return NULL; \
179
184
  }
180
185
 
@@ -226,22 +231,80 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
226
231
  #undef GGML_METAL_ADD_KERNEL
227
232
  }
228
233
 
229
- fprintf(stderr, "%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
230
- fprintf(stderr, "%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
234
+ metal_printf("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
235
+ metal_printf("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
231
236
  if (ctx->device.maxTransferRate != 0) {
232
- fprintf(stderr, "%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
237
+ metal_printf("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
233
238
  } else {
234
- fprintf(stderr, "%s: maxTransferRate = built-in GPU\n", __func__);
239
+ metal_printf("%s: maxTransferRate = built-in GPU\n", __func__);
235
240
  }
236
241
 
237
242
  return ctx;
238
243
  }
239
244
 
240
245
  void ggml_metal_free(struct ggml_metal_context * ctx) {
241
- fprintf(stderr, "%s: deallocating\n", __func__);
246
+ metal_printf("%s: deallocating\n", __func__);
247
+ #define GGML_METAL_DEL_KERNEL(name) \
248
+ [ctx->function_##name release]; \
249
+ [ctx->pipeline_##name release];
250
+
251
+ GGML_METAL_DEL_KERNEL(add);
252
+ GGML_METAL_DEL_KERNEL(add_row);
253
+ GGML_METAL_DEL_KERNEL(mul);
254
+ GGML_METAL_DEL_KERNEL(mul_row);
255
+ GGML_METAL_DEL_KERNEL(scale);
256
+ GGML_METAL_DEL_KERNEL(silu);
257
+ GGML_METAL_DEL_KERNEL(relu);
258
+ GGML_METAL_DEL_KERNEL(gelu);
259
+ GGML_METAL_DEL_KERNEL(soft_max);
260
+ GGML_METAL_DEL_KERNEL(diag_mask_inf);
261
+ GGML_METAL_DEL_KERNEL(get_rows_f16);
262
+ GGML_METAL_DEL_KERNEL(get_rows_q4_0);
263
+ GGML_METAL_DEL_KERNEL(get_rows_q4_1);
264
+ GGML_METAL_DEL_KERNEL(get_rows_q8_0);
265
+ GGML_METAL_DEL_KERNEL(get_rows_q2_K);
266
+ GGML_METAL_DEL_KERNEL(get_rows_q3_K);
267
+ GGML_METAL_DEL_KERNEL(get_rows_q4_K);
268
+ GGML_METAL_DEL_KERNEL(get_rows_q5_K);
269
+ GGML_METAL_DEL_KERNEL(get_rows_q6_K);
270
+ GGML_METAL_DEL_KERNEL(rms_norm);
271
+ GGML_METAL_DEL_KERNEL(norm);
272
+ GGML_METAL_DEL_KERNEL(mul_mat_f16_f32);
273
+ GGML_METAL_DEL_KERNEL(mul_mat_q4_0_f32);
274
+ GGML_METAL_DEL_KERNEL(mul_mat_q4_1_f32);
275
+ GGML_METAL_DEL_KERNEL(mul_mat_q8_0_f32);
276
+ GGML_METAL_DEL_KERNEL(mul_mat_q2_K_f32);
277
+ GGML_METAL_DEL_KERNEL(mul_mat_q3_K_f32);
278
+ GGML_METAL_DEL_KERNEL(mul_mat_q4_K_f32);
279
+ GGML_METAL_DEL_KERNEL(mul_mat_q5_K_f32);
280
+ GGML_METAL_DEL_KERNEL(mul_mat_q6_K_f32);
281
+ GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
282
+ GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
283
+ GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
284
+ GGML_METAL_DEL_KERNEL(mul_mm_q4_1_f32);
285
+ GGML_METAL_DEL_KERNEL(mul_mm_q2_K_f32);
286
+ GGML_METAL_DEL_KERNEL(mul_mm_q3_K_f32);
287
+ GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
288
+ GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
289
+ GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
290
+ GGML_METAL_DEL_KERNEL(rope);
291
+ GGML_METAL_DEL_KERNEL(alibi_f32);
292
+ GGML_METAL_DEL_KERNEL(cpy_f32_f16);
293
+ GGML_METAL_DEL_KERNEL(cpy_f32_f32);
294
+ GGML_METAL_DEL_KERNEL(cpy_f16_f16);
295
+
296
+ #undef GGML_METAL_DEL_KERNEL
297
+
242
298
  for (int i = 0; i < ctx->n_buffers; ++i) {
243
299
  [ctx->buffers[i].metal release];
244
300
  }
301
+
302
+ [ctx->library release];
303
+ [ctx->queue release];
304
+ [ctx->device release];
305
+
306
+ dispatch_release(ctx->d_queue);
307
+
245
308
  free(ctx);
246
309
  }
247
310
 
@@ -249,7 +312,7 @@ void * ggml_metal_host_malloc(size_t n) {
249
312
  void * data = NULL;
250
313
  const int result = posix_memalign((void **) &data, getpagesize(), n);
251
314
  if (result != 0) {
252
- fprintf(stderr, "%s: error: posix_memalign failed\n", __func__);
315
+ metal_printf("%s: error: posix_memalign failed\n", __func__);
253
316
  return NULL;
254
317
  }
255
318
 
@@ -261,7 +324,7 @@ void ggml_metal_host_free(void * data) {
261
324
  }
262
325
 
263
326
  void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb) {
264
- ctx->n_cb = n_cb;
327
+ ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
265
328
  }
266
329
 
267
330
  int ggml_metal_if_optimized(struct ggml_metal_context * ctx) {
@@ -277,7 +340,7 @@ int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx) {
277
340
  // Metal buffer based on the host memory pointer
278
341
  //
279
342
  static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, struct ggml_tensor * t, size_t * offs) {
280
- //fprintf(stderr, "%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach);
343
+ //metal_printf("%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach);
281
344
 
282
345
  const int64_t tsize = ggml_nbytes(t);
283
346
 
@@ -288,13 +351,13 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru
288
351
  if (ioffs >= 0 && ioffs + tsize <= (int64_t) ctx->buffers[i].size) {
289
352
  *offs = (size_t) ioffs;
290
353
 
291
- //fprintf(stderr, "%s: '%s' tensor '%16s', offs = %8ld\n", __func__, ctx->buffers[i].name, t->name, *offs);
354
+ //metal_printf("%s: '%s' tensor '%16s', offs = %8ld\n", __func__, ctx->buffers[i].name, t->name, *offs);
292
355
 
293
356
  return ctx->buffers[i].metal;
294
357
  }
295
358
  }
296
359
 
297
- fprintf(stderr, "%s: error: buffer is nil\n", __func__);
360
+ metal_printf("%s: error: buffer is nil\n", __func__);
298
361
 
299
362
  return nil;
300
363
  }
@@ -306,7 +369,7 @@ bool ggml_metal_add_buffer(
306
369
  size_t size,
307
370
  size_t max_size) {
308
371
  if (ctx->n_buffers >= GGML_METAL_MAX_BUFFERS) {
309
- fprintf(stderr, "%s: too many buffers\n", __func__);
372
+ metal_printf("%s: too many buffers\n", __func__);
310
373
  return false;
311
374
  }
312
375
 
@@ -316,7 +379,7 @@ bool ggml_metal_add_buffer(
316
379
  const int64_t ioffs = (int64_t) data - (int64_t) ctx->buffers[i].data;
317
380
 
318
381
  if (ioffs >= 0 && ioffs < (int64_t) ctx->buffers[i].size) {
319
- fprintf(stderr, "%s: error: buffer '%s' overlaps with '%s'\n", __func__, name, ctx->buffers[i].name);
382
+ metal_printf("%s: error: buffer '%s' overlaps with '%s'\n", __func__, name, ctx->buffers[i].name);
320
383
  return false;
321
384
  }
322
385
  }
@@ -337,11 +400,11 @@ bool ggml_metal_add_buffer(
337
400
  ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
338
401
 
339
402
  if (ctx->buffers[ctx->n_buffers].metal == nil) {
340
- fprintf(stderr, "%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_aligned / 1024.0 / 1024.0);
403
+ metal_printf("%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_aligned / 1024.0 / 1024.0);
341
404
  return false;
342
405
  }
343
406
 
344
- fprintf(stderr, "%s: allocated '%-16s' buffer, size = %8.2f MB", __func__, name, size_aligned / 1024.0 / 1024.0);
407
+ metal_printf("%s: allocated '%-16s' buffer, size = %8.2f MB", __func__, name, size_aligned / 1024.0 / 1024.0);
345
408
 
346
409
  ++ctx->n_buffers;
347
410
  } else {
@@ -361,27 +424,27 @@ bool ggml_metal_add_buffer(
361
424
  ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
362
425
 
363
426
  if (ctx->buffers[ctx->n_buffers].metal == nil) {
364
- fprintf(stderr, "%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0);
427
+ metal_printf("%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0);
365
428
  return false;
366
429
  }
367
430
 
368
- fprintf(stderr, "%s: allocated '%-16s' buffer, size = %8.2f MB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i);
431
+ metal_printf("%s: allocated '%-16s' buffer, size = %8.2f MB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i);
369
432
  if (i + size_step < size) {
370
- fprintf(stderr, "\n");
433
+ metal_printf("\n");
371
434
  }
372
435
 
373
436
  ++ctx->n_buffers;
374
437
  }
375
438
  }
376
439
 
377
- fprintf(stderr, ", (%8.2f / %8.2f)",
440
+ metal_printf(", (%8.2f / %8.2f)",
378
441
  ctx->device.currentAllocatedSize / 1024.0 / 1024.0,
379
442
  ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
380
443
 
381
444
  if (ctx->device.currentAllocatedSize > ctx->device.recommendedMaxWorkingSetSize) {
382
- fprintf(stderr, ", warning: current allocated size is greater than the recommended max working set size\n");
445
+ metal_printf(", warning: current allocated size is greater than the recommended max working set size\n");
383
446
  } else {
384
- fprintf(stderr, "\n");
447
+ metal_printf("\n");
385
448
  }
386
449
  }
387
450
 
@@ -391,8 +454,6 @@ bool ggml_metal_add_buffer(
391
454
  void ggml_metal_set_tensor(
392
455
  struct ggml_metal_context * ctx,
393
456
  struct ggml_tensor * t) {
394
- metal_printf("%s: set input for tensor '%s'\n", __func__, t->name);
395
-
396
457
  size_t offs;
397
458
  id<MTLBuffer> id_dst = ggml_metal_get_buffer(ctx, t, &offs);
398
459
 
@@ -402,8 +463,6 @@ void ggml_metal_set_tensor(
402
463
  void ggml_metal_get_tensor(
403
464
  struct ggml_metal_context * ctx,
404
465
  struct ggml_tensor * t) {
405
- metal_printf("%s: extract results for tensor '%s'\n", __func__, t->name);
406
-
407
466
  size_t offs;
408
467
  id<MTLBuffer> id_src = ggml_metal_get_buffer(ctx, t, &offs);
409
468
 
@@ -498,14 +557,14 @@ void ggml_metal_graph_find_concurrency(
498
557
  }
499
558
 
500
559
  if (ctx->concur_list_len > GGML_MAX_CONCUR) {
501
- fprintf(stderr, "%s: too many elements for metal ctx->concur_list!\n", __func__);
560
+ metal_printf("%s: too many elements for metal ctx->concur_list!\n", __func__);
502
561
  }
503
562
  }
504
563
 
505
564
  void ggml_metal_graph_compute(
506
565
  struct ggml_metal_context * ctx,
507
566
  struct ggml_cgraph * gf) {
508
- metal_printf("%s: evaluating graph\n", __func__);
567
+ @autoreleasepool {
509
568
 
510
569
  // if there is ctx->concur_list, dispatch concurrently
511
570
  // else fallback to serial dispatch
@@ -521,29 +580,25 @@ void ggml_metal_graph_compute(
521
580
 
522
581
  const int n_cb = ctx->n_cb;
523
582
 
524
- NSMutableArray * command_buffers = [NSMutableArray arrayWithCapacity:n_cb];
525
-
526
583
  for (int i = 0; i < n_cb; ++i) {
527
- command_buffers[i] = [ctx->queue commandBuffer];
584
+ ctx->command_buffers[i] = [ctx->queue commandBuffer];
528
585
 
529
586
  // enqueue the command buffers in order to specify their execution order
530
- [command_buffers[i] enqueue];
531
- }
587
+ [ctx->command_buffers[i] enqueue];
532
588
 
533
- // TODO: is this the best way to start threads?
534
- dispatch_queue_t queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT);
589
+ ctx->command_encoders[i] = [ctx->command_buffers[i] computeCommandEncoderWithDescriptor: edesc];
590
+ }
535
591
 
536
592
  for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
537
593
  const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb;
538
594
 
539
- dispatch_async(queue, ^{
595
+ dispatch_async(ctx->d_queue, ^{
540
596
  size_t offs_src0 = 0;
541
597
  size_t offs_src1 = 0;
542
598
  size_t offs_dst = 0;
543
599
 
544
- id<MTLCommandBuffer> command_buffer = command_buffers[cb_idx];
545
-
546
- id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
600
+ id<MTLCommandBuffer> command_buffer = ctx->command_buffers[cb_idx];
601
+ id<MTLComputeCommandEncoder> encoder = ctx->command_encoders[cb_idx];
547
602
 
548
603
  const int node_start = (cb_idx + 0) * n_nodes_per_cb;
549
604
  const int node_end = MIN((cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb, n_nodes);
@@ -556,7 +611,7 @@ void ggml_metal_graph_compute(
556
611
  continue;
557
612
  }
558
613
 
559
- metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
614
+ //metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
560
615
 
561
616
  struct ggml_tensor * src0 = gf->nodes[i]->src[0];
562
617
  struct ggml_tensor * src1 = gf->nodes[i]->src[1];
@@ -625,6 +680,12 @@ void ggml_metal_graph_compute(
625
680
  } break;
626
681
  case GGML_OP_ADD:
627
682
  {
683
+ GGML_ASSERT(ggml_is_contiguous(src0));
684
+
685
+ // utilize float4
686
+ GGML_ASSERT(ne00 % 4 == 0);
687
+ const int64_t nb = ne00/4;
688
+
628
689
  if (ggml_nelements(src1) == ne10) {
629
690
  // src1 is a row
630
691
  [encoder setComputePipelineState:ctx->pipeline_add_row];
@@ -634,14 +695,20 @@ void ggml_metal_graph_compute(
634
695
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
635
696
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
636
697
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
637
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
698
+ [encoder setBytes:&nb length:sizeof(nb) atIndex:3];
638
699
 
639
- const int64_t n = ggml_nelements(dst);
700
+ const int64_t n = ggml_nelements(dst)/4;
640
701
 
641
702
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
642
703
  } break;
643
704
  case GGML_OP_MUL:
644
705
  {
706
+ GGML_ASSERT(ggml_is_contiguous(src0));
707
+
708
+ // utilize float4
709
+ GGML_ASSERT(ne00 % 4 == 0);
710
+ const int64_t nb = ne00/4;
711
+
645
712
  if (ggml_nelements(src1) == ne10) {
646
713
  // src1 is a row
647
714
  [encoder setComputePipelineState:ctx->pipeline_mul_row];
@@ -651,9 +718,9 @@ void ggml_metal_graph_compute(
651
718
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
652
719
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
653
720
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
654
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
721
+ [encoder setBytes:&nb length:sizeof(nb) atIndex:3];
655
722
 
656
- const int64_t n = ggml_nelements(dst);
723
+ const int64_t n = ggml_nelements(dst)/4;
657
724
 
658
725
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
659
726
  } break;
@@ -704,7 +771,7 @@ void ggml_metal_graph_compute(
704
771
  } break;
705
772
  default:
706
773
  {
707
- fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
774
+ metal_printf("%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
708
775
  GGML_ASSERT(false);
709
776
  }
710
777
  } break;
@@ -785,7 +852,7 @@ void ggml_metal_graph_compute(
785
852
  switch (src0t) {
786
853
  case GGML_TYPE_F16:
787
854
  {
788
- nth0 = 64;
855
+ nth0 = 32;
789
856
  nth1 = 1;
790
857
  [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
791
858
  } break;
@@ -863,7 +930,7 @@ void ggml_metal_graph_compute(
863
930
  } break;
864
931
  default:
865
932
  {
866
- fprintf(stderr, "Asserting on type %d\n",(int)src0t);
933
+ metal_printf("Asserting on type %d\n",(int)src0t);
867
934
  GGML_ASSERT(false && "not implemented");
868
935
  }
869
936
  };
@@ -1101,7 +1168,7 @@ void ggml_metal_graph_compute(
1101
1168
  } break;
1102
1169
  default:
1103
1170
  {
1104
- fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
1171
+ metal_printf("%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
1105
1172
  GGML_ASSERT(false);
1106
1173
  }
1107
1174
  }
@@ -1117,17 +1184,19 @@ void ggml_metal_graph_compute(
1117
1184
  }
1118
1185
 
1119
1186
  // wait for all threads to finish
1120
- dispatch_barrier_sync(queue, ^{});
1121
-
1122
- [command_buffers[n_cb - 1] waitUntilCompleted];
1187
+ dispatch_barrier_sync(ctx->d_queue, ^{});
1123
1188
 
1124
1189
  // check status of command buffers
1125
1190
  // needed to detect if the device ran out-of-memory for example (#1881)
1126
1191
  for (int i = 0; i < n_cb; i++) {
1127
- MTLCommandBufferStatus status = (MTLCommandBufferStatus) [command_buffers[i] status];
1192
+ [ctx->command_buffers[i] waitUntilCompleted];
1193
+
1194
+ MTLCommandBufferStatus status = (MTLCommandBufferStatus) [ctx->command_buffers[i] status];
1128
1195
  if (status != MTLCommandBufferStatusCompleted) {
1129
- fprintf(stderr, "%s: command buffer %d failed with status %lu\n", __func__, i, status);
1196
+ metal_printf("%s: command buffer %d failed with status %lu\n", __func__, i, status);
1130
1197
  GGML_ASSERT(false);
1131
1198
  }
1132
1199
  }
1200
+
1201
+ }
1133
1202
  }
@@ -25,9 +25,9 @@ typedef struct {
25
25
  } block_q8_0;
26
26
 
27
27
  kernel void kernel_add(
28
- device const float * src0,
29
- device const float * src1,
30
- device float * dst,
28
+ device const float4 * src0,
29
+ device const float4 * src1,
30
+ device float4 * dst,
31
31
  uint tpig[[thread_position_in_grid]]) {
32
32
  dst[tpig] = src0[tpig] + src1[tpig];
33
33
  }
@@ -35,18 +35,18 @@ kernel void kernel_add(
35
35
  // assumption: src1 is a row
36
36
  // broadcast src1 into src0
37
37
  kernel void kernel_add_row(
38
- device const float * src0,
39
- device const float * src1,
40
- device float * dst,
41
- constant int64_t & ne00,
38
+ device const float4 * src0,
39
+ device const float4 * src1,
40
+ device float4 * dst,
41
+ constant int64_t & nb,
42
42
  uint tpig[[thread_position_in_grid]]) {
43
- dst[tpig] = src0[tpig] + src1[tpig % ne00];
43
+ dst[tpig] = src0[tpig] + src1[tpig % nb];
44
44
  }
45
45
 
46
46
  kernel void kernel_mul(
47
- device const float * src0,
48
- device const float * src1,
49
- device float * dst,
47
+ device const float4 * src0,
48
+ device const float4 * src1,
49
+ device float4 * dst,
50
50
  uint tpig[[thread_position_in_grid]]) {
51
51
  dst[tpig] = src0[tpig] * src1[tpig];
52
52
  }
@@ -54,12 +54,12 @@ kernel void kernel_mul(
54
54
  // assumption: src1 is a row
55
55
  // broadcast src1 into src0
56
56
  kernel void kernel_mul_row(
57
- device const float * src0,
58
- device const float * src1,
59
- device float * dst,
60
- constant int64_t & ne00,
57
+ device const float4 * src0,
58
+ device const float4 * src1,
59
+ device float4 * dst,
60
+ constant int64_t & nb,
61
61
  uint tpig[[thread_position_in_grid]]) {
62
- dst[tpig] = src0[tpig] * src1[tpig % ne00];
62
+ dst[tpig] = src0[tpig] * src1[tpig % nb];
63
63
  }
64
64
 
65
65
  kernel void kernel_scale(
@@ -528,24 +528,42 @@ kernel void kernel_mul_mat_f16_f32(
528
528
  device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
529
529
  device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
530
530
 
531
- sum[tpitg.x] = 0.0f;
531
+ uint ith = tpitg.x;
532
+ uint nth = tptg.x;
533
+
534
+ sum[ith] = 0.0f;
532
535
 
533
- for (int i = tpitg.x; i < ne00; i += tptg.x) {
534
- sum[tpitg.x] += (float) x[i] * (float) y[i];
536
+ for (int i = ith; i < ne00; i += nth) {
537
+ sum[ith] += (float) x[i] * (float) y[i];
535
538
  }
536
539
 
537
540
  // accumulate the sum from all threads in the threadgroup
538
541
  threadgroup_barrier(mem_flags::mem_threadgroup);
539
- for (uint i = tptg.x/2; i > 0; i /= 2) {
540
- if (tpitg.x < i) {
541
- sum[tpitg.x] += sum[tpitg.x + i];
542
- }
543
- threadgroup_barrier(mem_flags::mem_threadgroup);
542
+ if (ith%4 == 0) {
543
+ for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
544
544
  }
545
-
546
- if (tpitg.x == 0) {
545
+ threadgroup_barrier(mem_flags::mem_threadgroup);
546
+ if (ith%16 == 0) {
547
+ for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
548
+ }
549
+ threadgroup_barrier(mem_flags::mem_threadgroup);
550
+ if (ith == 0) {
551
+ for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
547
552
  dst[im*ne1*ne0 + r1*ne0 + r0] = sum[0];
548
553
  }
554
+
555
+ // Original implementation. Left behind commented out for now
556
+ //threadgroup_barrier(mem_flags::mem_threadgroup);
557
+ //for (uint i = tptg.x/2; i > 0; i /= 2) {
558
+ // if (tpitg.x < i) {
559
+ // sum[tpitg.x] += sum[tpitg.x + i];
560
+ // }
561
+ // threadgroup_barrier(mem_flags::mem_threadgroup);
562
+ //}
563
+ //
564
+ //if (tpitg.x == 0) {
565
+ // dst[im*ne1*ne0 + r1*ne0 + r0] = sum[0];
566
+ //}
549
567
  }
550
568
 
551
569
  kernel void kernel_alibi_f32(