llama_cpp 0.3.4 → 0.3.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +12 -0
- data/README.md +18 -2
- data/ext/llama_cpp/extconf.rb +2 -1
- data/ext/llama_cpp/llama_cpp.cpp +315 -8
- data/ext/llama_cpp/src/ggml-alloc.c +541 -0
- data/ext/llama_cpp/src/ggml-alloc.h +22 -0
- data/ext/llama_cpp/src/ggml-cuda.cu +2271 -414
- data/ext/llama_cpp/src/ggml-cuda.h +1 -0
- data/ext/llama_cpp/src/ggml-metal.h +7 -0
- data/ext/llama_cpp/src/ggml-metal.m +218 -87
- data/ext/llama_cpp/src/ggml-metal.metal +72 -55
- data/ext/llama_cpp/src/ggml.c +754 -996
- data/ext/llama_cpp/src/ggml.h +94 -18
- data/ext/llama_cpp/src/k_quants.c +350 -24
- data/ext/llama_cpp/src/llama.cpp +713 -179
- data/ext/llama_cpp/src/llama.h +61 -5
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +26 -0
- metadata +4 -2
@@ -39,6 +39,8 @@
|
|
39
39
|
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
40
40
|
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
41
41
|
|
42
|
+
#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
|
43
|
+
|
42
44
|
//
|
43
45
|
// 2-6 bit quantization in super-blocks
|
44
46
|
//
|
@@ -1353,7 +1355,7 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
|
|
1353
1355
|
const __m256i all_scales = _mm256_cvtepi8_epi16(scales8);
|
1354
1356
|
const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0);
|
1355
1357
|
const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1);
|
1356
|
-
const __m256i scales[2] = {
|
1358
|
+
const __m256i scales[2] = {MM256_SET_M128I(l_scales, l_scales), MM256_SET_M128I(h_scales, h_scales)};
|
1357
1359
|
|
1358
1360
|
__m256i sumi = _mm256_setzero_si256();
|
1359
1361
|
|
@@ -1421,7 +1423,7 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
|
|
1421
1423
|
const __m128i summs_1 = _mm_madd_epi16(mins_1, _mm_loadu_si128((const __m128i*)&y[i].bsums[8]));
|
1422
1424
|
|
1423
1425
|
// sumf += -dmin * summs in 32bits*8
|
1424
|
-
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&dmin), _mm256_cvtepi32_ps(
|
1426
|
+
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&dmin), _mm256_cvtepi32_ps(MM256_SET_M128I(summs_1, summs_0))), acc);
|
1425
1427
|
|
1426
1428
|
const __m128i scales_0 = _mm_cvtepi8_epi16(scales16);
|
1427
1429
|
const __m128i scales_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(scales16, scales16));
|
@@ -1493,7 +1495,7 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
|
|
1493
1495
|
}
|
1494
1496
|
|
1495
1497
|
// sumf += dall * isum - dmin * summs in 32bits
|
1496
|
-
__m256i sumi =
|
1498
|
+
__m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
|
1497
1499
|
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&dall), _mm256_cvtepi32_ps(sumi)), acc);
|
1498
1500
|
}
|
1499
1501
|
|
@@ -1644,8 +1646,8 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
|
|
1644
1646
|
summs += dmin * smin;
|
1645
1647
|
|
1646
1648
|
const __m128i q2bits = _mm_loadu_si128((const __m128i*)q2);
|
1647
|
-
const __m256i q2_0 = _mm256_and_si256(
|
1648
|
-
const __m256i q2_1 = _mm256_and_si256(
|
1649
|
+
const __m256i q2_0 = _mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(q2bits, 2), q2bits), m3);
|
1650
|
+
const __m256i q2_1 = _mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(q2bits, 6), _mm_srli_epi16(q2bits, 4)), m3);
|
1649
1651
|
|
1650
1652
|
const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
|
1651
1653
|
const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
|
@@ -1666,6 +1668,62 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
|
|
1666
1668
|
|
1667
1669
|
*s = hsum_float_8(acc) + summs;
|
1668
1670
|
|
1671
|
+
#elif defined __AVX__
|
1672
|
+
|
1673
|
+
const __m128i m3 = _mm_set1_epi8(3);
|
1674
|
+
|
1675
|
+
__m256 acc = _mm256_setzero_ps();
|
1676
|
+
|
1677
|
+
uint32_t ud, um;
|
1678
|
+
const uint8_t * restrict db = (const uint8_t *)&ud;
|
1679
|
+
const uint8_t * restrict mb = (const uint8_t *)&um;
|
1680
|
+
|
1681
|
+
float summs = 0;
|
1682
|
+
|
1683
|
+
// TODO: optimize this
|
1684
|
+
|
1685
|
+
for (int i = 0; i < nb; ++i) {
|
1686
|
+
|
1687
|
+
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
|
1688
|
+
const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
|
1689
|
+
|
1690
|
+
const uint8_t * restrict q2 = x[i].qs;
|
1691
|
+
const int8_t * restrict q8 = y[i].qs;
|
1692
|
+
|
1693
|
+
const uint32_t * restrict sc = (const uint32_t *)x[i].scales;
|
1694
|
+
ud = (sc[0] >> 0) & 0x0f0f0f0f;
|
1695
|
+
um = (sc[0] >> 4) & 0x0f0f0f0f;
|
1696
|
+
|
1697
|
+
int32_t smin = mb[0] * y[i].bsums[0] + mb[1] * y[i].bsums[1] + mb[2] * y[i].bsums[2] + mb[3] * y[i].bsums[3];
|
1698
|
+
summs += dmin * smin;
|
1699
|
+
|
1700
|
+
const __m128i q2bits = _mm_loadu_si128((const __m128i*)q2);
|
1701
|
+
const __m128i q2_0 = _mm_and_si128(q2bits, m3);
|
1702
|
+
const __m128i q2_1 = _mm_and_si128(_mm_srli_epi16(q2bits, 2), m3);
|
1703
|
+
const __m128i q2_2 = _mm_and_si128(_mm_srli_epi16(q2bits, 4), m3);
|
1704
|
+
const __m128i q2_3 = _mm_and_si128(_mm_srli_epi16(q2bits, 6), m3);
|
1705
|
+
|
1706
|
+
const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
|
1707
|
+
const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
|
1708
|
+
|
1709
|
+
const __m128i p0 = _mm_maddubs_epi16(q2_0, _mm256_extractf128_si256(q8_0, 0));
|
1710
|
+
const __m128i p1 = _mm_maddubs_epi16(q2_1, _mm256_extractf128_si256(q8_0, 1));
|
1711
|
+
const __m128i p2 = _mm_maddubs_epi16(q2_2, _mm256_extractf128_si256(q8_1, 0));
|
1712
|
+
const __m128i p3 = _mm_maddubs_epi16(q2_3, _mm256_extractf128_si256(q8_1, 1));
|
1713
|
+
|
1714
|
+
const __m256i p_0 = MM256_SET_M128I(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p0, p0)), _mm_cvtepi16_epi32(p0));
|
1715
|
+
const __m256i p_1 = MM256_SET_M128I(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p1, p1)), _mm_cvtepi16_epi32(p1));
|
1716
|
+
const __m256i p_2 = MM256_SET_M128I(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p2, p2)), _mm_cvtepi16_epi32(p2));
|
1717
|
+
const __m256i p_3 = MM256_SET_M128I(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p3, p3)), _mm_cvtepi16_epi32(p3));
|
1718
|
+
|
1719
|
+
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[0]), _mm256_cvtepi32_ps(p_0)), acc);
|
1720
|
+
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[1]), _mm256_cvtepi32_ps(p_1)), acc);
|
1721
|
+
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[2]), _mm256_cvtepi32_ps(p_2)), acc);
|
1722
|
+
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[3]), _mm256_cvtepi32_ps(p_3)), acc);
|
1723
|
+
}
|
1724
|
+
|
1725
|
+
*s = hsum_float_8(acc) + summs;
|
1726
|
+
|
1669
1727
|
#else
|
1670
1728
|
|
1671
1729
|
float sumf = 0;
|
@@ -1861,7 +1919,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
|
|
1861
1919
|
const __m256i all_scales = _mm256_cvtepi8_epi16(scales128);
|
1862
1920
|
const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0);
|
1863
1921
|
const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1);
|
1864
|
-
const __m256i scales[2] = {
|
1922
|
+
const __m256i scales[2] = {MM256_SET_M128I(l_scales, l_scales), MM256_SET_M128I(h_scales, h_scales)};
|
1865
1923
|
|
1866
1924
|
// high bit
|
1867
1925
|
const __m256i hbits = _mm256_loadu_si256((const __m256i*)x[i].hmask);
|
@@ -2072,7 +2130,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
|
|
2072
2130
|
}
|
2073
2131
|
|
2074
2132
|
// multiply with block scale and accumulate
|
2075
|
-
__m256i sumi =
|
2133
|
+
__m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
|
2076
2134
|
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi)), acc);
|
2077
2135
|
|
2078
2136
|
}
|
@@ -2247,13 +2305,13 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
|
|
2247
2305
|
aux16[0] = a & 0x0f0f;
|
2248
2306
|
aux16[1] = (a >> 4) & 0x0f0f;
|
2249
2307
|
|
2250
|
-
const __m256i scale_0 =
|
2251
|
-
const __m256i scale_1 =
|
2308
|
+
const __m256i scale_0 = MM256_SET_M128I(_mm_set1_epi16(aux8[2] - 8), _mm_set1_epi16(aux8[0] - 8));
|
2309
|
+
const __m256i scale_1 = MM256_SET_M128I(_mm_set1_epi16(aux8[3] - 8), _mm_set1_epi16(aux8[1] - 8));
|
2252
2310
|
|
2253
2311
|
memcpy(&aux64, x[i].hmask, 8);
|
2254
2312
|
|
2255
2313
|
const __m128i haux = _mm_set_epi64x(aux64 >> 1, aux64 >> 0);
|
2256
|
-
__m256i q3h_0 =
|
2314
|
+
__m256i q3h_0 = MM256_SET_M128I(_mm_srli_epi16(haux, 2), haux);
|
2257
2315
|
__m256i q3h_1 = _mm256_srli_epi16(q3h_0, 4);
|
2258
2316
|
q3h_0 = _mm256_slli_epi16(_mm256_andnot_si256(q3h_0, m1), 2);
|
2259
2317
|
q3h_1 = _mm256_slli_epi16(_mm256_andnot_si256(q3h_1, m1), 2);
|
@@ -2262,7 +2320,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
|
|
2262
2320
|
const __m128i q3bits = _mm_loadu_si128((const __m128i*)q3);
|
2263
2321
|
|
2264
2322
|
// prepare low and high bits
|
2265
|
-
const __m256i q3aux =
|
2323
|
+
const __m256i q3aux = MM256_SET_M128I(_mm_srli_epi16(q3bits, 2), q3bits);
|
2266
2324
|
const __m256i q3l_0 = _mm256_and_si256(q3aux, m3);
|
2267
2325
|
const __m256i q3l_1 = _mm256_and_si256(_mm256_srli_epi16(q3aux, 4), m3);
|
2268
2326
|
|
@@ -2295,6 +2353,93 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
|
|
2295
2353
|
|
2296
2354
|
*s = hsum_float_8(acc);
|
2297
2355
|
|
2356
|
+
#elif defined __AVX__
|
2357
|
+
|
2358
|
+
const __m128i m3 = _mm_set1_epi8(3);
|
2359
|
+
const __m128i m1 = _mm_set1_epi8(1);
|
2360
|
+
|
2361
|
+
__m256 acc = _mm256_setzero_ps();
|
2362
|
+
|
2363
|
+
uint64_t aux64;
|
2364
|
+
|
2365
|
+
uint16_t aux16[2];
|
2366
|
+
const int8_t * aux8 = (const int8_t *)aux16;
|
2367
|
+
|
2368
|
+
for (int i = 0; i < nb; ++i) {
|
2369
|
+
|
2370
|
+
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
|
2371
|
+
|
2372
|
+
const uint8_t * restrict q3 = x[i].qs;
|
2373
|
+
const int8_t * restrict q8 = y[i].qs;
|
2374
|
+
|
2375
|
+
const uint16_t a = *(const uint16_t *)x[i].scales;
|
2376
|
+
aux16[0] = a & 0x0f0f;
|
2377
|
+
aux16[1] = (a >> 4) & 0x0f0f;
|
2378
|
+
|
2379
|
+
const __m128i scale_0 = _mm_set1_epi16(aux8[0] - 8);
|
2380
|
+
const __m128i scale_1 = _mm_set1_epi16(aux8[2] - 8);
|
2381
|
+
const __m128i scale_2 = _mm_set1_epi16(aux8[1] - 8);
|
2382
|
+
const __m128i scale_3 = _mm_set1_epi16(aux8[3] - 8);
|
2383
|
+
|
2384
|
+
memcpy(&aux64, x[i].hmask, 8);
|
2385
|
+
|
2386
|
+
__m128i q3h_0 = _mm_set_epi64x(aux64 >> 1, aux64 >> 0);
|
2387
|
+
__m128i q3h_1 = _mm_srli_epi16(q3h_0, 2);
|
2388
|
+
__m128i q3h_2 = _mm_srli_epi16(q3h_0, 4);
|
2389
|
+
__m128i q3h_3 = _mm_srli_epi16(q3h_0, 6);
|
2390
|
+
q3h_0 = _mm_slli_epi16(_mm_andnot_si128(q3h_0, m1), 2);
|
2391
|
+
q3h_1 = _mm_slli_epi16(_mm_andnot_si128(q3h_1, m1), 2);
|
2392
|
+
q3h_2 = _mm_slli_epi16(_mm_andnot_si128(q3h_2, m1), 2);
|
2393
|
+
q3h_3 = _mm_slli_epi16(_mm_andnot_si128(q3h_3, m1), 2);
|
2394
|
+
|
2395
|
+
// load low 2 bits
|
2396
|
+
const __m128i q3bits = _mm_loadu_si128((const __m128i*)q3);
|
2397
|
+
|
2398
|
+
// prepare low and high bits
|
2399
|
+
const __m128i q3l_0 = _mm_and_si128(q3bits, m3);
|
2400
|
+
const __m128i q3l_1 = _mm_and_si128(_mm_srli_epi16(q3bits, 2), m3);
|
2401
|
+
const __m128i q3l_2 = _mm_and_si128(_mm_srli_epi16(q3bits, 4), m3);
|
2402
|
+
const __m128i q3l_3 = _mm_and_si128(_mm_srli_epi16(q3bits, 6), m3);
|
2403
|
+
|
2404
|
+
// load Q8 quants
|
2405
|
+
const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
|
2406
|
+
const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
|
2407
|
+
|
2408
|
+
// Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm_maddubs_epi16,
|
2409
|
+
// and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set,
|
2410
|
+
// and 2 if the high bit was set)
|
2411
|
+
const __m128i q8s_0 = _mm_maddubs_epi16(q3h_0, _mm256_extractf128_si256(q8_0, 0));
|
2412
|
+
const __m128i q8s_1 = _mm_maddubs_epi16(q3h_1, _mm256_extractf128_si256(q8_0, 1));
|
2413
|
+
const __m128i q8s_2 = _mm_maddubs_epi16(q3h_2, _mm256_extractf128_si256(q8_1, 0));
|
2414
|
+
const __m128i q8s_3 = _mm_maddubs_epi16(q3h_3, _mm256_extractf128_si256(q8_1, 1));
|
2415
|
+
|
2416
|
+
__m128i p16_0 = _mm_maddubs_epi16(q3l_0, _mm256_extractf128_si256(q8_0, 0));
|
2417
|
+
__m128i p16_1 = _mm_maddubs_epi16(q3l_1, _mm256_extractf128_si256(q8_0, 1));
|
2418
|
+
__m128i p16_2 = _mm_maddubs_epi16(q3l_2, _mm256_extractf128_si256(q8_1, 0));
|
2419
|
+
__m128i p16_3 = _mm_maddubs_epi16(q3l_3, _mm256_extractf128_si256(q8_1, 1));
|
2420
|
+
|
2421
|
+
p16_0 = _mm_sub_epi16(p16_0, q8s_0);
|
2422
|
+
p16_1 = _mm_sub_epi16(p16_1, q8s_1);
|
2423
|
+
p16_2 = _mm_sub_epi16(p16_2, q8s_2);
|
2424
|
+
p16_3 = _mm_sub_epi16(p16_3, q8s_3);
|
2425
|
+
|
2426
|
+
// multiply with scales
|
2427
|
+
p16_0 = _mm_madd_epi16(scale_0, p16_0);
|
2428
|
+
p16_1 = _mm_madd_epi16(scale_1, p16_1);
|
2429
|
+
p16_2 = _mm_madd_epi16(scale_2, p16_2);
|
2430
|
+
p16_3 = _mm_madd_epi16(scale_3, p16_3);
|
2431
|
+
|
2432
|
+
p16_0 = _mm_add_epi32(p16_0, p16_2);
|
2433
|
+
p16_1 = _mm_add_epi32(p16_1, p16_3);
|
2434
|
+
__m256i p16 = MM256_SET_M128I(p16_1, p16_0);
|
2435
|
+
|
2436
|
+
// multiply with block scale and accumulate
|
2437
|
+
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(p16)), acc);
|
2438
|
+
|
2439
|
+
}
|
2440
|
+
|
2441
|
+
*s = hsum_float_8(acc);
|
2442
|
+
|
2298
2443
|
#else
|
2299
2444
|
|
2300
2445
|
int8_t aux8[QK_K];
|
@@ -2477,7 +2622,7 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
|
|
2477
2622
|
acc_m = _mm_fmadd_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod), acc_m);
|
2478
2623
|
|
2479
2624
|
const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0);
|
2480
|
-
const __m256i scales =
|
2625
|
+
const __m256i scales = MM256_SET_M128I(sc128, sc128);
|
2481
2626
|
|
2482
2627
|
__m256i sumi = _mm256_setzero_si256();
|
2483
2628
|
|
@@ -2584,7 +2729,7 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
|
|
2584
2729
|
}
|
2585
2730
|
|
2586
2731
|
__m256 vd = _mm256_set1_ps(d);
|
2587
|
-
__m256i sumi =
|
2732
|
+
__m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
|
2588
2733
|
acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(sumi)), acc);
|
2589
2734
|
|
2590
2735
|
}
|
@@ -2781,6 +2926,60 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
|
|
2781
2926
|
|
2782
2927
|
*s = hsum_float_8(acc) - summs;
|
2783
2928
|
|
2929
|
+
#elif defined __AVX__
|
2930
|
+
|
2931
|
+
const __m128i m4 = _mm_set1_epi8(0xF);
|
2932
|
+
|
2933
|
+
__m256 acc = _mm256_setzero_ps();
|
2934
|
+
|
2935
|
+
float summs = 0;
|
2936
|
+
|
2937
|
+
uint16_t aux16[2];
|
2938
|
+
const uint8_t * scales = (const uint8_t *)aux16;
|
2939
|
+
|
2940
|
+
for (int i = 0; i < nb; ++i) {
|
2941
|
+
|
2942
|
+
const float d = ggml_fp16_to_fp32(x[i].d[0]) * y[i].d;
|
2943
|
+
const float m = ggml_fp16_to_fp32(x[i].d[1]) * y[i].d;
|
2944
|
+
const __m256 vd = _mm256_set1_ps(d);
|
2945
|
+
|
2946
|
+
const uint16_t * a = (const uint16_t *)x[i].scales;
|
2947
|
+
aux16[0] = a[0] & 0x0f0f;
|
2948
|
+
aux16[1] = (a[0] >> 4) & 0x0f0f;
|
2949
|
+
|
2950
|
+
summs += m * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3]));
|
2951
|
+
|
2952
|
+
const uint8_t * restrict q4 = x[i].qs;
|
2953
|
+
const int8_t * restrict q8 = y[i].qs;
|
2954
|
+
|
2955
|
+
const __m256i q4bits = _mm256_loadu_si256((const __m256i*)q4);
|
2956
|
+
const __m128i q4bits_0 = _mm256_extractf128_si256(q4bits, 0);
|
2957
|
+
const __m128i q4bits_1 = _mm256_extractf128_si256(q4bits, 1);
|
2958
|
+
const __m128i q4_0 = _mm_and_si128(q4bits_0, m4);
|
2959
|
+
const __m128i q4_1 = _mm_and_si128(q4bits_1, m4);
|
2960
|
+
const __m128i q4_2 = _mm_and_si128(_mm_srli_epi16(q4bits_0, 4), m4);
|
2961
|
+
const __m128i q4_3 = _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4);
|
2962
|
+
|
2963
|
+
const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
|
2964
|
+
const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
|
2965
|
+
|
2966
|
+
const __m128i p16_0 = _mm_maddubs_epi16(q4_0, _mm256_extractf128_si256(q8_0, 0));
|
2967
|
+
const __m128i p16_1 = _mm_maddubs_epi16(q4_1, _mm256_extractf128_si256(q8_0, 1));
|
2968
|
+
const __m128i p16_2 = _mm_maddubs_epi16(q4_2, _mm256_extractf128_si256(q8_1, 0));
|
2969
|
+
const __m128i p16_3 = _mm_maddubs_epi16(q4_3, _mm256_extractf128_si256(q8_1, 1));
|
2970
|
+
|
2971
|
+
const __m128i p32_0 = _mm_madd_epi16(_mm_set1_epi16(scales[0]), p16_0);
|
2972
|
+
const __m128i p32_1 = _mm_madd_epi16(_mm_set1_epi16(scales[0]), p16_1);
|
2973
|
+
acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(MM256_SET_M128I(p32_1, p32_0))), acc);
|
2974
|
+
|
2975
|
+
const __m128i p32_2 = _mm_madd_epi16(_mm_set1_epi16(scales[1]), p16_2);
|
2976
|
+
const __m128i p32_3 = _mm_madd_epi16(_mm_set1_epi16(scales[1]), p16_3);
|
2977
|
+
acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(MM256_SET_M128I(p32_3, p32_2))), acc);
|
2978
|
+
|
2979
|
+
}
|
2980
|
+
|
2981
|
+
*s = hsum_float_8(acc) - summs;
|
2982
|
+
|
2784
2983
|
#else
|
2785
2984
|
|
2786
2985
|
uint8_t aux8[QK_K];
|
@@ -2963,7 +3162,7 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
|
|
2963
3162
|
summs += dmin * _mm_extract_epi32(hsum, 0);
|
2964
3163
|
|
2965
3164
|
const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0);
|
2966
|
-
const __m256i scales =
|
3165
|
+
const __m256i scales = MM256_SET_M128I(sc128, sc128);
|
2967
3166
|
|
2968
3167
|
const __m256i hbits = _mm256_loadu_si256((const __m256i*)x[i].qh);
|
2969
3168
|
__m256i hmask = mone;
|
@@ -3102,7 +3301,7 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
|
|
3102
3301
|
}
|
3103
3302
|
|
3104
3303
|
__m256 vd = _mm256_set1_ps(d);
|
3105
|
-
__m256i sumi =
|
3304
|
+
__m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
|
3106
3305
|
acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(sumi)), acc);
|
3107
3306
|
|
3108
3307
|
}
|
@@ -3265,13 +3464,13 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
|
|
3265
3464
|
|
3266
3465
|
const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5);
|
3267
3466
|
|
3268
|
-
const __m256i scale_l =
|
3269
|
-
const __m256i scale_h =
|
3467
|
+
const __m256i scale_l = MM256_SET_M128I(_mm_set1_epi16(x[i].scales[1]), _mm_set1_epi16(x[i].scales[0]));
|
3468
|
+
const __m256i scale_h = MM256_SET_M128I(_mm_set1_epi16(x[i].scales[3]), _mm_set1_epi16(x[i].scales[2]));
|
3270
3469
|
|
3271
3470
|
int64_t aux64;
|
3272
3471
|
memcpy(&aux64, x[i].qh, 8);
|
3273
3472
|
const __m128i haux128 = _mm_set_epi64x(aux64 >> 1, aux64);
|
3274
|
-
const __m256i haux256 =
|
3473
|
+
const __m256i haux256 = MM256_SET_M128I(_mm_srli_epi16(haux128, 2), haux128);
|
3275
3474
|
|
3276
3475
|
const __m256i q5h_0 = _mm256_slli_epi16(_mm256_andnot_si256(haux256, mone), 4);
|
3277
3476
|
const __m256i q5h_1 = _mm256_slli_epi16(_mm256_andnot_si256(_mm256_srli_epi16(haux256, 4), mone), 4);
|
@@ -3295,10 +3494,66 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
|
|
3295
3494
|
|
3296
3495
|
*s = hsum_float_8(acc);
|
3297
3496
|
|
3298
|
-
#
|
3497
|
+
#elif defined __AVX__
|
3299
3498
|
|
3499
|
+
const __m128i m4 = _mm_set1_epi8(0xF);
|
3500
|
+
const __m128i mone = _mm_set1_epi8(1);
|
3300
3501
|
|
3301
|
-
|
3502
|
+
__m256 acc = _mm256_setzero_ps();
|
3503
|
+
|
3504
|
+
for (int i = 0; i < nb; ++i) {
|
3505
|
+
|
3506
|
+
const uint8_t * restrict q5 = x[i].qs;
|
3507
|
+
const int8_t * restrict q8 = y[i].qs;
|
3508
|
+
|
3509
|
+
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
|
3510
|
+
|
3511
|
+
const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5);
|
3512
|
+
|
3513
|
+
const __m128i scale_0 = _mm_set1_epi16(x[i].scales[0]);
|
3514
|
+
const __m128i scale_1 = _mm_set1_epi16(x[i].scales[1]);
|
3515
|
+
const __m128i scale_2 = _mm_set1_epi16(x[i].scales[2]);
|
3516
|
+
const __m128i scale_3 = _mm_set1_epi16(x[i].scales[3]);
|
3517
|
+
|
3518
|
+
int64_t aux64;
|
3519
|
+
memcpy(&aux64, x[i].qh, 8);
|
3520
|
+
const __m128i haux128_0 = _mm_set_epi64x(aux64 >> 1, aux64);
|
3521
|
+
const __m128i haux128_1 = _mm_srli_epi16(haux128_0, 2);
|
3522
|
+
|
3523
|
+
const __m128i q5h_0 = _mm_slli_epi16(_mm_andnot_si128(haux128_0, mone), 4);
|
3524
|
+
const __m128i q5h_1 = _mm_slli_epi16(_mm_andnot_si128(haux128_1, mone), 4);
|
3525
|
+
const __m128i q5h_2 = _mm_slli_epi16(_mm_andnot_si128(_mm_srli_epi16(haux128_0, 4), mone), 4);
|
3526
|
+
const __m128i q5h_3 = _mm_slli_epi16(_mm_andnot_si128(_mm_srli_epi16(haux128_1, 4), mone), 4);
|
3527
|
+
|
3528
|
+
const __m128i q5l_0 = _mm_and_si128(_mm256_extractf128_si256(q5bits, 0), m4);
|
3529
|
+
const __m128i q5l_1 = _mm_and_si128(_mm256_extractf128_si256(q5bits, 1), m4);
|
3530
|
+
const __m128i q5l_2 = _mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q5bits, 0), 4), m4);
|
3531
|
+
const __m128i q5l_3 = _mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q5bits, 1), 4), m4);
|
3532
|
+
|
3533
|
+
const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
|
3534
|
+
const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
|
3535
|
+
|
3536
|
+
const __m128i p16_0 = _mm_madd_epi16(scale_0, _mm_maddubs_epi16(q5l_0, _mm256_extractf128_si256(q8_0, 0)));
|
3537
|
+
const __m128i p16_1 = _mm_madd_epi16(scale_1, _mm_maddubs_epi16(q5l_1, _mm256_extractf128_si256(q8_0, 1)));
|
3538
|
+
const __m128i p16_2 = _mm_madd_epi16(scale_2, _mm_maddubs_epi16(q5l_2, _mm256_extractf128_si256(q8_1, 0)));
|
3539
|
+
const __m128i p16_3 = _mm_madd_epi16(scale_3, _mm_maddubs_epi16(q5l_3, _mm256_extractf128_si256(q8_1, 1)));
|
3540
|
+
const __m128i s16_0 = _mm_madd_epi16(scale_0, _mm_maddubs_epi16(q5h_0, _mm256_extractf128_si256(q8_0, 0)));
|
3541
|
+
const __m128i s16_1 = _mm_madd_epi16(scale_1, _mm_maddubs_epi16(q5h_1, _mm256_extractf128_si256(q8_0, 1)));
|
3542
|
+
const __m128i s16_2 = _mm_madd_epi16(scale_2, _mm_maddubs_epi16(q5h_2, _mm256_extractf128_si256(q8_1, 0)));
|
3543
|
+
const __m128i s16_3 = _mm_madd_epi16(scale_3, _mm_maddubs_epi16(q5h_3, _mm256_extractf128_si256(q8_1, 1)));
|
3544
|
+
|
3545
|
+
const __m128i dot_0 = _mm_sub_epi32(_mm_add_epi32(p16_0, p16_2), _mm_add_epi32(s16_0, s16_2));
|
3546
|
+
const __m128i dot_1 = _mm_sub_epi32(_mm_add_epi32(p16_1, p16_3), _mm_add_epi32(s16_1, s16_3));
|
3547
|
+
|
3548
|
+
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(dot_1, dot_0))), acc);
|
3549
|
+
|
3550
|
+
}
|
3551
|
+
|
3552
|
+
*s = hsum_float_8(acc);
|
3553
|
+
|
3554
|
+
#else
|
3555
|
+
|
3556
|
+
int8_t aux8[QK_K];
|
3302
3557
|
int16_t aux16[16];
|
3303
3558
|
float sums [8];
|
3304
3559
|
memset(sums, 0, 8*sizeof(float));
|
@@ -3308,7 +3563,7 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
|
|
3308
3563
|
const uint8_t * restrict q4 = x[i].qs;
|
3309
3564
|
const uint8_t * restrict hm = x[i].qh;
|
3310
3565
|
const int8_t * restrict q8 = y[i].qs;
|
3311
|
-
|
3566
|
+
int8_t * restrict a = aux8;
|
3312
3567
|
for (int l = 0; l < 32; ++l) {
|
3313
3568
|
a[l+ 0] = q4[l] & 0xF;
|
3314
3569
|
a[l+32] = q4[l] >> 4;
|
@@ -3672,7 +3927,7 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
|
|
3672
3927
|
|
3673
3928
|
}
|
3674
3929
|
|
3675
|
-
__m256i sumi =
|
3930
|
+
__m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
|
3676
3931
|
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi)), acc);
|
3677
3932
|
}
|
3678
3933
|
|
@@ -3830,8 +4085,8 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
|
|
3830
4085
|
const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4);
|
3831
4086
|
const __m128i q4bitsH = _mm_loadu_si128((const __m128i*)qh);
|
3832
4087
|
|
3833
|
-
const __m256i q4h_0 = _mm256_slli_epi16(_mm256_and_si256(
|
3834
|
-
const __m256i q4h_1 = _mm256_slli_epi16(_mm256_and_si256(
|
4088
|
+
const __m256i q4h_0 = _mm256_slli_epi16(_mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(q4bitsH, 2), q4bitsH), m2), 4);
|
4089
|
+
const __m256i q4h_1 = _mm256_slli_epi16(_mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(q4bitsH, 6), _mm_srli_epi16(q4bitsH, 4)), m2), 4);
|
3835
4090
|
|
3836
4091
|
const __m256i q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m4), q4h_0);
|
3837
4092
|
const __m256i q4_1 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m4), q4h_1);
|
@@ -3858,6 +4113,77 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
|
|
3858
4113
|
|
3859
4114
|
*s = hsum_float_8(acc);
|
3860
4115
|
|
4116
|
+
#elif defined __AVX__
|
4117
|
+
|
4118
|
+
const __m128i m4 = _mm_set1_epi8(0xF);
|
4119
|
+
const __m128i m2 = _mm_set1_epi8(3);
|
4120
|
+
const __m128i m32s = _mm_set1_epi8(32);
|
4121
|
+
|
4122
|
+
__m256 acc = _mm256_setzero_ps();
|
4123
|
+
|
4124
|
+
for (int i = 0; i < nb; ++i) {
|
4125
|
+
|
4126
|
+
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
|
4127
|
+
|
4128
|
+
const uint8_t * restrict q4 = x[i].ql;
|
4129
|
+
const uint8_t * restrict qh = x[i].qh;
|
4130
|
+
const int8_t * restrict q8 = y[i].qs;
|
4131
|
+
|
4132
|
+
const __m64 scales_1 = _mm_set1_pi8(x[i].scales[0]);
|
4133
|
+
const __m64 scales_2 = _mm_set1_pi8(x[i].scales[1]);
|
4134
|
+
const __m64 scales_3 = _mm_set1_pi8(x[i].scales[2]);
|
4135
|
+
const __m64 scales_4 = _mm_set1_pi8(x[i].scales[3]);
|
4136
|
+
|
4137
|
+
__m128i sumi_0 = _mm_setzero_si128();
|
4138
|
+
__m128i sumi_1 = _mm_setzero_si128();
|
4139
|
+
|
4140
|
+
const __m128i scale_0 = _mm_set_epi64(scales_2, scales_1);
|
4141
|
+
const __m128i scale_1 = _mm_set_epi64(scales_4, scales_3);
|
4142
|
+
|
4143
|
+
const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4);
|
4144
|
+
const __m128i q4bitsH = _mm_loadu_si128((const __m128i*)qh);
|
4145
|
+
|
4146
|
+
const __m128i q4h_0 = _mm_slli_epi16(_mm_and_si128(q4bitsH, m2), 4);
|
4147
|
+
const __m128i q4h_1 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH, 2), m2), 4);
|
4148
|
+
const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH, 4), m2), 4);
|
4149
|
+
const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH, 6), m2), 4);
|
4150
|
+
|
4151
|
+
const __m128i q4_0 = _mm_or_si128(_mm_and_si128(_mm256_extractf128_si256(q4bits1, 0), m4), q4h_0);
|
4152
|
+
const __m128i q4_1 = _mm_or_si128(_mm_and_si128(_mm256_extractf128_si256(q4bits1, 1), m4), q4h_1);
|
4153
|
+
const __m128i q4_2 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q4bits1, 0), 4), m4), q4h_2);
|
4154
|
+
const __m128i q4_3 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q4bits1, 1), 4), m4), q4h_3);
|
4155
|
+
|
4156
|
+
const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
|
4157
|
+
const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
|
4158
|
+
|
4159
|
+
__m128i q8s_0 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_0, 0));
|
4160
|
+
__m128i q8s_1 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_0, 1));
|
4161
|
+
__m128i q8s_2 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_1, 0));
|
4162
|
+
__m128i q8s_3 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_1, 1));
|
4163
|
+
|
4164
|
+
__m128i p16_0 = _mm_maddubs_epi16(q4_0, _mm256_extractf128_si256(q8_0, 0));
|
4165
|
+
__m128i p16_1 = _mm_maddubs_epi16(q4_1, _mm256_extractf128_si256(q8_0, 1));
|
4166
|
+
__m128i p16_2 = _mm_maddubs_epi16(q4_2, _mm256_extractf128_si256(q8_1, 0));
|
4167
|
+
__m128i p16_3 = _mm_maddubs_epi16(q4_3, _mm256_extractf128_si256(q8_1, 1));
|
4168
|
+
|
4169
|
+
p16_0 = _mm_sub_epi16(p16_0, q8s_0);
|
4170
|
+
p16_1 = _mm_sub_epi16(p16_1, q8s_1);
|
4171
|
+
p16_2 = _mm_sub_epi16(p16_2, q8s_2);
|
4172
|
+
p16_3 = _mm_sub_epi16(p16_3, q8s_3);
|
4173
|
+
|
4174
|
+
p16_0 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_0), p16_0);
|
4175
|
+
p16_1 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_0, scale_0)), p16_1);
|
4176
|
+
p16_2 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_1), p16_2);
|
4177
|
+
p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_1, scale_1)), p16_3);
|
4178
|
+
|
4179
|
+
sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2));
|
4180
|
+
sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3));
|
4181
|
+
|
4182
|
+
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(MM256_SET_M128I(sumi_1, sumi_0))), acc);
|
4183
|
+
}
|
4184
|
+
|
4185
|
+
*s = hsum_float_8(acc);
|
4186
|
+
|
3861
4187
|
#else
|
3862
4188
|
|
3863
4189
|
int8_t aux8[QK_K];
|