llama_cpp 0.3.4 → 0.3.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -27,6 +27,7 @@ void ggml_cuda_assign_buffers(struct ggml_tensor * tensor);
27
27
  void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor);
28
28
  void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor);
29
29
  void ggml_cuda_set_main_device(int main_device);
30
+ void ggml_cuda_set_mul_mat_q(bool mul_mat_q);
30
31
  void ggml_cuda_set_scratch_size(size_t scratch_size);
31
32
  void ggml_cuda_free_scratch(void);
32
33
  bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor);
@@ -61,6 +61,13 @@ void ggml_metal_set_tensor(struct ggml_metal_context * ctx, struct ggml_tensor *
61
61
  // get data from the device into host memory
62
62
  void ggml_metal_get_tensor(struct ggml_metal_context * ctx, struct ggml_tensor * t);
63
63
 
64
+ // try to find operations that can be run concurrently in the graph
65
+ // you should run it again if the topology of your graph changes
66
+ void ggml_metal_graph_find_concurrency(struct ggml_metal_context * ctx, struct ggml_cgraph * gf);
67
+
68
+ // if the graph has been optimized for concurrently dispatch
69
+ bool ggml_metal_if_optimized(struct ggml_metal_context * ctx);
70
+
64
71
  // same as ggml_graph_compute but uses Metal
65
72
  // creates gf->n_threads command buffers in parallel
66
73
  void ggml_metal_graph_compute(struct ggml_metal_context * ctx, struct ggml_cgraph * gf);
@@ -36,12 +36,16 @@ struct ggml_metal_context {
36
36
  int n_buffers;
37
37
  struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
38
38
 
39
+ int concur_list[GGML_MAX_NODES];
40
+ int concur_list_len;
41
+
39
42
  // custom kernels
40
43
  #define GGML_METAL_DECL_KERNEL(name) \
41
44
  id<MTLFunction> function_##name; \
42
45
  id<MTLComputePipelineState> pipeline_##name
43
46
 
44
47
  GGML_METAL_DECL_KERNEL(add);
48
+ GGML_METAL_DECL_KERNEL(add_row); // TODO: avoid this extra kernel, instead extend the "add" kernel to support broadcast
45
49
  GGML_METAL_DECL_KERNEL(mul);
46
50
  GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast
47
51
  GGML_METAL_DECL_KERNEL(scale);
@@ -97,6 +101,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
97
101
  ctx->device = MTLCreateSystemDefaultDevice();
98
102
  ctx->queue = [ctx->device newCommandQueue];
99
103
  ctx->n_buffers = 0;
104
+ ctx->concur_list_len = 0;
100
105
 
101
106
  // determine if we can use MPS
102
107
  if (MPSSupportsMTLDevice(ctx->device)) {
@@ -157,6 +162,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
157
162
  fprintf(stderr, "%s: loaded %-32s %16p\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name);
158
163
 
159
164
  GGML_METAL_ADD_KERNEL(add);
165
+ GGML_METAL_ADD_KERNEL(add_row);
160
166
  GGML_METAL_ADD_KERNEL(mul);
161
167
  GGML_METAL_ADD_KERNEL(mul_row);
162
168
  GGML_METAL_ADD_KERNEL(scale);
@@ -215,6 +221,13 @@ void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb) {
215
221
  ctx->n_cb = n_cb;
216
222
  }
217
223
 
224
+ bool ggml_metal_if_optimized(struct ggml_metal_context * ctx) {
225
+ if (ctx->concur_list_len) {
226
+ return true;
227
+ }
228
+ return false;
229
+ }
230
+
218
231
  // finds the Metal buffer that contains the tensor data on the GPU device
219
232
  // the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
220
233
  // Metal buffer based on the host memory pointer
@@ -353,11 +366,98 @@ void ggml_metal_get_tensor(
353
366
  memcpy(t->data, (void *) ((uint8_t *) id_src.contents + offs), ggml_nbytes(t));
354
367
  }
355
368
 
369
+ void ggml_metal_graph_find_concurrency(
370
+ struct ggml_metal_context * ctx,
371
+ struct ggml_cgraph * gf) {
372
+ int search_depth = gf->n_nodes; //we only find concurrency in this range to avoid wasting too much time
373
+ int nodes_unused[GGML_MAX_NODES];
374
+
375
+ for (int i = 0; i < GGML_MAX_NODES; i++) {ctx->concur_list[i] = 0;}
376
+ for (int i = 0; i < gf->n_nodes; i++) {nodes_unused[i] = 1;}
377
+ ctx->concur_list_len = 0;
378
+
379
+ int n_left = gf->n_nodes;
380
+ int n_start = 0; // all nodes before n_start at nodes_unused array have been sorted and store back to ctx->concur_list
381
+ int level_pos = 0; // at ctx->concur_list, the last layer (level) ends at level_pos
382
+
383
+ while (n_left > 0) {
384
+ // number of nodes at a layer (that can be issued concurrently)
385
+ int concurrency = 0;
386
+ for (int i = n_start; i < ((n_start + search_depth > gf->n_nodes) ? gf->n_nodes : n_start + search_depth); i++) {
387
+ if (nodes_unused[i]) {
388
+ // if the requirements for gf->nodes[i] are satisfied
389
+ int exe_flag=1;
390
+ // scan all srcs
391
+ for (int src_ind = 0; src_ind < GGML_MAX_SRC; src_ind++) {
392
+ struct ggml_tensor * src_cur = gf->nodes[i]->src[src_ind];
393
+ if (src_cur) {
394
+ // if is leaf nodes it's satisfied.
395
+ if (src_cur->op == GGML_OP_NONE && src_cur->grad == NULL) {continue;}
396
+
397
+ // otherwise this src should be the output from previous nodes.
398
+ int is_found = 0;
399
+ // scan 2*search_depth back because we inserted barrier.
400
+ for (int j = ((level_pos - 2*search_depth) < 0 ? 0 : (level_pos - 2*search_depth)); j < level_pos; j++) {
401
+ if (gf->nodes[ctx->concur_list[j]] == src_cur) {is_found = 1; break;}
402
+ }
403
+ if (is_found == 0) {exe_flag = 0; break;}
404
+ }
405
+ }
406
+ if (exe_flag) {
407
+ // check if nodes[i]'s data will be overwritten by a node before nodes[i].
408
+ // if node[5] and node[3] write to the same memory region, then we can't issue node[5] before node[3]
409
+ int64_t data_start = (int64_t) gf->nodes[i]->data;
410
+ int64_t length = (int64_t) ggml_nbytes(gf->nodes[i]);
411
+ for (int j = n_start; j < i; j++) {
412
+ if (nodes_unused[j] && gf->nodes[j]->op != GGML_OP_RESHAPE \
413
+ && gf->nodes[j]->op != GGML_OP_VIEW \
414
+ && gf->nodes[j]->op != GGML_OP_TRANSPOSE \
415
+ && gf->nodes[j]->op != GGML_OP_PERMUTE) {
416
+ if (((int64_t)gf->nodes[j]->data) >= data_start + length || \
417
+ ((int64_t)gf->nodes[j]->data) + (int64_t) ggml_nbytes(gf->nodes[j]) <= data_start) {
418
+ continue;
419
+ } else {
420
+ exe_flag = 0;
421
+ }
422
+ }
423
+ }
424
+ }
425
+ if (exe_flag) {
426
+ ctx->concur_list[level_pos + concurrency] = i;
427
+ nodes_unused[i] = 0;
428
+ concurrency++;
429
+ ctx->concur_list_len++;
430
+ }
431
+ }
432
+ }
433
+ n_left -= concurrency;
434
+ // adding a barrier different layer
435
+ ctx->concur_list[level_pos + concurrency] = -1;
436
+ ctx->concur_list_len++;
437
+ // jump all sorted nodes at nodes_bak
438
+ while (!nodes_unused[n_start]) {n_start++;}
439
+ level_pos += concurrency + 1;
440
+ }
441
+
442
+ if (ctx->concur_list_len > GGML_MAX_NODES) {
443
+ fprintf(stderr, "%s: too many elements for metal ctx->concur_list!\n", __func__);
444
+ }
445
+ }
446
+
356
447
  void ggml_metal_graph_compute(
357
448
  struct ggml_metal_context * ctx,
358
449
  struct ggml_cgraph * gf) {
359
450
  metal_printf("%s: evaluating graph\n", __func__);
360
451
 
452
+ // if there is ctx->concur_list, dispatch concurrently
453
+ // else fallback to serial dispatch
454
+ MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor;
455
+
456
+ const bool has_concur = ctx->concur_list_len && ctx->concur_list_len <= GGML_MAX_NODES;
457
+
458
+ const int n_nodes = has_concur ? ctx->concur_list_len : gf->n_nodes;
459
+ edesc.dispatchType = has_concur ? MTLDispatchTypeConcurrent : MTLDispatchTypeSerial;
460
+
361
461
  // create multiple command buffers and enqueue them
362
462
  // then, we encode the graph into the command buffers in parallel
363
463
 
@@ -376,7 +476,7 @@ void ggml_metal_graph_compute(
376
476
  dispatch_queue_t queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT);
377
477
 
378
478
  for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
379
- const int n_nodes_per_cb = (gf->n_nodes + n_cb - 1) / n_cb;
479
+ const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb;
380
480
 
381
481
  dispatch_async(queue, ^{
382
482
  size_t offs_src0 = 0;
@@ -387,10 +487,21 @@ void ggml_metal_graph_compute(
387
487
 
388
488
  id<MTLComputeCommandEncoder> encoder = nil;
389
489
 
390
- const int node_start = (cb_idx + 0) * n_nodes_per_cb;
391
- const int node_end = (cb_idx == n_cb - 1) ? gf->n_nodes : (cb_idx + 1) * n_nodes_per_cb;
490
+ const int node_start = (cb_idx + 0) * n_nodes_per_cb;
491
+ const int node_end = (cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb;
492
+
493
+ for (int ind = node_start; ind < node_end; ++ind) {
494
+ const int i = has_concur ? ctx->concur_list[ind] : ind;
495
+
496
+ if (i == -1) {
497
+ if (encoder == nil) {
498
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
499
+ continue;
500
+ }
501
+ [encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
502
+ continue;
503
+ }
392
504
 
393
- for (int i = node_start; i < node_end; ++i) {
394
505
  metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
395
506
 
396
507
  struct ggml_tensor * src0 = gf->nodes[i]->src[0];
@@ -461,13 +572,19 @@ void ggml_metal_graph_compute(
461
572
  case GGML_OP_ADD:
462
573
  {
463
574
  if (encoder == nil) {
464
- encoder = [command_buffer computeCommandEncoder];
575
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
465
576
  }
466
577
 
467
- [encoder setComputePipelineState:ctx->pipeline_add];
578
+ if (ggml_nelements(src1) == ne10) {
579
+ // src1 is a row
580
+ [encoder setComputePipelineState:ctx->pipeline_add_row];
581
+ } else {
582
+ [encoder setComputePipelineState:ctx->pipeline_add];
583
+ }
468
584
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
469
585
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
470
586
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
587
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
471
588
 
472
589
  const int64_t n = ggml_nelements(dst);
473
590
 
@@ -476,7 +593,7 @@ void ggml_metal_graph_compute(
476
593
  case GGML_OP_MUL:
477
594
  {
478
595
  if (encoder == nil) {
479
- encoder = [command_buffer computeCommandEncoder];
596
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
480
597
  }
481
598
 
482
599
  if (ggml_nelements(src1) == ne10) {
@@ -497,7 +614,7 @@ void ggml_metal_graph_compute(
497
614
  case GGML_OP_SCALE:
498
615
  {
499
616
  if (encoder == nil) {
500
- encoder = [command_buffer computeCommandEncoder];
617
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
501
618
  }
502
619
 
503
620
  const float scale = *(const float *) src1->data;
@@ -511,52 +628,60 @@ void ggml_metal_graph_compute(
511
628
 
512
629
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
513
630
  } break;
514
- case GGML_OP_SILU:
515
- {
516
- if (encoder == nil) {
517
- encoder = [command_buffer computeCommandEncoder];
518
- }
519
-
520
- [encoder setComputePipelineState:ctx->pipeline_silu];
521
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
522
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
523
-
524
- const int64_t n = ggml_nelements(dst);
525
-
526
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
527
- } break;
528
- case GGML_OP_RELU:
529
- {
530
- if (encoder == nil) {
531
- encoder = [command_buffer computeCommandEncoder];
532
- }
533
-
534
- [encoder setComputePipelineState:ctx->pipeline_relu];
535
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
536
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
537
-
538
- const int64_t n = ggml_nelements(dst);
539
-
540
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
631
+ case GGML_OP_UNARY:
632
+ switch (ggml_get_unary_op(gf->nodes[i])) {
633
+ case GGML_UNARY_OP_SILU:
634
+ {
635
+ if (encoder == nil) {
636
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
637
+ }
638
+
639
+ [encoder setComputePipelineState:ctx->pipeline_silu];
640
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
641
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
642
+
643
+ const int64_t n = ggml_nelements(dst);
644
+
645
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
646
+ } break;
647
+ case GGML_UNARY_OP_RELU:
648
+ {
649
+ if (encoder == nil) {
650
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
651
+ }
652
+
653
+ [encoder setComputePipelineState:ctx->pipeline_relu];
654
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
655
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
656
+
657
+ const int64_t n = ggml_nelements(dst);
658
+
659
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
660
+ } break;
661
+ case GGML_UNARY_OP_GELU:
662
+ {
663
+ if (encoder == nil) {
664
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
665
+ }
666
+
667
+ [encoder setComputePipelineState:ctx->pipeline_gelu];
668
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
669
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
670
+
671
+ const int64_t n = ggml_nelements(dst);
672
+
673
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
674
+ } break;
675
+ default:
676
+ {
677
+ fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
678
+ GGML_ASSERT(false);
679
+ }
541
680
  } break;
542
- case GGML_OP_GELU:
543
- {
544
- if (encoder == nil) {
545
- encoder = [command_buffer computeCommandEncoder];
546
- }
547
-
548
- [encoder setComputePipelineState:ctx->pipeline_gelu];
549
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
550
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
551
-
552
- const int64_t n = ggml_nelements(dst);
553
-
554
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
555
- } break;
556
681
  case GGML_OP_SOFT_MAX:
557
682
  {
558
683
  if (encoder == nil) {
559
- encoder = [command_buffer computeCommandEncoder];
684
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
560
685
  }
561
686
 
562
687
  const int nth = 32;
@@ -574,10 +699,10 @@ void ggml_metal_graph_compute(
574
699
  case GGML_OP_DIAG_MASK_INF:
575
700
  {
576
701
  if (encoder == nil) {
577
- encoder = [command_buffer computeCommandEncoder];
702
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
578
703
  }
579
704
 
580
- const int n_past = ((int32_t *)(src1->data))[0];
705
+ const int n_past = ((int32_t *)(dst->op_params))[0];
581
706
 
582
707
  [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
583
708
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@@ -593,7 +718,8 @@ void ggml_metal_graph_compute(
593
718
  // TODO: needs to be updated after PR: https://github.com/ggerganov/ggml/pull/224
594
719
 
595
720
  GGML_ASSERT(ne00 == ne10);
596
- GGML_ASSERT(ne02 == ne12);
721
+ // GGML_ASSERT(ne02 == ne12); // Should be checked on individual data types until broadcast is implemented everywhere
722
+ GGML_ASSERT(ne03 == ne13);
597
723
 
598
724
  if (ggml_is_contiguous(src0) &&
599
725
  ggml_is_contiguous(src1) &&
@@ -621,11 +747,11 @@ void ggml_metal_graph_compute(
621
747
  initWithDevice:ctx->device transposeLeft:false transposeRight:true
622
748
  resultRows:ne11 resultColumns:ne01 interiorColumns:ne00 alpha:1.0 beta:0.0];
623
749
 
624
- // we need to do ne02 multiplications
750
+ // we need to do ne12 multiplications
625
751
  // TODO: is there a way to do this in parallel - currently very slow ..
626
752
  // TODO: might be possible to offload part of the computation to ANE using Accelerate's CBLAS
627
- for (int64_t i02 = 0; i02 < ne02; ++i02) {
628
- size_t offs_src0_cur = offs_src0 + i02*nb02;
753
+ for (int64_t i02 = 0; i02 < ne12; ++i02) {
754
+ size_t offs_src0_cur = offs_src0 + i02/(ne12/ne02)*nb02; // gqa not used for now
629
755
  size_t offs_src1_cur = offs_src1 + i02*nb12;
630
756
  size_t offs_dst_cur = offs_dst + i02*nb2;
631
757
 
@@ -637,7 +763,7 @@ void ggml_metal_graph_compute(
637
763
  }
638
764
  } else {
639
765
  if (encoder == nil) {
640
- encoder = [command_buffer computeCommandEncoder];
766
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
641
767
  }
642
768
 
643
769
  int nth0 = 32;
@@ -647,8 +773,6 @@ void ggml_metal_graph_compute(
647
773
  switch (src0t) {
648
774
  case GGML_TYPE_F16:
649
775
  {
650
- GGML_ASSERT(ne02 == ne12);
651
-
652
776
  nth0 = 64;
653
777
  nth1 = 1;
654
778
  [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
@@ -728,16 +852,18 @@ void ggml_metal_graph_compute(
728
852
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
729
853
  [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
730
854
  [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
731
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:5];
732
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:6];
733
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:7];
734
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:8];
735
- [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:9];
736
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10];
737
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11];
738
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12];
739
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
740
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
855
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
856
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
857
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
858
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
859
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9];
860
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10];
861
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:11];
862
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:12];
863
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:13];
864
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
865
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
866
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
741
867
 
742
868
  if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
743
869
  src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
@@ -764,7 +890,7 @@ void ggml_metal_graph_compute(
764
890
  case GGML_OP_GET_ROWS:
765
891
  {
766
892
  if (encoder == nil) {
767
- encoder = [command_buffer computeCommandEncoder];
893
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
768
894
  }
769
895
 
770
896
  switch (src0->type) {
@@ -793,10 +919,11 @@ void ggml_metal_graph_compute(
793
919
  case GGML_OP_RMS_NORM:
794
920
  {
795
921
  if (encoder == nil) {
796
- encoder = [command_buffer computeCommandEncoder];
922
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
797
923
  }
798
924
 
799
- const float eps = 1e-6f;
925
+ float eps;
926
+ memcpy(&eps, dst->op_params, sizeof(float));
800
927
 
801
928
  const int nth = 512;
802
929
 
@@ -815,7 +942,7 @@ void ggml_metal_graph_compute(
815
942
  case GGML_OP_NORM:
816
943
  {
817
944
  if (encoder == nil) {
818
- encoder = [command_buffer computeCommandEncoder];
945
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
819
946
  }
820
947
 
821
948
  const float eps = 1e-5f;
@@ -837,14 +964,15 @@ void ggml_metal_graph_compute(
837
964
  case GGML_OP_ALIBI:
838
965
  {
839
966
  if (encoder == nil) {
840
- encoder = [command_buffer computeCommandEncoder];
967
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
841
968
  }
842
969
 
843
970
  GGML_ASSERT((src0t == GGML_TYPE_F32));
844
971
 
845
- const int n_past = ((int32_t *) src1->data)[0]; UNUSED(n_past);
846
- const int n_head = ((int32_t *) src1->data)[1];
847
- const float max_bias = ((float *) src1->data)[2];
972
+ const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past);
973
+ const int n_head = ((int32_t *) dst->op_params)[1];
974
+ float max_bias;
975
+ memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
848
976
 
849
977
  if (__builtin_popcount(n_head) != 1) {
850
978
  GGML_ASSERT(false && "only power-of-two n_head implemented");
@@ -879,18 +1007,17 @@ void ggml_metal_graph_compute(
879
1007
  case GGML_OP_ROPE:
880
1008
  {
881
1009
  if (encoder == nil) {
882
- encoder = [command_buffer computeCommandEncoder];
1010
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
883
1011
  }
884
1012
 
885
- const int n_dims = ((int32_t *) src1->data)[1];
886
- const int mode = ((int32_t *) src1->data)[2];
887
-
888
- const int n_past = ((int32_t *)(src1->data))[0];
1013
+ const int n_past = ((int32_t *) dst->op_params)[0];
1014
+ const int n_dims = ((int32_t *) dst->op_params)[1];
1015
+ const int mode = ((int32_t *) dst->op_params)[2];
889
1016
 
890
1017
  float freq_base;
891
1018
  float freq_scale;
892
- memcpy(&freq_base, (int32_t *) src1->data + 4, sizeof(float));
893
- memcpy(&freq_scale, (int32_t *) src1->data + 5, sizeof(float));
1019
+ memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float));
1020
+ memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
894
1021
 
895
1022
  [encoder setComputePipelineState:ctx->pipeline_rope];
896
1023
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@@ -919,10 +1046,12 @@ void ggml_metal_graph_compute(
919
1046
 
920
1047
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
921
1048
  } break;
1049
+ case GGML_OP_DUP:
922
1050
  case GGML_OP_CPY:
1051
+ case GGML_OP_CONT:
923
1052
  {
924
1053
  if (encoder == nil) {
925
- encoder = [command_buffer computeCommandEncoder];
1054
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
926
1055
  }
927
1056
 
928
1057
  const int nth = 32;
@@ -969,8 +1098,10 @@ void ggml_metal_graph_compute(
969
1098
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
970
1099
  } break;
971
1100
  default:
972
- fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
973
- GGML_ASSERT(false);
1101
+ {
1102
+ fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
1103
+ GGML_ASSERT(false);
1104
+ }
974
1105
  }
975
1106
  }
976
1107