llama_cpp 0.3.2 → 0.3.4

Sign up to get free protection for your applications and to get access to all the features.
@@ -34,9 +34,13 @@ extern "C" {
34
34
 
35
35
  struct ggml_metal_context;
36
36
 
37
- struct ggml_metal_context * ggml_metal_init(void);
37
+ // number of command buffers to use
38
+ struct ggml_metal_context * ggml_metal_init(int n_cb);
38
39
  void ggml_metal_free(struct ggml_metal_context * ctx);
39
40
 
41
+ // set the number of command buffers to use
42
+ void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb);
43
+
40
44
  // creates a mapping between a host memory buffer and a device memory buffer
41
45
  // - make sure to map all buffers used in the graph before calling ggml_metal_graph_compute
42
46
  // - the mapping is used during computation to determine the arguments of the compute kernels
@@ -25,6 +25,8 @@ struct ggml_metal_buffer {
25
25
  };
26
26
 
27
27
  struct ggml_metal_context {
28
+ int n_cb;
29
+
28
30
  float * logits;
29
31
 
30
32
  id<MTLDevice> device;
@@ -86,11 +88,12 @@ static NSString * const msl_library_source = @"see metal.metal";
86
88
  @implementation GGMLMetalClass
87
89
  @end
88
90
 
89
- struct ggml_metal_context * ggml_metal_init(void) {
91
+ struct ggml_metal_context * ggml_metal_init(int n_cb) {
90
92
  fprintf(stderr, "%s: allocating\n", __func__);
91
93
 
92
94
  struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context));
93
95
 
96
+ ctx->n_cb = n_cb;
94
97
  ctx->device = MTLCreateSystemDefaultDevice();
95
98
  ctx->queue = [ctx->device newCommandQueue];
96
99
  ctx->n_buffers = 0;
@@ -208,6 +211,10 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
208
211
  free(ctx);
209
212
  }
210
213
 
214
+ void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb) {
215
+ ctx->n_cb = n_cb;
216
+ }
217
+
211
218
  // finds the Metal buffer that contains the tensor data on the GPU device
212
219
  // the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
213
220
  // Metal buffer based on the host memory pointer
