cui-llama.rn 1.1.2 → 1.1.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/cpp/ggml.c CHANGED
@@ -69,23 +69,42 @@ int lm_ggml_sve_cnt_b = 0;
69
69
  #endif
70
70
  #include <windows.h>
71
71
 
72
+ #if !defined(__clang__)
72
73
  typedef volatile LONG atomic_int;
73
74
  typedef atomic_int atomic_bool;
74
75
  typedef atomic_int atomic_flag;
75
76
 
76
77
  #define ATOMIC_FLAG_INIT 0
77
78
 
79
+ typedef enum {
80
+ memory_order_relaxed,
81
+ memory_order_consume,
82
+ memory_order_acquire,
83
+ memory_order_release,
84
+ memory_order_acq_rel,
85
+ memory_order_seq_cst
86
+ } memory_order;
87
+
78
88
  static void atomic_store(atomic_int * ptr, LONG val) {
79
89
  InterlockedExchange(ptr, val);
80
90
  }
91
+ static void atomic_store_explicit(atomic_int * ptr, LONG val, memory_order mo) {
92
+ // TODO: add support for explicit memory order
93
+ InterlockedExchange(ptr, val);
94
+ }
81
95
  static LONG atomic_load(atomic_int * ptr) {
82
96
  return InterlockedCompareExchange(ptr, 0, 0);
83
97
  }
98
+ static LONG atomic_load_explicit(atomic_int * ptr, memory_order mo) {
99
+ // TODO: add support for explicit memory order
100
+ return InterlockedCompareExchange(ptr, 0, 0);
101
+ }
84
102
  static LONG atomic_fetch_add(atomic_int * ptr, LONG inc) {
85
103
  return InterlockedExchangeAdd(ptr, inc);
86
104
  }
87
- static LONG atomic_fetch_sub(atomic_int * ptr, LONG dec) {
88
- return atomic_fetch_add(ptr, -(dec));
105
+ static LONG atomic_fetch_add_explicit(atomic_int * ptr, LONG inc, memory_order mo) {
106
+ // TODO: add support for explicit memory order
107
+ return InterlockedExchangeAdd(ptr, inc);
89
108
  }
90
109
  static atomic_bool atomic_flag_test_and_set(atomic_flag * ptr) {
91
110
  return InterlockedExchange(ptr, 1);
@@ -93,6 +112,9 @@ static atomic_bool atomic_flag_test_and_set(atomic_flag * ptr) {
93
112
  static void atomic_flag_clear(atomic_flag * ptr) {
94
113
  InterlockedExchange(ptr, 0);
95
114
  }
115
+ #else // clang
116
+ #include <stdatomic.h>
117
+ #endif
96
118
 
97
119
  typedef HANDLE pthread_t;
98
120
 
@@ -121,8 +143,13 @@ static int sched_yield (void) {
121
143
  return 0;
122
144
  }
123
145
  #else
146
+
124
147
  #include <pthread.h>
125
148
  #include <stdatomic.h>
149
+ #include <sched.h>
150
+ #if defined(__FreeBSD__)
151
+ #include <pthread_np.h>
152
+ #endif
126
153
 
127
154
  typedef void * thread_ret_t;
128
155
 
@@ -1027,7 +1054,31 @@ static const lm_ggml_type_traits_t type_traits[LM_GGML_TYPE_COUNT] = {
1027
1054
  .ncols = 8,
1028
1055
  .gemv = lm_ggml_gemv_q4_0_8x8_q8_0,
1029
1056
  .gemm = lm_ggml_gemm_q4_0_8x8_q8_0,
1030
- }
1057
+ },
1058
+ [LM_GGML_TYPE_TQ1_0] = {
1059
+ .type_name = "tq1_0",
1060
+ .blck_size = QK_K,
1061
+ .type_size = sizeof(block_tq1_0),
1062
+ .is_quantized = true,
1063
+ .to_float = (lm_ggml_to_float_t) dequantize_row_tq1_0,
1064
+ .from_float = quantize_row_tq1_0,
1065
+ .from_float_ref = (lm_ggml_from_float_t) quantize_row_tq1_0_ref,
1066
+ .vec_dot = lm_ggml_vec_dot_tq1_0_q8_K,
1067
+ .vec_dot_type = LM_GGML_TYPE_Q8_K,
1068
+ .nrows = 1,
1069
+ },
1070
+ [LM_GGML_TYPE_TQ2_0] = {
1071
+ .type_name = "tq2_0",
1072
+ .blck_size = QK_K,
1073
+ .type_size = sizeof(block_tq2_0),
1074
+ .is_quantized = true,
1075
+ .to_float = (lm_ggml_to_float_t) dequantize_row_tq2_0,
1076
+ .from_float = quantize_row_tq2_0,
1077
+ .from_float_ref = (lm_ggml_from_float_t) quantize_row_tq2_0_ref,
1078
+ .vec_dot = lm_ggml_vec_dot_tq2_0_q8_K,
1079
+ .vec_dot_type = LM_GGML_TYPE_Q8_K,
1080
+ .nrows = 1,
1081
+ },
1031
1082
  };
1032
1083
 
1033
1084
  // For internal test use
@@ -1868,28 +1919,102 @@ struct lm_ggml_context_container {
1868
1919
  struct lm_ggml_context context;
1869
1920
  };
1870
1921
 
1871
- struct lm_ggml_compute_state_shared {
1872
- const struct lm_ggml_cgraph * cgraph;
1873
- const struct lm_ggml_cplan * cplan;
1922
+ //
1923
+ // Threading defs
1924
+ //
1925
+
1926
+ typedef pthread_t lm_ggml_thread_t;
1927
+
1928
+ #if defined(_WIN32)
1929
+
1930
+ typedef CONDITION_VARIABLE lm_ggml_cond_t;
1931
+ typedef SRWLOCK lm_ggml_mutex_t;
1932
+
1933
+ #define lm_ggml_mutex_init(m) InitializeSRWLock(m)
1934
+ #define lm_ggml_mutex_destroy(m)
1935
+ #define lm_ggml_mutex_lock(m) AcquireSRWLockExclusive(m)
1936
+ #define lm_ggml_mutex_unlock(m) ReleaseSRWLockExclusive(m)
1937
+ #define lm_ggml_mutex_lock_shared(m) AcquireSRWLockShared(m)
1938
+ #define lm_ggml_mutex_unlock_shared(m) ReleaseSRWLockShared(m)
1939
+
1940
+ #define lm_ggml_cond_init(c) InitializeConditionVariable(c)
1941
+ #define lm_ggml_cond_destroy(c)
1942
+ #define lm_ggml_cond_wait(c, m) SleepConditionVariableSRW(c, m, INFINITE, CONDITION_VARIABLE_LOCKMODE_SHARED)
1943
+ #define lm_ggml_cond_broadcast(c) WakeAllConditionVariable(c)
1944
+
1945
+ #define lm_ggml_thread_create pthread_create
1946
+ #define lm_ggml_thread_join pthread_join
1947
+
1948
+ #else
1949
+
1950
+ typedef pthread_cond_t lm_ggml_cond_t;
1951
+ typedef pthread_mutex_t lm_ggml_mutex_t;
1952
+
1953
+ #define lm_ggml_mutex_init(m) pthread_mutex_init(m, NULL)
1954
+ #define lm_ggml_mutex_destroy(m) pthread_mutex_destroy(m)
1955
+ #define lm_ggml_mutex_lock(m) pthread_mutex_lock(m)
1956
+ #define lm_ggml_mutex_unlock(m) pthread_mutex_unlock(m)
1957
+ #define lm_ggml_mutex_lock_shared(m) pthread_mutex_lock(m)
1958
+ #define lm_ggml_mutex_unlock_shared(m) pthread_mutex_unlock(m)
1959
+
1960
+ #define lm_ggml_lock_init(x) UNUSED(x)
1961
+ #define lm_ggml_lock_destroy(x) UNUSED(x)
1962
+ #if defined(__x86_64__) || (defined(_MSC_VER) && defined(_M_AMD64))
1963
+ #define lm_ggml_lock_lock(x) _mm_pause()
1964
+ #else
1965
+ #define lm_ggml_lock_lock(x) UNUSED(x)
1966
+ #endif
1967
+ #define lm_ggml_lock_unlock(x) UNUSED(x)
1968
+
1969
+ #define LM_GGML_LOCK_INITIALIZER 0
1970
+ #define lm_ggml_cond_init(c) pthread_cond_init(c, NULL)
1971
+ #define lm_ggml_cond_destroy(c) pthread_cond_destroy(c)
1972
+ #define lm_ggml_cond_wait(c, m) pthread_cond_wait(c, m)
1973
+ #define lm_ggml_cond_broadcast(c) pthread_cond_broadcast(c)
1974
+
1975
+ #define lm_ggml_thread_create pthread_create
1976
+ #define lm_ggml_thread_join pthread_join
1977
+
1978
+ #endif
1979
+
1980
+ // Threadpool def
1981
+ struct lm_ggml_threadpool {
1982
+ lm_ggml_mutex_t mutex; // mutex for cond.var
1983
+ lm_ggml_cond_t cond; // cond.var for waiting for new work
1874
1984
 
1875
- int n_threads;
1985
+ struct lm_ggml_cgraph * cgraph;
1986
+ struct lm_ggml_cplan * cplan;
1876
1987
 
1877
1988
  // synchronization primitives
1989
+ atomic_int n_graph; // incremented when there is work to be done (i.e each graph)
1878
1990
  atomic_int n_barrier;
1879
1991
  atomic_int n_barrier_passed;
1992
+ atomic_int current_chunk; // currently processing chunk during Mat_Mul, shared between all the threads.
1880
1993
 
1881
- lm_ggml_abort_callback abort_callback; // abort lm_ggml_graph_compute when true
1882
- void * abort_callback_data;
1994
+ // these are atomic as an annotation for thread-sanitizer
1995
+ atomic_bool stop; // Used for stopping the threadpool altogether
1996
+ atomic_bool pause; // Used for pausing the threadpool or individual threads
1883
1997
 
1884
- atomic_int current_chunk; // currently processing chunk during mul_mat, shared between all the threads
1998
+ struct lm_ggml_compute_state * workers; // per thread state
1999
+ int n_threads_max; // number of threads in the pool
2000
+ int n_threads_cur; // number of threads used in the current graph
2001
+
2002
+ int32_t prio; // Scheduling priority
2003
+ uint32_t poll; // Polling level (0 - no polling)
1885
2004
 
1886
2005
  enum lm_ggml_status ec;
1887
2006
  };
1888
2007
 
2008
+ // Per-thread state
1889
2009
  struct lm_ggml_compute_state {
2010
+ #ifndef LM_GGML_USE_OPENMP
1890
2011
  lm_ggml_thread_t thrd;
2012
+ bool cpumask[LM_GGML_MAX_N_THREADS];
2013
+ int last_graph;
2014
+ bool pending;
2015
+ #endif
2016
+ struct lm_ggml_threadpool * threadpool;
1891
2017
  int ith;
1892
- struct lm_ggml_compute_state_shared * shared;
1893
2018
  };
1894
2019
 
1895
2020
  struct lm_ggml_compute_params {
@@ -1900,7 +2025,7 @@ struct lm_ggml_compute_params {
1900
2025
  size_t wsize;
1901
2026
  void * wdata;
1902
2027
 
1903
- struct lm_ggml_compute_state_shared * shared;
2028
+ struct lm_ggml_threadpool * threadpool;
1904
2029
  };
1905
2030
 
1906
2031
  //
@@ -2310,7 +2435,9 @@ inline static void lm_ggml_vec_scale_f16(const int n, lm_ggml_fp16_t * y, const
2310
2435
  inline static void lm_ggml_vec_norm_f32 (const int n, float * s, const float * x) { lm_ggml_vec_dot_f32(n, s, 0, x, 0, x, 0, 1); *s = sqrtf(*s); }
2311
2436
  inline static void lm_ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; }
2312
2437
  inline static void lm_ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); }
2313
- inline static void lm_ggml_vec_log_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = logf(x[i]); }
2438
+ inline static void lm_ggml_vec_log_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = logf(x[i]); }
2439
+ inline static void lm_ggml_vec_sin_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sinf(x[i]); }
2440
+ inline static void lm_ggml_vec_cos_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = cosf(x[i]); }
2314
2441
  inline static void lm_ggml_vec_abs_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fabsf(x[i]); }
2315
2442
  inline static void lm_ggml_vec_sgn_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : ((x[i] < 0.f) ? -1.f : 0.f); }
2316
2443
  inline static void lm_ggml_vec_step_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : 0.f; }
@@ -2322,6 +2449,7 @@ inline static void lm_ggml_vec_sigmoid_f32 (const int n, float * y, const float
2322
2449
  // TODO: optimize performance
2323
2450
  inline static void lm_ggml_vec_hardswish_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); }
2324
2451
  inline static void lm_ggml_vec_hardsigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); }
2452
+ inline static void lm_ggml_vec_exp_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = expf(x[i]); }
2325
2453
 
2326
2454
  static const float GELU_COEF_A = 0.044715f;
2327
2455
  static const float GELU_QUICK_COEF = -1.702f;
@@ -2669,6 +2797,19 @@ static lm_ggml_float lm_ggml_vec_soft_max_f32(const int n, float * y, const floa
2669
2797
  return sum;
2670
2798
  }
2671
2799
 
