llama_cpp 0.3.8 → 0.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,6 +2,14 @@
2
2
 
3
3
  #include "ggml.h"
4
4
 
5
+ #ifdef GGML_USE_HIPBLAS
6
+ #define GGML_CUDA_NAME "ROCm"
7
+ #define GGML_CUBLAS_NAME "hipBLAS"
8
+ #else
9
+ #define GGML_CUDA_NAME "CUDA"
10
+ #define GGML_CUBLAS_NAME "cuBLAS"
11
+ #endif
12
+
5
13
  #ifdef __cplusplus
6
14
  extern "C" {
7
15
  #endif
@@ -16,9 +24,14 @@ GGML_API bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const str
16
24
  GGML_API void ggml_cuda_set_tensor_split(const float * tensor_split);
17
25
  GGML_API void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor);
18
26
  GGML_API void ggml_cuda_free_data(struct ggml_tensor * tensor);
27
+
19
28
  GGML_API void ggml_cuda_assign_buffers(struct ggml_tensor * tensor);
20
29
  GGML_API void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor);
21
30
  GGML_API void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor);
31
+
32
+ GGML_API void ggml_cuda_assign_buffers_no_alloc(struct ggml_tensor * tensor);
33
+ GGML_API void ggml_cuda_assign_scratch_offset(struct ggml_tensor * tensor, size_t offset);
34
+
22
35
  GGML_API void ggml_cuda_set_main_device(int main_device);
23
36
  GGML_API void ggml_cuda_set_mul_mat_q(bool mul_mat_q);
24
37
  GGML_API void ggml_cuda_set_scratch_size(size_t scratch_size);
@@ -38,6 +38,9 @@ struct ggml_metal_context;
38
38
  struct ggml_metal_context * ggml_metal_init(int n_cb);
39
39
  void ggml_metal_free(struct ggml_metal_context * ctx);
40
40
 
41
+ void * ggml_metal_host_malloc(size_t n);
42
+ void ggml_metal_host_free (void * data);
43
+
41
44
  // set the number of command buffers to use
42
45
  void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb);
43
46
 
@@ -63,6 +63,7 @@ struct ggml_metal_context {
63
63
  GGML_METAL_DECL_KERNEL(get_rows_f16);
64
64
  GGML_METAL_DECL_KERNEL(get_rows_q4_0);
65
65
  GGML_METAL_DECL_KERNEL(get_rows_q4_1);
66
+ GGML_METAL_DECL_KERNEL(get_rows_q8_0);
66
67
  GGML_METAL_DECL_KERNEL(get_rows_q2_K);
67
68
  GGML_METAL_DECL_KERNEL(get_rows_q3_K);
68
69
  GGML_METAL_DECL_KERNEL(get_rows_q4_K);
@@ -73,6 +74,7 @@ struct ggml_metal_context {
73
74
  GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
74
75
  GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
75
76
  GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
77
+ GGML_METAL_DECL_KERNEL(mul_mat_q8_0_f32);
76
78
  GGML_METAL_DECL_KERNEL(mul_mat_q2_K_f32);
77
79
  GGML_METAL_DECL_KERNEL(mul_mat_q3_K_f32);
78
80
  GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32);
@@ -81,6 +83,7 @@ struct ggml_metal_context {
81
83
  GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
82
84
  GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
83
85
  GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32);
86
+ GGML_METAL_DECL_KERNEL(mul_mm_q8_0_f32);
84
87
  GGML_METAL_DECL_KERNEL(mul_mm_q2_K_f32);
85
88
  GGML_METAL_DECL_KERNEL(mul_mm_q3_K_f32);
86
89
  GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
@@ -167,7 +170,9 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
167
170
  #define GGML_METAL_ADD_KERNEL(name) \
168
171
  ctx->function_##name = [ctx->library newFunctionWithName:@"kernel_"#name]; \
169
172
  ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:&error]; \
