llama_cpp 0.3.4 → 0.3.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -39,6 +39,8 @@
39
39
  #define MIN(a, b) ((a) < (b) ? (a) : (b))
40
40
  #define MAX(a, b) ((a) > (b) ? (a) : (b))
41
41
 
42
+ #define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
43
+
42
44
  //
43
45
  // 2-6 bit quantization in super-blocks
44
46
  //
@@ -1353,7 +1355,7 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
1353
1355
  const __m256i all_scales = _mm256_cvtepi8_epi16(scales8);
1354
1356
  const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0);
1355
1357
  const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1);
1356
- const __m256i scales[2] = {_mm256_set_m128i(l_scales, l_scales), _mm256_set_m128i(h_scales, h_scales)};
1358
+ const __m256i scales[2] = {MM256_SET_M128I(l_scales, l_scales), MM256_SET_M128I(h_scales, h_scales)};
1357
1359
 
1358
1360
  __m256i sumi = _mm256_setzero_si256();
1359
1361
 
@@ -1421,7 +1423,7 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
1421
1423
  const __m128i summs_1 = _mm_madd_epi16(mins_1, _mm_loadu_si128((const __m128i*)&y[i].bsums[8]));
1422
1424
 
1423
1425
  // sumf += -dmin * summs in 32bits*8
1424
- acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&dmin), _mm256_cvtepi32_ps(_mm256_set_m128i(summs_1, summs_0))), acc);
1426
+ acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&dmin), _mm256_cvtepi32_ps(MM256_SET_M128I(summs_1, summs_0))), acc);
1425
1427
 
1426
1428
  const __m128i scales_0 = _mm_cvtepi8_epi16(scales16);
1427
1429
  const __m128i scales_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(scales16, scales16));
@@ -1493,7 +1495,7 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
1493
1495
  }
1494
1496
 
1495
1497
  // sumf += dall * isum - dmin * summs in 32bits
1496
- __m256i sumi = _mm256_set_m128i(sumi_1, sumi_0);
1498
+ __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
1497
1499
  acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&dall), _mm256_cvtepi32_ps(sumi)), acc);
1498
1500
  }
1499
1501
 
@@ -1644,8 +1646,8 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
1644
1646
  summs += dmin * smin;
1645
1647
 
1646
1648
  const __m128i q2bits = _mm_loadu_si128((const __m128i*)q2);
1647
- const __m256i q2_0 = _mm256_and_si256(_mm256_set_m128i(_mm_srli_epi16(q2bits, 2), q2bits), m3);
1648
- const __m256i q2_1 = _mm256_and_si256(_mm256_set_m128i(_mm_srli_epi16(q2bits, 6), _mm_srli_epi16(q2bits, 4)), m3);
1649
+ const __m256i q2_0 = _mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(q2bits, 2), q2bits), m3);
1650
+ const __m256i q2_1 = _mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(q2bits, 6), _mm_srli_epi16(q2bits, 4)), m3);
1649
1651
 
1650
1652
  const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
1651
1653
  const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
@@ -1666,6 +1668,62 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
1666
1668
 
1667
1669
  *s = hsum_float_8(acc) + summs;
1668
1670
 
