llama_cpp 0.7.0 → 0.8.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -13,11 +13,26 @@ typedef struct {
13
13
 
14
14
  #define QK4_1 32
15
15
  typedef struct {
16
- half d; // delta
17
- half m; // min
16
+ half d; // delta
17
+ half m; // min
18
18
  uint8_t qs[QK4_1 / 2]; // nibbles / quants
19
19
  } block_q4_1;
20
20
 
21
+ #define QK5_0 32
22
+ typedef struct {
23
+ half d; // delta
24
+ uint8_t qh[4]; // 5-th bit of quants
25
+ uint8_t qs[QK5_0 / 2]; // nibbles / quants
26
+ } block_q5_0;
27
+
28
+ #define QK5_1 32
29
+ typedef struct {
30
+ half d; // delta
31
+ half m; // min
32
+ uint8_t qh[4]; // 5-th bit of quants
33
+ uint8_t qs[QK5_1 / 2]; // nibbles / quants
34
+ } block_q5_1;
35
+
21
36
  #define QK8_0 32
22
37
  typedef struct {
23
38
  half d; // delta
@@ -132,6 +147,13 @@ kernel void kernel_relu(
132
147
  dst[tpig] = max(0.0f, src0[tpig]);
133
148
  }
134
149
 
150
+ kernel void kernel_sqr(
151
+ device const float * src0,
152
+ device float * dst,
153
+ uint tpig[[thread_position_in_grid]]) {
154
+ dst[tpig] = src0[tpig] * src0[tpig];
155
+ }
156
+
135
157
  constant float GELU_COEF_A = 0.044715f;
136
158
  constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
137
159
 
@@ -338,10 +360,11 @@ kernel void kernel_rms_norm(
338
360
  uint sgitg[[simdgroup_index_in_threadgroup]],
339
361
  uint tiisg[[thread_index_in_simdgroup]],
340
362
  uint ntg[[threads_per_threadgroup]]) {
341
- device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
342
- device const float * x_scalar = (device const float *) x;
343
- float4 sumf=0;
344
- float all_sum=0;
363
+ device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
364
+ device const float * x_scalar = (device const float *) x;
365
+
366
+ float4 sumf = 0;
367
+ float all_sum = 0;
345
368
 
346
369
  // parallel sum
347
370
  for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
@@ -354,6 +377,7 @@ kernel void kernel_rms_norm(
354
377
  }
355
378
 
356
379
  threadgroup_barrier(mem_flags::mem_threadgroup);
380
+
357
381
  // broadcast, simd group number is ntg / 32
358
382
  for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
359
383
  if (tpitg < i) {
@@ -361,7 +385,9 @@ kernel void kernel_rms_norm(
361
385
  }
362
386
  }
363
387
  if (tpitg == 0) {
364
- for (int i = 4 * (ne00 / 4); i < ne00; i++) {sum[0] += x_scalar[i];}
388
+ for (int i = 4 * (ne00 / 4); i < ne00; i++) {
389
+ sum[0] += x_scalar[i];
390
+ }
365
391
  sum[0] /= ne00;
366
392
  }
367
393
 
@@ -376,7 +402,9 @@ kernel void kernel_rms_norm(
376
402
  y[i00] = x[i00] * scale;
377
403
  }
378
404
  if (tpitg == 0) {
379
- for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {y_scalar[i00] = x_scalar[i00] * scale;}
405
+ for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {
406
+ y_scalar[i00] = x_scalar[i00] * scale;
407
+ }
380
408
  }
381
409
  }
382
410
 
@@ -386,8 +414,11 @@ kernel void kernel_rms_norm(
386
414
  // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
387
415
  inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) {
388
416
  float d = qb_curr->d;
417
+
389
418
  float2 acc = 0.f;
419
+
390
420
  device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2);
421
+
391
422
  for (int i = 0; i < 8; i+=2) {
392
423
  acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
393
424
  + yl[i + 1] * (qs[i / 2] & 0x0F00);
@@ -404,8 +435,11 @@ inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thre
404
435
  inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) {
405
436
  float d = qb_curr->d;
406
437
  float m = qb_curr->m;
407
- device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2);
438
+
408
439
  float2 acc = 0.f;
440
+
441
+ device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2);
442
+
409
443
  for (int i = 0; i < 8; i+=2) {
410
444
  acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
411
445
  + yl[i + 1] * (qs[i / 2] & 0x0F00);
@@ -415,9 +449,52 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre
415
449
  return d * (acc[0] + acc[1]) + sumy * m;
416
450
  }
417
451
 
452
+ // function for calculate inner product between half a q5_0 block and 16 floats (yl), sumy is SUM(yl[i])
453
+ // il indicates where the q5 quants begin (0 or QK5_0/4)
454
+ // we assume that the yl's have been multiplied with the appropriate scale factor
455
+ // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
456
+ inline float block_q_n_dot_y(device const block_q5_0 * qb_curr, float sumy, thread float * yl, int il) {
457
+ float d = qb_curr->d;
458
+
459
+ float2 acc = 0.f;
460
+
461
+ device const uint16_t * qs = ((device const uint16_t *)qb_curr + 3 + il/2);
462
+ const uint32_t qh = *((device const uint32_t *)qb_curr->qh);
463
+
464
+ for (int i = 0; i < 8; i+=2) {
465
+ acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010))
466
+ + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000));
467
+ acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100))
468
+ + yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
469
+ }
470
+ return d * (sumy * -16.f + acc[0] + acc[1]);
471
+ }
472
+
473
+ // function for calculate inner product between half a q5_1 block and 16 floats (yl), sumy is SUM(yl[i])
474
+ // il indicates where the q5 quants begin (0 or QK5_1/4)
475
+ // we assume that the yl's have been multiplied with the appropriate scale factor
476
+ // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
477
+ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thread float * yl, int il) {
478
+ float d = qb_curr->d;
479
+ float m = qb_curr->m;
480
+
481
+ float2 acc = 0.f;
482
+
483
+ device const uint16_t * qs = ((device const uint16_t *)qb_curr + 4 + il/2);
484
+ const uint32_t qh = *((device const uint32_t *)qb_curr->qh);
485
+
486
+ for (int i = 0; i < 8; i+=2) {
487
+ acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010))
488
+ + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000));
489
+ acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100))
490
+ + yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
491
+ }
492
+ return d * (acc[0] + acc[1]) + sumy * m;
493
+ }
494
+
418
495
  // putting them in the kernel cause a significant performance penalty