170
- fprintf(stderr, "%s: loaded %-32s %16p\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name); \
173
+ fprintf(stderr, "%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name, \
174
+ (int) ctx->pipeline_##name.maxTotalThreadsPerThreadgroup, \
175
+ (int) ctx->pipeline_##name.threadExecutionWidth); \
171
176
  if (error) { \
172
177
  fprintf(stderr, "%s: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
173
178
  return NULL; \
@@ -186,6 +191,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
186
191
  GGML_METAL_ADD_KERNEL(get_rows_f16);
187
192
  GGML_METAL_ADD_KERNEL(get_rows_q4_0);
188
193
  GGML_METAL_ADD_KERNEL(get_rows_q4_1);
194
+ GGML_METAL_ADD_KERNEL(get_rows_q8_0);
189
195
  GGML_METAL_ADD_KERNEL(get_rows_q2_K);
190
196
  GGML_METAL_ADD_KERNEL(get_rows_q3_K);
191
197
  GGML_METAL_ADD_KERNEL(get_rows_q4_K);
@@ -196,6 +202,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
196
202
  GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
197
203
  GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
198
204
  GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
205
+ GGML_METAL_ADD_KERNEL(mul_mat_q8_0_f32);
199
206
  GGML_METAL_ADD_KERNEL(mul_mat_q2_K_f32);
200
207
  GGML_METAL_ADD_KERNEL(mul_mat_q3_K_f32);
201
208
  GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32);
@@ -203,6 +210,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
203
210
  GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32);
204
211
  GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
205
212
  GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
213
+ GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
206
214
  GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32);
207
215
  GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32);
208
216
  GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32);
@@ -218,12 +226,12 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
218
226
  #undef GGML_METAL_ADD_KERNEL
219
227
  }
220
228
 
