llama_cpp 0.5.0 → 0.5.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -63,18 +63,18 @@ kernel void kernel_mul_row(
63
63
  }
64
64
 
65
65
  kernel void kernel_scale(
66
- device const float * src0,
67
- device float * dst,
66
+ device const float4 * src0,
67
+ device float4 * dst,
68
68
  constant float & scale,
69
69
  uint tpig[[thread_position_in_grid]]) {
70
70
  dst[tpig] = src0[tpig] * scale;
71
71
  }
72
72
 
73
73
  kernel void kernel_silu(
74
- device const float * src0,
75
- device float * dst,
74
+ device const float4 * src0,
75
+ device float4 * dst,
76
76
  uint tpig[[thread_position_in_grid]]) {
77
- float x = src0[tpig];
77
+ device const float4 & x = src0[tpig];
78
78
  dst[tpig] = x / (1.0f + exp(-x));
79
79
  }
80
80
 
@@ -89,10 +89,10 @@ constant float GELU_COEF_A = 0.044715f;
89
89
  constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
90
90
 
91
91
  kernel void kernel_gelu(
92
- device const float * src0,
93
- device float * dst,
92
+ device const float4 * src0,
93
+ device float4 * dst,
94
94
  uint tpig[[thread_position_in_grid]]) {
95
- float x = src0[tpig];
95
+ device const float4 & x = src0[tpig];
96
96
 
97
97
  // BEWARE !!!
98
98
  // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
@@ -107,7 +107,6 @@ kernel void kernel_soft_max(
107
107
  constant int64_t & ne00,
108
108
  constant int64_t & ne01,
109
109
  constant int64_t & ne02,
110
- threadgroup float * buf [[threadgroup(0)]],
111
110
  uint3 tgpig[[threadgroup_position_in_grid]],
112
111
  uint3 tpitg[[thread_position_in_threadgroup]],
113
112
  uint3 ntg[[threads_per_threadgroup]]) {
@@ -119,55 +118,67 @@ kernel void kernel_soft_max(
119
118
  device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
120
119
 
121
120
  // parallel max
122
- buf[tpitg[0]] = -INFINITY;
123
- for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
124
- buf[tpitg[0]] = MAX(buf[tpitg[0]], psrc0[i00]);
121
+ float lmax = psrc0[tpitg[0]];
122
+ for (int i00 = tpitg[0] + ntg[0]; i00 < ne00; i00 += ntg[0]) {
123
+ lmax = MAX(lmax, psrc0[i00]);
125
124
  }
125
+ const float max = simd_max(lmax);
126
126
 
127
- // reduce
128
- threadgroup_barrier(mem_flags::mem_threadgroup);
129
- for (uint i = ntg[0]/2; i > 0; i /= 2) {
130
- if (tpitg[0] < i) {
131
- buf[tpitg[0]] = MAX(buf[tpitg[0]], buf[tpitg[0] + i]);
132
- }
133
- threadgroup_barrier(mem_flags::mem_threadgroup);
127
+ // parallel sum
128
+ float lsum = 0.0f;
129
+ for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
130
+ const float exp_psrc0 = exp(psrc0[i00] - max);
131
+ lsum += exp_psrc0;
132
+ // Remember the result of exp here. exp is expensive, so we really do not
133
+ // whish to compute it twice.
134
+ pdst[i00] = exp_psrc0;
134
135
  }
135
136
 
136
- // broadcast
137
- if (tpitg[0] == 0) {
138
- buf[0] = buf[0];
137
+ const float sum = simd_sum(lsum);
138
+
139
+ for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
140
+ pdst[i00] /= sum;
139
141
  }
142
+ }
140
143
 
141
- threadgroup_barrier(mem_flags::mem_threadgroup);
144
+ kernel void kernel_soft_max_4(
145
+ device const float * src0,
146
+ device float * dst,
147
+ constant int64_t & ne00,
148
+ constant int64_t & ne01,
149
+ constant int64_t & ne02,
150
+ uint3 tgpig[[threadgroup_position_in_grid]],
151
+ uint3 tpitg[[thread_position_in_threadgroup]],
152
+ uint3 ntg[[threads_per_threadgroup]]) {
153
+ const int64_t i03 = tgpig[2];
154
+ const int64_t i02 = tgpig[1];
155
+ const int64_t i01 = tgpig[0];
142
156
 
143
- const float max = buf[0];
157
+ device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
158
+ device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
144
159
 
145
- // parallel sum
146
- buf[tpitg[0]] = 0.0f;
147
- for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
148
- buf[tpitg[0]] += exp(psrc0[i00] - max);
160
+ // parallel max
161
+ float4 lmax4 = psrc4[tpitg[0]];
162
+ for (int i00 = tpitg[0] + ntg[0]; i00 < ne00/4; i00 += ntg[0]) {
163
+ lmax4 = fmax(lmax4, psrc4[i00]);
149
164
  }
165
+ float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
150
166
 
151
- // reduce
152
- threadgroup_barrier(mem_flags::mem_threadgroup);
153
- for (uint i = ntg[0]/2; i > 0; i /= 2) {
154
- if (tpitg[0] < i) {
155
- buf[tpitg[0]] += buf[tpitg[0] + i];
156
- }
157
- threadgroup_barrier(mem_flags::mem_threadgroup);
158
- }
167
+ const float max = simd_max(lmax);
159
168
 
160
- // broadcast
161
- if (tpitg[0] == 0) {
162
- buf[0] = buf[0];
169
+ // parallel sum
170
+ float4 lsum4 = 0.0f;
171
+ for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) {
172
+ const float4 exp_psrc4 = exp(psrc4[i00] - max);
173
+ lsum4 += exp_psrc4;
174
+ pdst4[i00] = exp_psrc4;
163
175
  }
176
+ float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
164
177
 
165
- threadgroup_barrier(mem_flags::mem_threadgroup);
166
-
167
- const float sum = buf[0];
178
+ const float sum = simd_sum(lsum);
168
179
 
169
- for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
170
- pdst[i00] = exp(psrc0[i00] - max) / sum;
180
+ for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) {
181
+ pdst4[i00] /= sum;
171
182
  }
172
183
  }
173
184
 
@@ -186,6 +197,33 @@ kernel void kernel_diag_mask_inf(
186
197
  dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
187
198
  } else {
188
199
  dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];
200
+ }
201
+ }
202
+
203
+ kernel void kernel_diag_mask_inf_8(
204
+ device const float4 * src0,
205
+ device float4 * dst,
206
+ constant int64_t & ne00,
207
+ constant int64_t & ne01,
208
+ constant int & n_past,
209
+ uint3 tpig[[thread_position_in_grid]]) {
210
+
211
+ const int64_t i = 2*tpig[0];
212
+
213
+ dst[i+0] = src0[i+0];
214
+ dst[i+1] = src0[i+1];
215
+ int64_t i4 = 4*i;
216
+ const int64_t i02 = i4/(ne00*ne01); i4 -= i02*ne00*ne01;
217
+ const int64_t i01 = i4/(ne00); i4 -= i01*ne00;
218
+ const int64_t i00 = i4;
219
+ for (int k = 3; k >= 0; --k) {
220
+ if (i00 + 4 + k <= n_past + i01) {
221
+ break;
222
+ }
223
+ dst[i+1][k] = -INFINITY;
224
+ if (i00 + k > n_past + i01) {
225
+ dst[i][k] = -INFINITY;
226
+ }
189
227
  }
190
228
  }
