cui-llama.rn 1.1.6 → 1.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/cpp/ggml.c CHANGED
@@ -1,7 +1,9 @@
1
1
  #define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnings on Windows
2
2
  #define _USE_MATH_DEFINES // For M_PI on MSVC
3
3
 
4
+ #include "ggml-backend.h"
4
5
  #include "ggml-impl.h"
6
+ #include "ggml-cpu-impl.h"
5
7
  #include "ggml-quants.h"
6
8
  #include "ggml.h"
7
9
  #include "ggml-aarch64.h"
@@ -61,6 +63,25 @@ int lm_ggml_sve_cnt_b = 0;
61
63
  #pragma warning(disable: 4702)
62
64
  #endif
63
65
 
66
+ // Note: once we move threading into a separate C++ file
67
+ // will use std::hardware_destructive_interference_size instead of hardcoding it here
68
+ // and we'll use C++ attribute syntax.
69
+ #define LM_GGML_CACHE_LINE 64
70
+
71
+ #if defined(__clang__) || defined(__GNUC__)
72
+ #define LM_GGML_CACHE_ALIGN __attribute__((aligned(LM_GGML_CACHE_LINE)))
73
+ #endif
74
+
75
+ #if defined(__has_feature)
76
+ #if __has_feature(thread_sanitizer)
77
+ #define LM_GGML_TSAN_ENABLED 1
78
+ #endif
79
+ #else // __has_feature
80
+ #if defined(__SANITIZE_THREAD__)
81
+ #define LM_GGML_TSAN_ENABLED 1
82
+ #endif
83
+ #endif // __has_feature
84
+
64
85
  #if defined(_WIN32)
65
86
 
66
87
  #define WIN32_LEAN_AND_MEAN
@@ -70,6 +91,8 @@ int lm_ggml_sve_cnt_b = 0;
70
91
  #include <windows.h>
71
92
 
72
93
  #if !defined(__clang__)
94
+ #define LM_GGML_CACHE_ALIGN __declspec(align(LM_GGML_CACHE_LINE))
95
+
73
96
  typedef volatile LONG atomic_int;
74
97
  typedef atomic_int atomic_bool;
75
98
  typedef atomic_int atomic_flag;
