llama_cpp 0.3.3 → 0.3.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -676,8 +676,8 @@ void ggml_metal_graph_compute(
676
676
  GGML_ASSERT(ne02 == 1);
677
677
  GGML_ASSERT(ne12 == 1);
678
678
 
679
- nth0 = 4;
680
- nth1 = 16;
679
+ nth0 = 2;
680
+ nth1 = 32;
681
681
  [encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_K_f32];
682
682
  } break;
683
683
  case GGML_TYPE_Q3_K:
@@ -685,8 +685,8 @@ void ggml_metal_graph_compute(
685
685
  GGML_ASSERT(ne02 == 1);
686
686
  GGML_ASSERT(ne12 == 1);
687
687
 
688
- nth0 = 4;
689
- nth1 = 16;
688
+ nth0 = 2;
689
+ nth1 = 32;
690
690
  [encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_K_f32];
691
691
  } break;
692
692
  case GGML_TYPE_Q4_K:
@@ -694,8 +694,8 @@ void ggml_metal_graph_compute(
694
694
  GGML_ASSERT(ne02 == 1);
695
695
  GGML_ASSERT(ne12 == 1);
696
696
 
697
- nth0 = 4;
698
- nth1 = 16;
697
+ nth0 = 2;
698
+ nth1 = 32;
699
699
  [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_K_f32];
700
700
  } break;
701
701
  case GGML_TYPE_Q5_K:
@@ -703,8 +703,8 @@ void ggml_metal_graph_compute(
703
703
  GGML_ASSERT(ne02 == 1);
704
704
  GGML_ASSERT(ne12 == 1);
705
705
 
706
- nth0 = 4;
707
- nth1 = 16;
706
+ nth0 = 2;
707
+ nth1 = 32;
708
708
  [encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_K_f32];
709
709
  } break;
710
710
  case GGML_TYPE_Q6_K:
@@ -712,8 +712,8 @@ void ggml_metal_graph_compute(
712
712
  GGML_ASSERT(ne02 == 1);
713
713
  GGML_ASSERT(ne12 == 1);
714
714
 
715
- nth0 = 4;
716
- nth1 = 16;
715
+ nth0 = 2;
716
+ nth1 = 32;
717
717
  [encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_K_f32];
718
718
  } break;
719
719
  default:
@@ -739,20 +739,22 @@ void ggml_metal_graph_compute(
739
739
  [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
740
740
  [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
741
741
 
742
- if (src0t == GGML_TYPE_Q4_0) {
743
- [encoder dispatchThreadgroups:MTLSizeMake(ne01 / 8+((ne01 % 8) & 0x01), ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
742
+ if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
743
+ src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
744
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
744
745
  }
745
- else if (src0t == GGML_TYPE_Q4_1) {
746
- [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
747
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
746
+ else if (src0t == GGML_TYPE_Q3_K) {
747
+ #ifdef GGML_QKK_64
748
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
749
+ #else
750
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01+3)/4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
751
+ #endif
752
+ }
753
+ else if (src0t == GGML_TYPE_Q5_K) {
754
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3) / 4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
748
755
  }
749
- else if (src0t == GGML_TYPE_Q2_K ||
750
- src0t == GGML_TYPE_Q3_K ||
751
- src0t == GGML_TYPE_Q4_K ||
752
- src0t == GGML_TYPE_Q5_K ||
753
- src0t == GGML_TYPE_Q6_K) {
754
- [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
755
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
756
+ else if (src0t == GGML_TYPE_Q6_K) {
757
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
756
758
  } else {
757
759
  [encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
758
760
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
@@ -796,7 +798,7 @@ void ggml_metal_graph_compute(
796
798
 
797
799
  const float eps = 1e-6f;
798
800
 
799
- const int nth = 256;
801
+ const int nth = 512;
800
802
 
801
803
  [encoder setComputePipelineState:ctx->pipeline_rms_norm];
802
804
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@@ -804,7 +806,7 @@ void ggml_metal_graph_compute(
804
806
  [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
805
807
  [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
806
808
  [encoder setBytes:&eps length:sizeof( float) atIndex:4];
807
- [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
809
+ [encoder setThreadgroupMemoryLength:nth/32*sizeof(float) atIndex:0];
808
810
 
809
811
  const int64_t nrows = ggml_nrows(src0);
810
812
 
@@ -885,28 +887,35 @@ void ggml_metal_graph_compute(
885
887
 
886
888
  const int n_past = ((int32_t *)(src1->data))[0];
887
889
 
890
+ float freq_base;
891
+ float freq_scale;
892
+ memcpy(&freq_base, (int32_t *) src1->data + 4, sizeof(float));
893
+ memcpy(&freq_scale, (int32_t *) src1->data + 5, sizeof(float));
894
+
888
895
  [encoder setComputePipelineState:ctx->pipeline_rope];
889
896
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
890
897
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
891
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
892
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
893
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
894
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
895
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
896
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
897
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
898
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
899
- [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
900
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
901
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
902
- [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
903
- [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
904
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
905
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
906
- [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
907
- [encoder setBytes:&n_past length:sizeof( int) atIndex:18];
908
- [encoder setBytes:&n_dims length:sizeof( int) atIndex:19];
909
- [encoder setBytes:&mode length:sizeof( int) atIndex:20];
898
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
899
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
900
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
901
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
902
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
903
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
904
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
905
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
906
+ [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
907
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
908
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
909
+ [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
910
+ [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
911
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
912
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
913
+ [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
914
+ [encoder setBytes:&n_past length:sizeof( int) atIndex:18];
915
+ [encoder setBytes:&n_dims length:sizeof( int) atIndex:19];
916
+ [encoder setBytes:&mode length:sizeof( int) atIndex:20];
917
+ [encoder setBytes:&freq_base length:sizeof(float) atIndex:21];
918
+ [encoder setBytes:&freq_scale length:sizeof(float) atIndex:22];
910
919
 
911
920
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
912
921
  } break;