1671
+ #elif defined __AVX__
1672
+
1673
+ const __m128i m3 = _mm_set1_epi8(3);
1674
+
1675
+ __m256 acc = _mm256_setzero_ps();
1676
+
1677
+ uint32_t ud, um;
1678
+ const uint8_t * restrict db = (const uint8_t *)&ud;
1679
+ const uint8_t * restrict mb = (const uint8_t *)&um;
1680
+
1681
+ float summs = 0;
1682
+
1683
+ // TODO: optimize this
1684
+
1685
+ for (int i = 0; i < nb; ++i) {
1686
+
1687
+ const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
1688
+ const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
1689
+
1690
+ const uint8_t * restrict q2 = x[i].qs;
1691
+ const int8_t * restrict q8 = y[i].qs;
1692
+
1693
+ const uint32_t * restrict sc = (const uint32_t *)x[i].scales;
1694
+ ud = (sc[0] >> 0) & 0x0f0f0f0f;
1695
+ um = (sc[0] >> 4) & 0x0f0f0f0f;
1696
+
1697
+ int32_t smin = mb[0] * y[i].bsums[0] + mb[1] * y[i].bsums[1] + mb[2] * y[i].bsums[2] + mb[3] * y[i].bsums[3];
1698
+ summs += dmin * smin;
1699
+
1700
+ const __m128i q2bits = _mm_loadu_si128((const __m128i*)q2);
1701
+ const __m128i q2_0 = _mm_and_si128(q2bits, m3);
1702
+ const __m128i q2_1 = _mm_and_si128(_mm_srli_epi16(q2bits, 2), m3);
1703
+ const __m128i q2_2 = _mm_and_si128(_mm_srli_epi16(q2bits, 4), m3);
1704
+ const __m128i q2_3 = _mm_and_si128(_mm_srli_epi16(q2bits, 6), m3);
1705
+
1706
+ const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
1707
+ const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
1708
+
1709
+ const __m128i p0 = _mm_maddubs_epi16(q2_0, _mm256_extractf128_si256(q8_0, 0));
1710
+ const __m128i p1 = _mm_maddubs_epi16(q2_1, _mm256_extractf128_si256(q8_0, 1));
1711
+ const __m128i p2 = _mm_maddubs_epi16(q2_2, _mm256_extractf128_si256(q8_1, 0));
1712
+ const __m128i p3 = _mm_maddubs_epi16(q2_3, _mm256_extractf128_si256(q8_1, 1));
1713
+
1714
+ const __m256i p_0 = MM256_SET_M128I(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p0, p0)), _mm_cvtepi16_epi32(p0));
1715
+ const __m256i p_1 = MM256_SET_M128I(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p1, p1)), _mm_cvtepi16_epi32(p1));
1716
+ const __m256i p_2 = MM256_SET_M128I(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p2, p2)), _mm_cvtepi16_epi32(p2));
1717
+ const __m256i p_3 = MM256_SET_M128I(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p3, p3)), _mm_cvtepi16_epi32(p3));
1718
+
1719
+ acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[0]), _mm256_cvtepi32_ps(p_0)), acc);
1720
+ acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[1]), _mm256_cvtepi32_ps(p_1)), acc);
1721
+ acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[2]), _mm256_cvtepi32_ps(p_2)), acc);
1722
+ acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[3]), _mm256_cvtepi32_ps(p_3)), acc);
1723
+ }
1724
+
1725
+ *s = hsum_float_8(acc) + summs;
1726
+
1669
1727
  #else
1670
1728
 
1671
1729
  float sumf = 0;
@@ -1861,7 +1919,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
1861
1919
  const __m256i all_scales = _mm256_cvtepi8_epi16(scales128);
1862
1920
  const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0);
1863
1921
  const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1);
1864
- const __m256i scales[2] = {_mm256_set_m128i(l_scales, l_scales), _mm256_set_m128i(h_scales, h_scales)};
1922
+ const __m256i scales[2] = {MM256_SET_M128I(l_scales, l_scales), MM256_SET_M128I(h_scales, h_scales)};
1865
1923
 
1866
1924
  // high bit
1867
1925
  const __m256i hbits = _mm256_loadu_si256((const __m256i*)x[i].hmask);
@@ -2072,7 +2130,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
2072
2130
  }
2073
2131
 
2074
2132
  // multiply with block scale and accumulate
2075
- __m256i sumi = _mm256_set_m128i(sumi_1, sumi_0);
2133
+ __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
2076
2134
  acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi)), acc);
2077
2135
 
2078
2136
  }
@@ -2247,13 +2305,13 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
2247
2305
  aux16[0] = a & 0x0f0f;
2248
2306
  aux16[1] = (a >> 4) & 0x0f0f;
2249
2307
 
2250
- const __m256i scale_0 = _mm256_set_m128i(_mm_set1_epi16(aux8[2] - 8), _mm_set1_epi16(aux8[0] - 8));
2251
- const __m256i scale_1 = _mm256_set_m128i(_mm_set1_epi16(aux8[3] - 8), _mm_set1_epi16(aux8[1] - 8));
2308
+ const __m256i scale_0 = MM256_SET_M128I(_mm_set1_epi16(aux8[2] - 8), _mm_set1_epi16(aux8[0] - 8));
2309
+ const __m256i scale_1 = MM256_SET_M128I(_mm_set1_epi16(aux8[3] - 8), _mm_set1_epi16(aux8[1] - 8));
2252
2310
 
