llama_cpp 0.3.4 → 0.3.6

Sign up to get free protection for your applications and to get access to all the features.
@@ -39,6 +39,8 @@
39
39
  #define MIN(a, b) ((a) < (b) ? (a) : (b))
40
40
  #define MAX(a, b) ((a) > (b) ? (a) : (b))
41
41
 
42
+ #define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
43
+
42
44
  //
43
45
  // 2-6 bit quantization in super-blocks
44
46
  //
@@ -1353,7 +1355,7 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
1353
1355
  const __m256i all_scales = _mm256_cvtepi8_epi16(scales8);
1354
1356
  const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0);
1355
1357
  const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1);
1356
- const __m256i scales[2] = {_mm256_set_m128i(l_scales, l_scales), _mm256_set_m128i(h_scales, h_scales)};
1358
+ const __m256i scales[2] = {MM256_SET_M128I(l_scales, l_scales), MM256_SET_M128I(h_scales, h_scales)};
1357
1359
 
1358
1360
  __m256i sumi = _mm256_setzero_si256();
1359
1361
 
@@ -1421,7 +1423,7 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
1421
1423
  const __m128i summs_1 = _mm_madd_epi16(mins_1, _mm_loadu_si128((const __m128i*)&y[i].bsums[8]));
1422
1424
 
1423
1425
  // sumf += -dmin * summs in 32bits*8
1424
- acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&dmin), _mm256_cvtepi32_ps(_mm256_set_m128i(summs_1, summs_0))), acc);
1426
+ acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&dmin), _mm256_cvtepi32_ps(MM256_SET_M128I(summs_1, summs_0))), acc);
1425
1427
 
1426
1428
  const __m128i scales_0 = _mm_cvtepi8_epi16(scales16);
1427
1429
  const __m128i scales_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(scales16, scales16));
@@ -1493,7 +1495,7 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
1493
1495
  }
1494
1496
 
1495
1497
  // sumf += dall * isum - dmin * summs in 32bits
1496
- __m256i sumi = _mm256_set_m128i(sumi_1, sumi_0);
1498
+ __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
1497
1499
  acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&dall), _mm256_cvtepi32_ps(sumi)), acc);
1498
1500
  }
1499
1501
 
@@ -1644,8 +1646,8 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
1644
1646
  summs += dmin * smin;
1645
1647
 
1646
1648
  const __m128i q2bits = _mm_loadu_si128((const __m128i*)q2);
1647
- const __m256i q2_0 = _mm256_and_si256(_mm256_set_m128i(_mm_srli_epi16(q2bits, 2), q2bits), m3);
1648
- const __m256i q2_1 = _mm256_and_si256(_mm256_set_m128i(_mm_srli_epi16(q2bits, 6), _mm_srli_epi16(q2bits, 4)), m3);
1649
+ const __m256i q2_0 = _mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(q2bits, 2), q2bits), m3);
1650
+ const __m256i q2_1 = _mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(q2bits, 6), _mm_srli_epi16(q2bits, 4)), m3);
1649
1651
 
1650
1652
  const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
1651
1653
  const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
@@ -1666,6 +1668,62 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
1666
1668
 
1667
1669
  *s = hsum_float_8(acc) + summs;
1668
1670
 
