llama_cpp 0.5.0 → 0.5.2

Sign up to get free protection for your applications and to get access to all the features.
@@ -63,18 +63,18 @@ kernel void kernel_mul_row(
63
63
  }
64
64
 
65
65
  kernel void kernel_scale(
66
- device const float * src0,
67
- device float * dst,
66
+ device const float4 * src0,
67
+ device float4 * dst,
68
68
  constant float & scale,
69
69
  uint tpig[[thread_position_in_grid]]) {
70
70
  dst[tpig] = src0[tpig] * scale;
71
71
  }
72
72
 
73
73
  kernel void kernel_silu(
74
- device const float * src0,
75
- device float * dst,
74
+ device const float4 * src0,
75
+ device float4 * dst,
76
76
  uint tpig[[thread_position_in_grid]]) {
77
- float x = src0[tpig];
77
+ device const float4 & x = src0[tpig];
78
78
  dst[tpig] = x / (1.0f + exp(-x));
79
79
  }
80
80
 
@@ -89,10 +89,10 @@ constant float GELU_COEF_A = 0.044715f;
89
89
  constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
90
90
 
91
91
  kernel void kernel_gelu(
92
- device const float * src0,
93
- device float * dst,
92
+ device const float4 * src0,
93
+ device float4 * dst,
94
94
  uint tpig[[thread_position_in_grid]]) {
95
- float x = src0[tpig];
95
+ device const float4 & x = src0[tpig];
96
96
 
97
97
  // BEWARE !!!
98
98
  // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
@@ -107,7 +107,6 @@ kernel void kernel_soft_max(
107
107
  constant int64_t & ne00,
108
108
  constant int64_t & ne01,
109
109
  constant int64_t & ne02,
110
- threadgroup float * buf [[threadgroup(0)]],
111
110
  uint3 tgpig[[threadgroup_position_in_grid]],
112
111
  uint3 tpitg[[thread_position_in_threadgroup]],
113
112
  uint3 ntg[[threads_per_threadgroup]]) {
@@ -119,55 +118,67 @@ kernel void kernel_soft_max(
119
118
  device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
120
119
 
121
120
  // parallel max
122
- buf[tpitg[0]] = -INFINITY;
123
- for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
124
- buf[tpitg[0]] = MAX(buf[tpitg[0]], psrc0[i00]);
121
+ float lmax = psrc0[tpitg[0]];
122
+ for (int i00 = tpitg[0] + ntg[0]; i00 < ne00; i00 += ntg[0]) {
123
+ lmax = MAX(lmax, psrc0[i00]);
125
124
  }
125
+ const float max = simd_max(lmax);
126
126
 
127
- // reduce
128
- threadgroup_barrier(mem_flags::mem_threadgroup);
129
- for (uint i = ntg[0]/2; i > 0; i /= 2) {
130
- if (tpitg[0] < i) {
131
- buf[tpitg[0]] = MAX(buf[tpitg[0]], buf[tpitg[0] + i]);
132
- }
133
- threadgroup_barrier(mem_flags::mem_threadgroup);
127
+ // parallel sum
128
+ float lsum = 0.0f;
129
+ for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
130
+ const float exp_psrc0 = exp(psrc0[i00] - max);
131
+ lsum += exp_psrc0;
132
+ // Remember the result of exp here. exp is expensive, so we really do not
133
+ // whish to compute it twice.
134
+ pdst[i00] = exp_psrc0;
134
135
  }
135
136
 
136
- // broadcast
137
- if (tpitg[0] == 0) {
138
- buf[0] = buf[0];
137
+ const float sum = simd_sum(lsum);
138
+
139
+ for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
140
+ pdst[i00] /= sum;
139
141
  }
142
+ }
140
143
 
141
- threadgroup_barrier(mem_flags::mem_threadgroup);
144
+ kernel void kernel_soft_max_4(
145
+ device const float * src0,
146
+ device float * dst,
147
+ constant int64_t & ne00,
148
+ constant int64_t & ne01,
149
+ constant int64_t & ne02,
150
+ uint3 tgpig[[threadgroup_position_in_grid]],
151
+ uint3 tpitg[[thread_position_in_threadgroup]],
152
+ uint3 ntg[[threads_per_threadgroup]]) {
153
+ const int64_t i03 = tgpig[2];
154
+ const int64_t i02 = tgpig[1];
155
+ const int64_t i01 = tgpig[0];
142
156
 
143
- const float max = buf[0];
157
+ device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
158
+ device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
144
159
 
145
- // parallel sum
146
- buf[tpitg[0]] = 0.0f;
147
- for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
148
- buf[tpitg[0]] += exp(psrc0[i00] - max);
160
+ // parallel max
161
+ float4 lmax4 = psrc4[tpitg[0]];
162
+ for (int i00 = tpitg[0] + ntg[0]; i00 < ne00/4; i00 += ntg[0]) {
163
+ lmax4 = fmax(lmax4, psrc4[i00]);
149
164
  }
165
+ float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
150
166
 
151
- // reduce
152
- threadgroup_barrier(mem_flags::mem_threadgroup);
153
- for (uint i = ntg[0]/2; i > 0; i /= 2) {
154
- if (tpitg[0] < i) {
155
- buf[tpitg[0]] += buf[tpitg[0] + i];
156
- }
157
- threadgroup_barrier(mem_flags::mem_threadgroup);
158
- }
167
+ const float max = simd_max(lmax);
159
168
 
160
- // broadcast
161
- if (tpitg[0] == 0) {
162
- buf[0] = buf[0];
169
+ // parallel sum
170
+ float4 lsum4 = 0.0f;
171
+ for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) {
172
+ const float4 exp_psrc4 = exp(psrc4[i00] - max);
173
+ lsum4 += exp_psrc4;
174
+ pdst4[i00] = exp_psrc4;
163
175
  }
176
+ float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
164
177
 
165
- threadgroup_barrier(mem_flags::mem_threadgroup);
166
-
167
- const float sum = buf[0];
178
+ const float sum = simd_sum(lsum);
168
179
 
169
- for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
170
- pdst[i00] = exp(psrc0[i00] - max) / sum;
180
+ for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) {
181
+ pdst4[i00] /= sum;
171
182
  }
172
183
  }
173
184
 
@@ -186,6 +197,33 @@ kernel void kernel_diag_mask_inf(
186
197
  dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
187
198
  } else {
188
199
  dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];
200
+ }
201
+ }
202
+
203
+ kernel void kernel_diag_mask_inf_8(
204
+ device const float4 * src0,
205
+ device float4 * dst,
206
+ constant int64_t & ne00,
207
+ constant int64_t & ne01,
208
+ constant int & n_past,
209
+ uint3 tpig[[thread_position_in_grid]]) {
210
+
211
+ const int64_t i = 2*tpig[0];
212
+
213
+ dst[i+0] = src0[i+0];
214
+ dst[i+1] = src0[i+1];
215
+ int64_t i4 = 4*i;
216
+ const int64_t i02 = i4/(ne00*ne01); i4 -= i02*ne00*ne01;
217
+ const int64_t i01 = i4/(ne00); i4 -= i01*ne00;
218
+ const int64_t i00 = i4;
219
+ for (int k = 3; k >= 0; --k) {
220
+ if (i00 + 4 + k <= n_past + i01) {
221
+ break;
222
+ }
223
+ dst[i+1][k] = -INFINITY;
224
+ if (i00 + k > n_past + i01) {
225
+ dst[i][k] = -INFINITY;
226
+ }
189
227
  }
190
228
  }
