llama_cpp 0.5.2 → 0.6.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -24,12 +24,59 @@ typedef struct {
24
24
  int8_t qs[QK8_0]; // quants
25
25
  } block_q8_0;
26
26
 
27
+ // general-purpose kernel for addition of two tensors
28
+ // pros: works for non-contiguous tensors, supports broadcast across dims 1, 2 and 3
29
+ // cons: not very efficient
27
30
  kernel void kernel_add(
28
- device const float4 * src0,
29
- device const float4 * src1,
30
- device float4 * dst,
31
- uint tpig[[thread_position_in_grid]]) {
32
- dst[tpig] = src0[tpig] + src1[tpig];
31
+ device const char * src0,
32
+ device const char * src1,
33
+ device char * dst,
34
+ constant int64_t & ne00,
35
+ constant int64_t & ne01,
36
+ constant int64_t & ne02,
37
+ constant int64_t & ne03,
38
+ constant int64_t & nb00,
39
+ constant int64_t & nb01,
40
+ constant int64_t & nb02,
41
+ constant int64_t & nb03,
42
+ constant int64_t & ne10,
43
+ constant int64_t & ne11,
44
+ constant int64_t & ne12,
45
+ constant int64_t & ne13,
46
+ constant int64_t & nb10,
47
+ constant int64_t & nb11,
48
+ constant int64_t & nb12,
49
+ constant int64_t & nb13,
50
+ constant int64_t & ne0,
51
+ constant int64_t & ne1,
52
+ constant int64_t & ne2,
53
+ constant int64_t & ne3,
54
+ constant int64_t & nb0,
55
+ constant int64_t & nb1,
56
+ constant int64_t & nb2,
57
+ constant int64_t & nb3,
58
+ uint3 tgpig[[threadgroup_position_in_grid]],
59
+ uint3 tpitg[[thread_position_in_threadgroup]],
60
+ uint3 ntg[[threads_per_threadgroup]]) {
61
+ const int64_t i03 = tgpig.z;
62
+ const int64_t i02 = tgpig.y;
63
+ const int64_t i01 = tgpig.x;
64
+
65
+ const int64_t i13 = i03 % ne13;
66
+ const int64_t i12 = i02 % ne12;
67
+ const int64_t i11 = i01 % ne11;
68
+
69
+ device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + tpitg.x*nb00;
70
+ device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
71
+ device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;
72
+
73
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
74
+ ((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0] + ((device float *)src1_ptr)[0];
75
+
76
+ src0_ptr += ntg.x*nb00;
77
+ src1_ptr += ntg.x*nb10;
78
+ dst_ptr += ntg.x*nb0;
79
+ }
33
80
  }
34
81
 
35
82
  // assumption: src1 is a row
@@ -38,7 +85,7 @@ kernel void kernel_add_row(
38
85
  device const float4 * src0,
39
86
  device const float4 * src1,
40
87
  device float4 * dst,
41
- constant int64_t & nb,
88
+ constant int64_t & nb [[buffer(27)]],
42
89
  uint tpig[[thread_position_in_grid]]) {
43
90
  dst[tpig] = src0[tpig] + src1[tpig % nb];
44
91
  }
