llama_cpp 0.2.0 → 0.2.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -35,6 +35,12 @@
35
35
  #define static_assert(cond, msg) struct global_scope_noop_trick
36
36
  #endif
37
37
 
38
+ #if defined(_MSC_VER)
39
+ // disable "possible loss of data" to avoid hundreds of casts
40
+ // we should just be careful :)
41
+ #pragma warning(disable: 4244 4267)
42
+ #endif
43
+
38
44
  #if defined(_WIN32)
39
45
 
40
46
  #include <windows.h>
@@ -106,6 +112,7 @@ typedef void* thread_ret_t;
106
112
  /*#define GGML_PERF*/
107
113
  #define GGML_DEBUG 0
108
114
  #define GGML_GELU_FP16
115
+ #define GGML_GELU_QUICK_FP16
109
116
  #define GGML_SILU_FP16
110
117
 
111
118
  #define GGML_SOFT_MAX_UNROLL 4
@@ -334,6 +341,9 @@ static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
334
341
  // precomputed gelu table for f16 (128 KB)
335
342
  static ggml_fp16_t table_gelu_f16[1 << 16];
336
343
 
344
+ // precomputed quick gelu table for f16 (128 KB)
345
+ static ggml_fp16_t table_gelu_quick_f16[1 << 16];
346
+
337
347
  // precomputed silu table for f16 (128 KB)
338
348
  static ggml_fp16_t table_silu_f16[1 << 16];
339
349
 
