llama_cpp 0.15.2 → 0.15.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +14 -0
- data/ext/llama_cpp/llama_cpp.cpp +61 -0
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +6 -0
- data/vendor/tmp/llama.cpp/Makefile +8 -16
- data/vendor/tmp/llama.cpp/ggml-common.h +0 -54
- data/vendor/tmp/llama.cpp/ggml-cuda.cu +99 -40
- data/vendor/tmp/llama.cpp/ggml-cuda.h +1 -0
- data/vendor/tmp/llama.cpp/ggml-impl.h +44 -0
- data/vendor/tmp/llama.cpp/ggml-kompute.cpp +4 -1
- data/vendor/tmp/llama.cpp/ggml-metal.m +133 -81
- data/vendor/tmp/llama.cpp/ggml-metal.metal +91 -434
- data/vendor/tmp/llama.cpp/ggml-opencl.cpp +4 -1
- data/vendor/tmp/llama.cpp/ggml-quants.c +1962 -2443
- data/vendor/tmp/llama.cpp/ggml-rpc.cpp +248 -108
- data/vendor/tmp/llama.cpp/ggml-sycl.cpp +375 -657
- data/vendor/tmp/llama.cpp/ggml-vulkan-shaders.hpp +9351 -5627
- data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +204 -225
- data/vendor/tmp/llama.cpp/ggml.c +498 -836
- data/vendor/tmp/llama.cpp/ggml.h +57 -30
- data/vendor/tmp/llama.cpp/llama.cpp +1477 -859
- data/vendor/tmp/llama.cpp/llama.h +21 -8
- metadata +3 -3
data/vendor/tmp/llama.cpp/ggml.c
CHANGED
@@ -60,6 +60,9 @@
|
|
60
60
|
|
61
61
|
typedef volatile LONG atomic_int;
|
62
62
|
typedef atomic_int atomic_bool;
|
63
|
+
typedef atomic_int atomic_flag;
|
64
|
+
|
65
|
+
#define ATOMIC_FLAG_INIT 0
|
63
66
|
|
64
67
|
static void atomic_store(atomic_int * ptr, LONG val) {
|
65
68
|
InterlockedExchange(ptr, val);
|
@@ -73,6 +76,12 @@ static LONG atomic_fetch_add(atomic_int * ptr, LONG inc) {
|
|
73
76
|
static LONG atomic_fetch_sub(atomic_int * ptr, LONG dec) {
|
74
77
|
return atomic_fetch_add(ptr, -(dec));
|
75
78
|
}
|
79
|
+
static atomic_bool atomic_flag_test_and_set(atomic_flag * ptr) {
|
80
|
+
return InterlockedExchange(ptr, 1);
|
81
|
+
}
|
82
|
+
static void atomic_flag_clear(atomic_flag * ptr) {
|
83
|
+
InterlockedExchange(ptr, 0);
|
84
|
+
}
|
76
85
|
|
77
86
|
typedef HANDLE pthread_t;
|
78
87
|
|
@@ -406,10 +415,10 @@ void ggml_fp32_to_bf16_row(const float * x, ggml_bf16_t * y, int64_t n) {
|
|
406
415
|
int i = 0;
|
407
416
|
#if defined(__AVX512BF16__)
|
408
417
|
for (; i + 32 <= n; i += 32) {
|
409
|
-
|
410
|
-
(
|
411
|
-
(
|
412
|
-
|
418
|
+
_mm512_storeu_si512(
|
419
|
+
(__m512i *)(y + i),
|
420
|
+
m512i(_mm512_cvtne2ps_pbh(_mm512_loadu_ps(x + i + 16),
|
421
|
+
_mm512_loadu_ps(x + i))));
|
413
422
|
}
|
414
423
|
#endif
|
415
424
|
for (; i < n; i++) {
|
@@ -871,22 +880,14 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
|
871
880
|
},
|
872
881
|
[GGML_TYPE_IQ4_XS] = {
|
873
882
|
.type_name = "iq4_xs",
|
874
|
-
#if QK_K == 64
|
875
|
-
.blck_size = QK4_NL,
|
876
|
-
#else
|
877
883
|
.blck_size = QK_K,
|
878
|
-
#endif
|
879
884
|
.type_size = sizeof(block_iq4_xs),
|
880
885
|
.is_quantized = true,
|
881
886
|
.to_float = (ggml_to_float_t) dequantize_row_iq4_xs,
|
882
887
|
.from_float = quantize_row_iq4_xs,
|
883
888
|
.from_float_reference = (ggml_from_float_t)quantize_row_iq4_xs_reference,
|
884
889
|
.vec_dot = ggml_vec_dot_iq4_xs_q8_K,
|
885
|
-
#if QK_K == 64
|
886
|
-
.vec_dot_type = GGML_TYPE_Q8_0,
|
887
|
-
#else
|
888
890
|
.vec_dot_type = GGML_TYPE_Q8_K,
|
889
|
-
#endif
|
890
891
|
.nrows = 1,
|
891
892
|
},
|
892
893
|
[GGML_TYPE_Q8_K] = {
|
@@ -1523,6 +1524,196 @@ static inline void __sse_f16x4_store(ggml_fp16_t *x, __m128 y) {
|
|
1523
1524
|
#define GGML_F16_VEC_MUL GGML_F32Cx4_MUL
|
1524
1525
|
#define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE
|
1525
1526
|
|
1527
|
+
#elif defined(__loongarch_asx)
|
1528
|
+
|
1529
|
+
#define GGML_SIMD
|
1530
|
+
|
1531
|
+
// F32 LASX
|
1532
|
+
#define GGML_F32_STEP 32
|
1533
|
+
#define GGML_F32_EPR 8
|
1534
|
+
|
1535
|
+
#define GGML_F32x8 __m256
|
1536
|
+
#define GGML_F32x8_ZERO (__m256)__lasx_xvldi(0)
|
1537
|
+
#define GGML_F32x8_SET1(x) (__m256)__lasx_xvreplfr2vr_s((x))
|
1538
|
+
#define GGML_F32x8_LOAD(x) (__m256)__lasx_xvld((x), 0)
|
1539
|
+
#define GGML_F32x8_STORE(x,y) __lasx_xvst((y), (x), 0)
|
1540
|
+
#define GGML_F32x8_FMA(a, b, c) __lasx_xvfmadd_s(b, c, a)
|
1541
|
+
#define GGML_F32x8_ADD __lasx_xvfadd_s
|
1542
|
+
#define GGML_F32x8_MUL __lasx_xvfmul_s
|
1543
|
+
#define GGML_F32x8_REDUCE(res, x) \
|
1544
|
+
do { \
|
1545
|
+
int offset = GGML_F32_ARR >> 1; \
|
1546
|
+
for (int i = 0; i < offset; ++i) { \
|
1547
|
+
x[i] = __lasx_xvfadd_s(x[i], x[offset+i]); \
|
1548
|
+
} \
|
1549
|
+
offset >>= 1; \
|
1550
|
+
for (int i = 0; i < offset; ++i) { \
|
1551
|
+
x[i] = __lasx_xvfadd_s(x[i], x[offset+i]); \
|
1552
|
+
} \
|
1553
|
+
offset >>= 1; \
|
1554
|
+
for (int i = 0; i < offset; ++i) { \
|
1555
|
+
x[i] = __lasx_xvfadd_s(x[i], x[offset+i]); \
|
1556
|
+
} \
|
1557
|
+
float *tmp_p = (float *)&x[0]; \
|
1558
|
+
res = tmp_p[0] + tmp_p[1] + tmp_p[2] + tmp_p[3] + tmp_p[4] + tmp_p[5] + tmp_p[6] + tmp_p[7]; \
|
1559
|
+
} while (0)
|
1560
|
+
// TODO: is this optimal ?
|
1561
|
+
|
1562
|
+
#define GGML_F32_VEC GGML_F32x8
|
1563
|
+
#define GGML_F32_VEC_ZERO GGML_F32x8_ZERO
|
1564
|
+
#define GGML_F32_VEC_SET1 GGML_F32x8_SET1
|
1565
|
+
#define GGML_F32_VEC_LOAD GGML_F32x8_LOAD
|
1566
|
+
#define GGML_F32_VEC_STORE GGML_F32x8_STORE
|
1567
|
+
#define GGML_F32_VEC_FMA GGML_F32x8_FMA
|
1568
|
+
#define GGML_F32_VEC_ADD GGML_F32x8_ADD
|
1569
|
+
#define GGML_F32_VEC_MUL GGML_F32x8_MUL
|
1570
|
+
#define GGML_F32_VEC_REDUCE GGML_F32x8_REDUCE
|
1571
|
+
|
1572
|
+
// F16 LASX
|
1573
|
+
|
1574
|
+
#define GGML_F16_STEP 32
|
1575
|
+
#define GGML_F16_EPR 8
|
1576
|
+
|
1577
|
+
// F16 arithmetic is not supported by AVX, so we use F32 instead
|
1578
|
+
|
1579
|
+
#define GGML_F32Cx8 __m256
|
1580
|
+
#define GGML_F32Cx8_ZERO (__m256)__lasx_xvldi(0)
|
1581
|
+
#define GGML_F32Cx8_SET1(x) (__m256)__lasx_xvreplgr2vr_w((x))
|
1582
|
+
|
1583
|
+
static inline __m256 __lasx_f32cx8_load(const ggml_fp16_t * x) {
|
1584
|
+
float tmp[8];
|
1585
|
+
|
1586
|
+
for (int i = 0; i < 8; i++) {
|
1587
|
+
tmp[i] = GGML_FP16_TO_FP32(x[i]);
|
1588
|
+
}
|
1589
|
+
|
1590
|
+
return (__m256)__lasx_xvld(tmp, 0);
|
1591
|
+
}
|
1592
|
+
static inline void __lasx_f32cx8_store(ggml_fp16_t * x, __m256 y) {
|
1593
|
+
float arr[8];
|
1594
|
+
|
1595
|
+
__lasx_xvst(y, arr, 0);
|
1596
|
+
|
1597
|
+
for (int i = 0; i < 8; i++) {
|
1598
|
+
x[i] = GGML_FP32_TO_FP16(arr[i]);
|
1599
|
+
}
|
1600
|
+
}
|
1601
|
+
#define GGML_F32Cx8_LOAD(x) __lasx_f32cx8_load(x)
|
1602
|
+
#define GGML_F32Cx8_STORE(x, y) __lasx_f32cx8_store(x, y)
|
1603
|
+
|
1604
|
+
#define GGML_F32Cx8_FMA GGML_F32x8_FMA
|
1605
|
+
#define GGML_F32Cx8_ADD __lasx_xvfadd_s
|
1606
|
+
#define GGML_F32Cx8_MUL __lasx_xvfmul_s
|
1607
|
+
#define GGML_F32Cx8_REDUCE GGML_F32x8_REDUCE
|
1608
|
+
|
1609
|
+
#define GGML_F16_VEC GGML_F32Cx8
|
1610
|
+
#define GGML_F16_VEC_ZERO GGML_F32Cx8_ZERO
|
1611
|
+
#define GGML_F16_VEC_SET1 GGML_F32Cx8_SET1
|
1612
|
+
#define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx8_LOAD(p)
|
1613
|
+
#define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx8_STORE(p, r[i])
|
1614
|
+
#define GGML_F16_VEC_FMA GGML_F32Cx8_FMA
|
1615
|
+
#define GGML_F16_VEC_ADD GGML_F32Cx8_ADD
|
1616
|
+
#define GGML_F16_VEC_MUL GGML_F32Cx8_MUL
|
1617
|
+
#define GGML_F16_VEC_REDUCE GGML_F32Cx8_REDUCE
|
1618
|
+
|
1619
|
+
#elif defined(__loongarch_sx)
|
1620
|
+
|
1621
|
+
#define GGML_SIMD
|
1622
|
+
|
1623
|
+
// F32 LSX
|
1624
|
+
|
1625
|
+
#define GGML_F32_STEP 32
|
1626
|
+
#define GGML_F32_EPR 4
|
1627
|
+
|
1628
|
+
#define GGML_F32x4 __m128
|
1629
|
+
#define GGML_F32x4_ZERO __lsx_vldi(0)
|
1630
|
+
#define GGML_F32x4_SET1(x) __lsx_vinsgr2vr_w(__lsx_vldi(0),(x), 0)
|
1631
|
+
#define GGML_F32x4_LOAD(x) __lsx_vld((x), 0)
|
1632
|
+
#define GGML_F32x4_STORE((x),(y)) __lsx_vst((y), (x), 0)
|
1633
|
+
#define GGML_F32x4_FMA(a, b, c) __lsx_vfmadd_s(b, c, a)
|
1634
|
+
#define GGML_F32x4_ADD __lsx_vfadd_s
|
1635
|
+
#define GGML_F32x4_MUL __lsx_vfmul_s
|
1636
|
+
#define GGML_F32x4_REDUCE(res, x) \
|
1637
|
+
{ \
|
1638
|
+
int offset = GGML_F32_ARR >> 1; \
|
1639
|
+
for (int i = 0; i < offset; ++i) { \
|
1640
|
+
x[i] = __lsx_vfadd_s(x[i], x[offset+i]); \
|
1641
|
+
} \
|
1642
|
+
offset >>= 1; \
|
1643
|
+
for (int i = 0; i < offset; ++i) { \
|
1644
|
+
x[i] = __lsx_vfadd_s(x[i], x[offset+i]); \
|
1645
|
+
} \
|
1646
|
+
offset >>= 1; \
|
1647
|
+
for (int i = 0; i < offset; ++i) { \
|
1648
|
+
x[i] = __lsx_vfadd_s(x[i], x[offset+i]); \
|
1649
|
+
} \
|
1650
|
+
__m128i tmp = __lsx_vsrli_d((__m128i)x[0], 32); \
|
1651
|
+
tmp = (__m128i)__lsx_vfadd_s((__m128)tmp, x[0]); \
|
1652
|
+
tmp = __lsx_vpickev_w(__lsx_vldi(0), tmp); \
|
1653
|
+
const __m128 t0 = __lsx_vshuf4i_w(tmp, 0x88); \
|
1654
|
+
tmp = __lsx_vsrli_d((__m128i)t0, 32); \
|
1655
|
+
tmp = (__m128i)__lsx_vfadd_s((__m128)tmp, t0); \
|
1656
|
+
tmp = __lsx_vpickev_w(__lsx_vldi(0), tmp); \
|
1657
|
+
res = (ggml_float) __lsx_vpickve2gr_w(__lsx_vshuf4i_w(tmp, 0x88), 0); \
|
1658
|
+
}
|
1659
|
+
|
1660
|
+
#define GGML_F32_VEC GGML_F32x4
|
1661
|
+
#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO
|
1662
|
+
#define GGML_F32_VEC_SET1 GGML_F32x4_SET1
|
1663
|
+
#define GGML_F32_VEC_LOAD GGML_F32x4_LOAD
|
1664
|
+
#define GGML_F32_VEC_STORE GGML_F32x4_STORE
|
1665
|
+
#define GGML_F32_VEC_FMA GGML_F32x4_FMA
|
1666
|
+
#define GGML_F32_VEC_ADD GGML_F32x4_ADD
|
1667
|
+
#define GGML_F32_VEC_MUL GGML_F32x4_MUL
|
1668
|
+
#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
|
1669
|
+
|
1670
|
+
// F16 LSX
|
1671
|
+
|
1672
|
+
#define GGML_F16_STEP 32
|
1673
|
+
#define GGML_F16_EPR 4
|
1674
|
+
|
1675
|
+
static inline __m128 __lsx_f16x4_load(const ggml_fp16_t * x) {
|
1676
|
+
float tmp[4];
|
1677
|
+
|
1678
|
+
tmp[0] = GGML_FP16_TO_FP32(x[0]);
|
1679
|
+
tmp[1] = GGML_FP16_TO_FP32(x[1]);
|
1680
|
+
tmp[2] = GGML_FP16_TO_FP32(x[2]);
|
1681
|
+
tmp[3] = GGML_FP16_TO_FP32(x[3]);
|
1682
|
+
|
1683
|
+
return __lsx_vld(tmp, 0);
|
1684
|
+
}
|
1685
|
+
|
1686
|
+
static inline void __lsx_f16x4_store(ggml_fp16_t * x, __m128 y) {
|
1687
|
+
float arr[4];
|
1688
|
+
|
1689
|
+
__lsx_vst(y, arr, 0);
|
1690
|
+
|
1691
|
+
x[0] = GGML_FP32_TO_FP16(arr[0]);
|
1692
|
+
x[1] = GGML_FP32_TO_FP16(arr[1]);
|
1693
|
+
x[2] = GGML_FP32_TO_FP16(arr[2]);
|
1694
|
+
x[3] = GGML_FP32_TO_FP16(arr[3]);
|
1695
|
+
}
|
1696
|
+
|
1697
|
+
#define GGML_F32Cx4 __m128
|
1698
|
+
#define GGML_F32Cx4_ZERO __lsx_vldi(0)
|
1699
|
+
#define GGML_F32Cx4_SET1(x) __lsx_vinsgr2vr_w(__lsx_vldi(0),(x), 0)
|
1700
|
+
#define GGML_F32Cx4_LOAD(x) __lsx_f16x4_load(x)
|
1701
|
+
#define GGML_F32Cx4_STORE(x, y) __lsx_f16x4_store(x, y)
|
1702
|
+
#define GGML_F32Cx4_FMA GGML_F32x4_FMA
|
1703
|
+
#define GGML_F32Cx4_ADD __lsx_vfadd_s
|
1704
|
+
#define GGML_F32Cx4_MUL __lsx_vfmul_s
|
1705
|
+
#define GGML_F32Cx4_REDUCE GGML_F32x4_REDUCE
|
1706
|
+
|
1707
|
+
#define GGML_F16_VEC GGML_F32Cx4
|
1708
|
+
#define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO
|
1709
|
+
#define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1
|
1710
|
+
#define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx4_LOAD(p)
|
1711
|
+
#define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE(p, r[i])
|
1712
|
+
#define GGML_F16_VEC_FMA GGML_F32Cx4_FMA
|
1713
|
+
#define GGML_F16_VEC_ADD GGML_F32Cx4_ADD
|
1714
|
+
#define GGML_F16_VEC_MUL GGML_F32Cx4_MUL
|
1715
|
+
#define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE
|
1716
|
+
|
1526
1717
|
#endif
|
1527
1718
|
|
1528
1719
|
// GGML_F32_ARR / GGML_F16_ARR
|
@@ -1666,10 +1857,10 @@ static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t
|
|
1666
1857
|
__m512 c1 = _mm512_setzero_ps();
|
1667
1858
|
__m512 c2 = _mm512_setzero_ps();
|
1668
1859
|
for (; i + 64 <= n; i += 64) {
|
1669
|
-
c1 = _mm512_dpbf16_ps(c1, (
|
1670
|
-
|
1671
|
-
c2 = _mm512_dpbf16_ps(c2, (
|
1672
|
-
|
1860
|
+
c1 = _mm512_dpbf16_ps(c1, m512bh(_mm512_loadu_si512((x + i))),
|
1861
|
+
m512bh(_mm512_loadu_si512((y + i))));
|
1862
|
+
c2 = _mm512_dpbf16_ps(c2, m512bh(_mm512_loadu_si512((x + i + 32))),
|
1863
|
+
m512bh(_mm512_loadu_si512((y + i + 32))));
|
1673
1864
|
}
|
1674
1865
|
sumf += (ggml_float)_mm512_reduce_add_ps(c1);
|
1675
1866
|
sumf += (ggml_float)_mm512_reduce_add_ps(c2);
|
@@ -2076,7 +2267,7 @@ inline static float ggml_silu_f32(float x) {
|
|
2076
2267
|
return x/(1.0f + expf(-x));
|
2077
2268
|
}
|
2078
2269
|
|
2079
|
-
#if defined(__ARM_NEON)
|
2270
|
+
#if defined(__ARM_NEON) && defined(__aarch64__)
|
2080
2271
|
|
2081
2272
|
// adapted from arm limited optimized routine
|
2082
2273
|
// the maximum error is 1.45358 plus 0.5 ulps
|
@@ -2125,32 +2316,27 @@ inline static __m512 ggml_v_expf(__m512 x) {
|
|
2125
2316
|
const __m512 r = _mm512_set1_ps(0x1.8p23f);
|
2126
2317
|
const __m512 z = _mm512_fmadd_ps(x, _mm512_set1_ps(0x1.715476p+0f), r);
|
2127
2318
|
const __m512 n = _mm512_sub_ps(z, r);
|
2128
|
-
const __m512 b =
|
2129
|
-
|
2130
|
-
|
2131
|
-
const __m512 k = _mm512_castsi512_ps(_mm512_add_epi32(e, _mm512_castps_si512(_mm512_set1_ps(1))));
|
2132
|
-
const __mmask16 c = _mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(126), _CMP_GT_OQ);
|
2133
|
-
const __m512 u = _mm512_mul_ps(b, b);
|
2134
|
-
const __m512 j = _mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_set1_ps(0x1.0e4020p-7f), b,
|
2135
|
-
_mm512_set1_ps(0x1.573e2ep-5f)), u,
|
2136
|
-
_mm512_fmadd_ps(_mm512_set1_ps(0x1.555e66p-3f), b,
|
2137
|
-
_mm512_set1_ps(0x1.fffdb6p-2f))),
|
2138
|
-
u, _mm512_mul_ps(_mm512_set1_ps(0x1.ffffecp-1f), b));
|
2139
|
-
if (_mm512_kortestz(c, c))
|
2140
|
-
return _mm512_fmadd_ps(j, k, k);
|
2141
|
-
const __m512i g = _mm512_and_si512(
|
2142
|
-
_mm512_movm_epi32(_mm512_cmp_ps_mask(n, _mm512_setzero_ps(), _CMP_LE_OQ)),
|
2143
|
-
_mm512_set1_epi32(0x82000000u));
|
2144
|
-
const __m512 s1 =
|
2145
|
-
_mm512_castsi512_ps(_mm512_add_epi32(g, _mm512_set1_epi32(0x7f000000u)));
|
2146
|
-
const __m512 s2 = _mm512_castsi512_ps(_mm512_sub_epi32(e, g));
|
2319
|
+
const __m512 b =
|
2320
|
+
_mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.7f7d1cp-20f),
|
2321
|
+
_mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.62e4p-1f), x));
|
2147
2322
|
const __mmask16 d =
|
2148
2323
|
_mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(192), _CMP_GT_OQ);
|
2149
|
-
|
2150
|
-
|
2151
|
-
|
2152
|
-
|
2153
|
-
|
2324
|
+
const __m512 u = _mm512_mul_ps(b, b);
|
2325
|
+
const __m512 j = _mm512_fmadd_ps(
|
2326
|
+
_mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_set1_ps(0x1.0e4020p-7f), b,
|
2327
|
+
_mm512_set1_ps(0x1.573e2ep-5f)),
|
2328
|
+
u,
|
2329
|
+
_mm512_fmadd_ps(_mm512_set1_ps(0x1.555e66p-3f), b,
|
2330
|
+
_mm512_set1_ps(0x1.fffdb6p-2f))),
|
2331
|
+
u,
|
2332
|
+
_mm512_fmadd_ps(_mm512_set1_ps(0x1.ffffecp-1f), b, _mm512_set1_ps(1.0F)));
|
2333
|
+
const __m512 res = _mm512_scalef_ps(j, n);
|
2334
|
+
if (_mm512_kortestz(d, d))
|
2335
|
+
return res;
|
2336
|
+
const __m512 zero = _mm512_setzero_ps();
|
2337
|
+
const __m512 alt = _mm512_mask_blend_ps(
|
2338
|
+
_mm512_cmp_ps_mask(n, zero, _CMP_LE_OQ), _mm512_set1_ps(INFINITY), zero);
|
2339
|
+
return _mm512_mask_blend_ps(d, res, alt);
|
2154
2340
|
}
|
2155
2341
|
|
2156
2342
|
// computes silu x/(1+exp(-x)) in single precision vector
|
@@ -2288,7 +2474,7 @@ static void ggml_vec_silu_f32(const int n, float * y, const float * x) {
|
|
2288
2474
|
for (; i + 3 < n; i += 4) {
|
2289
2475
|
_mm_storeu_ps(y + i, ggml_v_silu(_mm_loadu_ps(x + i)));
|
2290
2476
|
}
|
2291
|
-
#elif defined(__ARM_NEON)
|
2477
|
+
#elif defined(__ARM_NEON) && defined(__aarch64__)
|
2292
2478
|
for (; i + 3 < n; i += 4) {
|
2293
2479
|
vst1q_f32(y + i, ggml_v_silu(vld1q_f32(x + i)));
|
2294
2480
|
}
|
@@ -2335,7 +2521,7 @@ static ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x,
|
|
2335
2521
|
#endif
|
2336
2522
|
sum += (ggml_float)_mm_cvtss_f32(val);
|
2337
2523
|
}
|
2338
|
-
#elif defined(__ARM_NEON)
|
2524
|
+
#elif defined(__ARM_NEON) && defined(__aarch64__)
|
2339
2525
|
for (; i + 3 < n; i += 4) {
|
2340
2526
|
float32x4_t val = ggml_v_expf(vsubq_f32(vld1q_f32(x + i),
|
2341
2527
|
vdupq_n_f32(max)));
|
@@ -2489,9 +2675,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|
2489
2675
|
"ARGSORT",
|
2490
2676
|
"LEAKY_RELU",
|
2491
2677
|
|
2492
|
-
"FLASH_ATTN",
|
2493
2678
|
"FLASH_ATTN_EXT",
|
2494
|
-
"FLASH_FF",
|
2495
2679
|
"FLASH_ATTN_BACK",
|
2496
2680
|
"SSM_CONV",
|
2497
2681
|
"SSM_SCAN",
|
@@ -2517,7 +2701,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|
2517
2701
|
"CROSS_ENTROPY_LOSS_BACK",
|
2518
2702
|
};
|
2519
2703
|
|
2520
|
-
static_assert(GGML_OP_COUNT ==
|
2704
|
+
static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");
|
2521
2705
|
|
2522
2706
|
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
2523
2707
|
"none",
|
@@ -2579,9 +2763,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|
2579
2763
|
"argsort(x)",
|
2580
2764
|
"leaky_relu(x)",
|
2581
2765
|
|
2582
|
-
"flash_attn(x)",
|
2583
2766
|
"flash_attn_ext(x)",
|
2584
|
-
"flash_ff(x)",
|
2585
2767
|
"flash_attn_back(x)",
|
2586
2768
|
"ssm_conv(x)",
|
2587
2769
|
"ssm_scan(x)",
|
@@ -2607,7 +2789,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|
2607
2789
|
"cross_entropy_loss_back(x,y)",
|
2608
2790
|
};
|
2609
2791
|
|
2610
|
-
static_assert(GGML_OP_COUNT ==
|
2792
|
+
static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");
|
2611
2793
|
|
2612
2794
|
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
2613
2795
|
|
@@ -2706,24 +2888,20 @@ struct ggml_state {
|
|
2706
2888
|
|
2707
2889
|
// global state
|
2708
2890
|
static struct ggml_state g_state;
|
2709
|
-
static
|
2891
|
+
static atomic_flag g_state_critical = ATOMIC_FLAG_INIT;
|
2710
2892
|
|
2711
2893
|
// barrier via spin lock
|
2712
2894
|
inline static void ggml_critical_section_start(void) {
|
2713
|
-
|
2714
|
-
|
2715
|
-
|
2716
|
-
// wait for other threads to finish
|
2717
|
-
atomic_fetch_sub(&g_state_barrier, 1);
|
2718
|
-
sched_yield(); // TODO: reconsider this
|
2719
|
-
processing = atomic_fetch_add(&g_state_barrier, 1);
|
2895
|
+
while (atomic_flag_test_and_set(&g_state_critical)) {
|
2896
|
+
// spin
|
2897
|
+
sched_yield();
|
2720
2898
|
}
|
2721
2899
|
}
|
2722
2900
|
|
2723
2901
|
// TODO: make this somehow automatically executed
|
2724
2902
|
// some sort of "sentry" mechanism
|
2725
2903
|
inline static void ggml_critical_section_end(void) {
|
2726
|
-
|
2904
|
+
atomic_flag_clear(&g_state_critical);
|
2727
2905
|
}
|
2728
2906
|
|
2729
2907
|
#if defined(__gnu_linux__)
|
@@ -3039,7 +3217,11 @@ GGML_CALL bool ggml_is_contiguous(const struct ggml_tensor * tensor) {
|
|
3039
3217
|
tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
|
3040
3218
|
}
|
3041
3219
|
|
3042
|
-
|
3220
|
+
GGML_CALL bool ggml_is_contiguous_0(const struct ggml_tensor * tensor) {
|
3221
|
+
return ggml_is_contiguous(tensor);
|
3222
|
+
}
|
3223
|
+
|
3224
|
+
GGML_CALL bool ggml_is_contiguous_1(const struct ggml_tensor * tensor) {
|
3043
3225
|
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
3044
3226
|
|
3045
3227
|
return
|
@@ -3048,6 +3230,14 @@ static inline bool ggml_is_contiguous_except_dim_1(const struct ggml_tensor * te
|
|
3048
3230
|
tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
|
3049
3231
|
}
|
3050
3232
|
|
3233
|
+
GGML_CALL bool ggml_is_contiguous_2(const struct ggml_tensor * tensor) {
|
3234
|
+
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
3235
|
+
|
3236
|
+
return
|
3237
|
+
tensor->nb[0] == ggml_type_size(tensor->type) &&
|
3238
|
+
tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
|
3239
|
+
}
|
3240
|
+
|
3051
3241
|
GGML_CALL bool ggml_is_permuted(const struct ggml_tensor * tensor) {
|
3052
3242
|
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
3053
3243
|
|
@@ -4705,10 +4895,21 @@ struct ggml_tensor * ggml_repeat_back(
|
|
4705
4895
|
// ggml_concat
|
4706
4896
|
|
4707
4897
|
struct ggml_tensor * ggml_concat(
|
4708
|
-
struct ggml_context* ctx,
|
4709
|
-
struct ggml_tensor* a,
|
4710
|
-
struct ggml_tensor* b
|
4711
|
-
|
4898
|
+
struct ggml_context * ctx,
|
4899
|
+
struct ggml_tensor * a,
|
4900
|
+
struct ggml_tensor * b,
|
4901
|
+
int dim) {
|
4902
|
+
GGML_ASSERT(dim >= 0 && dim < GGML_MAX_DIMS);
|
4903
|
+
|
4904
|
+
int64_t ne[GGML_MAX_DIMS];
|
4905
|
+
for (int d = 0; d < GGML_MAX_DIMS; ++d) {
|
4906
|
+
if (d == dim) {
|
4907
|
+
ne[d] = a->ne[d] + b->ne[d];
|
4908
|
+
continue;
|
4909
|
+
}
|
4910
|
+
GGML_ASSERT(a->ne[d] == b->ne[d]);
|
4911
|
+
ne[d] = a->ne[d];
|
4912
|
+
}
|
4712
4913
|
|
4713
4914
|
bool is_node = false;
|
4714
4915
|
|
@@ -4716,7 +4917,9 @@ struct ggml_tensor * ggml_concat(
|
|
4716
4917
|
is_node = true;
|
4717
4918
|
}
|
4718
4919
|
|
4719
|
-
struct ggml_tensor * result =
|
4920
|
+
struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, GGML_MAX_DIMS, ne);
|
4921
|
+
|
4922
|
+
ggml_set_op_params_i32(result, 0, dim);
|
4720
4923
|
|
4721
4924
|
result->op = GGML_OP_CONCAT;
|
4722
4925
|
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
@@ -4836,6 +5039,7 @@ struct ggml_tensor * ggml_leaky_relu(
|
|
4836
5039
|
}
|
4837
5040
|
|
4838
5041
|
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
5042
|
+
|
4839
5043
|
ggml_set_op_params(result, &negative_slope, sizeof(negative_slope));
|
4840
5044
|
|
4841
5045
|
result->op = GGML_OP_LEAKY_RELU;
|
@@ -6042,6 +6246,7 @@ static struct ggml_tensor * ggml_rope_impl(
|
|
6042
6246
|
struct ggml_context * ctx,
|
6043
6247
|
struct ggml_tensor * a,
|
6044
6248
|
struct ggml_tensor * b,
|
6249
|
+
struct ggml_tensor * c,
|
6045
6250
|
int n_dims,
|
6046
6251
|
int mode,
|
6047
6252
|
int n_ctx,
|
@@ -6055,10 +6260,17 @@ static struct ggml_tensor * ggml_rope_impl(
|
|
6055
6260
|
float xpos_base,
|
6056
6261
|
bool xpos_down,
|
6057
6262
|
bool inplace) {
|
6263
|
+
GGML_ASSERT((mode & 1) == 0 && "mode & 1 == 1 is no longer supported");
|
6264
|
+
|
6058
6265
|
GGML_ASSERT(ggml_is_vector(b));
|
6059
6266
|
GGML_ASSERT(b->type == GGML_TYPE_I32);
|
6060
6267
|
GGML_ASSERT(a->ne[2] == b->ne[0]);
|
6061
6268
|
|
6269
|
+
if (c) {
|
6270
|
+
GGML_ASSERT(c->type == GGML_TYPE_F32);
|
6271
|
+
GGML_ASSERT(c->ne[0] >= n_dims / 2);
|
6272
|
+
}
|
6273
|
+
|
6062
6274
|
bool is_node = false;
|
6063
6275
|
|
6064
6276
|
if (a->grad) {
|
@@ -6082,6 +6294,7 @@ static struct ggml_tensor * ggml_rope_impl(
|
|
6082
6294
|
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
6083
6295
|
result->src[0] = a;
|
6084
6296
|
result->src[1] = b;
|
6297
|
+
result->src[2] = c;
|
6085
6298
|
|
6086
6299
|
return result;
|
6087
6300
|
}
|
@@ -6094,7 +6307,7 @@ struct ggml_tensor * ggml_rope(
|
|
6094
6307
|
int mode,
|
6095
6308
|
int n_ctx) {
|
6096
6309
|
return ggml_rope_impl(
|
6097
|
-
ctx, a, b, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, false
|
6310
|
+
ctx, a, b, NULL, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, false
|
6098
6311
|
);
|
6099
6312
|
}
|
6100
6313
|
|
@@ -6106,7 +6319,49 @@ struct ggml_tensor * ggml_rope_inplace(
|
|
6106
6319
|
int mode,
|
6107
6320
|
int n_ctx) {
|
6108
6321
|
return ggml_rope_impl(
|
6109
|
-
ctx, a, b, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, true
|
6322
|
+
ctx, a, b, NULL, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, true
|
6323
|
+
);
|
6324
|
+
}
|
6325
|
+
|
6326
|
+
struct ggml_tensor * ggml_rope_ext(
|
6327
|
+
struct ggml_context * ctx,
|
6328
|
+
struct ggml_tensor * a,
|
6329
|
+
struct ggml_tensor * b,
|
6330
|
+
struct ggml_tensor * c,
|
6331
|
+
int n_dims,
|
6332
|
+
int mode,
|
6333
|
+
int n_ctx,
|
6334
|
+
int n_orig_ctx,
|
6335
|
+
float freq_base,
|
6336
|
+
float freq_scale,
|
6337
|
+
float ext_factor,
|
6338
|
+
float attn_factor,
|
6339
|
+
float beta_fast,
|
6340
|
+
float beta_slow) {
|
6341
|
+
return ggml_rope_impl(
|
6342
|
+
ctx, a, b, c, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
|
6343
|
+
ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, false
|
6344
|
+
);
|
6345
|
+
}
|
6346
|
+
|
6347
|
+
struct ggml_tensor * ggml_rope_ext_inplace(
|
6348
|
+
struct ggml_context * ctx,
|
6349
|
+
struct ggml_tensor * a,
|
6350
|
+
struct ggml_tensor * b,
|
6351
|
+
struct ggml_tensor * c,
|
6352
|
+
int n_dims,
|
6353
|
+
int mode,
|
6354
|
+
int n_ctx,
|
6355
|
+
int n_orig_ctx,
|
6356
|
+
float freq_base,
|
6357
|
+
float freq_scale,
|
6358
|
+
float ext_factor,
|
6359
|
+
float attn_factor,
|
6360
|
+
float beta_fast,
|
6361
|
+
float beta_slow) {
|
6362
|
+
return ggml_rope_impl(
|
6363
|
+
ctx, a, b, c, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
|
6364
|
+
ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, true
|
6110
6365
|
);
|
6111
6366
|
}
|
6112
6367
|
|
@@ -6125,7 +6380,7 @@ struct ggml_tensor * ggml_rope_custom(
|
|
6125
6380
|
float beta_fast,
|
6126
6381
|
float beta_slow) {
|
6127
6382
|
return ggml_rope_impl(
|
6128
|
-
ctx, a, b, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
|
6383
|
+
ctx, a, b, NULL, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
|
6129
6384
|
ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, false
|
6130
6385
|
);
|
6131
6386
|
}
|
@@ -6145,7 +6400,7 @@ struct ggml_tensor * ggml_rope_custom_inplace(
|
|
6145
6400
|
float beta_fast,
|
6146
6401
|
float beta_slow) {
|
6147
6402
|
return ggml_rope_impl(
|
6148
|
-
ctx, a, b, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
|
6403
|
+
ctx, a, b, NULL, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
|
6149
6404
|
ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, true
|
6150
6405
|
);
|
6151
6406
|
}
|
@@ -6157,7 +6412,7 @@ struct ggml_tensor * ggml_rope_xpos_inplace(
|
|
6157
6412
|
int n_dims,
|
6158
6413
|
float base,
|
6159
6414
|
bool down) {
|
6160
|
-
return ggml_rope_impl(ctx, a, b, n_dims, 0, 0, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, base, down, true);
|
6415
|
+
return ggml_rope_impl(ctx, a, b, NULL, n_dims, 0, 0, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, base, down, true);
|
6161
6416
|
}
|
6162
6417
|
|
6163
6418
|
// ggml_rope_back
|
@@ -6166,6 +6421,7 @@ struct ggml_tensor * ggml_rope_back(
|
|
6166
6421
|
struct ggml_context * ctx,
|
6167
6422
|
struct ggml_tensor * a,
|
6168
6423
|
struct ggml_tensor * b,
|
6424
|
+
struct ggml_tensor * c,
|
6169
6425
|
int n_dims,
|
6170
6426
|
int mode,
|
6171
6427
|
int n_ctx,
|
@@ -6181,6 +6437,7 @@ struct ggml_tensor * ggml_rope_back(
|
|
6181
6437
|
GGML_ASSERT(ggml_is_vector(b));
|
6182
6438
|
GGML_ASSERT(b->type == GGML_TYPE_I32);
|
6183
6439
|
GGML_ASSERT(a->ne[2] == b->ne[0]);
|
6440
|
+
GGML_ASSERT(c == NULL && "freq factors not implemented yet");
|
6184
6441
|
|
6185
6442
|
GGML_ASSERT((mode & 4) == 0 && "ggml_rope_back() for ChatGLM not implemented yet");
|
6186
6443
|
|
@@ -6724,38 +6981,6 @@ struct ggml_tensor * ggml_top_k(
|
|
6724
6981
|
return result;
|
6725
6982
|
}
|
6726
6983
|
|
6727
|
-
// ggml_flash_attn
|
6728
|
-
|
6729
|
-
struct ggml_tensor * ggml_flash_attn(
|
6730
|
-
struct ggml_context * ctx,
|
6731
|
-
struct ggml_tensor * q,
|
6732
|
-
struct ggml_tensor * k,
|
6733
|
-
struct ggml_tensor * v,
|
6734
|
-
bool masked) {
|
6735
|
-
GGML_ASSERT(ggml_can_mul_mat(k, q));
|
6736
|
-
// TODO: check if vT can be multiplied by (k*qT)
|
6737
|
-
|
6738
|
-
bool is_node = false;
|
6739
|
-
|
6740
|
-
if (q->grad || k->grad || v->grad) {
|
6741
|
-
is_node = true;
|
6742
|
-
}
|
6743
|
-
|
6744
|
-
//struct ggml_tensor * result = ggml_dup_tensor(ctx, q);
|
6745
|
-
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, q->ne);
|
6746
|
-
|
6747
|
-
int32_t t = masked ? 1 : 0;
|
6748
|
-
ggml_set_op_params(result, &t, sizeof(t));
|
6749
|
-
|
6750
|
-
result->op = GGML_OP_FLASH_ATTN;
|
6751
|
-
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
6752
|
-
result->src[0] = q;
|
6753
|
-
result->src[1] = k;
|
6754
|
-
result->src[2] = v;
|
6755
|
-
|
6756
|
-
return result;
|
6757
|
-
}
|
6758
|
-
|
6759
6984
|
// ggml_flash_attn_ext
|
6760
6985
|
|
6761
6986
|
struct ggml_tensor * ggml_flash_attn_ext(
|
@@ -6815,38 +7040,6 @@ void ggml_flash_attn_ext_set_prec(
|
|
6815
7040
|
ggml_set_op_params_i32(a, 2, prec_i32); // scale is on first pos, max_bias on second
|
6816
7041
|
}
|
6817
7042
|
|
6818
|
-
// ggml_flash_ff
|
6819
|
-
|
6820
|
-
struct ggml_tensor * ggml_flash_ff(
|
6821
|
-
struct ggml_context * ctx,
|
6822
|
-
struct ggml_tensor * a,
|
6823
|
-
struct ggml_tensor * b0,
|
6824
|
-
struct ggml_tensor * b1,
|
6825
|
-
struct ggml_tensor * c0,
|
6826
|
-
struct ggml_tensor * c1) {
|
6827
|
-
GGML_ASSERT(ggml_can_mul_mat(b0, a));
|
6828
|
-
// TODO: more checks
|
6829
|
-
|
6830
|
-
bool is_node = false;
|
6831
|
-
|
6832
|
-
if (a->grad || b0->grad || b1->grad || c0->grad || c1->grad) {
|
6833
|
-
is_node = true;
|
6834
|
-
}
|
6835
|
-
|
6836
|
-
//struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
|
6837
|
-
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, a->ne);
|
6838
|
-
|
6839
|
-
result->op = GGML_OP_FLASH_FF;
|
6840
|
-
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
6841
|
-
result->src[0] = a;
|
6842
|
-
result->src[1] = b0;
|
6843
|
-
result->src[2] = b1;
|
6844
|
-
result->src[3] = c0;
|
6845
|
-
result->src[4] = c1;
|
6846
|
-
|
6847
|
-
return result;
|
6848
|
-
}
|
6849
|
-
|
6850
7043
|
// ggml_flash_attn_back
|
6851
7044
|
|
6852
7045
|
struct ggml_tensor * ggml_flash_attn_back(
|
@@ -6856,6 +7049,8 @@ struct ggml_tensor * ggml_flash_attn_back(
|
|
6856
7049
|
struct ggml_tensor * v,
|
6857
7050
|
struct ggml_tensor * d,
|
6858
7051
|
bool masked) {
|
7052
|
+
GGML_ASSERT(false && "TODO: adapt to ggml_flash_attn_ext() changes");
|
7053
|
+
|
6859
7054
|
GGML_ASSERT(ggml_can_mul_mat(k, q));
|
6860
7055
|
// TODO: check if vT can be multiplied by (k*qT)
|
6861
7056
|
|
@@ -10809,26 +11004,29 @@ static void ggml_compute_forward_concat_f32(
|
|
10809
11004
|
GGML_ASSERT(nb00 == sizeof(float));
|
10810
11005
|
GGML_ASSERT(nb10 == sizeof(float));
|
10811
11006
|
|
11007
|
+
const int32_t dim = ggml_get_op_params_i32(dst, 0);
|
11008
|
+
|
11009
|
+
GGML_ASSERT(dim >= 0 && dim < 4);
|
11010
|
+
|
11011
|
+
int64_t o[4] = {0, 0, 0, 0};
|
11012
|
+
o[dim] = src0->ne[dim];
|
11013
|
+
|
11014
|
+
const float * x;
|
11015
|
+
|
11016
|
+
// TODO: smarter multi-theading
|
10812
11017
|
for (int i3 = 0; i3 < ne3; i3++) {
|
10813
11018
|
for (int i2 = ith; i2 < ne2; i2 += nth) {
|
10814
|
-
|
10815
|
-
for (int
|
10816
|
-
|
10817
|
-
|
10818
|
-
|
10819
|
-
|
10820
|
-
*y = *x;
|
10821
|
-
}
|
10822
|
-
}
|
10823
|
-
} // src1
|
10824
|
-
else {
|
10825
|
-
for (int i1 = 0; i1 < ne1; i1++) {
|
10826
|
-
for (int i0 = 0; i0 < ne0; i0++) {
|
10827
|
-
const float * x = (float *)((char *) src1->data + i0 * nb10 + i1 * nb11 + (i2 - ne02) * nb12 + i3 * nb13);
|
10828
|
-
|
10829
|
-
float * y = (float *)((char *)dst->data + i0 * nb0 + i1 * nb1 + i2 * nb2 + i3 * nb3);
|
10830
|
-
*y = *x;
|
11019
|
+
for (int i1 = 0; i1 < ne1; i1++) {
|
11020
|
+
for (int i0 = 0; i0 < ne0; i0++) {
|
11021
|
+
if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
|
11022
|
+
x = (const float *) ((const char *)src0->data + (i0 )*nb00 + (i1 )*nb01 + (i2 )*nb02 + (i3 )*nb03);
|
11023
|
+
} else {
|
11024
|
+
x = (const float *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13);
|
10831
11025
|
}
|
11026
|
+
|
11027
|
+
float * y = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
|
11028
|
+
|
11029
|
+
*y = *x;
|
10832
11030
|
}
|
10833
11031
|
}
|
10834
11032
|
}
|
@@ -10836,8 +11034,8 @@ static void ggml_compute_forward_concat_f32(
|
|
10836
11034
|
}
|
10837
11035
|
|
10838
11036
|
static void ggml_compute_forward_concat(
|
10839
|
-
const struct ggml_compute_params* params,
|
10840
|
-
struct ggml_tensor* dst) {
|
11037
|
+
const struct ggml_compute_params * params,
|
11038
|
+
struct ggml_tensor * dst) {
|
10841
11039
|
|
10842
11040
|
const struct ggml_tensor * src0 = dst->src[0];
|
10843
11041
|
|
@@ -11230,8 +11428,8 @@ static void ggml_compute_forward_gelu_f32(
|
|
11230
11428
|
|
11231
11429
|
const struct ggml_tensor * src0 = dst->src[0];
|
11232
11430
|
|
11233
|
-
GGML_ASSERT(
|
11234
|
-
GGML_ASSERT(
|
11431
|
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
11432
|
+
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
11235
11433
|
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
11236
11434
|
|
11237
11435
|
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
|
@@ -11293,8 +11491,8 @@ static void ggml_compute_forward_gelu_quick_f32(
|
|
11293
11491
|
|
11294
11492
|
const struct ggml_tensor * src0 = dst->src[0];
|
11295
11493
|
|
11296
|
-
GGML_ASSERT(
|
11297
|
-
GGML_ASSERT(
|
11494
|
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
11495
|
+
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
11298
11496
|
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
11299
11497
|
|
11300
11498
|
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
|
@@ -11356,8 +11554,8 @@ static void ggml_compute_forward_silu_f32(
|
|
11356
11554
|
|
11357
11555
|
const struct ggml_tensor * src0 = dst->src[0];
|
11358
11556
|
|
11359
|
-
GGML_ASSERT(
|
11360
|
-
GGML_ASSERT(
|
11557
|
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
11558
|
+
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
11361
11559
|
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
11362
11560
|
|
11363
11561
|
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
|
@@ -11468,9 +11666,9 @@ static void ggml_compute_forward_silu_back_f32(
|
|
11468
11666
|
const struct ggml_tensor * src0 = dst->src[0];
|
11469
11667
|
const struct ggml_tensor * grad = dst->src[1];
|
11470
11668
|
|
11471
|
-
GGML_ASSERT(
|
11472
|
-
GGML_ASSERT(
|
11473
|
-
GGML_ASSERT(
|
11669
|
+
GGML_ASSERT(ggml_is_contiguous_1(grad));
|
11670
|
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
11671
|
+
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
11474
11672
|
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
11475
11673
|
GGML_ASSERT(ggml_are_same_shape(src0, grad));
|
11476
11674
|
|
@@ -14115,6 +14313,7 @@ static void ggml_compute_forward_rope_f32(
|
|
14115
14313
|
|
14116
14314
|
const struct ggml_tensor * src0 = dst->src[0];
|
14117
14315
|
const struct ggml_tensor * src1 = dst->src[1];
|
14316
|
+
const struct ggml_tensor * src2 = dst->src[2];
|
14118
14317
|
|
14119
14318
|
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
|
14120
14319
|
return;
|
@@ -14167,13 +14366,24 @@ static void ggml_compute_forward_rope_f32(
|
|
14167
14366
|
int ir = 0;
|
14168
14367
|
|
14169
14368
|
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
14170
|
-
|
14369
|
+
|
14171
14370
|
float corr_dims[2];
|
14172
14371
|
ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
|
14173
14372
|
|
14174
14373
|
const bool is_neox = mode & 2;
|
14175
14374
|
const bool is_glm = mode & 4;
|
14176
14375
|
|
14376
|
+
const float * freq_factors = NULL;
|
14377
|
+
if (is_neox) {
|
14378
|
+
if (src2 != NULL) {
|
14379
|
+
GGML_ASSERT(src2->type == GGML_TYPE_F32);
|
14380
|
+
GGML_ASSERT(src2->ne[0] >= n_dims / 2);
|
14381
|
+
freq_factors = (const float *) src2->data;
|
14382
|
+
}
|
14383
|
+
} else {
|
14384
|
+
GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for !is_neox");
|
14385
|
+
}
|
14386
|
+
|
14177
14387
|
// backward process uses inverse rotation by cos and sin.
|
14178
14388
|
// cos and sin build a rotation matrix, where the inverse is the transpose.
|
14179
14389
|
// this essentially just switches the sign of sin.
|
@@ -14205,7 +14415,7 @@ static void ggml_compute_forward_rope_f32(
|
|
14205
14415
|
const float cos_block_theta = cosf(block_theta);
|
14206
14416
|
const float sin_block_theta = sinf(block_theta) * sin_sign;
|
14207
14417
|
|
14208
|
-
theta_base
|
14418
|
+
theta_base *= theta_scale;
|
14209
14419
|
block_theta *= theta_scale;
|
14210
14420
|
|
14211
14421
|
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
@@ -14240,28 +14450,22 @@ static void ggml_compute_forward_rope_f32(
|
|
14240
14450
|
dst_data[1] = x0*sin_theta*zeta + x1*cos_theta*zeta;
|
14241
14451
|
}
|
14242
14452
|
} else {
|
14243
|
-
//
|
14244
|
-
// it seems we have to rope just the first n_dims elements and do nothing with the rest
|
14245
|
-
// ref: https://github.com/ml-explore/mlx/blob/dc2edc762c797e3b8de50b1dad4dc0a131691033/benchmarks/python/llama_jax_bench.py#L11-L26
|
14246
|
-
theta_base *= freq_scale;
|
14453
|
+
// ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
|
14247
14454
|
for (int64_t ic = 0; ic < ne0; ic += 2) {
|
14248
14455
|
if (ic < n_dims) {
|
14249
|
-
const int64_t
|
14456
|
+
const int64_t i0 = ic/2;
|
14250
14457
|
|
14251
|
-
|
14252
|
-
float cur_rot = inv_ndims * ic - ib;
|
14458
|
+
const float freq_factor = freq_factors ? freq_factors[i0] : 1.0f;
|
14253
14459
|
|
14254
14460
|
float cos_theta, sin_theta;
|
14255
14461
|
rope_yarn(
|
14256
|
-
theta_base, freq_scale, corr_dims,
|
14462
|
+
theta_base/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor,
|
14257
14463
|
&cos_theta, &sin_theta
|
14258
14464
|
);
|
14259
|
-
sin_theta *= sin_sign;
|
14260
14465
|
|
14466
|
+
sin_theta *= sin_sign;
|
14261
14467
|
theta_base *= theta_scale;
|
14262
14468
|
|
14263
|
-
const int64_t i0 = ib*n_dims + ic/2;
|
14264
|
-
|
14265
14469
|
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
14266
14470
|
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
14267
14471
|
|
@@ -14286,6 +14490,7 @@ static void ggml_compute_forward_rope_f32(
|
|
14286
14490
|
}
|
14287
14491
|
}
|
14288
14492
|
|
14493
|
+
// TODO: deduplicate f16/f32 code
|
14289
14494
|
static void ggml_compute_forward_rope_f16(
|
14290
14495
|
const struct ggml_compute_params * params,
|
14291
14496
|
struct ggml_tensor * dst,
|
@@ -14293,6 +14498,7 @@ static void ggml_compute_forward_rope_f16(
|
|
14293
14498
|
|
14294
14499
|
const struct ggml_tensor * src0 = dst->src[0];
|
14295
14500
|
const struct ggml_tensor * src1 = dst->src[1];
|
14501
|
+
const struct ggml_tensor * src2 = dst->src[2];
|
14296
14502
|
|
14297
14503
|
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
|
14298
14504
|
return;
|
@@ -14338,13 +14544,24 @@ static void ggml_compute_forward_rope_f16(
|
|
14338
14544
|
int ir = 0;
|
14339
14545
|
|
14340
14546
|
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
14341
|
-
|
14547
|
+
|
14342
14548
|
float corr_dims[2];
|
14343
14549
|
ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
|
14344
14550
|
|
14345
14551
|
const bool is_neox = mode & 2;
|
14346
14552
|
const bool is_glm = mode & 4;
|
14347
14553
|
|
14554
|
+
const float * freq_factors = NULL;
|
14555
|
+
if (is_neox) {
|
14556
|
+
if (src2 != NULL) {
|
14557
|
+
GGML_ASSERT(src2->type == GGML_TYPE_F32);
|
14558
|
+
GGML_ASSERT(src2->ne[0] >= n_dims / 2);
|
14559
|
+
freq_factors = (const float *) src2->data;
|
14560
|
+
}
|
14561
|
+
} else {
|
14562
|
+
GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for !is_neox");
|
14563
|
+
}
|
14564
|
+
|
14348
14565
|
// backward process uses inverse rotation by cos and sin.
|
14349
14566
|
// cos and sin build a rotation matrix, where the inverse is the transpose.
|
14350
14567
|
// this essentially just switches the sign of sin.
|
@@ -14376,7 +14593,7 @@ static void ggml_compute_forward_rope_f16(
|
|
14376
14593
|
const float cos_block_theta = cosf(block_theta);
|
14377
14594
|
const float sin_block_theta = sinf(block_theta) * sin_sign;
|
14378
14595
|
|
14379
|
-
theta_base
|
14596
|
+
theta_base *= theta_scale;
|
14380
14597
|
block_theta *= theta_scale;
|
14381
14598
|
|
14382
14599
|
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
@@ -14407,28 +14624,22 @@ static void ggml_compute_forward_rope_f16(
|
|
14407
14624
|
dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
14408
14625
|
}
|
14409
14626
|
} else {
|
14410
|
-
//
|
14411
|
-
// it seems we have to rope just the first n_dims elements and do nothing with the rest
|
14412
|
-
// ref: https://github.com/ml-explore/mlx/blob/dc2edc762c797e3b8de50b1dad4dc0a131691033/benchmarks/python/llama_jax_bench.py#L11-L26
|
14413
|
-
theta_base *= freq_scale;
|
14627
|
+
// ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
|
14414
14628
|
for (int64_t ic = 0; ic < ne0; ic += 2) {
|
14415
14629
|
if (ic < n_dims) {
|
14416
|
-
const int64_t
|
14630
|
+
const int64_t i0 = ic/2;
|
14417
14631
|
|
14418
|
-
|
14419
|
-
float cur_rot = inv_ndims * ic - ib;
|
14632
|
+
const float freq_factor = freq_factors ? freq_factors[i0] : 1.0f;
|
14420
14633
|
|
14421
14634
|
float cos_theta, sin_theta;
|
14422
14635
|
rope_yarn(
|
14423
|
-
theta_base, freq_scale, corr_dims,
|
14636
|
+
theta_base/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor,
|
14424
14637
|
&cos_theta, &sin_theta
|
14425
14638
|
);
|
14426
|
-
sin_theta *= sin_sign;
|
14427
14639
|
|
14640
|
+
sin_theta *= sin_sign;
|
14428
14641
|
theta_base *= theta_scale;
|
14429
14642
|
|
14430
|
-
const int64_t i0 = ib*n_dims + ic/2;
|
14431
|
-
|
14432
14643
|
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
14433
14644
|
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
14434
14645
|
|
@@ -15458,400 +15669,6 @@ static void ggml_compute_forward_argsort(
|
|
15458
15669
|
}
|
15459
15670
|
}
|
15460
15671
|
|
15461
|
-
// ggml_compute_forward_flash_attn
|
15462
|
-
|
15463
|
-
static void ggml_compute_forward_flash_attn_f32(
|
15464
|
-
const struct ggml_compute_params * params,
|
15465
|
-
const bool masked,
|
15466
|
-
struct ggml_tensor * dst) {
|
15467
|
-
|
15468
|
-
const struct ggml_tensor * q = dst->src[0];
|
15469
|
-
const struct ggml_tensor * k = dst->src[1];
|
15470
|
-
const struct ggml_tensor * v = dst->src[2];
|
15471
|
-
|
15472
|
-
int64_t t0 = ggml_perf_time_us();
|
15473
|
-
UNUSED(t0);
|
15474
|
-
|
15475
|
-
GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
|
15476
|
-
GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
|
15477
|
-
GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
|
15478
|
-
GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
|
15479
|
-
GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
|
15480
|
-
GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
|
15481
|
-
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
15482
|
-
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
15483
|
-
|
15484
|
-
const int ith = params->ith;
|
15485
|
-
const int nth = params->nth;
|
15486
|
-
|
15487
|
-
const int64_t D = neq0;
|
15488
|
-
const int64_t N = neq1;
|
15489
|
-
const int64_t P = nek1 - N;
|
15490
|
-
const int64_t M = P + N;
|
15491
|
-
|
15492
|
-
const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL);
|
15493
|
-
|
15494
|
-
GGML_ASSERT(ne0 == D);
|
15495
|
-
GGML_ASSERT(ne1 == N);
|
15496
|
-
GGML_ASSERT(P >= 0);
|
15497
|
-
|
15498
|
-
GGML_ASSERT(nbq0 == sizeof(float));
|
15499
|
-
GGML_ASSERT(nbk0 == sizeof(float));
|
15500
|
-
GGML_ASSERT(nbv0 == sizeof(float));
|
15501
|
-
|
15502
|
-
GGML_ASSERT(neq0 == D);
|
15503
|
-
GGML_ASSERT(nek0 == D);
|
15504
|
-
GGML_ASSERT(nev1 == D);
|
15505
|
-
|
15506
|
-
GGML_ASSERT(neq1 == N);
|
15507
|
-
GGML_ASSERT(nek1 == N + P);
|
15508
|
-
GGML_ASSERT(nev1 == D);
|
15509
|
-
|
15510
|
-
// dst cannot be transposed or permuted
|
15511
|
-
GGML_ASSERT(nb0 == sizeof(float));
|
15512
|
-
GGML_ASSERT(nb0 <= nb1);
|
15513
|
-
GGML_ASSERT(nb1 <= nb2);
|
15514
|
-
GGML_ASSERT(nb2 <= nb3);
|
15515
|
-
|
15516
|
-
if (params->type == GGML_TASK_TYPE_INIT) {
|
15517
|
-
return;
|
15518
|
-
}
|
15519
|
-
|
15520
|
-
if (params->type == GGML_TASK_TYPE_FINALIZE) {
|
15521
|
-
return;
|
15522
|
-
}
|
15523
|
-
|
15524
|
-
// parallelize by q rows using ggml_vec_dot_f32
|
15525
|
-
|
15526
|
-
// total rows in q
|
15527
|
-
const int nr = neq1*neq2*neq3;
|
15528
|
-
|
15529
|
-
// rows per thread
|
15530
|
-
const int dr = (nr + nth - 1)/nth;
|
15531
|
-
|
15532
|
-
// row range for this thread
|
15533
|
-
const int ir0 = dr*ith;
|
15534
|
-
const int ir1 = MIN(ir0 + dr, nr);
|
15535
|
-
|
15536
|
-
const float scale = 1.0f/sqrtf(D);
|
15537
|
-
|
15538
|
-
//printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
|
15539
|
-
|
15540
|
-
for (int ir = ir0; ir < ir1; ++ir) {
|
15541
|
-
// q indices
|
15542
|
-
const int iq3 = ir/(neq2*neq1);
|
15543
|
-
const int iq2 = (ir - iq3*neq2*neq1)/neq1;
|
15544
|
-
const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
|
15545
|
-
|
15546
|
-
float * S = (float *) params->wdata + ith*(Mup + CACHE_LINE_SIZE_F32);
|
15547
|
-
|
15548
|
-
for (int i = M; i < Mup; ++i) {
|
15549
|
-
S[i] = -INFINITY;
|
15550
|
-
}
|
15551
|
-
|
15552
|
-
const int64_t masked_begin = masked ? (P + iq1 + 1) : M;
|
15553
|
-
for (int64_t ic = 0; ic < masked_begin; ++ic) {
|
15554
|
-
// k indices
|
15555
|
-
const int ik3 = iq3;
|
15556
|
-
const int ik2 = iq2 % nek2;
|
15557
|
-
const int ik1 = ic;
|
15558
|
-
|
15559
|
-
// S indices
|
15560
|
-
const int i1 = ik1;
|
15561
|
-
|
15562
|
-
ggml_vec_dot_f32(neq0,
|
15563
|
-
S + i1, 0,
|
15564
|
-
(float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
|
15565
|
-
(float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)), 0, 1);
|
15566
|
-
}
|
15567
|
-
|
15568
|
-
// scale
|
15569
|
-
ggml_vec_scale_f32(masked_begin, S, scale);
|
15570
|
-
|
15571
|
-
for (int64_t i = masked_begin; i < M; i++) {
|
15572
|
-
S[i] = -INFINITY;
|
15573
|
-
}
|
15574
|
-
|
15575
|
-
// softmax
|
15576
|
-
// exclude known -INF S[..] values from max and loop
|
15577
|
-
// dont forget to set their SW values to zero
|
15578
|
-
{
|
15579
|
-
float max = -INFINITY;
|
15580
|
-
ggml_vec_max_f32(masked_begin, &max, S);
|
15581
|
-
|
15582
|
-
ggml_float sum = 0.0;
|
15583
|
-
{
|
15584
|
-
#ifdef GGML_SOFT_MAX_ACCELERATE
|
15585
|
-
max = -max;
|
15586
|
-
vDSP_vsadd(S, 1, &max, S, 1, Mup);
|
15587
|
-
vvexpf(S, S, &Mup);
|
15588
|
-
ggml_vec_sum_f32(Mup, &sum, S);
|
15589
|
-
#else
|
15590
|
-
sum = ggml_vec_soft_max_f32(Mup, S, S, max);
|
15591
|
-
#endif
|
15592
|
-
}
|
15593
|
-
|
15594
|
-
assert(sum > 0.0);
|
15595
|
-
|
15596
|
-
sum = 1.0/sum;
|
15597
|
-
ggml_vec_scale_f32(masked_begin, S, sum);
|
15598
|
-
|
15599
|
-
#ifndef NDEBUG
|
15600
|
-
for (int i = 0; i < masked_begin; ++i) {
|
15601
|
-
assert(!isnan(S[i]));
|
15602
|
-
assert(!isinf(S[i]));
|
15603
|
-
}
|
15604
|
-
#endif
|
15605
|
-
}
|
15606
|
-
|
15607
|
-
for (int64_t ic = 0; ic < nev1; ++ic) {
|
15608
|
-
// dst indices
|
15609
|
-
const int i1 = iq1;
|
15610
|
-
const int i2 = iq2;
|
15611
|
-
const int i3 = iq3;
|
15612
|
-
|
15613
|
-
// v indices
|
15614
|
-
const int iv2 = iq2 % nev2;
|
15615
|
-
const int iv3 = iq3;
|
15616
|
-
|
15617
|
-
ggml_vec_dot_f32(masked_begin,
|
15618
|
-
(float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), 0,
|
15619
|
-
(float *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), 0,
|
15620
|
-
S, 0, 1);
|
15621
|
-
}
|
15622
|
-
}
|
15623
|
-
}
|
15624
|
-
|
15625
|
-
static void ggml_compute_forward_flash_attn_f16(
|
15626
|
-
const struct ggml_compute_params * params,
|
15627
|
-
const bool masked,
|
15628
|
-
struct ggml_tensor * dst) {
|
15629
|
-
|
15630
|
-
const struct ggml_tensor * q = dst->src[0];
|
15631
|
-
const struct ggml_tensor * k = dst->src[1];
|
15632
|
-
const struct ggml_tensor * v = dst->src[2];
|
15633
|
-
|
15634
|
-
int64_t t0 = ggml_perf_time_us();
|
15635
|
-
UNUSED(t0);
|
15636
|
-
|
15637
|
-
GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
|
15638
|
-
GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
|
15639
|
-
GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
|
15640
|
-
GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
|
15641
|
-
GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
|
15642
|
-
GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
|
15643
|
-
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
15644
|
-
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
15645
|
-
|
15646
|
-
const int ith = params->ith;
|
15647
|
-
const int nth = params->nth;
|
15648
|
-
|
15649
|
-
const int64_t D = neq0;
|
15650
|
-
const int64_t N = neq1;
|
15651
|
-
const int64_t P = nek1 - N;
|
15652
|
-
const int64_t M = P + N;
|
15653
|
-
|
15654
|
-
const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL);
|
15655
|
-
|
15656
|
-
GGML_ASSERT(ne0 == D);
|
15657
|
-
GGML_ASSERT(ne1 == N);
|
15658
|
-
GGML_ASSERT(P >= 0);
|
15659
|
-
|
15660
|
-
GGML_ASSERT(nbq0 == sizeof(ggml_fp16_t));
|
15661
|
-
GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t));
|
15662
|
-
GGML_ASSERT(nbv0 == sizeof(ggml_fp16_t));
|
15663
|
-
|
15664
|
-
GGML_ASSERT(neq0 == D);
|
15665
|
-
GGML_ASSERT(nek0 == D);
|
15666
|
-
GGML_ASSERT(nev1 == D);
|
15667
|
-
|
15668
|
-
GGML_ASSERT(neq1 == N);
|
15669
|
-
GGML_ASSERT(nek1 == N + P);
|
15670
|
-
GGML_ASSERT(nev1 == D);
|
15671
|
-
|
15672
|
-
// dst cannot be transposed or permuted
|
15673
|
-
GGML_ASSERT(nb0 == sizeof(float));
|
15674
|
-
GGML_ASSERT(nb0 <= nb1);
|
15675
|
-
GGML_ASSERT(nb1 <= nb2);
|
15676
|
-
GGML_ASSERT(nb2 <= nb3);
|
15677
|
-
|
15678
|
-
if (params->type == GGML_TASK_TYPE_INIT) {
|
15679
|
-
return;
|
15680
|
-
}
|
15681
|
-
|
15682
|
-
if (params->type == GGML_TASK_TYPE_FINALIZE) {
|
15683
|
-
return;
|
15684
|
-
}
|
15685
|
-
|
15686
|
-
// parallelize by q rows using ggml_vec_dot_f32
|
15687
|
-
|
15688
|
-
// total rows in q
|
15689
|
-
const int nr = neq1*neq2*neq3;
|
15690
|
-
|
15691
|
-
// rows per thread
|
15692
|
-
const int dr = (nr + nth - 1)/nth;
|
15693
|
-
|
15694
|
-
// row range for this thread
|
15695
|
-
const int ir0 = dr*ith;
|
15696
|
-
const int ir1 = MIN(ir0 + dr, nr);
|
15697
|
-
|
15698
|
-
const float scale = 1.0f/sqrtf(D);
|
15699
|
-
|
15700
|
-
//printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
|
15701
|
-
|
15702
|
-
for (int ir = ir0; ir < ir1; ++ir) {
|
15703
|
-
// q indices
|
15704
|
-
const int iq3 = ir/(neq2*neq1);
|
15705
|
-
const int iq2 = (ir - iq3*neq2*neq1)/neq1;
|
15706
|
-
const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
|
15707
|
-
|
15708
|
-
float * S = (float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32);
|
15709
|
-
|
15710
|
-
for (int i = M; i < Mup; ++i) {
|
15711
|
-
S[i] = -INFINITY;
|
15712
|
-
}
|
15713
|
-
|
15714
|
-
if (GGML_VEC_DOT_UNROLL > 2 || nek1 % GGML_VEC_DOT_UNROLL != 0) {
|
15715
|
-
for (int64_t ic = 0; ic < nek1; ++ic) {
|
15716
|
-
// k indices
|
15717
|
-
const int ik3 = iq3;
|
15718
|
-
const int ik2 = iq2 % nek2;
|
15719
|
-
const int ik1 = ic;
|
15720
|
-
|
15721
|
-
// S indices
|
15722
|
-
const int i1 = ik1;
|
15723
|
-
|
15724
|
-
ggml_vec_dot_f16(neq0,
|
15725
|
-
S + i1, 0,
|
15726
|
-
(ggml_fp16_t *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
|
15727
|
-
(ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)), 0, 1);
|
15728
|
-
}
|
15729
|
-
} else {
|
15730
|
-
for (int64_t ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) {
|
15731
|
-
// k indices
|
15732
|
-
const int ik3 = iq3;
|
15733
|
-
const int ik2 = iq2 % nek2;
|
15734
|
-
const int ik1 = ic;
|
15735
|
-
|
15736
|
-
// S indices
|
15737
|
-
const int i1 = ik1;
|
15738
|
-
|
15739
|
-
ggml_vec_dot_f16_unroll(neq0, nbk1,
|
15740
|
-
S + i1,
|
15741
|
-
((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
|
15742
|
-
(ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
|
15743
|
-
}
|
15744
|
-
}
|
15745
|
-
|
15746
|
-
// scale
|
15747
|
-
ggml_vec_scale_f32(nek1, S, scale);
|
15748
|
-
|
15749
|
-
if (masked) {
|
15750
|
-
for (int64_t i = P; i < M; i++) {
|
15751
|
-
if (i > P + iq1) {
|
15752
|
-
S[i] = -INFINITY;
|
15753
|
-
}
|
15754
|
-
}
|
15755
|
-
}
|
15756
|
-
|
15757
|
-
// softmax
|
15758
|
-
// todo: exclude known -INF S[..] values from max and loop, assuming their results to be zero.
|
15759
|
-
// dont forget to set their S values to zero
|
15760
|
-
{
|
15761
|
-
float max = -INFINITY;
|
15762
|
-
ggml_vec_max_f32(M, &max, S);
|
15763
|
-
|
15764
|
-
ggml_float sum = 0.0;
|
15765
|
-
{
|
15766
|
-
#ifdef GGML_SOFT_MAX_ACCELERATE
|
15767
|
-
max = -max;
|
15768
|
-
vDSP_vsadd(S, 1, &max, S, 1, Mup);
|
15769
|
-
vvexpf(S, S, &Mup);
|
15770
|
-
ggml_vec_sum_f32(Mup, &sum, S);
|
15771
|
-
#else
|
15772
|
-
sum = ggml_vec_soft_max_f32(Mup, S, S, max);
|
15773
|
-
#endif
|
15774
|
-
}
|
15775
|
-
|
15776
|
-
assert(sum > 0.0);
|
15777
|
-
|
15778
|
-
sum = 1.0/sum;
|
15779
|
-
ggml_vec_scale_f32(M, S, sum);
|
15780
|
-
|
15781
|
-
#ifndef NDEBUG
|
15782
|
-
for (int i = 0; i < M; ++i) {
|
15783
|
-
assert(!isnan(S[i]));
|
15784
|
-
assert(!isinf(S[i]));
|
15785
|
-
}
|
15786
|
-
#endif
|
15787
|
-
}
|
15788
|
-
|
15789
|
-
ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32) + Mup);
|
15790
|
-
|
15791
|
-
for (int64_t i = 0; i < M; i++) {
|
15792
|
-
S16[i] = GGML_FP32_TO_FP16(S[i]);
|
15793
|
-
}
|
15794
|
-
|
15795
|
-
// todo: exclude known zero S[..] values from dot (reducing nev0 and increasing begin of v and S16).
|
15796
|
-
if (GGML_VEC_DOT_UNROLL == 1 || (nev1 % GGML_VEC_DOT_UNROLL != 0)) {
|
15797
|
-
for (int64_t ic = 0; ic < nev1; ++ic) {
|
15798
|
-
// dst indices
|
15799
|
-
const int i1 = iq1;
|
15800
|
-
const int i2 = iq2;
|
15801
|
-
const int i3 = iq3;
|
15802
|
-
|
15803
|
-
// v indices
|
15804
|
-
const int iv2 = iq2 % nev2;
|
15805
|
-
const int iv3 = iq3;
|
15806
|
-
|
15807
|
-
ggml_vec_dot_f16(nev0,
|
15808
|
-
(float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), 0,
|
15809
|
-
(ggml_fp16_t *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), 0,
|
15810
|
-
S16, 0, 1);
|
15811
|
-
}
|
15812
|
-
} else {
|
15813
|
-
for (int64_t ic = 0; ic < nev1; ic += GGML_VEC_DOT_UNROLL) {
|
15814
|
-
// dst indices
|
15815
|
-
const int i1 = iq1;
|
15816
|
-
const int i2 = iq2;
|
15817
|
-
const int i3 = iq3;
|
15818
|
-
|
15819
|
-
// v indices
|
15820
|
-
const int iv2 = iq2 % nev2;
|
15821
|
-
const int iv3 = iq3;
|
15822
|
-
|
15823
|
-
ggml_vec_dot_f16_unroll(nev0, nbv1,
|
15824
|
-
(float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
|
15825
|
-
((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
|
15826
|
-
S16);
|
15827
|
-
}
|
15828
|
-
}
|
15829
|
-
}
|
15830
|
-
}
|
15831
|
-
|
15832
|
-
static void ggml_compute_forward_flash_attn(
|
15833
|
-
const struct ggml_compute_params * params,
|
15834
|
-
const bool masked,
|
15835
|
-
struct ggml_tensor * dst) {
|
15836
|
-
|
15837
|
-
const struct ggml_tensor * q = dst->src[0];
|
15838
|
-
|
15839
|
-
switch (q->type) {
|
15840
|
-
case GGML_TYPE_F16:
|
15841
|
-
{
|
15842
|
-
ggml_compute_forward_flash_attn_f16(params, masked, dst);
|
15843
|
-
} break;
|
15844
|
-
case GGML_TYPE_F32:
|
15845
|
-
{
|
15846
|
-
ggml_compute_forward_flash_attn_f32(params, masked, dst);
|
15847
|
-
} break;
|
15848
|
-
default:
|
15849
|
-
{
|
15850
|
-
GGML_ASSERT(false);
|
15851
|
-
} break;
|
15852
|
-
}
|
15853
|
-
}
|
15854
|
-
|
15855
15672
|
// ggml_compute_forward_flash_attn_ext
|
15856
15673
|
|
15857
15674
|
static void ggml_compute_forward_flash_attn_ext_f16(
|
@@ -15882,9 +15699,10 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
15882
15699
|
GGML_ASSERT(ne0 == D);
|
15883
15700
|
GGML_ASSERT(ne2 == N);
|
15884
15701
|
|
15885
|
-
|
15886
|
-
GGML_ASSERT(
|
15887
|
-
GGML_ASSERT(
|
15702
|
+
// input tensor rows must be contiguous
|
15703
|
+
GGML_ASSERT(nbq0 == ggml_type_size(q->type));
|
15704
|
+
GGML_ASSERT(nbk0 == ggml_type_size(k->type));
|
15705
|
+
GGML_ASSERT(nbv0 == ggml_type_size(v->type));
|
15888
15706
|
|
15889
15707
|
GGML_ASSERT(neq0 == D);
|
15890
15708
|
GGML_ASSERT(nek0 == D);
|
@@ -15938,6 +15756,11 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
15938
15756
|
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
15939
15757
|
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
15940
15758
|
|
15759
|
+
enum ggml_type const k_vec_dot_type = type_traits[k->type].vec_dot_type;
|
15760
|
+
ggml_from_float_t const q_to_vec_dot = type_traits[k_vec_dot_type].from_float;
|
15761
|
+
ggml_vec_dot_t const kq_vec_dot = type_traits[k->type].vec_dot;
|
15762
|
+
ggml_to_float_t const v_to_float = type_traits[v->type].to_float;
|
15763
|
+
|
15941
15764
|
// loop over n_batch and n_head
|
15942
15765
|
for (int ir = ir0; ir < ir1; ++ir) {
|
15943
15766
|
// q indices
|
@@ -15945,17 +15768,22 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
15945
15768
|
const int iq2 = (ir - iq3*neq2*neq1)/neq1;
|
15946
15769
|
const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
|
15947
15770
|
|
15948
|
-
const uint32_t h = iq2; // head
|
15771
|
+
const uint32_t h = iq2; // head index
|
15949
15772
|
const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
|
15950
15773
|
|
15951
|
-
float S = 0.0f;
|
15952
|
-
float M = -INFINITY;
|
15774
|
+
float S = 0.0f; // sum
|
15775
|
+
float M = -INFINITY; // maximum KQ value
|
15953
15776
|
|
15954
|
-
float *
|
15955
|
-
|
15956
|
-
ggml_fp16_t *
|
15777
|
+
float * VKQ32 = (float *) params->wdata + ith*(3*D + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator
|
15778
|
+
float * V32 = (VKQ32 + 1*D); // (temporary) FP32 V buffer
|
15779
|
+
ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*D); // (temporary) FP16 VKQ accumulator
|
15780
|
+
ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*D); // (temporary) buffer for Q converted to quantized/FP16
|
15957
15781
|
|
15958
|
-
|
15782
|
+
if (v->type == GGML_TYPE_F16) {
|
15783
|
+
memset(VKQ16, 0, D*sizeof(ggml_fp16_t));
|
15784
|
+
} else {
|
15785
|
+
memset(VKQ32, 0, D*sizeof(float));
|
15786
|
+
}
|
15959
15787
|
|
15960
15788
|
const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
|
15961
15789
|
|
@@ -15967,6 +15795,9 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
15967
15795
|
const int iv3 = iq3 / rv3;
|
15968
15796
|
const int iv2 = iq2 / rv2;
|
15969
15797
|
|
15798
|
+
const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
|
15799
|
+
q_to_vec_dot(pq, Q_q, D);
|
15800
|
+
|
15970
15801
|
// online softmax / attention
|
15971
15802
|
// loop over n_kv and n_head_kv
|
15972
15803
|
// ref: https://arxiv.org/pdf/2112.05682.pdf
|
@@ -15976,52 +15807,67 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
15976
15807
|
continue;
|
15977
15808
|
}
|
15978
15809
|
|
15979
|
-
float s;
|
15810
|
+
float s; // KQ value
|
15980
15811
|
|
15981
|
-
|
15982
|
-
|
15983
|
-
const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
|
15812
|
+
const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3);
|
15813
|
+
kq_vec_dot(D, &s, 0, k_data, 0, Q_q, 0, 1);
|
15984
15814
|
|
15985
|
-
|
15986
|
-
Q16[d] = GGML_FP32_TO_FP16(pq[d]);
|
15987
|
-
}
|
15988
|
-
}
|
15815
|
+
s = s*scale + mv; // scale KQ value and apply mask
|
15989
15816
|
|
15990
|
-
|
15991
|
-
&s, 0,
|
15992
|
-
(ggml_fp16_t *) ((char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
|
15993
|
-
Q16, 0, 1);
|
15817
|
+
const float Mold = M;
|
15994
15818
|
|
15995
|
-
|
15819
|
+
float ms = 1.0f; // upon new higher max val, scale VKQ and KQ sum with this value
|
15820
|
+
float vs = 1.0f; // post-softmax KQ value, expf(s - M)
|
15996
15821
|
|
15997
|
-
const
|
15822
|
+
const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
|
15998
15823
|
|
15999
|
-
|
16000
|
-
|
15824
|
+
if (v->type== GGML_TYPE_F16) {
|
15825
|
+
if (s > M) {
|
15826
|
+
// s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
|
15827
|
+
M = s;
|
15828
|
+
ms = expf(Mold - M);
|
16001
15829
|
|
16002
|
-
|
16003
|
-
|
16004
|
-
|
15830
|
+
// V = V*expf(Mold - M)
|
15831
|
+
ggml_vec_scale_f16(D, VKQ16, ms);
|
15832
|
+
} else {
|
15833
|
+
// no new maximum, ms == 1.0f, vs != 1.0f
|
15834
|
+
vs = expf(s - M);
|
15835
|
+
}
|
16005
15836
|
|
16006
|
-
// V
|
16007
|
-
|
15837
|
+
// V += v*expf(s - M)
|
15838
|
+
ggml_vec_mad_f16(D, VKQ16, (const ggml_fp16_t *) v_data, vs);
|
16008
15839
|
} else {
|
16009
|
-
|
16010
|
-
|
15840
|
+
if (s > M) {
|
15841
|
+
// s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
|
15842
|
+
M = s;
|
15843
|
+
ms = expf(Mold - M);
|
16011
15844
|
|
16012
|
-
|
15845
|
+
// V = V*expf(Mold - M)
|
15846
|
+
ggml_vec_scale_f32(D, VKQ32, ms);
|
15847
|
+
} else {
|
15848
|
+
// no new maximum, ms == 1.0f, vs != 1.0f
|
15849
|
+
vs = expf(s - M);
|
15850
|
+
}
|
16013
15851
|
|
16014
|
-
|
16015
|
-
|
15852
|
+
v_to_float(v_data, V32, D);
|
15853
|
+
|
15854
|
+
// V += v*expf(s - M)
|
15855
|
+
ggml_vec_mad_f32(D, VKQ32, V32, vs);
|
15856
|
+
}
|
16016
15857
|
|
16017
|
-
S = S*ms + vs;
|
15858
|
+
S = S*ms + vs; // scale and increment sum with partial sum
|
16018
15859
|
}
|
16019
15860
|
|
16020
|
-
|
16021
|
-
|
16022
|
-
|
15861
|
+
if (v->type == GGML_TYPE_F16) {
|
15862
|
+
for (int64_t d = 0; d < D; ++d) {
|
15863
|
+
VKQ32[d] = GGML_FP16_TO_FP32(VKQ16[d]);
|
15864
|
+
}
|
16023
15865
|
}
|
16024
15866
|
|
15867
|
+
// V /= S
|
15868
|
+
const float S_inv = 1.0f/S;
|
15869
|
+
ggml_vec_scale_f32(D, VKQ32, S_inv);
|
15870
|
+
|
16025
15871
|
// dst indices
|
16026
15872
|
const int i1 = iq1;
|
16027
15873
|
const int i2 = iq2;
|
@@ -16031,7 +15877,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|
16031
15877
|
//memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float));
|
16032
15878
|
|
16033
15879
|
// permute(0, 2, 1, 3)
|
16034
|
-
memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1,
|
15880
|
+
memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1);
|
16035
15881
|
}
|
16036
15882
|
}
|
16037
15883
|
|
@@ -16056,165 +15902,6 @@ static void ggml_compute_forward_flash_attn_ext(
|
|
16056
15902
|
}
|
16057
15903
|
}
|
16058
15904
|
|
16059
|
-
// ggml_compute_forward_flash_ff
|
16060
|
-
|
16061
|
-
static void ggml_compute_forward_flash_ff_f16(
|
16062
|
-
const struct ggml_compute_params * params,
|
16063
|
-
struct ggml_tensor * dst) {
|
16064
|
-
|
16065
|
-
const struct ggml_tensor * a = dst->src[0]; // F16
|
16066
|
-
const struct ggml_tensor * b0 = dst->src[1]; // F16 fc_w
|
16067
|
-
const struct ggml_tensor * b1 = dst->src[2]; // F32 fc_b
|
16068
|
-
const struct ggml_tensor * c0 = dst->src[3]; // F16 proj_w
|
16069
|
-
const struct ggml_tensor * c1 = dst->src[4]; // F32 proj_b
|
16070
|
-
|
16071
|
-
int64_t t0 = ggml_perf_time_us();
|
16072
|
-
UNUSED(t0);
|
16073
|
-
|
16074
|
-
GGML_TENSOR_LOCALS(int64_t, nea, a, ne)
|
16075
|
-
GGML_TENSOR_LOCALS(size_t, nba, a, nb)
|
16076
|
-
GGML_TENSOR_LOCALS(int64_t, neb0, b0, ne)
|
16077
|
-
GGML_TENSOR_LOCALS(size_t, nbb0, b0, nb)
|
16078
|
-
GGML_TENSOR_LOCALS(int64_t, neb1, b1, ne)
|
16079
|
-
GGML_TENSOR_LOCALS(size_t, nbb1, b1, nb)
|
16080
|
-
GGML_TENSOR_LOCALS(int64_t, nec0, c0, ne)
|
16081
|
-
GGML_TENSOR_LOCALS(size_t, nbc0, c0, nb)
|
16082
|
-
GGML_TENSOR_LOCALS(int64_t, nec1, c1, ne)
|
16083
|
-
GGML_TENSOR_LOCALS(size_t, nbc1, c1, nb)
|
16084
|
-
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
16085
|
-
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
16086
|
-
|
16087
|
-
const int ith = params->ith;
|
16088
|
-
const int nth = params->nth;
|
16089
|
-
|
16090
|
-
const int64_t D = nea0;
|
16091
|
-
//const int64_t N = nea1;
|
16092
|
-
const int64_t M = neb01;
|
16093
|
-
|
16094
|
-
GGML_ASSERT(ne0 == nea0);
|
16095
|
-
GGML_ASSERT(ne1 == nea1);
|
16096
|
-
GGML_ASSERT(ne2 == nea2);
|
16097
|
-
|
16098
|
-
GGML_ASSERT(nba0 == sizeof(ggml_fp16_t));
|
16099
|
-
GGML_ASSERT(nbb00 == sizeof(ggml_fp16_t));
|
16100
|
-
GGML_ASSERT(nbb10 == sizeof(float));
|
16101
|
-
GGML_ASSERT(nbc00 == sizeof(ggml_fp16_t));
|
16102
|
-
GGML_ASSERT(nbc10 == sizeof(float));
|
16103
|
-
|
16104
|
-
GGML_ASSERT(neb00 == D);
|
16105
|
-
GGML_ASSERT(neb01 == M);
|
16106
|
-
GGML_ASSERT(neb10 == M);
|
16107
|
-
GGML_ASSERT(neb11 == 1);
|
16108
|
-
|
16109
|
-
GGML_ASSERT(nec00 == M);
|
16110
|
-
GGML_ASSERT(nec01 == D);
|
16111
|
-
GGML_ASSERT(nec10 == D);
|
16112
|
-
GGML_ASSERT(nec11 == 1);
|
16113
|
-
|
16114
|
-
// dst cannot be transposed or permuted
|
16115
|
-
GGML_ASSERT(nb0 == sizeof(float));
|
16116
|
-
GGML_ASSERT(nb0 <= nb1);
|
16117
|
-
GGML_ASSERT(nb1 <= nb2);
|
16118
|
-
GGML_ASSERT(nb2 <= nb3);
|
16119
|
-
|
16120
|
-
if (params->type == GGML_TASK_TYPE_INIT) {
|
16121
|
-
return;
|
16122
|
-
}
|
16123
|
-
|
16124
|
-
if (params->type == GGML_TASK_TYPE_FINALIZE) {
|
16125
|
-
return;
|
16126
|
-
}
|
16127
|
-
|
16128
|
-
// parallelize by a rows using ggml_vec_dot_f32
|
16129
|
-
|
16130
|
-
// total rows in a
|
16131
|
-
const int nr = nea1*nea2*nea3;
|
16132
|
-
|
16133
|
-
// rows per thread
|
16134
|
-
const int dr = (nr + nth - 1)/nth;
|
16135
|
-
|
16136
|
-
// row range for this thread
|
16137
|
-
const int ir0 = dr*ith;
|
16138
|
-
const int ir1 = MIN(ir0 + dr, nr);
|
16139
|
-
|
16140
|
-
for (int ir = ir0; ir < ir1; ++ir) {
|
16141
|
-
// a indices
|
16142
|
-
const int ia3 = ir/(nea2*nea1);
|
16143
|
-
const int ia2 = (ir - ia3*nea2*nea1)/nea1;
|
16144
|
-
const int ia1 = (ir - ia3*nea2*nea1 - ia2*nea1);
|
16145
|
-
|
16146
|
-
float * S = (float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32);
|
16147
|
-
|
16148
|
-
for (int64_t ic = 0; ic < neb01; ++ic) {
|
16149
|
-
// b0 indices
|
16150
|
-
const int ib03 = ia3;
|
16151
|
-
const int ib02 = ia2;
|
16152
|
-
const int ib01 = ic;
|
16153
|
-
|
16154
|
-
// S indices
|
16155
|
-
const int i1 = ib01;
|
16156
|
-
|
16157
|
-
ggml_vec_dot_f16(nea0,
|
16158
|
-
S + i1, 0,
|
16159
|
-
(ggml_fp16_t *) ((char *) b0->data + (ib01*nbb01 + ib02*nbb02 + ib03*nbb03)), 0,
|
16160
|
-
(ggml_fp16_t *) ((char *) a->data + ( ia1*nba1 + ia2*nba2 + ia3*nba3)), 0, 1);
|
16161
|
-
}
|
16162
|
-
|
16163
|
-
ggml_vec_add_f32(neb01, S, S, (float *) b1->data);
|
16164
|
-
//ggml_vec_gelu_f32(neb01, S, S);
|
16165
|
-
|
16166
|
-
ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32) + M);
|
16167
|
-
|
16168
|
-
for (int64_t i = 0; i < M; i++) {
|
16169
|
-
S16[i] = GGML_FP32_TO_FP16(S[i]);
|
16170
|
-
}
|
16171
|
-
|
16172
|
-
ggml_vec_gelu_f16(neb01, S16, S16);
|
16173
|
-
|
16174
|
-
{
|
16175
|
-
// dst indices
|
16176
|
-
const int i1 = ia1;
|
16177
|
-
const int i2 = ia2;
|
16178
|
-
const int i3 = ia3;
|
16179
|
-
|
16180
|
-
for (int64_t ic = 0; ic < nec01; ++ic) {
|
16181
|
-
|
16182
|
-
ggml_vec_dot_f16(neb01,
|
16183
|
-
(float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), 0,
|
16184
|
-
(ggml_fp16_t *) ((char *) c0->data + ( ic*nbc01 + i2*nbc02 + i3*nbc03)), 0,
|
16185
|
-
S16, 0, 1);
|
16186
|
-
}
|
16187
|
-
|
16188
|
-
ggml_vec_add_f32(nec01,
|
16189
|
-
(float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)),
|
16190
|
-
(float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)),
|
16191
|
-
(float *) c1->data);
|
16192
|
-
}
|
16193
|
-
}
|
16194
|
-
}
|
16195
|
-
|
16196
|
-
static void ggml_compute_forward_flash_ff(
|
16197
|
-
const struct ggml_compute_params * params,
|
16198
|
-
struct ggml_tensor * dst) {
|
16199
|
-
|
16200
|
-
const struct ggml_tensor * b0 = dst->src[1];
|
16201
|
-
|
16202
|
-
switch (b0->type) {
|
16203
|
-
case GGML_TYPE_F16:
|
16204
|
-
{
|
16205
|
-
ggml_compute_forward_flash_ff_f16(params, dst);
|
16206
|
-
} break;
|
16207
|
-
case GGML_TYPE_F32:
|
16208
|
-
{
|
16209
|
-
GGML_ASSERT(false); // TODO
|
16210
|
-
} break;
|
16211
|
-
default:
|
16212
|
-
{
|
16213
|
-
GGML_ASSERT(false);
|
16214
|
-
} break;
|
16215
|
-
}
|
16216
|
-
}
|
16217
|
-
|
16218
15905
|
// ggml_compute_forward_flash_attn_back
|
16219
15906
|
|
16220
15907
|
static void ggml_compute_forward_flash_attn_back_f32(
|
@@ -17785,21 +17472,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
17785
17472
|
{
|
17786
17473
|
ggml_compute_forward_leaky_relu(params, tensor);
|
17787
17474
|
} break;
|
17788
|
-
case GGML_OP_FLASH_ATTN:
|
17789
|
-
{
|
17790
|
-
const int32_t t = ggml_get_op_params_i32(tensor, 0);
|
17791
|
-
GGML_ASSERT(t == 0 || t == 1);
|
17792
|
-
const bool masked = t != 0;
|
17793
|
-
ggml_compute_forward_flash_attn(params, masked, tensor);
|
17794
|
-
} break;
|
17795
17475
|
case GGML_OP_FLASH_ATTN_EXT:
|
17796
17476
|
{
|
17797
17477
|
ggml_compute_forward_flash_attn_ext(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor);
|
17798
17478
|
} break;
|
17799
|
-
case GGML_OP_FLASH_FF:
|
17800
|
-
{
|
17801
|
-
ggml_compute_forward_flash_ff(params, tensor);
|
17802
|
-
} break;
|
17803
17479
|
case GGML_OP_FLASH_ATTN_BACK:
|
17804
17480
|
{
|
17805
17481
|
int32_t t = ggml_get_op_params_i32(tensor, 0);
|
@@ -18169,6 +17845,7 @@ static struct ggml_tensor * ggml_sub_or_set(struct ggml_context * ctx, struct gg
|
|
18169
17845
|
static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, struct ggml_hash_set zero_table) {
|
18170
17846
|
struct ggml_tensor * src0 = tensor->src[0];
|
18171
17847
|
struct ggml_tensor * src1 = tensor->src[1];
|
17848
|
+
struct ggml_tensor * src2 = tensor->src[2];
|
18172
17849
|
|
18173
17850
|
switch (tensor->op) {
|
18174
17851
|
case GGML_OP_DUP:
|
@@ -18700,6 +18377,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
18700
18377
|
ggml_rope_back(ctx,
|
18701
18378
|
tensor->grad,
|
18702
18379
|
src1,
|
18380
|
+
src2,
|
18703
18381
|
n_dims,
|
18704
18382
|
mode,
|
18705
18383
|
n_ctx,
|
@@ -18739,6 +18417,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
18739
18417
|
ggml_rope_impl(ctx,
|
18740
18418
|
tensor->grad,
|
18741
18419
|
src1,
|
18420
|
+
src2,
|
18742
18421
|
n_dims,
|
18743
18422
|
mode,
|
18744
18423
|
n_ctx,
|
@@ -18803,7 +18482,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
18803
18482
|
{
|
18804
18483
|
GGML_ASSERT(false); // TODO: not implemented
|
18805
18484
|
} break;
|
18806
|
-
case GGML_OP_FLASH_ATTN:
|
18807
18485
|
case GGML_OP_FLASH_ATTN_EXT:
|
18808
18486
|
{
|
18809
18487
|
struct ggml_tensor * flash_grad = NULL;
|
@@ -18820,7 +18498,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
18820
18498
|
masked);
|
18821
18499
|
}
|
18822
18500
|
|
18823
|
-
struct ggml_tensor * src2 = tensor->src[2];
|
18824
18501
|
const int64_t elem_q = ggml_nelements(src0);
|
18825
18502
|
const int64_t elem_k = ggml_nelements(src1);
|
18826
18503
|
const int64_t elem_v = ggml_nelements(src2);
|
@@ -18858,10 +18535,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
18858
18535
|
zero_table);
|
18859
18536
|
}
|
18860
18537
|
} break;
|
18861
|
-
case GGML_OP_FLASH_FF:
|
18862
|
-
{
|
18863
|
-
GGML_ASSERT(false); // not supported
|
18864
|
-
} break;
|
18865
18538
|
case GGML_OP_FLASH_ATTN_BACK:
|
18866
18539
|
{
|
18867
18540
|
GGML_ASSERT(false); // not supported
|
@@ -19548,15 +19221,10 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_
|
|
19548
19221
|
{
|
19549
19222
|
n_tasks = n_threads;
|
19550
19223
|
} break;
|
19551
|
-
case GGML_OP_FLASH_ATTN:
|
19552
19224
|
case GGML_OP_FLASH_ATTN_EXT:
|
19553
19225
|
{
|
19554
19226
|
n_tasks = n_threads;
|
19555
19227
|
} break;
|
19556
|
-
case GGML_OP_FLASH_FF:
|
19557
|
-
{
|
19558
|
-
n_tasks = n_threads;
|
19559
|
-
} break;
|
19560
19228
|
case GGML_OP_FLASH_ATTN_BACK:
|
19561
19229
|
{
|
19562
19230
|
n_tasks = n_threads;
|
@@ -19953,39 +19621,11 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
|
|
19953
19621
|
cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02*ne03;
|
19954
19622
|
cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12;
|
19955
19623
|
} break;
|
19956
|
-
case GGML_OP_FLASH_ATTN:
|
19957
|
-
{
|
19958
|
-
const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL);
|
19959
|
-
|
19960
|
-
if (node->src[1]->type == GGML_TYPE_F32) {
|
19961
|
-
cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
|
19962
|
-
cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
|
19963
|
-
} else if (node->src[1]->type == GGML_TYPE_F16) {
|
19964
|
-
cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
|
19965
|
-
cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
|
19966
|
-
} else if (node->src[1]->type == GGML_TYPE_BF16) {
|
19967
|
-
cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
|
19968
|
-
cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
|
19969
|
-
}
|
19970
|
-
} break;
|
19971
19624
|
case GGML_OP_FLASH_ATTN_EXT:
|
19972
19625
|
{
|
19973
19626
|
const int64_t ne00 = node->src[0]->ne[0]; // D
|
19974
19627
|
|
19975
|
-
cur =
|
19976
|
-
} break;
|
19977
|
-
case GGML_OP_FLASH_FF:
|
19978
|
-
{
|
19979
|
-
if (node->src[1]->type == GGML_TYPE_F32) {
|
19980
|
-
cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
|
19981
|
-
cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
|
19982
|
-
} else if (node->src[1]->type == GGML_TYPE_F16) {
|
19983
|
-
cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
|
19984
|
-
cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
|
19985
|
-
} else if (node->src[1]->type == GGML_TYPE_BF16) {
|
19986
|
-
cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
|
19987
|
-
cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
|
19988
|
-
}
|
19628
|
+
cur = 3*sizeof(float)*ne00*n_tasks; // 3x head size/thread
|
19989
19629
|
} break;
|
19990
19630
|
case GGML_OP_FLASH_ATTN_BACK:
|
19991
19631
|
{
|
@@ -21827,11 +21467,7 @@ size_t ggml_quantize_chunk(
|
|
21827
21467
|
case GGML_TYPE_IQ1_S: result = quantize_iq1_s (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
21828
21468
|
case GGML_TYPE_IQ1_M: result = quantize_iq1_m (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
21829
21469
|
case GGML_TYPE_IQ4_NL: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
21830
|
-
#if QK_K == 64
|
21831
|
-
case GGML_TYPE_IQ4_XS: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
21832
|
-
#else
|
21833
21470
|
case GGML_TYPE_IQ4_XS: result = quantize_iq4_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
21834
|
-
#endif
|
21835
21471
|
case GGML_TYPE_F16:
|
21836
21472
|
{
|
21837
21473
|
size_t elemsize = sizeof(ggml_fp16_t);
|
@@ -23108,6 +22744,14 @@ int ggml_cpu_has_avx512_vnni(void) {
|
|
23108
22744
|
#endif
|
23109
22745
|
}
|
23110
22746
|
|
22747
|
+
int ggml_cpu_has_avx512_bf16(void) {
|
22748
|
+
#if defined(__AVX512BF16__)
|
22749
|
+
return 1;
|
22750
|
+
#else
|
22751
|
+
return 0;
|
22752
|
+
#endif
|
22753
|
+
}
|
22754
|
+
|
23111
22755
|
int ggml_cpu_has_fma(void) {
|
23112
22756
|
#if defined(__FMA__)
|
23113
22757
|
return 1;
|
@@ -23124,6 +22768,16 @@ int ggml_cpu_has_neon(void) {
|
|
23124
22768
|
#endif
|
23125
22769
|
}
|
23126
22770
|
|
22771
|
+
int ggml_cpu_has_sve(void) {
|
22772
|
+
#if defined(__ARM_FEATURE_SVE)
|
22773
|
+
// TODO: Currently, SVE 256 bit is only supported.
|
22774
|
+
GGML_ASSERT(svcntb() == QK8_0);
|
22775
|
+
return 1;
|
22776
|
+
#else
|
22777
|
+
return 0;
|
22778
|
+
#endif
|
22779
|
+
}
|
22780
|
+
|
23127
22781
|
int ggml_cpu_has_arm_fma(void) {
|
23128
22782
|
#if defined(__ARM_FEATURE_FMA)
|
23129
22783
|
return 1;
|
@@ -23212,6 +22866,14 @@ int ggml_cpu_has_sycl(void) {
|
|
23212
22866
|
#endif
|
23213
22867
|
}
|
23214
22868
|
|
22869
|
+
int ggml_cpu_has_rpc(void) {
|
22870
|
+
#if defined(GGML_USE_RPC)
|
22871
|
+
return 1;
|
22872
|
+
#else
|
22873
|
+
return 0;
|
22874
|
+
#endif
|
22875
|
+
}
|
22876
|
+
|
23215
22877
|
int ggml_cpu_has_gpublas(void) {
|
23216
22878
|
return ggml_cpu_has_cuda() || ggml_cpu_has_clblast() || ggml_cpu_has_vulkan() || ggml_cpu_has_kompute() ||
|
23217
22879
|
ggml_cpu_has_sycl();
|