llama_cpp 0.2.0 → 0.2.2

Sign up to get free protection for your applications and to get access to all the features.
@@ -35,6 +35,12 @@
35
35
  #define static_assert(cond, msg) struct global_scope_noop_trick
36
36
  #endif
37
37
 
38
+ #if defined(_MSC_VER)
39
+ // disable "possible loss of data" to avoid hundreds of casts
40
+ // we should just be careful :)
41
+ #pragma warning(disable: 4244 4267)
42
+ #endif
43
+
38
44
  #if defined(_WIN32)
39
45
 
40
46
  #include <windows.h>
@@ -106,6 +112,7 @@ typedef void* thread_ret_t;
106
112
  /*#define GGML_PERF*/
107
113
  #define GGML_DEBUG 0
108
114
  #define GGML_GELU_FP16
115
+ #define GGML_GELU_QUICK_FP16
109
116
  #define GGML_SILU_FP16
110
117
 
111
118
  #define GGML_SOFT_MAX_UNROLL 4
@@ -334,6 +341,9 @@ static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
334
341
  // precomputed gelu table for f16 (128 KB)
335
342
  static ggml_fp16_t table_gelu_f16[1 << 16];
336
343
 
344
+ // precomputed quick gelu table for f16 (128 KB)
345
+ static ggml_fp16_t table_gelu_quick_f16[1 << 16];
346
+
337
347
  // precomputed silu table for f16 (128 KB)
338
348
  static ggml_fp16_t table_silu_f16[1 << 16];
339
349
 
