llama_cpp 0.5.1 → 0.5.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +15 -3
- data/examples/prompt_jp.txt +1 -1
- data/ext/llama_cpp/extconf.rb +1 -1
- data/ext/llama_cpp/llama_cpp.cpp +32 -2
- data/ext/llama_cpp/src/ggml-alloc.c +6 -11
- data/ext/llama_cpp/src/ggml-cuda.cu +1108 -699
- data/ext/llama_cpp/src/ggml-metal.m +93 -24
- data/ext/llama_cpp/src/ggml-metal.metal +407 -174
- data/ext/llama_cpp/src/ggml-opencl.cpp +3 -3
- data/ext/llama_cpp/src/ggml.c +75 -43
- data/ext/llama_cpp/src/ggml.h +42 -32
- data/ext/llama_cpp/src/k_quants.c +4 -1
- data/ext/llama_cpp/src/llama.cpp +1040 -201
- data/ext/llama_cpp/src/llama.h +13 -7
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +1 -1
- data/sig/llama_cpp.rbs +4 -0
- metadata +2 -2
@@ -38,7 +38,7 @@ kernel void kernel_add_row(
|
|
38
38
|
device const float4 * src0,
|
39
39
|
device const float4 * src1,
|
40
40
|
device float4 * dst,
|
41
|
-
constant
|
41
|
+
constant int64_t & nb,
|
42
42
|
uint tpig[[thread_position_in_grid]]) {
|
43
43
|
dst[tpig] = src0[tpig] + src1[tpig % nb];
|
44
44
|
}
|
@@ -63,18 +63,18 @@ kernel void kernel_mul_row(
|
|
63
63
|
}
|
64
64
|
|
65
65
|
kernel void kernel_scale(
|
66
|
-
device const
|
67
|
-
device
|
66
|
+
device const float4 * src0,
|
67
|
+
device float4 * dst,
|
68
68
|
constant float & scale,
|
69
69
|
uint tpig[[thread_position_in_grid]]) {
|
70
70
|
dst[tpig] = src0[tpig] * scale;
|
71
71
|
}
|
72
72
|
|
73
73
|
kernel void kernel_silu(
|
74
|
-
device const
|
75
|
-
device
|
74
|
+
device const float4 * src0,
|
75
|
+
device float4 * dst,
|
76
76
|
uint tpig[[thread_position_in_grid]]) {
|
77
|
-
|
77
|
+
device const float4 & x = src0[tpig];
|
78
78
|
dst[tpig] = x / (1.0f + exp(-x));
|
79
79
|
}
|
80
80
|
|
@@ -89,10 +89,10 @@ constant float GELU_COEF_A = 0.044715f;
|
|
89
89
|
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
90
90
|
|
91
91
|
kernel void kernel_gelu(
|
92
|
-
device const
|
93
|
-
device
|
92
|
+
device const float4 * src0,
|
93
|
+
device float4 * dst,
|
94
94
|
uint tpig[[thread_position_in_grid]]) {
|
95
|
-
|
95
|
+
device const float4 & x = src0[tpig];
|
96
96
|
|
97
97
|
// BEWARE !!!
|
98
98
|
// Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
|
@@ -107,7 +107,6 @@ kernel void kernel_soft_max(
|
|
107
107
|
constant int64_t & ne00,
|
108
108
|
constant int64_t & ne01,
|
109
109
|
constant int64_t & ne02,
|
110
|
-
threadgroup float * buf [[threadgroup(0)]],
|
111
110
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
112
111
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
113
112
|
uint3 ntg[[threads_per_threadgroup]]) {
|
@@ -119,61 +118,67 @@ kernel void kernel_soft_max(
|
|
119
118
|
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
120
119
|
|
121
120
|
// parallel max
|
122
|
-
|
123
|
-
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
|
124
|
-
|
121
|
+
float lmax = tpitg[0] < ne00 ? psrc0[tpitg[0]] : -INFINITY;
|
122
|
+
for (int i00 = tpitg[0] + ntg[0]; i00 < ne00; i00 += ntg[0]) {
|
123
|
+
lmax = MAX(lmax, psrc0[i00]);
|
125
124
|
}
|
126
|
-
|
127
|
-
// reduce
|
128
|
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
129
|
-
for (uint i = ntg[0]/2; i > 0; i /= 2) {
|
130
|
-
if (tpitg[0] < i) {
|
131
|
-
buf[tpitg[0]] = MAX(buf[tpitg[0]], buf[tpitg[0] + i]);
|
132
|
-
}
|
133
|
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
134
|
-
}
|
135
|
-
|
136
|
-
//// broadcast - not needed. There is a threadgroup barrier above in the last iteration of
|
137
|
-
// the loop, and when that is done, buf[0] has the correct (synchronized) value
|
138
|
-
//if (tpitg[0] == 0) {
|
139
|
-
// buf[0] = buf[0];
|
140
|
-
//}
|
141
|
-
|
142
|
-
//threadgroup_barrier(mem_flags::mem_threadgroup);
|
143
|
-
|
144
|
-
const float max = buf[0];
|
125
|
+
const float max = simd_max(lmax);
|
145
126
|
|
146
127
|
// parallel sum
|
147
|
-
|
128
|
+
float lsum = 0.0f;
|
148
129
|
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
|
149
130
|
const float exp_psrc0 = exp(psrc0[i00] - max);
|
150
|
-
|
131
|
+
lsum += exp_psrc0;
|
151
132
|
// Remember the result of exp here. exp is expensive, so we really do not
|
152
133
|
// whish to compute it twice.
|
153
134
|
pdst[i00] = exp_psrc0;
|
154
135
|
}
|
155
136
|
|
156
|
-
|
157
|
-
|
158
|
-
for (
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
137
|
+
const float sum = simd_sum(lsum);
|
138
|
+
|
139
|
+
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
|
140
|
+
pdst[i00] /= sum;
|
141
|
+
}
|
142
|
+
}
|
143
|
+
|
144
|
+
kernel void kernel_soft_max_4(
|
145
|
+
device const float * src0,
|
146
|
+
device float * dst,
|
147
|
+
constant int64_t & ne00,
|
148
|
+
constant int64_t & ne01,
|
149
|
+
constant int64_t & ne02,
|
150
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
151
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
152
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
153
|
+
const int64_t i03 = tgpig[2];
|
154
|
+
const int64_t i02 = tgpig[1];
|
155
|
+
const int64_t i01 = tgpig[0];
|
156
|
+
|
157
|
+
device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
158
|
+
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
159
|
+
|
160
|
+
// parallel max
|
161
|
+
float4 lmax4 = tpitg[0] < ne00/4 ? psrc4[tpitg[0]] : -INFINITY;
|
162
|
+
for (int i00 = tpitg[0] + ntg[0]; i00 < ne00/4; i00 += ntg[0]) {
|
163
|
+
lmax4 = fmax(lmax4, psrc4[i00]);
|
163
164
|
}
|
165
|
+
float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
|
164
166
|
|
165
|
-
|
166
|
-
//// broadcast
|
167
|
-
//if (tpitg[0] == 0) {
|
168
|
-
// buf[0] = buf[0];
|
169
|
-
//}
|
167
|
+
const float max = simd_max(lmax);
|
170
168
|
|
171
|
-
//
|
169
|
+
// parallel sum
|
170
|
+
float4 lsum4 = 0.0f;
|
171
|
+
for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) {
|
172
|
+
const float4 exp_psrc4 = exp(psrc4[i00] - max);
|
173
|
+
lsum4 += exp_psrc4;
|
174
|
+
pdst4[i00] = exp_psrc4;
|
175
|
+
}
|
176
|
+
float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
|
172
177
|
|
173
|
-
const float sum =
|
178
|
+
const float sum = simd_sum(lsum);
|
174
179
|
|
175
|
-
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
|
176
|
-
|
180
|
+
for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) {
|
181
|
+
pdst4[i00] /= sum;
|
177
182
|
}
|
178
183
|
}
|
179
184
|
|
@@ -192,6 +197,33 @@ kernel void kernel_diag_mask_inf(
|
|
192
197
|
dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
|
193
198
|
} else {
|
194
199
|
dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];
|
200
|
+
}
|
201
|
+
}
|
202
|
+
|
203
|
+
kernel void kernel_diag_mask_inf_8(
|
204
|
+
device const float4 * src0,
|
205
|
+
device float4 * dst,
|
206
|
+
constant int64_t & ne00,
|
207
|
+
constant int64_t & ne01,
|
208
|
+
constant int & n_past,
|
209
|
+
uint3 tpig[[thread_position_in_grid]]) {
|
210
|
+
|
211
|
+
const int64_t i = 2*tpig[0];
|
212
|
+
|
213
|
+
dst[i+0] = src0[i+0];
|
214
|
+
dst[i+1] = src0[i+1];
|
215
|
+
int64_t i4 = 4*i;
|
216
|
+
const int64_t i02 = i4/(ne00*ne01); i4 -= i02*ne00*ne01;
|
217
|
+
const int64_t i01 = i4/(ne00); i4 -= i01*ne00;
|
218
|
+
const int64_t i00 = i4;
|
219
|
+
for (int k = 3; k >= 0; --k) {
|
220
|
+
if (i00 + 4 + k <= n_past + i01) {
|
221
|
+
break;
|
222
|
+
}
|
223
|
+
dst[i+1][k] = -INFINITY;
|
224
|
+
if (i00 + k > n_past + i01) {
|
225
|
+
dst[i][k] = -INFINITY;
|
226
|
+
}
|
195
227
|
}
|
196
228
|
}
|
197
229
|
|
@@ -491,6 +523,79 @@ kernel void kernel_mul_mat_q8_0_f32(
|
|
491
523
|
}
|
492
524
|
}
|
493
525
|
|
526
|
+
#define N_F32_F32 4
|
527
|
+
|
528
|
+
kernel void kernel_mul_mat_f32_f32(
|
529
|
+
device const char * src0,
|
530
|
+
device const char * src1,
|
531
|
+
device float * dst,
|
532
|
+
constant int64_t & ne00,
|
533
|
+
constant int64_t & ne01,
|
534
|
+
constant int64_t & ne02,
|
535
|
+
constant uint64_t & nb00,
|
536
|
+
constant uint64_t & nb01,
|
537
|
+
constant uint64_t & nb02,
|
538
|
+
constant int64_t & ne10,
|
539
|
+
constant int64_t & ne11,
|
540
|
+
constant int64_t & ne12,
|
541
|
+
constant uint64_t & nb10,
|
542
|
+
constant uint64_t & nb11,
|
543
|
+
constant uint64_t & nb12,
|
544
|
+
constant int64_t & ne0,
|
545
|
+
constant int64_t & ne1,
|
546
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
547
|
+
uint tiisg[[thread_index_in_simdgroup]]) {
|
548
|
+
|
549
|
+
const int64_t r0 = tgpig.x;
|
550
|
+
const int64_t rb = tgpig.y*N_F32_F32;
|
551
|
+
const int64_t im = tgpig.z;
|
552
|
+
|
553
|
+
device const float * x = (device const float *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
|
554
|
+
|
555
|
+
if (ne00 < 128) {
|
556
|
+
for (int row = 0; row < N_F32_F32; ++row) {
|
557
|
+
int r1 = rb + row;
|
558
|
+
if (r1 >= ne11) {
|
559
|
+
break;
|
560
|
+
}
|
561
|
+
|
562
|
+
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
|
563
|
+
|
564
|
+
float sumf = 0;
|
565
|
+
for (int i = tiisg; i < ne00; i += 32) {
|
566
|
+
sumf += (float) x[i] * (float) y[i];
|
567
|
+
}
|
568
|
+
|
569
|
+
float all_sum = simd_sum(sumf);
|
570
|
+
if (tiisg == 0) {
|
571
|
+
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
572
|
+
}
|
573
|
+
}
|
574
|
+
} else {
|
575
|
+
device const float4 * x4 = (device const float4 *)x;
|
576
|
+
for (int row = 0; row < N_F32_F32; ++row) {
|
577
|
+
int r1 = rb + row;
|
578
|
+
if (r1 >= ne11) {
|
579
|
+
break;
|
580
|
+
}
|
581
|
+
|
582
|
+
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
|
583
|
+
device const float4 * y4 = (device const float4 *) y;
|
584
|
+
|
585
|
+
float sumf = 0;
|
586
|
+
for (int i = tiisg; i < ne00/4; i += 32) {
|
587
|
+
for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
|
588
|
+
}
|
589
|
+
|
590
|
+
float all_sum = simd_sum(sumf);
|
591
|
+
if (tiisg == 0) {
|
592
|
+
for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
|
593
|
+
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
594
|
+
}
|
595
|
+
}
|
596
|
+
}
|
597
|
+
}
|
598
|
+
|
494
599
|
kernel void kernel_mul_mat_f16_f32_1row(
|
495
600
|
device const char * src0,
|
496
601
|
device const char * src1,
|
@@ -616,6 +721,49 @@ kernel void kernel_mul_mat_f16_f32(
|
|
616
721
|
}
|
617
722
|
}
|
618
723
|
|
724
|
+
// Assumes row size (ne00) is a multiple of 4
|
725
|
+
kernel void kernel_mul_mat_f16_f32_l4(
|
726
|
+
device const char * src0,
|
727
|
+
device const char * src1,
|
728
|
+
device float * dst,
|
729
|
+
constant int64_t & ne00,
|
730
|
+
constant int64_t & ne01,
|
731
|
+
constant int64_t & ne02,
|
732
|
+
constant uint64_t & nb00,
|
733
|
+
constant uint64_t & nb01,
|
734
|
+
constant uint64_t & nb02,
|
735
|
+
constant int64_t & ne10,
|
736
|
+
constant int64_t & ne11,
|
737
|
+
constant int64_t & ne12,
|
738
|
+
constant uint64_t & nb10,
|
739
|
+
constant uint64_t & nb11,
|
740
|
+
constant uint64_t & nb12,
|
741
|
+
constant int64_t & ne0,
|
742
|
+
constant int64_t & ne1,
|
743
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
744
|
+
uint tiisg[[thread_index_in_simdgroup]]) {
|
745
|
+
|
746
|
+
const int nrows = ne11;
|
747
|
+
const int64_t r0 = tgpig.x;
|
748
|
+
const int64_t im = tgpig.z;
|
749
|
+
|
750
|
+
device const half4 * x4 = (device const half4 *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
|
751
|
+
|
752
|
+
for (int r1 = 0; r1 < nrows; ++r1) {
|
753
|
+
device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12);
|
754
|
+
|
755
|
+
float sumf = 0;
|
756
|
+
for (int i = tiisg; i < ne00/4; i += 32) {
|
757
|
+
for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
|
758
|
+
}
|
759
|
+
|
760
|
+
float all_sum = simd_sum(sumf);
|
761
|
+
if (tiisg == 0) {
|
762
|
+
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
763
|
+
}
|
764
|
+
}
|
765
|
+
}
|
766
|
+
|
619
767
|
kernel void kernel_alibi_f32(
|
620
768
|
device const float * src0,
|
621
769
|
device float * dst,
|
@@ -1123,31 +1271,40 @@ kernel void kernel_mul_mat_q3_K_f32(
|
|
1123
1271
|
device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0;
|
1124
1272
|
device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
|
1125
1273
|
|
1126
|
-
float yl[
|
1274
|
+
float yl[32];
|
1127
1275
|
|
1128
|
-
const uint16_t kmask1 =
|
1276
|
+
const uint16_t kmask1 = 0x3030;
|
1129
1277
|
const uint16_t kmask2 = 0x0f0f;
|
1130
1278
|
|
1131
|
-
const int tid = tiisg/
|
1132
|
-
const int ix = tiisg%
|
1133
|
-
const int ip = tid/
|
1134
|
-
const int il = tid/2
|
1279
|
+
const int tid = tiisg/4;
|
1280
|
+
const int ix = tiisg%4;
|
1281
|
+
const int ip = tid/4; // 0 or 1
|
1282
|
+
const int il = 2*((tid%4)/2); // 0 or 2
|
1135
1283
|
const int ir = tid%2;
|
1136
1284
|
const int n = 8;
|
1137
1285
|
const int l0 = n*ir;
|
1138
1286
|
|
1139
|
-
|
1140
|
-
|
1287
|
+
// One would think that the Metal compiler would figure out that ip and il can only have
|
1288
|
+
// 4 possible states, and optimize accordingly. Well, no. It needs help, and we do it
|
1289
|
+
// with these two tales.
|
1290
|
+
//
|
1291
|
+
// Possible masks for the high bit
|
1292
|
+
const ushort4 mm[4] = {{0x0001, 0x0100, 0x0002, 0x0200}, // ip = 0, il = 0
|
1293
|
+
{0x0004, 0x0400, 0x0008, 0x0800}, // ip = 0, il = 2
|
1294
|
+
{0x0010, 0x1000, 0x0020, 0x2000}, // ip = 1, il = 0
|
1295
|
+
{0x0040, 0x4000, 0x0080, 0x8000}}; // ip = 1, il = 2
|
1296
|
+
|
1297
|
+
// Possible masks for the low 2 bits
|
1298
|
+
const int4 qm[2] = {{0x0003, 0x0300, 0x000c, 0x0c00}, {0x0030, 0x3000, 0x00c0, 0xc000}};
|
1299
|
+
|
1300
|
+
const ushort4 hm = mm[2*ip + il/2];
|
1141
1301
|
|
1142
1302
|
const int shift = 2*il;
|
1143
|
-
const
|
1144
|
-
const
|
1145
|
-
const int32_t v1 = 4 << shift;
|
1146
|
-
const int32_t v2 = 1024 << shift;
|
1303
|
+
const float v1 = il == 0 ? 4.f : 64.f;
|
1304
|
+
const float v2 = 4.f * v1;
|
1147
1305
|
|
1148
1306
|
const uint16_t s_shift1 = 4*ip;
|
1149
|
-
const uint16_t s_shift2 = s_shift1 +
|
1150
|
-
const int ik = 4 + (il%2);
|
1307
|
+
const uint16_t s_shift2 = s_shift1 + il;
|
1151
1308
|
|
1152
1309
|
const int q_offset = 32*ip + l0;
|
1153
1310
|
const int y_offset = 128*ip + 32*il + l0;
|
@@ -1156,12 +1313,19 @@ kernel void kernel_mul_mat_q3_K_f32(
|
|
1156
1313
|
|
1157
1314
|
device const float * y1 = yy + ix*QK_K + y_offset;
|
1158
1315
|
|
1159
|
-
|
1160
|
-
|
1316
|
+
uint32_t scales32, aux32;
|
1317
|
+
thread uint16_t * scales16 = (thread uint16_t *)&scales32;
|
1318
|
+
thread const int8_t * scales = (thread const int8_t *)&scales32;
|
1319
|
+
|
1320
|
+
float sumf1[2] = {0.f};
|
1321
|
+
float sumf2[2] = {0.f};
|
1322
|
+
for (int i = ix; i < nb; i += 4) {
|
1161
1323
|
|
1162
1324
|
for (int l = 0; l < 8; ++l) {
|
1163
|
-
yl[l+0] = y1[l+ 0];
|
1164
|
-
yl[l+8] = y1[l+16];
|
1325
|
+
yl[l+ 0] = y1[l+ 0];
|
1326
|
+
yl[l+ 8] = y1[l+16];
|
1327
|
+
yl[l+16] = y1[l+32];
|
1328
|
+
yl[l+24] = y1[l+48];
|
1165
1329
|
}
|
1166
1330
|
|
1167
1331
|
device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset);
|
@@ -1172,27 +1336,43 @@ kernel void kernel_mul_mat_q3_K_f32(
|
|
1172
1336
|
for (int row = 0; row < 2; ++row) {
|
1173
1337
|
|
1174
1338
|
const float d_all = (float)dh[0];
|
1175
|
-
const char2 scales = as_type<char2>((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4)));
|
1176
1339
|
|
1177
|
-
|
1340
|
+
scales16[0] = a[4];
|
1341
|
+
scales16[1] = a[5];
|
1342
|
+
aux32 = ((scales32 >> s_shift2) << 4) & 0x30303030;
|
1343
|
+
scales16[0] = a[il+0];
|
1344
|
+
scales16[1] = a[il+1];
|
1345
|
+
scales32 = ((scales32 >> s_shift1) & 0x0f0f0f0f) | aux32;
|
1346
|
+
|
1347
|
+
float s1 = 0, s2 = 0, s3 = 0, s4 = 0, s5 = 0, s6 = 0;
|
1178
1348
|
for (int l = 0; l < n; l += 2) {
|
1179
|
-
const
|
1180
|
-
s1 += yl[l+0] * (
|
1181
|
-
s2 += yl[l+1] * (
|
1349
|
+
const int32_t qs = q[l/2];
|
1350
|
+
s1 += yl[l+0] * (qs & qm[il/2][0]);
|
1351
|
+
s2 += yl[l+1] * (qs & qm[il/2][1]);
|
1352
|
+
s3 += ((h[l/2] & hm[0]) ? 0.f : yl[l+0]) + ((h[l/2] & hm[1]) ? 0.f : yl[l+1]);
|
1353
|
+
s4 += yl[l+16] * (qs & qm[il/2][2]);
|
1354
|
+
s5 += yl[l+17] * (qs & qm[il/2][3]);
|
1355
|
+
s6 += ((h[l/2] & hm[2]) ? 0.f : yl[l+16]) + ((h[l/2] & hm[3]) ? 0.f : yl[l+17]);
|
1182
1356
|
}
|
1183
|
-
float
|
1184
|
-
|
1185
|
-
|
1357
|
+
float d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
|
1358
|
+
float d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2);
|
1359
|
+
sumf1[row] += d1 * (scales[0] - 32);
|
1360
|
+
sumf2[row] += d2 * (scales[2] - 32);
|
1186
1361
|
|
1187
|
-
s1 = s2 = 0;
|
1362
|
+
s1 = s2 = s3 = s4 = s5 = s6 = 0;
|
1188
1363
|
for (int l = 0; l < n; l += 2) {
|
1189
|
-
const
|
1190
|
-
s1 += yl[l+8] * (
|
1191
|
-
s2 += yl[l+9] * (
|
1364
|
+
const int32_t qs = q[l/2+8];
|
1365
|
+
s1 += yl[l+8] * (qs & qm[il/2][0]);
|
1366
|
+
s2 += yl[l+9] * (qs & qm[il/2][1]);
|
1367
|
+
s3 += ((h[l/2+8] & hm[0]) ? 0.f : yl[l+8]) + ((h[l/2+8] & hm[1]) ? 0.f : yl[l+9]);
|
1368
|
+
s4 += yl[l+24] * (qs & qm[il/2][2]);
|
1369
|
+
s5 += yl[l+25] * (qs & qm[il/2][3]);
|
1370
|
+
s6 += ((h[l/2+8] & hm[2]) ? 0.f : yl[l+24]) + ((h[l/2+8] & hm[3]) ? 0.f : yl[l+25]);
|
1192
1371
|
}
|
1193
|
-
|
1194
|
-
|
1195
|
-
|
1372
|
+
d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
|
1373
|
+
d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2);
|
1374
|
+
sumf1[row] += d1 * (scales[1] - 32);
|
1375
|
+
sumf2[row] += d2 * (scales[3] - 32);
|
1196
1376
|
|
1197
1377
|
q += step;
|
1198
1378
|
h += step;
|
@@ -1201,15 +1381,17 @@ kernel void kernel_mul_mat_q3_K_f32(
|
|
1201
1381
|
|
1202
1382
|
}
|
1203
1383
|
|
1204
|
-
y1 +=
|
1384
|
+
y1 += 4 * QK_K;
|
1205
1385
|
|
1206
1386
|
}
|
1207
1387
|
|
1208
1388
|
for (int row = 0; row < 2; ++row) {
|
1209
|
-
const float sumf = (sumf1[row]
|
1210
|
-
|
1211
|
-
|
1212
|
-
|
1389
|
+
const float sumf = (sumf1[row] + 0.25f * sumf2[row]) / (1 << shift);
|
1390
|
+
sumf1[row] = simd_sum(sumf);
|
1391
|
+
}
|
1392
|
+
if (tiisg == 0) {
|
1393
|
+
for (int row = 0; row < 2; ++row) {
|
1394
|
+
dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = sumf1[row];
|
1213
1395
|
}
|
1214
1396
|
}
|
1215
1397
|
}
|
@@ -1290,13 +1472,13 @@ kernel void kernel_mul_mat_q4_K_f32(
|
|
1290
1472
|
device const float * src1,
|
1291
1473
|
device float * dst,
|
1292
1474
|
constant int64_t & ne00,
|
1293
|
-
constant int64_t & ne01[[buffer(4)]],
|
1294
|
-
constant int64_t & ne02[[buffer(5)]],
|
1295
|
-
constant int64_t & ne10[[buffer(9)]],
|
1296
|
-
constant int64_t & ne12[[buffer(11)]],
|
1297
|
-
constant int64_t & ne0[[buffer(15)]],
|
1298
|
-
constant int64_t & ne1[[buffer(16)]],
|
1299
|
-
constant uint & gqa[[buffer(17)]],
|
1475
|
+
constant int64_t & ne01 [[buffer(4)]],
|
1476
|
+
constant int64_t & ne02 [[buffer(5)]],
|
1477
|
+
constant int64_t & ne10 [[buffer(9)]],
|
1478
|
+
constant int64_t & ne12 [[buffer(11)]],
|
1479
|
+
constant int64_t & ne0 [[buffer(15)]],
|
1480
|
+
constant int64_t & ne1 [[buffer(16)]],
|
1481
|
+
constant uint & gqa [[buffer(17)]],
|
1300
1482
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
1301
1483
|
uint tiisg[[thread_index_in_simdgroup]],
|
1302
1484
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
@@ -1564,17 +1746,25 @@ kernel void kernel_mul_mat_q5_K_f32(
|
|
1564
1746
|
sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2);
|
1565
1747
|
sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2);
|
1566
1748
|
|
1567
|
-
float4
|
1749
|
+
float4 acc1 = {0.f};
|
1750
|
+
float4 acc2 = {0.f};
|
1568
1751
|
for (int l = 0; l < n; ++l) {
|
1569
1752
|
uint8_t h = qh[l];
|
1570
|
-
|
1571
|
-
|
1572
|
-
|
1573
|
-
|
1753
|
+
acc1[0] += yl[l+0] * (q1[l] & 0x0F);
|
1754
|
+
acc1[1] += yl[l+8] * (q1[l] & 0xF0);
|
1755
|
+
acc1[2] += yh[l+0] * (q2[l] & 0x0F);
|
1756
|
+
acc1[3] += yh[l+8] * (q2[l] & 0xF0);
|
1757
|
+
acc2[0] += h & hm1 ? yl[l+0] : 0.f;
|
1758
|
+
acc2[1] += h & hm2 ? yl[l+8] : 0.f;
|
1759
|
+
acc2[2] += h & hm3 ? yh[l+0] : 0.f;
|
1760
|
+
acc2[3] += h & hm4 ? yh[l+8] : 0.f;
|
1574
1761
|
}
|
1575
1762
|
const float dall = dh[0];
|
1576
1763
|
const float dmin = dh[1];
|
1577
|
-
sumf[row] += dall * (
|
1764
|
+
sumf[row] += dall * (sc8[0] * (acc1[0] + 16.f*acc2[0]) +
|
1765
|
+
sc8[1] * (acc1[1]/16.f + 16.f*acc2[1]) +
|
1766
|
+
sc8[4] * (acc1[2] + 16.f*acc2[2]) +
|
1767
|
+
sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) -
|
1578
1768
|
dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
|
1579
1769
|
|
1580
1770
|
q1 += step;
|
@@ -1747,6 +1937,15 @@ kernel void kernel_mul_mat_q6_K_f32(
|
|
1747
1937
|
|
1748
1938
|
//============================= templates and their specializations =============================
|
1749
1939
|
|
1940
|
+
// NOTE: this is not dequantizing - we are simply fitting the template
|
1941
|
+
template <typename type4x4>
|
1942
|
+
void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
|
1943
|
+
float4x4 temp = *(((device float4x4 *)src));
|
1944
|
+
for (int i = 0; i < 16; i++){
|
1945
|
+
reg[i/4][i%4] = temp[i/4][i%4];
|
1946
|
+
}
|
1947
|
+
}
|
1948
|
+
|
1750
1949
|
template <typename type4x4>
|
1751
1950
|
void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
|
1752
1951
|
half4x4 temp = *(((device half4x4 *)src));
|
@@ -1758,28 +1957,30 @@ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg)
|
|
1758
1957
|
template <typename type4x4>
|
1759
1958
|
void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
|
1760
1959
|
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
|
1761
|
-
const
|
1762
|
-
const
|
1960
|
+
const float d1 = il ? (xb->d / 16.h) : xb->d;
|
1961
|
+
const float d2 = d1 / 256.f;
|
1962
|
+
const float md = -8.h * xb->d;
|
1763
1963
|
const ushort mask0 = il ? 0x00F0 : 0x000F;
|
1764
|
-
const ushort mask1 =
|
1964
|
+
const ushort mask1 = mask0 << 8;
|
1765
1965
|
|
1766
1966
|
for (int i=0;i<8;i++) {
|
1767
|
-
reg[i/2][2*(i%2)]
|
1768
|
-
reg[i/2][2*(i%2)+1] = (
|
1967
|
+
reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md;
|
1968
|
+
reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md;
|
1769
1969
|
}
|
1770
1970
|
}
|
1771
1971
|
|
1772
1972
|
template <typename type4x4>
|
1773
1973
|
void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
|
1774
1974
|
device const uint16_t * qs = ((device const uint16_t *)xb + 2);
|
1775
|
-
const
|
1776
|
-
const
|
1975
|
+
const float d1 = il ? (xb->d / 16.h) : xb->d;
|
1976
|
+
const float d2 = d1 / 256.f;
|
1977
|
+
const float m = xb->m;
|
1777
1978
|
const ushort mask0 = il ? 0x00F0 : 0x000F;
|
1778
|
-
const ushort mask1 =
|
1979
|
+
const ushort mask1 = mask0 << 8;
|
1779
1980
|
|
1780
1981
|
for (int i=0;i<8;i++) {
|
1781
|
-
reg[i/2][2*(i%2)]
|
1782
|
-
reg[i/2][2*(i%2)+1] = ((
|
1982
|
+
reg[i/2][2*(i%2)+0] = ((qs[i] & mask0) * d1) + m;
|
1983
|
+
reg[i/2][2*(i%2)+1] = ((qs[i] & mask1) * d2) + m;
|
1783
1984
|
}
|
1784
1985
|
}
|
1785
1986
|
|
@@ -1815,7 +2016,7 @@ void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg
|
|
1815
2016
|
|
1816
2017
|
template <typename type4x4>
|
1817
2018
|
void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) {
|
1818
|
-
const
|
2019
|
+
const half d_all = xb->d;
|
1819
2020
|
device const uint8_t * q = (device const uint8_t *)xb->qs;
|
1820
2021
|
device const uint8_t * h = (device const uint8_t *)xb->hmask;
|
1821
2022
|
device const int8_t * scales = (device const int8_t *)xb->scales;
|
@@ -1828,16 +2029,18 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
|
|
1828
2029
|
((il/4)>0 ? 12 : 3);
|
1829
2030
|
uint16_t kmask2 = il/8 ? 0xF0 : 0x0F;
|
1830
2031
|
uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];
|
1831
|
-
int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2)
|
1832
|
-
|
1833
|
-
|
2032
|
+
int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2)
|
2033
|
+
: (scale_2&kmask2) | ((scale_1&kmask1) << 4);
|
2034
|
+
half dl = il<8 ? d_all * (dl_int - 32.h) : d_all * (dl_int / 16.h - 32.h);
|
2035
|
+
const half ml = 4.h * dl;
|
1834
2036
|
|
1835
|
-
il = (il/2)
|
1836
|
-
|
1837
|
-
uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
|
2037
|
+
il = (il/2) & 3;
|
2038
|
+
const half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
|
2039
|
+
const uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
|
2040
|
+
dl *= coef;
|
1838
2041
|
|
1839
2042
|
for (int i = 0; i < 16; ++i) {
|
1840
|
-
reg[i/4][i%4] =
|
2043
|
+
reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml);
|
1841
2044
|
}
|
1842
2045
|
#else
|
1843
2046
|
float kcoef = il&1 ? 1.f/16.f : 1.f;
|
@@ -1852,26 +2055,31 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
|
|
1852
2055
|
#endif
|
1853
2056
|
}
|
1854
2057
|
|
2058
|
+
static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) {
|
2059
|
+
return j < 4 ? uchar2{uchar(q[j+0+k] & 63), uchar(q[j+4+k] & 63)}
|
2060
|
+
: uchar2{uchar((q[j+4+k] & 0xF) | ((q[j-4+k] & 0xc0) >> 2)), uchar((q[j+4+k] >> 4) | ((q[j-0+k] & 0xc0) >> 2))};
|
2061
|
+
}
|
2062
|
+
|
1855
2063
|
template <typename type4x4>
|
1856
2064
|
void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) {
|
1857
|
-
device const
|
2065
|
+
device const uchar * q = xb->qs;
|
1858
2066
|
|
1859
2067
|
#if QK_K == 256
|
1860
|
-
const float d = (float)(xb->d);
|
1861
|
-
const float min = (float)(xb->dmin);
|
1862
2068
|
short is = (il/4) * 2;
|
1863
2069
|
q = q + (il/4) * 32 + 16 * (il&1);
|
1864
|
-
il = il
|
1865
|
-
const
|
1866
|
-
const
|
1867
|
-
const
|
2070
|
+
il = il & 3;
|
2071
|
+
const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
|
2072
|
+
const half d = il < 2 ? xb->d : xb->d / 16.h;
|
2073
|
+
const half min = xb->dmin;
|
2074
|
+
const half dl = d * sc[0];
|
2075
|
+
const half ml = min * sc[1];
|
1868
2076
|
#else
|
1869
2077
|
q = q + 16 * (il&1);
|
1870
2078
|
device const uint8_t * s = xb->scales;
|
1871
2079
|
device const half2 * dh = (device const half2 *)xb->d;
|
1872
2080
|
const float2 d = (float2)dh[0];
|
1873
2081
|
const float dl = il<2 ? d[0] * (s[0]&0xF) : d[0] * (s[1]&0xF)/16.h;
|
1874
|
-
const float ml = il<2 ? d[1] * (s[0]>>4) : d[1
|
2082
|
+
const float ml = il<2 ? d[1] * (s[0]>>4) : d[1] * (s[1]>>4);
|
1875
2083
|
#endif
|
1876
2084
|
const ushort mask = il<2 ? 0x0F : 0xF0;
|
1877
2085
|
for (int i = 0; i < 16; ++i) {
|
@@ -1885,19 +2093,19 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
|
|
1885
2093
|
device const uint8_t * qh = xb->qh;
|
1886
2094
|
|
1887
2095
|
#if QK_K == 256
|
1888
|
-
const float d = (float)(xb->d);
|
1889
|
-
const float min = (float)(xb->dmin);
|
1890
2096
|
short is = (il/4) * 2;
|
1891
2097
|
q = q + 32 * (il/4) + 16 * (il&1);
|
1892
2098
|
qh = qh + 16 * (il&1);
|
1893
2099
|
uint8_t ul = 1 << (il/2);
|
1894
|
-
il = il
|
1895
|
-
const
|
1896
|
-
const
|
1897
|
-
const
|
2100
|
+
il = il & 3;
|
2101
|
+
const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
|
2102
|
+
const half d = il < 2 ? xb->d : xb->d / 16.h;
|
2103
|
+
const half min = xb->dmin;
|
2104
|
+
const half dl = d * sc[0];
|
2105
|
+
const half ml = min * sc[1];
|
1898
2106
|
|
1899
|
-
const ushort mask
|
1900
|
-
const
|
2107
|
+
const ushort mask = il<2 ? 0x0F : 0xF0;
|
2108
|
+
const half qh_val = il<2 ? 16.h : 256.h;
|
1901
2109
|
for (int i = 0; i < 16; ++i) {
|
1902
2110
|
reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
|
1903
2111
|
}
|
@@ -1916,7 +2124,7 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
|
|
1916
2124
|
|
1917
2125
|
template <typename type4x4>
|
1918
2126
|
void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) {
|
1919
|
-
const
|
2127
|
+
const half d_all = xb->d;
|
1920
2128
|
device const uint8_t * ql = (device const uint8_t *)xb->ql;
|
1921
2129
|
device const uint8_t * qh = (device const uint8_t *)xb->qh;
|
1922
2130
|
device const int8_t * scales = (device const int8_t *)xb->scales;
|
@@ -1924,19 +2132,21 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
|
|
1924
2132
|
#if QK_K == 256
|
1925
2133
|
ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
|
1926
2134
|
qh = qh + 32*(il/8) + 16*(il&1);
|
1927
|
-
|
1928
|
-
il = (il/2)
|
2135
|
+
half sc = scales[(il%2) + 2 * ((il/2))];
|
2136
|
+
il = (il/2) & 3;
|
1929
2137
|
#else
|
1930
2138
|
ql = ql + 16 * (il&1);
|
1931
|
-
|
2139
|
+
half sc = scales[il];
|
1932
2140
|
#endif
|
2141
|
+
const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
|
2142
|
+
const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
|
2143
|
+
const half coef = il>1 ? 1.f/16.h : 1.h;
|
2144
|
+
const half ml = d_all * sc * 32.h;
|
2145
|
+
const half dl = d_all * sc * coef;
|
1933
2146
|
for (int i = 0; i < 16; ++i) {
|
1934
|
-
|
1935
|
-
|
1936
|
-
|
1937
|
-
float q = il&1 ? ((ql[i]&kmask2)|((qh[i]&kmask1)<<2)) - 32.f/coef : \
|
1938
|
-
((ql[i]&kmask2)|((qh[i]&kmask1)<<4)) - 32.f/coef;
|
1939
|
-
reg[i/4][i%4] = d_all * sc * q * coef;
|
2147
|
+
const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2))
|
2148
|
+
: ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4));
|
2149
|
+
reg[i/4][i%4] = dl * q - ml;
|
1940
2150
|
}
|
1941
2151
|
}
|
1942
2152
|
|
@@ -1976,22 +2186,25 @@ kernel void kernel_get_rows(
|
|
1976
2186
|
// each block_q contains 16*nl weights
|
1977
2187
|
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
1978
2188
|
kernel void kernel_mul_mm(device const uchar * src0,
|
1979
|
-
|
1980
|
-
|
1981
|
-
|
1982
|
-
|
1983
|
-
|
1984
|
-
|
1985
|
-
|
1986
|
-
|
1987
|
-
|
1988
|
-
|
1989
|
-
|
1990
|
-
|
1991
|
-
|
1992
|
-
|
1993
|
-
|
1994
|
-
|
2189
|
+
device const uchar * src1,
|
2190
|
+
device float * dst,
|
2191
|
+
constant int64_t & ne00,
|
2192
|
+
constant int64_t & ne02,
|
2193
|
+
constant int64_t & nb01,
|
2194
|
+
constant int64_t & nb02,
|
2195
|
+
constant int64_t & ne12,
|
2196
|
+
constant int64_t & nb10,
|
2197
|
+
constant int64_t & nb11,
|
2198
|
+
constant int64_t & nb12,
|
2199
|
+
constant int64_t & ne0,
|
2200
|
+
constant int64_t & ne1,
|
2201
|
+
constant uint & gqa,
|
2202
|
+
threadgroup uchar * shared_memory [[threadgroup(0)]],
|
2203
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
2204
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
2205
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
2206
|
+
|
2207
|
+
threadgroup half * sa = (threadgroup half *)(shared_memory);
|
1995
2208
|
threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
|
1996
2209
|
|
1997
2210
|
const uint r0 = tgpig.y;
|
@@ -2004,7 +2217,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|
2004
2217
|
short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
|
2005
2218
|
short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
|
2006
2219
|
|
2007
|
-
simdgroup_half8x8
|
2220
|
+
simdgroup_half8x8 ma[4];
|
2008
2221
|
simdgroup_float8x8 mb[2];
|
2009
2222
|
simdgroup_float8x8 c_res[8];
|
2010
2223
|
for (int i = 0; i < 8; i++){
|
@@ -2012,10 +2225,15 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|
2012
2225
|
}
|
2013
2226
|
|
2014
2227
|
short il = (tiitg % THREAD_PER_ROW);
|
2015
|
-
|
2016
|
-
|
2017
|
-
|
2018
|
-
|
2228
|
+
|
2229
|
+
uint offset0 = im/gqa*nb02;
|
2230
|
+
ushort offset1 = il/nl;
|
2231
|
+
|
2232
|
+
device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
|
2233
|
+
device const float * y = (device const float *)(src1
|
2234
|
+
+ nb12 * im
|
2235
|
+
+ nb11 * (r1 * BLOCK_SIZE_N + thread_col)
|
2236
|
+
+ nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
|
2019
2237
|
|
2020
2238
|
for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
|
2021
2239
|
//load data and store to threadgroup memory
|
@@ -2095,6 +2313,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|
2095
2313
|
typedef void (get_rows_t)(device const void *, device const int *, device float *, constant int64_t &, \
|
2096
2314
|
constant uint64_t &, constant uint64_t &, uint, uint, uint);
|
2097
2315
|
|
2316
|
+
template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
|
2098
2317
|
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
|
2099
2318
|
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
|
2100
2319
|
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
|
@@ -2105,14 +2324,28 @@ template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows
|
|
2105
2324
|
template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
|
2106
2325
|
template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
|
2107
2326
|
|
2108
|
-
typedef void (mat_mm_t)(
|
2109
|
-
|
2110
|
-
|
2111
|
-
|
2112
|
-
|
2113
|
-
|
2114
|
-
|
2115
|
-
|
2327
|
+
typedef void (mat_mm_t)(
|
2328
|
+
device const uchar * src0,
|
2329
|
+
device const uchar * src1,
|
2330
|
+
device float * dst,
|
2331
|
+
constant int64_t & ne00,
|
2332
|
+
constant int64_t & ne02,
|
2333
|
+
constant int64_t & nb01,
|
2334
|
+
constant int64_t & nb02,
|
2335
|
+
constant int64_t & ne12,
|
2336
|
+
constant int64_t & nb10,
|
2337
|
+
constant int64_t & nb11,
|
2338
|
+
constant int64_t & nb12,
|
2339
|
+
constant int64_t & ne0,
|
2340
|
+
constant int64_t & ne1,
|
2341
|
+
constant uint & gqa,
|
2342
|
+
threadgroup uchar *, uint3, uint, uint);
|
2343
|
+
|
2344
|
+
template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<float4x4, 1, dequantize_f32>;
|
2345
|
+
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
|
2346
|
+
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
|
2347
|
+
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
|
2348
|
+
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
|
2116
2349
|
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
|
2117
2350
|
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
|
2118
2351
|
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
|