221
- fprintf(stderr, "%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
222
- fprintf(stderr, "%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
229
+ fprintf(stderr, "%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
230
+ fprintf(stderr, "%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
223
231
  if (ctx->device.maxTransferRate != 0) {
224
- fprintf(stderr, "%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
232
+ fprintf(stderr, "%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
225
233
  } else {
226
- fprintf(stderr, "%s: maxTransferRate = built-in GPU\n", __func__);
234
+ fprintf(stderr, "%s: maxTransferRate = built-in GPU\n", __func__);
227
235
  }
228
236
 
229
237
  return ctx;
@@ -237,6 +245,21 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
237
245
  free(ctx);
238
246
  }
239
247
 
248
+ void * ggml_metal_host_malloc(size_t n) {
249
+ void * data = NULL;
250
+ const int result = posix_memalign((void **) &data, getpagesize(), n);
251
+ if (result != 0) {
252
+ fprintf(stderr, "%s: error: posix_memalign failed\n", __func__);
253
+ return NULL;
254
+ }
255
+
256
+ return data;
257
+ }
258
+
259
+ void ggml_metal_host_free(void * data) {
260
+ free(data);
261
+ }
262
+
240
263
  void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb) {
241
264
  ctx->n_cb = n_cb;
242
265
  }
@@ -522,8 +545,8 @@ void ggml_metal_graph_compute(
522
545
 
523
546
  id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
524
547
 
525
- const int node_start = (cb_idx + 0) * n_nodes_per_cb;
526
- const int node_end = (cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb;
548
+ const int node_start = (cb_idx + 0) * n_nodes_per_cb;
549
+ const int node_end = MIN((cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb, n_nodes);
527
550
 
528
551
  for (int ind = node_start; ind < node_end; ++ind) {
529
552
  const int i = has_concur ? ctx->concur_list[ind] : ind;
@@ -729,32 +752,32 @@ void ggml_metal_graph_compute(
729
752
  [ctx->device supportsFamily:MTLGPUFamilyApple7] &&
730
753
  ne00%32 == 0 &&
731
754
  ne11 > 1) {
732
- switch (src0->type) {
733
- case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
734
- case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
735
- case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
736
- case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break;
737
- case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break;
738
- case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_K_f32]; break;
739
- case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_K_f32]; break;
740
- case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q6_K_f32]; break;
741
- default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
742
- }
743
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
744
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
745
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
746
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
747
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
748
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
749
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
750
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
751
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:8];
752
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:9];
753
- [encoder setBytes:&gqa length:sizeof(gqa) atIndex:10];
754
- [encoder setThreadgroupMemoryLength:8192 atIndex:0];
755
- [encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
755
+ switch (src0->type) {
756
+ case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
757
+ case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
758
+ case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
759
+ case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q8_0_f32]; break;
760
+ case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break;
761
+ case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break;
762
+ case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_K_f32]; break;
763
+ case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_K_f32]; break;
764
+ case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q6_K_f32]; break;
765
+ default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
756
766
  }
757
- else {
767
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
768
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
769
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
770
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
771
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
772
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
773
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
774
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
775
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:8];
776
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:9];
777
+ [encoder setBytes:&gqa length:sizeof(gqa) atIndex:10];
778
+ [encoder setThreadgroupMemoryLength:8192 atIndex:0];
779
+ [encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
780
+ } else {
758
781
  int nth0 = 32;
759
782
  int nth1 = 1;
760
783
 
@@ -784,6 +807,15 @@ void ggml_metal_graph_compute(
784
807
  nth1 = 8;
785
808
  [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_1_f32];
786
809
  } break;
810
+ case GGML_TYPE_Q8_0:
811
+ {
812
+ GGML_ASSERT(ne02 == 1);
813
+ GGML_ASSERT(ne12 == 1);
814
+
815
+ nth0 = 8;
816
+ nth1 = 8;
817
+ [encoder setComputePipelineState:ctx->pipeline_mul_mat_q8_0_f32];
818
+ } break;
787
819
  case GGML_TYPE_Q2_K:
788
820
  {
789
821
  GGML_ASSERT(ne02 == 1);
@@ -853,24 +885,24 @@ void ggml_metal_graph_compute(
853
885
  [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
854
886
  [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
855
887
  [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
856
- [encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
888
+ [encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
857
889
 
858
- if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
890
+ if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q8_0 ||
859
891
  src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
860
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
892
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
861
893
  }
862
894
  else if (src0t == GGML_TYPE_Q3_K) {
863
895
  #ifdef GGML_QKK_64
864
- [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
896
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
865
897
  #else
866
- [encoder dispatchThreadgroups:MTLSizeMake((ne01+3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
898
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
867
899
  #endif
868
900
  }
869
901
  else if (src0t == GGML_TYPE_Q5_K) {
870
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3) / 4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
902
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
871
903
  }
872
904
  else if (src0t == GGML_TYPE_Q6_K) {
873
- [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
905
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
874
906
  } else {
875
907
  [encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
876
908
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
@@ -880,9 +912,10 @@ void ggml_metal_graph_compute(
880
912
  case GGML_OP_GET_ROWS:
881
913
  {
882
914
  switch (src0->type) {
883
- case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
915
+ case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
884
916
  case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
885
917
  case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
918
+ case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q8_0]; break;
886
919
  case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_K]; break;
887
920
  case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_K]; break;
888
921
  case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_K]; break;
@@ -923,16 +956,17 @@ void ggml_metal_graph_compute(
923
956
  } break;
924
957
  case GGML_OP_NORM:
925
958
  {
926
- const float eps = 1e-5f;
959
+ float eps;
960
+ memcpy(&eps, dst->op_params, sizeof(float));
927
961
 
928
962
  const int nth = 256;
929
963
 
930
964
  [encoder setComputePipelineState:ctx->pipeline_norm];
931
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
932
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
933
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
934
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
935
- [encoder setBytes:&eps length:sizeof( float) atIndex:4];
965
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
966
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
967
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
968
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
969
+ [encoder setBytes:&eps length:sizeof( float) atIndex:4];
936
970
  [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
937
971
 
938
972
  const int64_t nrows = ggml_nrows(src0);
@@ -975,7 +1009,9 @@ void ggml_metal_graph_compute(
975
1009
  [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
976
1010
  [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
977
1011
  [encoder setBytes:&m0 length:sizeof( float) atIndex:18];
1012
+
978
1013
  const int nth = 32;
1014
+
979
1015
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
980
1016
  } break;
981
1017
  case GGML_OP_ROPE:
@@ -990,8 +1026,8 @@ void ggml_metal_graph_compute(
990
1026
  memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
991
1027
 
992
1028
  [encoder setComputePipelineState:ctx->pipeline_rope];
993
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
994
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1029
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1030
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
995
1031
  [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
996
1032
  [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
997
1033
  [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
@@ -1042,24 +1078,24 @@ void ggml_metal_graph_compute(
1042
1078
  default: GGML_ASSERT(false && "not implemented");
1043
1079
  }
1044
1080
 
1045
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1046
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1047
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1048
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
1049
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
1050
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
1051
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
1052
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
1053
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
1054
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
1055
- [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
1056
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
1057
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
1058
- [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
1059
- [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
1060
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
1061
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
1062
- [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
1081
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1082
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1083
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1084
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
1085
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
1086
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
1087
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
1088
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
1089
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
1090
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
1091
+ [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
1092
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
1093
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
1094
+ [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
1095
+ [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
1096
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
1097
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
1098
+ [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
1063
1099
 
1064
1100
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1065
1101
  } break;
@@ -18,6 +18,12 @@ typedef struct {
18
18
  uint8_t qs[QK4_1 / 2]; // nibbles / quants
19
19
  } block_q4_1;
20
20
 
21
+ #define QK8_0 32
22
+ typedef struct {
23
+ half d; // delta
24
+ int8_t qs[QK8_0]; // quants
25
+ } block_q8_0;
26
+
21
27
  kernel void kernel_add(
22
28
  device const float * src0,
23
29
  device const float * src1,
@@ -87,7 +93,12 @@ kernel void kernel_gelu(
87
93
  device float * dst,
88
94
  uint tpig[[thread_position_in_grid]]) {
89
95
  float x = src0[tpig];
90
- dst[tpig] = 0.5f*x*(1.0f + tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
96
+
97
+ // BEWARE !!!
98
+ // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
99
+ // This was observed with Falcon 7B and 40B models
100
+ //
101
+ dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
91
102
  }
92
103
 
93
104
  kernel void kernel_soft_max(
@@ -352,7 +363,7 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
352
363
  const int first_row = (r0 * nsg + sgitg) * nr;
353
364
  const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
354
365
  device const block_q_type * x = (device const block_q_type *) src0 + offset0;
355
- device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
366
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
356
367
  float yl[16]; // src1 vector cache
357
368
  float sumf[nr]={0.f};
358
369
 
@@ -424,6 +435,68 @@ kernel void kernel_mul_mat_q4_1_f32(
424
435
  mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
425
436
  }
426
437
 
438
+ kernel void kernel_mul_mat_q8_0_f32(
439
+ device const void * src0,
440
+ device const float * src1,
441
+ device float * dst,
442
+ constant int64_t & ne00,
443
+ constant int64_t & ne01[[buffer(4)]],
444
+ constant int64_t & ne02[[buffer(5)]],
445
+ constant int64_t & ne10[[buffer(9)]],
446
+ constant int64_t & ne12[[buffer(11)]],
447
+ constant int64_t & ne0[[buffer(15)]],
448
+ constant int64_t & ne1[[buffer(16)]],
449
+ constant uint & gqa[[buffer(17)]],
450
+ uint3 tgpig[[threadgroup_position_in_grid]],
451
+ uint tiisg[[thread_index_in_simdgroup]],
452
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
453
+ const int nr = N_DST;
454
+ const int nsg = N_SIMDGROUP;
455
+ const int nw = N_SIMDWIDTH;
456
+
457
+ const int nb = ne00/QK8_0;
458
+ const int r0 = tgpig.x;
459
+ const int r1 = tgpig.y;
460
+ const int im = tgpig.z;
461
+ const int first_row = (r0 * nsg + sgitg) * nr;
462
+ const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
463
+ device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0;
464
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
465
+
466
+ float yl[16];
467
+ float sumf[nr]={0.f};
468
+
469
+ const int ix = tiisg/2;
470
+ const int il = tiisg%2;
471
+
472
+ device const float * yb = y + ix * QK8_0 + 16*il;
473
+
474
+ // each thread in a SIMD group deals with half a block.
475
+ for (int ib = ix; ib < nb; ib += nw/2) {
476
+ for (int i = 0; i < 16; ++i) {
477
+ yl[i] = yb[i];
478
+ }
479
+
480
+ for (int row = 0; row < nr; row++) {
481
+ device const int8_t * qs = x[ib+row*nb].qs + 16*il;
482
+ float sumq = 0.f;
483
+ for (int iq = 0; iq < 16; ++iq) {
484
+ sumq += qs[iq] * yl[iq];
485
+ }
486
+ sumf[row] += sumq*x[ib+row*nb].d;
487
+ }
488
+
489
+ yb += QK8_0 * 16;
490
+ }
491
+
492
+ for (int row = 0; row < nr; ++row) {
493
+ const float tot = simd_sum(sumf[row]);
494
+ if (tiisg == 0 && first_row + row < ne01) {
495
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
496
+ }
497
+ }
498
+ }
499
+
427
500
  kernel void kernel_mul_mat_f16_f32(
428
501
  device const char * src0,
429
502
  device const char * src1,
@@ -475,7 +548,6 @@ kernel void kernel_mul_mat_f16_f32(
475
548
  }
476
549
  }
477
550
 
478
-
479
551
  kernel void kernel_alibi_f32(
480
552
  device const float * src0,
481
553
  device float * dst,
@@ -571,7 +643,25 @@ kernel void kernel_rope(
571
643
  dst_data[1] = x0*sin_theta + x1*cos_theta;
572
644
  }
573
645
  } else {
574
- // TODO: implement
646
+ for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
647
+ for (int64_t ic = 0; ic < n_dims; ic += 2) {
648
+ const float cos_theta = cos(theta);
649
+ const float sin_theta = sin(theta);
650
+
651
+ theta *= theta_scale;
652
+
653
+ const int64_t i0 = ib*n_dims + ic/2;
654
+
655
+ device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
656
+ device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
657
+
658
+ const float x0 = src[0];
659
+ const float x1 = src[n_dims/2];
660
+
661
+ dst_data[0] = x0*cos_theta - x1*sin_theta;
662
+ dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
663
+ }
664
+ }
575
665
  }
576
666
  }
577
667
 
@@ -1598,12 +1688,12 @@ template <typename type4x4>
1598
1688
  void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
1599
1689
  device const uint16_t * qs = ((device const uint16_t *)xb + 1);
1600
1690
  const half d = il ? (xb->d / 16.h) : xb->d;
1601
- const half m = il ? (-8.h * 16.h) : -8.h;
1691
+ const half m = il ? ( -8.h * 16.h) : -8.h;
1602
1692
  const ushort mask0 = il ? 0x00F0 : 0x000F;
1603
1693
  const ushort mask1 = il ? 0xF000 : 0x0F00;
1604
1694
 
1605
1695
  for (int i=0;i<8;i++) {
1606
- reg[i/2][2*(i%2)] = (((qs[i] & mask0)) + m) * d;
1696
+ reg[i/2][2*(i%2)] = (((qs[i] & mask0) ) + m) * d;
1607
1697
  reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) + m) * d;
1608
1698
  }
1609
1699
  }
@@ -1617,11 +1707,21 @@ void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg
1617
1707
  const ushort mask1 = il ? 0xF000 : 0x0F00;
1618
1708
 
1619
1709
  for (int i=0;i<8;i++) {
1620
- reg[i/2][2*(i%2)] = (((qs[i] & mask0)) * d) + m;
1710
+ reg[i/2][2*(i%2)] = (((qs[i] & mask0) ) * d) + m;
1621
1711
  reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) * d) + m;
1622
1712
  }
1623
1713
  }
1624
1714
 
1715
+ template <typename type4x4>
1716
+ void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
1717
+ device const int8_t * qs = ((device const int8_t *)xb->qs);
1718
+ const half d = xb->d;
1719
+
1720
+ for (int i=0;i<16;i++) {
1721
+ reg[i/4][i%4] = (qs[i + 16*il] * d);
1722
+ }
1723
+ }
1724
+
1625
1725
  template <typename type4x4>
1626
1726
  void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
1627
1727
  const half d = xb->d;
@@ -1850,6 +1950,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
1850
1950
  //load data and store to threadgroup memory
1851
1951
  half4x4 temp_a;
1852
1952
  dequantize_func(x, il, temp_a);
1953
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1853
1954
  #pragma unroll(16)
1854
1955
  for (int i = 0; i < 16; i++) {
1855
1956
  *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
@@ -1895,6 +1996,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
1895
1996
  }
1896
1997
  } else {
1897
1998
  // block is smaller than 64x32, we should avoid writing data outside of the matrix
1999
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1898
2000
  threadgroup float *temp_str = ((threadgroup float *)shared_memory) \
1899
2001
  + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
1900
2002
  for (int i = 0; i < 8; i++) {
@@ -1922,9 +2024,10 @@ kernel void kernel_mul_mm(device const uchar * src0,
1922
2024
  typedef void (get_rows_t)(device const void *, device const int *, device float *, constant int64_t &, \
1923
2025
  constant uint64_t &, constant uint64_t &, uint, uint, uint);
1924
2026
 
1925
- template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
2027
+ template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
1926
2028
  template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
1927
2029
  template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
2030
+ template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows<block_q8_0, 2, dequantize_q8_0>;
1928
2031
  template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows<block_q2_K, QK_NL, dequantize_q2_K>;
1929
2032
  template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows<block_q3_K, QK_NL, dequantize_q3_K>;
1930
2033
  template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows<block_q4_K, QK_NL, dequantize_q4_K>;
@@ -1935,9 +2038,10 @@ typedef void (mat_mm_t)(device const uchar *, device const float *, device float
1935
2038
  constant int64_t &, constant int64_t &, constant int64_t &, constant int64_t &, \
1936
2039
  constant int64_t &, constant int64_t &, constant uint &, threadgroup uchar *, uint3, uint, uint);
1937
2040
 
1938
- template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
2041
+ template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
1939
2042
  template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
1940
2043
  template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
2044
+ template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
1941
2045
  template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
1942
2046
  template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
1943
2047
  template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;