llama_cpp 0.1.4 → 0.2.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +36 -0
- data/examples/README.md +60 -0
- data/examples/chat.rb +195 -0
- data/ext/llama_cpp/extconf.rb +26 -1
- data/ext/llama_cpp/llama_cpp.cpp +262 -13
- data/ext/llama_cpp/src/ggml-cuda.cu +2483 -0
- data/ext/llama_cpp/src/ggml-cuda.h +18 -2
- data/ext/llama_cpp/src/ggml-metal.h +64 -0
- data/ext/llama_cpp/src/ggml-metal.m +834 -0
- data/ext/llama_cpp/src/ggml-metal.metal +1436 -0
- data/ext/llama_cpp/src/ggml-opencl.cpp +207 -40
- data/ext/llama_cpp/src/ggml-opencl.h +4 -1
- data/ext/llama_cpp/src/ggml.c +2236 -404
- data/ext/llama_cpp/src/ggml.h +170 -8
- data/ext/llama_cpp/src/k_quants.c +2244 -0
- data/ext/llama_cpp/src/k_quants.h +122 -0
- data/ext/llama_cpp/src/llama-util.h +16 -0
- data/ext/llama_cpp/src/llama.cpp +631 -179
- data/ext/llama_cpp/src/llama.h +51 -11
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +36 -1
- metadata +10 -2
data/ext/llama_cpp/src/ggml.c
CHANGED
@@ -3,6 +3,10 @@
|
|
3
3
|
|
4
4
|
#include "ggml.h"
|
5
5
|
|
6
|
+
#ifdef GGML_USE_K_QUANTS
|
7
|
+
#include "k_quants.h"
|
8
|
+
#endif
|
9
|
+
|
6
10
|
#if defined(_MSC_VER) || defined(__MINGW32__)
|
7
11
|
#include <malloc.h> // using malloc.h with MSC/MINGW
|
8
12
|
#elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__)
|
@@ -21,6 +25,10 @@
|
|
21
25
|
#include <float.h>
|
22
26
|
#include <limits.h>
|
23
27
|
|
28
|
+
#ifdef GGML_USE_METAL
|
29
|
+
#include <unistd.h>
|
30
|
+
#endif
|
31
|
+
|
24
32
|
// if C99 - static_assert is noop
|
25
33
|
// ref: https://stackoverflow.com/a/53923785/4039976
|
26
34
|
#ifndef static_assert
|
@@ -121,7 +129,11 @@ typedef void* thread_ret_t;
|
|
121
129
|
#else
|
122
130
|
inline static void* ggml_aligned_malloc(size_t size) {
|
123
131
|
void* aligned_memory = NULL;
|
132
|
+
#ifdef GGML_USE_METAL
|
133
|
+
int result = posix_memalign(&aligned_memory, getpagesize(), size);
|
134
|
+
#else
|
124
135
|
int result = posix_memalign(&aligned_memory, GGML_MEM_ALIGN, size);
|
136
|
+
#endif
|
125
137
|
if (result != 0) {
|
126
138
|
// Handle allocation failure
|
127
139
|
return NULL;
|
@@ -403,21 +415,27 @@ void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, size_t n) {
|
|
403
415
|
//
|
404
416
|
|
405
417
|
#if defined(_MSC_VER) || defined(__MINGW32__)
|
406
|
-
static int64_t timer_freq;
|
418
|
+
static int64_t timer_freq, timer_start;
|
407
419
|
void ggml_time_init(void) {
|
408
|
-
LARGE_INTEGER
|
409
|
-
QueryPerformanceFrequency(&
|
410
|
-
timer_freq =
|
420
|
+
LARGE_INTEGER t;
|
421
|
+
QueryPerformanceFrequency(&t);
|
422
|
+
timer_freq = t.QuadPart;
|
423
|
+
|
424
|
+
// The multiplication by 1000 or 1000000 below can cause an overflow if timer_freq
|
425
|
+
// and the uptime is high enough.
|
426
|
+
// We subtract the program start time to reduce the likelihood of that happening.
|
427
|
+
QueryPerformanceCounter(&t);
|
428
|
+
timer_start = t.QuadPart;
|
411
429
|
}
|
412
430
|
int64_t ggml_time_ms(void) {
|
413
431
|
LARGE_INTEGER t;
|
414
432
|
QueryPerformanceCounter(&t);
|
415
|
-
return (t.QuadPart * 1000) / timer_freq;
|
433
|
+
return ((t.QuadPart-timer_start) * 1000) / timer_freq;
|
416
434
|
}
|
417
435
|
int64_t ggml_time_us(void) {
|
418
436
|
LARGE_INTEGER t;
|
419
437
|
QueryPerformanceCounter(&t);
|
420
|
-
return (t.QuadPart * 1000000) / timer_freq;
|
438
|
+
return ((t.QuadPart-timer_start) * 1000000) / timer_freq;
|
421
439
|
}
|
422
440
|
#else
|
423
441
|
void ggml_time_init(void) {}
|
@@ -474,6 +492,8 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
|
|
474
492
|
// quantization
|
475
493
|
//
|
476
494
|
|
495
|
+
#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
|
496
|
+
|
477
497
|
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
|
478
498
|
// multiply int8_t, add results pairwise twice
|
479
499
|
static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) {
|
@@ -533,7 +553,7 @@ static inline __m256i bytes_from_bits_32(const uint8_t * x) {
|
|
533
553
|
static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
|
534
554
|
{
|
535
555
|
const __m128i tmp = _mm_loadu_si128((const __m128i *)rsi);
|
536
|
-
const __m256i bytes =
|
556
|
+
const __m256i bytes = MM256_SET_M128I(_mm_srli_epi16(tmp, 4), tmp);
|
537
557
|
const __m256i lowMask = _mm256_set1_epi8( 0xF );
|
538
558
|
return _mm256_and_si256(lowMask, bytes);
|
539
559
|
}
|
@@ -606,7 +626,7 @@ static inline __m256i bytes_from_bits_32(const uint8_t * x) {
|
|
606
626
|
bytesh = _mm_or_si128(bytesh, bit_mask);
|
607
627
|
bytesl = _mm_cmpeq_epi8(bytesl, _mm_set1_epi64x(-1));
|
608
628
|
bytesh = _mm_cmpeq_epi8(bytesh, _mm_set1_epi64x(-1));
|
609
|
-
return
|
629
|
+
return MM256_SET_M128I(bytesh, bytesl);
|
610
630
|
}
|
611
631
|
|
612
632
|
// Unpack 32 4-bit fields into 32 bytes
|
@@ -619,7 +639,7 @@ static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
|
|
619
639
|
const __m128i lowMask = _mm_set1_epi8(0xF);
|
620
640
|
tmpl = _mm_and_si128(lowMask, tmpl);
|
621
641
|
tmph = _mm_and_si128(lowMask, tmph);
|
622
|
-
return
|
642
|
+
return MM256_SET_M128I(tmph, tmpl);
|
623
643
|
}
|
624
644
|
|
625
645
|
// add int16_t pairwise and return as float vector
|
@@ -627,7 +647,7 @@ static inline __m256 sum_i16_pairs_float(const __m128i xh, const __m128i xl) {
|
|
627
647
|
const __m128i ones = _mm_set1_epi16(1);
|
628
648
|
const __m128i summed_pairsl = _mm_madd_epi16(ones, xl);
|
629
649
|
const __m128i summed_pairsh = _mm_madd_epi16(ones, xh);
|
630
|
-
const __m256i summed_pairs =
|
650
|
+
const __m256i summed_pairs = MM256_SET_M128I(summed_pairsh, summed_pairsl);
|
631
651
|
return _mm256_cvtepi32_ps(summed_pairs);
|
632
652
|
}
|
633
653
|
|
@@ -1565,6 +1585,48 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
|
|
1565
1585
|
.vec_dot_q = NULL, // TODO
|
1566
1586
|
.vec_dot_type = GGML_TYPE_Q8_1,
|
1567
1587
|
},
|
1588
|
+
#ifdef GGML_USE_K_QUANTS
|
1589
|
+
[GGML_TYPE_Q2_K] = {
|
1590
|
+
.dequantize_row_q = (dequantize_row_q_t) dequantize_row_q2_K,
|
1591
|
+
.quantize_row_q = quantize_row_q2_K,
|
1592
|
+
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q2_K_reference,
|
1593
|
+
.quantize_row_q_dot = quantize_row_q8_K,
|
1594
|
+
.vec_dot_q = ggml_vec_dot_q2_K_q8_K,
|
1595
|
+
.vec_dot_type = GGML_TYPE_Q8_K,
|
1596
|
+
},
|
1597
|
+
[GGML_TYPE_Q3_K] = {
|
1598
|
+
.dequantize_row_q = (dequantize_row_q_t) dequantize_row_q3_K,
|
1599
|
+
.quantize_row_q = quantize_row_q3_K,
|
1600
|
+
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q3_K_reference,
|
1601
|
+
.quantize_row_q_dot = quantize_row_q8_K,
|
1602
|
+
.vec_dot_q = ggml_vec_dot_q3_K_q8_K,
|
1603
|
+
.vec_dot_type = GGML_TYPE_Q8_K,
|
1604
|
+
},
|
1605
|
+
[GGML_TYPE_Q4_K] = {
|
1606
|
+
.dequantize_row_q = (dequantize_row_q_t) dequantize_row_q4_K,
|
1607
|
+
.quantize_row_q = quantize_row_q4_K,
|
1608
|
+
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_K_reference,
|
1609
|
+
.quantize_row_q_dot = quantize_row_q8_K,
|
1610
|
+
.vec_dot_q = ggml_vec_dot_q4_K_q8_K,
|
1611
|
+
.vec_dot_type = GGML_TYPE_Q8_K,
|
1612
|
+
},
|
1613
|
+
[GGML_TYPE_Q5_K] = {
|
1614
|
+
.dequantize_row_q = (dequantize_row_q_t) dequantize_row_q5_K,
|
1615
|
+
.quantize_row_q = quantize_row_q5_K,
|
1616
|
+
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q5_K_reference,
|
1617
|
+
.quantize_row_q_dot = quantize_row_q8_K,
|
1618
|
+
.vec_dot_q = ggml_vec_dot_q5_K_q8_K,
|
1619
|
+
.vec_dot_type = GGML_TYPE_Q8_K,
|
1620
|
+
},
|
1621
|
+
[GGML_TYPE_Q6_K] = {
|
1622
|
+
.dequantize_row_q = (dequantize_row_q_t) dequantize_row_q6_K,
|
1623
|
+
.quantize_row_q = quantize_row_q6_K,
|
1624
|
+
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q6_K_reference,
|
1625
|
+
.quantize_row_q_dot = quantize_row_q8_K,
|
1626
|
+
.vec_dot_q = ggml_vec_dot_q6_K_q8_K,
|
1627
|
+
.vec_dot_type = GGML_TYPE_Q8_K,
|
1628
|
+
},
|
1629
|
+
#endif
|
1568
1630
|
};
|
1569
1631
|
|
1570
1632
|
// For internal test use
|
@@ -2290,7 +2352,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
|
|
2290
2352
|
const __m128i i32_1 = mul_sum_i8_pairs(bx, by);
|
2291
2353
|
|
2292
2354
|
// Convert int32_t to float
|
2293
|
-
__m256 p = _mm256_cvtepi32_ps(
|
2355
|
+
__m256 p = _mm256_cvtepi32_ps(MM256_SET_M128I(i32_0, i32_1));
|
2294
2356
|
|
2295
2357
|
// Apply the scale, and accumulate
|
2296
2358
|
acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc);
|
@@ -2766,7 +2828,7 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
|
|
2766
2828
|
__m128i bxh = _mm256_extractf128_si256(bx, 1);
|
2767
2829
|
bxl = _mm_or_si128(bxl, bxhil);
|
2768
2830
|
bxh = _mm_or_si128(bxh, bxhih);
|
2769
|
-
bx =
|
2831
|
+
bx = MM256_SET_M128I(bxh, bxl);
|
2770
2832
|
|
2771
2833
|
const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
|
2772
2834
|
|
@@ -3022,7 +3084,7 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
|
|
3022
3084
|
__m128i bxh = _mm256_extractf128_si256(bx, 1);
|
3023
3085
|
bxl = _mm_or_si128(bxl, bxhil);
|
3024
3086
|
bxh = _mm_or_si128(bxh, bxhih);
|
3025
|
-
bx =
|
3087
|
+
bx = MM256_SET_M128I(bxh, bxl);
|
3026
3088
|
|
3027
3089
|
const __m256 dy = _mm256_set1_ps(y[i].d);
|
3028
3090
|
const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
|
@@ -3444,11 +3506,19 @@ static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = {
|
|
3444
3506
|
[GGML_TYPE_Q5_1] = QK5_1,
|
3445
3507
|
[GGML_TYPE_Q8_0] = QK8_0,
|
3446
3508
|
[GGML_TYPE_Q8_1] = QK8_1,
|
3509
|
+
#ifdef GGML_USE_K_QUANTS
|
3510
|
+
[GGML_TYPE_Q2_K] = QK_K,
|
3511
|
+
[GGML_TYPE_Q3_K] = QK_K,
|
3512
|
+
[GGML_TYPE_Q4_K] = QK_K,
|
3513
|
+
[GGML_TYPE_Q5_K] = QK_K,
|
3514
|
+
[GGML_TYPE_Q6_K] = QK_K,
|
3515
|
+
[GGML_TYPE_Q8_K] = QK_K,
|
3516
|
+
#endif
|
3447
3517
|
[GGML_TYPE_I8] = 1,
|
3448
3518
|
[GGML_TYPE_I16] = 1,
|
3449
3519
|
[GGML_TYPE_I32] = 1,
|
3450
3520
|
};
|
3451
|
-
static_assert(GGML_TYPE_COUNT ==
|
3521
|
+
static_assert(GGML_TYPE_COUNT == 19, "GGML_BLCK_SIZE is outdated");
|
3452
3522
|
|
3453
3523
|
static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
|
3454
3524
|
[GGML_TYPE_F32] = sizeof(float),
|
@@ -3459,11 +3529,19 @@ static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
|
|
3459
3529
|
[GGML_TYPE_Q5_1] = sizeof(block_q5_1),
|
3460
3530
|
[GGML_TYPE_Q8_0] = sizeof(block_q8_0),
|
3461
3531
|
[GGML_TYPE_Q8_1] = sizeof(block_q8_1),
|
3532
|
+
#ifdef GGML_USE_K_QUANTS
|
3533
|
+
[GGML_TYPE_Q2_K] = sizeof(block_q2_K),
|
3534
|
+
[GGML_TYPE_Q3_K] = sizeof(block_q3_K),
|
3535
|
+
[GGML_TYPE_Q4_K] = sizeof(block_q4_K),
|
3536
|
+
[GGML_TYPE_Q5_K] = sizeof(block_q5_K),
|
3537
|
+
[GGML_TYPE_Q6_K] = sizeof(block_q6_K),
|
3538
|
+
[GGML_TYPE_Q8_K] = sizeof(block_q8_K),
|
3539
|
+
#endif
|
3462
3540
|
[GGML_TYPE_I8] = sizeof(int8_t),
|
3463
3541
|
[GGML_TYPE_I16] = sizeof(int16_t),
|
3464
3542
|
[GGML_TYPE_I32] = sizeof(int32_t),
|
3465
3543
|
};
|
3466
|
-
static_assert(GGML_TYPE_COUNT ==
|
3544
|
+
static_assert(GGML_TYPE_COUNT == 19, "GGML_TYPE_SIZE is outdated");
|
3467
3545
|
|
3468
3546
|
|
3469
3547
|
static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
|
@@ -3475,11 +3553,17 @@ static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
|
|
3475
3553
|
[GGML_TYPE_Q5_1] = "q5_1",
|
3476
3554
|
[GGML_TYPE_Q8_0] = "q8_0",
|
3477
3555
|
[GGML_TYPE_Q8_1] = "q8_1",
|
3556
|
+
[GGML_TYPE_Q2_K] = "q2_K",
|
3557
|
+
[GGML_TYPE_Q3_K] = "q3_K",
|
3558
|
+
[GGML_TYPE_Q4_K] = "q4_K",
|
3559
|
+
[GGML_TYPE_Q5_K] = "q5_K",
|
3560
|
+
[GGML_TYPE_Q6_K] = "q6_K",
|
3561
|
+
[GGML_TYPE_Q8_K] = "q8_K",
|
3478
3562
|
[GGML_TYPE_I8] = "i8",
|
3479
3563
|
[GGML_TYPE_I16] = "i16",
|
3480
3564
|
[GGML_TYPE_I32] = "i32",
|
3481
3565
|
};
|
3482
|
-
static_assert(GGML_TYPE_COUNT ==
|
3566
|
+
static_assert(GGML_TYPE_COUNT == 19, "GGML_TYPE_NAME is outdated");
|
3483
3567
|
|
3484
3568
|
static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = {
|
3485
3569
|
[GGML_TYPE_F32] = false,
|
@@ -3490,11 +3574,17 @@ static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = {
|
|
3490
3574
|
[GGML_TYPE_Q5_1] = true,
|
3491
3575
|
[GGML_TYPE_Q8_0] = true,
|
3492
3576
|
[GGML_TYPE_Q8_1] = true,
|
3577
|
+
[GGML_TYPE_Q2_K] = true,
|
3578
|
+
[GGML_TYPE_Q3_K] = true,
|
3579
|
+
[GGML_TYPE_Q4_K] = true,
|
3580
|
+
[GGML_TYPE_Q5_K] = true,
|
3581
|
+
[GGML_TYPE_Q6_K] = true,
|
3582
|
+
[GGML_TYPE_Q8_K] = true,
|
3493
3583
|
[GGML_TYPE_I8] = false,
|
3494
3584
|
[GGML_TYPE_I16] = false,
|
3495
3585
|
[GGML_TYPE_I32] = false,
|
3496
3586
|
};
|
3497
|
-
static_assert(GGML_TYPE_COUNT ==
|
3587
|
+
static_assert(GGML_TYPE_COUNT == 19, "GGML_IS_QUANTIZED is outdated");
|
3498
3588
|
|
3499
3589
|
static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
3500
3590
|
"NONE",
|
@@ -3513,6 +3603,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|
3513
3603
|
"SUM_ROWS",
|
3514
3604
|
"MEAN",
|
3515
3605
|
"REPEAT",
|
3606
|
+
"REPEAT_BACK",
|
3516
3607
|
"ABS",
|
3517
3608
|
"SGN",
|
3518
3609
|
"NEG",
|
@@ -3526,6 +3617,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|
3526
3617
|
"RMS_NORM_BACK",
|
3527
3618
|
|
3528
3619
|
"MUL_MAT",
|
3620
|
+
"OUT_PROD",
|
3529
3621
|
|
3530
3622
|
"SCALE",
|
3531
3623
|
"SET",
|
@@ -3541,6 +3633,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|
3541
3633
|
"DIAG_MASK_INF",
|
3542
3634
|
"DIAG_MASK_ZERO",
|
3543
3635
|
"SOFT_MAX",
|
3636
|
+
"SOFT_MAX_BACK",
|
3544
3637
|
"ROPE",
|
3545
3638
|
"ROPE_BACK",
|
3546
3639
|
"ALIBI",
|
@@ -3550,13 +3643,16 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|
3550
3643
|
|
3551
3644
|
"FLASH_ATTN",
|
3552
3645
|
"FLASH_FF",
|
3646
|
+
"FLASH_ATTN_BACK",
|
3553
3647
|
|
3554
3648
|
"MAP_UNARY",
|
3555
3649
|
"MAP_BINARY",
|
3556
|
-
};
|
3557
3650
|
|
3558
|
-
|
3651
|
+
"CROSS_ENTROPY_LOSS",
|
3652
|
+
"CROSS_ENTROPY_LOSS_BACK",
|
3653
|
+
};
|
3559
3654
|
|
3655
|
+
static_assert(GGML_OP_COUNT == 57, "GGML_OP_COUNT != 57");
|
3560
3656
|
|
3561
3657
|
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
3562
3658
|
"none",
|
@@ -3575,6 +3671,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|
3575
3671
|
"Σx_k",
|
3576
3672
|
"Σx/n",
|
3577
3673
|
"repeat(x)",
|
3674
|
+
"repeat_back(x)",
|
3578
3675
|
"abs(x)",
|
3579
3676
|
"sgn(x)",
|
3580
3677
|
"-x",
|
@@ -3587,6 +3684,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|
3587
3684
|
"rms_norm(x)",
|
3588
3685
|
"rms_norm_back(x)",
|
3589
3686
|
|
3687
|
+
"X*Y",
|
3590
3688
|
"X*Y",
|
3591
3689
|
|
3592
3690
|
"x*v",
|
@@ -3603,6 +3701,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|
3603
3701
|
"diag_mask_inf(x)",
|
3604
3702
|
"diag_mask_zero(x)",
|
3605
3703
|
"soft_max(x)",
|
3704
|
+
"soft_max_back(x)",
|
3606
3705
|
"rope(x)",
|
3607
3706
|
"rope_back(x)",
|
3608
3707
|
"alibi(x)",
|
@@ -3612,12 +3711,16 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|
3612
3711
|
|
3613
3712
|
"flash_attn(x)",
|
3614
3713
|
"flash_ff(x)",
|
3714
|
+
"flash_attn_back(x)",
|
3615
3715
|
|
3616
3716
|
"f(x)",
|
3617
3717
|
"f(x,y)",
|
3718
|
+
|
3719
|
+
"cross_entropy_loss(x,y)",
|
3720
|
+
"cross_entropy_loss_back(x,y)",
|
3618
3721
|
};
|
3619
3722
|
|
3620
|
-
static_assert(GGML_OP_COUNT ==
|
3723
|
+
static_assert(GGML_OP_COUNT == 57, "GGML_OP_COUNT != 57");
|
3621
3724
|
|
3622
3725
|
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
|
3623
3726
|
static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
|
@@ -3631,6 +3734,7 @@ struct ggml_context {
|
|
3631
3734
|
void * mem_buffer;
|
3632
3735
|
bool mem_buffer_owned;
|
3633
3736
|
bool no_alloc;
|
3737
|
+
bool no_alloc_save; // this is used to save the no_alloc state when using scratch buffers
|
3634
3738
|
|
3635
3739
|
int n_objects;
|
3636
3740
|
|
@@ -3647,26 +3751,6 @@ struct ggml_context_container {
|
|
3647
3751
|
struct ggml_context context;
|
3648
3752
|
};
|
3649
3753
|
|
3650
|
-
//
|
3651
|
-
// compute types
|
3652
|
-
//
|
3653
|
-
|
3654
|
-
enum ggml_task_type {
|
3655
|
-
GGML_TASK_INIT = 0,
|
3656
|
-
GGML_TASK_COMPUTE,
|
3657
|
-
GGML_TASK_FINALIZE,
|
3658
|
-
};
|
3659
|
-
|
3660
|
-
struct ggml_compute_params {
|
3661
|
-
enum ggml_task_type type;
|
3662
|
-
|
3663
|
-
int ith, nth;
|
3664
|
-
|
3665
|
-
// work buffer for all threads
|
3666
|
-
size_t wsize;
|
3667
|
-
void * wdata;
|
3668
|
-
};
|
3669
|
-
|
3670
3754
|
//
|
3671
3755
|
// ggml state
|
3672
3756
|
//
|
@@ -3723,7 +3807,7 @@ int64_t ggml_nelements(const struct ggml_tensor * tensor) {
|
|
3723
3807
|
return tensor->ne[0]*tensor->ne[1]*tensor->ne[2]*tensor->ne[3];
|
3724
3808
|
}
|
3725
3809
|
|
3726
|
-
|
3810
|
+
int64_t ggml_nrows(const struct ggml_tensor * tensor) {
|
3727
3811
|
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
3728
3812
|
|
3729
3813
|
return tensor->ne[1]*tensor->ne[2]*tensor->ne[3];
|
@@ -3732,7 +3816,20 @@ int ggml_nrows(const struct ggml_tensor * tensor) {
|
|
3732
3816
|
size_t ggml_nbytes(const struct ggml_tensor * tensor) {
|
3733
3817
|
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
3734
3818
|
|
3735
|
-
|
3819
|
+
// this should handle cases where the tensor is not contiguous in memory
|
3820
|
+
// probaby just:
|
3821
|
+
//
|
3822
|
+
// return tensor->ne[3]*tensor->nb[3]
|
3823
|
+
//
|
3824
|
+
// is enough, but just in case, adding the second part
|
3825
|
+
|
3826
|
+
return MAX(tensor->ne[3]*tensor->nb[3], (ggml_nelements(tensor)*GGML_TYPE_SIZE[tensor->type])/GGML_BLCK_SIZE[tensor->type]);
|
3827
|
+
}
|
3828
|
+
|
3829
|
+
size_t ggml_nbytes_split(const struct ggml_tensor * tensor, int nrows_split) {
|
3830
|
+
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
3831
|
+
|
3832
|
+
return (nrows_split*tensor->ne[0]*GGML_TYPE_SIZE[tensor->type])/GGML_BLCK_SIZE[tensor->type];
|
3736
3833
|
}
|
3737
3834
|
|
3738
3835
|
int ggml_blck_size(enum ggml_type type) {
|
@@ -3786,6 +3883,15 @@ static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct
|
|
3786
3883
|
(t0->ne[3] == t1->ne[3]);
|
3787
3884
|
}
|
3788
3885
|
|
3886
|
+
static inline bool ggml_can_out_prod(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
|
3887
|
+
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
3888
|
+
|
3889
|
+
return
|
3890
|
+
(t0->ne[1] == t1->ne[1]) &&
|
3891
|
+
(t0->ne[2] == t1->ne[2]) &&
|
3892
|
+
(t0->ne[3] == t1->ne[3]);
|
3893
|
+
}
|
3894
|
+
|
3789
3895
|
bool ggml_is_quantized(enum ggml_type type) {
|
3790
3896
|
return GGML_IS_QUANTIZED[type];
|
3791
3897
|
}
|
@@ -3801,6 +3907,11 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
|
|
3801
3907
|
case GGML_FTYPE_MOSTLY_Q5_0: wtype = GGML_TYPE_Q5_0; break;
|
3802
3908
|
case GGML_FTYPE_MOSTLY_Q5_1: wtype = GGML_TYPE_Q5_1; break;
|
3803
3909
|
case GGML_FTYPE_MOSTLY_Q8_0: wtype = GGML_TYPE_Q8_0; break;
|
3910
|
+
case GGML_FTYPE_MOSTLY_Q2_K: wtype = GGML_TYPE_Q2_K; break;
|
3911
|
+
case GGML_FTYPE_MOSTLY_Q3_K: wtype = GGML_TYPE_Q3_K; break;
|
3912
|
+
case GGML_FTYPE_MOSTLY_Q4_K: wtype = GGML_TYPE_Q4_K; break;
|
3913
|
+
case GGML_FTYPE_MOSTLY_Q5_K: wtype = GGML_TYPE_Q5_K; break;
|
3914
|
+
case GGML_FTYPE_MOSTLY_Q6_K: wtype = GGML_TYPE_Q6_K; break;
|
3804
3915
|
case GGML_FTYPE_UNKNOWN: wtype = GGML_TYPE_COUNT; break;
|
3805
3916
|
case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_TYPE_COUNT; break;
|
3806
3917
|
}
|
@@ -3814,11 +3925,11 @@ size_t ggml_tensor_overhead(void) {
|
|
3814
3925
|
return GGML_OBJECT_SIZE + GGML_TENSOR_SIZE + 16;
|
3815
3926
|
}
|
3816
3927
|
|
3817
|
-
|
3928
|
+
bool ggml_is_transposed(const struct ggml_tensor * tensor) {
|
3818
3929
|
return tensor->nb[0] > tensor->nb[1];
|
3819
3930
|
}
|
3820
3931
|
|
3821
|
-
|
3932
|
+
bool ggml_is_contiguous(const struct ggml_tensor * tensor) {
|
3822
3933
|
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
3823
3934
|
|
3824
3935
|
return
|
@@ -3828,6 +3939,12 @@ static inline bool ggml_is_contiguous(const struct ggml_tensor * tensor) {
|
|
3828
3939
|
tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
|
3829
3940
|
}
|
3830
3941
|
|
3942
|
+
bool ggml_is_permuted(const struct ggml_tensor * tensor) {
|
3943
|
+
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
3944
|
+
|
3945
|
+
return tensor->nb[0] > tensor->nb[1] || tensor->nb[1] > tensor->nb[2] || tensor->nb[2] > tensor->nb[3];
|
3946
|
+
}
|
3947
|
+
|
3831
3948
|
static inline bool ggml_is_padded_1d(const struct ggml_tensor * tensor) {
|
3832
3949
|
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
3833
3950
|
|
@@ -3967,6 +4084,7 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
|
|
3967
4084
|
/*.mem_buffer =*/ params.mem_buffer ? params.mem_buffer : GGML_ALIGNED_MALLOC(mem_size),
|
3968
4085
|
/*.mem_buffer_owned =*/ params.mem_buffer ? false : true,
|
3969
4086
|
/*.no_alloc =*/ params.no_alloc,
|
4087
|
+
/*.no_alloc_save =*/ params.no_alloc,
|
3970
4088
|
/*.n_objects =*/ 0,
|
3971
4089
|
/*.objects_begin =*/ NULL,
|
3972
4090
|
/*.objects_end =*/ NULL,
|
@@ -4044,11 +4162,18 @@ size_t ggml_get_mem_size(struct ggml_context * ctx) {
|
|
4044
4162
|
// operators when using scratch buffers
|
4045
4163
|
// TODO: implement a better way
|
4046
4164
|
void ggml_scratch_save(struct ggml_context * ctx) {
|
4165
|
+
// this is needed to allow opt tensors to store their data
|
4166
|
+
// TODO: again, need to find a better way
|
4167
|
+
ctx->no_alloc_save = ctx->no_alloc;
|
4168
|
+
ctx->no_alloc = false;
|
4169
|
+
|
4047
4170
|
ctx->scratch_save = ctx->scratch;
|
4048
4171
|
ctx->scratch.data = NULL;
|
4049
4172
|
}
|
4050
4173
|
|
4051
4174
|
void ggml_scratch_load(struct ggml_context * ctx) {
|
4175
|
+
ctx->no_alloc = ctx->no_alloc_save;
|
4176
|
+
|
4052
4177
|
ctx->scratch = ctx->scratch_save;
|
4053
4178
|
}
|
4054
4179
|
|
@@ -4157,6 +4282,7 @@ struct ggml_tensor * ggml_new_tensor_impl(
|
|
4157
4282
|
/*.perf_time_us =*/ 0,
|
4158
4283
|
/*.data =*/ (data == NULL && !ctx->no_alloc) ? (void *)(result + 1) : data,
|
4159
4284
|
/*.name =*/ { 0 },
|
4285
|
+
/*.extra =*/ NULL,
|
4160
4286
|
/*.pad =*/ { 0 },
|
4161
4287
|
};
|
4162
4288
|
|
@@ -4595,7 +4721,7 @@ struct ggml_tensor * ggml_add_impl(
|
|
4595
4721
|
|
4596
4722
|
bool is_node = false;
|
4597
4723
|
|
4598
|
-
if (
|
4724
|
+
if (a->grad || b->grad) {
|
4599
4725
|
is_node = true;
|
4600
4726
|
}
|
4601
4727
|
|
@@ -4635,7 +4761,7 @@ struct ggml_tensor * ggml_add1_impl(
|
|
4635
4761
|
|
4636
4762
|
bool is_node = false;
|
4637
4763
|
|
4638
|
-
if (
|
4764
|
+
if (a->grad || b->grad) {
|
4639
4765
|
is_node = true;
|
4640
4766
|
}
|
4641
4767
|
|
@@ -5061,6 +5187,34 @@ struct ggml_tensor * ggml_repeat(
|
|
5061
5187
|
return result;
|
5062
5188
|
}
|
5063
5189
|
|
5190
|
+
// ggml_repeat_back
|
5191
|
+
|
5192
|
+
struct ggml_tensor * ggml_repeat_back(
|
5193
|
+
struct ggml_context * ctx,
|
5194
|
+
struct ggml_tensor * a,
|
5195
|
+
struct ggml_tensor * b) {
|
5196
|
+
GGML_ASSERT(ggml_can_repeat(b, a));
|
5197
|
+
|
5198
|
+
bool is_node = false;
|
5199
|
+
|
5200
|
+
if (a->grad) {
|
5201
|
+
is_node = true;
|
5202
|
+
}
|
5203
|
+
|
5204
|
+
if (ggml_are_same_shape(a, b) && !is_node) {
|
5205
|
+
return a;
|
5206
|
+
}
|
5207
|
+
|
5208
|
+
struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, b->n_dims, b->ne);
|
5209
|
+
|
5210
|
+
result->op = GGML_OP_REPEAT_BACK;
|
5211
|
+
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
5212
|
+
result->src0 = a;
|
5213
|
+
result->src1 = b;
|
5214
|
+
|
5215
|
+
return result;
|
5216
|
+
}
|
5217
|
+
|
5064
5218
|
// ggml_abs
|
5065
5219
|
|
5066
5220
|
struct ggml_tensor * ggml_abs_impl(
|
@@ -5438,6 +5592,32 @@ struct ggml_tensor * ggml_mul_mat(
|
|
5438
5592
|
return result;
|
5439
5593
|
}
|
5440
5594
|
|
5595
|
+
// ggml_out_prod
|
5596
|
+
|
5597
|
+
struct ggml_tensor * ggml_out_prod(
|
5598
|
+
struct ggml_context * ctx,
|
5599
|
+
struct ggml_tensor * a,
|
5600
|
+
struct ggml_tensor * b) {
|
5601
|
+
GGML_ASSERT(ggml_can_out_prod(a, b));
|
5602
|
+
GGML_ASSERT(!ggml_is_transposed(a));
|
5603
|
+
|
5604
|
+
bool is_node = false;
|
5605
|
+
|
5606
|
+
if (a->grad || b->grad) {
|
5607
|
+
is_node = true;
|
5608
|
+
}
|
5609
|
+
|
5610
|
+
const int64_t ne[4] = { a->ne[0], b->ne[0], a->ne[2], b->ne[3] };
|
5611
|
+
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MIN(a->n_dims, b->n_dims), ne);
|
5612
|
+
|
5613
|
+
result->op = GGML_OP_OUT_PROD;
|
5614
|
+
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
5615
|
+
result->src0 = a;
|
5616
|
+
result->src1 = b;
|
5617
|
+
|
5618
|
+
return result;
|
5619
|
+
}
|
5620
|
+
|
5441
5621
|
// ggml_scale
|
5442
5622
|
|
5443
5623
|
struct ggml_tensor * ggml_scale_impl(
|
@@ -5450,7 +5630,7 @@ struct ggml_tensor * ggml_scale_impl(
|
|
5450
5630
|
|
5451
5631
|
bool is_node = false;
|
5452
5632
|
|
5453
|
-
if (
|
5633
|
+
if (a->grad || b->grad) {
|
5454
5634
|
is_node = true;
|
5455
5635
|
}
|
5456
5636
|
|
@@ -5493,7 +5673,7 @@ struct ggml_tensor * ggml_set_impl(
|
|
5493
5673
|
|
5494
5674
|
bool is_node = false;
|
5495
5675
|
|
5496
|
-
if (
|
5676
|
+
if (a->grad || b->grad) {
|
5497
5677
|
is_node = true;
|
5498
5678
|
}
|
5499
5679
|
|
@@ -5802,14 +5982,18 @@ struct ggml_tensor * ggml_view_1d(
|
|
5802
5982
|
|
5803
5983
|
struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 1, &ne0, (char *) a->data + offset);
|
5804
5984
|
|
5985
|
+
ggml_scratch_save(ctx);
|
5986
|
+
|
5987
|
+
struct ggml_tensor * offs = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 2);
|
5988
|
+
memcpy(offs->data, &offset, 2*sizeof(int32_t));
|
5989
|
+
|
5990
|
+
ggml_scratch_load(ctx);
|
5991
|
+
|
5805
5992
|
result->op = GGML_OP_VIEW;
|
5806
5993
|
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
5807
5994
|
result->src0 = a;
|
5808
5995
|
result->src1 = NULL;
|
5809
|
-
|
5810
|
-
if (is_node) {
|
5811
|
-
memcpy(result->padding, &offset, sizeof(offset));
|
5812
|
-
}
|
5996
|
+
result->opt[0] = offs;
|
5813
5997
|
|
5814
5998
|
return result;
|
5815
5999
|
}
|
@@ -5834,6 +6018,13 @@ struct ggml_tensor * ggml_view_2d(
|
|
5834
6018
|
|
5835
6019
|
struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 2, ne, (char *) a->data + offset);
|
5836
6020
|
|
6021
|
+
ggml_scratch_save(ctx);
|
6022
|
+
|
6023
|
+
struct ggml_tensor * offs = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 2);
|
6024
|
+
memcpy(offs->data, &offset, 2*sizeof(int32_t));
|
6025
|
+
|
6026
|
+
ggml_scratch_load(ctx);
|
6027
|
+
|
5837
6028
|
result->nb[1] = nb1;
|
5838
6029
|
result->nb[2] = result->nb[1]*ne1;
|
5839
6030
|
result->nb[3] = result->nb[2];
|
@@ -5842,10 +6033,7 @@ struct ggml_tensor * ggml_view_2d(
|
|
5842
6033
|
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
5843
6034
|
result->src0 = a;
|
5844
6035
|
result->src1 = NULL;
|
5845
|
-
|
5846
|
-
if (is_node) {
|
5847
|
-
memcpy(result->padding, &offset, sizeof(offset));
|
5848
|
-
}
|
6036
|
+
result->opt[0] = offs;
|
5849
6037
|
|
5850
6038
|
return result;
|
5851
6039
|
}
|
@@ -5872,6 +6060,13 @@ struct ggml_tensor * ggml_view_3d(
|
|
5872
6060
|
|
5873
6061
|
struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 3, ne, (char *) a->data + offset);
|
5874
6062
|
|
6063
|
+
ggml_scratch_save(ctx);
|
6064
|
+
|
6065
|
+
struct ggml_tensor * offs = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 2);
|
6066
|
+
memcpy(offs->data, &offset, 2*sizeof(int32_t));
|
6067
|
+
|
6068
|
+
ggml_scratch_load(ctx);
|
6069
|
+
|
5875
6070
|
result->nb[1] = nb1;
|
5876
6071
|
result->nb[2] = nb2;
|
5877
6072
|
result->nb[3] = result->nb[2]*ne2;
|
@@ -5880,10 +6075,7 @@ struct ggml_tensor * ggml_view_3d(
|
|
5880
6075
|
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
5881
6076
|
result->src0 = a;
|
5882
6077
|
result->src1 = NULL;
|
5883
|
-
|
5884
|
-
if (is_node) {
|
5885
|
-
memcpy(result->padding, &offset, sizeof(offset));
|
5886
|
-
}
|
6078
|
+
result->opt[0] = offs;
|
5887
6079
|
|
5888
6080
|
return result;
|
5889
6081
|
}
|
@@ -5912,6 +6104,13 @@ struct ggml_tensor * ggml_view_4d(
|
|
5912
6104
|
|
5913
6105
|
struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 4, ne, (char *) a->data + offset);
|
5914
6106
|
|
6107
|
+
ggml_scratch_save(ctx);
|
6108
|
+
|
6109
|
+
struct ggml_tensor * offs = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 2);
|
6110
|
+
memcpy(offs->data, &offset, 2*sizeof(int32_t));
|
6111
|
+
|
6112
|
+
ggml_scratch_load(ctx);
|
6113
|
+
|
5915
6114
|
result->nb[1] = nb1;
|
5916
6115
|
result->nb[2] = nb2;
|
5917
6116
|
result->nb[3] = nb3;
|
@@ -5920,10 +6119,7 @@ struct ggml_tensor * ggml_view_4d(
|
|
5920
6119
|
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
5921
6120
|
result->src0 = a;
|
5922
6121
|
result->src1 = NULL;
|
5923
|
-
|
5924
|
-
if (is_node) {
|
5925
|
-
memcpy(result->padding, &offset, sizeof(offset));
|
5926
|
-
}
|
6122
|
+
result->opt[0] = offs;
|
5927
6123
|
|
5928
6124
|
return result;
|
5929
6125
|
}
|
@@ -5986,10 +6182,18 @@ struct ggml_tensor * ggml_permute(
|
|
5986
6182
|
result->src1 = NULL;
|
5987
6183
|
|
5988
6184
|
if (is_node) {
|
5989
|
-
|
5990
|
-
|
5991
|
-
|
5992
|
-
|
6185
|
+
ggml_scratch_save(ctx);
|
6186
|
+
|
6187
|
+
struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 4);
|
6188
|
+
|
6189
|
+
((int32_t *) b->data)[0] = axis0;
|
6190
|
+
((int32_t *) b->data)[1] = axis1;
|
6191
|
+
((int32_t *) b->data)[2] = axis2;
|
6192
|
+
((int32_t *) b->data)[3] = axis3;
|
6193
|
+
|
6194
|
+
ggml_scratch_load(ctx);
|
6195
|
+
|
6196
|
+
result->opt[0] = b;
|
5993
6197
|
}
|
5994
6198
|
|
5995
6199
|
return result;
|
@@ -6229,6 +6433,44 @@ struct ggml_tensor * ggml_soft_max_inplace(
|
|
6229
6433
|
return ggml_soft_max_impl(ctx, a, true);
|
6230
6434
|
}
|
6231
6435
|
|
6436
|
+
|
6437
|
+
// ggml_soft_max_back
|
6438
|
+
|
6439
|
+
struct ggml_tensor * ggml_soft_max_back_impl(
|
6440
|
+
struct ggml_context * ctx,
|
6441
|
+
struct ggml_tensor * a,
|
6442
|
+
struct ggml_tensor * b,
|
6443
|
+
bool inplace) {
|
6444
|
+
bool is_node = false;
|
6445
|
+
|
6446
|
+
if (a->grad || b->grad) {
|
6447
|
+
is_node = true; // TODO : implement backward pass
|
6448
|
+
}
|
6449
|
+
|
6450
|
+
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
6451
|
+
|
6452
|
+
result->op = GGML_OP_SOFT_MAX_BACK;
|
6453
|
+
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
6454
|
+
result->src0 = a;
|
6455
|
+
result->src1 = b;
|
6456
|
+
|
6457
|
+
return result;
|
6458
|
+
}
|
6459
|
+
|
6460
|
+
struct ggml_tensor * ggml_soft_max_back(
|
6461
|
+
struct ggml_context * ctx,
|
6462
|
+
struct ggml_tensor * a,
|
6463
|
+
struct ggml_tensor * b) {
|
6464
|
+
return ggml_soft_max_back_impl(ctx, a, b, false);
|
6465
|
+
}
|
6466
|
+
|
6467
|
+
struct ggml_tensor * ggml_soft_max_back_inplace(
|
6468
|
+
struct ggml_context * ctx,
|
6469
|
+
struct ggml_tensor * a,
|
6470
|
+
struct ggml_tensor * b) {
|
6471
|
+
return ggml_soft_max_back_impl(ctx, a, b, true);
|
6472
|
+
}
|
6473
|
+
|
6232
6474
|
// ggml_rope
|
6233
6475
|
|
6234
6476
|
struct ggml_tensor * ggml_rope_impl(
|
@@ -6241,7 +6483,7 @@ struct ggml_tensor * ggml_rope_impl(
|
|
6241
6483
|
GGML_ASSERT(n_past >= 0);
|
6242
6484
|
bool is_node = false;
|
6243
6485
|
|
6244
|
-
if (
|
6486
|
+
if (a->grad) {
|
6245
6487
|
is_node = true;
|
6246
6488
|
}
|
6247
6489
|
|
@@ -6295,8 +6537,7 @@ struct ggml_tensor * ggml_rope_back(
|
|
6295
6537
|
bool is_node = false;
|
6296
6538
|
|
6297
6539
|
if (a->grad) {
|
6298
|
-
|
6299
|
-
is_node = true;
|
6540
|
+
is_node = false; // TODO: implement backward
|
6300
6541
|
}
|
6301
6542
|
|
6302
6543
|
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
|
@@ -6461,7 +6702,6 @@ struct ggml_tensor * ggml_flash_attn(
|
|
6461
6702
|
bool is_node = false;
|
6462
6703
|
|
6463
6704
|
if (q->grad || k->grad || v->grad) {
|
6464
|
-
GGML_ASSERT(false); // TODO: implement backward
|
6465
6705
|
is_node = true;
|
6466
6706
|
}
|
6467
6707
|
|
@@ -6493,7 +6733,6 @@ struct ggml_tensor * ggml_flash_ff(
|
|
6493
6733
|
bool is_node = false;
|
6494
6734
|
|
6495
6735
|
if (a->grad || b0->grad || b1->grad || c0->grad || c1->grad) {
|
6496
|
-
GGML_ASSERT(false); // TODO: implement backward
|
6497
6736
|
is_node = true;
|
6498
6737
|
}
|
6499
6738
|
|
@@ -6511,6 +6750,71 @@ struct ggml_tensor * ggml_flash_ff(
|
|
6511
6750
|
return result;
|
6512
6751
|
}
|
6513
6752
|
|
6753
|
+
// ggml_flash_attn_back
|
6754
|
+
|
6755
|
+
struct ggml_tensor * ggml_flash_attn_back(
|
6756
|
+
struct ggml_context * ctx,
|
6757
|
+
struct ggml_tensor * q,
|
6758
|
+
struct ggml_tensor * k,
|
6759
|
+
struct ggml_tensor * v,
|
6760
|
+
struct ggml_tensor * d,
|
6761
|
+
bool masked) {
|
6762
|
+
GGML_ASSERT(ggml_can_mul_mat(k, q));
|
6763
|
+
// TODO: check if vT can be multiplied by (k*qT)
|
6764
|
+
|
6765
|
+
// d shape [D,N,ne2,ne3]
|
6766
|
+
// q shape [D,N,ne2,ne3]
|
6767
|
+
// k shape [D,M,ne2,ne3]
|
6768
|
+
// v shape [M,D,ne2,ne3]
|
6769
|
+
|
6770
|
+
const int64_t D = q->ne[0];
|
6771
|
+
const int64_t N = q->ne[1];
|
6772
|
+
const int64_t M = k->ne[1];
|
6773
|
+
const int64_t ne2 = q->ne[2];
|
6774
|
+
const int64_t ne3 = q->ne[3];
|
6775
|
+
|
6776
|
+
GGML_ASSERT(k->ne[0] == D);
|
6777
|
+
GGML_ASSERT(v->ne[0] == M);
|
6778
|
+
GGML_ASSERT(v->ne[1] == D);
|
6779
|
+
GGML_ASSERT(d->ne[0] == D);
|
6780
|
+
GGML_ASSERT(d->ne[1] == N);
|
6781
|
+
GGML_ASSERT(k->ne[2] == ne2);
|
6782
|
+
GGML_ASSERT(k->ne[3] == ne3);
|
6783
|
+
GGML_ASSERT(v->ne[2] == ne2);
|
6784
|
+
GGML_ASSERT(v->ne[3] == ne3);
|
6785
|
+
GGML_ASSERT(d->ne[2] == ne2);
|
6786
|
+
GGML_ASSERT(d->ne[3] == ne3);
|
6787
|
+
|
6788
|
+
bool is_node = false;
|
6789
|
+
|
6790
|
+
if (q->grad || k->grad || v->grad) {
|
6791
|
+
// when using this operation (in backwards pass) these grads are set.
|
6792
|
+
// we don't want to create (big) grad of our result, so is_node is false.
|
6793
|
+
is_node = false;
|
6794
|
+
}
|
6795
|
+
|
6796
|
+
// store gradients of q, k and v as continuous tensors concatenated in result.
|
6797
|
+
// q shape[D,N,ne2,ne3] ; k shape [D,M,ne2,ne3] ; v shape [M,D,ne2,ne3]
|
6798
|
+
// gradq->data = result->data
|
6799
|
+
// gradk->data = result->data + nb0*D*N*ne2*ne3
|
6800
|
+
// gradv->data = result->data + nb0*D*N*ne2*ne3 + nb0*D*M*ne2*ne3
|
6801
|
+
// note: v and gradv are actually transposed, i.e. v->ne[0] != D.
|
6802
|
+
int64_t ne[4] = {D,M+N+M,ne2,ne3};
|
6803
|
+
|
6804
|
+
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
|
6805
|
+
|
6806
|
+
result->op = GGML_OP_FLASH_ATTN_BACK;
|
6807
|
+
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
6808
|
+
result->src0 = q;
|
6809
|
+
result->src1 = k;
|
6810
|
+
result->opt[0] = v;
|
6811
|
+
result->opt[1] = d;
|
6812
|
+
result->opt[2] = ggml_new_i32(ctx, masked ? 1 : 0);
|
6813
|
+
|
6814
|
+
return result;
|
6815
|
+
}
|
6816
|
+
|
6817
|
+
|
6514
6818
|
// ggml_map_unary
|
6515
6819
|
|
6516
6820
|
struct ggml_tensor * ggml_map_unary_impl_f32(
|
@@ -6595,6 +6899,50 @@ struct ggml_tensor * ggml_map_binary_inplace_f32(
|
|
6595
6899
|
return ggml_map_binary_impl_f32(ctx, a, b, fun, true);
|
6596
6900
|
}
|
6597
6901
|
|
6902
|
+
// ggml_cross_entropy_loss
|
6903
|
+
|
6904
|
+
struct ggml_tensor * ggml_cross_entropy_loss(
|
6905
|
+
struct ggml_context * ctx,
|
6906
|
+
struct ggml_tensor * a,
|
6907
|
+
struct ggml_tensor * b) {
|
6908
|
+
GGML_ASSERT(ggml_are_same_shape(a, b));
|
6909
|
+
bool is_node = false;
|
6910
|
+
|
6911
|
+
if (a->grad || b->grad) {
|
6912
|
+
is_node = true;
|
6913
|
+
}
|
6914
|
+
|
6915
|
+
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, a->type, 1);
|
6916
|
+
|
6917
|
+
result->op = GGML_OP_CROSS_ENTROPY_LOSS;
|
6918
|
+
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
6919
|
+
result->src0 = a;
|
6920
|
+
result->src1 = b;
|
6921
|
+
|
6922
|
+
return result;
|
6923
|
+
}
|
6924
|
+
|
6925
|
+
// ggml_cross_entropy_loss_back
|
6926
|
+
|
6927
|
+
struct ggml_tensor * ggml_cross_entropy_loss_back(
|
6928
|
+
struct ggml_context * ctx,
|
6929
|
+
struct ggml_tensor * a,
|
6930
|
+
struct ggml_tensor * b,
|
6931
|
+
struct ggml_tensor * c) {
|
6932
|
+
GGML_ASSERT(ggml_are_same_shape(a, b));
|
6933
|
+
GGML_ASSERT(ggml_is_scalar(c));
|
6934
|
+
|
6935
|
+
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
|
6936
|
+
|
6937
|
+
result->op = GGML_OP_CROSS_ENTROPY_LOSS_BACK;
|
6938
|
+
result->grad = NULL;
|
6939
|
+
result->src0 = a;
|
6940
|
+
result->src1 = b;
|
6941
|
+
result->opt[0] = c;
|
6942
|
+
|
6943
|
+
return result;
|
6944
|
+
}
|
6945
|
+
|
6598
6946
|
////////////////////////////////////////////////////////////////////////////////
|
6599
6947
|
|
6600
6948
|
void ggml_set_param(
|
@@ -7584,6 +7932,11 @@ static void ggml_compute_forward_add(
|
|
7584
7932
|
case GGML_TYPE_Q5_0:
|
7585
7933
|
case GGML_TYPE_Q5_1:
|
7586
7934
|
case GGML_TYPE_Q8_0:
|
7935
|
+
case GGML_TYPE_Q2_K:
|
7936
|
+
case GGML_TYPE_Q3_K:
|
7937
|
+
case GGML_TYPE_Q4_K:
|
7938
|
+
case GGML_TYPE_Q5_K:
|
7939
|
+
case GGML_TYPE_Q6_K:
|
7587
7940
|
{
|
7588
7941
|
ggml_compute_forward_add_q_f32(params, src0, src1, dst);
|
7589
7942
|
} break;
|
@@ -7887,6 +8240,11 @@ static void ggml_compute_forward_add1(
|
|
7887
8240
|
case GGML_TYPE_Q5_1:
|
7888
8241
|
case GGML_TYPE_Q8_0:
|
7889
8242
|
case GGML_TYPE_Q8_1:
|
8243
|
+
case GGML_TYPE_Q2_K:
|
8244
|
+
case GGML_TYPE_Q3_K:
|
8245
|
+
case GGML_TYPE_Q4_K:
|
8246
|
+
case GGML_TYPE_Q5_K:
|
8247
|
+
case GGML_TYPE_Q6_K:
|
7890
8248
|
{
|
7891
8249
|
ggml_compute_forward_add1_q_f32(params, src0, src1, dst);
|
7892
8250
|
} break;
|
@@ -8009,6 +8367,11 @@ static void ggml_compute_forward_acc(
|
|
8009
8367
|
case GGML_TYPE_Q5_1:
|
8010
8368
|
case GGML_TYPE_Q8_0:
|
8011
8369
|
case GGML_TYPE_Q8_1:
|
8370
|
+
case GGML_TYPE_Q2_K:
|
8371
|
+
case GGML_TYPE_Q3_K:
|
8372
|
+
case GGML_TYPE_Q4_K:
|
8373
|
+
case GGML_TYPE_Q5_K:
|
8374
|
+
case GGML_TYPE_Q6_K:
|
8012
8375
|
default:
|
8013
8376
|
{
|
8014
8377
|
GGML_ASSERT(false);
|
@@ -8127,10 +8490,10 @@ static void ggml_compute_forward_mul_f32(
|
|
8127
8490
|
const int ith = params->ith;
|
8128
8491
|
const int nth = params->nth;
|
8129
8492
|
|
8130
|
-
#ifdef
|
8131
|
-
if (src1->backend ==
|
8493
|
+
#ifdef GGML_USE_CLBLAST
|
8494
|
+
if (src1->backend == GGML_BACKEND_GPU) {
|
8132
8495
|
if (ith == 0) {
|
8133
|
-
|
8496
|
+
ggml_cl_mul(src0, src1, dst);
|
8134
8497
|
}
|
8135
8498
|
return;
|
8136
8499
|
}
|
@@ -8730,29 +9093,122 @@ static void ggml_compute_forward_repeat(
|
|
8730
9093
|
}
|
8731
9094
|
}
|
8732
9095
|
|
8733
|
-
//
|
9096
|
+
// ggml_compute_forward_repeat_back
|
8734
9097
|
|
8735
|
-
static void
|
9098
|
+
static void ggml_compute_forward_repeat_back_f32(
|
8736
9099
|
const struct ggml_compute_params * params,
|
8737
9100
|
const struct ggml_tensor * src0,
|
8738
9101
|
struct ggml_tensor * dst) {
|
8739
|
-
|
8740
|
-
|
9102
|
+
GGML_ASSERT(params->ith == 0);
|
9103
|
+
GGML_ASSERT(ggml_can_repeat(dst, src0));
|
8741
9104
|
|
8742
9105
|
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
8743
9106
|
return;
|
8744
9107
|
}
|
8745
9108
|
|
8746
|
-
const
|
8747
|
-
const
|
9109
|
+
const int64_t ne0 = dst->ne[0];
|
9110
|
+
const int64_t ne1 = dst->ne[1];
|
9111
|
+
const int64_t ne2 = dst->ne[2];
|
9112
|
+
const int64_t ne3 = dst->ne[3];
|
8748
9113
|
|
8749
|
-
|
8750
|
-
|
9114
|
+
const int64_t ne00 = src0->ne[0];
|
9115
|
+
const int64_t ne01 = src0->ne[1];
|
9116
|
+
const int64_t ne02 = src0->ne[2];
|
9117
|
+
const int64_t ne03 = src0->ne[3];
|
8751
9118
|
|
8752
|
-
|
8753
|
-
|
8754
|
-
|
8755
|
-
|
9119
|
+
const size_t nb0 = dst->nb[0];
|
9120
|
+
const size_t nb1 = dst->nb[1];
|
9121
|
+
const size_t nb2 = dst->nb[2];
|
9122
|
+
const size_t nb3 = dst->nb[3];
|
9123
|
+
|
9124
|
+
const size_t nb00 = src0->nb[0];
|
9125
|
+
const size_t nb01 = src0->nb[1];
|
9126
|
+
const size_t nb02 = src0->nb[2];
|
9127
|
+
const size_t nb03 = src0->nb[3];
|
9128
|
+
|
9129
|
+
// guaranteed to be an integer due to the check in ggml_can_repeat
|
9130
|
+
const int nr0 = (int)(ne00/ne0);
|
9131
|
+
const int nr1 = (int)(ne01/ne1);
|
9132
|
+
const int nr2 = (int)(ne02/ne2);
|
9133
|
+
const int nr3 = (int)(ne03/ne3);
|
9134
|
+
|
9135
|
+
// TODO: support for transposed / permuted tensors
|
9136
|
+
GGML_ASSERT(nb0 == sizeof(float));
|
9137
|
+
GGML_ASSERT(nb00 == sizeof(float));
|
9138
|
+
|
9139
|
+
if (ggml_is_contiguous(dst)) {
|
9140
|
+
ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
|
9141
|
+
} else {
|
9142
|
+
for (int k3 = 0; k3 < ne3; k3++) {
|
9143
|
+
for (int k2 = 0; k2 < ne2; k2++) {
|
9144
|
+
for (int k1 = 0; k1 < ne1; k1++) {
|
9145
|
+
ggml_vec_set_f32(ne0,
|
9146
|
+
(float *) ((char *) dst->data + k1*nb1 + k2*nb2 + k3*nb3),
|
9147
|
+
0);
|
9148
|
+
}
|
9149
|
+
}
|
9150
|
+
}
|
9151
|
+
}
|
9152
|
+
|
9153
|
+
// TODO: maybe this is not optimal?
|
9154
|
+
for (int i3 = 0; i3 < nr3; i3++) {
|
9155
|
+
for (int k3 = 0; k3 < ne3; k3++) {
|
9156
|
+
for (int i2 = 0; i2 < nr2; i2++) {
|
9157
|
+
for (int k2 = 0; k2 < ne2; k2++) {
|
9158
|
+
for (int i1 = 0; i1 < nr1; i1++) {
|
9159
|
+
for (int k1 = 0; k1 < ne1; k1++) {
|
9160
|
+
for (int i0 = 0; i0 < nr0; i0++) {
|
9161
|
+
ggml_vec_acc_f32(ne0,
|
9162
|
+
(float *) ((char *) dst->data + ( k3)*nb3 + ( k2)*nb2 + ( k1)*nb1),
|
9163
|
+
(float *) ((char *) src0->data + (i3*ne3 + k3)*nb03 + (i2*ne2 + k2)*nb02 + (i1*ne1 + k1)*nb01 + (i0*ne0)*nb00));
|
9164
|
+
}
|
9165
|
+
}
|
9166
|
+
}
|
9167
|
+
}
|
9168
|
+
}
|
9169
|
+
}
|
9170
|
+
}
|
9171
|
+
}
|
9172
|
+
|
9173
|
+
static void ggml_compute_forward_repeat_back(
|
9174
|
+
const struct ggml_compute_params * params,
|
9175
|
+
const struct ggml_tensor * src0,
|
9176
|
+
struct ggml_tensor * dst) {
|
9177
|
+
switch (src0->type) {
|
9178
|
+
case GGML_TYPE_F32:
|
9179
|
+
{
|
9180
|
+
ggml_compute_forward_repeat_back_f32(params, src0, dst);
|
9181
|
+
} break;
|
9182
|
+
default:
|
9183
|
+
{
|
9184
|
+
GGML_ASSERT(false);
|
9185
|
+
} break;
|
9186
|
+
}
|
9187
|
+
}
|
9188
|
+
|
9189
|
+
// ggml_compute_forward_abs
|
9190
|
+
|
9191
|
+
static void ggml_compute_forward_abs_f32(
|
9192
|
+
const struct ggml_compute_params * params,
|
9193
|
+
const struct ggml_tensor * src0,
|
9194
|
+
struct ggml_tensor * dst) {
|
9195
|
+
assert(params->ith == 0);
|
9196
|
+
assert(ggml_are_same_shape(src0, dst));
|
9197
|
+
|
9198
|
+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
9199
|
+
return;
|
9200
|
+
}
|
9201
|
+
|
9202
|
+
const int n = ggml_nrows(src0);
|
9203
|
+
const int nc = src0->ne[0];
|
9204
|
+
|
9205
|
+
assert(dst->nb[0] == sizeof(float));
|
9206
|
+
assert(src0->nb[0] == sizeof(float));
|
9207
|
+
|
9208
|
+
for (int i = 0; i < n; i++) {
|
9209
|
+
ggml_vec_abs_f32(nc,
|
9210
|
+
(float *) ((char *) dst->data + i*( dst->nb[1])),
|
9211
|
+
(float *) ((char *) src0->data + i*(src0->nb[1])));
|
8756
9212
|
}
|
8757
9213
|
}
|
8758
9214
|
|
@@ -9245,7 +9701,7 @@ static void ggml_compute_forward_rms_norm_f32(
|
|
9245
9701
|
sum += (ggml_float)(x[i00] * x[i00]);
|
9246
9702
|
}
|
9247
9703
|
|
9248
|
-
float mean = sum/ne00;
|
9704
|
+
const float mean = sum/ne00;
|
9249
9705
|
|
9250
9706
|
float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
|
9251
9707
|
|
@@ -9568,14 +10024,7 @@ static void ggml_compute_forward_mul_mat_f32(
|
|
9568
10024
|
// nb01 >= nb00 - src0 is not transposed
|
9569
10025
|
// compute by src0 rows
|
9570
10026
|
|
9571
|
-
#if defined(
|
9572
|
-
if (ggml_cuda_can_mul_mat(src0, src1, dst)) {
|
9573
|
-
if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
|
9574
|
-
ggml_cuda_mul_mat(src0, src1, dst, params->wdata, params->wsize);
|
9575
|
-
}
|
9576
|
-
return;
|
9577
|
-
}
|
9578
|
-
#elif defined(GGML_USE_CLBLAST)
|
10027
|
+
#if defined(GGML_USE_CLBLAST)
|
9579
10028
|
if (ggml_cl_can_mul_mat(src0, src1, dst)) {
|
9580
10029
|
if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
|
9581
10030
|
ggml_cl_mul_mat(src0, src1, dst, params->wdata, params->wsize);
|
@@ -9740,14 +10189,7 @@ static void ggml_compute_forward_mul_mat_f16_f32(
|
|
9740
10189
|
// nb01 >= nb00 - src0 is not transposed
|
9741
10190
|
// compute by src0 rows
|
9742
10191
|
|
9743
|
-
#if defined(
|
9744
|
-
if (ggml_cuda_can_mul_mat(src0, src1, dst)) {
|
9745
|
-
if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
|
9746
|
-
ggml_cuda_mul_mat(src0, src1, dst, params->wdata, params->wsize);
|
9747
|
-
}
|
9748
|
-
return;
|
9749
|
-
}
|
9750
|
-
#elif defined(GGML_USE_CLBLAST)
|
10192
|
+
#if defined(GGML_USE_CLBLAST)
|
9751
10193
|
if (ggml_cl_can_mul_mat(src0, src1, dst)) {
|
9752
10194
|
if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
|
9753
10195
|
ggml_cl_mul_mat(src0, src1, dst, params->wdata, params->wsize);
|
@@ -9952,14 +10394,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|
9952
10394
|
// nb01 >= nb00 - src0 is not transposed
|
9953
10395
|
// compute by src0 rows
|
9954
10396
|
|
9955
|
-
#if defined(
|
9956
|
-
if (ggml_cuda_can_mul_mat(src0, src1, dst)) {
|
9957
|
-
if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
|
9958
|
-
ggml_cuda_mul_mat(src0, src1, dst, params->wdata, params->wsize);
|
9959
|
-
}
|
9960
|
-
return;
|
9961
|
-
}
|
9962
|
-
#elif defined(GGML_USE_CLBLAST)
|
10397
|
+
#if defined(GGML_USE_CLBLAST)
|
9963
10398
|
if (ggml_cl_can_mul_mat(src0, src1, dst)) {
|
9964
10399
|
if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
|
9965
10400
|
ggml_cl_mul_mat(src0, src1, dst, params->wdata, params->wsize);
|
@@ -10102,6 +10537,11 @@ static void ggml_compute_forward_mul_mat(
|
|
10102
10537
|
case GGML_TYPE_Q5_1:
|
10103
10538
|
case GGML_TYPE_Q8_0:
|
10104
10539
|
case GGML_TYPE_Q8_1:
|
10540
|
+
case GGML_TYPE_Q2_K:
|
10541
|
+
case GGML_TYPE_Q3_K:
|
10542
|
+
case GGML_TYPE_Q4_K:
|
10543
|
+
case GGML_TYPE_Q5_K:
|
10544
|
+
case GGML_TYPE_Q6_K:
|
10105
10545
|
{
|
10106
10546
|
ggml_compute_forward_mul_mat_q_f32(params, src0, src1, dst);
|
10107
10547
|
} break;
|
@@ -10120,6 +10560,176 @@ static void ggml_compute_forward_mul_mat(
|
|
10120
10560
|
}
|
10121
10561
|
}
|
10122
10562
|
|
10563
|
+
// ggml_compute_forward_out_prod
|
10564
|
+
|
10565
|
+
|
10566
|
+
static void ggml_compute_forward_out_prod_f32(
|
10567
|
+
const struct ggml_compute_params * params,
|
10568
|
+
const struct ggml_tensor * src0,
|
10569
|
+
const struct ggml_tensor * src1,
|
10570
|
+
struct ggml_tensor * dst) {
|
10571
|
+
int64_t t0 = ggml_perf_time_us();
|
10572
|
+
UNUSED(t0);
|
10573
|
+
|
10574
|
+
const int64_t ne00 = src0->ne[0];
|
10575
|
+
const int64_t ne01 = src0->ne[1];
|
10576
|
+
const int64_t ne02 = src0->ne[2];
|
10577
|
+
const int64_t ne03 = src0->ne[3];
|
10578
|
+
|
10579
|
+
const int64_t ne10 = src1->ne[0];
|
10580
|
+
//const int64_t ne11 = src1->ne[1];
|
10581
|
+
const int64_t ne12 = src1->ne[2];
|
10582
|
+
const int64_t ne13 = src1->ne[3];
|
10583
|
+
|
10584
|
+
const int64_t ne0 = dst->ne[0];
|
10585
|
+
const int64_t ne1 = dst->ne[1];
|
10586
|
+
const int64_t ne2 = dst->ne[2];
|
10587
|
+
const int64_t ne3 = dst->ne[3];
|
10588
|
+
|
10589
|
+
const int nb00 = src0->nb[0];
|
10590
|
+
const int nb01 = src0->nb[1];
|
10591
|
+
const int nb02 = src0->nb[2];
|
10592
|
+
const int nb03 = src0->nb[3];
|
10593
|
+
|
10594
|
+
const int nb10 = src1->nb[0];
|
10595
|
+
const int nb11 = src1->nb[1];
|
10596
|
+
const int nb12 = src1->nb[2];
|
10597
|
+
const int nb13 = src1->nb[3];
|
10598
|
+
|
10599
|
+
const int nb0 = dst->nb[0];
|
10600
|
+
const int nb1 = dst->nb[1];
|
10601
|
+
const int nb2 = dst->nb[2];
|
10602
|
+
const int nb3 = dst->nb[3];
|
10603
|
+
|
10604
|
+
const int ith = params->ith;
|
10605
|
+
const int nth = params->nth;
|
10606
|
+
|
10607
|
+
GGML_ASSERT(ne02 == ne12);
|
10608
|
+
GGML_ASSERT(ne03 == ne13);
|
10609
|
+
GGML_ASSERT(ne2 == ne12);
|
10610
|
+
GGML_ASSERT(ne3 == ne13);
|
10611
|
+
|
10612
|
+
// we don't support permuted src0 or src1
|
10613
|
+
GGML_ASSERT(nb00 == sizeof(float));
|
10614
|
+
|
10615
|
+
// dst cannot be transposed or permuted
|
10616
|
+
GGML_ASSERT(nb0 == sizeof(float));
|
10617
|
+
// GGML_ASSERT(nb0 <= nb1);
|
10618
|
+
// GGML_ASSERT(nb1 <= nb2);
|
10619
|
+
// GGML_ASSERT(nb2 <= nb3);
|
10620
|
+
|
10621
|
+
GGML_ASSERT(ne0 == ne00);
|
10622
|
+
GGML_ASSERT(ne1 == ne10);
|
10623
|
+
GGML_ASSERT(ne2 == ne02);
|
10624
|
+
GGML_ASSERT(ne3 == ne03);
|
10625
|
+
|
10626
|
+
// nb01 >= nb00 - src0 is not transposed
|
10627
|
+
// compute by src0 rows
|
10628
|
+
|
10629
|
+
// TODO: #if defined(GGML_USE_CUBLAS) ggml_cuda_out_prod
|
10630
|
+
// TODO: #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
|
10631
|
+
|
10632
|
+
if (params->type == GGML_TASK_INIT) {
|
10633
|
+
ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
|
10634
|
+
return;
|
10635
|
+
}
|
10636
|
+
|
10637
|
+
if (params->type == GGML_TASK_FINALIZE) {
|
10638
|
+
return;
|
10639
|
+
}
|
10640
|
+
|
10641
|
+
// parallelize by last three dimensions
|
10642
|
+
|
10643
|
+
// total rows in dst
|
10644
|
+
const int64_t nr = ne1*ne2*ne3;
|
10645
|
+
|
10646
|
+
// rows per thread
|
10647
|
+
const int64_t dr = (nr + nth - 1)/nth;
|
10648
|
+
|
10649
|
+
// row range for this thread
|
10650
|
+
const int64_t ir0 = dr*ith;
|
10651
|
+
const int64_t ir1 = MIN(ir0 + dr, nr);
|
10652
|
+
|
10653
|
+
// dst[:,:,:,:] = 0
|
10654
|
+
// for i2,i3:
|
10655
|
+
// for i1:
|
10656
|
+
// for i01:
|
10657
|
+
// for i0:
|
10658
|
+
// dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3]
|
10659
|
+
|
10660
|
+
for (int64_t ir = ir0; ir < ir1; ++ir) {
|
10661
|
+
// dst indices
|
10662
|
+
const int64_t i3 = ir/(ne2*ne1);
|
10663
|
+
const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
|
10664
|
+
const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
|
10665
|
+
|
10666
|
+
const int64_t i02 = i2;
|
10667
|
+
const int64_t i03 = i3;
|
10668
|
+
|
10669
|
+
//const int64_t i10 = i1;
|
10670
|
+
const int64_t i12 = i2;
|
10671
|
+
const int64_t i13 = i3;
|
10672
|
+
|
10673
|
+
for (int64_t i01 = 0; i01 < ne01; ++i01) {
|
10674
|
+
const int64_t i11 = i01;
|
10675
|
+
|
10676
|
+
float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
|
10677
|
+
float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
|
10678
|
+
float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
|
10679
|
+
|
10680
|
+
ggml_vec_mad_f32(ne0, d, s0, *s1);
|
10681
|
+
// for (int64_t i0 = 0; i0 < ne0; ++i0) {
|
10682
|
+
// d[i0] += s0[i0] * s1[i1];
|
10683
|
+
// }
|
10684
|
+
}
|
10685
|
+
}
|
10686
|
+
|
10687
|
+
//int64_t t1 = ggml_perf_time_us();
|
10688
|
+
//static int64_t acc = 0;
|
10689
|
+
//acc += t1 - t0;
|
10690
|
+
//if (t1 - t0 > 10) {
|
10691
|
+
// printf("\n");
|
10692
|
+
// printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03);
|
10693
|
+
// printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03);
|
10694
|
+
// printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13);
|
10695
|
+
// printf("nb10 = %5d, nb11 = %5d, nb12 = %5d, nb13 = %5d\n", nb10, nb11, nb12, nb13);
|
10696
|
+
|
10697
|
+
// printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc);
|
10698
|
+
//}
|
10699
|
+
}
|
10700
|
+
|
10701
|
+
static void ggml_compute_forward_out_prod(
|
10702
|
+
const struct ggml_compute_params * params,
|
10703
|
+
const struct ggml_tensor * src0,
|
10704
|
+
const struct ggml_tensor * src1,
|
10705
|
+
struct ggml_tensor * dst) {
|
10706
|
+
switch (src0->type) {
|
10707
|
+
case GGML_TYPE_Q4_0:
|
10708
|
+
case GGML_TYPE_Q4_1:
|
10709
|
+
case GGML_TYPE_Q5_0:
|
10710
|
+
case GGML_TYPE_Q5_1:
|
10711
|
+
case GGML_TYPE_Q8_0:
|
10712
|
+
case GGML_TYPE_Q8_1:
|
10713
|
+
{
|
10714
|
+
GGML_ASSERT(false); // todo
|
10715
|
+
// ggml_compute_forward_out_prod_q_f32(params, src0, src1, dst);
|
10716
|
+
} break;
|
10717
|
+
case GGML_TYPE_F16:
|
10718
|
+
{
|
10719
|
+
GGML_ASSERT(false); // todo
|
10720
|
+
// ggml_compute_forward_out_prod_f16_f32(params, src0, src1, dst);
|
10721
|
+
} break;
|
10722
|
+
case GGML_TYPE_F32:
|
10723
|
+
{
|
10724
|
+
ggml_compute_forward_out_prod_f32(params, src0, src1, dst);
|
10725
|
+
} break;
|
10726
|
+
default:
|
10727
|
+
{
|
10728
|
+
GGML_ASSERT(false);
|
10729
|
+
} break;
|
10730
|
+
}
|
10731
|
+
}
|
10732
|
+
|
10123
10733
|
// ggml_compute_forward_scale
|
10124
10734
|
|
10125
10735
|
static void ggml_compute_forward_scale_f32(
|
@@ -10285,6 +10895,11 @@ static void ggml_compute_forward_set(
|
|
10285
10895
|
case GGML_TYPE_Q5_1:
|
10286
10896
|
case GGML_TYPE_Q8_0:
|
10287
10897
|
case GGML_TYPE_Q8_1:
|
10898
|
+
case GGML_TYPE_Q2_K:
|
10899
|
+
case GGML_TYPE_Q3_K:
|
10900
|
+
case GGML_TYPE_Q4_K:
|
10901
|
+
case GGML_TYPE_Q5_K:
|
10902
|
+
case GGML_TYPE_Q6_K:
|
10288
10903
|
default:
|
10289
10904
|
{
|
10290
10905
|
GGML_ASSERT(false);
|
@@ -10450,6 +11065,11 @@ static void ggml_compute_forward_get_rows(
|
|
10450
11065
|
case GGML_TYPE_Q5_1:
|
10451
11066
|
case GGML_TYPE_Q8_0:
|
10452
11067
|
case GGML_TYPE_Q8_1:
|
11068
|
+
case GGML_TYPE_Q2_K:
|
11069
|
+
case GGML_TYPE_Q3_K:
|
11070
|
+
case GGML_TYPE_Q4_K:
|
11071
|
+
case GGML_TYPE_Q5_K:
|
11072
|
+
case GGML_TYPE_Q6_K:
|
10453
11073
|
{
|
10454
11074
|
ggml_compute_forward_get_rows_q(params, src0, src1, dst);
|
10455
11075
|
} break;
|
@@ -10532,7 +11152,11 @@ static void ggml_compute_forward_get_rows_back_f32(
|
|
10532
11152
|
GGML_ASSERT(ggml_is_contiguous(opt0));
|
10533
11153
|
GGML_ASSERT(ggml_is_contiguous(dst));
|
10534
11154
|
|
10535
|
-
ggml_compute_forward_dup_same_cont(params, opt0, dst);
|
11155
|
+
// ggml_compute_forward_dup_same_cont(params, opt0, dst);
|
11156
|
+
|
11157
|
+
if (params->type == GGML_TASK_INIT) {
|
11158
|
+
memset(dst->data, 0, ggml_nbytes(dst));
|
11159
|
+
}
|
10536
11160
|
|
10537
11161
|
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
10538
11162
|
return;
|
@@ -10676,8 +11300,8 @@ static void ggml_compute_forward_diag_mask_f32(
|
|
10676
11300
|
const struct ggml_tensor * src1,
|
10677
11301
|
struct ggml_tensor * dst,
|
10678
11302
|
const float value) {
|
10679
|
-
|
10680
|
-
|
11303
|
+
GGML_ASSERT(src1->type == GGML_TYPE_I32);
|
11304
|
+
GGML_ASSERT(ggml_nelements(src1) == 2);
|
10681
11305
|
|
10682
11306
|
const int ith = params->ith;
|
10683
11307
|
const int nth = params->nth;
|
@@ -10685,7 +11309,7 @@ static void ggml_compute_forward_diag_mask_f32(
|
|
10685
11309
|
const int n_past = ((int32_t *) src1->data)[0];
|
10686
11310
|
const bool inplace = (bool)((int32_t *) src1->data)[1];
|
10687
11311
|
|
10688
|
-
|
11312
|
+
GGML_ASSERT(n_past >= 0);
|
10689
11313
|
|
10690
11314
|
if (!inplace && (params->type == GGML_TASK_INIT)) {
|
10691
11315
|
// memcpy needs to be synchronized across threads to avoid race conditions.
|
@@ -10709,8 +11333,8 @@ static void ggml_compute_forward_diag_mask_f32(
|
|
10709
11333
|
const int nr = src0->ne[1];
|
10710
11334
|
const int nz = n/nr;
|
10711
11335
|
|
10712
|
-
|
10713
|
-
|
11336
|
+
GGML_ASSERT( dst->nb[0] == sizeof(float));
|
11337
|
+
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
10714
11338
|
|
10715
11339
|
for (int k = 0; k < nz; k++) {
|
10716
11340
|
for (int j = ith; j < nr; j += nth) {
|
@@ -10846,42 +11470,137 @@ static void ggml_compute_forward_soft_max(
|
|
10846
11470
|
}
|
10847
11471
|
}
|
10848
11472
|
|
10849
|
-
//
|
11473
|
+
// ggml_compute_forward_soft_max_back
|
10850
11474
|
|
10851
|
-
static void
|
11475
|
+
static void ggml_compute_forward_soft_max_back_f32(
|
10852
11476
|
const struct ggml_compute_params * params,
|
10853
11477
|
const struct ggml_tensor * src0,
|
10854
11478
|
const struct ggml_tensor * src1,
|
10855
11479
|
struct ggml_tensor * dst) {
|
10856
|
-
|
10857
|
-
|
10858
|
-
|
11480
|
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
11481
|
+
GGML_ASSERT(ggml_is_contiguous(src1));
|
11482
|
+
GGML_ASSERT(ggml_is_contiguous(dst));
|
11483
|
+
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
11484
|
+
GGML_ASSERT(ggml_are_same_shape(src1, dst));
|
10859
11485
|
|
10860
11486
|
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
10861
11487
|
return;
|
10862
11488
|
}
|
10863
11489
|
|
10864
|
-
|
10865
|
-
const int n_head = ((int32_t *) src1->data)[1];
|
10866
|
-
const float max_bias = ((float *) src1->data)[2];
|
11490
|
+
// TODO: handle transposed/permuted matrices
|
10867
11491
|
|
10868
|
-
|
11492
|
+
const int ith = params->ith;
|
11493
|
+
const int nth = params->nth;
|
10869
11494
|
|
10870
|
-
const int
|
10871
|
-
const int
|
10872
|
-
//const int ne2 = src0->ne[2]; // n_head -> this is k
|
10873
|
-
//const int ne3 = src0->ne[3]; // 1 -> bsz
|
11495
|
+
const int nc = src0->ne[0];
|
11496
|
+
const int nr = ggml_nrows(src0);
|
10874
11497
|
|
10875
|
-
|
10876
|
-
const int
|
11498
|
+
// rows per thread
|
11499
|
+
const int dr = (nr + nth - 1)/nth;
|
10877
11500
|
|
10878
|
-
|
10879
|
-
const int
|
10880
|
-
const int
|
10881
|
-
//const int nb3 = src0->nb[3];
|
11501
|
+
// row range for this thread
|
11502
|
+
const int ir0 = dr*ith;
|
11503
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
10882
11504
|
|
10883
|
-
|
10884
|
-
|
11505
|
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
11506
|
+
float *dy = (float *)((char *) src0->data + i1*src0->nb[1]);
|
11507
|
+
float *y = (float *)((char *) src1->data + i1*src1->nb[1]);
|
11508
|
+
float *dx = (float *)((char *) dst->data + i1*dst->nb[1]);
|
11509
|
+
|
11510
|
+
#ifndef NDEBUG
|
11511
|
+
for (int i = 0; i < nc; ++i) {
|
11512
|
+
//printf("p[%d] = %f\n", i, p[i]);
|
11513
|
+
assert(!isnan(dy[i]));
|
11514
|
+
assert(!isnan(y[i]));
|
11515
|
+
}
|
11516
|
+
#endif
|
11517
|
+
// Jii = yi - yi*yi
|
11518
|
+
// Jij = -yi*yj
|
11519
|
+
// J = diag(y)-y.T*y
|
11520
|
+
// dx = J * dy
|
11521
|
+
// dxk = sum_i(Jki * dyi)
|
11522
|
+
// dxk = sum_i(-yk*yi * dyi) - (-yk*yk)*dyk + (yk - yk*yk)*dyk
|
11523
|
+
// dxk = sum_i(-yk*yi * dyi) + yk*dyk
|
11524
|
+
// dxk = -yk * sum_i(yi * dyi) + yk*dyk
|
11525
|
+
// dxk = -yk * dot(y, dy) + yk*dyk
|
11526
|
+
// dxk = yk * (- dot(y, dy) + dyk)
|
11527
|
+
// dxk = yk * (dyk - dot(y, dy))
|
11528
|
+
//
|
11529
|
+
// post-order:
|
11530
|
+
// dot_y_dy := dot(y, dy)
|
11531
|
+
// dx := dy
|
11532
|
+
// dx := dx - dot_y_dy
|
11533
|
+
// dx := dx * y
|
11534
|
+
|
11535
|
+
// linear runtime, no additional memory
|
11536
|
+
float dot_y_dy = 0;
|
11537
|
+
ggml_vec_dot_f32 (nc, &dot_y_dy, y, dy);
|
11538
|
+
ggml_vec_cpy_f32 (nc, dx, dy);
|
11539
|
+
ggml_vec_acc1_f32(nc, dx, -dot_y_dy);
|
11540
|
+
ggml_vec_mul_f32 (nc, dx, dx, y);
|
11541
|
+
|
11542
|
+
#ifndef NDEBUG
|
11543
|
+
for (int i = 0; i < nc; ++i) {
|
11544
|
+
assert(!isnan(dx[i]));
|
11545
|
+
assert(!isinf(dx[i]));
|
11546
|
+
}
|
11547
|
+
#endif
|
11548
|
+
}
|
11549
|
+
}
|
11550
|
+
|
11551
|
+
static void ggml_compute_forward_soft_max_back(
|
11552
|
+
const struct ggml_compute_params * params,
|
11553
|
+
const struct ggml_tensor * src0,
|
11554
|
+
const struct ggml_tensor * src1,
|
11555
|
+
struct ggml_tensor * dst) {
|
11556
|
+
switch (src0->type) {
|
11557
|
+
case GGML_TYPE_F32:
|
11558
|
+
{
|
11559
|
+
ggml_compute_forward_soft_max_back_f32(params, src0, src1, dst);
|
11560
|
+
} break;
|
11561
|
+
default:
|
11562
|
+
{
|
11563
|
+
GGML_ASSERT(false);
|
11564
|
+
} break;
|
11565
|
+
}
|
11566
|
+
}
|
11567
|
+
|
11568
|
+
// ggml_compute_forward_alibi
|
11569
|
+
|
11570
|
+
static void ggml_compute_forward_alibi_f32(
|
11571
|
+
const struct ggml_compute_params * params,
|
11572
|
+
const struct ggml_tensor * src0,
|
11573
|
+
const struct ggml_tensor * src1,
|
11574
|
+
struct ggml_tensor * dst) {
|
11575
|
+
assert(params->ith == 0);
|
11576
|
+
assert(src1->type == GGML_TYPE_I32);
|
11577
|
+
assert(ggml_nelements(src1) == 3);
|
11578
|
+
|
11579
|
+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
11580
|
+
return;
|
11581
|
+
}
|
11582
|
+
|
11583
|
+
const int n_past = ((int32_t *) src1->data)[0];
|
11584
|
+
const int n_head = ((int32_t *) src1->data)[1];
|
11585
|
+
const float max_bias = ((float *) src1->data)[2];
|
11586
|
+
|
11587
|
+
assert(n_past >= 0);
|
11588
|
+
|
11589
|
+
const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
|
11590
|
+
const int ne1 = src0->ne[1]; // seq_len_without_past
|
11591
|
+
//const int ne2 = src0->ne[2]; // n_head -> this is k
|
11592
|
+
//const int ne3 = src0->ne[3]; // 1 -> bsz
|
11593
|
+
|
11594
|
+
const int n = ggml_nrows(src0);
|
11595
|
+
const int ne2_ne3 = n/ne1; // ne2*ne3
|
11596
|
+
|
11597
|
+
const int nb0 = src0->nb[0];
|
11598
|
+
const int nb1 = src0->nb[1];
|
11599
|
+
const int nb2 = src0->nb[2];
|
11600
|
+
//const int nb3 = src0->nb[3];
|
11601
|
+
|
11602
|
+
assert(nb0 == sizeof(float));
|
11603
|
+
assert(ne1 + n_past == ne0); (void) n_past;
|
10885
11604
|
|
10886
11605
|
// add alibi to src0 (KQ_scaled)
|
10887
11606
|
const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
|
@@ -10996,6 +11715,12 @@ static void ggml_compute_forward_alibi(
|
|
10996
11715
|
case GGML_TYPE_Q5_1:
|
10997
11716
|
case GGML_TYPE_Q8_0:
|
10998
11717
|
case GGML_TYPE_Q8_1:
|
11718
|
+
case GGML_TYPE_Q2_K:
|
11719
|
+
case GGML_TYPE_Q3_K:
|
11720
|
+
case GGML_TYPE_Q4_K:
|
11721
|
+
case GGML_TYPE_Q5_K:
|
11722
|
+
case GGML_TYPE_Q6_K:
|
11723
|
+
case GGML_TYPE_Q8_K:
|
10999
11724
|
case GGML_TYPE_I8:
|
11000
11725
|
case GGML_TYPE_I16:
|
11001
11726
|
case GGML_TYPE_I32:
|
@@ -11067,6 +11792,12 @@ static void ggml_compute_forward_clamp(
|
|
11067
11792
|
case GGML_TYPE_Q5_1:
|
11068
11793
|
case GGML_TYPE_Q8_0:
|
11069
11794
|
case GGML_TYPE_Q8_1:
|
11795
|
+
case GGML_TYPE_Q2_K:
|
11796
|
+
case GGML_TYPE_Q3_K:
|
11797
|
+
case GGML_TYPE_Q4_K:
|
11798
|
+
case GGML_TYPE_Q5_K:
|
11799
|
+
case GGML_TYPE_Q6_K:
|
11800
|
+
case GGML_TYPE_Q8_K:
|
11070
11801
|
case GGML_TYPE_I8:
|
11071
11802
|
case GGML_TYPE_I16:
|
11072
11803
|
case GGML_TYPE_I32:
|
@@ -11156,7 +11887,7 @@ static void ggml_compute_forward_rope_f32(
|
|
11156
11887
|
theta *= theta_scale;
|
11157
11888
|
|
11158
11889
|
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
11159
|
-
float * dst_data = (float *)((char *) dst->data +
|
11890
|
+
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
11160
11891
|
|
11161
11892
|
const float x0 = src[0];
|
11162
11893
|
const float x1 = src[1];
|
@@ -11177,7 +11908,7 @@ static void ggml_compute_forward_rope_f32(
|
|
11177
11908
|
const int64_t i0 = ib*n_dims + ic/2;
|
11178
11909
|
|
11179
11910
|
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
11180
|
-
float * dst_data = (float *)((char *) dst->data +
|
11911
|
+
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
11181
11912
|
|
11182
11913
|
const float x0 = src[0];
|
11183
11914
|
const float x1 = src[n_dims/2];
|
@@ -12787,6 +13518,414 @@ static void ggml_compute_forward_flash_ff(
|
|
12787
13518
|
}
|
12788
13519
|
}
|
12789
13520
|
|
13521
|
+
// ggml_compute_forward_flash_attn_back
|
13522
|
+
|
13523
|
+
static void ggml_compute_forward_flash_attn_back_f32(
|
13524
|
+
const struct ggml_compute_params * params,
|
13525
|
+
const struct ggml_tensor * q,
|
13526
|
+
const struct ggml_tensor * k,
|
13527
|
+
const struct ggml_tensor * v,
|
13528
|
+
const struct ggml_tensor * d,
|
13529
|
+
const bool masked,
|
13530
|
+
struct ggml_tensor * dst) {
|
13531
|
+
int64_t t0 = ggml_perf_time_us();
|
13532
|
+
UNUSED(t0);
|
13533
|
+
|
13534
|
+
const int64_t neq0 = q->ne[0];
|
13535
|
+
const int64_t neq1 = q->ne[1];
|
13536
|
+
const int64_t neq2 = q->ne[2];
|
13537
|
+
const int64_t neq3 = q->ne[3];
|
13538
|
+
|
13539
|
+
const int64_t nek0 = k->ne[0];
|
13540
|
+
const int64_t nek1 = k->ne[1];
|
13541
|
+
//const int64_t nek2 = k->ne[2];
|
13542
|
+
//const int64_t nek3 = k->ne[3];
|
13543
|
+
|
13544
|
+
const int64_t nev0 = v->ne[0];
|
13545
|
+
const int64_t nev1 = v->ne[1];
|
13546
|
+
//const int64_t nev2 = v->ne[2];
|
13547
|
+
//const int64_t nev3 = v->ne[3];
|
13548
|
+
|
13549
|
+
const int64_t ned0 = d->ne[0];
|
13550
|
+
const int64_t ned1 = d->ne[1];
|
13551
|
+
//const int64_t ned2 = d->ne[2];
|
13552
|
+
//const int64_t ned3 = d->ne[3];
|
13553
|
+
|
13554
|
+
const int64_t ne0 = dst->ne[0];
|
13555
|
+
const int64_t ne1 = dst->ne[1];
|
13556
|
+
const int64_t ne2 = dst->ne[2];
|
13557
|
+
const int64_t ne3 = dst->ne[3];
|
13558
|
+
|
13559
|
+
const int nbk0 = k->nb[0];
|
13560
|
+
const int nbk1 = k->nb[1];
|
13561
|
+
const int nbk2 = k->nb[2];
|
13562
|
+
const int nbk3 = k->nb[3];
|
13563
|
+
|
13564
|
+
const int nbq0 = q->nb[0];
|
13565
|
+
const int nbq1 = q->nb[1];
|
13566
|
+
const int nbq2 = q->nb[2];
|
13567
|
+
const int nbq3 = q->nb[3];
|
13568
|
+
|
13569
|
+
const int nbv0 = v->nb[0];
|
13570
|
+
const int nbv1 = v->nb[1];
|
13571
|
+
const int nbv2 = v->nb[2];
|
13572
|
+
const int nbv3 = v->nb[3];
|
13573
|
+
|
13574
|
+
const int nbd0 = d->nb[0];
|
13575
|
+
const int nbd1 = d->nb[1];
|
13576
|
+
const int nbd2 = d->nb[2];
|
13577
|
+
const int nbd3 = d->nb[3];
|
13578
|
+
|
13579
|
+
const int nb0 = dst->nb[0];
|
13580
|
+
const int nb1 = dst->nb[1];
|
13581
|
+
const int nb2 = dst->nb[2];
|
13582
|
+
const int nb3 = dst->nb[3];
|
13583
|
+
|
13584
|
+
const int ith = params->ith;
|
13585
|
+
const int nth = params->nth;
|
13586
|
+
|
13587
|
+
const int64_t D = neq0;
|
13588
|
+
const int64_t N = neq1;
|
13589
|
+
const int64_t P = nek1 - N;
|
13590
|
+
const int64_t M = P + N;
|
13591
|
+
|
13592
|
+
const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL);
|
13593
|
+
const int mxDM = MAX(D, Mup);
|
13594
|
+
|
13595
|
+
// GGML_ASSERT(ne0 == D);
|
13596
|
+
// GGML_ASSERT(ne1 == N);
|
13597
|
+
GGML_ASSERT(P >= 0);
|
13598
|
+
|
13599
|
+
GGML_ASSERT(nbq0 == sizeof(float));
|
13600
|
+
GGML_ASSERT(nbk0 == sizeof(float));
|
13601
|
+
GGML_ASSERT(nbv0 == sizeof(float));
|
13602
|
+
|
13603
|
+
GGML_ASSERT(neq0 == D);
|
13604
|
+
GGML_ASSERT(nek0 == D);
|
13605
|
+
GGML_ASSERT(nev1 == D);
|
13606
|
+
GGML_ASSERT(ned0 == D);
|
13607
|
+
|
13608
|
+
GGML_ASSERT(neq1 == N);
|
13609
|
+
GGML_ASSERT(nek1 == N + P);
|
13610
|
+
GGML_ASSERT(nev1 == D);
|
13611
|
+
GGML_ASSERT(ned1 == N);
|
13612
|
+
|
13613
|
+
// dst cannot be transposed or permuted
|
13614
|
+
GGML_ASSERT(nb0 == sizeof(float));
|
13615
|
+
GGML_ASSERT(nb0 <= nb1);
|
13616
|
+
GGML_ASSERT(nb1 <= nb2);
|
13617
|
+
GGML_ASSERT(nb2 <= nb3);
|
13618
|
+
|
13619
|
+
if (params->type == GGML_TASK_INIT) {
|
13620
|
+
if (ith == 0) {
|
13621
|
+
memset(dst->data, 0, nb0*ne0*ne1*ne2*ne3);
|
13622
|
+
}
|
13623
|
+
return;
|
13624
|
+
}
|
13625
|
+
|
13626
|
+
if (params->type == GGML_TASK_FINALIZE) {
|
13627
|
+
return;
|
13628
|
+
}
|
13629
|
+
|
13630
|
+
// parallelize by q rows using ggml_vec_dot_f32
|
13631
|
+
|
13632
|
+
// total rows in q
|
13633
|
+
const int nr = neq2*neq3;
|
13634
|
+
|
13635
|
+
// rows per thread
|
13636
|
+
const int dr = (nr + nth - 1)/nth;
|
13637
|
+
|
13638
|
+
// row range for this thread
|
13639
|
+
const int ir0 = dr*ith;
|
13640
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
13641
|
+
|
13642
|
+
const float scale = 1.0f/sqrtf(D);
|
13643
|
+
|
13644
|
+
//printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
|
13645
|
+
|
13646
|
+
for (int ir = ir0; ir < ir1; ++ir) {
|
13647
|
+
// q indices
|
13648
|
+
const int iq3 = ir/(neq2);
|
13649
|
+
const int iq2 = ir - iq3*neq2;
|
13650
|
+
for ( int iq1 = 0; iq1 < neq1; ++iq1) {
|
13651
|
+
|
13652
|
+
|
13653
|
+
// not sure about CACHE_LINE_SIZE_F32..
|
13654
|
+
// - maybe it must not be multiplied by 2 and excluded from .. in SM 1*(..) offset?
|
13655
|
+
float * S = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 0*(mxDM+CACHE_LINE_SIZE_F32);
|
13656
|
+
float * SM = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 1*(mxDM+CACHE_LINE_SIZE_F32);
|
13657
|
+
|
13658
|
+
for (int i = M; i < Mup; ++i) {
|
13659
|
+
S[i] = -INFINITY;
|
13660
|
+
}
|
13661
|
+
|
13662
|
+
for (int64_t ic = 0; ic < nek1; ++ic) {
|
13663
|
+
// k indices
|
13664
|
+
const int ik3 = iq3;
|
13665
|
+
const int ik2 = iq2;
|
13666
|
+
const int ik1 = ic;
|
13667
|
+
|
13668
|
+
// S indices
|
13669
|
+
const int i1 = ik1;
|
13670
|
+
|
13671
|
+
ggml_vec_dot_f32(neq0,
|
13672
|
+
S + i1,
|
13673
|
+
(float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
|
13674
|
+
(float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
|
13675
|
+
}
|
13676
|
+
|
13677
|
+
// scale
|
13678
|
+
ggml_vec_scale_f32(nek1, S, scale);
|
13679
|
+
|
13680
|
+
if (masked) {
|
13681
|
+
for (int64_t i = P; i < M; i++) {
|
13682
|
+
if (i > P + iq1) {
|
13683
|
+
S[i] = -INFINITY;
|
13684
|
+
}
|
13685
|
+
}
|
13686
|
+
}
|
13687
|
+
|
13688
|
+
// softmax
|
13689
|
+
{
|
13690
|
+
float max = -INFINITY;
|
13691
|
+
ggml_vec_max_f32(M, &max, S);
|
13692
|
+
|
13693
|
+
ggml_float sum = 0.0;
|
13694
|
+
{
|
13695
|
+
#ifdef GGML_SOFT_MAX_ACCELERATE
|
13696
|
+
max = -max;
|
13697
|
+
vDSP_vsadd(SM, 1, &max, SM, 1, Mup);
|
13698
|
+
vvexpf(SM, SM, &Mup);
|
13699
|
+
ggml_vec_sum_f32(Mup, &sum, SM);
|
13700
|
+
#else
|
13701
|
+
uint16_t scvt[GGML_SOFT_MAX_UNROLL];
|
13702
|
+
ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
|
13703
|
+
|
13704
|
+
for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
|
13705
|
+
float * SR = S + i;
|
13706
|
+
float * SW = SM + i;
|
13707
|
+
|
13708
|
+
for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) {
|
13709
|
+
if (SR[j] == -INFINITY) {
|
13710
|
+
SW[j] = 0.0f;
|
13711
|
+
} else {
|
13712
|
+
ggml_fp16_t s = GGML_FP32_TO_FP16(SR[j] - max);
|
13713
|
+
memcpy(&scvt[j], &s, sizeof(uint16_t));
|
13714
|
+
const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]);
|
13715
|
+
sump[j] += (ggml_float)val;
|
13716
|
+
SW[j] = val;
|
13717
|
+
}
|
13718
|
+
}
|
13719
|
+
}
|
13720
|
+
|
13721
|
+
for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) {
|
13722
|
+
sum += sump[i];
|
13723
|
+
}
|
13724
|
+
#endif
|
13725
|
+
}
|
13726
|
+
|
13727
|
+
assert(sum > 0.0);
|
13728
|
+
|
13729
|
+
sum = 1.0/sum;
|
13730
|
+
ggml_vec_scale_f32(M, SM, sum);
|
13731
|
+
|
13732
|
+
}
|
13733
|
+
|
13734
|
+
// step-by-step explanation
|
13735
|
+
{
|
13736
|
+
// forward-process shape grads from backward process
|
13737
|
+
// parallel_for iq2,iq3:
|
13738
|
+
// k[:D,:M,:,:] [D,M,:,:] grad[k][:D,:M,iq2,iq3] += grad[kcur]
|
13739
|
+
// q[:D,:N,:,:] [D,N,:,:] grad[q][:D,iq1,iq2,iq3] += grad[qcur]
|
13740
|
+
// v[:M,:D,:,:] [M,D,:,:] grad[v][:M,:D,iq2,iq3] += grad[vcur]
|
13741
|
+
// for iq1:
|
13742
|
+
// kcur = k[:D,:M,iq2,iq3] [D,M,1,1] grad[kcur] = grad[S1].T @ qcur
|
13743
|
+
// qcur = q[:D,iq1,iq2,iq3] [D,1,1,1] grad[qcur] = grad[S1] @ kcur
|
13744
|
+
// vcur = v[:M,:D,iq2,iq3] [M,D,1,1] grad[vcur] = grad[S5].T @ S4
|
13745
|
+
// S0 = -Inf [D,1,1,1]
|
13746
|
+
// ~S1[i] = dot(kcur[:D,i], qcur)
|
13747
|
+
// S1 = qcur @ kcur.T [M,1,1,1] grad[S1] = grad[S2] * scale
|
13748
|
+
// S2 = S1 * scale [M,1,1,1] grad[S2] = diag_mask_zero(grad[S3], P)
|
13749
|
+
// S3 = diag_mask_inf(S2, P) [M,1,1,1] grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
|
13750
|
+
// S4 = softmax(S3) [M,1,1,1] grad[S4] = grad[S5] @ vcur
|
13751
|
+
// ~S5[i] = dot(vcur[:,i], S4)
|
13752
|
+
// S5 = S4 @ vcur.T [D,1,1,1] grad[S5] = d[:D,iq1,iq2,iq3]
|
13753
|
+
// ~dst[i,iq1,iq2,iq3] = S5[i] ^
|
13754
|
+
// dst[:D,iq1,iq2,iq3] = S5 | grad[dst[:D,iq1,iq2,iq3]] = d[:D,iq1,iq2,iq3]
|
13755
|
+
// dst backward-/ grad[dst] = d
|
13756
|
+
//
|
13757
|
+
// output gradients with their dependencies:
|
13758
|
+
//
|
13759
|
+
// grad[kcur] = grad[S1].T @ qcur
|
13760
|
+
// grad[S1] = diag_mask_zero(grad[S3], P) * scale
|
13761
|
+
// grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
|
13762
|
+
// grad[S4] = grad[S5] @ vcur
|
13763
|
+
// grad[S4] = d[:D,iq1,iq2,iq3] @ vcur
|
13764
|
+
// grad[qcur] = grad[S1] @ kcur
|
13765
|
+
// grad[vcur] = grad[S5].T @ S4
|
13766
|
+
// grad[vcur] = d[:D,iq1,iq2,iq3].T @ S4
|
13767
|
+
//
|
13768
|
+
// in post-order:
|
13769
|
+
//
|
13770
|
+
// S1 = qcur @ kcur.T
|
13771
|
+
// S2 = S1 * scale
|
13772
|
+
// S3 = diag_mask_inf(S2, P)
|
13773
|
+
// S4 = softmax(S3)
|
13774
|
+
// grad[S4] = d[:D,iq1,iq2,iq3] @ vcur
|
13775
|
+
// grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
|
13776
|
+
// grad[S1] = diag_mask_zero(grad[S3], P) * scale
|
13777
|
+
// grad[qcur] = grad[S1] @ kcur
|
13778
|
+
// grad[kcur] = grad[S1].T @ qcur
|
13779
|
+
// grad[vcur] = d[:D,iq1,iq2,iq3].T @ S4
|
13780
|
+
//
|
13781
|
+
// using less variables (SM=S4):
|
13782
|
+
//
|
13783
|
+
// S = diag_mask_inf(qcur @ kcur.T * scale, P)
|
13784
|
+
// SM = softmax(S)
|
13785
|
+
// S = d[:D,iq1,iq2,iq3] @ vcur
|
13786
|
+
// dot_SM_gradSM = dot(SM, S)
|
13787
|
+
// S = SM * (S - dot(SM, S))
|
13788
|
+
// S = diag_mask_zero(S, P) * scale
|
13789
|
+
//
|
13790
|
+
// grad[q][:D,iq1,iq2,iq3] += S @ kcur
|
13791
|
+
// grad[k][:D,:M,iq2,iq3] += S.T @ qcur
|
13792
|
+
// grad[v][:M,:D,iq2,iq3] += d[:D,iq1,iq2,iq3].T @ SM
|
13793
|
+
}
|
13794
|
+
|
13795
|
+
// S = gradSM = d[:D,iq1,iq2,iq3] @ vcur
|
13796
|
+
// S = d[:D,iq1,iq2,iq3] @ vcur
|
13797
|
+
// S[:M] += vcur[:M,ic] * d[ic,iq1,iq2,iq3]
|
13798
|
+
ggml_vec_set_f32(M, S, 0);
|
13799
|
+
for (int64_t ic = 0; ic < D; ++ic) {
|
13800
|
+
// dst indices
|
13801
|
+
const int i1 = iq1;
|
13802
|
+
const int i2 = iq2;
|
13803
|
+
const int i3 = iq3;
|
13804
|
+
|
13805
|
+
ggml_vec_mad_f32(M,
|
13806
|
+
S,
|
13807
|
+
(float *) ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)),
|
13808
|
+
*(float *) ((char *) d->data + (ic*nbd0 + i1*nbd1 + i2*nbd2 + i3*nbd3)));
|
13809
|
+
}
|
13810
|
+
|
13811
|
+
// S = SM * (S - dot(SM, S))
|
13812
|
+
float dot_SM_gradSM = 0;
|
13813
|
+
ggml_vec_dot_f32 (M, &dot_SM_gradSM, SM, S);
|
13814
|
+
ggml_vec_acc1_f32(M, S, -dot_SM_gradSM);
|
13815
|
+
ggml_vec_mul_f32 (M, S, S, SM);
|
13816
|
+
|
13817
|
+
// S = diag_mask_zero(S, P) * scale
|
13818
|
+
if (masked) {
|
13819
|
+
// for (int64_t i = P + iq1 + 1; i < M; i++) {
|
13820
|
+
// S[i] = 0;
|
13821
|
+
// }
|
13822
|
+
for (int64_t i = P; i < M; i++) {
|
13823
|
+
if (i > P + iq1) {
|
13824
|
+
S[i] = 0;
|
13825
|
+
}
|
13826
|
+
}
|
13827
|
+
}
|
13828
|
+
ggml_vec_scale_f32(M, S, scale);
|
13829
|
+
|
13830
|
+
void * grad_q = (char *) dst->data;
|
13831
|
+
void * grad_k = (char *) dst->data + nb0*D*N*neq2*neq3;
|
13832
|
+
void * grad_v = (char *) dst->data + nb0*D*N*neq2*neq3 + nb0*D*M*neq2*neq3;
|
13833
|
+
|
13834
|
+
const size_t nbgq1 = nb0*neq0;
|
13835
|
+
const size_t nbgq2 = nb0*neq0*neq1;
|
13836
|
+
const size_t nbgq3 = nb0*neq0*neq1*neq2;
|
13837
|
+
|
13838
|
+
const size_t nbgk1 = nb0*nek0;
|
13839
|
+
const size_t nbgk2 = nb0*nek0*nek1;
|
13840
|
+
const size_t nbgk3 = nb0*nek0*nek1*neq2;
|
13841
|
+
|
13842
|
+
const size_t nbgv1 = nb0*nev0;
|
13843
|
+
const size_t nbgv2 = nb0*nev0*nev1;
|
13844
|
+
const size_t nbgv3 = nb0*nev0*nev1*neq2;
|
13845
|
+
|
13846
|
+
// S shape [M,1]
|
13847
|
+
// SM shape [M,1]
|
13848
|
+
// kcur shape [D,M]
|
13849
|
+
// qcur shape [D,1]
|
13850
|
+
// vcur shape [M,D]
|
13851
|
+
//
|
13852
|
+
// grad[q][:D,iq1,iq2,iq3] += S @ kcur
|
13853
|
+
// grad[q][:D,iq1,iq2,iq3] += shape[M,1] @ shape[D,M]
|
13854
|
+
// grad[q][:D,iq1,iq2,iq3] += S[ic] * kcur[:D,ic]
|
13855
|
+
//
|
13856
|
+
//// grad[q][ic,iq1,iq2,iq3] += dot(kcur[:,ic],S.T)
|
13857
|
+
//// grad[q][ic,iq1,iq2,iq3] += dot(k[:D,ic,iq2,iq3],S.T)
|
13858
|
+
for (int64_t ic = 0; ic < M; ++ic) {
|
13859
|
+
// dst indices
|
13860
|
+
const int i1 = iq1;
|
13861
|
+
const int i2 = iq2;
|
13862
|
+
const int i3 = iq3;
|
13863
|
+
|
13864
|
+
ggml_vec_mad_f32(D,
|
13865
|
+
(float *) ((char *) grad_q + (i1*nbgq1 + i2*nbgq2 + i3*nbgq3)),
|
13866
|
+
(float *) ((char *) k->data + (ic*nbk1 + i2*nbk2 + i3*nbk3)),
|
13867
|
+
S[ic]);
|
13868
|
+
}
|
13869
|
+
|
13870
|
+
// grad[k][:D,:M,iq2,iq3] += S.T @ qcur
|
13871
|
+
// grad[k][:D,ic,iq2,iq3] += S.T[0,ic] * qcur[:D,0]
|
13872
|
+
// grad[k][:D,ic,iq2,iq3] += S[ic] * qcur[:D,0]
|
13873
|
+
for (int64_t ic = 0; ic < M; ++ic) {
|
13874
|
+
// dst indices
|
13875
|
+
const int i1 = iq1;
|
13876
|
+
const int i2 = iq2;
|
13877
|
+
const int i3 = iq3;
|
13878
|
+
|
13879
|
+
// ggml_vec_set_f32(D,
|
13880
|
+
// (float *) ((char *) grad_k + (ic*nbgk1 + i2*nbgk2 + i3*nbgk3)),
|
13881
|
+
// 0);
|
13882
|
+
ggml_vec_mad_f32(D,
|
13883
|
+
(float *) ((char *) grad_k + (ic*nbgk1 + i2*nbgk2 + i3*nbgk3)),
|
13884
|
+
(float *) ((char *) q->data + (i1*nbq1 + i2*nbq2 + i3*nbq3)),
|
13885
|
+
S[ic]);
|
13886
|
+
}
|
13887
|
+
|
13888
|
+
// grad[v][:M,:D,iq2,iq3] += d[:D,iq1,iq2,iq3].T @ SM
|
13889
|
+
// grad[v][:M,ic,iq2,iq3] += d[:D,iq1,iq2,iq3].T[0,ic] * SM[:M]
|
13890
|
+
// grad[v][:M,ic,iq2,iq3] += d[ic,iq1,iq2,iq3] * SM[:M]
|
13891
|
+
for (int64_t ic = 0; ic < D; ++ic) {
|
13892
|
+
// dst indices
|
13893
|
+
const int i1 = iq1;
|
13894
|
+
const int i2 = iq2;
|
13895
|
+
const int i3 = iq3;
|
13896
|
+
|
13897
|
+
// ggml_vec_set_f32(M,
|
13898
|
+
// (float *) ((char *) grad_v + ( ic*nbgv1 + i2*nbgv2 + i3*nbgv3)),
|
13899
|
+
// 0);
|
13900
|
+
ggml_vec_mad_f32(M,
|
13901
|
+
(float *) ((char *) grad_v + ( ic*nbgv1 + i2*nbgv2 + i3*nbgv3)),
|
13902
|
+
SM,
|
13903
|
+
*(float *) ((char *) d->data + (ic*nbd0 + i1*nbd1 + i2*nbd2 + i3*nbd3)));
|
13904
|
+
}
|
13905
|
+
}
|
13906
|
+
}
|
13907
|
+
}
|
13908
|
+
|
13909
|
+
static void ggml_compute_forward_flash_attn_back(
|
13910
|
+
const struct ggml_compute_params * params,
|
13911
|
+
const struct ggml_tensor * q,
|
13912
|
+
const struct ggml_tensor * k,
|
13913
|
+
const struct ggml_tensor * v,
|
13914
|
+
const struct ggml_tensor * d,
|
13915
|
+
const bool masked,
|
13916
|
+
struct ggml_tensor * dst) {
|
13917
|
+
switch (q->type) {
|
13918
|
+
case GGML_TYPE_F32:
|
13919
|
+
{
|
13920
|
+
ggml_compute_forward_flash_attn_back_f32(params, q, k, v, d, masked, dst);
|
13921
|
+
} break;
|
13922
|
+
default:
|
13923
|
+
{
|
13924
|
+
GGML_ASSERT(false);
|
13925
|
+
} break;
|
13926
|
+
}
|
13927
|
+
}
|
13928
|
+
|
12790
13929
|
// ggml_compute_forward_map_unary
|
12791
13930
|
|
12792
13931
|
static void ggml_compute_forward_map_unary_f32(
|
@@ -12849,29 +13988,308 @@ static void ggml_compute_forward_map_binary_f32(
|
|
12849
13988
|
const int n = ggml_nrows(src0);
|
12850
13989
|
const int nc = src0->ne[0];
|
12851
13990
|
|
12852
|
-
assert( dst->nb[0] == sizeof(float));
|
12853
|
-
assert(src0->nb[0] == sizeof(float));
|
12854
|
-
assert(src1->nb[0] == sizeof(float));
|
13991
|
+
assert( dst->nb[0] == sizeof(float));
|
13992
|
+
assert(src0->nb[0] == sizeof(float));
|
13993
|
+
assert(src1->nb[0] == sizeof(float));
|
13994
|
+
|
13995
|
+
for (int i = 0; i < n; i++) {
|
13996
|
+
fun(nc,
|
13997
|
+
(float *) ((char *) dst->data + i*( dst->nb[1])),
|
13998
|
+
(float *) ((char *) src0->data + i*(src0->nb[1])),
|
13999
|
+
(float *) ((char *) src1->data + i*(src1->nb[1])));
|
14000
|
+
}
|
14001
|
+
}
|
14002
|
+
|
14003
|
+
|
14004
|
+
static void ggml_compute_forward_map_binary(
|
14005
|
+
const struct ggml_compute_params * params,
|
14006
|
+
const struct ggml_tensor * src0,
|
14007
|
+
const struct ggml_tensor * src1,
|
14008
|
+
struct ggml_tensor * dst,
|
14009
|
+
const ggml_binary_op_f32_t fun) {
|
14010
|
+
switch (src0->type) {
|
14011
|
+
case GGML_TYPE_F32:
|
14012
|
+
{
|
14013
|
+
ggml_compute_forward_map_binary_f32(params, src0, src1, dst, fun);
|
14014
|
+
} break;
|
14015
|
+
default:
|
14016
|
+
{
|
14017
|
+
GGML_ASSERT(false);
|
14018
|
+
} break;
|
14019
|
+
}
|
14020
|
+
}
|
14021
|
+
|
14022
|
+
// ggml_compute_forward_cross_entropy_loss
|
14023
|
+
|
14024
|
+
static void ggml_compute_forward_cross_entropy_loss_f32(
|
14025
|
+
const struct ggml_compute_params * params,
|
14026
|
+
const struct ggml_tensor * src0,
|
14027
|
+
const struct ggml_tensor * src1,
|
14028
|
+
struct ggml_tensor * dst) {
|
14029
|
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
14030
|
+
GGML_ASSERT(ggml_is_contiguous(src1));
|
14031
|
+
GGML_ASSERT(ggml_is_scalar(dst));
|
14032
|
+
GGML_ASSERT(ggml_are_same_shape(src0, src1));
|
14033
|
+
|
14034
|
+
const int ith = params->ith;
|
14035
|
+
const int nth = params->nth;
|
14036
|
+
|
14037
|
+
float * sums = (float *) params->wdata;
|
14038
|
+
|
14039
|
+
// TODO: handle transposed/permuted matrices
|
14040
|
+
const int nc = src0->ne[0];
|
14041
|
+
const int nr = ggml_nrows(src0);
|
14042
|
+
|
14043
|
+
if (params->type == GGML_TASK_INIT) {
|
14044
|
+
if (ith == 0) {
|
14045
|
+
memset(sums, 0, sizeof(float) * (nth + nth * nc));
|
14046
|
+
}
|
14047
|
+
return;
|
14048
|
+
}
|
14049
|
+
|
14050
|
+
if (params->type == GGML_TASK_FINALIZE) {
|
14051
|
+
if (ith == 0) {
|
14052
|
+
float * dp = (float *) dst->data;
|
14053
|
+
ggml_vec_sum_f32(nth, dp, sums);
|
14054
|
+
dp[0] *= -1.0f;
|
14055
|
+
}
|
14056
|
+
return;
|
14057
|
+
}
|
14058
|
+
|
14059
|
+
const double eps = 1e-9;
|
14060
|
+
|
14061
|
+
// rows per thread
|
14062
|
+
const int dr = (nr + nth - 1)/nth;
|
14063
|
+
|
14064
|
+
// row range for this thread
|
14065
|
+
const int ir0 = dr*ith;
|
14066
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
14067
|
+
|
14068
|
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
14069
|
+
float * s0 = (float *)((char *) src0->data + i1*src0->nb[1]);
|
14070
|
+
float * s1 = (float *)((char *) src1->data + i1*src1->nb[1]);
|
14071
|
+
float * st = (float *) params->wdata + nth + ith*nc;
|
14072
|
+
|
14073
|
+
#ifndef NDEBUG
|
14074
|
+
for (int i = 0; i < nc; ++i) {
|
14075
|
+
//printf("p[%d] = %f\n", i, p[i]);
|
14076
|
+
assert(!isnan(s0[i]));
|
14077
|
+
assert(!isnan(s1[i]));
|
14078
|
+
}
|
14079
|
+
#endif
|
14080
|
+
// soft_max
|
14081
|
+
ggml_float sum = 0.0;
|
14082
|
+
{
|
14083
|
+
float max = -INFINITY;
|
14084
|
+
ggml_vec_max_f32(nc, &max, s0);
|
14085
|
+
|
14086
|
+
uint16_t scvt;
|
14087
|
+
for (int i = 0; i < nc; i++) {
|
14088
|
+
if (s0[i] == -INFINITY) {
|
14089
|
+
st[i] = 0.0f;
|
14090
|
+
} else {
|
14091
|
+
// const float val = (s0[i] == -INFINITY) ? 0.0 : exp(s0[i] - max);
|
14092
|
+
ggml_fp16_t s = GGML_FP32_TO_FP16(s0[i] - max);
|
14093
|
+
memcpy(&scvt, &s, sizeof(scvt));
|
14094
|
+
const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt]);
|
14095
|
+
sum += (ggml_float)val;
|
14096
|
+
st[i] = val;
|
14097
|
+
}
|
14098
|
+
}
|
14099
|
+
|
14100
|
+
assert(sum > 0.0);
|
14101
|
+
// sum = 1.0/sum;
|
14102
|
+
}
|
14103
|
+
// avoid log(0) by rescaling from [0..1] to [eps..1]
|
14104
|
+
sum = (1.0 - eps) / sum;
|
14105
|
+
ggml_vec_scale_f32(nc, st, sum);
|
14106
|
+
ggml_vec_add1_f32(nc, st, st, eps);
|
14107
|
+
ggml_vec_log_f32(nc, st, st);
|
14108
|
+
ggml_vec_mul_f32(nc, st, st, s1);
|
14109
|
+
|
14110
|
+
ggml_vec_sum_f32(nc, sums + ith, st);
|
14111
|
+
|
14112
|
+
#ifndef NDEBUG
|
14113
|
+
for (int i = 0; i < nc; ++i) {
|
14114
|
+
assert(!isnan(st[i]));
|
14115
|
+
assert(!isinf(st[i]));
|
14116
|
+
}
|
14117
|
+
#endif
|
14118
|
+
}
|
14119
|
+
|
14120
|
+
}
|
14121
|
+
|
14122
|
+
static void ggml_compute_forward_cross_entropy_loss(
|
14123
|
+
const struct ggml_compute_params * params,
|
14124
|
+
const struct ggml_tensor * src0,
|
14125
|
+
const struct ggml_tensor * src1,
|
14126
|
+
struct ggml_tensor * dst) {
|
14127
|
+
switch (src0->type) {
|
14128
|
+
case GGML_TYPE_F32:
|
14129
|
+
{
|
14130
|
+
ggml_compute_forward_cross_entropy_loss_f32(params, src0, src1, dst);
|
14131
|
+
} break;
|
14132
|
+
default:
|
14133
|
+
{
|
14134
|
+
GGML_ASSERT(false);
|
14135
|
+
} break;
|
14136
|
+
}
|
14137
|
+
}
|
14138
|
+
|
14139
|
+
// ggml_compute_forward_cross_entropy_loss_back
|
14140
|
+
|
14141
|
+
static void ggml_compute_forward_cross_entropy_loss_back_f32(
|
14142
|
+
const struct ggml_compute_params * params,
|
14143
|
+
const struct ggml_tensor * src0,
|
14144
|
+
const struct ggml_tensor * src1,
|
14145
|
+
const struct ggml_tensor * opt0,
|
14146
|
+
struct ggml_tensor * dst) {
|
14147
|
+
GGML_ASSERT(ggml_is_contiguous(dst));
|
14148
|
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
14149
|
+
GGML_ASSERT(ggml_is_contiguous(src1));
|
14150
|
+
GGML_ASSERT(ggml_is_contiguous(opt0));
|
14151
|
+
GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
|
14152
|
+
|
14153
|
+
const int64_t ith = params->ith;
|
14154
|
+
const int64_t nth = params->nth;
|
14155
|
+
|
14156
|
+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
14157
|
+
return;
|
14158
|
+
}
|
14159
|
+
|
14160
|
+
const float eps = 1e-9f;
|
14161
|
+
|
14162
|
+
// TODO: handle transposed/permuted matrices
|
14163
|
+
const int64_t nc = src0->ne[0];
|
14164
|
+
const int64_t nr = ggml_nrows(src0);
|
14165
|
+
|
14166
|
+
// rows per thread
|
14167
|
+
const int64_t dr = (nr + nth - 1)/nth;
|
14168
|
+
|
14169
|
+
// row range for this thread
|
14170
|
+
const int64_t ir0 = dr*ith;
|
14171
|
+
const int64_t ir1 = MIN(ir0 + dr, nr);
|
14172
|
+
|
14173
|
+
float * d = (float *) opt0->data;
|
14174
|
+
|
14175
|
+
for (int64_t i1 = ir0; i1 < ir1; i1++) {
|
14176
|
+
float * ds0 = (float *)((char *) dst->data + i1*dst->nb[1]);
|
14177
|
+
float * s0 = (float *)((char *) src0->data + i1*src0->nb[1]);
|
14178
|
+
float * s1 = (float *)((char *) src1->data + i1*src1->nb[1]);
|
14179
|
+
float * sm = (float *) params->wdata + ith*nc;
|
14180
|
+
|
14181
|
+
#ifndef NDEBUG
|
14182
|
+
for (int i = 0; i < nc; ++i) {
|
14183
|
+
//printf("p[%d] = %f\n", i, p[i]);
|
14184
|
+
assert(!isnan(s0[i]));
|
14185
|
+
assert(!isnan(s1[i]));
|
14186
|
+
}
|
14187
|
+
#endif
|
14188
|
+
// step by step explanation:
|
14189
|
+
{
|
14190
|
+
//float * sums = (float *) params->wdata;
|
14191
|
+
|
14192
|
+
// forward pass with annotated gradients from backward pass
|
14193
|
+
// (built by going in reverse operation order, adding to gradients of current operation args)
|
14194
|
+
// st0 = exp(s0-max(s0)) grad[st0] = grad[st1]*(1.0 - eps)/sum
|
14195
|
+
// from softmax_back: grad[s0] = st1_k * (grad[st1]_k - dot(st1, grad[st1]))
|
14196
|
+
// ggml_vec_scale_f32(nc, st, sum); // st1 = st0*/sum = softmax(s0) grad[st1] = grad[st2]*(1.0 - eps)
|
14197
|
+
// ggml_vec_scale_f32(nc, st, (1.0f - eps)); // st2 = st1*(1.0 - eps) grad[st2] = grad[st3]
|
14198
|
+
// ggml_vec_add1_f32(nc, st, st, eps); // st3 = st2 + eps grad[st3] = grad[st4]/st3
|
14199
|
+
// ggml_vec_log_f32(nc, st, st); // st4 = log(st3) grad[st4] = grad[st5] * s1
|
14200
|
+
// ggml_vec_mul_f32(nc, st, st, s1); // st5 = st4 * s1 grad[st5] = grad[sums[ith]]
|
14201
|
+
// ggml_vec_sum_f32(nc, sums + ith, st); // sums[ith] = st5 grad[sums[ith]] = grad[cross_entropy_loss] = -grad[cel]
|
14202
|
+
|
14203
|
+
// substitute into grad[st1], because we can reuse softmax_back from this point on
|
14204
|
+
// grad[st1] = -grad[cel]*s1*(1.0 - eps)/(eps + softmax(s0)*(1.0 - eps))
|
14205
|
+
// postorder:
|
14206
|
+
// grad[st1] := softmax(s0)
|
14207
|
+
// grad[st1] := grad[st1]*(1.0 - eps)
|
14208
|
+
// grad[st1] := grad[st1] + eps
|
14209
|
+
// grad[st1] := s1 / grad[st1]
|
14210
|
+
// grad[st1] := grad[st1]*(1.0-eps)*-grad[cel]
|
14211
|
+
|
14212
|
+
// src0 gradients by going through softmax_back
|
14213
|
+
// grad[s0] = st1_k * (grad[st1]_k - dot(st1, grad[st1]))
|
14214
|
+
// from softmax_back:
|
14215
|
+
// dxk = yk * (dyk - dot(y, dy))
|
14216
|
+
// dot_y_dy := dot(y, dy)
|
14217
|
+
// dx := dy
|
14218
|
+
// dx := dx - dot_y_dy
|
14219
|
+
// dx := dx * y
|
14220
|
+
// postorder:
|
14221
|
+
// dot_st1_dst1 := dot(st1, grad[st1])
|
14222
|
+
// grad[s0] := grad[st1]
|
14223
|
+
// grad[s0] := grad[s0] - dot_st1_dst1
|
14224
|
+
// grad[s0] := grad[s0] * st1
|
14225
|
+
|
14226
|
+
// prepend postorder from grad[st1] directly using grad[s0] as memory location, as we will grad[s0] := grad[st1]
|
14227
|
+
// sm := softmax(s0)
|
14228
|
+
// grad[s0] := sm*(1.0 - eps)
|
14229
|
+
// grad[s0] := grad[s0] + eps
|
14230
|
+
// grad[s0] := s1 / grad[s0]
|
14231
|
+
// grad[s0] := grad[s0]*(1.0-eps)*-grad[cel]
|
14232
|
+
// dot_st1_dst1 := dot(sm, grad[s0])
|
14233
|
+
// grad[s0] := grad[s0] - dot_st1_dst1
|
14234
|
+
// grad[s0] := grad[s0] * sm
|
14235
|
+
}
|
14236
|
+
|
14237
|
+
// soft_max
|
14238
|
+
ggml_float sum = 0.0;
|
14239
|
+
{
|
14240
|
+
float max = -INFINITY;
|
14241
|
+
ggml_vec_max_f32(nc, &max, s0);
|
14242
|
+
|
14243
|
+
uint16_t scvt;
|
14244
|
+
for (int i = 0; i < nc; i++) {
|
14245
|
+
if (s0[i] == -INFINITY) {
|
14246
|
+
sm[i] = 0.0f;
|
14247
|
+
} else {
|
14248
|
+
// const float val = (s0[i] == -INFINITY) ? 0.0 : exp(s0[i] - max);
|
14249
|
+
ggml_fp16_t s = GGML_FP32_TO_FP16(s0[i] - max);
|
14250
|
+
memcpy(&scvt, &s, sizeof(scvt));
|
14251
|
+
const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt]);
|
14252
|
+
sum += (ggml_float)val;
|
14253
|
+
sm[i] = val;
|
14254
|
+
}
|
14255
|
+
}
|
14256
|
+
|
14257
|
+
assert(sum > 0.0);
|
14258
|
+
sum = 1.0/sum;
|
14259
|
+
}
|
12855
14260
|
|
12856
|
-
|
12857
|
-
|
12858
|
-
|
12859
|
-
|
12860
|
-
|
14261
|
+
float dot_st1_dst1 = 0;
|
14262
|
+
ggml_vec_scale_f32(nc, sm, sum);
|
14263
|
+
ggml_vec_cpy_f32 (nc, ds0, sm);
|
14264
|
+
ggml_vec_scale_f32(nc, ds0, (1.0f - eps));
|
14265
|
+
ggml_vec_add1_f32 (nc, ds0, ds0, eps);
|
14266
|
+
ggml_vec_div_f32 (nc, ds0, s1, ds0);
|
14267
|
+
ggml_vec_scale_f32(nc, ds0, -(1.0f - eps)*d[0]);
|
14268
|
+
ggml_vec_dot_f32 (nc, &dot_st1_dst1, sm, ds0);
|
14269
|
+
ggml_vec_acc1_f32 (nc, ds0, -dot_st1_dst1);
|
14270
|
+
ggml_vec_mul_f32 (nc, ds0, ds0, sm);
|
14271
|
+
|
14272
|
+
#ifndef NDEBUG
|
14273
|
+
for (int i = 0; i < nc; ++i) {
|
14274
|
+
assert(!isnan(sm[i]));
|
14275
|
+
assert(!isinf(sm[i]));
|
14276
|
+
assert(!isnan(ds0[i]));
|
14277
|
+
assert(!isinf(ds0[i]));
|
14278
|
+
}
|
14279
|
+
#endif
|
12861
14280
|
}
|
12862
14281
|
}
|
12863
14282
|
|
12864
|
-
|
12865
|
-
static void ggml_compute_forward_map_binary(
|
14283
|
+
static void ggml_compute_forward_cross_entropy_loss_back(
|
12866
14284
|
const struct ggml_compute_params * params,
|
12867
14285
|
const struct ggml_tensor * src0,
|
12868
14286
|
const struct ggml_tensor * src1,
|
12869
|
-
struct ggml_tensor *
|
12870
|
-
|
14287
|
+
const struct ggml_tensor * opt0,
|
14288
|
+
struct ggml_tensor * dst) {
|
12871
14289
|
switch (src0->type) {
|
12872
14290
|
case GGML_TYPE_F32:
|
12873
14291
|
{
|
12874
|
-
|
14292
|
+
ggml_compute_forward_cross_entropy_loss_back_f32(params, src0, src1, opt0, dst);
|
12875
14293
|
} break;
|
12876
14294
|
default:
|
12877
14295
|
{
|
@@ -12880,11 +14298,21 @@ static void ggml_compute_forward_map_binary(
|
|
12880
14298
|
}
|
12881
14299
|
}
|
12882
14300
|
|
14301
|
+
|
12883
14302
|
/////////////////////////////////
|
12884
14303
|
|
12885
14304
|
static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
|
12886
14305
|
GGML_ASSERT(params);
|
12887
14306
|
|
14307
|
+
#ifdef GGML_USE_CUBLAS
|
14308
|
+
bool skip_cpu = ggml_cuda_compute_forward(params, tensor);
|
14309
|
+
if (skip_cpu) {
|
14310
|
+
return;
|
14311
|
+
}
|
14312
|
+
GGML_ASSERT(tensor->src0->backend == GGML_BACKEND_CPU);
|
14313
|
+
GGML_ASSERT(tensor->src1 == NULL || tensor->src1->backend == GGML_BACKEND_CPU);
|
14314
|
+
#endif // GGML_USE_CUBLAS
|
14315
|
+
|
12888
14316
|
switch (tensor->op) {
|
12889
14317
|
case GGML_OP_DUP:
|
12890
14318
|
{
|
@@ -12942,6 +14370,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
12942
14370
|
{
|
12943
14371
|
ggml_compute_forward_repeat(params, tensor->src0, tensor);
|
12944
14372
|
} break;
|
14373
|
+
case GGML_OP_REPEAT_BACK:
|
14374
|
+
{
|
14375
|
+
ggml_compute_forward_repeat_back(params, tensor->src0, tensor);
|
14376
|
+
} break;
|
12945
14377
|
case GGML_OP_ABS:
|
12946
14378
|
{
|
12947
14379
|
ggml_compute_forward_abs(params, tensor->src0, tensor);
|
@@ -12990,6 +14422,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
12990
14422
|
{
|
12991
14423
|
ggml_compute_forward_mul_mat(params, tensor->src0, tensor->src1, tensor);
|
12992
14424
|
} break;
|
14425
|
+
case GGML_OP_OUT_PROD:
|
14426
|
+
{
|
14427
|
+
ggml_compute_forward_out_prod(params, tensor->src0, tensor->src1, tensor);
|
14428
|
+
} break;
|
12993
14429
|
case GGML_OP_SCALE:
|
12994
14430
|
{
|
12995
14431
|
ggml_compute_forward_scale(params, tensor->src0, tensor->src1, tensor);
|
@@ -13046,6 +14482,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
13046
14482
|
{
|
13047
14483
|
ggml_compute_forward_soft_max(params, tensor->src0, tensor);
|
13048
14484
|
} break;
|
14485
|
+
case GGML_OP_SOFT_MAX_BACK:
|
14486
|
+
{
|
14487
|
+
ggml_compute_forward_soft_max_back(params, tensor->src0, tensor->src1, tensor);
|
14488
|
+
} break;
|
13049
14489
|
case GGML_OP_ROPE:
|
13050
14490
|
{
|
13051
14491
|
ggml_compute_forward_rope(params, tensor->src0, tensor->src1, tensor);
|
@@ -13081,6 +14521,13 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
13081
14521
|
{
|
13082
14522
|
ggml_compute_forward_flash_ff(params, tensor->src0, tensor->src1, tensor->opt[0], tensor->opt[1], tensor->opt[2], tensor);
|
13083
14523
|
} break;
|
14524
|
+
case GGML_OP_FLASH_ATTN_BACK:
|
14525
|
+
{
|
14526
|
+
int32_t t = ggml_get_i32_1d(tensor->opt[2], 0);
|
14527
|
+
GGML_ASSERT(t == 0 || t == 1);
|
14528
|
+
bool masked = t != 0;
|
14529
|
+
ggml_compute_forward_flash_attn_back(params, tensor->src0, tensor->src1, tensor->opt[0], tensor->opt[1], masked, tensor);
|
14530
|
+
} break;
|
13084
14531
|
case GGML_OP_MAP_UNARY:
|
13085
14532
|
{
|
13086
14533
|
const ggml_unary_op_f32_t fun = *((ggml_unary_op_f32_t *)tensor->opt[0]->data);
|
@@ -13093,6 +14540,16 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
13093
14540
|
ggml_compute_forward_map_binary(params, tensor->src0, tensor->src1, tensor, fun);
|
13094
14541
|
}
|
13095
14542
|
break;
|
14543
|
+
case GGML_OP_CROSS_ENTROPY_LOSS:
|
14544
|
+
{
|
14545
|
+
ggml_compute_forward_cross_entropy_loss(params, tensor->src0, tensor->src1, tensor);
|
14546
|
+
}
|
14547
|
+
break;
|
14548
|
+
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
14549
|
+
{
|
14550
|
+
ggml_compute_forward_cross_entropy_loss_back(params, tensor->src0, tensor->src1, tensor->opt[0], tensor);
|
14551
|
+
}
|
14552
|
+
break;
|
13096
14553
|
case GGML_OP_NONE:
|
13097
14554
|
{
|
13098
14555
|
// nop
|
@@ -13231,11 +14688,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
13231
14688
|
src0->grad =
|
13232
14689
|
ggml_add_impl(ctx,
|
13233
14690
|
src0->grad,
|
13234
|
-
|
13235
|
-
tensor->grad, // this was not catched by test_grad because in test_grad tensor->grad is 1
|
14691
|
+
ggml_scale(ctx,
|
13236
14692
|
ggml_div(ctx,
|
13237
|
-
|
13238
|
-
tensor)
|
14693
|
+
tensor->grad,
|
14694
|
+
tensor),
|
14695
|
+
ggml_new_f32(ctx, 0.5f)),
|
13239
14696
|
inplace);
|
13240
14697
|
}
|
13241
14698
|
} break;
|
@@ -13281,43 +14738,20 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
13281
14738
|
{
|
13282
14739
|
// necessary for llama
|
13283
14740
|
if (src0->grad) {
|
13284
|
-
|
13285
|
-
|
13286
|
-
|
13287
|
-
|
13288
|
-
|
13289
|
-
|
13290
|
-
|
13291
|
-
|
13292
|
-
|
13293
|
-
//
|
13294
|
-
|
13295
|
-
|
13296
|
-
|
13297
|
-
|
13298
|
-
// transpose [nc0*nr0,1,1]
|
13299
|
-
// reshape [nc0,nr0,1,1] reshape_1d or reshape_2d
|
13300
|
-
// add to src0->grad
|
13301
|
-
|
13302
|
-
int64_t ne[4] = {nc0,ncr,nr0,nrr};
|
13303
|
-
|
13304
|
-
struct ggml_tensor* F00 = tensor->grad;
|
13305
|
-
struct ggml_tensor* F01 = ggml_reshape (ctx, F00, ggml_new_tensor(ctx,tensor->grad->type,4,ne));
|
13306
|
-
struct ggml_tensor* F02 = ggml_permute (ctx, F01, 0,2,1,3);
|
13307
|
-
struct ggml_tensor* F03 = ggml_cont (ctx, F02);
|
13308
|
-
struct ggml_tensor* F04 = ggml_reshape_2d(ctx, F03, nc0*nr0, ncr*nrr);
|
13309
|
-
struct ggml_tensor* F05 = ggml_transpose (ctx, F04);
|
13310
|
-
struct ggml_tensor* F06 = ggml_cont (ctx, F05);
|
13311
|
-
struct ggml_tensor* F07 = ggml_sum_rows (ctx, F06);
|
13312
|
-
struct ggml_tensor* F08 = ggml_transpose (ctx, F07);
|
13313
|
-
struct ggml_tensor* F09 = ggml_cont (ctx, F08);
|
13314
|
-
struct ggml_tensor* F10 = ggml_reshape (ctx, F09, src0->grad);
|
13315
|
-
|
13316
|
-
src0->grad =
|
13317
|
-
ggml_add_impl(ctx,
|
13318
|
-
src0->grad,
|
13319
|
-
F10,
|
13320
|
-
inplace);
|
14741
|
+
src0->grad = ggml_add_impl(ctx,
|
14742
|
+
src0->grad,
|
14743
|
+
ggml_repeat_back(ctx, tensor->grad, src0->grad),
|
14744
|
+
inplace);
|
14745
|
+
}
|
14746
|
+
} break;
|
14747
|
+
case GGML_OP_REPEAT_BACK:
|
14748
|
+
{
|
14749
|
+
if (src0->grad) {
|
14750
|
+
// TODO: test this
|
14751
|
+
src0->grad = ggml_add_impl(ctx,
|
14752
|
+
src0->grad,
|
14753
|
+
ggml_repeat(ctx, tensor->grad, src0->grad),
|
14754
|
+
inplace);
|
13321
14755
|
}
|
13322
14756
|
} break;
|
13323
14757
|
case GGML_OP_ABS:
|
@@ -13424,38 +14858,37 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
13424
14858
|
|
13425
14859
|
// necessary for llama
|
13426
14860
|
if (src0->grad) {
|
13427
|
-
// TODO: this requires outer product - ggml_out_prod(ctx, src1, tensor->grad);
|
13428
14861
|
src0->grad =
|
13429
14862
|
ggml_add_impl(ctx,
|
13430
14863
|
src0->grad,
|
13431
|
-
//
|
13432
|
-
|
13433
|
-
|
13434
|
-
// tensor->grad), // [m,p]
|
13435
|
-
// for now just using A*B==(B.T*A.T).T
|
13436
|
-
ggml_cont(ctx, // [n,m]
|
13437
|
-
ggml_transpose(ctx, // [n,m]
|
13438
|
-
ggml_mul_mat(ctx, // [m,n]
|
13439
|
-
ggml_cont(ctx, // [p,m]
|
13440
|
-
ggml_transpose(ctx, // [p,m]
|
13441
|
-
tensor->grad)), // [m,p]
|
13442
|
-
ggml_cont(ctx, // [p,n]
|
13443
|
-
ggml_transpose(ctx, // [p,n]
|
13444
|
-
src1))))), // [n,p]
|
14864
|
+
ggml_out_prod(ctx, // [n,m]
|
14865
|
+
src1, // [n,p]
|
14866
|
+
tensor->grad), // [m,p]
|
13445
14867
|
inplace);
|
13446
14868
|
}
|
13447
14869
|
if (src1->grad) {
|
13448
14870
|
src1->grad =
|
13449
14871
|
ggml_add_impl(ctx,
|
13450
14872
|
src1->grad,
|
13451
|
-
//
|
13452
|
-
|
13453
|
-
|
13454
|
-
|
13455
|
-
|
14873
|
+
// ggml_mul_mat(ctx, // [n,p]
|
14874
|
+
// ggml_cont(ctx, // [m,n]
|
14875
|
+
// ggml_transpose(ctx, src0)), // [m,n]
|
14876
|
+
// tensor->grad), // [m,p]
|
14877
|
+
|
14878
|
+
// // when src0 is bigger than tensor->grad (this is mostly the case in llama),
|
14879
|
+
// // avoid transpose of src0, rather transpose smaller tensor->grad
|
14880
|
+
// // and then use ggml_out_prod
|
14881
|
+
ggml_out_prod(ctx, // [n,p]
|
14882
|
+
src0, // [n,m]
|
14883
|
+
ggml_transpose(ctx, // [p,m]
|
14884
|
+
tensor->grad)), // [m,p]
|
13456
14885
|
inplace);
|
13457
14886
|
}
|
13458
14887
|
} break;
|
14888
|
+
case GGML_OP_OUT_PROD:
|
14889
|
+
{
|
14890
|
+
GGML_ASSERT(false); // TODO: not implemented
|
14891
|
+
} break;
|
13459
14892
|
case GGML_OP_SCALE:
|
13460
14893
|
{
|
13461
14894
|
// necessary for llama
|
@@ -13557,7 +14990,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
13557
14990
|
// necessary for llama
|
13558
14991
|
if (src0->grad) {
|
13559
14992
|
size_t offset;
|
13560
|
-
|
14993
|
+
|
14994
|
+
GGML_ASSERT(sizeof(offset) <= ggml_nbytes(tensor->opt[0]));
|
14995
|
+
memcpy(&offset, tensor->opt[0]->data, sizeof(offset));
|
13561
14996
|
|
13562
14997
|
size_t nb1 = tensor->nb[1];
|
13563
14998
|
size_t nb2 = tensor->nb[2];
|
@@ -13584,10 +15019,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
13584
15019
|
{
|
13585
15020
|
// necessary for llama
|
13586
15021
|
if (src0->grad) {
|
13587
|
-
|
13588
|
-
int
|
13589
|
-
int
|
13590
|
-
int
|
15022
|
+
int32_t * axes = (int32_t *) tensor->opt[0]->data;
|
15023
|
+
int axis0 = axes[0] & 0x3;
|
15024
|
+
int axis1 = axes[1] & 0x3;
|
15025
|
+
int axis2 = axes[2] & 0x3;
|
15026
|
+
int axis3 = axes[3] & 0x3;
|
13591
15027
|
int axes_backward[4] = {0,0,0,0};
|
13592
15028
|
axes_backward[axis0] = 0;
|
13593
15029
|
axes_backward[axis1] = 1;
|
@@ -13671,50 +15107,16 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
13671
15107
|
{
|
13672
15108
|
// necessary for llama
|
13673
15109
|
if (src0->grad) {
|
13674
|
-
// y = softmax(x)
|
13675
|
-
//
|
13676
|
-
// Jii = yi - yi*yi
|
13677
|
-
// Jij = -yi*yj
|
13678
|
-
// J = diag(y)-y.*y
|
13679
|
-
// dx = J * dy
|
13680
|
-
// dxk = sum(Jkj * dyk)
|
13681
|
-
|
13682
|
-
int64_t ne2[4] = {
|
13683
|
-
tensor->ne[0],
|
13684
|
-
1,
|
13685
|
-
tensor->ne[1]*tensor->ne[2],
|
13686
|
-
tensor->ne[3]
|
13687
|
-
};
|
13688
|
-
struct ggml_tensor * tensor2 = ggml_cont(ctx,
|
13689
|
-
ggml_reshape_4d(ctx,
|
13690
|
-
ggml_cont(ctx, tensor),
|
13691
|
-
ne2[0], ne2[1], ne2[2], ne2[3]));
|
13692
|
-
|
13693
|
-
struct ggml_tensor * grad2 = ggml_cont(ctx,
|
13694
|
-
ggml_reshape_4d(ctx,
|
13695
|
-
ggml_cont(ctx, tensor->grad),
|
13696
|
-
ne2[0], ne2[1], ne2[2], ne2[3]));
|
13697
|
-
|
13698
|
-
struct ggml_tensor * tensor2_t = ggml_cont(ctx, // [1,ne0,ne1*ne2,ne3]
|
13699
|
-
ggml_permute(ctx, // [1,ne0,ne1*ne2,ne3]
|
13700
|
-
tensor2, // [ne0,1,ne1*ne2,ne3]
|
13701
|
-
1, 0, 2, 3));
|
13702
|
-
|
13703
15110
|
src0->grad =
|
13704
|
-
ggml_add_impl(ctx,
|
13705
|
-
|
13706
|
-
|
13707
|
-
ggml_mul_mat(ctx, // [ne0,1,ne1*ne2,ne3]
|
13708
|
-
ggml_sub(ctx, // [ne0,ne0,ne1*ne2,ne3]
|
13709
|
-
ggml_diag(ctx, // [ne0,ne0,ne1*ne2,ne3]
|
13710
|
-
tensor2), // [ne0,1,ne1*ne2,ne3]
|
13711
|
-
ggml_mul_mat(ctx, // [ne0,ne0,ne1*ne2,ne3]
|
13712
|
-
tensor2_t, // [1,ne0,ne1*ne2,ne3]
|
13713
|
-
tensor2_t)), // [1,ne0,ne1*ne2,ne3]
|
13714
|
-
grad2), // [ne0,1,ne1*ne2,ne3]
|
13715
|
-
src0->grad),
|
13716
|
-
inplace);
|
15111
|
+
ggml_add_impl(ctx, src0->grad,
|
15112
|
+
ggml_soft_max_back(ctx, tensor->grad, tensor),
|
15113
|
+
inplace);
|
13717
15114
|
}
|
15115
|
+
|
15116
|
+
} break;
|
15117
|
+
case GGML_OP_SOFT_MAX_BACK:
|
15118
|
+
{
|
15119
|
+
GGML_ASSERT(false); // TODO: not implemented
|
13718
15120
|
} break;
|
13719
15121
|
case GGML_OP_ROPE:
|
13720
15122
|
{
|
@@ -13769,17 +15171,190 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
13769
15171
|
} break;
|
13770
15172
|
case GGML_OP_FLASH_ATTN:
|
13771
15173
|
{
|
13772
|
-
|
15174
|
+
struct ggml_tensor * flash_grad = NULL;
|
15175
|
+
if (src0->grad || src1->grad || tensor->opt[0]->grad) {
|
15176
|
+
int32_t t = ggml_get_i32_1d(tensor->opt[1], 0);
|
15177
|
+
GGML_ASSERT(t == 0 || t == 1);
|
15178
|
+
bool masked = t != 0;
|
15179
|
+
flash_grad =
|
15180
|
+
ggml_flash_attn_back(ctx,
|
15181
|
+
src0,
|
15182
|
+
src1,
|
15183
|
+
tensor->opt[0],
|
15184
|
+
tensor->grad,
|
15185
|
+
masked);
|
15186
|
+
}
|
15187
|
+
|
15188
|
+
if (src0->grad) {
|
15189
|
+
struct ggml_tensor * grad_q = NULL;
|
15190
|
+
const size_t nb0 = flash_grad->nb[0];
|
15191
|
+
const size_t offset = 0;
|
15192
|
+
switch(src0->n_dims) {
|
15193
|
+
case 2:
|
15194
|
+
{
|
15195
|
+
grad_q = ggml_view_2d(ctx,
|
15196
|
+
flash_grad,
|
15197
|
+
src0->ne[0],
|
15198
|
+
src0->ne[1],
|
15199
|
+
nb0*src0->ne[0],
|
15200
|
+
offset);
|
15201
|
+
} break;
|
15202
|
+
case 3:
|
15203
|
+
{
|
15204
|
+
grad_q = ggml_view_3d(ctx,
|
15205
|
+
flash_grad,
|
15206
|
+
src0->ne[0],
|
15207
|
+
src0->ne[1],
|
15208
|
+
src0->ne[2],
|
15209
|
+
nb0*src0->ne[0],
|
15210
|
+
nb0*src0->ne[0]*src0->ne[1],
|
15211
|
+
offset);
|
15212
|
+
} break;
|
15213
|
+
case 4:
|
15214
|
+
{
|
15215
|
+
grad_q = ggml_view_4d(ctx,
|
15216
|
+
flash_grad,
|
15217
|
+
src0->ne[0],
|
15218
|
+
src0->ne[1],
|
15219
|
+
src0->ne[2],
|
15220
|
+
src0->ne[3],
|
15221
|
+
nb0*src0->ne[0],
|
15222
|
+
nb0*src0->ne[0]*src0->ne[1],
|
15223
|
+
nb0*src0->ne[0]*src0->ne[1]*src0->ne[2],
|
15224
|
+
offset);
|
15225
|
+
} break;
|
15226
|
+
}
|
15227
|
+
|
15228
|
+
src0->grad = ggml_add_impl(ctx,
|
15229
|
+
src0->grad,
|
15230
|
+
grad_q,
|
15231
|
+
inplace);
|
15232
|
+
}
|
15233
|
+
|
15234
|
+
if (src1->grad) {
|
15235
|
+
struct ggml_tensor * grad_k = NULL;
|
15236
|
+
const size_t nb0 = flash_grad->nb[0];
|
15237
|
+
const size_t offset = nb0*src0->ne[0]*src0->ne[1]*src0->ne[2]*src0->ne[3];
|
15238
|
+
switch(src1->n_dims) {
|
15239
|
+
case 2:
|
15240
|
+
{
|
15241
|
+
grad_k = ggml_view_2d(ctx,
|
15242
|
+
flash_grad,
|
15243
|
+
src1->ne[0],
|
15244
|
+
src1->ne[1],
|
15245
|
+
nb0*src1->ne[0],
|
15246
|
+
offset);
|
15247
|
+
} break;
|
15248
|
+
case 3:
|
15249
|
+
{
|
15250
|
+
grad_k = ggml_view_3d(ctx,
|
15251
|
+
flash_grad,
|
15252
|
+
src1->ne[0],
|
15253
|
+
src1->ne[1],
|
15254
|
+
src1->ne[2],
|
15255
|
+
nb0*src1->ne[0],
|
15256
|
+
nb0*src1->ne[0]*src1->ne[1],
|
15257
|
+
offset);
|
15258
|
+
} break;
|
15259
|
+
case 4:
|
15260
|
+
{
|
15261
|
+
grad_k = ggml_view_4d(ctx,
|
15262
|
+
flash_grad,
|
15263
|
+
src1->ne[0],
|
15264
|
+
src1->ne[1],
|
15265
|
+
src1->ne[2],
|
15266
|
+
src1->ne[3],
|
15267
|
+
nb0*src1->ne[0],
|
15268
|
+
nb0*src1->ne[0]*src1->ne[1],
|
15269
|
+
nb0*src1->ne[0]*src1->ne[1]*src1->ne[2],
|
15270
|
+
offset);
|
15271
|
+
} break;
|
15272
|
+
}
|
15273
|
+
|
15274
|
+
src1->grad = ggml_add_impl(ctx,
|
15275
|
+
src1->grad,
|
15276
|
+
grad_k,
|
15277
|
+
inplace);
|
15278
|
+
}
|
15279
|
+
|
15280
|
+
struct ggml_tensor * opt0 = tensor->opt[0];
|
15281
|
+
|
15282
|
+
if (opt0->grad) {
|
15283
|
+
struct ggml_tensor * grad_v = NULL;
|
15284
|
+
const size_t nb0 = flash_grad->nb[0];
|
15285
|
+
const size_t offset = nb0*src0->ne[0]*src0->ne[1]*src0->ne[2]*src0->ne[3]
|
15286
|
+
+ nb0*src1->ne[0]*src1->ne[1]*src1->ne[2]*src1->ne[3];
|
15287
|
+
switch(opt0->n_dims) {
|
15288
|
+
case 2:
|
15289
|
+
{
|
15290
|
+
grad_v = ggml_view_2d(ctx,
|
15291
|
+
flash_grad,
|
15292
|
+
opt0->ne[0],
|
15293
|
+
opt0->ne[1],
|
15294
|
+
nb0*opt0->ne[0],
|
15295
|
+
offset);
|
15296
|
+
} break;
|
15297
|
+
case 3:
|
15298
|
+
{
|
15299
|
+
grad_v = ggml_view_3d(ctx,
|
15300
|
+
flash_grad,
|
15301
|
+
opt0->ne[0],
|
15302
|
+
opt0->ne[1],
|
15303
|
+
opt0->ne[2],
|
15304
|
+
nb0*opt0->ne[0],
|
15305
|
+
nb0*opt0->ne[0]*opt0->ne[1],
|
15306
|
+
offset);
|
15307
|
+
} break;
|
15308
|
+
case 4:
|
15309
|
+
{
|
15310
|
+
grad_v = ggml_view_4d(ctx,
|
15311
|
+
flash_grad,
|
15312
|
+
opt0->ne[0],
|
15313
|
+
opt0->ne[1],
|
15314
|
+
opt0->ne[2],
|
15315
|
+
opt0->ne[3],
|
15316
|
+
nb0*opt0->ne[0],
|
15317
|
+
nb0*opt0->ne[0]*opt0->ne[1],
|
15318
|
+
nb0*opt0->ne[0]*opt0->ne[1]*opt0->ne[2],
|
15319
|
+
offset);
|
15320
|
+
} break;
|
15321
|
+
}
|
15322
|
+
|
15323
|
+
opt0->grad = ggml_add_impl(ctx,
|
15324
|
+
opt0->grad,
|
15325
|
+
grad_v,
|
15326
|
+
inplace);
|
15327
|
+
}
|
13773
15328
|
} break;
|
13774
15329
|
case GGML_OP_FLASH_FF:
|
13775
15330
|
{
|
13776
15331
|
GGML_ASSERT(false); // not supported
|
13777
15332
|
} break;
|
15333
|
+
case GGML_OP_FLASH_ATTN_BACK:
|
15334
|
+
{
|
15335
|
+
GGML_ASSERT(false); // not supported
|
15336
|
+
} break;
|
13778
15337
|
case GGML_OP_MAP_UNARY:
|
13779
15338
|
case GGML_OP_MAP_BINARY:
|
13780
15339
|
{
|
13781
15340
|
GGML_ASSERT(false); // not supported
|
13782
15341
|
} break;
|
15342
|
+
case GGML_OP_CROSS_ENTROPY_LOSS:
|
15343
|
+
{
|
15344
|
+
if (src0->grad) {
|
15345
|
+
src0->grad = ggml_add_impl(ctx,
|
15346
|
+
src0->grad,
|
15347
|
+
ggml_cross_entropy_loss_back(ctx,
|
15348
|
+
src0,
|
15349
|
+
src1,
|
15350
|
+
tensor->grad),
|
15351
|
+
inplace);
|
15352
|
+
}
|
15353
|
+
} break;
|
15354
|
+
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
15355
|
+
{
|
15356
|
+
GGML_ASSERT(false); // not supported
|
15357
|
+
} break;
|
13783
15358
|
case GGML_OP_NONE:
|
13784
15359
|
{
|
13785
15360
|
// nop
|
@@ -14156,6 +15731,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|
14156
15731
|
case GGML_OP_SUM_ROWS:
|
14157
15732
|
case GGML_OP_MEAN:
|
14158
15733
|
case GGML_OP_REPEAT:
|
15734
|
+
case GGML_OP_REPEAT_BACK:
|
14159
15735
|
case GGML_OP_ABS:
|
14160
15736
|
case GGML_OP_SGN:
|
14161
15737
|
case GGML_OP_NEG:
|
@@ -14175,6 +15751,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|
14175
15751
|
node->n_tasks = n_threads;
|
14176
15752
|
} break;
|
14177
15753
|
case GGML_OP_MUL_MAT:
|
15754
|
+
case GGML_OP_OUT_PROD:
|
14178
15755
|
{
|
14179
15756
|
node->n_tasks = n_threads;
|
14180
15757
|
|
@@ -14191,7 +15768,6 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|
14191
15768
|
if (ggml_cuda_can_mul_mat(node->src0, node->src1, node)) {
|
14192
15769
|
node->n_tasks = 1; // TODO: this actually is doing nothing
|
14193
15770
|
// the threads are still spinning
|
14194
|
-
cur = ggml_cuda_mul_mat_get_wsize(node->src0, node->src1, node);
|
14195
15771
|
}
|
14196
15772
|
else
|
14197
15773
|
#elif defined(GGML_USE_CLBLAST)
|
@@ -14258,6 +15834,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|
14258
15834
|
} break;
|
14259
15835
|
case GGML_OP_DIAG_MASK_INF:
|
14260
15836
|
case GGML_OP_SOFT_MAX:
|
15837
|
+
case GGML_OP_SOFT_MAX_BACK:
|
14261
15838
|
case GGML_OP_ROPE:
|
14262
15839
|
case GGML_OP_ROPE_BACK:
|
14263
15840
|
{
|
@@ -14337,6 +15914,27 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|
14337
15914
|
cur += sizeof(float)*node->src1->ne[1]*node->n_tasks; // this is overestimated by x2
|
14338
15915
|
}
|
14339
15916
|
|
15917
|
+
work_size = MAX(work_size, cur);
|
15918
|
+
} break;
|
15919
|
+
case GGML_OP_FLASH_ATTN_BACK:
|
15920
|
+
{
|
15921
|
+
node->n_tasks = n_threads;
|
15922
|
+
|
15923
|
+
size_t cur = 0;
|
15924
|
+
|
15925
|
+
const int64_t D = node->src0->ne[0];
|
15926
|
+
const int64_t ne11 = ggml_up(node->src1->ne[1], GGML_SOFT_MAX_UNROLL);
|
15927
|
+
const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in ggml_compute_forward_flash_attn_back
|
15928
|
+
if (node->src1->type == GGML_TYPE_F32) {
|
15929
|
+
cur = sizeof(float)*mxDn*node->n_tasks; // TODO: this can become (n_tasks-1)
|
15930
|
+
cur += sizeof(float)*mxDn*node->n_tasks; // this is overestimated by x2
|
15931
|
+
}
|
15932
|
+
|
15933
|
+
if (node->src1->type == GGML_TYPE_F16) {
|
15934
|
+
cur = sizeof(float)*mxDn*node->n_tasks; // TODO: this can become (n_tasks-1)
|
15935
|
+
cur += sizeof(float)*mxDn*node->n_tasks; // this is overestimated by x2
|
15936
|
+
}
|
15937
|
+
|
14340
15938
|
work_size = MAX(work_size, cur);
|
14341
15939
|
} break;
|
14342
15940
|
case GGML_OP_MAP_UNARY:
|
@@ -14344,6 +15942,22 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|
14344
15942
|
{
|
14345
15943
|
node->n_tasks = 1;
|
14346
15944
|
} break;
|
15945
|
+
case GGML_OP_CROSS_ENTROPY_LOSS:
|
15946
|
+
{
|
15947
|
+
node->n_tasks = n_threads;
|
15948
|
+
|
15949
|
+
size_t cur = ggml_type_size(node->type)*(node->n_tasks + node->src0->ne[0]*node->n_tasks);
|
15950
|
+
|
15951
|
+
work_size = MAX(work_size, cur);
|
15952
|
+
} break;
|
15953
|
+
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
15954
|
+
{
|
15955
|
+
node->n_tasks = n_threads;
|
15956
|
+
|
15957
|
+
size_t cur = ggml_type_size(node->type)*node->src0->ne[0]*node->n_tasks;
|
15958
|
+
|
15959
|
+
work_size = MAX(work_size, cur);
|
15960
|
+
} break;
|
14347
15961
|
case GGML_OP_NONE:
|
14348
15962
|
{
|
14349
15963
|
node->n_tasks = 1;
|
@@ -14581,7 +16195,7 @@ static void ggml_graph_export_leaf(const struct ggml_tensor * tensor, FILE * fou
|
|
14581
16195
|
const int64_t * ne = tensor->ne;
|
14582
16196
|
const size_t * nb = tensor->nb;
|
14583
16197
|
|
14584
|
-
fprintf(fout, "%-6s %-12s %8d %
|
16198
|
+
fprintf(fout, "%-6s %-12s %8d %" PRId64 " %" PRId64 " %" PRId64 " %" PRId64 " %16zu %16zu %16zu %16zu %16p %32s\n",
|
14585
16199
|
ggml_type_name(tensor->type),
|
14586
16200
|
ggml_op_name (tensor->op),
|
14587
16201
|
tensor->n_dims,
|
@@ -14595,7 +16209,7 @@ static void ggml_graph_export_node(const struct ggml_tensor * tensor, const char
|
|
14595
16209
|
const int64_t * ne = tensor->ne;
|
14596
16210
|
const size_t * nb = tensor->nb;
|
14597
16211
|
|
14598
|
-
fprintf(fout, "%-6s %-6s %-12s %8d %
|
16212
|
+
fprintf(fout, "%-6s %-6s %-12s %8d %" PRId64 " %" PRId64 " %" PRId64 " %" PRId64 " %16zu %16zu %16zu %16zu %8d %16p %32s\n",
|
14599
16213
|
arg,
|
14600
16214
|
ggml_type_name(tensor->type),
|
14601
16215
|
ggml_op_name (tensor->op),
|
@@ -14608,8 +16222,8 @@ static void ggml_graph_export_node(const struct ggml_tensor * tensor, const char
|
|
14608
16222
|
}
|
14609
16223
|
|
14610
16224
|
void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname) {
|
14611
|
-
assert(cgraph->work == NULL);
|
14612
|
-
assert(cgraph->work_size == 0);
|
16225
|
+
//assert(cgraph->work == NULL);
|
16226
|
+
//assert(cgraph->work_size == 0);
|
14613
16227
|
|
14614
16228
|
uint64_t size_eval = 0;
|
14615
16229
|
|
@@ -14624,11 +16238,11 @@ void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname) {
|
|
14624
16238
|
FILE * fout = stdout;
|
14625
16239
|
|
14626
16240
|
fprintf(fout, "\n");
|
14627
|
-
fprintf(fout, "%-16s %8x\n",
|
14628
|
-
fprintf(fout, "%-16s %8d\n",
|
14629
|
-
fprintf(fout, "%-16s %8d\n",
|
14630
|
-
fprintf(fout, "%-16s %8d\n",
|
14631
|
-
fprintf(fout, "%-16s %
|
16241
|
+
fprintf(fout, "%-16s %8x\n", "magic", GGML_FILE_MAGIC);
|
16242
|
+
fprintf(fout, "%-16s %8d\n", "version", GGML_FILE_VERSION);
|
16243
|
+
fprintf(fout, "%-16s %8d\n", "leafs", cgraph->n_leafs);
|
16244
|
+
fprintf(fout, "%-16s %8d\n", "nodes", cgraph->n_nodes);
|
16245
|
+
fprintf(fout, "%-16s %" PRIu64 "\n", "eval", size_eval);
|
14632
16246
|
|
14633
16247
|
// header
|
14634
16248
|
fprintf(fout, "\n");
|
@@ -14830,7 +16444,6 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context **
|
|
14830
16444
|
// read file into data
|
14831
16445
|
{
|
14832
16446
|
FILE * fin = fopen(fname, "rb");
|
14833
|
-
|
14834
16447
|
if (!fin) {
|
14835
16448
|
fprintf(stderr, "%s: failed to open %s\n", __func__, fname);
|
14836
16449
|
return result;
|
@@ -14862,7 +16475,11 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context **
|
|
14862
16475
|
|
14863
16476
|
data = ggml_new_tensor_1d(*ctx_data, GGML_TYPE_I8, fsize);
|
14864
16477
|
|
14865
|
-
fread(data->data, sizeof(char), fsize, fin);
|
16478
|
+
const size_t ret = fread(data->data, sizeof(char), fsize, fin);
|
16479
|
+
if (ret != fsize) {
|
16480
|
+
fprintf(stderr, "%s: failed to read %s\n", __func__, fname);
|
16481
|
+
return result;
|
16482
|
+
}
|
14866
16483
|
|
14867
16484
|
fclose(fin);
|
14868
16485
|
}
|
@@ -14970,6 +16587,8 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context **
|
|
14970
16587
|
op = *(const uint32_t *) ptr; ptr += sizeof(op);
|
14971
16588
|
n_dims = *(const uint32_t *) ptr; ptr += sizeof(n_dims);
|
14972
16589
|
|
16590
|
+
enum ggml_op eop = (enum ggml_op) op;
|
16591
|
+
|
14973
16592
|
int64_t ne[GGML_MAX_DIMS];
|
14974
16593
|
size_t nb[GGML_MAX_DIMS];
|
14975
16594
|
|
@@ -14984,42 +16603,77 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context **
|
|
14984
16603
|
nb[j] = nb_cur;
|
14985
16604
|
}
|
14986
16605
|
|
14987
|
-
|
14988
|
-
|
14989
|
-
tensor->op = (enum ggml_op) op;
|
16606
|
+
uint64_t ptr_cur = *(const uint64_t *) ptr; ptr += sizeof(ptr_cur); // TODO: not yet used
|
14990
16607
|
|
14991
|
-
|
16608
|
+
const char * ptr_name = ptr; ptr += GGML_MAX_NAME;
|
14992
16609
|
|
14993
|
-
|
16610
|
+
const int32_t * ptr_arg_idx = (const int32_t *) ptr; ptr += (2 + GGML_MAX_OPT)*sizeof(int32_t);
|
14994
16611
|
|
14995
|
-
|
14996
|
-
tensor->nb[j] = nb[j];
|
14997
|
-
}
|
16612
|
+
struct ggml_tensor * args[2 + GGML_MAX_OPT] = { NULL };
|
14998
16613
|
|
14999
16614
|
// parse args
|
15000
|
-
{
|
15001
|
-
|
15002
|
-
&tensor->src0,
|
15003
|
-
&tensor->src1,
|
15004
|
-
};
|
16615
|
+
for (int j = 0; j < 2 + GGML_MAX_OPT; ++j) {
|
16616
|
+
const int32_t arg_idx = ptr_arg_idx[j];
|
15005
16617
|
|
15006
|
-
|
15007
|
-
|
16618
|
+
if (arg_idx == -1) {
|
16619
|
+
continue;
|
15008
16620
|
}
|
15009
16621
|
|
15010
|
-
|
15011
|
-
|
16622
|
+
if (arg_idx < GGML_MAX_NODES) {
|
16623
|
+
args[j] = result.leafs[arg_idx];
|
16624
|
+
} else {
|
16625
|
+
args[j] = result.nodes[arg_idx - GGML_MAX_NODES];
|
16626
|
+
}
|
16627
|
+
}
|
15012
16628
|
|
15013
|
-
|
15014
|
-
|
15015
|
-
|
16629
|
+
// create the tensor
|
16630
|
+
// "view" operations are handled differently
|
16631
|
+
// TODO: handle inplace ops - currently a copy is always made
|
16632
|
+
|
16633
|
+
struct ggml_tensor * tensor = NULL;
|
16634
|
+
|
16635
|
+
switch (eop) {
|
16636
|
+
// TODO: implement other view ops
|
16637
|
+
case GGML_OP_RESHAPE:
|
16638
|
+
{
|
16639
|
+
tensor = ggml_reshape_4d(*ctx_eval, args[0], ne[0], ne[1], ne[2], ne[3]);
|
16640
|
+
} break;
|
16641
|
+
case GGML_OP_VIEW:
|
16642
|
+
{
|
16643
|
+
tensor = ggml_view_4d(*ctx_eval, args[0], ne[0], ne[1], ne[2], ne[3], 0, 0, 0, 0);
|
16644
|
+
|
16645
|
+
uint64_t offs;
|
16646
|
+
memcpy(&offs, args[2]->data, sizeof(offs));
|
16647
|
+
|
16648
|
+
tensor->data = ((char *) tensor->data) + offs;
|
16649
|
+
} break;
|
16650
|
+
case GGML_OP_TRANSPOSE:
|
16651
|
+
{
|
16652
|
+
tensor = ggml_transpose(*ctx_eval, args[0]);
|
16653
|
+
} break;
|
16654
|
+
case GGML_OP_PERMUTE:
|
16655
|
+
{
|
16656
|
+
tensor = ggml_view_4d(*ctx_eval, args[0], ne[0], ne[1], ne[2], ne[3], 0, 0, 0, 0);
|
16657
|
+
} break;
|
16658
|
+
default:
|
16659
|
+
{
|
16660
|
+
tensor = ggml_new_tensor(*ctx_eval, (enum ggml_type) type, n_dims, ne);
|
16661
|
+
|
16662
|
+
tensor->op = eop;
|
16663
|
+
} break;
|
16664
|
+
}
|
15016
16665
|
|
15017
|
-
|
15018
|
-
|
15019
|
-
|
15020
|
-
|
15021
|
-
|
15022
|
-
|
16666
|
+
memcpy(tensor->name, ptr_name, GGML_MAX_NAME);
|
16667
|
+
|
16668
|
+
for (int j = 0; j < GGML_MAX_DIMS; ++j) {
|
16669
|
+
tensor->nb[j] = nb[j];
|
16670
|
+
}
|
16671
|
+
|
16672
|
+
tensor->src0 = args[0];
|
16673
|
+
tensor->src1 = args[1];
|
16674
|
+
|
16675
|
+
for (int j = 0; j < GGML_MAX_OPT; ++j) {
|
16676
|
+
tensor->opt[j] = args[2 + j];
|
15023
16677
|
}
|
15024
16678
|
|
15025
16679
|
result.nodes[i] = tensor;
|
@@ -15279,6 +16933,7 @@ static void ggml_opt_get_grad(int np, struct ggml_tensor * const ps[], float * g
|
|
15279
16933
|
|
15280
16934
|
static enum ggml_opt_result ggml_opt_adam(
|
15281
16935
|
struct ggml_context * ctx,
|
16936
|
+
struct ggml_opt_context * opt,
|
15282
16937
|
struct ggml_opt_params params,
|
15283
16938
|
struct ggml_tensor * f,
|
15284
16939
|
struct ggml_cgraph * gf,
|
@@ -15304,25 +16959,29 @@ static enum ggml_opt_result ggml_opt_adam(
|
|
15304
16959
|
}
|
15305
16960
|
}
|
15306
16961
|
|
16962
|
+
if ((opt->params.type != params.type) || (opt->nx != nx) || (opt->params.past != params.past)) {
|
16963
|
+
int iter = opt->iter;
|
16964
|
+
ggml_opt_init(opt->ctx, opt, params, nx);
|
16965
|
+
opt->iter = iter;
|
16966
|
+
}
|
16967
|
+
|
15307
16968
|
// constants
|
15308
|
-
const float
|
16969
|
+
const float sched = params.adam.sched;
|
16970
|
+
const float decay = params.adam.decay * sched;
|
16971
|
+
const float alpha = params.adam.alpha * sched;
|
15309
16972
|
const float beta1 = params.adam.beta1;
|
15310
16973
|
const float beta2 = params.adam.beta2;
|
15311
16974
|
const float eps = params.adam.eps;
|
15312
16975
|
|
15313
|
-
float * x =
|
15314
|
-
float * g1 =
|
15315
|
-
float * g2 =
|
15316
|
-
float * m =
|
15317
|
-
float * v =
|
15318
|
-
float * mh =
|
15319
|
-
float * vh =
|
15320
|
-
|
15321
|
-
float * pf = params.past > 0 ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)->data : NULL; // past function values
|
16976
|
+
float * x = opt->adam.x->data; // view of the parameters
|
16977
|
+
float * g1 = opt->adam.g1->data; // gradient
|
16978
|
+
float * g2 = opt->adam.g2->data; // gradient squared
|
16979
|
+
float * m = opt->adam.m->data; // first moment
|
16980
|
+
float * v = opt->adam.v->data; // second moment
|
16981
|
+
float * mh = opt->adam.mh->data; // first moment hat
|
16982
|
+
float * vh = opt->adam.vh->data; // second moment hat
|
15322
16983
|
|
15323
|
-
//
|
15324
|
-
ggml_vec_set_f32(nx, m, 0.0f);
|
15325
|
-
ggml_vec_set_f32(nx, v, 0.0f);
|
16984
|
+
float * pf = params.past > 0 ? opt->adam.pf->data : NULL; // past function values
|
15326
16985
|
|
15327
16986
|
// update view
|
15328
16987
|
ggml_opt_get_params(np, ps, x);
|
@@ -15332,16 +16991,27 @@ static enum ggml_opt_result ggml_opt_adam(
|
|
15332
16991
|
ggml_set_f32 (f->grad, 1.0f);
|
15333
16992
|
ggml_graph_compute(ctx, gb);
|
15334
16993
|
|
15335
|
-
|
16994
|
+
opt->adam.fx_prev = ggml_get_f32_1d(f, 0);
|
16995
|
+
opt->adam.fx_best = opt->adam.fx_prev;
|
15336
16996
|
if (pf) {
|
15337
|
-
pf[
|
16997
|
+
pf[opt->iter % params.past] = opt->adam.fx_prev;
|
16998
|
+
}
|
16999
|
+
|
17000
|
+
// initialize
|
17001
|
+
if (opt->just_initialized) {
|
17002
|
+
opt->adam.n_no_improvement = 0;
|
17003
|
+
opt->just_initialized = false;
|
15338
17004
|
}
|
15339
17005
|
|
15340
|
-
|
15341
|
-
float
|
17006
|
+
float * fx_best = &opt->adam.fx_best;
|
17007
|
+
float * fx_prev = &opt->adam.fx_prev;
|
17008
|
+
int * n_no_improvement = &opt->adam.n_no_improvement;
|
17009
|
+
|
17010
|
+
int iter0 = opt->iter;
|
15342
17011
|
|
15343
17012
|
// run the optimizer
|
15344
17013
|
for (int t = 0; t < params.adam.n_iter; ++t) {
|
17014
|
+
opt->iter = iter0 + t + 1;
|
15345
17015
|
GGML_PRINT_DEBUG ("=== iter %d ===\n", t);
|
15346
17016
|
|
15347
17017
|
GGML_PRINT_DEBUG ("f = %10.6f\n", ggml_get_f32_1d(f, 0));
|
@@ -15375,17 +17045,22 @@ static enum ggml_opt_result ggml_opt_adam(
|
|
15375
17045
|
|
15376
17046
|
// m^hat = m_t / (1 - beta1^t)
|
15377
17047
|
// v^hat = v_t / (1 - beta2^t)
|
15378
|
-
// x_t = x_t-1 - alpha*m^hat/(sqrt(v^hat) + eps)
|
17048
|
+
// x_t = x_t-1 - sched*(alpha*m^hat/(sqrt(v^hat) + eps) + decay*x_t-1)
|
17049
|
+
// x_t = x_t-1 - sched*alpha*m^hat/(sqrt(v^hat) + eps) - sched*decay*x_t-1
|
17050
|
+
// x_t = x_t-1*(1-sched*decay) - sched*alpha*m^hat/(sqrt(v^hat) + eps)
|
17051
|
+
// x_t = x_t-1*(1-sched*decay) + sched*decay*(-alpha/decay)*m^hat/(sqrt(v^hat) + eps)
|
17052
|
+
// x_t = mix(x_t-1, (-alpha/decay)*m^hat/(sqrt(v^hat) + eps), sched*decay)
|
15379
17053
|
ggml_vec_cpy_f32 (nx, mh, m);
|
15380
17054
|
ggml_vec_cpy_f32 (nx, vh, v);
|
15381
17055
|
|
15382
|
-
ggml_vec_scale_f32(nx, mh, alpha/(1.0f - powf(beta1,
|
15383
|
-
ggml_vec_scale_f32(nx, vh, 1.0f/(1.0f - powf(beta2,
|
17056
|
+
ggml_vec_scale_f32(nx, mh, alpha/(1.0f - powf(beta1, opt->iter)));
|
17057
|
+
ggml_vec_scale_f32(nx, vh, 1.0f/(1.0f - powf(beta2, opt->iter)));
|
15384
17058
|
|
15385
17059
|
ggml_vec_sqrt_f32 (nx, vh, vh);
|
15386
17060
|
ggml_vec_acc1_f32 (nx, vh, eps);
|
15387
17061
|
|
15388
17062
|
ggml_vec_div_f32 (nx, mh, mh, vh);
|
17063
|
+
ggml_vec_scale_f32(nx, x, 1.0f - decay);
|
15389
17064
|
ggml_vec_sub_f32 (nx, x, x, mh);
|
15390
17065
|
|
15391
17066
|
// update the parameters
|
@@ -15399,7 +17074,7 @@ static enum ggml_opt_result ggml_opt_adam(
|
|
15399
17074
|
const float fx = ggml_get_f32_1d(f, 0);
|
15400
17075
|
|
15401
17076
|
// check convergence
|
15402
|
-
if (fabsf(fx - fx_prev)/fx < params.adam.eps_f) {
|
17077
|
+
if (fabsf(fx - fx_prev[0])/fx < params.adam.eps_f) {
|
15403
17078
|
GGML_PRINT_DEBUG("converged\n");
|
15404
17079
|
|
15405
17080
|
return GGML_OPT_OK;
|
@@ -15408,32 +17083,32 @@ static enum ggml_opt_result ggml_opt_adam(
|
|
15408
17083
|
// delta-based convergence test
|
15409
17084
|
if (pf != NULL) {
|
15410
17085
|
// need at least params.past iterations to start checking for convergence
|
15411
|
-
if (params.past <= t) {
|
15412
|
-
const float rate = (pf[t%params.past] - fx)/fx;
|
17086
|
+
if (params.past <= iter0 + t) {
|
17087
|
+
const float rate = (pf[(iter0 + t)%params.past] - fx)/fx;
|
15413
17088
|
|
15414
17089
|
if (fabsf(rate) < params.delta) {
|
15415
17090
|
return GGML_OPT_OK;
|
15416
17091
|
}
|
15417
17092
|
}
|
15418
17093
|
|
15419
|
-
pf[t%params.past] = fx;
|
17094
|
+
pf[(iter0 + t)%params.past] = fx;
|
15420
17095
|
}
|
15421
17096
|
|
15422
17097
|
// check for improvement
|
15423
17098
|
if (params.max_no_improvement > 0) {
|
15424
|
-
if (fx_best > fx) {
|
15425
|
-
fx_best = fx;
|
15426
|
-
n_no_improvement = 0;
|
17099
|
+
if (fx_best[0] > fx) {
|
17100
|
+
fx_best[0] = fx;
|
17101
|
+
n_no_improvement[0] = 0;
|
15427
17102
|
} else {
|
15428
|
-
++n_no_improvement;
|
17103
|
+
++n_no_improvement[0];
|
15429
17104
|
|
15430
|
-
if (n_no_improvement >= params.max_no_improvement) {
|
17105
|
+
if (n_no_improvement[0] >= params.max_no_improvement) {
|
15431
17106
|
return GGML_OPT_OK;
|
15432
17107
|
}
|
15433
17108
|
}
|
15434
17109
|
}
|
15435
17110
|
|
15436
|
-
fx_prev = fx;
|
17111
|
+
fx_prev[0] = fx;
|
15437
17112
|
|
15438
17113
|
{
|
15439
17114
|
const int64_t t_end_cpu = ggml_cycles();
|
@@ -15572,6 +17247,7 @@ static enum ggml_opt_result linesearch_backtracking(
|
|
15572
17247
|
|
15573
17248
|
static enum ggml_opt_result ggml_opt_lbfgs(
|
15574
17249
|
struct ggml_context * ctx,
|
17250
|
+
struct ggml_opt_context * opt,
|
15575
17251
|
struct ggml_opt_params params,
|
15576
17252
|
struct ggml_tensor * f,
|
15577
17253
|
struct ggml_cgraph * gf,
|
@@ -15604,31 +17280,32 @@ static enum ggml_opt_result ggml_opt_lbfgs(
|
|
15604
17280
|
}
|
15605
17281
|
}
|
15606
17282
|
|
15607
|
-
|
15608
|
-
|
15609
|
-
|
15610
|
-
|
15611
|
-
|
17283
|
+
if ((opt->params.type != params.type) || (opt->nx != nx) || (opt->params.past != params.past) || (opt->params.lbfgs.m != params.lbfgs.m)) {
|
17284
|
+
int iter = opt->iter;
|
17285
|
+
ggml_opt_init(ctx, opt, params, nx);
|
17286
|
+
opt->iter = iter;
|
17287
|
+
}
|
17288
|
+
|
17289
|
+
float * x = opt->lbfgs.x->data; // current parameters
|
17290
|
+
float * xp = opt->lbfgs.xp->data; // previous parameters
|
17291
|
+
float * g = opt->lbfgs.g->data; // current gradient
|
17292
|
+
float * gp = opt->lbfgs.gp->data; // previous gradient
|
17293
|
+
float * d = opt->lbfgs.d->data; // search direction
|
15612
17294
|
|
15613
|
-
float * pf = params.past > 0 ?
|
17295
|
+
float * pf = params.past > 0 ? opt->lbfgs.pf->data : NULL; // past function values
|
15614
17296
|
|
15615
17297
|
float fx = 0.0f; // cost function value
|
15616
17298
|
float xnorm = 0.0f; // ||x||
|
15617
17299
|
float gnorm = 0.0f; // ||g||
|
15618
|
-
float step = 0.0f;
|
15619
17300
|
|
15620
17301
|
// initialize x from the graph nodes
|
15621
17302
|
ggml_opt_get_params(np, ps, x);
|
15622
17303
|
|
15623
17304
|
// the L-BFGS memory
|
15624
|
-
|
15625
|
-
|
15626
|
-
|
15627
|
-
|
15628
|
-
lm[i].ys = 0.0f;
|
15629
|
-
lm[i].s = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data;
|
15630
|
-
lm[i].y = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data;
|
15631
|
-
}
|
17305
|
+
float * lm_alpha = opt->lbfgs.lmal->data;
|
17306
|
+
float * lm_ys = opt->lbfgs.lmys->data;
|
17307
|
+
float * lm_s = opt->lbfgs.lms->data;
|
17308
|
+
float * lm_y = opt->lbfgs.lmy->data;
|
15632
17309
|
|
15633
17310
|
// evaluate the function value and its gradient
|
15634
17311
|
{
|
@@ -15643,12 +17320,6 @@ static enum ggml_opt_result ggml_opt_lbfgs(
|
|
15643
17320
|
fx = ggml_get_f32_1d(f, 0);
|
15644
17321
|
}
|
15645
17322
|
|
15646
|
-
if (pf) {
|
15647
|
-
pf[0] = fx;
|
15648
|
-
}
|
15649
|
-
|
15650
|
-
float fx_best = fx;
|
15651
|
-
|
15652
17323
|
// search direction = -gradient
|
15653
17324
|
ggml_vec_neg_f32(nx, d, g);
|
15654
17325
|
|
@@ -15665,26 +17336,43 @@ static enum ggml_opt_result ggml_opt_lbfgs(
|
|
15665
17336
|
return GGML_OPT_OK;
|
15666
17337
|
}
|
15667
17338
|
|
15668
|
-
|
15669
|
-
|
17339
|
+
if (opt->just_initialized) {
|
17340
|
+
if (pf) {
|
17341
|
+
pf[0] = fx;
|
17342
|
+
}
|
17343
|
+
opt->lbfgs.fx_best = fx;
|
17344
|
+
|
17345
|
+
// initial step
|
17346
|
+
ggml_vec_norm_inv_f32(nx, &opt->lbfgs.step, d);
|
17347
|
+
opt->lbfgs.j = 0;
|
17348
|
+
opt->lbfgs.k = 1;
|
17349
|
+
opt->lbfgs.end = 0;
|
17350
|
+
opt->lbfgs.n_no_improvement = 0;
|
17351
|
+
opt->just_initialized = false;
|
17352
|
+
}
|
17353
|
+
|
17354
|
+
float * fx_best = &opt->lbfgs.fx_best;
|
17355
|
+
float * step = &opt->lbfgs.step;
|
17356
|
+
int * j = &opt->lbfgs.j;
|
17357
|
+
int * k = &opt->lbfgs.k;
|
17358
|
+
int * end = &opt->lbfgs.end;
|
17359
|
+
int * n_no_improvement = &opt->lbfgs.n_no_improvement;
|
15670
17360
|
|
15671
|
-
int
|
15672
|
-
int
|
15673
|
-
int ls = 0;
|
15674
|
-
int end = 0;
|
15675
|
-
int bound = 0;
|
15676
|
-
int n_no_improvement = 0;
|
17361
|
+
int ls = 0;
|
17362
|
+
int bound = 0;
|
15677
17363
|
|
15678
17364
|
float ys = 0.0f;
|
15679
17365
|
float yy = 0.0f;
|
15680
17366
|
float beta = 0.0f;
|
15681
17367
|
|
17368
|
+
int it = 0;
|
17369
|
+
|
15682
17370
|
while (true) {
|
15683
17371
|
// store the current position and gradient vectors
|
15684
17372
|
ggml_vec_cpy_f32(nx, xp, x);
|
15685
17373
|
ggml_vec_cpy_f32(nx, gp, g);
|
15686
17374
|
|
15687
|
-
ls = linesearch_backtracking(ctx, ¶ms, nx, x, &fx, g, d,
|
17375
|
+
ls = linesearch_backtracking(ctx, ¶ms, nx, x, &fx, g, d, step, xp, f, gf, gb, np, ps);
|
15688
17376
|
|
15689
17377
|
if (ls < 0) {
|
15690
17378
|
// linesearch failed - go back to the previous point and return
|
@@ -15710,32 +17398,32 @@ static enum ggml_opt_result ggml_opt_lbfgs(
|
|
15710
17398
|
// delta-based convergence test
|
15711
17399
|
if (pf != NULL) {
|
15712
17400
|
// need at least params.past iterations to start checking for convergence
|
15713
|
-
if (params.past <= k) {
|
15714
|
-
const float rate = (pf[k%params.past] - fx)/fx;
|
17401
|
+
if (params.past <= k[0]) {
|
17402
|
+
const float rate = (pf[k[0]%params.past] - fx)/fx;
|
15715
17403
|
|
15716
17404
|
if (fabsf(rate) < params.delta) {
|
15717
17405
|
return GGML_OPT_OK;
|
15718
17406
|
}
|
15719
17407
|
}
|
15720
17408
|
|
15721
|
-
pf[k%params.past] = fx;
|
17409
|
+
pf[k[0]%params.past] = fx;
|
15722
17410
|
}
|
15723
17411
|
|
15724
17412
|
// check for improvement
|
15725
17413
|
if (params.max_no_improvement > 0) {
|
15726
|
-
if (fx < fx_best) {
|
15727
|
-
fx_best = fx;
|
15728
|
-
n_no_improvement = 0;
|
17414
|
+
if (fx < fx_best[0]) {
|
17415
|
+
fx_best[0] = fx;
|
17416
|
+
n_no_improvement[0] = 0;
|
15729
17417
|
} else {
|
15730
|
-
n_no_improvement++;
|
17418
|
+
n_no_improvement[0]++;
|
15731
17419
|
|
15732
|
-
if (n_no_improvement >= params.max_no_improvement) {
|
17420
|
+
if (n_no_improvement[0] >= params.max_no_improvement) {
|
15733
17421
|
return GGML_OPT_OK;
|
15734
17422
|
}
|
15735
17423
|
}
|
15736
17424
|
}
|
15737
17425
|
|
15738
|
-
if (params.lbfgs.n_iter != 0 && params.lbfgs.n_iter <
|
17426
|
+
if (params.lbfgs.n_iter != 0 && params.lbfgs.n_iter < it + 1) {
|
15739
17427
|
// reached the maximum number of iterations
|
15740
17428
|
return GGML_OPT_DID_NOT_CONVERGE;
|
15741
17429
|
}
|
@@ -15744,50 +17432,51 @@ static enum ggml_opt_result ggml_opt_lbfgs(
|
|
15744
17432
|
// s_{k+1} = x_{k+1} - x_{k} = \step * d_{k}.
|
15745
17433
|
// y_{k+1} = g_{k+1} - g_{k}.
|
15746
17434
|
//
|
15747
|
-
ggml_vec_sub_f32(nx,
|
15748
|
-
ggml_vec_sub_f32(nx,
|
17435
|
+
ggml_vec_sub_f32(nx, &lm_s[end[0]*nx], x, xp);
|
17436
|
+
ggml_vec_sub_f32(nx, &lm_y[end[0]*nx], g, gp);
|
15749
17437
|
|
15750
17438
|
// compute scalars ys and yy:
|
15751
17439
|
// ys = y^t \cdot s -> 1 / \rho.
|
15752
17440
|
// yy = y^t \cdot y.
|
15753
17441
|
//
|
15754
|
-
ggml_vec_dot_f32(nx, &ys,
|
15755
|
-
ggml_vec_dot_f32(nx, &yy,
|
17442
|
+
ggml_vec_dot_f32(nx, &ys, &lm_y[end[0]*nx], &lm_s[end[0] *nx]);
|
17443
|
+
ggml_vec_dot_f32(nx, &yy, &lm_y[end[0]*nx], &lm_y[end[0]*nx]);
|
15756
17444
|
|
15757
|
-
|
17445
|
+
lm_ys[end[0]] = ys;
|
15758
17446
|
|
15759
17447
|
// find new search direction
|
15760
17448
|
// ref: https://en.wikipedia.org/wiki/Limited-memory_BFGS
|
15761
17449
|
|
15762
|
-
bound = (m <= k) ? m : k;
|
15763
|
-
k++;
|
15764
|
-
|
17450
|
+
bound = (m <= k[0]) ? m : k[0];
|
17451
|
+
k[0]++;
|
17452
|
+
it++;
|
17453
|
+
end[0] = (end[0] + 1)%m;
|
15765
17454
|
|
15766
17455
|
// initialize search direction with -g
|
15767
17456
|
ggml_vec_neg_f32(nx, d, g);
|
15768
17457
|
|
15769
|
-
j = end;
|
17458
|
+
j[0] = end[0];
|
15770
17459
|
for (int i = 0; i < bound; ++i) {
|
15771
|
-
j = (j + m - 1) % m;
|
17460
|
+
j[0] = (j[0] + m - 1) % m;
|
15772
17461
|
// \alpha_{j} = \rho_{j} s^{t}_{j} \cdot q_{k+1}
|
15773
|
-
ggml_vec_dot_f32(nx, &
|
15774
|
-
|
17462
|
+
ggml_vec_dot_f32(nx, &lm_alpha[j[0]], &lm_s[j[0]*nx], d);
|
17463
|
+
lm_alpha[j[0]] /= lm_ys[j[0]];
|
15775
17464
|
// q_{i} = q_{i+1} - \alpha_{i} y_{i}
|
15776
|
-
ggml_vec_mad_f32(nx, d,
|
17465
|
+
ggml_vec_mad_f32(nx, d, &lm_y[j[0]*nx], -lm_alpha[j[0]]);
|
15777
17466
|
}
|
15778
17467
|
|
15779
17468
|
ggml_vec_scale_f32(nx, d, ys/yy);
|
15780
17469
|
|
15781
17470
|
for (int i = 0; i < bound; ++i) {
|
15782
17471
|
// \beta_{j} = \rho_{j} y^t_{j} \cdot \gamma_{i}
|
15783
|
-
ggml_vec_dot_f32(nx, &beta,
|
15784
|
-
beta /=
|
17472
|
+
ggml_vec_dot_f32(nx, &beta, &lm_y[j[0]*nx], d);
|
17473
|
+
beta /= lm_ys[j[0]];
|
15785
17474
|
// \gamma_{i+1} = \gamma_{i} + (\alpha_{j} - \beta_{j}) s_{j}
|
15786
|
-
ggml_vec_mad_f32(nx, d,
|
15787
|
-
j = (j + 1)%m;
|
17475
|
+
ggml_vec_mad_f32(nx, d, &lm_s[j[0]*nx], lm_alpha[j[0]] - beta);
|
17476
|
+
j[0] = (j[0] + 1)%m;
|
15788
17477
|
}
|
15789
17478
|
|
15790
|
-
step = 1.0;
|
17479
|
+
step[0] = 1.0;
|
15791
17480
|
}
|
15792
17481
|
|
15793
17482
|
return GGML_OPT_DID_NOT_CONVERGE;
|
@@ -15812,6 +17501,8 @@ struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) {
|
|
15812
17501
|
|
15813
17502
|
.adam = {
|
15814
17503
|
.n_iter = 10000,
|
17504
|
+
.sched = 1.000f,
|
17505
|
+
.decay = 0.001f,
|
15815
17506
|
.alpha = 0.001f,
|
15816
17507
|
.beta1 = 0.9f,
|
15817
17508
|
.beta2 = 0.999f,
|
@@ -15854,6 +17545,71 @@ struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) {
|
|
15854
17545
|
return result;
|
15855
17546
|
}
|
15856
17547
|
|
17548
|
+
GGML_API void ggml_opt_init(
|
17549
|
+
struct ggml_context * ctx,
|
17550
|
+
struct ggml_opt_context * opt,
|
17551
|
+
struct ggml_opt_params params,
|
17552
|
+
int64_t nx) {
|
17553
|
+
opt->ctx = ctx;
|
17554
|
+
opt->params = params;
|
17555
|
+
opt->iter = 0;
|
17556
|
+
opt->nx = nx;
|
17557
|
+
opt->just_initialized = true;
|
17558
|
+
switch (opt->params.type) {
|
17559
|
+
case GGML_OPT_ADAM:
|
17560
|
+
{
|
17561
|
+
opt->adam.x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
17562
|
+
opt->adam.g1 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
17563
|
+
opt->adam.g2 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
17564
|
+
opt->adam.m = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
17565
|
+
opt->adam.v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
17566
|
+
opt->adam.mh = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
17567
|
+
opt->adam.vh = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
17568
|
+
opt->adam.pf = params.past > 0
|
17569
|
+
? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)
|
17570
|
+
: NULL;
|
17571
|
+
ggml_set_zero(opt->adam.x);
|
17572
|
+
ggml_set_zero(opt->adam.g1);
|
17573
|
+
ggml_set_zero(opt->adam.g2);
|
17574
|
+
ggml_set_zero(opt->adam.m);
|
17575
|
+
ggml_set_zero(opt->adam.v);
|
17576
|
+
ggml_set_zero(opt->adam.mh);
|
17577
|
+
ggml_set_zero(opt->adam.vh);
|
17578
|
+
if (opt->adam.pf) {
|
17579
|
+
ggml_set_zero(opt->adam.pf);
|
17580
|
+
}
|
17581
|
+
} break;
|
17582
|
+
case GGML_OPT_LBFGS:
|
17583
|
+
{
|
17584
|
+
opt->lbfgs.x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
17585
|
+
opt->lbfgs.xp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
17586
|
+
opt->lbfgs.g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
17587
|
+
opt->lbfgs.gp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
17588
|
+
opt->lbfgs.d = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
17589
|
+
opt->lbfgs.pf = params.past > 0
|
17590
|
+
? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)
|
17591
|
+
: NULL;
|
17592
|
+
opt->lbfgs.lmal = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.lbfgs.m);
|
17593
|
+
opt->lbfgs.lmys = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.lbfgs.m);
|
17594
|
+
opt->lbfgs.lms = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nx, params.lbfgs.m);
|
17595
|
+
opt->lbfgs.lmy = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nx, params.lbfgs.m);
|
17596
|
+
ggml_set_zero(opt->lbfgs.x);
|
17597
|
+
ggml_set_zero(opt->lbfgs.xp);
|
17598
|
+
ggml_set_zero(opt->lbfgs.g);
|
17599
|
+
ggml_set_zero(opt->lbfgs.gp);
|
17600
|
+
ggml_set_zero(opt->lbfgs.d);
|
17601
|
+
ggml_set_zero(opt->lbfgs.pf);
|
17602
|
+
if (opt->lbfgs.pf) {
|
17603
|
+
ggml_set_zero(opt->lbfgs.pf);
|
17604
|
+
}
|
17605
|
+
ggml_set_zero(opt->lbfgs.lmal);
|
17606
|
+
ggml_set_zero(opt->lbfgs.lmys);
|
17607
|
+
ggml_set_zero(opt->lbfgs.lms);
|
17608
|
+
ggml_set_zero(opt->lbfgs.lmy);
|
17609
|
+
} break;
|
17610
|
+
}
|
17611
|
+
}
|
17612
|
+
|
15857
17613
|
enum ggml_opt_result ggml_opt(
|
15858
17614
|
struct ggml_context * ctx,
|
15859
17615
|
struct ggml_opt_params params,
|
@@ -15876,33 +17632,65 @@ enum ggml_opt_result ggml_opt(
|
|
15876
17632
|
|
15877
17633
|
enum ggml_opt_result result = GGML_OPT_OK;
|
15878
17634
|
|
17635
|
+
struct ggml_opt_context * opt = (struct ggml_opt_context *) alloca(sizeof(struct ggml_opt_context));
|
17636
|
+
|
17637
|
+
ggml_opt_init(ctx, opt, params, 0);
|
17638
|
+
result = ggml_opt_resume(ctx, opt, f);
|
17639
|
+
|
17640
|
+
if (free_ctx) {
|
17641
|
+
ggml_free(ctx);
|
17642
|
+
}
|
17643
|
+
|
17644
|
+
return result;
|
17645
|
+
}
|
17646
|
+
|
17647
|
+
enum ggml_opt_result ggml_opt_resume(
|
17648
|
+
struct ggml_context * ctx,
|
17649
|
+
struct ggml_opt_context * opt,
|
17650
|
+
struct ggml_tensor * f) {
|
17651
|
+
|
17652
|
+
// build forward + backward compute graphs
|
17653
|
+
struct ggml_tensor * gfbuf = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(struct ggml_cgraph) / GGML_TYPE_SIZE[GGML_TYPE_I32]+ (sizeof(struct ggml_cgraph) % GGML_TYPE_SIZE[GGML_TYPE_I32] ? 1 : 0));
|
17654
|
+
struct ggml_tensor * gbbuf = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(struct ggml_cgraph) / GGML_TYPE_SIZE[GGML_TYPE_I32]+ (sizeof(struct ggml_cgraph) % GGML_TYPE_SIZE[GGML_TYPE_I32] ? 1 : 0));
|
17655
|
+
|
17656
|
+
struct ggml_cgraph * gf = (struct ggml_cgraph *) gfbuf->data;
|
17657
|
+
struct ggml_cgraph * gb = (struct ggml_cgraph *) gbbuf->data;
|
17658
|
+
|
17659
|
+
*gf = ggml_build_forward (f);
|
17660
|
+
*gb = ggml_build_backward(ctx, gf, true);
|
17661
|
+
|
17662
|
+
return ggml_opt_resume_g(ctx, opt, f, gf, gb);
|
17663
|
+
}
|
17664
|
+
|
17665
|
+
enum ggml_opt_result ggml_opt_resume_g(
|
17666
|
+
struct ggml_context * ctx,
|
17667
|
+
struct ggml_opt_context * opt,
|
17668
|
+
struct ggml_tensor * f,
|
17669
|
+
struct ggml_cgraph * gf,
|
17670
|
+
struct ggml_cgraph * gb) {
|
17671
|
+
|
15879
17672
|
// build forward + backward compute graphs
|
15880
|
-
|
15881
|
-
struct ggml_cgraph gb = ggml_build_backward(ctx, &gf, true);
|
17673
|
+
enum ggml_opt_result result = GGML_OPT_OK;
|
15882
17674
|
|
15883
|
-
switch (params.type) {
|
17675
|
+
switch (opt->params.type) {
|
15884
17676
|
case GGML_OPT_ADAM:
|
15885
17677
|
{
|
15886
|
-
result = ggml_opt_adam(ctx, params, f,
|
17678
|
+
result = ggml_opt_adam(ctx, opt, opt->params, f, gf, gb);
|
15887
17679
|
} break;
|
15888
17680
|
case GGML_OPT_LBFGS:
|
15889
17681
|
{
|
15890
|
-
result = ggml_opt_lbfgs(ctx, params, f,
|
17682
|
+
result = ggml_opt_lbfgs(ctx, opt, opt->params, f, gf, gb);
|
15891
17683
|
} break;
|
15892
17684
|
}
|
15893
17685
|
|
15894
|
-
if (params.print_forward_graph) {
|
15895
|
-
ggml_graph_print (
|
15896
|
-
ggml_graph_dump_dot(
|
15897
|
-
}
|
15898
|
-
|
15899
|
-
if (params.print_backward_graph) {
|
15900
|
-
ggml_graph_print (&gb);
|
15901
|
-
ggml_graph_dump_dot(&gb, &gf, "opt-backward.dot");
|
17686
|
+
if (opt->params.print_forward_graph) {
|
17687
|
+
ggml_graph_print (gf);
|
17688
|
+
ggml_graph_dump_dot(gf, NULL, "opt-forward.dot");
|
15902
17689
|
}
|
15903
17690
|
|
15904
|
-
if (
|
15905
|
-
|
17691
|
+
if (opt->params.print_backward_graph) {
|
17692
|
+
ggml_graph_print (gb);
|
17693
|
+
ggml_graph_dump_dot(gb, gf, "opt-backward.dot");
|
15906
17694
|
}
|
15907
17695
|
|
15908
17696
|
return result;
|
@@ -16070,6 +17858,50 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i
|
|
16070
17858
|
block_q8_0 * block = (block_q8_0*)dst + start / QK8_0;
|
16071
17859
|
result = ggml_quantize_q8_0(src + start, block, n, n, hist);
|
16072
17860
|
} break;
|
17861
|
+
#ifdef GGML_USE_K_QUANTS
|
17862
|
+
case GGML_TYPE_Q2_K:
|
17863
|
+
{
|
17864
|
+
GGML_ASSERT(start % QK_K == 0);
|
17865
|
+
block_q2_K * block = (block_q2_K*)dst + start / QK_K;
|
17866
|
+
result = ggml_quantize_q2_K(src + start, block, n, n, hist);
|
17867
|
+
} break;
|
17868
|
+
case GGML_TYPE_Q3_K:
|
17869
|
+
{
|
17870
|
+
GGML_ASSERT(start % QK_K == 0);
|
17871
|
+
block_q3_K * block = (block_q3_K*)dst + start / QK_K;
|
17872
|
+
result = ggml_quantize_q3_K(src + start, block, n, n, hist);
|
17873
|
+
} break;
|
17874
|
+
case GGML_TYPE_Q4_K:
|
17875
|
+
{
|
17876
|
+
GGML_ASSERT(start % QK_K == 0);
|
17877
|
+
block_q4_K * block = (block_q4_K*)dst + start / QK_K;
|
17878
|
+
result = ggml_quantize_q4_K(src + start, block, n, n, hist);
|
17879
|
+
} break;
|
17880
|
+
case GGML_TYPE_Q5_K:
|
17881
|
+
{
|
17882
|
+
GGML_ASSERT(start % QK_K == 0);
|
17883
|
+
block_q5_K * block = (block_q5_K*)dst + start / QK_K;
|
17884
|
+
result = ggml_quantize_q5_K(src + start, block, n, n, hist);
|
17885
|
+
} break;
|
17886
|
+
case GGML_TYPE_Q6_K:
|
17887
|
+
{
|
17888
|
+
GGML_ASSERT(start % QK_K == 0);
|
17889
|
+
block_q6_K * block = (block_q6_K*)dst + start / QK_K;
|
17890
|
+
result = ggml_quantize_q6_K(src + start, block, n, n, hist);
|
17891
|
+
} break;
|
17892
|
+
#endif
|
17893
|
+
case GGML_TYPE_F16:
|
17894
|
+
{
|
17895
|
+
int elemsize = sizeof(ggml_fp16_t);
|
17896
|
+
ggml_fp32_to_fp16_row(src + start, (ggml_fp16_t *)dst + start, n);
|
17897
|
+
result = n * elemsize;
|
17898
|
+
} break;
|
17899
|
+
case GGML_TYPE_F32:
|
17900
|
+
{
|
17901
|
+
int elemsize = sizeof(float);
|
17902
|
+
result = n * elemsize;
|
17903
|
+
memcpy((uint8_t *)dst + start * elemsize, src + start, result);
|
17904
|
+
} break;
|
16073
17905
|
default:
|
16074
17906
|
assert(false);
|
16075
17907
|
}
|