1671
+ #elif defined __AVX__
1672
+
1673
+ const __m128i m3 = _mm_set1_epi8(3);
1674
+
1675
+ __m256 acc = _mm256_setzero_ps();
1676
+
1677
+ uint32_t ud, um;
1678
+ const uint8_t * restrict db = (const uint8_t *)&ud;
1679
+ const uint8_t * restrict mb = (const uint8_t *)&um;
1680
+
1681
+ float summs = 0;
1682
+
1683
+ // TODO: optimize this
1684
+
1685
+ for (int i = 0; i < nb; ++i) {
1686
+
1687
+ const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
1688
+ const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
1689
+
1690
+ const uint8_t * restrict q2 = x[i].qs;
1691
+ const int8_t * restrict q8 = y[i].qs;
1692
+
1693
+ const uint32_t * restrict sc = (const uint32_t *)x[i].scales;
1694
+ ud = (sc[0] >> 0) & 0x0f0f0f0f;
1695
+ um = (sc[0] >> 4) & 0x0f0f0f0f;
1696
+
1697
+ int32_t smin = mb[0] * y[i].bsums[0] + mb[1] * y[i].bsums[1] + mb[2] * y[i].bsums[2] + mb[3] * y[i].bsums[3];
1698
+ summs += dmin * smin;
1699
+
1700
+ const __m128i q2bits = _mm_loadu_si128((const __m128i*)q2);
1701
+ const __m128i q2_0 = _mm_and_si128(q2bits, m3);
1702
+ const __m128i q2_1 = _mm_and_si128(_mm_srli_epi16(q2bits, 2), m3);
1703
+ const __m128i q2_2 = _mm_and_si128(_mm_srli_epi16(q2bits, 4), m3);
1704
+ const __m128i q2_3 = _mm_and_si128(_mm_srli_epi16(q2bits, 6), m3);
1705
+
1706
+ const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
1707
+ const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
1708
+
1709
+ const __m128i p0 = _mm_maddubs_epi16(q2_0, _mm256_extractf128_si256(q8_0, 0));
1710
+ const __m128i p1 = _mm_maddubs_epi16(q2_1, _mm256_extractf128_si256(q8_0, 1));
1711
+ const __m128i p2 = _mm_maddubs_epi16(q2_2, _mm256_extractf128_si256(q8_1, 0));
1712
+ const __m128i p3 = _mm_maddubs_epi16(q2_3, _mm256_extractf128_si256(q8_1, 1));
1713
+
1714
+ const __m256i p_0 = MM256_SET_M128I(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p0, p0)), _mm_cvtepi16_epi32(p0));
1715
+ const __m256i p_1 = MM256_SET_M128I(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p1, p1)), _mm_cvtepi16_epi32(p1));
1716
+ const __m256i p_2 = MM256_SET_M128I(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p2, p2)), _mm_cvtepi16_epi32(p2));
1717
+ const __m256i p_3 = MM256_SET_M128I(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p3, p3)), _mm_cvtepi16_epi32(p3));
1718
+
1719
+ acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[0]), _mm256_cvtepi32_ps(p_0)), acc);
1720
+ acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[1]), _mm256_cvtepi32_ps(p_1)), acc);
1721
+ acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[2]), _mm256_cvtepi32_ps(p_2)), acc);
1722
+ acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[3]), _mm256_cvtepi32_ps(p_3)), acc);
1723
+ }
1724
+
1725
+ *s = hsum_float_8(acc) + summs;
1726
+
1669
1727
  #else
1670
1728
 
1671
1729
  float sumf = 0;
@@ -1861,7 +1919,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
1861
1919
  const __m256i all_scales = _mm256_cvtepi8_epi16(scales128);
1862
1920
  const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0);
1863
1921
  const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1);
1864
- const __m256i scales[2] = {_mm256_set_m128i(l_scales, l_scales), _mm256_set_m128i(h_scales, h_scales)};
1922
+ const __m256i scales[2] = {MM256_SET_M128I(l_scales, l_scales), MM256_SET_M128I(h_scales, h_scales)};
1865
1923
 
1866
1924
  // high bit
1867
1925
  const __m256i hbits = _mm256_loadu_si256((const __m256i*)x[i].hmask);
@@ -2072,7 +2130,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
2072
2130
  }
2073
2131
 
2074
2132
  // multiply with block scale and accumulate
2075
- __m256i sumi = _mm256_set_m128i(sumi_1, sumi_0);
2133
+ __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
2076
2134
  acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi)), acc);
2077
2135
 
2078
2136
  }
@@ -2247,13 +2305,13 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
2247
2305
  aux16[0] = a & 0x0f0f;
2248
2306
  aux16[1] = (a >> 4) & 0x0f0f;
2249
2307
 
2250
- const __m256i scale_0 = _mm256_set_m128i(_mm_set1_epi16(aux8[2] - 8), _mm_set1_epi16(aux8[0] - 8));
2251
- const __m256i scale_1 = _mm256_set_m128i(_mm_set1_epi16(aux8[3] - 8), _mm_set1_epi16(aux8[1] - 8));
2308
+ const __m256i scale_0 = MM256_SET_M128I(_mm_set1_epi16(aux8[2] - 8), _mm_set1_epi16(aux8[0] - 8));
2309
+ const __m256i scale_1 = MM256_SET_M128I(_mm_set1_epi16(aux8[3] - 8), _mm_set1_epi16(aux8[1] - 8));
2252
2310
 