@@ -112,6 +135,9 @@ static atomic_bool atomic_flag_test_and_set(atomic_flag * ptr) {
112
135
  static void atomic_flag_clear(atomic_flag * ptr) {
113
136
  InterlockedExchange(ptr, 0);
114
137
  }
138
+ static void atomic_thread_fence(memory_order mo) {
139
+ MemoryBarrier();
140
+ }
115
141
  #else // clang
116
142
  #include <stdatomic.h>
117
143
  #endif
@@ -287,7 +313,6 @@ void lm_ggml_abort(const char * file, int line, const char * fmt, ...) {
287
313
  #define LM_GGML_DEBUG 0
288
314
  #define LM_GGML_GELU_FP16
289
315
  #define LM_GGML_GELU_QUICK_FP16
290
- #define LM_GGML_N_TASKS_MAX (-1)
291
316
 
292
317
  #define LM_GGML_SOFT_MAX_UNROLL 4
293
318
  #define LM_GGML_VEC_DOT_UNROLL 2
@@ -2005,17 +2030,18 @@ struct lm_ggml_threadpool {
2005
2030
 
2006
2031
  // synchronization primitives
2007
2032
  atomic_int n_graph; // incremented when there is work to be done (i.e each graph)
2008
- atomic_int n_barrier;
2009
- atomic_int n_barrier_passed;
2033
+ atomic_int LM_GGML_CACHE_ALIGN n_barrier;
2034
+ atomic_int LM_GGML_CACHE_ALIGN n_barrier_passed;
2010
2035
  atomic_int current_chunk; // currently processing chunk during Mat_Mul, shared between all the threads.
2011
2036
 
2012
2037
  // these are atomic as an annotation for thread-sanitizer
2013
2038
  atomic_bool stop; // Used for stopping the threadpool altogether
2014
2039
  atomic_bool pause; // Used for pausing the threadpool or individual threads
2040
+ atomic_bool abort; // Used for aborting processing of a graph
2015
2041
 
2016
2042
  struct lm_ggml_compute_state * workers; // per thread state
2017
2043
  int n_threads_max; // number of threads in the pool
2018
- int n_threads_cur; // number of threads used in the current graph
2044
+ atomic_int n_threads_cur; // number of threads used in the current graph
2019
2045
 
2020
2046
  int32_t prio; // Scheduling priority
2021
2047
  uint32_t poll; // Polling level (0 - no polling)
@@ -2995,9 +3021,10 @@ static const char * LM_GGML_OP_NAME[LM_GGML_OP_COUNT] = {
2995
3021
 
2996
3022
  "CROSS_ENTROPY_LOSS",
2997
3023
  "CROSS_ENTROPY_LOSS_BACK",
3024
+ "OPT_STEP_ADAMW",
2998
3025
  };
2999
3026
 
3000
- static_assert(LM_GGML_OP_COUNT == 79, "LM_GGML_OP_COUNT != 79");
3027
+ static_assert(LM_GGML_OP_COUNT == 80, "LM_GGML_OP_COUNT != 80");
3001
3028
 
3002
3029
  static const char * LM_GGML_OP_SYMBOL[LM_GGML_OP_COUNT] = {
3003
3030
  "none",
@@ -3088,9 +3115,10 @@ static const char * LM_GGML_OP_SYMBOL[LM_GGML_OP_COUNT] = {
3088
3115
 
3089
3116
  "cross_entropy_loss(x,y)",
3090
3117
  "cross_entropy_loss_back(x,y)",
3118
+ "adamw(x)",
3091
3119
  };
3092
3120
 
3093
- static_assert(LM_GGML_OP_COUNT == 79, "LM_GGML_OP_COUNT != 79");
3121
+ static_assert(LM_GGML_OP_COUNT == 80, "LM_GGML_OP_COUNT != 80");
3094
3122
 
3095
3123
  static_assert(LM_GGML_OP_POOL_COUNT == 2, "LM_GGML_OP_POOL_COUNT != 2");
3096
3124
 
@@ -3177,41 +3205,43 @@ inline static void lm_ggml_critical_section_start(void) {
3177
3205
  }
3178
3206
  }
3179
3207
 
3180
- #ifdef LM_GGML_USE_OPENMP
3181
- static void lm_ggml_barrier(struct lm_ggml_threadpool * threadpool) {
3182
- if (threadpool->n_threads_cur == 1) {
3208
+ static void lm_ggml_barrier(struct lm_ggml_threadpool * tp) {
3209
+ int n_threads = atomic_load_explicit(&tp->n_threads_cur, memory_order_relaxed);
3210
+ if (n_threads == 1) {
3183
3211
  return;
3184
3212
  }
3185
3213
 
3214
+ #ifdef LM_GGML_USE_OPENMP
3186
3215
  #pragma omp barrier
3187
- }
3188
3216
  #else
3189
- static void lm_ggml_barrier(struct lm_ggml_threadpool * threadpool) {
3190
- if (threadpool->n_threads_cur == 1) {
3191
- return;
3192
- }
3217
+ int n_passed = atomic_load_explicit(&tp->n_barrier_passed, memory_order_relaxed);
3193
3218
 
3194
- atomic_int * n_barrier = &threadpool->n_barrier;
3195
- atomic_int * n_barrier_passed = &threadpool->n_barrier_passed;
3219
+ // enter barrier (full seq-cst fence)
3220
+ int n_barrier = atomic_fetch_add_explicit(&tp->n_barrier, 1, memory_order_seq_cst);
3196
3221
 
3197
- int n_threads = threadpool->n_threads_cur;
3198
- int passed_old = atomic_load_explicit(n_barrier_passed, memory_order_relaxed);
3199
-
3200
- if (atomic_fetch_add(n_barrier, 1) == n_threads - 1) {
3222
+ if (n_barrier == (n_threads - 1)) {
3201
3223
  // last thread
3202
- atomic_store(n_barrier, 0);
3203
- atomic_fetch_add_explicit(n_barrier_passed, 1, memory_order_relaxed);
3204
- } else {
3205
- // wait for other threads
3206
- while (true) {
3207
- if (atomic_load_explicit(n_barrier_passed, memory_order_relaxed) != passed_old) {
3208
- return;
3209
- }
3210
- lm_ggml_thread_cpu_relax();
3211
- }
3224
+ atomic_store_explicit(&tp->n_barrier, 0, memory_order_relaxed);
3225
+
3226
+ // exit barrier (fill seq-cst fence)
3227
+ atomic_fetch_add_explicit(&tp->n_barrier_passed, 1, memory_order_seq_cst);
3228
+ return;
3212
3229
  }
3213
- }
3230
+
3231
+ // wait for other threads
3232
+ while (atomic_load_explicit(&tp->n_barrier_passed, memory_order_relaxed) == n_passed) {
3233
+ lm_ggml_thread_cpu_relax();
3234
+ }
3235
+
3236
+ // exit barrier (full seq-cst fence)
3237
+ // TSAN doesn't support standalone fence yet, we use a dummy read-modify-write instead
3238
+ #ifdef LM_GGML_TSAN_ENABLED
3239
+ atomic_fetch_add_explicit(&tp->n_barrier_passed, 0, memory_order_seq_cst);
3240
+ #else
3241
+ atomic_thread_fence(memory_order_seq_cst);
3242
+ #endif
3214
3243
  #endif
3244
+ }
3215
3245
 
3216
3246
  // TODO: make this somehow automatically executed
3217
3247
  // some sort of "sentry" mechanism
@@ -4097,7 +4127,11 @@ static void lm_ggml_set_op_params_f32(struct lm_ggml_tensor * tensor, uint32_t i
4097
4127
  }
4098
4128
 
4099
4129
  struct lm_ggml_tensor * lm_ggml_set_zero(struct lm_ggml_tensor * tensor) {
4100
- memset(tensor->data, 0, lm_ggml_nbytes(tensor));
4130
+ if (tensor->buffer) {
4131
+ lm_ggml_backend_tensor_memset(tensor, 0, 0, lm_ggml_nbytes(tensor));
4132
+ } else {
4133
+ memset(tensor->data, 0, lm_ggml_nbytes(tensor));
4134
+ }
4101
4135
  return tensor;
4102
4136
  }
4103
4137
 
@@ -8323,11 +8357,46 @@ struct lm_ggml_tensor * lm_ggml_cross_entropy_loss_back(
8323
8357
  return result;
8324
8358
  }
8325
8359
 
8326
- ////////////////////////////////////////////////////////////////////////////////
8360
+ // opt_step_adamw
8327
8361
 
8328
- void lm_ggml_set_param(
8362
+ struct lm_ggml_tensor * lm_ggml_opt_step_adamw(
8329
8363
  struct lm_ggml_context * ctx,
8330
- struct lm_ggml_tensor * tensor) {
8364
+ struct lm_ggml_tensor * a,
8365
+ float alpha,
8366
+ float beta1,
8367
+ float beta2,
8368
+ float eps,
8369
+ float wd) {
8370
+ LM_GGML_ASSERT(a->grad);
8371
+ LM_GGML_ASSERT(alpha > 0.0f);
8372
+ LM_GGML_ASSERT(beta1 >= 0.0f && beta1 <= 1.0f);
8373
+ LM_GGML_ASSERT(beta2 >= 0.0f && beta2 <= 1.0f);
8374
+ LM_GGML_ASSERT(eps >= 0.0f);
8375
+ LM_GGML_ASSERT(wd >= 0.0f && wd <= 1.0f);
8376
+
8377
+ struct lm_ggml_tensor * result = lm_ggml_view_tensor(ctx, a);
8378
+
8379
+ result->op = LM_GGML_OP_OPT_STEP_ADAMW;
8380
+ result->grad = NULL;
8381
+ result->src[0] = a;
8382
+ result->src[1] = a->grad;
8383
+ result->src[2] = lm_ggml_dup_tensor(ctx, a->grad);
8384
+ result->src[3] = lm_ggml_dup_tensor(ctx, a->grad);
8385
+
8386
+ const int64_t iter = 1;
8387
+ memcpy(&result->op_params[0], &iter, sizeof(int64_t));
8388
+ lm_ggml_set_op_params_f32(result, 2, alpha);
8389
+ lm_ggml_set_op_params_f32(result, 3, beta1);
8390
+ lm_ggml_set_op_params_f32(result, 4, beta2);
8391
+ lm_ggml_set_op_params_f32(result, 5, eps);
8392
+ lm_ggml_set_op_params_f32(result, 6, wd);
8393
+
8394
+ return result;
8395
+ }
8396
+
8397
+ ////////////////////////////////////////////////////////////////////////////////
8398
+
8399
+ void lm_ggml_set_param(struct lm_ggml_context * ctx, struct lm_ggml_tensor * tensor) {
8331
8400
  tensor->flags |= LM_GGML_TENSOR_FLAG_PARAM;
8332
8401
 
8333
8402
  LM_GGML_ASSERT(tensor->grad == NULL);
@@ -8335,6 +8404,13 @@ void lm_ggml_set_param(
8335
8404
  lm_ggml_format_name(tensor->grad, "%s (grad)", tensor->name);
8336
8405
  }
8337
8406
 
8407
+ void lm_ggml_set_loss(struct lm_ggml_tensor * tensor) {
8408
+ LM_GGML_ASSERT(lm_ggml_is_scalar(tensor));
8409
+ LM_GGML_ASSERT(tensor->type == LM_GGML_TYPE_F32);
8410
+ LM_GGML_ASSERT(tensor->grad);
8411
+ tensor->flags |= LM_GGML_TENSOR_FLAG_LOSS;
8412
+ }
8413
+
8338
8414
  // lm_ggml_compute_forward_dup
8339
8415
 
8340
8416
  static void lm_ggml_compute_forward_dup_same_cont(
@@ -17409,7 +17485,7 @@ static void lm_ggml_compute_forward_cross_entropy_loss_back_f32(
17409
17485
  const int64_t ir0 = dr*ith;
17410
17486
  const int64_t ir1 = MIN(ir0 + dr, nr);
17411
17487
 
17412
- float * d = (float *) opt0->data;
17488
+ const float d_by_nr = ((const float *) opt0->data)[0] / (float) nr;
17413
17489
 
17414
17490
  for (int64_t i1 = ir0; i1 < ir1; i1++) {
17415
17491
  float * ds0 = (float *)((char *) dst->data + i1*dst->nb[1]);
@@ -17433,7 +17509,7 @@ static void lm_ggml_compute_forward_cross_entropy_loss_back_f32(
17433
17509
 
17434
17510
  // grad(src0) = (softmax(src0) - src1) * grad(cross_entropy_loss(src0, src1)) / nr
17435
17511
  lm_ggml_vec_sub_f32(nc, ds0, ds0, s1);
17436
- lm_ggml_vec_scale_f32(nc, ds0, d[0] / (float) nr);
17512
+ lm_ggml_vec_scale_f32(nc, ds0, d_by_nr);
17437
17513
 
17438
17514
  #ifndef NDEBUG
17439
17515
  for (int i = 0; i < nc; ++i) {
@@ -17462,6 +17538,94 @@ static void lm_ggml_compute_forward_cross_entropy_loss_back(
17462
17538
  }
17463
17539
  }
17464
17540
 
17541
+ static void lm_ggml_compute_forward_opt_step_adamw_f32(
17542
+ const struct lm_ggml_compute_params * params,
17543
+ struct lm_ggml_tensor * dst) {
17544
+
17545
+ const struct lm_ggml_tensor * src0 = dst->src[0];
17546
+ const struct lm_ggml_tensor * src0_grad = dst->src[1];
17547
+ const struct lm_ggml_tensor * src0_grad_m = dst->src[2];
17548
+ const struct lm_ggml_tensor * src0_grad_v = dst->src[3];
17549
+ LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, src0_grad));
17550
+
17551
+ const int ith = params->ith;
17552
+ const int nth = params->nth;
17553
+
17554
+ const int nr = lm_ggml_nrows(src0);
17555
+
17556
+ LM_GGML_TENSOR_UNARY_OP_LOCALS
17557
+ LM_GGML_ASSERT(nb00 == sizeof(float));
17558
+
17559
+ // rows per thread
17560
+ const int dr = (nr + nth - 1)/nth;
17561
+
17562
+ // row range for this thread
17563
+ const int ir0 = dr*ith;
17564
+ const int ir1 = MIN(ir0 + dr, nr);
17565
+
17566
+ /* const float gnorm = 1.0f; */
17567
+ int64_t iter; memcpy(&iter, &dst->op_params[0], sizeof(int64_t));
17568
+ const float alpha = lm_ggml_get_op_params_f32(dst, 2);
17569
+ const float beta1 = lm_ggml_get_op_params_f32(dst, 3);
17570
+ const float beta2 = lm_ggml_get_op_params_f32(dst, 4);
17571
+ const float eps = lm_ggml_get_op_params_f32(dst, 5);
17572
+ const float wd = lm_ggml_get_op_params_f32(dst, 6);
17573
+
17574
+ const float beta1h = alpha/(1.0f - powf(beta1, iter));
17575
+ const float beta2h = 1.0f/(1.0f - powf(beta2, iter));
17576
+
17577
+ for (int ir = ir0; ir < ir1; ++ir) {
17578
+ const int64_t i03 = ir/(ne02*ne01);
17579
+ const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
17580
+ const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
17581
+
17582
+ const size_t offset = i03*nb03 + i02*nb02 + i01*nb01;
17583
+
17584
+ float * w = (float *) ((char *) src0->data + offset); // weight
17585
+ const float * g = (const float *) ((const char *) src0_grad->data + offset); // grad
17586
+ float * m = (float *) ((char *) src0_grad_m->data + offset);
17587
+ float * v = (float *) ((char *) src0_grad_v->data + offset);
17588
+
17589
+ for (int i00 = 0; i00 < ne00; ++i00) {
17590
+ m[i00] = m[i00]*beta1 + g[i00]*(1.0f - beta1);
17591
+ v[i00] = v[i00]*beta2 + g[i00]*g[i00]*(1.0f - beta2);
17592
+
17593
+ const float mh = m[i00]*beta1h;
17594
+ const float vh = sqrtf(v[i00]*beta2h) + eps;
17595
+
17596
+ // The weight decay is applied independently of the Adam momenta m and v.
17597
+ // This is NOT equivalent to l2 regularization that adds w[i00]*w[i00] to the loss.
17598
+ // See: https://arxiv.org/pdf/1711.05101v3.pdf
17599
+ w[i00] = w[i00]*(1.0f - alpha*wd) - mh/vh;
17600
+ }
17601
+ }
17602
+
17603
+ lm_ggml_barrier(params->threadpool);
17604
+ if (ith != 0) {
17605
+ return;
17606
+ }
17607
+
17608
+ iter++;
17609
+ memcpy(&dst->op_params[0], &iter, sizeof(int64_t));
17610
+ }
17611
+
17612
+ static void lm_ggml_compute_forward_opt_step_adamw(
17613
+ const struct lm_ggml_compute_params * params,
17614
+ struct lm_ggml_tensor * dst) {
17615
+
17616
+ const struct lm_ggml_tensor * src0 = dst->src[0];
17617
+
17618
+ switch (src0->type) {
17619
+ case LM_GGML_TYPE_F32:
17620
+ {
17621
+ lm_ggml_compute_forward_opt_step_adamw_f32(params, dst);
17622
+ } break;
17623
+ default:
17624
+ {
17625
+ LM_GGML_ABORT("fatal error");
17626
+ }
17627
+ }
17628
+ }
17465
17629
  /////////////////////////////////
17466
17630
 
17467
17631
  static void lm_ggml_compute_forward(struct lm_ggml_compute_params * params, struct lm_ggml_tensor * tensor) {
@@ -17807,6 +17971,11 @@ static void lm_ggml_compute_forward(struct lm_ggml_compute_params * params, stru
17807
17971
  lm_ggml_compute_forward_cross_entropy_loss_back(params, tensor);
17808
17972
  }
17809
17973
  break;
17974
+ case LM_GGML_OP_OPT_STEP_ADAMW:
17975
+ {
17976
+ lm_ggml_compute_forward_opt_step_adamw(params, tensor);
17977
+ }
17978
+ break;
17810
17979
  case LM_GGML_OP_NONE:
17811
17980
  {
17812
17981
  // nop
@@ -17961,7 +18130,7 @@ void lm_ggml_build_backward_gradient_checkpointing(
17961
18130
  struct lm_ggml_tensor * * checkpoints,
17962
18131
  int n_checkpoints) {
17963
18132
  lm_ggml_graph_cpy(gf, gb_tmp);
17964
- lm_ggml_build_backward_expand(ctx, gf, gb_tmp, true);
18133
+ lm_ggml_build_backward_expand(ctx, gf, gb_tmp, false, true);
17965
18134
 
17966
18135
  if (n_checkpoints <= 0) {
17967
18136
  lm_ggml_graph_cpy(gb_tmp, gb);
@@ -17999,42 +18168,93 @@ void lm_ggml_build_backward_gradient_checkpointing(
17999
18168
  lm_ggml_hash_map_free(replacements);
18000
18169
  }
18001
18170
 
18002
- // functions to change gradients considering the case that input a might be initial gradient with zero value
18003
-
18004
- static struct lm_ggml_tensor * lm_ggml_add_or_set(struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, struct lm_ggml_tensor * b, struct lm_ggml_hash_set * zero_table) {
18171
+ // utility functions to change gradients
18172
+ // if a is in acc_table, modify gradients in-place and mark result as gradient accumulator
18173
+ // else if a is in zero_table, replace a
18174
+ // else, just add/subtract/etc. the gradients
18175
+
18176
+ static struct lm_ggml_tensor * lm_ggml_add_or_set(
18177
+ struct lm_ggml_context * ctx,
18178
+ struct lm_ggml_tensor * a,
18179
+ struct lm_ggml_tensor * b,
18180
+ struct lm_ggml_hash_set * zero_table,
18181
+ struct lm_ggml_hash_set * acc_table) {
18182
+ if (lm_ggml_hash_contains(acc_table, a)) {
18183
+ struct lm_ggml_tensor * ret = lm_ggml_add_impl(ctx, a, b, true);
18184
+ const size_t insert_result = lm_ggml_hash_insert(acc_table, ret);
18185
+ LM_GGML_ASSERT(insert_result != LM_GGML_HASHSET_FULL);
18186
+ LM_GGML_ASSERT(insert_result != LM_GGML_HASHSET_ALREADY_EXISTS);
18187
+ return ret;
18188
+ }
18005
18189
  if (lm_ggml_hash_contains(zero_table, a)) {
18006
18190
  return b;
18007
- } else {
18008
- return lm_ggml_add_impl(ctx, a, b, false);
18009
18191
  }
18192
+ return lm_ggml_add_impl(ctx, a, b, false);
18010
18193
  }
18011
18194
 
18012
- static struct lm_ggml_tensor * lm_ggml_acc_or_set(struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, struct lm_ggml_tensor * b, size_t nb1, size_t nb2, size_t nb3, size_t offset, struct lm_ggml_hash_set * zero_table) {
18195
+ static struct lm_ggml_tensor * lm_ggml_acc_or_set(
18196
+ struct lm_ggml_context * ctx,
18197
+ struct lm_ggml_tensor * a,
18198
+ struct lm_ggml_tensor * b,
18199
+ const size_t nb1,
18200
+ const size_t nb2,
18201
+ const size_t nb3,
18202
+ const size_t offset,
18203
+ struct lm_ggml_hash_set * zero_table,
18204
+ struct lm_ggml_hash_set * acc_table) {
18205
+ if (lm_ggml_hash_contains(acc_table, a)) {
18206
+ struct lm_ggml_tensor * ret = lm_ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, true);
18207
+ const size_t insert_result = lm_ggml_hash_insert(acc_table, ret);
18208
+ LM_GGML_ASSERT(insert_result != LM_GGML_HASHSET_FULL);
18209
+ LM_GGML_ASSERT(insert_result != LM_GGML_HASHSET_ALREADY_EXISTS);
18210
+ return ret;
18211
+ }
18013
18212
  if (lm_ggml_hash_contains(zero_table, a)) {
18014
- struct lm_ggml_tensor * a_zero = lm_ggml_scale(ctx, a, 0.0f);
18213
+ struct lm_ggml_tensor * a_zero = lm_ggml_scale(ctx, a, 0.0f); // FIXME this is going to produce NaN if a contains inf/NaN
18015
18214
  return lm_ggml_acc_impl(ctx, a_zero, b, nb1, nb2, nb3, offset, false);
18016
- } else {
18017
- return lm_ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, false);
18018
18215
  }
18216
+ return lm_ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, false);
18019
18217
  }
18020
18218
 
18021
- static struct lm_ggml_tensor * lm_ggml_add1_or_set(struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, struct lm_ggml_tensor * b, struct lm_ggml_hash_set * zero_table) {
18219
+ static struct lm_ggml_tensor * lm_ggml_add1_or_set(
18220
+ struct lm_ggml_context * ctx,
18221
+ struct lm_ggml_tensor * a,
18222
+ struct lm_ggml_tensor * b,
18223
+ struct lm_ggml_hash_set * zero_table,
18224
+ struct lm_ggml_hash_set * acc_table) {
18225
+ if (lm_ggml_hash_contains(acc_table, a)) {
18226
+ struct lm_ggml_tensor * ret = lm_ggml_add1_impl(ctx, a, b, true);
18227
+ const size_t insert_result = lm_ggml_hash_insert(acc_table, ret);
18228
+ LM_GGML_ASSERT(insert_result != LM_GGML_HASHSET_FULL);
18229
+ LM_GGML_ASSERT(insert_result != LM_GGML_HASHSET_ALREADY_EXISTS);
18230
+ return ret;
18231
+ }
18022
18232
  if (lm_ggml_hash_contains(zero_table, a)) {
18023
18233
  return lm_ggml_repeat(ctx, b, a);
18024
- } else {
18025
- return lm_ggml_add1_impl(ctx, a, b, false);
18026
18234
  }
18235
+ return lm_ggml_add1_impl(ctx, a, b, false);
18027
18236
  }
18028
18237
 
18029
- static struct lm_ggml_tensor * lm_ggml_sub_or_set(struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, struct lm_ggml_tensor * b, struct lm_ggml_hash_set * zero_table) {
18238
+ static struct lm_ggml_tensor * lm_ggml_sub_or_set(
18239
+ struct lm_ggml_context * ctx,
18240
+ struct lm_ggml_tensor * a,
18241
+ struct lm_ggml_tensor * b,
18242
+ struct lm_ggml_hash_set * zero_table,
18243
+ struct lm_ggml_hash_set * acc_table) {
18244
+ if (lm_ggml_hash_contains(acc_table, a)) {
18245
+ struct lm_ggml_tensor * ret = lm_ggml_sub_impl(ctx, a, b, true);
18246
+ const size_t insert_result = lm_ggml_hash_insert(acc_table, ret);
18247
+ LM_GGML_ASSERT(insert_result != LM_GGML_HASHSET_FULL);
18248
+ LM_GGML_ASSERT(insert_result != LM_GGML_HASHSET_ALREADY_EXISTS);
18249
+ return ret;
18250
+ }
18030
18251
  if (lm_ggml_hash_contains(zero_table, a)) {
18031
18252
  return lm_ggml_neg(ctx, b);
18032
- } else {
18033
- return lm_ggml_sub_impl(ctx, a, b, false);
18034
18253
  }
18254
+ return lm_ggml_sub_impl(ctx, a, b, false);
18035
18255
  }
18036
18256
 
18037
- static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggml_tensor * tensor, struct lm_ggml_hash_set * zero_table) {
18257
+ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggml_tensor * tensor, struct lm_ggml_hash_set * zero_table, struct lm_ggml_hash_set * acc_table) {
18038
18258
  struct lm_ggml_tensor * src0 = tensor->src[0];
18039
18259
  struct lm_ggml_tensor * src1 = tensor->src[1];
18040
18260
  struct lm_ggml_tensor * src2 = tensor->src[2];
@@ -18043,38 +18263,38 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
18043
18263
  case LM_GGML_OP_DUP:
18044
18264
  {
18045
18265
  if (src0->grad) {
18046
- src0->grad = lm_ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
18266
+ src0->grad = lm_ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
18047
18267
  }
18048
18268
  } break;
18049
18269
  case LM_GGML_OP_ADD:
18050
18270
  {
18051
18271
  if (src0->grad) {
18052
- src0->grad = lm_ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
18272
+ src0->grad = lm_ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
18053
18273
  }
18054
18274
  if (src1->grad) {
18055
18275
  if (lm_ggml_are_same_shape(src0, src1)) {
18056
- src1->grad = lm_ggml_add_or_set(ctx, src1->grad, tensor->grad, zero_table);
18276
+ src1->grad = lm_ggml_add_or_set(ctx, src1->grad, tensor->grad, zero_table, acc_table);
18057
18277
  } else {
18058
- src1->grad = lm_ggml_add_or_set(ctx, src1->grad, lm_ggml_repeat_back(ctx, tensor->grad, src1), zero_table);
18278
+ src1->grad = lm_ggml_add_or_set(ctx, src1->grad, lm_ggml_repeat_back(ctx, tensor->grad, src1), zero_table, acc_table);
18059
18279
  }
18060
18280
  }
18061
18281
  } break;
18062
18282
  case LM_GGML_OP_ADD1:
18063
18283
  {
18064
18284
  if (src0->grad) {
18065
- src0->grad = lm_ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
18285
+ src0->grad = lm_ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
18066
18286
  }
18067
18287
  if (src1->grad) {
18068
18288
  src1->grad = lm_ggml_add_or_set(ctx,
18069
18289
  src1->grad,
18070
18290
  lm_ggml_mean(ctx, tensor->grad), // TODO: should probably be sum instead of mean
18071
- zero_table);
18291
+ zero_table, acc_table);
18072
18292
  }
18073
18293
  } break;
18074
18294
  case LM_GGML_OP_ACC:
18075
18295
  {
18076
18296
  if (src0->grad) {
18077
- src0->grad = lm_ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
18297
+ src0->grad = lm_ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
18078
18298
  }
18079
18299
  if (src1->grad) {
18080
18300
  const size_t nb1 = ((int32_t *) tensor->op_params)[0];
@@ -18096,16 +18316,16 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
18096
18316
  lm_ggml_reshape(ctx,
18097
18317
  lm_ggml_cont(ctx, tensor_grad_view),
18098
18318
  src1->grad),
18099
- zero_table);
18319
+ zero_table, acc_table);
18100
18320
  }
18101
18321
  } break;
18102
18322
  case LM_GGML_OP_SUB:
18103
18323
  {
18104
18324
  if (src0->grad) {
18105
- src0->grad = lm_ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
18325
+ src0->grad = lm_ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
18106
18326
  }
18107
18327
  if (src1->grad) {
18108
- src1->grad = lm_ggml_sub_or_set(ctx, src1->grad, tensor->grad, zero_table);
18328
+ src1->grad = lm_ggml_sub_or_set(ctx, src1->grad, tensor->grad, zero_table, acc_table);
18109
18329
  }
18110
18330
  } break;
18111
18331
  case LM_GGML_OP_MUL:
@@ -18115,14 +18335,14 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
18115
18335
  lm_ggml_add_or_set(ctx,
18116
18336
  src0->grad,
18117
18337
  lm_ggml_mul(ctx, src1, tensor->grad),
18118
- zero_table);
18338
+ zero_table, acc_table);
18119
18339
  }
18120
18340
  if (src1->grad) {
18121
18341
  src1->grad =
18122
18342
  lm_ggml_add_or_set(ctx,
18123
18343
  src1->grad,
18124
18344
  lm_ggml_mul(ctx, src0, tensor->grad),
18125
- zero_table);
18345
+ zero_table, acc_table);
18126
18346
  }
18127
18347
  } break;
18128
18348
  case LM_GGML_OP_DIV:
@@ -18132,7 +18352,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
18132
18352
  lm_ggml_add_or_set(ctx,
18133
18353
  src0->grad,
18134
18354
  lm_ggml_div(ctx, tensor->grad, src1),
18135
- zero_table);
18355
+ zero_table, acc_table);
18136
18356
  }
18137
18357
  if (src1->grad) {
18138
18358
  src1->grad =
@@ -18141,7 +18361,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
18141
18361
  lm_ggml_mul(ctx,
18142
18362
  tensor->grad,
18143
18363
  lm_ggml_div(ctx, tensor, src1)),
18144
- zero_table);
18364
+ zero_table, acc_table);
18145
18365
  }
18146
18366
  } break;
18147
18367
  case LM_GGML_OP_SQR:
@@ -18153,7 +18373,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
18153
18373
  lm_ggml_scale(ctx,
18154
18374
  lm_ggml_mul(ctx, src0, tensor->grad),
18155
18375
  2.0f),
18156
- zero_table);
18376
+ zero_table, acc_table);
18157
18377
  }
