llama_cpp 0.7.0 → 0.8.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -13,11 +13,26 @@ typedef struct {
13
13
 
14
14
  #define QK4_1 32
15
15
  typedef struct {
16
- half d; // delta
17
- half m; // min
16
+ half d; // delta
17
+ half m; // min
18
18
  uint8_t qs[QK4_1 / 2]; // nibbles / quants
19
19
  } block_q4_1;
20
20
 
21
+ #define QK5_0 32
22
+ typedef struct {
23
+ half d; // delta
24
+ uint8_t qh[4]; // 5-th bit of quants
25
+ uint8_t qs[QK5_0 / 2]; // nibbles / quants
26
+ } block_q5_0;
27
+
28
+ #define QK5_1 32
29
+ typedef struct {
30
+ half d; // delta
31
+ half m; // min
32
+ uint8_t qh[4]; // 5-th bit of quants
33
+ uint8_t qs[QK5_1 / 2]; // nibbles / quants
34
+ } block_q5_1;
35
+
21
36
  #define QK8_0 32
22
37
  typedef struct {
23
38
  half d; // delta
@@ -132,6 +147,13 @@ kernel void kernel_relu(
132
147
  dst[tpig] = max(0.0f, src0[tpig]);
133
148
  }
134
149
 
150
+ kernel void kernel_sqr(
151
+ device const float * src0,
152
+ device float * dst,
153
+ uint tpig[[thread_position_in_grid]]) {
154
+ dst[tpig] = src0[tpig] * src0[tpig];
155
+ }
156
+
135
157
  constant float GELU_COEF_A = 0.044715f;
136
158
  constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
137
159
 
@@ -338,10 +360,11 @@ kernel void kernel_rms_norm(
338
360
  uint sgitg[[simdgroup_index_in_threadgroup]],
339
361
  uint tiisg[[thread_index_in_simdgroup]],
340
362
  uint ntg[[threads_per_threadgroup]]) {
341
- device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
342
- device const float * x_scalar = (device const float *) x;
343
- float4 sumf=0;
344
- float all_sum=0;
363
+ device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
364
+ device const float * x_scalar = (device const float *) x;
365
+
366
+ float4 sumf = 0;
367
+ float all_sum = 0;
345
368
 
346
369
  // parallel sum
347
370
  for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
@@ -354,6 +377,7 @@ kernel void kernel_rms_norm(
354
377
  }
355
378
 
356
379
  threadgroup_barrier(mem_flags::mem_threadgroup);
380
+
357
381
  // broadcast, simd group number is ntg / 32
358
382
  for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
359
383
  if (tpitg < i) {
@@ -361,7 +385,9 @@ kernel void kernel_rms_norm(
361
385
  }
362
386
  }
363
387
  if (tpitg == 0) {
364
- for (int i = 4 * (ne00 / 4); i < ne00; i++) {sum[0] += x_scalar[i];}
388
+ for (int i = 4 * (ne00 / 4); i < ne00; i++) {
389
+ sum[0] += x_scalar[i];
390
+ }
365
391
  sum[0] /= ne00;
366
392
  }
367
393
 
@@ -376,7 +402,9 @@ kernel void kernel_rms_norm(
376
402
  y[i00] = x[i00] * scale;
377
403
  }
378
404
  if (tpitg == 0) {
379
- for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {y_scalar[i00] = x_scalar[i00] * scale;}
405
+ for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {
406
+ y_scalar[i00] = x_scalar[i00] * scale;
407
+ }
380
408
  }
381
409
  }
382
410
 
@@ -386,8 +414,11 @@ kernel void kernel_rms_norm(
386
414
  // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
387
415
  inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) {
388
416
  float d = qb_curr->d;
417
+
389
418
  float2 acc = 0.f;
419
+
390
420
  device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2);
421
+
391
422
  for (int i = 0; i < 8; i+=2) {
392
423
  acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
393
424
  + yl[i + 1] * (qs[i / 2] & 0x0F00);
@@ -404,8 +435,11 @@ inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thre
404
435
  inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) {
405
436
  float d = qb_curr->d;
406
437
  float m = qb_curr->m;
407
- device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2);
438
+
408
439
  float2 acc = 0.f;
440
+
441
+ device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2);
442
+
409
443
  for (int i = 0; i < 8; i+=2) {
410
444
  acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
411
445
  + yl[i + 1] * (qs[i / 2] & 0x0F00);
@@ -415,9 +449,52 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre
415
449
  return d * (acc[0] + acc[1]) + sumy * m;
416
450
  }
417
451
 
452
+ // function for calculate inner product between half a q5_0 block and 16 floats (yl), sumy is SUM(yl[i])
453
+ // il indicates where the q5 quants begin (0 or QK5_0/4)
454
+ // we assume that the yl's have been multiplied with the appropriate scale factor
455
+ // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
456
+ inline float block_q_n_dot_y(device const block_q5_0 * qb_curr, float sumy, thread float * yl, int il) {
457
+ float d = qb_curr->d;
458
+
459
+ float2 acc = 0.f;
460
+
461
+ device const uint16_t * qs = ((device const uint16_t *)qb_curr + 3 + il/2);
462
+ const uint32_t qh = *((device const uint32_t *)qb_curr->qh);
463
+
464
+ for (int i = 0; i < 8; i+=2) {
465
+ acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010))
466
+ + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000));
467
+ acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100))
468
+ + yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
469
+ }
470
+ return d * (sumy * -16.f + acc[0] + acc[1]);
471
+ }
472
+
473
+ // function for calculate inner product between half a q5_1 block and 16 floats (yl), sumy is SUM(yl[i])
474
+ // il indicates where the q5 quants begin (0 or QK5_1/4)
475
+ // we assume that the yl's have been multiplied with the appropriate scale factor
476
+ // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
477
+ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thread float * yl, int il) {
478
+ float d = qb_curr->d;
479
+ float m = qb_curr->m;
480
+
481
+ float2 acc = 0.f;
482
+
483
+ device const uint16_t * qs = ((device const uint16_t *)qb_curr + 4 + il/2);
484
+ const uint32_t qh = *((device const uint32_t *)qb_curr->qh);
485
+
486
+ for (int i = 0; i < 8; i+=2) {
487
+ acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010))
488
+ + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000));
489
+ acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100))
490
+ + yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
491
+ }
492
+ return d * (acc[0] + acc[1]) + sumy * m;
493
+ }
494
+
418
495
  // putting them in the kernel cause a significant performance penalty