2253
2311
  memcpy(&aux64, x[i].hmask, 8);
2254
2312
 
2255
2313
  const __m128i haux = _mm_set_epi64x(aux64 >> 1, aux64 >> 0);
2256
- __m256i q3h_0 = _mm256_set_m128i(_mm_srli_epi16(haux, 2), haux);
2314
+ __m256i q3h_0 = MM256_SET_M128I(_mm_srli_epi16(haux, 2), haux);
2257
2315
  __m256i q3h_1 = _mm256_srli_epi16(q3h_0, 4);
2258
2316
  q3h_0 = _mm256_slli_epi16(_mm256_andnot_si256(q3h_0, m1), 2);
2259
2317
  q3h_1 = _mm256_slli_epi16(_mm256_andnot_si256(q3h_1, m1), 2);
@@ -2262,7 +2320,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
2262
2320
  const __m128i q3bits = _mm_loadu_si128((const __m128i*)q3);
2263
2321
 
2264
2322
  // prepare low and high bits
2265
- const __m256i q3aux = _mm256_set_m128i(_mm_srli_epi16(q3bits, 2), q3bits);
2323
+ const __m256i q3aux = MM256_SET_M128I(_mm_srli_epi16(q3bits, 2), q3bits);
2266
2324
  const __m256i q3l_0 = _mm256_and_si256(q3aux, m3);
2267
2325
  const __m256i q3l_1 = _mm256_and_si256(_mm256_srli_epi16(q3aux, 4), m3);
2268
2326
 
@@ -2295,6 +2353,93 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
2295
2353
 
2296
2354
  *s = hsum_float_8(acc);
2297
2355
 
