cui-llama.rn 1.1.2 → 1.1.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/cpp/ggml.c CHANGED
@@ -69,23 +69,42 @@ int lm_ggml_sve_cnt_b = 0;
69
69
  #endif
70
70
  #include <windows.h>
71
71
 
72
+ #if !defined(__clang__)
72
73
  typedef volatile LONG atomic_int;
73
74
  typedef atomic_int atomic_bool;
74
75
  typedef atomic_int atomic_flag;
75
76
 
76
77
  #define ATOMIC_FLAG_INIT 0
77
78
 
79
+ typedef enum {
80
+ memory_order_relaxed,
81
+ memory_order_consume,
82
+ memory_order_acquire,
83
+ memory_order_release,
84
+ memory_order_acq_rel,
85
+ memory_order_seq_cst
86
+ } memory_order;
87
+
78
88
  static void atomic_store(atomic_int * ptr, LONG val) {
79
89
  InterlockedExchange(ptr, val);
80
90
  }
91
+ static void atomic_store_explicit(atomic_int * ptr, LONG val, memory_order mo) {
92
+ // TODO: add support for explicit memory order
93
+ InterlockedExchange(ptr, val);
94
+ }
81
95
  static LONG atomic_load(atomic_int * ptr) {
82
96
  return InterlockedCompareExchange(ptr, 0, 0);
83
97
  }
98
+ static LONG atomic_load_explicit(atomic_int * ptr, memory_order mo) {
99
+ // TODO: add support for explicit memory order
100
+ return InterlockedCompareExchange(ptr, 0, 0);
101
+ }
84
102
  static LONG atomic_fetch_add(atomic_int * ptr, LONG inc) {
85
103
  return InterlockedExchangeAdd(ptr, inc);
86
104
  }
87
- static LONG atomic_fetch_sub(atomic_int * ptr, LONG dec) {
88
- return atomic_fetch_add(ptr, -(dec));
105
+ static LONG atomic_fetch_add_explicit(atomic_int * ptr, LONG inc, memory_order mo) {
106
+ // TODO: add support for explicit memory order
107
+ return InterlockedExchangeAdd(ptr, inc);
89
108
  }
90
109
  static atomic_bool atomic_flag_test_and_set(atomic_flag * ptr) {
91
110
  return InterlockedExchange(ptr, 1);
@@ -93,6 +112,9 @@ static atomic_bool atomic_flag_test_and_set(atomic_flag * ptr) {
93
112
  static void atomic_flag_clear(atomic_flag * ptr) {
94
113
  InterlockedExchange(ptr, 0);
95
114
  }
115
+ #else // clang
116
+ #include <stdatomic.h>
117
+ #endif
96
118
 
97
119
  typedef HANDLE pthread_t;
98
120
 
@@ -121,8 +143,13 @@ static int sched_yield (void) {
121
143
  return 0;
122
144
  }
123
145
  #else
146
+
124
147
  #include <pthread.h>
125
148
  #include <stdatomic.h>
149
+ #include <sched.h>
150
+ #if defined(__FreeBSD__)
151
+ #include <pthread_np.h>
152
+ #endif
126
153
 
127
154
  typedef void * thread_ret_t;
128
155
 
@@ -260,6 +287,7 @@ void lm_ggml_abort(const char * file, int line, const char * fmt, ...) {
260
287
  #define LM_GGML_DEBUG 0
261
288
  #define LM_GGML_GELU_FP16
262
289
  #define LM_GGML_GELU_QUICK_FP16
290
+ #define LM_GGML_N_TASKS_MAX (-1)
263
291
 
264
292
  #define LM_GGML_SOFT_MAX_UNROLL 4
265
293
  #define LM_GGML_VEC_DOT_UNROLL 2
@@ -1027,7 +1055,31 @@ static const lm_ggml_type_traits_t type_traits[LM_GGML_TYPE_COUNT] = {
1027
1055
  .ncols = 8,
1028
1056
  .gemv = lm_ggml_gemv_q4_0_8x8_q8_0,
1029
1057
  .gemm = lm_ggml_gemm_q4_0_8x8_q8_0,
1030
- }
1058
+ },
1059
+ [LM_GGML_TYPE_TQ1_0] = {
1060
+ .type_name = "tq1_0",
1061
+ .blck_size = QK_K,
1062
+ .type_size = sizeof(block_tq1_0),
1063
+ .is_quantized = true,
1064
+ .to_float = (lm_ggml_to_float_t) dequantize_row_tq1_0,
1065
+ .from_float = quantize_row_tq1_0,
1066
+ .from_float_ref = (lm_ggml_from_float_t) quantize_row_tq1_0_ref,
1067
+ .vec_dot = lm_ggml_vec_dot_tq1_0_q8_K,
1068
+ .vec_dot_type = LM_GGML_TYPE_Q8_K,
1069
+ .nrows = 1,
1070
+ },
1071
+ [LM_GGML_TYPE_TQ2_0] = {
1072
+ .type_name = "tq2_0",
1073
+ .blck_size = QK_K,
1074
+ .type_size = sizeof(block_tq2_0),
1075
+ .is_quantized = true,
1076
+ .to_float = (lm_ggml_to_float_t) dequantize_row_tq2_0,
1077
+ .from_float = quantize_row_tq2_0,
1078
+ .from_float_ref = (lm_ggml_from_float_t) quantize_row_tq2_0_ref,
1079
+ .vec_dot = lm_ggml_vec_dot_tq2_0_q8_K,
1080
+ .vec_dot_type = LM_GGML_TYPE_Q8_K,
1081
+ .nrows = 1,
1082
+ },
1031
1083
  };
1032
1084
 
1033
1085
  // For internal test use
@@ -1069,21 +1121,21 @@ lm_ggml_type_traits_t lm_ggml_internal_get_type_traits(enum lm_ggml_type type) {
1069
1121
  #define LM_GGML_F32x4_ADD vaddq_f32
1070
1122
  #define LM_GGML_F32x4_MUL vmulq_f32
1071
1123
  #define LM_GGML_F32x4_REDUCE_ONE(x) vaddvq_f32(x)
1072
- #define LM_GGML_F32x4_REDUCE(res, x) \
1073
- { \
1074
- int offset = LM_GGML_F32_ARR >> 1; \
1075
- for (int i = 0; i < offset; ++i) { \
1076
- x[i] = vaddq_f32(x[i], x[offset+i]); \
1077
- } \
1078
- offset >>= 1; \
1079
- for (int i = 0; i < offset; ++i) { \
1080
- x[i] = vaddq_f32(x[i], x[offset+i]); \
1081
- } \
1082
- offset >>= 1; \
1083
- for (int i = 0; i < offset; ++i) { \
1084
- x[i] = vaddq_f32(x[i], x[offset+i]); \
1085
- } \
1086
- res = LM_GGML_F32x4_REDUCE_ONE(x[0]); \
1124
+ #define LM_GGML_F32x4_REDUCE(res, x) \
1125
+ { \
1126
+ int offset = LM_GGML_F32_ARR >> 1; \
1127
+ for (int i = 0; i < offset; ++i) { \
1128
+ (x)[i] = vaddq_f32((x)[i], (x)[offset+i]); \
1129
+ } \
1130
+ offset >>= 1; \
1131
+ for (int i = 0; i < offset; ++i) { \
1132
+ (x)[i] = vaddq_f32((x)[i], (x)[offset+i]); \
1133
+ } \
1134
+ offset >>= 1; \
1135
+ for (int i = 0; i < offset; ++i) { \
1136
+ (x)[i] = vaddq_f32((x)[i], (x)[offset+i]); \
1137
+ } \
1138
+ (res) = LM_GGML_F32x4_REDUCE_ONE((x)[0]); \
1087
1139
  }
1088
1140
 
1089
1141
  #define LM_GGML_F32_VEC LM_GGML_F32x4
@@ -1110,30 +1162,30 @@ lm_ggml_type_traits_t lm_ggml_internal_get_type_traits(enum lm_ggml_type type) {
1110
1162
  #define LM_GGML_F16x8_FMA(a, b, c) vfmaq_f16(a, b, c)
1111
1163
  #define LM_GGML_F16x8_ADD vaddq_f16
1112
1164
  #define LM_GGML_F16x8_MUL vmulq_f16
1113
- #define LM_GGML_F16x8_REDUCE(res, x) \
1114
- do { \
1115
- int offset = LM_GGML_F16_ARR >> 1; \
1116
- for (int i = 0; i < offset; ++i) { \
1117
- x[i] = vaddq_f16(x[i], x[offset+i]); \
1118
- } \
1119
- offset >>= 1; \
1120
- for (int i = 0; i < offset; ++i) { \
1121
- x[i] = vaddq_f16(x[i], x[offset+i]); \
1122
- } \
1123
- offset >>= 1; \
1124
- for (int i = 0; i < offset; ++i) { \
1125
- x[i] = vaddq_f16(x[i], x[offset+i]); \
1126
- } \
1127
- const float32x4_t t0 = vcvt_f32_f16(vget_low_f16 (x[0])); \
1128
- const float32x4_t t1 = vcvt_f32_f16(vget_high_f16(x[0])); \
1129
- res = (lm_ggml_float) vaddvq_f32(vaddq_f32(t0, t1)); \
1165
+ #define LM_GGML_F16x8_REDUCE(res, x) \
1166
+ do { \
1167
+ int offset = LM_GGML_F16_ARR >> 1; \
1168
+ for (int i = 0; i < offset; ++i) { \
1169
+ (x)[i] = vaddq_f16((x)[i], (x)[offset+i]); \
1170
+ } \
1171
+ offset >>= 1; \
1172
+ for (int i = 0; i < offset; ++i) { \
1173
+ (x)[i] = vaddq_f16((x)[i], (x)[offset+i]); \
1174
+ } \
1175
+ offset >>= 1; \
1176
+ for (int i = 0; i < offset; ++i) { \
1177
+ (x)[i] = vaddq_f16((x)[i], (x)[offset+i]); \
1178
+ } \
1179
+ const float32x4_t t0 = vcvt_f32_f16(vget_low_f16 ((x)[0])); \
1180
+ const float32x4_t t1 = vcvt_f32_f16(vget_high_f16((x)[0])); \
1181
+ (res) = (lm_ggml_float) vaddvq_f32(vaddq_f32(t0, t1)); \
1130
1182
  } while (0)
1131
1183
 
1132
1184
  #define LM_GGML_F16_VEC LM_GGML_F16x8
1133
1185
  #define LM_GGML_F16_VEC_ZERO LM_GGML_F16x8_ZERO
1134
1186
  #define LM_GGML_F16_VEC_SET1 LM_GGML_F16x8_SET1
1135
1187
  #define LM_GGML_F16_VEC_LOAD(p, i) LM_GGML_F16x8_LOAD(p)
1136
- #define LM_GGML_F16_VEC_STORE(p, r, i) LM_GGML_F16x8_STORE((lm_ggml_fp16_internal_t *)(p), r[i])
1188
+ #define LM_GGML_F16_VEC_STORE(p, r, i) LM_GGML_F16x8_STORE((lm_ggml_fp16_internal_t *)(p), (r)[i])
1137
1189
  #define LM_GGML_F16_VEC_FMA LM_GGML_F16x8_FMA
1138
1190
  #define LM_GGML_F16_VEC_ADD LM_GGML_F16x8_ADD
1139
1191
  #define LM_GGML_F16_VEC_MUL LM_GGML_F16x8_MUL
@@ -1842,6 +1894,23 @@ static inline void __lsx_f16x4_store(lm_ggml_fp16_t * x, __m128 y) {
1842
1894
  #define LM_GGML_F16_ARR (LM_GGML_F16_STEP/LM_GGML_F16_EPR)
1843
1895
  #endif
1844
1896
 
1897
+ //
1898
+ // ggml object
1899
+ //
1900
+
1901
+ struct lm_ggml_object {
1902
+ size_t offs;
1903
+ size_t size;
1904
+
1905
+ struct lm_ggml_object * next;
1906
+
1907
+ enum lm_ggml_object_type type;
1908
+
1909
+ char padding[4];
1910
+ };
1911
+
1912
+ static const size_t LM_GGML_OBJECT_SIZE = sizeof(struct lm_ggml_object);
1913
+
1845
1914
  //
1846
1915
  // ggml context
1847
1916
  //
@@ -1868,28 +1937,102 @@ struct lm_ggml_context_container {
1868
1937
  struct lm_ggml_context context;
1869
1938
  };
1870
1939
 
1871
- struct lm_ggml_compute_state_shared {
1872
- const struct lm_ggml_cgraph * cgraph;
1873
- const struct lm_ggml_cplan * cplan;
1940
+ //
1941
+ // Threading defs
1942
+ //
1943
+
1944
+ typedef pthread_t lm_ggml_thread_t;
1945
+
1946
+ #if defined(_WIN32)
1947
+
1948
+ typedef CONDITION_VARIABLE lm_ggml_cond_t;
1949
+ typedef SRWLOCK lm_ggml_mutex_t;
1950
+
1951
+ #define lm_ggml_mutex_init(m) InitializeSRWLock(m)
1952
+ #define lm_ggml_mutex_destroy(m)
1953
+ #define lm_ggml_mutex_lock(m) AcquireSRWLockExclusive(m)
1954
+ #define lm_ggml_mutex_unlock(m) ReleaseSRWLockExclusive(m)
1955
+ #define lm_ggml_mutex_lock_shared(m) AcquireSRWLockShared(m)
1956
+ #define lm_ggml_mutex_unlock_shared(m) ReleaseSRWLockShared(m)
1957
+
1958
+ #define lm_ggml_cond_init(c) InitializeConditionVariable(c)
1959
+ #define lm_ggml_cond_destroy(c)
1960
+ #define lm_ggml_cond_wait(c, m) SleepConditionVariableSRW(c, m, INFINITE, CONDITION_VARIABLE_LOCKMODE_SHARED)
1961
+ #define lm_ggml_cond_broadcast(c) WakeAllConditionVariable(c)
1962
+
1963
+ #define lm_ggml_thread_create pthread_create
1964
+ #define lm_ggml_thread_join pthread_join
1965
+
1966
+ #else
1967
+
1968
+ typedef pthread_cond_t lm_ggml_cond_t;
1969
+ typedef pthread_mutex_t lm_ggml_mutex_t;
1970
+
1971
+ #define lm_ggml_mutex_init(m) pthread_mutex_init(m, NULL)
1972
+ #define lm_ggml_mutex_destroy(m) pthread_mutex_destroy(m)
1973
+ #define lm_ggml_mutex_lock(m) pthread_mutex_lock(m)
1974
+ #define lm_ggml_mutex_unlock(m) pthread_mutex_unlock(m)
1975
+ #define lm_ggml_mutex_lock_shared(m) pthread_mutex_lock(m)
1976
+ #define lm_ggml_mutex_unlock_shared(m) pthread_mutex_unlock(m)
1977
+
1978
+ #define lm_ggml_lock_init(x) UNUSED(x)
1979
+ #define lm_ggml_lock_destroy(x) UNUSED(x)
1980
+ #if defined(__x86_64__) || (defined(_MSC_VER) && defined(_M_AMD64))
1981
+ #define lm_ggml_lock_lock(x) _mm_pause()
1982
+ #else
1983
+ #define lm_ggml_lock_lock(x) UNUSED(x)
1984
+ #endif
1985
+ #define lm_ggml_lock_unlock(x) UNUSED(x)
1986
+
1987
+ #define LM_GGML_LOCK_INITIALIZER 0
1988
+ #define lm_ggml_cond_init(c) pthread_cond_init(c, NULL)
1989
+ #define lm_ggml_cond_destroy(c) pthread_cond_destroy(c)
1990
+ #define lm_ggml_cond_wait(c, m) pthread_cond_wait(c, m)
1991
+ #define lm_ggml_cond_broadcast(c) pthread_cond_broadcast(c)
1992
+
1993
+ #define lm_ggml_thread_create pthread_create
1994
+ #define lm_ggml_thread_join pthread_join
1995
+
1996
+ #endif
1997
+
1998
+ // Threadpool def
1999
+ struct lm_ggml_threadpool {
2000
+ lm_ggml_mutex_t mutex; // mutex for cond.var
2001
+ lm_ggml_cond_t cond; // cond.var for waiting for new work
1874
2002
 
1875
- int n_threads;
2003
+ struct lm_ggml_cgraph * cgraph;
2004
+ struct lm_ggml_cplan * cplan;
1876
2005
 
1877
2006
  // synchronization primitives
2007
+ atomic_int n_graph; // incremented when there is work to be done (i.e each graph)
1878
2008
  atomic_int n_barrier;
1879
2009
  atomic_int n_barrier_passed;
2010
+ atomic_int current_chunk; // currently processing chunk during Mat_Mul, shared between all the threads.
1880
2011
 
1881
- lm_ggml_abort_callback abort_callback; // abort lm_ggml_graph_compute when true
1882
- void * abort_callback_data;
2012
+ // these are atomic as an annotation for thread-sanitizer
2013
+ atomic_bool stop; // Used for stopping the threadpool altogether
2014
+ atomic_bool pause; // Used for pausing the threadpool or individual threads
1883
2015
 
1884
- atomic_int current_chunk; // currently processing chunk during mul_mat, shared between all the threads
2016
+ struct lm_ggml_compute_state * workers; // per thread state
2017
+ int n_threads_max; // number of threads in the pool
2018
+ int n_threads_cur; // number of threads used in the current graph
2019
+
2020
+ int32_t prio; // Scheduling priority
2021
+ uint32_t poll; // Polling level (0 - no polling)
1885
2022
 
1886
2023
  enum lm_ggml_status ec;
1887
2024
  };
1888
2025
 
2026
+ // Per-thread state
1889
2027
  struct lm_ggml_compute_state {
2028
+ #ifndef LM_GGML_USE_OPENMP
1890
2029
  lm_ggml_thread_t thrd;
2030
+ bool cpumask[LM_GGML_MAX_N_THREADS];
2031
+ int last_graph;
2032
+ bool pending;
2033
+ #endif
2034
+ struct lm_ggml_threadpool * threadpool;
1891
2035
  int ith;
1892
- struct lm_ggml_compute_state_shared * shared;
1893
2036
  };
1894
2037
 
1895
2038
  struct lm_ggml_compute_params {
@@ -1900,7 +2043,7 @@ struct lm_ggml_compute_params {
1900
2043
  size_t wsize;
1901
2044
  void * wdata;
1902
2045
 
1903
- struct lm_ggml_compute_state_shared * shared;
2046
+ struct lm_ggml_threadpool * threadpool;
1904
2047
  };
1905
2048
 
1906
2049
  //
@@ -2310,7 +2453,9 @@ inline static void lm_ggml_vec_scale_f16(const int n, lm_ggml_fp16_t * y, const
2310
2453
  inline static void lm_ggml_vec_norm_f32 (const int n, float * s, const float * x) { lm_ggml_vec_dot_f32(n, s, 0, x, 0, x, 0, 1); *s = sqrtf(*s); }
2311
2454
  inline static void lm_ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; }
2312
2455
  inline static void lm_ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); }
2313
- inline static void lm_ggml_vec_log_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = logf(x[i]); }
2456
+ inline static void lm_ggml_vec_log_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = logf(x[i]); }
2457
+ inline static void lm_ggml_vec_sin_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sinf(x[i]); }
2458
+ inline static void lm_ggml_vec_cos_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = cosf(x[i]); }
2314
2459
  inline static void lm_ggml_vec_abs_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fabsf(x[i]); }
2315
2460
  inline static void lm_ggml_vec_sgn_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : ((x[i] < 0.f) ? -1.f : 0.f); }
2316
2461
  inline static void lm_ggml_vec_step_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : 0.f; }
@@ -2322,6 +2467,7 @@ inline static void lm_ggml_vec_sigmoid_f32 (const int n, float * y, const float
2322
2467
  // TODO: optimize performance
2323
2468
  inline static void lm_ggml_vec_hardswish_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); }
2324
2469
  inline static void lm_ggml_vec_hardsigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); }
2470
+ inline static void lm_ggml_vec_exp_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = expf(x[i]); }
2325
2471
 