2253
2311
  memcpy(&aux64, x[i].hmask, 8);
2254
2312
 
2255
2313
  const __m128i haux = _mm_set_epi64x(aux64 >> 1, aux64 >> 0);
2256
- __m256i q3h_0 = _mm256_set_m128i(_mm_srli_epi16(haux, 2), haux);
2314
+ __m256i q3h_0 = MM256_SET_M128I(_mm_srli_epi16(haux, 2), haux);
2257
2315
  __m256i q3h_1 = _mm256_srli_epi16(q3h_0, 4);
2258
2316
  q3h_0 = _mm256_slli_epi16(_mm256_andnot_si256(q3h_0, m1), 2);
2259
2317
  q3h_1 = _mm256_slli_epi16(_mm256_andnot_si256(q3h_1, m1), 2);
@@ -2262,7 +2320,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
2262
2320
  const __m128i q3bits = _mm_loadu_si128((const __m128i*)q3);
2263
2321
 
2264
2322
  // prepare low and high bits
2265
- const __m256i q3aux = _mm256_set_m128i(_mm_srli_epi16(q3bits, 2), q3bits);
2323
+ const __m256i q3aux = MM256_SET_M128I(_mm_srli_epi16(q3bits, 2), q3bits);
2266
2324
  const __m256i q3l_0 = _mm256_and_si256(q3aux, m3);
2267
2325
  const __m256i q3l_1 = _mm256_and_si256(_mm256_srli_epi16(q3aux, 4), m3);
2268
2326
 
@@ -2295,6 +2353,93 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
2295
2353
 
2296
2354
  *s = hsum_float_8(acc);
2297
2355
 
