llama_cpp 0.3.8 → 0.4.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -2,6 +2,14 @@
2
2
 
3
3
  #include "ggml.h"
4
4
 
5
+ #ifdef GGML_USE_HIPBLAS
6
+ #define GGML_CUDA_NAME "ROCm"
7
+ #define GGML_CUBLAS_NAME "hipBLAS"
8
+ #else
9
+ #define GGML_CUDA_NAME "CUDA"
10
+ #define GGML_CUBLAS_NAME "cuBLAS"
11
+ #endif
12
+
5
13
  #ifdef __cplusplus
6
14
  extern "C" {
7
15
  #endif
@@ -16,9 +24,14 @@ GGML_API bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const str
16
24
  GGML_API void ggml_cuda_set_tensor_split(const float * tensor_split);
17
25
  GGML_API void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor);
18
26
  GGML_API void ggml_cuda_free_data(struct ggml_tensor * tensor);
27
+
19
28
  GGML_API void ggml_cuda_assign_buffers(struct ggml_tensor * tensor);
20
29
  GGML_API void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor);
21
30
  GGML_API void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor);
31
+
32
+ GGML_API void ggml_cuda_assign_buffers_no_alloc(struct ggml_tensor * tensor);
33
+ GGML_API void ggml_cuda_assign_scratch_offset(struct ggml_tensor * tensor, size_t offset);
34
+
22
35
  GGML_API void ggml_cuda_set_main_device(int main_device);
23
36
  GGML_API void ggml_cuda_set_mul_mat_q(bool mul_mat_q);
24
37
  GGML_API void ggml_cuda_set_scratch_size(size_t scratch_size);
@@ -38,6 +38,9 @@ struct ggml_metal_context;
38
38
  struct ggml_metal_context * ggml_metal_init(int n_cb);
39
39
  void ggml_metal_free(struct ggml_metal_context * ctx);
40
40
 
41
+ void * ggml_metal_host_malloc(size_t n);
42
+ void ggml_metal_host_free (void * data);
43
+
41
44
  // set the number of command buffers to use
42
45
  void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb);
43
46
 
@@ -63,6 +63,7 @@ struct ggml_metal_context {
63
63
  GGML_METAL_DECL_KERNEL(get_rows_f16);
64
64
  GGML_METAL_DECL_KERNEL(get_rows_q4_0);
65
65
  GGML_METAL_DECL_KERNEL(get_rows_q4_1);
66
+ GGML_METAL_DECL_KERNEL(get_rows_q8_0);
66
67
  GGML_METAL_DECL_KERNEL(get_rows_q2_K);
67
68
  GGML_METAL_DECL_KERNEL(get_rows_q3_K);
68
69
  GGML_METAL_DECL_KERNEL(get_rows_q4_K);
@@ -73,6 +74,7 @@ struct ggml_metal_context {
73
74
  GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
74
75
  GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
75
76
  GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
77
+ GGML_METAL_DECL_KERNEL(mul_mat_q8_0_f32);
76
78
  GGML_METAL_DECL_KERNEL(mul_mat_q2_K_f32);
77
79
  GGML_METAL_DECL_KERNEL(mul_mat_q3_K_f32);
78
80
  GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32);
@@ -81,6 +83,7 @@ struct ggml_metal_context {
81
83
  GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
82
84
  GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
83
85
  GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32);
86
+ GGML_METAL_DECL_KERNEL(mul_mm_q8_0_f32);
84
87
  GGML_METAL_DECL_KERNEL(mul_mm_q2_K_f32);
85
88
  GGML_METAL_DECL_KERNEL(mul_mm_q3_K_f32);
86
89
  GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
@@ -167,7 +170,9 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
167
170
  #define GGML_METAL_ADD_KERNEL(name) \
168
171
  ctx->function_##name = [ctx->library newFunctionWithName:@"kernel_"#name]; \
169
172
  ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:&error]; \
