llama_cpp 0.7.1 → 0.9.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -18,6 +18,21 @@ typedef struct {
18
18
  uint8_t qs[QK4_1 / 2]; // nibbles / quants
19
19
  } block_q4_1;
20
20
 
21
+ #define QK5_0 32
22
+ typedef struct {
23
+ half d; // delta
24
+ uint8_t qh[4]; // 5-th bit of quants
25
+ uint8_t qs[QK5_0 / 2]; // nibbles / quants
26
+ } block_q5_0;
27
+
28
+ #define QK5_1 32
29
+ typedef struct {
30
+ half d; // delta
31
+ half m; // min
32
+ uint8_t qh[4]; // 5-th bit of quants
33
+ uint8_t qs[QK5_1 / 2]; // nibbles / quants
34
+ } block_q5_1;
35
+
21
36
  #define QK8_0 32
22
37
  typedef struct {
23
38
  half d; // delta
@@ -110,9 +125,17 @@ kernel void kernel_mul_row(
110
125
  }
111
126
 
112
127
  kernel void kernel_scale(
128
+ device const float * src0,
129
+ device float * dst,
130
+ constant float & scale,
131
+ uint tpig[[thread_position_in_grid]]) {
132
+ dst[tpig] = src0[tpig] * scale;
133
+ }
134
+
135
+ kernel void kernel_scale_4(
113
136
  device const float4 * src0,
114
137
  device float4 * dst,
115
- constant float & scale,
138
+ constant float & scale,
116
139
  uint tpig[[thread_position_in_grid]]) {
117
140
  dst[tpig] = src0[tpig] * scale;
118
141
  }
@@ -399,8 +422,11 @@ kernel void kernel_rms_norm(
399
422
  // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
400
423
  inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) {
401
424
  float d = qb_curr->d;
425
+
402
426
  float2 acc = 0.f;
427
+
403
428
  device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2);
429
+
404
430
  for (int i = 0; i < 8; i+=2) {
405
431
  acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
406
432
  + yl[i + 1] * (qs[i / 2] & 0x0F00);
@@ -417,8 +443,11 @@ inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thre
417
443
  inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) {
418
444
  float d = qb_curr->d;
419
445
  float m = qb_curr->m;
420
- device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2);
446
+
421
447
  float2 acc = 0.f;
448
+
449
+ device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2);
450
+
422
451
  for (int i = 0; i < 8; i+=2) {
423
452
  acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
424
453
  + yl[i + 1] * (qs[i / 2] & 0x0F00);
@@ -428,6 +457,49 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre
428
457
  return d * (acc[0] + acc[1]) + sumy * m;
429
458
  }
430
459
 
460
+ // function for calculate inner product between half a q5_0 block and 16 floats (yl), sumy is SUM(yl[i])
461
+ // il indicates where the q5 quants begin (0 or QK5_0/4)
462
+ // we assume that the yl's have been multiplied with the appropriate scale factor
463
+ // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
464
+ inline float block_q_n_dot_y(device const block_q5_0 * qb_curr, float sumy, thread float * yl, int il) {
465
+ float d = qb_curr->d;
466
+
467
+ float2 acc = 0.f;
468
+
469
+ device const uint16_t * qs = ((device const uint16_t *)qb_curr + 3 + il/2);
470
+ const uint32_t qh = *((device const uint32_t *)qb_curr->qh);
471
+
472
+ for (int i = 0; i < 8; i+=2) {
473
+ acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010))
474
+ + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000));
475
+ acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100))
476
+ + yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
477
+ }
478
+ return d * (sumy * -16.f + acc[0] + acc[1]);
479
+ }
480
+
481
+ // function for calculate inner product between half a q5_1 block and 16 floats (yl), sumy is SUM(yl[i])
482
+ // il indicates where the q5 quants begin (0 or QK5_1/4)
483
+ // we assume that the yl's have been multiplied with the appropriate scale factor
484
+ // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
485
+ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thread float * yl, int il) {
486
+ float d = qb_curr->d;
487
+ float m = qb_curr->m;
488
+
489
+ float2 acc = 0.f;
490
+
491
+ device const uint16_t * qs = ((device const uint16_t *)qb_curr + 4 + il/2);
492
+ const uint32_t qh = *((device const uint32_t *)qb_curr->qh);
493
+
494
+ for (int i = 0; i < 8; i+=2) {
495
+ acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010))
496
+ + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000));
497
+ acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100))
498
+ + yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
499
+ }
500
+ return d * (acc[0] + acc[1]) + sumy * m;
501
+ }
502
+
431
503
  // putting them in the kernel cause a significant performance penalty
432
504
  #define N_DST 4 // each SIMD group works on 4 rows
433
505
  #define N_SIMDGROUP 2 // number of SIMD groups in a thread group
@@ -525,6 +597,43 @@ kernel void kernel_mul_mv_q4_1_f32(
525
597
  mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
526
598
  }
527
599
 
