llama_cpp 0.5.2 → 0.6.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -24,12 +24,59 @@ typedef struct {
24
24
  int8_t qs[QK8_0]; // quants
25
25
  } block_q8_0;
26
26
 
27
+ // general-purpose kernel for addition of two tensors
28
+ // pros: works for non-contiguous tensors, supports broadcast across dims 1, 2 and 3
29
+ // cons: not very efficient
27
30
  kernel void kernel_add(
28
- device const float4 * src0,
29
- device const float4 * src1,
30
- device float4 * dst,
31
- uint tpig[[thread_position_in_grid]]) {
32
- dst[tpig] = src0[tpig] + src1[tpig];
31
+ device const char * src0,
32
+ device const char * src1,
33
+ device char * dst,
34
+ constant int64_t & ne00,
35
+ constant int64_t & ne01,
36
+ constant int64_t & ne02,
37
+ constant int64_t & ne03,
38
+ constant int64_t & nb00,
39
+ constant int64_t & nb01,
40
+ constant int64_t & nb02,
41
+ constant int64_t & nb03,
42
+ constant int64_t & ne10,
43
+ constant int64_t & ne11,
44
+ constant int64_t & ne12,
45
+ constant int64_t & ne13,
46
+ constant int64_t & nb10,
47
+ constant int64_t & nb11,
48
+ constant int64_t & nb12,
49
+ constant int64_t & nb13,
50
+ constant int64_t & ne0,
51
+ constant int64_t & ne1,
52
+ constant int64_t & ne2,
53
+ constant int64_t & ne3,
54
+ constant int64_t & nb0,
55
+ constant int64_t & nb1,
56
+ constant int64_t & nb2,
57
+ constant int64_t & nb3,
58
+ uint3 tgpig[[threadgroup_position_in_grid]],
59
+ uint3 tpitg[[thread_position_in_threadgroup]],
60
+ uint3 ntg[[threads_per_threadgroup]]) {
61
+ const int64_t i03 = tgpig.z;
62
+ const int64_t i02 = tgpig.y;
63
+ const int64_t i01 = tgpig.x;
64
+
65
+ const int64_t i13 = i03 % ne13;
66
+ const int64_t i12 = i02 % ne12;
67
+ const int64_t i11 = i01 % ne11;
68
+
69
+ device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + tpitg.x*nb00;
70
+ device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
71
+ device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;
72
+
73
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
74
+ ((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0] + ((device float *)src1_ptr)[0];
75
+
76
+ src0_ptr += ntg.x*nb00;
77
+ src1_ptr += ntg.x*nb10;
78
+ dst_ptr += ntg.x*nb0;
79
+ }
33
80
  }
34
81
 
35
82
  // assumption: src1 is a row
@@ -38,7 +85,7 @@ kernel void kernel_add_row(
38
85
  device const float4 * src0,
39
86
  device const float4 * src1,
40
87
  device float4 * dst,
41
- constant int64_t & nb,
88
+ constant int64_t & nb [[buffer(27)]],
42
89
  uint tpig[[thread_position_in_grid]]) {
43
90
  dst[tpig] = src0[tpig] + src1[tpig % nb];
44
91
  }
