llama_cpp 0.2.0 → 0.2.1

Sign up to get free protection for your applications and to get access to all the features.
@@ -52,14 +52,18 @@ struct ggml_metal_context {
52
52
  GGML_METAL_DECL_KERNEL(get_rows_q4_0);
53
53
  GGML_METAL_DECL_KERNEL(get_rows_q4_1);
54
54
  GGML_METAL_DECL_KERNEL(get_rows_q2_k);
55
+ GGML_METAL_DECL_KERNEL(get_rows_q3_k);
55
56
  GGML_METAL_DECL_KERNEL(get_rows_q4_k);
57
+ GGML_METAL_DECL_KERNEL(get_rows_q5_k);
56
58
  GGML_METAL_DECL_KERNEL(get_rows_q6_k);
57
59
  GGML_METAL_DECL_KERNEL(rms_norm);
58
60
  GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
59
61
  GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
60
62
  GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
61
63
  GGML_METAL_DECL_KERNEL(mul_mat_q2_k_f32);
64
+ GGML_METAL_DECL_KERNEL(mul_mat_q3_k_f32);
62
65
  GGML_METAL_DECL_KERNEL(mul_mat_q4_k_f32);
66
+ GGML_METAL_DECL_KERNEL(mul_mat_q5_k_f32);
63
67
  GGML_METAL_DECL_KERNEL(mul_mat_q6_k_f32);
64
68
  GGML_METAL_DECL_KERNEL(rope);
65
69
  GGML_METAL_DECL_KERNEL(cpy_f32_f16);
@@ -86,6 +90,7 @@ struct ggml_metal_context * ggml_metal_init(void) {
86
90
 
87
91
  ctx->device = MTLCreateSystemDefaultDevice();
88
92
  ctx->queue = [ctx->device newCommandQueue];
93
+ ctx->n_buffers = 0;
89
94
 
90
95
  // determine if we can use MPS
91
96
  if (MPSSupportsMTLDevice(ctx->device)) {
@@ -152,14 +157,18 @@ struct ggml_metal_context * ggml_metal_init(void) {
152
157
  GGML_METAL_ADD_KERNEL(get_rows_q4_0);
153
158
  GGML_METAL_ADD_KERNEL(get_rows_q4_1);
154
159
  GGML_METAL_ADD_KERNEL(get_rows_q2_k);
160
+ GGML_METAL_ADD_KERNEL(get_rows_q3_k);
155
161
  GGML_METAL_ADD_KERNEL(get_rows_q4_k);
162
+ GGML_METAL_ADD_KERNEL(get_rows_q5_k);
156
163
  GGML_METAL_ADD_KERNEL(get_rows_q6_k);
157
164
  GGML_METAL_ADD_KERNEL(rms_norm);
158
165
  GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
159
166
  GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
160
167
  GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
161
168
  GGML_METAL_ADD_KERNEL(mul_mat_q2_k_f32);
169
+ GGML_METAL_ADD_KERNEL(mul_mat_q3_k_f32);
162
170
  GGML_METAL_ADD_KERNEL(mul_mat_q4_k_f32);
171
+ GGML_METAL_ADD_KERNEL(mul_mat_q5_k_f32);
163
172
  GGML_METAL_ADD_KERNEL(mul_mat_q6_k_f32);
164
173
  GGML_METAL_ADD_KERNEL(rope);
165
174
  GGML_METAL_ADD_KERNEL(cpy_f32_f16);
@@ -275,509 +284,551 @@ void ggml_metal_get_tensor(
275
284
 
276
285
  void ggml_metal_graph_compute(
277
286
  struct ggml_metal_context * ctx,
278
- struct ggml_cgraph * gf) {
287
+ struct ggml_cgraph * gf) {
279
288
  metal_printf("%s: evaluating graph\n", __func__);
280
289
 
281
- size_t offs_src0 = 0;
282
- size_t offs_src1 = 0;
283
- size_t offs_dst = 0;
284
-
285
- id<MTLCommandBuffer> command_buffer = [ctx->queue commandBuffer];
286
- id<MTLComputeCommandEncoder> encoder = nil;
287
-
288
- for (int i = 0; i < gf->n_nodes; ++i) {
289
- //metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
290
-
291
- struct ggml_tensor * src0 = gf->nodes[i]->src0;
292
- struct ggml_tensor * src1 = gf->nodes[i]->src1;
293
- struct ggml_tensor * dst = gf->nodes[i];
294
-
295
- const int64_t ne00 = src0 ? src0->ne[0] : 0;
296
- const int64_t ne01 = src0 ? src0->ne[1] : 0;
297
- const int64_t ne02 = src0 ? src0->ne[2] : 0;
298
- const int64_t ne03 = src0 ? src0->ne[3] : 0;
299
-
300
- const uint64_t nb00 = src0 ? src0->nb[0] : 0;
301
- const uint64_t nb01 = src0 ? src0->nb[1] : 0;
302
- const uint64_t nb02 = src0 ? src0->nb[2] : 0;
303
- const uint64_t nb03 = src0 ? src0->nb[3] : 0;
304
-
305
- const int64_t ne10 = src1 ? src1->ne[0] : 0;
306
- const int64_t ne11 = src1 ? src1->ne[1] : 0;
307
- const int64_t ne12 = src1 ? src1->ne[2] : 0;
308
- const int64_t ne13 = src1 ? src1->ne[3] : 0; UNUSED(ne13);
309
-
310
- const uint64_t nb10 = src1 ? src1->nb[0] : 0;
311
- const uint64_t nb11 = src1 ? src1->nb[1] : 0;
312
- const uint64_t nb12 = src1 ? src1->nb[2] : 0;
313
- const uint64_t nb13 = src1 ? src1->nb[3] : 0; UNUSED(nb13);
314
-
315
- const int64_t ne0 = dst ? dst->ne[0] : 0;
316
- const int64_t ne1 = dst ? dst->ne[1] : 0;
317
- const int64_t ne2 = dst ? dst->ne[2] : 0;
318
- const int64_t ne3 = dst ? dst->ne[3] : 0;
319
-
320
- const uint64_t nb0 = dst ? dst->nb[0] : 0;
321
- const uint64_t nb1 = dst ? dst->nb[1] : 0;
322
- const uint64_t nb2 = dst ? dst->nb[2] : 0;
323
- const uint64_t nb3 = dst ? dst->nb[3] : 0;
324
-
325
- const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
326
- const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
327
- const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT;
328
-
329
- id<MTLBuffer> id_src0 = src0 ? ggml_metal_get_buffer(ctx, src0, &offs_src0) : nil;
330
- id<MTLBuffer> id_src1 = src1 ? ggml_metal_get_buffer(ctx, src1, &offs_src1) : nil;
331
- id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(ctx, dst, &offs_dst) : nil;
332
-
333
- //metal_printf("%s: op - %s\n", __func__, ggml_op_name(dst->op));
334
- //if (src0) {
335
- // metal_printf("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02,
336
- // ggml_is_contiguous(src0), src0->name);
337
- //}
338
- //if (src1) {
339
- // metal_printf("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12,
340
- // ggml_is_contiguous(src1), src1->name);
341
- //}
342
- //if (dst) {
343
- // metal_printf("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2,
344
- // dst->name);
345
- //}
346
-
347
- switch (dst->op) {
348
- case GGML_OP_RESHAPE:
349
- case GGML_OP_VIEW:
350
- case GGML_OP_TRANSPOSE:
351
- case GGML_OP_PERMUTE:
352
- {
353
- // noop
354
- } break;
355
- case GGML_OP_ADD:
356
- {
357
- if (encoder == nil) {
358
- encoder = [command_buffer computeCommandEncoder];
359
- }
360
-
361
- [encoder setComputePipelineState:ctx->pipeline_add];
362
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
363
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
364
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
365
-
366
- const int64_t n = ggml_nelements(dst);
367
-
368
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
369
- } break;
370
- case GGML_OP_MUL:
371
- {
372
- if (encoder == nil) {
373
- encoder = [command_buffer computeCommandEncoder];
374
- }
375
-
376
- if (ggml_nelements(src1) == ne10) {
377
- // src1 is a row
378
- [encoder setComputePipelineState:ctx->pipeline_mul_row];
379
- } else {
380
- [encoder setComputePipelineState:ctx->pipeline_mul];
381
- }
382
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
383
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
384
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
385
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
386
-
387
- const int64_t n = ggml_nelements(dst);
388
-
389
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
390
- } break;
391
- case GGML_OP_SCALE:
392
- {
393
- if (encoder == nil) {
394
- encoder = [command_buffer computeCommandEncoder];
395
- }
396
-
397
- const float scale = *(const float *) src1->data;
398
-
399
- [encoder setComputePipelineState:ctx->pipeline_scale];
400
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
401
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
402
- [encoder setBytes:&scale length:sizeof(scale) atIndex:2];
403
-
404
- const int64_t n = ggml_nelements(dst);
405
-
406
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
407
- } break;
408
- case GGML_OP_SILU:
409
- {
410
- if (encoder == nil) {
411
- encoder = [command_buffer computeCommandEncoder];
412
- }
413
-
414
- [encoder setComputePipelineState:ctx->pipeline_silu];
415
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
416
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
417
-
418
- const int64_t n = ggml_nelements(dst);
419
-
420
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
421
- } break;
422
- case GGML_OP_RELU:
423
- {
424
- if (encoder == nil) {
425
- encoder = [command_buffer computeCommandEncoder];
426
- }
427
-
428
- [encoder setComputePipelineState:ctx->pipeline_relu];
429
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
430
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
431
-
432
- const int64_t n = ggml_nelements(dst);
433
-
434
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
435
- } break;
436
- case GGML_OP_GELU:
437
- {
438
- if (encoder == nil) {
439
- encoder = [command_buffer computeCommandEncoder];
440
- }
441
-
442
- [encoder setComputePipelineState:ctx->pipeline_gelu];
443
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
444
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
445
-
446
- const int64_t n = ggml_nelements(dst);
447
-
448
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
449
- } break;
450
- case GGML_OP_SOFT_MAX:
451
- {
452
- if (encoder == nil) {
453
- encoder = [command_buffer computeCommandEncoder];
454
- }
455
-
456
- const int nth = 32;
457
-
458
- [encoder setComputePipelineState:ctx->pipeline_soft_max];
459
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
460
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
461
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
462
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
463
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
464
- [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
465
-
466
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
467
- } break;
468
- case GGML_OP_DIAG_MASK_INF:
469
- {
470
- if (encoder == nil) {
471
- encoder = [command_buffer computeCommandEncoder];
472
- }
473
-
474
- const int n_past = ((int32_t *)(src1->data))[0];
475
-
476
- [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
477
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
478
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
479
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
480
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
481
- [encoder setBytes:&n_past length:sizeof(int) atIndex:4];
482
-
483
- [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
484
- } break;
485
- case GGML_OP_MUL_MAT:
486
- {
487
- // TODO: needs to be updated after PR: https://github.com/ggerganov/ggml/pull/224
488
-
489
- GGML_ASSERT(ne00 == ne10);
490
- GGML_ASSERT(ne02 == ne12);
491
-
492
- if (ggml_is_contiguous(src0) &&
493
- ggml_is_contiguous(src1) &&
494
- (src0t == GGML_TYPE_F32 || src0t == GGML_TYPE_F16) && ne11 > 1) {
495
-
496
- if (encoder != nil) {
497
- [encoder endEncoding];
498
- encoder = nil;
499
- }
500
-
501
- MPSDataType src0dt = src0t == GGML_TYPE_F32 ? MPSDataTypeFloat32 : MPSDataTypeFloat16;
502
- MPSDataType src1dt = src1t == GGML_TYPE_F32 ? MPSDataTypeFloat32 : MPSDataTypeFloat16;
503
-
504
- // for F32 x F32 we use MPS
505
- MPSMatrixDescriptor * desc0 = [MPSMatrixDescriptor
506
- matrixDescriptorWithRows:ne01 columns:ne00 rowBytes:src0->nb[1] dataType:src0dt];
507
-
508
- MPSMatrixDescriptor * desc1 = [MPSMatrixDescriptor
509
- matrixDescriptorWithRows:ne11 columns:ne10 rowBytes:src1->nb[1] dataType:src1dt];
510
-
511
- MPSMatrixDescriptor * desc = [MPSMatrixDescriptor
512
- matrixDescriptorWithRows:ne1 columns:ne0 rowBytes:dst->nb[1] dataType:MPSDataTypeFloat32];
513
-
514
- MPSMatrixMultiplication * mul = [[MPSMatrixMultiplication alloc]
515
- initWithDevice:ctx->device transposeLeft:false transposeRight:true
516
- resultRows:ne11 resultColumns:ne01 interiorColumns:ne00 alpha:1.0 beta:0.0];
517
-
518
- // we need to do ne02 multiplications
519
- // TODO: is there a way to do this in parallel - currently very slow ..
520
- // TODO: might be possible to offload part of the computation to ANE using Accelerate's CBLAS
521
- for (int64_t i02 = 0; i02 < ne02; ++i02) {
522
- size_t offs_src0_cur = offs_src0 + i02*nb02;
523
- size_t offs_src1_cur = offs_src1 + i02*nb12;
524
- size_t offs_dst_cur = offs_dst + i02*nb2;
525
-
526
- MPSMatrix * mat_src0 = [[MPSMatrix alloc] initWithBuffer:id_src0 offset:offs_src0_cur descriptor:desc0];
527
- MPSMatrix * mat_src1 = [[MPSMatrix alloc] initWithBuffer:id_src1 offset:offs_src1_cur descriptor:desc1];
528
- MPSMatrix * mat_dst = [[MPSMatrix alloc] initWithBuffer:id_dst offset:offs_dst_cur descriptor:desc ];
529
-
530
- [mul encodeToCommandBuffer:command_buffer leftMatrix:mat_src1 rightMatrix:mat_src0 resultMatrix:mat_dst];
531
- }
532
- } else {
533
- if (encoder == nil) {
534
- encoder = [command_buffer computeCommandEncoder];
535
- }
536
-
537
- int nth0 = 32;
538
- int nth1 = 1;
539
-
540
- // use custom matrix x vector kernel
541
- switch (src0t) {
542
- case GGML_TYPE_F16:
543
- {
544
- GGML_ASSERT(ne02 == ne12);
545
-
546
- nth0 = 64;
547
- nth1 = 1;
548
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
549
- } break;
550
- case GGML_TYPE_Q4_0:
551
- {
552
- GGML_ASSERT(ne02 == 1);
553
- GGML_ASSERT(ne12 == 1);
554
-
555
- nth0 = 8;
556
- nth1 = 8;
557
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32];
558
- } break;
559
- case GGML_TYPE_Q4_1:
560
- {
561
- GGML_ASSERT(ne02 == 1);
562
- GGML_ASSERT(ne12 == 1);
563
-
564
- nth0 = 8;
565
- nth1 = 8;
566
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_1_f32];
567
- } break;
568
- case GGML_TYPE_Q2_K:
569
- {
570
- GGML_ASSERT(ne02 == 1);
571
- GGML_ASSERT(ne12 == 1);
572
-
573
- nth0 = 4;
574
- nth1 = 16;
575
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_k_f32];
576
- } break;
577
- case GGML_TYPE_Q4_K:
578
- {
579
- GGML_ASSERT(ne02 == 1);
580
- GGML_ASSERT(ne12 == 1);
581
-
582
- nth0 = 4;
583
- nth1 = 16;
584
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_k_f32];
585
- } break;
586
- case GGML_TYPE_Q6_K:
587
- {
588
- GGML_ASSERT(ne02 == 1);
589
- GGML_ASSERT(ne12 == 1);
590
-
591
- nth0 = 4;
592
- nth1 = 16;
593
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_k_f32];
594
- } break;
595
- default:
596
- {
597
- fprintf(stderr, "Asserting on type %d\n",(int)src0t);
598
- GGML_ASSERT(false && "not implemented");
599
- }
600
- };
601
-
602
-
603
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
604
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
605
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
606
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
607
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
608
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:5];
609
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:6];
610
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:7];
611
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:8];
612
- [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:9];
613
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10];
614
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11];
615
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12];
616
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
617
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
618
-
619
- if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1) {
620
- [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
621
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
622
- } else if (src0t == GGML_TYPE_Q2_K) {
623
- [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
624
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
625
- } else if (src0t == GGML_TYPE_Q4_K) {
626
- [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
627
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
628
- } else if (src0t == GGML_TYPE_Q6_K) {
629
- [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
630
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
631
- } else {
632
- [encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
633
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
634
- }
635
- }
636
- } break;
637
- case GGML_OP_GET_ROWS:
638
- {
639
- if (encoder == nil) {
640
- encoder = [command_buffer computeCommandEncoder];
641
- }
642
-
643
- switch (src0->type) {
644
- case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
645
- case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
646
- case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
647
- case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_k]; break;
648
- case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_k]; break;
649
- case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_k]; break;
650
- default: GGML_ASSERT(false && "not implemented");
651
- }
652
-
653
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
654
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
655
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
656
- [encoder setBytes:&(src0->ne[0]) length:sizeof( int64_t) atIndex:3];
657
- [encoder setBytes:&(src0->nb[1]) length:sizeof(uint64_t) atIndex:4];
658
- [encoder setBytes:&(dst->nb[1]) length:sizeof(uint64_t) atIndex:5];
659
-
660
- const int64_t n = ggml_nelements(src1);
661
-
662
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
663
- } break;
664
- case GGML_OP_RMS_NORM:
665
- {
666
- if (encoder == nil) {
667
- encoder = [command_buffer computeCommandEncoder];
668
- }
669
-
670
- const float eps = 1e-6f;
671
-
672
- const int nth = 256;
673
-
674
- [encoder setComputePipelineState:ctx->pipeline_rms_norm];
675
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
676
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
677
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
678
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
679
- [encoder setBytes:&eps length:sizeof( float) atIndex:4];
680
- [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
681
-
682
- const int64_t nrows = ggml_nrows(src0);
683
-
684
- [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
685
- } break;
686
- case GGML_OP_ROPE:
687
- {
688
- if (encoder == nil) {
689
- encoder = [command_buffer computeCommandEncoder];
690
- }
691
-
692
- const int n_dims = ((int32_t *) src1->data)[1];
693
- const int mode = ((int32_t *) src1->data)[2];
694
-
695
- const int n_past = ((int32_t *)(src1->data))[0];
696
-
697
- [encoder setComputePipelineState:ctx->pipeline_rope];
698
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
699
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
700
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
701
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
702
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
703
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
704
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
705
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
706
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
707
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
708
- [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
709
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
710
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
711
- [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
712
- [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
713
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
714
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
715
- [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
716
- [encoder setBytes:&n_past length:sizeof( int) atIndex:18];
717
- [encoder setBytes:&n_dims length:sizeof( int) atIndex:19];
718
- [encoder setBytes:&mode length:sizeof( int) atIndex:20];
719
-
720
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
721
- } break;
722
- case GGML_OP_CPY:
723
- {
724
- if (encoder == nil) {
725
- encoder = [command_buffer computeCommandEncoder];
726
- }
727
-
728
- const int nth = 32;
729
-
730
- switch (src0t) {
731
- case GGML_TYPE_F32:
732
- {
733
- switch (dstt) {
734
- case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f16]; break;
735
- case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32]; break;
736
- default: GGML_ASSERT(false && "not implemented");
737
- };
738
- } break;
739
- default: GGML_ASSERT(false && "not implemented");
740
- }
741
-
742
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
743
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
744
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
745
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
746
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
747
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
748
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
749
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
750
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
751
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
752
- [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
753
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
754
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
755
- [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
756
- [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
757
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
758
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
759
- [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
760
-
761
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
762
- } break;
763
- default:
764
- fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
765
- GGML_ASSERT(false);
766
- }
767
- }
290
+ // create multiple command buffers and enqueue them
291
+ // then, we encode the graph into the command buffers in parallel
292
+
293
+ const int n_cb = gf->n_threads;
294
+
295
+ NSMutableArray * command_buffers = [NSMutableArray arrayWithCapacity:n_cb];
296
+
297
+ for (int i = 0; i < n_cb; ++i) {
298
+ command_buffers[i] = [ctx->queue commandBuffer];
768
299
 
769
- if (encoder != nil) {
770
- [encoder endEncoding];
771
- encoder = nil;
300
+ // enqueue the command buffers in order to specify their execution order
301
+ [command_buffers[i] enqueue];
772
302
  }
773
303
 
774
- [command_buffer commit];
775
- [command_buffer waitUntilCompleted];
304
+ // TODO: is this the best way to start threads?
305
+ dispatch_queue_t queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT);
306
+
307
+ for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
308
+ const int n_nodes_per_cb = (gf->n_nodes + n_cb - 1) / n_cb;
309
+
310
+ dispatch_async(queue, ^{
311
+ size_t offs_src0 = 0;
312
+ size_t offs_src1 = 0;
313
+ size_t offs_dst = 0;
314
+
315
+ id<MTLCommandBuffer> command_buffer = command_buffers[cb_idx];
316
+
317
+ id<MTLComputeCommandEncoder> encoder = nil;
318
+
319
+ const int node_start = (cb_idx + 0) * n_nodes_per_cb;
320
+ const int node_end = (cb_idx == n_cb - 1) ? gf->n_nodes : (cb_idx + 1) * n_nodes_per_cb;
321
+
322
+ for (int i = node_start; i < node_end; ++i) {
323
+ metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
324
+
325
+ struct ggml_tensor * src0 = gf->nodes[i]->src0;
326
+ struct ggml_tensor * src1 = gf->nodes[i]->src1;
327
+ struct ggml_tensor * dst = gf->nodes[i];
328
+
329
+ const int64_t ne00 = src0 ? src0->ne[0] : 0;
330
+ const int64_t ne01 = src0 ? src0->ne[1] : 0;
331
+ const int64_t ne02 = src0 ? src0->ne[2] : 0;
332
+ const int64_t ne03 = src0 ? src0->ne[3] : 0;
333
+
334
+ const uint64_t nb00 = src0 ? src0->nb[0] : 0;
335
+ const uint64_t nb01 = src0 ? src0->nb[1] : 0;
336
+ const uint64_t nb02 = src0 ? src0->nb[2] : 0;
337
+ const uint64_t nb03 = src0 ? src0->nb[3] : 0;
338
+
339
+ const int64_t ne10 = src1 ? src1->ne[0] : 0;
340
+ const int64_t ne11 = src1 ? src1->ne[1] : 0;
341
+ const int64_t ne12 = src1 ? src1->ne[2] : 0;
342
+ const int64_t ne13 = src1 ? src1->ne[3] : 0; UNUSED(ne13);
343
+
344
+ const uint64_t nb10 = src1 ? src1->nb[0] : 0;
345
+ const uint64_t nb11 = src1 ? src1->nb[1] : 0;
346
+ const uint64_t nb12 = src1 ? src1->nb[2] : 0;
347
+ const uint64_t nb13 = src1 ? src1->nb[3] : 0; UNUSED(nb13);
348
+
349
+ const int64_t ne0 = dst ? dst->ne[0] : 0;
350
+ const int64_t ne1 = dst ? dst->ne[1] : 0;
351
+ const int64_t ne2 = dst ? dst->ne[2] : 0;
352
+ const int64_t ne3 = dst ? dst->ne[3] : 0;
353
+
354
+ const uint64_t nb0 = dst ? dst->nb[0] : 0;
355
+ const uint64_t nb1 = dst ? dst->nb[1] : 0;
356
+ const uint64_t nb2 = dst ? dst->nb[2] : 0;
357
+ const uint64_t nb3 = dst ? dst->nb[3] : 0;
358
+
359
+ const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
360
+ const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
361
+ const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT;
362
+
363
+ id<MTLBuffer> id_src0 = src0 ? ggml_metal_get_buffer(ctx, src0, &offs_src0) : nil;
364
+ id<MTLBuffer> id_src1 = src1 ? ggml_metal_get_buffer(ctx, src1, &offs_src1) : nil;
365
+ id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(ctx, dst, &offs_dst) : nil;
366
+
367
+ //metal_printf("%s: op - %s\n", __func__, ggml_op_name(dst->op));
368
+ //if (src0) {
369
+ // metal_printf("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02,
370
+ // ggml_is_contiguous(src0), src0->name);
371
+ //}
372
+ //if (src1) {
373
+ // metal_printf("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12,
374
+ // ggml_is_contiguous(src1), src1->name);
375
+ //}
376
+ //if (dst) {
377
+ // metal_printf("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2,
378
+ // dst->name);
379
+ //}
380
+
381
+ switch (dst->op) {
382
+ case GGML_OP_RESHAPE:
383
+ case GGML_OP_VIEW:
384
+ case GGML_OP_TRANSPOSE:
385
+ case GGML_OP_PERMUTE:
386
+ {
387
+ // noop
388
+ } break;
389
+ case GGML_OP_ADD:
390
+ {
391
+ if (encoder == nil) {
392
+ encoder = [command_buffer computeCommandEncoder];
393
+ }
394
+
395
+ [encoder setComputePipelineState:ctx->pipeline_add];
396
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
397
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
398
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
399
+
400
+ const int64_t n = ggml_nelements(dst);
401
+
402
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
403
+ } break;
404
+ case GGML_OP_MUL:
405
+ {
406
+ if (encoder == nil) {
407
+ encoder = [command_buffer computeCommandEncoder];
408
+ }
409
+
410
+ if (ggml_nelements(src1) == ne10) {
411
+ // src1 is a row
412
+ [encoder setComputePipelineState:ctx->pipeline_mul_row];
413
+ } else {
414
+ [encoder setComputePipelineState:ctx->pipeline_mul];
415
+ }
416
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
417
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
418
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
419
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
420
+
421
+ const int64_t n = ggml_nelements(dst);
422
+
423
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
424
+ } break;
425
+ case GGML_OP_SCALE:
426
+ {
427
+ if (encoder == nil) {
428
+ encoder = [command_buffer computeCommandEncoder];
429
+ }
430
+
431
+ const float scale = *(const float *) src1->data;
432
+
433
+ [encoder setComputePipelineState:ctx->pipeline_scale];
434
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
435
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
436
+ [encoder setBytes:&scale length:sizeof(scale) atIndex:2];
437
+
438
+ const int64_t n = ggml_nelements(dst);
439
+
440
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
441
+ } break;
442
+ case GGML_OP_SILU:
443
+ {
444
+ if (encoder == nil) {
445
+ encoder = [command_buffer computeCommandEncoder];
446
+ }
447
+
448
+ [encoder setComputePipelineState:ctx->pipeline_silu];
449
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
450
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
451
+
452
+ const int64_t n = ggml_nelements(dst);
453
+
454
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
455
+ } break;
456
+ case GGML_OP_RELU:
457
+ {
458
+ if (encoder == nil) {
459
+ encoder = [command_buffer computeCommandEncoder];
460
+ }
461
+
462
+ [encoder setComputePipelineState:ctx->pipeline_relu];
463
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
464
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
465
+
466
+ const int64_t n = ggml_nelements(dst);
467
+
468
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
469
+ } break;
470
+ case GGML_OP_GELU:
471
+ {
472
+ if (encoder == nil) {
473
+ encoder = [command_buffer computeCommandEncoder];
474
+ }
475
+
476
+ [encoder setComputePipelineState:ctx->pipeline_gelu];
477
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
478
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
479
+
480
+ const int64_t n = ggml_nelements(dst);
481
+
482
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
483
+ } break;
484
+ case GGML_OP_SOFT_MAX:
485
+ {
486
+ if (encoder == nil) {
487
+ encoder = [command_buffer computeCommandEncoder];
488
+ }
489
+
490
+ const int nth = 32;
491
+
492
+ [encoder setComputePipelineState:ctx->pipeline_soft_max];
493
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
494
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
495
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
496
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
497
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
498
+ [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
499
+
500
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
501
+ } break;
502
+ case GGML_OP_DIAG_MASK_INF:
503
+ {
504
+ if (encoder == nil) {
505
+ encoder = [command_buffer computeCommandEncoder];
506
+ }
507
+
508
+ const int n_past = ((int32_t *)(src1->data))[0];
509
+
510
+ [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
511
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
512
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
513
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
514
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
515
+ [encoder setBytes:&n_past length:sizeof(int) atIndex:4];
516
+
517
+ [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
518
+ } break;
519
+ case GGML_OP_MUL_MAT:
520
+ {
521
+ // TODO: needs to be updated after PR: https://github.com/ggerganov/ggml/pull/224
522
+
523
+ GGML_ASSERT(ne00 == ne10);
524
+ GGML_ASSERT(ne02 == ne12);
525
+
526
+ if (ggml_is_contiguous(src0) &&
527
+ ggml_is_contiguous(src1) &&
528
+ (src0t == GGML_TYPE_F32 || src0t == GGML_TYPE_F16) && ne11 > 1) {
529
+
530
+ if (encoder != nil) {
531
+ [encoder endEncoding];
532
+ encoder = nil;
533
+ }
776
534
 
777
- {
778
- const double time_elapsed = [command_buffer GPUEndTime] - [command_buffer GPUStartTime];
779
- UNUSED(time_elapsed);
535
+ MPSDataType src0dt = src0t == GGML_TYPE_F32 ? MPSDataTypeFloat32 : MPSDataTypeFloat16;
536
+ MPSDataType src1dt = src1t == GGML_TYPE_F32 ? MPSDataTypeFloat32 : MPSDataTypeFloat16;
537
+
538
+ // for F32 x F32 we use MPS
539
+ MPSMatrixDescriptor * desc0 = [MPSMatrixDescriptor
540
+ matrixDescriptorWithRows:ne01 columns:ne00 rowBytes:src0->nb[1] dataType:src0dt];
541
+
542
+ MPSMatrixDescriptor * desc1 = [MPSMatrixDescriptor
543
+ matrixDescriptorWithRows:ne11 columns:ne10 rowBytes:src1->nb[1] dataType:src1dt];
544
+
545
+ MPSMatrixDescriptor * desc = [MPSMatrixDescriptor
546
+ matrixDescriptorWithRows:ne1 columns:ne0 rowBytes:dst->nb[1] dataType:MPSDataTypeFloat32];
547
+
548
+ MPSMatrixMultiplication * mul = [[MPSMatrixMultiplication alloc]
549
+ initWithDevice:ctx->device transposeLeft:false transposeRight:true
550
+ resultRows:ne11 resultColumns:ne01 interiorColumns:ne00 alpha:1.0 beta:0.0];
551
+
552
+ // we need to do ne02 multiplications
553
+ // TODO: is there a way to do this in parallel - currently very slow ..
554
+ // TODO: might be possible to offload part of the computation to ANE using Accelerate's CBLAS
555
+ for (int64_t i02 = 0; i02 < ne02; ++i02) {
556
+ size_t offs_src0_cur = offs_src0 + i02*nb02;
557
+ size_t offs_src1_cur = offs_src1 + i02*nb12;
558
+ size_t offs_dst_cur = offs_dst + i02*nb2;
780
559
 
781
- metal_printf("%s: time elapsed = %f ms\n", __func__, time_elapsed * 1000.0);
560
+ MPSMatrix * mat_src0 = [[MPSMatrix alloc] initWithBuffer:id_src0 offset:offs_src0_cur descriptor:desc0];
561
+ MPSMatrix * mat_src1 = [[MPSMatrix alloc] initWithBuffer:id_src1 offset:offs_src1_cur descriptor:desc1];
562
+ MPSMatrix * mat_dst = [[MPSMatrix alloc] initWithBuffer:id_dst offset:offs_dst_cur descriptor:desc ];
563
+
564
+ [mul encodeToCommandBuffer:command_buffer leftMatrix:mat_src1 rightMatrix:mat_src0 resultMatrix:mat_dst];
565
+ }
566
+ } else {
567
+ if (encoder == nil) {
568
+ encoder = [command_buffer computeCommandEncoder];
569
+ }
570
+
571
+ int nth0 = 32;
572
+ int nth1 = 1;
573
+
574
+ // use custom matrix x vector kernel
575
+ switch (src0t) {
576
+ case GGML_TYPE_F16:
577
+ {
578
+ GGML_ASSERT(ne02 == ne12);
579
+
580
+ nth0 = 64;
581
+ nth1 = 1;
582
+ [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
583
+ } break;
584
+ case GGML_TYPE_Q4_0:
585
+ {
586
+ GGML_ASSERT(ne02 == 1);
587
+ GGML_ASSERT(ne12 == 1);
588
+
589
+ nth0 = 8;
590
+ nth1 = 8;
591
+ [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32];
592
+ } break;
593
+ case GGML_TYPE_Q4_1:
594
+ {
595
+ GGML_ASSERT(ne02 == 1);
596
+ GGML_ASSERT(ne12 == 1);
597
+
598
+ nth0 = 8;
599
+ nth1 = 8;
600
+ [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_1_f32];
601
+ } break;
602
+ case GGML_TYPE_Q2_K:
603
+ {
604
+ GGML_ASSERT(ne02 == 1);
605
+ GGML_ASSERT(ne12 == 1);
606
+
607
+ nth0 = 4;
608
+ nth1 = 16;
609
+ [encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_k_f32];
610
+ } break;
611
+ case GGML_TYPE_Q3_K:
612
+ {
613
+ GGML_ASSERT(ne02 == 1);
614
+ GGML_ASSERT(ne12 == 1);
615
+
616
+ nth0 = 4;
617
+ nth1 = 16;
618
+ [encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_k_f32];
619
+ } break;
620
+ case GGML_TYPE_Q4_K:
621
+ {
622
+ GGML_ASSERT(ne02 == 1);
623
+ GGML_ASSERT(ne12 == 1);
624
+
625
+ nth0 = 4;
626
+ nth1 = 16;
627
+ [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_k_f32];
628
+ } break;
629
+ case GGML_TYPE_Q5_K:
630
+ {
631
+ GGML_ASSERT(ne02 == 1);
632
+ GGML_ASSERT(ne12 == 1);
633
+
634
+ nth0 = 4;
635
+ nth1 = 16;
636
+ [encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_k_f32];
637
+ } break;
638
+ case GGML_TYPE_Q6_K:
639
+ {
640
+ GGML_ASSERT(ne02 == 1);
641
+ GGML_ASSERT(ne12 == 1);
642
+
643
+ nth0 = 4;
644
+ nth1 = 16;
645
+ [encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_k_f32];
646
+ } break;
647
+ default:
648
+ {
649
+ fprintf(stderr, "Asserting on type %d\n",(int)src0t);
650
+ GGML_ASSERT(false && "not implemented");
651
+ }
652
+ };
653
+
654
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
655
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
656
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
657
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
658
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
659
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:5];
660
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:6];
661
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:7];
662
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:8];
663
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:9];
664
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10];
665
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11];
666
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12];
667
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
668
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
669
+
670
+ if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1) {
671
+ [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
672
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
673
+ }
674
+ else if (src0t == GGML_TYPE_Q2_K ||
675
+ src0t == GGML_TYPE_Q3_K ||
676
+ src0t == GGML_TYPE_Q4_K ||
677
+ src0t == GGML_TYPE_Q5_K ||
678
+ src0t == GGML_TYPE_Q6_K) {
679
+ [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
680
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
681
+ } else {
682
+ [encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
683
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
684
+ }
685
+ }
686
+ } break;
687
+ case GGML_OP_GET_ROWS:
688
+ {
689
+ if (encoder == nil) {
690
+ encoder = [command_buffer computeCommandEncoder];
691
+ }
692
+
693
+ switch (src0->type) {
694
+ case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
695
+ case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
696
+ case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
697
+ case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_k]; break;
698
+ case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_k]; break;
699
+ case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_k]; break;
700
+ case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_k]; break;
701
+ case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_k]; break;
702
+ default: GGML_ASSERT(false && "not implemented");
703
+ }
704
+
705
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
706
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
707
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
708
+ [encoder setBytes:&(src0->ne[0]) length:sizeof( int64_t) atIndex:3];
709
+ [encoder setBytes:&(src0->nb[1]) length:sizeof(uint64_t) atIndex:4];
710
+ [encoder setBytes:&(dst->nb[1]) length:sizeof(uint64_t) atIndex:5];
711
+
712
+ const int64_t n = ggml_nelements(src1);
713
+
714
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
715
+ } break;
716
+ case GGML_OP_RMS_NORM:
717
+ {
718
+ if (encoder == nil) {
719
+ encoder = [command_buffer computeCommandEncoder];
720
+ }
721
+
722
+ const float eps = 1e-6f;
723
+
724
+ const int nth = 256;
725
+
726
+ [encoder setComputePipelineState:ctx->pipeline_rms_norm];
727
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
728
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
729
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
730
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
731
+ [encoder setBytes:&eps length:sizeof( float) atIndex:4];
732
+ [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
733
+
734
+ const int64_t nrows = ggml_nrows(src0);
735
+
736
+ [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
737
+ } break;
738
+ case GGML_OP_ROPE:
739
+ {
740
+ if (encoder == nil) {
741
+ encoder = [command_buffer computeCommandEncoder];
742
+ }
743
+
744
+ const int n_dims = ((int32_t *) src1->data)[1];
745
+ const int mode = ((int32_t *) src1->data)[2];
746
+
747
+ const int n_past = ((int32_t *)(src1->data))[0];
748
+
749
+ [encoder setComputePipelineState:ctx->pipeline_rope];
750
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
751
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
752
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
753
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
754
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
755
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
756
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
757
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
758
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
759
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
760
+ [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
761
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
762
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
763
+ [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
764
+ [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
765
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
766
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
767
+ [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
768
+ [encoder setBytes:&n_past length:sizeof( int) atIndex:18];
769
+ [encoder setBytes:&n_dims length:sizeof( int) atIndex:19];
770
+ [encoder setBytes:&mode length:sizeof( int) atIndex:20];
771
+
772
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
773
+ } break;
774
+ case GGML_OP_CPY:
775
+ {
776
+ if (encoder == nil) {
777
+ encoder = [command_buffer computeCommandEncoder];
778
+ }
779
+
780
+ const int nth = 32;
781
+
782
+ switch (src0t) {
783
+ case GGML_TYPE_F32:
784
+ {
785
+ switch (dstt) {
786
+ case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f16]; break;
787
+ case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32]; break;
788
+ default: GGML_ASSERT(false && "not implemented");
789
+ };
790
+ } break;
791
+ default: GGML_ASSERT(false && "not implemented");
792
+ }
793
+
794
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
795
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
796
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
797
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
798
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
799
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
800
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
801
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
802
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
803
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
804
+ [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
805
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
806
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
807
+ [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
808
+ [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
809
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
810
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
811
+ [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
812
+
813
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
814
+ } break;
815
+ default:
816
+ fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
817
+ GGML_ASSERT(false);
818
+ }
819
+ }
820
+
821
+ if (encoder != nil) {
822
+ [encoder endEncoding];
823
+ encoder = nil;
824
+ }
825
+
826
+ [command_buffer commit];
827
+ });
782
828
  }
829
+
830
+ // wait for all threads to finish
831
+ dispatch_barrier_sync(queue, ^{});
832
+
833
+ [command_buffers[n_cb - 1] waitUntilCompleted];
783
834
  }