600
+ kernel void kernel_mul_mv_q5_0_f32(
601
+ device const void * src0,
602
+ device const float * src1,
603
+ device float * dst,
604
+ constant int64_t & ne00,
605
+ constant int64_t & ne01[[buffer(4)]],
606
+ constant int64_t & ne02[[buffer(5)]],
607
+ constant int64_t & ne10[[buffer(9)]],
608
+ constant int64_t & ne12[[buffer(11)]],
609
+ constant int64_t & ne0[[buffer(15)]],
610
+ constant int64_t & ne1[[buffer(16)]],
611
+ constant uint & gqa[[buffer(17)]],
612
+ uint3 tgpig[[threadgroup_position_in_grid]],
613
+ uint tiisg[[thread_index_in_simdgroup]],
614
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
615
+ mul_vec_q_n_f32<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
616
+ }
617
+
618
+ kernel void kernel_mul_mv_q5_1_f32(
619
+ device const void * src0,
620
+ device const float * src1,
621
+ device float * dst,
622
+ constant int64_t & ne00,
623
+ constant int64_t & ne01[[buffer(4)]],
624
+ constant int64_t & ne02[[buffer(5)]],
625
+ constant int64_t & ne10[[buffer(9)]],
626
+ constant int64_t & ne12[[buffer(11)]],
627
+ constant int64_t & ne0[[buffer(15)]],
628
+ constant int64_t & ne1[[buffer(16)]],
629
+ constant uint & gqa[[buffer(17)]],
630
+ uint3 tgpig[[threadgroup_position_in_grid]],
631
+ uint tiisg[[thread_index_in_simdgroup]],
632
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
633
+ mul_vec_q_n_f32<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
634
+ }
635
+
636
+
528
637
  #define NB_Q8_0 8
529
638
 
530
639
  kernel void kernel_mul_mv_q8_0_f32(
@@ -2149,6 +2258,62 @@ void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg
2149
2258
  }
2150
2259
  }
2151
2260
 
2261
+ template <typename type4x4>
2262
+ void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg) {
2263
+ device const uint16_t * qs = ((device const uint16_t *)xb + 3);
2264
+ const float d = xb->d;
2265
+ const float md = -16.h * xb->d;
2266
+ const ushort mask = il ? 0x00F0 : 0x000F;
2267
+
2268
+ const uint32_t qh = *((device const uint32_t *)xb->qh);
2269
+
2270
+ const int x_mv = il ? 4 : 0;
2271
+
2272
+ const int gh_mv = il ? 12 : 0;
2273
+ const int gh_bk = il ? 0 : 4;
2274
+
2275
+ for (int i = 0; i < 8; i++) {
2276
+ // extract the 5-th bits for x0 and x1
2277
+ const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
2278
+ const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
2279
+
2280
+ // combine the 4-bits from qs with the 5th bit
2281
+ const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
2282
+ const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
2283
+
2284
+ reg[i/2][2*(i%2)+0] = d * x0 + md;
2285
+ reg[i/2][2*(i%2)+1] = d * x1 + md;
2286
+ }
2287
+ }
2288
+
2289
+ template <typename type4x4>
2290
+ void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg) {
2291
+ device const uint16_t * qs = ((device const uint16_t *)xb + 4);
2292
+ const float d = xb->d;
2293
+ const float m = xb->m;
2294
+ const ushort mask = il ? 0x00F0 : 0x000F;
2295
+
2296
+ const uint32_t qh = *((device const uint32_t *)xb->qh);
2297
+
2298
+ const int x_mv = il ? 4 : 0;
2299
+
2300
+ const int gh_mv = il ? 12 : 0;
2301
+ const int gh_bk = il ? 0 : 4;
2302
+
2303
+ for (int i = 0; i < 8; i++) {
2304
+ // extract the 5-th bits for x0 and x1
2305
+ const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
2306
+ const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
2307
+
2308
+ // combine the 4-bits from qs with the 5th bit
2309
+ const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
2310
+ const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
2311
+
2312
+ reg[i/2][2*(i%2)+0] = d * x0 + m;
2313
+ reg[i/2][2*(i%2)+1] = d * x1 + m;
2314
+ }
2315
+ }
2316
+
2152
2317
  template <typename type4x4>
2153
2318
  void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
2154
2319
  device const int8_t * qs = ((device const int8_t *)xb->qs);
@@ -2490,6 +2655,8 @@ template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows
2490
2655
  template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
2491
2656
  template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
2492
2657
  template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
2658
+ template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows<block_q5_0, 2, dequantize_q5_0>;
2659
+ template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_t kernel_get_rows<block_q5_1, 2, dequantize_q5_1>;
2493
2660
  template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows<block_q8_0, 2, dequantize_q8_0>;
2494
2661
  template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows<block_q2_K, QK_NL, dequantize_q2_K>;
2495
2662
  template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows<block_q3_K, QK_NL, dequantize_q3_K>;
@@ -2518,6 +2685,8 @@ template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<f
2518
2685
  template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
2519
2686
  template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
2520
2687
  template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
2688
+ template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_0, 2, dequantize_q5_0>;
2689
+ template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_1, 2, dequantize_q5_1>;
2521
2690
  template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
2522
2691
  template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
2523
2692
  template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;