llama_cpp 0.5.0 → 0.5.2
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +14 -2
- data/examples/prompt_jp.txt +1 -1
- data/ext/llama_cpp/extconf.rb +1 -1
- data/ext/llama_cpp/llama_cpp.cpp +30 -0
- data/ext/llama_cpp/src/ggml-alloc.c +101 -24
- data/ext/llama_cpp/src/ggml-cuda.cu +1094 -678
- data/ext/llama_cpp/src/ggml-metal.m +89 -23
- data/ext/llama_cpp/src/ggml-metal.metal +398 -211
- data/ext/llama_cpp/src/ggml-opencl.cpp +7 -7
- data/ext/llama_cpp/src/ggml.c +32 -56
- data/ext/llama_cpp/src/ggml.h +1 -1
- data/ext/llama_cpp/src/k_quants.c +49 -13
- data/ext/llama_cpp/src/llama.cpp +833 -281
- data/ext/llama_cpp/src/llama.h +11 -6
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +1 -1
- data/sig/llama_cpp.rbs +4 -0
- metadata +2 -2
@@ -63,18 +63,18 @@ kernel void kernel_mul_row(
|
|
63
63
|
}
|
64
64
|
|
65
65
|
kernel void kernel_scale(
|
66
|
-
device const
|
67
|
-
device
|
66
|
+
device const float4 * src0,
|
67
|
+
device float4 * dst,
|
68
68
|
constant float & scale,
|
69
69
|
uint tpig[[thread_position_in_grid]]) {
|
70
70
|
dst[tpig] = src0[tpig] * scale;
|
71
71
|
}
|
72
72
|
|
73
73
|
kernel void kernel_silu(
|
74
|
-
device const
|
75
|
-
device
|
74
|
+
device const float4 * src0,
|
75
|
+
device float4 * dst,
|
76
76
|
uint tpig[[thread_position_in_grid]]) {
|
77
|
-
|
77
|
+
device const float4 & x = src0[tpig];
|
78
78
|
dst[tpig] = x / (1.0f + exp(-x));
|
79
79
|
}
|
80
80
|
|
@@ -89,10 +89,10 @@ constant float GELU_COEF_A = 0.044715f;
|
|
89
89
|
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
90
90
|
|
91
91
|
kernel void kernel_gelu(
|
92
|
-
device const
|
93
|
-
device
|
92
|
+
device const float4 * src0,
|
93
|
+
device float4 * dst,
|
94
94
|
uint tpig[[thread_position_in_grid]]) {
|
95
|
-
|
95
|
+
device const float4 & x = src0[tpig];
|
96
96
|
|
97
97
|
// BEWARE !!!
|
98
98
|
// Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
|
@@ -107,7 +107,6 @@ kernel void kernel_soft_max(
|
|
107
107
|
constant int64_t & ne00,
|
108
108
|
constant int64_t & ne01,
|
109
109
|
constant int64_t & ne02,
|
110
|
-
threadgroup float * buf [[threadgroup(0)]],
|
111
110
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
112
111
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
113
112
|
uint3 ntg[[threads_per_threadgroup]]) {
|
@@ -119,55 +118,67 @@ kernel void kernel_soft_max(
|
|
119
118
|
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
120
119
|
|
121
120
|
// parallel max
|
122
|
-
|
123
|
-
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
|
124
|
-
|
121
|
+
float lmax = psrc0[tpitg[0]];
|
122
|
+
for (int i00 = tpitg[0] + ntg[0]; i00 < ne00; i00 += ntg[0]) {
|
123
|
+
lmax = MAX(lmax, psrc0[i00]);
|
125
124
|
}
|
125
|
+
const float max = simd_max(lmax);
|
126
126
|
|
127
|
-
//
|
128
|
-
|
129
|
-
for (
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
127
|
+
// parallel sum
|
128
|
+
float lsum = 0.0f;
|
129
|
+
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
|
130
|
+
const float exp_psrc0 = exp(psrc0[i00] - max);
|
131
|
+
lsum += exp_psrc0;
|
132
|
+
// Remember the result of exp here. exp is expensive, so we really do not
|
133
|
+
// whish to compute it twice.
|
134
|
+
pdst[i00] = exp_psrc0;
|
134
135
|
}
|
135
136
|
|
136
|
-
|
137
|
-
|
138
|
-
|
137
|
+
const float sum = simd_sum(lsum);
|
138
|
+
|
139
|
+
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
|
140
|
+
pdst[i00] /= sum;
|
139
141
|
}
|
142
|
+
}
|
140
143
|
|
141
|
-
|
144
|
+
kernel void kernel_soft_max_4(
|
145
|
+
device const float * src0,
|
146
|
+
device float * dst,
|
147
|
+
constant int64_t & ne00,
|
148
|
+
constant int64_t & ne01,
|
149
|
+
constant int64_t & ne02,
|
150
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
151
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
152
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
153
|
+
const int64_t i03 = tgpig[2];
|
154
|
+
const int64_t i02 = tgpig[1];
|
155
|
+
const int64_t i01 = tgpig[0];
|
142
156
|
|
143
|
-
const
|
157
|
+
device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
158
|
+
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
144
159
|
|
145
|
-
// parallel
|
146
|
-
|
147
|
-
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
|
148
|
-
|
160
|
+
// parallel max
|
161
|
+
float4 lmax4 = psrc4[tpitg[0]];
|
162
|
+
for (int i00 = tpitg[0] + ntg[0]; i00 < ne00/4; i00 += ntg[0]) {
|
163
|
+
lmax4 = fmax(lmax4, psrc4[i00]);
|
149
164
|
}
|
165
|
+
float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
|
150
166
|
|
151
|
-
|
152
|
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
153
|
-
for (uint i = ntg[0]/2; i > 0; i /= 2) {
|
154
|
-
if (tpitg[0] < i) {
|
155
|
-
buf[tpitg[0]] += buf[tpitg[0] + i];
|
156
|
-
}
|
157
|
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
158
|
-
}
|
167
|
+
const float max = simd_max(lmax);
|
159
168
|
|
160
|
-
//
|
161
|
-
|
162
|
-
|
169
|
+
// parallel sum
|
170
|
+
float4 lsum4 = 0.0f;
|
171
|
+
for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) {
|
172
|
+
const float4 exp_psrc4 = exp(psrc4[i00] - max);
|
173
|
+
lsum4 += exp_psrc4;
|
174
|
+
pdst4[i00] = exp_psrc4;
|
163
175
|
}
|
176
|
+
float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
|
164
177
|
|
165
|
-
|
166
|
-
|
167
|
-
const float sum = buf[0];
|
178
|
+
const float sum = simd_sum(lsum);
|
168
179
|
|
169
|
-
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
|
170
|
-
|
180
|
+
for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) {
|
181
|
+
pdst4[i00] /= sum;
|
171
182
|
}
|
172
183
|
}
|
173
184
|
|
@@ -186,6 +197,33 @@ kernel void kernel_diag_mask_inf(
|
|
186
197
|
dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
|
187
198
|
} else {
|
188
199
|
dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];
|
200
|
+
}
|
201
|
+
}
|
202
|
+
|
203
|
+
kernel void kernel_diag_mask_inf_8(
|
204
|
+
device const float4 * src0,
|
205
|
+
device float4 * dst,
|
206
|
+
constant int64_t & ne00,
|
207
|
+
constant int64_t & ne01,
|
208
|
+
constant int & n_past,
|
209
|
+
uint3 tpig[[thread_position_in_grid]]) {
|
210
|
+
|
211
|
+
const int64_t i = 2*tpig[0];
|
212
|
+
|
213
|
+
dst[i+0] = src0[i+0];
|
214
|
+
dst[i+1] = src0[i+1];
|
215
|
+
int64_t i4 = 4*i;
|
216
|
+
const int64_t i02 = i4/(ne00*ne01); i4 -= i02*ne00*ne01;
|
217
|
+
const int64_t i01 = i4/(ne00); i4 -= i01*ne00;
|
218
|
+
const int64_t i00 = i4;
|
219
|
+
for (int k = 3; k >= 0; --k) {
|
220
|
+
if (i00 + 4 + k <= n_past + i01) {
|
221
|
+
break;
|
222
|
+
}
|
223
|
+
dst[i+1][k] = -INFINITY;
|
224
|
+
if (i00 + k > n_past + i01) {
|
225
|
+
dst[i][k] = -INFINITY;
|
226
|
+
}
|
189
227
|
}
|
190
228
|
}
|
191
229
|
|
@@ -214,25 +252,17 @@ kernel void kernel_norm(
|
|
214
252
|
}
|
215
253
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
216
254
|
}
|
217
|
-
|
218
|
-
if (tpitg == 0) {
|
219
|
-
sum[0] /= ne00;
|
220
|
-
}
|
221
|
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
222
|
-
const float mean = sum[0];
|
255
|
+
const float mean = sum[0] / ne00;
|
223
256
|
|
224
|
-
// recenter
|
257
|
+
// recenter and VARIANCE
|
258
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
225
259
|
device float * y = dst + tgpig*ne00;
|
226
|
-
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
227
|
-
y[i00] = x[i00] - mean;
|
228
|
-
}
|
229
|
-
|
230
|
-
// VARIANCE
|
231
|
-
// parallel sum
|
232
260
|
sum[tpitg] = 0.0f;
|
233
261
|
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
262
|
+
y[i00] = x[i00] - mean;
|
234
263
|
sum[tpitg] += y[i00] * y[i00];
|
235
264
|
}
|
265
|
+
|
236
266
|
// reduce
|
237
267
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
238
268
|
for (uint i = ntg/2; i > 0; i /= 2) {
|
@@ -241,12 +271,7 @@ kernel void kernel_norm(
|
|
241
271
|
}
|
242
272
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
243
273
|
}
|
244
|
-
|
245
|
-
if (tpitg == 0) {
|
246
|
-
sum[0] /= ne00;
|
247
|
-
}
|
248
|
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
249
|
-
const float variance = sum[0];
|
274
|
+
const float variance = sum[0] / ne00;
|
250
275
|
|
251
276
|
const float scale = 1.0f/sqrt(variance + eps);
|
252
277
|
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
@@ -254,7 +279,6 @@ kernel void kernel_norm(
|
|
254
279
|
}
|
255
280
|
}
|
256
281
|
|
257
|
-
|
258
282
|
kernel void kernel_rms_norm(
|
259
283
|
device const void * src0,
|
260
284
|
device float * dst,
|
@@ -435,6 +459,8 @@ kernel void kernel_mul_mat_q4_1_f32(
|
|
435
459
|
mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
|
436
460
|
}
|
437
461
|
|
462
|
+
#define NB_Q8_0 8
|
463
|
+
|
438
464
|
kernel void kernel_mul_mat_q8_0_f32(
|
439
465
|
device const void * src0,
|
440
466
|
device const float * src1,
|
@@ -463,30 +489,30 @@ kernel void kernel_mul_mat_q8_0_f32(
|
|
463
489
|
device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0;
|
464
490
|
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
465
491
|
|
466
|
-
float yl[
|
492
|
+
float yl[NB_Q8_0];
|
467
493
|
float sumf[nr]={0.f};
|
468
494
|
|
469
|
-
const int ix = tiisg/
|
470
|
-
const int il = tiisg%
|
495
|
+
const int ix = tiisg/4;
|
496
|
+
const int il = tiisg%4;
|
471
497
|
|
472
|
-
device const float * yb = y + ix * QK8_0 +
|
498
|
+
device const float * yb = y + ix * QK8_0 + NB_Q8_0*il;
|
473
499
|
|
474
|
-
// each thread in a SIMD group deals with
|
475
|
-
for (int ib = ix; ib < nb; ib += nw/
|
476
|
-
for (int i = 0; i <
|
500
|
+
// each thread in a SIMD group deals with NB_Q8_0 quants at a time
|
501
|
+
for (int ib = ix; ib < nb; ib += nw/4) {
|
502
|
+
for (int i = 0; i < NB_Q8_0; ++i) {
|
477
503
|
yl[i] = yb[i];
|
478
504
|
}
|
479
505
|
|
480
506
|
for (int row = 0; row < nr; row++) {
|
481
|
-
device const int8_t * qs = x[ib+row*nb].qs +
|
507
|
+
device const int8_t * qs = x[ib+row*nb].qs + NB_Q8_0*il;
|
482
508
|
float sumq = 0.f;
|
483
|
-
for (int iq = 0; iq <
|
509
|
+
for (int iq = 0; iq < NB_Q8_0; ++iq) {
|
484
510
|
sumq += qs[iq] * yl[iq];
|
485
511
|
}
|
486
512
|
sumf[row] += sumq*x[ib+row*nb].d;
|
487
513
|
}
|
488
514
|
|
489
|
-
yb +=
|
515
|
+
yb += NB_Q8_0 * nw;
|
490
516
|
}
|
491
517
|
|
492
518
|
for (int row = 0; row < nr; ++row) {
|
@@ -497,7 +523,7 @@ kernel void kernel_mul_mat_q8_0_f32(
|
|
497
523
|
}
|
498
524
|
}
|
499
525
|
|
500
|
-
kernel void
|
526
|
+
kernel void kernel_mul_mat_f16_f32_1row(
|
501
527
|
device const char * src0,
|
502
528
|
device const char * src1,
|
503
529
|
device float * dst,
|
@@ -515,11 +541,8 @@ kernel void kernel_mul_mat_f16_f32(
|
|
515
541
|
constant uint64_t & nb12,
|
516
542
|
constant int64_t & ne0,
|
517
543
|
constant int64_t & ne1,
|
518
|
-
threadgroup float * sum [[threadgroup(0)]],
|
519
544
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
520
|
-
|
521
|
-
uint3 tpitg[[thread_position_in_threadgroup]],
|
522
|
-
uint3 tptg[[threads_per_threadgroup]]) {
|
545
|
+
uint tiisg[[thread_index_in_simdgroup]]) {
|
523
546
|
|
524
547
|
const int64_t r0 = tgpig.x;
|
525
548
|
const int64_t r1 = tgpig.y;
|
@@ -528,42 +551,144 @@ kernel void kernel_mul_mat_f16_f32(
|
|
528
551
|
device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
|
529
552
|
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
|
530
553
|
|
531
|
-
|
532
|
-
|
554
|
+
float sumf = 0;
|
555
|
+
if (ne00 < 128) {
|
556
|
+
for (int i = tiisg; i < ne00; i += 32) {
|
557
|
+
sumf += (float) x[i] * (float) y[i];
|
558
|
+
}
|
559
|
+
float all_sum = simd_sum(sumf);
|
560
|
+
if (tiisg == 0) {
|
561
|
+
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
562
|
+
}
|
563
|
+
} else {
|
564
|
+
device const half4 * x4 = (device const half4 *) x;
|
565
|
+
device const float4 * y4 = (device const float4 *) y;
|
566
|
+
for (int i = tiisg; i < ne00/4; i += 32) {
|
567
|
+
for (int k = 0; k < 4; ++k) sumf += (float)x4[i][k] * y4[i][k];
|
568
|
+
}
|
569
|
+
float all_sum = simd_sum(sumf);
|
570
|
+
if (tiisg == 0) {
|
571
|
+
for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
|
572
|
+
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
573
|
+
}
|
574
|
+
}
|
533
575
|
|
534
|
-
|
576
|
+
}
|
535
577
|
|
536
|
-
|
537
|
-
sum[ith] += (float) x[i] * (float) y[i];
|
538
|
-
}
|
578
|
+
#define N_F16_F32 4
|
539
579
|
|
540
|
-
|
541
|
-
|
542
|
-
|
543
|
-
|
580
|
+
kernel void kernel_mul_mat_f16_f32(
|
581
|
+
device const char * src0,
|
582
|
+
device const char * src1,
|
583
|
+
device float * dst,
|
584
|
+
constant int64_t & ne00,
|
585
|
+
constant int64_t & ne01,
|
586
|
+
constant int64_t & ne02,
|
587
|
+
constant uint64_t & nb00,
|
588
|
+
constant uint64_t & nb01,
|
589
|
+
constant uint64_t & nb02,
|
590
|
+
constant int64_t & ne10,
|
591
|
+
constant int64_t & ne11,
|
592
|
+
constant int64_t & ne12,
|
593
|
+
constant uint64_t & nb10,
|
594
|
+
constant uint64_t & nb11,
|
595
|
+
constant uint64_t & nb12,
|
596
|
+
constant int64_t & ne0,
|
597
|
+
constant int64_t & ne1,
|
598
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
599
|
+
uint tiisg[[thread_index_in_simdgroup]]) {
|
600
|
+
|
601
|
+
const int64_t r0 = tgpig.x;
|
602
|
+
const int64_t rb = tgpig.y*N_F16_F32;
|
603
|
+
const int64_t im = tgpig.z;
|
604
|
+
|
605
|
+
device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
|
606
|
+
|
607
|
+
if (ne00 < 128) {
|
608
|
+
for (int row = 0; row < N_F16_F32; ++row) {
|
609
|
+
int r1 = rb + row;
|
610
|
+
if (r1 >= ne11) {
|
611
|
+
break;
|
612
|
+
}
|
613
|
+
|
614
|
+
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
|
615
|
+
|
616
|
+
float sumf = 0;
|
617
|
+
for (int i = tiisg; i < ne00; i += 32) {
|
618
|
+
sumf += (float) x[i] * (float) y[i];
|
619
|
+
}
|
620
|
+
|
621
|
+
float all_sum = simd_sum(sumf);
|
622
|
+
if (tiisg == 0) {
|
623
|
+
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
624
|
+
}
|
625
|
+
}
|
626
|
+
} else {
|
627
|
+
device const half4 * x4 = (device const half4 *)x;
|
628
|
+
for (int row = 0; row < N_F16_F32; ++row) {
|
629
|
+
int r1 = rb + row;
|
630
|
+
if (r1 >= ne11) {
|
631
|
+
break;
|
632
|
+
}
|
633
|
+
|
634
|
+
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
|
635
|
+
device const float4 * y4 = (device const float4 *) y;
|
636
|
+
|
637
|
+
float sumf = 0;
|
638
|
+
for (int i = tiisg; i < ne00/4; i += 32) {
|
639
|
+
for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
|
640
|
+
}
|
641
|
+
|
642
|
+
float all_sum = simd_sum(sumf);
|
643
|
+
if (tiisg == 0) {
|
644
|
+
for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
|
645
|
+
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
646
|
+
}
|
647
|
+
}
|
544
648
|
}
|
545
|
-
|
546
|
-
|
547
|
-
|
649
|
+
}
|
650
|
+
|
651
|
+
// Assumes row size (ne00) is a multiple of 4
|
652
|
+
kernel void kernel_mul_mat_f16_f32_l4(
|
653
|
+
device const char * src0,
|
654
|
+
device const char * src1,
|
655
|
+
device float * dst,
|
656
|
+
constant int64_t & ne00,
|
657
|
+
constant int64_t & ne01,
|
658
|
+
constant int64_t & ne02,
|
659
|
+
constant uint64_t & nb00,
|
660
|
+
constant uint64_t & nb01,
|
661
|
+
constant uint64_t & nb02,
|
662
|
+
constant int64_t & ne10,
|
663
|
+
constant int64_t & ne11,
|
664
|
+
constant int64_t & ne12,
|
665
|
+
constant uint64_t & nb10,
|
666
|
+
constant uint64_t & nb11,
|
667
|
+
constant uint64_t & nb12,
|
668
|
+
constant int64_t & ne0,
|
669
|
+
constant int64_t & ne1,
|
670
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
671
|
+
uint tiisg[[thread_index_in_simdgroup]]) {
|
672
|
+
|
673
|
+
const int nrows = ne11;
|
674
|
+
const int64_t r0 = tgpig.x;
|
675
|
+
const int64_t im = tgpig.z;
|
676
|
+
|
677
|
+
device const half4 * x4 = (device const half4 *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
|
678
|
+
|
679
|
+
for (int r1 = 0; r1 < nrows; ++r1) {
|
680
|
+
device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12);
|
681
|
+
|
682
|
+
float sumf = 0;
|
683
|
+
for (int i = tiisg; i < ne00/4; i += 32) {
|
684
|
+
for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
|
685
|
+
}
|
686
|
+
|
687
|
+
float all_sum = simd_sum(sumf);
|
688
|
+
if (tiisg == 0) {
|
689
|
+
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
690
|
+
}
|
548
691
|
}
|
549
|
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
550
|
-
if (ith == 0) {
|
551
|
-
for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
|
552
|
-
dst[im*ne1*ne0 + r1*ne0 + r0] = sum[0];
|
553
|
-
}
|
554
|
-
|
555
|
-
// Original implementation. Left behind commented out for now
|
556
|
-
//threadgroup_barrier(mem_flags::mem_threadgroup);
|
557
|
-
//for (uint i = tptg.x/2; i > 0; i /= 2) {
|
558
|
-
// if (tpitg.x < i) {
|
559
|
-
// sum[tpitg.x] += sum[tpitg.x + i];
|
560
|
-
// }
|
561
|
-
// threadgroup_barrier(mem_flags::mem_threadgroup);
|
562
|
-
//}
|
563
|
-
//
|
564
|
-
//if (tpitg.x == 0) {
|
565
|
-
// dst[im*ne1*ne0 + r1*ne0 + r0] = sum[0];
|
566
|
-
//}
|
567
692
|
}
|
568
693
|
|
569
694
|
kernel void kernel_alibi_f32(
|
@@ -632,25 +757,27 @@ kernel void kernel_rope(
|
|
632
757
|
constant int & mode,
|
633
758
|
constant float & freq_base,
|
634
759
|
constant float & freq_scale,
|
635
|
-
|
636
|
-
|
637
|
-
|
638
|
-
const int64_t
|
760
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
761
|
+
uint3 tptg[[threads_per_threadgroup]],
|
762
|
+
uint3 tgpig[[threadgroup_position_in_grid]]) {
|
763
|
+
const int64_t i3 = tgpig[2];
|
764
|
+
const int64_t i2 = tgpig[1];
|
765
|
+
const int64_t i1 = tgpig[0];
|
639
766
|
|
640
767
|
const bool is_neox = mode & 2;
|
641
|
-
const float theta_scale = pow(freq_base, -2.0f/n_dims);
|
642
768
|
|
643
769
|
const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
|
644
770
|
|
645
|
-
float
|
771
|
+
const float theta_0 = freq_scale * (float)p;
|
772
|
+
const float inv_ndims = -1.f/n_dims;
|
646
773
|
|
647
774
|
if (!is_neox) {
|
648
|
-
for (int64_t i0 =
|
775
|
+
for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
|
776
|
+
|
777
|
+
const float theta = theta_0 * pow(freq_base, inv_ndims*i0);
|
649
778
|
const float cos_theta = cos(theta);
|
650
779
|
const float sin_theta = sin(theta);
|
651
780
|
|
652
|
-
theta *= theta_scale;
|
653
|
-
|
654
781
|
device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
655
782
|
device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
656
783
|
|
@@ -662,12 +789,12 @@ kernel void kernel_rope(
|
|
662
789
|
}
|
663
790
|
} else {
|
664
791
|
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
|
665
|
-
for (int64_t ic =
|
792
|
+
for (int64_t ic = 2*tiitg; ic < n_dims; ic += 2*tptg.x) {
|
793
|
+
|
794
|
+
const float theta = theta_0 * pow(freq_base, inv_ndims*ic - ib);
|
666
795
|
const float cos_theta = cos(theta);
|
667
796
|
const float sin_theta = sin(theta);
|
668
797
|
|
669
|
-
theta *= theta_scale;
|
670
|
-
|
671
798
|
const int64_t i0 = ib*n_dims + ic/2;
|
672
799
|
|
673
800
|
device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
@@ -1071,31 +1198,40 @@ kernel void kernel_mul_mat_q3_K_f32(
|
|
1071
1198
|
device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0;
|
1072
1199
|
device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
|
1073
1200
|
|
1074
|
-
float yl[
|
1201
|
+
float yl[32];
|
1075
1202
|
|
1076
|
-
const uint16_t kmask1 =
|
1203
|
+
const uint16_t kmask1 = 0x3030;
|
1077
1204
|
const uint16_t kmask2 = 0x0f0f;
|
1078
1205
|
|
1079
|
-
const int tid = tiisg/
|
1080
|
-
const int ix = tiisg%
|
1081
|
-
const int ip = tid/
|
1082
|
-
const int il = tid/2
|
1206
|
+
const int tid = tiisg/4;
|
1207
|
+
const int ix = tiisg%4;
|
1208
|
+
const int ip = tid/4; // 0 or 1
|
1209
|
+
const int il = 2*((tid%4)/2); // 0 or 2
|
1083
1210
|
const int ir = tid%2;
|
1084
1211
|
const int n = 8;
|
1085
1212
|
const int l0 = n*ir;
|
1086
1213
|
|
1087
|
-
|
1088
|
-
|
1214
|
+
// One would think that the Metal compiler would figure out that ip and il can only have
|
1215
|
+
// 4 possible states, and optimize accordingly. Well, no. It needs help, and we do it
|
1216
|
+
// with these two tales.
|
1217
|
+
//
|
1218
|
+
// Possible masks for the high bit
|
1219
|
+
const ushort4 mm[4] = {{0x0001, 0x0100, 0x0002, 0x0200}, // ip = 0, il = 0
|
1220
|
+
{0x0004, 0x0400, 0x0008, 0x0800}, // ip = 0, il = 2
|
1221
|
+
{0x0010, 0x1000, 0x0020, 0x2000}, // ip = 1, il = 0
|
1222
|
+
{0x0040, 0x4000, 0x0080, 0x8000}}; // ip = 1, il = 2
|
1223
|
+
|
1224
|
+
// Possible masks for the low 2 bits
|
1225
|
+
const int4 qm[2] = {{0x0003, 0x0300, 0x000c, 0x0c00}, {0x0030, 0x3000, 0x00c0, 0xc000}};
|
1226
|
+
|
1227
|
+
const ushort4 hm = mm[2*ip + il/2];
|
1089
1228
|
|
1090
1229
|
const int shift = 2*il;
|
1091
|
-
const
|
1092
|
-
const
|
1093
|
-
const int32_t v1 = 4 << shift;
|
1094
|
-
const int32_t v2 = 1024 << shift;
|
1230
|
+
const float v1 = il == 0 ? 4.f : 64.f;
|
1231
|
+
const float v2 = 4.f * v1;
|
1095
1232
|
|
1096
1233
|
const uint16_t s_shift1 = 4*ip;
|
1097
|
-
const uint16_t s_shift2 = s_shift1 +
|
1098
|
-
const int ik = 4 + (il%2);
|
1234
|
+
const uint16_t s_shift2 = s_shift1 + il;
|
1099
1235
|
|
1100
1236
|
const int q_offset = 32*ip + l0;
|
1101
1237
|
const int y_offset = 128*ip + 32*il + l0;
|
@@ -1104,12 +1240,19 @@ kernel void kernel_mul_mat_q3_K_f32(
|
|
1104
1240
|
|
1105
1241
|
device const float * y1 = yy + ix*QK_K + y_offset;
|
1106
1242
|
|
1107
|
-
|
1108
|
-
|
1243
|
+
uint32_t scales32, aux32;
|
1244
|
+
thread uint16_t * scales16 = (thread uint16_t *)&scales32;
|
1245
|
+
thread const int8_t * scales = (thread const int8_t *)&scales32;
|
1246
|
+
|
1247
|
+
float sumf1[2] = {0.f};
|
1248
|
+
float sumf2[2] = {0.f};
|
1249
|
+
for (int i = ix; i < nb; i += 4) {
|
1109
1250
|
|
1110
1251
|
for (int l = 0; l < 8; ++l) {
|
1111
|
-
yl[l+0] = y1[l+ 0];
|
1112
|
-
yl[l+8] = y1[l+16];
|
1252
|
+
yl[l+ 0] = y1[l+ 0];
|
1253
|
+
yl[l+ 8] = y1[l+16];
|
1254
|
+
yl[l+16] = y1[l+32];
|
1255
|
+
yl[l+24] = y1[l+48];
|
1113
1256
|
}
|
1114
1257
|
|
1115
1258
|
device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset);
|
@@ -1120,27 +1263,43 @@ kernel void kernel_mul_mat_q3_K_f32(
|
|
1120
1263
|
for (int row = 0; row < 2; ++row) {
|
1121
1264
|
|
1122
1265
|
const float d_all = (float)dh[0];
|
1123
|
-
const char2 scales = as_type<char2>((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4)));
|
1124
1266
|
|
1125
|
-
|
1267
|
+
scales16[0] = a[4];
|
1268
|
+
scales16[1] = a[5];
|
1269
|
+
aux32 = ((scales32 >> s_shift2) << 4) & 0x30303030;
|
1270
|
+
scales16[0] = a[il+0];
|
1271
|
+
scales16[1] = a[il+1];
|
1272
|
+
scales32 = ((scales32 >> s_shift1) & 0x0f0f0f0f) | aux32;
|
1273
|
+
|
1274
|
+
float s1 = 0, s2 = 0, s3 = 0, s4 = 0, s5 = 0, s6 = 0;
|
1126
1275
|
for (int l = 0; l < n; l += 2) {
|
1127
|
-
const
|
1128
|
-
s1 += yl[l+0] * (
|
1129
|
-
s2 += yl[l+1] * (
|
1276
|
+
const int32_t qs = q[l/2];
|
1277
|
+
s1 += yl[l+0] * (qs & qm[il/2][0]);
|
1278
|
+
s2 += yl[l+1] * (qs & qm[il/2][1]);
|
1279
|
+
s3 += ((h[l/2] & hm[0]) ? 0.f : yl[l+0]) + ((h[l/2] & hm[1]) ? 0.f : yl[l+1]);
|
1280
|
+
s4 += yl[l+16] * (qs & qm[il/2][2]);
|
1281
|
+
s5 += yl[l+17] * (qs & qm[il/2][3]);
|
1282
|
+
s6 += ((h[l/2] & hm[2]) ? 0.f : yl[l+16]) + ((h[l/2] & hm[3]) ? 0.f : yl[l+17]);
|
1130
1283
|
}
|
1131
|
-
float
|
1132
|
-
|
1133
|
-
|
1284
|
+
float d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
|
1285
|
+
float d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2);
|
1286
|
+
sumf1[row] += d1 * (scales[0] - 32);
|
1287
|
+
sumf2[row] += d2 * (scales[2] - 32);
|
1134
1288
|
|
1135
|
-
s1 = s2 = 0;
|
1289
|
+
s1 = s2 = s3 = s4 = s5 = s6 = 0;
|
1136
1290
|
for (int l = 0; l < n; l += 2) {
|
1137
|
-
const
|
1138
|
-
s1 += yl[l+8] * (
|
1139
|
-
s2 += yl[l+9] * (
|
1291
|
+
const int32_t qs = q[l/2+8];
|
1292
|
+
s1 += yl[l+8] * (qs & qm[il/2][0]);
|
1293
|
+
s2 += yl[l+9] * (qs & qm[il/2][1]);
|
1294
|
+
s3 += ((h[l/2+8] & hm[0]) ? 0.f : yl[l+8]) + ((h[l/2+8] & hm[1]) ? 0.f : yl[l+9]);
|
1295
|
+
s4 += yl[l+24] * (qs & qm[il/2][2]);
|
1296
|
+
s5 += yl[l+25] * (qs & qm[il/2][3]);
|
1297
|
+
s6 += ((h[l/2+8] & hm[2]) ? 0.f : yl[l+24]) + ((h[l/2+8] & hm[3]) ? 0.f : yl[l+25]);
|
1140
1298
|
}
|
1141
|
-
|
1142
|
-
|
1143
|
-
|
1299
|
+
d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
|
1300
|
+
d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2);
|
1301
|
+
sumf1[row] += d1 * (scales[1] - 32);
|
1302
|
+
sumf2[row] += d2 * (scales[3] - 32);
|
1144
1303
|
|
1145
1304
|
q += step;
|
1146
1305
|
h += step;
|
@@ -1149,17 +1308,20 @@ kernel void kernel_mul_mat_q3_K_f32(
|
|
1149
1308
|
|
1150
1309
|
}
|
1151
1310
|
|
1152
|
-
y1 +=
|
1311
|
+
y1 += 4 * QK_K;
|
1153
1312
|
|
1154
1313
|
}
|
1155
1314
|
|
1156
1315
|
for (int row = 0; row < 2; ++row) {
|
1157
|
-
const float sumf = (sumf1[row]
|
1158
|
-
|
1159
|
-
|
1160
|
-
|
1316
|
+
const float sumf = (sumf1[row] + 0.25f * sumf2[row]) / (1 << shift);
|
1317
|
+
sumf1[row] = simd_sum(sumf);
|
1318
|
+
}
|
1319
|
+
if (tiisg == 0) {
|
1320
|
+
for (int row = 0; row < 2; ++row) {
|
1321
|
+
dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = sumf1[row];
|
1161
1322
|
}
|
1162
1323
|
}
|
1324
|
+
|
1163
1325
|
}
|
1164
1326
|
#else
|
1165
1327
|
kernel void kernel_mul_mat_q3_K_f32(
|
@@ -1262,7 +1424,8 @@ kernel void kernel_mul_mat_q4_K_f32(
|
|
1262
1424
|
const int r0 = tgpig.x;
|
1263
1425
|
const int r1 = tgpig.y;
|
1264
1426
|
const int r2 = tgpig.z;
|
1265
|
-
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
1427
|
+
//const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
1428
|
+
const int first_row = r0 * N_DST;
|
1266
1429
|
const int ib_row = first_row * nb;
|
1267
1430
|
const uint offset0 = r2/gqa*(nb*ne0);
|
1268
1431
|
device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
|
@@ -1511,17 +1674,25 @@ kernel void kernel_mul_mat_q5_K_f32(
|
|
1511
1674
|
sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2);
|
1512
1675
|
sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2);
|
1513
1676
|
|
1514
|
-
float4
|
1677
|
+
float4 acc1 = {0.f};
|
1678
|
+
float4 acc2 = {0.f};
|
1515
1679
|
for (int l = 0; l < n; ++l) {
|
1516
1680
|
uint8_t h = qh[l];
|
1517
|
-
|
1518
|
-
|
1519
|
-
|
1520
|
-
|
1681
|
+
acc1[0] += yl[l+0] * (q1[l] & 0x0F);
|
1682
|
+
acc1[1] += yl[l+8] * (q1[l] & 0xF0);
|
1683
|
+
acc1[2] += yh[l+0] * (q2[l] & 0x0F);
|
1684
|
+
acc1[3] += yh[l+8] * (q2[l] & 0xF0);
|
1685
|
+
acc2[0] += h & hm1 ? yl[l+0] : 0.f;
|
1686
|
+
acc2[1] += h & hm2 ? yl[l+8] : 0.f;
|
1687
|
+
acc2[2] += h & hm3 ? yh[l+0] : 0.f;
|
1688
|
+
acc2[3] += h & hm4 ? yh[l+8] : 0.f;
|
1521
1689
|
}
|
1522
1690
|
const float dall = dh[0];
|
1523
1691
|
const float dmin = dh[1];
|
1524
|
-
sumf[row] += dall * (
|
1692
|
+
sumf[row] += dall * (sc8[0] * (acc1[0] + 16.f*acc2[0]) +
|
1693
|
+
sc8[1] * (acc1[1]/16.f + 16.f*acc2[1]) +
|
1694
|
+
sc8[4] * (acc1[2] + 16.f*acc2[2]) +
|
1695
|
+
sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) -
|
1525
1696
|
dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
|
1526
1697
|
|
1527
1698
|
q1 += step;
|
@@ -1704,29 +1875,34 @@ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg)
|
|
1704
1875
|
|
1705
1876
|
template <typename type4x4>
|
1706
1877
|
void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
|
1878
|
+
|
1707
1879
|
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
|
1708
|
-
const
|
1709
|
-
const
|
1880
|
+
const float d1 = il ? (xb->d / 16.h) : xb->d;
|
1881
|
+
const float d2 = d1 / 256.f;
|
1882
|
+
const float md = -8.h * xb->d;
|
1710
1883
|
const ushort mask0 = il ? 0x00F0 : 0x000F;
|
1711
|
-
const ushort mask1 =
|
1884
|
+
const ushort mask1 = mask0 << 8;
|
1712
1885
|
|
1713
1886
|
for (int i=0;i<8;i++) {
|
1714
|
-
reg[i/2][2*(i%2)]
|
1715
|
-
reg[i/2][2*(i%2)+1] = (
|
1887
|
+
reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md;
|
1888
|
+
reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md;
|
1716
1889
|
}
|
1890
|
+
|
1717
1891
|
}
|
1718
1892
|
|
1719
1893
|
template <typename type4x4>
|
1720
1894
|
void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
|
1895
|
+
|
1721
1896
|
device const uint16_t * qs = ((device const uint16_t *)xb + 2);
|
1722
|
-
const
|
1723
|
-
const
|
1897
|
+
const float d1 = il ? (xb->d / 16.h) : xb->d;
|
1898
|
+
const float d2 = d1 / 256.f;
|
1899
|
+
const float m = xb->m;
|
1724
1900
|
const ushort mask0 = il ? 0x00F0 : 0x000F;
|
1725
|
-
const ushort mask1 =
|
1901
|
+
const ushort mask1 = mask0 << 8;
|
1726
1902
|
|
1727
1903
|
for (int i=0;i<8;i++) {
|
1728
|
-
reg[i/2][2*(i%2)]
|
1729
|
-
reg[i/2][2*(i%2)+1] = ((
|
1904
|
+
reg[i/2][2*(i%2)+0] = ((qs[i] & mask0) * d1) + m;
|
1905
|
+
reg[i/2][2*(i%2)+1] = ((qs[i] & mask1) * d2) + m;
|
1730
1906
|
}
|
1731
1907
|
}
|
1732
1908
|
|
@@ -1762,7 +1938,7 @@ void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg
|
|
1762
1938
|
|
1763
1939
|
template <typename type4x4>
|
1764
1940
|
void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) {
|
1765
|
-
const
|
1941
|
+
const half d_all = xb->d;
|
1766
1942
|
device const uint8_t * q = (device const uint8_t *)xb->qs;
|
1767
1943
|
device const uint8_t * h = (device const uint8_t *)xb->hmask;
|
1768
1944
|
device const int8_t * scales = (device const int8_t *)xb->scales;
|
@@ -1775,17 +1951,20 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
|
|
1775
1951
|
((il/4)>0 ? 12 : 3);
|
1776
1952
|
uint16_t kmask2 = il/8 ? 0xF0 : 0x0F;
|
1777
1953
|
uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];
|
1778
|
-
int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2)
|
1779
|
-
|
1780
|
-
|
1954
|
+
int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2)
|
1955
|
+
: (scale_2&kmask2) | ((scale_1&kmask1) << 4);
|
1956
|
+
half dl = il<8 ? d_all * (dl_int - 32.h) : d_all * (dl_int / 16.h - 32.h);
|
1957
|
+
const half ml = 4.h * dl;
|
1781
1958
|
|
1782
|
-
il = (il/2)
|
1783
|
-
|
1784
|
-
uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
|
1959
|
+
il = (il/2) & 3;
|
1960
|
+
const half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
|
1961
|
+
const uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
|
1962
|
+
dl *= coef;
|
1785
1963
|
|
1786
1964
|
for (int i = 0; i < 16; ++i) {
|
1787
|
-
reg[i/4][i%4] =
|
1965
|
+
reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml);
|
1788
1966
|
}
|
1967
|
+
|
1789
1968
|
#else
|
1790
1969
|
float kcoef = il&1 ? 1.f/16.f : 1.f;
|
1791
1970
|
uint16_t kmask = il&1 ? 0xF0 : 0x0F;
|
@@ -1799,31 +1978,37 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
|
|
1799
1978
|
#endif
|
1800
1979
|
}
|
1801
1980
|
|
1981
|
+
static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) {
|
1982
|
+
return j < 4 ? uchar2{uchar(q[j+0+k] & 63), uchar(q[j+4+k] & 63)}
|
1983
|
+
: uchar2{uchar((q[j+4+k] & 0xF) | ((q[j-4+k] & 0xc0) >> 2)), uchar((q[j+4+k] >> 4) | ((q[j-0+k] & 0xc0) >> 2))};
|
1984
|
+
}
|
1985
|
+
|
1802
1986
|
template <typename type4x4>
|
1803
1987
|
void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) {
|
1804
|
-
device const
|
1988
|
+
device const uchar * q = xb->qs;
|
1805
1989
|
|
1806
1990
|
#if QK_K == 256
|
1807
|
-
const float d = (float)(xb->d);
|
1808
|
-
const float min = (float)(xb->dmin);
|
1809
1991
|
short is = (il/4) * 2;
|
1810
1992
|
q = q + (il/4) * 32 + 16 * (il&1);
|
1811
|
-
il = il
|
1812
|
-
const
|
1813
|
-
const
|
1814
|
-
const
|
1993
|
+
il = il & 3;
|
1994
|
+
const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
|
1995
|
+
const half d = il < 2 ? xb->d : xb->d / 16.h;
|
1996
|
+
const half min = xb->dmin;
|
1997
|
+
const half dl = d * sc[0];
|
1998
|
+
const half ml = min * sc[1];
|
1815
1999
|
#else
|
1816
2000
|
q = q + 16 * (il&1);
|
1817
2001
|
device const uint8_t * s = xb->scales;
|
1818
2002
|
device const half2 * dh = (device const half2 *)xb->d;
|
1819
2003
|
const float2 d = (float2)dh[0];
|
1820
2004
|
const float dl = il<2 ? d[0] * (s[0]&0xF) : d[0] * (s[1]&0xF)/16.h;
|
1821
|
-
const float ml = il<2 ? d[1] * (s[0]>>4) : d[1
|
2005
|
+
const float ml = il<2 ? d[1] * (s[0]>>4) : d[1] * (s[1]>>4);
|
1822
2006
|
#endif
|
1823
2007
|
const ushort mask = il<2 ? 0x0F : 0xF0;
|
1824
2008
|
for (int i = 0; i < 16; ++i) {
|
1825
2009
|
reg[i/4][i%4] = dl * (q[i] & mask) - ml;
|
1826
2010
|
}
|
2011
|
+
|
1827
2012
|
}
|
1828
2013
|
|
1829
2014
|
template <typename type4x4>
|
@@ -1832,19 +2017,19 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
|
|
1832
2017
|
device const uint8_t * qh = xb->qh;
|
1833
2018
|
|
1834
2019
|
#if QK_K == 256
|
1835
|
-
const float d = (float)(xb->d);
|
1836
|
-
const float min = (float)(xb->dmin);
|
1837
2020
|
short is = (il/4) * 2;
|
1838
2021
|
q = q + 32 * (il/4) + 16 * (il&1);
|
1839
2022
|
qh = qh + 16 * (il&1);
|
1840
2023
|
uint8_t ul = 1 << (il/2);
|
1841
|
-
il = il
|
1842
|
-
const
|
1843
|
-
const
|
1844
|
-
const
|
2024
|
+
il = il & 3;
|
2025
|
+
const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
|
2026
|
+
const half d = il < 2 ? xb->d : xb->d / 16.h;
|
2027
|
+
const half min = xb->dmin;
|
2028
|
+
const half dl = d * sc[0];
|
2029
|
+
const half ml = min * sc[1];
|
1845
2030
|
|
1846
|
-
const ushort mask
|
1847
|
-
const
|
2031
|
+
const ushort mask = il<2 ? 0x0F : 0xF0;
|
2032
|
+
const half qh_val = il<2 ? 16.h : 256.h;
|
1848
2033
|
for (int i = 0; i < 16; ++i) {
|
1849
2034
|
reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
|
1850
2035
|
}
|
@@ -1863,7 +2048,7 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
|
|
1863
2048
|
|
1864
2049
|
template <typename type4x4>
|
1865
2050
|
void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) {
|
1866
|
-
const
|
2051
|
+
const half d_all = xb->d;
|
1867
2052
|
device const uint8_t * ql = (device const uint8_t *)xb->ql;
|
1868
2053
|
device const uint8_t * qh = (device const uint8_t *)xb->qh;
|
1869
2054
|
device const int8_t * scales = (device const int8_t *)xb->scales;
|
@@ -1871,19 +2056,21 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
|
|
1871
2056
|
#if QK_K == 256
|
1872
2057
|
ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
|
1873
2058
|
qh = qh + 32*(il/8) + 16*(il&1);
|
1874
|
-
|
1875
|
-
il = (il/2)
|
2059
|
+
half sc = scales[(il%2) + 2 * ((il/2))];
|
2060
|
+
il = (il/2) & 3;
|
1876
2061
|
#else
|
1877
2062
|
ql = ql + 16 * (il&1);
|
1878
|
-
|
2063
|
+
half sc = scales[il];
|
1879
2064
|
#endif
|
2065
|
+
const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
|
2066
|
+
const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
|
2067
|
+
const half coef = il>1 ? 1.f/16.h : 1.h;
|
2068
|
+
const half ml = d_all * sc * 32.h;
|
2069
|
+
const half dl = d_all * sc * coef;
|
1880
2070
|
for (int i = 0; i < 16; ++i) {
|
1881
|
-
|
1882
|
-
|
1883
|
-
|
1884
|
-
float q = il&1 ? ((ql[i]&kmask2)|((qh[i]&kmask1)<<2)) - 32.f/coef : \
|
1885
|
-
((ql[i]&kmask2)|((qh[i]&kmask1)<<4)) - 32.f/coef;
|
1886
|
-
reg[i/4][i%4] = d_all * sc * q * coef;
|
2071
|
+
const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2))
|
2072
|
+
: ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4));
|
2073
|
+
reg[i/4][i%4] = dl * q - ml;
|
1887
2074
|
}
|
1888
2075
|
}
|
1889
2076
|
|