191
229
 
@@ -214,25 +252,17 @@ kernel void kernel_norm(
214
252
  }
215
253
  threadgroup_barrier(mem_flags::mem_threadgroup);
216
254
  }
217
- // broadcast
218
- if (tpitg == 0) {
219
- sum[0] /= ne00;
220
- }
221
- threadgroup_barrier(mem_flags::mem_threadgroup);
222
- const float mean = sum[0];
255
+ const float mean = sum[0] / ne00;
223
256
 
224
- // recenter
257
+ // recenter and VARIANCE
258
+ threadgroup_barrier(mem_flags::mem_threadgroup);
225
259
  device float * y = dst + tgpig*ne00;
226
- for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
227
- y[i00] = x[i00] - mean;
228
- }
229
-
230
- // VARIANCE
231
- // parallel sum
232
260
  sum[tpitg] = 0.0f;
233
261
  for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
262
+ y[i00] = x[i00] - mean;
234
263
  sum[tpitg] += y[i00] * y[i00];
235
264
  }
265
+
236
266
  // reduce
237
267
  threadgroup_barrier(mem_flags::mem_threadgroup);
238
268
  for (uint i = ntg/2; i > 0; i /= 2) {
@@ -241,12 +271,7 @@ kernel void kernel_norm(
241
271
  }
242
272
  threadgroup_barrier(mem_flags::mem_threadgroup);
243
273
  }
244
- // broadcast
245
- if (tpitg == 0) {
246
- sum[0] /= ne00;
247
- }
248
- threadgroup_barrier(mem_flags::mem_threadgroup);
249
- const float variance = sum[0];
274
+ const float variance = sum[0] / ne00;
250
275
 
251
276
  const float scale = 1.0f/sqrt(variance + eps);
252
277
  for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
@@ -254,7 +279,6 @@ kernel void kernel_norm(
254
279
  }
255
280
  }
256
281
 