419
- #define N_DST 4 // each SIMD group works on 4 rows
420
- #define N_SIMDGROUP 2 // number of SIMD groups in a thread group
496
+ #define N_DST 4 // each SIMD group works on 4 rows
497
+ #define N_SIMDGROUP 2 // number of SIMD groups in a thread group
421
498
  #define N_SIMDWIDTH 32 // assuming SIMD group size is 32
422
499
  //Note: This is a template, but strictly speaking it only applies to
423
500
  // quantizations where the block size is 32. It also does not
@@ -428,18 +505,23 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
428
505
  int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne10, int64_t ne12, int64_t ne0, int64_t ne1, uint gqa,
429
506
  uint3 tgpig, uint tiisg, uint sgitg) {
430
507
  const int nb = ne00/QK4_0;
508
+
431
509
  const int r0 = tgpig.x;
432
510
  const int r1 = tgpig.y;
433
511
  const int im = tgpig.z;
512
+
434
513
  const int first_row = (r0 * nsg + sgitg) * nr;
514
+
435
515
  const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
516
+
436
517
  device const block_q_type * x = (device const block_q_type *) src0 + offset0;
437
518
  device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
438
- float yl[16]; // src1 vector cache
439
- float sumf[nr]={0.f};
440
519
 
441
- const int ix = tiisg/2;
442
- const int il = 8*(tiisg%2);
520
+ float yl[16]; // src1 vector cache
521
+ float sumf[nr] = {0.f};
522
+
523
+ const int ix = (tiisg/2);
524
+ const int il = (tiisg%2)*8;
443
525
 
444
526
  device const float * yb = y + ix * QK4_0 + il;
445
527
 
@@ -450,6 +532,7 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
450
532
  sumy += yb[i] + yb[i+1];
451
533
  yl[i+0] = yb[i+ 0];
452
534
  yl[i+1] = yb[i+ 1]/256.f;
535
+
453
536
  sumy += yb[i+16] + yb[i+17];
454
537
  yl[i+8] = yb[i+16]/16.f;
455
538
  yl[i+9] = yb[i+17]/4096.f;
@@ -465,12 +548,12 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
465
548
  for (int row = 0; row < nr; ++row) {
466
549
  const float tot = simd_sum(sumf[row]);
467
550
  if (tiisg == 0 && first_row + row < ne01) {
468
- dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
551
+ dst[im*ne0*ne1 + r1*ne0 + first_row + row] = tot;
469
552
  }
470
553
  }
471
554
  }
