llama_cpp 0.5.3 → 0.6.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -24,12 +24,59 @@ typedef struct {
24
24
  int8_t qs[QK8_0]; // quants
25
25
  } block_q8_0;
26
26
 
27
+ // general-purpose kernel for addition of two tensors
28
+ // pros: works for non-contiguous tensors, supports broadcast across dims 1, 2 and 3
29
+ // cons: not very efficient
27
30
  kernel void kernel_add(
28
- device const float4 * src0,
29
- device const float4 * src1,
30
- device float4 * dst,
31
- uint tpig[[thread_position_in_grid]]) {
32
- dst[tpig] = src0[tpig] + src1[tpig];
31
+ device const char * src0,
32
+ device const char * src1,
33
+ device char * dst,
34
+ constant int64_t & ne00,
35
+ constant int64_t & ne01,
36
+ constant int64_t & ne02,
37
+ constant int64_t & ne03,
38
+ constant int64_t & nb00,
39
+ constant int64_t & nb01,
40
+ constant int64_t & nb02,
41
+ constant int64_t & nb03,
42
+ constant int64_t & ne10,
43
+ constant int64_t & ne11,
44
+ constant int64_t & ne12,
45
+ constant int64_t & ne13,
46
+ constant int64_t & nb10,
47
+ constant int64_t & nb11,
48
+ constant int64_t & nb12,
49
+ constant int64_t & nb13,
50
+ constant int64_t & ne0,
51
+ constant int64_t & ne1,
52
+ constant int64_t & ne2,
53
+ constant int64_t & ne3,
54
+ constant int64_t & nb0,
55
+ constant int64_t & nb1,
56
+ constant int64_t & nb2,
57
+ constant int64_t & nb3,
58
+ uint3 tgpig[[threadgroup_position_in_grid]],
59
+ uint3 tpitg[[thread_position_in_threadgroup]],
60
+ uint3 ntg[[threads_per_threadgroup]]) {
61
+ const int64_t i03 = tgpig.z;
62
+ const int64_t i02 = tgpig.y;
63
+ const int64_t i01 = tgpig.x;
64
+
65
+ const int64_t i13 = i03 % ne13;
66
+ const int64_t i12 = i02 % ne12;
67
+ const int64_t i11 = i01 % ne11;
68
+
69
+ device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + tpitg.x*nb00;
70
+ device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
71
+ device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;
72
+
73
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
74
+ ((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0] + ((device float *)src1_ptr)[0];
75
+
76
+ src0_ptr += ntg.x*nb00;
77
+ src1_ptr += ntg.x*nb10;
78
+ dst_ptr += ntg.x*nb0;
79
+ }
33
80
  }
34
81
 
35
82
  // assumption: src1 is a row
@@ -38,7 +85,7 @@ kernel void kernel_add_row(
38
85
  device const float4 * src0,
39
86
  device const float4 * src1,
40
87
  device float4 * dst,
41
- constant int64_t & nb,
88
+ constant int64_t & nb [[buffer(27)]],
42
89
  uint tpig[[thread_position_in_grid]]) {
43
90
  dst[tpig] = src0[tpig] + src1[tpig % nb];
44
91
  }
@@ -806,30 +853,61 @@ kernel void kernel_alibi_f32(
806
853
  }
807
854
  }
808
855
 