18158
18378
  } break;
18159
18379
  case LM_GGML_OP_SQRT:
@@ -18167,7 +18387,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
18167
18387
  tensor->grad,
18168
18388
  tensor),
18169
18389
  0.5f),
18170
- zero_table);
18390
+ zero_table, acc_table);
18171
18391
  }
18172
18392
  } break;
18173
18393
  case LM_GGML_OP_LOG:
@@ -18179,7 +18399,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
18179
18399
  lm_ggml_div(ctx,
18180
18400
  tensor->grad,
18181
18401
  src0),
18182
- zero_table);
18402
+ zero_table, acc_table);
18183
18403
  }
18184
18404
  } break;
18185
18405
  case LM_GGML_OP_SIN:
@@ -18191,7 +18411,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
18191
18411
  lm_ggml_mul(ctx,
18192
18412
  tensor->grad,
18193
18413
  lm_ggml_cos(ctx, src0)),
18194
- zero_table);
18414
+ zero_table, acc_table);
18195
18415
  }
18196
18416
  } break;
18197
18417
  case LM_GGML_OP_COS:
@@ -18203,7 +18423,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
18203
18423
  lm_ggml_mul(ctx,
18204
18424
  tensor->grad,
18205
18425
  lm_ggml_sin(ctx, src0)),
