llama_cpp 0.3.4 → 0.3.6

Sign up to get free protection for your applications and to get access to all the features.
@@ -27,6 +27,7 @@ void ggml_cuda_assign_buffers(struct ggml_tensor * tensor);
27
27
  void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor);
28
28
  void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor);
29
29
  void ggml_cuda_set_main_device(int main_device);
30
+ void ggml_cuda_set_mul_mat_q(bool mul_mat_q);
30
31
  void ggml_cuda_set_scratch_size(size_t scratch_size);
31
32
  void ggml_cuda_free_scratch(void);
32
33
  bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor);
@@ -61,6 +61,13 @@ void ggml_metal_set_tensor(struct ggml_metal_context * ctx, struct ggml_tensor *
61
61
  // get data from the device into host memory
62
62
  void ggml_metal_get_tensor(struct ggml_metal_context * ctx, struct ggml_tensor * t);
63
63
 
64
+ // try to find operations that can be run concurrently in the graph
65
+ // you should run it again if the topology of your graph changes
66
+ void ggml_metal_graph_find_concurrency(struct ggml_metal_context * ctx, struct ggml_cgraph * gf);
67
+
68
+ // if the graph has been optimized for concurrently dispatch
69
+ bool ggml_metal_if_optimized(struct ggml_metal_context * ctx);
70
+
64
71
  // same as ggml_graph_compute but uses Metal
65
72
  // creates gf->n_threads command buffers in parallel
66
73
  void ggml_metal_graph_compute(struct ggml_metal_context * ctx, struct ggml_cgraph * gf);
@@ -36,12 +36,16 @@ struct ggml_metal_context {
36
36
  int n_buffers;
37
37
  struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
38
38
 
39
+ int concur_list[GGML_MAX_NODES];
40
+ int concur_list_len;
41
+
39
42
  // custom kernels
40
43
  #define GGML_METAL_DECL_KERNEL(name) \
41
44
  id<MTLFunction> function_##name; \
42
45
  id<MTLComputePipelineState> pipeline_##name
43
46
 
44
47
  GGML_METAL_DECL_KERNEL(add);
48
+ GGML_METAL_DECL_KERNEL(add_row); // TODO: avoid this extra kernel, instead extend the "add" kernel to support broadcast
45
49
  GGML_METAL_DECL_KERNEL(mul);
46
50
  GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast
47
51
  GGML_METAL_DECL_KERNEL(scale);
@@ -97,6 +101,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
97
101
  ctx->device = MTLCreateSystemDefaultDevice();
98
102
  ctx->queue = [ctx->device newCommandQueue];
99
103
  ctx->n_buffers = 0;
104
+ ctx->concur_list_len = 0;
100
105
 
101
106
  // determine if we can use MPS
102
107
  if (MPSSupportsMTLDevice(ctx->device)) {
@@ -157,6 +162,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
157
162
  fprintf(stderr, "%s: loaded %-32s %16p\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name);
158
163
 
159
164
  GGML_METAL_ADD_KERNEL(add);
165
+ GGML_METAL_ADD_KERNEL(add_row);
160
166
  GGML_METAL_ADD_KERNEL(mul);
161
167
  GGML_METAL_ADD_KERNEL(mul_row);
162
168
  GGML_METAL_ADD_KERNEL(scale);
@@ -215,6 +221,13 @@ void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb) {
215
221
  ctx->n_cb = n_cb;
216
222
  }
217
223
 
224
+ bool ggml_metal_if_optimized(struct ggml_metal_context * ctx) {
225
+ if (ctx->concur_list_len) {
226
+ return true;
227
+ }
228
+ return false;
229
+ }
230
+
218
231
  // finds the Metal buffer that contains the tensor data on the GPU device
219
232
  // the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
220
233
  // Metal buffer based on the host memory pointer
@@ -353,11 +366,98 @@ void ggml_metal_get_tensor(
353
366
  memcpy(t->data, (void *) ((uint8_t *) id_src.contents + offs), ggml_nbytes(t));
354
367
  }
355
368
 
369
+ void ggml_metal_graph_find_concurrency(
370
+ struct ggml_metal_context * ctx,
371
+ struct ggml_cgraph * gf) {
372
+ int search_depth = gf->n_nodes; //we only find concurrency in this range to avoid wasting too much time
373
+ int nodes_unused[GGML_MAX_NODES];
374
+
375
+ for (int i = 0; i < GGML_MAX_NODES; i++) {ctx->concur_list[i] = 0;}
376
+ for (int i = 0; i < gf->n_nodes; i++) {nodes_unused[i] = 1;}
377
+ ctx->concur_list_len = 0;
378
+
379
+ int n_left = gf->n_nodes;
380
+ int n_start = 0; // all nodes before n_start at nodes_unused array have been sorted and store back to ctx->concur_list
381
+ int level_pos = 0; // at ctx->concur_list, the last layer (level) ends at level_pos
382
+
383
+ while (n_left > 0) {
384
+ // number of nodes at a layer (that can be issued concurrently)
385
+ int concurrency = 0;
386
+ for (int i = n_start; i < ((n_start + search_depth > gf->n_nodes) ? gf->n_nodes : n_start + search_depth); i++) {
387
+ if (nodes_unused[i]) {
388
+ // if the requirements for gf->nodes[i] are satisfied
389
+ int exe_flag=1;
390
+ // scan all srcs
391
+ for (int src_ind = 0; src_ind < GGML_MAX_SRC; src_ind++) {
392
+ struct ggml_tensor * src_cur = gf->nodes[i]->src[src_ind];
393
+ if (src_cur) {
394
+ // if is leaf nodes it's satisfied.
395
+ if (src_cur->op == GGML_OP_NONE && src_cur->grad == NULL) {continue;}
396
+
397
+ // otherwise this src should be the output from previous nodes.
398
+ int is_found = 0;
399
+ // scan 2*search_depth back because we inserted barrier.
400
+ for (int j = ((level_pos - 2*search_depth) < 0 ? 0 : (level_pos - 2*search_depth)); j < level_pos; j++) {
401
+ if (gf->nodes[ctx->concur_list[j]] == src_cur) {is_found = 1; break;}
402
+ }
403
+ if (is_found == 0) {exe_flag = 0; break;}
404
+ }
405
+ }
406
+ if (exe_flag) {
407
+ // check if nodes[i]'s data will be overwritten by a node before nodes[i].
408
+ // if node[5] and node[3] write to the same memory region, then we can't issue node[5] before node[3]
409
+ int64_t data_start = (int64_t) gf->nodes[i]->data;
410
+ int64_t length = (int64_t) ggml_nbytes(gf->nodes[i]);
411
+ for (int j = n_start; j < i; j++) {
412
+ if (nodes_unused[j] && gf->nodes[j]->op != GGML_OP_RESHAPE \
413
+ && gf->nodes[j]->op != GGML_OP_VIEW \
414
+ && gf->nodes[j]->op != GGML_OP_TRANSPOSE \
415
+ && gf->nodes[j]->op != GGML_OP_PERMUTE) {
416
+ if (((int64_t)gf->nodes[j]->data) >= data_start + length || \
417
+ ((int64_t)gf->nodes[j]->data) + (int64_t) ggml_nbytes(gf->nodes[j]) <= data_start) {
418
+ continue;
419
+ } else {
420
+ exe_flag = 0;
421
+ }
422
+ }
423
+ }
424
+ }
425
+ if (exe_flag) {
426
+ ctx->concur_list[level_pos + concurrency] = i;
427
+ nodes_unused[i] = 0;
428
+ concurrency++;
429
+ ctx->concur_list_len++;
430
+ }
431
+ }
432
+ }
433
+ n_left -= concurrency;
434
+ // adding a barrier different layer
435
+ ctx->concur_list[level_pos + concurrency] = -1;
436
+ ctx->concur_list_len++;
437
+ // jump all sorted nodes at nodes_bak
438
+ while (!nodes_unused[n_start]) {n_start++;}
439
+ level_pos += concurrency + 1;
440
+ }
441
+
442
+ if (ctx->concur_list_len > GGML_MAX_NODES) {
443
+ fprintf(stderr, "%s: too many elements for metal ctx->concur_list!\n", __func__);
444
+ }
445
+ }
446
+
356
447
  void ggml_metal_graph_compute(
357
448
  struct ggml_metal_context * ctx,
358
449
  struct ggml_cgraph * gf) {
359
450
  metal_printf("%s: evaluating graph\n", __func__);
360
451
 
452
+ // if there is ctx->concur_list, dispatch concurrently
453
+ // else fallback to serial dispatch
454
+ MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor;
455
+
456
+ const bool has_concur = ctx->concur_list_len && ctx->concur_list_len <= GGML_MAX_NODES;
457
+
458
+ const int n_nodes = has_concur ? ctx->concur_list_len : gf->n_nodes;
459
+ edesc.dispatchType = has_concur ? MTLDispatchTypeConcurrent : MTLDispatchTypeSerial;
460
+
361
461
  // create multiple command buffers and enqueue them
362
462
  // then, we encode the graph into the command buffers in parallel
363
463
 
@@ -376,7 +476,7 @@ void ggml_metal_graph_compute(
376
476
  dispatch_queue_t queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT);
377
477
 
378
478
  for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
379
- const int n_nodes_per_cb = (gf->n_nodes + n_cb - 1) / n_cb;
479
+ const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb;
380
480
 
381
481
  dispatch_async(queue, ^{
382
482
  size_t offs_src0 = 0;
@@ -387,10 +487,21 @@ void ggml_metal_graph_compute(
387
487
 
388
488
  id<MTLComputeCommandEncoder> encoder = nil;
389
489
 
390
- const int node_start = (cb_idx + 0) * n_nodes_per_cb;
391
- const int node_end = (cb_idx == n_cb - 1) ? gf->n_nodes : (cb_idx + 1) * n_nodes_per_cb;
490
+ const int node_start = (cb_idx + 0) * n_nodes_per_cb;
491
+ const int node_end = (cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb;
492
+
493
+ for (int ind = node_start; ind < node_end; ++ind) {
494
+ const int i = has_concur ? ctx->concur_list[ind] : ind;
495
+
496
+ if (i == -1) {
497
+ if (encoder == nil) {
498
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
499
+ continue;
500
+ }
501
+ [encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
502
+ continue;
503
+ }
392
504
 
393
- for (int i = node_start; i < node_end; ++i) {
394
505
  metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
395
506
 
396
507
  struct ggml_tensor * src0 = gf->nodes[i]->src[0];
@@ -461,13 +572,19 @@ void ggml_metal_graph_compute(
461
572
  case GGML_OP_ADD:
462
573
  {
463
574
  if (encoder == nil) {
464
- encoder = [command_buffer computeCommandEncoder];
575
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
465
576
  }
466
577
 
467
- [encoder setComputePipelineState:ctx->pipeline_add];
578
+ if (ggml_nelements(src1) == ne10) {
579
+ // src1 is a row
580
+ [encoder setComputePipelineState:ctx->pipeline_add_row];
581
+ } else {
582
+ [encoder setComputePipelineState:ctx->pipeline_add];
583
+ }
468
584
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
469
585
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
470
586
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
587
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
471
588
 
472
589
  const int64_t n = ggml_nelements(dst);
473
590
 
@@ -476,7 +593,7 @@ void ggml_metal_graph_compute(
476
593
  case GGML_OP_MUL:
477
594
  {
478
595
  if (encoder == nil) {
479
- encoder = [command_buffer computeCommandEncoder];
596
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
480
597
  }
481
598
 
482
599
  if (ggml_nelements(src1) == ne10) {
@@ -497,7 +614,7 @@ void ggml_metal_graph_compute(
497
614
  case GGML_OP_SCALE:
498
615
  {
499
616
  if (encoder == nil) {
500
- encoder = [command_buffer computeCommandEncoder];
617
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
501
618
  }
502
619
 
503
620
  const float scale = *(const float *) src1->data;
@@ -511,52 +628,60 @@ void ggml_metal_graph_compute(
511
628
 
512
629
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
513
630
  } break;
514
- case GGML_OP_SILU:
515
- {
516
- if (encoder == nil) {
517
- encoder = [command_buffer computeCommandEncoder];
518
- }
519
-
520
- [encoder setComputePipelineState:ctx->pipeline_silu];
521
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
522
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
523
-
524
- const int64_t n = ggml_nelements(dst);
525
-
526
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
527
- } break;
528
- case GGML_OP_RELU:
529
- {
530
- if (encoder == nil) {
531
- encoder = [command_buffer computeCommandEncoder];
532
- }
533
-
534
- [encoder setComputePipelineState:ctx->pipeline_relu];
535
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
536
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
537
-
538
- const int64_t n = ggml_nelements(dst);
539
-
540
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
631
+ case GGML_OP_UNARY:
632
+ switch (ggml_get_unary_op(gf->nodes[i])) {
633
+ case GGML_UNARY_OP_SILU:
634
+ {
635
+ if (encoder == nil) {
636
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
637
+ }
638
+
639
+ [encoder setComputePipelineState:ctx->pipeline_silu];
640
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
641
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
642
+
643
+ const int64_t n = ggml_nelements(dst);
644
+
645
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
646
+ } break;
647
+ case GGML_UNARY_OP_RELU:
648
+ {
649
+ if (encoder == nil) {
650
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
651
+ }
652
+
653
+ [encoder setComputePipelineState:ctx->pipeline_relu];
654
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
655
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
656
+
657
+ const int64_t n = ggml_nelements(dst);
658
+
659
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
660
+ } break;
661
+ case GGML_UNARY_OP_GELU:
662
+ {
663
+ if (encoder == nil) {
664
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
665
+ }
666
+
667
+ [encoder setComputePipelineState:ctx->pipeline_gelu];
668
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
669
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
670
+
671
+ const int64_t n = ggml_nelements(dst);
672
+
673
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
674
+ } break;
675
+ default:
676
+ {
677
+ fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
678
+ GGML_ASSERT(false);
679
+ }
541
680
  } break;
542
- case GGML_OP_GELU:
543
- {
544
- if (encoder == nil) {
545
- encoder = [command_buffer computeCommandEncoder];
546
- }
547
-
548
- [encoder setComputePipelineState:ctx->pipeline_gelu];
549
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
550
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
551
-
552
- const int64_t n = ggml_nelements(dst);
553
-
554
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
555
- } break;
556
681
  case GGML_OP_SOFT_MAX:
557
682
  {
558
683
  if (encoder == nil) {
559
- encoder = [command_buffer computeCommandEncoder];
684
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
560
685
  }
561
686
 
562
687
  const int nth = 32;
@@ -574,10 +699,10 @@ void ggml_metal_graph_compute(
574
699
  case GGML_OP_DIAG_MASK_INF:
575
700
  {
576
701
  if (encoder == nil) {
577
- encoder = [command_buffer computeCommandEncoder];
702
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
578
703
  }
579
704
 
580
- const int n_past = ((int32_t *)(src1->data))[0];
705
+ const int n_past = ((int32_t *)(dst->op_params))[0];
581
706
 
582
707
  [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
583
708
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@@ -593,7 +718,8 @@ void ggml_metal_graph_compute(
593
718
  // TODO: needs to be updated after PR: https://github.com/ggerganov/ggml/pull/224
594
719
 
595
720
  GGML_ASSERT(ne00 == ne10);
596
- GGML_ASSERT(ne02 == ne12);
721
+ // GGML_ASSERT(ne02 == ne12); // Should be checked on individual data types until broadcast is implemented everywhere
722
+ GGML_ASSERT(ne03 == ne13);
597
723
 
598
724
  if (ggml_is_contiguous(src0) &&
599
725
  ggml_is_contiguous(src1) &&
@@ -621,11 +747,11 @@ void ggml_metal_graph_compute(
621
747
  initWithDevice:ctx->device transposeLeft:false transposeRight:true
622
748
  resultRows:ne11 resultColumns:ne01 interiorColumns:ne00 alpha:1.0 beta:0.0];
623
749
 
624
- // we need to do ne02 multiplications
750
+ // we need to do ne12 multiplications
625
751
  // TODO: is there a way to do this in parallel - currently very slow ..
626
752
  // TODO: might be possible to offload part of the computation to ANE using Accelerate's CBLAS
627
- for (int64_t i02 = 0; i02 < ne02; ++i02) {
628
- size_t offs_src0_cur = offs_src0 + i02*nb02;
753
+ for (int64_t i02 = 0; i02 < ne12; ++i02) {
754
+ size_t offs_src0_cur = offs_src0 + i02/(ne12/ne02)*nb02; // gqa not used for now
629
755
  size_t offs_src1_cur = offs_src1 + i02*nb12;
630
756
  size_t offs_dst_cur = offs_dst + i02*nb2;
631
757
 
@@ -637,7 +763,7 @@ void ggml_metal_graph_compute(
637
763
  }
638
764
  } else {
639
765
  if (encoder == nil) {
640
- encoder = [command_buffer computeCommandEncoder];
766
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
641
767
  }
642
768
 
643
769
  int nth0 = 32;
@@ -647,8 +773,6 @@ void ggml_metal_graph_compute(
647
773
  switch (src0t) {
648
774
  case GGML_TYPE_F16:
649
775
  {
650
- GGML_ASSERT(ne02 == ne12);
651
-
652
776
  nth0 = 64;
653
777
  nth1 = 1;
654
778
  [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
@@ -728,16 +852,18 @@ void ggml_metal_graph_compute(
728
852
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
729
853
  [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
730
854
  [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
731
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:5];
732
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:6];
733
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:7];
734
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:8];
735
- [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:9];
736
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10];
737
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11];
738
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12];
739
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
740
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
855
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
856
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
857
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
858
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
859
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9];
860
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10];
861
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:11];
862
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:12];
863
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:13];
864
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
865
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
866
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
741
867
 
742
868
  if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
743
869
  src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
@@ -764,7 +890,7 @@ void ggml_metal_graph_compute(
764
890
  case GGML_OP_GET_ROWS:
765
891
  {
766
892
  if (encoder == nil) {
767
- encoder = [command_buffer computeCommandEncoder];
893
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
768
894
  }
769
895
 
770
896
  switch (src0->type) {
@@ -793,10 +919,11 @@ void ggml_metal_graph_compute(
793
919
  case GGML_OP_RMS_NORM:
794
920
  {
795
921
  if (encoder == nil) {
796
- encoder = [command_buffer computeCommandEncoder];
922
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
797
923
  }
798
924
 
799
- const float eps = 1e-6f;
925
+ float eps;
926
+ memcpy(&eps, dst->op_params, sizeof(float));
800
927
 
801
928
  const int nth = 512;
802
929
 
@@ -815,7 +942,7 @@ void ggml_metal_graph_compute(
815
942
  case GGML_OP_NORM:
816
943
  {
817
944
  if (encoder == nil) {
818
- encoder = [command_buffer computeCommandEncoder];
945
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
819
946
  }
820
947
 
821
948
  const float eps = 1e-5f;
@@ -837,14 +964,15 @@ void ggml_metal_graph_compute(
837
964
  case GGML_OP_ALIBI:
838
965
  {
839
966
  if (encoder == nil) {
840
- encoder = [command_buffer computeCommandEncoder];
967
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
841
968
  }
842
969
 
843
970
  GGML_ASSERT((src0t == GGML_TYPE_F32));
844
971
 
845
- const int n_past = ((int32_t *) src1->data)[0]; UNUSED(n_past);
846
- const int n_head = ((int32_t *) src1->data)[1];
847
- const float max_bias = ((float *) src1->data)[2];
972
+ const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past);
973
+ const int n_head = ((int32_t *) dst->op_params)[1];
974
+ float max_bias;
975
+ memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
848
976
 
849
977
  if (__builtin_popcount(n_head) != 1) {
850
978
  GGML_ASSERT(false && "only power-of-two n_head implemented");
@@ -879,18 +1007,17 @@ void ggml_metal_graph_compute(
879
1007
  case GGML_OP_ROPE:
880
1008
  {
881
1009
  if (encoder == nil) {
882
- encoder = [command_buffer computeCommandEncoder];
1010
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
883
1011
  }
884
1012
 
885
- const int n_dims = ((int32_t *) src1->data)[1];
886
- const int mode = ((int32_t *) src1->data)[2];
887
-
888
- const int n_past = ((int32_t *)(src1->data))[0];
1013
+ const int n_past = ((int32_t *) dst->op_params)[0];
1014
+ const int n_dims = ((int32_t *) dst->op_params)[1];
1015
+ const int mode = ((int32_t *) dst->op_params)[2];
889
1016
 
890
1017
  float freq_base;
891
1018
  float freq_scale;
892
- memcpy(&freq_base, (int32_t *) src1->data + 4, sizeof(float));
893
- memcpy(&freq_scale, (int32_t *) src1->data + 5, sizeof(float));
1019
+ memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float));
1020
+ memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
894
1021
 
895
1022
  [encoder setComputePipelineState:ctx->pipeline_rope];
896
1023
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@@ -919,10 +1046,12 @@ void ggml_metal_graph_compute(
919
1046
 
920
1047
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
921
1048
  } break;
1049
+ case GGML_OP_DUP:
922
1050
  case GGML_OP_CPY:
1051
+ case GGML_OP_CONT:
923
1052
  {
924
1053
  if (encoder == nil) {
925
- encoder = [command_buffer computeCommandEncoder];
1054
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
926
1055
  }
927
1056
 
928
1057
  const int nth = 32;
@@ -969,8 +1098,10 @@ void ggml_metal_graph_compute(
969
1098
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
970
1099
  } break;
971
1100
  default:
972
- fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
973
- GGML_ASSERT(false);
1101
+ {
1102
+ fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
1103
+ GGML_ASSERT(false);
1104
+ }
974
1105
  }
975
1106
  }
976
1107