llama_cpp 0.3.3 → 0.3.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -61,6 +61,13 @@ void ggml_metal_set_tensor(struct ggml_metal_context * ctx, struct ggml_tensor *
61
61
  // get data from the device into host memory
62
62
  void ggml_metal_get_tensor(struct ggml_metal_context * ctx, struct ggml_tensor * t);
63
63
 
64
+ // try to find operations that can be run concurrently in the graph
65
+ // you should run it again if the topology of your graph changes
66
+ void ggml_metal_graph_find_concurrency(struct ggml_metal_context * ctx, struct ggml_cgraph * gf);
67
+
68
+ // if the graph has been optimized for concurrently dispatch
69
+ bool ggml_metal_if_optimized(struct ggml_metal_context * ctx);
70
+
64
71
  // same as ggml_graph_compute but uses Metal
65
72
  // creates gf->n_threads command buffers in parallel
66
73
  void ggml_metal_graph_compute(struct ggml_metal_context * ctx, struct ggml_cgraph * gf);
@@ -36,12 +36,16 @@ struct ggml_metal_context {
36
36
  int n_buffers;
37
37
  struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
38
38
 
39
+ int concur_list[GGML_MAX_NODES];
40
+ int concur_list_len;
41
+
39
42
  // custom kernels
40
43
  #define GGML_METAL_DECL_KERNEL(name) \
41
44
  id<MTLFunction> function_##name; \
42
45
  id<MTLComputePipelineState> pipeline_##name
43
46
 
44
47
  GGML_METAL_DECL_KERNEL(add);
48
+ GGML_METAL_DECL_KERNEL(add_row); // TODO: avoid this extra kernel, instead extend the "add" kernel to support broadcast
45
49
  GGML_METAL_DECL_KERNEL(mul);
46
50
  GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast
47
51
  GGML_METAL_DECL_KERNEL(scale);
@@ -97,6 +101,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
97
101
  ctx->device = MTLCreateSystemDefaultDevice();
98
102
  ctx->queue = [ctx->device newCommandQueue];
99
103
  ctx->n_buffers = 0;
104
+ ctx->concur_list_len = 0;
100
105
 
101
106
  // determine if we can use MPS
102
107
  if (MPSSupportsMTLDevice(ctx->device)) {
@@ -157,6 +162,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
157
162
  fprintf(stderr, "%s: loaded %-32s %16p\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name);
158
163
 
159
164
  GGML_METAL_ADD_KERNEL(add);
165
+ GGML_METAL_ADD_KERNEL(add_row);
160
166
  GGML_METAL_ADD_KERNEL(mul);
161
167
  GGML_METAL_ADD_KERNEL(mul_row);
162
168
  GGML_METAL_ADD_KERNEL(scale);
@@ -215,6 +221,13 @@ void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb) {
215
221
  ctx->n_cb = n_cb;
216
222
  }
217
223
 
224
+ bool ggml_metal_if_optimized(struct ggml_metal_context * ctx) {
225
+ if (ctx->concur_list_len) {
226
+ return true;
227
+ }
228
+ return false;
229
+ }
230
+
218
231
  // finds the Metal buffer that contains the tensor data on the GPU device
219
232
  // the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
220
233
  // Metal buffer based on the host memory pointer
@@ -353,11 +366,98 @@ void ggml_metal_get_tensor(
353
366
  memcpy(t->data, (void *) ((uint8_t *) id_src.contents + offs), ggml_nbytes(t));
354
367
  }
355
368
 
369
+ void ggml_metal_graph_find_concurrency(
370
+ struct ggml_metal_context * ctx,
371
+ struct ggml_cgraph * gf) {
372
+ int search_depth = gf->n_nodes; //we only find concurrency in this range to avoid wasting too much time
373
+ int nodes_unused[GGML_MAX_NODES];
374
+
375
+ for (int i = 0; i < GGML_MAX_NODES; i++) {ctx->concur_list[i] = 0;}
376
+ for (int i = 0; i < gf->n_nodes; i++) {nodes_unused[i] = 1;}
377
+ ctx->concur_list_len = 0;
378
+
379
+ int n_left = gf->n_nodes;
380
+ int n_start = 0; // all nodes before n_start at nodes_unused array have been sorted and store back to ctx->concur_list
381
+ int level_pos = 0; // at ctx->concur_list, the last layer (level) ends at level_pos
382
+
383
+ while (n_left > 0) {
384
+ // number of nodes at a layer (that can be issued concurrently)
385
+ int concurrency = 0;
386
+ for (int i = n_start; i < ((n_start + search_depth > gf->n_nodes) ? gf->n_nodes : n_start + search_depth); i++) {
387
+ if (nodes_unused[i]) {
388
+ // if the requirements for gf->nodes[i] are satisfied
389
+ int exe_flag=1;
390
+ // scan all srcs
391
+ for (int src_ind = 0; src_ind < GGML_MAX_SRC; src_ind++) {
392
+ struct ggml_tensor * src_cur = gf->nodes[i]->src[src_ind];
393
+ if (src_cur) {
394
+ // if is leaf nodes it's satisfied.
395
+ if (src_cur->op == GGML_OP_NONE && src_cur->grad == NULL) {continue;}
396
+
397
+ // otherwise this src should be the output from previous nodes.
398
+ int is_found = 0;
399
+ // scan 2*search_depth back because we inserted barrier.
400
+ for (int j = ((level_pos - 2*search_depth) < 0 ? 0 : (level_pos - 2*search_depth)); j < level_pos; j++) {
401
+ if (gf->nodes[ctx->concur_list[j]] == src_cur) {is_found = 1; break;}
402
+ }
403
+ if (is_found == 0) {exe_flag = 0; break;}
404
+ }
405
+ }
406
+ if (exe_flag) {
407
+ // check if nodes[i]'s data will be overwritten by a node before nodes[i].
408
+ // if node[5] and node[3] write to the same memory region, then we can't issue node[5] before node[3]
409
+ int64_t data_start = (int64_t) gf->nodes[i]->data;
410
+ int64_t length = (int64_t) ggml_nbytes(gf->nodes[i]);
411
+ for (int j = n_start; j < i; j++) {
412
+ if (nodes_unused[j] && gf->nodes[j]->op != GGML_OP_RESHAPE \
413
+ && gf->nodes[j]->op != GGML_OP_VIEW \
414
+ && gf->nodes[j]->op != GGML_OP_TRANSPOSE \
415
+ && gf->nodes[j]->op != GGML_OP_PERMUTE) {
416
+ if (((int64_t)gf->nodes[j]->data) >= data_start + length || \
417
+ ((int64_t)gf->nodes[j]->data) + (int64_t) ggml_nbytes(gf->nodes[j]) <= data_start) {
418
+ continue;
419
+ } else {
420
+ exe_flag = 0;
421
+ }
422
+ }
423
+ }
424
+ }
425
+ if (exe_flag) {
426
+ ctx->concur_list[level_pos + concurrency] = i;
427
+ nodes_unused[i] = 0;
428
+ concurrency++;
429
+ ctx->concur_list_len++;
430
+ }
431
+ }
432
+ }
433
+ n_left -= concurrency;
434
+ // adding a barrier different layer
435
+ ctx->concur_list[level_pos + concurrency] = -1;
436
+ ctx->concur_list_len++;
437
+ // jump all sorted nodes at nodes_bak
438
+ while (!nodes_unused[n_start]) {n_start++;}
439
+ level_pos += concurrency + 1;
440
+ }
441
+
442
+ if (ctx->concur_list_len > GGML_MAX_NODES) {
443
+ fprintf(stderr, "%s: too many elements for metal ctx->concur_list!\n", __func__);
444
+ }
445
+ }
446
+
356
447
  void ggml_metal_graph_compute(
357
448
  struct ggml_metal_context * ctx,
358
449
  struct ggml_cgraph * gf) {
359
450
  metal_printf("%s: evaluating graph\n", __func__);
360
451
 
452
+ // if there is ctx->concur_list, dispatch concurrently
453
+ // else fallback to serial dispatch
454
+ MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor;
455
+
456
+ const bool has_concur = ctx->concur_list_len && ctx->concur_list_len <= GGML_MAX_NODES;
457
+
458
+ const int n_nodes = has_concur ? ctx->concur_list_len : gf->n_nodes;
459
+ edesc.dispatchType = has_concur ? MTLDispatchTypeConcurrent : MTLDispatchTypeSerial;
460
+
361
461
  // create multiple command buffers and enqueue them
362
462
  // then, we encode the graph into the command buffers in parallel
363
463
 
@@ -376,7 +476,7 @@ void ggml_metal_graph_compute(
376
476
  dispatch_queue_t queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT);
377
477
 
378
478
  for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
379
- const int n_nodes_per_cb = (gf->n_nodes + n_cb - 1) / n_cb;
479
+ const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb;
380
480
 
381
481
  dispatch_async(queue, ^{
382
482
  size_t offs_src0 = 0;
@@ -387,10 +487,21 @@ void ggml_metal_graph_compute(
387
487
 
388
488
  id<MTLComputeCommandEncoder> encoder = nil;
389
489
 
390
- const int node_start = (cb_idx + 0) * n_nodes_per_cb;
391
- const int node_end = (cb_idx == n_cb - 1) ? gf->n_nodes : (cb_idx + 1) * n_nodes_per_cb;
490
+ const int node_start = (cb_idx + 0) * n_nodes_per_cb;
491
+ const int node_end = (cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb;
492
+
493
+ for (int ind = node_start; ind < node_end; ++ind) {
494
+ const int i = has_concur ? ctx->concur_list[ind] : ind;
495
+
496
+ if (i == -1) {
497
+ if (encoder == nil) {
498
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
499
+ continue;
500
+ }
501
+ [encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
502
+ continue;
503
+ }
392
504
 
393
- for (int i = node_start; i < node_end; ++i) {
394
505
  metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
395
506
 
396
507
  struct ggml_tensor * src0 = gf->nodes[i]->src[0];
@@ -461,13 +572,19 @@ void ggml_metal_graph_compute(
461
572
  case GGML_OP_ADD:
462
573
  {
463
574
  if (encoder == nil) {
464
- encoder = [command_buffer computeCommandEncoder];
575
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
465
576
  }
466
577
 
467
- [encoder setComputePipelineState:ctx->pipeline_add];
578
+ if (ggml_nelements(src1) == ne10) {
579
+ // src1 is a row
580
+ [encoder setComputePipelineState:ctx->pipeline_add_row];
581
+ } else {
582
+ [encoder setComputePipelineState:ctx->pipeline_add];
583
+ }
468
584
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
469
585
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
470
586
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
587
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
471
588
 
472
589
  const int64_t n = ggml_nelements(dst);
473
590
 
@@ -476,7 +593,7 @@ void ggml_metal_graph_compute(
476
593
  case GGML_OP_MUL:
477
594
  {
478
595
  if (encoder == nil) {
479
- encoder = [command_buffer computeCommandEncoder];
596
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
480
597
  }
481
598
 
482
599
  if (ggml_nelements(src1) == ne10) {
@@ -497,7 +614,7 @@ void ggml_metal_graph_compute(
497
614
  case GGML_OP_SCALE:
498
615
  {
499
616
  if (encoder == nil) {
500
- encoder = [command_buffer computeCommandEncoder];
617
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
501
618
  }
502
619
 
503
620
  const float scale = *(const float *) src1->data;
@@ -511,52 +628,60 @@ void ggml_metal_graph_compute(
511
628
 
512
629
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
513
630
  } break;
514
- case GGML_OP_SILU:
515
- {
516
- if (encoder == nil) {
517
- encoder = [command_buffer computeCommandEncoder];
518
- }
519
-
520
- [encoder setComputePipelineState:ctx->pipeline_silu];
521
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
522
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
523
-
524
- const int64_t n = ggml_nelements(dst);
525
-
526
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
527
- } break;
528
- case GGML_OP_RELU:
529
- {
530
- if (encoder == nil) {
531
- encoder = [command_buffer computeCommandEncoder];
532
- }
533
-
534
- [encoder setComputePipelineState:ctx->pipeline_relu];
535
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
536
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
537
-
538
- const int64_t n = ggml_nelements(dst);
539
-
540
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
631
+ case GGML_OP_UNARY:
632
+ switch (ggml_get_unary_op(gf->nodes[i])) {
633
+ case GGML_UNARY_OP_SILU:
634
+ {
635
+ if (encoder == nil) {
636
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
637
+ }
638
+
639
+ [encoder setComputePipelineState:ctx->pipeline_silu];
640
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
641
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
642
+
643
+ const int64_t n = ggml_nelements(dst);
644
+
645
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
646
+ } break;
647
+ case GGML_UNARY_OP_RELU:
648
+ {
649
+ if (encoder == nil) {
650
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
651
+ }
652
+
653
+ [encoder setComputePipelineState:ctx->pipeline_relu];
654
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
655
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
656
+
657
+ const int64_t n = ggml_nelements(dst);
658
+
659
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
660
+ } break;
661
+ case GGML_UNARY_OP_GELU:
662
+ {
663
+ if (encoder == nil) {
664
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
665
+ }
666
+
667
+ [encoder setComputePipelineState:ctx->pipeline_gelu];
668
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
669
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
670
+
671
+ const int64_t n = ggml_nelements(dst);
672
+
673
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
674
+ } break;
675
+ default:
676
+ {
677
+ fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
678
+ GGML_ASSERT(false);
679
+ }
541
680
  } break;
542
- case GGML_OP_GELU:
543
- {
544
- if (encoder == nil) {
545
- encoder = [command_buffer computeCommandEncoder];
546
- }
547
-
548
- [encoder setComputePipelineState:ctx->pipeline_gelu];
549
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
550
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
551
-
552
- const int64_t n = ggml_nelements(dst);
553
-
554
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
555
- } break;
556
681
  case GGML_OP_SOFT_MAX:
557
682
  {
558
683
  if (encoder == nil) {
559
- encoder = [command_buffer computeCommandEncoder];
684
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
560
685
  }
561
686
 
562
687
  const int nth = 32;
@@ -574,10 +699,10 @@ void ggml_metal_graph_compute(
574
699
  case GGML_OP_DIAG_MASK_INF:
575
700
  {
576
701
  if (encoder == nil) {
577
- encoder = [command_buffer computeCommandEncoder];
702
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
578
703
  }
579
704
 
580
- const int n_past = ((int32_t *)(src1->data))[0];
705
+ const int n_past = ((int32_t *)(dst->op_params))[0];
581
706
 
582
707
  [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
583
708
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@@ -637,7 +762,7 @@ void ggml_metal_graph_compute(
637
762
  }
638
763
  } else {
639
764
  if (encoder == nil) {
640
- encoder = [command_buffer computeCommandEncoder];
765
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
641
766
  }
642
767
 
643
768
  int nth0 = 32;
@@ -676,8 +801,8 @@ void ggml_metal_graph_compute(
676
801
  GGML_ASSERT(ne02 == 1);
677
802
  GGML_ASSERT(ne12 == 1);
678
803
 
679
- nth0 = 4;
680
- nth1 = 16;
804
+ nth0 = 2;
805
+ nth1 = 32;
681
806
  [encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_K_f32];
682
807
  } break;
683
808
  case GGML_TYPE_Q3_K:
@@ -685,8 +810,8 @@ void ggml_metal_graph_compute(
685
810
  GGML_ASSERT(ne02 == 1);
686
811
  GGML_ASSERT(ne12 == 1);
687
812
 
688
- nth0 = 4;
689
- nth1 = 16;
813
+ nth0 = 2;
814
+ nth1 = 32;
690
815
  [encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_K_f32];
691
816
  } break;
692
817
  case GGML_TYPE_Q4_K:
@@ -694,8 +819,8 @@ void ggml_metal_graph_compute(
694
819
  GGML_ASSERT(ne02 == 1);
695
820
  GGML_ASSERT(ne12 == 1);
696
821
 
697
- nth0 = 4;
698
- nth1 = 16;
822
+ nth0 = 2;
823
+ nth1 = 32;
699
824
  [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_K_f32];
700
825
  } break;
701
826
  case GGML_TYPE_Q5_K:
@@ -703,8 +828,8 @@ void ggml_metal_graph_compute(
703
828
  GGML_ASSERT(ne02 == 1);
704
829
  GGML_ASSERT(ne12 == 1);
705
830
 
706
- nth0 = 4;
707
- nth1 = 16;
831
+ nth0 = 2;
832
+ nth1 = 32;
708
833
  [encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_K_f32];
709
834
  } break;
710
835
  case GGML_TYPE_Q6_K:
@@ -712,8 +837,8 @@ void ggml_metal_graph_compute(
712
837
  GGML_ASSERT(ne02 == 1);
713
838
  GGML_ASSERT(ne12 == 1);
714
839
 
715
- nth0 = 4;
716
- nth1 = 16;
840
+ nth0 = 2;
841
+ nth1 = 32;
717
842
  [encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_K_f32];
718
843
  } break;
719
844
  default:
@@ -739,20 +864,22 @@ void ggml_metal_graph_compute(
739
864
  [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
740
865
  [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
741
866
 
742
- if (src0t == GGML_TYPE_Q4_0) {
743
- [encoder dispatchThreadgroups:MTLSizeMake(ne01 / 8+((ne01 % 8) & 0x01), ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
867
+ if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
868
+ src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
869
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
870
+ }
871
+ else if (src0t == GGML_TYPE_Q3_K) {
872
+ #ifdef GGML_QKK_64
873
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
874
+ #else
875
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01+3)/4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
876
+ #endif
744
877
  }
745
- else if (src0t == GGML_TYPE_Q4_1) {
746
- [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
747
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
878
+ else if (src0t == GGML_TYPE_Q5_K) {
879
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3) / 4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
748
880
  }
749
- else if (src0t == GGML_TYPE_Q2_K ||
750
- src0t == GGML_TYPE_Q3_K ||
751
- src0t == GGML_TYPE_Q4_K ||
752
- src0t == GGML_TYPE_Q5_K ||
753
- src0t == GGML_TYPE_Q6_K) {
754
- [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
755
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
881
+ else if (src0t == GGML_TYPE_Q6_K) {
882
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
756
883
  } else {
757
884
  [encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
758
885
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
@@ -762,7 +889,7 @@ void ggml_metal_graph_compute(
762
889
  case GGML_OP_GET_ROWS:
763
890
  {
764
891
  if (encoder == nil) {
765
- encoder = [command_buffer computeCommandEncoder];
892
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
766
893
  }
767
894
 
768
895
  switch (src0->type) {
@@ -791,12 +918,13 @@ void ggml_metal_graph_compute(
791
918
  case GGML_OP_RMS_NORM:
792
919
  {
793
920
  if (encoder == nil) {
794
- encoder = [command_buffer computeCommandEncoder];
921
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
795
922
  }
796
923
 
797
- const float eps = 1e-6f;
924
+ float eps;
925
+ memcpy(&eps, dst->op_params, sizeof(float));
798
926
 
799
- const int nth = 256;
927
+ const int nth = 512;
800
928
 
801
929
  [encoder setComputePipelineState:ctx->pipeline_rms_norm];
802
930
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@@ -804,7 +932,7 @@ void ggml_metal_graph_compute(
804
932
  [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
805
933
  [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
806
934
  [encoder setBytes:&eps length:sizeof( float) atIndex:4];
807
- [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
935
+ [encoder setThreadgroupMemoryLength:nth/32*sizeof(float) atIndex:0];
808
936
 
809
937
  const int64_t nrows = ggml_nrows(src0);
810
938
 
@@ -813,7 +941,7 @@ void ggml_metal_graph_compute(
813
941
  case GGML_OP_NORM:
814
942
  {
815
943
  if (encoder == nil) {
816
- encoder = [command_buffer computeCommandEncoder];
944
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
817
945
  }
818
946
 
819
947
  const float eps = 1e-5f;
@@ -835,14 +963,15 @@ void ggml_metal_graph_compute(
835
963
  case GGML_OP_ALIBI:
836
964
  {
837
965
  if (encoder == nil) {
838
- encoder = [command_buffer computeCommandEncoder];
966
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
839
967
  }
840
968
 
841
969
  GGML_ASSERT((src0t == GGML_TYPE_F32));
842
970
 
843
- const int n_past = ((int32_t *) src1->data)[0]; UNUSED(n_past);
844
- const int n_head = ((int32_t *) src1->data)[1];
845
- const float max_bias = ((float *) src1->data)[2];
971
+ const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past);
972
+ const int n_head = ((int32_t *) dst->op_params)[1];
973
+ float max_bias;
974
+ memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
846
975
 
847
976
  if (__builtin_popcount(n_head) != 1) {
848
977
  GGML_ASSERT(false && "only power-of-two n_head implemented");
@@ -877,43 +1006,51 @@ void ggml_metal_graph_compute(
877
1006
  case GGML_OP_ROPE:
878
1007
  {
879
1008
  if (encoder == nil) {
880
- encoder = [command_buffer computeCommandEncoder];
1009
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
881
1010
  }
882
1011
 
883
- const int n_dims = ((int32_t *) src1->data)[1];
884
- const int mode = ((int32_t *) src1->data)[2];
1012
+ const int n_past = ((int32_t *) dst->op_params)[0];
1013
+ const int n_dims = ((int32_t *) dst->op_params)[1];
1014
+ const int mode = ((int32_t *) dst->op_params)[2];
885
1015
 
886
- const int n_past = ((int32_t *)(src1->data))[0];
1016
+ float freq_base;
1017
+ float freq_scale;
1018
+ memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float));
1019
+ memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
887
1020
 
888
1021
  [encoder setComputePipelineState:ctx->pipeline_rope];
889
1022
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
890
1023
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
891
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
892
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
893
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
894
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
895
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
896
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
897
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
898
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
899
- [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
900
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
901
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
902
- [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
903
- [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
904
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
905
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
906
- [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
907
- [encoder setBytes:&n_past length:sizeof( int) atIndex:18];
908
- [encoder setBytes:&n_dims length:sizeof( int) atIndex:19];
909
- [encoder setBytes:&mode length:sizeof( int) atIndex:20];
1024
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1025
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
1026
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
1027
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
1028
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
1029
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
1030
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
1031
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
1032
+ [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
1033
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
1034
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
1035
+ [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
1036
+ [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
1037
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
1038
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
1039
+ [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
1040
+ [encoder setBytes:&n_past length:sizeof( int) atIndex:18];
1041
+ [encoder setBytes:&n_dims length:sizeof( int) atIndex:19];
1042
+ [encoder setBytes:&mode length:sizeof( int) atIndex:20];
1043
+ [encoder setBytes:&freq_base length:sizeof(float) atIndex:21];
1044
+ [encoder setBytes:&freq_scale length:sizeof(float) atIndex:22];
910
1045
 
911
1046
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
912
1047
  } break;
1048
+ case GGML_OP_DUP:
913
1049
  case GGML_OP_CPY:
1050
+ case GGML_OP_CONT:
914
1051
  {
915
1052
  if (encoder == nil) {
916
- encoder = [command_buffer computeCommandEncoder];
1053
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
917
1054
  }
918
1055
 
919
1056
  const int nth = 32;
@@ -960,8 +1097,10 @@ void ggml_metal_graph_compute(
960
1097
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
961
1098
  } break;
962
1099
  default:
963
- fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
964
- GGML_ASSERT(false);
1100
+ {
1101
+ fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
1102
+ GGML_ASSERT(false);
1103
+ }
965
1104
  }
966
1105
  }
967
1106