llama_cpp 0.2.0 → 0.2.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +11 -0
- data/examples/README.md +92 -0
- data/examples/chat.rb +195 -0
- data/examples/embedding.rb +37 -0
- data/ext/llama_cpp/llama_cpp.cpp +52 -0
- data/ext/llama_cpp/src/ggml-cuda.cu +1218 -411
- data/ext/llama_cpp/src/ggml-cuda.h +4 -1
- data/ext/llama_cpp/src/ggml-metal.h +5 -1
- data/ext/llama_cpp/src/ggml-metal.m +703 -514
- data/ext/llama_cpp/src/ggml-metal.metal +574 -122
- data/ext/llama_cpp/src/ggml-opencl.cpp +496 -36
- data/ext/llama_cpp/src/ggml-opencl.h +1 -2
- data/ext/llama_cpp/src/ggml.c +2715 -476
- data/ext/llama_cpp/src/ggml.h +266 -11
- data/ext/llama_cpp/src/llama.cpp +266 -135
- data/ext/llama_cpp/src/llama.h +19 -11
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +3 -0
- metadata +5 -2
data/ext/llama_cpp/src/ggml.c
CHANGED
@@ -35,6 +35,12 @@
|
|
35
35
|
#define static_assert(cond, msg) struct global_scope_noop_trick
|
36
36
|
#endif
|
37
37
|
|
38
|
+
#if defined(_MSC_VER)
|
39
|
+
// disable "possible loss of data" to avoid hundreds of casts
|
40
|
+
// we should just be careful :)
|
41
|
+
#pragma warning(disable: 4244 4267)
|
42
|
+
#endif
|
43
|
+
|
38
44
|
#if defined(_WIN32)
|
39
45
|
|
40
46
|
#include <windows.h>
|
@@ -106,6 +112,7 @@ typedef void* thread_ret_t;
|
|
106
112
|
/*#define GGML_PERF*/
|
107
113
|
#define GGML_DEBUG 0
|
108
114
|
#define GGML_GELU_FP16
|
115
|
+
#define GGML_GELU_QUICK_FP16
|
109
116
|
#define GGML_SILU_FP16
|
110
117
|
|
111
118
|
#define GGML_SOFT_MAX_UNROLL 4
|
@@ -334,6 +341,9 @@ static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
|
|
334
341
|
// precomputed gelu table for f16 (128 KB)
|
335
342
|
static ggml_fp16_t table_gelu_f16[1 << 16];
|
336
343
|
|
344
|
+
// precomputed quick gelu table for f16 (128 KB)
|
345
|
+
static ggml_fp16_t table_gelu_quick_f16[1 << 16];
|
346
|
+
|
337
347
|
// precomputed silu table for f16 (128 KB)
|
338
348
|
static ggml_fp16_t table_silu_f16[1 << 16];
|
339
349
|
|
@@ -1671,14 +1681,17 @@ quantize_fns_t ggml_internal_get_quantize_fn(size_t i) {
|
|
1671
1681
|
#define GGML_F32x4_REDUCE_ONE(x) vaddvq_f32(x)
|
1672
1682
|
#define GGML_F32x4_REDUCE(res, x) \
|
1673
1683
|
{ \
|
1674
|
-
|
1675
|
-
|
1684
|
+
int offset = GGML_F32_ARR >> 1; \
|
1685
|
+
for (int i = 0; i < offset; ++i) { \
|
1686
|
+
x[i] = vaddq_f32(x[i], x[offset+i]); \
|
1676
1687
|
} \
|
1677
|
-
|
1678
|
-
|
1688
|
+
offset >>= 1; \
|
1689
|
+
for (int i = 0; i < offset; ++i) { \
|
1690
|
+
x[i] = vaddq_f32(x[i], x[offset+i]); \
|
1679
1691
|
} \
|
1680
|
-
|
1681
|
-
|
1692
|
+
offset >>= 1; \
|
1693
|
+
for (int i = 0; i < offset; ++i) { \
|
1694
|
+
x[i] = vaddq_f32(x[i], x[offset+i]); \
|
1682
1695
|
} \
|
1683
1696
|
res = GGML_F32x4_REDUCE_ONE(x[0]); \
|
1684
1697
|
}
|
@@ -1709,14 +1722,17 @@ quantize_fns_t ggml_internal_get_quantize_fn(size_t i) {
|
|
1709
1722
|
#define GGML_F16x8_MUL vmulq_f16
|
1710
1723
|
#define GGML_F16x8_REDUCE(res, x) \
|
1711
1724
|
{ \
|
1712
|
-
|
1713
|
-
|
1725
|
+
int offset = GGML_F16_ARR >> 1; \
|
1726
|
+
for (int i = 0; i < offset; ++i) { \
|
1727
|
+
x[i] = vaddq_f16(x[i], x[offset+i]); \
|
1714
1728
|
} \
|
1715
|
-
|
1716
|
-
|
1729
|
+
offset >>= 1; \
|
1730
|
+
for (int i = 0; i < offset; ++i) { \
|
1731
|
+
x[i] = vaddq_f16(x[i], x[offset+i]); \
|
1717
1732
|
} \
|
1718
|
-
|
1719
|
-
|
1733
|
+
offset >>= 1; \
|
1734
|
+
for (int i = 0; i < offset; ++i) { \
|
1735
|
+
x[i] = vaddq_f16(x[i], x[offset+i]); \
|
1720
1736
|
} \
|
1721
1737
|
const float32x4_t t0 = vcvt_f32_f16(vget_low_f16 (x[0])); \
|
1722
1738
|
const float32x4_t t1 = vcvt_f32_f16(vget_high_f16(x[0])); \
|
@@ -1783,14 +1799,17 @@ quantize_fns_t ggml_internal_get_quantize_fn(size_t i) {
|
|
1783
1799
|
#define GGML_F32x8_MUL _mm256_mul_ps
|
1784
1800
|
#define GGML_F32x8_REDUCE(res, x) \
|
1785
1801
|
{ \
|
1786
|
-
|
1787
|
-
|
1802
|
+
int offset = GGML_F32_ARR >> 1; \
|
1803
|
+
for (int i = 0; i < offset; ++i) { \
|
1804
|
+
x[i] = _mm256_add_ps(x[i], x[offset+i]); \
|
1788
1805
|
} \
|
1789
|
-
|
1790
|
-
|
1806
|
+
offset >>= 1; \
|
1807
|
+
for (int i = 0; i < offset; ++i) { \
|
1808
|
+
x[i] = _mm256_add_ps(x[i], x[offset+i]); \
|
1791
1809
|
} \
|
1792
|
-
|
1793
|
-
|
1810
|
+
offset >>= 1; \
|
1811
|
+
for (int i = 0; i < offset; ++i) { \
|
1812
|
+
x[i] = _mm256_add_ps(x[i], x[offset+i]); \
|
1794
1813
|
} \
|
1795
1814
|
const __m128 t0 = _mm_add_ps(_mm256_castps256_ps128(x[0]), \
|
1796
1815
|
_mm256_extractf128_ps(x[0], 1)); \
|
@@ -1880,14 +1899,17 @@ static inline void __avx_f32cx8_store(ggml_fp16_t *x, __m256 y) {
|
|
1880
1899
|
#define GGML_F32x4_MUL vec_mul
|
1881
1900
|
#define GGML_F32x4_REDUCE(res, x) \
|
1882
1901
|
{ \
|
1883
|
-
|
1884
|
-
|
1902
|
+
int offset = GGML_F32_ARR >> 1; \
|
1903
|
+
for (int i = 0; i < offset; ++i) { \
|
1904
|
+
x[i] = vec_add(x[i], x[offset+i]); \
|
1885
1905
|
} \
|
1886
|
-
|
1887
|
-
|
1906
|
+
offset >>= 1; \
|
1907
|
+
for (int i = 0; i < offset; ++i) { \
|
1908
|
+
x[i] = vec_add(x[i], x[offset+i]); \
|
1888
1909
|
} \
|
1889
|
-
|
1890
|
-
|
1910
|
+
offset >>= 1; \
|
1911
|
+
for (int i = 0; i < offset; ++i) { \
|
1912
|
+
x[i] = vec_add(x[i], x[offset+i]); \
|
1891
1913
|
} \
|
1892
1914
|
res = vec_extract(x[0], 0) + \
|
1893
1915
|
vec_extract(x[0], 1) + \
|
@@ -1943,14 +1965,17 @@ static inline void __avx_f32cx8_store(ggml_fp16_t *x, __m256 y) {
|
|
1943
1965
|
#define GGML_F32x4_MUL wasm_f32x4_mul
|
1944
1966
|
#define GGML_F32x4_REDUCE(res, x) \
|
1945
1967
|
{ \
|
1946
|
-
|
1947
|
-
|
1968
|
+
int offset = GGML_F32_ARR >> 1; \
|
1969
|
+
for (int i = 0; i < offset; ++i) { \
|
1970
|
+
x[i] = wasm_f32x4_add(x[i], x[offset+i]); \
|
1948
1971
|
} \
|
1949
|
-
|
1950
|
-
|
1972
|
+
offset >>= 1; \
|
1973
|
+
for (int i = 0; i < offset; ++i) { \
|
1974
|
+
x[i] = wasm_f32x4_add(x[i], x[offset+i]); \
|
1951
1975
|
} \
|
1952
|
-
|
1953
|
-
|
1976
|
+
offset >>= 1; \
|
1977
|
+
for (int i = 0; i < offset; ++i) { \
|
1978
|
+
x[i] = wasm_f32x4_add(x[i], x[offset+i]); \
|
1954
1979
|
} \
|
1955
1980
|
res = wasm_f32x4_extract_lane(x[0], 0) + \
|
1956
1981
|
wasm_f32x4_extract_lane(x[0], 1) + \
|
@@ -2005,14 +2030,17 @@ inline static void __wasm_f16x4_store(ggml_fp16_t * p, v128_t x) {
|
|
2005
2030
|
#define GGML_F16x4_MUL wasm_f32x4_mul
|
2006
2031
|
#define GGML_F16x4_REDUCE(res, x) \
|
2007
2032
|
{ \
|
2008
|
-
|
2009
|
-
|
2033
|
+
int offset = GGML_F16_ARR >> 1; \
|
2034
|
+
for (int i = 0; i < offset; ++i) { \
|
2035
|
+
x[i] = wasm_f32x4_add(x[i], x[offset+i]); \
|
2010
2036
|
} \
|
2011
|
-
|
2012
|
-
|
2037
|
+
offset >>= 1; \
|
2038
|
+
for (int i = 0; i < offset; ++i) { \
|
2039
|
+
x[i] = wasm_f32x4_add(x[i], x[offset+i]); \
|
2013
2040
|
} \
|
2014
|
-
|
2015
|
-
|
2041
|
+
offset >>= 1; \
|
2042
|
+
for (int i = 0; i < offset; ++i) { \
|
2043
|
+
x[i] = wasm_f32x4_add(x[i], x[offset+i]); \
|
2016
2044
|
} \
|
2017
2045
|
res = wasm_f32x4_extract_lane(x[0], 0) + \
|
2018
2046
|
wasm_f32x4_extract_lane(x[0], 1) + \
|
@@ -2054,14 +2082,17 @@ inline static void __wasm_f16x4_store(ggml_fp16_t * p, v128_t x) {
|
|
2054
2082
|
#define GGML_F32x4_MUL _mm_mul_ps
|
2055
2083
|
#define GGML_F32x4_REDUCE(res, x) \
|
2056
2084
|
{ \
|
2057
|
-
|
2058
|
-
|
2085
|
+
int offset = GGML_F32_ARR >> 1; \
|
2086
|
+
for (int i = 0; i < offset; ++i) { \
|
2087
|
+
x[i] = _mm_add_ps(x[i], x[offset+i]); \
|
2059
2088
|
} \
|
2060
|
-
|
2061
|
-
|
2089
|
+
offset >>= 1; \
|
2090
|
+
for (int i = 0; i < offset; ++i) { \
|
2091
|
+
x[i] = _mm_add_ps(x[i], x[offset+i]); \
|
2062
2092
|
} \
|
2063
|
-
|
2064
|
-
|
2093
|
+
offset >>= 1; \
|
2094
|
+
for (int i = 0; i < offset; ++i) { \
|
2095
|
+
x[i] = _mm_add_ps(x[i], x[offset+i]); \
|
2065
2096
|
} \
|
2066
2097
|
const __m128 t0 = _mm_hadd_ps(x[0], x[0]); \
|
2067
2098
|
res = _mm_cvtss_f32(_mm_hadd_ps(t0, t0)); \
|
@@ -3350,6 +3381,7 @@ inline static void ggml_vec_step_f32 (const int n, float * y, const float * x) {
|
|
3350
3381
|
inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; }
|
3351
3382
|
|
3352
3383
|
static const float GELU_COEF_A = 0.044715f;
|
3384
|
+
static const float GELU_QUICK_COEF = -1.702f;
|
3353
3385
|
static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
3354
3386
|
|
3355
3387
|
inline static float ggml_gelu_f32(float x) {
|
@@ -3380,6 +3412,34 @@ inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
|
|
3380
3412
|
}
|
3381
3413
|
#endif
|
3382
3414
|
|
3415
|
+
inline static float ggml_gelu_quick_f32(float x) {
|
3416
|
+
return x*(1.0f/(1.0f+expf(GELU_QUICK_COEF*x)));
|
3417
|
+
}
|
3418
|
+
|
3419
|
+
//inline static void ggml_vec_gelu_quick_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
|
3420
|
+
// const uint16_t * i16 = (const uint16_t *) x;
|
3421
|
+
// for (int i = 0; i < n; ++i) {
|
3422
|
+
// y[i] = table_gelu_quick_f16[i16[i]];
|
3423
|
+
// }
|
3424
|
+
//}
|
3425
|
+
|
3426
|
+
#ifdef GGML_GELU_QUICK_FP16
|
3427
|
+
inline static void ggml_vec_gelu_quick_f32(const int n, float * y, const float * x) {
|
3428
|
+
uint16_t t;
|
3429
|
+
for (int i = 0; i < n; ++i) {
|
3430
|
+
ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]);
|
3431
|
+
memcpy(&t, &fp16, sizeof(uint16_t));
|
3432
|
+
y[i] = GGML_FP16_TO_FP32(table_gelu_quick_f16[t]);
|
3433
|
+
}
|
3434
|
+
}
|
3435
|
+
#else
|
3436
|
+
inline static void ggml_vec_gelu_quick_f32(const int n, float * y, const float * x) {
|
3437
|
+
for (int i = 0; i < n; ++i) {
|
3438
|
+
y[i] = ggml_gelu_quick_f32(x[i]);
|
3439
|
+
}
|
3440
|
+
}
|
3441
|
+
#endif
|
3442
|
+
|
3383
3443
|
// Sigmoid Linear Unit (SiLU) function
|
3384
3444
|
inline static float ggml_silu_f32(float x) {
|
3385
3445
|
return x/(1.0f + expf(-x));
|
@@ -3603,12 +3663,14 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|
3603
3663
|
"SUM_ROWS",
|
3604
3664
|
"MEAN",
|
3605
3665
|
"REPEAT",
|
3666
|
+
"REPEAT_BACK",
|
3606
3667
|
"ABS",
|
3607
3668
|
"SGN",
|
3608
3669
|
"NEG",
|
3609
3670
|
"STEP",
|
3610
3671
|
"RELU",
|
3611
3672
|
"GELU",
|
3673
|
+
"GELU_QUICK",
|
3612
3674
|
"SILU",
|
3613
3675
|
"SILU_BACK",
|
3614
3676
|
"NORM",
|
@@ -3616,6 +3678,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|
3616
3678
|
"RMS_NORM_BACK",
|
3617
3679
|
|
3618
3680
|
"MUL_MAT",
|
3681
|
+
"OUT_PROD",
|
3619
3682
|
|
3620
3683
|
"SCALE",
|
3621
3684
|
"SET",
|
@@ -3631,22 +3694,29 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|
3631
3694
|
"DIAG_MASK_INF",
|
3632
3695
|
"DIAG_MASK_ZERO",
|
3633
3696
|
"SOFT_MAX",
|
3697
|
+
"SOFT_MAX_BACK",
|
3634
3698
|
"ROPE",
|
3635
3699
|
"ROPE_BACK",
|
3636
3700
|
"ALIBI",
|
3637
3701
|
"CLAMP",
|
3638
|
-
"
|
3639
|
-
"
|
3702
|
+
"CONV_1D_S1_PH",
|
3703
|
+
"CONV_1D_S2_PH",
|
3704
|
+
"CONV_2D_SK_P0",
|
3640
3705
|
|
3641
3706
|
"FLASH_ATTN",
|
3642
3707
|
"FLASH_FF",
|
3708
|
+
"FLASH_ATTN_BACK",
|
3709
|
+
"WIN_PART",
|
3710
|
+
"WIN_UNPART",
|
3643
3711
|
|
3644
3712
|
"MAP_UNARY",
|
3645
3713
|
"MAP_BINARY",
|
3646
|
-
};
|
3647
3714
|
|
3648
|
-
|
3715
|
+
"CROSS_ENTROPY_LOSS",
|
3716
|
+
"CROSS_ENTROPY_LOSS_BACK",
|
3717
|
+
};
|
3649
3718
|
|
3719
|
+
static_assert(GGML_OP_COUNT == 61, "GGML_OP_COUNT != 61");
|
3650
3720
|
|
3651
3721
|
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
3652
3722
|
"none",
|
@@ -3665,18 +3735,21 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|
3665
3735
|
"Σx_k",
|
3666
3736
|
"Σx/n",
|
3667
3737
|
"repeat(x)",
|
3738
|
+
"repeat_back(x)",
|
3668
3739
|
"abs(x)",
|
3669
3740
|
"sgn(x)",
|
3670
3741
|
"-x",
|
3671
3742
|
"step(x)",
|
3672
3743
|
"relu(x)",
|
3673
3744
|
"gelu(x)",
|
3745
|
+
"gelu_quick(x)",
|
3674
3746
|
"silu(x)",
|
3675
3747
|
"silu_back(x)",
|
3676
3748
|
"norm(x)",
|
3677
3749
|
"rms_norm(x)",
|
3678
3750
|
"rms_norm_back(x)",
|
3679
3751
|
|
3752
|
+
"X*Y",
|
3680
3753
|
"X*Y",
|
3681
3754
|
|
3682
3755
|
"x*v",
|
@@ -3693,21 +3766,29 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|
3693
3766
|
"diag_mask_inf(x)",
|
3694
3767
|
"diag_mask_zero(x)",
|
3695
3768
|
"soft_max(x)",
|
3769
|
+
"soft_max_back(x)",
|
3696
3770
|
"rope(x)",
|
3697
3771
|
"rope_back(x)",
|
3698
3772
|
"alibi(x)",
|
3699
3773
|
"clamp(x)",
|
3700
|
-
"
|
3701
|
-
"
|
3774
|
+
"conv_1d_s1_ph(x)",
|
3775
|
+
"conv_1d_s2_ph(x)",
|
3776
|
+
"conv_2d_sk_p0(x)",
|
3702
3777
|
|
3703
3778
|
"flash_attn(x)",
|
3704
3779
|
"flash_ff(x)",
|
3780
|
+
"flash_attn_back(x)",
|
3781
|
+
"win_part(x)",
|
3782
|
+
"win_unpart(x)",
|
3705
3783
|
|
3706
3784
|
"f(x)",
|
3707
3785
|
"f(x,y)",
|
3786
|
+
|
3787
|
+
"cross_entropy_loss(x,y)",
|
3788
|
+
"cross_entropy_loss_back(x,y)",
|
3708
3789
|
};
|
3709
3790
|
|
3710
|
-
static_assert(GGML_OP_COUNT ==
|
3791
|
+
static_assert(GGML_OP_COUNT == 61, "GGML_OP_COUNT != 61");
|
3711
3792
|
|
3712
3793
|
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
|
3713
3794
|
static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
|
@@ -3870,6 +3951,15 @@ static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct
|
|
3870
3951
|
(t0->ne[3] == t1->ne[3]);
|
3871
3952
|
}
|
3872
3953
|
|
3954
|
+
static inline bool ggml_can_out_prod(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
|
3955
|
+
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
3956
|
+
|
3957
|
+
return
|
3958
|
+
(t0->ne[1] == t1->ne[1]) &&
|
3959
|
+
(t0->ne[2] == t1->ne[2]) &&
|
3960
|
+
(t0->ne[3] == t1->ne[3]);
|
3961
|
+
}
|
3962
|
+
|
3873
3963
|
bool ggml_is_quantized(enum ggml_type type) {
|
3874
3964
|
return GGML_IS_QUANTIZED[type];
|
3875
3965
|
}
|
@@ -3917,6 +4007,12 @@ bool ggml_is_contiguous(const struct ggml_tensor * tensor) {
|
|
3917
4007
|
tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
|
3918
4008
|
}
|
3919
4009
|
|
4010
|
+
bool ggml_is_permuted(const struct ggml_tensor * tensor) {
|
4011
|
+
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
4012
|
+
|
4013
|
+
return tensor->nb[0] > tensor->nb[1] || tensor->nb[1] > tensor->nb[2] || tensor->nb[2] > tensor->nb[3];
|
4014
|
+
}
|
4015
|
+
|
3920
4016
|
static inline bool ggml_is_padded_1d(const struct ggml_tensor * tensor) {
|
3921
4017
|
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
3922
4018
|
|
@@ -3983,7 +4079,7 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
|
|
3983
4079
|
// initialize time system (required on Windows)
|
3984
4080
|
ggml_time_init();
|
3985
4081
|
|
3986
|
-
// initialize GELU, SILU and EXP F32 tables
|
4082
|
+
// initialize GELU, Quick GELU, SILU and EXP F32 tables
|
3987
4083
|
{
|
3988
4084
|
const uint64_t t_start = ggml_time_us(); UNUSED(t_start);
|
3989
4085
|
|
@@ -3993,13 +4089,14 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
|
|
3993
4089
|
memcpy(&ii, &ui, sizeof(ii));
|
3994
4090
|
const float f = table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(ii);
|
3995
4091
|
table_gelu_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_f32(f));
|
4092
|
+
table_gelu_quick_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_quick_f32(f));
|
3996
4093
|
table_silu_f16[i] = GGML_FP32_TO_FP16(ggml_silu_f32(f));
|
3997
4094
|
table_exp_f16[i] = GGML_FP32_TO_FP16(expf(f));
|
3998
4095
|
}
|
3999
4096
|
|
4000
4097
|
const uint64_t t_end = ggml_time_us(); UNUSED(t_end);
|
4001
4098
|
|
4002
|
-
GGML_PRINT_DEBUG("%s: GELU, SILU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f);
|
4099
|
+
GGML_PRINT_DEBUG("%s: GELU, Quick GELU, SILU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f);
|
4003
4100
|
}
|
4004
4101
|
|
4005
4102
|
// initialize g_state
|
@@ -4120,14 +4217,34 @@ void ggml_set_no_alloc(struct ggml_context * ctx, bool no_alloc) {
|
|
4120
4217
|
ctx->no_alloc = no_alloc;
|
4121
4218
|
}
|
4122
4219
|
|
4123
|
-
void * ggml_get_mem_buffer(struct ggml_context * ctx) {
|
4220
|
+
void * ggml_get_mem_buffer(const struct ggml_context * ctx) {
|
4124
4221
|
return ctx->mem_buffer;
|
4125
4222
|
}
|
4126
4223
|
|
4127
|
-
size_t ggml_get_mem_size(struct ggml_context * ctx) {
|
4224
|
+
size_t ggml_get_mem_size(const struct ggml_context * ctx) {
|
4128
4225
|
return ctx->mem_size;
|
4129
4226
|
}
|
4130
4227
|
|
4228
|
+
size_t ggml_get_max_tensor_size(const struct ggml_context * ctx) {
|
4229
|
+
size_t max_size = 0;
|
4230
|
+
|
4231
|
+
struct ggml_object * obj = ctx->objects_begin;
|
4232
|
+
|
4233
|
+
while (obj != NULL) {
|
4234
|
+
struct ggml_tensor * tensor = (struct ggml_tensor *) ((char *) ctx->mem_buffer + obj->offs);
|
4235
|
+
|
4236
|
+
const size_t size = ggml_nbytes(tensor);
|
4237
|
+
|
4238
|
+
if (max_size < size) {
|
4239
|
+
max_size = size;
|
4240
|
+
}
|
4241
|
+
|
4242
|
+
obj = obj->next;
|
4243
|
+
}
|
4244
|
+
|
4245
|
+
return max_size;
|
4246
|
+
}
|
4247
|
+
|
4131
4248
|
// IMPORTANT:
|
4132
4249
|
// when creating "opt" tensors, always save and load the scratch buffer
|
4133
4250
|
// this is an error prone process, but it is necessary to support inplace
|
@@ -4611,9 +4728,10 @@ const char * ggml_get_name(const struct ggml_tensor * tensor) {
|
|
4611
4728
|
return tensor->name;
|
4612
4729
|
}
|
4613
4730
|
|
4614
|
-
|
4731
|
+
struct ggml_tensor * ggml_set_name(struct ggml_tensor * tensor, const char * name) {
|
4615
4732
|
strncpy(tensor->name, name, sizeof(tensor->name));
|
4616
4733
|
tensor->name[sizeof(tensor->name) - 1] = '\0';
|
4734
|
+
return tensor;
|
4617
4735
|
}
|
4618
4736
|
|
4619
4737
|
struct ggml_tensor * ggml_view_tensor(
|
@@ -4693,7 +4811,7 @@ struct ggml_tensor * ggml_add_impl(
|
|
4693
4811
|
|
4694
4812
|
bool is_node = false;
|
4695
4813
|
|
4696
|
-
if (
|
4814
|
+
if (a->grad || b->grad) {
|
4697
4815
|
is_node = true;
|
4698
4816
|
}
|
4699
4817
|
|
@@ -4733,7 +4851,7 @@ struct ggml_tensor * ggml_add1_impl(
|
|
4733
4851
|
|
4734
4852
|
bool is_node = false;
|
4735
4853
|
|
4736
|
-
if (
|
4854
|
+
if (a->grad || b->grad) {
|
4737
4855
|
is_node = true;
|
4738
4856
|
}
|
4739
4857
|
|
@@ -5159,6 +5277,34 @@ struct ggml_tensor * ggml_repeat(
|
|
5159
5277
|
return result;
|
5160
5278
|
}
|
5161
5279
|
|
5280
|
+
// ggml_repeat_back
|
5281
|
+
|
5282
|
+
struct ggml_tensor * ggml_repeat_back(
|
5283
|
+
struct ggml_context * ctx,
|
5284
|
+
struct ggml_tensor * a,
|
5285
|
+
struct ggml_tensor * b) {
|
5286
|
+
GGML_ASSERT(ggml_can_repeat(b, a));
|
5287
|
+
|
5288
|
+
bool is_node = false;
|
5289
|
+
|
5290
|
+
if (a->grad) {
|
5291
|
+
is_node = true;
|
5292
|
+
}
|
5293
|
+
|
5294
|
+
if (ggml_are_same_shape(a, b) && !is_node) {
|
5295
|
+
return a;
|
5296
|
+
}
|
5297
|
+
|
5298
|
+
struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, b->n_dims, b->ne);
|
5299
|
+
|
5300
|
+
result->op = GGML_OP_REPEAT_BACK;
|
5301
|
+
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
5302
|
+
result->src0 = a;
|
5303
|
+
result->src1 = b;
|
5304
|
+
|
5305
|
+
return result;
|
5306
|
+
}
|
5307
|
+
|
5162
5308
|
// ggml_abs
|
5163
5309
|
|
5164
5310
|
struct ggml_tensor * ggml_abs_impl(
|
@@ -5364,6 +5510,40 @@ struct ggml_tensor * ggml_gelu_inplace(
|
|
5364
5510
|
return ggml_gelu_impl(ctx, a, true);
|
5365
5511
|
}
|
5366
5512
|
|
5513
|
+
// ggml_gelu_quick
|
5514
|
+
|
5515
|
+
struct ggml_tensor * ggml_gelu_quick_impl(
|
5516
|
+
struct ggml_context * ctx,
|
5517
|
+
struct ggml_tensor * a,
|
5518
|
+
bool inplace) {
|
5519
|
+
bool is_node = false;
|
5520
|
+
|
5521
|
+
if (!inplace && (a->grad)) {
|
5522
|
+
is_node = true;
|
5523
|
+
}
|
5524
|
+
|
5525
|
+
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
5526
|
+
|
5527
|
+
result->op = GGML_OP_GELU_QUICK;
|
5528
|
+
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
5529
|
+
result->src0 = a;
|
5530
|
+
result->src1 = NULL;
|
5531
|
+
|
5532
|
+
return result;
|
5533
|
+
}
|
5534
|
+
|
5535
|
+
struct ggml_tensor * ggml_gelu_quick(
|
5536
|
+
struct ggml_context * ctx,
|
5537
|
+
struct ggml_tensor * a) {
|
5538
|
+
return ggml_gelu_quick_impl(ctx, a, false);
|
5539
|
+
}
|
5540
|
+
|
5541
|
+
struct ggml_tensor * ggml_gelu_quick_inplace(
|
5542
|
+
struct ggml_context * ctx,
|
5543
|
+
struct ggml_tensor * a) {
|
5544
|
+
return ggml_gelu_quick_impl(ctx, a, true);
|
5545
|
+
}
|
5546
|
+
|
5367
5547
|
// ggml_silu
|
5368
5548
|
|
5369
5549
|
struct ggml_tensor * ggml_silu_impl(
|
@@ -5536,6 +5716,32 @@ struct ggml_tensor * ggml_mul_mat(
|
|
5536
5716
|
return result;
|
5537
5717
|
}
|
5538
5718
|
|
5719
|
+
// ggml_out_prod
|
5720
|
+
|
5721
|
+
struct ggml_tensor * ggml_out_prod(
|
5722
|
+
struct ggml_context * ctx,
|
5723
|
+
struct ggml_tensor * a,
|
5724
|
+
struct ggml_tensor * b) {
|
5725
|
+
GGML_ASSERT(ggml_can_out_prod(a, b));
|
5726
|
+
GGML_ASSERT(!ggml_is_transposed(a));
|
5727
|
+
|
5728
|
+
bool is_node = false;
|
5729
|
+
|
5730
|
+
if (a->grad || b->grad) {
|
5731
|
+
is_node = true;
|
5732
|
+
}
|
5733
|
+
|
5734
|
+
const int64_t ne[4] = { a->ne[0], b->ne[0], a->ne[2], b->ne[3] };
|
5735
|
+
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MIN(a->n_dims, b->n_dims), ne);
|
5736
|
+
|
5737
|
+
result->op = GGML_OP_OUT_PROD;
|
5738
|
+
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
5739
|
+
result->src0 = a;
|
5740
|
+
result->src1 = b;
|
5741
|
+
|
5742
|
+
return result;
|
5743
|
+
}
|
5744
|
+
|
5539
5745
|
// ggml_scale
|
5540
5746
|
|
5541
5747
|
struct ggml_tensor * ggml_scale_impl(
|
@@ -5548,7 +5754,7 @@ struct ggml_tensor * ggml_scale_impl(
|
|
5548
5754
|
|
5549
5755
|
bool is_node = false;
|
5550
5756
|
|
5551
|
-
if (
|
5757
|
+
if (a->grad || b->grad) {
|
5552
5758
|
is_node = true;
|
5553
5759
|
}
|
5554
5760
|
|
@@ -5591,7 +5797,7 @@ struct ggml_tensor * ggml_set_impl(
|
|
5591
5797
|
|
5592
5798
|
bool is_node = false;
|
5593
5799
|
|
5594
|
-
if (
|
5800
|
+
if (a->grad || b->grad) {
|
5595
5801
|
is_node = true;
|
5596
5802
|
}
|
5597
5803
|
|
@@ -5913,10 +6119,6 @@ struct ggml_tensor * ggml_view_1d(
|
|
5913
6119
|
result->src1 = NULL;
|
5914
6120
|
result->opt[0] = offs;
|
5915
6121
|
|
5916
|
-
if (is_node) {
|
5917
|
-
memcpy(result->padding, &offset, sizeof(offset));
|
5918
|
-
}
|
5919
|
-
|
5920
6122
|
return result;
|
5921
6123
|
}
|
5922
6124
|
|
@@ -5957,10 +6159,6 @@ struct ggml_tensor * ggml_view_2d(
|
|
5957
6159
|
result->src1 = NULL;
|
5958
6160
|
result->opt[0] = offs;
|
5959
6161
|
|
5960
|
-
if (is_node) {
|
5961
|
-
memcpy(result->padding, &offset, sizeof(offset));
|
5962
|
-
}
|
5963
|
-
|
5964
6162
|
return result;
|
5965
6163
|
}
|
5966
6164
|
|
@@ -6003,10 +6201,6 @@ struct ggml_tensor * ggml_view_3d(
|
|
6003
6201
|
result->src1 = NULL;
|
6004
6202
|
result->opt[0] = offs;
|
6005
6203
|
|
6006
|
-
if (is_node) {
|
6007
|
-
memcpy(result->padding, &offset, sizeof(offset));
|
6008
|
-
}
|
6009
|
-
|
6010
6204
|
return result;
|
6011
6205
|
}
|
6012
6206
|
|
@@ -6051,10 +6245,6 @@ struct ggml_tensor * ggml_view_4d(
|
|
6051
6245
|
result->src1 = NULL;
|
6052
6246
|
result->opt[0] = offs;
|
6053
6247
|
|
6054
|
-
if (is_node) {
|
6055
|
-
memcpy(result->padding, &offset, sizeof(offset));
|
6056
|
-
}
|
6057
|
-
|
6058
6248
|
return result;
|
6059
6249
|
}
|
6060
6250
|
|
@@ -6116,10 +6306,18 @@ struct ggml_tensor * ggml_permute(
|
|
6116
6306
|
result->src1 = NULL;
|
6117
6307
|
|
6118
6308
|
if (is_node) {
|
6119
|
-
|
6120
|
-
|
6121
|
-
|
6122
|
-
|
6309
|
+
ggml_scratch_save(ctx);
|
6310
|
+
|
6311
|
+
struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 4);
|
6312
|
+
|
6313
|
+
((int32_t *) b->data)[0] = axis0;
|
6314
|
+
((int32_t *) b->data)[1] = axis1;
|
6315
|
+
((int32_t *) b->data)[2] = axis2;
|
6316
|
+
((int32_t *) b->data)[3] = axis3;
|
6317
|
+
|
6318
|
+
ggml_scratch_load(ctx);
|
6319
|
+
|
6320
|
+
result->opt[0] = b;
|
6123
6321
|
}
|
6124
6322
|
|
6125
6323
|
return result;
|
@@ -6359,6 +6557,44 @@ struct ggml_tensor * ggml_soft_max_inplace(
|
|
6359
6557
|
return ggml_soft_max_impl(ctx, a, true);
|
6360
6558
|
}
|
6361
6559
|
|
6560
|
+
|
6561
|
+
// ggml_soft_max_back
|
6562
|
+
|
6563
|
+
struct ggml_tensor * ggml_soft_max_back_impl(
|
6564
|
+
struct ggml_context * ctx,
|
6565
|
+
struct ggml_tensor * a,
|
6566
|
+
struct ggml_tensor * b,
|
6567
|
+
bool inplace) {
|
6568
|
+
bool is_node = false;
|
6569
|
+
|
6570
|
+
if (a->grad || b->grad) {
|
6571
|
+
is_node = true; // TODO : implement backward pass
|
6572
|
+
}
|
6573
|
+
|
6574
|
+
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
6575
|
+
|
6576
|
+
result->op = GGML_OP_SOFT_MAX_BACK;
|
6577
|
+
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
6578
|
+
result->src0 = a;
|
6579
|
+
result->src1 = b;
|
6580
|
+
|
6581
|
+
return result;
|
6582
|
+
}
|
6583
|
+
|
6584
|
+
struct ggml_tensor * ggml_soft_max_back(
|
6585
|
+
struct ggml_context * ctx,
|
6586
|
+
struct ggml_tensor * a,
|
6587
|
+
struct ggml_tensor * b) {
|
6588
|
+
return ggml_soft_max_back_impl(ctx, a, b, false);
|
6589
|
+
}
|
6590
|
+
|
6591
|
+
struct ggml_tensor * ggml_soft_max_back_inplace(
|
6592
|
+
struct ggml_context * ctx,
|
6593
|
+
struct ggml_tensor * a,
|
6594
|
+
struct ggml_tensor * b) {
|
6595
|
+
return ggml_soft_max_back_impl(ctx, a, b, true);
|
6596
|
+
}
|
6597
|
+
|
6362
6598
|
// ggml_rope
|
6363
6599
|
|
6364
6600
|
struct ggml_tensor * ggml_rope_impl(
|
@@ -6371,7 +6607,7 @@ struct ggml_tensor * ggml_rope_impl(
|
|
6371
6607
|
GGML_ASSERT(n_past >= 0);
|
6372
6608
|
bool is_node = false;
|
6373
6609
|
|
6374
|
-
if (
|
6610
|
+
if (a->grad) {
|
6375
6611
|
is_node = true;
|
6376
6612
|
}
|
6377
6613
|
|
@@ -6425,8 +6661,7 @@ struct ggml_tensor * ggml_rope_back(
|
|
6425
6661
|
bool is_node = false;
|
6426
6662
|
|
6427
6663
|
if (a->grad) {
|
6428
|
-
|
6429
|
-
is_node = true;
|
6664
|
+
is_node = false; // TODO: implement backward
|
6430
6665
|
}
|
6431
6666
|
|
6432
6667
|
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
|
@@ -6508,7 +6743,7 @@ struct ggml_tensor * ggml_clamp(
|
|
6508
6743
|
|
6509
6744
|
ggml_scratch_save(ctx);
|
6510
6745
|
|
6511
|
-
struct ggml_tensor * b = ggml_new_tensor_1d(ctx,
|
6746
|
+
struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 2);
|
6512
6747
|
|
6513
6748
|
((float *) b->data)[0] = min;
|
6514
6749
|
((float *) b->data)[1] = max;
|
@@ -6523,9 +6758,9 @@ struct ggml_tensor * ggml_clamp(
|
|
6523
6758
|
return result;
|
6524
6759
|
}
|
6525
6760
|
|
6526
|
-
//
|
6761
|
+
// ggml_conv_1d_s1_ph
|
6527
6762
|
|
6528
|
-
struct ggml_tensor *
|
6763
|
+
struct ggml_tensor * ggml_conv_1d_s1_ph(
|
6529
6764
|
struct ggml_context * ctx,
|
6530
6765
|
struct ggml_tensor * a,
|
6531
6766
|
struct ggml_tensor * b) {
|
@@ -6542,7 +6777,7 @@ struct ggml_tensor * ggml_conv_1d_1s(
|
|
6542
6777
|
const int64_t ne[4] = { b->ne[0], a->ne[2], 1, 1, };
|
6543
6778
|
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 2, ne);
|
6544
6779
|
|
6545
|
-
result->op =
|
6780
|
+
result->op = GGML_OP_CONV_1D_S1_PH;
|
6546
6781
|
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
6547
6782
|
result->src0 = a;
|
6548
6783
|
result->src1 = b;
|
@@ -6550,9 +6785,9 @@ struct ggml_tensor * ggml_conv_1d_1s(
|
|
6550
6785
|
return result;
|
6551
6786
|
}
|
6552
6787
|
|
6553
|
-
//
|
6788
|
+
// ggml_conv_1d_s2_ph
|
6554
6789
|
|
6555
|
-
struct ggml_tensor *
|
6790
|
+
struct ggml_tensor * ggml_conv_1d_s2_ph(
|
6556
6791
|
struct ggml_context * ctx,
|
6557
6792
|
struct ggml_tensor * a,
|
6558
6793
|
struct ggml_tensor * b) {
|
@@ -6569,7 +6804,35 @@ struct ggml_tensor * ggml_conv_1d_2s(
|
|
6569
6804
|
const int64_t ne[4] = { b->ne[0]/2, a->ne[2], 1, 1, };
|
6570
6805
|
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 2, ne);
|
6571
6806
|
|
6572
|
-
result->op =
|
6807
|
+
result->op = GGML_OP_CONV_1D_S2_PH;
|
6808
|
+
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
6809
|
+
result->src0 = a;
|
6810
|
+
result->src1 = b;
|
6811
|
+
|
6812
|
+
return result;
|
6813
|
+
}
|
6814
|
+
|
6815
|
+
// ggml_conv_2d_sk_p0
|
6816
|
+
|
6817
|
+
struct ggml_tensor * ggml_conv_2d_sk_p0(
|
6818
|
+
struct ggml_context * ctx,
|
6819
|
+
struct ggml_tensor * a,
|
6820
|
+
struct ggml_tensor * b) {
|
6821
|
+
GGML_ASSERT(b->ne[3] == 1);
|
6822
|
+
GGML_ASSERT(a->ne[2] == b->ne[2]);
|
6823
|
+
GGML_ASSERT(b->ne[0] % a->ne[0] == 0);
|
6824
|
+
GGML_ASSERT(b->ne[1] % a->ne[1] == 0);
|
6825
|
+
bool is_node = false;
|
6826
|
+
|
6827
|
+
if (a->grad || b->grad) {
|
6828
|
+
GGML_ASSERT(false); // TODO: implement backward
|
6829
|
+
is_node = true;
|
6830
|
+
}
|
6831
|
+
|
6832
|
+
const int64_t ne[4] = { b->ne[0]/a->ne[0], b->ne[1]/a->ne[1], a->ne[3], 1, };
|
6833
|
+
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
|
6834
|
+
|
6835
|
+
result->op = GGML_OP_CONV_2D_SK_P0;
|
6573
6836
|
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
6574
6837
|
result->src0 = a;
|
6575
6838
|
result->src1 = b;
|
@@ -6591,7 +6854,6 @@ struct ggml_tensor * ggml_flash_attn(
|
|
6591
6854
|
bool is_node = false;
|
6592
6855
|
|
6593
6856
|
if (q->grad || k->grad || v->grad) {
|
6594
|
-
GGML_ASSERT(false); // TODO: implement backward
|
6595
6857
|
is_node = true;
|
6596
6858
|
}
|
6597
6859
|
|
@@ -6623,7 +6885,6 @@ struct ggml_tensor * ggml_flash_ff(
|
|
6623
6885
|
bool is_node = false;
|
6624
6886
|
|
6625
6887
|
if (a->grad || b0->grad || b1->grad || c0->grad || c1->grad) {
|
6626
|
-
GGML_ASSERT(false); // TODO: implement backward
|
6627
6888
|
is_node = true;
|
6628
6889
|
}
|
6629
6890
|
|
@@ -6641,54 +6902,202 @@ struct ggml_tensor * ggml_flash_ff(
|
|
6641
6902
|
return result;
|
6642
6903
|
}
|
6643
6904
|
|
6644
|
-
//
|
6905
|
+
// ggml_flash_attn_back
|
6906
|
+
|
6907
|
+
struct ggml_tensor * ggml_flash_attn_back(
|
6908
|
+
struct ggml_context * ctx,
|
6909
|
+
struct ggml_tensor * q,
|
6910
|
+
struct ggml_tensor * k,
|
6911
|
+
struct ggml_tensor * v,
|
6912
|
+
struct ggml_tensor * d,
|
6913
|
+
bool masked) {
|
6914
|
+
GGML_ASSERT(ggml_can_mul_mat(k, q));
|
6915
|
+
// TODO: check if vT can be multiplied by (k*qT)
|
6916
|
+
|
6917
|
+
// d shape [D,N,ne2,ne3]
|
6918
|
+
// q shape [D,N,ne2,ne3]
|
6919
|
+
// k shape [D,M,ne2,ne3]
|
6920
|
+
// v shape [M,D,ne2,ne3]
|
6921
|
+
|
6922
|
+
const int64_t D = q->ne[0];
|
6923
|
+
const int64_t N = q->ne[1];
|
6924
|
+
const int64_t M = k->ne[1];
|
6925
|
+
const int64_t ne2 = q->ne[2];
|
6926
|
+
const int64_t ne3 = q->ne[3];
|
6927
|
+
|
6928
|
+
GGML_ASSERT(k->ne[0] == D);
|
6929
|
+
GGML_ASSERT(v->ne[0] == M);
|
6930
|
+
GGML_ASSERT(v->ne[1] == D);
|
6931
|
+
GGML_ASSERT(d->ne[0] == D);
|
6932
|
+
GGML_ASSERT(d->ne[1] == N);
|
6933
|
+
GGML_ASSERT(k->ne[2] == ne2);
|
6934
|
+
GGML_ASSERT(k->ne[3] == ne3);
|
6935
|
+
GGML_ASSERT(v->ne[2] == ne2);
|
6936
|
+
GGML_ASSERT(v->ne[3] == ne3);
|
6937
|
+
GGML_ASSERT(d->ne[2] == ne2);
|
6938
|
+
GGML_ASSERT(d->ne[3] == ne3);
|
6645
6939
|
|
6646
|
-
struct ggml_tensor * ggml_map_unary_impl_f32(
|
6647
|
-
struct ggml_context * ctx,
|
6648
|
-
struct ggml_tensor * a,
|
6649
|
-
const ggml_unary_op_f32_t fun,
|
6650
|
-
bool inplace) {
|
6651
6940
|
bool is_node = false;
|
6652
6941
|
|
6653
|
-
if (
|
6654
|
-
|
6942
|
+
if (q->grad || k->grad || v->grad) {
|
6943
|
+
// when using this operation (in backwards pass) these grads are set.
|
6944
|
+
// we don't want to create (big) grad of our result, so is_node is false.
|
6945
|
+
is_node = false;
|
6655
6946
|
}
|
6656
6947
|
|
6657
|
-
|
6658
|
-
|
6659
|
-
|
6948
|
+
// store gradients of q, k and v as continuous tensors concatenated in result.
|
6949
|
+
// q shape[D,N,ne2,ne3] ; k shape [D,M,ne2,ne3] ; v shape [M,D,ne2,ne3]
|
6950
|
+
// gradq->data = result->data
|
6951
|
+
// gradk->data = result->data + nb0*D*N*ne2*ne3
|
6952
|
+
// gradv->data = result->data + nb0*D*N*ne2*ne3 + nb0*D*M*ne2*ne3
|
6953
|
+
// note: v and gradv are actually transposed, i.e. v->ne[0] != D.
|
6954
|
+
int64_t ne[4] = {D,M+N+M,ne2,ne3};
|
6660
6955
|
|
6661
|
-
result
|
6956
|
+
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
|
6957
|
+
|
6958
|
+
result->op = GGML_OP_FLASH_ATTN_BACK;
|
6662
6959
|
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
6663
|
-
result->src0 =
|
6664
|
-
result->
|
6960
|
+
result->src0 = q;
|
6961
|
+
result->src1 = k;
|
6962
|
+
result->opt[0] = v;
|
6963
|
+
result->opt[1] = d;
|
6964
|
+
result->opt[2] = ggml_new_i32(ctx, masked ? 1 : 0);
|
6665
6965
|
|
6666
6966
|
return result;
|
6667
6967
|
}
|
6668
6968
|
|
6669
|
-
|
6670
|
-
struct ggml_context * ctx,
|
6671
|
-
struct ggml_tensor * a,
|
6672
|
-
const ggml_unary_op_f32_t fun) {
|
6673
|
-
return ggml_map_unary_impl_f32(ctx, a, fun, false);
|
6674
|
-
}
|
6675
|
-
|
6676
|
-
struct ggml_tensor * ggml_map_unary_inplace_f32(
|
6677
|
-
struct ggml_context * ctx,
|
6678
|
-
struct ggml_tensor * a,
|
6679
|
-
const ggml_unary_op_f32_t fun) {
|
6680
|
-
return ggml_map_unary_impl_f32(ctx, a, fun, true);
|
6681
|
-
}
|
6682
|
-
|
6683
|
-
// ggml_map_binary
|
6969
|
+
// ggml_win_part
|
6684
6970
|
|
6685
|
-
struct ggml_tensor *
|
6686
|
-
struct ggml_context
|
6687
|
-
struct ggml_tensor
|
6688
|
-
|
6689
|
-
|
6690
|
-
|
6691
|
-
|
6971
|
+
struct ggml_tensor * ggml_win_part(
|
6972
|
+
struct ggml_context * ctx,
|
6973
|
+
struct ggml_tensor * a,
|
6974
|
+
int w) {
|
6975
|
+
GGML_ASSERT(a->ne[3] == 1);
|
6976
|
+
GGML_ASSERT(a->type == GGML_TYPE_F32);
|
6977
|
+
|
6978
|
+
bool is_node = false;
|
6979
|
+
|
6980
|
+
if (a->grad) {
|
6981
|
+
GGML_ASSERT(false); // TODO: implement backward
|
6982
|
+
is_node = true;
|
6983
|
+
}
|
6984
|
+
|
6985
|
+
// padding
|
6986
|
+
const int px = (w - a->ne[1]%w)%w;
|
6987
|
+
const int py = (w - a->ne[2]%w)%w;
|
6988
|
+
|
6989
|
+
const int npx = (px + a->ne[1])/w;
|
6990
|
+
const int npy = (py + a->ne[2])/w;
|
6991
|
+
const int np = npx*npy;
|
6992
|
+
|
6993
|
+
const int64_t ne[4] = { a->ne[0], w, w, np, };
|
6994
|
+
|
6995
|
+
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
|
6996
|
+
|
6997
|
+
ggml_scratch_save(ctx);
|
6998
|
+
|
6999
|
+
struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3);
|
7000
|
+
|
7001
|
+
((int32_t *) b->data)[0] = npx;
|
7002
|
+
((int32_t *) b->data)[1] = npy;
|
7003
|
+
((int32_t *) b->data)[2] = w;
|
7004
|
+
|
7005
|
+
ggml_scratch_load(ctx);
|
7006
|
+
|
7007
|
+
result->op = GGML_OP_WIN_PART;
|
7008
|
+
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
7009
|
+
result->src0 = a;
|
7010
|
+
result->src1 = NULL;
|
7011
|
+
result->opt[0] = b;
|
7012
|
+
|
7013
|
+
return result;
|
7014
|
+
}
|
7015
|
+
|
7016
|
+
// ggml_win_unpart
|
7017
|
+
|
7018
|
+
struct ggml_tensor * ggml_win_unpart(
|
7019
|
+
struct ggml_context * ctx,
|
7020
|
+
struct ggml_tensor * a,
|
7021
|
+
int w0,
|
7022
|
+
int h0,
|
7023
|
+
int w) {
|
7024
|
+
GGML_ASSERT(a->type == GGML_TYPE_F32);
|
7025
|
+
|
7026
|
+
bool is_node = false;
|
7027
|
+
|
7028
|
+
if (a->grad) {
|
7029
|
+
GGML_ASSERT(false); // TODO: implement backward
|
7030
|
+
is_node = true;
|
7031
|
+
}
|
7032
|
+
|
7033
|
+
const int64_t ne[4] = { a->ne[0], w0, h0, 1, };
|
7034
|
+
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 3, ne);
|
7035
|
+
|
7036
|
+
ggml_scratch_save(ctx);
|
7037
|
+
|
7038
|
+
struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1);
|
7039
|
+
|
7040
|
+
((int32_t *) b->data)[0] = w;
|
7041
|
+
|
7042
|
+
ggml_scratch_load(ctx);
|
7043
|
+
|
7044
|
+
result->op = GGML_OP_WIN_UNPART;
|
7045
|
+
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
7046
|
+
result->src0 = a;
|
7047
|
+
result->src1 = NULL;
|
7048
|
+
result->opt[0] = b;
|
7049
|
+
|
7050
|
+
return result;
|
7051
|
+
}
|
7052
|
+
|
7053
|
+
// ggml_map_unary
|
7054
|
+
|
7055
|
+
struct ggml_tensor * ggml_map_unary_impl_f32(
|
7056
|
+
struct ggml_context * ctx,
|
7057
|
+
struct ggml_tensor * a,
|
7058
|
+
const ggml_unary_op_f32_t fun,
|
7059
|
+
bool inplace) {
|
7060
|
+
bool is_node = false;
|
7061
|
+
|
7062
|
+
if (!inplace && a->grad) {
|
7063
|
+
is_node = true;
|
7064
|
+
}
|
7065
|
+
|
7066
|
+
struct ggml_tensor * addr_tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(void *) / sizeof(int32_t));
|
7067
|
+
*((void (**)(void))addr_tensor->data) = (void (*)(void))fun;
|
7068
|
+
struct ggml_tensor *result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
7069
|
+
|
7070
|
+
result->op = GGML_OP_MAP_UNARY;
|
7071
|
+
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
7072
|
+
result->src0 = a;
|
7073
|
+
result->opt[0] = addr_tensor;
|
7074
|
+
|
7075
|
+
return result;
|
7076
|
+
}
|
7077
|
+
|
7078
|
+
struct ggml_tensor * ggml_map_unary_f32(
|
7079
|
+
struct ggml_context * ctx,
|
7080
|
+
struct ggml_tensor * a,
|
7081
|
+
const ggml_unary_op_f32_t fun) {
|
7082
|
+
return ggml_map_unary_impl_f32(ctx, a, fun, false);
|
7083
|
+
}
|
7084
|
+
|
7085
|
+
struct ggml_tensor * ggml_map_unary_inplace_f32(
|
7086
|
+
struct ggml_context * ctx,
|
7087
|
+
struct ggml_tensor * a,
|
7088
|
+
const ggml_unary_op_f32_t fun) {
|
7089
|
+
return ggml_map_unary_impl_f32(ctx, a, fun, true);
|
7090
|
+
}
|
7091
|
+
|
7092
|
+
// ggml_map_binary
|
7093
|
+
|
7094
|
+
struct ggml_tensor * ggml_map_binary_impl_f32(
|
7095
|
+
struct ggml_context * ctx,
|
7096
|
+
struct ggml_tensor * a,
|
7097
|
+
struct ggml_tensor * b,
|
7098
|
+
const ggml_binary_op_f32_t fun,
|
7099
|
+
bool inplace) {
|
7100
|
+
GGML_ASSERT(ggml_are_same_shape(a, b));
|
6692
7101
|
|
6693
7102
|
bool is_node = false;
|
6694
7103
|
|
@@ -6725,6 +7134,50 @@ struct ggml_tensor * ggml_map_binary_inplace_f32(
|
|
6725
7134
|
return ggml_map_binary_impl_f32(ctx, a, b, fun, true);
|
6726
7135
|
}
|
6727
7136
|
|
7137
|
+
// ggml_cross_entropy_loss
|
7138
|
+
|
7139
|
+
struct ggml_tensor * ggml_cross_entropy_loss(
|
7140
|
+
struct ggml_context * ctx,
|
7141
|
+
struct ggml_tensor * a,
|
7142
|
+
struct ggml_tensor * b) {
|
7143
|
+
GGML_ASSERT(ggml_are_same_shape(a, b));
|
7144
|
+
bool is_node = false;
|
7145
|
+
|
7146
|
+
if (a->grad || b->grad) {
|
7147
|
+
is_node = true;
|
7148
|
+
}
|
7149
|
+
|
7150
|
+
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, a->type, 1);
|
7151
|
+
|
7152
|
+
result->op = GGML_OP_CROSS_ENTROPY_LOSS;
|
7153
|
+
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
7154
|
+
result->src0 = a;
|
7155
|
+
result->src1 = b;
|
7156
|
+
|
7157
|
+
return result;
|
7158
|
+
}
|
7159
|
+
|
7160
|
+
// ggml_cross_entropy_loss_back
|
7161
|
+
|
7162
|
+
struct ggml_tensor * ggml_cross_entropy_loss_back(
|
7163
|
+
struct ggml_context * ctx,
|
7164
|
+
struct ggml_tensor * a,
|
7165
|
+
struct ggml_tensor * b,
|
7166
|
+
struct ggml_tensor * c) {
|
7167
|
+
GGML_ASSERT(ggml_are_same_shape(a, b));
|
7168
|
+
GGML_ASSERT(ggml_is_scalar(c));
|
7169
|
+
|
7170
|
+
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
|
7171
|
+
|
7172
|
+
result->op = GGML_OP_CROSS_ENTROPY_LOSS_BACK;
|
7173
|
+
result->grad = NULL;
|
7174
|
+
result->src0 = a;
|
7175
|
+
result->src1 = b;
|
7176
|
+
result->opt[0] = c;
|
7177
|
+
|
7178
|
+
return result;
|
7179
|
+
}
|
7180
|
+
|
6728
7181
|
////////////////////////////////////////////////////////////////////////////////
|
6729
7182
|
|
6730
7183
|
void ggml_set_param(
|
@@ -7674,7 +8127,7 @@ static void ggml_compute_forward_add_q_f32(
|
|
7674
8127
|
|
7675
8128
|
void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
|
7676
8129
|
float * src1_row = (float *)((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13));
|
7677
|
-
void * dst_row = (void *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*
|
8130
|
+
void * dst_row = (void *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
|
7678
8131
|
|
7679
8132
|
assert(ne00 % 32 == 0);
|
7680
8133
|
|
@@ -8875,6 +9328,99 @@ static void ggml_compute_forward_repeat(
|
|
8875
9328
|
}
|
8876
9329
|
}
|
8877
9330
|
|
9331
|
+
// ggml_compute_forward_repeat_back
|
9332
|
+
|
9333
|
+
static void ggml_compute_forward_repeat_back_f32(
|
9334
|
+
const struct ggml_compute_params * params,
|
9335
|
+
const struct ggml_tensor * src0,
|
9336
|
+
struct ggml_tensor * dst) {
|
9337
|
+
GGML_ASSERT(params->ith == 0);
|
9338
|
+
GGML_ASSERT(ggml_can_repeat(dst, src0));
|
9339
|
+
|
9340
|
+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
9341
|
+
return;
|
9342
|
+
}
|
9343
|
+
|
9344
|
+
const int64_t ne0 = dst->ne[0];
|
9345
|
+
const int64_t ne1 = dst->ne[1];
|
9346
|
+
const int64_t ne2 = dst->ne[2];
|
9347
|
+
const int64_t ne3 = dst->ne[3];
|
9348
|
+
|
9349
|
+
const int64_t ne00 = src0->ne[0];
|
9350
|
+
const int64_t ne01 = src0->ne[1];
|
9351
|
+
const int64_t ne02 = src0->ne[2];
|
9352
|
+
const int64_t ne03 = src0->ne[3];
|
9353
|
+
|
9354
|
+
const size_t nb0 = dst->nb[0];
|
9355
|
+
const size_t nb1 = dst->nb[1];
|
9356
|
+
const size_t nb2 = dst->nb[2];
|
9357
|
+
const size_t nb3 = dst->nb[3];
|
9358
|
+
|
9359
|
+
const size_t nb00 = src0->nb[0];
|
9360
|
+
const size_t nb01 = src0->nb[1];
|
9361
|
+
const size_t nb02 = src0->nb[2];
|
9362
|
+
const size_t nb03 = src0->nb[3];
|
9363
|
+
|
9364
|
+
// guaranteed to be an integer due to the check in ggml_can_repeat
|
9365
|
+
const int nr0 = (int)(ne00/ne0);
|
9366
|
+
const int nr1 = (int)(ne01/ne1);
|
9367
|
+
const int nr2 = (int)(ne02/ne2);
|
9368
|
+
const int nr3 = (int)(ne03/ne3);
|
9369
|
+
|
9370
|
+
// TODO: support for transposed / permuted tensors
|
9371
|
+
GGML_ASSERT(nb0 == sizeof(float));
|
9372
|
+
GGML_ASSERT(nb00 == sizeof(float));
|
9373
|
+
|
9374
|
+
if (ggml_is_contiguous(dst)) {
|
9375
|
+
ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
|
9376
|
+
} else {
|
9377
|
+
for (int k3 = 0; k3 < ne3; k3++) {
|
9378
|
+
for (int k2 = 0; k2 < ne2; k2++) {
|
9379
|
+
for (int k1 = 0; k1 < ne1; k1++) {
|
9380
|
+
ggml_vec_set_f32(ne0,
|
9381
|
+
(float *) ((char *) dst->data + k1*nb1 + k2*nb2 + k3*nb3),
|
9382
|
+
0);
|
9383
|
+
}
|
9384
|
+
}
|
9385
|
+
}
|
9386
|
+
}
|
9387
|
+
|
9388
|
+
// TODO: maybe this is not optimal?
|
9389
|
+
for (int i3 = 0; i3 < nr3; i3++) {
|
9390
|
+
for (int k3 = 0; k3 < ne3; k3++) {
|
9391
|
+
for (int i2 = 0; i2 < nr2; i2++) {
|
9392
|
+
for (int k2 = 0; k2 < ne2; k2++) {
|
9393
|
+
for (int i1 = 0; i1 < nr1; i1++) {
|
9394
|
+
for (int k1 = 0; k1 < ne1; k1++) {
|
9395
|
+
for (int i0 = 0; i0 < nr0; i0++) {
|
9396
|
+
ggml_vec_acc_f32(ne0,
|
9397
|
+
(float *) ((char *) dst->data + ( k3)*nb3 + ( k2)*nb2 + ( k1)*nb1),
|
9398
|
+
(float *) ((char *) src0->data + (i3*ne3 + k3)*nb03 + (i2*ne2 + k2)*nb02 + (i1*ne1 + k1)*nb01 + (i0*ne0)*nb00));
|
9399
|
+
}
|
9400
|
+
}
|
9401
|
+
}
|
9402
|
+
}
|
9403
|
+
}
|
9404
|
+
}
|
9405
|
+
}
|
9406
|
+
}
|
9407
|
+
|
9408
|
+
static void ggml_compute_forward_repeat_back(
|
9409
|
+
const struct ggml_compute_params * params,
|
9410
|
+
const struct ggml_tensor * src0,
|
9411
|
+
struct ggml_tensor * dst) {
|
9412
|
+
switch (src0->type) {
|
9413
|
+
case GGML_TYPE_F32:
|
9414
|
+
{
|
9415
|
+
ggml_compute_forward_repeat_back_f32(params, src0, dst);
|
9416
|
+
} break;
|
9417
|
+
default:
|
9418
|
+
{
|
9419
|
+
GGML_ASSERT(false);
|
9420
|
+
} break;
|
9421
|
+
}
|
9422
|
+
}
|
9423
|
+
|
8878
9424
|
// ggml_compute_forward_abs
|
8879
9425
|
|
8880
9426
|
static void ggml_compute_forward_abs_f32(
|
@@ -9142,8 +9688,65 @@ static void ggml_compute_forward_gelu(
|
|
9142
9688
|
GGML_ASSERT(false);
|
9143
9689
|
} break;
|
9144
9690
|
}
|
9691
|
+
}
|
9692
|
+
|
9693
|
+
// ggml_compute_forward_gelu_quick
|
9694
|
+
|
9695
|
+
static void ggml_compute_forward_gelu_quick_f32(
|
9696
|
+
const struct ggml_compute_params * params,
|
9697
|
+
const struct ggml_tensor * src0,
|
9698
|
+
struct ggml_tensor * dst) {
|
9699
|
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
9700
|
+
GGML_ASSERT(ggml_is_contiguous(dst));
|
9701
|
+
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
9702
|
+
|
9703
|
+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
9704
|
+
return;
|
9705
|
+
}
|
9706
|
+
|
9707
|
+
const int ith = params->ith;
|
9708
|
+
const int nth = params->nth;
|
9709
|
+
|
9710
|
+
const int nc = src0->ne[0];
|
9711
|
+
const int nr = ggml_nrows(src0);
|
9712
|
+
|
9713
|
+
// rows per thread
|
9714
|
+
const int dr = (nr + nth - 1)/nth;
|
9715
|
+
|
9716
|
+
// row range for this thread
|
9717
|
+
const int ir0 = dr*ith;
|
9718
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
9719
|
+
|
9720
|
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
9721
|
+
ggml_vec_gelu_quick_f32(nc,
|
9722
|
+
(float *) ((char *) dst->data + i1*( dst->nb[1])),
|
9723
|
+
(float *) ((char *) src0->data + i1*(src0->nb[1])));
|
9724
|
+
|
9725
|
+
#ifndef NDEBUG
|
9726
|
+
for (int k = 0; k < nc; k++) {
|
9727
|
+
const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
9728
|
+
UNUSED(x);
|
9729
|
+
assert(!isnan(x));
|
9730
|
+
assert(!isinf(x));
|
9731
|
+
}
|
9732
|
+
#endif
|
9733
|
+
}
|
9734
|
+
}
|
9145
9735
|
|
9146
|
-
|
9736
|
+
static void ggml_compute_forward_gelu_quick(
|
9737
|
+
const struct ggml_compute_params * params,
|
9738
|
+
const struct ggml_tensor * src0,
|
9739
|
+
struct ggml_tensor * dst) {
|
9740
|
+
switch (src0->type) {
|
9741
|
+
case GGML_TYPE_F32:
|
9742
|
+
{
|
9743
|
+
ggml_compute_forward_gelu_quick_f32(params, src0, dst);
|
9744
|
+
} break;
|
9745
|
+
default:
|
9746
|
+
{
|
9747
|
+
GGML_ASSERT(false);
|
9748
|
+
} break;
|
9749
|
+
}
|
9147
9750
|
}
|
9148
9751
|
|
9149
9752
|
// ggml_compute_forward_silu
|
@@ -10249,9 +10852,179 @@ static void ggml_compute_forward_mul_mat(
|
|
10249
10852
|
}
|
10250
10853
|
}
|
10251
10854
|
|
10252
|
-
//
|
10855
|
+
// ggml_compute_forward_out_prod
|
10253
10856
|
|
10254
|
-
|
10857
|
+
|
10858
|
+
static void ggml_compute_forward_out_prod_f32(
|
10859
|
+
const struct ggml_compute_params * params,
|
10860
|
+
const struct ggml_tensor * src0,
|
10861
|
+
const struct ggml_tensor * src1,
|
10862
|
+
struct ggml_tensor * dst) {
|
10863
|
+
int64_t t0 = ggml_perf_time_us();
|
10864
|
+
UNUSED(t0);
|
10865
|
+
|
10866
|
+
const int64_t ne00 = src0->ne[0];
|
10867
|
+
const int64_t ne01 = src0->ne[1];
|
10868
|
+
const int64_t ne02 = src0->ne[2];
|
10869
|
+
const int64_t ne03 = src0->ne[3];
|
10870
|
+
|
10871
|
+
const int64_t ne10 = src1->ne[0];
|
10872
|
+
//const int64_t ne11 = src1->ne[1];
|
10873
|
+
const int64_t ne12 = src1->ne[2];
|
10874
|
+
const int64_t ne13 = src1->ne[3];
|
10875
|
+
|
10876
|
+
const int64_t ne0 = dst->ne[0];
|
10877
|
+
const int64_t ne1 = dst->ne[1];
|
10878
|
+
const int64_t ne2 = dst->ne[2];
|
10879
|
+
const int64_t ne3 = dst->ne[3];
|
10880
|
+
|
10881
|
+
const int nb00 = src0->nb[0];
|
10882
|
+
const int nb01 = src0->nb[1];
|
10883
|
+
const int nb02 = src0->nb[2];
|
10884
|
+
const int nb03 = src0->nb[3];
|
10885
|
+
|
10886
|
+
const int nb10 = src1->nb[0];
|
10887
|
+
const int nb11 = src1->nb[1];
|
10888
|
+
const int nb12 = src1->nb[2];
|
10889
|
+
const int nb13 = src1->nb[3];
|
10890
|
+
|
10891
|
+
const int nb0 = dst->nb[0];
|
10892
|
+
const int nb1 = dst->nb[1];
|
10893
|
+
const int nb2 = dst->nb[2];
|
10894
|
+
const int nb3 = dst->nb[3];
|
10895
|
+
|
10896
|
+
const int ith = params->ith;
|
10897
|
+
const int nth = params->nth;
|
10898
|
+
|
10899
|
+
GGML_ASSERT(ne02 == ne12);
|
10900
|
+
GGML_ASSERT(ne03 == ne13);
|
10901
|
+
GGML_ASSERT(ne2 == ne12);
|
10902
|
+
GGML_ASSERT(ne3 == ne13);
|
10903
|
+
|
10904
|
+
// we don't support permuted src0 or src1
|
10905
|
+
GGML_ASSERT(nb00 == sizeof(float));
|
10906
|
+
|
10907
|
+
// dst cannot be transposed or permuted
|
10908
|
+
GGML_ASSERT(nb0 == sizeof(float));
|
10909
|
+
// GGML_ASSERT(nb0 <= nb1);
|
10910
|
+
// GGML_ASSERT(nb1 <= nb2);
|
10911
|
+
// GGML_ASSERT(nb2 <= nb3);
|
10912
|
+
|
10913
|
+
GGML_ASSERT(ne0 == ne00);
|
10914
|
+
GGML_ASSERT(ne1 == ne10);
|
10915
|
+
GGML_ASSERT(ne2 == ne02);
|
10916
|
+
GGML_ASSERT(ne3 == ne03);
|
10917
|
+
|
10918
|
+
// nb01 >= nb00 - src0 is not transposed
|
10919
|
+
// compute by src0 rows
|
10920
|
+
|
10921
|
+
// TODO: #if defined(GGML_USE_CUBLAS) ggml_cuda_out_prod
|
10922
|
+
// TODO: #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
|
10923
|
+
|
10924
|
+
if (params->type == GGML_TASK_INIT) {
|
10925
|
+
ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
|
10926
|
+
return;
|
10927
|
+
}
|
10928
|
+
|
10929
|
+
if (params->type == GGML_TASK_FINALIZE) {
|
10930
|
+
return;
|
10931
|
+
}
|
10932
|
+
|
10933
|
+
// parallelize by last three dimensions
|
10934
|
+
|
10935
|
+
// total rows in dst
|
10936
|
+
const int64_t nr = ne1*ne2*ne3;
|
10937
|
+
|
10938
|
+
// rows per thread
|
10939
|
+
const int64_t dr = (nr + nth - 1)/nth;
|
10940
|
+
|
10941
|
+
// row range for this thread
|
10942
|
+
const int64_t ir0 = dr*ith;
|
10943
|
+
const int64_t ir1 = MIN(ir0 + dr, nr);
|
10944
|
+
|
10945
|
+
// dst[:,:,:,:] = 0
|
10946
|
+
// for i2,i3:
|
10947
|
+
// for i1:
|
10948
|
+
// for i01:
|
10949
|
+
// for i0:
|
10950
|
+
// dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3]
|
10951
|
+
|
10952
|
+
for (int64_t ir = ir0; ir < ir1; ++ir) {
|
10953
|
+
// dst indices
|
10954
|
+
const int64_t i3 = ir/(ne2*ne1);
|
10955
|
+
const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
|
10956
|
+
const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
|
10957
|
+
|
10958
|
+
const int64_t i02 = i2;
|
10959
|
+
const int64_t i03 = i3;
|
10960
|
+
|
10961
|
+
//const int64_t i10 = i1;
|
10962
|
+
const int64_t i12 = i2;
|
10963
|
+
const int64_t i13 = i3;
|
10964
|
+
|
10965
|
+
for (int64_t i01 = 0; i01 < ne01; ++i01) {
|
10966
|
+
const int64_t i11 = i01;
|
10967
|
+
|
10968
|
+
float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
|
10969
|
+
float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
|
10970
|
+
float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
|
10971
|
+
|
10972
|
+
ggml_vec_mad_f32(ne0, d, s0, *s1);
|
10973
|
+
// for (int64_t i0 = 0; i0 < ne0; ++i0) {
|
10974
|
+
// d[i0] += s0[i0] * s1[i1];
|
10975
|
+
// }
|
10976
|
+
}
|
10977
|
+
}
|
10978
|
+
|
10979
|
+
//int64_t t1 = ggml_perf_time_us();
|
10980
|
+
//static int64_t acc = 0;
|
10981
|
+
//acc += t1 - t0;
|
10982
|
+
//if (t1 - t0 > 10) {
|
10983
|
+
// printf("\n");
|
10984
|
+
// printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03);
|
10985
|
+
// printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03);
|
10986
|
+
// printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13);
|
10987
|
+
// printf("nb10 = %5d, nb11 = %5d, nb12 = %5d, nb13 = %5d\n", nb10, nb11, nb12, nb13);
|
10988
|
+
|
10989
|
+
// printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc);
|
10990
|
+
//}
|
10991
|
+
}
|
10992
|
+
|
10993
|
+
static void ggml_compute_forward_out_prod(
|
10994
|
+
const struct ggml_compute_params * params,
|
10995
|
+
const struct ggml_tensor * src0,
|
10996
|
+
const struct ggml_tensor * src1,
|
10997
|
+
struct ggml_tensor * dst) {
|
10998
|
+
switch (src0->type) {
|
10999
|
+
case GGML_TYPE_Q4_0:
|
11000
|
+
case GGML_TYPE_Q4_1:
|
11001
|
+
case GGML_TYPE_Q5_0:
|
11002
|
+
case GGML_TYPE_Q5_1:
|
11003
|
+
case GGML_TYPE_Q8_0:
|
11004
|
+
case GGML_TYPE_Q8_1:
|
11005
|
+
{
|
11006
|
+
GGML_ASSERT(false); // todo
|
11007
|
+
// ggml_compute_forward_out_prod_q_f32(params, src0, src1, dst);
|
11008
|
+
} break;
|
11009
|
+
case GGML_TYPE_F16:
|
11010
|
+
{
|
11011
|
+
GGML_ASSERT(false); // todo
|
11012
|
+
// ggml_compute_forward_out_prod_f16_f32(params, src0, src1, dst);
|
11013
|
+
} break;
|
11014
|
+
case GGML_TYPE_F32:
|
11015
|
+
{
|
11016
|
+
ggml_compute_forward_out_prod_f32(params, src0, src1, dst);
|
11017
|
+
} break;
|
11018
|
+
default:
|
11019
|
+
{
|
11020
|
+
GGML_ASSERT(false);
|
11021
|
+
} break;
|
11022
|
+
}
|
11023
|
+
}
|
11024
|
+
|
11025
|
+
// ggml_compute_forward_scale
|
11026
|
+
|
11027
|
+
static void ggml_compute_forward_scale_f32(
|
10255
11028
|
const struct ggml_compute_params * params,
|
10256
11029
|
const struct ggml_tensor * src0,
|
10257
11030
|
const struct ggml_tensor * src1,
|
@@ -10371,7 +11144,7 @@ static void ggml_compute_forward_set_f32(
|
|
10371
11144
|
const int im2 = (ne12 == 0 ? 0 : ne12-1);
|
10372
11145
|
const int im3 = (ne13 == 0 ? 0 : ne13-1);
|
10373
11146
|
|
10374
|
-
GGML_ASSERT(offset + im0*nb0 + im1*nb1 + im2*nb2 + im3*nb3
|
11147
|
+
GGML_ASSERT(offset + im0*nb0 + im1*nb1 + im2*nb2 + im3*nb3 <= ggml_nbytes(dst));
|
10375
11148
|
|
10376
11149
|
GGML_ASSERT(nb10 == sizeof(float));
|
10377
11150
|
|
@@ -10671,7 +11444,11 @@ static void ggml_compute_forward_get_rows_back_f32(
|
|
10671
11444
|
GGML_ASSERT(ggml_is_contiguous(opt0));
|
10672
11445
|
GGML_ASSERT(ggml_is_contiguous(dst));
|
10673
11446
|
|
10674
|
-
ggml_compute_forward_dup_same_cont(params, opt0, dst);
|
11447
|
+
// ggml_compute_forward_dup_same_cont(params, opt0, dst);
|
11448
|
+
|
11449
|
+
if (params->type == GGML_TASK_INIT) {
|
11450
|
+
memset(dst->data, 0, ggml_nbytes(dst));
|
11451
|
+
}
|
10675
11452
|
|
10676
11453
|
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
10677
11454
|
return;
|
@@ -10815,8 +11592,8 @@ static void ggml_compute_forward_diag_mask_f32(
|
|
10815
11592
|
const struct ggml_tensor * src1,
|
10816
11593
|
struct ggml_tensor * dst,
|
10817
11594
|
const float value) {
|
10818
|
-
|
10819
|
-
|
11595
|
+
GGML_ASSERT(src1->type == GGML_TYPE_I32);
|
11596
|
+
GGML_ASSERT(ggml_nelements(src1) == 2);
|
10820
11597
|
|
10821
11598
|
const int ith = params->ith;
|
10822
11599
|
const int nth = params->nth;
|
@@ -10824,7 +11601,7 @@ static void ggml_compute_forward_diag_mask_f32(
|
|
10824
11601
|
const int n_past = ((int32_t *) src1->data)[0];
|
10825
11602
|
const bool inplace = (bool)((int32_t *) src1->data)[1];
|
10826
11603
|
|
10827
|
-
|
11604
|
+
GGML_ASSERT(n_past >= 0);
|
10828
11605
|
|
10829
11606
|
if (!inplace && (params->type == GGML_TASK_INIT)) {
|
10830
11607
|
// memcpy needs to be synchronized across threads to avoid race conditions.
|
@@ -10848,8 +11625,8 @@ static void ggml_compute_forward_diag_mask_f32(
|
|
10848
11625
|
const int nr = src0->ne[1];
|
10849
11626
|
const int nz = n/nr;
|
10850
11627
|
|
10851
|
-
|
10852
|
-
|
11628
|
+
GGML_ASSERT( dst->nb[0] == sizeof(float));
|
11629
|
+
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
10853
11630
|
|
10854
11631
|
for (int k = 0; k < nz; k++) {
|
10855
11632
|
for (int j = ith; j < nr; j += nth) {
|
@@ -10985,6 +11762,101 @@ static void ggml_compute_forward_soft_max(
|
|
10985
11762
|
}
|
10986
11763
|
}
|
10987
11764
|
|
11765
|
+
// ggml_compute_forward_soft_max_back
|
11766
|
+
|
11767
|
+
static void ggml_compute_forward_soft_max_back_f32(
|
11768
|
+
const struct ggml_compute_params * params,
|
11769
|
+
const struct ggml_tensor * src0,
|
11770
|
+
const struct ggml_tensor * src1,
|
11771
|
+
struct ggml_tensor * dst) {
|
11772
|
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
11773
|
+
GGML_ASSERT(ggml_is_contiguous(src1));
|
11774
|
+
GGML_ASSERT(ggml_is_contiguous(dst));
|
11775
|
+
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
11776
|
+
GGML_ASSERT(ggml_are_same_shape(src1, dst));
|
11777
|
+
|
11778
|
+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
11779
|
+
return;
|
11780
|
+
}
|
11781
|
+
|
11782
|
+
// TODO: handle transposed/permuted matrices
|
11783
|
+
|
11784
|
+
const int ith = params->ith;
|
11785
|
+
const int nth = params->nth;
|
11786
|
+
|
11787
|
+
const int nc = src0->ne[0];
|
11788
|
+
const int nr = ggml_nrows(src0);
|
11789
|
+
|
11790
|
+
// rows per thread
|
11791
|
+
const int dr = (nr + nth - 1)/nth;
|
11792
|
+
|
11793
|
+
// row range for this thread
|
11794
|
+
const int ir0 = dr*ith;
|
11795
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
11796
|
+
|
11797
|
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
11798
|
+
float *dy = (float *)((char *) src0->data + i1*src0->nb[1]);
|
11799
|
+
float *y = (float *)((char *) src1->data + i1*src1->nb[1]);
|
11800
|
+
float *dx = (float *)((char *) dst->data + i1*dst->nb[1]);
|
11801
|
+
|
11802
|
+
#ifndef NDEBUG
|
11803
|
+
for (int i = 0; i < nc; ++i) {
|
11804
|
+
//printf("p[%d] = %f\n", i, p[i]);
|
11805
|
+
assert(!isnan(dy[i]));
|
11806
|
+
assert(!isnan(y[i]));
|
11807
|
+
}
|
11808
|
+
#endif
|
11809
|
+
// Jii = yi - yi*yi
|
11810
|
+
// Jij = -yi*yj
|
11811
|
+
// J = diag(y)-y.T*y
|
11812
|
+
// dx = J * dy
|
11813
|
+
// dxk = sum_i(Jki * dyi)
|
11814
|
+
// dxk = sum_i(-yk*yi * dyi) - (-yk*yk)*dyk + (yk - yk*yk)*dyk
|
11815
|
+
// dxk = sum_i(-yk*yi * dyi) + yk*dyk
|
11816
|
+
// dxk = -yk * sum_i(yi * dyi) + yk*dyk
|
11817
|
+
// dxk = -yk * dot(y, dy) + yk*dyk
|
11818
|
+
// dxk = yk * (- dot(y, dy) + dyk)
|
11819
|
+
// dxk = yk * (dyk - dot(y, dy))
|
11820
|
+
//
|
11821
|
+
// post-order:
|
11822
|
+
// dot_y_dy := dot(y, dy)
|
11823
|
+
// dx := dy
|
11824
|
+
// dx := dx - dot_y_dy
|
11825
|
+
// dx := dx * y
|
11826
|
+
|
11827
|
+
// linear runtime, no additional memory
|
11828
|
+
float dot_y_dy = 0;
|
11829
|
+
ggml_vec_dot_f32 (nc, &dot_y_dy, y, dy);
|
11830
|
+
ggml_vec_cpy_f32 (nc, dx, dy);
|
11831
|
+
ggml_vec_acc1_f32(nc, dx, -dot_y_dy);
|
11832
|
+
ggml_vec_mul_f32 (nc, dx, dx, y);
|
11833
|
+
|
11834
|
+
#ifndef NDEBUG
|
11835
|
+
for (int i = 0; i < nc; ++i) {
|
11836
|
+
assert(!isnan(dx[i]));
|
11837
|
+
assert(!isinf(dx[i]));
|
11838
|
+
}
|
11839
|
+
#endif
|
11840
|
+
}
|
11841
|
+
}
|
11842
|
+
|
11843
|
+
static void ggml_compute_forward_soft_max_back(
|
11844
|
+
const struct ggml_compute_params * params,
|
11845
|
+
const struct ggml_tensor * src0,
|
11846
|
+
const struct ggml_tensor * src1,
|
11847
|
+
struct ggml_tensor * dst) {
|
11848
|
+
switch (src0->type) {
|
11849
|
+
case GGML_TYPE_F32:
|
11850
|
+
{
|
11851
|
+
ggml_compute_forward_soft_max_back_f32(params, src0, src1, dst);
|
11852
|
+
} break;
|
11853
|
+
default:
|
11854
|
+
{
|
11855
|
+
GGML_ASSERT(false);
|
11856
|
+
} break;
|
11857
|
+
}
|
11858
|
+
}
|
11859
|
+
|
10988
11860
|
// ggml_compute_forward_alibi
|
10989
11861
|
|
10990
11862
|
static void ggml_compute_forward_alibi_f32(
|
@@ -10993,8 +11865,9 @@ static void ggml_compute_forward_alibi_f32(
|
|
10993
11865
|
const struct ggml_tensor * src1,
|
10994
11866
|
struct ggml_tensor * dst) {
|
10995
11867
|
assert(params->ith == 0);
|
10996
|
-
|
10997
|
-
|
11868
|
+
|
11869
|
+
GGML_ASSERT(src1->type == GGML_TYPE_I32);
|
11870
|
+
GGML_ASSERT(ggml_nelements(src1) == 3);
|
10998
11871
|
|
10999
11872
|
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
11000
11873
|
return;
|
@@ -11057,8 +11930,9 @@ static void ggml_compute_forward_alibi_f16(
|
|
11057
11930
|
const struct ggml_tensor * src1,
|
11058
11931
|
struct ggml_tensor * dst) {
|
11059
11932
|
assert(params->ith == 0);
|
11060
|
-
|
11061
|
-
|
11933
|
+
|
11934
|
+
GGML_ASSERT(src1->type == GGML_TYPE_I32);
|
11935
|
+
GGML_ASSERT(ggml_nelements(src1) == 3);
|
11062
11936
|
|
11063
11937
|
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
11064
11938
|
return;
|
@@ -11160,15 +12034,16 @@ static void ggml_compute_forward_clamp_f32(
|
|
11160
12034
|
const struct ggml_tensor * src1,
|
11161
12035
|
struct ggml_tensor * dst) {
|
11162
12036
|
assert(params->ith == 0);
|
11163
|
-
|
11164
|
-
|
12037
|
+
|
12038
|
+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
12039
|
+
GGML_ASSERT(ggml_nelements(src1) == 2);
|
11165
12040
|
|
11166
12041
|
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
11167
12042
|
return;
|
11168
12043
|
}
|
11169
12044
|
|
11170
|
-
const
|
11171
|
-
const
|
12045
|
+
const float min = ((float *) src1->data)[0];
|
12046
|
+
const float max = ((float *) src1->data)[1];
|
11172
12047
|
|
11173
12048
|
const int ith = params->ith;
|
11174
12049
|
const int nth = params->nth;
|
@@ -11726,9 +12601,9 @@ static void ggml_compute_forward_rope_back(
|
|
11726
12601
|
}
|
11727
12602
|
}
|
11728
12603
|
|
11729
|
-
//
|
12604
|
+
// ggml_compute_forward_conv_1d_s1_ph
|
11730
12605
|
|
11731
|
-
static void
|
12606
|
+
static void ggml_compute_forward_conv_1d_s1_ph_f16_f32(
|
11732
12607
|
const struct ggml_compute_params * params,
|
11733
12608
|
const struct ggml_tensor * src0,
|
11734
12609
|
const struct ggml_tensor * src1,
|
@@ -11848,7 +12723,7 @@ static void ggml_compute_forward_conv_1d_1s_f16_f32(
|
|
11848
12723
|
}
|
11849
12724
|
}
|
11850
12725
|
|
11851
|
-
static void
|
12726
|
+
static void ggml_compute_forward_conv_1d_s1_ph_f32(
|
11852
12727
|
const struct ggml_compute_params * params,
|
11853
12728
|
const struct ggml_tensor * src0,
|
11854
12729
|
const struct ggml_tensor * src1,
|
@@ -11968,7 +12843,7 @@ static void ggml_compute_forward_conv_1d_1s_f32(
|
|
11968
12843
|
}
|
11969
12844
|
}
|
11970
12845
|
|
11971
|
-
static void
|
12846
|
+
static void ggml_compute_forward_conv_1d_s1_ph(
|
11972
12847
|
const struct ggml_compute_params * params,
|
11973
12848
|
const struct ggml_tensor * src0,
|
11974
12849
|
const struct ggml_tensor * src1,
|
@@ -11976,11 +12851,11 @@ static void ggml_compute_forward_conv_1d_1s(
|
|
11976
12851
|
switch (src0->type) {
|
11977
12852
|
case GGML_TYPE_F16:
|
11978
12853
|
{
|
11979
|
-
|
12854
|
+
ggml_compute_forward_conv_1d_s1_ph_f16_f32(params, src0, src1, dst);
|
11980
12855
|
} break;
|
11981
12856
|
case GGML_TYPE_F32:
|
11982
12857
|
{
|
11983
|
-
|
12858
|
+
ggml_compute_forward_conv_1d_s1_ph_f32(params, src0, src1, dst);
|
11984
12859
|
} break;
|
11985
12860
|
default:
|
11986
12861
|
{
|
@@ -11989,9 +12864,9 @@ static void ggml_compute_forward_conv_1d_1s(
|
|
11989
12864
|
}
|
11990
12865
|
}
|
11991
12866
|
|
11992
|
-
//
|
12867
|
+
// ggml_compute_forward_conv_1d_s2_ph
|
11993
12868
|
|
11994
|
-
static void
|
12869
|
+
static void ggml_compute_forward_conv_1d_s2_ph_f16_f32(
|
11995
12870
|
const struct ggml_compute_params * params,
|
11996
12871
|
const struct ggml_tensor * src0,
|
11997
12872
|
const struct ggml_tensor * src1,
|
@@ -12111,7 +12986,7 @@ static void ggml_compute_forward_conv_1d_2s_f16_f32(
|
|
12111
12986
|
}
|
12112
12987
|
}
|
12113
12988
|
|
12114
|
-
static void
|
12989
|
+
static void ggml_compute_forward_conv_1d_s2_ph_f32(
|
12115
12990
|
const struct ggml_compute_params * params,
|
12116
12991
|
const struct ggml_tensor * src0,
|
12117
12992
|
const struct ggml_tensor * src1,
|
@@ -12231,7 +13106,7 @@ static void ggml_compute_forward_conv_1d_2s_f32(
|
|
12231
13106
|
}
|
12232
13107
|
}
|
12233
13108
|
|
12234
|
-
static void
|
13109
|
+
static void ggml_compute_forward_conv_1d_s2_ph(
|
12235
13110
|
const struct ggml_compute_params * params,
|
12236
13111
|
const struct ggml_tensor * src0,
|
12237
13112
|
const struct ggml_tensor * src1,
|
@@ -12239,11 +13114,11 @@ static void ggml_compute_forward_conv_1d_2s(
|
|
12239
13114
|
switch (src0->type) {
|
12240
13115
|
case GGML_TYPE_F16:
|
12241
13116
|
{
|
12242
|
-
|
13117
|
+
ggml_compute_forward_conv_1d_s2_ph_f16_f32(params, src0, src1, dst);
|
12243
13118
|
} break;
|
12244
13119
|
case GGML_TYPE_F32:
|
12245
13120
|
{
|
12246
|
-
|
13121
|
+
ggml_compute_forward_conv_1d_s2_ph_f32(params, src0, src1, dst);
|
12247
13122
|
} break;
|
12248
13123
|
default:
|
12249
13124
|
{
|
@@ -12252,51 +13127,188 @@ static void ggml_compute_forward_conv_1d_2s(
|
|
12252
13127
|
}
|
12253
13128
|
}
|
12254
13129
|
|
12255
|
-
//
|
13130
|
+
// ggml_compute_forward_conv_2d_sk_p0
|
12256
13131
|
|
12257
|
-
static void
|
13132
|
+
static void ggml_compute_forward_conv_2d_sk_p0_f16_f32(
|
12258
13133
|
const struct ggml_compute_params * params,
|
12259
|
-
const struct ggml_tensor *
|
12260
|
-
const struct ggml_tensor *
|
12261
|
-
|
12262
|
-
|
12263
|
-
|
13134
|
+
const struct ggml_tensor * src0,
|
13135
|
+
const struct ggml_tensor * src1,
|
13136
|
+
struct ggml_tensor * dst) {
|
13137
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
13138
|
+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
13139
|
+
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
13140
|
+
|
12264
13141
|
int64_t t0 = ggml_perf_time_us();
|
12265
13142
|
UNUSED(t0);
|
12266
13143
|
|
12267
|
-
const
|
12268
|
-
const
|
12269
|
-
const
|
12270
|
-
const
|
13144
|
+
const int ne00 = src0->ne[0];
|
13145
|
+
const int ne01 = src0->ne[1];
|
13146
|
+
const int ne02 = src0->ne[2];
|
13147
|
+
//const int ne03 = src0->ne[3];
|
12271
13148
|
|
12272
|
-
const
|
12273
|
-
const
|
12274
|
-
|
12275
|
-
//const
|
13149
|
+
const int ne10 = src1->ne[0];
|
13150
|
+
//const int ne11 = src1->ne[1];
|
13151
|
+
const int ne12 = src1->ne[2];
|
13152
|
+
//const int ne13 = src1->ne[3];
|
12276
13153
|
|
12277
|
-
|
12278
|
-
const
|
12279
|
-
|
12280
|
-
//const
|
13154
|
+
const int ne0 = dst->ne[0];
|
13155
|
+
const int ne1 = dst->ne[1];
|
13156
|
+
const int ne2 = dst->ne[2];
|
13157
|
+
//const int ne3 = dst->ne[3];
|
13158
|
+
//const int ne = ne0*ne1*ne2*ne3;
|
12281
13159
|
|
12282
|
-
const
|
12283
|
-
const
|
12284
|
-
//const
|
12285
|
-
|
13160
|
+
const int nb00 = src0->nb[0];
|
13161
|
+
//const int nb01 = src0->nb[1];
|
13162
|
+
//const int nb02 = src0->nb[2];
|
13163
|
+
const int nb03 = src0->nb[3];
|
12286
13164
|
|
12287
|
-
const int
|
12288
|
-
const int
|
12289
|
-
const int
|
12290
|
-
const int
|
13165
|
+
const int nb10 = src1->nb[0];
|
13166
|
+
//const int nb11 = src1->nb[1];
|
13167
|
+
const int nb12 = src1->nb[2];
|
13168
|
+
//const int nb13 = src1->nb[3];
|
12291
13169
|
|
12292
|
-
const int
|
12293
|
-
const int
|
12294
|
-
const int
|
12295
|
-
const int
|
13170
|
+
//const int nb0 = dst->nb[0];
|
13171
|
+
//const int nb1 = dst->nb[1];
|
13172
|
+
const int nb2 = dst->nb[2];
|
13173
|
+
//const int nb3 = dst->nb[3];
|
12296
13174
|
|
12297
|
-
const int
|
12298
|
-
const int
|
12299
|
-
|
13175
|
+
const int ith = params->ith;
|
13176
|
+
const int nth = params->nth;
|
13177
|
+
|
13178
|
+
const int nk0 = ne00;
|
13179
|
+
const int nk1 = ne01;
|
13180
|
+
|
13181
|
+
// size of the convolution row - the kernel size unrolled across all channels
|
13182
|
+
// round-up so it is more suitable for SIMD
|
13183
|
+
const int ew0 = ggml_up32(nk0*nk1*ne02);
|
13184
|
+
|
13185
|
+
GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
|
13186
|
+
GGML_ASSERT(nb10 == sizeof(float));
|
13187
|
+
|
13188
|
+
if (params->type == GGML_TASK_INIT) {
|
13189
|
+
// TODO: fix this memset (wsize is overestimated)
|
13190
|
+
memset(params->wdata, 0, params->wsize);
|
13191
|
+
|
13192
|
+
// prepare source data (src1)
|
13193
|
+
{
|
13194
|
+
ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
|
13195
|
+
|
13196
|
+
for (int i12 = 0; i12 < ne12; i12++) {
|
13197
|
+
const float * const src = (float *)((char *) src1->data + i12*nb12);
|
13198
|
+
ggml_fp16_t * dst_data = wdata;
|
13199
|
+
|
13200
|
+
for (int i1 = 0; i1 < ne1; i1++) {
|
13201
|
+
for (int i0 = 0; i0 < ne0; i0++) {
|
13202
|
+
for (int ik1 = 0; ik1 < nk1; ik1++) {
|
13203
|
+
for (int ik0 = 0; ik0 < nk0; ik0++) {
|
13204
|
+
dst_data[(i1*ne0 + i0)*ew0 + i12*(nk0*nk1) + ik1*nk0 + ik0] =
|
13205
|
+
GGML_FP32_TO_FP16(src[(i1*nk1 + ik1)*ne10 + (i0*nk0 + ik0)]);
|
13206
|
+
}
|
13207
|
+
}
|
13208
|
+
}
|
13209
|
+
}
|
13210
|
+
}
|
13211
|
+
}
|
13212
|
+
|
13213
|
+
return;
|
13214
|
+
}
|
13215
|
+
|
13216
|
+
if (params->type == GGML_TASK_FINALIZE) {
|
13217
|
+
return;
|
13218
|
+
}
|
13219
|
+
|
13220
|
+
// total patches in dst
|
13221
|
+
const int np = ne2;
|
13222
|
+
|
13223
|
+
// patches per thread
|
13224
|
+
const int dp = (np + nth - 1)/nth;
|
13225
|
+
|
13226
|
+
// patch range for this thread
|
13227
|
+
const int ip0 = dp*ith;
|
13228
|
+
const int ip1 = MIN(ip0 + dp, np);
|
13229
|
+
|
13230
|
+
ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
|
13231
|
+
|
13232
|
+
for (int i2 = ip0; i2 < ip1; i2++) {
|
13233
|
+
float * dst_data = (float *)((char *) dst->data + i2*nb2);
|
13234
|
+
|
13235
|
+
for (int i1 = 0; i1 < ne1; ++i1) {
|
13236
|
+
for (int i0 = 0; i0 < ne0; ++i0) {
|
13237
|
+
ggml_vec_dot_f16(ew0, dst_data + i1*ne0 + i0,
|
13238
|
+
(ggml_fp16_t *) ((char *) src0->data + i2*nb03),
|
13239
|
+
(ggml_fp16_t *) wdata + (i1*ne0 + i0)*ew0);
|
13240
|
+
}
|
13241
|
+
}
|
13242
|
+
}
|
13243
|
+
}
|
13244
|
+
|
13245
|
+
static void ggml_compute_forward_conv_2d_sk_p0(
|
13246
|
+
const struct ggml_compute_params * params,
|
13247
|
+
const struct ggml_tensor * src0,
|
13248
|
+
const struct ggml_tensor * src1,
|
13249
|
+
struct ggml_tensor * dst) {
|
13250
|
+
switch (src0->type) {
|
13251
|
+
case GGML_TYPE_F16:
|
13252
|
+
{
|
13253
|
+
ggml_compute_forward_conv_2d_sk_p0_f16_f32(params, src0, src1, dst);
|
13254
|
+
} break;
|
13255
|
+
case GGML_TYPE_F32:
|
13256
|
+
{
|
13257
|
+
//ggml_compute_forward_conv_2d_sk_p0_f32(params, src0, src1, dst);
|
13258
|
+
GGML_ASSERT(false);
|
13259
|
+
} break;
|
13260
|
+
default:
|
13261
|
+
{
|
13262
|
+
GGML_ASSERT(false);
|
13263
|
+
} break;
|
13264
|
+
}
|
13265
|
+
}
|
13266
|
+
|
13267
|
+
// ggml_compute_forward_flash_attn
|
13268
|
+
|
13269
|
+
static void ggml_compute_forward_flash_attn_f32(
|
13270
|
+
const struct ggml_compute_params * params,
|
13271
|
+
const struct ggml_tensor * q,
|
13272
|
+
const struct ggml_tensor * k,
|
13273
|
+
const struct ggml_tensor * v,
|
13274
|
+
const bool masked,
|
13275
|
+
struct ggml_tensor * dst) {
|
13276
|
+
int64_t t0 = ggml_perf_time_us();
|
13277
|
+
UNUSED(t0);
|
13278
|
+
|
13279
|
+
const int64_t neq0 = q->ne[0];
|
13280
|
+
const int64_t neq1 = q->ne[1];
|
13281
|
+
const int64_t neq2 = q->ne[2];
|
13282
|
+
const int64_t neq3 = q->ne[3];
|
13283
|
+
|
13284
|
+
const int64_t nek0 = k->ne[0];
|
13285
|
+
const int64_t nek1 = k->ne[1];
|
13286
|
+
//const int64_t nek2 = k->ne[2];
|
13287
|
+
//const int64_t nek3 = k->ne[3];
|
13288
|
+
|
13289
|
+
//const int64_t nev0 = v->ne[0];
|
13290
|
+
const int64_t nev1 = v->ne[1];
|
13291
|
+
//const int64_t nev2 = v->ne[2];
|
13292
|
+
//const int64_t nev3 = v->ne[3];
|
13293
|
+
|
13294
|
+
const int64_t ne0 = dst->ne[0];
|
13295
|
+
const int64_t ne1 = dst->ne[1];
|
13296
|
+
//const int64_t ne2 = dst->ne[2];
|
13297
|
+
//const int64_t ne3 = dst->ne[3];
|
13298
|
+
|
13299
|
+
const int nbk0 = k->nb[0];
|
13300
|
+
const int nbk1 = k->nb[1];
|
13301
|
+
const int nbk2 = k->nb[2];
|
13302
|
+
const int nbk3 = k->nb[3];
|
13303
|
+
|
13304
|
+
const int nbq0 = q->nb[0];
|
13305
|
+
const int nbq1 = q->nb[1];
|
13306
|
+
const int nbq2 = q->nb[2];
|
13307
|
+
const int nbq3 = q->nb[3];
|
13308
|
+
|
13309
|
+
const int nbv0 = v->nb[0];
|
13310
|
+
const int nbv1 = v->nb[1];
|
13311
|
+
const int nbv2 = v->nb[2];
|
12300
13312
|
const int nbv3 = v->nb[3];
|
12301
13313
|
|
12302
13314
|
const int nb0 = dst->nb[0];
|
@@ -12938,91 +13950,917 @@ static void ggml_compute_forward_flash_ff(
|
|
12938
13950
|
}
|
12939
13951
|
}
|
12940
13952
|
|
12941
|
-
//
|
13953
|
+
// ggml_compute_forward_flash_attn_back
|
12942
13954
|
|
12943
|
-
static void
|
13955
|
+
static void ggml_compute_forward_flash_attn_back_f32(
|
12944
13956
|
const struct ggml_compute_params * params,
|
12945
|
-
const struct ggml_tensor *
|
12946
|
-
struct ggml_tensor *
|
12947
|
-
const
|
12948
|
-
|
13957
|
+
const struct ggml_tensor * q,
|
13958
|
+
const struct ggml_tensor * k,
|
13959
|
+
const struct ggml_tensor * v,
|
13960
|
+
const struct ggml_tensor * d,
|
13961
|
+
const bool masked,
|
13962
|
+
struct ggml_tensor * dst) {
|
13963
|
+
int64_t t0 = ggml_perf_time_us();
|
13964
|
+
UNUSED(t0);
|
12949
13965
|
|
12950
|
-
|
12951
|
-
|
12952
|
-
|
13966
|
+
const int64_t neq0 = q->ne[0];
|
13967
|
+
const int64_t neq1 = q->ne[1];
|
13968
|
+
const int64_t neq2 = q->ne[2];
|
13969
|
+
const int64_t neq3 = q->ne[3];
|
12953
13970
|
|
12954
|
-
const
|
12955
|
-
const
|
13971
|
+
const int64_t nek0 = k->ne[0];
|
13972
|
+
const int64_t nek1 = k->ne[1];
|
13973
|
+
//const int64_t nek2 = k->ne[2];
|
13974
|
+
//const int64_t nek3 = k->ne[3];
|
12956
13975
|
|
12957
|
-
|
12958
|
-
|
13976
|
+
const int64_t nev0 = v->ne[0];
|
13977
|
+
const int64_t nev1 = v->ne[1];
|
13978
|
+
//const int64_t nev2 = v->ne[2];
|
13979
|
+
//const int64_t nev3 = v->ne[3];
|
12959
13980
|
|
12960
|
-
|
12961
|
-
|
12962
|
-
|
12963
|
-
|
12964
|
-
}
|
12965
|
-
}
|
13981
|
+
const int64_t ned0 = d->ne[0];
|
13982
|
+
const int64_t ned1 = d->ne[1];
|
13983
|
+
//const int64_t ned2 = d->ne[2];
|
13984
|
+
//const int64_t ned3 = d->ne[3];
|
12966
13985
|
|
13986
|
+
const int64_t ne0 = dst->ne[0];
|
13987
|
+
const int64_t ne1 = dst->ne[1];
|
13988
|
+
const int64_t ne2 = dst->ne[2];
|
13989
|
+
const int64_t ne3 = dst->ne[3];
|
12967
13990
|
|
12968
|
-
|
12969
|
-
|
12970
|
-
|
12971
|
-
|
12972
|
-
const ggml_unary_op_f32_t fun) {
|
12973
|
-
switch (src0->type) {
|
12974
|
-
case GGML_TYPE_F32:
|
12975
|
-
{
|
12976
|
-
ggml_compute_forward_map_unary_f32(params, src0, dst, fun);
|
12977
|
-
} break;
|
12978
|
-
default:
|
12979
|
-
{
|
12980
|
-
GGML_ASSERT(false);
|
12981
|
-
} break;
|
12982
|
-
}
|
12983
|
-
}
|
13991
|
+
const int nbk0 = k->nb[0];
|
13992
|
+
const int nbk1 = k->nb[1];
|
13993
|
+
const int nbk2 = k->nb[2];
|
13994
|
+
const int nbk3 = k->nb[3];
|
12984
13995
|
|
12985
|
-
|
13996
|
+
const int nbq0 = q->nb[0];
|
13997
|
+
const int nbq1 = q->nb[1];
|
13998
|
+
const int nbq2 = q->nb[2];
|
13999
|
+
const int nbq3 = q->nb[3];
|
12986
14000
|
|
12987
|
-
|
12988
|
-
|
12989
|
-
|
12990
|
-
|
12991
|
-
struct ggml_tensor * dst,
|
12992
|
-
const ggml_binary_op_f32_t fun) {
|
12993
|
-
assert(params->ith == 0);
|
12994
|
-
assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
|
14001
|
+
const int nbv0 = v->nb[0];
|
14002
|
+
const int nbv1 = v->nb[1];
|
14003
|
+
const int nbv2 = v->nb[2];
|
14004
|
+
const int nbv3 = v->nb[3];
|
12995
14005
|
|
12996
|
-
|
14006
|
+
const int nbd0 = d->nb[0];
|
14007
|
+
const int nbd1 = d->nb[1];
|
14008
|
+
const int nbd2 = d->nb[2];
|
14009
|
+
const int nbd3 = d->nb[3];
|
14010
|
+
|
14011
|
+
const int nb0 = dst->nb[0];
|
14012
|
+
const int nb1 = dst->nb[1];
|
14013
|
+
const int nb2 = dst->nb[2];
|
14014
|
+
const int nb3 = dst->nb[3];
|
14015
|
+
|
14016
|
+
const int ith = params->ith;
|
14017
|
+
const int nth = params->nth;
|
14018
|
+
|
14019
|
+
const int64_t D = neq0;
|
14020
|
+
const int64_t N = neq1;
|
14021
|
+
const int64_t P = nek1 - N;
|
14022
|
+
const int64_t M = P + N;
|
14023
|
+
|
14024
|
+
const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL);
|
14025
|
+
const int mxDM = MAX(D, Mup);
|
14026
|
+
|
14027
|
+
// GGML_ASSERT(ne0 == D);
|
14028
|
+
// GGML_ASSERT(ne1 == N);
|
14029
|
+
GGML_ASSERT(P >= 0);
|
14030
|
+
|
14031
|
+
GGML_ASSERT(nbq0 == sizeof(float));
|
14032
|
+
GGML_ASSERT(nbk0 == sizeof(float));
|
14033
|
+
GGML_ASSERT(nbv0 == sizeof(float));
|
14034
|
+
|
14035
|
+
GGML_ASSERT(neq0 == D);
|
14036
|
+
GGML_ASSERT(nek0 == D);
|
14037
|
+
GGML_ASSERT(nev1 == D);
|
14038
|
+
GGML_ASSERT(ned0 == D);
|
14039
|
+
|
14040
|
+
GGML_ASSERT(neq1 == N);
|
14041
|
+
GGML_ASSERT(nek1 == N + P);
|
14042
|
+
GGML_ASSERT(nev1 == D);
|
14043
|
+
GGML_ASSERT(ned1 == N);
|
14044
|
+
|
14045
|
+
// dst cannot be transposed or permuted
|
14046
|
+
GGML_ASSERT(nb0 == sizeof(float));
|
14047
|
+
GGML_ASSERT(nb0 <= nb1);
|
14048
|
+
GGML_ASSERT(nb1 <= nb2);
|
14049
|
+
GGML_ASSERT(nb2 <= nb3);
|
14050
|
+
|
14051
|
+
if (params->type == GGML_TASK_INIT) {
|
14052
|
+
if (ith == 0) {
|
14053
|
+
memset(dst->data, 0, nb0*ne0*ne1*ne2*ne3);
|
14054
|
+
}
|
14055
|
+
return;
|
14056
|
+
}
|
14057
|
+
|
14058
|
+
if (params->type == GGML_TASK_FINALIZE) {
|
12997
14059
|
return;
|
12998
14060
|
}
|
12999
14061
|
|
13000
|
-
|
13001
|
-
|
14062
|
+
// parallelize by q rows using ggml_vec_dot_f32
|
14063
|
+
|
14064
|
+
// total rows in q
|
14065
|
+
const int nr = neq2*neq3;
|
14066
|
+
|
14067
|
+
// rows per thread
|
14068
|
+
const int dr = (nr + nth - 1)/nth;
|
14069
|
+
|
14070
|
+
// row range for this thread
|
14071
|
+
const int ir0 = dr*ith;
|
14072
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
14073
|
+
|
14074
|
+
const float scale = 1.0f/sqrtf(D);
|
14075
|
+
|
14076
|
+
//printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
|
14077
|
+
|
14078
|
+
for (int ir = ir0; ir < ir1; ++ir) {
|
14079
|
+
// q indices
|
14080
|
+
const int iq3 = ir/(neq2);
|
14081
|
+
const int iq2 = ir - iq3*neq2;
|
14082
|
+
for ( int iq1 = 0; iq1 < neq1; ++iq1) {
|
14083
|
+
|
14084
|
+
|
14085
|
+
// not sure about CACHE_LINE_SIZE_F32..
|
14086
|
+
// - maybe it must not be multiplied by 2 and excluded from .. in SM 1*(..) offset?
|
14087
|
+
float * S = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 0*(mxDM+CACHE_LINE_SIZE_F32);
|
14088
|
+
float * SM = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 1*(mxDM+CACHE_LINE_SIZE_F32);
|
14089
|
+
|
14090
|
+
for (int i = M; i < Mup; ++i) {
|
14091
|
+
S[i] = -INFINITY;
|
14092
|
+
}
|
14093
|
+
|
14094
|
+
for (int64_t ic = 0; ic < nek1; ++ic) {
|
14095
|
+
// k indices
|
14096
|
+
const int ik3 = iq3;
|
14097
|
+
const int ik2 = iq2;
|
14098
|
+
const int ik1 = ic;
|
14099
|
+
|
14100
|
+
// S indices
|
14101
|
+
const int i1 = ik1;
|
14102
|
+
|
14103
|
+
ggml_vec_dot_f32(neq0,
|
14104
|
+
S + i1,
|
14105
|
+
(float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
|
14106
|
+
(float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
|
14107
|
+
}
|
14108
|
+
|
14109
|
+
// scale
|
14110
|
+
ggml_vec_scale_f32(nek1, S, scale);
|
14111
|
+
|
14112
|
+
if (masked) {
|
14113
|
+
for (int64_t i = P; i < M; i++) {
|
14114
|
+
if (i > P + iq1) {
|
14115
|
+
S[i] = -INFINITY;
|
14116
|
+
}
|
14117
|
+
}
|
14118
|
+
}
|
14119
|
+
|
14120
|
+
// softmax
|
14121
|
+
{
|
14122
|
+
float max = -INFINITY;
|
14123
|
+
ggml_vec_max_f32(M, &max, S);
|
14124
|
+
|
14125
|
+
ggml_float sum = 0.0;
|
14126
|
+
{
|
14127
|
+
#ifdef GGML_SOFT_MAX_ACCELERATE
|
14128
|
+
max = -max;
|
14129
|
+
vDSP_vsadd(SM, 1, &max, SM, 1, Mup);
|
14130
|
+
vvexpf(SM, SM, &Mup);
|
14131
|
+
ggml_vec_sum_f32(Mup, &sum, SM);
|
14132
|
+
#else
|
14133
|
+
uint16_t scvt[GGML_SOFT_MAX_UNROLL];
|
14134
|
+
ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
|
14135
|
+
|
14136
|
+
for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
|
14137
|
+
float * SR = S + i;
|
14138
|
+
float * SW = SM + i;
|
14139
|
+
|
14140
|
+
for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) {
|
14141
|
+
if (SR[j] == -INFINITY) {
|
14142
|
+
SW[j] = 0.0f;
|
14143
|
+
} else {
|
14144
|
+
ggml_fp16_t s = GGML_FP32_TO_FP16(SR[j] - max);
|
14145
|
+
memcpy(&scvt[j], &s, sizeof(uint16_t));
|
14146
|
+
const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]);
|
14147
|
+
sump[j] += (ggml_float)val;
|
14148
|
+
SW[j] = val;
|
14149
|
+
}
|
14150
|
+
}
|
14151
|
+
}
|
14152
|
+
|
14153
|
+
for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) {
|
14154
|
+
sum += sump[i];
|
14155
|
+
}
|
14156
|
+
#endif
|
14157
|
+
}
|
14158
|
+
|
14159
|
+
assert(sum > 0.0);
|
14160
|
+
|
14161
|
+
sum = 1.0/sum;
|
14162
|
+
ggml_vec_scale_f32(M, SM, sum);
|
14163
|
+
|
14164
|
+
}
|
14165
|
+
|
14166
|
+
// step-by-step explanation
|
14167
|
+
{
|
14168
|
+
// forward-process shape grads from backward process
|
14169
|
+
// parallel_for iq2,iq3:
|
14170
|
+
// k[:D,:M,:,:] [D,M,:,:] grad[k][:D,:M,iq2,iq3] += grad[kcur]
|
14171
|
+
// q[:D,:N,:,:] [D,N,:,:] grad[q][:D,iq1,iq2,iq3] += grad[qcur]
|
14172
|
+
// v[:M,:D,:,:] [M,D,:,:] grad[v][:M,:D,iq2,iq3] += grad[vcur]
|
14173
|
+
// for iq1:
|
14174
|
+
// kcur = k[:D,:M,iq2,iq3] [D,M,1,1] grad[kcur] = grad[S1].T @ qcur
|
14175
|
+
// qcur = q[:D,iq1,iq2,iq3] [D,1,1,1] grad[qcur] = grad[S1] @ kcur
|
14176
|
+
// vcur = v[:M,:D,iq2,iq3] [M,D,1,1] grad[vcur] = grad[S5].T @ S4
|
14177
|
+
// S0 = -Inf [D,1,1,1]
|
14178
|
+
// ~S1[i] = dot(kcur[:D,i], qcur)
|
14179
|
+
// S1 = qcur @ kcur.T [M,1,1,1] grad[S1] = grad[S2] * scale
|
14180
|
+
// S2 = S1 * scale [M,1,1,1] grad[S2] = diag_mask_zero(grad[S3], P)
|
14181
|
+
// S3 = diag_mask_inf(S2, P) [M,1,1,1] grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
|
14182
|
+
// S4 = softmax(S3) [M,1,1,1] grad[S4] = grad[S5] @ vcur
|
14183
|
+
// ~S5[i] = dot(vcur[:,i], S4)
|
14184
|
+
// S5 = S4 @ vcur.T [D,1,1,1] grad[S5] = d[:D,iq1,iq2,iq3]
|
14185
|
+
// ~dst[i,iq1,iq2,iq3] = S5[i] ^
|
14186
|
+
// dst[:D,iq1,iq2,iq3] = S5 | grad[dst[:D,iq1,iq2,iq3]] = d[:D,iq1,iq2,iq3]
|
14187
|
+
// dst backward-/ grad[dst] = d
|
14188
|
+
//
|
14189
|
+
// output gradients with their dependencies:
|
14190
|
+
//
|
14191
|
+
// grad[kcur] = grad[S1].T @ qcur
|
14192
|
+
// grad[S1] = diag_mask_zero(grad[S3], P) * scale
|
14193
|
+
// grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
|
14194
|
+
// grad[S4] = grad[S5] @ vcur
|
14195
|
+
// grad[S4] = d[:D,iq1,iq2,iq3] @ vcur
|
14196
|
+
// grad[qcur] = grad[S1] @ kcur
|
14197
|
+
// grad[vcur] = grad[S5].T @ S4
|
14198
|
+
// grad[vcur] = d[:D,iq1,iq2,iq3].T @ S4
|
14199
|
+
//
|
14200
|
+
// in post-order:
|
14201
|
+
//
|
14202
|
+
// S1 = qcur @ kcur.T
|
14203
|
+
// S2 = S1 * scale
|
14204
|
+
// S3 = diag_mask_inf(S2, P)
|
14205
|
+
// S4 = softmax(S3)
|
14206
|
+
// grad[S4] = d[:D,iq1,iq2,iq3] @ vcur
|
14207
|
+
// grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
|
14208
|
+
// grad[S1] = diag_mask_zero(grad[S3], P) * scale
|
14209
|
+
// grad[qcur] = grad[S1] @ kcur
|
14210
|
+
// grad[kcur] = grad[S1].T @ qcur
|
14211
|
+
// grad[vcur] = d[:D,iq1,iq2,iq3].T @ S4
|
14212
|
+
//
|
14213
|
+
// using less variables (SM=S4):
|
14214
|
+
//
|
14215
|
+
// S = diag_mask_inf(qcur @ kcur.T * scale, P)
|
14216
|
+
// SM = softmax(S)
|
14217
|
+
// S = d[:D,iq1,iq2,iq3] @ vcur
|
14218
|
+
// dot_SM_gradSM = dot(SM, S)
|
14219
|
+
// S = SM * (S - dot(SM, S))
|
14220
|
+
// S = diag_mask_zero(S, P) * scale
|
14221
|
+
//
|
14222
|
+
// grad[q][:D,iq1,iq2,iq3] += S @ kcur
|
14223
|
+
// grad[k][:D,:M,iq2,iq3] += S.T @ qcur
|
14224
|
+
// grad[v][:M,:D,iq2,iq3] += d[:D,iq1,iq2,iq3].T @ SM
|
14225
|
+
}
|
14226
|
+
|
14227
|
+
// S = gradSM = d[:D,iq1,iq2,iq3] @ vcur
|
14228
|
+
// S = d[:D,iq1,iq2,iq3] @ vcur
|
14229
|
+
// S[:M] += vcur[:M,ic] * d[ic,iq1,iq2,iq3]
|
14230
|
+
ggml_vec_set_f32(M, S, 0);
|
14231
|
+
for (int64_t ic = 0; ic < D; ++ic) {
|
14232
|
+
// dst indices
|
14233
|
+
const int i1 = iq1;
|
14234
|
+
const int i2 = iq2;
|
14235
|
+
const int i3 = iq3;
|
14236
|
+
|
14237
|
+
ggml_vec_mad_f32(M,
|
14238
|
+
S,
|
14239
|
+
(float *) ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)),
|
14240
|
+
*(float *) ((char *) d->data + (ic*nbd0 + i1*nbd1 + i2*nbd2 + i3*nbd3)));
|
14241
|
+
}
|
14242
|
+
|
14243
|
+
// S = SM * (S - dot(SM, S))
|
14244
|
+
float dot_SM_gradSM = 0;
|
14245
|
+
ggml_vec_dot_f32 (M, &dot_SM_gradSM, SM, S);
|
14246
|
+
ggml_vec_acc1_f32(M, S, -dot_SM_gradSM);
|
14247
|
+
ggml_vec_mul_f32 (M, S, S, SM);
|
14248
|
+
|
14249
|
+
// S = diag_mask_zero(S, P) * scale
|
14250
|
+
if (masked) {
|
14251
|
+
// for (int64_t i = P + iq1 + 1; i < M; i++) {
|
14252
|
+
// S[i] = 0;
|
14253
|
+
// }
|
14254
|
+
for (int64_t i = P; i < M; i++) {
|
14255
|
+
if (i > P + iq1) {
|
14256
|
+
S[i] = 0;
|
14257
|
+
}
|
14258
|
+
}
|
14259
|
+
}
|
14260
|
+
ggml_vec_scale_f32(M, S, scale);
|
14261
|
+
|
14262
|
+
void * grad_q = (char *) dst->data;
|
14263
|
+
void * grad_k = (char *) dst->data + nb0*D*N*neq2*neq3;
|
14264
|
+
void * grad_v = (char *) dst->data + nb0*D*N*neq2*neq3 + nb0*D*M*neq2*neq3;
|
14265
|
+
|
14266
|
+
const size_t nbgq1 = nb0*neq0;
|
14267
|
+
const size_t nbgq2 = nb0*neq0*neq1;
|
14268
|
+
const size_t nbgq3 = nb0*neq0*neq1*neq2;
|
14269
|
+
|
14270
|
+
const size_t nbgk1 = nb0*nek0;
|
14271
|
+
const size_t nbgk2 = nb0*nek0*nek1;
|
14272
|
+
const size_t nbgk3 = nb0*nek0*nek1*neq2;
|
14273
|
+
|
14274
|
+
const size_t nbgv1 = nb0*nev0;
|
14275
|
+
const size_t nbgv2 = nb0*nev0*nev1;
|
14276
|
+
const size_t nbgv3 = nb0*nev0*nev1*neq2;
|
14277
|
+
|
14278
|
+
// S shape [M,1]
|
14279
|
+
// SM shape [M,1]
|
14280
|
+
// kcur shape [D,M]
|
14281
|
+
// qcur shape [D,1]
|
14282
|
+
// vcur shape [M,D]
|
14283
|
+
//
|
14284
|
+
// grad[q][:D,iq1,iq2,iq3] += S @ kcur
|
14285
|
+
// grad[q][:D,iq1,iq2,iq3] += shape[M,1] @ shape[D,M]
|
14286
|
+
// grad[q][:D,iq1,iq2,iq3] += S[ic] * kcur[:D,ic]
|
14287
|
+
//
|
14288
|
+
//// grad[q][ic,iq1,iq2,iq3] += dot(kcur[:,ic],S.T)
|
14289
|
+
//// grad[q][ic,iq1,iq2,iq3] += dot(k[:D,ic,iq2,iq3],S.T)
|
14290
|
+
for (int64_t ic = 0; ic < M; ++ic) {
|
14291
|
+
// dst indices
|
14292
|
+
const int i1 = iq1;
|
14293
|
+
const int i2 = iq2;
|
14294
|
+
const int i3 = iq3;
|
14295
|
+
|
14296
|
+
ggml_vec_mad_f32(D,
|
14297
|
+
(float *) ((char *) grad_q + (i1*nbgq1 + i2*nbgq2 + i3*nbgq3)),
|
14298
|
+
(float *) ((char *) k->data + (ic*nbk1 + i2*nbk2 + i3*nbk3)),
|
14299
|
+
S[ic]);
|
14300
|
+
}
|
14301
|
+
|
14302
|
+
// grad[k][:D,:M,iq2,iq3] += S.T @ qcur
|
14303
|
+
// grad[k][:D,ic,iq2,iq3] += S.T[0,ic] * qcur[:D,0]
|
14304
|
+
// grad[k][:D,ic,iq2,iq3] += S[ic] * qcur[:D,0]
|
14305
|
+
for (int64_t ic = 0; ic < M; ++ic) {
|
14306
|
+
// dst indices
|
14307
|
+
const int i1 = iq1;
|
14308
|
+
const int i2 = iq2;
|
14309
|
+
const int i3 = iq3;
|
14310
|
+
|
14311
|
+
// ggml_vec_set_f32(D,
|
14312
|
+
// (float *) ((char *) grad_k + (ic*nbgk1 + i2*nbgk2 + i3*nbgk3)),
|
14313
|
+
// 0);
|
14314
|
+
ggml_vec_mad_f32(D,
|
14315
|
+
(float *) ((char *) grad_k + (ic*nbgk1 + i2*nbgk2 + i3*nbgk3)),
|
14316
|
+
(float *) ((char *) q->data + (i1*nbq1 + i2*nbq2 + i3*nbq3)),
|
14317
|
+
S[ic]);
|
14318
|
+
}
|
14319
|
+
|
14320
|
+
// grad[v][:M,:D,iq2,iq3] += d[:D,iq1,iq2,iq3].T @ SM
|
14321
|
+
// grad[v][:M,ic,iq2,iq3] += d[:D,iq1,iq2,iq3].T[0,ic] * SM[:M]
|
14322
|
+
// grad[v][:M,ic,iq2,iq3] += d[ic,iq1,iq2,iq3] * SM[:M]
|
14323
|
+
for (int64_t ic = 0; ic < D; ++ic) {
|
14324
|
+
// dst indices
|
14325
|
+
const int i1 = iq1;
|
14326
|
+
const int i2 = iq2;
|
14327
|
+
const int i3 = iq3;
|
14328
|
+
|
14329
|
+
// ggml_vec_set_f32(M,
|
14330
|
+
// (float *) ((char *) grad_v + ( ic*nbgv1 + i2*nbgv2 + i3*nbgv3)),
|
14331
|
+
// 0);
|
14332
|
+
ggml_vec_mad_f32(M,
|
14333
|
+
(float *) ((char *) grad_v + ( ic*nbgv1 + i2*nbgv2 + i3*nbgv3)),
|
14334
|
+
SM,
|
14335
|
+
*(float *) ((char *) d->data + (ic*nbd0 + i1*nbd1 + i2*nbd2 + i3*nbd3)));
|
14336
|
+
}
|
14337
|
+
}
|
14338
|
+
}
|
14339
|
+
}
|
14340
|
+
|
14341
|
+
static void ggml_compute_forward_flash_attn_back(
|
14342
|
+
const struct ggml_compute_params * params,
|
14343
|
+
const struct ggml_tensor * q,
|
14344
|
+
const struct ggml_tensor * k,
|
14345
|
+
const struct ggml_tensor * v,
|
14346
|
+
const struct ggml_tensor * d,
|
14347
|
+
const bool masked,
|
14348
|
+
struct ggml_tensor * dst) {
|
14349
|
+
switch (q->type) {
|
14350
|
+
case GGML_TYPE_F32:
|
14351
|
+
{
|
14352
|
+
ggml_compute_forward_flash_attn_back_f32(params, q, k, v, d, masked, dst);
|
14353
|
+
} break;
|
14354
|
+
default:
|
14355
|
+
{
|
14356
|
+
GGML_ASSERT(false);
|
14357
|
+
} break;
|
14358
|
+
}
|
14359
|
+
}
|
14360
|
+
|
14361
|
+
// ggml_compute_forward_win_part
|
14362
|
+
|
14363
|
+
static void ggml_compute_forward_win_part_f32(
|
14364
|
+
const struct ggml_compute_params * params,
|
14365
|
+
const struct ggml_tensor * src0,
|
14366
|
+
const struct ggml_tensor * opt0,
|
14367
|
+
struct ggml_tensor * dst) {
|
14368
|
+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
14369
|
+
return;
|
14370
|
+
}
|
14371
|
+
|
14372
|
+
const int64_t ne00 = src0->ne[0]; UNUSED(ne00);
|
14373
|
+
const int64_t ne01 = src0->ne[1];
|
14374
|
+
const int64_t ne02 = src0->ne[2];
|
14375
|
+
const int64_t ne03 = src0->ne[3]; UNUSED(ne03);
|
14376
|
+
|
14377
|
+
const int64_t ne0 = dst->ne[0];
|
14378
|
+
const int64_t ne1 = dst->ne[1];
|
14379
|
+
const int64_t ne2 = dst->ne[2];
|
14380
|
+
const int64_t ne3 = dst->ne[3]; UNUSED(ne3);
|
14381
|
+
|
14382
|
+
const int32_t nep0 = ((const int32_t *)(opt0->data))[0];
|
14383
|
+
const int32_t nep1 = ((const int32_t *)(opt0->data))[1];
|
14384
|
+
const int32_t w = ((const int32_t *)(opt0->data))[2];
|
14385
|
+
|
14386
|
+
assert(ne00 == ne0);
|
14387
|
+
assert(ne3 == nep0*nep1);
|
14388
|
+
|
14389
|
+
// TODO: optimize / multi-thread
|
14390
|
+
for (int py = 0; py < nep1; ++py) {
|
14391
|
+
for (int px = 0; px < nep0; ++px) {
|
14392
|
+
const int64_t i3 = py*nep0 + px;
|
14393
|
+
for (int64_t i2 = 0; i2 < ne2; ++i2) {
|
14394
|
+
for (int64_t i1 = 0; i1 < ne1; ++i1) {
|
14395
|
+
for (int64_t i0 = 0; i0 < ne0; ++i0) {
|
14396
|
+
const int64_t i02 = py*w + i2;
|
14397
|
+
const int64_t i01 = px*w + i1;
|
14398
|
+
const int64_t i00 = i0;
|
14399
|
+
|
14400
|
+
const int64_t i = i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0 + i0;
|
14401
|
+
const int64_t j = i02*ne01*ne00 + i01*ne00 + i00;
|
14402
|
+
|
14403
|
+
if (py*w + i2 >= ne02 || px*w + i1 >= ne01) {
|
14404
|
+
((float *) dst->data)[i] = 0.0f;
|
14405
|
+
} else {
|
14406
|
+
((float *) dst->data)[i] = ((float *) src0->data)[j];
|
14407
|
+
}
|
14408
|
+
}
|
14409
|
+
}
|
14410
|
+
}
|
14411
|
+
}
|
14412
|
+
}
|
14413
|
+
}
|
14414
|
+
|
14415
|
+
static void ggml_compute_forward_win_part(
|
14416
|
+
const struct ggml_compute_params * params,
|
14417
|
+
const struct ggml_tensor * src0,
|
14418
|
+
const struct ggml_tensor * opt0,
|
14419
|
+
struct ggml_tensor * dst) {
|
14420
|
+
switch (src0->type) {
|
14421
|
+
case GGML_TYPE_F32:
|
14422
|
+
{
|
14423
|
+
ggml_compute_forward_win_part_f32(params, src0, opt0, dst);
|
14424
|
+
} break;
|
14425
|
+
default:
|
14426
|
+
{
|
14427
|
+
GGML_ASSERT(false);
|
14428
|
+
} break;
|
14429
|
+
}
|
14430
|
+
}
|
14431
|
+
|
14432
|
+
// ggml_compute_forward_win_unpart
|
14433
|
+
|
14434
|
+
static void ggml_compute_forward_win_unpart_f32(
|
14435
|
+
const struct ggml_compute_params * params,
|
14436
|
+
const struct ggml_tensor * src0,
|
14437
|
+
const struct ggml_tensor * opt0,
|
14438
|
+
struct ggml_tensor * dst) {
|
14439
|
+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
14440
|
+
return;
|
14441
|
+
}
|
14442
|
+
|
14443
|
+
const int64_t ne00 = src0->ne[0];
|
14444
|
+
const int64_t ne01 = src0->ne[1];
|
14445
|
+
const int64_t ne02 = src0->ne[2];
|
14446
|
+
//const int64_t ne03 = src0->ne[3];
|
14447
|
+
|
14448
|
+
const int64_t ne0 = dst->ne[0];
|
14449
|
+
const int64_t ne1 = dst->ne[1];
|
14450
|
+
const int64_t ne2 = dst->ne[2];
|
14451
|
+
|
14452
|
+
const int32_t w = ((const int32_t *)(opt0->data))[0];
|
14453
|
+
|
14454
|
+
// padding
|
14455
|
+
const int px = (w - ne1%w)%w;
|
14456
|
+
//const int py = (w - ne2%w)%w;
|
14457
|
+
|
14458
|
+
const int npx = (px + ne1)/w;
|
14459
|
+
//const int npy = (py + ne2)/w;
|
14460
|
+
|
14461
|
+
assert(ne0 == ne00);
|
14462
|
+
|
14463
|
+
// TODO: optimize / multi-thread
|
14464
|
+
for (int64_t i2 = 0; i2 < ne2; ++i2) {
|
14465
|
+
for (int64_t i1 = 0; i1 < ne1; ++i1) {
|
14466
|
+
for (int64_t i0 = 0; i0 < ne0; ++i0) {
|
14467
|
+
const int ip2 = i2/w;
|
14468
|
+
const int ip1 = i1/w;
|
14469
|
+
|
14470
|
+
const int64_t i02 = i2%w;
|
14471
|
+
const int64_t i01 = i1%w;
|
14472
|
+
const int64_t i00 = i0;
|
14473
|
+
|
14474
|
+
const int64_t i = (ip2*npx + ip1)*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00 + i00;
|
14475
|
+
const int64_t j = i2*ne1*ne0 + i1*ne0 + i0;
|
14476
|
+
|
14477
|
+
((float *) dst->data)[j] = ((float *) src0->data)[i];
|
14478
|
+
}
|
14479
|
+
}
|
14480
|
+
}
|
14481
|
+
}
|
14482
|
+
|
14483
|
+
static void ggml_compute_forward_win_unpart(
|
14484
|
+
const struct ggml_compute_params * params,
|
14485
|
+
const struct ggml_tensor * src0,
|
14486
|
+
const struct ggml_tensor * opt0,
|
14487
|
+
struct ggml_tensor * dst) {
|
14488
|
+
switch (src0->type) {
|
14489
|
+
case GGML_TYPE_F32:
|
14490
|
+
{
|
14491
|
+
ggml_compute_forward_win_unpart_f32(params, src0, opt0, dst);
|
14492
|
+
} break;
|
14493
|
+
default:
|
14494
|
+
{
|
14495
|
+
GGML_ASSERT(false);
|
14496
|
+
} break;
|
14497
|
+
}
|
14498
|
+
}
|
14499
|
+
|
14500
|
+
// ggml_compute_forward_map_unary
|
14501
|
+
|
14502
|
+
static void ggml_compute_forward_map_unary_f32(
|
14503
|
+
const struct ggml_compute_params * params,
|
14504
|
+
const struct ggml_tensor * src0,
|
14505
|
+
struct ggml_tensor * dst,
|
14506
|
+
const ggml_unary_op_f32_t fun) {
|
14507
|
+
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
14508
|
+
|
14509
|
+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
14510
|
+
return;
|
14511
|
+
}
|
14512
|
+
|
14513
|
+
const int n = ggml_nrows(src0);
|
14514
|
+
const int nc = src0->ne[0];
|
14515
|
+
|
14516
|
+
assert( dst->nb[0] == sizeof(float));
|
14517
|
+
assert(src0->nb[0] == sizeof(float));
|
14518
|
+
|
14519
|
+
for (int i = 0; i < n; i++) {
|
14520
|
+
fun(nc,
|
14521
|
+
(float *) ((char *) dst->data + i*( dst->nb[1])),
|
14522
|
+
(float *) ((char *) src0->data + i*(src0->nb[1])));
|
14523
|
+
}
|
14524
|
+
}
|
14525
|
+
|
14526
|
+
|
14527
|
+
static void ggml_compute_forward_map_unary(
|
14528
|
+
const struct ggml_compute_params * params,
|
14529
|
+
const struct ggml_tensor * src0,
|
14530
|
+
struct ggml_tensor * dst,
|
14531
|
+
const ggml_unary_op_f32_t fun) {
|
14532
|
+
switch (src0->type) {
|
14533
|
+
case GGML_TYPE_F32:
|
14534
|
+
{
|
14535
|
+
ggml_compute_forward_map_unary_f32(params, src0, dst, fun);
|
14536
|
+
} break;
|
14537
|
+
default:
|
14538
|
+
{
|
14539
|
+
GGML_ASSERT(false);
|
14540
|
+
} break;
|
14541
|
+
}
|
14542
|
+
}
|
14543
|
+
|
14544
|
+
// ggml_compute_forward_map_binary
|
14545
|
+
|
14546
|
+
static void ggml_compute_forward_map_binary_f32(
|
14547
|
+
const struct ggml_compute_params * params,
|
14548
|
+
const struct ggml_tensor * src0,
|
14549
|
+
const struct ggml_tensor * src1,
|
14550
|
+
struct ggml_tensor * dst,
|
14551
|
+
const ggml_binary_op_f32_t fun) {
|
14552
|
+
assert(params->ith == 0);
|
14553
|
+
assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
|
14554
|
+
|
14555
|
+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
14556
|
+
return;
|
14557
|
+
}
|
14558
|
+
|
14559
|
+
const int n = ggml_nrows(src0);
|
14560
|
+
const int nc = src0->ne[0];
|
14561
|
+
|
14562
|
+
assert( dst->nb[0] == sizeof(float));
|
14563
|
+
assert(src0->nb[0] == sizeof(float));
|
14564
|
+
assert(src1->nb[0] == sizeof(float));
|
14565
|
+
|
14566
|
+
for (int i = 0; i < n; i++) {
|
14567
|
+
fun(nc,
|
14568
|
+
(float *) ((char *) dst->data + i*( dst->nb[1])),
|
14569
|
+
(float *) ((char *) src0->data + i*(src0->nb[1])),
|
14570
|
+
(float *) ((char *) src1->data + i*(src1->nb[1])));
|
14571
|
+
}
|
14572
|
+
}
|
14573
|
+
|
14574
|
+
|
14575
|
+
static void ggml_compute_forward_map_binary(
|
14576
|
+
const struct ggml_compute_params * params,
|
14577
|
+
const struct ggml_tensor * src0,
|
14578
|
+
const struct ggml_tensor * src1,
|
14579
|
+
struct ggml_tensor * dst,
|
14580
|
+
const ggml_binary_op_f32_t fun) {
|
14581
|
+
switch (src0->type) {
|
14582
|
+
case GGML_TYPE_F32:
|
14583
|
+
{
|
14584
|
+
ggml_compute_forward_map_binary_f32(params, src0, src1, dst, fun);
|
14585
|
+
} break;
|
14586
|
+
default:
|
14587
|
+
{
|
14588
|
+
GGML_ASSERT(false);
|
14589
|
+
} break;
|
14590
|
+
}
|
14591
|
+
}
|
14592
|
+
|
14593
|
+
// ggml_compute_forward_cross_entropy_loss
|
14594
|
+
|
14595
|
+
static void ggml_compute_forward_cross_entropy_loss_f32(
|
14596
|
+
const struct ggml_compute_params * params,
|
14597
|
+
const struct ggml_tensor * src0,
|
14598
|
+
const struct ggml_tensor * src1,
|
14599
|
+
struct ggml_tensor * dst) {
|
14600
|
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
14601
|
+
GGML_ASSERT(ggml_is_contiguous(src1));
|
14602
|
+
GGML_ASSERT(ggml_is_scalar(dst));
|
14603
|
+
GGML_ASSERT(ggml_are_same_shape(src0, src1));
|
14604
|
+
|
14605
|
+
const int ith = params->ith;
|
14606
|
+
const int nth = params->nth;
|
14607
|
+
|
14608
|
+
float * sums = (float *) params->wdata;
|
14609
|
+
|
14610
|
+
// TODO: handle transposed/permuted matrices
|
14611
|
+
const int nc = src0->ne[0];
|
14612
|
+
const int nr = ggml_nrows(src0);
|
14613
|
+
|
14614
|
+
if (params->type == GGML_TASK_INIT) {
|
14615
|
+
if (ith == 0) {
|
14616
|
+
memset(sums, 0, sizeof(float) * (nth + nth * nc));
|
14617
|
+
}
|
14618
|
+
return;
|
14619
|
+
}
|
14620
|
+
|
14621
|
+
if (params->type == GGML_TASK_FINALIZE) {
|
14622
|
+
if (ith == 0) {
|
14623
|
+
float * dp = (float *) dst->data;
|
14624
|
+
ggml_vec_sum_f32(nth, dp, sums);
|
14625
|
+
dp[0] *= -1.0f;
|
14626
|
+
}
|
14627
|
+
return;
|
14628
|
+
}
|
14629
|
+
|
14630
|
+
const double eps = 1e-9;
|
14631
|
+
|
14632
|
+
// rows per thread
|
14633
|
+
const int dr = (nr + nth - 1)/nth;
|
14634
|
+
|
14635
|
+
// row range for this thread
|
14636
|
+
const int ir0 = dr*ith;
|
14637
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
14638
|
+
|
14639
|
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
14640
|
+
float * s0 = (float *)((char *) src0->data + i1*src0->nb[1]);
|
14641
|
+
float * s1 = (float *)((char *) src1->data + i1*src1->nb[1]);
|
14642
|
+
float * st = (float *) params->wdata + nth + ith*nc;
|
14643
|
+
|
14644
|
+
#ifndef NDEBUG
|
14645
|
+
for (int i = 0; i < nc; ++i) {
|
14646
|
+
//printf("p[%d] = %f\n", i, p[i]);
|
14647
|
+
assert(!isnan(s0[i]));
|
14648
|
+
assert(!isnan(s1[i]));
|
14649
|
+
}
|
14650
|
+
#endif
|
14651
|
+
// soft_max
|
14652
|
+
ggml_float sum = 0.0;
|
14653
|
+
{
|
14654
|
+
float max = -INFINITY;
|
14655
|
+
ggml_vec_max_f32(nc, &max, s0);
|
14656
|
+
|
14657
|
+
uint16_t scvt;
|
14658
|
+
for (int i = 0; i < nc; i++) {
|
14659
|
+
if (s0[i] == -INFINITY) {
|
14660
|
+
st[i] = 0.0f;
|
14661
|
+
} else {
|
14662
|
+
// const float val = (s0[i] == -INFINITY) ? 0.0 : exp(s0[i] - max);
|
14663
|
+
ggml_fp16_t s = GGML_FP32_TO_FP16(s0[i] - max);
|
14664
|
+
memcpy(&scvt, &s, sizeof(scvt));
|
14665
|
+
const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt]);
|
14666
|
+
sum += (ggml_float)val;
|
14667
|
+
st[i] = val;
|
14668
|
+
}
|
14669
|
+
}
|
14670
|
+
|
14671
|
+
assert(sum > 0.0);
|
14672
|
+
// sum = 1.0/sum;
|
14673
|
+
}
|
14674
|
+
// avoid log(0) by rescaling from [0..1] to [eps..1]
|
14675
|
+
sum = (1.0 - eps) / sum;
|
14676
|
+
ggml_vec_scale_f32(nc, st, sum);
|
14677
|
+
ggml_vec_add1_f32(nc, st, st, eps);
|
14678
|
+
ggml_vec_log_f32(nc, st, st);
|
14679
|
+
ggml_vec_mul_f32(nc, st, st, s1);
|
14680
|
+
|
14681
|
+
ggml_vec_sum_f32(nc, sums + ith, st);
|
14682
|
+
|
14683
|
+
#ifndef NDEBUG
|
14684
|
+
for (int i = 0; i < nc; ++i) {
|
14685
|
+
assert(!isnan(st[i]));
|
14686
|
+
assert(!isinf(st[i]));
|
14687
|
+
}
|
14688
|
+
#endif
|
14689
|
+
}
|
14690
|
+
|
14691
|
+
}
|
14692
|
+
|
14693
|
+
static void ggml_compute_forward_cross_entropy_loss(
|
14694
|
+
const struct ggml_compute_params * params,
|
14695
|
+
const struct ggml_tensor * src0,
|
14696
|
+
const struct ggml_tensor * src1,
|
14697
|
+
struct ggml_tensor * dst) {
|
14698
|
+
switch (src0->type) {
|
14699
|
+
case GGML_TYPE_F32:
|
14700
|
+
{
|
14701
|
+
ggml_compute_forward_cross_entropy_loss_f32(params, src0, src1, dst);
|
14702
|
+
} break;
|
14703
|
+
default:
|
14704
|
+
{
|
14705
|
+
GGML_ASSERT(false);
|
14706
|
+
} break;
|
14707
|
+
}
|
14708
|
+
}
|
14709
|
+
|
14710
|
+
// ggml_compute_forward_cross_entropy_loss_back
|
14711
|
+
|
14712
|
+
static void ggml_compute_forward_cross_entropy_loss_back_f32(
|
14713
|
+
const struct ggml_compute_params * params,
|
14714
|
+
const struct ggml_tensor * src0,
|
14715
|
+
const struct ggml_tensor * src1,
|
14716
|
+
const struct ggml_tensor * opt0,
|
14717
|
+
struct ggml_tensor * dst) {
|
14718
|
+
GGML_ASSERT(ggml_is_contiguous(dst));
|
14719
|
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
14720
|
+
GGML_ASSERT(ggml_is_contiguous(src1));
|
14721
|
+
GGML_ASSERT(ggml_is_contiguous(opt0));
|
14722
|
+
GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
|
14723
|
+
|
14724
|
+
const int64_t ith = params->ith;
|
14725
|
+
const int64_t nth = params->nth;
|
14726
|
+
|
14727
|
+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
14728
|
+
return;
|
14729
|
+
}
|
14730
|
+
|
14731
|
+
const float eps = 1e-9f;
|
14732
|
+
|
14733
|
+
// TODO: handle transposed/permuted matrices
|
14734
|
+
const int64_t nc = src0->ne[0];
|
14735
|
+
const int64_t nr = ggml_nrows(src0);
|
14736
|
+
|
14737
|
+
// rows per thread
|
14738
|
+
const int64_t dr = (nr + nth - 1)/nth;
|
14739
|
+
|
14740
|
+
// row range for this thread
|
14741
|
+
const int64_t ir0 = dr*ith;
|
14742
|
+
const int64_t ir1 = MIN(ir0 + dr, nr);
|
14743
|
+
|
14744
|
+
float * d = (float *) opt0->data;
|
14745
|
+
|
14746
|
+
for (int64_t i1 = ir0; i1 < ir1; i1++) {
|
14747
|
+
float * ds0 = (float *)((char *) dst->data + i1*dst->nb[1]);
|
14748
|
+
float * s0 = (float *)((char *) src0->data + i1*src0->nb[1]);
|
14749
|
+
float * s1 = (float *)((char *) src1->data + i1*src1->nb[1]);
|
14750
|
+
float * sm = (float *) params->wdata + ith*nc;
|
14751
|
+
|
14752
|
+
#ifndef NDEBUG
|
14753
|
+
for (int i = 0; i < nc; ++i) {
|
14754
|
+
//printf("p[%d] = %f\n", i, p[i]);
|
14755
|
+
assert(!isnan(s0[i]));
|
14756
|
+
assert(!isnan(s1[i]));
|
14757
|
+
}
|
14758
|
+
#endif
|
14759
|
+
// step by step explanation:
|
14760
|
+
{
|
14761
|
+
//float * sums = (float *) params->wdata;
|
14762
|
+
|
14763
|
+
// forward pass with annotated gradients from backward pass
|
14764
|
+
// (built by going in reverse operation order, adding to gradients of current operation args)
|
14765
|
+
// st0 = exp(s0-max(s0)) grad[st0] = grad[st1]*(1.0 - eps)/sum
|
14766
|
+
// from softmax_back: grad[s0] = st1_k * (grad[st1]_k - dot(st1, grad[st1]))
|
14767
|
+
// ggml_vec_scale_f32(nc, st, sum); // st1 = st0*/sum = softmax(s0) grad[st1] = grad[st2]*(1.0 - eps)
|
14768
|
+
// ggml_vec_scale_f32(nc, st, (1.0f - eps)); // st2 = st1*(1.0 - eps) grad[st2] = grad[st3]
|
14769
|
+
// ggml_vec_add1_f32(nc, st, st, eps); // st3 = st2 + eps grad[st3] = grad[st4]/st3
|
14770
|
+
// ggml_vec_log_f32(nc, st, st); // st4 = log(st3) grad[st4] = grad[st5] * s1
|
14771
|
+
// ggml_vec_mul_f32(nc, st, st, s1); // st5 = st4 * s1 grad[st5] = grad[sums[ith]]
|
14772
|
+
// ggml_vec_sum_f32(nc, sums + ith, st); // sums[ith] = st5 grad[sums[ith]] = grad[cross_entropy_loss] = -grad[cel]
|
14773
|
+
|
14774
|
+
// substitute into grad[st1], because we can reuse softmax_back from this point on
|
14775
|
+
// grad[st1] = -grad[cel]*s1*(1.0 - eps)/(eps + softmax(s0)*(1.0 - eps))
|
14776
|
+
// postorder:
|
14777
|
+
// grad[st1] := softmax(s0)
|
14778
|
+
// grad[st1] := grad[st1]*(1.0 - eps)
|
14779
|
+
// grad[st1] := grad[st1] + eps
|
14780
|
+
// grad[st1] := s1 / grad[st1]
|
14781
|
+
// grad[st1] := grad[st1]*(1.0-eps)*-grad[cel]
|
14782
|
+
|
14783
|
+
// src0 gradients by going through softmax_back
|
14784
|
+
// grad[s0] = st1_k * (grad[st1]_k - dot(st1, grad[st1]))
|
14785
|
+
// from softmax_back:
|
14786
|
+
// dxk = yk * (dyk - dot(y, dy))
|
14787
|
+
// dot_y_dy := dot(y, dy)
|
14788
|
+
// dx := dy
|
14789
|
+
// dx := dx - dot_y_dy
|
14790
|
+
// dx := dx * y
|
14791
|
+
// postorder:
|
14792
|
+
// dot_st1_dst1 := dot(st1, grad[st1])
|
14793
|
+
// grad[s0] := grad[st1]
|
14794
|
+
// grad[s0] := grad[s0] - dot_st1_dst1
|
14795
|
+
// grad[s0] := grad[s0] * st1
|
14796
|
+
|
14797
|
+
// prepend postorder from grad[st1] directly using grad[s0] as memory location, as we will grad[s0] := grad[st1]
|
14798
|
+
// sm := softmax(s0)
|
14799
|
+
// grad[s0] := sm*(1.0 - eps)
|
14800
|
+
// grad[s0] := grad[s0] + eps
|
14801
|
+
// grad[s0] := s1 / grad[s0]
|
14802
|
+
// grad[s0] := grad[s0]*(1.0-eps)*-grad[cel]
|
14803
|
+
// dot_st1_dst1 := dot(sm, grad[s0])
|
14804
|
+
// grad[s0] := grad[s0] - dot_st1_dst1
|
14805
|
+
// grad[s0] := grad[s0] * sm
|
14806
|
+
}
|
14807
|
+
|
14808
|
+
// soft_max
|
14809
|
+
ggml_float sum = 0.0;
|
14810
|
+
{
|
14811
|
+
float max = -INFINITY;
|
14812
|
+
ggml_vec_max_f32(nc, &max, s0);
|
14813
|
+
|
14814
|
+
uint16_t scvt;
|
14815
|
+
for (int i = 0; i < nc; i++) {
|
14816
|
+
if (s0[i] == -INFINITY) {
|
14817
|
+
sm[i] = 0.0f;
|
14818
|
+
} else {
|
14819
|
+
// const float val = (s0[i] == -INFINITY) ? 0.0 : exp(s0[i] - max);
|
14820
|
+
ggml_fp16_t s = GGML_FP32_TO_FP16(s0[i] - max);
|
14821
|
+
memcpy(&scvt, &s, sizeof(scvt));
|
14822
|
+
const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt]);
|
14823
|
+
sum += (ggml_float)val;
|
14824
|
+
sm[i] = val;
|
14825
|
+
}
|
14826
|
+
}
|
14827
|
+
|
14828
|
+
assert(sum > 0.0);
|
14829
|
+
sum = 1.0/sum;
|
14830
|
+
}
|
13002
14831
|
|
13003
|
-
|
13004
|
-
|
13005
|
-
|
14832
|
+
float dot_st1_dst1 = 0;
|
14833
|
+
ggml_vec_scale_f32(nc, sm, sum);
|
14834
|
+
ggml_vec_cpy_f32 (nc, ds0, sm);
|
14835
|
+
ggml_vec_scale_f32(nc, ds0, (1.0f - eps));
|
14836
|
+
ggml_vec_add1_f32 (nc, ds0, ds0, eps);
|
14837
|
+
ggml_vec_div_f32 (nc, ds0, s1, ds0);
|
14838
|
+
ggml_vec_scale_f32(nc, ds0, -(1.0f - eps)*d[0]);
|
14839
|
+
ggml_vec_dot_f32 (nc, &dot_st1_dst1, sm, ds0);
|
14840
|
+
ggml_vec_acc1_f32 (nc, ds0, -dot_st1_dst1);
|
14841
|
+
ggml_vec_mul_f32 (nc, ds0, ds0, sm);
|
13006
14842
|
|
13007
|
-
|
13008
|
-
|
13009
|
-
|
13010
|
-
|
13011
|
-
|
14843
|
+
#ifndef NDEBUG
|
14844
|
+
for (int i = 0; i < nc; ++i) {
|
14845
|
+
assert(!isnan(sm[i]));
|
14846
|
+
assert(!isinf(sm[i]));
|
14847
|
+
assert(!isnan(ds0[i]));
|
14848
|
+
assert(!isinf(ds0[i]));
|
14849
|
+
}
|
14850
|
+
#endif
|
13012
14851
|
}
|
13013
14852
|
}
|
13014
14853
|
|
13015
|
-
|
13016
|
-
static void ggml_compute_forward_map_binary(
|
14854
|
+
static void ggml_compute_forward_cross_entropy_loss_back(
|
13017
14855
|
const struct ggml_compute_params * params,
|
13018
14856
|
const struct ggml_tensor * src0,
|
13019
14857
|
const struct ggml_tensor * src1,
|
13020
|
-
struct ggml_tensor *
|
13021
|
-
|
14858
|
+
const struct ggml_tensor * opt0,
|
14859
|
+
struct ggml_tensor * dst) {
|
13022
14860
|
switch (src0->type) {
|
13023
14861
|
case GGML_TYPE_F32:
|
13024
14862
|
{
|
13025
|
-
|
14863
|
+
ggml_compute_forward_cross_entropy_loss_back_f32(params, src0, src1, opt0, dst);
|
13026
14864
|
} break;
|
13027
14865
|
default:
|
13028
14866
|
{
|
@@ -13031,6 +14869,7 @@ static void ggml_compute_forward_map_binary(
|
|
13031
14869
|
}
|
13032
14870
|
}
|
13033
14871
|
|
14872
|
+
|
13034
14873
|
/////////////////////////////////
|
13035
14874
|
|
13036
14875
|
static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
|
@@ -13102,6 +14941,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
13102
14941
|
{
|
13103
14942
|
ggml_compute_forward_repeat(params, tensor->src0, tensor);
|
13104
14943
|
} break;
|
14944
|
+
case GGML_OP_REPEAT_BACK:
|
14945
|
+
{
|
14946
|
+
ggml_compute_forward_repeat_back(params, tensor->src0, tensor);
|
14947
|
+
} break;
|
13105
14948
|
case GGML_OP_ABS:
|
13106
14949
|
{
|
13107
14950
|
ggml_compute_forward_abs(params, tensor->src0, tensor);
|
@@ -13126,6 +14969,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
13126
14969
|
{
|
13127
14970
|
ggml_compute_forward_gelu(params, tensor->src0, tensor);
|
13128
14971
|
} break;
|
14972
|
+
case GGML_OP_GELU_QUICK:
|
14973
|
+
{
|
14974
|
+
ggml_compute_forward_gelu_quick(params, tensor->src0, tensor);
|
14975
|
+
} break;
|
13129
14976
|
case GGML_OP_SILU:
|
13130
14977
|
{
|
13131
14978
|
ggml_compute_forward_silu(params, tensor->src0, tensor);
|
@@ -13150,6 +14997,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
13150
14997
|
{
|
13151
14998
|
ggml_compute_forward_mul_mat(params, tensor->src0, tensor->src1, tensor);
|
13152
14999
|
} break;
|
15000
|
+
case GGML_OP_OUT_PROD:
|
15001
|
+
{
|
15002
|
+
ggml_compute_forward_out_prod(params, tensor->src0, tensor->src1, tensor);
|
15003
|
+
} break;
|
13153
15004
|
case GGML_OP_SCALE:
|
13154
15005
|
{
|
13155
15006
|
ggml_compute_forward_scale(params, tensor->src0, tensor->src1, tensor);
|
@@ -13206,6 +15057,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
13206
15057
|
{
|
13207
15058
|
ggml_compute_forward_soft_max(params, tensor->src0, tensor);
|
13208
15059
|
} break;
|
15060
|
+
case GGML_OP_SOFT_MAX_BACK:
|
15061
|
+
{
|
15062
|
+
ggml_compute_forward_soft_max_back(params, tensor->src0, tensor->src1, tensor);
|
15063
|
+
} break;
|
13209
15064
|
case GGML_OP_ROPE:
|
13210
15065
|
{
|
13211
15066
|
ggml_compute_forward_rope(params, tensor->src0, tensor->src1, tensor);
|
@@ -13222,25 +15077,44 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
13222
15077
|
{
|
13223
15078
|
ggml_compute_forward_clamp(params, tensor->src0, tensor->src1, tensor);
|
13224
15079
|
} break;
|
13225
|
-
case
|
15080
|
+
case GGML_OP_CONV_1D_S1_PH:
|
15081
|
+
{
|
15082
|
+
ggml_compute_forward_conv_1d_s1_ph(params, tensor->src0, tensor->src1, tensor);
|
15083
|
+
} break;
|
15084
|
+
case GGML_OP_CONV_1D_S2_PH:
|
13226
15085
|
{
|
13227
|
-
|
15086
|
+
ggml_compute_forward_conv_1d_s2_ph(params, tensor->src0, tensor->src1, tensor);
|
13228
15087
|
} break;
|
13229
|
-
case
|
15088
|
+
case GGML_OP_CONV_2D_SK_P0:
|
13230
15089
|
{
|
13231
|
-
|
15090
|
+
ggml_compute_forward_conv_2d_sk_p0(params, tensor->src0, tensor->src1, tensor);
|
13232
15091
|
} break;
|
13233
15092
|
case GGML_OP_FLASH_ATTN:
|
13234
15093
|
{
|
13235
|
-
int32_t t = ggml_get_i32_1d(tensor->opt[1], 0);
|
15094
|
+
const int32_t t = ggml_get_i32_1d(tensor->opt[1], 0);
|
13236
15095
|
GGML_ASSERT(t == 0 || t == 1);
|
13237
|
-
bool masked = t != 0;
|
15096
|
+
const bool masked = t != 0;
|
13238
15097
|
ggml_compute_forward_flash_attn(params, tensor->src0, tensor->src1, tensor->opt[0], masked, tensor);
|
13239
15098
|
} break;
|
13240
15099
|
case GGML_OP_FLASH_FF:
|
13241
15100
|
{
|
13242
15101
|
ggml_compute_forward_flash_ff(params, tensor->src0, tensor->src1, tensor->opt[0], tensor->opt[1], tensor->opt[2], tensor);
|
13243
15102
|
} break;
|
15103
|
+
case GGML_OP_FLASH_ATTN_BACK:
|
15104
|
+
{
|
15105
|
+
int32_t t = ggml_get_i32_1d(tensor->opt[2], 0);
|
15106
|
+
GGML_ASSERT(t == 0 || t == 1);
|
15107
|
+
bool masked = t != 0;
|
15108
|
+
ggml_compute_forward_flash_attn_back(params, tensor->src0, tensor->src1, tensor->opt[0], tensor->opt[1], masked, tensor);
|
15109
|
+
} break;
|
15110
|
+
case GGML_OP_WIN_PART:
|
15111
|
+
{
|
15112
|
+
ggml_compute_forward_win_part(params, tensor->src0, tensor->opt[0], tensor);
|
15113
|
+
} break;
|
15114
|
+
case GGML_OP_WIN_UNPART:
|
15115
|
+
{
|
15116
|
+
ggml_compute_forward_win_unpart(params, tensor->src0, tensor->opt[0], tensor);
|
15117
|
+
} break;
|
13244
15118
|
case GGML_OP_MAP_UNARY:
|
13245
15119
|
{
|
13246
15120
|
const ggml_unary_op_f32_t fun = *((ggml_unary_op_f32_t *)tensor->opt[0]->data);
|
@@ -13253,6 +15127,16 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
13253
15127
|
ggml_compute_forward_map_binary(params, tensor->src0, tensor->src1, tensor, fun);
|
13254
15128
|
}
|
13255
15129
|
break;
|
15130
|
+
case GGML_OP_CROSS_ENTROPY_LOSS:
|
15131
|
+
{
|
15132
|
+
ggml_compute_forward_cross_entropy_loss(params, tensor->src0, tensor->src1, tensor);
|
15133
|
+
}
|
15134
|
+
break;
|
15135
|
+
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
15136
|
+
{
|
15137
|
+
ggml_compute_forward_cross_entropy_loss_back(params, tensor->src0, tensor->src1, tensor->opt[0], tensor);
|
15138
|
+
}
|
15139
|
+
break;
|
13256
15140
|
case GGML_OP_NONE:
|
13257
15141
|
{
|
13258
15142
|
// nop
|
@@ -13391,11 +15275,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
13391
15275
|
src0->grad =
|
13392
15276
|
ggml_add_impl(ctx,
|
13393
15277
|
src0->grad,
|
13394
|
-
|
13395
|
-
tensor->grad, // this was not catched by test_grad because in test_grad tensor->grad is 1
|
15278
|
+
ggml_scale(ctx,
|
13396
15279
|
ggml_div(ctx,
|
13397
|
-
|
13398
|
-
tensor)
|
15280
|
+
tensor->grad,
|
15281
|
+
tensor),
|
15282
|
+
ggml_new_f32(ctx, 0.5f)),
|
13399
15283
|
inplace);
|
13400
15284
|
}
|
13401
15285
|
} break;
|
@@ -13441,43 +15325,20 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
13441
15325
|
{
|
13442
15326
|
// necessary for llama
|
13443
15327
|
if (src0->grad) {
|
13444
|
-
|
13445
|
-
|
13446
|
-
|
13447
|
-
|
13448
|
-
|
13449
|
-
|
13450
|
-
|
13451
|
-
|
13452
|
-
|
13453
|
-
//
|
13454
|
-
|
13455
|
-
|
13456
|
-
|
13457
|
-
|
13458
|
-
// transpose [nc0*nr0,1,1]
|
13459
|
-
// reshape [nc0,nr0,1,1] reshape_1d or reshape_2d
|
13460
|
-
// add to src0->grad
|
13461
|
-
|
13462
|
-
int64_t ne[4] = {nc0,ncr,nr0,nrr};
|
13463
|
-
|
13464
|
-
struct ggml_tensor* F00 = tensor->grad;
|
13465
|
-
struct ggml_tensor* F01 = ggml_reshape (ctx, F00, ggml_new_tensor(ctx,tensor->grad->type,4,ne));
|
13466
|
-
struct ggml_tensor* F02 = ggml_permute (ctx, F01, 0,2,1,3);
|
13467
|
-
struct ggml_tensor* F03 = ggml_cont (ctx, F02);
|
13468
|
-
struct ggml_tensor* F04 = ggml_reshape_2d(ctx, F03, nc0*nr0, ncr*nrr);
|
13469
|
-
struct ggml_tensor* F05 = ggml_transpose (ctx, F04);
|
13470
|
-
struct ggml_tensor* F06 = ggml_cont (ctx, F05);
|
13471
|
-
struct ggml_tensor* F07 = ggml_sum_rows (ctx, F06);
|
13472
|
-
struct ggml_tensor* F08 = ggml_transpose (ctx, F07);
|
13473
|
-
struct ggml_tensor* F09 = ggml_cont (ctx, F08);
|
13474
|
-
struct ggml_tensor* F10 = ggml_reshape (ctx, F09, src0->grad);
|
13475
|
-
|
13476
|
-
src0->grad =
|
13477
|
-
ggml_add_impl(ctx,
|
13478
|
-
src0->grad,
|
13479
|
-
F10,
|
13480
|
-
inplace);
|
15328
|
+
src0->grad = ggml_add_impl(ctx,
|
15329
|
+
src0->grad,
|
15330
|
+
ggml_repeat_back(ctx, tensor->grad, src0->grad),
|
15331
|
+
inplace);
|
15332
|
+
}
|
15333
|
+
} break;
|
15334
|
+
case GGML_OP_REPEAT_BACK:
|
15335
|
+
{
|
15336
|
+
if (src0->grad) {
|
15337
|
+
// TODO: test this
|
15338
|
+
src0->grad = ggml_add_impl(ctx,
|
15339
|
+
src0->grad,
|
15340
|
+
ggml_repeat(ctx, tensor->grad, src0->grad),
|
15341
|
+
inplace);
|
13481
15342
|
}
|
13482
15343
|
} break;
|
13483
15344
|
case GGML_OP_ABS:
|
@@ -13525,6 +15386,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
13525
15386
|
{
|
13526
15387
|
GGML_ASSERT(false); // TODO: not implemented
|
13527
15388
|
} break;
|
15389
|
+
case GGML_OP_GELU_QUICK:
|
15390
|
+
{
|
15391
|
+
GGML_ASSERT(false); // TODO: not implemented
|
15392
|
+
} break;
|
13528
15393
|
case GGML_OP_ALIBI:
|
13529
15394
|
{
|
13530
15395
|
GGML_ASSERT(false); // TODO: not implemented
|
@@ -13584,38 +15449,37 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
13584
15449
|
|
13585
15450
|
// necessary for llama
|
13586
15451
|
if (src0->grad) {
|
13587
|
-
// TODO: this requires outer product - ggml_out_prod(ctx, src1, tensor->grad);
|
13588
15452
|
src0->grad =
|
13589
15453
|
ggml_add_impl(ctx,
|
13590
15454
|
src0->grad,
|
13591
|
-
//
|
13592
|
-
|
13593
|
-
|
13594
|
-
// tensor->grad), // [m,p]
|
13595
|
-
// for now just using A*B==(B.T*A.T).T
|
13596
|
-
ggml_cont(ctx, // [n,m]
|
13597
|
-
ggml_transpose(ctx, // [n,m]
|
13598
|
-
ggml_mul_mat(ctx, // [m,n]
|
13599
|
-
ggml_cont(ctx, // [p,m]
|
13600
|
-
ggml_transpose(ctx, // [p,m]
|
13601
|
-
tensor->grad)), // [m,p]
|
13602
|
-
ggml_cont(ctx, // [p,n]
|
13603
|
-
ggml_transpose(ctx, // [p,n]
|
13604
|
-
src1))))), // [n,p]
|
15455
|
+
ggml_out_prod(ctx, // [n,m]
|
15456
|
+
src1, // [n,p]
|
15457
|
+
tensor->grad), // [m,p]
|
13605
15458
|
inplace);
|
13606
15459
|
}
|
13607
15460
|
if (src1->grad) {
|
13608
15461
|
src1->grad =
|
13609
15462
|
ggml_add_impl(ctx,
|
13610
15463
|
src1->grad,
|
13611
|
-
//
|
13612
|
-
|
13613
|
-
|
13614
|
-
|
13615
|
-
|
15464
|
+
// ggml_mul_mat(ctx, // [n,p]
|
15465
|
+
// ggml_cont(ctx, // [m,n]
|
15466
|
+
// ggml_transpose(ctx, src0)), // [m,n]
|
15467
|
+
// tensor->grad), // [m,p]
|
15468
|
+
|
15469
|
+
// // when src0 is bigger than tensor->grad (this is mostly the case in llama),
|
15470
|
+
// // avoid transpose of src0, rather transpose smaller tensor->grad
|
15471
|
+
// // and then use ggml_out_prod
|
15472
|
+
ggml_out_prod(ctx, // [n,p]
|
15473
|
+
src0, // [n,m]
|
15474
|
+
ggml_transpose(ctx, // [p,m]
|
15475
|
+
tensor->grad)), // [m,p]
|
13616
15476
|
inplace);
|
13617
15477
|
}
|
13618
15478
|
} break;
|
15479
|
+
case GGML_OP_OUT_PROD:
|
15480
|
+
{
|
15481
|
+
GGML_ASSERT(false); // TODO: not implemented
|
15482
|
+
} break;
|
13619
15483
|
case GGML_OP_SCALE:
|
13620
15484
|
{
|
13621
15485
|
// necessary for llama
|
@@ -13717,7 +15581,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
13717
15581
|
// necessary for llama
|
13718
15582
|
if (src0->grad) {
|
13719
15583
|
size_t offset;
|
13720
|
-
|
15584
|
+
|
15585
|
+
GGML_ASSERT(sizeof(offset) <= ggml_nbytes(tensor->opt[0]));
|
15586
|
+
memcpy(&offset, tensor->opt[0]->data, sizeof(offset));
|
13721
15587
|
|
13722
15588
|
size_t nb1 = tensor->nb[1];
|
13723
15589
|
size_t nb2 = tensor->nb[2];
|
@@ -13744,10 +15610,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
13744
15610
|
{
|
13745
15611
|
// necessary for llama
|
13746
15612
|
if (src0->grad) {
|
13747
|
-
|
13748
|
-
int
|
13749
|
-
int
|
13750
|
-
int
|
15613
|
+
int32_t * axes = (int32_t *) tensor->opt[0]->data;
|
15614
|
+
int axis0 = axes[0] & 0x3;
|
15615
|
+
int axis1 = axes[1] & 0x3;
|
15616
|
+
int axis2 = axes[2] & 0x3;
|
15617
|
+
int axis3 = axes[3] & 0x3;
|
13751
15618
|
int axes_backward[4] = {0,0,0,0};
|
13752
15619
|
axes_backward[axis0] = 0;
|
13753
15620
|
axes_backward[axis1] = 1;
|
@@ -13831,50 +15698,16 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
13831
15698
|
{
|
13832
15699
|
// necessary for llama
|
13833
15700
|
if (src0->grad) {
|
13834
|
-
// y = softmax(x)
|
13835
|
-
//
|
13836
|
-
// Jii = yi - yi*yi
|
13837
|
-
// Jij = -yi*yj
|
13838
|
-
// J = diag(y)-y.*y
|
13839
|
-
// dx = J * dy
|
13840
|
-
// dxk = sum(Jkj * dyk)
|
13841
|
-
|
13842
|
-
int64_t ne2[4] = {
|
13843
|
-
tensor->ne[0],
|
13844
|
-
1,
|
13845
|
-
tensor->ne[1]*tensor->ne[2],
|
13846
|
-
tensor->ne[3]
|
13847
|
-
};
|
13848
|
-
struct ggml_tensor * tensor2 = ggml_cont(ctx,
|
13849
|
-
ggml_reshape_4d(ctx,
|
13850
|
-
ggml_cont(ctx, tensor),
|
13851
|
-
ne2[0], ne2[1], ne2[2], ne2[3]));
|
13852
|
-
|
13853
|
-
struct ggml_tensor * grad2 = ggml_cont(ctx,
|
13854
|
-
ggml_reshape_4d(ctx,
|
13855
|
-
ggml_cont(ctx, tensor->grad),
|
13856
|
-
ne2[0], ne2[1], ne2[2], ne2[3]));
|
13857
|
-
|
13858
|
-
struct ggml_tensor * tensor2_t = ggml_cont(ctx, // [1,ne0,ne1*ne2,ne3]
|
13859
|
-
ggml_permute(ctx, // [1,ne0,ne1*ne2,ne3]
|
13860
|
-
tensor2, // [ne0,1,ne1*ne2,ne3]
|
13861
|
-
1, 0, 2, 3));
|
13862
|
-
|
13863
15701
|
src0->grad =
|
13864
|
-
ggml_add_impl(ctx,
|
13865
|
-
|
13866
|
-
|
13867
|
-
ggml_mul_mat(ctx, // [ne0,1,ne1*ne2,ne3]
|
13868
|
-
ggml_sub(ctx, // [ne0,ne0,ne1*ne2,ne3]
|
13869
|
-
ggml_diag(ctx, // [ne0,ne0,ne1*ne2,ne3]
|
13870
|
-
tensor2), // [ne0,1,ne1*ne2,ne3]
|
13871
|
-
ggml_mul_mat(ctx, // [ne0,ne0,ne1*ne2,ne3]
|
13872
|
-
tensor2_t, // [1,ne0,ne1*ne2,ne3]
|
13873
|
-
tensor2_t)), // [1,ne0,ne1*ne2,ne3]
|
13874
|
-
grad2), // [ne0,1,ne1*ne2,ne3]
|
13875
|
-
src0->grad),
|
13876
|
-
inplace);
|
15702
|
+
ggml_add_impl(ctx, src0->grad,
|
15703
|
+
ggml_soft_max_back(ctx, tensor->grad, tensor),
|
15704
|
+
inplace);
|
13877
15705
|
}
|
15706
|
+
|
15707
|
+
} break;
|
15708
|
+
case GGML_OP_SOFT_MAX_BACK:
|
15709
|
+
{
|
15710
|
+
GGML_ASSERT(false); // TODO: not implemented
|
13878
15711
|
} break;
|
13879
15712
|
case GGML_OP_ROPE:
|
13880
15713
|
{
|
@@ -13919,27 +15752,206 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
13919
15752
|
// noop
|
13920
15753
|
}
|
13921
15754
|
} break;
|
13922
|
-
case
|
15755
|
+
case GGML_OP_CONV_1D_S1_PH:
|
15756
|
+
{
|
15757
|
+
GGML_ASSERT(false); // TODO: not implemented
|
15758
|
+
} break;
|
15759
|
+
case GGML_OP_CONV_1D_S2_PH:
|
13923
15760
|
{
|
13924
15761
|
GGML_ASSERT(false); // TODO: not implemented
|
13925
15762
|
} break;
|
13926
|
-
case
|
15763
|
+
case GGML_OP_CONV_2D_SK_P0:
|
13927
15764
|
{
|
13928
15765
|
GGML_ASSERT(false); // TODO: not implemented
|
13929
15766
|
} break;
|
13930
15767
|
case GGML_OP_FLASH_ATTN:
|
13931
15768
|
{
|
13932
|
-
|
15769
|
+
struct ggml_tensor * flash_grad = NULL;
|
15770
|
+
if (src0->grad || src1->grad || tensor->opt[0]->grad) {
|
15771
|
+
int32_t t = ggml_get_i32_1d(tensor->opt[1], 0);
|
15772
|
+
GGML_ASSERT(t == 0 || t == 1);
|
15773
|
+
bool masked = t != 0;
|
15774
|
+
flash_grad =
|
15775
|
+
ggml_flash_attn_back(ctx,
|
15776
|
+
src0,
|
15777
|
+
src1,
|
15778
|
+
tensor->opt[0],
|
15779
|
+
tensor->grad,
|
15780
|
+
masked);
|
15781
|
+
}
|
15782
|
+
|
15783
|
+
if (src0->grad) {
|
15784
|
+
struct ggml_tensor * grad_q = NULL;
|
15785
|
+
const size_t nb0 = flash_grad->nb[0];
|
15786
|
+
const size_t offset = 0;
|
15787
|
+
switch(src0->n_dims) {
|
15788
|
+
case 2:
|
15789
|
+
{
|
15790
|
+
grad_q = ggml_view_2d(ctx,
|
15791
|
+
flash_grad,
|
15792
|
+
src0->ne[0],
|
15793
|
+
src0->ne[1],
|
15794
|
+
nb0*src0->ne[0],
|
15795
|
+
offset);
|
15796
|
+
} break;
|
15797
|
+
case 3:
|
15798
|
+
{
|
15799
|
+
grad_q = ggml_view_3d(ctx,
|
15800
|
+
flash_grad,
|
15801
|
+
src0->ne[0],
|
15802
|
+
src0->ne[1],
|
15803
|
+
src0->ne[2],
|
15804
|
+
nb0*src0->ne[0],
|
15805
|
+
nb0*src0->ne[0]*src0->ne[1],
|
15806
|
+
offset);
|
15807
|
+
} break;
|
15808
|
+
case 4:
|
15809
|
+
{
|
15810
|
+
grad_q = ggml_view_4d(ctx,
|
15811
|
+
flash_grad,
|
15812
|
+
src0->ne[0],
|
15813
|
+
src0->ne[1],
|
15814
|
+
src0->ne[2],
|
15815
|
+
src0->ne[3],
|
15816
|
+
nb0*src0->ne[0],
|
15817
|
+
nb0*src0->ne[0]*src0->ne[1],
|
15818
|
+
nb0*src0->ne[0]*src0->ne[1]*src0->ne[2],
|
15819
|
+
offset);
|
15820
|
+
} break;
|
15821
|
+
}
|
15822
|
+
|
15823
|
+
src0->grad = ggml_add_impl(ctx,
|
15824
|
+
src0->grad,
|
15825
|
+
grad_q,
|
15826
|
+
inplace);
|
15827
|
+
}
|
15828
|
+
|
15829
|
+
if (src1->grad) {
|
15830
|
+
struct ggml_tensor * grad_k = NULL;
|
15831
|
+
const size_t nb0 = flash_grad->nb[0];
|
15832
|
+
const size_t offset = nb0*src0->ne[0]*src0->ne[1]*src0->ne[2]*src0->ne[3];
|
15833
|
+
switch(src1->n_dims) {
|
15834
|
+
case 2:
|
15835
|
+
{
|
15836
|
+
grad_k = ggml_view_2d(ctx,
|
15837
|
+
flash_grad,
|
15838
|
+
src1->ne[0],
|
15839
|
+
src1->ne[1],
|
15840
|
+
nb0*src1->ne[0],
|
15841
|
+
offset);
|
15842
|
+
} break;
|
15843
|
+
case 3:
|
15844
|
+
{
|
15845
|
+
grad_k = ggml_view_3d(ctx,
|
15846
|
+
flash_grad,
|
15847
|
+
src1->ne[0],
|
15848
|
+
src1->ne[1],
|
15849
|
+
src1->ne[2],
|
15850
|
+
nb0*src1->ne[0],
|
15851
|
+
nb0*src1->ne[0]*src1->ne[1],
|
15852
|
+
offset);
|
15853
|
+
} break;
|
15854
|
+
case 4:
|
15855
|
+
{
|
15856
|
+
grad_k = ggml_view_4d(ctx,
|
15857
|
+
flash_grad,
|
15858
|
+
src1->ne[0],
|
15859
|
+
src1->ne[1],
|
15860
|
+
src1->ne[2],
|
15861
|
+
src1->ne[3],
|
15862
|
+
nb0*src1->ne[0],
|
15863
|
+
nb0*src1->ne[0]*src1->ne[1],
|
15864
|
+
nb0*src1->ne[0]*src1->ne[1]*src1->ne[2],
|
15865
|
+
offset);
|
15866
|
+
} break;
|
15867
|
+
}
|
15868
|
+
|
15869
|
+
src1->grad = ggml_add_impl(ctx,
|
15870
|
+
src1->grad,
|
15871
|
+
grad_k,
|
15872
|
+
inplace);
|
15873
|
+
}
|
15874
|
+
|
15875
|
+
struct ggml_tensor * opt0 = tensor->opt[0];
|
15876
|
+
|
15877
|
+
if (opt0->grad) {
|
15878
|
+
struct ggml_tensor * grad_v = NULL;
|
15879
|
+
const size_t nb0 = flash_grad->nb[0];
|
15880
|
+
const size_t offset = nb0*src0->ne[0]*src0->ne[1]*src0->ne[2]*src0->ne[3]
|
15881
|
+
+ nb0*src1->ne[0]*src1->ne[1]*src1->ne[2]*src1->ne[3];
|
15882
|
+
switch(opt0->n_dims) {
|
15883
|
+
case 2:
|
15884
|
+
{
|
15885
|
+
grad_v = ggml_view_2d(ctx,
|
15886
|
+
flash_grad,
|
15887
|
+
opt0->ne[0],
|
15888
|
+
opt0->ne[1],
|
15889
|
+
nb0*opt0->ne[0],
|
15890
|
+
offset);
|
15891
|
+
} break;
|
15892
|
+
case 3:
|
15893
|
+
{
|
15894
|
+
grad_v = ggml_view_3d(ctx,
|
15895
|
+
flash_grad,
|
15896
|
+
opt0->ne[0],
|
15897
|
+
opt0->ne[1],
|
15898
|
+
opt0->ne[2],
|
15899
|
+
nb0*opt0->ne[0],
|
15900
|
+
nb0*opt0->ne[0]*opt0->ne[1],
|
15901
|
+
offset);
|
15902
|
+
} break;
|
15903
|
+
case 4:
|
15904
|
+
{
|
15905
|
+
grad_v = ggml_view_4d(ctx,
|
15906
|
+
flash_grad,
|
15907
|
+
opt0->ne[0],
|
15908
|
+
opt0->ne[1],
|
15909
|
+
opt0->ne[2],
|
15910
|
+
opt0->ne[3],
|
15911
|
+
nb0*opt0->ne[0],
|
15912
|
+
nb0*opt0->ne[0]*opt0->ne[1],
|
15913
|
+
nb0*opt0->ne[0]*opt0->ne[1]*opt0->ne[2],
|
15914
|
+
offset);
|
15915
|
+
} break;
|
15916
|
+
}
|
15917
|
+
|
15918
|
+
opt0->grad = ggml_add_impl(ctx,
|
15919
|
+
opt0->grad,
|
15920
|
+
grad_v,
|
15921
|
+
inplace);
|
15922
|
+
}
|
13933
15923
|
} break;
|
13934
15924
|
case GGML_OP_FLASH_FF:
|
13935
15925
|
{
|
13936
15926
|
GGML_ASSERT(false); // not supported
|
13937
15927
|
} break;
|
15928
|
+
case GGML_OP_FLASH_ATTN_BACK:
|
15929
|
+
{
|
15930
|
+
GGML_ASSERT(false); // not supported
|
15931
|
+
} break;
|
15932
|
+
case GGML_OP_WIN_PART:
|
15933
|
+
case GGML_OP_WIN_UNPART:
|
13938
15934
|
case GGML_OP_MAP_UNARY:
|
13939
15935
|
case GGML_OP_MAP_BINARY:
|
13940
15936
|
{
|
13941
15937
|
GGML_ASSERT(false); // not supported
|
13942
15938
|
} break;
|
15939
|
+
case GGML_OP_CROSS_ENTROPY_LOSS:
|
15940
|
+
{
|
15941
|
+
if (src0->grad) {
|
15942
|
+
src0->grad = ggml_add_impl(ctx,
|
15943
|
+
src0->grad,
|
15944
|
+
ggml_cross_entropy_loss_back(ctx,
|
15945
|
+
src0,
|
15946
|
+
src1,
|
15947
|
+
tensor->grad),
|
15948
|
+
inplace);
|
15949
|
+
}
|
15950
|
+
} break;
|
15951
|
+
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
15952
|
+
{
|
15953
|
+
GGML_ASSERT(false); // not supported
|
15954
|
+
} break;
|
13943
15955
|
case GGML_OP_NONE:
|
13944
15956
|
{
|
13945
15957
|
// nop
|
@@ -14316,6 +16328,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|
14316
16328
|
case GGML_OP_SUM_ROWS:
|
14317
16329
|
case GGML_OP_MEAN:
|
14318
16330
|
case GGML_OP_REPEAT:
|
16331
|
+
case GGML_OP_REPEAT_BACK:
|
14319
16332
|
case GGML_OP_ABS:
|
14320
16333
|
case GGML_OP_SGN:
|
14321
16334
|
case GGML_OP_NEG:
|
@@ -14326,6 +16339,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|
14326
16339
|
} break;
|
14327
16340
|
case GGML_OP_MUL:
|
14328
16341
|
case GGML_OP_GELU:
|
16342
|
+
case GGML_OP_GELU_QUICK:
|
14329
16343
|
case GGML_OP_SILU:
|
14330
16344
|
case GGML_OP_SILU_BACK:
|
14331
16345
|
case GGML_OP_NORM:
|
@@ -14335,6 +16349,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|
14335
16349
|
node->n_tasks = n_threads;
|
14336
16350
|
} break;
|
14337
16351
|
case GGML_OP_MUL_MAT:
|
16352
|
+
case GGML_OP_OUT_PROD:
|
14338
16353
|
{
|
14339
16354
|
node->n_tasks = n_threads;
|
14340
16355
|
|
@@ -14417,6 +16432,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|
14417
16432
|
} break;
|
14418
16433
|
case GGML_OP_DIAG_MASK_INF:
|
14419
16434
|
case GGML_OP_SOFT_MAX:
|
16435
|
+
case GGML_OP_SOFT_MAX_BACK:
|
14420
16436
|
case GGML_OP_ROPE:
|
14421
16437
|
case GGML_OP_ROPE_BACK:
|
14422
16438
|
{
|
@@ -14430,8 +16446,8 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|
14430
16446
|
{
|
14431
16447
|
node->n_tasks = 1; //TODO
|
14432
16448
|
} break;
|
14433
|
-
case
|
14434
|
-
case
|
16449
|
+
case GGML_OP_CONV_1D_S1_PH:
|
16450
|
+
case GGML_OP_CONV_1D_S2_PH:
|
14435
16451
|
{
|
14436
16452
|
node->n_tasks = n_threads;
|
14437
16453
|
|
@@ -14458,6 +16474,41 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|
14458
16474
|
GGML_ASSERT(false);
|
14459
16475
|
}
|
14460
16476
|
|
16477
|
+
work_size = MAX(work_size, cur);
|
16478
|
+
} break;
|
16479
|
+
case GGML_OP_CONV_2D_SK_P0:
|
16480
|
+
{
|
16481
|
+
node->n_tasks = n_threads;
|
16482
|
+
|
16483
|
+
GGML_ASSERT(node->src1->ne[3] == 1);
|
16484
|
+
|
16485
|
+
const int64_t ne00 = node->src0->ne[0]; // W
|
16486
|
+
const int64_t ne01 = node->src0->ne[1]; // H
|
16487
|
+
const int64_t ne02 = node->src0->ne[2]; // C
|
16488
|
+
const int64_t ne03 = node->src0->ne[3]; // N
|
16489
|
+
|
16490
|
+
const int64_t ne10 = node->src1->ne[0]; // W
|
16491
|
+
const int64_t ne11 = node->src1->ne[1]; // H
|
16492
|
+
const int64_t ne12 = node->src1->ne[2]; // C
|
16493
|
+
|
16494
|
+
const int64_t nk = ne00*ne01;
|
16495
|
+
|
16496
|
+
UNUSED(ne02);
|
16497
|
+
UNUSED(ne03);
|
16498
|
+
UNUSED(nk);
|
16499
|
+
|
16500
|
+
size_t cur = 0;
|
16501
|
+
|
16502
|
+
if (node->src0->type == GGML_TYPE_F16 &&
|
16503
|
+
node->src1->type == GGML_TYPE_F32) {
|
16504
|
+
cur = sizeof(ggml_fp16_t)*(ne10*ne11*ne12);
|
16505
|
+
} else if (node->src0->type == GGML_TYPE_F32 &&
|
16506
|
+
node->src1->type == GGML_TYPE_F32) {
|
16507
|
+
cur = sizeof(float)* (ne10*ne11*ne12);
|
16508
|
+
} else {
|
16509
|
+
GGML_ASSERT(false);
|
16510
|
+
}
|
16511
|
+
|
14461
16512
|
work_size = MAX(work_size, cur);
|
14462
16513
|
} break;
|
14463
16514
|
case GGML_OP_FLASH_ATTN:
|
@@ -14498,11 +16549,50 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|
14498
16549
|
|
14499
16550
|
work_size = MAX(work_size, cur);
|
14500
16551
|
} break;
|
16552
|
+
case GGML_OP_FLASH_ATTN_BACK:
|
16553
|
+
{
|
16554
|
+
node->n_tasks = n_threads;
|
16555
|
+
|
16556
|
+
size_t cur = 0;
|
16557
|
+
|
16558
|
+
const int64_t D = node->src0->ne[0];
|
16559
|
+
const int64_t ne11 = ggml_up(node->src1->ne[1], GGML_SOFT_MAX_UNROLL);
|
16560
|
+
const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in ggml_compute_forward_flash_attn_back
|
16561
|
+
if (node->src1->type == GGML_TYPE_F32) {
|
16562
|
+
cur = sizeof(float)*mxDn*node->n_tasks; // TODO: this can become (n_tasks-1)
|
16563
|
+
cur += sizeof(float)*mxDn*node->n_tasks; // this is overestimated by x2
|
16564
|
+
}
|
16565
|
+
|
16566
|
+
if (node->src1->type == GGML_TYPE_F16) {
|
16567
|
+
cur = sizeof(float)*mxDn*node->n_tasks; // TODO: this can become (n_tasks-1)
|
16568
|
+
cur += sizeof(float)*mxDn*node->n_tasks; // this is overestimated by x2
|
16569
|
+
}
|
16570
|
+
|
16571
|
+
work_size = MAX(work_size, cur);
|
16572
|
+
} break;
|
16573
|
+
case GGML_OP_WIN_PART:
|
16574
|
+
case GGML_OP_WIN_UNPART:
|
14501
16575
|
case GGML_OP_MAP_UNARY:
|
14502
16576
|
case GGML_OP_MAP_BINARY:
|
14503
16577
|
{
|
14504
16578
|
node->n_tasks = 1;
|
14505
16579
|
} break;
|
16580
|
+
case GGML_OP_CROSS_ENTROPY_LOSS:
|
16581
|
+
{
|
16582
|
+
node->n_tasks = n_threads;
|
16583
|
+
|
16584
|
+
size_t cur = ggml_type_size(node->type)*(node->n_tasks + node->src0->ne[0]*node->n_tasks);
|
16585
|
+
|
16586
|
+
work_size = MAX(work_size, cur);
|
16587
|
+
} break;
|
16588
|
+
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
16589
|
+
{
|
16590
|
+
node->n_tasks = n_threads;
|
16591
|
+
|
16592
|
+
size_t cur = ggml_type_size(node->type)*node->src0->ne[0]*node->n_tasks;
|
16593
|
+
|
16594
|
+
work_size = MAX(work_size, cur);
|
16595
|
+
} break;
|
14506
16596
|
case GGML_OP_NONE:
|
14507
16597
|
{
|
14508
16598
|
node->n_tasks = 1;
|
@@ -15014,16 +17104,20 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context **
|
|
15014
17104
|
|
15015
17105
|
if (!*ctx_data) {
|
15016
17106
|
fprintf(stderr, "%s: failed to create ggml context\n", __func__);
|
17107
|
+
fclose(fin);
|
15017
17108
|
return result;
|
15018
17109
|
}
|
15019
17110
|
}
|
15020
17111
|
|
15021
17112
|
data = ggml_new_tensor_1d(*ctx_data, GGML_TYPE_I8, fsize);
|
15022
17113
|
|
15023
|
-
|
15024
|
-
|
15025
|
-
|
15026
|
-
|
17114
|
+
{
|
17115
|
+
const size_t ret = fread(data->data, sizeof(char), fsize, fin);
|
17116
|
+
if (ret != fsize) {
|
17117
|
+
fprintf(stderr, "%s: failed to read %s\n", __func__, fname);
|
17118
|
+
fclose(fin);
|
17119
|
+
return result;
|
17120
|
+
}
|
15027
17121
|
}
|
15028
17122
|
|
15029
17123
|
fclose(fin);
|
@@ -15478,6 +17572,7 @@ static void ggml_opt_get_grad(int np, struct ggml_tensor * const ps[], float * g
|
|
15478
17572
|
|
15479
17573
|
static enum ggml_opt_result ggml_opt_adam(
|
15480
17574
|
struct ggml_context * ctx,
|
17575
|
+
struct ggml_opt_context * opt,
|
15481
17576
|
struct ggml_opt_params params,
|
15482
17577
|
struct ggml_tensor * f,
|
15483
17578
|
struct ggml_cgraph * gf,
|
@@ -15503,25 +17598,29 @@ static enum ggml_opt_result ggml_opt_adam(
|
|
15503
17598
|
}
|
15504
17599
|
}
|
15505
17600
|
|
17601
|
+
if ((opt->params.type != params.type) || (opt->nx != nx) || (opt->params.past != params.past)) {
|
17602
|
+
int iter = opt->iter;
|
17603
|
+
ggml_opt_init(opt->ctx, opt, params, nx);
|
17604
|
+
opt->iter = iter;
|
17605
|
+
}
|
17606
|
+
|
15506
17607
|
// constants
|
15507
|
-
const float
|
17608
|
+
const float sched = params.adam.sched;
|
17609
|
+
const float decay = params.adam.decay * sched;
|
17610
|
+
const float alpha = params.adam.alpha * sched;
|
15508
17611
|
const float beta1 = params.adam.beta1;
|
15509
17612
|
const float beta2 = params.adam.beta2;
|
15510
17613
|
const float eps = params.adam.eps;
|
15511
17614
|
|
15512
|
-
float * x =
|
15513
|
-
float * g1 =
|
15514
|
-
float * g2 =
|
15515
|
-
float * m =
|
15516
|
-
float * v =
|
15517
|
-
float * mh =
|
15518
|
-
float * vh =
|
17615
|
+
float * x = opt->adam.x->data; // view of the parameters
|
17616
|
+
float * g1 = opt->adam.g1->data; // gradient
|
17617
|
+
float * g2 = opt->adam.g2->data; // gradient squared
|
17618
|
+
float * m = opt->adam.m->data; // first moment
|
17619
|
+
float * v = opt->adam.v->data; // second moment
|
17620
|
+
float * mh = opt->adam.mh->data; // first moment hat
|
17621
|
+
float * vh = opt->adam.vh->data; // second moment hat
|
15519
17622
|
|
15520
|
-
float * pf = params.past > 0 ?
|
15521
|
-
|
15522
|
-
// initialize
|
15523
|
-
ggml_vec_set_f32(nx, m, 0.0f);
|
15524
|
-
ggml_vec_set_f32(nx, v, 0.0f);
|
17623
|
+
float * pf = params.past > 0 ? opt->adam.pf->data : NULL; // past function values
|
15525
17624
|
|
15526
17625
|
// update view
|
15527
17626
|
ggml_opt_get_params(np, ps, x);
|
@@ -15531,16 +17630,27 @@ static enum ggml_opt_result ggml_opt_adam(
|
|
15531
17630
|
ggml_set_f32 (f->grad, 1.0f);
|
15532
17631
|
ggml_graph_compute(ctx, gb);
|
15533
17632
|
|
15534
|
-
|
17633
|
+
opt->adam.fx_prev = ggml_get_f32_1d(f, 0);
|
17634
|
+
opt->adam.fx_best = opt->adam.fx_prev;
|
15535
17635
|
if (pf) {
|
15536
|
-
pf[
|
17636
|
+
pf[opt->iter % params.past] = opt->adam.fx_prev;
|
15537
17637
|
}
|
15538
17638
|
|
15539
|
-
|
15540
|
-
|
17639
|
+
// initialize
|
17640
|
+
if (opt->just_initialized) {
|
17641
|
+
opt->adam.n_no_improvement = 0;
|
17642
|
+
opt->just_initialized = false;
|
17643
|
+
}
|
17644
|
+
|
17645
|
+
float * fx_best = &opt->adam.fx_best;
|
17646
|
+
float * fx_prev = &opt->adam.fx_prev;
|
17647
|
+
int * n_no_improvement = &opt->adam.n_no_improvement;
|
17648
|
+
|
17649
|
+
int iter0 = opt->iter;
|
15541
17650
|
|
15542
17651
|
// run the optimizer
|
15543
17652
|
for (int t = 0; t < params.adam.n_iter; ++t) {
|
17653
|
+
opt->iter = iter0 + t + 1;
|
15544
17654
|
GGML_PRINT_DEBUG ("=== iter %d ===\n", t);
|
15545
17655
|
|
15546
17656
|
GGML_PRINT_DEBUG ("f = %10.6f\n", ggml_get_f32_1d(f, 0));
|
@@ -15574,17 +17684,22 @@ static enum ggml_opt_result ggml_opt_adam(
|
|
15574
17684
|
|
15575
17685
|
// m^hat = m_t / (1 - beta1^t)
|
15576
17686
|
// v^hat = v_t / (1 - beta2^t)
|
15577
|
-
// x_t = x_t-1 - alpha*m^hat/(sqrt(v^hat) + eps)
|
17687
|
+
// x_t = x_t-1 - sched*(alpha*m^hat/(sqrt(v^hat) + eps) + decay*x_t-1)
|
17688
|
+
// x_t = x_t-1 - sched*alpha*m^hat/(sqrt(v^hat) + eps) - sched*decay*x_t-1
|
17689
|
+
// x_t = x_t-1*(1-sched*decay) - sched*alpha*m^hat/(sqrt(v^hat) + eps)
|
17690
|
+
// x_t = x_t-1*(1-sched*decay) + sched*decay*(-alpha/decay)*m^hat/(sqrt(v^hat) + eps)
|
17691
|
+
// x_t = mix(x_t-1, (-alpha/decay)*m^hat/(sqrt(v^hat) + eps), sched*decay)
|
15578
17692
|
ggml_vec_cpy_f32 (nx, mh, m);
|
15579
17693
|
ggml_vec_cpy_f32 (nx, vh, v);
|
15580
17694
|
|
15581
|
-
ggml_vec_scale_f32(nx, mh, alpha/(1.0f - powf(beta1,
|
15582
|
-
ggml_vec_scale_f32(nx, vh, 1.0f/(1.0f - powf(beta2,
|
17695
|
+
ggml_vec_scale_f32(nx, mh, alpha/(1.0f - powf(beta1, opt->iter)));
|
17696
|
+
ggml_vec_scale_f32(nx, vh, 1.0f/(1.0f - powf(beta2, opt->iter)));
|
15583
17697
|
|
15584
17698
|
ggml_vec_sqrt_f32 (nx, vh, vh);
|
15585
17699
|
ggml_vec_acc1_f32 (nx, vh, eps);
|
15586
17700
|
|
15587
17701
|
ggml_vec_div_f32 (nx, mh, mh, vh);
|
17702
|
+
ggml_vec_scale_f32(nx, x, 1.0f - decay);
|
15588
17703
|
ggml_vec_sub_f32 (nx, x, x, mh);
|
15589
17704
|
|
15590
17705
|
// update the parameters
|
@@ -15598,7 +17713,7 @@ static enum ggml_opt_result ggml_opt_adam(
|
|
15598
17713
|
const float fx = ggml_get_f32_1d(f, 0);
|
15599
17714
|
|
15600
17715
|
// check convergence
|
15601
|
-
if (fabsf(fx - fx_prev)/fx < params.adam.eps_f) {
|
17716
|
+
if (fabsf(fx - fx_prev[0])/fx < params.adam.eps_f) {
|
15602
17717
|
GGML_PRINT_DEBUG("converged\n");
|
15603
17718
|
|
15604
17719
|
return GGML_OPT_OK;
|
@@ -15607,32 +17722,32 @@ static enum ggml_opt_result ggml_opt_adam(
|
|
15607
17722
|
// delta-based convergence test
|
15608
17723
|
if (pf != NULL) {
|
15609
17724
|
// need at least params.past iterations to start checking for convergence
|
15610
|
-
if (params.past <= t) {
|
15611
|
-
const float rate = (pf[t%params.past] - fx)/fx;
|
17725
|
+
if (params.past <= iter0 + t) {
|
17726
|
+
const float rate = (pf[(iter0 + t)%params.past] - fx)/fx;
|
15612
17727
|
|
15613
17728
|
if (fabsf(rate) < params.delta) {
|
15614
17729
|
return GGML_OPT_OK;
|
15615
17730
|
}
|
15616
17731
|
}
|
15617
17732
|
|
15618
|
-
pf[t%params.past] = fx;
|
17733
|
+
pf[(iter0 + t)%params.past] = fx;
|
15619
17734
|
}
|
15620
17735
|
|
15621
17736
|
// check for improvement
|
15622
17737
|
if (params.max_no_improvement > 0) {
|
15623
|
-
if (fx_best > fx) {
|
15624
|
-
fx_best = fx;
|
15625
|
-
n_no_improvement = 0;
|
17738
|
+
if (fx_best[0] > fx) {
|
17739
|
+
fx_best[0] = fx;
|
17740
|
+
n_no_improvement[0] = 0;
|
15626
17741
|
} else {
|
15627
|
-
++n_no_improvement;
|
17742
|
+
++n_no_improvement[0];
|
15628
17743
|
|
15629
|
-
if (n_no_improvement >= params.max_no_improvement) {
|
17744
|
+
if (n_no_improvement[0] >= params.max_no_improvement) {
|
15630
17745
|
return GGML_OPT_OK;
|
15631
17746
|
}
|
15632
17747
|
}
|
15633
17748
|
}
|
15634
17749
|
|
15635
|
-
fx_prev = fx;
|
17750
|
+
fx_prev[0] = fx;
|
15636
17751
|
|
15637
17752
|
{
|
15638
17753
|
const int64_t t_end_cpu = ggml_cycles();
|
@@ -15771,6 +17886,7 @@ static enum ggml_opt_result linesearch_backtracking(
|
|
15771
17886
|
|
15772
17887
|
static enum ggml_opt_result ggml_opt_lbfgs(
|
15773
17888
|
struct ggml_context * ctx,
|
17889
|
+
struct ggml_opt_context * opt,
|
15774
17890
|
struct ggml_opt_params params,
|
15775
17891
|
struct ggml_tensor * f,
|
15776
17892
|
struct ggml_cgraph * gf,
|
@@ -15803,31 +17919,32 @@ static enum ggml_opt_result ggml_opt_lbfgs(
|
|
15803
17919
|
}
|
15804
17920
|
}
|
15805
17921
|
|
15806
|
-
|
15807
|
-
|
15808
|
-
|
15809
|
-
|
15810
|
-
|
17922
|
+
if ((opt->params.type != params.type) || (opt->nx != nx) || (opt->params.past != params.past) || (opt->params.lbfgs.m != params.lbfgs.m)) {
|
17923
|
+
int iter = opt->iter;
|
17924
|
+
ggml_opt_init(ctx, opt, params, nx);
|
17925
|
+
opt->iter = iter;
|
17926
|
+
}
|
17927
|
+
|
17928
|
+
float * x = opt->lbfgs.x->data; // current parameters
|
17929
|
+
float * xp = opt->lbfgs.xp->data; // previous parameters
|
17930
|
+
float * g = opt->lbfgs.g->data; // current gradient
|
17931
|
+
float * gp = opt->lbfgs.gp->data; // previous gradient
|
17932
|
+
float * d = opt->lbfgs.d->data; // search direction
|
15811
17933
|
|
15812
|
-
float * pf = params.past > 0 ?
|
17934
|
+
float * pf = params.past > 0 ? opt->lbfgs.pf->data : NULL; // past function values
|
15813
17935
|
|
15814
17936
|
float fx = 0.0f; // cost function value
|
15815
17937
|
float xnorm = 0.0f; // ||x||
|
15816
17938
|
float gnorm = 0.0f; // ||g||
|
15817
|
-
float step = 0.0f;
|
15818
17939
|
|
15819
17940
|
// initialize x from the graph nodes
|
15820
17941
|
ggml_opt_get_params(np, ps, x);
|
15821
17942
|
|
15822
17943
|
// the L-BFGS memory
|
15823
|
-
|
15824
|
-
|
15825
|
-
|
15826
|
-
|
15827
|
-
lm[i].ys = 0.0f;
|
15828
|
-
lm[i].s = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data;
|
15829
|
-
lm[i].y = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data;
|
15830
|
-
}
|
17944
|
+
float * lm_alpha = opt->lbfgs.lmal->data;
|
17945
|
+
float * lm_ys = opt->lbfgs.lmys->data;
|
17946
|
+
float * lm_s = opt->lbfgs.lms->data;
|
17947
|
+
float * lm_y = opt->lbfgs.lmy->data;
|
15831
17948
|
|
15832
17949
|
// evaluate the function value and its gradient
|
15833
17950
|
{
|
@@ -15842,12 +17959,6 @@ static enum ggml_opt_result ggml_opt_lbfgs(
|
|
15842
17959
|
fx = ggml_get_f32_1d(f, 0);
|
15843
17960
|
}
|
15844
17961
|
|
15845
|
-
if (pf) {
|
15846
|
-
pf[0] = fx;
|
15847
|
-
}
|
15848
|
-
|
15849
|
-
float fx_best = fx;
|
15850
|
-
|
15851
17962
|
// search direction = -gradient
|
15852
17963
|
ggml_vec_neg_f32(nx, d, g);
|
15853
17964
|
|
@@ -15864,26 +17975,43 @@ static enum ggml_opt_result ggml_opt_lbfgs(
|
|
15864
17975
|
return GGML_OPT_OK;
|
15865
17976
|
}
|
15866
17977
|
|
15867
|
-
|
15868
|
-
|
17978
|
+
if (opt->just_initialized) {
|
17979
|
+
if (pf) {
|
17980
|
+
pf[0] = fx;
|
17981
|
+
}
|
17982
|
+
opt->lbfgs.fx_best = fx;
|
17983
|
+
|
17984
|
+
// initial step
|
17985
|
+
ggml_vec_norm_inv_f32(nx, &opt->lbfgs.step, d);
|
17986
|
+
opt->lbfgs.j = 0;
|
17987
|
+
opt->lbfgs.k = 1;
|
17988
|
+
opt->lbfgs.end = 0;
|
17989
|
+
opt->lbfgs.n_no_improvement = 0;
|
17990
|
+
opt->just_initialized = false;
|
17991
|
+
}
|
17992
|
+
|
17993
|
+
float * fx_best = &opt->lbfgs.fx_best;
|
17994
|
+
float * step = &opt->lbfgs.step;
|
17995
|
+
int * j = &opt->lbfgs.j;
|
17996
|
+
int * k = &opt->lbfgs.k;
|
17997
|
+
int * end = &opt->lbfgs.end;
|
17998
|
+
int * n_no_improvement = &opt->lbfgs.n_no_improvement;
|
15869
17999
|
|
15870
|
-
int
|
15871
|
-
int
|
15872
|
-
int ls = 0;
|
15873
|
-
int end = 0;
|
15874
|
-
int bound = 0;
|
15875
|
-
int n_no_improvement = 0;
|
18000
|
+
int ls = 0;
|
18001
|
+
int bound = 0;
|
15876
18002
|
|
15877
18003
|
float ys = 0.0f;
|
15878
18004
|
float yy = 0.0f;
|
15879
18005
|
float beta = 0.0f;
|
15880
18006
|
|
18007
|
+
int it = 0;
|
18008
|
+
|
15881
18009
|
while (true) {
|
15882
18010
|
// store the current position and gradient vectors
|
15883
18011
|
ggml_vec_cpy_f32(nx, xp, x);
|
15884
18012
|
ggml_vec_cpy_f32(nx, gp, g);
|
15885
18013
|
|
15886
|
-
ls = linesearch_backtracking(ctx, ¶ms, nx, x, &fx, g, d,
|
18014
|
+
ls = linesearch_backtracking(ctx, ¶ms, nx, x, &fx, g, d, step, xp, f, gf, gb, np, ps);
|
15887
18015
|
|
15888
18016
|
if (ls < 0) {
|
15889
18017
|
// linesearch failed - go back to the previous point and return
|
@@ -15909,32 +18037,32 @@ static enum ggml_opt_result ggml_opt_lbfgs(
|
|
15909
18037
|
// delta-based convergence test
|
15910
18038
|
if (pf != NULL) {
|
15911
18039
|
// need at least params.past iterations to start checking for convergence
|
15912
|
-
if (params.past <= k) {
|
15913
|
-
const float rate = (pf[k%params.past] - fx)/fx;
|
18040
|
+
if (params.past <= k[0]) {
|
18041
|
+
const float rate = (pf[k[0]%params.past] - fx)/fx;
|
15914
18042
|
|
15915
18043
|
if (fabsf(rate) < params.delta) {
|
15916
18044
|
return GGML_OPT_OK;
|
15917
18045
|
}
|
15918
18046
|
}
|
15919
18047
|
|
15920
|
-
pf[k%params.past] = fx;
|
18048
|
+
pf[k[0]%params.past] = fx;
|
15921
18049
|
}
|
15922
18050
|
|
15923
18051
|
// check for improvement
|
15924
18052
|
if (params.max_no_improvement > 0) {
|
15925
|
-
if (fx < fx_best) {
|
15926
|
-
fx_best = fx;
|
15927
|
-
n_no_improvement = 0;
|
18053
|
+
if (fx < fx_best[0]) {
|
18054
|
+
fx_best[0] = fx;
|
18055
|
+
n_no_improvement[0] = 0;
|
15928
18056
|
} else {
|
15929
|
-
n_no_improvement++;
|
18057
|
+
n_no_improvement[0]++;
|
15930
18058
|
|
15931
|
-
if (n_no_improvement >= params.max_no_improvement) {
|
18059
|
+
if (n_no_improvement[0] >= params.max_no_improvement) {
|
15932
18060
|
return GGML_OPT_OK;
|
15933
18061
|
}
|
15934
18062
|
}
|
15935
18063
|
}
|
15936
18064
|
|
15937
|
-
if (params.lbfgs.n_iter != 0 && params.lbfgs.n_iter <
|
18065
|
+
if (params.lbfgs.n_iter != 0 && params.lbfgs.n_iter < it + 1) {
|
15938
18066
|
// reached the maximum number of iterations
|
15939
18067
|
return GGML_OPT_DID_NOT_CONVERGE;
|
15940
18068
|
}
|
@@ -15943,50 +18071,51 @@ static enum ggml_opt_result ggml_opt_lbfgs(
|
|
15943
18071
|
// s_{k+1} = x_{k+1} - x_{k} = \step * d_{k}.
|
15944
18072
|
// y_{k+1} = g_{k+1} - g_{k}.
|
15945
18073
|
//
|
15946
|
-
ggml_vec_sub_f32(nx,
|
15947
|
-
ggml_vec_sub_f32(nx,
|
18074
|
+
ggml_vec_sub_f32(nx, &lm_s[end[0]*nx], x, xp);
|
18075
|
+
ggml_vec_sub_f32(nx, &lm_y[end[0]*nx], g, gp);
|
15948
18076
|
|
15949
18077
|
// compute scalars ys and yy:
|
15950
18078
|
// ys = y^t \cdot s -> 1 / \rho.
|
15951
18079
|
// yy = y^t \cdot y.
|
15952
18080
|
//
|
15953
|
-
ggml_vec_dot_f32(nx, &ys,
|
15954
|
-
ggml_vec_dot_f32(nx, &yy,
|
18081
|
+
ggml_vec_dot_f32(nx, &ys, &lm_y[end[0]*nx], &lm_s[end[0] *nx]);
|
18082
|
+
ggml_vec_dot_f32(nx, &yy, &lm_y[end[0]*nx], &lm_y[end[0]*nx]);
|
15955
18083
|
|
15956
|
-
|
18084
|
+
lm_ys[end[0]] = ys;
|
15957
18085
|
|
15958
18086
|
// find new search direction
|
15959
18087
|
// ref: https://en.wikipedia.org/wiki/Limited-memory_BFGS
|
15960
18088
|
|
15961
|
-
bound = (m <= k) ? m : k;
|
15962
|
-
k++;
|
15963
|
-
|
18089
|
+
bound = (m <= k[0]) ? m : k[0];
|
18090
|
+
k[0]++;
|
18091
|
+
it++;
|
18092
|
+
end[0] = (end[0] + 1)%m;
|
15964
18093
|
|
15965
18094
|
// initialize search direction with -g
|
15966
18095
|
ggml_vec_neg_f32(nx, d, g);
|
15967
18096
|
|
15968
|
-
j = end;
|
18097
|
+
j[0] = end[0];
|
15969
18098
|
for (int i = 0; i < bound; ++i) {
|
15970
|
-
j = (j + m - 1) % m;
|
18099
|
+
j[0] = (j[0] + m - 1) % m;
|
15971
18100
|
// \alpha_{j} = \rho_{j} s^{t}_{j} \cdot q_{k+1}
|
15972
|
-
ggml_vec_dot_f32(nx, &
|
15973
|
-
|
18101
|
+
ggml_vec_dot_f32(nx, &lm_alpha[j[0]], &lm_s[j[0]*nx], d);
|
18102
|
+
lm_alpha[j[0]] /= lm_ys[j[0]];
|
15974
18103
|
// q_{i} = q_{i+1} - \alpha_{i} y_{i}
|
15975
|
-
ggml_vec_mad_f32(nx, d,
|
18104
|
+
ggml_vec_mad_f32(nx, d, &lm_y[j[0]*nx], -lm_alpha[j[0]]);
|
15976
18105
|
}
|
15977
18106
|
|
15978
18107
|
ggml_vec_scale_f32(nx, d, ys/yy);
|
15979
18108
|
|
15980
18109
|
for (int i = 0; i < bound; ++i) {
|
15981
18110
|
// \beta_{j} = \rho_{j} y^t_{j} \cdot \gamma_{i}
|
15982
|
-
ggml_vec_dot_f32(nx, &beta,
|
15983
|
-
beta /=
|
18111
|
+
ggml_vec_dot_f32(nx, &beta, &lm_y[j[0]*nx], d);
|
18112
|
+
beta /= lm_ys[j[0]];
|
15984
18113
|
// \gamma_{i+1} = \gamma_{i} + (\alpha_{j} - \beta_{j}) s_{j}
|
15985
|
-
ggml_vec_mad_f32(nx, d,
|
15986
|
-
j = (j + 1)%m;
|
18114
|
+
ggml_vec_mad_f32(nx, d, &lm_s[j[0]*nx], lm_alpha[j[0]] - beta);
|
18115
|
+
j[0] = (j[0] + 1)%m;
|
15987
18116
|
}
|
15988
18117
|
|
15989
|
-
step = 1.0;
|
18118
|
+
step[0] = 1.0;
|
15990
18119
|
}
|
15991
18120
|
|
15992
18121
|
return GGML_OPT_DID_NOT_CONVERGE;
|
@@ -16011,6 +18140,8 @@ struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) {
|
|
16011
18140
|
|
16012
18141
|
.adam = {
|
16013
18142
|
.n_iter = 10000,
|
18143
|
+
.sched = 1.000f,
|
18144
|
+
.decay = 0.001f,
|
16014
18145
|
.alpha = 0.001f,
|
16015
18146
|
.beta1 = 0.9f,
|
16016
18147
|
.beta2 = 0.999f,
|
@@ -16053,6 +18184,70 @@ struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) {
|
|
16053
18184
|
return result;
|
16054
18185
|
}
|
16055
18186
|
|
18187
|
+
GGML_API void ggml_opt_init(
|
18188
|
+
struct ggml_context * ctx,
|
18189
|
+
struct ggml_opt_context * opt,
|
18190
|
+
struct ggml_opt_params params,
|
18191
|
+
int64_t nx) {
|
18192
|
+
opt->ctx = ctx;
|
18193
|
+
opt->params = params;
|
18194
|
+
opt->iter = 0;
|
18195
|
+
opt->nx = nx;
|
18196
|
+
opt->just_initialized = true;
|
18197
|
+
switch (opt->params.type) {
|
18198
|
+
case GGML_OPT_ADAM:
|
18199
|
+
{
|
18200
|
+
opt->adam.x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
18201
|
+
opt->adam.g1 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
18202
|
+
opt->adam.g2 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
18203
|
+
opt->adam.m = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
18204
|
+
opt->adam.v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
18205
|
+
opt->adam.mh = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
18206
|
+
opt->adam.vh = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
18207
|
+
opt->adam.pf = params.past > 0
|
18208
|
+
? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)
|
18209
|
+
: NULL;
|
18210
|
+
ggml_set_zero(opt->adam.x);
|
18211
|
+
ggml_set_zero(opt->adam.g1);
|
18212
|
+
ggml_set_zero(opt->adam.g2);
|
18213
|
+
ggml_set_zero(opt->adam.m);
|
18214
|
+
ggml_set_zero(opt->adam.v);
|
18215
|
+
ggml_set_zero(opt->adam.mh);
|
18216
|
+
ggml_set_zero(opt->adam.vh);
|
18217
|
+
if (opt->adam.pf) {
|
18218
|
+
ggml_set_zero(opt->adam.pf);
|
18219
|
+
}
|
18220
|
+
} break;
|
18221
|
+
case GGML_OPT_LBFGS:
|
18222
|
+
{
|
18223
|
+
opt->lbfgs.x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
18224
|
+
opt->lbfgs.xp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
18225
|
+
opt->lbfgs.g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
18226
|
+
opt->lbfgs.gp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
18227
|
+
opt->lbfgs.d = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
18228
|
+
opt->lbfgs.pf = params.past > 0
|
18229
|
+
? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)
|
18230
|
+
: NULL;
|
18231
|
+
opt->lbfgs.lmal = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.lbfgs.m);
|
18232
|
+
opt->lbfgs.lmys = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.lbfgs.m);
|
18233
|
+
opt->lbfgs.lms = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nx, params.lbfgs.m);
|
18234
|
+
opt->lbfgs.lmy = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nx, params.lbfgs.m);
|
18235
|
+
ggml_set_zero(opt->lbfgs.x);
|
18236
|
+
ggml_set_zero(opt->lbfgs.xp);
|
18237
|
+
ggml_set_zero(opt->lbfgs.g);
|
18238
|
+
ggml_set_zero(opt->lbfgs.gp);
|
18239
|
+
ggml_set_zero(opt->lbfgs.d);
|
18240
|
+
if (opt->lbfgs.pf) {
|
18241
|
+
ggml_set_zero(opt->lbfgs.pf);
|
18242
|
+
}
|
18243
|
+
ggml_set_zero(opt->lbfgs.lmal);
|
18244
|
+
ggml_set_zero(opt->lbfgs.lmys);
|
18245
|
+
ggml_set_zero(opt->lbfgs.lms);
|
18246
|
+
ggml_set_zero(opt->lbfgs.lmy);
|
18247
|
+
} break;
|
18248
|
+
}
|
18249
|
+
}
|
18250
|
+
|
16056
18251
|
enum ggml_opt_result ggml_opt(
|
16057
18252
|
struct ggml_context * ctx,
|
16058
18253
|
struct ggml_opt_params params,
|
@@ -16075,33 +18270,65 @@ enum ggml_opt_result ggml_opt(
|
|
16075
18270
|
|
16076
18271
|
enum ggml_opt_result result = GGML_OPT_OK;
|
16077
18272
|
|
18273
|
+
struct ggml_opt_context * opt = (struct ggml_opt_context *) alloca(sizeof(struct ggml_opt_context));
|
18274
|
+
|
18275
|
+
ggml_opt_init(ctx, opt, params, 0);
|
18276
|
+
result = ggml_opt_resume(ctx, opt, f);
|
18277
|
+
|
18278
|
+
if (free_ctx) {
|
18279
|
+
ggml_free(ctx);
|
18280
|
+
}
|
18281
|
+
|
18282
|
+
return result;
|
18283
|
+
}
|
18284
|
+
|
18285
|
+
enum ggml_opt_result ggml_opt_resume(
|
18286
|
+
struct ggml_context * ctx,
|
18287
|
+
struct ggml_opt_context * opt,
|
18288
|
+
struct ggml_tensor * f) {
|
18289
|
+
|
18290
|
+
// build forward + backward compute graphs
|
18291
|
+
struct ggml_tensor * gfbuf = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(struct ggml_cgraph) / GGML_TYPE_SIZE[GGML_TYPE_I32]+ (sizeof(struct ggml_cgraph) % GGML_TYPE_SIZE[GGML_TYPE_I32] ? 1 : 0));
|
18292
|
+
struct ggml_tensor * gbbuf = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(struct ggml_cgraph) / GGML_TYPE_SIZE[GGML_TYPE_I32]+ (sizeof(struct ggml_cgraph) % GGML_TYPE_SIZE[GGML_TYPE_I32] ? 1 : 0));
|
18293
|
+
|
18294
|
+
struct ggml_cgraph * gf = (struct ggml_cgraph *) gfbuf->data;
|
18295
|
+
struct ggml_cgraph * gb = (struct ggml_cgraph *) gbbuf->data;
|
18296
|
+
|
18297
|
+
*gf = ggml_build_forward (f);
|
18298
|
+
*gb = ggml_build_backward(ctx, gf, true);
|
18299
|
+
|
18300
|
+
return ggml_opt_resume_g(ctx, opt, f, gf, gb);
|
18301
|
+
}
|
18302
|
+
|
18303
|
+
enum ggml_opt_result ggml_opt_resume_g(
|
18304
|
+
struct ggml_context * ctx,
|
18305
|
+
struct ggml_opt_context * opt,
|
18306
|
+
struct ggml_tensor * f,
|
18307
|
+
struct ggml_cgraph * gf,
|
18308
|
+
struct ggml_cgraph * gb) {
|
18309
|
+
|
16078
18310
|
// build forward + backward compute graphs
|
16079
|
-
|
16080
|
-
struct ggml_cgraph gb = ggml_build_backward(ctx, &gf, true);
|
18311
|
+
enum ggml_opt_result result = GGML_OPT_OK;
|
16081
18312
|
|
16082
|
-
switch (params.type) {
|
18313
|
+
switch (opt->params.type) {
|
16083
18314
|
case GGML_OPT_ADAM:
|
16084
18315
|
{
|
16085
|
-
result = ggml_opt_adam(ctx, params, f,
|
18316
|
+
result = ggml_opt_adam(ctx, opt, opt->params, f, gf, gb);
|
16086
18317
|
} break;
|
16087
18318
|
case GGML_OPT_LBFGS:
|
16088
18319
|
{
|
16089
|
-
result = ggml_opt_lbfgs(ctx, params, f,
|
18320
|
+
result = ggml_opt_lbfgs(ctx, opt, opt->params, f, gf, gb);
|
16090
18321
|
} break;
|
16091
18322
|
}
|
16092
18323
|
|
16093
|
-
if (params.print_forward_graph) {
|
16094
|
-
ggml_graph_print (
|
16095
|
-
ggml_graph_dump_dot(
|
16096
|
-
}
|
16097
|
-
|
16098
|
-
if (params.print_backward_graph) {
|
16099
|
-
ggml_graph_print (&gb);
|
16100
|
-
ggml_graph_dump_dot(&gb, &gf, "opt-backward.dot");
|
18324
|
+
if (opt->params.print_forward_graph) {
|
18325
|
+
ggml_graph_print (gf);
|
18326
|
+
ggml_graph_dump_dot(gf, NULL, "opt-forward.dot");
|
16101
18327
|
}
|
16102
18328
|
|
16103
|
-
if (
|
16104
|
-
|
18329
|
+
if (opt->params.print_backward_graph) {
|
18330
|
+
ggml_graph_print (gb);
|
18331
|
+
ggml_graph_dump_dot(gb, gf, "opt-backward.dot");
|
16105
18332
|
}
|
16106
18333
|
|
16107
18334
|
return result;
|
@@ -16301,6 +18528,18 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i
|
|
16301
18528
|
result = ggml_quantize_q6_K(src + start, block, n, n, hist);
|
16302
18529
|
} break;
|
16303
18530
|
#endif
|
18531
|
+
case GGML_TYPE_F16:
|
18532
|
+
{
|
18533
|
+
int elemsize = sizeof(ggml_fp16_t);
|
18534
|
+
ggml_fp32_to_fp16_row(src + start, (ggml_fp16_t *)dst + start, n);
|
18535
|
+
result = n * elemsize;
|
18536
|
+
} break;
|
18537
|
+
case GGML_TYPE_F32:
|
18538
|
+
{
|
18539
|
+
int elemsize = sizeof(float);
|
18540
|
+
result = n * elemsize;
|
18541
|
+
memcpy((uint8_t *)dst + start * elemsize, src + start, result);
|
18542
|
+
} break;
|
16304
18543
|
default:
|
16305
18544
|
assert(false);
|
16306
18545
|
}
|