18206
- zero_table);
18426
+ zero_table, acc_table);
18207
18427
  }
18208
18428
  } break;
18209
18429
  case LM_GGML_OP_SUM:
@@ -18213,7 +18433,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
18213
18433
  lm_ggml_add1_or_set(ctx,
18214
18434
  src0->grad,
18215
18435
  tensor->grad,
18216
- zero_table);
18436
+ zero_table, acc_table);
18217
18437
  }
18218
18438
  } break;
18219
18439
  case LM_GGML_OP_SUM_ROWS:
@@ -18225,7 +18445,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
18225
18445
  lm_ggml_repeat(ctx,
18226
18446
  tensor->grad,
18227
18447
  src0->grad),
18228
- zero_table);
18448
+ zero_table, acc_table);
18229
18449
  }
18230
18450
  } break;
18231
18451
  case LM_GGML_OP_MEAN:
@@ -18240,7 +18460,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
18240
18460
  src0->grad = lm_ggml_add_or_set(ctx,
18241
18461
  src0->grad,
18242
18462
  lm_ggml_repeat_back(ctx, tensor->grad, src0->grad),
18243
- zero_table);
18463
+ zero_table, acc_table);
18244
18464
  }
18245
18465
  } break;
18246
18466
  case LM_GGML_OP_REPEAT_BACK:
@@ -18250,7 +18470,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
18250
18470
  src0->grad = lm_ggml_add_or_set(ctx,
18251
18471
  src0->grad,
18252
18472
  lm_ggml_repeat(ctx, tensor->grad, src0->grad),
18253
- zero_table);
18473
+ zero_table, acc_table);
18254
18474
  }
18255
18475
  } break;
18256
18476
  case LM_GGML_OP_CONCAT:
@@ -18275,7 +18495,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
18275
18495
  src0->grad = lm_ggml_add_or_set(ctx,
18276
18496
  src0->grad,
18277
18497
  lm_ggml_rms_norm_back(ctx, src0, tensor->grad, eps),
18278
- zero_table);
18498
+ zero_table, acc_table);
18279
18499
  }
18280
18500
  } break;
18281
18501
  case LM_GGML_OP_RMS_NORM_BACK:
@@ -18323,7 +18543,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
18323
18543
  lm_ggml_add_or_set(ctx,
18324
18544
  src0->grad, // [n,m,q1,r1]
18325
18545
  s1_tg, // [n,m,q1,r1]
18326
- zero_table);
18546
+ zero_table, acc_table);
18327
18547
  }
18328
18548
  if (src1->grad) {
18329
18549
  src1->grad =
@@ -18341,7 +18561,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
18341
18561
  src0, // [n,m,q1,r1]
18342
18562
  lm_ggml_transpose(ctx, // [p,m,qq,rr]
18343
18563
  tensor->grad)), // [m,p,qq,rr]
18344
- zero_table);
18564
+ zero_table, acc_table);
18345
18565
  }
18346
18566
  } break;
18347
18567
  case LM_GGML_OP_MUL_MAT_ID:
@@ -18363,7 +18583,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
18363
18583
  lm_ggml_add_or_set(ctx,
18364
18584
  src0->grad,
18365
18585
  lm_ggml_scale_impl(ctx, tensor->grad, s, false),
18366
- zero_table);
18586
+ zero_table, acc_table);
18367
18587
  }
18368
18588
  } break;
18369
18589
  case LM_GGML_OP_SET:
@@ -18392,7 +18612,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
18392
18612
  tensor->grad,
18393
18613
  lm_ggml_neg(ctx, tensor_grad_view),
18394
18614
  nb1, nb2, nb3, offset, false),
18395
- zero_table);
18615
+ zero_table, acc_table);
18396
18616
  }
18397
18617
 
18398
18618
  if (src1->grad) {
@@ -18402,7 +18622,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
18402
18622
  lm_ggml_reshape(ctx,
18403
18623
  lm_ggml_cont(ctx, tensor_grad_view),
18404
18624
  src1->grad),
18405
- zero_table);
18625
+ zero_table, acc_table);
18406
18626
  }
18407
18627
  } break;
18408
18628
  case LM_GGML_OP_CPY:
@@ -18413,7 +18633,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
18413
18633
  // tensor = src0 * 1 + src1 * 0
18414
18634
  if (src0->grad) {
18415
18635
  // dsrc0 = dtensor * 1
18416
- src0->grad = lm_ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
18636
+ src0->grad = lm_ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
18417
18637
  }
18418
18638
  if (src1->grad) {
18419
18639
  // dsrc1 = dtensor * 0 -> noop
@@ -18425,7 +18645,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
18425
18645
  if (src0->grad) {
18426
18646
  LM_GGML_ASSERT(lm_ggml_is_contiguous(src0->grad));
18427
18647
  LM_GGML_ASSERT(lm_ggml_is_contiguous(tensor->grad));
18428
- src0->grad = lm_ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
18648
+ src0->grad = lm_ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
18429
18649
  }
18430
18650
  } break;
18431
18651
  case LM_GGML_OP_RESHAPE:
@@ -18439,7 +18659,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
18439
18659
  ? tensor->grad
18440
18660
  : lm_ggml_cont(ctx, tensor->grad),
18441
18661
  src0->grad),
18442
- zero_table);
18662
+ zero_table, acc_table);
18443
18663
  }
18444
18664
  } break;
18445
18665
  case LM_GGML_OP_VIEW:
@@ -18468,7 +18688,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
18468
18688
  nb3 = (nb3 / n0) * ng;
18469
18689
  }
18470
18690
 
18471
- src0->grad = lm_ggml_acc_or_set(ctx, src0->grad, tensor->grad, nb1, nb2, nb3, offset, zero_table);
18691
+ src0->grad = lm_ggml_acc_or_set(ctx, src0->grad, tensor->grad, nb1, nb2, nb3, offset, zero_table, acc_table);
18472
18692
  }
18473
18693
  } break;
18474
18694
  case LM_GGML_OP_PERMUTE:
@@ -18493,7 +18713,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
18493
18713
  axes_backward[1],
18494
18714
  axes_backward[2],
18495
18715
  axes_backward[3]),
18496
- zero_table);
18716
+ zero_table, acc_table);
18497
18717
  }
18498
18718
  } break;
18499
18719
  case LM_GGML_OP_TRANSPOSE:
@@ -18503,7 +18723,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
18503
18723
  src0->grad =
18504
18724
  lm_ggml_add_or_set(ctx, src0->grad,
18505
18725
  lm_ggml_transpose(ctx, tensor->grad),
18506
- zero_table);
18726
+ zero_table, acc_table);
18507
18727
  }
18508
18728
  } break;
18509
18729
  case LM_GGML_OP_GET_ROWS:
@@ -18515,7 +18735,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
18515
18735
  // last lm_ggml_get_rows_back argument src0->grad is only
18516
18736
  // necessary to setup correct output shape
18517
18737
  lm_ggml_get_rows_back(ctx, tensor->grad, src1, src0->grad),
18518
- zero_table);
18738
+ zero_table, acc_table);
18519
18739
  }
18520
18740
  if (src1->grad) {
18521
18741
  // noop
@@ -18539,7 +18759,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
18539
18759
  /* lm_ggml_diag_mask_inf_impl() shouldn't be here */
18540
18760
  /* ref: https://github.com/ggerganov/llama.cpp/pull/4203#discussion_r1412377992 */
18541
18761
  lm_ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false),
18542
- zero_table);
18762
+ zero_table, acc_table);
18543
18763
  }
18544
18764
  } break;
18545
18765
  case LM_GGML_OP_DIAG_MASK_ZERO:
@@ -18550,7 +18770,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
18550
18770
  src0->grad =
18551
18771
  lm_ggml_add_or_set(ctx, src0->grad,
18552
18772
  lm_ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false),
18553
- zero_table);
18773
+ zero_table, acc_table);
18554
18774
  }
18555
18775
  } break;
18556
18776
  case LM_GGML_OP_SOFT_MAX:
@@ -18560,7 +18780,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
18560
18780
  src0->grad =
18561
18781
  lm_ggml_add_or_set(ctx, src0->grad,
18562
18782
  lm_ggml_soft_max_back(ctx, tensor->grad, tensor),
18563
- zero_table);
18783
+ zero_table, acc_table);
18564
18784
  }
18565
18785
 
18566
18786
  } break;
@@ -18601,7 +18821,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
18601
18821
  attn_factor,
18602
18822
  beta_fast,
18603
18823
  beta_slow),