2356
+ #elif defined __AVX__
2357
+
2358
+ const __m128i m3 = _mm_set1_epi8(3);
2359
+ const __m128i m1 = _mm_set1_epi8(1);
2360
+
2361
+ __m256 acc = _mm256_setzero_ps();
2362
+
2363
+ uint64_t aux64;
2364
+
2365
+ uint16_t aux16[2];
2366
+ const int8_t * aux8 = (const int8_t *)aux16;
2367
+
2368
+ for (int i = 0; i < nb; ++i) {
2369
+
2370
+ const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
2371
+
2372
+ const uint8_t * restrict q3 = x[i].qs;
2373
+ const int8_t * restrict q8 = y[i].qs;
2374
+
2375
+ const uint16_t a = *(const uint16_t *)x[i].scales;
2376
+ aux16[0] = a & 0x0f0f;
2377
+ aux16[1] = (a >> 4) & 0x0f0f;
2378
+
2379
+ const __m128i scale_0 = _mm_set1_epi16(aux8[0] - 8);
2380
+ const __m128i scale_1 = _mm_set1_epi16(aux8[2] - 8);
2381
+ const __m128i scale_2 = _mm_set1_epi16(aux8[1] - 8);
2382
+ const __m128i scale_3 = _mm_set1_epi16(aux8[3] - 8);
2383
+
2384
+ memcpy(&aux64, x[i].hmask, 8);
2385
+
2386
+ __m128i q3h_0 = _mm_set_epi64x(aux64 >> 1, aux64 >> 0);
2387
+ __m128i q3h_1 = _mm_srli_epi16(q3h_0, 2);
2388
+ __m128i q3h_2 = _mm_srli_epi16(q3h_0, 4);
2389
+ __m128i q3h_3 = _mm_srli_epi16(q3h_0, 6);
2390
+ q3h_0 = _mm_slli_epi16(_mm_andnot_si128(q3h_0, m1), 2);
2391
+ q3h_1 = _mm_slli_epi16(_mm_andnot_si128(q3h_1, m1), 2);
2392
+ q3h_2 = _mm_slli_epi16(_mm_andnot_si128(q3h_2, m1), 2);
2393
+ q3h_3 = _mm_slli_epi16(_mm_andnot_si128(q3h_3, m1), 2);
2394
+
2395
+ // load low 2 bits
2396
+ const __m128i q3bits = _mm_loadu_si128((const __m128i*)q3);
2397
+
2398
+ // prepare low and high bits
2399
+ const __m128i q3l_0 = _mm_and_si128(q3bits, m3);
2400
+ const __m128i q3l_1 = _mm_and_si128(_mm_srli_epi16(q3bits, 2), m3);
2401
+ const __m128i q3l_2 = _mm_and_si128(_mm_srli_epi16(q3bits, 4), m3);
2402
+ const __m128i q3l_3 = _mm_and_si128(_mm_srli_epi16(q3bits, 6), m3);
2403
+
2404
+ // load Q8 quants
2405
+ const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
2406
+ const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
2407
+
2408
+ // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm_maddubs_epi16,
2409
+ // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set,
2410
+ // and 2 if the high bit was set)
2411
+ const __m128i q8s_0 = _mm_maddubs_epi16(q3h_0, _mm256_extractf128_si256(q8_0, 0));
2412
+ const __m128i q8s_1 = _mm_maddubs_epi16(q3h_1, _mm256_extractf128_si256(q8_0, 1));
2413
+ const __m128i q8s_2 = _mm_maddubs_epi16(q3h_2, _mm256_extractf128_si256(q8_1, 0));
2414
+ const __m128i q8s_3 = _mm_maddubs_epi16(q3h_3, _mm256_extractf128_si256(q8_1, 1));
2415
+
2416
+ __m128i p16_0 = _mm_maddubs_epi16(q3l_0, _mm256_extractf128_si256(q8_0, 0));
2417
+ __m128i p16_1 = _mm_maddubs_epi16(q3l_1, _mm256_extractf128_si256(q8_0, 1));
2418
+ __m128i p16_2 = _mm_maddubs_epi16(q3l_2, _mm256_extractf128_si256(q8_1, 0));
2419
+ __m128i p16_3 = _mm_maddubs_epi16(q3l_3, _mm256_extractf128_si256(q8_1, 1));
2420
+
2421
+ p16_0 = _mm_sub_epi16(p16_0, q8s_0);
2422
+ p16_1 = _mm_sub_epi16(p16_1, q8s_1);
2423
+ p16_2 = _mm_sub_epi16(p16_2, q8s_2);
2424
+ p16_3 = _mm_sub_epi16(p16_3, q8s_3);
2425
+
2426
+ // multiply with scales
2427
+ p16_0 = _mm_madd_epi16(scale_0, p16_0);
2428
+ p16_1 = _mm_madd_epi16(scale_1, p16_1);
2429
+ p16_2 = _mm_madd_epi16(scale_2, p16_2);
2430
+ p16_3 = _mm_madd_epi16(scale_3, p16_3);
2431
+
2432
+ p16_0 = _mm_add_epi32(p16_0, p16_2);
2433
+ p16_1 = _mm_add_epi32(p16_1, p16_3);
2434
+ __m256i p16 = MM256_SET_M128I(p16_1, p16_0);
2435
+
2436
+ // multiply with block scale and accumulate
2437
+ acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(p16)), acc);
2438
+
2439
+ }
2440
+
2441
+ *s = hsum_float_8(acc);
2442
+
2298
2443
  #else
2299
2444
 
2300
2445
  int8_t aux8[QK_K];
@@ -2477,7 +2622,7 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
2477
2622
  acc_m = _mm_fmadd_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod), acc_m);
2478
2623
 
2479
2624
  const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0);
2480
- const __m256i scales = _mm256_set_m128i(sc128, sc128);
2625
+ const __m256i scales = MM256_SET_M128I(sc128, sc128);
2481
2626
 
2482
2627
  __m256i sumi = _mm256_setzero_si256();
2483
2628
 
@@ -2584,7 +2729,7 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
2584
2729
  }
2585
2730
 
2586
2731
  __m256 vd = _mm256_set1_ps(d);
2587
- __m256i sumi = _mm256_set_m128i(sumi_1, sumi_0);
2732
+ __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
2588
2733
  acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(sumi)), acc);
2589
2734
 
2590
2735
  }
@@ -2781,6 +2926,60 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
2781
2926
 
2782
2927
  *s = hsum_float_8(acc) - summs;
2783
2928
 