2326
2472
  static const float GELU_COEF_A = 0.044715f;
2327
2473
  static const float GELU_QUICK_COEF = -1.702f;
@@ -2669,6 +2815,19 @@ static lm_ggml_float lm_ggml_vec_soft_max_f32(const int n, float * y, const floa
2669
2815
  return sum;
2670
2816
  }
2671
2817
 
2818
+ static lm_ggml_float lm_ggml_vec_log_soft_max_f32(const int n, float * y, const float * x, float max) {
2819
+ // log(soft_max) = log(soft_max_i / soft_max_sum) = log(soft_max_i) - log(soft_max_sum) = (logit_i - max) - log(soft_max_i)
2820
+
2821
+ int i = 0;
2822
+ lm_ggml_float sum = 0;
2823
+ for (; i < n; ++i) {
2824
+ float val = x[i] - max;
2825
+ y[i] = val;
2826
+ sum += (lm_ggml_float)expf(val);
2827
+ }
2828
+ return sum = (lm_ggml_float)logf(sum);
2829
+ }
2830
+
2672
2831
  inline static float lm_ggml_silu_backward_f32(float x, float dy) {
2673
2832
  const float s = 1.0f/(1.0f + expf(-x));
2674
2833
  return dy*s*(1.0f + x*(1.0f - s));
@@ -2760,6 +2919,8 @@ static const char * LM_GGML_OP_NAME[LM_GGML_OP_COUNT] = {
2760
2919
  "SQR",
2761
2920
  "SQRT",
2762
2921
  "LOG",
2922
+ "SIN",
2923
+ "COS",
2763
2924
  "SUM",
2764
2925
  "SUM_ROWS",
2765
2926
  "MEAN",
@@ -2797,9 +2958,11 @@ static const char * LM_GGML_OP_NAME[LM_GGML_OP_COUNT] = {
2797
2958
  "CLAMP",
2798
2959
  "CONV_TRANSPOSE_1D",
2799
2960
  "IM2COL",
2961
+ "IM2COL_BACK",
2800
2962
  "CONV_TRANSPOSE_2D",
2801
2963
  "POOL_1D",
2802
2964
  "POOL_2D",
2965
+ "POOL_2D_BACK",
2803
2966
  "UPSCALE",
2804
2967
  "PAD",
2805
2968
  "ARANGE",
@@ -2815,6 +2978,7 @@ static const char * LM_GGML_OP_NAME[LM_GGML_OP_COUNT] = {
2815
2978
  "WIN_UNPART",
2816
2979
  "GET_REL_POS",
2817
2980
  "ADD_REL_POS",
2981
+ "RWKV_WKV",
2818
2982
 
2819
2983
  "UNARY",
2820
2984
 
@@ -2833,7 +2997,7 @@ static const char * LM_GGML_OP_NAME[LM_GGML_OP_COUNT] = {
2833
2997
  "CROSS_ENTROPY_LOSS_BACK",
2834
2998
  };
2835
2999
 
2836
- static_assert(LM_GGML_OP_COUNT == 74, "LM_GGML_OP_COUNT != 74");
3000
+ static_assert(LM_GGML_OP_COUNT == 79, "LM_GGML_OP_COUNT != 79");
2837
3001
 
2838
3002
  static const char * LM_GGML_OP_SYMBOL[LM_GGML_OP_COUNT] = {
2839
3003
  "none",
@@ -2848,6 +3012,8 @@ static const char * LM_GGML_OP_SYMBOL[LM_GGML_OP_COUNT] = {
2848
3012
  "x^2",
2849
3013
  "√x",
2850
3014
  "log(x)",
3015
+ "sin(x)",
3016
+ "cos(x)",
2851
3017
  "Σx",
2852
3018
  "Σx_k",
2853
3019
  "Σx/n",
@@ -2885,9 +3051,11 @@ static const char * LM_GGML_OP_SYMBOL[LM_GGML_OP_COUNT] = {
2885
3051
  "clamp(x)",
2886
3052
  "conv_transpose_1d(x)",
2887
3053
  "im2col(x)",
3054
+ "im2col_back(x)",
2888
3055
  "conv_transpose_2d(x)",
2889
3056
  "pool_1d(x)",
2890
3057
  "pool_2d(x)",
3058
+ "pool_2d_back(x)",
2891
3059
  "upscale(x)",
2892
3060
  "pad(x)",
2893
3061
  "arange(start, stop, step)",
@@ -2903,6 +3071,7 @@ static const char * LM_GGML_OP_SYMBOL[LM_GGML_OP_COUNT] = {
2903
3071
  "win_unpart(x)",
2904
3072
  "get_rel_pos(x)",
2905
3073
  "add_rel_pos(x)",
3074
+ "rwkv_wkv(k, v, r, tf, td, s)",
2906
3075
 
2907
3076
  "unary(x)",
2908
3077
 
@@ -2921,7 +3090,7 @@ static const char * LM_GGML_OP_SYMBOL[LM_GGML_OP_COUNT] = {
2921
3090
  "cross_entropy_loss_back(x,y)",
2922
3091
  };
2923
3092
 
2924
- static_assert(LM_GGML_OP_COUNT == 74, "LM_GGML_OP_COUNT != 74");
3093
+ static_assert(LM_GGML_OP_COUNT == 79, "LM_GGML_OP_COUNT != 79");
2925
3094
 
2926
3095
  static_assert(LM_GGML_OP_POOL_COUNT == 2, "LM_GGML_OP_POOL_COUNT != 2");
2927
3096
 
@@ -2940,14 +3109,28 @@ static const char * LM_GGML_UNARY_OP_NAME[LM_GGML_UNARY_OP_COUNT] = {
2940
3109
  "SILU",
2941
3110
  "HARDSWISH",
2942
3111
  "HARDSIGMOID",
3112
+ "EXP",
2943
3113
  };
2944
3114
 
2945
- static_assert(LM_GGML_UNARY_OP_COUNT == 13, "LM_GGML_UNARY_OP_COUNT != 13");
3115
+ static_assert(LM_GGML_UNARY_OP_COUNT == 14, "LM_GGML_UNARY_OP_COUNT != 14");
2946
3116
 
2947
3117
 
2948
3118
  static_assert(sizeof(struct lm_ggml_object)%LM_GGML_MEM_ALIGN == 0, "lm_ggml_object size must be a multiple of LM_GGML_MEM_ALIGN");
2949
3119
  static_assert(sizeof(struct lm_ggml_tensor)%LM_GGML_MEM_ALIGN == 0, "lm_ggml_tensor size must be a multiple of LM_GGML_MEM_ALIGN");
2950
3120
 
3121
+ // Helpers for polling loops
3122
+ #if defined(__aarch64__) && ( defined(__clang__) || defined(__GNUC__) )
3123
+ static inline void lm_ggml_thread_cpu_relax(void) {
3124
+ __asm__ volatile("yield" ::: "memory");
3125
+ }
3126
+ #elif defined(__x86_64__)
3127
+ static inline void lm_ggml_thread_cpu_relax(void) {
3128
+ _mm_pause();
3129
+ }
3130
+ #else
3131
+ static inline void lm_ggml_thread_cpu_relax(void) {;}
3132
+ #endif
3133
+
2951
3134
  //
2952
3135
  // NUMA support
2953
3136
  //
@@ -2995,42 +3178,36 @@ inline static void lm_ggml_critical_section_start(void) {
2995
3178
  }
2996
3179
 
2997
3180
  #ifdef LM_GGML_USE_OPENMP
2998
- static void lm_ggml_barrier(struct lm_ggml_compute_state_shared * shared) {
2999
- if (shared->n_threads == 1) {
3181
+ static void lm_ggml_barrier(struct lm_ggml_threadpool * threadpool) {
3182
+ if (threadpool->n_threads_cur == 1) {
3000
3183
  return;
3001
3184
  }
3002
3185
 
3003
3186
  #pragma omp barrier
3004
3187
  }
3005
3188
  #else
3006
- static void lm_ggml_barrier(struct lm_ggml_compute_state_shared * shared) {
3007
- if (shared->n_threads == 1) {
3189
+ static void lm_ggml_barrier(struct lm_ggml_threadpool * threadpool) {
3190
+ if (threadpool->n_threads_cur == 1) {
3008
3191
  return;
3009
3192
  }
3010
3193
 
3011
- atomic_int * n_barrier = &shared->n_barrier;
3012
- atomic_int * n_barrier_passed = &shared->n_barrier_passed;
3194
+ atomic_int * n_barrier = &threadpool->n_barrier;
3195
+ atomic_int * n_barrier_passed = &threadpool->n_barrier_passed;
3013
3196
 
3014
- int n_threads = shared->n_threads;
3015
- int passed_old = atomic_load(n_barrier_passed);
3197
+ int n_threads = threadpool->n_threads_cur;
3198
+ int passed_old = atomic_load_explicit(n_barrier_passed, memory_order_relaxed);
3016
3199
 
3017
3200
  if (atomic_fetch_add(n_barrier, 1) == n_threads - 1) {
3018
3201
  // last thread
3019
3202
  atomic_store(n_barrier, 0);
3020
- atomic_fetch_add(n_barrier_passed, 1);
3203
+ atomic_fetch_add_explicit(n_barrier_passed, 1, memory_order_relaxed);
3021
3204
  } else {
3022
3205
  // wait for other threads
3023
- const int n_spin_before_sleep = 100000;
3024
3206
  while (true) {
3025
- for (int i = 0; i < n_spin_before_sleep; i++) {
3026
- if (atomic_load(n_barrier_passed) != passed_old) {
3027
- return;
3028
- }
3029
- #if defined(__SSE3__)
3030
- _mm_pause();
3031
- #endif
3207
+ if (atomic_load_explicit(n_barrier_passed, memory_order_relaxed) != passed_old) {
3208
+ return;
3032
3209
  }
3033
- sched_yield();
3210
+ lm_ggml_thread_cpu_relax();
3034
3211
  }
3035
3212
  }
3036
3213
  }
@@ -3222,7 +3399,7 @@ double lm_ggml_type_sizef(enum lm_ggml_type type) {
3222
3399
  }
3223
3400
 
3224
3401
  LM_GGML_CALL const char * lm_ggml_type_name(enum lm_ggml_type type) {
3225
- return type_traits[type].type_name;
3402
+ return type < LM_GGML_TYPE_COUNT ? type_traits[type].type_name : "NONE";
3226
3403
  }
3227
3404
 
3228
3405
  LM_GGML_CALL bool lm_ggml_is_quantized(enum lm_ggml_type type) {
@@ -3688,7 +3865,7 @@ static struct lm_ggml_object * lm_ggml_new_object(struct lm_ggml_context * ctx,
3688
3865
 
3689
3866
  if (cur_end + size_needed + LM_GGML_OBJECT_SIZE > ctx->mem_size) {
3690
3867
  LM_GGML_PRINT("%s: not enough space in the context's memory pool (needed %zu, available %zu)\n",
3691
- __func__, cur_end + size_needed, ctx->mem_size);
3868
+ __func__, cur_end + size_needed + LM_GGML_OBJECT_SIZE, ctx->mem_size);
3692
3869
  assert(false);
3693
3870
  return NULL;
3694
3871
  }
@@ -3767,6 +3944,7 @@ static struct lm_ggml_tensor * lm_ggml_new_tensor_impl(
3767
3944
  }
3768
3945
 
3769
3946
  struct lm_ggml_object * const obj_new = lm_ggml_new_object(ctx, LM_GGML_OBJECT_TYPE_TENSOR, LM_GGML_TENSOR_SIZE + obj_alloc_size);
3947
+ LM_GGML_ASSERT(obj_new);
3770
3948
 
3771
3949
  // TODO: for recoverable errors, we would need to free the data allocated from the scratch buffer here
3772
3950
 
@@ -4486,8 +4664,6 @@ static struct lm_ggml_tensor * lm_ggml_add_impl(
4486
4664
  bool is_node = false;
4487
4665
 
4488
4666
  if (!inplace && (a->grad || b->grad)) {
4489
- // TODO: support backward pass for broadcasting
4490
- LM_GGML_ASSERT(lm_ggml_are_same_shape(a, b));
4491
4667
  is_node = true;
4492
4668
  }
4493
4669
 
@@ -4661,11 +4837,13 @@ static struct lm_ggml_tensor * lm_ggml_sub_impl(
4661
4837
  struct lm_ggml_tensor * a,
4662
4838
  struct lm_ggml_tensor * b,
4663
4839
  bool inplace) {
4664
- LM_GGML_ASSERT(lm_ggml_are_same_shape(a, b));
4840
+ LM_GGML_ASSERT(lm_ggml_can_repeat(b, a));
4665
4841
 
4666
4842
  bool is_node = false;
4667
4843
 
4668
4844
  if (!inplace && (a->grad || b->grad)) {
4845
+ // TODO: support backward pass for broadcasting
4846
+ LM_GGML_ASSERT(lm_ggml_are_same_shape(a, b));
4669
4847
  is_node = true;
4670
4848
  }
4671
4849
 
@@ -4880,6 +5058,72 @@ struct lm_ggml_tensor * lm_ggml_log_inplace(
4880
5058
  return lm_ggml_log_impl(ctx, a, true);
4881
5059
  }
4882
5060
 
5061
+ // lm_ggml_sin
5062
+
5063
+ static struct lm_ggml_tensor * lm_ggml_sin_impl(
5064
+ struct lm_ggml_context * ctx,
5065
+ struct lm_ggml_tensor * a,
5066
+ bool inplace) {
5067
+ bool is_node = false;
5068
+
5069
+ if (!inplace && (a->grad)) {
5070
+ is_node = true;
5071
+ }
5072
+
5073
+ struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a);
5074
+
5075
+ result->op = LM_GGML_OP_SIN;
5076
+ result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
5077
+ result->src[0] = a;
5078
+
5079
+ return result;
5080
+ }
5081
+
5082
+ struct lm_ggml_tensor * lm_ggml_sin(
5083
+ struct lm_ggml_context * ctx,
5084
+ struct lm_ggml_tensor * a) {
5085
+ return lm_ggml_sin_impl(ctx, a, false);
5086
+ }
5087
+
5088
+ struct lm_ggml_tensor * lm_ggml_sin_inplace(
5089
+ struct lm_ggml_context * ctx,
5090
+ struct lm_ggml_tensor * a) {
5091
+ return lm_ggml_sin_impl(ctx, a, true);
5092
+ }
5093
+
5094
+ // lm_ggml_cos
5095
+
5096
+ static struct lm_ggml_tensor * lm_ggml_cos_impl(
5097
+ struct lm_ggml_context * ctx,
5098
+ struct lm_ggml_tensor * a,
5099
+ bool inplace) {
5100
+ bool is_node = false;
5101
+
5102
+ if (!inplace && (a->grad)) {
5103
+ is_node = true;
5104
+ }
5105
+
5106
+ struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a);
5107
+
5108
+ result->op = LM_GGML_OP_COS;
5109
+ result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
5110
+ result->src[0] = a;
5111
+
5112
+ return result;
5113
+ }
5114
+
5115
+ struct lm_ggml_tensor * lm_ggml_cos(
5116
+ struct lm_ggml_context * ctx,
5117
+ struct lm_ggml_tensor * a) {
5118
+ return lm_ggml_cos_impl(ctx, a, false);
5119
+ }
5120
+
5121
+ struct lm_ggml_tensor * lm_ggml_cos_inplace(
5122
+ struct lm_ggml_context * ctx,
5123
+ struct lm_ggml_tensor * a) {
5124
+ return lm_ggml_cos_impl(ctx, a, true);
5125
+ }
5126
+
4883
5127
  // lm_ggml_sum
4884
5128
 
4885
5129
  struct lm_ggml_tensor * lm_ggml_sum(
@@ -5041,6 +5285,7 @@ struct lm_ggml_tensor * lm_ggml_concat(
5041
5285
  bool is_node = false;
5042
5286
 
5043
5287
  if (a->grad || b->grad) {
5288
+ LM_GGML_ABORT("fatal error"); // TODO: implement
5044
5289
  is_node = true;
5045
5290
  }
5046
5291
 
@@ -5162,6 +5407,7 @@ struct lm_ggml_tensor * lm_ggml_leaky_relu(
5162
5407
  bool is_node = false;
5163
5408
 
5164
5409
  if (!inplace && (a->grad)) {
5410
+ LM_GGML_ABORT("fatal error"); // TODO: not implemented
5165
5411
  is_node = true;
5166
5412
  }
5167
5413
 
@@ -5269,6 +5515,19 @@ struct lm_ggml_tensor * lm_ggml_hardsigmoid(
5269
5515
  return lm_ggml_unary(ctx, a, LM_GGML_UNARY_OP_HARDSIGMOID);
5270
5516
  }
5271
5517
 
5518
+ // ggml exp
5519
+ struct lm_ggml_tensor * lm_ggml_exp(
5520
+ struct lm_ggml_context * ctx,
5521
+ struct lm_ggml_tensor * a) {
5522
+ return lm_ggml_unary(ctx, a, LM_GGML_UNARY_OP_EXP);
5523
+ }
5524
+
5525
+ struct lm_ggml_tensor * lm_ggml_exp_inplace(
5526
+ struct lm_ggml_context * ctx,
5527
+ struct lm_ggml_tensor * a) {
5528
+ return lm_ggml_unary_inplace(ctx, a, LM_GGML_UNARY_OP_EXP);
5529
+ }
5530
+
5272
5531
  // lm_ggml_norm
5273
5532
 
5274
5533
  static struct lm_ggml_tensor * lm_ggml_norm_impl(
@@ -5587,6 +5846,7 @@ static struct lm_ggml_tensor * lm_ggml_set_impl(
5587
5846
  // make a view of the destination
5588
5847
  struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a);
5589
5848
 
5849
+ LM_GGML_ASSERT(offset < (size_t)(1 << 30));
5590
5850
  int32_t params[] = { nb1, nb2, nb3, offset, inplace ? 1 : 0 };
5591
5851
  lm_ggml_set_op_params(result, params, sizeof(params));
5592
5852
 
@@ -6544,14 +6804,12 @@ struct lm_ggml_tensor * lm_ggml_rope_back(
6544
6804
  LM_GGML_ASSERT(lm_ggml_is_vector(b));
6545
6805
  LM_GGML_ASSERT(b->type == LM_GGML_TYPE_I32);
6546
6806
  LM_GGML_ASSERT(a->ne[2] == b->ne[0]);
6547
- LM_GGML_ASSERT(c == NULL && "freq factors not implemented yet");
6548
-
6549
- LM_GGML_ASSERT((mode & 4) == 0 && "lm_ggml_rope_back() for ChatGLM not implemented yet");
6550
6807
 
6551
6808
  bool is_node = false;
6552
6809
 
6553
6810
  if (a->grad) {
6554
- is_node = false; // TODO: implement backward
6811
+ LM_GGML_ASSERT(false && "backwards pass not implemented");
6812
+ is_node = false;
6555
6813
  }
6556
6814
 
6557
6815
  struct lm_ggml_tensor * result = lm_ggml_dup_tensor(ctx, a);
@@ -6569,6 +6827,7 @@ struct lm_ggml_tensor * lm_ggml_rope_back(
6569
6827
  result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
6570
6828
  result->src[0] = a;
6571
6829
  result->src[1] = b;
6830
+ result->src[2] = c;
6572
6831
 
6573
6832
  return result;
6574
6833
  }
@@ -6727,17 +6986,20 @@ struct lm_ggml_tensor * lm_ggml_im2col(
6727
6986
  LM_GGML_ASSERT(a->ne[2] == b->ne[2]);
6728
6987
  } else {
6729
6988
  LM_GGML_ASSERT(a->ne[1] == b->ne[1]);
6989
+ LM_GGML_ASSERT(b->ne[3] == 1);
6730
6990
  }
6731
6991
  bool is_node = false;
6732
6992
 
6733
- if (a->grad || b->grad) {
6734
- LM_GGML_ABORT("fatal error"); // TODO: implement backward
6993
+ if (/*a->grad ||*/ b->grad) { // a is only used for its shape, not its data
6735
6994
  is_node = true;
6736
6995
  }
6737
6996
 
6738
6997
  const int64_t OH = is_2D ? lm_ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1) : 0;
6739
6998
  const int64_t OW = lm_ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0);
6740
6999
 
7000
+ LM_GGML_ASSERT((!is_2D || OH > 0) && "b too small compared to a");
7001
+ LM_GGML_ASSERT((OW > 0) && "b too small compared to a");
7002
+
6741
7003
  const int64_t ne[4] = {
6742
7004
  is_2D ? (a->ne[2] * a->ne[1] * a->ne[0]) : a->ne[1] * a->ne[0],
6743
7005
  OW,
@@ -6757,6 +7019,37 @@ struct lm_ggml_tensor * lm_ggml_im2col(
6757
7019
  return result;
6758
7020
  }
6759
7021
 
7022
+ struct lm_ggml_tensor * lm_ggml_im2col_back(
7023
+ struct lm_ggml_context * ctx,
7024
+ struct lm_ggml_tensor * a,
7025
+ struct lm_ggml_tensor * b,
7026
+ int64_t * ne,
7027
+ int s0,
7028
+ int s1,
7029
+ int p0,
7030
+ int p1,
7031
+ int d0,
7032
+ int d1,
7033
+ bool is_2D) {
7034
+
7035
+ bool is_node = false;
7036
+
7037
+ if (/*a->grad ||*/ b->grad) { // a is only used for its shape, not its data
7038
+ is_node = true;
7039
+ }
7040
+
7041
+ struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, 4, ne);
7042
+ int32_t params[] = { s0, s1, p0, p1, d0, d1, (is_2D ? 1 : 0) };
7043
+ lm_ggml_set_op_params(result, params, sizeof(params));
7044
+
7045
+ result->op = LM_GGML_OP_IM2COL_BACK;
7046
+ result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
7047
+ result->src[0] = a;
7048
+ result->src[1] = b;
7049
+
7050
+ return result;
7051
+ }
7052
+
6760
7053
  // a: [OC,IC, KH, KW]
6761
7054
  // b: [N, IC, IH, IW]
6762
7055
  // result: [N, OC, OH, OW]
@@ -6770,7 +7063,7 @@ struct lm_ggml_tensor * lm_ggml_conv_2d(
6770
7063
  int p1,
6771
7064
  int d0,
6772
7065
  int d1) {
6773
- struct lm_ggml_tensor * im2col = lm_ggml_im2col(ctx, a, b, s0, s1, p0, p1, d0, d1, true, LM_GGML_TYPE_F16); // [N, OH, OW, IC * KH * KW]
7066
+ struct lm_ggml_tensor * im2col = lm_ggml_im2col(ctx, a, b, s0, s1, p0, p1, d0, d1, true, a->type); // [N, OH, OW, IC * KH * KW]
6774
7067
 
6775
7068
  struct lm_ggml_tensor * result =
6776
7069
  lm_ggml_mul_mat(ctx,
@@ -6896,17 +7189,17 @@ struct lm_ggml_tensor * lm_ggml_pool_2d(
6896
7189
  bool is_node = false;
6897
7190
 
6898
7191
  if (a->grad) {
6899
- LM_GGML_ABORT("fatal error"); // TODO: implement backward
6900
7192
  is_node = true;
6901
7193
  }
6902
7194
 
6903
7195
  struct lm_ggml_tensor * result;
6904
- const int64_t ne[3] = {
7196
+ const int64_t ne[4] = {
6905
7197
  lm_ggml_calc_pool_output_size(a->ne[0], k0, s0, p0),
6906
7198
  lm_ggml_calc_pool_output_size(a->ne[1], k1, s1, p1),
6907
7199
  a->ne[2],
7200
+ a->ne[3],
6908
7201
  };
6909
- result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, 3, ne);
7202
+ result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, 4, ne);
6910
7203
 
6911
7204
  int32_t params[] = { op, k0, k1, s0, s1, p0, p1 };
6912
7205
  lm_ggml_set_op_params(result, params, sizeof(params));
@@ -6917,6 +7210,37 @@ struct lm_ggml_tensor * lm_ggml_pool_2d(
6917
7210
  return result;
6918
7211
  }
6919
7212
 
7213
+ struct lm_ggml_tensor * lm_ggml_pool_2d_back(
7214
+ struct lm_ggml_context * ctx,
7215
+ struct lm_ggml_tensor * a,
7216
+ struct lm_ggml_tensor * af,
7217
+ enum lm_ggml_op_pool op,
7218
+ int k0,
7219
+ int k1,
7220
+ int s0,
7221
+ int s1,
7222
+ float p0,
7223
+ float p1) {
7224
+
7225
+ bool is_node = false;
7226
+
7227
+ if (a->grad) {
7228
+ is_node = true;
7229
+ }
7230
+
7231
+ struct lm_ggml_tensor * result;
7232
+ result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, 4, af->ne);
7233
+
7234
+ int32_t params[] = { op, k0, k1, s0, s1, p0, p1 };
7235
+ lm_ggml_set_op_params(result, params, sizeof(params));
7236
+
7237
+ result->op = LM_GGML_OP_POOL_2D_BACK;
7238
+ result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
7239
+ result->src[0] = a;
7240
+ result->src[1] = af;
7241
+ return result;
7242
+ }
7243
+
6920
7244
  // lm_ggml_upscale
6921
7245
 
6922
7246
  static struct lm_ggml_tensor * lm_ggml_upscale_impl(
@@ -7057,6 +7381,11 @@ struct lm_ggml_tensor * lm_ggml_argsort(
7057
7381
  enum lm_ggml_sort_order order) {
7058
7382
  bool is_node = false;
7059
7383
 
7384
+ if (a->grad) {
7385
+ LM_GGML_ABORT("fatal error"); // TODO: not implemented
7386
+ is_node = true;
7387
+ }
7388
+
7060
7389
  struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_I32, LM_GGML_MAX_DIMS, a->ne);
7061
7390
 
7062
7391
  lm_ggml_set_op_params_i32(result, 0, (int32_t) order);
@@ -7467,6 +7796,59 @@ struct lm_ggml_tensor * lm_ggml_add_rel_pos_inplace(
7467
7796
  return lm_ggml_add_rel_pos_impl(ctx, a, pw, ph, true);
7468
7797
  }
7469
7798
 
7799
+ // lm_ggml_rwkv_wkv
7800
+
7801
+ struct lm_ggml_tensor * lm_ggml_rwkv_wkv(
7802
+ struct lm_ggml_context * ctx,
7803
+ struct lm_ggml_tensor * k,
7804
+ struct lm_ggml_tensor * v,
7805
+ struct lm_ggml_tensor * r,
7806
+ struct lm_ggml_tensor * tf,
7807
+ struct lm_ggml_tensor * td,
7808
+ struct lm_ggml_tensor * state) {
7809
+ LM_GGML_ASSERT(lm_ggml_is_contiguous(k));
7810
+ LM_GGML_ASSERT(lm_ggml_is_contiguous(v));
7811
+ LM_GGML_ASSERT(lm_ggml_is_contiguous(r));
7812
+ LM_GGML_ASSERT(lm_ggml_is_contiguous(tf));
7813
+ LM_GGML_ASSERT(lm_ggml_is_contiguous(td));
7814
+ LM_GGML_ASSERT(lm_ggml_is_contiguous(state));
7815
+
7816
+ const int64_t S = k->ne[0];
7817
+ const int64_t H = k->ne[2];
7818
+ const int64_t n_tokens = k->ne[3];
7819
+ const int64_t n_seqs = state->ne[1];
7820
+ {
7821
+ LM_GGML_ASSERT(k->ne[1] == 1);
7822
+ LM_GGML_ASSERT(v->ne[0] == 1 && v->ne[1] == S && v->ne[2] == H && v->ne[3] == n_tokens);
7823
+ LM_GGML_ASSERT(r->ne[0] == 1 && r->ne[1] == S && r->ne[2] == H && r->ne[3] == n_tokens);
7824
+ // TODO: RWKV v4 and v5
7825
+ LM_GGML_ASSERT(td->ne[0] == 1 && td->ne[1] == S && td->ne[2] == H && td->ne[3] == n_tokens);
7826
+ LM_GGML_ASSERT(lm_ggml_nelements(state) == S * S * H * n_seqs);
7827
+ }
7828
+
7829
+ bool is_node = false;
7830
+
7831
+ if (k->grad || v->grad || r->grad || tf->grad || td->grad || state->grad) {
7832
+ LM_GGML_ABORT("fatal error"); // TODO: implement backward
7833
+ is_node = true;
7834
+ }
7835
+
7836
+ // concat output and new_state
7837
+ const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 };
7838
+ struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, 4, ne);
7839
+
7840
+ result->op = LM_GGML_OP_RWKV_WKV;
7841
+ result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
7842
+ result->src[0] = k;
7843
+ result->src[1] = v;
7844
+ result->src[2] = r;
7845
+ result->src[3] = tf;
7846
+ result->src[4] = td;
7847
+ result->src[5] = state;
7848
+
7849
+ return result;
7850
+ }
7851
+
7470
7852
  // lm_ggml_unary
7471
7853
 
7472
7854
  static struct lm_ggml_tensor * lm_ggml_unary_impl(
@@ -7965,8 +8347,7 @@ static void lm_ggml_compute_forward_dup_same_cont(
7965
8347
  LM_GGML_ASSERT(lm_ggml_is_contiguous(dst) && lm_ggml_is_contiguous(src0));
7966
8348
  LM_GGML_ASSERT(src0->type == dst->type);
7967
8349
 
7968
- const size_t nb00 = src0->nb[0];
7969
- const size_t nb0 = dst->nb[0];
8350
+ const size_t nb0 = lm_ggml_type_size(src0->type);
7970
8351
 
7971
8352
  const int ith = params->ith; // thread index
7972
8353
  const int nth = params->nth; // number of threads
@@ -7980,8 +8361,8 @@ static void lm_ggml_compute_forward_dup_same_cont(
7980
8361
  if (ie0 < ie1) {
7981
8362
  memcpy(
7982
8363
  ((char *) dst->data + ie0*nb0),
7983
- ((char *) src0->data + ie0*nb00),
7984
- (ie1 - ie0) * lm_ggml_type_size(src0->type));
8364
+ ((char *) src0->data + ie0*nb0),
8365
+ (ie1 - ie0) * nb0);
7985
8366
  }
7986
8367
  }
7987
8368
 
@@ -7998,11 +8379,6 @@ static void lm_ggml_compute_forward_dup_f16(
7998
8379
  const int ith = params->ith; // thread index
7999
8380
  const int nth = params->nth; // number of threads
8000
8381
 
8001
- if (lm_ggml_is_contiguous(src0) && lm_ggml_is_contiguous(dst) && src0->type == dst->type) {
8002
- lm_ggml_compute_forward_dup_same_cont(params, dst);
8003
- return;
8004
- }
8005
-
8006
8382
  // parallelize by rows
8007
8383
  const int nr = ne01;
8008
8384
  // number of rows per thread
@@ -8267,11 +8643,6 @@ static void lm_ggml_compute_forward_dup_bf16(
8267
8643
  const int ith = params->ith; // thread index
8268
8644
  const int nth = params->nth; // number of threads
8269
8645
 
8270
- if (lm_ggml_is_contiguous(src0) && lm_ggml_is_contiguous(dst) && src0->type == dst->type) {
8271
- lm_ggml_compute_forward_dup_same_cont(params, dst);
8272
- return;
8273
- }
8274
-
8275
8646
  // parallelize by rows
8276
8647
  const int nr = ne01;
8277
8648
  // number of rows per thread
@@ -8623,11 +8994,6 @@ static void lm_ggml_compute_forward_dup_f32(
8623
8994
  const int ith = params->ith; // thread index
8624
8995
  const int nth = params->nth; // number of threads
8625
8996
 
8626
- if (lm_ggml_is_contiguous(src0) && lm_ggml_is_contiguous(dst) && src0->type == dst->type) {
8627
- lm_ggml_compute_forward_dup_same_cont(params, dst);
8628
- return;
8629
- }
8630
-
8631
8997
  // parallelize by rows
8632
8998
  const int nr = ne01;
8633
8999
  // number of rows per thread
@@ -8937,13 +9303,13 @@ static void lm_ggml_compute_forward_dup_bytes(
8937
9303
  LM_GGML_ASSERT(lm_ggml_nelements(dst) == lm_ggml_nelements(src0));
8938
9304
  LM_GGML_ASSERT(src0->type == dst->type);
8939
9305
 
9306
+ LM_GGML_TENSOR_UNARY_OP_LOCALS;
9307
+
8940
9308
  if (lm_ggml_is_contiguous(src0) && lm_ggml_is_contiguous(dst)) {
8941
9309
  lm_ggml_compute_forward_dup_same_cont(params, dst);
8942
9310
  return;
8943
9311
  }
8944
9312
 
8945
- LM_GGML_TENSOR_UNARY_OP_LOCALS;
8946
-
8947
9313
  const size_t type_size = lm_ggml_type_size(src0->type);
8948
9314
  const int ith = params->ith; // thread index
8949
9315
  const int nth = params->nth; // number of threads
@@ -9564,6 +9930,8 @@ static void lm_ggml_compute_forward_add(
9564
9930
  case LM_GGML_TYPE_Q4_K:
9565
9931
  case LM_GGML_TYPE_Q5_K:
9566
9932
  case LM_GGML_TYPE_Q6_K:
9933
+ case LM_GGML_TYPE_TQ1_0:
9934
+ case LM_GGML_TYPE_TQ2_0:
9567
9935
  case LM_GGML_TYPE_IQ2_XXS:
9568
9936
  case LM_GGML_TYPE_IQ2_XS:
9569
9937
  case LM_GGML_TYPE_IQ3_XXS:
@@ -9942,6 +10310,8 @@ static void lm_ggml_compute_forward_add1(
9942
10310
  case LM_GGML_TYPE_Q4_K:
9943
10311
  case LM_GGML_TYPE_Q5_K:
9944
10312
  case LM_GGML_TYPE_Q6_K:
10313
+ case LM_GGML_TYPE_TQ1_0:
10314
+ case LM_GGML_TYPE_TQ2_0:
9945
10315
  case LM_GGML_TYPE_IQ2_XXS:
9946
10316
  case LM_GGML_TYPE_IQ2_XS:
9947
10317
  case LM_GGML_TYPE_IQ3_XXS:
@@ -9993,7 +10363,7 @@ static void lm_ggml_compute_forward_acc_f32(
9993
10363
  ((char *) src0->data),
9994
10364
  lm_ggml_nbytes(dst));
9995
10365
  }
9996
- lm_ggml_barrier(params->shared);
10366
+ lm_ggml_barrier(params->threadpool);
9997
10367
  }
9998
10368
 
9999
10369
  const int ith = params->ith;
@@ -10070,6 +10440,8 @@ static void lm_ggml_compute_forward_acc(
10070
10440
  case LM_GGML_TYPE_Q4_K:
10071
10441
  case LM_GGML_TYPE_Q5_K:
10072
10442
  case LM_GGML_TYPE_Q6_K:
10443
+ case LM_GGML_TYPE_TQ1_0:
10444
+ case LM_GGML_TYPE_TQ2_0:
10073
10445
  case LM_GGML_TYPE_IQ2_XXS:
10074
10446
  case LM_GGML_TYPE_IQ2_XS:
10075
10447
  case LM_GGML_TYPE_IQ3_XXS:
@@ -10098,11 +10470,10 @@ static void lm_ggml_compute_forward_sub_f32(
10098
10470
  const struct lm_ggml_tensor * src0 = dst->src[0];
10099
10471
  const struct lm_ggml_tensor * src1 = dst->src[1];
10100
10472
 
10101
- if (params->ith != 0) {
10102
- return;
10103
- }
10473
+ assert(lm_ggml_can_repeat(src1, src0) && lm_ggml_are_same_shape(src0, dst));
10104
10474
 
10105
- assert(lm_ggml_are_same_shape(src0, src1) && lm_ggml_are_same_shape(src0, dst));
10475
+ const int ith = params->ith;
10476
+ const int nth = params->nth;
10106
10477
 
10107
10478
  const int nr = lm_ggml_nrows(src0);
10108
10479
 
@@ -10111,40 +10482,55 @@ static void lm_ggml_compute_forward_sub_f32(
10111
10482
  LM_GGML_ASSERT( nb0 == sizeof(float));
10112
10483
  LM_GGML_ASSERT(nb00 == sizeof(float));
10113
10484
 
10485
+ // rows per thread
10486
+ const int dr = (nr + nth - 1)/nth;
10487
+
10488
+ // row range for this thread
10489
+ const int ir0 = dr*ith;
10490
+ const int ir1 = MIN(ir0 + dr, nr);
10491
+
10114
10492
  if (nb10 == sizeof(float)) {
10115
- for (int ir = 0; ir < nr; ++ir) {
10116
- // src0, src1 and dst are same shape => same indices
10117
- const int i3 = ir/(ne2*ne1);
10118
- const int i2 = (ir - i3*ne2*ne1)/ne1;
10119
- const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
10493
+ for (int ir = ir0; ir < ir1; ++ir) {
10494
+ // src1 is broadcastable across src0 and dst in i1, i2, i3
10495
+ const int64_t i03 = ir/(ne02*ne01);
10496
+ const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
10497
+ const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
10498
+
10499
+ const int64_t i13 = i03 % ne13;
10500
+ const int64_t i12 = i02 % ne12;
10501
+ const int64_t i11 = i01 % ne11;
10502
+ const int64_t nr0 = ne00 / ne10;
10503
+
10504
+ float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
10505
+ float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
10506
+ float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
10120
10507
 
10508
+ for (int64_t r = 0; r < nr0; ++r) {
10121
10509
  #ifdef LM_GGML_USE_ACCELERATE
10122
- vDSP_vsub(
10123
- (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1,
10124
- (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1,
10125
- (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), 1,
10126
- ne0);
10510
+ vDSP_vsub(src1_ptr, 1, src0_ptr + r*ne10, 1, dst_ptr + r*ne10, 1, ne10);
10127
10511
  #else
10128
- lm_ggml_vec_sub_f32(ne0,
10129
- (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ),
10130
- (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
10131
- (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
10512
+ lm_ggml_vec_sub_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
10132
10513
  #endif
10133
- // }
10134
- // }
10514
+ }
10135
10515
  }
10136
10516
  } else {
10137
10517
  // src1 is not contiguous
10138
- for (int ir = 0; ir < nr; ++ir) {
10139
- // src0, src1 and dst are same shape => same indices
10140
- const int i3 = ir/(ne2*ne1);
10141
- const int i2 = (ir - i3*ne2*ne1)/ne1;
10142
- const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
10518
+ for (int ir = ir0; ir < ir1; ++ir) {
10519
+ // src1 is broadcastable across src0 and dst in i1, i2, i3
10520
+ const int64_t i03 = ir/(ne02*ne01);
10521
+ const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
10522
+ const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
10523
+
10524
+ const int64_t i13 = i03 % ne13;
10525
+ const int64_t i12 = i02 % ne12;
10526
+ const int64_t i11 = i01 % ne11;
10527
+
10528
+ float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
10529
+ float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
10143
10530
 
10144
- float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
10145
- float * src0_ptr = (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
10146
- for (int i0 = 0; i0 < ne0; i0++) {
10147
- float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11 + i0*nb10);
10531
+ for (int64_t i0 = 0; i0 < ne0; ++i0) {
10532
+ const int64_t i10 = i0 % ne10;
10533
+ float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10);
10148
10534
 
10149
10535
  dst_ptr[i0] = src0_ptr[i0] - *src1_ptr;
10150
10536
  }
@@ -10490,9 +10876,9 @@ static void lm_ggml_compute_forward_log(
10490
10876
  }
10491
10877
  }
10492
10878
 
10493
- // lm_ggml_compute_forward_sum
10879
+ // lm_ggml_compute_forward_sin
10494
10880
 
10495
- static void lm_ggml_compute_forward_sum_f32(
10881
+ static void lm_ggml_compute_forward_sin_f32(
10496
10882
  const struct lm_ggml_compute_params * params,
10497
10883
  struct lm_ggml_tensor * dst) {
10498
10884
 
@@ -10502,8 +10888,95 @@ static void lm_ggml_compute_forward_sum_f32(
10502
10888
  return;
10503
10889
  }
10504
10890
 
10505
- assert(lm_ggml_is_scalar(dst));
10891
+ LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, dst));
10892
+
10893
+ const int n = lm_ggml_nrows(src0);
10894
+ const int nc = src0->ne[0];
10895
+
10896
+ LM_GGML_ASSERT( dst->nb[0] == sizeof(float));
10897
+ LM_GGML_ASSERT(src0->nb[0] == sizeof(float));
10898
+
10899
+ for (int i = 0; i < n; i++) {
10900
+ lm_ggml_vec_sin_f32(nc,
10901
+ (float *) ((char *) dst->data + i*( dst->nb[1])),
10902
+ (float *) ((char *) src0->data + i*(src0->nb[1])));
10903
+ }
10904
+ }
10905
+
10906
+ static void lm_ggml_compute_forward_sin(
10907
+ const struct lm_ggml_compute_params * params,
10908
+ struct lm_ggml_tensor * dst) {
10909
+
10910
+ const struct lm_ggml_tensor * src0 = dst->src[0];
10911
+
10912
+ switch (src0->type) {
10913
+ case LM_GGML_TYPE_F32:
10914
+ {
10915
+ lm_ggml_compute_forward_sin_f32(params, dst);
10916
+ } break;
10917
+ default:
10918
+ {
10919
+ LM_GGML_ABORT("fatal error");
10920
+ }
10921
+ }
10922
+ }
10923
+
10924
+ // lm_ggml_compute_forward_cos
10925
+
10926
+ static void lm_ggml_compute_forward_cos_f32(
10927
+ const struct lm_ggml_compute_params * params,
10928
+ struct lm_ggml_tensor * dst) {
10929
+
10930
+ const struct lm_ggml_tensor * src0 = dst->src[0];
10931
+
10932
+ if (params->ith != 0) {
10933
+ return;
10934
+ }
10935
+
10936
+ LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, dst));
10937
+
10938
+ const int n = lm_ggml_nrows(src0);
10939
+ const int nc = src0->ne[0];
10940
+
10941
+ LM_GGML_ASSERT( dst->nb[0] == sizeof(float));
10942
+ LM_GGML_ASSERT(src0->nb[0] == sizeof(float));
10943
+
10944
+ for (int i = 0; i < n; i++) {
10945
+ lm_ggml_vec_cos_f32(nc,
10946
+ (float *) ((char *) dst->data + i*( dst->nb[1])),
10947
+ (float *) ((char *) src0->data + i*(src0->nb[1])));
10948
+ }
10949
+ }
10950
+
10951
+ static void lm_ggml_compute_forward_cos(
10952
+ const struct lm_ggml_compute_params * params,
10953
+ struct lm_ggml_tensor * dst) {
10954
+
10955
+ const struct lm_ggml_tensor * src0 = dst->src[0];
10956
+
10957
+ switch (src0->type) {
10958
+ case LM_GGML_TYPE_F32:
10959
+ {
10960
+ lm_ggml_compute_forward_cos_f32(params, dst);
10961
+ } break;
10962
+ default:
10963
+ {
10964
+ LM_GGML_ABORT("fatal error");
10965
+ }
10966
+ }
10967
+ }
10968
+
10969
+ // lm_ggml_compute_forward_sum
10970
+
10971
+ static void lm_ggml_compute_forward_sum_f32(
10972
+ const struct lm_ggml_compute_params * params,
10973
+ struct lm_ggml_tensor * dst) {
10974
+
10975
+ const struct lm_ggml_tensor * src0 = dst->src[0];
10506
10976
 
10977
+ if (params->ith != 0) {
10978
+ return;
10979
+ }
10507
10980
 
10508
10981
  assert(lm_ggml_is_scalar(dst));
10509
10982
  assert(src0->nb[0] == sizeof(float));
@@ -11762,6 +12235,48 @@ static void lm_ggml_compute_forward_hardsigmoid(
11762
12235
  }
11763
12236
  }
11764
12237
 
12238
+ static void lm_ggml_compute_forward_exp_f32(
12239
+ const struct lm_ggml_compute_params * params,
12240
+ struct lm_ggml_tensor * dst) {
12241
+
12242
+ const struct lm_ggml_tensor * src0 = dst->src[0];
12243
+
12244
+ if (params->ith != 0) {
12245
+ return;
12246
+ }
12247
+
12248
+ assert(lm_ggml_is_contiguous_1(src0));
12249
+ assert(lm_ggml_is_contiguous_1(dst));
12250
+ assert(lm_ggml_are_same_shape(src0, dst));
12251
+
12252
+ const int n = lm_ggml_nrows(src0);
12253
+ const int nc = src0->ne[0];
12254
+
12255
+ for (int i = 0; i < n; i++) {
12256
+ lm_ggml_vec_exp_f32(nc,
12257
+ (float *) ((char *) dst->data + i*( dst->nb[1])),
12258
+ (float *) ((char *) src0->data + i*(src0->nb[1])));
12259
+ }
12260
+ }
12261
+
12262
+ static void lm_ggml_compute_forward_exp(
12263
+ const struct lm_ggml_compute_params * params,
12264
+ struct lm_ggml_tensor * dst) {
12265
+
12266
+ const struct lm_ggml_tensor * src0 = dst->src[0];
12267
+
12268
+ switch (src0->type) {
12269
+ case LM_GGML_TYPE_F32:
12270
+ {
12271
+ lm_ggml_compute_forward_exp_f32(params, dst);
12272
+ } break;
12273
+ default:
12274
+ {
12275
+ LM_GGML_ABORT("fatal error");
12276
+ }
12277
+ }
12278
+ }
12279
+
11765
12280
 
11766
12281
  // lm_ggml_compute_forward_norm
11767
12282
 
@@ -12363,10 +12878,10 @@ UseGgmlGemm1:;
12363
12878
 
12364
12879
  if (ith == 0) {
12365
12880
  // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
12366
- atomic_store(&params->shared->current_chunk, nth);
12881
+ atomic_store_explicit(&params->threadpool->current_chunk, nth, memory_order_relaxed);
12367
12882
  }
12368
12883
 
12369
- lm_ggml_barrier(params->shared);
12884
+ lm_ggml_barrier(params->threadpool);
12370
12885
 
12371
12886
  #if LM_GGML_USE_LLAMAFILE
12372
12887
  if (src1->type != vec_dot_type) {
@@ -12474,7 +12989,7 @@ UseGgmlGemm2:;
12474
12989
  break;
12475
12990
  }
12476
12991
 
12477
- current_chunk = atomic_fetch_add(&params->shared->current_chunk, 1);
12992
+ current_chunk = atomic_fetch_add_explicit(&params->threadpool->current_chunk, 1, memory_order_relaxed);
12478
12993
  }
12479
12994
  }
12480
12995
 
@@ -12569,7 +13084,7 @@ static void lm_ggml_compute_forward_mul_mat_id(
12569
13084
  }
12570
13085
  }
12571
13086
 
12572
- lm_ggml_barrier(params->shared);
13087
+ lm_ggml_barrier(params->threadpool);
12573
13088
 
12574
13089
  // compute each matrix multiplication in sequence
12575
13090
  for (int cur_a = 0; cur_a < n_as; ++cur_a) {
@@ -12723,7 +13238,7 @@ static void lm_ggml_compute_forward_out_prod_f32(
12723
13238
  if (ith == 0) {
12724
13239
  lm_ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
12725
13240
  }
12726
- lm_ggml_barrier(params->shared);
13241
+ lm_ggml_barrier(params->threadpool);
12727
13242
 
12728
13243
  // dst[:,:,:,:] = 0
12729
13244
  // for i2,i3:
@@ -12841,7 +13356,7 @@ static void lm_ggml_compute_forward_out_prod_q_f32(
12841
13356
  if (ith == 0) {
12842
13357
  lm_ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
12843
13358
  }
12844
- lm_ggml_barrier(params->shared);
13359
+ lm_ggml_barrier(params->threadpool);
12845
13360
 
12846
13361
  // parallelize by last three dimensions
12847
13362
 
@@ -12907,6 +13422,8 @@ static void lm_ggml_compute_forward_out_prod(
12907
13422
  case LM_GGML_TYPE_Q4_K:
12908
13423
  case LM_GGML_TYPE_Q5_K:
12909
13424
  case LM_GGML_TYPE_Q6_K:
13425
+ case LM_GGML_TYPE_TQ1_0:
13426
+ case LM_GGML_TYPE_TQ2_0:
12910
13427
  case LM_GGML_TYPE_IQ2_XXS:
12911
13428
  case LM_GGML_TYPE_IQ2_XS:
12912
13429
  case LM_GGML_TYPE_IQ3_XXS:
@@ -13027,7 +13544,7 @@ static void lm_ggml_compute_forward_set_f32(
13027
13544
  ((char *) src0->data),
13028
13545
  lm_ggml_nbytes(dst));
13029
13546
  }
13030
- lm_ggml_barrier(params->shared);
13547
+ lm_ggml_barrier(params->threadpool);
13031
13548
  }
13032
13549
 
13033
13550
  const int ith = params->ith;
@@ -13095,6 +13612,8 @@ static void lm_ggml_compute_forward_set(
13095
13612
  case LM_GGML_TYPE_Q4_K:
13096
13613
  case LM_GGML_TYPE_Q5_K:
13097
13614
  case LM_GGML_TYPE_Q6_K:
13615
+ case LM_GGML_TYPE_TQ1_0:
13616
+ case LM_GGML_TYPE_TQ2_0:
13098
13617
  case LM_GGML_TYPE_IQ2_XXS:
13099
13618
  case LM_GGML_TYPE_IQ2_XS:
13100
13619
  case LM_GGML_TYPE_IQ3_XXS:
@@ -13208,7 +13727,7 @@ static void lm_ggml_compute_forward_get_rows_q(
13208
13727
  const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
13209
13728
  const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
13210
13729
 
13211
- assert(i01 >= 0 && i01 < ne01);
13730
+ LM_GGML_ASSERT(i01 >= 0 && i01 < ne01);
13212
13731
 
13213
13732
  dequantize_row_q(
13214
13733
  (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
@@ -13249,7 +13768,7 @@ static void lm_ggml_compute_forward_get_rows_f16(
13249
13768
  const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
13250
13769
  const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
13251
13770
 
13252
- assert(i01 >= 0 && i01 < ne01);
13771
+ LM_GGML_ASSERT(i01 >= 0 && i01 < ne01);
13253
13772
 
13254
13773
  lm_ggml_fp16_to_fp32_row(
13255
13774
  (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
@@ -13290,7 +13809,7 @@ static void lm_ggml_compute_forward_get_rows_bf16(
13290
13809
  const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
13291
13810
  const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
13292
13811
 
13293
- assert(i01 >= 0 && i01 < ne01);
13812
+ LM_GGML_ASSERT(i01 >= 0 && i01 < ne01);
13294
13813
 
13295
13814
  lm_ggml_bf16_to_fp32_row(
13296
13815
  (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
@@ -13331,7 +13850,7 @@ static void lm_ggml_compute_forward_get_rows_f32(
13331
13850
  const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
13332
13851
  const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
13333
13852
 
13334
- assert(i01 >= 0 && i01 < ne01);
13853
+ LM_GGML_ASSERT(i01 >= 0 && i01 < ne01);
13335
13854
 
13336
13855
  lm_ggml_vec_cpy_f32(nc,
13337
13856
  (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3),
@@ -13357,6 +13876,8 @@ static void lm_ggml_compute_forward_get_rows(
13357
13876
  case LM_GGML_TYPE_Q4_K:
13358
13877
  case LM_GGML_TYPE_Q5_K:
13359
13878
  case LM_GGML_TYPE_Q6_K:
13879
+ case LM_GGML_TYPE_TQ1_0:
13880
+ case LM_GGML_TYPE_TQ2_0:
13360
13881
  case LM_GGML_TYPE_IQ2_XXS:
13361
13882
  case LM_GGML_TYPE_IQ2_XS:
13362
13883
  case LM_GGML_TYPE_IQ3_XXS:
@@ -13606,7 +14127,7 @@ static void lm_ggml_compute_forward_diag_mask_f32(
13606
14127
  ((char *) src0->data),
13607
14128
  lm_ggml_nbytes(dst));
13608
14129
  }
13609
- lm_ggml_barrier(params->shared);
14130
+ lm_ggml_barrier(params->threadpool);
13610
14131
  }
13611
14132
 
13612
14133
  // TODO: handle transposed/permuted matrices
@@ -13946,6 +14467,8 @@ static void lm_ggml_compute_forward_clamp(
13946
14467
  case LM_GGML_TYPE_Q4_K:
13947
14468
  case LM_GGML_TYPE_Q5_K:
13948
14469
  case LM_GGML_TYPE_Q6_K:
14470
+ case LM_GGML_TYPE_TQ1_0:
14471
+ case LM_GGML_TYPE_TQ2_0:
13949
14472
  case LM_GGML_TYPE_IQ2_XXS:
13950
14473
  case LM_GGML_TYPE_IQ2_XS:
13951
14474
  case LM_GGML_TYPE_IQ3_XXS:
@@ -14382,7 +14905,7 @@ static void lm_ggml_compute_forward_conv_transpose_1d_f16_f32(
14382
14905
  // need to zero dst since we are accumulating into it
14383
14906
  memset(dst->data, 0, lm_ggml_nbytes(dst));
14384
14907
  }
14385
- lm_ggml_barrier(params->shared);
14908
+ lm_ggml_barrier(params->threadpool);
14386
14909
 
14387
14910
  const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
14388
14911
 
@@ -14470,7 +14993,7 @@ static void lm_ggml_compute_forward_conv_transpose_1d_f32(
14470
14993
  // need to zero dst since we are accumulating into it
14471
14994
  memset(dst->data, 0, lm_ggml_nbytes(dst));
14472
14995
  }
14473
- lm_ggml_barrier(params->shared);
14996
+ lm_ggml_barrier(params->threadpool);
14474
14997
 
14475
14998
  const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
14476
14999
 
@@ -14525,6 +15048,7 @@ static void lm_ggml_compute_forward_conv_transpose_1d(
14525
15048
  }
14526
15049
  }
14527
15050
 
15051
+ // lm_ggml_compute_forward_im2col_f32
14528
15052
  // src0: kernel [OC, IC, KH, KW]
14529
15053
  // src1: image [N, IC, IH, IW]
14530
15054
  // dst: result [N, OH, OW, IC*KH*KW]
@@ -14535,7 +15059,6 @@ static void lm_ggml_compute_forward_im2col_f32(
14535
15059
  const struct lm_ggml_tensor * src0 = dst->src[0];
14536
15060
  const struct lm_ggml_tensor * src1 = dst->src[1];
14537
15061
 
14538
- LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_F16);
14539
15062
  LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_F32);
14540
15063
  LM_GGML_ASSERT( dst->type == LM_GGML_TYPE_F32);
14541
15064
 
@@ -14566,7 +15089,6 @@ static void lm_ggml_compute_forward_im2col_f32(
14566
15089
  int ofs0 = is_2D ? nb13 : nb12;
14567
15090
  int ofs1 = is_2D ? nb12 : nb11;
14568
15091
 
14569
- LM_GGML_ASSERT(nb00 == sizeof(lm_ggml_fp16_t));
14570
15092
  LM_GGML_ASSERT(nb10 == sizeof(float));
14571
15093
 
14572
15094
  // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
@@ -14602,6 +15124,7 @@ static void lm_ggml_compute_forward_im2col_f32(
14602
15124
  }
14603
15125
 
14604
15126
 
15127
+ // lm_ggml_compute_forward_im2col_f16
14605
15128
  // src0: kernel [OC, IC, KH, KW]
14606
15129
  // src1: image [N, IC, IH, IW]
14607
15130
  // dst: result [N, OH, OW, IC*KH*KW]
@@ -14697,6 +15220,99 @@ static void lm_ggml_compute_forward_im2col(
14697
15220
  }
14698
15221
  }
14699
15222
 
15223
+ // lm_ggml_compute_forward_im2col_back_f32
15224
+
15225
+ static void lm_ggml_compute_forward_im2col_back_f32(
15226
+ const struct lm_ggml_compute_params * params,
15227
+ struct lm_ggml_tensor * dst) {
15228
+
15229
+ const struct lm_ggml_tensor * src0 = dst->src[0];
15230
+ const struct lm_ggml_tensor * src1 = dst->src[1];
15231
+
15232
+ LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_F32);
15233
+ LM_GGML_ASSERT( dst->type == LM_GGML_TYPE_F32);
15234
+
15235
+ LM_GGML_TENSOR_BINARY_OP_LOCALS;
15236
+
15237
+ const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
15238
+ const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
15239
+ const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
15240
+ const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
15241
+ const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
15242
+ const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
15243
+ const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
15244
+
15245
+ const int ith = params->ith;
15246
+ const int nth = params->nth;
15247
+
15248
+ const int64_t N = is_2D ? ne3 : ne2;
15249
+ const int64_t IC = is_2D ? ne2 : ne1;
15250
+ const int64_t IH = is_2D ? ne1 : 1;
15251
+ const int64_t IW = ne0;
15252
+
15253
+ const int64_t KH = is_2D ? ne01 : 1;
15254
+ const int64_t KW = ne00;
15255
+
15256
+ const int64_t OH = is_2D ? ne12 : 1;
15257
+ const int64_t OW = ne11;
15258
+
15259
+ int ofs0 = is_2D ? nb3 : nb2;
15260
+ int ofs1 = is_2D ? nb2 : nb1;
15261
+
15262
+ LM_GGML_ASSERT(nb0 == sizeof(float));
15263
+
15264
+ // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
15265
+ {
15266
+ float * const wdata = (float *) dst->data;
15267
+
15268
+ for (int64_t in = 0; in < N; in++) {
15269
+ for (int64_t iic = ith; iic < IC; iic += nth) {
15270
+ for (int64_t iih = 0; iih < IH; iih++) {
15271
+ for (int64_t iiw = 0; iiw < IW; iiw++) {
15272
+
15273
+ // micro kernel
15274
+ float grad = 0.0f;
15275
+ for (int64_t ikh = 0; ikh < KH; ikh++) {
15276
+ for (int64_t ikw = 0; ikw < KW; ikw++) {
15277
+ // For s0 > 1 some values were skipped over in the forward pass.
15278
+ // These values have tmpw % s0 != 0 and need to be skipped in the backwards pass as well.
15279
+ const int64_t tmpw = (iiw + p0 - ikw*d0);
15280
+ if (tmpw % s0 != 0) {
15281
+ continue;
15282
+ }
15283
+ const int64_t iow = tmpw / s0;
15284
+
15285
+ // Equivalent logic as above except for s1.
15286
+ int64_t ioh;
15287
+ if (is_2D) {
15288
+ const int64_t tmph = iih + p1 - ikh*d1;
15289
+
15290
+ if (tmph % s1 != 0) {
15291
+ continue;
15292
+ }
15293
+
15294
+ ioh = tmph / s1;
15295
+ } else {
15296
+ ioh = 0;
15297
+ }
15298
+
15299
+ if (iow < 0 || iow >= OW || ioh < 0 || ioh >= OH) {
15300
+ continue;
15301
+ }
15302
+
15303
+ const float * const src_data = (const float *) src1->data
15304
+ + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
15305
+ grad += src_data[iic*(KH*KW) + ikh*KW + ikw];
15306
+ }
15307
+ }
15308
+ float * dst_data = (float *)((char *) wdata + (in*ofs0 + iic*ofs1)); // [IH, IW]
15309
+ dst_data[iih*IW + iiw] = grad;
15310
+ }
15311
+ }
15312
+ }
15313
+ }
15314
+ }
15315
+ }
14700
15316
 
14701
15317
  // lm_ggml_compute_forward_conv_transpose_2d
14702
15318
 
@@ -14757,7 +15373,7 @@ static void lm_ggml_compute_forward_conv_transpose_2d(
14757
15373
 
14758
15374
  memset(dst->data, 0, lm_ggml_nbytes(dst));
14759
15375
  }
14760
- lm_ggml_barrier(params->shared);
15376
+ lm_ggml_barrier(params->threadpool);
14761
15377
 
14762
15378
  const int32_t stride = lm_ggml_get_op_params_i32(dst, 0);
14763
15379
 
@@ -14939,20 +15555,142 @@ static void lm_ggml_compute_forward_pool_2d(
14939
15555
  }
14940
15556
  }
14941
15557
 
14942
- // lm_ggml_compute_forward_upscale
14943
-
14944
- static void lm_ggml_compute_forward_upscale_f32(
14945
- const struct lm_ggml_compute_params * params,
14946
- struct lm_ggml_tensor * dst) {
15558
+ // lm_ggml_compute_forward_pool_2d_back
14947
15559
 
14948
- const struct lm_ggml_tensor * src0 = dst->src[0];
15560
+ static void lm_ggml_compute_forward_pool_2d_back(
15561
+ const struct lm_ggml_compute_params * params,
15562
+ struct lm_ggml_tensor * dst) {
14949
15563
 
14950
- LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_F32);
15564
+ const struct lm_ggml_tensor * src = dst->src[0];
15565
+ const struct lm_ggml_tensor * dstf = dst->src[1]; // forward tensor of dst
14951
15566
 
14952
- const int ith = params->ith;
14953
- const int nth = params->nth;
15567
+ assert(dst->type == LM_GGML_TYPE_F32 || dst->type == LM_GGML_TYPE_F16);
14954
15568
 
14955
- LM_GGML_TENSOR_UNARY_OP_LOCALS
15569
+ if (params->ith != 0) {
15570
+ return;
15571
+ }
15572
+
15573
+ const int32_t * opts = (const int32_t *)dst->op_params;
15574
+ enum lm_ggml_op_pool op = opts[0];
15575
+ const int k0 = opts[1];
15576
+ const int k1 = opts[2];
15577
+ const int s0 = opts[3];
15578
+ const int s1 = opts[4];
15579
+ const int p0 = opts[5];
15580
+ const int p1 = opts[6];
15581
+
15582
+ char * cdata = (char *) dst->data;
15583
+ const char * cdataf = (const char *) dstf->data;
15584
+ const char * const data_end = cdata + lm_ggml_nbytes(dst);
15585
+
15586
+ LM_GGML_ASSERT(params->ith == 0);
15587
+ memset(cdata, 0, lm_ggml_nbytes(dst));
15588
+
15589
+ const int64_t px = src->ne[0];
15590
+ const int64_t py = src->ne[1];
15591
+ const int64_t pa = px * py;
15592
+
15593
+ const float * splane = (const float *) src->data;
15594
+
15595
+ const int ka = k0 * k1;
15596
+ const int offset0 = -p0;
15597
+ const int offset1 = -p1;
15598
+
15599
+ while (cdata < data_end) {
15600
+ for (int oy = 0; oy < py; ++oy) {
15601
+ const float * const srow = splane + oy * px;
15602
+ for (int ox = 0; ox < px; ++ox) {
15603
+ const float grad0 = srow[ox];
15604
+
15605
+ const int ix = offset0 + ox * s0;
15606
+ const int iy = offset1 + oy * s1;
15607
+
15608
+ if (op == LM_GGML_OP_POOL_MAX) {
15609
+ float maxval = -FLT_MAX;
15610
+ int kxmax = -1;
15611
+ int kymax = -1;
15612
+
15613
+ for (int ky = 0; ky < k1; ++ky) {
15614
+ if (iy + ky < 0 || iy + ky >= dst->ne[1]) {
15615
+ continue;
15616
+ }
15617
+ const void * drowf = (const void *)(cdataf + dst->nb[1] * (iy + ky));
15618
+ for (int kx = 0; kx < k0; ++kx) {
15619
+ int j = ix + kx;
15620
+ if (j < 0 || j >= dst->ne[0]) {
15621
+ continue;
15622
+ }
15623
+
15624
+ const float val = dst->type == LM_GGML_TYPE_F32 ?
15625
+ ((const float *) drowf)[j] : LM_GGML_FP16_TO_FP32(((const lm_ggml_fp16_t *) drowf)[j]);
15626
+ if (val <= maxval) {
15627
+ continue;
15628
+ }
15629
+
15630
+ maxval = val;
15631
+ kxmax = kx;
15632
+ kymax = ky;
15633
+ }
15634
+ }
15635
+
15636
+ if (kxmax == -1 || kymax == -1) {
15637
+ continue;
15638
+ }
15639
+
15640
+ void * drow = (void *)(cdata + dst->nb[1] * (iy + kymax));
15641
+ const int j = ix + kxmax;
15642
+ if (dst->type == LM_GGML_TYPE_F32) {
15643
+ ((float *) drow)[j] += grad0;
15644
+ } else {
15645
+ ((lm_ggml_fp16_t *) drow)[j] = LM_GGML_FP32_TO_FP16(grad0 + LM_GGML_FP16_TO_FP32(((const lm_ggml_fp16_t *) drow)[j]));
15646
+ }
15647
+ } else if (op == LM_GGML_OP_POOL_AVG) {
15648
+ const float grad = grad0 / ka;
15649
+
15650
+ for (int ky = 0; ky < k1; ++ky) {
15651
+ if (iy + ky < 0 || iy + ky >= dst->ne[1]) {
15652
+ continue;
15653
+ }
15654
+ void * drow = (void *)(cdata + dst->nb[1] * (iy + ky));
15655
+ for (int kx = 0; kx < k0; ++kx) {
15656
+ int j = ix + kx;
15657
+ if (j < 0 || j >= dst->ne[0]) {
15658
+ continue;
15659
+ }
15660
+
15661
+ if (dst->type == LM_GGML_TYPE_F32) {
15662
+ ((float *) drow)[j] += grad;
15663
+ } else {
15664
+ ((lm_ggml_fp16_t *) drow)[j] += LM_GGML_FP32_TO_FP16(grad);
15665
+ }
15666
+ }
15667
+ }
15668
+ } else {
15669
+ LM_GGML_ASSERT(false);
15670
+ }
15671
+ }
15672
+ }
15673
+
15674
+ cdata += dst->nb[2];
15675
+ cdataf += dst->nb[2];
15676
+ splane += pa;
15677
+ }
15678
+ }
15679
+
15680
+ // lm_ggml_compute_forward_upscale
15681
+
15682
+ static void lm_ggml_compute_forward_upscale_f32(
15683
+ const struct lm_ggml_compute_params * params,
15684
+ struct lm_ggml_tensor * dst) {
15685
+
15686
+ const struct lm_ggml_tensor * src0 = dst->src[0];
15687
+
15688
+ LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_F32);
15689
+
15690
+ const int ith = params->ith;
15691
+ const int nth = params->nth;
15692
+
15693
+ LM_GGML_TENSOR_UNARY_OP_LOCALS
14956
15694
 
14957
15695
  const float sf0 = (float)ne0/src0->ne[0];
14958
15696
  const float sf1 = (float)ne1/src0->ne[1];
@@ -15503,7 +16241,7 @@ static void lm_ggml_compute_forward_flash_attn_back_f32(
15503
16241
  if (ith == 0) {
15504
16242
  memset(dst->data, 0, nb0*ne0*ne1*ne2*ne3);
15505
16243
  }
15506
- lm_ggml_barrier(params->shared);
16244
+ lm_ggml_barrier(params->threadpool);
15507
16245
 
15508
16246
  const int64_t elem_q = lm_ggml_nelements(q);
15509
16247
  const int64_t elem_k = lm_ggml_nelements(k);
@@ -16125,6 +16863,10 @@ static void lm_ggml_compute_forward_unary(
16125
16863
  {
16126
16864
  lm_ggml_compute_forward_hardsigmoid(params, dst);
16127
16865
  } break;
16866
+ case LM_GGML_UNARY_OP_EXP:
16867
+ {
16868
+ lm_ggml_compute_forward_exp(params, dst);
16869
+ } break;
16128
16870
  default:
16129
16871
  {
16130
16872
  LM_GGML_ABORT("fatal error");
@@ -16194,7 +16936,7 @@ static void lm_ggml_compute_forward_add_rel_pos_f32(
16194
16936
  if (params->ith == 0) {
16195
16937
  memcpy((char *) dst->data, (char *) src0->data, lm_ggml_nbytes(dst));
16196
16938
  }
16197
- lm_ggml_barrier(params->shared);
16939
+ lm_ggml_barrier(params->threadpool);
16198
16940
  }
16199
16941
  // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L357-L359
16200
16942
 
@@ -16260,6 +17002,96 @@ static void lm_ggml_compute_forward_add_rel_pos(
16260
17002
  }
16261
17003
  }
16262
17004
 
17005
+ // lm_ggml_compute_forward_rwkv_wkv
17006
+
17007
+ static void lm_ggml_compute_forward_rwkv_wkv_f32(
17008
+ const struct lm_ggml_compute_params * params,
17009
+ struct lm_ggml_tensor * dst) {
17010
+ const size_t T = dst->src[1]->ne[3];
17011
+ const size_t C = dst->ne[0];
17012
+ const size_t H = dst->src[1]->ne[2];
17013
+ const size_t n_seqs = dst->src[5]->ne[1];
17014
+
17015
+ float * dst_data = (float *) dst->data;
17016
+ float * state = ((float *) dst->data) + C * T;
17017
+
17018
+ if (params->ith != 0) {
17019
+ return;
17020
+ }
17021
+
17022
+ memset(dst_data, 0, T * C * sizeof(float));
17023
+
17024
+ float * k = (float *) dst->src[0]->data;
17025
+ float * v = (float *) dst->src[1]->data;
17026
+ float * r = (float *) dst->src[2]->data;
17027
+ float * time_faaaa = (float *) dst->src[3]->data;
17028
+ float * time_decay = (float *) dst->src[4]->data;
17029
+
17030
+ size_t t_stride = H * (C / H);
17031
+
17032
+ size_t h_stride = C / H;
17033
+ size_t h_stride_2d = (C / H) * (C / H);
17034
+
17035
+ // basically fused operations:
17036
+ // dst = r @ (time_faaaa * (k @ v) + state),
17037
+ // state = time_decay * state + (k @ v),
17038
+ // recursive through each token
17039
+ for (size_t t = 0; t < T; t++) {
17040
+ size_t t_offset = t * t_stride;
17041
+ size_t state_offset = (C / H) * C * (t / (T / n_seqs));
17042
+ float * state_cur = state + state_offset;
17043
+ float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
17044
+
17045
+ for (size_t h = 0; h < H; h++) {
17046
+ size_t h_offset = h * h_stride;
17047
+ size_t t_h_offset = t_offset + h_offset;
17048
+ size_t h_2d_offset = h * h_stride_2d;
17049
+
17050
+ for (size_t i = 0; i < C / H; i++) {
17051
+ size_t t_h_i_offset = t_h_offset + i;
17052
+ size_t h_i_offset = h_offset + i;
17053
+ size_t h_2d_i_offset = h_2d_offset + i * h_stride;
17054
+
17055
+ float k_val = k[t_h_i_offset];
17056
+ float r_val = r[t_h_i_offset];
17057
+ float time_faaaa_val = time_faaaa[h_i_offset];
17058
+ // RWKV v6: different time_decay for each token.
17059
+ float time_decay_val = time_decay[t_h_i_offset];
17060
+
17061
+ for (size_t j = 0; j < C / H; j ++) {
17062
+ size_t t_h_j_offset = t_h_offset + j;
17063
+ size_t h_2d_i_j_offset = h_2d_i_offset + j;
17064
+
17065
+ float v_val = v[t_h_j_offset];
17066
+ float kv_val = v_val * k_val;
17067
+ float prev_state_val = state_prev[h_2d_i_j_offset];
17068
+ float temp_val = kv_val * time_faaaa_val + prev_state_val;
17069
+ dst_data[t_h_j_offset] += temp_val * r_val;
17070
+ state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
17071
+ }
17072
+ }
17073
+ }
17074
+ }
17075
+ }
17076
+
17077
+ static void lm_ggml_compute_forward_rwkv_wkv(
17078
+ const struct lm_ggml_compute_params * params,
17079
+ struct lm_ggml_tensor * dst) {
17080
+
17081
+ const struct lm_ggml_tensor * src0 = dst->src[0];
17082
+
17083
+ switch (src0->type) {
17084
+ case LM_GGML_TYPE_F32:
17085
+ {
17086
+ lm_ggml_compute_forward_rwkv_wkv_f32(params, dst);
17087
+ } break;
17088
+ default:
17089
+ {
17090
+ LM_GGML_ABORT("fatal error");
17091
+ }
17092
+ }
17093
+ }
17094
+
16263
17095
  // lm_ggml_compute_forward_map_unary
16264
17096
 
16265
17097
  static void lm_ggml_compute_forward_map_unary_f32(
@@ -16479,9 +17311,7 @@ static void lm_ggml_compute_forward_cross_entropy_loss_f32(
16479
17311
  if (ith == 0) {
16480
17312
  memset(sums, 0, sizeof(float) * (nth + nth * nc));
16481
17313
  }
16482
- lm_ggml_barrier(params->shared);
16483
-
16484
- const double eps = 1e-9;
17314
+ lm_ggml_barrier(params->threadpool);
16485
17315
 
16486
17316
  // rows per thread
16487
17317
  const int dr = (nr + nth - 1)/nth;
@@ -16503,20 +17333,15 @@ static void lm_ggml_compute_forward_cross_entropy_loss_f32(
16503
17333
  }
16504
17334
  #endif
16505
17335
 
16506
- // soft_max
16507
17336
  float max = -INFINITY;
16508
17337
  lm_ggml_vec_max_f32(nc, &max, s0);
16509
- lm_ggml_float sum = lm_ggml_vec_soft_max_f32(nc, st, s0, max);
16510
- assert(sum > 0.0);
16511
- sum = (1.0 - eps) / sum;
17338
+ lm_ggml_float sum = lm_ggml_vec_log_soft_max_f32(nc, st, s0, max);
17339
+ assert(sum >= 0.0);
16512
17340
 
16513
- // avoid log(0) by rescaling from [0..1] to [eps..1]
16514
- lm_ggml_vec_scale_f32(nc, st, sum);
16515
- lm_ggml_vec_add1_f32(nc, st, st, eps);
16516
- lm_ggml_vec_log_f32(nc, st, st);
17341
+ lm_ggml_vec_add1_f32(nc, st, st, -sum);
16517
17342
  lm_ggml_vec_mul_f32(nc, st, st, s1);
16518
17343
 
16519
- float st_sum = 0;
17344
+ float st_sum = 0.0f;
16520
17345
  lm_ggml_vec_sum_f32(nc, &st_sum, st);
16521
17346
  sums[ith] += st_sum;
16522
17347
 
@@ -16527,7 +17352,7 @@ static void lm_ggml_compute_forward_cross_entropy_loss_f32(
16527
17352
  }
16528
17353
  #endif
16529
17354
  }
16530
- lm_ggml_barrier(params->shared);
17355
+ lm_ggml_barrier(params->threadpool);
16531
17356
 
16532
17357
  if (ith == 0) {
16533
17358
  float * dp = (float *) dst->data;
@@ -16573,8 +17398,6 @@ static void lm_ggml_compute_forward_cross_entropy_loss_back_f32(
16573
17398
  const int64_t ith = params->ith;
16574
17399
  const int64_t nth = params->nth;
16575
17400
 
16576
- const double eps = 1e-9;
16577
-
16578
17401
  // TODO: handle transposed/permuted matrices
16579
17402
  const int64_t nc = src0->ne[0];
16580
17403
  const int64_t nr = lm_ggml_nrows(src0);
@@ -16606,11 +17429,9 @@ static void lm_ggml_compute_forward_cross_entropy_loss_back_f32(
16606
17429
  lm_ggml_vec_max_f32(nc, &max, s0);
16607
17430
  lm_ggml_float sum = lm_ggml_vec_soft_max_f32(nc, ds0, s0, max);
16608
17431
  assert(sum > 0.0);
16609
- sum = (1.0 - eps) / sum;
17432
+ lm_ggml_vec_scale_f32(nc, ds0, 1.0/sum);
16610
17433
 
16611
17434
  // grad(src0) = (softmax(src0) - src1) * grad(cross_entropy_loss(src0, src1)) / nr
16612
- lm_ggml_vec_scale_f32(nc, ds0, sum);
16613
- lm_ggml_vec_add1_f32(nc, ds0, ds0, eps);
16614
17435
  lm_ggml_vec_sub_f32(nc, ds0, ds0, s1);
16615
17436
  lm_ggml_vec_scale_f32(nc, ds0, d[0] / (float) nr);
16616
17437
 
@@ -16691,6 +17512,14 @@ static void lm_ggml_compute_forward(struct lm_ggml_compute_params * params, stru
16691
17512
  {
16692
17513
  lm_ggml_compute_forward_log(params, tensor);
16693
17514
  } break;
17515
+ case LM_GGML_OP_SIN:
17516
+ {
17517
+ lm_ggml_compute_forward_sin(params, tensor);
17518
+ } break;
17519
+ case LM_GGML_OP_COS:
17520
+ {
17521
+ lm_ggml_compute_forward_cos(params, tensor);
17522
+ } break;
16694
17523
  case LM_GGML_OP_SUM:
16695
17524
  {
16696
17525
  lm_ggml_compute_forward_sum(params, tensor);
@@ -16831,6 +17660,10 @@ static void lm_ggml_compute_forward(struct lm_ggml_compute_params * params, stru
16831
17660
  {
16832
17661
  lm_ggml_compute_forward_im2col(params, tensor);
16833
17662
  } break;
17663
+ case LM_GGML_OP_IM2COL_BACK:
17664
+ {
17665
+ lm_ggml_compute_forward_im2col_back_f32(params, tensor);
17666
+ } break;
16834
17667
  case LM_GGML_OP_CONV_TRANSPOSE_2D:
16835
17668
  {
16836
17669
  lm_ggml_compute_forward_conv_transpose_2d(params, tensor);
@@ -16843,6 +17676,10 @@ static void lm_ggml_compute_forward(struct lm_ggml_compute_params * params, stru
16843
17676
  {
16844
17677
  lm_ggml_compute_forward_pool_2d(params, tensor);
16845
17678
  } break;
17679
+ case LM_GGML_OP_POOL_2D_BACK:
17680
+ {
17681
+ lm_ggml_compute_forward_pool_2d_back(params, tensor);
17682
+ } break;
16846
17683
  case LM_GGML_OP_UPSCALE:
16847
17684
  {
16848
17685
  lm_ggml_compute_forward_upscale(params, tensor);
@@ -16906,6 +17743,10 @@ static void lm_ggml_compute_forward(struct lm_ggml_compute_params * params, stru
16906
17743
  {
16907
17744
  lm_ggml_compute_forward_add_rel_pos(params, tensor);
16908
17745
  } break;
17746
+ case LM_GGML_OP_RWKV_WKV:
17747
+ {
17748
+ lm_ggml_compute_forward_rwkv_wkv(params, tensor);
17749
+ } break;
16909
17750
  case LM_GGML_OP_MAP_UNARY:
16910
17751
  {
16911
17752
  lm_ggml_unary_op_f32_t fun;
@@ -17211,7 +18052,11 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
17211
18052
  src0->grad = lm_ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
17212
18053
  }
17213
18054
  if (src1->grad) {
17214
- src1->grad = lm_ggml_add_or_set(ctx, src1->grad, tensor->grad, zero_table);
18055
+ if (lm_ggml_are_same_shape(src0, src1)) {
18056
+ src1->grad = lm_ggml_add_or_set(ctx, src1->grad, tensor->grad, zero_table);
18057
+ } else {
18058
+ src1->grad = lm_ggml_add_or_set(ctx, src1->grad, lm_ggml_repeat_back(ctx, tensor->grad, src1), zero_table);
18059
+ }
17215
18060
  }
17216
18061
  } break;
17217
18062
  case LM_GGML_OP_ADD1:
@@ -17337,6 +18182,30 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
17337
18182
  zero_table);
17338
18183
  }
17339
18184
  } break;
18185
+ case LM_GGML_OP_SIN:
18186
+ {
18187
+ if (src0->grad) {
18188
+ src0->grad =
18189
+ lm_ggml_add_or_set(ctx,
18190
+ src0->grad,
18191
+ lm_ggml_mul(ctx,
18192
+ tensor->grad,
18193
+ lm_ggml_cos(ctx, src0)),
18194
+ zero_table);
18195
+ }
18196
+ } break;
18197
+ case LM_GGML_OP_COS:
18198
+ {
18199
+ if (src0->grad) {
18200
+ src0->grad =
18201
+ lm_ggml_sub_or_set(ctx,
18202
+ src0->grad,
18203
+ lm_ggml_mul(ctx,
18204
+ tensor->grad,
18205
+ lm_ggml_sin(ctx, src0)),
18206
+ zero_table);
18207
+ }
18208
+ } break;
17340
18209
  case LM_GGML_OP_SUM:
17341
18210
  {
17342
18211
  if (src0->grad) {
@@ -17509,14 +18378,10 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
17509
18378
  if (src0->grad || src1->grad) {
17510
18379
  LM_GGML_ASSERT(src0->type == tensor->type);
17511
18380
  LM_GGML_ASSERT(tensor->grad->type == tensor->type);
17512
- LM_GGML_ASSERT(tensor->grad->type == src1->grad->type);
18381
+ LM_GGML_ASSERT(!src1->grad || src1->grad->type == tensor->grad->type);
17513
18382
 
17514
18383
  tensor_grad_view = lm_ggml_view_4d(ctx,
17515
- tensor->grad,
17516
- src1->grad->ne[0],
17517
- src1->grad->ne[1],
17518
- src1->grad->ne[2],
17519
- src1->grad->ne[3],
18384
+ tensor->grad, src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],
17520
18385
  nb1, nb2, nb3, offset);
17521
18386
  }
17522
18387
 
@@ -17585,9 +18450,9 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
17585
18450
 
17586
18451
  memcpy(&offset, tensor->op_params, sizeof(offset));
17587
18452
 
17588
- size_t nb1 = tensor->nb[1];
17589
- size_t nb2 = tensor->nb[2];
17590
- size_t nb3 = tensor->nb[3];
18453
+ size_t nb1 = tensor->nb[1];
18454
+ size_t nb2 = tensor->nb[2];
18455
+ size_t nb3 = tensor->nb[3];
17591
18456
 
17592
18457
  if (src0->type != src0->grad->type) {
17593
18458
  // gradient is typically F32, but src0 could be other type
@@ -17784,6 +18649,23 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
17784
18649
  LM_GGML_ABORT("fatal error"); // TODO: not implemented
17785
18650
  }
17786
18651
  case LM_GGML_OP_IM2COL:
18652
+ {
18653
+ if (src1->grad) {
18654
+ const int32_t s0 = lm_ggml_get_op_params_i32(tensor, 0);
18655
+ const int32_t s1 = lm_ggml_get_op_params_i32(tensor, 1);
18656
+ const int32_t p0 = lm_ggml_get_op_params_i32(tensor, 2);
18657
+ const int32_t p1 = lm_ggml_get_op_params_i32(tensor, 3);
18658
+ const int32_t d0 = lm_ggml_get_op_params_i32(tensor, 4);
18659
+ const int32_t d1 = lm_ggml_get_op_params_i32(tensor, 5);
18660
+ const bool is_2D = lm_ggml_get_op_params_i32(tensor, 6) == 1;
18661
+
18662
+ src1->grad = lm_ggml_add_or_set(ctx,
18663
+ src1->grad,
18664
+ lm_ggml_im2col_back(ctx, src0, tensor->grad, src1->ne, s0, s1, p0, p1, d0, d1, is_2D),
18665
+ zero_table);
18666
+ }
18667
+ } break;
18668
+ case LM_GGML_OP_IM2COL_BACK:
17787
18669
  {
17788
18670
  LM_GGML_ABORT("fatal error"); // TODO: not implemented
17789
18671
  }
@@ -17796,6 +18678,23 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
17796
18678
  LM_GGML_ABORT("fatal error"); // TODO: not implemented
17797
18679
  }
17798
18680
  case LM_GGML_OP_POOL_2D:
18681
+ {
18682
+ if (src0->grad) {
18683
+ const enum lm_ggml_op_pool op = lm_ggml_get_op_params_i32(tensor, 0);
18684
+ const int32_t k0 = lm_ggml_get_op_params_i32(tensor, 1);
18685
+ const int32_t k1 = lm_ggml_get_op_params_i32(tensor, 2);
18686
+ const int32_t s0 = lm_ggml_get_op_params_i32(tensor, 3);
18687
+ const int32_t s1 = lm_ggml_get_op_params_i32(tensor, 4);
18688
+ const int32_t p0 = lm_ggml_get_op_params_i32(tensor, 5);
18689
+ const int32_t p1 = lm_ggml_get_op_params_i32(tensor, 6);
18690
+
18691
+ src0->grad = lm_ggml_add_or_set(ctx,
18692
+ src0->grad,
18693
+ lm_ggml_pool_2d_back(ctx, tensor->grad, src0, op, k0, k1, s0, s1, p0, p1),
18694
+ zero_table);
18695
+ }
18696
+ } break;
18697
+ case LM_GGML_OP_POOL_2D_BACK:
17799
18698
  {
17800
18699
  LM_GGML_ABORT("fatal error"); // TODO: not implemented
17801
18700
  }
@@ -17961,12 +18860,22 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
17961
18860
  zero_table);
17962
18861
  }
17963
18862
  } break;
18863
+ case LM_GGML_UNARY_OP_EXP:
18864
+ {
18865
+ if (src0->grad) {
18866
+ src0->grad = lm_ggml_add_or_set(ctx,
18867
+ src0->grad,
18868
+ lm_ggml_mul(ctx, tensor, tensor->grad),
18869
+ zero_table);
18870
+ }
18871
+ } break;
17964
18872
  default:
17965
18873
  LM_GGML_ABORT("fatal error");
17966
18874
  }
17967
18875
  } break;
17968
18876
  case LM_GGML_OP_GET_REL_POS:
17969
18877
  case LM_GGML_OP_ADD_REL_POS:
18878
+ case LM_GGML_OP_RWKV_WKV:
17970
18879
  case LM_GGML_OP_MAP_UNARY:
17971
18880
  case LM_GGML_OP_MAP_BINARY:
17972
18881
  case LM_GGML_OP_MAP_CUSTOM1_F32:
@@ -18085,6 +18994,7 @@ void lm_ggml_build_forward_expand(struct lm_ggml_cgraph * cgraph, struct lm_ggml
18085
18994
 
18086
18995
  void lm_ggml_build_backward_expand(struct lm_ggml_context * ctx, struct lm_ggml_cgraph * gf, struct lm_ggml_cgraph * gb, bool keep) {
18087
18996
  LM_GGML_ASSERT(gf->n_nodes > 0);
18997
+ LM_GGML_ASSERT(gf->grads);
18088
18998
 
18089
18999
  // if we are keeping the gradient graph, we have to detach the gradient nodes from the original graph
18090
19000
  if (keep) {
@@ -18238,7 +19148,8 @@ void lm_ggml_graph_cpy(struct lm_ggml_cgraph * src, struct lm_ggml_cgraph * dst)
18238
19148
  }
18239
19149
 
18240
19150
  for (size_t i = 0; i < src->visited_hash_set.size; ++i) {
18241
- if (src->visited_hash_set.keys[i]) {
19151
+ // copy all hashset keys (tensors) that are in use
19152
+ if (lm_ggml_bitset_get(src->visited_hash_set.used, i)) {
18242
19153
  lm_ggml_hash_insert(&dst->visited_hash_set, src->visited_hash_set.keys[i]);
18243
19154
  }
18244
19155
  }
@@ -18268,64 +19179,33 @@ void lm_ggml_graph_clear(struct lm_ggml_cgraph * cgraph) {
18268
19179
  lm_ggml_hash_set_reset(&cgraph->visited_hash_set);
18269
19180
  }
18270
19181
 
18271
- //
18272
- // thread data
18273
- //
18274
- // synchronization is done via busy loops
18275
- // I tried using spin locks, but not sure how to use them correctly - the things I tried were slower than busy loops
18276
- //
18277
-
18278
- #ifdef __APPLE__
18279
-
18280
- //#include <os/lock.h>
18281
- //
18282
- //typedef os_unfair_lock lm_ggml_lock_t;
18283
- //
18284
- //#define lm_ggml_lock_init(x) UNUSED(x)
18285
- //#define lm_ggml_lock_destroy(x) UNUSED(x)
18286
- //#define lm_ggml_lock_lock os_unfair_lock_lock
18287
- //#define lm_ggml_lock_unlock os_unfair_lock_unlock
18288
- //
18289
- //#define LM_GGML_LOCK_INITIALIZER OS_UNFAIR_LOCK_INIT
18290
-
18291
- typedef int lm_ggml_lock_t;
18292
-
18293
- #define lm_ggml_lock_init(x) UNUSED(x)
18294
- #define lm_ggml_lock_destroy(x) UNUSED(x)
18295
- #define lm_ggml_lock_lock(x) UNUSED(x)
18296
- #define lm_ggml_lock_unlock(x) UNUSED(x)
18297
-
18298
- #define LM_GGML_LOCK_INITIALIZER 0
18299
-
18300
- #define lm_ggml_thread_create pthread_create
18301
- #define lm_ggml_thread_join pthread_join
18302
-
18303
- #else
18304
-
18305
- //typedef pthread_spinlock_t lm_ggml_lock_t;
18306
-
18307
- //#define lm_ggml_lock_init(x) pthread_spin_init(x, PTHREAD_PROCESS_PRIVATE)
18308
- //#define lm_ggml_lock_destroy pthread_spin_destroy
18309
- //#define lm_ggml_lock_lock pthread_spin_lock
18310
- //#define lm_ggml_lock_unlock pthread_spin_unlock
19182
+ int lm_ggml_graph_size(struct lm_ggml_cgraph * cgraph) {
19183
+ return cgraph->size;
19184
+ }
18311
19185
 
18312
- typedef int lm_ggml_lock_t;
19186
+ struct lm_ggml_tensor * lm_ggml_graph_node(struct lm_ggml_cgraph * cgraph, int i) {
19187
+ if (i < 0) {
19188
+ LM_GGML_ASSERT(cgraph->n_nodes + i >= 0);
19189
+ return cgraph->nodes[cgraph->n_nodes + i];
19190
+ }
18313
19191
 
18314
- #define lm_ggml_lock_init(x) UNUSED(x)
18315
- #define lm_ggml_lock_destroy(x) UNUSED(x)
18316
- #if defined(__x86_64__) || (defined(_MSC_VER) && defined(_M_AMD64))
18317
- #define lm_ggml_lock_lock(x) _mm_pause()
18318
- #else
18319
- #define lm_ggml_lock_lock(x) UNUSED(x)
18320
- #endif
18321
- #define lm_ggml_lock_unlock(x) UNUSED(x)
19192
+ LM_GGML_ASSERT(i < cgraph->n_nodes);
19193
+ return cgraph->nodes[i];
19194
+ }
18322
19195
 
18323
- #define LM_GGML_LOCK_INITIALIZER 0
19196
+ struct lm_ggml_tensor ** lm_ggml_graph_nodes(struct lm_ggml_cgraph * cgraph) {
19197
+ return cgraph->nodes;
19198
+ }
18324
19199
 
18325
- #define lm_ggml_thread_create pthread_create
18326
- #define lm_ggml_thread_join pthread_join
19200
+ int lm_ggml_graph_n_nodes(struct lm_ggml_cgraph * cgraph) {
19201
+ return cgraph->n_nodes;
19202
+ }
18327
19203
 
18328
- #endif
19204
+ void lm_ggml_graph_add_node(struct lm_ggml_cgraph * cgraph, struct lm_ggml_tensor * tensor) {
19205
+ LM_GGML_ASSERT(cgraph->size > cgraph->n_nodes);
19206
+ cgraph->nodes[cgraph->n_nodes] = tensor;
19207
+ cgraph->n_nodes++;
19208
+ }
18329
19209
 
18330
19210
  // Android's libc implementation "bionic" does not support setting affinity
18331
19211
  #if defined(__gnu_linux__)
@@ -18424,6 +19304,8 @@ static int lm_ggml_get_n_tasks(struct lm_ggml_tensor * node, int n_threads) {
18424
19304
  case LM_GGML_OP_SQR:
18425
19305
  case LM_GGML_OP_SQRT:
18426
19306
  case LM_GGML_OP_LOG:
19307
+ case LM_GGML_OP_SIN:
19308
+ case LM_GGML_OP_COS:
18427
19309
  case LM_GGML_OP_SUM:
18428
19310
  case LM_GGML_OP_SUM_ROWS:
18429
19311
  case LM_GGML_OP_MEAN:
@@ -18446,6 +19328,7 @@ static int lm_ggml_get_n_tasks(struct lm_ggml_tensor * node, int n_threads) {
18446
19328
  case LM_GGML_UNARY_OP_SIGMOID:
18447
19329
  case LM_GGML_UNARY_OP_HARDSWISH:
18448
19330
  case LM_GGML_UNARY_OP_HARDSIGMOID:
19331
+ case LM_GGML_UNARY_OP_EXP:
18449
19332
  {
18450
19333
  n_tasks = 1;
18451
19334
  } break;
@@ -18510,6 +19393,7 @@ static int lm_ggml_get_n_tasks(struct lm_ggml_tensor * node, int n_threads) {
18510
19393
  n_tasks = MIN(n_threads, lm_ggml_nrows(node->src[0]));
18511
19394
  } break;
18512
19395
  case LM_GGML_OP_IM2COL:
19396
+ case LM_GGML_OP_IM2COL_BACK:
18513
19397
  case LM_GGML_OP_CONV_TRANSPOSE_1D:
18514
19398
  case LM_GGML_OP_CONV_TRANSPOSE_2D:
18515
19399
  {
@@ -18517,6 +19401,7 @@ static int lm_ggml_get_n_tasks(struct lm_ggml_tensor * node, int n_threads) {
18517
19401
  } break;
18518
19402
  case LM_GGML_OP_POOL_1D:
18519
19403
  case LM_GGML_OP_POOL_2D:
19404
+ case LM_GGML_OP_POOL_2D_BACK:
18520
19405
  {
18521
19406
  n_tasks = 1;
18522
19407
  } break;
@@ -18535,6 +19420,7 @@ static int lm_ggml_get_n_tasks(struct lm_ggml_tensor * node, int n_threads) {
18535
19420
  case LM_GGML_OP_WIN_PART:
18536
19421
  case LM_GGML_OP_WIN_UNPART:
18537
19422
  case LM_GGML_OP_GET_REL_POS:
19423
+ case LM_GGML_OP_RWKV_WKV:
18538
19424
  case LM_GGML_OP_MAP_UNARY:
18539
19425
  case LM_GGML_OP_MAP_BINARY:
18540
19426
  case LM_GGML_OP_MAP_CUSTOM1_F32:
@@ -18603,9 +19489,281 @@ static int lm_ggml_get_n_tasks(struct lm_ggml_tensor * node, int n_threads) {
18603
19489
  return n_tasks;
18604
19490
  }
18605
19491
 
18606
- struct lm_ggml_cplan lm_ggml_graph_plan(const struct lm_ggml_cgraph * cgraph, int n_threads) {
19492
+ static thread_ret_t lm_ggml_graph_compute_secondary_thread(void* data);
19493
+
19494
+ #if defined(_WIN32)
19495
+ #include "windows.h"
19496
+
19497
+ // TODO: support > 64 CPUs
19498
+ bool lm_ggml_thread_apply_affinity(bool * mask) {
19499
+ HANDLE h = GetCurrentThread();
19500
+ uint64_t bitmask = 0ULL;
19501
+
19502
+ assert(LM_GGML_MAX_N_THREADS >= 64);
19503
+
19504
+ for (int32_t i = 0; i < 8; i++) {
19505
+ int32_t idx = i * 8;
19506
+ uint8_t val = 0;
19507
+ val |= mask[idx + 0] << 0;
19508
+ val |= mask[idx + 1] << 1;
19509
+ val |= mask[idx + 2] << 2;
19510
+ val |= mask[idx + 3] << 3;
19511
+ val |= mask[idx + 4] << 4;
19512
+ val |= mask[idx + 5] << 5;
19513
+ val |= mask[idx + 6] << 6;
19514
+ val |= mask[idx + 7] << 7;
19515
+ bitmask |= (uint64_t)val << idx;
19516
+ }
19517
+
19518
+ for (int32_t i = 64; i < LM_GGML_MAX_N_THREADS; i++) {
19519
+ if (mask[i]) {
19520
+ fprintf(stderr, "warn: setting thread-affinity for > 64 CPUs isn't supported on windows!\n");
19521
+ break;
19522
+ }
19523
+ }
19524
+
19525
+ DWORD_PTR m = (DWORD_PTR)bitmask;
19526
+
19527
+ m = SetThreadAffinityMask(h, m);
19528
+
19529
+ return m != 0;
19530
+ }
19531
+
19532
+ static bool lm_ggml_thread_apply_priority(int32_t prio) {
19533
+ // Note that on Windows the Process Priority Class must be updated in order to set Thread priority.
19534
+ // This is up to the applications.
19535
+ DWORD p = THREAD_PRIORITY_NORMAL;
19536
+ switch (prio) {
19537
+ case LM_GGML_SCHED_PRIO_NORMAL: p = THREAD_PRIORITY_NORMAL; break;
19538
+ case LM_GGML_SCHED_PRIO_MEDIUM: p = THREAD_PRIORITY_ABOVE_NORMAL; break;
19539
+ case LM_GGML_SCHED_PRIO_HIGH: p = THREAD_PRIORITY_HIGHEST; break;
19540
+ case LM_GGML_SCHED_PRIO_REALTIME: p = THREAD_PRIORITY_TIME_CRITICAL; break;
19541
+ }
19542
+
19543
+ if (prio == LM_GGML_SCHED_PRIO_NORMAL) {
19544
+ // Keep inherited policy/priority
19545
+ return true;
19546
+ }
19547
+
19548
+ if (!SetThreadPriority(GetCurrentThread(), p)) {
19549
+ fprintf(stderr, "warn: failed to set thread priority %d : (%d)\n", prio, (int) GetLastError());
19550
+ return false;
19551
+ }
19552
+
19553
+ return true;
19554
+ }
19555
+
19556
+ #elif defined(__APPLE__)
19557
+ #include <sys/types.h>
19558
+ #include <sys/resource.h>
19559
+
19560
+ static bool lm_ggml_thread_apply_affinity(const bool * mask) {
19561
+ // Not supported on Apple platforms
19562
+ UNUSED(mask);
19563
+ return true;
19564
+ }
19565
+
19566
+ static bool lm_ggml_thread_apply_priority(int32_t prio) {
19567
+ struct sched_param p;
19568
+ int32_t policy = SCHED_OTHER;
19569
+ switch (prio) {
19570
+ case LM_GGML_SCHED_PRIO_NORMAL: policy = SCHED_OTHER; p.sched_priority = 0; break;
19571
+ case LM_GGML_SCHED_PRIO_MEDIUM: policy = SCHED_FIFO; p.sched_priority = 40; break;
19572
+ case LM_GGML_SCHED_PRIO_HIGH: policy = SCHED_FIFO; p.sched_priority = 80; break;
19573
+ case LM_GGML_SCHED_PRIO_REALTIME: policy = SCHED_FIFO; p.sched_priority = 90; break;
19574
+ }
19575
+
19576
+ if (prio == LM_GGML_SCHED_PRIO_NORMAL) {
19577
+ // Keep inherited policy/priority
19578
+ return true;
19579
+ }
19580
+
19581
+ int32_t err = pthread_setschedparam(pthread_self(), policy, &p);
19582
+ if (err != 0) {
19583
+ fprintf(stderr, "warn: failed to set thread priority %d : %s (%d)\n", prio, strerror(err), err);
19584
+ return false;
19585
+ }
19586
+
19587
+ return true;
19588
+ }
19589
+
19590
+ #elif defined(__gnu_linux__)
19591
+ // TODO: this may not work on BSD, to be verified
19592
+
19593
+ static bool lm_ggml_thread_apply_affinity(const bool * mask) {
19594
+ cpu_set_t cpuset;
19595
+ int err;
19596
+
19597
+ CPU_ZERO(&cpuset);
19598
+
19599
+ for (uint32_t i = 0; i < LM_GGML_MAX_N_THREADS; i++) {
19600
+ if (mask[i]) {
19601
+ LM_GGML_PRINT_DEBUG("Thread %lx: adding %d to cpuset\n", pthread_self(), i);
19602
+ CPU_SET(i, &cpuset);
19603
+ }
19604
+ }
19605
+
19606
+ #ifdef __ANDROID__
19607
+ err = sched_setaffinity(0, sizeof(cpuset), &cpuset);
19608
+ if (err < 0) {
19609
+ err = errno;
19610
+ }
19611
+ #else
19612
+ err = pthread_setaffinity_np(pthread_self(), sizeof(cpuset), &cpuset);
19613
+ #endif
19614
+ if (err != 0) {
19615
+ fprintf(stderr, "warn: failed to set affinity mask 0x%llx : %s (%d)\n", (unsigned long long)mask, strerror(err), err);
19616
+ return false;
19617
+ }
19618
+
19619
+ return true;
19620
+ }
19621
+
19622
+ static bool lm_ggml_thread_apply_priority(int32_t prio) {
19623
+ struct sched_param p;
19624
+ int32_t policy = SCHED_OTHER;
19625
+ switch (prio) {
19626
+ case LM_GGML_SCHED_PRIO_NORMAL: policy = SCHED_OTHER; p.sched_priority = 0; break;
19627
+ case LM_GGML_SCHED_PRIO_MEDIUM: policy = SCHED_FIFO; p.sched_priority = 40; break;
19628
+ case LM_GGML_SCHED_PRIO_HIGH: policy = SCHED_FIFO; p.sched_priority = 80; break;
19629
+ case LM_GGML_SCHED_PRIO_REALTIME: policy = SCHED_FIFO; p.sched_priority = 90; break;
19630
+ }
19631
+
19632
+ if (prio == LM_GGML_SCHED_PRIO_NORMAL) {
19633
+ // Keep inherited policy/priority
19634
+ return true;
19635
+ }
19636
+
19637
+ int32_t err = pthread_setschedparam(pthread_self(), policy, &p);
19638
+ if (err != 0) {
19639
+ fprintf(stderr, "warn: failed to set thread priority %d : %s (%d)\n", prio, strerror(err), err);
19640
+ return false;
19641
+ }
19642
+
19643
+ return true;
19644
+ }
19645
+
19646
+ #else // unsupported platforms
19647
+
19648
+ static bool lm_ggml_thread_apply_affinity(const bool * mask) {
19649
+ UNUSED(mask);
19650
+ return true;
19651
+ }
19652
+
19653
+ static bool lm_ggml_thread_apply_priority(int32_t prio) {
19654
+ UNUSED(prio);
19655
+ return true;
19656
+ }
19657
+
19658
+ #endif
19659
+
19660
+ static bool lm_ggml_thread_cpumask_is_valid(const bool * mask) {
19661
+ for (int i = 0; i < LM_GGML_MAX_N_THREADS; i++) {
19662
+ if (mask[i]) { return true; }
19663
+ }
19664
+ return false;
19665
+ }
19666
+
19667
+ static void lm_ggml_thread_cpumask_next(const bool * global_mask, bool * local_mask, bool strict, int32_t* iter) {
19668
+ if (!strict) {
19669
+ memcpy(local_mask, global_mask, LM_GGML_MAX_N_THREADS);
19670
+ return;
19671
+ } else {
19672
+ memset(local_mask, 0, LM_GGML_MAX_N_THREADS);
19673
+ int32_t base_idx = *iter;
19674
+ for (int32_t i = 0; i < LM_GGML_MAX_N_THREADS; i++) {
19675
+ int32_t idx = base_idx + i;
19676
+ if (idx >= LM_GGML_MAX_N_THREADS) {
19677
+ // Just a cheaper modulo
19678
+ idx -= LM_GGML_MAX_N_THREADS;
19679
+ }
19680
+ if (global_mask[idx]) {
19681
+ local_mask[idx] = 1;
19682
+ *iter = idx + 1;
19683
+ return;
19684
+ }
19685
+ }
19686
+ }
19687
+ }
19688
+
19689
+ void lm_ggml_threadpool_free(struct lm_ggml_threadpool* threadpool) {
19690
+ if (!threadpool) return;
19691
+
19692
+ #ifndef LM_GGML_USE_OPENMP
19693
+ struct lm_ggml_compute_state* workers = threadpool->workers;
19694
+ const int n_threads = threadpool->n_threads_max;
19695
+
19696
+ lm_ggml_mutex_lock(&threadpool->mutex);
19697
+
19698
+ threadpool->stop = true;
19699
+ threadpool->pause = false;
19700
+
19701
+ lm_ggml_cond_broadcast(&threadpool->cond);
19702
+ lm_ggml_mutex_unlock(&threadpool->mutex);
19703
+
19704
+ for (int j = 1; j < n_threads; j++) {
19705
+ int32_t rc = lm_ggml_thread_join(workers[j].thrd, NULL);
19706
+ LM_GGML_ASSERT(rc == LM_GGML_EXIT_SUCCESS || rc == LM_GGML_EXIT_ABORTED);
19707
+ UNUSED(rc);
19708
+ }
19709
+
19710
+ lm_ggml_mutex_destroy(&threadpool->mutex);
19711
+ lm_ggml_cond_destroy(&threadpool->cond);
19712
+ #endif // LM_GGML_USE_OPENMP
19713
+
19714
+ LM_GGML_ALIGNED_FREE(threadpool->workers);
19715
+ LM_GGML_ALIGNED_FREE(threadpool);
19716
+ }
19717
+
19718
+ #ifndef LM_GGML_USE_OPENMP
19719
+ // pause/resume must be called under mutex
19720
+ static void lm_ggml_threadpool_pause_locked(struct lm_ggml_threadpool * threadpool) {
19721
+ LM_GGML_PRINT_DEBUG("Pausing threadpool\n");
19722
+ threadpool->pause = true;
19723
+ lm_ggml_cond_broadcast(&threadpool->cond);
19724
+ }
19725
+
19726
+ static void lm_ggml_threadpool_resume_locked(struct lm_ggml_threadpool * threadpool) {
19727
+ LM_GGML_PRINT_DEBUG("Resuming threadpool\n");
19728
+ threadpool->pause = false;
19729
+ lm_ggml_cond_broadcast(&threadpool->cond);
19730
+ }
19731
+ #endif
19732
+
19733
+ void lm_ggml_threadpool_pause(struct lm_ggml_threadpool * threadpool) {
19734
+ #ifndef LM_GGML_USE_OPENMP
19735
+ lm_ggml_mutex_lock(&threadpool->mutex);
19736
+ if (!threadpool->pause) {
19737
+ lm_ggml_threadpool_pause_locked(threadpool);
19738
+ }
19739
+ lm_ggml_mutex_unlock(&threadpool->mutex);
19740
+ #else
19741
+ UNUSED(threadpool);
19742
+ #endif
19743
+ }
19744
+
19745
+ void lm_ggml_threadpool_resume(struct lm_ggml_threadpool * threadpool) {
19746
+ #ifndef LM_GGML_USE_OPENMP
19747
+ lm_ggml_mutex_lock(&threadpool->mutex);
19748
+ if (threadpool->pause) {
19749
+ lm_ggml_threadpool_resume_locked(threadpool);
19750
+ }
19751
+ lm_ggml_mutex_unlock(&threadpool->mutex);
19752
+ #else
19753
+ UNUSED(threadpool);
19754
+ #endif
19755
+ }
19756
+
19757
+ struct lm_ggml_cplan lm_ggml_graph_plan(
19758
+ const struct lm_ggml_cgraph * cgraph,
19759
+ int n_threads,
19760
+ struct lm_ggml_threadpool * threadpool) {
19761
+
19762
+ if (threadpool == NULL) {
19763
+ LM_GGML_PRINT_DEBUG("Threadpool is not specified. Will create a disposable threadpool : n_threads %d\n", n_threads);
19764
+ }
18607
19765
  if (n_threads <= 0) {
18608
- n_threads = LM_GGML_DEFAULT_N_THREADS;
19766
+ n_threads = threadpool ? threadpool->n_threads_max : LM_GGML_DEFAULT_N_THREADS;
18609
19767
  }
18610
19768
 
18611
19769
  size_t work_size = 0;
@@ -18761,12 +19919,13 @@ struct lm_ggml_cplan lm_ggml_graph_plan(const struct lm_ggml_cgraph * cgraph, in
18761
19919
  }
18762
19920
 
18763
19921
  if (work_size > 0) {
18764
- work_size += CACHE_LINE_SIZE*(n_threads - 1);
19922
+ work_size += CACHE_LINE_SIZE*(n_threads);
18765
19923
  }
18766
19924
 
18767
- cplan.n_threads = MIN(max_tasks, n_threads);
18768
- cplan.work_size = work_size;
18769
- cplan.work_data = NULL;
19925
+ cplan.threadpool = threadpool;
19926
+ cplan.n_threads = MIN(max_tasks, n_threads);
19927
+ cplan.work_size = work_size;
19928
+ cplan.work_data = NULL;
18770
19929
 
18771
19930
  return cplan;
18772
19931
  }
@@ -18774,17 +19933,17 @@ struct lm_ggml_cplan lm_ggml_graph_plan(const struct lm_ggml_cgraph * cgraph, in
18774
19933
  static thread_ret_t lm_ggml_graph_compute_thread(void * data) {
18775
19934
  struct lm_ggml_compute_state * state = (struct lm_ggml_compute_state *) data;
18776
19935
 
18777
- const struct lm_ggml_cgraph * cgraph = state->shared->cgraph;
18778
- const struct lm_ggml_cplan * cplan = state->shared->cplan;
19936
+ const struct lm_ggml_cgraph * cgraph = state->threadpool->cgraph;
19937
+ const struct lm_ggml_cplan * cplan = state->threadpool->cplan;
18779
19938
 
18780
19939
  set_numa_thread_affinity(state->ith);
18781
19940
 
18782
19941
  struct lm_ggml_compute_params params = {
18783
- /*.ith =*/ state->ith,
18784
- /*.nth =*/ state->shared->n_threads,
18785
- /*.wsize =*/ cplan->work_size,
18786
- /*.wdata =*/ cplan->work_data,
18787
- /*.shared=*/ state->shared,
19942
+ /*.ith =*/ state->ith,
19943
+ /*.nth =*/ state->threadpool->n_threads_cur,
19944
+ /*.wsize =*/ cplan->work_size,
19945
+ /*.wdata =*/ cplan->work_data,
19946
+ /*.threadpool=*/ state->threadpool,
18788
19947
  };
18789
19948
 
18790
19949
  for (int node_n = 0; node_n < cgraph->n_nodes; node_n++) {
@@ -18793,12 +19952,12 @@ static thread_ret_t lm_ggml_graph_compute_thread(void * data) {
18793
19952
  lm_ggml_compute_forward(&params, node);
18794
19953
 
18795
19954
  if (state->ith == 0 && cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) {
18796
- state->shared->ec = LM_GGML_STATUS_ABORTED;
19955
+ state->threadpool->ec = LM_GGML_STATUS_ABORTED;
18797
19956
  }
18798
19957
 
18799
- lm_ggml_barrier(state->shared);
19958
+ lm_ggml_barrier(state->threadpool);
18800
19959
 
18801
- if (state->shared->ec != LM_GGML_STATUS_SUCCESS) {
19960
+ if (state->threadpool->ec != LM_GGML_STATUS_SUCCESS) {
18802
19961
  break;
18803
19962
  }
18804
19963
  }
@@ -18806,24 +19965,243 @@ static thread_ret_t lm_ggml_graph_compute_thread(void * data) {
18806
19965
  return 0;
18807
19966
  }
18808
19967
 
19968
+ #ifndef LM_GGML_USE_OPENMP
19969
+
19970
+ static inline bool lm_ggml_graph_compute_ready(struct lm_ggml_compute_state * state) {
19971
+ struct lm_ggml_threadpool * threadpool = state->threadpool;
19972
+
19973
+ if (state->pending || threadpool->stop || threadpool->pause) { return true; }
19974
+
19975
+ // check for new graph/work
19976
+ int new_graph = atomic_load_explicit(&threadpool->n_graph, memory_order_relaxed);
19977
+ if (new_graph != state->last_graph) {
19978
+ state->pending = (state->ith < threadpool->n_threads_cur);
19979
+ state->last_graph = new_graph;
19980
+ }
19981
+
19982
+ return state->pending;
19983
+ }
19984
+
19985
+ static inline bool lm_ggml_graph_compute_poll_for_work(struct lm_ggml_compute_state * state) {
19986
+ struct lm_ggml_threadpool * threadpool = state->threadpool;
19987
+
19988
+ // This seems to make 0 ... 100 a decent range for polling level across modern processors.
19989
+ // Perhaps, we can adjust it dynamically based on load and things.
19990
+ const uint64_t n_rounds = 1024UL * 128 * threadpool->poll;
19991
+
19992
+ for (uint64_t i=0; !lm_ggml_graph_compute_ready(state) && i<n_rounds; i++) {
19993
+ // No new work. Keep polling.
19994
+ lm_ggml_thread_cpu_relax();
19995
+ }
19996
+
19997
+ return state->pending;
19998
+ }
19999
+
20000
+ static inline bool lm_ggml_graph_compute_check_for_work(struct lm_ggml_compute_state * state) {
20001
+ struct lm_ggml_threadpool * threadpool = state->threadpool;
20002
+
20003
+ if (lm_ggml_graph_compute_poll_for_work(state)) {
20004
+ return state->pending;
20005
+ }
20006
+
20007
+ lm_ggml_mutex_lock_shared(&threadpool->mutex);
20008
+ while (!lm_ggml_graph_compute_ready(state)) {
20009
+ // No new work. Wait for the signal.
20010
+ LM_GGML_PRINT_DEBUG("thread #%d waiting for work\n", state->ith);
20011
+ lm_ggml_cond_wait(&threadpool->cond, &threadpool->mutex);
20012
+ }
20013
+ lm_ggml_mutex_unlock_shared(&threadpool->mutex);
20014
+
20015
+ return state->pending;
20016
+ }
20017
+
20018
+ static thread_ret_t lm_ggml_graph_compute_secondary_thread(void* data) {
20019
+ struct lm_ggml_compute_state * state = (struct lm_ggml_compute_state *) data;
20020
+ struct lm_ggml_threadpool * threadpool = state->threadpool;
20021
+
20022
+ lm_ggml_thread_apply_priority(threadpool->prio);
20023
+ if (lm_ggml_thread_cpumask_is_valid(state->cpumask)) {
20024
+ lm_ggml_thread_apply_affinity(state->cpumask);
20025
+ }
20026
+
20027
+ while (true) {
20028
+ // Check if we need to sleep
20029
+ while (threadpool->pause) {
20030
+ LM_GGML_PRINT_DEBUG("thread #%d inside pause loop\n", state->ith);
20031
+ lm_ggml_mutex_lock_shared(&threadpool->mutex);
20032
+ if (threadpool->pause) {
20033
+ lm_ggml_cond_wait(&threadpool->cond, &threadpool->mutex);
20034
+ }
20035
+ LM_GGML_PRINT_DEBUG("thread #%d resuming after wait\n", state->ith);
20036
+ lm_ggml_mutex_unlock_shared(&threadpool->mutex);
20037
+ }
20038
+
20039
+ // This needs to be checked for after the cond_wait
20040
+ if (threadpool->stop) break;
20041
+
20042
+ // Check if there is new work
20043
+ // The main thread is the only one that can dispatch new work
20044
+
20045
+ lm_ggml_graph_compute_check_for_work(state);
20046
+ if (state->pending) {
20047
+ state->pending = false;
20048
+
20049
+ lm_ggml_graph_compute_thread(state);
20050
+ }
20051
+ }
20052
+
20053
+ return (thread_ret_t) 0;
20054
+ }
20055
+
20056
+ // Start processing new graph
20057
+ static void lm_ggml_graph_compute_kickoff(struct lm_ggml_threadpool * threadpool)
20058
+ {
20059
+ // always take the mutex here because the worker threads are doing hybrid poll/wait
20060
+
20061
+ lm_ggml_mutex_lock(&threadpool->mutex);
20062
+
20063
+ atomic_fetch_add_explicit(&threadpool->n_graph, 1, memory_order_relaxed);
20064
+
20065
+ if (threadpool->pause) {
20066
+ // Update main thread prio and affinity to match the threadpool settings
20067
+ lm_ggml_thread_apply_priority(threadpool->prio);
20068
+ if (lm_ggml_thread_cpumask_is_valid(threadpool->workers[0].cpumask)) {
20069
+ lm_ggml_thread_apply_affinity(threadpool->workers[0].cpumask);
20070
+ }
20071
+
20072
+ // resume does cond broadcast
20073
+ lm_ggml_threadpool_resume_locked(threadpool);
20074
+ } else {
20075
+ lm_ggml_cond_broadcast(&threadpool->cond);
20076
+ }
20077
+
20078
+ lm_ggml_mutex_unlock(&threadpool->mutex);
20079
+ }
20080
+
20081
+ #endif // LM_GGML_USE_OPENMP
20082
+
20083
+ void lm_ggml_threadpool_params_init(struct lm_ggml_threadpool_params * p, int n_threads) {
20084
+ p->n_threads = n_threads;
20085
+ p->prio = 0; // default priority (usually means normal or inherited)
20086
+ p->poll = 50; // hybrid-polling enabled
20087
+ p->strict_cpu = false; // no strict placement (all threads share same cpumask)
20088
+ p->paused = false; // threads are ready to go
20089
+ memset(p->cpumask, 0, LM_GGML_MAX_N_THREADS); // all-zero means use the default affinity (usually inherited)
20090
+ }
20091
+
20092
+ struct lm_ggml_threadpool_params lm_ggml_threadpool_params_default(int n_threads) {
20093
+ struct lm_ggml_threadpool_params p;
20094
+ lm_ggml_threadpool_params_init(&p, n_threads);
20095
+ return p;
20096
+ }
20097
+
20098
+ bool lm_ggml_threadpool_params_match(const struct lm_ggml_threadpool_params * p0, const struct lm_ggml_threadpool_params * p1) {
20099
+ if (p0->n_threads != p1->n_threads ) return false;
20100
+ if (p0->prio != p1->prio ) return false;
20101
+ if (p0->poll != p1->poll ) return false;
20102
+ if (p0->strict_cpu != p1->strict_cpu ) return false;
20103
+ return memcmp(p0->cpumask, p1->cpumask, LM_GGML_MAX_N_THREADS) == 0;
20104
+ }
20105
+
20106
+ static struct lm_ggml_threadpool * lm_ggml_threadpool_new_impl(
20107
+ struct lm_ggml_threadpool_params * tpp,
20108
+ struct lm_ggml_cgraph * cgraph,
20109
+ struct lm_ggml_cplan * cplan) {
20110
+
20111
+ struct lm_ggml_threadpool * threadpool =
20112
+ LM_GGML_ALIGNED_MALLOC(sizeof(struct lm_ggml_threadpool));
20113
+ {
20114
+ threadpool->cgraph = cgraph;
20115
+ threadpool->cplan = cplan;
20116
+ threadpool->n_graph = 0;
20117
+ threadpool->n_barrier = 0;
20118
+ threadpool->n_barrier_passed = 0;
20119
+ threadpool->current_chunk = 0;
20120
+ threadpool->stop = false;
20121
+ threadpool->pause = tpp->paused;
20122
+ threadpool->workers = NULL;
20123
+ threadpool->n_threads_max = tpp->n_threads;
20124
+ threadpool->n_threads_cur = tpp->n_threads;
20125
+ threadpool->poll = tpp->poll;
20126
+ threadpool->prio = tpp->prio;
20127
+ threadpool->ec = LM_GGML_STATUS_SUCCESS;
20128
+ }
20129
+
20130
+ // Allocate and init workers state
20131
+ const size_t workers_size = sizeof(struct lm_ggml_compute_state) * tpp->n_threads;
20132
+ struct lm_ggml_compute_state * workers = LM_GGML_ALIGNED_MALLOC(workers_size);
20133
+
20134
+ memset(workers, 0, workers_size);
20135
+ for (int j = 0; j < tpp->n_threads; j++) {
20136
+ workers[j].threadpool = threadpool;
20137
+ workers[j].ith = j;
20138
+ }
20139
+
20140
+ threadpool->workers = workers;
20141
+
20142
+ #ifndef LM_GGML_USE_OPENMP
20143
+ lm_ggml_mutex_init(&threadpool->mutex);
20144
+ lm_ggml_cond_init(&threadpool->cond);
20145
+
20146
+ // Spin the threads for all workers, and update CPU placements.
20147
+ // Place the main thread last (towards the higher numbered CPU cores).
20148
+
20149
+ int32_t cpumask_iter = 0;
20150
+
20151
+ for (int j = 1; j < tpp->n_threads; j++) {
20152
+ lm_ggml_thread_cpumask_next(tpp->cpumask, workers[j].cpumask, tpp->strict_cpu, &cpumask_iter);
20153
+
20154
+ int32_t rc = lm_ggml_thread_create(&workers[j].thrd, NULL, lm_ggml_graph_compute_secondary_thread, &workers[j]);
20155
+ LM_GGML_ASSERT(rc == 0);
20156
+ }
20157
+
20158
+ lm_ggml_thread_cpumask_next(tpp->cpumask, workers[0].cpumask, tpp->strict_cpu, &cpumask_iter);
20159
+
20160
+ if (!threadpool->pause) {
20161
+ // Update main thread prio and affinity at the start, otherwise we'll do it in resume
20162
+ lm_ggml_thread_apply_priority(threadpool->prio);
20163
+ if (lm_ggml_thread_cpumask_is_valid(threadpool->workers[0].cpumask)) {
20164
+ lm_ggml_thread_apply_affinity(threadpool->workers[0].cpumask);
20165
+ }
20166
+ }
20167
+ #endif // LM_GGML_USE_OPENMP
20168
+
20169
+ return threadpool;
20170
+ }
20171
+
20172
+ struct lm_ggml_threadpool * lm_ggml_threadpool_new(struct lm_ggml_threadpool_params * tpp) {
20173
+ return lm_ggml_threadpool_new_impl(tpp, NULL, NULL);
20174
+ }
20175
+
18809
20176
  enum lm_ggml_status lm_ggml_graph_compute(struct lm_ggml_cgraph * cgraph, struct lm_ggml_cplan * cplan) {
18810
20177
  LM_GGML_ASSERT(cplan);
18811
20178
  LM_GGML_ASSERT(cplan->n_threads > 0);
18812
20179
  LM_GGML_ASSERT(cplan->work_size == 0 || cplan->work_data != NULL);
18813
20180
 
18814
- int n_threads = cplan->n_threads;
18815
-
18816
- struct lm_ggml_compute_state_shared state_shared = {
18817
- /*.cgraph =*/ cgraph,
18818
- /*.cgraph_plan =*/ cplan,
18819
- /*.n_threads =*/ n_threads,
18820
- /*.n_barrier =*/ 0,
18821
- /*.n_barrier_passed =*/ 0,
18822
- /*.abort_callback =*/ NULL,
18823
- /*.abort_callback_data =*/ NULL,
18824
- /*.current_chunk =*/ 0,
18825
- /*.ec =*/ LM_GGML_STATUS_SUCCESS,
18826
- };
20181
+ int n_threads = cplan->n_threads;
20182
+ struct lm_ggml_threadpool * threadpool = cplan->threadpool;
20183
+
20184
+ bool disposable_threadpool = false;
20185
+
20186
+ if (threadpool == NULL) {
20187
+ LM_GGML_PRINT_DEBUG("Threadpool is not specified. Will create a disposable threadpool : n_threads %d\n", n_threads);
20188
+ disposable_threadpool = true;
20189
+
20190
+ struct lm_ggml_threadpool_params ttp = lm_ggml_threadpool_params_default(n_threads);
20191
+ threadpool = lm_ggml_threadpool_new_impl(&ttp, cgraph, cplan);
20192
+ } else {
20193
+ // Reset some of the parameters that need resetting
20194
+ // No worker threads should be accessing the parameters below at this stage
20195
+ threadpool->cgraph = cgraph;
20196
+ threadpool->cplan = cplan;
20197
+ threadpool->n_threads_cur = n_threads;
20198
+ threadpool->current_chunk = 0;
20199
+ threadpool->ec = LM_GGML_STATUS_SUCCESS;
20200
+ }
20201
+
20202
+ if (n_threads > threadpool->n_threads_max) {
20203
+ LM_GGML_PRINT("WARNING: cplan is requesting more threads than the threadpool contains. Expect a bad time!\n");
20204
+ }
18827
20205
 
18828
20206
  #ifdef LM_GGML_USE_OPENMP
18829
20207
  if (n_threads > 1) {
@@ -18833,63 +20211,36 @@ enum lm_ggml_status lm_ggml_graph_compute(struct lm_ggml_cgraph * cgraph, struct
18833
20211
  {
18834
20212
  // update the number of threads from the actual number of threads that we got from OpenMP
18835
20213
  n_threads = omp_get_num_threads();
18836
- state_shared.n_threads = n_threads;
20214
+ threadpool->n_threads_cur = n_threads;
18837
20215
  }
18838
20216
 
18839
- struct lm_ggml_compute_state worker = {
18840
- .thrd = 0,
18841
- .ith = omp_get_thread_num(),
18842
- .shared = &state_shared,
18843
- };
18844
- lm_ggml_graph_compute_thread(&worker);
20217
+ lm_ggml_graph_compute_thread(&threadpool->workers[omp_get_thread_num()]);
18845
20218
  }
18846
20219
  } else {
18847
- struct lm_ggml_compute_state worker = {
18848
- .thrd = 0,
18849
- .ith = 0,
18850
- .shared = &state_shared,
18851
- };
18852
- lm_ggml_graph_compute_thread(&worker);
20220
+ lm_ggml_graph_compute_thread(&threadpool->workers[0]);
18853
20221
  }
18854
20222
  #else
18855
- struct lm_ggml_compute_state * workers = alloca(sizeof(struct lm_ggml_compute_state)*n_threads);
18856
-
18857
- for (int j = 0; j < n_threads; ++j) {
18858
- workers[j] = (struct lm_ggml_compute_state) {
18859
- .thrd = 0,
18860
- .ith = j,
18861
- .shared = &state_shared,
18862
- };
18863
- }
18864
-
18865
- // create thread pool
18866
- for (int j = 1; j < n_threads; ++j) {
18867
- const int rc = lm_ggml_thread_create(&workers[j].thrd, NULL, lm_ggml_graph_compute_thread, &workers[j]);
18868
- LM_GGML_ASSERT(rc == 0);
18869
- UNUSED(rc);
18870
- }
18871
-
18872
- // this is a work thread too
18873
- lm_ggml_graph_compute_thread(&workers[0]);
20223
+ // Kick all threads to start the new graph
20224
+ lm_ggml_graph_compute_kickoff(threadpool);
18874
20225
 
18875
- // join or kill thread pool
18876
- if (n_threads > 1) {
18877
- for (int j = 1; j < n_threads; j++) {
18878
- const int rc = lm_ggml_thread_join(workers[j].thrd, NULL);
18879
- LM_GGML_ASSERT(rc == 0);
18880
- UNUSED(rc);
18881
- }
18882
- }
20226
+ // This is a work thread too
20227
+ lm_ggml_graph_compute_thread(&threadpool->workers[0]);
18883
20228
  #endif
18884
20229
 
18885
20230
  // don't leave affinity set on the main thread
18886
20231
  clear_numa_thread_affinity();
18887
20232
 
18888
- return state_shared.ec;
20233
+ enum lm_ggml_status ret = threadpool->ec;
20234
+
20235
+ if (disposable_threadpool) {
20236
+ lm_ggml_threadpool_free(threadpool);
20237
+ }
20238
+
20239
+ return ret;
18889
20240
  }
18890
20241
 
18891
20242
  enum lm_ggml_status lm_ggml_graph_compute_with_ctx(struct lm_ggml_context * ctx, struct lm_ggml_cgraph * cgraph, int n_threads) {
18892
- struct lm_ggml_cplan cplan = lm_ggml_graph_plan(cgraph, n_threads);
20243
+ struct lm_ggml_cplan cplan = lm_ggml_graph_plan(cgraph, n_threads, NULL);
18893
20244
 
18894
20245
  struct lm_ggml_object * obj = lm_ggml_new_object(ctx, LM_GGML_OBJECT_TYPE_WORK_BUFFER, cplan.work_size);
18895
20246
 
@@ -19030,9 +20381,11 @@ void lm_ggml_graph_export(const struct lm_ggml_cgraph * cgraph, const char * fna
19030
20381
 
19031
20382
  const uint32_t type = tensor->type;
19032
20383
  const uint32_t op = tensor->op;
20384
+ const int32_t flags = tensor->flags;
19033
20385
 
19034
20386
  fwrite(&type, sizeof(uint32_t), 1, fout);
19035
20387
  fwrite(&op, sizeof(uint32_t), 1, fout);
20388
+ fwrite(&flags, sizeof(int32_t), 1, fout);
19036
20389
 
19037
20390
  for (int j = 0; j < LM_GGML_MAX_DIMS; ++j) {
19038
20391
  const uint64_t ne = tensor->ne[j];
@@ -19062,9 +20415,11 @@ void lm_ggml_graph_export(const struct lm_ggml_cgraph * cgraph, const char * fna
19062
20415
 
19063
20416
  const uint32_t type = tensor->type;
19064
20417
  const uint32_t op = tensor->op;
20418
+ const int32_t flags = tensor->flags;
19065
20419
 
19066
20420
  fwrite(&type, sizeof(uint32_t), 1, fout);
19067
20421
  fwrite(&op, sizeof(uint32_t), 1, fout);
20422
+ fwrite(&flags, sizeof(int32_t), 1, fout);
19068
20423
 
19069
20424
  for (int j = 0; j < LM_GGML_MAX_DIMS; ++j) {
19070
20425
  const uint64_t ne = tensor->ne[j];
@@ -19123,6 +20478,14 @@ void lm_ggml_graph_export(const struct lm_ggml_cgraph * cgraph, const char * fna
19123
20478
  }
19124
20479
  }
19125
20480
  }
20481
+
20482
+ // dump the data
20483
+ // TODO: pad this to 32 byte boundary
20484
+ if ((flags & LM_GGML_TENSOR_FLAG_PARAM)) {
20485
+ const size_t size = lm_ggml_nbytes(tensor);
20486
+
20487
+ fwrite(tensor->data, sizeof(char), size, fout);
20488
+ }
19126
20489
  }
19127
20490
  }
19128
20491
 
@@ -19236,10 +20599,12 @@ struct lm_ggml_cgraph * lm_ggml_graph_import(const char * fname, struct lm_ggml_
19236
20599
  {
19237
20600
  uint32_t type;
19238
20601
  uint32_t op;
20602
+ int32_t flags;
19239
20603
 
19240
20604
  for (uint32_t i = 0; i < n_leafs; ++i) {
19241
20605
  type = *(const uint32_t *) ptr; ptr += sizeof(type);
19242
20606
  op = *(const uint32_t *) ptr; ptr += sizeof(op);
20607
+ flags = *(const int32_t *) ptr; ptr += sizeof(flags);
19243
20608
 
19244
20609
  int64_t ne[LM_GGML_MAX_DIMS];
19245
20610
  size_t nb[LM_GGML_MAX_DIMS];
@@ -19257,20 +20622,19 @@ struct lm_ggml_cgraph * lm_ggml_graph_import(const char * fname, struct lm_ggml_
19257
20622
 
19258
20623
  struct lm_ggml_tensor * tensor = lm_ggml_new_tensor(*ctx_eval, (enum lm_ggml_type) type, LM_GGML_MAX_DIMS, ne);
19259
20624
 
19260
- tensor->op = (enum lm_ggml_op) op;
20625
+ tensor->op = (enum lm_ggml_op) op;
20626
+ tensor->flags = flags;
19261
20627
 
19262
20628
  memcpy(tensor->name, ptr, LM_GGML_MAX_NAME); ptr += LM_GGML_MAX_NAME;
19263
20629
  memcpy(tensor->op_params, ptr, LM_GGML_MAX_OP_PARAMS); ptr += LM_GGML_MAX_OP_PARAMS;
19264
20630
 
19265
- tensor->data = (void *) ptr;
19266
-
19267
20631
  for (int j = 0; j < LM_GGML_MAX_DIMS; ++j) {
19268
20632
  tensor->nb[j] = nb[j];
19269
20633
  }
19270
20634
 
19271
- result->leafs[i] = tensor;
20635
+ tensor->data = (void *) ptr; ptr += lm_ggml_nbytes(tensor);
19272
20636
 
19273
- ptr += lm_ggml_nbytes(tensor);
20637
+ result->leafs[i] = tensor;
19274
20638
 
19275
20639
  fprintf(stderr, "%s: loaded leaf %u: '%16s', %9zu bytes\n", __func__, i, tensor->name, lm_ggml_nbytes(tensor));
19276
20640
  }
@@ -19282,10 +20646,12 @@ struct lm_ggml_cgraph * lm_ggml_graph_import(const char * fname, struct lm_ggml_
19282
20646
  {
19283
20647
  uint32_t type;
19284
20648
  uint32_t op;
20649
+ int32_t flags;
19285
20650
 
19286
20651
  for (uint32_t i = 0; i < n_nodes; ++i) {
19287
20652
  type = *(const uint32_t *) ptr; ptr += sizeof(type);
19288
20653
  op = *(const uint32_t *) ptr; ptr += sizeof(op);
20654
+ flags = *(const int32_t *) ptr; ptr += sizeof(flags);
19289
20655
 
19290
20656
  enum lm_ggml_op eop = (enum lm_ggml_op) op;
19291
20657
 
@@ -19375,6 +20741,11 @@ struct lm_ggml_cgraph * lm_ggml_graph_import(const char * fname, struct lm_ggml_
19375
20741
 
19376
20742
  result->nodes[i] = tensor;
19377
20743
 
20744
+ // TODO tensor data is be duplicated due to lm_ggml_new_tensor call above
20745
+ if (flags & LM_GGML_TENSOR_FLAG_PARAM) {
20746
+ tensor->data = (void *) ptr; ptr += lm_ggml_nbytes(tensor);
20747
+ }
20748
+
19378
20749
  fprintf(stderr, "%s: loaded node %u: '%16s', %9zu bytes\n", __func__, i, tensor->name, lm_ggml_nbytes(tensor));
19379
20750
  }
19380
20751
  }
@@ -19643,6 +21014,7 @@ static enum lm_ggml_opt_result lm_ggml_opt_adam(
19643
21014
  lm_ggml_opt_callback callback,
19644
21015
  void * callback_data) {
19645
21016
  LM_GGML_ASSERT(lm_ggml_is_scalar(f));
21017
+ LM_GGML_ASSERT(f->type == LM_GGML_TYPE_F32);
19646
21018
 
19647
21019
  // these will store the parameters we want to optimize
19648
21020
  struct lm_ggml_tensor * ps[LM_GGML_MAX_PARAMS];
@@ -19684,7 +21056,7 @@ static enum lm_ggml_opt_result lm_ggml_opt_adam(
19684
21056
 
19685
21057
  float * pf = params.past > 0 ? opt->adam.pf->data : NULL; // past function values
19686
21058
 
19687
- struct lm_ggml_cplan cplan = lm_ggml_graph_plan(gb, params.n_threads);
21059
+ struct lm_ggml_cplan cplan = lm_ggml_graph_plan(gb, params.n_threads, NULL);
19688
21060
  struct lm_ggml_object * obj = lm_ggml_new_object(ctx, LM_GGML_OBJECT_TYPE_WORK_BUFFER, cplan.work_size);
19689
21061
  cplan.work_data = (uint8_t *)ctx->mem_buffer + obj->offs;
19690
21062
 
@@ -20031,7 +21403,7 @@ static enum lm_ggml_opt_result lm_ggml_opt_lbfgs(
20031
21403
  opt->iter = iter;
20032
21404
  }
20033
21405
 
20034
- struct lm_ggml_cplan cplan = lm_ggml_graph_plan(gb, params.n_threads);
21406
+ struct lm_ggml_cplan cplan = lm_ggml_graph_plan(gb, params.n_threads, NULL);
20035
21407
  struct lm_ggml_object * obj = lm_ggml_new_object(ctx, LM_GGML_OBJECT_TYPE_WORK_BUFFER, cplan.work_size);
20036
21408
  cplan.work_data = (uint8_t *)ctx->mem_buffer + obj->offs;
20037
21409
 
@@ -20409,6 +21781,8 @@ enum lm_ggml_opt_result lm_ggml_opt(
20409
21781
  struct lm_ggml_context * ctx,
20410
21782
  struct lm_ggml_opt_params params,
20411
21783
  struct lm_ggml_tensor * f) {
21784
+ LM_GGML_ASSERT(f->grad && "lm_ggml_set_param called for at least one parent tensor.");
21785
+
20412
21786
  bool free_ctx = false;
20413
21787
  if (ctx == NULL) {
20414
21788
  struct lm_ggml_init_params params_ctx = {
@@ -20463,6 +21837,8 @@ enum lm_ggml_opt_result lm_ggml_opt_resume_g(
20463
21837
  lm_ggml_opt_callback callback,
20464
21838
  void * callback_data) {
20465
21839
 
21840
+ LM_GGML_ASSERT(f->grad && "lm_ggml_set_param must be called for at least one ancestor");
21841
+
20466
21842
  // build forward + backward compute graphs
20467
21843
  enum lm_ggml_opt_result result = LM_GGML_OPT_RESULT_OK;
20468
21844
 
@@ -20574,6 +21950,8 @@ size_t lm_ggml_quantize_chunk(
20574
21950
  case LM_GGML_TYPE_Q4_K: result = quantize_q4_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
20575
21951
  case LM_GGML_TYPE_Q5_K: result = quantize_q5_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
20576
21952
  case LM_GGML_TYPE_Q6_K: result = quantize_q6_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
21953
+ case LM_GGML_TYPE_TQ1_0: result = quantize_tq1_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
21954
+ case LM_GGML_TYPE_TQ2_0: result = quantize_tq2_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
20577
21955
  case LM_GGML_TYPE_IQ2_XXS: result = quantize_iq2_xxs(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
20578
21956
  case LM_GGML_TYPE_IQ2_XS: result = quantize_iq2_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
20579
21957
  case LM_GGML_TYPE_IQ3_XXS: result = quantize_iq3_xxs(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
@@ -21550,6 +22928,7 @@ void lm_gguf_set_kv(struct lm_gguf_context * ctx, struct lm_gguf_context * src)
21550
22928
  void lm_gguf_add_tensor(
21551
22929
  struct lm_gguf_context * ctx,
21552
22930
  const struct lm_ggml_tensor * tensor) {
22931
+ LM_GGML_ASSERT(tensor);
21553
22932
  if (lm_gguf_find_tensor(ctx, tensor->name) != -1) {
21554
22933
  LM_GGML_ABORT("duplicated tensor name");
21555
22934
  }
@@ -21909,6 +23288,14 @@ int lm_ggml_cpu_has_arm_fma(void) {
21909
23288
  #endif
21910
23289
  }
21911
23290
 
23291
+ int lm_ggml_cpu_has_riscv_v(void) {
23292
+ #if defined(__riscv_v_intrinsic)
23293
+ return 1;
23294
+ #else
23295
+ return 0;
23296
+ #endif
23297
+ }
23298
+
21912
23299
  int lm_ggml_cpu_has_metal(void) {
21913
23300
  #if defined(LM_GGML_USE_METAL)
21914
23301
  return 1;