llama_cpp 0.3.3 → 0.3.5

Sign up to get free protection for your applications and to get access to all the features.
@@ -61,6 +61,13 @@ void ggml_metal_set_tensor(struct ggml_metal_context * ctx, struct ggml_tensor *
61
61
  // get data from the device into host memory
62
62
  void ggml_metal_get_tensor(struct ggml_metal_context * ctx, struct ggml_tensor * t);
63
63
 
64
+ // try to find operations that can be run concurrently in the graph
65
+ // you should run it again if the topology of your graph changes
66
+ void ggml_metal_graph_find_concurrency(struct ggml_metal_context * ctx, struct ggml_cgraph * gf);
67
+
68
+ // if the graph has been optimized for concurrently dispatch
69
+ bool ggml_metal_if_optimized(struct ggml_metal_context * ctx);
70
+
64
71
  // same as ggml_graph_compute but uses Metal
65
72
  // creates gf->n_threads command buffers in parallel
66
73
  void ggml_metal_graph_compute(struct ggml_metal_context * ctx, struct ggml_cgraph * gf);
@@ -36,12 +36,16 @@ struct ggml_metal_context {
36
36
  int n_buffers;
37
37
  struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
38
38
 
39
+ int concur_list[GGML_MAX_NODES];
40
+ int concur_list_len;
41
+
39
42
  // custom kernels
40
43
  #define GGML_METAL_DECL_KERNEL(name) \
41
44
  id<MTLFunction> function_##name; \
42
45
  id<MTLComputePipelineState> pipeline_##name
43
46
 
44
47
  GGML_METAL_DECL_KERNEL(add);
48
+ GGML_METAL_DECL_KERNEL(add_row); // TODO: avoid this extra kernel, instead extend the "add" kernel to support broadcast
45
49
  GGML_METAL_DECL_KERNEL(mul);
46
50
  GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast
47
51
  GGML_METAL_DECL_KERNEL(scale);
@@ -97,6 +101,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
97
101
  ctx->device = MTLCreateSystemDefaultDevice();
98
102
  ctx->queue = [ctx->device newCommandQueue];
99
103
  ctx->n_buffers = 0;
104
+ ctx->concur_list_len = 0;
100
105
 
101
106
  // determine if we can use MPS
102
107
  if (MPSSupportsMTLDevice(ctx->device)) {
@@ -157,6 +162,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
157
162
  fprintf(stderr, "%s: loaded %-32s %16p\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name);
158
163
 
159
164
  GGML_METAL_ADD_KERNEL(add);
165
+ GGML_METAL_ADD_KERNEL(add_row);
160
166
  GGML_METAL_ADD_KERNEL(mul);
161
167
  GGML_METAL_ADD_KERNEL(mul_row);
162
168
  GGML_METAL_ADD_KERNEL(scale);
@@ -215,6 +221,13 @@ void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb) {
215
221
  ctx->n_cb = n_cb;
216
222
  }
217
223
 
224
+ bool ggml_metal_if_optimized(struct ggml_metal_context * ctx) {
225
+ if (ctx->concur_list_len) {
226
+ return true;
227
+ }
228
+ return false;
229
+ }
230
+
218
231
  // finds the Metal buffer that contains the tensor data on the GPU device
219
232
  // the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
220
233
  // Metal buffer based on the host memory pointer
@@ -353,11 +366,98 @@ void ggml_metal_get_tensor(
353
366
  memcpy(t->data, (void *) ((uint8_t *) id_src.contents + offs), ggml_nbytes(t));
354
367
  }
355
368
 
369
+ void ggml_metal_graph_find_concurrency(
370
+ struct ggml_metal_context * ctx,
371
+ struct ggml_cgraph * gf) {
372
+ int search_depth = gf->n_nodes; //we only find concurrency in this range to avoid wasting too much time
373
+ int nodes_unused[GGML_MAX_NODES];
374
+
375
+ for (int i = 0; i < GGML_MAX_NODES; i++) {ctx->concur_list[i] = 0;}
376
+ for (int i = 0; i < gf->n_nodes; i++) {nodes_unused[i] = 1;}
377
+ ctx->concur_list_len = 0;
378
+
379
+ int n_left = gf->n_nodes;
380
+ int n_start = 0; // all nodes before n_start at nodes_unused array have been sorted and store back to ctx->concur_list
381
+ int level_pos = 0; // at ctx->concur_list, the last layer (level) ends at level_pos
382
+
383
+ while (n_left > 0) {
384
+ // number of nodes at a layer (that can be issued concurrently)
385
+ int concurrency = 0;
386
+ for (int i = n_start; i < ((n_start + search_depth > gf->n_nodes) ? gf->n_nodes : n_start + search_depth); i++) {
387
+ if (nodes_unused[i]) {
388
+ // if the requirements for gf->nodes[i] are satisfied
389
+ int exe_flag=1;
390
+ // scan all srcs
391
+ for (int src_ind = 0; src_ind < GGML_MAX_SRC; src_ind++) {
392
+ struct ggml_tensor * src_cur = gf->nodes[i]->src[src_ind];
393
+ if (src_cur) {
394
+ // if is leaf nodes it's satisfied.
395
+ if (src_cur->op == GGML_OP_NONE && src_cur->grad == NULL) {continue;}
396
+
397
+ // otherwise this src should be the output from previous nodes.
398
+ int is_found = 0;
399
+ // scan 2*search_depth back because we inserted barrier.
400
+ for (int j = ((level_pos - 2*search_depth) < 0 ? 0 : (level_pos - 2*search_depth)); j < level_pos; j++) {
401
+ if (gf->nodes[ctx->concur_list[j]] == src_cur) {is_found = 1; break;}
402
+ }
403
+ if (is_found == 0) {exe_flag = 0; break;}
404
+ }
405
+ }
406
+ if (exe_flag) {
407
+ // check if nodes[i]'s data will be overwritten by a node before nodes[i].
408
+ // if node[5] and node[3] write to the same memory region, then we can't issue node[5] before node[3]
409
+ int64_t data_start = (int64_t) gf->nodes[i]->data;
410
+ int64_t length = (int64_t) ggml_nbytes(gf->nodes[i]);
411
+ for (int j = n_start; j < i; j++) {
412
+ if (nodes_unused[j] && gf->nodes[j]->op != GGML_OP_RESHAPE \
413
+ && gf->nodes[j]->op != GGML_OP_VIEW \
414
+ && gf->nodes[j]->op != GGML_OP_TRANSPOSE \
415
+ && gf->nodes[j]->op != GGML_OP_PERMUTE) {
416
+ if (((int64_t)gf->nodes[j]->data) >= data_start + length || \
417
+ ((int64_t)gf->nodes[j]->data) + (int64_t) ggml_nbytes(gf->nodes[j]) <= data_start) {
418
+ continue;
419
+ } else {
420
+ exe_flag = 0;
421
+ }
422
+ }
423
+ }
424
+ }
425
+ if (exe_flag) {
426
+ ctx->concur_list[level_pos + concurrency] = i;
427
+ nodes_unused[i] = 0;
428
+ concurrency++;
429
+ ctx->concur_list_len++;
430
+ }
431
+ }
432
+ }
433
+ n_left -= concurrency;
434
+ // adding a barrier different layer
435
+ ctx->concur_list[level_pos + concurrency] = -1;
436
+ ctx->concur_list_len++;
437
+ // jump all sorted nodes at nodes_bak
438
+ while (!nodes_unused[n_start]) {n_start++;}
439
+ level_pos += concurrency + 1;
440
+ }
441
+
442
+ if (ctx->concur_list_len > GGML_MAX_NODES) {
443
+ fprintf(stderr, "%s: too many elements for metal ctx->concur_list!\n", __func__);
444
+ }
445
+ }
446
+
356
447
  void ggml_metal_graph_compute(
357
448
  struct ggml_metal_context * ctx,
358
449
  struct ggml_cgraph * gf) {
359
450
  metal_printf("%s: evaluating graph\n", __func__);
360
451
 
452
+ // if there is ctx->concur_list, dispatch concurrently
453
+ // else fallback to serial dispatch
454
+ MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor;
455
+
456
+ const bool has_concur = ctx->concur_list_len && ctx->concur_list_len <= GGML_MAX_NODES;
457
+
458
+ const int n_nodes = has_concur ? ctx->concur_list_len : gf->n_nodes;
459
+ edesc.dispatchType = has_concur ? MTLDispatchTypeConcurrent : MTLDispatchTypeSerial;
460
+
361
461
  // create multiple command buffers and enqueue them
362
462
  // then, we encode the graph into the command buffers in parallel
363
463
 
@@ -376,7 +476,7 @@ void ggml_metal_graph_compute(
376
476
  dispatch_queue_t queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT);
377
477
 
378
478
  for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
379
- const int n_nodes_per_cb = (gf->n_nodes + n_cb - 1) / n_cb;
479
+ const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb;
380
480
 
381
481
  dispatch_async(queue, ^{
382
482
  size_t offs_src0 = 0;
@@ -387,10 +487,21 @@ void ggml_metal_graph_compute(
387
487
 
388
488
  id<MTLComputeCommandEncoder> encoder = nil;
389
489
 
390
- const int node_start = (cb_idx + 0) * n_nodes_per_cb;
391
- const int node_end = (cb_idx == n_cb - 1) ? gf->n_nodes : (cb_idx + 1) * n_nodes_per_cb;
490
+ const int node_start = (cb_idx + 0) * n_nodes_per_cb;
491
+ const int node_end = (cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb;
492
+
493
+ for (int ind = node_start; ind < node_end; ++ind) {
494
+ const int i = has_concur ? ctx->concur_list[ind] : ind;
495
+
496
+ if (i == -1) {
497
+ if (encoder == nil) {
498
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
499
+ continue;
500
+ }
501
+ [encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
502
+ continue;
503
+ }
392
504
 
393
- for (int i = node_start; i < node_end; ++i) {
394
505
  metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
395
506
 
396
507
  struct ggml_tensor * src0 = gf->nodes[i]->src[0];
@@ -461,13 +572,19 @@ void ggml_metal_graph_compute(
461
572
  case GGML_OP_ADD:
462
573
  {
463
574
  if (encoder == nil) {
464
- encoder = [command_buffer computeCommandEncoder];
575
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
465
576
  }
466
577
 
467
- [encoder setComputePipelineState:ctx->pipeline_add];
578
+ if (ggml_nelements(src1) == ne10) {
579
+ // src1 is a row
580
+ [encoder setComputePipelineState:ctx->pipeline_add_row];
581
+ } else {
582
+ [encoder setComputePipelineState:ctx->pipeline_add];
583
+ }
468
584
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
469
585
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
470
586
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
587
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
471
588
 
472
589
  const int64_t n = ggml_nelements(dst);
473
590
 
@@ -476,7 +593,7 @@ void ggml_metal_graph_compute(
476
593
  case GGML_OP_MUL:
477
594
  {
478
595
  if (encoder == nil) {
479
- encoder = [command_buffer computeCommandEncoder];
596
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
480
597
  }
481
598
 
482
599
  if (ggml_nelements(src1) == ne10) {
@@ -497,7 +614,7 @@ void ggml_metal_graph_compute(
497
614
  case GGML_OP_SCALE:
498
615
  {
499
616
  if (encoder == nil) {
500
- encoder = [command_buffer computeCommandEncoder];
617
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
501
618
  }
502
619
 
503
620
  const float scale = *(const float *) src1->data;
@@ -511,52 +628,60 @@ void ggml_metal_graph_compute(
511
628
 
512
629
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
513
630
  } break;
514
- case GGML_OP_SILU:
515
- {
516
- if (encoder == nil) {
517
- encoder = [command_buffer computeCommandEncoder];
518
- }
519
-
520
- [encoder setComputePipelineState:ctx->pipeline_silu];
521
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
522
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
523
-
524
- const int64_t n = ggml_nelements(dst);
525
-
526
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
527
- } break;
528
- case GGML_OP_RELU:
529
- {
530
- if (encoder == nil) {
531
- encoder = [command_buffer computeCommandEncoder];
532
- }
533
-
534
- [encoder setComputePipelineState:ctx->pipeline_relu];
535
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
536
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
537
-
538
- const int64_t n = ggml_nelements(dst);
539
-
540
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
631
+ case GGML_OP_UNARY:
632
+ switch (ggml_get_unary_op(gf->nodes[i])) {
633
+ case GGML_UNARY_OP_SILU:
634
+ {
635
+ if (encoder == nil) {
636
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
637
+ }
638
+
639
+ [encoder setComputePipelineState:ctx->pipeline_silu];
640
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
641
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
642
+
643
+ const int64_t n = ggml_nelements(dst);
644
+
645
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
646
+ } break;
647
+ case GGML_UNARY_OP_RELU:
648
+ {
649
+ if (encoder == nil) {
650
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
651
+ }
652
+
653
+ [encoder setComputePipelineState:ctx->pipeline_relu];
654
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
655
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
656
+
657
+ const int64_t n = ggml_nelements(dst);
658
+
659
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
660
+ } break;
661
+ case GGML_UNARY_OP_GELU:
662
+ {
663
+ if (encoder == nil) {
664
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
665
+ }
666
+
667
+ [encoder setComputePipelineState:ctx->pipeline_gelu];
668
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
669
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
670
+
671
+ const int64_t n = ggml_nelements(dst);
672
+
673
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
674
+ } break;
675
+ default:
676
+ {
677
+ fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
678
+ GGML_ASSERT(false);
679
+ }
541
680
  } break;
542
- case GGML_OP_GELU:
543
- {
544
- if (encoder == nil) {
545
- encoder = [command_buffer computeCommandEncoder];
546
- }
547
-
548
- [encoder setComputePipelineState:ctx->pipeline_gelu];
549
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
550
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
551
-
552
- const int64_t n = ggml_nelements(dst);
553
-
554
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
555
- } break;
556
681
  case GGML_OP_SOFT_MAX:
557
682
  {
558
683
  if (encoder == nil) {
559
- encoder = [command_buffer computeCommandEncoder];
684
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
560
685
  }
561
686
 
562
687
  const int nth = 32;
@@ -574,10 +699,10 @@ void ggml_metal_graph_compute(
574
699
  case GGML_OP_DIAG_MASK_INF:
575
700
  {
576
701
  if (encoder == nil) {
577
- encoder = [command_buffer computeCommandEncoder];
702
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
578
703
  }
579
704
 
580
- const int n_past = ((int32_t *)(src1->data))[0];
705
+ const int n_past = ((int32_t *)(dst->op_params))[0];
581
706
 
582
707
  [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
583
708
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@@ -637,7 +762,7 @@ void ggml_metal_graph_compute(
637
762
  }
638
763
  } else {
639
764
  if (encoder == nil) {
640
- encoder = [command_buffer computeCommandEncoder];
765
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
641
766
  }
642
767
 
643
768
  int nth0 = 32;
@@ -676,8 +801,8 @@ void ggml_metal_graph_compute(
676
801
  GGML_ASSERT(ne02 == 1);
677
802
  GGML_ASSERT(ne12 == 1);
678
803
 
679
- nth0 = 4;
680
- nth1 = 16;
804
+ nth0 = 2;
805
+ nth1 = 32;
681
806
  [encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_K_f32];
682
807
  } break;
683
808
  case GGML_TYPE_Q3_K:
@@ -685,8 +810,8 @@ void ggml_metal_graph_compute(
685
810
  GGML_ASSERT(ne02 == 1);
686
811
  GGML_ASSERT(ne12 == 1);
687
812
 
688
- nth0 = 4;
689
- nth1 = 16;
813
+ nth0 = 2;
814
+ nth1 = 32;
690
815
  [encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_K_f32];
691
816
  } break;
692
817
  case GGML_TYPE_Q4_K:
@@ -694,8 +819,8 @@ void ggml_metal_graph_compute(
694
819
  GGML_ASSERT(ne02 == 1);
695
820
  GGML_ASSERT(ne12 == 1);
696
821
 
697
- nth0 = 4;
698
- nth1 = 16;
822
+ nth0 = 2;
823
+ nth1 = 32;
699
824
  [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_K_f32];
700
825
  } break;
701
826
  case GGML_TYPE_Q5_K:
@@ -703,8 +828,8 @@ void ggml_metal_graph_compute(
703
828
  GGML_ASSERT(ne02 == 1);
704
829
  GGML_ASSERT(ne12 == 1);
705
830
 
706
- nth0 = 4;
707
- nth1 = 16;
831
+ nth0 = 2;
832
+ nth1 = 32;
708
833
  [encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_K_f32];
709
834
  } break;
710
835
  case GGML_TYPE_Q6_K:
@@ -712,8 +837,8 @@ void ggml_metal_graph_compute(
712
837
  GGML_ASSERT(ne02 == 1);
713
838
  GGML_ASSERT(ne12 == 1);
714
839
 
715
- nth0 = 4;
716
- nth1 = 16;
840
+ nth0 = 2;
841
+ nth1 = 32;
717
842
  [encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_K_f32];
718
843
  } break;
719
844
  default:
@@ -739,20 +864,22 @@ void ggml_metal_graph_compute(
739
864
  [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
740
865
  [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
741
866
 
742
- if (src0t == GGML_TYPE_Q4_0) {
743
- [encoder dispatchThreadgroups:MTLSizeMake(ne01 / 8+((ne01 % 8) & 0x01), ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
867
+ if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
868
+ src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
869
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
870
+ }
871
+ else if (src0t == GGML_TYPE_Q3_K) {
872
+ #ifdef GGML_QKK_64
873
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
874
+ #else
875
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01+3)/4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
876
+ #endif
744
877
  }
745
- else if (src0t == GGML_TYPE_Q4_1) {
746
- [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
747
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
878
+ else if (src0t == GGML_TYPE_Q5_K) {
879
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3) / 4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
748
880
  }
749
- else if (src0t == GGML_TYPE_Q2_K ||
750
- src0t == GGML_TYPE_Q3_K ||
751
- src0t == GGML_TYPE_Q4_K ||
752
- src0t == GGML_TYPE_Q5_K ||
753
- src0t == GGML_TYPE_Q6_K) {
754
- [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
755
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
881
+ else if (src0t == GGML_TYPE_Q6_K) {
882
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
756
883
  } else {
757
884
  [encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
758
885
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
@@ -762,7 +889,7 @@ void ggml_metal_graph_compute(
762
889
  case GGML_OP_GET_ROWS:
763
890
  {
764
891
  if (encoder == nil) {
765
- encoder = [command_buffer computeCommandEncoder];
892
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
766
893
  }
767
894
 
768
895
  switch (src0->type) {
@@ -791,12 +918,13 @@ void ggml_metal_graph_compute(
791
918
  case GGML_OP_RMS_NORM:
792
919
  {
793
920
  if (encoder == nil) {
794
- encoder = [command_buffer computeCommandEncoder];
921
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
795
922
  }
796
923
 
797
- const float eps = 1e-6f;
924
+ float eps;
925
+ memcpy(&eps, dst->op_params, sizeof(float));
798
926
 
799
- const int nth = 256;
927
+ const int nth = 512;
800
928
 
801
929
  [encoder setComputePipelineState:ctx->pipeline_rms_norm];
802
930
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@@ -804,7 +932,7 @@ void ggml_metal_graph_compute(
804
932
  [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
805
933
  [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
806
934
  [encoder setBytes:&eps length:sizeof( float) atIndex:4];
807
- [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
935
+ [encoder setThreadgroupMemoryLength:nth/32*sizeof(float) atIndex:0];
808
936
 
809
937
  const int64_t nrows = ggml_nrows(src0);
810
938
 
@@ -813,7 +941,7 @@ void ggml_metal_graph_compute(
813
941
  case GGML_OP_NORM:
814
942
  {
815
943
  if (encoder == nil) {
816
- encoder = [command_buffer computeCommandEncoder];
944
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
817
945
  }
818
946
 
819
947
  const float eps = 1e-5f;
@@ -835,14 +963,15 @@ void ggml_metal_graph_compute(
835
963
  case GGML_OP_ALIBI:
836
964
  {
837
965
  if (encoder == nil) {
838
- encoder = [command_buffer computeCommandEncoder];
966
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
839
967
  }
840
968
 
841
969
  GGML_ASSERT((src0t == GGML_TYPE_F32));
842
970
 
843
- const int n_past = ((int32_t *) src1->data)[0]; UNUSED(n_past);
844
- const int n_head = ((int32_t *) src1->data)[1];
845
- const float max_bias = ((float *) src1->data)[2];
971
+ const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past);
972
+ const int n_head = ((int32_t *) dst->op_params)[1];
973
+ float max_bias;
974
+ memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
846
975
 
847
976
  if (__builtin_popcount(n_head) != 1) {
848
977
  GGML_ASSERT(false && "only power-of-two n_head implemented");
@@ -877,43 +1006,51 @@ void ggml_metal_graph_compute(
877
1006
  case GGML_OP_ROPE:
878
1007
  {
879
1008
  if (encoder == nil) {
880
- encoder = [command_buffer computeCommandEncoder];
1009
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
881
1010
  }
882
1011
 
883
- const int n_dims = ((int32_t *) src1->data)[1];
884
- const int mode = ((int32_t *) src1->data)[2];
1012
+ const int n_past = ((int32_t *) dst->op_params)[0];
1013
+ const int n_dims = ((int32_t *) dst->op_params)[1];
1014
+ const int mode = ((int32_t *) dst->op_params)[2];
885
1015
 
886
- const int n_past = ((int32_t *)(src1->data))[0];
1016
+ float freq_base;
1017
+ float freq_scale;
1018
+ memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float));
1019
+ memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
887
1020
 
888
1021
  [encoder setComputePipelineState:ctx->pipeline_rope];
889
1022
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
890
1023
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
891
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
892
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
893
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
894
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
895
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
896
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
897
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
898
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
899
- [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
900
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
901
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
902
- [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
903
- [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
904
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
905
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
906
- [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
907
- [encoder setBytes:&n_past length:sizeof( int) atIndex:18];
908
- [encoder setBytes:&n_dims length:sizeof( int) atIndex:19];
909
- [encoder setBytes:&mode length:sizeof( int) atIndex:20];
1024
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1025
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
1026
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
1027
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
1028
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
1029
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
1030
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
1031
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
1032
+ [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
1033
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
1034
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
1035
+ [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
1036
+ [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
1037
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
1038
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
1039
+ [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
1040
+ [encoder setBytes:&n_past length:sizeof( int) atIndex:18];
1041
+ [encoder setBytes:&n_dims length:sizeof( int) atIndex:19];
1042
+ [encoder setBytes:&mode length:sizeof( int) atIndex:20];
1043
+ [encoder setBytes:&freq_base length:sizeof(float) atIndex:21];
1044
+ [encoder setBytes:&freq_scale length:sizeof(float) atIndex:22];
910
1045
 
911
1046
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
912
1047
  } break;
1048
+ case GGML_OP_DUP:
913
1049
  case GGML_OP_CPY:
1050
+ case GGML_OP_CONT:
914
1051
  {
915
1052
  if (encoder == nil) {
916
- encoder = [command_buffer computeCommandEncoder];
1053
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
917
1054
  }
918
1055
 
919
1056
  const int nth = 32;
@@ -960,8 +1097,10 @@ void ggml_metal_graph_compute(
960
1097
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
961
1098
  } break;
962
1099
  default:
963
- fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
964
- GGML_ASSERT(false);
1100
+ {
1101
+ fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
1102
+ GGML_ASSERT(false);
1103
+ }
965
1104
  }
966
1105
  }
967
1106