llama_cpp 0.5.1 → 0.5.3
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +15 -3
- data/examples/prompt_jp.txt +1 -1
- data/ext/llama_cpp/extconf.rb +1 -1
- data/ext/llama_cpp/llama_cpp.cpp +32 -2
- data/ext/llama_cpp/src/ggml-alloc.c +6 -11
- data/ext/llama_cpp/src/ggml-cuda.cu +1108 -699
- data/ext/llama_cpp/src/ggml-metal.m +93 -24
- data/ext/llama_cpp/src/ggml-metal.metal +407 -174
- data/ext/llama_cpp/src/ggml-opencl.cpp +3 -3
- data/ext/llama_cpp/src/ggml.c +75 -43
- data/ext/llama_cpp/src/ggml.h +42 -32
- data/ext/llama_cpp/src/k_quants.c +4 -1
- data/ext/llama_cpp/src/llama.cpp +1040 -201
- data/ext/llama_cpp/src/llama.h +13 -7
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +1 -1
- data/sig/llama_cpp.rbs +4 -0
- metadata +2 -2
@@ -38,7 +38,7 @@ kernel void kernel_add_row(
|
|
38
38
|
device const float4 * src0,
|
39
39
|
device const float4 * src1,
|
40
40
|
device float4 * dst,
|
41
|
-
constant
|
41
|
+
constant int64_t & nb,
|
42
42
|
uint tpig[[thread_position_in_grid]]) {
|
43
43
|
dst[tpig] = src0[tpig] + src1[tpig % nb];
|
44
44
|
}
|
@@ -63,18 +63,18 @@ kernel void kernel_mul_row(
|
|
63
63
|
}
|
64
64
|
|
65
65
|
kernel void kernel_scale(
|
66
|
-
device const
|
67
|
-
device
|
66
|
+
device const float4 * src0,
|
67
|
+
device float4 * dst,
|
68
68
|
constant float & scale,
|
69
69
|
uint tpig[[thread_position_in_grid]]) {
|
70
70
|
dst[tpig] = src0[tpig] * scale;
|
71
71
|
}
|
72
72
|
|
73
73
|
kernel void kernel_silu(
|
74
|
-
device const
|
75
|
-
device
|
74
|
+
device const float4 * src0,
|
75
|
+
device float4 * dst,
|
76
76
|
uint tpig[[thread_position_in_grid]]) {
|
77
|
-
|
77
|
+
device const float4 & x = src0[tpig];
|
78
78
|
dst[tpig] = x / (1.0f + exp(-x));
|
79
79
|
}
|
80
80
|
|
@@ -89,10 +89,10 @@ constant float GELU_COEF_A = 0.044715f;
|
|
89
89
|
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
90
90
|
|
91
91
|
kernel void kernel_gelu(
|
92
|
-
device const
|
93
|
-
device
|
92
|
+
device const float4 * src0,
|
93
|
+
device float4 * dst,
|
94
94
|
uint tpig[[thread_position_in_grid]]) {
|
95
|
-
|
95
|
+
device const float4 & x = src0[tpig];
|
96
96
|
|
97
97
|
// BEWARE !!!
|
98
98
|
// Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
|
@@ -107,7 +107,6 @@ kernel void kernel_soft_max(
|
|
107
107
|
constant int64_t & ne00,
|
108
108
|
constant int64_t & ne01,
|
109
109
|
constant int64_t & ne02,
|
110
|
-
threadgroup float * buf [[threadgroup(0)]],
|
111
110
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
112
111
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
113
112
|
uint3 ntg[[threads_per_threadgroup]]) {
|
@@ -119,61 +118,67 @@ kernel void kernel_soft_max(
|
|
119
118
|
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
120
119
|
|
121
120
|
// parallel max
|
122
|
-
|
123
|
-
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
|
124
|
-
|
121
|
+
float lmax = tpitg[0] < ne00 ? psrc0[tpitg[0]] : -INFINITY;
|
122
|
+
for (int i00 = tpitg[0] + ntg[0]; i00 < ne00; i00 += ntg[0]) {
|
123
|
+
lmax = MAX(lmax, psrc0[i00]);
|
125
124
|
}
|
126
|
-
|
127
|
-
// reduce
|
128
|
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
129
|
-
for (uint i = ntg[0]/2; i > 0; i /= 2) {
|
130
|
-
if (tpitg[0] < i) {
|
131
|
-
buf[tpitg[0]] = MAX(buf[tpitg[0]], buf[tpitg[0] + i]);
|
132
|
-
}
|
133
|
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
134
|
-
}
|
135
|
-
|
136
|
-
//// broadcast - not needed. There is a threadgroup barrier above in the last iteration of
|
137
|
-
// the loop, and when that is done, buf[0] has the correct (synchronized) value
|
138
|
-
//if (tpitg[0] == 0) {
|
139
|
-
// buf[0] = buf[0];
|
140
|
-
//}
|
141
|
-
|
142
|
-
//threadgroup_barrier(mem_flags::mem_threadgroup);
|
143
|
-
|
144
|
-
const float max = buf[0];
|
125
|
+
const float max = simd_max(lmax);
|
145
126
|
|
146
127
|
// parallel sum
|
147
|
-
|
128
|
+
float lsum = 0.0f;
|
148
129
|
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
|
149
130
|
const float exp_psrc0 = exp(psrc0[i00] - max);
|
150
|
-
|
131
|
+
lsum += exp_psrc0;
|
151
132
|
// Remember the result of exp here. exp is expensive, so we really do not
|
152
133
|
// whish to compute it twice.
|
153
134
|
pdst[i00] = exp_psrc0;
|
154
135
|
}
|
155
136
|
|
156
|
-
|
157
|
-
|
158
|
-
for (
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
137
|
+
const float sum = simd_sum(lsum);
|
138
|
+
|
139
|
+
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
|
140
|
+
pdst[i00] /= sum;
|
141
|
+
}
|
142
|
+
}
|
143
|
+
|
144
|
+
kernel void kernel_soft_max_4(
|
145
|
+
device const float * src0,
|
146
|
+
device float * dst,
|
147
|
+
constant int64_t & ne00,
|
148
|
+
constant int64_t & ne01,
|
149
|
+
constant int64_t & ne02,
|
150
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
151
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
152
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
153
|
+
const int64_t i03 = tgpig[2];
|
154
|
+
const int64_t i02 = tgpig[1];
|
155
|
+
const int64_t i01 = tgpig[0];
|
156
|
+
|
157
|
+
device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
158
|
+
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
159
|
+
|
160
|
+
// parallel max
|
161
|
+
float4 lmax4 = tpitg[0] < ne00/4 ? psrc4[tpitg[0]] : -INFINITY;
|
162
|
+
for (int i00 = tpitg[0] + ntg[0]; i00 < ne00/4; i00 += ntg[0]) {
|
163
|
+
lmax4 = fmax(lmax4, psrc4[i00]);
|
163
164
|
}
|
165
|
+
float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
|
164
166
|
|
165
|
-
|
166
|
-
//// broadcast
|
167
|
-
//if (tpitg[0] == 0) {
|
168
|
-
// buf[0] = buf[0];
|
169
|
-
//}
|
167
|
+
const float max = simd_max(lmax);
|
170
168
|
|
171
|
-
//
|
169
|
+
// parallel sum
|
170
|
+
float4 lsum4 = 0.0f;
|
171
|
+
for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) {
|
172
|
+
const float4 exp_psrc4 = exp(psrc4[i00] - max);
|
173
|
+
lsum4 += exp_psrc4;
|
174
|
+
pdst4[i00] = exp_psrc4;
|
175
|
+
}
|
176
|
+
float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
|
172
177
|
|
173
|
-
const float sum =
|
178
|
+
const float sum = simd_sum(lsum);
|
174
179
|
|
175
|
-
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
|
176
|
-
|
180
|
+
for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) {
|
181
|
+
pdst4[i00] /= sum;
|
177
182
|
}
|
178
183
|
}
|
179
184
|
|
@@ -192,6 +197,33 @@ kernel void kernel_diag_mask_inf(
|
|
192
197
|
dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
|
193
198
|
} else {
|
194
199
|
dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];
|
200
|
+
}
|
201
|
+
}
|
202
|
+
|
203
|
+
kernel void kernel_diag_mask_inf_8(
|
204
|
+
device const float4 * src0,
|
205
|
+
device float4 * dst,
|
206
|
+
constant int64_t & ne00,
|
207
|
+
constant int64_t & ne01,
|
208
|
+
constant int & n_past,
|
209
|
+
uint3 tpig[[thread_position_in_grid]]) {
|
210
|
+
|
211
|
+
const int64_t i = 2*tpig[0];
|
212
|
+
|
213
|
+
dst[i+0] = src0[i+0];
|
214
|
+
dst[i+1] = src0[i+1];
|
215
|
+
int64_t i4 = 4*i;
|
216
|
+
const int64_t i02 = i4/(ne00*ne01); i4 -= i02*ne00*ne01;
|
217
|
+
const int64_t i01 = i4/(ne00); i4 -= i01*ne00;
|
218
|
+
const int64_t i00 = i4;
|
219
|
+
for (int k = 3; k >= 0; --k) {
|
220
|
+
if (i00 + 4 + k <= n_past + i01) {
|
221
|
+
break;
|
222
|
+
}
|
223
|
+
dst[i+1][k] = -INFINITY;
|
224
|
+
if (i00 + k > n_past + i01) {
|
225
|
+
dst[i][k] = -INFINITY;
|
226
|
+
}
|
195
227
|
}
|
196
228
|
}
|
197
229
|
|
@@ -491,6 +523,79 @@ kernel void kernel_mul_mat_q8_0_f32(
|
|
491
523
|
}
|
492
524
|
}
|
493
525
|
|
526
|
+
#define N_F32_F32 4
|
527
|
+
|
528
|
+
kernel void kernel_mul_mat_f32_f32(
|
529
|
+
device const char * src0,
|
530
|
+
device const char * src1,
|
531
|
+
device float * dst,
|
532
|
+
constant int64_t & ne00,
|
533
|
+
constant int64_t & ne01,
|
534
|
+
constant int64_t & ne02,
|
535
|
+
constant uint64_t & nb00,
|
536
|
+
constant uint64_t & nb01,
|
537
|
+
constant uint64_t & nb02,
|
538
|
+
constant int64_t & ne10,
|
539
|
+
constant int64_t & ne11,
|
540
|
+
constant int64_t & ne12,
|
541
|
+
constant uint64_t & nb10,
|
542
|
+
constant uint64_t & nb11,
|
543
|
+
constant uint64_t & nb12,
|
544
|
+
constant int64_t & ne0,
|
545
|
+
constant int64_t & ne1,
|
546
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
547
|
+
uint tiisg[[thread_index_in_simdgroup]]) {
|
548
|
+
|
549
|
+
const int64_t r0 = tgpig.x;
|
550
|
+
const int64_t rb = tgpig.y*N_F32_F32;
|
551
|
+
const int64_t im = tgpig.z;
|
552
|
+
|
553
|
+
device const float * x = (device const float *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
|
554
|
+
|
555
|
+
if (ne00 < 128) {
|
556
|
+
for (int row = 0; row < N_F32_F32; ++row) {
|
557
|
+
int r1 = rb + row;
|
558
|
+
if (r1 >= ne11) {
|
559
|
+
break;
|
560
|
+
}
|
561
|
+
|
562
|
+
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
|
563
|
+
|
564
|
+
float sumf = 0;
|
565
|
+
for (int i = tiisg; i < ne00; i += 32) {
|
566
|
+
sumf += (float) x[i] * (float) y[i];
|
567
|
+
}
|
568
|
+
|
569
|
+
float all_sum = simd_sum(sumf);
|
570
|
+
if (tiisg == 0) {
|
571
|
+
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
572
|
+
}
|
573
|
+
}
|
574
|
+
} else {
|
575
|
+
device const float4 * x4 = (device const float4 *)x;
|
576
|
+
for (int row = 0; row < N_F32_F32; ++row) {
|
577
|
+
int r1 = rb + row;
|
578
|
+
if (r1 >= ne11) {
|
579
|
+
break;
|
580
|
+
}
|
581
|
+
|
582
|
+
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
|
583
|
+
device const float4 * y4 = (device const float4 *) y;
|
584
|
+
|
585
|
+
float sumf = 0;
|
586
|
+
for (int i = tiisg; i < ne00/4; i += 32) {
|
587
|
+
for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
|
588
|
+
}
|
589
|
+
|
590
|
+
float all_sum = simd_sum(sumf);
|
591
|
+
if (tiisg == 0) {
|
592
|
+
for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
|
593
|
+
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
594
|
+
}
|
595
|
+
}
|
596
|
+
}
|
597
|
+
}
|
598
|
+
|
494
599
|
kernel void kernel_mul_mat_f16_f32_1row(
|
495
600
|
device const char * src0,
|
496
601
|
device const char * src1,
|
@@ -616,6 +721,49 @@ kernel void kernel_mul_mat_f16_f32(
|
|
616
721
|
}
|
617
722
|
}
|
618
723
|
|
724
|
+
// Assumes row size (ne00) is a multiple of 4
|
725
|
+
kernel void kernel_mul_mat_f16_f32_l4(
|
726
|
+
device const char * src0,
|
727
|
+
device const char * src1,
|
728
|
+
device float * dst,
|
729
|
+
constant int64_t & ne00,
|
730
|
+
constant int64_t & ne01,
|
731
|
+
constant int64_t & ne02,
|
732
|
+
constant uint64_t & nb00,
|
733
|
+
constant uint64_t & nb01,
|
734
|
+
constant uint64_t & nb02,
|
735
|
+
constant int64_t & ne10,
|
736
|
+
constant int64_t & ne11,
|
737
|
+
constant int64_t & ne12,
|
738
|
+
constant uint64_t & nb10,
|
739
|
+
constant uint64_t & nb11,
|
740
|
+
constant uint64_t & nb12,
|
741
|
+
constant int64_t & ne0,
|
742
|
+
constant int64_t & ne1,
|
743
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
744
|
+
uint tiisg[[thread_index_in_simdgroup]]) {
|
745
|
+
|
746
|
+
const int nrows = ne11;
|
747
|
+
const int64_t r0 = tgpig.x;
|
748
|
+
const int64_t im = tgpig.z;
|
749
|
+
|
750
|
+
device const half4 * x4 = (device const half4 *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
|
751
|
+
|
752
|
+
for (int r1 = 0; r1 < nrows; ++r1) {
|
753
|
+
device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12);
|
754
|
+
|
755
|
+
float sumf = 0;
|
756
|
+
for (int i = tiisg; i < ne00/4; i += 32) {
|
757
|
+
for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
|
758
|
+
}
|
759
|
+
|
760
|
+
float all_sum = simd_sum(sumf);
|
761
|
+
if (tiisg == 0) {
|
762
|
+
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
763
|
+
}
|
764
|
+
}
|
765
|
+
}
|
766
|
+
|
619
767
|
kernel void kernel_alibi_f32(
|
620
768
|
device const float * src0,
|
621
769
|
device float * dst,
|
@@ -1123,31 +1271,40 @@ kernel void kernel_mul_mat_q3_K_f32(
|
|
1123
1271
|
device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0;
|
1124
1272
|
device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
|
1125
1273
|
|
1126
|
-
float yl[
|
1274
|
+
float yl[32];
|
1127
1275
|
|
1128
|
-
const uint16_t kmask1 =
|
1276
|
+
const uint16_t kmask1 = 0x3030;
|
1129
1277
|
const uint16_t kmask2 = 0x0f0f;
|
1130
1278
|
|
1131
|
-
const int tid = tiisg/
|
1132
|
-
const int ix = tiisg%
|
1133
|
-
const int ip = tid/
|
1134
|
-
const int il = tid/2
|
1279
|
+
const int tid = tiisg/4;
|
1280
|
+
const int ix = tiisg%4;
|
1281
|
+
const int ip = tid/4; // 0 or 1
|
1282
|
+
const int il = 2*((tid%4)/2); // 0 or 2
|
1135
1283
|
const int ir = tid%2;
|
1136
1284
|
const int n = 8;
|
1137
1285
|
const int l0 = n*ir;
|
1138
1286
|
|
1139
|
-
|
1140
|
-
|
1287
|
+
// One would think that the Metal compiler would figure out that ip and il can only have
|
1288
|
+
// 4 possible states, and optimize accordingly. Well, no. It needs help, and we do it
|
1289
|
+
// with these two tales.
|
1290
|
+
//
|
1291
|
+
// Possible masks for the high bit
|
1292
|
+
const ushort4 mm[4] = {{0x0001, 0x0100, 0x0002, 0x0200}, // ip = 0, il = 0
|
1293
|
+
{0x0004, 0x0400, 0x0008, 0x0800}, // ip = 0, il = 2
|
1294
|
+
{0x0010, 0x1000, 0x0020, 0x2000}, // ip = 1, il = 0
|
1295
|
+
{0x0040, 0x4000, 0x0080, 0x8000}}; // ip = 1, il = 2
|
1296
|
+
|
1297
|
+
// Possible masks for the low 2 bits
|
1298
|
+
const int4 qm[2] = {{0x0003, 0x0300, 0x000c, 0x0c00}, {0x0030, 0x3000, 0x00c0, 0xc000}};
|
1299
|
+
|
1300
|
+
const ushort4 hm = mm[2*ip + il/2];
|
1141
1301
|
|
1142
1302
|
const int shift = 2*il;
|
1143
|
-
const
|
1144
|
-
const
|
1145
|
-
const int32_t v1 = 4 << shift;
|
1146
|
-
const int32_t v2 = 1024 << shift;
|
1303
|
+
const float v1 = il == 0 ? 4.f : 64.f;
|
1304
|
+
const float v2 = 4.f * v1;
|
1147
1305
|
|
1148
1306
|
const uint16_t s_shift1 = 4*ip;
|
1149
|
-
const uint16_t s_shift2 = s_shift1 +
|
1150
|
-
const int ik = 4 + (il%2);
|
1307
|
+
const uint16_t s_shift2 = s_shift1 + il;
|
1151
1308
|
|
1152
1309
|
const int q_offset = 32*ip + l0;
|
1153
1310
|
const int y_offset = 128*ip + 32*il + l0;
|
@@ -1156,12 +1313,19 @@ kernel void kernel_mul_mat_q3_K_f32(
|
|
1156
1313
|
|
1157
1314
|
device const float * y1 = yy + ix*QK_K + y_offset;
|
1158
1315
|
|
1159
|
-
|
1160
|
-
|
1316
|
+
uint32_t scales32, aux32;
|
1317
|
+
thread uint16_t * scales16 = (thread uint16_t *)&scales32;
|
1318
|
+
thread const int8_t * scales = (thread const int8_t *)&scales32;
|
1319
|
+
|
1320
|
+
float sumf1[2] = {0.f};
|
1321
|
+
float sumf2[2] = {0.f};
|
1322
|
+
for (int i = ix; i < nb; i += 4) {
|
1161
1323
|
|
1162
1324
|
for (int l = 0; l < 8; ++l) {
|
1163
|
-
yl[l+0] = y1[l+ 0];
|
1164
|
-
yl[l+8] = y1[l+16];
|
1325
|
+
yl[l+ 0] = y1[l+ 0];
|
1326
|
+
yl[l+ 8] = y1[l+16];
|
1327
|
+
yl[l+16] = y1[l+32];
|
1328
|
+
yl[l+24] = y1[l+48];
|
1165
1329
|
}
|
1166
1330
|
|
1167
1331
|
device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset);
|
@@ -1172,27 +1336,43 @@ kernel void kernel_mul_mat_q3_K_f32(
|
|
1172
1336
|
for (int row = 0; row < 2; ++row) {
|
1173
1337
|
|
1174
1338
|
const float d_all = (float)dh[0];
|
1175
|
-
const char2 scales = as_type<char2>((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4)));
|
1176
1339
|
|
1177
|
-
|
1340
|
+
scales16[0] = a[4];
|
1341
|
+
scales16[1] = a[5];
|
1342
|
+
aux32 = ((scales32 >> s_shift2) << 4) & 0x30303030;
|
1343
|
+
scales16[0] = a[il+0];
|
1344
|
+
scales16[1] = a[il+1];
|
1345
|
+
scales32 = ((scales32 >> s_shift1) & 0x0f0f0f0f) | aux32;
|
1346
|
+
|
1347
|
+
float s1 = 0, s2 = 0, s3 = 0, s4 = 0, s5 = 0, s6 = 0;
|
1178
1348
|
for (int l = 0; l < n; l += 2) {
|
1179
|
-
const
|
1180
|
-
s1 += yl[l+0] * (
|
1181
|
-
s2 += yl[l+1] * (
|
1349
|
+
const int32_t qs = q[l/2];
|
1350
|
+
s1 += yl[l+0] * (qs & qm[il/2][0]);
|
1351
|
+
s2 += yl[l+1] * (qs & qm[il/2][1]);
|
1352
|
+
s3 += ((h[l/2] & hm[0]) ? 0.f : yl[l+0]) + ((h[l/2] & hm[1]) ? 0.f : yl[l+1]);
|
1353
|
+
s4 += yl[l+16] * (qs & qm[il/2][2]);
|
1354
|
+
s5 += yl[l+17] * (qs & qm[il/2][3]);
|
1355
|
+
s6 += ((h[l/2] & hm[2]) ? 0.f : yl[l+16]) + ((h[l/2] & hm[3]) ? 0.f : yl[l+17]);
|
1182
1356
|
}
|
1183
|
-
float
|
1184
|
-
|
1185
|
-
|
1357
|
+
float d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
|
1358
|
+
float d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2);
|
1359
|
+
sumf1[row] += d1 * (scales[0] - 32);
|
1360
|
+
sumf2[row] += d2 * (scales[2] - 32);
|
1186
1361
|
|
1187
|
-
s1 = s2 = 0;
|
1362
|
+
s1 = s2 = s3 = s4 = s5 = s6 = 0;
|
1188
1363
|
for (int l = 0; l < n; l += 2) {
|
1189
|
-
const
|
1190
|
-
s1 += yl[l+8] * (
|
1191
|
-
s2 += yl[l+9] * (
|
1364
|
+
const int32_t qs = q[l/2+8];
|
1365
|
+
s1 += yl[l+8] * (qs & qm[il/2][0]);
|
1366
|
+
s2 += yl[l+9] * (qs & qm[il/2][1]);
|
1367
|
+
s3 += ((h[l/2+8] & hm[0]) ? 0.f : yl[l+8]) + ((h[l/2+8] & hm[1]) ? 0.f : yl[l+9]);
|
1368
|
+
s4 += yl[l+24] * (qs & qm[il/2][2]);
|
1369
|
+
s5 += yl[l+25] * (qs & qm[il/2][3]);
|
1370
|
+
s6 += ((h[l/2+8] & hm[2]) ? 0.f : yl[l+24]) + ((h[l/2+8] & hm[3]) ? 0.f : yl[l+25]);
|
1192
1371
|
}
|
1193
|
-
|
1194
|
-
|
1195
|
-
|
1372
|
+
d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
|
1373
|
+
d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2);
|
1374
|
+
sumf1[row] += d1 * (scales[1] - 32);
|
1375
|
+
sumf2[row] += d2 * (scales[3] - 32);
|
1196
1376
|
|
1197
1377
|
q += step;
|
1198
1378
|
h += step;
|
@@ -1201,15 +1381,17 @@ kernel void kernel_mul_mat_q3_K_f32(
|
|
1201
1381
|
|
1202
1382
|
}
|
1203
1383
|
|
1204
|
-
y1 +=
|
1384
|
+
y1 += 4 * QK_K;
|
1205
1385
|
|
1206
1386
|
}
|
1207
1387
|
|
1208
1388
|
for (int row = 0; row < 2; ++row) {
|
1209
|
-
const float sumf = (sumf1[row]
|
1210
|
-
|
1211
|
-
|
1212
|
-
|
1389
|
+
const float sumf = (sumf1[row] + 0.25f * sumf2[row]) / (1 << shift);
|
1390
|
+
sumf1[row] = simd_sum(sumf);
|
1391
|
+
}
|
1392
|
+
if (tiisg == 0) {
|
1393
|
+
for (int row = 0; row < 2; ++row) {
|
1394
|
+
dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = sumf1[row];
|
1213
1395
|
}
|
1214
1396
|
}
|
1215
1397
|
}
|
@@ -1290,13 +1472,13 @@ kernel void kernel_mul_mat_q4_K_f32(
|
|
1290
1472
|
device const float * src1,
|
1291
1473
|
device float * dst,
|
1292
1474
|
constant int64_t & ne00,
|
1293
|
-
constant int64_t & ne01[[buffer(4)]],
|
1294
|
-
constant int64_t & ne02[[buffer(5)]],
|
1295
|
-
constant int64_t & ne10[[buffer(9)]],
|
1296
|
-
constant int64_t & ne12[[buffer(11)]],
|
1297
|
-
constant int64_t & ne0[[buffer(15)]],
|
1298
|
-
constant int64_t & ne1[[buffer(16)]],
|
1299
|
-
constant uint & gqa[[buffer(17)]],
|
1475
|
+
constant int64_t & ne01 [[buffer(4)]],
|
1476
|
+
constant int64_t & ne02 [[buffer(5)]],
|
1477
|
+
constant int64_t & ne10 [[buffer(9)]],
|
1478
|
+
constant int64_t & ne12 [[buffer(11)]],
|
1479
|
+
constant int64_t & ne0 [[buffer(15)]],
|
1480
|
+
constant int64_t & ne1 [[buffer(16)]],
|
1481
|
+
constant uint & gqa [[buffer(17)]],
|
1300
1482
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
1301
1483
|
uint tiisg[[thread_index_in_simdgroup]],
|
1302
1484
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
@@ -1564,17 +1746,25 @@ kernel void kernel_mul_mat_q5_K_f32(
|
|
1564
1746
|
sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2);
|
1565
1747
|
sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2);
|
1566
1748
|
|
1567
|
-
float4
|
1749
|
+
float4 acc1 = {0.f};
|
1750
|
+
float4 acc2 = {0.f};
|
1568
1751
|
for (int l = 0; l < n; ++l) {
|
1569
1752
|
uint8_t h = qh[l];
|
1570
|
-
|
1571
|
-
|
1572
|
-
|
1573
|
-
|
1753
|
+
acc1[0] += yl[l+0] * (q1[l] & 0x0F);
|
1754
|
+
acc1[1] += yl[l+8] * (q1[l] & 0xF0);
|
1755
|
+
acc1[2] += yh[l+0] * (q2[l] & 0x0F);
|
1756
|
+
acc1[3] += yh[l+8] * (q2[l] & 0xF0);
|
1757
|
+
acc2[0] += h & hm1 ? yl[l+0] : 0.f;
|
1758
|
+
acc2[1] += h & hm2 ? yl[l+8] : 0.f;
|
1759
|
+
acc2[2] += h & hm3 ? yh[l+0] : 0.f;
|
1760
|
+
acc2[3] += h & hm4 ? yh[l+8] : 0.f;
|
1574
1761
|
}
|
1575
1762
|
const float dall = dh[0];
|
1576
1763
|
const float dmin = dh[1];
|
1577
|
-
sumf[row] += dall * (
|
1764
|
+
sumf[row] += dall * (sc8[0] * (acc1[0] + 16.f*acc2[0]) +
|
1765
|
+
sc8[1] * (acc1[1]/16.f + 16.f*acc2[1]) +
|
1766
|
+
sc8[4] * (acc1[2] + 16.f*acc2[2]) +
|
1767
|
+
sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) -
|
1578
1768
|
dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
|
1579
1769
|
|
1580
1770
|
q1 += step;
|
@@ -1747,6 +1937,15 @@ kernel void kernel_mul_mat_q6_K_f32(
|
|
1747
1937
|
|
1748
1938
|
//============================= templates and their specializations =============================
|
1749
1939
|
|
1940
|
+
// NOTE: this is not dequantizing - we are simply fitting the template
|
1941
|
+
template <typename type4x4>
|
1942
|
+
void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
|
1943
|
+
float4x4 temp = *(((device float4x4 *)src));
|
1944
|
+
for (int i = 0; i < 16; i++){
|
1945
|
+
reg[i/4][i%4] = temp[i/4][i%4];
|
1946
|
+
}
|
1947
|
+
}
|
1948
|
+
|
1750
1949
|
template <typename type4x4>
|
1751
1950
|
void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
|
1752
1951
|
half4x4 temp = *(((device half4x4 *)src));
|
@@ -1758,28 +1957,30 @@ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg)
|
|
1758
1957
|
template <typename type4x4>
|
1759
1958
|
void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
|
1760
1959
|
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
|
1761
|
-
const
|
1762
|
-
const
|
1960
|
+
const float d1 = il ? (xb->d / 16.h) : xb->d;
|
1961
|
+
const float d2 = d1 / 256.f;
|
1962
|
+
const float md = -8.h * xb->d;
|
1763
1963
|
const ushort mask0 = il ? 0x00F0 : 0x000F;
|
1764
|
-
const ushort mask1 =
|
1964
|
+
const ushort mask1 = mask0 << 8;
|
1765
1965
|
|
1766
1966
|
for (int i=0;i<8;i++) {
|
1767
|
-
reg[i/2][2*(i%2)]
|
1768
|
-
reg[i/2][2*(i%2)+1] = (
|
1967
|
+
reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md;
|
1968
|
+
reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md;
|
1769
1969
|
}
|
1770
1970
|
}
|
1771
1971
|
|
1772
1972
|
template <typename type4x4>
|
1773
1973
|
void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
|
1774
1974
|
device const uint16_t * qs = ((device const uint16_t *)xb + 2);
|
1775
|
-
const
|
1776
|
-
const
|
1975
|
+
const float d1 = il ? (xb->d / 16.h) : xb->d;
|
1976
|
+
const float d2 = d1 / 256.f;
|
1977
|
+
const float m = xb->m;
|
1777
1978
|
const ushort mask0 = il ? 0x00F0 : 0x000F;
|
1778
|
-
const ushort mask1 =
|
1979
|
+
const ushort mask1 = mask0 << 8;
|
1779
1980
|
|
1780
1981
|
for (int i=0;i<8;i++) {
|
1781
|
-
reg[i/2][2*(i%2)]
|
1782
|
-
reg[i/2][2*(i%2)+1] = ((
|
1982
|
+
reg[i/2][2*(i%2)+0] = ((qs[i] & mask0) * d1) + m;
|
1983
|
+
reg[i/2][2*(i%2)+1] = ((qs[i] & mask1) * d2) + m;
|
1783
1984
|
}
|
1784
1985
|
}
|
1785
1986
|
|
@@ -1815,7 +2016,7 @@ void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg
|
|
1815
2016
|
|
1816
2017
|
template <typename type4x4>
|
1817
2018
|
void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) {
|
1818
|
-
const
|
2019
|
+
const half d_all = xb->d;
|
1819
2020
|
device const uint8_t * q = (device const uint8_t *)xb->qs;
|
1820
2021
|
device const uint8_t * h = (device const uint8_t *)xb->hmask;
|
1821
2022
|
device const int8_t * scales = (device const int8_t *)xb->scales;
|
@@ -1828,16 +2029,18 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
|
|
1828
2029
|
((il/4)>0 ? 12 : 3);
|
1829
2030
|
uint16_t kmask2 = il/8 ? 0xF0 : 0x0F;
|
1830
2031
|
uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];
|
1831
|
-
int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2)
|
1832
|
-
|
1833
|
-
|
2032
|
+
int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2)
|
2033
|
+
: (scale_2&kmask2) | ((scale_1&kmask1) << 4);
|
2034
|
+
half dl = il<8 ? d_all * (dl_int - 32.h) : d_all * (dl_int / 16.h - 32.h);
|
2035
|
+
const half ml = 4.h * dl;
|
1834
2036
|
|
1835
|
-
il = (il/2)
|
1836
|
-
|
1837
|
-
uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
|
2037
|
+
il = (il/2) & 3;
|
2038
|
+
const half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
|
2039
|
+
const uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
|
2040
|
+
dl *= coef;
|
1838
2041
|
|
1839
2042
|
for (int i = 0; i < 16; ++i) {
|
1840
|
-
reg[i/4][i%4] =
|
2043
|
+
reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml);
|
1841
2044
|
}
|
1842
2045
|
#else
|
1843
2046
|
float kcoef = il&1 ? 1.f/16.f : 1.f;
|
@@ -1852,26 +2055,31 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
|
|
1852
2055
|
#endif
|
1853
2056
|
}
|
1854
2057
|
|
2058
|
+
static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) {
|
2059
|
+
return j < 4 ? uchar2{uchar(q[j+0+k] & 63), uchar(q[j+4+k] & 63)}
|
2060
|
+
: uchar2{uchar((q[j+4+k] & 0xF) | ((q[j-4+k] & 0xc0) >> 2)), uchar((q[j+4+k] >> 4) | ((q[j-0+k] & 0xc0) >> 2))};
|
2061
|
+
}
|
2062
|
+
|
1855
2063
|
template <typename type4x4>
|
1856
2064
|
void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) {
|
1857
|
-
device const
|
2065
|
+
device const uchar * q = xb->qs;
|
1858
2066
|
|
1859
2067
|
#if QK_K == 256
|
1860
|
-
const float d = (float)(xb->d);
|
1861
|
-
const float min = (float)(xb->dmin);
|
1862
2068
|
short is = (il/4) * 2;
|
1863
2069
|
q = q + (il/4) * 32 + 16 * (il&1);
|
1864
|
-
il = il
|
1865
|
-
const
|
1866
|
-
const
|
1867
|
-
const
|
2070
|
+
il = il & 3;
|
2071
|
+
const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
|
2072
|
+
const half d = il < 2 ? xb->d : xb->d / 16.h;
|
2073
|
+
const half min = xb->dmin;
|
2074
|
+
const half dl = d * sc[0];
|
2075
|
+
const half ml = min * sc[1];
|
1868
2076
|
#else
|
1869
2077
|
q = q + 16 * (il&1);
|
1870
2078
|
device const uint8_t * s = xb->scales;
|
1871
2079
|
device const half2 * dh = (device const half2 *)xb->d;
|
1872
2080
|
const float2 d = (float2)dh[0];
|
1873
2081
|
const float dl = il<2 ? d[0] * (s[0]&0xF) : d[0] * (s[1]&0xF)/16.h;
|
1874
|
-
const float ml = il<2 ? d[1] * (s[0]>>4) : d[1
|
2082
|
+
const float ml = il<2 ? d[1] * (s[0]>>4) : d[1] * (s[1]>>4);
|
1875
2083
|
#endif
|
1876
2084
|
const ushort mask = il<2 ? 0x0F : 0xF0;
|
1877
2085
|
for (int i = 0; i < 16; ++i) {
|
@@ -1885,19 +2093,19 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
|
|
1885
2093
|
device const uint8_t * qh = xb->qh;
|
1886
2094
|
|
1887
2095
|
#if QK_K == 256
|
1888
|
-
const float d = (float)(xb->d);
|
1889
|
-
const float min = (float)(xb->dmin);
|
1890
2096
|
short is = (il/4) * 2;
|
1891
2097
|
q = q + 32 * (il/4) + 16 * (il&1);
|
1892
2098
|
qh = qh + 16 * (il&1);
|
1893
2099
|
uint8_t ul = 1 << (il/2);
|
1894
|
-
il = il
|
1895
|
-
const
|
1896
|
-
const
|
1897
|
-
const
|
2100
|
+
il = il & 3;
|
2101
|
+
const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
|
2102
|
+
const half d = il < 2 ? xb->d : xb->d / 16.h;
|
2103
|
+
const half min = xb->dmin;
|
2104
|
+
const half dl = d * sc[0];
|
2105
|
+
const half ml = min * sc[1];
|
1898
2106
|
|
1899
|
-
const ushort mask
|
1900
|
-
const
|
2107
|
+
const ushort mask = il<2 ? 0x0F : 0xF0;
|
2108
|
+
const half qh_val = il<2 ? 16.h : 256.h;
|
1901
2109
|
for (int i = 0; i < 16; ++i) {
|
1902
2110
|
reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
|
1903
2111
|
}
|
@@ -1916,7 +2124,7 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
|
|
1916
2124
|
|
1917
2125
|
template <typename type4x4>
|
1918
2126
|
void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) {
|
1919
|
-
const
|
2127
|
+
const half d_all = xb->d;
|
1920
2128
|
device const uint8_t * ql = (device const uint8_t *)xb->ql;
|
1921
2129
|
device const uint8_t * qh = (device const uint8_t *)xb->qh;
|
1922
2130
|
device const int8_t * scales = (device const int8_t *)xb->scales;
|
@@ -1924,19 +2132,21 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
|
|
1924
2132
|
#if QK_K == 256
|
1925
2133
|
ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
|
1926
2134
|
qh = qh + 32*(il/8) + 16*(il&1);
|
1927
|
-
|
1928
|
-
il = (il/2)
|
2135
|
+
half sc = scales[(il%2) + 2 * ((il/2))];
|
2136
|
+
il = (il/2) & 3;
|
1929
2137
|
#else
|
1930
2138
|
ql = ql + 16 * (il&1);
|
1931
|
-
|
2139
|
+
half sc = scales[il];
|
1932
2140
|
#endif
|
2141
|
+
const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
|
2142
|
+
const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
|
2143
|
+
const half coef = il>1 ? 1.f/16.h : 1.h;
|
2144
|
+
const half ml = d_all * sc * 32.h;
|
2145
|
+
const half dl = d_all * sc * coef;
|
1933
2146
|
for (int i = 0; i < 16; ++i) {
|
1934
|
-
|
1935
|
-
|
1936
|
-
|
1937
|
-
float q = il&1 ? ((ql[i]&kmask2)|((qh[i]&kmask1)<<2)) - 32.f/coef : \
|
1938
|
-
((ql[i]&kmask2)|((qh[i]&kmask1)<<4)) - 32.f/coef;
|
1939
|
-
reg[i/4][i%4] = d_all * sc * q * coef;
|
2147
|
+
const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2))
|
2148
|
+
: ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4));
|
2149
|
+
reg[i/4][i%4] = dl * q - ml;
|
1940
2150
|
}
|
1941
2151
|
}
|
1942
2152
|
|
@@ -1976,22 +2186,25 @@ kernel void kernel_get_rows(
|
|
1976
2186
|
// each block_q contains 16*nl weights
|
1977
2187
|
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
1978
2188
|
kernel void kernel_mul_mm(device const uchar * src0,
|
1979
|
-
|
1980
|
-
|
1981
|
-
|
1982
|
-
|
1983
|
-
|
1984
|
-
|
1985
|
-
|
1986
|
-
|
1987
|
-
|
1988
|
-
|
1989
|
-
|
1990
|
-
|
1991
|
-
|
1992
|
-
|
1993
|
-
|
1994
|
-
|
2189
|
+
device const uchar * src1,
|
2190
|
+
device float * dst,
|
2191
|
+
constant int64_t & ne00,
|
2192
|
+
constant int64_t & ne02,
|
2193
|
+
constant int64_t & nb01,
|
2194
|
+
constant int64_t & nb02,
|
2195
|
+
constant int64_t & ne12,
|
2196
|
+
constant int64_t & nb10,
|
2197
|
+
constant int64_t & nb11,
|
2198
|
+
constant int64_t & nb12,
|
2199
|
+
constant int64_t & ne0,
|
2200
|
+
constant int64_t & ne1,
|
2201
|
+
constant uint & gqa,
|
2202
|
+
threadgroup uchar * shared_memory [[threadgroup(0)]],
|
2203
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
2204
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
2205
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
2206
|
+
|
2207
|
+
threadgroup half * sa = (threadgroup half *)(shared_memory);
|
1995
2208
|
threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
|
1996
2209
|
|
1997
2210
|
const uint r0 = tgpig.y;
|
@@ -2004,7 +2217,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|
2004
2217
|
short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
|
2005
2218
|
short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
|
2006
2219
|
|
2007
|
-
simdgroup_half8x8
|
2220
|
+
simdgroup_half8x8 ma[4];
|
2008
2221
|
simdgroup_float8x8 mb[2];
|
2009
2222
|
simdgroup_float8x8 c_res[8];
|
2010
2223
|
for (int i = 0; i < 8; i++){
|
@@ -2012,10 +2225,15 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|
2012
2225
|
}
|
2013
2226
|
|
2014
2227
|
short il = (tiitg % THREAD_PER_ROW);
|
2015
|
-
|
2016
|
-
|
2017
|
-
|
2018
|
-
|
2228
|
+
|
2229
|
+
uint offset0 = im/gqa*nb02;
|
2230
|
+
ushort offset1 = il/nl;
|
2231
|
+
|
2232
|
+
device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
|
2233
|
+
device const float * y = (device const float *)(src1
|
2234
|
+
+ nb12 * im
|
2235
|
+
+ nb11 * (r1 * BLOCK_SIZE_N + thread_col)
|
2236
|
+
+ nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
|
2019
2237
|
|
2020
2238
|
for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
|
2021
2239
|
//load data and store to threadgroup memory
|
@@ -2095,6 +2313,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|
2095
2313
|
typedef void (get_rows_t)(device const void *, device const int *, device float *, constant int64_t &, \
|
2096
2314
|
constant uint64_t &, constant uint64_t &, uint, uint, uint);
|
2097
2315
|
|
2316
|
+
template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
|
2098
2317
|
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
|
2099
2318
|
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
|
2100
2319
|
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
|
@@ -2105,14 +2324,28 @@ template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows
|
|
2105
2324
|
template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
|
2106
2325
|
template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
|
2107
2326
|
|
2108
|
-
typedef void (mat_mm_t)(
|
2109
|
-
|
2110
|
-
|
2111
|
-
|
2112
|
-
|
2113
|
-
|
2114
|
-
|
2115
|
-
|
2327
|
+
typedef void (mat_mm_t)(
|
2328
|
+
device const uchar * src0,
|
2329
|
+
device const uchar * src1,
|
2330
|
+
device float * dst,
|
2331
|
+
constant int64_t & ne00,
|
2332
|
+
constant int64_t & ne02,
|
2333
|
+
constant int64_t & nb01,
|
2334
|
+
constant int64_t & nb02,
|
2335
|
+
constant int64_t & ne12,
|
2336
|
+
constant int64_t & nb10,
|
2337
|
+
constant int64_t & nb11,
|
2338
|
+
constant int64_t & nb12,
|
2339
|
+
constant int64_t & ne0,
|
2340
|
+
constant int64_t & ne1,
|
2341
|
+
constant uint & gqa,
|
2342
|
+
threadgroup uchar *, uint3, uint, uint);
|
2343
|
+
|
2344
|
+
template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<float4x4, 1, dequantize_f32>;
|
2345
|
+
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
|
2346
|
+
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
|
2347
|
+
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
|
2348
|
+
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
|
2116
2349
|
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
|
2117
2350
|
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
|
2118
2351
|
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
|