llama_cpp 0.15.2 → 0.15.3
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +8 -0
- data/ext/llama_cpp/llama_cpp.cpp +49 -0
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +4 -0
- data/vendor/tmp/llama.cpp/Makefile +6 -17
- data/vendor/tmp/llama.cpp/ggml-common.h +0 -54
- data/vendor/tmp/llama.cpp/ggml-cuda.cu +72 -30
- data/vendor/tmp/llama.cpp/ggml-cuda.h +1 -0
- data/vendor/tmp/llama.cpp/ggml-impl.h +40 -0
- data/vendor/tmp/llama.cpp/ggml-kompute.cpp +4 -0
- data/vendor/tmp/llama.cpp/ggml-metal.m +68 -70
- data/vendor/tmp/llama.cpp/ggml-metal.metal +24 -409
- data/vendor/tmp/llama.cpp/ggml-opencl.cpp +4 -1
- data/vendor/tmp/llama.cpp/ggml-quants.c +1879 -2450
- data/vendor/tmp/llama.cpp/ggml-rpc.cpp +176 -53
- data/vendor/tmp/llama.cpp/ggml-sycl.cpp +40 -500
- data/vendor/tmp/llama.cpp/ggml-vulkan-shaders.hpp +9351 -5627
- data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +202 -225
- data/vendor/tmp/llama.cpp/ggml.c +376 -758
- data/vendor/tmp/llama.cpp/ggml.h +39 -27
- data/vendor/tmp/llama.cpp/llama.cpp +823 -593
- data/vendor/tmp/llama.cpp/llama.h +10 -3
- metadata +3 -3
data/vendor/tmp/llama.cpp/ggml.c
CHANGED
@@ -406,10 +406,10 @@ void ggml_fp32_to_bf16_row(const float * x, ggml_bf16_t * y, int64_t n) {
|
|
406
406
|
int i = 0;
|
407
407
|
#if defined(__AVX512BF16__)
|
408
408
|
for (; i + 32 <= n; i += 32) {
|
409
|
-
|
410
|
-
(
|
411
|
-
(
|
412
|
-
|
409
|
+
_mm512_storeu_si512(
|
410
|
+
(__m512i *)(y + i),
|
411
|
+
m512i(_mm512_cvtne2ps_pbh(_mm512_loadu_ps(x + i + 16),
|
412
|
+
_mm512_loadu_ps(x + i))));
|
413
413
|
}
|
414
414
|
#endif
|
415
415
|
for (; i < n; i++) {
|
@@ -871,22 +871,14 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
|
871
871
|
},
|
872
872
|
[GGML_TYPE_IQ4_XS] = {
|
873
873
|
.type_name = "iq4_xs",
|
874
|
-
#if QK_K == 64
|
875
|
-
.blck_size = QK4_NL,
|
876
|
-
#else
|
877
874
|
.blck_size = QK_K,
|
878
|
-
#endif
|
879
875
|
.type_size = sizeof(block_iq4_xs),
|
880
876
|
.is_quantized = true,
|
881
877
|
.to_float = (ggml_to_float_t) dequantize_row_iq4_xs,
|
882
878
|
.from_float = quantize_row_iq4_xs,
|
883
879
|
.from_float_reference = (ggml_from_float_t)quantize_row_iq4_xs_reference,
|
884
880
|
.vec_dot = ggml_vec_dot_iq4_xs_q8_K,
|
885
|
-
#if QK_K == 64
|
886
|
-
.vec_dot_type = GGML_TYPE_Q8_0,
|
887
|
-
#else
|
888
881
|
.vec_dot_type = GGML_TYPE_Q8_K,
|
889
|
-
#endif
|
890
882
|
.nrows = 1,
|
891
883
|
},
|
892
884
|
[GGML_TYPE_Q8_K] = {
|
@@ -1523,6 +1515,195 @@ static inline void __sse_f16x4_store(ggml_fp16_t *x, __m128 y) {
|
|
1523
1515
|
#define GGML_F16_VEC_MUL GGML_F32Cx4_MUL
|
1524
1516
|
#define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE
|
1525
1517
|
|
1518
|
+
#elif defined(__loongarch_asx)
|
1519
|
+
|
1520
|
+
#define GGML_SIMD
|
1521
|
+
|
1522
|
+
// F32 LASX
|
1523
|
+
#define GGML_F32_STEP 32
|
1524
|
+
#define GGML_F32_EPR 8
|
1525
|
+
|
1526
|
+
#define GGML_F32x8 __m256
|
1527
|
+
#define GGML_F32x8_ZERO (__m256)__lasx_xvldi(0)
|
1528
|
+
#define GGML_F32x8_SET1(x) (__m256)__lasx_xvreplfr2vr_s((x))
|
1529
|
+
#define GGML_F32x8_LOAD(x) (__m256)__lasx_xvld((x), 0)
|
1530
|
+
#define GGML_F32x8_STORE(x,y) __lasx_xvst((y), (x), 0)
|
1531
|
+
#define GGML_F32x8_FMA(a, b, c) __lasx_xvfmadd_s(b, c, a)
|
1532
|
+
#define GGML_F32x8_ADD __lasx_xvfadd_s
|
1533
|
+
#define GGML_F32x8_MUL __lasx_xvfmul_s
|
1534
|
+
#define GGML_F32x8_REDUCE(res, x) \
|
1535
|
+
do { \
|
1536
|
+
int offset = GGML_F32_ARR >> 1; \
|
1537
|
+
for (int i = 0; i < offset; ++i) { \
|
1538
|
+
x[i] = __lasx_xvfadd_s(x[i], x[offset+i]); \
|
1539
|
+
} \
|
1540
|
+
offset >>= 1; \
|
1541
|
+
for (int i = 0; i < offset; ++i) { \
|
1542
|
+
x[i] = __lasx_xvfadd_s(x[i], x[offset+i]); \
|
1543
|
+
} \
|
1544
|
+
offset >>= 1; \
|
1545
|
+
for (int i = 0; i < offset; ++i) { \
|
1546
|
+
x[i] = __lasx_xvfadd_s(x[i], x[offset+i]); \
|
1547
|
+
} \
|
1548
|
+
float *tmp_p = (float *)&x[0]; \
|
1549
|
+
res = tmp_p[0] + tmp_p[1] + tmp_p[2] + tmp_p[3] + tmp_p[4] + tmp_p[5] + tmp_p[6] + tmp_p[7]; \
|
1550
|
+
} while (0)
|
1551
|
+
// TODO: is this optimal ?
|
1552
|
+
|
1553
|
+
#define GGML_F32_VEC GGML_F32x8
|
1554
|
+
#define GGML_F32_VEC_ZERO GGML_F32x8_ZERO
|
1555
|
+
#define GGML_F32_VEC_SET1 GGML_F32x8_SET1
|
1556
|
+
#define GGML_F32_VEC_LOAD GGML_F32x8_LOAD
|
1557
|
+
#define GGML_F32_VEC_STORE GGML_F32x8_STORE
|
1558
|
+
#define GGML_F32_VEC_FMA GGML_F32x8_FMA
|
1559
|
+
#define GGML_F32_VEC_ADD GGML_F32x8_ADD
|
1560
|
+
#define GGML_F32_VEC_MUL GGML_F32x8_MUL
|
1561
|
+
#define GGML_F32_VEC_REDUCE GGML_F32x8_REDUCE
|
1562
|
+
|
1563
|
+
// F16 LASX
|
1564
|
+
|
1565
|
+
#define GGML_F16_STEP 32
|
1566
|
+
#define GGML_F16_EPR 8
|
1567
|
+
|
1568
|
+
// F16 arithmetic is not supported by AVX, so we use F32 instead
|
1569
|
+
|
1570
|
+
#define GGML_F32Cx8 __m256
|
1571
|
+
#define GGML_F32Cx8_ZERO (__m256)__lasx_xvldi(0)
|
1572
|
+
#define GGML_F32Cx8_SET1(x) (__m256)__lasx_xvreplgr2vr_w((x))
|
1573
|
+
|
1574
|
+
static inline __m256 __lasx_f32cx8_load(ggml_fp16_t *x) {
|
1575
|
+
float tmp[8];
|
1576
|
+
|
1577
|
+
for (int i = 0; i < 8; i++) {
|
1578
|
+
tmp[i] = GGML_FP16_TO_FP32(x[i]);
|
1579
|
+
}
|
1580
|
+
|
1581
|
+
return (__m256)__lasx_xvld(tmp, 0);
|
1582
|
+
}
|
1583
|
+
static inline void __lasx_f32cx8_store(ggml_fp16_t *x, __m256 y) {
|
1584
|
+
float arr[8];
|
1585
|
+
|
1586
|
+
__lasx_xvst(y, arr, 0);
|
1587
|
+
|
1588
|
+
for (int i = 0; i < 8; i++)
|
1589
|
+
x[i] = GGML_FP32_TO_FP16(arr[i]);
|
1590
|
+
}
|
1591
|
+
#define GGML_F32Cx8_LOAD(x) __lasx_f32cx8_load(x)
|
1592
|
+
#define GGML_F32Cx8_STORE(x, y) __lasx_f32cx8_store(x, y)
|
1593
|
+
|
1594
|
+
#define GGML_F32Cx8_FMA GGML_F32x8_FMA
|
1595
|
+
#define GGML_F32Cx8_ADD __lasx_xvfadd_s
|
1596
|
+
#define GGML_F32Cx8_MUL __lasx_xvfmul_s
|
1597
|
+
#define GGML_F32Cx8_REDUCE GGML_F32x8_REDUCE
|
1598
|
+
|
1599
|
+
#define GGML_F16_VEC GGML_F32Cx8
|
1600
|
+
#define GGML_F16_VEC_ZERO GGML_F32Cx8_ZERO
|
1601
|
+
#define GGML_F16_VEC_SET1 GGML_F32Cx8_SET1
|
1602
|
+
#define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx8_LOAD(p)
|
1603
|
+
#define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx8_STORE(p, r[i])
|
1604
|
+
#define GGML_F16_VEC_FMA GGML_F32Cx8_FMA
|
1605
|
+
#define GGML_F16_VEC_ADD GGML_F32Cx8_ADD
|
1606
|
+
#define GGML_F16_VEC_MUL GGML_F32Cx8_MUL
|
1607
|
+
#define GGML_F16_VEC_REDUCE GGML_F32Cx8_REDUCE
|
1608
|
+
|
1609
|
+
#elif defined(__loongarch_sx)
|
1610
|
+
|
1611
|
+
#define GGML_SIMD
|
1612
|
+
|
1613
|
+
// F32 LSX
|
1614
|
+
|
1615
|
+
#define GGML_F32_STEP 32
|
1616
|
+
#define GGML_F32_EPR 4
|
1617
|
+
|
1618
|
+
#define GGML_F32x4 __m128
|
1619
|
+
#define GGML_F32x4_ZERO __lsx_vldi(0)
|
1620
|
+
#define GGML_F32x4_SET1(x) __lsx_vinsgr2vr_w(__lsx_vldi(0),(x), 0)
|
1621
|
+
#define GGML_F32x4_LOAD(x) __lsx_vld((x), 0)
|
1622
|
+
#define GGML_F32x4_STORE((x),(y)) __lsx_vst((y), (x), 0)
|
1623
|
+
#define GGML_F32x4_FMA(a, b, c) __lsx_vfmadd_s(b, c, a)
|
1624
|
+
#define GGML_F32x4_ADD __lsx_vfadd_s
|
1625
|
+
#define GGML_F32x4_MUL __lsx_vfmul_s
|
1626
|
+
#define GGML_F32x4_REDUCE(res, x) \
|
1627
|
+
{ \
|
1628
|
+
int offset = GGML_F32_ARR >> 1; \
|
1629
|
+
for (int i = 0; i < offset; ++i) { \
|
1630
|
+
x[i] = __lsx_vfadd_s(x[i], x[offset+i]); \
|
1631
|
+
} \
|
1632
|
+
offset >>= 1; \
|
1633
|
+
for (int i = 0; i < offset; ++i) { \
|
1634
|
+
x[i] = __lsx_vfadd_s(x[i], x[offset+i]); \
|
1635
|
+
} \
|
1636
|
+
offset >>= 1; \
|
1637
|
+
for (int i = 0; i < offset; ++i) { \
|
1638
|
+
x[i] = __lsx_vfadd_s(x[i], x[offset+i]); \
|
1639
|
+
} \
|
1640
|
+
__m128i tmp = __lsx_vsrli_d((__m128i)x[0], 32); \
|
1641
|
+
tmp = (__m128i)__lsx_vfadd_s((__m128)tmp, x[0]); \
|
1642
|
+
tmp = __lsx_vpickev_w(__lsx_vldi(0), tmp); \
|
1643
|
+
const __m128 t0 = __lsx_vshuf4i_w(tmp, 0x88); \
|
1644
|
+
tmp = __lsx_vsrli_d((__m128i)t0, 32); \
|
1645
|
+
tmp = (__m128i)__lsx_vfadd_s((__m128)tmp, t0); \
|
1646
|
+
tmp = __lsx_vpickev_w(__lsx_vldi(0), tmp); \
|
1647
|
+
res = (ggml_float) __lsx_vpickve2gr_w(__lsx_vshuf4i_w(tmp, 0x88), 0); \
|
1648
|
+
}
|
1649
|
+
|
1650
|
+
#define GGML_F32_VEC GGML_F32x4
|
1651
|
+
#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO
|
1652
|
+
#define GGML_F32_VEC_SET1 GGML_F32x4_SET1
|
1653
|
+
#define GGML_F32_VEC_LOAD GGML_F32x4_LOAD
|
1654
|
+
#define GGML_F32_VEC_STORE GGML_F32x4_STORE
|
1655
|
+
#define GGML_F32_VEC_FMA GGML_F32x4_FMA
|
1656
|
+
#define GGML_F32_VEC_ADD GGML_F32x4_ADD
|
1657
|
+
#define GGML_F32_VEC_MUL GGML_F32x4_MUL
|
1658
|
+
#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
|
1659
|
+
|
1660
|
+
// F16 LSX
|
1661
|
+
|
1662
|
+
#define GGML_F16_STEP 32
|
1663
|
+
#define GGML_F16_EPR 4
|
1664
|
+
|
1665
|
+
static inline __m128 __lsx_f16x4_load(ggml_fp16_t *x) {
|
1666
|
+
float tmp[4];
|
1667
|
+
|
1668
|
+
tmp[0] = GGML_FP16_TO_FP32(x[0]);
|
1669
|
+
tmp[1] = GGML_FP16_TO_FP32(x[1]);
|
1670
|
+
tmp[2] = GGML_FP16_TO_FP32(x[2]);
|
1671
|
+
tmp[3] = GGML_FP16_TO_FP32(x[3]);
|
1672
|
+
|
1673
|
+
return __lsx_vld(tmp, 0);
|
1674
|
+
}
|
1675
|
+
|
1676
|
+
static inline void __lsx_f16x4_store(ggml_fp16_t *x, __m128 y) {
|
1677
|
+
float arr[4];
|
1678
|
+
|
1679
|
+
__lsx_vst(y, arr, 0);
|
1680
|
+
|
1681
|
+
x[0] = GGML_FP32_TO_FP16(arr[0]);
|
1682
|
+
x[1] = GGML_FP32_TO_FP16(arr[1]);
|
1683
|
+
x[2] = GGML_FP32_TO_FP16(arr[2]);
|
1684
|
+
x[3] = GGML_FP32_TO_FP16(arr[3]);
|
1685
|
+
}
|
1686
|
+
|
1687
|
+
#define GGML_F32Cx4 __m128
|
1688
|
+
#define GGML_F32Cx4_ZERO __lsx_vldi(0)
|
1689
|
+
#define GGML_F32Cx4_SET1(x) __lsx_vinsgr2vr_w(__lsx_vldi(0),(x), 0)
|
1690
|
+
#define GGML_F32Cx4_LOAD(x) __lsx_f16x4_load(x)
|
1691
|
+
#define GGML_F32Cx4_STORE(x, y) __lsx_f16x4_store(x, y)
|
1692
|
+
#define GGML_F32Cx4_FMA GGML_F32x4_FMA
|
1693
|
+
#define GGML_F32Cx4_ADD __lsx_vfadd_s
|
1694
|
+
#define GGML_F32Cx4_MUL __lsx_vfmul_s
|
1695
|
+
#define GGML_F32Cx4_REDUCE GGML_F32x4_REDUCE
|
1696
|
+
|
1697
|
+
#define GGML_F16_VEC GGML_F32Cx4
|
1698
|
+
#define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO
|
1699
|
+
#define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1
|
1700
|
+
#define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx4_LOAD(p)
|
1701
|
+
#define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE(p, r[i])
|
1702
|
+
#define GGML_F16_VEC_FMA GGML_F32Cx4_FMA
|
1703
|
+
#define GGML_F16_VEC_ADD GGML_F32Cx4_ADD
|
1704
|
+
#define GGML_F16_VEC_MUL GGML_F32Cx4_MUL
|
1705
|
+
#define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE
|
1706
|
+
|
1526
1707
|
#endif
|
1527
1708
|
|
1528
1709
|
// GGML_F32_ARR / GGML_F16_ARR
|
@@ -1666,10 +1847,10 @@ static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t
|
|
1666
1847
|
__m512 c1 = _mm512_setzero_ps();
|
1667
1848
|
__m512 c2 = _mm512_setzero_ps();
|
1668
1849
|
for (; i + 64 <= n; i += 64) {
|
1669
|
-
c1 = _mm512_dpbf16_ps(c1, (
|
1670
|
-
|
1671
|
-
c2 = _mm512_dpbf16_ps(c2, (
|
1672
|
-
|
1850
|
+
c1 = _mm512_dpbf16_ps(c1, m512bh(_mm512_loadu_si512((x + i))),
|
1851
|
+
m512bh(_mm512_loadu_si512((y + i))));
|
1852
|
+
c2 = _mm512_dpbf16_ps(c2, m512bh(_mm512_loadu_si512((x + i + 32))),
|
1853
|
+
m512bh(_mm512_loadu_si512((y + i + 32))));
|
1673
1854
|
}
|
1674
1855
|
sumf += (ggml_float)_mm512_reduce_add_ps(c1);
|
1675
1856
|
sumf += (ggml_float)_mm512_reduce_add_ps(c2);
|
@@ -2076,7 +2257,7 @@ inline static float ggml_silu_f32(float x) {
|
|
2076
2257
|
return x/(1.0f + expf(-x));
|
2077
2258
|
}
|
2078
2259
|
|
2079
|
-
#if defined(__ARM_NEON)
|
2260
|
+
#if defined(__ARM_NEON) && defined(__aarch64__)
|
2080
2261
|
|
2081
2262
|
// adapted from arm limited optimized routine
|
2082
2263
|
// the maximum error is 1.45358 plus 0.5 ulps
|
@@ -2288,7 +2469,7 @@ static void ggml_vec_silu_f32(const int n, float * y, const float * x) {
|
|
2288
2469
|
for (; i + 3 < n; i += 4) {
|
2289
2470
|
_mm_storeu_ps(y + i, ggml_v_silu(_mm_loadu_ps(x + i)));
|
2290
2471
|
}
|
2291
|
-
#elif defined(__ARM_NEON)
|
2472
|
+
#elif defined(__ARM_NEON) && defined(__aarch64__)
|
2292
2473
|
for (; i + 3 < n; i += 4) {
|
2293
2474
|
vst1q_f32(y + i, ggml_v_silu(vld1q_f32(x + i)));
|
2294
2475
|
}
|
@@ -2335,7 +2516,7 @@ static ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x,
|
|
2335
2516
|
#endif
|
2336
2517
|
sum += (ggml_float)_mm_cvtss_f32(val);
|
2337
2518
|
}
|
2338
|
-
#elif defined(__ARM_NEON)
|
2519
|
+
#elif defined(__ARM_NEON) && defined(__aarch64__)
|
2339
2520
|
for (; i + 3 < n; i += 4) {
|
2340
2521
|
float32x4_t val = ggml_v_expf(vsubq_f32(vld1q_f32(x + i),
|
2341
2522
|
vdupq_n_f32(max)));
|
@@ -2489,9 +2670,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|
2489
2670
|
"ARGSORT",
|
2490
2671
|
"LEAKY_RELU",
|
2491
2672
|
|
2492
|
-
"FLASH_ATTN",
|
2493
2673
|
"FLASH_ATTN_EXT",
|
2494
|
-
"FLASH_FF",
|
2495
2674
|
"FLASH_ATTN_BACK",
|
2496
2675
|
"SSM_CONV",
|
2497
2676
|
"SSM_SCAN",
|
@@ -2517,7 +2696,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|
2517
2696
|
"CROSS_ENTROPY_LOSS_BACK",
|
2518
2697
|
};
|
2519
2698
|
|
2520
|
-
static_assert(GGML_OP_COUNT ==
|
2699
|
+
static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");
|
2521
2700
|
|
2522
2701
|
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
2523
2702
|
"none",
|
@@ -2579,9 +2758,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|
2579
2758
|
"argsort(x)",
|
2580
2759
|
"leaky_relu(x)",
|
2581
2760
|
|
2582
|
-
"flash_attn(x)",
|
2583
2761
|
"flash_attn_ext(x)",
|
2584
|
-
"flash_ff(x)",
|
2585
2762
|
"flash_attn_back(x)",
|
2586
2763
|
"ssm_conv(x)",
|
2587
2764
|
"ssm_scan(x)",
|
@@ -2607,7 +2784,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|
2607
2784
|
"cross_entropy_loss_back(x,y)",
|
2608
2785
|
};
|
2609
2786
|
|
2610
|
-
static_assert(GGML_OP_COUNT ==
|
2787
|
+
static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");
|
2611
2788
|
|
2612
2789
|
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
2613
2790
|
|
@@ -6042,6 +6219,7 @@ static struct ggml_tensor * ggml_rope_impl(
|
|
6042
6219
|
struct ggml_context * ctx,
|
6043
6220
|
struct ggml_tensor * a,
|
6044
6221
|
struct ggml_tensor * b,
|
6222
|
+
struct ggml_tensor * c,
|
6045
6223
|
int n_dims,
|
6046
6224
|
int mode,
|
6047
6225
|
int n_ctx,
|
@@ -6055,10 +6233,17 @@ static struct ggml_tensor * ggml_rope_impl(
|
|
6055
6233
|
float xpos_base,
|
6056
6234
|
bool xpos_down,
|
6057
6235
|
bool inplace) {
|
6236
|
+
GGML_ASSERT((mode & 1) == 0 && "mode & 1 == 1 is no longer supported");
|
6237
|
+
|
6058
6238
|
GGML_ASSERT(ggml_is_vector(b));
|
6059
6239
|
GGML_ASSERT(b->type == GGML_TYPE_I32);
|
6060
6240
|
GGML_ASSERT(a->ne[2] == b->ne[0]);
|
6061
6241
|
|
6242
|
+
if (c) {
|
6243
|
+
GGML_ASSERT(c->type == GGML_TYPE_F32);
|
6244
|
+
GGML_ASSERT(c->ne[0] >= n_dims / 2);
|
6245
|
+
}
|
6246
|
+
|
6062
6247
|
bool is_node = false;
|
6063
6248
|
|
6064
6249
|
if (a->grad) {
|
@@ -6082,6 +6267,7 @@ static struct ggml_tensor * ggml_rope_impl(
|
|
6082
6267
|
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
6083
6268
|
result->src[0] = a;
|
6084
6269
|
result->src[1] = b;
|
6270
|
+
result->src[2] = c;
|
6085
6271
|
|
6086
6272
|
return result;
|
6087
6273
|
}
|
@@ -6094,7 +6280,7 @@ struct ggml_tensor * ggml_rope(
|
|
6094
6280
|
int mode,
|
6095
6281
|
int n_ctx) {
|
6096
6282
|
return ggml_rope_impl(
|
6097
|
-
ctx, a, b, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, false
|
6283
|
+
ctx, a, b, NULL, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, false
|
6098
6284
|
);
|
6099
6285
|
}
|
6100
6286
|
|
@@ -6106,14 +6292,15 @@ struct ggml_tensor * ggml_rope_inplace(
|
|
6106
6292
|
int mode,
|
6107
6293
|
int n_ctx) {
|
6108
6294
|
return ggml_rope_impl(
|
6109
|
-
ctx, a, b, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, true
|
6295
|
+
ctx, a, b, NULL, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, true
|
6110
6296
|
);
|
6111
6297
|
}
|
6112
6298
|
|
6113
|
-
struct ggml_tensor *
|
6299
|
+
struct ggml_tensor * ggml_rope_ext(
|
6114
6300
|
struct ggml_context * ctx,
|
6115
6301
|
struct ggml_tensor * a,
|
6116
6302
|
struct ggml_tensor * b,
|
6303
|
+
struct ggml_tensor * c,
|
6117
6304
|
int n_dims,
|
6118
6305
|
int mode,
|
6119
6306
|
int n_ctx,
|
@@ -6125,15 +6312,16 @@ struct ggml_tensor * ggml_rope_custom(
|
|
6125
6312
|
float beta_fast,
|
6126
6313
|
float beta_slow) {
|
6127
6314
|
return ggml_rope_impl(
|
6128
|
-
ctx, a, b, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
|
6315
|
+
ctx, a, b, c, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
|
6129
6316
|
ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, false
|
6130
6317
|
);
|
6131
6318
|
}
|
6132
6319
|
|
6133
|
-
struct ggml_tensor *
|
6320
|
+
struct ggml_tensor * ggml_rope_ext_inplace(
|
6134
6321
|
struct ggml_context * ctx,
|
6135
6322
|
struct ggml_tensor * a,
|
6136
6323
|
struct ggml_tensor * b,
|
6324
|
+
struct ggml_tensor * c,
|
6137
6325
|
int n_dims,
|
6138
6326
|
int mode,
|
6139
6327
|
int n_ctx,
|
@@ -6145,19 +6333,49 @@ struct ggml_tensor * ggml_rope_custom_inplace(
|
|
6145
6333
|
float beta_fast,
|
6146
6334
|
float beta_slow) {
|
6147
6335
|
return ggml_rope_impl(
|
6148
|
-
ctx, a, b, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
|
6336
|
+
ctx, a, b, c, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
|
6149
6337
|
ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, true
|
6150
6338
|
);
|
6151
6339
|
}
|
6152
6340
|
|
6153
|
-
struct ggml_tensor *
|
6341
|
+
struct ggml_tensor * ggml_rope_custom(
|
6342
|
+
struct ggml_context * ctx,
|
6343
|
+
struct ggml_tensor * a,
|
6344
|
+
struct ggml_tensor * b,
|
6345
|
+
int n_dims,
|
6346
|
+
int mode,
|
6347
|
+
int n_ctx,
|
6348
|
+
int n_orig_ctx,
|
6349
|
+
float freq_base,
|
6350
|
+
float freq_scale,
|
6351
|
+
float ext_factor,
|
6352
|
+
float attn_factor,
|
6353
|
+
float beta_fast,
|
6354
|
+
float beta_slow) {
|
6355
|
+
return ggml_rope_impl(
|
6356
|
+
ctx, a, b, NULL, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
|
6357
|
+
ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, false
|
6358
|
+
);
|
6359
|
+
}
|
6360
|
+
|
6361
|
+
struct ggml_tensor * ggml_rope_custom_inplace(
|
6154
6362
|
struct ggml_context * ctx,
|
6155
6363
|
struct ggml_tensor * a,
|
6156
6364
|
struct ggml_tensor * b,
|
6157
6365
|
int n_dims,
|
6158
|
-
|
6159
|
-
|
6160
|
-
|
6366
|
+
int mode,
|
6367
|
+
int n_ctx,
|
6368
|
+
int n_orig_ctx,
|
6369
|
+
float freq_base,
|
6370
|
+
float freq_scale,
|
6371
|
+
float ext_factor,
|
6372
|
+
float attn_factor,
|
6373
|
+
float beta_fast,
|
6374
|
+
float beta_slow) {
|
6375
|
+
return ggml_rope_impl(
|
6376
|
+
ctx, a, b, NULL, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
|
6377
|
+
ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, true
|
6378
|
+
);
|
6161
6379
|
}
|
6162
6380
|
|
6163
6381
|
// ggml_rope_back
|
@@ -6166,6 +6384,7 @@ struct ggml_tensor * ggml_rope_back(
|
|
6166
6384
|
struct ggml_context * ctx,
|
6167
6385
|
struct ggml_tensor * a,
|
6168
6386
|
struct ggml_tensor * b,
|
6387
|
+
struct ggml_tensor * c,
|
6169
6388
|
int n_dims,
|
6170
6389
|
int mode,
|
6171
6390
|
int n_ctx,
|
@@ -6181,6 +6400,7 @@ struct ggml_tensor * ggml_rope_back(
|
|
6181
6400
|
GGML_ASSERT(ggml_is_vector(b));
|
6182
6401
|
GGML_ASSERT(b->type == GGML_TYPE_I32);
|
6183
6402
|
GGML_ASSERT(a->ne[2] == b->ne[0]);
|
6403
|
+
GGML_ASSERT(c == NULL && "freq factors not implemented yet");
|
6184
6404
|
|
6185
6405
|
GGML_ASSERT((mode & 4) == 0 && "ggml_rope_back() for ChatGLM not implemented yet");
|
6186
6406
|
|
@@ -6724,38 +6944,6 @@ struct ggml_tensor * ggml_top_k(
|
|
6724
6944
|
return result;
|
6725
6945
|
}
|
6726
6946
|
|
6727
|
-
// ggml_flash_attn
|
6728
|
-
|
6729
|
-
struct ggml_tensor * ggml_flash_attn(
|
6730
|
-
struct ggml_context * ctx,
|
6731
|
-
struct ggml_tensor * q,
|
6732
|
-
struct ggml_tensor * k,
|
6733
|
-
struct ggml_tensor * v,
|
6734
|
-
bool masked) {
|
6735
|
-
GGML_ASSERT(ggml_can_mul_mat(k, q));
|
6736
|
-
// TODO: check if vT can be multiplied by (k*qT)
|
6737
|
-
|
6738
|
-
bool is_node = false;
|
6739
|
-
|
6740
|
-
if (q->grad || k->grad || v->grad) {
|
6741
|
-
is_node = true;
|
6742
|
-
}
|
6743
|
-
|
6744
|
-
//struct ggml_tensor * result = ggml_dup_tensor(ctx, q);
|
6745
|
-
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, q->ne);
|
6746
|
-
|
6747
|
-
int32_t t = masked ? 1 : 0;
|
6748
|
-
ggml_set_op_params(result, &t, sizeof(t));
|
6749
|
-
|
6750
|
-
result->op = GGML_OP_FLASH_ATTN;
|
6751
|
-
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
6752
|
-
result->src[0] = q;
|
6753
|
-
result->src[1] = k;
|
6754
|
-
result->src[2] = v;
|
6755
|
-
|
6756
|
-
return result;
|
6757
|
-
}
|
6758
|
-
|
6759
6947
|
// ggml_flash_attn_ext
|
6760
6948
|
|
6761
6949
|
struct ggml_tensor * ggml_flash_attn_ext(
|
@@ -6815,38 +7003,6 @@ void ggml_flash_attn_ext_set_prec(
|
|
6815
7003
|
ggml_set_op_params_i32(a, 2, prec_i32); // scale is on first pos, max_bias on second
|
6816
7004
|
}
|
6817
7005
|
|
6818
|
-
// ggml_flash_ff
|
6819
|
-
|
6820
|
-
struct ggml_tensor * ggml_flash_ff(
|
6821
|
-
struct ggml_context * ctx,
|
6822
|
-
struct ggml_tensor * a,
|
6823
|
-
struct ggml_tensor * b0,
|
6824
|
-
struct ggml_tensor * b1,
|
6825
|
-
struct ggml_tensor * c0,
|
6826
|
-
struct ggml_tensor * c1) {
|
6827
|
-
GGML_ASSERT(ggml_can_mul_mat(b0, a));
|
6828
|
-
// TODO: more checks
|
6829
|
-
|
6830
|
-
bool is_node = false;
|
6831
|
-
|
6832
|
-
if (a->grad || b0->grad || b1->grad || c0->grad || c1->grad) {
|
6833
|
-
is_node = true;
|
6834
|
-
}
|
6835
|
-
|
6836
|
-
//struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
|
6837
|
-
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, a->ne);
|
6838
|
-
|
6839
|
-
result->op = GGML_OP_FLASH_FF;
|
6840
|
-
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
6841
|
-
result->src[0] = a;
|
6842
|
-
result->src[1] = b0;
|
6843
|
-
result->src[2] = b1;
|
6844
|
-
result->src[3] = c0;
|
6845
|
-
result->src[4] = c1;
|
6846
|
-
|
6847
|
-
return result;
|
6848
|
-
}
|
6849
|
-
|
6850
7006
|
// ggml_flash_attn_back
|
6851
7007
|
|
6852
7008
|
struct ggml_tensor * ggml_flash_attn_back(
|
@@ -6856,6 +7012,8 @@ struct ggml_tensor * ggml_flash_attn_back(
|
|
6856
7012
|
struct ggml_tensor * v,
|
6857
7013
|
struct ggml_tensor * d,
|
6858
7014
|
bool masked) {
|
7015
|
+
GGML_ASSERT(false && "TODO: adapt to ggml_flash_attn_ext() changes");
|
7016
|
+
|
6859
7017
|
GGML_ASSERT(ggml_can_mul_mat(k, q));
|
6860
7018
|
// TODO: check if vT can be multiplied by (k*qT)
|
6861
7019
|
|
@@ -14115,6 +14273,7 @@ static void ggml_compute_forward_rope_f32(
|
|
14115
14273
|
|
14116
14274
|
const struct ggml_tensor * src0 = dst->src[0];
|
14117
14275
|
const struct ggml_tensor * src1 = dst->src[1];
|
14276
|
+
const struct ggml_tensor * src2 = dst->src[2];
|
14118
14277
|
|
14119
14278
|
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
|
14120
14279
|
return;
|
@@ -14174,6 +14333,17 @@ static void ggml_compute_forward_rope_f32(
|
|
14174
14333
|
const bool is_neox = mode & 2;
|
14175
14334
|
const bool is_glm = mode & 4;
|
14176
14335
|
|
14336
|
+
const float * freq_factors = NULL;
|
14337
|
+
if (is_neox) {
|
14338
|
+
if (src2 != NULL) {
|
14339
|
+
GGML_ASSERT(src2->type == GGML_TYPE_F32);
|
14340
|
+
GGML_ASSERT(src2->ne[0] >= n_dims / 2);
|
14341
|
+
freq_factors = (const float *) src2->data;
|
14342
|
+
}
|
14343
|
+
} else {
|
14344
|
+
GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for !is_neox");
|
14345
|
+
}
|
14346
|
+
|
14177
14347
|
// backward process uses inverse rotation by cos and sin.
|
14178
14348
|
// cos and sin build a rotation matrix, where the inverse is the transpose.
|
14179
14349
|
// this essentially just switches the sign of sin.
|
@@ -14250,10 +14420,11 @@ static void ggml_compute_forward_rope_f32(
|
|
14250
14420
|
|
14251
14421
|
// simplified from `(ib * n_dims + ic) * inv_ndims`
|
14252
14422
|
float cur_rot = inv_ndims * ic - ib;
|
14423
|
+
float freq_factor = freq_factors ? freq_factors[ic/2] : 1.0f;
|
14253
14424
|
|
14254
14425
|
float cos_theta, sin_theta;
|
14255
14426
|
rope_yarn(
|
14256
|
-
theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
|
14427
|
+
theta_base/freq_factor, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
|
14257
14428
|
&cos_theta, &sin_theta
|
14258
14429
|
);
|
14259
14430
|
sin_theta *= sin_sign;
|
@@ -14286,6 +14457,7 @@ static void ggml_compute_forward_rope_f32(
|
|
14286
14457
|
}
|
14287
14458
|
}
|
14288
14459
|
|
14460
|
+
// TODO: deduplicate f16/f32 code
|
14289
14461
|
static void ggml_compute_forward_rope_f16(
|
14290
14462
|
const struct ggml_compute_params * params,
|
14291
14463
|
struct ggml_tensor * dst,
|
@@ -14293,6 +14465,7 @@ static void ggml_compute_forward_rope_f16(
|
|
14293
14465
|
|
14294
14466
|
const struct ggml_tensor * src0 = dst->src[0];
|
14295
14467
|
const struct ggml_tensor * src1 = dst->src[1];
|
14468
|
+
const struct ggml_tensor * src2 = dst->src[2];
|
14296
14469
|
|
14297
14470
|
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
|
14298
14471
|
return;
|
@@ -14345,6 +14518,17 @@ static void ggml_compute_forward_rope_f16(
|
|
14345
14518
|
const bool is_neox = mode & 2;
|
14346
14519
|
const bool is_glm = mode & 4;
|
14347
14520
|
|
14521
|
+
const float * freq_factors = NULL;
|
14522
|
+
if (is_neox) {
|
14523
|
+
if (src2 != NULL) {
|
14524
|
+
GGML_ASSERT(src2->type == GGML_TYPE_F32);
|
14525
|
+
GGML_ASSERT(src2->ne[0] >= n_dims / 2);
|
14526
|
+
freq_factors = (const float *) src2->data;
|
14527
|
+
}
|
14528
|
+
} else {
|
14529
|
+
GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for !is_neox");
|
14530
|
+
}
|
14531
|
+
|
14348
14532
|
// backward process uses inverse rotation by cos and sin.
|
14349
14533
|
// cos and sin build a rotation matrix, where the inverse is the transpose.
|
14350
14534
|
// this essentially just switches the sign of sin.
|
@@ -14417,10 +14601,11 @@ static void ggml_compute_forward_rope_f16(
|
|
14417
14601
|
|
14418
14602
|
// simplified from `(ib * n_dims + ic) * inv_ndims`
|
14419
14603
|
float cur_rot = inv_ndims * ic - ib;
|
14604
|
+
float freq_factor = freq_factors ? freq_factors[ic/2] : 1.0f;
|
14420
14605
|
|
14421
14606
|
float cos_theta, sin_theta;
|
14422
14607
|
rope_yarn(
|
14423
|
-
theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
|
14608
|
+
theta_base/freq_factor, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
|
14424
14609
|
&cos_theta, &sin_theta
|
14425
14610
|
);
|
14426
14611
|
sin_theta *= sin_sign;
|
@@ -15458,17 +15643,15 @@ static void ggml_compute_forward_argsort(
|
|
15458
15643
|
}
|
15459
15644
|
}
|
15460
15645
|
|
15461
|
-
//
|
15646
|
+
// ggml_compute_forward_flash_attn_ext
|
15462
15647
|
|
15463
|
-
static void
|
15648
|
+
static void ggml_compute_forward_flash_attn_ext_f16(
|
15464
15649
|
const struct ggml_compute_params * params,
|
15465
|
-
const
|
15650
|
+
const struct ggml_tensor * q,
|
15651
|
+
const struct ggml_tensor * k,
|
15652
|
+
const struct ggml_tensor * v,
|
15653
|
+
const struct ggml_tensor * mask,
|
15466
15654
|
struct ggml_tensor * dst) {
|
15467
|
-
|
15468
|
-
const struct ggml_tensor * q = dst->src[0];
|
15469
|
-
const struct ggml_tensor * k = dst->src[1];
|
15470
|
-
const struct ggml_tensor * v = dst->src[2];
|
15471
|
-
|
15472
15655
|
int64_t t0 = ggml_perf_time_us();
|
15473
15656
|
UNUSED(t0);
|
15474
15657
|
|
@@ -15486,409 +15669,18 @@ static void ggml_compute_forward_flash_attn_f32(
|
|
15486
15669
|
|
15487
15670
|
const int64_t D = neq0;
|
15488
15671
|
const int64_t N = neq1;
|
15489
|
-
const int64_t P = nek1 - N;
|
15490
|
-
const int64_t M = P + N;
|
15491
|
-
|
15492
|
-
const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL);
|
15493
15672
|
|
15494
15673
|
GGML_ASSERT(ne0 == D);
|
15495
|
-
GGML_ASSERT(
|
15496
|
-
GGML_ASSERT(P >= 0);
|
15674
|
+
GGML_ASSERT(ne2 == N);
|
15497
15675
|
|
15498
|
-
|
15499
|
-
GGML_ASSERT(
|
15500
|
-
GGML_ASSERT(
|
15676
|
+
// input tensor rows must be contiguous
|
15677
|
+
GGML_ASSERT(nbq0 == ggml_type_size(q->type));
|
15678
|
+
GGML_ASSERT(nbk0 == ggml_type_size(k->type));
|
15679
|
+
GGML_ASSERT(nbv0 == ggml_type_size(v->type));
|
15501
15680
|
|
15502
15681
|
GGML_ASSERT(neq0 == D);
|
15503
15682
|
GGML_ASSERT(nek0 == D);
|
15504
|
-
GGML_ASSERT(
|
15505
|
-
|
15506
|
-
GGML_ASSERT(neq1 == N);
|
15507
|
-
GGML_ASSERT(nek1 == N + P);
|
15508
|
-
GGML_ASSERT(nev1 == D);
|
15509
|
-
|
15510
|
-
// dst cannot be transposed or permuted
|
15511
|
-
GGML_ASSERT(nb0 == sizeof(float));
|
15512
|
-
GGML_ASSERT(nb0 <= nb1);
|
15513
|
-
GGML_ASSERT(nb1 <= nb2);
|
15514
|
-
GGML_ASSERT(nb2 <= nb3);
|
15515
|
-
|
15516
|
-
if (params->type == GGML_TASK_TYPE_INIT) {
|
15517
|
-
return;
|
15518
|
-
}
|
15519
|
-
|
15520
|
-
if (params->type == GGML_TASK_TYPE_FINALIZE) {
|
15521
|
-
return;
|
15522
|
-
}
|
15523
|
-
|
15524
|
-
// parallelize by q rows using ggml_vec_dot_f32
|
15525
|
-
|
15526
|
-
// total rows in q
|
15527
|
-
const int nr = neq1*neq2*neq3;
|
15528
|
-
|
15529
|
-
// rows per thread
|
15530
|
-
const int dr = (nr + nth - 1)/nth;
|
15531
|
-
|
15532
|
-
// row range for this thread
|
15533
|
-
const int ir0 = dr*ith;
|
15534
|
-
const int ir1 = MIN(ir0 + dr, nr);
|
15535
|
-
|
15536
|
-
const float scale = 1.0f/sqrtf(D);
|
15537
|
-
|
15538
|
-
//printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
|
15539
|
-
|
15540
|
-
for (int ir = ir0; ir < ir1; ++ir) {
|
15541
|
-
// q indices
|
15542
|
-
const int iq3 = ir/(neq2*neq1);
|
15543
|
-
const int iq2 = (ir - iq3*neq2*neq1)/neq1;
|
15544
|
-
const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
|
15545
|
-
|
15546
|
-
float * S = (float *) params->wdata + ith*(Mup + CACHE_LINE_SIZE_F32);
|
15547
|
-
|
15548
|
-
for (int i = M; i < Mup; ++i) {
|
15549
|
-
S[i] = -INFINITY;
|
15550
|
-
}
|
15551
|
-
|
15552
|
-
const int64_t masked_begin = masked ? (P + iq1 + 1) : M;
|
15553
|
-
for (int64_t ic = 0; ic < masked_begin; ++ic) {
|
15554
|
-
// k indices
|
15555
|
-
const int ik3 = iq3;
|
15556
|
-
const int ik2 = iq2 % nek2;
|
15557
|
-
const int ik1 = ic;
|
15558
|
-
|
15559
|
-
// S indices
|
15560
|
-
const int i1 = ik1;
|
15561
|
-
|
15562
|
-
ggml_vec_dot_f32(neq0,
|
15563
|
-
S + i1, 0,
|
15564
|
-
(float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
|
15565
|
-
(float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)), 0, 1);
|
15566
|
-
}
|
15567
|
-
|
15568
|
-
// scale
|
15569
|
-
ggml_vec_scale_f32(masked_begin, S, scale);
|
15570
|
-
|
15571
|
-
for (int64_t i = masked_begin; i < M; i++) {
|
15572
|
-
S[i] = -INFINITY;
|
15573
|
-
}
|
15574
|
-
|
15575
|
-
// softmax
|
15576
|
-
// exclude known -INF S[..] values from max and loop
|
15577
|
-
// dont forget to set their SW values to zero
|
15578
|
-
{
|
15579
|
-
float max = -INFINITY;
|
15580
|
-
ggml_vec_max_f32(masked_begin, &max, S);
|
15581
|
-
|
15582
|
-
ggml_float sum = 0.0;
|
15583
|
-
{
|
15584
|
-
#ifdef GGML_SOFT_MAX_ACCELERATE
|
15585
|
-
max = -max;
|
15586
|
-
vDSP_vsadd(S, 1, &max, S, 1, Mup);
|
15587
|
-
vvexpf(S, S, &Mup);
|
15588
|
-
ggml_vec_sum_f32(Mup, &sum, S);
|
15589
|
-
#else
|
15590
|
-
sum = ggml_vec_soft_max_f32(Mup, S, S, max);
|
15591
|
-
#endif
|
15592
|
-
}
|
15593
|
-
|
15594
|
-
assert(sum > 0.0);
|
15595
|
-
|
15596
|
-
sum = 1.0/sum;
|
15597
|
-
ggml_vec_scale_f32(masked_begin, S, sum);
|
15598
|
-
|
15599
|
-
#ifndef NDEBUG
|
15600
|
-
for (int i = 0; i < masked_begin; ++i) {
|
15601
|
-
assert(!isnan(S[i]));
|
15602
|
-
assert(!isinf(S[i]));
|
15603
|
-
}
|
15604
|
-
#endif
|
15605
|
-
}
|
15606
|
-
|
15607
|
-
for (int64_t ic = 0; ic < nev1; ++ic) {
|
15608
|
-
// dst indices
|
15609
|
-
const int i1 = iq1;
|
15610
|
-
const int i2 = iq2;
|
15611
|
-
const int i3 = iq3;
|
15612
|
-
|
15613
|
-
// v indices
|
15614
|
-
const int iv2 = iq2 % nev2;
|
15615
|
-
const int iv3 = iq3;
|
15616
|
-
|
15617
|
-
ggml_vec_dot_f32(masked_begin,
|
15618
|
-
(float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), 0,
|
15619
|
-
(float *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), 0,
|
15620
|
-
S, 0, 1);
|
15621
|
-
}
|
15622
|
-
}
|
15623
|
-
}
|
15624
|
-
|
15625
|
-
static void ggml_compute_forward_flash_attn_f16(
|
15626
|
-
const struct ggml_compute_params * params,
|
15627
|
-
const bool masked,
|
15628
|
-
struct ggml_tensor * dst) {
|
15629
|
-
|
15630
|
-
const struct ggml_tensor * q = dst->src[0];
|
15631
|
-
const struct ggml_tensor * k = dst->src[1];
|
15632
|
-
const struct ggml_tensor * v = dst->src[2];
|
15633
|
-
|
15634
|
-
int64_t t0 = ggml_perf_time_us();
|
15635
|
-
UNUSED(t0);
|
15636
|
-
|
15637
|
-
GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
|
15638
|
-
GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
|
15639
|
-
GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
|
15640
|
-
GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
|
15641
|
-
GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
|
15642
|
-
GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
|
15643
|
-
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
15644
|
-
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
15645
|
-
|
15646
|
-
const int ith = params->ith;
|
15647
|
-
const int nth = params->nth;
|
15648
|
-
|
15649
|
-
const int64_t D = neq0;
|
15650
|
-
const int64_t N = neq1;
|
15651
|
-
const int64_t P = nek1 - N;
|
15652
|
-
const int64_t M = P + N;
|
15653
|
-
|
15654
|
-
const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL);
|
15655
|
-
|
15656
|
-
GGML_ASSERT(ne0 == D);
|
15657
|
-
GGML_ASSERT(ne1 == N);
|
15658
|
-
GGML_ASSERT(P >= 0);
|
15659
|
-
|
15660
|
-
GGML_ASSERT(nbq0 == sizeof(ggml_fp16_t));
|
15661
|
-
GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t));
|
15662
|
-
GGML_ASSERT(nbv0 == sizeof(ggml_fp16_t));
|
15663
|
-
|
15664
|
-
GGML_ASSERT(neq0 == D);
|
15665
|
-
GGML_ASSERT(nek0 == D);
|
15666
|
-
GGML_ASSERT(nev1 == D);
|
15667
|
-
|
15668
|
-
GGML_ASSERT(neq1 == N);
|
15669
|
-
GGML_ASSERT(nek1 == N + P);
|
15670
|
-
GGML_ASSERT(nev1 == D);
|
15671
|
-
|
15672
|
-
// dst cannot be transposed or permuted
|
15673
|
-
GGML_ASSERT(nb0 == sizeof(float));
|
15674
|
-
GGML_ASSERT(nb0 <= nb1);
|
15675
|
-
GGML_ASSERT(nb1 <= nb2);
|
15676
|
-
GGML_ASSERT(nb2 <= nb3);
|
15677
|
-
|
15678
|
-
if (params->type == GGML_TASK_TYPE_INIT) {
|
15679
|
-
return;
|
15680
|
-
}
|
15681
|
-
|
15682
|
-
if (params->type == GGML_TASK_TYPE_FINALIZE) {
|
15683
|
-
return;
|
15684
|
-
}
|
15685
|
-
|
15686
|
-
// parallelize by q rows using ggml_vec_dot_f32
|
15687
|
-
|
15688
|
-
// total rows in q
|
15689
|
-
const int nr = neq1*neq2*neq3;
|
15690
|
-
|
15691
|
-
// rows per thread
|
15692
|
-
const int dr = (nr + nth - 1)/nth;
|
15693
|
-
|
15694
|
-
// row range for this thread
|
15695
|
-
const int ir0 = dr*ith;
|
15696
|
-
const int ir1 = MIN(ir0 + dr, nr);
|
15697
|
-
|
15698
|
-
const float scale = 1.0f/sqrtf(D);
|
15699
|
-
|
15700
|
-
//printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
|
15701
|
-
|
15702
|
-
for (int ir = ir0; ir < ir1; ++ir) {
|
15703
|
-
// q indices
|
15704
|
-
const int iq3 = ir/(neq2*neq1);
|
15705
|
-
const int iq2 = (ir - iq3*neq2*neq1)/neq1;
|
15706
|
-
const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
|
15707
|
-
|
15708
|
-
float * S = (float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32);
|
15709
|
-
|
15710
|
-
for (int i = M; i < Mup; ++i) {
|
15711
|
-
S[i] = -INFINITY;
|
15712
|
-
}
|
15713
|
-
|
15714
|
-
if (GGML_VEC_DOT_UNROLL > 2 || nek1 % GGML_VEC_DOT_UNROLL != 0) {
|
15715
|
-
for (int64_t ic = 0; ic < nek1; ++ic) {
|
15716
|
-
// k indices
|
15717
|
-
const int ik3 = iq3;
|
15718
|
-
const int ik2 = iq2 % nek2;
|
15719
|
-
const int ik1 = ic;
|
15720
|
-
|
15721
|
-
// S indices
|
15722
|
-
const int i1 = ik1;
|
15723
|
-
|
15724
|
-
ggml_vec_dot_f16(neq0,
|
15725
|
-
S + i1, 0,
|
15726
|
-
(ggml_fp16_t *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
|
15727
|
-
(ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)), 0, 1);
|
15728
|
-
}
|
15729
|
-
} else {
|
15730
|
-
for (int64_t ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) {
|
15731
|
-
// k indices
|
15732
|
-
const int ik3 = iq3;
|
15733
|
-
const int ik2 = iq2 % nek2;
|
15734
|
-
const int ik1 = ic;
|
15735
|
-
|
15736
|
-
// S indices
|
15737
|
-
const int i1 = ik1;
|
15738
|
-
|
15739
|
-
ggml_vec_dot_f16_unroll(neq0, nbk1,
|
15740
|
-
S + i1,
|
15741
|
-
((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
|
15742
|
-
(ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
|
15743
|
-
}
|
15744
|
-
}
|
15745
|
-
|
15746
|
-
// scale
|
15747
|
-
ggml_vec_scale_f32(nek1, S, scale);
|
15748
|
-
|
15749
|
-
if (masked) {
|
15750
|
-
for (int64_t i = P; i < M; i++) {
|
15751
|
-
if (i > P + iq1) {
|
15752
|
-
S[i] = -INFINITY;
|
15753
|
-
}
|
15754
|
-
}
|
15755
|
-
}
|
15756
|
-
|
15757
|
-
// softmax
|
15758
|
-
// todo: exclude known -INF S[..] values from max and loop, assuming their results to be zero.
|
15759
|
-
// dont forget to set their S values to zero
|
15760
|
-
{
|
15761
|
-
float max = -INFINITY;
|
15762
|
-
ggml_vec_max_f32(M, &max, S);
|
15763
|
-
|
15764
|
-
ggml_float sum = 0.0;
|
15765
|
-
{
|
15766
|
-
#ifdef GGML_SOFT_MAX_ACCELERATE
|
15767
|
-
max = -max;
|
15768
|
-
vDSP_vsadd(S, 1, &max, S, 1, Mup);
|
15769
|
-
vvexpf(S, S, &Mup);
|
15770
|
-
ggml_vec_sum_f32(Mup, &sum, S);
|
15771
|
-
#else
|
15772
|
-
sum = ggml_vec_soft_max_f32(Mup, S, S, max);
|
15773
|
-
#endif
|
15774
|
-
}
|
15775
|
-
|
15776
|
-
assert(sum > 0.0);
|
15777
|
-
|
15778
|
-
sum = 1.0/sum;
|
15779
|
-
ggml_vec_scale_f32(M, S, sum);
|
15780
|
-
|
15781
|
-
#ifndef NDEBUG
|
15782
|
-
for (int i = 0; i < M; ++i) {
|
15783
|
-
assert(!isnan(S[i]));
|
15784
|
-
assert(!isinf(S[i]));
|
15785
|
-
}
|
15786
|
-
#endif
|
15787
|
-
}
|
15788
|
-
|
15789
|
-
ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32) + Mup);
|
15790
|
-
|
15791
|
-
for (int64_t i = 0; i < M; i++) {
|
15792
|
-
S16[i] = GGML_FP32_TO_FP16(S[i]);
|
15793
|
-
}
|
15794
|
-
|
15795
|
-
// todo: exclude known zero S[..] values from dot (reducing nev0 and increasing begin of v and S16).
|
15796
|
-
if (GGML_VEC_DOT_UNROLL == 1 || (nev1 % GGML_VEC_DOT_UNROLL != 0)) {
|
15797
|
-
for (int64_t ic = 0; ic < nev1; ++ic) {
|
15798
|
-
// dst indices
|
15799
|
-
const int i1 = iq1;
|
15800
|
-
const int i2 = iq2;
|
15801
|
-
const int i3 = iq3;
|
15802
|
-
|
15803
|
-
// v indices
|
15804
|
-
const int iv2 = iq2 % nev2;
|
15805
|
-
const int iv3 = iq3;
|
15806
|
-
|
15807
|
-
ggml_vec_dot_f16(nev0,
|
15808
|
-
(float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), 0,
|
15809
|
-
(ggml_fp16_t *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), 0,
|
15810
|
-
S16, 0, 1);
|
15811
|
-
}
|
15812
|
-
} else {
|
15813
|
-
for (int64_t ic = 0; ic < nev1; ic += GGML_VEC_DOT_UNROLL) {
|
15814
|
-
// dst indices
|
15815
|
-
const int i1 = iq1;
|
15816
|
-
const int i2 = iq2;
|
15817
|
-
const int i3 = iq3;
|
15818
|
-
|
15819
|
-
// v indices
|
15820
|
-
const int iv2 = iq2 % nev2;
|
15821
|
-
const int iv3 = iq3;
|
15822
|
-
|
15823
|
-
ggml_vec_dot_f16_unroll(nev0, nbv1,
|
15824
|
-
(float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
|
15825
|
-
((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
|
15826
|
-
S16);
|
15827
|
-
}
|
15828
|
-
}
|
15829
|
-
}
|
15830
|
-
}
|
15831
|
-
|
15832
|
-
static void ggml_compute_forward_flash_attn(
|
15833
|
-
const struct ggml_compute_params * params,
|
15834
|
-
const bool masked,
|
15835
|
-
struct ggml_tensor * dst) {
|
15836
|
-
|
15837
|
-
const struct ggml_tensor * q = dst->src[0];
|
15838
|
-
|
15839
|
-
switch (q->type) {
|
15840
|
-
case GGML_TYPE_F16:
|
15841
|
-
{
|
15842
|
-
ggml_compute_forward_flash_attn_f16(params, masked, dst);
|
15843
|
-
} break;
|
15844
|
-
case GGML_TYPE_F32:
|
15845
|
-
{
|
15846
|
-
ggml_compute_forward_flash_attn_f32(params, masked, dst);
|
15847
|
-
} break;
|
15848
|
-
default:
|
15849
|
-
{
|
15850
|
-
GGML_ASSERT(false);
|
15851
|
-
} break;
|
15852
|
-
}
|
15853
|
-
}
|
15854
|
-
|
15855
|
-
// ggml_compute_forward_flash_attn_ext
|
15856
|
-
|
15857
|
-
static void ggml_compute_forward_flash_attn_ext_f16(
|
15858
|
-
const struct ggml_compute_params * params,
|
15859
|
-
const struct ggml_tensor * q,
|
15860
|
-
const struct ggml_tensor * k,
|
15861
|
-
const struct ggml_tensor * v,
|
15862
|
-
const struct ggml_tensor * mask,
|
15863
|
-
struct ggml_tensor * dst) {
|
15864
|
-
int64_t t0 = ggml_perf_time_us();
|
15865
|
-
UNUSED(t0);
|
15866
|
-
|
15867
|
-
GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
|
15868
|
-
GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
|
15869
|
-
GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
|
15870
|
-
GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
|
15871
|
-
GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
|
15872
|
-
GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
|
15873
|
-
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
15874
|
-
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
15875
|
-
|
15876
|
-
const int ith = params->ith;
|
15877
|
-
const int nth = params->nth;
|
15878
|
-
|
15879
|
-
const int64_t D = neq0;
|
15880
|
-
const int64_t N = neq1;
|
15881
|
-
|
15882
|
-
GGML_ASSERT(ne0 == D);
|
15883
|
-
GGML_ASSERT(ne2 == N);
|
15884
|
-
|
15885
|
-
GGML_ASSERT(nbq0 == sizeof(float));
|
15886
|
-
GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t));
|
15887
|
-
GGML_ASSERT(nbv0 == sizeof(ggml_fp16_t));
|
15888
|
-
|
15889
|
-
GGML_ASSERT(neq0 == D);
|
15890
|
-
GGML_ASSERT(nek0 == D);
|
15891
|
-
GGML_ASSERT(nev0 == D);
|
15683
|
+
GGML_ASSERT(nev0 == D);
|
15892
15684
|
|
15893
15685
|
GGML_ASSERT(neq1 == N);
|
15894
15686
|
GGML_ASSERT(nev0 == D);
|
@@ -15938,6 +15730,11 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
15938
15730
|
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
15939
15731
|
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
15940
15732
|
|
15733
|
+
enum ggml_type const k_vec_dot_type = type_traits[k->type].vec_dot_type;
|
15734
|
+
ggml_from_float_t const q_to_vec_dot = type_traits[k_vec_dot_type].from_float;
|
15735
|
+
ggml_vec_dot_t const kq_vec_dot = type_traits[k->type].vec_dot;
|
15736
|
+
ggml_to_float_t const v_to_float = type_traits[v->type].to_float;
|
15737
|
+
|
15941
15738
|
// loop over n_batch and n_head
|
15942
15739
|
for (int ir = ir0; ir < ir1; ++ir) {
|
15943
15740
|
// q indices
|
@@ -15945,17 +15742,22 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
15945
15742
|
const int iq2 = (ir - iq3*neq2*neq1)/neq1;
|
15946
15743
|
const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
|
15947
15744
|
|
15948
|
-
const uint32_t h = iq2; // head
|
15745
|
+
const uint32_t h = iq2; // head index
|
15949
15746
|
const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
|
15950
15747
|
|
15951
|
-
float S = 0.0f;
|
15952
|
-
float M = -INFINITY;
|
15748
|
+
float S = 0.0f; // sum
|
15749
|
+
float M = -INFINITY; // maximum KQ value
|
15953
15750
|
|
15954
|
-
float *
|
15955
|
-
|
15956
|
-
ggml_fp16_t *
|
15751
|
+
float * VKQ32 = (float *) params->wdata + ith*(3*D + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator
|
15752
|
+
float * V32 = (VKQ32 + 1*D); // (temporary) FP32 V buffer
|
15753
|
+
ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*D); // (temporary) FP16 VKQ accumulator
|
15754
|
+
ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*D); // (temporary) buffer for Q converted to quantized/FP16
|
15957
15755
|
|
15958
|
-
|
15756
|
+
if (v->type == GGML_TYPE_F16) {
|
15757
|
+
memset(VKQ16, 0, D*sizeof(ggml_fp16_t));
|
15758
|
+
} else {
|
15759
|
+
memset(VKQ32, 0, D*sizeof(float));
|
15760
|
+
}
|
15959
15761
|
|
15960
15762
|
const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
|
15961
15763
|
|
@@ -15967,6 +15769,9 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
15967
15769
|
const int iv3 = iq3 / rv3;
|
15968
15770
|
const int iv2 = iq2 / rv2;
|
15969
15771
|
|
15772
|
+
const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
|
15773
|
+
q_to_vec_dot(pq, Q_q, D);
|
15774
|
+
|
15970
15775
|
// online softmax / attention
|
15971
15776
|
// loop over n_kv and n_head_kv
|
15972
15777
|
// ref: https://arxiv.org/pdf/2112.05682.pdf
|
@@ -15976,52 +15781,67 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
15976
15781
|
continue;
|
15977
15782
|
}
|
15978
15783
|
|
15979
|
-
float s;
|
15784
|
+
float s; // KQ value
|
15980
15785
|
|
15981
|
-
|
15982
|
-
|
15983
|
-
const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
|
15786
|
+
const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3);
|
15787
|
+
kq_vec_dot(D, &s, 0, k_data, 0, Q_q, 0, 1);
|
15984
15788
|
|
15985
|
-
|
15986
|
-
Q16[d] = GGML_FP32_TO_FP16(pq[d]);
|
15987
|
-
}
|
15988
|
-
}
|
15789
|
+
s = s*scale + mv; // scale KQ value and apply mask
|
15989
15790
|
|
15990
|
-
|
15991
|
-
&s, 0,
|
15992
|
-
(ggml_fp16_t *) ((char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
|
15993
|
-
Q16, 0, 1);
|
15791
|
+
const float Mold = M;
|
15994
15792
|
|
15995
|
-
|
15793
|
+
float ms = 1.0f; // upon new higher max val, scale VKQ and KQ sum with this value
|
15794
|
+
float vs = 1.0f; // post-softmax KQ value, expf(s - M)
|
15996
15795
|
|
15997
|
-
const
|
15796
|
+
const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
|
15998
15797
|
|
15999
|
-
|
16000
|
-
|
15798
|
+
if (v->type== GGML_TYPE_F16) {
|
15799
|
+
if (s > M) {
|
15800
|
+
// s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
|
15801
|
+
M = s;
|
15802
|
+
ms = expf(Mold - M);
|
16001
15803
|
|
16002
|
-
|
16003
|
-
|
16004
|
-
|
15804
|
+
// V = V*expf(Mold - M)
|
15805
|
+
ggml_vec_scale_f16(D, VKQ16, ms);
|
15806
|
+
} else {
|
15807
|
+
// no new maximum, ms == 1.0f, vs != 1.0f
|
15808
|
+
vs = expf(s - M);
|
15809
|
+
}
|
16005
15810
|
|
16006
|
-
// V
|
16007
|
-
|
15811
|
+
// V += v*expf(s - M)
|
15812
|
+
ggml_vec_mad_f16(D, VKQ16, (const ggml_fp16_t *) v_data, vs);
|
16008
15813
|
} else {
|
16009
|
-
|
16010
|
-
|
15814
|
+
if (s > M) {
|
15815
|
+
// s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
|
15816
|
+
M = s;
|
15817
|
+
ms = expf(Mold - M);
|
16011
15818
|
|
16012
|
-
|
15819
|
+
// V = V*expf(Mold - M)
|
15820
|
+
ggml_vec_scale_f32(D, VKQ32, ms);
|
15821
|
+
} else {
|
15822
|
+
// no new maximum, ms == 1.0f, vs != 1.0f
|
15823
|
+
vs = expf(s - M);
|
15824
|
+
}
|
16013
15825
|
|
16014
|
-
|
16015
|
-
ggml_vec_mad_f16(D, V16, v16, vs);
|
15826
|
+
v_to_float(v_data, V32, D);
|
16016
15827
|
|
16017
|
-
|
15828
|
+
// V += v*expf(s - M)
|
15829
|
+
ggml_vec_mad_f32(D, VKQ32, V32, vs);
|
15830
|
+
}
|
15831
|
+
|
15832
|
+
S = S*ms + vs; // scale and increment sum with partial sum
|
16018
15833
|
}
|
16019
15834
|
|
16020
|
-
|
16021
|
-
|
16022
|
-
|
15835
|
+
if (v->type == GGML_TYPE_F16) {
|
15836
|
+
for (int64_t d = 0; d < D; ++d) {
|
15837
|
+
VKQ32[d] = GGML_FP16_TO_FP32(VKQ16[d]);
|
15838
|
+
}
|
16023
15839
|
}
|
16024
15840
|
|
15841
|
+
// V /= S
|
15842
|
+
const float S_inv = 1.0f/S;
|
15843
|
+
ggml_vec_scale_f32(D, VKQ32, S_inv);
|
15844
|
+
|
16025
15845
|
// dst indices
|
16026
15846
|
const int i1 = iq1;
|
16027
15847
|
const int i2 = iq2;
|
@@ -16031,7 +15851,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
16031
15851
|
//memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float));
|
16032
15852
|
|
16033
15853
|
// permute(0, 2, 1, 3)
|
16034
|
-
memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1,
|
15854
|
+
memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1);
|
16035
15855
|
}
|
16036
15856
|
}
|
16037
15857
|
|
@@ -16056,165 +15876,6 @@ static void ggml_compute_forward_flash_attn_ext(
|
|
16056
15876
|
}
|
16057
15877
|
}
|
16058
15878
|
|
16059
|
-
// ggml_compute_forward_flash_ff
|
16060
|
-
|
16061
|
-
static void ggml_compute_forward_flash_ff_f16(
|
16062
|
-
const struct ggml_compute_params * params,
|
16063
|
-
struct ggml_tensor * dst) {
|
16064
|
-
|
16065
|
-
const struct ggml_tensor * a = dst->src[0]; // F16
|
16066
|
-
const struct ggml_tensor * b0 = dst->src[1]; // F16 fc_w
|
16067
|
-
const struct ggml_tensor * b1 = dst->src[2]; // F32 fc_b
|
16068
|
-
const struct ggml_tensor * c0 = dst->src[3]; // F16 proj_w
|
16069
|
-
const struct ggml_tensor * c1 = dst->src[4]; // F32 proj_b
|
16070
|
-
|
16071
|
-
int64_t t0 = ggml_perf_time_us();
|
16072
|
-
UNUSED(t0);
|
16073
|
-
|
16074
|
-
GGML_TENSOR_LOCALS(int64_t, nea, a, ne)
|
16075
|
-
GGML_TENSOR_LOCALS(size_t, nba, a, nb)
|
16076
|
-
GGML_TENSOR_LOCALS(int64_t, neb0, b0, ne)
|
16077
|
-
GGML_TENSOR_LOCALS(size_t, nbb0, b0, nb)
|
16078
|
-
GGML_TENSOR_LOCALS(int64_t, neb1, b1, ne)
|
16079
|
-
GGML_TENSOR_LOCALS(size_t, nbb1, b1, nb)
|
16080
|
-
GGML_TENSOR_LOCALS(int64_t, nec0, c0, ne)
|
16081
|
-
GGML_TENSOR_LOCALS(size_t, nbc0, c0, nb)
|
16082
|
-
GGML_TENSOR_LOCALS(int64_t, nec1, c1, ne)
|
16083
|
-
GGML_TENSOR_LOCALS(size_t, nbc1, c1, nb)
|
16084
|
-
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
16085
|
-
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
16086
|
-
|
16087
|
-
const int ith = params->ith;
|
16088
|
-
const int nth = params->nth;
|
16089
|
-
|
16090
|
-
const int64_t D = nea0;
|
16091
|
-
//const int64_t N = nea1;
|
16092
|
-
const int64_t M = neb01;
|
16093
|
-
|
16094
|
-
GGML_ASSERT(ne0 == nea0);
|
16095
|
-
GGML_ASSERT(ne1 == nea1);
|
16096
|
-
GGML_ASSERT(ne2 == nea2);
|
16097
|
-
|
16098
|
-
GGML_ASSERT(nba0 == sizeof(ggml_fp16_t));
|
16099
|
-
GGML_ASSERT(nbb00 == sizeof(ggml_fp16_t));
|
16100
|
-
GGML_ASSERT(nbb10 == sizeof(float));
|
16101
|
-
GGML_ASSERT(nbc00 == sizeof(ggml_fp16_t));
|
16102
|
-
GGML_ASSERT(nbc10 == sizeof(float));
|
16103
|
-
|
16104
|
-
GGML_ASSERT(neb00 == D);
|
16105
|
-
GGML_ASSERT(neb01 == M);
|
16106
|
-
GGML_ASSERT(neb10 == M);
|
16107
|
-
GGML_ASSERT(neb11 == 1);
|
16108
|
-
|
16109
|
-
GGML_ASSERT(nec00 == M);
|
16110
|
-
GGML_ASSERT(nec01 == D);
|
16111
|
-
GGML_ASSERT(nec10 == D);
|
16112
|
-
GGML_ASSERT(nec11 == 1);
|
16113
|
-
|
16114
|
-
// dst cannot be transposed or permuted
|
16115
|
-
GGML_ASSERT(nb0 == sizeof(float));
|
16116
|
-
GGML_ASSERT(nb0 <= nb1);
|
16117
|
-
GGML_ASSERT(nb1 <= nb2);
|
16118
|
-
GGML_ASSERT(nb2 <= nb3);
|
16119
|
-
|
16120
|
-
if (params->type == GGML_TASK_TYPE_INIT) {
|
16121
|
-
return;
|
16122
|
-
}
|
16123
|
-
|
16124
|
-
if (params->type == GGML_TASK_TYPE_FINALIZE) {
|
16125
|
-
return;
|
16126
|
-
}
|
16127
|
-
|
16128
|
-
// parallelize by a rows using ggml_vec_dot_f32
|
16129
|
-
|
16130
|
-
// total rows in a
|
16131
|
-
const int nr = nea1*nea2*nea3;
|
16132
|
-
|
16133
|
-
// rows per thread
|
16134
|
-
const int dr = (nr + nth - 1)/nth;
|
16135
|
-
|
16136
|
-
// row range for this thread
|
16137
|
-
const int ir0 = dr*ith;
|
16138
|
-
const int ir1 = MIN(ir0 + dr, nr);
|
16139
|
-
|
16140
|
-
for (int ir = ir0; ir < ir1; ++ir) {
|
16141
|
-
// a indices
|
16142
|
-
const int ia3 = ir/(nea2*nea1);
|
16143
|
-
const int ia2 = (ir - ia3*nea2*nea1)/nea1;
|
16144
|
-
const int ia1 = (ir - ia3*nea2*nea1 - ia2*nea1);
|
16145
|
-
|
16146
|
-
float * S = (float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32);
|
16147
|
-
|
16148
|
-
for (int64_t ic = 0; ic < neb01; ++ic) {
|
16149
|
-
// b0 indices
|
16150
|
-
const int ib03 = ia3;
|
16151
|
-
const int ib02 = ia2;
|
16152
|
-
const int ib01 = ic;
|
16153
|
-
|
16154
|
-
// S indices
|
16155
|
-
const int i1 = ib01;
|
16156
|
-
|
16157
|
-
ggml_vec_dot_f16(nea0,
|
16158
|
-
S + i1, 0,
|
16159
|
-
(ggml_fp16_t *) ((char *) b0->data + (ib01*nbb01 + ib02*nbb02 + ib03*nbb03)), 0,
|
16160
|
-
(ggml_fp16_t *) ((char *) a->data + ( ia1*nba1 + ia2*nba2 + ia3*nba3)), 0, 1);
|
16161
|
-
}
|
16162
|
-
|
16163
|
-
ggml_vec_add_f32(neb01, S, S, (float *) b1->data);
|
16164
|
-
//ggml_vec_gelu_f32(neb01, S, S);
|
16165
|
-
|
16166
|
-
ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32) + M);
|
16167
|
-
|
16168
|
-
for (int64_t i = 0; i < M; i++) {
|
16169
|
-
S16[i] = GGML_FP32_TO_FP16(S[i]);
|
16170
|
-
}
|
16171
|
-
|
16172
|
-
ggml_vec_gelu_f16(neb01, S16, S16);
|
16173
|
-
|
16174
|
-
{
|
16175
|
-
// dst indices
|
16176
|
-
const int i1 = ia1;
|
16177
|
-
const int i2 = ia2;
|
16178
|
-
const int i3 = ia3;
|
16179
|
-
|
16180
|
-
for (int64_t ic = 0; ic < nec01; ++ic) {
|
16181
|
-
|
16182
|
-
ggml_vec_dot_f16(neb01,
|
16183
|
-
(float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), 0,
|
16184
|
-
(ggml_fp16_t *) ((char *) c0->data + ( ic*nbc01 + i2*nbc02 + i3*nbc03)), 0,
|
16185
|
-
S16, 0, 1);
|
16186
|
-
}
|
16187
|
-
|
16188
|
-
ggml_vec_add_f32(nec01,
|
16189
|
-
(float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)),
|
16190
|
-
(float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)),
|
16191
|
-
(float *) c1->data);
|
16192
|
-
}
|
16193
|
-
}
|
16194
|
-
}
|
16195
|
-
|
16196
|
-
static void ggml_compute_forward_flash_ff(
|
16197
|
-
const struct ggml_compute_params * params,
|
16198
|
-
struct ggml_tensor * dst) {
|
16199
|
-
|
16200
|
-
const struct ggml_tensor * b0 = dst->src[1];
|
16201
|
-
|
16202
|
-
switch (b0->type) {
|
16203
|
-
case GGML_TYPE_F16:
|
16204
|
-
{
|
16205
|
-
ggml_compute_forward_flash_ff_f16(params, dst);
|
16206
|
-
} break;
|
16207
|
-
case GGML_TYPE_F32:
|
16208
|
-
{
|
16209
|
-
GGML_ASSERT(false); // TODO
|
16210
|
-
} break;
|
16211
|
-
default:
|
16212
|
-
{
|
16213
|
-
GGML_ASSERT(false);
|
16214
|
-
} break;
|
16215
|
-
}
|
16216
|
-
}
|
16217
|
-
|
16218
15879
|
// ggml_compute_forward_flash_attn_back
|
16219
15880
|
|
16220
15881
|
static void ggml_compute_forward_flash_attn_back_f32(
|
@@ -17785,21 +17446,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
17785
17446
|
{
|
17786
17447
|
ggml_compute_forward_leaky_relu(params, tensor);
|
17787
17448
|
} break;
|
17788
|
-
case GGML_OP_FLASH_ATTN:
|
17789
|
-
{
|
17790
|
-
const int32_t t = ggml_get_op_params_i32(tensor, 0);
|
17791
|
-
GGML_ASSERT(t == 0 || t == 1);
|
17792
|
-
const bool masked = t != 0;
|
17793
|
-
ggml_compute_forward_flash_attn(params, masked, tensor);
|
17794
|
-
} break;
|
17795
17449
|
case GGML_OP_FLASH_ATTN_EXT:
|
17796
17450
|
{
|
17797
17451
|
ggml_compute_forward_flash_attn_ext(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor);
|
17798
17452
|
} break;
|
17799
|
-
case GGML_OP_FLASH_FF:
|
17800
|
-
{
|
17801
|
-
ggml_compute_forward_flash_ff(params, tensor);
|
17802
|
-
} break;
|
17803
17453
|
case GGML_OP_FLASH_ATTN_BACK:
|
17804
17454
|
{
|
17805
17455
|
int32_t t = ggml_get_op_params_i32(tensor, 0);
|
@@ -18169,6 +17819,7 @@ static struct ggml_tensor * ggml_sub_or_set(struct ggml_context * ctx, struct gg
|
|
18169
17819
|
static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, struct ggml_hash_set zero_table) {
|
18170
17820
|
struct ggml_tensor * src0 = tensor->src[0];
|
18171
17821
|
struct ggml_tensor * src1 = tensor->src[1];
|
17822
|
+
struct ggml_tensor * src2 = tensor->src[2];
|
18172
17823
|
|
18173
17824
|
switch (tensor->op) {
|
18174
17825
|
case GGML_OP_DUP:
|
@@ -18700,6 +18351,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
18700
18351
|
ggml_rope_back(ctx,
|
18701
18352
|
tensor->grad,
|
18702
18353
|
src1,
|
18354
|
+
src2,
|
18703
18355
|
n_dims,
|
18704
18356
|
mode,
|
18705
18357
|
n_ctx,
|
@@ -18739,6 +18391,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
18739
18391
|
ggml_rope_impl(ctx,
|
18740
18392
|
tensor->grad,
|
18741
18393
|
src1,
|
18394
|
+
src2,
|
18742
18395
|
n_dims,
|
18743
18396
|
mode,
|
18744
18397
|
n_ctx,
|
@@ -18803,7 +18456,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
18803
18456
|
{
|
18804
18457
|
GGML_ASSERT(false); // TODO: not implemented
|
18805
18458
|
} break;
|
18806
|
-
case GGML_OP_FLASH_ATTN:
|
18807
18459
|
case GGML_OP_FLASH_ATTN_EXT:
|
18808
18460
|
{
|
18809
18461
|
struct ggml_tensor * flash_grad = NULL;
|
@@ -18820,7 +18472,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
18820
18472
|
masked);
|
18821
18473
|
}
|
18822
18474
|
|
18823
|
-
struct ggml_tensor * src2 = tensor->src[2];
|
18824
18475
|
const int64_t elem_q = ggml_nelements(src0);
|
18825
18476
|
const int64_t elem_k = ggml_nelements(src1);
|
18826
18477
|
const int64_t elem_v = ggml_nelements(src2);
|
@@ -18858,10 +18509,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
18858
18509
|
zero_table);
|
18859
18510
|
}
|
18860
18511
|
} break;
|
18861
|
-
case GGML_OP_FLASH_FF:
|
18862
|
-
{
|
18863
|
-
GGML_ASSERT(false); // not supported
|
18864
|
-
} break;
|
18865
18512
|
case GGML_OP_FLASH_ATTN_BACK:
|
18866
18513
|
{
|
18867
18514
|
GGML_ASSERT(false); // not supported
|
@@ -19548,15 +19195,10 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_
|
|
19548
19195
|
{
|
19549
19196
|
n_tasks = n_threads;
|
19550
19197
|
} break;
|
19551
|
-
case GGML_OP_FLASH_ATTN:
|
19552
19198
|
case GGML_OP_FLASH_ATTN_EXT:
|
19553
19199
|
{
|
19554
19200
|
n_tasks = n_threads;
|
19555
19201
|
} break;
|
19556
|
-
case GGML_OP_FLASH_FF:
|
19557
|
-
{
|
19558
|
-
n_tasks = n_threads;
|
19559
|
-
} break;
|
19560
19202
|
case GGML_OP_FLASH_ATTN_BACK:
|
19561
19203
|
{
|
19562
19204
|
n_tasks = n_threads;
|
@@ -19953,39 +19595,11 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
|
|
19953
19595
|
cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02*ne03;
|
19954
19596
|
cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12;
|
19955
19597
|
} break;
|
19956
|
-
case GGML_OP_FLASH_ATTN:
|
19957
|
-
{
|
19958
|
-
const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL);
|
19959
|
-
|
19960
|
-
if (node->src[1]->type == GGML_TYPE_F32) {
|
19961
|
-
cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
|
19962
|
-
cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
|
19963
|
-
} else if (node->src[1]->type == GGML_TYPE_F16) {
|
19964
|
-
cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
|
19965
|
-
cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
|
19966
|
-
} else if (node->src[1]->type == GGML_TYPE_BF16) {
|
19967
|
-
cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
|
19968
|
-
cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
|
19969
|
-
}
|
19970
|
-
} break;
|
19971
19598
|
case GGML_OP_FLASH_ATTN_EXT:
|
19972
19599
|
{
|
19973
19600
|
const int64_t ne00 = node->src[0]->ne[0]; // D
|
19974
19601
|
|
19975
|
-
cur =
|
19976
|
-
} break;
|
19977
|
-
case GGML_OP_FLASH_FF:
|
19978
|
-
{
|
19979
|
-
if (node->src[1]->type == GGML_TYPE_F32) {
|
19980
|
-
cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
|
19981
|
-
cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
|
19982
|
-
} else if (node->src[1]->type == GGML_TYPE_F16) {
|
19983
|
-
cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
|
19984
|
-
cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
|
19985
|
-
} else if (node->src[1]->type == GGML_TYPE_BF16) {
|
19986
|
-
cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
|
19987
|
-
cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
|
19988
|
-
}
|
19602
|
+
cur = 3*sizeof(float)*ne00*n_tasks; // 3x head size/thread
|
19989
19603
|
} break;
|
19990
19604
|
case GGML_OP_FLASH_ATTN_BACK:
|
19991
19605
|
{
|
@@ -21827,11 +21441,7 @@ size_t ggml_quantize_chunk(
|
|
21827
21441
|
case GGML_TYPE_IQ1_S: result = quantize_iq1_s (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
21828
21442
|
case GGML_TYPE_IQ1_M: result = quantize_iq1_m (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
21829
21443
|
case GGML_TYPE_IQ4_NL: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
21830
|
-
#if QK_K == 64
|
21831
|
-
case GGML_TYPE_IQ4_XS: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
21832
|
-
#else
|
21833
21444
|
case GGML_TYPE_IQ4_XS: result = quantize_iq4_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
21834
|
-
#endif
|
21835
21445
|
case GGML_TYPE_F16:
|
21836
21446
|
{
|
21837
21447
|
size_t elemsize = sizeof(ggml_fp16_t);
|
@@ -23108,6 +22718,14 @@ int ggml_cpu_has_avx512_vnni(void) {
|
|
23108
22718
|
#endif
|
23109
22719
|
}
|
23110
22720
|
|
22721
|
+
int ggml_cpu_has_avx512_bf16(void) {
|
22722
|
+
#if defined(__AVX512BF16__)
|
22723
|
+
return 1;
|
22724
|
+
#else
|
22725
|
+
return 0;
|
22726
|
+
#endif
|
22727
|
+
}
|
22728
|
+
|
23111
22729
|
int ggml_cpu_has_fma(void) {
|
23112
22730
|
#if defined(__FMA__)
|
23113
22731
|
return 1;
|