@@ -354,7 +361,7 @@ void ggml_metal_graph_compute(
354
361
  // create multiple command buffers and enqueue them
355
362
  // then, we encode the graph into the command buffers in parallel
356
363
 
357
- const int n_cb = gf->n_threads;
364
+ const int n_cb = ctx->n_cb;
358
365
 
359
366
  NSMutableArray * command_buffers = [NSMutableArray arrayWithCapacity:n_cb];
360
367
 
@@ -386,8 +393,8 @@ void ggml_metal_graph_compute(
386
393
  for (int i = node_start; i < node_end; ++i) {
387
394
  metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
388
395
 
389
- struct ggml_tensor * src0 = gf->nodes[i]->src0;
390
- struct ggml_tensor * src1 = gf->nodes[i]->src1;
396
+ struct ggml_tensor * src0 = gf->nodes[i]->src[0];
397
+ struct ggml_tensor * src1 = gf->nodes[i]->src[1];
391
398
  struct ggml_tensor * dst = gf->nodes[i];
392
399
 
393
400
  const int64_t ne00 = src0 ? src0->ne[0] : 0;
@@ -443,6 +450,7 @@ void ggml_metal_graph_compute(
443
450
  //}
444
451
 
445
452
  switch (dst->op) {
453
+ case GGML_OP_NONE:
446
454
  case GGML_OP_RESHAPE:
447
455
  case GGML_OP_VIEW:
448
456
  case GGML_OP_TRANSPOSE:
@@ -668,8 +676,8 @@ void ggml_metal_graph_compute(
668
676
  GGML_ASSERT(ne02 == 1);
669
677
  GGML_ASSERT(ne12 == 1);
670
678
 
671
- nth0 = 4;
672
- nth1 = 16;
679
+ nth0 = 2;
680
+ nth1 = 32;
673
681
  [encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_K_f32];
674
682
  } break;
675
683
  case GGML_TYPE_Q3_K:
@@ -677,8 +685,8 @@ void ggml_metal_graph_compute(
677
685
  GGML_ASSERT(ne02 == 1);
678
686
  GGML_ASSERT(ne12 == 1);
679
687
 
680
- nth0 = 4;
681
- nth1 = 16;
688
+ nth0 = 2;
689
+ nth1 = 32;
682
690
  [encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_K_f32];
683
691
  } break;
684
692
  case GGML_TYPE_Q4_K:
@@ -686,8 +694,8 @@ void ggml_metal_graph_compute(
686
694
  GGML_ASSERT(ne02 == 1);
687
695
  GGML_ASSERT(ne12 == 1);
688
696
 
689
- nth0 = 4;
690
- nth1 = 16;
697
+ nth0 = 2;
698
+ nth1 = 32;
691
699
  [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_K_f32];
692
700
  } break;
693
701
  case GGML_TYPE_Q5_K:
@@ -695,8 +703,8 @@ void ggml_metal_graph_compute(
695
703
  GGML_ASSERT(ne02 == 1);
696
704
  GGML_ASSERT(ne12 == 1);
697
705
 
698
- nth0 = 4;
699
- nth1 = 16;
706
+ nth0 = 2;
707
+ nth1 = 32;
700
708
  [encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_K_f32];
701
709
  } break;
702
710
  case GGML_TYPE_Q6_K:
@@ -704,8 +712,8 @@ void ggml_metal_graph_compute(
704
712
  GGML_ASSERT(ne02 == 1);
705
713
  GGML_ASSERT(ne12 == 1);
706
714
 
707
- nth0 = 4;
708
- nth1 = 16;
715
+ nth0 = 2;
716
+ nth1 = 32;
709
717
  [encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_K_f32];
710
718
  } break;
711
719
  default:
@@ -731,17 +739,22 @@ void ggml_metal_graph_compute(
731
739
  [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
732
740
  [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
733
741
 
734
- if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1) {
735
- [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
736
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
742
+ if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
743
+ src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
744
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
737
745
  }
738
- else if (src0t == GGML_TYPE_Q2_K ||
739
- src0t == GGML_TYPE_Q3_K ||
740
- src0t == GGML_TYPE_Q4_K ||
741
- src0t == GGML_TYPE_Q5_K ||
742
- src0t == GGML_TYPE_Q6_K) {
743
- [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
744
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
746
+ else if (src0t == GGML_TYPE_Q3_K) {
747
+ #ifdef GGML_QKK_64
748
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
749
+ #else
750
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01+3)/4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
751
+ #endif
752
+ }
753
+ else if (src0t == GGML_TYPE_Q5_K) {
754
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3) / 4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
755
+ }
756
+ else if (src0t == GGML_TYPE_Q6_K) {
757
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
745
758
  } else {
746
759
  [encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
747
760
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
@@ -785,7 +798,7 @@ void ggml_metal_graph_compute(
785
798
 
786
799
  const float eps = 1e-6f;
787
800
 
788
- const int nth = 256;
801
+ const int nth = 512;
789
802
 
790
803
  [encoder setComputePipelineState:ctx->pipeline_rms_norm];
791
804
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@@ -793,7 +806,7 @@ void ggml_metal_graph_compute(
793
806
  [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
794
807
  [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
795
808
  [encoder setBytes:&eps length:sizeof( float) atIndex:4];
796
- [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
809
+ [encoder setThreadgroupMemoryLength:nth/32*sizeof(float) atIndex:0];
797
810
 
798
811
  const int64_t nrows = ggml_nrows(src0);
799
812
 
@@ -874,28 +887,35 @@ void ggml_metal_graph_compute(
874
887
 
875
888
  const int n_past = ((int32_t *)(src1->data))[0];
876
889
 
890
+ float freq_base;
891
+ float freq_scale;
892
+ memcpy(&freq_base, (int32_t *) src1->data + 4, sizeof(float));
893
+ memcpy(&freq_scale, (int32_t *) src1->data + 5, sizeof(float));
894
+
877
895
  [encoder setComputePipelineState:ctx->pipeline_rope];
878
896
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
879
897
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
880
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
881
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
882
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
883
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
884
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
885
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
886
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
887
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
888
- [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
889
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
890
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
891
- [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
892
- [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
893
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
894
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
895
- [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
896
- [encoder setBytes:&n_past length:sizeof( int) atIndex:18];
897
- [encoder setBytes:&n_dims length:sizeof( int) atIndex:19];
898
- [encoder setBytes:&mode length:sizeof( int) atIndex:20];
898
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
899
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
900
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
901
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
902
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
903
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
904
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
905
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
906
+ [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
907
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
908
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
909
+ [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
910
+ [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
911
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
912
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
913
+ [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
914
+ [encoder setBytes:&n_past length:sizeof( int) atIndex:18];
915
+ [encoder setBytes:&n_dims length:sizeof( int) atIndex:19];
916
+ [encoder setBytes:&mode length:sizeof( int) atIndex:20];
917
+ [encoder setBytes:&freq_base length:sizeof(float) atIndex:21];
918
+ [encoder setBytes:&freq_scale length:sizeof(float) atIndex:22];
899
919
 
900
920
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
901
921
  } break;