@@ -1671,14 +1681,17 @@ quantize_fns_t ggml_internal_get_quantize_fn(size_t i) {
1671
1681
  #define GGML_F32x4_REDUCE_ONE(x) vaddvq_f32(x)
1672
1682
  #define GGML_F32x4_REDUCE(res, x) \
1673
1683
  { \
1674
- for (int i = 0; i < GGML_F32_ARR/2; ++i) { \
1675
- x[2*i] = vaddq_f32(x[2*i], x[2*i+1]); \
1684
+ int offset = GGML_F32_ARR >> 1; \
1685
+ for (int i = 0; i < offset; ++i) { \
1686
+ x[i] = vaddq_f32(x[i], x[offset+i]); \
1676
1687
  } \
1677
- for (int i = 0; i < GGML_F32_ARR/4; ++i) { \
1678
- x[4*i] = vaddq_f32(x[4*i], x[4*i+2]); \
1688
+ offset >>= 1; \
1689
+ for (int i = 0; i < offset; ++i) { \
1690
+ x[i] = vaddq_f32(x[i], x[offset+i]); \
1679
1691
  } \
1680
- for (int i = 0; i < GGML_F32_ARR/8; ++i) { \
1681
- x[8*i] = vaddq_f32(x[8*i], x[8*i+4]); \
1692
+ offset >>= 1; \
1693
+ for (int i = 0; i < offset; ++i) { \
1694
+ x[i] = vaddq_f32(x[i], x[offset+i]); \
1682
1695
  } \
1683
1696
  res = GGML_F32x4_REDUCE_ONE(x[0]); \
1684
1697
  }
@@ -1709,14 +1722,17 @@ quantize_fns_t ggml_internal_get_quantize_fn(size_t i) {
1709
1722
  #define GGML_F16x8_MUL vmulq_f16
1710
1723
  #define GGML_F16x8_REDUCE(res, x) \
1711
1724
  { \
1712
- for (int i = 0; i < GGML_F16_ARR/2; ++i) { \
1713
- x[2*i] = vaddq_f16(x[2*i], x[2*i+1]); \
1725
+ int offset = GGML_F16_ARR >> 1; \
1726
+ for (int i = 0; i < offset; ++i) { \
1727
+ x[i] = vaddq_f16(x[i], x[offset+i]); \
1714
1728
  } \
1715
- for (int i = 0; i < GGML_F16_ARR/4; ++i) { \
1716
- x[4*i] = vaddq_f16(x[4*i], x[4*i+2]); \
1729
+ offset >>= 1; \
1730
+ for (int i = 0; i < offset; ++i) { \
1731
+ x[i] = vaddq_f16(x[i], x[offset+i]); \
1717
1732
  } \
1718
- for (int i = 0; i < GGML_F16_ARR/8; ++i) { \
1719
- x[8*i] = vaddq_f16(x[8*i], x[8*i+4]); \
1733
+ offset >>= 1; \
1734
+ for (int i = 0; i < offset; ++i) { \
1735
+ x[i] = vaddq_f16(x[i], x[offset+i]); \
1720
1736
  } \
1721
1737
  const float32x4_t t0 = vcvt_f32_f16(vget_low_f16 (x[0])); \
1722
1738
  const float32x4_t t1 = vcvt_f32_f16(vget_high_f16(x[0])); \
@@ -1783,14 +1799,17 @@ quantize_fns_t ggml_internal_get_quantize_fn(size_t i) {
1783
1799
  #define GGML_F32x8_MUL _mm256_mul_ps
1784
1800
  #define GGML_F32x8_REDUCE(res, x) \
1785
1801
  { \
1786
- for (int i = 0; i < GGML_F32_ARR/2; ++i) { \
1787
- x[2*i] = _mm256_add_ps(x[2*i], x[2*i+1]); \
1802
+ int offset = GGML_F32_ARR >> 1; \
1803
+ for (int i = 0; i < offset; ++i) { \
1804
+ x[i] = _mm256_add_ps(x[i], x[offset+i]); \
1788
1805
  } \
1789
- for (int i = 0; i < GGML_F32_ARR/4; ++i) { \
1790
- x[4*i] = _mm256_add_ps(x[4*i], x[4*i+2]); \
1806
+ offset >>= 1; \
1807
+ for (int i = 0; i < offset; ++i) { \
1808
+ x[i] = _mm256_add_ps(x[i], x[offset+i]); \
1791
1809
  } \
1792
- for (int i = 0; i < GGML_F32_ARR/8; ++i) { \
1793
- x[8*i] = _mm256_add_ps(x[8*i], x[8*i+4]); \
1810
+ offset >>= 1; \
1811
+ for (int i = 0; i < offset; ++i) { \
1812
+ x[i] = _mm256_add_ps(x[i], x[offset+i]); \
1794
1813
  } \
1795
1814
  const __m128 t0 = _mm_add_ps(_mm256_castps256_ps128(x[0]), \
1796
1815
  _mm256_extractf128_ps(x[0], 1)); \
@@ -1880,14 +1899,17 @@ static inline void __avx_f32cx8_store(ggml_fp16_t *x, __m256 y) {
1880
1899
  #define GGML_F32x4_MUL vec_mul
1881
1900
  #define GGML_F32x4_REDUCE(res, x) \
1882
1901
  { \
1883
- for (int i = 0; i < GGML_F32_ARR/2; ++i) { \
1884
- x[2*i] = vec_add(x[2*i], x[2*i+1]); \
1902
+ int offset = GGML_F32_ARR >> 1; \
1903
+ for (int i = 0; i < offset; ++i) { \
1904
+ x[i] = vec_add(x[i], x[offset+i]); \
1885
1905
  } \
1886
- for (int i = 0; i < GGML_F32_ARR/4; ++i) { \
1887
- x[4*i] = vec_add(x[4*i], x[4*i+2]); \
1906
+ offset >>= 1; \
1907
+ for (int i = 0; i < offset; ++i) { \
1908
+ x[i] = vec_add(x[i], x[offset+i]); \
1888
1909
  } \
1889
- for (int i = 0; i < GGML_F32_ARR/8; ++i) { \
1890
- x[8*i] = vec_add(x[8*i], x[8*i+4]); \
1910
+ offset >>= 1; \
1911
+ for (int i = 0; i < offset; ++i) { \
1912
+ x[i] = vec_add(x[i], x[offset+i]); \
1891
1913
  } \
1892
1914
  res = vec_extract(x[0], 0) + \
1893
1915
  vec_extract(x[0], 1) + \
@@ -1943,14 +1965,17 @@ static inline void __avx_f32cx8_store(ggml_fp16_t *x, __m256 y) {
1943
1965
  #define GGML_F32x4_MUL wasm_f32x4_mul
1944
1966
  #define GGML_F32x4_REDUCE(res, x) \
1945
1967
  { \
1946
- for (int i = 0; i < GGML_F32_ARR/2; ++i) { \
1947
- x[2*i] = wasm_f32x4_add(x[2*i], x[2*i+1]); \
1968
+ int offset = GGML_F32_ARR >> 1; \
1969
+ for (int i = 0; i < offset; ++i) { \
1970
+ x[i] = wasm_f32x4_add(x[i], x[offset+i]); \
1948
1971
  } \
1949
- for (int i = 0; i < GGML_F32_ARR/4; ++i) { \
1950
- x[4*i] = wasm_f32x4_add(x[4*i], x[4*i+2]); \
1972
+ offset >>= 1; \
1973
+ for (int i = 0; i < offset; ++i) { \
1974
+ x[i] = wasm_f32x4_add(x[i], x[offset+i]); \
1951
1975
  } \
1952
- for (int i = 0; i < GGML_F32_ARR/8; ++i) { \
1953
- x[8*i] = wasm_f32x4_add(x[8*i], x[8*i+4]); \
1976
+ offset >>= 1; \
1977
+ for (int i = 0; i < offset; ++i) { \
1978
+ x[i] = wasm_f32x4_add(x[i], x[offset+i]); \
1954
1979
  } \
1955
1980
  res = wasm_f32x4_extract_lane(x[0], 0) + \
1956
1981
  wasm_f32x4_extract_lane(x[0], 1) + \
@@ -2005,14 +2030,17 @@ inline static void __wasm_f16x4_store(ggml_fp16_t * p, v128_t x) {
2005
2030
  #define GGML_F16x4_MUL wasm_f32x4_mul
2006
2031
  #define GGML_F16x4_REDUCE(res, x) \
2007
2032
  { \
2008
- for (int i = 0; i < GGML_F16_ARR/2; ++i) { \
2009
- x[2*i] = wasm_f32x4_add(x[2*i], x[2*i+1]); \
2033
+ int offset = GGML_F16_ARR >> 1; \
2034
+ for (int i = 0; i < offset; ++i) { \
2035
+ x[i] = wasm_f32x4_add(x[i], x[offset+i]); \
2010
2036
  } \
2011
- for (int i = 0; i < GGML_F16_ARR/4; ++i) { \
2012
- x[4*i] = wasm_f32x4_add(x[4*i], x[4*i+2]); \
2037
+ offset >>= 1; \
2038
+ for (int i = 0; i < offset; ++i) { \
2039
+ x[i] = wasm_f32x4_add(x[i], x[offset+i]); \
2013
2040
  } \
2014
- for (int i = 0; i < GGML_F16_ARR/8; ++i) { \
2015
- x[8*i] = wasm_f32x4_add(x[8*i], x[8*i+4]); \
2041
+ offset >>= 1; \
2042
+ for (int i = 0; i < offset; ++i) { \
2043
+ x[i] = wasm_f32x4_add(x[i], x[offset+i]); \
2016
2044
  } \
2017
2045
  res = wasm_f32x4_extract_lane(x[0], 0) + \
2018
2046
  wasm_f32x4_extract_lane(x[0], 1) + \
@@ -2054,14 +2082,17 @@ inline static void __wasm_f16x4_store(ggml_fp16_t * p, v128_t x) {
2054
2082
  #define GGML_F32x4_MUL _mm_mul_ps
2055
2083
  #define GGML_F32x4_REDUCE(res, x) \
2056
2084
  { \
2057
- for (int i = 0; i < GGML_F32_ARR/2; ++i) { \
2058
- x[2*i] = _mm_add_ps(x[2*i], x[2*i+1]); \
2085
+ int offset = GGML_F32_ARR >> 1; \
2086
+ for (int i = 0; i < offset; ++i) { \
2087
+ x[i] = _mm_add_ps(x[i], x[offset+i]); \
2059
2088
  } \
2060
- for (int i = 0; i < GGML_F32_ARR/4; ++i) { \
2061
- x[4*i] = _mm_add_ps(x[4*i], x[4*i+2]); \
2089
+ offset >>= 1; \
2090
+ for (int i = 0; i < offset; ++i) { \
2091
+ x[i] = _mm_add_ps(x[i], x[offset+i]); \
2062
2092
  } \
2063
- for (int i = 0; i < GGML_F32_ARR/8; ++i) { \
2064
- x[8*i] = _mm_add_ps(x[8*i], x[8*i+4]); \
2093
+ offset >>= 1; \
2094
+ for (int i = 0; i < offset; ++i) { \
2095
+ x[i] = _mm_add_ps(x[i], x[offset+i]); \
2065
2096
  } \
2066
2097
  const __m128 t0 = _mm_hadd_ps(x[0], x[0]); \
2067
2098
  res = _mm_cvtss_f32(_mm_hadd_ps(t0, t0)); \
@@ -3350,6 +3381,7 @@ inline static void ggml_vec_step_f32 (const int n, float * y, const float * x) {
3350
3381
  inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; }
3351
3382
 
3352
3383
  static const float GELU_COEF_A = 0.044715f;
3384
+ static const float GELU_QUICK_COEF = -1.702f;
3353
3385
  static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
3354
3386
 
3355
3387
  inline static float ggml_gelu_f32(float x) {
@@ -3380,6 +3412,34 @@ inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
3380
3412
  }
3381
3413
  #endif
3382
3414
 
3415
+ inline static float ggml_gelu_quick_f32(float x) {
3416
+ return x*(1.0f/(1.0f+expf(GELU_QUICK_COEF*x)));
3417
+ }
3418
+
3419
+ //inline static void ggml_vec_gelu_quick_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
3420
+ // const uint16_t * i16 = (const uint16_t *) x;
3421
+ // for (int i = 0; i < n; ++i) {
3422
+ // y[i] = table_gelu_quick_f16[i16[i]];
3423
+ // }
3424
+ //}
3425
+
3426
+ #ifdef GGML_GELU_QUICK_FP16
3427
+ inline static void ggml_vec_gelu_quick_f32(const int n, float * y, const float * x) {
3428
+ uint16_t t;
3429
+ for (int i = 0; i < n; ++i) {
3430
+ ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]);
3431
+ memcpy(&t, &fp16, sizeof(uint16_t));
3432
+ y[i] = GGML_FP16_TO_FP32(table_gelu_quick_f16[t]);
3433
+ }
3434
+ }
3435
+ #else
3436
+ inline static void ggml_vec_gelu_quick_f32(const int n, float * y, const float * x) {
3437
+ for (int i = 0; i < n; ++i) {
3438
+ y[i] = ggml_gelu_quick_f32(x[i]);
3439
+ }
3440
+ }
3441
+ #endif
3442
+
3383
3443
  // Sigmoid Linear Unit (SiLU) function
3384
3444
  inline static float ggml_silu_f32(float x) {
3385
3445
  return x/(1.0f + expf(-x));
@@ -3603,12 +3663,14 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
3603
3663
  "SUM_ROWS",
3604
3664
  "MEAN",
3605
3665
  "REPEAT",
3666
+ "REPEAT_BACK",
3606
3667
  "ABS",
3607
3668
  "SGN",
3608
3669
  "NEG",
3609
3670
  "STEP",
3610
3671
  "RELU",
3611
3672
  "GELU",
3673
+ "GELU_QUICK",
3612
3674
  "SILU",
3613
3675
  "SILU_BACK",
3614
3676
  "NORM",
@@ -3616,6 +3678,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
3616
3678
  "RMS_NORM_BACK",
3617
3679
 
3618
3680
  "MUL_MAT",
3681
+ "OUT_PROD",
3619
3682
 
3620
3683
  "SCALE",
3621
3684
  "SET",
@@ -3631,22 +3694,29 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
3631
3694
  "DIAG_MASK_INF",
3632
3695
  "DIAG_MASK_ZERO",
3633
3696
  "SOFT_MAX",
3697
+ "SOFT_MAX_BACK",
3634
3698
  "ROPE",
3635
3699
  "ROPE_BACK",
3636
3700
  "ALIBI",
3637
3701
  "CLAMP",
3638
- "CONV_1D_1S",
3639
- "CONV_1D_2S",
3702
+ "CONV_1D_S1_PH",
3703
+ "CONV_1D_S2_PH",
3704
+ "CONV_2D_SK_P0",
3640
3705
 
3641
3706
  "FLASH_ATTN",
3642
3707
  "FLASH_FF",
3708
+ "FLASH_ATTN_BACK",
3709
+ "WIN_PART",
3710
+ "WIN_UNPART",
3643
3711
 
3644
3712
  "MAP_UNARY",
3645
3713
  "MAP_BINARY",
3646
- };
3647
3714
 
3648
- static_assert(GGML_OP_COUNT == 51, "GGML_OP_COUNT != 51");
3715
+ "CROSS_ENTROPY_LOSS",
3716
+ "CROSS_ENTROPY_LOSS_BACK",
3717
+ };
3649
3718
 
3719
+ static_assert(GGML_OP_COUNT == 61, "GGML_OP_COUNT != 61");
3650
3720
 
3651
3721
  static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
3652
3722
  "none",
@@ -3665,18 +3735,21 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
3665
3735
  "Σx_k",
3666
3736
  "Σx/n",
3667
3737
  "repeat(x)",
3738
+ "repeat_back(x)",
3668
3739
  "abs(x)",
3669
3740
  "sgn(x)",
3670
3741
  "-x",
3671
3742
  "step(x)",
3672
3743
  "relu(x)",
3673
3744
  "gelu(x)",
3745
+ "gelu_quick(x)",
3674
3746
  "silu(x)",
3675
3747
  "silu_back(x)",
3676
3748
  "norm(x)",
3677
3749
  "rms_norm(x)",
3678
3750
  "rms_norm_back(x)",
3679
3751
 
3752
+ "X*Y",
3680
3753
  "X*Y",
3681
3754
 
3682
3755
  "x*v",
@@ -3693,21 +3766,29 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
3693
3766
  "diag_mask_inf(x)",
3694
3767
  "diag_mask_zero(x)",
3695
3768
  "soft_max(x)",
3769
+ "soft_max_back(x)",
3696
3770
  "rope(x)",
3697
3771
  "rope_back(x)",
3698
3772
  "alibi(x)",
3699
3773
  "clamp(x)",
3700
- "conv_1d_1s(x)",
3701
- "conv_1d_2s(x)",
3774
+ "conv_1d_s1_ph(x)",
3775
+ "conv_1d_s2_ph(x)",
3776
+ "conv_2d_sk_p0(x)",
3702
3777
 
3703
3778
  "flash_attn(x)",
3704
3779
  "flash_ff(x)",
3780
+ "flash_attn_back(x)",
3781
+ "win_part(x)",
3782
+ "win_unpart(x)",
3705
3783
 
3706
3784
  "f(x)",
3707
3785
  "f(x,y)",
3786
+
3787
+ "cross_entropy_loss(x,y)",
3788
+ "cross_entropy_loss_back(x,y)",
3708
3789
  };
3709
3790
 
3710
- static_assert(GGML_OP_COUNT == 51, "GGML_OP_COUNT != 51");
3791
+ static_assert(GGML_OP_COUNT == 61, "GGML_OP_COUNT != 61");
3711
3792
 
3712
3793
  static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
3713
3794
  static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
@@ -3870,6 +3951,15 @@ static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct
3870
3951
  (t0->ne[3] == t1->ne[3]);
3871
3952
  }
3872
3953
 
3954
+ static inline bool ggml_can_out_prod(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
3955
+ static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
3956
+
3957
+ return
3958
+ (t0->ne[1] == t1->ne[1]) &&
3959
+ (t0->ne[2] == t1->ne[2]) &&
3960
+ (t0->ne[3] == t1->ne[3]);
3961
+ }
3962
+
3873
3963
  bool ggml_is_quantized(enum ggml_type type) {
3874
3964
  return GGML_IS_QUANTIZED[type];
3875
3965
  }
@@ -3917,6 +4007,12 @@ bool ggml_is_contiguous(const struct ggml_tensor * tensor) {
3917
4007
  tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
3918
4008
  }
3919
4009
 
4010
+ bool ggml_is_permuted(const struct ggml_tensor * tensor) {
4011
+ static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
4012
+
4013
+ return tensor->nb[0] > tensor->nb[1] || tensor->nb[1] > tensor->nb[2] || tensor->nb[2] > tensor->nb[3];
4014
+ }
4015
+
3920
4016
  static inline bool ggml_is_padded_1d(const struct ggml_tensor * tensor) {
3921
4017
  static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
3922
4018
 
@@ -3983,7 +4079,7 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
3983
4079
  // initialize time system (required on Windows)
3984
4080
  ggml_time_init();
3985
4081
 
3986
- // initialize GELU, SILU and EXP F32 tables
4082
+ // initialize GELU, Quick GELU, SILU and EXP F32 tables
3987
4083
  {
3988
4084
  const uint64_t t_start = ggml_time_us(); UNUSED(t_start);
3989
4085
 
@@ -3993,13 +4089,14 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
3993
4089
  memcpy(&ii, &ui, sizeof(ii));
3994
4090
  const float f = table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(ii);
3995
4091
  table_gelu_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_f32(f));
4092
+ table_gelu_quick_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_quick_f32(f));
3996
4093
  table_silu_f16[i] = GGML_FP32_TO_FP16(ggml_silu_f32(f));
3997
4094
  table_exp_f16[i] = GGML_FP32_TO_FP16(expf(f));
3998
4095
  }
3999
4096
 
4000
4097
  const uint64_t t_end = ggml_time_us(); UNUSED(t_end);
4001
4098
 
4002
- GGML_PRINT_DEBUG("%s: GELU, SILU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f);
4099
+ GGML_PRINT_DEBUG("%s: GELU, Quick GELU, SILU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f);
4003
4100
  }
4004
4101
 
4005
4102
  // initialize g_state
@@ -4120,14 +4217,34 @@ void ggml_set_no_alloc(struct ggml_context * ctx, bool no_alloc) {
4120
4217
  ctx->no_alloc = no_alloc;
4121
4218
  }
4122
4219
 
4123
- void * ggml_get_mem_buffer(struct ggml_context * ctx) {
4220
+ void * ggml_get_mem_buffer(const struct ggml_context * ctx) {
4124
4221
  return ctx->mem_buffer;
4125
4222
  }
4126
4223
 
4127
- size_t ggml_get_mem_size(struct ggml_context * ctx) {
4224
+ size_t ggml_get_mem_size(const struct ggml_context * ctx) {
4128
4225
  return ctx->mem_size;
4129
4226
  }
4130
4227
 
4228
+ size_t ggml_get_max_tensor_size(const struct ggml_context * ctx) {
4229
+ size_t max_size = 0;
4230
+
4231
+ struct ggml_object * obj = ctx->objects_begin;
4232
+
4233
+ while (obj != NULL) {
4234
+ struct ggml_tensor * tensor = (struct ggml_tensor *) ((char *) ctx->mem_buffer + obj->offs);
4235
+
4236
+ const size_t size = ggml_nbytes(tensor);
4237
+
4238
+ if (max_size < size) {
4239
+ max_size = size;
4240
+ }
4241
+
4242
+ obj = obj->next;
4243
+ }
4244
+
4245
+ return max_size;
4246
+ }
4247
+
4131
4248
  // IMPORTANT:
4132
4249
  // when creating "opt" tensors, always save and load the scratch buffer
4133
4250
  // this is an error prone process, but it is necessary to support inplace
@@ -4611,9 +4728,10 @@ const char * ggml_get_name(const struct ggml_tensor * tensor) {
4611
4728
  return tensor->name;
4612
4729
  }
4613
4730
 
4614
- void ggml_set_name(struct ggml_tensor * tensor, const char * name) {
4731
+ struct ggml_tensor * ggml_set_name(struct ggml_tensor * tensor, const char * name) {
4615
4732
  strncpy(tensor->name, name, sizeof(tensor->name));
4616
4733
  tensor->name[sizeof(tensor->name) - 1] = '\0';
4734
+ return tensor;
4617
4735
  }
4618
4736
 
4619
4737
  struct ggml_tensor * ggml_view_tensor(
@@ -4693,7 +4811,7 @@ struct ggml_tensor * ggml_add_impl(
4693
4811
 
4694
4812
  bool is_node = false;
4695
4813
 
4696
- if (!inplace && (a->grad || b->grad)) {
4814
+ if (a->grad || b->grad) {
4697
4815
  is_node = true;
4698
4816
  }
4699
4817
 
@@ -4733,7 +4851,7 @@ struct ggml_tensor * ggml_add1_impl(
4733
4851
 
4734
4852
  bool is_node = false;
4735
4853
 
4736
- if (!inplace && (a->grad || b->grad)) {
4854
+ if (a->grad || b->grad) {
4737
4855
  is_node = true;
4738
4856
  }
4739
4857
 
@@ -5159,6 +5277,34 @@ struct ggml_tensor * ggml_repeat(
5159
5277
  return result;
5160
5278
  }
5161
5279
 
5280
+ // ggml_repeat_back
5281
+
5282
+ struct ggml_tensor * ggml_repeat_back(
5283
+ struct ggml_context * ctx,
5284
+ struct ggml_tensor * a,
5285
+ struct ggml_tensor * b) {
5286
+ GGML_ASSERT(ggml_can_repeat(b, a));
5287
+
5288
+ bool is_node = false;
5289
+
5290
+ if (a->grad) {
5291
+ is_node = true;
5292
+ }
5293
+
5294
+ if (ggml_are_same_shape(a, b) && !is_node) {
5295
+ return a;
5296
+ }
5297
+
5298
+ struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, b->n_dims, b->ne);
5299
+
5300
+ result->op = GGML_OP_REPEAT_BACK;
5301
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5302
+ result->src0 = a;
5303
+ result->src1 = b;
5304
+
5305
+ return result;
5306
+ }
5307
+
5162
5308
  // ggml_abs
5163
5309
 
5164
5310
  struct ggml_tensor * ggml_abs_impl(
@@ -5364,6 +5510,40 @@ struct ggml_tensor * ggml_gelu_inplace(
5364
5510
  return ggml_gelu_impl(ctx, a, true);
5365
5511
  }
5366
5512
 
5513
+ // ggml_gelu_quick
5514
+
5515
+ struct ggml_tensor * ggml_gelu_quick_impl(
5516
+ struct ggml_context * ctx,
5517
+ struct ggml_tensor * a,
5518
+ bool inplace) {
5519
+ bool is_node = false;
5520
+
5521
+ if (!inplace && (a->grad)) {
5522
+ is_node = true;
5523
+ }
5524
+
5525
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
5526
+
5527
+ result->op = GGML_OP_GELU_QUICK;
5528
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5529
+ result->src0 = a;
5530
+ result->src1 = NULL;
5531
+
5532
+ return result;
5533
+ }
5534
+
5535
+ struct ggml_tensor * ggml_gelu_quick(
5536
+ struct ggml_context * ctx,
5537
+ struct ggml_tensor * a) {
5538
+ return ggml_gelu_quick_impl(ctx, a, false);
5539
+ }
5540
+
5541
+ struct ggml_tensor * ggml_gelu_quick_inplace(
5542
+ struct ggml_context * ctx,
5543
+ struct ggml_tensor * a) {
5544
+ return ggml_gelu_quick_impl(ctx, a, true);
5545
+ }
5546
+
5367
5547
  // ggml_silu
5368
5548
 
5369
5549
  struct ggml_tensor * ggml_silu_impl(
@@ -5536,6 +5716,32 @@ struct ggml_tensor * ggml_mul_mat(
5536
5716
  return result;
5537
5717
  }
5538
5718
 
5719
+ // ggml_out_prod
5720
+
5721
+ struct ggml_tensor * ggml_out_prod(
5722
+ struct ggml_context * ctx,
5723
+ struct ggml_tensor * a,
5724
+ struct ggml_tensor * b) {
5725
+ GGML_ASSERT(ggml_can_out_prod(a, b));
5726
+ GGML_ASSERT(!ggml_is_transposed(a));
5727
+
5728
+ bool is_node = false;
5729
+
5730
+ if (a->grad || b->grad) {
5731
+ is_node = true;
5732
+ }
5733
+
5734
+ const int64_t ne[4] = { a->ne[0], b->ne[0], a->ne[2], b->ne[3] };
5735
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MIN(a->n_dims, b->n_dims), ne);
5736
+
5737
+ result->op = GGML_OP_OUT_PROD;
5738
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5739
+ result->src0 = a;
5740
+ result->src1 = b;
5741
+
5742
+ return result;
5743
+ }
5744
+
5539
5745
  // ggml_scale
5540
5746
 
5541
5747
  struct ggml_tensor * ggml_scale_impl(
@@ -5548,7 +5754,7 @@ struct ggml_tensor * ggml_scale_impl(
5548
5754
 
5549
5755
  bool is_node = false;
5550
5756
 
5551
- if (!inplace && (a->grad || b->grad)) {
5757
+ if (a->grad || b->grad) {
5552
5758
  is_node = true;
5553
5759
  }
5554
5760
 
@@ -5591,7 +5797,7 @@ struct ggml_tensor * ggml_set_impl(
5591
5797
 
5592
5798
  bool is_node = false;
5593
5799
 
5594
- if (!inplace && (a->grad || b->grad)) {
5800
+ if (a->grad || b->grad) {
5595
5801
  is_node = true;
5596
5802
  }
5597
5803
 
@@ -5913,10 +6119,6 @@ struct ggml_tensor * ggml_view_1d(
5913
6119
  result->src1 = NULL;
5914
6120
  result->opt[0] = offs;
5915
6121
 
5916
- if (is_node) {
5917
- memcpy(result->padding, &offset, sizeof(offset));
5918
- }
5919
-
5920
6122
  return result;
5921
6123
  }
5922
6124
 
@@ -5957,10 +6159,6 @@ struct ggml_tensor * ggml_view_2d(
5957
6159
  result->src1 = NULL;
5958
6160
  result->opt[0] = offs;
5959
6161
 
5960
- if (is_node) {
5961
- memcpy(result->padding, &offset, sizeof(offset));
5962
- }
5963
-
5964
6162
  return result;
5965
6163
  }
5966
6164
 
@@ -6003,10 +6201,6 @@ struct ggml_tensor * ggml_view_3d(
6003
6201
  result->src1 = NULL;
6004
6202
  result->opt[0] = offs;
6005
6203
 
6006
- if (is_node) {
6007
- memcpy(result->padding, &offset, sizeof(offset));
6008
- }
6009
-
6010
6204
  return result;
6011
6205
  }
6012
6206
 
@@ -6051,10 +6245,6 @@ struct ggml_tensor * ggml_view_4d(
6051
6245
  result->src1 = NULL;
6052
6246
  result->opt[0] = offs;
6053
6247
 
6054
- if (is_node) {
6055
- memcpy(result->padding, &offset, sizeof(offset));
6056
- }
6057
-
6058
6248
  return result;
6059
6249
  }
6060
6250
 
@@ -6116,10 +6306,18 @@ struct ggml_tensor * ggml_permute(
6116
6306
  result->src1 = NULL;
6117
6307
 
6118
6308
  if (is_node) {
6119
- result->padding[0] = axis0;
6120
- result->padding[1] = axis1;
6121
- result->padding[2] = axis2;
6122
- result->padding[3] = axis3;
6309
+ ggml_scratch_save(ctx);
6310
+
6311
+ struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 4);
6312
+
6313
+ ((int32_t *) b->data)[0] = axis0;
6314
+ ((int32_t *) b->data)[1] = axis1;
6315
+ ((int32_t *) b->data)[2] = axis2;
6316
+ ((int32_t *) b->data)[3] = axis3;
6317
+
6318
+ ggml_scratch_load(ctx);
6319
+
6320
+ result->opt[0] = b;
6123
6321
  }
6124
6322
 
6125
6323
  return result;
@@ -6359,6 +6557,44 @@ struct ggml_tensor * ggml_soft_max_inplace(
6359
6557
  return ggml_soft_max_impl(ctx, a, true);
6360
6558
  }
6361
6559
 
6560
+
6561
+ // ggml_soft_max_back
6562
+
6563
+ struct ggml_tensor * ggml_soft_max_back_impl(
6564
+ struct ggml_context * ctx,
6565
+ struct ggml_tensor * a,
6566
+ struct ggml_tensor * b,
6567
+ bool inplace) {
6568
+ bool is_node = false;
6569
+
6570
+ if (a->grad || b->grad) {
6571
+ is_node = true; // TODO : implement backward pass
6572
+ }
6573
+
6574
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
6575
+
6576
+ result->op = GGML_OP_SOFT_MAX_BACK;
6577
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6578
+ result->src0 = a;
6579
+ result->src1 = b;
6580
+
6581
+ return result;
6582
+ }
6583
+
6584
+ struct ggml_tensor * ggml_soft_max_back(
6585
+ struct ggml_context * ctx,
6586
+ struct ggml_tensor * a,
6587
+ struct ggml_tensor * b) {
6588
+ return ggml_soft_max_back_impl(ctx, a, b, false);
6589
+ }
6590
+
6591
+ struct ggml_tensor * ggml_soft_max_back_inplace(
6592
+ struct ggml_context * ctx,
6593
+ struct ggml_tensor * a,
6594
+ struct ggml_tensor * b) {
6595
+ return ggml_soft_max_back_impl(ctx, a, b, true);
6596
+ }
6597
+
6362
6598
  // ggml_rope
6363
6599
 
6364
6600
  struct ggml_tensor * ggml_rope_impl(
@@ -6371,7 +6607,7 @@ struct ggml_tensor * ggml_rope_impl(
6371
6607
  GGML_ASSERT(n_past >= 0);
6372
6608
  bool is_node = false;
6373
6609
 
6374
- if (!inplace && a->grad) {
6610
+ if (a->grad) {
6375
6611
  is_node = true;
6376
6612
  }
6377
6613
 
@@ -6425,8 +6661,7 @@ struct ggml_tensor * ggml_rope_back(
6425
6661
  bool is_node = false;
6426
6662
 
6427
6663
  if (a->grad) {
6428
- GGML_ASSERT(false); // TODO: implement backward
6429
- is_node = true;
6664
+ is_node = false; // TODO: implement backward
6430
6665
  }
6431
6666
 
6432
6667
  struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
@@ -6508,7 +6743,7 @@ struct ggml_tensor * ggml_clamp(
6508
6743
 
6509
6744
  ggml_scratch_save(ctx);
6510
6745
 
6511
- struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3);
6746
+ struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 2);
6512
6747
 
6513
6748
  ((float *) b->data)[0] = min;
6514
6749
  ((float *) b->data)[1] = max;
@@ -6523,9 +6758,9 @@ struct ggml_tensor * ggml_clamp(
6523
6758
  return result;
6524
6759
  }
6525
6760
 
6526
- // ggml_conv_1d_1s
6761
+ // ggml_conv_1d_s1_ph
6527
6762
 
6528
- struct ggml_tensor * ggml_conv_1d_1s(
6763
+ struct ggml_tensor * ggml_conv_1d_s1_ph(
6529
6764
  struct ggml_context * ctx,
6530
6765
  struct ggml_tensor * a,
6531
6766
  struct ggml_tensor * b) {
@@ -6542,7 +6777,7 @@ struct ggml_tensor * ggml_conv_1d_1s(
6542
6777
  const int64_t ne[4] = { b->ne[0], a->ne[2], 1, 1, };
6543
6778
  struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 2, ne);
6544
6779
 
6545
- result->op = GGML_OP_CONV_1D_1S;
6780
+ result->op = GGML_OP_CONV_1D_S1_PH;
6546
6781
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6547
6782
  result->src0 = a;
6548
6783
  result->src1 = b;
@@ -6550,9 +6785,9 @@ struct ggml_tensor * ggml_conv_1d_1s(
6550
6785
  return result;
6551
6786
  }
6552
6787
 
6553
- // ggml_conv_1d_2s
6788
+ // ggml_conv_1d_s2_ph
6554
6789
 
6555
- struct ggml_tensor * ggml_conv_1d_2s(
6790
+ struct ggml_tensor * ggml_conv_1d_s2_ph(
6556
6791
  struct ggml_context * ctx,
6557
6792
  struct ggml_tensor * a,
6558
6793
  struct ggml_tensor * b) {
@@ -6569,7 +6804,35 @@ struct ggml_tensor * ggml_conv_1d_2s(
6569
6804
  const int64_t ne[4] = { b->ne[0]/2, a->ne[2], 1, 1, };
6570
6805
  struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 2, ne);
6571
6806
 
6572
- result->op = GGML_OP_CONV_1D_2S;
6807
+ result->op = GGML_OP_CONV_1D_S2_PH;
6808
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6809
+ result->src0 = a;
6810
+ result->src1 = b;
6811
+
6812
+ return result;
6813
+ }
6814
+
6815
+ // ggml_conv_2d_sk_p0
6816
+
6817
+ struct ggml_tensor * ggml_conv_2d_sk_p0(
6818
+ struct ggml_context * ctx,
6819
+ struct ggml_tensor * a,
6820
+ struct ggml_tensor * b) {
6821
+ GGML_ASSERT(b->ne[3] == 1);
6822
+ GGML_ASSERT(a->ne[2] == b->ne[2]);
6823
+ GGML_ASSERT(b->ne[0] % a->ne[0] == 0);
6824
+ GGML_ASSERT(b->ne[1] % a->ne[1] == 0);
6825
+ bool is_node = false;
6826
+
6827
+ if (a->grad || b->grad) {
6828
+ GGML_ASSERT(false); // TODO: implement backward
6829
+ is_node = true;
6830
+ }
6831
+
6832
+ const int64_t ne[4] = { b->ne[0]/a->ne[0], b->ne[1]/a->ne[1], a->ne[3], 1, };
6833
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
6834
+
6835
+ result->op = GGML_OP_CONV_2D_SK_P0;
6573
6836
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6574
6837
  result->src0 = a;
6575
6838
  result->src1 = b;
@@ -6591,7 +6854,6 @@ struct ggml_tensor * ggml_flash_attn(
6591
6854
  bool is_node = false;
6592
6855
 
6593
6856
  if (q->grad || k->grad || v->grad) {
6594
- GGML_ASSERT(false); // TODO: implement backward
6595
6857
  is_node = true;
6596
6858
  }
6597
6859
 
@@ -6623,7 +6885,6 @@ struct ggml_tensor * ggml_flash_ff(
6623
6885
  bool is_node = false;
6624
6886
 
6625
6887
  if (a->grad || b0->grad || b1->grad || c0->grad || c1->grad) {
6626
- GGML_ASSERT(false); // TODO: implement backward
6627
6888
  is_node = true;
6628
6889
  }
6629
6890
 
@@ -6641,54 +6902,202 @@ struct ggml_tensor * ggml_flash_ff(
6641
6902
  return result;
6642
6903
  }
6643
6904
 
6644
- // ggml_map_unary
6905
+ // ggml_flash_attn_back
6906
+
6907
+ struct ggml_tensor * ggml_flash_attn_back(
6908
+ struct ggml_context * ctx,
6909
+ struct ggml_tensor * q,
6910
+ struct ggml_tensor * k,
6911
+ struct ggml_tensor * v,
6912
+ struct ggml_tensor * d,
6913
+ bool masked) {
6914
+ GGML_ASSERT(ggml_can_mul_mat(k, q));
6915
+ // TODO: check if vT can be multiplied by (k*qT)
6916
+
6917
+ // d shape [D,N,ne2,ne3]
6918
+ // q shape [D,N,ne2,ne3]
6919
+ // k shape [D,M,ne2,ne3]
6920
+ // v shape [M,D,ne2,ne3]
6921
+
6922
+ const int64_t D = q->ne[0];
6923
+ const int64_t N = q->ne[1];
6924
+ const int64_t M = k->ne[1];
6925
+ const int64_t ne2 = q->ne[2];
6926
+ const int64_t ne3 = q->ne[3];
6927
+
6928
+ GGML_ASSERT(k->ne[0] == D);
6929
+ GGML_ASSERT(v->ne[0] == M);
6930
+ GGML_ASSERT(v->ne[1] == D);
6931
+ GGML_ASSERT(d->ne[0] == D);
6932
+ GGML_ASSERT(d->ne[1] == N);
6933
+ GGML_ASSERT(k->ne[2] == ne2);
6934
+ GGML_ASSERT(k->ne[3] == ne3);
6935
+ GGML_ASSERT(v->ne[2] == ne2);
6936
+ GGML_ASSERT(v->ne[3] == ne3);
6937
+ GGML_ASSERT(d->ne[2] == ne2);
6938
+ GGML_ASSERT(d->ne[3] == ne3);
6645
6939
 
6646
- struct ggml_tensor * ggml_map_unary_impl_f32(
6647
- struct ggml_context * ctx,
6648
- struct ggml_tensor * a,
6649
- const ggml_unary_op_f32_t fun,
6650
- bool inplace) {
6651
6940
  bool is_node = false;
6652
6941
 
6653
- if (!inplace && a->grad) {
6654
- is_node = true;
6942
+ if (q->grad || k->grad || v->grad) {
6943
+ // when using this operation (in backwards pass) these grads are set.
6944
+ // we don't want to create (big) grad of our result, so is_node is false.
6945
+ is_node = false;
6655
6946
  }
6656
6947
 
6657
- struct ggml_tensor * addr_tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(void *) / sizeof(int32_t));
6658
- *((void (**)(void))addr_tensor->data) = (void (*)(void))fun;
6659
- struct ggml_tensor *result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
6948
+ // store gradients of q, k and v as continuous tensors concatenated in result.
6949
+ // q shape[D,N,ne2,ne3] ; k shape [D,M,ne2,ne3] ; v shape [M,D,ne2,ne3]
6950
+ // gradq->data = result->data
6951
+ // gradk->data = result->data + nb0*D*N*ne2*ne3
6952
+ // gradv->data = result->data + nb0*D*N*ne2*ne3 + nb0*D*M*ne2*ne3
6953
+ // note: v and gradv are actually transposed, i.e. v->ne[0] != D.
6954
+ int64_t ne[4] = {D,M+N+M,ne2,ne3};
6660
6955
 
6661
- result->op = GGML_OP_MAP_UNARY;
6956
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
6957
+
6958
+ result->op = GGML_OP_FLASH_ATTN_BACK;
6662
6959
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6663
- result->src0 = a;
6664
- result->opt[0] = addr_tensor;
6960
+ result->src0 = q;
6961
+ result->src1 = k;
6962
+ result->opt[0] = v;
6963
+ result->opt[1] = d;
6964
+ result->opt[2] = ggml_new_i32(ctx, masked ? 1 : 0);
6665
6965
 
6666
6966
  return result;
6667
6967
  }
6668
6968
 
6669
- struct ggml_tensor * ggml_map_unary_f32(
6670
- struct ggml_context * ctx,
6671
- struct ggml_tensor * a,
6672
- const ggml_unary_op_f32_t fun) {
6673
- return ggml_map_unary_impl_f32(ctx, a, fun, false);
6674
- }
6675
-
6676
- struct ggml_tensor * ggml_map_unary_inplace_f32(
6677
- struct ggml_context * ctx,
6678
- struct ggml_tensor * a,
6679
- const ggml_unary_op_f32_t fun) {
6680
- return ggml_map_unary_impl_f32(ctx, a, fun, true);
6681
- }
6682
-
6683
- // ggml_map_binary
6969
+ // ggml_win_part
6684
6970
 
6685
- struct ggml_tensor * ggml_map_binary_impl_f32(
6686
- struct ggml_context * ctx,
6687
- struct ggml_tensor * a,
6688
- struct ggml_tensor * b,
6689
- const ggml_binary_op_f32_t fun,
6690
- bool inplace) {
6691
- GGML_ASSERT(ggml_are_same_shape(a, b));
6971
+ struct ggml_tensor * ggml_win_part(
6972
+ struct ggml_context * ctx,
6973
+ struct ggml_tensor * a,
6974
+ int w) {
6975
+ GGML_ASSERT(a->ne[3] == 1);
6976
+ GGML_ASSERT(a->type == GGML_TYPE_F32);
6977
+
6978
+ bool is_node = false;
6979
+
6980
+ if (a->grad) {
6981
+ GGML_ASSERT(false); // TODO: implement backward
6982
+ is_node = true;
6983
+ }
6984
+
6985
+ // padding
6986
+ const int px = (w - a->ne[1]%w)%w;
6987
+ const int py = (w - a->ne[2]%w)%w;
6988
+
6989
+ const int npx = (px + a->ne[1])/w;
6990
+ const int npy = (py + a->ne[2])/w;
6991
+ const int np = npx*npy;
6992
+
6993
+ const int64_t ne[4] = { a->ne[0], w, w, np, };
6994
+
6995
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
6996
+
6997
+ ggml_scratch_save(ctx);
6998
+
6999
+ struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3);
7000
+
7001
+ ((int32_t *) b->data)[0] = npx;
7002
+ ((int32_t *) b->data)[1] = npy;
7003
+ ((int32_t *) b->data)[2] = w;
7004
+
7005
+ ggml_scratch_load(ctx);
7006
+
7007
+ result->op = GGML_OP_WIN_PART;
7008
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7009
+ result->src0 = a;
7010
+ result->src1 = NULL;
7011
+ result->opt[0] = b;
7012
+
7013
+ return result;
7014
+ }
7015
+
7016
+ // ggml_win_unpart
7017
+
7018
+ struct ggml_tensor * ggml_win_unpart(
7019
+ struct ggml_context * ctx,
7020
+ struct ggml_tensor * a,
7021
+ int w0,
7022
+ int h0,
7023
+ int w) {
7024
+ GGML_ASSERT(a->type == GGML_TYPE_F32);
7025
+
7026
+ bool is_node = false;
7027
+
7028
+ if (a->grad) {
7029
+ GGML_ASSERT(false); // TODO: implement backward
7030
+ is_node = true;
7031
+ }
7032
+
7033
+ const int64_t ne[4] = { a->ne[0], w0, h0, 1, };
7034
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 3, ne);
7035
+
7036
+ ggml_scratch_save(ctx);
7037
+
7038
+ struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1);
7039
+
7040
+ ((int32_t *) b->data)[0] = w;
7041
+
7042
+ ggml_scratch_load(ctx);
7043
+
7044
+ result->op = GGML_OP_WIN_UNPART;
7045
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7046
+ result->src0 = a;
7047
+ result->src1 = NULL;
7048
+ result->opt[0] = b;
7049
+
7050
+ return result;
7051
+ }
7052
+
7053
+ // ggml_map_unary
7054
+
7055
+ struct ggml_tensor * ggml_map_unary_impl_f32(
7056
+ struct ggml_context * ctx,
7057
+ struct ggml_tensor * a,
7058
+ const ggml_unary_op_f32_t fun,
7059
+ bool inplace) {
7060
+ bool is_node = false;
7061
+
7062
+ if (!inplace && a->grad) {
7063
+ is_node = true;
7064
+ }
7065
+
7066
+ struct ggml_tensor * addr_tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(void *) / sizeof(int32_t));
7067
+ *((void (**)(void))addr_tensor->data) = (void (*)(void))fun;
7068
+ struct ggml_tensor *result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
7069
+
7070
+ result->op = GGML_OP_MAP_UNARY;
7071
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7072
+ result->src0 = a;
7073
+ result->opt[0] = addr_tensor;
7074
+
7075
+ return result;
7076
+ }
7077
+
7078
+ struct ggml_tensor * ggml_map_unary_f32(
7079
+ struct ggml_context * ctx,
7080
+ struct ggml_tensor * a,
7081
+ const ggml_unary_op_f32_t fun) {
7082
+ return ggml_map_unary_impl_f32(ctx, a, fun, false);
7083
+ }
7084
+
7085
+ struct ggml_tensor * ggml_map_unary_inplace_f32(
7086
+ struct ggml_context * ctx,
7087
+ struct ggml_tensor * a,
7088
+ const ggml_unary_op_f32_t fun) {
7089
+ return ggml_map_unary_impl_f32(ctx, a, fun, true);
7090
+ }
7091
+
7092
+ // ggml_map_binary
7093
+
7094
+ struct ggml_tensor * ggml_map_binary_impl_f32(
7095
+ struct ggml_context * ctx,
7096
+ struct ggml_tensor * a,
7097
+ struct ggml_tensor * b,
7098
+ const ggml_binary_op_f32_t fun,
7099
+ bool inplace) {
7100
+ GGML_ASSERT(ggml_are_same_shape(a, b));
6692
7101
 
6693
7102
  bool is_node = false;
6694
7103
 
@@ -6725,6 +7134,50 @@ struct ggml_tensor * ggml_map_binary_inplace_f32(
6725
7134
  return ggml_map_binary_impl_f32(ctx, a, b, fun, true);
6726
7135
  }
6727
7136
 
7137
+ // ggml_cross_entropy_loss
7138
+
7139
+ struct ggml_tensor * ggml_cross_entropy_loss(
7140
+ struct ggml_context * ctx,
7141
+ struct ggml_tensor * a,
7142
+ struct ggml_tensor * b) {
7143
+ GGML_ASSERT(ggml_are_same_shape(a, b));
7144
+ bool is_node = false;
7145
+
7146
+ if (a->grad || b->grad) {
7147
+ is_node = true;
7148
+ }
7149
+
7150
+ struct ggml_tensor * result = ggml_new_tensor_1d(ctx, a->type, 1);
7151
+
7152
+ result->op = GGML_OP_CROSS_ENTROPY_LOSS;
7153
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7154
+ result->src0 = a;
7155
+ result->src1 = b;
7156
+
7157
+ return result;
7158
+ }
7159
+
7160
+ // ggml_cross_entropy_loss_back
7161
+
7162
+ struct ggml_tensor * ggml_cross_entropy_loss_back(
7163
+ struct ggml_context * ctx,
7164
+ struct ggml_tensor * a,
7165
+ struct ggml_tensor * b,
7166
+ struct ggml_tensor * c) {
7167
+ GGML_ASSERT(ggml_are_same_shape(a, b));
7168
+ GGML_ASSERT(ggml_is_scalar(c));
7169
+
7170
+ struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
7171
+
7172
+ result->op = GGML_OP_CROSS_ENTROPY_LOSS_BACK;
7173
+ result->grad = NULL;
7174
+ result->src0 = a;
7175
+ result->src1 = b;
7176
+ result->opt[0] = c;
7177
+
7178
+ return result;
7179
+ }
7180
+
6728
7181
  ////////////////////////////////////////////////////////////////////////////////
6729
7182
 
6730
7183
  void ggml_set_param(
@@ -7674,7 +8127,7 @@ static void ggml_compute_forward_add_q_f32(
7674
8127
 
7675
8128
  void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
7676
8129
  float * src1_row = (float *)((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13));
7677
- void * dst_row = (void *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb0));
8130
+ void * dst_row = (void *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
7678
8131
 
7679
8132
  assert(ne00 % 32 == 0);
7680
8133
 
@@ -8875,6 +9328,99 @@ static void ggml_compute_forward_repeat(
8875
9328
  }
8876
9329
  }
8877
9330
 
9331
+ // ggml_compute_forward_repeat_back
9332
+
9333
+ static void ggml_compute_forward_repeat_back_f32(
9334
+ const struct ggml_compute_params * params,
9335
+ const struct ggml_tensor * src0,
9336
+ struct ggml_tensor * dst) {
9337
+ GGML_ASSERT(params->ith == 0);
9338
+ GGML_ASSERT(ggml_can_repeat(dst, src0));
9339
+
9340
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
9341
+ return;
9342
+ }
9343
+
9344
+ const int64_t ne0 = dst->ne[0];
9345
+ const int64_t ne1 = dst->ne[1];
9346
+ const int64_t ne2 = dst->ne[2];
9347
+ const int64_t ne3 = dst->ne[3];
9348
+
9349
+ const int64_t ne00 = src0->ne[0];
9350
+ const int64_t ne01 = src0->ne[1];
9351
+ const int64_t ne02 = src0->ne[2];
9352
+ const int64_t ne03 = src0->ne[3];
9353
+
9354
+ const size_t nb0 = dst->nb[0];
9355
+ const size_t nb1 = dst->nb[1];
9356
+ const size_t nb2 = dst->nb[2];
9357
+ const size_t nb3 = dst->nb[3];
9358
+
9359
+ const size_t nb00 = src0->nb[0];
9360
+ const size_t nb01 = src0->nb[1];
9361
+ const size_t nb02 = src0->nb[2];
9362
+ const size_t nb03 = src0->nb[3];
9363
+
9364
+ // guaranteed to be an integer due to the check in ggml_can_repeat
9365
+ const int nr0 = (int)(ne00/ne0);
9366
+ const int nr1 = (int)(ne01/ne1);
9367
+ const int nr2 = (int)(ne02/ne2);
9368
+ const int nr3 = (int)(ne03/ne3);
9369
+
9370
+ // TODO: support for transposed / permuted tensors
9371
+ GGML_ASSERT(nb0 == sizeof(float));
9372
+ GGML_ASSERT(nb00 == sizeof(float));
9373
+
9374
+ if (ggml_is_contiguous(dst)) {
9375
+ ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
9376
+ } else {
9377
+ for (int k3 = 0; k3 < ne3; k3++) {
9378
+ for (int k2 = 0; k2 < ne2; k2++) {
9379
+ for (int k1 = 0; k1 < ne1; k1++) {
9380
+ ggml_vec_set_f32(ne0,
9381
+ (float *) ((char *) dst->data + k1*nb1 + k2*nb2 + k3*nb3),
9382
+ 0);
9383
+ }
9384
+ }
9385
+ }
9386
+ }
9387
+
9388
+ // TODO: maybe this is not optimal?
9389
+ for (int i3 = 0; i3 < nr3; i3++) {
9390
+ for (int k3 = 0; k3 < ne3; k3++) {
9391
+ for (int i2 = 0; i2 < nr2; i2++) {
9392
+ for (int k2 = 0; k2 < ne2; k2++) {
9393
+ for (int i1 = 0; i1 < nr1; i1++) {
9394
+ for (int k1 = 0; k1 < ne1; k1++) {
9395
+ for (int i0 = 0; i0 < nr0; i0++) {
9396
+ ggml_vec_acc_f32(ne0,
9397
+ (float *) ((char *) dst->data + ( k3)*nb3 + ( k2)*nb2 + ( k1)*nb1),
9398
+ (float *) ((char *) src0->data + (i3*ne3 + k3)*nb03 + (i2*ne2 + k2)*nb02 + (i1*ne1 + k1)*nb01 + (i0*ne0)*nb00));
9399
+ }
9400
+ }
9401
+ }
9402
+ }
9403
+ }
9404
+ }
9405
+ }
9406
+ }
9407
+
9408
+ static void ggml_compute_forward_repeat_back(
9409
+ const struct ggml_compute_params * params,
9410
+ const struct ggml_tensor * src0,
9411
+ struct ggml_tensor * dst) {
9412
+ switch (src0->type) {
9413
+ case GGML_TYPE_F32:
9414
+ {
9415
+ ggml_compute_forward_repeat_back_f32(params, src0, dst);
9416
+ } break;
9417
+ default:
9418
+ {
9419
+ GGML_ASSERT(false);
9420
+ } break;
9421
+ }
9422
+ }
9423
+
8878
9424
  // ggml_compute_forward_abs
8879
9425
 
8880
9426
  static void ggml_compute_forward_abs_f32(
@@ -9142,8 +9688,65 @@ static void ggml_compute_forward_gelu(
9142
9688
  GGML_ASSERT(false);
9143
9689
  } break;
9144
9690
  }
9691
+ }
9692
+
9693
+ // ggml_compute_forward_gelu_quick
9694
+
9695
+ static void ggml_compute_forward_gelu_quick_f32(
9696
+ const struct ggml_compute_params * params,
9697
+ const struct ggml_tensor * src0,
9698
+ struct ggml_tensor * dst) {
9699
+ GGML_ASSERT(ggml_is_contiguous(src0));
9700
+ GGML_ASSERT(ggml_is_contiguous(dst));
9701
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
9702
+
9703
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
9704
+ return;
9705
+ }
9706
+
9707
+ const int ith = params->ith;
9708
+ const int nth = params->nth;
9709
+
9710
+ const int nc = src0->ne[0];
9711
+ const int nr = ggml_nrows(src0);
9712
+
9713
+ // rows per thread
9714
+ const int dr = (nr + nth - 1)/nth;
9715
+
9716
+ // row range for this thread
9717
+ const int ir0 = dr*ith;
9718
+ const int ir1 = MIN(ir0 + dr, nr);
9719
+
9720
+ for (int i1 = ir0; i1 < ir1; i1++) {
9721
+ ggml_vec_gelu_quick_f32(nc,
9722
+ (float *) ((char *) dst->data + i1*( dst->nb[1])),
9723
+ (float *) ((char *) src0->data + i1*(src0->nb[1])));
9724
+
9725
+ #ifndef NDEBUG
9726
+ for (int k = 0; k < nc; k++) {
9727
+ const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
9728
+ UNUSED(x);
9729
+ assert(!isnan(x));
9730
+ assert(!isinf(x));
9731
+ }
9732
+ #endif
9733
+ }
9734
+ }
9145
9735
 
9146
- //printf("XXXXXXXX gelu\n");
9736
+ static void ggml_compute_forward_gelu_quick(
9737
+ const struct ggml_compute_params * params,
9738
+ const struct ggml_tensor * src0,
9739
+ struct ggml_tensor * dst) {
9740
+ switch (src0->type) {
9741
+ case GGML_TYPE_F32:
9742
+ {
9743
+ ggml_compute_forward_gelu_quick_f32(params, src0, dst);
9744
+ } break;
9745
+ default:
9746
+ {
9747
+ GGML_ASSERT(false);
9748
+ } break;
9749
+ }
9147
9750
  }
9148
9751
 
9149
9752
  // ggml_compute_forward_silu
@@ -10249,9 +10852,179 @@ static void ggml_compute_forward_mul_mat(
10249
10852
  }
10250
10853
  }
10251
10854
 
10252
- // ggml_compute_forward_scale
10855
+ // ggml_compute_forward_out_prod
10253
10856
 
10254
- static void ggml_compute_forward_scale_f32(
10857
+
10858
+ static void ggml_compute_forward_out_prod_f32(
10859
+ const struct ggml_compute_params * params,
10860
+ const struct ggml_tensor * src0,
10861
+ const struct ggml_tensor * src1,
10862
+ struct ggml_tensor * dst) {
10863
+ int64_t t0 = ggml_perf_time_us();
10864
+ UNUSED(t0);
10865
+
10866
+ const int64_t ne00 = src0->ne[0];
10867
+ const int64_t ne01 = src0->ne[1];
10868
+ const int64_t ne02 = src0->ne[2];
10869
+ const int64_t ne03 = src0->ne[3];
10870
+
10871
+ const int64_t ne10 = src1->ne[0];
10872
+ //const int64_t ne11 = src1->ne[1];
10873
+ const int64_t ne12 = src1->ne[2];
10874
+ const int64_t ne13 = src1->ne[3];
10875
+
10876
+ const int64_t ne0 = dst->ne[0];
10877
+ const int64_t ne1 = dst->ne[1];
10878
+ const int64_t ne2 = dst->ne[2];
10879
+ const int64_t ne3 = dst->ne[3];
10880
+
10881
+ const int nb00 = src0->nb[0];
10882
+ const int nb01 = src0->nb[1];
10883
+ const int nb02 = src0->nb[2];
10884
+ const int nb03 = src0->nb[3];
10885
+
10886
+ const int nb10 = src1->nb[0];
10887
+ const int nb11 = src1->nb[1];
10888
+ const int nb12 = src1->nb[2];
10889
+ const int nb13 = src1->nb[3];
10890
+
10891
+ const int nb0 = dst->nb[0];
10892
+ const int nb1 = dst->nb[1];
10893
+ const int nb2 = dst->nb[2];
10894
+ const int nb3 = dst->nb[3];
10895
+
10896
+ const int ith = params->ith;
10897
+ const int nth = params->nth;
10898
+
10899
+ GGML_ASSERT(ne02 == ne12);
10900
+ GGML_ASSERT(ne03 == ne13);
10901
+ GGML_ASSERT(ne2 == ne12);
10902
+ GGML_ASSERT(ne3 == ne13);
10903
+
10904
+ // we don't support permuted src0 or src1
10905
+ GGML_ASSERT(nb00 == sizeof(float));
10906
+
10907
+ // dst cannot be transposed or permuted
10908
+ GGML_ASSERT(nb0 == sizeof(float));
10909
+ // GGML_ASSERT(nb0 <= nb1);
10910
+ // GGML_ASSERT(nb1 <= nb2);
10911
+ // GGML_ASSERT(nb2 <= nb3);
10912
+
10913
+ GGML_ASSERT(ne0 == ne00);
10914
+ GGML_ASSERT(ne1 == ne10);
10915
+ GGML_ASSERT(ne2 == ne02);
10916
+ GGML_ASSERT(ne3 == ne03);
10917
+
10918
+ // nb01 >= nb00 - src0 is not transposed
10919
+ // compute by src0 rows
10920
+
10921
+ // TODO: #if defined(GGML_USE_CUBLAS) ggml_cuda_out_prod
10922
+ // TODO: #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
10923
+
10924
+ if (params->type == GGML_TASK_INIT) {
10925
+ ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
10926
+ return;
10927
+ }
10928
+
10929
+ if (params->type == GGML_TASK_FINALIZE) {
10930
+ return;
10931
+ }
10932
+
10933
+ // parallelize by last three dimensions
10934
+
10935
+ // total rows in dst
10936
+ const int64_t nr = ne1*ne2*ne3;
10937
+
10938
+ // rows per thread
10939
+ const int64_t dr = (nr + nth - 1)/nth;
10940
+
10941
+ // row range for this thread
10942
+ const int64_t ir0 = dr*ith;
10943
+ const int64_t ir1 = MIN(ir0 + dr, nr);
10944
+
10945
+ // dst[:,:,:,:] = 0
10946
+ // for i2,i3:
10947
+ // for i1:
10948
+ // for i01:
10949
+ // for i0:
10950
+ // dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3]
10951
+
10952
+ for (int64_t ir = ir0; ir < ir1; ++ir) {
10953
+ // dst indices
10954
+ const int64_t i3 = ir/(ne2*ne1);
10955
+ const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
10956
+ const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
10957
+
10958
+ const int64_t i02 = i2;
10959
+ const int64_t i03 = i3;
10960
+
10961
+ //const int64_t i10 = i1;
10962
+ const int64_t i12 = i2;
10963
+ const int64_t i13 = i3;
10964
+
10965
+ for (int64_t i01 = 0; i01 < ne01; ++i01) {
10966
+ const int64_t i11 = i01;
10967
+
10968
+ float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
10969
+ float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
10970
+ float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
10971
+
10972
+ ggml_vec_mad_f32(ne0, d, s0, *s1);
10973
+ // for (int64_t i0 = 0; i0 < ne0; ++i0) {
10974
+ // d[i0] += s0[i0] * s1[i1];
10975
+ // }
10976
+ }
10977
+ }
10978
+
10979
+ //int64_t t1 = ggml_perf_time_us();
10980
+ //static int64_t acc = 0;
10981
+ //acc += t1 - t0;
10982
+ //if (t1 - t0 > 10) {
10983
+ // printf("\n");
10984
+ // printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03);
10985
+ // printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03);
10986
+ // printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13);
10987
+ // printf("nb10 = %5d, nb11 = %5d, nb12 = %5d, nb13 = %5d\n", nb10, nb11, nb12, nb13);
10988
+
10989
+ // printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc);
10990
+ //}
10991
+ }
10992
+
10993
+ static void ggml_compute_forward_out_prod(
10994
+ const struct ggml_compute_params * params,
10995
+ const struct ggml_tensor * src0,
10996
+ const struct ggml_tensor * src1,
10997
+ struct ggml_tensor * dst) {
10998
+ switch (src0->type) {
10999
+ case GGML_TYPE_Q4_0:
11000
+ case GGML_TYPE_Q4_1:
11001
+ case GGML_TYPE_Q5_0:
11002
+ case GGML_TYPE_Q5_1:
11003
+ case GGML_TYPE_Q8_0:
11004
+ case GGML_TYPE_Q8_1:
11005
+ {
11006
+ GGML_ASSERT(false); // todo
11007
+ // ggml_compute_forward_out_prod_q_f32(params, src0, src1, dst);
11008
+ } break;
11009
+ case GGML_TYPE_F16:
11010
+ {
11011
+ GGML_ASSERT(false); // todo
11012
+ // ggml_compute_forward_out_prod_f16_f32(params, src0, src1, dst);
11013
+ } break;
11014
+ case GGML_TYPE_F32:
11015
+ {
11016
+ ggml_compute_forward_out_prod_f32(params, src0, src1, dst);
11017
+ } break;
11018
+ default:
11019
+ {
11020
+ GGML_ASSERT(false);
11021
+ } break;
11022
+ }
11023
+ }
11024
+
11025
+ // ggml_compute_forward_scale
11026
+
11027
+ static void ggml_compute_forward_scale_f32(
10255
11028
  const struct ggml_compute_params * params,
10256
11029
  const struct ggml_tensor * src0,
10257
11030
  const struct ggml_tensor * src1,
@@ -10371,7 +11144,7 @@ static void ggml_compute_forward_set_f32(
10371
11144
  const int im2 = (ne12 == 0 ? 0 : ne12-1);
10372
11145
  const int im3 = (ne13 == 0 ? 0 : ne13-1);
10373
11146
 
10374
- GGML_ASSERT(offset + im0*nb0 + im1*nb1 + im2*nb2 + im3*nb3 < ggml_nbytes(dst));
11147
+ GGML_ASSERT(offset + im0*nb0 + im1*nb1 + im2*nb2 + im3*nb3 <= ggml_nbytes(dst));
10375
11148
 
10376
11149
  GGML_ASSERT(nb10 == sizeof(float));
10377
11150
 
@@ -10671,7 +11444,11 @@ static void ggml_compute_forward_get_rows_back_f32(
10671
11444
  GGML_ASSERT(ggml_is_contiguous(opt0));
10672
11445
  GGML_ASSERT(ggml_is_contiguous(dst));
10673
11446
 
10674
- ggml_compute_forward_dup_same_cont(params, opt0, dst);
11447
+ // ggml_compute_forward_dup_same_cont(params, opt0, dst);
11448
+
11449
+ if (params->type == GGML_TASK_INIT) {
11450
+ memset(dst->data, 0, ggml_nbytes(dst));
11451
+ }
10675
11452
 
10676
11453
  if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
10677
11454
  return;
@@ -10815,8 +11592,8 @@ static void ggml_compute_forward_diag_mask_f32(
10815
11592
  const struct ggml_tensor * src1,
10816
11593
  struct ggml_tensor * dst,
10817
11594
  const float value) {
10818
- assert(src1->type == GGML_TYPE_I32);
10819
- assert(ggml_nelements(src1) == 2);
11595
+ GGML_ASSERT(src1->type == GGML_TYPE_I32);
11596
+ GGML_ASSERT(ggml_nelements(src1) == 2);
10820
11597
 
10821
11598
  const int ith = params->ith;
10822
11599
  const int nth = params->nth;
@@ -10824,7 +11601,7 @@ static void ggml_compute_forward_diag_mask_f32(
10824
11601
  const int n_past = ((int32_t *) src1->data)[0];
10825
11602
  const bool inplace = (bool)((int32_t *) src1->data)[1];
10826
11603
 
10827
- assert(n_past >= 0);
11604
+ GGML_ASSERT(n_past >= 0);
10828
11605
 
10829
11606
  if (!inplace && (params->type == GGML_TASK_INIT)) {
10830
11607
  // memcpy needs to be synchronized across threads to avoid race conditions.
@@ -10848,8 +11625,8 @@ static void ggml_compute_forward_diag_mask_f32(
10848
11625
  const int nr = src0->ne[1];
10849
11626
  const int nz = n/nr;
10850
11627
 
10851
- assert( dst->nb[0] == sizeof(float));
10852
- assert(src0->nb[0] == sizeof(float));
11628
+ GGML_ASSERT( dst->nb[0] == sizeof(float));
11629
+ GGML_ASSERT(src0->nb[0] == sizeof(float));
10853
11630
 
10854
11631
  for (int k = 0; k < nz; k++) {
10855
11632
  for (int j = ith; j < nr; j += nth) {
@@ -10985,6 +11762,101 @@ static void ggml_compute_forward_soft_max(
10985
11762
  }
10986
11763
  }
10987
11764
 
11765
+ // ggml_compute_forward_soft_max_back
11766
+
11767
+ static void ggml_compute_forward_soft_max_back_f32(
11768
+ const struct ggml_compute_params * params,
11769
+ const struct ggml_tensor * src0,
11770
+ const struct ggml_tensor * src1,
11771
+ struct ggml_tensor * dst) {
11772
+ GGML_ASSERT(ggml_is_contiguous(src0));
11773
+ GGML_ASSERT(ggml_is_contiguous(src1));
11774
+ GGML_ASSERT(ggml_is_contiguous(dst));
11775
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
11776
+ GGML_ASSERT(ggml_are_same_shape(src1, dst));
11777
+
11778
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
11779
+ return;
11780
+ }
11781
+
11782
+ // TODO: handle transposed/permuted matrices
11783
+
11784
+ const int ith = params->ith;
11785
+ const int nth = params->nth;
11786
+
11787
+ const int nc = src0->ne[0];
11788
+ const int nr = ggml_nrows(src0);
11789
+
11790
+ // rows per thread
11791
+ const int dr = (nr + nth - 1)/nth;
11792
+
11793
+ // row range for this thread
11794
+ const int ir0 = dr*ith;
11795
+ const int ir1 = MIN(ir0 + dr, nr);
11796
+
11797
+ for (int i1 = ir0; i1 < ir1; i1++) {
11798
+ float *dy = (float *)((char *) src0->data + i1*src0->nb[1]);
11799
+ float *y = (float *)((char *) src1->data + i1*src1->nb[1]);
11800
+ float *dx = (float *)((char *) dst->data + i1*dst->nb[1]);
11801
+
11802
+ #ifndef NDEBUG
11803
+ for (int i = 0; i < nc; ++i) {
11804
+ //printf("p[%d] = %f\n", i, p[i]);
11805
+ assert(!isnan(dy[i]));
11806
+ assert(!isnan(y[i]));
11807
+ }
11808
+ #endif
11809
+ // Jii = yi - yi*yi
11810
+ // Jij = -yi*yj
11811
+ // J = diag(y)-y.T*y
11812
+ // dx = J * dy
11813
+ // dxk = sum_i(Jki * dyi)
11814
+ // dxk = sum_i(-yk*yi * dyi) - (-yk*yk)*dyk + (yk - yk*yk)*dyk
11815
+ // dxk = sum_i(-yk*yi * dyi) + yk*dyk
11816
+ // dxk = -yk * sum_i(yi * dyi) + yk*dyk
11817
+ // dxk = -yk * dot(y, dy) + yk*dyk
11818
+ // dxk = yk * (- dot(y, dy) + dyk)
11819
+ // dxk = yk * (dyk - dot(y, dy))
11820
+ //
11821
+ // post-order:
11822
+ // dot_y_dy := dot(y, dy)
11823
+ // dx := dy
11824
+ // dx := dx - dot_y_dy
11825
+ // dx := dx * y
11826
+
11827
+ // linear runtime, no additional memory
11828
+ float dot_y_dy = 0;
11829
+ ggml_vec_dot_f32 (nc, &dot_y_dy, y, dy);
11830
+ ggml_vec_cpy_f32 (nc, dx, dy);
11831
+ ggml_vec_acc1_f32(nc, dx, -dot_y_dy);
11832
+ ggml_vec_mul_f32 (nc, dx, dx, y);
11833
+
11834
+ #ifndef NDEBUG
11835
+ for (int i = 0; i < nc; ++i) {
11836
+ assert(!isnan(dx[i]));
11837
+ assert(!isinf(dx[i]));
11838
+ }
11839
+ #endif
11840
+ }
11841
+ }
11842
+
11843
+ static void ggml_compute_forward_soft_max_back(
11844
+ const struct ggml_compute_params * params,
11845
+ const struct ggml_tensor * src0,
11846
+ const struct ggml_tensor * src1,
11847
+ struct ggml_tensor * dst) {
11848
+ switch (src0->type) {
11849
+ case GGML_TYPE_F32:
11850
+ {
11851
+ ggml_compute_forward_soft_max_back_f32(params, src0, src1, dst);
11852
+ } break;
11853
+ default:
11854
+ {
11855
+ GGML_ASSERT(false);
11856
+ } break;
11857
+ }
11858
+ }
11859
+
10988
11860
  // ggml_compute_forward_alibi
10989
11861
 
10990
11862
  static void ggml_compute_forward_alibi_f32(
@@ -10993,8 +11865,9 @@ static void ggml_compute_forward_alibi_f32(
10993
11865
  const struct ggml_tensor * src1,
10994
11866
  struct ggml_tensor * dst) {
10995
11867
  assert(params->ith == 0);
10996
- assert(src1->type == GGML_TYPE_I32);
10997
- assert(ggml_nelements(src1) == 3);
11868
+
11869
+ GGML_ASSERT(src1->type == GGML_TYPE_I32);
11870
+ GGML_ASSERT(ggml_nelements(src1) == 3);
10998
11871
 
10999
11872
  if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
11000
11873
  return;
@@ -11057,8 +11930,9 @@ static void ggml_compute_forward_alibi_f16(
11057
11930
  const struct ggml_tensor * src1,
11058
11931
  struct ggml_tensor * dst) {
11059
11932
  assert(params->ith == 0);
11060
- assert(src1->type == GGML_TYPE_I32);
11061
- assert(ggml_nelements(src1) == 3);
11933
+
11934
+ GGML_ASSERT(src1->type == GGML_TYPE_I32);
11935
+ GGML_ASSERT(ggml_nelements(src1) == 3);
11062
11936
 
11063
11937
  if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
11064
11938
  return;
@@ -11160,15 +12034,16 @@ static void ggml_compute_forward_clamp_f32(
11160
12034
  const struct ggml_tensor * src1,
11161
12035
  struct ggml_tensor * dst) {
11162
12036
  assert(params->ith == 0);
11163
- assert(src1->type == GGML_TYPE_I32);
11164
- assert(ggml_nelements(src1) == 2);
12037
+
12038
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
12039
+ GGML_ASSERT(ggml_nelements(src1) == 2);
11165
12040
 
11166
12041
  if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
11167
12042
  return;
11168
12043
  }
11169
12044
 
11170
- const int min = ((float *) src1->data)[0];
11171
- const int max = ((float *) src1->data)[1];
12045
+ const float min = ((float *) src1->data)[0];
12046
+ const float max = ((float *) src1->data)[1];
11172
12047
 
11173
12048
  const int ith = params->ith;
11174
12049
  const int nth = params->nth;
@@ -11726,9 +12601,9 @@ static void ggml_compute_forward_rope_back(
11726
12601
  }
11727
12602
  }
11728
12603
 
11729
- // ggml_compute_forward_conv_1d_1s
12604
+ // ggml_compute_forward_conv_1d_s1_ph
11730
12605
 
11731
- static void ggml_compute_forward_conv_1d_1s_f16_f32(
12606
+ static void ggml_compute_forward_conv_1d_s1_ph_f16_f32(
11732
12607
  const struct ggml_compute_params * params,
11733
12608
  const struct ggml_tensor * src0,
11734
12609
  const struct ggml_tensor * src1,
@@ -11848,7 +12723,7 @@ static void ggml_compute_forward_conv_1d_1s_f16_f32(
11848
12723
  }
11849
12724
  }
11850
12725
 
11851
- static void ggml_compute_forward_conv_1d_1s_f32(
12726
+ static void ggml_compute_forward_conv_1d_s1_ph_f32(
11852
12727
  const struct ggml_compute_params * params,
11853
12728
  const struct ggml_tensor * src0,
11854
12729
  const struct ggml_tensor * src1,
@@ -11968,7 +12843,7 @@ static void ggml_compute_forward_conv_1d_1s_f32(
11968
12843
  }
11969
12844
  }
11970
12845
 
11971
- static void ggml_compute_forward_conv_1d_1s(
12846
+ static void ggml_compute_forward_conv_1d_s1_ph(
11972
12847
  const struct ggml_compute_params * params,
11973
12848
  const struct ggml_tensor * src0,
11974
12849
  const struct ggml_tensor * src1,
@@ -11976,11 +12851,11 @@ static void ggml_compute_forward_conv_1d_1s(
11976
12851
  switch (src0->type) {
11977
12852
  case GGML_TYPE_F16:
11978
12853
  {
11979
- ggml_compute_forward_conv_1d_1s_f16_f32(params, src0, src1, dst);
12854
+ ggml_compute_forward_conv_1d_s1_ph_f16_f32(params, src0, src1, dst);
11980
12855
  } break;
11981
12856
  case GGML_TYPE_F32:
11982
12857
  {
11983
- ggml_compute_forward_conv_1d_1s_f32(params, src0, src1, dst);
12858
+ ggml_compute_forward_conv_1d_s1_ph_f32(params, src0, src1, dst);
11984
12859
  } break;
11985
12860
  default:
11986
12861
  {
@@ -11989,9 +12864,9 @@ static void ggml_compute_forward_conv_1d_1s(
11989
12864
  }
11990
12865
  }
11991
12866
 
11992
- // ggml_compute_forward_conv_1d_2s
12867
+ // ggml_compute_forward_conv_1d_s2_ph
11993
12868
 
11994
- static void ggml_compute_forward_conv_1d_2s_f16_f32(
12869
+ static void ggml_compute_forward_conv_1d_s2_ph_f16_f32(
11995
12870
  const struct ggml_compute_params * params,
11996
12871
  const struct ggml_tensor * src0,
11997
12872
  const struct ggml_tensor * src1,
@@ -12111,7 +12986,7 @@ static void ggml_compute_forward_conv_1d_2s_f16_f32(
12111
12986
  }
12112
12987
  }
12113
12988
 
12114
- static void ggml_compute_forward_conv_1d_2s_f32(
12989
+ static void ggml_compute_forward_conv_1d_s2_ph_f32(
12115
12990
  const struct ggml_compute_params * params,
12116
12991
  const struct ggml_tensor * src0,
12117
12992
  const struct ggml_tensor * src1,
@@ -12231,7 +13106,7 @@ static void ggml_compute_forward_conv_1d_2s_f32(
12231
13106
  }
12232
13107
  }
12233
13108
 
12234
- static void ggml_compute_forward_conv_1d_2s(
13109
+ static void ggml_compute_forward_conv_1d_s2_ph(
12235
13110
  const struct ggml_compute_params * params,
12236
13111
  const struct ggml_tensor * src0,
12237
13112
  const struct ggml_tensor * src1,
@@ -12239,11 +13114,11 @@ static void ggml_compute_forward_conv_1d_2s(
12239
13114
  switch (src0->type) {
12240
13115
  case GGML_TYPE_F16:
12241
13116
  {
12242
- ggml_compute_forward_conv_1d_2s_f16_f32(params, src0, src1, dst);
13117
+ ggml_compute_forward_conv_1d_s2_ph_f16_f32(params, src0, src1, dst);
12243
13118
  } break;
12244
13119
  case GGML_TYPE_F32:
12245
13120
  {
12246
- ggml_compute_forward_conv_1d_2s_f32(params, src0, src1, dst);
13121
+ ggml_compute_forward_conv_1d_s2_ph_f32(params, src0, src1, dst);
12247
13122
  } break;
12248
13123
  default:
12249
13124
  {
@@ -12252,51 +13127,188 @@ static void ggml_compute_forward_conv_1d_2s(
12252
13127
  }
12253
13128
  }
12254
13129
 
12255
- // ggml_compute_forward_flash_attn
13130
+ // ggml_compute_forward_conv_2d_sk_p0
12256
13131
 
12257
- static void ggml_compute_forward_flash_attn_f32(
13132
+ static void ggml_compute_forward_conv_2d_sk_p0_f16_f32(
12258
13133
  const struct ggml_compute_params * params,
12259
- const struct ggml_tensor * q,
12260
- const struct ggml_tensor * k,
12261
- const struct ggml_tensor * v,
12262
- const bool masked,
12263
- struct ggml_tensor * dst) {
13134
+ const struct ggml_tensor * src0,
13135
+ const struct ggml_tensor * src1,
13136
+ struct ggml_tensor * dst) {
13137
+ GGML_ASSERT(src0->type == GGML_TYPE_F16);
13138
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
13139
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
13140
+
12264
13141
  int64_t t0 = ggml_perf_time_us();
12265
13142
  UNUSED(t0);
12266
13143
 
12267
- const int64_t neq0 = q->ne[0];
12268
- const int64_t neq1 = q->ne[1];
12269
- const int64_t neq2 = q->ne[2];
12270
- const int64_t neq3 = q->ne[3];
13144
+ const int ne00 = src0->ne[0];
13145
+ const int ne01 = src0->ne[1];
13146
+ const int ne02 = src0->ne[2];
13147
+ //const int ne03 = src0->ne[3];
12271
13148
 
12272
- const int64_t nek0 = k->ne[0];
12273
- const int64_t nek1 = k->ne[1];
12274
- //const int64_t nek2 = k->ne[2];
12275
- //const int64_t nek3 = k->ne[3];
13149
+ const int ne10 = src1->ne[0];
13150
+ //const int ne11 = src1->ne[1];
13151
+ const int ne12 = src1->ne[2];
13152
+ //const int ne13 = src1->ne[3];
12276
13153
 
12277
- //const int64_t nev0 = v->ne[0];
12278
- const int64_t nev1 = v->ne[1];
12279
- //const int64_t nev2 = v->ne[2];
12280
- //const int64_t nev3 = v->ne[3];
13154
+ const int ne0 = dst->ne[0];
13155
+ const int ne1 = dst->ne[1];
13156
+ const int ne2 = dst->ne[2];
13157
+ //const int ne3 = dst->ne[3];
13158
+ //const int ne = ne0*ne1*ne2*ne3;
12281
13159
 
12282
- const int64_t ne0 = dst->ne[0];
12283
- const int64_t ne1 = dst->ne[1];
12284
- //const int64_t ne2 = dst->ne[2];
12285
- //const int64_t ne3 = dst->ne[3];
13160
+ const int nb00 = src0->nb[0];
13161
+ //const int nb01 = src0->nb[1];
13162
+ //const int nb02 = src0->nb[2];
13163
+ const int nb03 = src0->nb[3];
12286
13164
 
12287
- const int nbk0 = k->nb[0];
12288
- const int nbk1 = k->nb[1];
12289
- const int nbk2 = k->nb[2];
12290
- const int nbk3 = k->nb[3];
13165
+ const int nb10 = src1->nb[0];
13166
+ //const int nb11 = src1->nb[1];
13167
+ const int nb12 = src1->nb[2];
13168
+ //const int nb13 = src1->nb[3];
12291
13169
 
12292
- const int nbq0 = q->nb[0];
12293
- const int nbq1 = q->nb[1];
12294
- const int nbq2 = q->nb[2];
12295
- const int nbq3 = q->nb[3];
13170
+ //const int nb0 = dst->nb[0];
13171
+ //const int nb1 = dst->nb[1];
13172
+ const int nb2 = dst->nb[2];
13173
+ //const int nb3 = dst->nb[3];
12296
13174
 
12297
- const int nbv0 = v->nb[0];
12298
- const int nbv1 = v->nb[1];
12299
- const int nbv2 = v->nb[2];
13175
+ const int ith = params->ith;
13176
+ const int nth = params->nth;
13177
+
13178
+ const int nk0 = ne00;
13179
+ const int nk1 = ne01;
13180
+
13181
+ // size of the convolution row - the kernel size unrolled across all channels
13182
+ // round-up so it is more suitable for SIMD
13183
+ const int ew0 = ggml_up32(nk0*nk1*ne02);
13184
+
13185
+ GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
13186
+ GGML_ASSERT(nb10 == sizeof(float));
13187
+
13188
+ if (params->type == GGML_TASK_INIT) {
13189
+ // TODO: fix this memset (wsize is overestimated)
13190
+ memset(params->wdata, 0, params->wsize);
13191
+
13192
+ // prepare source data (src1)
13193
+ {
13194
+ ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
13195
+
13196
+ for (int i12 = 0; i12 < ne12; i12++) {
13197
+ const float * const src = (float *)((char *) src1->data + i12*nb12);
13198
+ ggml_fp16_t * dst_data = wdata;
13199
+
13200
+ for (int i1 = 0; i1 < ne1; i1++) {
13201
+ for (int i0 = 0; i0 < ne0; i0++) {
13202
+ for (int ik1 = 0; ik1 < nk1; ik1++) {
13203
+ for (int ik0 = 0; ik0 < nk0; ik0++) {
13204
+ dst_data[(i1*ne0 + i0)*ew0 + i12*(nk0*nk1) + ik1*nk0 + ik0] =
13205
+ GGML_FP32_TO_FP16(src[(i1*nk1 + ik1)*ne10 + (i0*nk0 + ik0)]);
13206
+ }
13207
+ }
13208
+ }
13209
+ }
13210
+ }
13211
+ }
13212
+
13213
+ return;
13214
+ }
13215
+
13216
+ if (params->type == GGML_TASK_FINALIZE) {
13217
+ return;
13218
+ }
13219
+
13220
+ // total patches in dst
13221
+ const int np = ne2;
13222
+
13223
+ // patches per thread
13224
+ const int dp = (np + nth - 1)/nth;
13225
+
13226
+ // patch range for this thread
13227
+ const int ip0 = dp*ith;
13228
+ const int ip1 = MIN(ip0 + dp, np);
13229
+
13230
+ ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
13231
+
13232
+ for (int i2 = ip0; i2 < ip1; i2++) {
13233
+ float * dst_data = (float *)((char *) dst->data + i2*nb2);
13234
+
13235
+ for (int i1 = 0; i1 < ne1; ++i1) {
13236
+ for (int i0 = 0; i0 < ne0; ++i0) {
13237
+ ggml_vec_dot_f16(ew0, dst_data + i1*ne0 + i0,
13238
+ (ggml_fp16_t *) ((char *) src0->data + i2*nb03),
13239
+ (ggml_fp16_t *) wdata + (i1*ne0 + i0)*ew0);
13240
+ }
13241
+ }
13242
+ }
13243
+ }
13244
+
13245
+ static void ggml_compute_forward_conv_2d_sk_p0(
13246
+ const struct ggml_compute_params * params,
13247
+ const struct ggml_tensor * src0,
13248
+ const struct ggml_tensor * src1,
13249
+ struct ggml_tensor * dst) {
13250
+ switch (src0->type) {
13251
+ case GGML_TYPE_F16:
13252
+ {
13253
+ ggml_compute_forward_conv_2d_sk_p0_f16_f32(params, src0, src1, dst);
13254
+ } break;
13255
+ case GGML_TYPE_F32:
13256
+ {
13257
+ //ggml_compute_forward_conv_2d_sk_p0_f32(params, src0, src1, dst);
13258
+ GGML_ASSERT(false);
13259
+ } break;
13260
+ default:
13261
+ {
13262
+ GGML_ASSERT(false);
13263
+ } break;
13264
+ }
13265
+ }
13266
+
13267
+ // ggml_compute_forward_flash_attn
13268
+
13269
+ static void ggml_compute_forward_flash_attn_f32(
13270
+ const struct ggml_compute_params * params,
13271
+ const struct ggml_tensor * q,
13272
+ const struct ggml_tensor * k,
13273
+ const struct ggml_tensor * v,
13274
+ const bool masked,
13275
+ struct ggml_tensor * dst) {
13276
+ int64_t t0 = ggml_perf_time_us();
13277
+ UNUSED(t0);
13278
+
13279
+ const int64_t neq0 = q->ne[0];
13280
+ const int64_t neq1 = q->ne[1];
13281
+ const int64_t neq2 = q->ne[2];
13282
+ const int64_t neq3 = q->ne[3];
13283
+
13284
+ const int64_t nek0 = k->ne[0];
13285
+ const int64_t nek1 = k->ne[1];
13286
+ //const int64_t nek2 = k->ne[2];
13287
+ //const int64_t nek3 = k->ne[3];
13288
+
13289
+ //const int64_t nev0 = v->ne[0];
13290
+ const int64_t nev1 = v->ne[1];
13291
+ //const int64_t nev2 = v->ne[2];
13292
+ //const int64_t nev3 = v->ne[3];
13293
+
13294
+ const int64_t ne0 = dst->ne[0];
13295
+ const int64_t ne1 = dst->ne[1];
13296
+ //const int64_t ne2 = dst->ne[2];
13297
+ //const int64_t ne3 = dst->ne[3];
13298
+
13299
+ const int nbk0 = k->nb[0];
13300
+ const int nbk1 = k->nb[1];
13301
+ const int nbk2 = k->nb[2];
13302
+ const int nbk3 = k->nb[3];
13303
+
13304
+ const int nbq0 = q->nb[0];
13305
+ const int nbq1 = q->nb[1];
13306
+ const int nbq2 = q->nb[2];
13307
+ const int nbq3 = q->nb[3];
13308
+
13309
+ const int nbv0 = v->nb[0];
13310
+ const int nbv1 = v->nb[1];
13311
+ const int nbv2 = v->nb[2];
12300
13312
  const int nbv3 = v->nb[3];
12301
13313
 
12302
13314
  const int nb0 = dst->nb[0];
@@ -12938,91 +13950,917 @@ static void ggml_compute_forward_flash_ff(
12938
13950
  }
12939
13951
  }
12940
13952
 
12941
- // ggml_compute_forward_map_unary
13953
+ // ggml_compute_forward_flash_attn_back
12942
13954
 
12943
- static void ggml_compute_forward_map_unary_f32(
13955
+ static void ggml_compute_forward_flash_attn_back_f32(
12944
13956
  const struct ggml_compute_params * params,
12945
- const struct ggml_tensor * src0,
12946
- struct ggml_tensor * dst,
12947
- const ggml_unary_op_f32_t fun) {
12948
- GGML_ASSERT(ggml_are_same_shape(src0, dst));
13957
+ const struct ggml_tensor * q,
13958
+ const struct ggml_tensor * k,
13959
+ const struct ggml_tensor * v,
13960
+ const struct ggml_tensor * d,
13961
+ const bool masked,
13962
+ struct ggml_tensor * dst) {
13963
+ int64_t t0 = ggml_perf_time_us();
13964
+ UNUSED(t0);
12949
13965
 
12950
- if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
12951
- return;
12952
- }
13966
+ const int64_t neq0 = q->ne[0];
13967
+ const int64_t neq1 = q->ne[1];
13968
+ const int64_t neq2 = q->ne[2];
13969
+ const int64_t neq3 = q->ne[3];
12953
13970
 
12954
- const int n = ggml_nrows(src0);
12955
- const int nc = src0->ne[0];
13971
+ const int64_t nek0 = k->ne[0];
13972
+ const int64_t nek1 = k->ne[1];
13973
+ //const int64_t nek2 = k->ne[2];
13974
+ //const int64_t nek3 = k->ne[3];
12956
13975
 
12957
- assert( dst->nb[0] == sizeof(float));
12958
- assert(src0->nb[0] == sizeof(float));
13976
+ const int64_t nev0 = v->ne[0];
13977
+ const int64_t nev1 = v->ne[1];
13978
+ //const int64_t nev2 = v->ne[2];
13979
+ //const int64_t nev3 = v->ne[3];
12959
13980
 
12960
- for (int i = 0; i < n; i++) {
12961
- fun(nc,
12962
- (float *) ((char *) dst->data + i*( dst->nb[1])),
12963
- (float *) ((char *) src0->data + i*(src0->nb[1])));
12964
- }
12965
- }
13981
+ const int64_t ned0 = d->ne[0];
13982
+ const int64_t ned1 = d->ne[1];
13983
+ //const int64_t ned2 = d->ne[2];
13984
+ //const int64_t ned3 = d->ne[3];
12966
13985
 
13986
+ const int64_t ne0 = dst->ne[0];
13987
+ const int64_t ne1 = dst->ne[1];
13988
+ const int64_t ne2 = dst->ne[2];
13989
+ const int64_t ne3 = dst->ne[3];
12967
13990
 
12968
- static void ggml_compute_forward_map_unary(
12969
- const struct ggml_compute_params * params,
12970
- const struct ggml_tensor * src0,
12971
- struct ggml_tensor * dst,
12972
- const ggml_unary_op_f32_t fun) {
12973
- switch (src0->type) {
12974
- case GGML_TYPE_F32:
12975
- {
12976
- ggml_compute_forward_map_unary_f32(params, src0, dst, fun);
12977
- } break;
12978
- default:
12979
- {
12980
- GGML_ASSERT(false);
12981
- } break;
12982
- }
12983
- }
13991
+ const int nbk0 = k->nb[0];
13992
+ const int nbk1 = k->nb[1];
13993
+ const int nbk2 = k->nb[2];
13994
+ const int nbk3 = k->nb[3];
12984
13995
 
12985
- // ggml_compute_forward_map_binary
13996
+ const int nbq0 = q->nb[0];
13997
+ const int nbq1 = q->nb[1];
13998
+ const int nbq2 = q->nb[2];
13999
+ const int nbq3 = q->nb[3];
12986
14000
 
12987
- static void ggml_compute_forward_map_binary_f32(
12988
- const struct ggml_compute_params * params,
12989
- const struct ggml_tensor * src0,
12990
- const struct ggml_tensor * src1,
12991
- struct ggml_tensor * dst,
12992
- const ggml_binary_op_f32_t fun) {
12993
- assert(params->ith == 0);
12994
- assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
14001
+ const int nbv0 = v->nb[0];
14002
+ const int nbv1 = v->nb[1];
14003
+ const int nbv2 = v->nb[2];
14004
+ const int nbv3 = v->nb[3];
12995
14005
 
12996
- if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
14006
+ const int nbd0 = d->nb[0];
14007
+ const int nbd1 = d->nb[1];
14008
+ const int nbd2 = d->nb[2];
14009
+ const int nbd3 = d->nb[3];
14010
+
14011
+ const int nb0 = dst->nb[0];
14012
+ const int nb1 = dst->nb[1];
14013
+ const int nb2 = dst->nb[2];
14014
+ const int nb3 = dst->nb[3];
14015
+
14016
+ const int ith = params->ith;
14017
+ const int nth = params->nth;
14018
+
14019
+ const int64_t D = neq0;
14020
+ const int64_t N = neq1;
14021
+ const int64_t P = nek1 - N;
14022
+ const int64_t M = P + N;
14023
+
14024
+ const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL);
14025
+ const int mxDM = MAX(D, Mup);
14026
+
14027
+ // GGML_ASSERT(ne0 == D);
14028
+ // GGML_ASSERT(ne1 == N);
14029
+ GGML_ASSERT(P >= 0);
14030
+
14031
+ GGML_ASSERT(nbq0 == sizeof(float));
14032
+ GGML_ASSERT(nbk0 == sizeof(float));
14033
+ GGML_ASSERT(nbv0 == sizeof(float));
14034
+
14035
+ GGML_ASSERT(neq0 == D);
14036
+ GGML_ASSERT(nek0 == D);
14037
+ GGML_ASSERT(nev1 == D);
14038
+ GGML_ASSERT(ned0 == D);
14039
+
14040
+ GGML_ASSERT(neq1 == N);
14041
+ GGML_ASSERT(nek1 == N + P);
14042
+ GGML_ASSERT(nev1 == D);
14043
+ GGML_ASSERT(ned1 == N);
14044
+
14045
+ // dst cannot be transposed or permuted
14046
+ GGML_ASSERT(nb0 == sizeof(float));
14047
+ GGML_ASSERT(nb0 <= nb1);
14048
+ GGML_ASSERT(nb1 <= nb2);
14049
+ GGML_ASSERT(nb2 <= nb3);
14050
+
14051
+ if (params->type == GGML_TASK_INIT) {
14052
+ if (ith == 0) {
14053
+ memset(dst->data, 0, nb0*ne0*ne1*ne2*ne3);
14054
+ }
14055
+ return;
14056
+ }
14057
+
14058
+ if (params->type == GGML_TASK_FINALIZE) {
12997
14059
  return;
12998
14060
  }
12999
14061
 
13000
- const int n = ggml_nrows(src0);
13001
- const int nc = src0->ne[0];
14062
+ // parallelize by q rows using ggml_vec_dot_f32
14063
+
14064
+ // total rows in q
14065
+ const int nr = neq2*neq3;
14066
+
14067
+ // rows per thread
14068
+ const int dr = (nr + nth - 1)/nth;
14069
+
14070
+ // row range for this thread
14071
+ const int ir0 = dr*ith;
14072
+ const int ir1 = MIN(ir0 + dr, nr);
14073
+
14074
+ const float scale = 1.0f/sqrtf(D);
14075
+
14076
+ //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
14077
+
14078
+ for (int ir = ir0; ir < ir1; ++ir) {
14079
+ // q indices
14080
+ const int iq3 = ir/(neq2);
14081
+ const int iq2 = ir - iq3*neq2;
14082
+ for ( int iq1 = 0; iq1 < neq1; ++iq1) {
14083
+
14084
+
14085
+ // not sure about CACHE_LINE_SIZE_F32..
14086
+ // - maybe it must not be multiplied by 2 and excluded from .. in SM 1*(..) offset?
14087
+ float * S = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 0*(mxDM+CACHE_LINE_SIZE_F32);
14088
+ float * SM = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 1*(mxDM+CACHE_LINE_SIZE_F32);
14089
+
14090
+ for (int i = M; i < Mup; ++i) {
14091
+ S[i] = -INFINITY;
14092
+ }
14093
+
14094
+ for (int64_t ic = 0; ic < nek1; ++ic) {
14095
+ // k indices
14096
+ const int ik3 = iq3;
14097
+ const int ik2 = iq2;
14098
+ const int ik1 = ic;
14099
+
14100
+ // S indices
14101
+ const int i1 = ik1;
14102
+
14103
+ ggml_vec_dot_f32(neq0,
14104
+ S + i1,
14105
+ (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
14106
+ (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
14107
+ }
14108
+
14109
+ // scale
14110
+ ggml_vec_scale_f32(nek1, S, scale);
14111
+
14112
+ if (masked) {
14113
+ for (int64_t i = P; i < M; i++) {
14114
+ if (i > P + iq1) {
14115
+ S[i] = -INFINITY;
14116
+ }
14117
+ }
14118
+ }
14119
+
14120
+ // softmax
14121
+ {
14122
+ float max = -INFINITY;
14123
+ ggml_vec_max_f32(M, &max, S);
14124
+
14125
+ ggml_float sum = 0.0;
14126
+ {
14127
+ #ifdef GGML_SOFT_MAX_ACCELERATE
14128
+ max = -max;
14129
+ vDSP_vsadd(SM, 1, &max, SM, 1, Mup);
14130
+ vvexpf(SM, SM, &Mup);
14131
+ ggml_vec_sum_f32(Mup, &sum, SM);
14132
+ #else
14133
+ uint16_t scvt[GGML_SOFT_MAX_UNROLL];
14134
+ ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
14135
+
14136
+ for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
14137
+ float * SR = S + i;
14138
+ float * SW = SM + i;
14139
+
14140
+ for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) {
14141
+ if (SR[j] == -INFINITY) {
14142
+ SW[j] = 0.0f;
14143
+ } else {
14144
+ ggml_fp16_t s = GGML_FP32_TO_FP16(SR[j] - max);
14145
+ memcpy(&scvt[j], &s, sizeof(uint16_t));
14146
+ const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]);
14147
+ sump[j] += (ggml_float)val;
14148
+ SW[j] = val;
14149
+ }
14150
+ }
14151
+ }
14152
+
14153
+ for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) {
14154
+ sum += sump[i];
14155
+ }
14156
+ #endif
14157
+ }
14158
+
14159
+ assert(sum > 0.0);
14160
+
14161
+ sum = 1.0/sum;
14162
+ ggml_vec_scale_f32(M, SM, sum);
14163
+
14164
+ }
14165
+
14166
+ // step-by-step explanation
14167
+ {
14168
+ // forward-process shape grads from backward process
14169
+ // parallel_for iq2,iq3:
14170
+ // k[:D,:M,:,:] [D,M,:,:] grad[k][:D,:M,iq2,iq3] += grad[kcur]
14171
+ // q[:D,:N,:,:] [D,N,:,:] grad[q][:D,iq1,iq2,iq3] += grad[qcur]
14172
+ // v[:M,:D,:,:] [M,D,:,:] grad[v][:M,:D,iq2,iq3] += grad[vcur]
14173
+ // for iq1:
14174
+ // kcur = k[:D,:M,iq2,iq3] [D,M,1,1] grad[kcur] = grad[S1].T @ qcur
14175
+ // qcur = q[:D,iq1,iq2,iq3] [D,1,1,1] grad[qcur] = grad[S1] @ kcur
14176
+ // vcur = v[:M,:D,iq2,iq3] [M,D,1,1] grad[vcur] = grad[S5].T @ S4
14177
+ // S0 = -Inf [D,1,1,1]
14178
+ // ~S1[i] = dot(kcur[:D,i], qcur)
14179
+ // S1 = qcur @ kcur.T [M,1,1,1] grad[S1] = grad[S2] * scale
14180
+ // S2 = S1 * scale [M,1,1,1] grad[S2] = diag_mask_zero(grad[S3], P)
14181
+ // S3 = diag_mask_inf(S2, P) [M,1,1,1] grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
14182
+ // S4 = softmax(S3) [M,1,1,1] grad[S4] = grad[S5] @ vcur
14183
+ // ~S5[i] = dot(vcur[:,i], S4)
14184
+ // S5 = S4 @ vcur.T [D,1,1,1] grad[S5] = d[:D,iq1,iq2,iq3]
14185
+ // ~dst[i,iq1,iq2,iq3] = S5[i] ^
14186
+ // dst[:D,iq1,iq2,iq3] = S5 | grad[dst[:D,iq1,iq2,iq3]] = d[:D,iq1,iq2,iq3]
14187
+ // dst backward-/ grad[dst] = d
14188
+ //
14189
+ // output gradients with their dependencies:
14190
+ //
14191
+ // grad[kcur] = grad[S1].T @ qcur
14192
+ // grad[S1] = diag_mask_zero(grad[S3], P) * scale
14193
+ // grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
14194
+ // grad[S4] = grad[S5] @ vcur
14195
+ // grad[S4] = d[:D,iq1,iq2,iq3] @ vcur
14196
+ // grad[qcur] = grad[S1] @ kcur
14197
+ // grad[vcur] = grad[S5].T @ S4
14198
+ // grad[vcur] = d[:D,iq1,iq2,iq3].T @ S4
14199
+ //
14200
+ // in post-order:
14201
+ //
14202
+ // S1 = qcur @ kcur.T
14203
+ // S2 = S1 * scale
14204
+ // S3 = diag_mask_inf(S2, P)
14205
+ // S4 = softmax(S3)
14206
+ // grad[S4] = d[:D,iq1,iq2,iq3] @ vcur
14207
+ // grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
14208
+ // grad[S1] = diag_mask_zero(grad[S3], P) * scale
14209
+ // grad[qcur] = grad[S1] @ kcur
14210
+ // grad[kcur] = grad[S1].T @ qcur
14211
+ // grad[vcur] = d[:D,iq1,iq2,iq3].T @ S4
14212
+ //
14213
+ // using less variables (SM=S4):
14214
+ //
14215
+ // S = diag_mask_inf(qcur @ kcur.T * scale, P)
14216
+ // SM = softmax(S)
14217
+ // S = d[:D,iq1,iq2,iq3] @ vcur
14218
+ // dot_SM_gradSM = dot(SM, S)
14219
+ // S = SM * (S - dot(SM, S))
14220
+ // S = diag_mask_zero(S, P) * scale
14221
+ //
14222
+ // grad[q][:D,iq1,iq2,iq3] += S @ kcur
14223
+ // grad[k][:D,:M,iq2,iq3] += S.T @ qcur
14224
+ // grad[v][:M,:D,iq2,iq3] += d[:D,iq1,iq2,iq3].T @ SM
14225
+ }
14226
+
14227
+ // S = gradSM = d[:D,iq1,iq2,iq3] @ vcur
14228
+ // S = d[:D,iq1,iq2,iq3] @ vcur
14229
+ // S[:M] += vcur[:M,ic] * d[ic,iq1,iq2,iq3]
14230
+ ggml_vec_set_f32(M, S, 0);
14231
+ for (int64_t ic = 0; ic < D; ++ic) {
14232
+ // dst indices
14233
+ const int i1 = iq1;
14234
+ const int i2 = iq2;
14235
+ const int i3 = iq3;
14236
+
14237
+ ggml_vec_mad_f32(M,
14238
+ S,
14239
+ (float *) ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)),
14240
+ *(float *) ((char *) d->data + (ic*nbd0 + i1*nbd1 + i2*nbd2 + i3*nbd3)));
14241
+ }
14242
+
14243
+ // S = SM * (S - dot(SM, S))
14244
+ float dot_SM_gradSM = 0;
14245
+ ggml_vec_dot_f32 (M, &dot_SM_gradSM, SM, S);
14246
+ ggml_vec_acc1_f32(M, S, -dot_SM_gradSM);
14247
+ ggml_vec_mul_f32 (M, S, S, SM);
14248
+
14249
+ // S = diag_mask_zero(S, P) * scale
14250
+ if (masked) {
14251
+ // for (int64_t i = P + iq1 + 1; i < M; i++) {
14252
+ // S[i] = 0;
14253
+ // }
14254
+ for (int64_t i = P; i < M; i++) {
14255
+ if (i > P + iq1) {
14256
+ S[i] = 0;
14257
+ }
14258
+ }
14259
+ }
14260
+ ggml_vec_scale_f32(M, S, scale);
14261
+
14262
+ void * grad_q = (char *) dst->data;
14263
+ void * grad_k = (char *) dst->data + nb0*D*N*neq2*neq3;
14264
+ void * grad_v = (char *) dst->data + nb0*D*N*neq2*neq3 + nb0*D*M*neq2*neq3;
14265
+
14266
+ const size_t nbgq1 = nb0*neq0;
14267
+ const size_t nbgq2 = nb0*neq0*neq1;
14268
+ const size_t nbgq3 = nb0*neq0*neq1*neq2;
14269
+
14270
+ const size_t nbgk1 = nb0*nek0;
14271
+ const size_t nbgk2 = nb0*nek0*nek1;
14272
+ const size_t nbgk3 = nb0*nek0*nek1*neq2;
14273
+
14274
+ const size_t nbgv1 = nb0*nev0;
14275
+ const size_t nbgv2 = nb0*nev0*nev1;
14276
+ const size_t nbgv3 = nb0*nev0*nev1*neq2;
14277
+
14278
+ // S shape [M,1]
14279
+ // SM shape [M,1]
14280
+ // kcur shape [D,M]
14281
+ // qcur shape [D,1]
14282
+ // vcur shape [M,D]
14283
+ //
14284
+ // grad[q][:D,iq1,iq2,iq3] += S @ kcur
14285
+ // grad[q][:D,iq1,iq2,iq3] += shape[M,1] @ shape[D,M]
14286
+ // grad[q][:D,iq1,iq2,iq3] += S[ic] * kcur[:D,ic]
14287
+ //
14288
+ //// grad[q][ic,iq1,iq2,iq3] += dot(kcur[:,ic],S.T)
14289
+ //// grad[q][ic,iq1,iq2,iq3] += dot(k[:D,ic,iq2,iq3],S.T)
14290
+ for (int64_t ic = 0; ic < M; ++ic) {
14291
+ // dst indices
14292
+ const int i1 = iq1;
14293
+ const int i2 = iq2;
14294
+ const int i3 = iq3;
14295
+
14296
+ ggml_vec_mad_f32(D,
14297
+ (float *) ((char *) grad_q + (i1*nbgq1 + i2*nbgq2 + i3*nbgq3)),
14298
+ (float *) ((char *) k->data + (ic*nbk1 + i2*nbk2 + i3*nbk3)),
14299
+ S[ic]);
14300
+ }
14301
+
14302
+ // grad[k][:D,:M,iq2,iq3] += S.T @ qcur
14303
+ // grad[k][:D,ic,iq2,iq3] += S.T[0,ic] * qcur[:D,0]
14304
+ // grad[k][:D,ic,iq2,iq3] += S[ic] * qcur[:D,0]
14305
+ for (int64_t ic = 0; ic < M; ++ic) {
14306
+ // dst indices
14307
+ const int i1 = iq1;
14308
+ const int i2 = iq2;
14309
+ const int i3 = iq3;
14310
+
14311
+ // ggml_vec_set_f32(D,
14312
+ // (float *) ((char *) grad_k + (ic*nbgk1 + i2*nbgk2 + i3*nbgk3)),
14313
+ // 0);
14314
+ ggml_vec_mad_f32(D,
14315
+ (float *) ((char *) grad_k + (ic*nbgk1 + i2*nbgk2 + i3*nbgk3)),
14316
+ (float *) ((char *) q->data + (i1*nbq1 + i2*nbq2 + i3*nbq3)),
14317
+ S[ic]);
14318
+ }
14319
+
14320
+ // grad[v][:M,:D,iq2,iq3] += d[:D,iq1,iq2,iq3].T @ SM
14321
+ // grad[v][:M,ic,iq2,iq3] += d[:D,iq1,iq2,iq3].T[0,ic] * SM[:M]
14322
+ // grad[v][:M,ic,iq2,iq3] += d[ic,iq1,iq2,iq3] * SM[:M]
14323
+ for (int64_t ic = 0; ic < D; ++ic) {
14324
+ // dst indices
14325
+ const int i1 = iq1;
14326
+ const int i2 = iq2;
14327
+ const int i3 = iq3;
14328
+
14329
+ // ggml_vec_set_f32(M,
14330
+ // (float *) ((char *) grad_v + ( ic*nbgv1 + i2*nbgv2 + i3*nbgv3)),
14331
+ // 0);
14332
+ ggml_vec_mad_f32(M,
14333
+ (float *) ((char *) grad_v + ( ic*nbgv1 + i2*nbgv2 + i3*nbgv3)),
14334
+ SM,
14335
+ *(float *) ((char *) d->data + (ic*nbd0 + i1*nbd1 + i2*nbd2 + i3*nbd3)));
14336
+ }
14337
+ }
14338
+ }
14339
+ }
14340
+
14341
+ static void ggml_compute_forward_flash_attn_back(
14342
+ const struct ggml_compute_params * params,
14343
+ const struct ggml_tensor * q,
14344
+ const struct ggml_tensor * k,
14345
+ const struct ggml_tensor * v,
14346
+ const struct ggml_tensor * d,
14347
+ const bool masked,
14348
+ struct ggml_tensor * dst) {
14349
+ switch (q->type) {
14350
+ case GGML_TYPE_F32:
14351
+ {
14352
+ ggml_compute_forward_flash_attn_back_f32(params, q, k, v, d, masked, dst);
14353
+ } break;
14354
+ default:
14355
+ {
14356
+ GGML_ASSERT(false);
14357
+ } break;
14358
+ }
14359
+ }
14360
+
14361
+ // ggml_compute_forward_win_part
14362
+
14363
+ static void ggml_compute_forward_win_part_f32(
14364
+ const struct ggml_compute_params * params,
14365
+ const struct ggml_tensor * src0,
14366
+ const struct ggml_tensor * opt0,
14367
+ struct ggml_tensor * dst) {
14368
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
14369
+ return;
14370
+ }
14371
+
14372
+ const int64_t ne00 = src0->ne[0]; UNUSED(ne00);
14373
+ const int64_t ne01 = src0->ne[1];
14374
+ const int64_t ne02 = src0->ne[2];
14375
+ const int64_t ne03 = src0->ne[3]; UNUSED(ne03);
14376
+
14377
+ const int64_t ne0 = dst->ne[0];
14378
+ const int64_t ne1 = dst->ne[1];
14379
+ const int64_t ne2 = dst->ne[2];
14380
+ const int64_t ne3 = dst->ne[3]; UNUSED(ne3);
14381
+
14382
+ const int32_t nep0 = ((const int32_t *)(opt0->data))[0];
14383
+ const int32_t nep1 = ((const int32_t *)(opt0->data))[1];
14384
+ const int32_t w = ((const int32_t *)(opt0->data))[2];
14385
+
14386
+ assert(ne00 == ne0);
14387
+ assert(ne3 == nep0*nep1);
14388
+
14389
+ // TODO: optimize / multi-thread
14390
+ for (int py = 0; py < nep1; ++py) {
14391
+ for (int px = 0; px < nep0; ++px) {
14392
+ const int64_t i3 = py*nep0 + px;
14393
+ for (int64_t i2 = 0; i2 < ne2; ++i2) {
14394
+ for (int64_t i1 = 0; i1 < ne1; ++i1) {
14395
+ for (int64_t i0 = 0; i0 < ne0; ++i0) {
14396
+ const int64_t i02 = py*w + i2;
14397
+ const int64_t i01 = px*w + i1;
14398
+ const int64_t i00 = i0;
14399
+
14400
+ const int64_t i = i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0 + i0;
14401
+ const int64_t j = i02*ne01*ne00 + i01*ne00 + i00;
14402
+
14403
+ if (py*w + i2 >= ne02 || px*w + i1 >= ne01) {
14404
+ ((float *) dst->data)[i] = 0.0f;
14405
+ } else {
14406
+ ((float *) dst->data)[i] = ((float *) src0->data)[j];
14407
+ }
14408
+ }
14409
+ }
14410
+ }
14411
+ }
14412
+ }
14413
+ }
14414
+
14415
+ static void ggml_compute_forward_win_part(
14416
+ const struct ggml_compute_params * params,
14417
+ const struct ggml_tensor * src0,
14418
+ const struct ggml_tensor * opt0,
14419
+ struct ggml_tensor * dst) {
14420
+ switch (src0->type) {
14421
+ case GGML_TYPE_F32:
14422
+ {
14423
+ ggml_compute_forward_win_part_f32(params, src0, opt0, dst);
14424
+ } break;
14425
+ default:
14426
+ {
14427
+ GGML_ASSERT(false);
14428
+ } break;
14429
+ }
14430
+ }
14431
+
14432
+ // ggml_compute_forward_win_unpart
14433
+
14434
+ static void ggml_compute_forward_win_unpart_f32(
14435
+ const struct ggml_compute_params * params,
14436
+ const struct ggml_tensor * src0,
14437
+ const struct ggml_tensor * opt0,
14438
+ struct ggml_tensor * dst) {
14439
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
14440
+ return;
14441
+ }
14442
+
14443
+ const int64_t ne00 = src0->ne[0];
14444
+ const int64_t ne01 = src0->ne[1];
14445
+ const int64_t ne02 = src0->ne[2];
14446
+ //const int64_t ne03 = src0->ne[3];
14447
+
14448
+ const int64_t ne0 = dst->ne[0];
14449
+ const int64_t ne1 = dst->ne[1];
14450
+ const int64_t ne2 = dst->ne[2];
14451
+
14452
+ const int32_t w = ((const int32_t *)(opt0->data))[0];
14453
+
14454
+ // padding
14455
+ const int px = (w - ne1%w)%w;
14456
+ //const int py = (w - ne2%w)%w;
14457
+
14458
+ const int npx = (px + ne1)/w;
14459
+ //const int npy = (py + ne2)/w;
14460
+
14461
+ assert(ne0 == ne00);
14462
+
14463
+ // TODO: optimize / multi-thread
14464
+ for (int64_t i2 = 0; i2 < ne2; ++i2) {
14465
+ for (int64_t i1 = 0; i1 < ne1; ++i1) {
14466
+ for (int64_t i0 = 0; i0 < ne0; ++i0) {
14467
+ const int ip2 = i2/w;
14468
+ const int ip1 = i1/w;
14469
+
14470
+ const int64_t i02 = i2%w;
14471
+ const int64_t i01 = i1%w;
14472
+ const int64_t i00 = i0;
14473
+
14474
+ const int64_t i = (ip2*npx + ip1)*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00 + i00;
14475
+ const int64_t j = i2*ne1*ne0 + i1*ne0 + i0;
14476
+
14477
+ ((float *) dst->data)[j] = ((float *) src0->data)[i];
14478
+ }
14479
+ }
14480
+ }
14481
+ }
14482
+
14483
+ static void ggml_compute_forward_win_unpart(
14484
+ const struct ggml_compute_params * params,
14485
+ const struct ggml_tensor * src0,
14486
+ const struct ggml_tensor * opt0,
14487
+ struct ggml_tensor * dst) {
14488
+ switch (src0->type) {
14489
+ case GGML_TYPE_F32:
14490
+ {
14491
+ ggml_compute_forward_win_unpart_f32(params, src0, opt0, dst);
14492
+ } break;
14493
+ default:
14494
+ {
14495
+ GGML_ASSERT(false);
14496
+ } break;
14497
+ }
14498
+ }
14499
+
14500
+ // ggml_compute_forward_map_unary
14501
+
14502
+ static void ggml_compute_forward_map_unary_f32(
14503
+ const struct ggml_compute_params * params,
14504
+ const struct ggml_tensor * src0,
14505
+ struct ggml_tensor * dst,
14506
+ const ggml_unary_op_f32_t fun) {
14507
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
14508
+
14509
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
14510
+ return;
14511
+ }
14512
+
14513
+ const int n = ggml_nrows(src0);
14514
+ const int nc = src0->ne[0];
14515
+
14516
+ assert( dst->nb[0] == sizeof(float));
14517
+ assert(src0->nb[0] == sizeof(float));
14518
+
14519
+ for (int i = 0; i < n; i++) {
14520
+ fun(nc,
14521
+ (float *) ((char *) dst->data + i*( dst->nb[1])),
14522
+ (float *) ((char *) src0->data + i*(src0->nb[1])));
14523
+ }
14524
+ }
14525
+
14526
+
14527
+ static void ggml_compute_forward_map_unary(
14528
+ const struct ggml_compute_params * params,
14529
+ const struct ggml_tensor * src0,
14530
+ struct ggml_tensor * dst,
14531
+ const ggml_unary_op_f32_t fun) {
14532
+ switch (src0->type) {
14533
+ case GGML_TYPE_F32:
14534
+ {
14535
+ ggml_compute_forward_map_unary_f32(params, src0, dst, fun);
14536
+ } break;
14537
+ default:
14538
+ {
14539
+ GGML_ASSERT(false);
14540
+ } break;
14541
+ }
14542
+ }
14543
+
14544
+ // ggml_compute_forward_map_binary
14545
+
14546
+ static void ggml_compute_forward_map_binary_f32(
14547
+ const struct ggml_compute_params * params,
14548
+ const struct ggml_tensor * src0,
14549
+ const struct ggml_tensor * src1,
14550
+ struct ggml_tensor * dst,
14551
+ const ggml_binary_op_f32_t fun) {
14552
+ assert(params->ith == 0);
14553
+ assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
14554
+
14555
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
14556
+ return;
14557
+ }
14558
+
14559
+ const int n = ggml_nrows(src0);
14560
+ const int nc = src0->ne[0];
14561
+
14562
+ assert( dst->nb[0] == sizeof(float));
14563
+ assert(src0->nb[0] == sizeof(float));
14564
+ assert(src1->nb[0] == sizeof(float));
14565
+
14566
+ for (int i = 0; i < n; i++) {
14567
+ fun(nc,
14568
+ (float *) ((char *) dst->data + i*( dst->nb[1])),
14569
+ (float *) ((char *) src0->data + i*(src0->nb[1])),
14570
+ (float *) ((char *) src1->data + i*(src1->nb[1])));
14571
+ }
14572
+ }
14573
+
14574
+
14575
+ static void ggml_compute_forward_map_binary(
14576
+ const struct ggml_compute_params * params,
14577
+ const struct ggml_tensor * src0,
14578
+ const struct ggml_tensor * src1,
14579
+ struct ggml_tensor * dst,
14580
+ const ggml_binary_op_f32_t fun) {
14581
+ switch (src0->type) {
14582
+ case GGML_TYPE_F32:
14583
+ {
14584
+ ggml_compute_forward_map_binary_f32(params, src0, src1, dst, fun);
14585
+ } break;
14586
+ default:
14587
+ {
14588
+ GGML_ASSERT(false);
14589
+ } break;
14590
+ }
14591
+ }
14592
+
14593
+ // ggml_compute_forward_cross_entropy_loss
14594
+
14595
+ static void ggml_compute_forward_cross_entropy_loss_f32(
14596
+ const struct ggml_compute_params * params,
14597
+ const struct ggml_tensor * src0,
14598
+ const struct ggml_tensor * src1,
14599
+ struct ggml_tensor * dst) {
14600
+ GGML_ASSERT(ggml_is_contiguous(src0));
14601
+ GGML_ASSERT(ggml_is_contiguous(src1));
14602
+ GGML_ASSERT(ggml_is_scalar(dst));
14603
+ GGML_ASSERT(ggml_are_same_shape(src0, src1));
14604
+
14605
+ const int ith = params->ith;
14606
+ const int nth = params->nth;
14607
+
14608
+ float * sums = (float *) params->wdata;
14609
+
14610
+ // TODO: handle transposed/permuted matrices
14611
+ const int nc = src0->ne[0];
14612
+ const int nr = ggml_nrows(src0);
14613
+
14614
+ if (params->type == GGML_TASK_INIT) {
14615
+ if (ith == 0) {
14616
+ memset(sums, 0, sizeof(float) * (nth + nth * nc));
14617
+ }
14618
+ return;
14619
+ }
14620
+
14621
+ if (params->type == GGML_TASK_FINALIZE) {
14622
+ if (ith == 0) {
14623
+ float * dp = (float *) dst->data;
14624
+ ggml_vec_sum_f32(nth, dp, sums);
14625
+ dp[0] *= -1.0f;
14626
+ }
14627
+ return;
14628
+ }
14629
+
14630
+ const double eps = 1e-9;
14631
+
14632
+ // rows per thread
14633
+ const int dr = (nr + nth - 1)/nth;
14634
+
14635
+ // row range for this thread
14636
+ const int ir0 = dr*ith;
14637
+ const int ir1 = MIN(ir0 + dr, nr);
14638
+
14639
+ for (int i1 = ir0; i1 < ir1; i1++) {
14640
+ float * s0 = (float *)((char *) src0->data + i1*src0->nb[1]);
14641
+ float * s1 = (float *)((char *) src1->data + i1*src1->nb[1]);
14642
+ float * st = (float *) params->wdata + nth + ith*nc;
14643
+
14644
+ #ifndef NDEBUG
14645
+ for (int i = 0; i < nc; ++i) {
14646
+ //printf("p[%d] = %f\n", i, p[i]);
14647
+ assert(!isnan(s0[i]));
14648
+ assert(!isnan(s1[i]));
14649
+ }
14650
+ #endif
14651
+ // soft_max
14652
+ ggml_float sum = 0.0;
14653
+ {
14654
+ float max = -INFINITY;
14655
+ ggml_vec_max_f32(nc, &max, s0);
14656
+
14657
+ uint16_t scvt;
14658
+ for (int i = 0; i < nc; i++) {
14659
+ if (s0[i] == -INFINITY) {
14660
+ st[i] = 0.0f;
14661
+ } else {
14662
+ // const float val = (s0[i] == -INFINITY) ? 0.0 : exp(s0[i] - max);
14663
+ ggml_fp16_t s = GGML_FP32_TO_FP16(s0[i] - max);
14664
+ memcpy(&scvt, &s, sizeof(scvt));
14665
+ const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt]);
14666
+ sum += (ggml_float)val;
14667
+ st[i] = val;
14668
+ }
14669
+ }
14670
+
14671
+ assert(sum > 0.0);
14672
+ // sum = 1.0/sum;
14673
+ }
14674
+ // avoid log(0) by rescaling from [0..1] to [eps..1]
14675
+ sum = (1.0 - eps) / sum;
14676
+ ggml_vec_scale_f32(nc, st, sum);
14677
+ ggml_vec_add1_f32(nc, st, st, eps);
14678
+ ggml_vec_log_f32(nc, st, st);
14679
+ ggml_vec_mul_f32(nc, st, st, s1);
14680
+
14681
+ ggml_vec_sum_f32(nc, sums + ith, st);
14682
+
14683
+ #ifndef NDEBUG
14684
+ for (int i = 0; i < nc; ++i) {
14685
+ assert(!isnan(st[i]));
14686
+ assert(!isinf(st[i]));
14687
+ }
14688
+ #endif
14689
+ }
14690
+
14691
+ }
14692
+
14693
+ static void ggml_compute_forward_cross_entropy_loss(
14694
+ const struct ggml_compute_params * params,
14695
+ const struct ggml_tensor * src0,
14696
+ const struct ggml_tensor * src1,
14697
+ struct ggml_tensor * dst) {
14698
+ switch (src0->type) {
14699
+ case GGML_TYPE_F32:
14700
+ {
14701
+ ggml_compute_forward_cross_entropy_loss_f32(params, src0, src1, dst);
14702
+ } break;
14703
+ default:
14704
+ {
14705
+ GGML_ASSERT(false);
14706
+ } break;
14707
+ }
14708
+ }
14709
+
14710
+ // ggml_compute_forward_cross_entropy_loss_back
14711
+
14712
+ static void ggml_compute_forward_cross_entropy_loss_back_f32(
14713
+ const struct ggml_compute_params * params,
14714
+ const struct ggml_tensor * src0,
14715
+ const struct ggml_tensor * src1,
14716
+ const struct ggml_tensor * opt0,
14717
+ struct ggml_tensor * dst) {
14718
+ GGML_ASSERT(ggml_is_contiguous(dst));
14719
+ GGML_ASSERT(ggml_is_contiguous(src0));
14720
+ GGML_ASSERT(ggml_is_contiguous(src1));
14721
+ GGML_ASSERT(ggml_is_contiguous(opt0));
14722
+ GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
14723
+
14724
+ const int64_t ith = params->ith;
14725
+ const int64_t nth = params->nth;
14726
+
14727
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
14728
+ return;
14729
+ }
14730
+
14731
+ const float eps = 1e-9f;
14732
+
14733
+ // TODO: handle transposed/permuted matrices
14734
+ const int64_t nc = src0->ne[0];
14735
+ const int64_t nr = ggml_nrows(src0);
14736
+
14737
+ // rows per thread
14738
+ const int64_t dr = (nr + nth - 1)/nth;
14739
+
14740
+ // row range for this thread
14741
+ const int64_t ir0 = dr*ith;
14742
+ const int64_t ir1 = MIN(ir0 + dr, nr);
14743
+
14744
+ float * d = (float *) opt0->data;
14745
+
14746
+ for (int64_t i1 = ir0; i1 < ir1; i1++) {
14747
+ float * ds0 = (float *)((char *) dst->data + i1*dst->nb[1]);
14748
+ float * s0 = (float *)((char *) src0->data + i1*src0->nb[1]);
14749
+ float * s1 = (float *)((char *) src1->data + i1*src1->nb[1]);
14750
+ float * sm = (float *) params->wdata + ith*nc;
14751
+
14752
+ #ifndef NDEBUG
14753
+ for (int i = 0; i < nc; ++i) {
14754
+ //printf("p[%d] = %f\n", i, p[i]);
14755
+ assert(!isnan(s0[i]));
14756
+ assert(!isnan(s1[i]));
14757
+ }
14758
+ #endif
14759
+ // step by step explanation:
14760
+ {
14761
+ //float * sums = (float *) params->wdata;
14762
+
14763
+ // forward pass with annotated gradients from backward pass
14764
+ // (built by going in reverse operation order, adding to gradients of current operation args)
14765
+ // st0 = exp(s0-max(s0)) grad[st0] = grad[st1]*(1.0 - eps)/sum
14766
+ // from softmax_back: grad[s0] = st1_k * (grad[st1]_k - dot(st1, grad[st1]))
14767
+ // ggml_vec_scale_f32(nc, st, sum); // st1 = st0*/sum = softmax(s0) grad[st1] = grad[st2]*(1.0 - eps)
14768
+ // ggml_vec_scale_f32(nc, st, (1.0f - eps)); // st2 = st1*(1.0 - eps) grad[st2] = grad[st3]
14769
+ // ggml_vec_add1_f32(nc, st, st, eps); // st3 = st2 + eps grad[st3] = grad[st4]/st3
14770
+ // ggml_vec_log_f32(nc, st, st); // st4 = log(st3) grad[st4] = grad[st5] * s1
14771
+ // ggml_vec_mul_f32(nc, st, st, s1); // st5 = st4 * s1 grad[st5] = grad[sums[ith]]
14772
+ // ggml_vec_sum_f32(nc, sums + ith, st); // sums[ith] = st5 grad[sums[ith]] = grad[cross_entropy_loss] = -grad[cel]
14773
+
14774
+ // substitute into grad[st1], because we can reuse softmax_back from this point on
14775
+ // grad[st1] = -grad[cel]*s1*(1.0 - eps)/(eps + softmax(s0)*(1.0 - eps))
14776
+ // postorder:
14777
+ // grad[st1] := softmax(s0)
14778
+ // grad[st1] := grad[st1]*(1.0 - eps)
14779
+ // grad[st1] := grad[st1] + eps
14780
+ // grad[st1] := s1 / grad[st1]
14781
+ // grad[st1] := grad[st1]*(1.0-eps)*-grad[cel]
14782
+
14783
+ // src0 gradients by going through softmax_back
14784
+ // grad[s0] = st1_k * (grad[st1]_k - dot(st1, grad[st1]))
14785
+ // from softmax_back:
14786
+ // dxk = yk * (dyk - dot(y, dy))
14787
+ // dot_y_dy := dot(y, dy)
14788
+ // dx := dy
14789
+ // dx := dx - dot_y_dy
14790
+ // dx := dx * y
14791
+ // postorder:
14792
+ // dot_st1_dst1 := dot(st1, grad[st1])
14793
+ // grad[s0] := grad[st1]
14794
+ // grad[s0] := grad[s0] - dot_st1_dst1
14795
+ // grad[s0] := grad[s0] * st1
14796
+
14797
+ // prepend postorder from grad[st1] directly using grad[s0] as memory location, as we will grad[s0] := grad[st1]
14798
+ // sm := softmax(s0)
14799
+ // grad[s0] := sm*(1.0 - eps)
14800
+ // grad[s0] := grad[s0] + eps
14801
+ // grad[s0] := s1 / grad[s0]
14802
+ // grad[s0] := grad[s0]*(1.0-eps)*-grad[cel]
14803
+ // dot_st1_dst1 := dot(sm, grad[s0])
14804
+ // grad[s0] := grad[s0] - dot_st1_dst1
14805
+ // grad[s0] := grad[s0] * sm
14806
+ }
14807
+
14808
+ // soft_max
14809
+ ggml_float sum = 0.0;
14810
+ {
14811
+ float max = -INFINITY;
14812
+ ggml_vec_max_f32(nc, &max, s0);
14813
+
14814
+ uint16_t scvt;
14815
+ for (int i = 0; i < nc; i++) {
14816
+ if (s0[i] == -INFINITY) {
14817
+ sm[i] = 0.0f;
14818
+ } else {
14819
+ // const float val = (s0[i] == -INFINITY) ? 0.0 : exp(s0[i] - max);
14820
+ ggml_fp16_t s = GGML_FP32_TO_FP16(s0[i] - max);
14821
+ memcpy(&scvt, &s, sizeof(scvt));
14822
+ const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt]);
14823
+ sum += (ggml_float)val;
14824
+ sm[i] = val;
14825
+ }
14826
+ }
14827
+
14828
+ assert(sum > 0.0);
14829
+ sum = 1.0/sum;
14830
+ }
13002
14831
 
13003
- assert( dst->nb[0] == sizeof(float));
13004
- assert(src0->nb[0] == sizeof(float));
13005
- assert(src1->nb[0] == sizeof(float));
14832
+ float dot_st1_dst1 = 0;
14833
+ ggml_vec_scale_f32(nc, sm, sum);
14834
+ ggml_vec_cpy_f32 (nc, ds0, sm);
14835
+ ggml_vec_scale_f32(nc, ds0, (1.0f - eps));
14836
+ ggml_vec_add1_f32 (nc, ds0, ds0, eps);
14837
+ ggml_vec_div_f32 (nc, ds0, s1, ds0);
14838
+ ggml_vec_scale_f32(nc, ds0, -(1.0f - eps)*d[0]);
14839
+ ggml_vec_dot_f32 (nc, &dot_st1_dst1, sm, ds0);
14840
+ ggml_vec_acc1_f32 (nc, ds0, -dot_st1_dst1);
14841
+ ggml_vec_mul_f32 (nc, ds0, ds0, sm);
13006
14842
 
13007
- for (int i = 0; i < n; i++) {
13008
- fun(nc,
13009
- (float *) ((char *) dst->data + i*( dst->nb[1])),
13010
- (float *) ((char *) src0->data + i*(src0->nb[1])),
13011
- (float *) ((char *) src1->data + i*(src1->nb[1])));
14843
+ #ifndef NDEBUG
14844
+ for (int i = 0; i < nc; ++i) {
14845
+ assert(!isnan(sm[i]));
14846
+ assert(!isinf(sm[i]));
14847
+ assert(!isnan(ds0[i]));
14848
+ assert(!isinf(ds0[i]));
14849
+ }
14850
+ #endif
13012
14851
  }
13013
14852
  }
13014
14853
 
13015
-
13016
- static void ggml_compute_forward_map_binary(
14854
+ static void ggml_compute_forward_cross_entropy_loss_back(
13017
14855
  const struct ggml_compute_params * params,
13018
14856
  const struct ggml_tensor * src0,
13019
14857
  const struct ggml_tensor * src1,
13020
- struct ggml_tensor * dst,
13021
- const ggml_binary_op_f32_t fun) {
14858
+ const struct ggml_tensor * opt0,
14859
+ struct ggml_tensor * dst) {
13022
14860
  switch (src0->type) {
13023
14861
  case GGML_TYPE_F32:
13024
14862
  {
13025
- ggml_compute_forward_map_binary_f32(params, src0, src1, dst, fun);
14863
+ ggml_compute_forward_cross_entropy_loss_back_f32(params, src0, src1, opt0, dst);
13026
14864
  } break;
13027
14865
  default:
13028
14866
  {
@@ -13031,6 +14869,7 @@ static void ggml_compute_forward_map_binary(
13031
14869
  }
13032
14870
  }
13033
14871
 
14872
+
13034
14873
  /////////////////////////////////
13035
14874
 
13036
14875
  static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
@@ -13102,6 +14941,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
13102
14941
  {
13103
14942
  ggml_compute_forward_repeat(params, tensor->src0, tensor);
13104
14943
  } break;
14944
+ case GGML_OP_REPEAT_BACK:
14945
+ {
14946
+ ggml_compute_forward_repeat_back(params, tensor->src0, tensor);
14947
+ } break;
13105
14948
  case GGML_OP_ABS:
13106
14949
  {
13107
14950
  ggml_compute_forward_abs(params, tensor->src0, tensor);
@@ -13126,6 +14969,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
13126
14969
  {
13127
14970
  ggml_compute_forward_gelu(params, tensor->src0, tensor);
13128
14971
  } break;
14972
+ case GGML_OP_GELU_QUICK:
14973
+ {
14974
+ ggml_compute_forward_gelu_quick(params, tensor->src0, tensor);
14975
+ } break;
13129
14976
  case GGML_OP_SILU:
13130
14977
  {
13131
14978
  ggml_compute_forward_silu(params, tensor->src0, tensor);
@@ -13150,6 +14997,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
13150
14997
  {
13151
14998
  ggml_compute_forward_mul_mat(params, tensor->src0, tensor->src1, tensor);
13152
14999
  } break;
15000
+ case GGML_OP_OUT_PROD:
15001
+ {
15002
+ ggml_compute_forward_out_prod(params, tensor->src0, tensor->src1, tensor);
15003
+ } break;
13153
15004
  case GGML_OP_SCALE:
13154
15005
  {
13155
15006
  ggml_compute_forward_scale(params, tensor->src0, tensor->src1, tensor);
@@ -13206,6 +15057,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
13206
15057
  {
13207
15058
  ggml_compute_forward_soft_max(params, tensor->src0, tensor);
13208
15059
  } break;
15060
+ case GGML_OP_SOFT_MAX_BACK:
15061
+ {
15062
+ ggml_compute_forward_soft_max_back(params, tensor->src0, tensor->src1, tensor);
15063
+ } break;
13209
15064
  case GGML_OP_ROPE:
13210
15065
  {
13211
15066
  ggml_compute_forward_rope(params, tensor->src0, tensor->src1, tensor);
@@ -13222,25 +15077,44 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
13222
15077
  {
13223
15078
  ggml_compute_forward_clamp(params, tensor->src0, tensor->src1, tensor);
13224
15079
  } break;
13225
- case GGML_OP_CONV_1D_1S:
15080
+ case GGML_OP_CONV_1D_S1_PH:
15081
+ {
15082
+ ggml_compute_forward_conv_1d_s1_ph(params, tensor->src0, tensor->src1, tensor);
15083
+ } break;
15084
+ case GGML_OP_CONV_1D_S2_PH:
13226
15085
  {
13227
- ggml_compute_forward_conv_1d_1s(params, tensor->src0, tensor->src1, tensor);
15086
+ ggml_compute_forward_conv_1d_s2_ph(params, tensor->src0, tensor->src1, tensor);
13228
15087
  } break;
13229
- case GGML_OP_CONV_1D_2S:
15088
+ case GGML_OP_CONV_2D_SK_P0:
13230
15089
  {
13231
- ggml_compute_forward_conv_1d_2s(params, tensor->src0, tensor->src1, tensor);
15090
+ ggml_compute_forward_conv_2d_sk_p0(params, tensor->src0, tensor->src1, tensor);
13232
15091
  } break;
13233
15092
  case GGML_OP_FLASH_ATTN:
13234
15093
  {
13235
- int32_t t = ggml_get_i32_1d(tensor->opt[1], 0);
15094
+ const int32_t t = ggml_get_i32_1d(tensor->opt[1], 0);
13236
15095
  GGML_ASSERT(t == 0 || t == 1);
13237
- bool masked = t != 0;
15096
+ const bool masked = t != 0;
13238
15097
  ggml_compute_forward_flash_attn(params, tensor->src0, tensor->src1, tensor->opt[0], masked, tensor);
13239
15098
  } break;
13240
15099
  case GGML_OP_FLASH_FF:
13241
15100
  {
13242
15101
  ggml_compute_forward_flash_ff(params, tensor->src0, tensor->src1, tensor->opt[0], tensor->opt[1], tensor->opt[2], tensor);
13243
15102
  } break;
15103
+ case GGML_OP_FLASH_ATTN_BACK:
15104
+ {
15105
+ int32_t t = ggml_get_i32_1d(tensor->opt[2], 0);
15106
+ GGML_ASSERT(t == 0 || t == 1);
15107
+ bool masked = t != 0;
15108
+ ggml_compute_forward_flash_attn_back(params, tensor->src0, tensor->src1, tensor->opt[0], tensor->opt[1], masked, tensor);
15109
+ } break;
15110
+ case GGML_OP_WIN_PART:
15111
+ {
15112
+ ggml_compute_forward_win_part(params, tensor->src0, tensor->opt[0], tensor);
15113
+ } break;
15114
+ case GGML_OP_WIN_UNPART:
15115
+ {
15116
+ ggml_compute_forward_win_unpart(params, tensor->src0, tensor->opt[0], tensor);
15117
+ } break;
13244
15118
  case GGML_OP_MAP_UNARY:
13245
15119
  {
13246
15120
  const ggml_unary_op_f32_t fun = *((ggml_unary_op_f32_t *)tensor->opt[0]->data);
@@ -13253,6 +15127,16 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
13253
15127
  ggml_compute_forward_map_binary(params, tensor->src0, tensor->src1, tensor, fun);
13254
15128
  }
13255
15129
  break;
15130
+ case GGML_OP_CROSS_ENTROPY_LOSS:
15131
+ {
15132
+ ggml_compute_forward_cross_entropy_loss(params, tensor->src0, tensor->src1, tensor);
15133
+ }
15134
+ break;
15135
+ case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
15136
+ {
15137
+ ggml_compute_forward_cross_entropy_loss_back(params, tensor->src0, tensor->src1, tensor->opt[0], tensor);
15138
+ }
15139
+ break;
13256
15140
  case GGML_OP_NONE:
13257
15141
  {
13258
15142
  // nop
@@ -13391,11 +15275,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
13391
15275
  src0->grad =
13392
15276
  ggml_add_impl(ctx,
13393
15277
  src0->grad,
13394
- ggml_mul(ctx,
13395
- tensor->grad, // this was not catched by test_grad because in test_grad tensor->grad is 1
15278
+ ggml_scale(ctx,
13396
15279
  ggml_div(ctx,
13397
- ggml_repeat(ctx, ggml_new_f32(ctx, 0.5f), tensor),
13398
- tensor)),
15280
+ tensor->grad,
15281
+ tensor),
15282
+ ggml_new_f32(ctx, 0.5f)),
13399
15283
  inplace);
13400
15284
  }
13401
15285
  } break;
@@ -13441,43 +15325,20 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
13441
15325
  {
13442
15326
  // necessary for llama
13443
15327
  if (src0->grad) {
13444
- GGML_ASSERT(src0->n_dims == 1 || src0->n_dims == 2);
13445
- const int nc = tensor->ne[0];
13446
- const int nr = tensor->ne[1];
13447
- const int nc0 = src0->ne[0];
13448
- const int nr0 = src0->ne[1];
13449
- const int ncr = nc/nc0; // guaranteed to be an integer due to the check in ggml_can_repeat
13450
- const int nrr = nr/nr0; // guaranteed to be an integer due to the check in ggml_can_repeat
13451
- // tensor->grad [nc,nr,1,1]
13452
- // reshape [nc0,nc/nc0,nr0,nr/nr0]
13453
- // permute [nc0,nr0,nc/nc0,nr/nr0]
13454
- // substitute [nc0,nr0,ncr,nrr]
13455
- // reshape [nc0*nr0,ncr*nrr,1,1]
13456
- // transpose [ncr*nrr,nc0*nr0,1,1]
13457
- // sum rows [1,nc0*nr0,1,1]
13458
- // transpose [nc0*nr0,1,1]
13459
- // reshape [nc0,nr0,1,1] reshape_1d or reshape_2d
13460
- // add to src0->grad
13461
-
13462
- int64_t ne[4] = {nc0,ncr,nr0,nrr};
13463
-
13464
- struct ggml_tensor* F00 = tensor->grad;
13465
- struct ggml_tensor* F01 = ggml_reshape (ctx, F00, ggml_new_tensor(ctx,tensor->grad->type,4,ne));
13466
- struct ggml_tensor* F02 = ggml_permute (ctx, F01, 0,2,1,3);
13467
- struct ggml_tensor* F03 = ggml_cont (ctx, F02);
13468
- struct ggml_tensor* F04 = ggml_reshape_2d(ctx, F03, nc0*nr0, ncr*nrr);
13469
- struct ggml_tensor* F05 = ggml_transpose (ctx, F04);
13470
- struct ggml_tensor* F06 = ggml_cont (ctx, F05);
13471
- struct ggml_tensor* F07 = ggml_sum_rows (ctx, F06);
13472
- struct ggml_tensor* F08 = ggml_transpose (ctx, F07);
13473
- struct ggml_tensor* F09 = ggml_cont (ctx, F08);
13474
- struct ggml_tensor* F10 = ggml_reshape (ctx, F09, src0->grad);
13475
-
13476
- src0->grad =
13477
- ggml_add_impl(ctx,
13478
- src0->grad,
13479
- F10,
13480
- inplace);
15328
+ src0->grad = ggml_add_impl(ctx,
15329
+ src0->grad,
15330
+ ggml_repeat_back(ctx, tensor->grad, src0->grad),
15331
+ inplace);
15332
+ }
15333
+ } break;
15334
+ case GGML_OP_REPEAT_BACK:
15335
+ {
15336
+ if (src0->grad) {
15337
+ // TODO: test this
15338
+ src0->grad = ggml_add_impl(ctx,
15339
+ src0->grad,
15340
+ ggml_repeat(ctx, tensor->grad, src0->grad),
15341
+ inplace);
13481
15342
  }
13482
15343
  } break;
13483
15344
  case GGML_OP_ABS:
@@ -13525,6 +15386,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
13525
15386
  {
13526
15387
  GGML_ASSERT(false); // TODO: not implemented
13527
15388
  } break;
15389
+ case GGML_OP_GELU_QUICK:
15390
+ {
15391
+ GGML_ASSERT(false); // TODO: not implemented
15392
+ } break;
13528
15393
  case GGML_OP_ALIBI:
13529
15394
  {
13530
15395
  GGML_ASSERT(false); // TODO: not implemented
@@ -13584,38 +15449,37 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
13584
15449
 
13585
15450
  // necessary for llama
13586
15451
  if (src0->grad) {
13587
- // TODO: this requires outer product - ggml_out_prod(ctx, src1, tensor->grad);
13588
15452
  src0->grad =
13589
15453
  ggml_add_impl(ctx,
13590
15454
  src0->grad,
13591
- // ds0 = dt.dot(s1.T)
13592
- // ggml_out_prod(ctx, // [n,m]
13593
- // src1, // [n,p]
13594
- // tensor->grad), // [m,p]
13595
- // for now just using A*B==(B.T*A.T).T
13596
- ggml_cont(ctx, // [n,m]
13597
- ggml_transpose(ctx, // [n,m]
13598
- ggml_mul_mat(ctx, // [m,n]
13599
- ggml_cont(ctx, // [p,m]
13600
- ggml_transpose(ctx, // [p,m]
13601
- tensor->grad)), // [m,p]
13602
- ggml_cont(ctx, // [p,n]
13603
- ggml_transpose(ctx, // [p,n]
13604
- src1))))), // [n,p]
15455
+ ggml_out_prod(ctx, // [n,m]
15456
+ src1, // [n,p]
15457
+ tensor->grad), // [m,p]
13605
15458
  inplace);
13606
15459
  }
13607
15460
  if (src1->grad) {
13608
15461
  src1->grad =
13609
15462
  ggml_add_impl(ctx,
13610
15463
  src1->grad,
13611
- // ds1 = s0.T.dot(dt):
13612
- ggml_mul_mat(ctx, // [n,p]
13613
- ggml_cont(ctx, // [m,n]
13614
- ggml_transpose(ctx, src0)), // [m,n]
13615
- tensor->grad), // [m,p]
15464
+ // ggml_mul_mat(ctx, // [n,p]
15465
+ // ggml_cont(ctx, // [m,n]
15466
+ // ggml_transpose(ctx, src0)), // [m,n]
15467
+ // tensor->grad), // [m,p]
15468
+
15469
+ // // when src0 is bigger than tensor->grad (this is mostly the case in llama),
15470
+ // // avoid transpose of src0, rather transpose smaller tensor->grad
15471
+ // // and then use ggml_out_prod
15472
+ ggml_out_prod(ctx, // [n,p]
15473
+ src0, // [n,m]
15474
+ ggml_transpose(ctx, // [p,m]
15475
+ tensor->grad)), // [m,p]
13616
15476
  inplace);
13617
15477
  }
13618
15478
  } break;
15479
+ case GGML_OP_OUT_PROD:
15480
+ {
15481
+ GGML_ASSERT(false); // TODO: not implemented
15482
+ } break;
13619
15483
  case GGML_OP_SCALE:
13620
15484
  {
13621
15485
  // necessary for llama
@@ -13717,7 +15581,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
13717
15581
  // necessary for llama
13718
15582
  if (src0->grad) {
13719
15583
  size_t offset;
13720
- memcpy(&offset, tensor->padding, sizeof(offset));
15584
+
15585
+ GGML_ASSERT(sizeof(offset) <= ggml_nbytes(tensor->opt[0]));
15586
+ memcpy(&offset, tensor->opt[0]->data, sizeof(offset));
13721
15587
 
13722
15588
  size_t nb1 = tensor->nb[1];
13723
15589
  size_t nb2 = tensor->nb[2];
@@ -13744,10 +15610,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
13744
15610
  {
13745
15611
  // necessary for llama
13746
15612
  if (src0->grad) {
13747
- int axis0 = tensor->padding[0] & 0x3;
13748
- int axis1 = tensor->padding[1] & 0x3;
13749
- int axis2 = tensor->padding[2] & 0x3;
13750
- int axis3 = tensor->padding[3] & 0x3;
15613
+ int32_t * axes = (int32_t *) tensor->opt[0]->data;
15614
+ int axis0 = axes[0] & 0x3;
15615
+ int axis1 = axes[1] & 0x3;
15616
+ int axis2 = axes[2] & 0x3;
15617
+ int axis3 = axes[3] & 0x3;
13751
15618
  int axes_backward[4] = {0,0,0,0};
13752
15619
  axes_backward[axis0] = 0;
13753
15620
  axes_backward[axis1] = 1;
@@ -13831,50 +15698,16 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
13831
15698
  {
13832
15699
  // necessary for llama
13833
15700
  if (src0->grad) {
13834
- // y = softmax(x)
13835
- //
13836
- // Jii = yi - yi*yi
13837
- // Jij = -yi*yj
13838
- // J = diag(y)-y.*y
13839
- // dx = J * dy
13840
- // dxk = sum(Jkj * dyk)
13841
-
13842
- int64_t ne2[4] = {
13843
- tensor->ne[0],
13844
- 1,
13845
- tensor->ne[1]*tensor->ne[2],
13846
- tensor->ne[3]
13847
- };
13848
- struct ggml_tensor * tensor2 = ggml_cont(ctx,
13849
- ggml_reshape_4d(ctx,
13850
- ggml_cont(ctx, tensor),
13851
- ne2[0], ne2[1], ne2[2], ne2[3]));
13852
-
13853
- struct ggml_tensor * grad2 = ggml_cont(ctx,
13854
- ggml_reshape_4d(ctx,
13855
- ggml_cont(ctx, tensor->grad),
13856
- ne2[0], ne2[1], ne2[2], ne2[3]));
13857
-
13858
- struct ggml_tensor * tensor2_t = ggml_cont(ctx, // [1,ne0,ne1*ne2,ne3]
13859
- ggml_permute(ctx, // [1,ne0,ne1*ne2,ne3]
13860
- tensor2, // [ne0,1,ne1*ne2,ne3]
13861
- 1, 0, 2, 3));
13862
-
13863
15701
  src0->grad =
13864
- ggml_add_impl(ctx,
13865
- src0->grad, // [ne0,ne1,ne2,ne3]
13866
- ggml_reshape(ctx, // [ne0,ne1,ne2,ne3]
13867
- ggml_mul_mat(ctx, // [ne0,1,ne1*ne2,ne3]
13868
- ggml_sub(ctx, // [ne0,ne0,ne1*ne2,ne3]
13869
- ggml_diag(ctx, // [ne0,ne0,ne1*ne2,ne3]
13870
- tensor2), // [ne0,1,ne1*ne2,ne3]
13871
- ggml_mul_mat(ctx, // [ne0,ne0,ne1*ne2,ne3]
13872
- tensor2_t, // [1,ne0,ne1*ne2,ne3]
13873
- tensor2_t)), // [1,ne0,ne1*ne2,ne3]
13874
- grad2), // [ne0,1,ne1*ne2,ne3]
13875
- src0->grad),
13876
- inplace);
15702
+ ggml_add_impl(ctx, src0->grad,
15703
+ ggml_soft_max_back(ctx, tensor->grad, tensor),
15704
+ inplace);
13877
15705
  }
15706
+
15707
+ } break;
15708
+ case GGML_OP_SOFT_MAX_BACK:
15709
+ {
15710
+ GGML_ASSERT(false); // TODO: not implemented
13878
15711
  } break;
13879
15712
  case GGML_OP_ROPE:
13880
15713
  {
@@ -13919,27 +15752,206 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
13919
15752
  // noop
13920
15753
  }
13921
15754
  } break;
13922
- case GGML_OP_CONV_1D_1S:
15755
+ case GGML_OP_CONV_1D_S1_PH:
15756
+ {
15757
+ GGML_ASSERT(false); // TODO: not implemented
15758
+ } break;
15759
+ case GGML_OP_CONV_1D_S2_PH:
13923
15760
  {
13924
15761
  GGML_ASSERT(false); // TODO: not implemented
13925
15762
  } break;
13926
- case GGML_OP_CONV_1D_2S:
15763
+ case GGML_OP_CONV_2D_SK_P0:
13927
15764
  {
13928
15765
  GGML_ASSERT(false); // TODO: not implemented
13929
15766
  } break;
13930
15767
  case GGML_OP_FLASH_ATTN:
13931
15768
  {
13932
- GGML_ASSERT(false); // not supported
15769
+ struct ggml_tensor * flash_grad = NULL;
15770
+ if (src0->grad || src1->grad || tensor->opt[0]->grad) {
15771
+ int32_t t = ggml_get_i32_1d(tensor->opt[1], 0);
15772
+ GGML_ASSERT(t == 0 || t == 1);
15773
+ bool masked = t != 0;
15774
+ flash_grad =
15775
+ ggml_flash_attn_back(ctx,
15776
+ src0,
15777
+ src1,
15778
+ tensor->opt[0],
15779
+ tensor->grad,
15780
+ masked);
15781
+ }
15782
+
15783
+ if (src0->grad) {
15784
+ struct ggml_tensor * grad_q = NULL;
15785
+ const size_t nb0 = flash_grad->nb[0];
15786
+ const size_t offset = 0;
15787
+ switch(src0->n_dims) {
15788
+ case 2:
15789
+ {
15790
+ grad_q = ggml_view_2d(ctx,
15791
+ flash_grad,
15792
+ src0->ne[0],
15793
+ src0->ne[1],
15794
+ nb0*src0->ne[0],
15795
+ offset);
15796
+ } break;
15797
+ case 3:
15798
+ {
15799
+ grad_q = ggml_view_3d(ctx,
15800
+ flash_grad,
15801
+ src0->ne[0],
15802
+ src0->ne[1],
15803
+ src0->ne[2],
15804
+ nb0*src0->ne[0],
15805
+ nb0*src0->ne[0]*src0->ne[1],
15806
+ offset);
15807
+ } break;
15808
+ case 4:
15809
+ {
15810
+ grad_q = ggml_view_4d(ctx,
15811
+ flash_grad,
15812
+ src0->ne[0],
15813
+ src0->ne[1],
15814
+ src0->ne[2],
15815
+ src0->ne[3],
15816
+ nb0*src0->ne[0],
15817
+ nb0*src0->ne[0]*src0->ne[1],
15818
+ nb0*src0->ne[0]*src0->ne[1]*src0->ne[2],
15819
+ offset);
15820
+ } break;
15821
+ }
15822
+
15823
+ src0->grad = ggml_add_impl(ctx,
15824
+ src0->grad,
15825
+ grad_q,
15826
+ inplace);
15827
+ }
15828
+
15829
+ if (src1->grad) {
15830
+ struct ggml_tensor * grad_k = NULL;
15831
+ const size_t nb0 = flash_grad->nb[0];
15832
+ const size_t offset = nb0*src0->ne[0]*src0->ne[1]*src0->ne[2]*src0->ne[3];
15833
+ switch(src1->n_dims) {
15834
+ case 2:
15835
+ {
15836
+ grad_k = ggml_view_2d(ctx,
15837
+ flash_grad,
15838
+ src1->ne[0],
15839
+ src1->ne[1],
15840
+ nb0*src1->ne[0],
15841
+ offset);
15842
+ } break;
15843
+ case 3:
15844
+ {
15845
+ grad_k = ggml_view_3d(ctx,
15846
+ flash_grad,
15847
+ src1->ne[0],
15848
+ src1->ne[1],
15849
+ src1->ne[2],
15850
+ nb0*src1->ne[0],
15851
+ nb0*src1->ne[0]*src1->ne[1],
15852
+ offset);
15853
+ } break;
15854
+ case 4:
15855
+ {
15856
+ grad_k = ggml_view_4d(ctx,
15857
+ flash_grad,
15858
+ src1->ne[0],
15859
+ src1->ne[1],
15860
+ src1->ne[2],
15861
+ src1->ne[3],
15862
+ nb0*src1->ne[0],
15863
+ nb0*src1->ne[0]*src1->ne[1],
15864
+ nb0*src1->ne[0]*src1->ne[1]*src1->ne[2],
15865
+ offset);
15866
+ } break;
15867
+ }
15868
+
15869
+ src1->grad = ggml_add_impl(ctx,
15870
+ src1->grad,
15871
+ grad_k,
15872
+ inplace);
15873
+ }
15874
+
15875
+ struct ggml_tensor * opt0 = tensor->opt[0];
15876
+
15877
+ if (opt0->grad) {
15878
+ struct ggml_tensor * grad_v = NULL;
15879
+ const size_t nb0 = flash_grad->nb[0];
15880
+ const size_t offset = nb0*src0->ne[0]*src0->ne[1]*src0->ne[2]*src0->ne[3]
15881
+ + nb0*src1->ne[0]*src1->ne[1]*src1->ne[2]*src1->ne[3];
15882
+ switch(opt0->n_dims) {
15883
+ case 2:
15884
+ {
15885
+ grad_v = ggml_view_2d(ctx,
15886
+ flash_grad,
15887
+ opt0->ne[0],
15888
+ opt0->ne[1],
15889
+ nb0*opt0->ne[0],
15890
+ offset);
15891
+ } break;
15892
+ case 3:
15893
+ {
15894
+ grad_v = ggml_view_3d(ctx,
15895
+ flash_grad,
15896
+ opt0->ne[0],
15897
+ opt0->ne[1],
15898
+ opt0->ne[2],
15899
+ nb0*opt0->ne[0],
15900
+ nb0*opt0->ne[0]*opt0->ne[1],
15901
+ offset);
15902
+ } break;
15903
+ case 4:
15904
+ {
15905
+ grad_v = ggml_view_4d(ctx,
15906
+ flash_grad,
15907
+ opt0->ne[0],
15908
+ opt0->ne[1],
15909
+ opt0->ne[2],
15910
+ opt0->ne[3],
15911
+ nb0*opt0->ne[0],
15912
+ nb0*opt0->ne[0]*opt0->ne[1],
15913
+ nb0*opt0->ne[0]*opt0->ne[1]*opt0->ne[2],
15914
+ offset);
15915
+ } break;
15916
+ }
15917
+
15918
+ opt0->grad = ggml_add_impl(ctx,
15919
+ opt0->grad,
15920
+ grad_v,
15921
+ inplace);
15922
+ }
13933
15923
  } break;
13934
15924
  case GGML_OP_FLASH_FF:
13935
15925
  {
13936
15926
  GGML_ASSERT(false); // not supported
13937
15927
  } break;
15928
+ case GGML_OP_FLASH_ATTN_BACK:
15929
+ {
15930
+ GGML_ASSERT(false); // not supported
15931
+ } break;
15932
+ case GGML_OP_WIN_PART:
15933
+ case GGML_OP_WIN_UNPART:
13938
15934
  case GGML_OP_MAP_UNARY:
13939
15935
  case GGML_OP_MAP_BINARY:
13940
15936
  {
13941
15937
  GGML_ASSERT(false); // not supported
13942
15938
  } break;
15939
+ case GGML_OP_CROSS_ENTROPY_LOSS:
15940
+ {
15941
+ if (src0->grad) {
15942
+ src0->grad = ggml_add_impl(ctx,
15943
+ src0->grad,
15944
+ ggml_cross_entropy_loss_back(ctx,
15945
+ src0,
15946
+ src1,
15947
+ tensor->grad),
15948
+ inplace);
15949
+ }
15950
+ } break;
15951
+ case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
15952
+ {
15953
+ GGML_ASSERT(false); // not supported
15954
+ } break;
13943
15955
  case GGML_OP_NONE:
13944
15956
  {
13945
15957
  // nop
@@ -14316,6 +16328,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
14316
16328
  case GGML_OP_SUM_ROWS:
14317
16329
  case GGML_OP_MEAN:
14318
16330
  case GGML_OP_REPEAT:
16331
+ case GGML_OP_REPEAT_BACK:
14319
16332
  case GGML_OP_ABS:
14320
16333
  case GGML_OP_SGN:
14321
16334
  case GGML_OP_NEG:
@@ -14326,6 +16339,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
14326
16339
  } break;
14327
16340
  case GGML_OP_MUL:
14328
16341
  case GGML_OP_GELU:
16342
+ case GGML_OP_GELU_QUICK:
14329
16343
  case GGML_OP_SILU:
14330
16344
  case GGML_OP_SILU_BACK:
14331
16345
  case GGML_OP_NORM:
@@ -14335,6 +16349,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
14335
16349
  node->n_tasks = n_threads;
14336
16350
  } break;
14337
16351
  case GGML_OP_MUL_MAT:
16352
+ case GGML_OP_OUT_PROD:
14338
16353
  {
14339
16354
  node->n_tasks = n_threads;
14340
16355
 
@@ -14417,6 +16432,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
14417
16432
  } break;
14418
16433
  case GGML_OP_DIAG_MASK_INF:
14419
16434
  case GGML_OP_SOFT_MAX:
16435
+ case GGML_OP_SOFT_MAX_BACK:
14420
16436
  case GGML_OP_ROPE:
14421
16437
  case GGML_OP_ROPE_BACK:
14422
16438
  {
@@ -14430,8 +16446,8 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
14430
16446
  {
14431
16447
  node->n_tasks = 1; //TODO
14432
16448
  } break;
14433
- case GGML_OP_CONV_1D_1S:
14434
- case GGML_OP_CONV_1D_2S:
16449
+ case GGML_OP_CONV_1D_S1_PH:
16450
+ case GGML_OP_CONV_1D_S2_PH:
14435
16451
  {
14436
16452
  node->n_tasks = n_threads;
14437
16453
 
@@ -14458,6 +16474,41 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
14458
16474
  GGML_ASSERT(false);
14459
16475
  }
14460
16476
 
16477
+ work_size = MAX(work_size, cur);
16478
+ } break;
16479
+ case GGML_OP_CONV_2D_SK_P0:
16480
+ {
16481
+ node->n_tasks = n_threads;
16482
+
16483
+ GGML_ASSERT(node->src1->ne[3] == 1);
16484
+
16485
+ const int64_t ne00 = node->src0->ne[0]; // W
16486
+ const int64_t ne01 = node->src0->ne[1]; // H
16487
+ const int64_t ne02 = node->src0->ne[2]; // C
16488
+ const int64_t ne03 = node->src0->ne[3]; // N
16489
+
16490
+ const int64_t ne10 = node->src1->ne[0]; // W
16491
+ const int64_t ne11 = node->src1->ne[1]; // H
16492
+ const int64_t ne12 = node->src1->ne[2]; // C
16493
+
16494
+ const int64_t nk = ne00*ne01;
16495
+
16496
+ UNUSED(ne02);
16497
+ UNUSED(ne03);
16498
+ UNUSED(nk);
16499
+
16500
+ size_t cur = 0;
16501
+
16502
+ if (node->src0->type == GGML_TYPE_F16 &&
16503
+ node->src1->type == GGML_TYPE_F32) {
16504
+ cur = sizeof(ggml_fp16_t)*(ne10*ne11*ne12);
16505
+ } else if (node->src0->type == GGML_TYPE_F32 &&
16506
+ node->src1->type == GGML_TYPE_F32) {
16507
+ cur = sizeof(float)* (ne10*ne11*ne12);
16508
+ } else {
16509
+ GGML_ASSERT(false);
16510
+ }
16511
+
14461
16512
  work_size = MAX(work_size, cur);
14462
16513
  } break;
14463
16514
  case GGML_OP_FLASH_ATTN:
@@ -14498,11 +16549,50 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
14498
16549
 
14499
16550
  work_size = MAX(work_size, cur);
14500
16551
  } break;
16552
+ case GGML_OP_FLASH_ATTN_BACK:
16553
+ {
16554
+ node->n_tasks = n_threads;
16555
+
16556
+ size_t cur = 0;
16557
+
16558
+ const int64_t D = node->src0->ne[0];
16559
+ const int64_t ne11 = ggml_up(node->src1->ne[1], GGML_SOFT_MAX_UNROLL);
16560
+ const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in ggml_compute_forward_flash_attn_back
16561
+ if (node->src1->type == GGML_TYPE_F32) {
16562
+ cur = sizeof(float)*mxDn*node->n_tasks; // TODO: this can become (n_tasks-1)
16563
+ cur += sizeof(float)*mxDn*node->n_tasks; // this is overestimated by x2
16564
+ }
16565
+
16566
+ if (node->src1->type == GGML_TYPE_F16) {
16567
+ cur = sizeof(float)*mxDn*node->n_tasks; // TODO: this can become (n_tasks-1)
16568
+ cur += sizeof(float)*mxDn*node->n_tasks; // this is overestimated by x2
16569
+ }
16570
+
16571
+ work_size = MAX(work_size, cur);
16572
+ } break;
16573
+ case GGML_OP_WIN_PART:
16574
+ case GGML_OP_WIN_UNPART:
14501
16575
  case GGML_OP_MAP_UNARY:
14502
16576
  case GGML_OP_MAP_BINARY:
14503
16577
  {
14504
16578
  node->n_tasks = 1;
14505
16579
  } break;
16580
+ case GGML_OP_CROSS_ENTROPY_LOSS:
16581
+ {
16582
+ node->n_tasks = n_threads;
16583
+
16584
+ size_t cur = ggml_type_size(node->type)*(node->n_tasks + node->src0->ne[0]*node->n_tasks);
16585
+
16586
+ work_size = MAX(work_size, cur);
16587
+ } break;
16588
+ case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
16589
+ {
16590
+ node->n_tasks = n_threads;
16591
+
16592
+ size_t cur = ggml_type_size(node->type)*node->src0->ne[0]*node->n_tasks;
16593
+
16594
+ work_size = MAX(work_size, cur);
16595
+ } break;
14506
16596
  case GGML_OP_NONE:
14507
16597
  {
14508
16598
  node->n_tasks = 1;
@@ -15014,16 +17104,20 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context **
15014
17104
 
15015
17105
  if (!*ctx_data) {
15016
17106
  fprintf(stderr, "%s: failed to create ggml context\n", __func__);
17107
+ fclose(fin);
15017
17108
  return result;
15018
17109
  }
15019
17110
  }
15020
17111
 
15021
17112
  data = ggml_new_tensor_1d(*ctx_data, GGML_TYPE_I8, fsize);
15022
17113
 
15023
- const size_t ret = fread(data->data, sizeof(char), fsize, fin);
15024
- if (ret != fsize) {
15025
- fprintf(stderr, "%s: failed to read %s\n", __func__, fname);
15026
- return result;
17114
+ {
17115
+ const size_t ret = fread(data->data, sizeof(char), fsize, fin);
17116
+ if (ret != fsize) {
17117
+ fprintf(stderr, "%s: failed to read %s\n", __func__, fname);
17118
+ fclose(fin);
17119
+ return result;
17120
+ }
15027
17121
  }
15028
17122
 
15029
17123
  fclose(fin);
@@ -15478,6 +17572,7 @@ static void ggml_opt_get_grad(int np, struct ggml_tensor * const ps[], float * g
15478
17572
 
15479
17573
  static enum ggml_opt_result ggml_opt_adam(
15480
17574
  struct ggml_context * ctx,
17575
+ struct ggml_opt_context * opt,
15481
17576
  struct ggml_opt_params params,
15482
17577
  struct ggml_tensor * f,
15483
17578
  struct ggml_cgraph * gf,
@@ -15503,25 +17598,29 @@ static enum ggml_opt_result ggml_opt_adam(
15503
17598
  }
15504
17599
  }
15505
17600
 
17601
+ if ((opt->params.type != params.type) || (opt->nx != nx) || (opt->params.past != params.past)) {
17602
+ int iter = opt->iter;
17603
+ ggml_opt_init(opt->ctx, opt, params, nx);
17604
+ opt->iter = iter;
17605
+ }
17606
+
15506
17607
  // constants
15507
- const float alpha = params.adam.alpha;
17608
+ const float sched = params.adam.sched;
17609
+ const float decay = params.adam.decay * sched;
17610
+ const float alpha = params.adam.alpha * sched;
15508
17611
  const float beta1 = params.adam.beta1;
15509
17612
  const float beta2 = params.adam.beta2;
15510
17613
  const float eps = params.adam.eps;
15511
17614
 
15512
- float * x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // view of the parameters
15513
- float * g1 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // gradient
15514
- float * g2 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // gradient squared
15515
- float * m = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // first moment
15516
- float * v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // second moment
15517
- float * mh = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // first moment hat
15518
- float * vh = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // second moment hat
17615
+ float * x = opt->adam.x->data; // view of the parameters
17616
+ float * g1 = opt->adam.g1->data; // gradient
17617
+ float * g2 = opt->adam.g2->data; // gradient squared
17618
+ float * m = opt->adam.m->data; // first moment
17619
+ float * v = opt->adam.v->data; // second moment
17620
+ float * mh = opt->adam.mh->data; // first moment hat
17621
+ float * vh = opt->adam.vh->data; // second moment hat
15519
17622
 
15520
- float * pf = params.past > 0 ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)->data : NULL; // past function values
15521
-
15522
- // initialize
15523
- ggml_vec_set_f32(nx, m, 0.0f);
15524
- ggml_vec_set_f32(nx, v, 0.0f);
17623
+ float * pf = params.past > 0 ? opt->adam.pf->data : NULL; // past function values
15525
17624
 
15526
17625
  // update view
15527
17626
  ggml_opt_get_params(np, ps, x);
@@ -15531,16 +17630,27 @@ static enum ggml_opt_result ggml_opt_adam(
15531
17630
  ggml_set_f32 (f->grad, 1.0f);
15532
17631
  ggml_graph_compute(ctx, gb);
15533
17632
 
15534
- float fx_prev = ggml_get_f32_1d(f, 0);
17633
+ opt->adam.fx_prev = ggml_get_f32_1d(f, 0);
17634
+ opt->adam.fx_best = opt->adam.fx_prev;
15535
17635
  if (pf) {
15536
- pf[0] = fx_prev;
17636
+ pf[opt->iter % params.past] = opt->adam.fx_prev;
15537
17637
  }
15538
17638
 
15539
- int n_no_improvement = 0;
15540
- float fx_best = fx_prev;
17639
+ // initialize
17640
+ if (opt->just_initialized) {
17641
+ opt->adam.n_no_improvement = 0;
17642
+ opt->just_initialized = false;
17643
+ }
17644
+
17645
+ float * fx_best = &opt->adam.fx_best;
17646
+ float * fx_prev = &opt->adam.fx_prev;
17647
+ int * n_no_improvement = &opt->adam.n_no_improvement;
17648
+
17649
+ int iter0 = opt->iter;
15541
17650
 
15542
17651
  // run the optimizer
15543
17652
  for (int t = 0; t < params.adam.n_iter; ++t) {
17653
+ opt->iter = iter0 + t + 1;
15544
17654
  GGML_PRINT_DEBUG ("=== iter %d ===\n", t);
15545
17655
 
15546
17656
  GGML_PRINT_DEBUG ("f = %10.6f\n", ggml_get_f32_1d(f, 0));
@@ -15574,17 +17684,22 @@ static enum ggml_opt_result ggml_opt_adam(
15574
17684
 
15575
17685
  // m^hat = m_t / (1 - beta1^t)
15576
17686
  // v^hat = v_t / (1 - beta2^t)
15577
- // x_t = x_t-1 - alpha*m^hat/(sqrt(v^hat) + eps)
17687
+ // x_t = x_t-1 - sched*(alpha*m^hat/(sqrt(v^hat) + eps) + decay*x_t-1)
17688
+ // x_t = x_t-1 - sched*alpha*m^hat/(sqrt(v^hat) + eps) - sched*decay*x_t-1
17689
+ // x_t = x_t-1*(1-sched*decay) - sched*alpha*m^hat/(sqrt(v^hat) + eps)
17690
+ // x_t = x_t-1*(1-sched*decay) + sched*decay*(-alpha/decay)*m^hat/(sqrt(v^hat) + eps)
17691
+ // x_t = mix(x_t-1, (-alpha/decay)*m^hat/(sqrt(v^hat) + eps), sched*decay)
15578
17692
  ggml_vec_cpy_f32 (nx, mh, m);
15579
17693
  ggml_vec_cpy_f32 (nx, vh, v);
15580
17694
 
15581
- ggml_vec_scale_f32(nx, mh, alpha/(1.0f - powf(beta1, t + 1)));
15582
- ggml_vec_scale_f32(nx, vh, 1.0f/(1.0f - powf(beta2, t + 1)));
17695
+ ggml_vec_scale_f32(nx, mh, alpha/(1.0f - powf(beta1, opt->iter)));
17696
+ ggml_vec_scale_f32(nx, vh, 1.0f/(1.0f - powf(beta2, opt->iter)));
15583
17697
 
15584
17698
  ggml_vec_sqrt_f32 (nx, vh, vh);
15585
17699
  ggml_vec_acc1_f32 (nx, vh, eps);
15586
17700
 
15587
17701
  ggml_vec_div_f32 (nx, mh, mh, vh);
17702
+ ggml_vec_scale_f32(nx, x, 1.0f - decay);
15588
17703
  ggml_vec_sub_f32 (nx, x, x, mh);
15589
17704
 
15590
17705
  // update the parameters
@@ -15598,7 +17713,7 @@ static enum ggml_opt_result ggml_opt_adam(
15598
17713
  const float fx = ggml_get_f32_1d(f, 0);
15599
17714
 
15600
17715
  // check convergence
15601
- if (fabsf(fx - fx_prev)/fx < params.adam.eps_f) {
17716
+ if (fabsf(fx - fx_prev[0])/fx < params.adam.eps_f) {
15602
17717
  GGML_PRINT_DEBUG("converged\n");
15603
17718
 
15604
17719
  return GGML_OPT_OK;
@@ -15607,32 +17722,32 @@ static enum ggml_opt_result ggml_opt_adam(
15607
17722
  // delta-based convergence test
15608
17723
  if (pf != NULL) {
15609
17724
  // need at least params.past iterations to start checking for convergence
15610
- if (params.past <= t) {
15611
- const float rate = (pf[t%params.past] - fx)/fx;
17725
+ if (params.past <= iter0 + t) {
17726
+ const float rate = (pf[(iter0 + t)%params.past] - fx)/fx;
15612
17727
 
15613
17728
  if (fabsf(rate) < params.delta) {
15614
17729
  return GGML_OPT_OK;
15615
17730
  }
15616
17731
  }
15617
17732
 
15618
- pf[t%params.past] = fx;
17733
+ pf[(iter0 + t)%params.past] = fx;
15619
17734
  }
15620
17735
 
15621
17736
  // check for improvement
15622
17737
  if (params.max_no_improvement > 0) {
15623
- if (fx_best > fx) {
15624
- fx_best = fx;
15625
- n_no_improvement = 0;
17738
+ if (fx_best[0] > fx) {
17739
+ fx_best[0] = fx;
17740
+ n_no_improvement[0] = 0;
15626
17741
  } else {
15627
- ++n_no_improvement;
17742
+ ++n_no_improvement[0];
15628
17743
 
15629
- if (n_no_improvement >= params.max_no_improvement) {
17744
+ if (n_no_improvement[0] >= params.max_no_improvement) {
15630
17745
  return GGML_OPT_OK;
15631
17746
  }
15632
17747
  }
15633
17748
  }
15634
17749
 
15635
- fx_prev = fx;
17750
+ fx_prev[0] = fx;
15636
17751
 
15637
17752
  {
15638
17753
  const int64_t t_end_cpu = ggml_cycles();
@@ -15771,6 +17886,7 @@ static enum ggml_opt_result linesearch_backtracking(
15771
17886
 
15772
17887
  static enum ggml_opt_result ggml_opt_lbfgs(
15773
17888
  struct ggml_context * ctx,
17889
+ struct ggml_opt_context * opt,
15774
17890
  struct ggml_opt_params params,
15775
17891
  struct ggml_tensor * f,
15776
17892
  struct ggml_cgraph * gf,
@@ -15803,31 +17919,32 @@ static enum ggml_opt_result ggml_opt_lbfgs(
15803
17919
  }
15804
17920
  }
15805
17921
 
15806
- float * x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // current parameters
15807
- float * xp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // previous parameters
15808
- float * g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // current gradient
15809
- float * gp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // previous gradient
15810
- float * d = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // search direction
17922
+ if ((opt->params.type != params.type) || (opt->nx != nx) || (opt->params.past != params.past) || (opt->params.lbfgs.m != params.lbfgs.m)) {
17923
+ int iter = opt->iter;
17924
+ ggml_opt_init(ctx, opt, params, nx);
17925
+ opt->iter = iter;
17926
+ }
17927
+
17928
+ float * x = opt->lbfgs.x->data; // current parameters
17929
+ float * xp = opt->lbfgs.xp->data; // previous parameters
17930
+ float * g = opt->lbfgs.g->data; // current gradient
17931
+ float * gp = opt->lbfgs.gp->data; // previous gradient
17932
+ float * d = opt->lbfgs.d->data; // search direction
15811
17933
 
15812
- float * pf = params.past > 0 ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)->data : NULL; // past function values
17934
+ float * pf = params.past > 0 ? opt->lbfgs.pf->data : NULL; // past function values
15813
17935
 
15814
17936
  float fx = 0.0f; // cost function value
15815
17937
  float xnorm = 0.0f; // ||x||
15816
17938
  float gnorm = 0.0f; // ||g||
15817
- float step = 0.0f;
15818
17939
 
15819
17940
  // initialize x from the graph nodes
15820
17941
  ggml_opt_get_params(np, ps, x);
15821
17942
 
15822
17943
  // the L-BFGS memory
15823
- struct ggml_lbfgs_iteration_data * lm = alloca(sizeof(struct ggml_lbfgs_iteration_data)*m);
15824
-
15825
- for (int i = 0; i < m; ++i) {
15826
- lm[i].alpha = 0.0f;
15827
- lm[i].ys = 0.0f;
15828
- lm[i].s = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data;
15829
- lm[i].y = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data;
15830
- }
17944
+ float * lm_alpha = opt->lbfgs.lmal->data;
17945
+ float * lm_ys = opt->lbfgs.lmys->data;
17946
+ float * lm_s = opt->lbfgs.lms->data;
17947
+ float * lm_y = opt->lbfgs.lmy->data;
15831
17948
 
15832
17949
  // evaluate the function value and its gradient
15833
17950
  {
@@ -15842,12 +17959,6 @@ static enum ggml_opt_result ggml_opt_lbfgs(
15842
17959
  fx = ggml_get_f32_1d(f, 0);
15843
17960
  }
15844
17961
 
15845
- if (pf) {
15846
- pf[0] = fx;
15847
- }
15848
-
15849
- float fx_best = fx;
15850
-
15851
17962
  // search direction = -gradient
15852
17963
  ggml_vec_neg_f32(nx, d, g);
15853
17964
 
@@ -15864,26 +17975,43 @@ static enum ggml_opt_result ggml_opt_lbfgs(
15864
17975
  return GGML_OPT_OK;
15865
17976
  }
15866
17977
 
15867
- // initial step
15868
- ggml_vec_norm_inv_f32(nx, &step, d);
17978
+ if (opt->just_initialized) {
17979
+ if (pf) {
17980
+ pf[0] = fx;
17981
+ }
17982
+ opt->lbfgs.fx_best = fx;
17983
+
17984
+ // initial step
17985
+ ggml_vec_norm_inv_f32(nx, &opt->lbfgs.step, d);
17986
+ opt->lbfgs.j = 0;
17987
+ opt->lbfgs.k = 1;
17988
+ opt->lbfgs.end = 0;
17989
+ opt->lbfgs.n_no_improvement = 0;
17990
+ opt->just_initialized = false;
17991
+ }
17992
+
17993
+ float * fx_best = &opt->lbfgs.fx_best;
17994
+ float * step = &opt->lbfgs.step;
17995
+ int * j = &opt->lbfgs.j;
17996
+ int * k = &opt->lbfgs.k;
17997
+ int * end = &opt->lbfgs.end;
17998
+ int * n_no_improvement = &opt->lbfgs.n_no_improvement;
15869
17999
 
15870
- int j = 0;
15871
- int k = 1;
15872
- int ls = 0;
15873
- int end = 0;
15874
- int bound = 0;
15875
- int n_no_improvement = 0;
18000
+ int ls = 0;
18001
+ int bound = 0;
15876
18002
 
15877
18003
  float ys = 0.0f;
15878
18004
  float yy = 0.0f;
15879
18005
  float beta = 0.0f;
15880
18006
 
18007
+ int it = 0;
18008
+
15881
18009
  while (true) {
15882
18010
  // store the current position and gradient vectors
15883
18011
  ggml_vec_cpy_f32(nx, xp, x);
15884
18012
  ggml_vec_cpy_f32(nx, gp, g);
15885
18013
 
15886
- ls = linesearch_backtracking(ctx, &params, nx, x, &fx, g, d, &step, xp, f, gf, gb, np, ps);
18014
+ ls = linesearch_backtracking(ctx, &params, nx, x, &fx, g, d, step, xp, f, gf, gb, np, ps);
15887
18015
 
15888
18016
  if (ls < 0) {
15889
18017
  // linesearch failed - go back to the previous point and return
@@ -15909,32 +18037,32 @@ static enum ggml_opt_result ggml_opt_lbfgs(
15909
18037
  // delta-based convergence test
15910
18038
  if (pf != NULL) {
15911
18039
  // need at least params.past iterations to start checking for convergence
15912
- if (params.past <= k) {
15913
- const float rate = (pf[k%params.past] - fx)/fx;
18040
+ if (params.past <= k[0]) {
18041
+ const float rate = (pf[k[0]%params.past] - fx)/fx;
15914
18042
 
15915
18043
  if (fabsf(rate) < params.delta) {
15916
18044
  return GGML_OPT_OK;
15917
18045
  }
15918
18046
  }
15919
18047
 
15920
- pf[k%params.past] = fx;
18048
+ pf[k[0]%params.past] = fx;
15921
18049
  }
15922
18050
 
15923
18051
  // check for improvement
15924
18052
  if (params.max_no_improvement > 0) {
15925
- if (fx < fx_best) {
15926
- fx_best = fx;
15927
- n_no_improvement = 0;
18053
+ if (fx < fx_best[0]) {
18054
+ fx_best[0] = fx;
18055
+ n_no_improvement[0] = 0;
15928
18056
  } else {
15929
- n_no_improvement++;
18057
+ n_no_improvement[0]++;
15930
18058
 
15931
- if (n_no_improvement >= params.max_no_improvement) {
18059
+ if (n_no_improvement[0] >= params.max_no_improvement) {
15932
18060
  return GGML_OPT_OK;
15933
18061
  }
15934
18062
  }
15935
18063
  }
15936
18064
 
15937
- if (params.lbfgs.n_iter != 0 && params.lbfgs.n_iter < k + 1) {
18065
+ if (params.lbfgs.n_iter != 0 && params.lbfgs.n_iter < it + 1) {
15938
18066
  // reached the maximum number of iterations
15939
18067
  return GGML_OPT_DID_NOT_CONVERGE;
15940
18068
  }
@@ -15943,50 +18071,51 @@ static enum ggml_opt_result ggml_opt_lbfgs(
15943
18071
  // s_{k+1} = x_{k+1} - x_{k} = \step * d_{k}.
15944
18072
  // y_{k+1} = g_{k+1} - g_{k}.
15945
18073
  //
15946
- ggml_vec_sub_f32(nx, lm[end].s, x, xp);
15947
- ggml_vec_sub_f32(nx, lm[end].y, g, gp);
18074
+ ggml_vec_sub_f32(nx, &lm_s[end[0]*nx], x, xp);
18075
+ ggml_vec_sub_f32(nx, &lm_y[end[0]*nx], g, gp);
15948
18076
 
15949
18077
  // compute scalars ys and yy:
15950
18078
  // ys = y^t \cdot s -> 1 / \rho.
15951
18079
  // yy = y^t \cdot y.
15952
18080
  //
15953
- ggml_vec_dot_f32(nx, &ys, lm[end].y, lm[end].s);
15954
- ggml_vec_dot_f32(nx, &yy, lm[end].y, lm[end].y);
18081
+ ggml_vec_dot_f32(nx, &ys, &lm_y[end[0]*nx], &lm_s[end[0] *nx]);
18082
+ ggml_vec_dot_f32(nx, &yy, &lm_y[end[0]*nx], &lm_y[end[0]*nx]);
15955
18083
 
15956
- lm[end].ys = ys;
18084
+ lm_ys[end[0]] = ys;
15957
18085
 
15958
18086
  // find new search direction
15959
18087
  // ref: https://en.wikipedia.org/wiki/Limited-memory_BFGS
15960
18088
 
15961
- bound = (m <= k) ? m : k;
15962
- k++;
15963
- end = (end + 1)%m;
18089
+ bound = (m <= k[0]) ? m : k[0];
18090
+ k[0]++;
18091
+ it++;
18092
+ end[0] = (end[0] + 1)%m;
15964
18093
 
15965
18094
  // initialize search direction with -g
15966
18095
  ggml_vec_neg_f32(nx, d, g);
15967
18096
 
15968
- j = end;
18097
+ j[0] = end[0];
15969
18098
  for (int i = 0; i < bound; ++i) {
15970
- j = (j + m - 1) % m;
18099
+ j[0] = (j[0] + m - 1) % m;
15971
18100
  // \alpha_{j} = \rho_{j} s^{t}_{j} \cdot q_{k+1}
15972
- ggml_vec_dot_f32(nx, &lm[j].alpha, lm[j].s, d);
15973
- lm[j].alpha /= lm[j].ys;
18101
+ ggml_vec_dot_f32(nx, &lm_alpha[j[0]], &lm_s[j[0]*nx], d);
18102
+ lm_alpha[j[0]] /= lm_ys[j[0]];
15974
18103
  // q_{i} = q_{i+1} - \alpha_{i} y_{i}
15975
- ggml_vec_mad_f32(nx, d, lm[j].y, -lm[j].alpha);
18104
+ ggml_vec_mad_f32(nx, d, &lm_y[j[0]*nx], -lm_alpha[j[0]]);
15976
18105
  }
15977
18106
 
15978
18107
  ggml_vec_scale_f32(nx, d, ys/yy);
15979
18108
 
15980
18109
  for (int i = 0; i < bound; ++i) {
15981
18110
  // \beta_{j} = \rho_{j} y^t_{j} \cdot \gamma_{i}
15982
- ggml_vec_dot_f32(nx, &beta, lm[j].y, d);
15983
- beta /= lm[j].ys;
18111
+ ggml_vec_dot_f32(nx, &beta, &lm_y[j[0]*nx], d);
18112
+ beta /= lm_ys[j[0]];
15984
18113
  // \gamma_{i+1} = \gamma_{i} + (\alpha_{j} - \beta_{j}) s_{j}
15985
- ggml_vec_mad_f32(nx, d, lm[j].s, lm[j].alpha - beta);
15986
- j = (j + 1)%m;
18114
+ ggml_vec_mad_f32(nx, d, &lm_s[j[0]*nx], lm_alpha[j[0]] - beta);
18115
+ j[0] = (j[0] + 1)%m;
15987
18116
  }
15988
18117
 
15989
- step = 1.0;
18118
+ step[0] = 1.0;
15990
18119
  }
15991
18120
 
15992
18121
  return GGML_OPT_DID_NOT_CONVERGE;
@@ -16011,6 +18140,8 @@ struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) {
16011
18140
 
16012
18141
  .adam = {
16013
18142
  .n_iter = 10000,
18143
+ .sched = 1.000f,
18144
+ .decay = 0.001f,
16014
18145
  .alpha = 0.001f,
16015
18146
  .beta1 = 0.9f,
16016
18147
  .beta2 = 0.999f,
@@ -16053,6 +18184,70 @@ struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) {
16053
18184
  return result;
16054
18185
  }
16055
18186
 
18187
+ GGML_API void ggml_opt_init(
18188
+ struct ggml_context * ctx,
18189
+ struct ggml_opt_context * opt,
18190
+ struct ggml_opt_params params,
18191
+ int64_t nx) {
18192
+ opt->ctx = ctx;
18193
+ opt->params = params;
18194
+ opt->iter = 0;
18195
+ opt->nx = nx;
18196
+ opt->just_initialized = true;
18197
+ switch (opt->params.type) {
18198
+ case GGML_OPT_ADAM:
18199
+ {
18200
+ opt->adam.x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
18201
+ opt->adam.g1 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
18202
+ opt->adam.g2 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
18203
+ opt->adam.m = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
18204
+ opt->adam.v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
18205
+ opt->adam.mh = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
18206
+ opt->adam.vh = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
18207
+ opt->adam.pf = params.past > 0
18208
+ ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)
18209
+ : NULL;
18210
+ ggml_set_zero(opt->adam.x);
18211
+ ggml_set_zero(opt->adam.g1);
18212
+ ggml_set_zero(opt->adam.g2);
18213
+ ggml_set_zero(opt->adam.m);
18214
+ ggml_set_zero(opt->adam.v);
18215
+ ggml_set_zero(opt->adam.mh);
18216
+ ggml_set_zero(opt->adam.vh);
18217
+ if (opt->adam.pf) {
18218
+ ggml_set_zero(opt->adam.pf);
18219
+ }
18220
+ } break;
18221
+ case GGML_OPT_LBFGS:
18222
+ {
18223
+ opt->lbfgs.x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
18224
+ opt->lbfgs.xp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
18225
+ opt->lbfgs.g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
18226
+ opt->lbfgs.gp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
18227
+ opt->lbfgs.d = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
18228
+ opt->lbfgs.pf = params.past > 0
18229
+ ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)
18230
+ : NULL;
18231
+ opt->lbfgs.lmal = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.lbfgs.m);
18232
+ opt->lbfgs.lmys = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.lbfgs.m);
18233
+ opt->lbfgs.lms = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nx, params.lbfgs.m);
18234
+ opt->lbfgs.lmy = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nx, params.lbfgs.m);
18235
+ ggml_set_zero(opt->lbfgs.x);
18236
+ ggml_set_zero(opt->lbfgs.xp);
18237
+ ggml_set_zero(opt->lbfgs.g);
18238
+ ggml_set_zero(opt->lbfgs.gp);
18239
+ ggml_set_zero(opt->lbfgs.d);
18240
+ if (opt->lbfgs.pf) {
18241
+ ggml_set_zero(opt->lbfgs.pf);
18242
+ }
18243
+ ggml_set_zero(opt->lbfgs.lmal);
18244
+ ggml_set_zero(opt->lbfgs.lmys);
18245
+ ggml_set_zero(opt->lbfgs.lms);
18246
+ ggml_set_zero(opt->lbfgs.lmy);
18247
+ } break;
18248
+ }
18249
+ }
18250
+
16056
18251
  enum ggml_opt_result ggml_opt(
16057
18252
  struct ggml_context * ctx,
16058
18253
  struct ggml_opt_params params,
@@ -16075,33 +18270,65 @@ enum ggml_opt_result ggml_opt(
16075
18270
 
16076
18271
  enum ggml_opt_result result = GGML_OPT_OK;
16077
18272
 
18273
+ struct ggml_opt_context * opt = (struct ggml_opt_context *) alloca(sizeof(struct ggml_opt_context));
18274
+
18275
+ ggml_opt_init(ctx, opt, params, 0);
18276
+ result = ggml_opt_resume(ctx, opt, f);
18277
+
18278
+ if (free_ctx) {
18279
+ ggml_free(ctx);
18280
+ }
18281
+
18282
+ return result;
18283
+ }
18284
+
18285
+ enum ggml_opt_result ggml_opt_resume(
18286
+ struct ggml_context * ctx,
18287
+ struct ggml_opt_context * opt,
18288
+ struct ggml_tensor * f) {
18289
+
18290
+ // build forward + backward compute graphs
18291
+ struct ggml_tensor * gfbuf = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(struct ggml_cgraph) / GGML_TYPE_SIZE[GGML_TYPE_I32]+ (sizeof(struct ggml_cgraph) % GGML_TYPE_SIZE[GGML_TYPE_I32] ? 1 : 0));
18292
+ struct ggml_tensor * gbbuf = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(struct ggml_cgraph) / GGML_TYPE_SIZE[GGML_TYPE_I32]+ (sizeof(struct ggml_cgraph) % GGML_TYPE_SIZE[GGML_TYPE_I32] ? 1 : 0));
18293
+
18294
+ struct ggml_cgraph * gf = (struct ggml_cgraph *) gfbuf->data;
18295
+ struct ggml_cgraph * gb = (struct ggml_cgraph *) gbbuf->data;
18296
+
18297
+ *gf = ggml_build_forward (f);
18298
+ *gb = ggml_build_backward(ctx, gf, true);
18299
+
18300
+ return ggml_opt_resume_g(ctx, opt, f, gf, gb);
18301
+ }
18302
+
18303
+ enum ggml_opt_result ggml_opt_resume_g(
18304
+ struct ggml_context * ctx,
18305
+ struct ggml_opt_context * opt,
18306
+ struct ggml_tensor * f,
18307
+ struct ggml_cgraph * gf,
18308
+ struct ggml_cgraph * gb) {
18309
+
16078
18310
  // build forward + backward compute graphs
16079
- struct ggml_cgraph gf = ggml_build_forward (f);
16080
- struct ggml_cgraph gb = ggml_build_backward(ctx, &gf, true);
18311
+ enum ggml_opt_result result = GGML_OPT_OK;
16081
18312
 
16082
- switch (params.type) {
18313
+ switch (opt->params.type) {
16083
18314
  case GGML_OPT_ADAM:
16084
18315
  {
16085
- result = ggml_opt_adam(ctx, params, f, &gf, &gb);
18316
+ result = ggml_opt_adam(ctx, opt, opt->params, f, gf, gb);
16086
18317
  } break;
16087
18318
  case GGML_OPT_LBFGS:
16088
18319
  {
16089
- result = ggml_opt_lbfgs(ctx, params, f, &gf, &gb);
18320
+ result = ggml_opt_lbfgs(ctx, opt, opt->params, f, gf, gb);
16090
18321
  } break;
16091
18322
  }
16092
18323
 
16093
- if (params.print_forward_graph) {
16094
- ggml_graph_print (&gf);
16095
- ggml_graph_dump_dot(&gf, NULL, "opt-forward.dot");
16096
- }
16097
-
16098
- if (params.print_backward_graph) {
16099
- ggml_graph_print (&gb);
16100
- ggml_graph_dump_dot(&gb, &gf, "opt-backward.dot");
18324
+ if (opt->params.print_forward_graph) {
18325
+ ggml_graph_print (gf);
18326
+ ggml_graph_dump_dot(gf, NULL, "opt-forward.dot");
16101
18327
  }
16102
18328
 
16103
- if (free_ctx) {
16104
- ggml_free(ctx);
18329
+ if (opt->params.print_backward_graph) {
18330
+ ggml_graph_print (gb);
18331
+ ggml_graph_dump_dot(gb, gf, "opt-backward.dot");
16105
18332
  }
16106
18333
 
16107
18334
  return result;
@@ -16301,6 +18528,18 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i
16301
18528
  result = ggml_quantize_q6_K(src + start, block, n, n, hist);
16302
18529
  } break;
16303
18530
  #endif
18531
+ case GGML_TYPE_F16:
18532
+ {
18533
+ int elemsize = sizeof(ggml_fp16_t);
18534
+ ggml_fp32_to_fp16_row(src + start, (ggml_fp16_t *)dst + start, n);
18535
+ result = n * elemsize;
18536
+ } break;
18537
+ case GGML_TYPE_F32:
18538
+ {
18539
+ int elemsize = sizeof(float);
18540
+ result = n * elemsize;
18541
+ memcpy((uint8_t *)dst + start * elemsize, src + start, result);
18542
+ } break;
16304
18543
  default:
16305
18544
  assert(false);
16306
18545
  }