170
- fprintf(stderr, "%s: loaded %-32s %16p\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name); \
173
+ fprintf(stderr, "%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name, \
174
+ (int) ctx->pipeline_##name.maxTotalThreadsPerThreadgroup, \
175
+ (int) ctx->pipeline_##name.threadExecutionWidth); \
171
176
  if (error) { \
172
177
  fprintf(stderr, "%s: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
173
178
  return NULL; \
@@ -186,6 +191,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
186
191
  GGML_METAL_ADD_KERNEL(get_rows_f16);
187
192
  GGML_METAL_ADD_KERNEL(get_rows_q4_0);
188
193
  GGML_METAL_ADD_KERNEL(get_rows_q4_1);
194
+ GGML_METAL_ADD_KERNEL(get_rows_q8_0);
189
195
  GGML_METAL_ADD_KERNEL(get_rows_q2_K);
190
196
  GGML_METAL_ADD_KERNEL(get_rows_q3_K);
191
197
  GGML_METAL_ADD_KERNEL(get_rows_q4_K);
@@ -196,6 +202,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
196
202
  GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
197
203
  GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
198
204
  GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
205
+ GGML_METAL_ADD_KERNEL(mul_mat_q8_0_f32);
199
206
  GGML_METAL_ADD_KERNEL(mul_mat_q2_K_f32);
200
207
  GGML_METAL_ADD_KERNEL(mul_mat_q3_K_f32);
201
208
  GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32);
@@ -203,6 +210,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
203
210
  GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32);
204
211
  GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
205
212
  GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
213
+ GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
206
214
  GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32);
207
215
  GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32);
208
216
  GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32);
@@ -218,12 +226,12 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
218
226
  #undef GGML_METAL_ADD_KERNEL
219
227
  }
220
228
 