856
+ typedef void (rope_t)(
857
+ device const void * src0,
858
+ device const int32_t * src1,
859
+ device float * dst,
860
+ constant int64_t & ne00,
861
+ constant int64_t & ne01,
862
+ constant int64_t & ne02,
863
+ constant int64_t & ne03,
864
+ constant uint64_t & nb00,
865
+ constant uint64_t & nb01,
866
+ constant uint64_t & nb02,
867
+ constant uint64_t & nb03,
868
+ constant int64_t & ne0,
869
+ constant int64_t & ne1,
870
+ constant int64_t & ne2,
871
+ constant int64_t & ne3,
872
+ constant uint64_t & nb0,
873
+ constant uint64_t & nb1,
874
+ constant uint64_t & nb2,
875
+ constant uint64_t & nb3,
876
+ constant int & n_past,
877
+ constant int & n_dims,
878
+ constant int & mode,
879
+ constant float & freq_base,
880
+ constant float & freq_scale,
881
+ uint tiitg[[thread_index_in_threadgroup]],
882
+ uint3 tptg[[threads_per_threadgroup]],
883
+ uint3 tgpig[[threadgroup_position_in_grid]]);
884
+
885
+ template<typename T>
809
886
  kernel void kernel_rope(
810
- device const void * src0,
811
- device float * dst,
812
- constant int64_t & ne00,
813
- constant int64_t & ne01,
814
- constant int64_t & ne02,
815
- constant int64_t & ne03,
816
- constant uint64_t & nb00,
817
- constant uint64_t & nb01,
818
- constant uint64_t & nb02,
819
- constant uint64_t & nb03,
820
- constant int64_t & ne0,
821
- constant int64_t & ne1,
822
- constant int64_t & ne2,
823
- constant int64_t & ne3,
824
- constant uint64_t & nb0,
825
- constant uint64_t & nb1,
826
- constant uint64_t & nb2,
827
- constant uint64_t & nb3,
828
- constant int & n_past,
829
- constant int & n_dims,
830
- constant int & mode,
831
- constant float & freq_base,
832
- constant float & freq_scale,
887
+ device const void * src0,
888
+ device const int32_t * src1,
889
+ device float * dst,
890
+ constant int64_t & ne00,
891
+ constant int64_t & ne01,
892
+ constant int64_t & ne02,
893
+ constant int64_t & ne03,
894
+ constant uint64_t & nb00,
895
+ constant uint64_t & nb01,
896
+ constant uint64_t & nb02,
897
+ constant uint64_t & nb03,
898
+ constant int64_t & ne0,
899
+ constant int64_t & ne1,
900
+ constant int64_t & ne2,
901
+ constant int64_t & ne3,
902
+ constant uint64_t & nb0,
903
+ constant uint64_t & nb1,
904
+ constant uint64_t & nb2,
905
+ constant uint64_t & nb3,
906
+ constant int & n_past,
907
+ constant int & n_dims,
908
+ constant int & mode,
909
+ constant float & freq_base,
910
+ constant float & freq_scale,
833
911
  uint tiitg[[thread_index_in_threadgroup]],
834
912
  uint3 tptg[[threads_per_threadgroup]],
835
913
  uint3 tgpig[[threadgroup_position_in_grid]]) {
@@ -839,7 +917,9 @@ kernel void kernel_rope(
839
917
 
840
918
  const bool is_neox = mode & 2;
841
919
 
842
- const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
920
+ device const int32_t * pos = src1;
921
+
922
+ const int64_t p = pos[i2];
843
923
 
844
924
  const float theta_0 = freq_scale * (float)p;
845
925
  const float inv_ndims = -1.f/n_dims;
@@ -851,11 +931,11 @@ kernel void kernel_rope(
851
931
  const float cos_theta = cos(theta);
852
932
  const float sin_theta = sin(theta);
853
933
 
854
- device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
855
- device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
934
+ device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
935
+ device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
856
936
 
857
- const float x0 = src[0];
858
- const float x1 = src[1];
937
+ const T x0 = src[0];
938
+ const T x1 = src[1];
859
939
 
860
940
  dst_data[0] = x0*cos_theta - x1*sin_theta;
861
941
  dst_data[1] = x0*sin_theta + x1*cos_theta;
@@ -870,8 +950,8 @@ kernel void kernel_rope(
870
950
 
871
951
  const int64_t i0 = ib*n_dims + ic/2;
872
952
 
873
- device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
874
- device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
953
+ device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
954
+ device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
875
955
 
876
956
  const float x0 = src[0];
877
957
  const float x1 = src[n_dims/2];
@@ -883,6 +963,9 @@ kernel void kernel_rope(
883
963
  }
884
964
  }
885
965
 
966
+ template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope<float>;
967
+ template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope<half>;
968
+
886
969
  kernel void kernel_cpy_f16_f16(
887
970
  device const half * src0,
888
971
  device half * dst,
@@ -1273,8 +1356,8 @@ kernel void kernel_mul_mat_q3_K_f32(
1273
1356
 
1274
1357
  float yl[32];
1275
1358
 
1276
- const uint16_t kmask1 = 0x3030;
1277
- const uint16_t kmask2 = 0x0f0f;
1359
+ //const uint16_t kmask1 = 0x3030;
1360
+ //const uint16_t kmask2 = 0x0f0f;
1278
1361
 
1279
1362
  const int tid = tiisg/4;
1280
1363
  const int ix = tiisg%4;