2356
+ #elif defined __AVX__
2357
+
2358
+ const __m128i m3 = _mm_set1_epi8(3);
2359
+ const __m128i m1 = _mm_set1_epi8(1);
2360
+
2361
+ __m256 acc = _mm256_setzero_ps();
2362
+
2363
+ uint64_t aux64;
2364
+
2365
+ uint16_t aux16[2];
2366
+ const int8_t * aux8 = (const int8_t *)aux16;
2367
+
2368
+ for (int i = 0; i < nb; ++i) {
2369
+
2370
+ const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
2371
+
2372
+ const uint8_t * restrict q3 = x[i].qs;
2373
+ const int8_t * restrict q8 = y[i].qs;
2374
+
2375
+ const uint16_t a = *(const uint16_t *)x[i].scales;
2376
+ aux16[0] = a & 0x0f0f;
2377
+ aux16[1] = (a >> 4) & 0x0f0f;
2378
+
2379
+ const __m128i scale_0 = _mm_set1_epi16(aux8[0] - 8);
2380
+ const __m128i scale_1 = _mm_set1_epi16(aux8[2] - 8);
2381
+ const __m128i scale_2 = _mm_set1_epi16(aux8[1] - 8);
2382
+ const __m128i scale_3 = _mm_set1_epi16(aux8[3] - 8);
2383
+
2384
+ memcpy(&aux64, x[i].hmask, 8);
2385
+
2386
+ __m128i q3h_0 = _mm_set_epi64x(aux64 >> 1, aux64 >> 0);
2387
+ __m128i q3h_1 = _mm_srli_epi16(q3h_0, 2);
2388
+ __m128i q3h_2 = _mm_srli_epi16(q3h_0, 4);
2389
+ __m128i q3h_3 = _mm_srli_epi16(q3h_0, 6);
2390
+ q3h_0 = _mm_slli_epi16(_mm_andnot_si128(q3h_0, m1), 2);
2391
+ q3h_1 = _mm_slli_epi16(_mm_andnot_si128(q3h_1, m1), 2);
2392
+ q3h_2 = _mm_slli_epi16(_mm_andnot_si128(q3h_2, m1), 2);
2393
+ q3h_3 = _mm_slli_epi16(_mm_andnot_si128(q3h_3, m1), 2);
2394
+
2395
+ // load low 2 bits
2396
+ const __m128i q3bits = _mm_loadu_si128((const __m128i*)q3);
2397
+
2398
+ // prepare low and high bits
2399
+ const __m128i q3l_0 = _mm_and_si128(q3bits, m3);
2400
+ const __m128i q3l_1 = _mm_and_si128(_mm_srli_epi16(q3bits, 2), m3);
2401
+ const __m128i q3l_2 = _mm_and_si128(_mm_srli_epi16(q3bits, 4), m3);
2402
+ const __m128i q3l_3 = _mm_and_si128(_mm_srli_epi16(q3bits, 6), m3);
2403
+
2404
+ // load Q8 quants
2405
+ const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
2406
+ const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
2407
+
2408
+ // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm_maddubs_epi16,
2409
+ // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set,
2410
+ // and 2 if the high bit was set)
2411
+ const __m128i q8s_0 = _mm_maddubs_epi16(q3h_0, _mm256_extractf128_si256(q8_0, 0));
2412
+ const __m128i q8s_1 = _mm_maddubs_epi16(q3h_1, _mm256_extractf128_si256(q8_0, 1));
2413
+ const __m128i q8s_2 = _mm_maddubs_epi16(q3h_2, _mm256_extractf128_si256(q8_1, 0));
2414
+ const __m128i q8s_3 = _mm_maddubs_epi16(q3h_3, _mm256_extractf128_si256(q8_1, 1));
2415
+
2416
+ __m128i p16_0 = _mm_maddubs_epi16(q3l_0, _mm256_extractf128_si256(q8_0, 0));
2417
+ __m128i p16_1 = _mm_maddubs_epi16(q3l_1, _mm256_extractf128_si256(q8_0, 1));
2418
+ __m128i p16_2 = _mm_maddubs_epi16(q3l_2, _mm256_extractf128_si256(q8_1, 0));
2419
+ __m128i p16_3 = _mm_maddubs_epi16(q3l_3, _mm256_extractf128_si256(q8_1, 1));
2420
+
2421
+ p16_0 = _mm_sub_epi16(p16_0, q8s_0);
2422
+ p16_1 = _mm_sub_epi16(p16_1, q8s_1);
2423
+ p16_2 = _mm_sub_epi16(p16_2, q8s_2);
2424
+ p16_3 = _mm_sub_epi16(p16_3, q8s_3);
2425
+
2426
+ // multiply with scales
2427
+ p16_0 = _mm_madd_epi16(scale_0, p16_0);
2428
+ p16_1 = _mm_madd_epi16(scale_1, p16_1);
2429
+ p16_2 = _mm_madd_epi16(scale_2, p16_2);
2430
+ p16_3 = _mm_madd_epi16(scale_3, p16_3);
2431
+
2432
+ p16_0 = _mm_add_epi32(p16_0, p16_2);
2433
+ p16_1 = _mm_add_epi32(p16_1, p16_3);
2434
+ __m256i p16 = MM256_SET_M128I(p16_1, p16_0);
2435
+
2436
+ // multiply with block scale and accumulate
2437
+ acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(p16)), acc);
2438
+
2439
+ }
2440
+
2441
+ *s = hsum_float_8(acc);
2442
+
2298
2443
  #else
2299
2444
 
2300
2445
  int8_t aux8[QK_K];
@@ -2477,7 +2622,7 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
2477
2622
  acc_m = _mm_fmadd_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod), acc_m);
2478
2623
 
2479
2624
  const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0);
2480
- const __m256i scales = _mm256_set_m128i(sc128, sc128);
2625
+ const __m256i scales = MM256_SET_M128I(sc128, sc128);
2481
2626
 
2482
2627
  __m256i sumi = _mm256_setzero_si256();
2483
2628
 
@@ -2584,7 +2729,7 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
2584
2729
  }
2585
2730
 
2586
2731
  __m256 vd = _mm256_set1_ps(d);
2587
- __m256i sumi = _mm256_set_m128i(sumi_1, sumi_0);
2732
+ __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
2588
2733
  acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(sumi)), acc);
2589
2734
 
2590
2735
  }
@@ -2781,6 +2926,60 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
2781
2926
 
2782
2927
  *s = hsum_float_8(acc) - summs;
2783
2928
 
