llama_cpp 0.15.2 → 0.15.4
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +14 -0
- data/ext/llama_cpp/llama_cpp.cpp +61 -0
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +6 -0
- data/vendor/tmp/llama.cpp/Makefile +8 -16
- data/vendor/tmp/llama.cpp/ggml-common.h +0 -54
- data/vendor/tmp/llama.cpp/ggml-cuda.cu +99 -40
- data/vendor/tmp/llama.cpp/ggml-cuda.h +1 -0
- data/vendor/tmp/llama.cpp/ggml-impl.h +44 -0
- data/vendor/tmp/llama.cpp/ggml-kompute.cpp +4 -1
- data/vendor/tmp/llama.cpp/ggml-metal.m +133 -81
- data/vendor/tmp/llama.cpp/ggml-metal.metal +91 -434
- data/vendor/tmp/llama.cpp/ggml-opencl.cpp +4 -1
- data/vendor/tmp/llama.cpp/ggml-quants.c +1962 -2443
- data/vendor/tmp/llama.cpp/ggml-rpc.cpp +248 -108
- data/vendor/tmp/llama.cpp/ggml-sycl.cpp +375 -657
- data/vendor/tmp/llama.cpp/ggml-vulkan-shaders.hpp +9351 -5627
- data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +204 -225
- data/vendor/tmp/llama.cpp/ggml.c +498 -836
- data/vendor/tmp/llama.cpp/ggml.h +57 -30
- data/vendor/tmp/llama.cpp/llama.cpp +1477 -859
- data/vendor/tmp/llama.cpp/llama.h +21 -8
- metadata +3 -3
data/vendor/tmp/llama.cpp/ggml.c
CHANGED
@@ -60,6 +60,9 @@
|
|
60
60
|
|
61
61
|
typedef volatile LONG atomic_int;
|
62
62
|
typedef atomic_int atomic_bool;
|
63
|
+
typedef atomic_int atomic_flag;
|
64
|
+
|
65
|
+
#define ATOMIC_FLAG_INIT 0
|
63
66
|
|
64
67
|
static void atomic_store(atomic_int * ptr, LONG val) {
|
65
68
|
InterlockedExchange(ptr, val);
|
@@ -73,6 +76,12 @@ static LONG atomic_fetch_add(atomic_int * ptr, LONG inc) {
|
|
73
76
|
static LONG atomic_fetch_sub(atomic_int * ptr, LONG dec) {
|
74
77
|
return atomic_fetch_add(ptr, -(dec));
|
75
78
|
}
|
79
|
+
static atomic_bool atomic_flag_test_and_set(atomic_flag * ptr) {
|
80
|
+
return InterlockedExchange(ptr, 1);
|
81
|
+
}
|
82
|
+
static void atomic_flag_clear(atomic_flag * ptr) {
|
83
|
+
InterlockedExchange(ptr, 0);
|
84
|
+
}
|
76
85
|
|
77
86
|
typedef HANDLE pthread_t;
|
78
87
|
|
@@ -406,10 +415,10 @@ void ggml_fp32_to_bf16_row(const float * x, ggml_bf16_t * y, int64_t n) {
|
|
406
415
|
int i = 0;
|
407
416
|
#if defined(__AVX512BF16__)
|
408
417
|
for (; i + 32 <= n; i += 32) {
|
409
|
-
|
410
|
-
(
|
411
|
-
(
|
412
|
-
|
418
|
+
_mm512_storeu_si512(
|
419
|
+
(__m512i *)(y + i),
|
420
|
+
m512i(_mm512_cvtne2ps_pbh(_mm512_loadu_ps(x + i + 16),
|
421
|
+
_mm512_loadu_ps(x + i))));
|
413
422
|
}
|
414
423
|
#endif
|
415
424
|
for (; i < n; i++) {
|
@@ -871,22 +880,14 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
|
871
880
|
},
|
872
881
|
[GGML_TYPE_IQ4_XS] = {
|
873
882
|
.type_name = "iq4_xs",
|
874
|
-
#if QK_K == 64
|
875
|
-
.blck_size = QK4_NL,
|
876
|
-
#else
|
877
883
|
.blck_size = QK_K,
|
878
|
-
#endif
|
879
884
|
.type_size = sizeof(block_iq4_xs),
|
880
885
|
.is_quantized = true,
|
881
886
|
.to_float = (ggml_to_float_t) dequantize_row_iq4_xs,
|
882
887
|
.from_float = quantize_row_iq4_xs,
|
883
888
|
.from_float_reference = (ggml_from_float_t)quantize_row_iq4_xs_reference,
|
884
889
|
.vec_dot = ggml_vec_dot_iq4_xs_q8_K,
|
885
|
-
#if QK_K == 64
|
886
|
-
.vec_dot_type = GGML_TYPE_Q8_0,
|
887
|
-
#else
|
888
890
|
.vec_dot_type = GGML_TYPE_Q8_K,
|
889
|
-
#endif
|
890
891
|
.nrows = 1,
|
891
892
|
},
|
892
893
|
[GGML_TYPE_Q8_K] = {
|
@@ -1523,6 +1524,196 @@ static inline void __sse_f16x4_store(ggml_fp16_t *x, __m128 y) {
|
|
1523
1524
|
#define GGML_F16_VEC_MUL GGML_F32Cx4_MUL
|
1524
1525
|
#define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE
|
1525
1526
|
|
1527
|
+
#elif defined(__loongarch_asx)
|
1528
|
+
|
1529
|
+
#define GGML_SIMD
|
1530
|
+
|
1531
|
+
// F32 LASX
|
1532
|
+
#define GGML_F32_STEP 32
|
1533
|
+
#define GGML_F32_EPR 8
|
1534
|
+
|
1535
|
+
#define GGML_F32x8 __m256
|
1536
|
+
#define GGML_F32x8_ZERO (__m256)__lasx_xvldi(0)
|
1537
|
+
#define GGML_F32x8_SET1(x) (__m256)__lasx_xvreplfr2vr_s((x))
|
1538
|
+
#define GGML_F32x8_LOAD(x) (__m256)__lasx_xvld((x), 0)
|
1539
|
+
#define GGML_F32x8_STORE(x,y) __lasx_xvst((y), (x), 0)
|
1540
|
+
#define GGML_F32x8_FMA(a, b, c) __lasx_xvfmadd_s(b, c, a)
|
1541
|
+
#define GGML_F32x8_ADD __lasx_xvfadd_s
|
1542
|
+
#define GGML_F32x8_MUL __lasx_xvfmul_s
|
1543
|
+
#define GGML_F32x8_REDUCE(res, x) \
|
1544
|
+
do { \
|
1545
|
+
int offset = GGML_F32_ARR >> 1; \
|
1546
|
+
for (int i = 0; i < offset; ++i) { \
|
1547
|
+
x[i] = __lasx_xvfadd_s(x[i], x[offset+i]); \
|
1548
|
+
} \
|
1549
|
+
offset >>= 1; \
|
1550
|
+
for (int i = 0; i < offset; ++i) { \
|
1551
|
+
x[i] = __lasx_xvfadd_s(x[i], x[offset+i]); \
|
1552
|
+
} \
|
1553
|
+
offset >>= 1; \
|
1554
|
+
for (int i = 0; i < offset; ++i) { \
|
1555
|
+
x[i] = __lasx_xvfadd_s(x[i], x[offset+i]); \
|
1556
|
+
} \
|
1557
|
+
float *tmp_p = (float *)&x[0]; \
|
1558
|
+
res = tmp_p[0] + tmp_p[1] + tmp_p[2] + tmp_p[3] + tmp_p[4] + tmp_p[5] + tmp_p[6] + tmp_p[7]; \
|
1559
|
+
} while (0)
|
1560
|
+
// TODO: is this optimal ?
|
1561
|
+
|
1562
|
+
#define GGML_F32_VEC GGML_F32x8
|
1563
|
+
#define GGML_F32_VEC_ZERO GGML_F32x8_ZERO
|
1564
|
+
#define GGML_F32_VEC_SET1 GGML_F32x8_SET1
|
1565
|
+
#define GGML_F32_VEC_LOAD GGML_F32x8_LOAD
|
1566
|
+
#define GGML_F32_VEC_STORE GGML_F32x8_STORE
|
1567
|
+
#define GGML_F32_VEC_FMA GGML_F32x8_FMA
|
1568
|
+
#define GGML_F32_VEC_ADD GGML_F32x8_ADD
|
1569
|
+
#define GGML_F32_VEC_MUL GGML_F32x8_MUL
|
1570
|
+
#define GGML_F32_VEC_REDUCE GGML_F32x8_REDUCE
|
1571
|
+
|
1572
|
+
// F16 LASX
|
1573
|
+
|
1574
|
+
#define GGML_F16_STEP 32
|
1575
|
+
#define GGML_F16_EPR 8
|
1576
|
+
|
1577
|
+
// F16 arithmetic is not supported by AVX, so we use F32 instead
|
1578
|
+
|
1579
|
+
#define GGML_F32Cx8 __m256
|
1580
|
+
#define GGML_F32Cx8_ZERO (__m256)__lasx_xvldi(0)
|
1581
|
+
#define GGML_F32Cx8_SET1(x) (__m256)__lasx_xvreplgr2vr_w((x))
|
1582
|
+
|
1583
|
+
static inline __m256 __lasx_f32cx8_load(const ggml_fp16_t * x) {
|
1584
|
+
float tmp[8];
|
1585
|
+
|
1586
|
+
for (int i = 0; i < 8; i++) {
|
1587
|
+
tmp[i] = GGML_FP16_TO_FP32(x[i]);
|
1588
|
+
}
|
1589
|
+
|
1590
|
+
return (__m256)__lasx_xvld(tmp, 0);
|
1591
|
+
}
|
1592
|
+
static inline void __lasx_f32cx8_store(ggml_fp16_t * x, __m256 y) {
|
1593
|
+
float arr[8];
|
1594
|
+
|
1595
|
+
__lasx_xvst(y, arr, 0);
|
1596
|
+
|
1597
|
+
for (int i = 0; i < 8; i++) {
|
1598
|
+
x[i] = GGML_FP32_TO_FP16(arr[i]);
|
1599
|
+
}
|
1600
|
+
}
|
1601
|
+
#define GGML_F32Cx8_LOAD(x) __lasx_f32cx8_load(x)
|
1602
|
+
#define GGML_F32Cx8_STORE(x, y) __lasx_f32cx8_store(x, y)
|
1603
|
+
|
1604
|
+
#define GGML_F32Cx8_FMA GGML_F32x8_FMA
|
1605
|
+
#define GGML_F32Cx8_ADD __lasx_xvfadd_s
|
1606
|
+
#define GGML_F32Cx8_MUL __lasx_xvfmul_s
|
1607
|
+
#define GGML_F32Cx8_REDUCE GGML_F32x8_REDUCE
|
1608
|
+
|
1609
|
+
#define GGML_F16_VEC GGML_F32Cx8
|
1610
|
+
#define GGML_F16_VEC_ZERO GGML_F32Cx8_ZERO
|
1611
|
+
#define GGML_F16_VEC_SET1 GGML_F32Cx8_SET1
|
1612
|
+
#define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx8_LOAD(p)
|
1613
|
+
#define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx8_STORE(p, r[i])
|
1614
|
+
#define GGML_F16_VEC_FMA GGML_F32Cx8_FMA
|
1615
|
+
#define GGML_F16_VEC_ADD GGML_F32Cx8_ADD
|
1616
|
+
#define GGML_F16_VEC_MUL GGML_F32Cx8_MUL
|
1617
|
+
#define GGML_F16_VEC_REDUCE GGML_F32Cx8_REDUCE
|
1618
|
+
|
1619
|
+
#elif defined(__loongarch_sx)
|
1620
|
+
|
1621
|
+
#define GGML_SIMD
|
1622
|
+
|
1623
|
+
// F32 LSX
|
1624
|
+
|
1625
|
+
#define GGML_F32_STEP 32
|
1626
|
+
#define GGML_F32_EPR 4
|
1627
|
+
|
1628
|
+
#define GGML_F32x4 __m128
|
1629
|
+
#define GGML_F32x4_ZERO __lsx_vldi(0)
|
1630
|
+
#define GGML_F32x4_SET1(x) __lsx_vinsgr2vr_w(__lsx_vldi(0),(x), 0)
|
1631
|
+
#define GGML_F32x4_LOAD(x) __lsx_vld((x), 0)
|
1632
|
+
#define GGML_F32x4_STORE((x),(y)) __lsx_vst((y), (x), 0)
|
1633
|
+
#define GGML_F32x4_FMA(a, b, c) __lsx_vfmadd_s(b, c, a)
|
1634
|
+
#define GGML_F32x4_ADD __lsx_vfadd_s
|
1635
|
+
#define GGML_F32x4_MUL __lsx_vfmul_s
|
1636
|
+
#define GGML_F32x4_REDUCE(res, x) \
|
1637
|
+
{ \
|
1638
|
+
int offset = GGML_F32_ARR >> 1; \
|
1639
|
+
for (int i = 0; i < offset; ++i) { \
|
1640
|
+
x[i] = __lsx_vfadd_s(x[i], x[offset+i]); \
|
1641
|
+
} \
|
1642
|
+
offset >>= 1; \
|
1643
|
+
for (int i = 0; i < offset; ++i) { \
|
1644
|
+
x[i] = __lsx_vfadd_s(x[i], x[offset+i]); \
|
1645
|
+
} \
|
1646
|
+
offset >>= 1; \
|
1647
|
+
for (int i = 0; i < offset; ++i) { \
|
1648
|
+
x[i] = __lsx_vfadd_s(x[i], x[offset+i]); \
|
1649
|
+
} \
|
1650
|
+
__m128i tmp = __lsx_vsrli_d((__m128i)x[0], 32); \
|
1651
|
+
tmp = (__m128i)__lsx_vfadd_s((__m128)tmp, x[0]); \
|
1652
|
+
tmp = __lsx_vpickev_w(__lsx_vldi(0), tmp); \
|
1653
|
+
const __m128 t0 = __lsx_vshuf4i_w(tmp, 0x88); \
|
1654
|
+
tmp = __lsx_vsrli_d((__m128i)t0, 32); \
|
1655
|
+
tmp = (__m128i)__lsx_vfadd_s((__m128)tmp, t0); \
|
1656
|
+
tmp = __lsx_vpickev_w(__lsx_vldi(0), tmp); \
|
1657
|
+
res = (ggml_float) __lsx_vpickve2gr_w(__lsx_vshuf4i_w(tmp, 0x88), 0); \
|
1658
|
+
}
|
1659
|
+
|
1660
|
+
#define GGML_F32_VEC GGML_F32x4
|
1661
|
+
#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO
|
1662
|
+
#define GGML_F32_VEC_SET1 GGML_F32x4_SET1
|
1663
|
+
#define GGML_F32_VEC_LOAD GGML_F32x4_LOAD
|
1664
|
+
#define GGML_F32_VEC_STORE GGML_F32x4_STORE
|
1665
|
+
#define GGML_F32_VEC_FMA GGML_F32x4_FMA
|
1666
|
+
#define GGML_F32_VEC_ADD GGML_F32x4_ADD
|
1667
|
+
#define GGML_F32_VEC_MUL GGML_F32x4_MUL
|
1668
|
+
#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
|
1669
|
+
|
1670
|
+
// F16 LSX
|
1671
|
+
|
1672
|
+
#define GGML_F16_STEP 32
|
1673
|
+
#define GGML_F16_EPR 4
|
1674
|
+
|
1675
|
+
static inline __m128 __lsx_f16x4_load(const ggml_fp16_t * x) {
|
1676
|
+
float tmp[4];
|
1677
|
+
|
1678
|
+
tmp[0] = GGML_FP16_TO_FP32(x[0]);
|
1679
|
+
tmp[1] = GGML_FP16_TO_FP32(x[1]);
|
1680
|
+
tmp[2] = GGML_FP16_TO_FP32(x[2]);
|
1681
|
+
tmp[3] = GGML_FP16_TO_FP32(x[3]);
|
1682
|
+
|
1683
|
+
return __lsx_vld(tmp, 0);
|
1684
|
+
}
|
1685
|
+
|
1686
|
+
static inline void __lsx_f16x4_store(ggml_fp16_t * x, __m128 y) {
|
1687
|
+
float arr[4];
|
1688
|
+
|
1689
|
+
__lsx_vst(y, arr, 0);
|
1690
|
+
|
1691
|
+
x[0] = GGML_FP32_TO_FP16(arr[0]);
|
1692
|
+
x[1] = GGML_FP32_TO_FP16(arr[1]);
|
1693
|
+
x[2] = GGML_FP32_TO_FP16(arr[2]);
|
1694
|
+
x[3] = GGML_FP32_TO_FP16(arr[3]);
|
1695
|
+
}
|
1696
|
+
|
1697
|
+
#define GGML_F32Cx4 __m128
|
1698
|
+
#define GGML_F32Cx4_ZERO __lsx_vldi(0)
|
1699
|
+
#define GGML_F32Cx4_SET1(x) __lsx_vinsgr2vr_w(__lsx_vldi(0),(x), 0)
|
1700
|
+
#define GGML_F32Cx4_LOAD(x) __lsx_f16x4_load(x)
|
1701
|
+
#define GGML_F32Cx4_STORE(x, y) __lsx_f16x4_store(x, y)
|
1702
|
+
#define GGML_F32Cx4_FMA GGML_F32x4_FMA
|
1703
|
+
#define GGML_F32Cx4_ADD __lsx_vfadd_s
|
1704
|
+
#define GGML_F32Cx4_MUL __lsx_vfmul_s
|
1705
|
+
#define GGML_F32Cx4_REDUCE GGML_F32x4_REDUCE
|
1706
|
+
|
1707
|
+
#define GGML_F16_VEC GGML_F32Cx4
|
1708
|
+
#define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO
|
1709
|
+
#define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1
|
1710
|
+
#define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx4_LOAD(p)
|
1711
|
+
#define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE(p, r[i])
|
1712
|
+
#define GGML_F16_VEC_FMA GGML_F32Cx4_FMA
|
1713
|
+
#define GGML_F16_VEC_ADD GGML_F32Cx4_ADD
|
1714
|
+
#define GGML_F16_VEC_MUL GGML_F32Cx4_MUL
|
1715
|
+
#define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE
|
1716
|
+
|
1526
1717
|
#endif
|
1527
1718
|
|
1528
1719
|
// GGML_F32_ARR / GGML_F16_ARR
|
@@ -1666,10 +1857,10 @@ static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t
|
|
1666
1857
|
__m512 c1 = _mm512_setzero_ps();
|
1667
1858
|
__m512 c2 = _mm512_setzero_ps();
|
1668
1859
|
for (; i + 64 <= n; i += 64) {
|
1669
|
-
c1 = _mm512_dpbf16_ps(c1, (
|
1670
|
-
|
1671
|
-
c2 = _mm512_dpbf16_ps(c2, (
|
1672
|
-
|
1860
|
+
c1 = _mm512_dpbf16_ps(c1, m512bh(_mm512_loadu_si512((x + i))),
|
1861
|
+
m512bh(_mm512_loadu_si512((y + i))));
|
1862
|
+
c2 = _mm512_dpbf16_ps(c2, m512bh(_mm512_loadu_si512((x + i + 32))),
|
1863
|
+
m512bh(_mm512_loadu_si512((y + i + 32))));
|
1673
1864
|
}
|
1674
1865
|
sumf += (ggml_float)_mm512_reduce_add_ps(c1);
|
1675
1866
|
sumf += (ggml_float)_mm512_reduce_add_ps(c2);
|
@@ -2076,7 +2267,7 @@ inline static float ggml_silu_f32(float x) {
|
|
2076
2267
|
return x/(1.0f + expf(-x));
|
2077
2268
|
}
|
2078
2269
|
|
2079
|
-
#if defined(__ARM_NEON)
|
2270
|
+
#if defined(__ARM_NEON) && defined(__aarch64__)
|
2080
2271
|
|
2081
2272
|
// adapted from arm limited optimized routine
|
2082
2273
|
// the maximum error is 1.45358 plus 0.5 ulps
|
@@ -2125,32 +2316,27 @@ inline static __m512 ggml_v_expf(__m512 x) {
|
|
2125
2316
|
const __m512 r = _mm512_set1_ps(0x1.8p23f);
|
2126
2317
|
const __m512 z = _mm512_fmadd_ps(x, _mm512_set1_ps(0x1.715476p+0f), r);
|
2127
2318
|
const __m512 n = _mm512_sub_ps(z, r);
|
2128
|
-
const __m512 b =
|
2129
|
-
|
2130
|
-
|
2131
|
-
const __m512 k = _mm512_castsi512_ps(_mm512_add_epi32(e, _mm512_castps_si512(_mm512_set1_ps(1))));
|
2132
|
-
const __mmask16 c = _mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(126), _CMP_GT_OQ);
|
2133
|
-
const __m512 u = _mm512_mul_ps(b, b);
|
2134
|
-
const __m512 j = _mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_set1_ps(0x1.0e4020p-7f), b,
|
2135
|
-
_mm512_set1_ps(0x1.573e2ep-5f)), u,
|
2136
|
-
_mm512_fmadd_ps(_mm512_set1_ps(0x1.555e66p-3f), b,
|
2137
|
-
_mm512_set1_ps(0x1.fffdb6p-2f))),
|
2138
|
-
u, _mm512_mul_ps(_mm512_set1_ps(0x1.ffffecp-1f), b));
|
2139
|
-
if (_mm512_kortestz(c, c))
|
2140
|
-
return _mm512_fmadd_ps(j, k, k);
|
2141
|
-
const __m512i g = _mm512_and_si512(
|
2142
|
-
_mm512_movm_epi32(_mm512_cmp_ps_mask(n, _mm512_setzero_ps(), _CMP_LE_OQ)),
|
2143
|
-
_mm512_set1_epi32(0x82000000u));
|
2144
|
-
const __m512 s1 =
|
2145
|
-
_mm512_castsi512_ps(_mm512_add_epi32(g, _mm512_set1_epi32(0x7f000000u)));
|
2146
|
-
const __m512 s2 = _mm512_castsi512_ps(_mm512_sub_epi32(e, g));
|
2319
|
+
const __m512 b =
|
2320
|
+
_mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.7f7d1cp-20f),
|
2321
|
+
_mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.62e4p-1f), x));
|
2147
2322
|
const __mmask16 d =
|
2148
2323
|
_mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(192), _CMP_GT_OQ);
|
2149
|
-
|
2150
|
-
|
2151
|
-
|
2152
|
-
|
2153
|
-
|
2324
|
+
const __m512 u = _mm512_mul_ps(b, b);
|
2325
|
+
const __m512 j = _mm512_fmadd_ps(
|
2326
|
+
_mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_set1_ps(0x1.0e4020p-7f), b,
|
2327
|
+
_mm512_set1_ps(0x1.573e2ep-5f)),
|
2328
|
+
u,
|
2329
|
+
_mm512_fmadd_ps(_mm512_set1_ps(0x1.555e66p-3f), b,
|
2330
|
+
_mm512_set1_ps(0x1.fffdb6p-2f))),
|
2331
|
+
u,
|
2332
|
+
_mm512_fmadd_ps(_mm512_set1_ps(0x1.ffffecp-1f), b, _mm512_set1_ps(1.0F)));
|
2333
|
+
const __m512 res = _mm512_scalef_ps(j, n);
|
2334
|
+
if (_mm512_kortestz(d, d))
|
2335
|
+
return res;
|
2336
|
+
const __m512 zero = _mm512_setzero_ps();
|
2337
|
+
const __m512 alt = _mm512_mask_blend_ps(
|
2338
|
+
_mm512_cmp_ps_mask(n, zero, _CMP_LE_OQ), _mm512_set1_ps(INFINITY), zero);
|
2339
|
+
return _mm512_mask_blend_ps(d, res, alt);
|
2154
2340
|
}
|
2155
2341
|
|
2156
2342
|
// computes silu x/(1+exp(-x)) in single precision vector
|
@@ -2288,7 +2474,7 @@ static void ggml_vec_silu_f32(const int n, float * y, const float * x) {
|
|
2288
2474
|
for (; i + 3 < n; i += 4) {
|
2289
2475
|
_mm_storeu_ps(y + i, ggml_v_silu(_mm_loadu_ps(x + i)));
|
2290
2476
|
}
|
2291
|
-
#elif defined(__ARM_NEON)
|
2477
|
+
#elif defined(__ARM_NEON) && defined(__aarch64__)
|
2292
2478
|
for (; i + 3 < n; i += 4) {
|
2293
2479
|
vst1q_f32(y + i, ggml_v_silu(vld1q_f32(x + i)));
|
2294
2480
|
}
|
@@ -2335,7 +2521,7 @@ static ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x,
|
|
2335
2521
|
#endif
|
2336
2522
|
sum += (ggml_float)_mm_cvtss_f32(val);
|
2337
2523
|
}
|
2338
|
-
#elif defined(__ARM_NEON)
|
2524
|
+
#elif defined(__ARM_NEON) && defined(__aarch64__)
|
2339
2525
|
for (; i + 3 < n; i += 4) {
|
2340
2526
|
float32x4_t val = ggml_v_expf(vsubq_f32(vld1q_f32(x + i),
|
2341
2527
|
vdupq_n_f32(max)));
|
@@ -2489,9 +2675,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|
2489
2675
|
"ARGSORT",
|
2490
2676
|
"LEAKY_RELU",
|
2491
2677
|
|
2492
|
-
"FLASH_ATTN",
|
2493
2678
|
"FLASH_ATTN_EXT",
|
2494
|
-
"FLASH_FF",
|
2495
2679
|
"FLASH_ATTN_BACK",
|
2496
2680
|
"SSM_CONV",
|
2497
2681
|
"SSM_SCAN",
|
@@ -2517,7 +2701,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|
2517
2701
|
"CROSS_ENTROPY_LOSS_BACK",
|
2518
2702
|
};
|
2519
2703
|
|
2520
|
-
static_assert(GGML_OP_COUNT ==
|
2704
|
+
static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");
|
2521
2705
|
|
2522
2706
|
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
2523
2707
|
"none",
|
@@ -2579,9 +2763,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|
2579
2763
|
"argsort(x)",
|
2580
2764
|
"leaky_relu(x)",
|
2581
2765
|
|
2582
|
-
"flash_attn(x)",
|
2583
2766
|
"flash_attn_ext(x)",
|
2584
|
-
"flash_ff(x)",
|
2585
2767
|
"flash_attn_back(x)",
|
2586
2768
|
"ssm_conv(x)",
|
2587
2769
|
"ssm_scan(x)",
|
@@ -2607,7 +2789,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|
2607
2789
|
"cross_entropy_loss_back(x,y)",
|
2608
2790
|
};
|
2609
2791
|
|
2610
|
-
static_assert(GGML_OP_COUNT ==
|
2792
|
+
static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");
|
2611
2793
|
|
2612
2794
|
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
2613
2795
|
|
@@ -2706,24 +2888,20 @@ struct ggml_state {
|
|
2706
2888
|
|
2707
2889
|
// global state
|
2708
2890
|
static struct ggml_state g_state;
|
2709
|
-
static
|
2891
|
+
static atomic_flag g_state_critical = ATOMIC_FLAG_INIT;
|
2710
2892
|
|
2711
2893
|
// barrier via spin lock
|
2712
2894
|
inline static void ggml_critical_section_start(void) {
|
2713
|
-
|
2714
|
-
|
2715
|
-
|
2716
|
-
// wait for other threads to finish
|
2717
|
-
atomic_fetch_sub(&g_state_barrier, 1);
|
2718
|
-
sched_yield(); // TODO: reconsider this
|
2719
|
-
processing = atomic_fetch_add(&g_state_barrier, 1);
|
2895
|
+
while (atomic_flag_test_and_set(&g_state_critical)) {
|
2896
|
+
// spin
|
2897
|
+
sched_yield();
|
2720
2898
|
}
|
2721
2899
|
}
|
2722
2900
|
|
2723
2901
|
// TODO: make this somehow automatically executed
|
2724
2902
|
// some sort of "sentry" mechanism
|
2725
2903
|
inline static void ggml_critical_section_end(void) {
|
2726
|
-
|
2904
|
+
atomic_flag_clear(&g_state_critical);
|
2727
2905
|
}
|
2728
2906
|
|
2729
2907
|
#if defined(__gnu_linux__)
|
@@ -3039,7 +3217,11 @@ GGML_CALL bool ggml_is_contiguous(const struct ggml_tensor * tensor) {
|
|
3039
3217
|
tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
|
3040
3218
|
}
|
3041
3219
|
|
3042
|
-
|
3220
|
+
GGML_CALL bool ggml_is_contiguous_0(const struct ggml_tensor * tensor) {
|
3221
|
+
return ggml_is_contiguous(tensor);
|
3222
|
+
}
|
3223
|
+
|
3224
|
+
GGML_CALL bool ggml_is_contiguous_1(const struct ggml_tensor * tensor) {
|
3043
3225
|
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
3044
3226
|
|
3045
3227
|
return
|
@@ -3048,6 +3230,14 @@ static inline bool ggml_is_contiguous_except_dim_1(const struct ggml_tensor * te
|
|
3048
3230
|
tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
|
3049
3231
|
}
|
3050
3232
|
|
3233
|
+
GGML_CALL bool ggml_is_contiguous_2(const struct ggml_tensor * tensor) {
|
3234
|
+
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
3235
|
+
|
3236
|
+
return
|
3237
|
+
tensor->nb[0] == ggml_type_size(tensor->type) &&
|
3238
|
+
tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
|
3239
|
+
}
|
3240
|
+
|
3051
3241
|
GGML_CALL bool ggml_is_permuted(const struct ggml_tensor * tensor) {
|
3052
3242
|
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
3053
3243
|
|
@@ -4705,10 +4895,21 @@ struct ggml_tensor * ggml_repeat_back(
|
|
4705
4895
|
// ggml_concat
|
4706
4896
|
|
4707
4897
|
struct ggml_tensor * ggml_concat(
|
4708
|
-
struct ggml_context* ctx,
|
4709
|
-
struct ggml_tensor* a,
|
4710
|
-
struct ggml_tensor* b
|
4711
|
-
|
4898
|
+
struct ggml_context * ctx,
|
4899
|
+
struct ggml_tensor * a,
|
4900
|
+
struct ggml_tensor * b,
|
4901
|
+
int dim) {
|
4902
|
+
GGML_ASSERT(dim >= 0 && dim < GGML_MAX_DIMS);
|
4903
|
+
|
4904
|
+
int64_t ne[GGML_MAX_DIMS];
|
4905
|
+
for (int d = 0; d < GGML_MAX_DIMS; ++d) {
|
4906
|
+
if (d == dim) {
|
4907
|
+
ne[d] = a->ne[d] + b->ne[d];
|
4908
|
+
continue;
|
4909
|
+
}
|
4910
|
+
GGML_ASSERT(a->ne[d] == b->ne[d]);
|
4911
|
+
ne[d] = a->ne[d];
|
4912
|
+
}
|
4712
4913
|
|
4713
4914
|
bool is_node = false;
|
4714
4915
|
|
@@ -4716,7 +4917,9 @@ struct ggml_tensor * ggml_concat(
|
|
4716
4917
|
is_node = true;
|
4717
4918
|
}
|
4718
4919
|
|
4719
|
-
struct ggml_tensor * result =
|
4920
|
+
struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, GGML_MAX_DIMS, ne);
|
4921
|
+
|
4922
|
+
ggml_set_op_params_i32(result, 0, dim);
|
4720
4923
|
|
4721
4924
|
result->op = GGML_OP_CONCAT;
|
4722
4925
|
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
@@ -4836,6 +5039,7 @@ struct ggml_tensor * ggml_leaky_relu(
|
|
4836
5039
|
}
|
4837
5040
|
|
4838
5041
|
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
5042
|
+
|
4839
5043
|
ggml_set_op_params(result, &negative_slope, sizeof(negative_slope));
|
4840
5044
|
|
4841
5045
|
result->op = GGML_OP_LEAKY_RELU;
|
@@ -6042,6 +6246,7 @@ static struct ggml_tensor * ggml_rope_impl(
|
|
6042
6246
|
struct ggml_context * ctx,
|
6043
6247
|
struct ggml_tensor * a,
|
6044
6248
|
struct ggml_tensor * b,
|
6249
|
+
struct ggml_tensor * c,
|
6045
6250
|
int n_dims,
|
6046
6251
|
int mode,
|
6047
6252
|
int n_ctx,
|
@@ -6055,10 +6260,17 @@ static struct ggml_tensor * ggml_rope_impl(
|
|
6055
6260
|
float xpos_base,
|
6056
6261
|
bool xpos_down,
|
6057
6262
|
bool inplace) {
|
6263
|
+
GGML_ASSERT((mode & 1) == 0 && "mode & 1 == 1 is no longer supported");
|
6264
|
+
|
6058
6265
|
GGML_ASSERT(ggml_is_vector(b));
|
6059
6266
|
GGML_ASSERT(b->type == GGML_TYPE_I32);
|
6060
6267
|
GGML_ASSERT(a->ne[2] == b->ne[0]);
|
6061
6268
|
|
6269
|
+
if (c) {
|
6270
|
+
GGML_ASSERT(c->type == GGML_TYPE_F32);
|
6271
|
+
GGML_ASSERT(c->ne[0] >= n_dims / 2);
|
6272
|
+
}
|
6273
|
+
|
6062
6274
|
bool is_node = false;
|
6063
6275
|
|
6064
6276
|
if (a->grad) {
|
@@ -6082,6 +6294,7 @@ static struct ggml_tensor * ggml_rope_impl(
|
|
6082
6294
|
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
6083
6295
|
result->src[0] = a;
|
6084
6296
|
result->src[1] = b;
|
6297
|
+
result->src[2] = c;
|
6085
6298
|
|
6086
6299
|
return result;
|
6087
6300
|
}
|
@@ -6094,7 +6307,7 @@ struct ggml_tensor * ggml_rope(
|
|
6094
6307
|
int mode,
|
6095
6308
|
int n_ctx) {
|
6096
6309
|
return ggml_rope_impl(
|
6097
|
-
ctx, a, b, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, false
|
6310
|
+
ctx, a, b, NULL, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, false
|
6098
6311
|
);
|
6099
6312
|
}
|
6100
6313
|
|
@@ -6106,7 +6319,49 @@ struct ggml_tensor * ggml_rope_inplace(
|
|
6106
6319
|
int mode,
|
6107
6320
|
int n_ctx) {
|
6108
6321
|
return ggml_rope_impl(
|
6109
|
-
ctx, a, b, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, true
|
6322
|
+
ctx, a, b, NULL, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, true
|
6323
|
+
);
|
6324
|
+
}
|
6325
|
+
|
6326
|
+
struct ggml_tensor * ggml_rope_ext(
|
6327
|
+
struct ggml_context * ctx,
|
6328
|
+
struct ggml_tensor * a,
|
6329
|
+
struct ggml_tensor * b,
|
6330
|
+
struct ggml_tensor * c,
|
6331
|
+
int n_dims,
|
6332
|
+
int mode,
|
6333
|
+
int n_ctx,
|
6334
|
+
int n_orig_ctx,
|
6335
|
+
float freq_base,
|
6336
|
+
float freq_scale,
|
6337
|
+
float ext_factor,
|
6338
|
+
float attn_factor,
|
6339
|
+
float beta_fast,
|
6340
|
+
float beta_slow) {
|
6341
|
+
return ggml_rope_impl(
|
6342
|
+
ctx, a, b, c, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
|
6343
|
+
ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, false
|
6344
|
+
);
|
6345
|
+
}
|
6346
|
+
|
6347
|
+
struct ggml_tensor * ggml_rope_ext_inplace(
|
6348
|
+
struct ggml_context * ctx,
|
6349
|
+
struct ggml_tensor * a,
|
6350
|
+
struct ggml_tensor * b,
|
6351
|
+
struct ggml_tensor * c,
|
6352
|
+
int n_dims,
|
6353
|
+
int mode,
|
6354
|
+
int n_ctx,
|
6355
|
+
int n_orig_ctx,
|
6356
|
+
float freq_base,
|
6357
|
+
float freq_scale,
|
6358
|
+
float ext_factor,
|
6359
|
+
float attn_factor,
|
6360
|
+
float beta_fast,
|
6361
|
+
float beta_slow) {
|
6362
|
+
return ggml_rope_impl(
|
6363
|
+
ctx, a, b, c, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
|
6364
|
+
ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, true
|
6110
6365
|
);
|
6111
6366
|
}
|
6112
6367
|
|
@@ -6125,7 +6380,7 @@ struct ggml_tensor * ggml_rope_custom(
|
|
6125
6380
|
float beta_fast,
|
6126
6381
|
float beta_slow) {
|
6127
6382
|
return ggml_rope_impl(
|
6128
|
-
ctx, a, b, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
|
6383
|
+
ctx, a, b, NULL, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
|
6129
6384
|
ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, false
|
6130
6385
|
);
|
6131
6386
|
}
|
@@ -6145,7 +6400,7 @@ struct ggml_tensor * ggml_rope_custom_inplace(
|
|
6145
6400
|
float beta_fast,
|
6146
6401
|
float beta_slow) {
|
6147
6402
|
return ggml_rope_impl(
|
6148
|
-
ctx, a, b, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
|
6403
|
+
ctx, a, b, NULL, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
|
6149
6404
|
ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, true
|
6150
6405
|
);
|
6151
6406
|
}
|
@@ -6157,7 +6412,7 @@ struct ggml_tensor * ggml_rope_xpos_inplace(
|
|
6157
6412
|
int n_dims,
|
6158
6413
|
float base,
|
6159
6414
|
bool down) {
|
6160
|
-
return ggml_rope_impl(ctx, a, b, n_dims, 0, 0, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, base, down, true);
|
6415
|
+
return ggml_rope_impl(ctx, a, b, NULL, n_dims, 0, 0, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, base, down, true);
|
6161
6416
|
}
|
6162
6417
|
|
6163
6418
|
// ggml_rope_back
|
@@ -6166,6 +6421,7 @@ struct ggml_tensor * ggml_rope_back(
|
|
6166
6421
|
struct ggml_context * ctx,
|
6167
6422
|
struct ggml_tensor * a,
|
6168
6423
|
struct ggml_tensor * b,
|
6424
|
+
struct ggml_tensor * c,
|
6169
6425
|
int n_dims,
|
6170
6426
|
int mode,
|
6171
6427
|
int n_ctx,
|
@@ -6181,6 +6437,7 @@ struct ggml_tensor * ggml_rope_back(
|
|
6181
6437
|
GGML_ASSERT(ggml_is_vector(b));
|
6182
6438
|
GGML_ASSERT(b->type == GGML_TYPE_I32);
|
6183
6439
|
GGML_ASSERT(a->ne[2] == b->ne[0]);
|
6440
|
+
GGML_ASSERT(c == NULL && "freq factors not implemented yet");
|
6184
6441
|
|
6185
6442
|
GGML_ASSERT((mode & 4) == 0 && "ggml_rope_back() for ChatGLM not implemented yet");
|
6186
6443
|
|
@@ -6724,38 +6981,6 @@ struct ggml_tensor * ggml_top_k(
|
|
6724
6981
|
return result;
|
6725
6982
|
}
|
6726
6983
|
|
6727
|
-
// ggml_flash_attn
|
6728
|
-
|
6729
|
-
struct ggml_tensor * ggml_flash_attn(
|
6730
|
-
struct ggml_context * ctx,
|
6731
|
-
struct ggml_tensor * q,
|
6732
|
-
struct ggml_tensor * k,
|
6733
|
-
struct ggml_tensor * v,
|
6734
|
-
bool masked) {
|
6735
|
-
GGML_ASSERT(ggml_can_mul_mat(k, q));
|
6736
|
-
// TODO: check if vT can be multiplied by (k*qT)
|
6737
|
-
|
6738
|
-
bool is_node = false;
|
6739
|
-
|
6740
|
-
if (q->grad || k->grad || v->grad) {
|
6741
|
-
is_node = true;
|
6742
|
-
}
|
6743
|
-
|
6744
|
-
//struct ggml_tensor * result = ggml_dup_tensor(ctx, q);
|
6745
|
-
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, q->ne);
|
6746
|
-
|
6747
|
-
int32_t t = masked ? 1 : 0;
|
6748
|
-
ggml_set_op_params(result, &t, sizeof(t));
|
6749
|
-
|
6750
|
-
result->op = GGML_OP_FLASH_ATTN;
|
6751
|
-
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
6752
|
-
result->src[0] = q;
|
6753
|
-
result->src[1] = k;
|
6754
|
-
result->src[2] = v;
|
6755
|
-
|
6756
|
-
return result;
|
6757
|
-
}
|
6758
|
-
|
6759
6984
|
// ggml_flash_attn_ext
|
6760
6985
|
|
6761
6986
|
struct ggml_tensor * ggml_flash_attn_ext(
|
@@ -6815,38 +7040,6 @@ void ggml_flash_attn_ext_set_prec(
|
|
6815
7040
|
ggml_set_op_params_i32(a, 2, prec_i32); // scale is on first pos, max_bias on second
|
6816
7041
|
}
|
6817
7042
|
|
6818
|
-
// ggml_flash_ff
|
6819
|
-
|
6820
|
-
struct ggml_tensor * ggml_flash_ff(
|
6821
|
-
struct ggml_context * ctx,
|
6822
|
-
struct ggml_tensor * a,
|
6823
|
-
struct ggml_tensor * b0,
|
6824
|
-
struct ggml_tensor * b1,
|
6825
|
-
struct ggml_tensor * c0,
|
6826
|
-
struct ggml_tensor * c1) {
|
6827
|
-
GGML_ASSERT(ggml_can_mul_mat(b0, a));
|
6828
|
-
// TODO: more checks
|
6829
|
-
|
6830
|
-
bool is_node = false;
|
6831
|
-
|
6832
|
-
if (a->grad || b0->grad || b1->grad || c0->grad || c1->grad) {
|
6833
|
-
is_node = true;
|
6834
|
-
}
|
6835
|
-
|
6836
|
-
//struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
|
6837
|
-
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, a->ne);
|
6838
|
-
|
6839
|
-
result->op = GGML_OP_FLASH_FF;
|
6840
|
-
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
6841
|
-
result->src[0] = a;
|
6842
|
-
result->src[1] = b0;
|
6843
|
-
result->src[2] = b1;
|
6844
|
-
result->src[3] = c0;
|
6845
|
-
result->src[4] = c1;
|
6846
|
-
|
6847
|
-
return result;
|
6848
|
-
}
|
6849
|
-
|
6850
7043
|
// ggml_flash_attn_back
|
6851
7044
|
|
6852
7045
|
struct ggml_tensor * ggml_flash_attn_back(
|
@@ -6856,6 +7049,8 @@ struct ggml_tensor * ggml_flash_attn_back(
|
|
6856
7049
|
struct ggml_tensor * v,
|
6857
7050
|
struct ggml_tensor * d,
|
6858
7051
|
bool masked) {
|
7052
|
+
GGML_ASSERT(false && "TODO: adapt to ggml_flash_attn_ext() changes");
|
7053
|
+
|
6859
7054
|
GGML_ASSERT(ggml_can_mul_mat(k, q));
|
6860
7055
|
// TODO: check if vT can be multiplied by (k*qT)
|
6861
7056
|
|
@@ -10809,26 +11004,29 @@ static void ggml_compute_forward_concat_f32(
|
|
10809
11004
|
GGML_ASSERT(nb00 == sizeof(float));
|
10810
11005
|
GGML_ASSERT(nb10 == sizeof(float));
|
10811
11006
|
|
11007
|
+
const int32_t dim = ggml_get_op_params_i32(dst, 0);
|
11008
|
+
|
11009
|
+
GGML_ASSERT(dim >= 0 && dim < 4);
|
11010
|
+
|
11011
|
+
int64_t o[4] = {0, 0, 0, 0};
|
11012
|
+
o[dim] = src0->ne[dim];
|
11013
|
+
|
11014
|
+
const float * x;
|
11015
|
+
|
11016
|
+
// TODO: smarter multi-theading
|
10812
11017
|
for (int i3 = 0; i3 < ne3; i3++) {
|
10813
11018
|
for (int i2 = ith; i2 < ne2; i2 += nth) {
|
10814
|
-
|
10815
|
-
for (int
|
10816
|
-
|
10817
|
-
|
10818
|
-
|
10819
|
-
|
10820
|
-
*y = *x;
|
10821
|
-
}
|
10822
|
-
}
|
10823
|
-
} // src1
|
10824
|
-
else {
|
10825
|
-
for (int i1 = 0; i1 < ne1; i1++) {
|
10826
|
-
for (int i0 = 0; i0 < ne0; i0++) {
|
10827
|
-
const float * x = (float *)((char *) src1->data + i0 * nb10 + i1 * nb11 + (i2 - ne02) * nb12 + i3 * nb13);
|
10828
|
-
|
10829
|
-
float * y = (float *)((char *)dst->data + i0 * nb0 + i1 * nb1 + i2 * nb2 + i3 * nb3);
|
10830
|
-
*y = *x;
|
11019
|
+
for (int i1 = 0; i1 < ne1; i1++) {
|
11020
|
+
for (int i0 = 0; i0 < ne0; i0++) {
|
11021
|
+
if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
|
11022
|
+
x = (const float *) ((const char *)src0->data + (i0 )*nb00 + (i1 )*nb01 + (i2 )*nb02 + (i3 )*nb03);
|
11023
|
+
} else {
|
11024
|
+
x = (const float *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13);
|
10831
11025
|
}
|
11026
|
+
|
11027
|
+
float * y = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
|
11028
|
+
|
11029
|
+
*y = *x;
|
10832
11030
|
}
|
10833
11031
|
}
|
10834
11032
|
}
|
@@ -10836,8 +11034,8 @@ static void ggml_compute_forward_concat_f32(
|
|
10836
11034
|
}
|
10837
11035
|
|
10838
11036
|
static void ggml_compute_forward_concat(
|
10839
|
-
const struct ggml_compute_params* params,
|
10840
|
-
struct ggml_tensor* dst) {
|
11037
|
+
const struct ggml_compute_params * params,
|
11038
|
+
struct ggml_tensor * dst) {
|
10841
11039
|
|
10842
11040
|
const struct ggml_tensor * src0 = dst->src[0];
|
10843
11041
|
|
@@ -11230,8 +11428,8 @@ static void ggml_compute_forward_gelu_f32(
|
|
11230
11428
|
|
11231
11429
|
const struct ggml_tensor * src0 = dst->src[0];
|
11232
11430
|
|
11233
|
-
GGML_ASSERT(
|
11234
|
-
GGML_ASSERT(
|
11431
|
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
11432
|
+
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
11235
11433
|
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
11236
11434
|
|
11237
11435
|
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
|
@@ -11293,8 +11491,8 @@ static void ggml_compute_forward_gelu_quick_f32(
|
|
11293
11491
|
|
11294
11492
|
const struct ggml_tensor * src0 = dst->src[0];
|
11295
11493
|
|
11296
|
-
GGML_ASSERT(
|
11297
|
-
GGML_ASSERT(
|
11494
|
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
11495
|
+
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
11298
11496
|
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
11299
11497
|
|
11300
11498
|
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
|
@@ -11356,8 +11554,8 @@ static void ggml_compute_forward_silu_f32(
|
|
11356
11554
|
|
11357
11555
|
const struct ggml_tensor * src0 = dst->src[0];
|
11358
11556
|
|
11359
|
-
GGML_ASSERT(
|
11360
|
-
GGML_ASSERT(
|
11557
|
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
11558
|
+
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
11361
11559
|
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
11362
11560
|
|
11363
11561
|
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
|
@@ -11468,9 +11666,9 @@ static void ggml_compute_forward_silu_back_f32(
|
|
11468
11666
|
const struct ggml_tensor * src0 = dst->src[0];
|
11469
11667
|
const struct ggml_tensor * grad = dst->src[1];
|
11470
11668
|
|
11471
|
-
GGML_ASSERT(
|
11472
|
-
GGML_ASSERT(
|
11473
|
-
GGML_ASSERT(
|
11669
|
+
GGML_ASSERT(ggml_is_contiguous_1(grad));
|
11670
|
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
11671
|
+
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
11474
11672
|
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
11475
11673
|
GGML_ASSERT(ggml_are_same_shape(src0, grad));
|
11476
11674
|
|
@@ -14115,6 +14313,7 @@ static void ggml_compute_forward_rope_f32(
|
|
14115
14313
|
|
14116
14314
|
const struct ggml_tensor * src0 = dst->src[0];
|
14117
14315
|
const struct ggml_tensor * src1 = dst->src[1];
|
14316
|
+
const struct ggml_tensor * src2 = dst->src[2];
|
14118
14317
|
|
14119
14318
|
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
|
14120
14319
|
return;
|
@@ -14167,13 +14366,24 @@ static void ggml_compute_forward_rope_f32(
|
|
14167
14366
|
int ir = 0;
|
14168
14367
|
|
14169
14368
|
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
14170
|
-
|
14369
|
+
|
14171
14370
|
float corr_dims[2];
|
14172
14371
|
ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
|
14173
14372
|
|
14174
14373
|
const bool is_neox = mode & 2;
|
14175
14374
|
const bool is_glm = mode & 4;
|
14176
14375
|
|
14376
|
+
const float * freq_factors = NULL;
|
14377
|
+
if (is_neox) {
|
14378
|
+
if (src2 != NULL) {
|
14379
|
+
GGML_ASSERT(src2->type == GGML_TYPE_F32);
|
14380
|
+
GGML_ASSERT(src2->ne[0] >= n_dims / 2);
|
14381
|
+
freq_factors = (const float *) src2->data;
|
14382
|
+
}
|
14383
|
+
} else {
|
14384
|
+
GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for !is_neox");
|
14385
|
+
}
|
14386
|
+
|
14177
14387
|
// backward process uses inverse rotation by cos and sin.
|
14178
14388
|
// cos and sin build a rotation matrix, where the inverse is the transpose.
|
14179
14389
|
// this essentially just switches the sign of sin.
|
@@ -14205,7 +14415,7 @@ static void ggml_compute_forward_rope_f32(
|
|
14205
14415
|
const float cos_block_theta = cosf(block_theta);
|
14206
14416
|
const float sin_block_theta = sinf(block_theta) * sin_sign;
|
14207
14417
|
|
14208
|
-
theta_base
|
14418
|
+
theta_base *= theta_scale;
|
14209
14419
|
block_theta *= theta_scale;
|
14210
14420
|
|
14211
14421
|
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
@@ -14240,28 +14450,22 @@ static void ggml_compute_forward_rope_f32(
|
|
14240
14450
|
dst_data[1] = x0*sin_theta*zeta + x1*cos_theta*zeta;
|
14241
14451
|
}
|
14242
14452
|
} else {
|
14243
|
-
//
|
14244
|
-
// it seems we have to rope just the first n_dims elements and do nothing with the rest
|
14245
|
-
// ref: https://github.com/ml-explore/mlx/blob/dc2edc762c797e3b8de50b1dad4dc0a131691033/benchmarks/python/llama_jax_bench.py#L11-L26
|
14246
|
-
theta_base *= freq_scale;
|
14453
|
+
// ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
|
14247
14454
|
for (int64_t ic = 0; ic < ne0; ic += 2) {
|
14248
14455
|
if (ic < n_dims) {
|
14249
|
-
const int64_t
|
14456
|
+
const int64_t i0 = ic/2;
|
14250
14457
|
|
14251
|
-
|
14252
|
-
float cur_rot = inv_ndims * ic - ib;
|
14458
|
+
const float freq_factor = freq_factors ? freq_factors[i0] : 1.0f;
|
14253
14459
|
|
14254
14460
|
float cos_theta, sin_theta;
|
14255
14461
|
rope_yarn(
|
14256
|
-
theta_base, freq_scale, corr_dims,
|
14462
|
+
theta_base/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor,
|
14257
14463
|
&cos_theta, &sin_theta
|
14258
14464
|
);
|
14259
|
-
sin_theta *= sin_sign;
|
14260
14465
|
|
14466
|
+
sin_theta *= sin_sign;
|
14261
14467
|
theta_base *= theta_scale;
|
14262
14468
|
|
14263
|
-
const int64_t i0 = ib*n_dims + ic/2;
|
14264
|
-
|
14265
14469
|
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
14266
14470
|
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
14267
14471
|
|
@@ -14286,6 +14490,7 @@ static void ggml_compute_forward_rope_f32(
|
|
14286
14490
|
}
|
14287
14491
|
}
|
14288
14492
|
|
14493
|
+
// TODO: deduplicate f16/f32 code
|
14289
14494
|
static void ggml_compute_forward_rope_f16(
|
14290
14495
|
const struct ggml_compute_params * params,
|
14291
14496
|
struct ggml_tensor * dst,
|
@@ -14293,6 +14498,7 @@ static void ggml_compute_forward_rope_f16(
|
|
14293
14498
|
|
14294
14499
|
const struct ggml_tensor * src0 = dst->src[0];
|
14295
14500
|
const struct ggml_tensor * src1 = dst->src[1];
|
14501
|
+
const struct ggml_tensor * src2 = dst->src[2];
|
14296
14502
|
|
14297
14503
|
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
|
14298
14504
|
return;
|
@@ -14338,13 +14544,24 @@ static void ggml_compute_forward_rope_f16(
|
|
14338
14544
|
int ir = 0;
|
14339
14545
|
|
14340
14546
|
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
14341
|
-
|
14547
|
+
|
14342
14548
|
float corr_dims[2];
|
14343
14549
|
ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
|
14344
14550
|
|
14345
14551
|
const bool is_neox = mode & 2;
|
14346
14552
|
const bool is_glm = mode & 4;
|
14347
14553
|
|
14554
|
+
const float * freq_factors = NULL;
|
14555
|
+
if (is_neox) {
|
14556
|
+
if (src2 != NULL) {
|
14557
|
+
GGML_ASSERT(src2->type == GGML_TYPE_F32);
|
14558
|
+
GGML_ASSERT(src2->ne[0] >= n_dims / 2);
|
14559
|
+
freq_factors = (const float *) src2->data;
|
14560
|
+
}
|
14561
|
+
} else {
|
14562
|
+
GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for !is_neox");
|
14563
|
+
}
|
14564
|
+
|
14348
14565
|
// backward process uses inverse rotation by cos and sin.
|
14349
14566
|
// cos and sin build a rotation matrix, where the inverse is the transpose.
|
14350
14567
|
// this essentially just switches the sign of sin.
|
@@ -14376,7 +14593,7 @@ static void ggml_compute_forward_rope_f16(
|
|
14376
14593
|
const float cos_block_theta = cosf(block_theta);
|
14377
14594
|
const float sin_block_theta = sinf(block_theta) * sin_sign;
|
14378
14595
|
|
14379
|
-
theta_base
|
14596
|
+
theta_base *= theta_scale;
|
14380
14597
|
block_theta *= theta_scale;
|
14381
14598
|
|
14382
14599
|
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
@@ -14407,28 +14624,22 @@ static void ggml_compute_forward_rope_f16(
|
|
14407
14624
|
dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
14408
14625
|
}
|
14409
14626
|
} else {
|
14410
|
-
//
|
14411
|
-
// it seems we have to rope just the first n_dims elements and do nothing with the rest
|
14412
|
-
// ref: https://github.com/ml-explore/mlx/blob/dc2edc762c797e3b8de50b1dad4dc0a131691033/benchmarks/python/llama_jax_bench.py#L11-L26
|
14413
|
-
theta_base *= freq_scale;
|
14627
|
+
// ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
|
14414
14628
|
for (int64_t ic = 0; ic < ne0; ic += 2) {
|
14415
14629
|
if (ic < n_dims) {
|
14416
|
-
const int64_t
|
14630
|
+
const int64_t i0 = ic/2;
|
14417
14631
|
|
14418
|
-
|
14419
|
-
float cur_rot = inv_ndims * ic - ib;
|
14632
|
+
const float freq_factor = freq_factors ? freq_factors[i0] : 1.0f;
|
14420
14633
|
|
14421
14634
|
float cos_theta, sin_theta;
|
14422
14635
|
rope_yarn(
|
14423
|
-
theta_base, freq_scale, corr_dims,
|
14636
|
+
theta_base/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor,
|
14424
14637
|
&cos_theta, &sin_theta
|
14425
14638
|
);
|
14426
|
-
sin_theta *= sin_sign;
|
14427
14639
|
|
14640
|
+
sin_theta *= sin_sign;
|
14428
14641
|
theta_base *= theta_scale;
|
14429
14642
|
|
14430
|
-
const int64_t i0 = ib*n_dims + ic/2;
|
14431
|
-
|
14432
14643
|
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
14433
14644
|
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
14434
14645
|
|
@@ -15458,400 +15669,6 @@ static void ggml_compute_forward_argsort(
|
|
15458
15669
|
}
|
15459
15670
|
}
|
15460
15671
|
|
15461
|
-
// ggml_compute_forward_flash_attn
|
15462
|
-
|
15463
|
-
static void ggml_compute_forward_flash_attn_f32(
|
15464
|
-
const struct ggml_compute_params * params,
|
15465
|
-
const bool masked,
|
15466
|
-
struct ggml_tensor * dst) {
|
15467
|
-
|
15468
|
-
const struct ggml_tensor * q = dst->src[0];
|
15469
|
-
const struct ggml_tensor * k = dst->src[1];
|
15470
|
-
const struct ggml_tensor * v = dst->src[2];
|
15471
|
-
|
15472
|
-
int64_t t0 = ggml_perf_time_us();
|
15473
|
-
UNUSED(t0);
|
15474
|
-
|
15475
|
-
GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
|
15476
|
-
GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
|
15477
|
-
GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
|
15478
|
-
GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
|
15479
|
-
GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
|
15480
|
-
GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
|
15481
|
-
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
15482
|
-
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
15483
|
-
|
15484
|
-
const int ith = params->ith;
|
15485
|
-
const int nth = params->nth;
|
15486
|
-
|
15487
|
-
const int64_t D = neq0;
|
15488
|
-
const int64_t N = neq1;
|
15489
|
-
const int64_t P = nek1 - N;
|
15490
|
-
const int64_t M = P + N;
|
15491
|
-
|
15492
|
-
const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL);
|
15493
|
-
|
15494
|
-
GGML_ASSERT(ne0 == D);
|
15495
|
-
GGML_ASSERT(ne1 == N);
|
15496
|
-
GGML_ASSERT(P >= 0);
|
15497
|
-
|
15498
|
-
GGML_ASSERT(nbq0 == sizeof(float));
|
15499
|
-
GGML_ASSERT(nbk0 == sizeof(float));
|
15500
|
-
GGML_ASSERT(nbv0 == sizeof(float));
|
15501
|
-
|
15502
|
-
GGML_ASSERT(neq0 == D);
|
15503
|
-
GGML_ASSERT(nek0 == D);
|
15504
|
-
GGML_ASSERT(nev1 == D);
|
15505
|
-
|
15506
|
-
GGML_ASSERT(neq1 == N);
|
15507
|
-
GGML_ASSERT(nek1 == N + P);
|
15508
|
-
GGML_ASSERT(nev1 == D);
|
15509
|
-
|
15510
|
-
// dst cannot be transposed or permuted
|
15511
|
-
GGML_ASSERT(nb0 == sizeof(float));
|
15512
|
-
GGML_ASSERT(nb0 <= nb1);
|
15513
|
-
GGML_ASSERT(nb1 <= nb2);
|
15514
|
-
GGML_ASSERT(nb2 <= nb3);
|
15515
|
-
|
15516
|
-
if (params->type == GGML_TASK_TYPE_INIT) {
|
15517
|
-
return;
|
15518
|
-
}
|
15519
|
-
|
15520
|
-
if (params->type == GGML_TASK_TYPE_FINALIZE) {
|
15521
|
-
return;
|
15522
|
-
}
|
15523
|
-
|
15524
|
-
// parallelize by q rows using ggml_vec_dot_f32
|
15525
|
-
|
15526
|
-
// total rows in q
|
15527
|
-
const int nr = neq1*neq2*neq3;
|
15528
|
-
|
15529
|
-
// rows per thread
|
15530
|
-
const int dr = (nr + nth - 1)/nth;
|
15531
|
-
|
15532
|
-
// row range for this thread
|
15533
|
-
const int ir0 = dr*ith;
|
15534
|
-
const int ir1 = MIN(ir0 + dr, nr);
|
15535
|
-
|
15536
|
-
const float scale = 1.0f/sqrtf(D);
|
15537
|
-
|
15538
|
-
//printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
|
15539
|
-
|
15540
|
-
for (int ir = ir0; ir < ir1; ++ir) {
|
15541
|
-
// q indices
|
15542
|
-
const int iq3 = ir/(neq2*neq1);
|
15543
|
-
const int iq2 = (ir - iq3*neq2*neq1)/neq1;
|
15544
|
-
const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
|
15545
|
-
|
15546
|
-
float * S = (float *) params->wdata + ith*(Mup + CACHE_LINE_SIZE_F32);
|
15547
|
-
|
15548
|
-
for (int i = M; i < Mup; ++i) {
|
15549
|
-
S[i] = -INFINITY;
|
15550
|
-
}
|
15551
|
-
|
15552
|
-
const int64_t masked_begin = masked ? (P + iq1 + 1) : M;
|
15553
|
-
for (int64_t ic = 0; ic < masked_begin; ++ic) {
|
15554
|
-
// k indices
|
15555
|
-
const int ik3 = iq3;
|
15556
|
-
const int ik2 = iq2 % nek2;
|
15557
|
-
const int ik1 = ic;
|
15558
|
-
|
15559
|
-
// S indices
|
15560
|
-
const int i1 = ik1;
|
15561
|
-
|
15562
|
-
ggml_vec_dot_f32(neq0,
|
15563
|
-
S + i1, 0,
|
15564
|
-
(float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
|
15565
|
-
(float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)), 0, 1);
|
15566
|
-
}
|
15567
|
-
|
15568
|
-
// scale
|
15569
|
-
ggml_vec_scale_f32(masked_begin, S, scale);
|
15570
|
-
|
15571
|
-
for (int64_t i = masked_begin; i < M; i++) {
|
15572
|
-
S[i] = -INFINITY;
|
15573
|
-
}
|
15574
|
-
|
15575
|
-
// softmax
|
15576
|
-
// exclude known -INF S[..] values from max and loop
|
15577
|
-
// dont forget to set their SW values to zero
|
15578
|
-
{
|
15579
|
-
float max = -INFINITY;
|
15580
|
-
ggml_vec_max_f32(masked_begin, &max, S);
|
15581
|
-
|
15582
|
-
ggml_float sum = 0.0;
|
15583
|
-
{
|
15584
|
-
#ifdef GGML_SOFT_MAX_ACCELERATE
|
15585
|
-
max = -max;
|
15586
|
-
vDSP_vsadd(S, 1, &max, S, 1, Mup);
|
15587
|
-
vvexpf(S, S, &Mup);
|
15588
|
-
ggml_vec_sum_f32(Mup, &sum, S);
|
15589
|
-
#else
|
15590
|
-
sum = ggml_vec_soft_max_f32(Mup, S, S, max);
|
15591
|
-
#endif
|
15592
|
-
}
|
15593
|
-
|
15594
|
-
assert(sum > 0.0);
|
15595
|
-
|
15596
|
-
sum = 1.0/sum;
|
15597
|
-
ggml_vec_scale_f32(masked_begin, S, sum);
|
15598
|
-
|
15599
|
-
#ifndef NDEBUG
|
15600
|
-
for (int i = 0; i < masked_begin; ++i) {
|
15601
|
-
assert(!isnan(S[i]));
|
15602
|
-
assert(!isinf(S[i]));
|
15603
|
-
}
|
15604
|
-
#endif
|
15605
|
-
}
|
15606
|
-
|
15607
|
-
for (int64_t ic = 0; ic < nev1; ++ic) {
|
15608
|
-
// dst indices
|
15609
|
-
const int i1 = iq1;
|
15610
|
-
const int i2 = iq2;
|
15611
|
-
const int i3 = iq3;
|
15612
|
-
|
15613
|
-
// v indices
|
15614
|
-
const int iv2 = iq2 % nev2;
|
15615
|
-
const int iv3 = iq3;
|
15616
|
-
|
15617
|
-
ggml_vec_dot_f32(masked_begin,
|
15618
|
-
(float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), 0,
|
15619
|
-
(float *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), 0,
|
15620
|
-
S, 0, 1);
|
15621
|
-
}
|
15622
|
-
}
|
15623
|
-
}
|
15624
|
-
|
15625
|
-
static void ggml_compute_forward_flash_attn_f16(
|
15626
|
-
const struct ggml_compute_params * params,
|
15627
|
-
const bool masked,
|
15628
|
-
struct ggml_tensor * dst) {
|
15629
|
-
|
15630
|
-
const struct ggml_tensor * q = dst->src[0];
|
15631
|
-
const struct ggml_tensor * k = dst->src[1];
|
15632
|
-
const struct ggml_tensor * v = dst->src[2];
|
15633
|
-
|
15634
|
-
int64_t t0 = ggml_perf_time_us();
|
15635
|
-
UNUSED(t0);
|
15636
|
-
|
15637
|
-
GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
|
15638
|
-
GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
|
15639
|
-
GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
|
15640
|
-
GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
|
15641
|
-
GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
|
15642
|
-
GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
|
15643
|
-
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
15644
|
-
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
15645
|
-
|
15646
|
-
const int ith = params->ith;
|
15647
|
-
const int nth = params->nth;
|
15648
|
-
|
15649
|
-
const int64_t D = neq0;
|
15650
|
-
const int64_t N = neq1;
|
15651
|
-
const int64_t P = nek1 - N;
|
15652
|
-
const int64_t M = P + N;
|
15653
|
-
|
15654
|
-
const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL);
|
15655
|
-
|
15656
|
-
GGML_ASSERT(ne0 == D);
|
15657
|
-
GGML_ASSERT(ne1 == N);
|
15658
|
-
GGML_ASSERT(P >= 0);
|
15659
|
-
|
15660
|
-
GGML_ASSERT(nbq0 == sizeof(ggml_fp16_t));
|
15661
|
-
GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t));
|
15662
|
-
GGML_ASSERT(nbv0 == sizeof(ggml_fp16_t));
|
15663
|
-
|
15664
|
-
GGML_ASSERT(neq0 == D);
|
15665
|
-
GGML_ASSERT(nek0 == D);
|
15666
|
-
GGML_ASSERT(nev1 == D);
|
15667
|
-
|
15668
|
-
GGML_ASSERT(neq1 == N);
|
15669
|
-
GGML_ASSERT(nek1 == N + P);
|
15670
|
-
GGML_ASSERT(nev1 == D);
|
15671
|
-
|
15672
|
-
// dst cannot be transposed or permuted
|
15673
|
-
GGML_ASSERT(nb0 == sizeof(float));
|
15674
|
-
GGML_ASSERT(nb0 <= nb1);
|
15675
|
-
GGML_ASSERT(nb1 <= nb2);
|
15676
|
-
GGML_ASSERT(nb2 <= nb3);
|
15677
|
-
|
15678
|
-
if (params->type == GGML_TASK_TYPE_INIT) {
|
15679
|
-
return;
|
15680
|
-
}
|
15681
|
-
|
15682
|
-
if (params->type == GGML_TASK_TYPE_FINALIZE) {
|
15683
|
-
return;
|
15684
|
-
}
|
15685
|
-
|
15686
|
-
// parallelize by q rows using ggml_vec_dot_f32
|
15687
|
-
|
15688
|
-
// total rows in q
|
15689
|
-
const int nr = neq1*neq2*neq3;
|
15690
|
-
|
15691
|
-
// rows per thread
|
15692
|
-
const int dr = (nr + nth - 1)/nth;
|
15693
|
-
|
15694
|
-
// row range for this thread
|
15695
|
-
const int ir0 = dr*ith;
|
15696
|
-
const int ir1 = MIN(ir0 + dr, nr);
|
15697
|
-
|
15698
|
-
const float scale = 1.0f/sqrtf(D);
|
15699
|
-
|
15700
|
-
//printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
|
15701
|
-
|
15702
|
-
for (int ir = ir0; ir < ir1; ++ir) {
|
15703
|
-
// q indices
|
15704
|
-
const int iq3 = ir/(neq2*neq1);
|
15705
|
-
const int iq2 = (ir - iq3*neq2*neq1)/neq1;
|
15706
|
-
const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
|
15707
|
-
|
15708
|
-
float * S = (float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32);
|
15709
|
-
|
15710
|
-
for (int i = M; i < Mup; ++i) {
|
15711
|
-
S[i] = -INFINITY;
|
15712
|
-
}
|
15713
|
-
|
15714
|
-
if (GGML_VEC_DOT_UNROLL > 2 || nek1 % GGML_VEC_DOT_UNROLL != 0) {
|
15715
|
-
for (int64_t ic = 0; ic < nek1; ++ic) {
|
15716
|
-
// k indices
|
15717
|
-
const int ik3 = iq3;
|
15718
|
-
const int ik2 = iq2 % nek2;
|
15719
|
-
const int ik1 = ic;
|
15720
|
-
|
15721
|
-
// S indices
|
15722
|
-
const int i1 = ik1;
|
15723
|
-
|
15724
|
-
ggml_vec_dot_f16(neq0,
|
15725
|
-
S + i1, 0,
|
15726
|
-
(ggml_fp16_t *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
|
15727
|
-
(ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)), 0, 1);
|
15728
|
-
}
|
15729
|
-
} else {
|
15730
|
-
for (int64_t ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) {
|
15731
|
-
// k indices
|
15732
|
-
const int ik3 = iq3;
|
15733
|
-
const int ik2 = iq2 % nek2;
|
15734
|
-
const int ik1 = ic;
|
15735
|
-
|
15736
|
-
// S indices
|
15737
|
-
const int i1 = ik1;
|
15738
|
-
|
15739
|
-
ggml_vec_dot_f16_unroll(neq0, nbk1,
|
15740
|
-
S + i1,
|
15741
|
-
((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
|
15742
|
-
(ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
|
15743
|
-
}
|
15744
|
-
}
|
15745
|
-
|
15746
|
-
// scale
|
15747
|
-
ggml_vec_scale_f32(nek1, S, scale);
|
15748
|
-
|
15749
|
-
if (masked) {
|
15750
|
-
for (int64_t i = P; i < M; i++) {
|
15751
|
-
if (i > P + iq1) {
|
15752
|
-
S[i] = -INFINITY;
|
15753
|
-
}
|
15754
|
-
}
|
15755
|
-
}
|
15756
|
-
|
15757
|
-
// softmax
|
15758
|
-
// todo: exclude known -INF S[..] values from max and loop, assuming their results to be zero.
|
15759
|
-
// dont forget to set their S values to zero
|
15760
|
-
{
|
15761
|
-
float max = -INFINITY;
|
15762
|
-
ggml_vec_max_f32(M, &max, S);
|
15763
|
-
|
15764
|
-
ggml_float sum = 0.0;
|
15765
|
-
{
|
15766
|
-
#ifdef GGML_SOFT_MAX_ACCELERATE
|
15767
|
-
max = -max;
|
15768
|
-
vDSP_vsadd(S, 1, &max, S, 1, Mup);
|
15769
|
-
vvexpf(S, S, &Mup);
|
15770
|
-
ggml_vec_sum_f32(Mup, &sum, S);
|
15771
|
-
#else
|
15772
|
-
sum = ggml_vec_soft_max_f32(Mup, S, S, max);
|
15773
|
-
#endif
|
15774
|
-
}
|
15775
|
-
|
15776
|
-
assert(sum > 0.0);
|
15777
|
-
|
15778
|
-
sum = 1.0/sum;
|
15779
|
-
ggml_vec_scale_f32(M, S, sum);
|
15780
|
-
|
15781
|
-
#ifndef NDEBUG
|
15782
|
-
for (int i = 0; i < M; ++i) {
|
15783
|
-
assert(!isnan(S[i]));
|
15784
|
-
assert(!isinf(S[i]));
|
15785
|
-
}
|
15786
|
-
#endif
|
15787
|
-
}
|
15788
|
-
|
15789
|
-
ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32) + Mup);
|
15790
|
-
|
15791
|
-
for (int64_t i = 0; i < M; i++) {
|
15792
|
-
S16[i] = GGML_FP32_TO_FP16(S[i]);
|
15793
|
-
}
|
15794
|
-
|
15795
|
-
// todo: exclude known zero S[..] values from dot (reducing nev0 and increasing begin of v and S16).
|
15796
|
-
if (GGML_VEC_DOT_UNROLL == 1 || (nev1 % GGML_VEC_DOT_UNROLL != 0)) {
|
15797
|
-
for (int64_t ic = 0; ic < nev1; ++ic) {
|
15798
|
-
// dst indices
|
15799
|
-
const int i1 = iq1;
|
15800
|
-
const int i2 = iq2;
|
15801
|
-
const int i3 = iq3;
|
15802
|
-
|
15803
|
-
// v indices
|
15804
|
-
const int iv2 = iq2 % nev2;
|
15805
|
-
const int iv3 = iq3;
|
15806
|
-
|
15807
|
-
ggml_vec_dot_f16(nev0,
|
15808
|
-
(float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), 0,
|
15809
|
-
(ggml_fp16_t *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), 0,
|
15810
|
-
S16, 0, 1);
|
15811
|
-
}
|
15812
|
-
} else {
|
15813
|
-
for (int64_t ic = 0; ic < nev1; ic += GGML_VEC_DOT_UNROLL) {
|
15814
|
-
// dst indices
|
15815
|
-
const int i1 = iq1;
|
15816
|
-
const int i2 = iq2;
|
15817
|
-
const int i3 = iq3;
|
15818
|
-
|
15819
|
-
// v indices
|
15820
|
-
const int iv2 = iq2 % nev2;
|
15821
|
-
const int iv3 = iq3;
|
15822
|
-
|
15823
|
-
ggml_vec_dot_f16_unroll(nev0, nbv1,
|
15824
|
-
(float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
|
15825
|
-
((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
|
15826
|
-
S16);
|
15827
|
-
}
|
15828
|
-
}
|
15829
|
-
}
|
15830
|
-
}
|
15831
|
-
|
15832
|
-
static void ggml_compute_forward_flash_attn(
|
15833
|
-
const struct ggml_compute_params * params,
|
15834
|
-
const bool masked,
|
15835
|
-
struct ggml_tensor * dst) {
|
15836
|
-
|
15837
|
-
const struct ggml_tensor * q = dst->src[0];
|
15838
|
-
|
15839
|
-
switch (q->type) {
|
15840
|
-
case GGML_TYPE_F16:
|
15841
|
-
{
|
15842
|
-
ggml_compute_forward_flash_attn_f16(params, masked, dst);
|
15843
|
-
} break;
|
15844
|
-
case GGML_TYPE_F32:
|
15845
|
-
{
|
15846
|
-
ggml_compute_forward_flash_attn_f32(params, masked, dst);
|
15847
|
-
} break;
|
15848
|
-
default:
|
15849
|
-
{
|
15850
|
-
GGML_ASSERT(false);
|
15851
|
-
} break;
|
15852
|
-
}
|
15853
|
-
}
|
15854
|
-
|
15855
15672
|
// ggml_compute_forward_flash_attn_ext
|
15856
15673
|
|
15857
15674
|
static void ggml_compute_forward_flash_attn_ext_f16(
|
@@ -15882,9 +15699,10 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
15882
15699
|
GGML_ASSERT(ne0 == D);
|
15883
15700
|
GGML_ASSERT(ne2 == N);
|
15884
15701
|
|
15885
|
-
|
15886
|
-
GGML_ASSERT(
|
15887
|
-
GGML_ASSERT(
|
15702
|
+
// input tensor rows must be contiguous
|
15703
|
+
GGML_ASSERT(nbq0 == ggml_type_size(q->type));
|
15704
|
+
GGML_ASSERT(nbk0 == ggml_type_size(k->type));
|
15705
|
+
GGML_ASSERT(nbv0 == ggml_type_size(v->type));
|
15888
15706
|
|
15889
15707
|
GGML_ASSERT(neq0 == D);
|
15890
15708
|
GGML_ASSERT(nek0 == D);
|
@@ -15938,6 +15756,11 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
15938
15756
|
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
15939
15757
|
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
15940
15758
|
|
15759
|
+
enum ggml_type const k_vec_dot_type = type_traits[k->type].vec_dot_type;
|
15760
|
+
ggml_from_float_t const q_to_vec_dot = type_traits[k_vec_dot_type].from_float;
|
15761
|
+
ggml_vec_dot_t const kq_vec_dot = type_traits[k->type].vec_dot;
|
15762
|
+
ggml_to_float_t const v_to_float = type_traits[v->type].to_float;
|
15763
|
+
|
15941
15764
|
// loop over n_batch and n_head
|
15942
15765
|
for (int ir = ir0; ir < ir1; ++ir) {
|
15943
15766
|
// q indices
|
@@ -15945,17 +15768,22 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
15945
15768
|
const int iq2 = (ir - iq3*neq2*neq1)/neq1;
|
15946
15769
|
const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
|
15947
15770
|
|
15948
|
-
const uint32_t h = iq2; // head
|
15771
|
+
const uint32_t h = iq2; // head index
|
15949
15772
|
const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
|
15950
15773
|
|
15951
|
-
float S = 0.0f;
|
15952
|
-
float M = -INFINITY;
|
15774
|
+
float S = 0.0f; // sum
|
15775
|
+
float M = -INFINITY; // maximum KQ value
|
15953
15776
|
|
15954
|
-
float *
|
15955
|
-
|
15956
|
-
ggml_fp16_t *
|
15777
|
+
float * VKQ32 = (float *) params->wdata + ith*(3*D + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator
|
15778
|
+
float * V32 = (VKQ32 + 1*D); // (temporary) FP32 V buffer
|
15779
|
+
ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*D); // (temporary) FP16 VKQ accumulator
|
15780
|
+
ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*D); // (temporary) buffer for Q converted to quantized/FP16
|
15957
15781
|
|
15958
|
-
|
15782
|
+
if (v->type == GGML_TYPE_F16) {
|
15783
|
+
memset(VKQ16, 0, D*sizeof(ggml_fp16_t));
|
15784
|
+
} else {
|
15785
|
+
memset(VKQ32, 0, D*sizeof(float));
|
15786
|
+
}
|
15959
15787
|
|
15960
15788
|
const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
|
15961
15789
|
|
@@ -15967,6 +15795,9 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
15967
15795
|
const int iv3 = iq3 / rv3;
|
15968
15796
|
const int iv2 = iq2 / rv2;
|
15969
15797
|
|
15798
|
+
const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
|
15799
|
+
q_to_vec_dot(pq, Q_q, D);
|
15800
|
+
|
15970
15801
|
// online softmax / attention
|
15971
15802
|
// loop over n_kv and n_head_kv
|
15972
15803
|
// ref: https://arxiv.org/pdf/2112.05682.pdf
|
@@ -15976,52 +15807,67 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
15976
15807
|
continue;
|
15977
15808
|
}
|
15978
15809
|
|
15979
|
-
float s;
|
15810
|
+
float s; // KQ value
|
15980
15811
|
|
15981
|
-
|
15982
|
-
|
15983
|
-
const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
|
15812
|
+
const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3);
|
15813
|
+
kq_vec_dot(D, &s, 0, k_data, 0, Q_q, 0, 1);
|
15984
15814
|
|
15985
|
-
|
15986
|
-
Q16[d] = GGML_FP32_TO_FP16(pq[d]);
|
15987
|
-
}
|
15988
|
-
}
|
15815
|
+
s = s*scale + mv; // scale KQ value and apply mask
|
15989
15816
|
|
15990
|
-
|
15991
|
-
&s, 0,
|
15992
|
-
(ggml_fp16_t *) ((char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
|
15993
|
-
Q16, 0, 1);
|
15817
|
+
const float Mold = M;
|
15994
15818
|
|
15995
|
-
|
15819
|
+
float ms = 1.0f; // upon new higher max val, scale VKQ and KQ sum with this value
|
15820
|
+
float vs = 1.0f; // post-softmax KQ value, expf(s - M)
|
15996
15821
|
|
15997
|
-
const
|
15822
|
+
const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
|
15998
15823
|
|
15999
|
-
|
16000
|
-
|
15824
|
+
if (v->type== GGML_TYPE_F16) {
|
15825
|
+
if (s > M) {
|
15826
|
+
// s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
|
15827
|
+
M = s;
|
15828
|
+
ms = expf(Mold - M);
|
16001
15829
|
|
16002
|
-
|
16003
|
-
|
16004
|
-
|
15830
|
+
// V = V*expf(Mold - M)
|
15831
|
+
ggml_vec_scale_f16(D, VKQ16, ms);
|
15832
|
+
} else {
|
15833
|
+
// no new maximum, ms == 1.0f, vs != 1.0f
|
15834
|
+
vs = expf(s - M);
|
15835
|
+
}
|
16005
15836
|
|
16006
|
-
// V
|
16007
|
-
|
15837
|
+
// V += v*expf(s - M)
|
15838
|
+
ggml_vec_mad_f16(D, VKQ16, (const ggml_fp16_t *) v_data, vs);
|
16008
15839
|
} else {
|
16009
|
-
|
16010
|
-
|
15840
|
+
if (s > M) {
|
15841
|
+
// s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
|
15842
|
+
M = s;
|
15843
|
+
ms = expf(Mold - M);
|
16011
15844
|
|
16012
|
-
|
15845
|
+
// V = V*expf(Mold - M)
|
15846
|
+
ggml_vec_scale_f32(D, VKQ32, ms);
|
15847
|
+
} else {
|
15848
|
+
// no new maximum, ms == 1.0f, vs != 1.0f
|
15849
|
+
vs = expf(s - M);
|
15850
|
+
}
|
16013
15851
|
|
16014
|
-
|
16015
|
-
|
15852
|
+
v_to_float(v_data, V32, D);
|
15853
|
+
|
15854
|
+
// V += v*expf(s - M)
|
15855
|
+
ggml_vec_mad_f32(D, VKQ32, V32, vs);
|
15856
|
+
}
|
16016
15857
|
|
16017
|
-
S = S*ms + vs;
|
15858
|
+
S = S*ms + vs; // scale and increment sum with partial sum
|
16018
15859
|
}
|
16019
15860
|
|
16020
|
-
|
16021
|
-
|
16022
|
-
|
15861
|
+
if (v->type == GGML_TYPE_F16) {
|
15862
|
+
for (int64_t d = 0; d < D; ++d) {
|
15863
|
+
VKQ32[d] = GGML_FP16_TO_FP32(VKQ16[d]);
|
15864
|
+
}
|
16023
15865
|
}
|
16024
15866
|
|
15867
|
+
// V /= S
|
15868
|
+
const float S_inv = 1.0f/S;
|
15869
|
+
ggml_vec_scale_f32(D, VKQ32, S_inv);
|
15870
|
+
|
16025
15871
|
// dst indices
|
16026
15872
|
const int i1 = iq1;
|
16027
15873
|
const int i2 = iq2;
|
@@ -16031,7 +15877,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
16031
15877
|
//memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float));
|
16032
15878
|
|
16033
15879
|
// permute(0, 2, 1, 3)
|
16034
|
-
memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1,
|
15880
|
+
memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1);
|
16035
15881
|
}
|
16036
15882
|
}
|
16037
15883
|
|
@@ -16056,165 +15902,6 @@ static void ggml_compute_forward_flash_attn_ext(
|
|
16056
15902
|
}
|
16057
15903
|
}
|
16058
15904
|
|
16059
|
-
// ggml_compute_forward_flash_ff
|
16060
|
-
|
16061
|
-
static void ggml_compute_forward_flash_ff_f16(
|
16062
|
-
const struct ggml_compute_params * params,
|
16063
|
-
struct ggml_tensor * dst) {
|
16064
|
-
|
16065
|
-
const struct ggml_tensor * a = dst->src[0]; // F16
|
16066
|
-
const struct ggml_tensor * b0 = dst->src[1]; // F16 fc_w
|
16067
|
-
const struct ggml_tensor * b1 = dst->src[2]; // F32 fc_b
|
16068
|
-
const struct ggml_tensor * c0 = dst->src[3]; // F16 proj_w
|
16069
|
-
const struct ggml_tensor * c1 = dst->src[4]; // F32 proj_b
|
16070
|
-
|
16071
|
-
int64_t t0 = ggml_perf_time_us();
|
16072
|
-
UNUSED(t0);
|
16073
|
-
|
16074
|
-
GGML_TENSOR_LOCALS(int64_t, nea, a, ne)
|
16075
|
-
GGML_TENSOR_LOCALS(size_t, nba, a, nb)
|
16076
|
-
GGML_TENSOR_LOCALS(int64_t, neb0, b0, ne)
|
16077
|
-
GGML_TENSOR_LOCALS(size_t, nbb0, b0, nb)
|
16078
|
-
GGML_TENSOR_LOCALS(int64_t, neb1, b1, ne)
|
16079
|
-
GGML_TENSOR_LOCALS(size_t, nbb1, b1, nb)
|
16080
|
-
GGML_TENSOR_LOCALS(int64_t, nec0, c0, ne)
|
16081
|
-
GGML_TENSOR_LOCALS(size_t, nbc0, c0, nb)
|
16082
|
-
GGML_TENSOR_LOCALS(int64_t, nec1, c1, ne)
|
16083
|
-
GGML_TENSOR_LOCALS(size_t, nbc1, c1, nb)
|
16084
|
-
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
16085
|
-
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
16086
|
-
|
16087
|
-
const int ith = params->ith;
|
16088
|
-
const int nth = params->nth;
|
16089
|
-
|
16090
|
-
const int64_t D = nea0;
|
16091
|
-
//const int64_t N = nea1;
|
16092
|
-
const int64_t M = neb01;
|
16093
|
-
|
16094
|
-
GGML_ASSERT(ne0 == nea0);
|
16095
|
-
GGML_ASSERT(ne1 == nea1);
|
16096
|
-
GGML_ASSERT(ne2 == nea2);
|
16097
|
-
|
16098
|
-
GGML_ASSERT(nba0 == sizeof(ggml_fp16_t));
|
16099
|
-
GGML_ASSERT(nbb00 == sizeof(ggml_fp16_t));
|
16100
|
-
GGML_ASSERT(nbb10 == sizeof(float));
|
16101
|
-
GGML_ASSERT(nbc00 == sizeof(ggml_fp16_t));
|
16102
|
-
GGML_ASSERT(nbc10 == sizeof(float));
|
16103
|
-
|
16104
|
-
GGML_ASSERT(neb00 == D);
|
16105
|
-
GGML_ASSERT(neb01 == M);
|
16106
|
-
GGML_ASSERT(neb10 == M);
|
16107
|
-
GGML_ASSERT(neb11 == 1);
|
16108
|
-
|
16109
|
-
GGML_ASSERT(nec00 == M);
|
16110
|
-
GGML_ASSERT(nec01 == D);
|
16111
|
-
GGML_ASSERT(nec10 == D);
|
16112
|
-
GGML_ASSERT(nec11 == 1);
|
16113
|
-
|
16114
|
-
// dst cannot be transposed or permuted
|
16115
|
-
GGML_ASSERT(nb0 == sizeof(float));
|
16116
|
-
GGML_ASSERT(nb0 <= nb1);
|
16117
|
-
GGML_ASSERT(nb1 <= nb2);
|
16118
|
-
GGML_ASSERT(nb2 <= nb3);
|
16119
|
-
|
16120
|
-
if (params->type == GGML_TASK_TYPE_INIT) {
|
16121
|
-
return;
|
16122
|
-
}
|
16123
|
-
|
16124
|
-
if (params->type == GGML_TASK_TYPE_FINALIZE) {
|
16125
|
-
return;
|
16126
|
-
}
|
16127
|
-
|
16128
|
-
// parallelize by a rows using ggml_vec_dot_f32
|
16129
|
-
|
16130
|
-
// total rows in a
|
16131
|
-
const int nr = nea1*nea2*nea3;
|
16132
|
-
|
16133
|
-
// rows per thread
|
16134
|
-
const int dr = (nr + nth - 1)/nth;
|
16135
|
-
|
16136
|
-
// row range for this thread
|
16137
|
-
const int ir0 = dr*ith;
|
16138
|
-
const int ir1 = MIN(ir0 + dr, nr);
|
16139
|
-
|
16140
|
-
for (int ir = ir0; ir < ir1; ++ir) {
|
16141
|
-
// a indices
|
16142
|
-
const int ia3 = ir/(nea2*nea1);
|
16143
|
-
const int ia2 = (ir - ia3*nea2*nea1)/nea1;
|
16144
|
-
const int ia1 = (ir - ia3*nea2*nea1 - ia2*nea1);
|
16145
|
-
|
16146
|
-
float * S = (float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32);
|
16147
|
-
|
16148
|
-
for (int64_t ic = 0; ic < neb01; ++ic) {
|
16149
|
-
// b0 indices
|
16150
|
-
const int ib03 = ia3;
|
16151
|
-
const int ib02 = ia2;
|
16152
|
-
const int ib01 = ic;
|
16153
|
-
|
16154
|
-
// S indices
|
16155
|
-
const int i1 = ib01;
|
16156
|
-
|
16157
|
-
ggml_vec_dot_f16(nea0,
|
16158
|
-
S + i1, 0,
|
16159
|
-
(ggml_fp16_t *) ((char *) b0->data + (ib01*nbb01 + ib02*nbb02 + ib03*nbb03)), 0,
|
16160
|
-
(ggml_fp16_t *) ((char *) a->data + ( ia1*nba1 + ia2*nba2 + ia3*nba3)), 0, 1);
|
16161
|
-
}
|
16162
|
-
|
16163
|
-
ggml_vec_add_f32(neb01, S, S, (float *) b1->data);
|
16164
|
-
//ggml_vec_gelu_f32(neb01, S, S);
|
16165
|
-
|
16166
|
-
ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32) + M);
|
16167
|
-
|
16168
|
-
for (int64_t i = 0; i < M; i++) {
|
16169
|
-
S16[i] = GGML_FP32_TO_FP16(S[i]);
|
16170
|
-
}
|
16171
|
-
|
16172
|
-
ggml_vec_gelu_f16(neb01, S16, S16);
|
16173
|
-
|
16174
|
-
{
|
16175
|
-
// dst indices
|
16176
|
-
const int i1 = ia1;
|
16177
|
-
const int i2 = ia2;
|
16178
|
-
const int i3 = ia3;
|
16179
|
-
|
16180
|
-
for (int64_t ic = 0; ic < nec01; ++ic) {
|
16181
|
-
|
16182
|
-
ggml_vec_dot_f16(neb01,
|
16183
|
-
(float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), 0,
|
16184
|
-
(ggml_fp16_t *) ((char *) c0->data + ( ic*nbc01 + i2*nbc02 + i3*nbc03)), 0,
|
16185
|
-
S16, 0, 1);
|
16186
|
-
}
|
16187
|
-
|
16188
|
-
ggml_vec_add_f32(nec01,
|
16189
|
-
(float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)),
|
16190
|
-
(float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)),
|
16191
|
-
(float *) c1->data);
|
16192
|
-
}
|
16193
|
-
}
|
16194
|
-
}
|
16195
|
-
|
16196
|
-
static void ggml_compute_forward_flash_ff(
|
16197
|
-
const struct ggml_compute_params * params,
|
16198
|
-
struct ggml_tensor * dst) {
|
16199
|
-
|
16200
|
-
const struct ggml_tensor * b0 = dst->src[1];
|
16201
|
-
|
16202
|
-
switch (b0->type) {
|
16203
|
-
case GGML_TYPE_F16:
|
16204
|
-
{
|
16205
|
-
ggml_compute_forward_flash_ff_f16(params, dst);
|
16206
|
-
} break;
|
16207
|
-
case GGML_TYPE_F32:
|
16208
|
-
{
|
16209
|
-
GGML_ASSERT(false); // TODO
|
16210
|
-
} break;
|
16211
|
-
default:
|
16212
|
-
{
|
16213
|
-
GGML_ASSERT(false);
|
16214
|
-
} break;
|
16215
|
-
}
|
16216
|
-
}
|
16217
|
-
|
16218
15905
|
// ggml_compute_forward_flash_attn_back
|
16219
15906
|
|
16220
15907
|
static void ggml_compute_forward_flash_attn_back_f32(
|
@@ -17785,21 +17472,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
17785
17472
|
{
|
17786
17473
|
ggml_compute_forward_leaky_relu(params, tensor);
|
17787
17474
|
} break;
|
17788
|
-
case GGML_OP_FLASH_ATTN:
|
17789
|
-
{
|
17790
|
-
const int32_t t = ggml_get_op_params_i32(tensor, 0);
|
17791
|
-
GGML_ASSERT(t == 0 || t == 1);
|
17792
|
-
const bool masked = t != 0;
|
17793
|
-
ggml_compute_forward_flash_attn(params, masked, tensor);
|
17794
|
-
} break;
|
17795
17475
|
case GGML_OP_FLASH_ATTN_EXT:
|
17796
17476
|
{
|
17797
17477
|
ggml_compute_forward_flash_attn_ext(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor);
|
17798
17478
|
} break;
|
17799
|
-
case GGML_OP_FLASH_FF:
|
17800
|
-
{
|
17801
|
-
ggml_compute_forward_flash_ff(params, tensor);
|
17802
|
-
} break;
|
17803
17479
|
case GGML_OP_FLASH_ATTN_BACK:
|
17804
17480
|
{
|
17805
17481
|
int32_t t = ggml_get_op_params_i32(tensor, 0);
|
@@ -18169,6 +17845,7 @@ static struct ggml_tensor * ggml_sub_or_set(struct ggml_context * ctx, struct gg
|
|
18169
17845
|
static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, struct ggml_hash_set zero_table) {
|
18170
17846
|
struct ggml_tensor * src0 = tensor->src[0];
|
18171
17847
|
struct ggml_tensor * src1 = tensor->src[1];
|
17848
|
+
struct ggml_tensor * src2 = tensor->src[2];
|
18172
17849
|
|
18173
17850
|
switch (tensor->op) {
|
18174
17851
|
case GGML_OP_DUP:
|
@@ -18700,6 +18377,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
18700
18377
|
ggml_rope_back(ctx,
|
18701
18378
|
tensor->grad,
|
18702
18379
|
src1,
|
18380
|
+
src2,
|
18703
18381
|
n_dims,
|
18704
18382
|
mode,
|
18705
18383
|
n_ctx,
|
@@ -18739,6 +18417,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
18739
18417
|
ggml_rope_impl(ctx,
|
18740
18418
|
tensor->grad,
|
18741
18419
|
src1,
|
18420
|
+
src2,
|
18742
18421
|
n_dims,
|
18743
18422
|
mode,
|
18744
18423
|
n_ctx,
|
@@ -18803,7 +18482,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
18803
18482
|
{
|
18804
18483
|
GGML_ASSERT(false); // TODO: not implemented
|
18805
18484
|
} break;
|
18806
|
-
case GGML_OP_FLASH_ATTN:
|
18807
18485
|
case GGML_OP_FLASH_ATTN_EXT:
|
18808
18486
|
{
|
18809
18487
|
struct ggml_tensor * flash_grad = NULL;
|
@@ -18820,7 +18498,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
18820
18498
|
masked);
|
18821
18499
|
}
|
18822
18500
|
|
18823
|
-
struct ggml_tensor * src2 = tensor->src[2];
|
18824
18501
|
const int64_t elem_q = ggml_nelements(src0);
|
18825
18502
|
const int64_t elem_k = ggml_nelements(src1);
|
18826
18503
|
const int64_t elem_v = ggml_nelements(src2);
|
@@ -18858,10 +18535,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
18858
18535
|
zero_table);
|
18859
18536
|
}
|
18860
18537
|
} break;
|
18861
|
-
case GGML_OP_FLASH_FF:
|
18862
|
-
{
|
18863
|
-
GGML_ASSERT(false); // not supported
|
18864
|
-
} break;
|
18865
18538
|
case GGML_OP_FLASH_ATTN_BACK:
|
18866
18539
|
{
|
18867
18540
|
GGML_ASSERT(false); // not supported
|
@@ -19548,15 +19221,10 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_
|
|
19548
19221
|
{
|
19549
19222
|
n_tasks = n_threads;
|
19550
19223
|
} break;
|
19551
|
-
case GGML_OP_FLASH_ATTN:
|
19552
19224
|
case GGML_OP_FLASH_ATTN_EXT:
|
19553
19225
|
{
|
19554
19226
|
n_tasks = n_threads;
|
19555
19227
|
} break;
|
19556
|
-
case GGML_OP_FLASH_FF:
|
19557
|
-
{
|
19558
|
-
n_tasks = n_threads;
|
19559
|
-
} break;
|
19560
19228
|
case GGML_OP_FLASH_ATTN_BACK:
|
19561
19229
|
{
|
19562
19230
|
n_tasks = n_threads;
|
@@ -19953,39 +19621,11 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
|
|
19953
19621
|
cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02*ne03;
|
19954
19622
|
cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12;
|
19955
19623
|
} break;
|
19956
|
-
case GGML_OP_FLASH_ATTN:
|
19957
|
-
{
|
19958
|
-
const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL);
|
19959
|
-
|
19960
|
-
if (node->src[1]->type == GGML_TYPE_F32) {
|
19961
|
-
cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
|
19962
|
-
cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
|
19963
|
-
} else if (node->src[1]->type == GGML_TYPE_F16) {
|
19964
|
-
cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
|
19965
|
-
cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
|
19966
|
-
} else if (node->src[1]->type == GGML_TYPE_BF16) {
|
19967
|
-
cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
|
19968
|
-
cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
|
19969
|
-
}
|
19970
|
-
} break;
|
19971
19624
|
case GGML_OP_FLASH_ATTN_EXT:
|
19972
19625
|
{
|
19973
19626
|
const int64_t ne00 = node->src[0]->ne[0]; // D
|
19974
19627
|
|
19975
|
-
cur =
|
19976
|
-
} break;
|
19977
|
-
case GGML_OP_FLASH_FF:
|
19978
|
-
{
|
19979
|
-
if (node->src[1]->type == GGML_TYPE_F32) {
|
19980
|
-
cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
|
19981
|
-
cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
|
19982
|
-
} else if (node->src[1]->type == GGML_TYPE_F16) {
|
19983
|
-
cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
|
19984
|
-
cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
|
19985
|
-
} else if (node->src[1]->type == GGML_TYPE_BF16) {
|
19986
|
-
cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
|
19987
|
-
cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
|
19988
|
-
}
|
19628
|
+
cur = 3*sizeof(float)*ne00*n_tasks; // 3x head size/thread
|
19989
19629
|
} break;
|
19990
19630
|
case GGML_OP_FLASH_ATTN_BACK:
|
19991
19631
|
{
|
@@ -21827,11 +21467,7 @@ size_t ggml_quantize_chunk(
|
|
21827
21467
|
case GGML_TYPE_IQ1_S: result = quantize_iq1_s (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
21828
21468
|
case GGML_TYPE_IQ1_M: result = quantize_iq1_m (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
21829
21469
|
case GGML_TYPE_IQ4_NL: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
21830
|
-
#if QK_K == 64
|
21831
|
-
case GGML_TYPE_IQ4_XS: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
21832
|
-
#else
|
21833
21470
|
case GGML_TYPE_IQ4_XS: result = quantize_iq4_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
21834
|
-
#endif
|
21835
21471
|
case GGML_TYPE_F16:
|
21836
21472
|
{
|
21837
21473
|
size_t elemsize = sizeof(ggml_fp16_t);
|
@@ -23108,6 +22744,14 @@ int ggml_cpu_has_avx512_vnni(void) {
|
|
23108
22744
|
#endif
|
23109
22745
|
}
|
23110
22746
|
|
22747
|
+
int ggml_cpu_has_avx512_bf16(void) {
|
22748
|
+
#if defined(__AVX512BF16__)
|
22749
|
+
return 1;
|
22750
|
+
#else
|
22751
|
+
return 0;
|
22752
|
+
#endif
|
22753
|
+
}
|
22754
|
+
|
23111
22755
|
int ggml_cpu_has_fma(void) {
|
23112
22756
|
#if defined(__FMA__)
|
23113
22757
|
return 1;
|
@@ -23124,6 +22768,16 @@ int ggml_cpu_has_neon(void) {
|
|
23124
22768
|
#endif
|
23125
22769
|
}
|
23126
22770
|
|
22771
|
+
int ggml_cpu_has_sve(void) {
|
22772
|
+
#if defined(__ARM_FEATURE_SVE)
|
22773
|
+
// TODO: Currently, SVE 256 bit is only supported.
|
22774
|
+
GGML_ASSERT(svcntb() == QK8_0);
|
22775
|
+
return 1;
|
22776
|
+
#else
|
22777
|
+
return 0;
|
22778
|
+
#endif
|
22779
|
+
}
|
22780
|
+
|
23127
22781
|
int ggml_cpu_has_arm_fma(void) {
|
23128
22782
|
#if defined(__ARM_FEATURE_FMA)
|
23129
22783
|
return 1;
|
@@ -23212,6 +22866,14 @@ int ggml_cpu_has_sycl(void) {
|
|
23212
22866
|
#endif
|
23213
22867
|
}
|
23214
22868
|
|
22869
|
+
int ggml_cpu_has_rpc(void) {
|
22870
|
+
#if defined(GGML_USE_RPC)
|
22871
|
+
return 1;
|
22872
|
+
#else
|
22873
|
+
return 0;
|
22874
|
+
#endif
|
22875
|
+
}
|
22876
|
+
|
23215
22877
|
int ggml_cpu_has_gpublas(void) {
|
23216
22878
|
return ggml_cpu_has_cuda() || ggml_cpu_has_clblast() || ggml_cpu_has_vulkan() || ggml_cpu_has_kompute() ||
|
23217
22879
|
ggml_cpu_has_sycl();
|