191
229
 
@@ -214,25 +252,17 @@ kernel void kernel_norm(
214
252
  }
215
253
  threadgroup_barrier(mem_flags::mem_threadgroup);
216
254
  }
217
- // broadcast
218
- if (tpitg == 0) {
219
- sum[0] /= ne00;
220
- }
221
- threadgroup_barrier(mem_flags::mem_threadgroup);
222
- const float mean = sum[0];
255
+ const float mean = sum[0] / ne00;
223
256
 
224
- // recenter
257
+ // recenter and VARIANCE
258
+ threadgroup_barrier(mem_flags::mem_threadgroup);
225
259
  device float * y = dst + tgpig*ne00;
226
- for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
227
- y[i00] = x[i00] - mean;
228
- }
229
-
230
- // VARIANCE
231
- // parallel sum
232
260
  sum[tpitg] = 0.0f;
233
261
  for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
262
+ y[i00] = x[i00] - mean;
234
263
  sum[tpitg] += y[i00] * y[i00];
235
264
  }
265
+
236
266
  // reduce
237
267
  threadgroup_barrier(mem_flags::mem_threadgroup);
238
268
  for (uint i = ntg/2; i > 0; i /= 2) {
@@ -241,12 +271,7 @@ kernel void kernel_norm(
241
271
  }
242
272
  threadgroup_barrier(mem_flags::mem_threadgroup);
243
273
  }
244
- // broadcast
245
- if (tpitg == 0) {
246
- sum[0] /= ne00;
247
- }
248
- threadgroup_barrier(mem_flags::mem_threadgroup);
249
- const float variance = sum[0];
274
+ const float variance = sum[0] / ne00;
250
275
 
251
276
  const float scale = 1.0f/sqrt(variance + eps);
252
277
  for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
@@ -254,7 +279,6 @@ kernel void kernel_norm(
254
279
  }
255
280
  }
256
281
 