2929
+ #elif defined __AVX__
2930
+
2931
+ const __m128i m4 = _mm_set1_epi8(0xF);
2932
+
2933
+ __m256 acc = _mm256_setzero_ps();
2934
+
2935
+ float summs = 0;
2936
+
2937
+ uint16_t aux16[2];
2938
+ const uint8_t * scales = (const uint8_t *)aux16;
2939
+
2940
+ for (int i = 0; i < nb; ++i) {
2941
+
2942
+ const float d = ggml_fp16_to_fp32(x[i].d[0]) * y[i].d;
2943
+ const float m = ggml_fp16_to_fp32(x[i].d[1]) * y[i].d;
2944
+ const __m256 vd = _mm256_set1_ps(d);
2945
+
2946
+ const uint16_t * a = (const uint16_t *)x[i].scales;
2947
+ aux16[0] = a[0] & 0x0f0f;
2948
+ aux16[1] = (a[0] >> 4) & 0x0f0f;
2949
+
2950
+ summs += m * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3]));
2951
+
2952
+ const uint8_t * restrict q4 = x[i].qs;
2953
+ const int8_t * restrict q8 = y[i].qs;
2954
+
2955
+ const __m256i q4bits = _mm256_loadu_si256((const __m256i*)q4);
2956
+ const __m128i q4bits_0 = _mm256_extractf128_si256(q4bits, 0);
2957
+ const __m128i q4bits_1 = _mm256_extractf128_si256(q4bits, 1);
2958
+ const __m128i q4_0 = _mm_and_si128(q4bits_0, m4);
2959
+ const __m128i q4_1 = _mm_and_si128(q4bits_1, m4);
2960
+ const __m128i q4_2 = _mm_and_si128(_mm_srli_epi16(q4bits_0, 4), m4);
2961
+ const __m128i q4_3 = _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4);
2962
+
2963
+ const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
2964
+ const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
2965
+
2966
+ const __m128i p16_0 = _mm_maddubs_epi16(q4_0, _mm256_extractf128_si256(q8_0, 0));
2967
+ const __m128i p16_1 = _mm_maddubs_epi16(q4_1, _mm256_extractf128_si256(q8_0, 1));
2968
+ const __m128i p16_2 = _mm_maddubs_epi16(q4_2, _mm256_extractf128_si256(q8_1, 0));
2969
+ const __m128i p16_3 = _mm_maddubs_epi16(q4_3, _mm256_extractf128_si256(q8_1, 1));
2970
+
2971
+ const __m128i p32_0 = _mm_madd_epi16(_mm_set1_epi16(scales[0]), p16_0);
2972
+ const __m128i p32_1 = _mm_madd_epi16(_mm_set1_epi16(scales[0]), p16_1);
2973
+ acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(MM256_SET_M128I(p32_1, p32_0))), acc);
2974
+
2975
+ const __m128i p32_2 = _mm_madd_epi16(_mm_set1_epi16(scales[1]), p16_2);
2976
+ const __m128i p32_3 = _mm_madd_epi16(_mm_set1_epi16(scales[1]), p16_3);
2977
+ acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(MM256_SET_M128I(p32_3, p32_2))), acc);
2978
+
2979
+ }
2980
+
2981
+ *s = hsum_float_8(acc) - summs;
2982
+
2784
2983
  #else
2785
2984
 
2786
2985
  uint8_t aux8[QK_K];
@@ -2963,7 +3162,7 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
2963
3162
  summs += dmin * _mm_extract_epi32(hsum, 0);
2964
3163
 
2965
3164
  const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0);
2966
- const __m256i scales = _mm256_set_m128i(sc128, sc128);
3165
+ const __m256i scales = MM256_SET_M128I(sc128, sc128);
2967
3166
 
2968
3167
  const __m256i hbits = _mm256_loadu_si256((const __m256i*)x[i].qh);
2969
3168
  __m256i hmask = mone;
@@ -3102,7 +3301,7 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
3102
3301
  }
3103
3302
 
3104
3303
  __m256 vd = _mm256_set1_ps(d);
3105
- __m256i sumi = _mm256_set_m128i(sumi_1, sumi_0);
3304
+ __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
3106
3305
  acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(sumi)), acc);
3107
3306
 
3108
3307
  }
@@ -3265,13 +3464,13 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
3265
3464
 