472
555
 
473
- kernel void kernel_mul_mat_q4_0_f32(
556
+ kernel void kernel_mul_mv_q4_0_f32(
474
557
  device const void * src0,
475
558
  device const float * src1,
476
559
  device float * dst,
@@ -483,12 +566,12 @@ kernel void kernel_mul_mat_q4_0_f32(
483
566
  constant int64_t & ne1[[buffer(16)]],
484
567
  constant uint & gqa[[buffer(17)]],
485
568
  uint3 tgpig[[threadgroup_position_in_grid]],
486
- uint tiisg[[thread_index_in_simdgroup]],
487
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
569
+ uint tiisg[[thread_index_in_simdgroup]],
570
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
488
571
  mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
489
572
  }
490
573
 
491
- kernel void kernel_mul_mat_q4_1_f32(
574
+ kernel void kernel_mul_mv_q4_1_f32(
492
575
  device const void * src0,
493
576
  device const float * src1,
494
577
  device float * dst,
@@ -506,9 +589,46 @@ kernel void kernel_mul_mat_q4_1_f32(
506
589
  mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
507
590
  }
508
591
 
592
+ kernel void kernel_mul_mv_q5_0_f32(
593
+ device const void * src0,
594
+ device const float * src1,
595
+ device float * dst,
596
+ constant int64_t & ne00,
597
+ constant int64_t & ne01[[buffer(4)]],
598
+ constant int64_t & ne02[[buffer(5)]],
599
+ constant int64_t & ne10[[buffer(9)]],
600
+ constant int64_t & ne12[[buffer(11)]],
601
+ constant int64_t & ne0[[buffer(15)]],
602
+ constant int64_t & ne1[[buffer(16)]],
603
+ constant uint & gqa[[buffer(17)]],
604
+ uint3 tgpig[[threadgroup_position_in_grid]],
605
+ uint tiisg[[thread_index_in_simdgroup]],
606
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
607
+ mul_vec_q_n_f32<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
608
+ }
609
+
610
+ kernel void kernel_mul_mv_q5_1_f32(
611
+ device const void * src0,
612
+ device const float * src1,
613
+ device float * dst,
614
+ constant int64_t & ne00,
615
+ constant int64_t & ne01[[buffer(4)]],
616
+ constant int64_t & ne02[[buffer(5)]],
617
+ constant int64_t & ne10[[buffer(9)]],
618
+ constant int64_t & ne12[[buffer(11)]],
619
+ constant int64_t & ne0[[buffer(15)]],
620
+ constant int64_t & ne1[[buffer(16)]],
621
+ constant uint & gqa[[buffer(17)]],
622
+ uint3 tgpig[[threadgroup_position_in_grid]],
623
+ uint tiisg[[thread_index_in_simdgroup]],
624
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
625
+ mul_vec_q_n_f32<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
626
+ }
627
+
628
+
509
629
  #define NB_Q8_0 8
510
630
 
511
- kernel void kernel_mul_mat_q8_0_f32(
631
+ kernel void kernel_mul_mv_q8_0_f32(
512
632
  device const void * src0,
513
633
  device const float * src1,
514
634
  device float * dst,
@@ -572,7 +692,7 @@ kernel void kernel_mul_mat_q8_0_f32(
572
692
 
573
693
  #define N_F32_F32 4
574
694
 
575
- kernel void kernel_mul_mat_f32_f32(
695
+ kernel void kernel_mul_mv_f32_f32(
576
696
  device const char * src0,
577
697
  device const char * src1,
578
698
  device float * dst,
@@ -643,7 +763,7 @@ kernel void kernel_mul_mat_f32_f32(
643
763
  }
644
764
  }
645
765
 
646
- kernel void kernel_mul_mat_f16_f32_1row(
766
+ kernel void kernel_mul_mv_f16_f32_1row(
647
767
  device const char * src0,
648
768
  device const char * src1,
649
769
  device float * dst,
@@ -662,7 +782,7 @@ kernel void kernel_mul_mat_f16_f32_1row(
662
782
  constant int64_t & ne0,
663
783
  constant int64_t & ne1,
664
784
  uint3 tgpig[[threadgroup_position_in_grid]],
665
- uint tiisg[[thread_index_in_simdgroup]]) {
785
+ uint tiisg[[thread_index_in_simdgroup]]) {
666
786
 
667
787
  const int64_t r0 = tgpig.x;
668
788
  const int64_t r1 = tgpig.y;
@@ -697,7 +817,7 @@ kernel void kernel_mul_mat_f16_f32_1row(
697
817
 
698
818
  #define N_F16_F32 4
699
819
 
700
- kernel void kernel_mul_mat_f16_f32(
820
+ kernel void kernel_mul_mv_f16_f32(
701
821
  device const char * src0,
702
822
  device const char * src1,
703
823
  device float * dst,
@@ -769,7 +889,7 @@ kernel void kernel_mul_mat_f16_f32(
769
889
  }
770
890
 
771
891
  // Assumes row size (ne00) is a multiple of 4
772
- kernel void kernel_mul_mat_f16_f32_l4(
892
+ kernel void kernel_mul_mv_f16_f32_l4(
773
893
  device const char * src0,
774
894
  device const char * src1,
775
895
  device float * dst,
@@ -1098,6 +1218,62 @@ kernel void kernel_cpy_f32_f32(
1098
1218
  }
1099
1219
  }
1100
1220
 
1221
+ kernel void kernel_concat(
1222
+ device const char * src0,
1223
+ device const char * src1,
1224
+ device char * dst,
1225
+ constant int64_t & ne00,
1226
+ constant int64_t & ne01,
1227
+ constant int64_t & ne02,
1228
+ constant int64_t & ne03,
1229
+ constant uint64_t & nb00,
1230
+ constant uint64_t & nb01,
1231
+ constant uint64_t & nb02,
1232
+ constant uint64_t & nb03,
1233
+ constant int64_t & ne10,
1234
+ constant int64_t & ne11,
1235
+ constant int64_t & ne12,
1236
+ constant int64_t & ne13,
1237
+ constant uint64_t & nb10,
1238
+ constant uint64_t & nb11,
1239
+ constant uint64_t & nb12,
1240
+ constant uint64_t & nb13,
1241
+ constant int64_t & ne0,
1242
+ constant int64_t & ne1,
1243
+ constant int64_t & ne2,
1244
+ constant int64_t & ne3,
1245
+ constant uint64_t & nb0,
1246
+ constant uint64_t & nb1,
1247
+ constant uint64_t & nb2,
1248
+ constant uint64_t & nb3,
1249
+ uint3 tgpig[[threadgroup_position_in_grid]],
1250
+ uint3 tpitg[[thread_position_in_threadgroup]],
1251
+ uint3 ntg[[threads_per_threadgroup]]) {
1252
+
1253
+ const int64_t i03 = tgpig.z;
1254
+ const int64_t i02 = tgpig.y;
1255
+ const int64_t i01 = tgpig.x;
1256
+
1257
+ const int64_t i13 = i03 % ne13;
1258
+ const int64_t i12 = i02 % ne12;
1259
+ const int64_t i11 = i01 % ne11;
1260
+
1261
+ device const char * src0_ptr = src0 + i03 * nb03 + i02 * nb02 + i01 * nb01 + tpitg.x*nb00;
1262
+ device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
1263
+ device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;
1264
+
1265
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
1266
+ if (i02 < ne02) {
1267
+ ((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0];
1268
+ src0_ptr += ntg.x*nb00;
1269
+ } else {
1270
+ ((device float *)dst_ptr)[0] = ((device float *)src1_ptr)[0];
1271
+ src1_ptr += ntg.x*nb10;
1272
+ }
1273
+ dst_ptr += ntg.x*nb0;
1274
+ }
1275
+ }
1276
+
1101
1277
  //============================================ k-quants ======================================================
1102
1278
 
1103
1279
  #ifndef QK_K
@@ -1190,7 +1366,7 @@ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
1190
1366
 
1191
1367
  //====================================== dot products =========================
1192
1368
 
1193
- kernel void kernel_mul_mat_q2_K_f32(
1369
+ kernel void kernel_mul_mv_q2_K_f32(
1194
1370
  device const void * src0,
1195
1371
  device const float * src1,
1196
1372
  device float * dst,
@@ -1334,7 +1510,7 @@ kernel void kernel_mul_mat_q2_K_f32(
1334
1510
  }
1335
1511
 
1336
1512
  #if QK_K == 256
1337
- kernel void kernel_mul_mat_q3_K_f32(
1513
+ kernel void kernel_mul_mv_q3_K_f32(
1338
1514
  device const void * src0,
1339
1515
  device const float * src1,
1340
1516
  device float * dst,
@@ -1486,7 +1662,7 @@ kernel void kernel_mul_mat_q3_K_f32(
1486
1662
  }
1487
1663
  }
1488
1664
  #else
1489
- kernel void kernel_mul_mat_q3_K_f32(
1665
+ kernel void kernel_mul_mv_q3_K_f32(
1490
1666
  device const void * src0,
1491
1667
  device const float * src1,
1492
1668
  device float * dst,
@@ -1557,7 +1733,7 @@ kernel void kernel_mul_mat_q3_K_f32(
1557
1733
  #endif
1558
1734
 
1559
1735
  #if QK_K == 256
1560
- kernel void kernel_mul_mat_q4_K_f32(
1736
+ kernel void kernel_mul_mv_q4_K_f32(
1561
1737
  device const void * src0,
1562
1738
  device const float * src1,
1563
1739
  device float * dst,
@@ -1663,7 +1839,7 @@ kernel void kernel_mul_mat_q4_K_f32(
1663
1839
  }
1664
1840
  }
1665
1841
  #else
1666
- kernel void kernel_mul_mat_q4_K_f32(
1842
+ kernel void kernel_mul_mv_q4_K_f32(
1667
1843
  device const void * src0,
1668
1844
  device const float * src1,
1669
1845
  device float * dst,
@@ -1752,7 +1928,7 @@ kernel void kernel_mul_mat_q4_K_f32(
1752
1928
  }
1753
1929
  #endif
1754
1930
 
1755
- kernel void kernel_mul_mat_q5_K_f32(
1931
+ kernel void kernel_mul_mv_q5_K_f32(
1756
1932
  device const void * src0,
1757
1933
  device const float * src1,
1758
1934
  device float * dst,
@@ -1925,7 +2101,7 @@ kernel void kernel_mul_mat_q5_K_f32(
1925
2101
 
1926
2102
  }
1927
2103
 
1928
- kernel void kernel_mul_mat_q6_K_f32(
2104
+ kernel void kernel_mul_mv_q6_K_f32(
1929
2105
  device const void * src0,
1930
2106
  device const float * src1,
1931
2107
  device float * dst,
@@ -2074,6 +2250,62 @@ void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg
2074
2250
  }
2075
2251
  }
2076
2252
 
2253
+ template <typename type4x4>
2254
+ void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg) {
2255
+ device const uint16_t * qs = ((device const uint16_t *)xb + 3);
2256
+ const float d = xb->d;
2257
+ const float md = -16.h * xb->d;
2258
+ const ushort mask = il ? 0x00F0 : 0x000F;
2259
+
2260
+ const uint32_t qh = *((device const uint32_t *)xb->qh);
2261
+
2262
+ const int x_mv = il ? 4 : 0;
2263
+
2264
+ const int gh_mv = il ? 12 : 0;
2265
+ const int gh_bk = il ? 0 : 4;
2266
+
2267
+ for (int i = 0; i < 8; i++) {
2268
+ // extract the 5-th bits for x0 and x1
2269
+ const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
2270
+ const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
2271
+
2272
+ // combine the 4-bits from qs with the 5th bit
2273
+ const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
2274
+ const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
2275
+
2276
+ reg[i/2][2*(i%2)+0] = d * x0 + md;
2277
+ reg[i/2][2*(i%2)+1] = d * x1 + md;
2278
+ }
2279
+ }
2280
+
2281
+ template <typename type4x4>
2282
+ void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg) {
2283
+ device const uint16_t * qs = ((device const uint16_t *)xb + 4);
2284
+ const float d = xb->d;
2285
+ const float m = xb->m;
2286
+ const ushort mask = il ? 0x00F0 : 0x000F;
2287
+
2288
+ const uint32_t qh = *((device const uint32_t *)xb->qh);
2289
+
2290
+ const int x_mv = il ? 4 : 0;
2291
+
2292
+ const int gh_mv = il ? 12 : 0;
2293
+ const int gh_bk = il ? 0 : 4;
2294
+
2295
+ for (int i = 0; i < 8; i++) {
2296
+ // extract the 5-th bits for x0 and x1
2297
+ const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
2298
+ const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
2299
+
2300
+ // combine the 4-bits from qs with the 5th bit
2301
+ const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
2302
+ const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
2303
+
2304
+ reg[i/2][2*(i%2)+0] = d * x0 + m;
2305
+ reg[i/2][2*(i%2)+1] = d * x1 + m;
2306
+ }
2307
+ }
2308
+
2077
2309
  template <typename type4x4>
2078
2310
  void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
2079
2311
  device const int8_t * qs = ((device const int8_t *)xb->qs);
@@ -2263,7 +2495,7 @@ kernel void kernel_get_rows(
2263
2495
  }
2264
2496
 
2265
2497
  #define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
2266
- #define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix A
2498
+ #define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
2267
2499
  #define BLOCK_SIZE_K 32
2268
2500
  #define THREAD_MAT_M 4 // each thread take 4 simdgroup matrices from matrix A
2269
2501
  #define THREAD_MAT_N 2 // each thread take 2 simdgroup matrices from matrix B
@@ -2300,9 +2532,11 @@ kernel void kernel_mul_mm(device const uchar * src0,
2300
2532
  const uint r0 = tgpig.y;
2301
2533
  const uint r1 = tgpig.x;
2302
2534
  const uint im = tgpig.z;
2535
+
2303
2536
  // if this block is of 64x32 shape or smaller
2304
2537
  short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
2305
2538
  short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
2539
+
2306
2540
  // a thread shouldn't load data outside of the matrix
2307
2541
  short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
2308
2542
  short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
@@ -2326,26 +2560,30 @@ kernel void kernel_mul_mm(device const uchar * src0,
2326
2560
  + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
2327
2561
 
2328
2562
  for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
2329
- //load data and store to threadgroup memory
2563
+ // load data and store to threadgroup memory
2330
2564
  half4x4 temp_a;
2331
2565
  dequantize_func(x, il, temp_a);
2332
2566
  threadgroup_barrier(mem_flags::mem_threadgroup);
2567
+
2333
2568
  #pragma unroll(16)
2334
2569
  for (int i = 0; i < 16; i++) {
2335
2570
  *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
2336
- + 16 * (tiitg % THREAD_PER_ROW) + 8 * (i / 8)) \
2337
- + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
2571
+ + (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \
2572
+ + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
2338
2573
  }
2339
- *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) \
2340
- = *((device float2x4 *)y);
2574
+
2575
+ *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y);
2576
+
2341
2577
  il = (il + 2 < nl) ? il + 2 : il % 2;
2342
2578
  x = (il < 2) ? x + (2+nl-1)/nl : x;
2343
2579
  y += BLOCK_SIZE_K;
2344
2580
 
2345
2581
  threadgroup_barrier(mem_flags::mem_threadgroup);
2346
- //load matrices from threadgroup memory and conduct outer products
2582
+
2583
+ // load matrices from threadgroup memory and conduct outer products
2347
2584
  threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
2348
2585
  threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
2586
+
2349
2587
  #pragma unroll(4)
2350
2588
  for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
2351
2589
  #pragma unroll(4)
@@ -2360,6 +2598,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
2360
2598
 
2361
2599
  lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
2362
2600
  lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
2601
+
2363
2602
  #pragma unroll(8)
2364
2603
  for (int i = 0; i < 8; i++){
2365
2604
  simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
@@ -2368,25 +2607,26 @@ kernel void kernel_mul_mm(device const uchar * src0,
2368
2607
  }
2369
2608
 
2370
2609
  if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) {
2371
- device float *C = dst + BLOCK_SIZE_M * r0 + 32 * (sgitg&1) \
2372
- + (BLOCK_SIZE_N * r1 + 16 * (sgitg>>1)) * ne0 + im*ne1*ne0;
2610
+ device float * C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1)) \
2611
+ + (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * ne0 + im*ne1*ne0;
2373
2612
  for (int i = 0; i < 8; i++) {
2374
2613
  simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0);
2375
2614
  }
2376
2615
  } else {
2377
2616
  // block is smaller than 64x32, we should avoid writing data outside of the matrix
2378
2617
  threadgroup_barrier(mem_flags::mem_threadgroup);
2379
- threadgroup float *temp_str = ((threadgroup float *)shared_memory) \
2618
+ threadgroup float * temp_str = ((threadgroup float *)shared_memory) \
2380
2619
  + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
2381
2620
  for (int i = 0; i < 8; i++) {
2382
2621
  simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
2383
2622
  }
2384
2623
 
2385
2624
  threadgroup_barrier(mem_flags::mem_threadgroup);
2386
- device float *C = dst + BLOCK_SIZE_M * r0 + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
2387
- if (sgitg==0) {
2625
+
2626
+ device float * C = dst + (BLOCK_SIZE_M * r0) + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
2627
+ if (sgitg == 0) {
2388
2628
  for (int i = 0; i < n_rows; i++) {
2389
- for (int j = tiitg; j< n_cols; j += BLOCK_SIZE_N) {
2629
+ for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
2390
2630
  *(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
2391
2631
  }
2392
2632
  }
@@ -2407,6 +2647,8 @@ template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows
2407
2647
  template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
2408
2648
  template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
2409
2649
  template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
2650
+ template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows<block_q5_0, 2, dequantize_q5_0>;
2651
+ template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_t kernel_get_rows<block_q5_1, 2, dequantize_q5_1>;
2410
2652
  template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows<block_q8_0, 2, dequantize_q8_0>;
2411
2653
  template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows<block_q2_K, QK_NL, dequantize_q2_K>;
2412
2654
  template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows<block_q3_K, QK_NL, dequantize_q3_K>;
@@ -2435,6 +2677,8 @@ template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<f
2435
2677
  template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
2436
2678
  template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
2437
2679
  template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
2680
+ template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_0, 2, dequantize_q5_0>;
2681
+ template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_1, 2, dequantize_q5_1>;
2438
2682
  template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
2439
2683
  template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
2440
2684
  template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;