257
-
258
282
  kernel void kernel_rms_norm(
259
283
  device const void * src0,
260
284
  device float * dst,
@@ -435,6 +459,8 @@ kernel void kernel_mul_mat_q4_1_f32(
435
459
  mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
436
460
  }
437
461
 
462
+ #define NB_Q8_0 8
463
+
438
464
  kernel void kernel_mul_mat_q8_0_f32(
439
465
  device const void * src0,
440
466
  device const float * src1,
@@ -463,30 +489,30 @@ kernel void kernel_mul_mat_q8_0_f32(
463
489
  device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0;
464
490
  device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
465
491
 
466
- float yl[16];
492
+ float yl[NB_Q8_0];
467
493
  float sumf[nr]={0.f};
468
494
 
469
- const int ix = tiisg/2;
470
- const int il = tiisg%2;
495
+ const int ix = tiisg/4;
496
+ const int il = tiisg%4;
471
497
 
472
- device const float * yb = y + ix * QK8_0 + 16*il;
498
+ device const float * yb = y + ix * QK8_0 + NB_Q8_0*il;
473
499
 
474
- // each thread in a SIMD group deals with half a block.
475
- for (int ib = ix; ib < nb; ib += nw/2) {
476
- for (int i = 0; i < 16; ++i) {
500
+ // each thread in a SIMD group deals with NB_Q8_0 quants at a time
501
+ for (int ib = ix; ib < nb; ib += nw/4) {
502
+ for (int i = 0; i < NB_Q8_0; ++i) {
477
503
  yl[i] = yb[i];
478
504
  }
479
505
 
480
506
  for (int row = 0; row < nr; row++) {
481
- device const int8_t * qs = x[ib+row*nb].qs + 16*il;
507
+ device const int8_t * qs = x[ib+row*nb].qs + NB_Q8_0*il;
482
508
  float sumq = 0.f;
483
- for (int iq = 0; iq < 16; ++iq) {
509
+ for (int iq = 0; iq < NB_Q8_0; ++iq) {
484
510
  sumq += qs[iq] * yl[iq];
485
511
  }
486
512
  sumf[row] += sumq*x[ib+row*nb].d;
487
513
  }
488
514
 
489
- yb += QK8_0 * 16;
515
+ yb += NB_Q8_0 * nw;
490
516
  }
491
517
 
492
518
  for (int row = 0; row < nr; ++row) {
@@ -497,7 +523,7 @@ kernel void kernel_mul_mat_q8_0_f32(
497
523
  }
498
524
  }
499
525
 
500
- kernel void kernel_mul_mat_f16_f32(
526
+ kernel void kernel_mul_mat_f16_f32_1row(
501
527
  device const char * src0,
502
528
  device const char * src1,
503
529
  device float * dst,
@@ -515,11 +541,8 @@ kernel void kernel_mul_mat_f16_f32(
515
541
  constant uint64_t & nb12,
516
542
  constant int64_t & ne0,
517
543
  constant int64_t & ne1,
518
- threadgroup float * sum [[threadgroup(0)]],
519
544
  uint3 tgpig[[threadgroup_position_in_grid]],
520
- uint3 tpig[[thread_position_in_grid]],
521
- uint3 tpitg[[thread_position_in_threadgroup]],
522
- uint3 tptg[[threads_per_threadgroup]]) {
545
+ uint tiisg[[thread_index_in_simdgroup]]) {
523
546
 
524
547
  const int64_t r0 = tgpig.x;
525
548
  const int64_t r1 = tgpig.y;
@@ -528,42 +551,144 @@ kernel void kernel_mul_mat_f16_f32(
528
551
  device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
529
552
  device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
530
553
 
531
- uint ith = tpitg.x;
532
- uint nth = tptg.x;
554
+ float sumf = 0;
555
+ if (ne00 < 128) {
556
+ for (int i = tiisg; i < ne00; i += 32) {
557
+ sumf += (float) x[i] * (float) y[i];
558
+ }
559
+ float all_sum = simd_sum(sumf);
560
+ if (tiisg == 0) {
561
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
562
+ }
563
+ } else {
564
+ device const half4 * x4 = (device const half4 *) x;
565
+ device const float4 * y4 = (device const float4 *) y;
566
+ for (int i = tiisg; i < ne00/4; i += 32) {
567
+ for (int k = 0; k < 4; ++k) sumf += (float)x4[i][k] * y4[i][k];
568
+ }
569
+ float all_sum = simd_sum(sumf);
570
+ if (tiisg == 0) {
571
+ for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
572
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
573
+ }
574
+ }
533
575
 
534
- sum[ith] = 0.0f;
576
+ }
535
577
 
536
- for (int i = ith; i < ne00; i += nth) {
537
- sum[ith] += (float) x[i] * (float) y[i];
538
- }
578
+ #define N_F16_F32 4
539
579
 
540
- // accumulate the sum from all threads in the threadgroup
541
- threadgroup_barrier(mem_flags::mem_threadgroup);
542
- if (ith%4 == 0) {
543
- for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
580
+ kernel void kernel_mul_mat_f16_f32(
581
+ device const char * src0,
582
+ device const char * src1,
583
+ device float * dst,
584
+ constant int64_t & ne00,
585
+ constant int64_t & ne01,
586
+ constant int64_t & ne02,
587
+ constant uint64_t & nb00,
588
+ constant uint64_t & nb01,
589
+ constant uint64_t & nb02,
590
+ constant int64_t & ne10,
591
+ constant int64_t & ne11,
592
+ constant int64_t & ne12,
593
+ constant uint64_t & nb10,
594
+ constant uint64_t & nb11,
595
+ constant uint64_t & nb12,
596
+ constant int64_t & ne0,
597
+ constant int64_t & ne1,
598
+ uint3 tgpig[[threadgroup_position_in_grid]],
599
+ uint tiisg[[thread_index_in_simdgroup]]) {
600
+
601
+ const int64_t r0 = tgpig.x;
602
+ const int64_t rb = tgpig.y*N_F16_F32;
603
+ const int64_t im = tgpig.z;
604
+
605
+ device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
606
+
607
+ if (ne00 < 128) {
608
+ for (int row = 0; row < N_F16_F32; ++row) {
609
+ int r1 = rb + row;
610
+ if (r1 >= ne11) {
611
+ break;
612
+ }
613
+
614
+ device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
615
+
616
+ float sumf = 0;
617
+ for (int i = tiisg; i < ne00; i += 32) {
618
+ sumf += (float) x[i] * (float) y[i];
619
+ }
620
+
621
+ float all_sum = simd_sum(sumf);
622
+ if (tiisg == 0) {
623
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
624
+ }
625
+ }
626
+ } else {
627
+ device const half4 * x4 = (device const half4 *)x;
628
+ for (int row = 0; row < N_F16_F32; ++row) {
629
+ int r1 = rb + row;
630
+ if (r1 >= ne11) {
631
+ break;
632
+ }
633
+
634
+ device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
635
+ device const float4 * y4 = (device const float4 *) y;
636
+
637
+ float sumf = 0;
638
+ for (int i = tiisg; i < ne00/4; i += 32) {
639
+ for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
640
+ }
641
+
642
+ float all_sum = simd_sum(sumf);
643
+ if (tiisg == 0) {
644
+ for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
645
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
646
+ }
647
+ }
544
648
  }
545
- threadgroup_barrier(mem_flags::mem_threadgroup);
546
- if (ith%16 == 0) {
547
- for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
649
+ }
650
+
651
+ // Assumes row size (ne00) is a multiple of 4
652
+ kernel void kernel_mul_mat_f16_f32_l4(
653
+ device const char * src0,
654
+ device const char * src1,
655
+ device float * dst,
656
+ constant int64_t & ne00,
657
+ constant int64_t & ne01,
658
+ constant int64_t & ne02,
659
+ constant uint64_t & nb00,
660
+ constant uint64_t & nb01,
661
+ constant uint64_t & nb02,
662
+ constant int64_t & ne10,
663
+ constant int64_t & ne11,
664
+ constant int64_t & ne12,
665
+ constant uint64_t & nb10,
666
+ constant uint64_t & nb11,
667
+ constant uint64_t & nb12,
668
+ constant int64_t & ne0,
669
+ constant int64_t & ne1,
670
+ uint3 tgpig[[threadgroup_position_in_grid]],
671
+ uint tiisg[[thread_index_in_simdgroup]]) {
672
+
673
+ const int nrows = ne11;
674
+ const int64_t r0 = tgpig.x;
675
+ const int64_t im = tgpig.z;
676
+
677
+ device const half4 * x4 = (device const half4 *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
678
+
679
+ for (int r1 = 0; r1 < nrows; ++r1) {
680
+ device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12);
681
+
682
+ float sumf = 0;
683
+ for (int i = tiisg; i < ne00/4; i += 32) {
684
+ for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
685
+ }
686
+
687
+ float all_sum = simd_sum(sumf);
688
+ if (tiisg == 0) {
689
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
690
+ }
548
691
  }
549
- threadgroup_barrier(mem_flags::mem_threadgroup);
550
- if (ith == 0) {
551
- for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
552
- dst[im*ne1*ne0 + r1*ne0 + r0] = sum[0];
553
- }
554
-
555
- // Original implementation. Left behind commented out for now
556
- //threadgroup_barrier(mem_flags::mem_threadgroup);
557
- //for (uint i = tptg.x/2; i > 0; i /= 2) {
558
- // if (tpitg.x < i) {
559
- // sum[tpitg.x] += sum[tpitg.x + i];
560
- // }
561
- // threadgroup_barrier(mem_flags::mem_threadgroup);
562
- //}
563
- //
564
- //if (tpitg.x == 0) {
565
- // dst[im*ne1*ne0 + r1*ne0 + r0] = sum[0];
566
- //}
567
692
  }
568
693
 
569
694
  kernel void kernel_alibi_f32(
@@ -632,25 +757,27 @@ kernel void kernel_rope(
632
757
  constant int & mode,
633
758
  constant float & freq_base,
634
759
  constant float & freq_scale,
635
- uint3 tpig[[thread_position_in_grid]]) {
636
- const int64_t i3 = tpig[2];
637
- const int64_t i2 = tpig[1];
638
- const int64_t i1 = tpig[0];
760
+ uint tiitg[[thread_index_in_threadgroup]],
761
+ uint3 tptg[[threads_per_threadgroup]],
762
+ uint3 tgpig[[threadgroup_position_in_grid]]) {
763
+ const int64_t i3 = tgpig[2];
764
+ const int64_t i2 = tgpig[1];
765
+ const int64_t i1 = tgpig[0];
639
766
 
640
767
  const bool is_neox = mode & 2;
641
- const float theta_scale = pow(freq_base, -2.0f/n_dims);
642
768
 
643
769
  const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
644
770
 
645
- float theta = freq_scale * (float)p;
771
+ const float theta_0 = freq_scale * (float)p;
772
+ const float inv_ndims = -1.f/n_dims;
646
773
 
647
774
  if (!is_neox) {
648
- for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
775
+ for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
776
+
777
+ const float theta = theta_0 * pow(freq_base, inv_ndims*i0);
649
778
  const float cos_theta = cos(theta);
650
779
  const float sin_theta = sin(theta);
651
780
 
652
- theta *= theta_scale;
653
-
654
781
  device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
655
782
  device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
656
783
 
@@ -662,12 +789,12 @@ kernel void kernel_rope(
662
789
  }
663
790
  } else {
664
791
  for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
665
- for (int64_t ic = 0; ic < n_dims; ic += 2) {
792
+ for (int64_t ic = 2*tiitg; ic < n_dims; ic += 2*tptg.x) {
793
+
794
+ const float theta = theta_0 * pow(freq_base, inv_ndims*ic - ib);
666
795
  const float cos_theta = cos(theta);
667
796
  const float sin_theta = sin(theta);
668
797
 
669
- theta *= theta_scale;
670
-
671
798
  const int64_t i0 = ib*n_dims + ic/2;
672
799
 
673
800
  device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
@@ -1071,31 +1198,40 @@ kernel void kernel_mul_mat_q3_K_f32(
1071
1198
  device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0;
1072
1199
  device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
1073
1200
 
1074
- float yl[16];
1201
+ float yl[32];
1075
1202
 
1076
- const uint16_t kmask1 = 0x0303;
1203
+ const uint16_t kmask1 = 0x3030;
1077
1204
  const uint16_t kmask2 = 0x0f0f;
1078
1205
 
1079
- const int tid = tiisg/2;
1080
- const int ix = tiisg%2;
1081
- const int ip = tid/8; // 0 or 1
1082
- const int il = tid/2 - 4*ip; // 0...3
1206
+ const int tid = tiisg/4;
1207
+ const int ix = tiisg%4;
1208
+ const int ip = tid/4; // 0 or 1
1209
+ const int il = 2*((tid%4)/2); // 0 or 2
1083
1210
  const int ir = tid%2;
1084
1211
  const int n = 8;
1085
1212
  const int l0 = n*ir;
1086
1213
 
1087
- const uint16_t m1 = 1 << (4*ip + il);
1088
- const uint16_t m2 = m1 << 8;
1214
+ // One would think that the Metal compiler would figure out that ip and il can only have
1215
+ // 4 possible states, and optimize accordingly. Well, no. It needs help, and we do it
1216
+ // with these two tales.
1217
+ //
1218
+ // Possible masks for the high bit
1219
+ const ushort4 mm[4] = {{0x0001, 0x0100, 0x0002, 0x0200}, // ip = 0, il = 0
1220
+ {0x0004, 0x0400, 0x0008, 0x0800}, // ip = 0, il = 2
1221
+ {0x0010, 0x1000, 0x0020, 0x2000}, // ip = 1, il = 0
1222
+ {0x0040, 0x4000, 0x0080, 0x8000}}; // ip = 1, il = 2
1223
+
1224
+ // Possible masks for the low 2 bits
1225
+ const int4 qm[2] = {{0x0003, 0x0300, 0x000c, 0x0c00}, {0x0030, 0x3000, 0x00c0, 0xc000}};
1226
+
1227
+ const ushort4 hm = mm[2*ip + il/2];
1089
1228
 
1090
1229
  const int shift = 2*il;
1091
- const uint16_t qm1 = 0x0003 << shift;
1092
- const uint16_t qm2 = 0x0300 << shift;
1093
- const int32_t v1 = 4 << shift;
1094
- const int32_t v2 = 1024 << shift;
1230
+ const float v1 = il == 0 ? 4.f : 64.f;
1231
+ const float v2 = 4.f * v1;
1095
1232
 
1096
1233
  const uint16_t s_shift1 = 4*ip;
1097
- const uint16_t s_shift2 = s_shift1 + 2*(il/2);
1098
- const int ik = 4 + (il%2);
1234
+ const uint16_t s_shift2 = s_shift1 + il;
1099
1235
 
1100
1236
  const int q_offset = 32*ip + l0;
1101
1237
  const int y_offset = 128*ip + 32*il + l0;
@@ -1104,12 +1240,19 @@ kernel void kernel_mul_mat_q3_K_f32(
1104
1240
 
1105
1241
  device const float * y1 = yy + ix*QK_K + y_offset;
1106
1242
 
1107
- float sumf1[2] = {0.f}, sumf2[2] = {0.f};
1108
- for (int i = ix; i < nb; i += 2) {
1243
+ uint32_t scales32, aux32;
1244
+ thread uint16_t * scales16 = (thread uint16_t *)&scales32;
1245
+ thread const int8_t * scales = (thread const int8_t *)&scales32;
1246
+
1247
+ float sumf1[2] = {0.f};
1248
+ float sumf2[2] = {0.f};
1249
+ for (int i = ix; i < nb; i += 4) {
1109
1250
 
1110
1251
  for (int l = 0; l < 8; ++l) {
1111
- yl[l+0] = y1[l+ 0];
1112
- yl[l+8] = y1[l+16];
1252
+ yl[l+ 0] = y1[l+ 0];
1253
+ yl[l+ 8] = y1[l+16];
1254
+ yl[l+16] = y1[l+32];
1255
+ yl[l+24] = y1[l+48];
1113
1256
  }
1114
1257
 
1115
1258
  device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset);
@@ -1120,27 +1263,43 @@ kernel void kernel_mul_mat_q3_K_f32(
1120
1263
  for (int row = 0; row < 2; ++row) {
1121
1264
 
1122
1265
  const float d_all = (float)dh[0];
1123
- const char2 scales = as_type<char2>((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4)));
1124
1266
 
1125
- float s1 = 0, s2 = 0;
1267
+ scales16[0] = a[4];
1268
+ scales16[1] = a[5];
1269
+ aux32 = ((scales32 >> s_shift2) << 4) & 0x30303030;
1270
+ scales16[0] = a[il+0];
1271
+ scales16[1] = a[il+1];
1272
+ scales32 = ((scales32 >> s_shift1) & 0x0f0f0f0f) | aux32;
1273
+
1274
+ float s1 = 0, s2 = 0, s3 = 0, s4 = 0, s5 = 0, s6 = 0;
1126
1275
  for (int l = 0; l < n; l += 2) {
1127
- const uint16_t qs = q[l/2];
1128
- s1 += yl[l+0] * ((int32_t)(qs & qm1) - ((h[l/2] & m1) ? 0 : v1));
1129
- s2 += yl[l+1] * ((int32_t)(qs & qm2) - ((h[l/2] & m2) ? 0 : v2));
1276
+ const int32_t qs = q[l/2];
1277
+ s1 += yl[l+0] * (qs & qm[il/2][0]);
1278
+ s2 += yl[l+1] * (qs & qm[il/2][1]);
1279
+ s3 += ((h[l/2] & hm[0]) ? 0.f : yl[l+0]) + ((h[l/2] & hm[1]) ? 0.f : yl[l+1]);
1280
+ s4 += yl[l+16] * (qs & qm[il/2][2]);
1281
+ s5 += yl[l+17] * (qs & qm[il/2][3]);
1282
+ s6 += ((h[l/2] & hm[2]) ? 0.f : yl[l+16]) + ((h[l/2] & hm[3]) ? 0.f : yl[l+17]);
1130
1283
  }
1131
- float d = d_all * (s1 + 1.f/256.f * s2);
1132
- sumf1[row] += d * scales[0];
1133
- sumf2[row] += d;
1284
+ float d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
1285
+ float d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2);
1286
+ sumf1[row] += d1 * (scales[0] - 32);
1287
+ sumf2[row] += d2 * (scales[2] - 32);
1134
1288
 
1135
- s1 = s2 = 0;
1289
+ s1 = s2 = s3 = s4 = s5 = s6 = 0;
1136
1290
  for (int l = 0; l < n; l += 2) {
1137
- const uint16_t qs = q[l/2+8];
1138
- s1 += yl[l+8] * ((int32_t)(qs & qm1) - ((h[l/2+8] & m1) ? 0 : v1));
1139
- s2 += yl[l+9] * ((int32_t)(qs & qm2) - ((h[l/2+8] & m2) ? 0 : v2));
1291
+ const int32_t qs = q[l/2+8];
1292
+ s1 += yl[l+8] * (qs & qm[il/2][0]);
1293
+ s2 += yl[l+9] * (qs & qm[il/2][1]);
1294
+ s3 += ((h[l/2+8] & hm[0]) ? 0.f : yl[l+8]) + ((h[l/2+8] & hm[1]) ? 0.f : yl[l+9]);
1295
+ s4 += yl[l+24] * (qs & qm[il/2][2]);
1296
+ s5 += yl[l+25] * (qs & qm[il/2][3]);
1297
+ s6 += ((h[l/2+8] & hm[2]) ? 0.f : yl[l+24]) + ((h[l/2+8] & hm[3]) ? 0.f : yl[l+25]);
1140
1298
  }
1141
- d = d_all * (s1 + 1.f/256.f * s2);
1142
- sumf1[row] += d * scales[1];
1143
- sumf2[row] += d;
1299
+ d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
1300
+ d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2);
1301
+ sumf1[row] += d1 * (scales[1] - 32);
1302
+ sumf2[row] += d2 * (scales[3] - 32);
1144
1303
 
1145
1304
  q += step;
1146
1305
  h += step;
@@ -1149,17 +1308,20 @@ kernel void kernel_mul_mat_q3_K_f32(
1149
1308
 
1150
1309
  }
1151
1310
 
1152
- y1 += 2 * QK_K;
1311
+ y1 += 4 * QK_K;
1153
1312
 
1154
1313
  }
1155
1314
 
1156
1315
  for (int row = 0; row < 2; ++row) {
1157
- const float sumf = (sumf1[row] - 32.f*sumf2[row]) / (1 << shift);
1158
- const float tot = simd_sum(sumf);
1159
- if (tiisg == 0) {
1160
- dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = tot;
1316
+ const float sumf = (sumf1[row] + 0.25f * sumf2[row]) / (1 << shift);
1317
+ sumf1[row] = simd_sum(sumf);
1318
+ }
1319
+ if (tiisg == 0) {
1320
+ for (int row = 0; row < 2; ++row) {
1321
+ dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = sumf1[row];
1161
1322
  }
1162
1323
  }
1324
+
1163
1325
  }
1164
1326
  #else
1165
1327
  kernel void kernel_mul_mat_q3_K_f32(
@@ -1262,7 +1424,8 @@ kernel void kernel_mul_mat_q4_K_f32(
1262
1424
  const int r0 = tgpig.x;
1263
1425
  const int r1 = tgpig.y;
1264
1426
  const int r2 = tgpig.z;
1265
- const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
1427
+ //const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
1428
+ const int first_row = r0 * N_DST;
1266
1429
  const int ib_row = first_row * nb;
1267
1430
  const uint offset0 = r2/gqa*(nb*ne0);
1268
1431
  device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
@@ -1511,17 +1674,25 @@ kernel void kernel_mul_mat_q5_K_f32(
1511
1674
  sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2);
1512
1675
  sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2);
1513
1676
 
1514
- float4 acc = {0.f, 0.f, 0.f, 0.f};
1677
+ float4 acc1 = {0.f};
1678
+ float4 acc2 = {0.f};
1515
1679
  for (int l = 0; l < n; ++l) {
1516
1680
  uint8_t h = qh[l];
1517
- acc[0] += yl[l+0] * ((uint16_t)(q1[l] & 0x0F) + (h & hm1 ? 16 : 0));
1518
- acc[1] += yl[l+8] * ((uint16_t)(q1[l] & 0xF0) + (h & hm2 ? 256 : 0));
1519
- acc[2] += yh[l+0] * ((uint16_t)(q2[l] & 0x0F) + (h & hm3 ? 16 : 0));
1520
- acc[3] += yh[l+8] * ((uint16_t)(q2[l] & 0xF0) + (h & hm4 ? 256 : 0));
1681
+ acc1[0] += yl[l+0] * (q1[l] & 0x0F);
1682
+ acc1[1] += yl[l+8] * (q1[l] & 0xF0);
1683
+ acc1[2] += yh[l+0] * (q2[l] & 0x0F);
1684
+ acc1[3] += yh[l+8] * (q2[l] & 0xF0);
1685
+ acc2[0] += h & hm1 ? yl[l+0] : 0.f;
1686
+ acc2[1] += h & hm2 ? yl[l+8] : 0.f;
1687
+ acc2[2] += h & hm3 ? yh[l+0] : 0.f;
1688
+ acc2[3] += h & hm4 ? yh[l+8] : 0.f;
1521
1689
  }
1522
1690
  const float dall = dh[0];
1523
1691
  const float dmin = dh[1];
1524
- sumf[row] += dall * (acc[0] * sc8[0] + acc[1] * sc8[1] * 1.f/16.f + acc[2] * sc8[4] + acc[3] * sc8[5] * 1.f/16.f) -
1692
+ sumf[row] += dall * (sc8[0] * (acc1[0] + 16.f*acc2[0]) +
1693
+ sc8[1] * (acc1[1]/16.f + 16.f*acc2[1]) +
1694
+ sc8[4] * (acc1[2] + 16.f*acc2[2]) +
1695
+ sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) -
1525
1696
  dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
1526
1697
 
1527
1698
  q1 += step;
@@ -1704,29 +1875,34 @@ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg)
1704
1875
 
1705
1876
  template <typename type4x4>
1706
1877
  void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
1878
+
1707
1879
  device const uint16_t * qs = ((device const uint16_t *)xb + 1);
1708
- const half d = il ? (xb->d / 16.h) : xb->d;
1709
- const half m = il ? ( -8.h * 16.h) : -8.h;
1880
+ const float d1 = il ? (xb->d / 16.h) : xb->d;
1881
+ const float d2 = d1 / 256.f;
1882
+ const float md = -8.h * xb->d;
1710
1883
  const ushort mask0 = il ? 0x00F0 : 0x000F;
1711
- const ushort mask1 = il ? 0xF000 : 0x0F00;
1884
+ const ushort mask1 = mask0 << 8;
1712
1885
 
1713
1886
  for (int i=0;i<8;i++) {
1714
- reg[i/2][2*(i%2)] = (((qs[i] & mask0) ) + m) * d;
1715
- reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) + m) * d;
1887
+ reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md;
1888
+ reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md;
1716
1889
  }
1890
+
1717
1891
  }
1718
1892
 
1719
1893
  template <typename type4x4>
1720
1894
  void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
1895
+
1721
1896
  device const uint16_t * qs = ((device const uint16_t *)xb + 2);
1722
- const half d = il ? (xb->d / 16.h) : xb->d;
1723
- const half m = xb->m;
1897
+ const float d1 = il ? (xb->d / 16.h) : xb->d;
1898
+ const float d2 = d1 / 256.f;
1899
+ const float m = xb->m;
1724
1900
  const ushort mask0 = il ? 0x00F0 : 0x000F;
1725
- const ushort mask1 = il ? 0xF000 : 0x0F00;
1901
+ const ushort mask1 = mask0 << 8;
1726
1902
 
1727
1903
  for (int i=0;i<8;i++) {
1728
- reg[i/2][2*(i%2)] = (((qs[i] & mask0) ) * d) + m;
1729
- reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) * d) + m;
1904
+ reg[i/2][2*(i%2)+0] = ((qs[i] & mask0) * d1) + m;
1905
+ reg[i/2][2*(i%2)+1] = ((qs[i] & mask1) * d2) + m;
1730
1906
  }
1731
1907
  }
1732
1908
 
@@ -1762,7 +1938,7 @@ void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg
1762
1938
 
1763
1939
  template <typename type4x4>
1764
1940
  void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) {
1765
- const float d_all = (float)(xb->d);
1941
+ const half d_all = xb->d;
1766
1942
  device const uint8_t * q = (device const uint8_t *)xb->qs;
1767
1943
  device const uint8_t * h = (device const uint8_t *)xb->hmask;
1768
1944
  device const int8_t * scales = (device const int8_t *)xb->scales;
@@ -1775,17 +1951,20 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
1775
1951
  ((il/4)>0 ? 12 : 3);
1776
1952
  uint16_t kmask2 = il/8 ? 0xF0 : 0x0F;
1777
1953
  uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];
1778
- int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2) : \
1779
- (scale_2&kmask2) | ((scale_1&kmask1) << 4);
1780
- float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f);
1954
+ int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2)
1955
+ : (scale_2&kmask2) | ((scale_1&kmask1) << 4);
1956
+ half dl = il<8 ? d_all * (dl_int - 32.h) : d_all * (dl_int / 16.h - 32.h);
1957
+ const half ml = 4.h * dl;
1781
1958
 
1782
- il = (il/2)%4;
1783
- float coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
1784
- uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
1959
+ il = (il/2) & 3;
1960
+ const half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
1961
+ const uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
1962
+ dl *= coef;
1785
1963
 
1786
1964
  for (int i = 0; i < 16; ++i) {
1787
- reg[i/4][i%4] = coef * dl * ((q[i] & mask) - ((h[i] & m) ? 0 : 4.f/coef));
1965
+ reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml);
1788
1966
  }
1967
+
1789
1968
  #else
1790
1969
  float kcoef = il&1 ? 1.f/16.f : 1.f;
1791
1970
  uint16_t kmask = il&1 ? 0xF0 : 0x0F;
@@ -1799,31 +1978,37 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
1799
1978
  #endif
1800
1979
  }
1801
1980
 
1981
+ static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) {
1982
+ return j < 4 ? uchar2{uchar(q[j+0+k] & 63), uchar(q[j+4+k] & 63)}
1983
+ : uchar2{uchar((q[j+4+k] & 0xF) | ((q[j-4+k] & 0xc0) >> 2)), uchar((q[j+4+k] >> 4) | ((q[j-0+k] & 0xc0) >> 2))};
1984
+ }
1985
+
1802
1986
  template <typename type4x4>
1803
1987
  void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) {
1804
- device const uint8_t * q = xb->qs;
1988
+ device const uchar * q = xb->qs;
1805
1989
 
1806
1990
  #if QK_K == 256
1807
- const float d = (float)(xb->d);
1808
- const float min = (float)(xb->dmin);
1809
1991
  short is = (il/4) * 2;
1810
1992
  q = q + (il/4) * 32 + 16 * (il&1);
1811
- il = il%4;
1812
- const uchar4 sc = get_scale_min_k4(is, xb->scales);
1813
- const float dl = il<2 ? d * sc[0] : d * sc[2]/16.h;
1814
- const float ml = il<2 ? min * sc[1] : min * sc[3];
1993
+ il = il & 3;
1994
+ const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
1995
+ const half d = il < 2 ? xb->d : xb->d / 16.h;
1996
+ const half min = xb->dmin;
1997
+ const half dl = d * sc[0];
1998
+ const half ml = min * sc[1];
1815
1999
  #else
1816
2000
  q = q + 16 * (il&1);
1817
2001
  device const uint8_t * s = xb->scales;
1818
2002
  device const half2 * dh = (device const half2 *)xb->d;
1819
2003
  const float2 d = (float2)dh[0];
1820
2004
  const float dl = il<2 ? d[0] * (s[0]&0xF) : d[0] * (s[1]&0xF)/16.h;
1821
- const float ml = il<2 ? d[1] * (s[0]>>4) : d[1 ]* (s[1]>>4);
2005
+ const float ml = il<2 ? d[1] * (s[0]>>4) : d[1] * (s[1]>>4);
1822
2006
  #endif
1823
2007
  const ushort mask = il<2 ? 0x0F : 0xF0;
1824
2008
  for (int i = 0; i < 16; ++i) {
1825
2009
  reg[i/4][i%4] = dl * (q[i] & mask) - ml;
1826
2010
  }
2011
+
1827
2012
  }
1828
2013
 
1829
2014
  template <typename type4x4>
@@ -1832,19 +2017,19 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
1832
2017
  device const uint8_t * qh = xb->qh;
1833
2018
 
1834
2019
  #if QK_K == 256
1835
- const float d = (float)(xb->d);
1836
- const float min = (float)(xb->dmin);
1837
2020
  short is = (il/4) * 2;
1838
2021
  q = q + 32 * (il/4) + 16 * (il&1);
1839
2022
  qh = qh + 16 * (il&1);
1840
2023
  uint8_t ul = 1 << (il/2);
1841
- il = il%4;
1842
- const uchar4 sc = get_scale_min_k4(is, xb->scales);
1843
- const float dl = il<2 ? d * sc[0] : d * sc[2]/16.h;
1844
- const float ml = il<2 ? min * sc[1] : min * sc[3];
2024
+ il = il & 3;
2025
+ const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
2026
+ const half d = il < 2 ? xb->d : xb->d / 16.h;
2027
+ const half min = xb->dmin;
2028
+ const half dl = d * sc[0];
2029
+ const half ml = min * sc[1];
1845
2030
 
1846
- const ushort mask = il<2 ? 0x0F : 0xF0;
1847
- const float qh_val = il<2 ? 16.f : 256.f;
2031
+ const ushort mask = il<2 ? 0x0F : 0xF0;
2032
+ const half qh_val = il<2 ? 16.h : 256.h;
1848
2033
  for (int i = 0; i < 16; ++i) {
1849
2034
  reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
1850
2035
  }
@@ -1863,7 +2048,7 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
1863
2048
 
1864
2049
  template <typename type4x4>
1865
2050
  void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) {
1866
- const float d_all = (float)(xb->d);
2051
+ const half d_all = xb->d;
1867
2052
  device const uint8_t * ql = (device const uint8_t *)xb->ql;
1868
2053
  device const uint8_t * qh = (device const uint8_t *)xb->qh;
1869
2054
  device const int8_t * scales = (device const int8_t *)xb->scales;
@@ -1871,19 +2056,21 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
1871
2056
  #if QK_K == 256
1872
2057
  ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
1873
2058
  qh = qh + 32*(il/8) + 16*(il&1);
1874
- float sc = scales[(il%2) + 2 * ((il/2))];
1875
- il = (il/2)%4;
2059
+ half sc = scales[(il%2) + 2 * ((il/2))];
2060
+ il = (il/2) & 3;
1876
2061
  #else
1877
2062
  ql = ql + 16 * (il&1);
1878
- float sc = scales[il];
2063
+ half sc = scales[il];
1879
2064
  #endif
2065
+ const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
2066
+ const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
2067
+ const half coef = il>1 ? 1.f/16.h : 1.h;
2068
+ const half ml = d_all * sc * 32.h;
2069
+ const half dl = d_all * sc * coef;
1880
2070
  for (int i = 0; i < 16; ++i) {
1881
- uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
1882
- uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
1883
- const float coef = il>1 ? 1.f/16.f : 1.f;
1884
- float q = il&1 ? ((ql[i]&kmask2)|((qh[i]&kmask1)<<2)) - 32.f/coef : \
1885
- ((ql[i]&kmask2)|((qh[i]&kmask1)<<4)) - 32.f/coef;
1886
- reg[i/4][i%4] = d_all * sc * q * coef;
2071
+ const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2))
2072
+ : ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4));
2073
+ reg[i/4][i%4] = dl * q - ml;
1887
2074
  }
1888
2075
  }
1889
2076