3266
3465
  const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5);
3267
3466
 
3268
- const __m256i scale_l = _mm256_set_m128i(_mm_set1_epi16(x[i].scales[1]), _mm_set1_epi16(x[i].scales[0]));
3269
- const __m256i scale_h = _mm256_set_m128i(_mm_set1_epi16(x[i].scales[3]), _mm_set1_epi16(x[i].scales[2]));
3467
+ const __m256i scale_l = MM256_SET_M128I(_mm_set1_epi16(x[i].scales[1]), _mm_set1_epi16(x[i].scales[0]));
3468
+ const __m256i scale_h = MM256_SET_M128I(_mm_set1_epi16(x[i].scales[3]), _mm_set1_epi16(x[i].scales[2]));
3270
3469
 
3271
3470
  int64_t aux64;
3272
3471
  memcpy(&aux64, x[i].qh, 8);
3273
3472
  const __m128i haux128 = _mm_set_epi64x(aux64 >> 1, aux64);
3274
- const __m256i haux256 = _mm256_set_m128i(_mm_srli_epi16(haux128, 2), haux128);
3473
+ const __m256i haux256 = MM256_SET_M128I(_mm_srli_epi16(haux128, 2), haux128);
3275
3474
 
3276
3475
  const __m256i q5h_0 = _mm256_slli_epi16(_mm256_andnot_si256(haux256, mone), 4);
3277
3476
  const __m256i q5h_1 = _mm256_slli_epi16(_mm256_andnot_si256(_mm256_srli_epi16(haux256, 4), mone), 4);
@@ -3295,10 +3494,66 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
3295
3494
 
3296
3495
  *s = hsum_float_8(acc);
3297
3496
 
3298
- #else
3497
+ #elif defined __AVX__
3299
3498
 
3499
+ const __m128i m4 = _mm_set1_epi8(0xF);
3500
+ const __m128i mone = _mm_set1_epi8(1);
3300
3501
 
3301
- uint8_t aux8[QK_K];
3502
+ __m256 acc = _mm256_setzero_ps();
3503
+
3504
+ for (int i = 0; i < nb; ++i) {
3505
+
3506
+ const uint8_t * restrict q5 = x[i].qs;
3507
+ const int8_t * restrict q8 = y[i].qs;
3508
+
3509
+ const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
3510
+
3511
+ const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5);
3512
+
3513
+ const __m128i scale_0 = _mm_set1_epi16(x[i].scales[0]);
3514
+ const __m128i scale_1 = _mm_set1_epi16(x[i].scales[1]);
3515
+ const __m128i scale_2 = _mm_set1_epi16(x[i].scales[2]);
3516
+ const __m128i scale_3 = _mm_set1_epi16(x[i].scales[3]);
3517
+
3518
+ int64_t aux64;
3519
+ memcpy(&aux64, x[i].qh, 8);
3520
+ const __m128i haux128_0 = _mm_set_epi64x(aux64 >> 1, aux64);
3521
+ const __m128i haux128_1 = _mm_srli_epi16(haux128_0, 2);
3522
+
3523
+ const __m128i q5h_0 = _mm_slli_epi16(_mm_andnot_si128(haux128_0, mone), 4);
3524
+ const __m128i q5h_1 = _mm_slli_epi16(_mm_andnot_si128(haux128_1, mone), 4);
3525
+ const __m128i q5h_2 = _mm_slli_epi16(_mm_andnot_si128(_mm_srli_epi16(haux128_0, 4), mone), 4);
3526
+ const __m128i q5h_3 = _mm_slli_epi16(_mm_andnot_si128(_mm_srli_epi16(haux128_1, 4), mone), 4);
3527
+
3528
+ const __m128i q5l_0 = _mm_and_si128(_mm256_extractf128_si256(q5bits, 0), m4);
3529
+ const __m128i q5l_1 = _mm_and_si128(_mm256_extractf128_si256(q5bits, 1), m4);
3530
+ const __m128i q5l_2 = _mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q5bits, 0), 4), m4);
3531
+ const __m128i q5l_3 = _mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q5bits, 1), 4), m4);
3532
+
3533
+ const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
3534
+ const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
3535
+
3536
+ const __m128i p16_0 = _mm_madd_epi16(scale_0, _mm_maddubs_epi16(q5l_0, _mm256_extractf128_si256(q8_0, 0)));
3537
+ const __m128i p16_1 = _mm_madd_epi16(scale_1, _mm_maddubs_epi16(q5l_1, _mm256_extractf128_si256(q8_0, 1)));
3538
+ const __m128i p16_2 = _mm_madd_epi16(scale_2, _mm_maddubs_epi16(q5l_2, _mm256_extractf128_si256(q8_1, 0)));
3539
+ const __m128i p16_3 = _mm_madd_epi16(scale_3, _mm_maddubs_epi16(q5l_3, _mm256_extractf128_si256(q8_1, 1)));
3540
+ const __m128i s16_0 = _mm_madd_epi16(scale_0, _mm_maddubs_epi16(q5h_0, _mm256_extractf128_si256(q8_0, 0)));
3541
+ const __m128i s16_1 = _mm_madd_epi16(scale_1, _mm_maddubs_epi16(q5h_1, _mm256_extractf128_si256(q8_0, 1)));
3542
+ const __m128i s16_2 = _mm_madd_epi16(scale_2, _mm_maddubs_epi16(q5h_2, _mm256_extractf128_si256(q8_1, 0)));
3543
+ const __m128i s16_3 = _mm_madd_epi16(scale_3, _mm_maddubs_epi16(q5h_3, _mm256_extractf128_si256(q8_1, 1)));
3544
+
3545
+ const __m128i dot_0 = _mm_sub_epi32(_mm_add_epi32(p16_0, p16_2), _mm_add_epi32(s16_0, s16_2));
3546
+ const __m128i dot_1 = _mm_sub_epi32(_mm_add_epi32(p16_1, p16_3), _mm_add_epi32(s16_1, s16_3));
3547
+
3548
+ acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(dot_1, dot_0))), acc);
3549
+
3550
+ }
3551
+
3552
+ *s = hsum_float_8(acc);
3553
+
3554
+ #else
3555
+
3556
+ int8_t aux8[QK_K];
3302
3557
  int16_t aux16[16];