@@ -118,7 +165,7 @@ kernel void kernel_soft_max(
118
165
  device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
119
166
 
120
167
  // parallel max
121
- float lmax = psrc0[tpitg[0]];
168
+ float lmax = tpitg[0] < ne00 ? psrc0[tpitg[0]] : -INFINITY;
122
169
  for (int i00 = tpitg[0] + ntg[0]; i00 < ne00; i00 += ntg[0]) {
123
170
  lmax = MAX(lmax, psrc0[i00]);
124
171
  }
@@ -158,7 +205,7 @@ kernel void kernel_soft_max_4(
158
205
  device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
159
206
 
160
207
  // parallel max
161
- float4 lmax4 = psrc4[tpitg[0]];
208
+ float4 lmax4 = tpitg[0] < ne00/4 ? psrc4[tpitg[0]] : -INFINITY;
162
209
  for (int i00 = tpitg[0] + ntg[0]; i00 < ne00/4; i00 += ntg[0]) {
163
210
  lmax4 = fmax(lmax4, psrc4[i00]);
164
211
  }
@@ -523,6 +570,79 @@ kernel void kernel_mul_mat_q8_0_f32(
523
570
  }
524
571
  }
525
572
 
573
+ #define N_F32_F32 4
574
+
575
+ kernel void kernel_mul_mat_f32_f32(
576
+ device const char * src0,
577
+ device const char * src1,
578
+ device float * dst,
579
+ constant int64_t & ne00,
580
+ constant int64_t & ne01,
581
+ constant int64_t & ne02,
582
+ constant uint64_t & nb00,
583
+ constant uint64_t & nb01,
584
+ constant uint64_t & nb02,
585
+ constant int64_t & ne10,
586
+ constant int64_t & ne11,
587
+ constant int64_t & ne12,
588
+ constant uint64_t & nb10,
589
+ constant uint64_t & nb11,
590
+ constant uint64_t & nb12,
591
+ constant int64_t & ne0,
592
+ constant int64_t & ne1,
593
+ uint3 tgpig[[threadgroup_position_in_grid]],
594
+ uint tiisg[[thread_index_in_simdgroup]]) {
595
+
596
+ const int64_t r0 = tgpig.x;
597
+ const int64_t rb = tgpig.y*N_F32_F32;
598
+ const int64_t im = tgpig.z;
599
+
600
+ device const float * x = (device const float *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
601
+
602
+ if (ne00 < 128) {
603
+ for (int row = 0; row < N_F32_F32; ++row) {
604
+ int r1 = rb + row;
605
+ if (r1 >= ne11) {
606
+ break;
607
+ }
608
+
609
+ device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
610
+
611
+ float sumf = 0;
612
+ for (int i = tiisg; i < ne00; i += 32) {
613
+ sumf += (float) x[i] * (float) y[i];
614
+ }
615
+
616
+ float all_sum = simd_sum(sumf);
617
+ if (tiisg == 0) {
618
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
619
+ }
620
+ }
621
+ } else {
622
+ device const float4 * x4 = (device const float4 *)x;
623
+ for (int row = 0; row < N_F32_F32; ++row) {
624
+ int r1 = rb + row;
625
+ if (r1 >= ne11) {
626
+ break;
627
+ }
628
+
629
+ device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
630
+ device const float4 * y4 = (device const float4 *) y;
631
+
632
+ float sumf = 0;
633
+ for (int i = tiisg; i < ne00/4; i += 32) {
634
+ for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
635
+ }
636
+
637
+ float all_sum = simd_sum(sumf);
638
+ if (tiisg == 0) {
639
+ for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
640
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
641
+ }
642
+ }
643
+ }
644
+ }
645
+
526
646
  kernel void kernel_mul_mat_f16_f32_1row(
527
647
  device const char * src0,
528
648
  device const char * src1,
@@ -733,30 +853,61 @@ kernel void kernel_alibi_f32(
733
853
  }
734
854
  }
735
855
 
856
+ typedef void (rope_t)(
857
+ device const void * src0,
858
+ device const int32_t * src1,
859
+ device float * dst,
860
+ constant int64_t & ne00,
861
+ constant int64_t & ne01,
862
+ constant int64_t & ne02,
863
+ constant int64_t & ne03,
864
+ constant uint64_t & nb00,
865
+ constant uint64_t & nb01,
866
+ constant uint64_t & nb02,
867
+ constant uint64_t & nb03,
868
+ constant int64_t & ne0,
869
+ constant int64_t & ne1,
870
+ constant int64_t & ne2,
871
+ constant int64_t & ne3,
872
+ constant uint64_t & nb0,
873
+ constant uint64_t & nb1,
874
+ constant uint64_t & nb2,
875
+ constant uint64_t & nb3,
876
+ constant int & n_past,
877
+ constant int & n_dims,
878
+ constant int & mode,
879
+ constant float & freq_base,
880
+ constant float & freq_scale,
881
+ uint tiitg[[thread_index_in_threadgroup]],
882
+ uint3 tptg[[threads_per_threadgroup]],
883
+ uint3 tgpig[[threadgroup_position_in_grid]]);
884
+
885
+ template<typename T>
736
886
  kernel void kernel_rope(
737
- device const void * src0,
738
- device float * dst,
739
- constant int64_t & ne00,
740
- constant int64_t & ne01,
741
- constant int64_t & ne02,
742
- constant int64_t & ne03,
743
- constant uint64_t & nb00,
744
- constant uint64_t & nb01,
745
- constant uint64_t & nb02,
746
- constant uint64_t & nb03,
747
- constant int64_t & ne0,
748
- constant int64_t & ne1,
749
- constant int64_t & ne2,
750
- constant int64_t & ne3,
751
- constant uint64_t & nb0,
752
- constant uint64_t & nb1,
753
- constant uint64_t & nb2,
754
- constant uint64_t & nb3,
755
- constant int & n_past,
756
- constant int & n_dims,
757
- constant int & mode,
758
- constant float & freq_base,
759
- constant float & freq_scale,
887
+ device const void * src0,
888
+ device const int32_t * src1,
889
+ device float * dst,
890
+ constant int64_t & ne00,
891
+ constant int64_t & ne01,
892
+ constant int64_t & ne02,
893
+ constant int64_t & ne03,
894
+ constant uint64_t & nb00,
895
+ constant uint64_t & nb01,
896
+ constant uint64_t & nb02,
897
+ constant uint64_t & nb03,
898
+ constant int64_t & ne0,
899
+ constant int64_t & ne1,
900
+ constant int64_t & ne2,
901
+ constant int64_t & ne3,
902
+ constant uint64_t & nb0,
903
+ constant uint64_t & nb1,
904
+ constant uint64_t & nb2,
905
+ constant uint64_t & nb3,
906
+ constant int & n_past,
907
+ constant int & n_dims,
908
+ constant int & mode,
909
+ constant float & freq_base,
910
+ constant float & freq_scale,
760
911
  uint tiitg[[thread_index_in_threadgroup]],
761
912
  uint3 tptg[[threads_per_threadgroup]],
762
913
  uint3 tgpig[[threadgroup_position_in_grid]]) {
@@ -766,7 +917,9 @@ kernel void kernel_rope(
766
917
 
767
918
  const bool is_neox = mode & 2;
768
919
 
769
- const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
920
+ device const int32_t * pos = src1;
921
+
922
+ const int64_t p = pos[i2];
770
923
 
771
924
  const float theta_0 = freq_scale * (float)p;
772
925
  const float inv_ndims = -1.f/n_dims;
@@ -778,11 +931,11 @@ kernel void kernel_rope(
778
931
  const float cos_theta = cos(theta);
779
932
  const float sin_theta = sin(theta);
780
933
 
781
- device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
782
- device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
934
+ device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
935
+ device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
783
936
 
784
- const float x0 = src[0];
785
- const float x1 = src[1];
937
+ const T x0 = src[0];
938
+ const T x1 = src[1];
786
939
 
787
940
  dst_data[0] = x0*cos_theta - x1*sin_theta;
788
941
  dst_data[1] = x0*sin_theta + x1*cos_theta;
@@ -797,8 +950,8 @@ kernel void kernel_rope(
797
950
 
798
951
  const int64_t i0 = ib*n_dims + ic/2;
799
952
 
800
- device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
801
- device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
953
+ device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
954
+ device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
802
955
 
803
956
  const float x0 = src[0];
804
957
  const float x1 = src[n_dims/2];
@@ -810,6 +963,9 @@ kernel void kernel_rope(
810
963
  }
811
964
  }
812
965
 
966
+ template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope<float>;
967
+ template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope<half>;
968
+
813
969
  kernel void kernel_cpy_f16_f16(
814
970
  device const half * src0,
815
971
  device half * dst,
@@ -1200,8 +1356,8 @@ kernel void kernel_mul_mat_q3_K_f32(
1200
1356
 
1201
1357
  float yl[32];
1202
1358
 
1203
- const uint16_t kmask1 = 0x3030;
1204
- const uint16_t kmask2 = 0x0f0f;
1359
+ //const uint16_t kmask1 = 0x3030;
1360
+ //const uint16_t kmask2 = 0x0f0f;
1205
1361
 
1206
1362
  const int tid = tiisg/4;
1207
1363
  const int ix = tiisg%4;
@@ -1321,7 +1477,6 @@ kernel void kernel_mul_mat_q3_K_f32(
1321
1477
  dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = sumf1[row];
1322
1478
  }
1323
1479
  }
1324
-
1325
1480
  }
1326
1481
  #else
1327
1482
  kernel void kernel_mul_mat_q3_K_f32(
@@ -1400,13 +1555,13 @@ kernel void kernel_mul_mat_q4_K_f32(
1400
1555
  device const float * src1,
1401
1556
  device float * dst,
1402
1557
  constant int64_t & ne00,
1403
- constant int64_t & ne01[[buffer(4)]],
1404
- constant int64_t & ne02[[buffer(5)]],
1405
- constant int64_t & ne10[[buffer(9)]],
1406
- constant int64_t & ne12[[buffer(11)]],
1407
- constant int64_t & ne0[[buffer(15)]],
1408
- constant int64_t & ne1[[buffer(16)]],
1409
- constant uint & gqa[[buffer(17)]],
1558
+ constant int64_t & ne01 [[buffer(4)]],
1559
+ constant int64_t & ne02 [[buffer(5)]],
1560
+ constant int64_t & ne10 [[buffer(9)]],
1561
+ constant int64_t & ne12 [[buffer(11)]],
1562
+ constant int64_t & ne0 [[buffer(15)]],
1563
+ constant int64_t & ne1 [[buffer(16)]],
1564
+ constant uint & gqa [[buffer(17)]],
1410
1565
  uint3 tgpig[[threadgroup_position_in_grid]],
1411
1566
  uint tiisg[[thread_index_in_simdgroup]],
1412
1567
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -1865,6 +2020,15 @@ kernel void kernel_mul_mat_q6_K_f32(
1865
2020
 
1866
2021
  //============================= templates and their specializations =============================
1867
2022
 
2023
+ // NOTE: this is not dequantizing - we are simply fitting the template
2024
+ template <typename type4x4>
2025
+ void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
2026
+ float4x4 temp = *(((device float4x4 *)src));
2027
+ for (int i = 0; i < 16; i++){
2028
+ reg[i/4][i%4] = temp[i/4][i%4];
2029
+ }
2030
+ }
2031
+
1868
2032
  template <typename type4x4>
1869
2033
  void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
1870
2034
  half4x4 temp = *(((device half4x4 *)src));
@@ -1875,7 +2039,6 @@ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg)
1875
2039
 
1876
2040
  template <typename type4x4>
1877
2041
  void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
1878
-
1879
2042
  device const uint16_t * qs = ((device const uint16_t *)xb + 1);
1880
2043
  const float d1 = il ? (xb->d / 16.h) : xb->d;
1881
2044
  const float d2 = d1 / 256.f;
@@ -1887,12 +2050,10 @@ void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg
1887
2050
  reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md;
1888
2051
  reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md;
1889
2052
  }
1890
-
1891
2053
  }
1892
2054
 
1893
2055
  template <typename type4x4>
1894
2056
  void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
1895
-
1896
2057
  device const uint16_t * qs = ((device const uint16_t *)xb + 2);
1897
2058
  const float d1 = il ? (xb->d / 16.h) : xb->d;
1898
2059
  const float d2 = d1 / 256.f;
@@ -1964,7 +2125,6 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
1964
2125
  for (int i = 0; i < 16; ++i) {
1965
2126
  reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml);
1966
2127
  }
1967
-
1968
2128
  #else
1969
2129
  float kcoef = il&1 ? 1.f/16.f : 1.f;
1970
2130
  uint16_t kmask = il&1 ? 0xF0 : 0x0F;
@@ -2008,7 +2168,6 @@ void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg
2008
2168
  for (int i = 0; i < 16; ++i) {
2009
2169
  reg[i/4][i%4] = dl * (q[i] & mask) - ml;
2010
2170
  }
2011
-
2012
2171
  }
2013
2172
 
2014
2173
  template <typename type4x4>
@@ -2110,22 +2269,25 @@ kernel void kernel_get_rows(
2110
2269
  // each block_q contains 16*nl weights
2111
2270
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
2112
2271
  kernel void kernel_mul_mm(device const uchar * src0,
2113
- device const float * src1,
2114
- device float * dst,
2115
- constant int64_t & ne00,
2116
- constant int64_t & ne02,
2117
- constant int64_t & nb01,
2118
- constant int64_t & nb02,
2119
- constant int64_t & ne12,
2120
- constant int64_t & ne0,
2121
- constant int64_t & ne1,
2122
- constant uint & gqa,
2123
- threadgroup uchar * shared_memory [[threadgroup(0)]],
2124
- uint3 tgpig[[threadgroup_position_in_grid]],
2125
- uint tiitg[[thread_index_in_threadgroup]],
2126
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
2127
-
2128
- threadgroup half * sa = ((threadgroup half *)shared_memory);
2272
+ device const uchar * src1,
2273
+ device float * dst,
2274
+ constant int64_t & ne00,
2275
+ constant int64_t & ne02,
2276
+ constant int64_t & nb01,
2277
+ constant int64_t & nb02,
2278
+ constant int64_t & ne12,
2279
+ constant int64_t & nb10,
2280
+ constant int64_t & nb11,
2281
+ constant int64_t & nb12,
2282
+ constant int64_t & ne0,
2283
+ constant int64_t & ne1,
2284
+ constant uint & gqa,
2285
+ threadgroup uchar * shared_memory [[threadgroup(0)]],
2286
+ uint3 tgpig[[threadgroup_position_in_grid]],
2287
+ uint tiitg[[thread_index_in_threadgroup]],
2288
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
2289
+
2290
+ threadgroup half * sa = (threadgroup half *)(shared_memory);
2129
2291
  threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
2130
2292
 
2131
2293
  const uint r0 = tgpig.y;
@@ -2138,7 +2300,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
2138
2300
  short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
2139
2301
  short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
2140
2302
 
2141
- simdgroup_half8x8 ma[4];
2303
+ simdgroup_half8x8 ma[4];
2142
2304
  simdgroup_float8x8 mb[2];
2143
2305
  simdgroup_float8x8 c_res[8];
2144
2306
  for (int i = 0; i < 8; i++){
@@ -2146,10 +2308,15 @@ kernel void kernel_mul_mm(device const uchar * src0,
2146
2308
  }
2147
2309
 
2148
2310
  short il = (tiitg % THREAD_PER_ROW);
2149
- uint offset0 = im/gqa*nb02; ushort offset1 = il/nl;
2150
- device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
2151
- device const float * y = src1 + (r1 * BLOCK_SIZE_N + thread_col) * ne00 \
2152
- + BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL) + im * ne00 * ne1;
2311
+
2312
+ uint offset0 = im/gqa*nb02;
2313
+ ushort offset1 = il/nl;
2314
+
2315
+ device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
2316
+ device const float * y = (device const float *)(src1
2317
+ + nb12 * im
2318
+ + nb11 * (r1 * BLOCK_SIZE_N + thread_col)
2319
+ + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
2153
2320
 
2154
2321
  for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
2155
2322
  //load data and store to threadgroup memory
@@ -2229,6 +2396,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
2229
2396
  typedef void (get_rows_t)(device const void *, device const int *, device float *, constant int64_t &, \
2230
2397
  constant uint64_t &, constant uint64_t &, uint, uint, uint);
2231
2398
 
2399
+ template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
2232
2400
  template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
2233
2401
  template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
2234
2402
  template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
@@ -2239,14 +2407,28 @@ template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows
2239
2407
  template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
2240
2408
  template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
2241
2409
 
2242
- typedef void (mat_mm_t)(device const uchar *, device const float *, device float *, constant int64_t &,\
2243
- constant int64_t &, constant int64_t &, constant int64_t &, constant int64_t &, \
2244
- constant int64_t &, constant int64_t &, constant uint &, threadgroup uchar *, uint3, uint, uint);
2245
-
2246
- template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
2247
- template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
2248
- template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
2249
- template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
2410
+ typedef void (mat_mm_t)(
2411
+ device const uchar * src0,
2412
+ device const uchar * src1,
2413
+ device float * dst,
2414
+ constant int64_t & ne00,
2415
+ constant int64_t & ne02,
2416
+ constant int64_t & nb01,
2417
+ constant int64_t & nb02,
2418
+ constant int64_t & ne12,
2419
+ constant int64_t & nb10,
2420
+ constant int64_t & nb11,
2421
+ constant int64_t & nb12,
2422
+ constant int64_t & ne0,
2423
+ constant int64_t & ne1,
2424
+ constant uint & gqa,
2425
+ threadgroup uchar *, uint3, uint, uint);
2426
+
2427
+ template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<float4x4, 1, dequantize_f32>;
2428
+ template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
2429
+ template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
2430
+ template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
2431
+ template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
2250
2432
  template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
2251
2433
  template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
2252
2434
  template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
@@ -847,7 +847,7 @@ std::array<std::string, 2> mul_str_values = {
847
847
  "mul_f32", "float"
848
848
  };
849
849
 
850
- std::string& replace(std::string& s, const std::string& from, const std::string& to) {
850
+ static std::string& replace(std::string& s, const std::string& from, const std::string& to) {
851
851
  size_t pos = 0;
852
852
  while ((pos = s.find(from, pos)) != std::string::npos) {
853
853
  s.replace(pos, from.length(), to);
@@ -856,7 +856,7 @@ std::string& replace(std::string& s, const std::string& from, const std::string&
856
856
  return s;
857
857
  }
858
858
 
859
- std::string generate_kernels() {
859
+ static std::string generate_kernels() {
860
860
  std::stringstream src;
861
861
  src << program_source << '\n';
862
862
  src << k_quants_source << '\n';
@@ -1788,7 +1788,7 @@ bool ggml_cl_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tens
1788
1788
  return false;
1789
1789
  }
1790
1790
 
1791
- bool ggml_cl_mul_mat_use_f16(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * /* dst */) {
1791
+ static bool ggml_cl_mul_mat_use_f16(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * /* dst */) {
1792
1792
  // If device doesn't support FP16
1793
1793
  if (!fp16_support) {
1794
1794
  return false;