llama_cpp 0.3.8 → 0.4.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +10 -0
- data/README.md +1 -1
- data/examples/chat.rb +2 -4
- data/ext/llama_cpp/extconf.rb +2 -2
- data/ext/llama_cpp/llama_cpp.cpp +110 -117
- data/ext/llama_cpp/src/ggml-alloc.c +79 -65
- data/ext/llama_cpp/src/ggml-alloc.h +1 -1
- data/ext/llama_cpp/src/ggml-cuda.cu +330 -69
- data/ext/llama_cpp/src/ggml-cuda.h +13 -0
- data/ext/llama_cpp/src/ggml-metal.h +3 -0
- data/ext/llama_cpp/src/ggml-metal.m +102 -66
- data/ext/llama_cpp/src/ggml-metal.metal +113 -9
- data/ext/llama_cpp/src/ggml.c +2064 -233
- data/ext/llama_cpp/src/ggml.h +238 -13
- data/ext/llama_cpp/src/k_quants.c +110 -54
- data/ext/llama_cpp/src/llama.cpp +4520 -2978
- data/ext/llama_cpp/src/llama.h +133 -125
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +1 -1
- data/sig/llama_cpp.rbs +7 -8
- metadata +2 -2
@@ -2,6 +2,14 @@
|
|
2
2
|
|
3
3
|
#include "ggml.h"
|
4
4
|
|
5
|
+
#ifdef GGML_USE_HIPBLAS
|
6
|
+
#define GGML_CUDA_NAME "ROCm"
|
7
|
+
#define GGML_CUBLAS_NAME "hipBLAS"
|
8
|
+
#else
|
9
|
+
#define GGML_CUDA_NAME "CUDA"
|
10
|
+
#define GGML_CUBLAS_NAME "cuBLAS"
|
11
|
+
#endif
|
12
|
+
|
5
13
|
#ifdef __cplusplus
|
6
14
|
extern "C" {
|
7
15
|
#endif
|
@@ -16,9 +24,14 @@ GGML_API bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const str
|
|
16
24
|
GGML_API void ggml_cuda_set_tensor_split(const float * tensor_split);
|
17
25
|
GGML_API void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor);
|
18
26
|
GGML_API void ggml_cuda_free_data(struct ggml_tensor * tensor);
|
27
|
+
|
19
28
|
GGML_API void ggml_cuda_assign_buffers(struct ggml_tensor * tensor);
|
20
29
|
GGML_API void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor);
|
21
30
|
GGML_API void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor);
|
31
|
+
|
32
|
+
GGML_API void ggml_cuda_assign_buffers_no_alloc(struct ggml_tensor * tensor);
|
33
|
+
GGML_API void ggml_cuda_assign_scratch_offset(struct ggml_tensor * tensor, size_t offset);
|
34
|
+
|
22
35
|
GGML_API void ggml_cuda_set_main_device(int main_device);
|
23
36
|
GGML_API void ggml_cuda_set_mul_mat_q(bool mul_mat_q);
|
24
37
|
GGML_API void ggml_cuda_set_scratch_size(size_t scratch_size);
|
@@ -38,6 +38,9 @@ struct ggml_metal_context;
|
|
38
38
|
struct ggml_metal_context * ggml_metal_init(int n_cb);
|
39
39
|
void ggml_metal_free(struct ggml_metal_context * ctx);
|
40
40
|
|
41
|
+
void * ggml_metal_host_malloc(size_t n);
|
42
|
+
void ggml_metal_host_free (void * data);
|
43
|
+
|
41
44
|
// set the number of command buffers to use
|
42
45
|
void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb);
|
43
46
|
|
@@ -63,6 +63,7 @@ struct ggml_metal_context {
|
|
63
63
|
GGML_METAL_DECL_KERNEL(get_rows_f16);
|
64
64
|
GGML_METAL_DECL_KERNEL(get_rows_q4_0);
|
65
65
|
GGML_METAL_DECL_KERNEL(get_rows_q4_1);
|
66
|
+
GGML_METAL_DECL_KERNEL(get_rows_q8_0);
|
66
67
|
GGML_METAL_DECL_KERNEL(get_rows_q2_K);
|
67
68
|
GGML_METAL_DECL_KERNEL(get_rows_q3_K);
|
68
69
|
GGML_METAL_DECL_KERNEL(get_rows_q4_K);
|
@@ -73,6 +74,7 @@ struct ggml_metal_context {
|
|
73
74
|
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
|
74
75
|
GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
|
75
76
|
GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
|
77
|
+
GGML_METAL_DECL_KERNEL(mul_mat_q8_0_f32);
|
76
78
|
GGML_METAL_DECL_KERNEL(mul_mat_q2_K_f32);
|
77
79
|
GGML_METAL_DECL_KERNEL(mul_mat_q3_K_f32);
|
78
80
|
GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32);
|
@@ -81,6 +83,7 @@ struct ggml_metal_context {
|
|
81
83
|
GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
|
82
84
|
GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
|
83
85
|
GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32);
|
86
|
+
GGML_METAL_DECL_KERNEL(mul_mm_q8_0_f32);
|
84
87
|
GGML_METAL_DECL_KERNEL(mul_mm_q2_K_f32);
|
85
88
|
GGML_METAL_DECL_KERNEL(mul_mm_q3_K_f32);
|
86
89
|
GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
|
@@ -167,7 +170,9 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
167
170
|
#define GGML_METAL_ADD_KERNEL(name) \
|
168
171
|
ctx->function_##name = [ctx->library newFunctionWithName:@"kernel_"#name]; \
|
169
172
|
ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:&error]; \
|
170
|
-
fprintf(stderr, "%s: loaded %-32s %16p\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name
|
173
|
+
fprintf(stderr, "%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name, \
|
174
|
+
(int) ctx->pipeline_##name.maxTotalThreadsPerThreadgroup, \
|
175
|
+
(int) ctx->pipeline_##name.threadExecutionWidth); \
|
171
176
|
if (error) { \
|
172
177
|
fprintf(stderr, "%s: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
|
173
178
|
return NULL; \
|
@@ -186,6 +191,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
186
191
|
GGML_METAL_ADD_KERNEL(get_rows_f16);
|
187
192
|
GGML_METAL_ADD_KERNEL(get_rows_q4_0);
|
188
193
|
GGML_METAL_ADD_KERNEL(get_rows_q4_1);
|
194
|
+
GGML_METAL_ADD_KERNEL(get_rows_q8_0);
|
189
195
|
GGML_METAL_ADD_KERNEL(get_rows_q2_K);
|
190
196
|
GGML_METAL_ADD_KERNEL(get_rows_q3_K);
|
191
197
|
GGML_METAL_ADD_KERNEL(get_rows_q4_K);
|
@@ -196,6 +202,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
196
202
|
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
|
197
203
|
GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
|
198
204
|
GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
|
205
|
+
GGML_METAL_ADD_KERNEL(mul_mat_q8_0_f32);
|
199
206
|
GGML_METAL_ADD_KERNEL(mul_mat_q2_K_f32);
|
200
207
|
GGML_METAL_ADD_KERNEL(mul_mat_q3_K_f32);
|
201
208
|
GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32);
|
@@ -203,6 +210,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
203
210
|
GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32);
|
204
211
|
GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
|
205
212
|
GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
|
213
|
+
GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
|
206
214
|
GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32);
|
207
215
|
GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32);
|
208
216
|
GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32);
|
@@ -218,12 +226,12 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
218
226
|
#undef GGML_METAL_ADD_KERNEL
|
219
227
|
}
|
220
228
|
|
221
|
-
fprintf(stderr, "%s: recommendedMaxWorkingSetSize
|
222
|
-
fprintf(stderr, "%s: hasUnifiedMemory
|
229
|
+
fprintf(stderr, "%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
230
|
+
fprintf(stderr, "%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
|
223
231
|
if (ctx->device.maxTransferRate != 0) {
|
224
|
-
fprintf(stderr, "%s: maxTransferRate
|
232
|
+
fprintf(stderr, "%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
|
225
233
|
} else {
|
226
|
-
fprintf(stderr, "%s: maxTransferRate
|
234
|
+
fprintf(stderr, "%s: maxTransferRate = built-in GPU\n", __func__);
|
227
235
|
}
|
228
236
|
|
229
237
|
return ctx;
|
@@ -237,6 +245,21 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|
237
245
|
free(ctx);
|
238
246
|
}
|
239
247
|
|
248
|
+
void * ggml_metal_host_malloc(size_t n) {
|
249
|
+
void * data = NULL;
|
250
|
+
const int result = posix_memalign((void **) &data, getpagesize(), n);
|
251
|
+
if (result != 0) {
|
252
|
+
fprintf(stderr, "%s: error: posix_memalign failed\n", __func__);
|
253
|
+
return NULL;
|
254
|
+
}
|
255
|
+
|
256
|
+
return data;
|
257
|
+
}
|
258
|
+
|
259
|
+
void ggml_metal_host_free(void * data) {
|
260
|
+
free(data);
|
261
|
+
}
|
262
|
+
|
240
263
|
void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb) {
|
241
264
|
ctx->n_cb = n_cb;
|
242
265
|
}
|
@@ -522,8 +545,8 @@ void ggml_metal_graph_compute(
|
|
522
545
|
|
523
546
|
id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
524
547
|
|
525
|
-
const int node_start =
|
526
|
-
const int node_end = (cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb;
|
548
|
+
const int node_start = (cb_idx + 0) * n_nodes_per_cb;
|
549
|
+
const int node_end = MIN((cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb, n_nodes);
|
527
550
|
|
528
551
|
for (int ind = node_start; ind < node_end; ++ind) {
|
529
552
|
const int i = has_concur ? ctx->concur_list[ind] : ind;
|
@@ -729,32 +752,32 @@ void ggml_metal_graph_compute(
|
|
729
752
|
[ctx->device supportsFamily:MTLGPUFamilyApple7] &&
|
730
753
|
ne00%32 == 0 &&
|
731
754
|
ne11 > 1) {
|
732
|
-
|
733
|
-
|
734
|
-
|
735
|
-
|
736
|
-
|
737
|
-
|
738
|
-
|
739
|
-
|
740
|
-
|
741
|
-
|
742
|
-
|
743
|
-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
744
|
-
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
745
|
-
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
746
|
-
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
747
|
-
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
748
|
-
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
|
749
|
-
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
|
750
|
-
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
|
751
|
-
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:8];
|
752
|
-
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:9];
|
753
|
-
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:10];
|
754
|
-
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
755
|
-
[encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
755
|
+
switch (src0->type) {
|
756
|
+
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
|
757
|
+
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
|
758
|
+
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
|
759
|
+
case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q8_0_f32]; break;
|
760
|
+
case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break;
|
761
|
+
case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break;
|
762
|
+
case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_K_f32]; break;
|
763
|
+
case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_K_f32]; break;
|
764
|
+
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q6_K_f32]; break;
|
765
|
+
default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
|
756
766
|
}
|
757
|
-
|
767
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
768
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
769
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
770
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
771
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
772
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
|
773
|
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
|
774
|
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
|
775
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:8];
|
776
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:9];
|
777
|
+
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:10];
|
778
|
+
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
779
|
+
[encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
780
|
+
} else {
|
758
781
|
int nth0 = 32;
|
759
782
|
int nth1 = 1;
|
760
783
|
|
@@ -784,6 +807,15 @@ void ggml_metal_graph_compute(
|
|
784
807
|
nth1 = 8;
|
785
808
|
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_1_f32];
|
786
809
|
} break;
|
810
|
+
case GGML_TYPE_Q8_0:
|
811
|
+
{
|
812
|
+
GGML_ASSERT(ne02 == 1);
|
813
|
+
GGML_ASSERT(ne12 == 1);
|
814
|
+
|
815
|
+
nth0 = 8;
|
816
|
+
nth1 = 8;
|
817
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q8_0_f32];
|
818
|
+
} break;
|
787
819
|
case GGML_TYPE_Q2_K:
|
788
820
|
{
|
789
821
|
GGML_ASSERT(ne02 == 1);
|
@@ -853,24 +885,24 @@ void ggml_metal_graph_compute(
|
|
853
885
|
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
|
854
886
|
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
|
855
887
|
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
|
856
|
-
[encoder setBytes:&gqa
|
888
|
+
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
|
857
889
|
|
858
|
-
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
|
890
|
+
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q8_0 ||
|
859
891
|
src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
|
860
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)
|
892
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
861
893
|
}
|
862
894
|
else if (src0t == GGML_TYPE_Q3_K) {
|
863
895
|
#ifdef GGML_QKK_64
|
864
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
896
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
865
897
|
#else
|
866
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01+3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
898
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
867
899
|
#endif
|
868
900
|
}
|
869
901
|
else if (src0t == GGML_TYPE_Q5_K) {
|
870
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)
|
902
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
871
903
|
}
|
872
904
|
else if (src0t == GGML_TYPE_Q6_K) {
|
873
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
905
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
874
906
|
} else {
|
875
907
|
[encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
|
876
908
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
@@ -880,9 +912,10 @@ void ggml_metal_graph_compute(
|
|
880
912
|
case GGML_OP_GET_ROWS:
|
881
913
|
{
|
882
914
|
switch (src0->type) {
|
883
|
-
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16];
|
915
|
+
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
|
884
916
|
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
|
885
917
|
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
|
918
|
+
case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q8_0]; break;
|
886
919
|
case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_K]; break;
|
887
920
|
case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_K]; break;
|
888
921
|
case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_K]; break;
|
@@ -923,16 +956,17 @@ void ggml_metal_graph_compute(
|
|
923
956
|
} break;
|
924
957
|
case GGML_OP_NORM:
|
925
958
|
{
|
926
|
-
|
959
|
+
float eps;
|
960
|
+
memcpy(&eps, dst->op_params, sizeof(float));
|
927
961
|
|
928
962
|
const int nth = 256;
|
929
963
|
|
930
964
|
[encoder setComputePipelineState:ctx->pipeline_norm];
|
931
|
-
[encoder setBuffer:id_src0 offset:offs_src0
|
932
|
-
[encoder setBuffer:id_dst offset:offs_dst
|
933
|
-
[encoder setBytes:&ne00
|
934
|
-
[encoder setBytes:&nb01
|
935
|
-
[encoder setBytes:&eps
|
965
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
966
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
967
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
968
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
|
969
|
+
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
|
936
970
|
[encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
|
937
971
|
|
938
972
|
const int64_t nrows = ggml_nrows(src0);
|
@@ -975,7 +1009,9 @@ void ggml_metal_graph_compute(
|
|
975
1009
|
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
|
976
1010
|
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
977
1011
|
[encoder setBytes:&m0 length:sizeof( float) atIndex:18];
|
1012
|
+
|
978
1013
|
const int nth = 32;
|
1014
|
+
|
979
1015
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
980
1016
|
} break;
|
981
1017
|
case GGML_OP_ROPE:
|
@@ -990,8 +1026,8 @@ void ggml_metal_graph_compute(
|
|
990
1026
|
memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
|
991
1027
|
|
992
1028
|
[encoder setComputePipelineState:ctx->pipeline_rope];
|
993
|
-
[encoder setBuffer:id_src0 offset:offs_src0
|
994
|
-
[encoder setBuffer:id_dst offset:offs_dst
|
1029
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1030
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
995
1031
|
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
996
1032
|
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
997
1033
|
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
@@ -1042,24 +1078,24 @@ void ggml_metal_graph_compute(
|
|
1042
1078
|
default: GGML_ASSERT(false && "not implemented");
|
1043
1079
|
}
|
1044
1080
|
|
1045
|
-
[encoder setBuffer:id_src0 offset:offs_src0
|
1046
|
-
[encoder setBuffer:id_dst offset:offs_dst
|
1047
|
-
[encoder setBytes:&ne00
|
1048
|
-
[encoder setBytes:&ne01
|
1049
|
-
[encoder setBytes:&ne02
|
1050
|
-
[encoder setBytes:&ne03
|
1051
|
-
[encoder setBytes:&nb00
|
1052
|
-
[encoder setBytes:&nb01
|
1053
|
-
[encoder setBytes:&nb02
|
1054
|
-
[encoder setBytes:&nb03
|
1055
|
-
[encoder setBytes:&ne0
|
1056
|
-
[encoder setBytes:&ne1
|
1057
|
-
[encoder setBytes:&ne2
|
1058
|
-
[encoder setBytes:&ne3
|
1059
|
-
[encoder setBytes:&nb0
|
1060
|
-
[encoder setBytes:&nb1
|
1061
|
-
[encoder setBytes:&nb2
|
1062
|
-
[encoder setBytes:&nb3
|
1081
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1082
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
1083
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
1084
|
+
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
1085
|
+
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
1086
|
+
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
|
1087
|
+
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
|
1088
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
|
1089
|
+
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
|
1090
|
+
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
|
1091
|
+
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
|
1092
|
+
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
|
1093
|
+
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
|
1094
|
+
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
|
1095
|
+
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
|
1096
|
+
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
|
1097
|
+
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
|
1098
|
+
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
1063
1099
|
|
1064
1100
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
1065
1101
|
} break;
|
@@ -18,6 +18,12 @@ typedef struct {
|
|
18
18
|
uint8_t qs[QK4_1 / 2]; // nibbles / quants
|
19
19
|
} block_q4_1;
|
20
20
|
|
21
|
+
#define QK8_0 32
|
22
|
+
typedef struct {
|
23
|
+
half d; // delta
|
24
|
+
int8_t qs[QK8_0]; // quants
|
25
|
+
} block_q8_0;
|
26
|
+
|
21
27
|
kernel void kernel_add(
|
22
28
|
device const float * src0,
|
23
29
|
device const float * src1,
|
@@ -87,7 +93,12 @@ kernel void kernel_gelu(
|
|
87
93
|
device float * dst,
|
88
94
|
uint tpig[[thread_position_in_grid]]) {
|
89
95
|
float x = src0[tpig];
|
90
|
-
|
96
|
+
|
97
|
+
// BEWARE !!!
|
98
|
+
// Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
|
99
|
+
// This was observed with Falcon 7B and 40B models
|
100
|
+
//
|
101
|
+
dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
|
91
102
|
}
|
92
103
|
|
93
104
|
kernel void kernel_soft_max(
|
@@ -352,7 +363,7 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
|
|
352
363
|
const int first_row = (r0 * nsg + sgitg) * nr;
|
353
364
|
const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
|
354
365
|
device const block_q_type * x = (device const block_q_type *) src0 + offset0;
|
355
|
-
device const float
|
366
|
+
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
356
367
|
float yl[16]; // src1 vector cache
|
357
368
|
float sumf[nr]={0.f};
|
358
369
|
|
@@ -424,6 +435,68 @@ kernel void kernel_mul_mat_q4_1_f32(
|
|
424
435
|
mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
|
425
436
|
}
|
426
437
|
|
438
|
+
kernel void kernel_mul_mat_q8_0_f32(
|
439
|
+
device const void * src0,
|
440
|
+
device const float * src1,
|
441
|
+
device float * dst,
|
442
|
+
constant int64_t & ne00,
|
443
|
+
constant int64_t & ne01[[buffer(4)]],
|
444
|
+
constant int64_t & ne02[[buffer(5)]],
|
445
|
+
constant int64_t & ne10[[buffer(9)]],
|
446
|
+
constant int64_t & ne12[[buffer(11)]],
|
447
|
+
constant int64_t & ne0[[buffer(15)]],
|
448
|
+
constant int64_t & ne1[[buffer(16)]],
|
449
|
+
constant uint & gqa[[buffer(17)]],
|
450
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
451
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
452
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
453
|
+
const int nr = N_DST;
|
454
|
+
const int nsg = N_SIMDGROUP;
|
455
|
+
const int nw = N_SIMDWIDTH;
|
456
|
+
|
457
|
+
const int nb = ne00/QK8_0;
|
458
|
+
const int r0 = tgpig.x;
|
459
|
+
const int r1 = tgpig.y;
|
460
|
+
const int im = tgpig.z;
|
461
|
+
const int first_row = (r0 * nsg + sgitg) * nr;
|
462
|
+
const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
|
463
|
+
device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0;
|
464
|
+
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
465
|
+
|
466
|
+
float yl[16];
|
467
|
+
float sumf[nr]={0.f};
|
468
|
+
|
469
|
+
const int ix = tiisg/2;
|
470
|
+
const int il = tiisg%2;
|
471
|
+
|
472
|
+
device const float * yb = y + ix * QK8_0 + 16*il;
|
473
|
+
|
474
|
+
// each thread in a SIMD group deals with half a block.
|
475
|
+
for (int ib = ix; ib < nb; ib += nw/2) {
|
476
|
+
for (int i = 0; i < 16; ++i) {
|
477
|
+
yl[i] = yb[i];
|
478
|
+
}
|
479
|
+
|
480
|
+
for (int row = 0; row < nr; row++) {
|
481
|
+
device const int8_t * qs = x[ib+row*nb].qs + 16*il;
|
482
|
+
float sumq = 0.f;
|
483
|
+
for (int iq = 0; iq < 16; ++iq) {
|
484
|
+
sumq += qs[iq] * yl[iq];
|
485
|
+
}
|
486
|
+
sumf[row] += sumq*x[ib+row*nb].d;
|
487
|
+
}
|
488
|
+
|
489
|
+
yb += QK8_0 * 16;
|
490
|
+
}
|
491
|
+
|
492
|
+
for (int row = 0; row < nr; ++row) {
|
493
|
+
const float tot = simd_sum(sumf[row]);
|
494
|
+
if (tiisg == 0 && first_row + row < ne01) {
|
495
|
+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
|
496
|
+
}
|
497
|
+
}
|
498
|
+
}
|
499
|
+
|
427
500
|
kernel void kernel_mul_mat_f16_f32(
|
428
501
|
device const char * src0,
|
429
502
|
device const char * src1,
|
@@ -475,7 +548,6 @@ kernel void kernel_mul_mat_f16_f32(
|
|
475
548
|
}
|
476
549
|
}
|
477
550
|
|
478
|
-
|
479
551
|
kernel void kernel_alibi_f32(
|
480
552
|
device const float * src0,
|
481
553
|
device float * dst,
|
@@ -571,7 +643,25 @@ kernel void kernel_rope(
|
|
571
643
|
dst_data[1] = x0*sin_theta + x1*cos_theta;
|
572
644
|
}
|
573
645
|
} else {
|
574
|
-
|
646
|
+
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
|
647
|
+
for (int64_t ic = 0; ic < n_dims; ic += 2) {
|
648
|
+
const float cos_theta = cos(theta);
|
649
|
+
const float sin_theta = sin(theta);
|
650
|
+
|
651
|
+
theta *= theta_scale;
|
652
|
+
|
653
|
+
const int64_t i0 = ib*n_dims + ic/2;
|
654
|
+
|
655
|
+
device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
656
|
+
device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
657
|
+
|
658
|
+
const float x0 = src[0];
|
659
|
+
const float x1 = src[n_dims/2];
|
660
|
+
|
661
|
+
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
662
|
+
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
|
663
|
+
}
|
664
|
+
}
|
575
665
|
}
|
576
666
|
}
|
577
667
|
|
@@ -1598,12 +1688,12 @@ template <typename type4x4>
|
|
1598
1688
|
void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
|
1599
1689
|
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
|
1600
1690
|
const half d = il ? (xb->d / 16.h) : xb->d;
|
1601
|
-
const half m = il ? (-8.h * 16.h) : -8.h;
|
1691
|
+
const half m = il ? ( -8.h * 16.h) : -8.h;
|
1602
1692
|
const ushort mask0 = il ? 0x00F0 : 0x000F;
|
1603
1693
|
const ushort mask1 = il ? 0xF000 : 0x0F00;
|
1604
1694
|
|
1605
1695
|
for (int i=0;i<8;i++) {
|
1606
|
-
reg[i/2][2*(i%2)]
|
1696
|
+
reg[i/2][2*(i%2)] = (((qs[i] & mask0) ) + m) * d;
|
1607
1697
|
reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) + m) * d;
|
1608
1698
|
}
|
1609
1699
|
}
|
@@ -1617,11 +1707,21 @@ void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg
|
|
1617
1707
|
const ushort mask1 = il ? 0xF000 : 0x0F00;
|
1618
1708
|
|
1619
1709
|
for (int i=0;i<8;i++) {
|
1620
|
-
reg[i/2][2*(i%2)]
|
1710
|
+
reg[i/2][2*(i%2)] = (((qs[i] & mask0) ) * d) + m;
|
1621
1711
|
reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) * d) + m;
|
1622
1712
|
}
|
1623
1713
|
}
|
1624
1714
|
|
1715
|
+
template <typename type4x4>
|
1716
|
+
void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
|
1717
|
+
device const int8_t * qs = ((device const int8_t *)xb->qs);
|
1718
|
+
const half d = xb->d;
|
1719
|
+
|
1720
|
+
for (int i=0;i<16;i++) {
|
1721
|
+
reg[i/4][i%4] = (qs[i + 16*il] * d);
|
1722
|
+
}
|
1723
|
+
}
|
1724
|
+
|
1625
1725
|
template <typename type4x4>
|
1626
1726
|
void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
|
1627
1727
|
const half d = xb->d;
|
@@ -1850,6 +1950,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|
1850
1950
|
//load data and store to threadgroup memory
|
1851
1951
|
half4x4 temp_a;
|
1852
1952
|
dequantize_func(x, il, temp_a);
|
1953
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1853
1954
|
#pragma unroll(16)
|
1854
1955
|
for (int i = 0; i < 16; i++) {
|
1855
1956
|
*(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
|
@@ -1895,6 +1996,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|
1895
1996
|
}
|
1896
1997
|
} else {
|
1897
1998
|
// block is smaller than 64x32, we should avoid writing data outside of the matrix
|
1999
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1898
2000
|
threadgroup float *temp_str = ((threadgroup float *)shared_memory) \
|
1899
2001
|
+ 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
|
1900
2002
|
for (int i = 0; i < 8; i++) {
|
@@ -1922,9 +2024,10 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|
1922
2024
|
typedef void (get_rows_t)(device const void *, device const int *, device float *, constant int64_t &, \
|
1923
2025
|
constant uint64_t &, constant uint64_t &, uint, uint, uint);
|
1924
2026
|
|
1925
|
-
template [[host_name("kernel_get_rows_f16")]]
|
2027
|
+
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
|
1926
2028
|
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
|
1927
2029
|
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
|
2030
|
+
template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows<block_q8_0, 2, dequantize_q8_0>;
|
1928
2031
|
template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows<block_q2_K, QK_NL, dequantize_q2_K>;
|
1929
2032
|
template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows<block_q3_K, QK_NL, dequantize_q3_K>;
|
1930
2033
|
template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows<block_q4_K, QK_NL, dequantize_q4_K>;
|
@@ -1935,9 +2038,10 @@ typedef void (mat_mm_t)(device const uchar *, device const float *, device float
|
|
1935
2038
|
constant int64_t &, constant int64_t &, constant int64_t &, constant int64_t &, \
|
1936
2039
|
constant int64_t &, constant int64_t &, constant uint &, threadgroup uchar *, uint3, uint, uint);
|
1937
2040
|
|
1938
|
-
template [[host_name("kernel_mul_mm_f16_f32")]]
|
2041
|
+
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
|
1939
2042
|
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
|
1940
2043
|
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
|
2044
|
+
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
|
1941
2045
|
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
|
1942
2046
|
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
|
1943
2047
|
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
|