419
- #define N_DST 4 // each SIMD group works on 4 rows
420
- #define N_SIMDGROUP 2 // number of SIMD groups in a thread group
496
+ #define N_DST 4 // each SIMD group works on 4 rows
497
+ #define N_SIMDGROUP 2 // number of SIMD groups in a thread group
421
498
  #define N_SIMDWIDTH 32 // assuming SIMD group size is 32
422
499
  //Note: This is a template, but strictly speaking it only applies to
423
500
  // quantizations where the block size is 32. It also does not
@@ -428,18 +505,23 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
428
505
  int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne10, int64_t ne12, int64_t ne0, int64_t ne1, uint gqa,
429
506
  uint3 tgpig, uint tiisg, uint sgitg) {
430
507
  const int nb = ne00/QK4_0;
508
+
431
509
  const int r0 = tgpig.x;
432
510
  const int r1 = tgpig.y;
433
511
  const int im = tgpig.z;
512
+
434
513
  const int first_row = (r0 * nsg + sgitg) * nr;
514
+
435
515
  const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
516
+
436
517
  device const block_q_type * x = (device const block_q_type *) src0 + offset0;
437
518
  device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
438
- float yl[16]; // src1 vector cache
439
- float sumf[nr]={0.f};
440
519
 
441
- const int ix = tiisg/2;
442
- const int il = 8*(tiisg%2);
520
+ float yl[16]; // src1 vector cache
521
+ float sumf[nr] = {0.f};
522
+
523
+ const int ix = (tiisg/2);
524
+ const int il = (tiisg%2)*8;
443
525
 
444
526
  device const float * yb = y + ix * QK4_0 + il;
445
527
 
@@ -450,6 +532,7 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
450
532
  sumy += yb[i] + yb[i+1];
451
533
  yl[i+0] = yb[i+ 0];
452
534
  yl[i+1] = yb[i+ 1]/256.f;
535
+
453
536
  sumy += yb[i+16] + yb[i+17];
454
537
  yl[i+8] = yb[i+16]/16.f;
455
538
  yl[i+9] = yb[i+17]/4096.f;
@@ -465,12 +548,12 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
465
548
  for (int row = 0; row < nr; ++row) {
466
549
  const float tot = simd_sum(sumf[row]);
467
550
  if (tiisg == 0 && first_row + row < ne01) {
468
- dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
551
+ dst[im*ne0*ne1 + r1*ne0 + first_row + row] = tot;
469
552
  }
470
553
  }
471
554
  }
