llama_cpp 0.15.2 → 0.15.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +8 -0
- data/ext/llama_cpp/llama_cpp.cpp +49 -0
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +4 -0
- data/vendor/tmp/llama.cpp/Makefile +6 -17
- data/vendor/tmp/llama.cpp/ggml-common.h +0 -54
- data/vendor/tmp/llama.cpp/ggml-cuda.cu +72 -30
- data/vendor/tmp/llama.cpp/ggml-cuda.h +1 -0
- data/vendor/tmp/llama.cpp/ggml-impl.h +40 -0
- data/vendor/tmp/llama.cpp/ggml-kompute.cpp +4 -0
- data/vendor/tmp/llama.cpp/ggml-metal.m +68 -70
- data/vendor/tmp/llama.cpp/ggml-metal.metal +24 -409
- data/vendor/tmp/llama.cpp/ggml-opencl.cpp +4 -1
- data/vendor/tmp/llama.cpp/ggml-quants.c +1879 -2450
- data/vendor/tmp/llama.cpp/ggml-rpc.cpp +176 -53
- data/vendor/tmp/llama.cpp/ggml-sycl.cpp +40 -500
- data/vendor/tmp/llama.cpp/ggml-vulkan-shaders.hpp +9351 -5627
- data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +202 -225
- data/vendor/tmp/llama.cpp/ggml.c +376 -758
- data/vendor/tmp/llama.cpp/ggml.h +39 -27
- data/vendor/tmp/llama.cpp/llama.cpp +823 -593
- data/vendor/tmp/llama.cpp/llama.h +10 -3
- metadata +3 -3
data/vendor/tmp/llama.cpp/ggml.c
CHANGED
@@ -406,10 +406,10 @@ void ggml_fp32_to_bf16_row(const float * x, ggml_bf16_t * y, int64_t n) {
|
|
406
406
|
int i = 0;
|
407
407
|
#if defined(__AVX512BF16__)
|
408
408
|
for (; i + 32 <= n; i += 32) {
|
409
|
-
|
410
|
-
(
|
411
|
-
(
|
412
|
-
|
409
|
+
_mm512_storeu_si512(
|
410
|
+
(__m512i *)(y + i),
|
411
|
+
m512i(_mm512_cvtne2ps_pbh(_mm512_loadu_ps(x + i + 16),
|
412
|
+
_mm512_loadu_ps(x + i))));
|
413
413
|
}
|
414
414
|
#endif
|
415
415
|
for (; i < n; i++) {
|
@@ -871,22 +871,14 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
|
871
871
|
},
|
872
872
|
[GGML_TYPE_IQ4_XS] = {
|
873
873
|
.type_name = "iq4_xs",
|
874
|
-
#if QK_K == 64
|
875
|
-
.blck_size = QK4_NL,
|
876
|
-
#else
|
877
874
|
.blck_size = QK_K,
|
878
|
-
#endif
|
879
875
|
.type_size = sizeof(block_iq4_xs),
|
880
876
|
.is_quantized = true,
|
881
877
|
.to_float = (ggml_to_float_t) dequantize_row_iq4_xs,
|
882
878
|
.from_float = quantize_row_iq4_xs,
|
883
879
|
.from_float_reference = (ggml_from_float_t)quantize_row_iq4_xs_reference,
|
884
880
|
.vec_dot = ggml_vec_dot_iq4_xs_q8_K,
|
885
|
-
#if QK_K == 64
|
886
|
-
.vec_dot_type = GGML_TYPE_Q8_0,
|
887
|
-
#else
|
888
881
|
.vec_dot_type = GGML_TYPE_Q8_K,
|
889
|
-
#endif
|
890
882
|
.nrows = 1,
|
891
883
|
},
|
892
884
|
[GGML_TYPE_Q8_K] = {
|
@@ -1523,6 +1515,195 @@ static inline void __sse_f16x4_store(ggml_fp16_t *x, __m128 y) {
|
|
1523
1515
|
#define GGML_F16_VEC_MUL GGML_F32Cx4_MUL
|
1524
1516
|
#define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE
|
1525
1517
|
|
1518
|
+
#elif defined(__loongarch_asx)
|
1519
|
+
|
1520
|
+
#define GGML_SIMD
|
1521
|
+
|
1522
|
+
// F32 LASX
|
1523
|
+
#define GGML_F32_STEP 32
|
1524
|
+
#define GGML_F32_EPR 8
|
1525
|
+
|
1526
|
+
#define GGML_F32x8 __m256
|
1527
|
+
#define GGML_F32x8_ZERO (__m256)__lasx_xvldi(0)
|
1528
|
+
#define GGML_F32x8_SET1(x) (__m256)__lasx_xvreplfr2vr_s((x))
|
1529
|
+
#define GGML_F32x8_LOAD(x) (__m256)__lasx_xvld((x), 0)
|
1530
|
+
#define GGML_F32x8_STORE(x,y) __lasx_xvst((y), (x), 0)
|
1531
|
+
#define GGML_F32x8_FMA(a, b, c) __lasx_xvfmadd_s(b, c, a)
|
1532
|
+
#define GGML_F32x8_ADD __lasx_xvfadd_s
|
1533
|
+
#define GGML_F32x8_MUL __lasx_xvfmul_s
|
1534
|
+
#define GGML_F32x8_REDUCE(res, x) \
|
1535
|
+
do { \
|
1536
|
+
int offset = GGML_F32_ARR >> 1; \
|
1537
|
+
for (int i = 0; i < offset; ++i) { \
|
1538
|
+
x[i] = __lasx_xvfadd_s(x[i], x[offset+i]); \
|
1539
|
+
} \
|
1540
|
+
offset >>= 1; \
|
1541
|
+
for (int i = 0; i < offset; ++i) { \
|
1542
|
+
x[i] = __lasx_xvfadd_s(x[i], x[offset+i]); \
|
1543
|
+
} \
|
1544
|
+
offset >>= 1; \
|
1545
|
+
for (int i = 0; i < offset; ++i) { \
|
1546
|
+
x[i] = __lasx_xvfadd_s(x[i], x[offset+i]); \
|
1547
|
+
} \
|
1548
|
+
float *tmp_p = (float *)&x[0]; \
|
1549
|
+
res = tmp_p[0] + tmp_p[1] + tmp_p[2] + tmp_p[3] + tmp_p[4] + tmp_p[5] + tmp_p[6] + tmp_p[7]; \
|
1550
|
+
} while (0)
|
1551
|
+
// TODO: is this optimal ?
|
1552
|
+
|
1553
|
+
#define GGML_F32_VEC GGML_F32x8
|
1554
|
+
#define GGML_F32_VEC_ZERO GGML_F32x8_ZERO
|
1555
|
+
#define GGML_F32_VEC_SET1 GGML_F32x8_SET1
|
1556
|
+
#define GGML_F32_VEC_LOAD GGML_F32x8_LOAD
|
1557
|
+
#define GGML_F32_VEC_STORE GGML_F32x8_STORE
|
1558
|
+
#define GGML_F32_VEC_FMA GGML_F32x8_FMA
|
1559
|
+
#define GGML_F32_VEC_ADD GGML_F32x8_ADD
|
1560
|
+
#define GGML_F32_VEC_MUL GGML_F32x8_MUL
|
1561
|
+
#define GGML_F32_VEC_REDUCE GGML_F32x8_REDUCE
|
1562
|
+
|
1563
|
+
// F16 LASX
|
1564
|
+
|
1565
|
+
#define GGML_F16_STEP 32
|
1566
|
+
#define GGML_F16_EPR 8
|
1567
|
+
|
1568
|
+
// F16 arithmetic is not supported by AVX, so we use F32 instead
|
1569
|
+
|
1570
|
+
#define GGML_F32Cx8 __m256
|
1571
|
+
#define GGML_F32Cx8_ZERO (__m256)__lasx_xvldi(0)
|
1572
|
+
#define GGML_F32Cx8_SET1(x) (__m256)__lasx_xvreplgr2vr_w((x))
|
1573
|
+
|
1574
|
+
static inline __m256 __lasx_f32cx8_load(ggml_fp16_t *x) {
|
1575
|
+
float tmp[8];
|
1576
|
+
|
1577
|
+
for (int i = 0; i < 8; i++) {
|
1578
|
+
tmp[i] = GGML_FP16_TO_FP32(x[i]);
|
1579
|
+
}
|
1580
|
+
|
1581
|
+
return (__m256)__lasx_xvld(tmp, 0);
|
1582
|
+
}
|
1583
|
+
static inline void __lasx_f32cx8_store(ggml_fp16_t *x, __m256 y) {
|
1584
|
+
float arr[8];
|
1585
|
+
|
1586
|
+
__lasx_xvst(y, arr, 0);
|
1587
|
+
|
1588
|
+
for (int i = 0; i < 8; i++)
|
1589
|
+
x[i] = GGML_FP32_TO_FP16(arr[i]);
|
1590
|
+
}
|
1591
|
+
#define GGML_F32Cx8_LOAD(x) __lasx_f32cx8_load(x)
|
1592
|
+
#define GGML_F32Cx8_STORE(x, y) __lasx_f32cx8_store(x, y)
|
1593
|
+
|
1594
|
+
#define GGML_F32Cx8_FMA GGML_F32x8_FMA
|
1595
|
+
#define GGML_F32Cx8_ADD __lasx_xvfadd_s
|
1596
|
+
#define GGML_F32Cx8_MUL __lasx_xvfmul_s
|
1597
|
+
#define GGML_F32Cx8_REDUCE GGML_F32x8_REDUCE
|
1598
|
+
|
1599
|
+
#define GGML_F16_VEC GGML_F32Cx8
|
1600
|
+
#define GGML_F16_VEC_ZERO GGML_F32Cx8_ZERO
|
1601
|
+
#define GGML_F16_VEC_SET1 GGML_F32Cx8_SET1
|
1602
|
+
#define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx8_LOAD(p)
|
1603
|
+
#define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx8_STORE(p, r[i])
|
1604
|
+
#define GGML_F16_VEC_FMA GGML_F32Cx8_FMA
|
1605
|
+
#define GGML_F16_VEC_ADD GGML_F32Cx8_ADD
|
1606
|
+
#define GGML_F16_VEC_MUL GGML_F32Cx8_MUL
|
1607
|
+
#define GGML_F16_VEC_REDUCE GGML_F32Cx8_REDUCE
|
1608
|
+
|
1609
|
+
#elif defined(__loongarch_sx)
|
1610
|
+
|
1611
|
+
#define GGML_SIMD
|
1612
|
+
|
1613
|
+
// F32 LSX
|
1614
|
+
|
1615
|
+
#define GGML_F32_STEP 32
|
1616
|
+
#define GGML_F32_EPR 4
|
1617
|
+
|
1618
|
+
#define GGML_F32x4 __m128
|
1619
|
+
#define GGML_F32x4_ZERO __lsx_vldi(0)
|
1620
|
+
#define GGML_F32x4_SET1(x) __lsx_vinsgr2vr_w(__lsx_vldi(0),(x), 0)
|
1621
|
+
#define GGML_F32x4_LOAD(x) __lsx_vld((x), 0)
|
1622
|
+
#define GGML_F32x4_STORE((x),(y)) __lsx_vst((y), (x), 0)
|
1623
|
+
#define GGML_F32x4_FMA(a, b, c) __lsx_vfmadd_s(b, c, a)
|
1624
|
+
#define GGML_F32x4_ADD __lsx_vfadd_s
|
1625
|
+
#define GGML_F32x4_MUL __lsx_vfmul_s
|
1626
|
+
#define GGML_F32x4_REDUCE(res, x) \
|
1627
|
+
{ \
|
1628
|
+
int offset = GGML_F32_ARR >> 1; \
|
1629
|
+
for (int i = 0; i < offset; ++i) { \
|
1630
|
+
x[i] = __lsx_vfadd_s(x[i], x[offset+i]); \
|
1631
|
+
} \
|
1632
|
+
offset >>= 1; \
|
1633
|
+
for (int i = 0; i < offset; ++i) { \
|
1634
|
+
x[i] = __lsx_vfadd_s(x[i], x[offset+i]); \
|
1635
|
+
} \
|
1636
|
+
offset >>= 1; \
|
1637
|
+
for (int i = 0; i < offset; ++i) { \
|
1638
|
+
x[i] = __lsx_vfadd_s(x[i], x[offset+i]); \
|
1639
|
+
} \
|
1640
|
+
__m128i tmp = __lsx_vsrli_d((__m128i)x[0], 32); \
|
1641
|
+
tmp = (__m128i)__lsx_vfadd_s((__m128)tmp, x[0]); \
|
1642
|
+
tmp = __lsx_vpickev_w(__lsx_vldi(0), tmp); \
|
1643
|
+
const __m128 t0 = __lsx_vshuf4i_w(tmp, 0x88); \
|
1644
|
+
tmp = __lsx_vsrli_d((__m128i)t0, 32); \
|
1645
|
+
tmp = (__m128i)__lsx_vfadd_s((__m128)tmp, t0); \
|
1646
|
+
tmp = __lsx_vpickev_w(__lsx_vldi(0), tmp); \
|
1647
|
+
res = (ggml_float) __lsx_vpickve2gr_w(__lsx_vshuf4i_w(tmp, 0x88), 0); \
|
1648
|
+
}
|
1649
|
+
|
1650
|
+
#define GGML_F32_VEC GGML_F32x4
|
1651
|
+
#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO
|
1652
|
+
#define GGML_F32_VEC_SET1 GGML_F32x4_SET1
|
1653
|
+
#define GGML_F32_VEC_LOAD GGML_F32x4_LOAD
|
1654
|
+
#define GGML_F32_VEC_STORE GGML_F32x4_STORE
|
1655
|
+
#define GGML_F32_VEC_FMA GGML_F32x4_FMA
|
1656
|
+
#define GGML_F32_VEC_ADD GGML_F32x4_ADD
|
1657
|
+
#define GGML_F32_VEC_MUL GGML_F32x4_MUL
|
1658
|
+
#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
|
1659
|
+
|
1660
|
+
// F16 LSX
|
1661
|
+
|
1662
|
+
#define GGML_F16_STEP 32
|
1663
|
+
#define GGML_F16_EPR 4
|
1664
|
+
|
1665
|
+
static inline __m128 __lsx_f16x4_load(ggml_fp16_t *x) {
|
1666
|
+
float tmp[4];
|
1667
|
+
|
1668
|
+
tmp[0] = GGML_FP16_TO_FP32(x[0]);
|
1669
|
+
tmp[1] = GGML_FP16_TO_FP32(x[1]);
|
1670
|
+
tmp[2] = GGML_FP16_TO_FP32(x[2]);
|
1671
|
+
tmp[3] = GGML_FP16_TO_FP32(x[3]);
|
1672
|
+
|
1673
|
+
return __lsx_vld(tmp, 0);
|
1674
|
+
}
|
1675
|
+
|
1676
|
+
static inline void __lsx_f16x4_store(ggml_fp16_t *x, __m128 y) {
|
1677
|
+
float arr[4];
|
1678
|
+
|
1679
|
+
__lsx_vst(y, arr, 0);
|
1680
|
+
|
1681
|
+
x[0] = GGML_FP32_TO_FP16(arr[0]);
|
1682
|
+
x[1] = GGML_FP32_TO_FP16(arr[1]);
|
1683
|
+
x[2] = GGML_FP32_TO_FP16(arr[2]);
|
1684
|
+
x[3] = GGML_FP32_TO_FP16(arr[3]);
|
1685
|
+
}
|
1686
|
+
|
1687
|
+
#define GGML_F32Cx4 __m128
|
1688
|
+
#define GGML_F32Cx4_ZERO __lsx_vldi(0)
|
1689
|
+
#define GGML_F32Cx4_SET1(x) __lsx_vinsgr2vr_w(__lsx_vldi(0),(x), 0)
|
1690
|
+
#define GGML_F32Cx4_LOAD(x) __lsx_f16x4_load(x)
|
1691
|
+
#define GGML_F32Cx4_STORE(x, y) __lsx_f16x4_store(x, y)
|
1692
|
+
#define GGML_F32Cx4_FMA GGML_F32x4_FMA
|
1693
|
+
#define GGML_F32Cx4_ADD __lsx_vfadd_s
|
1694
|
+
#define GGML_F32Cx4_MUL __lsx_vfmul_s
|
1695
|
+
#define GGML_F32Cx4_REDUCE GGML_F32x4_REDUCE
|
1696
|
+
|
1697
|
+
#define GGML_F16_VEC GGML_F32Cx4
|
1698
|
+
#define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO
|
1699
|
+
#define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1
|
1700
|
+
#define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx4_LOAD(p)
|
1701
|
+
#define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE(p, r[i])
|
1702
|
+
#define GGML_F16_VEC_FMA GGML_F32Cx4_FMA
|
1703
|
+
#define GGML_F16_VEC_ADD GGML_F32Cx4_ADD
|
1704
|
+
#define GGML_F16_VEC_MUL GGML_F32Cx4_MUL
|
1705
|
+
#define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE
|
1706
|
+
|
1526
1707
|
#endif
|
1527
1708
|
|
1528
1709
|
// GGML_F32_ARR / GGML_F16_ARR
|
@@ -1666,10 +1847,10 @@ static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t
|
|
1666
1847
|
__m512 c1 = _mm512_setzero_ps();
|
1667
1848
|
__m512 c2 = _mm512_setzero_ps();
|
1668
1849
|
for (; i + 64 <= n; i += 64) {
|
1669
|
-
c1 = _mm512_dpbf16_ps(c1, (
|
1670
|
-
|
1671
|
-
c2 = _mm512_dpbf16_ps(c2, (
|
1672
|
-
|
1850
|
+
c1 = _mm512_dpbf16_ps(c1, m512bh(_mm512_loadu_si512((x + i))),
|
1851
|
+
m512bh(_mm512_loadu_si512((y + i))));
|
1852
|
+
c2 = _mm512_dpbf16_ps(c2, m512bh(_mm512_loadu_si512((x + i + 32))),
|
1853
|
+
m512bh(_mm512_loadu_si512((y + i + 32))));
|
1673
1854
|
}
|
1674
1855
|
sumf += (ggml_float)_mm512_reduce_add_ps(c1);
|
1675
1856
|
sumf += (ggml_float)_mm512_reduce_add_ps(c2);
|
@@ -2076,7 +2257,7 @@ inline static float ggml_silu_f32(float x) {
|
|
2076
2257
|
return x/(1.0f + expf(-x));
|
2077
2258
|
}
|
2078
2259
|
|
2079
|
-
#if defined(__ARM_NEON)
|
2260
|
+
#if defined(__ARM_NEON) && defined(__aarch64__)
|
2080
2261
|
|
2081
2262
|
// adapted from arm limited optimized routine
|
2082
2263
|
// the maximum error is 1.45358 plus 0.5 ulps
|
@@ -2288,7 +2469,7 @@ static void ggml_vec_silu_f32(const int n, float * y, const float * x) {
|
|
2288
2469
|
for (; i + 3 < n; i += 4) {
|
2289
2470
|
_mm_storeu_ps(y + i, ggml_v_silu(_mm_loadu_ps(x + i)));
|
2290
2471
|
}
|
2291
|
-
#elif defined(__ARM_NEON)
|
2472
|
+
#elif defined(__ARM_NEON) && defined(__aarch64__)
|
2292
2473
|
for (; i + 3 < n; i += 4) {
|
2293
2474
|
vst1q_f32(y + i, ggml_v_silu(vld1q_f32(x + i)));
|
2294
2475
|
}
|
@@ -2335,7 +2516,7 @@ static ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x,
|
|
2335
2516
|
#endif
|
2336
2517
|
sum += (ggml_float)_mm_cvtss_f32(val);
|
2337
2518
|
}
|
2338
|
-
#elif defined(__ARM_NEON)
|
2519
|
+
#elif defined(__ARM_NEON) && defined(__aarch64__)
|
2339
2520
|
for (; i + 3 < n; i += 4) {
|
2340
2521
|
float32x4_t val = ggml_v_expf(vsubq_f32(vld1q_f32(x + i),
|
2341
2522
|
vdupq_n_f32(max)));
|
@@ -2489,9 +2670,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|
2489
2670
|
"ARGSORT",
|
2490
2671
|
"LEAKY_RELU",
|
2491
2672
|
|
2492
|
-
"FLASH_ATTN",
|
2493
2673
|
"FLASH_ATTN_EXT",
|
2494
|
-
"FLASH_FF",
|
2495
2674
|
"FLASH_ATTN_BACK",
|
2496
2675
|
"SSM_CONV",
|
2497
2676
|
"SSM_SCAN",
|
@@ -2517,7 +2696,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|
2517
2696
|
"CROSS_ENTROPY_LOSS_BACK",
|
2518
2697
|
};
|
2519
2698
|
|
2520
|
-
static_assert(GGML_OP_COUNT ==
|
2699
|
+
static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");
|
2521
2700
|
|
2522
2701
|
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
2523
2702
|
"none",
|
@@ -2579,9 +2758,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|
2579
2758
|
"argsort(x)",
|
2580
2759
|
"leaky_relu(x)",
|
2581
2760
|
|
2582
|
-
"flash_attn(x)",
|
2583
2761
|
"flash_attn_ext(x)",
|
2584
|
-
"flash_ff(x)",
|
2585
2762
|
"flash_attn_back(x)",
|
2586
2763
|
"ssm_conv(x)",
|
2587
2764
|
"ssm_scan(x)",
|
@@ -2607,7 +2784,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|
2607
2784
|
"cross_entropy_loss_back(x,y)",
|
2608
2785
|
};
|
2609
2786
|
|
2610
|
-
static_assert(GGML_OP_COUNT ==
|
2787
|
+
static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");
|
2611
2788
|
|
2612
2789
|
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
2613
2790
|
|
@@ -6042,6 +6219,7 @@ static struct ggml_tensor * ggml_rope_impl(
|
|
6042
6219
|
struct ggml_context * ctx,
|
6043
6220
|
struct ggml_tensor * a,
|
6044
6221
|
struct ggml_tensor * b,
|
6222
|
+
struct ggml_tensor * c,
|
6045
6223
|
int n_dims,
|
6046
6224
|
int mode,
|
6047
6225
|
int n_ctx,
|
@@ -6055,10 +6233,17 @@ static struct ggml_tensor * ggml_rope_impl(
|
|
6055
6233
|
float xpos_base,
|
6056
6234
|
bool xpos_down,
|
6057
6235
|
bool inplace) {
|
6236
|
+
GGML_ASSERT((mode & 1) == 0 && "mode & 1 == 1 is no longer supported");
|
6237
|
+
|
6058
6238
|
GGML_ASSERT(ggml_is_vector(b));
|
6059
6239
|
GGML_ASSERT(b->type == GGML_TYPE_I32);
|
6060
6240
|
GGML_ASSERT(a->ne[2] == b->ne[0]);
|
6061
6241
|
|
6242
|
+
if (c) {
|
6243
|
+
GGML_ASSERT(c->type == GGML_TYPE_F32);
|
6244
|
+
GGML_ASSERT(c->ne[0] >= n_dims / 2);
|
6245
|
+
}
|
6246
|
+
|
6062
6247
|
bool is_node = false;
|
6063
6248
|
|
6064
6249
|
if (a->grad) {
|
@@ -6082,6 +6267,7 @@ static struct ggml_tensor * ggml_rope_impl(
|
|
6082
6267
|
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
6083
6268
|
result->src[0] = a;
|
6084
6269
|
result->src[1] = b;
|
6270
|
+
result->src[2] = c;
|
6085
6271
|
|
6086
6272
|
return result;
|
6087
6273
|
}
|
@@ -6094,7 +6280,7 @@ struct ggml_tensor * ggml_rope(
|
|
6094
6280
|
int mode,
|
6095
6281
|
int n_ctx) {
|
6096
6282
|
return ggml_rope_impl(
|
6097
|
-
ctx, a, b, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, false
|
6283
|
+
ctx, a, b, NULL, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, false
|
6098
6284
|
);
|
6099
6285
|
}
|
6100
6286
|
|
@@ -6106,14 +6292,15 @@ struct ggml_tensor * ggml_rope_inplace(
|
|
6106
6292
|
int mode,
|
6107
6293
|
int n_ctx) {
|
6108
6294
|
return ggml_rope_impl(
|
6109
|
-
ctx, a, b, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, true
|
6295
|
+
ctx, a, b, NULL, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, true
|
6110
6296
|
);
|
6111
6297
|
}
|
6112
6298
|
|
6113
|
-
struct ggml_tensor *
|
6299
|
+
struct ggml_tensor * ggml_rope_ext(
|
6114
6300
|
struct ggml_context * ctx,
|
6115
6301
|
struct ggml_tensor * a,
|
6116
6302
|
struct ggml_tensor * b,
|
6303
|
+
struct ggml_tensor * c,
|
6117
6304
|
int n_dims,
|
6118
6305
|
int mode,
|
6119
6306
|
int n_ctx,
|
@@ -6125,15 +6312,16 @@ struct ggml_tensor * ggml_rope_custom(
|
|
6125
6312
|
float beta_fast,
|
6126
6313
|
float beta_slow) {
|
6127
6314
|
return ggml_rope_impl(
|
6128
|
-
ctx, a, b, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
|
6315
|
+
ctx, a, b, c, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
|
6129
6316
|
ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, false
|
6130
6317
|
);
|
6131
6318
|
}
|
6132
6319
|
|
6133
|
-
struct ggml_tensor *
|
6320
|
+
struct ggml_tensor * ggml_rope_ext_inplace(
|
6134
6321
|
struct ggml_context * ctx,
|
6135
6322
|
struct ggml_tensor * a,
|
6136
6323
|
struct ggml_tensor * b,
|
6324
|
+
struct ggml_tensor * c,
|
6137
6325
|
int n_dims,
|
6138
6326
|
int mode,
|
6139
6327
|
int n_ctx,
|
@@ -6145,19 +6333,49 @@ struct ggml_tensor * ggml_rope_custom_inplace(
|
|
6145
6333
|
float beta_fast,
|
6146
6334
|
float beta_slow) {
|
6147
6335
|
return ggml_rope_impl(
|
6148
|
-
ctx, a, b, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
|
6336
|
+
ctx, a, b, c, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
|
6149
6337
|
ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, true
|
6150
6338
|
);
|
6151
6339
|
}
|
6152
6340
|
|
6153
|
-
struct ggml_tensor *
|
6341
|
+
struct ggml_tensor * ggml_rope_custom(
|
6342
|
+
struct ggml_context * ctx,
|
6343
|
+
struct ggml_tensor * a,
|
6344
|
+
struct ggml_tensor * b,
|
6345
|
+
int n_dims,
|
6346
|
+
int mode,
|
6347
|
+
int n_ctx,
|
6348
|
+
int n_orig_ctx,
|
6349
|
+
float freq_base,
|
6350
|
+
float freq_scale,
|
6351
|
+
float ext_factor,
|
6352
|
+
float attn_factor,
|
6353
|
+
float beta_fast,
|
6354
|
+
float beta_slow) {
|
6355
|
+
return ggml_rope_impl(
|
6356
|
+
ctx, a, b, NULL, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
|
6357
|
+
ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, false
|
6358
|
+
);
|
6359
|
+
}
|
6360
|
+
|
6361
|
+
struct ggml_tensor * ggml_rope_custom_inplace(
|
6154
6362
|
struct ggml_context * ctx,
|
6155
6363
|
struct ggml_tensor * a,
|
6156
6364
|
struct ggml_tensor * b,
|
6157
6365
|
int n_dims,
|
6158
|
-
|
6159
|
-
|
6160
|
-
|
6366
|
+
int mode,
|
6367
|
+
int n_ctx,
|
6368
|
+
int n_orig_ctx,
|
6369
|
+
float freq_base,
|
6370
|
+
float freq_scale,
|
6371
|
+
float ext_factor,
|
6372
|
+
float attn_factor,
|
6373
|
+
float beta_fast,
|
6374
|
+
float beta_slow) {
|
6375
|
+
return ggml_rope_impl(
|
6376
|
+
ctx, a, b, NULL, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
|
6377
|
+
ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, true
|
6378
|
+
);
|
6161
6379
|
}
|
6162
6380
|
|
6163
6381
|
// ggml_rope_back
|
@@ -6166,6 +6384,7 @@ struct ggml_tensor * ggml_rope_back(
|
|
6166
6384
|
struct ggml_context * ctx,
|
6167
6385
|
struct ggml_tensor * a,
|
6168
6386
|
struct ggml_tensor * b,
|
6387
|
+
struct ggml_tensor * c,
|
6169
6388
|
int n_dims,
|
6170
6389
|
int mode,
|
6171
6390
|
int n_ctx,
|
@@ -6181,6 +6400,7 @@ struct ggml_tensor * ggml_rope_back(
|
|
6181
6400
|
GGML_ASSERT(ggml_is_vector(b));
|
6182
6401
|
GGML_ASSERT(b->type == GGML_TYPE_I32);
|
6183
6402
|
GGML_ASSERT(a->ne[2] == b->ne[0]);
|
6403
|
+
GGML_ASSERT(c == NULL && "freq factors not implemented yet");
|
6184
6404
|
|
6185
6405
|
GGML_ASSERT((mode & 4) == 0 && "ggml_rope_back() for ChatGLM not implemented yet");
|
6186
6406
|
|
@@ -6724,38 +6944,6 @@ struct ggml_tensor * ggml_top_k(
|
|
6724
6944
|
return result;
|
6725
6945
|
}
|
6726
6946
|
|
6727
|
-
// ggml_flash_attn
|
6728
|
-
|
6729
|
-
struct ggml_tensor * ggml_flash_attn(
|
6730
|
-
struct ggml_context * ctx,
|
6731
|
-
struct ggml_tensor * q,
|
6732
|
-
struct ggml_tensor * k,
|
6733
|
-
struct ggml_tensor * v,
|
6734
|
-
bool masked) {
|
6735
|
-
GGML_ASSERT(ggml_can_mul_mat(k, q));
|
6736
|
-
// TODO: check if vT can be multiplied by (k*qT)
|
6737
|
-
|
6738
|
-
bool is_node = false;
|
6739
|
-
|
6740
|
-
if (q->grad || k->grad || v->grad) {
|
6741
|
-
is_node = true;
|
6742
|
-
}
|
6743
|
-
|
6744
|
-
//struct ggml_tensor * result = ggml_dup_tensor(ctx, q);
|
6745
|
-
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, q->ne);
|
6746
|
-
|
6747
|
-
int32_t t = masked ? 1 : 0;
|
6748
|
-
ggml_set_op_params(result, &t, sizeof(t));
|
6749
|
-
|
6750
|
-
result->op = GGML_OP_FLASH_ATTN;
|
6751
|
-
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
6752
|
-
result->src[0] = q;
|
6753
|
-
result->src[1] = k;
|
6754
|
-
result->src[2] = v;
|
6755
|
-
|
6756
|
-
return result;
|
6757
|
-
}
|
6758
|
-
|
6759
6947
|
// ggml_flash_attn_ext
|
6760
6948
|
|
6761
6949
|
struct ggml_tensor * ggml_flash_attn_ext(
|
@@ -6815,38 +7003,6 @@ void ggml_flash_attn_ext_set_prec(
|
|
6815
7003
|
ggml_set_op_params_i32(a, 2, prec_i32); // scale is on first pos, max_bias on second
|
6816
7004
|
}
|
6817
7005
|
|
6818
|
-
// ggml_flash_ff
|
6819
|
-
|
6820
|
-
struct ggml_tensor * ggml_flash_ff(
|
6821
|
-
struct ggml_context * ctx,
|
6822
|
-
struct ggml_tensor * a,
|
6823
|
-
struct ggml_tensor * b0,
|
6824
|
-
struct ggml_tensor * b1,
|
6825
|
-
struct ggml_tensor * c0,
|
6826
|
-
struct ggml_tensor * c1) {
|
6827
|
-
GGML_ASSERT(ggml_can_mul_mat(b0, a));
|
6828
|
-
// TODO: more checks
|
6829
|
-
|
6830
|
-
bool is_node = false;
|
6831
|
-
|
6832
|
-
if (a->grad || b0->grad || b1->grad || c0->grad || c1->grad) {
|
6833
|
-
is_node = true;
|
6834
|
-
}
|
6835
|
-
|
6836
|
-
//struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
|
6837
|
-
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, a->ne);
|
6838
|
-
|
6839
|
-
result->op = GGML_OP_FLASH_FF;
|
6840
|
-
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
6841
|
-
result->src[0] = a;
|
6842
|
-
result->src[1] = b0;
|
6843
|
-
result->src[2] = b1;
|
6844
|
-
result->src[3] = c0;
|
6845
|
-
result->src[4] = c1;
|
6846
|
-
|
6847
|
-
return result;
|
6848
|
-
}
|
6849
|
-
|
6850
7006
|
// ggml_flash_attn_back
|
6851
7007
|
|
6852
7008
|
struct ggml_tensor * ggml_flash_attn_back(
|
@@ -6856,6 +7012,8 @@ struct ggml_tensor * ggml_flash_attn_back(
|
|
6856
7012
|
struct ggml_tensor * v,
|
6857
7013
|
struct ggml_tensor * d,
|
6858
7014
|
bool masked) {
|
7015
|
+
GGML_ASSERT(false && "TODO: adapt to ggml_flash_attn_ext() changes");
|
7016
|
+
|
6859
7017
|
GGML_ASSERT(ggml_can_mul_mat(k, q));
|
6860
7018
|
// TODO: check if vT can be multiplied by (k*qT)
|
6861
7019
|
|
@@ -14115,6 +14273,7 @@ static void ggml_compute_forward_rope_f32(
|
|
14115
14273
|
|
14116
14274
|
const struct ggml_tensor * src0 = dst->src[0];
|
14117
14275
|
const struct ggml_tensor * src1 = dst->src[1];
|
14276
|
+
const struct ggml_tensor * src2 = dst->src[2];
|
14118
14277
|
|
14119
14278
|
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
|
14120
14279
|
return;
|
@@ -14174,6 +14333,17 @@ static void ggml_compute_forward_rope_f32(
|
|
14174
14333
|
const bool is_neox = mode & 2;
|
14175
14334
|
const bool is_glm = mode & 4;
|
14176
14335
|
|
14336
|
+
const float * freq_factors = NULL;
|
14337
|
+
if (is_neox) {
|
14338
|
+
if (src2 != NULL) {
|
14339
|
+
GGML_ASSERT(src2->type == GGML_TYPE_F32);
|
14340
|
+
GGML_ASSERT(src2->ne[0] >= n_dims / 2);
|
14341
|
+
freq_factors = (const float *) src2->data;
|
14342
|
+
}
|
14343
|
+
} else {
|
14344
|
+
GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for !is_neox");
|
14345
|
+
}
|
14346
|
+
|
14177
14347
|
// backward process uses inverse rotation by cos and sin.
|
14178
14348
|
// cos and sin build a rotation matrix, where the inverse is the transpose.
|
14179
14349
|
// this essentially just switches the sign of sin.
|
@@ -14250,10 +14420,11 @@ static void ggml_compute_forward_rope_f32(
|
|
14250
14420
|
|
14251
14421
|
// simplified from `(ib * n_dims + ic) * inv_ndims`
|
14252
14422
|
float cur_rot = inv_ndims * ic - ib;
|
14423
|
+
float freq_factor = freq_factors ? freq_factors[ic/2] : 1.0f;
|
14253
14424
|
|
14254
14425
|
float cos_theta, sin_theta;
|
14255
14426
|
rope_yarn(
|
14256
|
-
theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
|
14427
|
+
theta_base/freq_factor, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
|
14257
14428
|
&cos_theta, &sin_theta
|
14258
14429
|
);
|
14259
14430
|
sin_theta *= sin_sign;
|
@@ -14286,6 +14457,7 @@ static void ggml_compute_forward_rope_f32(
|
|
14286
14457
|
}
|
14287
14458
|
}
|
14288
14459
|
|
14460
|
+
// TODO: deduplicate f16/f32 code
|
14289
14461
|
static void ggml_compute_forward_rope_f16(
|
14290
14462
|
const struct ggml_compute_params * params,
|
14291
14463
|
struct ggml_tensor * dst,
|
@@ -14293,6 +14465,7 @@ static void ggml_compute_forward_rope_f16(
|
|
14293
14465
|
|
14294
14466
|
const struct ggml_tensor * src0 = dst->src[0];
|
14295
14467
|
const struct ggml_tensor * src1 = dst->src[1];
|
14468
|
+
const struct ggml_tensor * src2 = dst->src[2];
|
14296
14469
|
|
14297
14470
|
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
|
14298
14471
|
return;
|
@@ -14345,6 +14518,17 @@ static void ggml_compute_forward_rope_f16(
|
|
14345
14518
|
const bool is_neox = mode & 2;
|
14346
14519
|
const bool is_glm = mode & 4;
|
14347
14520
|
|
14521
|
+
const float * freq_factors = NULL;
|
14522
|
+
if (is_neox) {
|
14523
|
+
if (src2 != NULL) {
|
14524
|
+
GGML_ASSERT(src2->type == GGML_TYPE_F32);
|
14525
|
+
GGML_ASSERT(src2->ne[0] >= n_dims / 2);
|
14526
|
+
freq_factors = (const float *) src2->data;
|
14527
|
+
}
|
14528
|
+
} else {
|
14529
|
+
GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for !is_neox");
|
14530
|
+
}
|
14531
|
+
|
14348
14532
|
// backward process uses inverse rotation by cos and sin.
|
14349
14533
|
// cos and sin build a rotation matrix, where the inverse is the transpose.
|
14350
14534
|
// this essentially just switches the sign of sin.
|
@@ -14417,10 +14601,11 @@ static void ggml_compute_forward_rope_f16(
|
|
14417
14601
|
|
14418
14602
|
// simplified from `(ib * n_dims + ic) * inv_ndims`
|
14419
14603
|
float cur_rot = inv_ndims * ic - ib;
|
14604
|
+
float freq_factor = freq_factors ? freq_factors[ic/2] : 1.0f;
|
14420
14605
|
|
14421
14606
|
float cos_theta, sin_theta;
|
14422
14607
|
rope_yarn(
|
14423
|
-
theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
|
14608
|
+
theta_base/freq_factor, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
|
14424
14609
|
&cos_theta, &sin_theta
|
14425
14610
|
);
|
14426
14611
|
sin_theta *= sin_sign;
|
@@ -15458,17 +15643,15 @@ static void ggml_compute_forward_argsort(
|
|
15458
15643
|
}
|
15459
15644
|
}
|
15460
15645
|
|
15461
|
-
//
|
15646
|
+
// ggml_compute_forward_flash_attn_ext
|
15462
15647
|
|
15463
|
-
static void
|
15648
|
+
static void ggml_compute_forward_flash_attn_ext_f16(
|
15464
15649
|
const struct ggml_compute_params * params,
|
15465
|
-
const
|
15650
|
+
const struct ggml_tensor * q,
|
15651
|
+
const struct ggml_tensor * k,
|
15652
|
+
const struct ggml_tensor * v,
|
15653
|
+
const struct ggml_tensor * mask,
|
15466
15654
|
struct ggml_tensor * dst) {
|
15467
|
-
|
15468
|
-
const struct ggml_tensor * q = dst->src[0];
|
15469
|
-
const struct ggml_tensor * k = dst->src[1];
|
15470
|
-
const struct ggml_tensor * v = dst->src[2];
|
15471
|
-
|
15472
15655
|
int64_t t0 = ggml_perf_time_us();
|
15473
15656
|
UNUSED(t0);
|
15474
15657
|
|
@@ -15486,409 +15669,18 @@ static void ggml_compute_forward_flash_attn_f32(
|
|
15486
15669
|
|
15487
15670
|
const int64_t D = neq0;
|
15488
15671
|
const int64_t N = neq1;
|
15489
|
-
const int64_t P = nek1 - N;
|
15490
|
-
const int64_t M = P + N;
|
15491
|
-
|
15492
|
-
const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL);
|
15493
15672
|
|
15494
15673
|
GGML_ASSERT(ne0 == D);
|
15495
|
-
GGML_ASSERT(
|
15496
|
-
GGML_ASSERT(P >= 0);
|
15674
|
+
GGML_ASSERT(ne2 == N);
|
15497
15675
|
|
15498
|
-
|
15499
|
-
GGML_ASSERT(
|
15500
|
-
GGML_ASSERT(
|
15676
|
+
// input tensor rows must be contiguous
|
15677
|
+
GGML_ASSERT(nbq0 == ggml_type_size(q->type));
|
15678
|
+
GGML_ASSERT(nbk0 == ggml_type_size(k->type));
|
15679
|
+
GGML_ASSERT(nbv0 == ggml_type_size(v->type));
|
15501
15680
|
|
15502
15681
|
GGML_ASSERT(neq0 == D);
|
15503
15682
|
GGML_ASSERT(nek0 == D);
|
15504
|
-
GGML_ASSERT(
|
15505
|
-
|
15506
|
-
GGML_ASSERT(neq1 == N);
|
15507
|
-
GGML_ASSERT(nek1 == N + P);
|
15508
|
-
GGML_ASSERT(nev1 == D);
|
15509
|
-
|
15510
|
-
// dst cannot be transposed or permuted
|
15511
|
-
GGML_ASSERT(nb0 == sizeof(float));
|
15512
|
-
GGML_ASSERT(nb0 <= nb1);
|
15513
|
-
GGML_ASSERT(nb1 <= nb2);
|
15514
|
-
GGML_ASSERT(nb2 <= nb3);
|
15515
|
-
|
15516
|
-
if (params->type == GGML_TASK_TYPE_INIT) {
|
15517
|
-
return;
|
15518
|
-
}
|
15519
|
-
|
15520
|
-
if (params->type == GGML_TASK_TYPE_FINALIZE) {
|
15521
|
-
return;
|
15522
|
-
}
|
15523
|
-
|
15524
|
-
// parallelize by q rows using ggml_vec_dot_f32
|
15525
|
-
|
15526
|
-
// total rows in q
|
15527
|
-
const int nr = neq1*neq2*neq3;
|
15528
|
-
|
15529
|
-
// rows per thread
|
15530
|
-
const int dr = (nr + nth - 1)/nth;
|
15531
|
-
|
15532
|
-
// row range for this thread
|
15533
|
-
const int ir0 = dr*ith;
|
15534
|
-
const int ir1 = MIN(ir0 + dr, nr);
|
15535
|
-
|
15536
|
-
const float scale = 1.0f/sqrtf(D);
|
15537
|
-
|
15538
|
-
//printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
|
15539
|
-
|
15540
|
-
for (int ir = ir0; ir < ir1; ++ir) {
|
15541
|
-
// q indices
|
15542
|
-
const int iq3 = ir/(neq2*neq1);
|
15543
|
-
const int iq2 = (ir - iq3*neq2*neq1)/neq1;
|
15544
|
-
const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
|
15545
|
-
|
15546
|
-
float * S = (float *) params->wdata + ith*(Mup + CACHE_LINE_SIZE_F32);
|
15547
|
-
|
15548
|
-
for (int i = M; i < Mup; ++i) {
|
15549
|
-
S[i] = -INFINITY;
|
15550
|
-
}
|
15551
|
-
|
15552
|
-
const int64_t masked_begin = masked ? (P + iq1 + 1) : M;
|
15553
|
-
for (int64_t ic = 0; ic < masked_begin; ++ic) {
|
15554
|
-
// k indices
|
15555
|
-
const int ik3 = iq3;
|
15556
|
-
const int ik2 = iq2 % nek2;
|
15557
|
-
const int ik1 = ic;
|
15558
|
-
|
15559
|
-
// S indices
|
15560
|
-
const int i1 = ik1;
|
15561
|
-
|
15562
|
-
ggml_vec_dot_f32(neq0,
|
15563
|
-
S + i1, 0,
|
15564
|
-
(float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
|
15565
|
-
(float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)), 0, 1);
|
15566
|
-
}
|
15567
|
-
|
15568
|
-
// scale
|
15569
|
-
ggml_vec_scale_f32(masked_begin, S, scale);
|
15570
|
-
|
15571
|
-
for (int64_t i = masked_begin; i < M; i++) {
|
15572
|
-
S[i] = -INFINITY;
|
15573
|
-
}
|
15574
|
-
|
15575
|
-
// softmax
|
15576
|
-
// exclude known -INF S[..] values from max and loop
|
15577
|
-
// dont forget to set their SW values to zero
|
15578
|
-
{
|
15579
|
-
float max = -INFINITY;
|
15580
|
-
ggml_vec_max_f32(masked_begin, &max, S);
|
15581
|
-
|
15582
|
-
ggml_float sum = 0.0;
|
15583
|
-
{
|
15584
|
-
#ifdef GGML_SOFT_MAX_ACCELERATE
|
15585
|
-
max = -max;
|
15586
|
-
vDSP_vsadd(S, 1, &max, S, 1, Mup);
|
15587
|
-
vvexpf(S, S, &Mup);
|
15588
|
-
ggml_vec_sum_f32(Mup, &sum, S);
|
15589
|
-
#else
|
15590
|
-
sum = ggml_vec_soft_max_f32(Mup, S, S, max);
|
15591
|
-
#endif
|
15592
|
-
}
|
15593
|
-
|
15594
|
-
assert(sum > 0.0);
|
15595
|
-
|
15596
|
-
sum = 1.0/sum;
|
15597
|
-
ggml_vec_scale_f32(masked_begin, S, sum);
|
15598
|
-
|
15599
|
-
#ifndef NDEBUG
|
15600
|
-
for (int i = 0; i < masked_begin; ++i) {
|
15601
|
-
assert(!isnan(S[i]));
|
15602
|
-
assert(!isinf(S[i]));
|
15603
|
-
}
|
15604
|
-
#endif
|
15605
|
-
}
|
15606
|
-
|
15607
|
-
for (int64_t ic = 0; ic < nev1; ++ic) {
|
15608
|
-
// dst indices
|
15609
|
-
const int i1 = iq1;
|
15610
|
-
const int i2 = iq2;
|
15611
|
-
const int i3 = iq3;
|
15612
|
-
|
15613
|
-
// v indices
|
15614
|
-
const int iv2 = iq2 % nev2;
|
15615
|
-
const int iv3 = iq3;
|
15616
|
-
|
15617
|
-
ggml_vec_dot_f32(masked_begin,
|
15618
|
-
(float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), 0,
|
15619
|
-
(float *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), 0,
|
15620
|
-
S, 0, 1);
|
15621
|
-
}
|
15622
|
-
}
|
15623
|
-
}
|
15624
|
-
|
15625
|
-
static void ggml_compute_forward_flash_attn_f16(
|
15626
|
-
const struct ggml_compute_params * params,
|
15627
|
-
const bool masked,
|
15628
|
-
struct ggml_tensor * dst) {
|
15629
|
-
|
15630
|
-
const struct ggml_tensor * q = dst->src[0];
|
15631
|
-
const struct ggml_tensor * k = dst->src[1];
|
15632
|
-
const struct ggml_tensor * v = dst->src[2];
|
15633
|
-
|
15634
|
-
int64_t t0 = ggml_perf_time_us();
|
15635
|
-
UNUSED(t0);
|
15636
|
-
|
15637
|
-
GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
|
15638
|
-
GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
|
15639
|
-
GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
|
15640
|
-
GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
|
15641
|
-
GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
|
15642
|
-
GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
|
15643
|
-
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
15644
|
-
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
15645
|
-
|
15646
|
-
const int ith = params->ith;
|
15647
|
-
const int nth = params->nth;
|
15648
|
-
|
15649
|
-
const int64_t D = neq0;
|
15650
|
-
const int64_t N = neq1;
|
15651
|
-
const int64_t P = nek1 - N;
|
15652
|
-
const int64_t M = P + N;
|
15653
|
-
|
15654
|
-
const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL);
|
15655
|
-
|
15656
|
-
GGML_ASSERT(ne0 == D);
|
15657
|
-
GGML_ASSERT(ne1 == N);
|
15658
|
-
GGML_ASSERT(P >= 0);
|
15659
|
-
|
15660
|
-
GGML_ASSERT(nbq0 == sizeof(ggml_fp16_t));
|
15661
|
-
GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t));
|
15662
|
-
GGML_ASSERT(nbv0 == sizeof(ggml_fp16_t));
|
15663
|
-
|
15664
|
-
GGML_ASSERT(neq0 == D);
|
15665
|
-
GGML_ASSERT(nek0 == D);
|
15666
|
-
GGML_ASSERT(nev1 == D);
|
15667
|
-
|
15668
|
-
GGML_ASSERT(neq1 == N);
|
15669
|
-
GGML_ASSERT(nek1 == N + P);
|
15670
|
-
GGML_ASSERT(nev1 == D);
|
15671
|
-
|
15672
|
-
// dst cannot be transposed or permuted
|
15673
|
-
GGML_ASSERT(nb0 == sizeof(float));
|
15674
|
-
GGML_ASSERT(nb0 <= nb1);
|
15675
|
-
GGML_ASSERT(nb1 <= nb2);
|
15676
|
-
GGML_ASSERT(nb2 <= nb3);
|
15677
|
-
|
15678
|
-
if (params->type == GGML_TASK_TYPE_INIT) {
|
15679
|
-
return;
|
15680
|
-
}
|
15681
|
-
|
15682
|
-
if (params->type == GGML_TASK_TYPE_FINALIZE) {
|
15683
|
-
return;
|
15684
|
-
}
|
15685
|
-
|
15686
|
-
// parallelize by q rows using ggml_vec_dot_f32
|
15687
|
-
|
15688
|
-
// total rows in q
|
15689
|
-
const int nr = neq1*neq2*neq3;
|
15690
|
-
|
15691
|
-
// rows per thread
|
15692
|
-
const int dr = (nr + nth - 1)/nth;
|
15693
|
-
|
15694
|
-
// row range for this thread
|
15695
|
-
const int ir0 = dr*ith;
|
15696
|
-
const int ir1 = MIN(ir0 + dr, nr);
|
15697
|
-
|
15698
|
-
const float scale = 1.0f/sqrtf(D);
|
15699
|
-
|
15700
|
-
//printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
|
15701
|
-
|
15702
|
-
for (int ir = ir0; ir < ir1; ++ir) {
|
15703
|
-
// q indices
|
15704
|
-
const int iq3 = ir/(neq2*neq1);
|
15705
|
-
const int iq2 = (ir - iq3*neq2*neq1)/neq1;
|
15706
|
-
const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
|
15707
|
-
|
15708
|
-
float * S = (float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32);
|
15709
|
-
|
15710
|
-
for (int i = M; i < Mup; ++i) {
|
15711
|
-
S[i] = -INFINITY;
|
15712
|
-
}
|
15713
|
-
|
15714
|
-
if (GGML_VEC_DOT_UNROLL > 2 || nek1 % GGML_VEC_DOT_UNROLL != 0) {
|
15715
|
-
for (int64_t ic = 0; ic < nek1; ++ic) {
|
15716
|
-
// k indices
|
15717
|
-
const int ik3 = iq3;
|
15718
|
-
const int ik2 = iq2 % nek2;
|
15719
|
-
const int ik1 = ic;
|
15720
|
-
|
15721
|
-
// S indices
|
15722
|
-
const int i1 = ik1;
|
15723
|
-
|
15724
|
-
ggml_vec_dot_f16(neq0,
|
15725
|
-
S + i1, 0,
|
15726
|
-
(ggml_fp16_t *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
|
15727
|
-
(ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)), 0, 1);
|
15728
|
-
}
|
15729
|
-
} else {
|
15730
|
-
for (int64_t ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) {
|
15731
|
-
// k indices
|
15732
|
-
const int ik3 = iq3;
|
15733
|
-
const int ik2 = iq2 % nek2;
|
15734
|
-
const int ik1 = ic;
|
15735
|
-
|
15736
|
-
// S indices
|
15737
|
-
const int i1 = ik1;
|
15738
|
-
|
15739
|
-
ggml_vec_dot_f16_unroll(neq0, nbk1,
|
15740
|
-
S + i1,
|
15741
|
-
((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
|
15742
|
-
(ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
|
15743
|
-
}
|
15744
|
-
}
|
15745
|
-
|
15746
|
-
// scale
|
15747
|
-
ggml_vec_scale_f32(nek1, S, scale);
|
15748
|
-
|
15749
|
-
if (masked) {
|
15750
|
-
for (int64_t i = P; i < M; i++) {
|
15751
|
-
if (i > P + iq1) {
|
15752
|
-
S[i] = -INFINITY;
|
15753
|
-
}
|
15754
|
-
}
|
15755
|
-
}
|
15756
|
-
|
15757
|
-
// softmax
|
15758
|
-
// todo: exclude known -INF S[..] values from max and loop, assuming their results to be zero.
|
15759
|
-
// dont forget to set their S values to zero
|
15760
|
-
{
|
15761
|
-
float max = -INFINITY;
|
15762
|
-
ggml_vec_max_f32(M, &max, S);
|
15763
|
-
|
15764
|
-
ggml_float sum = 0.0;
|
15765
|
-
{
|
15766
|
-
#ifdef GGML_SOFT_MAX_ACCELERATE
|
15767
|
-
max = -max;
|
15768
|
-
vDSP_vsadd(S, 1, &max, S, 1, Mup);
|
15769
|
-
vvexpf(S, S, &Mup);
|
15770
|
-
ggml_vec_sum_f32(Mup, &sum, S);
|
15771
|
-
#else
|
15772
|
-
sum = ggml_vec_soft_max_f32(Mup, S, S, max);
|
15773
|
-
#endif
|
15774
|
-
}
|
15775
|
-
|
15776
|
-
assert(sum > 0.0);
|
15777
|
-
|
15778
|
-
sum = 1.0/sum;
|
15779
|
-
ggml_vec_scale_f32(M, S, sum);
|
15780
|
-
|
15781
|
-
#ifndef NDEBUG
|
15782
|
-
for (int i = 0; i < M; ++i) {
|
15783
|
-
assert(!isnan(S[i]));
|
15784
|
-
assert(!isinf(S[i]));
|
15785
|
-
}
|
15786
|
-
#endif
|
15787
|
-
}
|
15788
|
-
|
15789
|
-
ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32) + Mup);
|
15790
|
-
|
15791
|
-
for (int64_t i = 0; i < M; i++) {
|
15792
|
-
S16[i] = GGML_FP32_TO_FP16(S[i]);
|
15793
|
-
}
|
15794
|
-
|
15795
|
-
// todo: exclude known zero S[..] values from dot (reducing nev0 and increasing begin of v and S16).
|
15796
|
-
if (GGML_VEC_DOT_UNROLL == 1 || (nev1 % GGML_VEC_DOT_UNROLL != 0)) {
|
15797
|
-
for (int64_t ic = 0; ic < nev1; ++ic) {
|
15798
|
-
// dst indices
|
15799
|
-
const int i1 = iq1;
|
15800
|
-
const int i2 = iq2;
|
15801
|
-
const int i3 = iq3;
|
15802
|
-
|
15803
|
-
// v indices
|
15804
|
-
const int iv2 = iq2 % nev2;
|
15805
|
-
const int iv3 = iq3;
|
15806
|
-
|
15807
|
-
ggml_vec_dot_f16(nev0,
|
15808
|
-
(float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), 0,
|
15809
|
-
(ggml_fp16_t *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), 0,
|
15810
|
-
S16, 0, 1);
|
15811
|
-
}
|
15812
|
-
} else {
|
15813
|
-
for (int64_t ic = 0; ic < nev1; ic += GGML_VEC_DOT_UNROLL) {
|
15814
|
-
// dst indices
|
15815
|
-
const int i1 = iq1;
|
15816
|
-
const int i2 = iq2;
|
15817
|
-
const int i3 = iq3;
|
15818
|
-
|
15819
|
-
// v indices
|
15820
|
-
const int iv2 = iq2 % nev2;
|
15821
|
-
const int iv3 = iq3;
|
15822
|
-
|
15823
|
-
ggml_vec_dot_f16_unroll(nev0, nbv1,
|
15824
|
-
(float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
|
15825
|
-
((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
|
15826
|
-
S16);
|
15827
|
-
}
|
15828
|
-
}
|
15829
|
-
}
|
15830
|
-
}
|
15831
|
-
|
15832
|
-
static void ggml_compute_forward_flash_attn(
|
15833
|
-
const struct ggml_compute_params * params,
|
15834
|
-
const bool masked,
|
15835
|
-
struct ggml_tensor * dst) {
|
15836
|
-
|
15837
|
-
const struct ggml_tensor * q = dst->src[0];
|
15838
|
-
|
15839
|
-
switch (q->type) {
|
15840
|
-
case GGML_TYPE_F16:
|
15841
|
-
{
|
15842
|
-
ggml_compute_forward_flash_attn_f16(params, masked, dst);
|
15843
|
-
} break;
|
15844
|
-
case GGML_TYPE_F32:
|
15845
|
-
{
|
15846
|
-
ggml_compute_forward_flash_attn_f32(params, masked, dst);
|
15847
|
-
} break;
|
15848
|
-
default:
|
15849
|
-
{
|
15850
|
-
GGML_ASSERT(false);
|
15851
|
-
} break;
|
15852
|
-
}
|
15853
|
-
}
|
15854
|
-
|
15855
|
-
// ggml_compute_forward_flash_attn_ext
|
15856
|
-
|
15857
|
-
static void ggml_compute_forward_flash_attn_ext_f16(
|
15858
|
-
const struct ggml_compute_params * params,
|
15859
|
-
const struct ggml_tensor * q,
|
15860
|
-
const struct ggml_tensor * k,
|
15861
|
-
const struct ggml_tensor * v,
|
15862
|
-
const struct ggml_tensor * mask,
|
15863
|
-
struct ggml_tensor * dst) {
|
15864
|
-
int64_t t0 = ggml_perf_time_us();
|
15865
|
-
UNUSED(t0);
|
15866
|
-
|
15867
|
-
GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
|
15868
|
-
GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
|
15869
|
-
GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
|
15870
|
-
GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
|
15871
|
-
GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
|
15872
|
-
GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
|
15873
|
-
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
15874
|
-
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
15875
|
-
|
15876
|
-
const int ith = params->ith;
|
15877
|
-
const int nth = params->nth;
|
15878
|
-
|
15879
|
-
const int64_t D = neq0;
|
15880
|
-
const int64_t N = neq1;
|
15881
|
-
|
15882
|
-
GGML_ASSERT(ne0 == D);
|
15883
|
-
GGML_ASSERT(ne2 == N);
|
15884
|
-
|
15885
|
-
GGML_ASSERT(nbq0 == sizeof(float));
|
15886
|
-
GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t));
|
15887
|
-
GGML_ASSERT(nbv0 == sizeof(ggml_fp16_t));
|
15888
|
-
|
15889
|
-
GGML_ASSERT(neq0 == D);
|
15890
|
-
GGML_ASSERT(nek0 == D);
|
15891
|
-
GGML_ASSERT(nev0 == D);
|
15683
|
+
GGML_ASSERT(nev0 == D);
|
15892
15684
|
|
15893
15685
|
GGML_ASSERT(neq1 == N);
|
15894
15686
|
GGML_ASSERT(nev0 == D);
|
@@ -15938,6 +15730,11 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
15938
15730
|
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
15939
15731
|
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
15940
15732
|
|
15733
|
+
enum ggml_type const k_vec_dot_type = type_traits[k->type].vec_dot_type;
|
15734
|
+
ggml_from_float_t const q_to_vec_dot = type_traits[k_vec_dot_type].from_float;
|
15735
|
+
ggml_vec_dot_t const kq_vec_dot = type_traits[k->type].vec_dot;
|
15736
|
+
ggml_to_float_t const v_to_float = type_traits[v->type].to_float;
|
15737
|
+
|
15941
15738
|
// loop over n_batch and n_head
|
15942
15739
|
for (int ir = ir0; ir < ir1; ++ir) {
|
15943
15740
|
// q indices
|
@@ -15945,17 +15742,22 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
15945
15742
|
const int iq2 = (ir - iq3*neq2*neq1)/neq1;
|
15946
15743
|
const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
|
15947
15744
|
|
15948
|
-
const uint32_t h = iq2; // head
|
15745
|
+
const uint32_t h = iq2; // head index
|
15949
15746
|
const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
|
15950
15747
|
|
15951
|
-
float S = 0.0f;
|
15952
|
-
float M = -INFINITY;
|
15748
|
+
float S = 0.0f; // sum
|
15749
|
+
float M = -INFINITY; // maximum KQ value
|
15953
15750
|
|
15954
|
-
float *
|
15955
|
-
|
15956
|
-
ggml_fp16_t *
|
15751
|
+
float * VKQ32 = (float *) params->wdata + ith*(3*D + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator
|
15752
|
+
float * V32 = (VKQ32 + 1*D); // (temporary) FP32 V buffer
|
15753
|
+
ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*D); // (temporary) FP16 VKQ accumulator
|
15754
|
+
ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*D); // (temporary) buffer for Q converted to quantized/FP16
|
15957
15755
|
|
15958
|
-
|
15756
|
+
if (v->type == GGML_TYPE_F16) {
|
15757
|
+
memset(VKQ16, 0, D*sizeof(ggml_fp16_t));
|
15758
|
+
} else {
|
15759
|
+
memset(VKQ32, 0, D*sizeof(float));
|
15760
|
+
}
|
15959
15761
|
|
15960
15762
|
const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
|
15961
15763
|
|
@@ -15967,6 +15769,9 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
15967
15769
|
const int iv3 = iq3 / rv3;
|
15968
15770
|
const int iv2 = iq2 / rv2;
|
15969
15771
|
|
15772
|
+
const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
|
15773
|
+
q_to_vec_dot(pq, Q_q, D);
|
15774
|
+
|
15970
15775
|
// online softmax / attention
|
15971
15776
|
// loop over n_kv and n_head_kv
|
15972
15777
|
// ref: https://arxiv.org/pdf/2112.05682.pdf
|
@@ -15976,52 +15781,67 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
15976
15781
|
continue;
|
15977
15782
|
}
|
15978
15783
|
|
15979
|
-
float s;
|
15784
|
+
float s; // KQ value
|
15980
15785
|
|
15981
|
-
|
15982
|
-
|
15983
|
-
const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
|
15786
|
+
const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3);
|
15787
|
+
kq_vec_dot(D, &s, 0, k_data, 0, Q_q, 0, 1);
|
15984
15788
|
|
15985
|
-
|
15986
|
-
Q16[d] = GGML_FP32_TO_FP16(pq[d]);
|
15987
|
-
}
|
15988
|
-
}
|
15789
|
+
s = s*scale + mv; // scale KQ value and apply mask
|
15989
15790
|
|
15990
|
-
|
15991
|
-
&s, 0,
|
15992
|
-
(ggml_fp16_t *) ((char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
|
15993
|
-
Q16, 0, 1);
|
15791
|
+
const float Mold = M;
|
15994
15792
|
|
15995
|
-
|
15793
|
+
float ms = 1.0f; // upon new higher max val, scale VKQ and KQ sum with this value
|
15794
|
+
float vs = 1.0f; // post-softmax KQ value, expf(s - M)
|
15996
15795
|
|
15997
|
-
const
|
15796
|
+
const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
|
15998
15797
|
|
15999
|
-
|
16000
|
-
|
15798
|
+
if (v->type== GGML_TYPE_F16) {
|
15799
|
+
if (s > M) {
|
15800
|
+
// s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
|
15801
|
+
M = s;
|
15802
|
+
ms = expf(Mold - M);
|
16001
15803
|
|
16002
|
-
|
16003
|
-
|
16004
|
-
|
15804
|
+
// V = V*expf(Mold - M)
|
15805
|
+
ggml_vec_scale_f16(D, VKQ16, ms);
|
15806
|
+
} else {
|
15807
|
+
// no new maximum, ms == 1.0f, vs != 1.0f
|
15808
|
+
vs = expf(s - M);
|
15809
|
+
}
|
16005
15810
|
|
16006
|
-
// V
|
16007
|
-
|
15811
|
+
// V += v*expf(s - M)
|
15812
|
+
ggml_vec_mad_f16(D, VKQ16, (const ggml_fp16_t *) v_data, vs);
|
16008
15813
|
} else {
|
16009
|
-
|
16010
|
-
|
15814
|
+
if (s > M) {
|
15815
|
+
// s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
|
15816
|
+
M = s;
|
15817
|
+
ms = expf(Mold - M);
|
16011
15818
|
|
16012
|
-
|
15819
|
+
// V = V*expf(Mold - M)
|
15820
|
+
ggml_vec_scale_f32(D, VKQ32, ms);
|
15821
|
+
} else {
|
15822
|
+
// no new maximum, ms == 1.0f, vs != 1.0f
|
15823
|
+
vs = expf(s - M);
|
15824
|
+
}
|
16013
15825
|
|
16014
|
-
|
16015
|
-
ggml_vec_mad_f16(D, V16, v16, vs);
|
15826
|
+
v_to_float(v_data, V32, D);
|
16016
15827
|
|
16017
|
-
|
15828
|
+
// V += v*expf(s - M)
|
15829
|
+
ggml_vec_mad_f32(D, VKQ32, V32, vs);
|
15830
|
+
}
|
15831
|
+
|
15832
|
+
S = S*ms + vs; // scale and increment sum with partial sum
|
16018
15833
|
}
|
16019
15834
|
|
16020
|
-
|
16021
|
-
|
16022
|
-
|
15835
|
+
if (v->type == GGML_TYPE_F16) {
|
15836
|
+
for (int64_t d = 0; d < D; ++d) {
|
15837
|
+
VKQ32[d] = GGML_FP16_TO_FP32(VKQ16[d]);
|
15838
|
+
}
|
16023
15839
|
}
|
16024
15840
|
|
15841
|
+
// V /= S
|
15842
|
+
const float S_inv = 1.0f/S;
|
15843
|
+
ggml_vec_scale_f32(D, VKQ32, S_inv);
|
15844
|
+
|
16025
15845
|
// dst indices
|
16026
15846
|
const int i1 = iq1;
|
16027
15847
|
const int i2 = iq2;
|
@@ -16031,7 +15851,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
16031
15851
|
//memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float));
|
16032
15852
|
|
16033
15853
|
// permute(0, 2, 1, 3)
|
16034
|
-
memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1,
|
15854
|
+
memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1);
|
16035
15855
|
}
|
16036
15856
|
}
|
16037
15857
|
|
@@ -16056,165 +15876,6 @@ static void ggml_compute_forward_flash_attn_ext(
|
|
16056
15876
|
}
|
16057
15877
|
}
|
16058
15878
|
|
16059
|
-
// ggml_compute_forward_flash_ff
|
16060
|
-
|
16061
|
-
static void ggml_compute_forward_flash_ff_f16(
|
16062
|
-
const struct ggml_compute_params * params,
|
16063
|
-
struct ggml_tensor * dst) {
|
16064
|
-
|
16065
|
-
const struct ggml_tensor * a = dst->src[0]; // F16
|
16066
|
-
const struct ggml_tensor * b0 = dst->src[1]; // F16 fc_w
|
16067
|
-
const struct ggml_tensor * b1 = dst->src[2]; // F32 fc_b
|
16068
|
-
const struct ggml_tensor * c0 = dst->src[3]; // F16 proj_w
|
16069
|
-
const struct ggml_tensor * c1 = dst->src[4]; // F32 proj_b
|
16070
|
-
|
16071
|
-
int64_t t0 = ggml_perf_time_us();
|
16072
|
-
UNUSED(t0);
|
16073
|
-
|
16074
|
-
GGML_TENSOR_LOCALS(int64_t, nea, a, ne)
|
16075
|
-
GGML_TENSOR_LOCALS(size_t, nba, a, nb)
|
16076
|
-
GGML_TENSOR_LOCALS(int64_t, neb0, b0, ne)
|
16077
|
-
GGML_TENSOR_LOCALS(size_t, nbb0, b0, nb)
|
16078
|
-
GGML_TENSOR_LOCALS(int64_t, neb1, b1, ne)
|
16079
|
-
GGML_TENSOR_LOCALS(size_t, nbb1, b1, nb)
|
16080
|
-
GGML_TENSOR_LOCALS(int64_t, nec0, c0, ne)
|
16081
|
-
GGML_TENSOR_LOCALS(size_t, nbc0, c0, nb)
|
16082
|
-
GGML_TENSOR_LOCALS(int64_t, nec1, c1, ne)
|
16083
|
-
GGML_TENSOR_LOCALS(size_t, nbc1, c1, nb)
|
16084
|
-
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
16085
|
-
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
16086
|
-
|
16087
|
-
const int ith = params->ith;
|
16088
|
-
const int nth = params->nth;
|
16089
|
-
|
16090
|
-
const int64_t D = nea0;
|
16091
|
-
//const int64_t N = nea1;
|
16092
|
-
const int64_t M = neb01;
|
16093
|
-
|
16094
|
-
GGML_ASSERT(ne0 == nea0);
|
16095
|
-
GGML_ASSERT(ne1 == nea1);
|
16096
|
-
GGML_ASSERT(ne2 == nea2);
|
16097
|
-
|
16098
|
-
GGML_ASSERT(nba0 == sizeof(ggml_fp16_t));
|
16099
|
-
GGML_ASSERT(nbb00 == sizeof(ggml_fp16_t));
|
16100
|
-
GGML_ASSERT(nbb10 == sizeof(float));
|
16101
|
-
GGML_ASSERT(nbc00 == sizeof(ggml_fp16_t));
|
16102
|
-
GGML_ASSERT(nbc10 == sizeof(float));
|
16103
|
-
|
16104
|
-
GGML_ASSERT(neb00 == D);
|
16105
|
-
GGML_ASSERT(neb01 == M);
|
16106
|
-
GGML_ASSERT(neb10 == M);
|
16107
|
-
GGML_ASSERT(neb11 == 1);
|
16108
|
-
|
16109
|
-
GGML_ASSERT(nec00 == M);
|
16110
|
-
GGML_ASSERT(nec01 == D);
|
16111
|
-
GGML_ASSERT(nec10 == D);
|
16112
|
-
GGML_ASSERT(nec11 == 1);
|
16113
|
-
|
16114
|
-
// dst cannot be transposed or permuted
|
16115
|
-
GGML_ASSERT(nb0 == sizeof(float));
|
16116
|
-
GGML_ASSERT(nb0 <= nb1);
|
16117
|
-
GGML_ASSERT(nb1 <= nb2);
|
16118
|
-
GGML_ASSERT(nb2 <= nb3);
|
16119
|
-
|
16120
|
-
if (params->type == GGML_TASK_TYPE_INIT) {
|
16121
|
-
return;
|
16122
|
-
}
|
16123
|
-
|
16124
|
-
if (params->type == GGML_TASK_TYPE_FINALIZE) {
|
16125
|
-
return;
|
16126
|
-
}
|
16127
|
-
|
16128
|
-
// parallelize by a rows using ggml_vec_dot_f32
|
16129
|
-
|
16130
|
-
// total rows in a
|
16131
|
-
const int nr = nea1*nea2*nea3;
|
16132
|
-
|
16133
|
-
// rows per thread
|
16134
|
-
const int dr = (nr + nth - 1)/nth;
|
16135
|
-
|
16136
|
-
// row range for this thread
|
16137
|
-
const int ir0 = dr*ith;
|
16138
|
-
const int ir1 = MIN(ir0 + dr, nr);
|
16139
|
-
|
16140
|
-
for (int ir = ir0; ir < ir1; ++ir) {
|
16141
|
-
// a indices
|
16142
|
-
const int ia3 = ir/(nea2*nea1);
|
16143
|
-
const int ia2 = (ir - ia3*nea2*nea1)/nea1;
|
16144
|
-
const int ia1 = (ir - ia3*nea2*nea1 - ia2*nea1);
|
16145
|
-
|
16146
|
-
float * S = (float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32);
|
16147
|
-
|
16148
|
-
for (int64_t ic = 0; ic < neb01; ++ic) {
|
16149
|
-
// b0 indices
|
16150
|
-
const int ib03 = ia3;
|
16151
|
-
const int ib02 = ia2;
|
16152
|
-
const int ib01 = ic;
|
16153
|
-
|
16154
|
-
// S indices
|
16155
|
-
const int i1 = ib01;
|
16156
|
-
|
16157
|
-
ggml_vec_dot_f16(nea0,
|
16158
|
-
S + i1, 0,
|
16159
|
-
(ggml_fp16_t *) ((char *) b0->data + (ib01*nbb01 + ib02*nbb02 + ib03*nbb03)), 0,
|
16160
|
-
(ggml_fp16_t *) ((char *) a->data + ( ia1*nba1 + ia2*nba2 + ia3*nba3)), 0, 1);
|
16161
|
-
}
|
16162
|
-
|
16163
|
-
ggml_vec_add_f32(neb01, S, S, (float *) b1->data);
|
16164
|
-
//ggml_vec_gelu_f32(neb01, S, S);
|
16165
|
-
|
16166
|
-
ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32) + M);
|
16167
|
-
|
16168
|
-
for (int64_t i = 0; i < M; i++) {
|
16169
|
-
S16[i] = GGML_FP32_TO_FP16(S[i]);
|
16170
|
-
}
|
16171
|
-
|
16172
|
-
ggml_vec_gelu_f16(neb01, S16, S16);
|
16173
|
-
|
16174
|
-
{
|
16175
|
-
// dst indices
|
16176
|
-
const int i1 = ia1;
|
16177
|
-
const int i2 = ia2;
|
16178
|
-
const int i3 = ia3;
|
16179
|
-
|
16180
|
-
for (int64_t ic = 0; ic < nec01; ++ic) {
|
16181
|
-
|
16182
|
-
ggml_vec_dot_f16(neb01,
|
16183
|
-
(float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), 0,
|
16184
|
-
(ggml_fp16_t *) ((char *) c0->data + ( ic*nbc01 + i2*nbc02 + i3*nbc03)), 0,
|
16185
|
-
S16, 0, 1);
|
16186
|
-
}
|
16187
|
-
|
16188
|
-
ggml_vec_add_f32(nec01,
|
16189
|
-
(float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)),
|
16190
|
-
(float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)),
|
16191
|
-
(float *) c1->data);
|
16192
|
-
}
|
16193
|
-
}
|
16194
|
-
}
|
16195
|
-
|
16196
|
-
static void ggml_compute_forward_flash_ff(
|
16197
|
-
const struct ggml_compute_params * params,
|
16198
|
-
struct ggml_tensor * dst) {
|
16199
|
-
|
16200
|
-
const struct ggml_tensor * b0 = dst->src[1];
|
16201
|
-
|
16202
|
-
switch (b0->type) {
|
16203
|
-
case GGML_TYPE_F16:
|
16204
|
-
{
|
16205
|
-
ggml_compute_forward_flash_ff_f16(params, dst);
|
16206
|
-
} break;
|
16207
|
-
case GGML_TYPE_F32:
|
16208
|
-
{
|
16209
|
-
GGML_ASSERT(false); // TODO
|
16210
|
-
} break;
|
16211
|
-
default:
|
16212
|
-
{
|
16213
|
-
GGML_ASSERT(false);
|
16214
|
-
} break;
|
16215
|
-
}
|
16216
|
-
}
|
16217
|
-
|
16218
15879
|
// ggml_compute_forward_flash_attn_back
|
16219
15880
|
|
16220
15881
|
static void ggml_compute_forward_flash_attn_back_f32(
|
@@ -17785,21 +17446,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
17785
17446
|
{
|
17786
17447
|
ggml_compute_forward_leaky_relu(params, tensor);
|
17787
17448
|
} break;
|
17788
|
-
case GGML_OP_FLASH_ATTN:
|
17789
|
-
{
|
17790
|
-
const int32_t t = ggml_get_op_params_i32(tensor, 0);
|
17791
|
-
GGML_ASSERT(t == 0 || t == 1);
|
17792
|
-
const bool masked = t != 0;
|
17793
|
-
ggml_compute_forward_flash_attn(params, masked, tensor);
|
17794
|
-
} break;
|
17795
17449
|
case GGML_OP_FLASH_ATTN_EXT:
|
17796
17450
|
{
|
17797
17451
|
ggml_compute_forward_flash_attn_ext(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor);
|
17798
17452
|
} break;
|
17799
|
-
case GGML_OP_FLASH_FF:
|
17800
|
-
{
|
17801
|
-
ggml_compute_forward_flash_ff(params, tensor);
|
17802
|
-
} break;
|
17803
17453
|
case GGML_OP_FLASH_ATTN_BACK:
|
17804
17454
|
{
|
17805
17455
|
int32_t t = ggml_get_op_params_i32(tensor, 0);
|
@@ -18169,6 +17819,7 @@ static struct ggml_tensor * ggml_sub_or_set(struct ggml_context * ctx, struct gg
|
|
18169
17819
|
static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, struct ggml_hash_set zero_table) {
|
18170
17820
|
struct ggml_tensor * src0 = tensor->src[0];
|
18171
17821
|
struct ggml_tensor * src1 = tensor->src[1];
|
17822
|
+
struct ggml_tensor * src2 = tensor->src[2];
|
18172
17823
|
|
18173
17824
|
switch (tensor->op) {
|
18174
17825
|
case GGML_OP_DUP:
|
@@ -18700,6 +18351,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
18700
18351
|
ggml_rope_back(ctx,
|
18701
18352
|
tensor->grad,
|
18702
18353
|
src1,
|
18354
|
+
src2,
|
18703
18355
|
n_dims,
|
18704
18356
|
mode,
|
18705
18357
|
n_ctx,
|
@@ -18739,6 +18391,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
18739
18391
|
ggml_rope_impl(ctx,
|
18740
18392
|
tensor->grad,
|
18741
18393
|
src1,
|
18394
|
+
src2,
|
18742
18395
|
n_dims,
|
18743
18396
|
mode,
|
18744
18397
|
n_ctx,
|
@@ -18803,7 +18456,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
18803
18456
|
{
|
18804
18457
|
GGML_ASSERT(false); // TODO: not implemented
|
18805
18458
|
} break;
|
18806
|
-
case GGML_OP_FLASH_ATTN:
|
18807
18459
|
case GGML_OP_FLASH_ATTN_EXT:
|
18808
18460
|
{
|
18809
18461
|
struct ggml_tensor * flash_grad = NULL;
|
@@ -18820,7 +18472,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
18820
18472
|
masked);
|
18821
18473
|
}
|
18822
18474
|
|
18823
|
-
struct ggml_tensor * src2 = tensor->src[2];
|
18824
18475
|
const int64_t elem_q = ggml_nelements(src0);
|
18825
18476
|
const int64_t elem_k = ggml_nelements(src1);
|
18826
18477
|
const int64_t elem_v = ggml_nelements(src2);
|
@@ -18858,10 +18509,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
18858
18509
|
zero_table);
|
18859
18510
|
}
|
18860
18511
|
} break;
|
18861
|
-
case GGML_OP_FLASH_FF:
|
18862
|
-
{
|
18863
|
-
GGML_ASSERT(false); // not supported
|
18864
|
-
} break;
|
18865
18512
|
case GGML_OP_FLASH_ATTN_BACK:
|
18866
18513
|
{
|
18867
18514
|
GGML_ASSERT(false); // not supported
|
@@ -19548,15 +19195,10 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_
|
|
19548
19195
|
{
|
19549
19196
|
n_tasks = n_threads;
|
19550
19197
|
} break;
|
19551
|
-
case GGML_OP_FLASH_ATTN:
|
19552
19198
|
case GGML_OP_FLASH_ATTN_EXT:
|
19553
19199
|
{
|
19554
19200
|
n_tasks = n_threads;
|
19555
19201
|
} break;
|
19556
|
-
case GGML_OP_FLASH_FF:
|
19557
|
-
{
|
19558
|
-
n_tasks = n_threads;
|
19559
|
-
} break;
|
19560
19202
|
case GGML_OP_FLASH_ATTN_BACK:
|
19561
19203
|
{
|
19562
19204
|
n_tasks = n_threads;
|
@@ -19953,39 +19595,11 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
|
|
19953
19595
|
cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02*ne03;
|
19954
19596
|
cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12;
|
19955
19597
|
} break;
|
19956
|
-
case GGML_OP_FLASH_ATTN:
|
19957
|
-
{
|
19958
|
-
const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL);
|
19959
|
-
|
19960
|
-
if (node->src[1]->type == GGML_TYPE_F32) {
|
19961
|
-
cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
|
19962
|
-
cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
|
19963
|
-
} else if (node->src[1]->type == GGML_TYPE_F16) {
|
19964
|
-
cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
|
19965
|
-
cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
|
19966
|
-
} else if (node->src[1]->type == GGML_TYPE_BF16) {
|
19967
|
-
cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
|
19968
|
-
cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
|
19969
|
-
}
|
19970
|
-
} break;
|
19971
19598
|
case GGML_OP_FLASH_ATTN_EXT:
|
19972
19599
|
{
|
19973
19600
|
const int64_t ne00 = node->src[0]->ne[0]; // D
|
19974
19601
|
|
19975
|
-
cur =
|
19976
|
-
} break;
|
19977
|
-
case GGML_OP_FLASH_FF:
|
19978
|
-
{
|
19979
|
-
if (node->src[1]->type == GGML_TYPE_F32) {
|
19980
|
-
cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
|
19981
|
-
cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
|
19982
|
-
} else if (node->src[1]->type == GGML_TYPE_F16) {
|
19983
|
-
cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
|
19984
|
-
cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
|
19985
|
-
} else if (node->src[1]->type == GGML_TYPE_BF16) {
|
19986
|
-
cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
|
19987
|
-
cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
|
19988
|
-
}
|
19602
|
+
cur = 3*sizeof(float)*ne00*n_tasks; // 3x head size/thread
|
19989
19603
|
} break;
|
19990
19604
|
case GGML_OP_FLASH_ATTN_BACK:
|
19991
19605
|
{
|
@@ -21827,11 +21441,7 @@ size_t ggml_quantize_chunk(
|
|
21827
21441
|
case GGML_TYPE_IQ1_S: result = quantize_iq1_s (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
21828
21442
|
case GGML_TYPE_IQ1_M: result = quantize_iq1_m (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
21829
21443
|
case GGML_TYPE_IQ4_NL: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
21830
|
-
#if QK_K == 64
|
21831
|
-
case GGML_TYPE_IQ4_XS: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
21832
|
-
#else
|
21833
21444
|
case GGML_TYPE_IQ4_XS: result = quantize_iq4_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
21834
|
-
#endif
|
21835
21445
|
case GGML_TYPE_F16:
|
21836
21446
|
{
|
21837
21447
|
size_t elemsize = sizeof(ggml_fp16_t);
|
@@ -23108,6 +22718,14 @@ int ggml_cpu_has_avx512_vnni(void) {
|
|
23108
22718
|
#endif
|
23109
22719
|
}
|
23110
22720
|
|
22721
|
+
int ggml_cpu_has_avx512_bf16(void) {
|
22722
|
+
#if defined(__AVX512BF16__)
|
22723
|
+
return 1;
|
22724
|
+
#else
|
22725
|
+
return 0;
|
22726
|
+
#endif
|
22727
|
+
}
|
22728
|
+
|
23111
22729
|
int ggml_cpu_has_fma(void) {
|
23112
22730
|
#if defined(__FMA__)
|
23113
22731
|
return 1;
|