18604
- zero_table);
18824
+ zero_table, acc_table);
18605
18825
  }
18606
18826
  } break;
18607
18827
  case LM_GGML_OP_ROPE_BACK:
@@ -18637,7 +18857,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
18637
18857
  beta_fast,
18638
18858
  beta_slow,
18639
18859
  false),
18640
- zero_table);
18860
+ zero_table, acc_table);
18641
18861
  }
18642
18862
  } break;
18643
18863
  case LM_GGML_OP_CLAMP:
@@ -18662,7 +18882,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
18662
18882
  src1->grad = lm_ggml_add_or_set(ctx,
18663
18883
  src1->grad,
18664
18884
  lm_ggml_im2col_back(ctx, src0, tensor->grad, src1->ne, s0, s1, p0, p1, d0, d1, is_2D),
18665
- zero_table);
18885
+ zero_table, acc_table);
18666
18886
  }
18667
18887
  } break;
18668
18888
  case LM_GGML_OP_IM2COL_BACK:
@@ -18691,7 +18911,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
18691
18911
  src0->grad = lm_ggml_add_or_set(ctx,
18692
18912
  src0->grad,
18693
18913
  lm_ggml_pool_2d_back(ctx, tensor->grad, src0, op, k0, k1, s0, s1, p0, p1),
18694
- zero_table);
18914
+ zero_table, acc_table);
18695
18915
  }
18696
18916
  } break;
18697
18917
  case LM_GGML_OP_POOL_2D_BACK:
@@ -18756,7 +18976,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
18756
18976
  src0->grad = lm_ggml_add_or_set(ctx,
18757
18977
  src0->grad,
18758
18978
  grad_q,
18759
- zero_table);
18979
+ zero_table, acc_table);
18760
18980
  }
18761
18981
  if (src1->grad) {
18762
18982
  struct lm_ggml_tensor * view_k = lm_ggml_view_1d(ctx, flash_grad, elem_k, offs_k);
@@ -18764,7 +18984,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
18764
18984
  src1->grad = lm_ggml_add_or_set(ctx,
18765
18985
  src1->grad,
18766
18986
  grad_k,
18767
- zero_table);
18987
+ zero_table, acc_table);
18768
18988
  }
18769
18989
  if (src2->grad) {
18770
18990
  struct lm_ggml_tensor * view_v = lm_ggml_view_1d(ctx, flash_grad, elem_v, offs_v);
@@ -18772,7 +18992,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
18772
18992
  src2->grad = lm_ggml_add_or_set(ctx,
18773
18993
  src2->grad,
18774
18994
  grad_v,
18775
- zero_table);
18995
+ zero_table, acc_table);
18776
18996
  }
18777
18997
  } break;
18778
18998
  case LM_GGML_OP_FLASH_ATTN_BACK:
@@ -18798,7 +19018,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
18798
19018
  lm_ggml_mul(ctx,
18799
19019
  lm_ggml_sgn(ctx, src0),
18800
19020
  tensor->grad),
18801
- zero_table);
19021
+ zero_table, acc_table);
18802
19022
  }
18803
19023
  } break;
18804
19024
  case LM_GGML_UNARY_OP_SGN:
@@ -18810,7 +19030,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
18810
19030
  case LM_GGML_UNARY_OP_NEG:
18811
19031
  {
18812
19032
  if (src0->grad) {
18813
- src0->grad = lm_ggml_sub_or_set(ctx, src0->grad, tensor->grad, zero_table);
19033
+ src0->grad = lm_ggml_sub_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
18814
19034
  }
18815
19035
  } break;
18816
19036
  case LM_GGML_UNARY_OP_STEP:
@@ -18835,7 +19055,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
18835
19055
  lm_ggml_mul(ctx,
18836
19056
  lm_ggml_step(ctx, src0),
18837
19057
  tensor->grad),
18838
- zero_table);
19058
+ zero_table, acc_table);
18839
19059
  }
18840
19060
  } break;
18841
19061
  case LM_GGML_UNARY_OP_SIGMOID:
@@ -18857,7 +19077,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
18857
19077
  src0->grad = lm_ggml_add_or_set(ctx,
18858
19078
  src0->grad,
18859
19079
  lm_ggml_silu_back(ctx, src0, tensor->grad),
18860
- zero_table);
19080
+ zero_table, acc_table);
18861
19081
  }
18862
19082
  } break;
18863
19083
  case LM_GGML_UNARY_OP_EXP:
@@ -18866,7 +19086,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
18866
19086
  src0->grad = lm_ggml_add_or_set(ctx,
18867
19087
  src0->grad,
18868
19088
  lm_ggml_mul(ctx, tensor, tensor->grad),
18869
- zero_table);
19089
+ zero_table, acc_table);
18870
19090
  }
18871
19091
  } break;
18872
19092
  default:
@@ -18896,13 +19116,17 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
18896
19116
  src0,
18897
19117
  src1,
18898
19118
  tensor->grad),
18899
- zero_table);
19119
+ zero_table, acc_table);
18900
19120
  }
18901
19121
  } break;
18902
19122
  case LM_GGML_OP_CROSS_ENTROPY_LOSS_BACK:
18903
19123
  {
18904
19124
  LM_GGML_ABORT("fatal error"); // not supported
18905
19125
  }
19126
+ case LM_GGML_OP_OPT_STEP_ADAMW:
19127
+ {
19128
+ LM_GGML_ABORT("fatal error"); // not supported
19129
+ }
18906
19130
  case LM_GGML_OP_NONE:
