llama_cpp 0.5.0 → 0.5.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +14 -2
- data/examples/prompt_jp.txt +1 -1
- data/ext/llama_cpp/extconf.rb +1 -1
- data/ext/llama_cpp/llama_cpp.cpp +30 -0
- data/ext/llama_cpp/src/ggml-alloc.c +101 -24
- data/ext/llama_cpp/src/ggml-cuda.cu +1094 -678
- data/ext/llama_cpp/src/ggml-metal.m +89 -23
- data/ext/llama_cpp/src/ggml-metal.metal +398 -211
- data/ext/llama_cpp/src/ggml-opencl.cpp +7 -7
- data/ext/llama_cpp/src/ggml.c +32 -56
- data/ext/llama_cpp/src/ggml.h +1 -1
- data/ext/llama_cpp/src/k_quants.c +49 -13
- data/ext/llama_cpp/src/llama.cpp +833 -281
- data/ext/llama_cpp/src/llama.h +11 -6
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +1 -1
- data/sig/llama_cpp.rbs +4 -0
- metadata +2 -2
@@ -63,18 +63,18 @@ kernel void kernel_mul_row(
|
|
63
63
|
}
|
64
64
|
|
65
65
|
kernel void kernel_scale(
|
66
|
-
device const
|
67
|
-
device
|
66
|
+
device const float4 * src0,
|
67
|
+
device float4 * dst,
|
68
68
|
constant float & scale,
|
69
69
|
uint tpig[[thread_position_in_grid]]) {
|
70
70
|
dst[tpig] = src0[tpig] * scale;
|
71
71
|
}
|
72
72
|
|
73
73
|
kernel void kernel_silu(
|
74
|
-
device const
|
75
|
-
device
|
74
|
+
device const float4 * src0,
|
75
|
+
device float4 * dst,
|
76
76
|
uint tpig[[thread_position_in_grid]]) {
|
77
|
-
|
77
|
+
device const float4 & x = src0[tpig];
|
78
78
|
dst[tpig] = x / (1.0f + exp(-x));
|
79
79
|
}
|
80
80
|
|
@@ -89,10 +89,10 @@ constant float GELU_COEF_A = 0.044715f;
|
|
89
89
|
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
90
90
|
|
91
91
|
kernel void kernel_gelu(
|
92
|
-
device const
|
93
|
-
device
|
92
|
+
device const float4 * src0,
|
93
|
+
device float4 * dst,
|
94
94
|
uint tpig[[thread_position_in_grid]]) {
|
95
|
-
|
95
|
+
device const float4 & x = src0[tpig];
|
96
96
|
|
97
97
|
// BEWARE !!!
|
98
98
|
// Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
|
@@ -107,7 +107,6 @@ kernel void kernel_soft_max(
|
|
107
107
|
constant int64_t & ne00,
|
108
108
|
constant int64_t & ne01,
|
109
109
|
constant int64_t & ne02,
|
110
|
-
threadgroup float * buf [[threadgroup(0)]],
|
111
110
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
112
111
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
113
112
|
uint3 ntg[[threads_per_threadgroup]]) {
|
@@ -119,55 +118,67 @@ kernel void kernel_soft_max(
|
|
119
118
|
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
120
119
|
|
121
120
|
// parallel max
|
122
|
-
|
123
|
-
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
|
124
|
-
|
121
|
+
float lmax = psrc0[tpitg[0]];
|
122
|
+
for (int i00 = tpitg[0] + ntg[0]; i00 < ne00; i00 += ntg[0]) {
|
123
|
+
lmax = MAX(lmax, psrc0[i00]);
|
125
124
|
}
|
125
|
+
const float max = simd_max(lmax);
|
126
126
|
|
127
|
-
//
|
128
|
-
|
129
|
-
for (
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
127
|
+
// parallel sum
|
128
|
+
float lsum = 0.0f;
|
129
|
+
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
|
130
|
+
const float exp_psrc0 = exp(psrc0[i00] - max);
|
131
|
+
lsum += exp_psrc0;
|
132
|
+
// Remember the result of exp here. exp is expensive, so we really do not
|
133
|
+
// whish to compute it twice.
|
134
|
+
pdst[i00] = exp_psrc0;
|
134
135
|
}
|
135
136
|
|
136
|
-
|
137
|
-
|
138
|
-
|
137
|
+
const float sum = simd_sum(lsum);
|
138
|
+
|
139
|
+
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
|
140
|
+
pdst[i00] /= sum;
|
139
141
|
}
|
142
|
+
}
|
140
143
|
|
141
|
-
|
144
|
+
kernel void kernel_soft_max_4(
|
145
|
+
device const float * src0,
|
146
|
+
device float * dst,
|
147
|
+
constant int64_t & ne00,
|
148
|
+
constant int64_t & ne01,
|
149
|
+
constant int64_t & ne02,
|
150
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
151
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
152
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
153
|
+
const int64_t i03 = tgpig[2];
|
154
|
+
const int64_t i02 = tgpig[1];
|
155
|
+
const int64_t i01 = tgpig[0];
|
142
156
|
|
143
|
-
const
|
157
|
+
device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
158
|
+
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
144
159
|
|
145
|
-
// parallel
|
146
|
-
|
147
|
-
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
|
148
|
-
|
160
|
+
// parallel max
|
161
|
+
float4 lmax4 = psrc4[tpitg[0]];
|
162
|
+
for (int i00 = tpitg[0] + ntg[0]; i00 < ne00/4; i00 += ntg[0]) {
|
163
|
+
lmax4 = fmax(lmax4, psrc4[i00]);
|
149
164
|
}
|
165
|
+
float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
|
150
166
|
|
151
|
-
|
152
|
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
153
|
-
for (uint i = ntg[0]/2; i > 0; i /= 2) {
|
154
|
-
if (tpitg[0] < i) {
|
155
|
-
buf[tpitg[0]] += buf[tpitg[0] + i];
|
156
|
-
}
|
157
|
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
158
|
-
}
|
167
|
+
const float max = simd_max(lmax);
|
159
168
|
|
160
|
-
//
|
161
|
-
|
162
|
-
|
169
|
+
// parallel sum
|
170
|
+
float4 lsum4 = 0.0f;
|
171
|
+
for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) {
|
172
|
+
const float4 exp_psrc4 = exp(psrc4[i00] - max);
|
173
|
+
lsum4 += exp_psrc4;
|
174
|
+
pdst4[i00] = exp_psrc4;
|
163
175
|
}
|
176
|
+
float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
|
164
177
|
|
165
|
-
|
166
|
-
|
167
|
-
const float sum = buf[0];
|
178
|
+
const float sum = simd_sum(lsum);
|
168
179
|
|
169
|
-
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
|
170
|
-
|
180
|
+
for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) {
|
181
|
+
pdst4[i00] /= sum;
|
171
182
|
}
|
172
183
|
}
|
173
184
|
|
@@ -186,6 +197,33 @@ kernel void kernel_diag_mask_inf(
|
|
186
197
|
dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
|
187
198
|
} else {
|
188
199
|
dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];
|
200
|
+
}
|
201
|
+
}
|
202
|
+
|
203
|
+
kernel void kernel_diag_mask_inf_8(
|
204
|
+
device const float4 * src0,
|
205
|
+
device float4 * dst,
|
206
|
+
constant int64_t & ne00,
|
207
|
+
constant int64_t & ne01,
|
208
|
+
constant int & n_past,
|
209
|
+
uint3 tpig[[thread_position_in_grid]]) {
|
210
|
+
|
211
|
+
const int64_t i = 2*tpig[0];
|
212
|
+
|
213
|
+
dst[i+0] = src0[i+0];
|
214
|
+
dst[i+1] = src0[i+1];
|
215
|
+
int64_t i4 = 4*i;
|
216
|
+
const int64_t i02 = i4/(ne00*ne01); i4 -= i02*ne00*ne01;
|
217
|
+
const int64_t i01 = i4/(ne00); i4 -= i01*ne00;
|
218
|
+
const int64_t i00 = i4;
|
219
|
+
for (int k = 3; k >= 0; --k) {
|
220
|
+
if (i00 + 4 + k <= n_past + i01) {
|
221
|
+
break;
|
222
|
+
}
|
223
|
+
dst[i+1][k] = -INFINITY;
|
224
|
+
if (i00 + k > n_past + i01) {
|
225
|
+
dst[i][k] = -INFINITY;
|
226
|
+
}
|
189
227
|
}
|
190
228
|
}
|
191
229
|
|
@@ -214,25 +252,17 @@ kernel void kernel_norm(
|
|
214
252
|
}
|
215
253
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
216
254
|
}
|
217
|
-
|
218
|
-
if (tpitg == 0) {
|
219
|
-
sum[0] /= ne00;
|
220
|
-
}
|
221
|
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
222
|
-
const float mean = sum[0];
|
255
|
+
const float mean = sum[0] / ne00;
|
223
256
|
|
224
|
-
// recenter
|
257
|
+
// recenter and VARIANCE
|
258
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
225
259
|
device float * y = dst + tgpig*ne00;
|
226
|
-
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
227
|
-
y[i00] = x[i00] - mean;
|
228
|
-
}
|
229
|
-
|
230
|
-
// VARIANCE
|
231
|
-
// parallel sum
|
232
260
|
sum[tpitg] = 0.0f;
|
233
261
|
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
262
|
+
y[i00] = x[i00] - mean;
|
234
263
|
sum[tpitg] += y[i00] * y[i00];
|
235
264
|
}
|
265
|
+
|
236
266
|
// reduce
|
237
267
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
238
268
|
for (uint i = ntg/2; i > 0; i /= 2) {
|
@@ -241,12 +271,7 @@ kernel void kernel_norm(
|
|
241
271
|
}
|
242
272
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
243
273
|
}
|
244
|
-
|
245
|
-
if (tpitg == 0) {
|
246
|
-
sum[0] /= ne00;
|
247
|
-
}
|
248
|
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
249
|
-
const float variance = sum[0];
|
274
|
+
const float variance = sum[0] / ne00;
|
250
275
|
|
251
276
|
const float scale = 1.0f/sqrt(variance + eps);
|
252
277
|
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
@@ -254,7 +279,6 @@ kernel void kernel_norm(
|
|
254
279
|
}
|
255
280
|
}
|
256
281
|
|
257
|
-
|
258
282
|
kernel void kernel_rms_norm(
|
259
283
|
device const void * src0,
|
260
284
|
device float * dst,
|
@@ -435,6 +459,8 @@ kernel void kernel_mul_mat_q4_1_f32(
|
|
435
459
|
mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
|
436
460
|
}
|
437
461
|
|
462
|
+
#define NB_Q8_0 8
|
463
|
+
|
438
464
|
kernel void kernel_mul_mat_q8_0_f32(
|
439
465
|
device const void * src0,
|
440
466
|
device const float * src1,
|
@@ -463,30 +489,30 @@ kernel void kernel_mul_mat_q8_0_f32(
|
|
463
489
|
device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0;
|
464
490
|
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
465
491
|
|
466
|
-
float yl[
|
492
|
+
float yl[NB_Q8_0];
|
467
493
|
float sumf[nr]={0.f};
|
468
494
|
|
469
|
-
const int ix = tiisg/
|
470
|
-
const int il = tiisg%
|
495
|
+
const int ix = tiisg/4;
|
496
|
+
const int il = tiisg%4;
|
471
497
|
|
472
|
-
device const float * yb = y + ix * QK8_0 +
|
498
|
+
device const float * yb = y + ix * QK8_0 + NB_Q8_0*il;
|
473
499
|
|
474
|
-
// each thread in a SIMD group deals with
|
475
|
-
for (int ib = ix; ib < nb; ib += nw/
|
476
|
-
for (int i = 0; i <
|
500
|
+
// each thread in a SIMD group deals with NB_Q8_0 quants at a time
|
501
|
+
for (int ib = ix; ib < nb; ib += nw/4) {
|
502
|
+
for (int i = 0; i < NB_Q8_0; ++i) {
|
477
503
|
yl[i] = yb[i];
|
478
504
|
}
|
479
505
|
|
480
506
|
for (int row = 0; row < nr; row++) {
|
481
|
-
device const int8_t * qs = x[ib+row*nb].qs +
|
507
|
+
device const int8_t * qs = x[ib+row*nb].qs + NB_Q8_0*il;
|
482
508
|
float sumq = 0.f;
|
483
|
-
for (int iq = 0; iq <
|
509
|
+
for (int iq = 0; iq < NB_Q8_0; ++iq) {
|
484
510
|
sumq += qs[iq] * yl[iq];
|
485
511
|
}
|
486
512
|
sumf[row] += sumq*x[ib+row*nb].d;
|
487
513
|
}
|
488
514
|
|
489
|
-
yb +=
|
515
|
+
yb += NB_Q8_0 * nw;
|
490
516
|
}
|
491
517
|
|
492
518
|
for (int row = 0; row < nr; ++row) {
|
@@ -497,7 +523,7 @@ kernel void kernel_mul_mat_q8_0_f32(
|
|
497
523
|
}
|
498
524
|
}
|
499
525
|
|
500
|
-
kernel void
|
526
|
+
kernel void kernel_mul_mat_f16_f32_1row(
|
501
527
|
device const char * src0,
|
502
528
|
device const char * src1,
|
503
529
|
device float * dst,
|
@@ -515,11 +541,8 @@ kernel void kernel_mul_mat_f16_f32(
|
|
515
541
|
constant uint64_t & nb12,
|
516
542
|
constant int64_t & ne0,
|
517
543
|
constant int64_t & ne1,
|
518
|
-
threadgroup float * sum [[threadgroup(0)]],
|
519
544
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
520
|
-
|
521
|
-
uint3 tpitg[[thread_position_in_threadgroup]],
|
522
|
-
uint3 tptg[[threads_per_threadgroup]]) {
|
545
|
+
uint tiisg[[thread_index_in_simdgroup]]) {
|
523
546
|
|
524
547
|
const int64_t r0 = tgpig.x;
|
525
548
|
const int64_t r1 = tgpig.y;
|
@@ -528,42 +551,144 @@ kernel void kernel_mul_mat_f16_f32(
|
|
528
551
|
device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
|
529
552
|
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
|
530
553
|
|
531
|
-
|
532
|
-
|
554
|
+
float sumf = 0;
|
555
|
+
if (ne00 < 128) {
|
556
|
+
for (int i = tiisg; i < ne00; i += 32) {
|
557
|
+
sumf += (float) x[i] * (float) y[i];
|
558
|
+
}
|
559
|
+
float all_sum = simd_sum(sumf);
|
560
|
+
if (tiisg == 0) {
|
561
|
+
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
562
|
+
}
|
563
|
+
} else {
|
564
|
+
device const half4 * x4 = (device const half4 *) x;
|
565
|
+
device const float4 * y4 = (device const float4 *) y;
|
566
|
+
for (int i = tiisg; i < ne00/4; i += 32) {
|
567
|
+
for (int k = 0; k < 4; ++k) sumf += (float)x4[i][k] * y4[i][k];
|
568
|
+
}
|
569
|
+
float all_sum = simd_sum(sumf);
|
570
|
+
if (tiisg == 0) {
|
571
|
+
for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
|
572
|
+
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
573
|
+
}
|
574
|
+
}
|
533
575
|
|
534
|
-
|
576
|
+
}
|
535
577
|
|
536
|
-
|
537
|
-
sum[ith] += (float) x[i] * (float) y[i];
|
538
|
-
}
|
578
|
+
#define N_F16_F32 4
|
539
579
|
|
540
|
-
|
541
|
-
|
542
|
-
|
543
|
-
|
580
|
+
kernel void kernel_mul_mat_f16_f32(
|
581
|
+
device const char * src0,
|
582
|
+
device const char * src1,
|
583
|
+
device float * dst,
|
584
|
+
constant int64_t & ne00,
|
585
|
+
constant int64_t & ne01,
|
586
|
+
constant int64_t & ne02,
|
587
|
+
constant uint64_t & nb00,
|
588
|
+
constant uint64_t & nb01,
|
589
|
+
constant uint64_t & nb02,
|
590
|
+
constant int64_t & ne10,
|
591
|
+
constant int64_t & ne11,
|
592
|
+
constant int64_t & ne12,
|
593
|
+
constant uint64_t & nb10,
|
594
|
+
constant uint64_t & nb11,
|
595
|
+
constant uint64_t & nb12,
|
596
|
+
constant int64_t & ne0,
|
597
|
+
constant int64_t & ne1,
|
598
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
599
|
+
uint tiisg[[thread_index_in_simdgroup]]) {
|
600
|
+
|
601
|
+
const int64_t r0 = tgpig.x;
|
602
|
+
const int64_t rb = tgpig.y*N_F16_F32;
|
603
|
+
const int64_t im = tgpig.z;
|
604
|
+
|
605
|
+
device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
|
606
|
+
|
607
|
+
if (ne00 < 128) {
|
608
|
+
for (int row = 0; row < N_F16_F32; ++row) {
|
609
|
+
int r1 = rb + row;
|
610
|
+
if (r1 >= ne11) {
|
611
|
+
break;
|
612
|
+
}
|
613
|
+
|
614
|
+
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
|
615
|
+
|
616
|
+
float sumf = 0;
|
617
|
+
for (int i = tiisg; i < ne00; i += 32) {
|
618
|
+
sumf += (float) x[i] * (float) y[i];
|
619
|
+
}
|
620
|
+
|
621
|
+
float all_sum = simd_sum(sumf);
|
622
|
+
if (tiisg == 0) {
|
623
|
+
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
624
|
+
}
|
625
|
+
}
|
626
|
+
} else {
|
627
|
+
device const half4 * x4 = (device const half4 *)x;
|
628
|
+
for (int row = 0; row < N_F16_F32; ++row) {
|
629
|
+
int r1 = rb + row;
|
630
|
+
if (r1 >= ne11) {
|
631
|
+
break;
|
632
|
+
}
|
633
|
+
|
634
|
+
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
|
635
|
+
device const float4 * y4 = (device const float4 *) y;
|
636
|
+
|
637
|
+
float sumf = 0;
|
638
|
+
for (int i = tiisg; i < ne00/4; i += 32) {
|
639
|
+
for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
|
640
|
+
}
|
641
|
+
|
642
|
+
float all_sum = simd_sum(sumf);
|
643
|
+
if (tiisg == 0) {
|
644
|
+
for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
|
645
|
+
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
646
|
+
}
|
647
|
+
}
|
544
648
|
}
|
545
|
-
|
546
|
-
|
547
|
-
|
649
|
+
}
|
650
|
+
|
651
|
+
// Assumes row size (ne00) is a multiple of 4
|
652
|
+
kernel void kernel_mul_mat_f16_f32_l4(
|
653
|
+
device const char * src0,
|
654
|
+
device const char * src1,
|
655
|
+
device float * dst,
|
656
|
+
constant int64_t & ne00,
|
657
|
+
constant int64_t & ne01,
|
658
|
+
constant int64_t & ne02,
|
659
|
+
constant uint64_t & nb00,
|
660
|
+
constant uint64_t & nb01,
|
661
|
+
constant uint64_t & nb02,
|
662
|
+
constant int64_t & ne10,
|
663
|
+
constant int64_t & ne11,
|
664
|
+
constant int64_t & ne12,
|
665
|
+
constant uint64_t & nb10,
|
666
|
+
constant uint64_t & nb11,
|
667
|
+
constant uint64_t & nb12,
|
668
|
+
constant int64_t & ne0,
|
669
|
+
constant int64_t & ne1,
|
670
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
671
|
+
uint tiisg[[thread_index_in_simdgroup]]) {
|
672
|
+
|
673
|
+
const int nrows = ne11;
|
674
|
+
const int64_t r0 = tgpig.x;
|
675
|
+
const int64_t im = tgpig.z;
|
676
|
+
|
677
|
+
device const half4 * x4 = (device const half4 *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
|
678
|
+
|
679
|
+
for (int r1 = 0; r1 < nrows; ++r1) {
|
680
|
+
device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12);
|
681
|
+
|
682
|
+
float sumf = 0;
|
683
|
+
for (int i = tiisg; i < ne00/4; i += 32) {
|
684
|
+
for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
|
685
|
+
}
|
686
|
+
|
687
|
+
float all_sum = simd_sum(sumf);
|
688
|
+
if (tiisg == 0) {
|
689
|
+
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
690
|
+
}
|
548
691
|
}
|
549
|
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
550
|
-
if (ith == 0) {
|
551
|
-
for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
|
552
|
-
dst[im*ne1*ne0 + r1*ne0 + r0] = sum[0];
|
553
|
-
}
|
554
|
-
|
555
|
-
// Original implementation. Left behind commented out for now
|
556
|
-
//threadgroup_barrier(mem_flags::mem_threadgroup);
|
557
|
-
//for (uint i = tptg.x/2; i > 0; i /= 2) {
|
558
|
-
// if (tpitg.x < i) {
|
559
|
-
// sum[tpitg.x] += sum[tpitg.x + i];
|
560
|
-
// }
|
561
|
-
// threadgroup_barrier(mem_flags::mem_threadgroup);
|
562
|
-
//}
|
563
|
-
//
|
564
|
-
//if (tpitg.x == 0) {
|
565
|
-
// dst[im*ne1*ne0 + r1*ne0 + r0] = sum[0];
|
566
|
-
//}
|
567
692
|
}
|
568
693
|
|
569
694
|
kernel void kernel_alibi_f32(
|
@@ -632,25 +757,27 @@ kernel void kernel_rope(
|
|
632
757
|
constant int & mode,
|
633
758
|
constant float & freq_base,
|
634
759
|
constant float & freq_scale,
|
635
|
-
|
636
|
-
|
637
|
-
|
638
|
-
const int64_t
|
760
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
761
|
+
uint3 tptg[[threads_per_threadgroup]],
|
762
|
+
uint3 tgpig[[threadgroup_position_in_grid]]) {
|
763
|
+
const int64_t i3 = tgpig[2];
|
764
|
+
const int64_t i2 = tgpig[1];
|
765
|
+
const int64_t i1 = tgpig[0];
|
639
766
|
|
640
767
|
const bool is_neox = mode & 2;
|
641
|
-
const float theta_scale = pow(freq_base, -2.0f/n_dims);
|
642
768
|
|
643
769
|
const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
|
644
770
|
|
645
|
-
float
|
771
|
+
const float theta_0 = freq_scale * (float)p;
|
772
|
+
const float inv_ndims = -1.f/n_dims;
|
646
773
|
|
647
774
|
if (!is_neox) {
|
648
|
-
for (int64_t i0 =
|
775
|
+
for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
|
776
|
+
|
777
|
+
const float theta = theta_0 * pow(freq_base, inv_ndims*i0);
|
649
778
|
const float cos_theta = cos(theta);
|
650
779
|
const float sin_theta = sin(theta);
|
651
780
|
|
652
|
-
theta *= theta_scale;
|
653
|
-
|
654
781
|
device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
655
782
|
device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
656
783
|
|
@@ -662,12 +789,12 @@ kernel void kernel_rope(
|
|
662
789
|
}
|
663
790
|
} else {
|
664
791
|
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
|
665
|
-
for (int64_t ic =
|
792
|
+
for (int64_t ic = 2*tiitg; ic < n_dims; ic += 2*tptg.x) {
|
793
|
+
|
794
|
+
const float theta = theta_0 * pow(freq_base, inv_ndims*ic - ib);
|
666
795
|
const float cos_theta = cos(theta);
|
667
796
|
const float sin_theta = sin(theta);
|
668
797
|
|
669
|
-
theta *= theta_scale;
|
670
|
-
|
671
798
|
const int64_t i0 = ib*n_dims + ic/2;
|
672
799
|
|
673
800
|
device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
@@ -1071,31 +1198,40 @@ kernel void kernel_mul_mat_q3_K_f32(
|
|
1071
1198
|
device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0;
|
1072
1199
|
device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
|
1073
1200
|
|
1074
|
-
float yl[
|
1201
|
+
float yl[32];
|
1075
1202
|
|
1076
|
-
const uint16_t kmask1 =
|
1203
|
+
const uint16_t kmask1 = 0x3030;
|
1077
1204
|
const uint16_t kmask2 = 0x0f0f;
|
1078
1205
|
|
1079
|
-
const int tid = tiisg/
|
1080
|
-
const int ix = tiisg%
|
1081
|
-
const int ip = tid/
|
1082
|
-
const int il = tid/2
|
1206
|
+
const int tid = tiisg/4;
|
1207
|
+
const int ix = tiisg%4;
|
1208
|
+
const int ip = tid/4; // 0 or 1
|
1209
|
+
const int il = 2*((tid%4)/2); // 0 or 2
|
1083
1210
|
const int ir = tid%2;
|
1084
1211
|
const int n = 8;
|
1085
1212
|
const int l0 = n*ir;
|
1086
1213
|
|
1087
|
-
|
1088
|
-
|
1214
|
+
// One would think that the Metal compiler would figure out that ip and il can only have
|
1215
|
+
// 4 possible states, and optimize accordingly. Well, no. It needs help, and we do it
|
1216
|
+
// with these two tales.
|
1217
|
+
//
|
1218
|
+
// Possible masks for the high bit
|
1219
|
+
const ushort4 mm[4] = {{0x0001, 0x0100, 0x0002, 0x0200}, // ip = 0, il = 0
|
1220
|
+
{0x0004, 0x0400, 0x0008, 0x0800}, // ip = 0, il = 2
|
1221
|
+
{0x0010, 0x1000, 0x0020, 0x2000}, // ip = 1, il = 0
|
1222
|
+
{0x0040, 0x4000, 0x0080, 0x8000}}; // ip = 1, il = 2
|
1223
|
+
|
1224
|
+
// Possible masks for the low 2 bits
|
1225
|
+
const int4 qm[2] = {{0x0003, 0x0300, 0x000c, 0x0c00}, {0x0030, 0x3000, 0x00c0, 0xc000}};
|
1226
|
+
|
1227
|
+
const ushort4 hm = mm[2*ip + il/2];
|
1089
1228
|
|
1090
1229
|
const int shift = 2*il;
|
1091
|
-
const
|
1092
|
-
const
|
1093
|
-
const int32_t v1 = 4 << shift;
|
1094
|
-
const int32_t v2 = 1024 << shift;
|
1230
|
+
const float v1 = il == 0 ? 4.f : 64.f;
|
1231
|
+
const float v2 = 4.f * v1;
|
1095
1232
|
|
1096
1233
|
const uint16_t s_shift1 = 4*ip;
|
1097
|
-
const uint16_t s_shift2 = s_shift1 +
|
1098
|
-
const int ik = 4 + (il%2);
|
1234
|
+
const uint16_t s_shift2 = s_shift1 + il;
|
1099
1235
|
|
1100
1236
|
const int q_offset = 32*ip + l0;
|
1101
1237
|
const int y_offset = 128*ip + 32*il + l0;
|
@@ -1104,12 +1240,19 @@ kernel void kernel_mul_mat_q3_K_f32(
|
|
1104
1240
|
|
1105
1241
|
device const float * y1 = yy + ix*QK_K + y_offset;
|
1106
1242
|
|
1107
|
-
|
1108
|
-
|
1243
|
+
uint32_t scales32, aux32;
|
1244
|
+
thread uint16_t * scales16 = (thread uint16_t *)&scales32;
|
1245
|
+
thread const int8_t * scales = (thread const int8_t *)&scales32;
|
1246
|
+
|
1247
|
+
float sumf1[2] = {0.f};
|
1248
|
+
float sumf2[2] = {0.f};
|
1249
|
+
for (int i = ix; i < nb; i += 4) {
|
1109
1250
|
|
1110
1251
|
for (int l = 0; l < 8; ++l) {
|
1111
|
-
yl[l+0] = y1[l+ 0];
|
1112
|
-
yl[l+8] = y1[l+16];
|
1252
|
+
yl[l+ 0] = y1[l+ 0];
|
1253
|
+
yl[l+ 8] = y1[l+16];
|
1254
|
+
yl[l+16] = y1[l+32];
|
1255
|
+
yl[l+24] = y1[l+48];
|
1113
1256
|
}
|
1114
1257
|
|
1115
1258
|
device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset);
|
@@ -1120,27 +1263,43 @@ kernel void kernel_mul_mat_q3_K_f32(
|
|
1120
1263
|
for (int row = 0; row < 2; ++row) {
|
1121
1264
|
|
1122
1265
|
const float d_all = (float)dh[0];
|
1123
|
-
const char2 scales = as_type<char2>((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4)));
|
1124
1266
|
|
1125
|
-
|
1267
|
+
scales16[0] = a[4];
|
1268
|
+
scales16[1] = a[5];
|
1269
|
+
aux32 = ((scales32 >> s_shift2) << 4) & 0x30303030;
|
1270
|
+
scales16[0] = a[il+0];
|
1271
|
+
scales16[1] = a[il+1];
|
1272
|
+
scales32 = ((scales32 >> s_shift1) & 0x0f0f0f0f) | aux32;
|
1273
|
+
|
1274
|
+
float s1 = 0, s2 = 0, s3 = 0, s4 = 0, s5 = 0, s6 = 0;
|
1126
1275
|
for (int l = 0; l < n; l += 2) {
|
1127
|
-
const
|
1128
|
-
s1 += yl[l+0] * (
|
1129
|
-
s2 += yl[l+1] * (
|
1276
|
+
const int32_t qs = q[l/2];
|
1277
|
+
s1 += yl[l+0] * (qs & qm[il/2][0]);
|
1278
|
+
s2 += yl[l+1] * (qs & qm[il/2][1]);
|
1279
|
+
s3 += ((h[l/2] & hm[0]) ? 0.f : yl[l+0]) + ((h[l/2] & hm[1]) ? 0.f : yl[l+1]);
|
1280
|
+
s4 += yl[l+16] * (qs & qm[il/2][2]);
|
1281
|
+
s5 += yl[l+17] * (qs & qm[il/2][3]);
|
1282
|
+
s6 += ((h[l/2] & hm[2]) ? 0.f : yl[l+16]) + ((h[l/2] & hm[3]) ? 0.f : yl[l+17]);
|
1130
1283
|
}
|
1131
|
-
float
|
1132
|
-
|
1133
|
-
|
1284
|
+
float d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
|
1285
|
+
float d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2);
|
1286
|
+
sumf1[row] += d1 * (scales[0] - 32);
|
1287
|
+
sumf2[row] += d2 * (scales[2] - 32);
|
1134
1288
|
|
1135
|
-
s1 = s2 = 0;
|
1289
|
+
s1 = s2 = s3 = s4 = s5 = s6 = 0;
|
1136
1290
|
for (int l = 0; l < n; l += 2) {
|
1137
|
-
const
|
1138
|
-
s1 += yl[l+8] * (
|
1139
|
-
s2 += yl[l+9] * (
|
1291
|
+
const int32_t qs = q[l/2+8];
|
1292
|
+
s1 += yl[l+8] * (qs & qm[il/2][0]);
|
1293
|
+
s2 += yl[l+9] * (qs & qm[il/2][1]);
|
1294
|
+
s3 += ((h[l/2+8] & hm[0]) ? 0.f : yl[l+8]) + ((h[l/2+8] & hm[1]) ? 0.f : yl[l+9]);
|
1295
|
+
s4 += yl[l+24] * (qs & qm[il/2][2]);
|
1296
|
+
s5 += yl[l+25] * (qs & qm[il/2][3]);
|
1297
|
+
s6 += ((h[l/2+8] & hm[2]) ? 0.f : yl[l+24]) + ((h[l/2+8] & hm[3]) ? 0.f : yl[l+25]);
|
1140
1298
|
}
|
1141
|
-
|
1142
|
-
|
1143
|
-
|
1299
|
+
d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
|
1300
|
+
d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2);
|
1301
|
+
sumf1[row] += d1 * (scales[1] - 32);
|
1302
|
+
sumf2[row] += d2 * (scales[3] - 32);
|
1144
1303
|
|
1145
1304
|
q += step;
|
1146
1305
|
h += step;
|
@@ -1149,17 +1308,20 @@ kernel void kernel_mul_mat_q3_K_f32(
|
|
1149
1308
|
|
1150
1309
|
}
|
1151
1310
|
|
1152
|
-
y1 +=
|
1311
|
+
y1 += 4 * QK_K;
|
1153
1312
|
|
1154
1313
|
}
|
1155
1314
|
|
1156
1315
|
for (int row = 0; row < 2; ++row) {
|
1157
|
-
const float sumf = (sumf1[row]
|
1158
|
-
|
1159
|
-
|
1160
|
-
|
1316
|
+
const float sumf = (sumf1[row] + 0.25f * sumf2[row]) / (1 << shift);
|
1317
|
+
sumf1[row] = simd_sum(sumf);
|
1318
|
+
}
|
1319
|
+
if (tiisg == 0) {
|
1320
|
+
for (int row = 0; row < 2; ++row) {
|
1321
|
+
dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = sumf1[row];
|
1161
1322
|
}
|
1162
1323
|
}
|
1324
|
+
|
1163
1325
|
}
|
1164
1326
|
#else
|
1165
1327
|
kernel void kernel_mul_mat_q3_K_f32(
|
@@ -1262,7 +1424,8 @@ kernel void kernel_mul_mat_q4_K_f32(
|
|
1262
1424
|
const int r0 = tgpig.x;
|
1263
1425
|
const int r1 = tgpig.y;
|
1264
1426
|
const int r2 = tgpig.z;
|
1265
|
-
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
1427
|
+
//const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
1428
|
+
const int first_row = r0 * N_DST;
|
1266
1429
|
const int ib_row = first_row * nb;
|
1267
1430
|
const uint offset0 = r2/gqa*(nb*ne0);
|
1268
1431
|
device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
|
@@ -1511,17 +1674,25 @@ kernel void kernel_mul_mat_q5_K_f32(
|
|
1511
1674
|
sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2);
|
1512
1675
|
sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2);
|
1513
1676
|
|
1514
|
-
float4
|
1677
|
+
float4 acc1 = {0.f};
|
1678
|
+
float4 acc2 = {0.f};
|
1515
1679
|
for (int l = 0; l < n; ++l) {
|
1516
1680
|
uint8_t h = qh[l];
|
1517
|
-
|
1518
|
-
|
1519
|
-
|
1520
|
-
|
1681
|
+
acc1[0] += yl[l+0] * (q1[l] & 0x0F);
|
1682
|
+
acc1[1] += yl[l+8] * (q1[l] & 0xF0);
|
1683
|
+
acc1[2] += yh[l+0] * (q2[l] & 0x0F);
|
1684
|
+
acc1[3] += yh[l+8] * (q2[l] & 0xF0);
|
1685
|
+
acc2[0] += h & hm1 ? yl[l+0] : 0.f;
|
1686
|
+
acc2[1] += h & hm2 ? yl[l+8] : 0.f;
|
1687
|
+
acc2[2] += h & hm3 ? yh[l+0] : 0.f;
|
1688
|
+
acc2[3] += h & hm4 ? yh[l+8] : 0.f;
|
1521
1689
|
}
|
1522
1690
|
const float dall = dh[0];
|
1523
1691
|
const float dmin = dh[1];
|
1524
|
-
sumf[row] += dall * (
|
1692
|
+
sumf[row] += dall * (sc8[0] * (acc1[0] + 16.f*acc2[0]) +
|
1693
|
+
sc8[1] * (acc1[1]/16.f + 16.f*acc2[1]) +
|
1694
|
+
sc8[4] * (acc1[2] + 16.f*acc2[2]) +
|
1695
|
+
sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) -
|
1525
1696
|
dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
|
1526
1697
|
|
1527
1698
|
q1 += step;
|
@@ -1704,29 +1875,34 @@ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg)
|
|
1704
1875
|
|
1705
1876
|
template <typename type4x4>
|
1706
1877
|
void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
|
1878
|
+
|
1707
1879
|
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
|
1708
|
-
const
|
1709
|
-
const
|
1880
|
+
const float d1 = il ? (xb->d / 16.h) : xb->d;
|
1881
|
+
const float d2 = d1 / 256.f;
|
1882
|
+
const float md = -8.h * xb->d;
|
1710
1883
|
const ushort mask0 = il ? 0x00F0 : 0x000F;
|
1711
|
-
const ushort mask1 =
|
1884
|
+
const ushort mask1 = mask0 << 8;
|
1712
1885
|
|
1713
1886
|
for (int i=0;i<8;i++) {
|
1714
|
-
reg[i/2][2*(i%2)]
|
1715
|
-
reg[i/2][2*(i%2)+1] = (
|
1887
|
+
reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md;
|
1888
|
+
reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md;
|
1716
1889
|
}
|
1890
|
+
|
1717
1891
|
}
|
1718
1892
|
|
1719
1893
|
template <typename type4x4>
|
1720
1894
|
void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
|
1895
|
+
|
1721
1896
|
device const uint16_t * qs = ((device const uint16_t *)xb + 2);
|
1722
|
-
const
|
1723
|
-
const
|
1897
|
+
const float d1 = il ? (xb->d / 16.h) : xb->d;
|
1898
|
+
const float d2 = d1 / 256.f;
|
1899
|
+
const float m = xb->m;
|
1724
1900
|
const ushort mask0 = il ? 0x00F0 : 0x000F;
|
1725
|
-
const ushort mask1 =
|
1901
|
+
const ushort mask1 = mask0 << 8;
|
1726
1902
|
|
1727
1903
|
for (int i=0;i<8;i++) {
|
1728
|
-
reg[i/2][2*(i%2)]
|
1729
|
-
reg[i/2][2*(i%2)+1] = ((
|
1904
|
+
reg[i/2][2*(i%2)+0] = ((qs[i] & mask0) * d1) + m;
|
1905
|
+
reg[i/2][2*(i%2)+1] = ((qs[i] & mask1) * d2) + m;
|
1730
1906
|
}
|
1731
1907
|
}
|
1732
1908
|
|
@@ -1762,7 +1938,7 @@ void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg
|
|
1762
1938
|
|
1763
1939
|
template <typename type4x4>
|
1764
1940
|
void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) {
|
1765
|
-
const
|
1941
|
+
const half d_all = xb->d;
|
1766
1942
|
device const uint8_t * q = (device const uint8_t *)xb->qs;
|
1767
1943
|
device const uint8_t * h = (device const uint8_t *)xb->hmask;
|
1768
1944
|
device const int8_t * scales = (device const int8_t *)xb->scales;
|
@@ -1775,17 +1951,20 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
|
|
1775
1951
|
((il/4)>0 ? 12 : 3);
|
1776
1952
|
uint16_t kmask2 = il/8 ? 0xF0 : 0x0F;
|
1777
1953
|
uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];
|
1778
|
-
int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2)
|
1779
|
-
|
1780
|
-
|
1954
|
+
int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2)
|
1955
|
+
: (scale_2&kmask2) | ((scale_1&kmask1) << 4);
|
1956
|
+
half dl = il<8 ? d_all * (dl_int - 32.h) : d_all * (dl_int / 16.h - 32.h);
|
1957
|
+
const half ml = 4.h * dl;
|
1781
1958
|
|
1782
|
-
il = (il/2)
|
1783
|
-
|
1784
|
-
uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
|
1959
|
+
il = (il/2) & 3;
|
1960
|
+
const half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
|
1961
|
+
const uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
|
1962
|
+
dl *= coef;
|
1785
1963
|
|
1786
1964
|
for (int i = 0; i < 16; ++i) {
|
1787
|
-
reg[i/4][i%4] =
|
1965
|
+
reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml);
|
1788
1966
|
}
|
1967
|
+
|
1789
1968
|
#else
|
1790
1969
|
float kcoef = il&1 ? 1.f/16.f : 1.f;
|
1791
1970
|
uint16_t kmask = il&1 ? 0xF0 : 0x0F;
|
@@ -1799,31 +1978,37 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
|
|
1799
1978
|
#endif
|
1800
1979
|
}
|
1801
1980
|
|
1981
|
+
static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) {
|
1982
|
+
return j < 4 ? uchar2{uchar(q[j+0+k] & 63), uchar(q[j+4+k] & 63)}
|
1983
|
+
: uchar2{uchar((q[j+4+k] & 0xF) | ((q[j-4+k] & 0xc0) >> 2)), uchar((q[j+4+k] >> 4) | ((q[j-0+k] & 0xc0) >> 2))};
|
1984
|
+
}
|
1985
|
+
|
1802
1986
|
template <typename type4x4>
|
1803
1987
|
void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) {
|
1804
|
-
device const
|
1988
|
+
device const uchar * q = xb->qs;
|
1805
1989
|
|
1806
1990
|
#if QK_K == 256
|
1807
|
-
const float d = (float)(xb->d);
|
1808
|
-
const float min = (float)(xb->dmin);
|
1809
1991
|
short is = (il/4) * 2;
|
1810
1992
|
q = q + (il/4) * 32 + 16 * (il&1);
|
1811
|
-
il = il
|
1812
|
-
const
|
1813
|
-
const
|
1814
|
-
const
|
1993
|
+
il = il & 3;
|
1994
|
+
const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
|
1995
|
+
const half d = il < 2 ? xb->d : xb->d / 16.h;
|
1996
|
+
const half min = xb->dmin;
|
1997
|
+
const half dl = d * sc[0];
|
1998
|
+
const half ml = min * sc[1];
|
1815
1999
|
#else
|
1816
2000
|
q = q + 16 * (il&1);
|
1817
2001
|
device const uint8_t * s = xb->scales;
|
1818
2002
|
device const half2 * dh = (device const half2 *)xb->d;
|
1819
2003
|
const float2 d = (float2)dh[0];
|
1820
2004
|
const float dl = il<2 ? d[0] * (s[0]&0xF) : d[0] * (s[1]&0xF)/16.h;
|
1821
|
-
const float ml = il<2 ? d[1] * (s[0]>>4) : d[1
|
2005
|
+
const float ml = il<2 ? d[1] * (s[0]>>4) : d[1] * (s[1]>>4);
|
1822
2006
|
#endif
|
1823
2007
|
const ushort mask = il<2 ? 0x0F : 0xF0;
|
1824
2008
|
for (int i = 0; i < 16; ++i) {
|
1825
2009
|
reg[i/4][i%4] = dl * (q[i] & mask) - ml;
|
1826
2010
|
}
|
2011
|
+
|
1827
2012
|
}
|
1828
2013
|
|
1829
2014
|
template <typename type4x4>
|
@@ -1832,19 +2017,19 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
|
|
1832
2017
|
device const uint8_t * qh = xb->qh;
|
1833
2018
|
|
1834
2019
|
#if QK_K == 256
|
1835
|
-
const float d = (float)(xb->d);
|
1836
|
-
const float min = (float)(xb->dmin);
|
1837
2020
|
short is = (il/4) * 2;
|
1838
2021
|
q = q + 32 * (il/4) + 16 * (il&1);
|
1839
2022
|
qh = qh + 16 * (il&1);
|
1840
2023
|
uint8_t ul = 1 << (il/2);
|
1841
|
-
il = il
|
1842
|
-
const
|
1843
|
-
const
|
1844
|
-
const
|
2024
|
+
il = il & 3;
|
2025
|
+
const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
|
2026
|
+
const half d = il < 2 ? xb->d : xb->d / 16.h;
|
2027
|
+
const half min = xb->dmin;
|
2028
|
+
const half dl = d * sc[0];
|
2029
|
+
const half ml = min * sc[1];
|
1845
2030
|
|
1846
|
-
const ushort mask
|
1847
|
-
const
|
2031
|
+
const ushort mask = il<2 ? 0x0F : 0xF0;
|
2032
|
+
const half qh_val = il<2 ? 16.h : 256.h;
|
1848
2033
|
for (int i = 0; i < 16; ++i) {
|
1849
2034
|
reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
|
1850
2035
|
}
|
@@ -1863,7 +2048,7 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
|
|
1863
2048
|
|
1864
2049
|
template <typename type4x4>
|
1865
2050
|
void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) {
|
1866
|
-
const
|
2051
|
+
const half d_all = xb->d;
|
1867
2052
|
device const uint8_t * ql = (device const uint8_t *)xb->ql;
|
1868
2053
|
device const uint8_t * qh = (device const uint8_t *)xb->qh;
|
1869
2054
|
device const int8_t * scales = (device const int8_t *)xb->scales;
|
@@ -1871,19 +2056,21 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
|
|
1871
2056
|
#if QK_K == 256
|
1872
2057
|
ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
|
1873
2058
|
qh = qh + 32*(il/8) + 16*(il&1);
|
1874
|
-
|
1875
|
-
il = (il/2)
|
2059
|
+
half sc = scales[(il%2) + 2 * ((il/2))];
|
2060
|
+
il = (il/2) & 3;
|
1876
2061
|
#else
|
1877
2062
|
ql = ql + 16 * (il&1);
|
1878
|
-
|
2063
|
+
half sc = scales[il];
|
1879
2064
|
#endif
|
2065
|
+
const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
|
2066
|
+
const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
|
2067
|
+
const half coef = il>1 ? 1.f/16.h : 1.h;
|
2068
|
+
const half ml = d_all * sc * 32.h;
|
2069
|
+
const half dl = d_all * sc * coef;
|
1880
2070
|
for (int i = 0; i < 16; ++i) {
|
1881
|
-
|
1882
|
-
|
1883
|
-
|
1884
|
-
float q = il&1 ? ((ql[i]&kmask2)|((qh[i]&kmask1)<<2)) - 32.f/coef : \
|
1885
|
-
((ql[i]&kmask2)|((qh[i]&kmask1)<<4)) - 32.f/coef;
|
1886
|
-
reg[i/4][i%4] = d_all * sc * q * coef;
|
2071
|
+
const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2))
|
2072
|
+
: ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4));
|
2073
|
+
reg[i/4][i%4] = dl * q - ml;
|
1887
2074
|
}
|
1888
2075
|
}
|
1889
2076
|
|