llama_cpp 0.3.4 → 0.3.6
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +12 -0
- data/README.md +18 -2
- data/ext/llama_cpp/extconf.rb +2 -1
- data/ext/llama_cpp/llama_cpp.cpp +315 -8
- data/ext/llama_cpp/src/ggml-alloc.c +541 -0
- data/ext/llama_cpp/src/ggml-alloc.h +22 -0
- data/ext/llama_cpp/src/ggml-cuda.cu +2271 -414
- data/ext/llama_cpp/src/ggml-cuda.h +1 -0
- data/ext/llama_cpp/src/ggml-metal.h +7 -0
- data/ext/llama_cpp/src/ggml-metal.m +218 -87
- data/ext/llama_cpp/src/ggml-metal.metal +72 -55
- data/ext/llama_cpp/src/ggml.c +754 -996
- data/ext/llama_cpp/src/ggml.h +94 -18
- data/ext/llama_cpp/src/k_quants.c +350 -24
- data/ext/llama_cpp/src/llama.cpp +713 -179
- data/ext/llama_cpp/src/llama.h +61 -5
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +26 -0
- metadata +4 -2
@@ -39,6 +39,8 @@
|
|
39
39
|
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
40
40
|
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
41
41
|
|
42
|
+
#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
|
43
|
+
|
42
44
|
//
|
43
45
|
// 2-6 bit quantization in super-blocks
|
44
46
|
//
|
@@ -1353,7 +1355,7 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
|
|
1353
1355
|
const __m256i all_scales = _mm256_cvtepi8_epi16(scales8);
|
1354
1356
|
const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0);
|
1355
1357
|
const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1);
|
1356
|
-
const __m256i scales[2] = {
|
1358
|
+
const __m256i scales[2] = {MM256_SET_M128I(l_scales, l_scales), MM256_SET_M128I(h_scales, h_scales)};
|
1357
1359
|
|
1358
1360
|
__m256i sumi = _mm256_setzero_si256();
|
1359
1361
|
|
@@ -1421,7 +1423,7 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
|
|
1421
1423
|
const __m128i summs_1 = _mm_madd_epi16(mins_1, _mm_loadu_si128((const __m128i*)&y[i].bsums[8]));
|
1422
1424
|
|
1423
1425
|
// sumf += -dmin * summs in 32bits*8
|
1424
|
-
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&dmin), _mm256_cvtepi32_ps(
|
1426
|
+
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&dmin), _mm256_cvtepi32_ps(MM256_SET_M128I(summs_1, summs_0))), acc);
|
1425
1427
|
|
1426
1428
|
const __m128i scales_0 = _mm_cvtepi8_epi16(scales16);
|
1427
1429
|
const __m128i scales_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(scales16, scales16));
|
@@ -1493,7 +1495,7 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
|
|
1493
1495
|
}
|
1494
1496
|
|
1495
1497
|
// sumf += dall * isum - dmin * summs in 32bits
|
1496
|
-
__m256i sumi =
|
1498
|
+
__m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
|
1497
1499
|
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&dall), _mm256_cvtepi32_ps(sumi)), acc);
|
1498
1500
|
}
|
1499
1501
|
|
@@ -1644,8 +1646,8 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
|
|
1644
1646
|
summs += dmin * smin;
|
1645
1647
|
|
1646
1648
|
const __m128i q2bits = _mm_loadu_si128((const __m128i*)q2);
|
1647
|
-
const __m256i q2_0 = _mm256_and_si256(
|
1648
|
-
const __m256i q2_1 = _mm256_and_si256(
|
1649
|
+
const __m256i q2_0 = _mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(q2bits, 2), q2bits), m3);
|
1650
|
+
const __m256i q2_1 = _mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(q2bits, 6), _mm_srli_epi16(q2bits, 4)), m3);
|
1649
1651
|
|
1650
1652
|
const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
|
1651
1653
|
const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
|
@@ -1666,6 +1668,62 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
|
|
1666
1668
|
|
1667
1669
|
*s = hsum_float_8(acc) + summs;
|
1668
1670
|
|
1671
|
+
#elif defined __AVX__
|
1672
|
+
|
1673
|
+
const __m128i m3 = _mm_set1_epi8(3);
|
1674
|
+
|
1675
|
+
__m256 acc = _mm256_setzero_ps();
|
1676
|
+
|
1677
|
+
uint32_t ud, um;
|
1678
|
+
const uint8_t * restrict db = (const uint8_t *)&ud;
|
1679
|
+
const uint8_t * restrict mb = (const uint8_t *)&um;
|
1680
|
+
|
1681
|
+
float summs = 0;
|
1682
|
+
|
1683
|
+
// TODO: optimize this
|
1684
|
+
|
1685
|
+
for (int i = 0; i < nb; ++i) {
|
1686
|
+
|
1687
|
+
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
|
1688
|
+
const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
|
1689
|
+
|
1690
|
+
const uint8_t * restrict q2 = x[i].qs;
|
1691
|
+
const int8_t * restrict q8 = y[i].qs;
|
1692
|
+
|
1693
|
+
const uint32_t * restrict sc = (const uint32_t *)x[i].scales;
|
1694
|
+
ud = (sc[0] >> 0) & 0x0f0f0f0f;
|
1695
|
+
um = (sc[0] >> 4) & 0x0f0f0f0f;
|
1696
|
+
|
1697
|
+
int32_t smin = mb[0] * y[i].bsums[0] + mb[1] * y[i].bsums[1] + mb[2] * y[i].bsums[2] + mb[3] * y[i].bsums[3];
|
1698
|
+
summs += dmin * smin;
|
1699
|
+
|
1700
|
+
const __m128i q2bits = _mm_loadu_si128((const __m128i*)q2);
|
1701
|
+
const __m128i q2_0 = _mm_and_si128(q2bits, m3);
|
1702
|
+
const __m128i q2_1 = _mm_and_si128(_mm_srli_epi16(q2bits, 2), m3);
|
1703
|
+
const __m128i q2_2 = _mm_and_si128(_mm_srli_epi16(q2bits, 4), m3);
|
1704
|
+
const __m128i q2_3 = _mm_and_si128(_mm_srli_epi16(q2bits, 6), m3);
|
1705
|
+
|
1706
|
+
const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
|
1707
|
+
const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
|
1708
|
+
|
1709
|
+
const __m128i p0 = _mm_maddubs_epi16(q2_0, _mm256_extractf128_si256(q8_0, 0));
|
1710
|
+
const __m128i p1 = _mm_maddubs_epi16(q2_1, _mm256_extractf128_si256(q8_0, 1));
|
1711
|
+
const __m128i p2 = _mm_maddubs_epi16(q2_2, _mm256_extractf128_si256(q8_1, 0));
|
1712
|
+
const __m128i p3 = _mm_maddubs_epi16(q2_3, _mm256_extractf128_si256(q8_1, 1));
|
1713
|
+
|
1714
|
+
const __m256i p_0 = MM256_SET_M128I(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p0, p0)), _mm_cvtepi16_epi32(p0));
|
1715
|
+
const __m256i p_1 = MM256_SET_M128I(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p1, p1)), _mm_cvtepi16_epi32(p1));
|
1716
|
+
const __m256i p_2 = MM256_SET_M128I(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p2, p2)), _mm_cvtepi16_epi32(p2));
|
1717
|
+
const __m256i p_3 = MM256_SET_M128I(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p3, p3)), _mm_cvtepi16_epi32(p3));
|
1718
|
+
|
1719
|
+
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[0]), _mm256_cvtepi32_ps(p_0)), acc);
|
1720
|
+
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[1]), _mm256_cvtepi32_ps(p_1)), acc);
|
1721
|
+
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[2]), _mm256_cvtepi32_ps(p_2)), acc);
|
1722
|
+
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[3]), _mm256_cvtepi32_ps(p_3)), acc);
|
1723
|
+
}
|
1724
|
+
|
1725
|
+
*s = hsum_float_8(acc) + summs;
|
1726
|
+
|
1669
1727
|
#else
|
1670
1728
|
|
1671
1729
|
float sumf = 0;
|
@@ -1861,7 +1919,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
|
|
1861
1919
|
const __m256i all_scales = _mm256_cvtepi8_epi16(scales128);
|
1862
1920
|
const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0);
|
1863
1921
|
const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1);
|
1864
|
-
const __m256i scales[2] = {
|
1922
|
+
const __m256i scales[2] = {MM256_SET_M128I(l_scales, l_scales), MM256_SET_M128I(h_scales, h_scales)};
|
1865
1923
|
|
1866
1924
|
// high bit
|
1867
1925
|
const __m256i hbits = _mm256_loadu_si256((const __m256i*)x[i].hmask);
|
@@ -2072,7 +2130,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
|
|
2072
2130
|
}
|
2073
2131
|
|
2074
2132
|
// multiply with block scale and accumulate
|
2075
|
-
__m256i sumi =
|
2133
|
+
__m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
|
2076
2134
|
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi)), acc);
|
2077
2135
|
|
2078
2136
|
}
|
@@ -2247,13 +2305,13 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
|
|
2247
2305
|
aux16[0] = a & 0x0f0f;
|
2248
2306
|
aux16[1] = (a >> 4) & 0x0f0f;
|
2249
2307
|
|
2250
|
-
const __m256i scale_0 =
|
2251
|
-
const __m256i scale_1 =
|
2308
|
+
const __m256i scale_0 = MM256_SET_M128I(_mm_set1_epi16(aux8[2] - 8), _mm_set1_epi16(aux8[0] - 8));
|
2309
|
+
const __m256i scale_1 = MM256_SET_M128I(_mm_set1_epi16(aux8[3] - 8), _mm_set1_epi16(aux8[1] - 8));
|
2252
2310
|
|
2253
2311
|
memcpy(&aux64, x[i].hmask, 8);
|
2254
2312
|
|
2255
2313
|
const __m128i haux = _mm_set_epi64x(aux64 >> 1, aux64 >> 0);
|
2256
|
-
__m256i q3h_0 =
|
2314
|
+
__m256i q3h_0 = MM256_SET_M128I(_mm_srli_epi16(haux, 2), haux);
|
2257
2315
|
__m256i q3h_1 = _mm256_srli_epi16(q3h_0, 4);
|
2258
2316
|
q3h_0 = _mm256_slli_epi16(_mm256_andnot_si256(q3h_0, m1), 2);
|
2259
2317
|
q3h_1 = _mm256_slli_epi16(_mm256_andnot_si256(q3h_1, m1), 2);
|
@@ -2262,7 +2320,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
|
|
2262
2320
|
const __m128i q3bits = _mm_loadu_si128((const __m128i*)q3);
|
2263
2321
|
|
2264
2322
|
// prepare low and high bits
|
2265
|
-
const __m256i q3aux =
|
2323
|
+
const __m256i q3aux = MM256_SET_M128I(_mm_srli_epi16(q3bits, 2), q3bits);
|
2266
2324
|
const __m256i q3l_0 = _mm256_and_si256(q3aux, m3);
|
2267
2325
|
const __m256i q3l_1 = _mm256_and_si256(_mm256_srli_epi16(q3aux, 4), m3);
|
2268
2326
|
|
@@ -2295,6 +2353,93 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
|
|
2295
2353
|
|
2296
2354
|
*s = hsum_float_8(acc);
|
2297
2355
|
|
2356
|
+
#elif defined __AVX__
|
2357
|
+
|
2358
|
+
const __m128i m3 = _mm_set1_epi8(3);
|
2359
|
+
const __m128i m1 = _mm_set1_epi8(1);
|
2360
|
+
|
2361
|
+
__m256 acc = _mm256_setzero_ps();
|
2362
|
+
|
2363
|
+
uint64_t aux64;
|
2364
|
+
|
2365
|
+
uint16_t aux16[2];
|
2366
|
+
const int8_t * aux8 = (const int8_t *)aux16;
|
2367
|
+
|
2368
|
+
for (int i = 0; i < nb; ++i) {
|
2369
|
+
|
2370
|
+
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
|
2371
|
+
|
2372
|
+
const uint8_t * restrict q3 = x[i].qs;
|
2373
|
+
const int8_t * restrict q8 = y[i].qs;
|
2374
|
+
|
2375
|
+
const uint16_t a = *(const uint16_t *)x[i].scales;
|
2376
|
+
aux16[0] = a & 0x0f0f;
|
2377
|
+
aux16[1] = (a >> 4) & 0x0f0f;
|
2378
|
+
|
2379
|
+
const __m128i scale_0 = _mm_set1_epi16(aux8[0] - 8);
|
2380
|
+
const __m128i scale_1 = _mm_set1_epi16(aux8[2] - 8);
|
2381
|
+
const __m128i scale_2 = _mm_set1_epi16(aux8[1] - 8);
|
2382
|
+
const __m128i scale_3 = _mm_set1_epi16(aux8[3] - 8);
|
2383
|
+
|
2384
|
+
memcpy(&aux64, x[i].hmask, 8);
|
2385
|
+
|
2386
|
+
__m128i q3h_0 = _mm_set_epi64x(aux64 >> 1, aux64 >> 0);
|
2387
|
+
__m128i q3h_1 = _mm_srli_epi16(q3h_0, 2);
|
2388
|
+
__m128i q3h_2 = _mm_srli_epi16(q3h_0, 4);
|
2389
|
+
__m128i q3h_3 = _mm_srli_epi16(q3h_0, 6);
|
2390
|
+
q3h_0 = _mm_slli_epi16(_mm_andnot_si128(q3h_0, m1), 2);
|
2391
|
+
q3h_1 = _mm_slli_epi16(_mm_andnot_si128(q3h_1, m1), 2);
|
2392
|
+
q3h_2 = _mm_slli_epi16(_mm_andnot_si128(q3h_2, m1), 2);
|
2393
|
+
q3h_3 = _mm_slli_epi16(_mm_andnot_si128(q3h_3, m1), 2);
|
2394
|
+
|
2395
|
+
// load low 2 bits
|
2396
|
+
const __m128i q3bits = _mm_loadu_si128((const __m128i*)q3);
|
2397
|
+
|
2398
|
+
// prepare low and high bits
|
2399
|
+
const __m128i q3l_0 = _mm_and_si128(q3bits, m3);
|
2400
|
+
const __m128i q3l_1 = _mm_and_si128(_mm_srli_epi16(q3bits, 2), m3);
|
2401
|
+
const __m128i q3l_2 = _mm_and_si128(_mm_srli_epi16(q3bits, 4), m3);
|
2402
|
+
const __m128i q3l_3 = _mm_and_si128(_mm_srli_epi16(q3bits, 6), m3);
|
2403
|
+
|
2404
|
+
// load Q8 quants
|
2405
|
+
const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
|
2406
|
+
const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
|
2407
|
+
|
2408
|
+
// Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm_maddubs_epi16,
|
2409
|
+
// and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set,
|
2410
|
+
// and 2 if the high bit was set)
|
2411
|
+
const __m128i q8s_0 = _mm_maddubs_epi16(q3h_0, _mm256_extractf128_si256(q8_0, 0));
|
2412
|
+
const __m128i q8s_1 = _mm_maddubs_epi16(q3h_1, _mm256_extractf128_si256(q8_0, 1));
|
2413
|
+
const __m128i q8s_2 = _mm_maddubs_epi16(q3h_2, _mm256_extractf128_si256(q8_1, 0));
|
2414
|
+
const __m128i q8s_3 = _mm_maddubs_epi16(q3h_3, _mm256_extractf128_si256(q8_1, 1));
|
2415
|
+
|
2416
|
+
__m128i p16_0 = _mm_maddubs_epi16(q3l_0, _mm256_extractf128_si256(q8_0, 0));
|
2417
|
+
__m128i p16_1 = _mm_maddubs_epi16(q3l_1, _mm256_extractf128_si256(q8_0, 1));
|
2418
|
+
__m128i p16_2 = _mm_maddubs_epi16(q3l_2, _mm256_extractf128_si256(q8_1, 0));
|
2419
|
+
__m128i p16_3 = _mm_maddubs_epi16(q3l_3, _mm256_extractf128_si256(q8_1, 1));
|
2420
|
+
|
2421
|
+
p16_0 = _mm_sub_epi16(p16_0, q8s_0);
|
2422
|
+
p16_1 = _mm_sub_epi16(p16_1, q8s_1);
|
2423
|
+
p16_2 = _mm_sub_epi16(p16_2, q8s_2);
|
2424
|
+
p16_3 = _mm_sub_epi16(p16_3, q8s_3);
|
2425
|
+
|
2426
|
+
// multiply with scales
|
2427
|
+
p16_0 = _mm_madd_epi16(scale_0, p16_0);
|
2428
|
+
p16_1 = _mm_madd_epi16(scale_1, p16_1);
|
2429
|
+
p16_2 = _mm_madd_epi16(scale_2, p16_2);
|
2430
|
+
p16_3 = _mm_madd_epi16(scale_3, p16_3);
|
2431
|
+
|
2432
|
+
p16_0 = _mm_add_epi32(p16_0, p16_2);
|
2433
|
+
p16_1 = _mm_add_epi32(p16_1, p16_3);
|
2434
|
+
__m256i p16 = MM256_SET_M128I(p16_1, p16_0);
|
2435
|
+
|
2436
|
+
// multiply with block scale and accumulate
|
2437
|
+
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(p16)), acc);
|
2438
|
+
|
2439
|
+
}
|
2440
|
+
|
2441
|
+
*s = hsum_float_8(acc);
|
2442
|
+
|
2298
2443
|
#else
|
2299
2444
|
|
2300
2445
|
int8_t aux8[QK_K];
|
@@ -2477,7 +2622,7 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
|
|
2477
2622
|
acc_m = _mm_fmadd_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod), acc_m);
|
2478
2623
|
|
2479
2624
|
const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0);
|
2480
|
-
const __m256i scales =
|
2625
|
+
const __m256i scales = MM256_SET_M128I(sc128, sc128);
|
2481
2626
|
|
2482
2627
|
__m256i sumi = _mm256_setzero_si256();
|
2483
2628
|
|
@@ -2584,7 +2729,7 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
|
|
2584
2729
|
}
|
2585
2730
|
|
2586
2731
|
__m256 vd = _mm256_set1_ps(d);
|
2587
|
-
__m256i sumi =
|
2732
|
+
__m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
|
2588
2733
|
acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(sumi)), acc);
|
2589
2734
|
|
2590
2735
|
}
|
@@ -2781,6 +2926,60 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
|
|
2781
2926
|
|
2782
2927
|
*s = hsum_float_8(acc) - summs;
|
2783
2928
|
|
2929
|
+
#elif defined __AVX__
|
2930
|
+
|
2931
|
+
const __m128i m4 = _mm_set1_epi8(0xF);
|
2932
|
+
|
2933
|
+
__m256 acc = _mm256_setzero_ps();
|
2934
|
+
|
2935
|
+
float summs = 0;
|
2936
|
+
|
2937
|
+
uint16_t aux16[2];
|
2938
|
+
const uint8_t * scales = (const uint8_t *)aux16;
|
2939
|
+
|
2940
|
+
for (int i = 0; i < nb; ++i) {
|
2941
|
+
|
2942
|
+
const float d = ggml_fp16_to_fp32(x[i].d[0]) * y[i].d;
|
2943
|
+
const float m = ggml_fp16_to_fp32(x[i].d[1]) * y[i].d;
|
2944
|
+
const __m256 vd = _mm256_set1_ps(d);
|
2945
|
+
|
2946
|
+
const uint16_t * a = (const uint16_t *)x[i].scales;
|
2947
|
+
aux16[0] = a[0] & 0x0f0f;
|
2948
|
+
aux16[1] = (a[0] >> 4) & 0x0f0f;
|
2949
|
+
|
2950
|
+
summs += m * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3]));
|
2951
|
+
|
2952
|
+
const uint8_t * restrict q4 = x[i].qs;
|
2953
|
+
const int8_t * restrict q8 = y[i].qs;
|
2954
|
+
|
2955
|
+
const __m256i q4bits = _mm256_loadu_si256((const __m256i*)q4);
|
2956
|
+
const __m128i q4bits_0 = _mm256_extractf128_si256(q4bits, 0);
|
2957
|
+
const __m128i q4bits_1 = _mm256_extractf128_si256(q4bits, 1);
|
2958
|
+
const __m128i q4_0 = _mm_and_si128(q4bits_0, m4);
|
2959
|
+
const __m128i q4_1 = _mm_and_si128(q4bits_1, m4);
|
2960
|
+
const __m128i q4_2 = _mm_and_si128(_mm_srli_epi16(q4bits_0, 4), m4);
|
2961
|
+
const __m128i q4_3 = _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4);
|
2962
|
+
|
2963
|
+
const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
|
2964
|
+
const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
|
2965
|
+
|
2966
|
+
const __m128i p16_0 = _mm_maddubs_epi16(q4_0, _mm256_extractf128_si256(q8_0, 0));
|
2967
|
+
const __m128i p16_1 = _mm_maddubs_epi16(q4_1, _mm256_extractf128_si256(q8_0, 1));
|
2968
|
+
const __m128i p16_2 = _mm_maddubs_epi16(q4_2, _mm256_extractf128_si256(q8_1, 0));
|
2969
|
+
const __m128i p16_3 = _mm_maddubs_epi16(q4_3, _mm256_extractf128_si256(q8_1, 1));
|
2970
|
+
|
2971
|
+
const __m128i p32_0 = _mm_madd_epi16(_mm_set1_epi16(scales[0]), p16_0);
|
2972
|
+
const __m128i p32_1 = _mm_madd_epi16(_mm_set1_epi16(scales[0]), p16_1);
|
2973
|
+
acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(MM256_SET_M128I(p32_1, p32_0))), acc);
|
2974
|
+
|
2975
|
+
const __m128i p32_2 = _mm_madd_epi16(_mm_set1_epi16(scales[1]), p16_2);
|
2976
|
+
const __m128i p32_3 = _mm_madd_epi16(_mm_set1_epi16(scales[1]), p16_3);
|
2977
|
+
acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(MM256_SET_M128I(p32_3, p32_2))), acc);
|
2978
|
+
|
2979
|
+
}
|
2980
|
+
|
2981
|
+
*s = hsum_float_8(acc) - summs;
|
2982
|
+
|
2784
2983
|
#else
|
2785
2984
|
|
2786
2985
|
uint8_t aux8[QK_K];
|
@@ -2963,7 +3162,7 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
|
|
2963
3162
|
summs += dmin * _mm_extract_epi32(hsum, 0);
|
2964
3163
|
|
2965
3164
|
const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0);
|
2966
|
-
const __m256i scales =
|
3165
|
+
const __m256i scales = MM256_SET_M128I(sc128, sc128);
|
2967
3166
|
|
2968
3167
|
const __m256i hbits = _mm256_loadu_si256((const __m256i*)x[i].qh);
|
2969
3168
|
__m256i hmask = mone;
|
@@ -3102,7 +3301,7 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
|
|
3102
3301
|
}
|
3103
3302
|
|
3104
3303
|
__m256 vd = _mm256_set1_ps(d);
|
3105
|
-
__m256i sumi =
|
3304
|
+
__m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
|
3106
3305
|
acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(sumi)), acc);
|
3107
3306
|
|
3108
3307
|
}
|
@@ -3265,13 +3464,13 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
|
|
3265
3464
|
|
3266
3465
|
const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5);
|
3267
3466
|
|
3268
|
-
const __m256i scale_l =
|
3269
|
-
const __m256i scale_h =
|
3467
|
+
const __m256i scale_l = MM256_SET_M128I(_mm_set1_epi16(x[i].scales[1]), _mm_set1_epi16(x[i].scales[0]));
|
3468
|
+
const __m256i scale_h = MM256_SET_M128I(_mm_set1_epi16(x[i].scales[3]), _mm_set1_epi16(x[i].scales[2]));
|
3270
3469
|
|
3271
3470
|
int64_t aux64;
|
3272
3471
|
memcpy(&aux64, x[i].qh, 8);
|
3273
3472
|
const __m128i haux128 = _mm_set_epi64x(aux64 >> 1, aux64);
|
3274
|
-
const __m256i haux256 =
|
3473
|
+
const __m256i haux256 = MM256_SET_M128I(_mm_srli_epi16(haux128, 2), haux128);
|
3275
3474
|
|
3276
3475
|
const __m256i q5h_0 = _mm256_slli_epi16(_mm256_andnot_si256(haux256, mone), 4);
|
3277
3476
|
const __m256i q5h_1 = _mm256_slli_epi16(_mm256_andnot_si256(_mm256_srli_epi16(haux256, 4), mone), 4);
|
@@ -3295,10 +3494,66 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
|
|
3295
3494
|
|
3296
3495
|
*s = hsum_float_8(acc);
|
3297
3496
|
|
3298
|
-
#
|
3497
|
+
#elif defined __AVX__
|
3299
3498
|
|
3499
|
+
const __m128i m4 = _mm_set1_epi8(0xF);
|
3500
|
+
const __m128i mone = _mm_set1_epi8(1);
|
3300
3501
|
|
3301
|
-
|
3502
|
+
__m256 acc = _mm256_setzero_ps();
|
3503
|
+
|
3504
|
+
for (int i = 0; i < nb; ++i) {
|
3505
|
+
|
3506
|
+
const uint8_t * restrict q5 = x[i].qs;
|
3507
|
+
const int8_t * restrict q8 = y[i].qs;
|
3508
|
+
|
3509
|
+
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
|
3510
|
+
|
3511
|
+
const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5);
|
3512
|
+
|
3513
|
+
const __m128i scale_0 = _mm_set1_epi16(x[i].scales[0]);
|
3514
|
+
const __m128i scale_1 = _mm_set1_epi16(x[i].scales[1]);
|
3515
|
+
const __m128i scale_2 = _mm_set1_epi16(x[i].scales[2]);
|
3516
|
+
const __m128i scale_3 = _mm_set1_epi16(x[i].scales[3]);
|
3517
|
+
|
3518
|
+
int64_t aux64;
|
3519
|
+
memcpy(&aux64, x[i].qh, 8);
|
3520
|
+
const __m128i haux128_0 = _mm_set_epi64x(aux64 >> 1, aux64);
|
3521
|
+
const __m128i haux128_1 = _mm_srli_epi16(haux128_0, 2);
|
3522
|
+
|
3523
|
+
const __m128i q5h_0 = _mm_slli_epi16(_mm_andnot_si128(haux128_0, mone), 4);
|
3524
|
+
const __m128i q5h_1 = _mm_slli_epi16(_mm_andnot_si128(haux128_1, mone), 4);
|
3525
|
+
const __m128i q5h_2 = _mm_slli_epi16(_mm_andnot_si128(_mm_srli_epi16(haux128_0, 4), mone), 4);
|
3526
|
+
const __m128i q5h_3 = _mm_slli_epi16(_mm_andnot_si128(_mm_srli_epi16(haux128_1, 4), mone), 4);
|
3527
|
+
|
3528
|
+
const __m128i q5l_0 = _mm_and_si128(_mm256_extractf128_si256(q5bits, 0), m4);
|
3529
|
+
const __m128i q5l_1 = _mm_and_si128(_mm256_extractf128_si256(q5bits, 1), m4);
|
3530
|
+
const __m128i q5l_2 = _mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q5bits, 0), 4), m4);
|
3531
|
+
const __m128i q5l_3 = _mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q5bits, 1), 4), m4);
|
3532
|
+
|
3533
|
+
const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
|
3534
|
+
const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
|
3535
|
+
|
3536
|
+
const __m128i p16_0 = _mm_madd_epi16(scale_0, _mm_maddubs_epi16(q5l_0, _mm256_extractf128_si256(q8_0, 0)));
|
3537
|
+
const __m128i p16_1 = _mm_madd_epi16(scale_1, _mm_maddubs_epi16(q5l_1, _mm256_extractf128_si256(q8_0, 1)));
|
3538
|
+
const __m128i p16_2 = _mm_madd_epi16(scale_2, _mm_maddubs_epi16(q5l_2, _mm256_extractf128_si256(q8_1, 0)));
|
3539
|
+
const __m128i p16_3 = _mm_madd_epi16(scale_3, _mm_maddubs_epi16(q5l_3, _mm256_extractf128_si256(q8_1, 1)));
|
3540
|
+
const __m128i s16_0 = _mm_madd_epi16(scale_0, _mm_maddubs_epi16(q5h_0, _mm256_extractf128_si256(q8_0, 0)));
|
3541
|
+
const __m128i s16_1 = _mm_madd_epi16(scale_1, _mm_maddubs_epi16(q5h_1, _mm256_extractf128_si256(q8_0, 1)));
|
3542
|
+
const __m128i s16_2 = _mm_madd_epi16(scale_2, _mm_maddubs_epi16(q5h_2, _mm256_extractf128_si256(q8_1, 0)));
|
3543
|
+
const __m128i s16_3 = _mm_madd_epi16(scale_3, _mm_maddubs_epi16(q5h_3, _mm256_extractf128_si256(q8_1, 1)));
|
3544
|
+
|
3545
|
+
const __m128i dot_0 = _mm_sub_epi32(_mm_add_epi32(p16_0, p16_2), _mm_add_epi32(s16_0, s16_2));
|
3546
|
+
const __m128i dot_1 = _mm_sub_epi32(_mm_add_epi32(p16_1, p16_3), _mm_add_epi32(s16_1, s16_3));
|
3547
|
+
|
3548
|
+
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(dot_1, dot_0))), acc);
|
3549
|
+
|
3550
|
+
}
|
3551
|
+
|
3552
|
+
*s = hsum_float_8(acc);
|
3553
|
+
|
3554
|
+
#else
|
3555
|
+
|
3556
|
+
int8_t aux8[QK_K];
|
3302
3557
|
int16_t aux16[16];
|
3303
3558
|
float sums [8];
|
3304
3559
|
memset(sums, 0, 8*sizeof(float));
|
@@ -3308,7 +3563,7 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
|
|
3308
3563
|
const uint8_t * restrict q4 = x[i].qs;
|
3309
3564
|
const uint8_t * restrict hm = x[i].qh;
|
3310
3565
|
const int8_t * restrict q8 = y[i].qs;
|
3311
|
-
|
3566
|
+
int8_t * restrict a = aux8;
|
3312
3567
|
for (int l = 0; l < 32; ++l) {
|
3313
3568
|
a[l+ 0] = q4[l] & 0xF;
|
3314
3569
|
a[l+32] = q4[l] >> 4;
|
@@ -3672,7 +3927,7 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
|
|
3672
3927
|
|
3673
3928
|
}
|
3674
3929
|
|
3675
|
-
__m256i sumi =
|
3930
|
+
__m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
|
3676
3931
|
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi)), acc);
|
3677
3932
|
}
|
3678
3933
|
|
@@ -3830,8 +4085,8 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
|
|
3830
4085
|
const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4);
|
3831
4086
|
const __m128i q4bitsH = _mm_loadu_si128((const __m128i*)qh);
|
3832
4087
|
|
3833
|
-
const __m256i q4h_0 = _mm256_slli_epi16(_mm256_and_si256(
|
3834
|
-
const __m256i q4h_1 = _mm256_slli_epi16(_mm256_and_si256(
|
4088
|
+
const __m256i q4h_0 = _mm256_slli_epi16(_mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(q4bitsH, 2), q4bitsH), m2), 4);
|
4089
|
+
const __m256i q4h_1 = _mm256_slli_epi16(_mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(q4bitsH, 6), _mm_srli_epi16(q4bitsH, 4)), m2), 4);
|
3835
4090
|
|
3836
4091
|
const __m256i q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m4), q4h_0);
|
3837
4092
|
const __m256i q4_1 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m4), q4h_1);
|
@@ -3858,6 +4113,77 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
|
|
3858
4113
|
|
3859
4114
|
*s = hsum_float_8(acc);
|
3860
4115
|
|
4116
|
+
#elif defined __AVX__
|
4117
|
+
|
4118
|
+
const __m128i m4 = _mm_set1_epi8(0xF);
|
4119
|
+
const __m128i m2 = _mm_set1_epi8(3);
|
4120
|
+
const __m128i m32s = _mm_set1_epi8(32);
|
4121
|
+
|
4122
|
+
__m256 acc = _mm256_setzero_ps();
|
4123
|
+
|
4124
|
+
for (int i = 0; i < nb; ++i) {
|
4125
|
+
|
4126
|
+
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
|
4127
|
+
|
4128
|
+
const uint8_t * restrict q4 = x[i].ql;
|
4129
|
+
const uint8_t * restrict qh = x[i].qh;
|
4130
|
+
const int8_t * restrict q8 = y[i].qs;
|
4131
|
+
|
4132
|
+
const __m64 scales_1 = _mm_set1_pi8(x[i].scales[0]);
|
4133
|
+
const __m64 scales_2 = _mm_set1_pi8(x[i].scales[1]);
|
4134
|
+
const __m64 scales_3 = _mm_set1_pi8(x[i].scales[2]);
|
4135
|
+
const __m64 scales_4 = _mm_set1_pi8(x[i].scales[3]);
|
4136
|
+
|
4137
|
+
__m128i sumi_0 = _mm_setzero_si128();
|
4138
|
+
__m128i sumi_1 = _mm_setzero_si128();
|
4139
|
+
|
4140
|
+
const __m128i scale_0 = _mm_set_epi64(scales_2, scales_1);
|
4141
|
+
const __m128i scale_1 = _mm_set_epi64(scales_4, scales_3);
|
4142
|
+
|
4143
|
+
const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4);
|
4144
|
+
const __m128i q4bitsH = _mm_loadu_si128((const __m128i*)qh);
|
4145
|
+
|
4146
|
+
const __m128i q4h_0 = _mm_slli_epi16(_mm_and_si128(q4bitsH, m2), 4);
|
4147
|
+
const __m128i q4h_1 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH, 2), m2), 4);
|
4148
|
+
const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH, 4), m2), 4);
|
4149
|
+
const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH, 6), m2), 4);
|
4150
|
+
|
4151
|
+
const __m128i q4_0 = _mm_or_si128(_mm_and_si128(_mm256_extractf128_si256(q4bits1, 0), m4), q4h_0);
|
4152
|
+
const __m128i q4_1 = _mm_or_si128(_mm_and_si128(_mm256_extractf128_si256(q4bits1, 1), m4), q4h_1);
|
4153
|
+
const __m128i q4_2 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q4bits1, 0), 4), m4), q4h_2);
|
4154
|
+
const __m128i q4_3 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q4bits1, 1), 4), m4), q4h_3);
|
4155
|
+
|
4156
|
+
const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
|
4157
|
+
const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
|
4158
|
+
|
4159
|
+
__m128i q8s_0 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_0, 0));
|
4160
|
+
__m128i q8s_1 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_0, 1));
|
4161
|
+
__m128i q8s_2 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_1, 0));
|
4162
|
+
__m128i q8s_3 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_1, 1));
|
4163
|
+
|
4164
|
+
__m128i p16_0 = _mm_maddubs_epi16(q4_0, _mm256_extractf128_si256(q8_0, 0));
|
4165
|
+
__m128i p16_1 = _mm_maddubs_epi16(q4_1, _mm256_extractf128_si256(q8_0, 1));
|
4166
|
+
__m128i p16_2 = _mm_maddubs_epi16(q4_2, _mm256_extractf128_si256(q8_1, 0));
|
4167
|
+
__m128i p16_3 = _mm_maddubs_epi16(q4_3, _mm256_extractf128_si256(q8_1, 1));
|
4168
|
+
|
4169
|
+
p16_0 = _mm_sub_epi16(p16_0, q8s_0);
|
4170
|
+
p16_1 = _mm_sub_epi16(p16_1, q8s_1);
|
4171
|
+
p16_2 = _mm_sub_epi16(p16_2, q8s_2);
|
4172
|
+
p16_3 = _mm_sub_epi16(p16_3, q8s_3);
|
4173
|
+
|
4174
|
+
p16_0 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_0), p16_0);
|
4175
|
+
p16_1 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_0, scale_0)), p16_1);
|
4176
|
+
p16_2 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_1), p16_2);
|
4177
|
+
p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_1, scale_1)), p16_3);
|
4178
|
+
|
4179
|
+
sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2));
|
4180
|
+
sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3));
|
4181
|
+
|
4182
|
+
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(MM256_SET_M128I(sumi_1, sumi_0))), acc);
|
4183
|
+
}
|
4184
|
+
|
4185
|
+
*s = hsum_float_8(acc);
|
4186
|
+
|
3861
4187
|
#else
|
3862
4188
|
|
3863
4189
|
int8_t aux8[QK_K];
|