3303
3558
  float sums [8];
3304
3559
  memset(sums, 0, 8*sizeof(float));
@@ -3308,7 +3563,7 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
3308
3563
  const uint8_t * restrict q4 = x[i].qs;
3309
3564
  const uint8_t * restrict hm = x[i].qh;
3310
3565
  const int8_t * restrict q8 = y[i].qs;
3311
- uint8_t * restrict a = aux8;
3566
+ int8_t * restrict a = aux8;
3312
3567
  for (int l = 0; l < 32; ++l) {
3313
3568
  a[l+ 0] = q4[l] & 0xF;
3314
3569
  a[l+32] = q4[l] >> 4;
@@ -3672,7 +3927,7 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
3672
3927
 
3673
3928
  }
3674
3929
 
3675
- __m256i sumi = _mm256_set_m128i(sumi_1, sumi_0);
3930
+ __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
3676
3931
  acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi)), acc);
3677
3932
  }
3678
3933
 
@@ -3830,8 +4085,8 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
3830
4085
  const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4);
3831
4086
  const __m128i q4bitsH = _mm_loadu_si128((const __m128i*)qh);
3832
4087
 
3833
- const __m256i q4h_0 = _mm256_slli_epi16(_mm256_and_si256(_mm256_set_m128i(_mm_srli_epi16(q4bitsH, 2), q4bitsH), m2), 4);
3834
- const __m256i q4h_1 = _mm256_slli_epi16(_mm256_and_si256(_mm256_set_m128i(_mm_srli_epi16(q4bitsH, 6), _mm_srli_epi16(q4bitsH, 4)), m2), 4);
4088
+ const __m256i q4h_0 = _mm256_slli_epi16(_mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(q4bitsH, 2), q4bitsH), m2), 4);
4089
+ const __m256i q4h_1 = _mm256_slli_epi16(_mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(q4bitsH, 6), _mm_srli_epi16(q4bitsH, 4)), m2), 4);
3835
4090
 
3836
4091
  const __m256i q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m4), q4h_0);
3837
4092
  const __m256i q4_1 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m4), q4h_1);
@@ -3858,6 +4113,77 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
3858
4113
 
3859
4114
  *s = hsum_float_8(acc);