18907
19131
  {
18908
19132
  // nop
@@ -18992,7 +19216,7 @@ void lm_ggml_build_forward_expand(struct lm_ggml_cgraph * cgraph, struct lm_ggml
18992
19216
  lm_ggml_build_forward_impl(cgraph, tensor, true);
18993
19217
  }
18994
19218
 
18995
- void lm_ggml_build_backward_expand(struct lm_ggml_context * ctx, struct lm_ggml_cgraph * gf, struct lm_ggml_cgraph * gb, bool keep) {
19219
+ void lm_ggml_build_backward_expand(struct lm_ggml_context * ctx, struct lm_ggml_cgraph * gf, struct lm_ggml_cgraph * gb, bool accumulate, bool keep) {
18996
19220
  LM_GGML_ASSERT(gf->n_nodes > 0);
18997
19221
  LM_GGML_ASSERT(gf->grads);
18998
19222
 
@@ -19008,21 +19232,35 @@ void lm_ggml_build_backward_expand(struct lm_ggml_context * ctx, struct lm_ggml_
19008
19232
  }
19009
19233
  }
19010
19234
 
19011
- // remember original gradients which start with zero values
19235
+ // keep tables of original gradients for replacement/accumulation logic
19012
19236
  struct lm_ggml_hash_set zero_table = lm_ggml_hash_set_new(gf->size);
19237
+ struct lm_ggml_hash_set acc_table = lm_ggml_hash_set_new(gf->size);
19013
19238
  for (int i = 0; i < gf->n_nodes; i++) {
19014
- if (gf->grads[i]) {
19015
- lm_ggml_hash_insert(&zero_table, gf->grads[i]);
19239
+ struct lm_ggml_tensor * node = gf->nodes[i];
19240
+
19241
+ if (node->grad) {
19242
+ {
19243
+ const size_t insert_result = lm_ggml_hash_insert(&zero_table, node->grad);
19244
+ LM_GGML_ASSERT(insert_result != LM_GGML_HASHSET_FULL);
19245
+ LM_GGML_ASSERT(insert_result != LM_GGML_HASHSET_ALREADY_EXISTS);
19246
+ }
19247
+
19248
+ // only gradients of trainable parameters should be accumulated
19249
+ if (accumulate && (node->flags & LM_GGML_TENSOR_FLAG_PARAM)) {
19250
+ const size_t insert_result = lm_ggml_hash_insert(&acc_table, node->grad);
19251
+ LM_GGML_ASSERT(insert_result != LM_GGML_HASHSET_FULL);
19252
+ LM_GGML_ASSERT(insert_result != LM_GGML_HASHSET_ALREADY_EXISTS);
19253
+ }
19016
19254
  }
19017
19255
  }
19018
19256
 
19019
19257
  for (int i = gf->n_nodes - 1; i >= 0; i--) {
19020
19258
  struct lm_ggml_tensor * node = gf->nodes[i];
19021
19259
 
19022
- // inplace operations to add gradients are not created by lm_ggml_compute_backward
19260
+ // inplace operations to add gradients are not created by lm_ggml_compute_backward except for gradient accumulation
19023
19261
  // use allocator to automatically make inplace operations
19024
19262
  if (node->grad) {
19025
- lm_ggml_compute_backward(ctx, node, &zero_table);
19263
+ lm_ggml_compute_backward(ctx, node, &zero_table, &acc_table);
19026
19264
  }
19027
19265
  }
19028
19266
 
@@ -19036,8 +19274,30 @@ void lm_ggml_build_backward_expand(struct lm_ggml_context * ctx, struct lm_ggml_
19036
19274
  }
19037
19275
 
19038
19276
  lm_ggml_hash_set_free(&zero_table);
19277
+ lm_ggml_hash_set_free(&acc_table);
19278
+ }
19279
+
19280
+ void lm_ggml_build_opt_adamw(
19281
+ struct lm_ggml_context * ctx,
19282
+ struct lm_ggml_cgraph * gf,
19283
+ struct lm_ggml_cgraph * gb,
19284
+ float alpha,
19285
+ float beta1,
19286
+ float beta2,
19287
+ float eps,
19288
+ float wd) {
19289
+ for (int i = 0; i < gf->n_nodes; i++) {
19290
+ struct lm_ggml_tensor * node = gf->nodes[i];
19291
+
19292
+ if (node->flags & LM_GGML_TENSOR_FLAG_PARAM) {
19293
+ LM_GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node);
19294
+ struct lm_ggml_tensor * opt_step = lm_ggml_opt_step_adamw(ctx, node, alpha, beta1, beta2, eps, wd);
19295
+ lm_ggml_build_forward_expand(gb, opt_step);
19296
+ }
19297
+ }
19039
19298
  }
19040
19299
 
19300
+
19041
19301
  static void * incr_ptr_aligned(void ** p, size_t size, size_t align) {
19042
19302
  void * ptr = *p;
19043
19303
  ptr = (void *) LM_GGML_PAD((uintptr_t) ptr, align);
@@ -19165,10 +19425,28 @@ void lm_ggml_graph_reset(struct lm_ggml_cgraph * cgraph) {
19165
19425
  LM_GGML_ASSERT(cgraph->grads != NULL);
19166
19426
 
19167
19427
  for (int i = 0; i < cgraph->n_nodes; i++) {
19168
- struct lm_ggml_tensor * grad = cgraph->grads[i];
19428
+ struct lm_ggml_tensor * node = cgraph->nodes[i];
19429
+
19430
+ // initial gradients of loss should be 1, 0 otherwise
19431
+ if (node->grad) {
19432
+ if (node->flags & LM_GGML_TENSOR_FLAG_LOSS) {
19433
+ LM_GGML_ASSERT(node->grad->buffer);
19434
+ LM_GGML_ASSERT(node->type == LM_GGML_TYPE_F32);
19435
+ LM_GGML_ASSERT(lm_ggml_is_scalar(node));
19436
+
19437
+ const float onef = 1.0f;
19438
+ lm_ggml_backend_tensor_set(node->grad, &onef, 0, lm_ggml_nbytes(node->grad));
19439
+ } else {
19440
+ lm_ggml_set_zero(node->grad);
19441
+ }
19442
+ }
19169
19443
 
19170
- if (grad) {
19171
- lm_ggml_set_zero(grad);
19444
+ LM_GGML_ASSERT(node);
19445
+ if (node->op == LM_GGML_OP_OPT_STEP_ADAMW) {
19446
+ // set iteration to 1 and clear momenta
19447
+ lm_ggml_set_op_params_i32(node, 0, 1);
19448
+ lm_ggml_set_zero(node->src[2]);
19449
+ lm_ggml_set_zero(node->src[3]);
19172
19450
  }
19173
19451
  }
19174
19452
  }
@@ -19461,6 +19739,7 @@ static int lm_ggml_get_n_tasks(struct lm_ggml_tensor * node, int n_threads) {
19461
19739
  } break;
19462
19740
  case LM_GGML_OP_CROSS_ENTROPY_LOSS:
19463
19741
  case LM_GGML_OP_CROSS_ENTROPY_LOSS_BACK:
19742
+ case LM_GGML_OP_OPT_STEP_ADAMW:
19464
19743
  {
19465
19744
  n_tasks = n_threads;
19466
19745
  } break;
@@ -19756,8 +20035,8 @@ void lm_ggml_threadpool_resume(struct lm_ggml_threadpool * threadpool) {
19756
20035
 
19757
20036
  struct lm_ggml_cplan lm_ggml_graph_plan(
19758
20037
  const struct lm_ggml_cgraph * cgraph,
19759
- int n_threads,
19760
- struct lm_ggml_threadpool * threadpool) {
20038
+ int n_threads,
20039
+ struct lm_ggml_threadpool * threadpool) {
19761
20040
 
19762
20041
  if (threadpool == NULL) {
19763
20042
  LM_GGML_PRINT_DEBUG("Threadpool is not specified. Will create a disposable threadpool : n_threads %d\n", n_threads);
@@ -19932,34 +20211,33 @@ struct lm_ggml_cplan lm_ggml_graph_plan(
19932
20211
 
19933
20212
  static thread_ret_t lm_ggml_graph_compute_thread(void * data) {
19934
20213
  struct lm_ggml_compute_state * state = (struct lm_ggml_compute_state *) data;
20214
+ struct lm_ggml_threadpool * tp = state->threadpool;
19935
20215
 
19936
- const struct lm_ggml_cgraph * cgraph = state->threadpool->cgraph;
19937
- const struct lm_ggml_cplan * cplan = state->threadpool->cplan;
20216
+ const struct lm_ggml_cgraph * cgraph = tp->cgraph;
20217
+ const struct lm_ggml_cplan * cplan = tp->cplan;
19938
20218
 
19939
20219
  set_numa_thread_affinity(state->ith);
19940
20220
 
19941
20221
  struct lm_ggml_compute_params params = {
19942
20222
  /*.ith =*/ state->ith,
19943
- /*.nth =*/ state->threadpool->n_threads_cur,
20223
+ /*.nth =*/ atomic_load_explicit(&tp->n_threads_cur, memory_order_relaxed),
19944
20224
  /*.wsize =*/ cplan->work_size,
19945
20225
  /*.wdata =*/ cplan->work_data,
19946
- /*.threadpool=*/ state->threadpool,
20226
+ /*.threadpool=*/ tp,
19947
20227
  };
19948
20228
 
19949
- for (int node_n = 0; node_n < cgraph->n_nodes; node_n++) {
20229
+ for (int node_n = 0; node_n < cgraph->n_nodes && !tp->abort; node_n++) {
19950
20230
  struct lm_ggml_tensor * node = cgraph->nodes[node_n];
19951
20231
 
19952
20232
  lm_ggml_compute_forward(&params, node);
19953
20233
 
19954
- if (state->ith == 0 && cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) {
19955
- state->threadpool->ec = LM_GGML_STATUS_ABORTED;
20234
+ if (state->ith == 0 && cplan->abort_callback &&
20235
+ cplan->abort_callback(cplan->abort_callback_data)) {
20236
+ tp->abort = true;
20237
+ tp->ec = LM_GGML_STATUS_ABORTED;
19956
20238
  }
19957
20239
 
19958
20240
  lm_ggml_barrier(state->threadpool);
19959
-
19960
- if (state->threadpool->ec != LM_GGML_STATUS_SUCCESS) {
19961
- break;
19962
- }
19963
20241
  }
19964
20242
 
19965
20243
  return 0;
@@ -19967,7 +20245,15 @@ static thread_ret_t lm_ggml_graph_compute_thread(void * data) {
19967
20245
 
19968
20246
  #ifndef LM_GGML_USE_OPENMP
19969
20247
 
19970
- static inline bool lm_ggml_graph_compute_ready(struct lm_ggml_compute_state * state) {
20248
+ // check if thread is active
20249
+ static inline bool lm_ggml_graph_compute_thread_active(struct lm_ggml_compute_state * state) {
20250
+ struct lm_ggml_threadpool * threadpool = state->threadpool;
20251
+ int n_threads = atomic_load_explicit(&threadpool->n_threads_cur, memory_order_relaxed);
20252
+ return (state->ith < n_threads);
20253
+ }
20254
+
20255
+ // check if thread is ready to proceed (exit from polling or sleeping)
20256
+ static inline bool lm_ggml_graph_compute_thread_ready(struct lm_ggml_compute_state * state) {
19971
20257
  struct lm_ggml_threadpool * threadpool = state->threadpool;
19972
20258
 
19973
20259
  if (state->pending || threadpool->stop || threadpool->pause) { return true; }
@@ -19975,21 +20261,37 @@ static inline bool lm_ggml_graph_compute_ready(struct lm_ggml_compute_state * st
19975
20261
  // check for new graph/work
19976
20262
  int new_graph = atomic_load_explicit(&threadpool->n_graph, memory_order_relaxed);
19977
20263
  if (new_graph != state->last_graph) {
19978
- state->pending = (state->ith < threadpool->n_threads_cur);
20264
+ state->pending = lm_ggml_graph_compute_thread_active(state);
19979
20265
  state->last_graph = new_graph;
19980
20266
  }
19981
20267
 
19982
20268
  return state->pending;
19983
20269
  }
19984
20270
 
20271
+ // sync thread state after polling
20272
+ static inline void lm_ggml_graph_compute_thread_sync(struct lm_ggml_compute_state * state) {
20273
+ // TSAN doesn't support standalone fence yet, we use a dummy read-modify-write instead
20274
+ #ifdef LM_GGML_TSAN_ENABLED
20275
+ atomic_fetch_add_explicit(&state->threadpool->n_graph, 0, memory_order_seq_cst);
20276
+ #else
20277
+ atomic_thread_fence(memory_order_seq_cst);
20278
+ #endif
20279
+ UNUSED(state);
20280
+ }
20281
+
19985
20282
  static inline bool lm_ggml_graph_compute_poll_for_work(struct lm_ggml_compute_state * state) {
19986
20283
  struct lm_ggml_threadpool * threadpool = state->threadpool;
19987
20284
 
20285
+ // Skip polling for unused threads
20286
+ if (!lm_ggml_graph_compute_thread_active(state)) {
20287
+ return state->pending;
20288
+ }
20289
+
19988
20290
  // This seems to make 0 ... 100 a decent range for polling level across modern processors.
19989
20291
  // Perhaps, we can adjust it dynamically based on load and things.
19990
20292
  const uint64_t n_rounds = 1024UL * 128 * threadpool->poll;
19991
20293
 
19992
- for (uint64_t i=0; !lm_ggml_graph_compute_ready(state) && i<n_rounds; i++) {
20294
+ for (uint64_t i=0; !lm_ggml_graph_compute_thread_ready(state) && i < n_rounds; i++) {
19993
20295
  // No new work. Keep polling.
19994
20296
  lm_ggml_thread_cpu_relax();
19995
20297
  }
@@ -20001,13 +20303,14 @@ static inline bool lm_ggml_graph_compute_check_for_work(struct lm_ggml_compute_s
20001
20303
  struct lm_ggml_threadpool * threadpool = state->threadpool;
20002
20304
 
20003
20305
  if (lm_ggml_graph_compute_poll_for_work(state)) {
20306
+ lm_ggml_graph_compute_thread_sync(state);
20004
20307
  return state->pending;
20005
20308
  }
20006
20309
 
20007
20310
  lm_ggml_mutex_lock_shared(&threadpool->mutex);
20008
- while (!lm_ggml_graph_compute_ready(state)) {
20311
+ while (!lm_ggml_graph_compute_thread_ready(state)) {
20009
20312
  // No new work. Wait for the signal.
20010
- LM_GGML_PRINT_DEBUG("thread #%d waiting for work\n", state->ith);
20313
+ LM_GGML_PRINT_DEBUG("thread #%d waiting for work (sleeping)\n", state->ith);
20011
20314
  lm_ggml_cond_wait(&threadpool->cond, &threadpool->mutex);
20012
20315
  }
20013
20316
  lm_ggml_mutex_unlock_shared(&threadpool->mutex);
@@ -20054,13 +20357,20 @@ static thread_ret_t lm_ggml_graph_compute_secondary_thread(void* data) {
20054
20357
  }
20055
20358
 
20056
20359
  // Start processing new graph
20057
- static void lm_ggml_graph_compute_kickoff(struct lm_ggml_threadpool * threadpool)
20360
+ static void lm_ggml_graph_compute_kickoff(struct lm_ggml_threadpool * threadpool, int n_threads)
20058
20361
  {
20059
- // always take the mutex here because the worker threads are doing hybrid poll/wait
20362
+ // Always take the mutex here because the worker threads are doing hybrid poll/wait
20060
20363
 
20061
20364
  lm_ggml_mutex_lock(&threadpool->mutex);
20062
20365
 
20063
- atomic_fetch_add_explicit(&threadpool->n_graph, 1, memory_order_relaxed);
20366
+ LM_GGML_PRINT_DEBUG("threadpool: n_threads_cur %d n_threads %d\n", threadpool->n_threads_cur, n_threads);
20367
+
20368
+ // Update the number of active threads
20369
+ atomic_store_explicit(&threadpool->n_threads_cur, n_threads, memory_order_relaxed);
20370
+
20371
+ // Indicate the graph is ready to be processed
20372
+ // We need the full seq-cst fence here because of the polling threads (used in thread_sync)
20373
+ atomic_fetch_add_explicit(&threadpool->n_graph, 1, memory_order_seq_cst);
20064
20374
 
20065
20375
  if (threadpool->pause) {
20066
20376
  // Update main thread prio and affinity to match the threadpool settings
@@ -20119,6 +20429,7 @@ static struct lm_ggml_threadpool * lm_ggml_threadpool_new_impl(
20119
20429
  threadpool->current_chunk = 0;
20120
20430
  threadpool->stop = false;
20121
20431
  threadpool->pause = tpp->paused;
20432
+ threadpool->abort = false;
20122
20433
  threadpool->workers = NULL;
20123
20434
  threadpool->n_threads_max = tpp->n_threads;
20124
20435
  threadpool->n_threads_cur = tpp->n_threads;
@@ -20194,15 +20505,11 @@ enum lm_ggml_status lm_ggml_graph_compute(struct lm_ggml_cgraph * cgraph, struct
20194
20505
  // No worker threads should be accessing the parameters below at this stage
20195
20506
  threadpool->cgraph = cgraph;
20196
20507
  threadpool->cplan = cplan;
20197
- threadpool->n_threads_cur = n_threads;
20198
20508
  threadpool->current_chunk = 0;
20509
+ threadpool->abort = false;
20199
20510
  threadpool->ec = LM_GGML_STATUS_SUCCESS;
20200
20511
  }
20201
20512
 
20202
- if (n_threads > threadpool->n_threads_max) {
20203
- LM_GGML_PRINT("WARNING: cplan is requesting more threads than the threadpool contains. Expect a bad time!\n");
20204
- }
20205
-
20206
20513
  #ifdef LM_GGML_USE_OPENMP
20207
20514
  if (n_threads > 1) {
20208
20515
  #pragma omp parallel num_threads(n_threads)
@@ -20211,17 +20518,23 @@ enum lm_ggml_status lm_ggml_graph_compute(struct lm_ggml_cgraph * cgraph, struct
20211
20518
  {
20212
20519
  // update the number of threads from the actual number of threads that we got from OpenMP
20213
20520
  n_threads = omp_get_num_threads();
20214
- threadpool->n_threads_cur = n_threads;
20521
+ atomic_store_explicit(&threadpool->n_threads_cur, n_threads, memory_order_relaxed);
20215
20522
  }
20216
20523
 
20217
20524
  lm_ggml_graph_compute_thread(&threadpool->workers[omp_get_thread_num()]);
20218
20525
  }
20219
20526
  } else {
20527
+ atomic_store_explicit(&threadpool->n_threads_cur, 1, memory_order_relaxed);
20220
20528
  lm_ggml_graph_compute_thread(&threadpool->workers[0]);
20221
20529
  }
20222
20530
  #else
20531
+ if (n_threads > threadpool->n_threads_max) {
20532
+ LM_GGML_PRINT("WARNING: cplan requested more threads (%d) than available (%d)\n", n_threads, threadpool->n_threads_max);
20533
+ n_threads = threadpool->n_threads_max;
20534
+ }
20535
+
20223
20536
  // Kick all threads to start the new graph
20224
- lm_ggml_graph_compute_kickoff(threadpool);
20537
+ lm_ggml_graph_compute_kickoff(threadpool, n_threads);
20225
20538
 
20226
20539
  // This is a work thread too
20227
20540
  lm_ggml_graph_compute_thread(&threadpool->workers[0]);
@@ -21823,7 +22136,7 @@ enum lm_ggml_opt_result lm_ggml_opt_resume(
21823
22136
  lm_ggml_build_forward_expand(gf, f);
21824
22137
 
21825
22138
  struct lm_ggml_cgraph * gb = lm_ggml_graph_dup(ctx, gf);
21826
- lm_ggml_build_backward_expand(ctx, gf, gb, true);
22139
+ lm_ggml_build_backward_expand(ctx, gf, gb, false, true);
21827
22140
 
21828
22141
  return lm_ggml_opt_resume_g(ctx, opt, f, gf, gb, NULL, NULL);
21829
22142
  }