2929
+ #elif defined __AVX__
2930
+
2931
+ const __m128i m4 = _mm_set1_epi8(0xF);
2932
+
2933
+ __m256 acc = _mm256_setzero_ps();
2934
+
2935
+ float summs = 0;
2936
+
2937
+ uint16_t aux16[2];
2938
+ const uint8_t * scales = (const uint8_t *)aux16;
2939
+
2940
+ for (int i = 0; i < nb; ++i) {
2941
+
2942
+ const float d = ggml_fp16_to_fp32(x[i].d[0]) * y[i].d;
2943
+ const float m = ggml_fp16_to_fp32(x[i].d[1]) * y[i].d;
2944
+ const __m256 vd = _mm256_set1_ps(d);
2945
+
2946
+ const uint16_t * a = (const uint16_t *)x[i].scales;
2947
+ aux16[0] = a[0] & 0x0f0f;
2948
+ aux16[1] = (a[0] >> 4) & 0x0f0f;
2949
+
2950
+ summs += m * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3]));
2951
+
2952
+ const uint8_t * restrict q4 = x[i].qs;
2953
+ const int8_t * restrict q8 = y[i].qs;
2954
+
2955
+ const __m256i q4bits = _mm256_loadu_si256((const __m256i*)q4);
2956
+ const __m128i q4bits_0 = _mm256_extractf128_si256(q4bits, 0);
2957
+ const __m128i q4bits_1 = _mm256_extractf128_si256(q4bits, 1);
2958
+ const __m128i q4_0 = _mm_and_si128(q4bits_0, m4);
2959
+ const __m128i q4_1 = _mm_and_si128(q4bits_1, m4);
2960
+ const __m128i q4_2 = _mm_and_si128(_mm_srli_epi16(q4bits_0, 4), m4);
2961
+ const __m128i q4_3 = _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4);
2962
+
2963
+ const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
2964
+ const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
2965
+
2966
+ const __m128i p16_0 = _mm_maddubs_epi16(q4_0, _mm256_extractf128_si256(q8_0, 0));
2967
+ const __m128i p16_1 = _mm_maddubs_epi16(q4_1, _mm256_extractf128_si256(q8_0, 1));
2968
+ const __m128i p16_2 = _mm_maddubs_epi16(q4_2, _mm256_extractf128_si256(q8_1, 0));
2969
+ const __m128i p16_3 = _mm_maddubs_epi16(q4_3, _mm256_extractf128_si256(q8_1, 1));
2970
+
2971
+ const __m128i p32_0 = _mm_madd_epi16(_mm_set1_epi16(scales[0]), p16_0);
2972
+ const __m128i p32_1 = _mm_madd_epi16(_mm_set1_epi16(scales[0]), p16_1);
2973
+ acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(MM256_SET_M128I(p32_1, p32_0))), acc);
2974
+
2975
+ const __m128i p32_2 = _mm_madd_epi16(_mm_set1_epi16(scales[1]), p16_2);
2976
+ const __m128i p32_3 = _mm_madd_epi16(_mm_set1_epi16(scales[1]), p16_3);
2977
+ acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(MM256_SET_M128I(p32_3, p32_2))), acc);
2978
+
2979
+ }
2980
+
2981
+ *s = hsum_float_8(acc) - summs;
2982
+
2784
2983
  #else
2785
2984
 
2786
2985
  uint8_t aux8[QK_K];
@@ -2963,7 +3162,7 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
2963
3162
  summs += dmin * _mm_extract_epi32(hsum, 0);
2964
3163
 
2965
3164
  const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0);
2966
- const __m256i scales = _mm256_set_m128i(sc128, sc128);
3165
+ const __m256i scales = MM256_SET_M128I(sc128, sc128);
2967
3166
 
2968
3167
  const __m256i hbits = _mm256_loadu_si256((const __m256i*)x[i].qh);
2969
3168
  __m256i hmask = mone;
@@ -3102,7 +3301,7 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
3102
3301
  }
3103
3302
 
3104
3303
  __m256 vd = _mm256_set1_ps(d);
3105
- __m256i sumi = _mm256_set_m128i(sumi_1, sumi_0);
3304
+ __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
3106
3305
  acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(sumi)), acc);
3107
3306
 
3108
3307
  }
@@ -3265,13 +3464,13 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
3265
3464
 
3266
3465
  const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5);
