llama_cpp 0.2.0 → 0.2.2
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +11 -0
- data/examples/README.md +92 -0
- data/examples/chat.rb +195 -0
- data/examples/embedding.rb +37 -0
- data/ext/llama_cpp/llama_cpp.cpp +52 -0
- data/ext/llama_cpp/src/ggml-cuda.cu +1218 -411
- data/ext/llama_cpp/src/ggml-cuda.h +4 -1
- data/ext/llama_cpp/src/ggml-metal.h +5 -1
- data/ext/llama_cpp/src/ggml-metal.m +703 -514
- data/ext/llama_cpp/src/ggml-metal.metal +574 -122
- data/ext/llama_cpp/src/ggml-opencl.cpp +496 -36
- data/ext/llama_cpp/src/ggml-opencl.h +1 -2
- data/ext/llama_cpp/src/ggml.c +2715 -476
- data/ext/llama_cpp/src/ggml.h +266 -11
- data/ext/llama_cpp/src/llama.cpp +266 -135
- data/ext/llama_cpp/src/llama.h +19 -11
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +3 -0
- metadata +5 -2
data/ext/llama_cpp/src/ggml.c
CHANGED
@@ -35,6 +35,12 @@
|
|
35
35
|
#define static_assert(cond, msg) struct global_scope_noop_trick
|
36
36
|
#endif
|
37
37
|
|
38
|
+
#if defined(_MSC_VER)
|
39
|
+
// disable "possible loss of data" to avoid hundreds of casts
|
40
|
+
// we should just be careful :)
|
41
|
+
#pragma warning(disable: 4244 4267)
|
42
|
+
#endif
|
43
|
+
|
38
44
|
#if defined(_WIN32)
|
39
45
|
|
40
46
|
#include <windows.h>
|
@@ -106,6 +112,7 @@ typedef void* thread_ret_t;
|
|
106
112
|
/*#define GGML_PERF*/
|
107
113
|
#define GGML_DEBUG 0
|
108
114
|
#define GGML_GELU_FP16
|
115
|
+
#define GGML_GELU_QUICK_FP16
|
109
116
|
#define GGML_SILU_FP16
|
110
117
|
|
111
118
|
#define GGML_SOFT_MAX_UNROLL 4
|
@@ -334,6 +341,9 @@ static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
|
|
334
341
|
// precomputed gelu table for f16 (128 KB)
|
335
342
|
static ggml_fp16_t table_gelu_f16[1 << 16];
|
336
343
|
|
344
|
+
// precomputed quick gelu table for f16 (128 KB)
|
345
|
+
static ggml_fp16_t table_gelu_quick_f16[1 << 16];
|
346
|
+
|
337
347
|
// precomputed silu table for f16 (128 KB)
|
338
348
|
static ggml_fp16_t table_silu_f16[1 << 16];
|
339
349
|
|
@@ -1671,14 +1681,17 @@ quantize_fns_t ggml_internal_get_quantize_fn(size_t i) {
|
|
1671
1681
|
#define GGML_F32x4_REDUCE_ONE(x) vaddvq_f32(x)
|
1672
1682
|
#define GGML_F32x4_REDUCE(res, x) \
|
1673
1683
|
{ \
|
1674
|
-
|
1675
|
-
|
1684
|
+
int offset = GGML_F32_ARR >> 1; \
|
1685
|
+
for (int i = 0; i < offset; ++i) { \
|
1686
|
+
x[i] = vaddq_f32(x[i], x[offset+i]); \
|
1676
1687
|
} \
|
1677
|
-
|
1678
|
-
|
1688
|
+
offset >>= 1; \
|
1689
|
+
for (int i = 0; i < offset; ++i) { \
|
1690
|
+
x[i] = vaddq_f32(x[i], x[offset+i]); \
|
1679
1691
|
} \
|
1680
|
-
|
1681
|
-
|
1692
|
+
offset >>= 1; \
|
1693
|
+
for (int i = 0; i < offset; ++i) { \
|
1694
|
+
x[i] = vaddq_f32(x[i], x[offset+i]); \
|
1682
1695
|
} \
|
1683
1696
|
res = GGML_F32x4_REDUCE_ONE(x[0]); \
|
1684
1697
|
}
|
@@ -1709,14 +1722,17 @@ quantize_fns_t ggml_internal_get_quantize_fn(size_t i) {
|
|
1709
1722
|
#define GGML_F16x8_MUL vmulq_f16
|
1710
1723
|
#define GGML_F16x8_REDUCE(res, x) \
|
1711
1724
|
{ \
|
1712
|
-
|
1713
|
-
|
1725
|
+
int offset = GGML_F16_ARR >> 1; \
|
1726
|
+
for (int i = 0; i < offset; ++i) { \
|
1727
|
+
x[i] = vaddq_f16(x[i], x[offset+i]); \
|
1714
1728
|
} \
|
1715
|
-
|
1716
|
-
|
1729
|
+
offset >>= 1; \
|
1730
|
+
for (int i = 0; i < offset; ++i) { \
|
1731
|
+
x[i] = vaddq_f16(x[i], x[offset+i]); \
|
1717
1732
|
} \
|
1718
|
-
|
1719
|
-
|
1733
|
+
offset >>= 1; \
|
1734
|
+
for (int i = 0; i < offset; ++i) { \
|
1735
|
+
x[i] = vaddq_f16(x[i], x[offset+i]); \
|
1720
1736
|
} \
|
1721
1737
|
const float32x4_t t0 = vcvt_f32_f16(vget_low_f16 (x[0])); \
|
1722
1738
|
const float32x4_t t1 = vcvt_f32_f16(vget_high_f16(x[0])); \
|
@@ -1783,14 +1799,17 @@ quantize_fns_t ggml_internal_get_quantize_fn(size_t i) {
|
|
1783
1799
|
#define GGML_F32x8_MUL _mm256_mul_ps
|
1784
1800
|
#define GGML_F32x8_REDUCE(res, x) \
|
1785
1801
|
{ \
|
1786
|
-
|
1787
|
-
|
1802
|
+
int offset = GGML_F32_ARR >> 1; \
|
1803
|
+
for (int i = 0; i < offset; ++i) { \
|
1804
|
+
x[i] = _mm256_add_ps(x[i], x[offset+i]); \
|
1788
1805
|
} \
|
1789
|
-
|
1790
|
-
|
1806
|
+
offset >>= 1; \
|
1807
|
+
for (int i = 0; i < offset; ++i) { \
|
1808
|
+
x[i] = _mm256_add_ps(x[i], x[offset+i]); \
|
1791
1809
|
} \
|
1792
|
-
|
1793
|
-
|
1810
|
+
offset >>= 1; \
|
1811
|
+
for (int i = 0; i < offset; ++i) { \
|
1812
|
+
x[i] = _mm256_add_ps(x[i], x[offset+i]); \
|
1794
1813
|
} \
|
1795
1814
|
const __m128 t0 = _mm_add_ps(_mm256_castps256_ps128(x[0]), \
|
1796
1815
|
_mm256_extractf128_ps(x[0], 1)); \
|
@@ -1880,14 +1899,17 @@ static inline void __avx_f32cx8_store(ggml_fp16_t *x, __m256 y) {
|
|
1880
1899
|
#define GGML_F32x4_MUL vec_mul
|
1881
1900
|
#define GGML_F32x4_REDUCE(res, x) \
|
1882
1901
|
{ \
|
1883
|
-
|
1884
|
-
|
1902
|
+
int offset = GGML_F32_ARR >> 1; \
|
1903
|
+
for (int i = 0; i < offset; ++i) { \
|
1904
|
+
x[i] = vec_add(x[i], x[offset+i]); \
|
1885
1905
|
} \
|
1886
|
-
|
1887
|
-
|
1906
|
+
offset >>= 1; \
|
1907
|
+
for (int i = 0; i < offset; ++i) { \
|
1908
|
+
x[i] = vec_add(x[i], x[offset+i]); \
|
1888
1909
|
} \
|
1889
|
-
|
1890
|
-
|
1910
|
+
offset >>= 1; \
|
1911
|
+
for (int i = 0; i < offset; ++i) { \
|
1912
|
+
x[i] = vec_add(x[i], x[offset+i]); \
|
1891
1913
|
} \
|
1892
1914
|
res = vec_extract(x[0], 0) + \
|
1893
1915
|
vec_extract(x[0], 1) + \
|
@@ -1943,14 +1965,17 @@ static inline void __avx_f32cx8_store(ggml_fp16_t *x, __m256 y) {
|
|
1943
1965
|
#define GGML_F32x4_MUL wasm_f32x4_mul
|
1944
1966
|
#define GGML_F32x4_REDUCE(res, x) \
|
1945
1967
|
{ \
|
1946
|
-
|
1947
|
-
|
1968
|
+
int offset = GGML_F32_ARR >> 1; \
|
1969
|
+
for (int i = 0; i < offset; ++i) { \
|
1970
|
+
x[i] = wasm_f32x4_add(x[i], x[offset+i]); \
|
1948
1971
|
} \
|
1949
|
-
|
1950
|
-
|
1972
|
+
offset >>= 1; \
|
1973
|
+
for (int i = 0; i < offset; ++i) { \
|
1974
|
+
x[i] = wasm_f32x4_add(x[i], x[offset+i]); \
|
1951
1975
|
} \
|
1952
|
-
|
1953
|
-
|
1976
|
+
offset >>= 1; \
|
1977
|
+
for (int i = 0; i < offset; ++i) { \
|
1978
|
+
x[i] = wasm_f32x4_add(x[i], x[offset+i]); \
|
1954
1979
|
} \
|
1955
1980
|
res = wasm_f32x4_extract_lane(x[0], 0) + \
|
1956
1981
|
wasm_f32x4_extract_lane(x[0], 1) + \
|
@@ -2005,14 +2030,17 @@ inline static void __wasm_f16x4_store(ggml_fp16_t * p, v128_t x) {
|
|
2005
2030
|
#define GGML_F16x4_MUL wasm_f32x4_mul
|
2006
2031
|
#define GGML_F16x4_REDUCE(res, x) \
|
2007
2032
|
{ \
|
2008
|
-
|
2009
|
-
|
2033
|
+
int offset = GGML_F16_ARR >> 1; \
|
2034
|
+
for (int i = 0; i < offset; ++i) { \
|
2035
|
+
x[i] = wasm_f32x4_add(x[i], x[offset+i]); \
|
2010
2036
|
} \
|
2011
|
-
|
2012
|
-
|
2037
|
+
offset >>= 1; \
|
2038
|
+
for (int i = 0; i < offset; ++i) { \
|
2039
|
+
x[i] = wasm_f32x4_add(x[i], x[offset+i]); \
|
2013
2040
|
} \
|
2014
|
-
|
2015
|
-
|
2041
|
+
offset >>= 1; \
|
2042
|
+
for (int i = 0; i < offset; ++i) { \
|
2043
|
+
x[i] = wasm_f32x4_add(x[i], x[offset+i]); \
|
2016
2044
|
} \
|
2017
2045
|
res = wasm_f32x4_extract_lane(x[0], 0) + \
|
2018
2046
|
wasm_f32x4_extract_lane(x[0], 1) + \
|
@@ -2054,14 +2082,17 @@ inline static void __wasm_f16x4_store(ggml_fp16_t * p, v128_t x) {
|
|
2054
2082
|
#define GGML_F32x4_MUL _mm_mul_ps
|
2055
2083
|
#define GGML_F32x4_REDUCE(res, x) \
|
2056
2084
|
{ \
|
2057
|
-
|
2058
|
-
|
2085
|
+
int offset = GGML_F32_ARR >> 1; \
|
2086
|
+
for (int i = 0; i < offset; ++i) { \
|
2087
|
+
x[i] = _mm_add_ps(x[i], x[offset+i]); \
|
2059
2088
|
} \
|
2060
|
-
|
2061
|
-
|
2089
|
+
offset >>= 1; \
|
2090
|
+
for (int i = 0; i < offset; ++i) { \
|
2091
|
+
x[i] = _mm_add_ps(x[i], x[offset+i]); \
|
2062
2092
|
} \
|
2063
|
-
|
2064
|
-
|
2093
|
+
offset >>= 1; \
|
2094
|
+
for (int i = 0; i < offset; ++i) { \
|
2095
|
+
x[i] = _mm_add_ps(x[i], x[offset+i]); \
|
2065
2096
|
} \
|
2066
2097
|
const __m128 t0 = _mm_hadd_ps(x[0], x[0]); \
|
2067
2098
|
res = _mm_cvtss_f32(_mm_hadd_ps(t0, t0)); \
|
@@ -3350,6 +3381,7 @@ inline static void ggml_vec_step_f32 (const int n, float * y, const float * x) {
|
|
3350
3381
|
inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; }
|
3351
3382
|
|
3352
3383
|
static const float GELU_COEF_A = 0.044715f;
|
3384
|
+
static const float GELU_QUICK_COEF = -1.702f;
|
3353
3385
|
static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
3354
3386
|
|
3355
3387
|
inline static float ggml_gelu_f32(float x) {
|
@@ -3380,6 +3412,34 @@ inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
|
|
3380
3412
|
}
|
3381
3413
|
#endif
|
3382
3414
|
|
3415
|
+
inline static float ggml_gelu_quick_f32(float x) {
|
3416
|
+
return x*(1.0f/(1.0f+expf(GELU_QUICK_COEF*x)));
|
3417
|
+
}
|
3418
|
+
|
3419
|
+
//inline static void ggml_vec_gelu_quick_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
|
3420
|
+
// const uint16_t * i16 = (const uint16_t *) x;
|
3421
|
+
// for (int i = 0; i < n; ++i) {
|
3422
|
+
// y[i] = table_gelu_quick_f16[i16[i]];
|
3423
|
+
// }
|
3424
|
+
//}
|
3425
|
+
|
3426
|
+
#ifdef GGML_GELU_QUICK_FP16
|
3427
|
+
inline static void ggml_vec_gelu_quick_f32(const int n, float * y, const float * x) {
|
3428
|
+
uint16_t t;
|
3429
|
+
for (int i = 0; i < n; ++i) {
|
3430
|
+
ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]);
|
3431
|
+
memcpy(&t, &fp16, sizeof(uint16_t));
|
3432
|
+
y[i] = GGML_FP16_TO_FP32(table_gelu_quick_f16[t]);
|
3433
|
+
}
|
3434
|
+
}
|
3435
|
+
#else
|
3436
|
+
inline static void ggml_vec_gelu_quick_f32(const int n, float * y, const float * x) {
|
3437
|
+
for (int i = 0; i < n; ++i) {
|
3438
|
+
y[i] = ggml_gelu_quick_f32(x[i]);
|
3439
|
+
}
|
3440
|
+
}
|
3441
|
+
#endif
|
3442
|
+
|
3383
3443
|
// Sigmoid Linear Unit (SiLU) function
|
3384
3444
|
inline static float ggml_silu_f32(float x) {
|
3385
3445
|
return x/(1.0f + expf(-x));
|
@@ -3603,12 +3663,14 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|
3603
3663
|
"SUM_ROWS",
|
3604
3664
|
"MEAN",
|
3605
3665
|
"REPEAT",
|
3666
|
+
"REPEAT_BACK",
|
3606
3667
|
"ABS",
|
3607
3668
|
"SGN",
|
3608
3669
|
"NEG",
|
3609
3670
|
"STEP",
|
3610
3671
|
"RELU",
|
3611
3672
|
"GELU",
|
3673
|
+
"GELU_QUICK",
|
3612
3674
|
"SILU",
|
3613
3675
|
"SILU_BACK",
|
3614
3676
|
"NORM",
|
@@ -3616,6 +3678,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|
3616
3678
|
"RMS_NORM_BACK",
|
3617
3679
|
|
3618
3680
|
"MUL_MAT",
|
3681
|
+
"OUT_PROD",
|
3619
3682
|
|
3620
3683
|
"SCALE",
|
3621
3684
|
"SET",
|
@@ -3631,22 +3694,29 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|
3631
3694
|
"DIAG_MASK_INF",
|
3632
3695
|
"DIAG_MASK_ZERO",
|
3633
3696
|
"SOFT_MAX",
|
3697
|
+
"SOFT_MAX_BACK",
|
3634
3698
|
"ROPE",
|
3635
3699
|
"ROPE_BACK",
|
3636
3700
|
"ALIBI",
|
3637
3701
|
"CLAMP",
|
3638
|
-
"
|
3639
|
-
"
|
3702
|
+
"CONV_1D_S1_PH",
|
3703
|
+
"CONV_1D_S2_PH",
|
3704
|
+
"CONV_2D_SK_P0",
|
3640
3705
|
|
3641
3706
|
"FLASH_ATTN",
|
3642
3707
|
"FLASH_FF",
|
3708
|
+
"FLASH_ATTN_BACK",
|
3709
|
+
"WIN_PART",
|
3710
|
+
"WIN_UNPART",
|
3643
3711
|
|
3644
3712
|
"MAP_UNARY",
|
3645
3713
|
"MAP_BINARY",
|
3646
|
-
};
|
3647
3714
|
|
3648
|
-
|
3715
|
+
"CROSS_ENTROPY_LOSS",
|
3716
|
+
"CROSS_ENTROPY_LOSS_BACK",
|
3717
|
+
};
|
3649
3718
|
|
3719
|
+
static_assert(GGML_OP_COUNT == 61, "GGML_OP_COUNT != 61");
|
3650
3720
|
|
3651
3721
|
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
3652
3722
|
"none",
|
@@ -3665,18 +3735,21 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|
3665
3735
|
"Σx_k",
|
3666
3736
|
"Σx/n",
|
3667
3737
|
"repeat(x)",
|
3738
|
+
"repeat_back(x)",
|
3668
3739
|
"abs(x)",
|
3669
3740
|
"sgn(x)",
|
3670
3741
|
"-x",
|
3671
3742
|
"step(x)",
|
3672
3743
|
"relu(x)",
|
3673
3744
|
"gelu(x)",
|
3745
|
+
"gelu_quick(x)",
|
3674
3746
|
"silu(x)",
|
3675
3747
|
"silu_back(x)",
|
3676
3748
|
"norm(x)",
|
3677
3749
|
"rms_norm(x)",
|
3678
3750
|
"rms_norm_back(x)",
|
3679
3751
|
|
3752
|
+
"X*Y",
|
3680
3753
|
"X*Y",
|
3681
3754
|
|
3682
3755
|
"x*v",
|
@@ -3693,21 +3766,29 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|
3693
3766
|
"diag_mask_inf(x)",
|
3694
3767
|
"diag_mask_zero(x)",
|
3695
3768
|
"soft_max(x)",
|
3769
|
+
"soft_max_back(x)",
|
3696
3770
|
"rope(x)",
|
3697
3771
|
"rope_back(x)",
|
3698
3772
|
"alibi(x)",
|
3699
3773
|
"clamp(x)",
|
3700
|
-
"
|
3701
|
-
"
|
3774
|
+
"conv_1d_s1_ph(x)",
|
3775
|
+
"conv_1d_s2_ph(x)",
|
3776
|
+
"conv_2d_sk_p0(x)",
|
3702
3777
|
|
3703
3778
|
"flash_attn(x)",
|
3704
3779
|
"flash_ff(x)",
|
3780
|
+
"flash_attn_back(x)",
|
3781
|
+
"win_part(x)",
|
3782
|
+
"win_unpart(x)",
|
3705
3783
|
|
3706
3784
|
"f(x)",
|
3707
3785
|
"f(x,y)",
|
3786
|
+
|
3787
|
+
"cross_entropy_loss(x,y)",
|
3788
|
+
"cross_entropy_loss_back(x,y)",
|
3708
3789
|
};
|
3709
3790
|
|
3710
|
-
static_assert(GGML_OP_COUNT ==
|
3791
|
+
static_assert(GGML_OP_COUNT == 61, "GGML_OP_COUNT != 61");
|
3711
3792
|
|
3712
3793
|
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
|
3713
3794
|
static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
|
@@ -3870,6 +3951,15 @@ static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct
|
|
3870
3951
|
(t0->ne[3] == t1->ne[3]);
|
3871
3952
|
}
|
3872
3953
|
|
3954
|
+
static inline bool ggml_can_out_prod(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
|
3955
|
+
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
3956
|
+
|
3957
|
+
return
|
3958
|
+
(t0->ne[1] == t1->ne[1]) &&
|
3959
|
+
(t0->ne[2] == t1->ne[2]) &&
|
3960
|
+
(t0->ne[3] == t1->ne[3]);
|
3961
|
+
}
|
3962
|
+
|
3873
3963
|
bool ggml_is_quantized(enum ggml_type type) {
|
3874
3964
|
return GGML_IS_QUANTIZED[type];
|
3875
3965
|
}
|
@@ -3917,6 +4007,12 @@ bool ggml_is_contiguous(const struct ggml_tensor * tensor) {
|
|
3917
4007
|
tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
|
3918
4008
|
}
|
3919
4009
|
|
4010
|
+
bool ggml_is_permuted(const struct ggml_tensor * tensor) {
|
4011
|
+
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
4012
|
+
|
4013
|
+
return tensor->nb[0] > tensor->nb[1] || tensor->nb[1] > tensor->nb[2] || tensor->nb[2] > tensor->nb[3];
|
4014
|
+
}
|
4015
|
+
|
3920
4016
|
static inline bool ggml_is_padded_1d(const struct ggml_tensor * tensor) {
|
3921
4017
|
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
3922
4018
|
|
@@ -3983,7 +4079,7 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
|
|
3983
4079
|
// initialize time system (required on Windows)
|
3984
4080
|
ggml_time_init();
|
3985
4081
|
|
3986
|
-
// initialize GELU, SILU and EXP F32 tables
|
4082
|
+
// initialize GELU, Quick GELU, SILU and EXP F32 tables
|
3987
4083
|
{
|
3988
4084
|
const uint64_t t_start = ggml_time_us(); UNUSED(t_start);
|
3989
4085
|
|
@@ -3993,13 +4089,14 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
|
|
3993
4089
|
memcpy(&ii, &ui, sizeof(ii));
|
3994
4090
|
const float f = table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(ii);
|
3995
4091
|
table_gelu_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_f32(f));
|
4092
|
+
table_gelu_quick_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_quick_f32(f));
|
3996
4093
|
table_silu_f16[i] = GGML_FP32_TO_FP16(ggml_silu_f32(f));
|
3997
4094
|
table_exp_f16[i] = GGML_FP32_TO_FP16(expf(f));
|
3998
4095
|
}
|
3999
4096
|
|
4000
4097
|
const uint64_t t_end = ggml_time_us(); UNUSED(t_end);
|
4001
4098
|
|
4002
|
-
GGML_PRINT_DEBUG("%s: GELU, SILU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f);
|
4099
|
+
GGML_PRINT_DEBUG("%s: GELU, Quick GELU, SILU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f);
|
4003
4100
|
}
|
4004
4101
|
|
4005
4102
|
// initialize g_state
|
@@ -4120,14 +4217,34 @@ void ggml_set_no_alloc(struct ggml_context * ctx, bool no_alloc) {
|
|
4120
4217
|
ctx->no_alloc = no_alloc;
|
4121
4218
|
}
|
4122
4219
|
|
4123
|
-
void * ggml_get_mem_buffer(struct ggml_context * ctx) {
|
4220
|
+
void * ggml_get_mem_buffer(const struct ggml_context * ctx) {
|
4124
4221
|
return ctx->mem_buffer;
|
4125
4222
|
}
|
4126
4223
|
|
4127
|
-
size_t ggml_get_mem_size(struct ggml_context * ctx) {
|
4224
|
+
size_t ggml_get_mem_size(const struct ggml_context * ctx) {
|
4128
4225
|
return ctx->mem_size;
|
4129
4226
|
}
|
4130
4227
|
|
4228
|
+
size_t ggml_get_max_tensor_size(const struct ggml_context * ctx) {
|
4229
|
+
size_t max_size = 0;
|
4230
|
+
|
4231
|
+
struct ggml_object * obj = ctx->objects_begin;
|
4232
|
+
|
4233
|
+
while (obj != NULL) {
|
4234
|
+
struct ggml_tensor * tensor = (struct ggml_tensor *) ((char *) ctx->mem_buffer + obj->offs);
|
4235
|
+
|
4236
|
+
const size_t size = ggml_nbytes(tensor);
|
4237
|
+
|
4238
|
+
if (max_size < size) {
|
4239
|
+
max_size = size;
|
4240
|
+
}
|
4241
|
+
|
4242
|
+
obj = obj->next;
|
4243
|
+
}
|
4244
|
+
|
4245
|
+
return max_size;
|
4246
|
+
}
|
4247
|
+
|
4131
4248
|
// IMPORTANT:
|
4132
4249
|
// when creating "opt" tensors, always save and load the scratch buffer
|
4133
4250
|
// this is an error prone process, but it is necessary to support inplace
|
@@ -4611,9 +4728,10 @@ const char * ggml_get_name(const struct ggml_tensor * tensor) {
|
|
4611
4728
|
return tensor->name;
|
4612
4729
|
}
|
4613
4730
|
|
4614
|
-
|
4731
|
+
struct ggml_tensor * ggml_set_name(struct ggml_tensor * tensor, const char * name) {
|
4615
4732
|
strncpy(tensor->name, name, sizeof(tensor->name));
|
4616
4733
|
tensor->name[sizeof(tensor->name) - 1] = '\0';
|
4734
|
+
return tensor;
|
4617
4735
|
}
|
4618
4736
|
|
4619
4737
|
struct ggml_tensor * ggml_view_tensor(
|
@@ -4693,7 +4811,7 @@ struct ggml_tensor * ggml_add_impl(
|
|
4693
4811
|
|
4694
4812
|
bool is_node = false;
|
4695
4813
|
|
4696
|
-
if (
|
4814
|
+
if (a->grad || b->grad) {
|
4697
4815
|
is_node = true;
|
4698
4816
|
}
|
4699
4817
|
|
@@ -4733,7 +4851,7 @@ struct ggml_tensor * ggml_add1_impl(
|
|
4733
4851
|
|
4734
4852
|
bool is_node = false;
|
4735
4853
|
|
4736
|
-
if (
|
4854
|
+
if (a->grad || b->grad) {
|
4737
4855
|
is_node = true;
|
4738
4856
|
}
|
4739
4857
|
|
@@ -5159,6 +5277,34 @@ struct ggml_tensor * ggml_repeat(
|
|
5159
5277
|
return result;
|
5160
5278
|
}
|
5161
5279
|
|
5280
|
+
// ggml_repeat_back
|
5281
|
+
|
5282
|
+
struct ggml_tensor * ggml_repeat_back(
|
5283
|
+
struct ggml_context * ctx,
|
5284
|
+
struct ggml_tensor * a,
|
5285
|
+
struct ggml_tensor * b) {
|
5286
|
+
GGML_ASSERT(ggml_can_repeat(b, a));
|
5287
|
+
|
5288
|
+
bool is_node = false;
|
5289
|
+
|
5290
|
+
if (a->grad) {
|
5291
|
+
is_node = true;
|
5292
|
+
}
|
5293
|
+
|
5294
|
+
if (ggml_are_same_shape(a, b) && !is_node) {
|
5295
|
+
return a;
|
5296
|
+
}
|
5297
|
+
|
5298
|
+
struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, b->n_dims, b->ne);
|
5299
|
+
|
5300
|
+
result->op = GGML_OP_REPEAT_BACK;
|
5301
|
+
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
5302
|
+
result->src0 = a;
|
5303
|
+
result->src1 = b;
|
5304
|
+
|
5305
|
+
return result;
|
5306
|
+
}
|
5307
|
+
|
5162
5308
|
// ggml_abs
|
5163
5309
|
|
5164
5310
|
struct ggml_tensor * ggml_abs_impl(
|
@@ -5364,6 +5510,40 @@ struct ggml_tensor * ggml_gelu_inplace(
|
|
5364
5510
|
return ggml_gelu_impl(ctx, a, true);
|
5365
5511
|
}
|
5366
5512
|
|
5513
|
+
// ggml_gelu_quick
|
5514
|
+
|
5515
|
+
struct ggml_tensor * ggml_gelu_quick_impl(
|
5516
|
+
struct ggml_context * ctx,
|
5517
|
+
struct ggml_tensor * a,
|
5518
|
+
bool inplace) {
|
5519
|
+
bool is_node = false;
|
5520
|
+
|
5521
|
+
if (!inplace && (a->grad)) {
|
5522
|
+
is_node = true;
|
5523
|
+
}
|
5524
|
+
|
5525
|
+
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
5526
|
+
|
5527
|
+
result->op = GGML_OP_GELU_QUICK;
|
5528
|
+
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
5529
|
+
result->src0 = a;
|
5530
|
+
result->src1 = NULL;
|
5531
|
+
|
5532
|
+
return result;
|
5533
|
+
}
|
5534
|
+
|
5535
|
+
struct ggml_tensor * ggml_gelu_quick(
|
5536
|
+
struct ggml_context * ctx,
|
5537
|
+
struct ggml_tensor * a) {
|
5538
|
+
return ggml_gelu_quick_impl(ctx, a, false);
|
5539
|
+
}
|
5540
|
+
|
5541
|
+
struct ggml_tensor * ggml_gelu_quick_inplace(
|
5542
|
+
struct ggml_context * ctx,
|
5543
|
+
struct ggml_tensor * a) {
|
5544
|
+
return ggml_gelu_quick_impl(ctx, a, true);
|
5545
|
+
}
|
5546
|
+
|
5367
5547
|
// ggml_silu
|
5368
5548
|
|
5369
5549
|
struct ggml_tensor * ggml_silu_impl(
|
@@ -5536,6 +5716,32 @@ struct ggml_tensor * ggml_mul_mat(
|
|
5536
5716
|
return result;
|
5537
5717
|
}
|
5538
5718
|
|
5719
|
+
// ggml_out_prod
|
5720
|
+
|
5721
|
+
struct ggml_tensor * ggml_out_prod(
|
5722
|
+
struct ggml_context * ctx,
|
5723
|
+
struct ggml_tensor * a,
|
5724
|
+
struct ggml_tensor * b) {
|
5725
|
+
GGML_ASSERT(ggml_can_out_prod(a, b));
|
5726
|
+
GGML_ASSERT(!ggml_is_transposed(a));
|
5727
|
+
|
5728
|
+
bool is_node = false;
|
5729
|
+
|
5730
|
+
if (a->grad || b->grad) {
|
5731
|
+
is_node = true;
|
5732
|
+
}
|
5733
|
+
|
5734
|
+
const int64_t ne[4] = { a->ne[0], b->ne[0], a->ne[2], b->ne[3] };
|
5735
|
+
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MIN(a->n_dims, b->n_dims), ne);
|
5736
|
+
|
5737
|
+
result->op = GGML_OP_OUT_PROD;
|
5738
|
+
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
5739
|
+
result->src0 = a;
|
5740
|
+
result->src1 = b;
|
5741
|
+
|
5742
|
+
return result;
|
5743
|
+
}
|
5744
|
+
|
5539
5745
|
// ggml_scale
|
5540
5746
|
|
5541
5747
|
struct ggml_tensor * ggml_scale_impl(
|
@@ -5548,7 +5754,7 @@ struct ggml_tensor * ggml_scale_impl(
|
|
5548
5754
|
|
5549
5755
|
bool is_node = false;
|
5550
5756
|
|
5551
|
-
if (
|
5757
|
+
if (a->grad || b->grad) {
|
5552
5758
|
is_node = true;
|
5553
5759
|
}
|
5554
5760
|
|
@@ -5591,7 +5797,7 @@ struct ggml_tensor * ggml_set_impl(
|
|
5591
5797
|
|
5592
5798
|
bool is_node = false;
|
5593
5799
|
|
5594
|
-
if (
|
5800
|
+
if (a->grad || b->grad) {
|
5595
5801
|
is_node = true;
|
5596
5802
|
}
|
5597
5803
|
|
@@ -5913,10 +6119,6 @@ struct ggml_tensor * ggml_view_1d(
|
|
5913
6119
|
result->src1 = NULL;
|
5914
6120
|
result->opt[0] = offs;
|
5915
6121
|
|
5916
|
-
if (is_node) {
|
5917
|
-
memcpy(result->padding, &offset, sizeof(offset));
|
5918
|
-
}
|
5919
|
-
|
5920
6122
|
return result;
|
5921
6123
|
}
|
5922
6124
|
|
@@ -5957,10 +6159,6 @@ struct ggml_tensor * ggml_view_2d(
|
|
5957
6159
|
result->src1 = NULL;
|
5958
6160
|
result->opt[0] = offs;
|
5959
6161
|
|
5960
|
-
if (is_node) {
|
5961
|
-
memcpy(result->padding, &offset, sizeof(offset));
|
5962
|
-
}
|
5963
|
-
|
5964
6162
|
return result;
|
5965
6163
|
}
|
5966
6164
|
|
@@ -6003,10 +6201,6 @@ struct ggml_tensor * ggml_view_3d(
|
|
6003
6201
|
result->src1 = NULL;
|
6004
6202
|
result->opt[0] = offs;
|
6005
6203
|
|
6006
|
-
if (is_node) {
|
6007
|
-
memcpy(result->padding, &offset, sizeof(offset));
|
6008
|
-
}
|
6009
|
-
|
6010
6204
|
return result;
|
6011
6205
|
}
|
6012
6206
|
|
@@ -6051,10 +6245,6 @@ struct ggml_tensor * ggml_view_4d(
|
|
6051
6245
|
result->src1 = NULL;
|
6052
6246
|
result->opt[0] = offs;
|
6053
6247
|
|
6054
|
-
if (is_node) {
|
6055
|
-
memcpy(result->padding, &offset, sizeof(offset));
|
6056
|
-
}
|
6057
|
-
|
6058
6248
|
return result;
|
6059
6249
|
}
|
6060
6250
|
|
@@ -6116,10 +6306,18 @@ struct ggml_tensor * ggml_permute(
|
|
6116
6306
|
result->src1 = NULL;
|
6117
6307
|
|
6118
6308
|
if (is_node) {
|
6119
|
-
|
6120
|
-
|
6121
|
-
|
6122
|
-
|
6309
|
+
ggml_scratch_save(ctx);
|
6310
|
+
|
6311
|
+
struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 4);
|
6312
|
+
|
6313
|
+
((int32_t *) b->data)[0] = axis0;
|
6314
|
+
((int32_t *) b->data)[1] = axis1;
|
6315
|
+
((int32_t *) b->data)[2] = axis2;
|
6316
|
+
((int32_t *) b->data)[3] = axis3;
|
6317
|
+
|
6318
|
+
ggml_scratch_load(ctx);
|
6319
|
+
|
6320
|
+
result->opt[0] = b;
|
6123
6321
|
}
|
6124
6322
|
|
6125
6323
|
return result;
|
@@ -6359,6 +6557,44 @@ struct ggml_tensor * ggml_soft_max_inplace(
|
|
6359
6557
|
return ggml_soft_max_impl(ctx, a, true);
|
6360
6558
|
}
|
6361
6559
|
|
6560
|
+
|
6561
|
+
// ggml_soft_max_back
|
6562
|
+
|
6563
|
+
struct ggml_tensor * ggml_soft_max_back_impl(
|
6564
|
+
struct ggml_context * ctx,
|
6565
|
+
struct ggml_tensor * a,
|
6566
|
+
struct ggml_tensor * b,
|
6567
|
+
bool inplace) {
|
6568
|
+
bool is_node = false;
|
6569
|
+
|
6570
|
+
if (a->grad || b->grad) {
|
6571
|
+
is_node = true; // TODO : implement backward pass
|
6572
|
+
}
|
6573
|
+
|
6574
|
+
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
6575
|
+
|
6576
|
+
result->op = GGML_OP_SOFT_MAX_BACK;
|
6577
|
+
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
6578
|
+
result->src0 = a;
|
6579
|
+
result->src1 = b;
|
6580
|
+
|
6581
|
+
return result;
|
6582
|
+
}
|
6583
|
+
|
6584
|
+
struct ggml_tensor * ggml_soft_max_back(
|
6585
|
+
struct ggml_context * ctx,
|
6586
|
+
struct ggml_tensor * a,
|
6587
|
+
struct ggml_tensor * b) {
|
6588
|
+
return ggml_soft_max_back_impl(ctx, a, b, false);
|
6589
|
+
}
|
6590
|
+
|
6591
|
+
struct ggml_tensor * ggml_soft_max_back_inplace(
|
6592
|
+
struct ggml_context * ctx,
|
6593
|
+
struct ggml_tensor * a,
|
6594
|
+
struct ggml_tensor * b) {
|
6595
|
+
return ggml_soft_max_back_impl(ctx, a, b, true);
|
6596
|
+
}
|
6597
|
+
|
6362
6598
|
// ggml_rope
|
6363
6599
|
|
6364
6600
|
struct ggml_tensor * ggml_rope_impl(
|
@@ -6371,7 +6607,7 @@ struct ggml_tensor * ggml_rope_impl(
|
|
6371
6607
|
GGML_ASSERT(n_past >= 0);
|
6372
6608
|
bool is_node = false;
|
6373
6609
|
|
6374
|
-
if (
|
6610
|
+
if (a->grad) {
|
6375
6611
|
is_node = true;
|
6376
6612
|
}
|
6377
6613
|
|
@@ -6425,8 +6661,7 @@ struct ggml_tensor * ggml_rope_back(
|
|
6425
6661
|
bool is_node = false;
|
6426
6662
|
|
6427
6663
|
if (a->grad) {
|
6428
|
-
|
6429
|
-
is_node = true;
|
6664
|
+
is_node = false; // TODO: implement backward
|
6430
6665
|
}
|
6431
6666
|
|
6432
6667
|
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
|
@@ -6508,7 +6743,7 @@ struct ggml_tensor * ggml_clamp(
|
|
6508
6743
|
|
6509
6744
|
ggml_scratch_save(ctx);
|
6510
6745
|
|
6511
|
-
struct ggml_tensor * b = ggml_new_tensor_1d(ctx,
|
6746
|
+
struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 2);
|
6512
6747
|
|
6513
6748
|
((float *) b->data)[0] = min;
|
6514
6749
|
((float *) b->data)[1] = max;
|
@@ -6523,9 +6758,9 @@ struct ggml_tensor * ggml_clamp(
|
|
6523
6758
|
return result;
|
6524
6759
|
}
|
6525
6760
|
|
6526
|
-
//
|
6761
|
+
// ggml_conv_1d_s1_ph
|
6527
6762
|
|
6528
|
-
struct ggml_tensor *
|
6763
|
+
struct ggml_tensor * ggml_conv_1d_s1_ph(
|
6529
6764
|
struct ggml_context * ctx,
|
6530
6765
|
struct ggml_tensor * a,
|
6531
6766
|
struct ggml_tensor * b) {
|
@@ -6542,7 +6777,7 @@ struct ggml_tensor * ggml_conv_1d_1s(
|
|
6542
6777
|
const int64_t ne[4] = { b->ne[0], a->ne[2], 1, 1, };
|
6543
6778
|
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 2, ne);
|
6544
6779
|
|
6545
|
-
result->op =
|
6780
|
+
result->op = GGML_OP_CONV_1D_S1_PH;
|
6546
6781
|
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
6547
6782
|
result->src0 = a;
|
6548
6783
|
result->src1 = b;
|
@@ -6550,9 +6785,9 @@ struct ggml_tensor * ggml_conv_1d_1s(
|
|
6550
6785
|
return result;
|
6551
6786
|
}
|
6552
6787
|
|
6553
|
-
//
|
6788
|
+
// ggml_conv_1d_s2_ph
|
6554
6789
|
|
6555
|
-
struct ggml_tensor *
|
6790
|
+
struct ggml_tensor * ggml_conv_1d_s2_ph(
|
6556
6791
|
struct ggml_context * ctx,
|
6557
6792
|
struct ggml_tensor * a,
|
6558
6793
|
struct ggml_tensor * b) {
|
@@ -6569,7 +6804,35 @@ struct ggml_tensor * ggml_conv_1d_2s(
|
|
6569
6804
|
const int64_t ne[4] = { b->ne[0]/2, a->ne[2], 1, 1, };
|
6570
6805
|
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 2, ne);
|
6571
6806
|
|
6572
|
-
result->op =
|
6807
|
+
result->op = GGML_OP_CONV_1D_S2_PH;
|
6808
|
+
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
6809
|
+
result->src0 = a;
|
6810
|
+
result->src1 = b;
|
6811
|
+
|
6812
|
+
return result;
|
6813
|
+
}
|
6814
|
+
|
6815
|
+
// ggml_conv_2d_sk_p0
|
6816
|
+
|
6817
|
+
struct ggml_tensor * ggml_conv_2d_sk_p0(
|
6818
|
+
struct ggml_context * ctx,
|
6819
|
+
struct ggml_tensor * a,
|
6820
|
+
struct ggml_tensor * b) {
|
6821
|
+
GGML_ASSERT(b->ne[3] == 1);
|
6822
|
+
GGML_ASSERT(a->ne[2] == b->ne[2]);
|
6823
|
+
GGML_ASSERT(b->ne[0] % a->ne[0] == 0);
|
6824
|
+
GGML_ASSERT(b->ne[1] % a->ne[1] == 0);
|
6825
|
+
bool is_node = false;
|
6826
|
+
|
6827
|
+
if (a->grad || b->grad) {
|
6828
|
+
GGML_ASSERT(false); // TODO: implement backward
|
6829
|
+
is_node = true;
|
6830
|
+
}
|
6831
|
+
|
6832
|
+
const int64_t ne[4] = { b->ne[0]/a->ne[0], b->ne[1]/a->ne[1], a->ne[3], 1, };
|
6833
|
+
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
|
6834
|
+
|
6835
|
+
result->op = GGML_OP_CONV_2D_SK_P0;
|
6573
6836
|
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
6574
6837
|
result->src0 = a;
|
6575
6838
|
result->src1 = b;
|
@@ -6591,7 +6854,6 @@ struct ggml_tensor * ggml_flash_attn(
|
|
6591
6854
|
bool is_node = false;
|
6592
6855
|
|
6593
6856
|
if (q->grad || k->grad || v->grad) {
|
6594
|
-
GGML_ASSERT(false); // TODO: implement backward
|
6595
6857
|
is_node = true;
|
6596
6858
|
}
|
6597
6859
|
|
@@ -6623,7 +6885,6 @@ struct ggml_tensor * ggml_flash_ff(
|
|
6623
6885
|
bool is_node = false;
|
6624
6886
|
|
6625
6887
|
if (a->grad || b0->grad || b1->grad || c0->grad || c1->grad) {
|
6626
|
-
GGML_ASSERT(false); // TODO: implement backward
|
6627
6888
|
is_node = true;
|
6628
6889
|
}
|
6629
6890
|
|
@@ -6641,54 +6902,202 @@ struct ggml_tensor * ggml_flash_ff(
|
|
6641
6902
|
return result;
|
6642
6903
|
}
|
6643
6904
|
|
6644
|
-
//
|
6905
|
+
// ggml_flash_attn_back
|
6906
|
+
|
6907
|
+
struct ggml_tensor * ggml_flash_attn_back(
|
6908
|
+
struct ggml_context * ctx,
|
6909
|
+
struct ggml_tensor * q,
|
6910
|
+
struct ggml_tensor * k,
|
6911
|
+
struct ggml_tensor * v,
|
6912
|
+
struct ggml_tensor * d,
|
6913
|
+
bool masked) {
|
6914
|
+
GGML_ASSERT(ggml_can_mul_mat(k, q));
|
6915
|
+
// TODO: check if vT can be multiplied by (k*qT)
|
6916
|
+
|
6917
|
+
// d shape [D,N,ne2,ne3]
|
6918
|
+
// q shape [D,N,ne2,ne3]
|
6919
|
+
// k shape [D,M,ne2,ne3]
|
6920
|
+
// v shape [M,D,ne2,ne3]
|
6921
|
+
|
6922
|
+
const int64_t D = q->ne[0];
|
6923
|
+
const int64_t N = q->ne[1];
|
6924
|
+
const int64_t M = k->ne[1];
|
6925
|
+
const int64_t ne2 = q->ne[2];
|
6926
|
+
const int64_t ne3 = q->ne[3];
|
6927
|
+
|
6928
|
+
GGML_ASSERT(k->ne[0] == D);
|
6929
|
+
GGML_ASSERT(v->ne[0] == M);
|
6930
|
+
GGML_ASSERT(v->ne[1] == D);
|
6931
|
+
GGML_ASSERT(d->ne[0] == D);
|
6932
|
+
GGML_ASSERT(d->ne[1] == N);
|
6933
|
+
GGML_ASSERT(k->ne[2] == ne2);
|
6934
|
+
GGML_ASSERT(k->ne[3] == ne3);
|
6935
|
+
GGML_ASSERT(v->ne[2] == ne2);
|
6936
|
+
GGML_ASSERT(v->ne[3] == ne3);
|
6937
|
+
GGML_ASSERT(d->ne[2] == ne2);
|
6938
|
+
GGML_ASSERT(d->ne[3] == ne3);
|
6645
6939
|
|
6646
|
-
struct ggml_tensor * ggml_map_unary_impl_f32(
|
6647
|
-
struct ggml_context * ctx,
|
6648
|
-
struct ggml_tensor * a,
|
6649
|
-
const ggml_unary_op_f32_t fun,
|
6650
|
-
bool inplace) {
|
6651
6940
|
bool is_node = false;
|
6652
6941
|
|
6653
|
-
if (
|
6654
|
-
|
6942
|
+
if (q->grad || k->grad || v->grad) {
|
6943
|
+
// when using this operation (in backwards pass) these grads are set.
|
6944
|
+
// we don't want to create (big) grad of our result, so is_node is false.
|
6945
|
+
is_node = false;
|
6655
6946
|
}
|
6656
6947
|
|
6657
|
-
|
6658
|
-
|
6659
|
-
|
6948
|
+
// store gradients of q, k and v as continuous tensors concatenated in result.
|
6949
|
+
// q shape[D,N,ne2,ne3] ; k shape [D,M,ne2,ne3] ; v shape [M,D,ne2,ne3]
|
6950
|
+
// gradq->data = result->data
|
6951
|
+
// gradk->data = result->data + nb0*D*N*ne2*ne3
|
6952
|
+
// gradv->data = result->data + nb0*D*N*ne2*ne3 + nb0*D*M*ne2*ne3
|
6953
|
+
// note: v and gradv are actually transposed, i.e. v->ne[0] != D.
|
6954
|
+
int64_t ne[4] = {D,M+N+M,ne2,ne3};
|
6660
6955
|
|
6661
|
-
result
|
6956
|
+
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
|
6957
|
+
|
6958
|
+
result->op = GGML_OP_FLASH_ATTN_BACK;
|
6662
6959
|
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
6663
|
-
result->src0 =
|
6664
|
-
result->
|
6960
|
+
result->src0 = q;
|
6961
|
+
result->src1 = k;
|
6962
|
+
result->opt[0] = v;
|
6963
|
+
result->opt[1] = d;
|
6964
|
+
result->opt[2] = ggml_new_i32(ctx, masked ? 1 : 0);
|
6665
6965
|
|
6666
6966
|
return result;
|
6667
6967
|
}
|
6668
6968
|
|
6669
|
-
|
6670
|
-
struct ggml_context * ctx,
|
6671
|
-
struct ggml_tensor * a,
|
6672
|
-
const ggml_unary_op_f32_t fun) {
|
6673
|
-
return ggml_map_unary_impl_f32(ctx, a, fun, false);
|
6674
|
-
}
|
6675
|
-
|
6676
|
-
struct ggml_tensor * ggml_map_unary_inplace_f32(
|
6677
|
-
struct ggml_context * ctx,
|
6678
|
-
struct ggml_tensor * a,
|
6679
|
-
const ggml_unary_op_f32_t fun) {
|
6680
|
-
return ggml_map_unary_impl_f32(ctx, a, fun, true);
|
6681
|
-
}
|
6682
|
-
|
6683
|
-
// ggml_map_binary
|
6969
|
+
// ggml_win_part
|
6684
6970
|
|
6685
|
-
struct ggml_tensor *
|
6686
|
-
struct ggml_context
|
6687
|
-
struct ggml_tensor
|
6688
|
-
|
6689
|
-
|
6690
|
-
|
6691
|
-
|
6971
|
+
struct ggml_tensor * ggml_win_part(
|
6972
|
+
struct ggml_context * ctx,
|
6973
|
+
struct ggml_tensor * a,
|
6974
|
+
int w) {
|
6975
|
+
GGML_ASSERT(a->ne[3] == 1);
|
6976
|
+
GGML_ASSERT(a->type == GGML_TYPE_F32);
|
6977
|
+
|
6978
|
+
bool is_node = false;
|
6979
|
+
|
6980
|
+
if (a->grad) {
|
6981
|
+
GGML_ASSERT(false); // TODO: implement backward
|
6982
|
+
is_node = true;
|
6983
|
+
}
|
6984
|
+
|
6985
|
+
// padding
|
6986
|
+
const int px = (w - a->ne[1]%w)%w;
|
6987
|
+
const int py = (w - a->ne[2]%w)%w;
|
6988
|
+
|
6989
|
+
const int npx = (px + a->ne[1])/w;
|
6990
|
+
const int npy = (py + a->ne[2])/w;
|
6991
|
+
const int np = npx*npy;
|
6992
|
+
|
6993
|
+
const int64_t ne[4] = { a->ne[0], w, w, np, };
|
6994
|
+
|
6995
|
+
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
|
6996
|
+
|
6997
|
+
ggml_scratch_save(ctx);
|
6998
|
+
|
6999
|
+
struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3);
|
7000
|
+
|
7001
|
+
((int32_t *) b->data)[0] = npx;
|
7002
|
+
((int32_t *) b->data)[1] = npy;
|
7003
|
+
((int32_t *) b->data)[2] = w;
|
7004
|
+
|
7005
|
+
ggml_scratch_load(ctx);
|
7006
|
+
|
7007
|
+
result->op = GGML_OP_WIN_PART;
|
7008
|
+
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
7009
|
+
result->src0 = a;
|
7010
|
+
result->src1 = NULL;
|
7011
|
+
result->opt[0] = b;
|
7012
|
+
|
7013
|
+
return result;
|
7014
|
+
}
|
7015
|
+
|
7016
|
+
// ggml_win_unpart
|
7017
|
+
|
7018
|
+
struct ggml_tensor * ggml_win_unpart(
|
7019
|
+
struct ggml_context * ctx,
|
7020
|
+
struct ggml_tensor * a,
|
7021
|
+
int w0,
|
7022
|
+
int h0,
|
7023
|
+
int w) {
|
7024
|
+
GGML_ASSERT(a->type == GGML_TYPE_F32);
|
7025
|
+
|
7026
|
+
bool is_node = false;
|
7027
|
+
|
7028
|
+
if (a->grad) {
|
7029
|
+
GGML_ASSERT(false); // TODO: implement backward
|
7030
|
+
is_node = true;
|
7031
|
+
}
|
7032
|
+
|
7033
|
+
const int64_t ne[4] = { a->ne[0], w0, h0, 1, };
|
7034
|
+
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 3, ne);
|
7035
|
+
|
7036
|
+
ggml_scratch_save(ctx);
|
7037
|
+
|
7038
|
+
struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1);
|
7039
|
+
|
7040
|
+
((int32_t *) b->data)[0] = w;
|
7041
|
+
|
7042
|
+
ggml_scratch_load(ctx);
|
7043
|
+
|
7044
|
+
result->op = GGML_OP_WIN_UNPART;
|
7045
|
+
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
7046
|
+
result->src0 = a;
|
7047
|
+
result->src1 = NULL;
|
7048
|
+
result->opt[0] = b;
|
7049
|
+
|
7050
|
+
return result;
|
7051
|
+
}
|
7052
|
+
|
7053
|
+
// ggml_map_unary
|
7054
|
+
|
7055
|
+
struct ggml_tensor * ggml_map_unary_impl_f32(
|
7056
|
+
struct ggml_context * ctx,
|
7057
|
+
struct ggml_tensor * a,
|
7058
|
+
const ggml_unary_op_f32_t fun,
|
7059
|
+
bool inplace) {
|
7060
|
+
bool is_node = false;
|
7061
|
+
|
7062
|
+
if (!inplace && a->grad) {
|
7063
|
+
is_node = true;
|
7064
|
+
}
|
7065
|
+
|
7066
|
+
struct ggml_tensor * addr_tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(void *) / sizeof(int32_t));
|
7067
|
+
*((void (**)(void))addr_tensor->data) = (void (*)(void))fun;
|
7068
|
+
struct ggml_tensor *result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
7069
|
+
|
7070
|
+
result->op = GGML_OP_MAP_UNARY;
|
7071
|
+
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
7072
|
+
result->src0 = a;
|
7073
|
+
result->opt[0] = addr_tensor;
|
7074
|
+
|
7075
|
+
return result;
|
7076
|
+
}
|
7077
|
+
|
7078
|
+
struct ggml_tensor * ggml_map_unary_f32(
|
7079
|
+
struct ggml_context * ctx,
|
7080
|
+
struct ggml_tensor * a,
|
7081
|
+
const ggml_unary_op_f32_t fun) {
|
7082
|
+
return ggml_map_unary_impl_f32(ctx, a, fun, false);
|
7083
|
+
}
|
7084
|
+
|
7085
|
+
struct ggml_tensor * ggml_map_unary_inplace_f32(
|
7086
|
+
struct ggml_context * ctx,
|
7087
|
+
struct ggml_tensor * a,
|
7088
|
+
const ggml_unary_op_f32_t fun) {
|
7089
|
+
return ggml_map_unary_impl_f32(ctx, a, fun, true);
|
7090
|
+
}
|
7091
|
+
|
7092
|
+
// ggml_map_binary
|
7093
|
+
|
7094
|
+
struct ggml_tensor * ggml_map_binary_impl_f32(
|
7095
|
+
struct ggml_context * ctx,
|
7096
|
+
struct ggml_tensor * a,
|
7097
|
+
struct ggml_tensor * b,
|
7098
|
+
const ggml_binary_op_f32_t fun,
|
7099
|
+
bool inplace) {
|
7100
|
+
GGML_ASSERT(ggml_are_same_shape(a, b));
|
6692
7101
|
|
6693
7102
|
bool is_node = false;
|
6694
7103
|
|
@@ -6725,6 +7134,50 @@ struct ggml_tensor * ggml_map_binary_inplace_f32(
|
|
6725
7134
|
return ggml_map_binary_impl_f32(ctx, a, b, fun, true);
|
6726
7135
|
}
|
6727
7136
|
|
7137
|
+
// ggml_cross_entropy_loss
|
7138
|
+
|
7139
|
+
struct ggml_tensor * ggml_cross_entropy_loss(
|
7140
|
+
struct ggml_context * ctx,
|
7141
|
+
struct ggml_tensor * a,
|
7142
|
+
struct ggml_tensor * b) {
|
7143
|
+
GGML_ASSERT(ggml_are_same_shape(a, b));
|
7144
|
+
bool is_node = false;
|
7145
|
+
|
7146
|
+
if (a->grad || b->grad) {
|
7147
|
+
is_node = true;
|
7148
|
+
}
|
7149
|
+
|
7150
|
+
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, a->type, 1);
|
7151
|
+
|
7152
|
+
result->op = GGML_OP_CROSS_ENTROPY_LOSS;
|
7153
|
+
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
7154
|
+
result->src0 = a;
|
7155
|
+
result->src1 = b;
|
7156
|
+
|
7157
|
+
return result;
|
7158
|
+
}
|
7159
|
+
|
7160
|
+
// ggml_cross_entropy_loss_back
|
7161
|
+
|
7162
|
+
struct ggml_tensor * ggml_cross_entropy_loss_back(
|
7163
|
+
struct ggml_context * ctx,
|
7164
|
+
struct ggml_tensor * a,
|
7165
|
+
struct ggml_tensor * b,
|
7166
|
+
struct ggml_tensor * c) {
|
7167
|
+
GGML_ASSERT(ggml_are_same_shape(a, b));
|
7168
|
+
GGML_ASSERT(ggml_is_scalar(c));
|
7169
|
+
|
7170
|
+
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
|
7171
|
+
|
7172
|
+
result->op = GGML_OP_CROSS_ENTROPY_LOSS_BACK;
|
7173
|
+
result->grad = NULL;
|
7174
|
+
result->src0 = a;
|
7175
|
+
result->src1 = b;
|
7176
|
+
result->opt[0] = c;
|
7177
|
+
|
7178
|
+
return result;
|
7179
|
+
}
|
7180
|
+
|
6728
7181
|
////////////////////////////////////////////////////////////////////////////////
|
6729
7182
|
|
6730
7183
|
void ggml_set_param(
|
@@ -7674,7 +8127,7 @@ static void ggml_compute_forward_add_q_f32(
|
|
7674
8127
|
|
7675
8128
|
void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
|
7676
8129
|
float * src1_row = (float *)((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13));
|
7677
|
-
void * dst_row = (void *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*
|
8130
|
+
void * dst_row = (void *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
|
7678
8131
|
|
7679
8132
|
assert(ne00 % 32 == 0);
|
7680
8133
|
|
@@ -8875,6 +9328,99 @@ static void ggml_compute_forward_repeat(
|
|
8875
9328
|
}
|
8876
9329
|
}
|
8877
9330
|
|
9331
|
+
// ggml_compute_forward_repeat_back
|
9332
|
+
|
9333
|
+
static void ggml_compute_forward_repeat_back_f32(
|
9334
|
+
const struct ggml_compute_params * params,
|
9335
|
+
const struct ggml_tensor * src0,
|
9336
|
+
struct ggml_tensor * dst) {
|
9337
|
+
GGML_ASSERT(params->ith == 0);
|
9338
|
+
GGML_ASSERT(ggml_can_repeat(dst, src0));
|
9339
|
+
|
9340
|
+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
9341
|
+
return;
|
9342
|
+
}
|
9343
|
+
|
9344
|
+
const int64_t ne0 = dst->ne[0];
|
9345
|
+
const int64_t ne1 = dst->ne[1];
|
9346
|
+
const int64_t ne2 = dst->ne[2];
|
9347
|
+
const int64_t ne3 = dst->ne[3];
|
9348
|
+
|
9349
|
+
const int64_t ne00 = src0->ne[0];
|
9350
|
+
const int64_t ne01 = src0->ne[1];
|
9351
|
+
const int64_t ne02 = src0->ne[2];
|
9352
|
+
const int64_t ne03 = src0->ne[3];
|
9353
|
+
|
9354
|
+
const size_t nb0 = dst->nb[0];
|
9355
|
+
const size_t nb1 = dst->nb[1];
|
9356
|
+
const size_t nb2 = dst->nb[2];
|
9357
|
+
const size_t nb3 = dst->nb[3];
|
9358
|
+
|
9359
|
+
const size_t nb00 = src0->nb[0];
|
9360
|
+
const size_t nb01 = src0->nb[1];
|
9361
|
+
const size_t nb02 = src0->nb[2];
|
9362
|
+
const size_t nb03 = src0->nb[3];
|
9363
|
+
|
9364
|
+
// guaranteed to be an integer due to the check in ggml_can_repeat
|
9365
|
+
const int nr0 = (int)(ne00/ne0);
|
9366
|
+
const int nr1 = (int)(ne01/ne1);
|
9367
|
+
const int nr2 = (int)(ne02/ne2);
|
9368
|
+
const int nr3 = (int)(ne03/ne3);
|
9369
|
+
|
9370
|
+
// TODO: support for transposed / permuted tensors
|
9371
|
+
GGML_ASSERT(nb0 == sizeof(float));
|
9372
|
+
GGML_ASSERT(nb00 == sizeof(float));
|
9373
|
+
|
9374
|
+
if (ggml_is_contiguous(dst)) {
|
9375
|
+
ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
|
9376
|
+
} else {
|
9377
|
+
for (int k3 = 0; k3 < ne3; k3++) {
|
9378
|
+
for (int k2 = 0; k2 < ne2; k2++) {
|
9379
|
+
for (int k1 = 0; k1 < ne1; k1++) {
|
9380
|
+
ggml_vec_set_f32(ne0,
|
9381
|
+
(float *) ((char *) dst->data + k1*nb1 + k2*nb2 + k3*nb3),
|
9382
|
+
0);
|
9383
|
+
}
|
9384
|
+
}
|
9385
|
+
}
|
9386
|
+
}
|
9387
|
+
|
9388
|
+
// TODO: maybe this is not optimal?
|
9389
|
+
for (int i3 = 0; i3 < nr3; i3++) {
|
9390
|
+
for (int k3 = 0; k3 < ne3; k3++) {
|
9391
|
+
for (int i2 = 0; i2 < nr2; i2++) {
|
9392
|
+
for (int k2 = 0; k2 < ne2; k2++) {
|
9393
|
+
for (int i1 = 0; i1 < nr1; i1++) {
|
9394
|
+
for (int k1 = 0; k1 < ne1; k1++) {
|
9395
|
+
for (int i0 = 0; i0 < nr0; i0++) {
|
9396
|
+
ggml_vec_acc_f32(ne0,
|
9397
|
+
(float *) ((char *) dst->data + ( k3)*nb3 + ( k2)*nb2 + ( k1)*nb1),
|
9398
|
+
(float *) ((char *) src0->data + (i3*ne3 + k3)*nb03 + (i2*ne2 + k2)*nb02 + (i1*ne1 + k1)*nb01 + (i0*ne0)*nb00));
|
9399
|
+
}
|
9400
|
+
}
|
9401
|
+
}
|
9402
|
+
}
|
9403
|
+
}
|
9404
|
+
}
|
9405
|
+
}
|
9406
|
+
}
|
9407
|
+
|
9408
|
+
static void ggml_compute_forward_repeat_back(
|
9409
|
+
const struct ggml_compute_params * params,
|
9410
|
+
const struct ggml_tensor * src0,
|
9411
|
+
struct ggml_tensor * dst) {
|
9412
|
+
switch (src0->type) {
|
9413
|
+
case GGML_TYPE_F32:
|
9414
|
+
{
|
9415
|
+
ggml_compute_forward_repeat_back_f32(params, src0, dst);
|
9416
|
+
} break;
|
9417
|
+
default:
|
9418
|
+
{
|
9419
|
+
GGML_ASSERT(false);
|
9420
|
+
} break;
|
9421
|
+
}
|
9422
|
+
}
|
9423
|
+
|
8878
9424
|
// ggml_compute_forward_abs
|
8879
9425
|
|
8880
9426
|
static void ggml_compute_forward_abs_f32(
|
@@ -9142,8 +9688,65 @@ static void ggml_compute_forward_gelu(
|
|
9142
9688
|
GGML_ASSERT(false);
|
9143
9689
|
} break;
|
9144
9690
|
}
|
9691
|
+
}
|
9692
|
+
|
9693
|
+
// ggml_compute_forward_gelu_quick
|
9694
|
+
|
9695
|
+
static void ggml_compute_forward_gelu_quick_f32(
|
9696
|
+
const struct ggml_compute_params * params,
|
9697
|
+
const struct ggml_tensor * src0,
|
9698
|
+
struct ggml_tensor * dst) {
|
9699
|
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
9700
|
+
GGML_ASSERT(ggml_is_contiguous(dst));
|
9701
|
+
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
9702
|
+
|
9703
|
+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
9704
|
+
return;
|
9705
|
+
}
|
9706
|
+
|
9707
|
+
const int ith = params->ith;
|
9708
|
+
const int nth = params->nth;
|
9709
|
+
|
9710
|
+
const int nc = src0->ne[0];
|
9711
|
+
const int nr = ggml_nrows(src0);
|
9712
|
+
|
9713
|
+
// rows per thread
|
9714
|
+
const int dr = (nr + nth - 1)/nth;
|
9715
|
+
|
9716
|
+
// row range for this thread
|
9717
|
+
const int ir0 = dr*ith;
|
9718
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
9719
|
+
|
9720
|
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
9721
|
+
ggml_vec_gelu_quick_f32(nc,
|
9722
|
+
(float *) ((char *) dst->data + i1*( dst->nb[1])),
|
9723
|
+
(float *) ((char *) src0->data + i1*(src0->nb[1])));
|
9724
|
+
|
9725
|
+
#ifndef NDEBUG
|
9726
|
+
for (int k = 0; k < nc; k++) {
|
9727
|
+
const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
9728
|
+
UNUSED(x);
|
9729
|
+
assert(!isnan(x));
|
9730
|
+
assert(!isinf(x));
|
9731
|
+
}
|
9732
|
+
#endif
|
9733
|
+
}
|
9734
|
+
}
|
9145
9735
|
|
9146
|
-
|
9736
|
+
static void ggml_compute_forward_gelu_quick(
|
9737
|
+
const struct ggml_compute_params * params,
|
9738
|
+
const struct ggml_tensor * src0,
|
9739
|
+
struct ggml_tensor * dst) {
|
9740
|
+
switch (src0->type) {
|
9741
|
+
case GGML_TYPE_F32:
|
9742
|
+
{
|
9743
|
+
ggml_compute_forward_gelu_quick_f32(params, src0, dst);
|
9744
|
+
} break;
|
9745
|
+
default:
|
9746
|
+
{
|
9747
|
+
GGML_ASSERT(false);
|
9748
|
+
} break;
|
9749
|
+
}
|
9147
9750
|
}
|
9148
9751
|
|
9149
9752
|
// ggml_compute_forward_silu
|
@@ -10249,9 +10852,179 @@ static void ggml_compute_forward_mul_mat(
|
|
10249
10852
|
}
|
10250
10853
|
}
|
10251
10854
|
|
10252
|
-
//
|
10855
|
+
// ggml_compute_forward_out_prod
|
10253
10856
|
|
10254
|
-
|
10857
|
+
|
10858
|
+
static void ggml_compute_forward_out_prod_f32(
|
10859
|
+
const struct ggml_compute_params * params,
|
10860
|
+
const struct ggml_tensor * src0,
|
10861
|
+
const struct ggml_tensor * src1,
|
10862
|
+
struct ggml_tensor * dst) {
|
10863
|
+
int64_t t0 = ggml_perf_time_us();
|
10864
|
+
UNUSED(t0);
|
10865
|
+
|
10866
|
+
const int64_t ne00 = src0->ne[0];
|
10867
|
+
const int64_t ne01 = src0->ne[1];
|
10868
|
+
const int64_t ne02 = src0->ne[2];
|
10869
|
+
const int64_t ne03 = src0->ne[3];
|
10870
|
+
|
10871
|
+
const int64_t ne10 = src1->ne[0];
|
10872
|
+
//const int64_t ne11 = src1->ne[1];
|
10873
|
+
const int64_t ne12 = src1->ne[2];
|
10874
|
+
const int64_t ne13 = src1->ne[3];
|
10875
|
+
|
10876
|
+
const int64_t ne0 = dst->ne[0];
|
10877
|
+
const int64_t ne1 = dst->ne[1];
|
10878
|
+
const int64_t ne2 = dst->ne[2];
|
10879
|
+
const int64_t ne3 = dst->ne[3];
|
10880
|
+
|
10881
|
+
const int nb00 = src0->nb[0];
|
10882
|
+
const int nb01 = src0->nb[1];
|
10883
|
+
const int nb02 = src0->nb[2];
|
10884
|
+
const int nb03 = src0->nb[3];
|
10885
|
+
|
10886
|
+
const int nb10 = src1->nb[0];
|
10887
|
+
const int nb11 = src1->nb[1];
|
10888
|
+
const int nb12 = src1->nb[2];
|
10889
|
+
const int nb13 = src1->nb[3];
|
10890
|
+
|
10891
|
+
const int nb0 = dst->nb[0];
|
10892
|
+
const int nb1 = dst->nb[1];
|
10893
|
+
const int nb2 = dst->nb[2];
|
10894
|
+
const int nb3 = dst->nb[3];
|
10895
|
+
|
10896
|
+
const int ith = params->ith;
|
10897
|
+
const int nth = params->nth;
|
10898
|
+
|
10899
|
+
GGML_ASSERT(ne02 == ne12);
|
10900
|
+
GGML_ASSERT(ne03 == ne13);
|
10901
|
+
GGML_ASSERT(ne2 == ne12);
|
10902
|
+
GGML_ASSERT(ne3 == ne13);
|
10903
|
+
|
10904
|
+
// we don't support permuted src0 or src1
|
10905
|
+
GGML_ASSERT(nb00 == sizeof(float));
|
10906
|
+
|
10907
|
+
// dst cannot be transposed or permuted
|
10908
|
+
GGML_ASSERT(nb0 == sizeof(float));
|
10909
|
+
// GGML_ASSERT(nb0 <= nb1);
|
10910
|
+
// GGML_ASSERT(nb1 <= nb2);
|
10911
|
+
// GGML_ASSERT(nb2 <= nb3);
|
10912
|
+
|
10913
|
+
GGML_ASSERT(ne0 == ne00);
|
10914
|
+
GGML_ASSERT(ne1 == ne10);
|
10915
|
+
GGML_ASSERT(ne2 == ne02);
|
10916
|
+
GGML_ASSERT(ne3 == ne03);
|
10917
|
+
|
10918
|
+
// nb01 >= nb00 - src0 is not transposed
|
10919
|
+
// compute by src0 rows
|
10920
|
+
|
10921
|
+
// TODO: #if defined(GGML_USE_CUBLAS) ggml_cuda_out_prod
|
10922
|
+
// TODO: #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
|
10923
|
+
|
10924
|
+
if (params->type == GGML_TASK_INIT) {
|
10925
|
+
ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
|
10926
|
+
return;
|
10927
|
+
}
|
10928
|
+
|
10929
|
+
if (params->type == GGML_TASK_FINALIZE) {
|
10930
|
+
return;
|
10931
|
+
}
|
10932
|
+
|
10933
|
+
// parallelize by last three dimensions
|
10934
|
+
|
10935
|
+
// total rows in dst
|
10936
|
+
const int64_t nr = ne1*ne2*ne3;
|
10937
|
+
|
10938
|
+
// rows per thread
|
10939
|
+
const int64_t dr = (nr + nth - 1)/nth;
|
10940
|
+
|
10941
|
+
// row range for this thread
|
10942
|
+
const int64_t ir0 = dr*ith;
|
10943
|
+
const int64_t ir1 = MIN(ir0 + dr, nr);
|
10944
|
+
|
10945
|
+
// dst[:,:,:,:] = 0
|
10946
|
+
// for i2,i3:
|
10947
|
+
// for i1:
|
10948
|
+
// for i01:
|
10949
|
+
// for i0:
|
10950
|
+
// dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3]
|
10951
|
+
|
10952
|
+
for (int64_t ir = ir0; ir < ir1; ++ir) {
|
10953
|
+
// dst indices
|
10954
|
+
const int64_t i3 = ir/(ne2*ne1);
|
10955
|
+
const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
|
10956
|
+
const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
|
10957
|
+
|
10958
|
+
const int64_t i02 = i2;
|
10959
|
+
const int64_t i03 = i3;
|
10960
|
+
|
10961
|
+
//const int64_t i10 = i1;
|
10962
|
+
const int64_t i12 = i2;
|
10963
|
+
const int64_t i13 = i3;
|
10964
|
+
|
10965
|
+
for (int64_t i01 = 0; i01 < ne01; ++i01) {
|
10966
|
+
const int64_t i11 = i01;
|
10967
|
+
|
10968
|
+
float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
|
10969
|
+
float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
|
10970
|
+
float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
|
10971
|
+
|
10972
|
+
ggml_vec_mad_f32(ne0, d, s0, *s1);
|
10973
|
+
// for (int64_t i0 = 0; i0 < ne0; ++i0) {
|
10974
|
+
// d[i0] += s0[i0] * s1[i1];
|
10975
|
+
// }
|
10976
|
+
}
|
10977
|
+
}
|
10978
|
+
|
10979
|
+
//int64_t t1 = ggml_perf_time_us();
|
10980
|
+
//static int64_t acc = 0;
|
10981
|
+
//acc += t1 - t0;
|
10982
|
+
//if (t1 - t0 > 10) {
|
10983
|
+
// printf("\n");
|
10984
|
+
// printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03);
|
10985
|
+
// printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03);
|
10986
|
+
// printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13);
|
10987
|
+
// printf("nb10 = %5d, nb11 = %5d, nb12 = %5d, nb13 = %5d\n", nb10, nb11, nb12, nb13);
|
10988
|
+
|
10989
|
+
// printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc);
|
10990
|
+
//}
|
10991
|
+
}
|
10992
|
+
|
10993
|
+
static void ggml_compute_forward_out_prod(
|
10994
|
+
const struct ggml_compute_params * params,
|
10995
|
+
const struct ggml_tensor * src0,
|
10996
|
+
const struct ggml_tensor * src1,
|
10997
|
+
struct ggml_tensor * dst) {
|
10998
|
+
switch (src0->type) {
|
10999
|
+
case GGML_TYPE_Q4_0:
|
11000
|
+
case GGML_TYPE_Q4_1:
|
11001
|
+
case GGML_TYPE_Q5_0:
|
11002
|
+
case GGML_TYPE_Q5_1:
|
11003
|
+
case GGML_TYPE_Q8_0:
|
11004
|
+
case GGML_TYPE_Q8_1:
|
11005
|
+
{
|
11006
|
+
GGML_ASSERT(false); // todo
|
11007
|
+
// ggml_compute_forward_out_prod_q_f32(params, src0, src1, dst);
|
11008
|
+
} break;
|
11009
|
+
case GGML_TYPE_F16:
|
11010
|
+
{
|
11011
|
+
GGML_ASSERT(false); // todo
|
11012
|
+
// ggml_compute_forward_out_prod_f16_f32(params, src0, src1, dst);
|
11013
|
+
} break;
|
11014
|
+
case GGML_TYPE_F32:
|
11015
|
+
{
|
11016
|
+
ggml_compute_forward_out_prod_f32(params, src0, src1, dst);
|
11017
|
+
} break;
|
11018
|
+
default:
|
11019
|
+
{
|
11020
|
+
GGML_ASSERT(false);
|
11021
|
+
} break;
|
11022
|
+
}
|
11023
|
+
}
|
11024
|
+
|
11025
|
+
// ggml_compute_forward_scale
|
11026
|
+
|
11027
|
+
static void ggml_compute_forward_scale_f32(
|
10255
11028
|
const struct ggml_compute_params * params,
|
10256
11029
|
const struct ggml_tensor * src0,
|
10257
11030
|
const struct ggml_tensor * src1,
|
@@ -10371,7 +11144,7 @@ static void ggml_compute_forward_set_f32(
|
|
10371
11144
|
const int im2 = (ne12 == 0 ? 0 : ne12-1);
|
10372
11145
|
const int im3 = (ne13 == 0 ? 0 : ne13-1);
|
10373
11146
|
|
10374
|
-
GGML_ASSERT(offset + im0*nb0 + im1*nb1 + im2*nb2 + im3*nb3
|
11147
|
+
GGML_ASSERT(offset + im0*nb0 + im1*nb1 + im2*nb2 + im3*nb3 <= ggml_nbytes(dst));
|
10375
11148
|
|
10376
11149
|
GGML_ASSERT(nb10 == sizeof(float));
|
10377
11150
|
|
@@ -10671,7 +11444,11 @@ static void ggml_compute_forward_get_rows_back_f32(
|
|
10671
11444
|
GGML_ASSERT(ggml_is_contiguous(opt0));
|
10672
11445
|
GGML_ASSERT(ggml_is_contiguous(dst));
|
10673
11446
|
|
10674
|
-
ggml_compute_forward_dup_same_cont(params, opt0, dst);
|
11447
|
+
// ggml_compute_forward_dup_same_cont(params, opt0, dst);
|
11448
|
+
|
11449
|
+
if (params->type == GGML_TASK_INIT) {
|
11450
|
+
memset(dst->data, 0, ggml_nbytes(dst));
|
11451
|
+
}
|
10675
11452
|
|
10676
11453
|
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
10677
11454
|
return;
|
@@ -10815,8 +11592,8 @@ static void ggml_compute_forward_diag_mask_f32(
|
|
10815
11592
|
const struct ggml_tensor * src1,
|
10816
11593
|
struct ggml_tensor * dst,
|
10817
11594
|
const float value) {
|
10818
|
-
|
10819
|
-
|
11595
|
+
GGML_ASSERT(src1->type == GGML_TYPE_I32);
|
11596
|
+
GGML_ASSERT(ggml_nelements(src1) == 2);
|
10820
11597
|
|
10821
11598
|
const int ith = params->ith;
|
10822
11599
|
const int nth = params->nth;
|
@@ -10824,7 +11601,7 @@ static void ggml_compute_forward_diag_mask_f32(
|
|
10824
11601
|
const int n_past = ((int32_t *) src1->data)[0];
|
10825
11602
|
const bool inplace = (bool)((int32_t *) src1->data)[1];
|
10826
11603
|
|
10827
|
-
|
11604
|
+
GGML_ASSERT(n_past >= 0);
|
10828
11605
|
|
10829
11606
|
if (!inplace && (params->type == GGML_TASK_INIT)) {
|
10830
11607
|
// memcpy needs to be synchronized across threads to avoid race conditions.
|
@@ -10848,8 +11625,8 @@ static void ggml_compute_forward_diag_mask_f32(
|
|
10848
11625
|
const int nr = src0->ne[1];
|
10849
11626
|
const int nz = n/nr;
|
10850
11627
|
|
10851
|
-
|
10852
|
-
|
11628
|
+
GGML_ASSERT( dst->nb[0] == sizeof(float));
|
11629
|
+
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
10853
11630
|
|
10854
11631
|
for (int k = 0; k < nz; k++) {
|
10855
11632
|
for (int j = ith; j < nr; j += nth) {
|
@@ -10985,6 +11762,101 @@ static void ggml_compute_forward_soft_max(
|
|
10985
11762
|
}
|
10986
11763
|
}
|
10987
11764
|
|
11765
|
+
// ggml_compute_forward_soft_max_back
|
11766
|
+
|
11767
|
+
static void ggml_compute_forward_soft_max_back_f32(
|
11768
|
+
const struct ggml_compute_params * params,
|
11769
|
+
const struct ggml_tensor * src0,
|
11770
|
+
const struct ggml_tensor * src1,
|
11771
|
+
struct ggml_tensor * dst) {
|
11772
|
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
11773
|
+
GGML_ASSERT(ggml_is_contiguous(src1));
|
11774
|
+
GGML_ASSERT(ggml_is_contiguous(dst));
|
11775
|
+
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
11776
|
+
GGML_ASSERT(ggml_are_same_shape(src1, dst));
|
11777
|
+
|
11778
|
+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
11779
|
+
return;
|
11780
|
+
}
|
11781
|
+
|
11782
|
+
// TODO: handle transposed/permuted matrices
|
11783
|
+
|
11784
|
+
const int ith = params->ith;
|
11785
|
+
const int nth = params->nth;
|
11786
|
+
|
11787
|
+
const int nc = src0->ne[0];
|
11788
|
+
const int nr = ggml_nrows(src0);
|
11789
|
+
|
11790
|
+
// rows per thread
|
11791
|
+
const int dr = (nr + nth - 1)/nth;
|
11792
|
+
|
11793
|
+
// row range for this thread
|
11794
|
+
const int ir0 = dr*ith;
|
11795
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
11796
|
+
|
11797
|
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
11798
|
+
float *dy = (float *)((char *) src0->data + i1*src0->nb[1]);
|
11799
|
+
float *y = (float *)((char *) src1->data + i1*src1->nb[1]);
|
11800
|
+
float *dx = (float *)((char *) dst->data + i1*dst->nb[1]);
|
11801
|
+
|
11802
|
+
#ifndef NDEBUG
|
11803
|
+
for (int i = 0; i < nc; ++i) {
|
11804
|
+
//printf("p[%d] = %f\n", i, p[i]);
|
11805
|
+
assert(!isnan(dy[i]));
|
11806
|
+
assert(!isnan(y[i]));
|
11807
|
+
}
|
11808
|
+
#endif
|
11809
|
+
// Jii = yi - yi*yi
|
11810
|
+
// Jij = -yi*yj
|
11811
|
+
// J = diag(y)-y.T*y
|
11812
|
+
// dx = J * dy
|
11813
|
+
// dxk = sum_i(Jki * dyi)
|
11814
|
+
// dxk = sum_i(-yk*yi * dyi) - (-yk*yk)*dyk + (yk - yk*yk)*dyk
|
11815
|
+
// dxk = sum_i(-yk*yi * dyi) + yk*dyk
|
11816
|
+
// dxk = -yk * sum_i(yi * dyi) + yk*dyk
|
11817
|
+
// dxk = -yk * dot(y, dy) + yk*dyk
|
11818
|
+
// dxk = yk * (- dot(y, dy) + dyk)
|
11819
|
+
// dxk = yk * (dyk - dot(y, dy))
|
11820
|
+
//
|
11821
|
+
// post-order:
|
11822
|
+
// dot_y_dy := dot(y, dy)
|
11823
|
+
// dx := dy
|
11824
|
+
// dx := dx - dot_y_dy
|
11825
|
+
// dx := dx * y
|
11826
|
+
|
11827
|
+
// linear runtime, no additional memory
|
11828
|
+
float dot_y_dy = 0;
|
11829
|
+
ggml_vec_dot_f32 (nc, &dot_y_dy, y, dy);
|
11830
|
+
ggml_vec_cpy_f32 (nc, dx, dy);
|
11831
|
+
ggml_vec_acc1_f32(nc, dx, -dot_y_dy);
|
11832
|
+
ggml_vec_mul_f32 (nc, dx, dx, y);
|
11833
|
+
|
11834
|
+
#ifndef NDEBUG
|
11835
|
+
for (int i = 0; i < nc; ++i) {
|
11836
|
+
assert(!isnan(dx[i]));
|
11837
|
+
assert(!isinf(dx[i]));
|
11838
|
+
}
|
11839
|
+
#endif
|
11840
|
+
}
|
11841
|
+
}
|
11842
|
+
|
11843
|
+
static void ggml_compute_forward_soft_max_back(
|
11844
|
+
const struct ggml_compute_params * params,
|
11845
|
+
const struct ggml_tensor * src0,
|
11846
|
+
const struct ggml_tensor * src1,
|
11847
|
+
struct ggml_tensor * dst) {
|
11848
|
+
switch (src0->type) {
|
11849
|
+
case GGML_TYPE_F32:
|
11850
|
+
{
|
11851
|
+
ggml_compute_forward_soft_max_back_f32(params, src0, src1, dst);
|
11852
|
+
} break;
|
11853
|
+
default:
|
11854
|
+
{
|
11855
|
+
GGML_ASSERT(false);
|
11856
|
+
} break;
|
11857
|
+
}
|
11858
|
+
}
|
11859
|
+
|
10988
11860
|
// ggml_compute_forward_alibi
|
10989
11861
|
|
10990
11862
|
static void ggml_compute_forward_alibi_f32(
|
@@ -10993,8 +11865,9 @@ static void ggml_compute_forward_alibi_f32(
|
|
10993
11865
|
const struct ggml_tensor * src1,
|
10994
11866
|
struct ggml_tensor * dst) {
|
10995
11867
|
assert(params->ith == 0);
|
10996
|
-
|
10997
|
-
|
11868
|
+
|
11869
|
+
GGML_ASSERT(src1->type == GGML_TYPE_I32);
|
11870
|
+
GGML_ASSERT(ggml_nelements(src1) == 3);
|
10998
11871
|
|
10999
11872
|
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
11000
11873
|
return;
|
@@ -11057,8 +11930,9 @@ static void ggml_compute_forward_alibi_f16(
|
|
11057
11930
|
const struct ggml_tensor * src1,
|
11058
11931
|
struct ggml_tensor * dst) {
|
11059
11932
|
assert(params->ith == 0);
|
11060
|
-
|
11061
|
-
|
11933
|
+
|
11934
|
+
GGML_ASSERT(src1->type == GGML_TYPE_I32);
|
11935
|
+
GGML_ASSERT(ggml_nelements(src1) == 3);
|
11062
11936
|
|
11063
11937
|
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
11064
11938
|
return;
|
@@ -11160,15 +12034,16 @@ static void ggml_compute_forward_clamp_f32(
|
|
11160
12034
|
const struct ggml_tensor * src1,
|
11161
12035
|
struct ggml_tensor * dst) {
|
11162
12036
|
assert(params->ith == 0);
|
11163
|
-
|
11164
|
-
|
12037
|
+
|
12038
|
+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
12039
|
+
GGML_ASSERT(ggml_nelements(src1) == 2);
|
11165
12040
|
|
11166
12041
|
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
11167
12042
|
return;
|
11168
12043
|
}
|
11169
12044
|
|
11170
|
-
const
|
11171
|
-
const
|
12045
|
+
const float min = ((float *) src1->data)[0];
|
12046
|
+
const float max = ((float *) src1->data)[1];
|
11172
12047
|
|
11173
12048
|
const int ith = params->ith;
|
11174
12049
|
const int nth = params->nth;
|
@@ -11726,9 +12601,9 @@ static void ggml_compute_forward_rope_back(
|
|
11726
12601
|
}
|
11727
12602
|
}
|
11728
12603
|
|
11729
|
-
//
|
12604
|
+
// ggml_compute_forward_conv_1d_s1_ph
|
11730
12605
|
|
11731
|
-
static void
|
12606
|
+
static void ggml_compute_forward_conv_1d_s1_ph_f16_f32(
|
11732
12607
|
const struct ggml_compute_params * params,
|
11733
12608
|
const struct ggml_tensor * src0,
|
11734
12609
|
const struct ggml_tensor * src1,
|
@@ -11848,7 +12723,7 @@ static void ggml_compute_forward_conv_1d_1s_f16_f32(
|
|
11848
12723
|
}
|
11849
12724
|
}
|
11850
12725
|
|
11851
|
-
static void
|
12726
|
+
static void ggml_compute_forward_conv_1d_s1_ph_f32(
|
11852
12727
|
const struct ggml_compute_params * params,
|
11853
12728
|
const struct ggml_tensor * src0,
|
11854
12729
|
const struct ggml_tensor * src1,
|
@@ -11968,7 +12843,7 @@ static void ggml_compute_forward_conv_1d_1s_f32(
|
|
11968
12843
|
}
|
11969
12844
|
}
|
11970
12845
|
|
11971
|
-
static void
|
12846
|
+
static void ggml_compute_forward_conv_1d_s1_ph(
|
11972
12847
|
const struct ggml_compute_params * params,
|
11973
12848
|
const struct ggml_tensor * src0,
|
11974
12849
|
const struct ggml_tensor * src1,
|
@@ -11976,11 +12851,11 @@ static void ggml_compute_forward_conv_1d_1s(
|
|
11976
12851
|
switch (src0->type) {
|
11977
12852
|
case GGML_TYPE_F16:
|
11978
12853
|
{
|
11979
|
-
|
12854
|
+
ggml_compute_forward_conv_1d_s1_ph_f16_f32(params, src0, src1, dst);
|
11980
12855
|
} break;
|
11981
12856
|
case GGML_TYPE_F32:
|
11982
12857
|
{
|
11983
|
-
|
12858
|
+
ggml_compute_forward_conv_1d_s1_ph_f32(params, src0, src1, dst);
|
11984
12859
|
} break;
|
11985
12860
|
default:
|
11986
12861
|
{
|
@@ -11989,9 +12864,9 @@ static void ggml_compute_forward_conv_1d_1s(
|
|
11989
12864
|
}
|
11990
12865
|
}
|
11991
12866
|
|
11992
|
-
//
|
12867
|
+
// ggml_compute_forward_conv_1d_s2_ph
|
11993
12868
|
|
11994
|
-
static void
|
12869
|
+
static void ggml_compute_forward_conv_1d_s2_ph_f16_f32(
|
11995
12870
|
const struct ggml_compute_params * params,
|
11996
12871
|
const struct ggml_tensor * src0,
|
11997
12872
|
const struct ggml_tensor * src1,
|
@@ -12111,7 +12986,7 @@ static void ggml_compute_forward_conv_1d_2s_f16_f32(
|
|
12111
12986
|
}
|
12112
12987
|
}
|
12113
12988
|
|
12114
|
-
static void
|
12989
|
+
static void ggml_compute_forward_conv_1d_s2_ph_f32(
|
12115
12990
|
const struct ggml_compute_params * params,
|
12116
12991
|
const struct ggml_tensor * src0,
|
12117
12992
|
const struct ggml_tensor * src1,
|
@@ -12231,7 +13106,7 @@ static void ggml_compute_forward_conv_1d_2s_f32(
|
|
12231
13106
|
}
|
12232
13107
|
}
|
12233
13108
|
|
12234
|
-
static void
|
13109
|
+
static void ggml_compute_forward_conv_1d_s2_ph(
|
12235
13110
|
const struct ggml_compute_params * params,
|
12236
13111
|
const struct ggml_tensor * src0,
|
12237
13112
|
const struct ggml_tensor * src1,
|
@@ -12239,11 +13114,11 @@ static void ggml_compute_forward_conv_1d_2s(
|
|
12239
13114
|
switch (src0->type) {
|
12240
13115
|
case GGML_TYPE_F16:
|
12241
13116
|
{
|
12242
|
-
|
13117
|
+
ggml_compute_forward_conv_1d_s2_ph_f16_f32(params, src0, src1, dst);
|
12243
13118
|
} break;
|
12244
13119
|
case GGML_TYPE_F32:
|
12245
13120
|
{
|
12246
|
-
|
13121
|
+
ggml_compute_forward_conv_1d_s2_ph_f32(params, src0, src1, dst);
|
12247
13122
|
} break;
|
12248
13123
|
default:
|
12249
13124
|
{
|
@@ -12252,51 +13127,188 @@ static void ggml_compute_forward_conv_1d_2s(
|
|
12252
13127
|
}
|
12253
13128
|
}
|
12254
13129
|
|
12255
|
-
//
|
13130
|
+
// ggml_compute_forward_conv_2d_sk_p0
|
12256
13131
|
|
12257
|
-
static void
|
13132
|
+
static void ggml_compute_forward_conv_2d_sk_p0_f16_f32(
|
12258
13133
|
const struct ggml_compute_params * params,
|
12259
|
-
const struct ggml_tensor *
|
12260
|
-
const struct ggml_tensor *
|
12261
|
-
|
12262
|
-
|
12263
|
-
|
13134
|
+
const struct ggml_tensor * src0,
|
13135
|
+
const struct ggml_tensor * src1,
|
13136
|
+
struct ggml_tensor * dst) {
|
13137
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
13138
|
+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
13139
|
+
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
13140
|
+
|
12264
13141
|
int64_t t0 = ggml_perf_time_us();
|
12265
13142
|
UNUSED(t0);
|
12266
13143
|
|
12267
|
-
const
|
12268
|
-
const
|
12269
|
-
const
|
12270
|
-
const
|
13144
|
+
const int ne00 = src0->ne[0];
|
13145
|
+
const int ne01 = src0->ne[1];
|
13146
|
+
const int ne02 = src0->ne[2];
|
13147
|
+
//const int ne03 = src0->ne[3];
|
12271
13148
|
|
12272
|
-
const
|
12273
|
-
const
|
12274
|
-
|
12275
|
-
//const
|
13149
|
+
const int ne10 = src1->ne[0];
|
13150
|
+
//const int ne11 = src1->ne[1];
|
13151
|
+
const int ne12 = src1->ne[2];
|
13152
|
+
//const int ne13 = src1->ne[3];
|
12276
13153
|
|
12277
|
-
|
12278
|
-
const
|
12279
|
-
|
12280
|
-
//const
|
13154
|
+
const int ne0 = dst->ne[0];
|
13155
|
+
const int ne1 = dst->ne[1];
|
13156
|
+
const int ne2 = dst->ne[2];
|
13157
|
+
//const int ne3 = dst->ne[3];
|
13158
|
+
//const int ne = ne0*ne1*ne2*ne3;
|
12281
13159
|
|
12282
|
-
const
|
12283
|
-
const
|
12284
|
-
//const
|
12285
|
-
|
13160
|
+
const int nb00 = src0->nb[0];
|
13161
|
+
//const int nb01 = src0->nb[1];
|
13162
|
+
//const int nb02 = src0->nb[2];
|
13163
|
+
const int nb03 = src0->nb[3];
|
12286
13164
|
|
12287
|
-
const int
|
12288
|
-
const int
|
12289
|
-
const int
|
12290
|
-
const int
|
13165
|
+
const int nb10 = src1->nb[0];
|
13166
|
+
//const int nb11 = src1->nb[1];
|
13167
|
+
const int nb12 = src1->nb[2];
|
13168
|
+
//const int nb13 = src1->nb[3];
|
12291
13169
|
|
12292
|
-
const int
|
12293
|
-
const int
|
12294
|
-
const int
|
12295
|
-
const int
|
13170
|
+
//const int nb0 = dst->nb[0];
|
13171
|
+
//const int nb1 = dst->nb[1];
|
13172
|
+
const int nb2 = dst->nb[2];
|
13173
|
+
//const int nb3 = dst->nb[3];
|
12296
13174
|
|
12297
|
-
const int
|
12298
|
-
const int
|
12299
|
-
|
13175
|
+
const int ith = params->ith;
|
13176
|
+
const int nth = params->nth;
|
13177
|
+
|
13178
|
+
const int nk0 = ne00;
|
13179
|
+
const int nk1 = ne01;
|
13180
|
+
|
13181
|
+
// size of the convolution row - the kernel size unrolled across all channels
|
13182
|
+
// round-up so it is more suitable for SIMD
|
13183
|
+
const int ew0 = ggml_up32(nk0*nk1*ne02);
|
13184
|
+
|
13185
|
+
GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
|
13186
|
+
GGML_ASSERT(nb10 == sizeof(float));
|
13187
|
+
|
13188
|
+
if (params->type == GGML_TASK_INIT) {
|
13189
|
+
// TODO: fix this memset (wsize is overestimated)
|
13190
|
+
memset(params->wdata, 0, params->wsize);
|
13191
|
+
|
13192
|
+
// prepare source data (src1)
|
13193
|
+
{
|
13194
|
+
ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
|
13195
|
+
|
13196
|
+
for (int i12 = 0; i12 < ne12; i12++) {
|
13197
|
+
const float * const src = (float *)((char *) src1->data + i12*nb12);
|
13198
|
+
ggml_fp16_t * dst_data = wdata;
|
13199
|
+
|
13200
|
+
for (int i1 = 0; i1 < ne1; i1++) {
|
13201
|
+
for (int i0 = 0; i0 < ne0; i0++) {
|
13202
|
+
for (int ik1 = 0; ik1 < nk1; ik1++) {
|
13203
|
+
for (int ik0 = 0; ik0 < nk0; ik0++) {
|
13204
|
+
dst_data[(i1*ne0 + i0)*ew0 + i12*(nk0*nk1) + ik1*nk0 + ik0] =
|
13205
|
+
GGML_FP32_TO_FP16(src[(i1*nk1 + ik1)*ne10 + (i0*nk0 + ik0)]);
|
13206
|
+
}
|
13207
|
+
}
|
13208
|
+
}
|
13209
|
+
}
|
13210
|
+
}
|
13211
|
+
}
|
13212
|
+
|
13213
|
+
return;
|
13214
|
+
}
|
13215
|
+
|
13216
|
+
if (params->type == GGML_TASK_FINALIZE) {
|
13217
|
+
return;
|
13218
|
+
}
|
13219
|
+
|
13220
|
+
// total patches in dst
|
13221
|
+
const int np = ne2;
|
13222
|
+
|
13223
|
+
// patches per thread
|
13224
|
+
const int dp = (np + nth - 1)/nth;
|
13225
|
+
|
13226
|
+
// patch range for this thread
|
13227
|
+
const int ip0 = dp*ith;
|
13228
|
+
const int ip1 = MIN(ip0 + dp, np);
|
13229
|
+
|
13230
|
+
ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
|
13231
|
+
|
13232
|
+
for (int i2 = ip0; i2 < ip1; i2++) {
|
13233
|
+
float * dst_data = (float *)((char *) dst->data + i2*nb2);
|
13234
|
+
|
13235
|
+
for (int i1 = 0; i1 < ne1; ++i1) {
|
13236
|
+
for (int i0 = 0; i0 < ne0; ++i0) {
|
13237
|
+
ggml_vec_dot_f16(ew0, dst_data + i1*ne0 + i0,
|
13238
|
+
(ggml_fp16_t *) ((char *) src0->data + i2*nb03),
|
13239
|
+
(ggml_fp16_t *) wdata + (i1*ne0 + i0)*ew0);
|
13240
|
+
}
|
13241
|
+
}
|
13242
|
+
}
|
13243
|
+
}
|
13244
|
+
|
13245
|
+
static void ggml_compute_forward_conv_2d_sk_p0(
|
13246
|
+
const struct ggml_compute_params * params,
|
13247
|
+
const struct ggml_tensor * src0,
|
13248
|
+
const struct ggml_tensor * src1,
|
13249
|
+
struct ggml_tensor * dst) {
|
13250
|
+
switch (src0->type) {
|
13251
|
+
case GGML_TYPE_F16:
|
13252
|
+
{
|
13253
|
+
ggml_compute_forward_conv_2d_sk_p0_f16_f32(params, src0, src1, dst);
|
13254
|
+
} break;
|
13255
|
+
case GGML_TYPE_F32:
|
13256
|
+
{
|
13257
|
+
//ggml_compute_forward_conv_2d_sk_p0_f32(params, src0, src1, dst);
|
13258
|
+
GGML_ASSERT(false);
|
13259
|
+
} break;
|
13260
|
+
default:
|
13261
|
+
{
|
13262
|
+
GGML_ASSERT(false);
|
13263
|
+
} break;
|
13264
|
+
}
|
13265
|
+
}
|
13266
|
+
|
13267
|
+
// ggml_compute_forward_flash_attn
|
13268
|
+
|
13269
|
+
static void ggml_compute_forward_flash_attn_f32(
|
13270
|
+
const struct ggml_compute_params * params,
|
13271
|
+
const struct ggml_tensor * q,
|
13272
|
+
const struct ggml_tensor * k,
|
13273
|
+
const struct ggml_tensor * v,
|
13274
|
+
const bool masked,
|
13275
|
+
struct ggml_tensor * dst) {
|
13276
|
+
int64_t t0 = ggml_perf_time_us();
|
13277
|
+
UNUSED(t0);
|
13278
|
+
|
13279
|
+
const int64_t neq0 = q->ne[0];
|
13280
|
+
const int64_t neq1 = q->ne[1];
|
13281
|
+
const int64_t neq2 = q->ne[2];
|
13282
|
+
const int64_t neq3 = q->ne[3];
|
13283
|
+
|
13284
|
+
const int64_t nek0 = k->ne[0];
|
13285
|
+
const int64_t nek1 = k->ne[1];
|
13286
|
+
//const int64_t nek2 = k->ne[2];
|
13287
|
+
//const int64_t nek3 = k->ne[3];
|
13288
|
+
|
13289
|
+
//const int64_t nev0 = v->ne[0];
|
13290
|
+
const int64_t nev1 = v->ne[1];
|
13291
|
+
//const int64_t nev2 = v->ne[2];
|
13292
|
+
//const int64_t nev3 = v->ne[3];
|
13293
|
+
|
13294
|
+
const int64_t ne0 = dst->ne[0];
|
13295
|
+
const int64_t ne1 = dst->ne[1];
|
13296
|
+
//const int64_t ne2 = dst->ne[2];
|
13297
|
+
//const int64_t ne3 = dst->ne[3];
|
13298
|
+
|
13299
|
+
const int nbk0 = k->nb[0];
|
13300
|
+
const int nbk1 = k->nb[1];
|
13301
|
+
const int nbk2 = k->nb[2];
|
13302
|
+
const int nbk3 = k->nb[3];
|
13303
|
+
|
13304
|
+
const int nbq0 = q->nb[0];
|
13305
|
+
const int nbq1 = q->nb[1];
|
13306
|
+
const int nbq2 = q->nb[2];
|
13307
|
+
const int nbq3 = q->nb[3];
|
13308
|
+
|
13309
|
+
const int nbv0 = v->nb[0];
|
13310
|
+
const int nbv1 = v->nb[1];
|
13311
|
+
const int nbv2 = v->nb[2];
|
12300
13312
|
const int nbv3 = v->nb[3];
|
12301
13313
|
|
12302
13314
|
const int nb0 = dst->nb[0];
|
@@ -12938,91 +13950,917 @@ static void ggml_compute_forward_flash_ff(
|
|
12938
13950
|
}
|
12939
13951
|
}
|
12940
13952
|
|
12941
|
-
//
|
13953
|
+
// ggml_compute_forward_flash_attn_back
|
12942
13954
|
|
12943
|
-
static void
|
13955
|
+
static void ggml_compute_forward_flash_attn_back_f32(
|
12944
13956
|
const struct ggml_compute_params * params,
|
12945
|
-
const struct ggml_tensor *
|
12946
|
-
struct ggml_tensor *
|
12947
|
-
const
|
12948
|
-
|
13957
|
+
const struct ggml_tensor * q,
|
13958
|
+
const struct ggml_tensor * k,
|
13959
|
+
const struct ggml_tensor * v,
|
13960
|
+
const struct ggml_tensor * d,
|
13961
|
+
const bool masked,
|
13962
|
+
struct ggml_tensor * dst) {
|
13963
|
+
int64_t t0 = ggml_perf_time_us();
|
13964
|
+
UNUSED(t0);
|
12949
13965
|
|
12950
|
-
|
12951
|
-
|
12952
|
-
|
13966
|
+
const int64_t neq0 = q->ne[0];
|
13967
|
+
const int64_t neq1 = q->ne[1];
|
13968
|
+
const int64_t neq2 = q->ne[2];
|
13969
|
+
const int64_t neq3 = q->ne[3];
|
12953
13970
|
|
12954
|
-
const
|
12955
|
-
const
|
13971
|
+
const int64_t nek0 = k->ne[0];
|
13972
|
+
const int64_t nek1 = k->ne[1];
|
13973
|
+
//const int64_t nek2 = k->ne[2];
|
13974
|
+
//const int64_t nek3 = k->ne[3];
|
12956
13975
|
|
12957
|
-
|
12958
|
-
|
13976
|
+
const int64_t nev0 = v->ne[0];
|
13977
|
+
const int64_t nev1 = v->ne[1];
|
13978
|
+
//const int64_t nev2 = v->ne[2];
|
13979
|
+
//const int64_t nev3 = v->ne[3];
|
12959
13980
|
|
12960
|
-
|
12961
|
-
|
12962
|
-
|
12963
|
-
|
12964
|
-
}
|
12965
|
-
}
|
13981
|
+
const int64_t ned0 = d->ne[0];
|
13982
|
+
const int64_t ned1 = d->ne[1];
|
13983
|
+
//const int64_t ned2 = d->ne[2];
|
13984
|
+
//const int64_t ned3 = d->ne[3];
|
12966
13985
|
|
13986
|
+
const int64_t ne0 = dst->ne[0];
|
13987
|
+
const int64_t ne1 = dst->ne[1];
|
13988
|
+
const int64_t ne2 = dst->ne[2];
|
13989
|
+
const int64_t ne3 = dst->ne[3];
|
12967
13990
|
|
12968
|
-
|
12969
|
-
|
12970
|
-
|
12971
|
-
|
12972
|
-
const ggml_unary_op_f32_t fun) {
|
12973
|
-
switch (src0->type) {
|
12974
|
-
case GGML_TYPE_F32:
|
12975
|
-
{
|
12976
|
-
ggml_compute_forward_map_unary_f32(params, src0, dst, fun);
|
12977
|
-
} break;
|
12978
|
-
default:
|
12979
|
-
{
|
12980
|
-
GGML_ASSERT(false);
|
12981
|
-
} break;
|
12982
|
-
}
|
12983
|
-
}
|
13991
|
+
const int nbk0 = k->nb[0];
|
13992
|
+
const int nbk1 = k->nb[1];
|
13993
|
+
const int nbk2 = k->nb[2];
|
13994
|
+
const int nbk3 = k->nb[3];
|
12984
13995
|
|
12985
|
-
|
13996
|
+
const int nbq0 = q->nb[0];
|
13997
|
+
const int nbq1 = q->nb[1];
|
13998
|
+
const int nbq2 = q->nb[2];
|
13999
|
+
const int nbq3 = q->nb[3];
|
12986
14000
|
|
12987
|
-
|
12988
|
-
|
12989
|
-
|
12990
|
-
|
12991
|
-
struct ggml_tensor * dst,
|
12992
|
-
const ggml_binary_op_f32_t fun) {
|
12993
|
-
assert(params->ith == 0);
|
12994
|
-
assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
|
14001
|
+
const int nbv0 = v->nb[0];
|
14002
|
+
const int nbv1 = v->nb[1];
|
14003
|
+
const int nbv2 = v->nb[2];
|
14004
|
+
const int nbv3 = v->nb[3];
|
12995
14005
|
|
12996
|
-
|
14006
|
+
const int nbd0 = d->nb[0];
|
14007
|
+
const int nbd1 = d->nb[1];
|
14008
|
+
const int nbd2 = d->nb[2];
|
14009
|
+
const int nbd3 = d->nb[3];
|
14010
|
+
|
14011
|
+
const int nb0 = dst->nb[0];
|
14012
|
+
const int nb1 = dst->nb[1];
|
14013
|
+
const int nb2 = dst->nb[2];
|
14014
|
+
const int nb3 = dst->nb[3];
|
14015
|
+
|
14016
|
+
const int ith = params->ith;
|
14017
|
+
const int nth = params->nth;
|
14018
|
+
|
14019
|
+
const int64_t D = neq0;
|
14020
|
+
const int64_t N = neq1;
|
14021
|
+
const int64_t P = nek1 - N;
|
14022
|
+
const int64_t M = P + N;
|
14023
|
+
|
14024
|
+
const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL);
|
14025
|
+
const int mxDM = MAX(D, Mup);
|
14026
|
+
|
14027
|
+
// GGML_ASSERT(ne0 == D);
|
14028
|
+
// GGML_ASSERT(ne1 == N);
|
14029
|
+
GGML_ASSERT(P >= 0);
|
14030
|
+
|
14031
|
+
GGML_ASSERT(nbq0 == sizeof(float));
|
14032
|
+
GGML_ASSERT(nbk0 == sizeof(float));
|
14033
|
+
GGML_ASSERT(nbv0 == sizeof(float));
|
14034
|
+
|
14035
|
+
GGML_ASSERT(neq0 == D);
|
14036
|
+
GGML_ASSERT(nek0 == D);
|
14037
|
+
GGML_ASSERT(nev1 == D);
|
14038
|
+
GGML_ASSERT(ned0 == D);
|
14039
|
+
|
14040
|
+
GGML_ASSERT(neq1 == N);
|
14041
|
+
GGML_ASSERT(nek1 == N + P);
|
14042
|
+
GGML_ASSERT(nev1 == D);
|
14043
|
+
GGML_ASSERT(ned1 == N);
|
14044
|
+
|
14045
|
+
// dst cannot be transposed or permuted
|
14046
|
+
GGML_ASSERT(nb0 == sizeof(float));
|
14047
|
+
GGML_ASSERT(nb0 <= nb1);
|
14048
|
+
GGML_ASSERT(nb1 <= nb2);
|
14049
|
+
GGML_ASSERT(nb2 <= nb3);
|
14050
|
+
|
14051
|
+
if (params->type == GGML_TASK_INIT) {
|
14052
|
+
if (ith == 0) {
|
14053
|
+
memset(dst->data, 0, nb0*ne0*ne1*ne2*ne3);
|
14054
|
+
}
|
14055
|
+
return;
|
14056
|
+
}
|
14057
|
+
|
14058
|
+
if (params->type == GGML_TASK_FINALIZE) {
|
12997
14059
|
return;
|
12998
14060
|
}
|
12999
14061
|
|
13000
|
-
|
13001
|
-
|
14062
|
+
// parallelize by q rows using ggml_vec_dot_f32
|
14063
|
+
|
14064
|
+
// total rows in q
|
14065
|
+
const int nr = neq2*neq3;
|
14066
|
+
|
14067
|
+
// rows per thread
|
14068
|
+
const int dr = (nr + nth - 1)/nth;
|
14069
|
+
|
14070
|
+
// row range for this thread
|
14071
|
+
const int ir0 = dr*ith;
|
14072
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
14073
|
+
|
14074
|
+
const float scale = 1.0f/sqrtf(D);
|
14075
|
+
|
14076
|
+
//printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
|
14077
|
+
|
14078
|
+
for (int ir = ir0; ir < ir1; ++ir) {
|
14079
|
+
// q indices
|
14080
|
+
const int iq3 = ir/(neq2);
|
14081
|
+
const int iq2 = ir - iq3*neq2;
|
14082
|
+
for ( int iq1 = 0; iq1 < neq1; ++iq1) {
|
14083
|
+
|
14084
|
+
|
14085
|
+
// not sure about CACHE_LINE_SIZE_F32..
|
14086
|
+
// - maybe it must not be multiplied by 2 and excluded from .. in SM 1*(..) offset?
|
14087
|
+
float * S = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 0*(mxDM+CACHE_LINE_SIZE_F32);
|
14088
|
+
float * SM = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 1*(mxDM+CACHE_LINE_SIZE_F32);
|
14089
|
+
|
14090
|
+
for (int i = M; i < Mup; ++i) {
|
14091
|
+
S[i] = -INFINITY;
|
14092
|
+
}
|
14093
|
+
|
14094
|
+
for (int64_t ic = 0; ic < nek1; ++ic) {
|
14095
|
+
// k indices
|
14096
|
+
const int ik3 = iq3;
|
14097
|
+
const int ik2 = iq2;
|
14098
|
+
const int ik1 = ic;
|
14099
|
+
|
14100
|
+
// S indices
|
14101
|
+
const int i1 = ik1;
|
14102
|
+
|
14103
|
+
ggml_vec_dot_f32(neq0,
|
14104
|
+
S + i1,
|
14105
|
+
(float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
|
14106
|
+
(float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
|
14107
|
+
}
|
14108
|
+
|
14109
|
+
// scale
|
14110
|
+
ggml_vec_scale_f32(nek1, S, scale);
|
14111
|
+
|
14112
|
+
if (masked) {
|
14113
|
+
for (int64_t i = P; i < M; i++) {
|
14114
|
+
if (i > P + iq1) {
|
14115
|
+
S[i] = -INFINITY;
|
14116
|
+
}
|
14117
|
+
}
|
14118
|
+
}
|
14119
|
+
|
14120
|
+
// softmax
|
14121
|
+
{
|
14122
|
+
float max = -INFINITY;
|
14123
|
+
ggml_vec_max_f32(M, &max, S);
|
14124
|
+
|
14125
|
+
ggml_float sum = 0.0;
|
14126
|
+
{
|
14127
|
+
#ifdef GGML_SOFT_MAX_ACCELERATE
|
14128
|
+
max = -max;
|
14129
|
+
vDSP_vsadd(SM, 1, &max, SM, 1, Mup);
|
14130
|
+
vvexpf(SM, SM, &Mup);
|
14131
|
+
ggml_vec_sum_f32(Mup, &sum, SM);
|
14132
|
+
#else
|
14133
|
+
uint16_t scvt[GGML_SOFT_MAX_UNROLL];
|
14134
|
+
ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
|
14135
|
+
|
14136
|
+
for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
|
14137
|
+
float * SR = S + i;
|
14138
|
+
float * SW = SM + i;
|
14139
|
+
|
14140
|
+
for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) {
|
14141
|
+
if (SR[j] == -INFINITY) {
|
14142
|
+
SW[j] = 0.0f;
|
14143
|
+
} else {
|
14144
|
+
ggml_fp16_t s = GGML_FP32_TO_FP16(SR[j] - max);
|
14145
|
+
memcpy(&scvt[j], &s, sizeof(uint16_t));
|
14146
|
+
const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]);
|
14147
|
+
sump[j] += (ggml_float)val;
|
14148
|
+
SW[j] = val;
|
14149
|
+
}
|
14150
|
+
}
|
14151
|
+
}
|
14152
|
+
|
14153
|
+
for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) {
|
14154
|
+
sum += sump[i];
|
14155
|
+
}
|
14156
|
+
#endif
|
14157
|
+
}
|
14158
|
+
|
14159
|
+
assert(sum > 0.0);
|
14160
|
+
|
14161
|
+
sum = 1.0/sum;
|
14162
|
+
ggml_vec_scale_f32(M, SM, sum);
|
14163
|
+
|
14164
|
+
}
|
14165
|
+
|
14166
|
+
// step-by-step explanation
|
14167
|
+
{
|
14168
|
+
// forward-process shape grads from backward process
|
14169
|
+
// parallel_for iq2,iq3:
|
14170
|
+
// k[:D,:M,:,:] [D,M,:,:] grad[k][:D,:M,iq2,iq3] += grad[kcur]
|
14171
|
+
// q[:D,:N,:,:] [D,N,:,:] grad[q][:D,iq1,iq2,iq3] += grad[qcur]
|
14172
|
+
// v[:M,:D,:,:] [M,D,:,:] grad[v][:M,:D,iq2,iq3] += grad[vcur]
|
14173
|
+
// for iq1:
|
14174
|
+
// kcur = k[:D,:M,iq2,iq3] [D,M,1,1] grad[kcur] = grad[S1].T @ qcur
|
14175
|
+
// qcur = q[:D,iq1,iq2,iq3] [D,1,1,1] grad[qcur] = grad[S1] @ kcur
|
14176
|
+
// vcur = v[:M,:D,iq2,iq3] [M,D,1,1] grad[vcur] = grad[S5].T @ S4
|
14177
|
+
// S0 = -Inf [D,1,1,1]
|
14178
|
+
// ~S1[i] = dot(kcur[:D,i], qcur)
|
14179
|
+
// S1 = qcur @ kcur.T [M,1,1,1] grad[S1] = grad[S2] * scale
|
14180
|
+
// S2 = S1 * scale [M,1,1,1] grad[S2] = diag_mask_zero(grad[S3], P)
|
14181
|
+
// S3 = diag_mask_inf(S2, P) [M,1,1,1] grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
|
14182
|
+
// S4 = softmax(S3) [M,1,1,1] grad[S4] = grad[S5] @ vcur
|
14183
|
+
// ~S5[i] = dot(vcur[:,i], S4)
|
14184
|
+
// S5 = S4 @ vcur.T [D,1,1,1] grad[S5] = d[:D,iq1,iq2,iq3]
|
14185
|
+
// ~dst[i,iq1,iq2,iq3] = S5[i] ^
|
14186
|
+
// dst[:D,iq1,iq2,iq3] = S5 | grad[dst[:D,iq1,iq2,iq3]] = d[:D,iq1,iq2,iq3]
|
14187
|
+
// dst backward-/ grad[dst] = d
|
14188
|
+
//
|
14189
|
+
// output gradients with their dependencies:
|
14190
|
+
//
|
14191
|
+
// grad[kcur] = grad[S1].T @ qcur
|
14192
|
+
// grad[S1] = diag_mask_zero(grad[S3], P) * scale
|
14193
|
+
// grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
|
14194
|
+
// grad[S4] = grad[S5] @ vcur
|
14195
|
+
// grad[S4] = d[:D,iq1,iq2,iq3] @ vcur
|
14196
|
+
// grad[qcur] = grad[S1] @ kcur
|
14197
|
+
// grad[vcur] = grad[S5].T @ S4
|
14198
|
+
// grad[vcur] = d[:D,iq1,iq2,iq3].T @ S4
|
14199
|
+
//
|
14200
|
+
// in post-order:
|
14201
|
+
//
|
14202
|
+
// S1 = qcur @ kcur.T
|
14203
|
+
// S2 = S1 * scale
|
14204
|
+
// S3 = diag_mask_inf(S2, P)
|
14205
|
+
// S4 = softmax(S3)
|
14206
|
+
// grad[S4] = d[:D,iq1,iq2,iq3] @ vcur
|
14207
|
+
// grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
|
14208
|
+
// grad[S1] = diag_mask_zero(grad[S3], P) * scale
|
14209
|
+
// grad[qcur] = grad[S1] @ kcur
|
14210
|
+
// grad[kcur] = grad[S1].T @ qcur
|
14211
|
+
// grad[vcur] = d[:D,iq1,iq2,iq3].T @ S4
|
14212
|
+
//
|
14213
|
+
// using less variables (SM=S4):
|
14214
|
+
//
|
14215
|
+
// S = diag_mask_inf(qcur @ kcur.T * scale, P)
|
14216
|
+
// SM = softmax(S)
|
14217
|
+
// S = d[:D,iq1,iq2,iq3] @ vcur
|
14218
|
+
// dot_SM_gradSM = dot(SM, S)
|
14219
|
+
// S = SM * (S - dot(SM, S))
|
14220
|
+
// S = diag_mask_zero(S, P) * scale
|
14221
|
+
//
|
14222
|
+
// grad[q][:D,iq1,iq2,iq3] += S @ kcur
|
14223
|
+
// grad[k][:D,:M,iq2,iq3] += S.T @ qcur
|
14224
|
+
// grad[v][:M,:D,iq2,iq3] += d[:D,iq1,iq2,iq3].T @ SM
|
14225
|
+
}
|
14226
|
+
|
14227
|
+
// S = gradSM = d[:D,iq1,iq2,iq3] @ vcur
|
14228
|
+
// S = d[:D,iq1,iq2,iq3] @ vcur
|
14229
|
+
// S[:M] += vcur[:M,ic] * d[ic,iq1,iq2,iq3]
|
14230
|
+
ggml_vec_set_f32(M, S, 0);
|
14231
|
+
for (int64_t ic = 0; ic < D; ++ic) {
|
14232
|
+
// dst indices
|
14233
|
+
const int i1 = iq1;
|
14234
|
+
const int i2 = iq2;
|
14235
|
+
const int i3 = iq3;
|
14236
|
+
|
14237
|
+
ggml_vec_mad_f32(M,
|
14238
|
+
S,
|
14239
|
+
(float *) ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)),
|
14240
|
+
*(float *) ((char *) d->data + (ic*nbd0 + i1*nbd1 + i2*nbd2 + i3*nbd3)));
|
14241
|
+
}
|
14242
|
+
|
14243
|
+
// S = SM * (S - dot(SM, S))
|
14244
|
+
float dot_SM_gradSM = 0;
|
14245
|
+
ggml_vec_dot_f32 (M, &dot_SM_gradSM, SM, S);
|
14246
|
+
ggml_vec_acc1_f32(M, S, -dot_SM_gradSM);
|
14247
|
+
ggml_vec_mul_f32 (M, S, S, SM);
|
14248
|
+
|
14249
|
+
// S = diag_mask_zero(S, P) * scale
|
14250
|
+
if (masked) {
|
14251
|
+
// for (int64_t i = P + iq1 + 1; i < M; i++) {
|
14252
|
+
// S[i] = 0;
|
14253
|
+
// }
|
14254
|
+
for (int64_t i = P; i < M; i++) {
|
14255
|
+
if (i > P + iq1) {
|
14256
|
+
S[i] = 0;
|
14257
|
+
}
|
14258
|
+
}
|
14259
|
+
}
|
14260
|
+
ggml_vec_scale_f32(M, S, scale);
|
14261
|
+
|
14262
|
+
void * grad_q = (char *) dst->data;
|
14263
|
+
void * grad_k = (char *) dst->data + nb0*D*N*neq2*neq3;
|
14264
|
+
void * grad_v = (char *) dst->data + nb0*D*N*neq2*neq3 + nb0*D*M*neq2*neq3;
|
14265
|
+
|
14266
|
+
const size_t nbgq1 = nb0*neq0;
|
14267
|
+
const size_t nbgq2 = nb0*neq0*neq1;
|
14268
|
+
const size_t nbgq3 = nb0*neq0*neq1*neq2;
|
14269
|
+
|
14270
|
+
const size_t nbgk1 = nb0*nek0;
|
14271
|
+
const size_t nbgk2 = nb0*nek0*nek1;
|
14272
|
+
const size_t nbgk3 = nb0*nek0*nek1*neq2;
|
14273
|
+
|
14274
|
+
const size_t nbgv1 = nb0*nev0;
|
14275
|
+
const size_t nbgv2 = nb0*nev0*nev1;
|
14276
|
+
const size_t nbgv3 = nb0*nev0*nev1*neq2;
|
14277
|
+
|
14278
|
+
// S shape [M,1]
|
14279
|
+
// SM shape [M,1]
|
14280
|
+
// kcur shape [D,M]
|
14281
|
+
// qcur shape [D,1]
|
14282
|
+
// vcur shape [M,D]
|
14283
|
+
//
|
14284
|
+
// grad[q][:D,iq1,iq2,iq3] += S @ kcur
|
14285
|
+
// grad[q][:D,iq1,iq2,iq3] += shape[M,1] @ shape[D,M]
|
14286
|
+
// grad[q][:D,iq1,iq2,iq3] += S[ic] * kcur[:D,ic]
|
14287
|
+
//
|
14288
|
+
//// grad[q][ic,iq1,iq2,iq3] += dot(kcur[:,ic],S.T)
|
14289
|
+
//// grad[q][ic,iq1,iq2,iq3] += dot(k[:D,ic,iq2,iq3],S.T)
|
14290
|
+
for (int64_t ic = 0; ic < M; ++ic) {
|
14291
|
+
// dst indices
|
14292
|
+
const int i1 = iq1;
|
14293
|
+
const int i2 = iq2;
|
14294
|
+
const int i3 = iq3;
|
14295
|
+
|
14296
|
+
ggml_vec_mad_f32(D,
|
14297
|
+
(float *) ((char *) grad_q + (i1*nbgq1 + i2*nbgq2 + i3*nbgq3)),
|
14298
|
+
(float *) ((char *) k->data + (ic*nbk1 + i2*nbk2 + i3*nbk3)),
|
14299
|
+
S[ic]);
|
14300
|
+
}
|
14301
|
+
|
14302
|
+
// grad[k][:D,:M,iq2,iq3] += S.T @ qcur
|
14303
|
+
// grad[k][:D,ic,iq2,iq3] += S.T[0,ic] * qcur[:D,0]
|
14304
|
+
// grad[k][:D,ic,iq2,iq3] += S[ic] * qcur[:D,0]
|
14305
|
+
for (int64_t ic = 0; ic < M; ++ic) {
|
14306
|
+
// dst indices
|
14307
|
+
const int i1 = iq1;
|
14308
|
+
const int i2 = iq2;
|
14309
|
+
const int i3 = iq3;
|
14310
|
+
|
14311
|
+
// ggml_vec_set_f32(D,
|
14312
|
+
// (float *) ((char *) grad_k + (ic*nbgk1 + i2*nbgk2 + i3*nbgk3)),
|
14313
|
+
// 0);
|
14314
|
+
ggml_vec_mad_f32(D,
|
14315
|
+
(float *) ((char *) grad_k + (ic*nbgk1 + i2*nbgk2 + i3*nbgk3)),
|
14316
|
+
(float *) ((char *) q->data + (i1*nbq1 + i2*nbq2 + i3*nbq3)),
|
14317
|
+
S[ic]);
|
14318
|
+
}
|
14319
|
+
|
14320
|
+
// grad[v][:M,:D,iq2,iq3] += d[:D,iq1,iq2,iq3].T @ SM
|
14321
|
+
// grad[v][:M,ic,iq2,iq3] += d[:D,iq1,iq2,iq3].T[0,ic] * SM[:M]
|
14322
|
+
// grad[v][:M,ic,iq2,iq3] += d[ic,iq1,iq2,iq3] * SM[:M]
|
14323
|
+
for (int64_t ic = 0; ic < D; ++ic) {
|
14324
|
+
// dst indices
|
14325
|
+
const int i1 = iq1;
|
14326
|
+
const int i2 = iq2;
|
14327
|
+
const int i3 = iq3;
|
14328
|
+
|
14329
|
+
// ggml_vec_set_f32(M,
|
14330
|
+
// (float *) ((char *) grad_v + ( ic*nbgv1 + i2*nbgv2 + i3*nbgv3)),
|
14331
|
+
// 0);
|
14332
|
+
ggml_vec_mad_f32(M,
|
14333
|
+
(float *) ((char *) grad_v + ( ic*nbgv1 + i2*nbgv2 + i3*nbgv3)),
|
14334
|
+
SM,
|
14335
|
+
*(float *) ((char *) d->data + (ic*nbd0 + i1*nbd1 + i2*nbd2 + i3*nbd3)));
|
14336
|
+
}
|
14337
|
+
}
|
14338
|
+
}
|
14339
|
+
}
|
14340
|
+
|
14341
|
+
static void ggml_compute_forward_flash_attn_back(
|
14342
|
+
const struct ggml_compute_params * params,
|
14343
|
+
const struct ggml_tensor * q,
|
14344
|
+
const struct ggml_tensor * k,
|
14345
|
+
const struct ggml_tensor * v,
|
14346
|
+
const struct ggml_tensor * d,
|
14347
|
+
const bool masked,
|
14348
|
+
struct ggml_tensor * dst) {
|
14349
|
+
switch (q->type) {
|
14350
|
+
case GGML_TYPE_F32:
|
14351
|
+
{
|
14352
|
+
ggml_compute_forward_flash_attn_back_f32(params, q, k, v, d, masked, dst);
|
14353
|
+
} break;
|
14354
|
+
default:
|
14355
|
+
{
|
14356
|
+
GGML_ASSERT(false);
|
14357
|
+
} break;
|
14358
|
+
}
|
14359
|
+
}
|
14360
|
+
|
14361
|
+
// ggml_compute_forward_win_part
|
14362
|
+
|
14363
|
+
static void ggml_compute_forward_win_part_f32(
|
14364
|
+
const struct ggml_compute_params * params,
|
14365
|
+
const struct ggml_tensor * src0,
|
14366
|
+
const struct ggml_tensor * opt0,
|
14367
|
+
struct ggml_tensor * dst) {
|
14368
|
+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
14369
|
+
return;
|
14370
|
+
}
|
14371
|
+
|
14372
|
+
const int64_t ne00 = src0->ne[0]; UNUSED(ne00);
|
14373
|
+
const int64_t ne01 = src0->ne[1];
|
14374
|
+
const int64_t ne02 = src0->ne[2];
|
14375
|
+
const int64_t ne03 = src0->ne[3]; UNUSED(ne03);
|
14376
|
+
|
14377
|
+
const int64_t ne0 = dst->ne[0];
|
14378
|
+
const int64_t ne1 = dst->ne[1];
|
14379
|
+
const int64_t ne2 = dst->ne[2];
|
14380
|
+
const int64_t ne3 = dst->ne[3]; UNUSED(ne3);
|
14381
|
+
|
14382
|
+
const int32_t nep0 = ((const int32_t *)(opt0->data))[0];
|
14383
|
+
const int32_t nep1 = ((const int32_t *)(opt0->data))[1];
|
14384
|
+
const int32_t w = ((const int32_t *)(opt0->data))[2];
|
14385
|
+
|
14386
|
+
assert(ne00 == ne0);
|
14387
|
+
assert(ne3 == nep0*nep1);
|
14388
|
+
|
14389
|
+
// TODO: optimize / multi-thread
|
14390
|
+
for (int py = 0; py < nep1; ++py) {
|
14391
|
+
for (int px = 0; px < nep0; ++px) {
|
14392
|
+
const int64_t i3 = py*nep0 + px;
|
14393
|
+
for (int64_t i2 = 0; i2 < ne2; ++i2) {
|
14394
|
+
for (int64_t i1 = 0; i1 < ne1; ++i1) {
|
14395
|
+
for (int64_t i0 = 0; i0 < ne0; ++i0) {
|
14396
|
+
const int64_t i02 = py*w + i2;
|
14397
|
+
const int64_t i01 = px*w + i1;
|
14398
|
+
const int64_t i00 = i0;
|
14399
|
+
|
14400
|
+
const int64_t i = i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0 + i0;
|
14401
|
+
const int64_t j = i02*ne01*ne00 + i01*ne00 + i00;
|
14402
|
+
|
14403
|
+
if (py*w + i2 >= ne02 || px*w + i1 >= ne01) {
|
14404
|
+
((float *) dst->data)[i] = 0.0f;
|
14405
|
+
} else {
|
14406
|
+
((float *) dst->data)[i] = ((float *) src0->data)[j];
|
14407
|
+
}
|
14408
|
+
}
|
14409
|
+
}
|
14410
|
+
}
|
14411
|
+
}
|
14412
|
+
}
|
14413
|
+
}
|
14414
|
+
|
14415
|
+
static void ggml_compute_forward_win_part(
|
14416
|
+
const struct ggml_compute_params * params,
|
14417
|
+
const struct ggml_tensor * src0,
|
14418
|
+
const struct ggml_tensor * opt0,
|
14419
|
+
struct ggml_tensor * dst) {
|
14420
|
+
switch (src0->type) {
|
14421
|
+
case GGML_TYPE_F32:
|
14422
|
+
{
|
14423
|
+
ggml_compute_forward_win_part_f32(params, src0, opt0, dst);
|
14424
|
+
} break;
|
14425
|
+
default:
|
14426
|
+
{
|
14427
|
+
GGML_ASSERT(false);
|
14428
|
+
} break;
|
14429
|
+
}
|
14430
|
+
}
|
14431
|
+
|
14432
|
+
// ggml_compute_forward_win_unpart
|
14433
|
+
|
14434
|
+
static void ggml_compute_forward_win_unpart_f32(
|
14435
|
+
const struct ggml_compute_params * params,
|
14436
|
+
const struct ggml_tensor * src0,
|
14437
|
+
const struct ggml_tensor * opt0,
|
14438
|
+
struct ggml_tensor * dst) {
|
14439
|
+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
14440
|
+
return;
|
14441
|
+
}
|
14442
|
+
|
14443
|
+
const int64_t ne00 = src0->ne[0];
|
14444
|
+
const int64_t ne01 = src0->ne[1];
|
14445
|
+
const int64_t ne02 = src0->ne[2];
|
14446
|
+
//const int64_t ne03 = src0->ne[3];
|
14447
|
+
|
14448
|
+
const int64_t ne0 = dst->ne[0];
|
14449
|
+
const int64_t ne1 = dst->ne[1];
|
14450
|
+
const int64_t ne2 = dst->ne[2];
|
14451
|
+
|
14452
|
+
const int32_t w = ((const int32_t *)(opt0->data))[0];
|
14453
|
+
|
14454
|
+
// padding
|
14455
|
+
const int px = (w - ne1%w)%w;
|
14456
|
+
//const int py = (w - ne2%w)%w;
|
14457
|
+
|
14458
|
+
const int npx = (px + ne1)/w;
|
14459
|
+
//const int npy = (py + ne2)/w;
|
14460
|
+
|
14461
|
+
assert(ne0 == ne00);
|
14462
|
+
|
14463
|
+
// TODO: optimize / multi-thread
|
14464
|
+
for (int64_t i2 = 0; i2 < ne2; ++i2) {
|
14465
|
+
for (int64_t i1 = 0; i1 < ne1; ++i1) {
|
14466
|
+
for (int64_t i0 = 0; i0 < ne0; ++i0) {
|
14467
|
+
const int ip2 = i2/w;
|
14468
|
+
const int ip1 = i1/w;
|
14469
|
+
|
14470
|
+
const int64_t i02 = i2%w;
|
14471
|
+
const int64_t i01 = i1%w;
|
14472
|
+
const int64_t i00 = i0;
|
14473
|
+
|
14474
|
+
const int64_t i = (ip2*npx + ip1)*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00 + i00;
|
14475
|
+
const int64_t j = i2*ne1*ne0 + i1*ne0 + i0;
|
14476
|
+
|
14477
|
+
((float *) dst->data)[j] = ((float *) src0->data)[i];
|
14478
|
+
}
|
14479
|
+
}
|
14480
|
+
}
|
14481
|
+
}
|
14482
|
+
|
14483
|
+
static void ggml_compute_forward_win_unpart(
|
14484
|
+
const struct ggml_compute_params * params,
|
14485
|
+
const struct ggml_tensor * src0,
|
14486
|
+
const struct ggml_tensor * opt0,
|
14487
|
+
struct ggml_tensor * dst) {
|
14488
|
+
switch (src0->type) {
|
14489
|
+
case GGML_TYPE_F32:
|
14490
|
+
{
|
14491
|
+
ggml_compute_forward_win_unpart_f32(params, src0, opt0, dst);
|
14492
|
+
} break;
|
14493
|
+
default:
|
14494
|
+
{
|
14495
|
+
GGML_ASSERT(false);
|
14496
|
+
} break;
|
14497
|
+
}
|
14498
|
+
}
|
14499
|
+
|
14500
|
+
// ggml_compute_forward_map_unary
|
14501
|
+
|
14502
|
+
static void ggml_compute_forward_map_unary_f32(
|
14503
|
+
const struct ggml_compute_params * params,
|
14504
|
+
const struct ggml_tensor * src0,
|
14505
|
+
struct ggml_tensor * dst,
|
14506
|
+
const ggml_unary_op_f32_t fun) {
|
14507
|
+
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
14508
|
+
|
14509
|
+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
14510
|
+
return;
|
14511
|
+
}
|
14512
|
+
|
14513
|
+
const int n = ggml_nrows(src0);
|
14514
|
+
const int nc = src0->ne[0];
|
14515
|
+
|
14516
|
+
assert( dst->nb[0] == sizeof(float));
|
14517
|
+
assert(src0->nb[0] == sizeof(float));
|
14518
|
+
|
14519
|
+
for (int i = 0; i < n; i++) {
|
14520
|
+
fun(nc,
|
14521
|
+
(float *) ((char *) dst->data + i*( dst->nb[1])),
|
14522
|
+
(float *) ((char *) src0->data + i*(src0->nb[1])));
|
14523
|
+
}
|
14524
|
+
}
|
14525
|
+
|
14526
|
+
|
14527
|
+
static void ggml_compute_forward_map_unary(
|
14528
|
+
const struct ggml_compute_params * params,
|
14529
|
+
const struct ggml_tensor * src0,
|
14530
|
+
struct ggml_tensor * dst,
|
14531
|
+
const ggml_unary_op_f32_t fun) {
|
14532
|
+
switch (src0->type) {
|
14533
|
+
case GGML_TYPE_F32:
|
14534
|
+
{
|
14535
|
+
ggml_compute_forward_map_unary_f32(params, src0, dst, fun);
|
14536
|
+
} break;
|
14537
|
+
default:
|
14538
|
+
{
|
14539
|
+
GGML_ASSERT(false);
|
14540
|
+
} break;
|
14541
|
+
}
|
14542
|
+
}
|
14543
|
+
|
14544
|
+
// ggml_compute_forward_map_binary
|
14545
|
+
|
14546
|
+
static void ggml_compute_forward_map_binary_f32(
|
14547
|
+
const struct ggml_compute_params * params,
|
14548
|
+
const struct ggml_tensor * src0,
|
14549
|
+
const struct ggml_tensor * src1,
|
14550
|
+
struct ggml_tensor * dst,
|
14551
|
+
const ggml_binary_op_f32_t fun) {
|
14552
|
+
assert(params->ith == 0);
|
14553
|
+
assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
|
14554
|
+
|
14555
|
+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
14556
|
+
return;
|
14557
|
+
}
|
14558
|
+
|
14559
|
+
const int n = ggml_nrows(src0);
|
14560
|
+
const int nc = src0->ne[0];
|
14561
|
+
|
14562
|
+
assert( dst->nb[0] == sizeof(float));
|
14563
|
+
assert(src0->nb[0] == sizeof(float));
|
14564
|
+
assert(src1->nb[0] == sizeof(float));
|
14565
|
+
|
14566
|
+
for (int i = 0; i < n; i++) {
|
14567
|
+
fun(nc,
|
14568
|
+
(float *) ((char *) dst->data + i*( dst->nb[1])),
|
14569
|
+
(float *) ((char *) src0->data + i*(src0->nb[1])),
|
14570
|
+
(float *) ((char *) src1->data + i*(src1->nb[1])));
|
14571
|
+
}
|
14572
|
+
}
|
14573
|
+
|
14574
|
+
|
14575
|
+
static void ggml_compute_forward_map_binary(
|
14576
|
+
const struct ggml_compute_params * params,
|
14577
|
+
const struct ggml_tensor * src0,
|
14578
|
+
const struct ggml_tensor * src1,
|
14579
|
+
struct ggml_tensor * dst,
|
14580
|
+
const ggml_binary_op_f32_t fun) {
|
14581
|
+
switch (src0->type) {
|
14582
|
+
case GGML_TYPE_F32:
|
14583
|
+
{
|
14584
|
+
ggml_compute_forward_map_binary_f32(params, src0, src1, dst, fun);
|
14585
|
+
} break;
|
14586
|
+
default:
|
14587
|
+
{
|
14588
|
+
GGML_ASSERT(false);
|
14589
|
+
} break;
|
14590
|
+
}
|
14591
|
+
}
|
14592
|
+
|
14593
|
+
// ggml_compute_forward_cross_entropy_loss
|
14594
|
+
|
14595
|
+
static void ggml_compute_forward_cross_entropy_loss_f32(
|
14596
|
+
const struct ggml_compute_params * params,
|
14597
|
+
const struct ggml_tensor * src0,
|
14598
|
+
const struct ggml_tensor * src1,
|
14599
|
+
struct ggml_tensor * dst) {
|
14600
|
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
14601
|
+
GGML_ASSERT(ggml_is_contiguous(src1));
|
14602
|
+
GGML_ASSERT(ggml_is_scalar(dst));
|
14603
|
+
GGML_ASSERT(ggml_are_same_shape(src0, src1));
|
14604
|
+
|
14605
|
+
const int ith = params->ith;
|
14606
|
+
const int nth = params->nth;
|
14607
|
+
|
14608
|
+
float * sums = (float *) params->wdata;
|
14609
|
+
|
14610
|
+
// TODO: handle transposed/permuted matrices
|
14611
|
+
const int nc = src0->ne[0];
|
14612
|
+
const int nr = ggml_nrows(src0);
|
14613
|
+
|
14614
|
+
if (params->type == GGML_TASK_INIT) {
|
14615
|
+
if (ith == 0) {
|
14616
|
+
memset(sums, 0, sizeof(float) * (nth + nth * nc));
|
14617
|
+
}
|
14618
|
+
return;
|
14619
|
+
}
|
14620
|
+
|
14621
|
+
if (params->type == GGML_TASK_FINALIZE) {
|
14622
|
+
if (ith == 0) {
|
14623
|
+
float * dp = (float *) dst->data;
|
14624
|
+
ggml_vec_sum_f32(nth, dp, sums);
|
14625
|
+
dp[0] *= -1.0f;
|
14626
|
+
}
|
14627
|
+
return;
|
14628
|
+
}
|
14629
|
+
|
14630
|
+
const double eps = 1e-9;
|
14631
|
+
|
14632
|
+
// rows per thread
|
14633
|
+
const int dr = (nr + nth - 1)/nth;
|
14634
|
+
|
14635
|
+
// row range for this thread
|
14636
|
+
const int ir0 = dr*ith;
|
14637
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
14638
|
+
|
14639
|
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
14640
|
+
float * s0 = (float *)((char *) src0->data + i1*src0->nb[1]);
|
14641
|
+
float * s1 = (float *)((char *) src1->data + i1*src1->nb[1]);
|
14642
|
+
float * st = (float *) params->wdata + nth + ith*nc;
|
14643
|
+
|
14644
|
+
#ifndef NDEBUG
|
14645
|
+
for (int i = 0; i < nc; ++i) {
|
14646
|
+
//printf("p[%d] = %f\n", i, p[i]);
|
14647
|
+
assert(!isnan(s0[i]));
|
14648
|
+
assert(!isnan(s1[i]));
|
14649
|
+
}
|
14650
|
+
#endif
|
14651
|
+
// soft_max
|
14652
|
+
ggml_float sum = 0.0;
|
14653
|
+
{
|
14654
|
+
float max = -INFINITY;
|
14655
|
+
ggml_vec_max_f32(nc, &max, s0);
|
14656
|
+
|
14657
|
+
uint16_t scvt;
|
14658
|
+
for (int i = 0; i < nc; i++) {
|
14659
|
+
if (s0[i] == -INFINITY) {
|
14660
|
+
st[i] = 0.0f;
|
14661
|
+
} else {
|
14662
|
+
// const float val = (s0[i] == -INFINITY) ? 0.0 : exp(s0[i] - max);
|
14663
|
+
ggml_fp16_t s = GGML_FP32_TO_FP16(s0[i] - max);
|
14664
|
+
memcpy(&scvt, &s, sizeof(scvt));
|
14665
|
+
const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt]);
|
14666
|
+
sum += (ggml_float)val;
|
14667
|
+
st[i] = val;
|
14668
|
+
}
|
14669
|
+
}
|
14670
|
+
|
14671
|
+
assert(sum > 0.0);
|
14672
|
+
// sum = 1.0/sum;
|
14673
|
+
}
|
14674
|
+
// avoid log(0) by rescaling from [0..1] to [eps..1]
|
14675
|
+
sum = (1.0 - eps) / sum;
|
14676
|
+
ggml_vec_scale_f32(nc, st, sum);
|
14677
|
+
ggml_vec_add1_f32(nc, st, st, eps);
|
14678
|
+
ggml_vec_log_f32(nc, st, st);
|
14679
|
+
ggml_vec_mul_f32(nc, st, st, s1);
|
14680
|
+
|
14681
|
+
ggml_vec_sum_f32(nc, sums + ith, st);
|
14682
|
+
|
14683
|
+
#ifndef NDEBUG
|
14684
|
+
for (int i = 0; i < nc; ++i) {
|
14685
|
+
assert(!isnan(st[i]));
|
14686
|
+
assert(!isinf(st[i]));
|
14687
|
+
}
|
14688
|
+
#endif
|
14689
|
+
}
|
14690
|
+
|
14691
|
+
}
|
14692
|
+
|
14693
|
+
static void ggml_compute_forward_cross_entropy_loss(
|
14694
|
+
const struct ggml_compute_params * params,
|
14695
|
+
const struct ggml_tensor * src0,
|
14696
|
+
const struct ggml_tensor * src1,
|
14697
|
+
struct ggml_tensor * dst) {
|
14698
|
+
switch (src0->type) {
|
14699
|
+
case GGML_TYPE_F32:
|
14700
|
+
{
|
14701
|
+
ggml_compute_forward_cross_entropy_loss_f32(params, src0, src1, dst);
|
14702
|
+
} break;
|
14703
|
+
default:
|
14704
|
+
{
|
14705
|
+
GGML_ASSERT(false);
|
14706
|
+
} break;
|
14707
|
+
}
|
14708
|
+
}
|
14709
|
+
|
14710
|
+
// ggml_compute_forward_cross_entropy_loss_back
|
14711
|
+
|
14712
|
+
static void ggml_compute_forward_cross_entropy_loss_back_f32(
|
14713
|
+
const struct ggml_compute_params * params,
|
14714
|
+
const struct ggml_tensor * src0,
|
14715
|
+
const struct ggml_tensor * src1,
|
14716
|
+
const struct ggml_tensor * opt0,
|
14717
|
+
struct ggml_tensor * dst) {
|
14718
|
+
GGML_ASSERT(ggml_is_contiguous(dst));
|
14719
|
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
14720
|
+
GGML_ASSERT(ggml_is_contiguous(src1));
|
14721
|
+
GGML_ASSERT(ggml_is_contiguous(opt0));
|
14722
|
+
GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
|
14723
|
+
|
14724
|
+
const int64_t ith = params->ith;
|
14725
|
+
const int64_t nth = params->nth;
|
14726
|
+
|
14727
|
+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
14728
|
+
return;
|
14729
|
+
}
|
14730
|
+
|
14731
|
+
const float eps = 1e-9f;
|
14732
|
+
|
14733
|
+
// TODO: handle transposed/permuted matrices
|
14734
|
+
const int64_t nc = src0->ne[0];
|
14735
|
+
const int64_t nr = ggml_nrows(src0);
|
14736
|
+
|
14737
|
+
// rows per thread
|
14738
|
+
const int64_t dr = (nr + nth - 1)/nth;
|
14739
|
+
|
14740
|
+
// row range for this thread
|
14741
|
+
const int64_t ir0 = dr*ith;
|
14742
|
+
const int64_t ir1 = MIN(ir0 + dr, nr);
|
14743
|
+
|
14744
|
+
float * d = (float *) opt0->data;
|
14745
|
+
|
14746
|
+
for (int64_t i1 = ir0; i1 < ir1; i1++) {
|
14747
|
+
float * ds0 = (float *)((char *) dst->data + i1*dst->nb[1]);
|
14748
|
+
float * s0 = (float *)((char *) src0->data + i1*src0->nb[1]);
|
14749
|
+
float * s1 = (float *)((char *) src1->data + i1*src1->nb[1]);
|
14750
|
+
float * sm = (float *) params->wdata + ith*nc;
|
14751
|
+
|
14752
|
+
#ifndef NDEBUG
|
14753
|
+
for (int i = 0; i < nc; ++i) {
|
14754
|
+
//printf("p[%d] = %f\n", i, p[i]);
|
14755
|
+
assert(!isnan(s0[i]));
|
14756
|
+
assert(!isnan(s1[i]));
|
14757
|
+
}
|
14758
|
+
#endif
|
14759
|
+
// step by step explanation:
|
14760
|
+
{
|
14761
|
+
//float * sums = (float *) params->wdata;
|
14762
|
+
|
14763
|
+
// forward pass with annotated gradients from backward pass
|
14764
|
+
// (built by going in reverse operation order, adding to gradients of current operation args)
|
14765
|
+
// st0 = exp(s0-max(s0)) grad[st0] = grad[st1]*(1.0 - eps)/sum
|
14766
|
+
// from softmax_back: grad[s0] = st1_k * (grad[st1]_k - dot(st1, grad[st1]))
|
14767
|
+
// ggml_vec_scale_f32(nc, st, sum); // st1 = st0*/sum = softmax(s0) grad[st1] = grad[st2]*(1.0 - eps)
|
14768
|
+
// ggml_vec_scale_f32(nc, st, (1.0f - eps)); // st2 = st1*(1.0 - eps) grad[st2] = grad[st3]
|
14769
|
+
// ggml_vec_add1_f32(nc, st, st, eps); // st3 = st2 + eps grad[st3] = grad[st4]/st3
|
14770
|
+
// ggml_vec_log_f32(nc, st, st); // st4 = log(st3) grad[st4] = grad[st5] * s1
|
14771
|
+
// ggml_vec_mul_f32(nc, st, st, s1); // st5 = st4 * s1 grad[st5] = grad[sums[ith]]
|
14772
|
+
// ggml_vec_sum_f32(nc, sums + ith, st); // sums[ith] = st5 grad[sums[ith]] = grad[cross_entropy_loss] = -grad[cel]
|
14773
|
+
|
14774
|
+
// substitute into grad[st1], because we can reuse softmax_back from this point on
|
14775
|
+
// grad[st1] = -grad[cel]*s1*(1.0 - eps)/(eps + softmax(s0)*(1.0 - eps))
|
14776
|
+
// postorder:
|
14777
|
+
// grad[st1] := softmax(s0)
|
14778
|
+
// grad[st1] := grad[st1]*(1.0 - eps)
|
14779
|
+
// grad[st1] := grad[st1] + eps
|
14780
|
+
// grad[st1] := s1 / grad[st1]
|
14781
|
+
// grad[st1] := grad[st1]*(1.0-eps)*-grad[cel]
|
14782
|
+
|
14783
|
+
// src0 gradients by going through softmax_back
|
14784
|
+
// grad[s0] = st1_k * (grad[st1]_k - dot(st1, grad[st1]))
|
14785
|
+
// from softmax_back:
|
14786
|
+
// dxk = yk * (dyk - dot(y, dy))
|
14787
|
+
// dot_y_dy := dot(y, dy)
|
14788
|
+
// dx := dy
|
14789
|
+
// dx := dx - dot_y_dy
|
14790
|
+
// dx := dx * y
|
14791
|
+
// postorder:
|
14792
|
+
// dot_st1_dst1 := dot(st1, grad[st1])
|
14793
|
+
// grad[s0] := grad[st1]
|
14794
|
+
// grad[s0] := grad[s0] - dot_st1_dst1
|
14795
|
+
// grad[s0] := grad[s0] * st1
|
14796
|
+
|
14797
|
+
// prepend postorder from grad[st1] directly using grad[s0] as memory location, as we will grad[s0] := grad[st1]
|
14798
|
+
// sm := softmax(s0)
|
14799
|
+
// grad[s0] := sm*(1.0 - eps)
|
14800
|
+
// grad[s0] := grad[s0] + eps
|
14801
|
+
// grad[s0] := s1 / grad[s0]
|
14802
|
+
// grad[s0] := grad[s0]*(1.0-eps)*-grad[cel]
|
14803
|
+
// dot_st1_dst1 := dot(sm, grad[s0])
|
14804
|
+
// grad[s0] := grad[s0] - dot_st1_dst1
|
14805
|
+
// grad[s0] := grad[s0] * sm
|
14806
|
+
}
|
14807
|
+
|
14808
|
+
// soft_max
|
14809
|
+
ggml_float sum = 0.0;
|
14810
|
+
{
|
14811
|
+
float max = -INFINITY;
|
14812
|
+
ggml_vec_max_f32(nc, &max, s0);
|
14813
|
+
|
14814
|
+
uint16_t scvt;
|
14815
|
+
for (int i = 0; i < nc; i++) {
|
14816
|
+
if (s0[i] == -INFINITY) {
|
14817
|
+
sm[i] = 0.0f;
|
14818
|
+
} else {
|
14819
|
+
// const float val = (s0[i] == -INFINITY) ? 0.0 : exp(s0[i] - max);
|
14820
|
+
ggml_fp16_t s = GGML_FP32_TO_FP16(s0[i] - max);
|
14821
|
+
memcpy(&scvt, &s, sizeof(scvt));
|
14822
|
+
const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt]);
|
14823
|
+
sum += (ggml_float)val;
|
14824
|
+
sm[i] = val;
|
14825
|
+
}
|
14826
|
+
}
|
14827
|
+
|
14828
|
+
assert(sum > 0.0);
|
14829
|
+
sum = 1.0/sum;
|
14830
|
+
}
|
13002
14831
|
|
13003
|
-
|
13004
|
-
|
13005
|
-
|
14832
|
+
float dot_st1_dst1 = 0;
|
14833
|
+
ggml_vec_scale_f32(nc, sm, sum);
|
14834
|
+
ggml_vec_cpy_f32 (nc, ds0, sm);
|
14835
|
+
ggml_vec_scale_f32(nc, ds0, (1.0f - eps));
|
14836
|
+
ggml_vec_add1_f32 (nc, ds0, ds0, eps);
|
14837
|
+
ggml_vec_div_f32 (nc, ds0, s1, ds0);
|
14838
|
+
ggml_vec_scale_f32(nc, ds0, -(1.0f - eps)*d[0]);
|
14839
|
+
ggml_vec_dot_f32 (nc, &dot_st1_dst1, sm, ds0);
|
14840
|
+
ggml_vec_acc1_f32 (nc, ds0, -dot_st1_dst1);
|
14841
|
+
ggml_vec_mul_f32 (nc, ds0, ds0, sm);
|
13006
14842
|
|
13007
|
-
|
13008
|
-
|
13009
|
-
|
13010
|
-
|
13011
|
-
|
14843
|
+
#ifndef NDEBUG
|
14844
|
+
for (int i = 0; i < nc; ++i) {
|
14845
|
+
assert(!isnan(sm[i]));
|
14846
|
+
assert(!isinf(sm[i]));
|
14847
|
+
assert(!isnan(ds0[i]));
|
14848
|
+
assert(!isinf(ds0[i]));
|
14849
|
+
}
|
14850
|
+
#endif
|
13012
14851
|
}
|
13013
14852
|
}
|
13014
14853
|
|
13015
|
-
|
13016
|
-
static void ggml_compute_forward_map_binary(
|
14854
|
+
static void ggml_compute_forward_cross_entropy_loss_back(
|
13017
14855
|
const struct ggml_compute_params * params,
|
13018
14856
|
const struct ggml_tensor * src0,
|
13019
14857
|
const struct ggml_tensor * src1,
|
13020
|
-
struct ggml_tensor *
|
13021
|
-
|
14858
|
+
const struct ggml_tensor * opt0,
|
14859
|
+
struct ggml_tensor * dst) {
|
13022
14860
|
switch (src0->type) {
|
13023
14861
|
case GGML_TYPE_F32:
|
13024
14862
|
{
|
13025
|
-
|
14863
|
+
ggml_compute_forward_cross_entropy_loss_back_f32(params, src0, src1, opt0, dst);
|
13026
14864
|
} break;
|
13027
14865
|
default:
|
13028
14866
|
{
|
@@ -13031,6 +14869,7 @@ static void ggml_compute_forward_map_binary(
|
|
13031
14869
|
}
|
13032
14870
|
}
|
13033
14871
|
|
14872
|
+
|
13034
14873
|
/////////////////////////////////
|
13035
14874
|
|
13036
14875
|
static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
|
@@ -13102,6 +14941,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
13102
14941
|
{
|
13103
14942
|
ggml_compute_forward_repeat(params, tensor->src0, tensor);
|
13104
14943
|
} break;
|
14944
|
+
case GGML_OP_REPEAT_BACK:
|
14945
|
+
{
|
14946
|
+
ggml_compute_forward_repeat_back(params, tensor->src0, tensor);
|
14947
|
+
} break;
|
13105
14948
|
case GGML_OP_ABS:
|
13106
14949
|
{
|
13107
14950
|
ggml_compute_forward_abs(params, tensor->src0, tensor);
|
@@ -13126,6 +14969,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
13126
14969
|
{
|
13127
14970
|
ggml_compute_forward_gelu(params, tensor->src0, tensor);
|
13128
14971
|
} break;
|
14972
|
+
case GGML_OP_GELU_QUICK:
|
14973
|
+
{
|
14974
|
+
ggml_compute_forward_gelu_quick(params, tensor->src0, tensor);
|
14975
|
+
} break;
|
13129
14976
|
case GGML_OP_SILU:
|
13130
14977
|
{
|
13131
14978
|
ggml_compute_forward_silu(params, tensor->src0, tensor);
|
@@ -13150,6 +14997,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
13150
14997
|
{
|
13151
14998
|
ggml_compute_forward_mul_mat(params, tensor->src0, tensor->src1, tensor);
|
13152
14999
|
} break;
|
15000
|
+
case GGML_OP_OUT_PROD:
|
15001
|
+
{
|
15002
|
+
ggml_compute_forward_out_prod(params, tensor->src0, tensor->src1, tensor);
|
15003
|
+
} break;
|
13153
15004
|
case GGML_OP_SCALE:
|
13154
15005
|
{
|
13155
15006
|
ggml_compute_forward_scale(params, tensor->src0, tensor->src1, tensor);
|
@@ -13206,6 +15057,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
13206
15057
|
{
|
13207
15058
|
ggml_compute_forward_soft_max(params, tensor->src0, tensor);
|
13208
15059
|
} break;
|
15060
|
+
case GGML_OP_SOFT_MAX_BACK:
|
15061
|
+
{
|
15062
|
+
ggml_compute_forward_soft_max_back(params, tensor->src0, tensor->src1, tensor);
|
15063
|
+
} break;
|
13209
15064
|
case GGML_OP_ROPE:
|
13210
15065
|
{
|
13211
15066
|
ggml_compute_forward_rope(params, tensor->src0, tensor->src1, tensor);
|
@@ -13222,25 +15077,44 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
13222
15077
|
{
|
13223
15078
|
ggml_compute_forward_clamp(params, tensor->src0, tensor->src1, tensor);
|
13224
15079
|
} break;
|
13225
|
-
case
|
15080
|
+
case GGML_OP_CONV_1D_S1_PH:
|
15081
|
+
{
|
15082
|
+
ggml_compute_forward_conv_1d_s1_ph(params, tensor->src0, tensor->src1, tensor);
|
15083
|
+
} break;
|
15084
|
+
case GGML_OP_CONV_1D_S2_PH:
|
13226
15085
|
{
|
13227
|
-
|
15086
|
+
ggml_compute_forward_conv_1d_s2_ph(params, tensor->src0, tensor->src1, tensor);
|
13228
15087
|
} break;
|
13229
|
-
case
|
15088
|
+
case GGML_OP_CONV_2D_SK_P0:
|
13230
15089
|
{
|
13231
|
-
|
15090
|
+
ggml_compute_forward_conv_2d_sk_p0(params, tensor->src0, tensor->src1, tensor);
|
13232
15091
|
} break;
|
13233
15092
|
case GGML_OP_FLASH_ATTN:
|
13234
15093
|
{
|
13235
|
-
int32_t t = ggml_get_i32_1d(tensor->opt[1], 0);
|
15094
|
+
const int32_t t = ggml_get_i32_1d(tensor->opt[1], 0);
|
13236
15095
|
GGML_ASSERT(t == 0 || t == 1);
|
13237
|
-
bool masked = t != 0;
|
15096
|
+
const bool masked = t != 0;
|
13238
15097
|
ggml_compute_forward_flash_attn(params, tensor->src0, tensor->src1, tensor->opt[0], masked, tensor);
|
13239
15098
|
} break;
|
13240
15099
|
case GGML_OP_FLASH_FF:
|
13241
15100
|
{
|
13242
15101
|
ggml_compute_forward_flash_ff(params, tensor->src0, tensor->src1, tensor->opt[0], tensor->opt[1], tensor->opt[2], tensor);
|
13243
15102
|
} break;
|
15103
|
+
case GGML_OP_FLASH_ATTN_BACK:
|
15104
|
+
{
|
15105
|
+
int32_t t = ggml_get_i32_1d(tensor->opt[2], 0);
|
15106
|
+
GGML_ASSERT(t == 0 || t == 1);
|
15107
|
+
bool masked = t != 0;
|
15108
|
+
ggml_compute_forward_flash_attn_back(params, tensor->src0, tensor->src1, tensor->opt[0], tensor->opt[1], masked, tensor);
|
15109
|
+
} break;
|
15110
|
+
case GGML_OP_WIN_PART:
|
15111
|
+
{
|
15112
|
+
ggml_compute_forward_win_part(params, tensor->src0, tensor->opt[0], tensor);
|
15113
|
+
} break;
|
15114
|
+
case GGML_OP_WIN_UNPART:
|
15115
|
+
{
|
15116
|
+
ggml_compute_forward_win_unpart(params, tensor->src0, tensor->opt[0], tensor);
|
15117
|
+
} break;
|
13244
15118
|
case GGML_OP_MAP_UNARY:
|
13245
15119
|
{
|
13246
15120
|
const ggml_unary_op_f32_t fun = *((ggml_unary_op_f32_t *)tensor->opt[0]->data);
|
@@ -13253,6 +15127,16 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
13253
15127
|
ggml_compute_forward_map_binary(params, tensor->src0, tensor->src1, tensor, fun);
|
13254
15128
|
}
|
13255
15129
|
break;
|
15130
|
+
case GGML_OP_CROSS_ENTROPY_LOSS:
|
15131
|
+
{
|
15132
|
+
ggml_compute_forward_cross_entropy_loss(params, tensor->src0, tensor->src1, tensor);
|
15133
|
+
}
|
15134
|
+
break;
|
15135
|
+
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
15136
|
+
{
|
15137
|
+
ggml_compute_forward_cross_entropy_loss_back(params, tensor->src0, tensor->src1, tensor->opt[0], tensor);
|
15138
|
+
}
|
15139
|
+
break;
|
13256
15140
|
case GGML_OP_NONE:
|
13257
15141
|
{
|
13258
15142
|
// nop
|
@@ -13391,11 +15275,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
13391
15275
|
src0->grad =
|
13392
15276
|
ggml_add_impl(ctx,
|
13393
15277
|
src0->grad,
|
13394
|
-
|
13395
|
-
tensor->grad, // this was not catched by test_grad because in test_grad tensor->grad is 1
|
15278
|
+
ggml_scale(ctx,
|
13396
15279
|
ggml_div(ctx,
|
13397
|
-
|
13398
|
-
tensor)
|
15280
|
+
tensor->grad,
|
15281
|
+
tensor),
|
15282
|
+
ggml_new_f32(ctx, 0.5f)),
|
13399
15283
|
inplace);
|
13400
15284
|
}
|
13401
15285
|
} break;
|
@@ -13441,43 +15325,20 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
13441
15325
|
{
|
13442
15326
|
// necessary for llama
|
13443
15327
|
if (src0->grad) {
|
13444
|
-
|
13445
|
-
|
13446
|
-
|
13447
|
-
|
13448
|
-
|
13449
|
-
|
13450
|
-
|
13451
|
-
|
13452
|
-
|
13453
|
-
//
|
13454
|
-
|
13455
|
-
|
13456
|
-
|
13457
|
-
|
13458
|
-
// transpose [nc0*nr0,1,1]
|
13459
|
-
// reshape [nc0,nr0,1,1] reshape_1d or reshape_2d
|
13460
|
-
// add to src0->grad
|
13461
|
-
|
13462
|
-
int64_t ne[4] = {nc0,ncr,nr0,nrr};
|
13463
|
-
|
13464
|
-
struct ggml_tensor* F00 = tensor->grad;
|
13465
|
-
struct ggml_tensor* F01 = ggml_reshape (ctx, F00, ggml_new_tensor(ctx,tensor->grad->type,4,ne));
|
13466
|
-
struct ggml_tensor* F02 = ggml_permute (ctx, F01, 0,2,1,3);
|
13467
|
-
struct ggml_tensor* F03 = ggml_cont (ctx, F02);
|
13468
|
-
struct ggml_tensor* F04 = ggml_reshape_2d(ctx, F03, nc0*nr0, ncr*nrr);
|
13469
|
-
struct ggml_tensor* F05 = ggml_transpose (ctx, F04);
|
13470
|
-
struct ggml_tensor* F06 = ggml_cont (ctx, F05);
|
13471
|
-
struct ggml_tensor* F07 = ggml_sum_rows (ctx, F06);
|
13472
|
-
struct ggml_tensor* F08 = ggml_transpose (ctx, F07);
|
13473
|
-
struct ggml_tensor* F09 = ggml_cont (ctx, F08);
|
13474
|
-
struct ggml_tensor* F10 = ggml_reshape (ctx, F09, src0->grad);
|
13475
|
-
|
13476
|
-
src0->grad =
|
13477
|
-
ggml_add_impl(ctx,
|
13478
|
-
src0->grad,
|
13479
|
-
F10,
|
13480
|
-
inplace);
|
15328
|
+
src0->grad = ggml_add_impl(ctx,
|
15329
|
+
src0->grad,
|
15330
|
+
ggml_repeat_back(ctx, tensor->grad, src0->grad),
|
15331
|
+
inplace);
|
15332
|
+
}
|
15333
|
+
} break;
|
15334
|
+
case GGML_OP_REPEAT_BACK:
|
15335
|
+
{
|
15336
|
+
if (src0->grad) {
|
15337
|
+
// TODO: test this
|
15338
|
+
src0->grad = ggml_add_impl(ctx,
|
15339
|
+
src0->grad,
|
15340
|
+
ggml_repeat(ctx, tensor->grad, src0->grad),
|
15341
|
+
inplace);
|
13481
15342
|
}
|
13482
15343
|
} break;
|
13483
15344
|
case GGML_OP_ABS:
|
@@ -13525,6 +15386,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
13525
15386
|
{
|
13526
15387
|
GGML_ASSERT(false); // TODO: not implemented
|
13527
15388
|
} break;
|
15389
|
+
case GGML_OP_GELU_QUICK:
|
15390
|
+
{
|
15391
|
+
GGML_ASSERT(false); // TODO: not implemented
|
15392
|
+
} break;
|
13528
15393
|
case GGML_OP_ALIBI:
|
13529
15394
|
{
|
13530
15395
|
GGML_ASSERT(false); // TODO: not implemented
|
@@ -13584,38 +15449,37 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
13584
15449
|
|
13585
15450
|
// necessary for llama
|
13586
15451
|
if (src0->grad) {
|
13587
|
-
// TODO: this requires outer product - ggml_out_prod(ctx, src1, tensor->grad);
|
13588
15452
|
src0->grad =
|
13589
15453
|
ggml_add_impl(ctx,
|
13590
15454
|
src0->grad,
|
13591
|
-
//
|
13592
|
-
|
13593
|
-
|
13594
|
-
// tensor->grad), // [m,p]
|
13595
|
-
// for now just using A*B==(B.T*A.T).T
|
13596
|
-
ggml_cont(ctx, // [n,m]
|
13597
|
-
ggml_transpose(ctx, // [n,m]
|
13598
|
-
ggml_mul_mat(ctx, // [m,n]
|
13599
|
-
ggml_cont(ctx, // [p,m]
|
13600
|
-
ggml_transpose(ctx, // [p,m]
|
13601
|
-
tensor->grad)), // [m,p]
|
13602
|
-
ggml_cont(ctx, // [p,n]
|
13603
|
-
ggml_transpose(ctx, // [p,n]
|
13604
|
-
src1))))), // [n,p]
|
15455
|
+
ggml_out_prod(ctx, // [n,m]
|
15456
|
+
src1, // [n,p]
|
15457
|
+
tensor->grad), // [m,p]
|
13605
15458
|
inplace);
|
13606
15459
|
}
|
13607
15460
|
if (src1->grad) {
|
13608
15461
|
src1->grad =
|
13609
15462
|
ggml_add_impl(ctx,
|
13610
15463
|
src1->grad,
|
13611
|
-
//
|
13612
|
-
|
13613
|
-
|
13614
|
-
|
13615
|
-
|
15464
|
+
// ggml_mul_mat(ctx, // [n,p]
|
15465
|
+
// ggml_cont(ctx, // [m,n]
|
15466
|
+
// ggml_transpose(ctx, src0)), // [m,n]
|
15467
|
+
// tensor->grad), // [m,p]
|
15468
|
+
|
15469
|
+
// // when src0 is bigger than tensor->grad (this is mostly the case in llama),
|
15470
|
+
// // avoid transpose of src0, rather transpose smaller tensor->grad
|
15471
|
+
// // and then use ggml_out_prod
|
15472
|
+
ggml_out_prod(ctx, // [n,p]
|
15473
|
+
src0, // [n,m]
|
15474
|
+
ggml_transpose(ctx, // [p,m]
|
15475
|
+
tensor->grad)), // [m,p]
|
13616
15476
|
inplace);
|
13617
15477
|
}
|
13618
15478
|
} break;
|
15479
|
+
case GGML_OP_OUT_PROD:
|
15480
|
+
{
|
15481
|
+
GGML_ASSERT(false); // TODO: not implemented
|
15482
|
+
} break;
|
13619
15483
|
case GGML_OP_SCALE:
|
13620
15484
|
{
|
13621
15485
|
// necessary for llama
|
@@ -13717,7 +15581,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
13717
15581
|
// necessary for llama
|
13718
15582
|
if (src0->grad) {
|
13719
15583
|
size_t offset;
|
13720
|
-
|
15584
|
+
|
15585
|
+
GGML_ASSERT(sizeof(offset) <= ggml_nbytes(tensor->opt[0]));
|
15586
|
+
memcpy(&offset, tensor->opt[0]->data, sizeof(offset));
|
13721
15587
|
|
13722
15588
|
size_t nb1 = tensor->nb[1];
|
13723
15589
|
size_t nb2 = tensor->nb[2];
|
@@ -13744,10 +15610,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
13744
15610
|
{
|
13745
15611
|
// necessary for llama
|
13746
15612
|
if (src0->grad) {
|
13747
|
-
|
13748
|
-
int
|
13749
|
-
int
|
13750
|
-
int
|
15613
|
+
int32_t * axes = (int32_t *) tensor->opt[0]->data;
|
15614
|
+
int axis0 = axes[0] & 0x3;
|
15615
|
+
int axis1 = axes[1] & 0x3;
|
15616
|
+
int axis2 = axes[2] & 0x3;
|
15617
|
+
int axis3 = axes[3] & 0x3;
|
13751
15618
|
int axes_backward[4] = {0,0,0,0};
|
13752
15619
|
axes_backward[axis0] = 0;
|
13753
15620
|
axes_backward[axis1] = 1;
|
@@ -13831,50 +15698,16 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
13831
15698
|
{
|
13832
15699
|
// necessary for llama
|
13833
15700
|
if (src0->grad) {
|
13834
|
-
// y = softmax(x)
|
13835
|
-
//
|
13836
|
-
// Jii = yi - yi*yi
|
13837
|
-
// Jij = -yi*yj
|
13838
|
-
// J = diag(y)-y.*y
|
13839
|
-
// dx = J * dy
|
13840
|
-
// dxk = sum(Jkj * dyk)
|
13841
|
-
|
13842
|
-
int64_t ne2[4] = {
|
13843
|
-
tensor->ne[0],
|
13844
|
-
1,
|
13845
|
-
tensor->ne[1]*tensor->ne[2],
|
13846
|
-
tensor->ne[3]
|
13847
|
-
};
|
13848
|
-
struct ggml_tensor * tensor2 = ggml_cont(ctx,
|
13849
|
-
ggml_reshape_4d(ctx,
|
13850
|
-
ggml_cont(ctx, tensor),
|
13851
|
-
ne2[0], ne2[1], ne2[2], ne2[3]));
|
13852
|
-
|
13853
|
-
struct ggml_tensor * grad2 = ggml_cont(ctx,
|
13854
|
-
ggml_reshape_4d(ctx,
|
13855
|
-
ggml_cont(ctx, tensor->grad),
|
13856
|
-
ne2[0], ne2[1], ne2[2], ne2[3]));
|
13857
|
-
|
13858
|
-
struct ggml_tensor * tensor2_t = ggml_cont(ctx, // [1,ne0,ne1*ne2,ne3]
|
13859
|
-
ggml_permute(ctx, // [1,ne0,ne1*ne2,ne3]
|
13860
|
-
tensor2, // [ne0,1,ne1*ne2,ne3]
|
13861
|
-
1, 0, 2, 3));
|
13862
|
-
|
13863
15701
|
src0->grad =
|
13864
|
-
ggml_add_impl(ctx,
|
13865
|
-
|
13866
|
-
|
13867
|
-
ggml_mul_mat(ctx, // [ne0,1,ne1*ne2,ne3]
|
13868
|
-
ggml_sub(ctx, // [ne0,ne0,ne1*ne2,ne3]
|
13869
|
-
ggml_diag(ctx, // [ne0,ne0,ne1*ne2,ne3]
|
13870
|
-
tensor2), // [ne0,1,ne1*ne2,ne3]
|
13871
|
-
ggml_mul_mat(ctx, // [ne0,ne0,ne1*ne2,ne3]
|
13872
|
-
tensor2_t, // [1,ne0,ne1*ne2,ne3]
|
13873
|
-
tensor2_t)), // [1,ne0,ne1*ne2,ne3]
|
13874
|
-
grad2), // [ne0,1,ne1*ne2,ne3]
|
13875
|
-
src0->grad),
|
13876
|
-
inplace);
|
15702
|
+
ggml_add_impl(ctx, src0->grad,
|
15703
|
+
ggml_soft_max_back(ctx, tensor->grad, tensor),
|
15704
|
+
inplace);
|
13877
15705
|
}
|
15706
|
+
|
15707
|
+
} break;
|
15708
|
+
case GGML_OP_SOFT_MAX_BACK:
|
15709
|
+
{
|
15710
|
+
GGML_ASSERT(false); // TODO: not implemented
|
13878
15711
|
} break;
|
13879
15712
|
case GGML_OP_ROPE:
|
13880
15713
|
{
|
@@ -13919,27 +15752,206 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
13919
15752
|
// noop
|
13920
15753
|
}
|
13921
15754
|
} break;
|
13922
|
-
case
|
15755
|
+
case GGML_OP_CONV_1D_S1_PH:
|
15756
|
+
{
|
15757
|
+
GGML_ASSERT(false); // TODO: not implemented
|
15758
|
+
} break;
|
15759
|
+
case GGML_OP_CONV_1D_S2_PH:
|
13923
15760
|
{
|
13924
15761
|
GGML_ASSERT(false); // TODO: not implemented
|
13925
15762
|
} break;
|
13926
|
-
case
|
15763
|
+
case GGML_OP_CONV_2D_SK_P0:
|
13927
15764
|
{
|
13928
15765
|
GGML_ASSERT(false); // TODO: not implemented
|
13929
15766
|
} break;
|
13930
15767
|
case GGML_OP_FLASH_ATTN:
|
13931
15768
|
{
|
13932
|
-
|
15769
|
+
struct ggml_tensor * flash_grad = NULL;
|
15770
|
+
if (src0->grad || src1->grad || tensor->opt[0]->grad) {
|
15771
|
+
int32_t t = ggml_get_i32_1d(tensor->opt[1], 0);
|
15772
|
+
GGML_ASSERT(t == 0 || t == 1);
|
15773
|
+
bool masked = t != 0;
|
15774
|
+
flash_grad =
|
15775
|
+
ggml_flash_attn_back(ctx,
|
15776
|
+
src0,
|
15777
|
+
src1,
|
15778
|
+
tensor->opt[0],
|
15779
|
+
tensor->grad,
|
15780
|
+
masked);
|
15781
|
+
}
|
15782
|
+
|
15783
|
+
if (src0->grad) {
|
15784
|
+
struct ggml_tensor * grad_q = NULL;
|
15785
|
+
const size_t nb0 = flash_grad->nb[0];
|
15786
|
+
const size_t offset = 0;
|
15787
|
+
switch(src0->n_dims) {
|
15788
|
+
case 2:
|
15789
|
+
{
|
15790
|
+
grad_q = ggml_view_2d(ctx,
|
15791
|
+
flash_grad,
|
15792
|
+
src0->ne[0],
|
15793
|
+
src0->ne[1],
|
15794
|
+
nb0*src0->ne[0],
|
15795
|
+
offset);
|
15796
|
+
} break;
|
15797
|
+
case 3:
|
15798
|
+
{
|
15799
|
+
grad_q = ggml_view_3d(ctx,
|
15800
|
+
flash_grad,
|
15801
|
+
src0->ne[0],
|
15802
|
+
src0->ne[1],
|
15803
|
+
src0->ne[2],
|
15804
|
+
nb0*src0->ne[0],
|
15805
|
+
nb0*src0->ne[0]*src0->ne[1],
|
15806
|
+
offset);
|
15807
|
+
} break;
|
15808
|
+
case 4:
|
15809
|
+
{
|
15810
|
+
grad_q = ggml_view_4d(ctx,
|
15811
|
+
flash_grad,
|
15812
|
+
src0->ne[0],
|
15813
|
+
src0->ne[1],
|
15814
|
+
src0->ne[2],
|
15815
|
+
src0->ne[3],
|
15816
|
+
nb0*src0->ne[0],
|
15817
|
+
nb0*src0->ne[0]*src0->ne[1],
|
15818
|
+
nb0*src0->ne[0]*src0->ne[1]*src0->ne[2],
|
15819
|
+
offset);
|
15820
|
+
} break;
|
15821
|
+
}
|
15822
|
+
|
15823
|
+
src0->grad = ggml_add_impl(ctx,
|
15824
|
+
src0->grad,
|
15825
|
+
grad_q,
|
15826
|
+
inplace);
|
15827
|
+
}
|
15828
|
+
|
15829
|
+
if (src1->grad) {
|
15830
|
+
struct ggml_tensor * grad_k = NULL;
|
15831
|
+
const size_t nb0 = flash_grad->nb[0];
|
15832
|
+
const size_t offset = nb0*src0->ne[0]*src0->ne[1]*src0->ne[2]*src0->ne[3];
|
15833
|
+
switch(src1->n_dims) {
|
15834
|
+
case 2:
|
15835
|
+
{
|
15836
|
+
grad_k = ggml_view_2d(ctx,
|
15837
|
+
flash_grad,
|
15838
|
+
src1->ne[0],
|
15839
|
+
src1->ne[1],
|
15840
|
+
nb0*src1->ne[0],
|
15841
|
+
offset);
|
15842
|
+
} break;
|
15843
|
+
case 3:
|
15844
|
+
{
|
15845
|
+
grad_k = ggml_view_3d(ctx,
|
15846
|
+
flash_grad,
|
15847
|
+
src1->ne[0],
|
15848
|
+
src1->ne[1],
|
15849
|
+
src1->ne[2],
|
15850
|
+
nb0*src1->ne[0],
|
15851
|
+
nb0*src1->ne[0]*src1->ne[1],
|
15852
|
+
offset);
|
15853
|
+
} break;
|
15854
|
+
case 4:
|
15855
|
+
{
|
15856
|
+
grad_k = ggml_view_4d(ctx,
|
15857
|
+
flash_grad,
|
15858
|
+
src1->ne[0],
|
15859
|
+
src1->ne[1],
|
15860
|
+
src1->ne[2],
|
15861
|
+
src1->ne[3],
|
15862
|
+
nb0*src1->ne[0],
|
15863
|
+
nb0*src1->ne[0]*src1->ne[1],
|
15864
|
+
nb0*src1->ne[0]*src1->ne[1]*src1->ne[2],
|
15865
|
+
offset);
|
15866
|
+
} break;
|
15867
|
+
}
|
15868
|
+
|
15869
|
+
src1->grad = ggml_add_impl(ctx,
|
15870
|
+
src1->grad,
|
15871
|
+
grad_k,
|
15872
|
+
inplace);
|
15873
|
+
}
|
15874
|
+
|
15875
|
+
struct ggml_tensor * opt0 = tensor->opt[0];
|
15876
|
+
|
15877
|
+
if (opt0->grad) {
|
15878
|
+
struct ggml_tensor * grad_v = NULL;
|
15879
|
+
const size_t nb0 = flash_grad->nb[0];
|
15880
|
+
const size_t offset = nb0*src0->ne[0]*src0->ne[1]*src0->ne[2]*src0->ne[3]
|
15881
|
+
+ nb0*src1->ne[0]*src1->ne[1]*src1->ne[2]*src1->ne[3];
|
15882
|
+
switch(opt0->n_dims) {
|
15883
|
+
case 2:
|
15884
|
+
{
|
15885
|
+
grad_v = ggml_view_2d(ctx,
|
15886
|
+
flash_grad,
|
15887
|
+
opt0->ne[0],
|
15888
|
+
opt0->ne[1],
|
15889
|
+
nb0*opt0->ne[0],
|
15890
|
+
offset);
|
15891
|
+
} break;
|
15892
|
+
case 3:
|
15893
|
+
{
|
15894
|
+
grad_v = ggml_view_3d(ctx,
|
15895
|
+
flash_grad,
|
15896
|
+
opt0->ne[0],
|
15897
|
+
opt0->ne[1],
|
15898
|
+
opt0->ne[2],
|
15899
|
+
nb0*opt0->ne[0],
|
15900
|
+
nb0*opt0->ne[0]*opt0->ne[1],
|
15901
|
+
offset);
|
15902
|
+
} break;
|
15903
|
+
case 4:
|
15904
|
+
{
|
15905
|
+
grad_v = ggml_view_4d(ctx,
|
15906
|
+
flash_grad,
|
15907
|
+
opt0->ne[0],
|
15908
|
+
opt0->ne[1],
|
15909
|
+
opt0->ne[2],
|
15910
|
+
opt0->ne[3],
|
15911
|
+
nb0*opt0->ne[0],
|
15912
|
+
nb0*opt0->ne[0]*opt0->ne[1],
|
15913
|
+
nb0*opt0->ne[0]*opt0->ne[1]*opt0->ne[2],
|
15914
|
+
offset);
|
15915
|
+
} break;
|
15916
|
+
}
|
15917
|
+
|
15918
|
+
opt0->grad = ggml_add_impl(ctx,
|
15919
|
+
opt0->grad,
|
15920
|
+
grad_v,
|
15921
|
+
inplace);
|
15922
|
+
}
|
13933
15923
|
} break;
|
13934
15924
|
case GGML_OP_FLASH_FF:
|
13935
15925
|
{
|
13936
15926
|
GGML_ASSERT(false); // not supported
|
13937
15927
|
} break;
|
15928
|
+
case GGML_OP_FLASH_ATTN_BACK:
|
15929
|
+
{
|
15930
|
+
GGML_ASSERT(false); // not supported
|
15931
|
+
} break;
|
15932
|
+
case GGML_OP_WIN_PART:
|
15933
|
+
case GGML_OP_WIN_UNPART:
|
13938
15934
|
case GGML_OP_MAP_UNARY:
|
13939
15935
|
case GGML_OP_MAP_BINARY:
|
13940
15936
|
{
|
13941
15937
|
GGML_ASSERT(false); // not supported
|
13942
15938
|
} break;
|
15939
|
+
case GGML_OP_CROSS_ENTROPY_LOSS:
|
15940
|
+
{
|
15941
|
+
if (src0->grad) {
|
15942
|
+
src0->grad = ggml_add_impl(ctx,
|
15943
|
+
src0->grad,
|
15944
|
+
ggml_cross_entropy_loss_back(ctx,
|
15945
|
+
src0,
|
15946
|
+
src1,
|
15947
|
+
tensor->grad),
|
15948
|
+
inplace);
|
15949
|
+
}
|
15950
|
+
} break;
|
15951
|
+
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
15952
|
+
{
|
15953
|
+
GGML_ASSERT(false); // not supported
|
15954
|
+
} break;
|
13943
15955
|
case GGML_OP_NONE:
|
13944
15956
|
{
|
13945
15957
|
// nop
|
@@ -14316,6 +16328,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|
14316
16328
|
case GGML_OP_SUM_ROWS:
|
14317
16329
|
case GGML_OP_MEAN:
|
14318
16330
|
case GGML_OP_REPEAT:
|
16331
|
+
case GGML_OP_REPEAT_BACK:
|
14319
16332
|
case GGML_OP_ABS:
|
14320
16333
|
case GGML_OP_SGN:
|
14321
16334
|
case GGML_OP_NEG:
|
@@ -14326,6 +16339,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|
14326
16339
|
} break;
|
14327
16340
|
case GGML_OP_MUL:
|
14328
16341
|
case GGML_OP_GELU:
|
16342
|
+
case GGML_OP_GELU_QUICK:
|
14329
16343
|
case GGML_OP_SILU:
|
14330
16344
|
case GGML_OP_SILU_BACK:
|
14331
16345
|
case GGML_OP_NORM:
|
@@ -14335,6 +16349,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|
14335
16349
|
node->n_tasks = n_threads;
|
14336
16350
|
} break;
|
14337
16351
|
case GGML_OP_MUL_MAT:
|
16352
|
+
case GGML_OP_OUT_PROD:
|
14338
16353
|
{
|
14339
16354
|
node->n_tasks = n_threads;
|
14340
16355
|
|
@@ -14417,6 +16432,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|
14417
16432
|
} break;
|
14418
16433
|
case GGML_OP_DIAG_MASK_INF:
|
14419
16434
|
case GGML_OP_SOFT_MAX:
|
16435
|
+
case GGML_OP_SOFT_MAX_BACK:
|
14420
16436
|
case GGML_OP_ROPE:
|
14421
16437
|
case GGML_OP_ROPE_BACK:
|
14422
16438
|
{
|
@@ -14430,8 +16446,8 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|
14430
16446
|
{
|
14431
16447
|
node->n_tasks = 1; //TODO
|
14432
16448
|
} break;
|
14433
|
-
case
|
14434
|
-
case
|
16449
|
+
case GGML_OP_CONV_1D_S1_PH:
|
16450
|
+
case GGML_OP_CONV_1D_S2_PH:
|
14435
16451
|
{
|
14436
16452
|
node->n_tasks = n_threads;
|
14437
16453
|
|
@@ -14458,6 +16474,41 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|
14458
16474
|
GGML_ASSERT(false);
|
14459
16475
|
}
|
14460
16476
|
|
16477
|
+
work_size = MAX(work_size, cur);
|
16478
|
+
} break;
|
16479
|
+
case GGML_OP_CONV_2D_SK_P0:
|
16480
|
+
{
|
16481
|
+
node->n_tasks = n_threads;
|
16482
|
+
|
16483
|
+
GGML_ASSERT(node->src1->ne[3] == 1);
|
16484
|
+
|
16485
|
+
const int64_t ne00 = node->src0->ne[0]; // W
|
16486
|
+
const int64_t ne01 = node->src0->ne[1]; // H
|
16487
|
+
const int64_t ne02 = node->src0->ne[2]; // C
|
16488
|
+
const int64_t ne03 = node->src0->ne[3]; // N
|
16489
|
+
|
16490
|
+
const int64_t ne10 = node->src1->ne[0]; // W
|
16491
|
+
const int64_t ne11 = node->src1->ne[1]; // H
|
16492
|
+
const int64_t ne12 = node->src1->ne[2]; // C
|
16493
|
+
|
16494
|
+
const int64_t nk = ne00*ne01;
|
16495
|
+
|
16496
|
+
UNUSED(ne02);
|
16497
|
+
UNUSED(ne03);
|
16498
|
+
UNUSED(nk);
|
16499
|
+
|
16500
|
+
size_t cur = 0;
|
16501
|
+
|
16502
|
+
if (node->src0->type == GGML_TYPE_F16 &&
|
16503
|
+
node->src1->type == GGML_TYPE_F32) {
|
16504
|
+
cur = sizeof(ggml_fp16_t)*(ne10*ne11*ne12);
|
16505
|
+
} else if (node->src0->type == GGML_TYPE_F32 &&
|
16506
|
+
node->src1->type == GGML_TYPE_F32) {
|
16507
|
+
cur = sizeof(float)* (ne10*ne11*ne12);
|
16508
|
+
} else {
|
16509
|
+
GGML_ASSERT(false);
|
16510
|
+
}
|
16511
|
+
|
14461
16512
|
work_size = MAX(work_size, cur);
|
14462
16513
|
} break;
|
14463
16514
|
case GGML_OP_FLASH_ATTN:
|
@@ -14498,11 +16549,50 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|
14498
16549
|
|
14499
16550
|
work_size = MAX(work_size, cur);
|
14500
16551
|
} break;
|
16552
|
+
case GGML_OP_FLASH_ATTN_BACK:
|
16553
|
+
{
|
16554
|
+
node->n_tasks = n_threads;
|
16555
|
+
|
16556
|
+
size_t cur = 0;
|
16557
|
+
|
16558
|
+
const int64_t D = node->src0->ne[0];
|
16559
|
+
const int64_t ne11 = ggml_up(node->src1->ne[1], GGML_SOFT_MAX_UNROLL);
|
16560
|
+
const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in ggml_compute_forward_flash_attn_back
|
16561
|
+
if (node->src1->type == GGML_TYPE_F32) {
|
16562
|
+
cur = sizeof(float)*mxDn*node->n_tasks; // TODO: this can become (n_tasks-1)
|
16563
|
+
cur += sizeof(float)*mxDn*node->n_tasks; // this is overestimated by x2
|
16564
|
+
}
|
16565
|
+
|
16566
|
+
if (node->src1->type == GGML_TYPE_F16) {
|
16567
|
+
cur = sizeof(float)*mxDn*node->n_tasks; // TODO: this can become (n_tasks-1)
|
16568
|
+
cur += sizeof(float)*mxDn*node->n_tasks; // this is overestimated by x2
|
16569
|
+
}
|
16570
|
+
|
16571
|
+
work_size = MAX(work_size, cur);
|
16572
|
+
} break;
|
16573
|
+
case GGML_OP_WIN_PART:
|
16574
|
+
case GGML_OP_WIN_UNPART:
|
14501
16575
|
case GGML_OP_MAP_UNARY:
|
14502
16576
|
case GGML_OP_MAP_BINARY:
|
14503
16577
|
{
|
14504
16578
|
node->n_tasks = 1;
|
14505
16579
|
} break;
|
16580
|
+
case GGML_OP_CROSS_ENTROPY_LOSS:
|
16581
|
+
{
|
16582
|
+
node->n_tasks = n_threads;
|
16583
|
+
|
16584
|
+
size_t cur = ggml_type_size(node->type)*(node->n_tasks + node->src0->ne[0]*node->n_tasks);
|
16585
|
+
|
16586
|
+
work_size = MAX(work_size, cur);
|
16587
|
+
} break;
|
16588
|
+
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
16589
|
+
{
|
16590
|
+
node->n_tasks = n_threads;
|
16591
|
+
|
16592
|
+
size_t cur = ggml_type_size(node->type)*node->src0->ne[0]*node->n_tasks;
|
16593
|
+
|
16594
|
+
work_size = MAX(work_size, cur);
|
16595
|
+
} break;
|
14506
16596
|
case GGML_OP_NONE:
|
14507
16597
|
{
|
14508
16598
|
node->n_tasks = 1;
|
@@ -15014,16 +17104,20 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context **
|
|
15014
17104
|
|
15015
17105
|
if (!*ctx_data) {
|
15016
17106
|
fprintf(stderr, "%s: failed to create ggml context\n", __func__);
|
17107
|
+
fclose(fin);
|
15017
17108
|
return result;
|
15018
17109
|
}
|
15019
17110
|
}
|
15020
17111
|
|
15021
17112
|
data = ggml_new_tensor_1d(*ctx_data, GGML_TYPE_I8, fsize);
|
15022
17113
|
|
15023
|
-
|
15024
|
-
|
15025
|
-
|
15026
|
-
|
17114
|
+
{
|
17115
|
+
const size_t ret = fread(data->data, sizeof(char), fsize, fin);
|
17116
|
+
if (ret != fsize) {
|
17117
|
+
fprintf(stderr, "%s: failed to read %s\n", __func__, fname);
|
17118
|
+
fclose(fin);
|
17119
|
+
return result;
|
17120
|
+
}
|
15027
17121
|
}
|
15028
17122
|
|
15029
17123
|
fclose(fin);
|
@@ -15478,6 +17572,7 @@ static void ggml_opt_get_grad(int np, struct ggml_tensor * const ps[], float * g
|
|
15478
17572
|
|
15479
17573
|
static enum ggml_opt_result ggml_opt_adam(
|
15480
17574
|
struct ggml_context * ctx,
|
17575
|
+
struct ggml_opt_context * opt,
|
15481
17576
|
struct ggml_opt_params params,
|
15482
17577
|
struct ggml_tensor * f,
|
15483
17578
|
struct ggml_cgraph * gf,
|
@@ -15503,25 +17598,29 @@ static enum ggml_opt_result ggml_opt_adam(
|
|
15503
17598
|
}
|
15504
17599
|
}
|
15505
17600
|
|
17601
|
+
if ((opt->params.type != params.type) || (opt->nx != nx) || (opt->params.past != params.past)) {
|
17602
|
+
int iter = opt->iter;
|
17603
|
+
ggml_opt_init(opt->ctx, opt, params, nx);
|
17604
|
+
opt->iter = iter;
|
17605
|
+
}
|
17606
|
+
|
15506
17607
|
// constants
|
15507
|
-
const float
|
17608
|
+
const float sched = params.adam.sched;
|
17609
|
+
const float decay = params.adam.decay * sched;
|
17610
|
+
const float alpha = params.adam.alpha * sched;
|
15508
17611
|
const float beta1 = params.adam.beta1;
|
15509
17612
|
const float beta2 = params.adam.beta2;
|
15510
17613
|
const float eps = params.adam.eps;
|
15511
17614
|
|
15512
|
-
float * x =
|
15513
|
-
float * g1 =
|
15514
|
-
float * g2 =
|
15515
|
-
float * m =
|
15516
|
-
float * v =
|
15517
|
-
float * mh =
|
15518
|
-
float * vh =
|
17615
|
+
float * x = opt->adam.x->data; // view of the parameters
|
17616
|
+
float * g1 = opt->adam.g1->data; // gradient
|
17617
|
+
float * g2 = opt->adam.g2->data; // gradient squared
|
17618
|
+
float * m = opt->adam.m->data; // first moment
|
17619
|
+
float * v = opt->adam.v->data; // second moment
|
17620
|
+
float * mh = opt->adam.mh->data; // first moment hat
|
17621
|
+
float * vh = opt->adam.vh->data; // second moment hat
|
15519
17622
|
|
15520
|
-
float * pf = params.past > 0 ?
|
15521
|
-
|
15522
|
-
// initialize
|
15523
|
-
ggml_vec_set_f32(nx, m, 0.0f);
|
15524
|
-
ggml_vec_set_f32(nx, v, 0.0f);
|
17623
|
+
float * pf = params.past > 0 ? opt->adam.pf->data : NULL; // past function values
|
15525
17624
|
|
15526
17625
|
// update view
|
15527
17626
|
ggml_opt_get_params(np, ps, x);
|
@@ -15531,16 +17630,27 @@ static enum ggml_opt_result ggml_opt_adam(
|
|
15531
17630
|
ggml_set_f32 (f->grad, 1.0f);
|
15532
17631
|
ggml_graph_compute(ctx, gb);
|
15533
17632
|
|
15534
|
-
|
17633
|
+
opt->adam.fx_prev = ggml_get_f32_1d(f, 0);
|
17634
|
+
opt->adam.fx_best = opt->adam.fx_prev;
|
15535
17635
|
if (pf) {
|
15536
|
-
pf[
|
17636
|
+
pf[opt->iter % params.past] = opt->adam.fx_prev;
|
15537
17637
|
}
|
15538
17638
|
|
15539
|
-
|
15540
|
-
|
17639
|
+
// initialize
|
17640
|
+
if (opt->just_initialized) {
|
17641
|
+
opt->adam.n_no_improvement = 0;
|
17642
|
+
opt->just_initialized = false;
|
17643
|
+
}
|
17644
|
+
|
17645
|
+
float * fx_best = &opt->adam.fx_best;
|
17646
|
+
float * fx_prev = &opt->adam.fx_prev;
|
17647
|
+
int * n_no_improvement = &opt->adam.n_no_improvement;
|
17648
|
+
|
17649
|
+
int iter0 = opt->iter;
|
15541
17650
|
|
15542
17651
|
// run the optimizer
|
15543
17652
|
for (int t = 0; t < params.adam.n_iter; ++t) {
|
17653
|
+
opt->iter = iter0 + t + 1;
|
15544
17654
|
GGML_PRINT_DEBUG ("=== iter %d ===\n", t);
|
15545
17655
|
|
15546
17656
|
GGML_PRINT_DEBUG ("f = %10.6f\n", ggml_get_f32_1d(f, 0));
|
@@ -15574,17 +17684,22 @@ static enum ggml_opt_result ggml_opt_adam(
|
|
15574
17684
|
|
15575
17685
|
// m^hat = m_t / (1 - beta1^t)
|
15576
17686
|
// v^hat = v_t / (1 - beta2^t)
|
15577
|
-
// x_t = x_t-1 - alpha*m^hat/(sqrt(v^hat) + eps)
|
17687
|
+
// x_t = x_t-1 - sched*(alpha*m^hat/(sqrt(v^hat) + eps) + decay*x_t-1)
|
17688
|
+
// x_t = x_t-1 - sched*alpha*m^hat/(sqrt(v^hat) + eps) - sched*decay*x_t-1
|
17689
|
+
// x_t = x_t-1*(1-sched*decay) - sched*alpha*m^hat/(sqrt(v^hat) + eps)
|
17690
|
+
// x_t = x_t-1*(1-sched*decay) + sched*decay*(-alpha/decay)*m^hat/(sqrt(v^hat) + eps)
|
17691
|
+
// x_t = mix(x_t-1, (-alpha/decay)*m^hat/(sqrt(v^hat) + eps), sched*decay)
|
15578
17692
|
ggml_vec_cpy_f32 (nx, mh, m);
|
15579
17693
|
ggml_vec_cpy_f32 (nx, vh, v);
|
15580
17694
|
|
15581
|
-
ggml_vec_scale_f32(nx, mh, alpha/(1.0f - powf(beta1,
|
15582
|
-
ggml_vec_scale_f32(nx, vh, 1.0f/(1.0f - powf(beta2,
|
17695
|
+
ggml_vec_scale_f32(nx, mh, alpha/(1.0f - powf(beta1, opt->iter)));
|
17696
|
+
ggml_vec_scale_f32(nx, vh, 1.0f/(1.0f - powf(beta2, opt->iter)));
|
15583
17697
|
|
15584
17698
|
ggml_vec_sqrt_f32 (nx, vh, vh);
|
15585
17699
|
ggml_vec_acc1_f32 (nx, vh, eps);
|
15586
17700
|
|
15587
17701
|
ggml_vec_div_f32 (nx, mh, mh, vh);
|
17702
|
+
ggml_vec_scale_f32(nx, x, 1.0f - decay);
|
15588
17703
|
ggml_vec_sub_f32 (nx, x, x, mh);
|
15589
17704
|
|
15590
17705
|
// update the parameters
|
@@ -15598,7 +17713,7 @@ static enum ggml_opt_result ggml_opt_adam(
|
|
15598
17713
|
const float fx = ggml_get_f32_1d(f, 0);
|
15599
17714
|
|
15600
17715
|
// check convergence
|
15601
|
-
if (fabsf(fx - fx_prev)/fx < params.adam.eps_f) {
|
17716
|
+
if (fabsf(fx - fx_prev[0])/fx < params.adam.eps_f) {
|
15602
17717
|
GGML_PRINT_DEBUG("converged\n");
|
15603
17718
|
|
15604
17719
|
return GGML_OPT_OK;
|
@@ -15607,32 +17722,32 @@ static enum ggml_opt_result ggml_opt_adam(
|
|
15607
17722
|
// delta-based convergence test
|
15608
17723
|
if (pf != NULL) {
|
15609
17724
|
// need at least params.past iterations to start checking for convergence
|
15610
|
-
if (params.past <= t) {
|
15611
|
-
const float rate = (pf[t%params.past] - fx)/fx;
|
17725
|
+
if (params.past <= iter0 + t) {
|
17726
|
+
const float rate = (pf[(iter0 + t)%params.past] - fx)/fx;
|
15612
17727
|
|
15613
17728
|
if (fabsf(rate) < params.delta) {
|
15614
17729
|
return GGML_OPT_OK;
|
15615
17730
|
}
|
15616
17731
|
}
|
15617
17732
|
|
15618
|
-
pf[t%params.past] = fx;
|
17733
|
+
pf[(iter0 + t)%params.past] = fx;
|
15619
17734
|
}
|
15620
17735
|
|
15621
17736
|
// check for improvement
|
15622
17737
|
if (params.max_no_improvement > 0) {
|
15623
|
-
if (fx_best > fx) {
|
15624
|
-
fx_best = fx;
|
15625
|
-
n_no_improvement = 0;
|
17738
|
+
if (fx_best[0] > fx) {
|
17739
|
+
fx_best[0] = fx;
|
17740
|
+
n_no_improvement[0] = 0;
|
15626
17741
|
} else {
|
15627
|
-
++n_no_improvement;
|
17742
|
+
++n_no_improvement[0];
|
15628
17743
|
|
15629
|
-
if (n_no_improvement >= params.max_no_improvement) {
|
17744
|
+
if (n_no_improvement[0] >= params.max_no_improvement) {
|
15630
17745
|
return GGML_OPT_OK;
|
15631
17746
|
}
|
15632
17747
|
}
|
15633
17748
|
}
|
15634
17749
|
|
15635
|
-
fx_prev = fx;
|
17750
|
+
fx_prev[0] = fx;
|
15636
17751
|
|
15637
17752
|
{
|
15638
17753
|
const int64_t t_end_cpu = ggml_cycles();
|
@@ -15771,6 +17886,7 @@ static enum ggml_opt_result linesearch_backtracking(
|
|
15771
17886
|
|
15772
17887
|
static enum ggml_opt_result ggml_opt_lbfgs(
|
15773
17888
|
struct ggml_context * ctx,
|
17889
|
+
struct ggml_opt_context * opt,
|
15774
17890
|
struct ggml_opt_params params,
|
15775
17891
|
struct ggml_tensor * f,
|
15776
17892
|
struct ggml_cgraph * gf,
|
@@ -15803,31 +17919,32 @@ static enum ggml_opt_result ggml_opt_lbfgs(
|
|
15803
17919
|
}
|
15804
17920
|
}
|
15805
17921
|
|
15806
|
-
|
15807
|
-
|
15808
|
-
|
15809
|
-
|
15810
|
-
|
17922
|
+
if ((opt->params.type != params.type) || (opt->nx != nx) || (opt->params.past != params.past) || (opt->params.lbfgs.m != params.lbfgs.m)) {
|
17923
|
+
int iter = opt->iter;
|
17924
|
+
ggml_opt_init(ctx, opt, params, nx);
|
17925
|
+
opt->iter = iter;
|
17926
|
+
}
|
17927
|
+
|
17928
|
+
float * x = opt->lbfgs.x->data; // current parameters
|
17929
|
+
float * xp = opt->lbfgs.xp->data; // previous parameters
|
17930
|
+
float * g = opt->lbfgs.g->data; // current gradient
|
17931
|
+
float * gp = opt->lbfgs.gp->data; // previous gradient
|
17932
|
+
float * d = opt->lbfgs.d->data; // search direction
|
15811
17933
|
|
15812
|
-
float * pf = params.past > 0 ?
|
17934
|
+
float * pf = params.past > 0 ? opt->lbfgs.pf->data : NULL; // past function values
|
15813
17935
|
|
15814
17936
|
float fx = 0.0f; // cost function value
|
15815
17937
|
float xnorm = 0.0f; // ||x||
|
15816
17938
|
float gnorm = 0.0f; // ||g||
|
15817
|
-
float step = 0.0f;
|
15818
17939
|
|
15819
17940
|
// initialize x from the graph nodes
|
15820
17941
|
ggml_opt_get_params(np, ps, x);
|
15821
17942
|
|
15822
17943
|
// the L-BFGS memory
|
15823
|
-
|
15824
|
-
|
15825
|
-
|
15826
|
-
|
15827
|
-
lm[i].ys = 0.0f;
|
15828
|
-
lm[i].s = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data;
|
15829
|
-
lm[i].y = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data;
|
15830
|
-
}
|
17944
|
+
float * lm_alpha = opt->lbfgs.lmal->data;
|
17945
|
+
float * lm_ys = opt->lbfgs.lmys->data;
|
17946
|
+
float * lm_s = opt->lbfgs.lms->data;
|
17947
|
+
float * lm_y = opt->lbfgs.lmy->data;
|
15831
17948
|
|
15832
17949
|
// evaluate the function value and its gradient
|
15833
17950
|
{
|
@@ -15842,12 +17959,6 @@ static enum ggml_opt_result ggml_opt_lbfgs(
|
|
15842
17959
|
fx = ggml_get_f32_1d(f, 0);
|
15843
17960
|
}
|
15844
17961
|
|
15845
|
-
if (pf) {
|
15846
|
-
pf[0] = fx;
|
15847
|
-
}
|
15848
|
-
|
15849
|
-
float fx_best = fx;
|
15850
|
-
|
15851
17962
|
// search direction = -gradient
|
15852
17963
|
ggml_vec_neg_f32(nx, d, g);
|
15853
17964
|
|
@@ -15864,26 +17975,43 @@ static enum ggml_opt_result ggml_opt_lbfgs(
|
|
15864
17975
|
return GGML_OPT_OK;
|
15865
17976
|
}
|
15866
17977
|
|
15867
|
-
|
15868
|
-
|
17978
|
+
if (opt->just_initialized) {
|
17979
|
+
if (pf) {
|
17980
|
+
pf[0] = fx;
|
17981
|
+
}
|
17982
|
+
opt->lbfgs.fx_best = fx;
|
17983
|
+
|
17984
|
+
// initial step
|
17985
|
+
ggml_vec_norm_inv_f32(nx, &opt->lbfgs.step, d);
|
17986
|
+
opt->lbfgs.j = 0;
|
17987
|
+
opt->lbfgs.k = 1;
|
17988
|
+
opt->lbfgs.end = 0;
|
17989
|
+
opt->lbfgs.n_no_improvement = 0;
|
17990
|
+
opt->just_initialized = false;
|
17991
|
+
}
|
17992
|
+
|
17993
|
+
float * fx_best = &opt->lbfgs.fx_best;
|
17994
|
+
float * step = &opt->lbfgs.step;
|
17995
|
+
int * j = &opt->lbfgs.j;
|
17996
|
+
int * k = &opt->lbfgs.k;
|
17997
|
+
int * end = &opt->lbfgs.end;
|
17998
|
+
int * n_no_improvement = &opt->lbfgs.n_no_improvement;
|
15869
17999
|
|
15870
|
-
int
|
15871
|
-
int
|
15872
|
-
int ls = 0;
|
15873
|
-
int end = 0;
|
15874
|
-
int bound = 0;
|
15875
|
-
int n_no_improvement = 0;
|
18000
|
+
int ls = 0;
|
18001
|
+
int bound = 0;
|
15876
18002
|
|
15877
18003
|
float ys = 0.0f;
|
15878
18004
|
float yy = 0.0f;
|
15879
18005
|
float beta = 0.0f;
|
15880
18006
|
|
18007
|
+
int it = 0;
|
18008
|
+
|
15881
18009
|
while (true) {
|
15882
18010
|
// store the current position and gradient vectors
|
15883
18011
|
ggml_vec_cpy_f32(nx, xp, x);
|
15884
18012
|
ggml_vec_cpy_f32(nx, gp, g);
|
15885
18013
|
|
15886
|
-
ls = linesearch_backtracking(ctx, ¶ms, nx, x, &fx, g, d,
|
18014
|
+
ls = linesearch_backtracking(ctx, ¶ms, nx, x, &fx, g, d, step, xp, f, gf, gb, np, ps);
|
15887
18015
|
|
15888
18016
|
if (ls < 0) {
|
15889
18017
|
// linesearch failed - go back to the previous point and return
|
@@ -15909,32 +18037,32 @@ static enum ggml_opt_result ggml_opt_lbfgs(
|
|
15909
18037
|
// delta-based convergence test
|
15910
18038
|
if (pf != NULL) {
|
15911
18039
|
// need at least params.past iterations to start checking for convergence
|
15912
|
-
if (params.past <= k) {
|
15913
|
-
const float rate = (pf[k%params.past] - fx)/fx;
|
18040
|
+
if (params.past <= k[0]) {
|
18041
|
+
const float rate = (pf[k[0]%params.past] - fx)/fx;
|
15914
18042
|
|
15915
18043
|
if (fabsf(rate) < params.delta) {
|
15916
18044
|
return GGML_OPT_OK;
|
15917
18045
|
}
|
15918
18046
|
}
|
15919
18047
|
|
15920
|
-
pf[k%params.past] = fx;
|
18048
|
+
pf[k[0]%params.past] = fx;
|
15921
18049
|
}
|
15922
18050
|
|
15923
18051
|
// check for improvement
|
15924
18052
|
if (params.max_no_improvement > 0) {
|
15925
|
-
if (fx < fx_best) {
|
15926
|
-
fx_best = fx;
|
15927
|
-
n_no_improvement = 0;
|
18053
|
+
if (fx < fx_best[0]) {
|
18054
|
+
fx_best[0] = fx;
|
18055
|
+
n_no_improvement[0] = 0;
|
15928
18056
|
} else {
|
15929
|
-
n_no_improvement++;
|
18057
|
+
n_no_improvement[0]++;
|
15930
18058
|
|
15931
|
-
if (n_no_improvement >= params.max_no_improvement) {
|
18059
|
+
if (n_no_improvement[0] >= params.max_no_improvement) {
|
15932
18060
|
return GGML_OPT_OK;
|
15933
18061
|
}
|
15934
18062
|
}
|
15935
18063
|
}
|
15936
18064
|
|
15937
|
-
if (params.lbfgs.n_iter != 0 && params.lbfgs.n_iter <
|
18065
|
+
if (params.lbfgs.n_iter != 0 && params.lbfgs.n_iter < it + 1) {
|
15938
18066
|
// reached the maximum number of iterations
|
15939
18067
|
return GGML_OPT_DID_NOT_CONVERGE;
|
15940
18068
|
}
|
@@ -15943,50 +18071,51 @@ static enum ggml_opt_result ggml_opt_lbfgs(
|
|
15943
18071
|
// s_{k+1} = x_{k+1} - x_{k} = \step * d_{k}.
|
15944
18072
|
// y_{k+1} = g_{k+1} - g_{k}.
|
15945
18073
|
//
|
15946
|
-
ggml_vec_sub_f32(nx,
|
15947
|
-
ggml_vec_sub_f32(nx,
|
18074
|
+
ggml_vec_sub_f32(nx, &lm_s[end[0]*nx], x, xp);
|
18075
|
+
ggml_vec_sub_f32(nx, &lm_y[end[0]*nx], g, gp);
|
15948
18076
|
|
15949
18077
|
// compute scalars ys and yy:
|
15950
18078
|
// ys = y^t \cdot s -> 1 / \rho.
|
15951
18079
|
// yy = y^t \cdot y.
|
15952
18080
|
//
|
15953
|
-
ggml_vec_dot_f32(nx, &ys,
|
15954
|
-
ggml_vec_dot_f32(nx, &yy,
|
18081
|
+
ggml_vec_dot_f32(nx, &ys, &lm_y[end[0]*nx], &lm_s[end[0] *nx]);
|
18082
|
+
ggml_vec_dot_f32(nx, &yy, &lm_y[end[0]*nx], &lm_y[end[0]*nx]);
|
15955
18083
|
|
15956
|
-
|
18084
|
+
lm_ys[end[0]] = ys;
|
15957
18085
|
|
15958
18086
|
// find new search direction
|
15959
18087
|
// ref: https://en.wikipedia.org/wiki/Limited-memory_BFGS
|
15960
18088
|
|
15961
|
-
bound = (m <= k) ? m : k;
|
15962
|
-
k++;
|
15963
|
-
|
18089
|
+
bound = (m <= k[0]) ? m : k[0];
|
18090
|
+
k[0]++;
|
18091
|
+
it++;
|
18092
|
+
end[0] = (end[0] + 1)%m;
|
15964
18093
|
|
15965
18094
|
// initialize search direction with -g
|
15966
18095
|
ggml_vec_neg_f32(nx, d, g);
|
15967
18096
|
|
15968
|
-
j = end;
|
18097
|
+
j[0] = end[0];
|
15969
18098
|
for (int i = 0; i < bound; ++i) {
|
15970
|
-
j = (j + m - 1) % m;
|
18099
|
+
j[0] = (j[0] + m - 1) % m;
|
15971
18100
|
// \alpha_{j} = \rho_{j} s^{t}_{j} \cdot q_{k+1}
|
15972
|
-
ggml_vec_dot_f32(nx, &
|
15973
|
-
|
18101
|
+
ggml_vec_dot_f32(nx, &lm_alpha[j[0]], &lm_s[j[0]*nx], d);
|
18102
|
+
lm_alpha[j[0]] /= lm_ys[j[0]];
|
15974
18103
|
// q_{i} = q_{i+1} - \alpha_{i} y_{i}
|
15975
|
-
ggml_vec_mad_f32(nx, d,
|
18104
|
+
ggml_vec_mad_f32(nx, d, &lm_y[j[0]*nx], -lm_alpha[j[0]]);
|
15976
18105
|
}
|
15977
18106
|
|
15978
18107
|
ggml_vec_scale_f32(nx, d, ys/yy);
|
15979
18108
|
|
15980
18109
|
for (int i = 0; i < bound; ++i) {
|
15981
18110
|
// \beta_{j} = \rho_{j} y^t_{j} \cdot \gamma_{i}
|
15982
|
-
ggml_vec_dot_f32(nx, &beta,
|
15983
|
-
beta /=
|
18111
|
+
ggml_vec_dot_f32(nx, &beta, &lm_y[j[0]*nx], d);
|
18112
|
+
beta /= lm_ys[j[0]];
|
15984
18113
|
// \gamma_{i+1} = \gamma_{i} + (\alpha_{j} - \beta_{j}) s_{j}
|
15985
|
-
ggml_vec_mad_f32(nx, d,
|
15986
|
-
j = (j + 1)%m;
|
18114
|
+
ggml_vec_mad_f32(nx, d, &lm_s[j[0]*nx], lm_alpha[j[0]] - beta);
|
18115
|
+
j[0] = (j[0] + 1)%m;
|
15987
18116
|
}
|
15988
18117
|
|
15989
|
-
step = 1.0;
|
18118
|
+
step[0] = 1.0;
|
15990
18119
|
}
|
15991
18120
|
|
15992
18121
|
return GGML_OPT_DID_NOT_CONVERGE;
|
@@ -16011,6 +18140,8 @@ struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) {
|
|
16011
18140
|
|
16012
18141
|
.adam = {
|
16013
18142
|
.n_iter = 10000,
|
18143
|
+
.sched = 1.000f,
|
18144
|
+
.decay = 0.001f,
|
16014
18145
|
.alpha = 0.001f,
|
16015
18146
|
.beta1 = 0.9f,
|
16016
18147
|
.beta2 = 0.999f,
|
@@ -16053,6 +18184,70 @@ struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) {
|
|
16053
18184
|
return result;
|
16054
18185
|
}
|
16055
18186
|
|
18187
|
+
GGML_API void ggml_opt_init(
|
18188
|
+
struct ggml_context * ctx,
|
18189
|
+
struct ggml_opt_context * opt,
|
18190
|
+
struct ggml_opt_params params,
|
18191
|
+
int64_t nx) {
|
18192
|
+
opt->ctx = ctx;
|
18193
|
+
opt->params = params;
|
18194
|
+
opt->iter = 0;
|
18195
|
+
opt->nx = nx;
|
18196
|
+
opt->just_initialized = true;
|
18197
|
+
switch (opt->params.type) {
|
18198
|
+
case GGML_OPT_ADAM:
|
18199
|
+
{
|
18200
|
+
opt->adam.x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
18201
|
+
opt->adam.g1 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
18202
|
+
opt->adam.g2 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
18203
|
+
opt->adam.m = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
18204
|
+
opt->adam.v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
18205
|
+
opt->adam.mh = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
18206
|
+
opt->adam.vh = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
18207
|
+
opt->adam.pf = params.past > 0
|
18208
|
+
? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)
|
18209
|
+
: NULL;
|
18210
|
+
ggml_set_zero(opt->adam.x);
|
18211
|
+
ggml_set_zero(opt->adam.g1);
|
18212
|
+
ggml_set_zero(opt->adam.g2);
|
18213
|
+
ggml_set_zero(opt->adam.m);
|
18214
|
+
ggml_set_zero(opt->adam.v);
|
18215
|
+
ggml_set_zero(opt->adam.mh);
|
18216
|
+
ggml_set_zero(opt->adam.vh);
|
18217
|
+
if (opt->adam.pf) {
|
18218
|
+
ggml_set_zero(opt->adam.pf);
|
18219
|
+
}
|
18220
|
+
} break;
|
18221
|
+
case GGML_OPT_LBFGS:
|
18222
|
+
{
|
18223
|
+
opt->lbfgs.x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
18224
|
+
opt->lbfgs.xp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
18225
|
+
opt->lbfgs.g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
18226
|
+
opt->lbfgs.gp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
18227
|
+
opt->lbfgs.d = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
18228
|
+
opt->lbfgs.pf = params.past > 0
|
18229
|
+
? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)
|
18230
|
+
: NULL;
|
18231
|
+
opt->lbfgs.lmal = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.lbfgs.m);
|
18232
|
+
opt->lbfgs.lmys = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.lbfgs.m);
|
18233
|
+
opt->lbfgs.lms = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nx, params.lbfgs.m);
|
18234
|
+
opt->lbfgs.lmy = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nx, params.lbfgs.m);
|
18235
|
+
ggml_set_zero(opt->lbfgs.x);
|
18236
|
+
ggml_set_zero(opt->lbfgs.xp);
|
18237
|
+
ggml_set_zero(opt->lbfgs.g);
|
18238
|
+
ggml_set_zero(opt->lbfgs.gp);
|
18239
|
+
ggml_set_zero(opt->lbfgs.d);
|
18240
|
+
if (opt->lbfgs.pf) {
|
18241
|
+
ggml_set_zero(opt->lbfgs.pf);
|
18242
|
+
}
|
18243
|
+
ggml_set_zero(opt->lbfgs.lmal);
|
18244
|
+
ggml_set_zero(opt->lbfgs.lmys);
|
18245
|
+
ggml_set_zero(opt->lbfgs.lms);
|
18246
|
+
ggml_set_zero(opt->lbfgs.lmy);
|
18247
|
+
} break;
|
18248
|
+
}
|
18249
|
+
}
|
18250
|
+
|
16056
18251
|
enum ggml_opt_result ggml_opt(
|
16057
18252
|
struct ggml_context * ctx,
|
16058
18253
|
struct ggml_opt_params params,
|
@@ -16075,33 +18270,65 @@ enum ggml_opt_result ggml_opt(
|
|
16075
18270
|
|
16076
18271
|
enum ggml_opt_result result = GGML_OPT_OK;
|
16077
18272
|
|
18273
|
+
struct ggml_opt_context * opt = (struct ggml_opt_context *) alloca(sizeof(struct ggml_opt_context));
|
18274
|
+
|
18275
|
+
ggml_opt_init(ctx, opt, params, 0);
|
18276
|
+
result = ggml_opt_resume(ctx, opt, f);
|
18277
|
+
|
18278
|
+
if (free_ctx) {
|
18279
|
+
ggml_free(ctx);
|
18280
|
+
}
|
18281
|
+
|
18282
|
+
return result;
|
18283
|
+
}
|
18284
|
+
|
18285
|
+
enum ggml_opt_result ggml_opt_resume(
|
18286
|
+
struct ggml_context * ctx,
|
18287
|
+
struct ggml_opt_context * opt,
|
18288
|
+
struct ggml_tensor * f) {
|
18289
|
+
|
18290
|
+
// build forward + backward compute graphs
|
18291
|
+
struct ggml_tensor * gfbuf = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(struct ggml_cgraph) / GGML_TYPE_SIZE[GGML_TYPE_I32]+ (sizeof(struct ggml_cgraph) % GGML_TYPE_SIZE[GGML_TYPE_I32] ? 1 : 0));
|
18292
|
+
struct ggml_tensor * gbbuf = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(struct ggml_cgraph) / GGML_TYPE_SIZE[GGML_TYPE_I32]+ (sizeof(struct ggml_cgraph) % GGML_TYPE_SIZE[GGML_TYPE_I32] ? 1 : 0));
|
18293
|
+
|
18294
|
+
struct ggml_cgraph * gf = (struct ggml_cgraph *) gfbuf->data;
|
18295
|
+
struct ggml_cgraph * gb = (struct ggml_cgraph *) gbbuf->data;
|
18296
|
+
|
18297
|
+
*gf = ggml_build_forward (f);
|
18298
|
+
*gb = ggml_build_backward(ctx, gf, true);
|
18299
|
+
|
18300
|
+
return ggml_opt_resume_g(ctx, opt, f, gf, gb);
|
18301
|
+
}
|
18302
|
+
|
18303
|
+
enum ggml_opt_result ggml_opt_resume_g(
|
18304
|
+
struct ggml_context * ctx,
|
18305
|
+
struct ggml_opt_context * opt,
|
18306
|
+
struct ggml_tensor * f,
|
18307
|
+
struct ggml_cgraph * gf,
|
18308
|
+
struct ggml_cgraph * gb) {
|
18309
|
+
|
16078
18310
|
// build forward + backward compute graphs
|
16079
|
-
|
16080
|
-
struct ggml_cgraph gb = ggml_build_backward(ctx, &gf, true);
|
18311
|
+
enum ggml_opt_result result = GGML_OPT_OK;
|
16081
18312
|
|
16082
|
-
switch (params.type) {
|
18313
|
+
switch (opt->params.type) {
|
16083
18314
|
case GGML_OPT_ADAM:
|
16084
18315
|
{
|
16085
|
-
result = ggml_opt_adam(ctx, params, f,
|
18316
|
+
result = ggml_opt_adam(ctx, opt, opt->params, f, gf, gb);
|
16086
18317
|
} break;
|
16087
18318
|
case GGML_OPT_LBFGS:
|
16088
18319
|
{
|
16089
|
-
result = ggml_opt_lbfgs(ctx, params, f,
|
18320
|
+
result = ggml_opt_lbfgs(ctx, opt, opt->params, f, gf, gb);
|
16090
18321
|
} break;
|
16091
18322
|
}
|
16092
18323
|
|
16093
|
-
if (params.print_forward_graph) {
|
16094
|
-
ggml_graph_print (
|
16095
|
-
ggml_graph_dump_dot(
|
16096
|
-
}
|
16097
|
-
|
16098
|
-
if (params.print_backward_graph) {
|
16099
|
-
ggml_graph_print (&gb);
|
16100
|
-
ggml_graph_dump_dot(&gb, &gf, "opt-backward.dot");
|
18324
|
+
if (opt->params.print_forward_graph) {
|
18325
|
+
ggml_graph_print (gf);
|
18326
|
+
ggml_graph_dump_dot(gf, NULL, "opt-forward.dot");
|
16101
18327
|
}
|
16102
18328
|
|
16103
|
-
if (
|
16104
|
-
|
18329
|
+
if (opt->params.print_backward_graph) {
|
18330
|
+
ggml_graph_print (gb);
|
18331
|
+
ggml_graph_dump_dot(gb, gf, "opt-backward.dot");
|
16105
18332
|
}
|
16106
18333
|
|
16107
18334
|
return result;
|
@@ -16301,6 +18528,18 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i
|
|
16301
18528
|
result = ggml_quantize_q6_K(src + start, block, n, n, hist);
|
16302
18529
|
} break;
|
16303
18530
|
#endif
|
18531
|
+
case GGML_TYPE_F16:
|
18532
|
+
{
|
18533
|
+
int elemsize = sizeof(ggml_fp16_t);
|
18534
|
+
ggml_fp32_to_fp16_row(src + start, (ggml_fp16_t *)dst + start, n);
|
18535
|
+
result = n * elemsize;
|
18536
|
+
} break;
|
18537
|
+
case GGML_TYPE_F32:
|
18538
|
+
{
|
18539
|
+
int elemsize = sizeof(float);
|
18540
|
+
result = n * elemsize;
|
18541
|
+
memcpy((uint8_t *)dst + start * elemsize, src + start, result);
|
18542
|
+
} break;
|
16304
18543
|
default:
|
16305
18544
|
assert(false);
|
16306
18545
|
}
|