llama_cpp 0.1.4 → 0.2.1

Sign up to get free protection for your applications and to get access to all the features.
@@ -3,6 +3,10 @@
3
3
 
4
4
  #include "ggml.h"
5
5
 
6
+ #ifdef GGML_USE_K_QUANTS
7
+ #include "k_quants.h"
8
+ #endif
9
+
6
10
  #if defined(_MSC_VER) || defined(__MINGW32__)
7
11
  #include <malloc.h> // using malloc.h with MSC/MINGW
8
12
  #elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__)
@@ -21,6 +25,10 @@
21
25
  #include <float.h>
22
26
  #include <limits.h>
23
27
 
28
+ #ifdef GGML_USE_METAL
29
+ #include <unistd.h>
30
+ #endif
31
+
24
32
  // if C99 - static_assert is noop
25
33
  // ref: https://stackoverflow.com/a/53923785/4039976
26
34
  #ifndef static_assert
@@ -121,7 +129,11 @@ typedef void* thread_ret_t;
121
129
  #else
122
130
  inline static void* ggml_aligned_malloc(size_t size) {
123
131
  void* aligned_memory = NULL;
132
+ #ifdef GGML_USE_METAL
133
+ int result = posix_memalign(&aligned_memory, getpagesize(), size);
134
+ #else
124
135
  int result = posix_memalign(&aligned_memory, GGML_MEM_ALIGN, size);
136
+ #endif
125
137
  if (result != 0) {
126
138
  // Handle allocation failure
127
139
  return NULL;
@@ -403,21 +415,27 @@ void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, size_t n) {
403
415
  //
404
416
 
405
417
  #if defined(_MSC_VER) || defined(__MINGW32__)
406
- static int64_t timer_freq;
418
+ static int64_t timer_freq, timer_start;
407
419
  void ggml_time_init(void) {
408
- LARGE_INTEGER frequency;
409
- QueryPerformanceFrequency(&frequency);
410
- timer_freq = frequency.QuadPart;
420
+ LARGE_INTEGER t;
421
+ QueryPerformanceFrequency(&t);
422
+ timer_freq = t.QuadPart;
423
+
424
+ // The multiplication by 1000 or 1000000 below can cause an overflow if timer_freq
425
+ // and the uptime is high enough.
426
+ // We subtract the program start time to reduce the likelihood of that happening.
427
+ QueryPerformanceCounter(&t);
428
+ timer_start = t.QuadPart;
411
429
  }
412
430
  int64_t ggml_time_ms(void) {
413
431
  LARGE_INTEGER t;
414
432
  QueryPerformanceCounter(&t);
415
- return (t.QuadPart * 1000) / timer_freq;
433
+ return ((t.QuadPart-timer_start) * 1000) / timer_freq;
416
434
  }
417
435
  int64_t ggml_time_us(void) {
418
436
  LARGE_INTEGER t;
419
437
  QueryPerformanceCounter(&t);
420
- return (t.QuadPart * 1000000) / timer_freq;
438
+ return ((t.QuadPart-timer_start) * 1000000) / timer_freq;
421
439
  }
422
440
  #else
423
441
  void ggml_time_init(void) {}
@@ -474,6 +492,8 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
474
492
  // quantization
475
493
  //
476
494
 
495
+ #define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
496
+
477
497
  #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
478
498
  // multiply int8_t, add results pairwise twice
479
499
  static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) {
@@ -533,7 +553,7 @@ static inline __m256i bytes_from_bits_32(const uint8_t * x) {
533
553
  static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
534
554
  {
535
555
  const __m128i tmp = _mm_loadu_si128((const __m128i *)rsi);
536
- const __m256i bytes = _mm256_set_m128i(_mm_srli_epi16(tmp, 4), tmp);
556
+ const __m256i bytes = MM256_SET_M128I(_mm_srli_epi16(tmp, 4), tmp);
537
557
  const __m256i lowMask = _mm256_set1_epi8( 0xF );
538
558
  return _mm256_and_si256(lowMask, bytes);
539
559
  }
@@ -606,7 +626,7 @@ static inline __m256i bytes_from_bits_32(const uint8_t * x) {
606
626
  bytesh = _mm_or_si128(bytesh, bit_mask);
607
627
  bytesl = _mm_cmpeq_epi8(bytesl, _mm_set1_epi64x(-1));
608
628
  bytesh = _mm_cmpeq_epi8(bytesh, _mm_set1_epi64x(-1));
609
- return _mm256_set_m128i(bytesh, bytesl);
629
+ return MM256_SET_M128I(bytesh, bytesl);
610
630
  }
611
631
 
612
632
  // Unpack 32 4-bit fields into 32 bytes
@@ -619,7 +639,7 @@ static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
619
639
  const __m128i lowMask = _mm_set1_epi8(0xF);
620
640
  tmpl = _mm_and_si128(lowMask, tmpl);
621
641
  tmph = _mm_and_si128(lowMask, tmph);
622
- return _mm256_set_m128i(tmph, tmpl);
642
+ return MM256_SET_M128I(tmph, tmpl);
623
643
  }
624
644
 
625
645
  // add int16_t pairwise and return as float vector
@@ -627,7 +647,7 @@ static inline __m256 sum_i16_pairs_float(const __m128i xh, const __m128i xl) {
627
647
  const __m128i ones = _mm_set1_epi16(1);
628
648
  const __m128i summed_pairsl = _mm_madd_epi16(ones, xl);
629
649
  const __m128i summed_pairsh = _mm_madd_epi16(ones, xh);
630
- const __m256i summed_pairs = _mm256_set_m128i(summed_pairsh, summed_pairsl);
650
+ const __m256i summed_pairs = MM256_SET_M128I(summed_pairsh, summed_pairsl);
631
651
  return _mm256_cvtepi32_ps(summed_pairs);
632
652
  }
633
653
 
@@ -1565,6 +1585,48 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
1565
1585
  .vec_dot_q = NULL, // TODO
1566
1586
  .vec_dot_type = GGML_TYPE_Q8_1,
1567
1587
  },
1588
+ #ifdef GGML_USE_K_QUANTS
1589
+ [GGML_TYPE_Q2_K] = {
1590
+ .dequantize_row_q = (dequantize_row_q_t) dequantize_row_q2_K,
1591
+ .quantize_row_q = quantize_row_q2_K,
1592
+ .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q2_K_reference,
1593
+ .quantize_row_q_dot = quantize_row_q8_K,
1594
+ .vec_dot_q = ggml_vec_dot_q2_K_q8_K,
1595
+ .vec_dot_type = GGML_TYPE_Q8_K,
1596
+ },
1597
+ [GGML_TYPE_Q3_K] = {
1598
+ .dequantize_row_q = (dequantize_row_q_t) dequantize_row_q3_K,
1599
+ .quantize_row_q = quantize_row_q3_K,
1600
+ .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q3_K_reference,
1601
+ .quantize_row_q_dot = quantize_row_q8_K,
1602
+ .vec_dot_q = ggml_vec_dot_q3_K_q8_K,
1603
+ .vec_dot_type = GGML_TYPE_Q8_K,
1604
+ },
1605
+ [GGML_TYPE_Q4_K] = {
1606
+ .dequantize_row_q = (dequantize_row_q_t) dequantize_row_q4_K,
1607
+ .quantize_row_q = quantize_row_q4_K,
1608
+ .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_K_reference,
1609
+ .quantize_row_q_dot = quantize_row_q8_K,
1610
+ .vec_dot_q = ggml_vec_dot_q4_K_q8_K,
1611
+ .vec_dot_type = GGML_TYPE_Q8_K,
1612
+ },
1613
+ [GGML_TYPE_Q5_K] = {
1614
+ .dequantize_row_q = (dequantize_row_q_t) dequantize_row_q5_K,
1615
+ .quantize_row_q = quantize_row_q5_K,
1616
+ .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q5_K_reference,
1617
+ .quantize_row_q_dot = quantize_row_q8_K,
1618
+ .vec_dot_q = ggml_vec_dot_q5_K_q8_K,
1619
+ .vec_dot_type = GGML_TYPE_Q8_K,
1620
+ },
1621
+ [GGML_TYPE_Q6_K] = {
1622
+ .dequantize_row_q = (dequantize_row_q_t) dequantize_row_q6_K,
1623
+ .quantize_row_q = quantize_row_q6_K,
1624
+ .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q6_K_reference,
1625
+ .quantize_row_q_dot = quantize_row_q8_K,
1626
+ .vec_dot_q = ggml_vec_dot_q6_K_q8_K,
1627
+ .vec_dot_type = GGML_TYPE_Q8_K,
1628
+ },
1629
+ #endif
1568
1630
  };
1569
1631
 
1570
1632
  // For internal test use
@@ -2290,7 +2352,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2290
2352
  const __m128i i32_1 = mul_sum_i8_pairs(bx, by);
2291
2353
 
2292
2354
  // Convert int32_t to float
2293
- __m256 p = _mm256_cvtepi32_ps(_mm256_set_m128i(i32_0, i32_1));
2355
+ __m256 p = _mm256_cvtepi32_ps(MM256_SET_M128I(i32_0, i32_1));
2294
2356
 
2295
2357
  // Apply the scale, and accumulate
2296
2358
  acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc);
@@ -2766,7 +2828,7 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
2766
2828
  __m128i bxh = _mm256_extractf128_si256(bx, 1);
2767
2829
  bxl = _mm_or_si128(bxl, bxhil);
2768
2830
  bxh = _mm_or_si128(bxh, bxhih);
2769
- bx = _mm256_set_m128i(bxh, bxl);
2831
+ bx = MM256_SET_M128I(bxh, bxl);
2770
2832
 
2771
2833
  const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
2772
2834
 
@@ -3022,7 +3084,7 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
3022
3084
  __m128i bxh = _mm256_extractf128_si256(bx, 1);
3023
3085
  bxl = _mm_or_si128(bxl, bxhil);
3024
3086
  bxh = _mm_or_si128(bxh, bxhih);
3025
- bx = _mm256_set_m128i(bxh, bxl);
3087
+ bx = MM256_SET_M128I(bxh, bxl);
3026
3088
 
3027
3089
  const __m256 dy = _mm256_set1_ps(y[i].d);
3028
3090
  const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
@@ -3444,11 +3506,19 @@ static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = {
3444
3506
  [GGML_TYPE_Q5_1] = QK5_1,
3445
3507
  [GGML_TYPE_Q8_0] = QK8_0,
3446
3508
  [GGML_TYPE_Q8_1] = QK8_1,
3509
+ #ifdef GGML_USE_K_QUANTS
3510
+ [GGML_TYPE_Q2_K] = QK_K,
3511
+ [GGML_TYPE_Q3_K] = QK_K,
3512
+ [GGML_TYPE_Q4_K] = QK_K,
3513
+ [GGML_TYPE_Q5_K] = QK_K,
3514
+ [GGML_TYPE_Q6_K] = QK_K,
3515
+ [GGML_TYPE_Q8_K] = QK_K,
3516
+ #endif
3447
3517
  [GGML_TYPE_I8] = 1,
3448
3518
  [GGML_TYPE_I16] = 1,
3449
3519
  [GGML_TYPE_I32] = 1,
3450
3520
  };
3451
- static_assert(GGML_TYPE_COUNT == 13, "GGML_BLCK_SIZE is outdated");
3521
+ static_assert(GGML_TYPE_COUNT == 19, "GGML_BLCK_SIZE is outdated");
3452
3522
 
3453
3523
  static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
3454
3524
  [GGML_TYPE_F32] = sizeof(float),
@@ -3459,11 +3529,19 @@ static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
3459
3529
  [GGML_TYPE_Q5_1] = sizeof(block_q5_1),
3460
3530
  [GGML_TYPE_Q8_0] = sizeof(block_q8_0),
3461
3531
  [GGML_TYPE_Q8_1] = sizeof(block_q8_1),
3532
+ #ifdef GGML_USE_K_QUANTS
3533
+ [GGML_TYPE_Q2_K] = sizeof(block_q2_K),
3534
+ [GGML_TYPE_Q3_K] = sizeof(block_q3_K),
3535
+ [GGML_TYPE_Q4_K] = sizeof(block_q4_K),
3536
+ [GGML_TYPE_Q5_K] = sizeof(block_q5_K),
3537
+ [GGML_TYPE_Q6_K] = sizeof(block_q6_K),
3538
+ [GGML_TYPE_Q8_K] = sizeof(block_q8_K),
3539
+ #endif
3462
3540
  [GGML_TYPE_I8] = sizeof(int8_t),
3463
3541
  [GGML_TYPE_I16] = sizeof(int16_t),
3464
3542
  [GGML_TYPE_I32] = sizeof(int32_t),
3465
3543
  };
3466
- static_assert(GGML_TYPE_COUNT == 13, "GGML_TYPE_SIZE is outdated");
3544
+ static_assert(GGML_TYPE_COUNT == 19, "GGML_TYPE_SIZE is outdated");
3467
3545
 
3468
3546
 
3469
3547
  static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
@@ -3475,11 +3553,17 @@ static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
3475
3553
  [GGML_TYPE_Q5_1] = "q5_1",
3476
3554
  [GGML_TYPE_Q8_0] = "q8_0",
3477
3555
  [GGML_TYPE_Q8_1] = "q8_1",
3556
+ [GGML_TYPE_Q2_K] = "q2_K",
3557
+ [GGML_TYPE_Q3_K] = "q3_K",
3558
+ [GGML_TYPE_Q4_K] = "q4_K",
3559
+ [GGML_TYPE_Q5_K] = "q5_K",
3560
+ [GGML_TYPE_Q6_K] = "q6_K",
3561
+ [GGML_TYPE_Q8_K] = "q8_K",
3478
3562
  [GGML_TYPE_I8] = "i8",
3479
3563
  [GGML_TYPE_I16] = "i16",
3480
3564
  [GGML_TYPE_I32] = "i32",
3481
3565
  };
3482
- static_assert(GGML_TYPE_COUNT == 13, "GGML_TYPE_NAME is outdated");
3566
+ static_assert(GGML_TYPE_COUNT == 19, "GGML_TYPE_NAME is outdated");
3483
3567
 
3484
3568
  static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = {
3485
3569
  [GGML_TYPE_F32] = false,
@@ -3490,11 +3574,17 @@ static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = {
3490
3574
  [GGML_TYPE_Q5_1] = true,
3491
3575
  [GGML_TYPE_Q8_0] = true,
3492
3576
  [GGML_TYPE_Q8_1] = true,
3577
+ [GGML_TYPE_Q2_K] = true,
3578
+ [GGML_TYPE_Q3_K] = true,
3579
+ [GGML_TYPE_Q4_K] = true,
3580
+ [GGML_TYPE_Q5_K] = true,
3581
+ [GGML_TYPE_Q6_K] = true,
3582
+ [GGML_TYPE_Q8_K] = true,
3493
3583
  [GGML_TYPE_I8] = false,
3494
3584
  [GGML_TYPE_I16] = false,
3495
3585
  [GGML_TYPE_I32] = false,
3496
3586
  };
3497
- static_assert(GGML_TYPE_COUNT == 13, "GGML_IS_QUANTIZED is outdated");
3587
+ static_assert(GGML_TYPE_COUNT == 19, "GGML_IS_QUANTIZED is outdated");
3498
3588
 
3499
3589
  static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
3500
3590
  "NONE",
@@ -3513,6 +3603,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
3513
3603
  "SUM_ROWS",
3514
3604
  "MEAN",
3515
3605
  "REPEAT",
3606
+ "REPEAT_BACK",
3516
3607
  "ABS",
3517
3608
  "SGN",
3518
3609
  "NEG",
@@ -3526,6 +3617,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
3526
3617
  "RMS_NORM_BACK",
3527
3618
 
3528
3619
  "MUL_MAT",
3620
+ "OUT_PROD",
3529
3621
 
3530
3622
  "SCALE",
3531
3623
  "SET",
@@ -3541,6 +3633,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
3541
3633
  "DIAG_MASK_INF",
3542
3634
  "DIAG_MASK_ZERO",
3543
3635
  "SOFT_MAX",
3636
+ "SOFT_MAX_BACK",
3544
3637
  "ROPE",
3545
3638
  "ROPE_BACK",
3546
3639
  "ALIBI",
@@ -3550,13 +3643,16 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
3550
3643
 
3551
3644
  "FLASH_ATTN",
3552
3645
  "FLASH_FF",
3646
+ "FLASH_ATTN_BACK",
3553
3647
 
3554
3648
  "MAP_UNARY",
3555
3649
  "MAP_BINARY",
3556
- };
3557
3650
 
3558
- static_assert(GGML_OP_COUNT == 51, "GGML_OP_COUNT != 51");
3651
+ "CROSS_ENTROPY_LOSS",
3652
+ "CROSS_ENTROPY_LOSS_BACK",
3653
+ };
3559
3654
 
3655
+ static_assert(GGML_OP_COUNT == 57, "GGML_OP_COUNT != 57");
3560
3656
 
3561
3657
  static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
3562
3658
  "none",
@@ -3575,6 +3671,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
3575
3671
  "Σx_k",
3576
3672
  "Σx/n",
3577
3673
  "repeat(x)",
3674
+ "repeat_back(x)",
3578
3675
  "abs(x)",
3579
3676
  "sgn(x)",
3580
3677
  "-x",
@@ -3587,6 +3684,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
3587
3684
  "rms_norm(x)",
3588
3685
  "rms_norm_back(x)",
3589
3686
 
3687
+ "X*Y",
3590
3688
  "X*Y",
3591
3689
 
3592
3690
  "x*v",
@@ -3603,6 +3701,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
3603
3701
  "diag_mask_inf(x)",
3604
3702
  "diag_mask_zero(x)",
3605
3703
  "soft_max(x)",
3704
+ "soft_max_back(x)",
3606
3705
  "rope(x)",
3607
3706
  "rope_back(x)",
3608
3707
  "alibi(x)",
@@ -3612,12 +3711,16 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
3612
3711
 
3613
3712
  "flash_attn(x)",
3614
3713
  "flash_ff(x)",
3714
+ "flash_attn_back(x)",
3615
3715
 
3616
3716
  "f(x)",
3617
3717
  "f(x,y)",
3718
+
3719
+ "cross_entropy_loss(x,y)",
3720
+ "cross_entropy_loss_back(x,y)",
3618
3721
  };
3619
3722
 
3620
- static_assert(GGML_OP_COUNT == 51, "GGML_OP_COUNT != 51");
3723
+ static_assert(GGML_OP_COUNT == 57, "GGML_OP_COUNT != 57");
3621
3724
 
3622
3725
  static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
3623
3726
  static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