3267
3466
 
3268
- const __m256i scale_l = _mm256_set_m128i(_mm_set1_epi16(x[i].scales[1]), _mm_set1_epi16(x[i].scales[0]));
3269
- const __m256i scale_h = _mm256_set_m128i(_mm_set1_epi16(x[i].scales[3]), _mm_set1_epi16(x[i].scales[2]));
3467
+ const __m256i scale_l = MM256_SET_M128I(_mm_set1_epi16(x[i].scales[1]), _mm_set1_epi16(x[i].scales[0]));
3468
+ const __m256i scale_h = MM256_SET_M128I(_mm_set1_epi16(x[i].scales[3]), _mm_set1_epi16(x[i].scales[2]));
3270
3469
 
3271
3470
  int64_t aux64;
3272
3471
  memcpy(&aux64, x[i].qh, 8);
3273
3472
  const __m128i haux128 = _mm_set_epi64x(aux64 >> 1, aux64);
3274
- const __m256i haux256 = _mm256_set_m128i(_mm_srli_epi16(haux128, 2), haux128);
3473
+ const __m256i haux256 = MM256_SET_M128I(_mm_srli_epi16(haux128, 2), haux128);
3275
3474
 
3276
3475
  const __m256i q5h_0 = _mm256_slli_epi16(_mm256_andnot_si256(haux256, mone), 4);
3277
3476
  const __m256i q5h_1 = _mm256_slli_epi16(_mm256_andnot_si256(_mm256_srli_epi16(haux256, 4), mone), 4);
@@ -3295,10 +3494,66 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
3295
3494
 
3296
3495
  *s = hsum_float_8(acc);
3297
3496
 
3298
- #else
3497
+ #elif defined __AVX__
3299
3498
 
3499
+ const __m128i m4 = _mm_set1_epi8(0xF);
3500
+ const __m128i mone = _mm_set1_epi8(1);
3300
3501
 
3301
- uint8_t aux8[QK_K];
3502
+ __m256 acc = _mm256_setzero_ps();
3503
+
3504
+ for (int i = 0; i < nb; ++i) {
3505
+
3506
+ const uint8_t * restrict q5 = x[i].qs;
3507
+ const int8_t * restrict q8 = y[i].qs;
3508
+
3509
+ const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
3510
+
3511
+ const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5);
3512
+
3513
+ const __m128i scale_0 = _mm_set1_epi16(x[i].scales[0]);
3514
+ const __m128i scale_1 = _mm_set1_epi16(x[i].scales[1]);
3515
+ const __m128i scale_2 = _mm_set1_epi16(x[i].scales[2]);
3516
+ const __m128i scale_3 = _mm_set1_epi16(x[i].scales[3]);
3517
+
3518
+ int64_t aux64;
3519
+ memcpy(&aux64, x[i].qh, 8);
3520
+ const __m128i haux128_0 = _mm_set_epi64x(aux64 >> 1, aux64);
3521
+ const __m128i haux128_1 = _mm_srli_epi16(haux128_0, 2);
3522
+
3523
+ const __m128i q5h_0 = _mm_slli_epi16(_mm_andnot_si128(haux128_0, mone), 4);
3524
+ const __m128i q5h_1 = _mm_slli_epi16(_mm_andnot_si128(haux128_1, mone), 4);
3525
+ const __m128i q5h_2 = _mm_slli_epi16(_mm_andnot_si128(_mm_srli_epi16(haux128_0, 4), mone), 4);
3526
+ const __m128i q5h_3 = _mm_slli_epi16(_mm_andnot_si128(_mm_srli_epi16(haux128_1, 4), mone), 4);
3527
+
3528
+ const __m128i q5l_0 = _mm_and_si128(_mm256_extractf128_si256(q5bits, 0), m4);
3529
+ const __m128i q5l_1 = _mm_and_si128(_mm256_extractf128_si256(q5bits, 1), m4);
3530
+ const __m128i q5l_2 = _mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q5bits, 0), 4), m4);
3531
+ const __m128i q5l_3 = _mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q5bits, 1), 4), m4);
3532
+
3533
+ const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
3534
+ const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
3535
+
3536
+ const __m128i p16_0 = _mm_madd_epi16(scale_0, _mm_maddubs_epi16(q5l_0, _mm256_extractf128_si256(q8_0, 0)));
3537
+ const __m128i p16_1 = _mm_madd_epi16(scale_1, _mm_maddubs_epi16(q5l_1, _mm256_extractf128_si256(q8_0, 1)));
3538
+ const __m128i p16_2 = _mm_madd_epi16(scale_2, _mm_maddubs_epi16(q5l_2, _mm256_extractf128_si256(q8_1, 0)));
3539
+ const __m128i p16_3 = _mm_madd_epi16(scale_3, _mm_maddubs_epi16(q5l_3, _mm256_extractf128_si256(q8_1, 1)));
3540
+ const __m128i s16_0 = _mm_madd_epi16(scale_0, _mm_maddubs_epi16(q5h_0, _mm256_extractf128_si256(q8_0, 0)));
3541
+ const __m128i s16_1 = _mm_madd_epi16(scale_1, _mm_maddubs_epi16(q5h_1, _mm256_extractf128_si256(q8_0, 1)));
3542
+ const __m128i s16_2 = _mm_madd_epi16(scale_2, _mm_maddubs_epi16(q5h_2, _mm256_extractf128_si256(q8_1, 0)));
3543
+ const __m128i s16_3 = _mm_madd_epi16(scale_3, _mm_maddubs_epi16(q5h_3, _mm256_extractf128_si256(q8_1, 1)));
3544
+
3545
+ const __m128i dot_0 = _mm_sub_epi32(_mm_add_epi32(p16_0, p16_2), _mm_add_epi32(s16_0, s16_2));
3546
+ const __m128i dot_1 = _mm_sub_epi32(_mm_add_epi32(p16_1, p16_3), _mm_add_epi32(s16_1, s16_3));
3547
+
3548
+ acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(dot_1, dot_0))), acc);
3549
+
3550
+ }
3551
+
3552
+ *s = hsum_float_8(acc);
3553
+
3554
+ #else
3555
+
3556
+ int8_t aux8[QK_K];
3302
3557
  int16_t aux16[16];