2800
+ static lm_ggml_float lm_ggml_vec_log_soft_max_f32(const int n, float * y, const float * x, float max) {
2801
+ // log(soft_max) = log(soft_max_i / soft_max_sum) = log(soft_max_i) - log(soft_max_sum) = (logit_i - max) - log(soft_max_i)
2802
+
2803
+ int i = 0;
2804
+ lm_ggml_float sum = 0;
2805
+ for (; i < n; ++i) {
2806
+ float val = x[i] - max;
2807
+ y[i] = val;
2808
+ sum += (lm_ggml_float)expf(val);
2809
+ }
2810
+ return sum = (lm_ggml_float)logf(sum);
2811
+ }
2812
+
2672
2813
  inline static float lm_ggml_silu_backward_f32(float x, float dy) {
2673
2814
  const float s = 1.0f/(1.0f + expf(-x));
2674
2815
  return dy*s*(1.0f + x*(1.0f - s));
@@ -2760,6 +2901,8 @@ static const char * LM_GGML_OP_NAME[LM_GGML_OP_COUNT] = {
2760
2901
  "SQR",
2761
2902
  "SQRT",
2762
2903
  "LOG",
2904
+ "SIN",
2905
+ "COS",
2763
2906
  "SUM",
2764
2907
  "SUM_ROWS",
2765
2908
  "MEAN",
@@ -2797,9 +2940,11 @@ static const char * LM_GGML_OP_NAME[LM_GGML_OP_COUNT] = {
2797
2940
  "CLAMP",
2798
2941
  "CONV_TRANSPOSE_1D",
2799
2942
  "IM2COL",
2943
+ "IM2COL_BACK",
2800
2944
  "CONV_TRANSPOSE_2D",
2801
2945
  "POOL_1D",
2802
2946
  "POOL_2D",
2947
+ "POOL_2D_BACK",
2803
2948
  "UPSCALE",
2804
2949
  "PAD",
2805
2950
  "ARANGE",
@@ -2815,6 +2960,7 @@ static const char * LM_GGML_OP_NAME[LM_GGML_OP_COUNT] = {
2815
2960
  "WIN_UNPART",
2816
2961
  "GET_REL_POS",
2817
2962
  "ADD_REL_POS",
2963
+ "RWKV_WKV",
2818
2964
 
2819
2965
  "UNARY",
2820
2966
 
@@ -2833,7 +2979,7 @@ static const char * LM_GGML_OP_NAME[LM_GGML_OP_COUNT] = {
2833
2979
  "CROSS_ENTROPY_LOSS_BACK",
2834
2980
  };
2835
2981
 
2836
- static_assert(LM_GGML_OP_COUNT == 74, "LM_GGML_OP_COUNT != 74");
2982
+ static_assert(LM_GGML_OP_COUNT == 79, "LM_GGML_OP_COUNT != 79");
2837
2983
 
2838
2984
  static const char * LM_GGML_OP_SYMBOL[LM_GGML_OP_COUNT] = {
2839
2985
  "none",
@@ -2848,6 +2994,8 @@ static const char * LM_GGML_OP_SYMBOL[LM_GGML_OP_COUNT] = {
2848
2994
  "x^2",
2849
2995
  "√x",
2850
2996
  "log(x)",
2997
+ "sin(x)",
2998
+ "cos(x)",
2851
2999
  "Σx",
2852
3000
  "Σx_k",
2853
3001
  "Σx/n",
@@ -2885,9 +3033,11 @@ static const char * LM_GGML_OP_SYMBOL[LM_GGML_OP_COUNT] = {
2885
3033
  "clamp(x)",
2886
3034
  "conv_transpose_1d(x)",
2887
3035
  "im2col(x)",
3036
+ "im2col_back(x)",
2888
3037
  "conv_transpose_2d(x)",
2889
3038
  "pool_1d(x)",
2890
3039
  "pool_2d(x)",
3040
+ "pool_2d_back(x)",
2891
3041
  "upscale(x)",
2892
3042
  "pad(x)",
2893
3043
  "arange(start, stop, step)",
@@ -2903,6 +3053,7 @@ static const char * LM_GGML_OP_SYMBOL[LM_GGML_OP_COUNT] = {
2903
3053
  "win_unpart(x)",
2904
3054
  "get_rel_pos(x)",
2905
3055
  "add_rel_pos(x)",
3056
+ "rwkv_wkv(k, v, r, tf, td, s)",
2906
3057
 
2907
3058
  "unary(x)",
2908
3059
 
@@ -2921,7 +3072,7 @@ static const char * LM_GGML_OP_SYMBOL[LM_GGML_OP_COUNT] = {
2921
3072
  "cross_entropy_loss_back(x,y)",
2922
3073
  };
2923
3074
 
2924
- static_assert(LM_GGML_OP_COUNT == 74, "LM_GGML_OP_COUNT != 74");
3075
+ static_assert(LM_GGML_OP_COUNT == 79, "LM_GGML_OP_COUNT != 79");
2925
3076
 
2926
3077
  static_assert(LM_GGML_OP_POOL_COUNT == 2, "LM_GGML_OP_POOL_COUNT != 2");
2927
3078
 
@@ -2940,14 +3091,28 @@ static const char * LM_GGML_UNARY_OP_NAME[LM_GGML_UNARY_OP_COUNT] = {
2940
3091
  "SILU",
2941
3092
  "HARDSWISH",
2942
3093
  "HARDSIGMOID",
3094
+ "EXP",
2943
3095
  };
2944
3096
 
2945
- static_assert(LM_GGML_UNARY_OP_COUNT == 13, "LM_GGML_UNARY_OP_COUNT != 13");
3097
+ static_assert(LM_GGML_UNARY_OP_COUNT == 14, "LM_GGML_UNARY_OP_COUNT != 14");
2946
3098
 
2947
3099
 
2948
3100
  static_assert(sizeof(struct lm_ggml_object)%LM_GGML_MEM_ALIGN == 0, "lm_ggml_object size must be a multiple of LM_GGML_MEM_ALIGN");
2949
3101
  static_assert(sizeof(struct lm_ggml_tensor)%LM_GGML_MEM_ALIGN == 0, "lm_ggml_tensor size must be a multiple of LM_GGML_MEM_ALIGN");
2950
3102
 
3103
+ // Helpers for polling loops
3104
+ #if defined(__aarch64__) && ( defined(__clang__) || defined(__GNUC__) )
3105
+ static inline void lm_ggml_thread_cpu_relax(void) {
3106
+ __asm__ volatile("yield" ::: "memory");
3107
+ }
3108
+ #elif defined(__x86_64__)
3109
+ static inline void lm_ggml_thread_cpu_relax(void) {
3110
+ _mm_pause();
3111
+ }
3112
+ #else
3113
+ static inline void lm_ggml_thread_cpu_relax(void) {;}
3114
+ #endif
3115
+
2951
3116
  //
2952
3117
  // NUMA support
2953
3118
  //
@@ -2995,42 +3160,36 @@ inline static void lm_ggml_critical_section_start(void) {
2995
3160
  }
2996
3161
 
2997
3162
  #ifdef LM_GGML_USE_OPENMP
2998
- static void lm_ggml_barrier(struct lm_ggml_compute_state_shared * shared) {
2999
- if (shared->n_threads == 1) {
3163
+ static void lm_ggml_barrier(struct lm_ggml_threadpool * threadpool) {
3164
+ if (threadpool->n_threads_cur == 1) {
3000
3165
  return;
3001
3166
  }
3002
3167
 
3003
3168
  #pragma omp barrier
3004
3169
  }
3005
3170
  #else
3006
- static void lm_ggml_barrier(struct lm_ggml_compute_state_shared * shared) {
3007
- if (shared->n_threads == 1) {
3171
+ static void lm_ggml_barrier(struct lm_ggml_threadpool * threadpool) {
3172
+ if (threadpool->n_threads_cur == 1) {
3008
3173
  return;
3009
3174
  }
3010
3175
 
3011
- atomic_int * n_barrier = &shared->n_barrier;
3012
- atomic_int * n_barrier_passed = &shared->n_barrier_passed;
3176
+ atomic_int * n_barrier = &threadpool->n_barrier;
3177
+ atomic_int * n_barrier_passed = &threadpool->n_barrier_passed;
3013
3178
 
3014
- int n_threads = shared->n_threads;
3015
- int passed_old = atomic_load(n_barrier_passed);
3179
+ int n_threads = threadpool->n_threads_cur;
3180
+ int passed_old = atomic_load_explicit(n_barrier_passed, memory_order_relaxed);
3016
3181
 
3017
3182
  if (atomic_fetch_add(n_barrier, 1) == n_threads - 1) {
3018
3183
  // last thread
3019
3184
  atomic_store(n_barrier, 0);
3020
- atomic_fetch_add(n_barrier_passed, 1);
3185
+ atomic_fetch_add_explicit(n_barrier_passed, 1, memory_order_relaxed);
3021
3186
  } else {
3022
3187
  // wait for other threads
3023
- const int n_spin_before_sleep = 100000;
3024
3188
  while (true) {
3025
- for (int i = 0; i < n_spin_before_sleep; i++) {
3026
- if (atomic_load(n_barrier_passed) != passed_old) {
3027
- return;
3028
- }
3029
- #if defined(__SSE3__)
3030
- _mm_pause();
3031
- #endif
3189
+ if (atomic_load_explicit(n_barrier_passed, memory_order_relaxed) != passed_old) {
3190
+ return;
3032
3191
  }
3033
- sched_yield();
3192
+ lm_ggml_thread_cpu_relax();
3034
3193
  }
3035
3194
  }
3036
3195
  }
@@ -3767,6 +3926,7 @@ static struct lm_ggml_tensor * lm_ggml_new_tensor_impl(
3767
3926
  }
3768
3927
 
3769
3928
  struct lm_ggml_object * const obj_new = lm_ggml_new_object(ctx, LM_GGML_OBJECT_TYPE_TENSOR, LM_GGML_TENSOR_SIZE + obj_alloc_size);
3929
+ LM_GGML_ASSERT(obj_new);
3770
3930
 
3771
3931
  // TODO: for recoverable errors, we would need to free the data allocated from the scratch buffer here
3772
3932
 
@@ -4486,8 +4646,6 @@ static struct lm_ggml_tensor * lm_ggml_add_impl(
4486
4646
  bool is_node = false;
4487
4647
 
4488
4648
  if (!inplace && (a->grad || b->grad)) {
4489
- // TODO: support backward pass for broadcasting
4490
- LM_GGML_ASSERT(lm_ggml_are_same_shape(a, b));
4491
4649
  is_node = true;
4492
4650
  }
4493
4651
 
@@ -4661,11 +4819,13 @@ static struct lm_ggml_tensor * lm_ggml_sub_impl(
4661
4819
  struct lm_ggml_tensor * a,
4662
4820
  struct lm_ggml_tensor * b,
4663
4821
  bool inplace) {
4664
- LM_GGML_ASSERT(lm_ggml_are_same_shape(a, b));
4822
+ LM_GGML_ASSERT(lm_ggml_can_repeat(b, a));
4665
4823
 
4666
4824
  bool is_node = false;
4667
4825
 
4668
4826
  if (!inplace && (a->grad || b->grad)) {
4827
+ // TODO: support backward pass for broadcasting
4828
+ LM_GGML_ASSERT(lm_ggml_are_same_shape(a, b));
4669
4829
  is_node = true;
4670
4830
  }
4671
4831
 
@@ -4880,6 +5040,72 @@ struct lm_ggml_tensor * lm_ggml_log_inplace(
4880
5040
  return lm_ggml_log_impl(ctx, a, true);
4881
5041
  }
4882
5042
 
5043
+ // lm_ggml_sin
5044
+
5045
+ static struct lm_ggml_tensor * lm_ggml_sin_impl(
5046
+ struct lm_ggml_context * ctx,
5047
+ struct lm_ggml_tensor * a,
5048
+ bool inplace) {
5049
+ bool is_node = false;
5050
+
5051
+ if (!inplace && (a->grad)) {
5052
+ is_node = true;
5053
+ }
5054
+
5055
+ struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a);
5056
+
5057
+ result->op = LM_GGML_OP_SIN;
5058
+ result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
5059
+ result->src[0] = a;
5060
+
5061
+ return result;
5062
+ }
5063
+
5064
+ struct lm_ggml_tensor * lm_ggml_sin(
5065
+ struct lm_ggml_context * ctx,
5066
+ struct lm_ggml_tensor * a) {
5067
+ return lm_ggml_sin_impl(ctx, a, false);
5068
+ }
5069
+
5070
+ struct lm_ggml_tensor * lm_ggml_sin_inplace(
5071
+ struct lm_ggml_context * ctx,
5072
+ struct lm_ggml_tensor * a) {
5073
+ return lm_ggml_sin_impl(ctx, a, true);
5074
+ }
5075
+
5076
+ // lm_ggml_cos
5077
+
5078
+ static struct lm_ggml_tensor * lm_ggml_cos_impl(
5079
+ struct lm_ggml_context * ctx,
5080
+ struct lm_ggml_tensor * a,
5081
+ bool inplace) {
5082
+ bool is_node = false;
5083
+
5084
+ if (!inplace && (a->grad)) {
5085
+ is_node = true;
5086
+ }
5087
+
5088
+ struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a);
5089
+
5090
+ result->op = LM_GGML_OP_COS;
5091
+ result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
5092
+ result->src[0] = a;
5093
+
5094
+ return result;
5095
+ }
5096
+
5097
+ struct lm_ggml_tensor * lm_ggml_cos(
5098
+ struct lm_ggml_context * ctx,
5099
+ struct lm_ggml_tensor * a) {
5100
+ return lm_ggml_cos_impl(ctx, a, false);
5101
+ }
5102
+
5103
+ struct lm_ggml_tensor * lm_ggml_cos_inplace(
5104
+ struct lm_ggml_context * ctx,
5105
+ struct lm_ggml_tensor * a) {
5106
+ return lm_ggml_cos_impl(ctx, a, true);
5107
+ }
5108
+
4883
5109
  // lm_ggml_sum
4884
5110
 
4885
5111
  struct lm_ggml_tensor * lm_ggml_sum(
@@ -5041,6 +5267,7 @@ struct lm_ggml_tensor * lm_ggml_concat(
5041
5267
  bool is_node = false;
5042
5268
 
5043
5269
  if (a->grad || b->grad) {
5270
+ LM_GGML_ABORT("fatal error"); // TODO: implement
5044
5271
  is_node = true;
5045
5272
  }
5046
5273
 
@@ -5162,6 +5389,7 @@ struct lm_ggml_tensor * lm_ggml_leaky_relu(
5162
5389
  bool is_node = false;
5163
5390
 
5164
5391
  if (!inplace && (a->grad)) {
5392
+ LM_GGML_ABORT("fatal error"); // TODO: not implemented
5165
5393
  is_node = true;
5166
5394
  }
5167
5395
 
@@ -5269,6 +5497,19 @@ struct lm_ggml_tensor * lm_ggml_hardsigmoid(
5269
5497
  return lm_ggml_unary(ctx, a, LM_GGML_UNARY_OP_HARDSIGMOID);
5270
5498
  }
5271
5499
 
5500
+ // ggml exp
5501
+ struct lm_ggml_tensor * lm_ggml_exp(
5502
+ struct lm_ggml_context * ctx,
5503
+ struct lm_ggml_tensor * a) {
5504
+ return lm_ggml_unary(ctx, a, LM_GGML_UNARY_OP_EXP);
5505
+ }
5506
+
5507
+ struct lm_ggml_tensor * lm_ggml_exp_inplace(
5508
+ struct lm_ggml_context * ctx,
5509
+ struct lm_ggml_tensor * a) {
5510
+ return lm_ggml_unary_inplace(ctx, a, LM_GGML_UNARY_OP_EXP);
5511
+ }
5512
+
5272
5513
  // lm_ggml_norm
5273
5514
 
5274
5515
  static struct lm_ggml_tensor * lm_ggml_norm_impl(
@@ -5587,6 +5828,7 @@ static struct lm_ggml_tensor * lm_ggml_set_impl(
5587
5828
  // make a view of the destination
5588
5829
  struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a);
5589
5830
 
5831
+ LM_GGML_ASSERT(offset < (size_t)(1 << 30));
5590
5832
  int32_t params[] = { nb1, nb2, nb3, offset, inplace ? 1 : 0 };
5591
5833
  lm_ggml_set_op_params(result, params, sizeof(params));
5592
5834
 
@@ -6544,14 +6786,12 @@ struct lm_ggml_tensor * lm_ggml_rope_back(
6544
6786
  LM_GGML_ASSERT(lm_ggml_is_vector(b));
6545
6787
  LM_GGML_ASSERT(b->type == LM_GGML_TYPE_I32);
6546
6788
  LM_GGML_ASSERT(a->ne[2] == b->ne[0]);
6547
- LM_GGML_ASSERT(c == NULL && "freq factors not implemented yet");
6548
-
6549
- LM_GGML_ASSERT((mode & 4) == 0 && "lm_ggml_rope_back() for ChatGLM not implemented yet");
6550
6789
 
6551
6790
  bool is_node = false;
6552
6791
 
6553
6792
  if (a->grad) {
6554
- is_node = false; // TODO: implement backward
6793
+ LM_GGML_ASSERT(false && "backwards pass not implemented");
6794
+ is_node = false;
6555
6795
  }
6556
6796
 
6557
6797
  struct lm_ggml_tensor * result = lm_ggml_dup_tensor(ctx, a);
@@ -6569,6 +6809,7 @@ struct lm_ggml_tensor * lm_ggml_rope_back(
6569
6809
  result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
6570
6810
  result->src[0] = a;
6571
6811
  result->src[1] = b;
6812
+ result->src[2] = c;
6572
6813
 
6573
6814
  return result;
6574
6815
  }
@@ -6727,17 +6968,20 @@ struct lm_ggml_tensor * lm_ggml_im2col(
6727
6968
  LM_GGML_ASSERT(a->ne[2] == b->ne[2]);
6728
6969
  } else {
6729
6970
  LM_GGML_ASSERT(a->ne[1] == b->ne[1]);
6971
+ LM_GGML_ASSERT(b->ne[3] == 1);
6730
6972
  }
6731
6973
  bool is_node = false;
6732
6974
 
6733
- if (a->grad || b->grad) {
6734
- LM_GGML_ABORT("fatal error"); // TODO: implement backward
6975
+ if (/*a->grad ||*/ b->grad) { // a is only used for its shape, not its data
6735
6976
  is_node = true;
6736
6977
  }
6737
6978
 
6738
6979
  const int64_t OH = is_2D ? lm_ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1) : 0;
6739
6980
  const int64_t OW = lm_ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0);
6740
6981
 
6982
+ LM_GGML_ASSERT((!is_2D || OH > 0) && "b too small compared to a");
6983
+ LM_GGML_ASSERT((OW > 0) && "b too small compared to a");
6984
+
6741
6985
  const int64_t ne[4] = {
6742
6986
  is_2D ? (a->ne[2] * a->ne[1] * a->ne[0]) : a->ne[1] * a->ne[0],
6743
6987
  OW,
@@ -6757,6 +7001,37 @@ struct lm_ggml_tensor * lm_ggml_im2col(
6757
7001
  return result;
6758
7002
  }
6759
7003
 
7004
+ struct lm_ggml_tensor * lm_ggml_im2col_back(
7005
+ struct lm_ggml_context * ctx,
7006
+ struct lm_ggml_tensor * a,
7007
+ struct lm_ggml_tensor * b,
7008
+ int64_t * ne,
7009
+ int s0,
7010
+ int s1,
7011
+ int p0,
7012
+ int p1,
7013
+ int d0,
7014
+ int d1,
7015
+ bool is_2D) {
7016
+
7017
+ bool is_node = false;
7018
+
7019
+ if (/*a->grad ||*/ b->grad) { // a is only used for its shape, not its data
7020
+ is_node = true;
7021
+ }
7022
+
7023
+ struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, 4, ne);
7024
+ int32_t params[] = { s0, s1, p0, p1, d0, d1, (is_2D ? 1 : 0) };
7025
+ lm_ggml_set_op_params(result, params, sizeof(params));
7026
+
7027
+ result->op = LM_GGML_OP_IM2COL_BACK;
7028
+ result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
7029
+ result->src[0] = a;
7030
+ result->src[1] = b;
7031
+
7032
+ return result;
7033
+ }
7034
+
6760
7035
  // a: [OC,IC, KH, KW]
6761
7036
  // b: [N, IC, IH, IW]
6762
7037
  // result: [N, OC, OH, OW]
@@ -6770,7 +7045,7 @@ struct lm_ggml_tensor * lm_ggml_conv_2d(
6770
7045
  int p1,
6771
7046
  int d0,
6772
7047
  int d1) {
6773
- struct lm_ggml_tensor * im2col = lm_ggml_im2col(ctx, a, b, s0, s1, p0, p1, d0, d1, true, LM_GGML_TYPE_F16); // [N, OH, OW, IC * KH * KW]
7048
+ struct lm_ggml_tensor * im2col = lm_ggml_im2col(ctx, a, b, s0, s1, p0, p1, d0, d1, true, a->type); // [N, OH, OW, IC * KH * KW]
6774
7049
 
6775
7050
  struct lm_ggml_tensor * result =
6776
7051
  lm_ggml_mul_mat(ctx,
@@ -6896,17 +7171,17 @@ struct lm_ggml_tensor * lm_ggml_pool_2d(
6896
7171
  bool is_node = false;
6897
7172
 
6898
7173
  if (a->grad) {
6899
- LM_GGML_ABORT("fatal error"); // TODO: implement backward
6900
7174
  is_node = true;
6901
7175
  }
6902
7176
 
6903
7177
  struct lm_ggml_tensor * result;
6904
- const int64_t ne[3] = {
7178
+ const int64_t ne[4] = {
6905
7179
  lm_ggml_calc_pool_output_size(a->ne[0], k0, s0, p0),
6906
7180
  lm_ggml_calc_pool_output_size(a->ne[1], k1, s1, p1),
6907
7181
  a->ne[2],
7182
+ a->ne[3],
6908
7183
  };
6909
- result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, 3, ne);
7184
+ result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, 4, ne);
6910
7185
 
6911
7186
  int32_t params[] = { op, k0, k1, s0, s1, p0, p1 };
6912
7187
  lm_ggml_set_op_params(result, params, sizeof(params));
@@ -6917,6 +7192,37 @@ struct lm_ggml_tensor * lm_ggml_pool_2d(
6917
7192
  return result;
6918
7193
  }
6919
7194
 
7195
+ struct lm_ggml_tensor * lm_ggml_pool_2d_back(
7196
+ struct lm_ggml_context * ctx,
7197
+ struct lm_ggml_tensor * a,
7198
+ struct lm_ggml_tensor * af,
7199
+ enum lm_ggml_op_pool op,
7200
+ int k0,
7201
+ int k1,
7202
+ int s0,
7203
+ int s1,
7204
+ float p0,
7205
+ float p1) {
7206
+
7207
+ bool is_node = false;
7208
+
7209
+ if (a->grad) {
7210
+ is_node = true;
7211
+ }
7212
+
7213
+ struct lm_ggml_tensor * result;
7214
+ result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, 4, af->ne);
7215
+
7216
+ int32_t params[] = { op, k0, k1, s0, s1, p0, p1 };
7217
+ lm_ggml_set_op_params(result, params, sizeof(params));
7218
+
7219
+ result->op = LM_GGML_OP_POOL_2D_BACK;
7220
+ result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
7221
+ result->src[0] = a;
7222
+ result->src[1] = af;
7223
+ return result;
7224
+ }
7225
+
6920
7226
  // lm_ggml_upscale
6921
7227
 
6922
7228
  static struct lm_ggml_tensor * lm_ggml_upscale_impl(
@@ -7057,6 +7363,11 @@ struct lm_ggml_tensor * lm_ggml_argsort(
7057
7363
  enum lm_ggml_sort_order order) {
7058
7364
  bool is_node = false;
7059
7365
 
7366
+ if (a->grad) {
7367
+ LM_GGML_ABORT("fatal error"); // TODO: not implemented
7368
+ is_node = true;
7369
+ }
7370
+
7060
7371
  struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_I32, LM_GGML_MAX_DIMS, a->ne);
7061
7372
 
7062
7373
  lm_ggml_set_op_params_i32(result, 0, (int32_t) order);
@@ -7467,6 +7778,59 @@ struct lm_ggml_tensor * lm_ggml_add_rel_pos_inplace(
7467
7778
  return lm_ggml_add_rel_pos_impl(ctx, a, pw, ph, true);
7468
7779
  }
7469
7780
 
7781
+ // lm_ggml_rwkv_wkv
7782
+
7783
+ struct lm_ggml_tensor * lm_ggml_rwkv_wkv(
7784
+ struct lm_ggml_context * ctx,
7785
+ struct lm_ggml_tensor * k,
7786
+ struct lm_ggml_tensor * v,
7787
+ struct lm_ggml_tensor * r,
7788
+ struct lm_ggml_tensor * tf,
7789
+ struct lm_ggml_tensor * td,
7790
+ struct lm_ggml_tensor * state) {
7791
+ LM_GGML_ASSERT(lm_ggml_is_contiguous(k));
7792
+ LM_GGML_ASSERT(lm_ggml_is_contiguous(v));
7793
+ LM_GGML_ASSERT(lm_ggml_is_contiguous(r));
7794
+ LM_GGML_ASSERT(lm_ggml_is_contiguous(tf));
7795
+ LM_GGML_ASSERT(lm_ggml_is_contiguous(td));
7796
+ LM_GGML_ASSERT(lm_ggml_is_contiguous(state));
7797
+
7798
+ const int64_t S = k->ne[0];
7799
+ const int64_t H = k->ne[2];
7800
+ const int64_t n_tokens = k->ne[3];
7801
+ const int64_t n_seqs = state->ne[1];
7802
+ {
7803
+ LM_GGML_ASSERT(k->ne[1] == 1);
7804
+ LM_GGML_ASSERT(v->ne[0] == 1 && v->ne[1] == S && v->ne[2] == H && v->ne[3] == n_tokens);
7805
+ LM_GGML_ASSERT(r->ne[0] == 1 && r->ne[1] == S && r->ne[2] == H && r->ne[3] == n_tokens);
7806
+ // TODO: RWKV v4 and v5
7807
+ LM_GGML_ASSERT(td->ne[0] == 1 && td->ne[1] == S && td->ne[2] == H && td->ne[3] == n_tokens);
7808
+ LM_GGML_ASSERT(lm_ggml_nelements(state) == S * S * H * n_seqs);
7809
+ }
7810
+
7811
+ bool is_node = false;
7812
+
7813
+ if (k->grad || v->grad || r->grad || tf->grad || td->grad || state->grad) {
7814
+ LM_GGML_ABORT("fatal error"); // TODO: implement backward
7815
+ is_node = true;
7816
+ }
7817
+
7818
+ // concat output and new_state
7819
+ const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 };
7820
+ struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, 4, ne);
7821
+
7822
+ result->op = LM_GGML_OP_RWKV_WKV;
7823
+ result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
7824
+ result->src[0] = k;
7825
+ result->src[1] = v;
7826
+ result->src[2] = r;
7827
+ result->src[3] = tf;
7828
+ result->src[4] = td;
7829
+ result->src[5] = state;
7830
+
7831
+ return result;
7832
+ }
7833
+
7470
7834
  // lm_ggml_unary
7471
7835
 
7472
7836
  static struct lm_ggml_tensor * lm_ggml_unary_impl(
@@ -7965,8 +8329,7 @@ static void lm_ggml_compute_forward_dup_same_cont(
7965
8329
  LM_GGML_ASSERT(lm_ggml_is_contiguous(dst) && lm_ggml_is_contiguous(src0));
7966
8330
  LM_GGML_ASSERT(src0->type == dst->type);
7967
8331
 
7968
- const size_t nb00 = src0->nb[0];
7969
- const size_t nb0 = dst->nb[0];
8332
+ const size_t nb0 = lm_ggml_type_size(src0->type);
7970
8333
 
7971
8334
  const int ith = params->ith; // thread index
7972
8335
  const int nth = params->nth; // number of threads
@@ -7980,8 +8343,8 @@ static void lm_ggml_compute_forward_dup_same_cont(
7980
8343
  if (ie0 < ie1) {
7981
8344
  memcpy(
7982
8345
  ((char *) dst->data + ie0*nb0),
7983
- ((char *) src0->data + ie0*nb00),
7984
- (ie1 - ie0) * lm_ggml_type_size(src0->type));
8346
+ ((char *) src0->data + ie0*nb0),
8347
+ (ie1 - ie0) * nb0);
7985
8348
  }
7986
8349
  }
7987
8350
 
@@ -7998,11 +8361,6 @@ static void lm_ggml_compute_forward_dup_f16(
7998
8361
  const int ith = params->ith; // thread index
7999
8362
  const int nth = params->nth; // number of threads
8000
8363
 
8001
- if (lm_ggml_is_contiguous(src0) && lm_ggml_is_contiguous(dst) && src0->type == dst->type) {
8002
- lm_ggml_compute_forward_dup_same_cont(params, dst);
8003
- return;
8004
- }
8005
-
8006
8364
  // parallelize by rows
8007
8365
  const int nr = ne01;
8008
8366
  // number of rows per thread
@@ -8267,11 +8625,6 @@ static void lm_ggml_compute_forward_dup_bf16(
8267
8625
  const int ith = params->ith; // thread index
8268
8626
  const int nth = params->nth; // number of threads
8269
8627
 
8270
- if (lm_ggml_is_contiguous(src0) && lm_ggml_is_contiguous(dst) && src0->type == dst->type) {
8271
- lm_ggml_compute_forward_dup_same_cont(params, dst);
8272
- return;
8273
- }
8274
-
8275
8628
  // parallelize by rows
8276
8629
  const int nr = ne01;
8277
8630
  // number of rows per thread
@@ -8623,11 +8976,6 @@ static void lm_ggml_compute_forward_dup_f32(
8623
8976
  const int ith = params->ith; // thread index
8624
8977
  const int nth = params->nth; // number of threads
8625
8978
 
8626
- if (lm_ggml_is_contiguous(src0) && lm_ggml_is_contiguous(dst) && src0->type == dst->type) {
8627
- lm_ggml_compute_forward_dup_same_cont(params, dst);
8628
- return;
8629
- }
8630
-
8631
8979
  // parallelize by rows
8632
8980
  const int nr = ne01;
8633
8981
  // number of rows per thread
@@ -8937,13 +9285,13 @@ static void lm_ggml_compute_forward_dup_bytes(
8937
9285
  LM_GGML_ASSERT(lm_ggml_nelements(dst) == lm_ggml_nelements(src0));
8938
9286
  LM_GGML_ASSERT(src0->type == dst->type);
8939
9287
 
9288
+ LM_GGML_TENSOR_UNARY_OP_LOCALS;
9289
+
8940
9290
  if (lm_ggml_is_contiguous(src0) && lm_ggml_is_contiguous(dst)) {
8941
9291
  lm_ggml_compute_forward_dup_same_cont(params, dst);
8942
9292
  return;
8943
9293
  }
8944
9294
 
8945
- LM_GGML_TENSOR_UNARY_OP_LOCALS;
8946
-
8947
9295
  const size_t type_size = lm_ggml_type_size(src0->type);
8948
9296
  const int ith = params->ith; // thread index
8949
9297
  const int nth = params->nth; // number of threads
@@ -9564,6 +9912,8 @@ static void lm_ggml_compute_forward_add(
9564
9912
  case LM_GGML_TYPE_Q4_K:
9565
9913
  case LM_GGML_TYPE_Q5_K:
9566
9914
  case LM_GGML_TYPE_Q6_K:
9915
+ case LM_GGML_TYPE_TQ1_0:
9916
+ case LM_GGML_TYPE_TQ2_0:
9567
9917
  case LM_GGML_TYPE_IQ2_XXS:
9568
9918
  case LM_GGML_TYPE_IQ2_XS:
9569
9919
  case LM_GGML_TYPE_IQ3_XXS:
@@ -9942,6 +10292,8 @@ static void lm_ggml_compute_forward_add1(
9942
10292
  case LM_GGML_TYPE_Q4_K:
9943
10293
  case LM_GGML_TYPE_Q5_K:
9944
10294
  case LM_GGML_TYPE_Q6_K:
10295
+ case LM_GGML_TYPE_TQ1_0:
10296
+ case LM_GGML_TYPE_TQ2_0:
9945
10297
  case LM_GGML_TYPE_IQ2_XXS:
9946
10298
  case LM_GGML_TYPE_IQ2_XS:
9947
10299
  case LM_GGML_TYPE_IQ3_XXS:
@@ -9993,7 +10345,7 @@ static void lm_ggml_compute_forward_acc_f32(
9993
10345
  ((char *) src0->data),
9994
10346
  lm_ggml_nbytes(dst));
9995
10347
  }
9996
- lm_ggml_barrier(params->shared);
10348
+ lm_ggml_barrier(params->threadpool);
9997
10349
  }
9998
10350
 
9999
10351
  const int ith = params->ith;
@@ -10070,6 +10422,8 @@ static void lm_ggml_compute_forward_acc(
10070
10422
  case LM_GGML_TYPE_Q4_K:
10071
10423
  case LM_GGML_TYPE_Q5_K:
10072
10424
  case LM_GGML_TYPE_Q6_K:
10425
+ case LM_GGML_TYPE_TQ1_0:
10426
+ case LM_GGML_TYPE_TQ2_0:
10073
10427
  case LM_GGML_TYPE_IQ2_XXS:
10074
10428
  case LM_GGML_TYPE_IQ2_XS:
10075
10429
  case LM_GGML_TYPE_IQ3_XXS:
@@ -10098,11 +10452,10 @@ static void lm_ggml_compute_forward_sub_f32(
10098
10452
  const struct lm_ggml_tensor * src0 = dst->src[0];
10099
10453
  const struct lm_ggml_tensor * src1 = dst->src[1];
10100
10454
 
10101
- if (params->ith != 0) {
10102
- return;
10103
- }
10455
+ assert(lm_ggml_can_repeat(src1, src0) && lm_ggml_are_same_shape(src0, dst));
10104
10456
 
10105
- assert(lm_ggml_are_same_shape(src0, src1) && lm_ggml_are_same_shape(src0, dst));
10457
+ const int ith = params->ith;
10458
+ const int nth = params->nth;
10106
10459
 
10107
10460
  const int nr = lm_ggml_nrows(src0);
10108
10461
 
@@ -10111,40 +10464,55 @@ static void lm_ggml_compute_forward_sub_f32(
10111
10464
  LM_GGML_ASSERT( nb0 == sizeof(float));
10112
10465
  LM_GGML_ASSERT(nb00 == sizeof(float));
10113
10466
 
10467
+ // rows per thread
10468
+ const int dr = (nr + nth - 1)/nth;
10469
+
10470
+ // row range for this thread
10471
+ const int ir0 = dr*ith;
10472
+ const int ir1 = MIN(ir0 + dr, nr);
10473
+
10114
10474
  if (nb10 == sizeof(float)) {
10115
- for (int ir = 0; ir < nr; ++ir) {
10116
- // src0, src1 and dst are same shape => same indices
10117
- const int i3 = ir/(ne2*ne1);
10118
- const int i2 = (ir - i3*ne2*ne1)/ne1;
10119
- const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
10475
+ for (int ir = ir0; ir < ir1; ++ir) {
10476
+ // src1 is broadcastable across src0 and dst in i1, i2, i3
10477
+ const int64_t i03 = ir/(ne02*ne01);
10478
+ const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
10479
+ const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
10120
10480
 
10121
- #ifdef LM_GGML_USE_ACCELERATE
10122
- vDSP_vsub(
10123
- (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1,
10124
- (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1,
10125
- (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), 1,
10126
- ne0);
10127
- #else
10128
- lm_ggml_vec_sub_f32(ne0,
10129
- (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ),
10130
- (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
10131
- (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
10481
+ const int64_t i13 = i03 % ne13;
10482
+ const int64_t i12 = i02 % ne12;
10483
+ const int64_t i11 = i01 % ne11;
10484
+ const int64_t nr0 = ne00 / ne10;
10485
+
10486
+ float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
10487
+ float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
10488
+ float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
10489
+
10490
+ for (int64_t r = 0; r < nr0; ++r) {
10491
+ #ifdef LM_GGML_USE_ACCELERATE
10492
+ vDSP_vsub(src1_ptr, 1, src0_ptr + r*ne10, 1, dst_ptr + r*ne10, 1, ne10);
10493
+ #else
10494
+ lm_ggml_vec_sub_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
10132
10495
  #endif
10133
- // }
10134
- // }
10496
+ }
10135
10497
  }
10136
10498
  } else {
10137
10499
  // src1 is not contiguous
10138
- for (int ir = 0; ir < nr; ++ir) {
10139
- // src0, src1 and dst are same shape => same indices
10140
- const int i3 = ir/(ne2*ne1);
10141
- const int i2 = (ir - i3*ne2*ne1)/ne1;
10142
- const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
10500
+ for (int ir = ir0; ir < ir1; ++ir) {
10501
+ // src1 is broadcastable across src0 and dst in i1, i2, i3
10502
+ const int64_t i03 = ir/(ne02*ne01);
10503
+ const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
10504
+ const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
10505
+
10506
+ const int64_t i13 = i03 % ne13;
10507
+ const int64_t i12 = i02 % ne12;
10508
+ const int64_t i11 = i01 % ne11;
10509
+
10510
+ float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
10511
+ float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
10143
10512
 
10144
- float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
10145
- float * src0_ptr = (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
10146
- for (int i0 = 0; i0 < ne0; i0++) {
10147
- float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11 + i0*nb10);
10513
+ for (int64_t i0 = 0; i0 < ne0; ++i0) {
10514
+ const int64_t i10 = i0 % ne10;
10515
+ float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10);
10148
10516
 
10149
10517
  dst_ptr[i0] = src0_ptr[i0] - *src1_ptr;
10150
10518
  }
@@ -10490,9 +10858,9 @@ static void lm_ggml_compute_forward_log(
10490
10858
  }
10491
10859
  }
10492
10860
 
10493
- // lm_ggml_compute_forward_sum
10861
+ // lm_ggml_compute_forward_sin
10494
10862
 
10495
- static void lm_ggml_compute_forward_sum_f32(
10863
+ static void lm_ggml_compute_forward_sin_f32(
10496
10864
  const struct lm_ggml_compute_params * params,
10497
10865
  struct lm_ggml_tensor * dst) {
10498
10866
 
@@ -10502,8 +10870,95 @@ static void lm_ggml_compute_forward_sum_f32(
10502
10870
  return;
10503
10871
  }
10504
10872
 
10505
- assert(lm_ggml_is_scalar(dst));
10873
+ LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, dst));
10874
+
10875
+ const int n = lm_ggml_nrows(src0);
10876
+ const int nc = src0->ne[0];
10877
+
10878
+ LM_GGML_ASSERT( dst->nb[0] == sizeof(float));
10879
+ LM_GGML_ASSERT(src0->nb[0] == sizeof(float));
10880
+
10881
+ for (int i = 0; i < n; i++) {
10882
+ lm_ggml_vec_sin_f32(nc,
10883
+ (float *) ((char *) dst->data + i*( dst->nb[1])),
10884
+ (float *) ((char *) src0->data + i*(src0->nb[1])));
10885
+ }
10886
+ }
10887
+
10888
+ static void lm_ggml_compute_forward_sin(
10889
+ const struct lm_ggml_compute_params * params,
10890
+ struct lm_ggml_tensor * dst) {
10891
+
10892
+ const struct lm_ggml_tensor * src0 = dst->src[0];
10893
+
10894
+ switch (src0->type) {
10895
+ case LM_GGML_TYPE_F32:
10896
+ {
10897
+ lm_ggml_compute_forward_sin_f32(params, dst);
10898
+ } break;
10899
+ default:
10900
+ {
10901
+ LM_GGML_ABORT("fatal error");
10902
+ }
10903
+ }
10904
+ }
10905
+
10906
+ // lm_ggml_compute_forward_cos
10907
+
10908
+ static void lm_ggml_compute_forward_cos_f32(
10909
+ const struct lm_ggml_compute_params * params,
10910
+ struct lm_ggml_tensor * dst) {
10911
+
10912
+ const struct lm_ggml_tensor * src0 = dst->src[0];
10913
+
10914
+ if (params->ith != 0) {
10915
+ return;
10916
+ }
10917
+
10918
+ LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, dst));
10919
+
10920
+ const int n = lm_ggml_nrows(src0);
10921
+ const int nc = src0->ne[0];
10922
+
10923
+ LM_GGML_ASSERT( dst->nb[0] == sizeof(float));
10924
+ LM_GGML_ASSERT(src0->nb[0] == sizeof(float));
10925
+
10926
+ for (int i = 0; i < n; i++) {
10927
+ lm_ggml_vec_cos_f32(nc,
10928
+ (float *) ((char *) dst->data + i*( dst->nb[1])),
10929
+ (float *) ((char *) src0->data + i*(src0->nb[1])));
10930
+ }
10931
+ }
10932
+
10933
+ static void lm_ggml_compute_forward_cos(
10934
+ const struct lm_ggml_compute_params * params,
10935
+ struct lm_ggml_tensor * dst) {
10936
+
10937
+ const struct lm_ggml_tensor * src0 = dst->src[0];
10938
+
10939
+ switch (src0->type) {
10940
+ case LM_GGML_TYPE_F32:
10941
+ {
10942
+ lm_ggml_compute_forward_cos_f32(params, dst);
10943
+ } break;
10944
+ default:
10945
+ {
10946
+ LM_GGML_ABORT("fatal error");
10947
+ }
10948
+ }
10949
+ }
10950
+
10951
+ // lm_ggml_compute_forward_sum
10952
+
10953
+ static void lm_ggml_compute_forward_sum_f32(
10954
+ const struct lm_ggml_compute_params * params,
10955
+ struct lm_ggml_tensor * dst) {
10956
+
10957
+ const struct lm_ggml_tensor * src0 = dst->src[0];
10506
10958
 
10959
+ if (params->ith != 0) {
10960
+ return;
10961
+ }
10507
10962
 
10508
10963
  assert(lm_ggml_is_scalar(dst));
10509
10964
  assert(src0->nb[0] == sizeof(float));
@@ -11762,6 +12217,48 @@ static void lm_ggml_compute_forward_hardsigmoid(
11762
12217
  }
11763
12218
  }
11764
12219
 
12220
+ static void lm_ggml_compute_forward_exp_f32(
12221
+ const struct lm_ggml_compute_params * params,
12222
+ struct lm_ggml_tensor * dst) {
12223
+
12224
+ const struct lm_ggml_tensor * src0 = dst->src[0];
12225
+
12226
+ if (params->ith != 0) {
12227
+ return;
12228
+ }
12229
+
12230
+ assert(lm_ggml_is_contiguous_1(src0));
12231
+ assert(lm_ggml_is_contiguous_1(dst));
12232
+ assert(lm_ggml_are_same_shape(src0, dst));
12233
+
12234
+ const int n = lm_ggml_nrows(src0);
12235
+ const int nc = src0->ne[0];
12236
+
12237
+ for (int i = 0; i < n; i++) {
12238
+ lm_ggml_vec_exp_f32(nc,
12239
+ (float *) ((char *) dst->data + i*( dst->nb[1])),
12240
+ (float *) ((char *) src0->data + i*(src0->nb[1])));
12241
+ }
12242
+ }
12243
+
12244
+ static void lm_ggml_compute_forward_exp(
12245
+ const struct lm_ggml_compute_params * params,
12246
+ struct lm_ggml_tensor * dst) {
12247
+
12248
+ const struct lm_ggml_tensor * src0 = dst->src[0];
12249
+
12250
+ switch (src0->type) {
12251
+ case LM_GGML_TYPE_F32:
12252
+ {
12253
+ lm_ggml_compute_forward_exp_f32(params, dst);
12254
+ } break;
12255
+ default:
12256
+ {
12257
+ LM_GGML_ABORT("fatal error");
12258
+ }
12259
+ }
12260
+ }
12261
+
11765
12262
 
11766
12263
  // lm_ggml_compute_forward_norm
11767
12264
 
@@ -12363,10 +12860,10 @@ UseGgmlGemm1:;
12363
12860
 
12364
12861
  if (ith == 0) {
12365
12862
  // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
12366
- atomic_store(&params->shared->current_chunk, nth);
12863
+ atomic_store_explicit(&params->threadpool->current_chunk, nth, memory_order_relaxed);
12367
12864
  }
12368
12865
 
12369
- lm_ggml_barrier(params->shared);
12866
+ lm_ggml_barrier(params->threadpool);
12370
12867
 
12371
12868
  #if LM_GGML_USE_LLAMAFILE
12372
12869
  if (src1->type != vec_dot_type) {
@@ -12474,7 +12971,7 @@ UseGgmlGemm2:;
12474
12971
  break;
12475
12972
  }
12476
12973
 
12477
- current_chunk = atomic_fetch_add(&params->shared->current_chunk, 1);
12974
+ current_chunk = atomic_fetch_add_explicit(&params->threadpool->current_chunk, 1, memory_order_relaxed);
12478
12975
  }
12479
12976
  }
12480
12977
 
@@ -12569,7 +13066,7 @@ static void lm_ggml_compute_forward_mul_mat_id(
12569
13066
  }
12570
13067
  }
12571
13068
 
12572
- lm_ggml_barrier(params->shared);
13069
+ lm_ggml_barrier(params->threadpool);
12573
13070
 
12574
13071
  // compute each matrix multiplication in sequence
12575
13072
  for (int cur_a = 0; cur_a < n_as; ++cur_a) {
@@ -12723,7 +13220,7 @@ static void lm_ggml_compute_forward_out_prod_f32(
12723
13220
  if (ith == 0) {
12724
13221
  lm_ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
12725
13222
  }
12726
- lm_ggml_barrier(params->shared);
13223
+ lm_ggml_barrier(params->threadpool);
12727
13224
 
12728
13225
  // dst[:,:,:,:] = 0
12729
13226
  // for i2,i3:
@@ -12841,7 +13338,7 @@ static void lm_ggml_compute_forward_out_prod_q_f32(
12841
13338
  if (ith == 0) {
12842
13339
  lm_ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
12843
13340
  }
12844
- lm_ggml_barrier(params->shared);
13341
+ lm_ggml_barrier(params->threadpool);
12845
13342
 
12846
13343
  // parallelize by last three dimensions
12847
13344
 
@@ -12907,6 +13404,8 @@ static void lm_ggml_compute_forward_out_prod(
12907
13404
  case LM_GGML_TYPE_Q4_K:
12908
13405
  case LM_GGML_TYPE_Q5_K:
12909
13406
  case LM_GGML_TYPE_Q6_K:
13407
+ case LM_GGML_TYPE_TQ1_0:
13408
+ case LM_GGML_TYPE_TQ2_0:
12910
13409
  case LM_GGML_TYPE_IQ2_XXS:
12911
13410
  case LM_GGML_TYPE_IQ2_XS:
12912
13411
  case LM_GGML_TYPE_IQ3_XXS:
@@ -13027,7 +13526,7 @@ static void lm_ggml_compute_forward_set_f32(
13027
13526
  ((char *) src0->data),
13028
13527
  lm_ggml_nbytes(dst));
13029
13528
  }
13030
- lm_ggml_barrier(params->shared);
13529
+ lm_ggml_barrier(params->threadpool);
13031
13530
  }
13032
13531
 
13033
13532
  const int ith = params->ith;
@@ -13095,6 +13594,8 @@ static void lm_ggml_compute_forward_set(
13095
13594
  case LM_GGML_TYPE_Q4_K:
13096
13595
  case LM_GGML_TYPE_Q5_K:
13097
13596
  case LM_GGML_TYPE_Q6_K:
13597
+ case LM_GGML_TYPE_TQ1_0:
13598
+ case LM_GGML_TYPE_TQ2_0:
13098
13599
  case LM_GGML_TYPE_IQ2_XXS:
13099
13600
  case LM_GGML_TYPE_IQ2_XS:
13100
13601
  case LM_GGML_TYPE_IQ3_XXS:
@@ -13208,7 +13709,7 @@ static void lm_ggml_compute_forward_get_rows_q(
13208
13709
  const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
13209
13710
  const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
13210
13711
 
13211
- assert(i01 >= 0 && i01 < ne01);
13712
+ LM_GGML_ASSERT(i01 >= 0 && i01 < ne01);
13212
13713
 
13213
13714
  dequantize_row_q(
13214
13715
  (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
@@ -13249,7 +13750,7 @@ static void lm_ggml_compute_forward_get_rows_f16(
13249
13750
  const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
13250
13751
  const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
13251
13752
 
13252
- assert(i01 >= 0 && i01 < ne01);
13753
+ LM_GGML_ASSERT(i01 >= 0 && i01 < ne01);
13253
13754
 
13254
13755
  lm_ggml_fp16_to_fp32_row(
13255
13756
  (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
@@ -13290,7 +13791,7 @@ static void lm_ggml_compute_forward_get_rows_bf16(
13290
13791
  const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
13291
13792
  const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
13292
13793
 
13293
- assert(i01 >= 0 && i01 < ne01);
13794
+ LM_GGML_ASSERT(i01 >= 0 && i01 < ne01);
13294
13795
 
13295
13796
  lm_ggml_bf16_to_fp32_row(
13296
13797
  (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
@@ -13331,7 +13832,7 @@ static void lm_ggml_compute_forward_get_rows_f32(
13331
13832
  const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
13332
13833
  const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
13333
13834
 
13334
- assert(i01 >= 0 && i01 < ne01);
13835
+ LM_GGML_ASSERT(i01 >= 0 && i01 < ne01);
13335
13836
 
13336
13837
  lm_ggml_vec_cpy_f32(nc,
13337
13838
  (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3),
@@ -13357,6 +13858,8 @@ static void lm_ggml_compute_forward_get_rows(
13357
13858
  case LM_GGML_TYPE_Q4_K:
13358
13859
  case LM_GGML_TYPE_Q5_K:
13359
13860
  case LM_GGML_TYPE_Q6_K:
13861
+ case LM_GGML_TYPE_TQ1_0:
13862
+ case LM_GGML_TYPE_TQ2_0:
13360
13863
  case LM_GGML_TYPE_IQ2_XXS:
13361
13864
  case LM_GGML_TYPE_IQ2_XS:
13362
13865
  case LM_GGML_TYPE_IQ3_XXS:
@@ -13606,7 +14109,7 @@ static void lm_ggml_compute_forward_diag_mask_f32(
13606
14109
  ((char *) src0->data),
13607
14110
  lm_ggml_nbytes(dst));
13608
14111
  }
13609
- lm_ggml_barrier(params->shared);
14112
+ lm_ggml_barrier(params->threadpool);
13610
14113
  }
13611
14114
 
13612
14115
  // TODO: handle transposed/permuted matrices
@@ -13946,6 +14449,8 @@ static void lm_ggml_compute_forward_clamp(
13946
14449
  case LM_GGML_TYPE_Q4_K:
13947
14450
  case LM_GGML_TYPE_Q5_K:
13948
14451
  case LM_GGML_TYPE_Q6_K:
14452
+ case LM_GGML_TYPE_TQ1_0:
14453
+ case LM_GGML_TYPE_TQ2_0:
13949
14454
  case LM_GGML_TYPE_IQ2_XXS:
13950
14455
  case LM_GGML_TYPE_IQ2_XS:
13951
14456
  case LM_GGML_TYPE_IQ3_XXS:
@@ -14382,7 +14887,7 @@ static void lm_ggml_compute_forward_conv_transpose_1d_f16_f32(
14382
14887
  // need to zero dst since we are accumulating into it
14383
14888
  memset(dst->data, 0, lm_ggml_nbytes(dst));
14384
14889
  }
14385
- lm_ggml_barrier(params->shared);
14890
+ lm_ggml_barrier(params->threadpool);
14386
14891
 
14387
14892
  const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
14388
14893
 
@@ -14470,7 +14975,7 @@ static void lm_ggml_compute_forward_conv_transpose_1d_f32(
14470
14975
  // need to zero dst since we are accumulating into it
14471
14976
  memset(dst->data, 0, lm_ggml_nbytes(dst));
14472
14977
  }
14473
- lm_ggml_barrier(params->shared);
14978
+ lm_ggml_barrier(params->threadpool);
14474
14979
 
14475
14980
  const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
14476
14981
 
@@ -14525,6 +15030,7 @@ static void lm_ggml_compute_forward_conv_transpose_1d(
14525
15030
  }
14526
15031
  }
14527
15032
 
15033
+ // lm_ggml_compute_forward_im2col_f32
14528
15034
  // src0: kernel [OC, IC, KH, KW]
14529
15035
  // src1: image [N, IC, IH, IW]
14530
15036
  // dst: result [N, OH, OW, IC*KH*KW]
@@ -14535,7 +15041,6 @@ static void lm_ggml_compute_forward_im2col_f32(
14535
15041
  const struct lm_ggml_tensor * src0 = dst->src[0];
14536
15042
  const struct lm_ggml_tensor * src1 = dst->src[1];
14537
15043
 
14538
- LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_F16);
14539
15044
  LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_F32);
14540
15045
  LM_GGML_ASSERT( dst->type == LM_GGML_TYPE_F32);
14541
15046
 
@@ -14566,7 +15071,6 @@ static void lm_ggml_compute_forward_im2col_f32(
14566
15071
  int ofs0 = is_2D ? nb13 : nb12;
14567
15072
  int ofs1 = is_2D ? nb12 : nb11;
14568
15073
 
14569
- LM_GGML_ASSERT(nb00 == sizeof(lm_ggml_fp16_t));
14570
15074
  LM_GGML_ASSERT(nb10 == sizeof(float));
14571
15075
 
14572
15076
  // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
@@ -14602,6 +15106,7 @@ static void lm_ggml_compute_forward_im2col_f32(
14602
15106
  }
14603
15107
 
14604
15108
 
15109
+ // lm_ggml_compute_forward_im2col_f16
14605
15110
  // src0: kernel [OC, IC, KH, KW]
14606
15111
  // src1: image [N, IC, IH, IW]
14607
15112
  // dst: result [N, OH, OW, IC*KH*KW]
@@ -14697,6 +15202,99 @@ static void lm_ggml_compute_forward_im2col(
14697
15202
  }
14698
15203
  }
14699
15204
 
15205
+ // lm_ggml_compute_forward_im2col_back_f32
15206
+
15207
+ static void lm_ggml_compute_forward_im2col_back_f32(
15208
+ const struct lm_ggml_compute_params * params,
15209
+ struct lm_ggml_tensor * dst) {
15210
+
15211
+ const struct lm_ggml_tensor * src0 = dst->src[0];
15212
+ const struct lm_ggml_tensor * src1 = dst->src[1];
15213
+
15214
+ LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_F32);
15215
+ LM_GGML_ASSERT( dst->type == LM_GGML_TYPE_F32);
15216
+
15217
+ LM_GGML_TENSOR_BINARY_OP_LOCALS;
15218
+
15219
+ const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
15220
+ const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
15221
+ const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
15222
+ const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
15223
+ const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
15224
+ const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
15225
+ const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
15226
+
15227
+ const int ith = params->ith;
15228
+ const int nth = params->nth;
15229
+
15230
+ const int64_t N = is_2D ? ne3 : ne2;
15231
+ const int64_t IC = is_2D ? ne2 : ne1;
15232
+ const int64_t IH = is_2D ? ne1 : 1;
15233
+ const int64_t IW = ne0;
15234
+
15235
+ const int64_t KH = is_2D ? ne01 : 1;
15236
+ const int64_t KW = ne00;
15237
+
15238
+ const int64_t OH = is_2D ? ne12 : 1;
15239
+ const int64_t OW = ne11;
15240
+
15241
+ int ofs0 = is_2D ? nb3 : nb2;
15242
+ int ofs1 = is_2D ? nb2 : nb1;
15243
+
15244
+ LM_GGML_ASSERT(nb0 == sizeof(float));
15245
+
15246
+ // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
15247
+ {
15248
+ float * const wdata = (float *) dst->data;
15249
+
15250
+ for (int64_t in = 0; in < N; in++) {
15251
+ for (int64_t iic = ith; iic < IC; iic += nth) {
15252
+ for (int64_t iih = 0; iih < IH; iih++) {
15253
+ for (int64_t iiw = 0; iiw < IW; iiw++) {
15254
+
15255
+ // micro kernel
15256
+ float grad = 0.0f;
15257
+ for (int64_t ikh = 0; ikh < KH; ikh++) {
15258
+ for (int64_t ikw = 0; ikw < KW; ikw++) {
15259
+ // For s0 > 1 some values were skipped over in the forward pass.
15260
+ // These values have tmpw % s0 != 0 and need to be skipped in the backwards pass as well.
15261
+ const int64_t tmpw = (iiw + p0 - ikw*d0);
15262
+ if (tmpw % s0 != 0) {
15263
+ continue;
15264
+ }
15265
+ const int64_t iow = tmpw / s0;
15266
+
15267
+ // Equivalent logic as above except for s1.
15268
+ int64_t ioh;
15269
+ if (is_2D) {
15270
+ const int64_t tmph = iih + p1 - ikh*d1;
15271
+
15272
+ if (tmph % s1 != 0) {
15273
+ continue;
15274
+ }
15275
+
15276
+ ioh = tmph / s1;
15277
+ } else {
15278
+ ioh = 0;
15279
+ }
15280
+
15281
+ if (iow < 0 || iow >= OW || ioh < 0 || ioh >= OH) {
15282
+ continue;
15283
+ }
15284
+
15285
+ const float * const src_data = (const float *) src1->data
15286
+ + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
15287
+ grad += src_data[iic*(KH*KW) + ikh*KW + ikw];
15288
+ }
15289
+ }
15290
+ float * dst_data = (float *)((char *) wdata + (in*ofs0 + iic*ofs1)); // [IH, IW]
15291
+ dst_data[iih*IW + iiw] = grad;
15292
+ }
15293
+ }
15294
+ }
15295
+ }
15296
+ }
15297
+ }
14700
15298
 
14701
15299
  // lm_ggml_compute_forward_conv_transpose_2d
14702
15300
 
@@ -14757,7 +15355,7 @@ static void lm_ggml_compute_forward_conv_transpose_2d(
14757
15355
 
14758
15356
  memset(dst->data, 0, lm_ggml_nbytes(dst));
14759
15357
  }
14760
- lm_ggml_barrier(params->shared);
15358
+ lm_ggml_barrier(params->threadpool);
14761
15359
 
14762
15360
  const int32_t stride = lm_ggml_get_op_params_i32(dst, 0);
14763
15361
 
@@ -14939,45 +15537,167 @@ static void lm_ggml_compute_forward_pool_2d(
14939
15537
  }
14940
15538
  }
14941
15539
 
14942
- // lm_ggml_compute_forward_upscale
15540
+ // lm_ggml_compute_forward_pool_2d_back
14943
15541
 
14944
- static void lm_ggml_compute_forward_upscale_f32(
14945
- const struct lm_ggml_compute_params * params,
14946
- struct lm_ggml_tensor * dst) {
15542
+ static void lm_ggml_compute_forward_pool_2d_back(
15543
+ const struct lm_ggml_compute_params * params,
15544
+ struct lm_ggml_tensor * dst) {
14947
15545
 
14948
- const struct lm_ggml_tensor * src0 = dst->src[0];
15546
+ const struct lm_ggml_tensor * src = dst->src[0];
15547
+ const struct lm_ggml_tensor * dstf = dst->src[1]; // forward tensor of dst
14949
15548
 
14950
- LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_F32);
15549
+ assert(dst->type == LM_GGML_TYPE_F32 || dst->type == LM_GGML_TYPE_F16);
14951
15550
 
14952
- const int ith = params->ith;
14953
- const int nth = params->nth;
15551
+ if (params->ith != 0) {
15552
+ return;
15553
+ }
14954
15554
 
14955
- LM_GGML_TENSOR_UNARY_OP_LOCALS
15555
+ const int32_t * opts = (const int32_t *)dst->op_params;
15556
+ enum lm_ggml_op_pool op = opts[0];
15557
+ const int k0 = opts[1];
15558
+ const int k1 = opts[2];
15559
+ const int s0 = opts[3];
15560
+ const int s1 = opts[4];
15561
+ const int p0 = opts[5];
15562
+ const int p1 = opts[6];
14956
15563
 
14957
- const float sf0 = (float)ne0/src0->ne[0];
14958
- const float sf1 = (float)ne1/src0->ne[1];
14959
- const float sf2 = (float)ne2/src0->ne[2];
14960
- const float sf3 = (float)ne3/src0->ne[3];
15564
+ char * cdata = (char *) dst->data;
15565
+ const char * cdataf = (const char *) dstf->data;
15566
+ const char * const data_end = cdata + lm_ggml_nbytes(dst);
14961
15567
 
14962
- // TODO: optimize
15568
+ LM_GGML_ASSERT(params->ith == 0);
15569
+ memset(cdata, 0, lm_ggml_nbytes(dst));
14963
15570
 
14964
- for (int64_t i3 = 0; i3 < ne3; i3++) {
14965
- const int64_t i03 = i3 / sf3;
14966
- for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
14967
- const int64_t i02 = i2 / sf2;
14968
- for (int64_t i1 = 0; i1 < ne1; i1++) {
14969
- const int64_t i01 = i1 / sf1;
14970
- for (int64_t i0 = 0; i0 < ne0; i0++) {
14971
- const int64_t i00 = i0 / sf0;
15571
+ const int64_t px = src->ne[0];
15572
+ const int64_t py = src->ne[1];
15573
+ const int64_t pa = px * py;
14972
15574
 
14973
- const float * x = (float *)((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
14974
- float * y = (float *)((char *) dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
15575
+ const float * splane = (const float *) src->data;
14975
15576
 
14976
- *y = *x;
14977
- }
14978
- }
14979
- }
14980
- }
15577
+ const int ka = k0 * k1;
15578
+ const int offset0 = -p0;
15579
+ const int offset1 = -p1;
15580
+
15581
+ while (cdata < data_end) {
15582
+ for (int oy = 0; oy < py; ++oy) {
15583
+ const float * const srow = splane + oy * px;
15584
+ for (int ox = 0; ox < px; ++ox) {
15585
+ const float grad0 = srow[ox];
15586
+
15587
+ const int ix = offset0 + ox * s0;
15588
+ const int iy = offset1 + oy * s1;
15589
+
15590
+ if (op == LM_GGML_OP_POOL_MAX) {
15591
+ float maxval = -FLT_MAX;
15592
+ int kxmax = -1;
15593
+ int kymax = -1;
15594
+
15595
+ for (int ky = 0; ky < k1; ++ky) {
15596
+ if (iy + ky < 0 || iy + ky >= dst->ne[1]) {
15597
+ continue;
15598
+ }
15599
+ const void * drowf = (const void *)(cdataf + dst->nb[1] * (iy + ky));
15600
+ for (int kx = 0; kx < k0; ++kx) {
15601
+ int j = ix + kx;
15602
+ if (j < 0 || j >= dst->ne[0]) {
15603
+ continue;
15604
+ }
15605
+
15606
+ const float val = dst->type == LM_GGML_TYPE_F32 ?
15607
+ ((const float *) drowf)[j] : LM_GGML_FP16_TO_FP32(((const lm_ggml_fp16_t *) drowf)[j]);
15608
+ if (val <= maxval) {
15609
+ continue;
15610
+ }
15611
+
15612
+ maxval = val;
15613
+ kxmax = kx;
15614
+ kymax = ky;
15615
+ }
15616
+ }
15617
+
15618
+ if (kxmax == -1 || kymax == -1) {
15619
+ continue;
15620
+ }
15621
+
15622
+ void * drow = (void *)(cdata + dst->nb[1] * (iy + kymax));
15623
+ const int j = ix + kxmax;
15624
+ if (dst->type == LM_GGML_TYPE_F32) {
15625
+ ((float *) drow)[j] += grad0;
15626
+ } else {
15627
+ ((lm_ggml_fp16_t *) drow)[j] = LM_GGML_FP32_TO_FP16(grad0 + LM_GGML_FP16_TO_FP32(((const lm_ggml_fp16_t *) drow)[j]));
15628
+ }
15629
+ } else if (op == LM_GGML_OP_POOL_AVG) {
15630
+ const float grad = grad0 / ka;
15631
+
15632
+ for (int ky = 0; ky < k1; ++ky) {
15633
+ if (iy + ky < 0 || iy + ky >= dst->ne[1]) {
15634
+ continue;
15635
+ }
15636
+ void * drow = (void *)(cdata + dst->nb[1] * (iy + ky));
15637
+ for (int kx = 0; kx < k0; ++kx) {
15638
+ int j = ix + kx;
15639
+ if (j < 0 || j >= dst->ne[0]) {
15640
+ continue;
15641
+ }
15642
+
15643
+ if (dst->type == LM_GGML_TYPE_F32) {
15644
+ ((float *) drow)[j] += grad;
15645
+ } else {
15646
+ ((lm_ggml_fp16_t *) drow)[j] += LM_GGML_FP32_TO_FP16(grad);
15647
+ }
15648
+ }
15649
+ }
15650
+ } else {
15651
+ LM_GGML_ASSERT(false);
15652
+ }
15653
+ }
15654
+ }
15655
+
15656
+ cdata += dst->nb[2];
15657
+ cdataf += dst->nb[2];
15658
+ splane += pa;
15659
+ }
15660
+ }
15661
+
15662
+ // lm_ggml_compute_forward_upscale
15663
+
15664
+ static void lm_ggml_compute_forward_upscale_f32(
15665
+ const struct lm_ggml_compute_params * params,
15666
+ struct lm_ggml_tensor * dst) {
15667
+
15668
+ const struct lm_ggml_tensor * src0 = dst->src[0];
15669
+
15670
+ LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_F32);
15671
+
15672
+ const int ith = params->ith;
15673
+ const int nth = params->nth;
15674
+
15675
+ LM_GGML_TENSOR_UNARY_OP_LOCALS
15676
+
15677
+ const float sf0 = (float)ne0/src0->ne[0];
15678
+ const float sf1 = (float)ne1/src0->ne[1];
15679
+ const float sf2 = (float)ne2/src0->ne[2];
15680
+ const float sf3 = (float)ne3/src0->ne[3];
15681
+
15682
+ // TODO: optimize
15683
+
15684
+ for (int64_t i3 = 0; i3 < ne3; i3++) {
15685
+ const int64_t i03 = i3 / sf3;
15686
+ for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
15687
+ const int64_t i02 = i2 / sf2;
15688
+ for (int64_t i1 = 0; i1 < ne1; i1++) {
15689
+ const int64_t i01 = i1 / sf1;
15690
+ for (int64_t i0 = 0; i0 < ne0; i0++) {
15691
+ const int64_t i00 = i0 / sf0;
15692
+
15693
+ const float * x = (float *)((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
15694
+ float * y = (float *)((char *) dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
15695
+
15696
+ *y = *x;
15697
+ }
15698
+ }
15699
+ }
15700
+ }
14981
15701
  }
14982
15702
 
14983
15703
  static void lm_ggml_compute_forward_upscale(
@@ -15503,7 +16223,7 @@ static void lm_ggml_compute_forward_flash_attn_back_f32(
15503
16223
  if (ith == 0) {
15504
16224
  memset(dst->data, 0, nb0*ne0*ne1*ne2*ne3);
15505
16225
  }
15506
- lm_ggml_barrier(params->shared);
16226
+ lm_ggml_barrier(params->threadpool);
15507
16227
 
15508
16228
  const int64_t elem_q = lm_ggml_nelements(q);
15509
16229
  const int64_t elem_k = lm_ggml_nelements(k);
@@ -16125,6 +16845,10 @@ static void lm_ggml_compute_forward_unary(
16125
16845
  {
16126
16846
  lm_ggml_compute_forward_hardsigmoid(params, dst);
16127
16847
  } break;
16848
+ case LM_GGML_UNARY_OP_EXP:
16849
+ {
16850
+ lm_ggml_compute_forward_exp(params, dst);
16851
+ } break;
16128
16852
  default:
16129
16853
  {
16130
16854
  LM_GGML_ABORT("fatal error");
@@ -16194,7 +16918,7 @@ static void lm_ggml_compute_forward_add_rel_pos_f32(
16194
16918
  if (params->ith == 0) {
16195
16919
  memcpy((char *) dst->data, (char *) src0->data, lm_ggml_nbytes(dst));
16196
16920
  }
16197
- lm_ggml_barrier(params->shared);
16921
+ lm_ggml_barrier(params->threadpool);
16198
16922
  }
16199
16923
  // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L357-L359
16200
16924
 
@@ -16260,6 +16984,96 @@ static void lm_ggml_compute_forward_add_rel_pos(
16260
16984
  }
16261
16985
  }
16262
16986
 
16987
+ // lm_ggml_compute_forward_rwkv_wkv
16988
+
16989
+ static void lm_ggml_compute_forward_rwkv_wkv_f32(
16990
+ const struct lm_ggml_compute_params * params,
16991
+ struct lm_ggml_tensor * dst) {
16992
+ const size_t T = dst->src[1]->ne[3];
16993
+ const size_t C = dst->ne[0];
16994
+ const size_t H = dst->src[1]->ne[2];
16995
+ const size_t n_seqs = dst->src[5]->ne[1];
16996
+
16997
+ float * dst_data = (float *) dst->data;
16998
+ float * state = ((float *) dst->data) + C * T;
16999
+
17000
+ if (params->ith != 0) {
17001
+ return;
17002
+ }
17003
+
17004
+ memset(dst_data, 0, T * C * sizeof(float));
17005
+
17006
+ float * k = (float *) dst->src[0]->data;
17007
+ float * v = (float *) dst->src[1]->data;
17008
+ float * r = (float *) dst->src[2]->data;
17009
+ float * time_faaaa = (float *) dst->src[3]->data;
17010
+ float * time_decay = (float *) dst->src[4]->data;
17011
+
17012
+ size_t t_stride = H * (C / H);
17013
+
17014
+ size_t h_stride = C / H;
17015
+ size_t h_stride_2d = (C / H) * (C / H);
17016
+
17017
+ // basically fused operations:
17018
+ // dst = r @ (time_faaaa * (k @ v) + state),
17019
+ // state = time_decay * state + (k @ v),
17020
+ // recursive through each token
17021
+ for (size_t t = 0; t < T; t++) {
17022
+ size_t t_offset = t * t_stride;
17023
+ size_t state_offset = (C / H) * C * (t / (T / n_seqs));
17024
+ float * state_cur = state + state_offset;
17025
+ float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
17026
+
17027
+ for (size_t h = 0; h < H; h++) {
17028
+ size_t h_offset = h * h_stride;
17029
+ size_t t_h_offset = t_offset + h_offset;
17030
+ size_t h_2d_offset = h * h_stride_2d;
17031
+
17032
+ for (size_t i = 0; i < C / H; i++) {
17033
+ size_t t_h_i_offset = t_h_offset + i;
17034
+ size_t h_i_offset = h_offset + i;
17035
+ size_t h_2d_i_offset = h_2d_offset + i * h_stride;
17036
+
17037
+ float k_val = k[t_h_i_offset];
17038
+ float r_val = r[t_h_i_offset];
17039
+ float time_faaaa_val = time_faaaa[h_i_offset];
17040
+ // RWKV v6: different time_decay for each token.
17041
+ float time_decay_val = time_decay[t_h_i_offset];
17042
+
17043
+ for (size_t j = 0; j < C / H; j ++) {
17044
+ size_t t_h_j_offset = t_h_offset + j;
17045
+ size_t h_2d_i_j_offset = h_2d_i_offset + j;
17046
+
17047
+ float v_val = v[t_h_j_offset];
17048
+ float kv_val = v_val * k_val;
17049
+ float prev_state_val = state_prev[h_2d_i_j_offset];
17050
+ float temp_val = kv_val * time_faaaa_val + prev_state_val;
17051
+ dst_data[t_h_j_offset] += temp_val * r_val;
17052
+ state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
17053
+ }
17054
+ }
17055
+ }
17056
+ }
17057
+ }
17058
+
17059
+ static void lm_ggml_compute_forward_rwkv_wkv(
17060
+ const struct lm_ggml_compute_params * params,
17061
+ struct lm_ggml_tensor * dst) {
17062
+
17063
+ const struct lm_ggml_tensor * src0 = dst->src[0];
17064
+
17065
+ switch (src0->type) {
17066
+ case LM_GGML_TYPE_F32:
17067
+ {
17068
+ lm_ggml_compute_forward_rwkv_wkv_f32(params, dst);
17069
+ } break;
17070
+ default:
17071
+ {
17072
+ LM_GGML_ABORT("fatal error");
17073
+ }
17074
+ }
17075
+ }
17076
+
16263
17077
  // lm_ggml_compute_forward_map_unary
16264
17078
 
16265
17079
  static void lm_ggml_compute_forward_map_unary_f32(
@@ -16479,9 +17293,7 @@ static void lm_ggml_compute_forward_cross_entropy_loss_f32(
16479
17293
  if (ith == 0) {
16480
17294
  memset(sums, 0, sizeof(float) * (nth + nth * nc));
16481
17295
  }
16482
- lm_ggml_barrier(params->shared);
16483
-
16484
- const double eps = 1e-9;
17296
+ lm_ggml_barrier(params->threadpool);
16485
17297
 
16486
17298
  // rows per thread
16487
17299
  const int dr = (nr + nth - 1)/nth;
@@ -16503,20 +17315,15 @@ static void lm_ggml_compute_forward_cross_entropy_loss_f32(
16503
17315
  }
16504
17316
  #endif
16505
17317
 
16506
- // soft_max
16507
17318
  float max = -INFINITY;
16508
17319
  lm_ggml_vec_max_f32(nc, &max, s0);
16509
- lm_ggml_float sum = lm_ggml_vec_soft_max_f32(nc, st, s0, max);
16510
- assert(sum > 0.0);
16511
- sum = (1.0 - eps) / sum;
17320
+ lm_ggml_float sum = lm_ggml_vec_log_soft_max_f32(nc, st, s0, max);
17321
+ assert(sum >= 0.0);
16512
17322
 
16513
- // avoid log(0) by rescaling from [0..1] to [eps..1]
16514
- lm_ggml_vec_scale_f32(nc, st, sum);
16515
- lm_ggml_vec_add1_f32(nc, st, st, eps);
16516
- lm_ggml_vec_log_f32(nc, st, st);
17323
+ lm_ggml_vec_add1_f32(nc, st, st, -sum);
16517
17324
  lm_ggml_vec_mul_f32(nc, st, st, s1);
16518
17325
 
16519
- float st_sum = 0;
17326
+ float st_sum = 0.0f;
16520
17327
  lm_ggml_vec_sum_f32(nc, &st_sum, st);
16521
17328
  sums[ith] += st_sum;
16522
17329
 
@@ -16527,7 +17334,7 @@ static void lm_ggml_compute_forward_cross_entropy_loss_f32(
16527
17334
  }
16528
17335
  #endif
16529
17336
  }
16530
- lm_ggml_barrier(params->shared);
17337
+ lm_ggml_barrier(params->threadpool);
16531
17338
 
16532
17339
  if (ith == 0) {
16533
17340
  float * dp = (float *) dst->data;
@@ -16573,8 +17380,6 @@ static void lm_ggml_compute_forward_cross_entropy_loss_back_f32(
16573
17380
  const int64_t ith = params->ith;
16574
17381
  const int64_t nth = params->nth;
16575
17382
 
16576
- const double eps = 1e-9;
16577
-
16578
17383
  // TODO: handle transposed/permuted matrices
16579
17384
  const int64_t nc = src0->ne[0];
16580
17385
  const int64_t nr = lm_ggml_nrows(src0);
@@ -16606,11 +17411,9 @@ static void lm_ggml_compute_forward_cross_entropy_loss_back_f32(
16606
17411
  lm_ggml_vec_max_f32(nc, &max, s0);
16607
17412
  lm_ggml_float sum = lm_ggml_vec_soft_max_f32(nc, ds0, s0, max);
16608
17413
  assert(sum > 0.0);
16609
- sum = (1.0 - eps) / sum;
17414
+ lm_ggml_vec_scale_f32(nc, ds0, 1.0/sum);
16610
17415
 
16611
17416
  // grad(src0) = (softmax(src0) - src1) * grad(cross_entropy_loss(src0, src1)) / nr
16612
- lm_ggml_vec_scale_f32(nc, ds0, sum);
16613
- lm_ggml_vec_add1_f32(nc, ds0, ds0, eps);
16614
17417
  lm_ggml_vec_sub_f32(nc, ds0, ds0, s1);
16615
17418
  lm_ggml_vec_scale_f32(nc, ds0, d[0] / (float) nr);
16616
17419
 
@@ -16691,6 +17494,14 @@ static void lm_ggml_compute_forward(struct lm_ggml_compute_params * params, stru
16691
17494
  {
16692
17495
  lm_ggml_compute_forward_log(params, tensor);
16693
17496
  } break;
17497
+ case LM_GGML_OP_SIN:
17498
+ {
17499
+ lm_ggml_compute_forward_sin(params, tensor);
17500
+ } break;
17501
+ case LM_GGML_OP_COS:
17502
+ {
17503
+ lm_ggml_compute_forward_cos(params, tensor);
17504
+ } break;
16694
17505
  case LM_GGML_OP_SUM:
16695
17506
  {
16696
17507
  lm_ggml_compute_forward_sum(params, tensor);
@@ -16831,6 +17642,10 @@ static void lm_ggml_compute_forward(struct lm_ggml_compute_params * params, stru
16831
17642
  {
16832
17643
  lm_ggml_compute_forward_im2col(params, tensor);
16833
17644
  } break;
17645
+ case LM_GGML_OP_IM2COL_BACK:
17646
+ {
17647
+ lm_ggml_compute_forward_im2col_back_f32(params, tensor);
17648
+ } break;
16834
17649
  case LM_GGML_OP_CONV_TRANSPOSE_2D:
16835
17650
  {
16836
17651
  lm_ggml_compute_forward_conv_transpose_2d(params, tensor);
@@ -16843,6 +17658,10 @@ static void lm_ggml_compute_forward(struct lm_ggml_compute_params * params, stru
16843
17658
  {
16844
17659
  lm_ggml_compute_forward_pool_2d(params, tensor);
16845
17660
  } break;
17661
+ case LM_GGML_OP_POOL_2D_BACK:
17662
+ {
17663
+ lm_ggml_compute_forward_pool_2d_back(params, tensor);
17664
+ } break;
16846
17665
  case LM_GGML_OP_UPSCALE:
16847
17666
  {
16848
17667
  lm_ggml_compute_forward_upscale(params, tensor);
@@ -16906,6 +17725,10 @@ static void lm_ggml_compute_forward(struct lm_ggml_compute_params * params, stru
16906
17725
  {
16907
17726
  lm_ggml_compute_forward_add_rel_pos(params, tensor);
16908
17727
  } break;
17728
+ case LM_GGML_OP_RWKV_WKV:
17729
+ {
17730
+ lm_ggml_compute_forward_rwkv_wkv(params, tensor);
17731
+ } break;
16909
17732
  case LM_GGML_OP_MAP_UNARY:
16910
17733
  {
16911
17734
  lm_ggml_unary_op_f32_t fun;
@@ -17211,7 +18034,11 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
17211
18034
  src0->grad = lm_ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
17212
18035
  }
17213
18036
  if (src1->grad) {
17214
- src1->grad = lm_ggml_add_or_set(ctx, src1->grad, tensor->grad, zero_table);
18037
+ if (lm_ggml_are_same_shape(src0, src1)) {
18038
+ src1->grad = lm_ggml_add_or_set(ctx, src1->grad, tensor->grad, zero_table);
18039
+ } else {
18040
+ src1->grad = lm_ggml_add_or_set(ctx, src1->grad, lm_ggml_repeat_back(ctx, tensor->grad, src1), zero_table);
18041
+ }
17215
18042
  }
17216
18043
  } break;
17217
18044
  case LM_GGML_OP_ADD1:
@@ -17337,6 +18164,30 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
17337
18164
  zero_table);
17338
18165
  }
17339
18166
  } break;
18167
+ case LM_GGML_OP_SIN:
18168
+ {
18169
+ if (src0->grad) {
18170
+ src0->grad =
18171
+ lm_ggml_add_or_set(ctx,
18172
+ src0->grad,
18173
+ lm_ggml_mul(ctx,
18174
+ tensor->grad,
18175
+ lm_ggml_cos(ctx, src0)),
18176
+ zero_table);
18177
+ }
18178
+ } break;
18179
+ case LM_GGML_OP_COS:
18180
+ {
18181
+ if (src0->grad) {
18182
+ src0->grad =
18183
+ lm_ggml_sub_or_set(ctx,
18184
+ src0->grad,
18185
+ lm_ggml_mul(ctx,
18186
+ tensor->grad,
18187
+ lm_ggml_sin(ctx, src0)),
18188
+ zero_table);
18189
+ }
18190
+ } break;
17340
18191
  case LM_GGML_OP_SUM:
17341
18192
  {
17342
18193
  if (src0->grad) {
@@ -17509,14 +18360,10 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
17509
18360
  if (src0->grad || src1->grad) {
17510
18361
  LM_GGML_ASSERT(src0->type == tensor->type);
17511
18362
  LM_GGML_ASSERT(tensor->grad->type == tensor->type);
17512
- LM_GGML_ASSERT(tensor->grad->type == src1->grad->type);
18363
+ LM_GGML_ASSERT(!src1->grad || src1->grad->type == tensor->grad->type);
17513
18364
 
17514
18365
  tensor_grad_view = lm_ggml_view_4d(ctx,
17515
- tensor->grad,
17516
- src1->grad->ne[0],
17517
- src1->grad->ne[1],
17518
- src1->grad->ne[2],
17519
- src1->grad->ne[3],
18366
+ tensor->grad, src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],
17520
18367
  nb1, nb2, nb3, offset);
17521
18368
  }
17522
18369
 
@@ -17585,9 +18432,9 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
17585
18432
 
17586
18433
  memcpy(&offset, tensor->op_params, sizeof(offset));
17587
18434
 
17588
- size_t nb1 = tensor->nb[1];
17589
- size_t nb2 = tensor->nb[2];
17590
- size_t nb3 = tensor->nb[3];
18435
+ size_t nb1 = tensor->nb[1];
18436
+ size_t nb2 = tensor->nb[2];
18437
+ size_t nb3 = tensor->nb[3];
17591
18438
 
17592
18439
  if (src0->type != src0->grad->type) {
17593
18440
  // gradient is typically F32, but src0 could be other type
@@ -17784,6 +18631,23 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
17784
18631
  LM_GGML_ABORT("fatal error"); // TODO: not implemented
17785
18632
  }
17786
18633
  case LM_GGML_OP_IM2COL:
18634
+ {
18635
+ if (src1->grad) {
18636
+ const int32_t s0 = lm_ggml_get_op_params_i32(tensor, 0);
18637
+ const int32_t s1 = lm_ggml_get_op_params_i32(tensor, 1);
18638
+ const int32_t p0 = lm_ggml_get_op_params_i32(tensor, 2);
18639
+ const int32_t p1 = lm_ggml_get_op_params_i32(tensor, 3);
18640
+ const int32_t d0 = lm_ggml_get_op_params_i32(tensor, 4);
18641
+ const int32_t d1 = lm_ggml_get_op_params_i32(tensor, 5);
18642
+ const bool is_2D = lm_ggml_get_op_params_i32(tensor, 6) == 1;
18643
+
18644
+ src1->grad = lm_ggml_add_or_set(ctx,
18645
+ src1->grad,
18646
+ lm_ggml_im2col_back(ctx, src0, tensor->grad, src1->ne, s0, s1, p0, p1, d0, d1, is_2D),
18647
+ zero_table);
18648
+ }
18649
+ } break;
18650
+ case LM_GGML_OP_IM2COL_BACK:
17787
18651
  {
17788
18652
  LM_GGML_ABORT("fatal error"); // TODO: not implemented
17789
18653
  }
@@ -17796,6 +18660,23 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
17796
18660
  LM_GGML_ABORT("fatal error"); // TODO: not implemented
17797
18661
  }
17798
18662
  case LM_GGML_OP_POOL_2D:
18663
+ {
18664
+ if (src0->grad) {
18665
+ const enum lm_ggml_op_pool op = lm_ggml_get_op_params_i32(tensor, 0);
18666
+ const int32_t k0 = lm_ggml_get_op_params_i32(tensor, 1);
18667
+ const int32_t k1 = lm_ggml_get_op_params_i32(tensor, 2);
18668
+ const int32_t s0 = lm_ggml_get_op_params_i32(tensor, 3);
18669
+ const int32_t s1 = lm_ggml_get_op_params_i32(tensor, 4);
18670
+ const int32_t p0 = lm_ggml_get_op_params_i32(tensor, 5);
18671
+ const int32_t p1 = lm_ggml_get_op_params_i32(tensor, 6);
18672
+
18673
+ src0->grad = lm_ggml_add_or_set(ctx,
18674
+ src0->grad,
18675
+ lm_ggml_pool_2d_back(ctx, tensor->grad, src0, op, k0, k1, s0, s1, p0, p1),
18676
+ zero_table);
18677
+ }
18678
+ } break;
18679
+ case LM_GGML_OP_POOL_2D_BACK:
17799
18680
  {
17800
18681
  LM_GGML_ABORT("fatal error"); // TODO: not implemented
17801
18682
  }
@@ -17961,12 +18842,22 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
17961
18842
  zero_table);
17962
18843
  }
17963
18844
  } break;
18845
+ case LM_GGML_UNARY_OP_EXP:
18846
+ {
18847
+ if (src0->grad) {
18848
+ src0->grad = lm_ggml_add_or_set(ctx,
18849
+ src0->grad,
18850
+ lm_ggml_mul(ctx, tensor, tensor->grad),
18851
+ zero_table);
18852
+ }
18853
+ } break;
17964
18854
  default:
17965
18855
  LM_GGML_ABORT("fatal error");
17966
18856
  }
17967
18857
  } break;
17968
18858
  case LM_GGML_OP_GET_REL_POS:
17969
18859
  case LM_GGML_OP_ADD_REL_POS:
18860
+ case LM_GGML_OP_RWKV_WKV:
17970
18861
  case LM_GGML_OP_MAP_UNARY:
17971
18862
  case LM_GGML_OP_MAP_BINARY:
17972
18863
  case LM_GGML_OP_MAP_CUSTOM1_F32:
@@ -18085,6 +18976,7 @@ void lm_ggml_build_forward_expand(struct lm_ggml_cgraph * cgraph, struct lm_ggml
18085
18976
 
18086
18977
  void lm_ggml_build_backward_expand(struct lm_ggml_context * ctx, struct lm_ggml_cgraph * gf, struct lm_ggml_cgraph * gb, bool keep) {
18087
18978
  LM_GGML_ASSERT(gf->n_nodes > 0);
18979
+ LM_GGML_ASSERT(gf->grads);
18088
18980
 
18089
18981
  // if we are keeping the gradient graph, we have to detach the gradient nodes from the original graph
18090
18982
  if (keep) {
@@ -18238,7 +19130,8 @@ void lm_ggml_graph_cpy(struct lm_ggml_cgraph * src, struct lm_ggml_cgraph * dst)
18238
19130
  }
18239
19131
 
18240
19132
  for (size_t i = 0; i < src->visited_hash_set.size; ++i) {
18241
- if (src->visited_hash_set.keys[i]) {
19133
+ // copy all hashset keys (tensors) that are in use
19134
+ if (lm_ggml_bitset_get(src->visited_hash_set.used, i)) {
18242
19135
  lm_ggml_hash_insert(&dst->visited_hash_set, src->visited_hash_set.keys[i]);
18243
19136
  }
18244
19137
  }
@@ -18268,65 +19161,6 @@ void lm_ggml_graph_clear(struct lm_ggml_cgraph * cgraph) {
18268
19161
  lm_ggml_hash_set_reset(&cgraph->visited_hash_set);
18269
19162
  }
18270
19163
 
18271
- //
18272
- // thread data
18273
- //
18274
- // synchronization is done via busy loops
18275
- // I tried using spin locks, but not sure how to use them correctly - the things I tried were slower than busy loops
18276
- //
18277
-
18278
- #ifdef __APPLE__
18279
-
18280
- //#include <os/lock.h>
18281
- //
18282
- //typedef os_unfair_lock lm_ggml_lock_t;
18283
- //
18284
- //#define lm_ggml_lock_init(x) UNUSED(x)
18285
- //#define lm_ggml_lock_destroy(x) UNUSED(x)
18286
- //#define lm_ggml_lock_lock os_unfair_lock_lock
18287
- //#define lm_ggml_lock_unlock os_unfair_lock_unlock
18288
- //
18289
- //#define LM_GGML_LOCK_INITIALIZER OS_UNFAIR_LOCK_INIT
18290
-
18291
- typedef int lm_ggml_lock_t;
18292
-
18293
- #define lm_ggml_lock_init(x) UNUSED(x)
18294
- #define lm_ggml_lock_destroy(x) UNUSED(x)
18295
- #define lm_ggml_lock_lock(x) UNUSED(x)
18296
- #define lm_ggml_lock_unlock(x) UNUSED(x)
18297
-
18298
- #define LM_GGML_LOCK_INITIALIZER 0
18299
-
18300
- #define lm_ggml_thread_create pthread_create
18301
- #define lm_ggml_thread_join pthread_join
18302
-
18303
- #else
18304
-
18305
- //typedef pthread_spinlock_t lm_ggml_lock_t;
18306
-
18307
- //#define lm_ggml_lock_init(x) pthread_spin_init(x, PTHREAD_PROCESS_PRIVATE)
18308
- //#define lm_ggml_lock_destroy pthread_spin_destroy
18309
- //#define lm_ggml_lock_lock pthread_spin_lock
18310
- //#define lm_ggml_lock_unlock pthread_spin_unlock
18311
-
18312
- typedef int lm_ggml_lock_t;
18313
-
18314
- #define lm_ggml_lock_init(x) UNUSED(x)
18315
- #define lm_ggml_lock_destroy(x) UNUSED(x)
18316
- #if defined(__x86_64__) || (defined(_MSC_VER) && defined(_M_AMD64))
18317
- #define lm_ggml_lock_lock(x) _mm_pause()
18318
- #else
18319
- #define lm_ggml_lock_lock(x) UNUSED(x)
18320
- #endif
18321
- #define lm_ggml_lock_unlock(x) UNUSED(x)
18322
-
18323
- #define LM_GGML_LOCK_INITIALIZER 0
18324
-
18325
- #define lm_ggml_thread_create pthread_create
18326
- #define lm_ggml_thread_join pthread_join
18327
-
18328
- #endif
18329
-
18330
19164
  // Android's libc implementation "bionic" does not support setting affinity
18331
19165
  #if defined(__gnu_linux__)
18332
19166
  static void set_numa_thread_affinity(int thread_n) {
@@ -18424,6 +19258,8 @@ static int lm_ggml_get_n_tasks(struct lm_ggml_tensor * node, int n_threads) {
18424
19258
  case LM_GGML_OP_SQR:
18425
19259
  case LM_GGML_OP_SQRT:
18426
19260
  case LM_GGML_OP_LOG:
19261
+ case LM_GGML_OP_SIN:
19262
+ case LM_GGML_OP_COS:
18427
19263
  case LM_GGML_OP_SUM:
18428
19264
  case LM_GGML_OP_SUM_ROWS:
18429
19265
  case LM_GGML_OP_MEAN:
@@ -18446,6 +19282,7 @@ static int lm_ggml_get_n_tasks(struct lm_ggml_tensor * node, int n_threads) {
18446
19282
  case LM_GGML_UNARY_OP_SIGMOID:
18447
19283
  case LM_GGML_UNARY_OP_HARDSWISH:
18448
19284
  case LM_GGML_UNARY_OP_HARDSIGMOID:
19285
+ case LM_GGML_UNARY_OP_EXP:
18449
19286
  {
18450
19287
  n_tasks = 1;
18451
19288
  } break;
@@ -18510,6 +19347,7 @@ static int lm_ggml_get_n_tasks(struct lm_ggml_tensor * node, int n_threads) {
18510
19347
  n_tasks = MIN(n_threads, lm_ggml_nrows(node->src[0]));
18511
19348
  } break;
18512
19349
  case LM_GGML_OP_IM2COL:
19350
+ case LM_GGML_OP_IM2COL_BACK:
18513
19351
  case LM_GGML_OP_CONV_TRANSPOSE_1D:
18514
19352
  case LM_GGML_OP_CONV_TRANSPOSE_2D:
18515
19353
  {
@@ -18517,6 +19355,7 @@ static int lm_ggml_get_n_tasks(struct lm_ggml_tensor * node, int n_threads) {
18517
19355
  } break;
18518
19356
  case LM_GGML_OP_POOL_1D:
18519
19357
  case LM_GGML_OP_POOL_2D:
19358
+ case LM_GGML_OP_POOL_2D_BACK:
18520
19359
  {
18521
19360
  n_tasks = 1;
18522
19361
  } break;
@@ -18535,6 +19374,7 @@ static int lm_ggml_get_n_tasks(struct lm_ggml_tensor * node, int n_threads) {
18535
19374
  case LM_GGML_OP_WIN_PART:
18536
19375
  case LM_GGML_OP_WIN_UNPART:
18537
19376
  case LM_GGML_OP_GET_REL_POS:
19377
+ case LM_GGML_OP_RWKV_WKV:
18538
19378
  case LM_GGML_OP_MAP_UNARY:
18539
19379
  case LM_GGML_OP_MAP_BINARY:
18540
19380
  case LM_GGML_OP_MAP_CUSTOM1_F32:
@@ -18603,9 +19443,281 @@ static int lm_ggml_get_n_tasks(struct lm_ggml_tensor * node, int n_threads) {
18603
19443
  return n_tasks;
18604
19444
  }
18605
19445
 
18606
- struct lm_ggml_cplan lm_ggml_graph_plan(const struct lm_ggml_cgraph * cgraph, int n_threads) {
19446
+ static thread_ret_t lm_ggml_graph_compute_secondary_thread(void* data);
19447
+
19448
+ #if defined(_WIN32)
19449
+ #include "windows.h"
19450
+
19451
+ // TODO: support > 64 CPUs
19452
+ bool lm_ggml_thread_apply_affinity(bool * mask) {
19453
+ HANDLE h = GetCurrentThread();
19454
+ uint64_t bitmask = 0ULL;
19455
+
19456
+ assert(LM_GGML_MAX_N_THREADS >= 64);
19457
+
19458
+ for (int32_t i = 0; i < 8; i++) {
19459
+ int32_t idx = i * 8;
19460
+ uint8_t val = 0;
19461
+ val |= mask[idx + 0] << 0;
19462
+ val |= mask[idx + 1] << 1;
19463
+ val |= mask[idx + 2] << 2;
19464
+ val |= mask[idx + 3] << 3;
19465
+ val |= mask[idx + 4] << 4;
19466
+ val |= mask[idx + 5] << 5;
19467
+ val |= mask[idx + 6] << 6;
19468
+ val |= mask[idx + 7] << 7;
19469
+ bitmask |= (uint64_t)val << idx;
19470
+ }
19471
+
19472
+ for (int32_t i = 64; i < LM_GGML_MAX_N_THREADS; i++) {
19473
+ if (mask[i]) {
19474
+ fprintf(stderr, "warn: setting thread-affinity for > 64 CPUs isn't supported on windows!\n");
19475
+ break;
19476
+ }
19477
+ }
19478
+
19479
+ DWORD_PTR m = (DWORD_PTR)bitmask;
19480
+
19481
+ m = SetThreadAffinityMask(h, m);
19482
+
19483
+ return m != 0;
19484
+ }
19485
+
19486
+ static bool lm_ggml_thread_apply_priority(int32_t prio) {
19487
+ // Note that on Windows the Process Priority Class must be updated in order to set Thread priority.
19488
+ // This is up to the applications.
19489
+ DWORD p = THREAD_PRIORITY_NORMAL;
19490
+ switch (prio) {
19491
+ case LM_GGML_SCHED_PRIO_NORMAL: p = THREAD_PRIORITY_NORMAL; break;
19492
+ case LM_GGML_SCHED_PRIO_MEDIUM: p = THREAD_PRIORITY_ABOVE_NORMAL; break;
19493
+ case LM_GGML_SCHED_PRIO_HIGH: p = THREAD_PRIORITY_HIGHEST; break;
19494
+ case LM_GGML_SCHED_PRIO_REALTIME: p = THREAD_PRIORITY_TIME_CRITICAL; break;
19495
+ }
19496
+
19497
+ if (prio == LM_GGML_SCHED_PRIO_NORMAL) {
19498
+ // Keep inherited policy/priority
19499
+ return true;
19500
+ }
19501
+
19502
+ if (!SetThreadPriority(GetCurrentThread(), p)) {
19503
+ fprintf(stderr, "warn: failed to set thread priority %d : (%d)\n", prio, (int) GetLastError());
19504
+ return false;
19505
+ }
19506
+
19507
+ return true;
19508
+ }
19509
+
19510
+ #elif defined(__APPLE__)
19511
+ #include <sys/types.h>
19512
+ #include <sys/resource.h>
19513
+
19514
+ static bool lm_ggml_thread_apply_affinity(const bool * mask) {
19515
+ // Not supported on Apple platforms
19516
+ UNUSED(mask);
19517
+ return true;
19518
+ }
19519
+
19520
+ static bool lm_ggml_thread_apply_priority(int32_t prio) {
19521
+ struct sched_param p;
19522
+ int32_t policy = SCHED_OTHER;
19523
+ switch (prio) {
19524
+ case LM_GGML_SCHED_PRIO_NORMAL: policy = SCHED_OTHER; p.sched_priority = 0; break;
19525
+ case LM_GGML_SCHED_PRIO_MEDIUM: policy = SCHED_FIFO; p.sched_priority = 40; break;
19526
+ case LM_GGML_SCHED_PRIO_HIGH: policy = SCHED_FIFO; p.sched_priority = 80; break;
19527
+ case LM_GGML_SCHED_PRIO_REALTIME: policy = SCHED_FIFO; p.sched_priority = 90; break;
19528
+ }
19529
+
19530
+ if (prio == LM_GGML_SCHED_PRIO_NORMAL) {
19531
+ // Keep inherited policy/priority
19532
+ return true;
19533
+ }
19534
+
19535
+ int32_t err = pthread_setschedparam(pthread_self(), policy, &p);
19536
+ if (err != 0) {
19537
+ fprintf(stderr, "warn: failed to set thread priority %d : %s (%d)\n", prio, strerror(err), err);
19538
+ return false;
19539
+ }
19540
+
19541
+ return true;
19542
+ }
19543
+
19544
+ #elif defined(__gnu_linux__)
19545
+ // TODO: this may not work on BSD, to be verified
19546
+
19547
+ static bool lm_ggml_thread_apply_affinity(const bool * mask) {
19548
+ cpu_set_t cpuset;
19549
+ int err;
19550
+
19551
+ CPU_ZERO(&cpuset);
19552
+
19553
+ for (uint32_t i = 0; i < LM_GGML_MAX_N_THREADS; i++) {
19554
+ if (mask[i]) {
19555
+ LM_GGML_PRINT_DEBUG("Thread %lx: adding %d to cpuset\n", pthread_self(), i);
19556
+ CPU_SET(i, &cpuset);
19557
+ }
19558
+ }
19559
+
19560
+ #ifdef __ANDROID__
19561
+ err = sched_setaffinity(0, sizeof(cpuset), &cpuset);
19562
+ if (err < 0) {
19563
+ err = errno;
19564
+ }
19565
+ #else
19566
+ err = pthread_setaffinity_np(pthread_self(), sizeof(cpuset), &cpuset);
19567
+ #endif
19568
+ if (err != 0) {
19569
+ fprintf(stderr, "warn: failed to set affinity mask 0x%llx : %s (%d)\n", (unsigned long long)mask, strerror(err), err);
19570
+ return false;
19571
+ }
19572
+
19573
+ return true;
19574
+ }
19575
+
19576
+ static bool lm_ggml_thread_apply_priority(int32_t prio) {
19577
+ struct sched_param p;
19578
+ int32_t policy = SCHED_OTHER;
19579
+ switch (prio) {
19580
+ case LM_GGML_SCHED_PRIO_NORMAL: policy = SCHED_OTHER; p.sched_priority = 0; break;
19581
+ case LM_GGML_SCHED_PRIO_MEDIUM: policy = SCHED_FIFO; p.sched_priority = 40; break;
19582
+ case LM_GGML_SCHED_PRIO_HIGH: policy = SCHED_FIFO; p.sched_priority = 80; break;
19583
+ case LM_GGML_SCHED_PRIO_REALTIME: policy = SCHED_FIFO; p.sched_priority = 90; break;
19584
+ }
19585
+
19586
+ if (prio == LM_GGML_SCHED_PRIO_NORMAL) {
19587
+ // Keep inherited policy/priority
19588
+ return true;
19589
+ }
19590
+
19591
+ int32_t err = pthread_setschedparam(pthread_self(), policy, &p);
19592
+ if (err != 0) {
19593
+ fprintf(stderr, "warn: failed to set thread priority %d : %s (%d)\n", prio, strerror(err), err);
19594
+ return false;
19595
+ }
19596
+
19597
+ return true;
19598
+ }
19599
+
19600
+ #else // unsupported platforms
19601
+
19602
+ static bool lm_ggml_thread_apply_affinity(const bool * mask) {
19603
+ UNUSED(mask);
19604
+ return true;
19605
+ }
19606
+
19607
+ static bool lm_ggml_thread_apply_priority(int32_t prio) {
19608
+ UNUSED(prio);
19609
+ return true;
19610
+ }
19611
+
19612
+ #endif
19613
+
19614
+ static bool lm_ggml_thread_cpumask_is_valid(const bool * mask) {
19615
+ for (int i = 0; i < LM_GGML_MAX_N_THREADS; i++) {
19616
+ if (mask[i]) { return true; }
19617
+ }
19618
+ return false;
19619
+ }
19620
+
19621
+ static void lm_ggml_thread_cpumask_next(const bool * global_mask, bool * local_mask, bool strict, int32_t* iter) {
19622
+ if (!strict) {
19623
+ memcpy(local_mask, global_mask, LM_GGML_MAX_N_THREADS);
19624
+ return;
19625
+ } else {
19626
+ memset(local_mask, 0, LM_GGML_MAX_N_THREADS);
19627
+ int32_t base_idx = *iter;
19628
+ for (int32_t i = 0; i < LM_GGML_MAX_N_THREADS; i++) {
19629
+ int32_t idx = base_idx + i;
19630
+ if (idx >= LM_GGML_MAX_N_THREADS) {
19631
+ // Just a cheaper modulo
19632
+ idx -= LM_GGML_MAX_N_THREADS;
19633
+ }
19634
+ if (global_mask[idx]) {
19635
+ local_mask[idx] = 1;
19636
+ *iter = idx + 1;
19637
+ return;
19638
+ }
19639
+ }
19640
+ }
19641
+ }
19642
+
19643
+ void lm_ggml_threadpool_free(struct lm_ggml_threadpool* threadpool) {
19644
+ if (!threadpool) return;
19645
+
19646
+ #ifndef LM_GGML_USE_OPENMP
19647
+ struct lm_ggml_compute_state* workers = threadpool->workers;
19648
+ const int n_threads = threadpool->n_threads_max;
19649
+
19650
+ lm_ggml_mutex_lock(&threadpool->mutex);
19651
+
19652
+ threadpool->stop = true;
19653
+ threadpool->pause = false;
19654
+
19655
+ lm_ggml_cond_broadcast(&threadpool->cond);
19656
+ lm_ggml_mutex_unlock(&threadpool->mutex);
19657
+
19658
+ for (int j = 1; j < n_threads; j++) {
19659
+ int32_t rc = lm_ggml_thread_join(workers[j].thrd, NULL);
19660
+ LM_GGML_ASSERT(rc == LM_GGML_EXIT_SUCCESS || rc == LM_GGML_EXIT_ABORTED);
19661
+ UNUSED(rc);
19662
+ }
19663
+
19664
+ lm_ggml_mutex_destroy(&threadpool->mutex);
19665
+ lm_ggml_cond_destroy(&threadpool->cond);
19666
+ #endif // LM_GGML_USE_OPENMP
19667
+
19668
+ LM_GGML_ALIGNED_FREE(threadpool->workers);
19669
+ LM_GGML_ALIGNED_FREE(threadpool);
19670
+ }
19671
+
19672
+ #ifndef LM_GGML_USE_OPENMP
19673
+ // pause/resume must be called under mutex
19674
+ static void lm_ggml_threadpool_pause_locked(struct lm_ggml_threadpool * threadpool) {
19675
+ LM_GGML_PRINT_DEBUG("Pausing threadpool\n");
19676
+ threadpool->pause = true;
19677
+ lm_ggml_cond_broadcast(&threadpool->cond);
19678
+ }
19679
+
19680
+ static void lm_ggml_threadpool_resume_locked(struct lm_ggml_threadpool * threadpool) {
19681
+ LM_GGML_PRINT_DEBUG("Resuming threadpool\n");
19682
+ threadpool->pause = false;
19683
+ lm_ggml_cond_broadcast(&threadpool->cond);
19684
+ }
19685
+ #endif
19686
+
19687
+ void lm_ggml_threadpool_pause(struct lm_ggml_threadpool * threadpool) {
19688
+ #ifndef LM_GGML_USE_OPENMP
19689
+ lm_ggml_mutex_lock(&threadpool->mutex);
19690
+ if (!threadpool->pause) {
19691
+ lm_ggml_threadpool_pause_locked(threadpool);
19692
+ }
19693
+ lm_ggml_mutex_unlock(&threadpool->mutex);
19694
+ #else
19695
+ UNUSED(threadpool);
19696
+ #endif
19697
+ }
19698
+
19699
+ void lm_ggml_threadpool_resume(struct lm_ggml_threadpool * threadpool) {
19700
+ #ifndef LM_GGML_USE_OPENMP
19701
+ lm_ggml_mutex_lock(&threadpool->mutex);
19702
+ if (threadpool->pause) {
19703
+ lm_ggml_threadpool_resume_locked(threadpool);
19704
+ }
19705
+ lm_ggml_mutex_unlock(&threadpool->mutex);
19706
+ #else
19707
+ UNUSED(threadpool);
19708
+ #endif
19709
+ }
19710
+
19711
+ struct lm_ggml_cplan lm_ggml_graph_plan(
19712
+ const struct lm_ggml_cgraph * cgraph,
19713
+ int n_threads,
19714
+ struct lm_ggml_threadpool * threadpool) {
19715
+
19716
+ if (threadpool == NULL) {
19717
+ LM_GGML_PRINT_DEBUG("Threadpool is not specified. Will create a disposable threadpool : n_threads %d\n", n_threads);
19718
+ }
18607
19719
  if (n_threads <= 0) {
18608
- n_threads = LM_GGML_DEFAULT_N_THREADS;
19720
+ n_threads = threadpool ? threadpool->n_threads_max : LM_GGML_DEFAULT_N_THREADS;
18609
19721
  }
18610
19722
 
18611
19723
  size_t work_size = 0;
@@ -18761,12 +19873,13 @@ struct lm_ggml_cplan lm_ggml_graph_plan(const struct lm_ggml_cgraph * cgraph, in
18761
19873
  }
18762
19874
 
18763
19875
  if (work_size > 0) {
18764
- work_size += CACHE_LINE_SIZE*(n_threads - 1);
19876
+ work_size += CACHE_LINE_SIZE*(n_threads);
18765
19877
  }
18766
19878
 
18767
- cplan.n_threads = MIN(max_tasks, n_threads);
18768
- cplan.work_size = work_size;
18769
- cplan.work_data = NULL;
19879
+ cplan.threadpool = threadpool;
19880
+ cplan.n_threads = MIN(max_tasks, n_threads);
19881
+ cplan.work_size = work_size;
19882
+ cplan.work_data = NULL;
18770
19883
 
18771
19884
  return cplan;
18772
19885
  }
@@ -18774,17 +19887,17 @@ struct lm_ggml_cplan lm_ggml_graph_plan(const struct lm_ggml_cgraph * cgraph, in
18774
19887
  static thread_ret_t lm_ggml_graph_compute_thread(void * data) {
18775
19888
  struct lm_ggml_compute_state * state = (struct lm_ggml_compute_state *) data;
18776
19889
 
18777
- const struct lm_ggml_cgraph * cgraph = state->shared->cgraph;
18778
- const struct lm_ggml_cplan * cplan = state->shared->cplan;
19890
+ const struct lm_ggml_cgraph * cgraph = state->threadpool->cgraph;
19891
+ const struct lm_ggml_cplan * cplan = state->threadpool->cplan;
18779
19892
 
18780
19893
  set_numa_thread_affinity(state->ith);
18781
19894
 
18782
19895
  struct lm_ggml_compute_params params = {
18783
- /*.ith =*/ state->ith,
18784
- /*.nth =*/ state->shared->n_threads,
18785
- /*.wsize =*/ cplan->work_size,
18786
- /*.wdata =*/ cplan->work_data,
18787
- /*.shared=*/ state->shared,
19896
+ /*.ith =*/ state->ith,
19897
+ /*.nth =*/ state->threadpool->n_threads_cur,
19898
+ /*.wsize =*/ cplan->work_size,
19899
+ /*.wdata =*/ cplan->work_data,
19900
+ /*.threadpool=*/ state->threadpool,
18788
19901
  };
18789
19902
 
18790
19903
  for (int node_n = 0; node_n < cgraph->n_nodes; node_n++) {
@@ -18793,12 +19906,12 @@ static thread_ret_t lm_ggml_graph_compute_thread(void * data) {
18793
19906
  lm_ggml_compute_forward(&params, node);
18794
19907
 
18795
19908
  if (state->ith == 0 && cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) {
18796
- state->shared->ec = LM_GGML_STATUS_ABORTED;
19909
+ state->threadpool->ec = LM_GGML_STATUS_ABORTED;
18797
19910
  }
18798
19911
 
18799
- lm_ggml_barrier(state->shared);
19912
+ lm_ggml_barrier(state->threadpool);
18800
19913
 
18801
- if (state->shared->ec != LM_GGML_STATUS_SUCCESS) {
19914
+ if (state->threadpool->ec != LM_GGML_STATUS_SUCCESS) {
18802
19915
  break;
18803
19916
  }
18804
19917
  }
@@ -18806,24 +19919,243 @@ static thread_ret_t lm_ggml_graph_compute_thread(void * data) {
18806
19919
  return 0;
18807
19920
  }
18808
19921
 
19922
+ #ifndef LM_GGML_USE_OPENMP
19923
+
19924
+ static inline bool lm_ggml_graph_compute_ready(struct lm_ggml_compute_state * state) {
19925
+ struct lm_ggml_threadpool * threadpool = state->threadpool;
19926
+
19927
+ if (state->pending || threadpool->stop || threadpool->pause) { return true; }
19928
+
19929
+ // check for new graph/work
19930
+ int new_graph = atomic_load_explicit(&threadpool->n_graph, memory_order_relaxed);
19931
+ if (new_graph != state->last_graph) {
19932
+ state->pending = (state->ith < threadpool->n_threads_cur);
19933
+ state->last_graph = new_graph;
19934
+ }
19935
+
19936
+ return state->pending;
19937
+ }
19938
+
19939
+ static inline bool lm_ggml_graph_compute_poll_for_work(struct lm_ggml_compute_state * state) {
19940
+ struct lm_ggml_threadpool * threadpool = state->threadpool;
19941
+
19942
+ // This seems to make 0 ... 100 a decent range for polling level across modern processors.
19943
+ // Perhaps, we can adjust it dynamically based on load and things.
19944
+ const uint64_t n_rounds = 1024UL * 128 * threadpool->poll;
19945
+
19946
+ for (uint64_t i=0; !lm_ggml_graph_compute_ready(state) && i<n_rounds; i++) {
19947
+ // No new work. Keep polling.
19948
+ lm_ggml_thread_cpu_relax();
19949
+ }
19950
+
19951
+ return state->pending;
19952
+ }
19953
+
19954
+ static inline bool lm_ggml_graph_compute_check_for_work(struct lm_ggml_compute_state * state) {
19955
+ struct lm_ggml_threadpool * threadpool = state->threadpool;
19956
+
19957
+ if (lm_ggml_graph_compute_poll_for_work(state)) {
19958
+ return state->pending;
19959
+ }
19960
+
19961
+ lm_ggml_mutex_lock_shared(&threadpool->mutex);
19962
+ while (!lm_ggml_graph_compute_ready(state)) {
19963
+ // No new work. Wait for the signal.
19964
+ LM_GGML_PRINT_DEBUG("thread #%d waiting for work\n", state->ith);
19965
+ lm_ggml_cond_wait(&threadpool->cond, &threadpool->mutex);
19966
+ }
19967
+ lm_ggml_mutex_unlock_shared(&threadpool->mutex);
19968
+
19969
+ return state->pending;
19970
+ }
19971
+
19972
+ static thread_ret_t lm_ggml_graph_compute_secondary_thread(void* data) {
19973
+ struct lm_ggml_compute_state * state = (struct lm_ggml_compute_state *) data;
19974
+ struct lm_ggml_threadpool * threadpool = state->threadpool;
19975
+
19976
+ lm_ggml_thread_apply_priority(threadpool->prio);
19977
+ if (lm_ggml_thread_cpumask_is_valid(state->cpumask)) {
19978
+ lm_ggml_thread_apply_affinity(state->cpumask);
19979
+ }
19980
+
19981
+ while (true) {
19982
+ // Check if we need to sleep
19983
+ while (threadpool->pause) {
19984
+ LM_GGML_PRINT_DEBUG("thread #%d inside pause loop\n", state->ith);
19985
+ lm_ggml_mutex_lock_shared(&threadpool->mutex);
19986
+ if (threadpool->pause) {
19987
+ lm_ggml_cond_wait(&threadpool->cond, &threadpool->mutex);
19988
+ }
19989
+ LM_GGML_PRINT_DEBUG("thread #%d resuming after wait\n", state->ith);
19990
+ lm_ggml_mutex_unlock_shared(&threadpool->mutex);
19991
+ }
19992
+
19993
+ // This needs to be checked for after the cond_wait
19994
+ if (threadpool->stop) break;
19995
+
19996
+ // Check if there is new work
19997
+ // The main thread is the only one that can dispatch new work
19998
+
19999
+ lm_ggml_graph_compute_check_for_work(state);
20000
+ if (state->pending) {
20001
+ state->pending = false;
20002
+
20003
+ lm_ggml_graph_compute_thread(state);
20004
+ }
20005
+ }
20006
+
20007
+ return (thread_ret_t) 0;
20008
+ }
20009
+
20010
+ // Start processing new graph
20011
+ static void lm_ggml_graph_compute_kickoff(struct lm_ggml_threadpool * threadpool)
20012
+ {
20013
+ // always take the mutex here because the worker threads are doing hybrid poll/wait
20014
+
20015
+ lm_ggml_mutex_lock(&threadpool->mutex);
20016
+
20017
+ atomic_fetch_add_explicit(&threadpool->n_graph, 1, memory_order_relaxed);
20018
+
20019
+ if (threadpool->pause) {
20020
+ // Update main thread prio and affinity to match the threadpool settings
20021
+ lm_ggml_thread_apply_priority(threadpool->prio);
20022
+ if (lm_ggml_thread_cpumask_is_valid(threadpool->workers[0].cpumask)) {
20023
+ lm_ggml_thread_apply_affinity(threadpool->workers[0].cpumask);
20024
+ }
20025
+
20026
+ // resume does cond broadcast
20027
+ lm_ggml_threadpool_resume_locked(threadpool);
20028
+ } else {
20029
+ lm_ggml_cond_broadcast(&threadpool->cond);
20030
+ }
20031
+
20032
+ lm_ggml_mutex_unlock(&threadpool->mutex);
20033
+ }
20034
+
20035
+ #endif // LM_GGML_USE_OPENMP
20036
+
20037
+ void lm_ggml_threadpool_params_init(struct lm_ggml_threadpool_params * p, int n_threads) {
20038
+ p->n_threads = n_threads;
20039
+ p->prio = 0; // default priority (usually means normal or inherited)
20040
+ p->poll = 50; // hybrid-polling enabled
20041
+ p->strict_cpu = false; // no strict placement (all threads share same cpumask)
20042
+ p->paused = false; // threads are ready to go
20043
+ memset(p->cpumask, 0, LM_GGML_MAX_N_THREADS); // all-zero means use the default affinity (usually inherited)
20044
+ }
20045
+
20046
+ struct lm_ggml_threadpool_params lm_ggml_threadpool_params_default(int n_threads) {
20047
+ struct lm_ggml_threadpool_params p;
20048
+ lm_ggml_threadpool_params_init(&p, n_threads);
20049
+ return p;
20050
+ }
20051
+
20052
+ bool lm_ggml_threadpool_params_match(const struct lm_ggml_threadpool_params * p0, const struct lm_ggml_threadpool_params * p1) {
20053
+ if (p0->n_threads != p1->n_threads ) return false;
20054
+ if (p0->prio != p1->prio ) return false;
20055
+ if (p0->poll != p1->poll ) return false;
20056
+ if (p0->strict_cpu != p1->strict_cpu ) return false;
20057
+ return memcmp(p0->cpumask, p1->cpumask, LM_GGML_MAX_N_THREADS) == 0;
20058
+ }
20059
+
20060
+ static struct lm_ggml_threadpool * lm_ggml_threadpool_new_impl(
20061
+ struct lm_ggml_threadpool_params * tpp,
20062
+ struct lm_ggml_cgraph * cgraph,
20063
+ struct lm_ggml_cplan * cplan) {
20064
+
20065
+ struct lm_ggml_threadpool * threadpool =
20066
+ LM_GGML_ALIGNED_MALLOC(sizeof(struct lm_ggml_threadpool));
20067
+ {
20068
+ threadpool->cgraph = cgraph;
20069
+ threadpool->cplan = cplan;
20070
+ threadpool->n_graph = 0;
20071
+ threadpool->n_barrier = 0;
20072
+ threadpool->n_barrier_passed = 0;
20073
+ threadpool->current_chunk = 0;
20074
+ threadpool->stop = false;
20075
+ threadpool->pause = tpp->paused;
20076
+ threadpool->workers = NULL;
20077
+ threadpool->n_threads_max = tpp->n_threads;
20078
+ threadpool->n_threads_cur = tpp->n_threads;
20079
+ threadpool->poll = tpp->poll;
20080
+ threadpool->prio = tpp->prio;
20081
+ threadpool->ec = LM_GGML_STATUS_SUCCESS;
20082
+ }
20083
+
20084
+ // Allocate and init workers state
20085
+ const size_t workers_size = sizeof(struct lm_ggml_compute_state) * tpp->n_threads;
20086
+ struct lm_ggml_compute_state * workers = LM_GGML_ALIGNED_MALLOC(workers_size);
20087
+
20088
+ memset(workers, 0, workers_size);
20089
+ for (int j = 0; j < tpp->n_threads; j++) {
20090
+ workers[j].threadpool = threadpool;
20091
+ workers[j].ith = j;
20092
+ }
20093
+
20094
+ threadpool->workers = workers;
20095
+
20096
+ #ifndef LM_GGML_USE_OPENMP
20097
+ lm_ggml_mutex_init(&threadpool->mutex);
20098
+ lm_ggml_cond_init(&threadpool->cond);
20099
+
20100
+ // Spin the threads for all workers, and update CPU placements.
20101
+ // Place the main thread last (towards the higher numbered CPU cores).
20102
+
20103
+ int32_t cpumask_iter = 0;
20104
+
20105
+ for (int j = 1; j < tpp->n_threads; j++) {
20106
+ lm_ggml_thread_cpumask_next(tpp->cpumask, workers[j].cpumask, tpp->strict_cpu, &cpumask_iter);
20107
+
20108
+ int32_t rc = lm_ggml_thread_create(&workers[j].thrd, NULL, lm_ggml_graph_compute_secondary_thread, &workers[j]);
20109
+ LM_GGML_ASSERT(rc == 0);
20110
+ }
20111
+
20112
+ lm_ggml_thread_cpumask_next(tpp->cpumask, workers[0].cpumask, tpp->strict_cpu, &cpumask_iter);
20113
+
20114
+ if (!threadpool->pause) {
20115
+ // Update main thread prio and affinity at the start, otherwise we'll do it in resume
20116
+ lm_ggml_thread_apply_priority(threadpool->prio);
20117
+ if (lm_ggml_thread_cpumask_is_valid(threadpool->workers[0].cpumask)) {
20118
+ lm_ggml_thread_apply_affinity(threadpool->workers[0].cpumask);
20119
+ }
20120
+ }
20121
+ #endif // LM_GGML_USE_OPENMP
20122
+
20123
+ return threadpool;
20124
+ }
20125
+
20126
+ struct lm_ggml_threadpool * lm_ggml_threadpool_new(struct lm_ggml_threadpool_params * tpp) {
20127
+ return lm_ggml_threadpool_new_impl(tpp, NULL, NULL);
20128
+ }
20129
+
18809
20130
  enum lm_ggml_status lm_ggml_graph_compute(struct lm_ggml_cgraph * cgraph, struct lm_ggml_cplan * cplan) {
18810
20131
  LM_GGML_ASSERT(cplan);
18811
20132
  LM_GGML_ASSERT(cplan->n_threads > 0);
18812
20133
  LM_GGML_ASSERT(cplan->work_size == 0 || cplan->work_data != NULL);
18813
20134
 
18814
- int n_threads = cplan->n_threads;
18815
-
18816
- struct lm_ggml_compute_state_shared state_shared = {
18817
- /*.cgraph =*/ cgraph,
18818
- /*.cgraph_plan =*/ cplan,
18819
- /*.n_threads =*/ n_threads,
18820
- /*.n_barrier =*/ 0,
18821
- /*.n_barrier_passed =*/ 0,
18822
- /*.abort_callback =*/ NULL,
18823
- /*.abort_callback_data =*/ NULL,
18824
- /*.current_chunk =*/ 0,
18825
- /*.ec =*/ LM_GGML_STATUS_SUCCESS,
18826
- };
20135
+ int n_threads = cplan->n_threads;
20136
+ struct lm_ggml_threadpool * threadpool = cplan->threadpool;
20137
+
20138
+ bool disposable_threadpool = false;
20139
+
20140
+ if (threadpool == NULL) {
20141
+ LM_GGML_PRINT_DEBUG("Threadpool is not specified. Will create a disposable threadpool : n_threads %d\n", n_threads);
20142
+ disposable_threadpool = true;
20143
+
20144
+ struct lm_ggml_threadpool_params ttp = lm_ggml_threadpool_params_default(n_threads);
20145
+ threadpool = lm_ggml_threadpool_new_impl(&ttp, cgraph, cplan);
20146
+ } else {
20147
+ // Reset some of the parameters that need resetting
20148
+ // No worker threads should be accessing the parameters below at this stage
20149
+ threadpool->cgraph = cgraph;
20150
+ threadpool->cplan = cplan;
20151
+ threadpool->n_threads_cur = n_threads;
20152
+ threadpool->current_chunk = 0;
20153
+ threadpool->ec = LM_GGML_STATUS_SUCCESS;
20154
+ }
20155
+
20156
+ if (n_threads > threadpool->n_threads_max) {
20157
+ LM_GGML_PRINT("WARNING: cplan is requesting more threads than the threadpool contains. Expect a bad time!\n");
20158
+ }
18827
20159
 
18828
20160
  #ifdef LM_GGML_USE_OPENMP
18829
20161
  if (n_threads > 1) {
@@ -18833,63 +20165,36 @@ enum lm_ggml_status lm_ggml_graph_compute(struct lm_ggml_cgraph * cgraph, struct
18833
20165
  {
18834
20166
  // update the number of threads from the actual number of threads that we got from OpenMP
18835
20167
  n_threads = omp_get_num_threads();
18836
- state_shared.n_threads = n_threads;
20168
+ threadpool->n_threads_cur = n_threads;
18837
20169
  }
18838
20170
 
18839
- struct lm_ggml_compute_state worker = {
18840
- .thrd = 0,
18841
- .ith = omp_get_thread_num(),
18842
- .shared = &state_shared,
18843
- };
18844
- lm_ggml_graph_compute_thread(&worker);
20171
+ lm_ggml_graph_compute_thread(&threadpool->workers[omp_get_thread_num()]);
18845
20172
  }
18846
20173
  } else {
18847
- struct lm_ggml_compute_state worker = {
18848
- .thrd = 0,
18849
- .ith = 0,
18850
- .shared = &state_shared,
18851
- };
18852
- lm_ggml_graph_compute_thread(&worker);
20174
+ lm_ggml_graph_compute_thread(&threadpool->workers[0]);
18853
20175
  }
18854
20176
  #else
18855
- struct lm_ggml_compute_state * workers = alloca(sizeof(struct lm_ggml_compute_state)*n_threads);
20177
+ // Kick all threads to start the new graph
20178
+ lm_ggml_graph_compute_kickoff(threadpool);
18856
20179
 
18857
- for (int j = 0; j < n_threads; ++j) {
18858
- workers[j] = (struct lm_ggml_compute_state) {
18859
- .thrd = 0,
18860
- .ith = j,
18861
- .shared = &state_shared,
18862
- };
18863
- }
18864
-
18865
- // create thread pool
18866
- for (int j = 1; j < n_threads; ++j) {
18867
- const int rc = lm_ggml_thread_create(&workers[j].thrd, NULL, lm_ggml_graph_compute_thread, &workers[j]);
18868
- LM_GGML_ASSERT(rc == 0);
18869
- UNUSED(rc);
18870
- }
18871
-
18872
- // this is a work thread too
18873
- lm_ggml_graph_compute_thread(&workers[0]);
18874
-
18875
- // join or kill thread pool
18876
- if (n_threads > 1) {
18877
- for (int j = 1; j < n_threads; j++) {
18878
- const int rc = lm_ggml_thread_join(workers[j].thrd, NULL);
18879
- LM_GGML_ASSERT(rc == 0);
18880
- UNUSED(rc);
18881
- }
18882
- }
20180
+ // This is a work thread too
20181
+ lm_ggml_graph_compute_thread(&threadpool->workers[0]);
18883
20182
  #endif
18884
20183
 
18885
20184
  // don't leave affinity set on the main thread
18886
20185
  clear_numa_thread_affinity();
18887
20186
 
18888
- return state_shared.ec;
20187
+ enum lm_ggml_status ret = threadpool->ec;
20188
+
20189
+ if (disposable_threadpool) {
20190
+ lm_ggml_threadpool_free(threadpool);
20191
+ }
20192
+
20193
+ return ret;
18889
20194
  }
18890
20195
 
18891
20196
  enum lm_ggml_status lm_ggml_graph_compute_with_ctx(struct lm_ggml_context * ctx, struct lm_ggml_cgraph * cgraph, int n_threads) {
18892
- struct lm_ggml_cplan cplan = lm_ggml_graph_plan(cgraph, n_threads);
20197
+ struct lm_ggml_cplan cplan = lm_ggml_graph_plan(cgraph, n_threads, NULL);
18893
20198
 
18894
20199
  struct lm_ggml_object * obj = lm_ggml_new_object(ctx, LM_GGML_OBJECT_TYPE_WORK_BUFFER, cplan.work_size);
18895
20200
 
@@ -19030,9 +20335,11 @@ void lm_ggml_graph_export(const struct lm_ggml_cgraph * cgraph, const char * fna
19030
20335
 
19031
20336
  const uint32_t type = tensor->type;
19032
20337
  const uint32_t op = tensor->op;
20338
+ const int32_t flags = tensor->flags;
19033
20339
 
19034
20340
  fwrite(&type, sizeof(uint32_t), 1, fout);
19035
20341
  fwrite(&op, sizeof(uint32_t), 1, fout);
20342
+ fwrite(&flags, sizeof(int32_t), 1, fout);
19036
20343
 
19037
20344
  for (int j = 0; j < LM_GGML_MAX_DIMS; ++j) {
19038
20345
  const uint64_t ne = tensor->ne[j];
@@ -19062,9 +20369,11 @@ void lm_ggml_graph_export(const struct lm_ggml_cgraph * cgraph, const char * fna
19062
20369
 
19063
20370
  const uint32_t type = tensor->type;
19064
20371
  const uint32_t op = tensor->op;
20372
+ const int32_t flags = tensor->flags;
19065
20373
 
19066
20374
  fwrite(&type, sizeof(uint32_t), 1, fout);
19067
20375
  fwrite(&op, sizeof(uint32_t), 1, fout);
20376
+ fwrite(&flags, sizeof(int32_t), 1, fout);
19068
20377
 
19069
20378
  for (int j = 0; j < LM_GGML_MAX_DIMS; ++j) {
19070
20379
  const uint64_t ne = tensor->ne[j];
@@ -19123,6 +20432,14 @@ void lm_ggml_graph_export(const struct lm_ggml_cgraph * cgraph, const char * fna
19123
20432
  }
19124
20433
  }
19125
20434
  }
20435
+
20436
+ // dump the data
20437
+ // TODO: pad this to 32 byte boundary
20438
+ if ((flags & LM_GGML_TENSOR_FLAG_PARAM)) {
20439
+ const size_t size = lm_ggml_nbytes(tensor);
20440
+
20441
+ fwrite(tensor->data, sizeof(char), size, fout);
20442
+ }
19126
20443
  }
19127
20444
  }
19128
20445
 
@@ -19236,10 +20553,12 @@ struct lm_ggml_cgraph * lm_ggml_graph_import(const char * fname, struct lm_ggml_
19236
20553
  {
19237
20554
  uint32_t type;
19238
20555
  uint32_t op;
20556
+ int32_t flags;
19239
20557
 
19240
20558
  for (uint32_t i = 0; i < n_leafs; ++i) {
19241
20559
  type = *(const uint32_t *) ptr; ptr += sizeof(type);
19242
20560
  op = *(const uint32_t *) ptr; ptr += sizeof(op);
20561
+ flags = *(const int32_t *) ptr; ptr += sizeof(flags);
19243
20562
 
19244
20563
  int64_t ne[LM_GGML_MAX_DIMS];
19245
20564
  size_t nb[LM_GGML_MAX_DIMS];
@@ -19257,20 +20576,19 @@ struct lm_ggml_cgraph * lm_ggml_graph_import(const char * fname, struct lm_ggml_
19257
20576
 
19258
20577
  struct lm_ggml_tensor * tensor = lm_ggml_new_tensor(*ctx_eval, (enum lm_ggml_type) type, LM_GGML_MAX_DIMS, ne);
19259
20578
 
19260
- tensor->op = (enum lm_ggml_op) op;
20579
+ tensor->op = (enum lm_ggml_op) op;
20580
+ tensor->flags = flags;
19261
20581
 
19262
20582
  memcpy(tensor->name, ptr, LM_GGML_MAX_NAME); ptr += LM_GGML_MAX_NAME;
19263
20583
  memcpy(tensor->op_params, ptr, LM_GGML_MAX_OP_PARAMS); ptr += LM_GGML_MAX_OP_PARAMS;
19264
20584
 
19265
- tensor->data = (void *) ptr;
19266
-
19267
20585
  for (int j = 0; j < LM_GGML_MAX_DIMS; ++j) {
19268
20586
  tensor->nb[j] = nb[j];
19269
20587
  }
19270
20588
 
19271
- result->leafs[i] = tensor;
20589
+ tensor->data = (void *) ptr; ptr += lm_ggml_nbytes(tensor);
19272
20590
 
19273
- ptr += lm_ggml_nbytes(tensor);
20591
+ result->leafs[i] = tensor;
19274
20592
 
19275
20593
  fprintf(stderr, "%s: loaded leaf %u: '%16s', %9zu bytes\n", __func__, i, tensor->name, lm_ggml_nbytes(tensor));
19276
20594
  }
@@ -19282,10 +20600,12 @@ struct lm_ggml_cgraph * lm_ggml_graph_import(const char * fname, struct lm_ggml_
19282
20600
  {
19283
20601
  uint32_t type;
19284
20602
  uint32_t op;
20603
+ int32_t flags;
19285
20604
 
19286
20605
  for (uint32_t i = 0; i < n_nodes; ++i) {
19287
20606
  type = *(const uint32_t *) ptr; ptr += sizeof(type);
19288
20607
  op = *(const uint32_t *) ptr; ptr += sizeof(op);
20608
+ flags = *(const int32_t *) ptr; ptr += sizeof(flags);
19289
20609
 
19290
20610
  enum lm_ggml_op eop = (enum lm_ggml_op) op;
19291
20611
 
@@ -19375,6 +20695,11 @@ struct lm_ggml_cgraph * lm_ggml_graph_import(const char * fname, struct lm_ggml_
19375
20695
 
19376
20696
  result->nodes[i] = tensor;
19377
20697
 
20698
+ // TODO tensor data is be duplicated due to lm_ggml_new_tensor call above
20699
+ if (flags & LM_GGML_TENSOR_FLAG_PARAM) {
20700
+ tensor->data = (void *) ptr; ptr += lm_ggml_nbytes(tensor);
20701
+ }
20702
+
19378
20703
  fprintf(stderr, "%s: loaded node %u: '%16s', %9zu bytes\n", __func__, i, tensor->name, lm_ggml_nbytes(tensor));
19379
20704
  }
19380
20705
  }
@@ -19643,6 +20968,7 @@ static enum lm_ggml_opt_result lm_ggml_opt_adam(
19643
20968
  lm_ggml_opt_callback callback,
19644
20969
  void * callback_data) {
19645
20970
  LM_GGML_ASSERT(lm_ggml_is_scalar(f));
20971
+ LM_GGML_ASSERT(f->type == LM_GGML_TYPE_F32);
19646
20972
 
19647
20973
  // these will store the parameters we want to optimize
19648
20974
  struct lm_ggml_tensor * ps[LM_GGML_MAX_PARAMS];
@@ -19684,7 +21010,7 @@ static enum lm_ggml_opt_result lm_ggml_opt_adam(
19684
21010
 
19685
21011
  float * pf = params.past > 0 ? opt->adam.pf->data : NULL; // past function values
19686
21012
 
19687
- struct lm_ggml_cplan cplan = lm_ggml_graph_plan(gb, params.n_threads);
21013
+ struct lm_ggml_cplan cplan = lm_ggml_graph_plan(gb, params.n_threads, NULL);
19688
21014
  struct lm_ggml_object * obj = lm_ggml_new_object(ctx, LM_GGML_OBJECT_TYPE_WORK_BUFFER, cplan.work_size);
19689
21015
  cplan.work_data = (uint8_t *)ctx->mem_buffer + obj->offs;
19690
21016
 
@@ -20031,7 +21357,7 @@ static enum lm_ggml_opt_result lm_ggml_opt_lbfgs(
20031
21357
  opt->iter = iter;
20032
21358
  }
20033
21359
 
20034
- struct lm_ggml_cplan cplan = lm_ggml_graph_plan(gb, params.n_threads);
21360
+ struct lm_ggml_cplan cplan = lm_ggml_graph_plan(gb, params.n_threads, NULL);
20035
21361
  struct lm_ggml_object * obj = lm_ggml_new_object(ctx, LM_GGML_OBJECT_TYPE_WORK_BUFFER, cplan.work_size);
20036
21362
  cplan.work_data = (uint8_t *)ctx->mem_buffer + obj->offs;
20037
21363
 
@@ -20409,6 +21735,8 @@ enum lm_ggml_opt_result lm_ggml_opt(
20409
21735
  struct lm_ggml_context * ctx,
20410
21736
  struct lm_ggml_opt_params params,
20411
21737
  struct lm_ggml_tensor * f) {
21738
+ LM_GGML_ASSERT(f->grad && "lm_ggml_set_param called for at least one parent tensor.");
21739
+
20412
21740
  bool free_ctx = false;
20413
21741
  if (ctx == NULL) {
20414
21742
  struct lm_ggml_init_params params_ctx = {
@@ -20463,6 +21791,8 @@ enum lm_ggml_opt_result lm_ggml_opt_resume_g(
20463
21791
  lm_ggml_opt_callback callback,
20464
21792
  void * callback_data) {
20465
21793
 
21794
+ LM_GGML_ASSERT(f->grad && "lm_ggml_set_param must be called for at least one ancestor");
21795
+
20466
21796
  // build forward + backward compute graphs
20467
21797
  enum lm_ggml_opt_result result = LM_GGML_OPT_RESULT_OK;
20468
21798
 
@@ -20574,6 +21904,8 @@ size_t lm_ggml_quantize_chunk(
20574
21904
  case LM_GGML_TYPE_Q4_K: result = quantize_q4_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
20575
21905
  case LM_GGML_TYPE_Q5_K: result = quantize_q5_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
20576
21906
  case LM_GGML_TYPE_Q6_K: result = quantize_q6_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
21907
+ case LM_GGML_TYPE_TQ1_0: result = quantize_tq1_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
21908
+ case LM_GGML_TYPE_TQ2_0: result = quantize_tq2_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
20577
21909
  case LM_GGML_TYPE_IQ2_XXS: result = quantize_iq2_xxs(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
20578
21910
  case LM_GGML_TYPE_IQ2_XS: result = quantize_iq2_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
20579
21911
  case LM_GGML_TYPE_IQ3_XXS: result = quantize_iq3_xxs(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
@@ -21550,6 +22882,7 @@ void lm_gguf_set_kv(struct lm_gguf_context * ctx, struct lm_gguf_context * src)
21550
22882
  void lm_gguf_add_tensor(
21551
22883
  struct lm_gguf_context * ctx,
21552
22884
  const struct lm_ggml_tensor * tensor) {
22885
+ LM_GGML_ASSERT(tensor);
21553
22886
  if (lm_gguf_find_tensor(ctx, tensor->name) != -1) {
21554
22887
  LM_GGML_ABORT("duplicated tensor name");
21555
22888
  }