257
-
258
282
  kernel void kernel_rms_norm(
259
283
  device const void * src0,
260
284
  device float * dst,
@@ -435,6 +459,8 @@ kernel void kernel_mul_mat_q4_1_f32(
435
459
  mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
436
460
  }
437
461
 
462
+ #define NB_Q8_0 8
463
+
438
464
  kernel void kernel_mul_mat_q8_0_f32(
439
465
  device const void * src0,
440
466
  device const float * src1,
@@ -463,30 +489,30 @@ kernel void kernel_mul_mat_q8_0_f32(
463
489
  device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0;
464
490
  device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
465
491
 
466
- float yl[16];
492
+ float yl[NB_Q8_0];
467
493
  float sumf[nr]={0.f};
468
494
 
469
- const int ix = tiisg/2;
470
- const int il = tiisg%2;
495
+ const int ix = tiisg/4;
496
+ const int il = tiisg%4;
471
497
 
472
- device const float * yb = y + ix * QK8_0 + 16*il;
498
+ device const float * yb = y + ix * QK8_0 + NB_Q8_0*il;
473
499
 
474
- // each thread in a SIMD group deals with half a block.
475
- for (int ib = ix; ib < nb; ib += nw/2) {
476
- for (int i = 0; i < 16; ++i) {
500
+ // each thread in a SIMD group deals with NB_Q8_0 quants at a time
501
+ for (int ib = ix; ib < nb; ib += nw/4) {
502
+ for (int i = 0; i < NB_Q8_0; ++i) {
477
503
  yl[i] = yb[i];
478
504
  }
479
505
 
480
506
  for (int row = 0; row < nr; row++) {
481
- device const int8_t * qs = x[ib+row*nb].qs + 16*il;
507
+ device const int8_t * qs = x[ib+row*nb].qs + NB_Q8_0*il;
482
508
  float sumq = 0.f;
483
- for (int iq = 0; iq < 16; ++iq) {
509
+ for (int iq = 0; iq < NB_Q8_0; ++iq) {
484
510
  sumq += qs[iq] * yl[iq];
485
511
  }
486
512
  sumf[row] += sumq*x[ib+row*nb].d;
487
513
  }
488
514
 
489
- yb += QK8_0 * 16;
515
+ yb += NB_Q8_0 * nw;
490
516
  }
491
517
 
492
518
  for (int row = 0; row < nr; ++row) {
@@ -497,7 +523,7 @@ kernel void kernel_mul_mat_q8_0_f32(
497
523
  }
498
524
  }
499
525
 
500
- kernel void kernel_mul_mat_f16_f32(
526
+ kernel void kernel_mul_mat_f16_f32_1row(
501
527
  device const char * src0,
502
528
  device const char * src1,
503
529
  device float * dst,
@@ -515,11 +541,8 @@ kernel void kernel_mul_mat_f16_f32(
515
541
  constant uint64_t & nb12,
516
542
  constant int64_t & ne0,
517
543
  constant int64_t & ne1,
518
- threadgroup float * sum [[threadgroup(0)]],
519
544
  uint3 tgpig[[threadgroup_position_in_grid]],
520
- uint3 tpig[[thread_position_in_grid]],
521
- uint3 tpitg[[thread_position_in_threadgroup]],
522
- uint3 tptg[[threads_per_threadgroup]]) {
545
+ uint tiisg[[thread_index_in_simdgroup]]) {
523
546
 
524
547
  const int64_t r0 = tgpig.x;
525
548
  const int64_t r1 = tgpig.y;
@@ -528,42 +551,144 @@ kernel void kernel_mul_mat_f16_f32(
528
551
  device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
529
552
  device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
530
553
 
531
- uint ith = tpitg.x;
532
- uint nth = tptg.x;
554
+ float sumf = 0;
555
+ if (ne00 < 128) {
556
+ for (int i = tiisg; i < ne00; i += 32) {
557
+ sumf += (float) x[i] * (float) y[i];
558
+ }
559
+ float all_sum = simd_sum(sumf);
560
+ if (tiisg == 0) {
561
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
562
+ }
563
+ } else {
564
+ device const half4 * x4 = (device const half4 *) x;
565
+ device const float4 * y4 = (device const float4 *) y;
566
+ for (int i = tiisg; i < ne00/4; i += 32) {
567
+ for (int k = 0; k < 4; ++k) sumf += (float)x4[i][k] * y4[i][k];
568
+ }
569
+ float all_sum = simd_sum(sumf);
570
+ if (tiisg == 0) {
571
+ for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
572
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
573
+ }
574
+ }
533
575
 
534
- sum[ith] = 0.0f;
576
+ }
535
577
 
536
- for (int i = ith; i < ne00; i += nth) {
537
- sum[ith] += (float) x[i] * (float) y[i];
538
- }
578
+ #define N_F16_F32 4
539
579
 
540
- // accumulate the sum from all threads in the threadgroup
541
- threadgroup_barrier(mem_flags::mem_threadgroup);
542
- if (ith%4 == 0) {
543
- for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
580
+ kernel void kernel_mul_mat_f16_f32(
581
+ device const char * src0,
582
+ device const char * src1,
583
+ device float * dst,
584
+ constant int64_t & ne00,
585
+ constant int64_t & ne01,
586
+ constant int64_t & ne02,
587
+ constant uint64_t & nb00,
588
+ constant uint64_t & nb01,
589
+ constant uint64_t & nb02,
590
+ constant int64_t & ne10,
591
+ constant int64_t & ne11,
592
+ constant int64_t & ne12,
593
+ constant uint64_t & nb10,
594
+ constant uint64_t & nb11,
595
+ constant uint64_t & nb12,
596
+ constant int64_t & ne0,
597
+ constant int64_t & ne1,
598
+ uint3 tgpig[[threadgroup_position_in_grid]],
599
+ uint tiisg[[thread_index_in_simdgroup]]) {
600
+
601
+ const int64_t r0 = tgpig.x;
602
+ const int64_t rb = tgpig.y*N_F16_F32;
603
+ const int64_t im = tgpig.z;
604
+
605
+ device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
606
+
607
+ if (ne00 < 128) {
608
+ for (int row = 0; row < N_F16_F32; ++row) {
609
+ int r1 = rb + row;
610
+ if (r1 >= ne11) {
611
+ break;
612
+ }
613
+
614
+ device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
615
+
616
+ float sumf = 0;
617
+ for (int i = tiisg; i < ne00; i += 32) {
618
+ sumf += (float) x[i] * (float) y[i];
619
+ }
620
+
621
+ float all_sum = simd_sum(sumf);
622
+ if (tiisg == 0) {
623
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
624
+ }
625
+ }
626
+ } else {
627
+ device const half4 * x4 = (device const half4 *)x;
628
+ for (int row = 0; row < N_F16_F32; ++row) {
629
+ int r1 = rb + row;
630
+ if (r1 >= ne11) {
631
+ break;
632
+ }
633
+
634
+ device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
635
+ device const float4 * y4 = (device const float4 *) y;
636
+
637
+ float sumf = 0;
638
+ for (int i = tiisg; i < ne00/4; i += 32) {
639
+ for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
640
+ }
641
+
642
+ float all_sum = simd_sum(sumf);
643
+ if (tiisg == 0) {
644
+ for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
645
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
646
+ }
647
+ }
544
648
  }
545
- threadgroup_barrier(mem_flags::mem_threadgroup);
546
- if (ith%16 == 0) {
547
- for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
649
+ }
650
+
651
+ // Assumes row size (ne00) is a multiple of 4
652
+ kernel void kernel_mul_mat_f16_f32_l4(
653
+ device const char * src0,
654
+ device const char * src1,
655
+ device float * dst,
656
+ constant int64_t & ne00,
657
+ constant int64_t & ne01,
658
+ constant int64_t & ne02,
659
+ constant uint64_t & nb00,
660
+ constant uint64_t & nb01,
661
+ constant uint64_t & nb02,
662
+ constant int64_t & ne10,
663
+ constant int64_t & ne11,
664
+ constant int64_t & ne12,
665
+ constant uint64_t & nb10,
666
+ constant uint64_t & nb11,
667
+ constant uint64_t & nb12,
668
+ constant int64_t & ne0,
669
+ constant int64_t & ne1,
670
+ uint3 tgpig[[threadgroup_position_in_grid]],
671
+ uint tiisg[[thread_index_in_simdgroup]]) {
672
+
673
+ const int nrows = ne11;
674
+ const int64_t r0 = tgpig.x;
675
+ const int64_t im = tgpig.z;
676
+
677
+ device const half4 * x4 = (device const half4 *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
678
+
679
+ for (int r1 = 0; r1 < nrows; ++r1) {
680
+ device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12);
681
+
682
+ float sumf = 0;
683
+ for (int i = tiisg; i < ne00/4; i += 32) {
684
+ for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
685
+ }
686
+
687
+ float all_sum = simd_sum(sumf);
688
+ if (tiisg == 0) {
689
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
690
+ }
548
691
  }
549
- threadgroup_barrier(mem_flags::mem_threadgroup);
550
- if (ith == 0) {
551
- for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
552
- dst[im*ne1*ne0 + r1*ne0 + r0] = sum[0];
553
- }
554
-
555
- // Original implementation. Left behind commented out for now
556
- //threadgroup_barrier(mem_flags::mem_threadgroup);
557
- //for (uint i = tptg.x/2; i > 0; i /= 2) {
558
- // if (tpitg.x < i) {
559
- // sum[tpitg.x] += sum[tpitg.x + i];
560
- // }
561
- // threadgroup_barrier(mem_flags::mem_threadgroup);
562
- //}
563
- //
564
- //if (tpitg.x == 0) {
565
- // dst[im*ne1*ne0 + r1*ne0 + r0] = sum[0];
566
- //}
567
692
  }
568
693
 
569
694
  kernel void kernel_alibi_f32(
@@ -632,25 +757,27 @@ kernel void kernel_rope(
632
757
  constant int & mode,
633
758
  constant float & freq_base,
634
759
  constant float & freq_scale,
635
- uint3 tpig[[thread_position_in_grid]]) {
636
- const int64_t i3 = tpig[2];
637
- const int64_t i2 = tpig[1];
638
- const int64_t i1 = tpig[0];
760
+ uint tiitg[[thread_index_in_threadgroup]],
761
+ uint3 tptg[[threads_per_threadgroup]],
762
+ uint3 tgpig[[threadgroup_position_in_grid]]) {
763
+ const int64_t i3 = tgpig[2];
764
+ const int64_t i2 = tgpig[1];
765
+ const int64_t i1 = tgpig[0];
639
766
 
640
767
  const bool is_neox = mode & 2;
641
- const float theta_scale = pow(freq_base, -2.0f/n_dims);
642
768
 
643
769
  const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
644
770
 
645
- float theta = freq_scale * (float)p;
771
+ const float theta_0 = freq_scale * (float)p;
772
+ const float inv_ndims = -1.f/n_dims;
646
773
 
647
774
  if (!is_neox) {
648
- for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
775
+ for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
776
+
777
+ const float theta = theta_0 * pow(freq_base, inv_ndims*i0);
649
778
  const float cos_theta = cos(theta);
650
779
  const float sin_theta = sin(theta);
651
780
 
652
- theta *= theta_scale;
653
-
654
781
  device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
655
782
  device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
656
783
 
@@ -662,12 +789,12 @@ kernel void kernel_rope(
662
789
  }
663
790
  } else {
664
791
  for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
665
- for (int64_t ic = 0; ic < n_dims; ic += 2) {
792
+ for (int64_t ic = 2*tiitg; ic < n_dims; ic += 2*tptg.x) {
793
+
794
+ const float theta = theta_0 * pow(freq_base, inv_ndims*ic - ib);
666
795
  const float cos_theta = cos(theta);
667
796
  const float sin_theta = sin(theta);
668
797
 
669
- theta *= theta_scale;
670
-
671
798
  const int64_t i0 = ib*n_dims + ic/2;
672
799
 
673
800
  device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
@@ -1071,31 +1198,40 @@ kernel void kernel_mul_mat_q3_K_f32(
1071
1198
  device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0;
1072
1199
  device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
1073
1200
 
1074
- float yl[16];
1201
+ float yl[32];
1075
1202
 
1076
- const uint16_t kmask1 = 0x0303;
1203
+ const uint16_t kmask1 = 0x3030;
1077
1204
  const uint16_t kmask2 = 0x0f0f;
1078
1205
 
1079
- const int tid = tiisg/2;
1080
- const int ix = tiisg%2;
1081
- const int ip = tid/8; // 0 or 1
1082
- const int il = tid/2 - 4*ip; // 0...3
1206
+ const int tid = tiisg/4;
1207
+ const int ix = tiisg%4;
1208
+ const int ip = tid/4; // 0 or 1
1209
+ const int il = 2*((tid%4)/2); // 0 or 2
1083
1210
  const int ir = tid%2;
1084
1211
  const int n = 8;
1085
1212
  const int l0 = n*ir;
1086
1213
 
1087
- const uint16_t m1 = 1 << (4*ip + il);
1088
- const uint16_t m2 = m1 << 8;
1214
+ // One would think that the Metal compiler would figure out that ip and il can only have
1215
+ // 4 possible states, and optimize accordingly. Well, no. It needs help, and we do it
1216
+ // with these two tales.
1217
+ //
1218
+ // Possible masks for the high bit
1219
+ const ushort4 mm[4] = {{0x0001, 0x0100, 0x0002, 0x0200}, // ip = 0, il = 0
1220
+ {0x0004, 0x0400, 0x0008, 0x0800}, // ip = 0, il = 2
1221
+ {0x0010, 0x1000, 0x0020, 0x2000}, // ip = 1, il = 0
1222
+ {0x0040, 0x4000, 0x0080, 0x8000}}; // ip = 1, il = 2
1223
+
1224
+ // Possible masks for the low 2 bits
1225
+ const int4 qm[2] = {{0x0003, 0x0300, 0x000c, 0x0c00}, {0x0030, 0x3000, 0x00c0, 0xc000}};
1226
+
1227
+ const ushort4 hm = mm[2*ip + il/2];
1089
1228
 
1090
1229
  const int shift = 2*il;
1091
- const uint16_t qm1 = 0x0003 << shift;
1092
- const uint16_t qm2 = 0x0300 << shift;
1093
- const int32_t v1 = 4 << shift;
1094
- const int32_t v2 = 1024 << shift;
1230
+ const float v1 = il == 0 ? 4.f : 64.f;
1231
+ const float v2 = 4.f * v1;
1095
1232
 
1096
1233
  const uint16_t s_shift1 = 4*ip;
1097
- const uint16_t s_shift2 = s_shift1 + 2*(il/2);
1098
- const int ik = 4 + (il%2);
1234
+ const uint16_t s_shift2 = s_shift1 + il;
1099
1235
 
1100
1236
  const int q_offset = 32*ip + l0;
1101
1237
  const int y_offset = 128*ip + 32*il + l0;
@@ -1104,12 +1240,19 @@ kernel void kernel_mul_mat_q3_K_f32(
1104
1240
 
1105
1241
  device const float * y1 = yy + ix*QK_K + y_offset;
1106
1242
 
1107
- float sumf1[2] = {0.f}, sumf2[2] = {0.f};
1108
- for (int i = ix; i < nb; i += 2) {
1243
+ uint32_t scales32, aux32;
1244
+ thread uint16_t * scales16 = (thread uint16_t *)&scales32;
1245
+ thread const int8_t * scales = (thread const int8_t *)&scales32;
1246
+
1247
+ float sumf1[2] = {0.f};
1248
+ float sumf2[2] = {0.f};
1249
+ for (int i = ix; i < nb; i += 4) {
1109
1250
 
1110
1251
  for (int l = 0; l < 8; ++l) {
1111
- yl[l+0] = y1[l+ 0];
1112
- yl[l+8] = y1[l+16];
1252
+ yl[l+ 0] = y1[l+ 0];
1253
+ yl[l+ 8] = y1[l+16];
1254
+ yl[l+16] = y1[l+32];
1255
+ yl[l+24] = y1[l+48];
1113
1256
  }
1114
1257
 
1115
1258
  device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset);
@@ -1120,27 +1263,43 @@ kernel void kernel_mul_mat_q3_K_f32(
1120
1263
  for (int row = 0; row < 2; ++row) {
1121
1264
 
1122
1265
  const float d_all = (float)dh[0];
1123
- const char2 scales = as_type<char2>((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4)));
1124
1266
 
1125
- float s1 = 0, s2 = 0;
1267
+ scales16[0] = a[4];
1268
+ scales16[1] = a[5];
1269
+ aux32 = ((scales32 >> s_shift2) << 4) & 0x30303030;
1270
+ scales16[0] = a[il+0];
1271
+ scales16[1] = a[il+1];
1272
+ scales32 = ((scales32 >> s_shift1) & 0x0f0f0f0f) | aux32;
1273
+
1274
+ float s1 = 0, s2 = 0, s3 = 0, s4 = 0, s5 = 0, s6 = 0;
1126
1275
  for (int l = 0; l < n; l += 2) {
1127
- const uint16_t qs = q[l/2];
1128
- s1 += yl[l+0] * ((int32_t)(qs & qm1) - ((h[l/2] & m1) ? 0 : v1));
1129
- s2 += yl[l+1] * ((int32_t)(qs & qm2) - ((h[l/2] & m2) ? 0 : v2));
1276
+ const int32_t qs = q[l/2];
1277
+ s1 += yl[l+0] * (qs & qm[il/2][0]);
1278
+ s2 += yl[l+1] * (qs & qm[il/2][1]);
1279
+ s3 += ((h[l/2] & hm[0]) ? 0.f : yl[l+0]) + ((h[l/2] & hm[1]) ? 0.f : yl[l+1]);
1280
+ s4 += yl[l+16] * (qs & qm[il/2][2]);
1281
+ s5 += yl[l+17] * (qs & qm[il/2][3]);
1282
+ s6 += ((h[l/2] & hm[2]) ? 0.f : yl[l+16]) + ((h[l/2] & hm[3]) ? 0.f : yl[l+17]);
1130
1283
  }
1131
- float d = d_all * (s1 + 1.f/256.f * s2);
1132
- sumf1[row] += d * scales[0];
1133
- sumf2[row] += d;
1284
+ float d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
1285
+ float d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2);
1286
+ sumf1[row] += d1 * (scales[0] - 32);
1287
+ sumf2[row] += d2 * (scales[2] - 32);
1134
1288
 
1135
- s1 = s2 = 0;
1289
+ s1 = s2 = s3 = s4 = s5 = s6 = 0;
1136
1290
  for (int l = 0; l < n; l += 2) {
1137
- const uint16_t qs = q[l/2+8];
1138
- s1 += yl[l+8] * ((int32_t)(qs & qm1) - ((h[l/2+8] & m1) ? 0 : v1));
1139
- s2 += yl[l+9] * ((int32_t)(qs & qm2) - ((h[l/2+8] & m2) ? 0 : v2));
1291
+ const int32_t qs = q[l/2+8];
1292
+ s1 += yl[l+8] * (qs & qm[il/2][0]);
1293
+ s2 += yl[l+9] * (qs & qm[il/2][1]);
1294
+ s3 += ((h[l/2+8] & hm[0]) ? 0.f : yl[l+8]) + ((h[l/2+8] & hm[1]) ? 0.f : yl[l+9]);
1295
+ s4 += yl[l+24] * (qs & qm[il/2][2]);
1296
+ s5 += yl[l+25] * (qs & qm[il/2][3]);
1297
+ s6 += ((h[l/2+8] & hm[2]) ? 0.f : yl[l+24]) + ((h[l/2+8] & hm[3]) ? 0.f : yl[l+25]);
1140
1298
  }
1141
- d = d_all * (s1 + 1.f/256.f * s2);
1142
- sumf1[row] += d * scales[1];
1143
- sumf2[row] += d;
1299
+ d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
1300
+ d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2);
1301
+ sumf1[row] += d1 * (scales[1] - 32);
1302
+ sumf2[row] += d2 * (scales[3] - 32);
1144
1303
 
1145
1304
  q += step;
1146
1305
  h += step;
@@ -1149,17 +1308,20 @@ kernel void kernel_mul_mat_q3_K_f32(
1149
1308
 
1150
1309
  }
1151
1310
 
1152
- y1 += 2 * QK_K;
1311
+ y1 += 4 * QK_K;
1153
1312
 
1154
1313
  }
1155
1314
 
1156
1315
  for (int row = 0; row < 2; ++row) {
1157
- const float sumf = (sumf1[row] - 32.f*sumf2[row]) / (1 << shift);
1158
- const float tot = simd_sum(sumf);
1159
- if (tiisg == 0) {
1160
- dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = tot;
1316
+ const float sumf = (sumf1[row] + 0.25f * sumf2[row]) / (1 << shift);
1317
+ sumf1[row] = simd_sum(sumf);
1318
+ }
1319
+ if (tiisg == 0) {
1320
+ for (int row = 0; row < 2; ++row) {
1321
+ dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = sumf1[row];
1161
1322
  }
1162
1323
  }
1324
+
1163
1325
  }
1164
1326
  #else
1165
1327
  kernel void kernel_mul_mat_q3_K_f32(
@@ -1262,7 +1424,8 @@ kernel void kernel_mul_mat_q4_K_f32(
1262
1424
  const int r0 = tgpig.x;
1263
1425
  const int r1 = tgpig.y;
1264
1426
  const int r2 = tgpig.z;
1265
- const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
1427
+ //const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
1428
+ const int first_row = r0 * N_DST;
1266
1429
  const int ib_row = first_row * nb;
1267
1430
  const uint offset0 = r2/gqa*(nb*ne0);
1268
1431
  device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
@@ -1511,17 +1674,25 @@ kernel void kernel_mul_mat_q5_K_f32(
1511
1674
  sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2);
1512
1675
  sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2);
1513
1676
 
1514
- float4 acc = {0.f, 0.f, 0.f, 0.f};
1677
+ float4 acc1 = {0.f};
1678
+ float4 acc2 = {0.f};
1515
1679
  for (int l = 0; l < n; ++l) {
1516
1680
  uint8_t h = qh[l];
1517
- acc[0] += yl[l+0] * ((uint16_t)(q1[l] & 0x0F) + (h & hm1 ? 16 : 0));
1518
- acc[1] += yl[l+8] * ((uint16_t)(q1[l] & 0xF0) + (h & hm2 ? 256 : 0));
1519
- acc[2] += yh[l+0] * ((uint16_t)(q2[l] & 0x0F) + (h & hm3 ? 16 : 0));
1520
- acc[3] += yh[l+8] * ((uint16_t)(q2[l] & 0xF0) + (h & hm4 ? 256 : 0));
1681
+ acc1[0] += yl[l+0] * (q1[l] & 0x0F);
1682
+ acc1[1] += yl[l+8] * (q1[l] & 0xF0);
1683
+ acc1[2] += yh[l+0] * (q2[l] & 0x0F);
1684
+ acc1[3] += yh[l+8] * (q2[l] & 0xF0);
1685
+ acc2[0] += h & hm1 ? yl[l+0] : 0.f;
1686
+ acc2[1] += h & hm2 ? yl[l+8] : 0.f;
1687
+ acc2[2] += h & hm3 ? yh[l+0] : 0.f;
1688
+ acc2[3] += h & hm4 ? yh[l+8] : 0.f;
1521
1689
  }
1522
1690
  const float dall = dh[0];
1523
1691
  const float dmin = dh[1];
1524
- sumf[row] += dall * (acc[0] * sc8[0] + acc[1] * sc8[1] * 1.f/16.f + acc[2] * sc8[4] + acc[3] * sc8[5] * 1.f/16.f) -
1692
+ sumf[row] += dall * (sc8[0] * (acc1[0] + 16.f*acc2[0]) +
1693
+ sc8[1] * (acc1[1]/16.f + 16.f*acc2[1]) +
1694
+ sc8[4] * (acc1[2] + 16.f*acc2[2]) +
1695
+ sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) -
1525
1696
  dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
1526
1697
 
1527
1698
  q1 += step;
@@ -1704,29 +1875,34 @@ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg)
1704
1875
 
1705
1876
  template <typename type4x4>
1706
1877
  void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
1878
+
1707
1879
  device const uint16_t * qs = ((device const uint16_t *)xb + 1);
1708
- const half d = il ? (xb->d / 16.h) : xb->d;
1709
- const half m = il ? ( -8.h * 16.h) : -8.h;
1880
+ const float d1 = il ? (xb->d / 16.h) : xb->d;
1881
+ const float d2 = d1 / 256.f;
1882
+ const float md = -8.h * xb->d;
1710
1883
  const ushort mask0 = il ? 0x00F0 : 0x000F;
1711
- const ushort mask1 = il ? 0xF000 : 0x0F00;
1884
+ const ushort mask1 = mask0 << 8;
1712
1885
 
1713
1886
  for (int i=0;i<8;i++) {
1714
- reg[i/2][2*(i%2)] = (((qs[i] & mask0) ) + m) * d;
1715
- reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) + m) * d;
1887
+ reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md;
1888
+ reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md;
1716
1889
  }
1890
+
1717
1891
  }
1718
1892
 
1719
1893
  template <typename type4x4>
1720
1894
  void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
1895
+
1721
1896
  device const uint16_t * qs = ((device const uint16_t *)xb + 2);
1722
- const half d = il ? (xb->d / 16.h) : xb->d;
1723
- const half m = xb->m;
1897
+ const float d1 = il ? (xb->d / 16.h) : xb->d;
1898
+ const float d2 = d1 / 256.f;
1899
+ const float m = xb->m;
1724
1900
  const ushort mask0 = il ? 0x00F0 : 0x000F;
1725
- const ushort mask1 = il ? 0xF000 : 0x0F00;
1901
+ const ushort mask1 = mask0 << 8;
1726
1902
 
1727
1903
  for (int i=0;i<8;i++) {
1728
- reg[i/2][2*(i%2)] = (((qs[i] & mask0) ) * d) + m;
1729
- reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) * d) + m;
1904
+ reg[i/2][2*(i%2)+0] = ((qs[i] & mask0) * d1) + m;
1905
+ reg[i/2][2*(i%2)+1] = ((qs[i] & mask1) * d2) + m;
1730
1906
  }
1731
1907
  }
1732
1908
 
@@ -1762,7 +1938,7 @@ void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg
1762
1938
 
1763
1939
  template <typename type4x4>
1764
1940
  void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) {
1765
- const float d_all = (float)(xb->d);
1941
+ const half d_all = xb->d;
1766
1942
  device const uint8_t * q = (device const uint8_t *)xb->qs;
1767
1943
  device const uint8_t * h = (device const uint8_t *)xb->hmask;
1768
1944
  device const int8_t * scales = (device const int8_t *)xb->scales;
@@ -1775,17 +1951,20 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
1775
1951
  ((il/4)>0 ? 12 : 3);
1776
1952
  uint16_t kmask2 = il/8 ? 0xF0 : 0x0F;
1777
1953
  uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];
1778
- int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2) : \
1779
- (scale_2&kmask2) | ((scale_1&kmask1) << 4);
1780
- float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f);
1954
+ int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2)
1955
+ : (scale_2&kmask2) | ((scale_1&kmask1) << 4);
1956
+ half dl = il<8 ? d_all * (dl_int - 32.h) : d_all * (dl_int / 16.h - 32.h);
1957
+ const half ml = 4.h * dl;
1781
1958
 
1782
- il = (il/2)%4;
1783
- float coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
1784
- uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
1959
+ il = (il/2) & 3;
1960
+ const half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
1961
+ const uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
1962
+ dl *= coef;
1785
1963
 
1786
1964
  for (int i = 0; i < 16; ++i) {
1787
- reg[i/4][i%4] = coef * dl * ((q[i] & mask) - ((h[i] & m) ? 0 : 4.f/coef));
1965
+ reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml);
1788
1966
  }
1967
+
1789
1968
  #else
1790
1969
  float kcoef = il&1 ? 1.f/16.f : 1.f;
1791
1970
  uint16_t kmask = il&1 ? 0xF0 : 0x0F;
@@ -1799,31 +1978,37 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
1799
1978
  #endif
1800
1979
  }
1801
1980
 
1981
+ static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) {
1982
+ return j < 4 ? uchar2{uchar(q[j+0+k] & 63), uchar(q[j+4+k] & 63)}
1983
+ : uchar2{uchar((q[j+4+k] & 0xF) | ((q[j-4+k] & 0xc0) >> 2)), uchar((q[j+4+k] >> 4) | ((q[j-0+k] & 0xc0) >> 2))};
1984
+ }
1985
+
1802
1986
  template <typename type4x4>
1803
1987
  void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) {
1804
- device const uint8_t * q = xb->qs;
1988
+ device const uchar * q = xb->qs;
1805
1989
 
1806
1990
  #if QK_K == 256
1807
- const float d = (float)(xb->d);
1808
- const float min = (float)(xb->dmin);
1809
1991
  short is = (il/4) * 2;
1810
1992
  q = q + (il/4) * 32 + 16 * (il&1);
1811
- il = il%4;
1812
- const uchar4 sc = get_scale_min_k4(is, xb->scales);
1813
- const float dl = il<2 ? d * sc[0] : d * sc[2]/16.h;
1814
- const float ml = il<2 ? min * sc[1] : min * sc[3];
1993
+ il = il & 3;
1994
+ const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
1995
+ const half d = il < 2 ? xb->d : xb->d / 16.h;
1996
+ const half min = xb->dmin;
1997
+ const half dl = d * sc[0];
1998
+ const half ml = min * sc[1];
1815
1999
  #else
1816
2000
  q = q + 16 * (il&1);
1817
2001
  device const uint8_t * s = xb->scales;
1818
2002
  device const half2 * dh = (device const half2 *)xb->d;
1819
2003
  const float2 d = (float2)dh[0];
1820
2004
  const float dl = il<2 ? d[0] * (s[0]&0xF) : d[0] * (s[1]&0xF)/16.h;
1821
- const float ml = il<2 ? d[1] * (s[0]>>4) : d[1 ]* (s[1]>>4);
2005
+ const float ml = il<2 ? d[1] * (s[0]>>4) : d[1] * (s[1]>>4);
1822
2006
  #endif
1823
2007
  const ushort mask = il<2 ? 0x0F : 0xF0;
1824
2008
  for (int i = 0; i < 16; ++i) {
1825
2009
  reg[i/4][i%4] = dl * (q[i] & mask) - ml;
1826
2010
  }
2011
+
1827
2012
  }
1828
2013
 
1829
2014
  template <typename type4x4>
@@ -1832,19 +2017,19 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
1832
2017
  device const uint8_t * qh = xb->qh;
1833
2018
 
1834
2019
  #if QK_K == 256
1835
- const float d = (float)(xb->d);
1836
- const float min = (float)(xb->dmin);
1837
2020
  short is = (il/4) * 2;
1838
2021
  q = q + 32 * (il/4) + 16 * (il&1);
1839
2022
  qh = qh + 16 * (il&1);
1840
2023
  uint8_t ul = 1 << (il/2);
1841
- il = il%4;
1842
- const uchar4 sc = get_scale_min_k4(is, xb->scales);
1843
- const float dl = il<2 ? d * sc[0] : d * sc[2]/16.h;
1844
- const float ml = il<2 ? min * sc[1] : min * sc[3];
2024
+ il = il & 3;
2025
+ const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
2026
+ const half d = il < 2 ? xb->d : xb->d / 16.h;
2027
+ const half min = xb->dmin;
2028
+ const half dl = d * sc[0];
2029
+ const half ml = min * sc[1];
1845
2030
 
1846
- const ushort mask = il<2 ? 0x0F : 0xF0;
1847
- const float qh_val = il<2 ? 16.f : 256.f;
2031
+ const ushort mask = il<2 ? 0x0F : 0xF0;
2032
+ const half qh_val = il<2 ? 16.h : 256.h;
1848
2033
  for (int i = 0; i < 16; ++i) {
1849
2034
  reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
1850
2035
  }
@@ -1863,7 +2048,7 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
1863
2048
 
1864
2049
  template <typename type4x4>
1865
2050
  void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) {
1866
- const float d_all = (float)(xb->d);
2051
+ const half d_all = xb->d;
1867
2052
  device const uint8_t * ql = (device const uint8_t *)xb->ql;
1868
2053
  device const uint8_t * qh = (device const uint8_t *)xb->qh;
1869
2054
  device const int8_t * scales = (device const int8_t *)xb->scales;
@@ -1871,19 +2056,21 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
1871
2056
  #if QK_K == 256
1872
2057
  ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
1873
2058
  qh = qh + 32*(il/8) + 16*(il&1);
1874
- float sc = scales[(il%2) + 2 * ((il/2))];
1875
- il = (il/2)%4;
2059
+ half sc = scales[(il%2) + 2 * ((il/2))];
2060
+ il = (il/2) & 3;
1876
2061
  #else
1877
2062
  ql = ql + 16 * (il&1);
1878
- float sc = scales[il];
2063
+ half sc = scales[il];
1879
2064
  #endif
2065
+ const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
2066
+ const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
2067
+ const half coef = il>1 ? 1.f/16.h : 1.h;
2068
+ const half ml = d_all * sc * 32.h;
2069
+ const half dl = d_all * sc * coef;
1880
2070
  for (int i = 0; i < 16; ++i) {
1881
- uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
1882
- uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
1883
- const float coef = il>1 ? 1.f/16.f : 1.f;
1884
- float q = il&1 ? ((ql[i]&kmask2)|((qh[i]&kmask1)<<2)) - 32.f/coef : \
1885
- ((ql[i]&kmask2)|((qh[i]&kmask1)<<4)) - 32.f/coef;
1886
- reg[i/4][i%4] = d_all * sc * q * coef;
2071
+ const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2))
2072
+ : ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4));
2073
+ reg[i/4][i%4] = dl * q - ml;
1887
2074
  }
1888
2075
  }
1889
2076