llama_cpp 0.3.8 → 0.4.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +10 -0
- data/README.md +1 -1
- data/examples/chat.rb +2 -4
- data/ext/llama_cpp/extconf.rb +2 -2
- data/ext/llama_cpp/llama_cpp.cpp +110 -117
- data/ext/llama_cpp/src/ggml-alloc.c +79 -65
- data/ext/llama_cpp/src/ggml-alloc.h +1 -1
- data/ext/llama_cpp/src/ggml-cuda.cu +330 -69
- data/ext/llama_cpp/src/ggml-cuda.h +13 -0
- data/ext/llama_cpp/src/ggml-metal.h +3 -0
- data/ext/llama_cpp/src/ggml-metal.m +102 -66
- data/ext/llama_cpp/src/ggml-metal.metal +113 -9
- data/ext/llama_cpp/src/ggml.c +2064 -233
- data/ext/llama_cpp/src/ggml.h +238 -13
- data/ext/llama_cpp/src/k_quants.c +110 -54
- data/ext/llama_cpp/src/llama.cpp +4520 -2978
- data/ext/llama_cpp/src/llama.h +133 -125
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +1 -1
- data/sig/llama_cpp.rbs +7 -8
- metadata +2 -2
@@ -2,6 +2,14 @@
|
|
2
2
|
|
3
3
|
#include "ggml.h"
|
4
4
|
|
5
|
+
#ifdef GGML_USE_HIPBLAS
|
6
|
+
#define GGML_CUDA_NAME "ROCm"
|
7
|
+
#define GGML_CUBLAS_NAME "hipBLAS"
|
8
|
+
#else
|
9
|
+
#define GGML_CUDA_NAME "CUDA"
|
10
|
+
#define GGML_CUBLAS_NAME "cuBLAS"
|
11
|
+
#endif
|
12
|
+
|
5
13
|
#ifdef __cplusplus
|
6
14
|
extern "C" {
|
7
15
|
#endif
|
@@ -16,9 +24,14 @@ GGML_API bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const str
|
|
16
24
|
GGML_API void ggml_cuda_set_tensor_split(const float * tensor_split);
|
17
25
|
GGML_API void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor);
|
18
26
|
GGML_API void ggml_cuda_free_data(struct ggml_tensor * tensor);
|
27
|
+
|
19
28
|
GGML_API void ggml_cuda_assign_buffers(struct ggml_tensor * tensor);
|
20
29
|
GGML_API void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor);
|
21
30
|
GGML_API void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor);
|
31
|
+
|
32
|
+
GGML_API void ggml_cuda_assign_buffers_no_alloc(struct ggml_tensor * tensor);
|
33
|
+
GGML_API void ggml_cuda_assign_scratch_offset(struct ggml_tensor * tensor, size_t offset);
|
34
|
+
|
22
35
|
GGML_API void ggml_cuda_set_main_device(int main_device);
|
23
36
|
GGML_API void ggml_cuda_set_mul_mat_q(bool mul_mat_q);
|
24
37
|
GGML_API void ggml_cuda_set_scratch_size(size_t scratch_size);
|
@@ -38,6 +38,9 @@ struct ggml_metal_context;
|
|
38
38
|
struct ggml_metal_context * ggml_metal_init(int n_cb);
|
39
39
|
void ggml_metal_free(struct ggml_metal_context * ctx);
|
40
40
|
|
41
|
+
void * ggml_metal_host_malloc(size_t n);
|
42
|
+
void ggml_metal_host_free (void * data);
|
43
|
+
|
41
44
|
// set the number of command buffers to use
|
42
45
|
void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb);
|
43
46
|
|
@@ -63,6 +63,7 @@ struct ggml_metal_context {
|
|
63
63
|
GGML_METAL_DECL_KERNEL(get_rows_f16);
|
64
64
|
GGML_METAL_DECL_KERNEL(get_rows_q4_0);
|
65
65
|
GGML_METAL_DECL_KERNEL(get_rows_q4_1);
|
66
|
+
GGML_METAL_DECL_KERNEL(get_rows_q8_0);
|
66
67
|
GGML_METAL_DECL_KERNEL(get_rows_q2_K);
|
67
68
|
GGML_METAL_DECL_KERNEL(get_rows_q3_K);
|
68
69
|
GGML_METAL_DECL_KERNEL(get_rows_q4_K);
|
@@ -73,6 +74,7 @@ struct ggml_metal_context {
|
|
73
74
|
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
|
74
75
|
GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
|
75
76
|
GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
|
77
|
+
GGML_METAL_DECL_KERNEL(mul_mat_q8_0_f32);
|
76
78
|
GGML_METAL_DECL_KERNEL(mul_mat_q2_K_f32);
|
77
79
|
GGML_METAL_DECL_KERNEL(mul_mat_q3_K_f32);
|
78
80
|
GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32);
|
@@ -81,6 +83,7 @@ struct ggml_metal_context {
|
|
81
83
|
GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
|
82
84
|
GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
|
83
85
|
GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32);
|
86
|
+
GGML_METAL_DECL_KERNEL(mul_mm_q8_0_f32);
|
84
87
|
GGML_METAL_DECL_KERNEL(mul_mm_q2_K_f32);
|
85
88
|
GGML_METAL_DECL_KERNEL(mul_mm_q3_K_f32);
|
86
89
|
GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
|
@@ -167,7 +170,9 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
167
170
|
#define GGML_METAL_ADD_KERNEL(name) \
|
168
171
|
ctx->function_##name = [ctx->library newFunctionWithName:@"kernel_"#name]; \
|
169
172
|
ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:&error]; \
|
170
|
-
fprintf(stderr, "%s: loaded %-32s %16p\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name
|
173
|
+
fprintf(stderr, "%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name, \
|
174
|
+
(int) ctx->pipeline_##name.maxTotalThreadsPerThreadgroup, \
|
175
|
+
(int) ctx->pipeline_##name.threadExecutionWidth); \
|
171
176
|
if (error) { \
|
172
177
|
fprintf(stderr, "%s: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
|
173
178
|
return NULL; \
|
@@ -186,6 +191,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
186
191
|
GGML_METAL_ADD_KERNEL(get_rows_f16);
|
187
192
|
GGML_METAL_ADD_KERNEL(get_rows_q4_0);
|
188
193
|
GGML_METAL_ADD_KERNEL(get_rows_q4_1);
|
194
|
+
GGML_METAL_ADD_KERNEL(get_rows_q8_0);
|
189
195
|
GGML_METAL_ADD_KERNEL(get_rows_q2_K);
|
190
196
|
GGML_METAL_ADD_KERNEL(get_rows_q3_K);
|
191
197
|
GGML_METAL_ADD_KERNEL(get_rows_q4_K);
|
@@ -196,6 +202,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
196
202
|
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
|
197
203
|
GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
|
198
204
|
GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
|
205
|
+
GGML_METAL_ADD_KERNEL(mul_mat_q8_0_f32);
|
199
206
|
GGML_METAL_ADD_KERNEL(mul_mat_q2_K_f32);
|
200
207
|
GGML_METAL_ADD_KERNEL(mul_mat_q3_K_f32);
|
201
208
|
GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32);
|
@@ -203,6 +210,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
203
210
|
GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32);
|
204
211
|
GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
|
205
212
|
GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
|
213
|
+
GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
|
206
214
|
GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32);
|
207
215
|
GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32);
|
208
216
|
GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32);
|
@@ -218,12 +226,12 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
218
226
|
#undef GGML_METAL_ADD_KERNEL
|
219
227
|
}
|
220
228
|
|
221
|
-
fprintf(stderr, "%s: recommendedMaxWorkingSetSize
|
222
|
-
fprintf(stderr, "%s: hasUnifiedMemory
|
229
|
+
fprintf(stderr, "%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
230
|
+
fprintf(stderr, "%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
|
223
231
|
if (ctx->device.maxTransferRate != 0) {
|
224
|
-
fprintf(stderr, "%s: maxTransferRate
|
232
|
+
fprintf(stderr, "%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
|
225
233
|
} else {
|
226
|
-
fprintf(stderr, "%s: maxTransferRate
|
234
|
+
fprintf(stderr, "%s: maxTransferRate = built-in GPU\n", __func__);
|
227
235
|
}
|
228
236
|
|
229
237
|
return ctx;
|
@@ -237,6 +245,21 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|
237
245
|
free(ctx);
|
238
246
|
}
|
239
247
|
|
248
|
+
void * ggml_metal_host_malloc(size_t n) {
|
249
|
+
void * data = NULL;
|
250
|
+
const int result = posix_memalign((void **) &data, getpagesize(), n);
|
251
|
+
if (result != 0) {
|
252
|
+
fprintf(stderr, "%s: error: posix_memalign failed\n", __func__);
|
253
|
+
return NULL;
|
254
|
+
}
|
255
|
+
|
256
|
+
return data;
|
257
|
+
}
|
258
|
+
|
259
|
+
void ggml_metal_host_free(void * data) {
|
260
|
+
free(data);
|
261
|
+
}
|
262
|
+
|
240
263
|
void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb) {
|
241
264
|
ctx->n_cb = n_cb;
|
242
265
|
}
|
@@ -522,8 +545,8 @@ void ggml_metal_graph_compute(
|
|
522
545
|
|
523
546
|
id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
524
547
|
|
525
|
-
const int node_start =
|
526
|
-
const int node_end = (cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb;
|
548
|
+
const int node_start = (cb_idx + 0) * n_nodes_per_cb;
|
549
|
+
const int node_end = MIN((cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb, n_nodes);
|
527
550
|
|
528
551
|
for (int ind = node_start; ind < node_end; ++ind) {
|
529
552
|
const int i = has_concur ? ctx->concur_list[ind] : ind;
|
@@ -729,32 +752,32 @@ void ggml_metal_graph_compute(
|
|
729
752
|
[ctx->device supportsFamily:MTLGPUFamilyApple7] &&
|
730
753
|
ne00%32 == 0 &&
|
731
754
|
ne11 > 1) {
|
732
|
-
|
733
|
-
|
734
|
-
|
735
|
-
|
736
|
-
|
737
|
-
|
738
|
-
|
739
|
-
|
740
|
-
|
741
|
-
|
742
|
-
|
743
|
-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
744
|
-
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
745
|
-
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
746
|
-
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
747
|
-
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
748
|
-
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
|
749
|
-
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
|
750
|
-
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
|
751
|
-
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:8];
|
752
|
-
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:9];
|
753
|
-
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:10];
|
754
|
-
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
755
|
-
[encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
755
|
+
switch (src0->type) {
|
756
|
+
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
|
757
|
+
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
|
758
|
+
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
|
759
|
+
case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q8_0_f32]; break;
|
760
|
+
case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break;
|
761
|
+
case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break;
|
762
|
+
case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_K_f32]; break;
|
763
|
+
case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_K_f32]; break;
|
764
|
+
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q6_K_f32]; break;
|
765
|
+
default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
|
756
766
|
}
|
757
|
-
|
767
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
768
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
769
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
770
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
771
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
772
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
|
773
|
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
|
774
|
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
|
775
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:8];
|
776
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:9];
|
777
|
+
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:10];
|
778
|
+
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
779
|
+
[encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
780
|
+
} else {
|
758
781
|
int nth0 = 32;
|
759
782
|
int nth1 = 1;
|
760
783
|
|
@@ -784,6 +807,15 @@ void ggml_metal_graph_compute(
|
|
784
807
|
nth1 = 8;
|
785
808
|
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_1_f32];
|
786
809
|
} break;
|
810
|
+
case GGML_TYPE_Q8_0:
|
811
|
+
{
|
812
|
+
GGML_ASSERT(ne02 == 1);
|
813
|
+
GGML_ASSERT(ne12 == 1);
|
814
|
+
|
815
|
+
nth0 = 8;
|
816
|
+
nth1 = 8;
|
817
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q8_0_f32];
|
818
|
+
} break;
|
787
819
|
case GGML_TYPE_Q2_K:
|
788
820
|
{
|
789
821
|
GGML_ASSERT(ne02 == 1);
|
@@ -853,24 +885,24 @@ void ggml_metal_graph_compute(
|
|
853
885
|
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
|
854
886
|
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
|
855
887
|
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
|
856
|
-
[encoder setBytes:&gqa
|
888
|
+
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
|
857
889
|
|
858
|
-
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
|
890
|
+
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q8_0 ||
|
859
891
|
src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
|
860
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)
|
892
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
861
893
|
}
|
862
894
|
else if (src0t == GGML_TYPE_Q3_K) {
|
863
895
|
#ifdef GGML_QKK_64
|
864
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
896
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
865
897
|
#else
|
866
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01+3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
898
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
867
899
|
#endif
|
868
900
|
}
|
869
901
|
else if (src0t == GGML_TYPE_Q5_K) {
|
870
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)
|
902
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
871
903
|
}
|
872
904
|
else if (src0t == GGML_TYPE_Q6_K) {
|
873
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
905
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
874
906
|
} else {
|
875
907
|
[encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
|
876
908
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
@@ -880,9 +912,10 @@ void ggml_metal_graph_compute(
|
|
880
912
|
case GGML_OP_GET_ROWS:
|
881
913
|
{
|
882
914
|
switch (src0->type) {
|
883
|
-
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16];
|
915
|
+
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
|
884
916
|
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
|
885
917
|
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
|
918
|
+
case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q8_0]; break;
|
886
919
|
case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_K]; break;
|
887
920
|
case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_K]; break;
|
888
921
|
case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_K]; break;
|
@@ -923,16 +956,17 @@ void ggml_metal_graph_compute(
|
|
923
956
|
} break;
|
924
957
|
case GGML_OP_NORM:
|
925
958
|
{
|
926
|
-
|
959
|
+
float eps;
|
960
|
+
memcpy(&eps, dst->op_params, sizeof(float));
|
927
961
|
|
928
962
|
const int nth = 256;
|
929
963
|
|
930
964
|
[encoder setComputePipelineState:ctx->pipeline_norm];
|
931
|
-
[encoder setBuffer:id_src0 offset:offs_src0
|
932
|
-
[encoder setBuffer:id_dst offset:offs_dst
|
933
|
-
[encoder setBytes:&ne00
|
934
|
-
[encoder setBytes:&nb01
|
935
|
-
[encoder setBytes:&eps
|
965
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
966
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
967
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
968
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
|
969
|
+
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
|
936
970
|
[encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
|
937
971
|
|
938
972
|
const int64_t nrows = ggml_nrows(src0);
|
@@ -975,7 +1009,9 @@ void ggml_metal_graph_compute(
|
|
975
1009
|
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
|
976
1010
|
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
977
1011
|
[encoder setBytes:&m0 length:sizeof( float) atIndex:18];
|
1012
|
+
|
978
1013
|
const int nth = 32;
|
1014
|
+
|
979
1015
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
980
1016
|
} break;
|
981
1017
|
case GGML_OP_ROPE:
|
@@ -990,8 +1026,8 @@ void ggml_metal_graph_compute(
|
|
990
1026
|
memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
|
991
1027
|
|
992
1028
|
[encoder setComputePipelineState:ctx->pipeline_rope];
|
993
|
-
[encoder setBuffer:id_src0 offset:offs_src0
|
994
|
-
[encoder setBuffer:id_dst offset:offs_dst
|
1029
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1030
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
995
1031
|
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
996
1032
|
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
997
1033
|
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
@@ -1042,24 +1078,24 @@ void ggml_metal_graph_compute(
|
|
1042
1078
|
default: GGML_ASSERT(false && "not implemented");
|
1043
1079
|
}
|
1044
1080
|
|
1045
|
-
[encoder setBuffer:id_src0 offset:offs_src0
|
1046
|
-
[encoder setBuffer:id_dst offset:offs_dst
|
1047
|
-
[encoder setBytes:&ne00
|
1048
|
-
[encoder setBytes:&ne01
|
1049
|
-
[encoder setBytes:&ne02
|
1050
|
-
[encoder setBytes:&ne03
|
1051
|
-
[encoder setBytes:&nb00
|
1052
|
-
[encoder setBytes:&nb01
|
1053
|
-
[encoder setBytes:&nb02
|
1054
|
-
[encoder setBytes:&nb03
|
1055
|
-
[encoder setBytes:&ne0
|
1056
|
-
[encoder setBytes:&ne1
|
1057
|
-
[encoder setBytes:&ne2
|
1058
|
-
[encoder setBytes:&ne3
|
1059
|
-
[encoder setBytes:&nb0
|
1060
|
-
[encoder setBytes:&nb1
|
1061
|
-
[encoder setBytes:&nb2
|
1062
|
-
[encoder setBytes:&nb3
|
1081
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1082
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
1083
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
1084
|
+
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
1085
|
+
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
1086
|
+
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
|
1087
|
+
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
|
1088
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
|
1089
|
+
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
|
1090
|
+
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
|
1091
|
+
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
|
1092
|
+
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
|
1093
|
+
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
|
1094
|
+
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
|
1095
|
+
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
|
1096
|
+
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
|
1097
|
+
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
|
1098
|
+
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
1063
1099
|
|
1064
1100
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
1065
1101
|
} break;
|
@@ -18,6 +18,12 @@ typedef struct {
|
|
18
18
|
uint8_t qs[QK4_1 / 2]; // nibbles / quants
|
19
19
|
} block_q4_1;
|
20
20
|
|
21
|
+
#define QK8_0 32
|
22
|
+
typedef struct {
|
23
|
+
half d; // delta
|
24
|
+
int8_t qs[QK8_0]; // quants
|
25
|
+
} block_q8_0;
|
26
|
+
|
21
27
|
kernel void kernel_add(
|
22
28
|
device const float * src0,
|
23
29
|
device const float * src1,
|
@@ -87,7 +93,12 @@ kernel void kernel_gelu(
|
|
87
93
|
device float * dst,
|
88
94
|
uint tpig[[thread_position_in_grid]]) {
|
89
95
|
float x = src0[tpig];
|
90
|
-
|
96
|
+
|
97
|
+
// BEWARE !!!
|
98
|
+
// Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
|
99
|
+
// This was observed with Falcon 7B and 40B models
|
100
|
+
//
|
101
|
+
dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
|
91
102
|
}
|
92
103
|
|
93
104
|
kernel void kernel_soft_max(
|
@@ -352,7 +363,7 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
|
|
352
363
|
const int first_row = (r0 * nsg + sgitg) * nr;
|
353
364
|
const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
|
354
365
|
device const block_q_type * x = (device const block_q_type *) src0 + offset0;
|
355
|
-
device const float
|
366
|
+
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
356
367
|
float yl[16]; // src1 vector cache
|
357
368
|
float sumf[nr]={0.f};
|
358
369
|
|
@@ -424,6 +435,68 @@ kernel void kernel_mul_mat_q4_1_f32(
|
|
424
435
|
mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
|
425
436
|
}
|
426
437
|
|
438
|
+
kernel void kernel_mul_mat_q8_0_f32(
|
439
|
+
device const void * src0,
|
440
|
+
device const float * src1,
|
441
|
+
device float * dst,
|
442
|
+
constant int64_t & ne00,
|
443
|
+
constant int64_t & ne01[[buffer(4)]],
|
444
|
+
constant int64_t & ne02[[buffer(5)]],
|
445
|
+
constant int64_t & ne10[[buffer(9)]],
|
446
|
+
constant int64_t & ne12[[buffer(11)]],
|
447
|
+
constant int64_t & ne0[[buffer(15)]],
|
448
|
+
constant int64_t & ne1[[buffer(16)]],
|
449
|
+
constant uint & gqa[[buffer(17)]],
|
450
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
451
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
452
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
453
|
+
const int nr = N_DST;
|
454
|
+
const int nsg = N_SIMDGROUP;
|
455
|
+
const int nw = N_SIMDWIDTH;
|
456
|
+
|
457
|
+
const int nb = ne00/QK8_0;
|
458
|
+
const int r0 = tgpig.x;
|
459
|
+
const int r1 = tgpig.y;
|
460
|
+
const int im = tgpig.z;
|
461
|
+
const int first_row = (r0 * nsg + sgitg) * nr;
|
462
|
+
const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
|
463
|
+
device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0;
|
464
|
+
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
465
|
+
|
466
|
+
float yl[16];
|
467
|
+
float sumf[nr]={0.f};
|
468
|
+
|
469
|
+
const int ix = tiisg/2;
|
470
|
+
const int il = tiisg%2;
|
471
|
+
|
472
|
+
device const float * yb = y + ix * QK8_0 + 16*il;
|
473
|
+
|
474
|
+
// each thread in a SIMD group deals with half a block.
|
475
|
+
for (int ib = ix; ib < nb; ib += nw/2) {
|
476
|
+
for (int i = 0; i < 16; ++i) {
|
477
|
+
yl[i] = yb[i];
|
478
|
+
}
|
479
|
+
|
480
|
+
for (int row = 0; row < nr; row++) {
|
481
|
+
device const int8_t * qs = x[ib+row*nb].qs + 16*il;
|
482
|
+
float sumq = 0.f;
|
483
|
+
for (int iq = 0; iq < 16; ++iq) {
|
484
|
+
sumq += qs[iq] * yl[iq];
|
485
|
+
}
|
486
|
+
sumf[row] += sumq*x[ib+row*nb].d;
|
487
|
+
}
|
488
|
+
|
489
|
+
yb += QK8_0 * 16;
|
490
|
+
}
|
491
|
+
|
492
|
+
for (int row = 0; row < nr; ++row) {
|
493
|
+
const float tot = simd_sum(sumf[row]);
|
494
|
+
if (tiisg == 0 && first_row + row < ne01) {
|
495
|
+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
|
496
|
+
}
|
497
|
+
}
|
498
|
+
}
|
499
|
+
|
427
500
|
kernel void kernel_mul_mat_f16_f32(
|
428
501
|
device const char * src0,
|
429
502
|
device const char * src1,
|
@@ -475,7 +548,6 @@ kernel void kernel_mul_mat_f16_f32(
|
|
475
548
|
}
|
476
549
|
}
|
477
550
|
|
478
|
-
|
479
551
|
kernel void kernel_alibi_f32(
|
480
552
|
device const float * src0,
|
481
553
|
device float * dst,
|
@@ -571,7 +643,25 @@ kernel void kernel_rope(
|
|
571
643
|
dst_data[1] = x0*sin_theta + x1*cos_theta;
|
572
644
|
}
|
573
645
|
} else {
|
574
|
-
|
646
|
+
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
|
647
|
+
for (int64_t ic = 0; ic < n_dims; ic += 2) {
|
648
|
+
const float cos_theta = cos(theta);
|
649
|
+
const float sin_theta = sin(theta);
|
650
|
+
|
651
|
+
theta *= theta_scale;
|
652
|
+
|
653
|
+
const int64_t i0 = ib*n_dims + ic/2;
|
654
|
+
|
655
|
+
device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
656
|
+
device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
657
|
+
|
658
|
+
const float x0 = src[0];
|
659
|
+
const float x1 = src[n_dims/2];
|
660
|
+
|
661
|
+
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
662
|
+
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
|
663
|
+
}
|
664
|
+
}
|
575
665
|
}
|
576
666
|
}
|
577
667
|
|
@@ -1598,12 +1688,12 @@ template <typename type4x4>
|
|
1598
1688
|
void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
|
1599
1689
|
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
|
1600
1690
|
const half d = il ? (xb->d / 16.h) : xb->d;
|
1601
|
-
const half m = il ? (-8.h * 16.h) : -8.h;
|
1691
|
+
const half m = il ? ( -8.h * 16.h) : -8.h;
|
1602
1692
|
const ushort mask0 = il ? 0x00F0 : 0x000F;
|
1603
1693
|
const ushort mask1 = il ? 0xF000 : 0x0F00;
|
1604
1694
|
|
1605
1695
|
for (int i=0;i<8;i++) {
|
1606
|
-
reg[i/2][2*(i%2)]
|
1696
|
+
reg[i/2][2*(i%2)] = (((qs[i] & mask0) ) + m) * d;
|
1607
1697
|
reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) + m) * d;
|
1608
1698
|
}
|
1609
1699
|
}
|
@@ -1617,11 +1707,21 @@ void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg
|
|
1617
1707
|
const ushort mask1 = il ? 0xF000 : 0x0F00;
|
1618
1708
|
|
1619
1709
|
for (int i=0;i<8;i++) {
|
1620
|
-
reg[i/2][2*(i%2)]
|
1710
|
+
reg[i/2][2*(i%2)] = (((qs[i] & mask0) ) * d) + m;
|
1621
1711
|
reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) * d) + m;
|
1622
1712
|
}
|
1623
1713
|
}
|
1624
1714
|
|
1715
|
+
template <typename type4x4>
|
1716
|
+
void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
|
1717
|
+
device const int8_t * qs = ((device const int8_t *)xb->qs);
|
1718
|
+
const half d = xb->d;
|
1719
|
+
|
1720
|
+
for (int i=0;i<16;i++) {
|
1721
|
+
reg[i/4][i%4] = (qs[i + 16*il] * d);
|
1722
|
+
}
|
1723
|
+
}
|
1724
|
+
|
1625
1725
|
template <typename type4x4>
|
1626
1726
|
void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
|
1627
1727
|
const half d = xb->d;
|
@@ -1850,6 +1950,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|
1850
1950
|
//load data and store to threadgroup memory
|
1851
1951
|
half4x4 temp_a;
|
1852
1952
|
dequantize_func(x, il, temp_a);
|
1953
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1853
1954
|
#pragma unroll(16)
|
1854
1955
|
for (int i = 0; i < 16; i++) {
|
1855
1956
|
*(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
|
@@ -1895,6 +1996,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|
1895
1996
|
}
|
1896
1997
|
} else {
|
1897
1998
|
// block is smaller than 64x32, we should avoid writing data outside of the matrix
|
1999
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
1898
2000
|
threadgroup float *temp_str = ((threadgroup float *)shared_memory) \
|
1899
2001
|
+ 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
|
1900
2002
|
for (int i = 0; i < 8; i++) {
|
@@ -1922,9 +2024,10 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|
1922
2024
|
typedef void (get_rows_t)(device const void *, device const int *, device float *, constant int64_t &, \
|
1923
2025
|
constant uint64_t &, constant uint64_t &, uint, uint, uint);
|
1924
2026
|
|
1925
|
-
template [[host_name("kernel_get_rows_f16")]]
|
2027
|
+
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
|
1926
2028
|
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
|
1927
2029
|
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
|
2030
|
+
template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows<block_q8_0, 2, dequantize_q8_0>;
|
1928
2031
|
template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows<block_q2_K, QK_NL, dequantize_q2_K>;
|
1929
2032
|
template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows<block_q3_K, QK_NL, dequantize_q3_K>;
|
1930
2033
|
template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows<block_q4_K, QK_NL, dequantize_q4_K>;
|
@@ -1935,9 +2038,10 @@ typedef void (mat_mm_t)(device const uchar *, device const float *, device float
|
|
1935
2038
|
constant int64_t &, constant int64_t &, constant int64_t &, constant int64_t &, \
|
1936
2039
|
constant int64_t &, constant int64_t &, constant uint &, threadgroup uchar *, uint3, uint, uint);
|
1937
2040
|
|
1938
|
-
template [[host_name("kernel_mul_mm_f16_f32")]]
|
2041
|
+
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
|
1939
2042
|
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
|
1940
2043
|
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
|
2044
|
+
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
|
1941
2045
|
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
|
1942
2046
|
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
|
1943
2047
|
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
|