@@ -118,7 +165,7 @@ kernel void kernel_soft_max(
118
165
  device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
119
166
 
120
167
  // parallel max
121
- float lmax = psrc0[tpitg[0]];
168
+ float lmax = tpitg[0] < ne00 ? psrc0[tpitg[0]] : -INFINITY;
122
169
  for (int i00 = tpitg[0] + ntg[0]; i00 < ne00; i00 += ntg[0]) {
123
170
  lmax = MAX(lmax, psrc0[i00]);
124
171
  }
@@ -158,7 +205,7 @@ kernel void kernel_soft_max_4(
158
205
  device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
159
206
 
160
207
  // parallel max
161
- float4 lmax4 = psrc4[tpitg[0]];
208
+ float4 lmax4 = tpitg[0] < ne00/4 ? psrc4[tpitg[0]] : -INFINITY;
162
209
  for (int i00 = tpitg[0] + ntg[0]; i00 < ne00/4; i00 += ntg[0]) {
163
210
  lmax4 = fmax(lmax4, psrc4[i00]);
164
211
  }
@@ -523,6 +570,79 @@ kernel void kernel_mul_mat_q8_0_f32(
523
570
  }
524
571
  }
525
572
 
573
+ #define N_F32_F32 4
574
+
575
+ kernel void kernel_mul_mat_f32_f32(
576
+ device const char * src0,
577
+ device const char * src1,
578
+ device float * dst,
579
+ constant int64_t & ne00,
580
+ constant int64_t & ne01,
581
+ constant int64_t & ne02,
582
+ constant uint64_t & nb00,
583
+ constant uint64_t & nb01,
584
+ constant uint64_t & nb02,
585
+ constant int64_t & ne10,
586
+ constant int64_t & ne11,
587
+ constant int64_t & ne12,
588
+ constant uint64_t & nb10,
589
+ constant uint64_t & nb11,
590
+ constant uint64_t & nb12,
591
+ constant int64_t & ne0,
592
+ constant int64_t & ne1,
593
+ uint3 tgpig[[threadgroup_position_in_grid]],
594
+ uint tiisg[[thread_index_in_simdgroup]]) {
595
+
596
+ const int64_t r0 = tgpig.x;
597
+ const int64_t rb = tgpig.y*N_F32_F32;
598
+ const int64_t im = tgpig.z;
599
+
600
+ device const float * x = (device const float *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
601
+
602
+ if (ne00 < 128) {
603
+ for (int row = 0; row < N_F32_F32; ++row) {
604
+ int r1 = rb + row;
605
+ if (r1 >= ne11) {
606
+ break;
607
+ }
608
+
609
+ device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
610
+
611
+ float sumf = 0;
612
+ for (int i = tiisg; i < ne00; i += 32) {
613
+ sumf += (float) x[i] * (float) y[i];
614
+ }
615
+
616
+ float all_sum = simd_sum(sumf);
617
+ if (tiisg == 0) {
618
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
619
+ }
620
+ }
621
+ } else {
622
+ device const float4 * x4 = (device const float4 *)x;
623
+ for (int row = 0; row < N_F32_F32; ++row) {
624
+ int r1 = rb + row;
625
+ if (r1 >= ne11) {
626
+ break;
627
+ }
628
+
629
+ device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
630
+ device const float4 * y4 = (device const float4 *) y;
631
+
632
+ float sumf = 0;
633
+ for (int i = tiisg; i < ne00/4; i += 32) {
634
+ for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
635
+ }
636
+
637
+ float all_sum = simd_sum(sumf);
638
+ if (tiisg == 0) {
639
+ for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
640
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
641
+ }
642
+ }
643
+ }
644
+ }
645
+
526
646
  kernel void kernel_mul_mat_f16_f32_1row(
527
647
  device const char * src0,
528
648
  device const char * src1,
@@ -733,30 +853,61 @@ kernel void kernel_alibi_f32(
733
853
  }
734
854
  }
735
855
 
856
+ typedef void (rope_t)(
857
+ device const void * src0,
858
+ device const int32_t * src1,
859
+ device float * dst,
860
+ constant int64_t & ne00,
861
+ constant int64_t & ne01,
862
+ constant int64_t & ne02,
863
+ constant int64_t & ne03,
864
+ constant uint64_t & nb00,
865
+ constant uint64_t & nb01,
866
+ constant uint64_t & nb02,
867
+ constant uint64_t & nb03,
868
+ constant int64_t & ne0,
869
+ constant int64_t & ne1,
870
+ constant int64_t & ne2,
871
+ constant int64_t & ne3,
872
+ constant uint64_t & nb0,
873
+ constant uint64_t & nb1,
874
+ constant uint64_t & nb2,
875
+ constant uint64_t & nb3,
876
+ constant int & n_past,
877
+ constant int & n_dims,
878
+ constant int & mode,
879
+ constant float & freq_base,
880
+ constant float & freq_scale,
881
+ uint tiitg[[thread_index_in_threadgroup]],
882
+ uint3 tptg[[threads_per_threadgroup]],
883
+ uint3 tgpig[[threadgroup_position_in_grid]]);
884
+
885
+ template<typename T>
736
886
  kernel void kernel_rope(
737
- device const void * src0,
738
- device float * dst,
739
- constant int64_t & ne00,
740
- constant int64_t & ne01,
741
- constant int64_t & ne02,
742
- constant int64_t & ne03,
743
- constant uint64_t & nb00,
744
- constant uint64_t & nb01,
745
- constant uint64_t & nb02,
746
- constant uint64_t & nb03,
747
- constant int64_t & ne0,
748
- constant int64_t & ne1,
749
- constant int64_t & ne2,
750
- constant int64_t & ne3,
751
- constant uint64_t & nb0,
752
- constant uint64_t & nb1,
753
- constant uint64_t & nb2,
754
- constant uint64_t & nb3,
755
- constant int & n_past,
756
- constant int & n_dims,
757
- constant int & mode,
758
- constant float & freq_base,
759
- constant float & freq_scale,
887
+ device const void * src0,
888
+ device const int32_t * src1,
889
+ device float * dst,
890
+ constant int64_t & ne00,
891
+ constant int64_t & ne01,
892
+ constant int64_t & ne02,
893
+ constant int64_t & ne03,
894
+ constant uint64_t & nb00,
895
+ constant uint64_t & nb01,
896
+ constant uint64_t & nb02,
897
+ constant uint64_t & nb03,
898
+ constant int64_t & ne0,
899
+ constant int64_t & ne1,
900
+ constant int64_t & ne2,
901
+ constant int64_t & ne3,
902
+ constant uint64_t & nb0,
903
+ constant uint64_t & nb1,
904
+ constant uint64_t & nb2,
905
+ constant uint64_t & nb3,
906
+ constant int & n_past,
907
+ constant int & n_dims,
908
+ constant int & mode,
909
+ constant float & freq_base,
910
+ constant float & freq_scale,
760
911
  uint tiitg[[thread_index_in_threadgroup]],
761
912
  uint3 tptg[[threads_per_threadgroup]],
762
913
  uint3 tgpig[[threadgroup_position_in_grid]]) {
@@ -766,7 +917,9 @@ kernel void kernel_rope(
766
917
 
767
918
  const bool is_neox = mode & 2;
768
919
 
769
- const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
920
+ device const int32_t * pos = src1;
921
+
922
+ const int64_t p = pos[i2];
770
923
 
771
924
  const float theta_0 = freq_scale * (float)p;
772
925
  const float inv_ndims = -1.f/n_dims;
@@ -778,11 +931,11 @@ kernel void kernel_rope(
778
931
  const float cos_theta = cos(theta);
779
932
  const float sin_theta = sin(theta);
780
933
 
781
- device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
782
- device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
934
+ device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
935
+ device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
783
936
 
784
- const float x0 = src[0];
785
- const float x1 = src[1];
937
+ const T x0 = src[0];
938
+ const T x1 = src[1];
786
939
 
787
940
  dst_data[0] = x0*cos_theta - x1*sin_theta;
788
941
  dst_data[1] = x0*sin_theta + x1*cos_theta;
@@ -797,8 +950,8 @@ kernel void kernel_rope(
797
950
 
798
951
  const int64_t i0 = ib*n_dims + ic/2;
799
952
 
800
- device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
801
- device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
953
+ device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
954
+ device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
802
955
 
803
956
  const float x0 = src[0];
804
957
  const float x1 = src[n_dims/2];
@@ -810,6 +963,9 @@ kernel void kernel_rope(
810
963
  }
811
964
  }
812
965
 
966
+ template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope<float>;
967
+ template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope<half>;
968
+
813
969
  kernel void kernel_cpy_f16_f16(
814
970
  device const half * src0,
815
971
  device half * dst,
@@ -1200,8 +1356,8 @@ kernel void kernel_mul_mat_q3_K_f32(
1200
1356
 
1201
1357
  float yl[32];
1202
1358
 
1203
- const uint16_t kmask1 = 0x3030;
1204
- const uint16_t kmask2 = 0x0f0f;
1359
+ //const uint16_t kmask1 = 0x3030;
1360
+ //const uint16_t kmask2 = 0x0f0f;
1205
1361
 
1206
1362
  const int tid = tiisg/4;
1207
1363
  const int ix = tiisg%4;
@@ -1321,7 +1477,6 @@ kernel void kernel_mul_mat_q3_K_f32(
1321
1477
  dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = sumf1[row];
1322
1478
  }
1323
1479
  }
1324
-
1325
1480
  }
1326
1481
  #else
1327
1482
  kernel void kernel_mul_mat_q3_K_f32(
@@ -1400,13 +1555,13 @@ kernel void kernel_mul_mat_q4_K_f32(
1400
1555
  device const float * src1,
1401
1556
  device float * dst,
1402
1557
  constant int64_t & ne00,
1403
- constant int64_t & ne01[[buffer(4)]],
1404
- constant int64_t & ne02[[buffer(5)]],
1405
- constant int64_t & ne10[[buffer(9)]],
1406
- constant int64_t & ne12[[buffer(11)]],
1407
- constant int64_t & ne0[[buffer(15)]],
1408
- constant int64_t & ne1[[buffer(16)]],
1409
- constant uint & gqa[[buffer(17)]],
1558
+ constant int64_t & ne01 [[buffer(4)]],
1559
+ constant int64_t & ne02 [[buffer(5)]],
1560
+ constant int64_t & ne10 [[buffer(9)]],
1561
+ constant int64_t & ne12 [[buffer(11)]],
1562
+ constant int64_t & ne0 [[buffer(15)]],
1563
+ constant int64_t & ne1 [[buffer(16)]],
1564
+ constant uint & gqa [[buffer(17)]],
1410
1565
  uint3 tgpig[[threadgroup_position_in_grid]],
1411
1566
  uint tiisg[[thread_index_in_simdgroup]],
1412
1567
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -1865,6 +2020,15 @@ kernel void kernel_mul_mat_q6_K_f32(
1865
2020
 
1866
2021
  //============================= templates and their specializations =============================
1867
2022
 
2023
+ // NOTE: this is not dequantizing - we are simply fitting the template
2024
+ template <typename type4x4>
2025
+ void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
2026
+ float4x4 temp = *(((device float4x4 *)src));
2027
+ for (int i = 0; i < 16; i++){
2028
+ reg[i/4][i%4] = temp[i/4][i%4];
2029
+ }
2030
+ }
2031
+
1868
2032
  template <typename type4x4>
1869
2033
  void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
1870
2034
  half4x4 temp = *(((device half4x4 *)src));
@@ -1875,7 +2039,6 @@ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg)
1875
2039
 
1876
2040
  template <typename type4x4>
1877
2041
  void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
1878
-
1879
2042
  device const uint16_t * qs = ((device const uint16_t *)xb + 1);
1880
2043
  const float d1 = il ? (xb->d / 16.h) : xb->d;
1881
2044
  const float d2 = d1 / 256.f;
@@ -1887,12 +2050,10 @@ void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg
1887
2050
  reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md;
1888
2051
  reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md;
1889
2052
  }
1890
-
1891
2053
  }
1892
2054
 
1893
2055
  template <typename type4x4>
1894
2056
  void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
1895
-
1896
2057
  device const uint16_t * qs = ((device const uint16_t *)xb + 2);
1897
2058
  const float d1 = il ? (xb->d / 16.h) : xb->d;
1898
2059
  const float d2 = d1 / 256.f;
@@ -1964,7 +2125,6 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
1964
2125
  for (int i = 0; i < 16; ++i) {
1965
2126
  reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml);
1966
2127
  }
1967
-
1968
2128
  #else
1969
2129
  float kcoef = il&1 ? 1.f/16.f : 1.f;
1970
2130
  uint16_t kmask = il&1 ? 0xF0 : 0x0F;
@@ -2008,7 +2168,6 @@ void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg
2008
2168
  for (int i = 0; i < 16; ++i) {
2009
2169
  reg[i/4][i%4] = dl * (q[i] & mask) - ml;
2010
2170
  }
2011
-
2012
2171
  }
2013
2172
 
2014
2173
  template <typename type4x4>
@@ -2110,22 +2269,25 @@ kernel void kernel_get_rows(
2110
2269
  // each block_q contains 16*nl weights
2111
2270
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
2112
2271
  kernel void kernel_mul_mm(device const uchar * src0,
2113
- device const float * src1,
2114
- device float * dst,
2115
- constant int64_t & ne00,
2116
- constant int64_t & ne02,
2117
- constant int64_t & nb01,
2118
- constant int64_t & nb02,
2119
- constant int64_t & ne12,
2120
- constant int64_t & ne0,
2121
- constant int64_t & ne1,
2122
- constant uint & gqa,
2123
- threadgroup uchar * shared_memory [[threadgroup(0)]],
2124
- uint3 tgpig[[threadgroup_position_in_grid]],
2125
- uint tiitg[[thread_index_in_threadgroup]],
2126
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
2127
-
2128
- threadgroup half * sa = ((threadgroup half *)shared_memory);
2272
+ device const uchar * src1,
2273
+ device float * dst,
2274
+ constant int64_t & ne00,
2275
+ constant int64_t & ne02,
2276
+ constant int64_t & nb01,
2277
+ constant int64_t & nb02,
2278
+ constant int64_t & ne12,
2279
+ constant int64_t & nb10,
2280
+ constant int64_t & nb11,
2281
+ constant int64_t & nb12,
2282
+ constant int64_t & ne0,
2283
+ constant int64_t & ne1,
2284
+ constant uint & gqa,
2285
+ threadgroup uchar * shared_memory [[threadgroup(0)]],
2286
+ uint3 tgpig[[threadgroup_position_in_grid]],
2287
+ uint tiitg[[thread_index_in_threadgroup]],
2288
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
2289
+
2290
+ threadgroup half * sa = (threadgroup half *)(shared_memory);
2129
2291
  threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
2130
2292
 
2131
2293
  const uint r0 = tgpig.y;
@@ -2138,7 +2300,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
2138
2300
  short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
2139
2301
  short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
2140
2302
 
2141
- simdgroup_half8x8 ma[4];
2303
+ simdgroup_half8x8 ma[4];
2142
2304
  simdgroup_float8x8 mb[2];
2143
2305
  simdgroup_float8x8 c_res[8];
2144
2306
  for (int i = 0; i < 8; i++){
@@ -2146,10 +2308,15 @@ kernel void kernel_mul_mm(device const uchar * src0,
2146
2308
  }
2147
2309
 
2148
2310
  short il = (tiitg % THREAD_PER_ROW);
2149
- uint offset0 = im/gqa*nb02; ushort offset1 = il/nl;
2150
- device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
2151
- device const float * y = src1 + (r1 * BLOCK_SIZE_N + thread_col) * ne00 \
2152
- + BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL) + im * ne00 * ne1;
2311
+
2312
+ uint offset0 = im/gqa*nb02;
2313
+ ushort offset1 = il/nl;
2314
+
2315
+ device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
2316
+ device const float * y = (device const float *)(src1
2317
+ + nb12 * im
2318
+ + nb11 * (r1 * BLOCK_SIZE_N + thread_col)
2319
+ + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
2153
2320
 
2154
2321
  for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
2155
2322
  //load data and store to threadgroup memory
@@ -2229,6 +2396,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
2229
2396
  typedef void (get_rows_t)(device const void *, device const int *, device float *, constant int64_t &, \
2230
2397
  constant uint64_t &, constant uint64_t &, uint, uint, uint);
2231
2398
 
2399
+ template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
2232
2400
  template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
2233
2401
  template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
2234
2402
  template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
@@ -2239,14 +2407,28 @@ template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows
2239
2407
  template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
2240
2408
  template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
2241
2409
 
2242
- typedef void (mat_mm_t)(device const uchar *, device const float *, device float *, constant int64_t &,\
2243
- constant int64_t &, constant int64_t &, constant int64_t &, constant int64_t &, \
2244
- constant int64_t &, constant int64_t &, constant uint &, threadgroup uchar *, uint3, uint, uint);
2245
-
2246
- template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
2247
- template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
2248
- template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
2249
- template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
2410
+ typedef void (mat_mm_t)(
2411
+ device const uchar * src0,
2412
+ device const uchar * src1,
2413
+ device float * dst,
2414
+ constant int64_t & ne00,
2415
+ constant int64_t & ne02,
2416
+ constant int64_t & nb01,
2417
+ constant int64_t & nb02,
2418
+ constant int64_t & ne12,
2419
+ constant int64_t & nb10,
2420
+ constant int64_t & nb11,
2421
+ constant int64_t & nb12,
2422
+ constant int64_t & ne0,
2423
+ constant int64_t & ne1,
2424
+ constant uint & gqa,
2425
+ threadgroup uchar *, uint3, uint, uint);
2426
+
2427
+ template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<float4x4, 1, dequantize_f32>;
2428
+ template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
2429
+ template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
2430
+ template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
2431
+ template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
2250
2432
  template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
2251
2433
  template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
2252
2434
  template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
@@ -847,7 +847,7 @@ std::array<std::string, 2> mul_str_values = {
847
847
  "mul_f32", "float"
848
848
  };
849
849
 
850
- std::string& replace(std::string& s, const std::string& from, const std::string& to) {
850
+ static std::string& replace(std::string& s, const std::string& from, const std::string& to) {
851
851
  size_t pos = 0;
852
852
  while ((pos = s.find(from, pos)) != std::string::npos) {
853
853
  s.replace(pos, from.length(), to);
@@ -856,7 +856,7 @@ std::string& replace(std::string& s, const std::string& from, const std::string&
856
856
  return s;
857
857
  }
858
858
 
859
- std::string generate_kernels() {
859
+ static std::string generate_kernels() {
860
860
  std::stringstream src;
861
861
  src << program_source << '\n';
862
862
  src << k_quants_source << '\n';
@@ -1788,7 +1788,7 @@ bool ggml_cl_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tens
1788
1788
  return false;
1789
1789
  }
1790
1790
 
1791
- bool ggml_cl_mul_mat_use_f16(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * /* dst */) {
1791
+ static bool ggml_cl_mul_mat_use_f16(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * /* dst */) {
1792
1792
  // If device doesn't support FP16
1793
1793
  if (!fp16_support) {
1794
1794
  return false;