llama_cpp 0.3.3 → 0.3.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +24 -0
- data/ext/llama_cpp/llama_cpp.cpp +146 -9
- data/ext/llama_cpp/src/ggml-cuda.cu +485 -67
- data/ext/llama_cpp/src/ggml-metal.m +52 -43
- data/ext/llama_cpp/src/ggml-metal.metal +587 -470
- data/ext/llama_cpp/src/ggml.c +105 -79
- data/ext/llama_cpp/src/ggml.h +13 -1
- data/ext/llama_cpp/src/k_quants.h +8 -0
- data/ext/llama_cpp/src/llama.cpp +123 -66
- data/ext/llama_cpp/src/llama.h +34 -5
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +1 -0
- data/sig/llama_cpp.rbs +12 -1
- metadata +2 -2
@@ -676,8 +676,8 @@ void ggml_metal_graph_compute(
|
|
676
676
|
GGML_ASSERT(ne02 == 1);
|
677
677
|
GGML_ASSERT(ne12 == 1);
|
678
678
|
|
679
|
-
nth0 =
|
680
|
-
nth1 =
|
679
|
+
nth0 = 2;
|
680
|
+
nth1 = 32;
|
681
681
|
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_K_f32];
|
682
682
|
} break;
|
683
683
|
case GGML_TYPE_Q3_K:
|
@@ -685,8 +685,8 @@ void ggml_metal_graph_compute(
|
|
685
685
|
GGML_ASSERT(ne02 == 1);
|
686
686
|
GGML_ASSERT(ne12 == 1);
|
687
687
|
|
688
|
-
nth0 =
|
689
|
-
nth1 =
|
688
|
+
nth0 = 2;
|
689
|
+
nth1 = 32;
|
690
690
|
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_K_f32];
|
691
691
|
} break;
|
692
692
|
case GGML_TYPE_Q4_K:
|
@@ -694,8 +694,8 @@ void ggml_metal_graph_compute(
|
|
694
694
|
GGML_ASSERT(ne02 == 1);
|
695
695
|
GGML_ASSERT(ne12 == 1);
|
696
696
|
|
697
|
-
nth0 =
|
698
|
-
nth1 =
|
697
|
+
nth0 = 2;
|
698
|
+
nth1 = 32;
|
699
699
|
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_K_f32];
|
700
700
|
} break;
|
701
701
|
case GGML_TYPE_Q5_K:
|
@@ -703,8 +703,8 @@ void ggml_metal_graph_compute(
|
|
703
703
|
GGML_ASSERT(ne02 == 1);
|
704
704
|
GGML_ASSERT(ne12 == 1);
|
705
705
|
|
706
|
-
nth0 =
|
707
|
-
nth1 =
|
706
|
+
nth0 = 2;
|
707
|
+
nth1 = 32;
|
708
708
|
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_K_f32];
|
709
709
|
} break;
|
710
710
|
case GGML_TYPE_Q6_K:
|
@@ -712,8 +712,8 @@ void ggml_metal_graph_compute(
|
|
712
712
|
GGML_ASSERT(ne02 == 1);
|
713
713
|
GGML_ASSERT(ne12 == 1);
|
714
714
|
|
715
|
-
nth0 =
|
716
|
-
nth1 =
|
715
|
+
nth0 = 2;
|
716
|
+
nth1 = 32;
|
717
717
|
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_K_f32];
|
718
718
|
} break;
|
719
719
|
default:
|
@@ -739,20 +739,22 @@ void ggml_metal_graph_compute(
|
|
739
739
|
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
|
740
740
|
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
|
741
741
|
|
742
|
-
if (src0t == GGML_TYPE_Q4_0
|
743
|
-
|
742
|
+
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
|
743
|
+
src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
|
744
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
744
745
|
}
|
745
|
-
else if (src0t ==
|
746
|
-
|
747
|
-
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
746
|
+
else if (src0t == GGML_TYPE_Q3_K) {
|
747
|
+
#ifdef GGML_QKK_64
|
748
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
749
|
+
#else
|
750
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01+3)/4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
751
|
+
#endif
|
752
|
+
}
|
753
|
+
else if (src0t == GGML_TYPE_Q5_K) {
|
754
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3) / 4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
748
755
|
}
|
749
|
-
else if (src0t ==
|
750
|
-
|
751
|
-
src0t == GGML_TYPE_Q4_K ||
|
752
|
-
src0t == GGML_TYPE_Q5_K ||
|
753
|
-
src0t == GGML_TYPE_Q6_K) {
|
754
|
-
[encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
|
755
|
-
[encoder dispatchThreadgroups:MTLSizeMake(ne01, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
756
|
+
else if (src0t == GGML_TYPE_Q6_K) {
|
757
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
756
758
|
} else {
|
757
759
|
[encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
|
758
760
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
@@ -796,7 +798,7 @@ void ggml_metal_graph_compute(
|
|
796
798
|
|
797
799
|
const float eps = 1e-6f;
|
798
800
|
|
799
|
-
const int nth =
|
801
|
+
const int nth = 512;
|
800
802
|
|
801
803
|
[encoder setComputePipelineState:ctx->pipeline_rms_norm];
|
802
804
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
@@ -804,7 +806,7 @@ void ggml_metal_graph_compute(
|
|
804
806
|
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
805
807
|
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
|
806
808
|
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
|
807
|
-
[encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
|
809
|
+
[encoder setThreadgroupMemoryLength:nth/32*sizeof(float) atIndex:0];
|
808
810
|
|
809
811
|
const int64_t nrows = ggml_nrows(src0);
|
810
812
|
|
@@ -885,28 +887,35 @@ void ggml_metal_graph_compute(
|
|
885
887
|
|
886
888
|
const int n_past = ((int32_t *)(src1->data))[0];
|
887
889
|
|
890
|
+
float freq_base;
|
891
|
+
float freq_scale;
|
892
|
+
memcpy(&freq_base, (int32_t *) src1->data + 4, sizeof(float));
|
893
|
+
memcpy(&freq_scale, (int32_t *) src1->data + 5, sizeof(float));
|
894
|
+
|
888
895
|
[encoder setComputePipelineState:ctx->pipeline_rope];
|
889
896
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
890
897
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
891
|
-
[encoder setBytes:&ne00
|
892
|
-
[encoder setBytes:&ne01
|
893
|
-
[encoder setBytes:&ne02
|
894
|
-
[encoder setBytes:&ne03
|
895
|
-
[encoder setBytes:&nb00
|
896
|
-
[encoder setBytes:&nb01
|
897
|
-
[encoder setBytes:&nb02
|
898
|
-
[encoder setBytes:&nb03
|
899
|
-
[encoder setBytes:&ne0
|
900
|
-
[encoder setBytes:&ne1
|
901
|
-
[encoder setBytes:&ne2
|
902
|
-
[encoder setBytes:&ne3
|
903
|
-
[encoder setBytes:&nb0
|
904
|
-
[encoder setBytes:&nb1
|
905
|
-
[encoder setBytes:&nb2
|
906
|
-
[encoder setBytes:&nb3
|
907
|
-
[encoder setBytes:&n_past
|
908
|
-
[encoder setBytes:&n_dims
|
909
|
-
[encoder setBytes:&mode
|
898
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
899
|
+
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
900
|
+
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
901
|
+
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
|
902
|
+
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
|
903
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
|
904
|
+
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
|
905
|
+
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
|
906
|
+
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
|
907
|
+
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
|
908
|
+
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
|
909
|
+
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
|
910
|
+
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
|
911
|
+
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
|
912
|
+
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
|
913
|
+
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
914
|
+
[encoder setBytes:&n_past length:sizeof( int) atIndex:18];
|
915
|
+
[encoder setBytes:&n_dims length:sizeof( int) atIndex:19];
|
916
|
+
[encoder setBytes:&mode length:sizeof( int) atIndex:20];
|
917
|
+
[encoder setBytes:&freq_base length:sizeof(float) atIndex:21];
|
918
|
+
[encoder setBytes:&freq_scale length:sizeof(float) atIndex:22];
|
910
919
|
|
911
920
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
912
921
|
} break;
|