cui-llama.rn 1.1.6 → 1.2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/src/main/java/com/rnllama/LlamaContext.java +11 -3
- package/android/src/main/jni.cpp +28 -4
- package/cpp/common.cpp +3 -0
- package/cpp/common.h +2 -0
- package/cpp/ggml-aarch64.c +1794 -1368
- package/cpp/ggml-alloc.c +6 -0
- package/cpp/ggml-backend-impl.h +10 -9
- package/cpp/ggml-backend.c +25 -0
- package/cpp/ggml-backend.h +2 -1
- package/cpp/ggml-cpu-impl.h +614 -0
- package/cpp/ggml-impl.h +13 -609
- package/cpp/ggml-metal.m +1 -0
- package/cpp/ggml-quants.c +1 -0
- package/cpp/ggml.c +457 -144
- package/cpp/ggml.h +37 -8
- package/cpp/llama-impl.h +2 -0
- package/cpp/llama-sampling.cpp +7 -5
- package/cpp/llama-vocab.cpp +1 -5
- package/cpp/llama-vocab.h +9 -5
- package/cpp/llama.cpp +202 -30
- package/cpp/llama.h +2 -0
- package/cpp/log.cpp +1 -1
- package/cpp/log.h +2 -0
- package/cpp/sampling.cpp +9 -1
- package/cpp/sgemm.cpp +1 -0
- package/cpp/unicode.cpp +1 -0
- package/lib/commonjs/index.js +8 -1
- package/lib/commonjs/index.js.map +1 -1
- package/lib/module/index.js +8 -1
- package/lib/module/index.js.map +1 -1
- package/lib/typescript/index.d.ts +1 -1
- package/lib/typescript/index.d.ts.map +1 -1
- package/package.json +1 -1
- package/src/index.ts +18 -4
package/cpp/ggml.c
CHANGED
@@ -1,7 +1,9 @@
|
|
1
1
|
#define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnings on Windows
|
2
2
|
#define _USE_MATH_DEFINES // For M_PI on MSVC
|
3
3
|
|
4
|
+
#include "ggml-backend.h"
|
4
5
|
#include "ggml-impl.h"
|
6
|
+
#include "ggml-cpu-impl.h"
|
5
7
|
#include "ggml-quants.h"
|
6
8
|
#include "ggml.h"
|
7
9
|
#include "ggml-aarch64.h"
|
@@ -61,6 +63,25 @@ int lm_ggml_sve_cnt_b = 0;
|
|
61
63
|
#pragma warning(disable: 4702)
|
62
64
|
#endif
|
63
65
|
|
66
|
+
// Note: once we move threading into a separate C++ file
|
67
|
+
// will use std::hardware_destructive_interference_size instead of hardcoding it here
|
68
|
+
// and we'll use C++ attribute syntax.
|
69
|
+
#define LM_GGML_CACHE_LINE 64
|
70
|
+
|
71
|
+
#if defined(__clang__) || defined(__GNUC__)
|
72
|
+
#define LM_GGML_CACHE_ALIGN __attribute__((aligned(LM_GGML_CACHE_LINE)))
|
73
|
+
#endif
|
74
|
+
|
75
|
+
#if defined(__has_feature)
|
76
|
+
#if __has_feature(thread_sanitizer)
|
77
|
+
#define LM_GGML_TSAN_ENABLED 1
|
78
|
+
#endif
|
79
|
+
#else // __has_feature
|
80
|
+
#if defined(__SANITIZE_THREAD__)
|
81
|
+
#define LM_GGML_TSAN_ENABLED 1
|
82
|
+
#endif
|
83
|
+
#endif // __has_feature
|
84
|
+
|
64
85
|
#if defined(_WIN32)
|
65
86
|
|
66
87
|
#define WIN32_LEAN_AND_MEAN
|
@@ -70,6 +91,8 @@ int lm_ggml_sve_cnt_b = 0;
|
|
70
91
|
#include <windows.h>
|
71
92
|
|
72
93
|
#if !defined(__clang__)
|
94
|
+
#define LM_GGML_CACHE_ALIGN __declspec(align(LM_GGML_CACHE_LINE))
|
95
|
+
|
73
96
|
typedef volatile LONG atomic_int;
|
74
97
|
typedef atomic_int atomic_bool;
|
75
98
|
typedef atomic_int atomic_flag;
|
@@ -112,6 +135,9 @@ static atomic_bool atomic_flag_test_and_set(atomic_flag * ptr) {
|
|
112
135
|
static void atomic_flag_clear(atomic_flag * ptr) {
|
113
136
|
InterlockedExchange(ptr, 0);
|
114
137
|
}
|
138
|
+
static void atomic_thread_fence(memory_order mo) {
|
139
|
+
MemoryBarrier();
|
140
|
+
}
|
115
141
|
#else // clang
|
116
142
|
#include <stdatomic.h>
|
117
143
|
#endif
|
@@ -287,7 +313,6 @@ void lm_ggml_abort(const char * file, int line, const char * fmt, ...) {
|
|
287
313
|
#define LM_GGML_DEBUG 0
|
288
314
|
#define LM_GGML_GELU_FP16
|
289
315
|
#define LM_GGML_GELU_QUICK_FP16
|
290
|
-
#define LM_GGML_N_TASKS_MAX (-1)
|
291
316
|
|
292
317
|
#define LM_GGML_SOFT_MAX_UNROLL 4
|
293
318
|
#define LM_GGML_VEC_DOT_UNROLL 2
|
@@ -2005,17 +2030,18 @@ struct lm_ggml_threadpool {
|
|
2005
2030
|
|
2006
2031
|
// synchronization primitives
|
2007
2032
|
atomic_int n_graph; // incremented when there is work to be done (i.e each graph)
|
2008
|
-
atomic_int n_barrier;
|
2009
|
-
atomic_int n_barrier_passed;
|
2033
|
+
atomic_int LM_GGML_CACHE_ALIGN n_barrier;
|
2034
|
+
atomic_int LM_GGML_CACHE_ALIGN n_barrier_passed;
|
2010
2035
|
atomic_int current_chunk; // currently processing chunk during Mat_Mul, shared between all the threads.
|
2011
2036
|
|
2012
2037
|
// these are atomic as an annotation for thread-sanitizer
|
2013
2038
|
atomic_bool stop; // Used for stopping the threadpool altogether
|
2014
2039
|
atomic_bool pause; // Used for pausing the threadpool or individual threads
|
2040
|
+
atomic_bool abort; // Used for aborting processing of a graph
|
2015
2041
|
|
2016
2042
|
struct lm_ggml_compute_state * workers; // per thread state
|
2017
2043
|
int n_threads_max; // number of threads in the pool
|
2018
|
-
|
2044
|
+
atomic_int n_threads_cur; // number of threads used in the current graph
|
2019
2045
|
|
2020
2046
|
int32_t prio; // Scheduling priority
|
2021
2047
|
uint32_t poll; // Polling level (0 - no polling)
|
@@ -2995,9 +3021,10 @@ static const char * LM_GGML_OP_NAME[LM_GGML_OP_COUNT] = {
|
|
2995
3021
|
|
2996
3022
|
"CROSS_ENTROPY_LOSS",
|
2997
3023
|
"CROSS_ENTROPY_LOSS_BACK",
|
3024
|
+
"OPT_STEP_ADAMW",
|
2998
3025
|
};
|
2999
3026
|
|
3000
|
-
static_assert(LM_GGML_OP_COUNT ==
|
3027
|
+
static_assert(LM_GGML_OP_COUNT == 80, "LM_GGML_OP_COUNT != 80");
|
3001
3028
|
|
3002
3029
|
static const char * LM_GGML_OP_SYMBOL[LM_GGML_OP_COUNT] = {
|
3003
3030
|
"none",
|
@@ -3088,9 +3115,10 @@ static const char * LM_GGML_OP_SYMBOL[LM_GGML_OP_COUNT] = {
|
|
3088
3115
|
|
3089
3116
|
"cross_entropy_loss(x,y)",
|
3090
3117
|
"cross_entropy_loss_back(x,y)",
|
3118
|
+
"adamw(x)",
|
3091
3119
|
};
|
3092
3120
|
|
3093
|
-
static_assert(LM_GGML_OP_COUNT ==
|
3121
|
+
static_assert(LM_GGML_OP_COUNT == 80, "LM_GGML_OP_COUNT != 80");
|
3094
3122
|
|
3095
3123
|
static_assert(LM_GGML_OP_POOL_COUNT == 2, "LM_GGML_OP_POOL_COUNT != 2");
|
3096
3124
|
|
@@ -3177,41 +3205,43 @@ inline static void lm_ggml_critical_section_start(void) {
|
|
3177
3205
|
}
|
3178
3206
|
}
|
3179
3207
|
|
3180
|
-
|
3181
|
-
|
3182
|
-
if (
|
3208
|
+
static void lm_ggml_barrier(struct lm_ggml_threadpool * tp) {
|
3209
|
+
int n_threads = atomic_load_explicit(&tp->n_threads_cur, memory_order_relaxed);
|
3210
|
+
if (n_threads == 1) {
|
3183
3211
|
return;
|
3184
3212
|
}
|
3185
3213
|
|
3214
|
+
#ifdef LM_GGML_USE_OPENMP
|
3186
3215
|
#pragma omp barrier
|
3187
|
-
}
|
3188
3216
|
#else
|
3189
|
-
|
3190
|
-
if (threadpool->n_threads_cur == 1) {
|
3191
|
-
return;
|
3192
|
-
}
|
3217
|
+
int n_passed = atomic_load_explicit(&tp->n_barrier_passed, memory_order_relaxed);
|
3193
3218
|
|
3194
|
-
|
3195
|
-
|
3219
|
+
// enter barrier (full seq-cst fence)
|
3220
|
+
int n_barrier = atomic_fetch_add_explicit(&tp->n_barrier, 1, memory_order_seq_cst);
|
3196
3221
|
|
3197
|
-
|
3198
|
-
int passed_old = atomic_load_explicit(n_barrier_passed, memory_order_relaxed);
|
3199
|
-
|
3200
|
-
if (atomic_fetch_add(n_barrier, 1) == n_threads - 1) {
|
3222
|
+
if (n_barrier == (n_threads - 1)) {
|
3201
3223
|
// last thread
|
3202
|
-
|
3203
|
-
|
3204
|
-
|
3205
|
-
|
3206
|
-
|
3207
|
-
if (atomic_load_explicit(n_barrier_passed, memory_order_relaxed) != passed_old) {
|
3208
|
-
return;
|
3209
|
-
}
|
3210
|
-
lm_ggml_thread_cpu_relax();
|
3211
|
-
}
|
3224
|
+
atomic_store_explicit(&tp->n_barrier, 0, memory_order_relaxed);
|
3225
|
+
|
3226
|
+
// exit barrier (fill seq-cst fence)
|
3227
|
+
atomic_fetch_add_explicit(&tp->n_barrier_passed, 1, memory_order_seq_cst);
|
3228
|
+
return;
|
3212
3229
|
}
|
3213
|
-
|
3230
|
+
|
3231
|
+
// wait for other threads
|
3232
|
+
while (atomic_load_explicit(&tp->n_barrier_passed, memory_order_relaxed) == n_passed) {
|
3233
|
+
lm_ggml_thread_cpu_relax();
|
3234
|
+
}
|
3235
|
+
|
3236
|
+
// exit barrier (full seq-cst fence)
|
3237
|
+
// TSAN doesn't support standalone fence yet, we use a dummy read-modify-write instead
|
3238
|
+
#ifdef LM_GGML_TSAN_ENABLED
|
3239
|
+
atomic_fetch_add_explicit(&tp->n_barrier_passed, 0, memory_order_seq_cst);
|
3240
|
+
#else
|
3241
|
+
atomic_thread_fence(memory_order_seq_cst);
|
3242
|
+
#endif
|
3214
3243
|
#endif
|
3244
|
+
}
|
3215
3245
|
|
3216
3246
|
// TODO: make this somehow automatically executed
|
3217
3247
|
// some sort of "sentry" mechanism
|
@@ -4097,7 +4127,11 @@ static void lm_ggml_set_op_params_f32(struct lm_ggml_tensor * tensor, uint32_t i
|
|
4097
4127
|
}
|
4098
4128
|
|
4099
4129
|
struct lm_ggml_tensor * lm_ggml_set_zero(struct lm_ggml_tensor * tensor) {
|
4100
|
-
|
4130
|
+
if (tensor->buffer) {
|
4131
|
+
lm_ggml_backend_tensor_memset(tensor, 0, 0, lm_ggml_nbytes(tensor));
|
4132
|
+
} else {
|
4133
|
+
memset(tensor->data, 0, lm_ggml_nbytes(tensor));
|
4134
|
+
}
|
4101
4135
|
return tensor;
|
4102
4136
|
}
|
4103
4137
|
|
@@ -8323,11 +8357,46 @@ struct lm_ggml_tensor * lm_ggml_cross_entropy_loss_back(
|
|
8323
8357
|
return result;
|
8324
8358
|
}
|
8325
8359
|
|
8326
|
-
|
8360
|
+
// opt_step_adamw
|
8327
8361
|
|
8328
|
-
|
8362
|
+
struct lm_ggml_tensor * lm_ggml_opt_step_adamw(
|
8329
8363
|
struct lm_ggml_context * ctx,
|
8330
|
-
struct lm_ggml_tensor
|
8364
|
+
struct lm_ggml_tensor * a,
|
8365
|
+
float alpha,
|
8366
|
+
float beta1,
|
8367
|
+
float beta2,
|
8368
|
+
float eps,
|
8369
|
+
float wd) {
|
8370
|
+
LM_GGML_ASSERT(a->grad);
|
8371
|
+
LM_GGML_ASSERT(alpha > 0.0f);
|
8372
|
+
LM_GGML_ASSERT(beta1 >= 0.0f && beta1 <= 1.0f);
|
8373
|
+
LM_GGML_ASSERT(beta2 >= 0.0f && beta2 <= 1.0f);
|
8374
|
+
LM_GGML_ASSERT(eps >= 0.0f);
|
8375
|
+
LM_GGML_ASSERT(wd >= 0.0f && wd <= 1.0f);
|
8376
|
+
|
8377
|
+
struct lm_ggml_tensor * result = lm_ggml_view_tensor(ctx, a);
|
8378
|
+
|
8379
|
+
result->op = LM_GGML_OP_OPT_STEP_ADAMW;
|
8380
|
+
result->grad = NULL;
|
8381
|
+
result->src[0] = a;
|
8382
|
+
result->src[1] = a->grad;
|
8383
|
+
result->src[2] = lm_ggml_dup_tensor(ctx, a->grad);
|
8384
|
+
result->src[3] = lm_ggml_dup_tensor(ctx, a->grad);
|
8385
|
+
|
8386
|
+
const int64_t iter = 1;
|
8387
|
+
memcpy(&result->op_params[0], &iter, sizeof(int64_t));
|
8388
|
+
lm_ggml_set_op_params_f32(result, 2, alpha);
|
8389
|
+
lm_ggml_set_op_params_f32(result, 3, beta1);
|
8390
|
+
lm_ggml_set_op_params_f32(result, 4, beta2);
|
8391
|
+
lm_ggml_set_op_params_f32(result, 5, eps);
|
8392
|
+
lm_ggml_set_op_params_f32(result, 6, wd);
|
8393
|
+
|
8394
|
+
return result;
|
8395
|
+
}
|
8396
|
+
|
8397
|
+
////////////////////////////////////////////////////////////////////////////////
|
8398
|
+
|
8399
|
+
void lm_ggml_set_param(struct lm_ggml_context * ctx, struct lm_ggml_tensor * tensor) {
|
8331
8400
|
tensor->flags |= LM_GGML_TENSOR_FLAG_PARAM;
|
8332
8401
|
|
8333
8402
|
LM_GGML_ASSERT(tensor->grad == NULL);
|
@@ -8335,6 +8404,13 @@ void lm_ggml_set_param(
|
|
8335
8404
|
lm_ggml_format_name(tensor->grad, "%s (grad)", tensor->name);
|
8336
8405
|
}
|
8337
8406
|
|
8407
|
+
void lm_ggml_set_loss(struct lm_ggml_tensor * tensor) {
|
8408
|
+
LM_GGML_ASSERT(lm_ggml_is_scalar(tensor));
|
8409
|
+
LM_GGML_ASSERT(tensor->type == LM_GGML_TYPE_F32);
|
8410
|
+
LM_GGML_ASSERT(tensor->grad);
|
8411
|
+
tensor->flags |= LM_GGML_TENSOR_FLAG_LOSS;
|
8412
|
+
}
|
8413
|
+
|
8338
8414
|
// lm_ggml_compute_forward_dup
|
8339
8415
|
|
8340
8416
|
static void lm_ggml_compute_forward_dup_same_cont(
|
@@ -17409,7 +17485,7 @@ static void lm_ggml_compute_forward_cross_entropy_loss_back_f32(
|
|
17409
17485
|
const int64_t ir0 = dr*ith;
|
17410
17486
|
const int64_t ir1 = MIN(ir0 + dr, nr);
|
17411
17487
|
|
17412
|
-
float
|
17488
|
+
const float d_by_nr = ((const float *) opt0->data)[0] / (float) nr;
|
17413
17489
|
|
17414
17490
|
for (int64_t i1 = ir0; i1 < ir1; i1++) {
|
17415
17491
|
float * ds0 = (float *)((char *) dst->data + i1*dst->nb[1]);
|
@@ -17433,7 +17509,7 @@ static void lm_ggml_compute_forward_cross_entropy_loss_back_f32(
|
|
17433
17509
|
|
17434
17510
|
// grad(src0) = (softmax(src0) - src1) * grad(cross_entropy_loss(src0, src1)) / nr
|
17435
17511
|
lm_ggml_vec_sub_f32(nc, ds0, ds0, s1);
|
17436
|
-
lm_ggml_vec_scale_f32(nc, ds0,
|
17512
|
+
lm_ggml_vec_scale_f32(nc, ds0, d_by_nr);
|
17437
17513
|
|
17438
17514
|
#ifndef NDEBUG
|
17439
17515
|
for (int i = 0; i < nc; ++i) {
|
@@ -17462,6 +17538,94 @@ static void lm_ggml_compute_forward_cross_entropy_loss_back(
|
|
17462
17538
|
}
|
17463
17539
|
}
|
17464
17540
|
|
17541
|
+
static void lm_ggml_compute_forward_opt_step_adamw_f32(
|
17542
|
+
const struct lm_ggml_compute_params * params,
|
17543
|
+
struct lm_ggml_tensor * dst) {
|
17544
|
+
|
17545
|
+
const struct lm_ggml_tensor * src0 = dst->src[0];
|
17546
|
+
const struct lm_ggml_tensor * src0_grad = dst->src[1];
|
17547
|
+
const struct lm_ggml_tensor * src0_grad_m = dst->src[2];
|
17548
|
+
const struct lm_ggml_tensor * src0_grad_v = dst->src[3];
|
17549
|
+
LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, src0_grad));
|
17550
|
+
|
17551
|
+
const int ith = params->ith;
|
17552
|
+
const int nth = params->nth;
|
17553
|
+
|
17554
|
+
const int nr = lm_ggml_nrows(src0);
|
17555
|
+
|
17556
|
+
LM_GGML_TENSOR_UNARY_OP_LOCALS
|
17557
|
+
LM_GGML_ASSERT(nb00 == sizeof(float));
|
17558
|
+
|
17559
|
+
// rows per thread
|
17560
|
+
const int dr = (nr + nth - 1)/nth;
|
17561
|
+
|
17562
|
+
// row range for this thread
|
17563
|
+
const int ir0 = dr*ith;
|
17564
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
17565
|
+
|
17566
|
+
/* const float gnorm = 1.0f; */
|
17567
|
+
int64_t iter; memcpy(&iter, &dst->op_params[0], sizeof(int64_t));
|
17568
|
+
const float alpha = lm_ggml_get_op_params_f32(dst, 2);
|
17569
|
+
const float beta1 = lm_ggml_get_op_params_f32(dst, 3);
|
17570
|
+
const float beta2 = lm_ggml_get_op_params_f32(dst, 4);
|
17571
|
+
const float eps = lm_ggml_get_op_params_f32(dst, 5);
|
17572
|
+
const float wd = lm_ggml_get_op_params_f32(dst, 6);
|
17573
|
+
|
17574
|
+
const float beta1h = alpha/(1.0f - powf(beta1, iter));
|
17575
|
+
const float beta2h = 1.0f/(1.0f - powf(beta2, iter));
|
17576
|
+
|
17577
|
+
for (int ir = ir0; ir < ir1; ++ir) {
|
17578
|
+
const int64_t i03 = ir/(ne02*ne01);
|
17579
|
+
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
|
17580
|
+
const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
|
17581
|
+
|
17582
|
+
const size_t offset = i03*nb03 + i02*nb02 + i01*nb01;
|
17583
|
+
|
17584
|
+
float * w = (float *) ((char *) src0->data + offset); // weight
|
17585
|
+
const float * g = (const float *) ((const char *) src0_grad->data + offset); // grad
|
17586
|
+
float * m = (float *) ((char *) src0_grad_m->data + offset);
|
17587
|
+
float * v = (float *) ((char *) src0_grad_v->data + offset);
|
17588
|
+
|
17589
|
+
for (int i00 = 0; i00 < ne00; ++i00) {
|
17590
|
+
m[i00] = m[i00]*beta1 + g[i00]*(1.0f - beta1);
|
17591
|
+
v[i00] = v[i00]*beta2 + g[i00]*g[i00]*(1.0f - beta2);
|
17592
|
+
|
17593
|
+
const float mh = m[i00]*beta1h;
|
17594
|
+
const float vh = sqrtf(v[i00]*beta2h) + eps;
|
17595
|
+
|
17596
|
+
// The weight decay is applied independently of the Adam momenta m and v.
|
17597
|
+
// This is NOT equivalent to l2 regularization that adds w[i00]*w[i00] to the loss.
|
17598
|
+
// See: https://arxiv.org/pdf/1711.05101v3.pdf
|
17599
|
+
w[i00] = w[i00]*(1.0f - alpha*wd) - mh/vh;
|
17600
|
+
}
|
17601
|
+
}
|
17602
|
+
|
17603
|
+
lm_ggml_barrier(params->threadpool);
|
17604
|
+
if (ith != 0) {
|
17605
|
+
return;
|
17606
|
+
}
|
17607
|
+
|
17608
|
+
iter++;
|
17609
|
+
memcpy(&dst->op_params[0], &iter, sizeof(int64_t));
|
17610
|
+
}
|
17611
|
+
|
17612
|
+
static void lm_ggml_compute_forward_opt_step_adamw(
|
17613
|
+
const struct lm_ggml_compute_params * params,
|
17614
|
+
struct lm_ggml_tensor * dst) {
|
17615
|
+
|
17616
|
+
const struct lm_ggml_tensor * src0 = dst->src[0];
|
17617
|
+
|
17618
|
+
switch (src0->type) {
|
17619
|
+
case LM_GGML_TYPE_F32:
|
17620
|
+
{
|
17621
|
+
lm_ggml_compute_forward_opt_step_adamw_f32(params, dst);
|
17622
|
+
} break;
|
17623
|
+
default:
|
17624
|
+
{
|
17625
|
+
LM_GGML_ABORT("fatal error");
|
17626
|
+
}
|
17627
|
+
}
|
17628
|
+
}
|
17465
17629
|
/////////////////////////////////
|
17466
17630
|
|
17467
17631
|
static void lm_ggml_compute_forward(struct lm_ggml_compute_params * params, struct lm_ggml_tensor * tensor) {
|
@@ -17807,6 +17971,11 @@ static void lm_ggml_compute_forward(struct lm_ggml_compute_params * params, stru
|
|
17807
17971
|
lm_ggml_compute_forward_cross_entropy_loss_back(params, tensor);
|
17808
17972
|
}
|
17809
17973
|
break;
|
17974
|
+
case LM_GGML_OP_OPT_STEP_ADAMW:
|
17975
|
+
{
|
17976
|
+
lm_ggml_compute_forward_opt_step_adamw(params, tensor);
|
17977
|
+
}
|
17978
|
+
break;
|
17810
17979
|
case LM_GGML_OP_NONE:
|
17811
17980
|
{
|
17812
17981
|
// nop
|
@@ -17961,7 +18130,7 @@ void lm_ggml_build_backward_gradient_checkpointing(
|
|
17961
18130
|
struct lm_ggml_tensor * * checkpoints,
|
17962
18131
|
int n_checkpoints) {
|
17963
18132
|
lm_ggml_graph_cpy(gf, gb_tmp);
|
17964
|
-
lm_ggml_build_backward_expand(ctx, gf, gb_tmp, true);
|
18133
|
+
lm_ggml_build_backward_expand(ctx, gf, gb_tmp, false, true);
|
17965
18134
|
|
17966
18135
|
if (n_checkpoints <= 0) {
|
17967
18136
|
lm_ggml_graph_cpy(gb_tmp, gb);
|
@@ -17999,42 +18168,93 @@ void lm_ggml_build_backward_gradient_checkpointing(
|
|
17999
18168
|
lm_ggml_hash_map_free(replacements);
|
18000
18169
|
}
|
18001
18170
|
|
18002
|
-
// functions to change gradients
|
18003
|
-
|
18004
|
-
|
18171
|
+
// utility functions to change gradients
|
18172
|
+
// if a is in acc_table, modify gradients in-place and mark result as gradient accumulator
|
18173
|
+
// else if a is in zero_table, replace a
|
18174
|
+
// else, just add/subtract/etc. the gradients
|
18175
|
+
|
18176
|
+
static struct lm_ggml_tensor * lm_ggml_add_or_set(
|
18177
|
+
struct lm_ggml_context * ctx,
|
18178
|
+
struct lm_ggml_tensor * a,
|
18179
|
+
struct lm_ggml_tensor * b,
|
18180
|
+
struct lm_ggml_hash_set * zero_table,
|
18181
|
+
struct lm_ggml_hash_set * acc_table) {
|
18182
|
+
if (lm_ggml_hash_contains(acc_table, a)) {
|
18183
|
+
struct lm_ggml_tensor * ret = lm_ggml_add_impl(ctx, a, b, true);
|
18184
|
+
const size_t insert_result = lm_ggml_hash_insert(acc_table, ret);
|
18185
|
+
LM_GGML_ASSERT(insert_result != LM_GGML_HASHSET_FULL);
|
18186
|
+
LM_GGML_ASSERT(insert_result != LM_GGML_HASHSET_ALREADY_EXISTS);
|
18187
|
+
return ret;
|
18188
|
+
}
|
18005
18189
|
if (lm_ggml_hash_contains(zero_table, a)) {
|
18006
18190
|
return b;
|
18007
|
-
} else {
|
18008
|
-
return lm_ggml_add_impl(ctx, a, b, false);
|
18009
18191
|
}
|
18192
|
+
return lm_ggml_add_impl(ctx, a, b, false);
|
18010
18193
|
}
|
18011
18194
|
|
18012
|
-
static struct lm_ggml_tensor * lm_ggml_acc_or_set(
|
18195
|
+
static struct lm_ggml_tensor * lm_ggml_acc_or_set(
|
18196
|
+
struct lm_ggml_context * ctx,
|
18197
|
+
struct lm_ggml_tensor * a,
|
18198
|
+
struct lm_ggml_tensor * b,
|
18199
|
+
const size_t nb1,
|
18200
|
+
const size_t nb2,
|
18201
|
+
const size_t nb3,
|
18202
|
+
const size_t offset,
|
18203
|
+
struct lm_ggml_hash_set * zero_table,
|
18204
|
+
struct lm_ggml_hash_set * acc_table) {
|
18205
|
+
if (lm_ggml_hash_contains(acc_table, a)) {
|
18206
|
+
struct lm_ggml_tensor * ret = lm_ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, true);
|
18207
|
+
const size_t insert_result = lm_ggml_hash_insert(acc_table, ret);
|
18208
|
+
LM_GGML_ASSERT(insert_result != LM_GGML_HASHSET_FULL);
|
18209
|
+
LM_GGML_ASSERT(insert_result != LM_GGML_HASHSET_ALREADY_EXISTS);
|
18210
|
+
return ret;
|
18211
|
+
}
|
18013
18212
|
if (lm_ggml_hash_contains(zero_table, a)) {
|
18014
|
-
struct lm_ggml_tensor * a_zero = lm_ggml_scale(ctx, a, 0.0f);
|
18213
|
+
struct lm_ggml_tensor * a_zero = lm_ggml_scale(ctx, a, 0.0f); // FIXME this is going to produce NaN if a contains inf/NaN
|
18015
18214
|
return lm_ggml_acc_impl(ctx, a_zero, b, nb1, nb2, nb3, offset, false);
|
18016
|
-
} else {
|
18017
|
-
return lm_ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, false);
|
18018
18215
|
}
|
18216
|
+
return lm_ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, false);
|
18019
18217
|
}
|
18020
18218
|
|
18021
|
-
static struct lm_ggml_tensor * lm_ggml_add1_or_set(
|
18219
|
+
static struct lm_ggml_tensor * lm_ggml_add1_or_set(
|
18220
|
+
struct lm_ggml_context * ctx,
|
18221
|
+
struct lm_ggml_tensor * a,
|
18222
|
+
struct lm_ggml_tensor * b,
|
18223
|
+
struct lm_ggml_hash_set * zero_table,
|
18224
|
+
struct lm_ggml_hash_set * acc_table) {
|
18225
|
+
if (lm_ggml_hash_contains(acc_table, a)) {
|
18226
|
+
struct lm_ggml_tensor * ret = lm_ggml_add1_impl(ctx, a, b, true);
|
18227
|
+
const size_t insert_result = lm_ggml_hash_insert(acc_table, ret);
|
18228
|
+
LM_GGML_ASSERT(insert_result != LM_GGML_HASHSET_FULL);
|
18229
|
+
LM_GGML_ASSERT(insert_result != LM_GGML_HASHSET_ALREADY_EXISTS);
|
18230
|
+
return ret;
|
18231
|
+
}
|
18022
18232
|
if (lm_ggml_hash_contains(zero_table, a)) {
|
18023
18233
|
return lm_ggml_repeat(ctx, b, a);
|
18024
|
-
} else {
|
18025
|
-
return lm_ggml_add1_impl(ctx, a, b, false);
|
18026
18234
|
}
|
18235
|
+
return lm_ggml_add1_impl(ctx, a, b, false);
|
18027
18236
|
}
|
18028
18237
|
|
18029
|
-
static struct lm_ggml_tensor * lm_ggml_sub_or_set(
|
18238
|
+
static struct lm_ggml_tensor * lm_ggml_sub_or_set(
|
18239
|
+
struct lm_ggml_context * ctx,
|
18240
|
+
struct lm_ggml_tensor * a,
|
18241
|
+
struct lm_ggml_tensor * b,
|
18242
|
+
struct lm_ggml_hash_set * zero_table,
|
18243
|
+
struct lm_ggml_hash_set * acc_table) {
|
18244
|
+
if (lm_ggml_hash_contains(acc_table, a)) {
|
18245
|
+
struct lm_ggml_tensor * ret = lm_ggml_sub_impl(ctx, a, b, true);
|
18246
|
+
const size_t insert_result = lm_ggml_hash_insert(acc_table, ret);
|
18247
|
+
LM_GGML_ASSERT(insert_result != LM_GGML_HASHSET_FULL);
|
18248
|
+
LM_GGML_ASSERT(insert_result != LM_GGML_HASHSET_ALREADY_EXISTS);
|
18249
|
+
return ret;
|
18250
|
+
}
|
18030
18251
|
if (lm_ggml_hash_contains(zero_table, a)) {
|
18031
18252
|
return lm_ggml_neg(ctx, b);
|
18032
|
-
} else {
|
18033
|
-
return lm_ggml_sub_impl(ctx, a, b, false);
|
18034
18253
|
}
|
18254
|
+
return lm_ggml_sub_impl(ctx, a, b, false);
|
18035
18255
|
}
|
18036
18256
|
|
18037
|
-
static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggml_tensor * tensor, struct lm_ggml_hash_set * zero_table) {
|
18257
|
+
static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggml_tensor * tensor, struct lm_ggml_hash_set * zero_table, struct lm_ggml_hash_set * acc_table) {
|
18038
18258
|
struct lm_ggml_tensor * src0 = tensor->src[0];
|
18039
18259
|
struct lm_ggml_tensor * src1 = tensor->src[1];
|
18040
18260
|
struct lm_ggml_tensor * src2 = tensor->src[2];
|
@@ -18043,38 +18263,38 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
|
|
18043
18263
|
case LM_GGML_OP_DUP:
|
18044
18264
|
{
|
18045
18265
|
if (src0->grad) {
|
18046
|
-
src0->grad = lm_ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
|
18266
|
+
src0->grad = lm_ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
|
18047
18267
|
}
|
18048
18268
|
} break;
|
18049
18269
|
case LM_GGML_OP_ADD:
|
18050
18270
|
{
|
18051
18271
|
if (src0->grad) {
|
18052
|
-
src0->grad = lm_ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
|
18272
|
+
src0->grad = lm_ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
|
18053
18273
|
}
|
18054
18274
|
if (src1->grad) {
|
18055
18275
|
if (lm_ggml_are_same_shape(src0, src1)) {
|
18056
|
-
src1->grad = lm_ggml_add_or_set(ctx, src1->grad, tensor->grad, zero_table);
|
18276
|
+
src1->grad = lm_ggml_add_or_set(ctx, src1->grad, tensor->grad, zero_table, acc_table);
|
18057
18277
|
} else {
|
18058
|
-
src1->grad = lm_ggml_add_or_set(ctx, src1->grad, lm_ggml_repeat_back(ctx, tensor->grad, src1), zero_table);
|
18278
|
+
src1->grad = lm_ggml_add_or_set(ctx, src1->grad, lm_ggml_repeat_back(ctx, tensor->grad, src1), zero_table, acc_table);
|
18059
18279
|
}
|
18060
18280
|
}
|
18061
18281
|
} break;
|
18062
18282
|
case LM_GGML_OP_ADD1:
|
18063
18283
|
{
|
18064
18284
|
if (src0->grad) {
|
18065
|
-
src0->grad = lm_ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
|
18285
|
+
src0->grad = lm_ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
|
18066
18286
|
}
|
18067
18287
|
if (src1->grad) {
|
18068
18288
|
src1->grad = lm_ggml_add_or_set(ctx,
|
18069
18289
|
src1->grad,
|
18070
18290
|
lm_ggml_mean(ctx, tensor->grad), // TODO: should probably be sum instead of mean
|
18071
|
-
zero_table);
|
18291
|
+
zero_table, acc_table);
|
18072
18292
|
}
|
18073
18293
|
} break;
|
18074
18294
|
case LM_GGML_OP_ACC:
|
18075
18295
|
{
|
18076
18296
|
if (src0->grad) {
|
18077
|
-
src0->grad = lm_ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
|
18297
|
+
src0->grad = lm_ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
|
18078
18298
|
}
|
18079
18299
|
if (src1->grad) {
|
18080
18300
|
const size_t nb1 = ((int32_t *) tensor->op_params)[0];
|
@@ -18096,16 +18316,16 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
|
|
18096
18316
|
lm_ggml_reshape(ctx,
|
18097
18317
|
lm_ggml_cont(ctx, tensor_grad_view),
|
18098
18318
|
src1->grad),
|
18099
|
-
zero_table);
|
18319
|
+
zero_table, acc_table);
|
18100
18320
|
}
|
18101
18321
|
} break;
|
18102
18322
|
case LM_GGML_OP_SUB:
|
18103
18323
|
{
|
18104
18324
|
if (src0->grad) {
|
18105
|
-
src0->grad = lm_ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
|
18325
|
+
src0->grad = lm_ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
|
18106
18326
|
}
|
18107
18327
|
if (src1->grad) {
|
18108
|
-
src1->grad = lm_ggml_sub_or_set(ctx, src1->grad, tensor->grad, zero_table);
|
18328
|
+
src1->grad = lm_ggml_sub_or_set(ctx, src1->grad, tensor->grad, zero_table, acc_table);
|
18109
18329
|
}
|
18110
18330
|
} break;
|
18111
18331
|
case LM_GGML_OP_MUL:
|
@@ -18115,14 +18335,14 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
|
|
18115
18335
|
lm_ggml_add_or_set(ctx,
|
18116
18336
|
src0->grad,
|
18117
18337
|
lm_ggml_mul(ctx, src1, tensor->grad),
|
18118
|
-
zero_table);
|
18338
|
+
zero_table, acc_table);
|
18119
18339
|
}
|
18120
18340
|
if (src1->grad) {
|
18121
18341
|
src1->grad =
|
18122
18342
|
lm_ggml_add_or_set(ctx,
|
18123
18343
|
src1->grad,
|
18124
18344
|
lm_ggml_mul(ctx, src0, tensor->grad),
|
18125
|
-
zero_table);
|
18345
|
+
zero_table, acc_table);
|
18126
18346
|
}
|
18127
18347
|
} break;
|
18128
18348
|
case LM_GGML_OP_DIV:
|
@@ -18132,7 +18352,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
|
|
18132
18352
|
lm_ggml_add_or_set(ctx,
|
18133
18353
|
src0->grad,
|
18134
18354
|
lm_ggml_div(ctx, tensor->grad, src1),
|
18135
|
-
zero_table);
|
18355
|
+
zero_table, acc_table);
|
18136
18356
|
}
|
18137
18357
|
if (src1->grad) {
|
18138
18358
|
src1->grad =
|
@@ -18141,7 +18361,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
|
|
18141
18361
|
lm_ggml_mul(ctx,
|
18142
18362
|
tensor->grad,
|
18143
18363
|
lm_ggml_div(ctx, tensor, src1)),
|
18144
|
-
zero_table);
|
18364
|
+
zero_table, acc_table);
|
18145
18365
|
}
|
18146
18366
|
} break;
|
18147
18367
|
case LM_GGML_OP_SQR:
|
@@ -18153,7 +18373,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
|
|
18153
18373
|
lm_ggml_scale(ctx,
|
18154
18374
|
lm_ggml_mul(ctx, src0, tensor->grad),
|
18155
18375
|
2.0f),
|
18156
|
-
zero_table);
|
18376
|
+
zero_table, acc_table);
|
18157
18377
|
}
|
18158
18378
|
} break;
|
18159
18379
|
case LM_GGML_OP_SQRT:
|
@@ -18167,7 +18387,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
|
|
18167
18387
|
tensor->grad,
|
18168
18388
|
tensor),
|
18169
18389
|
0.5f),
|
18170
|
-
zero_table);
|
18390
|
+
zero_table, acc_table);
|
18171
18391
|
}
|
18172
18392
|
} break;
|
18173
18393
|
case LM_GGML_OP_LOG:
|
@@ -18179,7 +18399,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
|
|
18179
18399
|
lm_ggml_div(ctx,
|
18180
18400
|
tensor->grad,
|
18181
18401
|
src0),
|
18182
|
-
zero_table);
|
18402
|
+
zero_table, acc_table);
|
18183
18403
|
}
|
18184
18404
|
} break;
|
18185
18405
|
case LM_GGML_OP_SIN:
|
@@ -18191,7 +18411,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
|
|
18191
18411
|
lm_ggml_mul(ctx,
|
18192
18412
|
tensor->grad,
|
18193
18413
|
lm_ggml_cos(ctx, src0)),
|
18194
|
-
zero_table);
|
18414
|
+
zero_table, acc_table);
|
18195
18415
|
}
|
18196
18416
|
} break;
|
18197
18417
|
case LM_GGML_OP_COS:
|
@@ -18203,7 +18423,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
|
|
18203
18423
|
lm_ggml_mul(ctx,
|
18204
18424
|
tensor->grad,
|
18205
18425
|
lm_ggml_sin(ctx, src0)),
|
18206
|
-
zero_table);
|
18426
|
+
zero_table, acc_table);
|
18207
18427
|
}
|
18208
18428
|
} break;
|
18209
18429
|
case LM_GGML_OP_SUM:
|
@@ -18213,7 +18433,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
|
|
18213
18433
|
lm_ggml_add1_or_set(ctx,
|
18214
18434
|
src0->grad,
|
18215
18435
|
tensor->grad,
|
18216
|
-
zero_table);
|
18436
|
+
zero_table, acc_table);
|
18217
18437
|
}
|
18218
18438
|
} break;
|
18219
18439
|
case LM_GGML_OP_SUM_ROWS:
|
@@ -18225,7 +18445,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
|
|
18225
18445
|
lm_ggml_repeat(ctx,
|
18226
18446
|
tensor->grad,
|
18227
18447
|
src0->grad),
|
18228
|
-
zero_table);
|
18448
|
+
zero_table, acc_table);
|
18229
18449
|
}
|
18230
18450
|
} break;
|
18231
18451
|
case LM_GGML_OP_MEAN:
|
@@ -18240,7 +18460,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
|
|
18240
18460
|
src0->grad = lm_ggml_add_or_set(ctx,
|
18241
18461
|
src0->grad,
|
18242
18462
|
lm_ggml_repeat_back(ctx, tensor->grad, src0->grad),
|
18243
|
-
zero_table);
|
18463
|
+
zero_table, acc_table);
|
18244
18464
|
}
|
18245
18465
|
} break;
|
18246
18466
|
case LM_GGML_OP_REPEAT_BACK:
|
@@ -18250,7 +18470,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
|
|
18250
18470
|
src0->grad = lm_ggml_add_or_set(ctx,
|
18251
18471
|
src0->grad,
|
18252
18472
|
lm_ggml_repeat(ctx, tensor->grad, src0->grad),
|
18253
|
-
zero_table);
|
18473
|
+
zero_table, acc_table);
|
18254
18474
|
}
|
18255
18475
|
} break;
|
18256
18476
|
case LM_GGML_OP_CONCAT:
|
@@ -18275,7 +18495,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
|
|
18275
18495
|
src0->grad = lm_ggml_add_or_set(ctx,
|
18276
18496
|
src0->grad,
|
18277
18497
|
lm_ggml_rms_norm_back(ctx, src0, tensor->grad, eps),
|
18278
|
-
zero_table);
|
18498
|
+
zero_table, acc_table);
|
18279
18499
|
}
|
18280
18500
|
} break;
|
18281
18501
|
case LM_GGML_OP_RMS_NORM_BACK:
|
@@ -18323,7 +18543,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
|
|
18323
18543
|
lm_ggml_add_or_set(ctx,
|
18324
18544
|
src0->grad, // [n,m,q1,r1]
|
18325
18545
|
s1_tg, // [n,m,q1,r1]
|
18326
|
-
zero_table);
|
18546
|
+
zero_table, acc_table);
|
18327
18547
|
}
|
18328
18548
|
if (src1->grad) {
|
18329
18549
|
src1->grad =
|
@@ -18341,7 +18561,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
|
|
18341
18561
|
src0, // [n,m,q1,r1]
|
18342
18562
|
lm_ggml_transpose(ctx, // [p,m,qq,rr]
|
18343
18563
|
tensor->grad)), // [m,p,qq,rr]
|
18344
|
-
zero_table);
|
18564
|
+
zero_table, acc_table);
|
18345
18565
|
}
|
18346
18566
|
} break;
|
18347
18567
|
case LM_GGML_OP_MUL_MAT_ID:
|
@@ -18363,7 +18583,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
|
|
18363
18583
|
lm_ggml_add_or_set(ctx,
|
18364
18584
|
src0->grad,
|
18365
18585
|
lm_ggml_scale_impl(ctx, tensor->grad, s, false),
|
18366
|
-
zero_table);
|
18586
|
+
zero_table, acc_table);
|
18367
18587
|
}
|
18368
18588
|
} break;
|
18369
18589
|
case LM_GGML_OP_SET:
|
@@ -18392,7 +18612,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
|
|
18392
18612
|
tensor->grad,
|
18393
18613
|
lm_ggml_neg(ctx, tensor_grad_view),
|
18394
18614
|
nb1, nb2, nb3, offset, false),
|
18395
|
-
zero_table);
|
18615
|
+
zero_table, acc_table);
|
18396
18616
|
}
|
18397
18617
|
|
18398
18618
|
if (src1->grad) {
|
@@ -18402,7 +18622,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
|
|
18402
18622
|
lm_ggml_reshape(ctx,
|
18403
18623
|
lm_ggml_cont(ctx, tensor_grad_view),
|
18404
18624
|
src1->grad),
|
18405
|
-
zero_table);
|
18625
|
+
zero_table, acc_table);
|
18406
18626
|
}
|
18407
18627
|
} break;
|
18408
18628
|
case LM_GGML_OP_CPY:
|
@@ -18413,7 +18633,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
|
|
18413
18633
|
// tensor = src0 * 1 + src1 * 0
|
18414
18634
|
if (src0->grad) {
|
18415
18635
|
// dsrc0 = dtensor * 1
|
18416
|
-
src0->grad = lm_ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
|
18636
|
+
src0->grad = lm_ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
|
18417
18637
|
}
|
18418
18638
|
if (src1->grad) {
|
18419
18639
|
// dsrc1 = dtensor * 0 -> noop
|
@@ -18425,7 +18645,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
|
|
18425
18645
|
if (src0->grad) {
|
18426
18646
|
LM_GGML_ASSERT(lm_ggml_is_contiguous(src0->grad));
|
18427
18647
|
LM_GGML_ASSERT(lm_ggml_is_contiguous(tensor->grad));
|
18428
|
-
src0->grad = lm_ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
|
18648
|
+
src0->grad = lm_ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
|
18429
18649
|
}
|
18430
18650
|
} break;
|
18431
18651
|
case LM_GGML_OP_RESHAPE:
|
@@ -18439,7 +18659,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
|
|
18439
18659
|
? tensor->grad
|
18440
18660
|
: lm_ggml_cont(ctx, tensor->grad),
|
18441
18661
|
src0->grad),
|
18442
|
-
zero_table);
|
18662
|
+
zero_table, acc_table);
|
18443
18663
|
}
|
18444
18664
|
} break;
|
18445
18665
|
case LM_GGML_OP_VIEW:
|
@@ -18468,7 +18688,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
|
|
18468
18688
|
nb3 = (nb3 / n0) * ng;
|
18469
18689
|
}
|
18470
18690
|
|
18471
|
-
src0->grad = lm_ggml_acc_or_set(ctx, src0->grad, tensor->grad, nb1, nb2, nb3, offset, zero_table);
|
18691
|
+
src0->grad = lm_ggml_acc_or_set(ctx, src0->grad, tensor->grad, nb1, nb2, nb3, offset, zero_table, acc_table);
|
18472
18692
|
}
|
18473
18693
|
} break;
|
18474
18694
|
case LM_GGML_OP_PERMUTE:
|
@@ -18493,7 +18713,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
|
|
18493
18713
|
axes_backward[1],
|
18494
18714
|
axes_backward[2],
|
18495
18715
|
axes_backward[3]),
|
18496
|
-
zero_table);
|
18716
|
+
zero_table, acc_table);
|
18497
18717
|
}
|
18498
18718
|
} break;
|
18499
18719
|
case LM_GGML_OP_TRANSPOSE:
|
@@ -18503,7 +18723,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
|
|
18503
18723
|
src0->grad =
|
18504
18724
|
lm_ggml_add_or_set(ctx, src0->grad,
|
18505
18725
|
lm_ggml_transpose(ctx, tensor->grad),
|
18506
|
-
zero_table);
|
18726
|
+
zero_table, acc_table);
|
18507
18727
|
}
|
18508
18728
|
} break;
|
18509
18729
|
case LM_GGML_OP_GET_ROWS:
|
@@ -18515,7 +18735,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
|
|
18515
18735
|
// last lm_ggml_get_rows_back argument src0->grad is only
|
18516
18736
|
// necessary to setup correct output shape
|
18517
18737
|
lm_ggml_get_rows_back(ctx, tensor->grad, src1, src0->grad),
|
18518
|
-
zero_table);
|
18738
|
+
zero_table, acc_table);
|
18519
18739
|
}
|
18520
18740
|
if (src1->grad) {
|
18521
18741
|
// noop
|
@@ -18539,7 +18759,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
|
|
18539
18759
|
/* lm_ggml_diag_mask_inf_impl() shouldn't be here */
|
18540
18760
|
/* ref: https://github.com/ggerganov/llama.cpp/pull/4203#discussion_r1412377992 */
|
18541
18761
|
lm_ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false),
|
18542
|
-
zero_table);
|
18762
|
+
zero_table, acc_table);
|
18543
18763
|
}
|
18544
18764
|
} break;
|
18545
18765
|
case LM_GGML_OP_DIAG_MASK_ZERO:
|
@@ -18550,7 +18770,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
|
|
18550
18770
|
src0->grad =
|
18551
18771
|
lm_ggml_add_or_set(ctx, src0->grad,
|
18552
18772
|
lm_ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false),
|
18553
|
-
zero_table);
|
18773
|
+
zero_table, acc_table);
|
18554
18774
|
}
|
18555
18775
|
} break;
|
18556
18776
|
case LM_GGML_OP_SOFT_MAX:
|
@@ -18560,7 +18780,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
|
|
18560
18780
|
src0->grad =
|
18561
18781
|
lm_ggml_add_or_set(ctx, src0->grad,
|
18562
18782
|
lm_ggml_soft_max_back(ctx, tensor->grad, tensor),
|
18563
|
-
zero_table);
|
18783
|
+
zero_table, acc_table);
|
18564
18784
|
}
|
18565
18785
|
|
18566
18786
|
} break;
|
@@ -18601,7 +18821,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
|
|
18601
18821
|
attn_factor,
|
18602
18822
|
beta_fast,
|
18603
18823
|
beta_slow),
|
18604
|
-
zero_table);
|
18824
|
+
zero_table, acc_table);
|
18605
18825
|
}
|
18606
18826
|
} break;
|
18607
18827
|
case LM_GGML_OP_ROPE_BACK:
|
@@ -18637,7 +18857,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
|
|
18637
18857
|
beta_fast,
|
18638
18858
|
beta_slow,
|
18639
18859
|
false),
|
18640
|
-
zero_table);
|
18860
|
+
zero_table, acc_table);
|
18641
18861
|
}
|
18642
18862
|
} break;
|
18643
18863
|
case LM_GGML_OP_CLAMP:
|
@@ -18662,7 +18882,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
|
|
18662
18882
|
src1->grad = lm_ggml_add_or_set(ctx,
|
18663
18883
|
src1->grad,
|
18664
18884
|
lm_ggml_im2col_back(ctx, src0, tensor->grad, src1->ne, s0, s1, p0, p1, d0, d1, is_2D),
|
18665
|
-
zero_table);
|
18885
|
+
zero_table, acc_table);
|
18666
18886
|
}
|
18667
18887
|
} break;
|
18668
18888
|
case LM_GGML_OP_IM2COL_BACK:
|
@@ -18691,7 +18911,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
|
|
18691
18911
|
src0->grad = lm_ggml_add_or_set(ctx,
|
18692
18912
|
src0->grad,
|
18693
18913
|
lm_ggml_pool_2d_back(ctx, tensor->grad, src0, op, k0, k1, s0, s1, p0, p1),
|
18694
|
-
zero_table);
|
18914
|
+
zero_table, acc_table);
|
18695
18915
|
}
|
18696
18916
|
} break;
|
18697
18917
|
case LM_GGML_OP_POOL_2D_BACK:
|
@@ -18756,7 +18976,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
|
|
18756
18976
|
src0->grad = lm_ggml_add_or_set(ctx,
|
18757
18977
|
src0->grad,
|
18758
18978
|
grad_q,
|
18759
|
-
zero_table);
|
18979
|
+
zero_table, acc_table);
|
18760
18980
|
}
|
18761
18981
|
if (src1->grad) {
|
18762
18982
|
struct lm_ggml_tensor * view_k = lm_ggml_view_1d(ctx, flash_grad, elem_k, offs_k);
|
@@ -18764,7 +18984,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
|
|
18764
18984
|
src1->grad = lm_ggml_add_or_set(ctx,
|
18765
18985
|
src1->grad,
|
18766
18986
|
grad_k,
|
18767
|
-
zero_table);
|
18987
|
+
zero_table, acc_table);
|
18768
18988
|
}
|
18769
18989
|
if (src2->grad) {
|
18770
18990
|
struct lm_ggml_tensor * view_v = lm_ggml_view_1d(ctx, flash_grad, elem_v, offs_v);
|
@@ -18772,7 +18992,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
|
|
18772
18992
|
src2->grad = lm_ggml_add_or_set(ctx,
|
18773
18993
|
src2->grad,
|
18774
18994
|
grad_v,
|
18775
|
-
zero_table);
|
18995
|
+
zero_table, acc_table);
|
18776
18996
|
}
|
18777
18997
|
} break;
|
18778
18998
|
case LM_GGML_OP_FLASH_ATTN_BACK:
|
@@ -18798,7 +19018,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
|
|
18798
19018
|
lm_ggml_mul(ctx,
|
18799
19019
|
lm_ggml_sgn(ctx, src0),
|
18800
19020
|
tensor->grad),
|
18801
|
-
zero_table);
|
19021
|
+
zero_table, acc_table);
|
18802
19022
|
}
|
18803
19023
|
} break;
|
18804
19024
|
case LM_GGML_UNARY_OP_SGN:
|
@@ -18810,7 +19030,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
|
|
18810
19030
|
case LM_GGML_UNARY_OP_NEG:
|
18811
19031
|
{
|
18812
19032
|
if (src0->grad) {
|
18813
|
-
src0->grad = lm_ggml_sub_or_set(ctx, src0->grad, tensor->grad, zero_table);
|
19033
|
+
src0->grad = lm_ggml_sub_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
|
18814
19034
|
}
|
18815
19035
|
} break;
|
18816
19036
|
case LM_GGML_UNARY_OP_STEP:
|
@@ -18835,7 +19055,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
|
|
18835
19055
|
lm_ggml_mul(ctx,
|
18836
19056
|
lm_ggml_step(ctx, src0),
|
18837
19057
|
tensor->grad),
|
18838
|
-
zero_table);
|
19058
|
+
zero_table, acc_table);
|
18839
19059
|
}
|
18840
19060
|
} break;
|
18841
19061
|
case LM_GGML_UNARY_OP_SIGMOID:
|
@@ -18857,7 +19077,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
|
|
18857
19077
|
src0->grad = lm_ggml_add_or_set(ctx,
|
18858
19078
|
src0->grad,
|
18859
19079
|
lm_ggml_silu_back(ctx, src0, tensor->grad),
|
18860
|
-
zero_table);
|
19080
|
+
zero_table, acc_table);
|
18861
19081
|
}
|
18862
19082
|
} break;
|
18863
19083
|
case LM_GGML_UNARY_OP_EXP:
|
@@ -18866,7 +19086,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
|
|
18866
19086
|
src0->grad = lm_ggml_add_or_set(ctx,
|
18867
19087
|
src0->grad,
|
18868
19088
|
lm_ggml_mul(ctx, tensor, tensor->grad),
|
18869
|
-
zero_table);
|
19089
|
+
zero_table, acc_table);
|
18870
19090
|
}
|
18871
19091
|
} break;
|
18872
19092
|
default:
|
@@ -18896,13 +19116,17 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
|
|
18896
19116
|
src0,
|
18897
19117
|
src1,
|
18898
19118
|
tensor->grad),
|
18899
|
-
zero_table);
|
19119
|
+
zero_table, acc_table);
|
18900
19120
|
}
|
18901
19121
|
} break;
|
18902
19122
|
case LM_GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
18903
19123
|
{
|
18904
19124
|
LM_GGML_ABORT("fatal error"); // not supported
|
18905
19125
|
}
|
19126
|
+
case LM_GGML_OP_OPT_STEP_ADAMW:
|
19127
|
+
{
|
19128
|
+
LM_GGML_ABORT("fatal error"); // not supported
|
19129
|
+
}
|
18906
19130
|
case LM_GGML_OP_NONE:
|
18907
19131
|
{
|
18908
19132
|
// nop
|
@@ -18992,7 +19216,7 @@ void lm_ggml_build_forward_expand(struct lm_ggml_cgraph * cgraph, struct lm_ggml
|
|
18992
19216
|
lm_ggml_build_forward_impl(cgraph, tensor, true);
|
18993
19217
|
}
|
18994
19218
|
|
18995
|
-
void lm_ggml_build_backward_expand(struct lm_ggml_context * ctx, struct lm_ggml_cgraph * gf, struct lm_ggml_cgraph * gb, bool keep) {
|
19219
|
+
void lm_ggml_build_backward_expand(struct lm_ggml_context * ctx, struct lm_ggml_cgraph * gf, struct lm_ggml_cgraph * gb, bool accumulate, bool keep) {
|
18996
19220
|
LM_GGML_ASSERT(gf->n_nodes > 0);
|
18997
19221
|
LM_GGML_ASSERT(gf->grads);
|
18998
19222
|
|
@@ -19008,21 +19232,35 @@ void lm_ggml_build_backward_expand(struct lm_ggml_context * ctx, struct lm_ggml_
|
|
19008
19232
|
}
|
19009
19233
|
}
|
19010
19234
|
|
19011
|
-
//
|
19235
|
+
// keep tables of original gradients for replacement/accumulation logic
|
19012
19236
|
struct lm_ggml_hash_set zero_table = lm_ggml_hash_set_new(gf->size);
|
19237
|
+
struct lm_ggml_hash_set acc_table = lm_ggml_hash_set_new(gf->size);
|
19013
19238
|
for (int i = 0; i < gf->n_nodes; i++) {
|
19014
|
-
|
19015
|
-
|
19239
|
+
struct lm_ggml_tensor * node = gf->nodes[i];
|
19240
|
+
|
19241
|
+
if (node->grad) {
|
19242
|
+
{
|
19243
|
+
const size_t insert_result = lm_ggml_hash_insert(&zero_table, node->grad);
|
19244
|
+
LM_GGML_ASSERT(insert_result != LM_GGML_HASHSET_FULL);
|
19245
|
+
LM_GGML_ASSERT(insert_result != LM_GGML_HASHSET_ALREADY_EXISTS);
|
19246
|
+
}
|
19247
|
+
|
19248
|
+
// only gradients of trainable parameters should be accumulated
|
19249
|
+
if (accumulate && (node->flags & LM_GGML_TENSOR_FLAG_PARAM)) {
|
19250
|
+
const size_t insert_result = lm_ggml_hash_insert(&acc_table, node->grad);
|
19251
|
+
LM_GGML_ASSERT(insert_result != LM_GGML_HASHSET_FULL);
|
19252
|
+
LM_GGML_ASSERT(insert_result != LM_GGML_HASHSET_ALREADY_EXISTS);
|
19253
|
+
}
|
19016
19254
|
}
|
19017
19255
|
}
|
19018
19256
|
|
19019
19257
|
for (int i = gf->n_nodes - 1; i >= 0; i--) {
|
19020
19258
|
struct lm_ggml_tensor * node = gf->nodes[i];
|
19021
19259
|
|
19022
|
-
// inplace operations to add gradients are not created by lm_ggml_compute_backward
|
19260
|
+
// inplace operations to add gradients are not created by lm_ggml_compute_backward except for gradient accumulation
|
19023
19261
|
// use allocator to automatically make inplace operations
|
19024
19262
|
if (node->grad) {
|
19025
|
-
lm_ggml_compute_backward(ctx, node, &zero_table);
|
19263
|
+
lm_ggml_compute_backward(ctx, node, &zero_table, &acc_table);
|
19026
19264
|
}
|
19027
19265
|
}
|
19028
19266
|
|
@@ -19036,8 +19274,30 @@ void lm_ggml_build_backward_expand(struct lm_ggml_context * ctx, struct lm_ggml_
|
|
19036
19274
|
}
|
19037
19275
|
|
19038
19276
|
lm_ggml_hash_set_free(&zero_table);
|
19277
|
+
lm_ggml_hash_set_free(&acc_table);
|
19278
|
+
}
|
19279
|
+
|
19280
|
+
void lm_ggml_build_opt_adamw(
|
19281
|
+
struct lm_ggml_context * ctx,
|
19282
|
+
struct lm_ggml_cgraph * gf,
|
19283
|
+
struct lm_ggml_cgraph * gb,
|
19284
|
+
float alpha,
|
19285
|
+
float beta1,
|
19286
|
+
float beta2,
|
19287
|
+
float eps,
|
19288
|
+
float wd) {
|
19289
|
+
for (int i = 0; i < gf->n_nodes; i++) {
|
19290
|
+
struct lm_ggml_tensor * node = gf->nodes[i];
|
19291
|
+
|
19292
|
+
if (node->flags & LM_GGML_TENSOR_FLAG_PARAM) {
|
19293
|
+
LM_GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node);
|
19294
|
+
struct lm_ggml_tensor * opt_step = lm_ggml_opt_step_adamw(ctx, node, alpha, beta1, beta2, eps, wd);
|
19295
|
+
lm_ggml_build_forward_expand(gb, opt_step);
|
19296
|
+
}
|
19297
|
+
}
|
19039
19298
|
}
|
19040
19299
|
|
19300
|
+
|
19041
19301
|
static void * incr_ptr_aligned(void ** p, size_t size, size_t align) {
|
19042
19302
|
void * ptr = *p;
|
19043
19303
|
ptr = (void *) LM_GGML_PAD((uintptr_t) ptr, align);
|
@@ -19165,10 +19425,28 @@ void lm_ggml_graph_reset(struct lm_ggml_cgraph * cgraph) {
|
|
19165
19425
|
LM_GGML_ASSERT(cgraph->grads != NULL);
|
19166
19426
|
|
19167
19427
|
for (int i = 0; i < cgraph->n_nodes; i++) {
|
19168
|
-
struct lm_ggml_tensor *
|
19428
|
+
struct lm_ggml_tensor * node = cgraph->nodes[i];
|
19429
|
+
|
19430
|
+
// initial gradients of loss should be 1, 0 otherwise
|
19431
|
+
if (node->grad) {
|
19432
|
+
if (node->flags & LM_GGML_TENSOR_FLAG_LOSS) {
|
19433
|
+
LM_GGML_ASSERT(node->grad->buffer);
|
19434
|
+
LM_GGML_ASSERT(node->type == LM_GGML_TYPE_F32);
|
19435
|
+
LM_GGML_ASSERT(lm_ggml_is_scalar(node));
|
19436
|
+
|
19437
|
+
const float onef = 1.0f;
|
19438
|
+
lm_ggml_backend_tensor_set(node->grad, &onef, 0, lm_ggml_nbytes(node->grad));
|
19439
|
+
} else {
|
19440
|
+
lm_ggml_set_zero(node->grad);
|
19441
|
+
}
|
19442
|
+
}
|
19169
19443
|
|
19170
|
-
|
19171
|
-
|
19444
|
+
LM_GGML_ASSERT(node);
|
19445
|
+
if (node->op == LM_GGML_OP_OPT_STEP_ADAMW) {
|
19446
|
+
// set iteration to 1 and clear momenta
|
19447
|
+
lm_ggml_set_op_params_i32(node, 0, 1);
|
19448
|
+
lm_ggml_set_zero(node->src[2]);
|
19449
|
+
lm_ggml_set_zero(node->src[3]);
|
19172
19450
|
}
|
19173
19451
|
}
|
19174
19452
|
}
|
@@ -19461,6 +19739,7 @@ static int lm_ggml_get_n_tasks(struct lm_ggml_tensor * node, int n_threads) {
|
|
19461
19739
|
} break;
|
19462
19740
|
case LM_GGML_OP_CROSS_ENTROPY_LOSS:
|
19463
19741
|
case LM_GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
19742
|
+
case LM_GGML_OP_OPT_STEP_ADAMW:
|
19464
19743
|
{
|
19465
19744
|
n_tasks = n_threads;
|
19466
19745
|
} break;
|
@@ -19756,8 +20035,8 @@ void lm_ggml_threadpool_resume(struct lm_ggml_threadpool * threadpool) {
|
|
19756
20035
|
|
19757
20036
|
struct lm_ggml_cplan lm_ggml_graph_plan(
|
19758
20037
|
const struct lm_ggml_cgraph * cgraph,
|
19759
|
-
|
19760
|
-
|
20038
|
+
int n_threads,
|
20039
|
+
struct lm_ggml_threadpool * threadpool) {
|
19761
20040
|
|
19762
20041
|
if (threadpool == NULL) {
|
19763
20042
|
LM_GGML_PRINT_DEBUG("Threadpool is not specified. Will create a disposable threadpool : n_threads %d\n", n_threads);
|
@@ -19932,34 +20211,33 @@ struct lm_ggml_cplan lm_ggml_graph_plan(
|
|
19932
20211
|
|
19933
20212
|
static thread_ret_t lm_ggml_graph_compute_thread(void * data) {
|
19934
20213
|
struct lm_ggml_compute_state * state = (struct lm_ggml_compute_state *) data;
|
20214
|
+
struct lm_ggml_threadpool * tp = state->threadpool;
|
19935
20215
|
|
19936
|
-
const struct lm_ggml_cgraph * cgraph =
|
19937
|
-
const struct lm_ggml_cplan * cplan =
|
20216
|
+
const struct lm_ggml_cgraph * cgraph = tp->cgraph;
|
20217
|
+
const struct lm_ggml_cplan * cplan = tp->cplan;
|
19938
20218
|
|
19939
20219
|
set_numa_thread_affinity(state->ith);
|
19940
20220
|
|
19941
20221
|
struct lm_ggml_compute_params params = {
|
19942
20222
|
/*.ith =*/ state->ith,
|
19943
|
-
/*.nth =*/
|
20223
|
+
/*.nth =*/ atomic_load_explicit(&tp->n_threads_cur, memory_order_relaxed),
|
19944
20224
|
/*.wsize =*/ cplan->work_size,
|
19945
20225
|
/*.wdata =*/ cplan->work_data,
|
19946
|
-
/*.threadpool=*/
|
20226
|
+
/*.threadpool=*/ tp,
|
19947
20227
|
};
|
19948
20228
|
|
19949
|
-
for (int node_n = 0; node_n < cgraph->n_nodes; node_n++) {
|
20229
|
+
for (int node_n = 0; node_n < cgraph->n_nodes && !tp->abort; node_n++) {
|
19950
20230
|
struct lm_ggml_tensor * node = cgraph->nodes[node_n];
|
19951
20231
|
|
19952
20232
|
lm_ggml_compute_forward(¶ms, node);
|
19953
20233
|
|
19954
|
-
if (state->ith == 0 && cplan->abort_callback &&
|
19955
|
-
|
20234
|
+
if (state->ith == 0 && cplan->abort_callback &&
|
20235
|
+
cplan->abort_callback(cplan->abort_callback_data)) {
|
20236
|
+
tp->abort = true;
|
20237
|
+
tp->ec = LM_GGML_STATUS_ABORTED;
|
19956
20238
|
}
|
19957
20239
|
|
19958
20240
|
lm_ggml_barrier(state->threadpool);
|
19959
|
-
|
19960
|
-
if (state->threadpool->ec != LM_GGML_STATUS_SUCCESS) {
|
19961
|
-
break;
|
19962
|
-
}
|
19963
20241
|
}
|
19964
20242
|
|
19965
20243
|
return 0;
|
@@ -19967,7 +20245,15 @@ static thread_ret_t lm_ggml_graph_compute_thread(void * data) {
|
|
19967
20245
|
|
19968
20246
|
#ifndef LM_GGML_USE_OPENMP
|
19969
20247
|
|
19970
|
-
|
20248
|
+
// check if thread is active
|
20249
|
+
static inline bool lm_ggml_graph_compute_thread_active(struct lm_ggml_compute_state * state) {
|
20250
|
+
struct lm_ggml_threadpool * threadpool = state->threadpool;
|
20251
|
+
int n_threads = atomic_load_explicit(&threadpool->n_threads_cur, memory_order_relaxed);
|
20252
|
+
return (state->ith < n_threads);
|
20253
|
+
}
|
20254
|
+
|
20255
|
+
// check if thread is ready to proceed (exit from polling or sleeping)
|
20256
|
+
static inline bool lm_ggml_graph_compute_thread_ready(struct lm_ggml_compute_state * state) {
|
19971
20257
|
struct lm_ggml_threadpool * threadpool = state->threadpool;
|
19972
20258
|
|
19973
20259
|
if (state->pending || threadpool->stop || threadpool->pause) { return true; }
|
@@ -19975,21 +20261,37 @@ static inline bool lm_ggml_graph_compute_ready(struct lm_ggml_compute_state * st
|
|
19975
20261
|
// check for new graph/work
|
19976
20262
|
int new_graph = atomic_load_explicit(&threadpool->n_graph, memory_order_relaxed);
|
19977
20263
|
if (new_graph != state->last_graph) {
|
19978
|
-
state->pending = (state
|
20264
|
+
state->pending = lm_ggml_graph_compute_thread_active(state);
|
19979
20265
|
state->last_graph = new_graph;
|
19980
20266
|
}
|
19981
20267
|
|
19982
20268
|
return state->pending;
|
19983
20269
|
}
|
19984
20270
|
|
20271
|
+
// sync thread state after polling
|
20272
|
+
static inline void lm_ggml_graph_compute_thread_sync(struct lm_ggml_compute_state * state) {
|
20273
|
+
// TSAN doesn't support standalone fence yet, we use a dummy read-modify-write instead
|
20274
|
+
#ifdef LM_GGML_TSAN_ENABLED
|
20275
|
+
atomic_fetch_add_explicit(&state->threadpool->n_graph, 0, memory_order_seq_cst);
|
20276
|
+
#else
|
20277
|
+
atomic_thread_fence(memory_order_seq_cst);
|
20278
|
+
#endif
|
20279
|
+
UNUSED(state);
|
20280
|
+
}
|
20281
|
+
|
19985
20282
|
static inline bool lm_ggml_graph_compute_poll_for_work(struct lm_ggml_compute_state * state) {
|
19986
20283
|
struct lm_ggml_threadpool * threadpool = state->threadpool;
|
19987
20284
|
|
20285
|
+
// Skip polling for unused threads
|
20286
|
+
if (!lm_ggml_graph_compute_thread_active(state)) {
|
20287
|
+
return state->pending;
|
20288
|
+
}
|
20289
|
+
|
19988
20290
|
// This seems to make 0 ... 100 a decent range for polling level across modern processors.
|
19989
20291
|
// Perhaps, we can adjust it dynamically based on load and things.
|
19990
20292
|
const uint64_t n_rounds = 1024UL * 128 * threadpool->poll;
|
19991
20293
|
|
19992
|
-
for (uint64_t i=0; !
|
20294
|
+
for (uint64_t i=0; !lm_ggml_graph_compute_thread_ready(state) && i < n_rounds; i++) {
|
19993
20295
|
// No new work. Keep polling.
|
19994
20296
|
lm_ggml_thread_cpu_relax();
|
19995
20297
|
}
|
@@ -20001,13 +20303,14 @@ static inline bool lm_ggml_graph_compute_check_for_work(struct lm_ggml_compute_s
|
|
20001
20303
|
struct lm_ggml_threadpool * threadpool = state->threadpool;
|
20002
20304
|
|
20003
20305
|
if (lm_ggml_graph_compute_poll_for_work(state)) {
|
20306
|
+
lm_ggml_graph_compute_thread_sync(state);
|
20004
20307
|
return state->pending;
|
20005
20308
|
}
|
20006
20309
|
|
20007
20310
|
lm_ggml_mutex_lock_shared(&threadpool->mutex);
|
20008
|
-
while (!
|
20311
|
+
while (!lm_ggml_graph_compute_thread_ready(state)) {
|
20009
20312
|
// No new work. Wait for the signal.
|
20010
|
-
LM_GGML_PRINT_DEBUG("thread #%d waiting for work\n", state->ith);
|
20313
|
+
LM_GGML_PRINT_DEBUG("thread #%d waiting for work (sleeping)\n", state->ith);
|
20011
20314
|
lm_ggml_cond_wait(&threadpool->cond, &threadpool->mutex);
|
20012
20315
|
}
|
20013
20316
|
lm_ggml_mutex_unlock_shared(&threadpool->mutex);
|
@@ -20054,13 +20357,20 @@ static thread_ret_t lm_ggml_graph_compute_secondary_thread(void* data) {
|
|
20054
20357
|
}
|
20055
20358
|
|
20056
20359
|
// Start processing new graph
|
20057
|
-
static void lm_ggml_graph_compute_kickoff(struct lm_ggml_threadpool * threadpool)
|
20360
|
+
static void lm_ggml_graph_compute_kickoff(struct lm_ggml_threadpool * threadpool, int n_threads)
|
20058
20361
|
{
|
20059
|
-
//
|
20362
|
+
// Always take the mutex here because the worker threads are doing hybrid poll/wait
|
20060
20363
|
|
20061
20364
|
lm_ggml_mutex_lock(&threadpool->mutex);
|
20062
20365
|
|
20063
|
-
|
20366
|
+
LM_GGML_PRINT_DEBUG("threadpool: n_threads_cur %d n_threads %d\n", threadpool->n_threads_cur, n_threads);
|
20367
|
+
|
20368
|
+
// Update the number of active threads
|
20369
|
+
atomic_store_explicit(&threadpool->n_threads_cur, n_threads, memory_order_relaxed);
|
20370
|
+
|
20371
|
+
// Indicate the graph is ready to be processed
|
20372
|
+
// We need the full seq-cst fence here because of the polling threads (used in thread_sync)
|
20373
|
+
atomic_fetch_add_explicit(&threadpool->n_graph, 1, memory_order_seq_cst);
|
20064
20374
|
|
20065
20375
|
if (threadpool->pause) {
|
20066
20376
|
// Update main thread prio and affinity to match the threadpool settings
|
@@ -20119,6 +20429,7 @@ static struct lm_ggml_threadpool * lm_ggml_threadpool_new_impl(
|
|
20119
20429
|
threadpool->current_chunk = 0;
|
20120
20430
|
threadpool->stop = false;
|
20121
20431
|
threadpool->pause = tpp->paused;
|
20432
|
+
threadpool->abort = false;
|
20122
20433
|
threadpool->workers = NULL;
|
20123
20434
|
threadpool->n_threads_max = tpp->n_threads;
|
20124
20435
|
threadpool->n_threads_cur = tpp->n_threads;
|
@@ -20194,15 +20505,11 @@ enum lm_ggml_status lm_ggml_graph_compute(struct lm_ggml_cgraph * cgraph, struct
|
|
20194
20505
|
// No worker threads should be accessing the parameters below at this stage
|
20195
20506
|
threadpool->cgraph = cgraph;
|
20196
20507
|
threadpool->cplan = cplan;
|
20197
|
-
threadpool->n_threads_cur = n_threads;
|
20198
20508
|
threadpool->current_chunk = 0;
|
20509
|
+
threadpool->abort = false;
|
20199
20510
|
threadpool->ec = LM_GGML_STATUS_SUCCESS;
|
20200
20511
|
}
|
20201
20512
|
|
20202
|
-
if (n_threads > threadpool->n_threads_max) {
|
20203
|
-
LM_GGML_PRINT("WARNING: cplan is requesting more threads than the threadpool contains. Expect a bad time!\n");
|
20204
|
-
}
|
20205
|
-
|
20206
20513
|
#ifdef LM_GGML_USE_OPENMP
|
20207
20514
|
if (n_threads > 1) {
|
20208
20515
|
#pragma omp parallel num_threads(n_threads)
|
@@ -20211,17 +20518,23 @@ enum lm_ggml_status lm_ggml_graph_compute(struct lm_ggml_cgraph * cgraph, struct
|
|
20211
20518
|
{
|
20212
20519
|
// update the number of threads from the actual number of threads that we got from OpenMP
|
20213
20520
|
n_threads = omp_get_num_threads();
|
20214
|
-
threadpool->n_threads_cur
|
20521
|
+
atomic_store_explicit(&threadpool->n_threads_cur, n_threads, memory_order_relaxed);
|
20215
20522
|
}
|
20216
20523
|
|
20217
20524
|
lm_ggml_graph_compute_thread(&threadpool->workers[omp_get_thread_num()]);
|
20218
20525
|
}
|
20219
20526
|
} else {
|
20527
|
+
atomic_store_explicit(&threadpool->n_threads_cur, 1, memory_order_relaxed);
|
20220
20528
|
lm_ggml_graph_compute_thread(&threadpool->workers[0]);
|
20221
20529
|
}
|
20222
20530
|
#else
|
20531
|
+
if (n_threads > threadpool->n_threads_max) {
|
20532
|
+
LM_GGML_PRINT("WARNING: cplan requested more threads (%d) than available (%d)\n", n_threads, threadpool->n_threads_max);
|
20533
|
+
n_threads = threadpool->n_threads_max;
|
20534
|
+
}
|
20535
|
+
|
20223
20536
|
// Kick all threads to start the new graph
|
20224
|
-
lm_ggml_graph_compute_kickoff(threadpool);
|
20537
|
+
lm_ggml_graph_compute_kickoff(threadpool, n_threads);
|
20225
20538
|
|
20226
20539
|
// This is a work thread too
|
20227
20540
|
lm_ggml_graph_compute_thread(&threadpool->workers[0]);
|
@@ -21823,7 +22136,7 @@ enum lm_ggml_opt_result lm_ggml_opt_resume(
|
|
21823
22136
|
lm_ggml_build_forward_expand(gf, f);
|
21824
22137
|
|
21825
22138
|
struct lm_ggml_cgraph * gb = lm_ggml_graph_dup(ctx, gf);
|
21826
|
-
lm_ggml_build_backward_expand(ctx, gf, gb, true);
|
22139
|
+
lm_ggml_build_backward_expand(ctx, gf, gb, false, true);
|
21827
22140
|
|
21828
22141
|
return lm_ggml_opt_resume_g(ctx, opt, f, gf, gb, NULL, NULL);
|
21829
22142
|
}
|