@@ -1671,14 +1681,17 @@ quantize_fns_t ggml_internal_get_quantize_fn(size_t i) {
1671
1681
  #define GGML_F32x4_REDUCE_ONE(x) vaddvq_f32(x)
1672
1682
  #define GGML_F32x4_REDUCE(res, x) \
1673
1683
  { \
1674
- for (int i = 0; i < GGML_F32_ARR/2; ++i) { \
1675
- x[2*i] = vaddq_f32(x[2*i], x[2*i+1]); \
1684
+ int offset = GGML_F32_ARR >> 1; \
1685
+ for (int i = 0; i < offset; ++i) { \
1686
+ x[i] = vaddq_f32(x[i], x[offset+i]); \
1676
1687
  } \
1677
- for (int i = 0; i < GGML_F32_ARR/4; ++i) { \
1678
- x[4*i] = vaddq_f32(x[4*i], x[4*i+2]); \
1688
+ offset >>= 1; \
1689
+ for (int i = 0; i < offset; ++i) { \
1690
+ x[i] = vaddq_f32(x[i], x[offset+i]); \
1679
1691
  } \
1680
- for (int i = 0; i < GGML_F32_ARR/8; ++i) { \
1681
- x[8*i] = vaddq_f32(x[8*i], x[8*i+4]); \
1692
+ offset >>= 1; \
1693
+ for (int i = 0; i < offset; ++i) { \
1694
+ x[i] = vaddq_f32(x[i], x[offset+i]); \
1682
1695
  } \
1683
1696
  res = GGML_F32x4_REDUCE_ONE(x[0]); \
1684
1697
  }
@@ -1709,14 +1722,17 @@ quantize_fns_t ggml_internal_get_quantize_fn(size_t i) {
1709
1722
  #define GGML_F16x8_MUL vmulq_f16
1710
1723
  #define GGML_F16x8_REDUCE(res, x) \
1711
1724
  { \
1712
- for (int i = 0; i < GGML_F16_ARR/2; ++i) { \
1713
- x[2*i] = vaddq_f16(x[2*i], x[2*i+1]); \
1725
+ int offset = GGML_F16_ARR >> 1; \
1726
+ for (int i = 0; i < offset; ++i) { \
1727
+ x[i] = vaddq_f16(x[i], x[offset+i]); \
1714
1728
  } \
1715
- for (int i = 0; i < GGML_F16_ARR/4; ++i) { \
1716
- x[4*i] = vaddq_f16(x[4*i], x[4*i+2]); \
1729
+ offset >>= 1; \
1730
+ for (int i = 0; i < offset; ++i) { \
1731
+ x[i] = vaddq_f16(x[i], x[offset+i]); \
1717
1732
  } \
1718
- for (int i = 0; i < GGML_F16_ARR/8; ++i) { \
1719
- x[8*i] = vaddq_f16(x[8*i], x[8*i+4]); \
1733
+ offset >>= 1; \
1734
+ for (int i = 0; i < offset; ++i) { \
1735
+ x[i] = vaddq_f16(x[i], x[offset+i]); \
1720
1736
  } \
1721
1737
  const float32x4_t t0 = vcvt_f32_f16(vget_low_f16 (x[0])); \
1722
1738
  const float32x4_t t1 = vcvt_f32_f16(vget_high_f16(x[0])); \
@@ -1783,14 +1799,17 @@ quantize_fns_t ggml_internal_get_quantize_fn(size_t i) {
1783
1799
  #define GGML_F32x8_MUL _mm256_mul_ps
1784
1800
  #define GGML_F32x8_REDUCE(res, x) \
1785
1801
  { \
1786
- for (int i = 0; i < GGML_F32_ARR/2; ++i) { \
1787
- x[2*i] = _mm256_add_ps(x[2*i], x[2*i+1]); \
1802
+ int offset = GGML_F32_ARR >> 1; \
1803
+ for (int i = 0; i < offset; ++i) { \
1804
+ x[i] = _mm256_add_ps(x[i], x[offset+i]); \
1788
1805
  } \
1789
- for (int i = 0; i < GGML_F32_ARR/4; ++i) { \
1790
- x[4*i] = _mm256_add_ps(x[4*i], x[4*i+2]); \
1806
+ offset >>= 1; \
1807
+ for (int i = 0; i < offset; ++i) { \
1808
+ x[i] = _mm256_add_ps(x[i], x[offset+i]); \
1791
1809
  } \
1792
- for (int i = 0; i < GGML_F32_ARR/8; ++i) { \
1793
- x[8*i] = _mm256_add_ps(x[8*i], x[8*i+4]); \
1810
+ offset >>= 1; \
1811
+ for (int i = 0; i < offset; ++i) { \
1812
+ x[i] = _mm256_add_ps(x[i], x[offset+i]); \
1794
1813
  } \
1795
1814
  const __m128 t0 = _mm_add_ps(_mm256_castps256_ps128(x[0]), \
1796
1815
  _mm256_extractf128_ps(x[0], 1)); \
@@ -1880,14 +1899,17 @@ static inline void __avx_f32cx8_store(ggml_fp16_t *x, __m256 y) {
1880
1899
  #define GGML_F32x4_MUL vec_mul
1881
1900
  #define GGML_F32x4_REDUCE(res, x) \
1882
1901
  { \
1883
- for (int i = 0; i < GGML_F32_ARR/2; ++i) { \
1884
- x[2*i] = vec_add(x[2*i], x[2*i+1]); \
1902
+ int offset = GGML_F32_ARR >> 1; \
1903
+ for (int i = 0; i < offset; ++i) { \
1904
+ x[i] = vec_add(x[i], x[offset+i]); \
1885
1905
  } \
1886
- for (int i = 0; i < GGML_F32_ARR/4; ++i) { \
1887
- x[4*i] = vec_add(x[4*i], x[4*i+2]); \
1906
+ offset >>= 1; \
1907
+ for (int i = 0; i < offset; ++i) { \
1908
+ x[i] = vec_add(x[i], x[offset+i]); \
1888
1909
  } \
1889
- for (int i = 0; i < GGML_F32_ARR/8; ++i) { \
1890
- x[8*i] = vec_add(x[8*i], x[8*i+4]); \
1910
+ offset >>= 1; \
1911
+ for (int i = 0; i < offset; ++i) { \
1912
+ x[i] = vec_add(x[i], x[offset+i]); \
1891
1913
  } \
1892
1914
  res = vec_extract(x[0], 0) + \
1893
1915
  vec_extract(x[0], 1) + \
@@ -1943,14 +1965,17 @@ static inline void __avx_f32cx8_store(ggml_fp16_t *x, __m256 y) {
1943
1965
  #define GGML_F32x4_MUL wasm_f32x4_mul
1944
1966
  #define GGML_F32x4_REDUCE(res, x) \
1945
1967
  { \
1946
- for (int i = 0; i < GGML_F32_ARR/2; ++i) { \
1947
- x[2*i] = wasm_f32x4_add(x[2*i], x[2*i+1]); \
1968
+ int offset = GGML_F32_ARR >> 1; \
1969
+ for (int i = 0; i < offset; ++i) { \
1970
+ x[i] = wasm_f32x4_add(x[i], x[offset+i]); \
1948
1971
  } \
1949
- for (int i = 0; i < GGML_F32_ARR/4; ++i) { \
1950
- x[4*i] = wasm_f32x4_add(x[4*i], x[4*i+2]); \
1972
+ offset >>= 1; \
1973
+ for (int i = 0; i < offset; ++i) { \
1974
+ x[i] = wasm_f32x4_add(x[i], x[offset+i]); \
1951
1975
  } \
1952
- for (int i = 0; i < GGML_F32_ARR/8; ++i) { \
1953
- x[8*i] = wasm_f32x4_add(x[8*i], x[8*i+4]); \
1976
+ offset >>= 1; \
1977
+ for (int i = 0; i < offset; ++i) { \
1978
+ x[i] = wasm_f32x4_add(x[i], x[offset+i]); \
1954
1979
  } \
1955
1980
  res = wasm_f32x4_extract_lane(x[0], 0) + \
1956
1981
  wasm_f32x4_extract_lane(x[0], 1) + \
@@ -2005,14 +2030,17 @@ inline static void __wasm_f16x4_store(ggml_fp16_t * p, v128_t x) {
2005
2030
  #define GGML_F16x4_MUL wasm_f32x4_mul
2006
2031
  #define GGML_F16x4_REDUCE(res, x) \
2007
2032
  { \
2008
- for (int i = 0; i < GGML_F16_ARR/2; ++i) { \
2009
- x[2*i] = wasm_f32x4_add(x[2*i], x[2*i+1]); \
2033
+ int offset = GGML_F16_ARR >> 1; \
2034
+ for (int i = 0; i < offset; ++i) { \
2035
+ x[i] = wasm_f32x4_add(x[i], x[offset+i]); \
2010
2036
  } \
2011
- for (int i = 0; i < GGML_F16_ARR/4; ++i) { \
2012
- x[4*i] = wasm_f32x4_add(x[4*i], x[4*i+2]); \
2037
+ offset >>= 1; \
2038
+ for (int i = 0; i < offset; ++i) { \
2039
+ x[i] = wasm_f32x4_add(x[i], x[offset+i]); \
2013
2040
  } \
2014
- for (int i = 0; i < GGML_F16_ARR/8; ++i) { \
2015
- x[8*i] = wasm_f32x4_add(x[8*i], x[8*i+4]); \
2041
+ offset >>= 1; \
2042
+ for (int i = 0; i < offset; ++i) { \
2043
+ x[i] = wasm_f32x4_add(x[i], x[offset+i]); \
2016
2044
  } \
2017
2045
  res = wasm_f32x4_extract_lane(x[0], 0) + \
2018
2046
  wasm_f32x4_extract_lane(x[0], 1) + \
@@ -2054,14 +2082,17 @@ inline static void __wasm_f16x4_store(ggml_fp16_t * p, v128_t x) {
2054
2082
  #define GGML_F32x4_MUL _mm_mul_ps
2055
2083
  #define GGML_F32x4_REDUCE(res, x) \
2056
2084
  { \
2057
- for (int i = 0; i < GGML_F32_ARR/2; ++i) { \
2058
- x[2*i] = _mm_add_ps(x[2*i], x[2*i+1]); \
2085
+ int offset = GGML_F32_ARR >> 1; \
2086
+ for (int i = 0; i < offset; ++i) { \
2087
+ x[i] = _mm_add_ps(x[i], x[offset+i]); \
2059
2088
  } \
2060
- for (int i = 0; i < GGML_F32_ARR/4; ++i) { \
2061
- x[4*i] = _mm_add_ps(x[4*i], x[4*i+2]); \
2089
+ offset >>= 1; \
2090
+ for (int i = 0; i < offset; ++i) { \
2091
+ x[i] = _mm_add_ps(x[i], x[offset+i]); \
2062
2092
  } \
2063
- for (int i = 0; i < GGML_F32_ARR/8; ++i) { \
2064
- x[8*i] = _mm_add_ps(x[8*i], x[8*i+4]); \
2093
+ offset >>= 1; \
2094
+ for (int i = 0; i < offset; ++i) { \
2095
+ x[i] = _mm_add_ps(x[i], x[offset+i]); \
2065
2096
  } \
2066
2097
  const __m128 t0 = _mm_hadd_ps(x[0], x[0]); \
2067
2098
  res = _mm_cvtss_f32(_mm_hadd_ps(t0, t0)); \
@@ -3350,6 +3381,7 @@ inline static void ggml_vec_step_f32 (const int n, float * y, const float * x) {
3350
3381
  inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; }
3351
3382
 
3352
3383
  static const float GELU_COEF_A = 0.044715f;
3384
+ static const float GELU_QUICK_COEF = -1.702f;
3353
3385
  static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
3354
3386
 
3355
3387
  inline static float ggml_gelu_f32(float x) {
@@ -3380,6 +3412,34 @@ inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
3380
3412
  }
3381
3413
  #endif
3382
3414
 
3415
+ inline static float ggml_gelu_quick_f32(float x) {
3416
+ return x*(1.0f/(1.0f+expf(GELU_QUICK_COEF*x)));
3417
+ }
3418
+
3419
+ //inline static void ggml_vec_gelu_quick_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
3420
+ // const uint16_t * i16 = (const uint16_t *) x;
3421
+ // for (int i = 0; i < n; ++i) {
3422
+ // y[i] = table_gelu_quick_f16[i16[i]];
3423
+ // }
3424
+ //}
3425
+
3426
+ #ifdef GGML_GELU_QUICK_FP16
3427
+ inline static void ggml_vec_gelu_quick_f32(const int n, float * y, const float * x) {
3428
+ uint16_t t;
3429
+ for (int i = 0; i < n; ++i) {
3430
+ ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]);
3431
+ memcpy(&t, &fp16, sizeof(uint16_t));
3432
+ y[i] = GGML_FP16_TO_FP32(table_gelu_quick_f16[t]);
3433
+ }
3434
+ }
3435
+ #else
3436
+ inline static void ggml_vec_gelu_quick_f32(const int n, float * y, const float * x) {
3437
+ for (int i = 0; i < n; ++i) {
3438
+ y[i] = ggml_gelu_quick_f32(x[i]);
3439
+ }
3440
+ }
3441
+ #endif
3442
+
3383
3443
  // Sigmoid Linear Unit (SiLU) function
3384
3444
  inline static float ggml_silu_f32(float x) {
3385
3445
  return x/(1.0f + expf(-x));
@@ -3603,12 +3663,14 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
3603
3663
  "SUM_ROWS",
3604
3664
  "MEAN",
3605
3665
  "REPEAT",
3666
+ "REPEAT_BACK",
3606
3667
  "ABS",
3607
3668
  "SGN",
3608
3669
  "NEG",
3609
3670
  "STEP",
3610
3671
  "RELU",
3611
3672
  "GELU",
3673
+ "GELU_QUICK",
3612
3674
  "SILU",
3613
3675
  "SILU_BACK",
3614
3676
  "NORM",
@@ -3616,6 +3678,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
3616
3678
  "RMS_NORM_BACK",
3617
3679
 
3618
3680
  "MUL_MAT",
3681
+ "OUT_PROD",
3619
3682
 
3620
3683
  "SCALE",
3621
3684
  "SET",
@@ -3631,22 +3694,29 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
3631
3694
  "DIAG_MASK_INF",
3632
3695
  "DIAG_MASK_ZERO",
3633
3696
  "SOFT_MAX",
3697
+ "SOFT_MAX_BACK",
3634
3698
  "ROPE",
3635
3699
  "ROPE_BACK",
3636
3700
  "ALIBI",
3637
3701
  "CLAMP",
3638
- "CONV_1D_1S",
3639
- "CONV_1D_2S",
3702
+ "CONV_1D_S1_PH",
3703
+ "CONV_1D_S2_PH",
3704
+ "CONV_2D_SK_P0",
3640
3705
 
3641
3706
  "FLASH_ATTN",
3642
3707
  "FLASH_FF",
3708
+ "FLASH_ATTN_BACK",
3709
+ "WIN_PART",
3710
+ "WIN_UNPART",
3643
3711
 
3644
3712
  "MAP_UNARY",
3645
3713
  "MAP_BINARY",
3646
- };
3647
3714
 
3648
- static_assert(GGML_OP_COUNT == 51, "GGML_OP_COUNT != 51");
3715
+ "CROSS_ENTROPY_LOSS",
3716
+ "CROSS_ENTROPY_LOSS_BACK",
3717
+ };
3649
3718
 
3719
+ static_assert(GGML_OP_COUNT == 61, "GGML_OP_COUNT != 61");
3650
3720
 
3651
3721
  static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
3652
3722
  "none",
@@ -3665,18 +3735,21 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
3665
3735
  "Σx_k",
3666
3736
  "Σx/n",
3667
3737
  "repeat(x)",
3738
+ "repeat_back(x)",
3668
3739
  "abs(x)",
3669
3740
  "sgn(x)",
3670
3741
  "-x",
3671
3742
  "step(x)",
3672
3743
  "relu(x)",
3673
3744
  "gelu(x)",
3745
+ "gelu_quick(x)",
3674
3746
  "silu(x)",
3675
3747
  "silu_back(x)",
3676
3748
  "norm(x)",
3677
3749
  "rms_norm(x)",
3678
3750
  "rms_norm_back(x)",
3679
3751
 
3752
+ "X*Y",
3680
3753
  "X*Y",
3681
3754
 
3682
3755
  "x*v",
@@ -3693,21 +3766,29 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
3693
3766
  "diag_mask_inf(x)",
3694
3767
  "diag_mask_zero(x)",
3695
3768
  "soft_max(x)",
3769
+ "soft_max_back(x)",
3696
3770
  "rope(x)",
3697
3771
  "rope_back(x)",
3698
3772
  "alibi(x)",
3699
3773
  "clamp(x)",
3700
- "conv_1d_1s(x)",
3701
- "conv_1d_2s(x)",
3774
+ "conv_1d_s1_ph(x)",
3775
+ "conv_1d_s2_ph(x)",
3776
+ "conv_2d_sk_p0(x)",
3702
3777
 
3703
3778
  "flash_attn(x)",
3704
3779
  "flash_ff(x)",
3780
+ "flash_attn_back(x)",
3781
+ "win_part(x)",
3782
+ "win_unpart(x)",
3705
3783
 
3706
3784
  "f(x)",
3707
3785
  "f(x,y)",
3786
+
3787
+ "cross_entropy_loss(x,y)",
3788
+ "cross_entropy_loss_back(x,y)",
3708
3789
  };
3709
3790
 
3710
- static_assert(GGML_OP_COUNT == 51, "GGML_OP_COUNT != 51");
3791
+ static_assert(GGML_OP_COUNT == 61, "GGML_OP_COUNT != 61");
3711
3792
 
3712
3793
  static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
3713
3794
  static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
@@ -3870,6 +3951,15 @@ static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct
3870
3951
  (t0->ne[3] == t1->ne[3]);
3871
3952
  }
3872
3953
 
3954
+ static inline bool ggml_can_out_prod(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
3955
+ static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
3956
+
3957
+ return
3958
+ (t0->ne[1] == t1->ne[1]) &&
3959
+ (t0->ne[2] == t1->ne[2]) &&
3960
+ (t0->ne[3] == t1->ne[3]);
3961
+ }
3962
+
3873
3963
  bool ggml_is_quantized(enum ggml_type type) {
3874
3964
  return GGML_IS_QUANTIZED[type];
3875
3965
  }
@@ -3917,6 +4007,12 @@ bool ggml_is_contiguous(const struct ggml_tensor * tensor) {
3917
4007
  tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
3918
4008
  }
3919
4009
 
4010
+ bool ggml_is_permuted(const struct ggml_tensor * tensor) {
4011
+ static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
4012
+
4013
+ return tensor->nb[0] > tensor->nb[1] || tensor->nb[1] > tensor->nb[2] || tensor->nb[2] > tensor->nb[3];
4014
+ }
4015
+
3920
4016
  static inline bool ggml_is_padded_1d(const struct ggml_tensor * tensor) {
3921
4017
  static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
3922
4018
 
@@ -3983,7 +4079,7 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
3983
4079
  // initialize time system (required on Windows)
3984
4080
  ggml_time_init();
3985
4081
 
3986
- // initialize GELU, SILU and EXP F32 tables
4082
+ // initialize GELU, Quick GELU, SILU and EXP F32 tables
3987
4083
  {
3988
4084
  const uint64_t t_start = ggml_time_us(); UNUSED(t_start);
3989
4085
 
@@ -3993,13 +4089,14 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
3993
4089
  memcpy(&ii, &ui, sizeof(ii));
3994
4090
  const float f = table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(ii);
3995
4091
  table_gelu_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_f32(f));
4092
+ table_gelu_quick_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_quick_f32(f));
3996
4093
  table_silu_f16[i] = GGML_FP32_TO_FP16(ggml_silu_f32(f));
3997
4094
  table_exp_f16[i] = GGML_FP32_TO_FP16(expf(f));
3998
4095
  }
3999
4096
 
4000
4097
  const uint64_t t_end = ggml_time_us(); UNUSED(t_end);
4001
4098
 
4002
- GGML_PRINT_DEBUG("%s: GELU, SILU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f);
4099
+ GGML_PRINT_DEBUG("%s: GELU, Quick GELU, SILU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f);
4003
4100
  }
4004
4101
 
4005
4102
  // initialize g_state
@@ -4120,14 +4217,34 @@ void ggml_set_no_alloc(struct ggml_context * ctx, bool no_alloc) {
4120
4217
  ctx->no_alloc = no_alloc;
4121
4218
  }
4122
4219
 
4123
- void * ggml_get_mem_buffer(struct ggml_context * ctx) {
4220
+ void * ggml_get_mem_buffer(const struct ggml_context * ctx) {
4124
4221
  return ctx->mem_buffer;
4125
4222
  }
4126
4223
 
4127
- size_t ggml_get_mem_size(struct ggml_context * ctx) {
4224
+ size_t ggml_get_mem_size(const struct ggml_context * ctx) {
4128
4225
  return ctx->mem_size;
4129
4226
  }
4130
4227
 
4228
+ size_t ggml_get_max_tensor_size(const struct ggml_context * ctx) {
4229
+ size_t max_size = 0;
4230
+
4231
+ struct ggml_object * obj = ctx->objects_begin;
4232
+
4233
+ while (obj != NULL) {
4234
+ struct ggml_tensor * tensor = (struct ggml_tensor *) ((char *) ctx->mem_buffer + obj->offs);
4235
+
4236
+ const size_t size = ggml_nbytes(tensor);
4237
+
4238
+ if (max_size < size) {
4239
+ max_size = size;
4240
+ }
4241
+
4242
+ obj = obj->next;
4243
+ }
4244
+
4245
+ return max_size;
4246
+ }
4247
+
4131
4248
  // IMPORTANT:
4132
4249
  // when creating "opt" tensors, always save and load the scratch buffer
4133
4250
  // this is an error prone process, but it is necessary to support inplace
@@ -4611,9 +4728,10 @@ const char * ggml_get_name(const struct ggml_tensor * tensor) {
4611
4728
  return tensor->name;
4612
4729
  }
4613
4730
 
4614
- void ggml_set_name(struct ggml_tensor * tensor, const char * name) {
4731
+ struct ggml_tensor * ggml_set_name(struct ggml_tensor * tensor, const char * name) {
4615
4732
  strncpy(tensor->name, name, sizeof(tensor->name));
4616
4733
  tensor->name[sizeof(tensor->name) - 1] = '\0';
4734
+ return tensor;
4617
4735
  }
4618
4736
 
4619
4737
  struct ggml_tensor * ggml_view_tensor(
@@ -4693,7 +4811,7 @@ struct ggml_tensor * ggml_add_impl(
4693
4811
 
4694
4812
  bool is_node = false;
4695
4813
 
4696
- if (!inplace && (a->grad || b->grad)) {
4814
+ if (a->grad || b->grad) {
4697
4815
  is_node = true;
4698
4816
  }
4699
4817
 
@@ -4733,7 +4851,7 @@ struct ggml_tensor * ggml_add1_impl(
4733
4851
 
4734
4852
  bool is_node = false;
4735
4853
 
4736
- if (!inplace && (a->grad || b->grad)) {
4854
+ if (a->grad || b->grad) {
4737
4855
  is_node = true;
4738
4856
  }
4739
4857
 
@@ -5159,6 +5277,34 @@ struct ggml_tensor * ggml_repeat(
5159
5277
  return result;
5160
5278
  }
5161
5279
 
5280
+ // ggml_repeat_back
5281
+
5282
+ struct ggml_tensor * ggml_repeat_back(
5283
+ struct ggml_context * ctx,
5284
+ struct ggml_tensor * a,
5285
+ struct ggml_tensor * b) {
5286
+ GGML_ASSERT(ggml_can_repeat(b, a));
5287
+
5288
+ bool is_node = false;
5289
+
5290
+ if (a->grad) {
5291
+ is_node = true;
5292
+ }
5293
+
5294
+ if (ggml_are_same_shape(a, b) && !is_node) {
5295
+ return a;
5296
+ }
5297
+
5298
+ struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, b->n_dims, b->ne);
5299
+
5300
+ result->op = GGML_OP_REPEAT_BACK;
5301
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5302
+ result->src0 = a;
5303
+ result->src1 = b;
5304
+
5305
+ return result;
5306
+ }
5307
+
5162
5308
  // ggml_abs
5163
5309
 
5164
5310
  struct ggml_tensor * ggml_abs_impl(
@@ -5364,6 +5510,40 @@ struct ggml_tensor * ggml_gelu_inplace(
5364
5510
  return ggml_gelu_impl(ctx, a, true);
5365
5511
  }
5366
5512
 
5513
+ // ggml_gelu_quick
5514
+
5515
+ struct ggml_tensor * ggml_gelu_quick_impl(
5516
+ struct ggml_context * ctx,
5517
+ struct ggml_tensor * a,
5518
+ bool inplace) {
5519
+ bool is_node = false;
5520
+
5521
+ if (!inplace && (a->grad)) {
5522
+ is_node = true;
5523
+ }
5524
+
5525
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
5526
+
5527
+ result->op = GGML_OP_GELU_QUICK;
5528
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5529
+ result->src0 = a;
5530
+ result->src1 = NULL;
5531
+
5532
+ return result;
5533
+ }
5534
+
5535
+ struct ggml_tensor * ggml_gelu_quick(
5536
+ struct ggml_context * ctx,
5537
+ struct ggml_tensor * a) {
5538
+ return ggml_gelu_quick_impl(ctx, a, false);
5539
+ }
5540
+
5541
+ struct ggml_tensor * ggml_gelu_quick_inplace(
5542
+ struct ggml_context * ctx,
5543
+ struct ggml_tensor * a) {
5544
+ return ggml_gelu_quick_impl(ctx, a, true);
5545
+ }
5546
+
5367
5547
  // ggml_silu
5368
5548
 
5369
5549
  struct ggml_tensor * ggml_silu_impl(
@@ -5536,6 +5716,32 @@ struct ggml_tensor * ggml_mul_mat(
5536
5716
  return result;
5537
5717
  }
5538
5718
 
5719
+ // ggml_out_prod
5720
+
5721
+ struct ggml_tensor * ggml_out_prod(
5722
+ struct ggml_context * ctx,
5723
+ struct ggml_tensor * a,
5724
+ struct ggml_tensor * b) {
5725
+ GGML_ASSERT(ggml_can_out_prod(a, b));
5726
+ GGML_ASSERT(!ggml_is_transposed(a));
5727
+
5728
+ bool is_node = false;
5729
+
5730
+ if (a->grad || b->grad) {
5731
+ is_node = true;
5732
+ }
5733
+
5734
+ const int64_t ne[4] = { a->ne[0], b->ne[0], a->ne[2], b->ne[3] };
5735
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MIN(a->n_dims, b->n_dims), ne);
5736
+
5737
+ result->op = GGML_OP_OUT_PROD;
5738
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5739
+ result->src0 = a;
5740
+ result->src1 = b;
5741
+
5742
+ return result;
5743
+ }
5744
+
5539
5745
  // ggml_scale
5540
5746
 
5541
5747
  struct ggml_tensor * ggml_scale_impl(
@@ -5548,7 +5754,7 @@ struct ggml_tensor * ggml_scale_impl(
5548
5754
 
5549
5755
  bool is_node = false;
5550
5756
 
5551
- if (!inplace && (a->grad || b->grad)) {
5757
+ if (a->grad || b->grad) {
5552
5758
  is_node = true;
5553
5759
  }
5554
5760
 
@@ -5591,7 +5797,7 @@ struct ggml_tensor * ggml_set_impl(
5591
5797
 
5592
5798
  bool is_node = false;
5593
5799
 
5594
- if (!inplace && (a->grad || b->grad)) {
5800
+ if (a->grad || b->grad) {
5595
5801
  is_node = true;
5596
5802
  }
5597
5803
 
@@ -5913,10 +6119,6 @@ struct ggml_tensor * ggml_view_1d(
5913
6119
  result->src1 = NULL;
5914
6120
  result->opt[0] = offs;
5915
6121
 
5916
- if (is_node) {
5917
- memcpy(result->padding, &offset, sizeof(offset));
5918
- }
5919
-
5920
6122
  return result;
5921
6123
  }
5922
6124
 
@@ -5957,10 +6159,6 @@ struct ggml_tensor * ggml_view_2d(
5957
6159
  result->src1 = NULL;
5958
6160
  result->opt[0] = offs;
5959
6161
 
5960
- if (is_node) {
5961
- memcpy(result->padding, &offset, sizeof(offset));
5962
- }
5963
-
5964
6162
  return result;
5965
6163
  }
5966
6164
 
@@ -6003,10 +6201,6 @@ struct ggml_tensor * ggml_view_3d(
6003
6201
  result->src1 = NULL;
6004
6202
  result->opt[0] = offs;
6005
6203
 
6006
- if (is_node) {
6007
- memcpy(result->padding, &offset, sizeof(offset));
6008
- }
6009
-
6010
6204
  return result;
6011
6205
  }
6012
6206
 
@@ -6051,10 +6245,6 @@ struct ggml_tensor * ggml_view_4d(
6051
6245
  result->src1 = NULL;
6052
6246
  result->opt[0] = offs;
6053
6247
 
6054
- if (is_node) {
6055
- memcpy(result->padding, &offset, sizeof(offset));
6056
- }
6057
-
6058
6248
  return result;
6059
6249
  }
6060
6250
 
@@ -6116,10 +6306,18 @@ struct ggml_tensor * ggml_permute(
6116
6306
  result->src1 = NULL;
6117
6307
 
6118
6308
  if (is_node) {
6119
- result->padding[0] = axis0;
6120
- result->padding[1] = axis1;
6121
- result->padding[2] = axis2;
6122
- result->padding[3] = axis3;
6309
+ ggml_scratch_save(ctx);
6310
+
6311
+ struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 4);
6312
+
6313
+ ((int32_t *) b->data)[0] = axis0;
6314
+ ((int32_t *) b->data)[1] = axis1;
6315
+ ((int32_t *) b->data)[2] = axis2;
6316
+ ((int32_t *) b->data)[3] = axis3;
6317
+
6318
+ ggml_scratch_load(ctx);
6319
+
6320
+ result->opt[0] = b;
6123
6321
  }
6124
6322
 
6125
6323
  return result;
@@ -6359,6 +6557,44 @@ struct ggml_tensor * ggml_soft_max_inplace(
6359
6557
  return ggml_soft_max_impl(ctx, a, true);
6360
6558
  }
6361
6559
 
6560
+
6561
+ // ggml_soft_max_back
6562
+
6563
+ struct ggml_tensor * ggml_soft_max_back_impl(
6564
+ struct ggml_context * ctx,
6565
+ struct ggml_tensor * a,
6566
+ struct ggml_tensor * b,
6567
+ bool inplace) {
6568
+ bool is_node = false;
6569
+
6570
+ if (a->grad || b->grad) {
6571
+ is_node = true; // TODO : implement backward pass
6572
+ }
6573
+
6574
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
6575
+
6576
+ result->op = GGML_OP_SOFT_MAX_BACK;
6577
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6578
+ result->src0 = a;
6579
+ result->src1 = b;
6580
+
6581
+ return result;
6582
+ }
6583
+
6584
+ struct ggml_tensor * ggml_soft_max_back(
6585
+ struct ggml_context * ctx,
6586
+ struct ggml_tensor * a,
6587
+ struct ggml_tensor * b) {
6588
+ return ggml_soft_max_back_impl(ctx, a, b, false);
6589
+ }
6590
+
6591
+ struct ggml_tensor * ggml_soft_max_back_inplace(
6592
+ struct ggml_context * ctx,
6593
+ struct ggml_tensor * a,
6594
+ struct ggml_tensor * b) {
6595
+ return ggml_soft_max_back_impl(ctx, a, b, true);
6596
+ }
6597
+
6362
6598
  // ggml_rope
6363
6599
 
6364
6600
  struct ggml_tensor * ggml_rope_impl(
@@ -6371,7 +6607,7 @@ struct ggml_tensor * ggml_rope_impl(
6371
6607
  GGML_ASSERT(n_past >= 0);
6372
6608
  bool is_node = false;
6373
6609
 
6374
- if (!inplace && a->grad) {
6610
+ if (a->grad) {
6375
6611
  is_node = true;
6376
6612
  }
6377
6613
 
@@ -6425,8 +6661,7 @@ struct ggml_tensor * ggml_rope_back(
6425
6661
  bool is_node = false;
6426
6662
 
6427
6663
  if (a->grad) {
6428
- GGML_ASSERT(false); // TODO: implement backward
6429
- is_node = true;
6664
+ is_node = false; // TODO: implement backward
6430
6665
  }
6431
6666
 
6432
6667
  struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
@@ -6508,7 +6743,7 @@ struct ggml_tensor * ggml_clamp(
6508
6743
 
6509
6744
  ggml_scratch_save(ctx);
6510
6745
 
6511
- struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3);
6746
+ struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 2);
6512
6747
 
6513
6748
  ((float *) b->data)[0] = min;
6514
6749
  ((float *) b->data)[1] = max;
@@ -6523,9 +6758,9 @@ struct ggml_tensor * ggml_clamp(
6523
6758
  return result;
6524
6759
  }
6525
6760
 
6526
- // ggml_conv_1d_1s
6761
+ // ggml_conv_1d_s1_ph
6527
6762
 
6528
- struct ggml_tensor * ggml_conv_1d_1s(
6763
+ struct ggml_tensor * ggml_conv_1d_s1_ph(
6529
6764
  struct ggml_context * ctx,
6530
6765
  struct ggml_tensor * a,
6531
6766
  struct ggml_tensor * b) {
@@ -6542,7 +6777,7 @@ struct ggml_tensor * ggml_conv_1d_1s(
6542
6777
  const int64_t ne[4] = { b->ne[0], a->ne[2], 1, 1, };
6543
6778
  struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 2, ne);
6544
6779
 
6545
- result->op = GGML_OP_CONV_1D_1S;
6780
+ result->op = GGML_OP_CONV_1D_S1_PH;
6546
6781
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6547
6782
  result->src0 = a;
6548
6783
  result->src1 = b;
@@ -6550,9 +6785,9 @@ struct ggml_tensor * ggml_conv_1d_1s(
6550
6785
  return result;
6551
6786
  }
6552
6787
 
6553
- // ggml_conv_1d_2s
6788
+ // ggml_conv_1d_s2_ph
6554
6789
 
6555
- struct ggml_tensor * ggml_conv_1d_2s(
6790
+ struct ggml_tensor * ggml_conv_1d_s2_ph(
6556
6791
  struct ggml_context * ctx,
6557
6792
  struct ggml_tensor * a,
6558
6793
  struct ggml_tensor * b) {
@@ -6569,7 +6804,35 @@ struct ggml_tensor * ggml_conv_1d_2s(
6569
6804
  const int64_t ne[4] = { b->ne[0]/2, a->ne[2], 1, 1, };
6570
6805
  struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 2, ne);
6571
6806
 
6572
- result->op = GGML_OP_CONV_1D_2S;
6807
+ result->op = GGML_OP_CONV_1D_S2_PH;
6808
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6809
+ result->src0 = a;
6810
+ result->src1 = b;
6811
+
6812
+ return result;
6813
+ }
6814
+
6815
+ // ggml_conv_2d_sk_p0
6816
+
6817
+ struct ggml_tensor * ggml_conv_2d_sk_p0(
6818
+ struct ggml_context * ctx,
6819
+ struct ggml_tensor * a,
6820
+ struct ggml_tensor * b) {
6821
+ GGML_ASSERT(b->ne[3] == 1);
6822
+ GGML_ASSERT(a->ne[2] == b->ne[2]);
6823
+ GGML_ASSERT(b->ne[0] % a->ne[0] == 0);
6824
+ GGML_ASSERT(b->ne[1] % a->ne[1] == 0);
6825
+ bool is_node = false;
6826
+
6827
+ if (a->grad || b->grad) {
6828
+ GGML_ASSERT(false); // TODO: implement backward
6829
+ is_node = true;
6830
+ }
6831
+
6832
+ const int64_t ne[4] = { b->ne[0]/a->ne[0], b->ne[1]/a->ne[1], a->ne[3], 1, };
6833
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
6834
+
6835
+ result->op = GGML_OP_CONV_2D_SK_P0;
6573
6836
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6574
6837
  result->src0 = a;
6575
6838
  result->src1 = b;
@@ -6591,7 +6854,6 @@ struct ggml_tensor * ggml_flash_attn(
6591
6854
  bool is_node = false;
6592
6855
 
6593
6856
  if (q->grad || k->grad || v->grad) {
6594
- GGML_ASSERT(false); // TODO: implement backward
6595
6857
  is_node = true;
6596
6858
  }
6597
6859
 
@@ -6623,7 +6885,6 @@ struct ggml_tensor * ggml_flash_ff(
6623
6885
  bool is_node = false;
6624
6886
 
6625
6887
  if (a->grad || b0->grad || b1->grad || c0->grad || c1->grad) {
6626
- GGML_ASSERT(false); // TODO: implement backward
6627
6888
  is_node = true;
6628
6889
  }
6629
6890
 
@@ -6641,54 +6902,202 @@ struct ggml_tensor * ggml_flash_ff(
6641
6902
  return result;
6642
6903
  }
6643
6904
 
6644
- // ggml_map_unary
6905
+ // ggml_flash_attn_back
6906
+
6907
+ struct ggml_tensor * ggml_flash_attn_back(
6908
+ struct ggml_context * ctx,
6909
+ struct ggml_tensor * q,
6910
+ struct ggml_tensor * k,
6911
+ struct ggml_tensor * v,
6912
+ struct ggml_tensor * d,
6913
+ bool masked) {
6914
+ GGML_ASSERT(ggml_can_mul_mat(k, q));
6915
+ // TODO: check if vT can be multiplied by (k*qT)
6916
+
6917
+ // d shape [D,N,ne2,ne3]
6918
+ // q shape [D,N,ne2,ne3]
6919
+ // k shape [D,M,ne2,ne3]
6920
+ // v shape [M,D,ne2,ne3]
6921
+
6922
+ const int64_t D = q->ne[0];
6923
+ const int64_t N = q->ne[1];
6924
+ const int64_t M = k->ne[1];
6925
+ const int64_t ne2 = q->ne[2];
6926
+ const int64_t ne3 = q->ne[3];
6927
+
6928
+ GGML_ASSERT(k->ne[0] == D);
6929
+ GGML_ASSERT(v->ne[0] == M);
6930
+ GGML_ASSERT(v->ne[1] == D);
6931
+ GGML_ASSERT(d->ne[0] == D);
6932
+ GGML_ASSERT(d->ne[1] == N);
6933
+ GGML_ASSERT(k->ne[2] == ne2);
6934
+ GGML_ASSERT(k->ne[3] == ne3);
6935
+ GGML_ASSERT(v->ne[2] == ne2);
6936
+ GGML_ASSERT(v->ne[3] == ne3);
6937
+ GGML_ASSERT(d->ne[2] == ne2);
6938
+ GGML_ASSERT(d->ne[3] == ne3);
6645
6939
 
6646
- struct ggml_tensor * ggml_map_unary_impl_f32(
6647
- struct ggml_context * ctx,
6648
- struct ggml_tensor * a,
6649
- const ggml_unary_op_f32_t fun,
6650
- bool inplace) {
6651
6940
  bool is_node = false;
6652
6941
 
6653
- if (!inplace && a->grad) {
6654
- is_node = true;
6942
+ if (q->grad || k->grad || v->grad) {
6943
+ // when using this operation (in backwards pass) these grads are set.
6944
+ // we don't want to create (big) grad of our result, so is_node is false.
6945
+ is_node = false;
6655
6946
  }
6656
6947
 
6657
- struct ggml_tensor * addr_tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(void *) / sizeof(int32_t));
6658
- *((void (**)(void))addr_tensor->data) = (void (*)(void))fun;
6659
- struct ggml_tensor *result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
6948
+ // store gradients of q, k and v as continuous tensors concatenated in result.
6949
+ // q shape[D,N,ne2,ne3] ; k shape [D,M,ne2,ne3] ; v shape [M,D,ne2,ne3]
6950
+ // gradq->data = result->data
6951
+ // gradk->data = result->data + nb0*D*N*ne2*ne3
6952
+ // gradv->data = result->data + nb0*D*N*ne2*ne3 + nb0*D*M*ne2*ne3
6953
+ // note: v and gradv are actually transposed, i.e. v->ne[0] != D.
6954
+ int64_t ne[4] = {D,M+N+M,ne2,ne3};
6660
6955
 
6661
- result->op = GGML_OP_MAP_UNARY;
6956
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
6957
+
6958
+ result->op = GGML_OP_FLASH_ATTN_BACK;
6662
6959
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6663
- result->src0 = a;
6664
- result->opt[0] = addr_tensor;
6960
+ result->src0 = q;
6961
+ result->src1 = k;
6962
+ result->opt[0] = v;
6963
+ result->opt[1] = d;
6964
+ result->opt[2] = ggml_new_i32(ctx, masked ? 1 : 0);
6665
6965
 
6666
6966
  return result;
6667
6967
  }
6668
6968
 
6669
- struct ggml_tensor * ggml_map_unary_f32(
6670
- struct ggml_context * ctx,
6671
- struct ggml_tensor * a,
6672
- const ggml_unary_op_f32_t fun) {
6673
- return ggml_map_unary_impl_f32(ctx, a, fun, false);
6674
- }
6675
-
6676
- struct ggml_tensor * ggml_map_unary_inplace_f32(
6677
- struct ggml_context * ctx,
6678
- struct ggml_tensor * a,
6679
- const ggml_unary_op_f32_t fun) {
6680
- return ggml_map_unary_impl_f32(ctx, a, fun, true);
6681
- }
6682
-
6683
- // ggml_map_binary
6969
+ // ggml_win_part
6684
6970
 
6685
- struct ggml_tensor * ggml_map_binary_impl_f32(
6686
- struct ggml_context * ctx,
6687
- struct ggml_tensor * a,
6688
- struct ggml_tensor * b,
6689
- const ggml_binary_op_f32_t fun,
6690
- bool inplace) {
6691
- GGML_ASSERT(ggml_are_same_shape(a, b));
6971
+ struct ggml_tensor * ggml_win_part(
6972
+ struct ggml_context * ctx,
6973
+ struct ggml_tensor * a,
6974
+ int w) {
6975
+ GGML_ASSERT(a->ne[3] == 1);
6976
+ GGML_ASSERT(a->type == GGML_TYPE_F32);
6977
+
6978
+ bool is_node = false;
6979
+
6980
+ if (a->grad) {
6981
+ GGML_ASSERT(false); // TODO: implement backward
6982
+ is_node = true;
6983
+ }
6984
+
6985
+ // padding
6986
+ const int px = (w - a->ne[1]%w)%w;
6987
+ const int py = (w - a->ne[2]%w)%w;
6988
+
6989
+ const int npx = (px + a->ne[1])/w;
6990
+ const int npy = (py + a->ne[2])/w;
6991
+ const int np = npx*npy;
6992
+
6993
+ const int64_t ne[4] = { a->ne[0], w, w, np, };
6994
+
6995
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
6996
+
6997
+ ggml_scratch_save(ctx);
6998
+
6999
+ struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3);
7000
+
7001
+ ((int32_t *) b->data)[0] = npx;
7002
+ ((int32_t *) b->data)[1] = npy;
7003
+ ((int32_t *) b->data)[2] = w;
7004
+
7005
+ ggml_scratch_load(ctx);
7006
+
7007
+ result->op = GGML_OP_WIN_PART;
7008
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7009
+ result->src0 = a;
7010
+ result->src1 = NULL;
7011
+ result->opt[0] = b;
7012
+
7013
+ return result;
7014
+ }
7015
+
7016
+ // ggml_win_unpart
7017
+
7018
+ struct ggml_tensor * ggml_win_unpart(
7019
+ struct ggml_context * ctx,
7020
+ struct ggml_tensor * a,
7021
+ int w0,
7022
+ int h0,
7023
+ int w) {
7024
+ GGML_ASSERT(a->type == GGML_TYPE_F32);
7025
+
7026
+ bool is_node = false;
7027
+
7028
+ if (a->grad) {
7029
+ GGML_ASSERT(false); // TODO: implement backward
7030
+ is_node = true;
7031
+ }
7032
+
7033
+ const int64_t ne[4] = { a->ne[0], w0, h0, 1, };
7034
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 3, ne);
7035
+
7036
+ ggml_scratch_save(ctx);
7037
+
7038
+ struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1);
7039
+
7040
+ ((int32_t *) b->data)[0] = w;
7041
+
7042
+ ggml_scratch_load(ctx);
7043
+
7044
+ result->op = GGML_OP_WIN_UNPART;
7045
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7046
+ result->src0 = a;
7047
+ result->src1 = NULL;
7048
+ result->opt[0] = b;
7049
+
7050
+ return result;
7051
+ }
7052
+
7053
+ // ggml_map_unary
7054
+
7055
+ struct ggml_tensor * ggml_map_unary_impl_f32(
7056
+ struct ggml_context * ctx,
7057
+ struct ggml_tensor * a,
7058
+ const ggml_unary_op_f32_t fun,
7059
+ bool inplace) {
7060
+ bool is_node = false;
7061
+
7062
+ if (!inplace && a->grad) {
7063
+ is_node = true;
7064
+ }
7065
+
7066
+ struct ggml_tensor * addr_tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(void *) / sizeof(int32_t));
7067
+ *((void (**)(void))addr_tensor->data) = (void (*)(void))fun;
7068
+ struct ggml_tensor *result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
7069
+
7070
+ result->op = GGML_OP_MAP_UNARY;
7071
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7072
+ result->src0 = a;
7073
+ result->opt[0] = addr_tensor;
7074
+
7075
+ return result;
7076
+ }
7077
+
7078
+ struct ggml_tensor * ggml_map_unary_f32(
7079
+ struct ggml_context * ctx,
7080
+ struct ggml_tensor * a,
7081
+ const ggml_unary_op_f32_t fun) {
7082
+ return ggml_map_unary_impl_f32(ctx, a, fun, false);
7083
+ }
7084
+
7085
+ struct ggml_tensor * ggml_map_unary_inplace_f32(
7086
+ struct ggml_context * ctx,
7087
+ struct ggml_tensor * a,
7088
+ const ggml_unary_op_f32_t fun) {
7089
+ return ggml_map_unary_impl_f32(ctx, a, fun, true);
7090
+ }
7091
+
7092
+ // ggml_map_binary
7093
+
7094
+ struct ggml_tensor * ggml_map_binary_impl_f32(
7095
+ struct ggml_context * ctx,
7096
+ struct ggml_tensor * a,
7097
+ struct ggml_tensor * b,
7098
+ const ggml_binary_op_f32_t fun,
7099
+ bool inplace) {
7100
+ GGML_ASSERT(ggml_are_same_shape(a, b));
6692
7101
 
6693
7102
  bool is_node = false;
6694
7103
 
@@ -6725,6 +7134,50 @@ struct ggml_tensor * ggml_map_binary_inplace_f32(
6725
7134
  return ggml_map_binary_impl_f32(ctx, a, b, fun, true);
6726
7135
  }
6727
7136
 
7137
+ // ggml_cross_entropy_loss
7138
+
7139
+ struct ggml_tensor * ggml_cross_entropy_loss(
7140
+ struct ggml_context * ctx,
7141
+ struct ggml_tensor * a,
7142
+ struct ggml_tensor * b) {
7143
+ GGML_ASSERT(ggml_are_same_shape(a, b));
7144
+ bool is_node = false;
7145
+
7146
+ if (a->grad || b->grad) {
7147
+ is_node = true;
7148
+ }
7149
+
7150
+ struct ggml_tensor * result = ggml_new_tensor_1d(ctx, a->type, 1);
7151
+
7152
+ result->op = GGML_OP_CROSS_ENTROPY_LOSS;
7153
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7154
+ result->src0 = a;
7155
+ result->src1 = b;
7156
+
7157
+ return result;
7158
+ }
7159
+
7160
+ // ggml_cross_entropy_loss_back
7161
+
7162
+ struct ggml_tensor * ggml_cross_entropy_loss_back(
7163
+ struct ggml_context * ctx,
7164
+ struct ggml_tensor * a,
7165
+ struct ggml_tensor * b,
7166
+ struct ggml_tensor * c) {
7167
+ GGML_ASSERT(ggml_are_same_shape(a, b));
7168
+ GGML_ASSERT(ggml_is_scalar(c));
7169
+
7170
+ struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
7171
+
7172
+ result->op = GGML_OP_CROSS_ENTROPY_LOSS_BACK;
7173
+ result->grad = NULL;
7174
+ result->src0 = a;
7175
+ result->src1 = b;
7176
+ result->opt[0] = c;
7177
+
7178
+ return result;
7179
+ }
7180
+
6728
7181
  ////////////////////////////////////////////////////////////////////////////////
6729
7182
 
6730
7183
  void ggml_set_param(
@@ -7674,7 +8127,7 @@ static void ggml_compute_forward_add_q_f32(
7674
8127
 
7675
8128
  void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
7676
8129
  float * src1_row = (float *)((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13));
7677
- void * dst_row = (void *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb0));
8130
+ void * dst_row = (void *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
7678
8131
 
7679
8132
  assert(ne00 % 32 == 0);
7680
8133
 
@@ -8875,6 +9328,99 @@ static void ggml_compute_forward_repeat(
8875
9328
  }
8876
9329
  }
8877
9330
 
9331
+ // ggml_compute_forward_repeat_back
9332
+
9333
+ static void ggml_compute_forward_repeat_back_f32(
9334
+ const struct ggml_compute_params * params,
9335
+ const struct ggml_tensor * src0,
9336
+ struct ggml_tensor * dst) {
9337
+ GGML_ASSERT(params->ith == 0);
9338
+ GGML_ASSERT(ggml_can_repeat(dst, src0));
9339
+
9340
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
9341
+ return;
9342
+ }
9343
+
9344
+ const int64_t ne0 = dst->ne[0];
9345
+ const int64_t ne1 = dst->ne[1];
9346
+ const int64_t ne2 = dst->ne[2];
9347
+ const int64_t ne3 = dst->ne[3];
9348
+
9349
+ const int64_t ne00 = src0->ne[0];
9350
+ const int64_t ne01 = src0->ne[1];
9351
+ const int64_t ne02 = src0->ne[2];
9352
+ const int64_t ne03 = src0->ne[3];
9353
+
9354
+ const size_t nb0 = dst->nb[0];
9355
+ const size_t nb1 = dst->nb[1];
9356
+ const size_t nb2 = dst->nb[2];
9357
+ const size_t nb3 = dst->nb[3];
9358
+
9359
+ const size_t nb00 = src0->nb[0];
9360
+ const size_t nb01 = src0->nb[1];
9361
+ const size_t nb02 = src0->nb[2];
9362
+ const size_t nb03 = src0->nb[3];
9363
+
9364
+ // guaranteed to be an integer due to the check in ggml_can_repeat
9365
+ const int nr0 = (int)(ne00/ne0);
9366
+ const int nr1 = (int)(ne01/ne1);
9367
+ const int nr2 = (int)(ne02/ne2);
9368
+ const int nr3 = (int)(ne03/ne3);
9369
+
9370
+ // TODO: support for transposed / permuted tensors
9371
+ GGML_ASSERT(nb0 == sizeof(float));
9372
+ GGML_ASSERT(nb00 == sizeof(float));
9373
+
9374
+ if (ggml_is_contiguous(dst)) {
9375
+ ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
9376
+ } else {
9377
+ for (int k3 = 0; k3 < ne3; k3++) {
9378
+ for (int k2 = 0; k2 < ne2; k2++) {
9379
+ for (int k1 = 0; k1 < ne1; k1++) {
9380
+ ggml_vec_set_f32(ne0,
9381
+ (float *) ((char *) dst->data + k1*nb1 + k2*nb2 + k3*nb3),
9382
+ 0);
9383
+ }
9384
+ }
9385
+ }
9386
+ }
9387
+
9388
+ // TODO: maybe this is not optimal?
9389
+ for (int i3 = 0; i3 < nr3; i3++) {
9390
+ for (int k3 = 0; k3 < ne3; k3++) {
9391
+ for (int i2 = 0; i2 < nr2; i2++) {
9392
+ for (int k2 = 0; k2 < ne2; k2++) {
9393
+ for (int i1 = 0; i1 < nr1; i1++) {
9394
+ for (int k1 = 0; k1 < ne1; k1++) {
9395
+ for (int i0 = 0; i0 < nr0; i0++) {
9396
+ ggml_vec_acc_f32(ne0,
9397
+ (float *) ((char *) dst->data + ( k3)*nb3 + ( k2)*nb2 + ( k1)*nb1),
9398
+ (float *) ((char *) src0->data + (i3*ne3 + k3)*nb03 + (i2*ne2 + k2)*nb02 + (i1*ne1 + k1)*nb01 + (i0*ne0)*nb00));
9399
+ }
9400
+ }
9401
+ }
9402
+ }
9403
+ }
9404
+ }
9405
+ }
9406
+ }
9407
+
9408
+ static void ggml_compute_forward_repeat_back(
9409
+ const struct ggml_compute_params * params,
9410
+ const struct ggml_tensor * src0,
9411
+ struct ggml_tensor * dst) {
9412
+ switch (src0->type) {
9413
+ case GGML_TYPE_F32:
9414
+ {
9415
+ ggml_compute_forward_repeat_back_f32(params, src0, dst);
9416
+ } break;
9417
+ default:
9418
+ {
9419
+ GGML_ASSERT(false);
9420
+ } break;
9421
+ }
9422
+ }
9423
+
8878
9424
  // ggml_compute_forward_abs
8879
9425
 
8880
9426
  static void ggml_compute_forward_abs_f32(
@@ -9142,8 +9688,65 @@ static void ggml_compute_forward_gelu(
9142
9688
  GGML_ASSERT(false);
9143
9689
  } break;
9144
9690
  }
9691
+ }
9692
+
9693
+ // ggml_compute_forward_gelu_quick
9694
+
9695
+ static void ggml_compute_forward_gelu_quick_f32(
9696
+ const struct ggml_compute_params * params,
9697
+ const struct ggml_tensor * src0,
9698
+ struct ggml_tensor * dst) {
9699
+ GGML_ASSERT(ggml_is_contiguous(src0));
9700
+ GGML_ASSERT(ggml_is_contiguous(dst));
9701
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
9702
+
9703
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
9704
+ return;
9705
+ }
9706
+
9707
+ const int ith = params->ith;
9708
+ const int nth = params->nth;
9709
+
9710
+ const int nc = src0->ne[0];
9711
+ const int nr = ggml_nrows(src0);
9712
+
9713
+ // rows per thread
9714
+ const int dr = (nr + nth - 1)/nth;
9715
+
9716
+ // row range for this thread
9717
+ const int ir0 = dr*ith;
9718
+ const int ir1 = MIN(ir0 + dr, nr);
9719
+
9720
+ for (int i1 = ir0; i1 < ir1; i1++) {
9721
+ ggml_vec_gelu_quick_f32(nc,
9722
+ (float *) ((char *) dst->data + i1*( dst->nb[1])),
9723
+ (float *) ((char *) src0->data + i1*(src0->nb[1])));
9724
+
9725
+ #ifndef NDEBUG
9726
+ for (int k = 0; k < nc; k++) {
9727
+ const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
9728
+ UNUSED(x);
9729
+ assert(!isnan(x));
9730
+ assert(!isinf(x));
9731
+ }
9732
+ #endif
9733
+ }
9734
+ }
9145
9735
 
9146
- //printf("XXXXXXXX gelu\n");
9736
+ static void ggml_compute_forward_gelu_quick(
9737
+ const struct ggml_compute_params * params,
9738
+ const struct ggml_tensor * src0,
9739
+ struct ggml_tensor * dst) {
9740
+ switch (src0->type) {
9741
+ case GGML_TYPE_F32:
9742
+ {
9743
+ ggml_compute_forward_gelu_quick_f32(params, src0, dst);
9744
+ } break;
9745
+ default:
9746
+ {
9747
+ GGML_ASSERT(false);
9748
+ } break;
9749
+ }
9147
9750
  }
9148
9751
 
9149
9752
  // ggml_compute_forward_silu
@@ -10249,9 +10852,179 @@ static void ggml_compute_forward_mul_mat(
10249
10852
  }
10250
10853
  }
10251
10854
 
10252
- // ggml_compute_forward_scale
10855
+ // ggml_compute_forward_out_prod
10253
10856
 
10254
- static void ggml_compute_forward_scale_f32(
10857
+
10858
+ static void ggml_compute_forward_out_prod_f32(
10859
+ const struct ggml_compute_params * params,
10860
+ const struct ggml_tensor * src0,
10861
+ const struct ggml_tensor * src1,
10862
+ struct ggml_tensor * dst) {
10863
+ int64_t t0 = ggml_perf_time_us();
10864
+ UNUSED(t0);
10865
+
10866
+ const int64_t ne00 = src0->ne[0];
10867
+ const int64_t ne01 = src0->ne[1];
10868
+ const int64_t ne02 = src0->ne[2];
10869
+ const int64_t ne03 = src0->ne[3];
10870
+
10871
+ const int64_t ne10 = src1->ne[0];
10872
+ //const int64_t ne11 = src1->ne[1];
10873
+ const int64_t ne12 = src1->ne[2];
10874
+ const int64_t ne13 = src1->ne[3];
10875
+
10876
+ const int64_t ne0 = dst->ne[0];
10877
+ const int64_t ne1 = dst->ne[1];
10878
+ const int64_t ne2 = dst->ne[2];
10879
+ const int64_t ne3 = dst->ne[3];
10880
+
10881
+ const int nb00 = src0->nb[0];
10882
+ const int nb01 = src0->nb[1];
10883
+ const int nb02 = src0->nb[2];
10884
+ const int nb03 = src0->nb[3];
10885
+
10886
+ const int nb10 = src1->nb[0];
10887
+ const int nb11 = src1->nb[1];
10888
+ const int nb12 = src1->nb[2];
10889
+ const int nb13 = src1->nb[3];
10890
+
10891
+ const int nb0 = dst->nb[0];
10892
+ const int nb1 = dst->nb[1];
10893
+ const int nb2 = dst->nb[2];
10894
+ const int nb3 = dst->nb[3];
10895
+
10896
+ const int ith = params->ith;
10897
+ const int nth = params->nth;
10898
+
10899
+ GGML_ASSERT(ne02 == ne12);
10900
+ GGML_ASSERT(ne03 == ne13);
10901
+ GGML_ASSERT(ne2 == ne12);
10902
+ GGML_ASSERT(ne3 == ne13);
10903
+
10904
+ // we don't support permuted src0 or src1
10905
+ GGML_ASSERT(nb00 == sizeof(float));
10906
+
10907
+ // dst cannot be transposed or permuted
10908
+ GGML_ASSERT(nb0 == sizeof(float));
10909
+ // GGML_ASSERT(nb0 <= nb1);
10910
+ // GGML_ASSERT(nb1 <= nb2);
10911
+ // GGML_ASSERT(nb2 <= nb3);
10912
+
10913
+ GGML_ASSERT(ne0 == ne00);
10914
+ GGML_ASSERT(ne1 == ne10);
10915
+ GGML_ASSERT(ne2 == ne02);
10916
+ GGML_ASSERT(ne3 == ne03);
10917
+
10918
+ // nb01 >= nb00 - src0 is not transposed
10919
+ // compute by src0 rows
10920
+
10921
+ // TODO: #if defined(GGML_USE_CUBLAS) ggml_cuda_out_prod
10922
+ // TODO: #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
10923
+
10924
+ if (params->type == GGML_TASK_INIT) {
10925
+ ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
10926
+ return;
10927
+ }
10928
+
10929
+ if (params->type == GGML_TASK_FINALIZE) {
10930
+ return;
10931
+ }
10932
+
10933
+ // parallelize by last three dimensions
10934
+
10935
+ // total rows in dst
10936
+ const int64_t nr = ne1*ne2*ne3;
10937
+
10938
+ // rows per thread
10939
+ const int64_t dr = (nr + nth - 1)/nth;
10940
+
10941
+ // row range for this thread
10942
+ const int64_t ir0 = dr*ith;
10943
+ const int64_t ir1 = MIN(ir0 + dr, nr);
10944
+
10945
+ // dst[:,:,:,:] = 0
10946
+ // for i2,i3:
10947
+ // for i1:
10948
+ // for i01:
10949
+ // for i0:
10950
+ // dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3]
10951
+
10952
+ for (int64_t ir = ir0; ir < ir1; ++ir) {
10953
+ // dst indices
10954
+ const int64_t i3 = ir/(ne2*ne1);
10955
+ const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
10956
+ const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
10957
+
10958
+ const int64_t i02 = i2;
10959
+ const int64_t i03 = i3;
10960
+
10961
+ //const int64_t i10 = i1;
10962
+ const int64_t i12 = i2;
10963
+ const int64_t i13 = i3;
10964
+
10965
+ for (int64_t i01 = 0; i01 < ne01; ++i01) {
10966
+ const int64_t i11 = i01;
10967
+
10968
+ float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
10969
+ float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
10970
+ float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
10971
+
10972
+ ggml_vec_mad_f32(ne0, d, s0, *s1);
10973
+ // for (int64_t i0 = 0; i0 < ne0; ++i0) {
10974
+ // d[i0] += s0[i0] * s1[i1];
10975
+ // }
10976
+ }
10977
+ }
10978
+
10979
+ //int64_t t1 = ggml_perf_time_us();
10980
+ //static int64_t acc = 0;
10981
+ //acc += t1 - t0;
10982
+ //if (t1 - t0 > 10) {
10983
+ // printf("\n");
10984
+ // printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03);
10985
+ // printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03);
10986
+ // printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13);
10987
+ // printf("nb10 = %5d, nb11 = %5d, nb12 = %5d, nb13 = %5d\n", nb10, nb11, nb12, nb13);
10988
+
10989
+ // printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc);
10990
+ //}
10991
+ }
10992
+
10993
+ static void ggml_compute_forward_out_prod(
10994
+ const struct ggml_compute_params * params,
10995
+ const struct ggml_tensor * src0,
10996
+ const struct ggml_tensor * src1,
10997
+ struct ggml_tensor * dst) {
10998
+ switch (src0->type) {
10999
+ case GGML_TYPE_Q4_0:
11000
+ case GGML_TYPE_Q4_1:
11001
+ case GGML_TYPE_Q5_0:
11002
+ case GGML_TYPE_Q5_1:
11003
+ case GGML_TYPE_Q8_0:
11004
+ case GGML_TYPE_Q8_1:
11005
+ {
11006
+ GGML_ASSERT(false); // todo
11007
+ // ggml_compute_forward_out_prod_q_f32(params, src0, src1, dst);
11008
+ } break;
11009
+ case GGML_TYPE_F16:
11010
+ {
11011
+ GGML_ASSERT(false); // todo
11012
+ // ggml_compute_forward_out_prod_f16_f32(params, src0, src1, dst);
11013
+ } break;
11014
+ case GGML_TYPE_F32:
11015
+ {
11016
+ ggml_compute_forward_out_prod_f32(params, src0, src1, dst);
11017
+ } break;
11018
+ default:
11019
+ {
11020
+ GGML_ASSERT(false);
11021
+ } break;
11022
+ }
11023
+ }
11024
+
11025
+ // ggml_compute_forward_scale
11026
+
11027
+ static void ggml_compute_forward_scale_f32(
10255
11028
  const struct ggml_compute_params * params,
10256
11029
  const struct ggml_tensor * src0,
10257
11030
  const struct ggml_tensor * src1,
@@ -10371,7 +11144,7 @@ static void ggml_compute_forward_set_f32(
10371
11144
  const int im2 = (ne12 == 0 ? 0 : ne12-1);
10372
11145
  const int im3 = (ne13 == 0 ? 0 : ne13-1);
10373
11146
 
10374
- GGML_ASSERT(offset + im0*nb0 + im1*nb1 + im2*nb2 + im3*nb3 < ggml_nbytes(dst));
11147
+ GGML_ASSERT(offset + im0*nb0 + im1*nb1 + im2*nb2 + im3*nb3 <= ggml_nbytes(dst));
10375
11148
 
10376
11149
  GGML_ASSERT(nb10 == sizeof(float));
10377
11150
 
@@ -10671,7 +11444,11 @@ static void ggml_compute_forward_get_rows_back_f32(
10671
11444
  GGML_ASSERT(ggml_is_contiguous(opt0));
10672
11445
  GGML_ASSERT(ggml_is_contiguous(dst));
10673
11446
 
10674
- ggml_compute_forward_dup_same_cont(params, opt0, dst);
11447
+ // ggml_compute_forward_dup_same_cont(params, opt0, dst);
11448
+
11449
+ if (params->type == GGML_TASK_INIT) {
11450
+ memset(dst->data, 0, ggml_nbytes(dst));
11451
+ }
10675
11452
 
10676
11453
  if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
10677
11454
  return;
@@ -10815,8 +11592,8 @@ static void ggml_compute_forward_diag_mask_f32(
10815
11592
  const struct ggml_tensor * src1,
10816
11593
  struct ggml_tensor * dst,
10817
11594
  const float value) {
10818
- assert(src1->type == GGML_TYPE_I32);
10819
- assert(ggml_nelements(src1) == 2);
11595
+ GGML_ASSERT(src1->type == GGML_TYPE_I32);
11596
+ GGML_ASSERT(ggml_nelements(src1) == 2);
10820
11597
 
10821
11598
  const int ith = params->ith;
10822
11599
  const int nth = params->nth;
@@ -10824,7 +11601,7 @@ static void ggml_compute_forward_diag_mask_f32(
10824
11601
  const int n_past = ((int32_t *) src1->data)[0];
10825
11602
  const bool inplace = (bool)((int32_t *) src1->data)[1];
10826
11603
 
10827
- assert(n_past >= 0);
11604
+ GGML_ASSERT(n_past >= 0);
10828
11605
 
10829
11606
  if (!inplace && (params->type == GGML_TASK_INIT)) {
10830
11607
  // memcpy needs to be synchronized across threads to avoid race conditions.
@@ -10848,8 +11625,8 @@ static void ggml_compute_forward_diag_mask_f32(
10848
11625
  const int nr = src0->ne[1];
10849
11626
  const int nz = n/nr;
10850
11627
 
10851
- assert( dst->nb[0] == sizeof(float));
10852
- assert(src0->nb[0] == sizeof(float));
11628
+ GGML_ASSERT( dst->nb[0] == sizeof(float));
11629
+ GGML_ASSERT(src0->nb[0] == sizeof(float));
10853
11630
 
10854
11631
  for (int k = 0; k < nz; k++) {
10855
11632
  for (int j = ith; j < nr; j += nth) {
@@ -10985,6 +11762,101 @@ static void ggml_compute_forward_soft_max(
10985
11762
  }
10986
11763
  }
10987
11764
 
11765
+ // ggml_compute_forward_soft_max_back
11766
+
11767
+ static void ggml_compute_forward_soft_max_back_f32(
11768
+ const struct ggml_compute_params * params,
11769
+ const struct ggml_tensor * src0,
11770
+ const struct ggml_tensor * src1,
11771
+ struct ggml_tensor * dst) {
11772
+ GGML_ASSERT(ggml_is_contiguous(src0));
11773
+ GGML_ASSERT(ggml_is_contiguous(src1));
11774
+ GGML_ASSERT(ggml_is_contiguous(dst));
11775
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
11776
+ GGML_ASSERT(ggml_are_same_shape(src1, dst));
11777
+
11778
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
11779
+ return;
11780
+ }
11781
+
11782
+ // TODO: handle transposed/permuted matrices
11783
+
11784
+ const int ith = params->ith;
11785
+ const int nth = params->nth;
11786
+
11787
+ const int nc = src0->ne[0];
11788
+ const int nr = ggml_nrows(src0);
11789
+
11790
+ // rows per thread
11791
+ const int dr = (nr + nth - 1)/nth;
11792
+
11793
+ // row range for this thread
11794
+ const int ir0 = dr*ith;
11795
+ const int ir1 = MIN(ir0 + dr, nr);
11796
+
11797
+ for (int i1 = ir0; i1 < ir1; i1++) {
11798
+ float *dy = (float *)((char *) src0->data + i1*src0->nb[1]);
11799
+ float *y = (float *)((char *) src1->data + i1*src1->nb[1]);
11800
+ float *dx = (float *)((char *) dst->data + i1*dst->nb[1]);
11801
+
11802
+ #ifndef NDEBUG
11803
+ for (int i = 0; i < nc; ++i) {
11804
+ //printf("p[%d] = %f\n", i, p[i]);
11805
+ assert(!isnan(dy[i]));
11806
+ assert(!isnan(y[i]));
11807
+ }
11808
+ #endif
11809
+ // Jii = yi - yi*yi
11810
+ // Jij = -yi*yj
11811
+ // J = diag(y)-y.T*y
11812
+ // dx = J * dy
11813
+ // dxk = sum_i(Jki * dyi)
11814
+ // dxk = sum_i(-yk*yi * dyi) - (-yk*yk)*dyk + (yk - yk*yk)*dyk
11815
+ // dxk = sum_i(-yk*yi * dyi) + yk*dyk
11816
+ // dxk = -yk * sum_i(yi * dyi) + yk*dyk
11817
+ // dxk = -yk * dot(y, dy) + yk*dyk
11818
+ // dxk = yk * (- dot(y, dy) + dyk)
11819
+ // dxk = yk * (dyk - dot(y, dy))
11820
+ //
11821
+ // post-order:
11822
+ // dot_y_dy := dot(y, dy)
11823
+ // dx := dy
11824
+ // dx := dx - dot_y_dy
11825
+ // dx := dx * y
11826
+
11827
+ // linear runtime, no additional memory
11828
+ float dot_y_dy = 0;
11829
+ ggml_vec_dot_f32 (nc, &dot_y_dy, y, dy);
11830
+ ggml_vec_cpy_f32 (nc, dx, dy);
11831
+ ggml_vec_acc1_f32(nc, dx, -dot_y_dy);
11832
+ ggml_vec_mul_f32 (nc, dx, dx, y);
11833
+
11834
+ #ifndef NDEBUG
11835
+ for (int i = 0; i < nc; ++i) {
11836
+ assert(!isnan(dx[i]));
11837
+ assert(!isinf(dx[i]));
11838
+ }
11839
+ #endif
11840
+ }
11841
+ }
11842
+
11843
+ static void ggml_compute_forward_soft_max_back(
11844
+ const struct ggml_compute_params * params,
11845
+ const struct ggml_tensor * src0,
11846
+ const struct ggml_tensor * src1,
11847
+ struct ggml_tensor * dst) {
11848
+ switch (src0->type) {
11849
+ case GGML_TYPE_F32:
11850
+ {
11851
+ ggml_compute_forward_soft_max_back_f32(params, src0, src1, dst);
11852
+ } break;
11853
+ default:
11854
+ {
11855
+ GGML_ASSERT(false);
11856
+ } break;
11857
+ }
11858
+ }
11859
+
10988
11860
  // ggml_compute_forward_alibi
10989
11861
 
10990
11862
  static void ggml_compute_forward_alibi_f32(
@@ -10993,8 +11865,9 @@ static void ggml_compute_forward_alibi_f32(
10993
11865
  const struct ggml_tensor * src1,
10994
11866
  struct ggml_tensor * dst) {
10995
11867
  assert(params->ith == 0);
10996
- assert(src1->type == GGML_TYPE_I32);
10997
- assert(ggml_nelements(src1) == 3);
11868
+
11869
+ GGML_ASSERT(src1->type == GGML_TYPE_I32);
11870
+ GGML_ASSERT(ggml_nelements(src1) == 3);
10998
11871
 
10999
11872
  if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
11000
11873
  return;
@@ -11057,8 +11930,9 @@ static void ggml_compute_forward_alibi_f16(
11057
11930
  const struct ggml_tensor * src1,
11058
11931
  struct ggml_tensor * dst) {
11059
11932
  assert(params->ith == 0);
11060
- assert(src1->type == GGML_TYPE_I32);
11061
- assert(ggml_nelements(src1) == 3);
11933
+
11934
+ GGML_ASSERT(src1->type == GGML_TYPE_I32);
11935
+ GGML_ASSERT(ggml_nelements(src1) == 3);
11062
11936
 
11063
11937
  if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
11064
11938
  return;
@@ -11160,15 +12034,16 @@ static void ggml_compute_forward_clamp_f32(
11160
12034
  const struct ggml_tensor * src1,
11161
12035
  struct ggml_tensor * dst) {
11162
12036
  assert(params->ith == 0);
11163
- assert(src1->type == GGML_TYPE_I32);
11164
- assert(ggml_nelements(src1) == 2);
12037
+
12038
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
12039
+ GGML_ASSERT(ggml_nelements(src1) == 2);
11165
12040
 
11166
12041
  if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
11167
12042
  return;
11168
12043
  }
11169
12044
 
11170
- const int min = ((float *) src1->data)[0];
11171
- const int max = ((float *) src1->data)[1];
12045
+ const float min = ((float *) src1->data)[0];
12046
+ const float max = ((float *) src1->data)[1];
11172
12047
 
11173
12048
  const int ith = params->ith;
11174
12049
  const int nth = params->nth;
@@ -11726,9 +12601,9 @@ static void ggml_compute_forward_rope_back(
11726
12601
  }
11727
12602
  }
11728
12603
 
11729
- // ggml_compute_forward_conv_1d_1s
12604
+ // ggml_compute_forward_conv_1d_s1_ph
11730
12605
 
11731
- static void ggml_compute_forward_conv_1d_1s_f16_f32(
12606
+ static void ggml_compute_forward_conv_1d_s1_ph_f16_f32(
11732
12607
  const struct ggml_compute_params * params,
11733
12608
  const struct ggml_tensor * src0,
11734
12609
  const struct ggml_tensor * src1,
@@ -11848,7 +12723,7 @@ static void ggml_compute_forward_conv_1d_1s_f16_f32(
11848
12723
  }
11849
12724
  }
11850
12725
 
11851
- static void ggml_compute_forward_conv_1d_1s_f32(
12726
+ static void ggml_compute_forward_conv_1d_s1_ph_f32(
11852
12727
  const struct ggml_compute_params * params,
11853
12728
  const struct ggml_tensor * src0,
11854
12729
  const struct ggml_tensor * src1,
@@ -11968,7 +12843,7 @@ static void ggml_compute_forward_conv_1d_1s_f32(
11968
12843
  }
11969
12844
  }
11970
12845
 
11971
- static void ggml_compute_forward_conv_1d_1s(
12846
+ static void ggml_compute_forward_conv_1d_s1_ph(
11972
12847
  const struct ggml_compute_params * params,
11973
12848
  const struct ggml_tensor * src0,
11974
12849
  const struct ggml_tensor * src1,
@@ -11976,11 +12851,11 @@ static void ggml_compute_forward_conv_1d_1s(
11976
12851
  switch (src0->type) {
11977
12852
  case GGML_TYPE_F16:
11978
12853
  {
11979
- ggml_compute_forward_conv_1d_1s_f16_f32(params, src0, src1, dst);
12854
+ ggml_compute_forward_conv_1d_s1_ph_f16_f32(params, src0, src1, dst);
11980
12855
  } break;
11981
12856
  case GGML_TYPE_F32:
11982
12857
  {
11983
- ggml_compute_forward_conv_1d_1s_f32(params, src0, src1, dst);
12858
+ ggml_compute_forward_conv_1d_s1_ph_f32(params, src0, src1, dst);
11984
12859
  } break;
11985
12860
  default:
11986
12861
  {
@@ -11989,9 +12864,9 @@ static void ggml_compute_forward_conv_1d_1s(
11989
12864
  }
11990
12865
  }
11991
12866
 
11992
- // ggml_compute_forward_conv_1d_2s
12867
+ // ggml_compute_forward_conv_1d_s2_ph
11993
12868
 
11994
- static void ggml_compute_forward_conv_1d_2s_f16_f32(
12869
+ static void ggml_compute_forward_conv_1d_s2_ph_f16_f32(
11995
12870
  const struct ggml_compute_params * params,
11996
12871
  const struct ggml_tensor * src0,
11997
12872
  const struct ggml_tensor * src1,
@@ -12111,7 +12986,7 @@ static void ggml_compute_forward_conv_1d_2s_f16_f32(
12111
12986
  }
12112
12987
  }
12113
12988
 
12114
- static void ggml_compute_forward_conv_1d_2s_f32(
12989
+ static void ggml_compute_forward_conv_1d_s2_ph_f32(
12115
12990
  const struct ggml_compute_params * params,
12116
12991
  const struct ggml_tensor * src0,
12117
12992
  const struct ggml_tensor * src1,
@@ -12231,7 +13106,7 @@ static void ggml_compute_forward_conv_1d_2s_f32(
12231
13106
  }
12232
13107
  }
12233
13108
 
12234
- static void ggml_compute_forward_conv_1d_2s(
13109
+ static void ggml_compute_forward_conv_1d_s2_ph(
12235
13110
  const struct ggml_compute_params * params,
12236
13111
  const struct ggml_tensor * src0,
12237
13112
  const struct ggml_tensor * src1,
@@ -12239,11 +13114,11 @@ static void ggml_compute_forward_conv_1d_2s(
12239
13114
  switch (src0->type) {
12240
13115
  case GGML_TYPE_F16:
12241
13116
  {
12242
- ggml_compute_forward_conv_1d_2s_f16_f32(params, src0, src1, dst);
13117
+ ggml_compute_forward_conv_1d_s2_ph_f16_f32(params, src0, src1, dst);
12243
13118
  } break;
12244
13119
  case GGML_TYPE_F32:
12245
13120
  {
12246
- ggml_compute_forward_conv_1d_2s_f32(params, src0, src1, dst);
13121
+ ggml_compute_forward_conv_1d_s2_ph_f32(params, src0, src1, dst);
12247
13122
  } break;
12248
13123
  default:
12249
13124
  {
@@ -12252,51 +13127,188 @@ static void ggml_compute_forward_conv_1d_2s(
12252
13127
  }
12253
13128
  }
12254
13129
 
12255
- // ggml_compute_forward_flash_attn
13130
+ // ggml_compute_forward_conv_2d_sk_p0
12256
13131
 
12257
- static void ggml_compute_forward_flash_attn_f32(
13132
+ static void ggml_compute_forward_conv_2d_sk_p0_f16_f32(
12258
13133
  const struct ggml_compute_params * params,
12259
- const struct ggml_tensor * q,
12260
- const struct ggml_tensor * k,
12261
- const struct ggml_tensor * v,
12262
- const bool masked,
12263
- struct ggml_tensor * dst) {
13134
+ const struct ggml_tensor * src0,
13135
+ const struct ggml_tensor * src1,
13136
+ struct ggml_tensor * dst) {
13137
+ GGML_ASSERT(src0->type == GGML_TYPE_F16);
13138
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
13139
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
13140
+
12264
13141
  int64_t t0 = ggml_perf_time_us();
12265
13142
  UNUSED(t0);
12266
13143
 
12267
- const int64_t neq0 = q->ne[0];
12268
- const int64_t neq1 = q->ne[1];
12269
- const int64_t neq2 = q->ne[2];
12270
- const int64_t neq3 = q->ne[3];
13144
+ const int ne00 = src0->ne[0];
13145
+ const int ne01 = src0->ne[1];
13146
+ const int ne02 = src0->ne[2];
13147
+ //const int ne03 = src0->ne[3];
12271
13148
 
12272
- const int64_t nek0 = k->ne[0];
12273
- const int64_t nek1 = k->ne[1];
12274
- //const int64_t nek2 = k->ne[2];
12275
- //const int64_t nek3 = k->ne[3];
13149
+ const int ne10 = src1->ne[0];
13150
+ //const int ne11 = src1->ne[1];
13151
+ const int ne12 = src1->ne[2];
13152
+ //const int ne13 = src1->ne[3];
12276
13153
 
12277
- //const int64_t nev0 = v->ne[0];
12278
- const int64_t nev1 = v->ne[1];
12279
- //const int64_t nev2 = v->ne[2];
12280
- //const int64_t nev3 = v->ne[3];
13154
+ const int ne0 = dst->ne[0];
13155
+ const int ne1 = dst->ne[1];
13156
+ const int ne2 = dst->ne[2];
13157
+ //const int ne3 = dst->ne[3];
13158
+ //const int ne = ne0*ne1*ne2*ne3;
12281
13159
 
12282
- const int64_t ne0 = dst->ne[0];
12283
- const int64_t ne1 = dst->ne[1];
12284
- //const int64_t ne2 = dst->ne[2];
12285
- //const int64_t ne3 = dst->ne[3];
13160
+ const int nb00 = src0->nb[0];
13161
+ //const int nb01 = src0->nb[1];
13162
+ //const int nb02 = src0->nb[2];
13163
+ const int nb03 = src0->nb[3];
12286
13164
 
12287
- const int nbk0 = k->nb[0];
12288
- const int nbk1 = k->nb[1];
12289
- const int nbk2 = k->nb[2];
12290
- const int nbk3 = k->nb[3];
13165
+ const int nb10 = src1->nb[0];
13166
+ //const int nb11 = src1->nb[1];
13167
+ const int nb12 = src1->nb[2];
13168
+ //const int nb13 = src1->nb[3];
12291
13169
 
12292
- const int nbq0 = q->nb[0];
12293
- const int nbq1 = q->nb[1];
12294
- const int nbq2 = q->nb[2];
12295
- const int nbq3 = q->nb[3];
13170
+ //const int nb0 = dst->nb[0];
13171
+ //const int nb1 = dst->nb[1];
13172
+ const int nb2 = dst->nb[2];
13173
+ //const int nb3 = dst->nb[3];
12296
13174
 
12297
- const int nbv0 = v->nb[0];
12298
- const int nbv1 = v->nb[1];
12299
- const int nbv2 = v->nb[2];
13175
+ const int ith = params->ith;
13176
+ const int nth = params->nth;
13177
+
13178
+ const int nk0 = ne00;
13179
+ const int nk1 = ne01;
13180
+
13181
+ // size of the convolution row - the kernel size unrolled across all channels
13182
+ // round-up so it is more suitable for SIMD
13183
+ const int ew0 = ggml_up32(nk0*nk1*ne02);
13184
+
13185
+ GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
13186
+ GGML_ASSERT(nb10 == sizeof(float));
13187
+
13188
+ if (params->type == GGML_TASK_INIT) {
13189
+ // TODO: fix this memset (wsize is overestimated)
13190
+ memset(params->wdata, 0, params->wsize);
13191
+
13192
+ // prepare source data (src1)
13193
+ {
13194
+ ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
13195
+
13196
+ for (int i12 = 0; i12 < ne12; i12++) {
13197
+ const float * const src = (float *)((char *) src1->data + i12*nb12);
13198
+ ggml_fp16_t * dst_data = wdata;
13199
+
13200
+ for (int i1 = 0; i1 < ne1; i1++) {
13201
+ for (int i0 = 0; i0 < ne0; i0++) {
13202
+ for (int ik1 = 0; ik1 < nk1; ik1++) {
13203
+ for (int ik0 = 0; ik0 < nk0; ik0++) {
13204
+ dst_data[(i1*ne0 + i0)*ew0 + i12*(nk0*nk1) + ik1*nk0 + ik0] =
13205
+ GGML_FP32_TO_FP16(src[(i1*nk1 + ik1)*ne10 + (i0*nk0 + ik0)]);
13206
+ }
13207
+ }
13208
+ }
13209
+ }
13210
+ }
13211
+ }
13212
+
13213
+ return;
13214
+ }
13215
+
13216
+ if (params->type == GGML_TASK_FINALIZE) {
13217
+ return;
13218
+ }
13219
+
13220
+ // total patches in dst
13221
+ const int np = ne2;
13222
+
13223
+ // patches per thread
13224
+ const int dp = (np + nth - 1)/nth;
13225
+
13226
+ // patch range for this thread
13227
+ const int ip0 = dp*ith;
13228
+ const int ip1 = MIN(ip0 + dp, np);
13229
+
13230
+ ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
13231
+
13232
+ for (int i2 = ip0; i2 < ip1; i2++) {
13233
+ float * dst_data = (float *)((char *) dst->data + i2*nb2);
13234
+
13235
+ for (int i1 = 0; i1 < ne1; ++i1) {
13236
+ for (int i0 = 0; i0 < ne0; ++i0) {
13237
+ ggml_vec_dot_f16(ew0, dst_data + i1*ne0 + i0,
13238
+ (ggml_fp16_t *) ((char *) src0->data + i2*nb03),
13239
+ (ggml_fp16_t *) wdata + (i1*ne0 + i0)*ew0);
13240
+ }
13241
+ }
13242
+ }
13243
+ }
13244
+
13245
+ static void ggml_compute_forward_conv_2d_sk_p0(
13246
+ const struct ggml_compute_params * params,
13247
+ const struct ggml_tensor * src0,
13248
+ const struct ggml_tensor * src1,
13249
+ struct ggml_tensor * dst) {
13250
+ switch (src0->type) {
13251
+ case GGML_TYPE_F16:
13252
+ {
13253
+ ggml_compute_forward_conv_2d_sk_p0_f16_f32(params, src0, src1, dst);
13254
+ } break;
13255
+ case GGML_TYPE_F32:
13256
+ {
13257
+ //ggml_compute_forward_conv_2d_sk_p0_f32(params, src0, src1, dst);
13258
+ GGML_ASSERT(false);
13259
+ } break;
13260
+ default:
13261
+ {
13262
+ GGML_ASSERT(false);
13263
+ } break;
13264
+ }
13265
+ }
13266
+
13267
+ // ggml_compute_forward_flash_attn
13268
+
13269
+ static void ggml_compute_forward_flash_attn_f32(
13270
+ const struct ggml_compute_params * params,
13271
+ const struct ggml_tensor * q,
13272
+ const struct ggml_tensor * k,
13273
+ const struct ggml_tensor * v,
13274
+ const bool masked,
13275
+ struct ggml_tensor * dst) {
13276
+ int64_t t0 = ggml_perf_time_us();
13277
+ UNUSED(t0);
13278
+
13279
+ const int64_t neq0 = q->ne[0];
13280
+ const int64_t neq1 = q->ne[1];
13281
+ const int64_t neq2 = q->ne[2];
13282
+ const int64_t neq3 = q->ne[3];
13283
+
13284
+ const int64_t nek0 = k->ne[0];
13285
+ const int64_t nek1 = k->ne[1];
13286
+ //const int64_t nek2 = k->ne[2];
13287
+ //const int64_t nek3 = k->ne[3];
13288
+
13289
+ //const int64_t nev0 = v->ne[0];
13290
+ const int64_t nev1 = v->ne[1];
13291
+ //const int64_t nev2 = v->ne[2];
13292
+ //const int64_t nev3 = v->ne[3];
13293
+
13294
+ const int64_t ne0 = dst->ne[0];
13295
+ const int64_t ne1 = dst->ne[1];
13296
+ //const int64_t ne2 = dst->ne[2];
13297
+ //const int64_t ne3 = dst->ne[3];
13298
+
13299
+ const int nbk0 = k->nb[0];
13300
+ const int nbk1 = k->nb[1];
13301
+ const int nbk2 = k->nb[2];
13302
+ const int nbk3 = k->nb[3];
13303
+
13304
+ const int nbq0 = q->nb[0];
13305
+ const int nbq1 = q->nb[1];
13306
+ const int nbq2 = q->nb[2];
13307
+ const int nbq3 = q->nb[3];
13308
+
13309
+ const int nbv0 = v->nb[0];
13310
+ const int nbv1 = v->nb[1];
13311
+ const int nbv2 = v->nb[2];
12300
13312
  const int nbv3 = v->nb[3];
12301
13313
 
12302
13314
  const int nb0 = dst->nb[0];
@@ -12938,91 +13950,917 @@ static void ggml_compute_forward_flash_ff(
12938
13950
  }
12939
13951
  }
12940
13952
 
12941
- // ggml_compute_forward_map_unary
13953
+ // ggml_compute_forward_flash_attn_back
12942
13954
 
12943
- static void ggml_compute_forward_map_unary_f32(
13955
+ static void ggml_compute_forward_flash_attn_back_f32(
12944
13956
  const struct ggml_compute_params * params,
12945
- const struct ggml_tensor * src0,
12946
- struct ggml_tensor * dst,
12947
- const ggml_unary_op_f32_t fun) {
12948
- GGML_ASSERT(ggml_are_same_shape(src0, dst));
13957
+ const struct ggml_tensor * q,
13958
+ const struct ggml_tensor * k,
13959
+ const struct ggml_tensor * v,
13960
+ const struct ggml_tensor * d,
13961
+ const bool masked,
13962
+ struct ggml_tensor * dst) {
13963
+ int64_t t0 = ggml_perf_time_us();
13964
+ UNUSED(t0);
12949
13965
 
12950
- if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
12951
- return;
12952
- }
13966
+ const int64_t neq0 = q->ne[0];
13967
+ const int64_t neq1 = q->ne[1];
13968
+ const int64_t neq2 = q->ne[2];
13969
+ const int64_t neq3 = q->ne[3];
12953
13970
 
12954
- const int n = ggml_nrows(src0);
12955
- const int nc = src0->ne[0];
13971
+ const int64_t nek0 = k->ne[0];
13972
+ const int64_t nek1 = k->ne[1];
13973
+ //const int64_t nek2 = k->ne[2];
13974
+ //const int64_t nek3 = k->ne[3];
12956
13975
 
12957
- assert( dst->nb[0] == sizeof(float));
12958
- assert(src0->nb[0] == sizeof(float));
13976
+ const int64_t nev0 = v->ne[0];
13977
+ const int64_t nev1 = v->ne[1];
13978
+ //const int64_t nev2 = v->ne[2];
13979
+ //const int64_t nev3 = v->ne[3];
12959
13980
 
12960
- for (int i = 0; i < n; i++) {
12961
- fun(nc,
12962
- (float *) ((char *) dst->data + i*( dst->nb[1])),
12963
- (float *) ((char *) src0->data + i*(src0->nb[1])));
12964
- }
12965
- }
13981
+ const int64_t ned0 = d->ne[0];
13982
+ const int64_t ned1 = d->ne[1];
13983
+ //const int64_t ned2 = d->ne[2];
13984
+ //const int64_t ned3 = d->ne[3];
12966
13985
 
13986
+ const int64_t ne0 = dst->ne[0];
13987
+ const int64_t ne1 = dst->ne[1];
13988
+ const int64_t ne2 = dst->ne[2];
13989
+ const int64_t ne3 = dst->ne[3];
12967
13990
 
12968
- static void ggml_compute_forward_map_unary(
12969
- const struct ggml_compute_params * params,
12970
- const struct ggml_tensor * src0,
12971
- struct ggml_tensor * dst,
12972
- const ggml_unary_op_f32_t fun) {
12973
- switch (src0->type) {
12974
- case GGML_TYPE_F32:
12975
- {
12976
- ggml_compute_forward_map_unary_f32(params, src0, dst, fun);
12977
- } break;
12978
- default:
12979
- {
12980
- GGML_ASSERT(false);
12981
- } break;
12982
- }
12983
- }
13991
+ const int nbk0 = k->nb[0];
13992
+ const int nbk1 = k->nb[1];
13993
+ const int nbk2 = k->nb[2];
13994
+ const int nbk3 = k->nb[3];
12984
13995
 
12985
- // ggml_compute_forward_map_binary
13996
+ const int nbq0 = q->nb[0];
13997
+ const int nbq1 = q->nb[1];
13998
+ const int nbq2 = q->nb[2];
13999
+ const int nbq3 = q->nb[3];
12986
14000
 
12987
- static void ggml_compute_forward_map_binary_f32(
12988
- const struct ggml_compute_params * params,
12989
- const struct ggml_tensor * src0,
12990
- const struct ggml_tensor * src1,
12991
- struct ggml_tensor * dst,
12992
- const ggml_binary_op_f32_t fun) {
12993
- assert(params->ith == 0);
12994
- assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
14001
+ const int nbv0 = v->nb[0];
14002
+ const int nbv1 = v->nb[1];
14003
+ const int nbv2 = v->nb[2];
14004
+ const int nbv3 = v->nb[3];
12995
14005
 
12996
- if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
14006
+ const int nbd0 = d->nb[0];
14007
+ const int nbd1 = d->nb[1];
14008
+ const int nbd2 = d->nb[2];
14009
+ const int nbd3 = d->nb[3];
14010
+
14011
+ const int nb0 = dst->nb[0];
14012
+ const int nb1 = dst->nb[1];
14013
+ const int nb2 = dst->nb[2];
14014
+ const int nb3 = dst->nb[3];
14015
+
14016
+ const int ith = params->ith;
14017
+ const int nth = params->nth;
14018
+
14019
+ const int64_t D = neq0;
14020
+ const int64_t N = neq1;
14021
+ const int64_t P = nek1 - N;
14022
+ const int64_t M = P + N;
14023
+
14024
+ const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL);
14025
+ const int mxDM = MAX(D, Mup);
14026
+
14027
+ // GGML_ASSERT(ne0 == D);
14028
+ // GGML_ASSERT(ne1 == N);
14029
+ GGML_ASSERT(P >= 0);
14030
+
14031
+ GGML_ASSERT(nbq0 == sizeof(float));
14032
+ GGML_ASSERT(nbk0 == sizeof(float));
14033
+ GGML_ASSERT(nbv0 == sizeof(float));
14034
+
14035
+ GGML_ASSERT(neq0 == D);
14036
+ GGML_ASSERT(nek0 == D);
14037
+ GGML_ASSERT(nev1 == D);
14038
+ GGML_ASSERT(ned0 == D);
14039
+
14040
+ GGML_ASSERT(neq1 == N);
14041
+ GGML_ASSERT(nek1 == N + P);
14042
+ GGML_ASSERT(nev1 == D);
14043
+ GGML_ASSERT(ned1 == N);
14044
+
14045
+ // dst cannot be transposed or permuted
14046
+ GGML_ASSERT(nb0 == sizeof(float));
14047
+ GGML_ASSERT(nb0 <= nb1);
14048
+ GGML_ASSERT(nb1 <= nb2);
14049
+ GGML_ASSERT(nb2 <= nb3);
14050
+
14051
+ if (params->type == GGML_TASK_INIT) {
14052
+ if (ith == 0) {
14053
+ memset(dst->data, 0, nb0*ne0*ne1*ne2*ne3);
14054
+ }
14055
+ return;
14056
+ }
14057
+
14058
+ if (params->type == GGML_TASK_FINALIZE) {
12997
14059
  return;
12998
14060
  }
12999
14061
 
13000
- const int n = ggml_nrows(src0);
13001
- const int nc = src0->ne[0];
14062
+ // parallelize by q rows using ggml_vec_dot_f32
14063
+
14064
+ // total rows in q
14065
+ const int nr = neq2*neq3;
14066
+
14067
+ // rows per thread
14068
+ const int dr = (nr + nth - 1)/nth;
14069
+
14070
+ // row range for this thread
14071
+ const int ir0 = dr*ith;
14072
+ const int ir1 = MIN(ir0 + dr, nr);
14073
+
14074
+ const float scale = 1.0f/sqrtf(D);
14075
+
14076
+ //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
14077
+
14078
+ for (int ir = ir0; ir < ir1; ++ir) {
14079
+ // q indices
14080
+ const int iq3 = ir/(neq2);
14081
+ const int iq2 = ir - iq3*neq2;
14082
+ for ( int iq1 = 0; iq1 < neq1; ++iq1) {
14083
+
14084
+
14085
+ // not sure about CACHE_LINE_SIZE_F32..
14086
+ // - maybe it must not be multiplied by 2 and excluded from .. in SM 1*(..) offset?
14087
+ float * S = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 0*(mxDM+CACHE_LINE_SIZE_F32);
14088
+ float * SM = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 1*(mxDM+CACHE_LINE_SIZE_F32);
14089
+
14090
+ for (int i = M; i < Mup; ++i) {
14091
+ S[i] = -INFINITY;
14092
+ }
14093
+
14094
+ for (int64_t ic = 0; ic < nek1; ++ic) {
14095
+ // k indices
14096
+ const int ik3 = iq3;
14097
+ const int ik2 = iq2;
14098
+ const int ik1 = ic;
14099
+
14100
+ // S indices
14101
+ const int i1 = ik1;
14102
+
14103
+ ggml_vec_dot_f32(neq0,
14104
+ S + i1,
14105
+ (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
14106
+ (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
14107
+ }
14108
+
14109
+ // scale
14110
+ ggml_vec_scale_f32(nek1, S, scale);
14111
+
14112
+ if (masked) {
14113
+ for (int64_t i = P; i < M; i++) {
14114
+ if (i > P + iq1) {
14115
+ S[i] = -INFINITY;
14116
+ }
14117
+ }
14118
+ }
14119
+
14120
+ // softmax
14121
+ {
14122
+ float max = -INFINITY;
14123
+ ggml_vec_max_f32(M, &max, S);
14124
+
14125
+ ggml_float sum = 0.0;
14126
+ {
14127
+ #ifdef GGML_SOFT_MAX_ACCELERATE
14128
+ max = -max;
14129
+ vDSP_vsadd(SM, 1, &max, SM, 1, Mup);
14130
+ vvexpf(SM, SM, &Mup);
14131
+ ggml_vec_sum_f32(Mup, &sum, SM);
14132
+ #else
14133
+ uint16_t scvt[GGML_SOFT_MAX_UNROLL];
14134
+ ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
14135
+
14136
+ for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
14137
+ float * SR = S + i;
14138
+ float * SW = SM + i;
14139
+
14140
+ for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) {
14141
+ if (SR[j] == -INFINITY) {
14142
+ SW[j] = 0.0f;
14143
+ } else {
14144
+ ggml_fp16_t s = GGML_FP32_TO_FP16(SR[j] - max);
14145
+ memcpy(&scvt[j], &s, sizeof(uint16_t));
14146
+ const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]);
14147
+ sump[j] += (ggml_float)val;
14148
+ SW[j] = val;
14149
+ }
14150
+ }
14151
+ }
14152
+
14153
+ for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) {
14154
+ sum += sump[i];
14155
+ }
14156
+ #endif
14157
+ }
14158
+
14159
+ assert(sum > 0.0);
14160
+
14161
+ sum = 1.0/sum;
14162
+ ggml_vec_scale_f32(M, SM, sum);
14163
+
14164
+ }
14165
+
14166
+ // step-by-step explanation
14167
+ {
14168
+ // forward-process shape grads from backward process
14169
+ // parallel_for iq2,iq3:
14170
+ // k[:D,:M,:,:] [D,M,:,:] grad[k][:D,:M,iq2,iq3] += grad[kcur]
14171
+ // q[:D,:N,:,:] [D,N,:,:] grad[q][:D,iq1,iq2,iq3] += grad[qcur]
14172
+ // v[:M,:D,:,:] [M,D,:,:] grad[v][:M,:D,iq2,iq3] += grad[vcur]
14173
+ // for iq1:
14174
+ // kcur = k[:D,:M,iq2,iq3] [D,M,1,1] grad[kcur] = grad[S1].T @ qcur
14175
+ // qcur = q[:D,iq1,iq2,iq3] [D,1,1,1] grad[qcur] = grad[S1] @ kcur
14176
+ // vcur = v[:M,:D,iq2,iq3] [M,D,1,1] grad[vcur] = grad[S5].T @ S4
14177
+ // S0 = -Inf [D,1,1,1]
14178
+ // ~S1[i] = dot(kcur[:D,i], qcur)
14179
+ // S1 = qcur @ kcur.T [M,1,1,1] grad[S1] = grad[S2] * scale
14180
+ // S2 = S1 * scale [M,1,1,1] grad[S2] = diag_mask_zero(grad[S3], P)
14181
+ // S3 = diag_mask_inf(S2, P) [M,1,1,1] grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
14182
+ // S4 = softmax(S3) [M,1,1,1] grad[S4] = grad[S5] @ vcur
14183
+ // ~S5[i] = dot(vcur[:,i], S4)
14184
+ // S5 = S4 @ vcur.T [D,1,1,1] grad[S5] = d[:D,iq1,iq2,iq3]
14185
+ // ~dst[i,iq1,iq2,iq3] = S5[i] ^
14186
+ // dst[:D,iq1,iq2,iq3] = S5 | grad[dst[:D,iq1,iq2,iq3]] = d[:D,iq1,iq2,iq3]
14187
+ // dst backward-/ grad[dst] = d
14188
+ //
14189
+ // output gradients with their dependencies:
14190
+ //
14191
+ // grad[kcur] = grad[S1].T @ qcur
14192
+ // grad[S1] = diag_mask_zero(grad[S3], P) * scale
14193
+ // grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
14194
+ // grad[S4] = grad[S5] @ vcur
14195
+ // grad[S4] = d[:D,iq1,iq2,iq3] @ vcur
14196
+ // grad[qcur] = grad[S1] @ kcur
14197
+ // grad[vcur] = grad[S5].T @ S4
14198
+ // grad[vcur] = d[:D,iq1,iq2,iq3].T @ S4
14199
+ //
14200
+ // in post-order:
14201
+ //
14202
+ // S1 = qcur @ kcur.T
14203
+ // S2 = S1 * scale
14204
+ // S3 = diag_mask_inf(S2, P)
14205
+ // S4 = softmax(S3)
14206
+ // grad[S4] = d[:D,iq1,iq2,iq3] @ vcur
14207
+ // grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
14208
+ // grad[S1] = diag_mask_zero(grad[S3], P) * scale
14209
+ // grad[qcur] = grad[S1] @ kcur
14210
+ // grad[kcur] = grad[S1].T @ qcur
14211
+ // grad[vcur] = d[:D,iq1,iq2,iq3].T @ S4
14212
+ //
14213
+ // using less variables (SM=S4):
14214
+ //
14215
+ // S = diag_mask_inf(qcur @ kcur.T * scale, P)
14216
+ // SM = softmax(S)
14217
+ // S = d[:D,iq1,iq2,iq3] @ vcur
14218
+ // dot_SM_gradSM = dot(SM, S)
14219
+ // S = SM * (S - dot(SM, S))
14220
+ // S = diag_mask_zero(S, P) * scale
14221
+ //
14222
+ // grad[q][:D,iq1,iq2,iq3] += S @ kcur
14223
+ // grad[k][:D,:M,iq2,iq3] += S.T @ qcur
14224
+ // grad[v][:M,:D,iq2,iq3] += d[:D,iq1,iq2,iq3].T @ SM
14225
+ }
14226
+
14227
+ // S = gradSM = d[:D,iq1,iq2,iq3] @ vcur
14228
+ // S = d[:D,iq1,iq2,iq3] @ vcur
14229
+ // S[:M] += vcur[:M,ic] * d[ic,iq1,iq2,iq3]
14230
+ ggml_vec_set_f32(M, S, 0);
14231
+ for (int64_t ic = 0; ic < D; ++ic) {
14232
+ // dst indices
14233
+ const int i1 = iq1;
14234
+ const int i2 = iq2;
14235
+ const int i3 = iq3;
14236
+
14237
+ ggml_vec_mad_f32(M,
14238
+ S,
14239
+ (float *) ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)),
14240
+ *(float *) ((char *) d->data + (ic*nbd0 + i1*nbd1 + i2*nbd2 + i3*nbd3)));
14241
+ }
14242
+
14243
+ // S = SM * (S - dot(SM, S))
14244
+ float dot_SM_gradSM = 0;
14245
+ ggml_vec_dot_f32 (M, &dot_SM_gradSM, SM, S);
14246
+ ggml_vec_acc1_f32(M, S, -dot_SM_gradSM);
14247
+ ggml_vec_mul_f32 (M, S, S, SM);
14248
+
14249
+ // S = diag_mask_zero(S, P) * scale
14250
+ if (masked) {
14251
+ // for (int64_t i = P + iq1 + 1; i < M; i++) {
14252
+ // S[i] = 0;
14253
+ // }
14254
+ for (int64_t i = P; i < M; i++) {
14255
+ if (i > P + iq1) {
14256
+ S[i] = 0;
14257
+ }
14258
+ }
14259
+ }
14260
+ ggml_vec_scale_f32(M, S, scale);
14261
+
14262
+ void * grad_q = (char *) dst->data;
14263
+ void * grad_k = (char *) dst->data + nb0*D*N*neq2*neq3;
14264
+ void * grad_v = (char *) dst->data + nb0*D*N*neq2*neq3 + nb0*D*M*neq2*neq3;
14265
+
14266
+ const size_t nbgq1 = nb0*neq0;
14267
+ const size_t nbgq2 = nb0*neq0*neq1;
14268
+ const size_t nbgq3 = nb0*neq0*neq1*neq2;
14269
+
14270
+ const size_t nbgk1 = nb0*nek0;
14271
+ const size_t nbgk2 = nb0*nek0*nek1;
14272
+ const size_t nbgk3 = nb0*nek0*nek1*neq2;
14273
+
14274
+ const size_t nbgv1 = nb0*nev0;
14275
+ const size_t nbgv2 = nb0*nev0*nev1;
14276
+ const size_t nbgv3 = nb0*nev0*nev1*neq2;
14277
+
14278
+ // S shape [M,1]
14279
+ // SM shape [M,1]
14280
+ // kcur shape [D,M]
14281
+ // qcur shape [D,1]
14282
+ // vcur shape [M,D]
14283
+ //
14284
+ // grad[q][:D,iq1,iq2,iq3] += S @ kcur
14285
+ // grad[q][:D,iq1,iq2,iq3] += shape[M,1] @ shape[D,M]
14286
+ // grad[q][:D,iq1,iq2,iq3] += S[ic] * kcur[:D,ic]
14287
+ //
14288
+ //// grad[q][ic,iq1,iq2,iq3] += dot(kcur[:,ic],S.T)
14289
+ //// grad[q][ic,iq1,iq2,iq3] += dot(k[:D,ic,iq2,iq3],S.T)
14290
+ for (int64_t ic = 0; ic < M; ++ic) {
14291
+ // dst indices
14292
+ const int i1 = iq1;
14293
+ const int i2 = iq2;
14294
+ const int i3 = iq3;
14295
+
14296
+ ggml_vec_mad_f32(D,
14297
+ (float *) ((char *) grad_q + (i1*nbgq1 + i2*nbgq2 + i3*nbgq3)),
14298
+ (float *) ((char *) k->data + (ic*nbk1 + i2*nbk2 + i3*nbk3)),
14299
+ S[ic]);
14300
+ }
14301
+
14302
+ // grad[k][:D,:M,iq2,iq3] += S.T @ qcur
14303
+ // grad[k][:D,ic,iq2,iq3] += S.T[0,ic] * qcur[:D,0]
14304
+ // grad[k][:D,ic,iq2,iq3] += S[ic] * qcur[:D,0]
14305
+ for (int64_t ic = 0; ic < M; ++ic) {
14306
+ // dst indices
14307
+ const int i1 = iq1;
14308
+ const int i2 = iq2;
14309
+ const int i3 = iq3;
14310
+
14311
+ // ggml_vec_set_f32(D,
14312
+ // (float *) ((char *) grad_k + (ic*nbgk1 + i2*nbgk2 + i3*nbgk3)),
14313
+ // 0);
14314
+ ggml_vec_mad_f32(D,
14315
+ (float *) ((char *) grad_k + (ic*nbgk1 + i2*nbgk2 + i3*nbgk3)),
14316
+ (float *) ((char *) q->data + (i1*nbq1 + i2*nbq2 + i3*nbq3)),
14317
+ S[ic]);
14318
+ }
14319
+
14320
+ // grad[v][:M,:D,iq2,iq3] += d[:D,iq1,iq2,iq3].T @ SM
14321
+ // grad[v][:M,ic,iq2,iq3] += d[:D,iq1,iq2,iq3].T[0,ic] * SM[:M]
14322
+ // grad[v][:M,ic,iq2,iq3] += d[ic,iq1,iq2,iq3] * SM[:M]
14323
+ for (int64_t ic = 0; ic < D; ++ic) {
14324
+ // dst indices
14325
+ const int i1 = iq1;
14326
+ const int i2 = iq2;
14327
+ const int i3 = iq3;
14328
+
14329
+ // ggml_vec_set_f32(M,
14330
+ // (float *) ((char *) grad_v + ( ic*nbgv1 + i2*nbgv2 + i3*nbgv3)),
14331
+ // 0);
14332
+ ggml_vec_mad_f32(M,
14333
+ (float *) ((char *) grad_v + ( ic*nbgv1 + i2*nbgv2 + i3*nbgv3)),
14334
+ SM,
14335
+ *(float *) ((char *) d->data + (ic*nbd0 + i1*nbd1 + i2*nbd2 + i3*nbd3)));
14336
+ }
14337
+ }
14338
+ }
14339
+ }
14340
+
14341
+ static void ggml_compute_forward_flash_attn_back(
14342
+ const struct ggml_compute_params * params,
14343
+ const struct ggml_tensor * q,
14344
+ const struct ggml_tensor * k,
14345
+ const struct ggml_tensor * v,
14346
+ const struct ggml_tensor * d,
14347
+ const bool masked,
14348
+ struct ggml_tensor * dst) {
14349
+ switch (q->type) {
14350
+ case GGML_TYPE_F32:
14351
+ {
14352
+ ggml_compute_forward_flash_attn_back_f32(params, q, k, v, d, masked, dst);
14353
+ } break;
14354
+ default:
14355
+ {
14356
+ GGML_ASSERT(false);
14357
+ } break;
14358
+ }
14359
+ }
14360
+
14361
+ // ggml_compute_forward_win_part
14362
+
14363
+ static void ggml_compute_forward_win_part_f32(
14364
+ const struct ggml_compute_params * params,
14365
+ const struct ggml_tensor * src0,
14366
+ const struct ggml_tensor * opt0,
14367
+ struct ggml_tensor * dst) {
14368
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
14369
+ return;
14370
+ }
14371
+
14372
+ const int64_t ne00 = src0->ne[0]; UNUSED(ne00);
14373
+ const int64_t ne01 = src0->ne[1];
14374
+ const int64_t ne02 = src0->ne[2];
14375
+ const int64_t ne03 = src0->ne[3]; UNUSED(ne03);
14376
+
14377
+ const int64_t ne0 = dst->ne[0];
14378
+ const int64_t ne1 = dst->ne[1];
14379
+ const int64_t ne2 = dst->ne[2];
14380
+ const int64_t ne3 = dst->ne[3]; UNUSED(ne3);
14381
+
14382
+ const int32_t nep0 = ((const int32_t *)(opt0->data))[0];
14383
+ const int32_t nep1 = ((const int32_t *)(opt0->data))[1];
14384
+ const int32_t w = ((const int32_t *)(opt0->data))[2];
14385
+
14386
+ assert(ne00 == ne0);
14387
+ assert(ne3 == nep0*nep1);
14388
+
14389
+ // TODO: optimize / multi-thread
14390
+ for (int py = 0; py < nep1; ++py) {
14391
+ for (int px = 0; px < nep0; ++px) {
14392
+ const int64_t i3 = py*nep0 + px;
14393
+ for (int64_t i2 = 0; i2 < ne2; ++i2) {
14394
+ for (int64_t i1 = 0; i1 < ne1; ++i1) {
14395
+ for (int64_t i0 = 0; i0 < ne0; ++i0) {
14396
+ const int64_t i02 = py*w + i2;
14397
+ const int64_t i01 = px*w + i1;
14398
+ const int64_t i00 = i0;
14399
+
14400
+ const int64_t i = i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0 + i0;
14401
+ const int64_t j = i02*ne01*ne00 + i01*ne00 + i00;
14402
+
14403
+ if (py*w + i2 >= ne02 || px*w + i1 >= ne01) {
14404
+ ((float *) dst->data)[i] = 0.0f;
14405
+ } else {
14406
+ ((float *) dst->data)[i] = ((float *) src0->data)[j];
14407
+ }
14408
+ }
14409
+ }
14410
+ }
14411
+ }
14412
+ }
14413
+ }
14414
+
14415
+ static void ggml_compute_forward_win_part(
14416
+ const struct ggml_compute_params * params,
14417
+ const struct ggml_tensor * src0,
14418
+ const struct ggml_tensor * opt0,
14419
+ struct ggml_tensor * dst) {
14420
+ switch (src0->type) {
14421
+ case GGML_TYPE_F32:
14422
+ {
14423
+ ggml_compute_forward_win_part_f32(params, src0, opt0, dst);
14424
+ } break;
14425
+ default:
14426
+ {
14427
+ GGML_ASSERT(false);
14428
+ } break;
14429
+ }
14430
+ }
14431
+
14432
+ // ggml_compute_forward_win_unpart
14433
+
14434
+ static void ggml_compute_forward_win_unpart_f32(
14435
+ const struct ggml_compute_params * params,
14436
+ const struct ggml_tensor * src0,
14437
+ const struct ggml_tensor * opt0,
14438
+ struct ggml_tensor * dst) {
14439
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
14440
+ return;
14441
+ }
14442
+
14443
+ const int64_t ne00 = src0->ne[0];
14444
+ const int64_t ne01 = src0->ne[1];
14445
+ const int64_t ne02 = src0->ne[2];
14446
+ //const int64_t ne03 = src0->ne[3];
14447
+
14448
+ const int64_t ne0 = dst->ne[0];
14449
+ const int64_t ne1 = dst->ne[1];
14450
+ const int64_t ne2 = dst->ne[2];
14451
+
14452
+ const int32_t w = ((const int32_t *)(opt0->data))[0];
14453
+
14454
+ // padding
14455
+ const int px = (w - ne1%w)%w;
14456
+ //const int py = (w - ne2%w)%w;
14457
+
14458
+ const int npx = (px + ne1)/w;
14459
+ //const int npy = (py + ne2)/w;
14460
+
14461
+ assert(ne0 == ne00);
14462
+
14463
+ // TODO: optimize / multi-thread
14464
+ for (int64_t i2 = 0; i2 < ne2; ++i2) {
14465
+ for (int64_t i1 = 0; i1 < ne1; ++i1) {
14466
+ for (int64_t i0 = 0; i0 < ne0; ++i0) {
14467
+ const int ip2 = i2/w;
14468
+ const int ip1 = i1/w;
14469
+
14470
+ const int64_t i02 = i2%w;
14471
+ const int64_t i01 = i1%w;
14472
+ const int64_t i00 = i0;
14473
+
14474
+ const int64_t i = (ip2*npx + ip1)*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00 + i00;
14475
+ const int64_t j = i2*ne1*ne0 + i1*ne0 + i0;
14476
+
14477
+ ((float *) dst->data)[j] = ((float *) src0->data)[i];
14478
+ }
14479
+ }
14480
+ }
14481
+ }
14482
+
14483
+ static void ggml_compute_forward_win_unpart(
14484
+ const struct ggml_compute_params * params,
14485
+ const struct ggml_tensor * src0,
14486
+ const struct ggml_tensor * opt0,
14487
+ struct ggml_tensor * dst) {
14488
+ switch (src0->type) {
14489
+ case GGML_TYPE_F32:
14490
+ {
14491
+ ggml_compute_forward_win_unpart_f32(params, src0, opt0, dst);
14492
+ } break;
14493
+ default:
14494
+ {
14495
+ GGML_ASSERT(false);
14496
+ } break;
14497
+ }
14498
+ }
14499
+
14500
+ // ggml_compute_forward_map_unary
14501
+
14502
+ static void ggml_compute_forward_map_unary_f32(
14503
+ const struct ggml_compute_params * params,
14504
+ const struct ggml_tensor * src0,
14505
+ struct ggml_tensor * dst,
14506
+ const ggml_unary_op_f32_t fun) {
14507
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
14508
+
14509
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
14510
+ return;
14511
+ }
14512
+
14513
+ const int n = ggml_nrows(src0);
14514
+ const int nc = src0->ne[0];
14515
+
14516
+ assert( dst->nb[0] == sizeof(float));
14517
+ assert(src0->nb[0] == sizeof(float));
14518
+
14519
+ for (int i = 0; i < n; i++) {
14520
+ fun(nc,
14521
+ (float *) ((char *) dst->data + i*( dst->nb[1])),
14522
+ (float *) ((char *) src0->data + i*(src0->nb[1])));
14523
+ }
14524
+ }
14525
+
14526
+
14527
+ static void ggml_compute_forward_map_unary(
14528
+ const struct ggml_compute_params * params,
14529
+ const struct ggml_tensor * src0,
14530
+ struct ggml_tensor * dst,
14531
+ const ggml_unary_op_f32_t fun) {
14532
+ switch (src0->type) {
14533
+ case GGML_TYPE_F32:
14534
+ {
14535
+ ggml_compute_forward_map_unary_f32(params, src0, dst, fun);
14536
+ } break;
14537
+ default:
14538
+ {
14539
+ GGML_ASSERT(false);
14540
+ } break;
14541
+ }
14542
+ }
14543
+
14544
+ // ggml_compute_forward_map_binary
14545
+
14546
+ static void ggml_compute_forward_map_binary_f32(
14547
+ const struct ggml_compute_params * params,
14548
+ const struct ggml_tensor * src0,
14549
+ const struct ggml_tensor * src1,
14550
+ struct ggml_tensor * dst,
14551
+ const ggml_binary_op_f32_t fun) {
14552
+ assert(params->ith == 0);
14553
+ assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
14554
+
14555
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
14556
+ return;
14557
+ }
14558
+
14559
+ const int n = ggml_nrows(src0);
14560
+ const int nc = src0->ne[0];
14561
+
14562
+ assert( dst->nb[0] == sizeof(float));
14563
+ assert(src0->nb[0] == sizeof(float));
14564
+ assert(src1->nb[0] == sizeof(float));
14565
+
14566
+ for (int i = 0; i < n; i++) {
14567
+ fun(nc,
14568
+ (float *) ((char *) dst->data + i*( dst->nb[1])),
14569
+ (float *) ((char *) src0->data + i*(src0->nb[1])),
14570
+ (float *) ((char *) src1->data + i*(src1->nb[1])));
14571
+ }
14572
+ }
14573
+
14574
+
14575
+ static void ggml_compute_forward_map_binary(
14576
+ const struct ggml_compute_params * params,
14577
+ const struct ggml_tensor * src0,
14578
+ const struct ggml_tensor * src1,
14579
+ struct ggml_tensor * dst,
14580
+ const ggml_binary_op_f32_t fun) {
14581
+ switch (src0->type) {
14582
+ case GGML_TYPE_F32:
14583
+ {
14584
+ ggml_compute_forward_map_binary_f32(params, src0, src1, dst, fun);
14585
+ } break;
14586
+ default:
14587
+ {
14588
+ GGML_ASSERT(false);
14589
+ } break;
14590
+ }
14591
+ }
14592
+
14593
+ // ggml_compute_forward_cross_entropy_loss
14594
+
14595
+ static void ggml_compute_forward_cross_entropy_loss_f32(
14596
+ const struct ggml_compute_params * params,
14597
+ const struct ggml_tensor * src0,
14598
+ const struct ggml_tensor * src1,
14599
+ struct ggml_tensor * dst) {
14600
+ GGML_ASSERT(ggml_is_contiguous(src0));
14601
+ GGML_ASSERT(ggml_is_contiguous(src1));
14602
+ GGML_ASSERT(ggml_is_scalar(dst));
14603
+ GGML_ASSERT(ggml_are_same_shape(src0, src1));
14604
+
14605
+ const int ith = params->ith;
14606
+ const int nth = params->nth;
14607
+
14608
+ float * sums = (float *) params->wdata;
14609
+
14610
+ // TODO: handle transposed/permuted matrices
14611
+ const int nc = src0->ne[0];
14612
+ const int nr = ggml_nrows(src0);
14613
+
14614
+ if (params->type == GGML_TASK_INIT) {
14615
+ if (ith == 0) {
14616
+ memset(sums, 0, sizeof(float) * (nth + nth * nc));
14617
+ }
14618
+ return;
14619
+ }
14620
+
14621
+ if (params->type == GGML_TASK_FINALIZE) {
14622
+ if (ith == 0) {
14623
+ float * dp = (float *) dst->data;
14624
+ ggml_vec_sum_f32(nth, dp, sums);
14625
+ dp[0] *= -1.0f;
14626
+ }
14627
+ return;
14628
+ }
14629
+
14630
+ const double eps = 1e-9;
14631
+
14632
+ // rows per thread
14633
+ const int dr = (nr + nth - 1)/nth;
14634
+
14635
+ // row range for this thread
14636
+ const int ir0 = dr*ith;
14637
+ const int ir1 = MIN(ir0 + dr, nr);
14638
+
14639
+ for (int i1 = ir0; i1 < ir1; i1++) {
14640
+ float * s0 = (float *)((char *) src0->data + i1*src0->nb[1]);
14641
+ float * s1 = (float *)((char *) src1->data + i1*src1->nb[1]);
14642
+ float * st = (float *) params->wdata + nth + ith*nc;
14643
+
14644
+ #ifndef NDEBUG
14645
+ for (int i = 0; i < nc; ++i) {
14646
+ //printf("p[%d] = %f\n", i, p[i]);
14647
+ assert(!isnan(s0[i]));
14648
+ assert(!isnan(s1[i]));
14649
+ }
14650
+ #endif
14651
+ // soft_max
14652
+ ggml_float sum = 0.0;
14653
+ {
14654
+ float max = -INFINITY;
14655
+ ggml_vec_max_f32(nc, &max, s0);
14656
+
14657
+ uint16_t scvt;
14658
+ for (int i = 0; i < nc; i++) {
14659
+ if (s0[i] == -INFINITY) {
14660
+ st[i] = 0.0f;
14661
+ } else {
14662
+ // const float val = (s0[i] == -INFINITY) ? 0.0 : exp(s0[i] - max);
14663
+ ggml_fp16_t s = GGML_FP32_TO_FP16(s0[i] - max);
14664
+ memcpy(&scvt, &s, sizeof(scvt));
14665
+ const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt]);
14666
+ sum += (ggml_float)val;
14667
+ st[i] = val;
14668
+ }
14669
+ }
14670
+
14671
+ assert(sum > 0.0);
14672
+ // sum = 1.0/sum;
14673
+ }
14674
+ // avoid log(0) by rescaling from [0..1] to [eps..1]
14675
+ sum = (1.0 - eps) / sum;
14676
+ ggml_vec_scale_f32(nc, st, sum);
14677
+ ggml_vec_add1_f32(nc, st, st, eps);
14678
+ ggml_vec_log_f32(nc, st, st);
14679
+ ggml_vec_mul_f32(nc, st, st, s1);
14680
+
14681
+ ggml_vec_sum_f32(nc, sums + ith, st);
14682
+
14683
+ #ifndef NDEBUG
14684
+ for (int i = 0; i < nc; ++i) {
14685
+ assert(!isnan(st[i]));
14686
+ assert(!isinf(st[i]));
14687
+ }
14688
+ #endif
14689
+ }
14690
+
14691
+ }
14692
+
14693
+ static void ggml_compute_forward_cross_entropy_loss(
14694
+ const struct ggml_compute_params * params,
14695
+ const struct ggml_tensor * src0,
14696
+ const struct ggml_tensor * src1,
14697
+ struct ggml_tensor * dst) {
14698
+ switch (src0->type) {
14699
+ case GGML_TYPE_F32:
14700
+ {
14701
+ ggml_compute_forward_cross_entropy_loss_f32(params, src0, src1, dst);
14702
+ } break;
14703
+ default:
14704
+ {
14705
+ GGML_ASSERT(false);
14706
+ } break;
14707
+ }
14708
+ }
14709
+
14710
+ // ggml_compute_forward_cross_entropy_loss_back
14711
+
14712
+ static void ggml_compute_forward_cross_entropy_loss_back_f32(
14713
+ const struct ggml_compute_params * params,
14714
+ const struct ggml_tensor * src0,
14715
+ const struct ggml_tensor * src1,
14716
+ const struct ggml_tensor * opt0,
14717
+ struct ggml_tensor * dst) {
14718
+ GGML_ASSERT(ggml_is_contiguous(dst));
14719
+ GGML_ASSERT(ggml_is_contiguous(src0));
14720
+ GGML_ASSERT(ggml_is_contiguous(src1));
14721
+ GGML_ASSERT(ggml_is_contiguous(opt0));
14722
+ GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
14723
+
14724
+ const int64_t ith = params->ith;
14725
+ const int64_t nth = params->nth;
14726
+
14727
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
14728
+ return;
14729
+ }
14730
+
14731
+ const float eps = 1e-9f;
14732
+
14733
+ // TODO: handle transposed/permuted matrices
14734
+ const int64_t nc = src0->ne[0];
14735
+ const int64_t nr = ggml_nrows(src0);
14736
+
14737
+ // rows per thread
14738
+ const int64_t dr = (nr + nth - 1)/nth;
14739
+
14740
+ // row range for this thread
14741
+ const int64_t ir0 = dr*ith;
14742
+ const int64_t ir1 = MIN(ir0 + dr, nr);
14743
+
14744
+ float * d = (float *) opt0->data;
14745
+
14746
+ for (int64_t i1 = ir0; i1 < ir1; i1++) {
14747
+ float * ds0 = (float *)((char *) dst->data + i1*dst->nb[1]);
14748
+ float * s0 = (float *)((char *) src0->data + i1*src0->nb[1]);
14749
+ float * s1 = (float *)((char *) src1->data + i1*src1->nb[1]);
14750
+ float * sm = (float *) params->wdata + ith*nc;
14751
+
14752
+ #ifndef NDEBUG
14753
+ for (int i = 0; i < nc; ++i) {
14754
+ //printf("p[%d] = %f\n", i, p[i]);
14755
+ assert(!isnan(s0[i]));
14756
+ assert(!isnan(s1[i]));
14757
+ }
14758
+ #endif
14759
+ // step by step explanation:
14760
+ {
14761
+ //float * sums = (float *) params->wdata;
14762
+
14763
+ // forward pass with annotated gradients from backward pass
14764
+ // (built by going in reverse operation order, adding to gradients of current operation args)
14765
+ // st0 = exp(s0-max(s0)) grad[st0] = grad[st1]*(1.0 - eps)/sum
14766
+ // from softmax_back: grad[s0] = st1_k * (grad[st1]_k - dot(st1, grad[st1]))
14767
+ // ggml_vec_scale_f32(nc, st, sum); // st1 = st0*/sum = softmax(s0) grad[st1] = grad[st2]*(1.0 - eps)
14768
+ // ggml_vec_scale_f32(nc, st, (1.0f - eps)); // st2 = st1*(1.0 - eps) grad[st2] = grad[st3]
14769
+ // ggml_vec_add1_f32(nc, st, st, eps); // st3 = st2 + eps grad[st3] = grad[st4]/st3
14770
+ // ggml_vec_log_f32(nc, st, st); // st4 = log(st3) grad[st4] = grad[st5] * s1
14771
+ // ggml_vec_mul_f32(nc, st, st, s1); // st5 = st4 * s1 grad[st5] = grad[sums[ith]]
14772
+ // ggml_vec_sum_f32(nc, sums + ith, st); // sums[ith] = st5 grad[sums[ith]] = grad[cross_entropy_loss] = -grad[cel]
14773
+
14774
+ // substitute into grad[st1], because we can reuse softmax_back from this point on
14775
+ // grad[st1] = -grad[cel]*s1*(1.0 - eps)/(eps + softmax(s0)*(1.0 - eps))
14776
+ // postorder:
14777
+ // grad[st1] := softmax(s0)
14778
+ // grad[st1] := grad[st1]*(1.0 - eps)
14779
+ // grad[st1] := grad[st1] + eps
14780
+ // grad[st1] := s1 / grad[st1]
14781
+ // grad[st1] := grad[st1]*(1.0-eps)*-grad[cel]
14782
+
14783
+ // src0 gradients by going through softmax_back
14784
+ // grad[s0] = st1_k * (grad[st1]_k - dot(st1, grad[st1]))
14785
+ // from softmax_back:
14786
+ // dxk = yk * (dyk - dot(y, dy))
14787
+ // dot_y_dy := dot(y, dy)
14788
+ // dx := dy
14789
+ // dx := dx - dot_y_dy
14790
+ // dx := dx * y
14791
+ // postorder:
14792
+ // dot_st1_dst1 := dot(st1, grad[st1])
14793
+ // grad[s0] := grad[st1]
14794
+ // grad[s0] := grad[s0] - dot_st1_dst1
14795
+ // grad[s0] := grad[s0] * st1
14796
+
14797
+ // prepend postorder from grad[st1] directly using grad[s0] as memory location, as we will grad[s0] := grad[st1]
14798
+ // sm := softmax(s0)
14799
+ // grad[s0] := sm*(1.0 - eps)
14800
+ // grad[s0] := grad[s0] + eps
14801
+ // grad[s0] := s1 / grad[s0]
14802
+ // grad[s0] := grad[s0]*(1.0-eps)*-grad[cel]
14803
+ // dot_st1_dst1 := dot(sm, grad[s0])
14804
+ // grad[s0] := grad[s0] - dot_st1_dst1
14805
+ // grad[s0] := grad[s0] * sm
14806
+ }
14807
+
14808
+ // soft_max
14809
+ ggml_float sum = 0.0;
14810
+ {
14811
+ float max = -INFINITY;
14812
+ ggml_vec_max_f32(nc, &max, s0);
14813
+
14814
+ uint16_t scvt;
14815
+ for (int i = 0; i < nc; i++) {
14816
+ if (s0[i] == -INFINITY) {
14817
+ sm[i] = 0.0f;
14818
+ } else {
14819
+ // const float val = (s0[i] == -INFINITY) ? 0.0 : exp(s0[i] - max);
14820
+ ggml_fp16_t s = GGML_FP32_TO_FP16(s0[i] - max);
14821
+ memcpy(&scvt, &s, sizeof(scvt));
14822
+ const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt]);
14823
+ sum += (ggml_float)val;
14824
+ sm[i] = val;
14825
+ }
14826
+ }
14827
+
14828
+ assert(sum > 0.0);
14829
+ sum = 1.0/sum;
14830
+ }
13002
14831
 
13003
- assert( dst->nb[0] == sizeof(float));
13004
- assert(src0->nb[0] == sizeof(float));
13005
- assert(src1->nb[0] == sizeof(float));
14832
+ float dot_st1_dst1 = 0;
14833
+ ggml_vec_scale_f32(nc, sm, sum);
14834
+ ggml_vec_cpy_f32 (nc, ds0, sm);
14835
+ ggml_vec_scale_f32(nc, ds0, (1.0f - eps));
14836
+ ggml_vec_add1_f32 (nc, ds0, ds0, eps);
14837
+ ggml_vec_div_f32 (nc, ds0, s1, ds0);
14838
+ ggml_vec_scale_f32(nc, ds0, -(1.0f - eps)*d[0]);
14839
+ ggml_vec_dot_f32 (nc, &dot_st1_dst1, sm, ds0);
14840
+ ggml_vec_acc1_f32 (nc, ds0, -dot_st1_dst1);
14841
+ ggml_vec_mul_f32 (nc, ds0, ds0, sm);
13006
14842
 
13007
- for (int i = 0; i < n; i++) {
13008
- fun(nc,
13009
- (float *) ((char *) dst->data + i*( dst->nb[1])),
13010
- (float *) ((char *) src0->data + i*(src0->nb[1])),
13011
- (float *) ((char *) src1->data + i*(src1->nb[1])));
14843
+ #ifndef NDEBUG
14844
+ for (int i = 0; i < nc; ++i) {
14845
+ assert(!isnan(sm[i]));
14846
+ assert(!isinf(sm[i]));
14847
+ assert(!isnan(ds0[i]));
14848
+ assert(!isinf(ds0[i]));
14849
+ }
14850
+ #endif
13012
14851
  }
13013
14852
  }
13014
14853
 
13015
-
13016
- static void ggml_compute_forward_map_binary(
14854
+ static void ggml_compute_forward_cross_entropy_loss_back(
13017
14855
  const struct ggml_compute_params * params,
13018
14856
  const struct ggml_tensor * src0,
13019
14857
  const struct ggml_tensor * src1,
13020
- struct ggml_tensor * dst,
13021
- const ggml_binary_op_f32_t fun) {
14858
+ const struct ggml_tensor * opt0,
14859
+ struct ggml_tensor * dst) {
13022
14860
  switch (src0->type) {
13023
14861
  case GGML_TYPE_F32:
13024
14862
  {
13025
- ggml_compute_forward_map_binary_f32(params, src0, src1, dst, fun);
14863
+ ggml_compute_forward_cross_entropy_loss_back_f32(params, src0, src1, opt0, dst);
13026
14864
  } break;
13027
14865
  default:
13028
14866
  {
@@ -13031,6 +14869,7 @@ static void ggml_compute_forward_map_binary(
13031
14869
  }
13032
14870
  }
13033
14871
 
14872
+
13034
14873
  /////////////////////////////////
13035
14874
 
13036
14875
  static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
@@ -13102,6 +14941,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
13102
14941
  {
13103
14942
  ggml_compute_forward_repeat(params, tensor->src0, tensor);
13104
14943
  } break;
14944
+ case GGML_OP_REPEAT_BACK:
14945
+ {
14946
+ ggml_compute_forward_repeat_back(params, tensor->src0, tensor);
14947
+ } break;
13105
14948
  case GGML_OP_ABS:
13106
14949
  {
13107
14950
  ggml_compute_forward_abs(params, tensor->src0, tensor);
@@ -13126,6 +14969,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
13126
14969
  {
13127
14970
  ggml_compute_forward_gelu(params, tensor->src0, tensor);
13128
14971
  } break;
14972
+ case GGML_OP_GELU_QUICK:
14973
+ {
14974
+ ggml_compute_forward_gelu_quick(params, tensor->src0, tensor);
14975
+ } break;
13129
14976
  case GGML_OP_SILU:
13130
14977
  {
13131
14978
  ggml_compute_forward_silu(params, tensor->src0, tensor);
@@ -13150,6 +14997,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
13150
14997
  {
13151
14998
  ggml_compute_forward_mul_mat(params, tensor->src0, tensor->src1, tensor);
13152
14999
  } break;
15000
+ case GGML_OP_OUT_PROD:
15001
+ {
15002
+ ggml_compute_forward_out_prod(params, tensor->src0, tensor->src1, tensor);
15003
+ } break;
13153
15004
  case GGML_OP_SCALE:
13154
15005
  {
13155
15006
  ggml_compute_forward_scale(params, tensor->src0, tensor->src1, tensor);
@@ -13206,6 +15057,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
13206
15057
  {
13207
15058
  ggml_compute_forward_soft_max(params, tensor->src0, tensor);
13208
15059
  } break;
15060
+ case GGML_OP_SOFT_MAX_BACK:
15061
+ {
15062
+ ggml_compute_forward_soft_max_back(params, tensor->src0, tensor->src1, tensor);
15063
+ } break;
13209
15064
  case GGML_OP_ROPE:
13210
15065
  {
13211
15066
  ggml_compute_forward_rope(params, tensor->src0, tensor->src1, tensor);
@@ -13222,25 +15077,44 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
13222
15077
  {
13223
15078
  ggml_compute_forward_clamp(params, tensor->src0, tensor->src1, tensor);
13224
15079
  } break;
13225
- case GGML_OP_CONV_1D_1S:
15080
+ case GGML_OP_CONV_1D_S1_PH:
15081
+ {
15082
+ ggml_compute_forward_conv_1d_s1_ph(params, tensor->src0, tensor->src1, tensor);
15083
+ } break;
15084
+ case GGML_OP_CONV_1D_S2_PH:
13226
15085
  {
13227
- ggml_compute_forward_conv_1d_1s(params, tensor->src0, tensor->src1, tensor);
15086
+ ggml_compute_forward_conv_1d_s2_ph(params, tensor->src0, tensor->src1, tensor);
13228
15087
  } break;
13229
- case GGML_OP_CONV_1D_2S:
15088
+ case GGML_OP_CONV_2D_SK_P0:
13230
15089
  {
13231
- ggml_compute_forward_conv_1d_2s(params, tensor->src0, tensor->src1, tensor);
15090
+ ggml_compute_forward_conv_2d_sk_p0(params, tensor->src0, tensor->src1, tensor);
13232
15091
  } break;
13233
15092
  case GGML_OP_FLASH_ATTN:
13234
15093
  {
13235
- int32_t t = ggml_get_i32_1d(tensor->opt[1], 0);
15094
+ const int32_t t = ggml_get_i32_1d(tensor->opt[1], 0);
13236
15095
  GGML_ASSERT(t == 0 || t == 1);
13237
- bool masked = t != 0;
15096
+ const bool masked = t != 0;
13238
15097
  ggml_compute_forward_flash_attn(params, tensor->src0, tensor->src1, tensor->opt[0], masked, tensor);
13239
15098
  } break;
13240
15099
  case GGML_OP_FLASH_FF:
13241
15100
  {
13242
15101
  ggml_compute_forward_flash_ff(params, tensor->src0, tensor->src1, tensor->opt[0], tensor->opt[1], tensor->opt[2], tensor);
13243
15102
  } break;
15103
+ case GGML_OP_FLASH_ATTN_BACK:
15104
+ {
15105
+ int32_t t = ggml_get_i32_1d(tensor->opt[2], 0);
15106
+ GGML_ASSERT(t == 0 || t == 1);
15107
+ bool masked = t != 0;
15108
+ ggml_compute_forward_flash_attn_back(params, tensor->src0, tensor->src1, tensor->opt[0], tensor->opt[1], masked, tensor);
15109
+ } break;
15110
+ case GGML_OP_WIN_PART:
15111
+ {
15112
+ ggml_compute_forward_win_part(params, tensor->src0, tensor->opt[0], tensor);
15113
+ } break;
15114
+ case GGML_OP_WIN_UNPART:
15115
+ {
15116
+ ggml_compute_forward_win_unpart(params, tensor->src0, tensor->opt[0], tensor);
15117
+ } break;
13244
15118
  case GGML_OP_MAP_UNARY:
13245
15119
  {
13246
15120
  const ggml_unary_op_f32_t fun = *((ggml_unary_op_f32_t *)tensor->opt[0]->data);
@@ -13253,6 +15127,16 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
13253
15127
  ggml_compute_forward_map_binary(params, tensor->src0, tensor->src1, tensor, fun);
13254
15128
  }
13255
15129
  break;
15130
+ case GGML_OP_CROSS_ENTROPY_LOSS:
15131
+ {
15132
+ ggml_compute_forward_cross_entropy_loss(params, tensor->src0, tensor->src1, tensor);
15133
+ }
15134
+ break;
15135
+ case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
15136
+ {
15137
+ ggml_compute_forward_cross_entropy_loss_back(params, tensor->src0, tensor->src1, tensor->opt[0], tensor);
15138
+ }
15139
+ break;
13256
15140
  case GGML_OP_NONE:
13257
15141
  {
13258
15142
  // nop
@@ -13391,11 +15275,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
13391
15275
  src0->grad =
13392
15276
  ggml_add_impl(ctx,
13393
15277
  src0->grad,
13394
- ggml_mul(ctx,
13395
- tensor->grad, // this was not catched by test_grad because in test_grad tensor->grad is 1
15278
+ ggml_scale(ctx,
13396
15279
  ggml_div(ctx,
13397
- ggml_repeat(ctx, ggml_new_f32(ctx, 0.5f), tensor),
13398
- tensor)),
15280
+ tensor->grad,
15281
+ tensor),
15282
+ ggml_new_f32(ctx, 0.5f)),
13399
15283
  inplace);
13400
15284
  }
13401
15285
  } break;
@@ -13441,43 +15325,20 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
13441
15325
  {
13442
15326
  // necessary for llama
13443
15327
  if (src0->grad) {
13444
- GGML_ASSERT(src0->n_dims == 1 || src0->n_dims == 2);
13445
- const int nc = tensor->ne[0];
13446
- const int nr = tensor->ne[1];
13447
- const int nc0 = src0->ne[0];
13448
- const int nr0 = src0->ne[1];
13449
- const int ncr = nc/nc0; // guaranteed to be an integer due to the check in ggml_can_repeat
13450
- const int nrr = nr/nr0; // guaranteed to be an integer due to the check in ggml_can_repeat
13451
- // tensor->grad [nc,nr,1,1]
13452
- // reshape [nc0,nc/nc0,nr0,nr/nr0]
13453
- // permute [nc0,nr0,nc/nc0,nr/nr0]
13454
- // substitute [nc0,nr0,ncr,nrr]
13455
- // reshape [nc0*nr0,ncr*nrr,1,1]
13456
- // transpose [ncr*nrr,nc0*nr0,1,1]
13457
- // sum rows [1,nc0*nr0,1,1]
13458
- // transpose [nc0*nr0,1,1]
13459
- // reshape [nc0,nr0,1,1] reshape_1d or reshape_2d
13460
- // add to src0->grad
13461
-
13462
- int64_t ne[4] = {nc0,ncr,nr0,nrr};
13463
-
13464
- struct ggml_tensor* F00 = tensor->grad;
13465
- struct ggml_tensor* F01 = ggml_reshape (ctx, F00, ggml_new_tensor(ctx,tensor->grad->type,4,ne));
13466
- struct ggml_tensor* F02 = ggml_permute (ctx, F01, 0,2,1,3);
13467
- struct ggml_tensor* F03 = ggml_cont (ctx, F02);
13468
- struct ggml_tensor* F04 = ggml_reshape_2d(ctx, F03, nc0*nr0, ncr*nrr);
13469
- struct ggml_tensor* F05 = ggml_transpose (ctx, F04);
13470
- struct ggml_tensor* F06 = ggml_cont (ctx, F05);
13471
- struct ggml_tensor* F07 = ggml_sum_rows (ctx, F06);
13472
- struct ggml_tensor* F08 = ggml_transpose (ctx, F07);
13473
- struct ggml_tensor* F09 = ggml_cont (ctx, F08);
13474
- struct ggml_tensor* F10 = ggml_reshape (ctx, F09, src0->grad);
13475
-
13476
- src0->grad =
13477
- ggml_add_impl(ctx,
13478
- src0->grad,
13479
- F10,
13480
- inplace);
15328
+ src0->grad = ggml_add_impl(ctx,
15329
+ src0->grad,
15330
+ ggml_repeat_back(ctx, tensor->grad, src0->grad),
15331
+ inplace);
15332
+ }
15333
+ } break;
15334
+ case GGML_OP_REPEAT_BACK:
15335
+ {
15336
+ if (src0->grad) {
15337
+ // TODO: test this
15338
+ src0->grad = ggml_add_impl(ctx,
15339
+ src0->grad,
15340
+ ggml_repeat(ctx, tensor->grad, src0->grad),
15341
+ inplace);
13481
15342
  }
13482
15343
  } break;
13483
15344
  case GGML_OP_ABS:
@@ -13525,6 +15386,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
13525
15386
  {
13526
15387
  GGML_ASSERT(false); // TODO: not implemented
13527
15388
  } break;
15389
+ case GGML_OP_GELU_QUICK:
15390
+ {
15391
+ GGML_ASSERT(false); // TODO: not implemented
15392
+ } break;
13528
15393
  case GGML_OP_ALIBI:
13529
15394
  {
13530
15395
  GGML_ASSERT(false); // TODO: not implemented
@@ -13584,38 +15449,37 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
13584
15449
 
13585
15450
  // necessary for llama
13586
15451
  if (src0->grad) {
13587
- // TODO: this requires outer product - ggml_out_prod(ctx, src1, tensor->grad);
13588
15452
  src0->grad =
13589
15453
  ggml_add_impl(ctx,
13590
15454
  src0->grad,
13591
- // ds0 = dt.dot(s1.T)
13592
- // ggml_out_prod(ctx, // [n,m]
13593
- // src1, // [n,p]
13594
- // tensor->grad), // [m,p]
13595
- // for now just using A*B==(B.T*A.T).T
13596
- ggml_cont(ctx, // [n,m]
13597
- ggml_transpose(ctx, // [n,m]
13598
- ggml_mul_mat(ctx, // [m,n]
13599
- ggml_cont(ctx, // [p,m]
13600
- ggml_transpose(ctx, // [p,m]
13601
- tensor->grad)), // [m,p]
13602
- ggml_cont(ctx, // [p,n]
13603
- ggml_transpose(ctx, // [p,n]
13604
- src1))))), // [n,p]
15455
+ ggml_out_prod(ctx, // [n,m]
15456
+ src1, // [n,p]
15457
+ tensor->grad), // [m,p]
13605
15458
  inplace);
13606
15459
  }
13607
15460
  if (src1->grad) {
13608
15461
  src1->grad =
13609
15462
  ggml_add_impl(ctx,
13610
15463
  src1->grad,
13611
- // ds1 = s0.T.dot(dt):
13612
- ggml_mul_mat(ctx, // [n,p]
13613
- ggml_cont(ctx, // [m,n]
13614
- ggml_transpose(ctx, src0)), // [m,n]
13615
- tensor->grad), // [m,p]
15464
+ // ggml_mul_mat(ctx, // [n,p]
15465
+ // ggml_cont(ctx, // [m,n]
15466
+ // ggml_transpose(ctx, src0)), // [m,n]
15467
+ // tensor->grad), // [m,p]
15468
+
15469
+ // // when src0 is bigger than tensor->grad (this is mostly the case in llama),
15470
+ // // avoid transpose of src0, rather transpose smaller tensor->grad
15471
+ // // and then use ggml_out_prod
15472
+ ggml_out_prod(ctx, // [n,p]
15473
+ src0, // [n,m]
15474
+ ggml_transpose(ctx, // [p,m]
15475
+ tensor->grad)), // [m,p]
13616
15476
  inplace);
13617
15477
  }
13618
15478
  } break;
15479
+ case GGML_OP_OUT_PROD:
15480
+ {
15481
+ GGML_ASSERT(false); // TODO: not implemented
15482
+ } break;
13619
15483
  case GGML_OP_SCALE:
13620
15484
  {
13621
15485
  // necessary for llama
@@ -13717,7 +15581,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
13717
15581
  // necessary for llama
13718
15582
  if (src0->grad) {
13719
15583
  size_t offset;
13720
- memcpy(&offset, tensor->padding, sizeof(offset));
15584
+
15585
+ GGML_ASSERT(sizeof(offset) <= ggml_nbytes(tensor->opt[0]));
15586
+ memcpy(&offset, tensor->opt[0]->data, sizeof(offset));
13721
15587
 
13722
15588
  size_t nb1 = tensor->nb[1];
13723
15589
  size_t nb2 = tensor->nb[2];
@@ -13744,10 +15610,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
13744
15610
  {
13745
15611
  // necessary for llama
13746
15612
  if (src0->grad) {
13747
- int axis0 = tensor->padding[0] & 0x3;
13748
- int axis1 = tensor->padding[1] & 0x3;
13749
- int axis2 = tensor->padding[2] & 0x3;
13750
- int axis3 = tensor->padding[3] & 0x3;
15613
+ int32_t * axes = (int32_t *) tensor->opt[0]->data;
15614
+ int axis0 = axes[0] & 0x3;
15615
+ int axis1 = axes[1] & 0x3;
15616
+ int axis2 = axes[2] & 0x3;
15617
+ int axis3 = axes[3] & 0x3;
13751
15618
  int axes_backward[4] = {0,0,0,0};
13752
15619
  axes_backward[axis0] = 0;
13753
15620
  axes_backward[axis1] = 1;
@@ -13831,50 +15698,16 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
13831
15698
  {
13832
15699
  // necessary for llama
13833
15700
  if (src0->grad) {
13834
- // y = softmax(x)
13835
- //
13836
- // Jii = yi - yi*yi
13837
- // Jij = -yi*yj
13838
- // J = diag(y)-y.*y
13839
- // dx = J * dy
13840
- // dxk = sum(Jkj * dyk)
13841
-
13842
- int64_t ne2[4] = {
13843
- tensor->ne[0],
13844
- 1,
13845
- tensor->ne[1]*tensor->ne[2],
13846
- tensor->ne[3]
13847
- };
13848
- struct ggml_tensor * tensor2 = ggml_cont(ctx,
13849
- ggml_reshape_4d(ctx,
13850
- ggml_cont(ctx, tensor),
13851
- ne2[0], ne2[1], ne2[2], ne2[3]));
13852
-
13853
- struct ggml_tensor * grad2 = ggml_cont(ctx,
13854
- ggml_reshape_4d(ctx,
13855
- ggml_cont(ctx, tensor->grad),
13856
- ne2[0], ne2[1], ne2[2], ne2[3]));
13857
-
13858
- struct ggml_tensor * tensor2_t = ggml_cont(ctx, // [1,ne0,ne1*ne2,ne3]
13859
- ggml_permute(ctx, // [1,ne0,ne1*ne2,ne3]
13860
- tensor2, // [ne0,1,ne1*ne2,ne3]
13861
- 1, 0, 2, 3));
13862
-
13863
15701
  src0->grad =
13864
- ggml_add_impl(ctx,
13865
- src0->grad, // [ne0,ne1,ne2,ne3]
13866
- ggml_reshape(ctx, // [ne0,ne1,ne2,ne3]
13867
- ggml_mul_mat(ctx, // [ne0,1,ne1*ne2,ne3]
13868
- ggml_sub(ctx, // [ne0,ne0,ne1*ne2,ne3]
13869
- ggml_diag(ctx, // [ne0,ne0,ne1*ne2,ne3]
13870
- tensor2), // [ne0,1,ne1*ne2,ne3]
13871
- ggml_mul_mat(ctx, // [ne0,ne0,ne1*ne2,ne3]
13872
- tensor2_t, // [1,ne0,ne1*ne2,ne3]
13873
- tensor2_t)), // [1,ne0,ne1*ne2,ne3]
13874
- grad2), // [ne0,1,ne1*ne2,ne3]
13875
- src0->grad),
13876
- inplace);
15702
+ ggml_add_impl(ctx, src0->grad,
15703
+ ggml_soft_max_back(ctx, tensor->grad, tensor),
15704
+ inplace);
13877
15705
  }
15706
+
15707
+ } break;
15708
+ case GGML_OP_SOFT_MAX_BACK:
15709
+ {
15710
+ GGML_ASSERT(false); // TODO: not implemented
13878
15711
  } break;
13879
15712
  case GGML_OP_ROPE:
13880
15713
  {
@@ -13919,27 +15752,206 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
13919
15752
  // noop
13920
15753
  }
13921
15754
  } break;
13922
- case GGML_OP_CONV_1D_1S:
15755
+ case GGML_OP_CONV_1D_S1_PH:
15756
+ {
15757
+ GGML_ASSERT(false); // TODO: not implemented
15758
+ } break;
15759
+ case GGML_OP_CONV_1D_S2_PH:
13923
15760
  {
13924
15761
  GGML_ASSERT(false); // TODO: not implemented
13925
15762
  } break;
13926
- case GGML_OP_CONV_1D_2S:
15763
+ case GGML_OP_CONV_2D_SK_P0:
13927
15764
  {
13928
15765
  GGML_ASSERT(false); // TODO: not implemented
13929
15766
  } break;
13930
15767
  case GGML_OP_FLASH_ATTN:
13931
15768
  {
13932
- GGML_ASSERT(false); // not supported
15769
+ struct ggml_tensor * flash_grad = NULL;
15770
+ if (src0->grad || src1->grad || tensor->opt[0]->grad) {
15771
+ int32_t t = ggml_get_i32_1d(tensor->opt[1], 0);
15772
+ GGML_ASSERT(t == 0 || t == 1);
15773
+ bool masked = t != 0;
15774
+ flash_grad =
15775
+ ggml_flash_attn_back(ctx,
15776
+ src0,
15777
+ src1,
15778
+ tensor->opt[0],
15779
+ tensor->grad,
15780
+ masked);
15781
+ }
15782
+
15783
+ if (src0->grad) {
15784
+ struct ggml_tensor * grad_q = NULL;
15785
+ const size_t nb0 = flash_grad->nb[0];
15786
+ const size_t offset = 0;
15787
+ switch(src0->n_dims) {
15788
+ case 2:
15789
+ {
15790
+ grad_q = ggml_view_2d(ctx,
15791
+ flash_grad,
15792
+ src0->ne[0],
15793
+ src0->ne[1],
15794
+ nb0*src0->ne[0],
15795
+ offset);
15796
+ } break;
15797
+ case 3:
15798
+ {
15799
+ grad_q = ggml_view_3d(ctx,
15800
+ flash_grad,
15801
+ src0->ne[0],
15802
+ src0->ne[1],
15803
+ src0->ne[2],
15804
+ nb0*src0->ne[0],
15805
+ nb0*src0->ne[0]*src0->ne[1],
15806
+ offset);
15807
+ } break;
15808
+ case 4:
15809
+ {
15810
+ grad_q = ggml_view_4d(ctx,
15811
+ flash_grad,
15812
+ src0->ne[0],
15813
+ src0->ne[1],
15814
+ src0->ne[2],
15815
+ src0->ne[3],
15816
+ nb0*src0->ne[0],
15817
+ nb0*src0->ne[0]*src0->ne[1],
15818
+ nb0*src0->ne[0]*src0->ne[1]*src0->ne[2],
15819
+ offset);
15820
+ } break;
15821
+ }
15822
+
15823
+ src0->grad = ggml_add_impl(ctx,
15824
+ src0->grad,
15825
+ grad_q,
15826
+ inplace);
15827
+ }
15828
+
15829
+ if (src1->grad) {
15830
+ struct ggml_tensor * grad_k = NULL;
15831
+ const size_t nb0 = flash_grad->nb[0];
15832
+ const size_t offset = nb0*src0->ne[0]*src0->ne[1]*src0->ne[2]*src0->ne[3];
15833
+ switch(src1->n_dims) {
15834
+ case 2:
15835
+ {
15836
+ grad_k = ggml_view_2d(ctx,
15837
+ flash_grad,
15838
+ src1->ne[0],
15839
+ src1->ne[1],
15840
+ nb0*src1->ne[0],
15841
+ offset);
15842
+ } break;
15843
+ case 3:
15844
+ {
15845
+ grad_k = ggml_view_3d(ctx,
15846
+ flash_grad,
15847
+ src1->ne[0],
15848
+ src1->ne[1],
15849
+ src1->ne[2],
15850
+ nb0*src1->ne[0],
15851
+ nb0*src1->ne[0]*src1->ne[1],
15852
+ offset);
15853
+ } break;
15854
+ case 4:
15855
+ {
15856
+ grad_k = ggml_view_4d(ctx,
15857
+ flash_grad,
15858
+ src1->ne[0],
15859
+ src1->ne[1],
15860
+ src1->ne[2],
15861
+ src1->ne[3],
15862
+ nb0*src1->ne[0],
15863
+ nb0*src1->ne[0]*src1->ne[1],
15864
+ nb0*src1->ne[0]*src1->ne[1]*src1->ne[2],
15865
+ offset);
15866
+ } break;
15867
+ }
15868
+
15869
+ src1->grad = ggml_add_impl(ctx,
15870
+ src1->grad,
15871
+ grad_k,
15872
+ inplace);
15873
+ }
15874
+
15875
+ struct ggml_tensor * opt0 = tensor->opt[0];
15876
+
15877
+ if (opt0->grad) {
15878
+ struct ggml_tensor * grad_v = NULL;
15879
+ const size_t nb0 = flash_grad->nb[0];
15880
+ const size_t offset = nb0*src0->ne[0]*src0->ne[1]*src0->ne[2]*src0->ne[3]
15881
+ + nb0*src1->ne[0]*src1->ne[1]*src1->ne[2]*src1->ne[3];
15882
+ switch(opt0->n_dims) {
15883
+ case 2:
15884
+ {
15885
+ grad_v = ggml_view_2d(ctx,
15886
+ flash_grad,
15887
+ opt0->ne[0],
15888
+ opt0->ne[1],
15889
+ nb0*opt0->ne[0],
15890
+ offset);
15891
+ } break;
15892
+ case 3:
15893
+ {
15894
+ grad_v = ggml_view_3d(ctx,
15895
+ flash_grad,
15896
+ opt0->ne[0],
15897
+ opt0->ne[1],
15898
+ opt0->ne[2],
15899
+ nb0*opt0->ne[0],
15900
+ nb0*opt0->ne[0]*opt0->ne[1],
15901
+ offset);
15902
+ } break;
15903
+ case 4:
15904
+ {
15905
+ grad_v = ggml_view_4d(ctx,
15906
+ flash_grad,
15907
+ opt0->ne[0],
15908
+ opt0->ne[1],
15909
+ opt0->ne[2],
15910
+ opt0->ne[3],
15911
+ nb0*opt0->ne[0],
15912
+ nb0*opt0->ne[0]*opt0->ne[1],
15913
+ nb0*opt0->ne[0]*opt0->ne[1]*opt0->ne[2],
15914
+ offset);
15915
+ } break;
15916
+ }
15917
+
15918
+ opt0->grad = ggml_add_impl(ctx,
15919
+ opt0->grad,
15920
+ grad_v,
15921
+ inplace);
15922
+ }
13933
15923
  } break;
13934
15924
  case GGML_OP_FLASH_FF:
13935
15925
  {
13936
15926
  GGML_ASSERT(false); // not supported
13937
15927
  } break;
15928
+ case GGML_OP_FLASH_ATTN_BACK:
15929
+ {
15930
+ GGML_ASSERT(false); // not supported
15931
+ } break;
15932
+ case GGML_OP_WIN_PART:
15933
+ case GGML_OP_WIN_UNPART:
13938
15934
  case GGML_OP_MAP_UNARY:
13939
15935
  case GGML_OP_MAP_BINARY:
13940
15936
  {
13941
15937
  GGML_ASSERT(false); // not supported
13942
15938
  } break;
15939
+ case GGML_OP_CROSS_ENTROPY_LOSS:
15940
+ {
15941
+ if (src0->grad) {
15942
+ src0->grad = ggml_add_impl(ctx,
15943
+ src0->grad,
15944
+ ggml_cross_entropy_loss_back(ctx,
15945
+ src0,
15946
+ src1,
15947
+ tensor->grad),
15948
+ inplace);
15949
+ }
15950
+ } break;
15951
+ case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
15952
+ {
15953
+ GGML_ASSERT(false); // not supported
15954
+ } break;
13943
15955
  case GGML_OP_NONE:
13944
15956
  {
13945
15957
  // nop
@@ -14316,6 +16328,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
14316
16328
  case GGML_OP_SUM_ROWS:
14317
16329
  case GGML_OP_MEAN:
14318
16330
  case GGML_OP_REPEAT:
16331
+ case GGML_OP_REPEAT_BACK:
14319
16332
  case GGML_OP_ABS:
14320
16333
  case GGML_OP_SGN:
14321
16334
  case GGML_OP_NEG:
@@ -14326,6 +16339,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
14326
16339
  } break;
14327
16340
  case GGML_OP_MUL:
14328
16341
  case GGML_OP_GELU:
16342
+ case GGML_OP_GELU_QUICK:
14329
16343
  case GGML_OP_SILU:
14330
16344
  case GGML_OP_SILU_BACK:
14331
16345
  case GGML_OP_NORM:
@@ -14335,6 +16349,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
14335
16349
  node->n_tasks = n_threads;
14336
16350
  } break;
14337
16351
  case GGML_OP_MUL_MAT:
16352
+ case GGML_OP_OUT_PROD:
14338
16353
  {
14339
16354
  node->n_tasks = n_threads;
14340
16355
 
@@ -14417,6 +16432,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
14417
16432
  } break;
14418
16433
  case GGML_OP_DIAG_MASK_INF:
14419
16434
  case GGML_OP_SOFT_MAX:
16435
+ case GGML_OP_SOFT_MAX_BACK:
14420
16436
  case GGML_OP_ROPE:
14421
16437
  case GGML_OP_ROPE_BACK:
14422
16438
  {
@@ -14430,8 +16446,8 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
14430
16446
  {
14431
16447
  node->n_tasks = 1; //TODO
14432
16448
  } break;
14433
- case GGML_OP_CONV_1D_1S:
14434
- case GGML_OP_CONV_1D_2S:
16449
+ case GGML_OP_CONV_1D_S1_PH:
16450
+ case GGML_OP_CONV_1D_S2_PH:
14435
16451
  {
14436
16452
  node->n_tasks = n_threads;
14437
16453
 
@@ -14458,6 +16474,41 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
14458
16474
  GGML_ASSERT(false);
14459
16475
  }
14460
16476
 
16477
+ work_size = MAX(work_size, cur);
16478
+ } break;
16479
+ case GGML_OP_CONV_2D_SK_P0:
16480
+ {
16481
+ node->n_tasks = n_threads;
16482
+
16483
+ GGML_ASSERT(node->src1->ne[3] == 1);
16484
+
16485
+ const int64_t ne00 = node->src0->ne[0]; // W
16486
+ const int64_t ne01 = node->src0->ne[1]; // H
16487
+ const int64_t ne02 = node->src0->ne[2]; // C
16488
+ const int64_t ne03 = node->src0->ne[3]; // N
16489
+
16490
+ const int64_t ne10 = node->src1->ne[0]; // W
16491
+ const int64_t ne11 = node->src1->ne[1]; // H
16492
+ const int64_t ne12 = node->src1->ne[2]; // C
16493
+
16494
+ const int64_t nk = ne00*ne01;
16495
+
16496
+ UNUSED(ne02);
16497
+ UNUSED(ne03);
16498
+ UNUSED(nk);
16499
+
16500
+ size_t cur = 0;
16501
+
16502
+ if (node->src0->type == GGML_TYPE_F16 &&
16503
+ node->src1->type == GGML_TYPE_F32) {
16504
+ cur = sizeof(ggml_fp16_t)*(ne10*ne11*ne12);
16505
+ } else if (node->src0->type == GGML_TYPE_F32 &&
16506
+ node->src1->type == GGML_TYPE_F32) {
16507
+ cur = sizeof(float)* (ne10*ne11*ne12);
16508
+ } else {
16509
+ GGML_ASSERT(false);
16510
+ }
16511
+
14461
16512
  work_size = MAX(work_size, cur);
14462
16513
  } break;
14463
16514
  case GGML_OP_FLASH_ATTN:
@@ -14498,11 +16549,50 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
14498
16549
 
14499
16550
  work_size = MAX(work_size, cur);
14500
16551
  } break;
16552
+ case GGML_OP_FLASH_ATTN_BACK:
16553
+ {
16554
+ node->n_tasks = n_threads;
16555
+
16556
+ size_t cur = 0;
16557
+
16558
+ const int64_t D = node->src0->ne[0];
16559
+ const int64_t ne11 = ggml_up(node->src1->ne[1], GGML_SOFT_MAX_UNROLL);
16560
+ const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in ggml_compute_forward_flash_attn_back
16561
+ if (node->src1->type == GGML_TYPE_F32) {
16562
+ cur = sizeof(float)*mxDn*node->n_tasks; // TODO: this can become (n_tasks-1)
16563
+ cur += sizeof(float)*mxDn*node->n_tasks; // this is overestimated by x2
16564
+ }
16565
+
16566
+ if (node->src1->type == GGML_TYPE_F16) {
16567
+ cur = sizeof(float)*mxDn*node->n_tasks; // TODO: this can become (n_tasks-1)
16568
+ cur += sizeof(float)*mxDn*node->n_tasks; // this is overestimated by x2
16569
+ }
16570
+
16571
+ work_size = MAX(work_size, cur);
16572
+ } break;
16573
+ case GGML_OP_WIN_PART:
16574
+ case GGML_OP_WIN_UNPART:
14501
16575
  case GGML_OP_MAP_UNARY:
14502
16576
  case GGML_OP_MAP_BINARY:
14503
16577
  {
14504
16578
  node->n_tasks = 1;
14505
16579
  } break;
16580
+ case GGML_OP_CROSS_ENTROPY_LOSS:
16581
+ {
16582
+ node->n_tasks = n_threads;
16583
+
16584
+ size_t cur = ggml_type_size(node->type)*(node->n_tasks + node->src0->ne[0]*node->n_tasks);
16585
+
16586
+ work_size = MAX(work_size, cur);
16587
+ } break;
16588
+ case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
16589
+ {
16590
+ node->n_tasks = n_threads;
16591
+
16592
+ size_t cur = ggml_type_size(node->type)*node->src0->ne[0]*node->n_tasks;
16593
+
16594
+ work_size = MAX(work_size, cur);
16595
+ } break;
14506
16596
  case GGML_OP_NONE:
14507
16597
  {
14508
16598
  node->n_tasks = 1;
@@ -15014,16 +17104,20 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context **
15014
17104
 
15015
17105
  if (!*ctx_data) {
15016
17106
  fprintf(stderr, "%s: failed to create ggml context\n", __func__);
17107
+ fclose(fin);
15017
17108
  return result;
15018
17109
  }
15019
17110
  }
15020
17111
 
15021
17112
  data = ggml_new_tensor_1d(*ctx_data, GGML_TYPE_I8, fsize);
15022
17113
 
15023
- const size_t ret = fread(data->data, sizeof(char), fsize, fin);
15024
- if (ret != fsize) {
15025
- fprintf(stderr, "%s: failed to read %s\n", __func__, fname);
15026
- return result;
17114
+ {
17115
+ const size_t ret = fread(data->data, sizeof(char), fsize, fin);
17116
+ if (ret != fsize) {
17117
+ fprintf(stderr, "%s: failed to read %s\n", __func__, fname);
17118
+ fclose(fin);
17119
+ return result;
17120
+ }
15027
17121
  }
15028
17122
 
15029
17123
  fclose(fin);
@@ -15478,6 +17572,7 @@ static void ggml_opt_get_grad(int np, struct ggml_tensor * const ps[], float * g
15478
17572
 
15479
17573
  static enum ggml_opt_result ggml_opt_adam(
15480
17574
  struct ggml_context * ctx,
17575
+ struct ggml_opt_context * opt,
15481
17576
  struct ggml_opt_params params,
15482
17577
  struct ggml_tensor * f,
15483
17578
  struct ggml_cgraph * gf,
@@ -15503,25 +17598,29 @@ static enum ggml_opt_result ggml_opt_adam(
15503
17598
  }
15504
17599
  }
15505
17600
 
17601
+ if ((opt->params.type != params.type) || (opt->nx != nx) || (opt->params.past != params.past)) {
17602
+ int iter = opt->iter;
17603
+ ggml_opt_init(opt->ctx, opt, params, nx);
17604
+ opt->iter = iter;
17605
+ }
17606
+
15506
17607
  // constants
15507
- const float alpha = params.adam.alpha;
17608
+ const float sched = params.adam.sched;
17609
+ const float decay = params.adam.decay * sched;
17610
+ const float alpha = params.adam.alpha * sched;
15508
17611
  const float beta1 = params.adam.beta1;
15509
17612
  const float beta2 = params.adam.beta2;
15510
17613
  const float eps = params.adam.eps;
15511
17614
 
15512
- float * x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // view of the parameters
15513
- float * g1 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // gradient
15514
- float * g2 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // gradient squared
15515
- float * m = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // first moment
15516
- float * v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // second moment
15517
- float * mh = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // first moment hat
15518
- float * vh = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // second moment hat
17615
+ float * x = opt->adam.x->data; // view of the parameters
17616
+ float * g1 = opt->adam.g1->data; // gradient
17617
+ float * g2 = opt->adam.g2->data; // gradient squared
17618
+ float * m = opt->adam.m->data; // first moment
17619
+ float * v = opt->adam.v->data; // second moment
17620
+ float * mh = opt->adam.mh->data; // first moment hat
17621
+ float * vh = opt->adam.vh->data; // second moment hat
15519
17622
 
15520
- float * pf = params.past > 0 ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)->data : NULL; // past function values
15521
-
15522
- // initialize
15523
- ggml_vec_set_f32(nx, m, 0.0f);
15524
- ggml_vec_set_f32(nx, v, 0.0f);
17623
+ float * pf = params.past > 0 ? opt->adam.pf->data : NULL; // past function values
15525
17624
 
15526
17625
  // update view
15527
17626
  ggml_opt_get_params(np, ps, x);
@@ -15531,16 +17630,27 @@ static enum ggml_opt_result ggml_opt_adam(
15531
17630
  ggml_set_f32 (f->grad, 1.0f);
15532
17631
  ggml_graph_compute(ctx, gb);
15533
17632
 
15534
- float fx_prev = ggml_get_f32_1d(f, 0);
17633
+ opt->adam.fx_prev = ggml_get_f32_1d(f, 0);
17634
+ opt->adam.fx_best = opt->adam.fx_prev;
15535
17635
  if (pf) {
15536
- pf[0] = fx_prev;
17636
+ pf[opt->iter % params.past] = opt->adam.fx_prev;
15537
17637
  }
15538
17638
 
15539
- int n_no_improvement = 0;
15540
- float fx_best = fx_prev;
17639
+ // initialize
17640
+ if (opt->just_initialized) {
17641
+ opt->adam.n_no_improvement = 0;
17642
+ opt->just_initialized = false;
17643
+ }
17644
+
17645
+ float * fx_best = &opt->adam.fx_best;
17646
+ float * fx_prev = &opt->adam.fx_prev;
17647
+ int * n_no_improvement = &opt->adam.n_no_improvement;
17648
+
17649
+ int iter0 = opt->iter;
15541
17650
 
15542
17651
  // run the optimizer
15543
17652
  for (int t = 0; t < params.adam.n_iter; ++t) {
17653
+ opt->iter = iter0 + t + 1;
15544
17654
  GGML_PRINT_DEBUG ("=== iter %d ===\n", t);
15545
17655
 
15546
17656
  GGML_PRINT_DEBUG ("f = %10.6f\n", ggml_get_f32_1d(f, 0));
@@ -15574,17 +17684,22 @@ static enum ggml_opt_result ggml_opt_adam(
15574
17684
 
15575
17685
  // m^hat = m_t / (1 - beta1^t)
15576
17686
  // v^hat = v_t / (1 - beta2^t)
15577
- // x_t = x_t-1 - alpha*m^hat/(sqrt(v^hat) + eps)
17687
+ // x_t = x_t-1 - sched*(alpha*m^hat/(sqrt(v^hat) + eps) + decay*x_t-1)
17688
+ // x_t = x_t-1 - sched*alpha*m^hat/(sqrt(v^hat) + eps) - sched*decay*x_t-1
17689
+ // x_t = x_t-1*(1-sched*decay) - sched*alpha*m^hat/(sqrt(v^hat) + eps)
17690
+ // x_t = x_t-1*(1-sched*decay) + sched*decay*(-alpha/decay)*m^hat/(sqrt(v^hat) + eps)
17691
+ // x_t = mix(x_t-1, (-alpha/decay)*m^hat/(sqrt(v^hat) + eps), sched*decay)
15578
17692
  ggml_vec_cpy_f32 (nx, mh, m);
15579
17693
  ggml_vec_cpy_f32 (nx, vh, v);
15580
17694
 
15581
- ggml_vec_scale_f32(nx, mh, alpha/(1.0f - powf(beta1, t + 1)));
15582
- ggml_vec_scale_f32(nx, vh, 1.0f/(1.0f - powf(beta2, t + 1)));
17695
+ ggml_vec_scale_f32(nx, mh, alpha/(1.0f - powf(beta1, opt->iter)));
17696
+ ggml_vec_scale_f32(nx, vh, 1.0f/(1.0f - powf(beta2, opt->iter)));
15583
17697
 
15584
17698
  ggml_vec_sqrt_f32 (nx, vh, vh);
15585
17699
  ggml_vec_acc1_f32 (nx, vh, eps);
15586
17700
 
15587
17701
  ggml_vec_div_f32 (nx, mh, mh, vh);
17702
+ ggml_vec_scale_f32(nx, x, 1.0f - decay);
15588
17703
  ggml_vec_sub_f32 (nx, x, x, mh);
15589
17704
 
15590
17705
  // update the parameters
@@ -15598,7 +17713,7 @@ static enum ggml_opt_result ggml_opt_adam(
15598
17713
  const float fx = ggml_get_f32_1d(f, 0);
15599
17714
 
15600
17715
  // check convergence
15601
- if (fabsf(fx - fx_prev)/fx < params.adam.eps_f) {
17716
+ if (fabsf(fx - fx_prev[0])/fx < params.adam.eps_f) {
15602
17717
  GGML_PRINT_DEBUG("converged\n");
15603
17718
 
15604
17719
  return GGML_OPT_OK;
@@ -15607,32 +17722,32 @@ static enum ggml_opt_result ggml_opt_adam(
15607
17722
  // delta-based convergence test
15608
17723
  if (pf != NULL) {
15609
17724
  // need at least params.past iterations to start checking for convergence
15610
- if (params.past <= t) {
15611
- const float rate = (pf[t%params.past] - fx)/fx;
17725
+ if (params.past <= iter0 + t) {
17726
+ const float rate = (pf[(iter0 + t)%params.past] - fx)/fx;
15612
17727
 
15613
17728
  if (fabsf(rate) < params.delta) {
15614
17729
  return GGML_OPT_OK;
15615
17730
  }
15616
17731
  }
15617
17732
 
15618
- pf[t%params.past] = fx;
17733
+ pf[(iter0 + t)%params.past] = fx;
15619
17734
  }
15620
17735
 
15621
17736
  // check for improvement
15622
17737
  if (params.max_no_improvement > 0) {
15623
- if (fx_best > fx) {
15624
- fx_best = fx;
15625
- n_no_improvement = 0;
17738
+ if (fx_best[0] > fx) {
17739
+ fx_best[0] = fx;
17740
+ n_no_improvement[0] = 0;
15626
17741
  } else {
15627
- ++n_no_improvement;
17742
+ ++n_no_improvement[0];
15628
17743
 
15629
- if (n_no_improvement >= params.max_no_improvement) {
17744
+ if (n_no_improvement[0] >= params.max_no_improvement) {
15630
17745
  return GGML_OPT_OK;
15631
17746
  }
15632
17747
  }
15633
17748
  }
15634
17749
 
15635
- fx_prev = fx;
17750
+ fx_prev[0] = fx;
15636
17751
 
15637
17752
  {
15638
17753
  const int64_t t_end_cpu = ggml_cycles();
@@ -15771,6 +17886,7 @@ static enum ggml_opt_result linesearch_backtracking(
15771
17886
 
15772
17887
  static enum ggml_opt_result ggml_opt_lbfgs(
15773
17888
  struct ggml_context * ctx,
17889
+ struct ggml_opt_context * opt,
15774
17890
  struct ggml_opt_params params,
15775
17891
  struct ggml_tensor * f,
15776
17892
  struct ggml_cgraph * gf,
@@ -15803,31 +17919,32 @@ static enum ggml_opt_result ggml_opt_lbfgs(
15803
17919
  }
15804
17920
  }
15805
17921
 
15806
- float * x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // current parameters
15807
- float * xp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // previous parameters
15808
- float * g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // current gradient
15809
- float * gp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // previous gradient
15810
- float * d = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // search direction
17922
+ if ((opt->params.type != params.type) || (opt->nx != nx) || (opt->params.past != params.past) || (opt->params.lbfgs.m != params.lbfgs.m)) {
17923
+ int iter = opt->iter;
17924
+ ggml_opt_init(ctx, opt, params, nx);
17925
+ opt->iter = iter;
17926
+ }
17927
+
17928
+ float * x = opt->lbfgs.x->data; // current parameters
17929
+ float * xp = opt->lbfgs.xp->data; // previous parameters
17930
+ float * g = opt->lbfgs.g->data; // current gradient
17931
+ float * gp = opt->lbfgs.gp->data; // previous gradient
17932
+ float * d = opt->lbfgs.d->data; // search direction
15811
17933
 
15812
- float * pf = params.past > 0 ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)->data : NULL; // past function values
17934
+ float * pf = params.past > 0 ? opt->lbfgs.pf->data : NULL; // past function values
15813
17935
 
15814
17936
  float fx = 0.0f; // cost function value
15815
17937
  float xnorm = 0.0f; // ||x||
15816
17938
  float gnorm = 0.0f; // ||g||
15817
- float step = 0.0f;
15818
17939
 
15819
17940
  // initialize x from the graph nodes
15820
17941
  ggml_opt_get_params(np, ps, x);
15821
17942
 
15822
17943
  // the L-BFGS memory
15823
- struct ggml_lbfgs_iteration_data * lm = alloca(sizeof(struct ggml_lbfgs_iteration_data)*m);
15824
-
15825
- for (int i = 0; i < m; ++i) {
15826
- lm[i].alpha = 0.0f;
15827
- lm[i].ys = 0.0f;
15828
- lm[i].s = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data;
15829
- lm[i].y = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data;
15830
- }
17944
+ float * lm_alpha = opt->lbfgs.lmal->data;
17945
+ float * lm_ys = opt->lbfgs.lmys->data;
17946
+ float * lm_s = opt->lbfgs.lms->data;
17947
+ float * lm_y = opt->lbfgs.lmy->data;
15831
17948
 
15832
17949
  // evaluate the function value and its gradient
15833
17950
  {
@@ -15842,12 +17959,6 @@ static enum ggml_opt_result ggml_opt_lbfgs(
15842
17959
  fx = ggml_get_f32_1d(f, 0);
15843
17960
  }
15844
17961
 
15845
- if (pf) {
15846
- pf[0] = fx;
15847
- }
15848
-
15849
- float fx_best = fx;
15850
-
15851
17962
  // search direction = -gradient
15852
17963
  ggml_vec_neg_f32(nx, d, g);
15853
17964
 
@@ -15864,26 +17975,43 @@ static enum ggml_opt_result ggml_opt_lbfgs(
15864
17975
  return GGML_OPT_OK;
15865
17976
  }
15866
17977
 
15867
- // initial step
15868
- ggml_vec_norm_inv_f32(nx, &step, d);
17978
+ if (opt->just_initialized) {
17979
+ if (pf) {
17980
+ pf[0] = fx;
17981
+ }
17982
+ opt->lbfgs.fx_best = fx;
17983
+
17984
+ // initial step
17985
+ ggml_vec_norm_inv_f32(nx, &opt->lbfgs.step, d);
17986
+ opt->lbfgs.j = 0;
17987
+ opt->lbfgs.k = 1;
17988
+ opt->lbfgs.end = 0;
17989
+ opt->lbfgs.n_no_improvement = 0;
17990
+ opt->just_initialized = false;
17991
+ }
17992
+
17993
+ float * fx_best = &opt->lbfgs.fx_best;
17994
+ float * step = &opt->lbfgs.step;
17995
+ int * j = &opt->lbfgs.j;
17996
+ int * k = &opt->lbfgs.k;
17997
+ int * end = &opt->lbfgs.end;
17998
+ int * n_no_improvement = &opt->lbfgs.n_no_improvement;
15869
17999
 
15870
- int j = 0;
15871
- int k = 1;
15872
- int ls = 0;
15873
- int end = 0;
15874
- int bound = 0;
15875
- int n_no_improvement = 0;
18000
+ int ls = 0;
18001
+ int bound = 0;
15876
18002
 
15877
18003
  float ys = 0.0f;
15878
18004
  float yy = 0.0f;
15879
18005
  float beta = 0.0f;
15880
18006
 
18007
+ int it = 0;
18008
+
15881
18009
  while (true) {
15882
18010
  // store the current position and gradient vectors
15883
18011
  ggml_vec_cpy_f32(nx, xp, x);
15884
18012
  ggml_vec_cpy_f32(nx, gp, g);
15885
18013
 
15886
- ls = linesearch_backtracking(ctx, &params, nx, x, &fx, g, d, &step, xp, f, gf, gb, np, ps);
18014
+ ls = linesearch_backtracking(ctx, &params, nx, x, &fx, g, d, step, xp, f, gf, gb, np, ps);
15887
18015
 
15888
18016
  if (ls < 0) {
15889
18017
  // linesearch failed - go back to the previous point and return
@@ -15909,32 +18037,32 @@ static enum ggml_opt_result ggml_opt_lbfgs(
15909
18037
  // delta-based convergence test
15910
18038
  if (pf != NULL) {
15911
18039
  // need at least params.past iterations to start checking for convergence
15912
- if (params.past <= k) {
15913
- const float rate = (pf[k%params.past] - fx)/fx;
18040
+ if (params.past <= k[0]) {
18041
+ const float rate = (pf[k[0]%params.past] - fx)/fx;
15914
18042
 
15915
18043
  if (fabsf(rate) < params.delta) {
15916
18044
  return GGML_OPT_OK;
15917
18045
  }
15918
18046
  }
15919
18047
 
15920
- pf[k%params.past] = fx;
18048
+ pf[k[0]%params.past] = fx;
15921
18049
  }
15922
18050
 
15923
18051
  // check for improvement
15924
18052
  if (params.max_no_improvement > 0) {
15925
- if (fx < fx_best) {
15926
- fx_best = fx;
15927
- n_no_improvement = 0;
18053
+ if (fx < fx_best[0]) {
18054
+ fx_best[0] = fx;
18055
+ n_no_improvement[0] = 0;
15928
18056
  } else {
15929
- n_no_improvement++;
18057
+ n_no_improvement[0]++;
15930
18058
 
15931
- if (n_no_improvement >= params.max_no_improvement) {
18059
+ if (n_no_improvement[0] >= params.max_no_improvement) {
15932
18060
  return GGML_OPT_OK;
15933
18061
  }
15934
18062
  }
15935
18063
  }
15936
18064
 
15937
- if (params.lbfgs.n_iter != 0 && params.lbfgs.n_iter < k + 1) {
18065
+ if (params.lbfgs.n_iter != 0 && params.lbfgs.n_iter < it + 1) {
15938
18066
  // reached the maximum number of iterations
15939
18067
  return GGML_OPT_DID_NOT_CONVERGE;
15940
18068
  }
@@ -15943,50 +18071,51 @@ static enum ggml_opt_result ggml_opt_lbfgs(
15943
18071
  // s_{k+1} = x_{k+1} - x_{k} = \step * d_{k}.
15944
18072
  // y_{k+1} = g_{k+1} - g_{k}.
15945
18073
  //
15946
- ggml_vec_sub_f32(nx, lm[end].s, x, xp);
15947
- ggml_vec_sub_f32(nx, lm[end].y, g, gp);
18074
+ ggml_vec_sub_f32(nx, &lm_s[end[0]*nx], x, xp);
18075
+ ggml_vec_sub_f32(nx, &lm_y[end[0]*nx], g, gp);
15948
18076
 
15949
18077
  // compute scalars ys and yy:
15950
18078
  // ys = y^t \cdot s -> 1 / \rho.
15951
18079
  // yy = y^t \cdot y.
15952
18080
  //
15953
- ggml_vec_dot_f32(nx, &ys, lm[end].y, lm[end].s);
15954
- ggml_vec_dot_f32(nx, &yy, lm[end].y, lm[end].y);
18081
+ ggml_vec_dot_f32(nx, &ys, &lm_y[end[0]*nx], &lm_s[end[0] *nx]);
18082
+ ggml_vec_dot_f32(nx, &yy, &lm_y[end[0]*nx], &lm_y[end[0]*nx]);
15955
18083
 
15956
- lm[end].ys = ys;
18084
+ lm_ys[end[0]] = ys;
15957
18085
 
15958
18086
  // find new search direction
15959
18087
  // ref: https://en.wikipedia.org/wiki/Limited-memory_BFGS
15960
18088
 
15961
- bound = (m <= k) ? m : k;
15962
- k++;
15963
- end = (end + 1)%m;
18089
+ bound = (m <= k[0]) ? m : k[0];
18090
+ k[0]++;
18091
+ it++;
18092
+ end[0] = (end[0] + 1)%m;
15964
18093
 
15965
18094
  // initialize search direction with -g
15966
18095
  ggml_vec_neg_f32(nx, d, g);
15967
18096
 
15968
- j = end;
18097
+ j[0] = end[0];
15969
18098
  for (int i = 0; i < bound; ++i) {
15970
- j = (j + m - 1) % m;
18099
+ j[0] = (j[0] + m - 1) % m;
15971
18100
  // \alpha_{j} = \rho_{j} s^{t}_{j} \cdot q_{k+1}
15972
- ggml_vec_dot_f32(nx, &lm[j].alpha, lm[j].s, d);
15973
- lm[j].alpha /= lm[j].ys;
18101
+ ggml_vec_dot_f32(nx, &lm_alpha[j[0]], &lm_s[j[0]*nx], d);
18102
+ lm_alpha[j[0]] /= lm_ys[j[0]];
15974
18103
  // q_{i} = q_{i+1} - \alpha_{i} y_{i}
15975
- ggml_vec_mad_f32(nx, d, lm[j].y, -lm[j].alpha);
18104
+ ggml_vec_mad_f32(nx, d, &lm_y[j[0]*nx], -lm_alpha[j[0]]);
15976
18105
  }
15977
18106
 
15978
18107
  ggml_vec_scale_f32(nx, d, ys/yy);
15979
18108
 
15980
18109
  for (int i = 0; i < bound; ++i) {
15981
18110
  // \beta_{j} = \rho_{j} y^t_{j} \cdot \gamma_{i}
15982
- ggml_vec_dot_f32(nx, &beta, lm[j].y, d);
15983
- beta /= lm[j].ys;
18111
+ ggml_vec_dot_f32(nx, &beta, &lm_y[j[0]*nx], d);
18112
+ beta /= lm_ys[j[0]];
15984
18113
  // \gamma_{i+1} = \gamma_{i} + (\alpha_{j} - \beta_{j}) s_{j}
15985
- ggml_vec_mad_f32(nx, d, lm[j].s, lm[j].alpha - beta);
15986
- j = (j + 1)%m;
18114
+ ggml_vec_mad_f32(nx, d, &lm_s[j[0]*nx], lm_alpha[j[0]] - beta);
18115
+ j[0] = (j[0] + 1)%m;
15987
18116
  }
15988
18117
 
15989
- step = 1.0;
18118
+ step[0] = 1.0;
15990
18119
  }
15991
18120
 
15992
18121
  return GGML_OPT_DID_NOT_CONVERGE;
@@ -16011,6 +18140,8 @@ struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) {
16011
18140
 
16012
18141
  .adam = {
16013
18142
  .n_iter = 10000,
18143
+ .sched = 1.000f,
18144
+ .decay = 0.001f,
16014
18145
  .alpha = 0.001f,
16015
18146
  .beta1 = 0.9f,
16016
18147
  .beta2 = 0.999f,
@@ -16053,6 +18184,70 @@ struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) {
16053
18184
  return result;
16054
18185
  }
16055
18186
 
18187
+ GGML_API void ggml_opt_init(
18188
+ struct ggml_context * ctx,
18189
+ struct ggml_opt_context * opt,
18190
+ struct ggml_opt_params params,
18191
+ int64_t nx) {
18192
+ opt->ctx = ctx;
18193
+ opt->params = params;
18194
+ opt->iter = 0;
18195
+ opt->nx = nx;
18196
+ opt->just_initialized = true;
18197
+ switch (opt->params.type) {
18198
+ case GGML_OPT_ADAM:
18199
+ {
18200
+ opt->adam.x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
18201
+ opt->adam.g1 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
18202
+ opt->adam.g2 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
18203
+ opt->adam.m = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
18204
+ opt->adam.v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
18205
+ opt->adam.mh = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
18206
+ opt->adam.vh = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
18207
+ opt->adam.pf = params.past > 0
18208
+ ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)
18209
+ : NULL;
18210
+ ggml_set_zero(opt->adam.x);
18211
+ ggml_set_zero(opt->adam.g1);
18212
+ ggml_set_zero(opt->adam.g2);
18213
+ ggml_set_zero(opt->adam.m);
18214
+ ggml_set_zero(opt->adam.v);
18215
+ ggml_set_zero(opt->adam.mh);
18216
+ ggml_set_zero(opt->adam.vh);
18217
+ if (opt->adam.pf) {
18218
+ ggml_set_zero(opt->adam.pf);
18219
+ }
18220
+ } break;
18221
+ case GGML_OPT_LBFGS:
18222
+ {
18223
+ opt->lbfgs.x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
18224
+ opt->lbfgs.xp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
18225
+ opt->lbfgs.g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
18226
+ opt->lbfgs.gp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
18227
+ opt->lbfgs.d = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
18228
+ opt->lbfgs.pf = params.past > 0
18229
+ ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)
18230
+ : NULL;
18231
+ opt->lbfgs.lmal = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.lbfgs.m);
18232
+ opt->lbfgs.lmys = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.lbfgs.m);
18233
+ opt->lbfgs.lms = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nx, params.lbfgs.m);
18234
+ opt->lbfgs.lmy = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nx, params.lbfgs.m);
18235
+ ggml_set_zero(opt->lbfgs.x);
18236
+ ggml_set_zero(opt->lbfgs.xp);
18237
+ ggml_set_zero(opt->lbfgs.g);
18238
+ ggml_set_zero(opt->lbfgs.gp);
18239
+ ggml_set_zero(opt->lbfgs.d);
18240
+ if (opt->lbfgs.pf) {
18241
+ ggml_set_zero(opt->lbfgs.pf);
18242
+ }
18243
+ ggml_set_zero(opt->lbfgs.lmal);
18244
+ ggml_set_zero(opt->lbfgs.lmys);
18245
+ ggml_set_zero(opt->lbfgs.lms);
18246
+ ggml_set_zero(opt->lbfgs.lmy);
18247
+ } break;
18248
+ }
18249
+ }
18250
+
16056
18251
  enum ggml_opt_result ggml_opt(
16057
18252
  struct ggml_context * ctx,
16058
18253
  struct ggml_opt_params params,
@@ -16075,33 +18270,65 @@ enum ggml_opt_result ggml_opt(
16075
18270
 
16076
18271
  enum ggml_opt_result result = GGML_OPT_OK;
16077
18272
 
18273
+ struct ggml_opt_context * opt = (struct ggml_opt_context *) alloca(sizeof(struct ggml_opt_context));
18274
+
18275
+ ggml_opt_init(ctx, opt, params, 0);
18276
+ result = ggml_opt_resume(ctx, opt, f);
18277
+
18278
+ if (free_ctx) {
18279
+ ggml_free(ctx);
18280
+ }
18281
+
18282
+ return result;
18283
+ }
18284
+
18285
+ enum ggml_opt_result ggml_opt_resume(
18286
+ struct ggml_context * ctx,
18287
+ struct ggml_opt_context * opt,
18288
+ struct ggml_tensor * f) {
18289
+
18290
+ // build forward + backward compute graphs
18291
+ struct ggml_tensor * gfbuf = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(struct ggml_cgraph) / GGML_TYPE_SIZE[GGML_TYPE_I32]+ (sizeof(struct ggml_cgraph) % GGML_TYPE_SIZE[GGML_TYPE_I32] ? 1 : 0));
18292
+ struct ggml_tensor * gbbuf = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(struct ggml_cgraph) / GGML_TYPE_SIZE[GGML_TYPE_I32]+ (sizeof(struct ggml_cgraph) % GGML_TYPE_SIZE[GGML_TYPE_I32] ? 1 : 0));
18293
+
18294
+ struct ggml_cgraph * gf = (struct ggml_cgraph *) gfbuf->data;
18295
+ struct ggml_cgraph * gb = (struct ggml_cgraph *) gbbuf->data;
18296
+
18297
+ *gf = ggml_build_forward (f);
18298
+ *gb = ggml_build_backward(ctx, gf, true);
18299
+
18300
+ return ggml_opt_resume_g(ctx, opt, f, gf, gb);
18301
+ }
18302
+
18303
+ enum ggml_opt_result ggml_opt_resume_g(
18304
+ struct ggml_context * ctx,
18305
+ struct ggml_opt_context * opt,
18306
+ struct ggml_tensor * f,
18307
+ struct ggml_cgraph * gf,
18308
+ struct ggml_cgraph * gb) {
18309
+
16078
18310
  // build forward + backward compute graphs
16079
- struct ggml_cgraph gf = ggml_build_forward (f);
16080
- struct ggml_cgraph gb = ggml_build_backward(ctx, &gf, true);
18311
+ enum ggml_opt_result result = GGML_OPT_OK;
16081
18312
 
16082
- switch (params.type) {
18313
+ switch (opt->params.type) {
16083
18314
  case GGML_OPT_ADAM:
16084
18315
  {
16085
- result = ggml_opt_adam(ctx, params, f, &gf, &gb);
18316
+ result = ggml_opt_adam(ctx, opt, opt->params, f, gf, gb);
16086
18317
  } break;
16087
18318
  case GGML_OPT_LBFGS:
16088
18319
  {
16089
- result = ggml_opt_lbfgs(ctx, params, f, &gf, &gb);
18320
+ result = ggml_opt_lbfgs(ctx, opt, opt->params, f, gf, gb);
16090
18321
  } break;
16091
18322
  }
16092
18323
 
16093
- if (params.print_forward_graph) {
16094
- ggml_graph_print (&gf);
16095
- ggml_graph_dump_dot(&gf, NULL, "opt-forward.dot");
16096
- }
16097
-
16098
- if (params.print_backward_graph) {
16099
- ggml_graph_print (&gb);
16100
- ggml_graph_dump_dot(&gb, &gf, "opt-backward.dot");
18324
+ if (opt->params.print_forward_graph) {
18325
+ ggml_graph_print (gf);
18326
+ ggml_graph_dump_dot(gf, NULL, "opt-forward.dot");
16101
18327
  }
16102
18328
 
16103
- if (free_ctx) {
16104
- ggml_free(ctx);
18329
+ if (opt->params.print_backward_graph) {
18330
+ ggml_graph_print (gb);
18331
+ ggml_graph_dump_dot(gb, gf, "opt-backward.dot");
16105
18332
  }
16106
18333
 
16107
18334
  return result;
@@ -16301,6 +18528,18 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i
16301
18528
  result = ggml_quantize_q6_K(src + start, block, n, n, hist);
16302
18529
  } break;
16303
18530
  #endif
18531
+ case GGML_TYPE_F16:
18532
+ {
18533
+ int elemsize = sizeof(ggml_fp16_t);
18534
+ ggml_fp32_to_fp16_row(src + start, (ggml_fp16_t *)dst + start, n);
18535
+ result = n * elemsize;
18536
+ } break;
18537
+ case GGML_TYPE_F32:
18538
+ {
18539
+ int elemsize = sizeof(float);
18540
+ result = n * elemsize;
18541
+ memcpy((uint8_t *)dst + start * elemsize, src + start, result);
18542
+ } break;
16304
18543
  default:
16305
18544
  assert(false);
16306
18545
  }