llama_cpp 0.3.3 → 0.3.4

Sign up to get free protection for your applications and to get access to all the features.
@@ -676,8 +676,8 @@ void ggml_metal_graph_compute(
676
676
  GGML_ASSERT(ne02 == 1);
677
677
  GGML_ASSERT(ne12 == 1);
678
678
 
679
- nth0 = 4;
680
- nth1 = 16;
679
+ nth0 = 2;
680
+ nth1 = 32;
681
681
  [encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_K_f32];
682
682
  } break;
683
683
  case GGML_TYPE_Q3_K:
@@ -685,8 +685,8 @@ void ggml_metal_graph_compute(
685
685
  GGML_ASSERT(ne02 == 1);
686
686
  GGML_ASSERT(ne12 == 1);
687
687
 
688
- nth0 = 4;
689
- nth1 = 16;
688
+ nth0 = 2;
689
+ nth1 = 32;
690
690
  [encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_K_f32];
691
691
  } break;
692
692
  case GGML_TYPE_Q4_K:
@@ -694,8 +694,8 @@ void ggml_metal_graph_compute(
694
694
  GGML_ASSERT(ne02 == 1);
695
695
  GGML_ASSERT(ne12 == 1);
696
696
 
697
- nth0 = 4;
698
- nth1 = 16;
697
+ nth0 = 2;
698
+ nth1 = 32;
699
699
  [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_K_f32];
700
700
  } break;
701
701
  case GGML_TYPE_Q5_K:
@@ -703,8 +703,8 @@ void ggml_metal_graph_compute(
703
703
  GGML_ASSERT(ne02 == 1);
704
704
  GGML_ASSERT(ne12 == 1);
705
705
 
706
- nth0 = 4;
707
- nth1 = 16;
706
+ nth0 = 2;
707
+ nth1 = 32;
708
708
  [encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_K_f32];
709
709
  } break;
710
710
  case GGML_TYPE_Q6_K:
@@ -712,8 +712,8 @@ void ggml_metal_graph_compute(
712
712
  GGML_ASSERT(ne02 == 1);
713
713
  GGML_ASSERT(ne12 == 1);
714
714
 
715
- nth0 = 4;
716
- nth1 = 16;
715
+ nth0 = 2;
716
+ nth1 = 32;
717
717
  [encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_K_f32];
718
718
  } break;
719
719
  default:
@@ -739,20 +739,22 @@ void ggml_metal_graph_compute(
739
739
  [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
740
740
  [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
741
741
 
742
- if (src0t == GGML_TYPE_Q4_0) {
743
- [encoder dispatchThreadgroups:MTLSizeMake(ne01 / 8+((ne01 % 8) & 0x01), ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
742
+ if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
743
+ src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
744
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
744
745
  }
745
- else if (src0t == GGML_TYPE_Q4_1) {
746
- [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
747
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
746
+ else if (src0t == GGML_TYPE_Q3_K) {
747
+ #ifdef GGML_QKK_64
748
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
749
+ #else
750
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01+3)/4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
751
+ #endif
752
+ }
753
+ else if (src0t == GGML_TYPE_Q5_K) {
754
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3) / 4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
748
755
  }
749
- else if (src0t == GGML_TYPE_Q2_K ||
750
- src0t == GGML_TYPE_Q3_K ||
751
- src0t == GGML_TYPE_Q4_K ||
752
- src0t == GGML_TYPE_Q5_K ||
753
- src0t == GGML_TYPE_Q6_K) {
754
- [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
755
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
756
+ else if (src0t == GGML_TYPE_Q6_K) {
757
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
756
758
  } else {
757
759
  [encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
758
760
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
@@ -796,7 +798,7 @@ void ggml_metal_graph_compute(
796
798
 
797
799
  const float eps = 1e-6f;
798
800
 
799
- const int nth = 256;
801
+ const int nth = 512;
800
802
 
801
803
  [encoder setComputePipelineState:ctx->pipeline_rms_norm];
802
804
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@@ -804,7 +806,7 @@ void ggml_metal_graph_compute(
804
806
  [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
805
807
  [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
806
808
  [encoder setBytes:&eps length:sizeof( float) atIndex:4];
807
- [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
809
+ [encoder setThreadgroupMemoryLength:nth/32*sizeof(float) atIndex:0];
808
810
 
809
811
  const int64_t nrows = ggml_nrows(src0);
810
812
 
@@ -885,28 +887,35 @@ void ggml_metal_graph_compute(
885
887
 
886
888
  const int n_past = ((int32_t *)(src1->data))[0];
887
889
 
890
+ float freq_base;
891
+ float freq_scale;
892
+ memcpy(&freq_base, (int32_t *) src1->data + 4, sizeof(float));
893
+ memcpy(&freq_scale, (int32_t *) src1->data + 5, sizeof(float));
894
+
888
895
  [encoder setComputePipelineState:ctx->pipeline_rope];
889
896
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
890
897
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
891
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
892
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
893
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
894
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
895
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
896
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
897
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
898
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
899
- [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
900
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
901
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
902
- [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
903
- [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
904
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
905
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
906
- [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
907
- [encoder setBytes:&n_past length:sizeof( int) atIndex:18];
908
- [encoder setBytes:&n_dims length:sizeof( int) atIndex:19];
909
- [encoder setBytes:&mode length:sizeof( int) atIndex:20];
898
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
899
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
900
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
901
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
902
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
903
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
904
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
905
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
906
+ [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
907
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
908
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
909
+ [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
910
+ [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
911
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
912
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
913
+ [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
914
+ [encoder setBytes:&n_past length:sizeof( int) atIndex:18];
915
+ [encoder setBytes:&n_dims length:sizeof( int) atIndex:19];
916
+ [encoder setBytes:&mode length:sizeof( int) atIndex:20];
917
+ [encoder setBytes:&freq_base length:sizeof(float) atIndex:21];
918
+ [encoder setBytes:&freq_scale length:sizeof(float) atIndex:22];
910
919
 
911
920
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
912
921
  } break;