llama_cpp 0.3.2 → 0.3.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -34,9 +34,13 @@ extern "C" {
34
34
 
35
35
  struct ggml_metal_context;
36
36
 
37
- struct ggml_metal_context * ggml_metal_init(void);
37
+ // number of command buffers to use
38
+ struct ggml_metal_context * ggml_metal_init(int n_cb);
38
39
  void ggml_metal_free(struct ggml_metal_context * ctx);
39
40
 
41
+ // set the number of command buffers to use
42
+ void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb);
43
+
40
44
  // creates a mapping between a host memory buffer and a device memory buffer
41
45
  // - make sure to map all buffers used in the graph before calling ggml_metal_graph_compute
42
46
  // - the mapping is used during computation to determine the arguments of the compute kernels
@@ -25,6 +25,8 @@ struct ggml_metal_buffer {
25
25
  };
26
26
 
27
27
  struct ggml_metal_context {
28
+ int n_cb;
29
+
28
30
  float * logits;
29
31
 
30
32
  id<MTLDevice> device;
@@ -86,11 +88,12 @@ static NSString * const msl_library_source = @"see metal.metal";
86
88
  @implementation GGMLMetalClass
87
89
  @end
88
90
 
89
- struct ggml_metal_context * ggml_metal_init(void) {
91
+ struct ggml_metal_context * ggml_metal_init(int n_cb) {
90
92
  fprintf(stderr, "%s: allocating\n", __func__);
91
93
 
92
94
  struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context));
93
95
 
96
+ ctx->n_cb = n_cb;
94
97
  ctx->device = MTLCreateSystemDefaultDevice();
95
98
  ctx->queue = [ctx->device newCommandQueue];
96
99
  ctx->n_buffers = 0;
@@ -208,6 +211,10 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
208
211
  free(ctx);
209
212
  }
210
213
 
214
+ void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb) {
215
+ ctx->n_cb = n_cb;
216
+ }
217
+
211
218
  // finds the Metal buffer that contains the tensor data on the GPU device
212
219
  // the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
213
220
  // Metal buffer based on the host memory pointer