3860
4115
 
4116
+ #elif defined __AVX__
4117
+
4118
+ const __m128i m4 = _mm_set1_epi8(0xF);
4119
+ const __m128i m2 = _mm_set1_epi8(3);
4120
+ const __m128i m32s = _mm_set1_epi8(32);
4121
+
4122
+ __m256 acc = _mm256_setzero_ps();
4123
+
4124
+ for (int i = 0; i < nb; ++i) {
4125
+
4126
+ const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
4127
+
4128
+ const uint8_t * restrict q4 = x[i].ql;
4129
+ const uint8_t * restrict qh = x[i].qh;
4130
+ const int8_t * restrict q8 = y[i].qs;
4131
+
4132
+ const __m64 scales_1 = _mm_set1_pi8(x[i].scales[0]);
4133
+ const __m64 scales_2 = _mm_set1_pi8(x[i].scales[1]);
4134
+ const __m64 scales_3 = _mm_set1_pi8(x[i].scales[2]);
4135
+ const __m64 scales_4 = _mm_set1_pi8(x[i].scales[3]);
4136
+
4137
+ __m128i sumi_0 = _mm_setzero_si128();
4138
+ __m128i sumi_1 = _mm_setzero_si128();
4139
+
4140
+ const __m128i scale_0 = _mm_set_epi64(scales_2, scales_1);
4141
+ const __m128i scale_1 = _mm_set_epi64(scales_4, scales_3);
4142
+
4143
+ const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4);
4144
+ const __m128i q4bitsH = _mm_loadu_si128((const __m128i*)qh);
4145
+
4146
+ const __m128i q4h_0 = _mm_slli_epi16(_mm_and_si128(q4bitsH, m2), 4);
4147
+ const __m128i q4h_1 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH, 2), m2), 4);
4148
+ const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH, 4), m2), 4);
4149
+ const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH, 6), m2), 4);
4150
+
4151
+ const __m128i q4_0 = _mm_or_si128(_mm_and_si128(_mm256_extractf128_si256(q4bits1, 0), m4), q4h_0);
4152
+ const __m128i q4_1 = _mm_or_si128(_mm_and_si128(_mm256_extractf128_si256(q4bits1, 1), m4), q4h_1);
4153
+ const __m128i q4_2 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q4bits1, 0), 4), m4), q4h_2);
4154
+ const __m128i q4_3 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q4bits1, 1), 4), m4), q4h_3);
4155
+
4156
+ const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
4157
+ const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
4158
+
4159
+ __m128i q8s_0 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_0, 0));
4160
+ __m128i q8s_1 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_0, 1));
4161
+ __m128i q8s_2 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_1, 0));
4162
+ __m128i q8s_3 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_1, 1));
4163
+
4164
+ __m128i p16_0 = _mm_maddubs_epi16(q4_0, _mm256_extractf128_si256(q8_0, 0));
4165
+ __m128i p16_1 = _mm_maddubs_epi16(q4_1, _mm256_extractf128_si256(q8_0, 1));
4166
+ __m128i p16_2 = _mm_maddubs_epi16(q4_2, _mm256_extractf128_si256(q8_1, 0));
4167
+ __m128i p16_3 = _mm_maddubs_epi16(q4_3, _mm256_extractf128_si256(q8_1, 1));
4168
+
4169
+ p16_0 = _mm_sub_epi16(p16_0, q8s_0);
4170
+ p16_1 = _mm_sub_epi16(p16_1, q8s_1);
4171
+ p16_2 = _mm_sub_epi16(p16_2, q8s_2);
4172
+ p16_3 = _mm_sub_epi16(p16_3, q8s_3);
4173
+
4174
+ p16_0 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_0), p16_0);
4175
+ p16_1 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_0, scale_0)), p16_1);
4176
+ p16_2 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_1), p16_2);
4177
+ p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_1, scale_1)), p16_3);
4178
+
4179
+ sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2));
4180
+ sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3));
4181
+
4182
+ acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(MM256_SET_M128I(sumi_1, sumi_0))), acc);
4183
+ }
4184
+
4185
+ *s = hsum_float_8(acc);
4186
+
3861
4187
  #else
3862
4188
 
3863
4189
  int8_t aux8[QK_K];