llama_cpp 0.3.3 → 0.3.4
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +24 -0
- data/ext/llama_cpp/llama_cpp.cpp +146 -9
- data/ext/llama_cpp/src/ggml-cuda.cu +485 -67
- data/ext/llama_cpp/src/ggml-metal.m +52 -43
- data/ext/llama_cpp/src/ggml-metal.metal +587 -470
- data/ext/llama_cpp/src/ggml.c +105 -79
- data/ext/llama_cpp/src/ggml.h +13 -1
- data/ext/llama_cpp/src/k_quants.h +8 -0
- data/ext/llama_cpp/src/llama.cpp +123 -66
- data/ext/llama_cpp/src/llama.h +34 -5
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +1 -0
- data/sig/llama_cpp.rbs +12 -1
- metadata +2 -2
@@ -676,8 +676,8 @@ void ggml_metal_graph_compute(
|
|
676
676
|
GGML_ASSERT(ne02 == 1);
|
677
677
|
GGML_ASSERT(ne12 == 1);
|
678
678
|
|
679
|
-
nth0 =
|
680
|
-
nth1 =
|
679
|
+
nth0 = 2;
|
680
|
+
nth1 = 32;
|
681
681
|
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_K_f32];
|
682
682
|
} break;
|
683
683
|
case GGML_TYPE_Q3_K:
|
@@ -685,8 +685,8 @@ void ggml_metal_graph_compute(
|
|
685
685
|
GGML_ASSERT(ne02 == 1);
|
686
686
|
GGML_ASSERT(ne12 == 1);
|
687
687
|
|
688
|
-
nth0 =
|
689
|
-
nth1 =
|
688
|
+
nth0 = 2;
|
689
|
+
nth1 = 32;
|
690
690
|
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_K_f32];
|
691
691
|
} break;
|
692
692
|
case GGML_TYPE_Q4_K:
|
@@ -694,8 +694,8 @@ void ggml_metal_graph_compute(
|
|
694
694
|
GGML_ASSERT(ne02 == 1);
|
695
695
|
GGML_ASSERT(ne12 == 1);
|
696
696
|
|
697
|
-
nth0 =
|
698
|
-
nth1 =
|
697
|
+
nth0 = 2;
|
698
|
+
nth1 = 32;
|
699
699
|
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_K_f32];
|
700
700
|
} break;
|
701
701
|
case GGML_TYPE_Q5_K:
|
@@ -703,8 +703,8 @@ void ggml_metal_graph_compute(
|
|
703
703
|
GGML_ASSERT(ne02 == 1);
|
704
704
|
GGML_ASSERT(ne12 == 1);
|
705
705
|
|
706
|
-
nth0 =
|
707
|
-
nth1 =
|
706
|
+
nth0 = 2;
|
707
|
+
nth1 = 32;
|
708
708
|
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_K_f32];
|
709
709
|
} break;
|
710
710
|
case GGML_TYPE_Q6_K:
|
@@ -712,8 +712,8 @@ void ggml_metal_graph_compute(
|
|
712
712
|
GGML_ASSERT(ne02 == 1);
|
713
713
|
GGML_ASSERT(ne12 == 1);
|
714
714
|
|
715
|
-
nth0 =
|
716
|
-
nth1 =
|
715
|
+
nth0 = 2;
|
716
|
+
nth1 = 32;
|
717
717
|
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_K_f32];
|
718
718
|
} break;
|
719
719
|
default:
|
@@ -739,20 +739,22 @@ void ggml_metal_graph_compute(
|
|
739
739
|
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
|
740
740
|
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
|
741
741
|
|
742
|
-
if (src0t == GGML_TYPE_Q4_0
|
743
|
-
|
742
|
+
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
|
743
|
+
src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
|
744
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
744
745
|
}
|
745
|
-
else if (src0t ==
|
746
|
-
|
747
|
-
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
746
|
+
else if (src0t == GGML_TYPE_Q3_K) {
|
747
|
+
#ifdef GGML_QKK_64
|
748
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
749
|
+
#else
|
750
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01+3)/4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
751
|
+
#endif
|
752
|
+
}
|
753
|
+
else if (src0t == GGML_TYPE_Q5_K) {
|
754
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3) / 4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
748
755
|
}
|
749
|
-
else if (src0t ==
|
750
|
-
|
751
|
-
src0t == GGML_TYPE_Q4_K ||
|
752
|
-
src0t == GGML_TYPE_Q5_K ||
|
753
|
-
src0t == GGML_TYPE_Q6_K) {
|
754
|
-
[encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
|
755
|
-
[encoder dispatchThreadgroups:MTLSizeMake(ne01, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
756
|
+
else if (src0t == GGML_TYPE_Q6_K) {
|
757
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
756
758
|
} else {
|
757
759
|
[encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
|
758
760
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
@@ -796,7 +798,7 @@ void ggml_metal_graph_compute(
|
|
796
798
|
|
797
799
|
const float eps = 1e-6f;
|
798
800
|
|
799
|
-
const int nth =
|
801
|
+
const int nth = 512;
|
800
802
|
|
801
803
|
[encoder setComputePipelineState:ctx->pipeline_rms_norm];
|
802
804
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
@@ -804,7 +806,7 @@ void ggml_metal_graph_compute(
|
|
804
806
|
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
805
807
|
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
|
806
808
|
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
|
807
|
-
[encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
|
809
|
+
[encoder setThreadgroupMemoryLength:nth/32*sizeof(float) atIndex:0];
|
808
810
|
|
809
811
|
const int64_t nrows = ggml_nrows(src0);
|
810
812
|
|
@@ -885,28 +887,35 @@ void ggml_metal_graph_compute(
|
|
885
887
|
|
886
888
|
const int n_past = ((int32_t *)(src1->data))[0];
|
887
889
|
|
890
|
+
float freq_base;
|
891
|
+
float freq_scale;
|
892
|
+
memcpy(&freq_base, (int32_t *) src1->data + 4, sizeof(float));
|
893
|
+
memcpy(&freq_scale, (int32_t *) src1->data + 5, sizeof(float));
|
894
|
+
|
888
895
|
[encoder setComputePipelineState:ctx->pipeline_rope];
|
889
896
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
890
897
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
891
|
-
[encoder setBytes:&ne00
|
892
|
-
[encoder setBytes:&ne01
|
893
|
-
[encoder setBytes:&ne02
|
894
|
-
[encoder setBytes:&ne03
|
895
|
-
[encoder setBytes:&nb00
|
896
|
-
[encoder setBytes:&nb01
|
897
|
-
[encoder setBytes:&nb02
|
898
|
-
[encoder setBytes:&nb03
|
899
|
-
[encoder setBytes:&ne0
|
900
|
-
[encoder setBytes:&ne1
|
901
|
-
[encoder setBytes:&ne2
|
902
|
-
[encoder setBytes:&ne3
|
903
|
-
[encoder setBytes:&nb0
|
904
|
-
[encoder setBytes:&nb1
|
905
|
-
[encoder setBytes:&nb2
|
906
|
-
[encoder setBytes:&nb3
|
907
|
-
[encoder setBytes:&n_past
|
908
|
-
[encoder setBytes:&n_dims
|
909
|
-
[encoder setBytes:&mode
|
898
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
899
|
+
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
900
|
+
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
901
|
+
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
|
902
|
+
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
|
903
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
|
904
|
+
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
|
905
|
+
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
|
906
|
+
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
|
907
|
+
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
|
908
|
+
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
|
909
|
+
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
|
910
|
+
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
|
911
|
+
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
|
912
|
+
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
|
913
|
+
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
914
|
+
[encoder setBytes:&n_past length:sizeof( int) atIndex:18];
|
915
|
+
[encoder setBytes:&n_dims length:sizeof( int) atIndex:19];
|
916
|
+
[encoder setBytes:&mode length:sizeof( int) atIndex:20];
|
917
|
+
[encoder setBytes:&freq_base length:sizeof(float) atIndex:21];
|
918
|
+
[encoder setBytes:&freq_scale length:sizeof(float) atIndex:22];
|
910
919
|
|
911
920
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
912
921
|
} break;
|