221
- fprintf(stderr, "%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
222
- fprintf(stderr, "%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
229
+ fprintf(stderr, "%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
230
+ fprintf(stderr, "%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
223
231
  if (ctx->device.maxTransferRate != 0) {
224
- fprintf(stderr, "%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
232
+ fprintf(stderr, "%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
225
233
  } else {
226
- fprintf(stderr, "%s: maxTransferRate = built-in GPU\n", __func__);
234
+ fprintf(stderr, "%s: maxTransferRate = built-in GPU\n", __func__);
227
235
  }
228
236
 
229
237
  return ctx;
@@ -237,6 +245,21 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
237
245
  free(ctx);
238
246
  }
239
247
 
248
+ void * ggml_metal_host_malloc(size_t n) {
249
+ void * data = NULL;
250
+ const int result = posix_memalign((void **) &data, getpagesize(), n);
251
+ if (result != 0) {
252
+ fprintf(stderr, "%s: error: posix_memalign failed\n", __func__);
253
+ return NULL;
254
+ }
255
+
256
+ return data;
257
+ }
258
+
259
+ void ggml_metal_host_free(void * data) {
260
+ free(data);
261
+ }
262
+
240
263
  void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb) {
241
264
  ctx->n_cb = n_cb;
242
265
  }
@@ -522,8 +545,8 @@ void ggml_metal_graph_compute(
522
545
 
523
546
  id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
524
547
 
525
- const int node_start = (cb_idx + 0) * n_nodes_per_cb;
526
- const int node_end = (cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb;
548
+ const int node_start = (cb_idx + 0) * n_nodes_per_cb;
549
+ const int node_end = MIN((cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb, n_nodes);
527
550
 
528
551
  for (int ind = node_start; ind < node_end; ++ind) {
529
552
  const int i = has_concur ? ctx->concur_list[ind] : ind;
@@ -729,32 +752,32 @@ void ggml_metal_graph_compute(
729
752
  [ctx->device supportsFamily:MTLGPUFamilyApple7] &&
730
753
  ne00%32 == 0 &&
731
754
  ne11 > 1) {
732
- switch (src0->type) {
733
- case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
734
- case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
735
- case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
736
- case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break;
737
- case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break;
738
- case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_K_f32]; break;
739
- case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_K_f32]; break;
740
- case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q6_K_f32]; break;
741
- default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
742
- }
743
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
744
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
745
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
746
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
747
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
748
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
749
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
750
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
751
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:8];
752
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:9];
753
- [encoder setBytes:&gqa length:sizeof(gqa) atIndex:10];
754
- [encoder setThreadgroupMemoryLength:8192 atIndex:0];
755
- [encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
755
+ switch (src0->type) {
756
+ case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
757
+ case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
758
+ case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
759
+ case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q8_0_f32]; break;
760
+ case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break;
761
+ case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break;
762
+ case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_K_f32]; break;
763
+ case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_K_f32]; break;
764
+ case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q6_K_f32]; break;
765
+ default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
756
766
  }
757
- else {
767
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
768
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
769
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
770
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
771
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
772
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
773
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
774
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
775
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:8];
776
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:9];
777
+ [encoder setBytes:&gqa length:sizeof(gqa) atIndex:10];
778
+ [encoder setThreadgroupMemoryLength:8192 atIndex:0];
779
+ [encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
780
+ } else {
758
781
  int nth0 = 32;
759
782
  int nth1 = 1;
760
783
 
@@ -784,6 +807,15 @@ void ggml_metal_graph_compute(
784
807
  nth1 = 8;
785
808
  [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_1_f32];
786
809
  } break;
810
+ case GGML_TYPE_Q8_0:
811
+ {
812
+ GGML_ASSERT(ne02 == 1);
813
+ GGML_ASSERT(ne12 == 1);
814
+
815
+ nth0 = 8;
816
+ nth1 = 8;
817
+ [encoder setComputePipelineState:ctx->pipeline_mul_mat_q8_0_f32];
818
+ } break;
787
819
  case GGML_TYPE_Q2_K:
788
820
  {
789
821
  GGML_ASSERT(ne02 == 1);
@@ -853,24 +885,24 @@ void ggml_metal_graph_compute(
853
885
  [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
854
886
  [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
855
887
  [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
856
- [encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
888
+ [encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
857
889
 
858
- if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
890
+ if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q8_0 ||
859
891
  src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
860
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
892
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
861
893
  }
862
894
  else if (src0t == GGML_TYPE_Q3_K) {
863
895
  #ifdef GGML_QKK_64
864
- [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
896
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
865
897
  #else
866
- [encoder dispatchThreadgroups:MTLSizeMake((ne01+3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
898
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
867
899
  #endif
868
900
  }
869
901
  else if (src0t == GGML_TYPE_Q5_K) {
870
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3) / 4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
902
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
871
903
  }
872
904
  else if (src0t == GGML_TYPE_Q6_K) {
873
- [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
905
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
874
906
  } else {
875
907
  [encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
876
908
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
@@ -880,9 +912,10 @@ void ggml_metal_graph_compute(
880
912
  case GGML_OP_GET_ROWS:
881
913
  {
882
914
  switch (src0->type) {
883
- case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
915
+ case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
884
916
  case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
885
917
  case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
918
+ case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q8_0]; break;
886
919
  case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_K]; break;
887
920
  case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_K]; break;
888
921
  case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_K]; break;
@@ -923,16 +956,17 @@ void ggml_metal_graph_compute(
923
956
  } break;
924
957
  case GGML_OP_NORM:
925
958
  {
926
- const float eps = 1e-5f;
959
+ float eps;
960
+ memcpy(&eps, dst->op_params, sizeof(float));
927
961
 
928
962
  const int nth = 256;
929
963
 
930
964
  [encoder setComputePipelineState:ctx->pipeline_norm];
931
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
932
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
933
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
934
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
935
- [encoder setBytes:&eps length:sizeof( float) atIndex:4];
965
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
966
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
967
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
968
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
969
+ [encoder setBytes:&eps length:sizeof( float) atIndex:4];
936
970
  [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
937
971
 
938
972
  const int64_t nrows = ggml_nrows(src0);
@@ -975,7 +1009,9 @@ void ggml_metal_graph_compute(
975
1009
  [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
976
1010
  [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
977
1011
  [encoder setBytes:&m0 length:sizeof( float) atIndex:18];
1012
+
978
1013
  const int nth = 32;
1014
+
979
1015
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
980
1016
  } break;
981
1017
  case GGML_OP_ROPE:
@@ -990,8 +1026,8 @@ void ggml_metal_graph_compute(
990
1026
  memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
991
1027
 
992
1028
  [encoder setComputePipelineState:ctx->pipeline_rope];
993
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
994
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1029
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1030
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
995
1031
  [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
996
1032
  [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
997
1033
  [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
@@ -1042,24 +1078,24 @@ void ggml_metal_graph_compute(
1042
1078
  default: GGML_ASSERT(false && "not implemented");
1043
1079
  }
1044
1080
 
1045
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1046
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1047
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1048
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
1049
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
1050
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
1051
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
1052
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
1053
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
1054
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
1055
- [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
1056
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
1057
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
1058
- [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
1059
- [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
1060
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
1061
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
1062
- [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
1081
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1082
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1083
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1084
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
1085
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
1086
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
1087
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
1088
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
1089
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
1090
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
1091
+ [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
1092
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
1093
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
1094
+ [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
1095
+ [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
1096
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
1097
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
1098
+ [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
1063
1099
 
1064
1100
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1065
1101
  } break;
@@ -18,6 +18,12 @@ typedef struct {
18
18
  uint8_t qs[QK4_1 / 2]; // nibbles / quants
19
19
  } block_q4_1;
20
20
 
21
+ #define QK8_0 32
22
+ typedef struct {
23
+ half d; // delta
24
+ int8_t qs[QK8_0]; // quants
25
+ } block_q8_0;
26
+
21
27
  kernel void kernel_add(
22
28
  device const float * src0,
23
29
  device const float * src1,
@@ -87,7 +93,12 @@ kernel void kernel_gelu(
87
93
  device float * dst,
88
94
  uint tpig[[thread_position_in_grid]]) {
89
95
  float x = src0[tpig];
90
- dst[tpig] = 0.5f*x*(1.0f + tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
96
+
97
+ // BEWARE !!!
98
+ // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
99
+ // This was observed with Falcon 7B and 40B models
100
+ //
101
+ dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
91
102
  }
92
103
 
93
104
  kernel void kernel_soft_max(
@@ -352,7 +363,7 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
352
363
  const int first_row = (r0 * nsg + sgitg) * nr;
353
364
  const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
354
365
  device const block_q_type * x = (device const block_q_type *) src0 + offset0;
355
- device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
366
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
356
367
  float yl[16]; // src1 vector cache
357
368
  float sumf[nr]={0.f};
358
369
 
@@ -424,6 +435,68 @@ kernel void kernel_mul_mat_q4_1_f32(
424
435
  mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
425
436
  }
426
437
 
438
+ kernel void kernel_mul_mat_q8_0_f32(
439
+ device const void * src0,
440
+ device const float * src1,
441
+ device float * dst,
442
+ constant int64_t & ne00,
443
+ constant int64_t & ne01[[buffer(4)]],
444
+ constant int64_t & ne02[[buffer(5)]],
445
+ constant int64_t & ne10[[buffer(9)]],
446
+ constant int64_t & ne12[[buffer(11)]],
447
+ constant int64_t & ne0[[buffer(15)]],
448
+ constant int64_t & ne1[[buffer(16)]],
449
+ constant uint & gqa[[buffer(17)]],
450
+ uint3 tgpig[[threadgroup_position_in_grid]],
451
+ uint tiisg[[thread_index_in_simdgroup]],
452
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
453
+ const int nr = N_DST;
454
+ const int nsg = N_SIMDGROUP;
455
+ const int nw = N_SIMDWIDTH;
456
+
457
+ const int nb = ne00/QK8_0;
458
+ const int r0 = tgpig.x;
459
+ const int r1 = tgpig.y;
460
+ const int im = tgpig.z;
461
+ const int first_row = (r0 * nsg + sgitg) * nr;
462
+ const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
463
+ device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0;
464
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
465
+
466
+ float yl[16];
467
+ float sumf[nr]={0.f};
468
+
469
+ const int ix = tiisg/2;
470
+ const int il = tiisg%2;
471
+
472
+ device const float * yb = y + ix * QK8_0 + 16*il;
473
+
474
+ // each thread in a SIMD group deals with half a block.
475
+ for (int ib = ix; ib < nb; ib += nw/2) {
476
+ for (int i = 0; i < 16; ++i) {
477
+ yl[i] = yb[i];
478
+ }
479
+
480
+ for (int row = 0; row < nr; row++) {
481
+ device const int8_t * qs = x[ib+row*nb].qs + 16*il;
482
+ float sumq = 0.f;
483
+ for (int iq = 0; iq < 16; ++iq) {
484
+ sumq += qs[iq] * yl[iq];
485
+ }
486
+ sumf[row] += sumq*x[ib+row*nb].d;
487
+ }
488
+
489
+ yb += QK8_0 * 16;
490
+ }
491
+
492
+ for (int row = 0; row < nr; ++row) {
493
+ const float tot = simd_sum(sumf[row]);
494
+ if (tiisg == 0 && first_row + row < ne01) {
495
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
496
+ }
497
+ }
498
+ }
499
+
427
500
  kernel void kernel_mul_mat_f16_f32(
428
501
  device const char * src0,
429
502
  device const char * src1,
@@ -475,7 +548,6 @@ kernel void kernel_mul_mat_f16_f32(
475
548
  }
476
549
  }
477
550
 
478
-
479
551
  kernel void kernel_alibi_f32(
480
552
  device const float * src0,
481
553
  device float * dst,
@@ -571,7 +643,25 @@ kernel void kernel_rope(
571
643
  dst_data[1] = x0*sin_theta + x1*cos_theta;
572
644
  }
573
645
  } else {
574
- // TODO: implement
646
+ for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
647
+ for (int64_t ic = 0; ic < n_dims; ic += 2) {
648
+ const float cos_theta = cos(theta);
649
+ const float sin_theta = sin(theta);
650
+
651
+ theta *= theta_scale;
652
+
653
+ const int64_t i0 = ib*n_dims + ic/2;
654
+
655
+ device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
656
+ device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
657
+
658
+ const float x0 = src[0];
659
+ const float x1 = src[n_dims/2];
660
+
661
+ dst_data[0] = x0*cos_theta - x1*sin_theta;
662
+ dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
663
+ }
664
+ }
575
665
  }
576
666
  }
577
667
 
@@ -1598,12 +1688,12 @@ template <typename type4x4>
1598
1688
  void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
1599
1689
  device const uint16_t * qs = ((device const uint16_t *)xb + 1);
1600
1690
  const half d = il ? (xb->d / 16.h) : xb->d;
1601
- const half m = il ? (-8.h * 16.h) : -8.h;
1691
+ const half m = il ? ( -8.h * 16.h) : -8.h;
1602
1692
  const ushort mask0 = il ? 0x00F0 : 0x000F;
1603
1693
  const ushort mask1 = il ? 0xF000 : 0x0F00;
1604
1694
 
1605
1695
  for (int i=0;i<8;i++) {
1606
- reg[i/2][2*(i%2)] = (((qs[i] & mask0)) + m) * d;
1696
+ reg[i/2][2*(i%2)] = (((qs[i] & mask0) ) + m) * d;
1607
1697
  reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) + m) * d;
1608
1698
  }
1609
1699
  }
@@ -1617,11 +1707,21 @@ void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg
1617
1707
  const ushort mask1 = il ? 0xF000 : 0x0F00;
1618
1708
 
1619
1709
  for (int i=0;i<8;i++) {
1620
- reg[i/2][2*(i%2)] = (((qs[i] & mask0)) * d) + m;
1710
+ reg[i/2][2*(i%2)] = (((qs[i] & mask0) ) * d) + m;
1621
1711
  reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) * d) + m;
1622
1712
  }
1623
1713
  }
1624
1714
 
1715
+ template <typename type4x4>
1716
+ void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
1717
+ device const int8_t * qs = ((device const int8_t *)xb->qs);
1718
+ const half d = xb->d;
1719
+
1720
+ for (int i=0;i<16;i++) {
1721
+ reg[i/4][i%4] = (qs[i + 16*il] * d);
1722
+ }
1723
+ }
1724
+
1625
1725
  template <typename type4x4>
1626
1726
  void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
1627
1727
  const half d = xb->d;
@@ -1850,6 +1950,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
1850
1950
  //load data and store to threadgroup memory
1851
1951
  half4x4 temp_a;
1852
1952
  dequantize_func(x, il, temp_a);
1953
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1853
1954
  #pragma unroll(16)
1854
1955
  for (int i = 0; i < 16; i++) {
1855
1956
  *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
@@ -1895,6 +1996,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
1895
1996
  }
1896
1997
  } else {
1897
1998
  // block is smaller than 64x32, we should avoid writing data outside of the matrix
1999
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1898
2000
  threadgroup float *temp_str = ((threadgroup float *)shared_memory) \
1899
2001
  + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
1900
2002
  for (int i = 0; i < 8; i++) {
@@ -1922,9 +2024,10 @@ kernel void kernel_mul_mm(device const uchar * src0,
1922
2024
  typedef void (get_rows_t)(device const void *, device const int *, device float *, constant int64_t &, \
1923
2025
  constant uint64_t &, constant uint64_t &, uint, uint, uint);
1924
2026
 
1925
- template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
2027
+ template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
1926
2028
  template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
1927
2029
  template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
2030
+ template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows<block_q8_0, 2, dequantize_q8_0>;
1928
2031
  template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows<block_q2_K, QK_NL, dequantize_q2_K>;
1929
2032
  template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows<block_q3_K, QK_NL, dequantize_q3_K>;
1930
2033
  template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows<block_q4_K, QK_NL, dequantize_q4_K>;
@@ -1935,9 +2038,10 @@ typedef void (mat_mm_t)(device const uchar *, device const float *, device float
1935
2038
  constant int64_t &, constant int64_t &, constant int64_t &, constant int64_t &, \
1936
2039
  constant int64_t &, constant int64_t &, constant uint &, threadgroup uchar *, uint3, uint, uint);
1937
2040
 
1938
- template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
2041
+ template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
1939
2042
  template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
1940
2043
  template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
2044
+ template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
1941
2045
  template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
1942
2046
  template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
1943
2047
  template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;