472
555
 
473
- kernel void kernel_mul_mat_q4_0_f32(
556
+ kernel void kernel_mul_mv_q4_0_f32(
474
557
  device const void * src0,
475
558
  device const float * src1,
476
559
  device float * dst,
@@ -483,12 +566,12 @@ kernel void kernel_mul_mat_q4_0_f32(
483
566
  constant int64_t & ne1[[buffer(16)]],
484
567
  constant uint & gqa[[buffer(17)]],
485
568
  uint3 tgpig[[threadgroup_position_in_grid]],
486
- uint tiisg[[thread_index_in_simdgroup]],
487
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
569
+ uint tiisg[[thread_index_in_simdgroup]],
570
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
488
571
  mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
489
572
  }
490
573
 
491
- kernel void kernel_mul_mat_q4_1_f32(
574
+ kernel void kernel_mul_mv_q4_1_f32(
492
575
  device const void * src0,
493
576
  device const float * src1,
494
577
  device float * dst,
@@ -506,9 +589,46 @@ kernel void kernel_mul_mat_q4_1_f32(
506
589
  mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
507
590
  }
508
591
 
592
+ kernel void kernel_mul_mv_q5_0_f32(
593
+ device const void * src0,
594
+ device const float * src1,
595
+ device float * dst,
596
+ constant int64_t & ne00,
597
+ constant int64_t & ne01[[buffer(4)]],
598
+ constant int64_t & ne02[[buffer(5)]],
599
+ constant int64_t & ne10[[buffer(9)]],
600
+ constant int64_t & ne12[[buffer(11)]],
601
+ constant int64_t & ne0[[buffer(15)]],
602
+ constant int64_t & ne1[[buffer(16)]],
603
+ constant uint & gqa[[buffer(17)]],
604
+ uint3 tgpig[[threadgroup_position_in_grid]],
605
+ uint tiisg[[thread_index_in_simdgroup]],
606
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
607
+ mul_vec_q_n_f32<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
608
+ }
609
+
610
+ kernel void kernel_mul_mv_q5_1_f32(
611
+ device const void * src0,
612
+ device const float * src1,
613
+ device float * dst,
614
+ constant int64_t & ne00,
615
+ constant int64_t & ne01[[buffer(4)]],
616
+ constant int64_t & ne02[[buffer(5)]],
617
+ constant int64_t & ne10[[buffer(9)]],
618
+ constant int64_t & ne12[[buffer(11)]],
619
+ constant int64_t & ne0[[buffer(15)]],
620
+ constant int64_t & ne1[[buffer(16)]],
621
+ constant uint & gqa[[buffer(17)]],
622
+ uint3 tgpig[[threadgroup_position_in_grid]],
623
+ uint tiisg[[thread_index_in_simdgroup]],
624
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
625
+ mul_vec_q_n_f32<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
626
+ }
627
+
628
+
509
629
  #define NB_Q8_0 8
510
630
 
511
- kernel void kernel_mul_mat_q8_0_f32(
631
+ kernel void kernel_mul_mv_q8_0_f32(
512
632
  device const void * src0,
513
633
  device const float * src1,
514
634
  device float * dst,
@@ -572,7 +692,7 @@ kernel void kernel_mul_mat_q8_0_f32(
572
692
 
573
693
  #define N_F32_F32 4
574
694
 
575
- kernel void kernel_mul_mat_f32_f32(
695
+ kernel void kernel_mul_mv_f32_f32(
576
696
  device const char * src0,
577
697
  device const char * src1,
578
698
  device float * dst,
@@ -643,7 +763,7 @@ kernel void kernel_mul_mat_f32_f32(
643
763
  }
644
764
  }
645
765
 
646
- kernel void kernel_mul_mat_f16_f32_1row(
766
+ kernel void kernel_mul_mv_f16_f32_1row(
647
767
  device const char * src0,
648
768
  device const char * src1,
649
769
  device float * dst,
@@ -662,7 +782,7 @@ kernel void kernel_mul_mat_f16_f32_1row(
662
782
  constant int64_t & ne0,
663
783
  constant int64_t & ne1,
664
784
  uint3 tgpig[[threadgroup_position_in_grid]],
665
- uint tiisg[[thread_index_in_simdgroup]]) {
785
+ uint tiisg[[thread_index_in_simdgroup]]) {
666
786
 
667
787
  const int64_t r0 = tgpig.x;
668
788
  const int64_t r1 = tgpig.y;
@@ -697,7 +817,7 @@ kernel void kernel_mul_mat_f16_f32_1row(
697
817
 
698
818
  #define N_F16_F32 4
699
819
 
700
- kernel void kernel_mul_mat_f16_f32(
820
+ kernel void kernel_mul_mv_f16_f32(
701
821
  device const char * src0,
702
822
  device const char * src1,
703
823
  device float * dst,
@@ -769,7 +889,7 @@ kernel void kernel_mul_mat_f16_f32(
769
889
  }
770
890
 
771
891
  // Assumes row size (ne00) is a multiple of 4
772
- kernel void kernel_mul_mat_f16_f32_l4(
892
+ kernel void kernel_mul_mv_f16_f32_l4(
773
893
  device const char * src0,
774
894
  device const char * src1,
775
895
  device float * dst,
@@ -1098,6 +1218,62 @@ kernel void kernel_cpy_f32_f32(
1098
1218
  }
1099
1219
  }
1100
1220
 
1221
+ kernel void kernel_concat(
1222
+ device const char * src0,
1223
+ device const char * src1,
1224
+ device char * dst,
1225
+ constant int64_t & ne00,
1226
+ constant int64_t & ne01,
1227
+ constant int64_t & ne02,
1228
+ constant int64_t & ne03,
1229
+ constant uint64_t & nb00,
1230
+ constant uint64_t & nb01,
1231
+ constant uint64_t & nb02,
1232
+ constant uint64_t & nb03,
1233
+ constant int64_t & ne10,
1234
+ constant int64_t & ne11,
1235
+ constant int64_t & ne12,
1236
+ constant int64_t & ne13,
1237
+ constant uint64_t & nb10,
1238
+ constant uint64_t & nb11,
1239
+ constant uint64_t & nb12,
1240
+ constant uint64_t & nb13,
1241
+ constant int64_t & ne0,
1242
+ constant int64_t & ne1,
1243
+ constant int64_t & ne2,
1244
+ constant int64_t & ne3,
1245
+ constant uint64_t & nb0,
1246
+ constant uint64_t & nb1,
1247
+ constant uint64_t & nb2,
1248
+ constant uint64_t & nb3,
1249
+ uint3 tgpig[[threadgroup_position_in_grid]],
1250
+ uint3 tpitg[[thread_position_in_threadgroup]],
1251
+ uint3 ntg[[threads_per_threadgroup]]) {
1252
+
1253
+ const int64_t i03 = tgpig.z;
1254
+ const int64_t i02 = tgpig.y;
1255
+ const int64_t i01 = tgpig.x;
1256
+
1257
+ const int64_t i13 = i03 % ne13;
1258
+ const int64_t i12 = i02 % ne12;
1259
+ const int64_t i11 = i01 % ne11;
1260
+
1261
+ device const char * src0_ptr = src0 + i03 * nb03 + i02 * nb02 + i01 * nb01 + tpitg.x*nb00;
1262
+ device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
1263
+ device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;
1264
+
1265
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
1266
+ if (i02 < ne02) {
1267
+ ((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0];
1268
+ src0_ptr += ntg.x*nb00;
1269
+ } else {
1270
+ ((device float *)dst_ptr)[0] = ((device float *)src1_ptr)[0];
1271
+ src1_ptr += ntg.x*nb10;
1272
+ }
1273
+ dst_ptr += ntg.x*nb0;
1274
+ }
1275
+ }
1276
+
1101
1277
  //============================================ k-quants ======================================================
1102
1278
 
1103
1279
  #ifndef QK_K
@@ -1190,7 +1366,7 @@ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
1190
1366
 
1191
1367
  //====================================== dot products =========================
1192
1368
 
1193
- kernel void kernel_mul_mat_q2_K_f32(
1369
+ kernel void kernel_mul_mv_q2_K_f32(
1194
1370
  device const void * src0,
1195
1371
  device const float * src1,
1196
1372
  device float * dst,
@@ -1334,7 +1510,7 @@ kernel void kernel_mul_mat_q2_K_f32(
1334
1510
  }
1335
1511
 
1336
1512
  #if QK_K == 256
1337
- kernel void kernel_mul_mat_q3_K_f32(
1513
+ kernel void kernel_mul_mv_q3_K_f32(
1338
1514
  device const void * src0,
1339
1515
  device const float * src1,
1340
1516
  device float * dst,
@@ -1486,7 +1662,7 @@ kernel void kernel_mul_mat_q3_K_f32(
1486
1662
  }
1487
1663
  }
1488
1664
  #else
1489
- kernel void kernel_mul_mat_q3_K_f32(
1665
+ kernel void kernel_mul_mv_q3_K_f32(
1490
1666
  device const void * src0,
1491
1667
  device const float * src1,
1492
1668
  device float * dst,
@@ -1557,7 +1733,7 @@ kernel void kernel_mul_mat_q3_K_f32(
1557
1733
  #endif
1558
1734
 
1559
1735
  #if QK_K == 256
1560
- kernel void kernel_mul_mat_q4_K_f32(
1736
+ kernel void kernel_mul_mv_q4_K_f32(
1561
1737
  device const void * src0,
1562
1738
  device const float * src1,
1563
1739
  device float * dst,
@@ -1663,7 +1839,7 @@ kernel void kernel_mul_mat_q4_K_f32(
1663
1839
  }
1664
1840
  }
1665
1841
  #else
1666
- kernel void kernel_mul_mat_q4_K_f32(
1842
+ kernel void kernel_mul_mv_q4_K_f32(
1667
1843
  device const void * src0,
1668
1844
  device const float * src1,
1669
1845
  device float * dst,
@@ -1752,7 +1928,7 @@ kernel void kernel_mul_mat_q4_K_f32(
1752
1928
  }
1753
1929
  #endif
1754
1930
 
1755
- kernel void kernel_mul_mat_q5_K_f32(
1931
+ kernel void kernel_mul_mv_q5_K_f32(
1756
1932
  device const void * src0,
1757
1933
  device const float * src1,
1758
1934
  device float * dst,
@@ -1925,7 +2101,7 @@ kernel void kernel_mul_mat_q5_K_f32(
1925
2101
 
1926
2102
  }
1927
2103
 
1928
- kernel void kernel_mul_mat_q6_K_f32(
2104
+ kernel void kernel_mul_mv_q6_K_f32(
1929
2105
  device const void * src0,
1930
2106
  device const float * src1,
1931
2107
  device float * dst,
@@ -2074,6 +2250,62 @@ void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg
2074
2250
  }
2075
2251
  }
2076
2252
 
2253
+ template <typename type4x4>
2254
+ void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg) {
2255
+ device const uint16_t * qs = ((device const uint16_t *)xb + 3);
2256
+ const float d = xb->d;
2257
+ const float md = -16.h * xb->d;
2258
+ const ushort mask = il ? 0x00F0 : 0x000F;
2259
+
2260
+ const uint32_t qh = *((device const uint32_t *)xb->qh);
2261
+
2262
+ const int x_mv = il ? 4 : 0;
2263
+
2264
+ const int gh_mv = il ? 12 : 0;
2265
+ const int gh_bk = il ? 0 : 4;
2266
+
2267
+ for (int i = 0; i < 8; i++) {
2268
+ // extract the 5-th bits for x0 and x1
2269
+ const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
2270
+ const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
2271
+
2272
+ // combine the 4-bits from qs with the 5th bit
2273
+ const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
2274
+ const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
2275
+
2276
+ reg[i/2][2*(i%2)+0] = d * x0 + md;
2277
+ reg[i/2][2*(i%2)+1] = d * x1 + md;
2278
+ }
2279
+ }
2280
+
2281
+ template <typename type4x4>
2282
+ void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg) {
2283
+ device const uint16_t * qs = ((device const uint16_t *)xb + 4);
2284
+ const float d = xb->d;
2285
+ const float m = xb->m;
2286
+ const ushort mask = il ? 0x00F0 : 0x000F;
2287
+
2288
+ const uint32_t qh = *((device const uint32_t *)xb->qh);
2289
+
2290
+ const int x_mv = il ? 4 : 0;
2291
+
2292
+ const int gh_mv = il ? 12 : 0;
2293
+ const int gh_bk = il ? 0 : 4;
2294
+
2295
+ for (int i = 0; i < 8; i++) {
2296
+ // extract the 5-th bits for x0 and x1
2297
+ const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
2298
+ const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
2299
+
2300
+ // combine the 4-bits from qs with the 5th bit
2301
+ const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
2302
+ const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
2303
+
2304
+ reg[i/2][2*(i%2)+0] = d * x0 + m;
2305
+ reg[i/2][2*(i%2)+1] = d * x1 + m;
2306
+ }
2307
+ }
2308
+
2077
2309
  template <typename type4x4>
2078
2310
  void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
2079
2311
  device const int8_t * qs = ((device const int8_t *)xb->qs);
@@ -2263,7 +2495,7 @@ kernel void kernel_get_rows(
2263
2495
  }
2264
2496
 
2265
2497
  #define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
2266
- #define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix A
2498
+ #define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
2267
2499
  #define BLOCK_SIZE_K 32
2268
2500
  #define THREAD_MAT_M 4 // each thread take 4 simdgroup matrices from matrix A
2269
2501
  #define THREAD_MAT_N 2 // each thread take 2 simdgroup matrices from matrix B
@@ -2300,9 +2532,11 @@ kernel void kernel_mul_mm(device const uchar * src0,
2300
2532
  const uint r0 = tgpig.y;
2301
2533
  const uint r1 = tgpig.x;
2302
2534
  const uint im = tgpig.z;
2535
+
2303
2536
  // if this block is of 64x32 shape or smaller
2304
2537
  short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
2305
2538
  short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
2539
+
2306
2540
  // a thread shouldn't load data outside of the matrix
2307
2541
  short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
2308
2542
  short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
@@ -2326,26 +2560,30 @@ kernel void kernel_mul_mm(device const uchar * src0,
2326
2560
  + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
2327
2561
 
2328
2562
  for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
2329
- //load data and store to threadgroup memory
2563
+ // load data and store to threadgroup memory
2330
2564
  half4x4 temp_a;
2331
2565
  dequantize_func(x, il, temp_a);
2332
2566
  threadgroup_barrier(mem_flags::mem_threadgroup);
2567
+
2333
2568
  #pragma unroll(16)
2334
2569
  for (int i = 0; i < 16; i++) {
2335
2570
  *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
2336
- + 16 * (tiitg % THREAD_PER_ROW) + 8 * (i / 8)) \
2337
- + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
2571
+ + (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \
2572
+ + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
2338
2573
  }
2339
- *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) \
2340
- = *((device float2x4 *)y);
2574
+
2575
+ *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y);
2576
+
2341
2577
  il = (il + 2 < nl) ? il + 2 : il % 2;
2342
2578
  x = (il < 2) ? x + (2+nl-1)/nl : x;
2343
2579
  y += BLOCK_SIZE_K;
2344
2580
 
2345
2581
  threadgroup_barrier(mem_flags::mem_threadgroup);
2346
- //load matrices from threadgroup memory and conduct outer products
2582
+
2583
+ // load matrices from threadgroup memory and conduct outer products
2347
2584
  threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
2348
2585
  threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
2586
+
2349
2587
  #pragma unroll(4)
2350
2588
  for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
2351
2589
  #pragma unroll(4)
@@ -2360,6 +2598,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
2360
2598
 
2361
2599
  lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
2362
2600
  lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
2601
+
2363
2602
  #pragma unroll(8)
2364
2603
  for (int i = 0; i < 8; i++){
2365
2604
  simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
@@ -2368,25 +2607,26 @@ kernel void kernel_mul_mm(device const uchar * src0,
2368
2607
  }
2369
2608
 
2370
2609
  if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) {
2371
- device float *C = dst + BLOCK_SIZE_M * r0 + 32 * (sgitg&1) \
2372
- + (BLOCK_SIZE_N * r1 + 16 * (sgitg>>1)) * ne0 + im*ne1*ne0;
2610
+ device float * C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1)) \
2611
+ + (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * ne0 + im*ne1*ne0;
2373
2612
  for (int i = 0; i < 8; i++) {
2374
2613
  simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0);
2375
2614
  }
2376
2615
  } else {
2377
2616
  // block is smaller than 64x32, we should avoid writing data outside of the matrix
2378
2617
  threadgroup_barrier(mem_flags::mem_threadgroup);
2379
- threadgroup float *temp_str = ((threadgroup float *)shared_memory) \
2618
+ threadgroup float * temp_str = ((threadgroup float *)shared_memory) \
2380
2619
  + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
2381
2620
  for (int i = 0; i < 8; i++) {
2382
2621
  simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
2383
2622
  }
2384
2623
 
2385
2624
  threadgroup_barrier(mem_flags::mem_threadgroup);
2386
- device float *C = dst + BLOCK_SIZE_M * r0 + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
2387
- if (sgitg==0) {
2625
+
2626
+ device float * C = dst + (BLOCK_SIZE_M * r0) + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
2627
+ if (sgitg == 0) {
2388
2628
  for (int i = 0; i < n_rows; i++) {
2389
- for (int j = tiitg; j< n_cols; j += BLOCK_SIZE_N) {
2629
+ for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
2390
2630
  *(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
2391
2631
  }
2392
2632
  }
@@ -2407,6 +2647,8 @@ template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows
2407
2647
  template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
2408
2648
  template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
2409
2649
  template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
2650
+ template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows<block_q5_0, 2, dequantize_q5_0>;
2651
+ template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_t kernel_get_rows<block_q5_1, 2, dequantize_q5_1>;
2410
2652
  template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows<block_q8_0, 2, dequantize_q8_0>;
2411
2653
  template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows<block_q2_K, QK_NL, dequantize_q2_K>;
2412
2654
  template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows<block_q3_K, QK_NL, dequantize_q3_K>;
@@ -2435,6 +2677,8 @@ template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<f
2435
2677
  template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
2436
2678
  template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
2437
2679
  template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
2680
+ template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_0, 2, dequantize_q5_0>;
2681
+ template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_1, 2, dequantize_q5_1>;
2438
2682
  template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
2439
2683
  template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
2440
2684
  template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;