@@ -3631,6 +3734,7 @@ struct ggml_context {
3631
3734
  void * mem_buffer;
3632
3735
  bool mem_buffer_owned;
3633
3736
  bool no_alloc;
3737
+ bool no_alloc_save; // this is used to save the no_alloc state when using scratch buffers
3634
3738
 
3635
3739
  int n_objects;
3636
3740
 
@@ -3647,26 +3751,6 @@ struct ggml_context_container {
3647
3751
  struct ggml_context context;
3648
3752
  };
3649
3753
 
3650
- //
3651
- // compute types
3652
- //
3653
-
3654
- enum ggml_task_type {
3655
- GGML_TASK_INIT = 0,
3656
- GGML_TASK_COMPUTE,
3657
- GGML_TASK_FINALIZE,
3658
- };
3659
-
3660
- struct ggml_compute_params {
3661
- enum ggml_task_type type;
3662
-
3663
- int ith, nth;
3664
-
3665
- // work buffer for all threads
3666
- size_t wsize;
3667
- void * wdata;
3668
- };
3669
-
3670
3754
  //
3671
3755
  // ggml state
3672
3756
  //
@@ -3723,7 +3807,7 @@ int64_t ggml_nelements(const struct ggml_tensor * tensor) {
3723
3807
  return tensor->ne[0]*tensor->ne[1]*tensor->ne[2]*tensor->ne[3];
3724
3808
  }
3725
3809
 
3726
- int ggml_nrows(const struct ggml_tensor * tensor) {
3810
+ int64_t ggml_nrows(const struct ggml_tensor * tensor) {
3727
3811
  static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
3728
3812
 
3729
3813
  return tensor->ne[1]*tensor->ne[2]*tensor->ne[3];
@@ -3732,7 +3816,20 @@ int ggml_nrows(const struct ggml_tensor * tensor) {
3732
3816
  size_t ggml_nbytes(const struct ggml_tensor * tensor) {
3733
3817
  static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
3734
3818
 
3735
- return (ggml_nelements(tensor)*GGML_TYPE_SIZE[tensor->type])/GGML_BLCK_SIZE[tensor->type];
3819
+ // this should handle cases where the tensor is not contiguous in memory
3820
+ // probaby just:
3821
+ //
3822
+ // return tensor->ne[3]*tensor->nb[3]
3823
+ //
3824
+ // is enough, but just in case, adding the second part
3825
+
3826
+ return MAX(tensor->ne[3]*tensor->nb[3], (ggml_nelements(tensor)*GGML_TYPE_SIZE[tensor->type])/GGML_BLCK_SIZE[tensor->type]);
3827
+ }
3828
+
3829
+ size_t ggml_nbytes_split(const struct ggml_tensor * tensor, int nrows_split) {
3830
+ static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
3831
+
3832
+ return (nrows_split*tensor->ne[0]*GGML_TYPE_SIZE[tensor->type])/GGML_BLCK_SIZE[tensor->type];
3736
3833
  }
3737
3834
 
3738
3835
  int ggml_blck_size(enum ggml_type type) {
@@ -3786,6 +3883,15 @@ static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct
3786
3883
  (t0->ne[3] == t1->ne[3]);
3787
3884
  }
3788
3885
 
3886
+ static inline bool ggml_can_out_prod(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
3887
+ static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
3888
+
3889
+ return
3890
+ (t0->ne[1] == t1->ne[1]) &&
3891
+ (t0->ne[2] == t1->ne[2]) &&
3892
+ (t0->ne[3] == t1->ne[3]);
3893
+ }
3894
+
3789
3895
  bool ggml_is_quantized(enum ggml_type type) {
3790
3896
  return GGML_IS_QUANTIZED[type];
3791
3897
  }
@@ -3801,6 +3907,11 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
3801
3907
  case GGML_FTYPE_MOSTLY_Q5_0: wtype = GGML_TYPE_Q5_0; break;
3802
3908
  case GGML_FTYPE_MOSTLY_Q5_1: wtype = GGML_TYPE_Q5_1; break;
3803
3909
  case GGML_FTYPE_MOSTLY_Q8_0: wtype = GGML_TYPE_Q8_0; break;
3910
+ case GGML_FTYPE_MOSTLY_Q2_K: wtype = GGML_TYPE_Q2_K; break;
3911
+ case GGML_FTYPE_MOSTLY_Q3_K: wtype = GGML_TYPE_Q3_K; break;
3912
+ case GGML_FTYPE_MOSTLY_Q4_K: wtype = GGML_TYPE_Q4_K; break;
3913
+ case GGML_FTYPE_MOSTLY_Q5_K: wtype = GGML_TYPE_Q5_K; break;
3914
+ case GGML_FTYPE_MOSTLY_Q6_K: wtype = GGML_TYPE_Q6_K; break;
3804
3915
  case GGML_FTYPE_UNKNOWN: wtype = GGML_TYPE_COUNT; break;
3805
3916
  case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_TYPE_COUNT; break;
3806
3917
  }
@@ -3814,11 +3925,11 @@ size_t ggml_tensor_overhead(void) {
3814
3925
  return GGML_OBJECT_SIZE + GGML_TENSOR_SIZE + 16;
3815
3926
  }
3816
3927
 
3817
- static inline bool ggml_is_transposed(const struct ggml_tensor * tensor) {
3928
+ bool ggml_is_transposed(const struct ggml_tensor * tensor) {
3818
3929
  return tensor->nb[0] > tensor->nb[1];
3819
3930
  }
3820
3931
 
3821
- static inline bool ggml_is_contiguous(const struct ggml_tensor * tensor) {
3932
+ bool ggml_is_contiguous(const struct ggml_tensor * tensor) {
3822
3933
  static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
3823
3934
 
3824
3935
  return
@@ -3828,6 +3939,12 @@ static inline bool ggml_is_contiguous(const struct ggml_tensor * tensor) {
3828
3939
  tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
3829
3940
  }
3830
3941
 
3942
+ bool ggml_is_permuted(const struct ggml_tensor * tensor) {
3943
+ static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
3944
+
3945
+ return tensor->nb[0] > tensor->nb[1] || tensor->nb[1] > tensor->nb[2] || tensor->nb[2] > tensor->nb[3];
3946
+ }
3947
+
3831
3948
  static inline bool ggml_is_padded_1d(const struct ggml_tensor * tensor) {
3832
3949
  static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
3833
3950
 
@@ -3967,6 +4084,7 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
3967
4084
  /*.mem_buffer =*/ params.mem_buffer ? params.mem_buffer : GGML_ALIGNED_MALLOC(mem_size),
3968
4085
  /*.mem_buffer_owned =*/ params.mem_buffer ? false : true,
3969
4086
  /*.no_alloc =*/ params.no_alloc,
4087
+ /*.no_alloc_save =*/ params.no_alloc,
3970
4088
  /*.n_objects =*/ 0,
3971
4089
  /*.objects_begin =*/ NULL,
3972
4090
  /*.objects_end =*/ NULL,
@@ -4044,11 +4162,18 @@ size_t ggml_get_mem_size(struct ggml_context * ctx) {
4044
4162
  // operators when using scratch buffers
4045
4163
  // TODO: implement a better way
4046
4164
  void ggml_scratch_save(struct ggml_context * ctx) {
4165
+ // this is needed to allow opt tensors to store their data
4166
+ // TODO: again, need to find a better way
4167
+ ctx->no_alloc_save = ctx->no_alloc;
4168
+ ctx->no_alloc = false;
4169
+
4047
4170
  ctx->scratch_save = ctx->scratch;
4048
4171
  ctx->scratch.data = NULL;
4049
4172
  }
4050
4173
 
4051
4174
  void ggml_scratch_load(struct ggml_context * ctx) {
4175
+ ctx->no_alloc = ctx->no_alloc_save;
4176
+
4052
4177
  ctx->scratch = ctx->scratch_save;
4053
4178
  }
4054
4179
 
@@ -4157,6 +4282,7 @@ struct ggml_tensor * ggml_new_tensor_impl(
4157
4282
  /*.perf_time_us =*/ 0,
4158
4283
  /*.data =*/ (data == NULL && !ctx->no_alloc) ? (void *)(result + 1) : data,
4159
4284
  /*.name =*/ { 0 },
4285
+ /*.extra =*/ NULL,
4160
4286
  /*.pad =*/ { 0 },
4161
4287
  };
4162
4288
 
@@ -4595,7 +4721,7 @@ struct ggml_tensor * ggml_add_impl(
4595
4721
 
4596
4722
  bool is_node = false;
4597
4723
 
4598
- if (!inplace && (a->grad || b->grad)) {
4724
+ if (a->grad || b->grad) {
4599
4725
  is_node = true;
4600
4726
  }
4601
4727
 
@@ -4635,7 +4761,7 @@ struct ggml_tensor * ggml_add1_impl(
4635
4761
 
4636
4762
  bool is_node = false;
4637
4763
 
4638
- if (!inplace && (a->grad || b->grad)) {
4764
+ if (a->grad || b->grad) {
4639
4765
  is_node = true;
4640
4766
  }
4641
4767
 
@@ -5061,6 +5187,34 @@ struct ggml_tensor * ggml_repeat(
5061
5187
  return result;
5062
5188
  }
5063
5189
 
5190
+ // ggml_repeat_back
5191
+
5192
+ struct ggml_tensor * ggml_repeat_back(
5193
+ struct ggml_context * ctx,
5194
+ struct ggml_tensor * a,
5195
+ struct ggml_tensor * b) {
5196
+ GGML_ASSERT(ggml_can_repeat(b, a));
5197
+
5198
+ bool is_node = false;
5199
+
5200
+ if (a->grad) {
5201
+ is_node = true;
5202
+ }
5203
+
5204
+ if (ggml_are_same_shape(a, b) && !is_node) {
5205
+ return a;
5206
+ }
5207
+
5208
+ struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, b->n_dims, b->ne);
5209
+
5210
+ result->op = GGML_OP_REPEAT_BACK;
5211
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5212
+ result->src0 = a;
5213
+ result->src1 = b;
5214
+
5215
+ return result;
5216
+ }
5217
+
5064
5218
  // ggml_abs
5065
5219
 
5066
5220
  struct ggml_tensor * ggml_abs_impl(
@@ -5438,6 +5592,32 @@ struct ggml_tensor * ggml_mul_mat(
5438
5592
  return result;
5439
5593
  }
5440
5594
 
5595
+ // ggml_out_prod
5596
+
5597
+ struct ggml_tensor * ggml_out_prod(
5598
+ struct ggml_context * ctx,
5599
+ struct ggml_tensor * a,
5600
+ struct ggml_tensor * b) {
5601
+ GGML_ASSERT(ggml_can_out_prod(a, b));
5602
+ GGML_ASSERT(!ggml_is_transposed(a));
5603
+
5604
+ bool is_node = false;
5605
+
5606
+ if (a->grad || b->grad) {
5607
+ is_node = true;
5608
+ }
5609
+
5610
+ const int64_t ne[4] = { a->ne[0], b->ne[0], a->ne[2], b->ne[3] };
5611
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MIN(a->n_dims, b->n_dims), ne);
5612
+
5613
+ result->op = GGML_OP_OUT_PROD;
5614
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5615
+ result->src0 = a;
5616
+ result->src1 = b;
5617
+
5618
+ return result;
5619
+ }
5620
+
5441
5621
  // ggml_scale
5442
5622
 
5443
5623
  struct ggml_tensor * ggml_scale_impl(
@@ -5450,7 +5630,7 @@ struct ggml_tensor * ggml_scale_impl(
5450
5630
 
5451
5631
  bool is_node = false;
5452
5632
 
5453
- if (!inplace && (a->grad || b->grad)) {
5633
+ if (a->grad || b->grad) {
5454
5634
  is_node = true;
5455
5635
  }
5456
5636
 
@@ -5493,7 +5673,7 @@ struct ggml_tensor * ggml_set_impl(
5493
5673
 
5494
5674
  bool is_node = false;
5495
5675
 
5496
- if (!inplace && (a->grad || b->grad)) {
5676
+ if (a->grad || b->grad) {
5497
5677
  is_node = true;
5498
5678
  }
5499
5679
 
@@ -5802,14 +5982,18 @@ struct ggml_tensor * ggml_view_1d(
5802
5982
 
5803
5983
  struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 1, &ne0, (char *) a->data + offset);
5804
5984
 
5985
+ ggml_scratch_save(ctx);
5986
+
5987
+ struct ggml_tensor * offs = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 2);
5988
+ memcpy(offs->data, &offset, 2*sizeof(int32_t));
5989
+
5990
+ ggml_scratch_load(ctx);
5991
+
5805
5992
  result->op = GGML_OP_VIEW;
5806
5993
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5807
5994
  result->src0 = a;
5808
5995
  result->src1 = NULL;
5809
-
5810
- if (is_node) {
5811
- memcpy(result->padding, &offset, sizeof(offset));
5812
- }
5996
+ result->opt[0] = offs;
5813
5997
 
5814
5998
  return result;
5815
5999
  }
@@ -5834,6 +6018,13 @@ struct ggml_tensor * ggml_view_2d(
5834
6018
 
5835
6019
  struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 2, ne, (char *) a->data + offset);
5836
6020
 
6021
+ ggml_scratch_save(ctx);
6022
+
6023
+ struct ggml_tensor * offs = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 2);
6024
+ memcpy(offs->data, &offset, 2*sizeof(int32_t));
6025
+
6026
+ ggml_scratch_load(ctx);
6027
+
5837
6028
  result->nb[1] = nb1;
5838
6029
  result->nb[2] = result->nb[1]*ne1;
5839
6030
  result->nb[3] = result->nb[2];
@@ -5842,10 +6033,7 @@ struct ggml_tensor * ggml_view_2d(
5842
6033
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5843
6034
  result->src0 = a;
5844
6035
  result->src1 = NULL;
5845
-
5846
- if (is_node) {
5847
- memcpy(result->padding, &offset, sizeof(offset));
5848
- }
6036
+ result->opt[0] = offs;
5849
6037
 
5850
6038
  return result;
5851
6039
  }
@@ -5872,6 +6060,13 @@ struct ggml_tensor * ggml_view_3d(
5872
6060
 
5873
6061
  struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 3, ne, (char *) a->data + offset);
5874
6062
 
6063
+ ggml_scratch_save(ctx);
6064
+
6065
+ struct ggml_tensor * offs = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 2);
6066
+ memcpy(offs->data, &offset, 2*sizeof(int32_t));
6067
+
6068
+ ggml_scratch_load(ctx);
6069
+
5875
6070
  result->nb[1] = nb1;
5876
6071
  result->nb[2] = nb2;
5877
6072
  result->nb[3] = result->nb[2]*ne2;
@@ -5880,10 +6075,7 @@ struct ggml_tensor * ggml_view_3d(
5880
6075
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5881
6076
  result->src0 = a;
5882
6077
  result->src1 = NULL;
5883
-
5884
- if (is_node) {
5885
- memcpy(result->padding, &offset, sizeof(offset));
5886
- }
6078
+ result->opt[0] = offs;
5887
6079
 
5888
6080
  return result;
5889
6081
  }
@@ -5912,6 +6104,13 @@ struct ggml_tensor * ggml_view_4d(
5912
6104
 
5913
6105
  struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 4, ne, (char *) a->data + offset);
5914
6106
 
6107
+ ggml_scratch_save(ctx);
6108
+
6109
+ struct ggml_tensor * offs = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 2);
6110
+ memcpy(offs->data, &offset, 2*sizeof(int32_t));
6111
+
6112
+ ggml_scratch_load(ctx);
6113
+
5915
6114
  result->nb[1] = nb1;
5916
6115
  result->nb[2] = nb2;
5917
6116
  result->nb[3] = nb3;
@@ -5920,10 +6119,7 @@ struct ggml_tensor * ggml_view_4d(
5920
6119
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5921
6120
  result->src0 = a;
5922
6121
  result->src1 = NULL;
5923
-
5924
- if (is_node) {
5925
- memcpy(result->padding, &offset, sizeof(offset));
5926
- }
6122
+ result->opt[0] = offs;
5927
6123
 
5928
6124
  return result;
5929
6125
  }
@@ -5986,10 +6182,18 @@ struct ggml_tensor * ggml_permute(
5986
6182
  result->src1 = NULL;
5987
6183
 
5988
6184
  if (is_node) {
5989
- result->padding[0] = axis0;
5990
- result->padding[1] = axis1;
5991
- result->padding[2] = axis2;
5992
- result->padding[3] = axis3;
6185
+ ggml_scratch_save(ctx);
6186
+
6187
+ struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 4);
6188
+
6189
+ ((int32_t *) b->data)[0] = axis0;
6190
+ ((int32_t *) b->data)[1] = axis1;
6191
+ ((int32_t *) b->data)[2] = axis2;
6192
+ ((int32_t *) b->data)[3] = axis3;
6193
+
6194
+ ggml_scratch_load(ctx);
6195
+
6196
+ result->opt[0] = b;
5993
6197
  }
5994
6198
 
5995
6199
  return result;
@@ -6229,6 +6433,44 @@ struct ggml_tensor * ggml_soft_max_inplace(
6229
6433
  return ggml_soft_max_impl(ctx, a, true);
6230
6434
  }
6231
6435
 
6436
+
6437
+ // ggml_soft_max_back
6438
+
6439
+ struct ggml_tensor * ggml_soft_max_back_impl(
6440
+ struct ggml_context * ctx,
6441
+ struct ggml_tensor * a,
6442
+ struct ggml_tensor * b,
6443
+ bool inplace) {
6444
+ bool is_node = false;
6445
+
6446
+ if (a->grad || b->grad) {
6447
+ is_node = true; // TODO : implement backward pass
6448
+ }
6449
+
6450
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
6451
+
6452
+ result->op = GGML_OP_SOFT_MAX_BACK;
6453
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6454
+ result->src0 = a;
6455
+ result->src1 = b;
6456
+
6457
+ return result;
6458
+ }
6459
+
6460
+ struct ggml_tensor * ggml_soft_max_back(
6461
+ struct ggml_context * ctx,
6462
+ struct ggml_tensor * a,
6463
+ struct ggml_tensor * b) {
6464
+ return ggml_soft_max_back_impl(ctx, a, b, false);
6465
+ }
6466
+
6467
+ struct ggml_tensor * ggml_soft_max_back_inplace(
6468
+ struct ggml_context * ctx,
6469
+ struct ggml_tensor * a,
6470
+ struct ggml_tensor * b) {
6471
+ return ggml_soft_max_back_impl(ctx, a, b, true);
6472
+ }
6473
+
6232
6474
  // ggml_rope
6233
6475
 
6234
6476
  struct ggml_tensor * ggml_rope_impl(
@@ -6241,7 +6483,7 @@ struct ggml_tensor * ggml_rope_impl(
6241
6483
  GGML_ASSERT(n_past >= 0);
6242
6484
  bool is_node = false;
6243
6485
 
6244
- if (!inplace && a->grad) {
6486
+ if (a->grad) {
6245
6487
  is_node = true;
6246
6488
  }
6247
6489
 
@@ -6295,8 +6537,7 @@ struct ggml_tensor * ggml_rope_back(
6295
6537
  bool is_node = false;
6296
6538
 
6297
6539
  if (a->grad) {
6298
- GGML_ASSERT(false); // TODO: implement backward
6299
- is_node = true;
6540
+ is_node = false; // TODO: implement backward
6300
6541
  }
6301
6542
 
6302
6543
  struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
@@ -6461,7 +6702,6 @@ struct ggml_tensor * ggml_flash_attn(
6461
6702
  bool is_node = false;
6462
6703
 
6463
6704
  if (q->grad || k->grad || v->grad) {
6464
- GGML_ASSERT(false); // TODO: implement backward
6465
6705
  is_node = true;
6466
6706
  }
6467
6707
 
@@ -6493,7 +6733,6 @@ struct ggml_tensor * ggml_flash_ff(
6493
6733
  bool is_node = false;
6494
6734
 
6495
6735
  if (a->grad || b0->grad || b1->grad || c0->grad || c1->grad) {
6496
- GGML_ASSERT(false); // TODO: implement backward
6497
6736
  is_node = true;
6498
6737
  }
6499
6738
 
@@ -6511,6 +6750,71 @@ struct ggml_tensor * ggml_flash_ff(
6511
6750
  return result;
6512
6751
  }
6513
6752
 
6753
+ // ggml_flash_attn_back
6754
+
6755
+ struct ggml_tensor * ggml_flash_attn_back(
6756
+ struct ggml_context * ctx,
6757
+ struct ggml_tensor * q,
6758
+ struct ggml_tensor * k,
6759
+ struct ggml_tensor * v,
6760
+ struct ggml_tensor * d,
6761
+ bool masked) {
6762
+ GGML_ASSERT(ggml_can_mul_mat(k, q));
6763
+ // TODO: check if vT can be multiplied by (k*qT)
6764
+
6765
+ // d shape [D,N,ne2,ne3]
6766
+ // q shape [D,N,ne2,ne3]
6767
+ // k shape [D,M,ne2,ne3]
6768
+ // v shape [M,D,ne2,ne3]
6769
+
6770
+ const int64_t D = q->ne[0];
6771
+ const int64_t N = q->ne[1];
6772
+ const int64_t M = k->ne[1];
6773
+ const int64_t ne2 = q->ne[2];
6774
+ const int64_t ne3 = q->ne[3];
6775
+
6776
+ GGML_ASSERT(k->ne[0] == D);
6777
+ GGML_ASSERT(v->ne[0] == M);
6778
+ GGML_ASSERT(v->ne[1] == D);
6779
+ GGML_ASSERT(d->ne[0] == D);
6780
+ GGML_ASSERT(d->ne[1] == N);
6781
+ GGML_ASSERT(k->ne[2] == ne2);
6782
+ GGML_ASSERT(k->ne[3] == ne3);
6783
+ GGML_ASSERT(v->ne[2] == ne2);
6784
+ GGML_ASSERT(v->ne[3] == ne3);
6785
+ GGML_ASSERT(d->ne[2] == ne2);
6786
+ GGML_ASSERT(d->ne[3] == ne3);
6787
+
6788
+ bool is_node = false;
6789
+
6790
+ if (q->grad || k->grad || v->grad) {
6791
+ // when using this operation (in backwards pass) these grads are set.
6792
+ // we don't want to create (big) grad of our result, so is_node is false.
6793
+ is_node = false;
6794
+ }
6795
+
6796
+ // store gradients of q, k and v as continuous tensors concatenated in result.
6797
+ // q shape[D,N,ne2,ne3] ; k shape [D,M,ne2,ne3] ; v shape [M,D,ne2,ne3]
6798
+ // gradq->data = result->data
6799
+ // gradk->data = result->data + nb0*D*N*ne2*ne3
6800
+ // gradv->data = result->data + nb0*D*N*ne2*ne3 + nb0*D*M*ne2*ne3
6801
+ // note: v and gradv are actually transposed, i.e. v->ne[0] != D.
6802
+ int64_t ne[4] = {D,M+N+M,ne2,ne3};
6803
+
6804
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
6805
+
6806
+ result->op = GGML_OP_FLASH_ATTN_BACK;
6807
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6808
+ result->src0 = q;
6809
+ result->src1 = k;
6810
+ result->opt[0] = v;
6811
+ result->opt[1] = d;
6812
+ result->opt[2] = ggml_new_i32(ctx, masked ? 1 : 0);
6813
+
6814
+ return result;
6815
+ }
6816
+
6817
+
6514
6818
  // ggml_map_unary
6515
6819
 
6516
6820
  struct ggml_tensor * ggml_map_unary_impl_f32(
@@ -6595,6 +6899,50 @@ struct ggml_tensor * ggml_map_binary_inplace_f32(
6595
6899
  return ggml_map_binary_impl_f32(ctx, a, b, fun, true);
6596
6900
  }
6597
6901
 
6902
+ // ggml_cross_entropy_loss
6903
+
6904
+ struct ggml_tensor * ggml_cross_entropy_loss(
6905
+ struct ggml_context * ctx,
6906
+ struct ggml_tensor * a,
6907
+ struct ggml_tensor * b) {
6908
+ GGML_ASSERT(ggml_are_same_shape(a, b));
6909
+ bool is_node = false;
6910
+
6911
+ if (a->grad || b->grad) {
6912
+ is_node = true;
6913
+ }
6914
+
6915
+ struct ggml_tensor * result = ggml_new_tensor_1d(ctx, a->type, 1);
6916
+
6917
+ result->op = GGML_OP_CROSS_ENTROPY_LOSS;
6918
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6919
+ result->src0 = a;
6920
+ result->src1 = b;
6921
+
6922
+ return result;
6923
+ }
6924
+
6925
+ // ggml_cross_entropy_loss_back
6926
+
6927
+ struct ggml_tensor * ggml_cross_entropy_loss_back(
6928
+ struct ggml_context * ctx,
6929
+ struct ggml_tensor * a,
6930
+ struct ggml_tensor * b,
6931
+ struct ggml_tensor * c) {
6932
+ GGML_ASSERT(ggml_are_same_shape(a, b));
6933
+ GGML_ASSERT(ggml_is_scalar(c));
6934
+
6935
+ struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
6936
+
6937
+ result->op = GGML_OP_CROSS_ENTROPY_LOSS_BACK;
6938
+ result->grad = NULL;
6939
+ result->src0 = a;
6940
+ result->src1 = b;
6941
+ result->opt[0] = c;
6942
+
6943
+ return result;
6944
+ }
6945
+
6598
6946
  ////////////////////////////////////////////////////////////////////////////////
6599
6947
 
6600
6948
  void ggml_set_param(
@@ -7584,6 +7932,11 @@ static void ggml_compute_forward_add(
7584
7932
  case GGML_TYPE_Q5_0:
7585
7933
  case GGML_TYPE_Q5_1:
7586
7934
  case GGML_TYPE_Q8_0:
7935
+ case GGML_TYPE_Q2_K:
7936
+ case GGML_TYPE_Q3_K:
7937
+ case GGML_TYPE_Q4_K:
7938
+ case GGML_TYPE_Q5_K:
7939
+ case GGML_TYPE_Q6_K:
7587
7940
  {
7588
7941
  ggml_compute_forward_add_q_f32(params, src0, src1, dst);
7589
7942
  } break;
@@ -7887,6 +8240,11 @@ static void ggml_compute_forward_add1(
7887
8240
  case GGML_TYPE_Q5_1:
7888
8241
  case GGML_TYPE_Q8_0:
7889
8242
  case GGML_TYPE_Q8_1:
8243
+ case GGML_TYPE_Q2_K:
8244
+ case GGML_TYPE_Q3_K:
8245
+ case GGML_TYPE_Q4_K:
8246
+ case GGML_TYPE_Q5_K:
8247
+ case GGML_TYPE_Q6_K:
7890
8248
  {
7891
8249
  ggml_compute_forward_add1_q_f32(params, src0, src1, dst);
7892
8250
  } break;
@@ -8009,6 +8367,11 @@ static void ggml_compute_forward_acc(
8009
8367
  case GGML_TYPE_Q5_1:
8010
8368
  case GGML_TYPE_Q8_0:
8011
8369
  case GGML_TYPE_Q8_1:
8370
+ case GGML_TYPE_Q2_K:
8371
+ case GGML_TYPE_Q3_K:
8372
+ case GGML_TYPE_Q4_K:
8373
+ case GGML_TYPE_Q5_K:
8374
+ case GGML_TYPE_Q6_K:
8012
8375
  default:
8013
8376
  {
8014
8377
  GGML_ASSERT(false);
@@ -8127,10 +8490,10 @@ static void ggml_compute_forward_mul_f32(
8127
8490
  const int ith = params->ith;
8128
8491
  const int nth = params->nth;
8129
8492
 
8130
- #ifdef GGML_USE_CUBLAS
8131
- if (src1->backend == GGML_BACKEND_CUDA) {
8493
+ #ifdef GGML_USE_CLBLAST
8494
+ if (src1->backend == GGML_BACKEND_GPU) {
8132
8495
  if (ith == 0) {
8133
- ggml_cuda_mul(src0, src1, dst);
8496
+ ggml_cl_mul(src0, src1, dst);
8134
8497
  }
8135
8498
  return;
8136
8499
  }
@@ -8730,29 +9093,122 @@ static void ggml_compute_forward_repeat(
8730
9093
  }
8731
9094
  }
8732
9095
 
8733
- // ggml_compute_forward_abs
9096
+ // ggml_compute_forward_repeat_back
8734
9097
 
8735
- static void ggml_compute_forward_abs_f32(
9098
+ static void ggml_compute_forward_repeat_back_f32(
8736
9099
  const struct ggml_compute_params * params,
8737
9100
  const struct ggml_tensor * src0,
8738
9101
  struct ggml_tensor * dst) {
8739
- assert(params->ith == 0);
8740
- assert(ggml_are_same_shape(src0, dst));
9102
+ GGML_ASSERT(params->ith == 0);
9103
+ GGML_ASSERT(ggml_can_repeat(dst, src0));
8741
9104
 
8742
9105
  if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
8743
9106
  return;
8744
9107
  }
8745
9108
 
8746
- const int n = ggml_nrows(src0);
8747
- const int nc = src0->ne[0];
9109
+ const int64_t ne0 = dst->ne[0];
9110
+ const int64_t ne1 = dst->ne[1];
9111
+ const int64_t ne2 = dst->ne[2];
9112
+ const int64_t ne3 = dst->ne[3];
8748
9113
 
8749
- assert(dst->nb[0] == sizeof(float));
8750
- assert(src0->nb[0] == sizeof(float));
9114
+ const int64_t ne00 = src0->ne[0];
9115
+ const int64_t ne01 = src0->ne[1];
9116
+ const int64_t ne02 = src0->ne[2];
9117
+ const int64_t ne03 = src0->ne[3];
8751
9118
 
8752
- for (int i = 0; i < n; i++) {
8753
- ggml_vec_abs_f32(nc,
8754
- (float *) ((char *) dst->data + i*( dst->nb[1])),
8755
- (float *) ((char *) src0->data + i*(src0->nb[1])));
9119
+ const size_t nb0 = dst->nb[0];
9120
+ const size_t nb1 = dst->nb[1];
9121
+ const size_t nb2 = dst->nb[2];
9122
+ const size_t nb3 = dst->nb[3];
9123
+
9124
+ const size_t nb00 = src0->nb[0];
9125
+ const size_t nb01 = src0->nb[1];
9126
+ const size_t nb02 = src0->nb[2];
9127
+ const size_t nb03 = src0->nb[3];
9128
+
9129
+ // guaranteed to be an integer due to the check in ggml_can_repeat
9130
+ const int nr0 = (int)(ne00/ne0);
9131
+ const int nr1 = (int)(ne01/ne1);
9132
+ const int nr2 = (int)(ne02/ne2);
9133
+ const int nr3 = (int)(ne03/ne3);
9134
+
9135
+ // TODO: support for transposed / permuted tensors
9136
+ GGML_ASSERT(nb0 == sizeof(float));
9137
+ GGML_ASSERT(nb00 == sizeof(float));
9138
+
9139
+ if (ggml_is_contiguous(dst)) {
9140
+ ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
9141
+ } else {
9142
+ for (int k3 = 0; k3 < ne3; k3++) {
9143
+ for (int k2 = 0; k2 < ne2; k2++) {
9144
+ for (int k1 = 0; k1 < ne1; k1++) {
9145
+ ggml_vec_set_f32(ne0,
9146
+ (float *) ((char *) dst->data + k1*nb1 + k2*nb2 + k3*nb3),
9147
+ 0);
9148
+ }
9149
+ }
9150
+ }
9151
+ }
9152
+
9153
+ // TODO: maybe this is not optimal?
9154
+ for (int i3 = 0; i3 < nr3; i3++) {
9155
+ for (int k3 = 0; k3 < ne3; k3++) {
9156
+ for (int i2 = 0; i2 < nr2; i2++) {
9157
+ for (int k2 = 0; k2 < ne2; k2++) {
9158
+ for (int i1 = 0; i1 < nr1; i1++) {
9159
+ for (int k1 = 0; k1 < ne1; k1++) {
9160
+ for (int i0 = 0; i0 < nr0; i0++) {
9161
+ ggml_vec_acc_f32(ne0,
9162
+ (float *) ((char *) dst->data + ( k3)*nb3 + ( k2)*nb2 + ( k1)*nb1),
9163
+ (float *) ((char *) src0->data + (i3*ne3 + k3)*nb03 + (i2*ne2 + k2)*nb02 + (i1*ne1 + k1)*nb01 + (i0*ne0)*nb00));
9164
+ }
9165
+ }
9166
+ }
9167
+ }
9168
+ }
9169
+ }
9170
+ }
9171
+ }
9172
+
9173
+ static void ggml_compute_forward_repeat_back(
9174
+ const struct ggml_compute_params * params,
9175
+ const struct ggml_tensor * src0,
9176
+ struct ggml_tensor * dst) {
9177
+ switch (src0->type) {
9178
+ case GGML_TYPE_F32:
9179
+ {
9180
+ ggml_compute_forward_repeat_back_f32(params, src0, dst);
9181
+ } break;
9182
+ default:
9183
+ {
9184
+ GGML_ASSERT(false);
9185
+ } break;
9186
+ }
9187
+ }
9188
+
9189
+ // ggml_compute_forward_abs
9190
+
9191
+ static void ggml_compute_forward_abs_f32(
9192
+ const struct ggml_compute_params * params,
9193
+ const struct ggml_tensor * src0,
9194
+ struct ggml_tensor * dst) {
9195
+ assert(params->ith == 0);
9196
+ assert(ggml_are_same_shape(src0, dst));
9197
+
9198
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
9199
+ return;
9200
+ }
9201
+
9202
+ const int n = ggml_nrows(src0);
9203
+ const int nc = src0->ne[0];
9204
+
9205
+ assert(dst->nb[0] == sizeof(float));
9206
+ assert(src0->nb[0] == sizeof(float));
9207
+
9208
+ for (int i = 0; i < n; i++) {
9209
+ ggml_vec_abs_f32(nc,
9210
+ (float *) ((char *) dst->data + i*( dst->nb[1])),
9211
+ (float *) ((char *) src0->data + i*(src0->nb[1])));
8756
9212
  }
8757
9213
  }
8758
9214
 
@@ -9245,7 +9701,7 @@ static void ggml_compute_forward_rms_norm_f32(
9245
9701
  sum += (ggml_float)(x[i00] * x[i00]);
9246
9702
  }
9247
9703
 
9248
- float mean = sum/ne00;
9704
+ const float mean = sum/ne00;
9249
9705
 
9250
9706
  float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
9251
9707
 
@@ -9568,14 +10024,7 @@ static void ggml_compute_forward_mul_mat_f32(
9568
10024
  // nb01 >= nb00 - src0 is not transposed
9569
10025
  // compute by src0 rows
9570
10026
 
9571
- #if defined(GGML_USE_CUBLAS)
9572
- if (ggml_cuda_can_mul_mat(src0, src1, dst)) {
9573
- if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
9574
- ggml_cuda_mul_mat(src0, src1, dst, params->wdata, params->wsize);
9575
- }
9576
- return;
9577
- }
9578
- #elif defined(GGML_USE_CLBLAST)
10027
+ #if defined(GGML_USE_CLBLAST)
9579
10028
  if (ggml_cl_can_mul_mat(src0, src1, dst)) {
9580
10029
  if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
9581
10030
  ggml_cl_mul_mat(src0, src1, dst, params->wdata, params->wsize);
@@ -9740,14 +10189,7 @@ static void ggml_compute_forward_mul_mat_f16_f32(
9740
10189
  // nb01 >= nb00 - src0 is not transposed
9741
10190
  // compute by src0 rows
9742
10191
 
9743
- #if defined(GGML_USE_CUBLAS)
9744
- if (ggml_cuda_can_mul_mat(src0, src1, dst)) {
9745
- if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
9746
- ggml_cuda_mul_mat(src0, src1, dst, params->wdata, params->wsize);
9747
- }
9748
- return;
9749
- }
9750
- #elif defined(GGML_USE_CLBLAST)
10192
+ #if defined(GGML_USE_CLBLAST)
9751
10193
  if (ggml_cl_can_mul_mat(src0, src1, dst)) {
9752
10194
  if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
9753
10195
  ggml_cl_mul_mat(src0, src1, dst, params->wdata, params->wsize);
@@ -9952,14 +10394,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
9952
10394
  // nb01 >= nb00 - src0 is not transposed
9953
10395
  // compute by src0 rows
9954
10396
 
9955
- #if defined(GGML_USE_CUBLAS)
9956
- if (ggml_cuda_can_mul_mat(src0, src1, dst)) {
9957
- if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
9958
- ggml_cuda_mul_mat(src0, src1, dst, params->wdata, params->wsize);
9959
- }
9960
- return;
9961
- }
9962
- #elif defined(GGML_USE_CLBLAST)
10397
+ #if defined(GGML_USE_CLBLAST)
9963
10398
  if (ggml_cl_can_mul_mat(src0, src1, dst)) {
9964
10399
  if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
9965
10400
  ggml_cl_mul_mat(src0, src1, dst, params->wdata, params->wsize);
@@ -10102,6 +10537,11 @@ static void ggml_compute_forward_mul_mat(
10102
10537
  case GGML_TYPE_Q5_1:
10103
10538
  case GGML_TYPE_Q8_0:
10104
10539
  case GGML_TYPE_Q8_1:
10540
+ case GGML_TYPE_Q2_K:
10541
+ case GGML_TYPE_Q3_K:
10542
+ case GGML_TYPE_Q4_K:
10543
+ case GGML_TYPE_Q5_K:
10544
+ case GGML_TYPE_Q6_K:
10105
10545
  {
10106
10546
  ggml_compute_forward_mul_mat_q_f32(params, src0, src1, dst);
10107
10547
  } break;
@@ -10120,6 +10560,176 @@ static void ggml_compute_forward_mul_mat(
10120
10560
  }
10121
10561
  }
10122
10562
 
10563
+ // ggml_compute_forward_out_prod
10564
+
10565
+
10566
+ static void ggml_compute_forward_out_prod_f32(
10567
+ const struct ggml_compute_params * params,
10568
+ const struct ggml_tensor * src0,
10569
+ const struct ggml_tensor * src1,
10570
+ struct ggml_tensor * dst) {
10571
+ int64_t t0 = ggml_perf_time_us();
10572
+ UNUSED(t0);
10573
+
10574
+ const int64_t ne00 = src0->ne[0];
10575
+ const int64_t ne01 = src0->ne[1];
10576
+ const int64_t ne02 = src0->ne[2];
10577
+ const int64_t ne03 = src0->ne[3];
10578
+
10579
+ const int64_t ne10 = src1->ne[0];
10580
+ //const int64_t ne11 = src1->ne[1];
10581
+ const int64_t ne12 = src1->ne[2];
10582
+ const int64_t ne13 = src1->ne[3];
10583
+
10584
+ const int64_t ne0 = dst->ne[0];
10585
+ const int64_t ne1 = dst->ne[1];
10586
+ const int64_t ne2 = dst->ne[2];
10587
+ const int64_t ne3 = dst->ne[3];
10588
+
10589
+ const int nb00 = src0->nb[0];
10590
+ const int nb01 = src0->nb[1];
10591
+ const int nb02 = src0->nb[2];
10592
+ const int nb03 = src0->nb[3];
10593
+
10594
+ const int nb10 = src1->nb[0];
10595
+ const int nb11 = src1->nb[1];
10596
+ const int nb12 = src1->nb[2];
10597
+ const int nb13 = src1->nb[3];
10598
+
10599
+ const int nb0 = dst->nb[0];
10600
+ const int nb1 = dst->nb[1];
10601
+ const int nb2 = dst->nb[2];
10602
+ const int nb3 = dst->nb[3];
10603
+
10604
+ const int ith = params->ith;
10605
+ const int nth = params->nth;
10606
+
10607
+ GGML_ASSERT(ne02 == ne12);
10608
+ GGML_ASSERT(ne03 == ne13);
10609
+ GGML_ASSERT(ne2 == ne12);
10610
+ GGML_ASSERT(ne3 == ne13);
10611
+
10612
+ // we don't support permuted src0 or src1
10613
+ GGML_ASSERT(nb00 == sizeof(float));
10614
+
10615
+ // dst cannot be transposed or permuted
10616
+ GGML_ASSERT(nb0 == sizeof(float));
10617
+ // GGML_ASSERT(nb0 <= nb1);
10618
+ // GGML_ASSERT(nb1 <= nb2);
10619
+ // GGML_ASSERT(nb2 <= nb3);
10620
+
10621
+ GGML_ASSERT(ne0 == ne00);
10622
+ GGML_ASSERT(ne1 == ne10);
10623
+ GGML_ASSERT(ne2 == ne02);
10624
+ GGML_ASSERT(ne3 == ne03);
10625
+
10626
+ // nb01 >= nb00 - src0 is not transposed
10627
+ // compute by src0 rows
10628
+
10629
+ // TODO: #if defined(GGML_USE_CUBLAS) ggml_cuda_out_prod
10630
+ // TODO: #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
10631
+
10632
+ if (params->type == GGML_TASK_INIT) {
10633
+ ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
10634
+ return;
10635
+ }
10636
+
10637
+ if (params->type == GGML_TASK_FINALIZE) {
10638
+ return;
10639
+ }
10640
+
10641
+ // parallelize by last three dimensions
10642
+
10643
+ // total rows in dst
10644
+ const int64_t nr = ne1*ne2*ne3;
10645
+
10646
+ // rows per thread
10647
+ const int64_t dr = (nr + nth - 1)/nth;
10648
+
10649
+ // row range for this thread
10650
+ const int64_t ir0 = dr*ith;
10651
+ const int64_t ir1 = MIN(ir0 + dr, nr);
10652
+
10653
+ // dst[:,:,:,:] = 0
10654
+ // for i2,i3:
10655
+ // for i1:
10656
+ // for i01:
10657
+ // for i0:
10658
+ // dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3]
10659
+
10660
+ for (int64_t ir = ir0; ir < ir1; ++ir) {
10661
+ // dst indices
10662
+ const int64_t i3 = ir/(ne2*ne1);
10663
+ const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
10664
+ const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
10665
+
10666
+ const int64_t i02 = i2;
10667
+ const int64_t i03 = i3;
10668
+
10669
+ //const int64_t i10 = i1;
10670
+ const int64_t i12 = i2;
10671
+ const int64_t i13 = i3;
10672
+
10673
+ for (int64_t i01 = 0; i01 < ne01; ++i01) {
10674
+ const int64_t i11 = i01;
10675
+
10676
+ float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
10677
+ float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
10678
+ float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
10679
+
10680
+ ggml_vec_mad_f32(ne0, d, s0, *s1);
10681
+ // for (int64_t i0 = 0; i0 < ne0; ++i0) {
10682
+ // d[i0] += s0[i0] * s1[i1];
10683
+ // }
10684
+ }
10685
+ }
10686
+
10687
+ //int64_t t1 = ggml_perf_time_us();
10688
+ //static int64_t acc = 0;
10689
+ //acc += t1 - t0;
10690
+ //if (t1 - t0 > 10) {
10691
+ // printf("\n");
10692
+ // printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03);
10693
+ // printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03);
10694
+ // printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13);
10695
+ // printf("nb10 = %5d, nb11 = %5d, nb12 = %5d, nb13 = %5d\n", nb10, nb11, nb12, nb13);
10696
+
10697
+ // printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc);
10698
+ //}
10699
+ }
10700
+
10701
+ static void ggml_compute_forward_out_prod(
10702
+ const struct ggml_compute_params * params,
10703
+ const struct ggml_tensor * src0,
10704
+ const struct ggml_tensor * src1,
10705
+ struct ggml_tensor * dst) {
10706
+ switch (src0->type) {
10707
+ case GGML_TYPE_Q4_0:
10708
+ case GGML_TYPE_Q4_1:
10709
+ case GGML_TYPE_Q5_0:
10710
+ case GGML_TYPE_Q5_1:
10711
+ case GGML_TYPE_Q8_0:
10712
+ case GGML_TYPE_Q8_1:
10713
+ {
10714
+ GGML_ASSERT(false); // todo
10715
+ // ggml_compute_forward_out_prod_q_f32(params, src0, src1, dst);
10716
+ } break;
10717
+ case GGML_TYPE_F16:
10718
+ {
10719
+ GGML_ASSERT(false); // todo
10720
+ // ggml_compute_forward_out_prod_f16_f32(params, src0, src1, dst);
10721
+ } break;
10722
+ case GGML_TYPE_F32:
10723
+ {
10724
+ ggml_compute_forward_out_prod_f32(params, src0, src1, dst);
10725
+ } break;
10726
+ default:
10727
+ {
10728
+ GGML_ASSERT(false);
10729
+ } break;
10730
+ }
10731
+ }
10732
+
10123
10733
  // ggml_compute_forward_scale
10124
10734
 
10125
10735
  static void ggml_compute_forward_scale_f32(
@@ -10285,6 +10895,11 @@ static void ggml_compute_forward_set(
10285
10895
  case GGML_TYPE_Q5_1:
10286
10896
  case GGML_TYPE_Q8_0:
10287
10897
  case GGML_TYPE_Q8_1:
10898
+ case GGML_TYPE_Q2_K:
10899
+ case GGML_TYPE_Q3_K:
10900
+ case GGML_TYPE_Q4_K:
10901
+ case GGML_TYPE_Q5_K:
10902
+ case GGML_TYPE_Q6_K:
10288
10903
  default:
10289
10904
  {
10290
10905
  GGML_ASSERT(false);
@@ -10450,6 +11065,11 @@ static void ggml_compute_forward_get_rows(
10450
11065
  case GGML_TYPE_Q5_1:
10451
11066
  case GGML_TYPE_Q8_0:
10452
11067
  case GGML_TYPE_Q8_1:
11068
+ case GGML_TYPE_Q2_K:
11069
+ case GGML_TYPE_Q3_K:
11070
+ case GGML_TYPE_Q4_K:
11071
+ case GGML_TYPE_Q5_K:
11072
+ case GGML_TYPE_Q6_K:
10453
11073
  {
10454
11074
  ggml_compute_forward_get_rows_q(params, src0, src1, dst);
10455
11075
  } break;
@@ -10532,7 +11152,11 @@ static void ggml_compute_forward_get_rows_back_f32(
10532
11152
  GGML_ASSERT(ggml_is_contiguous(opt0));
10533
11153
  GGML_ASSERT(ggml_is_contiguous(dst));
10534
11154
 
10535
- ggml_compute_forward_dup_same_cont(params, opt0, dst);
11155
+ // ggml_compute_forward_dup_same_cont(params, opt0, dst);
11156
+
11157
+ if (params->type == GGML_TASK_INIT) {
11158
+ memset(dst->data, 0, ggml_nbytes(dst));
11159
+ }
10536
11160
 
10537
11161
  if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
10538
11162
  return;
@@ -10676,8 +11300,8 @@ static void ggml_compute_forward_diag_mask_f32(
10676
11300
  const struct ggml_tensor * src1,
10677
11301
  struct ggml_tensor * dst,
10678
11302
  const float value) {
10679
- assert(src1->type == GGML_TYPE_I32);
10680
- assert(ggml_nelements(src1) == 2);
11303
+ GGML_ASSERT(src1->type == GGML_TYPE_I32);
11304
+ GGML_ASSERT(ggml_nelements(src1) == 2);
10681
11305
 
10682
11306
  const int ith = params->ith;
10683
11307
  const int nth = params->nth;
@@ -10685,7 +11309,7 @@ static void ggml_compute_forward_diag_mask_f32(
10685
11309
  const int n_past = ((int32_t *) src1->data)[0];
10686
11310
  const bool inplace = (bool)((int32_t *) src1->data)[1];
10687
11311
 
10688
- assert(n_past >= 0);
11312
+ GGML_ASSERT(n_past >= 0);
10689
11313
 
10690
11314
  if (!inplace && (params->type == GGML_TASK_INIT)) {
10691
11315
  // memcpy needs to be synchronized across threads to avoid race conditions.
@@ -10709,8 +11333,8 @@ static void ggml_compute_forward_diag_mask_f32(
10709
11333
  const int nr = src0->ne[1];
10710
11334
  const int nz = n/nr;
10711
11335
 
10712
- assert( dst->nb[0] == sizeof(float));
10713
- assert(src0->nb[0] == sizeof(float));
11336
+ GGML_ASSERT( dst->nb[0] == sizeof(float));
11337
+ GGML_ASSERT(src0->nb[0] == sizeof(float));
10714
11338
 
10715
11339
  for (int k = 0; k < nz; k++) {
10716
11340
  for (int j = ith; j < nr; j += nth) {
@@ -10846,42 +11470,137 @@ static void ggml_compute_forward_soft_max(
10846
11470
  }
10847
11471
  }
10848
11472
 
10849
- // ggml_compute_forward_alibi
11473
+ // ggml_compute_forward_soft_max_back
10850
11474
 
10851
- static void ggml_compute_forward_alibi_f32(
11475
+ static void ggml_compute_forward_soft_max_back_f32(
10852
11476
  const struct ggml_compute_params * params,
10853
11477
  const struct ggml_tensor * src0,
10854
11478
  const struct ggml_tensor * src1,
10855
11479
  struct ggml_tensor * dst) {
10856
- assert(params->ith == 0);
10857
- assert(src1->type == GGML_TYPE_I32);
10858
- assert(ggml_nelements(src1) == 3);
11480
+ GGML_ASSERT(ggml_is_contiguous(src0));
11481
+ GGML_ASSERT(ggml_is_contiguous(src1));
11482
+ GGML_ASSERT(ggml_is_contiguous(dst));
11483
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
11484
+ GGML_ASSERT(ggml_are_same_shape(src1, dst));
10859
11485
 
10860
11486
  if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
10861
11487
  return;
10862
11488
  }
10863
11489
 
10864
- const int n_past = ((int32_t *) src1->data)[0];
10865
- const int n_head = ((int32_t *) src1->data)[1];
10866
- const float max_bias = ((float *) src1->data)[2];
11490
+ // TODO: handle transposed/permuted matrices
10867
11491
 
10868
- assert(n_past >= 0);
11492
+ const int ith = params->ith;
11493
+ const int nth = params->nth;
10869
11494
 
10870
- const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
10871
- const int ne1 = src0->ne[1]; // seq_len_without_past
10872
- //const int ne2 = src0->ne[2]; // n_head -> this is k
10873
- //const int ne3 = src0->ne[3]; // 1 -> bsz
11495
+ const int nc = src0->ne[0];
11496
+ const int nr = ggml_nrows(src0);
10874
11497
 
10875
- const int n = ggml_nrows(src0);
10876
- const int ne2_ne3 = n/ne1; // ne2*ne3
11498
+ // rows per thread
11499
+ const int dr = (nr + nth - 1)/nth;
10877
11500
 
10878
- const int nb0 = src0->nb[0];
10879
- const int nb1 = src0->nb[1];
10880
- const int nb2 = src0->nb[2];
10881
- //const int nb3 = src0->nb[3];
11501
+ // row range for this thread
11502
+ const int ir0 = dr*ith;
11503
+ const int ir1 = MIN(ir0 + dr, nr);
10882
11504
 
10883
- assert(nb0 == sizeof(float));
10884
- assert(ne1 + n_past == ne0); (void) n_past;
11505
+ for (int i1 = ir0; i1 < ir1; i1++) {
11506
+ float *dy = (float *)((char *) src0->data + i1*src0->nb[1]);
11507
+ float *y = (float *)((char *) src1->data + i1*src1->nb[1]);
11508
+ float *dx = (float *)((char *) dst->data + i1*dst->nb[1]);
11509
+
11510
+ #ifndef NDEBUG
11511
+ for (int i = 0; i < nc; ++i) {
11512
+ //printf("p[%d] = %f\n", i, p[i]);
11513
+ assert(!isnan(dy[i]));
11514
+ assert(!isnan(y[i]));
11515
+ }
11516
+ #endif
11517
+ // Jii = yi - yi*yi
11518
+ // Jij = -yi*yj
11519
+ // J = diag(y)-y.T*y
11520
+ // dx = J * dy
11521
+ // dxk = sum_i(Jki * dyi)
11522
+ // dxk = sum_i(-yk*yi * dyi) - (-yk*yk)*dyk + (yk - yk*yk)*dyk
11523
+ // dxk = sum_i(-yk*yi * dyi) + yk*dyk
11524
+ // dxk = -yk * sum_i(yi * dyi) + yk*dyk
11525
+ // dxk = -yk * dot(y, dy) + yk*dyk
11526
+ // dxk = yk * (- dot(y, dy) + dyk)
11527
+ // dxk = yk * (dyk - dot(y, dy))
11528
+ //
11529
+ // post-order:
11530
+ // dot_y_dy := dot(y, dy)
11531
+ // dx := dy
11532
+ // dx := dx - dot_y_dy
11533
+ // dx := dx * y
11534
+
11535
+ // linear runtime, no additional memory
11536
+ float dot_y_dy = 0;
11537
+ ggml_vec_dot_f32 (nc, &dot_y_dy, y, dy);
11538
+ ggml_vec_cpy_f32 (nc, dx, dy);
11539
+ ggml_vec_acc1_f32(nc, dx, -dot_y_dy);
11540
+ ggml_vec_mul_f32 (nc, dx, dx, y);
11541
+
11542
+ #ifndef NDEBUG
11543
+ for (int i = 0; i < nc; ++i) {
11544
+ assert(!isnan(dx[i]));
11545
+ assert(!isinf(dx[i]));
11546
+ }
11547
+ #endif
11548
+ }
11549
+ }
11550
+
11551
+ static void ggml_compute_forward_soft_max_back(
11552
+ const struct ggml_compute_params * params,
11553
+ const struct ggml_tensor * src0,
11554
+ const struct ggml_tensor * src1,
11555
+ struct ggml_tensor * dst) {
11556
+ switch (src0->type) {
11557
+ case GGML_TYPE_F32:
11558
+ {
11559
+ ggml_compute_forward_soft_max_back_f32(params, src0, src1, dst);
11560
+ } break;
11561
+ default:
11562
+ {
11563
+ GGML_ASSERT(false);
11564
+ } break;
11565
+ }
11566
+ }
11567
+
11568
+ // ggml_compute_forward_alibi
11569
+
11570
+ static void ggml_compute_forward_alibi_f32(
11571
+ const struct ggml_compute_params * params,
11572
+ const struct ggml_tensor * src0,
11573
+ const struct ggml_tensor * src1,
11574
+ struct ggml_tensor * dst) {
11575
+ assert(params->ith == 0);
11576
+ assert(src1->type == GGML_TYPE_I32);
11577
+ assert(ggml_nelements(src1) == 3);
11578
+
11579
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
11580
+ return;
11581
+ }
11582
+
11583
+ const int n_past = ((int32_t *) src1->data)[0];
11584
+ const int n_head = ((int32_t *) src1->data)[1];
11585
+ const float max_bias = ((float *) src1->data)[2];
11586
+
11587
+ assert(n_past >= 0);
11588
+
11589
+ const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
11590
+ const int ne1 = src0->ne[1]; // seq_len_without_past
11591
+ //const int ne2 = src0->ne[2]; // n_head -> this is k
11592
+ //const int ne3 = src0->ne[3]; // 1 -> bsz
11593
+
11594
+ const int n = ggml_nrows(src0);
11595
+ const int ne2_ne3 = n/ne1; // ne2*ne3
11596
+
11597
+ const int nb0 = src0->nb[0];
11598
+ const int nb1 = src0->nb[1];
11599
+ const int nb2 = src0->nb[2];
11600
+ //const int nb3 = src0->nb[3];
11601
+
11602
+ assert(nb0 == sizeof(float));
11603
+ assert(ne1 + n_past == ne0); (void) n_past;
10885
11604
 
10886
11605
  // add alibi to src0 (KQ_scaled)
10887
11606
  const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
@@ -10996,6 +11715,12 @@ static void ggml_compute_forward_alibi(
10996
11715
  case GGML_TYPE_Q5_1:
10997
11716
  case GGML_TYPE_Q8_0:
10998
11717
  case GGML_TYPE_Q8_1:
11718
+ case GGML_TYPE_Q2_K:
11719
+ case GGML_TYPE_Q3_K:
11720
+ case GGML_TYPE_Q4_K:
11721
+ case GGML_TYPE_Q5_K:
11722
+ case GGML_TYPE_Q6_K:
11723
+ case GGML_TYPE_Q8_K:
10999
11724
  case GGML_TYPE_I8:
11000
11725
  case GGML_TYPE_I16:
11001
11726
  case GGML_TYPE_I32:
@@ -11067,6 +11792,12 @@ static void ggml_compute_forward_clamp(
11067
11792
  case GGML_TYPE_Q5_1:
11068
11793
  case GGML_TYPE_Q8_0:
11069
11794
  case GGML_TYPE_Q8_1:
11795
+ case GGML_TYPE_Q2_K:
11796
+ case GGML_TYPE_Q3_K:
11797
+ case GGML_TYPE_Q4_K:
11798
+ case GGML_TYPE_Q5_K:
11799
+ case GGML_TYPE_Q6_K:
11800
+ case GGML_TYPE_Q8_K:
11070
11801
  case GGML_TYPE_I8:
11071
11802
  case GGML_TYPE_I16:
11072
11803
  case GGML_TYPE_I32:
@@ -11156,7 +11887,7 @@ static void ggml_compute_forward_rope_f32(
11156
11887
  theta *= theta_scale;
11157
11888
 
11158
11889
  const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
11159
- float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
11890
+ float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
11160
11891
 
11161
11892
  const float x0 = src[0];
11162
11893
  const float x1 = src[1];
@@ -11177,7 +11908,7 @@ static void ggml_compute_forward_rope_f32(
11177
11908
  const int64_t i0 = ib*n_dims + ic/2;
11178
11909
 
11179
11910
  const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
11180
- float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
11911
+ float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
11181
11912
 
11182
11913
  const float x0 = src[0];
11183
11914
  const float x1 = src[n_dims/2];
@@ -12787,6 +13518,414 @@ static void ggml_compute_forward_flash_ff(
12787
13518
  }
12788
13519
  }
12789
13520
 
13521
+ // ggml_compute_forward_flash_attn_back
13522
+
13523
+ static void ggml_compute_forward_flash_attn_back_f32(
13524
+ const struct ggml_compute_params * params,
13525
+ const struct ggml_tensor * q,
13526
+ const struct ggml_tensor * k,
13527
+ const struct ggml_tensor * v,
13528
+ const struct ggml_tensor * d,
13529
+ const bool masked,
13530
+ struct ggml_tensor * dst) {
13531
+ int64_t t0 = ggml_perf_time_us();
13532
+ UNUSED(t0);
13533
+
13534
+ const int64_t neq0 = q->ne[0];
13535
+ const int64_t neq1 = q->ne[1];
13536
+ const int64_t neq2 = q->ne[2];
13537
+ const int64_t neq3 = q->ne[3];
13538
+
13539
+ const int64_t nek0 = k->ne[0];
13540
+ const int64_t nek1 = k->ne[1];
13541
+ //const int64_t nek2 = k->ne[2];
13542
+ //const int64_t nek3 = k->ne[3];
13543
+
13544
+ const int64_t nev0 = v->ne[0];
13545
+ const int64_t nev1 = v->ne[1];
13546
+ //const int64_t nev2 = v->ne[2];
13547
+ //const int64_t nev3 = v->ne[3];
13548
+
13549
+ const int64_t ned0 = d->ne[0];
13550
+ const int64_t ned1 = d->ne[1];
13551
+ //const int64_t ned2 = d->ne[2];
13552
+ //const int64_t ned3 = d->ne[3];
13553
+
13554
+ const int64_t ne0 = dst->ne[0];
13555
+ const int64_t ne1 = dst->ne[1];
13556
+ const int64_t ne2 = dst->ne[2];
13557
+ const int64_t ne3 = dst->ne[3];
13558
+
13559
+ const int nbk0 = k->nb[0];
13560
+ const int nbk1 = k->nb[1];
13561
+ const int nbk2 = k->nb[2];
13562
+ const int nbk3 = k->nb[3];
13563
+
13564
+ const int nbq0 = q->nb[0];
13565
+ const int nbq1 = q->nb[1];
13566
+ const int nbq2 = q->nb[2];
13567
+ const int nbq3 = q->nb[3];
13568
+
13569
+ const int nbv0 = v->nb[0];
13570
+ const int nbv1 = v->nb[1];
13571
+ const int nbv2 = v->nb[2];
13572
+ const int nbv3 = v->nb[3];
13573
+
13574
+ const int nbd0 = d->nb[0];
13575
+ const int nbd1 = d->nb[1];
13576
+ const int nbd2 = d->nb[2];
13577
+ const int nbd3 = d->nb[3];
13578
+
13579
+ const int nb0 = dst->nb[0];
13580
+ const int nb1 = dst->nb[1];
13581
+ const int nb2 = dst->nb[2];
13582
+ const int nb3 = dst->nb[3];
13583
+
13584
+ const int ith = params->ith;
13585
+ const int nth = params->nth;
13586
+
13587
+ const int64_t D = neq0;
13588
+ const int64_t N = neq1;
13589
+ const int64_t P = nek1 - N;
13590
+ const int64_t M = P + N;
13591
+
13592
+ const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL);
13593
+ const int mxDM = MAX(D, Mup);
13594
+
13595
+ // GGML_ASSERT(ne0 == D);
13596
+ // GGML_ASSERT(ne1 == N);
13597
+ GGML_ASSERT(P >= 0);
13598
+
13599
+ GGML_ASSERT(nbq0 == sizeof(float));
13600
+ GGML_ASSERT(nbk0 == sizeof(float));
13601
+ GGML_ASSERT(nbv0 == sizeof(float));
13602
+
13603
+ GGML_ASSERT(neq0 == D);
13604
+ GGML_ASSERT(nek0 == D);
13605
+ GGML_ASSERT(nev1 == D);
13606
+ GGML_ASSERT(ned0 == D);
13607
+
13608
+ GGML_ASSERT(neq1 == N);
13609
+ GGML_ASSERT(nek1 == N + P);
13610
+ GGML_ASSERT(nev1 == D);
13611
+ GGML_ASSERT(ned1 == N);
13612
+
13613
+ // dst cannot be transposed or permuted
13614
+ GGML_ASSERT(nb0 == sizeof(float));
13615
+ GGML_ASSERT(nb0 <= nb1);
13616
+ GGML_ASSERT(nb1 <= nb2);
13617
+ GGML_ASSERT(nb2 <= nb3);
13618
+
13619
+ if (params->type == GGML_TASK_INIT) {
13620
+ if (ith == 0) {
13621
+ memset(dst->data, 0, nb0*ne0*ne1*ne2*ne3);
13622
+ }
13623
+ return;
13624
+ }
13625
+
13626
+ if (params->type == GGML_TASK_FINALIZE) {
13627
+ return;
13628
+ }
13629
+
13630
+ // parallelize by q rows using ggml_vec_dot_f32
13631
+
13632
+ // total rows in q
13633
+ const int nr = neq2*neq3;
13634
+
13635
+ // rows per thread
13636
+ const int dr = (nr + nth - 1)/nth;
13637
+
13638
+ // row range for this thread
13639
+ const int ir0 = dr*ith;
13640
+ const int ir1 = MIN(ir0 + dr, nr);
13641
+
13642
+ const float scale = 1.0f/sqrtf(D);
13643
+
13644
+ //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
13645
+
13646
+ for (int ir = ir0; ir < ir1; ++ir) {
13647
+ // q indices
13648
+ const int iq3 = ir/(neq2);
13649
+ const int iq2 = ir - iq3*neq2;
13650
+ for ( int iq1 = 0; iq1 < neq1; ++iq1) {
13651
+
13652
+
13653
+ // not sure about CACHE_LINE_SIZE_F32..
13654
+ // - maybe it must not be multiplied by 2 and excluded from .. in SM 1*(..) offset?
13655
+ float * S = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 0*(mxDM+CACHE_LINE_SIZE_F32);
13656
+ float * SM = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 1*(mxDM+CACHE_LINE_SIZE_F32);
13657
+
13658
+ for (int i = M; i < Mup; ++i) {
13659
+ S[i] = -INFINITY;
13660
+ }
13661
+
13662
+ for (int64_t ic = 0; ic < nek1; ++ic) {
13663
+ // k indices
13664
+ const int ik3 = iq3;
13665
+ const int ik2 = iq2;
13666
+ const int ik1 = ic;
13667
+
13668
+ // S indices
13669
+ const int i1 = ik1;
13670
+
13671
+ ggml_vec_dot_f32(neq0,
13672
+ S + i1,
13673
+ (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
13674
+ (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
13675
+ }
13676
+
13677
+ // scale
13678
+ ggml_vec_scale_f32(nek1, S, scale);
13679
+
13680
+ if (masked) {
13681
+ for (int64_t i = P; i < M; i++) {
13682
+ if (i > P + iq1) {
13683
+ S[i] = -INFINITY;
13684
+ }
13685
+ }
13686
+ }
13687
+
13688
+ // softmax
13689
+ {
13690
+ float max = -INFINITY;
13691
+ ggml_vec_max_f32(M, &max, S);
13692
+
13693
+ ggml_float sum = 0.0;
13694
+ {
13695
+ #ifdef GGML_SOFT_MAX_ACCELERATE
13696
+ max = -max;
13697
+ vDSP_vsadd(SM, 1, &max, SM, 1, Mup);
13698
+ vvexpf(SM, SM, &Mup);
13699
+ ggml_vec_sum_f32(Mup, &sum, SM);
13700
+ #else
13701
+ uint16_t scvt[GGML_SOFT_MAX_UNROLL];
13702
+ ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
13703
+
13704
+ for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
13705
+ float * SR = S + i;
13706
+ float * SW = SM + i;
13707
+
13708
+ for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) {
13709
+ if (SR[j] == -INFINITY) {
13710
+ SW[j] = 0.0f;
13711
+ } else {
13712
+ ggml_fp16_t s = GGML_FP32_TO_FP16(SR[j] - max);
13713
+ memcpy(&scvt[j], &s, sizeof(uint16_t));
13714
+ const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]);
13715
+ sump[j] += (ggml_float)val;
13716
+ SW[j] = val;
13717
+ }
13718
+ }
13719
+ }
13720
+
13721
+ for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) {
13722
+ sum += sump[i];
13723
+ }
13724
+ #endif
13725
+ }
13726
+
13727
+ assert(sum > 0.0);
13728
+
13729
+ sum = 1.0/sum;
13730
+ ggml_vec_scale_f32(M, SM, sum);
13731
+
13732
+ }
13733
+
13734
+ // step-by-step explanation
13735
+ {
13736
+ // forward-process shape grads from backward process
13737
+ // parallel_for iq2,iq3:
13738
+ // k[:D,:M,:,:] [D,M,:,:] grad[k][:D,:M,iq2,iq3] += grad[kcur]
13739
+ // q[:D,:N,:,:] [D,N,:,:] grad[q][:D,iq1,iq2,iq3] += grad[qcur]
13740
+ // v[:M,:D,:,:] [M,D,:,:] grad[v][:M,:D,iq2,iq3] += grad[vcur]
13741
+ // for iq1:
13742
+ // kcur = k[:D,:M,iq2,iq3] [D,M,1,1] grad[kcur] = grad[S1].T @ qcur
13743
+ // qcur = q[:D,iq1,iq2,iq3] [D,1,1,1] grad[qcur] = grad[S1] @ kcur
13744
+ // vcur = v[:M,:D,iq2,iq3] [M,D,1,1] grad[vcur] = grad[S5].T @ S4
13745
+ // S0 = -Inf [D,1,1,1]
13746
+ // ~S1[i] = dot(kcur[:D,i], qcur)
13747
+ // S1 = qcur @ kcur.T [M,1,1,1] grad[S1] = grad[S2] * scale
13748
+ // S2 = S1 * scale [M,1,1,1] grad[S2] = diag_mask_zero(grad[S3], P)
13749
+ // S3 = diag_mask_inf(S2, P) [M,1,1,1] grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
13750
+ // S4 = softmax(S3) [M,1,1,1] grad[S4] = grad[S5] @ vcur
13751
+ // ~S5[i] = dot(vcur[:,i], S4)
13752
+ // S5 = S4 @ vcur.T [D,1,1,1] grad[S5] = d[:D,iq1,iq2,iq3]
13753
+ // ~dst[i,iq1,iq2,iq3] = S5[i] ^
13754
+ // dst[:D,iq1,iq2,iq3] = S5 | grad[dst[:D,iq1,iq2,iq3]] = d[:D,iq1,iq2,iq3]
13755
+ // dst backward-/ grad[dst] = d
13756
+ //
13757
+ // output gradients with their dependencies:
13758
+ //
13759
+ // grad[kcur] = grad[S1].T @ qcur
13760
+ // grad[S1] = diag_mask_zero(grad[S3], P) * scale
13761
+ // grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
13762
+ // grad[S4] = grad[S5] @ vcur
13763
+ // grad[S4] = d[:D,iq1,iq2,iq3] @ vcur
13764
+ // grad[qcur] = grad[S1] @ kcur
13765
+ // grad[vcur] = grad[S5].T @ S4
13766
+ // grad[vcur] = d[:D,iq1,iq2,iq3].T @ S4
13767
+ //
13768
+ // in post-order:
13769
+ //
13770
+ // S1 = qcur @ kcur.T
13771
+ // S2 = S1 * scale
13772
+ // S3 = diag_mask_inf(S2, P)
13773
+ // S4 = softmax(S3)
13774
+ // grad[S4] = d[:D,iq1,iq2,iq3] @ vcur
13775
+ // grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
13776
+ // grad[S1] = diag_mask_zero(grad[S3], P) * scale
13777
+ // grad[qcur] = grad[S1] @ kcur
13778
+ // grad[kcur] = grad[S1].T @ qcur
13779
+ // grad[vcur] = d[:D,iq1,iq2,iq3].T @ S4
13780
+ //
13781
+ // using less variables (SM=S4):
13782
+ //
13783
+ // S = diag_mask_inf(qcur @ kcur.T * scale, P)
13784
+ // SM = softmax(S)
13785
+ // S = d[:D,iq1,iq2,iq3] @ vcur
13786
+ // dot_SM_gradSM = dot(SM, S)
13787
+ // S = SM * (S - dot(SM, S))
13788
+ // S = diag_mask_zero(S, P) * scale
13789
+ //
13790
+ // grad[q][:D,iq1,iq2,iq3] += S @ kcur
13791
+ // grad[k][:D,:M,iq2,iq3] += S.T @ qcur
13792
+ // grad[v][:M,:D,iq2,iq3] += d[:D,iq1,iq2,iq3].T @ SM
13793
+ }
13794
+
13795
+ // S = gradSM = d[:D,iq1,iq2,iq3] @ vcur
13796
+ // S = d[:D,iq1,iq2,iq3] @ vcur
13797
+ // S[:M] += vcur[:M,ic] * d[ic,iq1,iq2,iq3]
13798
+ ggml_vec_set_f32(M, S, 0);
13799
+ for (int64_t ic = 0; ic < D; ++ic) {
13800
+ // dst indices
13801
+ const int i1 = iq1;
13802
+ const int i2 = iq2;
13803
+ const int i3 = iq3;
13804
+
13805
+ ggml_vec_mad_f32(M,
13806
+ S,
13807
+ (float *) ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)),
13808
+ *(float *) ((char *) d->data + (ic*nbd0 + i1*nbd1 + i2*nbd2 + i3*nbd3)));
13809
+ }
13810
+
13811
+ // S = SM * (S - dot(SM, S))
13812
+ float dot_SM_gradSM = 0;
13813
+ ggml_vec_dot_f32 (M, &dot_SM_gradSM, SM, S);
13814
+ ggml_vec_acc1_f32(M, S, -dot_SM_gradSM);
13815
+ ggml_vec_mul_f32 (M, S, S, SM);
13816
+
13817
+ // S = diag_mask_zero(S, P) * scale
13818
+ if (masked) {
13819
+ // for (int64_t i = P + iq1 + 1; i < M; i++) {
13820
+ // S[i] = 0;
13821
+ // }
13822
+ for (int64_t i = P; i < M; i++) {
13823
+ if (i > P + iq1) {
13824
+ S[i] = 0;
13825
+ }
13826
+ }
13827
+ }
13828
+ ggml_vec_scale_f32(M, S, scale);
13829
+
13830
+ void * grad_q = (char *) dst->data;
13831
+ void * grad_k = (char *) dst->data + nb0*D*N*neq2*neq3;
13832
+ void * grad_v = (char *) dst->data + nb0*D*N*neq2*neq3 + nb0*D*M*neq2*neq3;
13833
+
13834
+ const size_t nbgq1 = nb0*neq0;
13835
+ const size_t nbgq2 = nb0*neq0*neq1;
13836
+ const size_t nbgq3 = nb0*neq0*neq1*neq2;
13837
+
13838
+ const size_t nbgk1 = nb0*nek0;
13839
+ const size_t nbgk2 = nb0*nek0*nek1;
13840
+ const size_t nbgk3 = nb0*nek0*nek1*neq2;
13841
+
13842
+ const size_t nbgv1 = nb0*nev0;
13843
+ const size_t nbgv2 = nb0*nev0*nev1;
13844
+ const size_t nbgv3 = nb0*nev0*nev1*neq2;
13845
+
13846
+ // S shape [M,1]
13847
+ // SM shape [M,1]
13848
+ // kcur shape [D,M]
13849
+ // qcur shape [D,1]
13850
+ // vcur shape [M,D]
13851
+ //
13852
+ // grad[q][:D,iq1,iq2,iq3] += S @ kcur
13853
+ // grad[q][:D,iq1,iq2,iq3] += shape[M,1] @ shape[D,M]
13854
+ // grad[q][:D,iq1,iq2,iq3] += S[ic] * kcur[:D,ic]
13855
+ //
13856
+ //// grad[q][ic,iq1,iq2,iq3] += dot(kcur[:,ic],S.T)
13857
+ //// grad[q][ic,iq1,iq2,iq3] += dot(k[:D,ic,iq2,iq3],S.T)
13858
+ for (int64_t ic = 0; ic < M; ++ic) {
13859
+ // dst indices
13860
+ const int i1 = iq1;
13861
+ const int i2 = iq2;
13862
+ const int i3 = iq3;
13863
+
13864
+ ggml_vec_mad_f32(D,
13865
+ (float *) ((char *) grad_q + (i1*nbgq1 + i2*nbgq2 + i3*nbgq3)),
13866
+ (float *) ((char *) k->data + (ic*nbk1 + i2*nbk2 + i3*nbk3)),
13867
+ S[ic]);
13868
+ }
13869
+
13870
+ // grad[k][:D,:M,iq2,iq3] += S.T @ qcur
13871
+ // grad[k][:D,ic,iq2,iq3] += S.T[0,ic] * qcur[:D,0]
13872
+ // grad[k][:D,ic,iq2,iq3] += S[ic] * qcur[:D,0]
13873
+ for (int64_t ic = 0; ic < M; ++ic) {
13874
+ // dst indices
13875
+ const int i1 = iq1;
13876
+ const int i2 = iq2;
13877
+ const int i3 = iq3;
13878
+
13879
+ // ggml_vec_set_f32(D,
13880
+ // (float *) ((char *) grad_k + (ic*nbgk1 + i2*nbgk2 + i3*nbgk3)),
13881
+ // 0);
13882
+ ggml_vec_mad_f32(D,
13883
+ (float *) ((char *) grad_k + (ic*nbgk1 + i2*nbgk2 + i3*nbgk3)),
13884
+ (float *) ((char *) q->data + (i1*nbq1 + i2*nbq2 + i3*nbq3)),
13885
+ S[ic]);
13886
+ }
13887
+
13888
+ // grad[v][:M,:D,iq2,iq3] += d[:D,iq1,iq2,iq3].T @ SM
13889
+ // grad[v][:M,ic,iq2,iq3] += d[:D,iq1,iq2,iq3].T[0,ic] * SM[:M]
13890
+ // grad[v][:M,ic,iq2,iq3] += d[ic,iq1,iq2,iq3] * SM[:M]
13891
+ for (int64_t ic = 0; ic < D; ++ic) {
13892
+ // dst indices
13893
+ const int i1 = iq1;
13894
+ const int i2 = iq2;
13895
+ const int i3 = iq3;
13896
+
13897
+ // ggml_vec_set_f32(M,
13898
+ // (float *) ((char *) grad_v + ( ic*nbgv1 + i2*nbgv2 + i3*nbgv3)),
13899
+ // 0);
13900
+ ggml_vec_mad_f32(M,
13901
+ (float *) ((char *) grad_v + ( ic*nbgv1 + i2*nbgv2 + i3*nbgv3)),
13902
+ SM,
13903
+ *(float *) ((char *) d->data + (ic*nbd0 + i1*nbd1 + i2*nbd2 + i3*nbd3)));
13904
+ }
13905
+ }
13906
+ }
13907
+ }
13908
+
13909
+ static void ggml_compute_forward_flash_attn_back(
13910
+ const struct ggml_compute_params * params,
13911
+ const struct ggml_tensor * q,
13912
+ const struct ggml_tensor * k,
13913
+ const struct ggml_tensor * v,
13914
+ const struct ggml_tensor * d,
13915
+ const bool masked,
13916
+ struct ggml_tensor * dst) {
13917
+ switch (q->type) {
13918
+ case GGML_TYPE_F32:
13919
+ {
13920
+ ggml_compute_forward_flash_attn_back_f32(params, q, k, v, d, masked, dst);
13921
+ } break;
13922
+ default:
13923
+ {
13924
+ GGML_ASSERT(false);
13925
+ } break;
13926
+ }
13927
+ }
13928
+
12790
13929
  // ggml_compute_forward_map_unary
12791
13930
 
12792
13931
  static void ggml_compute_forward_map_unary_f32(
@@ -12849,29 +13988,308 @@ static void ggml_compute_forward_map_binary_f32(
12849
13988
  const int n = ggml_nrows(src0);
12850
13989
  const int nc = src0->ne[0];
12851
13990
 
12852
- assert( dst->nb[0] == sizeof(float));
12853
- assert(src0->nb[0] == sizeof(float));
12854
- assert(src1->nb[0] == sizeof(float));
13991
+ assert( dst->nb[0] == sizeof(float));
13992
+ assert(src0->nb[0] == sizeof(float));
13993
+ assert(src1->nb[0] == sizeof(float));
13994
+
13995
+ for (int i = 0; i < n; i++) {
13996
+ fun(nc,
13997
+ (float *) ((char *) dst->data + i*( dst->nb[1])),
13998
+ (float *) ((char *) src0->data + i*(src0->nb[1])),
13999
+ (float *) ((char *) src1->data + i*(src1->nb[1])));
14000
+ }
14001
+ }
14002
+
14003
+
14004
+ static void ggml_compute_forward_map_binary(
14005
+ const struct ggml_compute_params * params,
14006
+ const struct ggml_tensor * src0,
14007
+ const struct ggml_tensor * src1,
14008
+ struct ggml_tensor * dst,
14009
+ const ggml_binary_op_f32_t fun) {
14010
+ switch (src0->type) {
14011
+ case GGML_TYPE_F32:
14012
+ {
14013
+ ggml_compute_forward_map_binary_f32(params, src0, src1, dst, fun);
14014
+ } break;
14015
+ default:
14016
+ {
14017
+ GGML_ASSERT(false);
14018
+ } break;
14019
+ }
14020
+ }
14021
+
14022
+ // ggml_compute_forward_cross_entropy_loss
14023
+
14024
+ static void ggml_compute_forward_cross_entropy_loss_f32(
14025
+ const struct ggml_compute_params * params,
14026
+ const struct ggml_tensor * src0,
14027
+ const struct ggml_tensor * src1,
14028
+ struct ggml_tensor * dst) {
14029
+ GGML_ASSERT(ggml_is_contiguous(src0));
14030
+ GGML_ASSERT(ggml_is_contiguous(src1));
14031
+ GGML_ASSERT(ggml_is_scalar(dst));
14032
+ GGML_ASSERT(ggml_are_same_shape(src0, src1));
14033
+
14034
+ const int ith = params->ith;
14035
+ const int nth = params->nth;
14036
+
14037
+ float * sums = (float *) params->wdata;
14038
+
14039
+ // TODO: handle transposed/permuted matrices
14040
+ const int nc = src0->ne[0];
14041
+ const int nr = ggml_nrows(src0);
14042
+
14043
+ if (params->type == GGML_TASK_INIT) {
14044
+ if (ith == 0) {
14045
+ memset(sums, 0, sizeof(float) * (nth + nth * nc));
14046
+ }
14047
+ return;
14048
+ }
14049
+
14050
+ if (params->type == GGML_TASK_FINALIZE) {
14051
+ if (ith == 0) {
14052
+ float * dp = (float *) dst->data;
14053
+ ggml_vec_sum_f32(nth, dp, sums);
14054
+ dp[0] *= -1.0f;
14055
+ }
14056
+ return;
14057
+ }
14058
+
14059
+ const double eps = 1e-9;
14060
+
14061
+ // rows per thread
14062
+ const int dr = (nr + nth - 1)/nth;
14063
+
14064
+ // row range for this thread
14065
+ const int ir0 = dr*ith;
14066
+ const int ir1 = MIN(ir0 + dr, nr);
14067
+
14068
+ for (int i1 = ir0; i1 < ir1; i1++) {
14069
+ float * s0 = (float *)((char *) src0->data + i1*src0->nb[1]);
14070
+ float * s1 = (float *)((char *) src1->data + i1*src1->nb[1]);
14071
+ float * st = (float *) params->wdata + nth + ith*nc;
14072
+
14073
+ #ifndef NDEBUG
14074
+ for (int i = 0; i < nc; ++i) {
14075
+ //printf("p[%d] = %f\n", i, p[i]);
14076
+ assert(!isnan(s0[i]));
14077
+ assert(!isnan(s1[i]));
14078
+ }
14079
+ #endif
14080
+ // soft_max
14081
+ ggml_float sum = 0.0;
14082
+ {
14083
+ float max = -INFINITY;
14084
+ ggml_vec_max_f32(nc, &max, s0);
14085
+
14086
+ uint16_t scvt;
14087
+ for (int i = 0; i < nc; i++) {
14088
+ if (s0[i] == -INFINITY) {
14089
+ st[i] = 0.0f;
14090
+ } else {
14091
+ // const float val = (s0[i] == -INFINITY) ? 0.0 : exp(s0[i] - max);
14092
+ ggml_fp16_t s = GGML_FP32_TO_FP16(s0[i] - max);
14093
+ memcpy(&scvt, &s, sizeof(scvt));
14094
+ const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt]);
14095
+ sum += (ggml_float)val;
14096
+ st[i] = val;
14097
+ }
14098
+ }
14099
+
14100
+ assert(sum > 0.0);
14101
+ // sum = 1.0/sum;
14102
+ }
14103
+ // avoid log(0) by rescaling from [0..1] to [eps..1]
14104
+ sum = (1.0 - eps) / sum;
14105
+ ggml_vec_scale_f32(nc, st, sum);
14106
+ ggml_vec_add1_f32(nc, st, st, eps);
14107
+ ggml_vec_log_f32(nc, st, st);
14108
+ ggml_vec_mul_f32(nc, st, st, s1);
14109
+
14110
+ ggml_vec_sum_f32(nc, sums + ith, st);
14111
+
14112
+ #ifndef NDEBUG
14113
+ for (int i = 0; i < nc; ++i) {
14114
+ assert(!isnan(st[i]));
14115
+ assert(!isinf(st[i]));
14116
+ }
14117
+ #endif
14118
+ }
14119
+
14120
+ }
14121
+
14122
+ static void ggml_compute_forward_cross_entropy_loss(
14123
+ const struct ggml_compute_params * params,
14124
+ const struct ggml_tensor * src0,
14125
+ const struct ggml_tensor * src1,
14126
+ struct ggml_tensor * dst) {
14127
+ switch (src0->type) {
14128
+ case GGML_TYPE_F32:
14129
+ {
14130
+ ggml_compute_forward_cross_entropy_loss_f32(params, src0, src1, dst);
14131
+ } break;
14132
+ default:
14133
+ {
14134
+ GGML_ASSERT(false);
14135
+ } break;
14136
+ }
14137
+ }
14138
+
14139
+ // ggml_compute_forward_cross_entropy_loss_back
14140
+
14141
+ static void ggml_compute_forward_cross_entropy_loss_back_f32(
14142
+ const struct ggml_compute_params * params,
14143
+ const struct ggml_tensor * src0,
14144
+ const struct ggml_tensor * src1,
14145
+ const struct ggml_tensor * opt0,
14146
+ struct ggml_tensor * dst) {
14147
+ GGML_ASSERT(ggml_is_contiguous(dst));
14148
+ GGML_ASSERT(ggml_is_contiguous(src0));
14149
+ GGML_ASSERT(ggml_is_contiguous(src1));
14150
+ GGML_ASSERT(ggml_is_contiguous(opt0));
14151
+ GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
14152
+
14153
+ const int64_t ith = params->ith;
14154
+ const int64_t nth = params->nth;
14155
+
14156
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
14157
+ return;
14158
+ }
14159
+
14160
+ const float eps = 1e-9f;
14161
+
14162
+ // TODO: handle transposed/permuted matrices
14163
+ const int64_t nc = src0->ne[0];
14164
+ const int64_t nr = ggml_nrows(src0);
14165
+
14166
+ // rows per thread
14167
+ const int64_t dr = (nr + nth - 1)/nth;
14168
+
14169
+ // row range for this thread
14170
+ const int64_t ir0 = dr*ith;
14171
+ const int64_t ir1 = MIN(ir0 + dr, nr);
14172
+
14173
+ float * d = (float *) opt0->data;
14174
+
14175
+ for (int64_t i1 = ir0; i1 < ir1; i1++) {
14176
+ float * ds0 = (float *)((char *) dst->data + i1*dst->nb[1]);
14177
+ float * s0 = (float *)((char *) src0->data + i1*src0->nb[1]);
14178
+ float * s1 = (float *)((char *) src1->data + i1*src1->nb[1]);
14179
+ float * sm = (float *) params->wdata + ith*nc;
14180
+
14181
+ #ifndef NDEBUG
14182
+ for (int i = 0; i < nc; ++i) {
14183
+ //printf("p[%d] = %f\n", i, p[i]);
14184
+ assert(!isnan(s0[i]));
14185
+ assert(!isnan(s1[i]));
14186
+ }
14187
+ #endif
14188
+ // step by step explanation:
14189
+ {
14190
+ //float * sums = (float *) params->wdata;
14191
+
14192
+ // forward pass with annotated gradients from backward pass
14193
+ // (built by going in reverse operation order, adding to gradients of current operation args)
14194
+ // st0 = exp(s0-max(s0)) grad[st0] = grad[st1]*(1.0 - eps)/sum
14195
+ // from softmax_back: grad[s0] = st1_k * (grad[st1]_k - dot(st1, grad[st1]))
14196
+ // ggml_vec_scale_f32(nc, st, sum); // st1 = st0*/sum = softmax(s0) grad[st1] = grad[st2]*(1.0 - eps)
14197
+ // ggml_vec_scale_f32(nc, st, (1.0f - eps)); // st2 = st1*(1.0 - eps) grad[st2] = grad[st3]
14198
+ // ggml_vec_add1_f32(nc, st, st, eps); // st3 = st2 + eps grad[st3] = grad[st4]/st3
14199
+ // ggml_vec_log_f32(nc, st, st); // st4 = log(st3) grad[st4] = grad[st5] * s1
14200
+ // ggml_vec_mul_f32(nc, st, st, s1); // st5 = st4 * s1 grad[st5] = grad[sums[ith]]
14201
+ // ggml_vec_sum_f32(nc, sums + ith, st); // sums[ith] = st5 grad[sums[ith]] = grad[cross_entropy_loss] = -grad[cel]
14202
+
14203
+ // substitute into grad[st1], because we can reuse softmax_back from this point on
14204
+ // grad[st1] = -grad[cel]*s1*(1.0 - eps)/(eps + softmax(s0)*(1.0 - eps))
14205
+ // postorder:
14206
+ // grad[st1] := softmax(s0)
14207
+ // grad[st1] := grad[st1]*(1.0 - eps)
14208
+ // grad[st1] := grad[st1] + eps
14209
+ // grad[st1] := s1 / grad[st1]
14210
+ // grad[st1] := grad[st1]*(1.0-eps)*-grad[cel]
14211
+
14212
+ // src0 gradients by going through softmax_back
14213
+ // grad[s0] = st1_k * (grad[st1]_k - dot(st1, grad[st1]))
14214
+ // from softmax_back:
14215
+ // dxk = yk * (dyk - dot(y, dy))
14216
+ // dot_y_dy := dot(y, dy)
14217
+ // dx := dy
14218
+ // dx := dx - dot_y_dy
14219
+ // dx := dx * y
14220
+ // postorder:
14221
+ // dot_st1_dst1 := dot(st1, grad[st1])
14222
+ // grad[s0] := grad[st1]
14223
+ // grad[s0] := grad[s0] - dot_st1_dst1
14224
+ // grad[s0] := grad[s0] * st1
14225
+
14226
+ // prepend postorder from grad[st1] directly using grad[s0] as memory location, as we will grad[s0] := grad[st1]
14227
+ // sm := softmax(s0)
14228
+ // grad[s0] := sm*(1.0 - eps)
14229
+ // grad[s0] := grad[s0] + eps
14230
+ // grad[s0] := s1 / grad[s0]
14231
+ // grad[s0] := grad[s0]*(1.0-eps)*-grad[cel]
14232
+ // dot_st1_dst1 := dot(sm, grad[s0])
14233
+ // grad[s0] := grad[s0] - dot_st1_dst1
14234
+ // grad[s0] := grad[s0] * sm
14235
+ }
14236
+
14237
+ // soft_max
14238
+ ggml_float sum = 0.0;
14239
+ {
14240
+ float max = -INFINITY;
14241
+ ggml_vec_max_f32(nc, &max, s0);
14242
+
14243
+ uint16_t scvt;
14244
+ for (int i = 0; i < nc; i++) {
14245
+ if (s0[i] == -INFINITY) {
14246
+ sm[i] = 0.0f;
14247
+ } else {
14248
+ // const float val = (s0[i] == -INFINITY) ? 0.0 : exp(s0[i] - max);
14249
+ ggml_fp16_t s = GGML_FP32_TO_FP16(s0[i] - max);
14250
+ memcpy(&scvt, &s, sizeof(scvt));
14251
+ const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt]);
14252
+ sum += (ggml_float)val;
14253
+ sm[i] = val;
14254
+ }
14255
+ }
14256
+
14257
+ assert(sum > 0.0);
14258
+ sum = 1.0/sum;
14259
+ }
12855
14260
 
12856
- for (int i = 0; i < n; i++) {
12857
- fun(nc,
12858
- (float *) ((char *) dst->data + i*( dst->nb[1])),
12859
- (float *) ((char *) src0->data + i*(src0->nb[1])),
12860
- (float *) ((char *) src1->data + i*(src1->nb[1])));
14261
+ float dot_st1_dst1 = 0;
14262
+ ggml_vec_scale_f32(nc, sm, sum);
14263
+ ggml_vec_cpy_f32 (nc, ds0, sm);
14264
+ ggml_vec_scale_f32(nc, ds0, (1.0f - eps));
14265
+ ggml_vec_add1_f32 (nc, ds0, ds0, eps);
14266
+ ggml_vec_div_f32 (nc, ds0, s1, ds0);
14267
+ ggml_vec_scale_f32(nc, ds0, -(1.0f - eps)*d[0]);
14268
+ ggml_vec_dot_f32 (nc, &dot_st1_dst1, sm, ds0);
14269
+ ggml_vec_acc1_f32 (nc, ds0, -dot_st1_dst1);
14270
+ ggml_vec_mul_f32 (nc, ds0, ds0, sm);
14271
+
14272
+ #ifndef NDEBUG
14273
+ for (int i = 0; i < nc; ++i) {
14274
+ assert(!isnan(sm[i]));
14275
+ assert(!isinf(sm[i]));
14276
+ assert(!isnan(ds0[i]));
14277
+ assert(!isinf(ds0[i]));
14278
+ }
14279
+ #endif
12861
14280
  }
12862
14281
  }
12863
14282
 
12864
-
12865
- static void ggml_compute_forward_map_binary(
14283
+ static void ggml_compute_forward_cross_entropy_loss_back(
12866
14284
  const struct ggml_compute_params * params,
12867
14285
  const struct ggml_tensor * src0,
12868
14286
  const struct ggml_tensor * src1,
12869
- struct ggml_tensor * dst,
12870
- const ggml_binary_op_f32_t fun) {
14287
+ const struct ggml_tensor * opt0,
14288
+ struct ggml_tensor * dst) {
12871
14289
  switch (src0->type) {
12872
14290
  case GGML_TYPE_F32:
12873
14291
  {
12874
- ggml_compute_forward_map_binary_f32(params, src0, src1, dst, fun);
14292
+ ggml_compute_forward_cross_entropy_loss_back_f32(params, src0, src1, opt0, dst);
12875
14293
  } break;
12876
14294
  default:
12877
14295
  {
@@ -12880,11 +14298,21 @@ static void ggml_compute_forward_map_binary(
12880
14298
  }
12881
14299
  }
12882
14300
 
14301
+
12883
14302
  /////////////////////////////////
12884
14303
 
12885
14304
  static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
12886
14305
  GGML_ASSERT(params);
12887
14306
 
14307
+ #ifdef GGML_USE_CUBLAS
14308
+ bool skip_cpu = ggml_cuda_compute_forward(params, tensor);
14309
+ if (skip_cpu) {
14310
+ return;
14311
+ }
14312
+ GGML_ASSERT(tensor->src0->backend == GGML_BACKEND_CPU);
14313
+ GGML_ASSERT(tensor->src1 == NULL || tensor->src1->backend == GGML_BACKEND_CPU);
14314
+ #endif // GGML_USE_CUBLAS
14315
+
12888
14316
  switch (tensor->op) {
12889
14317
  case GGML_OP_DUP:
12890
14318
  {
@@ -12942,6 +14370,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
12942
14370
  {
12943
14371
  ggml_compute_forward_repeat(params, tensor->src0, tensor);
12944
14372
  } break;
14373
+ case GGML_OP_REPEAT_BACK:
14374
+ {
14375
+ ggml_compute_forward_repeat_back(params, tensor->src0, tensor);
14376
+ } break;
12945
14377
  case GGML_OP_ABS:
12946
14378
  {
12947
14379
  ggml_compute_forward_abs(params, tensor->src0, tensor);
@@ -12990,6 +14422,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
12990
14422
  {
12991
14423
  ggml_compute_forward_mul_mat(params, tensor->src0, tensor->src1, tensor);
12992
14424
  } break;
14425
+ case GGML_OP_OUT_PROD:
14426
+ {
14427
+ ggml_compute_forward_out_prod(params, tensor->src0, tensor->src1, tensor);
14428
+ } break;
12993
14429
  case GGML_OP_SCALE:
12994
14430
  {
12995
14431
  ggml_compute_forward_scale(params, tensor->src0, tensor->src1, tensor);
@@ -13046,6 +14482,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
13046
14482
  {
13047
14483
  ggml_compute_forward_soft_max(params, tensor->src0, tensor);
13048
14484
  } break;
14485
+ case GGML_OP_SOFT_MAX_BACK:
14486
+ {
14487
+ ggml_compute_forward_soft_max_back(params, tensor->src0, tensor->src1, tensor);
14488
+ } break;
13049
14489
  case GGML_OP_ROPE:
13050
14490
  {
13051
14491
  ggml_compute_forward_rope(params, tensor->src0, tensor->src1, tensor);
@@ -13081,6 +14521,13 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
13081
14521
  {
13082
14522
  ggml_compute_forward_flash_ff(params, tensor->src0, tensor->src1, tensor->opt[0], tensor->opt[1], tensor->opt[2], tensor);
13083
14523
  } break;
14524
+ case GGML_OP_FLASH_ATTN_BACK:
14525
+ {
14526
+ int32_t t = ggml_get_i32_1d(tensor->opt[2], 0);
14527
+ GGML_ASSERT(t == 0 || t == 1);
14528
+ bool masked = t != 0;
14529
+ ggml_compute_forward_flash_attn_back(params, tensor->src0, tensor->src1, tensor->opt[0], tensor->opt[1], masked, tensor);
14530
+ } break;
13084
14531
  case GGML_OP_MAP_UNARY:
13085
14532
  {
13086
14533
  const ggml_unary_op_f32_t fun = *((ggml_unary_op_f32_t *)tensor->opt[0]->data);
@@ -13093,6 +14540,16 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
13093
14540
  ggml_compute_forward_map_binary(params, tensor->src0, tensor->src1, tensor, fun);
13094
14541
  }
13095
14542
  break;
14543
+ case GGML_OP_CROSS_ENTROPY_LOSS:
14544
+ {
14545
+ ggml_compute_forward_cross_entropy_loss(params, tensor->src0, tensor->src1, tensor);
14546
+ }
14547
+ break;
14548
+ case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
14549
+ {
14550
+ ggml_compute_forward_cross_entropy_loss_back(params, tensor->src0, tensor->src1, tensor->opt[0], tensor);
14551
+ }
14552
+ break;
13096
14553
  case GGML_OP_NONE:
13097
14554
  {
13098
14555
  // nop
@@ -13231,11 +14688,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
13231
14688
  src0->grad =
13232
14689
  ggml_add_impl(ctx,
13233
14690
  src0->grad,
13234
- ggml_mul(ctx,
13235
- tensor->grad, // this was not catched by test_grad because in test_grad tensor->grad is 1
14691
+ ggml_scale(ctx,
13236
14692
  ggml_div(ctx,
13237
- ggml_repeat(ctx, ggml_new_f32(ctx, 0.5f), tensor),
13238
- tensor)),
14693
+ tensor->grad,
14694
+ tensor),
14695
+ ggml_new_f32(ctx, 0.5f)),
13239
14696
  inplace);
13240
14697
  }
13241
14698
  } break;
@@ -13281,43 +14738,20 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
13281
14738
  {
13282
14739
  // necessary for llama
13283
14740
  if (src0->grad) {
13284
- GGML_ASSERT(src0->n_dims == 1 || src0->n_dims == 2);
13285
- const int nc = tensor->ne[0];
13286
- const int nr = tensor->ne[1];
13287
- const int nc0 = src0->ne[0];
13288
- const int nr0 = src0->ne[1];
13289
- const int ncr = nc/nc0; // guaranteed to be an integer due to the check in ggml_can_repeat
13290
- const int nrr = nr/nr0; // guaranteed to be an integer due to the check in ggml_can_repeat
13291
- // tensor->grad [nc,nr,1,1]
13292
- // reshape [nc0,nc/nc0,nr0,nr/nr0]
13293
- // permute [nc0,nr0,nc/nc0,nr/nr0]
13294
- // substitute [nc0,nr0,ncr,nrr]
13295
- // reshape [nc0*nr0,ncr*nrr,1,1]
13296
- // transpose [ncr*nrr,nc0*nr0,1,1]
13297
- // sum rows [1,nc0*nr0,1,1]
13298
- // transpose [nc0*nr0,1,1]
13299
- // reshape [nc0,nr0,1,1] reshape_1d or reshape_2d
13300
- // add to src0->grad
13301
-
13302
- int64_t ne[4] = {nc0,ncr,nr0,nrr};
13303
-
13304
- struct ggml_tensor* F00 = tensor->grad;
13305
- struct ggml_tensor* F01 = ggml_reshape (ctx, F00, ggml_new_tensor(ctx,tensor->grad->type,4,ne));
13306
- struct ggml_tensor* F02 = ggml_permute (ctx, F01, 0,2,1,3);
13307
- struct ggml_tensor* F03 = ggml_cont (ctx, F02);
13308
- struct ggml_tensor* F04 = ggml_reshape_2d(ctx, F03, nc0*nr0, ncr*nrr);
13309
- struct ggml_tensor* F05 = ggml_transpose (ctx, F04);
13310
- struct ggml_tensor* F06 = ggml_cont (ctx, F05);
13311
- struct ggml_tensor* F07 = ggml_sum_rows (ctx, F06);
13312
- struct ggml_tensor* F08 = ggml_transpose (ctx, F07);
13313
- struct ggml_tensor* F09 = ggml_cont (ctx, F08);
13314
- struct ggml_tensor* F10 = ggml_reshape (ctx, F09, src0->grad);
13315
-
13316
- src0->grad =
13317
- ggml_add_impl(ctx,
13318
- src0->grad,
13319
- F10,
13320
- inplace);
14741
+ src0->grad = ggml_add_impl(ctx,
14742
+ src0->grad,
14743
+ ggml_repeat_back(ctx, tensor->grad, src0->grad),
14744
+ inplace);
14745
+ }
14746
+ } break;
14747
+ case GGML_OP_REPEAT_BACK:
14748
+ {
14749
+ if (src0->grad) {
14750
+ // TODO: test this
14751
+ src0->grad = ggml_add_impl(ctx,
14752
+ src0->grad,
14753
+ ggml_repeat(ctx, tensor->grad, src0->grad),
14754
+ inplace);
13321
14755
  }
13322
14756
  } break;
13323
14757
  case GGML_OP_ABS:
@@ -13424,38 +14858,37 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
13424
14858
 
13425
14859
  // necessary for llama
13426
14860
  if (src0->grad) {
13427
- // TODO: this requires outer product - ggml_out_prod(ctx, src1, tensor->grad);
13428
14861
  src0->grad =
13429
14862
  ggml_add_impl(ctx,
13430
14863
  src0->grad,
13431
- // ds0 = dt.dot(s1.T)
13432
- // ggml_out_prod(ctx, // [n,m]
13433
- // src1, // [n,p]
13434
- // tensor->grad), // [m,p]
13435
- // for now just using A*B==(B.T*A.T).T
13436
- ggml_cont(ctx, // [n,m]
13437
- ggml_transpose(ctx, // [n,m]
13438
- ggml_mul_mat(ctx, // [m,n]
13439
- ggml_cont(ctx, // [p,m]
13440
- ggml_transpose(ctx, // [p,m]
13441
- tensor->grad)), // [m,p]
13442
- ggml_cont(ctx, // [p,n]
13443
- ggml_transpose(ctx, // [p,n]
13444
- src1))))), // [n,p]
14864
+ ggml_out_prod(ctx, // [n,m]
14865
+ src1, // [n,p]
14866
+ tensor->grad), // [m,p]
13445
14867
  inplace);
13446
14868
  }
13447
14869
  if (src1->grad) {
13448
14870
  src1->grad =
13449
14871
  ggml_add_impl(ctx,
13450
14872
  src1->grad,
13451
- // ds1 = s0.T.dot(dt):
13452
- ggml_mul_mat(ctx, // [n,p]
13453
- ggml_cont(ctx, // [m,n]
13454
- ggml_transpose(ctx, src0)), // [m,n]
13455
- tensor->grad), // [m,p]
14873
+ // ggml_mul_mat(ctx, // [n,p]
14874
+ // ggml_cont(ctx, // [m,n]
14875
+ // ggml_transpose(ctx, src0)), // [m,n]
14876
+ // tensor->grad), // [m,p]
14877
+
14878
+ // // when src0 is bigger than tensor->grad (this is mostly the case in llama),
14879
+ // // avoid transpose of src0, rather transpose smaller tensor->grad
14880
+ // // and then use ggml_out_prod
14881
+ ggml_out_prod(ctx, // [n,p]
14882
+ src0, // [n,m]
14883
+ ggml_transpose(ctx, // [p,m]
14884
+ tensor->grad)), // [m,p]
13456
14885
  inplace);
13457
14886
  }
13458
14887
  } break;
14888
+ case GGML_OP_OUT_PROD:
14889
+ {
14890
+ GGML_ASSERT(false); // TODO: not implemented
14891
+ } break;
13459
14892
  case GGML_OP_SCALE:
13460
14893
  {
13461
14894
  // necessary for llama
@@ -13557,7 +14990,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
13557
14990
  // necessary for llama
13558
14991
  if (src0->grad) {
13559
14992
  size_t offset;
13560
- memcpy(&offset, tensor->padding, sizeof(offset));
14993
+
14994
+ GGML_ASSERT(sizeof(offset) <= ggml_nbytes(tensor->opt[0]));
14995
+ memcpy(&offset, tensor->opt[0]->data, sizeof(offset));
13561
14996
 
13562
14997
  size_t nb1 = tensor->nb[1];
13563
14998
  size_t nb2 = tensor->nb[2];
@@ -13584,10 +15019,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
13584
15019
  {
13585
15020
  // necessary for llama
13586
15021
  if (src0->grad) {
13587
- int axis0 = tensor->padding[0] & 0x3;
13588
- int axis1 = tensor->padding[1] & 0x3;
13589
- int axis2 = tensor->padding[2] & 0x3;
13590
- int axis3 = tensor->padding[3] & 0x3;
15022
+ int32_t * axes = (int32_t *) tensor->opt[0]->data;
15023
+ int axis0 = axes[0] & 0x3;
15024
+ int axis1 = axes[1] & 0x3;
15025
+ int axis2 = axes[2] & 0x3;
15026
+ int axis3 = axes[3] & 0x3;
13591
15027
  int axes_backward[4] = {0,0,0,0};
13592
15028
  axes_backward[axis0] = 0;
13593
15029
  axes_backward[axis1] = 1;
@@ -13671,50 +15107,16 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
13671
15107
  {
13672
15108
  // necessary for llama
13673
15109
  if (src0->grad) {
13674
- // y = softmax(x)
13675
- //
13676
- // Jii = yi - yi*yi
13677
- // Jij = -yi*yj
13678
- // J = diag(y)-y.*y
13679
- // dx = J * dy
13680
- // dxk = sum(Jkj * dyk)
13681
-
13682
- int64_t ne2[4] = {
13683
- tensor->ne[0],
13684
- 1,
13685
- tensor->ne[1]*tensor->ne[2],
13686
- tensor->ne[3]
13687
- };
13688
- struct ggml_tensor * tensor2 = ggml_cont(ctx,
13689
- ggml_reshape_4d(ctx,
13690
- ggml_cont(ctx, tensor),
13691
- ne2[0], ne2[1], ne2[2], ne2[3]));
13692
-
13693
- struct ggml_tensor * grad2 = ggml_cont(ctx,
13694
- ggml_reshape_4d(ctx,
13695
- ggml_cont(ctx, tensor->grad),
13696
- ne2[0], ne2[1], ne2[2], ne2[3]));
13697
-
13698
- struct ggml_tensor * tensor2_t = ggml_cont(ctx, // [1,ne0,ne1*ne2,ne3]
13699
- ggml_permute(ctx, // [1,ne0,ne1*ne2,ne3]
13700
- tensor2, // [ne0,1,ne1*ne2,ne3]
13701
- 1, 0, 2, 3));
13702
-
13703
15110
  src0->grad =
13704
- ggml_add_impl(ctx,
13705
- src0->grad, // [ne0,ne1,ne2,ne3]
13706
- ggml_reshape(ctx, // [ne0,ne1,ne2,ne3]
13707
- ggml_mul_mat(ctx, // [ne0,1,ne1*ne2,ne3]
13708
- ggml_sub(ctx, // [ne0,ne0,ne1*ne2,ne3]
13709
- ggml_diag(ctx, // [ne0,ne0,ne1*ne2,ne3]
13710
- tensor2), // [ne0,1,ne1*ne2,ne3]
13711
- ggml_mul_mat(ctx, // [ne0,ne0,ne1*ne2,ne3]
13712
- tensor2_t, // [1,ne0,ne1*ne2,ne3]
13713
- tensor2_t)), // [1,ne0,ne1*ne2,ne3]
13714
- grad2), // [ne0,1,ne1*ne2,ne3]
13715
- src0->grad),
13716
- inplace);
15111
+ ggml_add_impl(ctx, src0->grad,
15112
+ ggml_soft_max_back(ctx, tensor->grad, tensor),
15113
+ inplace);
13717
15114
  }
15115
+
15116
+ } break;
15117
+ case GGML_OP_SOFT_MAX_BACK:
15118
+ {
15119
+ GGML_ASSERT(false); // TODO: not implemented
13718
15120
  } break;
13719
15121
  case GGML_OP_ROPE:
13720
15122
  {
@@ -13769,17 +15171,190 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
13769
15171
  } break;
13770
15172
  case GGML_OP_FLASH_ATTN:
13771
15173
  {
13772
- GGML_ASSERT(false); // not supported
15174
+ struct ggml_tensor * flash_grad = NULL;
15175
+ if (src0->grad || src1->grad || tensor->opt[0]->grad) {
15176
+ int32_t t = ggml_get_i32_1d(tensor->opt[1], 0);
15177
+ GGML_ASSERT(t == 0 || t == 1);
15178
+ bool masked = t != 0;
15179
+ flash_grad =
15180
+ ggml_flash_attn_back(ctx,
15181
+ src0,
15182
+ src1,
15183
+ tensor->opt[0],
15184
+ tensor->grad,
15185
+ masked);
15186
+ }
15187
+
15188
+ if (src0->grad) {
15189
+ struct ggml_tensor * grad_q = NULL;
15190
+ const size_t nb0 = flash_grad->nb[0];
15191
+ const size_t offset = 0;
15192
+ switch(src0->n_dims) {
15193
+ case 2:
15194
+ {
15195
+ grad_q = ggml_view_2d(ctx,
15196
+ flash_grad,
15197
+ src0->ne[0],
15198
+ src0->ne[1],
15199
+ nb0*src0->ne[0],
15200
+ offset);
15201
+ } break;
15202
+ case 3:
15203
+ {
15204
+ grad_q = ggml_view_3d(ctx,
15205
+ flash_grad,
15206
+ src0->ne[0],
15207
+ src0->ne[1],
15208
+ src0->ne[2],
15209
+ nb0*src0->ne[0],
15210
+ nb0*src0->ne[0]*src0->ne[1],
15211
+ offset);
15212
+ } break;
15213
+ case 4:
15214
+ {
15215
+ grad_q = ggml_view_4d(ctx,
15216
+ flash_grad,
15217
+ src0->ne[0],
15218
+ src0->ne[1],
15219
+ src0->ne[2],
15220
+ src0->ne[3],
15221
+ nb0*src0->ne[0],
15222
+ nb0*src0->ne[0]*src0->ne[1],
15223
+ nb0*src0->ne[0]*src0->ne[1]*src0->ne[2],
15224
+ offset);
15225
+ } break;
15226
+ }
15227
+
15228
+ src0->grad = ggml_add_impl(ctx,
15229
+ src0->grad,
15230
+ grad_q,
15231
+ inplace);
15232
+ }
15233
+
15234
+ if (src1->grad) {
15235
+ struct ggml_tensor * grad_k = NULL;
15236
+ const size_t nb0 = flash_grad->nb[0];
15237
+ const size_t offset = nb0*src0->ne[0]*src0->ne[1]*src0->ne[2]*src0->ne[3];
15238
+ switch(src1->n_dims) {
15239
+ case 2:
15240
+ {
15241
+ grad_k = ggml_view_2d(ctx,
15242
+ flash_grad,
15243
+ src1->ne[0],
15244
+ src1->ne[1],
15245
+ nb0*src1->ne[0],
15246
+ offset);
15247
+ } break;
15248
+ case 3:
15249
+ {
15250
+ grad_k = ggml_view_3d(ctx,
15251
+ flash_grad,
15252
+ src1->ne[0],
15253
+ src1->ne[1],
15254
+ src1->ne[2],
15255
+ nb0*src1->ne[0],
15256
+ nb0*src1->ne[0]*src1->ne[1],
15257
+ offset);
15258
+ } break;
15259
+ case 4:
15260
+ {
15261
+ grad_k = ggml_view_4d(ctx,
15262
+ flash_grad,
15263
+ src1->ne[0],
15264
+ src1->ne[1],
15265
+ src1->ne[2],
15266
+ src1->ne[3],
15267
+ nb0*src1->ne[0],
15268
+ nb0*src1->ne[0]*src1->ne[1],
15269
+ nb0*src1->ne[0]*src1->ne[1]*src1->ne[2],
15270
+ offset);
15271
+ } break;
15272
+ }
15273
+
15274
+ src1->grad = ggml_add_impl(ctx,
15275
+ src1->grad,
15276
+ grad_k,
15277
+ inplace);
15278
+ }
15279
+
15280
+ struct ggml_tensor * opt0 = tensor->opt[0];
15281
+
15282
+ if (opt0->grad) {
15283
+ struct ggml_tensor * grad_v = NULL;
15284
+ const size_t nb0 = flash_grad->nb[0];
15285
+ const size_t offset = nb0*src0->ne[0]*src0->ne[1]*src0->ne[2]*src0->ne[3]
15286
+ + nb0*src1->ne[0]*src1->ne[1]*src1->ne[2]*src1->ne[3];
15287
+ switch(opt0->n_dims) {
15288
+ case 2:
15289
+ {
15290
+ grad_v = ggml_view_2d(ctx,
15291
+ flash_grad,
15292
+ opt0->ne[0],
15293
+ opt0->ne[1],
15294
+ nb0*opt0->ne[0],
15295
+ offset);
15296
+ } break;
15297
+ case 3:
15298
+ {
15299
+ grad_v = ggml_view_3d(ctx,
15300
+ flash_grad,
15301
+ opt0->ne[0],
15302
+ opt0->ne[1],
15303
+ opt0->ne[2],
15304
+ nb0*opt0->ne[0],
15305
+ nb0*opt0->ne[0]*opt0->ne[1],
15306
+ offset);
15307
+ } break;
15308
+ case 4:
15309
+ {
15310
+ grad_v = ggml_view_4d(ctx,
15311
+ flash_grad,
15312
+ opt0->ne[0],
15313
+ opt0->ne[1],
15314
+ opt0->ne[2],
15315
+ opt0->ne[3],
15316
+ nb0*opt0->ne[0],
15317
+ nb0*opt0->ne[0]*opt0->ne[1],
15318
+ nb0*opt0->ne[0]*opt0->ne[1]*opt0->ne[2],
15319
+ offset);
15320
+ } break;
15321
+ }
15322
+
15323
+ opt0->grad = ggml_add_impl(ctx,
15324
+ opt0->grad,
15325
+ grad_v,
15326
+ inplace);
15327
+ }
13773
15328
  } break;
13774
15329
  case GGML_OP_FLASH_FF:
13775
15330
  {
13776
15331
  GGML_ASSERT(false); // not supported
13777
15332
  } break;
15333
+ case GGML_OP_FLASH_ATTN_BACK:
15334
+ {
15335
+ GGML_ASSERT(false); // not supported
15336
+ } break;
13778
15337
  case GGML_OP_MAP_UNARY:
13779
15338
  case GGML_OP_MAP_BINARY:
13780
15339
  {
13781
15340
  GGML_ASSERT(false); // not supported
13782
15341
  } break;
15342
+ case GGML_OP_CROSS_ENTROPY_LOSS:
15343
+ {
15344
+ if (src0->grad) {
15345
+ src0->grad = ggml_add_impl(ctx,
15346
+ src0->grad,
15347
+ ggml_cross_entropy_loss_back(ctx,
15348
+ src0,
15349
+ src1,
15350
+ tensor->grad),
15351
+ inplace);
15352
+ }
15353
+ } break;
15354
+ case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
15355
+ {
15356
+ GGML_ASSERT(false); // not supported
15357
+ } break;
13783
15358
  case GGML_OP_NONE:
13784
15359
  {
13785
15360
  // nop
@@ -14156,6 +15731,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
14156
15731
  case GGML_OP_SUM_ROWS:
14157
15732
  case GGML_OP_MEAN:
14158
15733
  case GGML_OP_REPEAT:
15734
+ case GGML_OP_REPEAT_BACK:
14159
15735
  case GGML_OP_ABS:
14160
15736
  case GGML_OP_SGN:
14161
15737
  case GGML_OP_NEG:
@@ -14175,6 +15751,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
14175
15751
  node->n_tasks = n_threads;
14176
15752
  } break;
14177
15753
  case GGML_OP_MUL_MAT:
15754
+ case GGML_OP_OUT_PROD:
14178
15755
  {
14179
15756
  node->n_tasks = n_threads;
14180
15757
 
@@ -14191,7 +15768,6 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
14191
15768
  if (ggml_cuda_can_mul_mat(node->src0, node->src1, node)) {
14192
15769
  node->n_tasks = 1; // TODO: this actually is doing nothing
14193
15770
  // the threads are still spinning
14194
- cur = ggml_cuda_mul_mat_get_wsize(node->src0, node->src1, node);
14195
15771
  }
14196
15772
  else
14197
15773
  #elif defined(GGML_USE_CLBLAST)
@@ -14258,6 +15834,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
14258
15834
  } break;
14259
15835
  case GGML_OP_DIAG_MASK_INF:
14260
15836
  case GGML_OP_SOFT_MAX:
15837
+ case GGML_OP_SOFT_MAX_BACK:
14261
15838
  case GGML_OP_ROPE:
14262
15839
  case GGML_OP_ROPE_BACK:
14263
15840
  {
@@ -14337,6 +15914,27 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
14337
15914
  cur += sizeof(float)*node->src1->ne[1]*node->n_tasks; // this is overestimated by x2
14338
15915
  }
14339
15916
 
15917
+ work_size = MAX(work_size, cur);
15918
+ } break;
15919
+ case GGML_OP_FLASH_ATTN_BACK:
15920
+ {
15921
+ node->n_tasks = n_threads;
15922
+
15923
+ size_t cur = 0;
15924
+
15925
+ const int64_t D = node->src0->ne[0];
15926
+ const int64_t ne11 = ggml_up(node->src1->ne[1], GGML_SOFT_MAX_UNROLL);
15927
+ const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in ggml_compute_forward_flash_attn_back
15928
+ if (node->src1->type == GGML_TYPE_F32) {
15929
+ cur = sizeof(float)*mxDn*node->n_tasks; // TODO: this can become (n_tasks-1)
15930
+ cur += sizeof(float)*mxDn*node->n_tasks; // this is overestimated by x2
15931
+ }
15932
+
15933
+ if (node->src1->type == GGML_TYPE_F16) {
15934
+ cur = sizeof(float)*mxDn*node->n_tasks; // TODO: this can become (n_tasks-1)
15935
+ cur += sizeof(float)*mxDn*node->n_tasks; // this is overestimated by x2
15936
+ }
15937
+
14340
15938
  work_size = MAX(work_size, cur);
14341
15939
  } break;
14342
15940
  case GGML_OP_MAP_UNARY:
@@ -14344,6 +15942,22 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
14344
15942
  {
14345
15943
  node->n_tasks = 1;
14346
15944
  } break;
15945
+ case GGML_OP_CROSS_ENTROPY_LOSS:
15946
+ {
15947
+ node->n_tasks = n_threads;
15948
+
15949
+ size_t cur = ggml_type_size(node->type)*(node->n_tasks + node->src0->ne[0]*node->n_tasks);
15950
+
15951
+ work_size = MAX(work_size, cur);
15952
+ } break;
15953
+ case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
15954
+ {
15955
+ node->n_tasks = n_threads;
15956
+
15957
+ size_t cur = ggml_type_size(node->type)*node->src0->ne[0]*node->n_tasks;
15958
+
15959
+ work_size = MAX(work_size, cur);
15960
+ } break;
14347
15961
  case GGML_OP_NONE:
14348
15962
  {
14349
15963
  node->n_tasks = 1;
@@ -14581,7 +16195,7 @@ static void ggml_graph_export_leaf(const struct ggml_tensor * tensor, FILE * fou
14581
16195
  const int64_t * ne = tensor->ne;
14582
16196
  const size_t * nb = tensor->nb;
14583
16197
 
14584
- fprintf(fout, "%-6s %-12s %8d %8lld %8lld %8lld %8lld %16zu %16zu %16zu %16zu %16p %16s\n",
16198
+ fprintf(fout, "%-6s %-12s %8d %" PRId64 " %" PRId64 " %" PRId64 " %" PRId64 " %16zu %16zu %16zu %16zu %16p %32s\n",
14585
16199
  ggml_type_name(tensor->type),
14586
16200
  ggml_op_name (tensor->op),
14587
16201
  tensor->n_dims,
@@ -14595,7 +16209,7 @@ static void ggml_graph_export_node(const struct ggml_tensor * tensor, const char
14595
16209
  const int64_t * ne = tensor->ne;
14596
16210
  const size_t * nb = tensor->nb;
14597
16211
 
14598
- fprintf(fout, "%-6s %-6s %-12s %8d %8lld %8lld %8lld %8lld %16zu %16zu %16zu %16zu %8d %16p %16s\n",
16212
+ fprintf(fout, "%-6s %-6s %-12s %8d %" PRId64 " %" PRId64 " %" PRId64 " %" PRId64 " %16zu %16zu %16zu %16zu %8d %16p %32s\n",
14599
16213
  arg,
14600
16214
  ggml_type_name(tensor->type),
14601
16215
  ggml_op_name (tensor->op),
@@ -14608,8 +16222,8 @@ static void ggml_graph_export_node(const struct ggml_tensor * tensor, const char
14608
16222
  }
14609
16223
 
14610
16224
  void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname) {
14611
- assert(cgraph->work == NULL);
14612
- assert(cgraph->work_size == 0);
16225
+ //assert(cgraph->work == NULL);
16226
+ //assert(cgraph->work_size == 0);
14613
16227
 
14614
16228
  uint64_t size_eval = 0;
14615
16229
 
@@ -14624,11 +16238,11 @@ void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname) {
14624
16238
  FILE * fout = stdout;
14625
16239
 
14626
16240
  fprintf(fout, "\n");
14627
- fprintf(fout, "%-16s %8x\n", "magic", GGML_FILE_MAGIC);
14628
- fprintf(fout, "%-16s %8d\n", "version", GGML_FILE_VERSION);
14629
- fprintf(fout, "%-16s %8d\n", "leafs", cgraph->n_leafs);
14630
- fprintf(fout, "%-16s %8d\n", "nodes", cgraph->n_nodes);
14631
- fprintf(fout, "%-16s %8llu\n", "eval", size_eval);
16241
+ fprintf(fout, "%-16s %8x\n", "magic", GGML_FILE_MAGIC);
16242
+ fprintf(fout, "%-16s %8d\n", "version", GGML_FILE_VERSION);
16243
+ fprintf(fout, "%-16s %8d\n", "leafs", cgraph->n_leafs);
16244
+ fprintf(fout, "%-16s %8d\n", "nodes", cgraph->n_nodes);
16245
+ fprintf(fout, "%-16s %" PRIu64 "\n", "eval", size_eval);
14632
16246
 
14633
16247
  // header
14634
16248
  fprintf(fout, "\n");
@@ -14830,7 +16444,6 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context **
14830
16444
  // read file into data
14831
16445
  {
14832
16446
  FILE * fin = fopen(fname, "rb");
14833
-
14834
16447
  if (!fin) {
14835
16448
  fprintf(stderr, "%s: failed to open %s\n", __func__, fname);
14836
16449
  return result;
@@ -14862,7 +16475,11 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context **
14862
16475
 
14863
16476
  data = ggml_new_tensor_1d(*ctx_data, GGML_TYPE_I8, fsize);
14864
16477
 
14865
- fread(data->data, sizeof(char), fsize, fin);
16478
+ const size_t ret = fread(data->data, sizeof(char), fsize, fin);
16479
+ if (ret != fsize) {
16480
+ fprintf(stderr, "%s: failed to read %s\n", __func__, fname);
16481
+ return result;
16482
+ }
14866
16483
 
14867
16484
  fclose(fin);
14868
16485
  }
@@ -14970,6 +16587,8 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context **
14970
16587
  op = *(const uint32_t *) ptr; ptr += sizeof(op);
14971
16588
  n_dims = *(const uint32_t *) ptr; ptr += sizeof(n_dims);
14972
16589
 
16590
+ enum ggml_op eop = (enum ggml_op) op;
16591
+
14973
16592
  int64_t ne[GGML_MAX_DIMS];
14974
16593
  size_t nb[GGML_MAX_DIMS];
14975
16594
 
@@ -14984,42 +16603,77 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context **
14984
16603
  nb[j] = nb_cur;
14985
16604
  }
14986
16605
 
14987
- struct ggml_tensor * tensor = ggml_new_tensor(*ctx_eval, (enum ggml_type) type, n_dims, ne);
14988
-
14989
- tensor->op = (enum ggml_op) op;
16606
+ uint64_t ptr_cur = *(const uint64_t *) ptr; ptr += sizeof(ptr_cur); // TODO: not yet used
14990
16607
 
14991
- uint64_t ptr_cur = *(const uint64_t *) ptr; ptr += sizeof(ptr_cur);
16608
+ const char * ptr_name = ptr; ptr += GGML_MAX_NAME;
14992
16609
 
14993
- memcpy(tensor->name, ptr, GGML_MAX_NAME); ptr += GGML_MAX_NAME;
16610
+ const int32_t * ptr_arg_idx = (const int32_t *) ptr; ptr += (2 + GGML_MAX_OPT)*sizeof(int32_t);
14994
16611
 
14995
- for (int j = 0; j < GGML_MAX_DIMS; ++j) {
14996
- tensor->nb[j] = nb[j];
14997
- }
16612
+ struct ggml_tensor * args[2 + GGML_MAX_OPT] = { NULL };
14998
16613
 
14999
16614
  // parse args
15000
- {
15001
- struct ggml_tensor ** args[2 + GGML_MAX_OPT] = {
15002
- &tensor->src0,
15003
- &tensor->src1,
15004
- };
16615
+ for (int j = 0; j < 2 + GGML_MAX_OPT; ++j) {
16616
+ const int32_t arg_idx = ptr_arg_idx[j];
15005
16617
 
15006
- for (int j = 0; j < GGML_MAX_OPT; ++j) {
15007
- args[2 + j] = &tensor->opt[j];
16618
+ if (arg_idx == -1) {
16619
+ continue;
15008
16620
  }
15009
16621
 
15010
- for (int j = 0; j < 2 + GGML_MAX_OPT; ++j) {
15011
- const int32_t arg_idx = *(const int32_t *) ptr; ptr += sizeof(arg_idx);
16622
+ if (arg_idx < GGML_MAX_NODES) {
16623
+ args[j] = result.leafs[arg_idx];
16624
+ } else {
16625
+ args[j] = result.nodes[arg_idx - GGML_MAX_NODES];
16626
+ }
16627
+ }
15012
16628
 
15013
- if (arg_idx == -1) {
15014
- continue;
15015
- }
16629
+ // create the tensor
16630
+ // "view" operations are handled differently
16631
+ // TODO: handle inplace ops - currently a copy is always made
16632
+
16633
+ struct ggml_tensor * tensor = NULL;
16634
+
16635
+ switch (eop) {
16636
+ // TODO: implement other view ops
16637
+ case GGML_OP_RESHAPE:
16638
+ {
16639
+ tensor = ggml_reshape_4d(*ctx_eval, args[0], ne[0], ne[1], ne[2], ne[3]);
16640
+ } break;
16641
+ case GGML_OP_VIEW:
16642
+ {
16643
+ tensor = ggml_view_4d(*ctx_eval, args[0], ne[0], ne[1], ne[2], ne[3], 0, 0, 0, 0);
16644
+
16645
+ uint64_t offs;
16646
+ memcpy(&offs, args[2]->data, sizeof(offs));
16647
+
16648
+ tensor->data = ((char *) tensor->data) + offs;
16649
+ } break;
16650
+ case GGML_OP_TRANSPOSE:
16651
+ {
16652
+ tensor = ggml_transpose(*ctx_eval, args[0]);
16653
+ } break;
16654
+ case GGML_OP_PERMUTE:
16655
+ {
16656
+ tensor = ggml_view_4d(*ctx_eval, args[0], ne[0], ne[1], ne[2], ne[3], 0, 0, 0, 0);
16657
+ } break;
16658
+ default:
16659
+ {
16660
+ tensor = ggml_new_tensor(*ctx_eval, (enum ggml_type) type, n_dims, ne);
16661
+
16662
+ tensor->op = eop;
16663
+ } break;
16664
+ }
15016
16665
 
15017
- if (arg_idx < GGML_MAX_NODES) {
15018
- *args[j] = result.leafs[arg_idx];
15019
- } else {
15020
- *args[j] = result.nodes[arg_idx - GGML_MAX_NODES];
15021
- }
15022
- }
16666
+ memcpy(tensor->name, ptr_name, GGML_MAX_NAME);
16667
+
16668
+ for (int j = 0; j < GGML_MAX_DIMS; ++j) {
16669
+ tensor->nb[j] = nb[j];
16670
+ }
16671
+
16672
+ tensor->src0 = args[0];
16673
+ tensor->src1 = args[1];
16674
+
16675
+ for (int j = 0; j < GGML_MAX_OPT; ++j) {
16676
+ tensor->opt[j] = args[2 + j];
15023
16677
  }
15024
16678
 
15025
16679
  result.nodes[i] = tensor;
@@ -15279,6 +16933,7 @@ static void ggml_opt_get_grad(int np, struct ggml_tensor * const ps[], float * g
15279
16933
 
15280
16934
  static enum ggml_opt_result ggml_opt_adam(
15281
16935
  struct ggml_context * ctx,
16936
+ struct ggml_opt_context * opt,
15282
16937
  struct ggml_opt_params params,
15283
16938
  struct ggml_tensor * f,
15284
16939
  struct ggml_cgraph * gf,
@@ -15304,25 +16959,29 @@ static enum ggml_opt_result ggml_opt_adam(
15304
16959
  }
15305
16960
  }
15306
16961
 
16962
+ if ((opt->params.type != params.type) || (opt->nx != nx) || (opt->params.past != params.past)) {
16963
+ int iter = opt->iter;
16964
+ ggml_opt_init(opt->ctx, opt, params, nx);
16965
+ opt->iter = iter;
16966
+ }
16967
+
15307
16968
  // constants
15308
- const float alpha = params.adam.alpha;
16969
+ const float sched = params.adam.sched;
16970
+ const float decay = params.adam.decay * sched;
16971
+ const float alpha = params.adam.alpha * sched;
15309
16972
  const float beta1 = params.adam.beta1;
15310
16973
  const float beta2 = params.adam.beta2;
15311
16974
  const float eps = params.adam.eps;
15312
16975
 
15313
- float * x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // view of the parameters
15314
- float * g1 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // gradient
15315
- float * g2 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // gradient squared
15316
- float * m = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // first moment
15317
- float * v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // second moment
15318
- float * mh = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // first moment hat
15319
- float * vh = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // second moment hat
15320
-
15321
- float * pf = params.past > 0 ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)->data : NULL; // past function values
16976
+ float * x = opt->adam.x->data; // view of the parameters
16977
+ float * g1 = opt->adam.g1->data; // gradient
16978
+ float * g2 = opt->adam.g2->data; // gradient squared
16979
+ float * m = opt->adam.m->data; // first moment
16980
+ float * v = opt->adam.v->data; // second moment
16981
+ float * mh = opt->adam.mh->data; // first moment hat
16982
+ float * vh = opt->adam.vh->data; // second moment hat
15322
16983
 
15323
- // initialize
15324
- ggml_vec_set_f32(nx, m, 0.0f);
15325
- ggml_vec_set_f32(nx, v, 0.0f);
16984
+ float * pf = params.past > 0 ? opt->adam.pf->data : NULL; // past function values
15326
16985
 
15327
16986
  // update view
15328
16987
  ggml_opt_get_params(np, ps, x);
@@ -15332,16 +16991,27 @@ static enum ggml_opt_result ggml_opt_adam(
15332
16991
  ggml_set_f32 (f->grad, 1.0f);
15333
16992
  ggml_graph_compute(ctx, gb);
15334
16993
 
15335
- float fx_prev = ggml_get_f32_1d(f, 0);
16994
+ opt->adam.fx_prev = ggml_get_f32_1d(f, 0);
16995
+ opt->adam.fx_best = opt->adam.fx_prev;
15336
16996
  if (pf) {
15337
- pf[0] = fx_prev;
16997
+ pf[opt->iter % params.past] = opt->adam.fx_prev;
16998
+ }
16999
+
17000
+ // initialize
17001
+ if (opt->just_initialized) {
17002
+ opt->adam.n_no_improvement = 0;
17003
+ opt->just_initialized = false;
15338
17004
  }
15339
17005
 
15340
- int n_no_improvement = 0;
15341
- float fx_best = fx_prev;
17006
+ float * fx_best = &opt->adam.fx_best;
17007
+ float * fx_prev = &opt->adam.fx_prev;
17008
+ int * n_no_improvement = &opt->adam.n_no_improvement;
17009
+
17010
+ int iter0 = opt->iter;
15342
17011
 
15343
17012
  // run the optimizer
15344
17013
  for (int t = 0; t < params.adam.n_iter; ++t) {
17014
+ opt->iter = iter0 + t + 1;
15345
17015
  GGML_PRINT_DEBUG ("=== iter %d ===\n", t);
15346
17016
 
15347
17017
  GGML_PRINT_DEBUG ("f = %10.6f\n", ggml_get_f32_1d(f, 0));
@@ -15375,17 +17045,22 @@ static enum ggml_opt_result ggml_opt_adam(
15375
17045
 
15376
17046
  // m^hat = m_t / (1 - beta1^t)
15377
17047
  // v^hat = v_t / (1 - beta2^t)
15378
- // x_t = x_t-1 - alpha*m^hat/(sqrt(v^hat) + eps)
17048
+ // x_t = x_t-1 - sched*(alpha*m^hat/(sqrt(v^hat) + eps) + decay*x_t-1)
17049
+ // x_t = x_t-1 - sched*alpha*m^hat/(sqrt(v^hat) + eps) - sched*decay*x_t-1
17050
+ // x_t = x_t-1*(1-sched*decay) - sched*alpha*m^hat/(sqrt(v^hat) + eps)
17051
+ // x_t = x_t-1*(1-sched*decay) + sched*decay*(-alpha/decay)*m^hat/(sqrt(v^hat) + eps)
17052
+ // x_t = mix(x_t-1, (-alpha/decay)*m^hat/(sqrt(v^hat) + eps), sched*decay)
15379
17053
  ggml_vec_cpy_f32 (nx, mh, m);
15380
17054
  ggml_vec_cpy_f32 (nx, vh, v);
15381
17055
 
15382
- ggml_vec_scale_f32(nx, mh, alpha/(1.0f - powf(beta1, t + 1)));
15383
- ggml_vec_scale_f32(nx, vh, 1.0f/(1.0f - powf(beta2, t + 1)));
17056
+ ggml_vec_scale_f32(nx, mh, alpha/(1.0f - powf(beta1, opt->iter)));
17057
+ ggml_vec_scale_f32(nx, vh, 1.0f/(1.0f - powf(beta2, opt->iter)));
15384
17058
 
15385
17059
  ggml_vec_sqrt_f32 (nx, vh, vh);
15386
17060
  ggml_vec_acc1_f32 (nx, vh, eps);
15387
17061
 
15388
17062
  ggml_vec_div_f32 (nx, mh, mh, vh);
17063
+ ggml_vec_scale_f32(nx, x, 1.0f - decay);
15389
17064
  ggml_vec_sub_f32 (nx, x, x, mh);
15390
17065
 
15391
17066
  // update the parameters
@@ -15399,7 +17074,7 @@ static enum ggml_opt_result ggml_opt_adam(
15399
17074
  const float fx = ggml_get_f32_1d(f, 0);
15400
17075
 
15401
17076
  // check convergence
15402
- if (fabsf(fx - fx_prev)/fx < params.adam.eps_f) {
17077
+ if (fabsf(fx - fx_prev[0])/fx < params.adam.eps_f) {
15403
17078
  GGML_PRINT_DEBUG("converged\n");
15404
17079
 
15405
17080
  return GGML_OPT_OK;
@@ -15408,32 +17083,32 @@ static enum ggml_opt_result ggml_opt_adam(
15408
17083
  // delta-based convergence test
15409
17084
  if (pf != NULL) {
15410
17085
  // need at least params.past iterations to start checking for convergence
15411
- if (params.past <= t) {
15412
- const float rate = (pf[t%params.past] - fx)/fx;
17086
+ if (params.past <= iter0 + t) {
17087
+ const float rate = (pf[(iter0 + t)%params.past] - fx)/fx;
15413
17088
 
15414
17089
  if (fabsf(rate) < params.delta) {
15415
17090
  return GGML_OPT_OK;
15416
17091
  }
15417
17092
  }
15418
17093
 
15419
- pf[t%params.past] = fx;
17094
+ pf[(iter0 + t)%params.past] = fx;
15420
17095
  }
15421
17096
 
15422
17097
  // check for improvement
15423
17098
  if (params.max_no_improvement > 0) {
15424
- if (fx_best > fx) {
15425
- fx_best = fx;
15426
- n_no_improvement = 0;
17099
+ if (fx_best[0] > fx) {
17100
+ fx_best[0] = fx;
17101
+ n_no_improvement[0] = 0;
15427
17102
  } else {
15428
- ++n_no_improvement;
17103
+ ++n_no_improvement[0];
15429
17104
 
15430
- if (n_no_improvement >= params.max_no_improvement) {
17105
+ if (n_no_improvement[0] >= params.max_no_improvement) {
15431
17106
  return GGML_OPT_OK;
15432
17107
  }
15433
17108
  }
15434
17109
  }
15435
17110
 
15436
- fx_prev = fx;
17111
+ fx_prev[0] = fx;
15437
17112
 
15438
17113
  {
15439
17114
  const int64_t t_end_cpu = ggml_cycles();
@@ -15572,6 +17247,7 @@ static enum ggml_opt_result linesearch_backtracking(
15572
17247
 
15573
17248
  static enum ggml_opt_result ggml_opt_lbfgs(
15574
17249
  struct ggml_context * ctx,
17250
+ struct ggml_opt_context * opt,
15575
17251
  struct ggml_opt_params params,
15576
17252
  struct ggml_tensor * f,
15577
17253
  struct ggml_cgraph * gf,
@@ -15604,31 +17280,32 @@ static enum ggml_opt_result ggml_opt_lbfgs(
15604
17280
  }
15605
17281
  }
15606
17282
 
15607
- float * x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // current parameters
15608
- float * xp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // previous parameters
15609
- float * g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // current gradient
15610
- float * gp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // previous gradient
15611
- float * d = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // search direction
17283
+ if ((opt->params.type != params.type) || (opt->nx != nx) || (opt->params.past != params.past) || (opt->params.lbfgs.m != params.lbfgs.m)) {
17284
+ int iter = opt->iter;
17285
+ ggml_opt_init(ctx, opt, params, nx);
17286
+ opt->iter = iter;
17287
+ }
17288
+
17289
+ float * x = opt->lbfgs.x->data; // current parameters
17290
+ float * xp = opt->lbfgs.xp->data; // previous parameters
17291
+ float * g = opt->lbfgs.g->data; // current gradient
17292
+ float * gp = opt->lbfgs.gp->data; // previous gradient
17293
+ float * d = opt->lbfgs.d->data; // search direction
15612
17294
 
15613
- float * pf = params.past > 0 ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)->data : NULL; // past function values
17295
+ float * pf = params.past > 0 ? opt->lbfgs.pf->data : NULL; // past function values
15614
17296
 
15615
17297
  float fx = 0.0f; // cost function value
15616
17298
  float xnorm = 0.0f; // ||x||
15617
17299
  float gnorm = 0.0f; // ||g||
15618
- float step = 0.0f;
15619
17300
 
15620
17301
  // initialize x from the graph nodes
15621
17302
  ggml_opt_get_params(np, ps, x);
15622
17303
 
15623
17304
  // the L-BFGS memory
15624
- struct ggml_lbfgs_iteration_data * lm = alloca(sizeof(struct ggml_lbfgs_iteration_data)*m);
15625
-
15626
- for (int i = 0; i < m; ++i) {
15627
- lm[i].alpha = 0.0f;
15628
- lm[i].ys = 0.0f;
15629
- lm[i].s = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data;
15630
- lm[i].y = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data;
15631
- }
17305
+ float * lm_alpha = opt->lbfgs.lmal->data;
17306
+ float * lm_ys = opt->lbfgs.lmys->data;
17307
+ float * lm_s = opt->lbfgs.lms->data;
17308
+ float * lm_y = opt->lbfgs.lmy->data;
15632
17309
 
15633
17310
  // evaluate the function value and its gradient
15634
17311
  {
@@ -15643,12 +17320,6 @@ static enum ggml_opt_result ggml_opt_lbfgs(
15643
17320
  fx = ggml_get_f32_1d(f, 0);
15644
17321
  }
15645
17322
 
15646
- if (pf) {
15647
- pf[0] = fx;
15648
- }
15649
-
15650
- float fx_best = fx;
15651
-
15652
17323
  // search direction = -gradient
15653
17324
  ggml_vec_neg_f32(nx, d, g);
15654
17325
 
@@ -15665,26 +17336,43 @@ static enum ggml_opt_result ggml_opt_lbfgs(
15665
17336
  return GGML_OPT_OK;
15666
17337
  }
15667
17338
 
15668
- // initial step
15669
- ggml_vec_norm_inv_f32(nx, &step, d);
17339
+ if (opt->just_initialized) {
17340
+ if (pf) {
17341
+ pf[0] = fx;
17342
+ }
17343
+ opt->lbfgs.fx_best = fx;
17344
+
17345
+ // initial step
17346
+ ggml_vec_norm_inv_f32(nx, &opt->lbfgs.step, d);
17347
+ opt->lbfgs.j = 0;
17348
+ opt->lbfgs.k = 1;
17349
+ opt->lbfgs.end = 0;
17350
+ opt->lbfgs.n_no_improvement = 0;
17351
+ opt->just_initialized = false;
17352
+ }
17353
+
17354
+ float * fx_best = &opt->lbfgs.fx_best;
17355
+ float * step = &opt->lbfgs.step;
17356
+ int * j = &opt->lbfgs.j;
17357
+ int * k = &opt->lbfgs.k;
17358
+ int * end = &opt->lbfgs.end;
17359
+ int * n_no_improvement = &opt->lbfgs.n_no_improvement;
15670
17360
 
15671
- int j = 0;
15672
- int k = 1;
15673
- int ls = 0;
15674
- int end = 0;
15675
- int bound = 0;
15676
- int n_no_improvement = 0;
17361
+ int ls = 0;
17362
+ int bound = 0;
15677
17363
 
15678
17364
  float ys = 0.0f;
15679
17365
  float yy = 0.0f;
15680
17366
  float beta = 0.0f;
15681
17367
 
17368
+ int it = 0;
17369
+
15682
17370
  while (true) {
15683
17371
  // store the current position and gradient vectors
15684
17372
  ggml_vec_cpy_f32(nx, xp, x);
15685
17373
  ggml_vec_cpy_f32(nx, gp, g);
15686
17374
 
15687
- ls = linesearch_backtracking(ctx, &params, nx, x, &fx, g, d, &step, xp, f, gf, gb, np, ps);
17375
+ ls = linesearch_backtracking(ctx, &params, nx, x, &fx, g, d, step, xp, f, gf, gb, np, ps);
15688
17376
 
15689
17377
  if (ls < 0) {
15690
17378
  // linesearch failed - go back to the previous point and return
@@ -15710,32 +17398,32 @@ static enum ggml_opt_result ggml_opt_lbfgs(
15710
17398
  // delta-based convergence test
15711
17399
  if (pf != NULL) {
15712
17400
  // need at least params.past iterations to start checking for convergence
15713
- if (params.past <= k) {
15714
- const float rate = (pf[k%params.past] - fx)/fx;
17401
+ if (params.past <= k[0]) {
17402
+ const float rate = (pf[k[0]%params.past] - fx)/fx;
15715
17403
 
15716
17404
  if (fabsf(rate) < params.delta) {
15717
17405
  return GGML_OPT_OK;
15718
17406
  }
15719
17407
  }
15720
17408
 
15721
- pf[k%params.past] = fx;
17409
+ pf[k[0]%params.past] = fx;
15722
17410
  }
15723
17411
 
15724
17412
  // check for improvement
15725
17413
  if (params.max_no_improvement > 0) {
15726
- if (fx < fx_best) {
15727
- fx_best = fx;
15728
- n_no_improvement = 0;
17414
+ if (fx < fx_best[0]) {
17415
+ fx_best[0] = fx;
17416
+ n_no_improvement[0] = 0;
15729
17417
  } else {
15730
- n_no_improvement++;
17418
+ n_no_improvement[0]++;
15731
17419
 
15732
- if (n_no_improvement >= params.max_no_improvement) {
17420
+ if (n_no_improvement[0] >= params.max_no_improvement) {
15733
17421
  return GGML_OPT_OK;
15734
17422
  }
15735
17423
  }
15736
17424
  }
15737
17425
 
15738
- if (params.lbfgs.n_iter != 0 && params.lbfgs.n_iter < k + 1) {
17426
+ if (params.lbfgs.n_iter != 0 && params.lbfgs.n_iter < it + 1) {
15739
17427
  // reached the maximum number of iterations
15740
17428
  return GGML_OPT_DID_NOT_CONVERGE;
15741
17429
  }
@@ -15744,50 +17432,51 @@ static enum ggml_opt_result ggml_opt_lbfgs(
15744
17432
  // s_{k+1} = x_{k+1} - x_{k} = \step * d_{k}.
15745
17433
  // y_{k+1} = g_{k+1} - g_{k}.
15746
17434
  //
15747
- ggml_vec_sub_f32(nx, lm[end].s, x, xp);
15748
- ggml_vec_sub_f32(nx, lm[end].y, g, gp);
17435
+ ggml_vec_sub_f32(nx, &lm_s[end[0]*nx], x, xp);
17436
+ ggml_vec_sub_f32(nx, &lm_y[end[0]*nx], g, gp);
15749
17437
 
15750
17438
  // compute scalars ys and yy:
15751
17439
  // ys = y^t \cdot s -> 1 / \rho.
15752
17440
  // yy = y^t \cdot y.
15753
17441
  //
15754
- ggml_vec_dot_f32(nx, &ys, lm[end].y, lm[end].s);
15755
- ggml_vec_dot_f32(nx, &yy, lm[end].y, lm[end].y);
17442
+ ggml_vec_dot_f32(nx, &ys, &lm_y[end[0]*nx], &lm_s[end[0] *nx]);
17443
+ ggml_vec_dot_f32(nx, &yy, &lm_y[end[0]*nx], &lm_y[end[0]*nx]);
15756
17444
 
15757
- lm[end].ys = ys;
17445
+ lm_ys[end[0]] = ys;
15758
17446
 
15759
17447
  // find new search direction
15760
17448
  // ref: https://en.wikipedia.org/wiki/Limited-memory_BFGS
15761
17449
 
15762
- bound = (m <= k) ? m : k;
15763
- k++;
15764
- end = (end + 1)%m;
17450
+ bound = (m <= k[0]) ? m : k[0];
17451
+ k[0]++;
17452
+ it++;
17453
+ end[0] = (end[0] + 1)%m;
15765
17454
 
15766
17455
  // initialize search direction with -g
15767
17456
  ggml_vec_neg_f32(nx, d, g);
15768
17457
 
15769
- j = end;
17458
+ j[0] = end[0];
15770
17459
  for (int i = 0; i < bound; ++i) {
15771
- j = (j + m - 1) % m;
17460
+ j[0] = (j[0] + m - 1) % m;
15772
17461
  // \alpha_{j} = \rho_{j} s^{t}_{j} \cdot q_{k+1}
15773
- ggml_vec_dot_f32(nx, &lm[j].alpha, lm[j].s, d);
15774
- lm[j].alpha /= lm[j].ys;
17462
+ ggml_vec_dot_f32(nx, &lm_alpha[j[0]], &lm_s[j[0]*nx], d);
17463
+ lm_alpha[j[0]] /= lm_ys[j[0]];
15775
17464
  // q_{i} = q_{i+1} - \alpha_{i} y_{i}
15776
- ggml_vec_mad_f32(nx, d, lm[j].y, -lm[j].alpha);
17465
+ ggml_vec_mad_f32(nx, d, &lm_y[j[0]*nx], -lm_alpha[j[0]]);
15777
17466
  }
15778
17467
 
15779
17468
  ggml_vec_scale_f32(nx, d, ys/yy);
15780
17469
 
15781
17470
  for (int i = 0; i < bound; ++i) {
15782
17471
  // \beta_{j} = \rho_{j} y^t_{j} \cdot \gamma_{i}
15783
- ggml_vec_dot_f32(nx, &beta, lm[j].y, d);
15784
- beta /= lm[j].ys;
17472
+ ggml_vec_dot_f32(nx, &beta, &lm_y[j[0]*nx], d);
17473
+ beta /= lm_ys[j[0]];
15785
17474
  // \gamma_{i+1} = \gamma_{i} + (\alpha_{j} - \beta_{j}) s_{j}
15786
- ggml_vec_mad_f32(nx, d, lm[j].s, lm[j].alpha - beta);
15787
- j = (j + 1)%m;
17475
+ ggml_vec_mad_f32(nx, d, &lm_s[j[0]*nx], lm_alpha[j[0]] - beta);
17476
+ j[0] = (j[0] + 1)%m;
15788
17477
  }
15789
17478
 
15790
- step = 1.0;
17479
+ step[0] = 1.0;
15791
17480
  }
15792
17481
 
15793
17482
  return GGML_OPT_DID_NOT_CONVERGE;
@@ -15812,6 +17501,8 @@ struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) {
15812
17501
 
15813
17502
  .adam = {
15814
17503
  .n_iter = 10000,
17504
+ .sched = 1.000f,
17505
+ .decay = 0.001f,
15815
17506
  .alpha = 0.001f,
15816
17507
  .beta1 = 0.9f,
15817
17508
  .beta2 = 0.999f,
@@ -15854,6 +17545,71 @@ struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) {
15854
17545
  return result;
15855
17546
  }
15856
17547
 
17548
+ GGML_API void ggml_opt_init(
17549
+ struct ggml_context * ctx,
17550
+ struct ggml_opt_context * opt,
17551
+ struct ggml_opt_params params,
17552
+ int64_t nx) {
17553
+ opt->ctx = ctx;
17554
+ opt->params = params;
17555
+ opt->iter = 0;
17556
+ opt->nx = nx;
17557
+ opt->just_initialized = true;
17558
+ switch (opt->params.type) {
17559
+ case GGML_OPT_ADAM:
17560
+ {
17561
+ opt->adam.x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
17562
+ opt->adam.g1 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
17563
+ opt->adam.g2 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
17564
+ opt->adam.m = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
17565
+ opt->adam.v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
17566
+ opt->adam.mh = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
17567
+ opt->adam.vh = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
17568
+ opt->adam.pf = params.past > 0
17569
+ ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)
17570
+ : NULL;
17571
+ ggml_set_zero(opt->adam.x);
17572
+ ggml_set_zero(opt->adam.g1);
17573
+ ggml_set_zero(opt->adam.g2);
17574
+ ggml_set_zero(opt->adam.m);
17575
+ ggml_set_zero(opt->adam.v);
17576
+ ggml_set_zero(opt->adam.mh);
17577
+ ggml_set_zero(opt->adam.vh);
17578
+ if (opt->adam.pf) {
17579
+ ggml_set_zero(opt->adam.pf);
17580
+ }
17581
+ } break;
17582
+ case GGML_OPT_LBFGS:
17583
+ {
17584
+ opt->lbfgs.x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
17585
+ opt->lbfgs.xp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
17586
+ opt->lbfgs.g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
17587
+ opt->lbfgs.gp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
17588
+ opt->lbfgs.d = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
17589
+ opt->lbfgs.pf = params.past > 0
17590
+ ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)
17591
+ : NULL;
17592
+ opt->lbfgs.lmal = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.lbfgs.m);
17593
+ opt->lbfgs.lmys = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.lbfgs.m);
17594
+ opt->lbfgs.lms = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nx, params.lbfgs.m);
17595
+ opt->lbfgs.lmy = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nx, params.lbfgs.m);
17596
+ ggml_set_zero(opt->lbfgs.x);
17597
+ ggml_set_zero(opt->lbfgs.xp);
17598
+ ggml_set_zero(opt->lbfgs.g);
17599
+ ggml_set_zero(opt->lbfgs.gp);
17600
+ ggml_set_zero(opt->lbfgs.d);
17601
+ ggml_set_zero(opt->lbfgs.pf);
17602
+ if (opt->lbfgs.pf) {
17603
+ ggml_set_zero(opt->lbfgs.pf);
17604
+ }
17605
+ ggml_set_zero(opt->lbfgs.lmal);
17606
+ ggml_set_zero(opt->lbfgs.lmys);
17607
+ ggml_set_zero(opt->lbfgs.lms);
17608
+ ggml_set_zero(opt->lbfgs.lmy);
17609
+ } break;
17610
+ }
17611
+ }
17612
+
15857
17613
  enum ggml_opt_result ggml_opt(
15858
17614
  struct ggml_context * ctx,
15859
17615
  struct ggml_opt_params params,
@@ -15876,33 +17632,65 @@ enum ggml_opt_result ggml_opt(
15876
17632
 
15877
17633
  enum ggml_opt_result result = GGML_OPT_OK;
15878
17634
 
17635
+ struct ggml_opt_context * opt = (struct ggml_opt_context *) alloca(sizeof(struct ggml_opt_context));
17636
+
17637
+ ggml_opt_init(ctx, opt, params, 0);
17638
+ result = ggml_opt_resume(ctx, opt, f);
17639
+
17640
+ if (free_ctx) {
17641
+ ggml_free(ctx);
17642
+ }
17643
+
17644
+ return result;
17645
+ }
17646
+
17647
+ enum ggml_opt_result ggml_opt_resume(
17648
+ struct ggml_context * ctx,
17649
+ struct ggml_opt_context * opt,
17650
+ struct ggml_tensor * f) {
17651
+
17652
+ // build forward + backward compute graphs
17653
+ struct ggml_tensor * gfbuf = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(struct ggml_cgraph) / GGML_TYPE_SIZE[GGML_TYPE_I32]+ (sizeof(struct ggml_cgraph) % GGML_TYPE_SIZE[GGML_TYPE_I32] ? 1 : 0));
17654
+ struct ggml_tensor * gbbuf = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(struct ggml_cgraph) / GGML_TYPE_SIZE[GGML_TYPE_I32]+ (sizeof(struct ggml_cgraph) % GGML_TYPE_SIZE[GGML_TYPE_I32] ? 1 : 0));
17655
+
17656
+ struct ggml_cgraph * gf = (struct ggml_cgraph *) gfbuf->data;
17657
+ struct ggml_cgraph * gb = (struct ggml_cgraph *) gbbuf->data;
17658
+
17659
+ *gf = ggml_build_forward (f);
17660
+ *gb = ggml_build_backward(ctx, gf, true);
17661
+
17662
+ return ggml_opt_resume_g(ctx, opt, f, gf, gb);
17663
+ }
17664
+
17665
+ enum ggml_opt_result ggml_opt_resume_g(
17666
+ struct ggml_context * ctx,
17667
+ struct ggml_opt_context * opt,
17668
+ struct ggml_tensor * f,
17669
+ struct ggml_cgraph * gf,
17670
+ struct ggml_cgraph * gb) {
17671
+
15879
17672
  // build forward + backward compute graphs
15880
- struct ggml_cgraph gf = ggml_build_forward (f);
15881
- struct ggml_cgraph gb = ggml_build_backward(ctx, &gf, true);
17673
+ enum ggml_opt_result result = GGML_OPT_OK;
15882
17674
 
15883
- switch (params.type) {
17675
+ switch (opt->params.type) {
15884
17676
  case GGML_OPT_ADAM:
15885
17677
  {
15886
- result = ggml_opt_adam(ctx, params, f, &gf, &gb);
17678
+ result = ggml_opt_adam(ctx, opt, opt->params, f, gf, gb);
15887
17679
  } break;
15888
17680
  case GGML_OPT_LBFGS:
15889
17681
  {
15890
- result = ggml_opt_lbfgs(ctx, params, f, &gf, &gb);
17682
+ result = ggml_opt_lbfgs(ctx, opt, opt->params, f, gf, gb);
15891
17683
  } break;
15892
17684
  }
15893
17685
 
15894
- if (params.print_forward_graph) {
15895
- ggml_graph_print (&gf);
15896
- ggml_graph_dump_dot(&gf, NULL, "opt-forward.dot");
15897
- }
15898
-
15899
- if (params.print_backward_graph) {
15900
- ggml_graph_print (&gb);
15901
- ggml_graph_dump_dot(&gb, &gf, "opt-backward.dot");
17686
+ if (opt->params.print_forward_graph) {
17687
+ ggml_graph_print (gf);
17688
+ ggml_graph_dump_dot(gf, NULL, "opt-forward.dot");
15902
17689
  }
15903
17690
 
15904
- if (free_ctx) {
15905
- ggml_free(ctx);
17691
+ if (opt->params.print_backward_graph) {
17692
+ ggml_graph_print (gb);
17693
+ ggml_graph_dump_dot(gb, gf, "opt-backward.dot");
15906
17694
  }
15907
17695
 
15908
17696
  return result;
@@ -16070,6 +17858,50 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i
16070
17858
  block_q8_0 * block = (block_q8_0*)dst + start / QK8_0;
16071
17859
  result = ggml_quantize_q8_0(src + start, block, n, n, hist);
16072
17860
  } break;
17861
+ #ifdef GGML_USE_K_QUANTS
17862
+ case GGML_TYPE_Q2_K:
17863
+ {
17864
+ GGML_ASSERT(start % QK_K == 0);
17865
+ block_q2_K * block = (block_q2_K*)dst + start / QK_K;
17866
+ result = ggml_quantize_q2_K(src + start, block, n, n, hist);
17867
+ } break;
17868
+ case GGML_TYPE_Q3_K:
17869
+ {
17870
+ GGML_ASSERT(start % QK_K == 0);
17871
+ block_q3_K * block = (block_q3_K*)dst + start / QK_K;
17872
+ result = ggml_quantize_q3_K(src + start, block, n, n, hist);
17873
+ } break;
17874
+ case GGML_TYPE_Q4_K:
17875
+ {
17876
+ GGML_ASSERT(start % QK_K == 0);
17877
+ block_q4_K * block = (block_q4_K*)dst + start / QK_K;
17878
+ result = ggml_quantize_q4_K(src + start, block, n, n, hist);
17879
+ } break;
17880
+ case GGML_TYPE_Q5_K:
17881
+ {
17882
+ GGML_ASSERT(start % QK_K == 0);
17883
+ block_q5_K * block = (block_q5_K*)dst + start / QK_K;
17884
+ result = ggml_quantize_q5_K(src + start, block, n, n, hist);
17885
+ } break;
17886
+ case GGML_TYPE_Q6_K:
17887
+ {
17888
+ GGML_ASSERT(start % QK_K == 0);
17889
+ block_q6_K * block = (block_q6_K*)dst + start / QK_K;
17890
+ result = ggml_quantize_q6_K(src + start, block, n, n, hist);
17891
+ } break;
17892
+ #endif
17893
+ case GGML_TYPE_F16:
17894
+ {
17895
+ int elemsize = sizeof(ggml_fp16_t);
17896
+ ggml_fp32_to_fp16_row(src + start, (ggml_fp16_t *)dst + start, n);
17897
+ result = n * elemsize;
17898
+ } break;
17899
+ case GGML_TYPE_F32:
17900
+ {
17901
+ int elemsize = sizeof(float);
17902
+ result = n * elemsize;
17903
+ memcpy((uint8_t *)dst + start * elemsize, src + start, result);
17904
+ } break;
16073
17905
  default:
16074
17906
  assert(false);
16075
17907
  }