llama_cpp 0.1.4 → 0.2.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +36 -0
- data/examples/README.md +60 -0
- data/examples/chat.rb +195 -0
- data/ext/llama_cpp/extconf.rb +26 -1
- data/ext/llama_cpp/llama_cpp.cpp +262 -13
- data/ext/llama_cpp/src/ggml-cuda.cu +2483 -0
- data/ext/llama_cpp/src/ggml-cuda.h +18 -2
- data/ext/llama_cpp/src/ggml-metal.h +64 -0
- data/ext/llama_cpp/src/ggml-metal.m +834 -0
- data/ext/llama_cpp/src/ggml-metal.metal +1436 -0
- data/ext/llama_cpp/src/ggml-opencl.cpp +207 -40
- data/ext/llama_cpp/src/ggml-opencl.h +4 -1
- data/ext/llama_cpp/src/ggml.c +2236 -404
- data/ext/llama_cpp/src/ggml.h +170 -8
- data/ext/llama_cpp/src/k_quants.c +2244 -0
- data/ext/llama_cpp/src/k_quants.h +122 -0
- data/ext/llama_cpp/src/llama-util.h +16 -0
- data/ext/llama_cpp/src/llama.cpp +631 -179
- data/ext/llama_cpp/src/llama.h +51 -11
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +36 -1
- metadata +10 -2
data/ext/llama_cpp/src/ggml.c
CHANGED
@@ -3,6 +3,10 @@
|
|
3
3
|
|
4
4
|
#include "ggml.h"
|
5
5
|
|
6
|
+
#ifdef GGML_USE_K_QUANTS
|
7
|
+
#include "k_quants.h"
|
8
|
+
#endif
|
9
|
+
|
6
10
|
#if defined(_MSC_VER) || defined(__MINGW32__)
|
7
11
|
#include <malloc.h> // using malloc.h with MSC/MINGW
|
8
12
|
#elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__)
|
@@ -21,6 +25,10 @@
|
|
21
25
|
#include <float.h>
|
22
26
|
#include <limits.h>
|
23
27
|
|
28
|
+
#ifdef GGML_USE_METAL
|
29
|
+
#include <unistd.h>
|
30
|
+
#endif
|
31
|
+
|
24
32
|
// if C99 - static_assert is noop
|
25
33
|
// ref: https://stackoverflow.com/a/53923785/4039976
|
26
34
|
#ifndef static_assert
|
@@ -121,7 +129,11 @@ typedef void* thread_ret_t;
|
|
121
129
|
#else
|
122
130
|
inline static void* ggml_aligned_malloc(size_t size) {
|
123
131
|
void* aligned_memory = NULL;
|
132
|
+
#ifdef GGML_USE_METAL
|
133
|
+
int result = posix_memalign(&aligned_memory, getpagesize(), size);
|
134
|
+
#else
|
124
135
|
int result = posix_memalign(&aligned_memory, GGML_MEM_ALIGN, size);
|
136
|
+
#endif
|
125
137
|
if (result != 0) {
|
126
138
|
// Handle allocation failure
|
127
139
|
return NULL;
|
@@ -403,21 +415,27 @@ void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, size_t n) {
|
|
403
415
|
//
|
404
416
|
|
405
417
|
#if defined(_MSC_VER) || defined(__MINGW32__)
|
406
|
-
static int64_t timer_freq;
|
418
|
+
static int64_t timer_freq, timer_start;
|
407
419
|
void ggml_time_init(void) {
|
408
|
-
LARGE_INTEGER
|
409
|
-
QueryPerformanceFrequency(&
|
410
|
-
timer_freq =
|
420
|
+
LARGE_INTEGER t;
|
421
|
+
QueryPerformanceFrequency(&t);
|
422
|
+
timer_freq = t.QuadPart;
|
423
|
+
|
424
|
+
// The multiplication by 1000 or 1000000 below can cause an overflow if timer_freq
|
425
|
+
// and the uptime is high enough.
|
426
|
+
// We subtract the program start time to reduce the likelihood of that happening.
|
427
|
+
QueryPerformanceCounter(&t);
|
428
|
+
timer_start = t.QuadPart;
|
411
429
|
}
|
412
430
|
int64_t ggml_time_ms(void) {
|
413
431
|
LARGE_INTEGER t;
|
414
432
|
QueryPerformanceCounter(&t);
|
415
|
-
return (t.QuadPart * 1000) / timer_freq;
|
433
|
+
return ((t.QuadPart-timer_start) * 1000) / timer_freq;
|
416
434
|
}
|
417
435
|
int64_t ggml_time_us(void) {
|
418
436
|
LARGE_INTEGER t;
|
419
437
|
QueryPerformanceCounter(&t);
|
420
|
-
return (t.QuadPart * 1000000) / timer_freq;
|
438
|
+
return ((t.QuadPart-timer_start) * 1000000) / timer_freq;
|
421
439
|
}
|
422
440
|
#else
|
423
441
|
void ggml_time_init(void) {}
|
@@ -474,6 +492,8 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
|
|
474
492
|
// quantization
|
475
493
|
//
|
476
494
|
|
495
|
+
#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
|
496
|
+
|
477
497
|
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
|
478
498
|
// multiply int8_t, add results pairwise twice
|
479
499
|
static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) {
|
@@ -533,7 +553,7 @@ static inline __m256i bytes_from_bits_32(const uint8_t * x) {
|
|
533
553
|
static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
|
534
554
|
{
|
535
555
|
const __m128i tmp = _mm_loadu_si128((const __m128i *)rsi);
|
536
|
-
const __m256i bytes =
|
556
|
+
const __m256i bytes = MM256_SET_M128I(_mm_srli_epi16(tmp, 4), tmp);
|
537
557
|
const __m256i lowMask = _mm256_set1_epi8( 0xF );
|
538
558
|
return _mm256_and_si256(lowMask, bytes);
|
539
559
|
}
|
@@ -606,7 +626,7 @@ static inline __m256i bytes_from_bits_32(const uint8_t * x) {
|
|
606
626
|
bytesh = _mm_or_si128(bytesh, bit_mask);
|
607
627
|
bytesl = _mm_cmpeq_epi8(bytesl, _mm_set1_epi64x(-1));
|
608
628
|
bytesh = _mm_cmpeq_epi8(bytesh, _mm_set1_epi64x(-1));
|
609
|
-
return
|
629
|
+
return MM256_SET_M128I(bytesh, bytesl);
|
610
630
|
}
|
611
631
|
|
612
632
|
// Unpack 32 4-bit fields into 32 bytes
|
@@ -619,7 +639,7 @@ static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
|
|
619
639
|
const __m128i lowMask = _mm_set1_epi8(0xF);
|
620
640
|
tmpl = _mm_and_si128(lowMask, tmpl);
|
621
641
|
tmph = _mm_and_si128(lowMask, tmph);
|
622
|
-
return
|
642
|
+
return MM256_SET_M128I(tmph, tmpl);
|
623
643
|
}
|
624
644
|
|
625
645
|
// add int16_t pairwise and return as float vector
|
@@ -627,7 +647,7 @@ static inline __m256 sum_i16_pairs_float(const __m128i xh, const __m128i xl) {
|
|
627
647
|
const __m128i ones = _mm_set1_epi16(1);
|
628
648
|
const __m128i summed_pairsl = _mm_madd_epi16(ones, xl);
|
629
649
|
const __m128i summed_pairsh = _mm_madd_epi16(ones, xh);
|
630
|
-
const __m256i summed_pairs =
|
650
|
+
const __m256i summed_pairs = MM256_SET_M128I(summed_pairsh, summed_pairsl);
|
631
651
|
return _mm256_cvtepi32_ps(summed_pairs);
|
632
652
|
}
|
633
653
|
|
@@ -1565,6 +1585,48 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
|
|
1565
1585
|
.vec_dot_q = NULL, // TODO
|
1566
1586
|
.vec_dot_type = GGML_TYPE_Q8_1,
|
1567
1587
|
},
|
1588
|
+
#ifdef GGML_USE_K_QUANTS
|
1589
|
+
[GGML_TYPE_Q2_K] = {
|
1590
|
+
.dequantize_row_q = (dequantize_row_q_t) dequantize_row_q2_K,
|
1591
|
+
.quantize_row_q = quantize_row_q2_K,
|
1592
|
+
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q2_K_reference,
|
1593
|
+
.quantize_row_q_dot = quantize_row_q8_K,
|
1594
|
+
.vec_dot_q = ggml_vec_dot_q2_K_q8_K,
|
1595
|
+
.vec_dot_type = GGML_TYPE_Q8_K,
|
1596
|
+
},
|
1597
|
+
[GGML_TYPE_Q3_K] = {
|
1598
|
+
.dequantize_row_q = (dequantize_row_q_t) dequantize_row_q3_K,
|
1599
|
+
.quantize_row_q = quantize_row_q3_K,
|
1600
|
+
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q3_K_reference,
|
1601
|
+
.quantize_row_q_dot = quantize_row_q8_K,
|
1602
|
+
.vec_dot_q = ggml_vec_dot_q3_K_q8_K,
|
1603
|
+
.vec_dot_type = GGML_TYPE_Q8_K,
|
1604
|
+
},
|
1605
|
+
[GGML_TYPE_Q4_K] = {
|
1606
|
+
.dequantize_row_q = (dequantize_row_q_t) dequantize_row_q4_K,
|
1607
|
+
.quantize_row_q = quantize_row_q4_K,
|
1608
|
+
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_K_reference,
|
1609
|
+
.quantize_row_q_dot = quantize_row_q8_K,
|
1610
|
+
.vec_dot_q = ggml_vec_dot_q4_K_q8_K,
|
1611
|
+
.vec_dot_type = GGML_TYPE_Q8_K,
|
1612
|
+
},
|
1613
|
+
[GGML_TYPE_Q5_K] = {
|
1614
|
+
.dequantize_row_q = (dequantize_row_q_t) dequantize_row_q5_K,
|
1615
|
+
.quantize_row_q = quantize_row_q5_K,
|
1616
|
+
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q5_K_reference,
|
1617
|
+
.quantize_row_q_dot = quantize_row_q8_K,
|
1618
|
+
.vec_dot_q = ggml_vec_dot_q5_K_q8_K,
|
1619
|
+
.vec_dot_type = GGML_TYPE_Q8_K,
|
1620
|
+
},
|
1621
|
+
[GGML_TYPE_Q6_K] = {
|
1622
|
+
.dequantize_row_q = (dequantize_row_q_t) dequantize_row_q6_K,
|
1623
|
+
.quantize_row_q = quantize_row_q6_K,
|
1624
|
+
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q6_K_reference,
|
1625
|
+
.quantize_row_q_dot = quantize_row_q8_K,
|
1626
|
+
.vec_dot_q = ggml_vec_dot_q6_K_q8_K,
|
1627
|
+
.vec_dot_type = GGML_TYPE_Q8_K,
|
1628
|
+
},
|
1629
|
+
#endif
|
1568
1630
|
};
|
1569
1631
|
|
1570
1632
|
// For internal test use
|
@@ -2290,7 +2352,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
|
|
2290
2352
|
const __m128i i32_1 = mul_sum_i8_pairs(bx, by);
|
2291
2353
|
|
2292
2354
|
// Convert int32_t to float
|
2293
|
-
__m256 p = _mm256_cvtepi32_ps(
|
2355
|
+
__m256 p = _mm256_cvtepi32_ps(MM256_SET_M128I(i32_0, i32_1));
|
2294
2356
|
|
2295
2357
|
// Apply the scale, and accumulate
|
2296
2358
|
acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc);
|
@@ -2766,7 +2828,7 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
|
|
2766
2828
|
__m128i bxh = _mm256_extractf128_si256(bx, 1);
|
2767
2829
|
bxl = _mm_or_si128(bxl, bxhil);
|
2768
2830
|
bxh = _mm_or_si128(bxh, bxhih);
|
2769
|
-
bx =
|
2831
|
+
bx = MM256_SET_M128I(bxh, bxl);
|
2770
2832
|
|
2771
2833
|
const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
|
2772
2834
|
|
@@ -3022,7 +3084,7 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
|
|
3022
3084
|
__m128i bxh = _mm256_extractf128_si256(bx, 1);
|
3023
3085
|
bxl = _mm_or_si128(bxl, bxhil);
|
3024
3086
|
bxh = _mm_or_si128(bxh, bxhih);
|
3025
|
-
bx =
|
3087
|
+
bx = MM256_SET_M128I(bxh, bxl);
|
3026
3088
|
|
3027
3089
|
const __m256 dy = _mm256_set1_ps(y[i].d);
|
3028
3090
|
const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
|
@@ -3444,11 +3506,19 @@ static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = {
|
|
3444
3506
|
[GGML_TYPE_Q5_1] = QK5_1,
|
3445
3507
|
[GGML_TYPE_Q8_0] = QK8_0,
|
3446
3508
|
[GGML_TYPE_Q8_1] = QK8_1,
|
3509
|
+
#ifdef GGML_USE_K_QUANTS
|
3510
|
+
[GGML_TYPE_Q2_K] = QK_K,
|
3511
|
+
[GGML_TYPE_Q3_K] = QK_K,
|
3512
|
+
[GGML_TYPE_Q4_K] = QK_K,
|
3513
|
+
[GGML_TYPE_Q5_K] = QK_K,
|
3514
|
+
[GGML_TYPE_Q6_K] = QK_K,
|
3515
|
+
[GGML_TYPE_Q8_K] = QK_K,
|
3516
|
+
#endif
|
3447
3517
|
[GGML_TYPE_I8] = 1,
|
3448
3518
|
[GGML_TYPE_I16] = 1,
|
3449
3519
|
[GGML_TYPE_I32] = 1,
|
3450
3520
|
};
|
3451
|
-
static_assert(GGML_TYPE_COUNT ==
|
3521
|
+
static_assert(GGML_TYPE_COUNT == 19, "GGML_BLCK_SIZE is outdated");
|
3452
3522
|
|
3453
3523
|
static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
|
3454
3524
|
[GGML_TYPE_F32] = sizeof(float),
|
@@ -3459,11 +3529,19 @@ static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
|
|
3459
3529
|
[GGML_TYPE_Q5_1] = sizeof(block_q5_1),
|
3460
3530
|
[GGML_TYPE_Q8_0] = sizeof(block_q8_0),
|
3461
3531
|
[GGML_TYPE_Q8_1] = sizeof(block_q8_1),
|
3532
|
+
#ifdef GGML_USE_K_QUANTS
|
3533
|
+
[GGML_TYPE_Q2_K] = sizeof(block_q2_K),
|
3534
|
+
[GGML_TYPE_Q3_K] = sizeof(block_q3_K),
|
3535
|
+
[GGML_TYPE_Q4_K] = sizeof(block_q4_K),
|
3536
|
+
[GGML_TYPE_Q5_K] = sizeof(block_q5_K),
|
3537
|
+
[GGML_TYPE_Q6_K] = sizeof(block_q6_K),
|
3538
|
+
[GGML_TYPE_Q8_K] = sizeof(block_q8_K),
|
3539
|
+
#endif
|
3462
3540
|
[GGML_TYPE_I8] = sizeof(int8_t),
|
3463
3541
|
[GGML_TYPE_I16] = sizeof(int16_t),
|
3464
3542
|
[GGML_TYPE_I32] = sizeof(int32_t),
|
3465
3543
|
};
|
3466
|
-
static_assert(GGML_TYPE_COUNT ==
|
3544
|
+
static_assert(GGML_TYPE_COUNT == 19, "GGML_TYPE_SIZE is outdated");
|
3467
3545
|
|
3468
3546
|
|
3469
3547
|
static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
|
@@ -3475,11 +3553,17 @@ static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
|
|
3475
3553
|
[GGML_TYPE_Q5_1] = "q5_1",
|
3476
3554
|
[GGML_TYPE_Q8_0] = "q8_0",
|
3477
3555
|
[GGML_TYPE_Q8_1] = "q8_1",
|
3556
|
+
[GGML_TYPE_Q2_K] = "q2_K",
|
3557
|
+
[GGML_TYPE_Q3_K] = "q3_K",
|
3558
|
+
[GGML_TYPE_Q4_K] = "q4_K",
|
3559
|
+
[GGML_TYPE_Q5_K] = "q5_K",
|
3560
|
+
[GGML_TYPE_Q6_K] = "q6_K",
|
3561
|
+
[GGML_TYPE_Q8_K] = "q8_K",
|
3478
3562
|
[GGML_TYPE_I8] = "i8",
|
3479
3563
|
[GGML_TYPE_I16] = "i16",
|
3480
3564
|
[GGML_TYPE_I32] = "i32",
|
3481
3565
|
};
|
3482
|
-
static_assert(GGML_TYPE_COUNT ==
|
3566
|
+
static_assert(GGML_TYPE_COUNT == 19, "GGML_TYPE_NAME is outdated");
|
3483
3567
|
|
3484
3568
|
static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = {
|
3485
3569
|
[GGML_TYPE_F32] = false,
|
@@ -3490,11 +3574,17 @@ static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = {
|
|
3490
3574
|
[GGML_TYPE_Q5_1] = true,
|
3491
3575
|
[GGML_TYPE_Q8_0] = true,
|
3492
3576
|
[GGML_TYPE_Q8_1] = true,
|
3577
|
+
[GGML_TYPE_Q2_K] = true,
|
3578
|
+
[GGML_TYPE_Q3_K] = true,
|
3579
|
+
[GGML_TYPE_Q4_K] = true,
|
3580
|
+
[GGML_TYPE_Q5_K] = true,
|
3581
|
+
[GGML_TYPE_Q6_K] = true,
|
3582
|
+
[GGML_TYPE_Q8_K] = true,
|
3493
3583
|
[GGML_TYPE_I8] = false,
|
3494
3584
|
[GGML_TYPE_I16] = false,
|
3495
3585
|
[GGML_TYPE_I32] = false,
|
3496
3586
|
};
|
3497
|
-
static_assert(GGML_TYPE_COUNT ==
|
3587
|
+
static_assert(GGML_TYPE_COUNT == 19, "GGML_IS_QUANTIZED is outdated");
|
3498
3588
|
|
3499
3589
|
static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
3500
3590
|
"NONE",
|
@@ -3513,6 +3603,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|
3513
3603
|
"SUM_ROWS",
|
3514
3604
|
"MEAN",
|
3515
3605
|
"REPEAT",
|
3606
|
+
"REPEAT_BACK",
|
3516
3607
|
"ABS",
|
3517
3608
|
"SGN",
|
3518
3609
|
"NEG",
|
@@ -3526,6 +3617,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|
3526
3617
|
"RMS_NORM_BACK",
|
3527
3618
|
|
3528
3619
|
"MUL_MAT",
|
3620
|
+
"OUT_PROD",
|
3529
3621
|
|
3530
3622
|
"SCALE",
|
3531
3623
|
"SET",
|
@@ -3541,6 +3633,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|
3541
3633
|
"DIAG_MASK_INF",
|
3542
3634
|
"DIAG_MASK_ZERO",
|
3543
3635
|
"SOFT_MAX",
|
3636
|
+
"SOFT_MAX_BACK",
|
3544
3637
|
"ROPE",
|
3545
3638
|
"ROPE_BACK",
|
3546
3639
|
"ALIBI",
|
@@ -3550,13 +3643,16 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|
3550
3643
|
|
3551
3644
|
"FLASH_ATTN",
|
3552
3645
|
"FLASH_FF",
|
3646
|
+
"FLASH_ATTN_BACK",
|
3553
3647
|
|
3554
3648
|
"MAP_UNARY",
|
3555
3649
|
"MAP_BINARY",
|
3556
|
-
};
|
3557
3650
|
|
3558
|
-
|
3651
|
+
"CROSS_ENTROPY_LOSS",
|
3652
|
+
"CROSS_ENTROPY_LOSS_BACK",
|
3653
|
+
};
|
3559
3654
|
|
3655
|
+
static_assert(GGML_OP_COUNT == 57, "GGML_OP_COUNT != 57");
|
3560
3656
|
|
3561
3657
|
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
3562
3658
|
"none",
|
@@ -3575,6 +3671,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|
3575
3671
|
"Σx_k",
|
3576
3672
|
"Σx/n",
|
3577
3673
|
"repeat(x)",
|
3674
|
+
"repeat_back(x)",
|
3578
3675
|
"abs(x)",
|
3579
3676
|
"sgn(x)",
|
3580
3677
|
"-x",
|
@@ -3587,6 +3684,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|
3587
3684
|
"rms_norm(x)",
|
3588
3685
|
"rms_norm_back(x)",
|
3589
3686
|
|
3687
|
+
"X*Y",
|
3590
3688
|
"X*Y",
|
3591
3689
|
|
3592
3690
|
"x*v",
|
@@ -3603,6 +3701,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|
3603
3701
|
"diag_mask_inf(x)",
|
3604
3702
|
"diag_mask_zero(x)",
|
3605
3703
|
"soft_max(x)",
|
3704
|
+
"soft_max_back(x)",
|
3606
3705
|
"rope(x)",
|
3607
3706
|
"rope_back(x)",
|
3608
3707
|
"alibi(x)",
|
@@ -3612,12 +3711,16 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|
3612
3711
|
|
3613
3712
|
"flash_attn(x)",
|
3614
3713
|
"flash_ff(x)",
|
3714
|
+
"flash_attn_back(x)",
|
3615
3715
|
|
3616
3716
|
"f(x)",
|
3617
3717
|
"f(x,y)",
|
3718
|
+
|
3719
|
+
"cross_entropy_loss(x,y)",
|
3720
|
+
"cross_entropy_loss_back(x,y)",
|
3618
3721
|
};
|
3619
3722
|
|
3620
|
-
static_assert(GGML_OP_COUNT ==
|
3723
|
+
static_assert(GGML_OP_COUNT == 57, "GGML_OP_COUNT != 57");
|
3621
3724
|
|
3622
3725
|
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
|
3623
3726
|
static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
|
@@ -3631,6 +3734,7 @@ struct ggml_context {
|
|
3631
3734
|
void * mem_buffer;
|
3632
3735
|
bool mem_buffer_owned;
|
3633
3736
|
bool no_alloc;
|
3737
|
+
bool no_alloc_save; // this is used to save the no_alloc state when using scratch buffers
|
3634
3738
|
|
3635
3739
|
int n_objects;
|
3636
3740
|
|
@@ -3647,26 +3751,6 @@ struct ggml_context_container {
|
|
3647
3751
|
struct ggml_context context;
|
3648
3752
|
};
|
3649
3753
|
|
3650
|
-
//
|
3651
|
-
// compute types
|
3652
|
-
//
|
3653
|
-
|
3654
|
-
enum ggml_task_type {
|
3655
|
-
GGML_TASK_INIT = 0,
|
3656
|
-
GGML_TASK_COMPUTE,
|
3657
|
-
GGML_TASK_FINALIZE,
|
3658
|
-
};
|
3659
|
-
|
3660
|
-
struct ggml_compute_params {
|
3661
|
-
enum ggml_task_type type;
|
3662
|
-
|
3663
|
-
int ith, nth;
|
3664
|
-
|
3665
|
-
// work buffer for all threads
|
3666
|
-
size_t wsize;
|
3667
|
-
void * wdata;
|
3668
|
-
};
|
3669
|
-
|
3670
3754
|
//
|
3671
3755
|
// ggml state
|
3672
3756
|
//
|
@@ -3723,7 +3807,7 @@ int64_t ggml_nelements(const struct ggml_tensor * tensor) {
|
|
3723
3807
|
return tensor->ne[0]*tensor->ne[1]*tensor->ne[2]*tensor->ne[3];
|
3724
3808
|
}
|
3725
3809
|
|
3726
|
-
|
3810
|
+
int64_t ggml_nrows(const struct ggml_tensor * tensor) {
|
3727
3811
|
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
3728
3812
|
|
3729
3813
|
return tensor->ne[1]*tensor->ne[2]*tensor->ne[3];
|
@@ -3732,7 +3816,20 @@ int ggml_nrows(const struct ggml_tensor * tensor) {
|
|
3732
3816
|
size_t ggml_nbytes(const struct ggml_tensor * tensor) {
|
3733
3817
|
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
3734
3818
|
|
3735
|
-
|
3819
|
+
// this should handle cases where the tensor is not contiguous in memory
|
3820
|
+
// probaby just:
|
3821
|
+
//
|
3822
|
+
// return tensor->ne[3]*tensor->nb[3]
|
3823
|
+
//
|
3824
|
+
// is enough, but just in case, adding the second part
|
3825
|
+
|
3826
|
+
return MAX(tensor->ne[3]*tensor->nb[3], (ggml_nelements(tensor)*GGML_TYPE_SIZE[tensor->type])/GGML_BLCK_SIZE[tensor->type]);
|
3827
|
+
}
|
3828
|
+
|
3829
|
+
size_t ggml_nbytes_split(const struct ggml_tensor * tensor, int nrows_split) {
|
3830
|
+
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
3831
|
+
|
3832
|
+
return (nrows_split*tensor->ne[0]*GGML_TYPE_SIZE[tensor->type])/GGML_BLCK_SIZE[tensor->type];
|
3736
3833
|
}
|
3737
3834
|
|
3738
3835
|
int ggml_blck_size(enum ggml_type type) {
|
@@ -3786,6 +3883,15 @@ static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct
|
|
3786
3883
|
(t0->ne[3] == t1->ne[3]);
|
3787
3884
|
}
|
3788
3885
|
|
3886
|
+
static inline bool ggml_can_out_prod(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
|
3887
|
+
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
3888
|
+
|
3889
|
+
return
|
3890
|
+
(t0->ne[1] == t1->ne[1]) &&
|
3891
|
+
(t0->ne[2] == t1->ne[2]) &&
|
3892
|
+
(t0->ne[3] == t1->ne[3]);
|
3893
|
+
}
|
3894
|
+
|
3789
3895
|
bool ggml_is_quantized(enum ggml_type type) {
|
3790
3896
|
return GGML_IS_QUANTIZED[type];
|
3791
3897
|
}
|
@@ -3801,6 +3907,11 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
|
|
3801
3907
|
case GGML_FTYPE_MOSTLY_Q5_0: wtype = GGML_TYPE_Q5_0; break;
|
3802
3908
|
case GGML_FTYPE_MOSTLY_Q5_1: wtype = GGML_TYPE_Q5_1; break;
|
3803
3909
|
case GGML_FTYPE_MOSTLY_Q8_0: wtype = GGML_TYPE_Q8_0; break;
|
3910
|
+
case GGML_FTYPE_MOSTLY_Q2_K: wtype = GGML_TYPE_Q2_K; break;
|
3911
|
+
case GGML_FTYPE_MOSTLY_Q3_K: wtype = GGML_TYPE_Q3_K; break;
|
3912
|
+
case GGML_FTYPE_MOSTLY_Q4_K: wtype = GGML_TYPE_Q4_K; break;
|
3913
|
+
case GGML_FTYPE_MOSTLY_Q5_K: wtype = GGML_TYPE_Q5_K; break;
|
3914
|
+
case GGML_FTYPE_MOSTLY_Q6_K: wtype = GGML_TYPE_Q6_K; break;
|
3804
3915
|
case GGML_FTYPE_UNKNOWN: wtype = GGML_TYPE_COUNT; break;
|
3805
3916
|
case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_TYPE_COUNT; break;
|
3806
3917
|
}
|
@@ -3814,11 +3925,11 @@ size_t ggml_tensor_overhead(void) {
|
|
3814
3925
|
return GGML_OBJECT_SIZE + GGML_TENSOR_SIZE + 16;
|
3815
3926
|
}
|
3816
3927
|
|
3817
|
-
|
3928
|
+
bool ggml_is_transposed(const struct ggml_tensor * tensor) {
|
3818
3929
|
return tensor->nb[0] > tensor->nb[1];
|
3819
3930
|
}
|
3820
3931
|
|
3821
|
-
|
3932
|
+
bool ggml_is_contiguous(const struct ggml_tensor * tensor) {
|
3822
3933
|
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
3823
3934
|
|
3824
3935
|
return
|
@@ -3828,6 +3939,12 @@ static inline bool ggml_is_contiguous(const struct ggml_tensor * tensor) {
|
|
3828
3939
|
tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
|
3829
3940
|
}
|
3830
3941
|
|
3942
|
+
bool ggml_is_permuted(const struct ggml_tensor * tensor) {
|
3943
|
+
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
3944
|
+
|
3945
|
+
return tensor->nb[0] > tensor->nb[1] || tensor->nb[1] > tensor->nb[2] || tensor->nb[2] > tensor->nb[3];
|
3946
|
+
}
|
3947
|
+
|
3831
3948
|
static inline bool ggml_is_padded_1d(const struct ggml_tensor * tensor) {
|
3832
3949
|
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
3833
3950
|
|
@@ -3967,6 +4084,7 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
|
|
3967
4084
|
/*.mem_buffer =*/ params.mem_buffer ? params.mem_buffer : GGML_ALIGNED_MALLOC(mem_size),
|
3968
4085
|
/*.mem_buffer_owned =*/ params.mem_buffer ? false : true,
|
3969
4086
|
/*.no_alloc =*/ params.no_alloc,
|
4087
|
+
/*.no_alloc_save =*/ params.no_alloc,
|
3970
4088
|
/*.n_objects =*/ 0,
|
3971
4089
|
/*.objects_begin =*/ NULL,
|
3972
4090
|
/*.objects_end =*/ NULL,
|
@@ -4044,11 +4162,18 @@ size_t ggml_get_mem_size(struct ggml_context * ctx) {
|
|
4044
4162
|
// operators when using scratch buffers
|
4045
4163
|
// TODO: implement a better way
|
4046
4164
|
void ggml_scratch_save(struct ggml_context * ctx) {
|
4165
|
+
// this is needed to allow opt tensors to store their data
|
4166
|
+
// TODO: again, need to find a better way
|
4167
|
+
ctx->no_alloc_save = ctx->no_alloc;
|
4168
|
+
ctx->no_alloc = false;
|
4169
|
+
|
4047
4170
|
ctx->scratch_save = ctx->scratch;
|
4048
4171
|
ctx->scratch.data = NULL;
|
4049
4172
|
}
|
4050
4173
|
|
4051
4174
|
void ggml_scratch_load(struct ggml_context * ctx) {
|
4175
|
+
ctx->no_alloc = ctx->no_alloc_save;
|
4176
|
+
|
4052
4177
|
ctx->scratch = ctx->scratch_save;
|
4053
4178
|
}
|
4054
4179
|
|
@@ -4157,6 +4282,7 @@ struct ggml_tensor * ggml_new_tensor_impl(
|
|
4157
4282
|
/*.perf_time_us =*/ 0,
|
4158
4283
|
/*.data =*/ (data == NULL && !ctx->no_alloc) ? (void *)(result + 1) : data,
|
4159
4284
|
/*.name =*/ { 0 },
|
4285
|
+
/*.extra =*/ NULL,
|
4160
4286
|
/*.pad =*/ { 0 },
|
4161
4287
|
};
|
4162
4288
|
|
@@ -4595,7 +4721,7 @@ struct ggml_tensor * ggml_add_impl(
|
|
4595
4721
|
|
4596
4722
|
bool is_node = false;
|
4597
4723
|
|
4598
|
-
if (
|
4724
|
+
if (a->grad || b->grad) {
|
4599
4725
|
is_node = true;
|
4600
4726
|
}
|
4601
4727
|
|
@@ -4635,7 +4761,7 @@ struct ggml_tensor * ggml_add1_impl(
|
|
4635
4761
|
|
4636
4762
|
bool is_node = false;
|
4637
4763
|
|
4638
|
-
if (
|
4764
|
+
if (a->grad || b->grad) {
|
4639
4765
|
is_node = true;
|
4640
4766
|
}
|
4641
4767
|
|
@@ -5061,6 +5187,34 @@ struct ggml_tensor * ggml_repeat(
|
|
5061
5187
|
return result;
|
5062
5188
|
}
|
5063
5189
|
|
5190
|
+
// ggml_repeat_back
|
5191
|
+
|
5192
|
+
struct ggml_tensor * ggml_repeat_back(
|
5193
|
+
struct ggml_context * ctx,
|
5194
|
+
struct ggml_tensor * a,
|
5195
|
+
struct ggml_tensor * b) {
|
5196
|
+
GGML_ASSERT(ggml_can_repeat(b, a));
|
5197
|
+
|
5198
|
+
bool is_node = false;
|
5199
|
+
|
5200
|
+
if (a->grad) {
|
5201
|
+
is_node = true;
|
5202
|
+
}
|
5203
|
+
|
5204
|
+
if (ggml_are_same_shape(a, b) && !is_node) {
|
5205
|
+
return a;
|
5206
|
+
}
|
5207
|
+
|
5208
|
+
struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, b->n_dims, b->ne);
|
5209
|
+
|
5210
|
+
result->op = GGML_OP_REPEAT_BACK;
|
5211
|
+
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
5212
|
+
result->src0 = a;
|
5213
|
+
result->src1 = b;
|
5214
|
+
|
5215
|
+
return result;
|
5216
|
+
}
|
5217
|
+
|
5064
5218
|
// ggml_abs
|
5065
5219
|
|
5066
5220
|
struct ggml_tensor * ggml_abs_impl(
|
@@ -5438,6 +5592,32 @@ struct ggml_tensor * ggml_mul_mat(
|
|
5438
5592
|
return result;
|
5439
5593
|
}
|
5440
5594
|
|
5595
|
+
// ggml_out_prod
|
5596
|
+
|
5597
|
+
struct ggml_tensor * ggml_out_prod(
|
5598
|
+
struct ggml_context * ctx,
|
5599
|
+
struct ggml_tensor * a,
|
5600
|
+
struct ggml_tensor * b) {
|
5601
|
+
GGML_ASSERT(ggml_can_out_prod(a, b));
|
5602
|
+
GGML_ASSERT(!ggml_is_transposed(a));
|
5603
|
+
|
5604
|
+
bool is_node = false;
|
5605
|
+
|
5606
|
+
if (a->grad || b->grad) {
|
5607
|
+
is_node = true;
|
5608
|
+
}
|
5609
|
+
|
5610
|
+
const int64_t ne[4] = { a->ne[0], b->ne[0], a->ne[2], b->ne[3] };
|
5611
|
+
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MIN(a->n_dims, b->n_dims), ne);
|
5612
|
+
|
5613
|
+
result->op = GGML_OP_OUT_PROD;
|
5614
|
+
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
5615
|
+
result->src0 = a;
|
5616
|
+
result->src1 = b;
|
5617
|
+
|
5618
|
+
return result;
|
5619
|
+
}
|
5620
|
+
|
5441
5621
|
// ggml_scale
|
5442
5622
|
|
5443
5623
|
struct ggml_tensor * ggml_scale_impl(
|
@@ -5450,7 +5630,7 @@ struct ggml_tensor * ggml_scale_impl(
|
|
5450
5630
|
|
5451
5631
|
bool is_node = false;
|
5452
5632
|
|
5453
|
-
if (
|
5633
|
+
if (a->grad || b->grad) {
|
5454
5634
|
is_node = true;
|
5455
5635
|
}
|
5456
5636
|
|
@@ -5493,7 +5673,7 @@ struct ggml_tensor * ggml_set_impl(
|
|
5493
5673
|
|
5494
5674
|
bool is_node = false;
|
5495
5675
|
|
5496
|
-
if (
|
5676
|
+
if (a->grad || b->grad) {
|
5497
5677
|
is_node = true;
|
5498
5678
|
}
|
5499
5679
|
|
@@ -5802,14 +5982,18 @@ struct ggml_tensor * ggml_view_1d(
|
|
5802
5982
|
|
5803
5983
|
struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 1, &ne0, (char *) a->data + offset);
|
5804
5984
|
|
5985
|
+
ggml_scratch_save(ctx);
|
5986
|
+
|
5987
|
+
struct ggml_tensor * offs = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 2);
|
5988
|
+
memcpy(offs->data, &offset, 2*sizeof(int32_t));
|
5989
|
+
|
5990
|
+
ggml_scratch_load(ctx);
|
5991
|
+
|
5805
5992
|
result->op = GGML_OP_VIEW;
|
5806
5993
|
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
5807
5994
|
result->src0 = a;
|
5808
5995
|
result->src1 = NULL;
|
5809
|
-
|
5810
|
-
if (is_node) {
|
5811
|
-
memcpy(result->padding, &offset, sizeof(offset));
|
5812
|
-
}
|
5996
|
+
result->opt[0] = offs;
|
5813
5997
|
|
5814
5998
|
return result;
|
5815
5999
|
}
|
@@ -5834,6 +6018,13 @@ struct ggml_tensor * ggml_view_2d(
|
|
5834
6018
|
|
5835
6019
|
struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 2, ne, (char *) a->data + offset);
|
5836
6020
|
|
6021
|
+
ggml_scratch_save(ctx);
|
6022
|
+
|
6023
|
+
struct ggml_tensor * offs = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 2);
|
6024
|
+
memcpy(offs->data, &offset, 2*sizeof(int32_t));
|
6025
|
+
|
6026
|
+
ggml_scratch_load(ctx);
|
6027
|
+
|
5837
6028
|
result->nb[1] = nb1;
|
5838
6029
|
result->nb[2] = result->nb[1]*ne1;
|
5839
6030
|
result->nb[3] = result->nb[2];
|
@@ -5842,10 +6033,7 @@ struct ggml_tensor * ggml_view_2d(
|
|
5842
6033
|
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
5843
6034
|
result->src0 = a;
|
5844
6035
|
result->src1 = NULL;
|
5845
|
-
|
5846
|
-
if (is_node) {
|
5847
|
-
memcpy(result->padding, &offset, sizeof(offset));
|
5848
|
-
}
|
6036
|
+
result->opt[0] = offs;
|
5849
6037
|
|
5850
6038
|
return result;
|
5851
6039
|
}
|
@@ -5872,6 +6060,13 @@ struct ggml_tensor * ggml_view_3d(
|
|
5872
6060
|
|
5873
6061
|
struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 3, ne, (char *) a->data + offset);
|
5874
6062
|
|
6063
|
+
ggml_scratch_save(ctx);
|
6064
|
+
|
6065
|
+
struct ggml_tensor * offs = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 2);
|
6066
|
+
memcpy(offs->data, &offset, 2*sizeof(int32_t));
|
6067
|
+
|
6068
|
+
ggml_scratch_load(ctx);
|
6069
|
+
|
5875
6070
|
result->nb[1] = nb1;
|
5876
6071
|
result->nb[2] = nb2;
|
5877
6072
|
result->nb[3] = result->nb[2]*ne2;
|
@@ -5880,10 +6075,7 @@ struct ggml_tensor * ggml_view_3d(
|
|
5880
6075
|
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
5881
6076
|
result->src0 = a;
|
5882
6077
|
result->src1 = NULL;
|
5883
|
-
|
5884
|
-
if (is_node) {
|
5885
|
-
memcpy(result->padding, &offset, sizeof(offset));
|
5886
|
-
}
|
6078
|
+
result->opt[0] = offs;
|
5887
6079
|
|
5888
6080
|
return result;
|
5889
6081
|
}
|
@@ -5912,6 +6104,13 @@ struct ggml_tensor * ggml_view_4d(
|
|
5912
6104
|
|
5913
6105
|
struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 4, ne, (char *) a->data + offset);
|
5914
6106
|
|
6107
|
+
ggml_scratch_save(ctx);
|
6108
|
+
|
6109
|
+
struct ggml_tensor * offs = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 2);
|
6110
|
+
memcpy(offs->data, &offset, 2*sizeof(int32_t));
|
6111
|
+
|
6112
|
+
ggml_scratch_load(ctx);
|
6113
|
+
|
5915
6114
|
result->nb[1] = nb1;
|
5916
6115
|
result->nb[2] = nb2;
|
5917
6116
|
result->nb[3] = nb3;
|
@@ -5920,10 +6119,7 @@ struct ggml_tensor * ggml_view_4d(
|
|
5920
6119
|
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
5921
6120
|
result->src0 = a;
|
5922
6121
|
result->src1 = NULL;
|
5923
|
-
|
5924
|
-
if (is_node) {
|
5925
|
-
memcpy(result->padding, &offset, sizeof(offset));
|
5926
|
-
}
|
6122
|
+
result->opt[0] = offs;
|
5927
6123
|
|
5928
6124
|
return result;
|
5929
6125
|
}
|
@@ -5986,10 +6182,18 @@ struct ggml_tensor * ggml_permute(
|
|
5986
6182
|
result->src1 = NULL;
|
5987
6183
|
|
5988
6184
|
if (is_node) {
|
5989
|
-
|
5990
|
-
|
5991
|
-
|
5992
|
-
|
6185
|
+
ggml_scratch_save(ctx);
|
6186
|
+
|
6187
|
+
struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 4);
|
6188
|
+
|
6189
|
+
((int32_t *) b->data)[0] = axis0;
|
6190
|
+
((int32_t *) b->data)[1] = axis1;
|
6191
|
+
((int32_t *) b->data)[2] = axis2;
|
6192
|
+
((int32_t *) b->data)[3] = axis3;
|
6193
|
+
|
6194
|
+
ggml_scratch_load(ctx);
|
6195
|
+
|
6196
|
+
result->opt[0] = b;
|
5993
6197
|
}
|
5994
6198
|
|
5995
6199
|
return result;
|
@@ -6229,6 +6433,44 @@ struct ggml_tensor * ggml_soft_max_inplace(
|
|
6229
6433
|
return ggml_soft_max_impl(ctx, a, true);
|
6230
6434
|
}
|
6231
6435
|
|
6436
|
+
|
6437
|
+
// ggml_soft_max_back
|
6438
|
+
|
6439
|
+
struct ggml_tensor * ggml_soft_max_back_impl(
|
6440
|
+
struct ggml_context * ctx,
|
6441
|
+
struct ggml_tensor * a,
|
6442
|
+
struct ggml_tensor * b,
|
6443
|
+
bool inplace) {
|
6444
|
+
bool is_node = false;
|
6445
|
+
|
6446
|
+
if (a->grad || b->grad) {
|
6447
|
+
is_node = true; // TODO : implement backward pass
|
6448
|
+
}
|
6449
|
+
|
6450
|
+
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
6451
|
+
|
6452
|
+
result->op = GGML_OP_SOFT_MAX_BACK;
|
6453
|
+
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
6454
|
+
result->src0 = a;
|
6455
|
+
result->src1 = b;
|
6456
|
+
|
6457
|
+
return result;
|
6458
|
+
}
|
6459
|
+
|
6460
|
+
struct ggml_tensor * ggml_soft_max_back(
|
6461
|
+
struct ggml_context * ctx,
|
6462
|
+
struct ggml_tensor * a,
|
6463
|
+
struct ggml_tensor * b) {
|
6464
|
+
return ggml_soft_max_back_impl(ctx, a, b, false);
|
6465
|
+
}
|
6466
|
+
|
6467
|
+
struct ggml_tensor * ggml_soft_max_back_inplace(
|
6468
|
+
struct ggml_context * ctx,
|
6469
|
+
struct ggml_tensor * a,
|
6470
|
+
struct ggml_tensor * b) {
|
6471
|
+
return ggml_soft_max_back_impl(ctx, a, b, true);
|
6472
|
+
}
|
6473
|
+
|
6232
6474
|
// ggml_rope
|
6233
6475
|
|
6234
6476
|
struct ggml_tensor * ggml_rope_impl(
|
@@ -6241,7 +6483,7 @@ struct ggml_tensor * ggml_rope_impl(
|
|
6241
6483
|
GGML_ASSERT(n_past >= 0);
|
6242
6484
|
bool is_node = false;
|
6243
6485
|
|
6244
|
-
if (
|
6486
|
+
if (a->grad) {
|
6245
6487
|
is_node = true;
|
6246
6488
|
}
|
6247
6489
|
|
@@ -6295,8 +6537,7 @@ struct ggml_tensor * ggml_rope_back(
|
|
6295
6537
|
bool is_node = false;
|
6296
6538
|
|
6297
6539
|
if (a->grad) {
|
6298
|
-
|
6299
|
-
is_node = true;
|
6540
|
+
is_node = false; // TODO: implement backward
|
6300
6541
|
}
|
6301
6542
|
|
6302
6543
|
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
|
@@ -6461,7 +6702,6 @@ struct ggml_tensor * ggml_flash_attn(
|
|
6461
6702
|
bool is_node = false;
|
6462
6703
|
|
6463
6704
|
if (q->grad || k->grad || v->grad) {
|
6464
|
-
GGML_ASSERT(false); // TODO: implement backward
|
6465
6705
|
is_node = true;
|
6466
6706
|
}
|
6467
6707
|
|
@@ -6493,7 +6733,6 @@ struct ggml_tensor * ggml_flash_ff(
|
|
6493
6733
|
bool is_node = false;
|
6494
6734
|
|
6495
6735
|
if (a->grad || b0->grad || b1->grad || c0->grad || c1->grad) {
|
6496
|
-
GGML_ASSERT(false); // TODO: implement backward
|
6497
6736
|
is_node = true;
|
6498
6737
|
}
|
6499
6738
|
|
@@ -6511,6 +6750,71 @@ struct ggml_tensor * ggml_flash_ff(
|
|
6511
6750
|
return result;
|
6512
6751
|
}
|
6513
6752
|
|
6753
|
+
// ggml_flash_attn_back
|
6754
|
+
|
6755
|
+
struct ggml_tensor * ggml_flash_attn_back(
|
6756
|
+
struct ggml_context * ctx,
|
6757
|
+
struct ggml_tensor * q,
|
6758
|
+
struct ggml_tensor * k,
|
6759
|
+
struct ggml_tensor * v,
|
6760
|
+
struct ggml_tensor * d,
|
6761
|
+
bool masked) {
|
6762
|
+
GGML_ASSERT(ggml_can_mul_mat(k, q));
|
6763
|
+
// TODO: check if vT can be multiplied by (k*qT)
|
6764
|
+
|
6765
|
+
// d shape [D,N,ne2,ne3]
|
6766
|
+
// q shape [D,N,ne2,ne3]
|
6767
|
+
// k shape [D,M,ne2,ne3]
|
6768
|
+
// v shape [M,D,ne2,ne3]
|
6769
|
+
|
6770
|
+
const int64_t D = q->ne[0];
|
6771
|
+
const int64_t N = q->ne[1];
|
6772
|
+
const int64_t M = k->ne[1];
|
6773
|
+
const int64_t ne2 = q->ne[2];
|
6774
|
+
const int64_t ne3 = q->ne[3];
|
6775
|
+
|
6776
|
+
GGML_ASSERT(k->ne[0] == D);
|
6777
|
+
GGML_ASSERT(v->ne[0] == M);
|
6778
|
+
GGML_ASSERT(v->ne[1] == D);
|
6779
|
+
GGML_ASSERT(d->ne[0] == D);
|
6780
|
+
GGML_ASSERT(d->ne[1] == N);
|
6781
|
+
GGML_ASSERT(k->ne[2] == ne2);
|
6782
|
+
GGML_ASSERT(k->ne[3] == ne3);
|
6783
|
+
GGML_ASSERT(v->ne[2] == ne2);
|
6784
|
+
GGML_ASSERT(v->ne[3] == ne3);
|
6785
|
+
GGML_ASSERT(d->ne[2] == ne2);
|
6786
|
+
GGML_ASSERT(d->ne[3] == ne3);
|
6787
|
+
|
6788
|
+
bool is_node = false;
|
6789
|
+
|
6790
|
+
if (q->grad || k->grad || v->grad) {
|
6791
|
+
// when using this operation (in backwards pass) these grads are set.
|
6792
|
+
// we don't want to create (big) grad of our result, so is_node is false.
|
6793
|
+
is_node = false;
|
6794
|
+
}
|
6795
|
+
|
6796
|
+
// store gradients of q, k and v as continuous tensors concatenated in result.
|
6797
|
+
// q shape[D,N,ne2,ne3] ; k shape [D,M,ne2,ne3] ; v shape [M,D,ne2,ne3]
|
6798
|
+
// gradq->data = result->data
|
6799
|
+
// gradk->data = result->data + nb0*D*N*ne2*ne3
|
6800
|
+
// gradv->data = result->data + nb0*D*N*ne2*ne3 + nb0*D*M*ne2*ne3
|
6801
|
+
// note: v and gradv are actually transposed, i.e. v->ne[0] != D.
|
6802
|
+
int64_t ne[4] = {D,M+N+M,ne2,ne3};
|
6803
|
+
|
6804
|
+
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
|
6805
|
+
|
6806
|
+
result->op = GGML_OP_FLASH_ATTN_BACK;
|
6807
|
+
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
6808
|
+
result->src0 = q;
|
6809
|
+
result->src1 = k;
|
6810
|
+
result->opt[0] = v;
|
6811
|
+
result->opt[1] = d;
|
6812
|
+
result->opt[2] = ggml_new_i32(ctx, masked ? 1 : 0);
|
6813
|
+
|
6814
|
+
return result;
|
6815
|
+
}
|
6816
|
+
|
6817
|
+
|
6514
6818
|
// ggml_map_unary
|
6515
6819
|
|
6516
6820
|
struct ggml_tensor * ggml_map_unary_impl_f32(
|
@@ -6595,6 +6899,50 @@ struct ggml_tensor * ggml_map_binary_inplace_f32(
|
|
6595
6899
|
return ggml_map_binary_impl_f32(ctx, a, b, fun, true);
|
6596
6900
|
}
|
6597
6901
|
|
6902
|
+
// ggml_cross_entropy_loss
|
6903
|
+
|
6904
|
+
struct ggml_tensor * ggml_cross_entropy_loss(
|
6905
|
+
struct ggml_context * ctx,
|
6906
|
+
struct ggml_tensor * a,
|
6907
|
+
struct ggml_tensor * b) {
|
6908
|
+
GGML_ASSERT(ggml_are_same_shape(a, b));
|
6909
|
+
bool is_node = false;
|
6910
|
+
|
6911
|
+
if (a->grad || b->grad) {
|
6912
|
+
is_node = true;
|
6913
|
+
}
|
6914
|
+
|
6915
|
+
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, a->type, 1);
|
6916
|
+
|
6917
|
+
result->op = GGML_OP_CROSS_ENTROPY_LOSS;
|
6918
|
+
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
6919
|
+
result->src0 = a;
|
6920
|
+
result->src1 = b;
|
6921
|
+
|
6922
|
+
return result;
|
6923
|
+
}
|
6924
|
+
|
6925
|
+
// ggml_cross_entropy_loss_back
|
6926
|
+
|
6927
|
+
struct ggml_tensor * ggml_cross_entropy_loss_back(
|
6928
|
+
struct ggml_context * ctx,
|
6929
|
+
struct ggml_tensor * a,
|
6930
|
+
struct ggml_tensor * b,
|
6931
|
+
struct ggml_tensor * c) {
|
6932
|
+
GGML_ASSERT(ggml_are_same_shape(a, b));
|
6933
|
+
GGML_ASSERT(ggml_is_scalar(c));
|
6934
|
+
|
6935
|
+
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
|
6936
|
+
|
6937
|
+
result->op = GGML_OP_CROSS_ENTROPY_LOSS_BACK;
|
6938
|
+
result->grad = NULL;
|
6939
|
+
result->src0 = a;
|
6940
|
+
result->src1 = b;
|
6941
|
+
result->opt[0] = c;
|
6942
|
+
|
6943
|
+
return result;
|
6944
|
+
}
|
6945
|
+
|
6598
6946
|
////////////////////////////////////////////////////////////////////////////////
|
6599
6947
|
|
6600
6948
|
void ggml_set_param(
|
@@ -7584,6 +7932,11 @@ static void ggml_compute_forward_add(
|
|
7584
7932
|
case GGML_TYPE_Q5_0:
|
7585
7933
|
case GGML_TYPE_Q5_1:
|
7586
7934
|
case GGML_TYPE_Q8_0:
|
7935
|
+
case GGML_TYPE_Q2_K:
|
7936
|
+
case GGML_TYPE_Q3_K:
|
7937
|
+
case GGML_TYPE_Q4_K:
|
7938
|
+
case GGML_TYPE_Q5_K:
|
7939
|
+
case GGML_TYPE_Q6_K:
|
7587
7940
|
{
|
7588
7941
|
ggml_compute_forward_add_q_f32(params, src0, src1, dst);
|
7589
7942
|
} break;
|
@@ -7887,6 +8240,11 @@ static void ggml_compute_forward_add1(
|
|
7887
8240
|
case GGML_TYPE_Q5_1:
|
7888
8241
|
case GGML_TYPE_Q8_0:
|
7889
8242
|
case GGML_TYPE_Q8_1:
|
8243
|
+
case GGML_TYPE_Q2_K:
|
8244
|
+
case GGML_TYPE_Q3_K:
|
8245
|
+
case GGML_TYPE_Q4_K:
|
8246
|
+
case GGML_TYPE_Q5_K:
|
8247
|
+
case GGML_TYPE_Q6_K:
|
7890
8248
|
{
|
7891
8249
|
ggml_compute_forward_add1_q_f32(params, src0, src1, dst);
|
7892
8250
|
} break;
|
@@ -8009,6 +8367,11 @@ static void ggml_compute_forward_acc(
|
|
8009
8367
|
case GGML_TYPE_Q5_1:
|
8010
8368
|
case GGML_TYPE_Q8_0:
|
8011
8369
|
case GGML_TYPE_Q8_1:
|
8370
|
+
case GGML_TYPE_Q2_K:
|
8371
|
+
case GGML_TYPE_Q3_K:
|
8372
|
+
case GGML_TYPE_Q4_K:
|
8373
|
+
case GGML_TYPE_Q5_K:
|
8374
|
+
case GGML_TYPE_Q6_K:
|
8012
8375
|
default:
|
8013
8376
|
{
|
8014
8377
|
GGML_ASSERT(false);
|
@@ -8127,10 +8490,10 @@ static void ggml_compute_forward_mul_f32(
|
|
8127
8490
|
const int ith = params->ith;
|
8128
8491
|
const int nth = params->nth;
|
8129
8492
|
|
8130
|
-
#ifdef
|
8131
|
-
if (src1->backend ==
|
8493
|
+
#ifdef GGML_USE_CLBLAST
|
8494
|
+
if (src1->backend == GGML_BACKEND_GPU) {
|
8132
8495
|
if (ith == 0) {
|
8133
|
-
|
8496
|
+
ggml_cl_mul(src0, src1, dst);
|
8134
8497
|
}
|
8135
8498
|
return;
|
8136
8499
|
}
|
@@ -8730,29 +9093,122 @@ static void ggml_compute_forward_repeat(
|
|
8730
9093
|
}
|
8731
9094
|
}
|
8732
9095
|
|
8733
|
-
//
|
9096
|
+
// ggml_compute_forward_repeat_back
|
8734
9097
|
|
8735
|
-
static void
|
9098
|
+
static void ggml_compute_forward_repeat_back_f32(
|
8736
9099
|
const struct ggml_compute_params * params,
|
8737
9100
|
const struct ggml_tensor * src0,
|
8738
9101
|
struct ggml_tensor * dst) {
|
8739
|
-
|
8740
|
-
|
9102
|
+
GGML_ASSERT(params->ith == 0);
|
9103
|
+
GGML_ASSERT(ggml_can_repeat(dst, src0));
|
8741
9104
|
|
8742
9105
|
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
8743
9106
|
return;
|
8744
9107
|
}
|
8745
9108
|
|
8746
|
-
const
|
8747
|
-
const
|
9109
|
+
const int64_t ne0 = dst->ne[0];
|
9110
|
+
const int64_t ne1 = dst->ne[1];
|
9111
|
+
const int64_t ne2 = dst->ne[2];
|
9112
|
+
const int64_t ne3 = dst->ne[3];
|
8748
9113
|
|
8749
|
-
|
8750
|
-
|
9114
|
+
const int64_t ne00 = src0->ne[0];
|
9115
|
+
const int64_t ne01 = src0->ne[1];
|
9116
|
+
const int64_t ne02 = src0->ne[2];
|
9117
|
+
const int64_t ne03 = src0->ne[3];
|
8751
9118
|
|
8752
|
-
|
8753
|
-
|
8754
|
-
|
8755
|
-
|
9119
|
+
const size_t nb0 = dst->nb[0];
|
9120
|
+
const size_t nb1 = dst->nb[1];
|
9121
|
+
const size_t nb2 = dst->nb[2];
|
9122
|
+
const size_t nb3 = dst->nb[3];
|
9123
|
+
|
9124
|
+
const size_t nb00 = src0->nb[0];
|
9125
|
+
const size_t nb01 = src0->nb[1];
|
9126
|
+
const size_t nb02 = src0->nb[2];
|
9127
|
+
const size_t nb03 = src0->nb[3];
|
9128
|
+
|
9129
|
+
// guaranteed to be an integer due to the check in ggml_can_repeat
|
9130
|
+
const int nr0 = (int)(ne00/ne0);
|
9131
|
+
const int nr1 = (int)(ne01/ne1);
|
9132
|
+
const int nr2 = (int)(ne02/ne2);
|
9133
|
+
const int nr3 = (int)(ne03/ne3);
|
9134
|
+
|
9135
|
+
// TODO: support for transposed / permuted tensors
|
9136
|
+
GGML_ASSERT(nb0 == sizeof(float));
|
9137
|
+
GGML_ASSERT(nb00 == sizeof(float));
|
9138
|
+
|
9139
|
+
if (ggml_is_contiguous(dst)) {
|
9140
|
+
ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
|
9141
|
+
} else {
|
9142
|
+
for (int k3 = 0; k3 < ne3; k3++) {
|
9143
|
+
for (int k2 = 0; k2 < ne2; k2++) {
|
9144
|
+
for (int k1 = 0; k1 < ne1; k1++) {
|
9145
|
+
ggml_vec_set_f32(ne0,
|
9146
|
+
(float *) ((char *) dst->data + k1*nb1 + k2*nb2 + k3*nb3),
|
9147
|
+
0);
|
9148
|
+
}
|
9149
|
+
}
|
9150
|
+
}
|
9151
|
+
}
|
9152
|
+
|
9153
|
+
// TODO: maybe this is not optimal?
|
9154
|
+
for (int i3 = 0; i3 < nr3; i3++) {
|
9155
|
+
for (int k3 = 0; k3 < ne3; k3++) {
|
9156
|
+
for (int i2 = 0; i2 < nr2; i2++) {
|
9157
|
+
for (int k2 = 0; k2 < ne2; k2++) {
|
9158
|
+
for (int i1 = 0; i1 < nr1; i1++) {
|
9159
|
+
for (int k1 = 0; k1 < ne1; k1++) {
|
9160
|
+
for (int i0 = 0; i0 < nr0; i0++) {
|
9161
|
+
ggml_vec_acc_f32(ne0,
|
9162
|
+
(float *) ((char *) dst->data + ( k3)*nb3 + ( k2)*nb2 + ( k1)*nb1),
|
9163
|
+
(float *) ((char *) src0->data + (i3*ne3 + k3)*nb03 + (i2*ne2 + k2)*nb02 + (i1*ne1 + k1)*nb01 + (i0*ne0)*nb00));
|
9164
|
+
}
|
9165
|
+
}
|
9166
|
+
}
|
9167
|
+
}
|
9168
|
+
}
|
9169
|
+
}
|
9170
|
+
}
|
9171
|
+
}
|
9172
|
+
|
9173
|
+
static void ggml_compute_forward_repeat_back(
|
9174
|
+
const struct ggml_compute_params * params,
|
9175
|
+
const struct ggml_tensor * src0,
|
9176
|
+
struct ggml_tensor * dst) {
|
9177
|
+
switch (src0->type) {
|
9178
|
+
case GGML_TYPE_F32:
|
9179
|
+
{
|
9180
|
+
ggml_compute_forward_repeat_back_f32(params, src0, dst);
|
9181
|
+
} break;
|
9182
|
+
default:
|
9183
|
+
{
|
9184
|
+
GGML_ASSERT(false);
|
9185
|
+
} break;
|
9186
|
+
}
|
9187
|
+
}
|
9188
|
+
|
9189
|
+
// ggml_compute_forward_abs
|
9190
|
+
|
9191
|
+
static void ggml_compute_forward_abs_f32(
|
9192
|
+
const struct ggml_compute_params * params,
|
9193
|
+
const struct ggml_tensor * src0,
|
9194
|
+
struct ggml_tensor * dst) {
|
9195
|
+
assert(params->ith == 0);
|
9196
|
+
assert(ggml_are_same_shape(src0, dst));
|
9197
|
+
|
9198
|
+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
9199
|
+
return;
|
9200
|
+
}
|
9201
|
+
|
9202
|
+
const int n = ggml_nrows(src0);
|
9203
|
+
const int nc = src0->ne[0];
|
9204
|
+
|
9205
|
+
assert(dst->nb[0] == sizeof(float));
|
9206
|
+
assert(src0->nb[0] == sizeof(float));
|
9207
|
+
|
9208
|
+
for (int i = 0; i < n; i++) {
|
9209
|
+
ggml_vec_abs_f32(nc,
|
9210
|
+
(float *) ((char *) dst->data + i*( dst->nb[1])),
|
9211
|
+
(float *) ((char *) src0->data + i*(src0->nb[1])));
|
8756
9212
|
}
|
8757
9213
|
}
|
8758
9214
|
|
@@ -9245,7 +9701,7 @@ static void ggml_compute_forward_rms_norm_f32(
|
|
9245
9701
|
sum += (ggml_float)(x[i00] * x[i00]);
|
9246
9702
|
}
|
9247
9703
|
|
9248
|
-
float mean = sum/ne00;
|
9704
|
+
const float mean = sum/ne00;
|
9249
9705
|
|
9250
9706
|
float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
|
9251
9707
|
|
@@ -9568,14 +10024,7 @@ static void ggml_compute_forward_mul_mat_f32(
|
|
9568
10024
|
// nb01 >= nb00 - src0 is not transposed
|
9569
10025
|
// compute by src0 rows
|
9570
10026
|
|
9571
|
-
#if defined(
|
9572
|
-
if (ggml_cuda_can_mul_mat(src0, src1, dst)) {
|
9573
|
-
if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
|
9574
|
-
ggml_cuda_mul_mat(src0, src1, dst, params->wdata, params->wsize);
|
9575
|
-
}
|
9576
|
-
return;
|
9577
|
-
}
|
9578
|
-
#elif defined(GGML_USE_CLBLAST)
|
10027
|
+
#if defined(GGML_USE_CLBLAST)
|
9579
10028
|
if (ggml_cl_can_mul_mat(src0, src1, dst)) {
|
9580
10029
|
if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
|
9581
10030
|
ggml_cl_mul_mat(src0, src1, dst, params->wdata, params->wsize);
|
@@ -9740,14 +10189,7 @@ static void ggml_compute_forward_mul_mat_f16_f32(
|
|
9740
10189
|
// nb01 >= nb00 - src0 is not transposed
|
9741
10190
|
// compute by src0 rows
|
9742
10191
|
|
9743
|
-
#if defined(
|
9744
|
-
if (ggml_cuda_can_mul_mat(src0, src1, dst)) {
|
9745
|
-
if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
|
9746
|
-
ggml_cuda_mul_mat(src0, src1, dst, params->wdata, params->wsize);
|
9747
|
-
}
|
9748
|
-
return;
|
9749
|
-
}
|
9750
|
-
#elif defined(GGML_USE_CLBLAST)
|
10192
|
+
#if defined(GGML_USE_CLBLAST)
|
9751
10193
|
if (ggml_cl_can_mul_mat(src0, src1, dst)) {
|
9752
10194
|
if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
|
9753
10195
|
ggml_cl_mul_mat(src0, src1, dst, params->wdata, params->wsize);
|
@@ -9952,14 +10394,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|
9952
10394
|
// nb01 >= nb00 - src0 is not transposed
|
9953
10395
|
// compute by src0 rows
|
9954
10396
|
|
9955
|
-
#if defined(
|
9956
|
-
if (ggml_cuda_can_mul_mat(src0, src1, dst)) {
|
9957
|
-
if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
|
9958
|
-
ggml_cuda_mul_mat(src0, src1, dst, params->wdata, params->wsize);
|
9959
|
-
}
|
9960
|
-
return;
|
9961
|
-
}
|
9962
|
-
#elif defined(GGML_USE_CLBLAST)
|
10397
|
+
#if defined(GGML_USE_CLBLAST)
|
9963
10398
|
if (ggml_cl_can_mul_mat(src0, src1, dst)) {
|
9964
10399
|
if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
|
9965
10400
|
ggml_cl_mul_mat(src0, src1, dst, params->wdata, params->wsize);
|
@@ -10102,6 +10537,11 @@ static void ggml_compute_forward_mul_mat(
|
|
10102
10537
|
case GGML_TYPE_Q5_1:
|
10103
10538
|
case GGML_TYPE_Q8_0:
|
10104
10539
|
case GGML_TYPE_Q8_1:
|
10540
|
+
case GGML_TYPE_Q2_K:
|
10541
|
+
case GGML_TYPE_Q3_K:
|
10542
|
+
case GGML_TYPE_Q4_K:
|
10543
|
+
case GGML_TYPE_Q5_K:
|
10544
|
+
case GGML_TYPE_Q6_K:
|
10105
10545
|
{
|
10106
10546
|
ggml_compute_forward_mul_mat_q_f32(params, src0, src1, dst);
|
10107
10547
|
} break;
|
@@ -10120,6 +10560,176 @@ static void ggml_compute_forward_mul_mat(
|
|
10120
10560
|
}
|
10121
10561
|
}
|
10122
10562
|
|
10563
|
+
// ggml_compute_forward_out_prod
|
10564
|
+
|
10565
|
+
|
10566
|
+
static void ggml_compute_forward_out_prod_f32(
|
10567
|
+
const struct ggml_compute_params * params,
|
10568
|
+
const struct ggml_tensor * src0,
|
10569
|
+
const struct ggml_tensor * src1,
|
10570
|
+
struct ggml_tensor * dst) {
|
10571
|
+
int64_t t0 = ggml_perf_time_us();
|
10572
|
+
UNUSED(t0);
|
10573
|
+
|
10574
|
+
const int64_t ne00 = src0->ne[0];
|
10575
|
+
const int64_t ne01 = src0->ne[1];
|
10576
|
+
const int64_t ne02 = src0->ne[2];
|
10577
|
+
const int64_t ne03 = src0->ne[3];
|
10578
|
+
|
10579
|
+
const int64_t ne10 = src1->ne[0];
|
10580
|
+
//const int64_t ne11 = src1->ne[1];
|
10581
|
+
const int64_t ne12 = src1->ne[2];
|
10582
|
+
const int64_t ne13 = src1->ne[3];
|
10583
|
+
|
10584
|
+
const int64_t ne0 = dst->ne[0];
|
10585
|
+
const int64_t ne1 = dst->ne[1];
|
10586
|
+
const int64_t ne2 = dst->ne[2];
|
10587
|
+
const int64_t ne3 = dst->ne[3];
|
10588
|
+
|
10589
|
+
const int nb00 = src0->nb[0];
|
10590
|
+
const int nb01 = src0->nb[1];
|
10591
|
+
const int nb02 = src0->nb[2];
|
10592
|
+
const int nb03 = src0->nb[3];
|
10593
|
+
|
10594
|
+
const int nb10 = src1->nb[0];
|
10595
|
+
const int nb11 = src1->nb[1];
|
10596
|
+
const int nb12 = src1->nb[2];
|
10597
|
+
const int nb13 = src1->nb[3];
|
10598
|
+
|
10599
|
+
const int nb0 = dst->nb[0];
|
10600
|
+
const int nb1 = dst->nb[1];
|
10601
|
+
const int nb2 = dst->nb[2];
|
10602
|
+
const int nb3 = dst->nb[3];
|
10603
|
+
|
10604
|
+
const int ith = params->ith;
|
10605
|
+
const int nth = params->nth;
|
10606
|
+
|
10607
|
+
GGML_ASSERT(ne02 == ne12);
|
10608
|
+
GGML_ASSERT(ne03 == ne13);
|
10609
|
+
GGML_ASSERT(ne2 == ne12);
|
10610
|
+
GGML_ASSERT(ne3 == ne13);
|
10611
|
+
|
10612
|
+
// we don't support permuted src0 or src1
|
10613
|
+
GGML_ASSERT(nb00 == sizeof(float));
|
10614
|
+
|
10615
|
+
// dst cannot be transposed or permuted
|
10616
|
+
GGML_ASSERT(nb0 == sizeof(float));
|
10617
|
+
// GGML_ASSERT(nb0 <= nb1);
|
10618
|
+
// GGML_ASSERT(nb1 <= nb2);
|
10619
|
+
// GGML_ASSERT(nb2 <= nb3);
|
10620
|
+
|
10621
|
+
GGML_ASSERT(ne0 == ne00);
|
10622
|
+
GGML_ASSERT(ne1 == ne10);
|
10623
|
+
GGML_ASSERT(ne2 == ne02);
|
10624
|
+
GGML_ASSERT(ne3 == ne03);
|
10625
|
+
|
10626
|
+
// nb01 >= nb00 - src0 is not transposed
|
10627
|
+
// compute by src0 rows
|
10628
|
+
|
10629
|
+
// TODO: #if defined(GGML_USE_CUBLAS) ggml_cuda_out_prod
|
10630
|
+
// TODO: #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
|
10631
|
+
|
10632
|
+
if (params->type == GGML_TASK_INIT) {
|
10633
|
+
ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
|
10634
|
+
return;
|
10635
|
+
}
|
10636
|
+
|
10637
|
+
if (params->type == GGML_TASK_FINALIZE) {
|
10638
|
+
return;
|
10639
|
+
}
|
10640
|
+
|
10641
|
+
// parallelize by last three dimensions
|
10642
|
+
|
10643
|
+
// total rows in dst
|
10644
|
+
const int64_t nr = ne1*ne2*ne3;
|
10645
|
+
|
10646
|
+
// rows per thread
|
10647
|
+
const int64_t dr = (nr + nth - 1)/nth;
|
10648
|
+
|
10649
|
+
// row range for this thread
|
10650
|
+
const int64_t ir0 = dr*ith;
|
10651
|
+
const int64_t ir1 = MIN(ir0 + dr, nr);
|
10652
|
+
|
10653
|
+
// dst[:,:,:,:] = 0
|
10654
|
+
// for i2,i3:
|
10655
|
+
// for i1:
|
10656
|
+
// for i01:
|
10657
|
+
// for i0:
|
10658
|
+
// dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3]
|
10659
|
+
|
10660
|
+
for (int64_t ir = ir0; ir < ir1; ++ir) {
|
10661
|
+
// dst indices
|
10662
|
+
const int64_t i3 = ir/(ne2*ne1);
|
10663
|
+
const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
|
10664
|
+
const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
|
10665
|
+
|
10666
|
+
const int64_t i02 = i2;
|
10667
|
+
const int64_t i03 = i3;
|
10668
|
+
|
10669
|
+
//const int64_t i10 = i1;
|
10670
|
+
const int64_t i12 = i2;
|
10671
|
+
const int64_t i13 = i3;
|
10672
|
+
|
10673
|
+
for (int64_t i01 = 0; i01 < ne01; ++i01) {
|
10674
|
+
const int64_t i11 = i01;
|
10675
|
+
|
10676
|
+
float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
|
10677
|
+
float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
|
10678
|
+
float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
|
10679
|
+
|
10680
|
+
ggml_vec_mad_f32(ne0, d, s0, *s1);
|
10681
|
+
// for (int64_t i0 = 0; i0 < ne0; ++i0) {
|
10682
|
+
// d[i0] += s0[i0] * s1[i1];
|
10683
|
+
// }
|
10684
|
+
}
|
10685
|
+
}
|
10686
|
+
|
10687
|
+
//int64_t t1 = ggml_perf_time_us();
|
10688
|
+
//static int64_t acc = 0;
|
10689
|
+
//acc += t1 - t0;
|
10690
|
+
//if (t1 - t0 > 10) {
|
10691
|
+
// printf("\n");
|
10692
|
+
// printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03);
|
10693
|
+
// printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03);
|
10694
|
+
// printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13);
|
10695
|
+
// printf("nb10 = %5d, nb11 = %5d, nb12 = %5d, nb13 = %5d\n", nb10, nb11, nb12, nb13);
|
10696
|
+
|
10697
|
+
// printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc);
|
10698
|
+
//}
|
10699
|
+
}
|
10700
|
+
|
10701
|
+
static void ggml_compute_forward_out_prod(
|
10702
|
+
const struct ggml_compute_params * params,
|
10703
|
+
const struct ggml_tensor * src0,
|
10704
|
+
const struct ggml_tensor * src1,
|
10705
|
+
struct ggml_tensor * dst) {
|
10706
|
+
switch (src0->type) {
|
10707
|
+
case GGML_TYPE_Q4_0:
|
10708
|
+
case GGML_TYPE_Q4_1:
|
10709
|
+
case GGML_TYPE_Q5_0:
|
10710
|
+
case GGML_TYPE_Q5_1:
|
10711
|
+
case GGML_TYPE_Q8_0:
|
10712
|
+
case GGML_TYPE_Q8_1:
|
10713
|
+
{
|
10714
|
+
GGML_ASSERT(false); // todo
|
10715
|
+
// ggml_compute_forward_out_prod_q_f32(params, src0, src1, dst);
|
10716
|
+
} break;
|
10717
|
+
case GGML_TYPE_F16:
|
10718
|
+
{
|
10719
|
+
GGML_ASSERT(false); // todo
|
10720
|
+
// ggml_compute_forward_out_prod_f16_f32(params, src0, src1, dst);
|
10721
|
+
} break;
|
10722
|
+
case GGML_TYPE_F32:
|
10723
|
+
{
|
10724
|
+
ggml_compute_forward_out_prod_f32(params, src0, src1, dst);
|
10725
|
+
} break;
|
10726
|
+
default:
|
10727
|
+
{
|
10728
|
+
GGML_ASSERT(false);
|
10729
|
+
} break;
|
10730
|
+
}
|
10731
|
+
}
|
10732
|
+
|
10123
10733
|
// ggml_compute_forward_scale
|
10124
10734
|
|
10125
10735
|
static void ggml_compute_forward_scale_f32(
|
@@ -10285,6 +10895,11 @@ static void ggml_compute_forward_set(
|
|
10285
10895
|
case GGML_TYPE_Q5_1:
|
10286
10896
|
case GGML_TYPE_Q8_0:
|
10287
10897
|
case GGML_TYPE_Q8_1:
|
10898
|
+
case GGML_TYPE_Q2_K:
|
10899
|
+
case GGML_TYPE_Q3_K:
|
10900
|
+
case GGML_TYPE_Q4_K:
|
10901
|
+
case GGML_TYPE_Q5_K:
|
10902
|
+
case GGML_TYPE_Q6_K:
|
10288
10903
|
default:
|
10289
10904
|
{
|
10290
10905
|
GGML_ASSERT(false);
|
@@ -10450,6 +11065,11 @@ static void ggml_compute_forward_get_rows(
|
|
10450
11065
|
case GGML_TYPE_Q5_1:
|
10451
11066
|
case GGML_TYPE_Q8_0:
|
10452
11067
|
case GGML_TYPE_Q8_1:
|
11068
|
+
case GGML_TYPE_Q2_K:
|
11069
|
+
case GGML_TYPE_Q3_K:
|
11070
|
+
case GGML_TYPE_Q4_K:
|
11071
|
+
case GGML_TYPE_Q5_K:
|
11072
|
+
case GGML_TYPE_Q6_K:
|
10453
11073
|
{
|
10454
11074
|
ggml_compute_forward_get_rows_q(params, src0, src1, dst);
|
10455
11075
|
} break;
|
@@ -10532,7 +11152,11 @@ static void ggml_compute_forward_get_rows_back_f32(
|
|
10532
11152
|
GGML_ASSERT(ggml_is_contiguous(opt0));
|
10533
11153
|
GGML_ASSERT(ggml_is_contiguous(dst));
|
10534
11154
|
|
10535
|
-
ggml_compute_forward_dup_same_cont(params, opt0, dst);
|
11155
|
+
// ggml_compute_forward_dup_same_cont(params, opt0, dst);
|
11156
|
+
|
11157
|
+
if (params->type == GGML_TASK_INIT) {
|
11158
|
+
memset(dst->data, 0, ggml_nbytes(dst));
|
11159
|
+
}
|
10536
11160
|
|
10537
11161
|
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
10538
11162
|
return;
|
@@ -10676,8 +11300,8 @@ static void ggml_compute_forward_diag_mask_f32(
|
|
10676
11300
|
const struct ggml_tensor * src1,
|
10677
11301
|
struct ggml_tensor * dst,
|
10678
11302
|
const float value) {
|
10679
|
-
|
10680
|
-
|
11303
|
+
GGML_ASSERT(src1->type == GGML_TYPE_I32);
|
11304
|
+
GGML_ASSERT(ggml_nelements(src1) == 2);
|
10681
11305
|
|
10682
11306
|
const int ith = params->ith;
|
10683
11307
|
const int nth = params->nth;
|
@@ -10685,7 +11309,7 @@ static void ggml_compute_forward_diag_mask_f32(
|
|
10685
11309
|
const int n_past = ((int32_t *) src1->data)[0];
|
10686
11310
|
const bool inplace = (bool)((int32_t *) src1->data)[1];
|
10687
11311
|
|
10688
|
-
|
11312
|
+
GGML_ASSERT(n_past >= 0);
|
10689
11313
|
|
10690
11314
|
if (!inplace && (params->type == GGML_TASK_INIT)) {
|
10691
11315
|
// memcpy needs to be synchronized across threads to avoid race conditions.
|
@@ -10709,8 +11333,8 @@ static void ggml_compute_forward_diag_mask_f32(
|
|
10709
11333
|
const int nr = src0->ne[1];
|
10710
11334
|
const int nz = n/nr;
|
10711
11335
|
|
10712
|
-
|
10713
|
-
|
11336
|
+
GGML_ASSERT( dst->nb[0] == sizeof(float));
|
11337
|
+
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
10714
11338
|
|
10715
11339
|
for (int k = 0; k < nz; k++) {
|
10716
11340
|
for (int j = ith; j < nr; j += nth) {
|
@@ -10846,42 +11470,137 @@ static void ggml_compute_forward_soft_max(
|
|
10846
11470
|
}
|
10847
11471
|
}
|
10848
11472
|
|
10849
|
-
//
|
11473
|
+
// ggml_compute_forward_soft_max_back
|
10850
11474
|
|
10851
|
-
static void
|
11475
|
+
static void ggml_compute_forward_soft_max_back_f32(
|
10852
11476
|
const struct ggml_compute_params * params,
|
10853
11477
|
const struct ggml_tensor * src0,
|
10854
11478
|
const struct ggml_tensor * src1,
|
10855
11479
|
struct ggml_tensor * dst) {
|
10856
|
-
|
10857
|
-
|
10858
|
-
|
11480
|
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
11481
|
+
GGML_ASSERT(ggml_is_contiguous(src1));
|
11482
|
+
GGML_ASSERT(ggml_is_contiguous(dst));
|
11483
|
+
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
11484
|
+
GGML_ASSERT(ggml_are_same_shape(src1, dst));
|
10859
11485
|
|
10860
11486
|
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
10861
11487
|
return;
|
10862
11488
|
}
|
10863
11489
|
|
10864
|
-
|
10865
|
-
const int n_head = ((int32_t *) src1->data)[1];
|
10866
|
-
const float max_bias = ((float *) src1->data)[2];
|
11490
|
+
// TODO: handle transposed/permuted matrices
|
10867
11491
|
|
10868
|
-
|
11492
|
+
const int ith = params->ith;
|
11493
|
+
const int nth = params->nth;
|
10869
11494
|
|
10870
|
-
const int
|
10871
|
-
const int
|
10872
|
-
//const int ne2 = src0->ne[2]; // n_head -> this is k
|
10873
|
-
//const int ne3 = src0->ne[3]; // 1 -> bsz
|
11495
|
+
const int nc = src0->ne[0];
|
11496
|
+
const int nr = ggml_nrows(src0);
|
10874
11497
|
|
10875
|
-
|
10876
|
-
const int
|
11498
|
+
// rows per thread
|
11499
|
+
const int dr = (nr + nth - 1)/nth;
|
10877
11500
|
|
10878
|
-
|
10879
|
-
const int
|
10880
|
-
const int
|
10881
|
-
//const int nb3 = src0->nb[3];
|
11501
|
+
// row range for this thread
|
11502
|
+
const int ir0 = dr*ith;
|
11503
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
10882
11504
|
|
10883
|
-
|
10884
|
-
|
11505
|
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
11506
|
+
float *dy = (float *)((char *) src0->data + i1*src0->nb[1]);
|
11507
|
+
float *y = (float *)((char *) src1->data + i1*src1->nb[1]);
|
11508
|
+
float *dx = (float *)((char *) dst->data + i1*dst->nb[1]);
|
11509
|
+
|
11510
|
+
#ifndef NDEBUG
|
11511
|
+
for (int i = 0; i < nc; ++i) {
|
11512
|
+
//printf("p[%d] = %f\n", i, p[i]);
|
11513
|
+
assert(!isnan(dy[i]));
|
11514
|
+
assert(!isnan(y[i]));
|
11515
|
+
}
|
11516
|
+
#endif
|
11517
|
+
// Jii = yi - yi*yi
|
11518
|
+
// Jij = -yi*yj
|
11519
|
+
// J = diag(y)-y.T*y
|
11520
|
+
// dx = J * dy
|
11521
|
+
// dxk = sum_i(Jki * dyi)
|
11522
|
+
// dxk = sum_i(-yk*yi * dyi) - (-yk*yk)*dyk + (yk - yk*yk)*dyk
|
11523
|
+
// dxk = sum_i(-yk*yi * dyi) + yk*dyk
|
11524
|
+
// dxk = -yk * sum_i(yi * dyi) + yk*dyk
|
11525
|
+
// dxk = -yk * dot(y, dy) + yk*dyk
|
11526
|
+
// dxk = yk * (- dot(y, dy) + dyk)
|
11527
|
+
// dxk = yk * (dyk - dot(y, dy))
|
11528
|
+
//
|
11529
|
+
// post-order:
|
11530
|
+
// dot_y_dy := dot(y, dy)
|
11531
|
+
// dx := dy
|
11532
|
+
// dx := dx - dot_y_dy
|
11533
|
+
// dx := dx * y
|
11534
|
+
|
11535
|
+
// linear runtime, no additional memory
|
11536
|
+
float dot_y_dy = 0;
|
11537
|
+
ggml_vec_dot_f32 (nc, &dot_y_dy, y, dy);
|
11538
|
+
ggml_vec_cpy_f32 (nc, dx, dy);
|
11539
|
+
ggml_vec_acc1_f32(nc, dx, -dot_y_dy);
|
11540
|
+
ggml_vec_mul_f32 (nc, dx, dx, y);
|
11541
|
+
|
11542
|
+
#ifndef NDEBUG
|
11543
|
+
for (int i = 0; i < nc; ++i) {
|
11544
|
+
assert(!isnan(dx[i]));
|
11545
|
+
assert(!isinf(dx[i]));
|
11546
|
+
}
|
11547
|
+
#endif
|
11548
|
+
}
|
11549
|
+
}
|
11550
|
+
|
11551
|
+
static void ggml_compute_forward_soft_max_back(
|
11552
|
+
const struct ggml_compute_params * params,
|
11553
|
+
const struct ggml_tensor * src0,
|
11554
|
+
const struct ggml_tensor * src1,
|
11555
|
+
struct ggml_tensor * dst) {
|
11556
|
+
switch (src0->type) {
|
11557
|
+
case GGML_TYPE_F32:
|
11558
|
+
{
|
11559
|
+
ggml_compute_forward_soft_max_back_f32(params, src0, src1, dst);
|
11560
|
+
} break;
|
11561
|
+
default:
|
11562
|
+
{
|
11563
|
+
GGML_ASSERT(false);
|
11564
|
+
} break;
|
11565
|
+
}
|
11566
|
+
}
|
11567
|
+
|
11568
|
+
// ggml_compute_forward_alibi
|
11569
|
+
|
11570
|
+
static void ggml_compute_forward_alibi_f32(
|
11571
|
+
const struct ggml_compute_params * params,
|
11572
|
+
const struct ggml_tensor * src0,
|
11573
|
+
const struct ggml_tensor * src1,
|
11574
|
+
struct ggml_tensor * dst) {
|
11575
|
+
assert(params->ith == 0);
|
11576
|
+
assert(src1->type == GGML_TYPE_I32);
|
11577
|
+
assert(ggml_nelements(src1) == 3);
|
11578
|
+
|
11579
|
+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
11580
|
+
return;
|
11581
|
+
}
|
11582
|
+
|
11583
|
+
const int n_past = ((int32_t *) src1->data)[0];
|
11584
|
+
const int n_head = ((int32_t *) src1->data)[1];
|
11585
|
+
const float max_bias = ((float *) src1->data)[2];
|
11586
|
+
|
11587
|
+
assert(n_past >= 0);
|
11588
|
+
|
11589
|
+
const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
|
11590
|
+
const int ne1 = src0->ne[1]; // seq_len_without_past
|
11591
|
+
//const int ne2 = src0->ne[2]; // n_head -> this is k
|
11592
|
+
//const int ne3 = src0->ne[3]; // 1 -> bsz
|
11593
|
+
|
11594
|
+
const int n = ggml_nrows(src0);
|
11595
|
+
const int ne2_ne3 = n/ne1; // ne2*ne3
|
11596
|
+
|
11597
|
+
const int nb0 = src0->nb[0];
|
11598
|
+
const int nb1 = src0->nb[1];
|
11599
|
+
const int nb2 = src0->nb[2];
|
11600
|
+
//const int nb3 = src0->nb[3];
|
11601
|
+
|
11602
|
+
assert(nb0 == sizeof(float));
|
11603
|
+
assert(ne1 + n_past == ne0); (void) n_past;
|
10885
11604
|
|
10886
11605
|
// add alibi to src0 (KQ_scaled)
|
10887
11606
|
const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
|
@@ -10996,6 +11715,12 @@ static void ggml_compute_forward_alibi(
|
|
10996
11715
|
case GGML_TYPE_Q5_1:
|
10997
11716
|
case GGML_TYPE_Q8_0:
|
10998
11717
|
case GGML_TYPE_Q8_1:
|
11718
|
+
case GGML_TYPE_Q2_K:
|
11719
|
+
case GGML_TYPE_Q3_K:
|
11720
|
+
case GGML_TYPE_Q4_K:
|
11721
|
+
case GGML_TYPE_Q5_K:
|
11722
|
+
case GGML_TYPE_Q6_K:
|
11723
|
+
case GGML_TYPE_Q8_K:
|
10999
11724
|
case GGML_TYPE_I8:
|
11000
11725
|
case GGML_TYPE_I16:
|
11001
11726
|
case GGML_TYPE_I32:
|
@@ -11067,6 +11792,12 @@ static void ggml_compute_forward_clamp(
|
|
11067
11792
|
case GGML_TYPE_Q5_1:
|
11068
11793
|
case GGML_TYPE_Q8_0:
|
11069
11794
|
case GGML_TYPE_Q8_1:
|
11795
|
+
case GGML_TYPE_Q2_K:
|
11796
|
+
case GGML_TYPE_Q3_K:
|
11797
|
+
case GGML_TYPE_Q4_K:
|
11798
|
+
case GGML_TYPE_Q5_K:
|
11799
|
+
case GGML_TYPE_Q6_K:
|
11800
|
+
case GGML_TYPE_Q8_K:
|
11070
11801
|
case GGML_TYPE_I8:
|
11071
11802
|
case GGML_TYPE_I16:
|
11072
11803
|
case GGML_TYPE_I32:
|
@@ -11156,7 +11887,7 @@ static void ggml_compute_forward_rope_f32(
|
|
11156
11887
|
theta *= theta_scale;
|
11157
11888
|
|
11158
11889
|
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
11159
|
-
float * dst_data = (float *)((char *) dst->data +
|
11890
|
+
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
11160
11891
|
|
11161
11892
|
const float x0 = src[0];
|
11162
11893
|
const float x1 = src[1];
|
@@ -11177,7 +11908,7 @@ static void ggml_compute_forward_rope_f32(
|
|
11177
11908
|
const int64_t i0 = ib*n_dims + ic/2;
|
11178
11909
|
|
11179
11910
|
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
11180
|
-
float * dst_data = (float *)((char *) dst->data +
|
11911
|
+
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
11181
11912
|
|
11182
11913
|
const float x0 = src[0];
|
11183
11914
|
const float x1 = src[n_dims/2];
|
@@ -12787,6 +13518,414 @@ static void ggml_compute_forward_flash_ff(
|
|
12787
13518
|
}
|
12788
13519
|
}
|
12789
13520
|
|
13521
|
+
// ggml_compute_forward_flash_attn_back
|
13522
|
+
|
13523
|
+
static void ggml_compute_forward_flash_attn_back_f32(
|
13524
|
+
const struct ggml_compute_params * params,
|
13525
|
+
const struct ggml_tensor * q,
|
13526
|
+
const struct ggml_tensor * k,
|
13527
|
+
const struct ggml_tensor * v,
|
13528
|
+
const struct ggml_tensor * d,
|
13529
|
+
const bool masked,
|
13530
|
+
struct ggml_tensor * dst) {
|
13531
|
+
int64_t t0 = ggml_perf_time_us();
|
13532
|
+
UNUSED(t0);
|
13533
|
+
|
13534
|
+
const int64_t neq0 = q->ne[0];
|
13535
|
+
const int64_t neq1 = q->ne[1];
|
13536
|
+
const int64_t neq2 = q->ne[2];
|
13537
|
+
const int64_t neq3 = q->ne[3];
|
13538
|
+
|
13539
|
+
const int64_t nek0 = k->ne[0];
|
13540
|
+
const int64_t nek1 = k->ne[1];
|
13541
|
+
//const int64_t nek2 = k->ne[2];
|
13542
|
+
//const int64_t nek3 = k->ne[3];
|
13543
|
+
|
13544
|
+
const int64_t nev0 = v->ne[0];
|
13545
|
+
const int64_t nev1 = v->ne[1];
|
13546
|
+
//const int64_t nev2 = v->ne[2];
|
13547
|
+
//const int64_t nev3 = v->ne[3];
|
13548
|
+
|
13549
|
+
const int64_t ned0 = d->ne[0];
|
13550
|
+
const int64_t ned1 = d->ne[1];
|
13551
|
+
//const int64_t ned2 = d->ne[2];
|
13552
|
+
//const int64_t ned3 = d->ne[3];
|
13553
|
+
|
13554
|
+
const int64_t ne0 = dst->ne[0];
|
13555
|
+
const int64_t ne1 = dst->ne[1];
|
13556
|
+
const int64_t ne2 = dst->ne[2];
|
13557
|
+
const int64_t ne3 = dst->ne[3];
|
13558
|
+
|
13559
|
+
const int nbk0 = k->nb[0];
|
13560
|
+
const int nbk1 = k->nb[1];
|
13561
|
+
const int nbk2 = k->nb[2];
|
13562
|
+
const int nbk3 = k->nb[3];
|
13563
|
+
|
13564
|
+
const int nbq0 = q->nb[0];
|
13565
|
+
const int nbq1 = q->nb[1];
|
13566
|
+
const int nbq2 = q->nb[2];
|
13567
|
+
const int nbq3 = q->nb[3];
|
13568
|
+
|
13569
|
+
const int nbv0 = v->nb[0];
|
13570
|
+
const int nbv1 = v->nb[1];
|
13571
|
+
const int nbv2 = v->nb[2];
|
13572
|
+
const int nbv3 = v->nb[3];
|
13573
|
+
|
13574
|
+
const int nbd0 = d->nb[0];
|
13575
|
+
const int nbd1 = d->nb[1];
|
13576
|
+
const int nbd2 = d->nb[2];
|
13577
|
+
const int nbd3 = d->nb[3];
|
13578
|
+
|
13579
|
+
const int nb0 = dst->nb[0];
|
13580
|
+
const int nb1 = dst->nb[1];
|
13581
|
+
const int nb2 = dst->nb[2];
|
13582
|
+
const int nb3 = dst->nb[3];
|
13583
|
+
|
13584
|
+
const int ith = params->ith;
|
13585
|
+
const int nth = params->nth;
|
13586
|
+
|
13587
|
+
const int64_t D = neq0;
|
13588
|
+
const int64_t N = neq1;
|
13589
|
+
const int64_t P = nek1 - N;
|
13590
|
+
const int64_t M = P + N;
|
13591
|
+
|
13592
|
+
const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL);
|
13593
|
+
const int mxDM = MAX(D, Mup);
|
13594
|
+
|
13595
|
+
// GGML_ASSERT(ne0 == D);
|
13596
|
+
// GGML_ASSERT(ne1 == N);
|
13597
|
+
GGML_ASSERT(P >= 0);
|
13598
|
+
|
13599
|
+
GGML_ASSERT(nbq0 == sizeof(float));
|
13600
|
+
GGML_ASSERT(nbk0 == sizeof(float));
|
13601
|
+
GGML_ASSERT(nbv0 == sizeof(float));
|
13602
|
+
|
13603
|
+
GGML_ASSERT(neq0 == D);
|
13604
|
+
GGML_ASSERT(nek0 == D);
|
13605
|
+
GGML_ASSERT(nev1 == D);
|
13606
|
+
GGML_ASSERT(ned0 == D);
|
13607
|
+
|
13608
|
+
GGML_ASSERT(neq1 == N);
|
13609
|
+
GGML_ASSERT(nek1 == N + P);
|
13610
|
+
GGML_ASSERT(nev1 == D);
|
13611
|
+
GGML_ASSERT(ned1 == N);
|
13612
|
+
|
13613
|
+
// dst cannot be transposed or permuted
|
13614
|
+
GGML_ASSERT(nb0 == sizeof(float));
|
13615
|
+
GGML_ASSERT(nb0 <= nb1);
|
13616
|
+
GGML_ASSERT(nb1 <= nb2);
|
13617
|
+
GGML_ASSERT(nb2 <= nb3);
|
13618
|
+
|
13619
|
+
if (params->type == GGML_TASK_INIT) {
|
13620
|
+
if (ith == 0) {
|
13621
|
+
memset(dst->data, 0, nb0*ne0*ne1*ne2*ne3);
|
13622
|
+
}
|
13623
|
+
return;
|
13624
|
+
}
|
13625
|
+
|
13626
|
+
if (params->type == GGML_TASK_FINALIZE) {
|
13627
|
+
return;
|
13628
|
+
}
|
13629
|
+
|
13630
|
+
// parallelize by q rows using ggml_vec_dot_f32
|
13631
|
+
|
13632
|
+
// total rows in q
|
13633
|
+
const int nr = neq2*neq3;
|
13634
|
+
|
13635
|
+
// rows per thread
|
13636
|
+
const int dr = (nr + nth - 1)/nth;
|
13637
|
+
|
13638
|
+
// row range for this thread
|
13639
|
+
const int ir0 = dr*ith;
|
13640
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
13641
|
+
|
13642
|
+
const float scale = 1.0f/sqrtf(D);
|
13643
|
+
|
13644
|
+
//printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
|
13645
|
+
|
13646
|
+
for (int ir = ir0; ir < ir1; ++ir) {
|
13647
|
+
// q indices
|
13648
|
+
const int iq3 = ir/(neq2);
|
13649
|
+
const int iq2 = ir - iq3*neq2;
|
13650
|
+
for ( int iq1 = 0; iq1 < neq1; ++iq1) {
|
13651
|
+
|
13652
|
+
|
13653
|
+
// not sure about CACHE_LINE_SIZE_F32..
|
13654
|
+
// - maybe it must not be multiplied by 2 and excluded from .. in SM 1*(..) offset?
|
13655
|
+
float * S = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 0*(mxDM+CACHE_LINE_SIZE_F32);
|
13656
|
+
float * SM = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 1*(mxDM+CACHE_LINE_SIZE_F32);
|
13657
|
+
|
13658
|
+
for (int i = M; i < Mup; ++i) {
|
13659
|
+
S[i] = -INFINITY;
|
13660
|
+
}
|
13661
|
+
|
13662
|
+
for (int64_t ic = 0; ic < nek1; ++ic) {
|
13663
|
+
// k indices
|
13664
|
+
const int ik3 = iq3;
|
13665
|
+
const int ik2 = iq2;
|
13666
|
+
const int ik1 = ic;
|
13667
|
+
|
13668
|
+
// S indices
|
13669
|
+
const int i1 = ik1;
|
13670
|
+
|
13671
|
+
ggml_vec_dot_f32(neq0,
|
13672
|
+
S + i1,
|
13673
|
+
(float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
|
13674
|
+
(float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
|
13675
|
+
}
|
13676
|
+
|
13677
|
+
// scale
|
13678
|
+
ggml_vec_scale_f32(nek1, S, scale);
|
13679
|
+
|
13680
|
+
if (masked) {
|
13681
|
+
for (int64_t i = P; i < M; i++) {
|
13682
|
+
if (i > P + iq1) {
|
13683
|
+
S[i] = -INFINITY;
|
13684
|
+
}
|
13685
|
+
}
|
13686
|
+
}
|
13687
|
+
|
13688
|
+
// softmax
|
13689
|
+
{
|
13690
|
+
float max = -INFINITY;
|
13691
|
+
ggml_vec_max_f32(M, &max, S);
|
13692
|
+
|
13693
|
+
ggml_float sum = 0.0;
|
13694
|
+
{
|
13695
|
+
#ifdef GGML_SOFT_MAX_ACCELERATE
|
13696
|
+
max = -max;
|
13697
|
+
vDSP_vsadd(SM, 1, &max, SM, 1, Mup);
|
13698
|
+
vvexpf(SM, SM, &Mup);
|
13699
|
+
ggml_vec_sum_f32(Mup, &sum, SM);
|
13700
|
+
#else
|
13701
|
+
uint16_t scvt[GGML_SOFT_MAX_UNROLL];
|
13702
|
+
ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
|
13703
|
+
|
13704
|
+
for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
|
13705
|
+
float * SR = S + i;
|
13706
|
+
float * SW = SM + i;
|
13707
|
+
|
13708
|
+
for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) {
|
13709
|
+
if (SR[j] == -INFINITY) {
|
13710
|
+
SW[j] = 0.0f;
|
13711
|
+
} else {
|
13712
|
+
ggml_fp16_t s = GGML_FP32_TO_FP16(SR[j] - max);
|
13713
|
+
memcpy(&scvt[j], &s, sizeof(uint16_t));
|
13714
|
+
const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]);
|
13715
|
+
sump[j] += (ggml_float)val;
|
13716
|
+
SW[j] = val;
|
13717
|
+
}
|
13718
|
+
}
|
13719
|
+
}
|
13720
|
+
|
13721
|
+
for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) {
|
13722
|
+
sum += sump[i];
|
13723
|
+
}
|
13724
|
+
#endif
|
13725
|
+
}
|
13726
|
+
|
13727
|
+
assert(sum > 0.0);
|
13728
|
+
|
13729
|
+
sum = 1.0/sum;
|
13730
|
+
ggml_vec_scale_f32(M, SM, sum);
|
13731
|
+
|
13732
|
+
}
|
13733
|
+
|
13734
|
+
// step-by-step explanation
|
13735
|
+
{
|
13736
|
+
// forward-process shape grads from backward process
|
13737
|
+
// parallel_for iq2,iq3:
|
13738
|
+
// k[:D,:M,:,:] [D,M,:,:] grad[k][:D,:M,iq2,iq3] += grad[kcur]
|
13739
|
+
// q[:D,:N,:,:] [D,N,:,:] grad[q][:D,iq1,iq2,iq3] += grad[qcur]
|
13740
|
+
// v[:M,:D,:,:] [M,D,:,:] grad[v][:M,:D,iq2,iq3] += grad[vcur]
|
13741
|
+
// for iq1:
|
13742
|
+
// kcur = k[:D,:M,iq2,iq3] [D,M,1,1] grad[kcur] = grad[S1].T @ qcur
|
13743
|
+
// qcur = q[:D,iq1,iq2,iq3] [D,1,1,1] grad[qcur] = grad[S1] @ kcur
|
13744
|
+
// vcur = v[:M,:D,iq2,iq3] [M,D,1,1] grad[vcur] = grad[S5].T @ S4
|
13745
|
+
// S0 = -Inf [D,1,1,1]
|
13746
|
+
// ~S1[i] = dot(kcur[:D,i], qcur)
|
13747
|
+
// S1 = qcur @ kcur.T [M,1,1,1] grad[S1] = grad[S2] * scale
|
13748
|
+
// S2 = S1 * scale [M,1,1,1] grad[S2] = diag_mask_zero(grad[S3], P)
|
13749
|
+
// S3 = diag_mask_inf(S2, P) [M,1,1,1] grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
|
13750
|
+
// S4 = softmax(S3) [M,1,1,1] grad[S4] = grad[S5] @ vcur
|
13751
|
+
// ~S5[i] = dot(vcur[:,i], S4)
|
13752
|
+
// S5 = S4 @ vcur.T [D,1,1,1] grad[S5] = d[:D,iq1,iq2,iq3]
|
13753
|
+
// ~dst[i,iq1,iq2,iq3] = S5[i] ^
|
13754
|
+
// dst[:D,iq1,iq2,iq3] = S5 | grad[dst[:D,iq1,iq2,iq3]] = d[:D,iq1,iq2,iq3]
|
13755
|
+
// dst backward-/ grad[dst] = d
|
13756
|
+
//
|
13757
|
+
// output gradients with their dependencies:
|
13758
|
+
//
|
13759
|
+
// grad[kcur] = grad[S1].T @ qcur
|
13760
|
+
// grad[S1] = diag_mask_zero(grad[S3], P) * scale
|
13761
|
+
// grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
|
13762
|
+
// grad[S4] = grad[S5] @ vcur
|
13763
|
+
// grad[S4] = d[:D,iq1,iq2,iq3] @ vcur
|
13764
|
+
// grad[qcur] = grad[S1] @ kcur
|
13765
|
+
// grad[vcur] = grad[S5].T @ S4
|
13766
|
+
// grad[vcur] = d[:D,iq1,iq2,iq3].T @ S4
|
13767
|
+
//
|
13768
|
+
// in post-order:
|
13769
|
+
//
|
13770
|
+
// S1 = qcur @ kcur.T
|
13771
|
+
// S2 = S1 * scale
|
13772
|
+
// S3 = diag_mask_inf(S2, P)
|
13773
|
+
// S4 = softmax(S3)
|
13774
|
+
// grad[S4] = d[:D,iq1,iq2,iq3] @ vcur
|
13775
|
+
// grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
|
13776
|
+
// grad[S1] = diag_mask_zero(grad[S3], P) * scale
|
13777
|
+
// grad[qcur] = grad[S1] @ kcur
|
13778
|
+
// grad[kcur] = grad[S1].T @ qcur
|
13779
|
+
// grad[vcur] = d[:D,iq1,iq2,iq3].T @ S4
|
13780
|
+
//
|
13781
|
+
// using less variables (SM=S4):
|
13782
|
+
//
|
13783
|
+
// S = diag_mask_inf(qcur @ kcur.T * scale, P)
|
13784
|
+
// SM = softmax(S)
|
13785
|
+
// S = d[:D,iq1,iq2,iq3] @ vcur
|
13786
|
+
// dot_SM_gradSM = dot(SM, S)
|
13787
|
+
// S = SM * (S - dot(SM, S))
|
13788
|
+
// S = diag_mask_zero(S, P) * scale
|
13789
|
+
//
|
13790
|
+
// grad[q][:D,iq1,iq2,iq3] += S @ kcur
|
13791
|
+
// grad[k][:D,:M,iq2,iq3] += S.T @ qcur
|
13792
|
+
// grad[v][:M,:D,iq2,iq3] += d[:D,iq1,iq2,iq3].T @ SM
|
13793
|
+
}
|
13794
|
+
|
13795
|
+
// S = gradSM = d[:D,iq1,iq2,iq3] @ vcur
|
13796
|
+
// S = d[:D,iq1,iq2,iq3] @ vcur
|
13797
|
+
// S[:M] += vcur[:M,ic] * d[ic,iq1,iq2,iq3]
|
13798
|
+
ggml_vec_set_f32(M, S, 0);
|
13799
|
+
for (int64_t ic = 0; ic < D; ++ic) {
|
13800
|
+
// dst indices
|
13801
|
+
const int i1 = iq1;
|
13802
|
+
const int i2 = iq2;
|
13803
|
+
const int i3 = iq3;
|
13804
|
+
|
13805
|
+
ggml_vec_mad_f32(M,
|
13806
|
+
S,
|
13807
|
+
(float *) ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)),
|
13808
|
+
*(float *) ((char *) d->data + (ic*nbd0 + i1*nbd1 + i2*nbd2 + i3*nbd3)));
|
13809
|
+
}
|
13810
|
+
|
13811
|
+
// S = SM * (S - dot(SM, S))
|
13812
|
+
float dot_SM_gradSM = 0;
|
13813
|
+
ggml_vec_dot_f32 (M, &dot_SM_gradSM, SM, S);
|
13814
|
+
ggml_vec_acc1_f32(M, S, -dot_SM_gradSM);
|
13815
|
+
ggml_vec_mul_f32 (M, S, S, SM);
|
13816
|
+
|
13817
|
+
// S = diag_mask_zero(S, P) * scale
|
13818
|
+
if (masked) {
|
13819
|
+
// for (int64_t i = P + iq1 + 1; i < M; i++) {
|
13820
|
+
// S[i] = 0;
|
13821
|
+
// }
|
13822
|
+
for (int64_t i = P; i < M; i++) {
|
13823
|
+
if (i > P + iq1) {
|
13824
|
+
S[i] = 0;
|
13825
|
+
}
|
13826
|
+
}
|
13827
|
+
}
|
13828
|
+
ggml_vec_scale_f32(M, S, scale);
|
13829
|
+
|
13830
|
+
void * grad_q = (char *) dst->data;
|
13831
|
+
void * grad_k = (char *) dst->data + nb0*D*N*neq2*neq3;
|
13832
|
+
void * grad_v = (char *) dst->data + nb0*D*N*neq2*neq3 + nb0*D*M*neq2*neq3;
|
13833
|
+
|
13834
|
+
const size_t nbgq1 = nb0*neq0;
|
13835
|
+
const size_t nbgq2 = nb0*neq0*neq1;
|
13836
|
+
const size_t nbgq3 = nb0*neq0*neq1*neq2;
|
13837
|
+
|
13838
|
+
const size_t nbgk1 = nb0*nek0;
|
13839
|
+
const size_t nbgk2 = nb0*nek0*nek1;
|
13840
|
+
const size_t nbgk3 = nb0*nek0*nek1*neq2;
|
13841
|
+
|
13842
|
+
const size_t nbgv1 = nb0*nev0;
|
13843
|
+
const size_t nbgv2 = nb0*nev0*nev1;
|
13844
|
+
const size_t nbgv3 = nb0*nev0*nev1*neq2;
|
13845
|
+
|
13846
|
+
// S shape [M,1]
|
13847
|
+
// SM shape [M,1]
|
13848
|
+
// kcur shape [D,M]
|
13849
|
+
// qcur shape [D,1]
|
13850
|
+
// vcur shape [M,D]
|
13851
|
+
//
|
13852
|
+
// grad[q][:D,iq1,iq2,iq3] += S @ kcur
|
13853
|
+
// grad[q][:D,iq1,iq2,iq3] += shape[M,1] @ shape[D,M]
|
13854
|
+
// grad[q][:D,iq1,iq2,iq3] += S[ic] * kcur[:D,ic]
|
13855
|
+
//
|
13856
|
+
//// grad[q][ic,iq1,iq2,iq3] += dot(kcur[:,ic],S.T)
|
13857
|
+
//// grad[q][ic,iq1,iq2,iq3] += dot(k[:D,ic,iq2,iq3],S.T)
|
13858
|
+
for (int64_t ic = 0; ic < M; ++ic) {
|
13859
|
+
// dst indices
|
13860
|
+
const int i1 = iq1;
|
13861
|
+
const int i2 = iq2;
|
13862
|
+
const int i3 = iq3;
|
13863
|
+
|
13864
|
+
ggml_vec_mad_f32(D,
|
13865
|
+
(float *) ((char *) grad_q + (i1*nbgq1 + i2*nbgq2 + i3*nbgq3)),
|
13866
|
+
(float *) ((char *) k->data + (ic*nbk1 + i2*nbk2 + i3*nbk3)),
|
13867
|
+
S[ic]);
|
13868
|
+
}
|
13869
|
+
|
13870
|
+
// grad[k][:D,:M,iq2,iq3] += S.T @ qcur
|
13871
|
+
// grad[k][:D,ic,iq2,iq3] += S.T[0,ic] * qcur[:D,0]
|
13872
|
+
// grad[k][:D,ic,iq2,iq3] += S[ic] * qcur[:D,0]
|
13873
|
+
for (int64_t ic = 0; ic < M; ++ic) {
|
13874
|
+
// dst indices
|
13875
|
+
const int i1 = iq1;
|
13876
|
+
const int i2 = iq2;
|
13877
|
+
const int i3 = iq3;
|
13878
|
+
|
13879
|
+
// ggml_vec_set_f32(D,
|
13880
|
+
// (float *) ((char *) grad_k + (ic*nbgk1 + i2*nbgk2 + i3*nbgk3)),
|
13881
|
+
// 0);
|
13882
|
+
ggml_vec_mad_f32(D,
|
13883
|
+
(float *) ((char *) grad_k + (ic*nbgk1 + i2*nbgk2 + i3*nbgk3)),
|
13884
|
+
(float *) ((char *) q->data + (i1*nbq1 + i2*nbq2 + i3*nbq3)),
|
13885
|
+
S[ic]);
|
13886
|
+
}
|
13887
|
+
|
13888
|
+
// grad[v][:M,:D,iq2,iq3] += d[:D,iq1,iq2,iq3].T @ SM
|
13889
|
+
// grad[v][:M,ic,iq2,iq3] += d[:D,iq1,iq2,iq3].T[0,ic] * SM[:M]
|
13890
|
+
// grad[v][:M,ic,iq2,iq3] += d[ic,iq1,iq2,iq3] * SM[:M]
|
13891
|
+
for (int64_t ic = 0; ic < D; ++ic) {
|
13892
|
+
// dst indices
|
13893
|
+
const int i1 = iq1;
|
13894
|
+
const int i2 = iq2;
|
13895
|
+
const int i3 = iq3;
|
13896
|
+
|
13897
|
+
// ggml_vec_set_f32(M,
|
13898
|
+
// (float *) ((char *) grad_v + ( ic*nbgv1 + i2*nbgv2 + i3*nbgv3)),
|
13899
|
+
// 0);
|
13900
|
+
ggml_vec_mad_f32(M,
|
13901
|
+
(float *) ((char *) grad_v + ( ic*nbgv1 + i2*nbgv2 + i3*nbgv3)),
|
13902
|
+
SM,
|
13903
|
+
*(float *) ((char *) d->data + (ic*nbd0 + i1*nbd1 + i2*nbd2 + i3*nbd3)));
|
13904
|
+
}
|
13905
|
+
}
|
13906
|
+
}
|
13907
|
+
}
|
13908
|
+
|
13909
|
+
static void ggml_compute_forward_flash_attn_back(
|
13910
|
+
const struct ggml_compute_params * params,
|
13911
|
+
const struct ggml_tensor * q,
|
13912
|
+
const struct ggml_tensor * k,
|
13913
|
+
const struct ggml_tensor * v,
|
13914
|
+
const struct ggml_tensor * d,
|
13915
|
+
const bool masked,
|
13916
|
+
struct ggml_tensor * dst) {
|
13917
|
+
switch (q->type) {
|
13918
|
+
case GGML_TYPE_F32:
|
13919
|
+
{
|
13920
|
+
ggml_compute_forward_flash_attn_back_f32(params, q, k, v, d, masked, dst);
|
13921
|
+
} break;
|
13922
|
+
default:
|
13923
|
+
{
|
13924
|
+
GGML_ASSERT(false);
|
13925
|
+
} break;
|
13926
|
+
}
|
13927
|
+
}
|
13928
|
+
|
12790
13929
|
// ggml_compute_forward_map_unary
|
12791
13930
|
|
12792
13931
|
static void ggml_compute_forward_map_unary_f32(
|
@@ -12849,29 +13988,308 @@ static void ggml_compute_forward_map_binary_f32(
|
|
12849
13988
|
const int n = ggml_nrows(src0);
|
12850
13989
|
const int nc = src0->ne[0];
|
12851
13990
|
|
12852
|
-
assert( dst->nb[0] == sizeof(float));
|
12853
|
-
assert(src0->nb[0] == sizeof(float));
|
12854
|
-
assert(src1->nb[0] == sizeof(float));
|
13991
|
+
assert( dst->nb[0] == sizeof(float));
|
13992
|
+
assert(src0->nb[0] == sizeof(float));
|
13993
|
+
assert(src1->nb[0] == sizeof(float));
|
13994
|
+
|
13995
|
+
for (int i = 0; i < n; i++) {
|
13996
|
+
fun(nc,
|
13997
|
+
(float *) ((char *) dst->data + i*( dst->nb[1])),
|
13998
|
+
(float *) ((char *) src0->data + i*(src0->nb[1])),
|
13999
|
+
(float *) ((char *) src1->data + i*(src1->nb[1])));
|
14000
|
+
}
|
14001
|
+
}
|
14002
|
+
|
14003
|
+
|
14004
|
+
static void ggml_compute_forward_map_binary(
|
14005
|
+
const struct ggml_compute_params * params,
|
14006
|
+
const struct ggml_tensor * src0,
|
14007
|
+
const struct ggml_tensor * src1,
|
14008
|
+
struct ggml_tensor * dst,
|
14009
|
+
const ggml_binary_op_f32_t fun) {
|
14010
|
+
switch (src0->type) {
|
14011
|
+
case GGML_TYPE_F32:
|
14012
|
+
{
|
14013
|
+
ggml_compute_forward_map_binary_f32(params, src0, src1, dst, fun);
|
14014
|
+
} break;
|
14015
|
+
default:
|
14016
|
+
{
|
14017
|
+
GGML_ASSERT(false);
|
14018
|
+
} break;
|
14019
|
+
}
|
14020
|
+
}
|
14021
|
+
|
14022
|
+
// ggml_compute_forward_cross_entropy_loss
|
14023
|
+
|
14024
|
+
static void ggml_compute_forward_cross_entropy_loss_f32(
|
14025
|
+
const struct ggml_compute_params * params,
|
14026
|
+
const struct ggml_tensor * src0,
|
14027
|
+
const struct ggml_tensor * src1,
|
14028
|
+
struct ggml_tensor * dst) {
|
14029
|
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
14030
|
+
GGML_ASSERT(ggml_is_contiguous(src1));
|
14031
|
+
GGML_ASSERT(ggml_is_scalar(dst));
|
14032
|
+
GGML_ASSERT(ggml_are_same_shape(src0, src1));
|
14033
|
+
|
14034
|
+
const int ith = params->ith;
|
14035
|
+
const int nth = params->nth;
|
14036
|
+
|
14037
|
+
float * sums = (float *) params->wdata;
|
14038
|
+
|
14039
|
+
// TODO: handle transposed/permuted matrices
|
14040
|
+
const int nc = src0->ne[0];
|
14041
|
+
const int nr = ggml_nrows(src0);
|
14042
|
+
|
14043
|
+
if (params->type == GGML_TASK_INIT) {
|
14044
|
+
if (ith == 0) {
|
14045
|
+
memset(sums, 0, sizeof(float) * (nth + nth * nc));
|
14046
|
+
}
|
14047
|
+
return;
|
14048
|
+
}
|
14049
|
+
|
14050
|
+
if (params->type == GGML_TASK_FINALIZE) {
|
14051
|
+
if (ith == 0) {
|
14052
|
+
float * dp = (float *) dst->data;
|
14053
|
+
ggml_vec_sum_f32(nth, dp, sums);
|
14054
|
+
dp[0] *= -1.0f;
|
14055
|
+
}
|
14056
|
+
return;
|
14057
|
+
}
|
14058
|
+
|
14059
|
+
const double eps = 1e-9;
|
14060
|
+
|
14061
|
+
// rows per thread
|
14062
|
+
const int dr = (nr + nth - 1)/nth;
|
14063
|
+
|
14064
|
+
// row range for this thread
|
14065
|
+
const int ir0 = dr*ith;
|
14066
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
14067
|
+
|
14068
|
+
for (int i1 = ir0; i1 < ir1; i1++) {
|
14069
|
+
float * s0 = (float *)((char *) src0->data + i1*src0->nb[1]);
|
14070
|
+
float * s1 = (float *)((char *) src1->data + i1*src1->nb[1]);
|
14071
|
+
float * st = (float *) params->wdata + nth + ith*nc;
|
14072
|
+
|
14073
|
+
#ifndef NDEBUG
|
14074
|
+
for (int i = 0; i < nc; ++i) {
|
14075
|
+
//printf("p[%d] = %f\n", i, p[i]);
|
14076
|
+
assert(!isnan(s0[i]));
|
14077
|
+
assert(!isnan(s1[i]));
|
14078
|
+
}
|
14079
|
+
#endif
|
14080
|
+
// soft_max
|
14081
|
+
ggml_float sum = 0.0;
|
14082
|
+
{
|
14083
|
+
float max = -INFINITY;
|
14084
|
+
ggml_vec_max_f32(nc, &max, s0);
|
14085
|
+
|
14086
|
+
uint16_t scvt;
|
14087
|
+
for (int i = 0; i < nc; i++) {
|
14088
|
+
if (s0[i] == -INFINITY) {
|
14089
|
+
st[i] = 0.0f;
|
14090
|
+
} else {
|
14091
|
+
// const float val = (s0[i] == -INFINITY) ? 0.0 : exp(s0[i] - max);
|
14092
|
+
ggml_fp16_t s = GGML_FP32_TO_FP16(s0[i] - max);
|
14093
|
+
memcpy(&scvt, &s, sizeof(scvt));
|
14094
|
+
const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt]);
|
14095
|
+
sum += (ggml_float)val;
|
14096
|
+
st[i] = val;
|
14097
|
+
}
|
14098
|
+
}
|
14099
|
+
|
14100
|
+
assert(sum > 0.0);
|
14101
|
+
// sum = 1.0/sum;
|
14102
|
+
}
|
14103
|
+
// avoid log(0) by rescaling from [0..1] to [eps..1]
|
14104
|
+
sum = (1.0 - eps) / sum;
|
14105
|
+
ggml_vec_scale_f32(nc, st, sum);
|
14106
|
+
ggml_vec_add1_f32(nc, st, st, eps);
|
14107
|
+
ggml_vec_log_f32(nc, st, st);
|
14108
|
+
ggml_vec_mul_f32(nc, st, st, s1);
|
14109
|
+
|
14110
|
+
ggml_vec_sum_f32(nc, sums + ith, st);
|
14111
|
+
|
14112
|
+
#ifndef NDEBUG
|
14113
|
+
for (int i = 0; i < nc; ++i) {
|
14114
|
+
assert(!isnan(st[i]));
|
14115
|
+
assert(!isinf(st[i]));
|
14116
|
+
}
|
14117
|
+
#endif
|
14118
|
+
}
|
14119
|
+
|
14120
|
+
}
|
14121
|
+
|
14122
|
+
static void ggml_compute_forward_cross_entropy_loss(
|
14123
|
+
const struct ggml_compute_params * params,
|
14124
|
+
const struct ggml_tensor * src0,
|
14125
|
+
const struct ggml_tensor * src1,
|
14126
|
+
struct ggml_tensor * dst) {
|
14127
|
+
switch (src0->type) {
|
14128
|
+
case GGML_TYPE_F32:
|
14129
|
+
{
|
14130
|
+
ggml_compute_forward_cross_entropy_loss_f32(params, src0, src1, dst);
|
14131
|
+
} break;
|
14132
|
+
default:
|
14133
|
+
{
|
14134
|
+
GGML_ASSERT(false);
|
14135
|
+
} break;
|
14136
|
+
}
|
14137
|
+
}
|
14138
|
+
|
14139
|
+
// ggml_compute_forward_cross_entropy_loss_back
|
14140
|
+
|
14141
|
+
static void ggml_compute_forward_cross_entropy_loss_back_f32(
|
14142
|
+
const struct ggml_compute_params * params,
|
14143
|
+
const struct ggml_tensor * src0,
|
14144
|
+
const struct ggml_tensor * src1,
|
14145
|
+
const struct ggml_tensor * opt0,
|
14146
|
+
struct ggml_tensor * dst) {
|
14147
|
+
GGML_ASSERT(ggml_is_contiguous(dst));
|
14148
|
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
14149
|
+
GGML_ASSERT(ggml_is_contiguous(src1));
|
14150
|
+
GGML_ASSERT(ggml_is_contiguous(opt0));
|
14151
|
+
GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
|
14152
|
+
|
14153
|
+
const int64_t ith = params->ith;
|
14154
|
+
const int64_t nth = params->nth;
|
14155
|
+
|
14156
|
+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
14157
|
+
return;
|
14158
|
+
}
|
14159
|
+
|
14160
|
+
const float eps = 1e-9f;
|
14161
|
+
|
14162
|
+
// TODO: handle transposed/permuted matrices
|
14163
|
+
const int64_t nc = src0->ne[0];
|
14164
|
+
const int64_t nr = ggml_nrows(src0);
|
14165
|
+
|
14166
|
+
// rows per thread
|
14167
|
+
const int64_t dr = (nr + nth - 1)/nth;
|
14168
|
+
|
14169
|
+
// row range for this thread
|
14170
|
+
const int64_t ir0 = dr*ith;
|
14171
|
+
const int64_t ir1 = MIN(ir0 + dr, nr);
|
14172
|
+
|
14173
|
+
float * d = (float *) opt0->data;
|
14174
|
+
|
14175
|
+
for (int64_t i1 = ir0; i1 < ir1; i1++) {
|
14176
|
+
float * ds0 = (float *)((char *) dst->data + i1*dst->nb[1]);
|
14177
|
+
float * s0 = (float *)((char *) src0->data + i1*src0->nb[1]);
|
14178
|
+
float * s1 = (float *)((char *) src1->data + i1*src1->nb[1]);
|
14179
|
+
float * sm = (float *) params->wdata + ith*nc;
|
14180
|
+
|
14181
|
+
#ifndef NDEBUG
|
14182
|
+
for (int i = 0; i < nc; ++i) {
|
14183
|
+
//printf("p[%d] = %f\n", i, p[i]);
|
14184
|
+
assert(!isnan(s0[i]));
|
14185
|
+
assert(!isnan(s1[i]));
|
14186
|
+
}
|
14187
|
+
#endif
|
14188
|
+
// step by step explanation:
|
14189
|
+
{
|
14190
|
+
//float * sums = (float *) params->wdata;
|
14191
|
+
|
14192
|
+
// forward pass with annotated gradients from backward pass
|
14193
|
+
// (built by going in reverse operation order, adding to gradients of current operation args)
|
14194
|
+
// st0 = exp(s0-max(s0)) grad[st0] = grad[st1]*(1.0 - eps)/sum
|
14195
|
+
// from softmax_back: grad[s0] = st1_k * (grad[st1]_k - dot(st1, grad[st1]))
|
14196
|
+
// ggml_vec_scale_f32(nc, st, sum); // st1 = st0*/sum = softmax(s0) grad[st1] = grad[st2]*(1.0 - eps)
|
14197
|
+
// ggml_vec_scale_f32(nc, st, (1.0f - eps)); // st2 = st1*(1.0 - eps) grad[st2] = grad[st3]
|
14198
|
+
// ggml_vec_add1_f32(nc, st, st, eps); // st3 = st2 + eps grad[st3] = grad[st4]/st3
|
14199
|
+
// ggml_vec_log_f32(nc, st, st); // st4 = log(st3) grad[st4] = grad[st5] * s1
|
14200
|
+
// ggml_vec_mul_f32(nc, st, st, s1); // st5 = st4 * s1 grad[st5] = grad[sums[ith]]
|
14201
|
+
// ggml_vec_sum_f32(nc, sums + ith, st); // sums[ith] = st5 grad[sums[ith]] = grad[cross_entropy_loss] = -grad[cel]
|
14202
|
+
|
14203
|
+
// substitute into grad[st1], because we can reuse softmax_back from this point on
|
14204
|
+
// grad[st1] = -grad[cel]*s1*(1.0 - eps)/(eps + softmax(s0)*(1.0 - eps))
|
14205
|
+
// postorder:
|
14206
|
+
// grad[st1] := softmax(s0)
|
14207
|
+
// grad[st1] := grad[st1]*(1.0 - eps)
|
14208
|
+
// grad[st1] := grad[st1] + eps
|
14209
|
+
// grad[st1] := s1 / grad[st1]
|
14210
|
+
// grad[st1] := grad[st1]*(1.0-eps)*-grad[cel]
|
14211
|
+
|
14212
|
+
// src0 gradients by going through softmax_back
|
14213
|
+
// grad[s0] = st1_k * (grad[st1]_k - dot(st1, grad[st1]))
|
14214
|
+
// from softmax_back:
|
14215
|
+
// dxk = yk * (dyk - dot(y, dy))
|
14216
|
+
// dot_y_dy := dot(y, dy)
|
14217
|
+
// dx := dy
|
14218
|
+
// dx := dx - dot_y_dy
|
14219
|
+
// dx := dx * y
|
14220
|
+
// postorder:
|
14221
|
+
// dot_st1_dst1 := dot(st1, grad[st1])
|
14222
|
+
// grad[s0] := grad[st1]
|
14223
|
+
// grad[s0] := grad[s0] - dot_st1_dst1
|
14224
|
+
// grad[s0] := grad[s0] * st1
|
14225
|
+
|
14226
|
+
// prepend postorder from grad[st1] directly using grad[s0] as memory location, as we will grad[s0] := grad[st1]
|
14227
|
+
// sm := softmax(s0)
|
14228
|
+
// grad[s0] := sm*(1.0 - eps)
|
14229
|
+
// grad[s0] := grad[s0] + eps
|
14230
|
+
// grad[s0] := s1 / grad[s0]
|
14231
|
+
// grad[s0] := grad[s0]*(1.0-eps)*-grad[cel]
|
14232
|
+
// dot_st1_dst1 := dot(sm, grad[s0])
|
14233
|
+
// grad[s0] := grad[s0] - dot_st1_dst1
|
14234
|
+
// grad[s0] := grad[s0] * sm
|
14235
|
+
}
|
14236
|
+
|
14237
|
+
// soft_max
|
14238
|
+
ggml_float sum = 0.0;
|
14239
|
+
{
|
14240
|
+
float max = -INFINITY;
|
14241
|
+
ggml_vec_max_f32(nc, &max, s0);
|
14242
|
+
|
14243
|
+
uint16_t scvt;
|
14244
|
+
for (int i = 0; i < nc; i++) {
|
14245
|
+
if (s0[i] == -INFINITY) {
|
14246
|
+
sm[i] = 0.0f;
|
14247
|
+
} else {
|
14248
|
+
// const float val = (s0[i] == -INFINITY) ? 0.0 : exp(s0[i] - max);
|
14249
|
+
ggml_fp16_t s = GGML_FP32_TO_FP16(s0[i] - max);
|
14250
|
+
memcpy(&scvt, &s, sizeof(scvt));
|
14251
|
+
const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt]);
|
14252
|
+
sum += (ggml_float)val;
|
14253
|
+
sm[i] = val;
|
14254
|
+
}
|
14255
|
+
}
|
14256
|
+
|
14257
|
+
assert(sum > 0.0);
|
14258
|
+
sum = 1.0/sum;
|
14259
|
+
}
|
12855
14260
|
|
12856
|
-
|
12857
|
-
|
12858
|
-
|
12859
|
-
|
12860
|
-
|
14261
|
+
float dot_st1_dst1 = 0;
|
14262
|
+
ggml_vec_scale_f32(nc, sm, sum);
|
14263
|
+
ggml_vec_cpy_f32 (nc, ds0, sm);
|
14264
|
+
ggml_vec_scale_f32(nc, ds0, (1.0f - eps));
|
14265
|
+
ggml_vec_add1_f32 (nc, ds0, ds0, eps);
|
14266
|
+
ggml_vec_div_f32 (nc, ds0, s1, ds0);
|
14267
|
+
ggml_vec_scale_f32(nc, ds0, -(1.0f - eps)*d[0]);
|
14268
|
+
ggml_vec_dot_f32 (nc, &dot_st1_dst1, sm, ds0);
|
14269
|
+
ggml_vec_acc1_f32 (nc, ds0, -dot_st1_dst1);
|
14270
|
+
ggml_vec_mul_f32 (nc, ds0, ds0, sm);
|
14271
|
+
|
14272
|
+
#ifndef NDEBUG
|
14273
|
+
for (int i = 0; i < nc; ++i) {
|
14274
|
+
assert(!isnan(sm[i]));
|
14275
|
+
assert(!isinf(sm[i]));
|
14276
|
+
assert(!isnan(ds0[i]));
|
14277
|
+
assert(!isinf(ds0[i]));
|
14278
|
+
}
|
14279
|
+
#endif
|
12861
14280
|
}
|
12862
14281
|
}
|
12863
14282
|
|
12864
|
-
|
12865
|
-
static void ggml_compute_forward_map_binary(
|
14283
|
+
static void ggml_compute_forward_cross_entropy_loss_back(
|
12866
14284
|
const struct ggml_compute_params * params,
|
12867
14285
|
const struct ggml_tensor * src0,
|
12868
14286
|
const struct ggml_tensor * src1,
|
12869
|
-
struct ggml_tensor *
|
12870
|
-
|
14287
|
+
const struct ggml_tensor * opt0,
|
14288
|
+
struct ggml_tensor * dst) {
|
12871
14289
|
switch (src0->type) {
|
12872
14290
|
case GGML_TYPE_F32:
|
12873
14291
|
{
|
12874
|
-
|
14292
|
+
ggml_compute_forward_cross_entropy_loss_back_f32(params, src0, src1, opt0, dst);
|
12875
14293
|
} break;
|
12876
14294
|
default:
|
12877
14295
|
{
|
@@ -12880,11 +14298,21 @@ static void ggml_compute_forward_map_binary(
|
|
12880
14298
|
}
|
12881
14299
|
}
|
12882
14300
|
|
14301
|
+
|
12883
14302
|
/////////////////////////////////
|
12884
14303
|
|
12885
14304
|
static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
|
12886
14305
|
GGML_ASSERT(params);
|
12887
14306
|
|
14307
|
+
#ifdef GGML_USE_CUBLAS
|
14308
|
+
bool skip_cpu = ggml_cuda_compute_forward(params, tensor);
|
14309
|
+
if (skip_cpu) {
|
14310
|
+
return;
|
14311
|
+
}
|
14312
|
+
GGML_ASSERT(tensor->src0->backend == GGML_BACKEND_CPU);
|
14313
|
+
GGML_ASSERT(tensor->src1 == NULL || tensor->src1->backend == GGML_BACKEND_CPU);
|
14314
|
+
#endif // GGML_USE_CUBLAS
|
14315
|
+
|
12888
14316
|
switch (tensor->op) {
|
12889
14317
|
case GGML_OP_DUP:
|
12890
14318
|
{
|
@@ -12942,6 +14370,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
12942
14370
|
{
|
12943
14371
|
ggml_compute_forward_repeat(params, tensor->src0, tensor);
|
12944
14372
|
} break;
|
14373
|
+
case GGML_OP_REPEAT_BACK:
|
14374
|
+
{
|
14375
|
+
ggml_compute_forward_repeat_back(params, tensor->src0, tensor);
|
14376
|
+
} break;
|
12945
14377
|
case GGML_OP_ABS:
|
12946
14378
|
{
|
12947
14379
|
ggml_compute_forward_abs(params, tensor->src0, tensor);
|
@@ -12990,6 +14422,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
12990
14422
|
{
|
12991
14423
|
ggml_compute_forward_mul_mat(params, tensor->src0, tensor->src1, tensor);
|
12992
14424
|
} break;
|
14425
|
+
case GGML_OP_OUT_PROD:
|
14426
|
+
{
|
14427
|
+
ggml_compute_forward_out_prod(params, tensor->src0, tensor->src1, tensor);
|
14428
|
+
} break;
|
12993
14429
|
case GGML_OP_SCALE:
|
12994
14430
|
{
|
12995
14431
|
ggml_compute_forward_scale(params, tensor->src0, tensor->src1, tensor);
|
@@ -13046,6 +14482,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
13046
14482
|
{
|
13047
14483
|
ggml_compute_forward_soft_max(params, tensor->src0, tensor);
|
13048
14484
|
} break;
|
14485
|
+
case GGML_OP_SOFT_MAX_BACK:
|
14486
|
+
{
|
14487
|
+
ggml_compute_forward_soft_max_back(params, tensor->src0, tensor->src1, tensor);
|
14488
|
+
} break;
|
13049
14489
|
case GGML_OP_ROPE:
|
13050
14490
|
{
|
13051
14491
|
ggml_compute_forward_rope(params, tensor->src0, tensor->src1, tensor);
|
@@ -13081,6 +14521,13 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
13081
14521
|
{
|
13082
14522
|
ggml_compute_forward_flash_ff(params, tensor->src0, tensor->src1, tensor->opt[0], tensor->opt[1], tensor->opt[2], tensor);
|
13083
14523
|
} break;
|
14524
|
+
case GGML_OP_FLASH_ATTN_BACK:
|
14525
|
+
{
|
14526
|
+
int32_t t = ggml_get_i32_1d(tensor->opt[2], 0);
|
14527
|
+
GGML_ASSERT(t == 0 || t == 1);
|
14528
|
+
bool masked = t != 0;
|
14529
|
+
ggml_compute_forward_flash_attn_back(params, tensor->src0, tensor->src1, tensor->opt[0], tensor->opt[1], masked, tensor);
|
14530
|
+
} break;
|
13084
14531
|
case GGML_OP_MAP_UNARY:
|
13085
14532
|
{
|
13086
14533
|
const ggml_unary_op_f32_t fun = *((ggml_unary_op_f32_t *)tensor->opt[0]->data);
|
@@ -13093,6 +14540,16 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
13093
14540
|
ggml_compute_forward_map_binary(params, tensor->src0, tensor->src1, tensor, fun);
|
13094
14541
|
}
|
13095
14542
|
break;
|
14543
|
+
case GGML_OP_CROSS_ENTROPY_LOSS:
|
14544
|
+
{
|
14545
|
+
ggml_compute_forward_cross_entropy_loss(params, tensor->src0, tensor->src1, tensor);
|
14546
|
+
}
|
14547
|
+
break;
|
14548
|
+
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
14549
|
+
{
|
14550
|
+
ggml_compute_forward_cross_entropy_loss_back(params, tensor->src0, tensor->src1, tensor->opt[0], tensor);
|
14551
|
+
}
|
14552
|
+
break;
|
13096
14553
|
case GGML_OP_NONE:
|
13097
14554
|
{
|
13098
14555
|
// nop
|
@@ -13231,11 +14688,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
13231
14688
|
src0->grad =
|
13232
14689
|
ggml_add_impl(ctx,
|
13233
14690
|
src0->grad,
|
13234
|
-
|
13235
|
-
tensor->grad, // this was not catched by test_grad because in test_grad tensor->grad is 1
|
14691
|
+
ggml_scale(ctx,
|
13236
14692
|
ggml_div(ctx,
|
13237
|
-
|
13238
|
-
tensor)
|
14693
|
+
tensor->grad,
|
14694
|
+
tensor),
|
14695
|
+
ggml_new_f32(ctx, 0.5f)),
|
13239
14696
|
inplace);
|
13240
14697
|
}
|
13241
14698
|
} break;
|
@@ -13281,43 +14738,20 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
13281
14738
|
{
|
13282
14739
|
// necessary for llama
|
13283
14740
|
if (src0->grad) {
|
13284
|
-
|
13285
|
-
|
13286
|
-
|
13287
|
-
|
13288
|
-
|
13289
|
-
|
13290
|
-
|
13291
|
-
|
13292
|
-
|
13293
|
-
//
|
13294
|
-
|
13295
|
-
|
13296
|
-
|
13297
|
-
|
13298
|
-
// transpose [nc0*nr0,1,1]
|
13299
|
-
// reshape [nc0,nr0,1,1] reshape_1d or reshape_2d
|
13300
|
-
// add to src0->grad
|
13301
|
-
|
13302
|
-
int64_t ne[4] = {nc0,ncr,nr0,nrr};
|
13303
|
-
|
13304
|
-
struct ggml_tensor* F00 = tensor->grad;
|
13305
|
-
struct ggml_tensor* F01 = ggml_reshape (ctx, F00, ggml_new_tensor(ctx,tensor->grad->type,4,ne));
|
13306
|
-
struct ggml_tensor* F02 = ggml_permute (ctx, F01, 0,2,1,3);
|
13307
|
-
struct ggml_tensor* F03 = ggml_cont (ctx, F02);
|
13308
|
-
struct ggml_tensor* F04 = ggml_reshape_2d(ctx, F03, nc0*nr0, ncr*nrr);
|
13309
|
-
struct ggml_tensor* F05 = ggml_transpose (ctx, F04);
|
13310
|
-
struct ggml_tensor* F06 = ggml_cont (ctx, F05);
|
13311
|
-
struct ggml_tensor* F07 = ggml_sum_rows (ctx, F06);
|
13312
|
-
struct ggml_tensor* F08 = ggml_transpose (ctx, F07);
|
13313
|
-
struct ggml_tensor* F09 = ggml_cont (ctx, F08);
|
13314
|
-
struct ggml_tensor* F10 = ggml_reshape (ctx, F09, src0->grad);
|
13315
|
-
|
13316
|
-
src0->grad =
|
13317
|
-
ggml_add_impl(ctx,
|
13318
|
-
src0->grad,
|
13319
|
-
F10,
|
13320
|
-
inplace);
|
14741
|
+
src0->grad = ggml_add_impl(ctx,
|
14742
|
+
src0->grad,
|
14743
|
+
ggml_repeat_back(ctx, tensor->grad, src0->grad),
|
14744
|
+
inplace);
|
14745
|
+
}
|
14746
|
+
} break;
|
14747
|
+
case GGML_OP_REPEAT_BACK:
|
14748
|
+
{
|
14749
|
+
if (src0->grad) {
|
14750
|
+
// TODO: test this
|
14751
|
+
src0->grad = ggml_add_impl(ctx,
|
14752
|
+
src0->grad,
|
14753
|
+
ggml_repeat(ctx, tensor->grad, src0->grad),
|
14754
|
+
inplace);
|
13321
14755
|
}
|
13322
14756
|
} break;
|
13323
14757
|
case GGML_OP_ABS:
|
@@ -13424,38 +14858,37 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
13424
14858
|
|
13425
14859
|
// necessary for llama
|
13426
14860
|
if (src0->grad) {
|
13427
|
-
// TODO: this requires outer product - ggml_out_prod(ctx, src1, tensor->grad);
|
13428
14861
|
src0->grad =
|
13429
14862
|
ggml_add_impl(ctx,
|
13430
14863
|
src0->grad,
|
13431
|
-
//
|
13432
|
-
|
13433
|
-
|
13434
|
-
// tensor->grad), // [m,p]
|
13435
|
-
// for now just using A*B==(B.T*A.T).T
|
13436
|
-
ggml_cont(ctx, // [n,m]
|
13437
|
-
ggml_transpose(ctx, // [n,m]
|
13438
|
-
ggml_mul_mat(ctx, // [m,n]
|
13439
|
-
ggml_cont(ctx, // [p,m]
|
13440
|
-
ggml_transpose(ctx, // [p,m]
|
13441
|
-
tensor->grad)), // [m,p]
|
13442
|
-
ggml_cont(ctx, // [p,n]
|
13443
|
-
ggml_transpose(ctx, // [p,n]
|
13444
|
-
src1))))), // [n,p]
|
14864
|
+
ggml_out_prod(ctx, // [n,m]
|
14865
|
+
src1, // [n,p]
|
14866
|
+
tensor->grad), // [m,p]
|
13445
14867
|
inplace);
|
13446
14868
|
}
|
13447
14869
|
if (src1->grad) {
|
13448
14870
|
src1->grad =
|
13449
14871
|
ggml_add_impl(ctx,
|
13450
14872
|
src1->grad,
|
13451
|
-
//
|
13452
|
-
|
13453
|
-
|
13454
|
-
|
13455
|
-
|
14873
|
+
// ggml_mul_mat(ctx, // [n,p]
|
14874
|
+
// ggml_cont(ctx, // [m,n]
|
14875
|
+
// ggml_transpose(ctx, src0)), // [m,n]
|
14876
|
+
// tensor->grad), // [m,p]
|
14877
|
+
|
14878
|
+
// // when src0 is bigger than tensor->grad (this is mostly the case in llama),
|
14879
|
+
// // avoid transpose of src0, rather transpose smaller tensor->grad
|
14880
|
+
// // and then use ggml_out_prod
|
14881
|
+
ggml_out_prod(ctx, // [n,p]
|
14882
|
+
src0, // [n,m]
|
14883
|
+
ggml_transpose(ctx, // [p,m]
|
14884
|
+
tensor->grad)), // [m,p]
|
13456
14885
|
inplace);
|
13457
14886
|
}
|
13458
14887
|
} break;
|
14888
|
+
case GGML_OP_OUT_PROD:
|
14889
|
+
{
|
14890
|
+
GGML_ASSERT(false); // TODO: not implemented
|
14891
|
+
} break;
|
13459
14892
|
case GGML_OP_SCALE:
|
13460
14893
|
{
|
13461
14894
|
// necessary for llama
|
@@ -13557,7 +14990,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
13557
14990
|
// necessary for llama
|
13558
14991
|
if (src0->grad) {
|
13559
14992
|
size_t offset;
|
13560
|
-
|
14993
|
+
|
14994
|
+
GGML_ASSERT(sizeof(offset) <= ggml_nbytes(tensor->opt[0]));
|
14995
|
+
memcpy(&offset, tensor->opt[0]->data, sizeof(offset));
|
13561
14996
|
|
13562
14997
|
size_t nb1 = tensor->nb[1];
|
13563
14998
|
size_t nb2 = tensor->nb[2];
|
@@ -13584,10 +15019,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
13584
15019
|
{
|
13585
15020
|
// necessary for llama
|
13586
15021
|
if (src0->grad) {
|
13587
|
-
|
13588
|
-
int
|
13589
|
-
int
|
13590
|
-
int
|
15022
|
+
int32_t * axes = (int32_t *) tensor->opt[0]->data;
|
15023
|
+
int axis0 = axes[0] & 0x3;
|
15024
|
+
int axis1 = axes[1] & 0x3;
|
15025
|
+
int axis2 = axes[2] & 0x3;
|
15026
|
+
int axis3 = axes[3] & 0x3;
|
13591
15027
|
int axes_backward[4] = {0,0,0,0};
|
13592
15028
|
axes_backward[axis0] = 0;
|
13593
15029
|
axes_backward[axis1] = 1;
|
@@ -13671,50 +15107,16 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
13671
15107
|
{
|
13672
15108
|
// necessary for llama
|
13673
15109
|
if (src0->grad) {
|
13674
|
-
// y = softmax(x)
|
13675
|
-
//
|
13676
|
-
// Jii = yi - yi*yi
|
13677
|
-
// Jij = -yi*yj
|
13678
|
-
// J = diag(y)-y.*y
|
13679
|
-
// dx = J * dy
|
13680
|
-
// dxk = sum(Jkj * dyk)
|
13681
|
-
|
13682
|
-
int64_t ne2[4] = {
|
13683
|
-
tensor->ne[0],
|
13684
|
-
1,
|
13685
|
-
tensor->ne[1]*tensor->ne[2],
|
13686
|
-
tensor->ne[3]
|
13687
|
-
};
|
13688
|
-
struct ggml_tensor * tensor2 = ggml_cont(ctx,
|
13689
|
-
ggml_reshape_4d(ctx,
|
13690
|
-
ggml_cont(ctx, tensor),
|
13691
|
-
ne2[0], ne2[1], ne2[2], ne2[3]));
|
13692
|
-
|
13693
|
-
struct ggml_tensor * grad2 = ggml_cont(ctx,
|
13694
|
-
ggml_reshape_4d(ctx,
|
13695
|
-
ggml_cont(ctx, tensor->grad),
|
13696
|
-
ne2[0], ne2[1], ne2[2], ne2[3]));
|
13697
|
-
|
13698
|
-
struct ggml_tensor * tensor2_t = ggml_cont(ctx, // [1,ne0,ne1*ne2,ne3]
|
13699
|
-
ggml_permute(ctx, // [1,ne0,ne1*ne2,ne3]
|
13700
|
-
tensor2, // [ne0,1,ne1*ne2,ne3]
|
13701
|
-
1, 0, 2, 3));
|
13702
|
-
|
13703
15110
|
src0->grad =
|
13704
|
-
ggml_add_impl(ctx,
|
13705
|
-
|
13706
|
-
|
13707
|
-
ggml_mul_mat(ctx, // [ne0,1,ne1*ne2,ne3]
|
13708
|
-
ggml_sub(ctx, // [ne0,ne0,ne1*ne2,ne3]
|
13709
|
-
ggml_diag(ctx, // [ne0,ne0,ne1*ne2,ne3]
|
13710
|
-
tensor2), // [ne0,1,ne1*ne2,ne3]
|
13711
|
-
ggml_mul_mat(ctx, // [ne0,ne0,ne1*ne2,ne3]
|
13712
|
-
tensor2_t, // [1,ne0,ne1*ne2,ne3]
|
13713
|
-
tensor2_t)), // [1,ne0,ne1*ne2,ne3]
|
13714
|
-
grad2), // [ne0,1,ne1*ne2,ne3]
|
13715
|
-
src0->grad),
|
13716
|
-
inplace);
|
15111
|
+
ggml_add_impl(ctx, src0->grad,
|
15112
|
+
ggml_soft_max_back(ctx, tensor->grad, tensor),
|
15113
|
+
inplace);
|
13717
15114
|
}
|
15115
|
+
|
15116
|
+
} break;
|
15117
|
+
case GGML_OP_SOFT_MAX_BACK:
|
15118
|
+
{
|
15119
|
+
GGML_ASSERT(false); // TODO: not implemented
|
13718
15120
|
} break;
|
13719
15121
|
case GGML_OP_ROPE:
|
13720
15122
|
{
|
@@ -13769,17 +15171,190 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
13769
15171
|
} break;
|
13770
15172
|
case GGML_OP_FLASH_ATTN:
|
13771
15173
|
{
|
13772
|
-
|
15174
|
+
struct ggml_tensor * flash_grad = NULL;
|
15175
|
+
if (src0->grad || src1->grad || tensor->opt[0]->grad) {
|
15176
|
+
int32_t t = ggml_get_i32_1d(tensor->opt[1], 0);
|
15177
|
+
GGML_ASSERT(t == 0 || t == 1);
|
15178
|
+
bool masked = t != 0;
|
15179
|
+
flash_grad =
|
15180
|
+
ggml_flash_attn_back(ctx,
|
15181
|
+
src0,
|
15182
|
+
src1,
|
15183
|
+
tensor->opt[0],
|
15184
|
+
tensor->grad,
|
15185
|
+
masked);
|
15186
|
+
}
|
15187
|
+
|
15188
|
+
if (src0->grad) {
|
15189
|
+
struct ggml_tensor * grad_q = NULL;
|
15190
|
+
const size_t nb0 = flash_grad->nb[0];
|
15191
|
+
const size_t offset = 0;
|
15192
|
+
switch(src0->n_dims) {
|
15193
|
+
case 2:
|
15194
|
+
{
|
15195
|
+
grad_q = ggml_view_2d(ctx,
|
15196
|
+
flash_grad,
|
15197
|
+
src0->ne[0],
|
15198
|
+
src0->ne[1],
|
15199
|
+
nb0*src0->ne[0],
|
15200
|
+
offset);
|
15201
|
+
} break;
|
15202
|
+
case 3:
|
15203
|
+
{
|
15204
|
+
grad_q = ggml_view_3d(ctx,
|
15205
|
+
flash_grad,
|
15206
|
+
src0->ne[0],
|
15207
|
+
src0->ne[1],
|
15208
|
+
src0->ne[2],
|
15209
|
+
nb0*src0->ne[0],
|
15210
|
+
nb0*src0->ne[0]*src0->ne[1],
|
15211
|
+
offset);
|
15212
|
+
} break;
|
15213
|
+
case 4:
|
15214
|
+
{
|
15215
|
+
grad_q = ggml_view_4d(ctx,
|
15216
|
+
flash_grad,
|
15217
|
+
src0->ne[0],
|
15218
|
+
src0->ne[1],
|
15219
|
+
src0->ne[2],
|
15220
|
+
src0->ne[3],
|
15221
|
+
nb0*src0->ne[0],
|
15222
|
+
nb0*src0->ne[0]*src0->ne[1],
|
15223
|
+
nb0*src0->ne[0]*src0->ne[1]*src0->ne[2],
|
15224
|
+
offset);
|
15225
|
+
} break;
|
15226
|
+
}
|
15227
|
+
|
15228
|
+
src0->grad = ggml_add_impl(ctx,
|
15229
|
+
src0->grad,
|
15230
|
+
grad_q,
|
15231
|
+
inplace);
|
15232
|
+
}
|
15233
|
+
|
15234
|
+
if (src1->grad) {
|
15235
|
+
struct ggml_tensor * grad_k = NULL;
|
15236
|
+
const size_t nb0 = flash_grad->nb[0];
|
15237
|
+
const size_t offset = nb0*src0->ne[0]*src0->ne[1]*src0->ne[2]*src0->ne[3];
|
15238
|
+
switch(src1->n_dims) {
|
15239
|
+
case 2:
|
15240
|
+
{
|
15241
|
+
grad_k = ggml_view_2d(ctx,
|
15242
|
+
flash_grad,
|
15243
|
+
src1->ne[0],
|
15244
|
+
src1->ne[1],
|
15245
|
+
nb0*src1->ne[0],
|
15246
|
+
offset);
|
15247
|
+
} break;
|
15248
|
+
case 3:
|
15249
|
+
{
|
15250
|
+
grad_k = ggml_view_3d(ctx,
|
15251
|
+
flash_grad,
|
15252
|
+
src1->ne[0],
|
15253
|
+
src1->ne[1],
|
15254
|
+
src1->ne[2],
|
15255
|
+
nb0*src1->ne[0],
|
15256
|
+
nb0*src1->ne[0]*src1->ne[1],
|
15257
|
+
offset);
|
15258
|
+
} break;
|
15259
|
+
case 4:
|
15260
|
+
{
|
15261
|
+
grad_k = ggml_view_4d(ctx,
|
15262
|
+
flash_grad,
|
15263
|
+
src1->ne[0],
|
15264
|
+
src1->ne[1],
|
15265
|
+
src1->ne[2],
|
15266
|
+
src1->ne[3],
|
15267
|
+
nb0*src1->ne[0],
|
15268
|
+
nb0*src1->ne[0]*src1->ne[1],
|
15269
|
+
nb0*src1->ne[0]*src1->ne[1]*src1->ne[2],
|
15270
|
+
offset);
|
15271
|
+
} break;
|
15272
|
+
}
|
15273
|
+
|
15274
|
+
src1->grad = ggml_add_impl(ctx,
|
15275
|
+
src1->grad,
|
15276
|
+
grad_k,
|
15277
|
+
inplace);
|
15278
|
+
}
|
15279
|
+
|
15280
|
+
struct ggml_tensor * opt0 = tensor->opt[0];
|
15281
|
+
|
15282
|
+
if (opt0->grad) {
|
15283
|
+
struct ggml_tensor * grad_v = NULL;
|
15284
|
+
const size_t nb0 = flash_grad->nb[0];
|
15285
|
+
const size_t offset = nb0*src0->ne[0]*src0->ne[1]*src0->ne[2]*src0->ne[3]
|
15286
|
+
+ nb0*src1->ne[0]*src1->ne[1]*src1->ne[2]*src1->ne[3];
|
15287
|
+
switch(opt0->n_dims) {
|
15288
|
+
case 2:
|
15289
|
+
{
|
15290
|
+
grad_v = ggml_view_2d(ctx,
|
15291
|
+
flash_grad,
|
15292
|
+
opt0->ne[0],
|
15293
|
+
opt0->ne[1],
|
15294
|
+
nb0*opt0->ne[0],
|
15295
|
+
offset);
|
15296
|
+
} break;
|
15297
|
+
case 3:
|
15298
|
+
{
|
15299
|
+
grad_v = ggml_view_3d(ctx,
|
15300
|
+
flash_grad,
|
15301
|
+
opt0->ne[0],
|
15302
|
+
opt0->ne[1],
|
15303
|
+
opt0->ne[2],
|
15304
|
+
nb0*opt0->ne[0],
|
15305
|
+
nb0*opt0->ne[0]*opt0->ne[1],
|
15306
|
+
offset);
|
15307
|
+
} break;
|
15308
|
+
case 4:
|
15309
|
+
{
|
15310
|
+
grad_v = ggml_view_4d(ctx,
|
15311
|
+
flash_grad,
|
15312
|
+
opt0->ne[0],
|
15313
|
+
opt0->ne[1],
|
15314
|
+
opt0->ne[2],
|
15315
|
+
opt0->ne[3],
|
15316
|
+
nb0*opt0->ne[0],
|
15317
|
+
nb0*opt0->ne[0]*opt0->ne[1],
|
15318
|
+
nb0*opt0->ne[0]*opt0->ne[1]*opt0->ne[2],
|
15319
|
+
offset);
|
15320
|
+
} break;
|
15321
|
+
}
|
15322
|
+
|
15323
|
+
opt0->grad = ggml_add_impl(ctx,
|
15324
|
+
opt0->grad,
|
15325
|
+
grad_v,
|
15326
|
+
inplace);
|
15327
|
+
}
|
13773
15328
|
} break;
|
13774
15329
|
case GGML_OP_FLASH_FF:
|
13775
15330
|
{
|
13776
15331
|
GGML_ASSERT(false); // not supported
|
13777
15332
|
} break;
|
15333
|
+
case GGML_OP_FLASH_ATTN_BACK:
|
15334
|
+
{
|
15335
|
+
GGML_ASSERT(false); // not supported
|
15336
|
+
} break;
|
13778
15337
|
case GGML_OP_MAP_UNARY:
|
13779
15338
|
case GGML_OP_MAP_BINARY:
|
13780
15339
|
{
|
13781
15340
|
GGML_ASSERT(false); // not supported
|
13782
15341
|
} break;
|
15342
|
+
case GGML_OP_CROSS_ENTROPY_LOSS:
|
15343
|
+
{
|
15344
|
+
if (src0->grad) {
|
15345
|
+
src0->grad = ggml_add_impl(ctx,
|
15346
|
+
src0->grad,
|
15347
|
+
ggml_cross_entropy_loss_back(ctx,
|
15348
|
+
src0,
|
15349
|
+
src1,
|
15350
|
+
tensor->grad),
|
15351
|
+
inplace);
|
15352
|
+
}
|
15353
|
+
} break;
|
15354
|
+
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
15355
|
+
{
|
15356
|
+
GGML_ASSERT(false); // not supported
|
15357
|
+
} break;
|
13783
15358
|
case GGML_OP_NONE:
|
13784
15359
|
{
|
13785
15360
|
// nop
|
@@ -14156,6 +15731,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|
14156
15731
|
case GGML_OP_SUM_ROWS:
|
14157
15732
|
case GGML_OP_MEAN:
|
14158
15733
|
case GGML_OP_REPEAT:
|
15734
|
+
case GGML_OP_REPEAT_BACK:
|
14159
15735
|
case GGML_OP_ABS:
|
14160
15736
|
case GGML_OP_SGN:
|
14161
15737
|
case GGML_OP_NEG:
|
@@ -14175,6 +15751,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|
14175
15751
|
node->n_tasks = n_threads;
|
14176
15752
|
} break;
|
14177
15753
|
case GGML_OP_MUL_MAT:
|
15754
|
+
case GGML_OP_OUT_PROD:
|
14178
15755
|
{
|
14179
15756
|
node->n_tasks = n_threads;
|
14180
15757
|
|
@@ -14191,7 +15768,6 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|
14191
15768
|
if (ggml_cuda_can_mul_mat(node->src0, node->src1, node)) {
|
14192
15769
|
node->n_tasks = 1; // TODO: this actually is doing nothing
|
14193
15770
|
// the threads are still spinning
|
14194
|
-
cur = ggml_cuda_mul_mat_get_wsize(node->src0, node->src1, node);
|
14195
15771
|
}
|
14196
15772
|
else
|
14197
15773
|
#elif defined(GGML_USE_CLBLAST)
|
@@ -14258,6 +15834,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|
14258
15834
|
} break;
|
14259
15835
|
case GGML_OP_DIAG_MASK_INF:
|
14260
15836
|
case GGML_OP_SOFT_MAX:
|
15837
|
+
case GGML_OP_SOFT_MAX_BACK:
|
14261
15838
|
case GGML_OP_ROPE:
|
14262
15839
|
case GGML_OP_ROPE_BACK:
|
14263
15840
|
{
|
@@ -14337,6 +15914,27 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|
14337
15914
|
cur += sizeof(float)*node->src1->ne[1]*node->n_tasks; // this is overestimated by x2
|
14338
15915
|
}
|
14339
15916
|
|
15917
|
+
work_size = MAX(work_size, cur);
|
15918
|
+
} break;
|
15919
|
+
case GGML_OP_FLASH_ATTN_BACK:
|
15920
|
+
{
|
15921
|
+
node->n_tasks = n_threads;
|
15922
|
+
|
15923
|
+
size_t cur = 0;
|
15924
|
+
|
15925
|
+
const int64_t D = node->src0->ne[0];
|
15926
|
+
const int64_t ne11 = ggml_up(node->src1->ne[1], GGML_SOFT_MAX_UNROLL);
|
15927
|
+
const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in ggml_compute_forward_flash_attn_back
|
15928
|
+
if (node->src1->type == GGML_TYPE_F32) {
|
15929
|
+
cur = sizeof(float)*mxDn*node->n_tasks; // TODO: this can become (n_tasks-1)
|
15930
|
+
cur += sizeof(float)*mxDn*node->n_tasks; // this is overestimated by x2
|
15931
|
+
}
|
15932
|
+
|
15933
|
+
if (node->src1->type == GGML_TYPE_F16) {
|
15934
|
+
cur = sizeof(float)*mxDn*node->n_tasks; // TODO: this can become (n_tasks-1)
|
15935
|
+
cur += sizeof(float)*mxDn*node->n_tasks; // this is overestimated by x2
|
15936
|
+
}
|
15937
|
+
|
14340
15938
|
work_size = MAX(work_size, cur);
|
14341
15939
|
} break;
|
14342
15940
|
case GGML_OP_MAP_UNARY:
|
@@ -14344,6 +15942,22 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|
14344
15942
|
{
|
14345
15943
|
node->n_tasks = 1;
|
14346
15944
|
} break;
|
15945
|
+
case GGML_OP_CROSS_ENTROPY_LOSS:
|
15946
|
+
{
|
15947
|
+
node->n_tasks = n_threads;
|
15948
|
+
|
15949
|
+
size_t cur = ggml_type_size(node->type)*(node->n_tasks + node->src0->ne[0]*node->n_tasks);
|
15950
|
+
|
15951
|
+
work_size = MAX(work_size, cur);
|
15952
|
+
} break;
|
15953
|
+
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
15954
|
+
{
|
15955
|
+
node->n_tasks = n_threads;
|
15956
|
+
|
15957
|
+
size_t cur = ggml_type_size(node->type)*node->src0->ne[0]*node->n_tasks;
|
15958
|
+
|
15959
|
+
work_size = MAX(work_size, cur);
|
15960
|
+
} break;
|
14347
15961
|
case GGML_OP_NONE:
|
14348
15962
|
{
|
14349
15963
|
node->n_tasks = 1;
|
@@ -14581,7 +16195,7 @@ static void ggml_graph_export_leaf(const struct ggml_tensor * tensor, FILE * fou
|
|
14581
16195
|
const int64_t * ne = tensor->ne;
|
14582
16196
|
const size_t * nb = tensor->nb;
|
14583
16197
|
|
14584
|
-
fprintf(fout, "%-6s %-12s %8d %
|
16198
|
+
fprintf(fout, "%-6s %-12s %8d %" PRId64 " %" PRId64 " %" PRId64 " %" PRId64 " %16zu %16zu %16zu %16zu %16p %32s\n",
|
14585
16199
|
ggml_type_name(tensor->type),
|
14586
16200
|
ggml_op_name (tensor->op),
|
14587
16201
|
tensor->n_dims,
|
@@ -14595,7 +16209,7 @@ static void ggml_graph_export_node(const struct ggml_tensor * tensor, const char
|
|
14595
16209
|
const int64_t * ne = tensor->ne;
|
14596
16210
|
const size_t * nb = tensor->nb;
|
14597
16211
|
|
14598
|
-
fprintf(fout, "%-6s %-6s %-12s %8d %
|
16212
|
+
fprintf(fout, "%-6s %-6s %-12s %8d %" PRId64 " %" PRId64 " %" PRId64 " %" PRId64 " %16zu %16zu %16zu %16zu %8d %16p %32s\n",
|
14599
16213
|
arg,
|
14600
16214
|
ggml_type_name(tensor->type),
|
14601
16215
|
ggml_op_name (tensor->op),
|
@@ -14608,8 +16222,8 @@ static void ggml_graph_export_node(const struct ggml_tensor * tensor, const char
|
|
14608
16222
|
}
|
14609
16223
|
|
14610
16224
|
void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname) {
|
14611
|
-
assert(cgraph->work == NULL);
|
14612
|
-
assert(cgraph->work_size == 0);
|
16225
|
+
//assert(cgraph->work == NULL);
|
16226
|
+
//assert(cgraph->work_size == 0);
|
14613
16227
|
|
14614
16228
|
uint64_t size_eval = 0;
|
14615
16229
|
|
@@ -14624,11 +16238,11 @@ void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname) {
|
|
14624
16238
|
FILE * fout = stdout;
|
14625
16239
|
|
14626
16240
|
fprintf(fout, "\n");
|
14627
|
-
fprintf(fout, "%-16s %8x\n",
|
14628
|
-
fprintf(fout, "%-16s %8d\n",
|
14629
|
-
fprintf(fout, "%-16s %8d\n",
|
14630
|
-
fprintf(fout, "%-16s %8d\n",
|
14631
|
-
fprintf(fout, "%-16s %
|
16241
|
+
fprintf(fout, "%-16s %8x\n", "magic", GGML_FILE_MAGIC);
|
16242
|
+
fprintf(fout, "%-16s %8d\n", "version", GGML_FILE_VERSION);
|
16243
|
+
fprintf(fout, "%-16s %8d\n", "leafs", cgraph->n_leafs);
|
16244
|
+
fprintf(fout, "%-16s %8d\n", "nodes", cgraph->n_nodes);
|
16245
|
+
fprintf(fout, "%-16s %" PRIu64 "\n", "eval", size_eval);
|
14632
16246
|
|
14633
16247
|
// header
|
14634
16248
|
fprintf(fout, "\n");
|
@@ -14830,7 +16444,6 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context **
|
|
14830
16444
|
// read file into data
|
14831
16445
|
{
|
14832
16446
|
FILE * fin = fopen(fname, "rb");
|
14833
|
-
|
14834
16447
|
if (!fin) {
|
14835
16448
|
fprintf(stderr, "%s: failed to open %s\n", __func__, fname);
|
14836
16449
|
return result;
|
@@ -14862,7 +16475,11 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context **
|
|
14862
16475
|
|
14863
16476
|
data = ggml_new_tensor_1d(*ctx_data, GGML_TYPE_I8, fsize);
|
14864
16477
|
|
14865
|
-
fread(data->data, sizeof(char), fsize, fin);
|
16478
|
+
const size_t ret = fread(data->data, sizeof(char), fsize, fin);
|
16479
|
+
if (ret != fsize) {
|
16480
|
+
fprintf(stderr, "%s: failed to read %s\n", __func__, fname);
|
16481
|
+
return result;
|
16482
|
+
}
|
14866
16483
|
|
14867
16484
|
fclose(fin);
|
14868
16485
|
}
|
@@ -14970,6 +16587,8 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context **
|
|
14970
16587
|
op = *(const uint32_t *) ptr; ptr += sizeof(op);
|
14971
16588
|
n_dims = *(const uint32_t *) ptr; ptr += sizeof(n_dims);
|
14972
16589
|
|
16590
|
+
enum ggml_op eop = (enum ggml_op) op;
|
16591
|
+
|
14973
16592
|
int64_t ne[GGML_MAX_DIMS];
|
14974
16593
|
size_t nb[GGML_MAX_DIMS];
|
14975
16594
|
|
@@ -14984,42 +16603,77 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context **
|
|
14984
16603
|
nb[j] = nb_cur;
|
14985
16604
|
}
|
14986
16605
|
|
14987
|
-
|
14988
|
-
|
14989
|
-
tensor->op = (enum ggml_op) op;
|
16606
|
+
uint64_t ptr_cur = *(const uint64_t *) ptr; ptr += sizeof(ptr_cur); // TODO: not yet used
|
14990
16607
|
|
14991
|
-
|
16608
|
+
const char * ptr_name = ptr; ptr += GGML_MAX_NAME;
|
14992
16609
|
|
14993
|
-
|
16610
|
+
const int32_t * ptr_arg_idx = (const int32_t *) ptr; ptr += (2 + GGML_MAX_OPT)*sizeof(int32_t);
|
14994
16611
|
|
14995
|
-
|
14996
|
-
tensor->nb[j] = nb[j];
|
14997
|
-
}
|
16612
|
+
struct ggml_tensor * args[2 + GGML_MAX_OPT] = { NULL };
|
14998
16613
|
|
14999
16614
|
// parse args
|
15000
|
-
{
|
15001
|
-
|
15002
|
-
&tensor->src0,
|
15003
|
-
&tensor->src1,
|
15004
|
-
};
|
16615
|
+
for (int j = 0; j < 2 + GGML_MAX_OPT; ++j) {
|
16616
|
+
const int32_t arg_idx = ptr_arg_idx[j];
|
15005
16617
|
|
15006
|
-
|
15007
|
-
|
16618
|
+
if (arg_idx == -1) {
|
16619
|
+
continue;
|
15008
16620
|
}
|
15009
16621
|
|
15010
|
-
|
15011
|
-
|
16622
|
+
if (arg_idx < GGML_MAX_NODES) {
|
16623
|
+
args[j] = result.leafs[arg_idx];
|
16624
|
+
} else {
|
16625
|
+
args[j] = result.nodes[arg_idx - GGML_MAX_NODES];
|
16626
|
+
}
|
16627
|
+
}
|
15012
16628
|
|
15013
|
-
|
15014
|
-
|
15015
|
-
|
16629
|
+
// create the tensor
|
16630
|
+
// "view" operations are handled differently
|
16631
|
+
// TODO: handle inplace ops - currently a copy is always made
|
16632
|
+
|
16633
|
+
struct ggml_tensor * tensor = NULL;
|
16634
|
+
|
16635
|
+
switch (eop) {
|
16636
|
+
// TODO: implement other view ops
|
16637
|
+
case GGML_OP_RESHAPE:
|
16638
|
+
{
|
16639
|
+
tensor = ggml_reshape_4d(*ctx_eval, args[0], ne[0], ne[1], ne[2], ne[3]);
|
16640
|
+
} break;
|
16641
|
+
case GGML_OP_VIEW:
|
16642
|
+
{
|
16643
|
+
tensor = ggml_view_4d(*ctx_eval, args[0], ne[0], ne[1], ne[2], ne[3], 0, 0, 0, 0);
|
16644
|
+
|
16645
|
+
uint64_t offs;
|
16646
|
+
memcpy(&offs, args[2]->data, sizeof(offs));
|
16647
|
+
|
16648
|
+
tensor->data = ((char *) tensor->data) + offs;
|
16649
|
+
} break;
|
16650
|
+
case GGML_OP_TRANSPOSE:
|
16651
|
+
{
|
16652
|
+
tensor = ggml_transpose(*ctx_eval, args[0]);
|
16653
|
+
} break;
|
16654
|
+
case GGML_OP_PERMUTE:
|
16655
|
+
{
|
16656
|
+
tensor = ggml_view_4d(*ctx_eval, args[0], ne[0], ne[1], ne[2], ne[3], 0, 0, 0, 0);
|
16657
|
+
} break;
|
16658
|
+
default:
|
16659
|
+
{
|
16660
|
+
tensor = ggml_new_tensor(*ctx_eval, (enum ggml_type) type, n_dims, ne);
|
16661
|
+
|
16662
|
+
tensor->op = eop;
|
16663
|
+
} break;
|
16664
|
+
}
|
15016
16665
|
|
15017
|
-
|
15018
|
-
|
15019
|
-
|
15020
|
-
|
15021
|
-
|
15022
|
-
|
16666
|
+
memcpy(tensor->name, ptr_name, GGML_MAX_NAME);
|
16667
|
+
|
16668
|
+
for (int j = 0; j < GGML_MAX_DIMS; ++j) {
|
16669
|
+
tensor->nb[j] = nb[j];
|
16670
|
+
}
|
16671
|
+
|
16672
|
+
tensor->src0 = args[0];
|
16673
|
+
tensor->src1 = args[1];
|
16674
|
+
|
16675
|
+
for (int j = 0; j < GGML_MAX_OPT; ++j) {
|
16676
|
+
tensor->opt[j] = args[2 + j];
|
15023
16677
|
}
|
15024
16678
|
|
15025
16679
|
result.nodes[i] = tensor;
|
@@ -15279,6 +16933,7 @@ static void ggml_opt_get_grad(int np, struct ggml_tensor * const ps[], float * g
|
|
15279
16933
|
|
15280
16934
|
static enum ggml_opt_result ggml_opt_adam(
|
15281
16935
|
struct ggml_context * ctx,
|
16936
|
+
struct ggml_opt_context * opt,
|
15282
16937
|
struct ggml_opt_params params,
|
15283
16938
|
struct ggml_tensor * f,
|
15284
16939
|
struct ggml_cgraph * gf,
|
@@ -15304,25 +16959,29 @@ static enum ggml_opt_result ggml_opt_adam(
|
|
15304
16959
|
}
|
15305
16960
|
}
|
15306
16961
|
|
16962
|
+
if ((opt->params.type != params.type) || (opt->nx != nx) || (opt->params.past != params.past)) {
|
16963
|
+
int iter = opt->iter;
|
16964
|
+
ggml_opt_init(opt->ctx, opt, params, nx);
|
16965
|
+
opt->iter = iter;
|
16966
|
+
}
|
16967
|
+
|
15307
16968
|
// constants
|
15308
|
-
const float
|
16969
|
+
const float sched = params.adam.sched;
|
16970
|
+
const float decay = params.adam.decay * sched;
|
16971
|
+
const float alpha = params.adam.alpha * sched;
|
15309
16972
|
const float beta1 = params.adam.beta1;
|
15310
16973
|
const float beta2 = params.adam.beta2;
|
15311
16974
|
const float eps = params.adam.eps;
|
15312
16975
|
|
15313
|
-
float * x =
|
15314
|
-
float * g1 =
|
15315
|
-
float * g2 =
|
15316
|
-
float * m =
|
15317
|
-
float * v =
|
15318
|
-
float * mh =
|
15319
|
-
float * vh =
|
15320
|
-
|
15321
|
-
float * pf = params.past > 0 ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)->data : NULL; // past function values
|
16976
|
+
float * x = opt->adam.x->data; // view of the parameters
|
16977
|
+
float * g1 = opt->adam.g1->data; // gradient
|
16978
|
+
float * g2 = opt->adam.g2->data; // gradient squared
|
16979
|
+
float * m = opt->adam.m->data; // first moment
|
16980
|
+
float * v = opt->adam.v->data; // second moment
|
16981
|
+
float * mh = opt->adam.mh->data; // first moment hat
|
16982
|
+
float * vh = opt->adam.vh->data; // second moment hat
|
15322
16983
|
|
15323
|
-
//
|
15324
|
-
ggml_vec_set_f32(nx, m, 0.0f);
|
15325
|
-
ggml_vec_set_f32(nx, v, 0.0f);
|
16984
|
+
float * pf = params.past > 0 ? opt->adam.pf->data : NULL; // past function values
|
15326
16985
|
|
15327
16986
|
// update view
|
15328
16987
|
ggml_opt_get_params(np, ps, x);
|
@@ -15332,16 +16991,27 @@ static enum ggml_opt_result ggml_opt_adam(
|
|
15332
16991
|
ggml_set_f32 (f->grad, 1.0f);
|
15333
16992
|
ggml_graph_compute(ctx, gb);
|
15334
16993
|
|
15335
|
-
|
16994
|
+
opt->adam.fx_prev = ggml_get_f32_1d(f, 0);
|
16995
|
+
opt->adam.fx_best = opt->adam.fx_prev;
|
15336
16996
|
if (pf) {
|
15337
|
-
pf[
|
16997
|
+
pf[opt->iter % params.past] = opt->adam.fx_prev;
|
16998
|
+
}
|
16999
|
+
|
17000
|
+
// initialize
|
17001
|
+
if (opt->just_initialized) {
|
17002
|
+
opt->adam.n_no_improvement = 0;
|
17003
|
+
opt->just_initialized = false;
|
15338
17004
|
}
|
15339
17005
|
|
15340
|
-
|
15341
|
-
float
|
17006
|
+
float * fx_best = &opt->adam.fx_best;
|
17007
|
+
float * fx_prev = &opt->adam.fx_prev;
|
17008
|
+
int * n_no_improvement = &opt->adam.n_no_improvement;
|
17009
|
+
|
17010
|
+
int iter0 = opt->iter;
|
15342
17011
|
|
15343
17012
|
// run the optimizer
|
15344
17013
|
for (int t = 0; t < params.adam.n_iter; ++t) {
|
17014
|
+
opt->iter = iter0 + t + 1;
|
15345
17015
|
GGML_PRINT_DEBUG ("=== iter %d ===\n", t);
|
15346
17016
|
|
15347
17017
|
GGML_PRINT_DEBUG ("f = %10.6f\n", ggml_get_f32_1d(f, 0));
|
@@ -15375,17 +17045,22 @@ static enum ggml_opt_result ggml_opt_adam(
|
|
15375
17045
|
|
15376
17046
|
// m^hat = m_t / (1 - beta1^t)
|
15377
17047
|
// v^hat = v_t / (1 - beta2^t)
|
15378
|
-
// x_t = x_t-1 - alpha*m^hat/(sqrt(v^hat) + eps)
|
17048
|
+
// x_t = x_t-1 - sched*(alpha*m^hat/(sqrt(v^hat) + eps) + decay*x_t-1)
|
17049
|
+
// x_t = x_t-1 - sched*alpha*m^hat/(sqrt(v^hat) + eps) - sched*decay*x_t-1
|
17050
|
+
// x_t = x_t-1*(1-sched*decay) - sched*alpha*m^hat/(sqrt(v^hat) + eps)
|
17051
|
+
// x_t = x_t-1*(1-sched*decay) + sched*decay*(-alpha/decay)*m^hat/(sqrt(v^hat) + eps)
|
17052
|
+
// x_t = mix(x_t-1, (-alpha/decay)*m^hat/(sqrt(v^hat) + eps), sched*decay)
|
15379
17053
|
ggml_vec_cpy_f32 (nx, mh, m);
|
15380
17054
|
ggml_vec_cpy_f32 (nx, vh, v);
|
15381
17055
|
|
15382
|
-
ggml_vec_scale_f32(nx, mh, alpha/(1.0f - powf(beta1,
|
15383
|
-
ggml_vec_scale_f32(nx, vh, 1.0f/(1.0f - powf(beta2,
|
17056
|
+
ggml_vec_scale_f32(nx, mh, alpha/(1.0f - powf(beta1, opt->iter)));
|
17057
|
+
ggml_vec_scale_f32(nx, vh, 1.0f/(1.0f - powf(beta2, opt->iter)));
|
15384
17058
|
|
15385
17059
|
ggml_vec_sqrt_f32 (nx, vh, vh);
|
15386
17060
|
ggml_vec_acc1_f32 (nx, vh, eps);
|
15387
17061
|
|
15388
17062
|
ggml_vec_div_f32 (nx, mh, mh, vh);
|
17063
|
+
ggml_vec_scale_f32(nx, x, 1.0f - decay);
|
15389
17064
|
ggml_vec_sub_f32 (nx, x, x, mh);
|
15390
17065
|
|
15391
17066
|
// update the parameters
|
@@ -15399,7 +17074,7 @@ static enum ggml_opt_result ggml_opt_adam(
|
|
15399
17074
|
const float fx = ggml_get_f32_1d(f, 0);
|
15400
17075
|
|
15401
17076
|
// check convergence
|
15402
|
-
if (fabsf(fx - fx_prev)/fx < params.adam.eps_f) {
|
17077
|
+
if (fabsf(fx - fx_prev[0])/fx < params.adam.eps_f) {
|
15403
17078
|
GGML_PRINT_DEBUG("converged\n");
|
15404
17079
|
|
15405
17080
|
return GGML_OPT_OK;
|
@@ -15408,32 +17083,32 @@ static enum ggml_opt_result ggml_opt_adam(
|
|
15408
17083
|
// delta-based convergence test
|
15409
17084
|
if (pf != NULL) {
|
15410
17085
|
// need at least params.past iterations to start checking for convergence
|
15411
|
-
if (params.past <= t) {
|
15412
|
-
const float rate = (pf[t%params.past] - fx)/fx;
|
17086
|
+
if (params.past <= iter0 + t) {
|
17087
|
+
const float rate = (pf[(iter0 + t)%params.past] - fx)/fx;
|
15413
17088
|
|
15414
17089
|
if (fabsf(rate) < params.delta) {
|
15415
17090
|
return GGML_OPT_OK;
|
15416
17091
|
}
|
15417
17092
|
}
|
15418
17093
|
|
15419
|
-
pf[t%params.past] = fx;
|
17094
|
+
pf[(iter0 + t)%params.past] = fx;
|
15420
17095
|
}
|
15421
17096
|
|
15422
17097
|
// check for improvement
|
15423
17098
|
if (params.max_no_improvement > 0) {
|
15424
|
-
if (fx_best > fx) {
|
15425
|
-
fx_best = fx;
|
15426
|
-
n_no_improvement = 0;
|
17099
|
+
if (fx_best[0] > fx) {
|
17100
|
+
fx_best[0] = fx;
|
17101
|
+
n_no_improvement[0] = 0;
|
15427
17102
|
} else {
|
15428
|
-
++n_no_improvement;
|
17103
|
+
++n_no_improvement[0];
|
15429
17104
|
|
15430
|
-
if (n_no_improvement >= params.max_no_improvement) {
|
17105
|
+
if (n_no_improvement[0] >= params.max_no_improvement) {
|
15431
17106
|
return GGML_OPT_OK;
|
15432
17107
|
}
|
15433
17108
|
}
|
15434
17109
|
}
|
15435
17110
|
|
15436
|
-
fx_prev = fx;
|
17111
|
+
fx_prev[0] = fx;
|
15437
17112
|
|
15438
17113
|
{
|
15439
17114
|
const int64_t t_end_cpu = ggml_cycles();
|
@@ -15572,6 +17247,7 @@ static enum ggml_opt_result linesearch_backtracking(
|
|
15572
17247
|
|
15573
17248
|
static enum ggml_opt_result ggml_opt_lbfgs(
|
15574
17249
|
struct ggml_context * ctx,
|
17250
|
+
struct ggml_opt_context * opt,
|
15575
17251
|
struct ggml_opt_params params,
|
15576
17252
|
struct ggml_tensor * f,
|
15577
17253
|
struct ggml_cgraph * gf,
|
@@ -15604,31 +17280,32 @@ static enum ggml_opt_result ggml_opt_lbfgs(
|
|
15604
17280
|
}
|
15605
17281
|
}
|
15606
17282
|
|
15607
|
-
|
15608
|
-
|
15609
|
-
|
15610
|
-
|
15611
|
-
|
17283
|
+
if ((opt->params.type != params.type) || (opt->nx != nx) || (opt->params.past != params.past) || (opt->params.lbfgs.m != params.lbfgs.m)) {
|
17284
|
+
int iter = opt->iter;
|
17285
|
+
ggml_opt_init(ctx, opt, params, nx);
|
17286
|
+
opt->iter = iter;
|
17287
|
+
}
|
17288
|
+
|
17289
|
+
float * x = opt->lbfgs.x->data; // current parameters
|
17290
|
+
float * xp = opt->lbfgs.xp->data; // previous parameters
|
17291
|
+
float * g = opt->lbfgs.g->data; // current gradient
|
17292
|
+
float * gp = opt->lbfgs.gp->data; // previous gradient
|
17293
|
+
float * d = opt->lbfgs.d->data; // search direction
|
15612
17294
|
|
15613
|
-
float * pf = params.past > 0 ?
|
17295
|
+
float * pf = params.past > 0 ? opt->lbfgs.pf->data : NULL; // past function values
|
15614
17296
|
|
15615
17297
|
float fx = 0.0f; // cost function value
|
15616
17298
|
float xnorm = 0.0f; // ||x||
|
15617
17299
|
float gnorm = 0.0f; // ||g||
|
15618
|
-
float step = 0.0f;
|
15619
17300
|
|
15620
17301
|
// initialize x from the graph nodes
|
15621
17302
|
ggml_opt_get_params(np, ps, x);
|
15622
17303
|
|
15623
17304
|
// the L-BFGS memory
|
15624
|
-
|
15625
|
-
|
15626
|
-
|
15627
|
-
|
15628
|
-
lm[i].ys = 0.0f;
|
15629
|
-
lm[i].s = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data;
|
15630
|
-
lm[i].y = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data;
|
15631
|
-
}
|
17305
|
+
float * lm_alpha = opt->lbfgs.lmal->data;
|
17306
|
+
float * lm_ys = opt->lbfgs.lmys->data;
|
17307
|
+
float * lm_s = opt->lbfgs.lms->data;
|
17308
|
+
float * lm_y = opt->lbfgs.lmy->data;
|
15632
17309
|
|
15633
17310
|
// evaluate the function value and its gradient
|
15634
17311
|
{
|
@@ -15643,12 +17320,6 @@ static enum ggml_opt_result ggml_opt_lbfgs(
|
|
15643
17320
|
fx = ggml_get_f32_1d(f, 0);
|
15644
17321
|
}
|
15645
17322
|
|
15646
|
-
if (pf) {
|
15647
|
-
pf[0] = fx;
|
15648
|
-
}
|
15649
|
-
|
15650
|
-
float fx_best = fx;
|
15651
|
-
|
15652
17323
|
// search direction = -gradient
|
15653
17324
|
ggml_vec_neg_f32(nx, d, g);
|
15654
17325
|
|
@@ -15665,26 +17336,43 @@ static enum ggml_opt_result ggml_opt_lbfgs(
|
|
15665
17336
|
return GGML_OPT_OK;
|
15666
17337
|
}
|
15667
17338
|
|
15668
|
-
|
15669
|
-
|
17339
|
+
if (opt->just_initialized) {
|
17340
|
+
if (pf) {
|
17341
|
+
pf[0] = fx;
|
17342
|
+
}
|
17343
|
+
opt->lbfgs.fx_best = fx;
|
17344
|
+
|
17345
|
+
// initial step
|
17346
|
+
ggml_vec_norm_inv_f32(nx, &opt->lbfgs.step, d);
|
17347
|
+
opt->lbfgs.j = 0;
|
17348
|
+
opt->lbfgs.k = 1;
|
17349
|
+
opt->lbfgs.end = 0;
|
17350
|
+
opt->lbfgs.n_no_improvement = 0;
|
17351
|
+
opt->just_initialized = false;
|
17352
|
+
}
|
17353
|
+
|
17354
|
+
float * fx_best = &opt->lbfgs.fx_best;
|
17355
|
+
float * step = &opt->lbfgs.step;
|
17356
|
+
int * j = &opt->lbfgs.j;
|
17357
|
+
int * k = &opt->lbfgs.k;
|
17358
|
+
int * end = &opt->lbfgs.end;
|
17359
|
+
int * n_no_improvement = &opt->lbfgs.n_no_improvement;
|
15670
17360
|
|
15671
|
-
int
|
15672
|
-
int
|
15673
|
-
int ls = 0;
|
15674
|
-
int end = 0;
|
15675
|
-
int bound = 0;
|
15676
|
-
int n_no_improvement = 0;
|
17361
|
+
int ls = 0;
|
17362
|
+
int bound = 0;
|
15677
17363
|
|
15678
17364
|
float ys = 0.0f;
|
15679
17365
|
float yy = 0.0f;
|
15680
17366
|
float beta = 0.0f;
|
15681
17367
|
|
17368
|
+
int it = 0;
|
17369
|
+
|
15682
17370
|
while (true) {
|
15683
17371
|
// store the current position and gradient vectors
|
15684
17372
|
ggml_vec_cpy_f32(nx, xp, x);
|
15685
17373
|
ggml_vec_cpy_f32(nx, gp, g);
|
15686
17374
|
|
15687
|
-
ls = linesearch_backtracking(ctx, ¶ms, nx, x, &fx, g, d,
|
17375
|
+
ls = linesearch_backtracking(ctx, ¶ms, nx, x, &fx, g, d, step, xp, f, gf, gb, np, ps);
|
15688
17376
|
|
15689
17377
|
if (ls < 0) {
|
15690
17378
|
// linesearch failed - go back to the previous point and return
|
@@ -15710,32 +17398,32 @@ static enum ggml_opt_result ggml_opt_lbfgs(
|
|
15710
17398
|
// delta-based convergence test
|
15711
17399
|
if (pf != NULL) {
|
15712
17400
|
// need at least params.past iterations to start checking for convergence
|
15713
|
-
if (params.past <= k) {
|
15714
|
-
const float rate = (pf[k%params.past] - fx)/fx;
|
17401
|
+
if (params.past <= k[0]) {
|
17402
|
+
const float rate = (pf[k[0]%params.past] - fx)/fx;
|
15715
17403
|
|
15716
17404
|
if (fabsf(rate) < params.delta) {
|
15717
17405
|
return GGML_OPT_OK;
|
15718
17406
|
}
|
15719
17407
|
}
|
15720
17408
|
|
15721
|
-
pf[k%params.past] = fx;
|
17409
|
+
pf[k[0]%params.past] = fx;
|
15722
17410
|
}
|
15723
17411
|
|
15724
17412
|
// check for improvement
|
15725
17413
|
if (params.max_no_improvement > 0) {
|
15726
|
-
if (fx < fx_best) {
|
15727
|
-
fx_best = fx;
|
15728
|
-
n_no_improvement = 0;
|
17414
|
+
if (fx < fx_best[0]) {
|
17415
|
+
fx_best[0] = fx;
|
17416
|
+
n_no_improvement[0] = 0;
|
15729
17417
|
} else {
|
15730
|
-
n_no_improvement++;
|
17418
|
+
n_no_improvement[0]++;
|
15731
17419
|
|
15732
|
-
if (n_no_improvement >= params.max_no_improvement) {
|
17420
|
+
if (n_no_improvement[0] >= params.max_no_improvement) {
|
15733
17421
|
return GGML_OPT_OK;
|
15734
17422
|
}
|
15735
17423
|
}
|
15736
17424
|
}
|
15737
17425
|
|
15738
|
-
if (params.lbfgs.n_iter != 0 && params.lbfgs.n_iter <
|
17426
|
+
if (params.lbfgs.n_iter != 0 && params.lbfgs.n_iter < it + 1) {
|
15739
17427
|
// reached the maximum number of iterations
|
15740
17428
|
return GGML_OPT_DID_NOT_CONVERGE;
|
15741
17429
|
}
|
@@ -15744,50 +17432,51 @@ static enum ggml_opt_result ggml_opt_lbfgs(
|
|
15744
17432
|
// s_{k+1} = x_{k+1} - x_{k} = \step * d_{k}.
|
15745
17433
|
// y_{k+1} = g_{k+1} - g_{k}.
|
15746
17434
|
//
|
15747
|
-
ggml_vec_sub_f32(nx,
|
15748
|
-
ggml_vec_sub_f32(nx,
|
17435
|
+
ggml_vec_sub_f32(nx, &lm_s[end[0]*nx], x, xp);
|
17436
|
+
ggml_vec_sub_f32(nx, &lm_y[end[0]*nx], g, gp);
|
15749
17437
|
|
15750
17438
|
// compute scalars ys and yy:
|
15751
17439
|
// ys = y^t \cdot s -> 1 / \rho.
|
15752
17440
|
// yy = y^t \cdot y.
|
15753
17441
|
//
|
15754
|
-
ggml_vec_dot_f32(nx, &ys,
|
15755
|
-
ggml_vec_dot_f32(nx, &yy,
|
17442
|
+
ggml_vec_dot_f32(nx, &ys, &lm_y[end[0]*nx], &lm_s[end[0] *nx]);
|
17443
|
+
ggml_vec_dot_f32(nx, &yy, &lm_y[end[0]*nx], &lm_y[end[0]*nx]);
|
15756
17444
|
|
15757
|
-
|
17445
|
+
lm_ys[end[0]] = ys;
|
15758
17446
|
|
15759
17447
|
// find new search direction
|
15760
17448
|
// ref: https://en.wikipedia.org/wiki/Limited-memory_BFGS
|
15761
17449
|
|
15762
|
-
bound = (m <= k) ? m : k;
|
15763
|
-
k++;
|
15764
|
-
|
17450
|
+
bound = (m <= k[0]) ? m : k[0];
|
17451
|
+
k[0]++;
|
17452
|
+
it++;
|
17453
|
+
end[0] = (end[0] + 1)%m;
|
15765
17454
|
|
15766
17455
|
// initialize search direction with -g
|
15767
17456
|
ggml_vec_neg_f32(nx, d, g);
|
15768
17457
|
|
15769
|
-
j = end;
|
17458
|
+
j[0] = end[0];
|
15770
17459
|
for (int i = 0; i < bound; ++i) {
|
15771
|
-
j = (j + m - 1) % m;
|
17460
|
+
j[0] = (j[0] + m - 1) % m;
|
15772
17461
|
// \alpha_{j} = \rho_{j} s^{t}_{j} \cdot q_{k+1}
|
15773
|
-
ggml_vec_dot_f32(nx, &
|
15774
|
-
|
17462
|
+
ggml_vec_dot_f32(nx, &lm_alpha[j[0]], &lm_s[j[0]*nx], d);
|
17463
|
+
lm_alpha[j[0]] /= lm_ys[j[0]];
|
15775
17464
|
// q_{i} = q_{i+1} - \alpha_{i} y_{i}
|
15776
|
-
ggml_vec_mad_f32(nx, d,
|
17465
|
+
ggml_vec_mad_f32(nx, d, &lm_y[j[0]*nx], -lm_alpha[j[0]]);
|
15777
17466
|
}
|
15778
17467
|
|
15779
17468
|
ggml_vec_scale_f32(nx, d, ys/yy);
|
15780
17469
|
|
15781
17470
|
for (int i = 0; i < bound; ++i) {
|
15782
17471
|
// \beta_{j} = \rho_{j} y^t_{j} \cdot \gamma_{i}
|
15783
|
-
ggml_vec_dot_f32(nx, &beta,
|
15784
|
-
beta /=
|
17472
|
+
ggml_vec_dot_f32(nx, &beta, &lm_y[j[0]*nx], d);
|
17473
|
+
beta /= lm_ys[j[0]];
|
15785
17474
|
// \gamma_{i+1} = \gamma_{i} + (\alpha_{j} - \beta_{j}) s_{j}
|
15786
|
-
ggml_vec_mad_f32(nx, d,
|
15787
|
-
j = (j + 1)%m;
|
17475
|
+
ggml_vec_mad_f32(nx, d, &lm_s[j[0]*nx], lm_alpha[j[0]] - beta);
|
17476
|
+
j[0] = (j[0] + 1)%m;
|
15788
17477
|
}
|
15789
17478
|
|
15790
|
-
step = 1.0;
|
17479
|
+
step[0] = 1.0;
|
15791
17480
|
}
|
15792
17481
|
|
15793
17482
|
return GGML_OPT_DID_NOT_CONVERGE;
|
@@ -15812,6 +17501,8 @@ struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) {
|
|
15812
17501
|
|
15813
17502
|
.adam = {
|
15814
17503
|
.n_iter = 10000,
|
17504
|
+
.sched = 1.000f,
|
17505
|
+
.decay = 0.001f,
|
15815
17506
|
.alpha = 0.001f,
|
15816
17507
|
.beta1 = 0.9f,
|
15817
17508
|
.beta2 = 0.999f,
|
@@ -15854,6 +17545,71 @@ struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) {
|
|
15854
17545
|
return result;
|
15855
17546
|
}
|
15856
17547
|
|
17548
|
+
GGML_API void ggml_opt_init(
|
17549
|
+
struct ggml_context * ctx,
|
17550
|
+
struct ggml_opt_context * opt,
|
17551
|
+
struct ggml_opt_params params,
|
17552
|
+
int64_t nx) {
|
17553
|
+
opt->ctx = ctx;
|
17554
|
+
opt->params = params;
|
17555
|
+
opt->iter = 0;
|
17556
|
+
opt->nx = nx;
|
17557
|
+
opt->just_initialized = true;
|
17558
|
+
switch (opt->params.type) {
|
17559
|
+
case GGML_OPT_ADAM:
|
17560
|
+
{
|
17561
|
+
opt->adam.x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
17562
|
+
opt->adam.g1 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
17563
|
+
opt->adam.g2 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
17564
|
+
opt->adam.m = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
17565
|
+
opt->adam.v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
17566
|
+
opt->adam.mh = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
17567
|
+
opt->adam.vh = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
17568
|
+
opt->adam.pf = params.past > 0
|
17569
|
+
? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)
|
17570
|
+
: NULL;
|
17571
|
+
ggml_set_zero(opt->adam.x);
|
17572
|
+
ggml_set_zero(opt->adam.g1);
|
17573
|
+
ggml_set_zero(opt->adam.g2);
|
17574
|
+
ggml_set_zero(opt->adam.m);
|
17575
|
+
ggml_set_zero(opt->adam.v);
|
17576
|
+
ggml_set_zero(opt->adam.mh);
|
17577
|
+
ggml_set_zero(opt->adam.vh);
|
17578
|
+
if (opt->adam.pf) {
|
17579
|
+
ggml_set_zero(opt->adam.pf);
|
17580
|
+
}
|
17581
|
+
} break;
|
17582
|
+
case GGML_OPT_LBFGS:
|
17583
|
+
{
|
17584
|
+
opt->lbfgs.x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
17585
|
+
opt->lbfgs.xp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
17586
|
+
opt->lbfgs.g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
17587
|
+
opt->lbfgs.gp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
17588
|
+
opt->lbfgs.d = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
17589
|
+
opt->lbfgs.pf = params.past > 0
|
17590
|
+
? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)
|
17591
|
+
: NULL;
|
17592
|
+
opt->lbfgs.lmal = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.lbfgs.m);
|
17593
|
+
opt->lbfgs.lmys = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.lbfgs.m);
|
17594
|
+
opt->lbfgs.lms = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nx, params.lbfgs.m);
|
17595
|
+
opt->lbfgs.lmy = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nx, params.lbfgs.m);
|
17596
|
+
ggml_set_zero(opt->lbfgs.x);
|
17597
|
+
ggml_set_zero(opt->lbfgs.xp);
|
17598
|
+
ggml_set_zero(opt->lbfgs.g);
|
17599
|
+
ggml_set_zero(opt->lbfgs.gp);
|
17600
|
+
ggml_set_zero(opt->lbfgs.d);
|
17601
|
+
ggml_set_zero(opt->lbfgs.pf);
|
17602
|
+
if (opt->lbfgs.pf) {
|
17603
|
+
ggml_set_zero(opt->lbfgs.pf);
|
17604
|
+
}
|
17605
|
+
ggml_set_zero(opt->lbfgs.lmal);
|
17606
|
+
ggml_set_zero(opt->lbfgs.lmys);
|
17607
|
+
ggml_set_zero(opt->lbfgs.lms);
|
17608
|
+
ggml_set_zero(opt->lbfgs.lmy);
|
17609
|
+
} break;
|
17610
|
+
}
|
17611
|
+
}
|
17612
|
+
|
15857
17613
|
enum ggml_opt_result ggml_opt(
|
15858
17614
|
struct ggml_context * ctx,
|
15859
17615
|
struct ggml_opt_params params,
|
@@ -15876,33 +17632,65 @@ enum ggml_opt_result ggml_opt(
|
|
15876
17632
|
|
15877
17633
|
enum ggml_opt_result result = GGML_OPT_OK;
|
15878
17634
|
|
17635
|
+
struct ggml_opt_context * opt = (struct ggml_opt_context *) alloca(sizeof(struct ggml_opt_context));
|
17636
|
+
|
17637
|
+
ggml_opt_init(ctx, opt, params, 0);
|
17638
|
+
result = ggml_opt_resume(ctx, opt, f);
|
17639
|
+
|
17640
|
+
if (free_ctx) {
|
17641
|
+
ggml_free(ctx);
|
17642
|
+
}
|
17643
|
+
|
17644
|
+
return result;
|
17645
|
+
}
|
17646
|
+
|
17647
|
+
enum ggml_opt_result ggml_opt_resume(
|
17648
|
+
struct ggml_context * ctx,
|
17649
|
+
struct ggml_opt_context * opt,
|
17650
|
+
struct ggml_tensor * f) {
|
17651
|
+
|
17652
|
+
// build forward + backward compute graphs
|
17653
|
+
struct ggml_tensor * gfbuf = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(struct ggml_cgraph) / GGML_TYPE_SIZE[GGML_TYPE_I32]+ (sizeof(struct ggml_cgraph) % GGML_TYPE_SIZE[GGML_TYPE_I32] ? 1 : 0));
|
17654
|
+
struct ggml_tensor * gbbuf = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(struct ggml_cgraph) / GGML_TYPE_SIZE[GGML_TYPE_I32]+ (sizeof(struct ggml_cgraph) % GGML_TYPE_SIZE[GGML_TYPE_I32] ? 1 : 0));
|
17655
|
+
|
17656
|
+
struct ggml_cgraph * gf = (struct ggml_cgraph *) gfbuf->data;
|
17657
|
+
struct ggml_cgraph * gb = (struct ggml_cgraph *) gbbuf->data;
|
17658
|
+
|
17659
|
+
*gf = ggml_build_forward (f);
|
17660
|
+
*gb = ggml_build_backward(ctx, gf, true);
|
17661
|
+
|
17662
|
+
return ggml_opt_resume_g(ctx, opt, f, gf, gb);
|
17663
|
+
}
|
17664
|
+
|
17665
|
+
enum ggml_opt_result ggml_opt_resume_g(
|
17666
|
+
struct ggml_context * ctx,
|
17667
|
+
struct ggml_opt_context * opt,
|
17668
|
+
struct ggml_tensor * f,
|
17669
|
+
struct ggml_cgraph * gf,
|
17670
|
+
struct ggml_cgraph * gb) {
|
17671
|
+
|
15879
17672
|
// build forward + backward compute graphs
|
15880
|
-
|
15881
|
-
struct ggml_cgraph gb = ggml_build_backward(ctx, &gf, true);
|
17673
|
+
enum ggml_opt_result result = GGML_OPT_OK;
|
15882
17674
|
|
15883
|
-
switch (params.type) {
|
17675
|
+
switch (opt->params.type) {
|
15884
17676
|
case GGML_OPT_ADAM:
|
15885
17677
|
{
|
15886
|
-
result = ggml_opt_adam(ctx, params, f,
|
17678
|
+
result = ggml_opt_adam(ctx, opt, opt->params, f, gf, gb);
|
15887
17679
|
} break;
|
15888
17680
|
case GGML_OPT_LBFGS:
|
15889
17681
|
{
|
15890
|
-
result = ggml_opt_lbfgs(ctx, params, f,
|
17682
|
+
result = ggml_opt_lbfgs(ctx, opt, opt->params, f, gf, gb);
|
15891
17683
|
} break;
|
15892
17684
|
}
|
15893
17685
|
|
15894
|
-
if (params.print_forward_graph) {
|
15895
|
-
ggml_graph_print (
|
15896
|
-
ggml_graph_dump_dot(
|
15897
|
-
}
|
15898
|
-
|
15899
|
-
if (params.print_backward_graph) {
|
15900
|
-
ggml_graph_print (&gb);
|
15901
|
-
ggml_graph_dump_dot(&gb, &gf, "opt-backward.dot");
|
17686
|
+
if (opt->params.print_forward_graph) {
|
17687
|
+
ggml_graph_print (gf);
|
17688
|
+
ggml_graph_dump_dot(gf, NULL, "opt-forward.dot");
|
15902
17689
|
}
|
15903
17690
|
|
15904
|
-
if (
|
15905
|
-
|
17691
|
+
if (opt->params.print_backward_graph) {
|
17692
|
+
ggml_graph_print (gb);
|
17693
|
+
ggml_graph_dump_dot(gb, gf, "opt-backward.dot");
|
15906
17694
|
}
|
15907
17695
|
|
15908
17696
|
return result;
|
@@ -16070,6 +17858,50 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i
|
|
16070
17858
|
block_q8_0 * block = (block_q8_0*)dst + start / QK8_0;
|
16071
17859
|
result = ggml_quantize_q8_0(src + start, block, n, n, hist);
|
16072
17860
|
} break;
|
17861
|
+
#ifdef GGML_USE_K_QUANTS
|
17862
|
+
case GGML_TYPE_Q2_K:
|
17863
|
+
{
|
17864
|
+
GGML_ASSERT(start % QK_K == 0);
|
17865
|
+
block_q2_K * block = (block_q2_K*)dst + start / QK_K;
|
17866
|
+
result = ggml_quantize_q2_K(src + start, block, n, n, hist);
|
17867
|
+
} break;
|
17868
|
+
case GGML_TYPE_Q3_K:
|
17869
|
+
{
|
17870
|
+
GGML_ASSERT(start % QK_K == 0);
|
17871
|
+
block_q3_K * block = (block_q3_K*)dst + start / QK_K;
|
17872
|
+
result = ggml_quantize_q3_K(src + start, block, n, n, hist);
|
17873
|
+
} break;
|
17874
|
+
case GGML_TYPE_Q4_K:
|
17875
|
+
{
|
17876
|
+
GGML_ASSERT(start % QK_K == 0);
|
17877
|
+
block_q4_K * block = (block_q4_K*)dst + start / QK_K;
|
17878
|
+
result = ggml_quantize_q4_K(src + start, block, n, n, hist);
|
17879
|
+
} break;
|
17880
|
+
case GGML_TYPE_Q5_K:
|
17881
|
+
{
|
17882
|
+
GGML_ASSERT(start % QK_K == 0);
|
17883
|
+
block_q5_K * block = (block_q5_K*)dst + start / QK_K;
|
17884
|
+
result = ggml_quantize_q5_K(src + start, block, n, n, hist);
|
17885
|
+
} break;
|
17886
|
+
case GGML_TYPE_Q6_K:
|
17887
|
+
{
|
17888
|
+
GGML_ASSERT(start % QK_K == 0);
|
17889
|
+
block_q6_K * block = (block_q6_K*)dst + start / QK_K;
|
17890
|
+
result = ggml_quantize_q6_K(src + start, block, n, n, hist);
|
17891
|
+
} break;
|
17892
|
+
#endif
|
17893
|
+
case GGML_TYPE_F16:
|
17894
|
+
{
|
17895
|
+
int elemsize = sizeof(ggml_fp16_t);
|
17896
|
+
ggml_fp32_to_fp16_row(src + start, (ggml_fp16_t *)dst + start, n);
|
17897
|
+
result = n * elemsize;
|
17898
|
+
} break;
|
17899
|
+
case GGML_TYPE_F32:
|
17900
|
+
{
|
17901
|
+
int elemsize = sizeof(float);
|
17902
|
+
result = n * elemsize;
|
17903
|
+
memcpy((uint8_t *)dst + start * elemsize, src + start, result);
|
17904
|
+
} break;
|
16073
17905
|
default:
|
16074
17906
|
assert(false);
|
16075
17907
|
}
|