3303
3558
  float sums [8];
3304
3559
  memset(sums, 0, 8*sizeof(float));
@@ -3308,7 +3563,7 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
3308
3563
  const uint8_t * restrict q4 = x[i].qs;
3309
3564
  const uint8_t * restrict hm = x[i].qh;
3310
3565
  const int8_t * restrict q8 = y[i].qs;
3311
- uint8_t * restrict a = aux8;
3566
+ int8_t * restrict a = aux8;
3312
3567
  for (int l = 0; l < 32; ++l) {
3313
3568
  a[l+ 0] = q4[l] & 0xF;
3314
3569
  a[l+32] = q4[l] >> 4;
@@ -3672,7 +3927,7 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
3672
3927
 
3673
3928
  }
3674
3929
 
3675
- __m256i sumi = _mm256_set_m128i(sumi_1, sumi_0);
3930
+ __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
3676
3931
  acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi)), acc);
3677
3932
  }
3678
3933
 
@@ -3830,8 +4085,8 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
3830
4085
  const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4);
3831
4086
  const __m128i q4bitsH = _mm_loadu_si128((const __m128i*)qh);
3832
4087
 
3833
- const __m256i q4h_0 = _mm256_slli_epi16(_mm256_and_si256(_mm256_set_m128i(_mm_srli_epi16(q4bitsH, 2), q4bitsH), m2), 4);
3834
- const __m256i q4h_1 = _mm256_slli_epi16(_mm256_and_si256(_mm256_set_m128i(_mm_srli_epi16(q4bitsH, 6), _mm_srli_epi16(q4bitsH, 4)), m2), 4);
4088
+ const __m256i q4h_0 = _mm256_slli_epi16(_mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(q4bitsH, 2), q4bitsH), m2), 4);
4089
+ const __m256i q4h_1 = _mm256_slli_epi16(_mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(q4bitsH, 6), _mm_srli_epi16(q4bitsH, 4)), m2), 4);
3835
4090
 
3836
4091
  const __m256i q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m4), q4h_0);
3837
4092
  const __m256i q4_1 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m4), q4h_1);
@@ -3858,6 +4113,77 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
3858
4113
 
3859
4114
  *s = hsum_float_8(acc);