@@ -354,7 +361,7 @@ void ggml_metal_graph_compute(
354
361
  // create multiple command buffers and enqueue them
355
362
  // then, we encode the graph into the command buffers in parallel
356
363
 
357
- const int n_cb = gf->n_threads;
364
+ const int n_cb = ctx->n_cb;
358
365
 
359
366
  NSMutableArray * command_buffers = [NSMutableArray arrayWithCapacity:n_cb];
360
367
 
@@ -386,8 +393,8 @@ void ggml_metal_graph_compute(
386
393
  for (int i = node_start; i < node_end; ++i) {
387
394
  metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
388
395
 
389
- struct ggml_tensor * src0 = gf->nodes[i]->src0;
390
- struct ggml_tensor * src1 = gf->nodes[i]->src1;
396
+ struct ggml_tensor * src0 = gf->nodes[i]->src[0];
397
+ struct ggml_tensor * src1 = gf->nodes[i]->src[1];
391
398
  struct ggml_tensor * dst = gf->nodes[i];
392
399
 
393
400
  const int64_t ne00 = src0 ? src0->ne[0] : 0;
@@ -443,6 +450,7 @@ void ggml_metal_graph_compute(
443
450
  //}
444
451
 
445
452
  switch (dst->op) {
453
+ case GGML_OP_NONE:
446
454
  case GGML_OP_RESHAPE:
447
455
  case GGML_OP_VIEW:
448
456
  case GGML_OP_TRANSPOSE:
@@ -668,8 +676,8 @@ void ggml_metal_graph_compute(
668
676
  GGML_ASSERT(ne02 == 1);
669
677
  GGML_ASSERT(ne12 == 1);
670
678
 
671
- nth0 = 4;
672
- nth1 = 16;
679
+ nth0 = 2;
680
+ nth1 = 32;
673
681
  [encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_K_f32];
674
682
  } break;
675
683
  case GGML_TYPE_Q3_K:
@@ -677,8 +685,8 @@ void ggml_metal_graph_compute(
677
685
  GGML_ASSERT(ne02 == 1);
678
686
  GGML_ASSERT(ne12 == 1);
679
687
 
680
- nth0 = 4;
681
- nth1 = 16;
688
+ nth0 = 2;
689
+ nth1 = 32;
682
690
  [encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_K_f32];
683
691
  } break;
684
692
  case GGML_TYPE_Q4_K:
@@ -686,8 +694,8 @@ void ggml_metal_graph_compute(
686
694
  GGML_ASSERT(ne02 == 1);
687
695
  GGML_ASSERT(ne12 == 1);
688
696
 
689
- nth0 = 4;
690
- nth1 = 16;
697
+ nth0 = 2;
698
+ nth1 = 32;
691
699
  [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_K_f32];
692
700
  } break;
693
701
  case GGML_TYPE_Q5_K:
@@ -695,8 +703,8 @@ void ggml_metal_graph_compute(
695
703
  GGML_ASSERT(ne02 == 1);
696
704
  GGML_ASSERT(ne12 == 1);
697
705
 
698
- nth0 = 4;
699
- nth1 = 16;
706
+ nth0 = 2;
707
+ nth1 = 32;
700
708
  [encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_K_f32];
701
709
  } break;
702
710
  case GGML_TYPE_Q6_K:
@@ -704,8 +712,8 @@ void ggml_metal_graph_compute(
704
712
  GGML_ASSERT(ne02 == 1);
705
713
  GGML_ASSERT(ne12 == 1);
706
714
 
707
- nth0 = 4;
708
- nth1 = 16;
715
+ nth0 = 2;
716
+ nth1 = 32;
709
717
  [encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_K_f32];
710
718
  } break;
711
719
  default:
@@ -731,17 +739,22 @@ void ggml_metal_graph_compute(
731
739
  [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
732
740
  [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
733
741
 
734
- if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1) {
735
- [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
736
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
742
+ if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
743
+ src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
744
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
737
745
  }
738
- else if (src0t == GGML_TYPE_Q2_K ||
739
- src0t == GGML_TYPE_Q3_K ||
740
- src0t == GGML_TYPE_Q4_K ||
741
- src0t == GGML_TYPE_Q5_K ||
742
- src0t == GGML_TYPE_Q6_K) {
743
- [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
744
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
746
+ else if (src0t == GGML_TYPE_Q3_K) {
747
+ #ifdef GGML_QKK_64
748
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
749
+ #else
750
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01+3)/4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
751
+ #endif
752
+ }
753
+ else if (src0t == GGML_TYPE_Q5_K) {
754
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3) / 4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
755
+ }
756
+ else if (src0t == GGML_TYPE_Q6_K) {
757
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
745
758
  } else {
746
759
  [encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
747
760
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
@@ -785,7 +798,7 @@ void ggml_metal_graph_compute(
785
798
 
786
799
  const float eps = 1e-6f;
787
800
 
788
- const int nth = 256;
801
+ const int nth = 512;
789
802
 
790
803
  [encoder setComputePipelineState:ctx->pipeline_rms_norm];
791
804
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@@ -793,7 +806,7 @@ void ggml_metal_graph_compute(
793
806
  [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
794
807
  [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
795
808
  [encoder setBytes:&eps length:sizeof( float) atIndex:4];
796
- [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
809
+ [encoder setThreadgroupMemoryLength:nth/32*sizeof(float) atIndex:0];
797
810
 
798
811
  const int64_t nrows = ggml_nrows(src0);
799
812
 
@@ -874,28 +887,35 @@ void ggml_metal_graph_compute(
874
887
 
875
888
  const int n_past = ((int32_t *)(src1->data))[0];
876
889
 
890
+ float freq_base;
891
+ float freq_scale;
892
+ memcpy(&freq_base, (int32_t *) src1->data + 4, sizeof(float));
893
+ memcpy(&freq_scale, (int32_t *) src1->data + 5, sizeof(float));
894
+
877
895
  [encoder setComputePipelineState:ctx->pipeline_rope];
878
896
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
879
897
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
880
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
881
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
882
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
883
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
884
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
885
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
886
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
887
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
888
- [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
889
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
890
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
891
- [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
892
- [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
893
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
894
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
895
- [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
896
- [encoder setBytes:&n_past length:sizeof( int) atIndex:18];
897
- [encoder setBytes:&n_dims length:sizeof( int) atIndex:19];
898
- [encoder setBytes:&mode length:sizeof( int) atIndex:20];
898
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
899
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
900
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
901
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
902
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
903
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
904
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
905
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
906
+ [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
907
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
908
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
909
+ [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
910
+ [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
911
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
912
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
913
+ [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
914
+ [encoder setBytes:&n_past length:sizeof( int) atIndex:18];
915
+ [encoder setBytes:&n_dims length:sizeof( int) atIndex:19];
916
+ [encoder setBytes:&mode length:sizeof( int) atIndex:20];
917
+ [encoder setBytes:&freq_base length:sizeof(float) atIndex:21];
918
+ [encoder setBytes:&freq_scale length:sizeof(float) atIndex:22];
899
919
 
900
920
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
901
921
  } break;