3860
4115
 
4116
+ #elif defined __AVX__
4117
+
4118
+ const __m128i m4 = _mm_set1_epi8(0xF);
4119
+ const __m128i m2 = _mm_set1_epi8(3);
4120
+ const __m128i m32s = _mm_set1_epi8(32);
4121
+
4122
+ __m256 acc = _mm256_setzero_ps();
4123
+
4124
+ for (int i = 0; i < nb; ++i) {
4125
+
4126
+ const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
4127
+
4128
+ const uint8_t * restrict q4 = x[i].ql;
4129
+ const uint8_t * restrict qh = x[i].qh;
4130
+ const int8_t * restrict q8 = y[i].qs;
4131
+
4132
+ const __m64 scales_1 = _mm_set1_pi8(x[i].scales[0]);
4133
+ const __m64 scales_2 = _mm_set1_pi8(x[i].scales[1]);
4134
+ const __m64 scales_3 = _mm_set1_pi8(x[i].scales[2]);
4135
+ const __m64 scales_4 = _mm_set1_pi8(x[i].scales[3]);
4136
+
4137
+ __m128i sumi_0 = _mm_setzero_si128();
4138
+ __m128i sumi_1 = _mm_setzero_si128();
4139
+
4140
+ const __m128i scale_0 = _mm_set_epi64(scales_2, scales_1);
4141
+ const __m128i scale_1 = _mm_set_epi64(scales_4, scales_3);
4142
+
4143
+ const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4);
4144
+ const __m128i q4bitsH = _mm_loadu_si128((const __m128i*)qh);
4145
+
4146
+ const __m128i q4h_0 = _mm_slli_epi16(_mm_and_si128(q4bitsH, m2), 4);
4147
+ const __m128i q4h_1 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH, 2), m2), 4);
4148
+ const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH, 4), m2), 4);
4149
+ const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH, 6), m2), 4);
4150
+
4151
+ const __m128i q4_0 = _mm_or_si128(_mm_and_si128(_mm256_extractf128_si256(q4bits1, 0), m4), q4h_0);
4152
+ const __m128i q4_1 = _mm_or_si128(_mm_and_si128(_mm256_extractf128_si256(q4bits1, 1), m4), q4h_1);
4153
+ const __m128i q4_2 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q4bits1, 0), 4), m4), q4h_2);
4154
+ const __m128i q4_3 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q4bits1, 1), 4), m4), q4h_3);
4155
+
4156
+ const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
4157
+ const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
4158
+
4159
+ __m128i q8s_0 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_0, 0));
4160
+ __m128i q8s_1 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_0, 1));
4161
+ __m128i q8s_2 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_1, 0));
4162
+ __m128i q8s_3 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_1, 1));
4163
+
4164
+ __m128i p16_0 = _mm_maddubs_epi16(q4_0, _mm256_extractf128_si256(q8_0, 0));
4165
+ __m128i p16_1 = _mm_maddubs_epi16(q4_1, _mm256_extractf128_si256(q8_0, 1));
4166
+ __m128i p16_2 = _mm_maddubs_epi16(q4_2, _mm256_extractf128_si256(q8_1, 0));
4167
+ __m128i p16_3 = _mm_maddubs_epi16(q4_3, _mm256_extractf128_si256(q8_1, 1));
4168
+
4169
+ p16_0 = _mm_sub_epi16(p16_0, q8s_0);
4170
+ p16_1 = _mm_sub_epi16(p16_1, q8s_1);
4171
+ p16_2 = _mm_sub_epi16(p16_2, q8s_2);
4172
+ p16_3 = _mm_sub_epi16(p16_3, q8s_3);
4173
+
4174
+ p16_0 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_0), p16_0);
4175
+ p16_1 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_0, scale_0)), p16_1);
4176
+ p16_2 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_1), p16_2);
4177
+ p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_1, scale_1)), p16_3);
4178
+
4179
+ sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2));
4180
+ sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3));
4181
+
4182
+ acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(MM256_SET_M128I(sumi_1, sumi_0))), acc);
4183
+ }
4184
+
4185
+ *s = hsum_float_8(acc);
4186
+
3861
4187
  #else
3862
4188
 
3863
4189
  int8_t aux8[QK_K];